Compare commits
55 Commits
v0.11.6
...
pdevine/pa
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
c10a40db99 | ||
|
|
93c64ea1b1 | ||
|
|
3f6642f6fc | ||
|
|
6f7117145f | ||
|
|
92b96d54ef | ||
|
|
9d56e63dbf | ||
|
|
053092185e | ||
|
|
44a6792873 | ||
|
|
e4ce68311a | ||
|
|
26214125e8 | ||
|
|
61fb912ca4 | ||
|
|
aba1575315 | ||
|
|
eb10390de9 | ||
|
|
feb18cd710 | ||
|
|
8a7e2055d2 | ||
|
|
29ddfc2cab | ||
|
|
71cb86af3e | ||
|
|
5198956372 | ||
|
|
17a023f34b | ||
|
|
8d6fffaead | ||
|
|
20b53eaa72 | ||
|
|
6745182885 | ||
|
|
f810ec741c | ||
|
|
e119783e66 | ||
|
|
1a558f98e2 | ||
|
|
7b91c9ce51 | ||
|
|
950d33aa30 | ||
|
|
9714e38dd0 | ||
|
|
4378ae4ffa | ||
|
|
5994e8e8fd | ||
|
|
b3e6120736 | ||
|
|
fb92b61754 | ||
|
|
8149a3c86e | ||
|
|
0cc90a8186 | ||
|
|
e42300f25b | ||
|
|
66e73809a1 | ||
|
|
517807cdf2 | ||
|
|
ead4a9a1d0 | ||
|
|
4383a3ab7a | ||
|
|
9d97e6a9f1 | ||
|
|
1081532430 | ||
|
|
59412fbb43 | ||
|
|
86834a2797 | ||
|
|
85ccf7354d | ||
|
|
30fb7e19f8 | ||
|
|
d3450dd52e | ||
|
|
4bcb04ad88 | ||
|
|
e3d5708754 | ||
|
|
4be4dc8717 | ||
|
|
109d4fc3b4 | ||
|
|
2cb0a580f3 | ||
|
|
7cce5aac76 | ||
|
|
4ae4f47b16 | ||
|
|
073fa31df5 | ||
|
|
91fc3c48e3 |
28
.github/workflows/release.yaml
vendored
28
.github/workflows/release.yaml
vendored
@@ -65,14 +65,36 @@ jobs:
|
|||||||
arch: amd64
|
arch: amd64
|
||||||
preset: 'CUDA 12'
|
preset: 'CUDA 12'
|
||||||
install: https://developer.download.nvidia.com/compute/cuda/12.8.0/local_installers/cuda_12.8.0_571.96_windows.exe
|
install: https://developer.download.nvidia.com/compute/cuda/12.8.0/local_installers/cuda_12.8.0_571.96_windows.exe
|
||||||
|
cuda-components:
|
||||||
|
- '"cudart"'
|
||||||
|
- '"nvcc"'
|
||||||
|
- '"cublas"'
|
||||||
|
- '"cublas_dev"'
|
||||||
cuda-version: '12.8'
|
cuda-version: '12.8'
|
||||||
flags: ''
|
flags: ''
|
||||||
|
runner_dir: 'cuda_v12'
|
||||||
|
- os: windows
|
||||||
|
arch: amd64
|
||||||
|
preset: 'CUDA 13'
|
||||||
|
install: https://developer.download.nvidia.com/compute/cuda/13.0.0/local_installers/cuda_13.0.0_windows.exe
|
||||||
|
cuda-components:
|
||||||
|
- '"cudart"'
|
||||||
|
- '"nvcc"'
|
||||||
|
- '"cublas"'
|
||||||
|
- '"cublas_dev"'
|
||||||
|
- '"crt"'
|
||||||
|
- '"nvvm"'
|
||||||
|
- '"nvptxcompiler"'
|
||||||
|
cuda-version: '13.0'
|
||||||
|
flags: ''
|
||||||
|
runner_dir: 'cuda_v13'
|
||||||
- os: windows
|
- os: windows
|
||||||
arch: amd64
|
arch: amd64
|
||||||
preset: 'ROCm 6'
|
preset: 'ROCm 6'
|
||||||
install: https://download.amd.com/developer/eula/rocm-hub/AMD-Software-PRO-Edition-24.Q4-WinSvr2022-For-HIP.exe
|
install: https://download.amd.com/developer/eula/rocm-hub/AMD-Software-PRO-Edition-24.Q4-WinSvr2022-For-HIP.exe
|
||||||
rocm-version: '6.2'
|
rocm-version: '6.2'
|
||||||
flags: '-DCMAKE_C_COMPILER=clang -DCMAKE_CXX_COMPILER=clang++ -DCMAKE_C_FLAGS="-parallel-jobs=4 -Wno-ignored-attributes -Wno-deprecated-pragma" -DCMAKE_CXX_FLAGS="-parallel-jobs=4 -Wno-ignored-attributes -Wno-deprecated-pragma"'
|
flags: '-DCMAKE_C_COMPILER=clang -DCMAKE_CXX_COMPILER=clang++ -DCMAKE_C_FLAGS="-parallel-jobs=4 -Wno-ignored-attributes -Wno-deprecated-pragma" -DCMAKE_CXX_FLAGS="-parallel-jobs=4 -Wno-ignored-attributes -Wno-deprecated-pragma"'
|
||||||
|
runner_dir: ''
|
||||||
runs-on: ${{ matrix.arch == 'arm64' && format('{0}-{1}', matrix.os, matrix.arch) || matrix.os }}
|
runs-on: ${{ matrix.arch == 'arm64' && format('{0}-{1}', matrix.os, matrix.arch) || matrix.os }}
|
||||||
environment: release
|
environment: release
|
||||||
env:
|
env:
|
||||||
@@ -96,7 +118,7 @@ jobs:
|
|||||||
$ErrorActionPreference = "Stop"
|
$ErrorActionPreference = "Stop"
|
||||||
if ("${{ steps.cache-install.outputs.cache-hit }}" -ne 'true') {
|
if ("${{ steps.cache-install.outputs.cache-hit }}" -ne 'true') {
|
||||||
Invoke-WebRequest -Uri "${{ matrix.install }}" -OutFile "install.exe"
|
Invoke-WebRequest -Uri "${{ matrix.install }}" -OutFile "install.exe"
|
||||||
$subpackages = @("cudart", "nvcc", "cublas", "cublas_dev") | Foreach-Object {"${_}_${{ matrix.cuda-version }}"}
|
$subpackages = @(${{ join(matrix.cuda-components, ', ') }}) | Foreach-Object {"${_}_${{ matrix.cuda-version }}"}
|
||||||
Start-Process -FilePath .\install.exe -ArgumentList (@("-s") + $subpackages) -NoNewWindow -Wait
|
Start-Process -FilePath .\install.exe -ArgumentList (@("-s") + $subpackages) -NoNewWindow -Wait
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -138,7 +160,7 @@ jobs:
|
|||||||
run: |
|
run: |
|
||||||
Import-Module 'C:\Program Files\Microsoft Visual Studio\2022\Enterprise\Common7\Tools\Microsoft.VisualStudio.DevShell.dll'
|
Import-Module 'C:\Program Files\Microsoft Visual Studio\2022\Enterprise\Common7\Tools\Microsoft.VisualStudio.DevShell.dll'
|
||||||
Enter-VsDevShell -VsInstallPath 'C:\Program Files\Microsoft Visual Studio\2022\Enterprise' -SkipAutomaticLocation -DevCmdArguments '-arch=x64 -no_logo'
|
Enter-VsDevShell -VsInstallPath 'C:\Program Files\Microsoft Visual Studio\2022\Enterprise' -SkipAutomaticLocation -DevCmdArguments '-arch=x64 -no_logo'
|
||||||
cmake --preset "${{ matrix.preset }}" ${{ matrix.flags }}
|
cmake --preset "${{ matrix.preset }}" ${{ matrix.flags }} -DOLLAMA_RUNNER_DIR="${{ matrix.runner_dir }}"
|
||||||
cmake --build --parallel --preset "${{ matrix.preset }}"
|
cmake --build --parallel --preset "${{ matrix.preset }}"
|
||||||
cmake --install build --component "${{ startsWith(matrix.preset, 'CUDA ') && 'CUDA' || startsWith(matrix.preset, 'ROCm ') && 'HIP' || 'CPU' }}" --strip --parallel 8
|
cmake --install build --component "${{ startsWith(matrix.preset, 'CUDA ') && 'CUDA' || startsWith(matrix.preset, 'ROCm ') && 'HIP' || 'CPU' }}" --strip --parallel 8
|
||||||
env:
|
env:
|
||||||
@@ -232,7 +254,7 @@ jobs:
|
|||||||
case "$COMPONENT" in
|
case "$COMPONENT" in
|
||||||
bin/ollama) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}.tar.in ;;
|
bin/ollama) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}.tar.in ;;
|
||||||
lib/ollama/*.so*) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}.tar.in ;;
|
lib/ollama/*.so*) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}.tar.in ;;
|
||||||
lib/ollama/cuda_sbsa) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}.tar.in ;;
|
lib/ollama/cuda_v*) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}.tar.in ;;
|
||||||
lib/ollama/cuda_jetpack5) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}-jetpack5.tar.in ;;
|
lib/ollama/cuda_jetpack5) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}-jetpack5.tar.in ;;
|
||||||
lib/ollama/cuda_jetpack6) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}-jetpack6.tar.in ;;
|
lib/ollama/cuda_jetpack6) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}-jetpack6.tar.in ;;
|
||||||
lib/ollama/rocm) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}-rocm.tar.in ;;
|
lib/ollama/rocm) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}-rocm.tar.in ;;
|
||||||
|
|||||||
16
.github/workflows/test.yaml
vendored
16
.github/workflows/test.yaml
vendored
@@ -46,7 +46,7 @@ jobs:
|
|||||||
include:
|
include:
|
||||||
- preset: CPU
|
- preset: CPU
|
||||||
- preset: CUDA
|
- preset: CUDA
|
||||||
container: nvidia/cuda:12.8.1-devel-ubuntu22.04
|
container: nvidia/cuda:13.0.0-devel-ubuntu22.04
|
||||||
flags: '-DCMAKE_CUDA_ARCHITECTURES=87'
|
flags: '-DCMAKE_CUDA_ARCHITECTURES=87'
|
||||||
- preset: ROCm
|
- preset: ROCm
|
||||||
container: rocm/dev-ubuntu-22.04:6.1.2
|
container: rocm/dev-ubuntu-22.04:6.1.2
|
||||||
@@ -78,8 +78,17 @@ jobs:
|
|||||||
include:
|
include:
|
||||||
- preset: CPU
|
- preset: CPU
|
||||||
- preset: CUDA
|
- preset: CUDA
|
||||||
install: https://developer.download.nvidia.com/compute/cuda/12.8.0/local_installers/cuda_12.8.0_571.96_windows.exe
|
install: https://developer.download.nvidia.com/compute/cuda/13.0.0/local_installers/cuda_13.0.0_windows.exe
|
||||||
flags: '-DCMAKE_CUDA_ARCHITECTURES=80'
|
flags: '-DCMAKE_CUDA_ARCHITECTURES=80'
|
||||||
|
cuda-components:
|
||||||
|
- '"cudart"'
|
||||||
|
- '"nvcc"'
|
||||||
|
- '"cublas"'
|
||||||
|
- '"cublas_dev"'
|
||||||
|
- '"crt"'
|
||||||
|
- '"nvvm"'
|
||||||
|
- '"nvptxcompiler"'
|
||||||
|
cuda-version: '13.0'
|
||||||
- preset: ROCm
|
- preset: ROCm
|
||||||
install: https://download.amd.com/developer/eula/rocm-hub/AMD-Software-PRO-Edition-24.Q4-WinSvr2022-For-HIP.exe
|
install: https://download.amd.com/developer/eula/rocm-hub/AMD-Software-PRO-Edition-24.Q4-WinSvr2022-For-HIP.exe
|
||||||
flags: '-DAMDGPU_TARGETS=gfx1010 -DCMAKE_C_COMPILER=clang -DCMAKE_CXX_COMPILER=clang++ -DCMAKE_C_FLAGS="-parallel-jobs=4 -Wno-ignored-attributes -Wno-deprecated-pragma" -DCMAKE_CXX_FLAGS="-parallel-jobs=4 -Wno-ignored-attributes -Wno-deprecated-pragma"'
|
flags: '-DAMDGPU_TARGETS=gfx1010 -DCMAKE_C_COMPILER=clang -DCMAKE_CXX_COMPILER=clang++ -DCMAKE_C_FLAGS="-parallel-jobs=4 -Wno-ignored-attributes -Wno-deprecated-pragma" -DCMAKE_CXX_FLAGS="-parallel-jobs=4 -Wno-ignored-attributes -Wno-deprecated-pragma"'
|
||||||
@@ -102,7 +111,8 @@ jobs:
|
|||||||
$ErrorActionPreference = "Stop"
|
$ErrorActionPreference = "Stop"
|
||||||
if ("${{ steps.cache-install.outputs.cache-hit }}" -ne 'true') {
|
if ("${{ steps.cache-install.outputs.cache-hit }}" -ne 'true') {
|
||||||
Invoke-WebRequest -Uri "${{ matrix.install }}" -OutFile "install.exe"
|
Invoke-WebRequest -Uri "${{ matrix.install }}" -OutFile "install.exe"
|
||||||
Start-Process -FilePath .\install.exe -ArgumentList (@("-s", "cudart_12.8", "nvcc_12.8", "cublas_12.8", "cublas_dev_12.8")) -NoNewWindow -Wait
|
$subpackages = @(${{ join(matrix.cuda-components, ', ') }}) | Foreach-Object {"${_}_${{ matrix.cuda-version }}"}
|
||||||
|
Start-Process -FilePath .\install.exe -ArgumentList (@("-s") + $subpackages) -NoNewWindow -Wait
|
||||||
}
|
}
|
||||||
|
|
||||||
$cudaPath = (Resolve-Path "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\*").path
|
$cudaPath = (Resolve-Path "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\*").path
|
||||||
|
|||||||
@@ -38,7 +38,7 @@ if (CMAKE_OSX_ARCHITECTURES MATCHES "x86_64")
|
|||||||
endif()
|
endif()
|
||||||
|
|
||||||
set(OLLAMA_BUILD_DIR ${CMAKE_BINARY_DIR}/lib/ollama)
|
set(OLLAMA_BUILD_DIR ${CMAKE_BINARY_DIR}/lib/ollama)
|
||||||
set(OLLAMA_INSTALL_DIR ${CMAKE_INSTALL_PREFIX}/lib/ollama)
|
set(OLLAMA_INSTALL_DIR ${CMAKE_INSTALL_PREFIX}/lib/ollama/${OLLAMA_RUNNER_DIR})
|
||||||
|
|
||||||
set(CMAKE_RUNTIME_OUTPUT_DIRECTORY ${OLLAMA_BUILD_DIR})
|
set(CMAKE_RUNTIME_OUTPUT_DIRECTORY ${OLLAMA_BUILD_DIR})
|
||||||
set(CMAKE_RUNTIME_OUTPUT_DIRECTORY_DEBUG ${OLLAMA_BUILD_DIR})
|
set(CMAKE_RUNTIME_OUTPUT_DIRECTORY_DEBUG ${OLLAMA_BUILD_DIR})
|
||||||
@@ -81,7 +81,7 @@ if(CMAKE_CUDA_COMPILER)
|
|||||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/ggml/ggml/src/ggml-cuda)
|
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/ggml/ggml/src/ggml-cuda)
|
||||||
install(TARGETS ggml-cuda
|
install(TARGETS ggml-cuda
|
||||||
RUNTIME_DEPENDENCIES
|
RUNTIME_DEPENDENCIES
|
||||||
DIRECTORIES ${CUDAToolkit_BIN_DIR} ${CUDAToolkit_LIBRARY_DIR}
|
DIRECTORIES ${CUDAToolkit_BIN_DIR} ${CUDAToolkit_BIN_DIR}/x64 ${CUDAToolkit_LIBRARY_DIR}
|
||||||
PRE_INCLUDE_REGEXES cublas cublasLt cudart
|
PRE_INCLUDE_REGEXES cublas cublasLt cudart
|
||||||
PRE_EXCLUDE_REGEXES ".*"
|
PRE_EXCLUDE_REGEXES ".*"
|
||||||
RUNTIME DESTINATION ${OLLAMA_INSTALL_DIR} COMPONENT CUDA
|
RUNTIME DESTINATION ${OLLAMA_INSTALL_DIR} COMPONENT CUDA
|
||||||
|
|||||||
@@ -18,6 +18,14 @@
|
|||||||
"name": "CUDA",
|
"name": "CUDA",
|
||||||
"inherits": [ "Default" ]
|
"inherits": [ "Default" ]
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
"name": "CUDA 11",
|
||||||
|
"inherits": [ "CUDA" ],
|
||||||
|
"cacheVariables": {
|
||||||
|
"CMAKE_CUDA_ARCHITECTURES": "50-virtual;60-virtual;61-virtual;70-virtual;75-virtual;80-virtual;86-virtual;87-virtual;89-virtual;90-virtual",
|
||||||
|
"CMAKE_CUDA_FLAGS": "-Wno-deprecated-gpu-targets -t 2"
|
||||||
|
}
|
||||||
|
},
|
||||||
{
|
{
|
||||||
"name": "CUDA 12",
|
"name": "CUDA 12",
|
||||||
"inherits": [ "CUDA" ],
|
"inherits": [ "CUDA" ],
|
||||||
@@ -26,6 +34,14 @@
|
|||||||
"CMAKE_CUDA_FLAGS": "-Wno-deprecated-gpu-targets -t 2"
|
"CMAKE_CUDA_FLAGS": "-Wno-deprecated-gpu-targets -t 2"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
"name": "CUDA 13",
|
||||||
|
"inherits": [ "CUDA" ],
|
||||||
|
"cacheVariables": {
|
||||||
|
"CMAKE_CUDA_ARCHITECTURES": "75-virtual;80-virtual;86-virtual;87-virtual;89-virtual;90-virtual;90a-virtual;100-virtual;110-virtual;120-virtual;121-virtual",
|
||||||
|
"CMAKE_CUDA_FLAGS": "-t 2"
|
||||||
|
}
|
||||||
|
},
|
||||||
{
|
{
|
||||||
"name": "JetPack 5",
|
"name": "JetPack 5",
|
||||||
"inherits": [ "CUDA" ],
|
"inherits": [ "CUDA" ],
|
||||||
@@ -72,11 +88,21 @@
|
|||||||
"configurePreset": "CUDA",
|
"configurePreset": "CUDA",
|
||||||
"targets": [ "ggml-cuda" ]
|
"targets": [ "ggml-cuda" ]
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
"name": "CUDA 11",
|
||||||
|
"inherits": [ "CUDA" ],
|
||||||
|
"configurePreset": "CUDA 11"
|
||||||
|
},
|
||||||
{
|
{
|
||||||
"name": "CUDA 12",
|
"name": "CUDA 12",
|
||||||
"inherits": [ "CUDA" ],
|
"inherits": [ "CUDA" ],
|
||||||
"configurePreset": "CUDA 12"
|
"configurePreset": "CUDA 12"
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
"name": "CUDA 13",
|
||||||
|
"inherits": [ "CUDA" ],
|
||||||
|
"configurePreset": "CUDA 13"
|
||||||
|
},
|
||||||
{
|
{
|
||||||
"name": "JetPack 5",
|
"name": "JetPack 5",
|
||||||
"inherits": [ "CUDA" ],
|
"inherits": [ "CUDA" ],
|
||||||
|
|||||||
30
Dockerfile
30
Dockerfile
@@ -39,15 +39,35 @@ RUN --mount=type=cache,target=/root/.ccache \
|
|||||||
&& cmake --build --parallel --preset 'CPU' \
|
&& cmake --build --parallel --preset 'CPU' \
|
||||||
&& cmake --install build --component CPU --strip --parallel 8
|
&& cmake --install build --component CPU --strip --parallel 8
|
||||||
|
|
||||||
|
FROM base AS cuda-11
|
||||||
|
ARG CUDA11VERSION=11.8
|
||||||
|
RUN dnf install -y cuda-toolkit-${CUDA11VERSION//./-}
|
||||||
|
ENV PATH=/usr/local/cuda-11/bin:$PATH
|
||||||
|
RUN --mount=type=cache,target=/root/.ccache \
|
||||||
|
cmake --preset 'CUDA 11' -DOLLAMA_RUNNER_DIR="cuda_v11" \
|
||||||
|
&& cmake --build --parallel --preset 'CUDA 11' \
|
||||||
|
&& cmake --install build --component CUDA --strip --parallel 8
|
||||||
|
|
||||||
FROM base AS cuda-12
|
FROM base AS cuda-12
|
||||||
ARG CUDA12VERSION=12.8
|
ARG CUDA12VERSION=12.8
|
||||||
RUN dnf install -y cuda-toolkit-${CUDA12VERSION//./-}
|
RUN dnf install -y cuda-toolkit-${CUDA12VERSION//./-}
|
||||||
ENV PATH=/usr/local/cuda-12/bin:$PATH
|
ENV PATH=/usr/local/cuda-12/bin:$PATH
|
||||||
RUN --mount=type=cache,target=/root/.ccache \
|
RUN --mount=type=cache,target=/root/.ccache \
|
||||||
cmake --preset 'CUDA 12' \
|
cmake --preset 'CUDA 12' -DOLLAMA_RUNNER_DIR="cuda_v12"\
|
||||||
&& cmake --build --parallel --preset 'CUDA 12' \
|
&& cmake --build --parallel --preset 'CUDA 12' \
|
||||||
&& cmake --install build --component CUDA --strip --parallel 8
|
&& cmake --install build --component CUDA --strip --parallel 8
|
||||||
|
|
||||||
|
|
||||||
|
FROM base AS cuda-13
|
||||||
|
ARG CUDA13VERSION=13.0
|
||||||
|
RUN dnf install -y cuda-toolkit-${CUDA13VERSION//./-}
|
||||||
|
ENV PATH=/usr/local/cuda-13/bin:$PATH
|
||||||
|
RUN --mount=type=cache,target=/root/.ccache \
|
||||||
|
cmake --preset 'CUDA 13' -DOLLAMA_RUNNER_DIR="cuda_v13" \
|
||||||
|
&& cmake --build --parallel --preset 'CUDA 13' \
|
||||||
|
&& cmake --install build --component CUDA --strip --parallel 8
|
||||||
|
|
||||||
|
|
||||||
FROM base AS rocm-6
|
FROM base AS rocm-6
|
||||||
ENV PATH=/opt/rocm/hcc/bin:/opt/rocm/hip/bin:/opt/rocm/bin:/opt/rocm/hcc/bin:$PATH
|
ENV PATH=/opt/rocm/hcc/bin:/opt/rocm/hip/bin:/opt/rocm/bin:/opt/rocm/hcc/bin:$PATH
|
||||||
RUN --mount=type=cache,target=/root/.ccache \
|
RUN --mount=type=cache,target=/root/.ccache \
|
||||||
@@ -92,10 +112,14 @@ RUN --mount=type=cache,target=/root/.cache/go-build \
|
|||||||
go build -trimpath -buildmode=pie -o /bin/ollama .
|
go build -trimpath -buildmode=pie -o /bin/ollama .
|
||||||
|
|
||||||
FROM --platform=linux/amd64 scratch AS amd64
|
FROM --platform=linux/amd64 scratch AS amd64
|
||||||
COPY --from=cuda-12 dist/lib/ollama /lib/ollama
|
# COPY --from=cuda-11 dist/lib/ollama/ /lib/ollama/
|
||||||
|
COPY --from=cuda-12 dist/lib/ollama /lib/ollama/
|
||||||
|
COPY --from=cuda-13 dist/lib/ollama/ /lib/ollama/
|
||||||
|
|
||||||
FROM --platform=linux/arm64 scratch AS arm64
|
FROM --platform=linux/arm64 scratch AS arm64
|
||||||
COPY --from=cuda-12 dist/lib/ollama /lib/ollama/cuda_sbsa
|
# COPY --from=cuda-11 dist/lib/ollama/ /lib/ollama/
|
||||||
|
COPY --from=cuda-12 dist/lib/ollama /lib/ollama/
|
||||||
|
COPY --from=cuda-13 dist/lib/ollama/ /lib/ollama/
|
||||||
COPY --from=jetpack-5 dist/lib/ollama /lib/ollama/cuda_jetpack5
|
COPY --from=jetpack-5 dist/lib/ollama /lib/ollama/cuda_jetpack5
|
||||||
COPY --from=jetpack-6 dist/lib/ollama /lib/ollama/cuda_jetpack6
|
COPY --from=jetpack-6 dist/lib/ollama /lib/ollama/cuda_jetpack6
|
||||||
|
|
||||||
|
|||||||
@@ -413,6 +413,8 @@ See the [API documentation](./docs/api.md) for all endpoints.
|
|||||||
- [Mayan EDMS](https://gitlab.com/mayan-edms/mayan-edms) (Open source document management system to organize, tag, search, and automate your files with powerful Ollama driven workflows.)
|
- [Mayan EDMS](https://gitlab.com/mayan-edms/mayan-edms) (Open source document management system to organize, tag, search, and automate your files with powerful Ollama driven workflows.)
|
||||||
- [Serene Pub](https://github.com/doolijb/serene-pub) (Beginner friendly, open source AI Roleplaying App for Windows, Mac OS and Linux. Search, download and use models with Ollama all inside the app.)
|
- [Serene Pub](https://github.com/doolijb/serene-pub) (Beginner friendly, open source AI Roleplaying App for Windows, Mac OS and Linux. Search, download and use models with Ollama all inside the app.)
|
||||||
- [Andes](https://github.com/aqerd/andes) (A Visual Studio Code extension that provides a local UI interface for Ollama models)
|
- [Andes](https://github.com/aqerd/andes) (A Visual Studio Code extension that provides a local UI interface for Ollama models)
|
||||||
|
- [Clueless](https://github.com/KashyapTan/clueless) (Open Source & Local Cluely: A desktop application LLM assistant to help you talk to anything on your screen using locally served Ollama models. Also undetectable to screenshare)
|
||||||
|
- [ollama-co2](https://github.com/carbonatedWaterOrg/ollama-co2) (FastAPI web interface for monitoring and managing local and remote Ollama servers with real-time model monitoring and concurrent downloads)
|
||||||
|
|
||||||
### Cloud
|
### Cloud
|
||||||
|
|
||||||
@@ -541,6 +543,7 @@ See the [API documentation](./docs/api.md) for all endpoints.
|
|||||||
- [OllamaPlusPlus](https://github.com/HardCodeDev777/OllamaPlusPlus) (Very simple C++ library for Ollama)
|
- [OllamaPlusPlus](https://github.com/HardCodeDev777/OllamaPlusPlus) (Very simple C++ library for Ollama)
|
||||||
- [any-llm](https://github.com/mozilla-ai/any-llm) (A single interface to use different llm providers by [mozilla.ai](https://www.mozilla.ai/))
|
- [any-llm](https://github.com/mozilla-ai/any-llm) (A single interface to use different llm providers by [mozilla.ai](https://www.mozilla.ai/))
|
||||||
- [any-agent](https://github.com/mozilla-ai/any-agent) (A single interface to use and evaluate different agent frameworks by [mozilla.ai](https://www.mozilla.ai/))
|
- [any-agent](https://github.com/mozilla-ai/any-agent) (A single interface to use and evaluate different agent frameworks by [mozilla.ai](https://www.mozilla.ai/))
|
||||||
|
- [Neuro SAN](https://github.com/cognizant-ai-lab/neuro-san-studio) (Data-driven multi-agent orchestration framework) with [example](https://github.com/cognizant-ai-lab/neuro-san-studio/blob/main/docs/user_guide.md#ollama)
|
||||||
|
|
||||||
### Mobile
|
### Mobile
|
||||||
|
|
||||||
@@ -601,6 +604,7 @@ See the [API documentation](./docs/api.md) for all endpoints.
|
|||||||
- [UnityCodeLama](https://github.com/HardCodeDev777/UnityCodeLama) (Unity Edtior tool to analyze scripts via Ollama)
|
- [UnityCodeLama](https://github.com/HardCodeDev777/UnityCodeLama) (Unity Edtior tool to analyze scripts via Ollama)
|
||||||
- [NativeMind](https://github.com/NativeMindBrowser/NativeMindExtension) (Private, on-device AI Assistant, no cloud dependencies)
|
- [NativeMind](https://github.com/NativeMindBrowser/NativeMindExtension) (Private, on-device AI Assistant, no cloud dependencies)
|
||||||
- [GMAI - Gradle Managed AI](https://gmai.premex.se/) (Gradle plugin for automated Ollama lifecycle management during build phases)
|
- [GMAI - Gradle Managed AI](https://gmai.premex.se/) (Gradle plugin for automated Ollama lifecycle management during build phases)
|
||||||
|
- [NOMYO Router](https://github.com/nomyo-ai/nomyo-router) (A transparent Ollama proxy with model deployment aware routing which auto-manages multiple Ollama instances in a given network)
|
||||||
|
|
||||||
### Supported backends
|
### Supported backends
|
||||||
|
|
||||||
|
|||||||
31
api/types.go
31
api/types.go
@@ -286,16 +286,23 @@ func mapToTypeScriptType(jsonType string) string {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type ToolFunctionParameters struct {
|
||||||
|
Type string `json:"type"`
|
||||||
|
Defs any `json:"$defs,omitempty"`
|
||||||
|
Items any `json:"items,omitempty"`
|
||||||
|
Required []string `json:"required"`
|
||||||
|
Properties map[string]ToolProperty `json:"properties"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *ToolFunctionParameters) String() string {
|
||||||
|
bts, _ := json.Marshal(t)
|
||||||
|
return string(bts)
|
||||||
|
}
|
||||||
|
|
||||||
type ToolFunction struct {
|
type ToolFunction struct {
|
||||||
Name string `json:"name"`
|
Name string `json:"name"`
|
||||||
Description string `json:"description"`
|
Description string `json:"description"`
|
||||||
Parameters struct {
|
Parameters ToolFunctionParameters `json:"parameters"`
|
||||||
Type string `json:"type"`
|
|
||||||
Defs any `json:"$defs,omitempty"`
|
|
||||||
Items any `json:"items,omitempty"`
|
|
||||||
Required []string `json:"required"`
|
|
||||||
Properties map[string]ToolProperty `json:"properties"`
|
|
||||||
} `json:"parameters"`
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *ToolFunction) String() string {
|
func (t *ToolFunction) String() string {
|
||||||
@@ -381,8 +388,12 @@ type EmbedRequest struct {
|
|||||||
// this request.
|
// this request.
|
||||||
KeepAlive *Duration `json:"keep_alive,omitempty"`
|
KeepAlive *Duration `json:"keep_alive,omitempty"`
|
||||||
|
|
||||||
|
// Truncate truncates the input to fit the model's max sequence length.
|
||||||
Truncate *bool `json:"truncate,omitempty"`
|
Truncate *bool `json:"truncate,omitempty"`
|
||||||
|
|
||||||
|
// Dimensions truncates the output embedding to the specified dimension.
|
||||||
|
Dimensions int `json:"dimensions,omitempty"`
|
||||||
|
|
||||||
// Options lists model-specific options.
|
// Options lists model-specific options.
|
||||||
Options map[string]any `json:"options"`
|
Options map[string]any `json:"options"`
|
||||||
}
|
}
|
||||||
@@ -881,7 +892,7 @@ func (d *Duration) UnmarshalJSON(b []byte) (err error) {
|
|||||||
if t < 0 {
|
if t < 0 {
|
||||||
d.Duration = time.Duration(math.MaxInt64)
|
d.Duration = time.Duration(math.MaxInt64)
|
||||||
} else {
|
} else {
|
||||||
d.Duration = time.Duration(int(t) * int(time.Second))
|
d.Duration = time.Duration(t * float64(time.Second))
|
||||||
}
|
}
|
||||||
case string:
|
case string:
|
||||||
d.Duration, err = time.ParseDuration(t)
|
d.Duration, err = time.ParseDuration(t)
|
||||||
|
|||||||
@@ -17,6 +17,11 @@ func TestKeepAliveParsingFromJSON(t *testing.T) {
|
|||||||
req string
|
req string
|
||||||
exp *Duration
|
exp *Duration
|
||||||
}{
|
}{
|
||||||
|
{
|
||||||
|
name: "Unset",
|
||||||
|
req: `{ }`,
|
||||||
|
exp: nil,
|
||||||
|
},
|
||||||
{
|
{
|
||||||
name: "Positive Integer",
|
name: "Positive Integer",
|
||||||
req: `{ "keep_alive": 42 }`,
|
req: `{ "keep_alive": 42 }`,
|
||||||
@@ -25,7 +30,7 @@ func TestKeepAliveParsingFromJSON(t *testing.T) {
|
|||||||
{
|
{
|
||||||
name: "Positive Float",
|
name: "Positive Float",
|
||||||
req: `{ "keep_alive": 42.5 }`,
|
req: `{ "keep_alive": 42.5 }`,
|
||||||
exp: &Duration{42 * time.Second},
|
exp: &Duration{42500 * time.Millisecond},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "Positive Integer String",
|
name: "Positive Integer String",
|
||||||
@@ -436,3 +441,50 @@ func TestThinking_UnmarshalJSON(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestToolFunctionParameters_String(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
params ToolFunctionParameters
|
||||||
|
expected string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "simple object with string property",
|
||||||
|
params: ToolFunctionParameters{
|
||||||
|
Type: "object",
|
||||||
|
Required: []string{"name"},
|
||||||
|
Properties: map[string]ToolProperty{
|
||||||
|
"name": {
|
||||||
|
Type: PropertyType{"string"},
|
||||||
|
Description: "The name of the person",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
expected: `{"type":"object","required":["name"],"properties":{"name":{"type":"string","description":"The name of the person"}}}`,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "marshal failure returns empty string",
|
||||||
|
params: ToolFunctionParameters{
|
||||||
|
Type: "object",
|
||||||
|
Defs: func() any {
|
||||||
|
// Create a cycle that will cause json.Marshal to fail
|
||||||
|
type selfRef struct {
|
||||||
|
Self *selfRef
|
||||||
|
}
|
||||||
|
s := &selfRef{}
|
||||||
|
s.Self = s
|
||||||
|
return s
|
||||||
|
}(),
|
||||||
|
Properties: map[string]ToolProperty{},
|
||||||
|
},
|
||||||
|
expected: "",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, test := range tests {
|
||||||
|
t.Run(test.name, func(t *testing.T) {
|
||||||
|
result := test.params.String()
|
||||||
|
assert.Equal(t, test.expected, result)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -56,10 +56,8 @@ func ensureThinkingSupport(ctx context.Context, client *api.Client, name string)
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
for _, cap := range resp.Capabilities {
|
if slices.Contains(resp.Capabilities, model.CapabilityThinking) {
|
||||||
if cap == model.CapabilityThinking {
|
return
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
fmt.Fprintf(os.Stderr, "warning: model %q does not support thinking output\n", name)
|
fmt.Fprintf(os.Stderr, "warning: model %q does not support thinking output\n", name)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -28,6 +28,7 @@ type bertModel struct {
|
|||||||
LayerNormEPS float32 `json:"layer_norm_eps"`
|
LayerNormEPS float32 `json:"layer_norm_eps"`
|
||||||
LayerNormEpsilon float32 `json:"layer_norm_epsilon"`
|
LayerNormEpsilon float32 `json:"layer_norm_epsilon"`
|
||||||
NormEpsilon float32 `json:"norm_epsilon"`
|
NormEpsilon float32 `json:"norm_epsilon"`
|
||||||
|
normalizeEmbeddings bool
|
||||||
|
|
||||||
PoolingType uint32
|
PoolingType uint32
|
||||||
}
|
}
|
||||||
@@ -54,9 +55,11 @@ func (p *bertModel) parseMore(fsys fs.FS) error {
|
|||||||
|
|
||||||
var pooling string
|
var pooling string
|
||||||
for _, m := range modules {
|
for _, m := range modules {
|
||||||
if m.Type == "sentence_transformers.models.Pooling" {
|
switch m.Type {
|
||||||
|
case "sentence_transformers.models.Pooling":
|
||||||
pooling = m.Path
|
pooling = m.Path
|
||||||
break
|
case "sentence_transformers.models.Normalize":
|
||||||
|
p.normalizeEmbeddings = true
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -90,6 +93,7 @@ func (p *bertModel) KV(t *Tokenizer) ggml.KV {
|
|||||||
kv["general.architecture"] = "bert"
|
kv["general.architecture"] = "bert"
|
||||||
kv["bert.attention.causal"] = false
|
kv["bert.attention.causal"] = false
|
||||||
kv["bert.pooling_type"] = p.PoolingType
|
kv["bert.pooling_type"] = p.PoolingType
|
||||||
|
kv["bert.normalize_embeddings"] = p.normalizeEmbeddings
|
||||||
|
|
||||||
kv["bert.block_count"] = cmp.Or(p.NLayers, p.NumHiddenLayers, p.NLayer)
|
kv["bert.block_count"] = cmp.Or(p.NLayers, p.NumHiddenLayers, p.NLayer)
|
||||||
|
|
||||||
|
|||||||
@@ -15,19 +15,24 @@ import (
|
|||||||
|
|
||||||
type gptossModel struct {
|
type gptossModel struct {
|
||||||
ModelParameters
|
ModelParameters
|
||||||
HiddenLayers uint32 `json:"num_hidden_layers"`
|
HiddenLayers uint32 `json:"num_hidden_layers"`
|
||||||
HiddenSize uint32 `json:"hidden_size"`
|
MaxPositionEmbeddings uint32 `json:"max_position_embeddings"`
|
||||||
IntermediateSize uint32 `json:"intermediate_size"`
|
HiddenSize uint32 `json:"hidden_size"`
|
||||||
AttentionHeads uint32 `json:"num_attention_heads"`
|
IntermediateSize uint32 `json:"intermediate_size"`
|
||||||
KeyValueHeads uint32 `json:"num_key_value_heads"`
|
AttentionHeads uint32 `json:"num_attention_heads"`
|
||||||
HeadDim uint32 `json:"head_dim"`
|
KeyValueHeads uint32 `json:"num_key_value_heads"`
|
||||||
Experts uint32 `json:"num_experts"`
|
HeadDim uint32 `json:"head_dim"`
|
||||||
ExpertsPerToken uint32 `json:"experts_per_token"`
|
Experts uint32 `json:"num_experts"`
|
||||||
RMSNormEpsilon float32 `json:"rms_norm_eps"`
|
LocalExperts uint32 `json:"num_local_experts"`
|
||||||
InitialContextLength uint32 `json:"initial_context_length"`
|
ExpertsPerToken uint32 `json:"experts_per_token"`
|
||||||
RopeTheta float32 `json:"rope_theta"`
|
RMSNormEpsilon float32 `json:"rms_norm_eps"`
|
||||||
RopeScalingFactor float32 `json:"rope_scaling_factor"`
|
InitialContextLength uint32 `json:"initial_context_length"`
|
||||||
SlidingWindow uint32 `json:"sliding_window"`
|
RopeTheta float32 `json:"rope_theta"`
|
||||||
|
RopeScalingFactor float32 `json:"rope_scaling_factor"`
|
||||||
|
RopeScaling struct {
|
||||||
|
Factor float32 `json:"factor"`
|
||||||
|
} `json:"rope_scaling"`
|
||||||
|
SlidingWindow uint32 `json:"sliding_window"`
|
||||||
}
|
}
|
||||||
|
|
||||||
var _ ModelConverter = (*gptossModel)(nil)
|
var _ ModelConverter = (*gptossModel)(nil)
|
||||||
@@ -36,11 +41,11 @@ func (m *gptossModel) KV(t *Tokenizer) ggml.KV {
|
|||||||
kv := m.ModelParameters.KV(t)
|
kv := m.ModelParameters.KV(t)
|
||||||
kv["general.architecture"] = "gptoss"
|
kv["general.architecture"] = "gptoss"
|
||||||
kv["general.file_type"] = uint32(4)
|
kv["general.file_type"] = uint32(4)
|
||||||
kv["gptoss.context_length"] = uint32(m.RopeScalingFactor * float32(m.InitialContextLength))
|
kv["gptoss.context_length"] = cmp.Or(m.MaxPositionEmbeddings, uint32(m.RopeScalingFactor*float32(m.InitialContextLength)))
|
||||||
kv["gptoss.block_count"] = m.HiddenLayers
|
kv["gptoss.block_count"] = m.HiddenLayers
|
||||||
kv["gptoss.embedding_length"] = m.HiddenSize
|
kv["gptoss.embedding_length"] = m.HiddenSize
|
||||||
kv["gptoss.feed_forward_length"] = m.IntermediateSize
|
kv["gptoss.feed_forward_length"] = m.IntermediateSize
|
||||||
kv["gptoss.expert_count"] = m.Experts
|
kv["gptoss.expert_count"] = cmp.Or(m.Experts, m.LocalExperts)
|
||||||
kv["gptoss.expert_used_count"] = m.ExpertsPerToken
|
kv["gptoss.expert_used_count"] = m.ExpertsPerToken
|
||||||
kv["gptoss.attention.head_count"] = m.AttentionHeads
|
kv["gptoss.attention.head_count"] = m.AttentionHeads
|
||||||
kv["gptoss.attention.head_count_kv"] = m.KeyValueHeads
|
kv["gptoss.attention.head_count_kv"] = m.KeyValueHeads
|
||||||
@@ -49,7 +54,7 @@ func (m *gptossModel) KV(t *Tokenizer) ggml.KV {
|
|||||||
kv["gptoss.attention.layer_norm_rms_epsilon"] = cmp.Or(m.RMSNormEpsilon, 1e-5)
|
kv["gptoss.attention.layer_norm_rms_epsilon"] = cmp.Or(m.RMSNormEpsilon, 1e-5)
|
||||||
kv["gptoss.attention.sliding_window"] = m.SlidingWindow
|
kv["gptoss.attention.sliding_window"] = m.SlidingWindow
|
||||||
kv["gptoss.rope.freq_base"] = m.RopeTheta
|
kv["gptoss.rope.freq_base"] = m.RopeTheta
|
||||||
kv["gptoss.rope.scaling.factor"] = m.RopeScalingFactor
|
kv["gptoss.rope.scaling.factor"] = cmp.Or(m.RopeScalingFactor, m.RopeScaling.Factor)
|
||||||
kv["gptoss.rope.scaling.original_context_length"] = m.InitialContextLength
|
kv["gptoss.rope.scaling.original_context_length"] = m.InitialContextLength
|
||||||
kv["tokenizer.ggml.bos_token_id"] = uint32(199998) // <|startoftext|>
|
kv["tokenizer.ggml.bos_token_id"] = uint32(199998) // <|startoftext|>
|
||||||
kv["tokenizer.ggml.add_bos_token"] = false
|
kv["tokenizer.ggml.add_bos_token"] = false
|
||||||
@@ -92,6 +97,11 @@ func (m *gptossModel) Tensors(ts []Tensor) []*ggml.Tensor {
|
|||||||
|
|
||||||
for name, mxfp4 := range mxfp4s {
|
for name, mxfp4 := range mxfp4s {
|
||||||
dims := mxfp4.blocks.Shape()
|
dims := mxfp4.blocks.Shape()
|
||||||
|
|
||||||
|
if !strings.HasSuffix(name, ".weight") {
|
||||||
|
name += ".weight"
|
||||||
|
}
|
||||||
|
|
||||||
out = append(out, &ggml.Tensor{
|
out = append(out, &ggml.Tensor{
|
||||||
Name: name,
|
Name: name,
|
||||||
Kind: uint32(ggml.TensorTypeMXFP4),
|
Kind: uint32(ggml.TensorTypeMXFP4),
|
||||||
@@ -104,25 +114,47 @@ func (m *gptossModel) Tensors(ts []Tensor) []*ggml.Tensor {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (m *gptossModel) Replacements() []string {
|
func (m *gptossModel) Replacements() []string {
|
||||||
return []string{
|
var replacements []string
|
||||||
// noop replacements so other replacements will not be applied
|
if m.MaxPositionEmbeddings > 0 {
|
||||||
".blocks", ".blocks",
|
// hf flavored model
|
||||||
".scales", ".scales",
|
replacements = []string{
|
||||||
// real replacements
|
"lm_head", "output",
|
||||||
"block", "blk",
|
"model.embed_tokens", "token_embd",
|
||||||
"attn.norm", "attn_norm",
|
"model.layers", "blk",
|
||||||
"attn.qkv", "attn_qkv",
|
"input_layernorm", "attn_norm",
|
||||||
"attn.sinks", "attn_sinks",
|
"self_attn.q_proj", "attn_q",
|
||||||
"attn.out", "attn_out",
|
"self_attn.k_proj", "attn_k",
|
||||||
"mlp.norm", "ffn_norm",
|
"self_attn.v_proj", "attn_v",
|
||||||
"mlp.gate", "ffn_gate_inp",
|
"self_attn.o_proj", "attn_out",
|
||||||
"mlp.mlp1_", "ffn_gate_up_exps.",
|
"self_attn.sinks", "attn_sinks",
|
||||||
"mlp.mlp2_", "ffn_down_exps.",
|
"post_attention_layernorm", "ffn_norm",
|
||||||
"embedding", "token_embd",
|
"mlp.router", "ffn_gate_inp",
|
||||||
"norm", "output_norm",
|
"mlp.experts.gate_up_proj_", "ffn_gate_up_exps.",
|
||||||
"unembedding", "output",
|
"mlp.experts.down_proj_", "ffn_down_exps.",
|
||||||
"scale", "weight",
|
"model.norm", "output_norm",
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
replacements = []string{
|
||||||
|
// noop replacements so other replacements will not be applied
|
||||||
|
".blocks", ".blocks",
|
||||||
|
".scales", ".scales",
|
||||||
|
// real replacements
|
||||||
|
"block", "blk",
|
||||||
|
"attn.norm", "attn_norm",
|
||||||
|
"attn.qkv", "attn_qkv",
|
||||||
|
"attn.sinks", "attn_sinks",
|
||||||
|
"attn.out", "attn_out",
|
||||||
|
"mlp.norm", "ffn_norm",
|
||||||
|
"mlp.gate", "ffn_gate_inp",
|
||||||
|
"mlp.mlp1_", "ffn_gate_up_exps.",
|
||||||
|
"mlp.mlp2_", "ffn_down_exps.",
|
||||||
|
"embedding", "token_embd",
|
||||||
|
"norm", "output_norm",
|
||||||
|
"unembedding", "output",
|
||||||
|
"scale", "weight",
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
return replacements
|
||||||
}
|
}
|
||||||
|
|
||||||
type mxfp4 struct {
|
type mxfp4 struct {
|
||||||
@@ -140,7 +172,20 @@ func (m *mxfp4) WriteTo(w io.Writer) (int64, error) {
|
|||||||
blocksDims[i] = int(d)
|
blocksDims[i] = int(d)
|
||||||
}
|
}
|
||||||
|
|
||||||
var blocks tensor.Tensor = tensor.New(tensor.WithShape(blocksDims...), tensor.WithBacking(b.Bytes()))
|
bts := b.Bytes()
|
||||||
|
var tmp [16]byte
|
||||||
|
for i := 0; i < b.Len(); i += 16 {
|
||||||
|
for j := range 8 {
|
||||||
|
// transform a1b2c3 ... x7y8z9 -> 71xa82yb93zc
|
||||||
|
a, b := bts[i+j], bts[i+j+8]
|
||||||
|
tmp[2*j+0] = (a & 0x0F) | (b << 4)
|
||||||
|
tmp[2*j+1] = (a >> 4) | (b & 0xF0)
|
||||||
|
}
|
||||||
|
|
||||||
|
copy(bts[i:i+16], tmp[:])
|
||||||
|
}
|
||||||
|
|
||||||
|
var blocks tensor.Tensor = tensor.New(tensor.WithShape(blocksDims...), tensor.WithBacking(bts))
|
||||||
|
|
||||||
var s bytes.Buffer
|
var s bytes.Buffer
|
||||||
if _, err := m.scales.WriteTo(&s); err != nil {
|
if _, err := m.scales.WriteTo(&s); err != nil {
|
||||||
@@ -174,5 +219,5 @@ func (m *mxfp4) WriteTo(w io.Writer) (int64, error) {
|
|||||||
return 0, err
|
return 0, err
|
||||||
}
|
}
|
||||||
|
|
||||||
return 0, nil
|
return int64(len(u8s)), nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -33,8 +33,8 @@ func (t tensorBase) Shape() []uint64 {
|
|||||||
const (
|
const (
|
||||||
tensorKindFP32 uint32 = iota
|
tensorKindFP32 uint32 = iota
|
||||||
tensorKindFP16
|
tensorKindFP16
|
||||||
tensorKindMXFP4 = 4
|
|
||||||
tensorKindBF16 = 30
|
tensorKindBF16 = 30
|
||||||
|
tensorKindMXFP4 = 39
|
||||||
)
|
)
|
||||||
|
|
||||||
func (t tensorBase) Kind() uint32 {
|
func (t tensorBase) Kind() uint32 {
|
||||||
|
|||||||
@@ -188,17 +188,17 @@ func (st safetensor) WriteTo(w io.Writer) (int64, error) {
|
|||||||
|
|
||||||
switch st.Kind() {
|
switch st.Kind() {
|
||||||
case tensorKindFP32:
|
case tensorKindFP32:
|
||||||
return 0, binary.Write(w, binary.LittleEndian, f32s)
|
return int64(len(f32s) * 4), binary.Write(w, binary.LittleEndian, f32s)
|
||||||
case tensorKindFP16:
|
case tensorKindFP16:
|
||||||
f16s := make([]uint16, len(f32s))
|
f16s := make([]uint16, len(f32s))
|
||||||
for i := range f32s {
|
for i := range f32s {
|
||||||
f16s[i] = float16.Fromfloat32(f32s[i]).Bits()
|
f16s[i] = float16.Fromfloat32(f32s[i]).Bits()
|
||||||
}
|
}
|
||||||
|
|
||||||
return 0, binary.Write(w, binary.LittleEndian, f16s)
|
return int64(len(f16s) * 2), binary.Write(w, binary.LittleEndian, f16s)
|
||||||
case tensorKindBF16:
|
case tensorKindBF16:
|
||||||
u8s := bfloat16.EncodeFloat32(f32s)
|
u8s := bfloat16.EncodeFloat32(f32s)
|
||||||
return 0, binary.Write(w, binary.LittleEndian, u8s)
|
return int64(len(u8s)), binary.Write(w, binary.LittleEndian, u8s)
|
||||||
default:
|
default:
|
||||||
return 0, fmt.Errorf("unknown storage type: %d", st.Kind())
|
return 0, fmt.Errorf("unknown storage type: %d", st.Kind())
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -277,6 +277,7 @@ func AMDGetGPUInfo() ([]RocmGPUInfo, error) {
|
|||||||
FreeMemory: (totalMemory - usedMemory),
|
FreeMemory: (totalMemory - usedMemory),
|
||||||
},
|
},
|
||||||
ID: ID,
|
ID: ID,
|
||||||
|
filterID: gpuOrdinalID,
|
||||||
Name: name,
|
Name: name,
|
||||||
Compute: fmt.Sprintf("gfx%d%x%x", major, minor, patch),
|
Compute: fmt.Sprintf("gfx%d%x%x", major, minor, patch),
|
||||||
MinimumMemory: rocmMinimumMemory,
|
MinimumMemory: rocmMinimumMemory,
|
||||||
@@ -394,7 +395,7 @@ func AMDGetGPUInfo() ([]RocmGPUInfo, error) {
|
|||||||
|
|
||||||
// Check for env var workarounds
|
// Check for env var workarounds
|
||||||
if name == "1002:687f" { // Vega RX 56
|
if name == "1002:687f" { // Vega RX 56
|
||||||
gpuInfo.EnvWorkarounds = append(gpuInfo.EnvWorkarounds, [2]string{"HSA_ENABLE_SDMA", "0"})
|
gpuInfo.EnvWorkarounds = append(gpuInfo.EnvWorkarounds, "HSA_ENABLE_SDMA=0")
|
||||||
}
|
}
|
||||||
|
|
||||||
// The GPU has passed all the verification steps and is supported
|
// The GPU has passed all the verification steps and is supported
|
||||||
@@ -523,19 +524,26 @@ func verifyKFDDriverAccess() error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func rocmGetVisibleDevicesEnv(gpuInfo []GpuInfo) (string, string) {
|
func rocmGetVisibleDevicesEnv(gpuInfo []GpuInfo) string {
|
||||||
ids := []string{}
|
ids := []string{}
|
||||||
for _, info := range gpuInfo {
|
for _, info := range gpuInfo {
|
||||||
if info.Library != "rocm" {
|
if info.Library != "rocm" {
|
||||||
// TODO shouldn't happen if things are wired correctly...
|
|
||||||
slog.Debug("rocmGetVisibleDevicesEnv skipping over non-rocm device", "library", info.Library)
|
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
ids = append(ids, info.ID)
|
// If the devices requires a numeric ID, for filtering purposes, we use the unfiltered ID number
|
||||||
|
if _, err := strconv.Atoi(info.ID); err == nil {
|
||||||
|
ids = append(ids, fmt.Sprintf("%d", info.filterID))
|
||||||
|
} else {
|
||||||
|
ids = append(ids, info.ID)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
if len(ids) == 0 {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
// There are 3 potential env vars to use to select GPUs.
|
// There are 3 potential env vars to use to select GPUs.
|
||||||
// ROCR_VISIBLE_DEVICES supports UUID or numeric so is our preferred on linux
|
// ROCR_VISIBLE_DEVICES supports UUID or numeric so is our preferred on linux
|
||||||
// GPU_DEVICE_ORDINAL supports numeric IDs only
|
// GPU_DEVICE_ORDINAL supports numeric IDs only
|
||||||
// HIP_VISIBLE_DEVICES supports numeric IDs only
|
// HIP_VISIBLE_DEVICES supports numeric IDs only
|
||||||
return "ROCR_VISIBLE_DEVICES", strings.Join(ids, ",")
|
return "ROCR_VISIBLE_DEVICES=" + strings.Join(ids, ",")
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -111,6 +111,7 @@ func AMDGetGPUInfo() ([]RocmGPUInfo, error) {
|
|||||||
UnreliableFreeMemory: true,
|
UnreliableFreeMemory: true,
|
||||||
|
|
||||||
ID: strconv.Itoa(i), // TODO this is probably wrong if we specify visible devices
|
ID: strconv.Itoa(i), // TODO this is probably wrong if we specify visible devices
|
||||||
|
filterID: i,
|
||||||
DependencyPath: []string{libDir},
|
DependencyPath: []string{libDir},
|
||||||
MinimumMemory: rocmMinimumMemory,
|
MinimumMemory: rocmMinimumMemory,
|
||||||
Name: name,
|
Name: name,
|
||||||
@@ -200,19 +201,26 @@ func (gpus RocmGPUInfoList) RefreshFreeMemory() error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func rocmGetVisibleDevicesEnv(gpuInfo []GpuInfo) (string, string) {
|
func rocmGetVisibleDevicesEnv(gpuInfo []GpuInfo) string {
|
||||||
ids := []string{}
|
ids := []string{}
|
||||||
for _, info := range gpuInfo {
|
for _, info := range gpuInfo {
|
||||||
if info.Library != "rocm" {
|
if info.Library != "rocm" {
|
||||||
// TODO shouldn't happen if things are wired correctly...
|
|
||||||
slog.Debug("rocmGetVisibleDevicesEnv skipping over non-rocm device", "library", info.Library)
|
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
ids = append(ids, info.ID)
|
// If the devices requires a numeric ID, for filtering purposes, we use the unfiltered ID number
|
||||||
|
if _, err := strconv.Atoi(info.ID); err == nil {
|
||||||
|
ids = append(ids, fmt.Sprintf("%d", info.filterID))
|
||||||
|
} else {
|
||||||
|
ids = append(ids, info.ID)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
if len(ids) == 0 {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
// There are 3 potential env vars to use to select GPUs.
|
// There are 3 potential env vars to use to select GPUs.
|
||||||
// ROCR_VISIBLE_DEVICES supports UUID or numeric but does not work on Windows
|
// ROCR_VISIBLE_DEVICES supports UUID or numeric but does not work on Windows
|
||||||
// HIP_VISIBLE_DEVICES supports numeric IDs only
|
// HIP_VISIBLE_DEVICES supports numeric IDs only
|
||||||
// GPU_DEVICE_ORDINAL supports numeric IDs only
|
// GPU_DEVICE_ORDINAL supports numeric IDs only
|
||||||
return "HIP_VISIBLE_DEVICES", strings.Join(ids, ",")
|
return "HIP_VISIBLE_DEVICES=" + strings.Join(ids, ",")
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -16,19 +16,6 @@ import (
|
|||||||
// Included to drive logic for reducing Ollama-allocated overhead on L4T/Jetson devices.
|
// Included to drive logic for reducing Ollama-allocated overhead on L4T/Jetson devices.
|
||||||
var CudaTegra string = os.Getenv("JETSON_JETPACK")
|
var CudaTegra string = os.Getenv("JETSON_JETPACK")
|
||||||
|
|
||||||
func cudaGetVisibleDevicesEnv(gpuInfo []GpuInfo) (string, string) {
|
|
||||||
ids := []string{}
|
|
||||||
for _, info := range gpuInfo {
|
|
||||||
if info.Library != "cuda" {
|
|
||||||
// TODO shouldn't happen if things are wired correctly...
|
|
||||||
slog.Debug("cudaGetVisibleDevicesEnv skipping over non-cuda device", "library", info.Library)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
ids = append(ids, info.ID)
|
|
||||||
}
|
|
||||||
return "CUDA_VISIBLE_DEVICES", strings.Join(ids, ",")
|
|
||||||
}
|
|
||||||
|
|
||||||
func cudaVariant(gpuInfo CudaGPUInfo) string {
|
func cudaVariant(gpuInfo CudaGPUInfo) string {
|
||||||
if runtime.GOARCH == "arm64" && runtime.GOOS == "linux" {
|
if runtime.GOARCH == "arm64" && runtime.GOOS == "linux" {
|
||||||
if CudaTegra != "" {
|
if CudaTegra != "" {
|
||||||
@@ -56,14 +43,15 @@ func cudaVariant(gpuInfo CudaGPUInfo) string {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return "sbsa"
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// driver 12.0 has problems with the cuda v12 library, so run v11 on those older drivers
|
if gpuInfo.DriverMajor < 13 {
|
||||||
if gpuInfo.DriverMajor < 12 || (gpuInfo.DriverMajor == 12 && gpuInfo.DriverMinor == 0) {
|
// The detected driver is older than 580 (Aug 2025)
|
||||||
// The detected driver is older than Feb 2023
|
// Warn if their CC is compatible with v13 and they should upgrade their driver to get better performance
|
||||||
slog.Warn("old CUDA driver detected - please upgrade to a newer driver", "version", fmt.Sprintf("%d.%d", gpuInfo.DriverMajor, gpuInfo.DriverMinor))
|
if gpuInfo.computeMajor > 7 || (gpuInfo.computeMajor == 7 && gpuInfo.computeMinor >= 5) {
|
||||||
return "v11"
|
slog.Warn("old CUDA driver detected - please upgrade to a newer driver for best performance", "version", fmt.Sprintf("%d.%d", gpuInfo.DriverMajor, gpuInfo.DriverMinor))
|
||||||
|
}
|
||||||
|
return "v12"
|
||||||
}
|
}
|
||||||
return "v12"
|
return "v13"
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -371,6 +371,15 @@ func GetGPUInfo() GpuInfoList {
|
|||||||
}
|
}
|
||||||
|
|
||||||
rocmGPUs, err = AMDGetGPUInfo()
|
rocmGPUs, err = AMDGetGPUInfo()
|
||||||
|
|
||||||
|
// The ID field is used in context of the filtered set of GPUS
|
||||||
|
// so we have to replace any of these numeric IDs with their
|
||||||
|
// placement in this set of GPUs
|
||||||
|
for i := range rocmGPUs {
|
||||||
|
if _, err := strconv.Atoi(rocmGPUs[i].ID); err == nil {
|
||||||
|
rocmGPUs[i].ID = strconv.Itoa(i)
|
||||||
|
}
|
||||||
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
bootstrapErrors = append(bootstrapErrors, err)
|
bootstrapErrors = append(bootstrapErrors, err)
|
||||||
}
|
}
|
||||||
@@ -680,23 +689,16 @@ func getVerboseState() C.uint16_t {
|
|||||||
|
|
||||||
// Given the list of GPUs this instantiation is targeted for,
|
// Given the list of GPUs this instantiation is targeted for,
|
||||||
// figure out the visible devices environment variable
|
// figure out the visible devices environment variable
|
||||||
//
|
func (l GpuInfoList) GetVisibleDevicesEnv() []string {
|
||||||
// If different libraries are detected, the first one is what we use
|
|
||||||
func (l GpuInfoList) GetVisibleDevicesEnv() (string, string) {
|
|
||||||
if len(l) == 0 {
|
if len(l) == 0 {
|
||||||
return "", ""
|
return nil
|
||||||
}
|
}
|
||||||
switch l[0].Library {
|
vd := []string{}
|
||||||
case "cuda":
|
// Only filter the AMD GPUs at this level, let all NVIDIA devices through
|
||||||
return cudaGetVisibleDevicesEnv(l)
|
if tmp := rocmGetVisibleDevicesEnv(l); tmp != "" {
|
||||||
case "rocm":
|
vd = append(vd, tmp)
|
||||||
return rocmGetVisibleDevicesEnv(l)
|
|
||||||
case "oneapi":
|
|
||||||
return oneapiGetVisibleDevicesEnv(l)
|
|
||||||
default:
|
|
||||||
slog.Debug("no filter required for library " + l[0].Library)
|
|
||||||
return "", ""
|
|
||||||
}
|
}
|
||||||
|
return vd
|
||||||
}
|
}
|
||||||
|
|
||||||
func GetSystemInfo() SystemInfo {
|
func GetSystemInfo() SystemInfo {
|
||||||
|
|||||||
@@ -62,9 +62,9 @@ func GetCPUMem() (memInfo, error) {
|
|||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (l GpuInfoList) GetVisibleDevicesEnv() (string, string) {
|
func (l GpuInfoList) GetVisibleDevicesEnv() []string {
|
||||||
// No-op on darwin
|
// No-op on darwin
|
||||||
return "", ""
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func GetSystemInfo() SystemInfo {
|
func GetSystemInfo() SystemInfo {
|
||||||
|
|||||||
@@ -1,21 +0,0 @@
|
|||||||
//go:build linux || windows
|
|
||||||
|
|
||||||
package discover
|
|
||||||
|
|
||||||
import (
|
|
||||||
"log/slog"
|
|
||||||
"strings"
|
|
||||||
)
|
|
||||||
|
|
||||||
func oneapiGetVisibleDevicesEnv(gpuInfo []GpuInfo) (string, string) {
|
|
||||||
ids := []string{}
|
|
||||||
for _, info := range gpuInfo {
|
|
||||||
if info.Library != "oneapi" {
|
|
||||||
// TODO shouldn't happen if things are wired correctly...
|
|
||||||
slog.Debug("oneapiGetVisibleDevicesEnv skipping over non-sycl device", "library", info.Library)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
ids = append(ids, info.ID)
|
|
||||||
}
|
|
||||||
return "ONEAPI_DEVICE_SELECTOR", "level_zero:" + strings.Join(ids, ",")
|
|
||||||
}
|
|
||||||
@@ -27,8 +27,8 @@ type GpuInfo struct { // TODO better name maybe "InferenceProcessor"?
|
|||||||
// Any extra PATH/LD_LIBRARY_PATH dependencies required for the Library to operate properly
|
// Any extra PATH/LD_LIBRARY_PATH dependencies required for the Library to operate properly
|
||||||
DependencyPath []string `json:"lib_path,omitempty"`
|
DependencyPath []string `json:"lib_path,omitempty"`
|
||||||
|
|
||||||
// Extra environment variables specific to the GPU as list of [key,value]
|
// Extra environment variables specific to the GPU as list of [key=value]
|
||||||
EnvWorkarounds [][2]string `json:"envs,omitempty"`
|
EnvWorkarounds []string `json:"envs,omitempty"`
|
||||||
|
|
||||||
// Set to true if we can NOT reliably discover FreeMemory. A value of true indicates
|
// Set to true if we can NOT reliably discover FreeMemory. A value of true indicates
|
||||||
// the FreeMemory is best effort, and may over or under report actual memory usage
|
// the FreeMemory is best effort, and may over or under report actual memory usage
|
||||||
@@ -36,9 +36,10 @@ type GpuInfo struct { // TODO better name maybe "InferenceProcessor"?
|
|||||||
UnreliableFreeMemory bool
|
UnreliableFreeMemory bool
|
||||||
|
|
||||||
// GPU information
|
// GPU information
|
||||||
ID string `json:"gpu_id"` // string to use for selection of this specific GPU
|
ID string `json:"gpu_id"` // string to use for selection of this specific GPU
|
||||||
Name string `json:"name"` // user friendly name if available
|
filterID int //nolint:unused,nolintlint // AMD Workaround: The numeric ID of the device used to filter out other devices
|
||||||
Compute string `json:"compute"` // Compute Capability or gfx
|
Name string `json:"name"` // user friendly name if available
|
||||||
|
Compute string `json:"compute"` // Compute Capability or gfx
|
||||||
|
|
||||||
// Driver Information - TODO no need to put this on each GPU
|
// Driver Information - TODO no need to put this on each GPU
|
||||||
DriverMajor int `json:"driver_major,omitempty"`
|
DriverMajor int `json:"driver_major,omitempty"`
|
||||||
|
|||||||
@@ -1708,6 +1708,7 @@ Advanced parameters:
|
|||||||
- `truncate`: truncates the end of each input to fit within context length. Returns error if `false` and context length is exceeded. Defaults to `true`
|
- `truncate`: truncates the end of each input to fit within context length. Returns error if `false` and context length is exceeded. Defaults to `true`
|
||||||
- `options`: additional model parameters listed in the documentation for the [Modelfile](./modelfile.md#valid-parameters-and-values) such as `temperature`
|
- `options`: additional model parameters listed in the documentation for the [Modelfile](./modelfile.md#valid-parameters-and-values) such as `temperature`
|
||||||
- `keep_alive`: controls how long the model will stay loaded into memory following the request (default: `5m`)
|
- `keep_alive`: controls how long the model will stay loaded into memory following the request (default: `5m`)
|
||||||
|
- `dimensions`: number of dimensions for the embedding
|
||||||
|
|
||||||
### Examples
|
### Examples
|
||||||
|
|
||||||
|
|||||||
@@ -11,6 +11,10 @@ Then build and run Ollama from the root directory of the repository:
|
|||||||
go run . serve
|
go run . serve
|
||||||
```
|
```
|
||||||
|
|
||||||
|
> [!NOTE]
|
||||||
|
> Ollama includes native code compiled with CGO. From time to time these data structures can change and CGO can get out of sync resulting in unexpected crashes. You can force a full build of the native code by running `go clean -cache` first.
|
||||||
|
|
||||||
|
|
||||||
## macOS (Apple Silicon)
|
## macOS (Apple Silicon)
|
||||||
|
|
||||||
macOS Apple Silicon supports Metal which is built-in to the Ollama binary. No additional steps are required.
|
macOS Apple Silicon supports Metal which is built-in to the Ollama binary. No additional steps are required.
|
||||||
|
|||||||
@@ -11,12 +11,13 @@ curl -fsSL https://ollama.com/install.sh | sh
|
|||||||
## Manual install
|
## Manual install
|
||||||
|
|
||||||
> [!NOTE]
|
> [!NOTE]
|
||||||
> If you are upgrading from a prior version, you should remove the old libraries with `sudo rm -rf /usr/lib/ollama` first.
|
> If you are upgrading from a prior version, you **MUST** remove the old libraries with `sudo rm -rf /usr/lib/ollama` first.
|
||||||
|
|
||||||
Download and extract the package:
|
Download and extract the package:
|
||||||
|
|
||||||
```shell
|
```shell
|
||||||
curl -LO https://ollama.com/download/ollama-linux-amd64.tgz
|
curl -LO https://ollama.com/download/ollama-linux-amd64.tgz
|
||||||
|
sudo rm -rf /usr/lib/ollama
|
||||||
sudo tar -C /usr -xzf ollama-linux-amd64.tgz
|
sudo tar -C /usr -xzf ollama-linux-amd64.tgz
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|||||||
@@ -92,6 +92,9 @@ If none of those resolve the problem, gather additional information and file an
|
|||||||
- Set `CUDA_ERROR_LEVEL=50` and try again to get more diagnostic logs
|
- Set `CUDA_ERROR_LEVEL=50` and try again to get more diagnostic logs
|
||||||
- Check dmesg for any errors `sudo dmesg | grep -i nvrm` and `sudo dmesg | grep -i nvidia`
|
- Check dmesg for any errors `sudo dmesg | grep -i nvrm` and `sudo dmesg | grep -i nvidia`
|
||||||
|
|
||||||
|
You may get more details for initialization failures by enabling debug prints in the uvm driver. You should only use this temporarily while troubleshooting
|
||||||
|
- `sudo rmmod nvidia_uvm` then `sudo modprobe nvidia_uvm uvm_debug_prints=1`
|
||||||
|
|
||||||
|
|
||||||
## AMD GPU Discovery
|
## AMD GPU Discovery
|
||||||
|
|
||||||
|
|||||||
@@ -185,8 +185,6 @@ var (
|
|||||||
ContextLength = Uint("OLLAMA_CONTEXT_LENGTH", 4096)
|
ContextLength = Uint("OLLAMA_CONTEXT_LENGTH", 4096)
|
||||||
// Auth enables authentication between the Ollama client and server
|
// Auth enables authentication between the Ollama client and server
|
||||||
UseAuth = Bool("OLLAMA_AUTH")
|
UseAuth = Bool("OLLAMA_AUTH")
|
||||||
// Enable the new memory estimation logic
|
|
||||||
NewMemoryEstimates = Bool("OLLAMA_NEW_ESTIMATES")
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func String(s string) func() string {
|
func String(s string) func() string {
|
||||||
@@ -272,7 +270,6 @@ func AsMap() map[string]EnvVar {
|
|||||||
"OLLAMA_MULTIUSER_CACHE": {"OLLAMA_MULTIUSER_CACHE", MultiUserCache(), "Optimize prompt caching for multi-user scenarios"},
|
"OLLAMA_MULTIUSER_CACHE": {"OLLAMA_MULTIUSER_CACHE", MultiUserCache(), "Optimize prompt caching for multi-user scenarios"},
|
||||||
"OLLAMA_CONTEXT_LENGTH": {"OLLAMA_CONTEXT_LENGTH", ContextLength(), "Context length to use unless otherwise specified (default: 4096)"},
|
"OLLAMA_CONTEXT_LENGTH": {"OLLAMA_CONTEXT_LENGTH", ContextLength(), "Context length to use unless otherwise specified (default: 4096)"},
|
||||||
"OLLAMA_NEW_ENGINE": {"OLLAMA_NEW_ENGINE", NewEngine(), "Enable the new Ollama engine"},
|
"OLLAMA_NEW_ENGINE": {"OLLAMA_NEW_ENGINE", NewEngine(), "Enable the new Ollama engine"},
|
||||||
"OLLAMA_NEW_ESTIMATES": {"OLLAMA_NEW_ESTIMATES", NewMemoryEstimates(), "Enable the new memory estimation logic"},
|
|
||||||
|
|
||||||
// Informational
|
// Informational
|
||||||
"HTTP_PROXY": {"HTTP_PROXY", String("HTTP_PROXY")(), "HTTP proxy"},
|
"HTTP_PROXY": {"HTTP_PROXY", String("HTTP_PROXY")(), "HTTP proxy"},
|
||||||
|
|||||||
203
fs/ggml/ggml.go
203
fs/ggml/ggml.go
@@ -7,9 +7,11 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"log/slog"
|
"log/slog"
|
||||||
|
"math"
|
||||||
"slices"
|
"slices"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
|
"github.com/ollama/ollama/format"
|
||||||
"github.com/ollama/ollama/fs/util/bufioutil"
|
"github.com/ollama/ollama/fs/util/bufioutil"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -55,10 +57,28 @@ func (kv KV) EmbeddingLength() uint64 {
|
|||||||
return uint64(kv.Uint("embedding_length"))
|
return uint64(kv.Uint("embedding_length"))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (kv KV) HeadCount() []uint64 {
|
||||||
|
headCountDefault := uint32(1)
|
||||||
|
headCount := kv.UintOrArrayValueAsArray("attention.head_count", headCountDefault)
|
||||||
|
if len(headCount) == 1 {
|
||||||
|
headCountDefault = headCount[0]
|
||||||
|
}
|
||||||
|
nLayers := int(kv.BlockCount())
|
||||||
|
if len(headCount) > nLayers {
|
||||||
|
slog.Warn("got more elements of attention.head_count than layers", "len(headCount)", len(headCount), "layers", nLayers)
|
||||||
|
}
|
||||||
|
out := make([]uint64, nLayers)
|
||||||
|
for i := range nLayers {
|
||||||
|
if i >= len(headCount) {
|
||||||
|
out[i] = uint64(headCountDefault)
|
||||||
|
} else {
|
||||||
|
out[i] = uint64(headCount[i])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
func (kv KV) HeadCountMax() uint64 {
|
func (kv KV) HeadCountMax() uint64 {
|
||||||
// TODO(drifkin): using the max value can cause an overestimation. In the
|
|
||||||
// future if array values become more popular, we can adapt the more invasive
|
|
||||||
// <https://github.com/ollama/ollama/pull/10225>
|
|
||||||
return uint64(kv.UintOrMaxArrayValue("attention.head_count", 1))
|
return uint64(kv.UintOrMaxArrayValue("attention.head_count", 1))
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -66,6 +86,27 @@ func (kv KV) HeadCountMin() uint64 {
|
|||||||
return uint64(kv.UintOrMinArrayValue("attention.head_count", 1))
|
return uint64(kv.UintOrMinArrayValue("attention.head_count", 1))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (kv KV) HeadCountKV() []uint64 {
|
||||||
|
headCountKVDefault := uint32(1)
|
||||||
|
headCountKV := kv.UintOrArrayValueAsArray("attention.head_count_kv", headCountKVDefault)
|
||||||
|
if len(headCountKV) == 1 {
|
||||||
|
headCountKVDefault = headCountKV[0]
|
||||||
|
}
|
||||||
|
nLayers := int(kv.BlockCount())
|
||||||
|
if len(headCountKV) > nLayers {
|
||||||
|
slog.Warn("got more elements of attention.head_count than layers", "len(headCountKV)", len(headCountKV), "layers", nLayers)
|
||||||
|
}
|
||||||
|
out := make([]uint64, nLayers)
|
||||||
|
for i := range nLayers {
|
||||||
|
if i >= len(headCountKV) {
|
||||||
|
out[i] = uint64(headCountKVDefault)
|
||||||
|
} else {
|
||||||
|
out[i] = uint64(headCountKV[i])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
func (kv KV) HeadCountKVMax() uint64 {
|
func (kv KV) HeadCountKVMax() uint64 {
|
||||||
return uint64(kv.UintOrMaxArrayValue("attention.head_count_kv", 1))
|
return uint64(kv.UintOrMaxArrayValue("attention.head_count_kv", 1))
|
||||||
}
|
}
|
||||||
@@ -98,6 +139,26 @@ func (kv KV) ChatTemplate() string {
|
|||||||
return kv.String("tokenizer.chat_template")
|
return kv.String("tokenizer.chat_template")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ssm architecture parameters
|
||||||
|
|
||||||
|
func (kv KV) SSMConvKernel() uint64 {
|
||||||
|
return uint64(kv.Uint("ssm.conv_kernel"))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (kv KV) SSMInnerSize() uint64 {
|
||||||
|
return uint64(kv.Uint("ssm.inner_size"))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (kv KV) SSMStateSize() uint64 {
|
||||||
|
return uint64(kv.Uint("ssm.state_size"))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (kv KV) SSMGroupCount() uint64 {
|
||||||
|
return uint64(kv.Uint("ssm.group_count"))
|
||||||
|
}
|
||||||
|
|
||||||
|
// general types
|
||||||
|
|
||||||
func (kv KV) String(key string, defaultValue ...string) string {
|
func (kv KV) String(key string, defaultValue ...string) string {
|
||||||
val, _ := keyValue(kv, key, append(defaultValue, "")...)
|
val, _ := keyValue(kv, key, append(defaultValue, "")...)
|
||||||
return val
|
return val
|
||||||
@@ -129,22 +190,27 @@ func (kv KV) UintOrMinArrayValue(key string, defaultValue uint32) uint32 {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (kv KV) UintOrArrayValue(key string, defaultValue uint32) (uint32, uint32) {
|
func (kv KV) UintOrArrayValue(key string, defaultValue uint32) (uint32, uint32) {
|
||||||
|
arrVal := kv.UintOrArrayValueAsArray(key, defaultValue)
|
||||||
|
return slices.Min(arrVal), slices.Max(arrVal)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (kv KV) UintOrArrayValueAsArray(key string, defaultValue uint32) []uint32 {
|
||||||
if u32, ok := keyValue(kv, key, uint32(0)); ok {
|
if u32, ok := keyValue(kv, key, uint32(0)); ok {
|
||||||
return u32, u32
|
return []uint32{u32}
|
||||||
} else if u32s, ok := keyValue(kv, key, &array[uint32]{}); ok {
|
} else if u32s, ok := keyValue(kv, key, &array[uint32]{}); ok {
|
||||||
min := slices.Min(u32s.values)
|
return u32s.values
|
||||||
max := slices.Max(u32s.values)
|
|
||||||
return min, max
|
|
||||||
} else if i32s, ok := keyValue(kv, key, &array[int32]{}); ok {
|
} else if i32s, ok := keyValue(kv, key, &array[int32]{}); ok {
|
||||||
min := slices.Min(i32s.values)
|
dst := make([]uint32, len(i32s.values))
|
||||||
max := slices.Max(i32s.values)
|
for i, v := range i32s.values {
|
||||||
if min < 0 || max < 0 {
|
if v < 0 {
|
||||||
slog.Warn("array values are unexpectedly negative", "key", key, "min", min, "max", max)
|
slog.Warn("array values are unexpectedly negative", "key", key, "i", i, "v", v)
|
||||||
|
}
|
||||||
|
dst[i] = uint32(v)
|
||||||
}
|
}
|
||||||
return uint32(min), uint32(max)
|
return dst
|
||||||
}
|
}
|
||||||
|
|
||||||
return defaultValue, defaultValue
|
return []uint32{defaultValue}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (kv KV) Strings(key string, defaultValue ...[]string) []string {
|
func (kv KV) Strings(key string, defaultValue ...[]string) []string {
|
||||||
@@ -275,7 +341,7 @@ type Tensor struct {
|
|||||||
|
|
||||||
func (t Tensor) block() (n int) {
|
func (t Tensor) block() (n int) {
|
||||||
if _, err := fmt.Sscanf(t.Name, "blk.%d.", &n); err != nil {
|
if _, err := fmt.Sscanf(t.Name, "blk.%d.", &n); err != nil {
|
||||||
return -1
|
return math.MaxInt
|
||||||
}
|
}
|
||||||
|
|
||||||
return
|
return
|
||||||
@@ -288,24 +354,24 @@ func (t Tensor) blockSize() uint64 {
|
|||||||
func (t TensorType) BlockSize() uint64 {
|
func (t TensorType) BlockSize() uint64 {
|
||||||
switch t {
|
switch t {
|
||||||
case
|
case
|
||||||
0, // F32
|
TensorTypeF32,
|
||||||
1, // F16
|
TensorTypeF16,
|
||||||
24, // I8
|
TensorTypeI8,
|
||||||
25, // I16
|
TensorTypeI16,
|
||||||
26, // I32
|
TensorTypeI32,
|
||||||
27, // I64
|
TensorTypeI64,
|
||||||
28, // F64
|
TensorTypeF64,
|
||||||
30: // BF16
|
TensorTypeBF16:
|
||||||
return 1
|
return 1
|
||||||
case
|
case
|
||||||
2, // Q4_0
|
TensorTypeQ4_0,
|
||||||
3, // Q4_1
|
TensorTypeQ4_1,
|
||||||
4, // MXFP4
|
TensorTypeQ5_0,
|
||||||
6, // Q5_0
|
TensorTypeQ5_1,
|
||||||
7, // Q5_1
|
TensorTypeQ8_0,
|
||||||
8, // Q8_0
|
TensorTypeQ8_1,
|
||||||
9, // Q8_1
|
tensorTypeIQ4_NL,
|
||||||
20: // IQ4_NL
|
4, TensorTypeMXFP4:
|
||||||
return 32
|
return 32
|
||||||
default:
|
default:
|
||||||
return 256
|
return 256
|
||||||
@@ -328,8 +394,6 @@ func (t TensorType) TypeSize() uint64 {
|
|||||||
return 2 + blockSize/2
|
return 2 + blockSize/2
|
||||||
case TensorTypeQ4_1:
|
case TensorTypeQ4_1:
|
||||||
return 2 + 2 + blockSize/2
|
return 2 + 2 + blockSize/2
|
||||||
case TensorTypeMXFP4, 39:
|
|
||||||
return 1 + blockSize/2
|
|
||||||
case TensorTypeQ5_0:
|
case TensorTypeQ5_0:
|
||||||
return 2 + 4 + blockSize/2
|
return 2 + 4 + blockSize/2
|
||||||
case TensorTypeQ5_1:
|
case TensorTypeQ5_1:
|
||||||
@@ -380,6 +444,8 @@ func (t TensorType) TypeSize() uint64 {
|
|||||||
return blockSize/8 + blockSize/16 + blockSize/32
|
return blockSize/8 + blockSize/16 + blockSize/32
|
||||||
case TensorTypeBF16:
|
case TensorTypeBF16:
|
||||||
return 2
|
return 2
|
||||||
|
case 4, TensorTypeMXFP4:
|
||||||
|
return 1 + blockSize/2
|
||||||
default:
|
default:
|
||||||
return 0
|
return 0
|
||||||
}
|
}
|
||||||
@@ -479,12 +545,14 @@ func Decode(rs io.ReadSeeker, maxArraySize int) (*GGML, error) {
|
|||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (f GGML) GraphSize(context, batch uint64, numParallel int, kvCacheType string) (kv []uint64, partialOffload, fullOffload uint64) {
|
func (f GGML) GraphSize(context, batch uint64, numParallel int, kvCacheType string, useFlashAttention bool) (kv []uint64, partialOffload, fullOffload uint64) {
|
||||||
context *= uint64(numParallel)
|
context *= uint64(numParallel)
|
||||||
|
|
||||||
embedding := f.KV().EmbeddingLength()
|
embedding := f.KV().EmbeddingLength()
|
||||||
heads := f.KV().HeadCountMax()
|
heads := f.KV().HeadCountMax()
|
||||||
|
headsArr := f.KV().HeadCount()
|
||||||
headsKV := f.KV().HeadCountKVMax()
|
headsKV := f.KV().HeadCountKVMax()
|
||||||
|
headsKVArr := f.KV().HeadCountKV()
|
||||||
vocab := uint64(f.KV()["tokenizer.ggml.tokens"].(*array[string]).size)
|
vocab := uint64(f.KV()["tokenizer.ggml.tokens"].(*array[string]).size)
|
||||||
|
|
||||||
embeddingHeads := f.KV().EmbeddingHeadCountMax()
|
embeddingHeads := f.KV().EmbeddingHeadCountMax()
|
||||||
@@ -494,12 +562,51 @@ func (f GGML) GraphSize(context, batch uint64, numParallel int, kvCacheType stri
|
|||||||
layers := f.Tensors().GroupLayers()
|
layers := f.Tensors().GroupLayers()
|
||||||
|
|
||||||
bytesPerElement := kvCacheBytesPerElement(kvCacheType)
|
bytesPerElement := kvCacheBytesPerElement(kvCacheType)
|
||||||
|
|
||||||
|
// Default for models unless special-cased below. These defaults mirror the
|
||||||
|
// cache usage in llama.cpp under the assumption that models without special
|
||||||
|
// cases below will use the llamarunner and caching will be handled by the
|
||||||
|
// llama.cpp layer.
|
||||||
|
//
|
||||||
|
// This also assumes that a layer without heads or headsKV set is recurrent
|
||||||
|
// which is usually the case. Some models (eg nemotronh) use "blocks" in
|
||||||
|
// place of layers where some are MLP blocks that don't have any cache.
|
||||||
|
// Models like this will need a special case below to be accurately
|
||||||
|
// estimated.
|
||||||
var kvTotal uint64
|
var kvTotal uint64
|
||||||
kv = make([]uint64, f.KV().BlockCount())
|
kv = make([]uint64, f.KV().BlockCount())
|
||||||
|
kvSizeAttn := uint64(0)
|
||||||
|
kvSizeRecurrent := uint64(0)
|
||||||
for i := range kv {
|
for i := range kv {
|
||||||
kv[i] = uint64(float64(context*(embeddingHeadsK+embeddingHeadsV)*headsKV) * bytesPerElement)
|
headsL := headsArr[i]
|
||||||
|
headsKVL := headsKVArr[i]
|
||||||
|
if headsL > 0 && headsKVL > 0 {
|
||||||
|
// full attention layer
|
||||||
|
// NOTE: Assumes uniform values for all attn layers
|
||||||
|
kv[i] = uint64(float64(context*(embeddingHeadsK+embeddingHeadsV)*headsKVL) * bytesPerElement)
|
||||||
|
kvSizeAttn += kv[i]
|
||||||
|
} else {
|
||||||
|
// recurrent layer
|
||||||
|
ssmDConv := f.KV().SSMConvKernel()
|
||||||
|
ssmDState := f.KV().SSMStateSize()
|
||||||
|
ssmDInner := f.KV().SSMInnerSize()
|
||||||
|
ssmNGroups := f.KV().SSMGroupCount()
|
||||||
|
nEmbdR := uint64(0)
|
||||||
|
if ssmDConv > 0 {
|
||||||
|
nEmbdR = (ssmDConv - 1) * (ssmDInner + 2*ssmNGroups*ssmDState)
|
||||||
|
}
|
||||||
|
nEmbdS := ssmDState * ssmDInner
|
||||||
|
|
||||||
|
// recurrent always uses F32 in llama.cpp backend
|
||||||
|
// https://github.com/ggml-org/llama.cpp/blob/master/src/llama-model.cpp#L18644
|
||||||
|
bytesPerElementRecurrent := kvCacheBytesPerElement("f32")
|
||||||
|
|
||||||
|
kv[i] = (nEmbdR + nEmbdS) * uint64(bytesPerElementRecurrent)
|
||||||
|
kvSizeRecurrent += kv[i]
|
||||||
|
}
|
||||||
kvTotal += kv[i]
|
kvTotal += kv[i]
|
||||||
}
|
}
|
||||||
|
slog.Debug("default cache size estimate", "attention MiB", float32(kvSizeAttn)/(1024.*1024.), "attention bytes", kvSizeAttn, "recurrent MiB", float32(kvSizeRecurrent)/(1024.*1024.), "recurrent bytes", kvSizeRecurrent)
|
||||||
|
|
||||||
switch f.KV().Architecture() {
|
switch f.KV().Architecture() {
|
||||||
case "llama", "llama4":
|
case "llama", "llama4":
|
||||||
@@ -677,7 +784,12 @@ func (f GGML) GraphSize(context, batch uint64, numParallel int, kvCacheType stri
|
|||||||
kv[i] *= context
|
kv[i] *= context
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
partialOffload = 2 * f.KV().HeadCountMax() / cmp.Or(f.KV().HeadCountKVMin(), 1) * kvTotal / 6
|
partialOffload = 2 * f.KV().HeadCountMax() / cmp.Or(f.KV().HeadCountKVMin(), 1) * kvTotal / 6
|
||||||
|
if useFlashAttention {
|
||||||
|
// rough estimate of graph size with flash attention on
|
||||||
|
partialOffload = (4*uint64(numParallel) + context>>10 + 110) * format.MebiByte
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return
|
return
|
||||||
@@ -752,12 +864,16 @@ func (llm GGML) VisionGraphSize() (weights, graphSize uint64) {
|
|||||||
|
|
||||||
// SupportsKVCacheType checks if the requested cache type is supported
|
// SupportsKVCacheType checks if the requested cache type is supported
|
||||||
func (f GGML) SupportsKVCacheType(cacheType string) bool {
|
func (f GGML) SupportsKVCacheType(cacheType string) bool {
|
||||||
|
if cacheType == "" || cacheType == "f16" {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
if arch := f.KV().Architecture(); slices.Contains([]string{"gptoss", "gpt-oss"}, arch) {
|
if arch := f.KV().Architecture(); slices.Contains([]string{"gptoss", "gpt-oss"}, arch) {
|
||||||
// gpt-oss uses attention with sinks which does not support quantized cache types
|
// gpt-oss uses attention with sinks which does not support quantized cache types
|
||||||
slog.Warn("model only supports non-quantized cache types ", "mode", arch)
|
slog.Warn("model only supports non-quantized cache types", "model", arch)
|
||||||
return cacheType == "f16"
|
return false
|
||||||
}
|
}
|
||||||
return slices.Contains([]string{"f16", "q8_0", "q4_0"}, cacheType)
|
return slices.Contains([]string{"q8_0", "q4_0"}, cacheType)
|
||||||
}
|
}
|
||||||
|
|
||||||
// SupportsFlashAttention checks if the model supports flash attention
|
// SupportsFlashAttention checks if the model supports flash attention
|
||||||
@@ -767,12 +883,23 @@ func (f GGML) SupportsFlashAttention() bool {
|
|||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if arch := f.KV().Architecture(); slices.Contains([]string{"gemma2"}, arch) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
// Check head counts match and are non-zero
|
// Check head counts match and are non-zero
|
||||||
headCountK := f.KV().EmbeddingHeadCountK()
|
headCountK := f.KV().EmbeddingHeadCountK()
|
||||||
headCountV := f.KV().EmbeddingHeadCountV()
|
headCountV := f.KV().EmbeddingHeadCountV()
|
||||||
return headCountK != 0 && headCountV != 0 && headCountK == headCountV
|
return headCountK != 0 && headCountV != 0 && headCountK == headCountV
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// FlashAttention checks if the model should enable flash attention
|
||||||
|
func (f GGML) FlashAttention() bool {
|
||||||
|
return slices.Contains([]string{
|
||||||
|
"gptoss", "gpt-oss",
|
||||||
|
}, f.KV().String("general.architecture"))
|
||||||
|
}
|
||||||
|
|
||||||
// kvCacheBytesPerElement returns the number of bytes per element for a given KV cache type
|
// kvCacheBytesPerElement returns the number of bytes per element for a given KV cache type
|
||||||
func kvCacheBytesPerElement(cacheType string) float64 {
|
func kvCacheBytesPerElement(cacheType string) float64 {
|
||||||
switch cacheType {
|
switch cacheType {
|
||||||
@@ -780,6 +907,8 @@ func kvCacheBytesPerElement(cacheType string) float64 {
|
|||||||
return 1 // 1/2 of fp16
|
return 1 // 1/2 of fp16
|
||||||
case "q4_0":
|
case "q4_0":
|
||||||
return 0.5 // 1/4 of fp16
|
return 0.5 // 1/4 of fp16
|
||||||
|
case "f32":
|
||||||
|
return 4 // f32 (default for recurrent)
|
||||||
default:
|
default:
|
||||||
return 2 // f16 (default)
|
return 2 // f16 (default)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -533,12 +533,15 @@ func WriteGGUF(f *os.File, kv KV, ts []*Tensor) error {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
slices.SortStableFunc(ts, func(a, b *Tensor) int {
|
slices.SortStableFunc(
|
||||||
if i, j := a.block(), b.block(); i > 0 && j > 0 {
|
ts,
|
||||||
return cmp.Compare(i, j)
|
func(a, b *Tensor) int {
|
||||||
}
|
return cmp.Or(
|
||||||
return cmp.Compare(a.Name, b.Name)
|
cmp.Compare(a.block(), b.block()),
|
||||||
})
|
cmp.Compare(a.Name, b.Name),
|
||||||
|
)
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
var s uint64
|
var s uint64
|
||||||
for i := range ts {
|
for i := range ts {
|
||||||
|
|||||||
@@ -11,24 +11,24 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
func TestWriteGGUF(t *testing.T) {
|
func TestWriteGGUF(t *testing.T) {
|
||||||
r := rand.New(rand.NewPCG(0, 0))
|
b := bytes.NewBuffer(make([]byte, 2*3))
|
||||||
for range 8 {
|
for range 8 {
|
||||||
t.Run("shuffle", func(t *testing.T) {
|
t.Run("shuffle", func(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
ts := []*Tensor{
|
ts := []*Tensor{
|
||||||
{Name: "token_embd.weight", Shape: []uint64{2, 3}, WriterTo: bytes.NewBuffer(make([]byte, 2*3))},
|
{Name: "token_embd.weight", Shape: []uint64{2, 3}, WriterTo: b},
|
||||||
{Name: "blk.0.attn_norm.weight", Shape: []uint64{2, 3}, WriterTo: bytes.NewBuffer(make([]byte, 2*3))},
|
{Name: "blk.0.ffn_norm.weight", Shape: []uint64{2, 3}, WriterTo: b},
|
||||||
{Name: "blk.1.attn_norm.weight", Shape: []uint64{2, 3}, WriterTo: bytes.NewBuffer(make([]byte, 2*3))},
|
{Name: "blk.0.attn_norm.weight", Shape: []uint64{2, 3}, WriterTo: b},
|
||||||
{Name: "blk.2.attn_norm.weight", Shape: []uint64{2, 3}, WriterTo: bytes.NewBuffer(make([]byte, 2*3))},
|
{Name: "blk.1.ffn_up.weight", Shape: []uint64{2, 3}, WriterTo: b},
|
||||||
{Name: "blk.3.attn_norm.weight", Shape: []uint64{2, 3}, WriterTo: bytes.NewBuffer(make([]byte, 2*3))},
|
{Name: "blk.2.ffn_norm.weight", Shape: []uint64{2, 3}, WriterTo: b},
|
||||||
{Name: "blk.4.attn_norm.weight", Shape: []uint64{2, 3}, WriterTo: bytes.NewBuffer(make([]byte, 2*3))},
|
{Name: "blk.1.ffn_down.weight", Shape: []uint64{2, 3}, WriterTo: b},
|
||||||
{Name: "blk.5.attn_norm.weight", Shape: []uint64{2, 3}, WriterTo: bytes.NewBuffer(make([]byte, 2*3))},
|
{Name: "blk.0.attn_k.weight", Shape: []uint64{2, 3}, WriterTo: b},
|
||||||
{Name: "output_norm.weight", Shape: []uint64{3, 2}, WriterTo: bytes.NewBuffer(make([]byte, 3*2))},
|
{Name: "output_norm.weight", Shape: []uint64{3, 2}, WriterTo: b},
|
||||||
{Name: "output.weight", Shape: []uint64{3, 2}, WriterTo: bytes.NewBuffer(make([]byte, 3*2))},
|
{Name: "output.weight", Shape: []uint64{3, 2}, WriterTo: b},
|
||||||
}
|
}
|
||||||
|
|
||||||
r.Shuffle(len(ts), func(i, j int) {
|
rand.Shuffle(len(ts), func(i, j int) {
|
||||||
ts[i], ts[j] = ts[j], ts[i]
|
ts[i], ts[j] = ts[j], ts[i]
|
||||||
})
|
})
|
||||||
|
|
||||||
@@ -63,14 +63,14 @@ func TestWriteGGUF(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if diff := cmp.Diff(Tensors{
|
if diff := cmp.Diff(Tensors{
|
||||||
Offset: 608,
|
Offset: 592,
|
||||||
items: []*Tensor{
|
items: []*Tensor{
|
||||||
{Name: "blk.0.attn_norm.weight", Offset: 0, Shape: []uint64{2, 3}},
|
{Name: "blk.0.attn_k.weight", Offset: 0, Shape: []uint64{2, 3}},
|
||||||
{Name: "blk.1.attn_norm.weight", Offset: 32, Shape: []uint64{2, 3}},
|
{Name: "blk.0.attn_norm.weight", Offset: 32, Shape: []uint64{2, 3}},
|
||||||
{Name: "blk.2.attn_norm.weight", Offset: 64, Shape: []uint64{2, 3}},
|
{Name: "blk.0.ffn_norm.weight", Offset: 64, Shape: []uint64{2, 3}},
|
||||||
{Name: "blk.3.attn_norm.weight", Offset: 96, Shape: []uint64{2, 3}},
|
{Name: "blk.1.ffn_down.weight", Offset: 96, Shape: []uint64{2, 3}},
|
||||||
{Name: "blk.4.attn_norm.weight", Offset: 128, Shape: []uint64{2, 3}},
|
{Name: "blk.1.ffn_up.weight", Offset: 128, Shape: []uint64{2, 3}},
|
||||||
{Name: "blk.5.attn_norm.weight", Offset: 160, Shape: []uint64{2, 3}},
|
{Name: "blk.2.ffn_norm.weight", Offset: 160, Shape: []uint64{2, 3}},
|
||||||
{Name: "output.weight", Offset: 192, Shape: []uint64{3, 2}},
|
{Name: "output.weight", Offset: 192, Shape: []uint64{3, 2}},
|
||||||
{Name: "output_norm.weight", Offset: 224, Shape: []uint64{3, 2}},
|
{Name: "output_norm.weight", Offset: 224, Shape: []uint64{3, 2}},
|
||||||
{Name: "token_embd.weight", Offset: 256, Shape: []uint64{2, 3}},
|
{Name: "token_embd.weight", Offset: 256, Shape: []uint64{2, 3}},
|
||||||
|
|||||||
@@ -146,8 +146,6 @@ func (ftype FileType) ToTensorType() TensorType {
|
|||||||
return TensorTypeQ4_0
|
return TensorTypeQ4_0
|
||||||
case fileTypeQ4_1:
|
case fileTypeQ4_1:
|
||||||
return TensorTypeQ4_1
|
return TensorTypeQ4_1
|
||||||
case fileTypeMXFP4:
|
|
||||||
return TensorTypeMXFP4 // Formerly unused tensorTypeQ4_2
|
|
||||||
case FileTypeQ8_0:
|
case FileTypeQ8_0:
|
||||||
return TensorTypeQ8_0
|
return TensorTypeQ8_0
|
||||||
case fileTypeQ5_0:
|
case fileTypeQ5_0:
|
||||||
@@ -176,6 +174,8 @@ func (ftype FileType) ToTensorType() TensorType {
|
|||||||
return TensorTypeQ2_K
|
return TensorTypeQ2_K
|
||||||
case FileTypeBF16:
|
case FileTypeBF16:
|
||||||
return TensorTypeBF16
|
return TensorTypeBF16
|
||||||
|
case fileTypeMXFP4:
|
||||||
|
return TensorTypeMXFP4
|
||||||
default:
|
default:
|
||||||
slog.Warn("unsupported file type", "type", ftype)
|
slog.Warn("unsupported file type", "type", ftype)
|
||||||
return 0 // F32
|
return 0 // F32
|
||||||
@@ -191,8 +191,8 @@ const (
|
|||||||
TensorTypeF16
|
TensorTypeF16
|
||||||
TensorTypeQ4_0
|
TensorTypeQ4_0
|
||||||
TensorTypeQ4_1
|
TensorTypeQ4_1
|
||||||
TensorTypeMXFP4 // Formerly unused tensorTypeQ4_2
|
tensorTypeQ4_2
|
||||||
tensorTypeQ4_3 // unused by GGML
|
tensorTypeQ4_3 // unused by GGML
|
||||||
TensorTypeQ5_0
|
TensorTypeQ5_0
|
||||||
TensorTypeQ5_1
|
TensorTypeQ5_1
|
||||||
TensorTypeQ8_0
|
TensorTypeQ8_0
|
||||||
@@ -226,6 +226,7 @@ const (
|
|||||||
tensorTypeIQ4_NL_4_4 // unused by GGML
|
tensorTypeIQ4_NL_4_4 // unused by GGML
|
||||||
tensorTypeIQ4_NL_4_8 // unused by GGML
|
tensorTypeIQ4_NL_4_8 // unused by GGML
|
||||||
tensorTypeIQ4_NL_8_8 // unused by GGML
|
tensorTypeIQ4_NL_8_8 // unused by GGML
|
||||||
|
TensorTypeMXFP4
|
||||||
)
|
)
|
||||||
|
|
||||||
// ParseFileType parses the provided GGUF file type
|
// ParseFileType parses the provided GGUF file type
|
||||||
@@ -318,7 +319,7 @@ func (t TensorType) String() string {
|
|||||||
return "F64"
|
return "F64"
|
||||||
case TensorTypeBF16:
|
case TensorTypeBF16:
|
||||||
return "BF16"
|
return "BF16"
|
||||||
case TensorTypeMXFP4:
|
case 4, TensorTypeMXFP4:
|
||||||
return "MXFP4"
|
return "MXFP4"
|
||||||
default:
|
default:
|
||||||
return "unknown"
|
return "unknown"
|
||||||
|
|||||||
@@ -1,10 +1,8 @@
|
|||||||
package server
|
package harmony
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"log/slog"
|
"log/slog"
|
||||||
"slices"
|
|
||||||
"strings"
|
"strings"
|
||||||
"unicode"
|
"unicode"
|
||||||
|
|
||||||
@@ -20,18 +18,6 @@ const (
|
|||||||
harmonyParserState_ParsingContent
|
harmonyParserState_ParsingContent
|
||||||
)
|
)
|
||||||
|
|
||||||
func shouldUseHarmony(model Model) bool {
|
|
||||||
if slices.Contains([]string{"gptoss", "gpt-oss"}, model.Config.ModelFamily) {
|
|
||||||
// heuristic to check whether the template expects to be parsed via harmony:
|
|
||||||
// search for harmony tags that are nearly always used
|
|
||||||
if model.Template.Contains("<|start|>") && model.Template.Contains("<|end|>") {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s harmonyParserState) String() string {
|
func (s harmonyParserState) String() string {
|
||||||
switch s {
|
switch s {
|
||||||
// we're looking for the message start tag
|
// we're looking for the message start tag
|
||||||
@@ -277,20 +263,20 @@ const (
|
|||||||
// This is a higher level interface that maps harmony concepts into ollama concepts
|
// This is a higher level interface that maps harmony concepts into ollama concepts
|
||||||
type HarmonyMessageHandler struct {
|
type HarmonyMessageHandler struct {
|
||||||
state harmonyMessageState
|
state harmonyMessageState
|
||||||
harmonyParser *HarmonyParser
|
HarmonyParser *HarmonyParser
|
||||||
functionNameMap *FunctionNameMap
|
FunctionNameMap *FunctionNameMap
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewHarmonyMessageHandler creates a new message handler
|
// NewHarmonyMessageHandler creates a new message handler
|
||||||
func NewHarmonyMessageHandler() *HarmonyMessageHandler {
|
func NewHarmonyMessageHandler() *HarmonyMessageHandler {
|
||||||
return &HarmonyMessageHandler{
|
return &HarmonyMessageHandler{
|
||||||
state: harmonyMessageState_Normal,
|
state: harmonyMessageState_Normal,
|
||||||
harmonyParser: &HarmonyParser{
|
HarmonyParser: &HarmonyParser{
|
||||||
MessageStartTag: "<|start|>",
|
MessageStartTag: "<|start|>",
|
||||||
MessageEndTag: "<|end|>",
|
MessageEndTag: "<|end|>",
|
||||||
HeaderEndTag: "<|message|>",
|
HeaderEndTag: "<|message|>",
|
||||||
},
|
},
|
||||||
functionNameMap: NewFunctionNameMap(),
|
FunctionNameMap: NewFunctionNameMap(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -301,11 +287,11 @@ func (h *HarmonyMessageHandler) AddContent(content string, toolParser *HarmonyTo
|
|||||||
thinkingSb := strings.Builder{}
|
thinkingSb := strings.Builder{}
|
||||||
toolContentSb := strings.Builder{}
|
toolContentSb := strings.Builder{}
|
||||||
|
|
||||||
events := h.harmonyParser.AddContent(content)
|
events := h.HarmonyParser.AddContent(content)
|
||||||
for _, event := range events {
|
for _, event := range events {
|
||||||
switch event := event.(type) {
|
switch event := event.(type) {
|
||||||
case HarmonyEventHeaderComplete:
|
case HarmonyEventHeaderComplete:
|
||||||
slog.Log(context.TODO(), logutil.LevelTrace, "harmony event header complete", "header", event.Header)
|
logutil.Trace("harmony event header complete", "header", event.Header)
|
||||||
switch event.Header.Channel {
|
switch event.Header.Channel {
|
||||||
case "analysis":
|
case "analysis":
|
||||||
if event.Header.Recipient != "" {
|
if event.Header.Recipient != "" {
|
||||||
@@ -328,7 +314,7 @@ func (h *HarmonyMessageHandler) AddContent(content string, toolParser *HarmonyTo
|
|||||||
h.state = harmonyMessageState_Normal
|
h.state = harmonyMessageState_Normal
|
||||||
}
|
}
|
||||||
case HarmonyEventContentEmitted:
|
case HarmonyEventContentEmitted:
|
||||||
slog.Log(context.TODO(), logutil.LevelTrace, "harmony event content", "content", event.Content, "state", h.state)
|
logutil.Trace("harmony event content", "content", event.Content, "state", h.state)
|
||||||
if h.state == harmonyMessageState_Normal {
|
if h.state == harmonyMessageState_Normal {
|
||||||
contentSb.WriteString(event.Content)
|
contentSb.WriteString(event.Content)
|
||||||
} else if h.state == harmonyMessageState_Thinking {
|
} else if h.state == harmonyMessageState_Thinking {
|
||||||
@@ -1,4 +1,4 @@
|
|||||||
package server
|
package harmony
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
@@ -2,10 +2,13 @@
|
|||||||
|
|
||||||
This directory contains integration tests to exercise Ollama end-to-end to verify behavior
|
This directory contains integration tests to exercise Ollama end-to-end to verify behavior
|
||||||
|
|
||||||
By default, these tests are disabled so `go test ./...` will exercise only unit tests. To run integration tests you must pass the integration tag. `go test -tags=integration ./...`
|
By default, these tests are disabled so `go test ./...` will exercise only unit tests. To run integration tests you must pass the integration tag. `go test -tags=integration ./...` Some tests require additional tags to enable to allow scoped testing to keep the duration reasonable. For example, testing a broad set of models requires `-tags=integration,models` and a longer timeout (~60m or more depending on the speed of your GPU.). To view the current set of tag combinations use `find integration -type f | xargs grep "go:build"`
|
||||||
|
|
||||||
|
|
||||||
The integration tests have 2 modes of operating.
|
The integration tests have 2 modes of operating.
|
||||||
|
|
||||||
1. By default, they will start the server on a random port, run the tests, and then shutdown the server.
|
1. By default, they will start the server on a random port, run the tests, and then shutdown the server.
|
||||||
2. If `OLLAMA_TEST_EXISTING` is set to a non-empty string, the tests will run against an existing running server, which can be remote
|
2. If `OLLAMA_TEST_EXISTING` is set to a non-empty string, the tests will run against an existing running server, which can be remote based on your `OLLAMA_HOST` environment variable
|
||||||
|
|
||||||
|
> [!IMPORTANT]
|
||||||
|
> Before running the tests locally without the "test existing" setting, compile ollama from the top of the source tree `go build .` in addition to GPU support with cmake if applicable on your platform. The integration tests expect to find an ollama binary at the top of the tree.
|
||||||
|
|||||||
@@ -390,7 +390,7 @@ func TestAPIEmbeddings(t *testing.T) {
|
|||||||
client, _, cleanup := InitServerConnection(ctx, t)
|
client, _, cleanup := InitServerConnection(ctx, t)
|
||||||
defer cleanup()
|
defer cleanup()
|
||||||
req := api.EmbeddingRequest{
|
req := api.EmbeddingRequest{
|
||||||
Model: "orca-mini",
|
Model: libraryEmbedModels[0],
|
||||||
Prompt: "why is the sky blue?",
|
Prompt: "why is the sky blue?",
|
||||||
Options: map[string]interface{}{
|
Options: map[string]interface{}{
|
||||||
"temperature": 0,
|
"temperature": 0,
|
||||||
@@ -410,3 +410,99 @@ func TestAPIEmbeddings(t *testing.T) {
|
|||||||
t.Errorf("zero length embedding response")
|
t.Errorf("zero length embedding response")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestAPIToolCalling(t *testing.T) {
|
||||||
|
initialTimeout := 60 * time.Second
|
||||||
|
streamTimeout := 30 * time.Second
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
client, _, cleanup := InitServerConnection(ctx, t)
|
||||||
|
defer cleanup()
|
||||||
|
|
||||||
|
modelName := "qwen3:0.6b"
|
||||||
|
if err := PullIfMissing(ctx, client, modelName); err != nil {
|
||||||
|
t.Fatalf("pull failed %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
tools := []api.Tool{
|
||||||
|
{
|
||||||
|
Type: "function",
|
||||||
|
Function: api.ToolFunction{
|
||||||
|
Name: "get_weather",
|
||||||
|
Description: "Get the current weather in a given location",
|
||||||
|
Parameters: api.ToolFunctionParameters{
|
||||||
|
Type: "object",
|
||||||
|
Required: []string{"location"},
|
||||||
|
Properties: map[string]api.ToolProperty{
|
||||||
|
"location": {
|
||||||
|
Type: api.PropertyType{"string"},
|
||||||
|
Description: "The city and state, e.g. San Francisco, CA",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
req := api.ChatRequest{
|
||||||
|
Model: modelName,
|
||||||
|
Messages: []api.Message{
|
||||||
|
{
|
||||||
|
Role: "user",
|
||||||
|
Content: "Call get_weather with location set to San Francisco.",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
Tools: tools,
|
||||||
|
Options: map[string]any{
|
||||||
|
"temperature": 0,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
stallTimer := time.NewTimer(initialTimeout)
|
||||||
|
var gotToolCall bool
|
||||||
|
var lastToolCall api.ToolCall
|
||||||
|
|
||||||
|
fn := func(response api.ChatResponse) error {
|
||||||
|
if len(response.Message.ToolCalls) > 0 {
|
||||||
|
gotToolCall = true
|
||||||
|
lastToolCall = response.Message.ToolCalls[len(response.Message.ToolCalls)-1]
|
||||||
|
}
|
||||||
|
if !stallTimer.Reset(streamTimeout) {
|
||||||
|
return fmt.Errorf("stall was detected while streaming response, aborting")
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
stream := true
|
||||||
|
req.Stream = &stream
|
||||||
|
done := make(chan int)
|
||||||
|
var genErr error
|
||||||
|
go func() {
|
||||||
|
genErr = client.Chat(ctx, &req, fn)
|
||||||
|
done <- 0
|
||||||
|
}()
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-stallTimer.C:
|
||||||
|
t.Errorf("tool-calling chat never started. Timed out after: %s", initialTimeout.String())
|
||||||
|
case <-done:
|
||||||
|
if genErr != nil {
|
||||||
|
t.Fatalf("chat failed: %v", genErr)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !gotToolCall {
|
||||||
|
t.Fatalf("expected at least one tool call, got none")
|
||||||
|
}
|
||||||
|
|
||||||
|
if lastToolCall.Function.Name != "get_weather" {
|
||||||
|
t.Errorf("unexpected tool called: got %q want %q", lastToolCall.Function.Name, "get_weather")
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, ok := lastToolCall.Function.Arguments["location"]; !ok {
|
||||||
|
t.Errorf("expected tool arguments to include 'location', got: %s", lastToolCall.Function.Arguments.String())
|
||||||
|
}
|
||||||
|
case <-ctx.Done():
|
||||||
|
t.Error("outer test context done while waiting for tool-calling chat")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -11,7 +11,6 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/ollama/ollama/api"
|
"github.com/ollama/ollama/api"
|
||||||
"github.com/stretchr/testify/require"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestBlueSky(t *testing.T) {
|
func TestBlueSky(t *testing.T) {
|
||||||
@@ -37,8 +36,8 @@ func TestUnicode(t *testing.T) {
|
|||||||
// Set up the test data
|
// Set up the test data
|
||||||
req := api.GenerateRequest{
|
req := api.GenerateRequest{
|
||||||
// DeepSeek has a Unicode tokenizer regex, making it a unicode torture test
|
// DeepSeek has a Unicode tokenizer regex, making it a unicode torture test
|
||||||
Model: "deepseek-coder-v2:16b-lite-instruct-q2_K",
|
Model: "deepseek-coder-v2:16b-lite-instruct-q2_K", // TODO is there an ollama-engine model we can switch to and keep the coverage?
|
||||||
Prompt: "天空为什么是蓝色的?",
|
Prompt: "天空为什么是蓝色的?", // Why is the sky blue?
|
||||||
Stream: &stream,
|
Stream: &stream,
|
||||||
Options: map[string]any{
|
Options: map[string]any{
|
||||||
"temperature": 0,
|
"temperature": 0,
|
||||||
@@ -50,8 +49,20 @@ func TestUnicode(t *testing.T) {
|
|||||||
}
|
}
|
||||||
client, _, cleanup := InitServerConnection(ctx, t)
|
client, _, cleanup := InitServerConnection(ctx, t)
|
||||||
defer cleanup()
|
defer cleanup()
|
||||||
require.NoError(t, PullIfMissing(ctx, client, req.Model))
|
if err := PullIfMissing(ctx, client, req.Model); err != nil {
|
||||||
DoGenerate(ctx, t, client, req, []string{"散射", "频率"}, 120*time.Second, 120*time.Second)
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
slog.Info("loading", "model", req.Model)
|
||||||
|
err := client.Generate(ctx, &api.GenerateRequest{Model: req.Model}, func(response api.GenerateResponse) error { return nil })
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to load model %s: %s", req.Model, err)
|
||||||
|
}
|
||||||
|
skipIfNotGPULoaded(ctx, t, client, req.Model, 100)
|
||||||
|
|
||||||
|
DoGenerate(ctx, t, client, req, []string{
|
||||||
|
"散射", // scattering
|
||||||
|
"频率", // frequency
|
||||||
|
}, 120*time.Second, 120*time.Second)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestExtendedUnicodeOutput(t *testing.T) {
|
func TestExtendedUnicodeOutput(t *testing.T) {
|
||||||
@@ -69,7 +80,9 @@ func TestExtendedUnicodeOutput(t *testing.T) {
|
|||||||
}
|
}
|
||||||
client, _, cleanup := InitServerConnection(ctx, t)
|
client, _, cleanup := InitServerConnection(ctx, t)
|
||||||
defer cleanup()
|
defer cleanup()
|
||||||
require.NoError(t, PullIfMissing(ctx, client, req.Model))
|
if err := PullIfMissing(ctx, client, req.Model); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
DoGenerate(ctx, t, client, req, []string{"😀", "😊", "😁", "😂", "😄", "😃"}, 120*time.Second, 120*time.Second)
|
DoGenerate(ctx, t, client, req, []string{"😀", "😊", "😁", "😂", "😄", "😃"}, 120*time.Second, 120*time.Second)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -84,7 +97,9 @@ func TestUnicodeModelDir(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
modelDir, err := os.MkdirTemp("", "ollama_埃")
|
modelDir, err := os.MkdirTemp("", "ollama_埃")
|
||||||
require.NoError(t, err)
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
defer os.RemoveAll(modelDir)
|
defer os.RemoveAll(modelDir)
|
||||||
slog.Info("unicode", "OLLAMA_MODELS", modelDir)
|
slog.Info("unicode", "OLLAMA_MODELS", modelDir)
|
||||||
|
|
||||||
|
|||||||
@@ -14,8 +14,6 @@ import (
|
|||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/stretchr/testify/require"
|
|
||||||
|
|
||||||
"github.com/ollama/ollama/api"
|
"github.com/ollama/ollama/api"
|
||||||
"github.com/ollama/ollama/envconfig"
|
"github.com/ollama/ollama/envconfig"
|
||||||
"github.com/ollama/ollama/format"
|
"github.com/ollama/ollama/format"
|
||||||
@@ -79,21 +77,21 @@ func TestMultiModelStress(t *testing.T) {
|
|||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// All models compatible with ollama-engine
|
||||||
smallModels := []string{
|
smallModels := []string{
|
||||||
"llama3.2:1b",
|
"llama3.2:1b",
|
||||||
"qwen3:0.6b",
|
"qwen3:0.6b",
|
||||||
"gemma:2b",
|
"gemma2:2b",
|
||||||
"deepseek-r1:1.5b",
|
"deepseek-r1:1.5b", // qwen2 arch
|
||||||
"starcoder2:3b",
|
"gemma3:270m",
|
||||||
}
|
}
|
||||||
mediumModels := []string{
|
mediumModels := []string{
|
||||||
"qwen3:8b",
|
"llama3.2:3b", // ~3.4G
|
||||||
"llama2",
|
"qwen3:8b", // ~6.6G
|
||||||
"deepseek-r1:7b",
|
"gpt-oss:20b", // ~15G
|
||||||
"mistral",
|
"deepseek-r1:7b", // ~5.6G
|
||||||
"dolphin-mistral",
|
"gemma3:4b", // ~5.8G
|
||||||
"gemma:7b",
|
"gemma2:9b", // ~8.1G
|
||||||
"codellama:7b",
|
|
||||||
}
|
}
|
||||||
|
|
||||||
var chosenModels []string
|
var chosenModels []string
|
||||||
@@ -114,13 +112,16 @@ func TestMultiModelStress(t *testing.T) {
|
|||||||
|
|
||||||
// Make sure all the models are pulled before we get started
|
// Make sure all the models are pulled before we get started
|
||||||
for _, model := range chosenModels {
|
for _, model := range chosenModels {
|
||||||
require.NoError(t, PullIfMissing(ctx, client, model))
|
if err := PullIfMissing(ctx, client, model); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Determine how many models we can load in parallel before we exceed VRAM
|
// Determine how many models we can load in parallel before we exceed VRAM
|
||||||
// The intent is to go 1 over what can fit so we force the scheduler to thrash
|
// The intent is to go 1 over what can fit so we force the scheduler to thrash
|
||||||
targetLoadCount := 0
|
targetLoadCount := 0
|
||||||
slog.Info("Loading models to find how many can fit in VRAM before overflowing")
|
slog.Info("Loading models to find how many can fit in VRAM before overflowing")
|
||||||
|
chooseModels:
|
||||||
for i, model := range chosenModels {
|
for i, model := range chosenModels {
|
||||||
req := &api.GenerateRequest{Model: model}
|
req := &api.GenerateRequest{Model: model}
|
||||||
slog.Info("loading", "model", model)
|
slog.Info("loading", "model", model)
|
||||||
@@ -142,6 +143,13 @@ func TestMultiModelStress(t *testing.T) {
|
|||||||
slog.Info("found model load capacity", "target", targetLoadCount, "current", loaded, "chosen", chosenModels[:targetLoadCount])
|
slog.Info("found model load capacity", "target", targetLoadCount, "current", loaded, "chosen", chosenModels[:targetLoadCount])
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
|
// Effectively limit model count to 2 on CPU only systems to avoid thrashing and timeouts
|
||||||
|
for _, m := range models.Models {
|
||||||
|
if m.SizeVRAM == 0 {
|
||||||
|
slog.Info("model running on CPU", "name", m.Name, "target", targetLoadCount, "chosen", chosenModels[:targetLoadCount])
|
||||||
|
break chooseModels
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if targetLoadCount == len(chosenModels) {
|
if targetLoadCount == len(chosenModels) {
|
||||||
|
|||||||
@@ -22,7 +22,7 @@ func TestLongInputContext(t *testing.T) {
|
|||||||
defer cancel()
|
defer cancel()
|
||||||
// Set up the test data
|
// Set up the test data
|
||||||
req := api.GenerateRequest{
|
req := api.GenerateRequest{
|
||||||
Model: "llama2",
|
Model: smol,
|
||||||
Prompt: "Oh, don’t speak to me of Austria. Perhaps I don’t understand things, but Austria never has wished, and does not wish, for war. She is betraying us! Russia alone must save Europe. Our gracious sovereign recognizes his high vocation and will be true to it. That is the one thing I have faith in! Our good and wonderful sovereign has to perform the noblest role on earth, and he is so virtuous and noble that God will not forsake him. He will fulfill his vocation and crush the hydra of revolution, which has become more terrible than ever in the person of this murderer and villain! We alone must avenge the blood of the just one.... Whom, I ask you, can we rely on?... England with her commercial spirit will not and cannot understand the Emperor Alexander’s loftiness of soul. She has refused to evacuate Malta. She wanted to find, and still seeks, some secret motive in our actions. What answer did Novosíltsev get? None. The English have not understood and cannot understand the self-abnegation of our Emperor who wants nothing for himself, but only desires the good of mankind. And what have they promised? Nothing! And what little they have promised they will not perform! Prussia has always declared that Buonaparte is invincible, and that all Europe is powerless before him.... And I don’t believe a word that Hardenburg says, or Haugwitz either. This famous Prussian neutrality is just a trap. I have faith only in God and the lofty destiny of our adored monarch. He will save Europe! What country is this referring to?",
|
Prompt: "Oh, don’t speak to me of Austria. Perhaps I don’t understand things, but Austria never has wished, and does not wish, for war. She is betraying us! Russia alone must save Europe. Our gracious sovereign recognizes his high vocation and will be true to it. That is the one thing I have faith in! Our good and wonderful sovereign has to perform the noblest role on earth, and he is so virtuous and noble that God will not forsake him. He will fulfill his vocation and crush the hydra of revolution, which has become more terrible than ever in the person of this murderer and villain! We alone must avenge the blood of the just one.... Whom, I ask you, can we rely on?... England with her commercial spirit will not and cannot understand the Emperor Alexander’s loftiness of soul. She has refused to evacuate Malta. She wanted to find, and still seeks, some secret motive in our actions. What answer did Novosíltsev get? None. The English have not understood and cannot understand the self-abnegation of our Emperor who wants nothing for himself, but only desires the good of mankind. And what have they promised? Nothing! And what little they have promised they will not perform! Prussia has always declared that Buonaparte is invincible, and that all Europe is powerless before him.... And I don’t believe a word that Hardenburg says, or Haugwitz either. This famous Prussian neutrality is just a trap. I have faith only in God and the lofty destiny of our adored monarch. He will save Europe! What country is this referring to?",
|
||||||
Stream: &stream,
|
Stream: &stream,
|
||||||
Options: map[string]any{
|
Options: map[string]any{
|
||||||
@@ -36,7 +36,7 @@ func TestLongInputContext(t *testing.T) {
|
|||||||
if err := PullIfMissing(ctx, client, req.Model); err != nil {
|
if err := PullIfMissing(ctx, client, req.Model); err != nil {
|
||||||
t.Fatalf("PullIfMissing failed: %v", err)
|
t.Fatalf("PullIfMissing failed: %v", err)
|
||||||
}
|
}
|
||||||
DoGenerate(ctx, t, client, req, []string{"russia", "germany", "france", "england", "austria", "prussia"}, 120*time.Second, 10*time.Second)
|
DoGenerate(ctx, t, client, req, []string{"russia", "germany", "france", "england", "austria", "prussia", "europe", "individuals", "coalition", "conflict"}, 120*time.Second, 10*time.Second)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestContextExhaustion(t *testing.T) {
|
func TestContextExhaustion(t *testing.T) {
|
||||||
@@ -49,8 +49,8 @@ func TestContextExhaustion(t *testing.T) {
|
|||||||
defer cancel()
|
defer cancel()
|
||||||
// Set up the test data
|
// Set up the test data
|
||||||
req := api.GenerateRequest{
|
req := api.GenerateRequest{
|
||||||
Model: "llama2",
|
Model: smol,
|
||||||
Prompt: "Write me a story with a ton of emojis?",
|
Prompt: "Write me a story in english with a lot of emojis",
|
||||||
Stream: &stream,
|
Stream: &stream,
|
||||||
Options: map[string]any{
|
Options: map[string]any{
|
||||||
"temperature": 0,
|
"temperature": 0,
|
||||||
@@ -63,10 +63,10 @@ func TestContextExhaustion(t *testing.T) {
|
|||||||
if err := PullIfMissing(ctx, client, req.Model); err != nil {
|
if err := PullIfMissing(ctx, client, req.Model); err != nil {
|
||||||
t.Fatalf("PullIfMissing failed: %v", err)
|
t.Fatalf("PullIfMissing failed: %v", err)
|
||||||
}
|
}
|
||||||
DoGenerate(ctx, t, client, req, []string{"once", "upon", "lived"}, 120*time.Second, 10*time.Second)
|
DoGenerate(ctx, t, client, req, []string{"once", "upon", "lived", "sunny", "cloudy", "clear", "water"}, 120*time.Second, 10*time.Second)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Send multiple requests with prior context and ensure the response is coherant and expected
|
// Send multiple generate requests with prior context and ensure the response is coherant and expected
|
||||||
func TestGenerateWithHistory(t *testing.T) {
|
func TestGenerateWithHistory(t *testing.T) {
|
||||||
modelOverride := ollamaEngineChatModels[0] // Most recent ollama engine model
|
modelOverride := ollamaEngineChatModels[0] // Most recent ollama engine model
|
||||||
req, resp := GenerateRequests()
|
req, resp := GenerateRequests()
|
||||||
@@ -111,5 +111,56 @@ func TestGenerateWithHistory(t *testing.T) {
|
|||||||
}(i)
|
}(i)
|
||||||
}
|
}
|
||||||
wg.Wait()
|
wg.Wait()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Send multiple chat requests with prior context and ensure the response is coherant and expected
|
||||||
|
func TestChatWithHistory(t *testing.T) {
|
||||||
|
modelOverride := ollamaEngineChatModels[0] // Most recent ollama engine model
|
||||||
|
req, resp := ChatRequests()
|
||||||
|
numParallel := 2
|
||||||
|
iterLimit := 2
|
||||||
|
|
||||||
|
softTimeout, hardTimeout := getTimeouts(t)
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), hardTimeout)
|
||||||
|
defer cancel()
|
||||||
|
client, _, cleanup := InitServerConnection(ctx, t)
|
||||||
|
defer cleanup()
|
||||||
|
|
||||||
|
// Get the server running (if applicable) warm the model up with a single initial empty request
|
||||||
|
slog.Info("loading", "model", modelOverride)
|
||||||
|
err := client.Generate(ctx,
|
||||||
|
&api.GenerateRequest{Model: modelOverride, KeepAlive: &api.Duration{Duration: 10 * time.Second}},
|
||||||
|
func(response api.GenerateResponse) error { return nil },
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to load model %s: %s", modelOverride, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
wg.Add(numParallel)
|
||||||
|
for i := range numParallel {
|
||||||
|
go func(i int) {
|
||||||
|
defer wg.Done()
|
||||||
|
k := i % len(req)
|
||||||
|
req[k].Model = modelOverride
|
||||||
|
for j := 0; j < iterLimit; j++ {
|
||||||
|
if time.Now().Sub(started) > softTimeout {
|
||||||
|
slog.Info("exceeded soft timeout, winding down test")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
slog.Info("Starting", "thread", i, "iter", j)
|
||||||
|
// On slower GPUs it can take a while to process the concurrent requests
|
||||||
|
// so we allow a much longer initial timeout
|
||||||
|
assistant := DoChat(ctx, t, client, req[k], resp[k], 120*time.Second, 20*time.Second)
|
||||||
|
if assistant == nil {
|
||||||
|
t.Fatalf("didn't get an assistant response for context")
|
||||||
|
}
|
||||||
|
req[k].Messages = append(req[k].Messages,
|
||||||
|
*assistant,
|
||||||
|
api.Message{Role: "user", Content: "tell me more!"},
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}(i)
|
||||||
|
}
|
||||||
|
wg.Wait()
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -38,8 +38,9 @@ func TestAllMiniLMEmbeddings(t *testing.T) {
|
|||||||
defer cleanup()
|
defer cleanup()
|
||||||
|
|
||||||
req := api.EmbeddingRequest{
|
req := api.EmbeddingRequest{
|
||||||
Model: "all-minilm",
|
Model: "all-minilm",
|
||||||
Prompt: "why is the sky blue?",
|
Prompt: "why is the sky blue?",
|
||||||
|
KeepAlive: &api.Duration{Duration: 10 * time.Second},
|
||||||
}
|
}
|
||||||
|
|
||||||
res, err := embeddingTestHelper(ctx, client, t, req)
|
res, err := embeddingTestHelper(ctx, client, t, req)
|
||||||
|
|||||||
@@ -9,7 +9,6 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/ollama/ollama/api"
|
"github.com/ollama/ollama/api"
|
||||||
"github.com/stretchr/testify/require"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestVisionModels(t *testing.T) {
|
func TestVisionModels(t *testing.T) {
|
||||||
@@ -32,7 +31,9 @@ func TestVisionModels(t *testing.T) {
|
|||||||
for _, v := range testCases {
|
for _, v := range testCases {
|
||||||
t.Run(v.model, func(t *testing.T) {
|
t.Run(v.model, func(t *testing.T) {
|
||||||
image, err := base64.StdEncoding.DecodeString(imageEncoding)
|
image, err := base64.StdEncoding.DecodeString(imageEncoding)
|
||||||
require.NoError(t, err)
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
req := api.GenerateRequest{
|
req := api.GenerateRequest{
|
||||||
Model: v.model,
|
Model: v.model,
|
||||||
Prompt: "what does the text in this image say?",
|
Prompt: "what does the text in this image say?",
|
||||||
@@ -52,7 +53,9 @@ func TestVisionModels(t *testing.T) {
|
|||||||
// Note: sometimes it returns "the ollamas" sometimes "the ollams"
|
// Note: sometimes it returns "the ollamas" sometimes "the ollams"
|
||||||
resp := "the ollam"
|
resp := "the ollam"
|
||||||
defer cleanup()
|
defer cleanup()
|
||||||
require.NoError(t, PullIfMissing(ctx, client, req.Model))
|
if err := PullIfMissing(ctx, client, req.Model); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
// llava models on CPU can be quite slow to start
|
// llava models on CPU can be quite slow to start
|
||||||
DoGenerate(ctx, t, client, req, []string{resp}, 240*time.Second, 30*time.Second)
|
DoGenerate(ctx, t, client, req, []string{resp}, 240*time.Second, 30*time.Second)
|
||||||
})
|
})
|
||||||
@@ -62,7 +65,9 @@ func TestVisionModels(t *testing.T) {
|
|||||||
func TestIntegrationSplitBatch(t *testing.T) {
|
func TestIntegrationSplitBatch(t *testing.T) {
|
||||||
skipUnderMinVRAM(t, 6)
|
skipUnderMinVRAM(t, 6)
|
||||||
image, err := base64.StdEncoding.DecodeString(imageEncoding)
|
image, err := base64.StdEncoding.DecodeString(imageEncoding)
|
||||||
require.NoError(t, err)
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
req := api.GenerateRequest{
|
req := api.GenerateRequest{
|
||||||
Model: "gemma3:4b",
|
Model: "gemma3:4b",
|
||||||
// Fill up a chunk of the batch so the image will partially spill over into the next one
|
// Fill up a chunk of the batch so the image will partially spill over into the next one
|
||||||
@@ -84,7 +89,9 @@ func TestIntegrationSplitBatch(t *testing.T) {
|
|||||||
defer cancel()
|
defer cancel()
|
||||||
client, _, cleanup := InitServerConnection(ctx, t)
|
client, _, cleanup := InitServerConnection(ctx, t)
|
||||||
defer cleanup()
|
defer cleanup()
|
||||||
require.NoError(t, PullIfMissing(ctx, client, req.Model))
|
if err := PullIfMissing(ctx, client, req.Model); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
// llava models on CPU can be quite slow to start,
|
// llava models on CPU can be quite slow to start,
|
||||||
DoGenerate(ctx, t, client, req, []string{resp}, 120*time.Second, 30*time.Second)
|
DoGenerate(ctx, t, client, req, []string{resp}, 120*time.Second, 30*time.Second)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,47 +0,0 @@
|
|||||||
//go:build integration
|
|
||||||
|
|
||||||
package integration
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"testing"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/ollama/ollama/api"
|
|
||||||
)
|
|
||||||
|
|
||||||
// TODO - this would ideally be in the llm package, but that would require some refactoring of interfaces in the server
|
|
||||||
// package to avoid circular dependencies
|
|
||||||
|
|
||||||
var (
|
|
||||||
stream = false
|
|
||||||
req = [2]api.GenerateRequest{
|
|
||||||
{
|
|
||||||
Model: smol,
|
|
||||||
Prompt: "why is the ocean blue?",
|
|
||||||
Stream: &stream,
|
|
||||||
Options: map[string]any{
|
|
||||||
"seed": 42,
|
|
||||||
"temperature": 0.0,
|
|
||||||
},
|
|
||||||
}, {
|
|
||||||
Model: smol,
|
|
||||||
Prompt: "what is the origin of the us thanksgiving holiday?",
|
|
||||||
Stream: &stream,
|
|
||||||
Options: map[string]any{
|
|
||||||
"seed": 42,
|
|
||||||
"temperature": 0.0,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
resp = [2][]string{
|
|
||||||
{"sunlight", "scattering", "interact"},
|
|
||||||
{"england", "english", "massachusetts", "pilgrims"},
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestIntegrationSimple(t *testing.T) {
|
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second*120)
|
|
||||||
defer cancel()
|
|
||||||
GenerateTestHelper(ctx, t, req[0], resp[0])
|
|
||||||
}
|
|
||||||
@@ -13,12 +13,12 @@ import (
|
|||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/stretchr/testify/require"
|
|
||||||
|
|
||||||
"github.com/ollama/ollama/api"
|
"github.com/ollama/ollama/api"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestMaxQueue(t *testing.T) {
|
func TestMaxQueue(t *testing.T) {
|
||||||
|
t.Skip("this test needs to be re-evaluated to use a proper embedding model")
|
||||||
|
|
||||||
if os.Getenv("OLLAMA_TEST_EXISTING") != "" {
|
if os.Getenv("OLLAMA_TEST_EXISTING") != "" {
|
||||||
t.Skip("Max Queue test requires spawning a local server so we can adjust the queue size")
|
t.Skip("Max Queue test requires spawning a local server so we can adjust the queue size")
|
||||||
return
|
return
|
||||||
@@ -45,7 +45,9 @@ func TestMaxQueue(t *testing.T) {
|
|||||||
client, _, cleanup := InitServerConnection(ctx, t)
|
client, _, cleanup := InitServerConnection(ctx, t)
|
||||||
defer cleanup()
|
defer cleanup()
|
||||||
|
|
||||||
require.NoError(t, PullIfMissing(ctx, client, req.Model))
|
if err := PullIfMissing(ctx, client, req.Model); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
// Context for the worker threads so we can shut them down
|
// Context for the worker threads so we can shut them down
|
||||||
// embedCtx, embedCancel := context.WithCancel(ctx)
|
// embedCtx, embedCancel := context.WithCancel(ctx)
|
||||||
@@ -89,7 +91,9 @@ func TestMaxQueue(t *testing.T) {
|
|||||||
switch {
|
switch {
|
||||||
case genErr == nil:
|
case genErr == nil:
|
||||||
successCount++
|
successCount++
|
||||||
require.Greater(t, len(resp.Embedding), 5) // somewhat arbitrary, but sufficient to be reasonable
|
if len(resp.Embedding) < 5 { // somewhat arbitrary, but sufficient to be reasonable
|
||||||
|
t.Fatalf("embeddings shorter than expected: %d", len(resp.Embedding))
|
||||||
|
}
|
||||||
case errors.Is(genErr, context.Canceled):
|
case errors.Is(genErr, context.Canceled):
|
||||||
canceledCount++
|
canceledCount++
|
||||||
case strings.Contains(genErr.Error(), "busy"):
|
case strings.Contains(genErr.Error(), "busy"):
|
||||||
@@ -97,7 +101,9 @@ func TestMaxQueue(t *testing.T) {
|
|||||||
case strings.Contains(genErr.Error(), "connection reset by peer"):
|
case strings.Contains(genErr.Error(), "connection reset by peer"):
|
||||||
resetByPeerCount++
|
resetByPeerCount++
|
||||||
default:
|
default:
|
||||||
require.NoError(t, genErr, "%d request failed", i)
|
if genErr != nil {
|
||||||
|
t.Fatalf("%d request failed", i)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
slog.Info("embed finished", "id", i)
|
slog.Info("embed finished", "id", i)
|
||||||
@@ -108,8 +114,13 @@ func TestMaxQueue(t *testing.T) {
|
|||||||
embedwg.Wait()
|
embedwg.Wait()
|
||||||
|
|
||||||
slog.Info("embeds completed", "success", successCount, "busy", busyCount, "reset", resetByPeerCount, "canceled", canceledCount)
|
slog.Info("embeds completed", "success", successCount, "busy", busyCount, "reset", resetByPeerCount, "canceled", canceledCount)
|
||||||
require.Equal(t, resetByPeerCount, 0, "Connections reset by peer, have you updated your fd and socket limits?")
|
if resetByPeerCount != 0 {
|
||||||
require.True(t, busyCount > 0, "no requests hit busy error but some should have")
|
t.Fatalf("Connections reset by peer, have you updated your fd and socket limits? %d", resetByPeerCount)
|
||||||
require.True(t, canceledCount == 0, "no requests should have been canceled due to timeout")
|
}
|
||||||
|
if busyCount == 0 {
|
||||||
|
t.Fatalf("no requests hit busy error but some should have")
|
||||||
|
}
|
||||||
|
if canceledCount > 0 {
|
||||||
|
t.Fatalf("no requests should have been canceled due to timeout %d", canceledCount)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -9,6 +9,7 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"log/slog"
|
"log/slog"
|
||||||
|
"math"
|
||||||
"math/rand"
|
"math/rand"
|
||||||
"net"
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
@@ -25,11 +26,11 @@ import (
|
|||||||
"github.com/ollama/ollama/api"
|
"github.com/ollama/ollama/api"
|
||||||
"github.com/ollama/ollama/app/lifecycle"
|
"github.com/ollama/ollama/app/lifecycle"
|
||||||
"github.com/ollama/ollama/format"
|
"github.com/ollama/ollama/format"
|
||||||
"github.com/stretchr/testify/require"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
smol = "llama3.2:1b"
|
smol = "llama3.2:1b"
|
||||||
|
stream = false
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
@@ -435,7 +436,9 @@ func InitServerConnection(ctx context.Context, t *testing.T) (*api.Client, strin
|
|||||||
}
|
}
|
||||||
lifecycle.ServerLogFile = fp.Name()
|
lifecycle.ServerLogFile = fp.Name()
|
||||||
fp.Close()
|
fp.Close()
|
||||||
require.NoError(t, startServer(t, ctx, testEndpoint))
|
if err := startServer(t, ctx, testEndpoint); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return client, testEndpoint, func() {
|
return client, testEndpoint, func() {
|
||||||
@@ -468,7 +471,9 @@ func InitServerConnection(ctx context.Context, t *testing.T) (*api.Client, strin
|
|||||||
func GenerateTestHelper(ctx context.Context, t *testing.T, genReq api.GenerateRequest, anyResp []string) {
|
func GenerateTestHelper(ctx context.Context, t *testing.T, genReq api.GenerateRequest, anyResp []string) {
|
||||||
client, _, cleanup := InitServerConnection(ctx, t)
|
client, _, cleanup := InitServerConnection(ctx, t)
|
||||||
defer cleanup()
|
defer cleanup()
|
||||||
require.NoError(t, PullIfMissing(ctx, client, genReq.Model))
|
if err := PullIfMissing(ctx, client, genReq.Model); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
DoGenerate(ctx, t, client, genReq, anyResp, 30*time.Second, 10*time.Second)
|
DoGenerate(ctx, t, client, genReq, anyResp, 30*time.Second, 10*time.Second)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -497,6 +502,22 @@ func DoGenerate(ctx context.Context, t *testing.T, client *api.Client, genReq ap
|
|||||||
done <- 0
|
done <- 0
|
||||||
}()
|
}()
|
||||||
|
|
||||||
|
var response string
|
||||||
|
verify := func() {
|
||||||
|
// Verify the response contains the expected data
|
||||||
|
response = buf.String()
|
||||||
|
atLeastOne := false
|
||||||
|
for _, resp := range anyResp {
|
||||||
|
if strings.Contains(strings.ToLower(response), resp) {
|
||||||
|
atLeastOne = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !atLeastOne {
|
||||||
|
t.Fatalf("%s: none of %v found in %s", genReq.Model, anyResp, response)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
select {
|
select {
|
||||||
case <-stallTimer.C:
|
case <-stallTimer.C:
|
||||||
if buf.Len() == 0 {
|
if buf.Len() == 0 {
|
||||||
@@ -509,20 +530,17 @@ func DoGenerate(ctx context.Context, t *testing.T, client *api.Client, genReq ap
|
|||||||
slog.Warn("model is too large for the target test system", "model", genReq.Model, "error", genErr)
|
slog.Warn("model is too large for the target test system", "model", genReq.Model, "error", genErr)
|
||||||
return context
|
return context
|
||||||
}
|
}
|
||||||
require.NoError(t, genErr, "failed with %s request prompt %s ", genReq.Model, genReq.Prompt)
|
if genErr != nil {
|
||||||
// Verify the response contains the expected data
|
t.Fatalf("%s failed with %s request prompt %s", genErr, genReq.Model, genReq.Prompt)
|
||||||
response := buf.String()
|
|
||||||
atLeastOne := false
|
|
||||||
for _, resp := range anyResp {
|
|
||||||
if strings.Contains(strings.ToLower(response), resp) {
|
|
||||||
atLeastOne = true
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
require.True(t, atLeastOne, "%s: none of %v found in %s", genReq.Model, anyResp, response)
|
verify()
|
||||||
slog.Info("test pass", "model", genReq.Model, "prompt", genReq.Prompt, "contains", anyResp, "response", response)
|
slog.Info("test pass", "model", genReq.Model, "prompt", genReq.Prompt, "contains", anyResp, "response", response)
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
t.Error("outer test context done while waiting for generate")
|
// On slow systems, we might timeout before some models finish rambling, so check what we have so far to see
|
||||||
|
// if it's considered a pass - the stallTimer will detect hangs, but we want to consider slow systems a pass
|
||||||
|
// if they are still generating valid responses
|
||||||
|
slog.Warn("outer test context done while waiting for generate")
|
||||||
|
verify()
|
||||||
}
|
}
|
||||||
return context
|
return context
|
||||||
}
|
}
|
||||||
@@ -543,7 +561,7 @@ func GenerateRequests() ([]api.GenerateRequest, [][]string) {
|
|||||||
KeepAlive: &api.Duration{Duration: 10 * time.Second},
|
KeepAlive: &api.Duration{Duration: 10 * time.Second},
|
||||||
}, {
|
}, {
|
||||||
Model: smol,
|
Model: smol,
|
||||||
Prompt: "what is the origin of the US thanksgiving holiday? Be brief but factual in your reply",
|
Prompt: "how do rainbows form? Be brief but factual in your reply",
|
||||||
Stream: &stream,
|
Stream: &stream,
|
||||||
KeepAlive: &api.Duration{Duration: 10 * time.Second},
|
KeepAlive: &api.Duration{Duration: 10 * time.Second},
|
||||||
}, {
|
}, {
|
||||||
@@ -561,17 +579,104 @@ func GenerateRequests() ([]api.GenerateRequest, [][]string) {
|
|||||||
[][]string{
|
[][]string{
|
||||||
{"sunlight", "scattering", "interact", "color", "surface", "depth", "red", "orange", "yellow", "absorbs", "wavelength"},
|
{"sunlight", "scattering", "interact", "color", "surface", "depth", "red", "orange", "yellow", "absorbs", "wavelength"},
|
||||||
{"soil", "organic", "earth", "black", "tan", "chemical", "processes", "pigments", "particles", "iron oxide", "rust", "air", "water", "mixture", "mixing"},
|
{"soil", "organic", "earth", "black", "tan", "chemical", "processes", "pigments", "particles", "iron oxide", "rust", "air", "water", "mixture", "mixing"},
|
||||||
{"england", "english", "massachusetts", "pilgrims", "colonists", "independence", "british", "feast", "family", "gatherings", "traditions", "turkey", "colonial", "period", "harvest", "agricultural", "european settlers", "american revolution", "civil war", "16th century", "17th century", "native american", "united states"},
|
{"water", "droplet", "refracted", "reflect", "color", "spectrum"},
|
||||||
{"fourth", "july", "declaration", "independence"},
|
{"fourth", "july", "declaration", "independence"},
|
||||||
{"nitrogen", "oxygen", "carbon", "dioxide"},
|
{"nitrogen", "oxygen", "carbon", "dioxide", "water", "vapor"},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func DoChat(ctx context.Context, t *testing.T, client *api.Client, req api.ChatRequest, anyResp []string, initialTimeout, streamTimeout time.Duration) *api.Message {
|
||||||
|
stallTimer := time.NewTimer(initialTimeout)
|
||||||
|
var buf bytes.Buffer
|
||||||
|
role := "assistant"
|
||||||
|
fn := func(response api.ChatResponse) error {
|
||||||
|
// fmt.Print(".")
|
||||||
|
role = response.Message.Role
|
||||||
|
buf.Write([]byte(response.Message.Content))
|
||||||
|
if !stallTimer.Reset(streamTimeout) {
|
||||||
|
return errors.New("stall was detected while streaming response, aborting")
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
stream := true
|
||||||
|
req.Stream = &stream
|
||||||
|
done := make(chan int)
|
||||||
|
var genErr error
|
||||||
|
go func() {
|
||||||
|
genErr = client.Chat(ctx, &req, fn)
|
||||||
|
done <- 0
|
||||||
|
}()
|
||||||
|
|
||||||
|
var response string
|
||||||
|
verify := func() {
|
||||||
|
// Verify the response contains the expected data
|
||||||
|
response = buf.String()
|
||||||
|
atLeastOne := false
|
||||||
|
for _, resp := range anyResp {
|
||||||
|
if strings.Contains(strings.ToLower(response), resp) {
|
||||||
|
atLeastOne = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !atLeastOne {
|
||||||
|
t.Fatalf("%s: none of %v found in \"%s\" -- request was:%v", req.Model, anyResp, response, req.Messages)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-stallTimer.C:
|
||||||
|
if buf.Len() == 0 {
|
||||||
|
t.Errorf("generate never started. Timed out after :%s", initialTimeout.String())
|
||||||
|
} else {
|
||||||
|
t.Errorf("generate stalled. Response so far:%s", buf.String())
|
||||||
|
}
|
||||||
|
case <-done:
|
||||||
|
if genErr != nil && strings.Contains(genErr.Error(), "model requires more system memory") {
|
||||||
|
slog.Warn("model is too large for the target test system", "model", req.Model, "error", genErr)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if genErr != nil {
|
||||||
|
t.Fatalf("%s failed with %s request prompt %v", genErr, req.Model, req.Messages)
|
||||||
|
}
|
||||||
|
verify()
|
||||||
|
slog.Info("test pass", "model", req.Model, "messages", req.Messages, "contains", anyResp, "response", response)
|
||||||
|
case <-ctx.Done():
|
||||||
|
// On slow systems, we might timeout before some models finish rambling, so check what we have so far to see
|
||||||
|
// if it's considered a pass - the stallTimer will detect hangs, but we want to consider slow systems a pass
|
||||||
|
// if they are still generating valid responses
|
||||||
|
slog.Warn("outer test context done while waiting for chat")
|
||||||
|
verify()
|
||||||
|
}
|
||||||
|
return &api.Message{Role: role, Content: buf.String()}
|
||||||
|
}
|
||||||
|
|
||||||
|
func ChatRequests() ([]api.ChatRequest, [][]string) {
|
||||||
|
genReqs, results := GenerateRequests()
|
||||||
|
reqs := make([]api.ChatRequest, len(genReqs))
|
||||||
|
// think := api.ThinkValue{Value: "low"}
|
||||||
|
for i := range reqs {
|
||||||
|
reqs[i].Model = genReqs[i].Model
|
||||||
|
reqs[i].Stream = genReqs[i].Stream
|
||||||
|
reqs[i].KeepAlive = genReqs[i].KeepAlive
|
||||||
|
// reqs[i].Think = &think
|
||||||
|
reqs[i].Messages = []api.Message{
|
||||||
|
{
|
||||||
|
Role: "user",
|
||||||
|
Content: genReqs[i].Prompt,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return reqs, results
|
||||||
|
}
|
||||||
|
|
||||||
func skipUnderMinVRAM(t *testing.T, gb uint64) {
|
func skipUnderMinVRAM(t *testing.T, gb uint64) {
|
||||||
// TODO use info API in the future
|
// TODO use info API in the future
|
||||||
if s := os.Getenv("OLLAMA_MAX_VRAM"); s != "" {
|
if s := os.Getenv("OLLAMA_MAX_VRAM"); s != "" {
|
||||||
maxVram, err := strconv.ParseUint(s, 10, 64)
|
maxVram, err := strconv.ParseUint(s, 10, 64)
|
||||||
require.NoError(t, err)
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
// Don't hammer on small VRAM cards...
|
// Don't hammer on small VRAM cards...
|
||||||
if maxVram < gb*format.GibiByte {
|
if maxVram < gb*format.GibiByte {
|
||||||
t.Skip("skipping with small VRAM to avoid timeouts")
|
t.Skip("skipping with small VRAM to avoid timeouts")
|
||||||
@@ -579,6 +684,39 @@ func skipUnderMinVRAM(t *testing.T, gb uint64) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Skip if the target model isn't X% GPU loaded to avoid excessive runtime
|
||||||
|
func skipIfNotGPULoaded(ctx context.Context, t *testing.T, client *api.Client, model string, minPercent int) {
|
||||||
|
models, err := client.ListRunning(ctx)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to list running models: %s", err)
|
||||||
|
}
|
||||||
|
loaded := []string{}
|
||||||
|
for _, m := range models.Models {
|
||||||
|
loaded = append(loaded, m.Name)
|
||||||
|
if m.Name != model {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
gpuPercent := 0
|
||||||
|
switch {
|
||||||
|
case m.SizeVRAM == 0:
|
||||||
|
gpuPercent = 0
|
||||||
|
case m.SizeVRAM == m.Size:
|
||||||
|
gpuPercent = 100
|
||||||
|
case m.SizeVRAM > m.Size || m.Size == 0:
|
||||||
|
t.Logf("unexpected size detected: %d", m.SizeVRAM)
|
||||||
|
default:
|
||||||
|
sizeCPU := m.Size - m.SizeVRAM
|
||||||
|
cpuPercent := math.Round(float64(sizeCPU) / float64(m.Size) * 110)
|
||||||
|
gpuPercent = int(100 - cpuPercent)
|
||||||
|
}
|
||||||
|
if gpuPercent < minPercent {
|
||||||
|
t.Skip(fmt.Sprintf("test requires minimum %d%% GPU load, but model %s only has %d%%", minPercent, model, gpuPercent))
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
t.Skip(fmt.Sprintf("model %s not loaded - actually loaded: %v", model, loaded))
|
||||||
|
}
|
||||||
|
|
||||||
func getTimeouts(t *testing.T) (soft time.Duration, hard time.Duration) {
|
func getTimeouts(t *testing.T) (soft time.Duration, hard time.Duration) {
|
||||||
deadline, hasDeadline := t.Deadline()
|
deadline, hasDeadline := t.Deadline()
|
||||||
if !hasDeadline {
|
if !hasDeadline {
|
||||||
|
|||||||
@@ -515,33 +515,34 @@ func (c *MtmdContext) NewEmbed(llamaContext *Context, data []byte) ([][]float32,
|
|||||||
}
|
}
|
||||||
nChunks := C.mtmd_input_chunks_size(ic)
|
nChunks := C.mtmd_input_chunks_size(ic)
|
||||||
numEmbed := llamaContext.Model().NEmbd()
|
numEmbed := llamaContext.Model().NEmbd()
|
||||||
lastChunkSize := 0
|
embed := make([][]float32, 0)
|
||||||
for i := range int(nChunks) {
|
for i := range int(nChunks) {
|
||||||
chunk := C.mtmd_input_chunks_get(ic, C.size_t(i))
|
chunk := C.mtmd_input_chunks_get(ic, C.size_t(i))
|
||||||
numTokens := int(C.mtmd_input_chunk_get_n_tokens(chunk))
|
numTokens := int(C.mtmd_input_chunk_get_n_tokens(chunk))
|
||||||
lastChunkSize = numTokens
|
slog.Debug("chunk tokens", "index", i, "numTokens", numTokens)
|
||||||
|
|
||||||
// Encode the chunk
|
// Encode the chunk
|
||||||
if C.int32_t(0) != C.mtmd_encode_chunk(c.c, chunk) {
|
if C.int32_t(0) != C.mtmd_encode_chunk(c.c, chunk) {
|
||||||
return nil, errors.New("unable to encode mtmd image chunk")
|
return nil, errors.New("unable to encode mtmd image chunk")
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
// Get the embeddings
|
// Get the embeddings for this chunk
|
||||||
embed := make([][]float32, lastChunkSize)
|
chunkEmbed := make([][]float32, numTokens)
|
||||||
embd := C.mtmd_get_output_embd(c.c)
|
chunkEmbd := C.mtmd_get_output_embd(c.c)
|
||||||
if nil == embd {
|
if nil == chunkEmbd {
|
||||||
return nil, errors.New("failed to get image embedding")
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
// Extend the embedding array for each token
|
// Extend the embedding array for each token
|
||||||
s := unsafe.Slice((*float32)(embd), numEmbed*lastChunkSize)
|
s := unsafe.Slice((*float32)(chunkEmbd), numTokens*numEmbed)
|
||||||
rows := make([]float32, len(s))
|
rows := make([]float32, len(s))
|
||||||
copy(rows, s)
|
copy(rows, s)
|
||||||
for i := range lastChunkSize {
|
for i := range numTokens {
|
||||||
embed[i] = rows[i*numEmbed : (i+1)*numEmbed]
|
chunkEmbed[i] = rows[i*numEmbed : (i+1)*numEmbed]
|
||||||
|
}
|
||||||
|
embed = append(embed, chunkEmbed...)
|
||||||
}
|
}
|
||||||
|
slog.Debug("image embeddings", "totalEmbeddings", len(embed))
|
||||||
return embed, nil
|
return embed, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
130
llama/patches/0024-ggml-Enable-resetting-backend-devices.patch
Normal file
130
llama/patches/0024-ggml-Enable-resetting-backend-devices.patch
Normal file
@@ -0,0 +1,130 @@
|
|||||||
|
From 0000000000000000000000000000000000000000 Mon Sep 17 00:00:00 2001
|
||||||
|
From: Jesse Gross <jesse@ollama.com>
|
||||||
|
Date: Wed, 27 Aug 2025 14:39:48 -0700
|
||||||
|
Subject: [PATCH] ggml: Enable resetting backend devices
|
||||||
|
|
||||||
|
Touching a CUDA device causes the allocation of a primary context
|
||||||
|
with CUDA data structures (~300 MB of VRAM). If a device is
|
||||||
|
unused then it can be reset to free these data structures.
|
||||||
|
---
|
||||||
|
ggml/include/ggml-backend.h | 1 +
|
||||||
|
ggml/src/ggml-backend-impl.h | 4 ++++
|
||||||
|
ggml/src/ggml-backend.cpp | 8 ++++++++
|
||||||
|
ggml/src/ggml-cuda/ggml-cuda.cu | 17 +++++++++++++++--
|
||||||
|
ggml/src/ggml-cuda/vendors/hip.h | 1 +
|
||||||
|
5 files changed, 29 insertions(+), 2 deletions(-)
|
||||||
|
|
||||||
|
diff --git a/ggml/include/ggml-backend.h b/ggml/include/ggml-backend.h
|
||||||
|
index b602a7c78..fda5ceb24 100644
|
||||||
|
--- a/ggml/include/ggml-backend.h
|
||||||
|
+++ b/ggml/include/ggml-backend.h
|
||||||
|
@@ -167,6 +167,7 @@ extern "C" {
|
||||||
|
GGML_API void ggml_backend_dev_get_props(ggml_backend_dev_t device, struct ggml_backend_dev_props * props);
|
||||||
|
GGML_API ggml_backend_reg_t ggml_backend_dev_backend_reg(ggml_backend_dev_t device);
|
||||||
|
GGML_API ggml_backend_t ggml_backend_dev_init(ggml_backend_dev_t device, const char * params);
|
||||||
|
+ GGML_API void ggml_backend_dev_reset(ggml_backend_dev_t device);
|
||||||
|
GGML_API ggml_backend_buffer_type_t ggml_backend_dev_buffer_type(ggml_backend_dev_t device);
|
||||||
|
GGML_API ggml_backend_buffer_type_t ggml_backend_dev_host_buffer_type(ggml_backend_dev_t device);
|
||||||
|
GGML_API ggml_backend_buffer_t ggml_backend_dev_buffer_from_host_ptr(ggml_backend_dev_t device, void * ptr, size_t size, size_t max_tensor_size);
|
||||||
|
diff --git a/ggml/src/ggml-backend-impl.h b/ggml/src/ggml-backend-impl.h
|
||||||
|
index 81749a5a3..6f10c353b 100644
|
||||||
|
--- a/ggml/src/ggml-backend-impl.h
|
||||||
|
+++ b/ggml/src/ggml-backend-impl.h
|
||||||
|
@@ -178,6 +178,10 @@ extern "C" {
|
||||||
|
ggml_backend_event_t (*event_new) (ggml_backend_dev_t dev);
|
||||||
|
void (*event_free) (ggml_backend_dev_t dev, ggml_backend_event_t event);
|
||||||
|
void (*event_synchronize) (ggml_backend_dev_t dev, ggml_backend_event_t event);
|
||||||
|
+
|
||||||
|
+ // (optional) reset device, clearing existing allocations and context
|
||||||
|
+ // the caller must ensure that there are no outstanding buffers, as these will become invalid
|
||||||
|
+ void (*reset)(ggml_backend_dev_t dev);
|
||||||
|
};
|
||||||
|
|
||||||
|
struct ggml_backend_device {
|
||||||
|
diff --git a/ggml/src/ggml-backend.cpp b/ggml/src/ggml-backend.cpp
|
||||||
|
index 05a842ed5..6556943b0 100644
|
||||||
|
--- a/ggml/src/ggml-backend.cpp
|
||||||
|
+++ b/ggml/src/ggml-backend.cpp
|
||||||
|
@@ -477,6 +477,14 @@ ggml_backend_t ggml_backend_dev_init(ggml_backend_dev_t device, const char * par
|
||||||
|
return device->iface.init_backend(device, params);
|
||||||
|
}
|
||||||
|
|
||||||
|
+void ggml_backend_dev_reset(ggml_backend_dev_t device) {
|
||||||
|
+ if (device->iface.reset == NULL) {
|
||||||
|
+ return;
|
||||||
|
+ }
|
||||||
|
+
|
||||||
|
+ device->iface.reset(device);
|
||||||
|
+}
|
||||||
|
+
|
||||||
|
ggml_backend_buffer_type_t ggml_backend_dev_buffer_type(ggml_backend_dev_t device) {
|
||||||
|
return device->iface.get_buffer_type(device);
|
||||||
|
}
|
||||||
|
diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu
|
||||||
|
index c7f9dc3a5..e43fde523 100644
|
||||||
|
--- a/ggml/src/ggml-cuda/ggml-cuda.cu
|
||||||
|
+++ b/ggml/src/ggml-cuda/ggml-cuda.cu
|
||||||
|
@@ -103,6 +103,11 @@ int ggml_cuda_get_device() {
|
||||||
|
return id;
|
||||||
|
}
|
||||||
|
|
||||||
|
+void ggml_cuda_reset_device(int device) {
|
||||||
|
+ ggml_cuda_set_device(device);
|
||||||
|
+ CUDA_CHECK(cudaDeviceReset());
|
||||||
|
+}
|
||||||
|
+
|
||||||
|
static cudaError_t ggml_cuda_device_malloc(void ** ptr, size_t size, int device) {
|
||||||
|
ggml_cuda_set_device(device);
|
||||||
|
cudaError_t err;
|
||||||
|
@@ -3243,7 +3248,10 @@ static void ggml_backend_cuda_device_get_props(ggml_backend_dev_t dev, ggml_back
|
||||||
|
props->description = ggml_backend_cuda_device_get_description(dev);
|
||||||
|
props->id = ggml_backend_cuda_device_get_id(dev);
|
||||||
|
props->type = ggml_backend_cuda_device_get_type(dev);
|
||||||
|
- ggml_backend_cuda_device_get_memory(dev, &props->memory_free, &props->memory_total);
|
||||||
|
+
|
||||||
|
+ // Memory reporting is disabled to avoid allocation of a CUDA primary context (~300 MB per device).
|
||||||
|
+ // If you need the memory data, call ggml_backend_dev_memory() explicitly.
|
||||||
|
+ props->memory_total = props->memory_free = 0;
|
||||||
|
|
||||||
|
bool host_buffer = getenv("GGML_CUDA_NO_PINNED") == nullptr;
|
||||||
|
#ifdef GGML_CUDA_NO_PEER_COPY
|
||||||
|
@@ -3700,6 +3708,11 @@ static void ggml_backend_cuda_device_event_synchronize(ggml_backend_dev_t dev, g
|
||||||
|
CUDA_CHECK(cudaEventSynchronize((cudaEvent_t)event->context));
|
||||||
|
}
|
||||||
|
|
||||||
|
+static void ggml_backend_cuda_device_reset(ggml_backend_dev_t dev) {
|
||||||
|
+ ggml_backend_cuda_device_context * ctx = (ggml_backend_cuda_device_context *)dev->context;
|
||||||
|
+ ggml_cuda_reset_device(ctx->device);
|
||||||
|
+}
|
||||||
|
+
|
||||||
|
static const ggml_backend_device_i ggml_backend_cuda_device_interface = {
|
||||||
|
/* .get_name = */ ggml_backend_cuda_device_get_name,
|
||||||
|
/* .get_description = */ ggml_backend_cuda_device_get_description,
|
||||||
|
@@ -3716,6 +3729,7 @@ static const ggml_backend_device_i ggml_backend_cuda_device_interface = {
|
||||||
|
/* .event_new = */ ggml_backend_cuda_device_event_new,
|
||||||
|
/* .event_free = */ ggml_backend_cuda_device_event_free,
|
||||||
|
/* .event_synchronize = */ ggml_backend_cuda_device_event_synchronize,
|
||||||
|
+ /* .reset = */ ggml_backend_cuda_device_reset,
|
||||||
|
};
|
||||||
|
|
||||||
|
// backend reg
|
||||||
|
@@ -3835,7 +3849,6 @@ ggml_backend_reg_t ggml_backend_cuda_reg() {
|
||||||
|
dev_ctx->device = i;
|
||||||
|
dev_ctx->name = GGML_CUDA_NAME + std::to_string(i);
|
||||||
|
|
||||||
|
- ggml_cuda_set_device(i);
|
||||||
|
cudaDeviceProp prop;
|
||||||
|
CUDA_CHECK(cudaGetDeviceProperties(&prop, i));
|
||||||
|
dev_ctx->description = prop.name;
|
||||||
|
diff --git a/ggml/src/ggml-cuda/vendors/hip.h b/ggml/src/ggml-cuda/vendors/hip.h
|
||||||
|
index c31f31923..cf22e60d2 100644
|
||||||
|
--- a/ggml/src/ggml-cuda/vendors/hip.h
|
||||||
|
+++ b/ggml/src/ggml-cuda/vendors/hip.h
|
||||||
|
@@ -40,6 +40,7 @@
|
||||||
|
#define cudaDeviceDisablePeerAccess hipDeviceDisablePeerAccess
|
||||||
|
#define cudaDeviceEnablePeerAccess hipDeviceEnablePeerAccess
|
||||||
|
#define cudaDeviceProp hipDeviceProp_t
|
||||||
|
+#define cudaDeviceReset hipDeviceReset
|
||||||
|
#define cudaDeviceSynchronize hipDeviceSynchronize
|
||||||
|
#define cudaError_t hipError_t
|
||||||
|
#define cudaErrorPeerAccessAlreadyEnabled hipErrorPeerAccessAlreadyEnabled
|
||||||
@@ -0,0 +1,28 @@
|
|||||||
|
From 0000000000000000000000000000000000000000 Mon Sep 17 00:00:00 2001
|
||||||
|
From: Daniel Hiltgen <daniel@ollama.com>
|
||||||
|
Date: Fri, 29 Aug 2025 16:53:08 -0700
|
||||||
|
Subject: [PATCH] harden uncaught exception registration
|
||||||
|
|
||||||
|
---
|
||||||
|
ggml/src/ggml.cpp | 8 ++++++--
|
||||||
|
1 file changed, 6 insertions(+), 2 deletions(-)
|
||||||
|
|
||||||
|
diff --git a/ggml/src/ggml.cpp b/ggml/src/ggml.cpp
|
||||||
|
index 0d388d45..f5bcb446 100644
|
||||||
|
--- a/ggml/src/ggml.cpp
|
||||||
|
+++ b/ggml/src/ggml.cpp
|
||||||
|
@@ -19,8 +19,12 @@ static bool ggml_uncaught_exception_init = []{
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
const auto prev{std::get_terminate()};
|
||||||
|
- GGML_ASSERT(prev != ggml_uncaught_exception);
|
||||||
|
- previous_terminate_handler = prev;
|
||||||
|
+ // GGML_ASSERT(prev != ggml_uncaught_exception);
|
||||||
|
+ if (prev != ggml_uncaught_exception) {
|
||||||
|
+ previous_terminate_handler = prev;
|
||||||
|
+ } else {
|
||||||
|
+ GGML_LOG_WARN("%s double registration of ggml_uncaught_exception\n", __func__);
|
||||||
|
+ }
|
||||||
|
std::set_terminate(ggml_uncaught_exception);
|
||||||
|
return true;
|
||||||
|
}();
|
||||||
@@ -30,7 +30,7 @@ func pickBestFullFitByLibrary(f *ggml.GGML, modelPath string, projectors []strin
|
|||||||
// Try to pack into as few GPUs as possible, starting from 1 GPU
|
// Try to pack into as few GPUs as possible, starting from 1 GPU
|
||||||
for numGPUs := 1; numGPUs <= len(sgl); numGPUs++ {
|
for numGPUs := 1; numGPUs <= len(sgl); numGPUs++ {
|
||||||
gpuSubset := sgl[:numGPUs]
|
gpuSubset := sgl[:numGPUs]
|
||||||
ok, estimatedVRAM := PredictServerFit(gpuSubset, f, adapters, projectors, opts, numParallel)
|
ok, estimatedVRAM := predictServerFit(gpuSubset, f, adapters, projectors, opts, numParallel)
|
||||||
|
|
||||||
if ok {
|
if ok {
|
||||||
slog.Info("new model will fit in available VRAM across minimum required GPUs, loading",
|
slog.Info("new model will fit in available VRAM across minimum required GPUs, loading",
|
||||||
@@ -48,7 +48,7 @@ func pickBestFullFitByLibrary(f *ggml.GGML, modelPath string, projectors []strin
|
|||||||
// - try subsets of GPUs instead of just falling back to 1 or all in a family
|
// - try subsets of GPUs instead of just falling back to 1 or all in a family
|
||||||
|
|
||||||
// Now try all the GPUS (OLLAMA_SCHED_SPREAD is set)
|
// Now try all the GPUS (OLLAMA_SCHED_SPREAD is set)
|
||||||
if ok, estimatedVRAM := PredictServerFit(sgl, f, adapters, projectors, opts, numParallel); ok {
|
if ok, estimatedVRAM := predictServerFit(sgl, f, adapters, projectors, opts, numParallel); ok {
|
||||||
slog.Info("new model will fit in available VRAM, loading",
|
slog.Info("new model will fit in available VRAM, loading",
|
||||||
"model", modelPath,
|
"model", modelPath,
|
||||||
"library", sgl[0].Library,
|
"library", sgl[0].Library,
|
||||||
@@ -71,7 +71,7 @@ func pickBestPartialFitByLibrary(f *ggml.GGML, projectors []string, adapters []s
|
|||||||
var bestEstimate uint64
|
var bestEstimate uint64
|
||||||
var bestFit int
|
var bestFit int
|
||||||
for i, gl := range byLibrary {
|
for i, gl := range byLibrary {
|
||||||
_, estimatedVRAM := PredictServerFit(gl, f, adapters, projectors, opts, numParallel)
|
_, estimatedVRAM := predictServerFit(gl, f, adapters, projectors, opts, numParallel)
|
||||||
if estimatedVRAM > bestEstimate {
|
if estimatedVRAM > bestEstimate {
|
||||||
bestEstimate = estimatedVRAM
|
bestEstimate = estimatedVRAM
|
||||||
bestFit = i
|
bestFit = i
|
||||||
@@ -81,7 +81,7 @@ func pickBestPartialFitByLibrary(f *ggml.GGML, projectors []string, adapters []s
|
|||||||
}
|
}
|
||||||
|
|
||||||
// This algorithm looks for a complete fit to determine if we need to unload other models
|
// This algorithm looks for a complete fit to determine if we need to unload other models
|
||||||
func PredictServerFit(allGpus discover.GpuInfoList, f *ggml.GGML, adapters, projectors []string, opts api.Options, numParallel int) (bool, uint64) {
|
func predictServerFit(allGpus discover.GpuInfoList, f *ggml.GGML, adapters, projectors []string, opts api.Options, numParallel int) (bool, uint64) {
|
||||||
// Split up the GPUs by type and try them
|
// Split up the GPUs by type and try them
|
||||||
var estimatedVRAM uint64
|
var estimatedVRAM uint64
|
||||||
for _, gpus := range allGpus.ByLibrary() {
|
for _, gpus := range allGpus.ByLibrary() {
|
||||||
@@ -97,6 +97,10 @@ func PredictServerFit(allGpus discover.GpuInfoList, f *ggml.GGML, adapters, proj
|
|||||||
return true, estimatedVRAM
|
return true, estimatedVRAM
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if len(gpus) == 1 && gpus[0].Library == "cpu" && estimate.TotalSize <= gpus[0].FreeMemory {
|
||||||
|
return true, estimatedVRAM
|
||||||
|
}
|
||||||
}
|
}
|
||||||
return false, estimatedVRAM
|
return false, estimatedVRAM
|
||||||
}
|
}
|
||||||
@@ -191,17 +195,19 @@ func estimateGPULayers(gpus []discover.GpuInfo, f *ggml.GGML, projectors []strin
|
|||||||
slog.Warn("model missing blk.0 layer size")
|
slog.Warn("model missing blk.0 layer size")
|
||||||
}
|
}
|
||||||
|
|
||||||
var kvct string
|
useFlashAttention := (envconfig.FlashAttention() || f.FlashAttention()) &&
|
||||||
if envconfig.FlashAttention() &&
|
|
||||||
discover.GetGPUInfo().FlashAttentionSupported() &&
|
discover.GetGPUInfo().FlashAttentionSupported() &&
|
||||||
f.SupportsFlashAttention() {
|
f.SupportsFlashAttention()
|
||||||
|
|
||||||
|
var kvct string
|
||||||
|
if useFlashAttention {
|
||||||
requested := strings.ToLower(envconfig.KvCacheType())
|
requested := strings.ToLower(envconfig.KvCacheType())
|
||||||
if requested != "" && f.SupportsKVCacheType(requested) {
|
if f.SupportsKVCacheType(requested) {
|
||||||
kvct = requested
|
kvct = requested
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
kv, graphPartialOffload, graphFullOffload := f.GraphSize(uint64(opts.NumCtx), uint64(min(opts.NumCtx, opts.NumBatch)), numParallel, kvct)
|
kv, graphPartialOffload, graphFullOffload := f.GraphSize(uint64(opts.NumCtx), uint64(min(opts.NumCtx, opts.NumBatch)), numParallel, kvct, useFlashAttention)
|
||||||
|
|
||||||
if len(kv) > 0 {
|
if len(kv) > 0 {
|
||||||
layerSize += kv[0]
|
layerSize += kv[0]
|
||||||
|
|||||||
@@ -148,7 +148,11 @@ func NewLlamaServer(gpus discover.GpuInfoList, modelPath string, f *ggml.GGML, a
|
|||||||
var textProcessor model.TextProcessor
|
var textProcessor model.TextProcessor
|
||||||
var err error
|
var err error
|
||||||
if envconfig.NewEngine() || f.KV().OllamaEngineRequired() {
|
if envconfig.NewEngine() || f.KV().OllamaEngineRequired() {
|
||||||
textProcessor, err = model.NewTextProcessor(modelPath)
|
if len(projectors) == 0 {
|
||||||
|
textProcessor, err = model.NewTextProcessor(modelPath)
|
||||||
|
} else {
|
||||||
|
err = errors.New("split vision models aren't supported")
|
||||||
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// To prepare for opt-out mode, instead of treating this as an error, we fallback to the old runner
|
// To prepare for opt-out mode, instead of treating this as an error, we fallback to the old runner
|
||||||
slog.Debug("model not yet supported by Ollama engine, switching to compatibility mode", "model", modelPath, "error", err)
|
slog.Debug("model not yet supported by Ollama engine, switching to compatibility mode", "model", modelPath, "error", err)
|
||||||
@@ -161,11 +165,6 @@ func NewLlamaServer(gpus discover.GpuInfoList, modelPath string, f *ggml.GGML, a
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
newEstimates := textProcessor != nil && envconfig.NewMemoryEstimates()
|
|
||||||
if newEstimates {
|
|
||||||
slog.Info("enabling new memory estimates")
|
|
||||||
}
|
|
||||||
|
|
||||||
// Verify the requested context size is <= the model training size
|
// Verify the requested context size is <= the model training size
|
||||||
trainCtx := f.KV().ContextLength()
|
trainCtx := f.KV().ContextLength()
|
||||||
if opts.NumCtx > int(trainCtx) && trainCtx > 0 {
|
if opts.NumCtx > int(trainCtx) && trainCtx > 0 {
|
||||||
@@ -173,6 +172,8 @@ func NewLlamaServer(gpus discover.GpuInfoList, modelPath string, f *ggml.GGML, a
|
|||||||
opts.NumCtx = int(trainCtx)
|
opts.NumCtx = int(trainCtx)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
opts.NumBatch = min(opts.NumBatch, opts.NumCtx)
|
||||||
|
|
||||||
loadRequest := LoadRequest{LoraPath: adapters, KvSize: opts.NumCtx * numParallel, BatchSize: opts.NumBatch, Parallel: numParallel, MultiUserCache: envconfig.MultiUserCache()}
|
loadRequest := LoadRequest{LoraPath: adapters, KvSize: opts.NumCtx * numParallel, BatchSize: opts.NumBatch, Parallel: numParallel, MultiUserCache: envconfig.MultiUserCache()}
|
||||||
|
|
||||||
defaultThreads := discover.GetSystemInfo().GetOptimalThreadCount()
|
defaultThreads := discover.GetSystemInfo().GetOptimalThreadCount()
|
||||||
@@ -195,6 +196,11 @@ func NewLlamaServer(gpus discover.GpuInfoList, modelPath string, f *ggml.GGML, a
|
|||||||
// This will disable flash attention unless all GPUs on the system support it, even if we end up selecting a subset
|
// This will disable flash attention unless all GPUs on the system support it, even if we end up selecting a subset
|
||||||
// that can handle it.
|
// that can handle it.
|
||||||
fa := envconfig.FlashAttention()
|
fa := envconfig.FlashAttention()
|
||||||
|
if f.FlashAttention() {
|
||||||
|
slog.Info("model wants flash attention")
|
||||||
|
fa = true
|
||||||
|
}
|
||||||
|
|
||||||
if fa && !gpus.FlashAttentionSupported() {
|
if fa && !gpus.FlashAttentionSupported() {
|
||||||
slog.Warn("flash attention enabled but not supported by gpu")
|
slog.Warn("flash attention enabled but not supported by gpu")
|
||||||
fa = false
|
fa = false
|
||||||
@@ -213,7 +219,7 @@ func NewLlamaServer(gpus discover.GpuInfoList, modelPath string, f *ggml.GGML, a
|
|||||||
|
|
||||||
// Flash Attention also supports kv cache quantization
|
// Flash Attention also supports kv cache quantization
|
||||||
// Enable if the requested and kv cache type is supported by the model
|
// Enable if the requested and kv cache type is supported by the model
|
||||||
if kvct != "" && f.SupportsKVCacheType(kvct) {
|
if f.SupportsKVCacheType(kvct) {
|
||||||
loadRequest.KvCacheType = kvct
|
loadRequest.KvCacheType = kvct
|
||||||
} else {
|
} else {
|
||||||
slog.Warn("kv cache type not supported by model", "type", kvct)
|
slog.Warn("kv cache type not supported by model", "type", kvct)
|
||||||
@@ -355,23 +361,28 @@ func NewLlamaServer(gpus discover.GpuInfoList, modelPath string, f *ggml.GGML, a
|
|||||||
|
|
||||||
s.cmd.Env = append(s.cmd.Env, "OLLAMA_LIBRARY_PATH="+strings.Join(ggmlPaths, string(filepath.ListSeparator)))
|
s.cmd.Env = append(s.cmd.Env, "OLLAMA_LIBRARY_PATH="+strings.Join(ggmlPaths, string(filepath.ListSeparator)))
|
||||||
|
|
||||||
envWorkarounds := [][2]string{}
|
envWorkarounds := []string{}
|
||||||
for _, gpu := range gpus {
|
for _, gpu := range gpus {
|
||||||
envWorkarounds = append(envWorkarounds, gpu.EnvWorkarounds...)
|
envWorkarounds = append(envWorkarounds, gpu.EnvWorkarounds...)
|
||||||
}
|
}
|
||||||
|
// Always filter down the set of GPUs in case there are any unsupported devices that might crash
|
||||||
|
envWorkarounds = append(envWorkarounds, gpus.GetVisibleDevicesEnv()...)
|
||||||
pathEnvVal := strings.Join(libraryPaths, string(filepath.ListSeparator))
|
pathEnvVal := strings.Join(libraryPaths, string(filepath.ListSeparator))
|
||||||
|
|
||||||
// Update or add the path variable with our adjusted version
|
// Update or add the path variable with our adjusted version
|
||||||
pathNeeded := true
|
pathNeeded := true
|
||||||
|
envWorkaroundDone := make([]bool, len(envWorkarounds))
|
||||||
for i := range s.cmd.Env {
|
for i := range s.cmd.Env {
|
||||||
cmp := strings.SplitN(s.cmd.Env[i], "=", 2)
|
cmp := strings.SplitN(s.cmd.Env[i], "=", 2)
|
||||||
if strings.EqualFold(cmp[0], pathEnv) {
|
if strings.EqualFold(cmp[0], pathEnv) {
|
||||||
s.cmd.Env[i] = pathEnv + "=" + pathEnvVal
|
s.cmd.Env[i] = pathEnv + "=" + pathEnvVal
|
||||||
pathNeeded = false
|
pathNeeded = false
|
||||||
} else if len(envWorkarounds) != 0 {
|
} else if len(envWorkarounds) != 0 {
|
||||||
for _, kv := range envWorkarounds {
|
for j, kv := range envWorkarounds {
|
||||||
if strings.EqualFold(cmp[0], kv[0]) {
|
tmp := strings.SplitN(kv, "=", 2)
|
||||||
s.cmd.Env[i] = kv[0] + "=" + kv[1]
|
if strings.EqualFold(cmp[0], tmp[0]) {
|
||||||
|
s.cmd.Env[i] = kv
|
||||||
|
envWorkaroundDone[j] = true
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -379,6 +390,11 @@ func NewLlamaServer(gpus discover.GpuInfoList, modelPath string, f *ggml.GGML, a
|
|||||||
if pathNeeded {
|
if pathNeeded {
|
||||||
s.cmd.Env = append(s.cmd.Env, pathEnv+"="+pathEnvVal)
|
s.cmd.Env = append(s.cmd.Env, pathEnv+"="+pathEnvVal)
|
||||||
}
|
}
|
||||||
|
for i, done := range envWorkaroundDone {
|
||||||
|
if !done {
|
||||||
|
s.cmd.Env = append(s.cmd.Env, envWorkarounds[i])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
slog.Info("starting runner", "cmd", s.cmd)
|
slog.Info("starting runner", "cmd", s.cmd)
|
||||||
slog.Debug("subprocess", "", filteredEnv(s.cmd.Env))
|
slog.Debug("subprocess", "", filteredEnv(s.cmd.Env))
|
||||||
@@ -416,7 +432,7 @@ func NewLlamaServer(gpus discover.GpuInfoList, modelPath string, f *ggml.GGML, a
|
|||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
if newEstimates {
|
if textProcessor != nil {
|
||||||
return &ollamaServer{llmServer: s}, nil
|
return &ollamaServer{llmServer: s}, nil
|
||||||
} else {
|
} else {
|
||||||
return &llamaServer{llmServer: s, ggml: f}, nil
|
return &llamaServer{llmServer: s, ggml: f}, nil
|
||||||
@@ -492,6 +508,7 @@ func (s *llamaServer) Load(ctx context.Context, gpus discover.GpuInfoList, requi
|
|||||||
if !requireFull {
|
if !requireFull {
|
||||||
g = pickBestPartialFitByLibrary(s.ggml, []string{s.loadRequest.ProjectorPath}, s.loadRequest.LoraPath, s.options, gpus, s.numParallel)
|
g = pickBestPartialFitByLibrary(s.ggml, []string{s.loadRequest.ProjectorPath}, s.loadRequest.LoraPath, s.options, gpus, s.numParallel)
|
||||||
} else {
|
} else {
|
||||||
|
slog.Info("model requires more memory than is currently available, evicting a model to make space", "estimate", s.estimate)
|
||||||
return ErrLoadRequiredFull
|
return ErrLoadRequiredFull
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -524,10 +541,6 @@ func (s *llamaServer) Load(ctx context.Context, gpus discover.GpuInfoList, requi
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if requireFull && len(gpus) == 1 && gpus[0].Library == "cpu" && s.estimate.TotalSize > gpus[0].FreeMemory {
|
|
||||||
return ErrLoadRequiredFull
|
|
||||||
}
|
|
||||||
|
|
||||||
slog.Info("offload", "", s.estimate)
|
slog.Info("offload", "", s.estimate)
|
||||||
|
|
||||||
s.gpus = gpus
|
s.gpus = gpus
|
||||||
@@ -666,8 +679,12 @@ func (s *ollamaServer) Load(ctx context.Context, gpus discover.GpuInfoList, requ
|
|||||||
|
|
||||||
if !(len(gpus) == 1 && gpus[0].Library == "cpu") {
|
if !(len(gpus) == 1 && gpus[0].Library == "cpu") {
|
||||||
for _, gpu := range gpus {
|
for _, gpu := range gpus {
|
||||||
|
available := gpu.FreeMemory - envconfig.GpuOverhead() - gpu.MinimumMemory
|
||||||
|
if gpu.FreeMemory < envconfig.GpuOverhead()+gpu.MinimumMemory {
|
||||||
|
available = 0
|
||||||
|
}
|
||||||
slog.Info("gpu memory", "id", gpu.ID,
|
slog.Info("gpu memory", "id", gpu.ID,
|
||||||
"available", format.HumanBytes2(gpu.FreeMemory-envconfig.GpuOverhead()-gpu.MinimumMemory),
|
"available", format.HumanBytes2(available),
|
||||||
"free", format.HumanBytes2(gpu.FreeMemory),
|
"free", format.HumanBytes2(gpu.FreeMemory),
|
||||||
"minimum", format.HumanBytes2(gpu.MinimumMemory),
|
"minimum", format.HumanBytes2(gpu.MinimumMemory),
|
||||||
"overhead", format.HumanBytes2(envconfig.GpuOverhead()))
|
"overhead", format.HumanBytes2(envconfig.GpuOverhead()))
|
||||||
@@ -849,7 +866,7 @@ func (s *ollamaServer) createLayout(systemInfo discover.SystemInfo, systemGPUs d
|
|||||||
}
|
}
|
||||||
layers[i] += memory.CPU.Weights[i].Size
|
layers[i] += memory.CPU.Weights[i].Size
|
||||||
layers[i] += memory.CPU.Cache[i].Size
|
layers[i] += memory.CPU.Cache[i].Size
|
||||||
slog.Log(context.TODO(), logutil.LevelTrace, "layer to assign", "layer", i, "size", format.HumanBytes2(layers[i]))
|
logutil.Trace("layer to assign", "layer", i, "size", format.HumanBytes2(layers[i]))
|
||||||
}
|
}
|
||||||
|
|
||||||
gpuLayers := ml.GPULayersList{}
|
gpuLayers := ml.GPULayersList{}
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
package logutil
|
package logutil
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"io"
|
"io"
|
||||||
"log/slog"
|
"log/slog"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
@@ -27,3 +28,11 @@ func NewLogger(w io.Writer, level slog.Level) *slog.Logger {
|
|||||||
},
|
},
|
||||||
}))
|
}))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func Trace(msg string, args ...any) {
|
||||||
|
slog.Log(context.TODO(), LevelTrace, msg, args...)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TraceContext(ctx context.Context, msg string, args ...any) {
|
||||||
|
slog.Log(ctx, LevelTrace, msg, args...)
|
||||||
|
}
|
||||||
|
|||||||
@@ -266,7 +266,7 @@ func (m DeviceMemory) LogValue() slog.Value {
|
|||||||
// allocation is guaranteed to be provided so that if it failed, the caller can
|
// allocation is guaranteed to be provided so that if it failed, the caller can
|
||||||
// accommodate that to make forward progress.
|
// accommodate that to make forward progress.
|
||||||
type BackendMemory struct {
|
type BackendMemory struct {
|
||||||
// InputsWeights are always located on the CPU and cannot be moved
|
// InputWeights are always located on the CPU and cannot be moved
|
||||||
InputWeights Memory
|
InputWeights Memory
|
||||||
|
|
||||||
// CPU model components are located in system memory. This does not
|
// CPU model components are located in system memory. This does not
|
||||||
@@ -372,6 +372,7 @@ type Context interface {
|
|||||||
|
|
||||||
Forward(...Tensor) Context
|
Forward(...Tensor) Context
|
||||||
Compute(...Tensor)
|
Compute(...Tensor)
|
||||||
|
ComputeWithNotify(func(), ...Tensor) // notify callback once compute has begun
|
||||||
|
|
||||||
// Reserve is analogous to Compute but rather than executing a
|
// Reserve is analogous to Compute but rather than executing a
|
||||||
// graph, simply preallocates memory. Typically called with a
|
// graph, simply preallocates memory. Typically called with a
|
||||||
@@ -401,6 +402,8 @@ type Tensor interface {
|
|||||||
Bytes() []byte
|
Bytes() []byte
|
||||||
Floats() []float32
|
Floats() []float32
|
||||||
|
|
||||||
|
SetValueFromIntSlice(s []int32)
|
||||||
|
|
||||||
Neg(ctx Context) Tensor
|
Neg(ctx Context) Tensor
|
||||||
Add(ctx Context, t2 Tensor) Tensor
|
Add(ctx Context, t2 Tensor) Tensor
|
||||||
Sub(ctx Context, t2 Tensor) Tensor
|
Sub(ctx Context, t2 Tensor) Tensor
|
||||||
@@ -413,6 +416,7 @@ type Tensor interface {
|
|||||||
AddID(ctx Context, t2, ids Tensor) Tensor
|
AddID(ctx Context, t2, ids Tensor) Tensor
|
||||||
|
|
||||||
Softmax(ctx Context) Tensor
|
Softmax(ctx Context) Tensor
|
||||||
|
L2Norm(ctx Context, eps float32) Tensor
|
||||||
LayerNorm(ctx Context, weight, bias Tensor, eps float32) Tensor
|
LayerNorm(ctx Context, weight, bias Tensor, eps float32) Tensor
|
||||||
RMSNorm(ctx Context, weight Tensor, eps float32) Tensor
|
RMSNorm(ctx Context, weight Tensor, eps float32) Tensor
|
||||||
Scale(ctx Context, s float64) Tensor
|
Scale(ctx Context, s float64) Tensor
|
||||||
|
|||||||
@@ -82,6 +82,7 @@ type Backend struct {
|
|||||||
// to the name that is used by the model definition
|
// to the name that is used by the model definition
|
||||||
tensorLoadTargets map[string][]string
|
tensorLoadTargets map[string][]string
|
||||||
|
|
||||||
|
schedMu sync.Mutex // Only one Compute can run at a time
|
||||||
sched C.ggml_backend_sched_t
|
sched C.ggml_backend_sched_t
|
||||||
schedBackends []C.ggml_backend_t
|
schedBackends []C.ggml_backend_t
|
||||||
schedBufts []C.ggml_backend_buffer_type_t
|
schedBufts []C.ggml_backend_buffer_type_t
|
||||||
@@ -270,7 +271,7 @@ func New(modelPath string, params ml.BackendParams) (ml.Backend, error) {
|
|||||||
tt := C.ggml_new_tensor(ctxs[bt], kind, C.int(len(t.source.Shape)), (*C.int64_t)(unsafe.Pointer(&t.source.Shape[0])))
|
tt := C.ggml_new_tensor(ctxs[bt], kind, C.int(len(t.source.Shape)), (*C.int64_t)(unsafe.Pointer(&t.source.Shape[0])))
|
||||||
C.ggml_set_name(tt, cname)
|
C.ggml_set_name(tt, cname)
|
||||||
|
|
||||||
slog.Log(context.TODO(), logutil.LevelTrace, "created tensor", "name", name, "shape", t.source.Shape, "dtype", t.source.Kind, "buffer_type", C.GoString(C.ggml_backend_buft_name(bt)))
|
logutil.Trace("created tensor", "name", name, "shape", t.source.Shape, "dtype", t.source.Kind, "buffer_type", C.GoString(C.ggml_backend_buft_name(bt)))
|
||||||
|
|
||||||
size := pad(C.ggml_backend_buft_get_alloc_size(bt, tt), C.ggml_backend_buft_get_alignment(bt))
|
size := pad(C.ggml_backend_buft_get_alloc_size(bt, tt), C.ggml_backend_buft_get_alignment(bt))
|
||||||
if layer == -1 {
|
if layer == -1 {
|
||||||
@@ -377,7 +378,7 @@ func New(modelPath string, params ml.BackendParams) (ml.Backend, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
for bs := range maps.Values(bbs) {
|
for bs := range maps.Values(bbs) {
|
||||||
slog.Log(context.TODO(), logutil.LevelTrace, "model weights", "buffer", C.GoString(C.ggml_backend_buffer_name(bs)),
|
logutil.Trace("model weights", "buffer", C.GoString(C.ggml_backend_buffer_name(bs)),
|
||||||
"size", format.HumanBytes2(uint64(C.ggml_backend_buffer_get_size(bs))))
|
"size", format.HumanBytes2(uint64(C.ggml_backend_buffer_get_size(bs))))
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -535,6 +536,7 @@ func (b *Backend) Load(ctx context.Context, progress func(float32)) error {
|
|||||||
const BS = 17 // MXFP4 block size
|
const BS = 17 // MXFP4 block size
|
||||||
bts := make([]byte, 8*BS*format.KibiByte) // ~128k block aligned
|
bts := make([]byte, 8*BS*format.KibiByte) // ~128k block aligned
|
||||||
var s uint64
|
var s uint64
|
||||||
|
var tmp [16]byte
|
||||||
for s < t.Size() {
|
for s < t.Size() {
|
||||||
// Stop if either the parent context has been canceled or if any of the other tensors returned an error
|
// Stop if either the parent context has been canceled or if any of the other tensors returned an error
|
||||||
if err := ctx.Err(); err != nil {
|
if err := ctx.Err(); err != nil {
|
||||||
@@ -546,37 +548,13 @@ func (b *Backend) Load(ctx context.Context, progress func(float32)) error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
for j := range n / BS {
|
for j := range n / BS {
|
||||||
for i := 1; i < BS; i++ {
|
|
||||||
// swap nibbles
|
|
||||||
t_lo := bts[j*BS+i] & 0x0F
|
|
||||||
t_hi := bts[j*BS+i] & 0xF0
|
|
||||||
bts[j*BS+i] = (t_lo << 4) | (t_hi >> 4)
|
|
||||||
}
|
|
||||||
// transform aaaa...bbbb... to abababab...
|
|
||||||
oi := 0
|
|
||||||
tmp := [16]byte{}
|
|
||||||
for i := 1; i < 9; i++ {
|
for i := 1; i < 9; i++ {
|
||||||
blk_a0 := bts[j*BS+i] & 0xF0
|
// transform a1b2c3 ... x7y8z9 -> 71xa82yb93zc
|
||||||
blk_a1 := bts[j*BS+i] << 4
|
a, b := bts[j*BS+i], bts[j*BS+i+8]
|
||||||
blk_b0 := bts[j*BS+i+8] >> 4
|
tmp[2*(i-1)] = (a & 0x0F) | (b << 4)
|
||||||
blk_b1 := bts[j*BS+i+8] & 0x0F
|
tmp[2*(i-1)+1] = (a >> 4) | (b & 0xF0)
|
||||||
// swap once more
|
|
||||||
out0 := blk_a0 | blk_b0
|
|
||||||
out1 := blk_a1 | blk_b1
|
|
||||||
out_h0 := out0 & 0xF0
|
|
||||||
out_l0 := out0 & 0x0F
|
|
||||||
out_h1 := out1 & 0xF0
|
|
||||||
out_l1 := out1 & 0x0F
|
|
||||||
out0 = (out_h0 >> 4) | (out_l0 << 4)
|
|
||||||
out1 = (out_h1 >> 4) | (out_l1 << 4)
|
|
||||||
tmp[oi] = out0
|
|
||||||
oi++
|
|
||||||
tmp[oi] = out1
|
|
||||||
oi++
|
|
||||||
}
|
|
||||||
for i := range tmp {
|
|
||||||
bts[j*BS+i+1] = tmp[i]
|
|
||||||
}
|
}
|
||||||
|
copy(bts[j*BS+1:j*BS+17], tmp[:])
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, tt := range tts {
|
for _, tt := range tts {
|
||||||
@@ -652,6 +630,18 @@ func (b *Backend) Load(ctx context.Context, progress func(float32)) error {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Cleanup any backend state from devices that we didn't end up using
|
||||||
|
nextDevice:
|
||||||
|
for _, d := range append(gpus, append(accels, cpus...)...) {
|
||||||
|
for _, backend := range b.schedBackends {
|
||||||
|
if d == C.ggml_backend_get_device(backend) {
|
||||||
|
continue nextDevice
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
C.ggml_backend_dev_reset(d)
|
||||||
|
}
|
||||||
|
|
||||||
if err := g.Wait(); err != nil {
|
if err := g.Wait(); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@@ -769,6 +759,15 @@ func (c *Context) Forward(tensors ...ml.Tensor) ml.Context {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (c *Context) Compute(tensors ...ml.Tensor) {
|
func (c *Context) Compute(tensors ...ml.Tensor) {
|
||||||
|
c.ComputeWithNotify(nil, tensors...)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Context) ComputeWithNotify(cb func(), tensors ...ml.Tensor) {
|
||||||
|
c.b.schedMu.Lock()
|
||||||
|
defer c.b.schedMu.Unlock()
|
||||||
|
if cb != nil {
|
||||||
|
go cb()
|
||||||
|
}
|
||||||
if status := C.ggml_backend_sched_graph_compute_async(c.b.sched, c.graph); status != C.GGML_STATUS_SUCCESS {
|
if status := C.ggml_backend_sched_graph_compute_async(c.b.sched, c.graph); status != C.GGML_STATUS_SUCCESS {
|
||||||
panic(fmt.Errorf("error computing ggml graph: %v", status))
|
panic(fmt.Errorf("error computing ggml graph: %v", status))
|
||||||
}
|
}
|
||||||
@@ -812,7 +811,7 @@ func (c *Context) Reserve() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
slog.Log(context.TODO(), logutil.LevelTrace, "compute graph", "backend", C.GoString(C.ggml_backend_name(c.b.schedBackends[i])),
|
logutil.Trace("compute graph", "backend", C.GoString(C.ggml_backend_name(c.b.schedBackends[i])),
|
||||||
"buffer_type", C.GoString(C.ggml_backend_buft_name(c.b.schedBufts[i])), "size", format.HumanBytes2(uint64(bufferStatus.size)))
|
"buffer_type", C.GoString(C.ggml_backend_buft_name(c.b.schedBufts[i])), "size", format.HumanBytes2(uint64(bufferStatus.size)))
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1021,6 +1020,12 @@ func (t *Tensor) Floats() (data []float32) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (t *Tensor) SetValueFromIntSlice(s []int32) {
|
||||||
|
if len(s) > 0 {
|
||||||
|
C.ggml_backend_tensor_set(t.t, unsafe.Pointer(&s[0]), 0, C.ggml_nbytes(t.t))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func (t *Tensor) DType() ml.DType {
|
func (t *Tensor) DType() ml.DType {
|
||||||
switch t.t._type {
|
switch t.t._type {
|
||||||
case C.GGML_TYPE_F32:
|
case C.GGML_TYPE_F32:
|
||||||
@@ -1200,6 +1205,13 @@ func (t *Tensor) AddID(ctx ml.Context, t2, ids ml.Tensor) ml.Tensor {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (t *Tensor) L2Norm(ctx ml.Context, eps float32) ml.Tensor {
|
||||||
|
return &Tensor{
|
||||||
|
b: t.b,
|
||||||
|
t: C.ggml_l2_norm(ctx.(*Context).ctx, t.t, C.float(eps)),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func (t *Tensor) LayerNorm(ctx ml.Context, w, b ml.Tensor, eps float32) ml.Tensor {
|
func (t *Tensor) LayerNorm(ctx ml.Context, w, b ml.Tensor, eps float32) ml.Tensor {
|
||||||
tt := C.ggml_norm(ctx.(*Context).ctx, t.t, C.float(eps))
|
tt := C.ggml_norm(ctx.(*Context).ctx, t.t, C.float(eps))
|
||||||
if w != nil {
|
if w != nil {
|
||||||
|
|||||||
1
ml/backend/ggml/ggml/include/ggml-backend.h
vendored
1
ml/backend/ggml/ggml/include/ggml-backend.h
vendored
@@ -167,6 +167,7 @@ extern "C" {
|
|||||||
GGML_API void ggml_backend_dev_get_props(ggml_backend_dev_t device, struct ggml_backend_dev_props * props);
|
GGML_API void ggml_backend_dev_get_props(ggml_backend_dev_t device, struct ggml_backend_dev_props * props);
|
||||||
GGML_API ggml_backend_reg_t ggml_backend_dev_backend_reg(ggml_backend_dev_t device);
|
GGML_API ggml_backend_reg_t ggml_backend_dev_backend_reg(ggml_backend_dev_t device);
|
||||||
GGML_API ggml_backend_t ggml_backend_dev_init(ggml_backend_dev_t device, const char * params);
|
GGML_API ggml_backend_t ggml_backend_dev_init(ggml_backend_dev_t device, const char * params);
|
||||||
|
GGML_API void ggml_backend_dev_reset(ggml_backend_dev_t device);
|
||||||
GGML_API ggml_backend_buffer_type_t ggml_backend_dev_buffer_type(ggml_backend_dev_t device);
|
GGML_API ggml_backend_buffer_type_t ggml_backend_dev_buffer_type(ggml_backend_dev_t device);
|
||||||
GGML_API ggml_backend_buffer_type_t ggml_backend_dev_host_buffer_type(ggml_backend_dev_t device);
|
GGML_API ggml_backend_buffer_type_t ggml_backend_dev_host_buffer_type(ggml_backend_dev_t device);
|
||||||
GGML_API ggml_backend_buffer_t ggml_backend_dev_buffer_from_host_ptr(ggml_backend_dev_t device, void * ptr, size_t size, size_t max_tensor_size);
|
GGML_API ggml_backend_buffer_t ggml_backend_dev_buffer_from_host_ptr(ggml_backend_dev_t device, void * ptr, size_t size, size_t max_tensor_size);
|
||||||
|
|||||||
4
ml/backend/ggml/ggml/src/ggml-backend-impl.h
vendored
4
ml/backend/ggml/ggml/src/ggml-backend-impl.h
vendored
@@ -178,6 +178,10 @@ extern "C" {
|
|||||||
ggml_backend_event_t (*event_new) (ggml_backend_dev_t dev);
|
ggml_backend_event_t (*event_new) (ggml_backend_dev_t dev);
|
||||||
void (*event_free) (ggml_backend_dev_t dev, ggml_backend_event_t event);
|
void (*event_free) (ggml_backend_dev_t dev, ggml_backend_event_t event);
|
||||||
void (*event_synchronize) (ggml_backend_dev_t dev, ggml_backend_event_t event);
|
void (*event_synchronize) (ggml_backend_dev_t dev, ggml_backend_event_t event);
|
||||||
|
|
||||||
|
// (optional) reset device, clearing existing allocations and context
|
||||||
|
// the caller must ensure that there are no outstanding buffers, as these will become invalid
|
||||||
|
void (*reset)(ggml_backend_dev_t dev);
|
||||||
};
|
};
|
||||||
|
|
||||||
struct ggml_backend_device {
|
struct ggml_backend_device {
|
||||||
|
|||||||
8
ml/backend/ggml/ggml/src/ggml-backend.cpp
vendored
8
ml/backend/ggml/ggml/src/ggml-backend.cpp
vendored
@@ -477,6 +477,14 @@ ggml_backend_t ggml_backend_dev_init(ggml_backend_dev_t device, const char * par
|
|||||||
return device->iface.init_backend(device, params);
|
return device->iface.init_backend(device, params);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void ggml_backend_dev_reset(ggml_backend_dev_t device) {
|
||||||
|
if (device->iface.reset == NULL) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
device->iface.reset(device);
|
||||||
|
}
|
||||||
|
|
||||||
ggml_backend_buffer_type_t ggml_backend_dev_buffer_type(ggml_backend_dev_t device) {
|
ggml_backend_buffer_type_t ggml_backend_dev_buffer_type(ggml_backend_dev_t device) {
|
||||||
return device->iface.get_buffer_type(device);
|
return device->iface.get_buffer_type(device);
|
||||||
}
|
}
|
||||||
|
|||||||
17
ml/backend/ggml/ggml/src/ggml-cuda/ggml-cuda.cu
vendored
17
ml/backend/ggml/ggml/src/ggml-cuda/ggml-cuda.cu
vendored
@@ -103,6 +103,11 @@ int ggml_cuda_get_device() {
|
|||||||
return id;
|
return id;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void ggml_cuda_reset_device(int device) {
|
||||||
|
ggml_cuda_set_device(device);
|
||||||
|
CUDA_CHECK(cudaDeviceReset());
|
||||||
|
}
|
||||||
|
|
||||||
static cudaError_t ggml_cuda_device_malloc(void ** ptr, size_t size, int device) {
|
static cudaError_t ggml_cuda_device_malloc(void ** ptr, size_t size, int device) {
|
||||||
ggml_cuda_set_device(device);
|
ggml_cuda_set_device(device);
|
||||||
cudaError_t err;
|
cudaError_t err;
|
||||||
@@ -3243,7 +3248,10 @@ static void ggml_backend_cuda_device_get_props(ggml_backend_dev_t dev, ggml_back
|
|||||||
props->description = ggml_backend_cuda_device_get_description(dev);
|
props->description = ggml_backend_cuda_device_get_description(dev);
|
||||||
props->id = ggml_backend_cuda_device_get_id(dev);
|
props->id = ggml_backend_cuda_device_get_id(dev);
|
||||||
props->type = ggml_backend_cuda_device_get_type(dev);
|
props->type = ggml_backend_cuda_device_get_type(dev);
|
||||||
ggml_backend_cuda_device_get_memory(dev, &props->memory_free, &props->memory_total);
|
|
||||||
|
// Memory reporting is disabled to avoid allocation of a CUDA primary context (~300 MB per device).
|
||||||
|
// If you need the memory data, call ggml_backend_dev_memory() explicitly.
|
||||||
|
props->memory_total = props->memory_free = 0;
|
||||||
|
|
||||||
bool host_buffer = getenv("GGML_CUDA_NO_PINNED") == nullptr;
|
bool host_buffer = getenv("GGML_CUDA_NO_PINNED") == nullptr;
|
||||||
#ifdef GGML_CUDA_NO_PEER_COPY
|
#ifdef GGML_CUDA_NO_PEER_COPY
|
||||||
@@ -3700,6 +3708,11 @@ static void ggml_backend_cuda_device_event_synchronize(ggml_backend_dev_t dev, g
|
|||||||
CUDA_CHECK(cudaEventSynchronize((cudaEvent_t)event->context));
|
CUDA_CHECK(cudaEventSynchronize((cudaEvent_t)event->context));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static void ggml_backend_cuda_device_reset(ggml_backend_dev_t dev) {
|
||||||
|
ggml_backend_cuda_device_context * ctx = (ggml_backend_cuda_device_context *)dev->context;
|
||||||
|
ggml_cuda_reset_device(ctx->device);
|
||||||
|
}
|
||||||
|
|
||||||
static const ggml_backend_device_i ggml_backend_cuda_device_interface = {
|
static const ggml_backend_device_i ggml_backend_cuda_device_interface = {
|
||||||
/* .get_name = */ ggml_backend_cuda_device_get_name,
|
/* .get_name = */ ggml_backend_cuda_device_get_name,
|
||||||
/* .get_description = */ ggml_backend_cuda_device_get_description,
|
/* .get_description = */ ggml_backend_cuda_device_get_description,
|
||||||
@@ -3716,6 +3729,7 @@ static const ggml_backend_device_i ggml_backend_cuda_device_interface = {
|
|||||||
/* .event_new = */ ggml_backend_cuda_device_event_new,
|
/* .event_new = */ ggml_backend_cuda_device_event_new,
|
||||||
/* .event_free = */ ggml_backend_cuda_device_event_free,
|
/* .event_free = */ ggml_backend_cuda_device_event_free,
|
||||||
/* .event_synchronize = */ ggml_backend_cuda_device_event_synchronize,
|
/* .event_synchronize = */ ggml_backend_cuda_device_event_synchronize,
|
||||||
|
/* .reset = */ ggml_backend_cuda_device_reset,
|
||||||
};
|
};
|
||||||
|
|
||||||
// backend reg
|
// backend reg
|
||||||
@@ -3835,7 +3849,6 @@ ggml_backend_reg_t ggml_backend_cuda_reg() {
|
|||||||
dev_ctx->device = i;
|
dev_ctx->device = i;
|
||||||
dev_ctx->name = GGML_CUDA_NAME + std::to_string(i);
|
dev_ctx->name = GGML_CUDA_NAME + std::to_string(i);
|
||||||
|
|
||||||
ggml_cuda_set_device(i);
|
|
||||||
cudaDeviceProp prop;
|
cudaDeviceProp prop;
|
||||||
CUDA_CHECK(cudaGetDeviceProperties(&prop, i));
|
CUDA_CHECK(cudaGetDeviceProperties(&prop, i));
|
||||||
dev_ctx->description = prop.name;
|
dev_ctx->description = prop.name;
|
||||||
|
|||||||
@@ -40,6 +40,7 @@
|
|||||||
#define cudaDeviceDisablePeerAccess hipDeviceDisablePeerAccess
|
#define cudaDeviceDisablePeerAccess hipDeviceDisablePeerAccess
|
||||||
#define cudaDeviceEnablePeerAccess hipDeviceEnablePeerAccess
|
#define cudaDeviceEnablePeerAccess hipDeviceEnablePeerAccess
|
||||||
#define cudaDeviceProp hipDeviceProp_t
|
#define cudaDeviceProp hipDeviceProp_t
|
||||||
|
#define cudaDeviceReset hipDeviceReset
|
||||||
#define cudaDeviceSynchronize hipDeviceSynchronize
|
#define cudaDeviceSynchronize hipDeviceSynchronize
|
||||||
#define cudaError_t hipError_t
|
#define cudaError_t hipError_t
|
||||||
#define cudaErrorPeerAccessAlreadyEnabled hipErrorPeerAccessAlreadyEnabled
|
#define cudaErrorPeerAccessAlreadyEnabled hipErrorPeerAccessAlreadyEnabled
|
||||||
|
|||||||
8
ml/backend/ggml/ggml/src/ggml.cpp
vendored
8
ml/backend/ggml/ggml/src/ggml.cpp
vendored
@@ -19,8 +19,12 @@ static bool ggml_uncaught_exception_init = []{
|
|||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
const auto prev{std::get_terminate()};
|
const auto prev{std::get_terminate()};
|
||||||
GGML_ASSERT(prev != ggml_uncaught_exception);
|
// GGML_ASSERT(prev != ggml_uncaught_exception);
|
||||||
previous_terminate_handler = prev;
|
if (prev != ggml_uncaught_exception) {
|
||||||
|
previous_terminate_handler = prev;
|
||||||
|
} else {
|
||||||
|
GGML_LOG_WARN("%s double registration of ggml_uncaught_exception\n", __func__);
|
||||||
|
}
|
||||||
std::set_terminate(ggml_uncaught_exception);
|
std::set_terminate(ggml_uncaught_exception);
|
||||||
return true;
|
return true;
|
||||||
}();
|
}();
|
||||||
|
|||||||
36
ml/nn/pooling/pooling.go
Normal file
36
ml/nn/pooling/pooling.go
Normal file
@@ -0,0 +1,36 @@
|
|||||||
|
package pooling
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/ollama/ollama/ml"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Type uint32
|
||||||
|
|
||||||
|
const (
|
||||||
|
TypeNone Type = iota
|
||||||
|
TypeMean
|
||||||
|
TypeCLS
|
||||||
|
TypeLast
|
||||||
|
TypeRank
|
||||||
|
|
||||||
|
TypeUnknown = 0xFFFFFFFE
|
||||||
|
TypeUnspecified = 0xFFFFFFFF
|
||||||
|
)
|
||||||
|
|
||||||
|
func Pooling(ctx ml.Context, hiddenStates ml.Tensor, poolingType Type) ml.Tensor {
|
||||||
|
switch poolingType {
|
||||||
|
case TypeNone:
|
||||||
|
return hiddenStates
|
||||||
|
case TypeMean:
|
||||||
|
hiddenStates = hiddenStates.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx).Mean(ctx)
|
||||||
|
return hiddenStates.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx)
|
||||||
|
case TypeCLS:
|
||||||
|
return hiddenStates.View(ctx, 0, hiddenStates.Dim(0))
|
||||||
|
case TypeLast:
|
||||||
|
panic("not implemented")
|
||||||
|
case TypeRank:
|
||||||
|
panic("not implemented")
|
||||||
|
default:
|
||||||
|
panic("not implemented")
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -2,7 +2,6 @@ package model
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"cmp"
|
"cmp"
|
||||||
"context"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"iter"
|
"iter"
|
||||||
"log/slog"
|
"log/slog"
|
||||||
@@ -202,12 +201,11 @@ func (bpe BytePairEncoding) Encode(s string, addSpecial bool) ([]int32, error) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
slog.Log(context.TODO(), logutil.LevelTrace, "encoded", "string", s, "ids", ids)
|
|
||||||
|
|
||||||
if addSpecial && len(ids) > 0 {
|
if addSpecial && len(ids) > 0 {
|
||||||
ids = bpe.vocab.addSpecials(ids)
|
ids = bpe.vocab.addSpecials(ids)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
logutil.Trace("encoded", "string", s, "ids", ids)
|
||||||
return ids, nil
|
return ids, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -243,6 +241,6 @@ func (bpe BytePairEncoding) Decode(ids []int32) (string, error) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
slog.Log(context.TODO(), logutil.LevelTrace, "decoded", "string", sb.String(), "from", lazyIdsString{ids: ids})
|
logutil.Trace("decoded", "string", sb.String(), "from", lazyIdsString{ids: ids})
|
||||||
return sb.String(), nil
|
return sb.String(), nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -54,10 +54,9 @@ type Batch struct {
|
|||||||
// Inputs is the input tokens, including placeholders for multimodal inputs.
|
// Inputs is the input tokens, including placeholders for multimodal inputs.
|
||||||
Inputs ml.Tensor
|
Inputs ml.Tensor
|
||||||
|
|
||||||
// Multimodal is a set of multimodal embeddings previously created by
|
// Outputs are the set of indicies into Inputs for which output data should
|
||||||
// EncodeMultimodal, along with an index into Inputs. Unused for text-only
|
// be returned.
|
||||||
// models or for batches without multimodal elements.
|
Outputs ml.Tensor
|
||||||
Multimodal []MultimodalIndex
|
|
||||||
|
|
||||||
// Positions is the position for each Input, relative to its sequence. Equal
|
// Positions is the position for each Input, relative to its sequence. Equal
|
||||||
// in length to Inputs.
|
// in length to Inputs.
|
||||||
@@ -66,7 +65,8 @@ type Batch struct {
|
|||||||
// Sequences is the sequence for each Input. Equal in length to Inputs.
|
// Sequences is the sequence for each Input. Equal in length to Inputs.
|
||||||
Sequences []int
|
Sequences []int
|
||||||
|
|
||||||
// Outputs are the set of indicies into Inputs for which output data should
|
// Multimodal is a set of multimodal embeddings previously created by
|
||||||
// be returned.
|
// EncodeMultimodal, along with an index into Inputs. Unused for text-only
|
||||||
Outputs []int32
|
// models or for batches without multimodal elements.
|
||||||
|
Multimodal []MultimodalIndex
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,12 +1,11 @@
|
|||||||
package model
|
package model
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
_ "image/jpeg"
|
_ "image/jpeg"
|
||||||
_ "image/png"
|
_ "image/png"
|
||||||
"log/slog"
|
"math"
|
||||||
"os"
|
"os"
|
||||||
"reflect"
|
"reflect"
|
||||||
"strconv"
|
"strconv"
|
||||||
@@ -25,7 +24,11 @@ import (
|
|||||||
"github.com/ollama/ollama/model/input"
|
"github.com/ollama/ollama/model/input"
|
||||||
)
|
)
|
||||||
|
|
||||||
var ErrNoVisionModel = errors.New("this model is missing data required for image input")
|
var (
|
||||||
|
ErrNoVisionModel = errors.New("this model is missing data required for image input")
|
||||||
|
ErrUnsupportedModel = errors.New("model not supported")
|
||||||
|
ErrUnsupportedTokenizer = errors.New("tokenizer not supported")
|
||||||
|
)
|
||||||
|
|
||||||
// Model implements a specific model architecture, defining the forward pass and any model-specific configuration
|
// Model implements a specific model architecture, defining the forward pass and any model-specific configuration
|
||||||
type Model interface {
|
type Model interface {
|
||||||
@@ -64,7 +67,7 @@ type MultimodalProcessor interface {
|
|||||||
// This function is also responsible for updating MultimodalHash for any Multimodal
|
// This function is also responsible for updating MultimodalHash for any Multimodal
|
||||||
// that is modified to ensure that there is a unique hash value that accurately
|
// that is modified to ensure that there is a unique hash value that accurately
|
||||||
// represents the contents.
|
// represents the contents.
|
||||||
PostTokenize([]input.Input) ([]input.Input, error)
|
PostTokenize([]*input.Input) ([]*input.Input, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Base implements the common fields and methods for all models
|
// Base implements the common fields and methods for all models
|
||||||
@@ -105,6 +108,10 @@ func New(modelPath string, params ml.BackendParams) (Model, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
arch := b.Config().Architecture()
|
arch := b.Config().Architecture()
|
||||||
|
if b.Config().Uint("pooling_type", math.MaxUint32) != math.MaxUint32 {
|
||||||
|
arch = arch + "_embed"
|
||||||
|
}
|
||||||
|
|
||||||
f, ok := models[arch]
|
f, ok := models[arch]
|
||||||
if !ok {
|
if !ok {
|
||||||
return nil, fmt.Errorf("unsupported model architecture %q", arch)
|
return nil, fmt.Errorf("unsupported model architecture %q", arch)
|
||||||
@@ -198,7 +205,7 @@ func populateFields(base Base, v reflect.Value, tags ...Tag) reflect.Value {
|
|||||||
names := fn(tagsCopy)
|
names := fn(tagsCopy)
|
||||||
for _, name := range names {
|
for _, name := range names {
|
||||||
if tensor := base.Backend().Get(strings.Join(name, ".")); tensor != nil {
|
if tensor := base.Backend().Get(strings.Join(name, ".")); tensor != nil {
|
||||||
slog.Log(context.TODO(), logutil.LevelTrace, "found tensor", "", tensor)
|
logutil.Trace("found tensor", "", tensor)
|
||||||
vv.Set(reflect.ValueOf(tensor))
|
vv.Set(reflect.ValueOf(tensor))
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
@@ -239,7 +246,7 @@ func setPointer(base Base, v reflect.Value, tags []Tag) {
|
|||||||
vv = vv.Elem()
|
vv = vv.Elem()
|
||||||
}
|
}
|
||||||
|
|
||||||
vv = vv.Elem()
|
vv = reflect.Indirect(vv)
|
||||||
if v.IsNil() {
|
if v.IsNil() {
|
||||||
vv = reflect.New(v.Type().Elem()).Elem()
|
vv = reflect.New(v.Type().Elem()).Elem()
|
||||||
}
|
}
|
||||||
@@ -278,7 +285,7 @@ func canNil(t reflect.Type) bool {
|
|||||||
t.Kind() == reflect.Slice
|
t.Kind() == reflect.Slice
|
||||||
}
|
}
|
||||||
|
|
||||||
func Forward(ctx ml.Context, m Model, inputs []int32, batch input.Batch) (ml.Tensor, error) {
|
func Forward(ctx ml.Context, m Model, batch input.Batch) (ml.Tensor, error) {
|
||||||
if len(batch.Positions) != len(batch.Sequences) {
|
if len(batch.Positions) != len(batch.Sequences) {
|
||||||
return nil, fmt.Errorf("length of positions (%v) must match length of seqs (%v)", len(batch.Positions), len(batch.Sequences))
|
return nil, fmt.Errorf("length of positions (%v) must match length of seqs (%v)", len(batch.Positions), len(batch.Sequences))
|
||||||
}
|
}
|
||||||
@@ -287,8 +294,6 @@ func Forward(ctx ml.Context, m Model, inputs []int32, batch input.Batch) (ml.Ten
|
|||||||
return nil, errors.New("batch size cannot be less than 1")
|
return nil, errors.New("batch size cannot be less than 1")
|
||||||
}
|
}
|
||||||
|
|
||||||
batch.Inputs = ctx.Input().FromIntSlice(inputs, len(inputs))
|
|
||||||
|
|
||||||
cache := m.Config().Cache
|
cache := m.Config().Cache
|
||||||
if cache != nil {
|
if cache != nil {
|
||||||
err := cache.StartForward(ctx, batch, false)
|
err := cache.StartForward(ctx, batch, false)
|
||||||
@@ -302,7 +307,7 @@ func Forward(ctx ml.Context, m Model, inputs []int32, batch input.Batch) (ml.Ten
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
ctx.Forward(t).Compute(t)
|
ctx.Forward(t)
|
||||||
|
|
||||||
return t, nil
|
return t, nil
|
||||||
}
|
}
|
||||||
|
|||||||
181
model/models/bert/model.go
Normal file
181
model/models/bert/model.go
Normal file
@@ -0,0 +1,181 @@
|
|||||||
|
package bert
|
||||||
|
|
||||||
|
import (
|
||||||
|
"cmp"
|
||||||
|
"math"
|
||||||
|
|
||||||
|
"github.com/ollama/ollama/fs"
|
||||||
|
"github.com/ollama/ollama/ml"
|
||||||
|
"github.com/ollama/ollama/ml/nn"
|
||||||
|
"github.com/ollama/ollama/ml/nn/pooling"
|
||||||
|
"github.com/ollama/ollama/model"
|
||||||
|
"github.com/ollama/ollama/model/input"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Model struct {
|
||||||
|
model.Base
|
||||||
|
model.TextProcessor
|
||||||
|
|
||||||
|
TokenEmbedding *nn.Embedding `gguf:"token_embd"`
|
||||||
|
TypeEmbedding *nn.Embedding `gguf:"token_types"`
|
||||||
|
PositionEmbedding *nn.Embedding `gguf:"position_embd"`
|
||||||
|
TokenEmbeddingNorm *nn.LayerNorm `gguf:"token_embd_norm"`
|
||||||
|
|
||||||
|
Layers []EncoderLayer `gguf:"blk"`
|
||||||
|
|
||||||
|
Options
|
||||||
|
}
|
||||||
|
|
||||||
|
// Forward implements model.Model.
|
||||||
|
func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
|
||||||
|
hiddenStates := m.TokenEmbedding.Forward(ctx, batch.Inputs)
|
||||||
|
hiddenStates = hiddenStates.Add(ctx, m.TypeEmbedding.Weight.View(ctx, 0, m.hiddenSize))
|
||||||
|
hiddenStates = hiddenStates.Add(ctx, m.PositionEmbedding.Forward(ctx, ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions))))
|
||||||
|
hiddenStates = m.TokenEmbeddingNorm.Forward(ctx, hiddenStates, m.eps)
|
||||||
|
|
||||||
|
for _, layer := range m.Layers {
|
||||||
|
hiddenStates = layer.Forward(ctx, hiddenStates, &m.Options)
|
||||||
|
}
|
||||||
|
|
||||||
|
hiddenStates = pooling.Pooling(ctx, hiddenStates, m.poolingType)
|
||||||
|
if m.normalize {
|
||||||
|
hiddenStates = hiddenStates.L2Norm(ctx, 1e-12)
|
||||||
|
}
|
||||||
|
|
||||||
|
return hiddenStates, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type EncoderLayer struct {
|
||||||
|
*Attention
|
||||||
|
AttentionNorm *nn.LayerNorm `gguf:"attn_output_norm"`
|
||||||
|
|
||||||
|
*MLP
|
||||||
|
MLPNorm *nn.LayerNorm `gguf:"layer_output_norm"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *EncoderLayer) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *Options) ml.Tensor {
|
||||||
|
// Attention
|
||||||
|
residual := hiddenStates
|
||||||
|
hiddenStates = e.Attention.Forward(ctx, hiddenStates, opts)
|
||||||
|
hiddenStates = hiddenStates.Add(ctx, residual)
|
||||||
|
hiddenStates = e.AttentionNorm.Forward(ctx, hiddenStates, opts.eps)
|
||||||
|
|
||||||
|
// MLP
|
||||||
|
residual = hiddenStates
|
||||||
|
hiddenStates = e.MLP.Forward(ctx, hiddenStates, opts)
|
||||||
|
hiddenStates = hiddenStates.Add(ctx, residual)
|
||||||
|
hiddenStates = e.MLPNorm.Forward(ctx, hiddenStates, opts.eps)
|
||||||
|
|
||||||
|
return hiddenStates
|
||||||
|
}
|
||||||
|
|
||||||
|
type Attention struct {
|
||||||
|
Query *nn.Linear `gguf:"attn_q"`
|
||||||
|
QueryNorm *nn.LayerNorm `gguf:"attn_q_norm"`
|
||||||
|
|
||||||
|
Key *nn.Linear `gguf:"attn_k"`
|
||||||
|
KeyNorm *nn.LayerNorm `gguf:"attn_k_norm"`
|
||||||
|
|
||||||
|
Value *nn.Linear `gguf:"attn_v"`
|
||||||
|
|
||||||
|
Output *nn.Linear `gguf:"attn_output"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *Attention) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *Options) ml.Tensor {
|
||||||
|
batchSize := hiddenStates.Dim(1)
|
||||||
|
|
||||||
|
query := a.Query.Forward(ctx, hiddenStates)
|
||||||
|
if a.QueryNorm != nil {
|
||||||
|
query = a.QueryNorm.Forward(ctx, query, opts.eps)
|
||||||
|
}
|
||||||
|
query = query.Reshape(ctx, opts.headDim(), opts.numHeads, batchSize)
|
||||||
|
|
||||||
|
key := a.Key.Forward(ctx, hiddenStates)
|
||||||
|
if a.KeyNorm != nil {
|
||||||
|
key = a.KeyNorm.Forward(ctx, key, opts.eps)
|
||||||
|
}
|
||||||
|
key = key.Reshape(ctx, opts.headDim(), cmp.Or(opts.numKVHeads, opts.numHeads), batchSize)
|
||||||
|
|
||||||
|
value := a.Value.Forward(ctx, hiddenStates)
|
||||||
|
value = value.Reshape(ctx, opts.headDim(), cmp.Or(opts.numKVHeads, opts.numHeads), batchSize)
|
||||||
|
|
||||||
|
attention := nn.Attention(ctx, query, key, value, 1/math.Sqrt(float64(opts.headDim())), nil)
|
||||||
|
attention = attention.Reshape(ctx, opts.hiddenSize, batchSize)
|
||||||
|
return a.Output.Forward(ctx, attention)
|
||||||
|
}
|
||||||
|
|
||||||
|
type MLP struct {
|
||||||
|
Up *nn.Linear `gguf:"ffn_up"`
|
||||||
|
Down *nn.Linear `gguf:"ffn_down"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *MLP) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *Options) ml.Tensor {
|
||||||
|
return m.Down.Forward(ctx, m.Up.Forward(ctx, hiddenStates).GELU(ctx))
|
||||||
|
}
|
||||||
|
|
||||||
|
type Options struct {
|
||||||
|
hiddenSize,
|
||||||
|
numHeads,
|
||||||
|
numKVHeads,
|
||||||
|
keyLength,
|
||||||
|
valueLength int
|
||||||
|
poolingType pooling.Type
|
||||||
|
eps float32
|
||||||
|
normalize bool
|
||||||
|
}
|
||||||
|
|
||||||
|
func (o Options) headDim() int {
|
||||||
|
return cmp.Or(o.keyLength, o.valueLength, o.hiddenSize/o.numHeads)
|
||||||
|
}
|
||||||
|
|
||||||
|
func New(c fs.Config) (model.Model, error) {
|
||||||
|
var processor model.TextProcessor
|
||||||
|
switch c.String("tokenizer.ggml.model", "bert") {
|
||||||
|
case "bert":
|
||||||
|
processor = model.NewWordPiece(
|
||||||
|
&model.Vocabulary{
|
||||||
|
Values: c.Strings("tokenizer.ggml.tokens"),
|
||||||
|
Scores: c.Floats("tokenizer.ggml.scores"),
|
||||||
|
Types: c.Ints("tokenizer.ggml.token_type"),
|
||||||
|
AddBOS: c.Bool("tokenizer.ggml.add_bos_token", true),
|
||||||
|
BOS: []int32{
|
||||||
|
int32(cmp.Or(
|
||||||
|
c.Uint("tokenizer.ggml.cls_token_id"),
|
||||||
|
c.Uint("tokenizer.ggml.bos_token_id"),
|
||||||
|
)),
|
||||||
|
},
|
||||||
|
AddEOS: c.Bool("tokenizer.ggml.add_eos_token", true),
|
||||||
|
EOS: []int32{
|
||||||
|
int32(cmp.Or(
|
||||||
|
c.Uint("tokenizer.ggml.separator_token_id"),
|
||||||
|
//nolint:misspell
|
||||||
|
// NOTE: "seperator_token_id" is a typo in model metadata but we need to
|
||||||
|
// support it for compatibility.
|
||||||
|
c.Uint("tokenizer.ggml.seperator_token_id"),
|
||||||
|
c.Uint("tokenizer.ggml.eos_token_id"),
|
||||||
|
)),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
default:
|
||||||
|
return nil, model.ErrUnsupportedTokenizer
|
||||||
|
}
|
||||||
|
|
||||||
|
return &Model{
|
||||||
|
TextProcessor: processor,
|
||||||
|
Layers: make([]EncoderLayer, c.Uint("block_count")),
|
||||||
|
Options: Options{
|
||||||
|
hiddenSize: int(c.Uint("embedding_length")),
|
||||||
|
numHeads: int(c.Uint("attention.head_count")),
|
||||||
|
numKVHeads: int(c.Uint("attention.head_count_kv")),
|
||||||
|
eps: c.Float("attention.layer_norm_epsilon"),
|
||||||
|
poolingType: pooling.Type(c.Uint("pooling_type")),
|
||||||
|
normalize: c.Bool("normalize_embeddings", true),
|
||||||
|
},
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
model.Register("bert", New)
|
||||||
|
model.Register("bert_embed", New)
|
||||||
|
}
|
||||||
@@ -24,7 +24,7 @@ type Options struct {
|
|||||||
|
|
||||||
type Model struct {
|
type Model struct {
|
||||||
model.Base
|
model.Base
|
||||||
model.SentencePieceModel
|
model.SentencePiece
|
||||||
|
|
||||||
TokenEmbedding *nn.Embedding `gguf:"token_embd"`
|
TokenEmbedding *nn.Embedding `gguf:"token_embd"`
|
||||||
Layers []Layer `gguf:"blk"`
|
Layers []Layer `gguf:"blk"`
|
||||||
@@ -40,7 +40,7 @@ const (
|
|||||||
|
|
||||||
func New(c fs.Config) (model.Model, error) {
|
func New(c fs.Config) (model.Model, error) {
|
||||||
m := Model{
|
m := Model{
|
||||||
SentencePieceModel: model.NewSentencePieceModel(
|
SentencePiece: model.NewSentencePiece(
|
||||||
&model.Vocabulary{
|
&model.Vocabulary{
|
||||||
Values: c.Strings("tokenizer.ggml.tokens"),
|
Values: c.Strings("tokenizer.ggml.tokens"),
|
||||||
Scores: c.Floats("tokenizer.ggml.scores"),
|
Scores: c.Floats("tokenizer.ggml.scores"),
|
||||||
@@ -176,7 +176,6 @@ func (l *Layer) Forward(ctx ml.Context, hiddenState, positionIDs, outputs ml.Ten
|
|||||||
|
|
||||||
func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
|
func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
|
||||||
positions := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions))
|
positions := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions))
|
||||||
outputs := ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs))
|
|
||||||
|
|
||||||
hiddenState := m.TokenEmbedding.Forward(ctx, batch.Inputs)
|
hiddenState := m.TokenEmbedding.Forward(ctx, batch.Inputs)
|
||||||
hiddenState = hiddenState.Scale(ctx, math.Sqrt(float64(m.Options.hiddenSize)))
|
hiddenState = hiddenState.Scale(ctx, math.Sqrt(float64(m.Options.hiddenSize)))
|
||||||
@@ -193,7 +192,7 @@ func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
|
|||||||
|
|
||||||
var lastLayerOutputs ml.Tensor
|
var lastLayerOutputs ml.Tensor
|
||||||
if i == len(m.Layers)-1 {
|
if i == len(m.Layers)-1 {
|
||||||
lastLayerOutputs = outputs
|
lastLayerOutputs = batch.Outputs
|
||||||
}
|
}
|
||||||
|
|
||||||
hiddenState = layer.Forward(ctx, hiddenState, positions, lastLayerOutputs, m.Cache, m.Options)
|
hiddenState = layer.Forward(ctx, hiddenState, positions, lastLayerOutputs, m.Cache, m.Options)
|
||||||
|
|||||||
62
model/models/gemma3/embed.go
Normal file
62
model/models/gemma3/embed.go
Normal file
@@ -0,0 +1,62 @@
|
|||||||
|
package gemma3
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/ollama/ollama/fs"
|
||||||
|
"github.com/ollama/ollama/kvcache"
|
||||||
|
"github.com/ollama/ollama/ml"
|
||||||
|
"github.com/ollama/ollama/ml/nn"
|
||||||
|
"github.com/ollama/ollama/ml/nn/pooling"
|
||||||
|
"github.com/ollama/ollama/model"
|
||||||
|
"github.com/ollama/ollama/model/input"
|
||||||
|
)
|
||||||
|
|
||||||
|
type embedModel struct {
|
||||||
|
model.Base
|
||||||
|
model.SentencePiece
|
||||||
|
|
||||||
|
*TextModel
|
||||||
|
poolingType pooling.Type
|
||||||
|
|
||||||
|
Dense [2]*nn.Linear `gguf:"dense"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *embedModel) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
|
||||||
|
hiddenStates := m.TextModel.Forward(ctx, batch, m.Cache)
|
||||||
|
hiddenStates = pooling.Pooling(ctx, hiddenStates, m.poolingType)
|
||||||
|
for _, dense := range m.Dense {
|
||||||
|
hiddenStates = dense.Forward(ctx, hiddenStates)
|
||||||
|
}
|
||||||
|
hiddenStates = hiddenStates.L2Norm(ctx, 1e-12)
|
||||||
|
return hiddenStates, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func newEmbedModel(c fs.Config) (model.Model, error) {
|
||||||
|
m := &embedModel{
|
||||||
|
SentencePiece: model.NewSentencePiece(
|
||||||
|
&model.Vocabulary{
|
||||||
|
Values: c.Strings("tokenizer.ggml.tokens"),
|
||||||
|
Scores: c.Floats("tokenizer.ggml.scores"),
|
||||||
|
Types: c.Ints("tokenizer.ggml.token_type"),
|
||||||
|
AddBOS: c.Bool("tokenizer.ggml.add_bos_token", true),
|
||||||
|
BOS: []int32{int32(c.Uint("tokenizer.ggml.bos_token_id"))},
|
||||||
|
AddEOS: c.Bool("tokenizer.ggml.add_eos_token", false),
|
||||||
|
EOS: append(
|
||||||
|
[]int32{
|
||||||
|
int32(c.Uint("tokenizer.ggml.eos_token_id")),
|
||||||
|
int32(c.Uint("tokenizer.ggml.eot_token_id", 106)),
|
||||||
|
},
|
||||||
|
c.Ints("tokenizer.ggml.eos_token_ids")...,
|
||||||
|
),
|
||||||
|
},
|
||||||
|
),
|
||||||
|
TextModel: newTextModel(c),
|
||||||
|
poolingType: pooling.Type(c.Uint("pooling_type", 0)),
|
||||||
|
}
|
||||||
|
|
||||||
|
m.Cache = kvcache.NewWrapperCache(
|
||||||
|
kvcache.NewSWACache(int32(c.Uint("attention.sliding_window")), m.Shift),
|
||||||
|
kvcache.NewCausalCache(m.Shift),
|
||||||
|
)
|
||||||
|
|
||||||
|
return m, nil
|
||||||
|
}
|
||||||
@@ -16,9 +16,9 @@ import (
|
|||||||
|
|
||||||
type Model struct {
|
type Model struct {
|
||||||
model.Base
|
model.Base
|
||||||
model.SentencePieceModel
|
model.SentencePiece
|
||||||
|
|
||||||
*VisionModel `gguf:"v,vision"`
|
*VisionModel `gguf:"v"`
|
||||||
*TextModel
|
*TextModel
|
||||||
|
|
||||||
*MultiModalProjector `gguf:"mm"`
|
*MultiModalProjector `gguf:"mm"`
|
||||||
@@ -55,7 +55,7 @@ func (p *MultiModalProjector) Forward(ctx ml.Context, visionOutputs ml.Tensor, i
|
|||||||
|
|
||||||
func New(c fs.Config) (model.Model, error) {
|
func New(c fs.Config) (model.Model, error) {
|
||||||
m := Model{
|
m := Model{
|
||||||
SentencePieceModel: model.NewSentencePieceModel(
|
SentencePiece: model.NewSentencePiece(
|
||||||
&model.Vocabulary{
|
&model.Vocabulary{
|
||||||
Values: c.Strings("tokenizer.ggml.tokens"),
|
Values: c.Strings("tokenizer.ggml.tokens"),
|
||||||
Scores: c.Floats("tokenizer.ggml.scores"),
|
Scores: c.Floats("tokenizer.ggml.scores"),
|
||||||
@@ -112,8 +112,8 @@ func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) ([]input
|
|||||||
return []input.Multimodal{{Tensor: visionOutputs}}, nil
|
return []input.Multimodal{{Tensor: visionOutputs}}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Model) PostTokenize(inputs []input.Input) ([]input.Input, error) {
|
func (m *Model) PostTokenize(inputs []*input.Input) ([]*input.Input, error) {
|
||||||
var result []input.Input
|
var result []*input.Input
|
||||||
|
|
||||||
for _, inp := range inputs {
|
for _, inp := range inputs {
|
||||||
if len(inp.Multimodal) == 0 {
|
if len(inp.Multimodal) == 0 {
|
||||||
@@ -122,17 +122,17 @@ func (m *Model) PostTokenize(inputs []input.Input) ([]input.Input, error) {
|
|||||||
inputMultimodal := inp.Multimodal[0].Tensor
|
inputMultimodal := inp.Multimodal[0].Tensor
|
||||||
|
|
||||||
result = append(result,
|
result = append(result,
|
||||||
input.Input{Token: 108, SameBatch: inputMultimodal.Dim(1) + 3}, // "\n\n"
|
&input.Input{Token: 108, SameBatch: inputMultimodal.Dim(1) + 3}, // "\n\n"
|
||||||
input.Input{Token: 255999}, // "<start_of_image>""
|
&input.Input{Token: 255999}, // "<start_of_image>""
|
||||||
input.Input{Multimodal: []input.Multimodal{{Tensor: inputMultimodal}}, MultimodalHash: inp.MultimodalHash}, // image data is on the first placeholder
|
&input.Input{Multimodal: []input.Multimodal{{Tensor: inputMultimodal}}, MultimodalHash: inp.MultimodalHash}, // image data is on the first placeholder
|
||||||
)
|
)
|
||||||
|
|
||||||
// add image token placeholders
|
// add image token placeholders
|
||||||
result = append(result, slices.Repeat([]input.Input{{Token: 0}}, inputMultimodal.Dim(1)-1)...)
|
result = append(result, slices.Repeat([]*input.Input{{Token: 0}}, inputMultimodal.Dim(1)-1)...)
|
||||||
|
|
||||||
result = append(result,
|
result = append(result,
|
||||||
input.Input{Token: 256000}, // <end_of_image>
|
&input.Input{Token: 256000}, // <end_of_image>
|
||||||
input.Input{Token: 108}, // "\n\n"
|
&input.Input{Token: 108}, // "\n\n"
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -141,12 +141,11 @@ func (m *Model) PostTokenize(inputs []input.Input) ([]input.Input, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
|
func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
|
||||||
positions := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions))
|
hiddenStates := m.TextModel.Forward(ctx, batch, m.Cache)
|
||||||
outputs := ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs))
|
return m.Output.Forward(ctx, hiddenStates), nil
|
||||||
|
|
||||||
return m.TextModel.Forward(ctx, batch.Inputs, positions, outputs, batch, m.Cache), nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func init() {
|
func init() {
|
||||||
model.Register("gemma3", New)
|
model.Register("gemma3", New)
|
||||||
|
model.Register("gemma3_embed", newEmbedModel)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -159,8 +159,10 @@ func (l *TextLayer) Forward(ctx ml.Context, layer int, hiddenState, positionIDs,
|
|||||||
return hiddenState.Add(ctx, residual)
|
return hiddenState.Add(ctx, residual)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *TextModel) Forward(ctx ml.Context, inputs, positions, outputs ml.Tensor, batch input.Batch, cache kvcache.Cache) ml.Tensor {
|
func (m *TextModel) Forward(ctx ml.Context, batch input.Batch, cache kvcache.Cache) ml.Tensor {
|
||||||
hiddenState := m.TokenEmbedding.Forward(ctx, inputs)
|
positions := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions))
|
||||||
|
|
||||||
|
hiddenState := m.TokenEmbedding.Forward(ctx, batch.Inputs)
|
||||||
hiddenState = hiddenState.Scale(ctx, math.Sqrt(float64(m.TextConfig.hiddenSize)))
|
hiddenState = hiddenState.Scale(ctx, math.Sqrt(float64(m.TextConfig.hiddenSize)))
|
||||||
|
|
||||||
// set image embeddings
|
// set image embeddings
|
||||||
@@ -191,12 +193,12 @@ func (m *TextModel) Forward(ctx ml.Context, inputs, positions, outputs ml.Tensor
|
|||||||
|
|
||||||
var lastLayerOutputs ml.Tensor
|
var lastLayerOutputs ml.Tensor
|
||||||
if i == len(m.Layers)-1 {
|
if i == len(m.Layers)-1 {
|
||||||
lastLayerOutputs = outputs
|
lastLayerOutputs = batch.Outputs
|
||||||
}
|
}
|
||||||
|
|
||||||
hiddenState = layer.Forward(ctx, i, hiddenState, positions, lastLayerOutputs, cache, m.TextConfig)
|
hiddenState = layer.Forward(ctx, i, hiddenState, positions, lastLayerOutputs, cache, m.TextConfig)
|
||||||
}
|
}
|
||||||
|
|
||||||
hiddenState = m.OutputNorm.Forward(ctx, hiddenState, m.eps)
|
hiddenState = m.OutputNorm.Forward(ctx, hiddenState, m.eps)
|
||||||
return m.Output.Forward(ctx, hiddenState)
|
return hiddenState
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -10,7 +10,7 @@ import (
|
|||||||
|
|
||||||
type Model struct {
|
type Model struct {
|
||||||
model.Base
|
model.Base
|
||||||
model.SentencePieceModel
|
model.SentencePiece
|
||||||
|
|
||||||
*TextModel
|
*TextModel
|
||||||
}
|
}
|
||||||
@@ -23,7 +23,7 @@ func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
|
|||||||
func New(c fs.Config) (model.Model, error) {
|
func New(c fs.Config) (model.Model, error) {
|
||||||
m := Model{
|
m := Model{
|
||||||
TextModel: newTextModel(c),
|
TextModel: newTextModel(c),
|
||||||
SentencePieceModel: model.NewSentencePieceModel(
|
SentencePiece: model.NewSentencePiece(
|
||||||
&model.Vocabulary{
|
&model.Vocabulary{
|
||||||
Values: c.Strings("tokenizer.ggml.tokens"),
|
Values: c.Strings("tokenizer.ggml.tokens"),
|
||||||
Scores: c.Floats("tokenizer.ggml.scores"),
|
Scores: c.Floats("tokenizer.ggml.scores"),
|
||||||
|
|||||||
@@ -83,7 +83,7 @@ func (m *TextModel) Forward(ctx ml.Context, batch input.Batch, cache kvcache.Cac
|
|||||||
|
|
||||||
hiddenStates = hiddenStates.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx).Mean(ctx)
|
hiddenStates = hiddenStates.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx).Mean(ctx)
|
||||||
hiddenStates = hiddenStates.Permute(ctx, 2, 0, 1, 3).Contiguous(ctx)
|
hiddenStates = hiddenStates.Permute(ctx, 2, 0, 1, 3).Contiguous(ctx)
|
||||||
hiddenStates = hiddenStates.Rows(ctx, ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs)))
|
hiddenStates = hiddenStates.Rows(ctx, batch.Outputs)
|
||||||
|
|
||||||
hiddenStates = m.OutputNorm.Forward(ctx, hiddenStates, m.eps)
|
hiddenStates = m.OutputNorm.Forward(ctx, hiddenStates, m.eps)
|
||||||
return m.Output.Forward(ctx, hiddenStates), nil
|
return m.Output.Forward(ctx, hiddenStates), nil
|
||||||
|
|||||||
@@ -41,8 +41,8 @@ func (m *Transformer) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, err
|
|||||||
}
|
}
|
||||||
|
|
||||||
var outputs ml.Tensor
|
var outputs ml.Tensor
|
||||||
if len(batch.Outputs) > 0 && i == len(m.TransformerBlocks)-1 {
|
if i == len(m.TransformerBlocks)-1 {
|
||||||
outputs = ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs))
|
outputs = batch.Outputs
|
||||||
}
|
}
|
||||||
|
|
||||||
hiddenStates = block.Forward(ctx, hiddenStates, positions, outputs, one, m.Cache, &m.Options)
|
hiddenStates = block.Forward(ctx, hiddenStates, positions, outputs, one, m.Cache, &m.Options)
|
||||||
|
|||||||
@@ -160,7 +160,7 @@ func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
|
|||||||
|
|
||||||
var outputs ml.Tensor
|
var outputs ml.Tensor
|
||||||
if i == len(m.Layers)-1 {
|
if i == len(m.Layers)-1 {
|
||||||
outputs = ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs))
|
outputs = batch.Outputs
|
||||||
}
|
}
|
||||||
|
|
||||||
hiddenState = layer.Forward(ctx, hiddenState, positions, outputs, m.Cache, m.Options)
|
hiddenState = layer.Forward(ctx, hiddenState, positions, outputs, m.Cache, m.Options)
|
||||||
|
|||||||
@@ -18,7 +18,7 @@ type Model struct {
|
|||||||
model.BytePairEncoding
|
model.BytePairEncoding
|
||||||
ImageProcessor
|
ImageProcessor
|
||||||
|
|
||||||
*VisionModel `gguf:"v,vision"`
|
*VisionModel `gguf:"v"`
|
||||||
*Projector `gguf:"mm"`
|
*Projector `gguf:"mm"`
|
||||||
*TextModel
|
*TextModel
|
||||||
}
|
}
|
||||||
@@ -134,16 +134,16 @@ type separator struct {
|
|||||||
y bool
|
y bool
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Model) PostTokenize(inputs []input.Input) ([]input.Input, error) {
|
func (m *Model) PostTokenize(inputs []*input.Input) ([]*input.Input, error) {
|
||||||
var result []input.Input
|
var result []*input.Input
|
||||||
for _, inp := range inputs {
|
for _, inp := range inputs {
|
||||||
if len(inp.Multimodal) == 0 {
|
if len(inp.Multimodal) == 0 {
|
||||||
result = append(result, inp)
|
result = append(result, inp)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
var imageInputs []input.Input
|
var imageInputs []*input.Input
|
||||||
imageInputs = append(imageInputs, input.Input{Token: 200080}) // <|image_start|>
|
imageInputs = append(imageInputs, &input.Input{Token: 200080}) // <|image_start|>
|
||||||
|
|
||||||
for i, mm := range inp.Multimodal {
|
for i, mm := range inp.Multimodal {
|
||||||
patchesPerChunk := mm.Tensor.Dim(1)
|
patchesPerChunk := mm.Tensor.Dim(1)
|
||||||
@@ -151,20 +151,20 @@ func (m *Model) PostTokenize(inputs []input.Input) ([]input.Input, error) {
|
|||||||
if i < len(inp.Multimodal)-1 {
|
if i < len(inp.Multimodal)-1 {
|
||||||
separator := mm.Data.(*separator)
|
separator := mm.Data.(*separator)
|
||||||
|
|
||||||
imageInputs = append(imageInputs, input.Input{Token: 200092, Multimodal: []input.Multimodal{{Tensor: mm.Tensor}}, MultimodalHash: inp.MultimodalHash, SameBatch: patchesPerChunk}) // <|patch|>
|
imageInputs = append(imageInputs, &input.Input{Token: 200092, Multimodal: []input.Multimodal{{Tensor: mm.Tensor}}, MultimodalHash: inp.MultimodalHash, SameBatch: patchesPerChunk}) // <|patch|>
|
||||||
imageInputs = append(imageInputs, slices.Repeat([]input.Input{{Token: 200092}}, patchesPerChunk-1)...)
|
imageInputs = append(imageInputs, slices.Repeat([]*input.Input{{Token: 200092}}, patchesPerChunk-1)...)
|
||||||
|
|
||||||
if separator.x {
|
if separator.x {
|
||||||
imageInputs = append(imageInputs, input.Input{Token: 200084}) // <|tile_x_separator|>
|
imageInputs = append(imageInputs, &input.Input{Token: 200084}) // <|tile_x_separator|>
|
||||||
}
|
}
|
||||||
if separator.y {
|
if separator.y {
|
||||||
imageInputs = append(imageInputs, input.Input{Token: 200085}) // <|tile_y_separator|>
|
imageInputs = append(imageInputs, &input.Input{Token: 200085}) // <|tile_y_separator|>
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
imageInputs = append(imageInputs, input.Input{Token: 200090}) // <|image|>
|
imageInputs = append(imageInputs, &input.Input{Token: 200090}) // <|image|>
|
||||||
imageInputs = append(imageInputs, input.Input{Token: 200092, Multimodal: []input.Multimodal{{Tensor: mm.Tensor}}, MultimodalHash: inp.MultimodalHash, SameBatch: patchesPerChunk}) // <|patch|>
|
imageInputs = append(imageInputs, &input.Input{Token: 200092, Multimodal: []input.Multimodal{{Tensor: mm.Tensor}}, MultimodalHash: inp.MultimodalHash, SameBatch: patchesPerChunk}) // <|patch|>
|
||||||
imageInputs = append(imageInputs, slices.Repeat([]input.Input{{Token: 200092}}, patchesPerChunk-1)...)
|
imageInputs = append(imageInputs, slices.Repeat([]*input.Input{{Token: 200092}}, patchesPerChunk-1)...)
|
||||||
imageInputs = append(imageInputs, input.Input{Token: 200080}) // <|image_end|>
|
imageInputs = append(imageInputs, &input.Input{Token: 200080}) // <|image_end|>
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -176,9 +176,7 @@ func (m *Model) PostTokenize(inputs []input.Input) ([]input.Input, error) {
|
|||||||
|
|
||||||
func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
|
func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
|
||||||
positions := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions))
|
positions := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions))
|
||||||
outputs := ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs))
|
return m.TextModel.Forward(ctx, batch.Inputs, positions, batch.Outputs, batch, m.Cache), nil
|
||||||
|
|
||||||
return m.TextModel.Forward(ctx, batch.Inputs, positions, outputs, batch, m.Cache), nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func init() {
|
func init() {
|
||||||
|
|||||||
@@ -18,7 +18,7 @@ type Model struct {
|
|||||||
model.BytePairEncoding
|
model.BytePairEncoding
|
||||||
|
|
||||||
*TextModel
|
*TextModel
|
||||||
*VisionModel `gguf:"v,vision"`
|
*VisionModel `gguf:"v"`
|
||||||
*MultiModalProjector `gguf:"mm"`
|
*MultiModalProjector `gguf:"mm"`
|
||||||
|
|
||||||
ImageProcessor
|
ImageProcessor
|
||||||
@@ -133,22 +133,22 @@ func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) ([]input
|
|||||||
// [IMG]...[IMG][IMG_BREAK][IMG]...[IMG][IMG_BREAK][IMG]...[IMG][IMG_END]
|
// [IMG]...[IMG][IMG_BREAK][IMG]...[IMG][IMG_BREAK][IMG]...[IMG][IMG_END]
|
||||||
// Each sequence of [IMG]...[IMG] is a set of patches of vision embeddings
|
// Each sequence of [IMG]...[IMG] is a set of patches of vision embeddings
|
||||||
// that can be processed together.
|
// that can be processed together.
|
||||||
func (m *Model) PostTokenize(inputs []input.Input) ([]input.Input, error) {
|
func (m *Model) PostTokenize(inputs []*input.Input) ([]*input.Input, error) {
|
||||||
var result []input.Input
|
var result []*input.Input
|
||||||
for _, inp := range inputs {
|
for _, inp := range inputs {
|
||||||
if len(inp.Multimodal) == 0 {
|
if len(inp.Multimodal) == 0 {
|
||||||
result = append(result, inp)
|
result = append(result, inp)
|
||||||
} else {
|
} else {
|
||||||
for i, row := range inp.Multimodal {
|
for i, row := range inp.Multimodal {
|
||||||
// [IMG]
|
// [IMG]
|
||||||
result = append(result, input.Input{Token: 10, Multimodal: []input.Multimodal{{Tensor: row.Tensor}}, MultimodalHash: inp.MultimodalHash, SameBatch: row.Tensor.Dim(1)})
|
result = append(result, &input.Input{Token: 10, Multimodal: []input.Multimodal{{Tensor: row.Tensor}}, MultimodalHash: inp.MultimodalHash, SameBatch: row.Tensor.Dim(1)})
|
||||||
result = append(result, slices.Repeat([]input.Input{{Token: 10}}, row.Tensor.Dim(1)-1)...)
|
result = append(result, slices.Repeat([]*input.Input{{Token: 10}}, row.Tensor.Dim(1)-1)...)
|
||||||
if i == len(inp.Multimodal)-1 {
|
if i == len(inp.Multimodal)-1 {
|
||||||
// [IMG_END]
|
// [IMG_END]
|
||||||
result = append(result, input.Input{Token: 13})
|
result = append(result, &input.Input{Token: 13})
|
||||||
} else {
|
} else {
|
||||||
// [IMG_BREAK]
|
// [IMG_BREAK]
|
||||||
result = append(result, input.Input{Token: 12})
|
result = append(result, &input.Input{Token: 12})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -159,9 +159,8 @@ func (m *Model) PostTokenize(inputs []input.Input) ([]input.Input, error) {
|
|||||||
|
|
||||||
func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
|
func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
|
||||||
positions := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions))
|
positions := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions))
|
||||||
outputs := ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs))
|
|
||||||
|
|
||||||
return m.TextModel.Forward(ctx, batch.Inputs, positions, outputs, batch, m.Cache), nil
|
return m.TextModel.Forward(ctx, batch.Inputs, positions, batch.Outputs, batch, m.Cache), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func init() {
|
func init() {
|
||||||
|
|||||||
@@ -17,7 +17,7 @@ type Model struct {
|
|||||||
model.Base
|
model.Base
|
||||||
model.BytePairEncoding
|
model.BytePairEncoding
|
||||||
|
|
||||||
*VisionModel `gguf:"v,vision"`
|
*VisionModel `gguf:"v"`
|
||||||
*TextModel
|
*TextModel
|
||||||
|
|
||||||
Projector *nn.Linear `gguf:"mm.0"`
|
Projector *nn.Linear `gguf:"mm.0"`
|
||||||
@@ -90,7 +90,7 @@ func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) ([]input
|
|||||||
return []input.Multimodal{{Tensor: projectedOutputs}}, nil
|
return []input.Multimodal{{Tensor: projectedOutputs}}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Model) PostTokenize(inputs []input.Input) ([]input.Input, error) {
|
func (m *Model) PostTokenize(inputs []*input.Input) ([]*input.Input, error) {
|
||||||
for i := range inputs {
|
for i := range inputs {
|
||||||
if inputs[i].Multimodal != nil {
|
if inputs[i].Multimodal != nil {
|
||||||
inputs[i].Token = 128256 // <|image|>
|
inputs[i].Token = 128256 // <|image|>
|
||||||
@@ -107,10 +107,9 @@ func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
positions := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions))
|
positions := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions))
|
||||||
outputs := ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs))
|
|
||||||
|
|
||||||
// TODO: attention mask, cross attention mask
|
// TODO: attention mask, cross attention mask
|
||||||
return m.TextModel.Forward(ctx, batch.Inputs, positions, outputs, crossAttentionStates, nil, m.Cache.(*kvcache.WrapperCache)), nil
|
return m.TextModel.Forward(ctx, batch.Inputs, positions, batch.Outputs, crossAttentionStates, nil, m.Cache.(*kvcache.WrapperCache)), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func init() {
|
func init() {
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
package models
|
package models
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
_ "github.com/ollama/ollama/model/models/bert"
|
||||||
_ "github.com/ollama/ollama/model/models/gemma2"
|
_ "github.com/ollama/ollama/model/models/gemma2"
|
||||||
_ "github.com/ollama/ollama/model/models/gemma3"
|
_ "github.com/ollama/ollama/model/models/gemma3"
|
||||||
_ "github.com/ollama/ollama/model/models/gemma3n"
|
_ "github.com/ollama/ollama/model/models/gemma3n"
|
||||||
|
|||||||
@@ -111,7 +111,7 @@ func (m Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
|
|||||||
|
|
||||||
var outputs ml.Tensor
|
var outputs ml.Tensor
|
||||||
if i == len(m.Layers)-1 {
|
if i == len(m.Layers)-1 {
|
||||||
outputs = ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs))
|
outputs = batch.Outputs
|
||||||
}
|
}
|
||||||
|
|
||||||
hiddenStates = layer.Forward(ctx, hiddenStates, positions, outputs, m.Cache, &m.Options)
|
hiddenStates = layer.Forward(ctx, hiddenStates, positions, outputs, m.Cache, &m.Options)
|
||||||
|
|||||||
@@ -18,7 +18,7 @@ type Model struct {
|
|||||||
model.BytePairEncoding
|
model.BytePairEncoding
|
||||||
|
|
||||||
*TextModel
|
*TextModel
|
||||||
*VisionModel `gguf:"v,vision"`
|
*VisionModel `gguf:"v"`
|
||||||
|
|
||||||
ImageProcessor
|
ImageProcessor
|
||||||
}
|
}
|
||||||
@@ -89,8 +89,8 @@ func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) ([]input
|
|||||||
}
|
}
|
||||||
|
|
||||||
// PostTokenize arranges Qwen-2.5-VL's inputs for the forward pass
|
// PostTokenize arranges Qwen-2.5-VL's inputs for the forward pass
|
||||||
func (m *Model) PostTokenize(inputs []input.Input) ([]input.Input, error) {
|
func (m *Model) PostTokenize(inputs []*input.Input) ([]*input.Input, error) {
|
||||||
var result []input.Input
|
var result []*input.Input
|
||||||
|
|
||||||
var (
|
var (
|
||||||
imageToken int32 = 151655
|
imageToken int32 = 151655
|
||||||
@@ -112,16 +112,16 @@ func (m *Model) PostTokenize(inputs []input.Input) ([]input.Input, error) {
|
|||||||
return nil, fmt.Errorf("failed to encode image prompt: %w", err)
|
return nil, fmt.Errorf("failed to encode image prompt: %w", err)
|
||||||
}
|
}
|
||||||
for i := range pre {
|
for i := range pre {
|
||||||
result = append(result, input.Input{Token: pre[i]})
|
result = append(result, &input.Input{Token: pre[i]})
|
||||||
}
|
}
|
||||||
|
|
||||||
patchesPerChunk := inp.Multimodal[0].Tensor.Dim(1)
|
patchesPerChunk := inp.Multimodal[0].Tensor.Dim(1)
|
||||||
|
|
||||||
// First add the vision start token
|
// First add the vision start token
|
||||||
result = append(result, input.Input{Token: visionStartToken})
|
result = append(result, &input.Input{Token: visionStartToken})
|
||||||
|
|
||||||
// Add the image token with the multimodal tensor data at the first position
|
// Add the image token with the multimodal tensor data at the first position
|
||||||
result = append(result, input.Input{
|
result = append(result, &input.Input{
|
||||||
Token: imageToken,
|
Token: imageToken,
|
||||||
Multimodal: inp.Multimodal,
|
Multimodal: inp.Multimodal,
|
||||||
MultimodalHash: inp.MultimodalHash,
|
MultimodalHash: inp.MultimodalHash,
|
||||||
@@ -129,9 +129,9 @@ func (m *Model) PostTokenize(inputs []input.Input) ([]input.Input, error) {
|
|||||||
})
|
})
|
||||||
|
|
||||||
// Add the placeholder tokens for the remaining positions (tokensPerGrid-1)
|
// Add the placeholder tokens for the remaining positions (tokensPerGrid-1)
|
||||||
result = append(result, slices.Repeat([]input.Input{{Token: imageToken}}, patchesPerChunk-1)...)
|
result = append(result, slices.Repeat([]*input.Input{{Token: imageToken}}, patchesPerChunk-1)...)
|
||||||
|
|
||||||
result = append(result, input.Input{Token: visionEndToken})
|
result = append(result, &input.Input{Token: visionEndToken})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -140,9 +140,8 @@ func (m *Model) PostTokenize(inputs []input.Input) ([]input.Input, error) {
|
|||||||
|
|
||||||
func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
|
func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
|
||||||
positions := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions))
|
positions := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions))
|
||||||
outputs := ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs))
|
|
||||||
|
|
||||||
return m.TextModel.Forward(ctx, batch.Inputs, positions, outputs, batch, m.Cache)
|
return m.TextModel.Forward(ctx, batch.Inputs, positions, batch.Outputs, batch, m.Cache)
|
||||||
}
|
}
|
||||||
|
|
||||||
func init() {
|
func init() {
|
||||||
|
|||||||
@@ -165,7 +165,7 @@ func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
|
|||||||
|
|
||||||
var outputs ml.Tensor
|
var outputs ml.Tensor
|
||||||
if i == len(m.Layers)-1 {
|
if i == len(m.Layers)-1 {
|
||||||
outputs = ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs))
|
outputs = batch.Outputs
|
||||||
}
|
}
|
||||||
|
|
||||||
hiddenStates = layer.Forward(ctx, hiddenStates, positions, outputs, m.Cache, m.Options)
|
hiddenStates = layer.Forward(ctx, hiddenStates, positions, outputs, m.Cache, m.Options)
|
||||||
|
|||||||
@@ -2,7 +2,6 @@ package model
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"container/heap"
|
"container/heap"
|
||||||
"context"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"log/slog"
|
"log/slog"
|
||||||
"strconv"
|
"strconv"
|
||||||
@@ -13,19 +12,19 @@ import (
|
|||||||
|
|
||||||
const spmWhitespaceSep = "▁"
|
const spmWhitespaceSep = "▁"
|
||||||
|
|
||||||
type SentencePieceModel struct {
|
type SentencePiece struct {
|
||||||
maxTokenLen int
|
maxTokenLen int
|
||||||
vocab *Vocabulary
|
vocab *Vocabulary
|
||||||
}
|
}
|
||||||
|
|
||||||
var _ TextProcessor = (*SentencePieceModel)(nil)
|
var _ TextProcessor = (*SentencePiece)(nil)
|
||||||
|
|
||||||
func (spm SentencePieceModel) Vocabulary() *Vocabulary {
|
func (spm SentencePiece) Vocabulary() *Vocabulary {
|
||||||
return spm.vocab
|
return spm.vocab
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewSentencePieceModel(vocab *Vocabulary) SentencePieceModel {
|
func NewSentencePiece(vocab *Vocabulary) SentencePiece {
|
||||||
slog.Log(context.TODO(), logutil.LevelTrace, "Tokens", "num tokens", len(vocab.Values), "vals", vocab.Values[:5], "scores", vocab.Scores[:5], "types", vocab.Types[:5])
|
logutil.Trace("Tokens", "num tokens", len(vocab.Values), "vals", vocab.Values[:5], "scores", vocab.Scores[:5], "types", vocab.Types[:5])
|
||||||
|
|
||||||
counter := map[int]int{}
|
counter := map[int]int{}
|
||||||
var maxTokenLen int
|
var maxTokenLen int
|
||||||
@@ -39,21 +38,21 @@ func NewSentencePieceModel(vocab *Vocabulary) SentencePieceModel {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
slog.Log(context.TODO(), logutil.LevelTrace, "Token counts", "normal", counter[TOKEN_TYPE_NORMAL], "unknown", counter[TOKEN_TYPE_UNKNOWN], "control", counter[TOKEN_TYPE_CONTROL],
|
logutil.Trace("Token counts", "normal", counter[TOKEN_TYPE_NORMAL], "unknown", counter[TOKEN_TYPE_UNKNOWN], "control", counter[TOKEN_TYPE_CONTROL],
|
||||||
"user defined", counter[TOKEN_TYPE_USER_DEFINED], "unused", counter[TOKEN_TYPE_UNUSED], "byte", counter[TOKEN_TYPE_BYTE],
|
"user defined", counter[TOKEN_TYPE_USER_DEFINED], "unused", counter[TOKEN_TYPE_UNUSED], "byte", counter[TOKEN_TYPE_BYTE],
|
||||||
"max token len", maxTokenLen)
|
"max token len", maxTokenLen)
|
||||||
|
|
||||||
return SentencePieceModel{
|
return SentencePiece{
|
||||||
maxTokenLen: maxTokenLen,
|
maxTokenLen: maxTokenLen,
|
||||||
vocab: vocab,
|
vocab: vocab,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (spm SentencePieceModel) Is(id int32, special Special) bool {
|
func (spm SentencePiece) Is(id int32, special Special) bool {
|
||||||
return spm.vocab.Is(id, special)
|
return spm.vocab.Is(id, special)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (spm SentencePieceModel) Encode(s string, addSpecial bool) ([]int32, error) {
|
func (spm SentencePiece) Encode(s string, addSpecial bool) ([]int32, error) {
|
||||||
fragments := []fragment{{value: s}}
|
fragments := []fragment{{value: s}}
|
||||||
for _, special := range spm.vocab.SpecialVocabulary() {
|
for _, special := range spm.vocab.SpecialVocabulary() {
|
||||||
id := spm.vocab.Encode(special)
|
id := spm.vocab.Encode(special)
|
||||||
@@ -182,12 +181,11 @@ func (spm SentencePieceModel) Encode(s string, addSpecial bool) ([]int32, error)
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
slog.Log(context.TODO(), logutil.LevelTrace, "encoded", "string", s, "ids", ids)
|
|
||||||
|
|
||||||
if addSpecial && len(ids) > 0 {
|
if addSpecial && len(ids) > 0 {
|
||||||
ids = spm.vocab.addSpecials(ids)
|
ids = spm.vocab.addSpecials(ids)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
logutil.Trace("encoded", "string", s, "ids", ids)
|
||||||
return ids, nil
|
return ids, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -220,7 +218,7 @@ func (q *queue) Pop() interface{} {
|
|||||||
return item
|
return item
|
||||||
}
|
}
|
||||||
|
|
||||||
func (spm SentencePieceModel) Decode(ids []int32) (string, error) {
|
func (spm SentencePiece) Decode(ids []int32) (string, error) {
|
||||||
var sb strings.Builder
|
var sb strings.Builder
|
||||||
for _, id := range ids {
|
for _, id := range ids {
|
||||||
data := spm.vocab.Decode(id)
|
data := spm.vocab.Decode(id)
|
||||||
@@ -246,6 +244,6 @@ func (spm SentencePieceModel) Decode(ids []int32) (string, error) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
slog.Log(context.TODO(), logutil.LevelTrace, "decoded", "ids", ids, "string", sb.String())
|
logutil.Trace("decoded", "ids", ids, "string", sb.String())
|
||||||
return sb.String(), nil
|
return sb.String(), nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -12,7 +12,7 @@ import (
|
|||||||
"github.com/ollama/ollama/convert/sentencepiece"
|
"github.com/ollama/ollama/convert/sentencepiece"
|
||||||
)
|
)
|
||||||
|
|
||||||
func loadSentencePieceVocab(t *testing.T) SentencePieceModel {
|
func loadSentencePieceVocab(t *testing.T) SentencePiece {
|
||||||
t.Helper()
|
t.Helper()
|
||||||
|
|
||||||
bts, err := os.ReadFile(filepath.Join("testdata", "gemma2", "tokenizer.model"))
|
bts, err := os.ReadFile(filepath.Join("testdata", "gemma2", "tokenizer.model"))
|
||||||
@@ -45,7 +45,7 @@ func loadSentencePieceVocab(t *testing.T) SentencePieceModel {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return NewSentencePieceModel(&v)
|
return NewSentencePiece(&v)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestSentencePieceEncode(t *testing.T) {
|
func TestSentencePieceEncode(t *testing.T) {
|
||||||
@@ -115,7 +115,7 @@ func TestSentencePieceEncode(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestSentencePieceModelDecodeByteTokens(t *testing.T) {
|
func TestSentencePieceDecodeByteTokens(t *testing.T) {
|
||||||
vocab := &Vocabulary{
|
vocab := &Vocabulary{
|
||||||
Values: []string{
|
Values: []string{
|
||||||
"normal",
|
"normal",
|
||||||
@@ -134,7 +134,7 @@ func TestSentencePieceModelDecodeByteTokens(t *testing.T) {
|
|||||||
Scores: []float32{0, 0, 0, 0, 0},
|
Scores: []float32{0, 0, 0, 0, 0},
|
||||||
}
|
}
|
||||||
|
|
||||||
spm := NewSentencePieceModel(vocab)
|
spm := NewSentencePiece(vocab)
|
||||||
|
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
|
|||||||
@@ -49,7 +49,7 @@ func (v *Vocabulary) addSpecials(ids []int32) []int32 {
|
|||||||
slog.Warn("adding bos token to prompt which already has it", "id", v.BOS)
|
slog.Warn("adding bos token to prompt which already has it", "id", v.BOS)
|
||||||
}
|
}
|
||||||
|
|
||||||
slog.Debug("adding bos token to prompt", "id", v.BOS)
|
slog.Debug("adding bos token to prompt", "id", v.BOS[0])
|
||||||
ids = append([]int32{v.BOS[0]}, ids...)
|
ids = append([]int32{v.BOS[0]}, ids...)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -58,7 +58,7 @@ func (v *Vocabulary) addSpecials(ids []int32) []int32 {
|
|||||||
slog.Warn("adding eos token to prompt which already has it", "id", v.EOS)
|
slog.Warn("adding eos token to prompt which already has it", "id", v.EOS)
|
||||||
}
|
}
|
||||||
|
|
||||||
slog.Debug("adding eos token to prompt", "id", v.EOS)
|
slog.Debug("adding eos token to prompt", "id", v.EOS[0])
|
||||||
ids = append(ids, v.EOS[0])
|
ids = append(ids, v.EOS[0])
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
167
model/wordpiece.go
Normal file
167
model/wordpiece.go
Normal file
@@ -0,0 +1,167 @@
|
|||||||
|
package model
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"iter"
|
||||||
|
"strings"
|
||||||
|
"unicode"
|
||||||
|
|
||||||
|
"github.com/ollama/ollama/logutil"
|
||||||
|
)
|
||||||
|
|
||||||
|
type WordPiece struct {
|
||||||
|
vocab *Vocabulary
|
||||||
|
}
|
||||||
|
|
||||||
|
// ggmlPrefix is the prefix used by GGML vocabularies to indicate word boundaries.
|
||||||
|
// this differs from original word piece which uses "##" to indicate subwords.
|
||||||
|
const ggmlPrefix = "▁"
|
||||||
|
|
||||||
|
var wordPieceReplacer = strings.NewReplacer(
|
||||||
|
" .", ".",
|
||||||
|
" ?", "?",
|
||||||
|
" !", "!",
|
||||||
|
" ,", ",",
|
||||||
|
" ' ", "'",
|
||||||
|
" n't", "n't",
|
||||||
|
" 'm", "'m",
|
||||||
|
" do not", " don't",
|
||||||
|
" 's", "'s",
|
||||||
|
" 've", "'ve",
|
||||||
|
" 're", "'re",
|
||||||
|
)
|
||||||
|
|
||||||
|
// Decode implements TextProcessor.
|
||||||
|
func (wpm WordPiece) Decode(ids []int32) (string, error) {
|
||||||
|
var sb strings.Builder
|
||||||
|
for i, id := range ids {
|
||||||
|
if id < 0 || int(id) >= len(wpm.vocab.Values) {
|
||||||
|
return "", fmt.Errorf("invalid token id: %d", id)
|
||||||
|
}
|
||||||
|
|
||||||
|
var separator string
|
||||||
|
piece := wpm.vocab.Values[id]
|
||||||
|
if i > 0 &&
|
||||||
|
(strings.HasPrefix(piece, ggmlPrefix) ||
|
||||||
|
(strings.HasPrefix(piece, "[") && strings.HasSuffix(piece, "]"))) {
|
||||||
|
separator = " "
|
||||||
|
}
|
||||||
|
|
||||||
|
sb.WriteString(wordPieceReplacer.Replace(separator + strings.TrimPrefix(piece, ggmlPrefix)))
|
||||||
|
}
|
||||||
|
|
||||||
|
return sb.String(), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// words splits a string into words, treating CJK characters as separate words.
|
||||||
|
// TODO: this is specifically for BERT and may need to be adjusted or refactored for other models.
|
||||||
|
func (wpm WordPiece) words(s string) iter.Seq[string] {
|
||||||
|
return func(yield func(string) bool) {
|
||||||
|
runes := make([]rune, 0, len(s)*3)
|
||||||
|
for _, r := range s {
|
||||||
|
switch {
|
||||||
|
case r >= 0x4E00 && r <= 0x9FFF,
|
||||||
|
r >= 0x3400 && r <= 0x4DBF,
|
||||||
|
r >= 0x20000 && r <= 0x2A6DF,
|
||||||
|
r >= 0x2A700 && r <= 0x2B73F,
|
||||||
|
r >= 0x2B740 && r <= 0x2B81F,
|
||||||
|
r >= 0x2B820 && r <= 0x2CEAF,
|
||||||
|
r >= 0xF900 && r <= 0xFAFF,
|
||||||
|
r >= 0x2F800 && r <= 0x2FA1F:
|
||||||
|
runes = append(runes, ' ', r, ' ')
|
||||||
|
default:
|
||||||
|
runes = append(runes, r)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for w := range strings.FieldsFuncSeq(string(runes), unicode.IsSpace) {
|
||||||
|
// split on but keep punctuation
|
||||||
|
var start int
|
||||||
|
for start < len(w) {
|
||||||
|
end := strings.IndexFunc(w[start:], unicode.IsPunct)
|
||||||
|
if end < 0 {
|
||||||
|
end = len(w) - start
|
||||||
|
} else if end == 0 {
|
||||||
|
end = 1
|
||||||
|
}
|
||||||
|
|
||||||
|
if !yield(w[start : start+end]) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
start += end
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Encode implements TextProcessor.
|
||||||
|
func (wpm WordPiece) Encode(s string, addSpecial bool) ([]int32, error) {
|
||||||
|
var ids []int32
|
||||||
|
|
||||||
|
// TODO: use [UNK] from config
|
||||||
|
unk := wpm.vocab.Encode("[UNK]")
|
||||||
|
for word := range wpm.words(s) {
|
||||||
|
var start int
|
||||||
|
var pieces []int32
|
||||||
|
for start < len(word) {
|
||||||
|
end := len(word)
|
||||||
|
|
||||||
|
var piece int32
|
||||||
|
for start < end {
|
||||||
|
subword := word[start:end]
|
||||||
|
if start == 0 {
|
||||||
|
subword = ggmlPrefix + subword
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO: some models might not want [ToLower]
|
||||||
|
piece = wpm.vocab.Encode(strings.ToLower(subword))
|
||||||
|
if piece >= 0 {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
end--
|
||||||
|
}
|
||||||
|
|
||||||
|
if piece < 0 {
|
||||||
|
// Unknown token
|
||||||
|
pieces = pieces[:0]
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
pieces = append(pieces, piece)
|
||||||
|
start = end
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(pieces) > 0 {
|
||||||
|
ids = append(ids, pieces...)
|
||||||
|
} else {
|
||||||
|
ids = append(ids, unk)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if addSpecial && len(ids) > 0 {
|
||||||
|
ids = wpm.vocab.addSpecials(ids)
|
||||||
|
}
|
||||||
|
|
||||||
|
logutil.Trace("encoded", "string", s, "ids", ids)
|
||||||
|
return ids, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Is implements TextProcessor.
|
||||||
|
func (wpm WordPiece) Is(id int32, special Special) bool {
|
||||||
|
return wpm.vocab.Is(id, special)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Vocabulary implements TextProcessor.
|
||||||
|
func (wpm WordPiece) Vocabulary() *Vocabulary {
|
||||||
|
return wpm.vocab
|
||||||
|
}
|
||||||
|
|
||||||
|
var _ TextProcessor = (*WordPiece)(nil)
|
||||||
|
|
||||||
|
func NewWordPiece(vocab *Vocabulary) WordPiece {
|
||||||
|
return WordPiece{
|
||||||
|
vocab: vocab,
|
||||||
|
}
|
||||||
|
}
|
||||||
51
model/wordpiece_test.go
Normal file
51
model/wordpiece_test.go
Normal file
@@ -0,0 +1,51 @@
|
|||||||
|
package model
|
||||||
|
|
||||||
|
import (
|
||||||
|
"slices"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/google/go-cmp/cmp"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestWordPiece(t *testing.T) {
|
||||||
|
wpm := NewWordPiece(
|
||||||
|
&Vocabulary{
|
||||||
|
Values: []string{"[UNK]", "[CLS]", "[SEP]", "▁hello", "▁world", "s", "▁!", "▁@", "▁#"},
|
||||||
|
AddBOS: true,
|
||||||
|
AddEOS: true,
|
||||||
|
BOS: []int32{1},
|
||||||
|
EOS: []int32{2},
|
||||||
|
})
|
||||||
|
|
||||||
|
ids, err := wpm.Encode("Hello world!", true)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if diff := cmp.Diff([]int32{1, 3, 4, 6, 2}, ids); diff != "" {
|
||||||
|
t.Errorf("unexpected ids (-want +got):\n%s", diff)
|
||||||
|
}
|
||||||
|
|
||||||
|
words, err := wpm.Decode(ids)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if diff := cmp.Diff("[CLS] hello world! [SEP]", words); diff != "" {
|
||||||
|
t.Errorf("unexpected words (-want +got):\n%s", diff)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestWordPieceWords(t *testing.T) {
|
||||||
|
var wpm WordPiece
|
||||||
|
|
||||||
|
basic := slices.Collect(wpm.words("Hey friend! How are you?!?"))
|
||||||
|
if diff := cmp.Diff([]string{"Hey", "friend", "!", "How", "are", "you", "?", "!", "?"}, basic); diff != "" {
|
||||||
|
t.Errorf("unexpected words (-want +got):\n%s", diff)
|
||||||
|
}
|
||||||
|
|
||||||
|
chinese := slices.Collect(wpm.words("野口里佳 Noguchi Rika"))
|
||||||
|
if diff := cmp.Diff([]string{"野", "口", "里", "佳", "Noguchi", "Rika"}, chinese); diff != "" {
|
||||||
|
t.Errorf("unexpected words (-want +got):\n%s", diff)
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -76,8 +76,9 @@ type JsonSchema struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type EmbedRequest struct {
|
type EmbedRequest struct {
|
||||||
Input any `json:"input"`
|
Input any `json:"input"`
|
||||||
Model string `json:"model"`
|
Model string `json:"model"`
|
||||||
|
Dimensions int `json:"dimensions,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type StreamOptions struct {
|
type StreamOptions struct {
|
||||||
@@ -557,12 +558,10 @@ func fromChatRequest(r ChatCompletionRequest) (*api.ChatRequest, error) {
|
|||||||
|
|
||||||
var think *api.ThinkValue
|
var think *api.ThinkValue
|
||||||
if r.Reasoning != nil {
|
if r.Reasoning != nil {
|
||||||
options["reasoning"] = *r.Reasoning.Effort
|
|
||||||
think = &api.ThinkValue{
|
think = &api.ThinkValue{
|
||||||
Value: *r.Reasoning.Effort,
|
Value: *r.Reasoning.Effort,
|
||||||
}
|
}
|
||||||
} else if r.ReasoningEffort != nil {
|
} else if r.ReasoningEffort != nil {
|
||||||
options["reasoning"] = *r.ReasoningEffort
|
|
||||||
think = &api.ThinkValue{
|
think = &api.ThinkValue{
|
||||||
Value: *r.ReasoningEffort,
|
Value: *r.ReasoningEffort,
|
||||||
}
|
}
|
||||||
@@ -1007,7 +1006,7 @@ func EmbeddingsMiddleware() gin.HandlerFunc {
|
|||||||
}
|
}
|
||||||
|
|
||||||
var b bytes.Buffer
|
var b bytes.Buffer
|
||||||
if err := json.NewEncoder(&b).Encode(api.EmbedRequest{Model: req.Model, Input: req.Input}); err != nil {
|
if err := json.NewEncoder(&b).Encode(api.EmbedRequest{Model: req.Model, Input: req.Input, Dimensions: req.Dimensions}); err != nil {
|
||||||
c.AbortWithStatusJSON(http.StatusInternalServerError, NewError(http.StatusInternalServerError, err.Error()))
|
c.AbortWithStatusJSON(http.StatusInternalServerError, NewError(http.StatusInternalServerError, err.Error()))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|||||||
115
parser/parser.go
115
parser/parser.go
@@ -62,14 +62,15 @@ func (f Modelfile) CreateRequest(relativeDir string) (*api.CreateRequest, error)
|
|||||||
for _, c := range f.Commands {
|
for _, c := range f.Commands {
|
||||||
switch c.Name {
|
switch c.Name {
|
||||||
case "model":
|
case "model":
|
||||||
path, err := expandPath(c.Args, relativeDir)
|
name := c.Args.(string)
|
||||||
|
path, err := expandPath(name, relativeDir)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
digestMap, err := fileDigestMap(path)
|
digestMap, err := fileDigestMap(path)
|
||||||
if errors.Is(err, os.ErrNotExist) {
|
if errors.Is(err, os.ErrNotExist) {
|
||||||
req.From = c.Args
|
req.From = name
|
||||||
continue
|
continue
|
||||||
} else if err != nil {
|
} else if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
@@ -83,7 +84,8 @@ func (f Modelfile) CreateRequest(relativeDir string) (*api.CreateRequest, error)
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
case "adapter":
|
case "adapter":
|
||||||
path, err := expandPath(c.Args, relativeDir)
|
adapter := c.Args.(string)
|
||||||
|
path, err := expandPath(adapter, relativeDir)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@@ -95,21 +97,25 @@ func (f Modelfile) CreateRequest(relativeDir string) (*api.CreateRequest, error)
|
|||||||
|
|
||||||
req.Adapters = digestMap
|
req.Adapters = digestMap
|
||||||
case "template":
|
case "template":
|
||||||
req.Template = c.Args
|
template := c.Args.(string)
|
||||||
|
req.Template = template
|
||||||
case "system":
|
case "system":
|
||||||
req.System = c.Args
|
system := c.Args.(string)
|
||||||
|
req.System = system
|
||||||
case "license":
|
case "license":
|
||||||
licenses = append(licenses, c.Args)
|
license := c.Args.(string)
|
||||||
|
licenses = append(licenses, license)
|
||||||
case "message":
|
case "message":
|
||||||
role, msg, _ := strings.Cut(c.Args, ": ")
|
msg := c.Args.(*Message)
|
||||||
messages = append(messages, api.Message{Role: role, Content: msg})
|
messages = append(messages, api.Message{Role: msg.Role, Content: msg.Content})
|
||||||
default:
|
case "parameter":
|
||||||
if slices.Contains(deprecatedParameters, c.Name) {
|
if slices.Contains(deprecatedParameters, c.Name) {
|
||||||
fmt.Printf("warning: parameter %s is deprecated\n", c.Name)
|
fmt.Printf("warning: parameter '%s' is deprecated\n", c.Name)
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
|
|
||||||
ps, err := api.FormatParams(map[string][]string{c.Name: {c.Args}})
|
param := c.Args.(*Parameter)
|
||||||
|
ps, err := api.FormatParams(map[string][]string{param.Name: {param.Value}})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@@ -123,6 +129,8 @@ func (f Modelfile) CreateRequest(relativeDir string) (*api.CreateRequest, error)
|
|||||||
params[k] = v
|
params[k] = v
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
default:
|
||||||
|
return nil, fmt.Errorf("warning: unknown command '%s'", c.Name)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -246,7 +254,7 @@ func filesForModel(path string) ([]string, error) {
|
|||||||
for _, match := range matches {
|
for _, match := range matches {
|
||||||
if ct, err := detectContentType(match); err != nil {
|
if ct, err := detectContentType(match); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
} else if ct != contentType {
|
} else if len(contentType) > 0 && ct != contentType {
|
||||||
return nil, fmt.Errorf("invalid content type: expected %s for %s", ct, match)
|
return nil, fmt.Errorf("invalid content type: expected %s for %s", ct, match)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -255,7 +263,8 @@ func filesForModel(path string) ([]string, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
var files []string
|
var files []string
|
||||||
if st, _ := glob(filepath.Join(path, "*.safetensors"), "application/octet-stream"); len(st) > 0 {
|
// some safetensors files do not properly match "application/octet-stream", so skip checking their contentType
|
||||||
|
if st, _ := glob(filepath.Join(path, "*.safetensors"), ""); len(st) > 0 {
|
||||||
// safetensors files might be unresolved git lfs references; skip if they are
|
// safetensors files might be unresolved git lfs references; skip if they are
|
||||||
// covers model-x-of-y.safetensors, model.fp32-x-of-y.safetensors, model.safetensors
|
// covers model-x-of-y.safetensors, model.fp32-x-of-y.safetensors, model.safetensors
|
||||||
files = append(files, st...)
|
files = append(files, st...)
|
||||||
@@ -311,7 +320,17 @@ func filesForModel(path string) ([]string, error) {
|
|||||||
|
|
||||||
type Command struct {
|
type Command struct {
|
||||||
Name string
|
Name string
|
||||||
Args string
|
Args any
|
||||||
|
}
|
||||||
|
|
||||||
|
type Parameter struct {
|
||||||
|
Name string
|
||||||
|
Value string
|
||||||
|
}
|
||||||
|
|
||||||
|
type Message struct {
|
||||||
|
Role string
|
||||||
|
Content string
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c Command) String() string {
|
func (c Command) String() string {
|
||||||
@@ -320,12 +339,16 @@ func (c Command) String() string {
|
|||||||
case "model":
|
case "model":
|
||||||
fmt.Fprintf(&sb, "FROM %s", c.Args)
|
fmt.Fprintf(&sb, "FROM %s", c.Args)
|
||||||
case "license", "template", "system", "adapter":
|
case "license", "template", "system", "adapter":
|
||||||
fmt.Fprintf(&sb, "%s %s", strings.ToUpper(c.Name), quote(c.Args))
|
data := c.Args.(string)
|
||||||
|
fmt.Fprintf(&sb, "%s %s", strings.ToUpper(c.Name), quote(data))
|
||||||
case "message":
|
case "message":
|
||||||
role, message, _ := strings.Cut(c.Args, ": ")
|
data := c.Args.(*Message)
|
||||||
fmt.Fprintf(&sb, "MESSAGE %s %s", role, quote(message))
|
fmt.Fprintf(&sb, "MESSAGE %s %s", data.Role, quote(data.Content))
|
||||||
|
case "parameter":
|
||||||
|
data := c.Args.(*Parameter)
|
||||||
|
fmt.Fprintf(&sb, "PARAMETER %s %s", data.Name, quote(data.Value))
|
||||||
default:
|
default:
|
||||||
fmt.Fprintf(&sb, "PARAMETER %s %s", c.Name, quote(c.Args))
|
fmt.Printf("unknown command '%s'\n", c.Name)
|
||||||
}
|
}
|
||||||
|
|
||||||
return sb.String()
|
return sb.String()
|
||||||
@@ -365,7 +388,6 @@ func ParseFile(r io.Reader) (*Modelfile, error) {
|
|||||||
var curr state
|
var curr state
|
||||||
var currLine int = 1
|
var currLine int = 1
|
||||||
var b bytes.Buffer
|
var b bytes.Buffer
|
||||||
var role string
|
|
||||||
|
|
||||||
var f Modelfile
|
var f Modelfile
|
||||||
|
|
||||||
@@ -412,6 +434,7 @@ func ParseFile(r io.Reader) (*Modelfile, error) {
|
|||||||
case "parameter":
|
case "parameter":
|
||||||
// transition to stateParameter which sets command name
|
// transition to stateParameter which sets command name
|
||||||
next = stateParameter
|
next = stateParameter
|
||||||
|
cmd.Name = s
|
||||||
case "message":
|
case "message":
|
||||||
// transition to stateMessage which validates the message role
|
// transition to stateMessage which validates the message role
|
||||||
next = stateMessage
|
next = stateMessage
|
||||||
@@ -420,16 +443,37 @@ func ParseFile(r io.Reader) (*Modelfile, error) {
|
|||||||
cmd.Name = s
|
cmd.Name = s
|
||||||
}
|
}
|
||||||
case stateParameter:
|
case stateParameter:
|
||||||
cmd.Name = b.String()
|
s, ok := unquote(strings.TrimSpace(b.String()))
|
||||||
|
if !ok || isSpace(r) {
|
||||||
|
if _, err := b.WriteRune(r); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
cmd.Args = &Parameter{
|
||||||
|
Name: s,
|
||||||
|
}
|
||||||
case stateMessage:
|
case stateMessage:
|
||||||
if !isValidMessageRole(b.String()) {
|
s, ok := unquote(strings.TrimSpace(b.String()))
|
||||||
|
if !ok || isSpace(r) {
|
||||||
|
if _, err := b.WriteRune(r); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if !isValidMessageRole(s) {
|
||||||
return nil, &ParserError{
|
return nil, &ParserError{
|
||||||
LineNumber: currLine,
|
LineNumber: currLine,
|
||||||
Msg: errInvalidMessageRole.Error(),
|
Msg: errInvalidMessageRole.Error(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
role = b.String()
|
cmd.Args = &Message{
|
||||||
|
Role: s,
|
||||||
|
}
|
||||||
case stateComment, stateNil:
|
case stateComment, stateNil:
|
||||||
// pass
|
// pass
|
||||||
case stateValue:
|
case stateValue:
|
||||||
@@ -442,12 +486,16 @@ func ParseFile(r io.Reader) (*Modelfile, error) {
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
if role != "" {
|
switch cmd.Name {
|
||||||
s = role + ": " + s
|
case "parameter":
|
||||||
role = ""
|
p := cmd.Args.(*Parameter)
|
||||||
|
p.Value = s
|
||||||
|
case "message":
|
||||||
|
m := cmd.Args.(*Message)
|
||||||
|
m.Content = s
|
||||||
|
default:
|
||||||
|
cmd.Args = s
|
||||||
}
|
}
|
||||||
|
|
||||||
cmd.Args = s
|
|
||||||
f.Commands = append(f.Commands, cmd)
|
f.Commands = append(f.Commands, cmd)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -472,11 +520,16 @@ func ParseFile(r io.Reader) (*Modelfile, error) {
|
|||||||
return nil, io.ErrUnexpectedEOF
|
return nil, io.ErrUnexpectedEOF
|
||||||
}
|
}
|
||||||
|
|
||||||
if role != "" {
|
switch cmd.Name {
|
||||||
s = role + ": " + s
|
case "parameter":
|
||||||
|
c := cmd.Args.(*Parameter)
|
||||||
|
c.Value = s
|
||||||
|
case "message":
|
||||||
|
c := cmd.Args.(*Message)
|
||||||
|
c.Content = s
|
||||||
|
default:
|
||||||
|
cmd.Args = s
|
||||||
}
|
}
|
||||||
|
|
||||||
cmd.Args = s
|
|
||||||
f.Commands = append(f.Commands, cmd)
|
f.Commands = append(f.Commands, cmd)
|
||||||
default:
|
default:
|
||||||
return nil, io.ErrUnexpectedEOF
|
return nil, io.ErrUnexpectedEOF
|
||||||
|
|||||||
@@ -47,8 +47,8 @@ TEMPLATE """{{ if .System }}<|start_header_id|>system<|end_header_id|>
|
|||||||
{Name: "model", Args: "model1"},
|
{Name: "model", Args: "model1"},
|
||||||
{Name: "adapter", Args: "adapter1"},
|
{Name: "adapter", Args: "adapter1"},
|
||||||
{Name: "license", Args: "MIT"},
|
{Name: "license", Args: "MIT"},
|
||||||
{Name: "param1", Args: "value1"},
|
{Name: "parameter", Args: &Parameter{"param1", "value1"}},
|
||||||
{Name: "param2", Args: "value2"},
|
{Name: "parameter", Args: &Parameter{"param2", "value2"}},
|
||||||
{Name: "template", Args: "{{ if .System }}<|start_header_id|>system<|end_header_id|>\n\n{{ .System }}<|eot_id|>{{ end }}{{ if .Prompt }}<|start_header_id|>user<|end_header_id|>\n\n{{ .Prompt }}<|eot_id|>{{ end }}<|start_header_id|>assistant<|end_header_id|>\n\n{{ .Response }}<|eot_id|>"},
|
{Name: "template", Args: "{{ if .System }}<|start_header_id|>system<|end_header_id|>\n\n{{ .System }}<|eot_id|>{{ end }}{{ if .Prompt }}<|start_header_id|>user<|end_header_id|>\n\n{{ .Prompt }}<|eot_id|>{{ end }}<|start_header_id|>assistant<|end_header_id|>\n\n{{ .Response }}<|eot_id|>"},
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -80,8 +80,8 @@ TEMPLATE """ {{ if .System }}<|start_header_id|>system<|end_header_id|>
|
|||||||
{Name: "model", Args: " model 1"},
|
{Name: "model", Args: " model 1"},
|
||||||
{Name: "adapter", Args: "adapter3"},
|
{Name: "adapter", Args: "adapter3"},
|
||||||
{Name: "license", Args: "MIT "},
|
{Name: "license", Args: "MIT "},
|
||||||
{Name: "param1", Args: "value1"},
|
{Name: "parameter", Args: &Parameter{"param1", "value1"}},
|
||||||
{Name: "param2", Args: "value2"},
|
{Name: "parameter", Args: &Parameter{"param2", "value2"}},
|
||||||
{Name: "template", Args: " {{ if .System }}<|start_header_id|>system<|end_header_id|>\n\n{{ .System }}<|eot_id|>{{ end }}{{ if .Prompt }}<|start_header_id|>user<|end_header_id|>\n\n{{ .Prompt }}<|eot_id|>{{ end }}<|start_header_id|>assistant<|end_header_id|>\n\n{{ .Response }}<|eot_id|> "},
|
{Name: "template", Args: " {{ if .System }}<|start_header_id|>system<|end_header_id|>\n\n{{ .System }}<|eot_id|>{{ end }}{{ if .Prompt }}<|start_header_id|>user<|end_header_id|>\n\n{{ .Prompt }}<|eot_id|>{{ end }}<|start_header_id|>assistant<|end_header_id|>\n\n{{ .Response }}<|eot_id|> "},
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -101,7 +101,7 @@ func TestParseFileFrom(t *testing.T) {
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"FROM \"FOO BAR\"\nPARAMETER param1 value1",
|
"FROM \"FOO BAR\"\nPARAMETER param1 value1",
|
||||||
[]Command{{Name: "model", Args: "FOO BAR"}, {Name: "param1", Args: "value1"}},
|
[]Command{{Name: "model", Args: "FOO BAR"}, {Name: "parameter", Args: &Parameter{"param1", "value1"}}},
|
||||||
nil,
|
nil,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@@ -149,12 +149,12 @@ func TestParseFileFrom(t *testing.T) {
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"PARAMETER param1 value1\nFROM foo",
|
"PARAMETER param1 value1\nFROM foo",
|
||||||
[]Command{{Name: "param1", Args: "value1"}, {Name: "model", Args: "foo"}},
|
[]Command{{Name: "parameter", Args: &Parameter{"param1", "value1"}}, {Name: "model", Args: "foo"}},
|
||||||
nil,
|
nil,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"PARAMETER what the \nFROM lemons make lemonade ",
|
"PARAMETER what the \nFROM lemons make lemonade ",
|
||||||
[]Command{{Name: "what", Args: "the"}, {Name: "model", Args: "lemons make lemonade"}},
|
[]Command{{Name: "parameter", Args: &Parameter{"what", "the"}}, {Name: "model", Args: "lemons make lemonade"}},
|
||||||
nil,
|
nil,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
@@ -211,7 +211,7 @@ MESSAGE system You are a file parser. Always parse things.
|
|||||||
`,
|
`,
|
||||||
[]Command{
|
[]Command{
|
||||||
{Name: "model", Args: "foo"},
|
{Name: "model", Args: "foo"},
|
||||||
{Name: "message", Args: "system: You are a file parser. Always parse things."},
|
{Name: "message", Args: &Message{"system", "You are a file parser. Always parse things."}},
|
||||||
},
|
},
|
||||||
nil,
|
nil,
|
||||||
},
|
},
|
||||||
@@ -221,7 +221,7 @@ FROM foo
|
|||||||
MESSAGE system You are a file parser. Always parse things.`,
|
MESSAGE system You are a file parser. Always parse things.`,
|
||||||
[]Command{
|
[]Command{
|
||||||
{Name: "model", Args: "foo"},
|
{Name: "model", Args: "foo"},
|
||||||
{Name: "message", Args: "system: You are a file parser. Always parse things."},
|
{Name: "message", Args: &Message{"system", "You are a file parser. Always parse things."}},
|
||||||
},
|
},
|
||||||
nil,
|
nil,
|
||||||
},
|
},
|
||||||
@@ -234,9 +234,9 @@ MESSAGE assistant Hello, I want to parse all the things!
|
|||||||
`,
|
`,
|
||||||
[]Command{
|
[]Command{
|
||||||
{Name: "model", Args: "foo"},
|
{Name: "model", Args: "foo"},
|
||||||
{Name: "message", Args: "system: You are a file parser. Always parse things."},
|
{Name: "message", Args: &Message{"system", "You are a file parser. Always parse things."}},
|
||||||
{Name: "message", Args: "user: Hey there!"},
|
{Name: "message", Args: &Message{"user", "Hey there!"}},
|
||||||
{Name: "message", Args: "assistant: Hello, I want to parse all the things!"},
|
{Name: "message", Args: &Message{"assistant", "Hello, I want to parse all the things!"}},
|
||||||
},
|
},
|
||||||
nil,
|
nil,
|
||||||
},
|
},
|
||||||
@@ -244,12 +244,12 @@ MESSAGE assistant Hello, I want to parse all the things!
|
|||||||
`
|
`
|
||||||
FROM foo
|
FROM foo
|
||||||
MESSAGE system """
|
MESSAGE system """
|
||||||
You are a multiline file parser. Always parse things.
|
You are a multiline file "parser". Always parse things.
|
||||||
"""
|
"""
|
||||||
`,
|
`,
|
||||||
[]Command{
|
[]Command{
|
||||||
{Name: "model", Args: "foo"},
|
{Name: "model", Args: "foo"},
|
||||||
{Name: "message", Args: "system: \nYou are a multiline file parser. Always parse things.\n"},
|
{Name: "message", Args: &Message{"system", "\nYou are a multiline file \"parser\". Always parse things.\n"}},
|
||||||
},
|
},
|
||||||
nil,
|
nil,
|
||||||
},
|
},
|
||||||
@@ -514,7 +514,7 @@ func TestParseFileParameters(t *testing.T) {
|
|||||||
|
|
||||||
assert.Equal(t, []Command{
|
assert.Equal(t, []Command{
|
||||||
{Name: "model", Args: "foo"},
|
{Name: "model", Args: "foo"},
|
||||||
{Name: v.name, Args: v.value},
|
{Name: "parameter", Args: &Parameter{v.name, v.value}},
|
||||||
}, modelfile.Commands)
|
}, modelfile.Commands)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
@@ -617,8 +617,8 @@ SYSTEM You are a utf16 file.
|
|||||||
|
|
||||||
expected := []Command{
|
expected := []Command{
|
||||||
{Name: "model", Args: "bob"},
|
{Name: "model", Args: "bob"},
|
||||||
{Name: "param1", Args: "1"},
|
{Name: "parameter", Args: &Parameter{"param1", "1"}},
|
||||||
{Name: "param2", Args: "4096"},
|
{Name: "parameter", Args: &Parameter{"param2", "4096"}},
|
||||||
{Name: "system", Args: "You are a utf16 file."},
|
{Name: "system", Args: "You are a utf16 file."},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -46,7 +46,7 @@ func NewInputCache(lc *llama.Context, kvSize int, numSlots int, multiUserCache b
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Locking: Operations on InputCacheSlot (including finding one
|
// Locking: Operations on InputCacheSlot (including finding one
|
||||||
// through LoadCacheSlot) require a lock to be be held that serializes
|
// through LoadCacheSlot) require a lock to be held that serializes
|
||||||
// these operations with each other and llama.Decode
|
// these operations with each other and llama.Decode
|
||||||
|
|
||||||
type InputCacheSlot struct {
|
type InputCacheSlot struct {
|
||||||
|
|||||||
@@ -34,8 +34,8 @@ type InputCache struct {
|
|||||||
func NewInputCache(model model.Model, kvCacheType string, kvSize int32, numSlots int, batchSize int, multiUserCache bool) (*InputCache, error) {
|
func NewInputCache(model model.Model, kvCacheType string, kvSize int32, numSlots int, batchSize int, multiUserCache bool) (*InputCache, error) {
|
||||||
numCtx := kvSize / int32(numSlots)
|
numCtx := kvSize / int32(numSlots)
|
||||||
|
|
||||||
if numCtx < 1 {
|
if int(numCtx) < batchSize {
|
||||||
return nil, fmt.Errorf("must have at least one kv cache entry per parallel sequence (kv: %v parallel: %v)", kvSize, numSlots)
|
return nil, fmt.Errorf("kv size must be at least as large as batch size * parallel (kv: %v batch: %v parallel: %v)", kvSize, batchSize, numSlots)
|
||||||
}
|
}
|
||||||
|
|
||||||
slots := make([]InputCacheSlot, numSlots)
|
slots := make([]InputCacheSlot, numSlots)
|
||||||
@@ -70,15 +70,13 @@ func kvCacheTypeFromStr(s string) ml.DType {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (c *InputCache) Close() {
|
func (c *InputCache) Close() {
|
||||||
if c == nil {
|
if c != nil && c.cache != nil {
|
||||||
return
|
c.cache.Close()
|
||||||
}
|
}
|
||||||
|
|
||||||
c.cache.Close()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Locking: Operations on InputCacheSlot (including finding one
|
// Locking: Operations on InputCacheSlot (including finding one
|
||||||
// through LoadCacheSlot) require a lock to be be held that serializes
|
// through LoadCacheSlot) require a lock to be held that serializes
|
||||||
// these operations with each other and processBatch
|
// these operations with each other and processBatch
|
||||||
|
|
||||||
type InputCacheSlot struct {
|
type InputCacheSlot struct {
|
||||||
@@ -86,7 +84,7 @@ type InputCacheSlot struct {
|
|||||||
Id int
|
Id int
|
||||||
|
|
||||||
// Inputs that are stored in the KV cache
|
// Inputs that are stored in the KV cache
|
||||||
Inputs []input.Input
|
Inputs []*input.Input
|
||||||
|
|
||||||
// is this cache actively being processed as part of a sequence?
|
// is this cache actively being processed as part of a sequence?
|
||||||
InUse bool
|
InUse bool
|
||||||
@@ -95,7 +93,7 @@ type InputCacheSlot struct {
|
|||||||
lastUsed time.Time
|
lastUsed time.Time
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *InputCache) LoadCacheSlot(prompt []input.Input) (*InputCacheSlot, []input.Input, error) {
|
func (c *InputCache) LoadCacheSlot(prompt []*input.Input, cachePrompt bool) (*InputCacheSlot, []*input.Input, error) {
|
||||||
var slot *InputCacheSlot
|
var slot *InputCacheSlot
|
||||||
var numPast int32
|
var numPast int32
|
||||||
var err error
|
var err error
|
||||||
@@ -113,6 +111,10 @@ func (c *InputCache) LoadCacheSlot(prompt []input.Input) (*InputCacheSlot, []inp
|
|||||||
return nil, nil, err
|
return nil, nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if !cachePrompt {
|
||||||
|
numPast = 0
|
||||||
|
}
|
||||||
|
|
||||||
slot.InUse = true
|
slot.InUse = true
|
||||||
slot.lastUsed = time.Now()
|
slot.lastUsed = time.Now()
|
||||||
|
|
||||||
@@ -146,7 +148,7 @@ func (c *InputCache) LoadCacheSlot(prompt []input.Input) (*InputCacheSlot, []inp
|
|||||||
return slot, prompt, nil
|
return slot, prompt, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *InputCache) findLongestCacheSlot(prompt []input.Input) (*InputCacheSlot, int32, error) {
|
func (c *InputCache) findLongestCacheSlot(prompt []*input.Input) (*InputCacheSlot, int32, error) {
|
||||||
longest := int32(-1)
|
longest := int32(-1)
|
||||||
var longestSlot *InputCacheSlot
|
var longestSlot *InputCacheSlot
|
||||||
|
|
||||||
@@ -169,7 +171,7 @@ func (c *InputCache) findLongestCacheSlot(prompt []input.Input) (*InputCacheSlot
|
|||||||
return longestSlot, longest, nil
|
return longestSlot, longest, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *InputCache) findBestCacheSlot(prompt []input.Input) (*InputCacheSlot, int32, error) {
|
func (c *InputCache) findBestCacheSlot(prompt []*input.Input) (*InputCacheSlot, int32, error) {
|
||||||
oldest := time.Now()
|
oldest := time.Now()
|
||||||
var oldestSlot *InputCacheSlot
|
var oldestSlot *InputCacheSlot
|
||||||
|
|
||||||
@@ -205,7 +207,7 @@ func (c *InputCache) findBestCacheSlot(prompt []input.Input) (*InputCacheSlot, i
|
|||||||
if longest > 0 && longestSlot != oldestSlot {
|
if longest > 0 && longestSlot != oldestSlot {
|
||||||
slog.Debug("forking cache slot", "src", longestSlot.Id, "dst", oldestSlot.Id, "inputs", longest, "total",
|
slog.Debug("forking cache slot", "src", longestSlot.Id, "dst", oldestSlot.Id, "inputs", longest, "total",
|
||||||
len(longestSlot.Inputs))
|
len(longestSlot.Inputs))
|
||||||
oldestSlot.Inputs = make([]input.Input, longest)
|
oldestSlot.Inputs = make([]*input.Input, longest)
|
||||||
copy(oldestSlot.Inputs, longestSlot.Inputs[:longest])
|
copy(oldestSlot.Inputs, longestSlot.Inputs[:longest])
|
||||||
if c.cache != nil {
|
if c.cache != nil {
|
||||||
c.cache.CopyPrefix(longestSlot.Id, oldestSlot.Id, longest)
|
c.cache.CopyPrefix(longestSlot.Id, oldestSlot.Id, longest)
|
||||||
@@ -215,7 +217,7 @@ func (c *InputCache) findBestCacheSlot(prompt []input.Input) (*InputCacheSlot, i
|
|||||||
return oldestSlot, longest, nil
|
return oldestSlot, longest, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func countCommonPrefix(a []input.Input, b []input.Input) int32 {
|
func countCommonPrefix(a []*input.Input, b []*input.Input) int32 {
|
||||||
var count int32
|
var count int32
|
||||||
|
|
||||||
for i := range a {
|
for i := range a {
|
||||||
@@ -250,7 +252,7 @@ func (c *InputCache) ShiftDiscard(inputLen int32, numKeep int32) int32 {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type ErrReprocessInputs struct {
|
type ErrReprocessInputs struct {
|
||||||
Inputs []input.Input
|
Inputs []*input.Input
|
||||||
}
|
}
|
||||||
|
|
||||||
func (e *ErrReprocessInputs) Error() string {
|
func (e *ErrReprocessInputs) Error() string {
|
||||||
@@ -283,13 +285,13 @@ func (c *InputCache) ShiftCacheSlot(slot *InputCacheSlot, numKeep int32) error {
|
|||||||
"id", slot.Id, "error", err)
|
"id", slot.Id, "error", err)
|
||||||
|
|
||||||
// Create new input slice with preserved tokens (numKeep + remaining tokens after discard)
|
// Create new input slice with preserved tokens (numKeep + remaining tokens after discard)
|
||||||
newInputs := make([]input.Input, numKeep+inputLen-(numKeep+discard))
|
newInputs := make([]*input.Input, numKeep+inputLen-(numKeep+discard))
|
||||||
copy(newInputs[:numKeep], slot.Inputs[:numKeep])
|
copy(newInputs[:numKeep], slot.Inputs[:numKeep])
|
||||||
copy(newInputs[numKeep:], slot.Inputs[numKeep+discard:])
|
copy(newInputs[numKeep:], slot.Inputs[numKeep+discard:])
|
||||||
|
|
||||||
// Reset the cache
|
// Reset the cache
|
||||||
_ = c.cache.Remove(slot.Id, 0, math.MaxInt32)
|
_ = c.cache.Remove(slot.Id, 0, math.MaxInt32)
|
||||||
slot.Inputs = []input.Input{}
|
slot.Inputs = []*input.Input{}
|
||||||
|
|
||||||
// Return error with inputs that need to be reprocessed
|
// Return error with inputs that need to be reprocessed
|
||||||
return &ErrReprocessInputs{Inputs: newInputs}
|
return &ErrReprocessInputs{Inputs: newInputs}
|
||||||
|
|||||||
@@ -13,50 +13,50 @@ import (
|
|||||||
func TestCountCommon(t *testing.T) {
|
func TestCountCommon(t *testing.T) {
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
t1 []input.Input
|
t1 []*input.Input
|
||||||
t2 []input.Input
|
t2 []*input.Input
|
||||||
expected int32
|
expected int32
|
||||||
}{
|
}{
|
||||||
{
|
{
|
||||||
name: "Equal",
|
name: "Equal",
|
||||||
t1: []input.Input{{Token: 1}, {Token: 2}, {Token: 3}},
|
t1: []*input.Input{{Token: 1}, {Token: 2}, {Token: 3}},
|
||||||
t2: []input.Input{{Token: 1}, {Token: 2}, {Token: 3}},
|
t2: []*input.Input{{Token: 1}, {Token: 2}, {Token: 3}},
|
||||||
expected: 3,
|
expected: 3,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "Prefix",
|
name: "Prefix",
|
||||||
t1: []input.Input{{Token: 1}},
|
t1: []*input.Input{{Token: 1}},
|
||||||
t2: []input.Input{{Token: 1}, {Token: 2}, {Token: 3}},
|
t2: []*input.Input{{Token: 1}, {Token: 2}, {Token: 3}},
|
||||||
expected: 1,
|
expected: 1,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "Image Prefix",
|
name: "Image Prefix",
|
||||||
t1: []input.Input{{MultimodalHash: 1}},
|
t1: []*input.Input{{MultimodalHash: 1}},
|
||||||
t2: []input.Input{{MultimodalHash: 1}, {MultimodalHash: 2}, {MultimodalHash: 3}},
|
t2: []*input.Input{{MultimodalHash: 1}, {MultimodalHash: 2}, {MultimodalHash: 3}},
|
||||||
expected: 1,
|
expected: 1,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "Mixed",
|
name: "Mixed",
|
||||||
t1: []input.Input{{Token: 1}, {MultimodalHash: 1}},
|
t1: []*input.Input{{Token: 1}, {MultimodalHash: 1}},
|
||||||
t2: []input.Input{{Token: 1}, {MultimodalHash: 1}, {Token: 5}},
|
t2: []*input.Input{{Token: 1}, {MultimodalHash: 1}, {Token: 5}},
|
||||||
expected: 2,
|
expected: 2,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "Mixed, Same Length",
|
name: "Mixed, Same Length",
|
||||||
t1: []input.Input{{Token: 1}, {MultimodalHash: 1}},
|
t1: []*input.Input{{Token: 1}, {MultimodalHash: 1}},
|
||||||
t2: []input.Input{{Token: 1}, {MultimodalHash: 2}},
|
t2: []*input.Input{{Token: 1}, {MultimodalHash: 2}},
|
||||||
expected: 1,
|
expected: 1,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "Empty",
|
name: "Empty",
|
||||||
t1: []input.Input{},
|
t1: []*input.Input{},
|
||||||
t2: []input.Input{{Token: 1}, {Token: 2}, {Token: 3}},
|
t2: []*input.Input{{Token: 1}, {Token: 2}, {Token: 3}},
|
||||||
expected: 0,
|
expected: 0,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "Both Empty",
|
name: "Both Empty",
|
||||||
t1: []input.Input{},
|
t1: []*input.Input{},
|
||||||
t2: []input.Input{},
|
t2: []*input.Input{},
|
||||||
expected: 0,
|
expected: 0,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
@@ -80,7 +80,7 @@ func TestFindCacheSlot(t *testing.T) {
|
|||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
cache InputCache
|
cache InputCache
|
||||||
prompt []input.Input
|
prompt []*input.Input
|
||||||
longest expected
|
longest expected
|
||||||
best expected
|
best expected
|
||||||
}{
|
}{
|
||||||
@@ -89,18 +89,18 @@ func TestFindCacheSlot(t *testing.T) {
|
|||||||
cache: InputCache{slots: []InputCacheSlot{
|
cache: InputCache{slots: []InputCacheSlot{
|
||||||
{
|
{
|
||||||
Id: 0,
|
Id: 0,
|
||||||
Inputs: []input.Input{},
|
Inputs: []*input.Input{},
|
||||||
InUse: false,
|
InUse: false,
|
||||||
lastUsed: time.Time{},
|
lastUsed: time.Time{},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
Id: 1,
|
Id: 1,
|
||||||
Inputs: []input.Input{},
|
Inputs: []*input.Input{},
|
||||||
InUse: false,
|
InUse: false,
|
||||||
lastUsed: time.Time{},
|
lastUsed: time.Time{},
|
||||||
},
|
},
|
||||||
}},
|
}},
|
||||||
prompt: []input.Input{{Token: 1}},
|
prompt: []*input.Input{{Token: 1}},
|
||||||
longest: expected{result: 0, len: 0},
|
longest: expected{result: 0, len: 0},
|
||||||
best: expected{result: 0, len: 0},
|
best: expected{result: 0, len: 0},
|
||||||
},
|
},
|
||||||
@@ -109,18 +109,18 @@ func TestFindCacheSlot(t *testing.T) {
|
|||||||
cache: InputCache{slots: []InputCacheSlot{
|
cache: InputCache{slots: []InputCacheSlot{
|
||||||
{
|
{
|
||||||
Id: 0,
|
Id: 0,
|
||||||
Inputs: []input.Input{{Token: 1}},
|
Inputs: []*input.Input{{Token: 1}},
|
||||||
InUse: false,
|
InUse: false,
|
||||||
lastUsed: time.Now().Add(-time.Second),
|
lastUsed: time.Now().Add(-time.Second),
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
Id: 1,
|
Id: 1,
|
||||||
Inputs: []input.Input{{Token: 1}, {Token: 2}},
|
Inputs: []*input.Input{{Token: 1}, {Token: 2}},
|
||||||
InUse: false,
|
InUse: false,
|
||||||
lastUsed: time.Now().Add(-2 * time.Second),
|
lastUsed: time.Now().Add(-2 * time.Second),
|
||||||
},
|
},
|
||||||
}},
|
}},
|
||||||
prompt: []input.Input{{Token: 1}, {Token: 2}},
|
prompt: []*input.Input{{Token: 1}, {Token: 2}},
|
||||||
longest: expected{result: 1, len: 2},
|
longest: expected{result: 1, len: 2},
|
||||||
best: expected{result: 1, len: 2},
|
best: expected{result: 1, len: 2},
|
||||||
},
|
},
|
||||||
@@ -129,18 +129,18 @@ func TestFindCacheSlot(t *testing.T) {
|
|||||||
cache: InputCache{slots: []InputCacheSlot{
|
cache: InputCache{slots: []InputCacheSlot{
|
||||||
{
|
{
|
||||||
Id: 0,
|
Id: 0,
|
||||||
Inputs: []input.Input{{Token: 1}, {Token: 2}},
|
Inputs: []*input.Input{{Token: 1}, {Token: 2}},
|
||||||
InUse: false,
|
InUse: false,
|
||||||
lastUsed: time.Now().Add(-time.Second),
|
lastUsed: time.Now().Add(-time.Second),
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
Id: 1,
|
Id: 1,
|
||||||
Inputs: []input.Input{},
|
Inputs: []*input.Input{},
|
||||||
InUse: false,
|
InUse: false,
|
||||||
lastUsed: time.Time{},
|
lastUsed: time.Time{},
|
||||||
},
|
},
|
||||||
}},
|
}},
|
||||||
prompt: []input.Input{{Token: 2}},
|
prompt: []*input.Input{{Token: 2}},
|
||||||
longest: expected{result: 0, len: 0},
|
longest: expected{result: 0, len: 0},
|
||||||
best: expected{result: 1, len: 0},
|
best: expected{result: 1, len: 0},
|
||||||
},
|
},
|
||||||
@@ -150,19 +150,19 @@ func TestFindCacheSlot(t *testing.T) {
|
|||||||
slots: []InputCacheSlot{
|
slots: []InputCacheSlot{
|
||||||
{
|
{
|
||||||
Id: 0,
|
Id: 0,
|
||||||
Inputs: []input.Input{{Token: 1}, {Token: 2}},
|
Inputs: []*input.Input{{Token: 1}, {Token: 2}},
|
||||||
InUse: false,
|
InUse: false,
|
||||||
lastUsed: time.Now().Add(-time.Second),
|
lastUsed: time.Now().Add(-time.Second),
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
Id: 1,
|
Id: 1,
|
||||||
Inputs: []input.Input{},
|
Inputs: []*input.Input{},
|
||||||
InUse: false,
|
InUse: false,
|
||||||
lastUsed: time.Time{},
|
lastUsed: time.Time{},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
prompt: []input.Input{{Token: 1}},
|
prompt: []*input.Input{{Token: 1}},
|
||||||
longest: expected{result: 0, len: 1},
|
longest: expected{result: 0, len: 1},
|
||||||
best: expected{result: 1, len: 1},
|
best: expected{result: 1, len: 1},
|
||||||
},
|
},
|
||||||
@@ -171,18 +171,18 @@ func TestFindCacheSlot(t *testing.T) {
|
|||||||
cache: InputCache{slots: []InputCacheSlot{
|
cache: InputCache{slots: []InputCacheSlot{
|
||||||
{
|
{
|
||||||
Id: 0,
|
Id: 0,
|
||||||
Inputs: []input.Input{{Token: 1}},
|
Inputs: []*input.Input{{Token: 1}},
|
||||||
InUse: false,
|
InUse: false,
|
||||||
lastUsed: time.Now().Add(-time.Second),
|
lastUsed: time.Now().Add(-time.Second),
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
Id: 1,
|
Id: 1,
|
||||||
Inputs: []input.Input{{Token: 1}, {Token: 2}},
|
Inputs: []*input.Input{{Token: 1}, {Token: 2}},
|
||||||
InUse: false,
|
InUse: false,
|
||||||
lastUsed: time.Now().Add(-2 * time.Second),
|
lastUsed: time.Now().Add(-2 * time.Second),
|
||||||
},
|
},
|
||||||
}},
|
}},
|
||||||
prompt: []input.Input{{Token: 2}, {Token: 3}},
|
prompt: []*input.Input{{Token: 2}, {Token: 3}},
|
||||||
longest: expected{result: 0, len: 0},
|
longest: expected{result: 0, len: 0},
|
||||||
best: expected{result: 1, len: 0},
|
best: expected{result: 1, len: 0},
|
||||||
},
|
},
|
||||||
@@ -191,18 +191,18 @@ func TestFindCacheSlot(t *testing.T) {
|
|||||||
cache: InputCache{slots: []InputCacheSlot{
|
cache: InputCache{slots: []InputCacheSlot{
|
||||||
{
|
{
|
||||||
Id: 0,
|
Id: 0,
|
||||||
Inputs: []input.Input{{Token: 1}, {Token: 2}},
|
Inputs: []*input.Input{{Token: 1}, {Token: 2}},
|
||||||
InUse: true,
|
InUse: true,
|
||||||
lastUsed: time.Now().Add(-time.Second),
|
lastUsed: time.Now().Add(-time.Second),
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
Id: 1,
|
Id: 1,
|
||||||
Inputs: []input.Input{{Token: 1}},
|
Inputs: []*input.Input{{Token: 1}},
|
||||||
InUse: false,
|
InUse: false,
|
||||||
lastUsed: time.Now().Add(-2 * time.Second),
|
lastUsed: time.Now().Add(-2 * time.Second),
|
||||||
},
|
},
|
||||||
}},
|
}},
|
||||||
prompt: []input.Input{{Token: 1}, {Token: 2}},
|
prompt: []*input.Input{{Token: 1}, {Token: 2}},
|
||||||
longest: expected{result: 1, len: 1},
|
longest: expected{result: 1, len: 1},
|
||||||
best: expected{result: 1, len: 2},
|
best: expected{result: 1, len: 2},
|
||||||
},
|
},
|
||||||
@@ -300,7 +300,7 @@ func TestLoadCacheSlot(t *testing.T) {
|
|||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
cache InputCache
|
cache InputCache
|
||||||
prompt []input.Input
|
prompt []*input.Input
|
||||||
wantErr bool
|
wantErr bool
|
||||||
expectedSlotId int
|
expectedSlotId int
|
||||||
expectedPrompt int // expected length of remaining prompt
|
expectedPrompt int // expected length of remaining prompt
|
||||||
@@ -312,19 +312,19 @@ func TestLoadCacheSlot(t *testing.T) {
|
|||||||
slots: []InputCacheSlot{
|
slots: []InputCacheSlot{
|
||||||
{
|
{
|
||||||
Id: 0,
|
Id: 0,
|
||||||
Inputs: []input.Input{{Token: 1}, {Token: 2}},
|
Inputs: []*input.Input{{Token: 1}, {Token: 2}},
|
||||||
InUse: false,
|
InUse: false,
|
||||||
lastUsed: time.Now().Add(-time.Second),
|
lastUsed: time.Now().Add(-time.Second),
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
Id: 1,
|
Id: 1,
|
||||||
Inputs: []input.Input{},
|
Inputs: []*input.Input{},
|
||||||
InUse: false,
|
InUse: false,
|
||||||
lastUsed: time.Now().Add(-2 * time.Second),
|
lastUsed: time.Now().Add(-2 * time.Second),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
prompt: []input.Input{{Token: 1}, {Token: 2}, {Token: 3}},
|
prompt: []*input.Input{{Token: 1}, {Token: 2}, {Token: 3}},
|
||||||
wantErr: false,
|
wantErr: false,
|
||||||
expectedSlotId: 0,
|
expectedSlotId: 0,
|
||||||
expectedPrompt: 1, // Only token 3 remains
|
expectedPrompt: 1, // Only token 3 remains
|
||||||
@@ -336,19 +336,19 @@ func TestLoadCacheSlot(t *testing.T) {
|
|||||||
slots: []InputCacheSlot{
|
slots: []InputCacheSlot{
|
||||||
{
|
{
|
||||||
Id: 0,
|
Id: 0,
|
||||||
Inputs: []input.Input{{Token: 1}, {Token: 2}},
|
Inputs: []*input.Input{{Token: 1}, {Token: 2}},
|
||||||
InUse: false,
|
InUse: false,
|
||||||
lastUsed: time.Now().Add(-time.Second),
|
lastUsed: time.Now().Add(-time.Second),
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
Id: 1,
|
Id: 1,
|
||||||
Inputs: []input.Input{},
|
Inputs: []*input.Input{},
|
||||||
InUse: false,
|
InUse: false,
|
||||||
lastUsed: time.Now().Add(-2 * time.Second),
|
lastUsed: time.Now().Add(-2 * time.Second),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
prompt: []input.Input{{Token: 1}, {Token: 2}, {Token: 3}},
|
prompt: []*input.Input{{Token: 1}, {Token: 2}, {Token: 3}},
|
||||||
wantErr: false,
|
wantErr: false,
|
||||||
expectedSlotId: 0,
|
expectedSlotId: 0,
|
||||||
expectedPrompt: 1, // Only token 3 remains
|
expectedPrompt: 1, // Only token 3 remains
|
||||||
@@ -360,13 +360,13 @@ func TestLoadCacheSlot(t *testing.T) {
|
|||||||
slots: []InputCacheSlot{
|
slots: []InputCacheSlot{
|
||||||
{
|
{
|
||||||
Id: 0,
|
Id: 0,
|
||||||
Inputs: []input.Input{{Token: 1}, {Token: 2}},
|
Inputs: []*input.Input{{Token: 1}, {Token: 2}},
|
||||||
InUse: false,
|
InUse: false,
|
||||||
lastUsed: time.Now().Add(-time.Second),
|
lastUsed: time.Now().Add(-time.Second),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
prompt: []input.Input{{Token: 1}, {Token: 2}},
|
prompt: []*input.Input{{Token: 1}, {Token: 2}},
|
||||||
wantErr: false,
|
wantErr: false,
|
||||||
expectedSlotId: 0,
|
expectedSlotId: 0,
|
||||||
expectedPrompt: 1, // Should leave 1 token for sampling
|
expectedPrompt: 1, // Should leave 1 token for sampling
|
||||||
@@ -378,13 +378,13 @@ func TestLoadCacheSlot(t *testing.T) {
|
|||||||
slots: []InputCacheSlot{
|
slots: []InputCacheSlot{
|
||||||
{
|
{
|
||||||
Id: 0,
|
Id: 0,
|
||||||
Inputs: []input.Input{{Token: 1}, {Token: 2}},
|
Inputs: []*input.Input{{Token: 1}, {Token: 2}},
|
||||||
InUse: true,
|
InUse: true,
|
||||||
lastUsed: time.Now().Add(-time.Second),
|
lastUsed: time.Now().Add(-time.Second),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
prompt: []input.Input{{Token: 1}, {Token: 2}, {Token: 3}},
|
prompt: []*input.Input{{Token: 1}, {Token: 2}, {Token: 3}},
|
||||||
wantErr: true,
|
wantErr: true,
|
||||||
expectedSlotId: -1,
|
expectedSlotId: -1,
|
||||||
expectedPrompt: -1,
|
expectedPrompt: -1,
|
||||||
@@ -393,7 +393,7 @@ func TestLoadCacheSlot(t *testing.T) {
|
|||||||
|
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
slot, remainingPrompt, err := tt.cache.LoadCacheSlot(tt.prompt)
|
slot, remainingPrompt, err := tt.cache.LoadCacheSlot(tt.prompt, true)
|
||||||
|
|
||||||
// Check error state
|
// Check error state
|
||||||
if (err != nil) != tt.wantErr {
|
if (err != nil) != tt.wantErr {
|
||||||
@@ -452,7 +452,7 @@ func TestShiftCacheSlot(t *testing.T) {
|
|||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
numCtx int32
|
numCtx int32
|
||||||
inputs []input.Input
|
inputs []*input.Input
|
||||||
numKeep int32
|
numKeep int32
|
||||||
cacheErr bool
|
cacheErr bool
|
||||||
wantErr any
|
wantErr any
|
||||||
@@ -461,7 +461,7 @@ func TestShiftCacheSlot(t *testing.T) {
|
|||||||
{
|
{
|
||||||
name: "Normal shift",
|
name: "Normal shift",
|
||||||
numCtx: 10,
|
numCtx: 10,
|
||||||
inputs: []input.Input{{Token: 1}, {Token: 2}, {Token: 3}, {Token: 4}, {Token: 5}, {Token: 6}, {Token: 7}, {Token: 8}, {Token: 9}, {Token: 10}},
|
inputs: []*input.Input{{Token: 1}, {Token: 2}, {Token: 3}, {Token: 4}, {Token: 5}, {Token: 6}, {Token: 7}, {Token: 8}, {Token: 9}, {Token: 10}},
|
||||||
numKeep: 2,
|
numKeep: 2,
|
||||||
cacheErr: false, // No error
|
cacheErr: false, // No error
|
||||||
wantErr: nil,
|
wantErr: nil,
|
||||||
@@ -470,7 +470,7 @@ func TestShiftCacheSlot(t *testing.T) {
|
|||||||
{
|
{
|
||||||
name: "Cache removal fails",
|
name: "Cache removal fails",
|
||||||
numCtx: 10,
|
numCtx: 10,
|
||||||
inputs: []input.Input{{Token: 1}, {Token: 2}, {Token: 3}, {Token: 4}, {Token: 5}, {Token: 6}, {Token: 7}, {Token: 8}, {Token: 9}, {Token: 10}},
|
inputs: []*input.Input{{Token: 1}, {Token: 2}, {Token: 3}, {Token: 4}, {Token: 5}, {Token: 6}, {Token: 7}, {Token: 8}, {Token: 9}, {Token: 10}},
|
||||||
numKeep: 2,
|
numKeep: 2,
|
||||||
cacheErr: true,
|
cacheErr: true,
|
||||||
wantErr: &ErrReprocessInputs{},
|
wantErr: &ErrReprocessInputs{},
|
||||||
@@ -487,7 +487,7 @@ func TestShiftCacheSlot(t *testing.T) {
|
|||||||
}
|
}
|
||||||
slot := &InputCacheSlot{
|
slot := &InputCacheSlot{
|
||||||
Id: 123,
|
Id: 123,
|
||||||
Inputs: make([]input.Input, len(tt.inputs)),
|
Inputs: make([]*input.Input, len(tt.inputs)),
|
||||||
}
|
}
|
||||||
copy(slot.Inputs, tt.inputs)
|
copy(slot.Inputs, tt.inputs)
|
||||||
|
|
||||||
|
|||||||
@@ -11,6 +11,7 @@ import (
|
|||||||
"image"
|
"image"
|
||||||
"log"
|
"log"
|
||||||
"log/slog"
|
"log/slog"
|
||||||
|
"math"
|
||||||
"net"
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
"os"
|
"os"
|
||||||
@@ -51,10 +52,10 @@ type Sequence struct {
|
|||||||
iBatch int
|
iBatch int
|
||||||
|
|
||||||
// prompt inputs left to evaluate
|
// prompt inputs left to evaluate
|
||||||
inputs []input.Input
|
inputs []*input.Input
|
||||||
|
|
||||||
// inputs that have been added to a batch but not yet submitted to Forward
|
// inputs that have been added to a batch but not yet submitted to Forward
|
||||||
pendingInputs []input.Input
|
pendingInputs []*input.Input
|
||||||
|
|
||||||
// tokens that have been generated but not returned yet (e.g. for stop sequences)
|
// tokens that have been generated but not returned yet (e.g. for stop sequences)
|
||||||
pendingResponses []string
|
pendingResponses []string
|
||||||
@@ -182,8 +183,8 @@ func (s *Server) NewSequence(prompt string, images []llm.ImageData, params NewSe
|
|||||||
// inputs processes the prompt and images into a list of inputs
|
// inputs processes the prompt and images into a list of inputs
|
||||||
// by splitting the prompt on [img-<n>] tags, tokenizing text and
|
// by splitting the prompt on [img-<n>] tags, tokenizing text and
|
||||||
// decoding images
|
// decoding images
|
||||||
func (s *Server) inputs(prompt string, images []llm.ImageData) ([]input.Input, []ml.Context, multimodalStore, error) {
|
func (s *Server) inputs(prompt string, images []llm.ImageData) ([]*input.Input, []ml.Context, multimodalStore, error) {
|
||||||
var inputs []input.Input
|
var inputs []*input.Input
|
||||||
var ctxs []ml.Context
|
var ctxs []ml.Context
|
||||||
var mmStore multimodalStore
|
var mmStore multimodalStore
|
||||||
|
|
||||||
@@ -210,7 +211,7 @@ func (s *Server) inputs(prompt string, images []llm.ImageData) ([]input.Input, [
|
|||||||
}
|
}
|
||||||
|
|
||||||
for _, t := range tokens {
|
for _, t := range tokens {
|
||||||
inputs = append(inputs, input.Input{Token: t})
|
inputs = append(inputs, &input.Input{Token: t})
|
||||||
}
|
}
|
||||||
|
|
||||||
// image - decode and store
|
// image - decode and store
|
||||||
@@ -243,7 +244,7 @@ func (s *Server) inputs(prompt string, images []llm.ImageData) ([]input.Input, [
|
|||||||
|
|
||||||
mmStore.addMultimodal(imageEmbeddings)
|
mmStore.addMultimodal(imageEmbeddings)
|
||||||
|
|
||||||
inputs = append(inputs, input.Input{Multimodal: imageEmbeddings, MultimodalHash: imageHash})
|
inputs = append(inputs, &input.Input{Multimodal: imageEmbeddings, MultimodalHash: imageHash})
|
||||||
postTokenize = true
|
postTokenize = true
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -259,6 +260,37 @@ func (s *Server) inputs(prompt string, images []llm.ImageData) ([]input.Input, [
|
|||||||
return inputs, ctxs, mmStore, nil
|
return inputs, ctxs, mmStore, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type batchState struct {
|
||||||
|
// id provides a counter for trace logging batches
|
||||||
|
id int
|
||||||
|
|
||||||
|
// ctx holds the backend context used for this batch
|
||||||
|
ctx ml.Context
|
||||||
|
|
||||||
|
// modelOutput holds the outputs from this batch
|
||||||
|
modelOutput ml.Tensor
|
||||||
|
|
||||||
|
// batchInputs holds the input token pointers which may start as
|
||||||
|
// placeholders later filled in before calling ctx.Compute
|
||||||
|
batchInputs []*input.Input
|
||||||
|
|
||||||
|
// batch contains the inputs for a model forward pass
|
||||||
|
batch input.Batch
|
||||||
|
|
||||||
|
// full set of seqs at the time this batch was initiated
|
||||||
|
seqs []*Sequence
|
||||||
|
|
||||||
|
// Signaled when this batches inputs are ready and compute can proceed
|
||||||
|
inputsReadyCh chan struct{}
|
||||||
|
|
||||||
|
// Signaling when Compute is about to begin on this batch, and
|
||||||
|
// seqs have been updated to prepare for the next batch
|
||||||
|
computeStartedCh chan struct{}
|
||||||
|
|
||||||
|
// Signaled when this batches outputs are complete and the next batch can proceed
|
||||||
|
outputsReadyCh chan struct{}
|
||||||
|
}
|
||||||
|
|
||||||
type Server struct {
|
type Server struct {
|
||||||
// modelPath is the location of the model to be loaded
|
// modelPath is the location of the model to be loaded
|
||||||
modelPath string
|
modelPath string
|
||||||
@@ -290,6 +322,12 @@ type Server struct {
|
|||||||
// TODO (jmorganca): make this n_batch
|
// TODO (jmorganca): make this n_batch
|
||||||
batchSize int
|
batchSize int
|
||||||
|
|
||||||
|
// Used to signal a hard failure during async processing which will panic the runner
|
||||||
|
hardErrCh chan error
|
||||||
|
|
||||||
|
// Simple counter used only for trace logging batches
|
||||||
|
batchID int
|
||||||
|
|
||||||
// protects access to everything below this line
|
// protects access to everything below this line
|
||||||
// this is context state needed for decoding
|
// this is context state needed for decoding
|
||||||
mu sync.Mutex
|
mu sync.Mutex
|
||||||
@@ -362,33 +400,74 @@ func (s *Server) removeSequence(seqIndex int, reason llm.DoneReason) {
|
|||||||
s.seqsSem.Release(1)
|
s.seqsSem.Release(1)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// track batch state between forwardBatch, computeBatch and predictForwardBatch
|
||||||
|
|
||||||
func (s *Server) run(ctx context.Context) {
|
func (s *Server) run(ctx context.Context) {
|
||||||
s.ready.Wait()
|
s.ready.Wait()
|
||||||
|
|
||||||
|
supportsAsync := s.model.Backend().Config().Uint("pooling_type", math.MaxUint32) == math.MaxUint32
|
||||||
|
|
||||||
|
var activeBatch batchState
|
||||||
for {
|
for {
|
||||||
select {
|
select {
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
return
|
return
|
||||||
|
case err := <-s.hardErrCh:
|
||||||
|
panic(err)
|
||||||
default:
|
default:
|
||||||
err := s.processBatch()
|
var err error
|
||||||
|
activeBatch, err = s.forwardBatch(activeBatch)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
panic(err)
|
panic(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if supportsAsync {
|
||||||
|
go s.computeBatch(activeBatch)
|
||||||
|
} else {
|
||||||
|
s.computeBatch(activeBatch)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) processBatch() error {
|
// forwardBatch will calculate a batch.
|
||||||
|
func (s *Server) forwardBatch(pendingBatch batchState) (nextBatch batchState, err error) {
|
||||||
|
// If we have a pending batch still processing, wait until Compute has started
|
||||||
|
// before setting up the next batch so the seqs inputs are ready to receive their
|
||||||
|
// token values and we get the correct input pointers for the batchInputs
|
||||||
|
if pendingBatch.ctx != nil {
|
||||||
|
logutil.Trace("forwardBatch waiting for compute to start", "pendingBatch.id", pendingBatch.id)
|
||||||
|
<-pendingBatch.computeStartedCh
|
||||||
|
logutil.Trace("forwardBatch compute started, setting up next batch", "pendingBatch.id", pendingBatch.id, "id", s.batchID)
|
||||||
|
nextBatch.inputsReadyCh = pendingBatch.outputsReadyCh // Chain the ouputs from the pending batch to the next inputs batch
|
||||||
|
} else {
|
||||||
|
logutil.Trace("forwardBatch no pending batch detected", "batchID", s.batchID)
|
||||||
|
// No pendingBatch, so the inputs will be ready in the seqs immediately
|
||||||
|
nextBatch.inputsReadyCh = make(chan struct{}, 1)
|
||||||
|
nextBatch.inputsReadyCh <- struct{}{}
|
||||||
|
}
|
||||||
|
|
||||||
s.mu.Lock()
|
s.mu.Lock()
|
||||||
for s.allNil() {
|
for s.allNil() {
|
||||||
s.cond.Wait() // Wait until an item is added
|
s.cond.Wait() // Wait until an item is added
|
||||||
}
|
}
|
||||||
defer s.mu.Unlock()
|
defer s.mu.Unlock()
|
||||||
|
|
||||||
ctx := s.model.Backend().NewContext()
|
nextBatch.ctx = s.model.Backend().NewContext()
|
||||||
defer ctx.Close()
|
defer func() {
|
||||||
|
if err != nil {
|
||||||
|
nextBatch.ctx.Close()
|
||||||
|
nextBatch.ctx = nil
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
nextBatch.id = s.batchID
|
||||||
|
nextBatch.seqs = append([]*Sequence{}, s.seqs...)
|
||||||
|
nextBatch.computeStartedCh = make(chan struct{}, 1)
|
||||||
|
nextBatch.outputsReadyCh = make(chan struct{}, 1)
|
||||||
|
|
||||||
var batchInputs []int32
|
// Prepare the seqs and batch, but defer the input token values as we may not be ready yet
|
||||||
|
var batchInputs []*input.Input
|
||||||
|
var batchOutputs []int32
|
||||||
var batch input.Batch
|
var batch input.Batch
|
||||||
|
|
||||||
resumeSeq := -1
|
resumeSeq := -1
|
||||||
@@ -396,7 +475,6 @@ func (s *Server) processBatch() error {
|
|||||||
for range s.seqs {
|
for range s.seqs {
|
||||||
seqIdx = (seqIdx + 1) % len(s.seqs)
|
seqIdx = (seqIdx + 1) % len(s.seqs)
|
||||||
seq := s.seqs[seqIdx]
|
seq := s.seqs[seqIdx]
|
||||||
|
|
||||||
if seq == nil {
|
if seq == nil {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
@@ -404,12 +482,13 @@ func (s *Server) processBatch() error {
|
|||||||
// if past the num predict limit
|
// if past the num predict limit
|
||||||
if seq.numPredict > 0 && seq.numPredicted >= seq.numPredict {
|
if seq.numPredict > 0 && seq.numPredicted >= seq.numPredict {
|
||||||
s.removeSequence(seqIdx, llm.DoneReasonLength)
|
s.removeSequence(seqIdx, llm.DoneReasonLength)
|
||||||
|
nextBatch.seqs[seqIdx] = nil
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
if !s.cache.enabled {
|
if !s.cache.enabled {
|
||||||
seq.inputs = append(seq.cache.Inputs, seq.inputs...)
|
seq.inputs = append(seq.cache.Inputs, seq.inputs...)
|
||||||
seq.cache.Inputs = []input.Input{}
|
seq.cache.Inputs = []*input.Input{}
|
||||||
}
|
}
|
||||||
|
|
||||||
batchSize := s.batchSize
|
batchSize := s.batchSize
|
||||||
@@ -442,25 +521,28 @@ func (s *Server) processBatch() error {
|
|||||||
break
|
break
|
||||||
}
|
}
|
||||||
|
|
||||||
err := s.cache.ShiftCacheSlot(seq.cache, seq.numKeep)
|
err = s.cache.ShiftCacheSlot(seq.cache, seq.numKeep)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
var reprocess *ErrReprocessInputs
|
var reprocess *ErrReprocessInputs
|
||||||
if errors.As(err, &reprocess) {
|
if errors.As(err, &reprocess) {
|
||||||
// Prepend these inputs to the sequence's inputs queue for reprocessing
|
// Prepend these inputs to the sequence's inputs queue for reprocessing
|
||||||
seq.inputs = append(reprocess.Inputs, seq.inputs...)
|
seq.inputs = append(reprocess.Inputs, seq.inputs...)
|
||||||
// Skip this sequence but continue processing the rest
|
// Skip this sequence but continue processing the rest
|
||||||
|
nextBatch.seqs[seqIdx] = nil // clear this sequence for this batch
|
||||||
|
err = nil
|
||||||
continue
|
continue
|
||||||
} else {
|
} else {
|
||||||
return err
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
batchInputs = append(batchInputs, inp.Token)
|
batchInputs = append(batchInputs, seq.inputs[i])
|
||||||
if inp.Multimodal != nil {
|
if inp.Multimodal != nil {
|
||||||
mm, err := seq.mmStore.getMultimodal(s.model.Backend(), ctx, inp.Multimodal, false)
|
var mm []input.Multimodal
|
||||||
|
mm, err = seq.mmStore.getMultimodal(s.model.Backend(), nextBatch.ctx, inp.Multimodal, false)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return
|
||||||
}
|
}
|
||||||
batch.Multimodal = append(batch.Multimodal, input.MultimodalIndex{Index: len(batchInputs) - 1, Multimodal: mm})
|
batch.Multimodal = append(batch.Multimodal, input.MultimodalIndex{Index: len(batchInputs) - 1, Multimodal: mm})
|
||||||
}
|
}
|
||||||
@@ -468,10 +550,11 @@ func (s *Server) processBatch() error {
|
|||||||
batch.Positions = append(batch.Positions, int32(len(seq.cache.Inputs)+len(seq.pendingInputs)))
|
batch.Positions = append(batch.Positions, int32(len(seq.cache.Inputs)+len(seq.pendingInputs)))
|
||||||
batch.Sequences = append(batch.Sequences, seq.cache.Id)
|
batch.Sequences = append(batch.Sequences, seq.cache.Id)
|
||||||
|
|
||||||
seq.iBatch = len(batch.Outputs)
|
seq.iBatch = len(batchOutputs)
|
||||||
if i+1 == len(seq.inputs) {
|
if i+1 == len(seq.inputs) || seq.embeddingOnly {
|
||||||
batch.Outputs = append(batch.Outputs, int32(len(batchInputs)-1))
|
batchOutputs = append(batchOutputs, int32(len(batchInputs)-1))
|
||||||
}
|
}
|
||||||
|
logutil.Trace("forwardBatch iBatch", "batchID", s.batchID, "seqIdx", seqIdx, "seq.iBatch", seq.iBatch, "i+1", i+1, "len(seq.inputs)", len(seq.inputs))
|
||||||
seq.pendingInputs = append(seq.pendingInputs, inp)
|
seq.pendingInputs = append(seq.pendingInputs, inp)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -485,73 +568,169 @@ func (s *Server) processBatch() error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if len(batchInputs) == 0 {
|
if len(batchInputs) == 0 {
|
||||||
return nil
|
logutil.Trace("forwardBatch no batchInputs, going idle", "batchID", s.batchID)
|
||||||
|
nextBatch.ctx.Close()
|
||||||
|
nextBatch.ctx = nil
|
||||||
|
return
|
||||||
}
|
}
|
||||||
|
s.batchID++
|
||||||
|
|
||||||
modelOutput, err := model.Forward(ctx, s.model, batchInputs, batch)
|
// Actual batchInputs values will be injected into the batch.Inputs tensor before calling Compute
|
||||||
|
batch.Inputs = nextBatch.ctx.Input().Empty(ml.DTypeI32, len(batchInputs))
|
||||||
|
batch.Outputs = nextBatch.ctx.Input().FromIntSlice(batchOutputs, len(batchOutputs))
|
||||||
|
nextBatch.modelOutput, err = model.Forward(nextBatch.ctx, s.model, batch)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to decode batch: %w", err)
|
err = fmt.Errorf("failed to build graph: %w", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
nextBatch.batchInputs = batchInputs
|
||||||
|
nextBatch.batch = batch
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Async processing of the next batch
|
||||||
|
func (s *Server) computeBatch(activeBatch batchState) {
|
||||||
|
if activeBatch.ctx == nil {
|
||||||
|
// Nothing to compute
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer activeBatch.ctx.Close()
|
||||||
|
|
||||||
|
// Wait until inputs are ready
|
||||||
|
logutil.Trace("computeBatch: waiting for inputs to be ready", "batchID", activeBatch.id)
|
||||||
|
<-activeBatch.inputsReadyCh
|
||||||
|
logutil.Trace("computeBatch: inputs are ready", "batchID", activeBatch.id)
|
||||||
|
|
||||||
|
// Once we complete, signal the next batch of inputs are ready
|
||||||
|
// This will unblock the next computeBatch, or forwardBatch if new seqs come in
|
||||||
|
defer func() {
|
||||||
|
logutil.Trace("computeBatch: outputs are ready", "batchID", activeBatch.id)
|
||||||
|
activeBatch.outputsReadyCh <- struct{}{}
|
||||||
|
}()
|
||||||
|
|
||||||
|
s.mu.Lock()
|
||||||
|
|
||||||
|
// Gather the actual input token values now that they're ready
|
||||||
|
batchInputs := make([]int32, len(activeBatch.batchInputs))
|
||||||
|
for i := range batchInputs {
|
||||||
|
batchInputs[i] = activeBatch.batchInputs[i].Token
|
||||||
}
|
}
|
||||||
|
|
||||||
logits := modelOutput.Floats()
|
// Now we run part of the decoding algorithm to adjust the seq.inputs with placeholder tokens
|
||||||
|
// so that forwardBatch can build a batchInputs set which will eventually contain the actual
|
||||||
|
// decoded tokens.
|
||||||
|
nextBatchTokens := make([]*input.Input, len(s.seqs))
|
||||||
|
iBatches := make([]int, len(s.seqs)) // Record the iBatch values before releasing the lock
|
||||||
for i, seq := range s.seqs {
|
for i, seq := range s.seqs {
|
||||||
|
iBatches[i] = -1
|
||||||
if seq == nil {
|
if seq == nil {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
// Skip over any newly added or skipped sequences
|
||||||
|
if activeBatch.seqs[i] == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
// After calling Forward, pending inputs are now in the cache
|
// Detect if the sequence we're processing has already been completed and replaced
|
||||||
|
// with a new sequence
|
||||||
|
if seq != activeBatch.seqs[i] {
|
||||||
|
logutil.Trace("computeBatch: sequence replaced, discarding its results", "batchID", activeBatch.id, "seqIdx", i)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Pending inputs will actually be in the cache after we call Compute.
|
||||||
|
// However, we have already resolved any placeholder tokens.
|
||||||
|
//
|
||||||
|
// It's possible for incoming sequences to look at the values that we've
|
||||||
|
// added to the cache here and start relying on them before we've done
|
||||||
|
// the computation. This is OK as long as we ensure that this batch's
|
||||||
|
// computation happens before any future batch's and we never fail
|
||||||
|
// (unless we take down the whole runner).
|
||||||
if len(seq.pendingInputs) > 0 {
|
if len(seq.pendingInputs) > 0 {
|
||||||
seq.cache.Inputs = append(seq.cache.Inputs, seq.pendingInputs...)
|
seq.cache.Inputs = append(seq.cache.Inputs, seq.pendingInputs...)
|
||||||
seq.pendingInputs = []input.Input{}
|
seq.pendingInputs = []*input.Input{}
|
||||||
}
|
}
|
||||||
|
|
||||||
// don't sample prompt processing
|
// don't sample prompt processing
|
||||||
if len(seq.inputs) != 0 {
|
if len(seq.inputs) != 0 {
|
||||||
if !s.cache.enabled {
|
if !s.cache.enabled {
|
||||||
return errors.New("caching disabled but unable to fit entire input in a batch")
|
s.hardErrCh <- fmt.Errorf("caching disabled but unable to fit entire input in a batch")
|
||||||
|
s.mu.Unlock()
|
||||||
|
return
|
||||||
}
|
}
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
seq.numPredicted++
|
seq.numPredicted++
|
||||||
|
nextToken := &input.Input{Token: 0} // placeholder we'll fill in after Compute/Floats
|
||||||
|
seq.inputs = []*input.Input{nextToken}
|
||||||
|
nextBatchTokens[i] = nextToken
|
||||||
|
iBatches[i] = seq.iBatch
|
||||||
|
}
|
||||||
|
|
||||||
|
// At this point the seqs are ready for forwardBatch to move forward so unblock
|
||||||
|
s.mu.Unlock()
|
||||||
|
|
||||||
|
activeBatch.batch.Inputs.SetValueFromIntSlice(batchInputs)
|
||||||
|
activeBatch.ctx.ComputeWithNotify(
|
||||||
|
func() {
|
||||||
|
logutil.Trace("computeBatch: signaling computeStartedCh", "batchID", activeBatch.id)
|
||||||
|
activeBatch.computeStartedCh <- struct{}{}
|
||||||
|
},
|
||||||
|
activeBatch.modelOutput)
|
||||||
|
|
||||||
|
outputs := activeBatch.modelOutput.Floats()
|
||||||
|
|
||||||
|
logutil.Trace("computeBatch: logits ready", "batchID", activeBatch.id)
|
||||||
|
|
||||||
|
s.mu.Lock()
|
||||||
|
defer s.mu.Unlock()
|
||||||
|
|
||||||
|
logutil.Trace("computeBatch: decoding", "batchID", activeBatch.id)
|
||||||
|
for i, seq := range s.seqs {
|
||||||
|
if seq == nil || nextBatchTokens[i] == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
if seq.numPredicted == 1 {
|
if seq.numPredicted == 1 {
|
||||||
seq.startGenerationTime = time.Now()
|
seq.startGenerationTime = time.Now()
|
||||||
}
|
}
|
||||||
|
|
||||||
// if done processing the prompt, generate an embedding and return
|
// if done processing the prompt, generate an embedding and return
|
||||||
if seq.embeddingOnly {
|
if seq.embeddingOnly {
|
||||||
// TODO(jessegross): Embedding support
|
seq.embedding <- outputs
|
||||||
slog.Warn("generation of embedding outputs not yet supported")
|
|
||||||
s.removeSequence(i, llm.DoneReasonStop)
|
s.removeSequence(i, llm.DoneReasonStop)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
// sample a token
|
// sample a token
|
||||||
vocabSize := len(logits) / len(batch.Outputs)
|
vocabSize := len(outputs) / activeBatch.batch.Outputs.Dim(0)
|
||||||
|
logutil.Trace("computeBatch: vocab details", "batchID", activeBatch.id, "seqIdx", i, "len(logits)", len(outputs), "len(activeBatch.batch.Outputs)", activeBatch.batch.Outputs.Dim(0), "vocabSize", vocabSize, "iBatches", iBatches)
|
||||||
token, err := seq.sampler.Sample(logits[seq.iBatch*vocabSize : (seq.iBatch+1)*vocabSize])
|
token, err := seq.sampler.Sample(outputs[iBatches[i]*vocabSize : (iBatches[i]+1)*vocabSize])
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to sample token: %w", err)
|
s.hardErrCh <- fmt.Errorf("failed to sample token: %w", err)
|
||||||
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
nextBatchTokens[i].Token = token
|
||||||
|
|
||||||
// if it's an end of sequence token, break
|
// if it's an end of sequence token, break
|
||||||
if s.model.(model.TextProcessor).Is(token, model.SpecialEOS) {
|
if s.model.(model.TextProcessor).Is(token, model.SpecialEOS) {
|
||||||
// TODO (jmorganca): we should send this back
|
// TODO (jmorganca): we should send this back
|
||||||
// as it's important for the /api/generate context
|
// as it's important for the /api/generate context
|
||||||
// seq.responses <- piece
|
// seq.responses <- piece
|
||||||
|
logutil.Trace("computeBatch: EOS", "batchID", activeBatch.id, "seqIdx", i)
|
||||||
s.removeSequence(i, llm.DoneReasonStop)
|
s.removeSequence(i, llm.DoneReasonStop)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
piece, err := s.model.(model.TextProcessor).Decode([]int32{token})
|
piece, err := s.model.(model.TextProcessor).Decode([]int32{token})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
s.hardErrCh <- fmt.Errorf("failed to decode token: %w", err)
|
||||||
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
seq.inputs = []input.Input{{Token: token}}
|
|
||||||
|
|
||||||
seq.pendingResponses = append(seq.pendingResponses, piece)
|
seq.pendingResponses = append(seq.pendingResponses, piece)
|
||||||
sequence := strings.Join(seq.pendingResponses, "")
|
sequence := strings.Join(seq.pendingResponses, "")
|
||||||
|
|
||||||
@@ -575,6 +754,7 @@ func (s *Server) processBatch() error {
|
|||||||
if tokenTruncated || origLen == newLen {
|
if tokenTruncated || origLen == newLen {
|
||||||
tokenLen--
|
tokenLen--
|
||||||
}
|
}
|
||||||
|
|
||||||
seq.cache.Inputs = seq.cache.Inputs[:tokenLen]
|
seq.cache.Inputs = seq.cache.Inputs[:tokenLen]
|
||||||
|
|
||||||
s.removeSequence(i, llm.DoneReasonStop)
|
s.removeSequence(i, llm.DoneReasonStop)
|
||||||
@@ -593,8 +773,6 @@ func (s *Server) processBatch() error {
|
|||||||
s.removeSequence(i, llm.DoneReasonConnectionClosed)
|
s.removeSequence(i, llm.DoneReasonConnectionClosed)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
|
func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
|
||||||
@@ -665,7 +843,7 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
|
|||||||
found := false
|
found := false
|
||||||
for i, sq := range s.seqs {
|
for i, sq := range s.seqs {
|
||||||
if sq == nil {
|
if sq == nil {
|
||||||
seq.cache, seq.inputs, err = s.cache.LoadCacheSlot(seq.inputs)
|
seq.cache, seq.inputs, err = s.cache.LoadCacheSlot(seq.inputs, true)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
s.mu.Unlock()
|
s.mu.Unlock()
|
||||||
s.seqsSem.Release(1)
|
s.seqsSem.Release(1)
|
||||||
@@ -721,6 +899,67 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *Server) embeddings(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if s.model.Backend().Config().Uint("pooling_type", math.MaxUint32) == math.MaxUint32 {
|
||||||
|
http.Error(w, "this model does not support embeddings", http.StatusNotImplemented)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var req llm.EmbeddingRequest
|
||||||
|
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||||
|
http.Error(w, fmt.Sprintf("bad request: %s", err), http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
seq, err := s.NewSequence(req.Content, nil, NewSequenceParams{embedding: true})
|
||||||
|
if err != nil {
|
||||||
|
http.Error(w, fmt.Sprintf("failed to create new sequence: %v", err), http.StatusInternalServerError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := s.seqsSem.Acquire(r.Context(), 1); err != nil {
|
||||||
|
if errors.Is(err, context.Canceled) {
|
||||||
|
slog.Info("aborting embedding request due to client closing the connection")
|
||||||
|
} else {
|
||||||
|
http.Error(w, fmt.Sprintf("failed to acquire semaphore: %v", err), http.StatusInternalServerError)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
s.mu.Lock()
|
||||||
|
found := false
|
||||||
|
for i, sq := range s.seqs {
|
||||||
|
if sq == nil {
|
||||||
|
seq.cache, seq.inputs, err = s.cache.LoadCacheSlot(seq.inputs, false)
|
||||||
|
if err != nil {
|
||||||
|
s.mu.Unlock()
|
||||||
|
s.seqsSem.Release(1)
|
||||||
|
http.Error(w, fmt.Sprintf("failed to load cache: %v", err), http.StatusInternalServerError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
s.seqs[i] = seq
|
||||||
|
s.cond.Signal()
|
||||||
|
found = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
s.mu.Unlock()
|
||||||
|
|
||||||
|
if !found {
|
||||||
|
s.seqsSem.Release(1)
|
||||||
|
http.Error(w, "could not find an available sequence", http.StatusInternalServerError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := json.NewEncoder(w).Encode(&llm.EmbeddingResponse{
|
||||||
|
Embedding: <-seq.embedding,
|
||||||
|
}); err != nil {
|
||||||
|
http.Error(w, fmt.Sprintf("failed to encode response: %v", err), http.StatusInternalServerError)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func (s *Server) health(w http.ResponseWriter, r *http.Request) {
|
func (s *Server) health(w http.ResponseWriter, r *http.Request) {
|
||||||
w.Header().Set("Content-Type", "application/json")
|
w.Header().Set("Content-Type", "application/json")
|
||||||
if err := json.NewEncoder(w).Encode(&llm.ServerStatusResponse{
|
if err := json.NewEncoder(w).Encode(&llm.ServerStatusResponse{
|
||||||
@@ -736,7 +975,10 @@ func (s *Server) reserveWorstCaseGraph() error {
|
|||||||
defer ctx.Close()
|
defer ctx.Close()
|
||||||
|
|
||||||
var err error
|
var err error
|
||||||
inputs := make([]input.Input, s.batchSize)
|
inputs := make([]*input.Input, s.batchSize)
|
||||||
|
for i := range inputs {
|
||||||
|
inputs[i] = &input.Input{}
|
||||||
|
}
|
||||||
mmStore := newMultimodalStore()
|
mmStore := newMultimodalStore()
|
||||||
|
|
||||||
// Multimodal strategy:
|
// Multimodal strategy:
|
||||||
@@ -778,8 +1020,11 @@ func (s *Server) reserveWorstCaseGraph() error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if len(inputs) < s.batchSize {
|
if len(inputs) < s.batchSize {
|
||||||
newInputs := make([]input.Input, s.batchSize)
|
newInputs := make([]*input.Input, s.batchSize)
|
||||||
copy(newInputs, inputs)
|
copy(newInputs, inputs)
|
||||||
|
for i := len(inputs); i < s.batchSize; i++ {
|
||||||
|
newInputs[i] = &input.Input{}
|
||||||
|
}
|
||||||
inputs = newInputs
|
inputs = newInputs
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -803,12 +1048,8 @@ func (s *Server) reserveWorstCaseGraph() error {
|
|||||||
batch.Positions[i] = int32(i)
|
batch.Positions[i] = int32(i)
|
||||||
}
|
}
|
||||||
|
|
||||||
batch.Outputs = make([]int32, s.parallel)
|
|
||||||
for i := range batch.Outputs {
|
|
||||||
batch.Outputs[i] = int32(i)
|
|
||||||
}
|
|
||||||
|
|
||||||
batch.Inputs = ctx.Input().FromIntSlice(batchInputs, len(batchInputs))
|
batch.Inputs = ctx.Input().FromIntSlice(batchInputs, len(batchInputs))
|
||||||
|
batch.Outputs = ctx.Input().Empty(ml.DTypeI32, s.parallel)
|
||||||
|
|
||||||
cache := s.model.Config().Cache
|
cache := s.model.Config().Cache
|
||||||
if cache != nil {
|
if cache != nil {
|
||||||
@@ -843,7 +1084,12 @@ func (s *Server) allocModel(
|
|||||||
defer func() {
|
defer func() {
|
||||||
if r := recover(); r != nil {
|
if r := recover(); r != nil {
|
||||||
if err, ok := r.(error); ok {
|
if err, ok := r.(error); ok {
|
||||||
panicErr = err
|
var noMem ml.ErrNoMem
|
||||||
|
if errors.As(err, &noMem) {
|
||||||
|
panicErr = noMem
|
||||||
|
} else {
|
||||||
|
panic(r)
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
panic(r)
|
panic(r)
|
||||||
}
|
}
|
||||||
@@ -1011,6 +1257,7 @@ func Execute(args []string) error {
|
|||||||
server := &Server{
|
server := &Server{
|
||||||
modelPath: *mpath,
|
modelPath: *mpath,
|
||||||
status: llm.ServerStatusLaunched,
|
status: llm.ServerStatusLaunched,
|
||||||
|
hardErrCh: make(chan error, 1),
|
||||||
}
|
}
|
||||||
|
|
||||||
server.cond = sync.NewCond(&server.mu)
|
server.cond = sync.NewCond(&server.mu)
|
||||||
@@ -1029,10 +1276,7 @@ func Execute(args []string) error {
|
|||||||
mux := http.NewServeMux()
|
mux := http.NewServeMux()
|
||||||
// TODO: support embeddings
|
// TODO: support embeddings
|
||||||
mux.HandleFunc("POST /load", server.load)
|
mux.HandleFunc("POST /load", server.load)
|
||||||
mux.HandleFunc("POST /embedding", func(w http.ResponseWriter, r *http.Request) {
|
mux.HandleFunc("POST /embedding", server.embeddings)
|
||||||
http.Error(w, "this model does not support embeddings", http.StatusNotImplemented)
|
|
||||||
})
|
|
||||||
|
|
||||||
mux.HandleFunc("POST /completion", server.completion)
|
mux.HandleFunc("POST /completion", server.completion)
|
||||||
mux.HandleFunc("GET /health", server.health)
|
mux.HandleFunc("GET /health", server.health)
|
||||||
|
|
||||||
|
|||||||
@@ -78,7 +78,7 @@ function checkEnv() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
function buildOllama() {
|
function buildCPU() {
|
||||||
mkdir -Force -path "${script:DIST_DIR}\"
|
mkdir -Force -path "${script:DIST_DIR}\"
|
||||||
if ($script:ARCH -ne "arm64") {
|
if ($script:ARCH -ne "arm64") {
|
||||||
Remove-Item -ea 0 -recurse -force -path "${script:SRC_DIR}\dist\windows-${script:ARCH}"
|
Remove-Item -ea 0 -recurse -force -path "${script:SRC_DIR}\dist\windows-${script:ARCH}"
|
||||||
@@ -90,20 +90,72 @@ function buildOllama() {
|
|||||||
if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)}
|
if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)}
|
||||||
& cmake --install build --component CPU --strip
|
& cmake --install build --component CPU --strip
|
||||||
if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)}
|
if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
function buildCUDA11() {
|
||||||
|
# CUDA v11 claims to be compatible with MSVC 2022, but the latest updates are no longer compatible
|
||||||
|
# 19.40 is the last compiler version that works, but recent udpates are 19.43
|
||||||
|
# So this pins to MSVC 2019 for best compatibility
|
||||||
|
mkdir -Force -path "${script:DIST_DIR}\"
|
||||||
|
if ($script:ARCH -ne "arm64") {
|
||||||
$hashEnv = @{}
|
$hashEnv = @{}
|
||||||
Get-ChildItem env: | foreach { $hashEnv[$_.Name] = $_.Value }
|
Get-ChildItem env: | foreach { $hashEnv[$_.Name] = $_.Value }
|
||||||
if ("$script:CUDA_DIRS".Contains("v12")) {
|
if ("$script:CUDA_DIRS".Contains("v11")) {
|
||||||
$hashEnv.Keys | foreach { if ($_.Contains("CUDA_PATH_V12")) { $v12="$_" }}
|
$hashEnv.Keys | foreach { if ($_.Contains("CUDA_PATH_V11")) { $x=$hashEnv[$_]; if (test-path -literalpath "$x\bin\nvcc.exe" ) { $cuda=$x} }}
|
||||||
$env:CUDAToolkit_ROOT=$hashEnv[$v12]
|
write-host "Building CUDA v11 backend libraries $cuda"
|
||||||
write-host "Building CUDA v12 backend libraries"
|
$env:CUDAToolkit_ROOT=$cuda
|
||||||
& cmake --fresh --preset "CUDA 12" --install-prefix $script:DIST_DIR
|
& cmake --fresh --preset "CUDA 11" -T cuda="$cuda" -DCMAKE_CUDA_COMPILER="$cuda\bin\nvcc.exe" -G "Visual Studio 16 2019" --install-prefix $script:DIST_DIR -DOLLAMA_RUNNER_DIR="cuda_v11"
|
||||||
|
if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)}
|
||||||
|
& cmake --build --preset "CUDA 11" --config Release --parallel $script:JOBS
|
||||||
|
if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)}
|
||||||
|
& cmake --install build --component "CUDA" --strip
|
||||||
|
if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
function buildCUDA12() {
|
||||||
|
mkdir -Force -path "${script:DIST_DIR}\"
|
||||||
|
if ($script:ARCH -ne "arm64") {
|
||||||
|
$hashEnv = @{}
|
||||||
|
Get-ChildItem env: | foreach { $hashEnv[$_.Name] = $_.Value }
|
||||||
|
if ("$script:CUDA_DIRS".Contains("v12.8")) {
|
||||||
|
$hashEnv.Keys | foreach { if ($_.Contains("CUDA_PATH_V12_8")) { $x=$hashEnv[$_]; if (test-path -literalpath "$x\bin\nvcc.exe" ) { $cuda=$x} }}
|
||||||
|
write-host "Building CUDA v12 backend libraries $cuda"
|
||||||
|
$env:CUDAToolkit_ROOT=$cuda
|
||||||
|
& cmake --fresh --preset "CUDA 12" -T cuda="$cuda" --install-prefix $script:DIST_DIR -DOLLAMA_RUNNER_DIR="cuda_v12"
|
||||||
if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)}
|
if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)}
|
||||||
& cmake --build --preset "CUDA 12" --config Release --parallel $script:JOBS
|
& cmake --build --preset "CUDA 12" --config Release --parallel $script:JOBS
|
||||||
if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)}
|
if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)}
|
||||||
& cmake --install build --component "CUDA" --strip
|
& cmake --install build --component "CUDA" --strip
|
||||||
if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)}
|
if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
function buildCUDA13() {
|
||||||
|
mkdir -Force -path "${script:DIST_DIR}\"
|
||||||
|
if ($script:ARCH -ne "arm64") {
|
||||||
|
$hashEnv = @{}
|
||||||
|
Get-ChildItem env: | foreach { $hashEnv[$_.Name] = $_.Value }
|
||||||
|
if ("$script:CUDA_DIRS".Contains("v13")) {
|
||||||
|
$hashEnv.Keys | foreach { if ($_.Contains("CUDA_PATH_V13")) { $x=$hashEnv[$_]; if (test-path -literalpath "$x\bin\nvcc.exe" ) { $cuda=$x} }}
|
||||||
|
$env:CUDAToolkit_ROOT=$cuda
|
||||||
|
write-host "Building CUDA v13 backend libraries $cuda"
|
||||||
|
& cmake --fresh --preset "CUDA 13" -T cuda="$cuda" --install-prefix $script:DIST_DIR -DOLLAMA_RUNNER_DIR="cuda_v13"
|
||||||
|
if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)}
|
||||||
|
& cmake --build --preset "CUDA 13" --config Release --parallel $script:JOBS
|
||||||
|
if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)}
|
||||||
|
& cmake --install build --component "CUDA" --strip
|
||||||
|
if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
function buildROCm() {
|
||||||
|
mkdir -Force -path "${script:DIST_DIR}\"
|
||||||
|
if ($script:ARCH -ne "arm64") {
|
||||||
if ($env:HIP_PATH) {
|
if ($env:HIP_PATH) {
|
||||||
write-host "Building ROCm backend libraries"
|
write-host "Building ROCm backend libraries"
|
||||||
if (-Not (get-command -ErrorAction silent ninja)) {
|
if (-Not (get-command -ErrorAction silent ninja)) {
|
||||||
@@ -129,6 +181,10 @@ function buildOllama() {
|
|||||||
if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)}
|
if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
function buildOllama() {
|
||||||
|
mkdir -Force -path "${script:DIST_DIR}\"
|
||||||
write-host "Building ollama CLI"
|
write-host "Building ollama CLI"
|
||||||
& go build -trimpath -ldflags "-s -w -X=github.com/ollama/ollama/version.Version=$script:VERSION -X=github.com/ollama/ollama/server.mode=release" .
|
& go build -trimpath -ldflags "-s -w -X=github.com/ollama/ollama/version.Version=$script:VERSION -X=github.com/ollama/ollama/server.mode=release" .
|
||||||
if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)}
|
if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)}
|
||||||
@@ -236,6 +292,10 @@ function distZip() {
|
|||||||
checkEnv
|
checkEnv
|
||||||
try {
|
try {
|
||||||
if ($($args.count) -eq 0) {
|
if ($($args.count) -eq 0) {
|
||||||
|
buildCPU
|
||||||
|
buildCUDA12
|
||||||
|
buildCUDA13
|
||||||
|
buildROCm
|
||||||
buildOllama
|
buildOllama
|
||||||
buildApp
|
buildApp
|
||||||
gatherDependencies
|
gatherDependencies
|
||||||
|
|||||||
@@ -32,6 +32,7 @@ import (
|
|||||||
"github.com/ollama/ollama/envconfig"
|
"github.com/ollama/ollama/envconfig"
|
||||||
"github.com/ollama/ollama/format"
|
"github.com/ollama/ollama/format"
|
||||||
"github.com/ollama/ollama/fs/ggml"
|
"github.com/ollama/ollama/fs/ggml"
|
||||||
|
"github.com/ollama/ollama/harmony"
|
||||||
"github.com/ollama/ollama/llm"
|
"github.com/ollama/ollama/llm"
|
||||||
"github.com/ollama/ollama/logutil"
|
"github.com/ollama/ollama/logutil"
|
||||||
"github.com/ollama/ollama/openai"
|
"github.com/ollama/ollama/openai"
|
||||||
@@ -45,6 +46,18 @@ import (
|
|||||||
"github.com/ollama/ollama/version"
|
"github.com/ollama/ollama/version"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
func shouldUseHarmony(model *Model) bool {
|
||||||
|
if slices.Contains([]string{"gptoss", "gpt-oss"}, model.Config.ModelFamily) {
|
||||||
|
// heuristic to check whether the template expects to be parsed via harmony:
|
||||||
|
// search for harmony tags that are nearly always used
|
||||||
|
if model.Template.Contains("<|start|>") && model.Template.Contains("<|end|>") {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
func experimentEnabled(name string) bool {
|
func experimentEnabled(name string) bool {
|
||||||
return slices.Contains(strings.Split(os.Getenv("OLLAMA_EXPERIMENT"), ","), name)
|
return slices.Contains(strings.Split(os.Getenv("OLLAMA_EXPERIMENT"), ","), name)
|
||||||
}
|
}
|
||||||
@@ -176,7 +189,7 @@ func (s *Server) GenerateHandler(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// expire the runner
|
// expire the runner
|
||||||
if req.Prompt == "" && req.KeepAlive != nil && int(req.KeepAlive.Seconds()) == 0 {
|
if req.Prompt == "" && req.KeepAlive != nil && req.KeepAlive.Duration == 0 {
|
||||||
s.sched.expireRunner(m)
|
s.sched.expireRunner(m)
|
||||||
|
|
||||||
c.JSON(http.StatusOK, api.GenerateResponse{
|
c.JSON(http.StatusOK, api.GenerateResponse{
|
||||||
@@ -194,12 +207,12 @@ func (s *Server) GenerateHandler(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
useHarmony := shouldUseHarmony(*m) && !req.Raw
|
useHarmony := shouldUseHarmony(m) && !req.Raw
|
||||||
var harmonyMessageHandler *HarmonyMessageHandler
|
var harmonyMessageHandler *harmony.HarmonyMessageHandler
|
||||||
var harmonyToolParser *HarmonyToolCallAccumulator
|
var harmonyToolParser *harmony.HarmonyToolCallAccumulator
|
||||||
if useHarmony {
|
if useHarmony {
|
||||||
harmonyMessageHandler = NewHarmonyMessageHandler()
|
harmonyMessageHandler = harmony.NewHarmonyMessageHandler()
|
||||||
harmonyMessageHandler.harmonyParser.AddImplicitStart()
|
harmonyMessageHandler.HarmonyParser.AddImplicitStart()
|
||||||
harmonyToolParser = harmonyMessageHandler.CreateToolParser()
|
harmonyToolParser = harmonyMessageHandler.CreateToolParser()
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -475,7 +488,6 @@ func (s *Server) EmbedHandler(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
truncate := true
|
truncate := true
|
||||||
|
|
||||||
if req.Truncate != nil && !*req.Truncate {
|
if req.Truncate != nil && !*req.Truncate {
|
||||||
truncate = false
|
truncate = false
|
||||||
}
|
}
|
||||||
@@ -542,7 +554,16 @@ func (s *Server) EmbedHandler(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if bos := kvData.Uint("tokenizer.ggml.bos_token_id"); tokens[0] != int(bos) && kvData.Bool("add_bos_token", true) {
|
||||||
|
ctxLen--
|
||||||
|
}
|
||||||
|
|
||||||
|
if eos := kvData.Uint("tokenizer.ggml.eos_token_id"); tokens[len(tokens)-1] != int(eos) && kvData.Bool("add_eos_token", true) {
|
||||||
|
ctxLen--
|
||||||
|
}
|
||||||
|
|
||||||
tokens = tokens[:ctxLen]
|
tokens = tokens[:ctxLen]
|
||||||
|
|
||||||
s, err = r.Detokenize(c.Request.Context(), tokens)
|
s, err = r.Detokenize(c.Request.Context(), tokens)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||||
@@ -563,7 +584,12 @@ func (s *Server) EmbedHandler(c *gin.Context) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
embeddings[i] = normalize(embedding)
|
// TODO: this first normalization should be done by the model
|
||||||
|
embedding = normalize(embedding)
|
||||||
|
if req.Dimensions > 0 && req.Dimensions < len(embedding) {
|
||||||
|
embedding = normalize(embedding[:req.Dimensions])
|
||||||
|
}
|
||||||
|
embeddings[i] = embedding
|
||||||
return nil
|
return nil
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
@@ -589,11 +615,7 @@ func normalize(vec []float32) []float32 {
|
|||||||
sum += v * v
|
sum += v * v
|
||||||
}
|
}
|
||||||
|
|
||||||
norm := float32(0.0)
|
norm := float32(1.0 / max(math.Sqrt(float64(sum)), 1e-12))
|
||||||
if sum > 0 {
|
|
||||||
norm = float32(1.0 / math.Sqrt(float64(sum)))
|
|
||||||
}
|
|
||||||
|
|
||||||
for i := range vec {
|
for i := range vec {
|
||||||
vec[i] *= norm
|
vec[i] *= norm
|
||||||
}
|
}
|
||||||
@@ -1531,7 +1553,7 @@ func (s *Server) ChatHandler(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// expire the runner
|
// expire the runner
|
||||||
if len(req.Messages) == 0 && req.KeepAlive != nil && int(req.KeepAlive.Seconds()) == 0 {
|
if len(req.Messages) == 0 && req.KeepAlive != nil && req.KeepAlive.Duration == 0 {
|
||||||
model, err := GetModel(req.Model)
|
model, err := GetModel(req.Model)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
switch {
|
switch {
|
||||||
@@ -1603,19 +1625,19 @@ func (s *Server) ChatHandler(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
msgs = filterThinkTags(msgs, m)
|
msgs = filterThinkTags(msgs, m)
|
||||||
|
|
||||||
var harmonyMessageHandler *HarmonyMessageHandler
|
var harmonyMessageHandler *harmony.HarmonyMessageHandler
|
||||||
var harmonyToolParser *HarmonyToolCallAccumulator
|
var harmonyToolParser *harmony.HarmonyToolCallAccumulator
|
||||||
|
|
||||||
useHarmony := shouldUseHarmony(*m)
|
useHarmony := shouldUseHarmony(m)
|
||||||
|
|
||||||
processedTools := req.Tools
|
processedTools := req.Tools
|
||||||
if useHarmony {
|
if useHarmony {
|
||||||
harmonyMessageHandler = NewHarmonyMessageHandler()
|
harmonyMessageHandler = harmony.NewHarmonyMessageHandler()
|
||||||
var lastMessage *api.Message
|
var lastMessage *api.Message
|
||||||
if len(msgs) > 0 {
|
if len(msgs) > 0 {
|
||||||
lastMessage = &msgs[len(msgs)-1]
|
lastMessage = &msgs[len(msgs)-1]
|
||||||
}
|
}
|
||||||
harmonyMessageHandler.harmonyParser.AddImplicitStartOrPrefill(lastMessage)
|
harmonyMessageHandler.HarmonyParser.AddImplicitStartOrPrefill(lastMessage)
|
||||||
harmonyToolParser = harmonyMessageHandler.CreateToolParser()
|
harmonyToolParser = harmonyMessageHandler.CreateToolParser()
|
||||||
|
|
||||||
// make a copy of tools to pass to the chat prompt. Function names may be
|
// make a copy of tools to pass to the chat prompt. Function names may be
|
||||||
@@ -1623,7 +1645,7 @@ func (s *Server) ChatHandler(c *gin.Context) {
|
|||||||
processedTools = make([]api.Tool, len(req.Tools))
|
processedTools = make([]api.Tool, len(req.Tools))
|
||||||
copy(processedTools, req.Tools)
|
copy(processedTools, req.Tools)
|
||||||
for i, tool := range processedTools {
|
for i, tool := range processedTools {
|
||||||
processedTools[i].Function.Name = harmonyMessageHandler.functionNameMap.ConvertAndAdd(tool.Function.Name)
|
processedTools[i].Function.Name = harmonyMessageHandler.FunctionNameMap.ConvertAndAdd(tool.Function.Name)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1660,6 +1682,10 @@ func (s *Server) ChatHandler(c *gin.Context) {
|
|||||||
OpeningTag: openingTag,
|
OpeningTag: openingTag,
|
||||||
ClosingTag: closingTag,
|
ClosingTag: closingTag,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if strings.HasSuffix(strings.TrimSpace(prompt), openingTag) {
|
||||||
|
thinkingState.AddContent(openingTag)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
var toolParser *tools.Parser
|
var toolParser *tools.Parser
|
||||||
@@ -1705,7 +1731,7 @@ func (s *Server) ChatHandler(c *gin.Context) {
|
|||||||
toolName, toolContent := harmonyToolParser.Drain()
|
toolName, toolContent := harmonyToolParser.Drain()
|
||||||
if toolName != nil {
|
if toolName != nil {
|
||||||
*toolName = strings.TrimPrefix(*toolName, "functions.")
|
*toolName = strings.TrimPrefix(*toolName, "functions.")
|
||||||
*toolName = harmonyMessageHandler.functionNameMap.OriginalFromConverted(*toolName)
|
*toolName = harmonyMessageHandler.FunctionNameMap.OriginalFromConverted(*toolName)
|
||||||
var args api.ToolCallFunctionArguments
|
var args api.ToolCallFunctionArguments
|
||||||
if err := json.Unmarshal([]byte(toolContent), &args); err != nil {
|
if err := json.Unmarshal([]byte(toolContent), &args); err != nil {
|
||||||
errStr := fmt.Sprintf("error parsing tool call: raw='%s', err=%s", toolContent, err.Error())
|
errStr := fmt.Sprintf("error parsing tool call: raw='%s', err=%s", toolContent, err.Error())
|
||||||
|
|||||||
@@ -969,3 +969,233 @@ func TestGenerate(t *testing.T) {
|
|||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestChatWithPromptEndingInThinkTag(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
|
||||||
|
// Helper to create a standard thinking test setup
|
||||||
|
setupThinkingTest := func(t *testing.T) (*mockRunner, *Server) {
|
||||||
|
mock := &mockRunner{
|
||||||
|
CompletionResponse: llm.CompletionResponse{
|
||||||
|
Done: true,
|
||||||
|
DoneReason: llm.DoneReasonStop,
|
||||||
|
PromptEvalCount: 1,
|
||||||
|
PromptEvalDuration: 1,
|
||||||
|
EvalCount: 1,
|
||||||
|
EvalDuration: 1,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
s := &Server{
|
||||||
|
sched: &Scheduler{
|
||||||
|
pendingReqCh: make(chan *LlmRequest, 1),
|
||||||
|
finishedReqCh: make(chan *LlmRequest, 1),
|
||||||
|
expiredCh: make(chan *runnerRef, 1),
|
||||||
|
unloadedCh: make(chan any, 1),
|
||||||
|
loaded: make(map[string]*runnerRef),
|
||||||
|
newServerFn: newMockServer(mock),
|
||||||
|
getGpuFn: discover.GetGPUInfo,
|
||||||
|
getCpuFn: discover.GetCPUInfo,
|
||||||
|
reschedDelay: 250 * time.Millisecond,
|
||||||
|
loadFn: func(req *LlmRequest, _ *ggml.GGML, _ discover.GpuInfoList, _ bool) bool {
|
||||||
|
time.Sleep(time.Millisecond)
|
||||||
|
req.successCh <- &runnerRef{llama: mock}
|
||||||
|
return false
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
go s.sched.Run(t.Context())
|
||||||
|
|
||||||
|
// Create a model with thinking support
|
||||||
|
_, digest := createBinFile(t, ggml.KV{
|
||||||
|
"general.architecture": "llama",
|
||||||
|
"llama.block_count": uint32(1),
|
||||||
|
"llama.context_length": uint32(8192),
|
||||||
|
"llama.embedding_length": uint32(4096),
|
||||||
|
"llama.attention.head_count": uint32(32),
|
||||||
|
"llama.attention.head_count_kv": uint32(8),
|
||||||
|
"tokenizer.ggml.tokens": []string{""},
|
||||||
|
"tokenizer.ggml.scores": []float32{0},
|
||||||
|
"tokenizer.ggml.token_type": []int32{0},
|
||||||
|
}, []*ggml.Tensor{
|
||||||
|
{Name: "token_embd.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||||
|
{Name: "blk.0.attn_norm.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||||
|
{Name: "blk.0.ffn_down.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||||
|
{Name: "blk.0.ffn_gate.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||||
|
{Name: "blk.0.ffn_up.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||||
|
{Name: "blk.0.ffn_norm.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||||
|
{Name: "blk.0.attn_k.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||||
|
{Name: "blk.0.attn_output.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||||
|
{Name: "blk.0.attn_q.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||||
|
{Name: "blk.0.attn_v.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||||
|
{Name: "output.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||||
|
})
|
||||||
|
|
||||||
|
// Create model with thinking template that adds <think> at the end
|
||||||
|
w := createRequest(t, s.CreateHandler, api.CreateRequest{
|
||||||
|
Model: "test-thinking",
|
||||||
|
Files: map[string]string{"file.gguf": digest},
|
||||||
|
Template: `{{- range .Messages }}
|
||||||
|
{{- if eq .Role "user" }}user: {{ .Content }}
|
||||||
|
{{ else if eq .Role "assistant" }}assistant: {{ if .Thinking }}<think>{{ .Thinking }}</think>{{ end }}{{ .Content }}
|
||||||
|
{{ end }}{{ end }}<think>`,
|
||||||
|
Stream: &stream,
|
||||||
|
})
|
||||||
|
|
||||||
|
if w.Code != http.StatusOK {
|
||||||
|
t.Fatalf("expected status 200, got %d", w.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
return mock, s
|
||||||
|
}
|
||||||
|
|
||||||
|
mock, s := setupThinkingTest(t)
|
||||||
|
|
||||||
|
// Helper to test chat responses
|
||||||
|
testChatRequest := func(t *testing.T, name string, userContent string, modelResponse string, expectedThinking string, expectedContent string, think bool) {
|
||||||
|
t.Run(name, func(t *testing.T) {
|
||||||
|
mock.CompletionResponse = llm.CompletionResponse{
|
||||||
|
Content: modelResponse,
|
||||||
|
Done: true,
|
||||||
|
DoneReason: llm.DoneReasonStop,
|
||||||
|
PromptEvalCount: 1,
|
||||||
|
PromptEvalDuration: 1,
|
||||||
|
EvalCount: 1,
|
||||||
|
EvalDuration: 1,
|
||||||
|
}
|
||||||
|
mock.CompletionFn = nil
|
||||||
|
|
||||||
|
streamRequest := false
|
||||||
|
req := api.ChatRequest{
|
||||||
|
Model: "test-thinking",
|
||||||
|
Messages: []api.Message{
|
||||||
|
{Role: "user", Content: userContent},
|
||||||
|
},
|
||||||
|
Stream: &streamRequest,
|
||||||
|
}
|
||||||
|
if think {
|
||||||
|
req.Think = &api.ThinkValue{Value: think}
|
||||||
|
}
|
||||||
|
|
||||||
|
w := createRequest(t, s.ChatHandler, req)
|
||||||
|
if w.Code != http.StatusOK {
|
||||||
|
t.Fatalf("expected status 200, got %d", w.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
var resp api.ChatResponse
|
||||||
|
if err := json.NewDecoder(w.Body).Decode(&resp); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if resp.Message.Thinking != expectedThinking {
|
||||||
|
t.Errorf("expected thinking %q, got %q", expectedThinking, resp.Message.Thinking)
|
||||||
|
}
|
||||||
|
|
||||||
|
if resp.Message.Content != expectedContent {
|
||||||
|
t.Errorf("expected content %q, got %q", expectedContent, resp.Message.Content)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test cases - Note: Template adds <think> at the end, and leading whitespace after <think> is eaten by the parser
|
||||||
|
testChatRequest(t, "basic thinking response",
|
||||||
|
"Help me solve this problem",
|
||||||
|
" Let me think about this step by step... </think> The answer is 42.",
|
||||||
|
"Let me think about this step by step... ",
|
||||||
|
"The answer is 42.",
|
||||||
|
true)
|
||||||
|
|
||||||
|
testChatRequest(t, "thinking with multiple sentences",
|
||||||
|
"Explain quantum computing",
|
||||||
|
" First, I need to understand the basics. Quantum bits can be in superposition. </think> Quantum computing uses quantum mechanics principles.",
|
||||||
|
"First, I need to understand the basics. Quantum bits can be in superposition. ",
|
||||||
|
"Quantum computing uses quantum mechanics principles.",
|
||||||
|
true)
|
||||||
|
|
||||||
|
testChatRequest(t, "no thinking content",
|
||||||
|
"What is 2+2?",
|
||||||
|
"</think> The answer is 4.",
|
||||||
|
"",
|
||||||
|
"The answer is 4.",
|
||||||
|
true)
|
||||||
|
|
||||||
|
testChatRequest(t, "thinking disabled but template still adds think tag",
|
||||||
|
"Simple question",
|
||||||
|
" My thoughts </think> The answer.",
|
||||||
|
"",
|
||||||
|
" My thoughts </think> The answer.",
|
||||||
|
false)
|
||||||
|
|
||||||
|
// Test streaming response with template-added <think>
|
||||||
|
t.Run("streaming with thinking", func(t *testing.T) {
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
wg.Add(1)
|
||||||
|
|
||||||
|
mock.CompletionFn = func(ctx context.Context, r llm.CompletionRequest, fn func(r llm.CompletionResponse)) error {
|
||||||
|
defer wg.Done()
|
||||||
|
|
||||||
|
// Verify the prompt ends with <think> due to template
|
||||||
|
if !strings.HasSuffix(r.Prompt, "<think>") {
|
||||||
|
t.Errorf("expected prompt to end with <think>, got: %q", r.Prompt)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Simulate streaming chunks
|
||||||
|
responses := []llm.CompletionResponse{
|
||||||
|
{Content: " I need to consider", Done: false, PromptEvalCount: 1, PromptEvalDuration: 1},
|
||||||
|
{Content: " multiple factors here...", Done: false, PromptEvalCount: 1, PromptEvalDuration: 1},
|
||||||
|
{Content: " </think> Based on my analysis,", Done: false, PromptEvalCount: 1, PromptEvalDuration: 1},
|
||||||
|
{Content: " the solution is straightforward.", Done: true, DoneReason: llm.DoneReasonStop, PromptEvalCount: 1, PromptEvalDuration: 1, EvalCount: 1, EvalDuration: 1},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, resp := range responses {
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return ctx.Err()
|
||||||
|
default:
|
||||||
|
fn(resp)
|
||||||
|
time.Sleep(10 * time.Millisecond)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
think := true
|
||||||
|
w := createRequest(t, s.ChatHandler, api.ChatRequest{
|
||||||
|
Model: "test-thinking",
|
||||||
|
Messages: []api.Message{{Role: "user", Content: "Analyze this complex problem"}},
|
||||||
|
Think: &api.ThinkValue{Value: think},
|
||||||
|
Stream: &stream,
|
||||||
|
})
|
||||||
|
|
||||||
|
wg.Wait()
|
||||||
|
|
||||||
|
if w.Code != http.StatusOK {
|
||||||
|
t.Fatalf("expected status 200, got %d", w.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Parse streaming responses
|
||||||
|
decoder := json.NewDecoder(w.Body)
|
||||||
|
var allThinking, allContent strings.Builder
|
||||||
|
|
||||||
|
for {
|
||||||
|
var resp api.ChatResponse
|
||||||
|
if err := decoder.Decode(&resp); err == io.EOF {
|
||||||
|
break
|
||||||
|
} else if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
allThinking.WriteString(resp.Message.Thinking)
|
||||||
|
allContent.WriteString(resp.Message.Content)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Note: Leading whitespace after <think> is eaten by the parser
|
||||||
|
if got := allThinking.String(); got != "I need to consider multiple factors here... " {
|
||||||
|
t.Errorf("expected thinking %q, got %q", "I need to consider multiple factors here... ", got)
|
||||||
|
}
|
||||||
|
|
||||||
|
if got := allContent.String(); got != "Based on my analysis, the solution is straightforward." {
|
||||||
|
t.Errorf("expected content %q, got %q", "Based on my analysis, the solution is straightforward.", got)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|||||||
@@ -103,7 +103,9 @@ func eat(s *Parser) (string, string, bool) {
|
|||||||
// note that we use the original content, not the trimmed one because we
|
// note that we use the original content, not the trimmed one because we
|
||||||
// don't want to eat any whitespace in the real content if there were no
|
// don't want to eat any whitespace in the real content if there were no
|
||||||
// thinking tags
|
// thinking tags
|
||||||
return "", s.acc.String(), false
|
untrimmed := s.acc.String()
|
||||||
|
s.acc.Reset()
|
||||||
|
return "", untrimmed, false
|
||||||
}
|
}
|
||||||
case thinkingState_ThinkingStartedEatingWhitespace:
|
case thinkingState_ThinkingStartedEatingWhitespace:
|
||||||
trimmed := strings.TrimLeftFunc(s.acc.String(), unicode.IsSpace)
|
trimmed := strings.TrimLeftFunc(s.acc.String(), unicode.IsSpace)
|
||||||
|
|||||||
@@ -58,6 +58,15 @@ func TestThinkingStreaming(t *testing.T) {
|
|||||||
wantContent: " abc",
|
wantContent: " abc",
|
||||||
wantStateAfter: thinkingState_ThinkingDone,
|
wantStateAfter: thinkingState_ThinkingDone,
|
||||||
},
|
},
|
||||||
|
// regression test for a bug where we were transitioning directly to
|
||||||
|
// ThinkingDone without clearing the buffer. This would cuase the first
|
||||||
|
// step to be outputted twice
|
||||||
|
{
|
||||||
|
input: "def",
|
||||||
|
wantThinking: "",
|
||||||
|
wantContent: "def",
|
||||||
|
wantStateAfter: thinkingState_ThinkingDone,
|
||||||
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
|||||||
@@ -224,22 +224,45 @@ func findArguments(buffer []byte) (map[string]any, int) {
|
|||||||
return nil, 0
|
return nil, 0
|
||||||
}
|
}
|
||||||
|
|
||||||
|
start := -1
|
||||||
var braces int
|
var braces int
|
||||||
var start int = -1
|
var inString, escaped bool
|
||||||
|
|
||||||
|
for i := range buffer {
|
||||||
|
c := buffer[i]
|
||||||
|
|
||||||
|
if escaped {
|
||||||
|
escaped = false
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if c == '\\' {
|
||||||
|
escaped = true
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if c == '"' {
|
||||||
|
inString = !inString
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if inString {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
for i, c := range buffer {
|
|
||||||
if c == '{' {
|
if c == '{' {
|
||||||
if braces == 0 {
|
if braces == 0 {
|
||||||
start = i
|
start = i
|
||||||
}
|
}
|
||||||
braces++
|
braces++
|
||||||
} else if c == '}' && braces > 0 {
|
} else if c == '}' {
|
||||||
braces--
|
braces--
|
||||||
if braces == 0 && start != -1 {
|
if braces == 0 && start != -1 {
|
||||||
object := buffer[start : i+1]
|
object := buffer[start : i+1]
|
||||||
|
|
||||||
var data map[string]any
|
var data map[string]any
|
||||||
if err := json.Unmarshal(object, &data); err != nil {
|
if err := json.Unmarshal(object, &data); err != nil {
|
||||||
|
// not a valid object, keep looking
|
||||||
start = -1
|
start = -1
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
@@ -282,6 +305,10 @@ func findArguments(buffer []byte) (map[string]any, int) {
|
|||||||
|
|
||||||
return data, i
|
return data, i
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if braces < 0 {
|
||||||
|
braces = 0
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
package tools
|
package tools
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
"text/template"
|
"text/template"
|
||||||
|
|
||||||
@@ -40,13 +41,7 @@ func TestParser(t *testing.T) {
|
|||||||
Function: api.ToolFunction{
|
Function: api.ToolFunction{
|
||||||
Name: "get_temperature",
|
Name: "get_temperature",
|
||||||
Description: "Retrieve the temperature for a given location",
|
Description: "Retrieve the temperature for a given location",
|
||||||
Parameters: struct {
|
Parameters: api.ToolFunctionParameters{
|
||||||
Type string `json:"type"`
|
|
||||||
Defs any `json:"$defs,omitempty"`
|
|
||||||
Items any `json:"items,omitempty"`
|
|
||||||
Required []string `json:"required"`
|
|
||||||
Properties map[string]api.ToolProperty `json:"properties"`
|
|
||||||
}{
|
|
||||||
Type: "object",
|
Type: "object",
|
||||||
Required: []string{"city"},
|
Required: []string{"city"},
|
||||||
Properties: map[string]api.ToolProperty{
|
Properties: map[string]api.ToolProperty{
|
||||||
@@ -68,13 +63,7 @@ func TestParser(t *testing.T) {
|
|||||||
Function: api.ToolFunction{
|
Function: api.ToolFunction{
|
||||||
Name: "get_conditions",
|
Name: "get_conditions",
|
||||||
Description: "Retrieve the current weather conditions for a given location",
|
Description: "Retrieve the current weather conditions for a given location",
|
||||||
Parameters: struct {
|
Parameters: api.ToolFunctionParameters{
|
||||||
Type string `json:"type"`
|
|
||||||
Defs any `json:"$defs,omitempty"`
|
|
||||||
Items any `json:"items,omitempty"`
|
|
||||||
Required []string `json:"required"`
|
|
||||||
Properties map[string]api.ToolProperty `json:"properties"`
|
|
||||||
}{
|
|
||||||
Type: "object",
|
Type: "object",
|
||||||
Properties: map[string]api.ToolProperty{
|
Properties: map[string]api.ToolProperty{
|
||||||
"location": {
|
"location": {
|
||||||
@@ -104,13 +93,7 @@ func TestParser(t *testing.T) {
|
|||||||
Function: api.ToolFunction{
|
Function: api.ToolFunction{
|
||||||
Name: "get_address",
|
Name: "get_address",
|
||||||
Description: "Get the address of a given location",
|
Description: "Get the address of a given location",
|
||||||
Parameters: struct {
|
Parameters: api.ToolFunctionParameters{
|
||||||
Type string `json:"type"`
|
|
||||||
Defs any `json:"$defs,omitempty"`
|
|
||||||
Items any `json:"items,omitempty"`
|
|
||||||
Required []string `json:"required"`
|
|
||||||
Properties map[string]api.ToolProperty `json:"properties"`
|
|
||||||
}{
|
|
||||||
Type: "object",
|
Type: "object",
|
||||||
Properties: map[string]api.ToolProperty{
|
Properties: map[string]api.ToolProperty{
|
||||||
"location": {
|
"location": {
|
||||||
@@ -126,13 +109,7 @@ func TestParser(t *testing.T) {
|
|||||||
Function: api.ToolFunction{
|
Function: api.ToolFunction{
|
||||||
Name: "add",
|
Name: "add",
|
||||||
Description: "Add two numbers",
|
Description: "Add two numbers",
|
||||||
Parameters: struct {
|
Parameters: api.ToolFunctionParameters{
|
||||||
Type string `json:"type"`
|
|
||||||
Defs any `json:"$defs,omitempty"`
|
|
||||||
Items any `json:"items,omitempty"`
|
|
||||||
Required []string `json:"required"`
|
|
||||||
Properties map[string]api.ToolProperty `json:"properties"`
|
|
||||||
}{
|
|
||||||
Type: "object",
|
Type: "object",
|
||||||
Properties: map[string]api.ToolProperty{
|
Properties: map[string]api.ToolProperty{
|
||||||
"a": {
|
"a": {
|
||||||
@@ -1140,11 +1117,163 @@ func TestFindArguments(t *testing.T) {
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "deepseek",
|
name: "deepseek",
|
||||||
buffer: []byte(`", "arguments": {"location": "Tokyo"}}</tool_call>`),
|
buffer: []byte(`"arguments": {"location": "Tokyo"}}</tool_call>`),
|
||||||
want: map[string]any{
|
want: map[string]any{
|
||||||
"location": "Tokyo",
|
"location": "Tokyo",
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
name: "string with braces",
|
||||||
|
buffer: []byte(`{"name": "process_code", "arguments": {"code": "if (x > 0) { return true; }"}}`),
|
||||||
|
want: map[string]any{
|
||||||
|
"code": "if (x > 0) { return true; }",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "string with nested json",
|
||||||
|
buffer: []byte(`{"name": "send_data", "arguments": {"payload": "{\"nested\": {\"key\": \"value\"}}"}}`),
|
||||||
|
want: map[string]any{
|
||||||
|
"payload": `{"nested": {"key": "value"}}`,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "string with escaped quotes and braces",
|
||||||
|
buffer: []byte(`{"name": "analyze", "arguments": {"text": "The JSON is: {\"key\": \"val{ue}\"}"}}`),
|
||||||
|
want: map[string]any{
|
||||||
|
"text": `The JSON is: {"key": "val{ue}"}`,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "multiple objects with string containing braces",
|
||||||
|
buffer: []byte(`{"name": "test", "arguments": {"query": "find } in text"}} {"name": "other"}`),
|
||||||
|
want: map[string]any{
|
||||||
|
"query": "find } in text",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "unmatched closing brace in string",
|
||||||
|
buffer: []byte(`{"name": "search", "arguments": {"pattern": "regex: }"}}`),
|
||||||
|
want: map[string]any{
|
||||||
|
"pattern": "regex: }",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "complex nested with mixed braces",
|
||||||
|
buffer: []byte(`{"name": "analyze", "arguments": {"data": "{\"items\": [{\"value\": \"}\"}, {\"code\": \"if (x) { return y; }\"}]}"}}`),
|
||||||
|
want: map[string]any{
|
||||||
|
"data": `{"items": [{"value": "}"}, {"code": "if (x) { return y; }"}]}`,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "string with newline and braces",
|
||||||
|
buffer: []byte(`{"name": "format", "arguments": {"template": "{\n \"key\": \"value\"\n}"}}`),
|
||||||
|
want: map[string]any{
|
||||||
|
"template": "{\n \"key\": \"value\"\n}",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "string with unicode escape",
|
||||||
|
buffer: []byte(`{"name": "test", "arguments": {"text": "Unicode: \u007B and \u007D"}}`),
|
||||||
|
want: map[string]any{
|
||||||
|
"text": "Unicode: { and }",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "array arguments",
|
||||||
|
buffer: []byte(`{"name": "batch", "arguments": ["item1", "item2", "{\"nested\": true}"]}`),
|
||||||
|
want: nil, // This should return nil because arguments is not a map
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "escaped backslash before quote",
|
||||||
|
buffer: []byte(`{"name": "path", "arguments": {"dir": "C:\\Program Files\\{App}\\"}}`),
|
||||||
|
want: map[string]any{
|
||||||
|
"dir": `C:\Program Files\{App}\`,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "single quotes not treated as string delimiters",
|
||||||
|
buffer: []byte(`{"name": "query", "arguments": {"sql": "SELECT * FROM users WHERE name = '{admin}'"}}`),
|
||||||
|
want: map[string]any{
|
||||||
|
"sql": "SELECT * FROM users WHERE name = '{admin}'",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "incomplete json at buffer end",
|
||||||
|
buffer: []byte(`{"name": "test", "arguments": {"data": "some {"`),
|
||||||
|
want: nil,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "multiple escaped quotes",
|
||||||
|
buffer: []byte(`{"name": "echo", "arguments": {"msg": "He said \"Hello {World}\" loudly"}}`),
|
||||||
|
want: map[string]any{
|
||||||
|
"msg": `He said "Hello {World}" loudly`,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "json with comments style string",
|
||||||
|
buffer: []byte(`{"name": "code", "arguments": {"snippet": "// This is a comment with { and }"}}`),
|
||||||
|
want: map[string]any{
|
||||||
|
"snippet": "// This is a comment with { and }",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "consecutive escaped backslashes",
|
||||||
|
buffer: []byte(`{"name": "test", "arguments": {"path": "C:\\\\{folder}\\\\"}}`),
|
||||||
|
want: map[string]any{
|
||||||
|
"path": `C:\\{folder}\\`,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "empty string with braces after",
|
||||||
|
buffer: []byte(`{"name": "test", "arguments": {"a": "", "b": "{value}"}}`),
|
||||||
|
want: map[string]any{
|
||||||
|
"a": "",
|
||||||
|
"b": "{value}",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "unicode in key names",
|
||||||
|
buffer: []byte(`{"name": "test", "arguments": {"key{": "value", "key}": "value2"}}`),
|
||||||
|
want: map[string]any{
|
||||||
|
"key{": "value",
|
||||||
|
"key}": "value2",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "very long string with braces",
|
||||||
|
buffer: []byte(`{"name": "test", "arguments": {"data": "` + strings.Repeat("a{b}c", 100) + `"}}`),
|
||||||
|
want: map[string]any{
|
||||||
|
"data": strings.Repeat("a{b}c", 100),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "tab characters and braces",
|
||||||
|
buffer: []byte(`{"name": "test", "arguments": {"code": "\tif (true) {\n\t\treturn;\n\t}"}}`),
|
||||||
|
want: map[string]any{
|
||||||
|
"code": "\tif (true) {\n\t\treturn;\n\t}",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "null byte in string",
|
||||||
|
buffer: []byte(`{"name": "test", "arguments": {"data": "before\u0000{after}"}}`),
|
||||||
|
want: map[string]any{
|
||||||
|
"data": "before\x00{after}",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "escaped quote at end of string",
|
||||||
|
buffer: []byte(`{"name": "test", "arguments": {"data": "text with quote at end\\\""}}`),
|
||||||
|
want: map[string]any{
|
||||||
|
"data": `text with quote at end\"`,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "mixed array and object in arguments",
|
||||||
|
buffer: []byte(`{"name": "test", "arguments": {"items": ["{", "}", {"key": "value"}]}}`),
|
||||||
|
want: map[string]any{
|
||||||
|
"items": []any{"{", "}", map[string]any{"key": "value"}},
|
||||||
|
},
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
|
|||||||
Reference in New Issue
Block a user