Compare commits

..

4 Commits

Author SHA1 Message Date
ParthSareen
b4cd1118ab checkpoint for vscode 2025-04-24 18:23:23 -07:00
ParthSareen
128c90d3ac checkpoint!!! 2025-04-24 16:57:54 -07:00
ParthSareen
f5872a097c checkpoint 2025-04-23 15:45:35 -07:00
ParthSareen
3ac5e0f102 model: update tool calling to use regex 2025-04-14 17:35:17 -07:00
421 changed files with 46011 additions and 73051 deletions

View File

@@ -103,18 +103,21 @@ jobs:
arch: [amd64] arch: [amd64]
preset: ['CPU'] preset: ['CPU']
include: include:
- os: windows
arch: amd64
preset: 'CUDA 11'
install: https://developer.download.nvidia.com/compute/cuda/11.3.1/local_installers/cuda_11.3.1_465.89_win10.exe
cuda-version: '11.3'
- os: windows - os: windows
arch: amd64 arch: amd64
preset: 'CUDA 12' preset: 'CUDA 12'
install: https://developer.download.nvidia.com/compute/cuda/12.8.0/local_installers/cuda_12.8.0_571.96_windows.exe install: https://developer.download.nvidia.com/compute/cuda/12.8.0/local_installers/cuda_12.8.0_571.96_windows.exe
cuda-version: '12.8' cuda-version: '12.8'
flags: ''
- os: windows - os: windows
arch: amd64 arch: amd64
preset: 'ROCm 6' preset: 'ROCm 6'
install: https://download.amd.com/developer/eula/rocm-hub/AMD-Software-PRO-Edition-24.Q4-WinSvr2022-For-HIP.exe install: https://download.amd.com/developer/eula/rocm-hub/AMD-Software-PRO-Edition-24.Q4-WinSvr2022-For-HIP.exe
rocm-version: '6.2' rocm-version: '6.2'
flags: '-DCMAKE_C_COMPILER=clang -DCMAKE_CXX_COMPILER=clang++ -DCMAKE_C_FLAGS="-parallel-jobs=4 -Wno-ignored-attributes -Wno-deprecated-pragma" -DCMAKE_CXX_FLAGS="-parallel-jobs=4 -Wno-ignored-attributes -Wno-deprecated-pragma"'
runs-on: ${{ matrix.arch == 'arm64' && format('{0}-{1}', matrix.os, matrix.arch) || matrix.os }} runs-on: ${{ matrix.arch == 'arm64' && format('{0}-{1}', matrix.os, matrix.arch) || matrix.os }}
environment: release environment: release
env: env:
@@ -157,9 +160,6 @@ jobs:
echo "$hipPath\bin" | Out-File -FilePath $env:GITHUB_PATH -Encoding utf8 -Append echo "$hipPath\bin" | Out-File -FilePath $env:GITHUB_PATH -Encoding utf8 -Append
echo "CC=$hipPath\bin\clang.exe" | Out-File -FilePath $env:GITHUB_ENV -Append echo "CC=$hipPath\bin\clang.exe" | Out-File -FilePath $env:GITHUB_ENV -Append
echo "CXX=$hipPath\bin\clang++.exe" | Out-File -FilePath $env:GITHUB_ENV -Append echo "CXX=$hipPath\bin\clang++.exe" | Out-File -FilePath $env:GITHUB_ENV -Append
echo "HIPCXX=$hipPath\bin\clang++.exe" | Out-File -FilePath $env:GITHUB_ENV -Append
echo "HIP_PLATFORM=amd" | Out-File -FilePath $env:GITHUB_ENV -Append
echo "CMAKE_PREFIX_PATH=$hipPath" | Out-File -FilePath $env:GITHUB_ENV -Append
- if: matrix.preset == 'CPU' - if: matrix.preset == 'CPU'
run: | run: |
echo "CC=clang.exe" | Out-File -FilePath $env:GITHUB_ENV -Append echo "CC=clang.exe" | Out-File -FilePath $env:GITHUB_ENV -Append
@@ -178,9 +178,9 @@ jobs:
key: ccache-${{ matrix.os }}-${{ matrix.arch }}-${{ matrix.preset }} key: ccache-${{ matrix.os }}-${{ matrix.arch }}-${{ matrix.preset }}
- name: Build target "${{ matrix.preset }}" - name: Build target "${{ matrix.preset }}"
run: | run: |
Import-Module 'C:\Program Files\Microsoft Visual Studio\2022\Enterprise\Common7\Tools\Microsoft.VisualStudio.DevShell.dll' Import-Module 'C:\Program Files (x86)\Microsoft Visual Studio\2019\Enterprise\Common7\Tools\Microsoft.VisualStudio.DevShell.dll'
Enter-VsDevShell -VsInstallPath 'C:\Program Files\Microsoft Visual Studio\2022\Enterprise' -SkipAutomaticLocation -DevCmdArguments '-arch=x64 -no_logo' Enter-VsDevShell -VsInstallPath 'C:\Program Files (x86)\Microsoft Visual Studio\2019\Enterprise' -SkipAutomaticLocation -DevCmdArguments '-arch=x64 -no_logo'
cmake --preset "${{ matrix.preset }}" ${{ matrix.flags }} cmake --preset "${{ matrix.preset }}"
cmake --build --parallel --preset "${{ matrix.preset }}" cmake --build --parallel --preset "${{ matrix.preset }}"
cmake --install build --component "${{ startsWith(matrix.preset, 'CUDA ') && 'CUDA' || startsWith(matrix.preset, 'ROCm ') && 'HIP' || 'CPU' }}" --strip --parallel 8 cmake --install build --component "${{ startsWith(matrix.preset, 'CUDA ') && 'CUDA' || startsWith(matrix.preset, 'ROCm ') && 'HIP' || 'CPU' }}" --strip --parallel 8
env: env:
@@ -246,7 +246,7 @@ jobs:
dist\${{ matrix.os }}-${{ matrix.arch }}-app.exe dist\${{ matrix.os }}-${{ matrix.arch }}-app.exe
windows-sign: windows-sign:
runs-on: windows runs-on: windows-2022
environment: release environment: release
needs: [windows-depends, windows-build] needs: [windows-depends, windows-build]
steps: steps:
@@ -322,21 +322,16 @@ jobs:
- run: | - run: |
for COMPONENT in bin/* lib/ollama/*; do for COMPONENT in bin/* lib/ollama/*; do
case "$COMPONENT" in case "$COMPONENT" in
bin/ollama) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}.tar.in ;; bin/ollama) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}.tar.in ;;
lib/ollama/*.so*) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}.tar.in ;; lib/ollama/*.so) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}.tar.in ;;
lib/ollama/cuda_sbsa) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}.tar.in ;; lib/ollama/cuda_v11) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}.tar.in ;;
lib/ollama/cuda_jetpack5) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}-jetpack5.tar.in ;; lib/ollama/cuda_v12) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}.tar.in ;;
lib/ollama/cuda_jetpack6) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}-jetpack6.tar.in ;; lib/ollama/cuda_jetpack5) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}-jetpack5.tar.in ;;
lib/ollama/rocm) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}-rocm.tar.in ;; lib/ollama/cuda_jetpack6) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}-jetpack6.tar.in ;;
lib/ollama/rocm) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}-rocm.tar.in ;;
esac esac
done done
working-directory: dist/${{ matrix.os }}-${{ matrix.arch }} working-directory: dist/${{ matrix.os }}-${{ matrix.arch }}
- run: |
echo "Manifests"
for ARCHIVE in dist/${{ matrix.os }}-${{ matrix.arch }}/*.tar.in ; do
echo $ARCHIVE
cat $ARCHIVE
done
- run: | - run: |
for ARCHIVE in dist/${{ matrix.os }}-${{ matrix.arch }}/*.tar.in; do for ARCHIVE in dist/${{ matrix.os }}-${{ matrix.arch }}/*.tar.in; do
tar c -C dist/${{ matrix.os }}-${{ matrix.arch }} -T $ARCHIVE --owner 0 --group 0 | pigz -9vc >$(basename ${ARCHIVE//.*/}.tgz); tar c -C dist/${{ matrix.os }}-${{ matrix.arch }} -T $ARCHIVE --owner 0 --group 0 | pigz -9vc >$(basename ${ARCHIVE//.*/}.tgz);
@@ -437,22 +432,6 @@ jobs:
docker buildx imagetools inspect ollama/ollama:${{ steps.metadata.outputs.version }} docker buildx imagetools inspect ollama/ollama:${{ steps.metadata.outputs.version }}
working-directory: ${{ runner.temp }} working-directory: ${{ runner.temp }}
# Trigger downstream release process
trigger:
runs-on: ubuntu-latest
environment: release
needs: [darwin-build, windows-build, windows-depends]
steps:
- name: Trigger downstream release process
run: |
curl -L \
-X POST \
-H "Accept: application/vnd.github+json" \
-H "Authorization: Bearer ${{ secrets.RELEASE_TOKEN }}" \
-H "X-GitHub-Api-Version: 2022-11-28" \
https://api.github.com/repos/ollama/${{ vars.RELEASE_REPO }}/dispatches \
-d "{\"event_type\": \"trigger-workflow\", \"client_payload\": {\"run_id\": \"${GITHUB_RUN_ID}\", \"version\": \"${GITHUB_REF_NAME#v}\"}}"
# Aggregate all the assets and ship a release # Aggregate all the assets and ship a release
release: release:
needs: [darwin-sign, windows-sign, linux-build] needs: [darwin-sign, windows-sign, linux-build]
@@ -475,18 +454,8 @@ jobs:
- uses: actions/download-artifact@v4 - uses: actions/download-artifact@v4
with: with:
pattern: dist-linux-* pattern: dist-linux-*
path: stage path: dist
merge-multiple: false merge-multiple: true
- name: Merge linux amd64 payload
working-directory: stage/dist-linux-amd64-archive
run: |
tar zxf ollama-linux-amd64.tgz
tar zxf ../dist-linux-amd64-rocm/ollama-linux-amd64.tgz
rm -f ollama-linux-amd64.tgz ../dist-linux-amd64-rocm/ollama-linux-amd64.tgz
tar -c -f- --owner 0 --group 0 . | pigz -9vc > ../ollama-linux-amd64.tgz
- name: Cleanup linux payloads
run: |
find stage -name ollama-linux\*.tgz -exec mv {} dist/ \;
- run: find . -type f -not -name 'sha256sum.txt' | xargs sha256sum | tee sha256sum.txt - run: find . -type f -not -name 'sha256sum.txt' | xargs sha256sum | tee sha256sum.txt
working-directory: dist working-directory: dist
- name: Create or update Release - name: Create or update Release

View File

@@ -36,7 +36,7 @@ jobs:
| xargs python3 -c "import sys; from pathlib import Path; print(any(Path(x).match(glob) for x in sys.argv[1:] for glob in '$*'.split(' ')))" | xargs python3 -c "import sys; from pathlib import Path; print(any(Path(x).match(glob) for x in sys.argv[1:] for glob in '$*'.split(' ')))"
} }
echo changed=$(changed 'llama/llama.cpp/**/*' 'ml/backend/ggml/ggml/**/*') | tee -a $GITHUB_OUTPUT echo changed=$(changed 'llama/llama.cpp/**' 'ml/backend/ggml/ggml/**') | tee -a $GITHUB_OUTPUT
linux: linux:
needs: [changes] needs: [changes]
@@ -46,7 +46,7 @@ jobs:
include: include:
- preset: CPU - preset: CPU
- preset: CUDA - preset: CUDA
container: nvidia/cuda:12.8.1-devel-ubuntu22.04 container: nvidia/cuda:11.8.0-devel-ubuntu22.04
flags: '-DCMAKE_CUDA_ARCHITECTURES=87' flags: '-DCMAKE_CUDA_ARCHITECTURES=87'
- preset: ROCm - preset: ROCm
container: rocm/dev-ubuntu-22.04:6.1.2 container: rocm/dev-ubuntu-22.04:6.1.2
@@ -78,11 +78,11 @@ jobs:
include: include:
- preset: CPU - preset: CPU
- preset: CUDA - preset: CUDA
install: https://developer.download.nvidia.com/compute/cuda/12.8.0/local_installers/cuda_12.8.0_571.96_windows.exe install: https://developer.download.nvidia.com/compute/cuda/11.3.1/local_installers/cuda_11.3.1_465.89_win10.exe
flags: '-DCMAKE_CUDA_ARCHITECTURES=80' flags: '-DCMAKE_CUDA_ARCHITECTURES=80'
- preset: ROCm - preset: ROCm
install: https://download.amd.com/developer/eula/rocm-hub/AMD-Software-PRO-Edition-24.Q4-WinSvr2022-For-HIP.exe install: https://download.amd.com/developer/eula/rocm-hub/AMD-Software-PRO-Edition-24.Q4-WinSvr2022-For-HIP.exe
flags: '-DAMDGPU_TARGETS=gfx1010 -DCMAKE_C_COMPILER=clang -DCMAKE_CXX_COMPILER=clang++ -DCMAKE_C_FLAGS="-parallel-jobs=4 -Wno-ignored-attributes -Wno-deprecated-pragma" -DCMAKE_CXX_FLAGS="-parallel-jobs=4 -Wno-ignored-attributes -Wno-deprecated-pragma"' flags: '-DAMDGPU_TARGETS=gfx1010'
runs-on: windows runs-on: windows
steps: steps:
- run: | - run: |
@@ -102,7 +102,7 @@ jobs:
$ErrorActionPreference = "Stop" $ErrorActionPreference = "Stop"
if ("${{ steps.cache-install.outputs.cache-hit }}" -ne 'true') { if ("${{ steps.cache-install.outputs.cache-hit }}" -ne 'true') {
Invoke-WebRequest -Uri "${{ matrix.install }}" -OutFile "install.exe" Invoke-WebRequest -Uri "${{ matrix.install }}" -OutFile "install.exe"
Start-Process -FilePath .\install.exe -ArgumentList (@("-s", "cudart_12.8", "nvcc_12.8", "cublas_12.8", "cublas_dev_12.8")) -NoNewWindow -Wait Start-Process -FilePath .\install.exe -ArgumentList (@("-s", "cudart_11.3", "nvcc_11.3", "cublas_11.3", "cublas_dev_11.3")) -NoNewWindow -Wait
} }
$cudaPath = (Resolve-Path "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\*").path $cudaPath = (Resolve-Path "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\*").path
@@ -120,9 +120,6 @@ jobs:
echo "$hipPath\bin" | Out-File -FilePath $env:GITHUB_PATH -Encoding utf8 -Append echo "$hipPath\bin" | Out-File -FilePath $env:GITHUB_PATH -Encoding utf8 -Append
echo "CC=$hipPath\bin\clang.exe" | Out-File -FilePath $env:GITHUB_ENV -Append echo "CC=$hipPath\bin\clang.exe" | Out-File -FilePath $env:GITHUB_ENV -Append
echo "CXX=$hipPath\bin\clang++.exe" | Out-File -FilePath $env:GITHUB_ENV -Append echo "CXX=$hipPath\bin\clang++.exe" | Out-File -FilePath $env:GITHUB_ENV -Append
echo "HIPCXX=$hipPath\bin\clang++.exe" | Out-File -FilePath $env:GITHUB_ENV -Append
echo "HIP_PLATFORM=amd" | Out-File -FilePath $env:GITHUB_ENV -Append
echo "CMAKE_PREFIX_PATH=$hipPath" | Out-File -FilePath $env:GITHUB_ENV -Append
- if: ${{ !cancelled() && steps.cache-install.outputs.cache-hit != 'true' }} - if: ${{ !cancelled() && steps.cache-install.outputs.cache-hit != 'true' }}
uses: actions/cache/save@v4 uses: actions/cache/save@v4
with: with:
@@ -136,8 +133,8 @@ jobs:
path: ${{ github.workspace }}\.ccache path: ${{ github.workspace }}\.ccache
key: ccache-${{ runner.os }}-${{ runner.arch }}-${{ matrix.preset }} key: ccache-${{ runner.os }}-${{ runner.arch }}-${{ matrix.preset }}
- run: | - run: |
Import-Module 'C:\Program Files\Microsoft Visual Studio\2022\Enterprise\Common7\Tools\Microsoft.VisualStudio.DevShell.dll' Import-Module 'C:\Program Files (x86)\Microsoft Visual Studio\2019\Enterprise\Common7\Tools\Microsoft.VisualStudio.DevShell.dll'
Enter-VsDevShell -VsInstallPath 'C:\Program Files\Microsoft Visual Studio\2022\Enterprise' -SkipAutomaticLocation -DevCmdArguments '-arch=x64 -no_logo' Enter-VsDevShell -VsInstallPath 'C:\Program Files (x86)\Microsoft Visual Studio\2019\Enterprise' -SkipAutomaticLocation -DevCmdArguments '-arch=x64 -no_logo'
cmake --preset "${{ matrix.preset }}" ${{ matrix.flags }} cmake --preset "${{ matrix.preset }}" ${{ matrix.flags }}
cmake --build --parallel --preset "${{ matrix.preset }}" cmake --build --parallel --preset "${{ matrix.preset }}"
env: env:
@@ -240,5 +237,5 @@ jobs:
- uses: actions/checkout@v4 - uses: actions/checkout@v4
- name: Verify patches apply cleanly and do not change files - name: Verify patches apply cleanly and do not change files
run: | run: |
make -f Makefile.sync clean checkout apply-patches sync make -f Makefile.sync clean sync
git diff --compact-summary --exit-code git diff --compact-summary --exit-code

View File

@@ -19,8 +19,8 @@ linters:
- nolintlint - nolintlint
- nosprintfhostport - nosprintfhostport
- staticcheck - staticcheck
- tenv
- unconvert - unconvert
- usetesting
- wastedassign - wastedassign
- whitespace - whitespace
disable: disable:

View File

@@ -24,7 +24,6 @@ set(GGML_LLAMAFILE ON)
set(GGML_CUDA_PEER_MAX_BATCH_SIZE 128) set(GGML_CUDA_PEER_MAX_BATCH_SIZE 128)
set(GGML_CUDA_GRAPHS ON) set(GGML_CUDA_GRAPHS ON)
set(GGML_CUDA_FA ON) set(GGML_CUDA_FA ON)
set(GGML_CUDA_COMPRESSION_MODE default)
if((CMAKE_OSX_ARCHITECTURES AND NOT CMAKE_OSX_ARCHITECTURES MATCHES "arm64") if((CMAKE_OSX_ARCHITECTURES AND NOT CMAKE_OSX_ARCHITECTURES MATCHES "arm64")
OR (NOT CMAKE_OSX_ARCHITECTURES AND NOT CMAKE_SYSTEM_PROCESSOR MATCHES "arm|aarch64|ARM64|ARMv[0-9]+")) OR (NOT CMAKE_OSX_ARCHITECTURES AND NOT CMAKE_SYSTEM_PROCESSOR MATCHES "arm|aarch64|ARM64|ARMv[0-9]+"))
@@ -51,8 +50,6 @@ include_directories(${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/ggml/ggml/src/include
include_directories(${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/ggml/ggml/src/ggml-cpu) include_directories(${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/ggml/ggml/src/ggml-cpu)
include_directories(${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/ggml/ggml/src/ggml-cpu/amx) include_directories(${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/ggml/ggml/src/ggml-cpu/amx)
add_compile_definitions(NDEBUG)
set(GGML_CPU ON) set(GGML_CPU ON)
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/ggml/ggml/src) add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/ggml/ggml/src)
set_property(TARGET ggml PROPERTY EXCLUDE_FROM_ALL TRUE) set_property(TARGET ggml PROPERTY EXCLUDE_FROM_ALL TRUE)
@@ -78,13 +75,14 @@ if(CMAKE_CUDA_COMPILER)
find_package(CUDAToolkit) find_package(CUDAToolkit)
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/ggml/ggml/src/ggml-cuda) add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/ggml/ggml/src/ggml-cuda)
set(OLLAMA_CUDA_INSTALL_DIR ${OLLAMA_INSTALL_DIR}/cuda_v${CUDAToolkit_VERSION_MAJOR})
install(TARGETS ggml-cuda install(TARGETS ggml-cuda
RUNTIME_DEPENDENCIES RUNTIME_DEPENDENCIES
DIRECTORIES ${CUDAToolkit_BIN_DIR} ${CUDAToolkit_LIBRARY_DIR} DIRECTORIES ${CUDAToolkit_BIN_DIR} ${CUDAToolkit_LIBRARY_DIR}
PRE_INCLUDE_REGEXES cublas cublasLt cudart PRE_INCLUDE_REGEXES cublas cublasLt cudart
PRE_EXCLUDE_REGEXES ".*" PRE_EXCLUDE_REGEXES ".*"
RUNTIME DESTINATION ${OLLAMA_INSTALL_DIR} COMPONENT CUDA RUNTIME DESTINATION ${OLLAMA_CUDA_INSTALL_DIR} COMPONENT CUDA
LIBRARY DESTINATION ${OLLAMA_INSTALL_DIR} COMPONENT CUDA LIBRARY DESTINATION ${OLLAMA_CUDA_INSTALL_DIR} COMPONENT CUDA
) )
endif() endif()
@@ -115,11 +113,7 @@ if(CMAKE_HIP_COMPILER)
set(OLLAMA_HIP_INSTALL_DIR ${OLLAMA_INSTALL_DIR}/rocm) set(OLLAMA_HIP_INSTALL_DIR ${OLLAMA_INSTALL_DIR}/rocm)
install(TARGETS ggml-hip install(TARGETS ggml-hip
RUNTIME_DEPENDENCY_SET rocm RUNTIME_DEPENDENCIES
RUNTIME DESTINATION ${OLLAMA_INSTALL_DIR} COMPONENT HIP
LIBRARY DESTINATION ${OLLAMA_INSTALL_DIR} COMPONENT HIP
)
install(RUNTIME_DEPENDENCY_SET rocm
DIRECTORIES ${HIP_BIN_INSTALL_DIR} ${HIP_LIB_INSTALL_DIR} DIRECTORIES ${HIP_BIN_INSTALL_DIR} ${HIP_LIB_INSTALL_DIR}
PRE_INCLUDE_REGEXES hipblas rocblas amdhip64 rocsolver amd_comgr hsa-runtime64 rocsparse tinfo rocprofiler-register drm drm_amdgpu numa elf PRE_INCLUDE_REGEXES hipblas rocblas amdhip64 rocsolver amd_comgr hsa-runtime64 rocsparse tinfo rocprofiler-register drm drm_amdgpu numa elf
PRE_EXCLUDE_REGEXES ".*" PRE_EXCLUDE_REGEXES ".*"

View File

@@ -17,12 +17,18 @@
"name": "CUDA", "name": "CUDA",
"inherits": [ "Default" ] "inherits": [ "Default" ]
}, },
{
"name": "CUDA 11",
"inherits": [ "CUDA" ],
"cacheVariables": {
"CMAKE_CUDA_ARCHITECTURES": "50;52;53;60;61;70;75;80;86"
}
},
{ {
"name": "CUDA 12", "name": "CUDA 12",
"inherits": [ "CUDA" ], "inherits": [ "CUDA" ],
"cacheVariables": { "cacheVariables": {
"CMAKE_CUDA_ARCHITECTURES": "50;60;61;70;75;80;86;87;89;90;90a;120", "CMAKE_CUDA_ARCHITECTURES": "50;60;61;70;75;80;86;87;89;90;90a;120"
"CMAKE_CUDA_FLAGS": "-Wno-deprecated-gpu-targets -t 2"
} }
}, },
{ {
@@ -50,7 +56,6 @@
"name": "ROCm 6", "name": "ROCm 6",
"inherits": [ "ROCm" ], "inherits": [ "ROCm" ],
"cacheVariables": { "cacheVariables": {
"CMAKE_HIP_FLAGS": "-parallel-jobs=4",
"AMDGPU_TARGETS": "gfx900;gfx940;gfx941;gfx942;gfx1010;gfx1012;gfx1030;gfx1100;gfx1101;gfx1102;gfx1151;gfx1200;gfx1201;gfx906:xnack-;gfx908:xnack-;gfx90a:xnack+;gfx90a:xnack-" "AMDGPU_TARGETS": "gfx900;gfx940;gfx941;gfx942;gfx1010;gfx1012;gfx1030;gfx1100;gfx1101;gfx1102;gfx1151;gfx1200;gfx1201;gfx906:xnack-;gfx908:xnack-;gfx90a:xnack+;gfx90a:xnack-"
} }
} }
@@ -71,6 +76,11 @@
"configurePreset": "CUDA", "configurePreset": "CUDA",
"targets": [ "ggml-cuda" ] "targets": [ "ggml-cuda" ]
}, },
{
"name": "CUDA 11",
"inherits": [ "CUDA" ],
"configurePreset": "CUDA 11"
},
{ {
"name": "CUDA 12", "name": "CUDA 12",
"inherits": [ "CUDA" ], "inherits": [ "CUDA" ],

View File

@@ -7,13 +7,12 @@ ARG JETPACK5VERSION=r35.4.1
ARG JETPACK6VERSION=r36.4.0 ARG JETPACK6VERSION=r36.4.0
ARG CMAKEVERSION=3.31.2 ARG CMAKEVERSION=3.31.2
# We require gcc v10 minimum. v10.3 has regressions, so the rockylinux 8.5 AppStream has the latest compatible version # CUDA v11 requires gcc v10. v10.3 has regressions, so the rockylinux 8.5 AppStream has the latest compatible version
FROM --platform=linux/amd64 rocm/dev-almalinux-8:${ROCMVERSION}-complete AS base-amd64 FROM --platform=linux/amd64 rocm/dev-almalinux-8:${ROCMVERSION}-complete AS base-amd64
RUN yum install -y yum-utils \ RUN yum install -y yum-utils \
&& yum-config-manager --add-repo https://dl.rockylinux.org/vault/rocky/8.5/AppStream/\$basearch/os/ \ && yum-config-manager --add-repo https://dl.rockylinux.org/vault/rocky/8.5/AppStream/\$basearch/os/ \
&& rpm --import https://dl.rockylinux.org/pub/rocky/RPM-GPG-KEY-Rocky-8 \ && rpm --import https://dl.rockylinux.org/pub/rocky/RPM-GPG-KEY-Rocky-8 \
&& dnf install -y yum-utils ccache gcc-toolset-10-gcc-10.2.1-8.2.el8 gcc-toolset-10-gcc-c++-10.2.1-8.2.el8 gcc-toolset-10-binutils-2.35-11.el8 \ && dnf install -y yum-utils ccache gcc-toolset-10-gcc-10.2.1-8.2.el8 gcc-toolset-10-gcc-c++-10.2.1-8.2.el8 gcc-toolset-10-binutils-2.35-11.el8 \
&& dnf install -y ccache \
&& yum-config-manager --add-repo https://developer.download.nvidia.com/compute/cuda/repos/rhel8/x86_64/cuda-rhel8.repo && yum-config-manager --add-repo https://developer.download.nvidia.com/compute/cuda/repos/rhel8/x86_64/cuda-rhel8.repo
ENV PATH=/opt/rh/gcc-toolset-10/root/usr/bin:$PATH ENV PATH=/opt/rh/gcc-toolset-10/root/usr/bin:$PATH
@@ -39,6 +38,15 @@ RUN --mount=type=cache,target=/root/.ccache \
&& cmake --build --parallel --preset 'CPU' \ && cmake --build --parallel --preset 'CPU' \
&& cmake --install build --component CPU --strip --parallel 8 && cmake --install build --component CPU --strip --parallel 8
FROM base AS cuda-11
ARG CUDA11VERSION=11.3
RUN dnf install -y cuda-toolkit-${CUDA11VERSION//./-}
ENV PATH=/usr/local/cuda-11/bin:$PATH
RUN --mount=type=cache,target=/root/.ccache \
cmake --preset 'CUDA 11' \
&& cmake --build --parallel --preset 'CUDA 11' \
&& cmake --install build --component CUDA --strip --parallel 8
FROM base AS cuda-12 FROM base AS cuda-12
ARG CUDA12VERSION=12.8 ARG CUDA12VERSION=12.8
RUN dnf install -y cuda-toolkit-${CUDA12VERSION//./-} RUN dnf install -y cuda-toolkit-${CUDA12VERSION//./-}
@@ -90,15 +98,17 @@ RUN --mount=type=cache,target=/root/.cache/go-build \
go build -trimpath -buildmode=pie -o /bin/ollama . go build -trimpath -buildmode=pie -o /bin/ollama .
FROM --platform=linux/amd64 scratch AS amd64 FROM --platform=linux/amd64 scratch AS amd64
COPY --from=cuda-12 dist/lib/ollama /lib/ollama COPY --from=cuda-11 dist/lib/ollama/cuda_v11 /lib/ollama/cuda_v11
COPY --from=cuda-12 dist/lib/ollama/cuda_v12 /lib/ollama/cuda_v12
FROM --platform=linux/arm64 scratch AS arm64 FROM --platform=linux/arm64 scratch AS arm64
COPY --from=cuda-12 dist/lib/ollama /lib/ollama/cuda_sbsa COPY --from=cuda-11 dist/lib/ollama/cuda_v11 /lib/ollama/cuda_v11
COPY --from=jetpack-5 dist/lib/ollama /lib/ollama/cuda_jetpack5 COPY --from=cuda-12 dist/lib/ollama/cuda_v12 /lib/ollama/cuda_v12
COPY --from=jetpack-6 dist/lib/ollama /lib/ollama/cuda_jetpack6 COPY --from=jetpack-5 dist/lib/ollama/cuda_v11 /lib/ollama/cuda_jetpack5
COPY --from=jetpack-6 dist/lib/ollama/cuda_v12 /lib/ollama/cuda_jetpack6
FROM scratch AS rocm FROM scratch AS rocm
COPY --from=rocm-6 dist/lib/ollama /lib/ollama COPY --from=rocm-6 dist/lib/ollama/rocm /lib/ollama/rocm
FROM ${FLAVOR} AS archive FROM ${FLAVOR} AS archive
COPY --from=cpu dist/lib/ollama /lib/ollama COPY --from=cpu dist/lib/ollama /lib/ollama

View File

@@ -1,6 +1,6 @@
UPSTREAM=https://github.com/ggerganov/llama.cpp.git UPSTREAM=https://github.com/ggerganov/llama.cpp.git
WORKDIR=llama/vendor WORKDIR=llama/vendor
FETCH_HEAD=de4c07f93783a1a96456a44dc16b9db538ee1618 FETCH_HEAD=d7cfe1ffe0f435d0048a6058d529daf76e072d9c
.PHONY: help .PHONY: help
help: help:
@@ -15,30 +15,27 @@ help:
@echo " make -f $(lastword $(MAKEFILE_LIST)) clean sync" @echo " make -f $(lastword $(MAKEFILE_LIST)) clean sync"
.PHONY: sync .PHONY: sync
sync: llama/build-info.cpp ml/backend/ggml/ggml/src/ggml-metal/ggml-metal-embed.metal sync: llama/build-info.cpp llama/llama.cpp ml/backend/ggml/ggml apply-patches
llama/build-info.cpp: llama/build-info.cpp.in llama/llama.cpp .PHONY: llama/build-info.cpp
sed -e 's|@FETCH_HEAD@|$(FETCH_HEAD)|' <$< >$@ llama/build-info.cpp: llama/build-info.cpp.in
sed -e 's|@FETCH_HEAD@|$(FETCH_HEAD)|' $< > $@
ml/backend/ggml/ggml/src/ggml-metal/ggml-metal-embed.metal: ml/backend/ggml/ggml
go generate ./$(@D)
.PHONY: llama/llama.cpp .PHONY: llama/llama.cpp
llama/llama.cpp: llama/vendor/ llama/llama.cpp: llama/vendor/ apply-patches
rsync -arvzc -f "merge $@/.rsync-filter" $< $@ rsync -arvzc -f "merge $@/.rsync-filter" $< $@
.PHONY: ml/backend/ggml/ggml .PHONY: ml/backend/ggml/ggml apply-patches
ml/backend/ggml/ggml: llama/vendor/ggml/ ml/backend/ggml/ggml: llama/vendor/ggml/ apply-patches
rsync -arvzc -f "merge $@/.rsync-filter" $< $@ rsync -arvzc -f "merge $@/.rsync-filter" $< $@
PATCHES=$(wildcard llama/patches/*.patch) PATCHES=$(wildcard llama/patches/*.patch)
PATCHED=$(join $(dir $(PATCHES)), $(addsuffix ed, $(addprefix ., $(notdir $(PATCHES)))))
.PHONY: apply-patches .PHONY: apply-patches
.NOTPARALLEL: .NOTPARALLEL:
apply-patches: $(PATCHED) apply-patches: $(addsuffix ed, $(PATCHES))
llama/patches/.%.patched: llama/patches/%.patch %.patched: %.patch
@if git -c user.name=nobody -c 'user.email=<>' -C $(WORKDIR) am -3 $(realpath $<); then touch $@; else git -C $(WORKDIR) am --abort; exit 1; fi @if git -c user.name=nobody -c 'user.email=<>' -C $(WORKDIR) am -3 $(realpath $<); then touch $@; else git -C $(WORKDIR) am --abort; exit 1; fi
.PHONY: checkout .PHONY: checkout
@@ -60,4 +57,4 @@ format-patches: llama/patches
.PHONE: clean .PHONE: clean
clean: checkout clean: checkout
$(RM) llama/patches/.*.patched $(RM) $(addsuffix ed, $(PATCHES))

View File

@@ -40,10 +40,10 @@ The official [Ollama Docker image](https://hub.docker.com/r/ollama/ollama) `olla
## Quickstart ## Quickstart
To run and chat with [Gemma 3](https://ollama.com/library/gemma3): To run and chat with [Llama 3.2](https://ollama.com/library/llama3.2):
```shell ```shell
ollama run gemma3 ollama run llama3.2
``` ```
## Model library ## Model library
@@ -61,8 +61,6 @@ Here are some example models that can be downloaded:
| QwQ | 32B | 20GB | `ollama run qwq` | | QwQ | 32B | 20GB | `ollama run qwq` |
| DeepSeek-R1 | 7B | 4.7GB | `ollama run deepseek-r1` | | DeepSeek-R1 | 7B | 4.7GB | `ollama run deepseek-r1` |
| DeepSeek-R1 | 671B | 404GB | `ollama run deepseek-r1:671b` | | DeepSeek-R1 | 671B | 404GB | `ollama run deepseek-r1:671b` |
| Llama 4 | 109B | 67GB | `ollama run llama4:scout` |
| Llama 4 | 400B | 245GB | `ollama run llama4:maverick` |
| Llama 3.3 | 70B | 43GB | `ollama run llama3.3` | | Llama 3.3 | 70B | 43GB | `ollama run llama3.3` |
| Llama 3.2 | 3B | 2.0GB | `ollama run llama3.2` | | Llama 3.2 | 3B | 2.0GB | `ollama run llama3.2` |
| Llama 3.2 | 1B | 1.3GB | `ollama run llama3.2:1b` | | Llama 3.2 | 1B | 1.3GB | `ollama run llama3.2:1b` |
@@ -79,7 +77,7 @@ Here are some example models that can be downloaded:
| Code Llama | 7B | 3.8GB | `ollama run codellama` | | Code Llama | 7B | 3.8GB | `ollama run codellama` |
| Llama 2 Uncensored | 7B | 3.8GB | `ollama run llama2-uncensored` | | Llama 2 Uncensored | 7B | 3.8GB | `ollama run llama2-uncensored` |
| LLaVA | 7B | 4.5GB | `ollama run llava` | | LLaVA | 7B | 4.5GB | `ollama run llava` |
| Granite-3.3 | 8B | 4.9GB | `ollama run granite3.3` | | Granite-3.2 | 8B | 4.9GB | `ollama run granite3.2` |
> [!NOTE] > [!NOTE]
> You should have at least 8 GB of RAM available to run the 7B models, 16 GB to run the 13B models, and 32 GB to run the 33B models. > You should have at least 8 GB of RAM available to run the 7B models, 16 GB to run the 13B models, and 32 GB to run the 33B models.
@@ -287,13 +285,13 @@ See the [API documentation](./docs/api.md) for all endpoints.
- [Bionic GPT](https://github.com/bionic-gpt/bionic-gpt) - [Bionic GPT](https://github.com/bionic-gpt/bionic-gpt)
- [HTML UI](https://github.com/rtcfirefly/ollama-ui) - [HTML UI](https://github.com/rtcfirefly/ollama-ui)
- [Saddle](https://github.com/jikkuatwork/saddle) - [Saddle](https://github.com/jikkuatwork/saddle)
- [TagSpaces](https://www.tagspaces.org) (A platform for file-based apps, [utilizing Ollama](https://docs.tagspaces.org/ai/) for the generation of tags and descriptions) - [TagSpaces](https://www.tagspaces.org) (A platform for file based apps, [utilizing Ollama](https://docs.tagspaces.org/ai/) for the generation of tags and descriptions)
- [Chatbot UI](https://github.com/ivanfioravanti/chatbot-ollama) - [Chatbot UI](https://github.com/ivanfioravanti/chatbot-ollama)
- [Chatbot UI v2](https://github.com/mckaywrigley/chatbot-ui) - [Chatbot UI v2](https://github.com/mckaywrigley/chatbot-ui)
- [Typescript UI](https://github.com/ollama-interface/Ollama-Gui?tab=readme-ov-file) - [Typescript UI](https://github.com/ollama-interface/Ollama-Gui?tab=readme-ov-file)
- [Minimalistic React UI for Ollama Models](https://github.com/richawo/minimal-llm-ui) - [Minimalistic React UI for Ollama Models](https://github.com/richawo/minimal-llm-ui)
- [Ollamac](https://github.com/kevinhermawan/Ollamac) - [Ollamac](https://github.com/kevinhermawan/Ollamac)
- [big-AGI](https://github.com/enricoros/big-AGI) - [big-AGI](https://github.com/enricoros/big-AGI)
- [Cheshire Cat assistant framework](https://github.com/cheshire-cat-ai/core) - [Cheshire Cat assistant framework](https://github.com/cheshire-cat-ai/core)
- [Amica](https://github.com/semperai/amica) - [Amica](https://github.com/semperai/amica)
- [chatd](https://github.com/BruceMacD/chatd) - [chatd](https://github.com/BruceMacD/chatd)
@@ -314,8 +312,6 @@ See the [API documentation](./docs/api.md) for all endpoints.
- [Ollama Basic Chat: Uses HyperDiv Reactive UI](https://github.com/rapidarchitect/ollama_basic_chat) - [Ollama Basic Chat: Uses HyperDiv Reactive UI](https://github.com/rapidarchitect/ollama_basic_chat)
- [Ollama-chats RPG](https://github.com/drazdra/ollama-chats) - [Ollama-chats RPG](https://github.com/drazdra/ollama-chats)
- [IntelliBar](https://intellibar.app/) (AI-powered assistant for macOS) - [IntelliBar](https://intellibar.app/) (AI-powered assistant for macOS)
- [Jirapt](https://github.com/AliAhmedNada/jirapt) (Jira Integration to generate issues, tasks, epics)
- [ojira](https://github.com/AliAhmedNada/ojira) (Jira chrome plugin to easily generate descriptions for tasks)
- [QA-Pilot](https://github.com/reid41/QA-Pilot) (Interactive chat tool that can leverage Ollama models for rapid understanding and navigation of GitHub code repositories) - [QA-Pilot](https://github.com/reid41/QA-Pilot) (Interactive chat tool that can leverage Ollama models for rapid understanding and navigation of GitHub code repositories)
- [ChatOllama](https://github.com/sugarforever/chat-ollama) (Open Source Chatbot based on Ollama with Knowledge Bases) - [ChatOllama](https://github.com/sugarforever/chat-ollama) (Open Source Chatbot based on Ollama with Knowledge Bases)
- [CRAG Ollama Chat](https://github.com/Nagi-ovo/CRAG-Ollama-Chat) (Simple Web Search with Corrective RAG) - [CRAG Ollama Chat](https://github.com/Nagi-ovo/CRAG-Ollama-Chat) (Simple Web Search with Corrective RAG)
@@ -329,14 +325,14 @@ See the [API documentation](./docs/api.md) for all endpoints.
- [RWKV-Runner](https://github.com/josStorer/RWKV-Runner) (RWKV offline LLM deployment tool, also usable as a client for ChatGPT and Ollama) - [RWKV-Runner](https://github.com/josStorer/RWKV-Runner) (RWKV offline LLM deployment tool, also usable as a client for ChatGPT and Ollama)
- [Ollama Grid Search](https://github.com/dezoito/ollama-grid-search) (app to evaluate and compare models) - [Ollama Grid Search](https://github.com/dezoito/ollama-grid-search) (app to evaluate and compare models)
- [Olpaka](https://github.com/Otacon/olpaka) (User-friendly Flutter Web App for Ollama) - [Olpaka](https://github.com/Otacon/olpaka) (User-friendly Flutter Web App for Ollama)
- [Casibase](https://casibase.org) (An open source AI knowledge base and dialogue system combining the latest RAG, SSO, ollama support, and multiple large language models.) - [Casibase](https://casibase.org) (An open source AI knowledge base and dialogue system combining the latest RAG, SSO, ollama support and multiple large language models.)
- [OllamaSpring](https://github.com/CrazyNeil/OllamaSpring) (Ollama Client for macOS) - [OllamaSpring](https://github.com/CrazyNeil/OllamaSpring) (Ollama Client for macOS)
- [LLocal.in](https://github.com/kartikm7/llocal) (Easy to use Electron Desktop Client for Ollama) - [LLocal.in](https://github.com/kartikm7/llocal) (Easy to use Electron Desktop Client for Ollama)
- [Shinkai Desktop](https://github.com/dcSpark/shinkai-apps) (Two click install Local AI using Ollama + Files + RAG) - [Shinkai Desktop](https://github.com/dcSpark/shinkai-apps) (Two click install Local AI using Ollama + Files + RAG)
- [AiLama](https://github.com/zeyoyt/ailama) (A Discord User App that allows you to interact with Ollama anywhere in Discord) - [AiLama](https://github.com/zeyoyt/ailama) (A Discord User App that allows you to interact with Ollama anywhere in discord )
- [Ollama with Google Mesop](https://github.com/rapidarchitect/ollama_mesop/) (Mesop Chat Client implementation with Ollama) - [Ollama with Google Mesop](https://github.com/rapidarchitect/ollama_mesop/) (Mesop Chat Client implementation with Ollama)
- [R2R](https://github.com/SciPhi-AI/R2R) (Open-source RAG engine) - [R2R](https://github.com/SciPhi-AI/R2R) (Open-source RAG engine)
- [Ollama-Kis](https://github.com/elearningshow/ollama-kis) (A simple easy-to-use GUI with sample custom LLM for Drivers Education) - [Ollama-Kis](https://github.com/elearningshow/ollama-kis) (A simple easy to use GUI with sample custom LLM for Drivers Education)
- [OpenGPA](https://opengpa.org) (Open-source offline-first Enterprise Agentic Application) - [OpenGPA](https://opengpa.org) (Open-source offline-first Enterprise Agentic Application)
- [Painting Droid](https://github.com/mateuszmigas/painting-droid) (Painting app with AI integrations) - [Painting Droid](https://github.com/mateuszmigas/painting-droid) (Painting app with AI integrations)
- [Kerlig AI](https://www.kerlig.com/) (AI writing assistant for macOS) - [Kerlig AI](https://www.kerlig.com/) (AI writing assistant for macOS)
@@ -345,16 +341,16 @@ See the [API documentation](./docs/api.md) for all endpoints.
- [LLMStack](https://github.com/trypromptly/LLMStack) (No-code multi-agent framework to build LLM agents and workflows) - [LLMStack](https://github.com/trypromptly/LLMStack) (No-code multi-agent framework to build LLM agents and workflows)
- [BoltAI for Mac](https://boltai.com) (AI Chat Client for Mac) - [BoltAI for Mac](https://boltai.com) (AI Chat Client for Mac)
- [Harbor](https://github.com/av/harbor) (Containerized LLM Toolkit with Ollama as default backend) - [Harbor](https://github.com/av/harbor) (Containerized LLM Toolkit with Ollama as default backend)
- [PyGPT](https://github.com/szczyglis-dev/py-gpt) (AI desktop assistant for Linux, Windows, and Mac) - [PyGPT](https://github.com/szczyglis-dev/py-gpt) (AI desktop assistant for Linux, Windows and Mac)
- [Alpaca](https://github.com/Jeffser/Alpaca) (An Ollama client application for Linux and macOS made with GTK4 and Adwaita) - [Alpaca](https://github.com/Jeffser/Alpaca) (An Ollama client application for linux and macos made with GTK4 and Adwaita)
- [AutoGPT](https://github.com/Significant-Gravitas/AutoGPT/blob/master/docs/content/platform/ollama.md) (AutoGPT Ollama integration) - [AutoGPT](https://github.com/Significant-Gravitas/AutoGPT/blob/master/docs/content/platform/ollama.md) (AutoGPT Ollama integration)
- [Go-CREW](https://www.jonathanhecl.com/go-crew/) (Powerful Offline RAG in Golang) - [Go-CREW](https://www.jonathanhecl.com/go-crew/) (Powerful Offline RAG in Golang)
- [PartCAD](https://github.com/openvmp/partcad/) (CAD model generation with OpenSCAD and CadQuery) - [PartCAD](https://github.com/openvmp/partcad/) (CAD model generation with OpenSCAD and CadQuery)
- [Ollama4j Web UI](https://github.com/ollama4j/ollama4j-web-ui) - Java-based Web UI for Ollama built with Vaadin, Spring Boot, and Ollama4j - [Ollama4j Web UI](https://github.com/ollama4j/ollama4j-web-ui) - Java-based Web UI for Ollama built with Vaadin, Spring Boot and Ollama4j
- [PyOllaMx](https://github.com/kspviswa/pyOllaMx) - macOS application capable of chatting with both Ollama and Apple MLX models. - [PyOllaMx](https://github.com/kspviswa/pyOllaMx) - macOS application capable of chatting with both Ollama and Apple MLX models.
- [Cline](https://github.com/cline/cline) - Formerly known as Claude Dev is a VSCode extension for multi-file/whole-repo coding - [Cline](https://github.com/cline/cline) - Formerly known as Claude Dev is a VSCode extension for multi-file/whole-repo coding
- [Cherry Studio](https://github.com/kangfenmao/cherry-studio) (Desktop client with Ollama support) - [Cherry Studio](https://github.com/kangfenmao/cherry-studio) (Desktop client with Ollama support)
- [ConfiChat](https://github.com/1runeberg/confichat) (Lightweight, standalone, multi-platform, and privacy-focused LLM chat interface with optional encryption) - [ConfiChat](https://github.com/1runeberg/confichat) (Lightweight, standalone, multi-platform, and privacy focused LLM chat interface with optional encryption)
- [Archyve](https://github.com/nickthecook/archyve) (RAG-enabling document library) - [Archyve](https://github.com/nickthecook/archyve) (RAG-enabling document library)
- [crewAI with Mesop](https://github.com/rapidarchitect/ollama-crew-mesop) (Mesop Web Interface to run crewAI with Ollama) - [crewAI with Mesop](https://github.com/rapidarchitect/ollama-crew-mesop) (Mesop Web Interface to run crewAI with Ollama)
- [Tkinter-based client](https://github.com/chyok/ollama-gui) (Python tkinter-based Client for Ollama) - [Tkinter-based client](https://github.com/chyok/ollama-gui) (Python tkinter-based Client for Ollama)
@@ -372,7 +368,7 @@ See the [API documentation](./docs/api.md) for all endpoints.
- [DualMind](https://github.com/tcsenpai/dualmind) (Experimental app allowing two models to talk to each other in the terminal or in a web interface) - [DualMind](https://github.com/tcsenpai/dualmind) (Experimental app allowing two models to talk to each other in the terminal or in a web interface)
- [ollamarama-matrix](https://github.com/h1ddenpr0cess20/ollamarama-matrix) (Ollama chatbot for the Matrix chat protocol) - [ollamarama-matrix](https://github.com/h1ddenpr0cess20/ollamarama-matrix) (Ollama chatbot for the Matrix chat protocol)
- [ollama-chat-app](https://github.com/anan1213095357/ollama-chat-app) (Flutter-based chat app) - [ollama-chat-app](https://github.com/anan1213095357/ollama-chat-app) (Flutter-based chat app)
- [Perfect Memory AI](https://www.perfectmemory.ai/) (Productivity AI assists personalized by what you have seen on your screen, heard, and said in the meetings) - [Perfect Memory AI](https://www.perfectmemory.ai/) (Productivity AI assists personalized by what you have seen on your screen, heard and said in the meetings)
- [Hexabot](https://github.com/hexastack/hexabot) (A conversational AI builder) - [Hexabot](https://github.com/hexastack/hexabot) (A conversational AI builder)
- [Reddit Rate](https://github.com/rapidarchitect/reddit_analyzer) (Search and Rate Reddit topics with a weighted summation) - [Reddit Rate](https://github.com/rapidarchitect/reddit_analyzer) (Search and Rate Reddit topics with a weighted summation)
- [OpenTalkGpt](https://github.com/adarshM84/OpenTalkGpt) (Chrome Extension to manage open-source models supported by Ollama, create custom models, and chat with models from a user-friendly UI) - [OpenTalkGpt](https://github.com/adarshM84/OpenTalkGpt) (Chrome Extension to manage open-source models supported by Ollama, create custom models, and chat with models from a user-friendly UI)
@@ -390,7 +386,7 @@ See the [API documentation](./docs/api.md) for all endpoints.
- [ChibiChat](https://github.com/CosmicEventHorizon/ChibiChat) (Kotlin-based Android app to chat with Ollama and Koboldcpp API endpoints) - [ChibiChat](https://github.com/CosmicEventHorizon/ChibiChat) (Kotlin-based Android app to chat with Ollama and Koboldcpp API endpoints)
- [LocalLLM](https://github.com/qusaismael/localllm) (Minimal Web-App to run ollama models on it with a GUI) - [LocalLLM](https://github.com/qusaismael/localllm) (Minimal Web-App to run ollama models on it with a GUI)
- [Ollamazing](https://github.com/buiducnhat/ollamazing) (Web extension to run Ollama models) - [Ollamazing](https://github.com/buiducnhat/ollamazing) (Web extension to run Ollama models)
- [OpenDeepResearcher-via-searxng](https://github.com/benhaotang/OpenDeepResearcher-via-searxng) (A Deep Research equivalent endpoint with Ollama support for running locally) - [OpenDeepResearcher-via-searxng](https://github.com/benhaotang/OpenDeepResearcher-via-searxng) (A Deep Research equivent endpoint with Ollama support for running locally)
- [AntSK](https://github.com/AIDotNet/AntSK) (Out-of-the-box & Adaptable RAG Chatbot) - [AntSK](https://github.com/AIDotNet/AntSK) (Out-of-the-box & Adaptable RAG Chatbot)
- [MaxKB](https://github.com/1Panel-dev/MaxKB/) (Ready-to-use & flexible RAG Chatbot) - [MaxKB](https://github.com/1Panel-dev/MaxKB/) (Ready-to-use & flexible RAG Chatbot)
- [yla](https://github.com/danielekp/yla) (Web interface to freely interact with your customized models) - [yla](https://github.com/danielekp/yla) (Web interface to freely interact with your customized models)
@@ -398,18 +394,10 @@ See the [API documentation](./docs/api.md) for all endpoints.
- [1Panel](https://github.com/1Panel-dev/1Panel/) (Web-based Linux Server Management Tool) - [1Panel](https://github.com/1Panel-dev/1Panel/) (Web-based Linux Server Management Tool)
- [AstrBot](https://github.com/Soulter/AstrBot/) (User-friendly LLM-based multi-platform chatbot with a WebUI, supporting RAG, LLM agents, and plugins integration) - [AstrBot](https://github.com/Soulter/AstrBot/) (User-friendly LLM-based multi-platform chatbot with a WebUI, supporting RAG, LLM agents, and plugins integration)
- [Reins](https://github.com/ibrahimcetin/reins) (Easily tweak parameters, customize system prompts per chat, and enhance your AI experiments with reasoning model support.) - [Reins](https://github.com/ibrahimcetin/reins) (Easily tweak parameters, customize system prompts per chat, and enhance your AI experiments with reasoning model support.)
- [Flufy](https://github.com/Aharon-Bensadoun/Flufy) (A beautiful chat interface for interacting with Ollama's API. Built with React, TypeScript, and Material-UI.)
- [Ellama](https://github.com/zeozeozeo/ellama) (Friendly native app to chat with an Ollama instance) - [Ellama](https://github.com/zeozeozeo/ellama) (Friendly native app to chat with an Ollama instance)
- [screenpipe](https://github.com/mediar-ai/screenpipe) Build agents powered by your screen history - [screenpipe](https://github.com/mediar-ai/screenpipe) Build agents powered by your screen history
- [Ollamb](https://github.com/hengkysteen/ollamb) (Simple yet rich in features, cross-platform built with Flutter and designed for Ollama. Try the [web demo](https://hengkysteen.github.io/demo/ollamb/).) - [Ollamb](https://github.com/hengkysteen/ollamb) (Simple yet rich in features, cross-platform built with Flutter and designed for Ollama. Try the [web demo](https://hengkysteen.github.io/demo/ollamb/).)
- [Writeopia](https://github.com/Writeopia/Writeopia) (Text editor with integration with Ollama) - [Writeopia](https://github.com/Writeopia/Writeopia) (Text editor with integration with Ollama)
- [AppFlowy](https://github.com/AppFlowy-IO/AppFlowy) (AI collaborative workspace with Ollama, cross-platform and self-hostable)
- [Lumina](https://github.com/cushydigit/lumina.git) (A lightweight, minimal React.js frontend for interacting with Ollama servers)
- [Tiny Notepad](https://pypi.org/project/tiny-notepad) (A lightweight, notepad-like interface to chat with ollama available on PyPI)
- [macLlama (macOS native)](https://github.com/hellotunamayo/macLlama) (A native macOS GUI application for interacting with Ollama models, featuring a chat interface.)
- [GPTranslate](https://github.com/philberndt/GPTranslate) (A fast and lightweight, AI powered desktop translation application written with Rust and Tauri. Features real-time translation with OpenAI/Azure/Ollama.)
- [ollama launcher](https://github.com/NGC13009/ollama-launcher) (A launcher for Ollama, aiming to provide users with convenient functions such as ollama server launching, management, or configuration.)
- [ai-hub](https://github.com/Aj-Seven/ai-hub) (AI Hub supports multiple models via API keys and Chat support via Ollama API.)
### Cloud ### Cloud
@@ -451,10 +439,8 @@ See the [API documentation](./docs/api.md) for all endpoints.
- [PowershAI](https://github.com/rrg92/powershai) PowerShell module that brings AI to terminal on Windows, including support for Ollama - [PowershAI](https://github.com/rrg92/powershai) PowerShell module that brings AI to terminal on Windows, including support for Ollama
- [DeepShell](https://github.com/Abyss-c0re/deepshell) Your self-hosted AI assistant. Interactive Shell, Files and Folders analysis. - [DeepShell](https://github.com/Abyss-c0re/deepshell) Your self-hosted AI assistant. Interactive Shell, Files and Folders analysis.
- [orbiton](https://github.com/xyproto/orbiton) Configuration-free text editor and IDE with support for tab completion with Ollama. - [orbiton](https://github.com/xyproto/orbiton) Configuration-free text editor and IDE with support for tab completion with Ollama.
- [orca-cli](https://github.com/molbal/orca-cli) Ollama Registry CLI Application - Browse, pull, and download models from Ollama Registry in your terminal. - [orca-cli](https://github.com/molbal/orca-cli) Ollama Registry CLI Application - Browse, pull and download models from Ollama Registry in your terminal.
- [GGUF-to-Ollama](https://github.com/jonathanhecl/gguf-to-ollama) - Importing GGUF to Ollama made easy (multiplatform) - [GGUF-to-Ollama](https://github.com/jonathanhecl/gguf-to-ollama) - Importing GGUF to Ollama made easy (multiplatform)
- [AWS-Strands-With-Ollama](https://github.com/rapidarchitect/ollama_strands) - AWS Strands Agents with Ollama Examples
- [ollama-multirun](https://github.com/attogram/ollama-multirun) - A bash shell script to run a single prompt against any or all of your locally installed ollama models, saving the output and performance statistics as easily navigable web pages. ([Demo](https://attogram.github.io/ai_test_zone/))
### Apple Vision Pro ### Apple Vision Pro
@@ -481,7 +467,7 @@ See the [API documentation](./docs/api.md) for all endpoints.
### Libraries ### Libraries
- [LangChain](https://python.langchain.com/docs/integrations/chat/ollama/) and [LangChain.js](https://js.langchain.com/docs/integrations/chat/ollama/) with [example](https://js.langchain.com/docs/tutorials/local_rag/) - [LangChain](https://python.langchain.com/docs/integrations/llms/ollama) and [LangChain.js](https://js.langchain.com/docs/integrations/chat/ollama/) with [example](https://js.langchain.com/docs/tutorials/local_rag/)
- [Firebase Genkit](https://firebase.google.com/docs/genkit/plugins/ollama) - [Firebase Genkit](https://firebase.google.com/docs/genkit/plugins/ollama)
- [crewAI](https://github.com/crewAIInc/crewAI) - [crewAI](https://github.com/crewAIInc/crewAI)
- [Yacana](https://remembersoftwares.github.io/yacana/) (User-friendly multi-agent framework for brainstorming and executing predetermined flows with built-in tool integration) - [Yacana](https://remembersoftwares.github.io/yacana/) (User-friendly multi-agent framework for brainstorming and executing predetermined flows with built-in tool integration)
@@ -528,21 +514,20 @@ See the [API documentation](./docs/api.md) for all endpoints.
- [Swollama for Swift](https://github.com/marcusziade/Swollama) with [DocC](https://marcusziade.github.io/Swollama/documentation/swollama/) - [Swollama for Swift](https://github.com/marcusziade/Swollama) with [DocC](https://marcusziade.github.io/Swollama/documentation/swollama/)
- [GoLamify](https://github.com/prasad89/golamify) - [GoLamify](https://github.com/prasad89/golamify)
- [Ollama for Haskell](https://github.com/tusharad/ollama-haskell) - [Ollama for Haskell](https://github.com/tusharad/ollama-haskell)
- [multi-llm-ts](https://github.com/nbonamy/multi-llm-ts) (A Typescript/JavaScript library allowing access to different LLM in a unified API) - [multi-llm-ts](https://github.com/nbonamy/multi-llm-ts) (A Typescript/JavaScript library allowing access to different LLM in unified API)
- [LlmTornado](https://github.com/lofcz/llmtornado) (C# library providing a unified interface for major FOSS & Commercial inference APIs) - [LlmTornado](https://github.com/lofcz/llmtornado) (C# library providing a unified interface for major FOSS & Commercial inference APIs)
- [Ollama for Zig](https://github.com/dravenk/ollama-zig) - [Ollama for Zig](https://github.com/dravenk/ollama-zig)
- [Abso](https://github.com/lunary-ai/abso) (OpenAI-compatible TypeScript SDK for any LLM provider) - [Abso](https://github.com/lunary-ai/abso) (OpenAI-compatible TypeScript SDK for any LLM provider)
- [Nichey](https://github.com/goodreasonai/nichey) is a Python package for generating custom wikis for your research topic - [Nichey](https://github.com/goodreasonai/nichey) is a Python package for generating custom wikis for your research topic
- [Ollama for D](https://github.com/kassane/ollama-d) - [Ollama for D](https://github.com/kassane/ollama-d)
- [OllamaPlusPlus](https://github.com/HardCodeDev777/OllamaPlusPlus) (Very simple C++ library for Ollama)
### Mobile ### Mobile
- [SwiftChat](https://github.com/aws-samples/swift-chat) (Lightning-fast Cross-platform AI chat app with native UI for Android, iOS, and iPad) - [SwiftChat](https://github.com/aws-samples/swift-chat) (Lightning-fast Cross-platform AI chat app with native UI for Android, iOS and iPad)
- [Enchanted](https://github.com/AugustDev/enchanted) - [Enchanted](https://github.com/AugustDev/enchanted)
- [Maid](https://github.com/Mobile-Artificial-Intelligence/maid) - [Maid](https://github.com/Mobile-Artificial-Intelligence/maid)
- [Ollama App](https://github.com/JHubi1/ollama-app) (Modern and easy-to-use multi-platform client for Ollama) - [Ollama App](https://github.com/JHubi1/ollama-app) (Modern and easy-to-use multi-platform client for Ollama)
- [ConfiChat](https://github.com/1runeberg/confichat) (Lightweight, standalone, multi-platform, and privacy-focused LLM chat interface with optional encryption) - [ConfiChat](https://github.com/1runeberg/confichat) (Lightweight, standalone, multi-platform, and privacy focused LLM chat interface with optional encryption)
- [Ollama Android Chat](https://github.com/sunshine0523/OllamaServer) (No need for Termux, start the Ollama service with one click on an Android device) - [Ollama Android Chat](https://github.com/sunshine0523/OllamaServer) (No need for Termux, start the Ollama service with one click on an Android device)
- [Reins](https://github.com/ibrahimcetin/reins) (Easily tweak parameters, customize system prompts per chat, and enhance your AI experiments with reasoning model support.) - [Reins](https://github.com/ibrahimcetin/reins) (Easily tweak parameters, customize system prompts per chat, and enhance your AI experiments with reasoning model support.)
@@ -566,7 +551,7 @@ See the [API documentation](./docs/api.md) for all endpoints.
- [Obsidian Local GPT plugin](https://github.com/pfrankov/obsidian-local-gpt) - [Obsidian Local GPT plugin](https://github.com/pfrankov/obsidian-local-gpt)
- [Open Interpreter](https://docs.openinterpreter.com/language-model-setup/local-models/ollama) - [Open Interpreter](https://docs.openinterpreter.com/language-model-setup/local-models/ollama)
- [Llama Coder](https://github.com/ex3ndr/llama-coder) (Copilot alternative using Ollama) - [Llama Coder](https://github.com/ex3ndr/llama-coder) (Copilot alternative using Ollama)
- [Ollama Copilot](https://github.com/bernardo-bruning/ollama-copilot) (Proxy that allows you to use Ollama as a copilot like GitHub Copilot) - [Ollama Copilot](https://github.com/bernardo-bruning/ollama-copilot) (Proxy that allows you to use ollama as a copilot like Github copilot)
- [twinny](https://github.com/rjmacarthy/twinny) (Copilot and Copilot chat alternative using Ollama) - [twinny](https://github.com/rjmacarthy/twinny) (Copilot and Copilot chat alternative using Ollama)
- [Wingman-AI](https://github.com/RussellCanfield/wingman-ai) (Copilot code and chat alternative using Ollama and Hugging Face) - [Wingman-AI](https://github.com/RussellCanfield/wingman-ai) (Copilot code and chat alternative using Ollama and Hugging Face)
- [Page Assist](https://github.com/n4ze3m/page-assist) (Chrome Extension) - [Page Assist](https://github.com/n4ze3m/page-assist) (Chrome Extension)
@@ -576,8 +561,8 @@ See the [API documentation](./docs/api.md) for all endpoints.
- [Discord-Ollama Chat Bot](https://github.com/kevinthedang/discord-ollama) (Generalized TypeScript Discord Bot w/ Tuning Documentation) - [Discord-Ollama Chat Bot](https://github.com/kevinthedang/discord-ollama) (Generalized TypeScript Discord Bot w/ Tuning Documentation)
- [ChatGPTBox: All in one browser extension](https://github.com/josStorer/chatGPTBox) with [Integrating Tutorial](https://github.com/josStorer/chatGPTBox/issues/616#issuecomment-1975186467) - [ChatGPTBox: All in one browser extension](https://github.com/josStorer/chatGPTBox) with [Integrating Tutorial](https://github.com/josStorer/chatGPTBox/issues/616#issuecomment-1975186467)
- [Discord AI chat/moderation bot](https://github.com/rapmd73/Companion) Chat/moderation bot written in python. Uses Ollama to create personalities. - [Discord AI chat/moderation bot](https://github.com/rapmd73/Companion) Chat/moderation bot written in python. Uses Ollama to create personalities.
- [Headless Ollama](https://github.com/nischalj10/headless-ollama) (Scripts to automatically install ollama client & models on any OS for apps that depend on ollama server) - [Headless Ollama](https://github.com/nischalj10/headless-ollama) (Scripts to automatically install ollama client & models on any OS for apps that depends on ollama server)
- [Terraform AWS Ollama & Open WebUI](https://github.com/xuyangbocn/terraform-aws-self-host-llm) (A Terraform module to deploy on AWS a ready-to-use Ollama service, together with its front-end Open WebUI service.) - [Terraform AWS Ollama & Open WebUI](https://github.com/xuyangbocn/terraform-aws-self-host-llm) (A Terraform module to deploy on AWS a ready-to-use Ollama service, together with its front end Open WebUI service.)
- [node-red-contrib-ollama](https://github.com/jakubburkiewicz/node-red-contrib-ollama) - [node-red-contrib-ollama](https://github.com/jakubburkiewicz/node-red-contrib-ollama)
- [Local AI Helper](https://github.com/ivostoykov/localAI) (Chrome and Firefox extensions that enable interactions with the active tab and customisable API endpoints. Includes secure storage for user prompts.) - [Local AI Helper](https://github.com/ivostoykov/localAI) (Chrome and Firefox extensions that enable interactions with the active tab and customisable API endpoints. Includes secure storage for user prompts.)
- [vnc-lm](https://github.com/jake83741/vnc-lm) (Discord bot for messaging with LLMs through Ollama and LiteLLM. Seamlessly move between local and flagship models.) - [vnc-lm](https://github.com/jake83741/vnc-lm) (Discord bot for messaging with LLMs through Ollama and LiteLLM. Seamlessly move between local and flagship models.)
@@ -591,8 +576,6 @@ See the [API documentation](./docs/api.md) for all endpoints.
- [Simple-Discord-AI](https://github.com/zyphixor/simple-discord-ai) - [Simple-Discord-AI](https://github.com/zyphixor/simple-discord-ai)
- [LLM Telegram Bot](https://github.com/innightwolfsleep/llm_telegram_bot) (telegram bot, primary for RP. Oobabooga-like buttons, [A1111](https://github.com/AUTOMATIC1111/stable-diffusion-webui) API integration e.t.c) - [LLM Telegram Bot](https://github.com/innightwolfsleep/llm_telegram_bot) (telegram bot, primary for RP. Oobabooga-like buttons, [A1111](https://github.com/AUTOMATIC1111/stable-diffusion-webui) API integration e.t.c)
- [mcp-llm](https://github.com/sammcj/mcp-llm) (MCP Server to allow LLMs to call other LLMs) - [mcp-llm](https://github.com/sammcj/mcp-llm) (MCP Server to allow LLMs to call other LLMs)
- [SimpleOllamaUnity](https://github.com/HardCodeDev777/SimpleOllamaUnity) (Unity Engine extension for communicating with Ollama in a few lines of code. Also works at runtime)
- [UnityCodeLama](https://github.com/HardCodeDev777/UnityCodeLama) (Unity Edtior tool to analyze scripts via Ollama)
### Supported backends ### Supported backends

View File

@@ -24,10 +24,7 @@ import (
"net/http" "net/http"
"net/url" "net/url"
"runtime" "runtime"
"strconv"
"time"
"github.com/ollama/ollama/auth"
"github.com/ollama/ollama/envconfig" "github.com/ollama/ollama/envconfig"
"github.com/ollama/ollama/format" "github.com/ollama/ollama/format"
"github.com/ollama/ollama/version" "github.com/ollama/ollama/version"
@@ -79,14 +76,6 @@ func NewClient(base *url.URL, http *http.Client) *Client {
} }
} }
func getAuthorizationToken(ctx context.Context, challenge string) (string, error) {
token, err := auth.Sign(ctx, []byte(challenge))
if err != nil {
return "", err
}
return token, nil
}
func (c *Client) do(ctx context.Context, method, path string, reqData, respData any) error { func (c *Client) do(ctx context.Context, method, path string, reqData, respData any) error {
var reqBody io.Reader var reqBody io.Reader
var data []byte var data []byte
@@ -108,21 +97,6 @@ func (c *Client) do(ctx context.Context, method, path string, reqData, respData
} }
requestURL := c.base.JoinPath(path) requestURL := c.base.JoinPath(path)
var token string
if envconfig.UseAuth() || c.base.Hostname() == "ollama.com" {
now := strconv.FormatInt(time.Now().Unix(), 10)
chal := fmt.Sprintf("%s,%s?ts=%s", method, path, now)
token, err = getAuthorizationToken(ctx, chal)
if err != nil {
return err
}
q := requestURL.Query()
q.Set("ts", now)
requestURL.RawQuery = q.Encode()
}
request, err := http.NewRequestWithContext(ctx, method, requestURL.String(), reqBody) request, err := http.NewRequestWithContext(ctx, method, requestURL.String(), reqBody)
if err != nil { if err != nil {
return err return err
@@ -132,10 +106,6 @@ func (c *Client) do(ctx context.Context, method, path string, reqData, respData
request.Header.Set("Accept", "application/json") request.Header.Set("Accept", "application/json")
request.Header.Set("User-Agent", fmt.Sprintf("ollama/%s (%s %s) Go/%s", version.Version, runtime.GOARCH, runtime.GOOS, runtime.Version())) request.Header.Set("User-Agent", fmt.Sprintf("ollama/%s (%s %s) Go/%s", version.Version, runtime.GOARCH, runtime.GOOS, runtime.Version()))
if token != "" {
request.Header.Set("Authorization", token)
}
respObj, err := c.http.Do(request) respObj, err := c.http.Do(request)
if err != nil { if err != nil {
return err return err
@@ -173,22 +143,6 @@ func (c *Client) stream(ctx context.Context, method, path string, data any, fn f
} }
requestURL := c.base.JoinPath(path) requestURL := c.base.JoinPath(path)
var token string
if envconfig.UseAuth() || c.base.Hostname() == "ollama.com" {
var err error
now := strconv.FormatInt(time.Now().Unix(), 10)
chal := fmt.Sprintf("%s,%s?ts=%s", method, path, now)
token, err = getAuthorizationToken(ctx, chal)
if err != nil {
return err
}
q := requestURL.Query()
q.Set("ts", now)
requestURL.RawQuery = q.Encode()
}
request, err := http.NewRequestWithContext(ctx, method, requestURL.String(), buf) request, err := http.NewRequestWithContext(ctx, method, requestURL.String(), buf)
if err != nil { if err != nil {
return err return err
@@ -198,10 +152,6 @@ func (c *Client) stream(ctx context.Context, method, path string, data any, fn f
request.Header.Set("Accept", "application/x-ndjson") request.Header.Set("Accept", "application/x-ndjson")
request.Header.Set("User-Agent", fmt.Sprintf("ollama/%s (%s %s) Go/%s", version.Version, runtime.GOARCH, runtime.GOOS, runtime.Version())) request.Header.Set("User-Agent", fmt.Sprintf("ollama/%s (%s %s) Go/%s", version.Version, runtime.GOARCH, runtime.GOOS, runtime.Version()))
if token != "" {
request.Header.Set("Authorization", token)
}
response, err := c.http.Do(request) response, err := c.http.Do(request)
if err != nil { if err != nil {
return err return err

View File

@@ -1,6 +1,7 @@
package api package api
import ( import (
"context"
"encoding/json" "encoding/json"
"fmt" "fmt"
"net/http" "net/http"
@@ -136,7 +137,7 @@ func TestClientStream(t *testing.T) {
client := NewClient(&url.URL{Scheme: "http", Host: ts.Listener.Addr().String()}, http.DefaultClient) client := NewClient(&url.URL{Scheme: "http", Host: ts.Listener.Addr().String()}, http.DefaultClient)
var receivedChunks []ChatResponse var receivedChunks []ChatResponse
err := client.stream(t.Context(), http.MethodPost, "/v1/chat", nil, func(chunk []byte) error { err := client.stream(context.Background(), http.MethodPost, "/v1/chat", nil, func(chunk []byte) error {
var resp ChatResponse var resp ChatResponse
if err := json.Unmarshal(chunk, &resp); err != nil { if err := json.Unmarshal(chunk, &resp); err != nil {
return fmt.Errorf("failed to unmarshal chunk: %w", err) return fmt.Errorf("failed to unmarshal chunk: %w", err)
@@ -222,7 +223,7 @@ func TestClientDo(t *testing.T) {
ID string `json:"id"` ID string `json:"id"`
Success bool `json:"success"` Success bool `json:"success"`
} }
err := client.do(t.Context(), http.MethodPost, "/v1/messages", nil, &resp) err := client.do(context.Background(), http.MethodPost, "/v1/messages", nil, &resp)
if tc.wantErr != "" { if tc.wantErr != "" {
if err == nil { if err == nil {

View File

@@ -76,19 +76,13 @@ type GenerateRequest struct {
// this request. // this request.
KeepAlive *Duration `json:"keep_alive,omitempty"` KeepAlive *Duration `json:"keep_alive,omitempty"`
// Images is an optional list of raw image bytes accompanying this // Images is an optional list of base64-encoded images accompanying this
// request, for multimodal models. // request, for multimodal models.
Images []ImageData `json:"images,omitempty"` Images []ImageData `json:"images,omitempty"`
// Options lists model-specific options. For example, temperature can be // Options lists model-specific options. For example, temperature can be
// set through this field, if the model supports it. // set through this field, if the model supports it.
Options map[string]any `json:"options"` Options map[string]any `json:"options"`
// Think controls whether thinking/reasoning models will think before
// responding. Needs to be a pointer so we can distinguish between false
// (request that thinking _not_ be used) and unset (use the old behavior
// before this option was introduced)
Think *bool `json:"think,omitempty"`
} }
// ChatRequest describes a request sent by [Client.Chat]. // ChatRequest describes a request sent by [Client.Chat].
@@ -114,10 +108,6 @@ type ChatRequest struct {
// Options lists model-specific options. // Options lists model-specific options.
Options map[string]any `json:"options"` Options map[string]any `json:"options"`
// Think controls whether thinking/reasoning models will think before
// responding
Think *bool `json:"think,omitempty"`
} }
type Tools []Tool type Tools []Tool
@@ -136,11 +126,8 @@ func (t Tool) String() string {
// role ("system", "user", or "assistant"), the content and an optional list // role ("system", "user", or "assistant"), the content and an optional list
// of images. // of images.
type Message struct { type Message struct {
Role string `json:"role"` Role string `json:"role"`
Content string `json:"content"` Content string `json:"content"`
// Thinking contains the text that was inside thinking tags in the
// original model output when ChatRequest.Think is enabled.
Thinking string `json:"thinking,omitempty"`
Images []ImageData `json:"images,omitempty"` Images []ImageData `json:"images,omitempty"`
ToolCalls []ToolCall `json:"tool_calls,omitempty"` ToolCalls []ToolCall `json:"tool_calls,omitempty"`
} }
@@ -284,6 +271,9 @@ type Options struct {
RepeatPenalty float32 `json:"repeat_penalty,omitempty"` RepeatPenalty float32 `json:"repeat_penalty,omitempty"`
PresencePenalty float32 `json:"presence_penalty,omitempty"` PresencePenalty float32 `json:"presence_penalty,omitempty"`
FrequencyPenalty float32 `json:"frequency_penalty,omitempty"` FrequencyPenalty float32 `json:"frequency_penalty,omitempty"`
Mirostat int `json:"mirostat,omitempty"`
MirostatTau float32 `json:"mirostat_tau,omitempty"`
MirostatEta float32 `json:"mirostat_eta,omitempty"`
Stop []string `json:"stop,omitempty"` Stop []string `json:"stop,omitempty"`
} }
@@ -293,7 +283,12 @@ type Runner struct {
NumBatch int `json:"num_batch,omitempty"` NumBatch int `json:"num_batch,omitempty"`
NumGPU int `json:"num_gpu,omitempty"` NumGPU int `json:"num_gpu,omitempty"`
MainGPU int `json:"main_gpu,omitempty"` MainGPU int `json:"main_gpu,omitempty"`
LowVRAM bool `json:"low_vram,omitempty"`
F16KV bool `json:"f16_kv,omitempty"` // Deprecated: This option is ignored
LogitsAll bool `json:"logits_all,omitempty"`
VocabOnly bool `json:"vocab_only,omitempty"`
UseMMap *bool `json:"use_mmap,omitempty"` UseMMap *bool `json:"use_mmap,omitempty"`
UseMLock bool `json:"use_mlock,omitempty"`
NumThread int `json:"num_thread,omitempty"` NumThread int `json:"num_thread,omitempty"`
} }
@@ -476,6 +471,13 @@ type ProcessModelResponse struct {
SizeVRAM int64 `json:"size_vram"` SizeVRAM int64 `json:"size_vram"`
} }
type RetrieveModelResponse struct {
Id string `json:"id"`
Object string `json:"object"`
Created int64 `json:"created"`
OwnedBy string `json:"owned_by"`
}
type TokenResponse struct { type TokenResponse struct {
Token string `json:"token"` Token string `json:"token"`
} }
@@ -491,10 +493,6 @@ type GenerateResponse struct {
// Response is the textual response itself. // Response is the textual response itself.
Response string `json:"response"` Response string `json:"response"`
// Thinking contains the text that was inside thinking tags in the
// original model output when ChatRequest.Think is enabled.
Thinking string `json:"thinking,omitempty"`
// Done specifies if the response is complete. // Done specifies if the response is complete.
Done bool `json:"done"` Done bool `json:"done"`
@@ -662,6 +660,9 @@ func DefaultOptions() Options {
RepeatPenalty: 1.1, RepeatPenalty: 1.1,
PresencePenalty: 0.0, PresencePenalty: 0.0,
FrequencyPenalty: 0.0, FrequencyPenalty: 0.0,
Mirostat: 0,
MirostatTau: 5.0,
MirostatEta: 0.1,
Seed: -1, Seed: -1,
Runner: Runner{ Runner: Runner{
@@ -670,6 +671,8 @@ func DefaultOptions() Options {
NumBatch: 512, NumBatch: 512,
NumGPU: -1, // -1 here indicates that NumGPU should be set dynamically NumGPU: -1, // -1 here indicates that NumGPU should be set dynamically
NumThread: 0, // let the runtime decide NumThread: 0, // let the runtime decide
LowVRAM: false,
UseMLock: false,
UseMMap: nil, UseMMap: nil,
}, },
} }

View File

@@ -372,50 +372,3 @@ func TestPropertyType_MarshalJSON(t *testing.T) {
}) })
} }
} }
func TestThinking_UnmarshalJSON(t *testing.T) {
trueVal := true
falseVal := false
tests := []struct {
name string
input string
expectedThinking *bool
expectedError bool
}{
{
name: "true",
input: `{ "think": true }`,
expectedThinking: &trueVal,
},
{
name: "false",
input: `{ "think": false }`,
expectedThinking: &falseVal,
},
{
name: "unset",
input: `{ }`,
expectedThinking: nil,
},
{
name: "invalid",
input: `{ "think": "true" }`,
expectedThinking: nil,
expectedError: true,
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
var req GenerateRequest
err := json.Unmarshal([]byte(test.input), &req)
if test.expectedError {
require.Error(t, err)
} else {
require.NoError(t, err)
assert.Equal(t, test.expectedThinking, req.Think)
}
})
}
}

View File

@@ -4,14 +4,20 @@ import (
"fmt" "fmt"
"log/slog" "log/slog"
"os" "os"
"path/filepath"
"strconv" "strconv"
"strings" "strings"
"github.com/ollama/ollama/envconfig" "github.com/ollama/ollama/envconfig"
"github.com/ollama/ollama/logutil"
) )
func InitLogging() { func InitLogging() {
level := slog.LevelInfo
if envconfig.Debug() {
level = slog.LevelDebug
}
var logFile *os.File var logFile *os.File
var err error var err error
// Detect if we're a GUI app on windows, and if not, send logs to console // Detect if we're a GUI app on windows, and if not, send logs to console
@@ -27,8 +33,20 @@ func InitLogging() {
return return
} }
} }
handler := slog.NewTextHandler(logFile, &slog.HandlerOptions{
Level: level,
AddSource: true,
ReplaceAttr: func(_ []string, attr slog.Attr) slog.Attr {
if attr.Key == slog.SourceKey {
source := attr.Value.Any().(*slog.Source)
source.File = filepath.Base(source.File)
}
return attr
},
})
slog.SetDefault(slog.New(handler))
slog.SetDefault(logutil.NewLogger(logFile, envconfig.LogLevel()))
slog.Info("ollama app started") slog.Info("ollama app started")
} }

View File

@@ -0,0 +1,178 @@
package benchmark
import (
"context"
"flag"
"fmt"
"testing"
"time"
"github.com/ollama/ollama/api"
)
// Command line flags
var modelFlag string
func init() {
flag.StringVar(&modelFlag, "m", "", "Name of the model to benchmark")
flag.Lookup("m").DefValue = "model"
}
// modelName returns the model name from flags, failing the test if not set
func modelName(b *testing.B) string {
if modelFlag == "" {
b.Fatal("Error: -m flag is required for benchmark tests")
}
return modelFlag
}
type TestCase struct {
name string
prompt string
maxTokens int
}
// runGenerateBenchmark contains the common generate and metrics logic
func runGenerateBenchmark(b *testing.B, ctx context.Context, client *api.Client, req *api.GenerateRequest) {
start := time.Now()
var ttft time.Duration
var metrics api.Metrics
err := client.Generate(ctx, req, func(resp api.GenerateResponse) error {
if ttft == 0 && resp.Response != "" {
ttft = time.Since(start)
}
if resp.Done {
metrics = resp.Metrics
}
return nil
})
// Report custom metrics as part of the benchmark results
b.ReportMetric(float64(ttft.Milliseconds()), "ttft_ms")
b.ReportMetric(float64(metrics.LoadDuration.Milliseconds()), "load_ms")
// Token throughput metrics
promptThroughput := float64(metrics.PromptEvalCount) / metrics.PromptEvalDuration.Seconds()
genThroughput := float64(metrics.EvalCount) / metrics.EvalDuration.Seconds()
b.ReportMetric(promptThroughput, "prompt_tok/s")
b.ReportMetric(genThroughput, "gen_tok/s")
// Token counts
b.ReportMetric(float64(metrics.PromptEvalCount), "prompt_tokens")
b.ReportMetric(float64(metrics.EvalCount), "gen_tokens")
if err != nil {
b.Fatal(err)
}
}
// BenchmarkColdStart runs benchmarks with model loading from cold state
func BenchmarkColdStart(b *testing.B) {
client := setup(b)
tests := []TestCase{
{"short_prompt", "Write a long story", 100},
{"medium_prompt", "Write a detailed economic analysis", 500},
{"long_prompt", "Write a comprehensive AI research paper", 1000},
}
m := modelName(b)
for _, tt := range tests {
b.Run(fmt.Sprintf("%s/cold/%s", m, tt.name), func(b *testing.B) {
ctx := context.Background()
// Set number of tokens as our throughput metric
b.SetBytes(int64(tt.maxTokens))
for b.Loop() {
b.StopTimer()
// Ensure model is unloaded before each iteration
unload(client, m, b)
b.StartTimer()
req := &api.GenerateRequest{
Model: m,
Prompt: tt.prompt,
Options: map[string]any{"num_predict": tt.maxTokens, "temperature": 0.1},
}
runGenerateBenchmark(b, ctx, client, req)
}
})
}
}
// BenchmarkWarmStart runs benchmarks with pre-loaded model
func BenchmarkWarmStart(b *testing.B) {
client := setup(b)
tests := []TestCase{
{"short_prompt", "Write a long story", 100},
{"medium_prompt", "Write a detailed economic analysis", 500},
{"long_prompt", "Write a comprehensive AI research paper", 1000},
}
m := modelName(b)
for _, tt := range tests {
b.Run(fmt.Sprintf("%s/warm/%s", m, tt.name), func(b *testing.B) {
ctx := context.Background()
// Pre-warm the model
warmup(client, m, tt.prompt, b)
// Set number of tokens as our throughput metric
b.SetBytes(int64(tt.maxTokens))
for b.Loop() {
req := &api.GenerateRequest{
Model: m,
Prompt: tt.prompt,
Options: map[string]any{"num_predict": tt.maxTokens, "temperature": 0.1},
}
runGenerateBenchmark(b, ctx, client, req)
}
})
}
}
// setup verifies server and model availability
func setup(b *testing.B) *api.Client {
client, err := api.ClientFromEnvironment()
if err != nil {
b.Fatal(err)
}
if _, err := client.Show(context.Background(), &api.ShowRequest{Model: modelName(b)}); err != nil {
b.Fatalf("Model unavailable: %v", err)
}
return client
}
// warmup ensures the model is loaded and warmed up
func warmup(client *api.Client, model string, prompt string, b *testing.B) {
for range 3 {
err := client.Generate(
context.Background(),
&api.GenerateRequest{
Model: model,
Prompt: prompt,
Options: map[string]any{"num_predict": 50, "temperature": 0.1},
},
func(api.GenerateResponse) error { return nil },
)
if err != nil {
b.Logf("Error during model warm-up: %v", err)
}
}
}
// unload forces model unloading using KeepAlive: 0 parameter
func unload(client *api.Client, model string, b *testing.B) {
req := &api.GenerateRequest{
Model: model,
KeepAlive: &api.Duration{Duration: 0},
}
if err := client.Generate(context.Background(), req, func(api.GenerateResponse) error { return nil }); err != nil {
b.Logf("Unload error: %v", err)
}
time.Sleep(1 * time.Second)
}

View File

@@ -31,7 +31,6 @@ import (
"github.com/olekukonko/tablewriter" "github.com/olekukonko/tablewriter"
"github.com/spf13/cobra" "github.com/spf13/cobra"
"golang.org/x/crypto/ssh" "golang.org/x/crypto/ssh"
"golang.org/x/sync/errgroup"
"golang.org/x/term" "golang.org/x/term"
"github.com/ollama/ollama/api" "github.com/ollama/ollama/api"
@@ -39,31 +38,12 @@ import (
"github.com/ollama/ollama/format" "github.com/ollama/ollama/format"
"github.com/ollama/ollama/parser" "github.com/ollama/ollama/parser"
"github.com/ollama/ollama/progress" "github.com/ollama/ollama/progress"
"github.com/ollama/ollama/readline"
"github.com/ollama/ollama/runner" "github.com/ollama/ollama/runner"
"github.com/ollama/ollama/server" "github.com/ollama/ollama/server"
"github.com/ollama/ollama/types/model" "github.com/ollama/ollama/types/model"
"github.com/ollama/ollama/types/syncmap"
"github.com/ollama/ollama/version" "github.com/ollama/ollama/version"
) )
// ensureThinkingSupport emits a warning if the model does not advertise thinking support
func ensureThinkingSupport(ctx context.Context, client *api.Client, name string) {
if name == "" {
return
}
resp, err := client.Show(ctx, &api.ShowRequest{Model: name})
if err != nil {
return
}
for _, cap := range resp.Capabilities {
if cap == model.CapabilityThinking {
return
}
}
fmt.Fprintf(os.Stderr, "warning: model %q does not support thinking output\n", name)
}
var errModelfileNotFound = errors.New("specified Modelfile wasn't found") var errModelfileNotFound = errors.New("specified Modelfile wasn't found")
func getModelfileName(cmd *cobra.Command) (string, error) { func getModelfileName(cmd *cobra.Command) (string, error) {
@@ -126,7 +106,7 @@ func CreateHandler(cmd *cobra.Command, args []string) error {
} }
spinner.Stop() spinner.Stop()
req.Model = args[0] req.Name = args[0]
quantize, _ := cmd.Flags().GetString("quantize") quantize, _ := cmd.Flags().GetString("quantize")
if quantize != "" { if quantize != "" {
req.Quantize = quantize req.Quantize = quantize
@@ -137,54 +117,34 @@ func CreateHandler(cmd *cobra.Command, args []string) error {
return err return err
} }
var g errgroup.Group if len(req.Files) > 0 {
g.SetLimit(max(runtime.GOMAXPROCS(0)-1, 1)) fileMap := map[string]string{}
for f, digest := range req.Files {
files := syncmap.NewSyncMap[string, string]()
for f, digest := range req.Files {
g.Go(func() error {
if _, err := createBlob(cmd, client, f, digest, p); err != nil { if _, err := createBlob(cmd, client, f, digest, p); err != nil {
return err return err
} }
fileMap[filepath.Base(f)] = digest
// TODO: this is incorrect since the file might be in a subdirectory }
// instead this should take the path relative to the model directory req.Files = fileMap
// but the current implementation does not allow this
files.Store(filepath.Base(f), digest)
return nil
})
} }
adapters := syncmap.NewSyncMap[string, string]() if len(req.Adapters) > 0 {
for f, digest := range req.Adapters { fileMap := map[string]string{}
g.Go(func() error { for f, digest := range req.Adapters {
if _, err := createBlob(cmd, client, f, digest, p); err != nil { if _, err := createBlob(cmd, client, f, digest, p); err != nil {
return err return err
} }
fileMap[filepath.Base(f)] = digest
// TODO: same here }
adapters.Store(filepath.Base(f), digest) req.Adapters = fileMap
return nil
})
} }
if err := g.Wait(); err != nil {
return err
}
req.Files = files.Items()
req.Adapters = adapters.Items()
bars := make(map[string]*progress.Bar) bars := make(map[string]*progress.Bar)
fn := func(resp api.ProgressResponse) error { fn := func(resp api.ProgressResponse) error {
if resp.Digest != "" { if resp.Digest != "" {
bar, ok := bars[resp.Digest] bar, ok := bars[resp.Digest]
if !ok { if !ok {
msg := resp.Status bar = progress.NewBar(fmt.Sprintf("pulling %s...", resp.Digest[7:19]), resp.Total, resp.Completed)
if msg == "" {
msg = fmt.Sprintf("pulling %s...", resp.Digest[7:19])
}
bar = progress.NewBar(msg, resp.Total, resp.Completed)
bars[resp.Digest] = bar bars[resp.Digest] = bar
p.Add(resp.Digest, bar) p.Add(resp.Digest, bar)
} }
@@ -253,7 +213,7 @@ func createBlob(cmd *cobra.Command, client *api.Client, path string, digest stri
} }
}() }()
if err := client.CreateBlob(cmd.Context(), digest, io.TeeReader(bin, &pw)); err != nil { if err = client.CreateBlob(cmd.Context(), digest, io.TeeReader(bin, &pw)); err != nil {
return "", err return "", err
} }
return digest, nil return digest, nil
@@ -283,9 +243,6 @@ func loadOrUnloadModel(cmd *cobra.Command, opts *runOptions) error {
req := &api.GenerateRequest{ req := &api.GenerateRequest{
Model: opts.Model, Model: opts.Model,
KeepAlive: opts.KeepAlive, KeepAlive: opts.KeepAlive,
// pass Think here so we fail before getting to the chat prompt if the model doesn't support it
Think: opts.Think,
} }
return client.Generate(cmd.Context(), req, func(api.GenerateResponse) error { return nil }) return client.Generate(cmd.Context(), req, func(api.GenerateResponse) error { return nil })
@@ -320,22 +277,6 @@ func RunHandler(cmd *cobra.Command, args []string) error {
} }
opts.Format = format opts.Format = format
thinkFlag := cmd.Flags().Lookup("think")
if thinkFlag.Changed {
think, err := cmd.Flags().GetBool("think")
if err != nil {
return err
}
opts.Think = &think
} else {
opts.Think = nil
}
hidethinking, err := cmd.Flags().GetBool("hidethinking")
if err != nil {
return err
}
opts.HideThinking = hidethinking
keepAlive, err := cmd.Flags().GetString("keepalive") keepAlive, err := cmd.Flags().GetString("keepalive")
if err != nil { if err != nil {
return err return err
@@ -399,11 +340,6 @@ func RunHandler(cmd *cobra.Command, args []string) error {
return err return err
} }
opts.Think, err = inferThinkingOption(&info.Capabilities, &opts, thinkFlag.Changed)
if err != nil {
return err
}
opts.MultiModal = slices.Contains(info.Capabilities, model.CapabilityVision) opts.MultiModal = slices.Contains(info.Capabilities, model.CapabilityVision)
// TODO: remove the projector info and vision info checks below, // TODO: remove the projector info and vision info checks below,
@@ -789,38 +725,11 @@ func showInfo(resp *api.ShowResponse, verbose bool, w io.Writer) error {
case float64: case float64:
v = fmt.Sprintf("%g", vData) v = fmt.Sprintf("%g", vData)
case []any: case []any:
targetWidth := 10 // Small width where we are displaying the data in a column n := 3
if len(vData) < n {
var itemsToShow int n = len(vData)
totalWidth := 1 // Start with 1 for opening bracket
// Find how many we can fit
for i := range vData {
itemStr := fmt.Sprintf("%v", vData[i])
width := runewidth.StringWidth(itemStr)
// Add separator width (", ") for all items except the first
if i > 0 {
width += 2
}
// Check if adding this item would exceed our width limit
if totalWidth+width > targetWidth && i > 0 {
break
}
totalWidth += width
itemsToShow++
}
// Format the output
if itemsToShow < len(vData) {
v = fmt.Sprintf("%v", vData[:itemsToShow])
v = strings.TrimSuffix(v, "]")
v += fmt.Sprintf(" ...+%d more]", len(vData)-itemsToShow)
} else {
v = fmt.Sprintf("%v", vData)
} }
v = fmt.Sprintf("%v", vData[:n])
default: default:
v = fmt.Sprintf("%T", vData) v = fmt.Sprintf("%T", vData)
} }
@@ -841,19 +750,10 @@ func showInfo(resp *api.ShowResponse, verbose bool, w io.Writer) error {
head := func(s string, n int) (rows [][]string) { head := func(s string, n int) (rows [][]string) {
scanner := bufio.NewScanner(strings.NewReader(s)) scanner := bufio.NewScanner(strings.NewReader(s))
count := 0 for scanner.Scan() && (len(rows) < n || n < 0) {
for scanner.Scan() { if text := scanner.Text(); text != "" {
text := strings.TrimSpace(scanner.Text()) rows = append(rows, []string{"", strings.TrimSpace(text)})
if text == "" {
continue
} }
count++
if n < 0 || count <= n {
rows = append(rows, []string{"", text})
}
}
if n >= 0 && count > n {
rows = append(rows, []string{"", "..."})
} }
return return
} }
@@ -908,38 +808,13 @@ func PullHandler(cmd *cobra.Command, args []string) error {
fn := func(resp api.ProgressResponse) error { fn := func(resp api.ProgressResponse) error {
if resp.Digest != "" { if resp.Digest != "" {
if resp.Completed == 0 {
// This is the initial status update for the
// layer, which the server sends before
// beginning the download, for clients to
// compute total size and prepare for
// downloads, if needed.
//
// Skipping this here to avoid showing a 0%
// progress bar, which *should* clue the user
// into the fact that many things are being
// downloaded and that the current active
// download is not that last. However, in rare
// cases it seems to be triggering to some, and
// it isn't worth explaining, so just ignore
// and regress to the old UI that keeps giving
// you the "But wait, there is more!" after
// each "100% done" bar, which is "better."
return nil
}
if spinner != nil { if spinner != nil {
spinner.Stop() spinner.Stop()
} }
bar, ok := bars[resp.Digest] bar, ok := bars[resp.Digest]
if !ok { if !ok {
name, isDigest := strings.CutPrefix(resp.Digest, "sha256:") bar = progress.NewBar(fmt.Sprintf("pulling %s...", resp.Digest[7:19]), resp.Total, resp.Completed)
name = strings.TrimSpace(name)
if isDigest {
name = name[:min(12, len(name))]
}
bar = progress.NewBar(fmt.Sprintf("pulling %s:", name), resp.Total, resp.Completed)
bars[resp.Digest] = bar bars[resp.Digest] = bar
p.Add(resp.Digest, bar) p.Add(resp.Digest, bar)
} }
@@ -959,25 +834,27 @@ func PullHandler(cmd *cobra.Command, args []string) error {
} }
request := api.PullRequest{Name: args[0], Insecure: insecure} request := api.PullRequest{Name: args[0], Insecure: insecure}
return client.Pull(cmd.Context(), &request, fn) if err := client.Pull(cmd.Context(), &request, fn); err != nil {
return err
}
return nil
} }
type generateContextKey string type generateContextKey string
type runOptions struct { type runOptions struct {
Model string Model string
ParentModel string ParentModel string
Prompt string Prompt string
Messages []api.Message Messages []api.Message
WordWrap bool WordWrap bool
Format string Format string
System string System string
Images []api.ImageData Images []api.ImageData
Options map[string]any Options map[string]any
MultiModal bool MultiModal bool
KeepAlive *api.Duration KeepAlive *api.Duration
Think *bool
HideThinking bool
} }
type displayResponseState struct { type displayResponseState struct {
@@ -1033,26 +910,6 @@ func displayResponse(content string, wordWrap bool, state *displayResponseState)
} }
} }
func thinkingOutputOpeningText(plainText bool) string {
text := "Thinking...\n"
if plainText {
return text
}
return readline.ColorGrey + readline.ColorBold + text + readline.ColorDefault + readline.ColorGrey
}
func thinkingOutputClosingText(plainText bool) string {
text := "...done thinking.\n\n"
if plainText {
return text
}
return readline.ColorGrey + readline.ColorBold + text + readline.ColorDefault
}
func chat(cmd *cobra.Command, opts runOptions) (*api.Message, error) { func chat(cmd *cobra.Command, opts runOptions) (*api.Message, error) {
client, err := api.ClientFromEnvironment() client, err := api.ClientFromEnvironment()
if err != nil { if err != nil {
@@ -1080,34 +937,14 @@ func chat(cmd *cobra.Command, opts runOptions) (*api.Message, error) {
var latest api.ChatResponse var latest api.ChatResponse
var fullResponse strings.Builder var fullResponse strings.Builder
var role string var role string
var thinkTagOpened bool = false
var thinkTagClosed bool = false
fn := func(response api.ChatResponse) error { fn := func(response api.ChatResponse) error {
if response.Message.Content != "" || !opts.HideThinking { p.StopAndClear()
p.StopAndClear()
}
latest = response latest = response
role = response.Message.Role role = response.Message.Role
if response.Message.Thinking != "" && !opts.HideThinking {
if !thinkTagOpened {
fmt.Print(thinkingOutputOpeningText(false))
thinkTagOpened = true
}
displayResponse(response.Message.Thinking, opts.WordWrap, state)
}
content := response.Message.Content content := response.Message.Content
if thinkTagOpened && !thinkTagClosed && content != "" {
fmt.Print(thinkingOutputClosingText(false))
thinkTagClosed = true
}
// purposefully not putting thinking blocks in the response, which would
// only be needed if we later added tool calling to the cli (they get
// filtered out anyway since current models don't expect them unless you're
// about to finish some tool calls)
fullResponse.WriteString(content) fullResponse.WriteString(content)
displayResponse(content, opts.WordWrap, state) displayResponse(content, opts.WordWrap, state)
@@ -1124,7 +961,6 @@ func chat(cmd *cobra.Command, opts runOptions) (*api.Message, error) {
Messages: opts.Messages, Messages: opts.Messages,
Format: json.RawMessage(opts.Format), Format: json.RawMessage(opts.Format),
Options: opts.Options, Options: opts.Options,
Think: opts.Think,
} }
if opts.KeepAlive != nil { if opts.KeepAlive != nil {
@@ -1186,32 +1022,13 @@ func generate(cmd *cobra.Command, opts runOptions) error {
}() }()
var state *displayResponseState = &displayResponseState{} var state *displayResponseState = &displayResponseState{}
var thinkTagOpened bool = false
var thinkTagClosed bool = false
plainText := !term.IsTerminal(int(os.Stdout.Fd()))
fn := func(response api.GenerateResponse) error { fn := func(response api.GenerateResponse) error {
p.StopAndClear()
latest = response latest = response
content := response.Response content := response.Response
if response.Response != "" || !opts.HideThinking {
p.StopAndClear()
}
if response.Thinking != "" && !opts.HideThinking {
if !thinkTagOpened {
fmt.Print(thinkingOutputOpeningText(plainText))
thinkTagOpened = true
}
displayResponse(response.Thinking, opts.WordWrap, state)
}
if thinkTagOpened && !thinkTagClosed && content != "" {
fmt.Print(thinkingOutputClosingText(plainText))
thinkTagClosed = true
}
displayResponse(content, opts.WordWrap, state) displayResponse(content, opts.WordWrap, state)
return nil return nil
@@ -1237,7 +1054,6 @@ func generate(cmd *cobra.Command, opts runOptions) error {
System: opts.System, System: opts.System,
Options: opts.Options, Options: opts.Options,
KeepAlive: opts.KeepAlive, KeepAlive: opts.KeepAlive,
Think: opts.Think,
} }
if err := client.Generate(ctx, &request, fn); err != nil { if err := client.Generate(ctx, &request, fn); err != nil {
@@ -1341,11 +1157,11 @@ func checkServerHeartbeat(cmd *cobra.Command, _ []string) error {
return err return err
} }
if err := client.Heartbeat(cmd.Context()); err != nil { if err := client.Heartbeat(cmd.Context()); err != nil {
if !(strings.Contains(err.Error(), " refused") || strings.Contains(err.Error(), "could not connect")) { if !strings.Contains(err.Error(), " refused") {
return err return err
} }
if err := startApp(cmd.Context(), client); err != nil { if err := startApp(cmd.Context(), client); err != nil {
return fmt.Errorf("ollama server not responding - %w", err) return errors.New("could not connect to ollama app, is it running?")
} }
} }
return nil return nil
@@ -1423,7 +1239,7 @@ func NewCLI() *cobra.Command {
} }
createCmd.Flags().StringP("file", "f", "", "Name of the Modelfile (default \"Modelfile\"") createCmd.Flags().StringP("file", "f", "", "Name of the Modelfile (default \"Modelfile\"")
createCmd.Flags().StringP("quantize", "q", "", "Quantize model to this level (e.g. q4_K_M)") createCmd.Flags().StringP("quantize", "q", "", "Quantize model to this level (e.g. q4_0)")
showCmd := &cobra.Command{ showCmd := &cobra.Command{
Use: "show MODEL", Use: "show MODEL",
@@ -1453,8 +1269,6 @@ func NewCLI() *cobra.Command {
runCmd.Flags().Bool("insecure", false, "Use an insecure registry") runCmd.Flags().Bool("insecure", false, "Use an insecure registry")
runCmd.Flags().Bool("nowordwrap", false, "Don't wrap words to the next line automatically") runCmd.Flags().Bool("nowordwrap", false, "Don't wrap words to the next line automatically")
runCmd.Flags().String("format", "", "Response format (e.g. json)") runCmd.Flags().String("format", "", "Response format (e.g. json)")
runCmd.Flags().Bool("think", false, "Whether to use thinking mode for supported models")
runCmd.Flags().Bool("hidethinking", false, "Hide thinking output (if provided)")
stopCmd := &cobra.Command{ stopCmd := &cobra.Command{
Use: "stop MODEL", Use: "stop MODEL",
@@ -1506,6 +1320,7 @@ func NewCLI() *cobra.Command {
PreRunE: checkServerHeartbeat, PreRunE: checkServerHeartbeat,
RunE: ListRunningHandler, RunE: ListRunningHandler,
} }
copyCmd := &cobra.Command{ copyCmd := &cobra.Command{
Use: "cp SOURCE DESTINATION", Use: "cp SOURCE DESTINATION",
Short: "Copy a model", Short: "Copy a model",
@@ -1594,45 +1409,3 @@ func NewCLI() *cobra.Command {
return rootCmd return rootCmd
} }
// If the user has explicitly set thinking options, either through the CLI or
// through the `/set think` or `set nothink` interactive options, then we
// respect them. Otherwise, we check model capabilities to see if the model
// supports thinking. If the model does support thinking, we enable it.
// Otherwise, we unset the thinking option (which is different than setting it
// to false).
//
// If capabilities are not provided, we fetch them from the server.
func inferThinkingOption(caps *[]model.Capability, runOpts *runOptions, explicitlySetByUser bool) (*bool, error) {
if explicitlySetByUser {
return runOpts.Think, nil
}
if caps == nil {
client, err := api.ClientFromEnvironment()
if err != nil {
return nil, err
}
ret, err := client.Show(context.Background(), &api.ShowRequest{
Model: runOpts.Model,
})
if err != nil {
return nil, err
}
caps = &ret.Capabilities
}
thinkingSupported := false
for _, cap := range *caps {
if cap == model.CapabilityThinking {
thinkingSupported = true
}
}
if thinkingSupported {
thinking := true
return &thinking, nil
}
return nil, nil
}

View File

@@ -2,6 +2,7 @@ package cmd
import ( import (
"bytes" "bytes"
"context"
"encoding/json" "encoding/json"
"io" "io"
"net/http" "net/http"
@@ -225,7 +226,6 @@ Weigh anchor!
System System
You are a pirate! You are a pirate!
Ahoy, matey! Ahoy, matey!
...
` `
if diff := cmp.Diff(expect, b.String()); diff != "" { if diff := cmp.Diff(expect, b.String()); diff != "" {
@@ -337,7 +337,7 @@ func TestDeleteHandler(t *testing.T) {
t.Cleanup(mockServer.Close) t.Cleanup(mockServer.Close)
cmd := &cobra.Command{} cmd := &cobra.Command{}
cmd.SetContext(t.Context()) cmd.SetContext(context.TODO())
if err := DeleteHandler(cmd, []string{"test-model"}); err != nil { if err := DeleteHandler(cmd, []string{"test-model"}); err != nil {
t.Fatalf("DeleteHandler failed: %v", err) t.Fatalf("DeleteHandler failed: %v", err)
} }
@@ -399,6 +399,11 @@ func TestGetModelfileName(t *testing.T) {
var expectedFilename string var expectedFilename string
if tt.fileExists { if tt.fileExists {
tempDir, err := os.MkdirTemp("", "modelfiledir")
defer os.RemoveAll(tempDir)
if err != nil {
t.Fatalf("temp modelfile dir creation failed: %v", err)
}
var fn string var fn string
if tt.modelfileName != "" { if tt.modelfileName != "" {
fn = tt.modelfileName fn = tt.modelfileName
@@ -406,11 +411,10 @@ func TestGetModelfileName(t *testing.T) {
fn = "Modelfile" fn = "Modelfile"
} }
tempFile, err := os.CreateTemp(t.TempDir(), fn) tempFile, err := os.CreateTemp(tempDir, fn)
if err != nil { if err != nil {
t.Fatalf("temp modelfile creation failed: %v", err) t.Fatalf("temp modelfile creation failed: %v", err)
} }
defer tempFile.Close()
expectedFilename = tempFile.Name() expectedFilename = tempFile.Name()
err = cmd.Flags().Set("file", expectedFilename) err = cmd.Flags().Set("file", expectedFilename)
@@ -525,7 +529,7 @@ func TestPushHandler(t *testing.T) {
cmd := &cobra.Command{} cmd := &cobra.Command{}
cmd.Flags().Bool("insecure", false, "") cmd.Flags().Bool("insecure", false, "")
cmd.SetContext(t.Context()) cmd.SetContext(context.TODO())
// Redirect stderr to capture progress output // Redirect stderr to capture progress output
oldStderr := os.Stderr oldStderr := os.Stderr
@@ -630,7 +634,7 @@ func TestListHandler(t *testing.T) {
t.Setenv("OLLAMA_HOST", mockServer.URL) t.Setenv("OLLAMA_HOST", mockServer.URL)
cmd := &cobra.Command{} cmd := &cobra.Command{}
cmd.SetContext(t.Context()) cmd.SetContext(context.TODO())
// Capture stdout // Capture stdout
oldStdout := os.Stdout oldStdout := os.Stdout
@@ -685,7 +689,7 @@ func TestCreateHandler(t *testing.T) {
return return
} }
if req.Model != "test-model" { if req.Name != "test-model" {
t.Errorf("expected model name 'test-model', got %s", req.Name) t.Errorf("expected model name 'test-model', got %s", req.Name)
} }
@@ -725,7 +729,7 @@ func TestCreateHandler(t *testing.T) {
})) }))
t.Setenv("OLLAMA_HOST", mockServer.URL) t.Setenv("OLLAMA_HOST", mockServer.URL)
t.Cleanup(mockServer.Close) t.Cleanup(mockServer.Close)
tempFile, err := os.CreateTemp(t.TempDir(), "modelfile") tempFile, err := os.CreateTemp("", "modelfile")
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@@ -745,7 +749,7 @@ func TestCreateHandler(t *testing.T) {
} }
cmd.Flags().Bool("insecure", false, "") cmd.Flags().Bool("insecure", false, "")
cmd.SetContext(t.Context()) cmd.SetContext(context.TODO())
// Redirect stderr to capture progress output // Redirect stderr to capture progress output
oldStderr := os.Stderr oldStderr := os.Stderr

View File

@@ -44,7 +44,7 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error {
fmt.Fprintln(os.Stderr, "Use \"\"\" to begin a multi-line message.") fmt.Fprintln(os.Stderr, "Use \"\"\" to begin a multi-line message.")
if opts.MultiModal { if opts.MultiModal {
fmt.Fprintf(os.Stderr, "Use %s to include .jpg, .png, or .webp images.\n", filepath.FromSlash("/path/to/file")) fmt.Fprintf(os.Stderr, "Use %s to include .jpg or .png images.\n", filepath.FromSlash("/path/to/file"))
} }
fmt.Fprintln(os.Stderr, "") fmt.Fprintln(os.Stderr, "")
@@ -62,8 +62,6 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error {
fmt.Fprintln(os.Stderr, " /set noformat Disable formatting") fmt.Fprintln(os.Stderr, " /set noformat Disable formatting")
fmt.Fprintln(os.Stderr, " /set verbose Show LLM stats") fmt.Fprintln(os.Stderr, " /set verbose Show LLM stats")
fmt.Fprintln(os.Stderr, " /set quiet Disable LLM stats") fmt.Fprintln(os.Stderr, " /set quiet Disable LLM stats")
fmt.Fprintln(os.Stderr, " /set think Enable thinking")
fmt.Fprintln(os.Stderr, " /set nothink Disable thinking")
fmt.Fprintln(os.Stderr, "") fmt.Fprintln(os.Stderr, "")
} }
@@ -130,7 +128,6 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error {
var sb strings.Builder var sb strings.Builder
var multiline MultilineState var multiline MultilineState
var thinkExplicitlySet bool = opts.Think != nil
for { for {
line, err := scanner.Readline() line, err := scanner.Readline()
@@ -198,19 +195,11 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error {
opts.Model = args[1] opts.Model = args[1]
opts.Messages = []api.Message{} opts.Messages = []api.Message{}
fmt.Printf("Loading model '%s'\n", opts.Model) fmt.Printf("Loading model '%s'\n", opts.Model)
opts.Think, err = inferThinkingOption(nil, &opts, thinkExplicitlySet)
if err != nil {
return err
}
if err := loadOrUnloadModel(cmd, &opts); err != nil { if err := loadOrUnloadModel(cmd, &opts); err != nil {
if strings.Contains(err.Error(), "not found") { if strings.Contains(err.Error(), "not found") {
fmt.Printf("error: %v\n", err) fmt.Printf("error: %v\n", err)
continue continue
} }
if strings.Contains(err.Error(), "does not support thinking") {
fmt.Printf("error: %v\n", err)
continue
}
return err return err
} }
continue continue
@@ -271,22 +260,6 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error {
return err return err
} }
fmt.Println("Set 'quiet' mode.") fmt.Println("Set 'quiet' mode.")
case "think":
think := true
opts.Think = &think
thinkExplicitlySet = true
if client, err := api.ClientFromEnvironment(); err == nil {
ensureThinkingSupport(cmd.Context(), client, opts.Model)
}
fmt.Println("Set 'think' mode.")
case "nothink":
think := false
opts.Think = &think
thinkExplicitlySet = true
if client, err := api.ClientFromEnvironment(); err == nil {
ensureThinkingSupport(cmd.Context(), client, opts.Model)
}
fmt.Println("Set 'nothink' mode.")
case "format": case "format":
if len(args) < 3 || args[2] != "json" { if len(args) < 3 || args[2] != "json" {
fmt.Println("Invalid or missing format. For 'json' mode use '/set format json'") fmt.Println("Invalid or missing format. For 'json' mode use '/set format json'")
@@ -475,11 +448,6 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error {
assistant, err := chat(cmd, opts) assistant, err := chat(cmd, opts)
if err != nil { if err != nil {
if strings.Contains(err.Error(), "does not support thinking") {
fmt.Printf("error: %v\n", err)
sb.Reset()
continue
}
return err return err
} }
if assistant != nil { if assistant != nil {
@@ -535,7 +503,6 @@ func normalizeFilePath(fp string) string {
"\\\\", "\\", // Escaped backslash "\\\\", "\\", // Escaped backslash
"\\*", "*", // Escaped asterisk "\\*", "*", // Escaped asterisk
"\\?", "?", // Escaped question mark "\\?", "?", // Escaped question mark
"\\~", "~", // Escaped tilde
).Replace(fp) ).Replace(fp)
} }
@@ -543,7 +510,7 @@ func extractFileNames(input string) []string {
// Regex to match file paths starting with optional drive letter, / ./ \ or .\ and include escaped or unescaped spaces (\ or %20) // Regex to match file paths starting with optional drive letter, / ./ \ or .\ and include escaped or unescaped spaces (\ or %20)
// and followed by more characters and a file extension // and followed by more characters and a file extension
// This will capture non filename strings, but we'll check for file existence to remove mismatches // This will capture non filename strings, but we'll check for file existence to remove mismatches
regexPattern := `(?:[a-zA-Z]:)?(?:\./|/|\\)[\S\\ ]+?\.(?i:jpg|jpeg|png|webp)\b` regexPattern := `(?:[a-zA-Z]:)?(?:\./|/|\\)[\S\\ ]+?\.(?i:jpg|jpeg|png)\b`
re := regexp.MustCompile(regexPattern) re := regexp.MustCompile(regexPattern)
return re.FindAllString(input, -1) return re.FindAllString(input, -1)
@@ -563,8 +530,6 @@ func extractFileData(input string) (string, []api.ImageData, error) {
return "", imgs, err return "", imgs, err
} }
fmt.Fprintf(os.Stderr, "Added image '%s'\n", nfp) fmt.Fprintf(os.Stderr, "Added image '%s'\n", nfp)
input = strings.ReplaceAll(input, "'"+nfp+"'", "")
input = strings.ReplaceAll(input, "'"+fp+"'", "")
input = strings.ReplaceAll(input, fp, "") input = strings.ReplaceAll(input, fp, "")
imgs = append(imgs, data) imgs = append(imgs, data)
} }
@@ -585,7 +550,7 @@ func getImageData(filePath string) ([]byte, error) {
} }
contentType := http.DetectContentType(buf) contentType := http.DetectContentType(buf)
allowedTypes := []string{"image/jpeg", "image/jpg", "image/png", "image/webp"} allowedTypes := []string{"image/jpeg", "image/jpg", "image/png"}
if !slices.Contains(allowedTypes, contentType) { if !slices.Contains(allowedTypes, contentType) {
return nil, fmt.Errorf("invalid image type: %s", contentType) return nil, fmt.Errorf("invalid image type: %s", contentType)
} }

View File

@@ -1,8 +1,6 @@
package cmd package cmd
import ( import (
"os"
"path/filepath"
"testing" "testing"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
@@ -12,17 +10,14 @@ func TestExtractFilenames(t *testing.T) {
// Unix style paths // Unix style paths
input := ` some preamble input := ` some preamble
./relative\ path/one.png inbetween1 ./not a valid two.jpg inbetween2 ./1.svg ./relative\ path/one.png inbetween1 ./not a valid two.jpg inbetween2 ./1.svg
/unescaped space /three.jpeg inbetween3 /valid\ path/dir/four.png "./quoted with spaces/five.JPG /unescaped space /three.jpeg inbetween3 /valid\ path/dir/four.png "./quoted with spaces/five.JPG`
/unescaped space /six.webp inbetween6 /valid\ path/dir/seven.WEBP`
res := extractFileNames(input) res := extractFileNames(input)
assert.Len(t, res, 7) assert.Len(t, res, 5)
assert.Contains(t, res[0], "one.png") assert.Contains(t, res[0], "one.png")
assert.Contains(t, res[1], "two.jpg") assert.Contains(t, res[1], "two.jpg")
assert.Contains(t, res[2], "three.jpeg") assert.Contains(t, res[2], "three.jpeg")
assert.Contains(t, res[3], "four.png") assert.Contains(t, res[3], "four.png")
assert.Contains(t, res[4], "five.JPG") assert.Contains(t, res[4], "five.JPG")
assert.Contains(t, res[5], "six.webp")
assert.Contains(t, res[6], "seven.WEBP")
assert.NotContains(t, res[4], '"') assert.NotContains(t, res[4], '"')
assert.NotContains(t, res, "inbetween1") assert.NotContains(t, res, "inbetween1")
assert.NotContains(t, res, "./1.svg") assert.NotContains(t, res, "./1.svg")
@@ -33,12 +28,10 @@ func TestExtractFilenames(t *testing.T) {
/absolute/nospace/three.jpeg inbetween3 /absolute/with space/four.png inbetween4 /absolute/nospace/three.jpeg inbetween3 /absolute/with space/four.png inbetween4
./relative\ path/five.JPG inbetween5 "./relative with/spaces/six.png inbetween6 ./relative\ path/five.JPG inbetween5 "./relative with/spaces/six.png inbetween6
d:\path with\spaces\seven.JPEG inbetween7 c:\users\jdoe\eight.png inbetween8 d:\path with\spaces\seven.JPEG inbetween7 c:\users\jdoe\eight.png inbetween8
d:\program files\someplace\nine.png inbetween9 "E:\program files\someplace\ten.PNG d:\program files\someplace\nine.png inbetween9 "E:\program files\someplace\ten.PNG some ending
c:/users/jdoe/eleven.webp inbetween11 c:/program files/someplace/twelve.WebP inbetween12
d:\path with\spaces\thirteen.WEBP some ending
` `
res = extractFileNames(input) res = extractFileNames(input)
assert.Len(t, res, 13) assert.Len(t, res, 10)
assert.NotContains(t, res, "inbetween2") assert.NotContains(t, res, "inbetween2")
assert.Contains(t, res[0], "one.png") assert.Contains(t, res[0], "one.png")
assert.Contains(t, res[0], "c:") assert.Contains(t, res[0], "c:")
@@ -56,31 +49,4 @@ d:\path with\spaces\thirteen.WEBP some ending
assert.Contains(t, res[8], "d:") assert.Contains(t, res[8], "d:")
assert.Contains(t, res[9], "ten.PNG") assert.Contains(t, res[9], "ten.PNG")
assert.Contains(t, res[9], "E:") assert.Contains(t, res[9], "E:")
assert.Contains(t, res[10], "eleven.webp")
assert.Contains(t, res[10], "c:")
assert.Contains(t, res[11], "twelve.WebP")
assert.Contains(t, res[11], "c:")
assert.Contains(t, res[12], "thirteen.WEBP")
assert.Contains(t, res[12], "d:")
}
// Ensure that file paths wrapped in single quotes are removed with the quotes.
func TestExtractFileDataRemovesQuotedFilepath(t *testing.T) {
dir := t.TempDir()
fp := filepath.Join(dir, "img.jpg")
data := make([]byte, 600)
copy(data, []byte{
0xff, 0xd8, 0xff, 0xe0, 0x00, 0x10, 'J', 'F', 'I', 'F',
0x00, 0x01, 0x01, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
0xff, 0xd9,
})
if err := os.WriteFile(fp, data, 0o600); err != nil {
t.Fatalf("failed to write test image: %v", err)
}
input := "before '" + fp + "' after"
cleaned, imgs, err := extractFileData(input)
assert.NoError(t, err)
assert.Len(t, imgs, 1)
assert.Equal(t, cleaned, "before after")
} }

View File

@@ -5,7 +5,7 @@ import (
"errors" "errors"
"os" "os"
"os/exec" "os/exec"
"regexp" "strings"
"github.com/ollama/ollama/api" "github.com/ollama/ollama/api"
) )
@@ -19,12 +19,11 @@ func startApp(ctx context.Context, client *api.Client) error {
if err != nil { if err != nil {
return err return err
} }
r := regexp.MustCompile(`^.*/Ollama\s?\d*.app`) if !strings.Contains(link, "Ollama.app") {
m := r.FindStringSubmatch(link)
if len(m) != 1 {
return errors.New("could not find ollama app") return errors.New("could not find ollama app")
} }
if err := exec.Command("/usr/bin/open", "-j", "-a", m[0], "--args", "--fast-startup").Run(); err != nil { path := strings.Split(link, "Ollama.app")
if err := exec.Command("/usr/bin/open", "-a", path[0]+"Ollama.app").Run(); err != nil {
return err return err
} }
return waitForServer(ctx, client) return waitForServer(ctx, client)

View File

@@ -4,27 +4,17 @@ import (
"context" "context"
"errors" "errors"
"fmt" "fmt"
"log/slog"
"os" "os"
"os/exec" "os/exec"
"path"
"path/filepath" "path/filepath"
"strings" "strings"
"syscall" "syscall"
"unsafe"
"github.com/ollama/ollama/api" "github.com/ollama/ollama/api"
"golang.org/x/sys/windows"
)
const (
Installer = "OllamaSetup.exe"
) )
func startApp(ctx context.Context, client *api.Client) error { func startApp(ctx context.Context, client *api.Client) error {
if len(isProcRunning(Installer)) > 0 { // log.Printf("XXX Attempting to find and start ollama app")
return fmt.Errorf("upgrade in progress...")
}
AppName := "ollama app.exe" AppName := "ollama app.exe"
exe, err := os.Executable() exe, err := os.Executable()
if err != nil { if err != nil {
@@ -45,11 +35,14 @@ func startApp(ctx context.Context, client *api.Client) error {
} }
} }
} }
// log.Printf("XXX attempting to start app %s", appExe)
cmd_path := "c:\\Windows\\system32\\cmd.exe" cmd_path := "c:\\Windows\\system32\\cmd.exe"
cmd := exec.Command(cmd_path, "/c", appExe, "--hide", "--fast-startup") cmd := exec.Command(cmd_path, "/c", appExe)
// TODO - these hide flags aren't working - still pops up a command window for some reason
cmd.SysProcAttr = &syscall.SysProcAttr{CreationFlags: 0x08000000, HideWindow: true} cmd.SysProcAttr = &syscall.SysProcAttr{CreationFlags: 0x08000000, HideWindow: true}
// TODO this didn't help either...
cmd.Stdin = strings.NewReader("") cmd.Stdin = strings.NewReader("")
cmd.Stdout = os.Stdout cmd.Stdout = os.Stdout
cmd.Stderr = os.Stderr cmd.Stderr = os.Stderr
@@ -63,50 +56,3 @@ func startApp(ctx context.Context, client *api.Client) error {
} }
return waitForServer(ctx, client) return waitForServer(ctx, client)
} }
func isProcRunning(procName string) []uint32 {
pids := make([]uint32, 2048)
var ret uint32
if err := windows.EnumProcesses(pids, &ret); err != nil || ret == 0 {
slog.Debug("failed to check for running installers", "error", err)
return nil
}
if ret > uint32(len(pids)) {
pids = make([]uint32, ret+10)
if err := windows.EnumProcesses(pids, &ret); err != nil || ret == 0 {
slog.Debug("failed to check for running installers", "error", err)
return nil
}
}
if ret < uint32(len(pids)) {
pids = pids[:ret]
}
var matches []uint32
for _, pid := range pids {
if pid == 0 {
continue
}
hProcess, err := windows.OpenProcess(windows.PROCESS_QUERY_INFORMATION|windows.PROCESS_VM_READ, false, pid)
if err != nil {
continue
}
defer windows.CloseHandle(hProcess)
var module windows.Handle
var cbNeeded uint32
cb := (uint32)(unsafe.Sizeof(module))
if err := windows.EnumProcessModules(hProcess, &module, cb, &cbNeeded); err != nil {
continue
}
var sz uint32 = 1024 * 8
moduleName := make([]uint16, sz)
cb = uint32(len(moduleName)) * (uint32)(unsafe.Sizeof(uint16(0)))
if err := windows.GetModuleBaseName(hProcess, module, &moduleName[0], cb); err != nil && err != syscall.ERROR_INSUFFICIENT_BUFFER {
continue
}
exeFile := path.Base(strings.ToLower(syscall.UTF16ToString(moduleName)))
if strings.EqualFold(exeFile, procName) {
matches = append(matches, pid)
}
}
return matches
}

View File

@@ -1,63 +0,0 @@
package cmd
import (
"encoding/json"
"io"
"net/http"
"net/http/httptest"
"os"
"strings"
"testing"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/types/model"
)
// Test that a warning is printed when thinking is requested but not supported.
func TestWarnMissingThinking(t *testing.T) {
cases := []struct {
capabilities []model.Capability
expectWarn bool
}{
{capabilities: []model.Capability{model.CapabilityThinking}, expectWarn: false},
{capabilities: []model.Capability{}, expectWarn: true},
}
for _, tc := range cases {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/api/show" || r.Method != http.MethodPost {
t.Fatalf("unexpected request to %s %s", r.URL.Path, r.Method)
}
var req api.ShowRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
t.Fatalf("decode request: %v", err)
}
resp := api.ShowResponse{Capabilities: tc.capabilities}
if err := json.NewEncoder(w).Encode(resp); err != nil {
t.Fatalf("encode response: %v", err)
}
}))
defer srv.Close()
t.Setenv("OLLAMA_HOST", srv.URL)
client, err := api.ClientFromEnvironment()
if err != nil {
t.Fatal(err)
}
oldStderr := os.Stderr
r, w, _ := os.Pipe()
os.Stderr = w
ensureThinkingSupport(t.Context(), client, "m")
w.Close()
os.Stderr = oldStderr
out, _ := io.ReadAll(r)
warned := strings.Contains(string(out), "warning:")
if tc.expectWarn && !warned {
t.Errorf("expected warning, got none")
}
if !tc.expectWarn && warned {
t.Errorf("did not expect warning, got: %s", string(out))
}
}
}

View File

@@ -1,26 +1,25 @@
package convert package convert
import ( import (
"cmp"
"encoding/json" "encoding/json"
"errors" "errors"
"fmt" "fmt"
"io"
"io/fs" "io/fs"
"log/slog" "log/slog"
"os"
"slices"
"strings" "strings"
"github.com/ollama/ollama/fs/ggml" "github.com/ollama/ollama/fs/ggml"
) )
type ModelParameters struct { type ModelParameters struct {
Architectures []string `json:"architectures"` Architectures []string `json:"architectures"`
VocabSize uint32 `json:"vocab_size"` VocabSize uint32 `json:"vocab_size"`
TextModel TextParameters `json:"text_config"`
}
TextModel struct { type TextParameters struct {
VocabSize uint32 `json:"vocab_size"` VocabSize uint32 `json:"vocab_size"`
} `json:"text_config"`
} }
type AdapterParameters struct { type AdapterParameters struct {
@@ -53,11 +52,8 @@ func (ModelParameters) KV(t *Tokenizer) ggml.KV {
} }
for _, sv := range t.SpecialVocabulary { for _, sv := range t.SpecialVocabulary {
kv[fmt.Sprintf("tokenizer.ggml.add_%s_token", sv.Key())] = sv.AddToken
kv[fmt.Sprintf("tokenizer.ggml.%s_token_id", sv.Key())] = uint32(sv.ID) kv[fmt.Sprintf("tokenizer.ggml.%s_token_id", sv.Key())] = uint32(sv.ID)
if len(sv.IDs) > 0 { kv[fmt.Sprintf("tokenizer.ggml.add_%s_token", sv.Key())] = sv.AddToken
kv[fmt.Sprintf("tokenizer.ggml.%s_token_ids", sv.Key())] = sv.IDs
}
} }
return kv return kv
@@ -88,17 +84,27 @@ func (ModelParameters) specialTokenTypes() []string {
} }
} }
func (ModelParameters) writeFile(ws io.WriteSeeker, kv ggml.KV, ts []ggml.Tensor) error {
return ggml.WriteGGUF(ws, kv, ts)
}
func (AdapterParameters) writeFile(ws io.WriteSeeker, kv ggml.KV, ts []ggml.Tensor) error {
return ggml.WriteGGUF(ws, kv, ts)
}
type ModelConverter interface { type ModelConverter interface {
// KV maps parameters to LLM key-values // KV maps parameters to LLM key-values
KV(*Tokenizer) ggml.KV KV(*Tokenizer) ggml.KV
// Tensors maps input tensors to LLM tensors. Model specific modifications can be done here. // Tensors maps input tensors to LLM tensors. Model specific modifications can be done here.
Tensors([]Tensor) []*ggml.Tensor Tensors([]Tensor) []ggml.Tensor
// Replacements returns a list of string pairs to replace in tensor names. // Replacements returns a list of string pairs to replace in tensor names.
// See [strings.Replacer](https://pkg.go.dev/strings#Replacer) for details // See [strings.Replacer](https://pkg.go.dev/strings#Replacer) for details
Replacements() []string Replacements() []string
// specialTokenTypes returns any special token types the model uses // specialTokenTypes returns any special token types the model uses
specialTokenTypes() []string specialTokenTypes() []string
// writeFile writes the model to the provided io.WriteSeeker
writeFile(io.WriteSeeker, ggml.KV, []ggml.Tensor) error
} }
type moreParser interface { type moreParser interface {
@@ -109,13 +115,15 @@ type AdapterConverter interface {
// KV maps parameters to LLM key-values // KV maps parameters to LLM key-values
KV(ggml.KV) ggml.KV KV(ggml.KV) ggml.KV
// Tensors maps input tensors to LLM tensors. Adapter specific modifications can be done here. // Tensors maps input tensors to LLM tensors. Adapter specific modifications can be done here.
Tensors([]Tensor) []*ggml.Tensor Tensors([]Tensor) []ggml.Tensor
// Replacements returns a list of string pairs to replace in tensor names. // Replacements returns a list of string pairs to replace in tensor names.
// See [strings.Replacer](https://pkg.go.dev/strings#Replacer) for details // See [strings.Replacer](https://pkg.go.dev/strings#Replacer) for details
Replacements() []string Replacements() []string
writeFile(io.WriteSeeker, ggml.KV, []ggml.Tensor) error
} }
func ConvertAdapter(fsys fs.FS, f *os.File, baseKV ggml.KV) error { func ConvertAdapter(fsys fs.FS, ws io.WriteSeeker, baseKV ggml.KV) error {
bts, err := fs.ReadFile(fsys, "adapter_config.json") bts, err := fs.ReadFile(fsys, "adapter_config.json")
if err != nil { if err != nil {
return err return err
@@ -150,14 +158,14 @@ func ConvertAdapter(fsys fs.FS, f *os.File, baseKV ggml.KV) error {
return err return err
} }
return writeFile(f, conv.KV(baseKV), conv.Tensors(ts)) return conv.writeFile(ws, conv.KV(baseKV), conv.Tensors(ts))
} }
// Convert writes an Ollama compatible model to the provided io.WriteSeeker based on configurations // Convert writes an Ollama compatible model to the provided io.WriteSeeker based on configurations
// and files it finds in the input path. // and files it finds in the input path.
// Supported input model formats include safetensors. // Supported input model formats include safetensors.
// Supported input tokenizers files include tokenizer.json (preferred) and tokenizer.model. // Supported input tokenizers files include tokenizer.json (preferred) and tokenizer.model.
func ConvertModel(fsys fs.FS, f *os.File) error { func ConvertModel(fsys fs.FS, ws io.WriteSeeker) error {
bts, err := fs.ReadFile(fsys, "config.json") bts, err := fs.ReadFile(fsys, "config.json")
if err != nil { if err != nil {
return err return err
@@ -176,10 +184,6 @@ func ConvertModel(fsys fs.FS, f *os.File) error {
switch p.Architectures[0] { switch p.Architectures[0] {
case "LlamaForCausalLM": case "LlamaForCausalLM":
conv = &llamaModel{} conv = &llamaModel{}
case "MllamaForConditionalGeneration":
conv = &mllamaModel{}
case "Llama4ForConditionalGeneration":
conv = &llama4Model{}
case "Mistral3ForConditionalGeneration": case "Mistral3ForConditionalGeneration":
conv = &mistral3Model{} conv = &mistral3Model{}
case "MixtralForCausalLM": case "MixtralForCausalLM":
@@ -190,14 +194,10 @@ func ConvertModel(fsys fs.FS, f *os.File) error {
conv = &gemma2Model{} conv = &gemma2Model{}
case "Gemma3ForCausalLM", "Gemma3ForConditionalGeneration": case "Gemma3ForCausalLM", "Gemma3ForConditionalGeneration":
conv = &gemma3Model{Architecture: p.Architectures[0]} conv = &gemma3Model{Architecture: p.Architectures[0]}
case "Gemma3nForConditionalGeneration":
conv = &gemma3nModel{}
case "Phi3ForCausalLM": case "Phi3ForCausalLM":
conv = &phi3Model{} conv = &phi3Model{}
case "Qwen2ForCausalLM": case "Qwen2ForCausalLM":
conv = &qwen2Model{} conv = &qwen2Model{}
case "Qwen2_5_VLForConditionalGeneration":
conv = &qwen25VLModel{}
case "BertModel": case "BertModel":
conv = &bertModel{} conv = &bertModel{}
case "CohereForCausalLM": case "CohereForCausalLM":
@@ -221,22 +221,24 @@ func ConvertModel(fsys fs.FS, f *os.File) error {
return err return err
} }
vocabSize := int(cmp.Or(p.VocabSize, p.TextModel.VocabSize)) vocabSize := int(p.VocabSize)
if vocabSize == 0 {
tVocabSize := int(p.TextModel.VocabSize)
vocabSize = tVocabSize
}
switch { switch {
case vocabSize == 0: case vocabSize == 0:
slog.Debug("vocabulary size was not explicitly set by the model", "default size", len(t.Vocabulary.Tokens)) slog.Warn("vocabulary size was not explicitly set by the model", "default size", len(t.Vocabulary.Tokens))
case vocabSize > len(t.Vocabulary.Tokens): case vocabSize > len(t.Vocabulary.Tokens):
slog.Debug("vocabulary is smaller than expected, padding with dummy tokens", "expect", vocabSize, "actual", len(t.Vocabulary.Tokens)) slog.Warn("vocabulary is smaller than expected, padding with dummy tokens", "expect", vocabSize, "actual", len(t.Vocabulary.Tokens))
for i := range vocabSize - len(t.Vocabulary.Tokens) { for i := range vocabSize - len(t.Vocabulary.Tokens) {
t.Vocabulary.Tokens = append(t.Vocabulary.Tokens, fmt.Sprintf("[PAD%d]", i)) t.Vocabulary.Tokens = append(t.Vocabulary.Tokens, fmt.Sprintf("[PAD%d]", i))
t.Vocabulary.Scores = append(t.Vocabulary.Scores, -1) t.Vocabulary.Scores = append(t.Vocabulary.Scores, -1)
t.Vocabulary.Types = append(t.Vocabulary.Types, tokenTypeUserDefined) t.Vocabulary.Types = append(t.Vocabulary.Types, tokenTypeUserDefined)
} }
case vocabSize < len(t.Vocabulary.Tokens): case vocabSize < len(t.Vocabulary.Tokens):
slog.Debug("vocabulary is larger than expected", "want", vocabSize, "got", len(t.Vocabulary.Tokens)) return fmt.Errorf("vocabulary is larger than expected '%d' instead of '%d'", len(t.Vocabulary.Tokens), vocabSize)
p.VocabSize = uint32(len(t.Vocabulary.Tokens))
p.TextModel.VocabSize = uint32(len(t.Vocabulary.Tokens))
default: default:
slog.Debug("vocabulary", "size", len(t.Vocabulary.Tokens)) slog.Debug("vocabulary", "size", len(t.Vocabulary.Tokens))
} }
@@ -246,13 +248,5 @@ func ConvertModel(fsys fs.FS, f *os.File) error {
return err return err
} }
return writeFile(f, conv.KV(t), conv.Tensors(ts)) return conv.writeFile(ws, conv.KV(t), conv.Tensors(ts))
}
func writeFile(f *os.File, kv ggml.KV, ts []*ggml.Tensor) error {
for i := range ts {
ts[i].Shape = slices.Clone(ts[i].Shape)
slices.Reverse(ts[i].Shape)
}
return ggml.WriteGGUF(f, kv, ts)
} }

View File

@@ -132,8 +132,8 @@ func (p *bertModel) KV(t *Tokenizer) ggml.KV {
return kv return kv
} }
func (p *bertModel) Tensors(ts []Tensor) []*ggml.Tensor { func (p *bertModel) Tensors(ts []Tensor) []ggml.Tensor {
var out []*ggml.Tensor var out []ggml.Tensor
for _, t := range ts { for _, t := range ts {
if slices.Contains([]string{ if slices.Contains([]string{
"embeddings.position_ids", "embeddings.position_ids",
@@ -143,7 +143,7 @@ func (p *bertModel) Tensors(ts []Tensor) []*ggml.Tensor {
continue continue
} }
out = append(out, &ggml.Tensor{ out = append(out, ggml.Tensor{
Name: t.Name(), Name: t.Name(),
Kind: t.Kind(), Kind: t.Kind(),
Shape: t.Shape(), Shape: t.Shape(),

View File

@@ -43,10 +43,10 @@ func (p *commandrModel) KV(t *Tokenizer) ggml.KV {
return kv return kv
} }
func (p *commandrModel) Tensors(ts []Tensor) []*ggml.Tensor { func (p *commandrModel) Tensors(ts []Tensor) []ggml.Tensor {
var out []*ggml.Tensor var out []ggml.Tensor
for _, t := range ts { for _, t := range ts {
out = append(out, &ggml.Tensor{ out = append(out, ggml.Tensor{
Name: t.Name(), Name: t.Name(),
Kind: t.Kind(), Kind: t.Kind(),
Shape: t.Shape(), Shape: t.Shape(),

View File

@@ -42,14 +42,14 @@ func (p *gemmaModel) KV(t *Tokenizer) ggml.KV {
return kv return kv
} }
func (p *gemmaModel) Tensors(ts []Tensor) []*ggml.Tensor { func (p *gemmaModel) Tensors(ts []Tensor) []ggml.Tensor {
var out []*ggml.Tensor var out []ggml.Tensor
for _, t := range ts { for _, t := range ts {
if !strings.HasPrefix(t.Name(), "v.") && strings.HasSuffix(t.Name(), "_norm.weight") { if !strings.HasPrefix(t.Name(), "v.") && strings.HasSuffix(t.Name(), "_norm.weight") {
t.SetRepacker(p.addOne) t.SetRepacker(p.addOne)
} }
out = append(out, &ggml.Tensor{ out = append(out, ggml.Tensor{
Name: t.Name(), Name: t.Name(),
Kind: t.Kind(), Kind: t.Kind(),
Shape: t.Shape(), Shape: t.Shape(),

View File

@@ -21,8 +21,8 @@ func (p *gemma2Adapter) KV(baseKV ggml.KV) ggml.KV {
return kv return kv
} }
func (p *gemma2Adapter) Tensors(ts []Tensor) []*ggml.Tensor { func (p *gemma2Adapter) Tensors(ts []Tensor) []ggml.Tensor {
var out []*ggml.Tensor var out []ggml.Tensor
for _, t := range ts { for _, t := range ts {
shape := t.Shape() shape := t.Shape()
if (strings.HasSuffix(t.Name(), "weight.lora_a") && shape[0] > shape[1]) || if (strings.HasSuffix(t.Name(), "weight.lora_a") && shape[0] > shape[1]) ||
@@ -31,7 +31,7 @@ func (p *gemma2Adapter) Tensors(ts []Tensor) []*ggml.Tensor {
t.SetRepacker(p.repack) t.SetRepacker(p.repack)
} }
out = append(out, &ggml.Tensor{ out = append(out, ggml.Tensor{
Name: t.Name(), Name: t.Name(),
Kind: t.Kind(), Kind: t.Kind(),
Shape: t.Shape(), Shape: t.Shape(),

View File

@@ -1,168 +0,0 @@
package convert
import (
"slices"
"strings"
"github.com/ollama/ollama/fs/ggml"
"github.com/pdevine/tensor"
"github.com/pdevine/tensor/native"
"gonum.org/v1/gonum/stat/distuv"
)
type gemma3nModel struct {
ModelParameters
TextModel struct {
ActivationSparsityPattern []float32 `json:"activation_sparsity_pattern"`
AltupActiveIdx uint32 `json:"altup_active_idx"`
AltupCoefClip float32 `json:"altup_coef_clip"`
AltupCorrectScale bool `json:"altup_correct_scale"`
AltupLRMultiplier float32 `json:"altup_lr_multiplier"`
AltupNumInputs uint32 `json:"altup_num_inputs"`
HeadDim uint32 `json:"head_dim"`
HiddenSize uint32 `json:"hidden_size"`
HiddenSizePerLayerInput uint32 `json:"hidden_size_per_layer_input"`
IntermediateSize uint32 `json:"intermediate_size"`
LaurelRank uint32 `json:"laurel_rank"`
MaxPositionEmbeddings uint32 `json:"max_position_embeddings"`
NumAttentionHeads uint32 `json:"num_attention_heads"`
NumHiddenLayers uint32 `json:"num_hidden_layers"`
NumKeyValueHeads uint32 `json:"num_key_value_heads"`
NumKVSharedLayers uint32 `json:"num_kv_shared_layers"`
RMSNormEPS float32 `json:"rms_norm_eps"`
RopeLocalBaseFreq float32 `json:"rope_local_base_freq"`
RopeTheta float32 `json:"rope_theta"`
SlidingWindow uint32 `json:"sliding_window"`
LayerTypes []string `json:"layer_types"`
} `json:"text_config"`
VisionModel struct{} `json:"vision_config"`
}
func (m *gemma3nModel) KV(t *Tokenizer) ggml.KV {
kv := m.ModelParameters.KV(t)
kv["general.architecture"] = "gemma3n"
kv["gemma3n.activation_sparsity_scale"] = slices.Collect(func(yield func(float32) bool) {
norm := distuv.Normal{Mu: 0, Sigma: 1}
for _, v := range m.TextModel.ActivationSparsityPattern {
if !yield(float32(norm.Quantile(float64(v)))) {
break
}
}
})
kv["gemma3n.altup.active_idx"] = m.TextModel.AltupActiveIdx
kv["gemma3n.altup.correct_scale"] = m.TextModel.AltupCorrectScale
kv["gemma3n.altup.lr_multiplier"] = m.TextModel.AltupLRMultiplier
kv["gemma3n.altup.num_inputs"] = m.TextModel.AltupNumInputs
kv["gemma3n.attention.head_count_kv"] = m.TextModel.NumKeyValueHeads
kv["gemma3n.attention.head_count"] = m.TextModel.NumAttentionHeads
kv["gemma3n.attention.layer_norm_rms_epsilon"] = m.TextModel.RMSNormEPS
kv["gemma3n.attention.sliding_window"] = m.TextModel.SlidingWindow
kv["gemma3n.attention.sliding_window_pattern"] = slices.Collect(func(yield func(bool) bool) {
for _, t := range m.TextModel.LayerTypes {
if !yield(t == "sliding_attention") {
break
}
}
})
kv["gemma3n.attention.shared_kv_layers"] = m.TextModel.NumKVSharedLayers
kv["gemma3n.block_count"] = m.TextModel.NumHiddenLayers
kv["gemma3n.context_length"] = m.TextModel.MaxPositionEmbeddings
kv["gemma3n.embedding_length_per_layer_input"] = m.TextModel.HiddenSizePerLayerInput
kv["gemma3n.embedding_length"] = m.TextModel.HiddenSize
kv["gemma3n.feed_forward_length"] = m.TextModel.IntermediateSize
kv["gemma3n.head_dim"] = m.TextModel.HeadDim
kv["gemma3n.laurel_rank"] = m.TextModel.LaurelRank
kv["gemma3n.num_kv_shared_layers"] = m.TextModel.NumKVSharedLayers
kv["gemma3n.rope.freq_base_local"] = m.TextModel.RopeLocalBaseFreq
kv["gemma3n.rope.freq_base"] = m.TextModel.RopeTheta
return kv
}
func (m *gemma3nModel) Tensors(ts []Tensor) []*ggml.Tensor {
out, ts := mergeTensors(ts,
merge{"altup_proj.*.weight", "altup_proj.weight"},
merge{"altup_unembd_proj.*.weight", "altup_unembd_proj.weight"},
)
for _, t := range ts {
switch {
case strings.Contains(t.Name(), "audio_tower"),
strings.Contains(t.Name(), "embed_audio"),
strings.Contains(t.Name(), "vision_tower"),
strings.Contains(t.Name(), "embed_vision"):
// TODO: handle audio and vision towers
continue
case strings.Contains(t.Name(), "altup_predict_coef"),
strings.Contains(t.Name(), "altup_correct_coef"):
if m.TextModel.AltupCoefClip > 0 {
t.SetRepacker(func(name string, data []float32, shape []uint64) (_ []float32, err error) {
dims := make([]int, len(shape))
for i := range shape {
dims[i] = int(shape[i])
}
var t tensor.Tensor = tensor.New(tensor.WithShape(dims...), tensor.WithBacking(data))
t, err = tensor.Clamp(t, -m.TextModel.AltupCoefClip, m.TextModel.AltupCoefClip)
if err != nil {
return nil, err
}
if err := t.Reshape(t.Shape().TotalSize()); err != nil {
return nil, err
}
return native.VectorF32(t.(*tensor.Dense))
})
}
}
out = append(out, &ggml.Tensor{
Name: t.Name(),
Kind: t.Kind(),
Shape: t.Shape(),
WriterTo: t,
})
}
return out
}
func (m *gemma3nModel) Replacements() []string {
return []string{
"model.language_model.embed_tokens_per_layer", "per_layer_token_embd",
"model.language_model.embed_tokens", "token_embd",
"model.language_model.per_layer_model_projection", "per_layer_model_proj",
"model.language_model.per_layer_projection_norm", "per_layer_proj_norm", "model.language_model.altup_projections", "altup_proj",
"model.language_model.altup_unembed_projections", "altup_unembd_proj",
"model.language_model.norm", "output_norm",
"model.language_model.layers", "blk",
"input_layernorm", "attn_norm",
"self_attn.q_proj", "attn_q",
"self_attn.q_norm", "attn_q_norm",
"self_attn.k_proj", "attn_k",
"self_attn.k_norm", "attn_k_norm",
"self_attn.v_proj", "attn_v",
"self_attn.o_proj", "attn_output",
"post_attention_layernorm", "post_attention_norm",
"pre_feedforward_layernorm", "ffn_norm",
"mlp.gate_proj", "ffn_gate",
"mlp.up_proj", "ffn_up",
"mlp.down_proj", "ffn_down",
"post_feedforward_layernorm", "post_ffw_norm",
"per_layer_input_gate", "inp_gate",
"per_layer_projection", "proj",
"post_per_layer_input_norm", "post_norm",
"altup.", "altup_",
"modality_router", "router",
"prediction_coefs", "predict_coef",
"correction_coefs", "correct_coef",
"correct_output_scale", "correct_scale.weight",
"laurel.", "laurel_",
"linear_left", "l",
"linear_right", "r",
"post_laurel_norm", "post_norm",
}
}

View File

@@ -28,12 +28,12 @@ type llamaModel struct {
NumKeyValueHeads uint32 `json:"num_key_value_heads"` NumKeyValueHeads uint32 `json:"num_key_value_heads"`
RopeTheta float32 `json:"rope_theta"` RopeTheta float32 `json:"rope_theta"`
RopeScaling struct { RopeScaling struct {
Type string `json:"type"` Type string `json:"type"`
RopeType string `json:"rope_type"` RopeType string `json:"rope_type"`
Factor float32 `json:"factor"` Factor float32 `json:"factor"`
LowFrequencyFactor float32 `json:"low_freq_factor"` LowFrequencyFactor float32 `json:"low_freq_factor"`
HighFrequencyFactor float32 `json:"high_freq_factor"` HighFrequencyFactor float32 `json:"high_freq_factor"`
OriginalMaxPositionEmbeddings uint32 `json:"original_max_position_embeddings"` OriginalMaxPositionalEmbeddings uint32 `json:"original_max_positional_embeddings"`
factors ropeFactor factors ropeFactor
} `json:"rope_scaling"` } `json:"rope_scaling"`
@@ -42,8 +42,6 @@ type llamaModel struct {
LayerNormEpsilon float32 `json:"layer_norm_epsilon"` LayerNormEpsilon float32 `json:"layer_norm_epsilon"`
NormEpsilon float32 `json:"norm_epsilon"` NormEpsilon float32 `json:"norm_epsilon"`
HeadDim uint32 `json:"head_dim"` HeadDim uint32 `json:"head_dim"`
skipRepack bool
} }
var _ ModelConverter = (*llamaModel)(nil) var _ ModelConverter = (*llamaModel)(nil)
@@ -72,10 +70,6 @@ func (p *llamaModel) KV(t *Tokenizer) ggml.KV {
kv["llama.rope.dimension_count"] = p.HiddenSize / headCount kv["llama.rope.dimension_count"] = p.HiddenSize / headCount
} }
if p.HeadDim > 0 {
kv["llama.attention.head_dim"] = p.HeadDim
}
if p.RopeTheta > 0 { if p.RopeTheta > 0 {
kv["llama.rope.freq_base"] = p.RopeTheta kv["llama.rope.freq_base"] = p.RopeTheta
} }
@@ -90,7 +84,7 @@ func (p *llamaModel) KV(t *Tokenizer) ggml.KV {
factorLow := cmp.Or(p.RopeScaling.LowFrequencyFactor, 1.0) factorLow := cmp.Or(p.RopeScaling.LowFrequencyFactor, 1.0)
factorHigh := cmp.Or(p.RopeScaling.HighFrequencyFactor, 4.0) factorHigh := cmp.Or(p.RopeScaling.HighFrequencyFactor, 4.0)
original := cmp.Or(p.RopeScaling.OriginalMaxPositionEmbeddings, 8192) original := cmp.Or(p.RopeScaling.OriginalMaxPositionalEmbeddings, 8192)
lambdaLow := float32(original) / factorLow lambdaLow := float32(original) / factorLow
lambdaHigh := float32(original) / factorHigh lambdaHigh := float32(original) / factorHigh
@@ -126,11 +120,11 @@ func (p *llamaModel) KV(t *Tokenizer) ggml.KV {
return kv return kv
} }
func (p *llamaModel) Tensors(ts []Tensor) []*ggml.Tensor { func (p *llamaModel) Tensors(ts []Tensor) []ggml.Tensor {
var out []*ggml.Tensor var out []ggml.Tensor
if p.RopeScaling.factors != nil { if p.RopeScaling.factors != nil {
out = append(out, &ggml.Tensor{ out = append(out, ggml.Tensor{
Name: "rope_freqs.weight", Name: "rope_freqs.weight",
Kind: 0, Kind: 0,
Shape: []uint64{uint64(len(p.RopeScaling.factors))}, Shape: []uint64{uint64(len(p.RopeScaling.factors))},
@@ -139,14 +133,12 @@ func (p *llamaModel) Tensors(ts []Tensor) []*ggml.Tensor {
} }
for _, t := range ts { for _, t := range ts {
if strings.HasSuffix(t.Name(), "attn_q.weight") || strings.HasSuffix(t.Name(), "attn_k.weight") || if strings.HasSuffix(t.Name(), "attn_q.weight") ||
strings.HasSuffix(t.Name(), "attn_q_proj.weight") || strings.HasSuffix(t.Name(), "attn_k_proj.weight") { strings.HasSuffix(t.Name(), "attn_k.weight") {
if !p.skipRepack { t.SetRepacker(p.repack)
t.SetRepacker(p.repack)
}
} }
out = append(out, &ggml.Tensor{ out = append(out, ggml.Tensor{
Name: t.Name(), Name: t.Name(),
Kind: t.Kind(), Kind: t.Kind(),
Shape: t.Shape(), Shape: t.Shape(),
@@ -182,9 +174,9 @@ func (p *llamaModel) repack(name string, data []float32, shape []uint64) ([]floa
} }
var heads uint32 var heads uint32
if strings.HasSuffix(name, "attn_q.weight") || strings.HasSuffix(name, "attn_q_proj.weight") { if strings.HasSuffix(name, "attn_q.weight") {
heads = p.NumAttentionHeads heads = p.NumAttentionHeads
} else if strings.HasSuffix(name, "attn_k.weight") || strings.HasSuffix(name, "attn_k_proj.weight") { } else if strings.HasSuffix(name, "attn_k.weight") {
heads = cmp.Or(p.NumKeyValueHeads, p.NumAttentionHeads) heads = cmp.Or(p.NumKeyValueHeads, p.NumAttentionHeads)
} else { } else {
return nil, fmt.Errorf("unknown tensor for repack: %s", name) return nil, fmt.Errorf("unknown tensor for repack: %s", name)

View File

@@ -1,169 +0,0 @@
package convert
import (
"slices"
"strings"
"github.com/pdevine/tensor"
"github.com/pdevine/tensor/native"
"github.com/ollama/ollama/fs/ggml"
)
type llama4Model struct {
ModelParameters
TextModel struct {
llamaModel
NumExpertsPerToken uint32 `json:"num_experts_per_tok"`
NumLocalExperts uint32 `json:"num_local_experts"`
InterleaveMOELayerStep uint32 `json:"interleave_moe_layer_step"`
UseQKNorm bool `json:"use_qk_norm"`
IntermediateSizeMLP uint32 `json:"intermediate_size_mlp"`
AttentionChunkSize uint32 `json:"attention_chunk_size"`
} `json:"text_config"`
VisionModel struct {
NumHiddenLayers uint32 `json:"num_hidden_layers"`
HiddenSize uint32 `json:"hidden_size"`
IntermediateSize uint32 `json:"intermediate_size"`
NumAttentionHeads uint32 `json:"num_attention_heads"`
ImageSize uint32 `json:"image_size"`
PatchSize uint32 `json:"patch_size"`
RopeTheta float32 `json:"rope_theta"`
NormEpsilon float32 `json:"norm_eps"`
PixelShuffleRatio float32 `json:"pixel_shuffle_ratio"`
} `json:"vision_config"`
}
// KV implements ModelConverter.
func (p *llama4Model) KV(t *Tokenizer) ggml.KV {
kv := p.ModelParameters.KV(t)
kv["general.architecture"] = "llama4"
for k, v := range p.TextModel.KV(t) {
if strings.HasPrefix(k, "llama.") {
kv[strings.ReplaceAll(k, "llama.", "llama4.")] = v
}
}
kv["llama4.feed_forward_length"] = p.TextModel.IntermediateSizeMLP
kv["llama4.expert_feed_forward_length"] = p.TextModel.IntermediateSize
kv["llama4.expert_count"] = p.TextModel.NumLocalExperts
kv["llama4.expert_used_count"] = p.TextModel.NumExpertsPerToken
kv["llama4.interleave_moe_layer_step"] = p.TextModel.InterleaveMOELayerStep
kv["llama4.use_qk_norm"] = p.TextModel.UseQKNorm
kv["llama4.attention.chunk_size"] = p.TextModel.AttentionChunkSize
kv["llama4.vision.block_count"] = p.VisionModel.NumHiddenLayers
kv["llama4.vision.embedding_length"] = p.VisionModel.HiddenSize
kv["llama4.vision.feed_forward_length"] = p.VisionModel.IntermediateSize
kv["llama4.vision.attention.head_count"] = p.VisionModel.NumAttentionHeads
kv["llama4.vision.image_size"] = p.VisionModel.ImageSize
kv["llama4.vision.patch_size"] = p.VisionModel.PatchSize
kv["llama4.vision.rope.freq_base"] = p.VisionModel.RopeTheta
kv["llama4.vision.layer_norm_epsilon"] = p.VisionModel.NormEpsilon
kv["llama4.vision.pixel_shuffle_ratio"] = p.VisionModel.PixelShuffleRatio
return kv
}
// Replacements implements ModelConverter.
func (p *llama4Model) Replacements() []string {
return append(
p.TextModel.Replacements(),
"language_model.", "",
"vision_model", "v",
"multi_modal_projector", "mm",
"feed_forward.down_proj", "ffn_down",
"feed_forward.up_proj", "ffn_up",
"feed_forward.gate_proj", "ffn_gate",
"feed_forward.", "ffn_",
"shared_expert.down_proj", "down_shexp",
"shared_expert.gate_proj", "gate_shexp",
"shared_expert.up_proj", "up_shexp",
"experts.down_proj", "down_exps.weight",
"experts.gate_up_proj", "gate_up_exps.weight",
"router", "gate_inp",
"patch_embedding.linear", "patch_embedding",
)
}
// Tensors implements ModelConverter.
func (p *llama4Model) Tensors(ts []Tensor) []*ggml.Tensor {
var out []*ggml.Tensor
var textTensors []Tensor
for _, t := range ts {
if strings.HasPrefix(t.Name(), "v.") || strings.HasPrefix(t.Name(), "mm.") {
out = append(out, &ggml.Tensor{
Name: t.Name(),
Kind: t.Kind(),
Shape: t.Shape(),
WriterTo: t,
})
} else if strings.Contains(t.Name(), "ffn_gate_up_exps") {
// gate and up projectors are fused
// dims[1], dims[2] must be swapped
// [experts, hidden_size, intermediate_size * 2] --> [experts, intermediate_size, hidden_size]
halfDim := int(t.Shape()[2]) / 2
newShape := slices.Clone(t.Shape())
newShape[1], newShape[2] = newShape[2]/2, newShape[1]
for i, name := range []string{"ffn_gate_exps", "ffn_up_exps"} {
// clone tensor since we need separate repackers
tt := t.Clone()
tt.SetRepacker(p.repack(nil, nil, tensor.S(i*halfDim, (i+1)*halfDim)))
out = append(out, &ggml.Tensor{
Name: strings.ReplaceAll(tt.Name(), "ffn_gate_up_exps", name),
Kind: tt.Kind(),
Shape: newShape,
WriterTo: tt,
})
}
} else if strings.Contains(t.Name(), "ffn_down_exps") {
// dims[1], dims[2] must be swapped
// [experts, intermediate_size, hidden_size] --> [experts, hidden_size, intermediate_size]
t.SetRepacker(p.repack())
newShape := slices.Clone(t.Shape())
newShape[1], newShape[2] = newShape[2], newShape[1]
out = append(out, &ggml.Tensor{
Name: t.Name(),
Kind: t.Kind(),
Shape: newShape,
WriterTo: t,
})
} else {
textTensors = append(textTensors, t)
}
}
p.TextModel.skipRepack = true
out = append(out, p.TextModel.Tensors(textTensors)...)
return out
}
func (p *llama4Model) repack(slice ...tensor.Slice) Repacker {
return func(name string, data []float32, shape []uint64) ([]float32, error) {
dims := make([]int, len(shape))
for i, dim := range shape {
dims[i] = int(dim)
}
var t tensor.Tensor = tensor.New(tensor.WithShape(dims...), tensor.WithBacking(data))
t, err := t.Slice(slice...)
if err != nil {
return nil, err
}
if err := t.T(0, 2, 1); err != nil {
return nil, err
}
t = tensor.Materialize(t)
// flatten tensor so it can be return as a vector
if err := t.Reshape(t.Shape().TotalSize()); err != nil {
return nil, err
}
return native.VectorF32(t.(*tensor.Dense))
}
}

View File

@@ -29,8 +29,8 @@ func (p *llamaAdapter) KV(baseKV ggml.KV) ggml.KV {
return kv return kv
} }
func (p *llamaAdapter) Tensors(ts []Tensor) []*ggml.Tensor { func (p *llamaAdapter) Tensors(ts []Tensor) []ggml.Tensor {
var out []*ggml.Tensor var out []ggml.Tensor
for _, t := range ts { for _, t := range ts {
shape := t.Shape() shape := t.Shape()
if (strings.HasSuffix(t.Name(), "weight.lora_a") && shape[0] > shape[1]) || if (strings.HasSuffix(t.Name(), "weight.lora_a") && shape[0] > shape[1]) ||
@@ -41,7 +41,7 @@ func (p *llamaAdapter) Tensors(ts []Tensor) []*ggml.Tensor {
t.SetRepacker(p.repack) t.SetRepacker(p.repack)
} }
out = append(out, &ggml.Tensor{ out = append(out, ggml.Tensor{
Name: t.Name(), Name: t.Name(),
Kind: t.Kind(), Kind: t.Kind(),
Shape: shape, Shape: shape,

View File

@@ -89,8 +89,8 @@ func (p *mistral3Model) KV(t *Tokenizer) ggml.KV {
return kv return kv
} }
func (p *mistral3Model) Tensors(ts []Tensor) []*ggml.Tensor { func (p *mistral3Model) Tensors(ts []Tensor) []ggml.Tensor {
var out []*ggml.Tensor var out []ggml.Tensor
for _, t := range ts { for _, t := range ts {
if !strings.HasPrefix(t.Name(), "v.") { if !strings.HasPrefix(t.Name(), "v.") {
@@ -100,7 +100,7 @@ func (p *mistral3Model) Tensors(ts []Tensor) []*ggml.Tensor {
} }
} }
out = append(out, &ggml.Tensor{ out = append(out, ggml.Tensor{
Name: t.Name(), Name: t.Name(),
Kind: t.Kind(), Kind: t.Kind(),
Shape: t.Shape(), Shape: t.Shape(),

View File

@@ -2,6 +2,9 @@ package convert
import ( import (
"fmt" "fmt"
"io"
"slices"
"strings"
"github.com/ollama/ollama/fs/ggml" "github.com/ollama/ollama/fs/ggml"
) )
@@ -26,39 +29,66 @@ func (p *mixtralModel) KV(t *Tokenizer) ggml.KV {
return kv return kv
} }
func (p *mixtralModel) Tensors(ts []Tensor) []*ggml.Tensor { func (p *mixtralModel) Tensors(ts []Tensor) []ggml.Tensor {
merges := make([]merge, 0, p.NumHiddenLayers*6) oldnew := []string{
for i := range p.NumHiddenLayers { "model.layers", "blk",
merges = append(merges, merge{ "w1", "ffn_gate_exps",
fmt.Sprintf("blk.%d.*.w1.weight", i), "w2", "ffn_down_exps",
fmt.Sprintf("blk.%d.ffn_gate_exps.weight", i), "w3", "ffn_up_exps",
}, merge{ }
fmt.Sprintf("blk.%d.*.w1.bias", i),
fmt.Sprintf("blk.%d.ffn_gate_exps.bias", i), for i := range p.NumLocalExperts {
}, merge{ oldnew = append(oldnew, fmt.Sprintf(".block_sparse_moe.experts.%d.", i), ".")
fmt.Sprintf("blk.%d.*.w2.weight", i), }
fmt.Sprintf("blk.%d.ffn_up_exps.weight", i),
}, merge{ // group experts of the same layer (model.layers.%d) and type (w[123]) into a single tensor
fmt.Sprintf("blk.%d.*.w2.bias", i), namer := strings.NewReplacer(oldnew...)
fmt.Sprintf("blk.%d.ffn_up_exps.bias", i), experts := make(map[string]experts)
}, merge{
fmt.Sprintf("blk.%d.*.w3.weight", i), // merge experts into a single tensor while removing them from ts
fmt.Sprintf("blk.%d.ffn_down_exps.weight", i), ts = slices.DeleteFunc(ts, func(t Tensor) bool {
}, merge{ if !strings.Contains(t.Name(), ".block_sparse_moe.experts.") {
fmt.Sprintf("blk.%d.*.w3.bias", i), return false
fmt.Sprintf("blk.%d.ffn_down_exps.bias", i), }
name := namer.Replace(t.Name())
experts[name] = append(experts[name], t)
return true
})
var out []ggml.Tensor
for n, e := range experts {
// TODO(mxyng): sanity check experts
out = append(out, ggml.Tensor{
Name: n,
Kind: e[0].Kind(),
Shape: append([]uint64{uint64(len(e))}, e[0].Shape()...),
WriterTo: e,
}) })
} }
out, ts := mergeTensors(ts, merges...)
return append(out, p.llamaModel.Tensors(ts)...) return append(out, p.llamaModel.Tensors(ts)...)
} }
func (p *mixtralModel) Replacements() []string { func (p *mixtralModel) Replacements() []string {
return append( return append(
p.llamaModel.Replacements(), p.llamaModel.Replacements(),
"model.layers", "blk",
"block_sparse_moe.gate", "ffn_gate_inp", "block_sparse_moe.gate", "ffn_gate_inp",
"block_sparse_moe.experts.", ".",
) )
} }
type experts []Tensor
func (e experts) WriteTo(w io.Writer) (int64, error) {
// TODO(mxyng): experts _should_ be numerically sorted by expert but this should check
for _, t := range e {
// the canonical merged experts tensor stacks all experts along a new, 0 axis,
// e.g. `tensor.Stack(0, e[0], e[1:]...)`, which requires allocating temporary buffers
// this accomplishes the same thing by writing each expert tensor in sequence
if _, err := t.WriteTo(w); err != nil {
return 0, err
}
}
return 0, nil
}

View File

@@ -1,179 +0,0 @@
package convert
import (
"strings"
"github.com/ollama/ollama/fs/ggml"
"github.com/pdevine/tensor"
"github.com/pdevine/tensor/native"
)
type mllamaModel struct {
ModelParameters
TextModel struct {
llamaModel
CrossAttentionLayers []int32 `json:"cross_attention_layers"`
} `json:"text_config"`
VisionModel struct {
NumHiddenLayers uint32 `json:"num_hidden_layers"`
NumGlobalLayers uint32 `json:"num_global_layers"`
IntermediateLayersIndices []int32 `json:"intermediate_layers_indices"`
HiddenSize uint32 `json:"hidden_size"`
IntermediateSize uint32 `json:"intermediate_size"`
AttentionHeads uint32 `json:"attention_heads"`
ImageSize uint32 `json:"image_size"`
PatchSize uint32 `json:"patch_size"`
NumChannels uint32 `json:"num_channels"`
MaxNumTiles uint32 `json:"max_num_tiles"`
NormEpsilon float32 `json:"norm_eps"`
RopeTheta float32 `json:"rope.freq_base"`
} `json:"vision_config"`
}
func (m *mllamaModel) KV(t *Tokenizer) ggml.KV {
kv := m.ModelParameters.KV(t)
kv["general.architecture"] = "mllama"
for k, v := range m.TextModel.KV(t) {
if strings.HasPrefix(k, "llama.") {
kv[strings.ReplaceAll(k, "llama.", "mllama.")] = v
}
}
kv["mllama.attention.cross_attention_layers"] = m.TextModel.CrossAttentionLayers
kv["mllama.vision.block_count"] = m.VisionModel.NumHiddenLayers
kv["mllama.vision.global.block_count"] = m.VisionModel.NumGlobalLayers
kv["mllama.vision.intermediate_layers_indices"] = m.VisionModel.IntermediateLayersIndices
kv["mllama.vision.embedding_length"] = m.VisionModel.HiddenSize
kv["mllama.vision.feed_forward_length"] = m.VisionModel.IntermediateSize
kv["mllama.vision.attention.head_count"] = m.VisionModel.AttentionHeads
kv["mllama.vision.attention.layer_norm_epsilon"] = m.VisionModel.NormEpsilon
kv["mllama.vision.image_size"] = m.VisionModel.ImageSize
kv["mllama.vision.patch_size"] = m.VisionModel.PatchSize
kv["mllama.vision.max_num_tiles"] = m.VisionModel.MaxNumTiles
kv["mllama.vision.num_channels"] = m.VisionModel.NumChannels
return kv
}
func (m *mllamaModel) Replacements() []string {
return append(
m.TextModel.Replacements(),
"language_model.", "",
"gate_attn", "attn_gate",
"gate_ffn", "ffn_gate",
"cross_attn.", "cross_attn_",
"vision_model", "v",
"class_embedding", "class_embd",
"patch_embedding", "patch_embd",
"gated_positional_embedding.tile_embedding", "tile_position_embd",
"gated_positional_embedding.embedding", "position_embd.weight",
"gated_positional_embedding", "position_embd",
"embedding.weight", "weight",
"pre_tile_positional_embedding", "pre_tile_position_embd",
"post_tile_positional_embedding", "post_tile_position_embd",
"layernorm_pre", "pre_ln",
"layernorm_post", "post_ln",
"global_transformer.layers", "global.blk",
"transformer.layers", "blk",
"mlp.fc1", "ffn_up",
"mlp.fc2", "ffn_down",
"multi_modal_projector", "mm.0",
)
}
func (m *mllamaModel) Tensors(ts []Tensor) []*ggml.Tensor {
var out []*ggml.Tensor
var text []Tensor
for _, t := range ts {
if !strings.HasPrefix(t.Name(), "v.") && !strings.HasPrefix(t.Name(), "mm.") {
text = append(text, t)
} else if t.Name() == "v.position_embd.gate" {
for _, name := range []string{"v.position_embd.gate", "v.tile_position_embd.gate"} {
tt := t.Clone()
tt.SetRepacker(m.repack(name))
out = append(out, &ggml.Tensor{
Name: name,
Kind: t.Kind(),
Shape: t.Shape(),
WriterTo: tt,
})
}
} else {
if t.Name() == "v.pre_tile_position_embd.gate" || t.Name() == "v.post_tile_position_embd.gate" {
t.SetRepacker(m.repack(t.Name()))
} else if strings.HasSuffix(t.Name(), "attn_q.weight") || strings.HasSuffix(t.Name(), "attn_k.weight") {
t.SetRepacker(m.repack(t.Name()))
} else if strings.HasSuffix(t.Name(), "attn_gate") || strings.HasSuffix(t.Name(), "ffn_gate") {
t.SetRepacker(m.repack(t.Name()))
}
out = append(out, &ggml.Tensor{
Name: t.Name(),
Kind: t.Kind(),
Shape: t.Shape(),
WriterTo: t,
})
}
}
return append(out, m.TextModel.Tensors(text)...)
}
func (m *mllamaModel) repack(name string) Repacker {
return func(_ string, data []float32, shape []uint64) (_ []float32, err error) {
dims := make([]int, len(shape))
for i, dim := range shape {
dims[i] = int(dim)
}
var t tensor.Tensor = tensor.New(tensor.WithShape(dims...), tensor.WithBacking(data))
if strings.HasSuffix(name, "attn_q.weight") || strings.HasSuffix(name, "attn_k.weight") {
heads := m.VisionModel.AttentionHeads
if err := t.Reshape(append([]int{int(heads), 2, dims[0] / int(heads) / 2}, dims[1:]...)...); err != nil {
return nil, err
}
if err := t.T(0, 2, 1, 3); err != nil {
return nil, err
}
if err := t.Reshape(dims...); err != nil {
return nil, err
}
if err := t.Transpose(); err != nil {
return nil, err
}
} else {
t, err = tensor.Tanh(t)
if err != nil {
return nil, err
}
if name == "v.position_embd.gate" {
t, err = tensor.Sub(float32(1), t)
if err != nil {
return nil, err
}
}
}
t = tensor.Materialize(t)
// flatten tensor so it can be return as a vector
if err := t.Reshape(t.Shape().TotalSize()); err != nil {
return nil, err
}
return native.VectorF32(t.(*tensor.Dense))
}
}

View File

@@ -68,19 +68,19 @@ func (p *phi3Model) KV(t *Tokenizer) ggml.KV {
return kv return kv
} }
func (p *phi3Model) Tensors(ts []Tensor) []*ggml.Tensor { func (p *phi3Model) Tensors(ts []Tensor) []ggml.Tensor {
var addRopeFactors sync.Once var addRopeFactors sync.Once
out := make([]*ggml.Tensor, 0, len(ts)+2) out := make([]ggml.Tensor, 0, len(ts)+2)
for _, t := range ts { for _, t := range ts {
if strings.HasPrefix(t.Name(), "blk.0.") { if strings.HasPrefix(t.Name(), "blk.0.") {
addRopeFactors.Do(func() { addRopeFactors.Do(func() {
out = append(out, &ggml.Tensor{ out = append(out, ggml.Tensor{
Name: "rope_factors_long.weight", Name: "rope_factors_long.weight",
Kind: 0, Kind: 0,
Shape: []uint64{uint64(len(p.RopeScaling.LongFactor))}, Shape: []uint64{uint64(len(p.RopeScaling.LongFactor))},
WriterTo: p.RopeScaling.LongFactor, WriterTo: p.RopeScaling.LongFactor,
}, &ggml.Tensor{ }, ggml.Tensor{
Name: "rope_factors_short.weight", Name: "rope_factors_short.weight",
Kind: 0, Kind: 0,
Shape: []uint64{uint64(len(p.RopeScaling.ShortFactor))}, Shape: []uint64{uint64(len(p.RopeScaling.ShortFactor))},
@@ -89,7 +89,7 @@ func (p *phi3Model) Tensors(ts []Tensor) []*ggml.Tensor {
}) })
} }
out = append(out, &ggml.Tensor{ out = append(out, ggml.Tensor{
Name: t.Name(), Name: t.Name(),
Kind: t.Kind(), Kind: t.Kind(),
Shape: t.Shape(), Shape: t.Shape(),
@@ -118,5 +118,6 @@ func (p *phi3Model) Replacements() []string {
type ropeFactor []float32 type ropeFactor []float32
func (r ropeFactor) WriteTo(w io.Writer) (int64, error) { func (r ropeFactor) WriteTo(w io.Writer) (int64, error) {
return 0, binary.Write(w, binary.LittleEndian, r) err := binary.Write(w, binary.LittleEndian, r)
return 0, err
} }

View File

@@ -15,7 +15,6 @@ type qwen2Model struct {
Type string `json:"type"` Type string `json:"type"`
Factor ropeFactor `json:"factor"` Factor ropeFactor `json:"factor"`
OriginalMaxPositionEmbeddings uint32 `json:"original_max_position_embeddings"` OriginalMaxPositionEmbeddings uint32 `json:"original_max_position_embeddings"`
MropeSection []int32 `json:"mrope_section"`
} `json:"rope_scaling"` } `json:"rope_scaling"`
RMSNormEPS float32 `json:"rms_norm_eps"` RMSNormEPS float32 `json:"rms_norm_eps"`
} }
@@ -40,18 +39,16 @@ func (q *qwen2Model) KV(t *Tokenizer) ggml.KV {
case "yarn": case "yarn":
kv["qwen2.rope.scaling.type"] = q.RopeScaling.Type kv["qwen2.rope.scaling.type"] = q.RopeScaling.Type
kv["qwen2.rope.scaling.factor"] = q.RopeScaling.Factor kv["qwen2.rope.scaling.factor"] = q.RopeScaling.Factor
case "mrope", "default":
kv["qwen2.rope.mrope_section"] = q.RopeScaling.MropeSection
default: default:
panic("unknown rope scaling type") panic("unknown rope scaling type")
} }
return kv return kv
} }
func (q *qwen2Model) Tensors(ts []Tensor) []*ggml.Tensor { func (q *qwen2Model) Tensors(ts []Tensor) []ggml.Tensor {
var out []*ggml.Tensor var out []ggml.Tensor
for _, t := range ts { for _, t := range ts {
out = append(out, &ggml.Tensor{ out = append(out, ggml.Tensor{
Name: t.Name(), Name: t.Name(),
Kind: t.Kind(), Kind: t.Kind(),
Shape: t.Shape(), Shape: t.Shape(),

View File

@@ -1,102 +0,0 @@
package convert
import (
"cmp"
"slices"
"strings"
"github.com/ollama/ollama/fs/ggml"
)
type qwen25VLModel struct {
qwen2Model
VisionModel struct {
Depth uint32 `json:"depth"`
HiddenSize uint32 `json:"hidden_size"`
NumHeads uint32 `json:"num_heads"`
InChannels uint32 `json:"in_chans"`
PatchSize uint32 `json:"patch_size"`
SpatialMergeSize uint32 `json:"spatial_merge_size"`
SpatialPatchSize uint32 `json:"spatial_patch_size"`
WindowSize uint32 `json:"window_size"`
RMSNormEps float32 `json:"layer_norm_epsilon"`
RopeTheta float32 `json:"rope_theta"`
FullAttentionBlocks []int32 `json:"fullatt_block_indexes"`
TemporalPatchSize uint32 `json:"temporal_patch_size"`
} `json:"vision_config"`
}
var _ ModelConverter = (*qwen25VLModel)(nil)
func (q *qwen25VLModel) KV(t *Tokenizer) ggml.KV {
kv := q.ModelParameters.KV(t)
kv["general.architecture"] = "qwen25vl"
for k, v := range q.qwen2Model.KV(t) {
if strings.HasPrefix(k, "qwen2.") {
kv[strings.Replace(k, "qwen2.", "qwen25vl.", 1)] = v
}
}
if q.VisionModel.FullAttentionBlocks == nil {
kv["qwen25vl.vision.fullatt_block_indexes"] = []int32{7, 15, 23, 31}
}
kv["qwen25vl.vision.block_count"] = cmp.Or(q.VisionModel.Depth, 32)
kv["qwen25vl.vision.embedding_length"] = q.VisionModel.HiddenSize
kv["qwen25vl.vision.attention.head_count"] = cmp.Or(q.VisionModel.NumHeads, 16)
kv["qwen25vl.vision.num_channels"] = q.VisionModel.InChannels
kv["qwen25vl.vision.patch_size"] = cmp.Or(q.VisionModel.PatchSize, 14)
kv["qwen25vl.vision.spatial_merge_size"] = cmp.Or(q.VisionModel.SpatialMergeSize, 2)
kv["qwen25vl.vision.spatial_patch_size"] = q.VisionModel.SpatialPatchSize
kv["qwen25vl.vision.window_size"] = cmp.Or(q.VisionModel.WindowSize, 112)
kv["qwen25vl.vision.attention.layer_norm_epsilon"] = cmp.Or(q.VisionModel.RMSNormEps, 1e-6)
kv["qwen25vl.vision.rope.freq_base"] = cmp.Or(q.VisionModel.RopeTheta, 1e4)
kv["qwen25vl.vision.fullatt_block_indexes"] = q.VisionModel.FullAttentionBlocks
kv["qwen25vl.vision.temporal_patch_size"] = cmp.Or(q.VisionModel.TemporalPatchSize, 2)
return kv
}
func (q *qwen25VLModel) Tensors(ts []Tensor) []*ggml.Tensor {
var out []*ggml.Tensor
for _, t := range ts {
if strings.Contains(t.Name(), "patch_embed.proj") {
for t := range splitDim(t, 2,
split{Replacer: strings.NewReplacer("patch_embed.proj", "patch_embd_0")},
split{Replacer: strings.NewReplacer("patch_embed.proj", "patch_embd_1")},
) {
t.Shape = slices.DeleteFunc(t.Shape, func(i uint64) bool { return i == 1 })
out = append(out, t)
}
} else if strings.Contains(t.Name(), "attn.qkv") {
out = append(out, slices.Collect(splitDim(t, 0,
split{Replacer: strings.NewReplacer("attn.qkv", "attn_q")},
split{Replacer: strings.NewReplacer("attn.qkv", "attn_k")},
split{Replacer: strings.NewReplacer("attn.qkv", "attn_v")},
))...)
} else {
out = append(out, &ggml.Tensor{
Name: t.Name(),
Kind: t.Kind(),
Shape: t.Shape(),
WriterTo: t,
})
}
}
return out
}
func (p *qwen25VLModel) Replacements() []string {
return append(
p.qwen2Model.Replacements(),
"visual", "v",
"blocks", "blk",
"attn.proj", "attn_out",
"norm1", "ln1",
"norm2", "ln2",
)
}

View File

@@ -11,6 +11,7 @@ import (
"io" "io"
"io/fs" "io/fs"
"log/slog" "log/slog"
"math"
"os" "os"
"path/filepath" "path/filepath"
"slices" "slices"
@@ -47,7 +48,7 @@ func convertFull(t *testing.T, fsys fs.FS) (*os.File, ggml.KV, ggml.Tensors) {
} }
t.Cleanup(func() { r.Close() }) t.Cleanup(func() { r.Close() })
m, err := ggml.Decode(r, -1) m, _, err := ggml.Decode(r, math.MaxInt)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@@ -130,7 +131,6 @@ func TestConvertModel(t *testing.T) {
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
defer expectFile.Close()
var expect map[string]string var expect map[string]string
if err := json.NewDecoder(expectFile).Decode(&expect); err != nil { if err := json.NewDecoder(expectFile).Decode(&expect); err != nil {
@@ -332,7 +332,7 @@ func TestConvertAdapter(t *testing.T) {
} }
defer r.Close() defer r.Close()
m, err := ggml.Decode(r, -1) m, _, err := ggml.Decode(r, math.MaxInt)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }

58
convert/fs.go Normal file
View File

@@ -0,0 +1,58 @@
package convert
import (
"archive/zip"
"errors"
"io"
"io/fs"
"os"
"path/filepath"
)
type ZipReader struct {
r *zip.Reader
p string
// limit is the maximum size of a file that can be read directly
// from the zip archive. Files larger than this size will be extracted
limit int64
}
func NewZipReader(r *zip.Reader, p string, limit int64) fs.FS {
return &ZipReader{r, p, limit}
}
func (z *ZipReader) Open(name string) (fs.File, error) {
r, err := z.r.Open(name)
if err != nil {
return nil, err
}
defer r.Close()
if fi, err := r.Stat(); err != nil {
return nil, err
} else if fi.Size() < z.limit {
return r, nil
}
if !filepath.IsLocal(name) {
return nil, zip.ErrInsecurePath
}
n := filepath.Join(z.p, name)
if _, err := os.Stat(n); errors.Is(err, os.ErrNotExist) {
w, err := os.Create(n)
if err != nil {
return nil, err
}
defer w.Close()
if _, err := io.Copy(w, r); err != nil {
return nil, err
}
} else if err != nil {
return nil, err
}
return os.Open(n)
}

View File

@@ -11,15 +11,14 @@ type Tensor interface {
Name() string Name() string
Shape() []uint64 Shape() []uint64
Kind() uint32 Kind() uint32
SetRepacker(Repacker) SetRepacker(repacker)
WriteTo(io.Writer) (int64, error) WriteTo(io.Writer) (int64, error)
Clone() Tensor
} }
type tensorBase struct { type tensorBase struct {
name string name string
shape []uint64 shape []uint64
repacker Repacker repacker
} }
func (t tensorBase) Name() string { func (t tensorBase) Name() string {
@@ -37,11 +36,7 @@ const (
func (t tensorBase) Kind() uint32 { func (t tensorBase) Kind() uint32 {
if strings.HasSuffix(t.name, ".ffn_gate_inp.weight") || if strings.HasSuffix(t.name, ".ffn_gate_inp.weight") ||
t.name == "token_types.weight" || t.name == "token_types.weight" {
t.name == "v.positional_embedding_vlm" ||
t.name == "v.tile_position_embd.weight" ||
t.name == "v.pre_tile_position_embd.weight" ||
t.name == "v.post_tile_position_embd.weight" {
// these tensors are always F32 // these tensors are always F32
return 0 return 0
} }
@@ -56,11 +51,11 @@ func (t tensorBase) Kind() uint32 {
} }
} }
func (t *tensorBase) SetRepacker(fn Repacker) { func (t *tensorBase) SetRepacker(fn repacker) {
t.repacker = fn t.repacker = fn
} }
type Repacker func(string, []float32, []uint64) ([]float32, error) type repacker func(string, []float32, []uint64) ([]float32, error)
func parseTensors(fsys fs.FS, replacer *strings.Replacer) ([]Tensor, error) { func parseTensors(fsys fs.FS, replacer *strings.Replacer) ([]Tensor, error) {
patterns := []struct { patterns := []struct {

View File

@@ -94,21 +94,6 @@ type safetensor struct {
*tensorBase *tensorBase
} }
func (st safetensor) Clone() Tensor {
return &safetensor{
fs: st.fs,
path: st.path,
dtype: st.dtype,
offset: st.offset,
size: st.size,
tensorBase: &tensorBase{
name: st.name,
repacker: st.repacker,
shape: slices.Clone(st.shape),
},
}
}
func (st safetensor) WriteTo(w io.Writer) (int64, error) { func (st safetensor) WriteTo(w io.Writer) (int64, error) {
f, err := st.fs.Open(st.path) f, err := st.fs.Open(st.path)
if err != nil { if err != nil {

View File

@@ -43,17 +43,6 @@ type torch struct {
*tensorBase *tensorBase
} }
func (t torch) Clone() Tensor {
return torch{
storage: t.storage,
tensorBase: &tensorBase{
name: t.name,
shape: t.shape,
repacker: t.repacker,
},
}
}
func (pt torch) WriteTo(w io.Writer) (int64, error) { func (pt torch) WriteTo(w io.Writer) (int64, error) {
return 0, nil return 0, nil
} }

View File

@@ -1,129 +0,0 @@
package convert
import (
"cmp"
"io"
"iter"
"path"
"slices"
"strings"
"github.com/pdevine/tensor"
"github.com/pdevine/tensor/native"
"github.com/ollama/ollama/fs/ggml"
)
type split struct {
*strings.Replacer
dim int
// fn is an optional function to apply to the tensor after slicing
fn func(tensor.Tensor) (tensor.Tensor, error)
}
// splitDim splits a tensor along a specified dimension into multiple tensors. The dimension
// is split evenly based on the number of replacers provided unless a specific count is given.
func splitDim(t Tensor, dim int, splits ...split) iter.Seq[*ggml.Tensor] {
return func(yield func(*ggml.Tensor) bool) {
var offset int
for _, split := range splits {
t := t.Clone()
shape := slices.Clone(t.Shape())
shape[dim] = cmp.Or(uint64(split.dim), shape[dim]/uint64(len(splits)))
slice := slices.Repeat([]tensor.Slice{nil}, len(shape))
slice[dim] = tensor.S(offset, offset+int(shape[dim]))
offset += int(shape[dim])
t.SetRepacker(func(_ string, data []float32, shape []uint64) ([]float32, error) {
dims := make([]int, len(shape))
for i := range shape {
dims[i] = int(shape[i])
}
var tt tensor.Tensor = tensor.New(tensor.WithShape(dims...), tensor.WithBacking(data))
tt, err := tt.Slice(slice...)
if err != nil {
return nil, err
}
tt = tensor.Materialize(tt)
if split.fn != nil {
tt, err = split.fn(tt)
if err != nil {
return nil, err
}
}
// flatten tensor so it can be written as a vector
if err := tt.Reshape(tt.Shape().TotalSize()); err != nil {
return nil, err
}
return native.VectorF32(tt.(*tensor.Dense))
})
if !yield(&ggml.Tensor{
Name: split.Replace(t.Name()),
Kind: t.Kind(),
Shape: shape,
WriterTo: t,
}) {
break
}
}
}
}
type merge struct {
pattern, name string
}
// mergeTensors merges tensors that match a given pattern into a single tensor.
func mergeTensors(unmatched []Tensor, merges ...merge) (out []*ggml.Tensor, _ []Tensor) {
var matched []Tensor
for i := range merges {
matched, unmatched = slicesSplitFunc(unmatched, func(t Tensor) bool {
matched, _ := path.Match(merges[i].pattern, t.Name())
return matched
})
if len(matched) > 0 {
out = append(out, &ggml.Tensor{
Name: merges[i].name,
Kind: matched[0].Kind(),
Shape: append([]uint64{uint64(len(matched))}, matched[0].Shape()...),
WriterTo: mergeGroup(matched),
})
}
}
return out, unmatched
}
// slicesSplitFunc splits a slice into two slices based on a predicate function.
func slicesSplitFunc[S ~[]E, E comparable](s S, fn func(e E) bool) (matched, unmatched S) {
for _, e := range s {
if fn(e) {
matched = append(matched, e)
} else {
unmatched = append(unmatched, e)
}
}
return matched, unmatched
}
type mergeGroup []Tensor
func (g mergeGroup) WriteTo(w io.Writer) (int64, error) {
for _, t := range g {
if _, err := t.WriteTo(w); err != nil {
return 0, err
}
}
return 0, nil
}

View File

@@ -1,402 +0,0 @@
package convert
import (
"bytes"
"encoding/binary"
"io"
"iter"
"slices"
"strings"
"testing"
"github.com/google/go-cmp/cmp"
"github.com/ollama/ollama/fs/ggml"
"github.com/pdevine/tensor"
)
type fakeTensor struct {
name string
shape []uint64
data []float32
repacker Repacker
}
func (f fakeTensor) Name() string {
return f.name
}
func (f fakeTensor) Shape() []uint64 {
return f.shape
}
func (f fakeTensor) Kind() uint32 {
return 0
}
func (f *fakeTensor) SetRepacker(fn Repacker) {
f.repacker = fn
}
func (f fakeTensor) Clone() Tensor {
return &fakeTensor{
name: f.name,
shape: slices.Clone(f.shape),
data: slices.Clone(f.data),
repacker: f.repacker,
}
}
func (f fakeTensor) WriteTo(w io.Writer) (n int64, err error) {
data := f.data
if f.repacker != nil {
data, err = f.repacker(f.name, data, f.shape)
if err != nil {
return 0, err
}
}
if err := binary.Write(w, binary.LittleEndian, data); err != nil {
return 0, err
}
return int64(len(data) * 4), nil
}
func mul(shape []uint64) int {
n := 1
for _, dim := range shape {
n *= int(dim)
}
return n
}
func TestSplitDim(t *testing.T) {
r := fakeTensor{
name: "a.b",
shape: []uint64{3, 4},
data: []float32{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11},
}
t.Run("no split", func(t *testing.T) {
for tt := range splitDim(&r, 0, split{Replacer: strings.NewReplacer("a", "x")}) {
if tt.Name != "x.b" {
t.Fatalf("expected name 'x', got '%s'", tt.Name)
}
if !slices.Equal(tt.Shape, []uint64{3, 4}) {
t.Fatalf("expected shape [3, 4], got %v", tt.Shape)
}
var b bytes.Buffer
if _, err := tt.WriteTo(&b); err != nil {
t.Fatal(err)
}
f32s := make([]float32, mul(tt.Shape))
if err := binary.Read(&b, binary.LittleEndian, &f32s); err != nil {
t.Fatal(err)
}
if !slices.Equal(f32s, []float32{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11}) {
t.Fatalf("expected data [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11], got %v", f32s)
}
}
})
t.Run("even split", func(t *testing.T) {
next, stop := iter.Pull(splitDim(&r, 1,
split{Replacer: strings.NewReplacer("a", "x")},
split{Replacer: strings.NewReplacer("b", "y")},
))
defer stop()
{
tt, ok := next()
if !ok {
t.Fatal("expected at least one split")
}
if tt.Name != "x.b" {
t.Fatal("expected name 'x.b', got", tt.Name)
}
if !slices.Equal(tt.Shape, []uint64{3, 2}) {
t.Fatal("expected shape [3, 2], got", tt.Shape)
}
var b bytes.Buffer
if _, err := tt.WriteTo(&b); err != nil {
t.Fatal(err)
}
f32s := make([]float32, mul(tt.Shape))
if err := binary.Read(&b, binary.LittleEndian, &f32s); err != nil {
t.Fatal(err)
}
if !slices.Equal(f32s, []float32{0, 1, 4, 5, 8, 9}) {
t.Fatal("expected data [0, 1, 4, 5, 8, 9], got", f32s)
}
}
{
tt, ok := next()
if !ok {
t.Fatal("expected at least one split")
}
if tt.Name != "a.y" {
t.Fatal("expected name 'a.y', got", tt.Name)
}
if !slices.Equal(tt.Shape, []uint64{3, 2}) {
t.Fatal("expected shape [3, 2], got", tt.Shape)
}
var b bytes.Buffer
if _, err := tt.WriteTo(&b); err != nil {
t.Fatal(err)
}
f32s := make([]float32, mul(tt.Shape))
if err := binary.Read(&b, binary.LittleEndian, &f32s); err != nil {
t.Fatal(err)
}
if !slices.Equal(f32s, []float32{2, 3, 6, 7, 10, 11}) {
t.Fatal("expected data [2, 3, 6, 7, 10, 11], got", f32s)
}
}
})
t.Run("uneven split", func(t *testing.T) {
next, stop := iter.Pull(splitDim(&r, 0,
split{Replacer: strings.NewReplacer("a", "x"), dim: 2},
split{Replacer: strings.NewReplacer("b", "y"), dim: 1},
))
defer stop()
{
tt, ok := next()
if !ok {
t.Fatal("expected at least one split")
}
if tt.Name != "x.b" {
t.Fatal("expected name 'x.b', got", tt.Name)
}
if !slices.Equal(tt.Shape, []uint64{2, 4}) {
t.Fatal("expected shape [2, 4], got", tt.Shape)
}
var b bytes.Buffer
if _, err := tt.WriteTo(&b); err != nil {
t.Fatal(err)
}
f32s := make([]float32, mul(tt.Shape))
if err := binary.Read(&b, binary.LittleEndian, &f32s); err != nil {
t.Fatal(err)
}
if !slices.Equal(f32s, []float32{0, 1, 2, 3, 4, 5, 6, 7}) {
t.Fatal("expected data [0, 1, 2, 3, 4, 5, 6, 7], got", f32s)
}
}
{
tt, ok := next()
if !ok {
t.Fatal("expected at least one split")
}
if tt.Name != "a.y" {
t.Fatal("expected name 'a.y', got", tt.Name)
}
if !slices.Equal(tt.Shape, []uint64{1, 4}) {
t.Fatal("expected shape [1, 4], got", tt.Shape)
}
var b bytes.Buffer
if _, err := tt.WriteTo(&b); err != nil {
t.Fatal(err)
}
f32s := make([]float32, mul(tt.Shape))
if err := binary.Read(&b, binary.LittleEndian, &f32s); err != nil {
t.Fatal(err)
}
if !slices.Equal(f32s, []float32{8, 9, 10, 11}) {
t.Fatal("expected data [8, 9, 10, 11], got", f32s)
}
}
})
t.Run("split with transpose", func(t *testing.T) {
next, stop := iter.Pull(splitDim(&r, 1,
split{Replacer: strings.NewReplacer("a", "x")},
split{Replacer: strings.NewReplacer("b", "y"), fn: func(tt tensor.Tensor) (tensor.Tensor, error) {
return tensor.Transpose(tt, 1, 0)
}},
))
defer stop()
{
tt, ok := next()
if !ok {
t.Fatal("expected at least one split")
}
if tt.Name != "x.b" {
t.Fatal("expected name 'x.b', got", tt.Name)
}
if !slices.Equal(tt.Shape, []uint64{3, 2}) {
t.Fatal("expected shape [3, 2], got", tt.Shape)
}
var b bytes.Buffer
if _, err := tt.WriteTo(&b); err != nil {
t.Fatal(err)
}
f32s := make([]float32, mul(tt.Shape))
if err := binary.Read(&b, binary.LittleEndian, &f32s); err != nil {
t.Fatal(err)
}
if !slices.Equal(f32s, []float32{0, 1, 4, 5, 8, 9}) {
t.Fatal("expected data [0, 1, 4, 5, 8, 9], got", f32s)
}
}
{
tt, ok := next()
if !ok {
t.Fatal("expected at least one split")
}
if tt.Name != "a.y" {
t.Fatal("expected name 'a.y', got", tt.Name)
}
if !slices.Equal(tt.Shape, []uint64{3, 2}) {
t.Fatal("expected shape [3, 2], got", tt.Shape)
}
var b bytes.Buffer
if _, err := tt.WriteTo(&b); err != nil {
t.Fatal(err)
}
f32s := make([]float32, mul(tt.Shape))
if err := binary.Read(&b, binary.LittleEndian, &f32s); err != nil {
t.Fatal(err)
}
if !slices.Equal(f32s, []float32{2, 6, 10, 3, 7, 11}) {
t.Fatal("expected data [2, 6, 10, 3, 7, 11], got", f32s)
}
}
})
}
func TestMerge(t *testing.T) {
unmatched := []Tensor{
&fakeTensor{
name: "a.0.b",
shape: []uint64{5, 2},
data: []float32{10, 11, 12, 13, 14, 15, 16, 17, 18, 19},
},
&fakeTensor{
name: "a.1.b",
shape: []uint64{5, 2},
data: []float32{20, 21, 22, 23, 24, 25, 26, 27, 28, 29},
},
&fakeTensor{
name: "c.0.d",
shape: []uint64{5, 2},
data: []float32{30, 31, 32, 33, 34, 35, 36, 37, 38, 39},
},
&fakeTensor{
name: "c.1.d",
shape: []uint64{5, 2},
data: []float32{40, 41, 42, 43, 44, 45, 46, 47, 48, 49},
},
&fakeTensor{
name: "e.0.f",
shape: []uint64{5, 2},
data: []float32{50, 51, 52, 53, 54, 55, 56, 57, 58, 59},
},
}
checkMatched := func(t *testing.T, n int, matched []*ggml.Tensor) {
for i := range n {
got := matched[i]
if diff := cmp.Diff([]uint64{2, 5, 2}, got.Shape); diff != "" {
t.Errorf("unexpected (-want +got):\n%s", diff)
}
var b bytes.Buffer
if _, err := got.WriteTo(&b); err != nil {
t.Fatal(err)
}
f32s := make([]float32, 20)
if err := binary.Read(&b, binary.LittleEndian, &f32s); err != nil {
t.Fatal(err)
}
offset := 10 + (i * 20)
want := make([]float32, 20)
for j := range 20 {
want[j] = float32(offset + j)
}
if diff := cmp.Diff(want, f32s); diff != "" {
t.Errorf("unexpected data (-want +got):\n%s", diff)
}
}
}
t.Run("single merge", func(t *testing.T) {
matched, unmatched := mergeTensors(unmatched, merge{"a.*.b", "a.b"})
if len(unmatched) != 3 {
t.Error("expected 3 remaining tensors, got", len(unmatched))
}
if len(matched) != 1 {
t.Error("expected 1 merged tensor, got", len(matched))
}
checkMatched(t, 1, matched)
})
t.Run("multiple merges", func(t *testing.T) {
matched, unmatched := mergeTensors(unmatched, merge{"a.*.b", "a.b"}, merge{"c.*.d", "c.d"})
if len(unmatched) != 1 {
t.Error("expected 1 remaining tensors, got", len(unmatched))
}
if len(matched) != 2 {
t.Error("expected 2 merged tensor, got", len(matched))
}
checkMatched(t, 2, matched)
})
t.Run("no match", func(t *testing.T) {
matched, unmatched := mergeTensors(unmatched, merge{"x.*.y", "x.y"})
if len(unmatched) != 5 {
t.Error("expected 5 remaining tensors, got", len(unmatched))
}
if len(matched) != 0 {
t.Error("expected no merged tensors, got", len(matched))
}
})
}

View File

@@ -110,7 +110,6 @@ func parseTokenizer(fsys fs.FS, specialTokenTypes []string) (*Tokenizer, error)
} }
if f, err := fsys.Open("tokenizer_config.json"); errors.Is(err, os.ErrNotExist) { if f, err := fsys.Open("tokenizer_config.json"); errors.Is(err, os.ErrNotExist) {
// noop
} else if err != nil { } else if err != nil {
return nil, err return nil, err
} else { } else {
@@ -172,34 +171,6 @@ func parseTokenizer(fsys fs.FS, specialTokenTypes []string) (*Tokenizer, error)
} }
} }
if f, err := fsys.Open("generation_config.json"); errors.Is(err, os.ErrNotExist) {
} else if err != nil {
return nil, err
} else {
defer f.Close()
var p map[string]json.RawMessage
if err := json.NewDecoder(f).Decode(&p); err != nil {
return nil, err
}
for _, st := range specialTokenTypes {
if bts, ok := p[fmt.Sprintf("%s_token_id", st)]; ok {
var ids []int32
if err := json.Unmarshal(bts, &ids); err != nil {
// value is not a list so the existing ID is used
continue
}
if i := slices.IndexFunc(t.SpecialVocabulary, func(sv *SpecialVocabulary) bool {
return sv.Type == st
}); i >= 0 {
t.SpecialVocabulary[i].IDs = ids
}
}
}
}
return t, nil return t, nil
} }
@@ -309,9 +280,6 @@ type SpecialVocabulary struct {
ID int ID int
Content string Content string
AddToken bool AddToken bool
// IDs is populated by generation_config.json
IDs []int32
} }
func (sv SpecialVocabulary) Key() string { func (sv SpecialVocabulary) Key() string {

View File

@@ -247,67 +247,6 @@ func TestParseTokenizer(t *testing.T) {
Pre: "default", Pre: "default",
}, },
}, },
{
name: "generation config eos token ids",
fsys: createTokenizerFS(t, t.TempDir(), map[string]io.Reader{
"tokenizer.json": strings.NewReader(`{
"added_tokens": [
{
"id": 0,
"content": "<bos>",
"special": true
},
{
"id": 1,
"content": "<eos>",
"special": true
},
{
"id": 2,
"content": "<eot>",
"special": true
},
{
"id": 3,
"content": "<eom>",
"special": true
}
],
"model": {
"vocab": {
"<bos>": 0,
"<eos>": 1,
"<eot>": 2,
"<eom>": 3
}
}
}`),
"tokenizer_config.json": strings.NewReader(`{
"add_bos_token": true,
"add_eos_token": false,
"bos_token": "<bos>",
"eos_token": "<eos>"
}`),
"generation_config.json": strings.NewReader(`{
"bos_token_id": 0,
"eos_token_id": [1, 2, 3]
}`),
}),
specialTokenTypes: []string{"pad", "eos", "bos", "unk"},
want: &Tokenizer{
Vocabulary: &Vocabulary{
Model: "gpt2",
Tokens: []string{"<bos>", "<eos>", "<eot>", "<eom>"},
Scores: []float32{0, 1, 2, 3},
Types: []int32{3, 3, 3, 3},
},
SpecialVocabulary: []*SpecialVocabulary{
{Type: "eos", Content: "<eos>", ID: 1, IDs: []int32{1, 2, 3}, AddToken: false},
{Type: "bos", Content: "<bos>", ID: 0, AddToken: true},
},
Pre: "default",
},
},
} }
for _, tt := range cases { for _, tt := range cases {

View File

@@ -3,7 +3,6 @@
package discover package discover
import ( import (
"fmt"
"log/slog" "log/slog"
"os" "os"
"regexp" "regexp"
@@ -56,13 +55,10 @@ func cudaVariant(gpuInfo CudaGPUInfo) string {
} }
} }
} }
return "sbsa"
} }
// driver 12.0 has problems with the cuda v12 library, so run v11 on those older drivers // driver 12.0 has problems with the cuda v12 library, so run v11 on those older drivers
if gpuInfo.DriverMajor < 12 || (gpuInfo.DriverMajor == 12 && gpuInfo.DriverMinor == 0) { if gpuInfo.DriverMajor < 12 || (gpuInfo.DriverMajor == 12 && gpuInfo.DriverMinor == 0) {
// The detected driver is older than Feb 2023
slog.Warn("old CUDA driver detected - please upgrade to a newer driver", "version", fmt.Sprintf("%d.%d", gpuInfo.DriverMajor, gpuInfo.DriverMinor))
return "v11" return "v11"
} }
return "v12" return "v12"

View File

@@ -670,7 +670,7 @@ func loadOneapiMgmt(oneapiLibPaths []string) (int, *C.oneapi_handle_t, string, e
} }
func getVerboseState() C.uint16_t { func getVerboseState() C.uint16_t {
if envconfig.LogLevel() < slog.LevelInfo { if envconfig.Debug() {
return C.uint16_t(1) return C.uint16_t(1)
} }
return C.uint16_t(0) return C.uint16_t(0)

View File

@@ -27,14 +27,12 @@
#endif #endif
#ifndef LOG
#define LOG(verbose, ...) \ #define LOG(verbose, ...) \
do { \ do { \
if (verbose) { \ if (verbose) { \
fprintf(stderr, __VA_ARGS__); \ fprintf(stderr, __VA_ARGS__); \
} \ } \
} while (0) } while (0)
#endif
#ifdef __cplusplus #ifdef __cplusplus
extern "C" { extern "C" {

View File

@@ -1,7 +1,6 @@
#ifndef __APPLE__ // TODO - maybe consider nvidia support on intel macs? #ifndef __APPLE__ // TODO - maybe consider nvidia support on intel macs?
#include <string.h> #include <string.h>
#include <inttypes.h>
#include "gpu_info_cudart.h" #include "gpu_info_cudart.h"
void cudart_init(char *cudart_lib_path, cudart_init_resp_t *resp) { void cudart_init(char *cudart_lib_path, cudart_init_resp_t *resp) {
@@ -59,7 +58,7 @@ void cudart_init(char *cudart_lib_path, cudart_init_resp_t *resp) {
LOG(resp->ch.verbose, "cudaSetDevice err: %d\n", ret); LOG(resp->ch.verbose, "cudaSetDevice err: %d\n", ret);
UNLOAD_LIBRARY(resp->ch.handle); UNLOAD_LIBRARY(resp->ch.handle);
resp->ch.handle = NULL; resp->ch.handle = NULL;
if (ret == CUDART_ERROR_INSUFFICIENT_DRIVER) { if (ret == CUDA_ERROR_INSUFFICIENT_DRIVER) {
resp->err = strdup("your nvidia driver is too old or missing. If you have a CUDA GPU please upgrade to run ollama"); resp->err = strdup("your nvidia driver is too old or missing. If you have a CUDA GPU please upgrade to run ollama");
return; return;
} }
@@ -169,9 +168,9 @@ void cudart_bootstrap(cudart_handle_t h, int i, mem_info_t *resp) {
resp->free = memInfo.free; resp->free = memInfo.free;
resp->used = memInfo.used; resp->used = memInfo.used;
LOG(h.verbose, "[%s] CUDA totalMem %" PRId64 "\n", resp->gpu_id, resp->total); LOG(h.verbose, "[%s] CUDA totalMem %lu\n", resp->gpu_id, resp->total);
LOG(h.verbose, "[%s] CUDA freeMem %" PRId64 "\n", resp->gpu_id, resp->free); LOG(h.verbose, "[%s] CUDA freeMem %lu\n", resp->gpu_id, resp->free);
LOG(h.verbose, "[%s] CUDA usedMem %" PRId64 "\n", resp->gpu_id, resp->used); LOG(h.verbose, "[%s] CUDA usedMem %lu\n", resp->gpu_id, resp->used);
LOG(h.verbose, "[%s] Compute Capability %d.%d\n", resp->gpu_id, resp->major, resp->minor); LOG(h.verbose, "[%s] Compute Capability %d.%d\n", resp->gpu_id, resp->major, resp->minor);
} }
@@ -181,4 +180,4 @@ void cudart_release(cudart_handle_t h) {
h.handle = NULL; h.handle = NULL;
} }
#endif // __APPLE__ #endif // __APPLE__

View File

@@ -1,7 +1,6 @@
#ifndef __APPLE__ // TODO - maybe consider nvidia support on intel macs? #ifndef __APPLE__ // TODO - maybe consider nvidia support on intel macs?
#include <string.h> #include <string.h>
#include <inttypes.h>
#include "gpu_info_nvcuda.h" #include "gpu_info_nvcuda.h"
void nvcuda_init(char *nvcuda_lib_path, nvcuda_init_resp_t *resp) { void nvcuda_init(char *nvcuda_lib_path, nvcuda_init_resp_t *resp) {
@@ -194,8 +193,8 @@ void nvcuda_bootstrap(nvcuda_handle_t h, int i, mem_info_t *resp) {
resp->total = memInfo.total; resp->total = memInfo.total;
resp->free = memInfo.free; resp->free = memInfo.free;
LOG(h.verbose, "[%s] CUDA totalMem %" PRId64 "mb\n", resp->gpu_id, resp->total / 1024 / 1024); LOG(h.verbose, "[%s] CUDA totalMem %lu mb\n", resp->gpu_id, resp->total / 1024 / 1024);
LOG(h.verbose, "[%s] CUDA freeMem %" PRId64 "mb\n", resp->gpu_id, resp->free / 1024 / 1024); LOG(h.verbose, "[%s] CUDA freeMem %lu mb\n", resp->gpu_id, resp->free / 1024 / 1024);
LOG(h.verbose, "[%s] Compute Capability %d.%d\n", resp->gpu_id, resp->major, resp->minor); LOG(h.verbose, "[%s] Compute Capability %d.%d\n", resp->gpu_id, resp->major, resp->minor);
@@ -248,4 +247,4 @@ void nvcuda_release(nvcuda_handle_t h) {
h.handle = NULL; h.handle = NULL;
} }
#endif // __APPLE__ #endif // __APPLE__

View File

@@ -12,7 +12,7 @@ import (
// '../lib/ollama' on Linux and the executable's directory on macOS // '../lib/ollama' on Linux and the executable's directory on macOS
// note: distribution builds, additional GPU-specific libraries are // note: distribution builds, additional GPU-specific libraries are
// found in subdirectories of the returned path, such as // found in subdirectories of the returned path, such as
// 'cuda_v12', 'rocm', etc. // 'cuda_v11', 'cuda_v12', 'rocm', etc.
var LibOllamaPath string = func() string { var LibOllamaPath string = func() string {
exe, err := os.Executable() exe, err := os.Executable()
if err != nil { if err != nil {

View File

@@ -19,7 +19,7 @@
### Model names ### Model names
Model names follow a `model:tag` format, where `model` can have an optional namespace such as `example/model`. Some examples are `orca-mini:3b-q8_0` and `llama3:70b`. The tag is optional and, if not provided, will default to `latest`. The tag is used to identify a specific version. Model names follow a `model:tag` format, where `model` can have an optional namespace such as `example/model`. Some examples are `orca-mini:3b-q4_1` and `llama3:70b`. The tag is optional and, if not provided, will default to `latest`. The tag is used to identify a specific version.
### Durations ### Durations
@@ -43,7 +43,6 @@ Generate a response for a given prompt with a provided model. This is a streamin
- `prompt`: the prompt to generate a response for - `prompt`: the prompt to generate a response for
- `suffix`: the text after the model response - `suffix`: the text after the model response
- `images`: (optional) a list of base64-encoded images (for multimodal models such as `llava`) - `images`: (optional) a list of base64-encoded images (for multimodal models such as `llava`)
- `think`: (for thinking models) should the model think before responding?
Advanced parameters (optional): Advanced parameters (optional):
@@ -174,7 +173,7 @@ curl http://localhost:11434/api/generate -d '{
##### Response ##### Response
```json5 ```json
{ {
"model": "codellama:code", "model": "codellama:code",
"created_at": "2024-07-22T20:47:51.147561Z", "created_at": "2024-07-22T20:47:51.147561Z",
@@ -395,6 +394,9 @@ curl http://localhost:11434/api/generate -d '{
"repeat_penalty": 1.2, "repeat_penalty": 1.2,
"presence_penalty": 1.5, "presence_penalty": 1.5,
"frequency_penalty": 1.0, "frequency_penalty": 1.0,
"mirostat": 1,
"mirostat_tau": 0.8,
"mirostat_eta": 0.6,
"penalize_newline": true, "penalize_newline": true,
"stop": ["\n", "user:"], "stop": ["\n", "user:"],
"numa": false, "numa": false,
@@ -402,7 +404,10 @@ curl http://localhost:11434/api/generate -d '{
"num_batch": 2, "num_batch": 2,
"num_gpu": 1, "num_gpu": 1,
"main_gpu": 0, "main_gpu": 0,
"low_vram": false,
"vocab_only": false,
"use_mmap": true, "use_mmap": true,
"use_mlock": false,
"num_thread": 8 "num_thread": 8
} }
}' }'
@@ -491,13 +496,11 @@ Generate the next message in a chat with a provided model. This is a streaming e
- `model`: (required) the [model name](#model-names) - `model`: (required) the [model name](#model-names)
- `messages`: the messages of the chat, this can be used to keep a chat memory - `messages`: the messages of the chat, this can be used to keep a chat memory
- `tools`: list of tools in JSON for the model to use if supported - `tools`: list of tools in JSON for the model to use if supported
- `think`: (for thinking models) should the model think before responding?
The `message` object has the following fields: The `message` object has the following fields:
- `role`: the role of the message, either `system`, `user`, `assistant`, or `tool` - `role`: the role of the message, either `system`, `user`, `assistant`, or `tool`
- `content`: the content of the message - `content`: the content of the message
- `thinking`: (for thinking models) the model's thinking process
- `images` (optional): a list of images to include in the message (for multimodal models such as `llava`) - `images` (optional): a list of images to include in the message (for multimodal models such as `llava`)
- `tool_calls` (optional): a list of tools in JSON that the model wants to use - `tool_calls` (optional): a list of tools in JSON that the model wants to use
@@ -955,8 +958,19 @@ If you are creating a model from a safetensors directory or from a GGUF file, yo
| Type | Recommended | | Type | Recommended |
| --- | :-: | | --- | :-: |
| q2_K | |
| q3_K_L | |
| q3_K_M | |
| q3_K_S | |
| q4_0 | |
| q4_1 | |
| q4_K_M | * | | q4_K_M | * |
| q4_K_S | | | q4_K_S | |
| q5_0 | |
| q5_1 | |
| q5_K_M | |
| q5_K_S | |
| q6_K | |
| q8_0 | * | | q8_0 | * |
### Examples ### Examples
@@ -1001,8 +1015,8 @@ Quantize a non-quantized model.
```shell ```shell
curl http://localhost:11434/api/create -d '{ curl http://localhost:11434/api/create -d '{
"model": "llama3.2:quantized", "model": "llama3.1:quantized",
"from": "llama3.2:3b-instruct-fp16", "from": "llama3.1:8b-instruct-fp16",
"quantize": "q4_K_M" "quantize": "q4_K_M"
}' }'
``` ```
@@ -1012,14 +1026,12 @@ curl http://localhost:11434/api/create -d '{
A stream of JSON objects is returned: A stream of JSON objects is returned:
```json ```json
{"status":"quantizing F16 model to Q4_K_M","digest":"0","total":6433687776,"completed":12302} {"status":"quantizing F16 model to Q4_K_M"}
{"status":"quantizing F16 model to Q4_K_M","digest":"0","total":6433687776,"completed":6433687552} {"status":"creating new layer sha256:667b0c1932bc6ffc593ed1d03f895bf2dc8dc6df21db3042284a6f4416b06a29"}
{"status":"verifying conversion"} {"status":"using existing layer sha256:11ce4ee3e170f6adebac9a991c22e22ab3f8530e154ee669954c4bc73061c258"}
{"status":"creating new layer sha256:fb7f4f211b89c6c4928ff4ddb73db9f9c0cfca3e000c3e40d6cf27ddc6ca72eb"} {"status":"using existing layer sha256:0ba8f0e314b4264dfd19df045cde9d4c394a52474bf92ed6a3de22a4ca31a177"}
{"status":"using existing layer sha256:966de95ca8a62200913e3f8bfbf84c8494536f1b94b49166851e76644e966396"}
{"status":"using existing layer sha256:fcc5a6bec9daf9b561a68827b67ab6088e1dba9d1fa2a50d7bbcc8384e0a265d"}
{"status":"using existing layer sha256:a70ff7e570d97baaf4e62ac6e6ad9975e04caa6d900d3742d37698494479e0cd"}
{"status":"using existing layer sha256:56bb8bd477a519ffa694fc449c2413c6f0e1d3b1c88fa7e3c9d88d3ae49d4dcb"} {"status":"using existing layer sha256:56bb8bd477a519ffa694fc449c2413c6f0e1d3b1c88fa7e3c9d88d3ae49d4dcb"}
{"status":"creating new layer sha256:455f34728c9b5dd3376378bfb809ee166c145b0b4c1f1a6feca069055066ef9a"}
{"status":"writing manifest"} {"status":"writing manifest"}
{"status":"success"} {"status":"success"}
``` ```
@@ -1157,37 +1169,29 @@ A single JSON object will be returned.
{ {
"models": [ "models": [
{ {
"name": "deepseek-r1:latest", "name": "codellama:13b",
"model": "deepseek-r1:latest", "modified_at": "2023-11-04T14:56:49.277302595-07:00",
"modified_at": "2025-05-10T08:06:48.639712648-07:00", "size": 7365960935,
"size": 4683075271, "digest": "9f438cb9cd581fc025612d27f7c1a6669ff83a8bb0ed86c94fcf4c5440555697",
"digest": "0a8c266910232fd3291e71e5ba1e058cc5af9d411192cf88b6d30e92b6e73163",
"details": { "details": {
"parent_model": "",
"format": "gguf", "format": "gguf",
"family": "qwen2", "family": "llama",
"families": [ "families": null,
"qwen2" "parameter_size": "13B",
], "quantization_level": "Q4_0"
"parameter_size": "7.6B",
"quantization_level": "Q4_K_M"
} }
}, },
{ {
"name": "llama3.2:latest", "name": "llama3:latest",
"model": "llama3.2:latest", "modified_at": "2023-12-07T09:32:18.757212583-08:00",
"modified_at": "2025-05-04T17:37:44.706015396-07:00", "size": 3825819519,
"size": 2019393189, "digest": "fe938a131f40e6f6d40083c9f0f430a515233eb2edaa6d72eb85c50d64f2300e",
"digest": "a80c4f17acd55265feec403c7aef86be0c25983ab279d83f3bcd3abbcb5b8b72",
"details": { "details": {
"parent_model": "",
"format": "gguf", "format": "gguf",
"family": "llama", "family": "llama",
"families": [ "families": null,
"llama" "parameter_size": "7B",
], "quantization_level": "Q4_0"
"parameter_size": "3.2B",
"quantization_level": "Q4_K_M"
} }
} }
] ]
@@ -1219,7 +1223,7 @@ curl http://localhost:11434/api/show -d '{
#### Response #### Response
```json5 ```json
{ {
"modelfile": "# Modelfile generated by \"ollama show\"\n# To build a new Modelfile based on this one, replace the FROM line with:\n# FROM llava:latest\n\nFROM /Users/matt/.ollama/models/blobs/sha256:200765e1283640ffbd013184bf496e261032fa75b99498a9613be4e94d63ad52\nTEMPLATE \"\"\"{{ .System }}\nUSER: {{ .Prompt }}\nASSISTANT: \"\"\"\nPARAMETER num_ctx 4096\nPARAMETER stop \"\u003c/s\u003e\"\nPARAMETER stop \"USER:\"\nPARAMETER stop \"ASSISTANT:\"", "modelfile": "# Modelfile generated by \"ollama show\"\n# To build a new Modelfile based on this one, replace the FROM line with:\n# FROM llava:latest\n\nFROM /Users/matt/.ollama/models/blobs/sha256:200765e1283640ffbd013184bf496e261032fa75b99498a9613be4e94d63ad52\nTEMPLATE \"\"\"{{ .System }}\nUSER: {{ .Prompt }}\nASSISTANT: \"\"\"\nPARAMETER num_ctx 4096\nPARAMETER stop \"\u003c/s\u003e\"\nPARAMETER stop \"USER:\"\nPARAMETER stop \"ASSISTANT:\"",
"parameters": "num_keep 24\nstop \"<|start_header_id|>\"\nstop \"<|end_header_id|>\"\nstop \"<|eot_id|>\"", "parameters": "num_keep 24\nstop \"<|start_header_id|>\"\nstop \"<|end_header_id|>\"\nstop \"<|eot_id|>\"",

59
docs/benchmark.md Normal file
View File

@@ -0,0 +1,59 @@
# Benchmark
Go benchmark tests that measure end-to-end performance of a running Ollama server. Run these tests to evaluate model inference performance on your hardware and measure the impact of code changes.
## When to use
Run these benchmarks when:
- Making changes to the model inference engine
- Modifying model loading/unloading logic
- Changing prompt processing or token generation code
- Implementing a new model architecture
- Testing performance across different hardware setups
## Prerequisites
- Ollama server running locally with `ollama serve` on `127.0.0.1:11434`
## Usage and Examples
>[!NOTE]
>All commands must be run from the root directory of the Ollama project.
Basic syntax:
```bash
go test -bench=. ./benchmark/... -m $MODEL_NAME
```
Required flags:
- `-bench=.`: Run all benchmarks
- `-m`: Model name to benchmark
Optional flags:
- `-count N`: Number of times to run the benchmark (useful for statistical analysis)
- `-timeout T`: Maximum time for the benchmark to run (e.g. "10m" for 10 minutes)
Common usage patterns:
Single benchmark run with a model specified:
```bash
go test -bench=. ./benchmark/... -m llama3.3
```
## Output metrics
The benchmark reports several key metrics:
- `gen_tok/s`: Generated tokens per second
- `prompt_tok/s`: Prompt processing tokens per second
- `ttft_ms`: Time to first token in milliseconds
- `load_ms`: Model load time in milliseconds
- `gen_tokens`: Total tokens generated
- `prompt_tokens`: Total prompt tokens processed
Each benchmark runs two scenarios:
- Cold start: Model is loaded from disk for each test
- Warm start: Model is pre-loaded in memory
Three prompt lengths are tested for each scenario:
- Short prompt (100 tokens)
- Medium prompt (500 tokens)
- Long prompt (1000 tokens)

View File

@@ -118,7 +118,7 @@ To run tests, use `go test`:
go test ./... go test ./...
``` ```
> NOTE: In rare cirumstances, you may need to change a package using the new > NOTE: In rare cirumstances, you may nedd to change a package using the new
> "synctest" package in go1.24. > "synctest" package in go1.24.
> >
> If you do not have the "synctest" package enabled, you will not see build or > If you do not have the "synctest" package enabled, you will not see build or

View File

@@ -20,7 +20,7 @@ Please refer to the [GPU docs](./gpu.md).
## How can I specify the context window size? ## How can I specify the context window size?
By default, Ollama uses a context window size of 4096 tokens. By default, Ollama uses a context window size of 2048 tokens.
This can be overridden with the `OLLAMA_CONTEXT_LENGTH` environment variable. For example, to set the default context window to 8K, use: This can be overridden with the `OLLAMA_CONTEXT_LENGTH` environment variable. For example, to set the default context window to 8K, use:

View File

@@ -1,6 +1,6 @@
# GPU # GPU
## Nvidia ## Nvidia
Ollama supports Nvidia GPUs with compute capability 5.0+ and driver version 531 and newer. Ollama supports Nvidia GPUs with compute capability 5.0+.
Check your compute compatibility to see if your card is supported: Check your compute compatibility to see if your card is supported:
[https://developer.nvidia.com/cuda-gpus](https://developer.nvidia.com/cuda-gpus) [https://developer.nvidia.com/cuda-gpus](https://developer.nvidia.com/cuda-gpus)

View File

@@ -132,12 +132,22 @@ success
### Supported Quantizations ### Supported Quantizations
- `q4_0`
- `q4_1`
- `q5_0`
- `q5_1`
- `q8_0` - `q8_0`
#### K-means Quantizations #### K-means Quantizations
- `q3_K_S`
- `q3_K_M`
- `q3_K_L`
- `q4_K_S` - `q4_K_S`
- `q4_K_M` - `q4_K_M`
- `q5_K_S`
- `q5_K_M`
- `q6_K`
## Sharing your model on ollama.com ## Sharing your model on ollama.com

View File

@@ -112,8 +112,8 @@ sudo systemctl status ollama
> While AMD has contributed the `amdgpu` driver upstream to the official linux > While AMD has contributed the `amdgpu` driver upstream to the official linux
> kernel source, the version is older and may not support all ROCm features. We > kernel source, the version is older and may not support all ROCm features. We
> recommend you install the latest driver from > recommend you install the latest driver from
> [AMD](https://www.amd.com/en/support/download/linux-drivers.html) for best support > https://www.amd.com/en/support/linux-drivers for best support of your Radeon
> of your Radeon GPU. > GPU.
## Customizing ## Customizing

View File

@@ -150,6 +150,9 @@ PARAMETER <parameter> <parametervalue>
| Parameter | Description | Value Type | Example Usage | | Parameter | Description | Value Type | Example Usage |
| -------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | ---------- | -------------------- | | -------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | ---------- | -------------------- |
| mirostat | Enable Mirostat sampling for controlling perplexity. (default: 0, 0 = disabled, 1 = Mirostat, 2 = Mirostat 2.0) | int | mirostat 0 |
| mirostat_eta | Influences how quickly the algorithm responds to feedback from the generated text. A lower learning rate will result in slower adjustments, while a higher learning rate will make the algorithm more responsive. (Default: 0.1) | float | mirostat_eta 0.1 |
| mirostat_tau | Controls the balance between coherence and diversity of the output. A lower value will result in more focused and coherent text. (Default: 5.0) | float | mirostat_tau 5.0 |
| num_ctx | Sets the size of the context window used to generate the next token. (Default: 2048) | int | num_ctx 4096 | | num_ctx | Sets the size of the context window used to generate the next token. (Default: 2048) | int | num_ctx 4096 |
| repeat_last_n | Sets how far back for the model to look back to prevent repetition. (Default: 64, 0 = disabled, -1 = num_ctx) | int | repeat_last_n 64 | | repeat_last_n | Sets how far back for the model to look back to prevent repetition. (Default: 64, 0 = disabled, -1 = num_ctx) | int | repeat_last_n 64 |
| repeat_penalty | Sets how strongly to penalize repetitions. A higher value (e.g., 1.5) will penalize repetitions more strongly, while a lower value (e.g., 0.9) will be more lenient. (Default: 1.1) | float | repeat_penalty 1.1 | | repeat_penalty | Sets how strongly to penalize repetitions. A higher value (e.g., 1.5) will penalize repetitions more strongly, while a lower value (e.g., 0.9) will be more lenient. (Default: 1.1) | float | repeat_penalty 1.1 |

View File

@@ -12,7 +12,7 @@ A basic Go template consists of three main parts:
Here's an example of a simple chat template: Here's an example of a simple chat template:
```go ```gotmpl
{{- range .Messages }} {{- range .Messages }}
{{ .Role }}: {{ .Content }} {{ .Role }}: {{ .Content }}
{{- end }} {{- end }}
@@ -162,6 +162,6 @@ CodeLlama [7B](https://ollama.com/library/codellama:7b-code) and [13B](https://o
Codestral [22B](https://ollama.com/library/codestral:22b) supports fill-in-middle. Codestral [22B](https://ollama.com/library/codestral:22b) supports fill-in-middle.
```go ```gotmpl
[SUFFIX]{{ .Suffix }}[PREFIX] {{ .Prompt }} [SUFFIX]{{ .Suffix }}[PREFIX] {{ .Prompt }}
``` ```

View File

@@ -43,7 +43,7 @@ Ollama includes multiple LLM libraries compiled for different GPUs and CPU vecto
In the server log, you will see a message that looks something like this (varies from release to release): In the server log, you will see a message that looks something like this (varies from release to release):
``` ```
Dynamic LLM libraries [rocm_v6 cpu cpu_avx cpu_avx2 cuda_v12 rocm_v5] Dynamic LLM libraries [rocm_v6 cpu cpu_avx cpu_avx2 cuda_v11 rocm_v5]
``` ```
**Experimental LLM Library Override** **Experimental LLM Library Override**

View File

@@ -149,22 +149,9 @@ func Bool(k string) func() bool {
} }
} }
// LogLevel returns the log level for the application.
// Values are 0 or false INFO (Default), 1 or true DEBUG, 2 TRACE
func LogLevel() slog.Level {
level := slog.LevelInfo
if s := Var("OLLAMA_DEBUG"); s != "" {
if b, _ := strconv.ParseBool(s); b {
level = slog.LevelDebug
} else if i, _ := strconv.ParseInt(s, 10, 64); i != 0 {
level = slog.Level(i * -4)
}
}
return level
}
var ( var (
// Debug enabled additional debug information.
Debug = Bool("OLLAMA_DEBUG")
// FlashAttention enables the experimental flash attention feature. // FlashAttention enables the experimental flash attention feature.
FlashAttention = Bool("OLLAMA_FLASH_ATTENTION") FlashAttention = Bool("OLLAMA_FLASH_ATTENTION")
// KvCacheType is the quantization type for the K/V cache. // KvCacheType is the quantization type for the K/V cache.
@@ -182,9 +169,7 @@ var (
// Enable the new Ollama engine // Enable the new Ollama engine
NewEngine = Bool("OLLAMA_NEW_ENGINE") NewEngine = Bool("OLLAMA_NEW_ENGINE")
// ContextLength sets the default context length // ContextLength sets the default context length
ContextLength = Uint("OLLAMA_CONTEXT_LENGTH", 4096) ContextLength = Uint("OLLAMA_CONTEXT_LENGTH", 2048)
// Auth enables authentication between the Ollama client and server
UseAuth = Bool("OLLAMA_AUTH")
) )
func String(s string) func() string { func String(s string) func() string {
@@ -224,6 +209,8 @@ var (
MaxRunners = Uint("OLLAMA_MAX_LOADED_MODELS", 0) MaxRunners = Uint("OLLAMA_MAX_LOADED_MODELS", 0)
// MaxQueue sets the maximum number of queued requests. MaxQueue can be configured via the OLLAMA_MAX_QUEUE environment variable. // MaxQueue sets the maximum number of queued requests. MaxQueue can be configured via the OLLAMA_MAX_QUEUE environment variable.
MaxQueue = Uint("OLLAMA_MAX_QUEUE", 512) MaxQueue = Uint("OLLAMA_MAX_QUEUE", 512)
// MaxVRAM sets a maximum VRAM override in bytes. MaxVRAM can be configured via the OLLAMA_MAX_VRAM environment variable.
MaxVRAM = Uint("OLLAMA_MAX_VRAM", 0)
) )
func Uint64(key string, defaultValue uint64) func() uint64 { func Uint64(key string, defaultValue uint64) func() uint64 {
@@ -251,7 +238,7 @@ type EnvVar struct {
func AsMap() map[string]EnvVar { func AsMap() map[string]EnvVar {
ret := map[string]EnvVar{ ret := map[string]EnvVar{
"OLLAMA_DEBUG": {"OLLAMA_DEBUG", LogLevel(), "Show additional debug information (e.g. OLLAMA_DEBUG=1)"}, "OLLAMA_DEBUG": {"OLLAMA_DEBUG", Debug(), "Show additional debug information (e.g. OLLAMA_DEBUG=1)"},
"OLLAMA_FLASH_ATTENTION": {"OLLAMA_FLASH_ATTENTION", FlashAttention(), "Enabled flash attention"}, "OLLAMA_FLASH_ATTENTION": {"OLLAMA_FLASH_ATTENTION", FlashAttention(), "Enabled flash attention"},
"OLLAMA_KV_CACHE_TYPE": {"OLLAMA_KV_CACHE_TYPE", KvCacheType(), "Quantization type for the K/V cache (default: f16)"}, "OLLAMA_KV_CACHE_TYPE": {"OLLAMA_KV_CACHE_TYPE", KvCacheType(), "Quantization type for the K/V cache (default: f16)"},
"OLLAMA_GPU_OVERHEAD": {"OLLAMA_GPU_OVERHEAD", GpuOverhead(), "Reserve a portion of VRAM per GPU (bytes)"}, "OLLAMA_GPU_OVERHEAD": {"OLLAMA_GPU_OVERHEAD", GpuOverhead(), "Reserve a portion of VRAM per GPU (bytes)"},
@@ -268,7 +255,7 @@ func AsMap() map[string]EnvVar {
"OLLAMA_ORIGINS": {"OLLAMA_ORIGINS", AllowedOrigins(), "A comma separated list of allowed origins"}, "OLLAMA_ORIGINS": {"OLLAMA_ORIGINS", AllowedOrigins(), "A comma separated list of allowed origins"},
"OLLAMA_SCHED_SPREAD": {"OLLAMA_SCHED_SPREAD", SchedSpread(), "Always schedule model across all GPUs"}, "OLLAMA_SCHED_SPREAD": {"OLLAMA_SCHED_SPREAD", SchedSpread(), "Always schedule model across all GPUs"},
"OLLAMA_MULTIUSER_CACHE": {"OLLAMA_MULTIUSER_CACHE", MultiUserCache(), "Optimize prompt caching for multi-user scenarios"}, "OLLAMA_MULTIUSER_CACHE": {"OLLAMA_MULTIUSER_CACHE", MultiUserCache(), "Optimize prompt caching for multi-user scenarios"},
"OLLAMA_CONTEXT_LENGTH": {"OLLAMA_CONTEXT_LENGTH", ContextLength(), "Context length to use unless otherwise specified (default: 4096)"}, "OLLAMA_CONTEXT_LENGTH": {"OLLAMA_CONTEXT_LENGTH", ContextLength(), "Context length to use unless otherwise specified (default: 2048)"},
"OLLAMA_NEW_ENGINE": {"OLLAMA_NEW_ENGINE", NewEngine(), "Enable the new Ollama engine"}, "OLLAMA_NEW_ENGINE": {"OLLAMA_NEW_ENGINE", NewEngine(), "Enable the new Ollama engine"},
// Informational // Informational

View File

@@ -1,13 +1,11 @@
package envconfig package envconfig
import ( import (
"log/slog"
"math" "math"
"testing" "testing"
"time" "time"
"github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp"
"github.com/ollama/ollama/logutil"
) )
func TestHost(t *testing.T) { func TestHost(t *testing.T) {
@@ -281,8 +279,8 @@ func TestVar(t *testing.T) {
func TestContextLength(t *testing.T) { func TestContextLength(t *testing.T) {
cases := map[string]uint{ cases := map[string]uint{
"": 4096, "": 2048,
"2048": 2048, "4096": 4096,
} }
for k, v := range cases { for k, v := range cases {
@@ -294,34 +292,3 @@ func TestContextLength(t *testing.T) {
}) })
} }
} }
func TestLogLevel(t *testing.T) {
cases := map[string]slog.Level{
// Default to INFO
"": slog.LevelInfo,
"false": slog.LevelInfo,
"f": slog.LevelInfo,
"0": slog.LevelInfo,
// True values enable Debug
"true": slog.LevelDebug,
"t": slog.LevelDebug,
// Positive values increase verbosity
"1": slog.LevelDebug,
"2": logutil.LevelTrace,
// Negative values decrease verbosity
"-1": slog.LevelWarn,
"-2": slog.LevelError,
}
for k, v := range cases {
t.Run(k, func(t *testing.T) {
t.Setenv("OLLAMA_DEBUG", k)
if i := LogLevel(); i != v {
t.Errorf("%s: expected %d, got %d", k, v, i)
}
})
}
}

View File

@@ -8,7 +8,6 @@ type Config interface {
Bool(string, ...bool) bool Bool(string, ...bool) bool
Strings(string, ...[]string) []string Strings(string, ...[]string) []string
Ints(string, ...[]int32) []int32 Uints(string, ...[]uint32) []uint32
Floats(string, ...[]float32) []float32 Floats(string, ...[]float32) []float32
Bools(string, ...[]bool) []bool
} }

View File

@@ -15,7 +15,6 @@ import (
type GGML struct { type GGML struct {
container container
model model
Length int64
} }
type model interface { type model interface {
@@ -34,16 +33,15 @@ func (kv KV) Kind() string {
} }
func (kv KV) ParameterCount() uint64 { func (kv KV) ParameterCount() uint64 {
val, _ := keyValue(kv, "general.parameter_count", uint64(0)) return keyValue[uint64](kv, "general.parameter_count")
return val
} }
func (kv KV) FileType() FileType { func (kv KV) FileType() fileType {
if t := kv.Uint("general.file_type"); t > 0 { if t := kv.Uint("general.file_type"); t > 0 {
return FileType(t) return fileType(t)
} }
return FileTypeUnknown return fileTypeUnknown
} }
func (kv KV) BlockCount() uint64 { func (kv KV) BlockCount() uint64 {
@@ -54,27 +52,16 @@ func (kv KV) EmbeddingLength() uint64 {
return uint64(kv.Uint("embedding_length")) return uint64(kv.Uint("embedding_length"))
} }
func (kv KV) HeadCountMax() uint64 { func (kv KV) HeadCount() uint64 {
// TODO(drifkin): using the max value can cause an overestimation. In the return uint64(kv.Uint("attention.head_count"))
// future if array values become more popular, we can adapt the more invasive
// <https://github.com/ollama/ollama/pull/10225>
return uint64(kv.UintOrMaxArrayValue("attention.head_count", 1))
} }
func (kv KV) HeadCountMin() uint64 { func (kv KV) HeadCountKV() uint64 {
return uint64(kv.UintOrMinArrayValue("attention.head_count", 1)) return uint64(kv.Uint("attention.head_count_kv", 1))
} }
func (kv KV) HeadCountKVMax() uint64 { func (kv KV) EmbeddingHeadCount() uint64 {
return uint64(kv.UintOrMaxArrayValue("attention.head_count_kv", 1)) if heads := kv.HeadCount(); heads > 0 {
}
func (kv KV) HeadCountKVMin() uint64 {
return uint64(kv.UintOrMinArrayValue("attention.head_count_kv", 1))
}
func (kv KV) EmbeddingHeadCountMax() uint64 {
if heads := kv.HeadCountMin(); heads > 0 {
return kv.EmbeddingLength() / heads return kv.EmbeddingLength() / heads
} }
@@ -82,11 +69,15 @@ func (kv KV) EmbeddingHeadCountMax() uint64 {
} }
func (kv KV) EmbeddingHeadCountK() uint64 { func (kv KV) EmbeddingHeadCountK() uint64 {
return uint64(kv.Uint("attention.key_length", uint32(kv.EmbeddingHeadCountMax()))) return uint64(kv.Uint("attention.key_length", uint32(kv.EmbeddingHeadCount())))
} }
func (kv KV) EmbeddingHeadCountV() uint64 { func (kv KV) EmbeddingHeadCountV() uint64 {
return uint64(kv.Uint("attention.value_length", uint32(kv.EmbeddingHeadCountMax()))) return uint64(kv.Uint("attention.value_length", uint32(kv.EmbeddingHeadCount())))
}
func (kv KV) GQA() uint64 {
return kv.HeadCount() / kv.HeadCountKV()
} }
func (kv KV) ContextLength() uint64 { func (kv KV) ContextLength() uint64 {
@@ -98,113 +89,68 @@ func (kv KV) ChatTemplate() string {
} }
func (kv KV) String(key string, defaultValue ...string) string { func (kv KV) String(key string, defaultValue ...string) string {
val, _ := keyValue(kv, key, append(defaultValue, "")...) return keyValue(kv, key, append(defaultValue, "")...)
return val
} }
func (kv KV) Uint(key string, defaultValue ...uint32) uint32 { func (kv KV) Uint(key string, defaultValue ...uint32) uint32 {
val, _ := keyValue(kv, key, append(defaultValue, 0)...) return keyValue(kv, key, append(defaultValue, 0)...)
return val
} }
func (kv KV) Float(key string, defaultValue ...float32) float32 { func (kv KV) Float(key string, defaultValue ...float32) float32 {
val, _ := keyValue(kv, key, append(defaultValue, 0)...) return keyValue(kv, key, append(defaultValue, 0)...)
return val
} }
func (kv KV) Bool(key string, defaultValue ...bool) bool { func (kv KV) Bool(key string, defaultValue ...bool) bool {
val, _ := keyValue(kv, key, append(defaultValue, false)...) return keyValue(kv, key, append(defaultValue, false)...)
return val
}
func (kv KV) UintOrMaxArrayValue(key string, defaultValue uint32) uint32 {
_, max := kv.UintOrArrayValue(key, defaultValue)
return max
}
func (kv KV) UintOrMinArrayValue(key string, defaultValue uint32) uint32 {
min, _ := kv.UintOrArrayValue(key, defaultValue)
return min
}
func (kv KV) UintOrArrayValue(key string, defaultValue uint32) (uint32, uint32) {
if u32, ok := keyValue(kv, key, uint32(0)); ok {
return u32, u32
} else if u32s, ok := keyValue(kv, key, &array[uint32]{}); ok {
min := slices.Min(u32s.values)
max := slices.Max(u32s.values)
return min, max
} else if i32s, ok := keyValue(kv, key, &array[int32]{}); ok {
min := slices.Min(i32s.values)
max := slices.Max(i32s.values)
if min < 0 || max < 0 {
slog.Warn("array values are unexpectedly negative", "key", key, "min", min, "max", max)
}
return uint32(min), uint32(max)
}
return defaultValue, defaultValue
} }
func (kv KV) Strings(key string, defaultValue ...[]string) []string { func (kv KV) Strings(key string, defaultValue ...[]string) []string {
val, _ := keyValue(kv, key, &array[string]{values: append(defaultValue, []string(nil))[0]}) r := keyValue(kv, key, &array{})
return val.values s := make([]string, r.size)
} for i := range r.size {
s[i] = r.values[i].(string)
}
func (kv KV) Ints(key string, defaultValue ...[]int32) []int32 { return s
val, _ := keyValue(kv, key, &array[int32]{values: append(defaultValue, []int32(nil))[0]})
return val.values
} }
func (kv KV) Uints(key string, defaultValue ...[]uint32) []uint32 { func (kv KV) Uints(key string, defaultValue ...[]uint32) []uint32 {
val, _ := keyValue(kv, key, &array[uint32]{values: append(defaultValue, []uint32(nil))[0]}) r := keyValue(kv, key, &array{})
return val.values s := make([]uint32, r.size)
for i := range r.size {
s[i] = uint32(r.values[i].(int32))
}
return s
} }
func (kv KV) Floats(key string, defaultValue ...[]float32) []float32 { func (kv KV) Floats(key string, defaultValue ...[]float32) []float32 {
val, _ := keyValue(kv, key, &array[float32]{values: append(defaultValue, []float32(nil))[0]}) r := keyValue(kv, key, &array{})
return val.values s := make([]float32, r.size)
} for i := range r.size {
s[i] = float32(r.values[i].(float32))
func (kv KV) Bools(key string, defaultValue ...[]bool) []bool { }
val, _ := keyValue(kv, key, &array[bool]{values: append(defaultValue, []bool(nil))[0]}) return s
return val.values
} }
func (kv KV) OllamaEngineRequired() bool { func (kv KV) OllamaEngineRequired() bool {
return slices.Contains([]string{ return slices.Contains([]string{
"gemma3", "gemma3",
"gemma3n",
"mistral3", "mistral3",
"llama4",
"mllama",
"qwen25vl",
}, kv.Architecture()) }, kv.Architecture())
} }
type valueTypes interface { func keyValue[T string | uint32 | uint64 | float32 | *array | bool](kv KV, key string, defaultValue ...T) T {
uint8 | int8 | uint16 | int16 |
uint32 | int32 | uint64 | int64 |
string | float32 | float64 | bool
}
type arrayValueTypes interface {
*array[uint8] | *array[int8] | *array[uint16] | *array[int16] |
*array[uint32] | *array[int32] | *array[uint64] | *array[int64] |
*array[string] | *array[float32] | *array[float64] | *array[bool]
}
func keyValue[T valueTypes | arrayValueTypes](kv KV, key string, defaultValue ...T) (T, bool) {
if !strings.HasPrefix(key, "tokenizer.") && !strings.HasPrefix(key, "general.") { if !strings.HasPrefix(key, "tokenizer.") && !strings.HasPrefix(key, "general.") {
key = kv.Architecture() + "." + key key = kv.Architecture() + "." + key
} }
if val, ok := kv[key].(T); ok { if val, ok := kv[key]; ok {
return val, true return val.(T)
} }
slog.Debug("key with type not found", "key", key, "default", defaultValue[0]) slog.Warn("key not found", "key", key, "default", defaultValue[0])
return defaultValue[0], false return defaultValue[0]
} }
type Tensors struct { type Tensors struct {
@@ -280,11 +226,7 @@ func (t Tensor) block() (n int) {
} }
func (t Tensor) blockSize() uint64 { func (t Tensor) blockSize() uint64 {
return (TensorType)(t.Kind).BlockSize() switch t.Kind {
}
func (t TensorType) BlockSize() uint64 {
switch t {
case case
0, // F32 0, // F32
1, // F16 1, // F16
@@ -310,77 +252,73 @@ func (t TensorType) BlockSize() uint64 {
} }
func (t Tensor) typeSize() uint64 { func (t Tensor) typeSize() uint64 {
return TensorType(t.Kind).TypeSize() blockSize := t.blockSize()
}
func (t TensorType) TypeSize() uint64 { switch t.Kind {
blockSize := t.BlockSize() case 0: // FP32
switch t {
case TensorTypeF32:
return 4 return 4
case TensorTypeF16: case 1: // FP16
return 2 return 2
case TensorTypeQ4_0: case 2: // Q4_0
return 2 + blockSize/2 return 2 + blockSize/2
case TensorTypeQ4_1: case 3: // Q4_1
return 2 + 2 + blockSize/2 return 2 + 2 + blockSize/2
case TensorTypeQ5_0: case 6: // Q5_0
return 2 + 4 + blockSize/2 return 2 + 4 + blockSize/2
case TensorTypeQ5_1: case 7: // Q5_1
return 2 + 2 + 4 + blockSize/2 return 2 + 2 + 4 + blockSize/2
case TensorTypeQ8_0: case 8: // Q8_0
return 2 + blockSize return 2 + blockSize
case TensorTypeQ8_1: case 9: // Q8_1
return 2 + 2 + blockSize return 2 + 2 + blockSize
case TensorTypeQ2_K: case 10: // Q2_K
return blockSize/16 + blockSize/4 + 2 + 2 return blockSize/16 + blockSize/4 + 2 + 2
case TensorTypeQ3_K: case 11: // Q3_K
return blockSize/8 + blockSize/4 + 12 + 2 return blockSize/8 + blockSize/4 + 12 + 2
case TensorTypeQ4_K: case 12: // Q4_K
return 2 + 2 + 12 + blockSize/2 return 2 + 2 + 12 + blockSize/2
case TensorTypeQ5_K: case 13: // Q5_K
return 2 + 2 + 12 + blockSize/8 + blockSize/2 return 2 + 2 + 12 + blockSize/8 + blockSize/2
case TensorTypeQ6_K: case 14: // Q6_K
return blockSize/2 + blockSize/4 + blockSize/16 + 2 return blockSize/2 + blockSize/4 + blockSize/16 + 2
case TensorTypeQ8_K: case 15: // Q8_K
return 4 + blockSize + 2*blockSize/16 return 4 + blockSize + 2*blockSize/16
case tensorTypeIQ2_XXS: case 16: // IQ2_XXS
return 2 + 2*blockSize/8 return 2 + 2*blockSize/8
case tensorTypeIQ2_XS: case 17: // IQ2_XS
return 2 + 2*blockSize/8 + blockSize/32 return 2 + 2*blockSize/8 + blockSize/32
case tensorTypeIQ3_XXS: case 18: // IQ3_XXS
return 2 + blockSize/4 + blockSize/8 return 2 + blockSize/4 + blockSize/8
case tensorTypeIQ1_S: case 19: // IQ1_S
return 2 + blockSize/8 + blockSize/16 return 2 + blockSize/8 + blockSize/16
case tensorTypeIQ4_NL: case 20: // IQ4_NL
return 2 + blockSize/2 return 2 + blockSize/2
case tensorTypeIQ3_S: case 21: // IQ3_S
return 2 + blockSize/4 + blockSize/8 + blockSize/32 + 4 return 2 + blockSize/4 + blockSize/8 + blockSize/32 + 4
case tensorTypeIQ2_S: case 22: // IQ2_S
return 2 + blockSize/4 + blockSize/16 return 2 + blockSize/4 + blockSize/16
case tensorTypeIQ4_XS: case 23: // IQ4_XS
return 2 + 2 + blockSize/2 + blockSize/64 return 2 + 2 + blockSize/2 + blockSize/64
case TensorTypeI8: case 24: // I8
return 1 return 1
case TensorTypeI16: case 25: // I16
return 2 return 2
case TensorTypeI32: case 26: // I32
return 4 return 4
case TensorTypeI64: case 27: // I64
return 8 return 8
case TensorTypeF64: case 28: // F64
return 8 return 8
case tensorTypeIQ1_M: case 29: // IQ1_M
return blockSize/8 + blockSize/16 + blockSize/32 return blockSize/8 + blockSize/16 + blockSize/32
case TensorTypeBF16: case 30: // BF16
return 2 return 2
default: default:
return 0 return 0
} }
} }
func (t Tensor) Elements() uint64 { func (t Tensor) parameters() uint64 {
var count uint64 = 1 var count uint64 = 1
for _, n := range t.Shape { for _, n := range t.Shape {
count *= n count *= n
@@ -389,11 +327,11 @@ func (t Tensor) Elements() uint64 {
} }
func (t Tensor) Size() uint64 { func (t Tensor) Size() uint64 {
return t.Elements() * t.typeSize() / t.blockSize() return t.parameters() * t.typeSize() / t.blockSize()
} }
func (t Tensor) Type() string { func (t Tensor) Type() string {
return TensorType(t.Kind).String() return fileType(t.Kind).String()
} }
type container interface { type container interface {
@@ -437,13 +375,18 @@ func DetectContentType(b []byte) string {
// Decode decodes a GGML model from the given reader. // Decode decodes a GGML model from the given reader.
// //
// It collects array values for arrays with a size less than or equal to // It collects array values for arrays with a size less than or equal to
// maxArraySize. If the maxArraySize is negative, all arrays are collected. // maxArraySize. If maxArraySize is 0, the default value of 1024 is used. If
func Decode(rs io.ReadSeeker, maxArraySize int) (*GGML, error) { // the maxArraySize is negative, all arrays are collected.
func Decode(rs io.ReadSeeker, maxArraySize int) (*GGML, int64, error) {
if maxArraySize == 0 {
maxArraySize = 1024
}
rs = bufioutil.NewBufferedSeeker(rs, 32<<10) rs = bufioutil.NewBufferedSeeker(rs, 32<<10)
var magic uint32 var magic uint32
if err := binary.Read(rs, binary.LittleEndian, &magic); err != nil { if err := binary.Read(rs, binary.LittleEndian, &magic); err != nil {
return nil, err return nil, 0, err
} }
var c container var c container
@@ -453,34 +396,33 @@ func Decode(rs io.ReadSeeker, maxArraySize int) (*GGML, error) {
case FILE_MAGIC_GGUF_BE: case FILE_MAGIC_GGUF_BE:
c = &containerGGUF{ByteOrder: binary.BigEndian, maxArraySize: maxArraySize} c = &containerGGUF{ByteOrder: binary.BigEndian, maxArraySize: maxArraySize}
default: default:
return nil, errors.New("invalid file magic") return nil, 0, errors.New("invalid file magic")
} }
model, err := c.Decode(rs) model, err := c.Decode(rs)
if err != nil { if err != nil {
return nil, err return nil, 0, err
} }
offset, err := rs.Seek(0, io.SeekCurrent) offset, err := rs.Seek(0, io.SeekCurrent)
if err != nil { if err != nil {
return nil, err return nil, 0, err
} }
// final model type // final model type
return &GGML{ return &GGML{
container: c, container: c,
model: model, model: model,
Length: offset, }, offset, nil
}, nil
} }
func (f GGML) GraphSize(context, batch uint64, numParallel int, kvCacheType string) (kv []uint64, partialOffload, fullOffload uint64) { func (f GGML) GraphSize(context, batch uint64, numParallel int, kvCacheType string) (kv []uint64, partialOffload, fullOffload uint64) {
embedding := f.KV().EmbeddingLength() embedding := f.KV().EmbeddingLength()
heads := f.KV().HeadCountMax() heads := f.KV().HeadCount()
headsKV := f.KV().HeadCountKVMax() headsKV := f.KV().HeadCountKV()
vocab := uint64(f.KV()["tokenizer.ggml.tokens"].(*array[string]).size) vocab := uint64(f.KV()["tokenizer.ggml.tokens"].(*array).size)
embeddingHeads := f.KV().EmbeddingHeadCountMax() embeddingHeads := f.KV().EmbeddingHeadCount()
embeddingHeadsK := f.KV().EmbeddingHeadCountK() embeddingHeadsK := f.KV().EmbeddingHeadCountK()
embeddingHeadsV := f.KV().EmbeddingHeadCountV() embeddingHeadsV := f.KV().EmbeddingHeadCountV()
@@ -493,7 +435,7 @@ func (f GGML) GraphSize(context, batch uint64, numParallel int, kvCacheType stri
} }
switch f.KV().Architecture() { switch f.KV().Architecture() {
case "llama", "llama4": case "llama":
fullOffload = max( fullOffload = max(
4*batch*(1+4*embedding+context*(1+heads)), 4*batch*(1+4*embedding+context*(1+heads)),
4*batch*(embedding+vocab), 4*batch*(embedding+vocab),
@@ -507,7 +449,7 @@ func (f GGML) GraphSize(context, batch uint64, numParallel int, kvCacheType stri
if ffnGateExpsWeight, ok := layers["blk.0"]["ffn_gate_exps.weight"]; ok { if ffnGateExpsWeight, ok := layers["blk.0"]["ffn_gate_exps.weight"]; ok {
// mixtral 8x22b // mixtral 8x22b
ff := uint64(f.KV().Uint("feed_forward_length")) ff := uint64(f.KV()["llama.feed_forward_length"].(uint32))
partialOffload = max( partialOffload = max(
3*ffnGateExpsWeight.Size()+4*batch*(2*ff+headsKV+embedding+context+embeddingHeads*headsKV), 3*ffnGateExpsWeight.Size()+4*batch*(2*ff+headsKV+embedding+context+embeddingHeads*headsKV),
4*(context*batch*heads+context*embeddingHeads*headsKV+batch*1024+embeddingHeads*headsKV*batch), 4*(context*batch*heads+context*embeddingHeads*headsKV+batch*1024+embeddingHeads*headsKV*batch),
@@ -524,9 +466,9 @@ func (f GGML) GraphSize(context, batch uint64, numParallel int, kvCacheType stri
case "mllama": case "mllama":
var visionTokens, tiles uint64 = 1601, 4 var visionTokens, tiles uint64 = 1601, 4
crossAttentionLayers := f.KV().Ints("attention.cross_attention_layers") crossAttentionLayers := f.KV().Uints("attention.cross_attention_layers")
for i := range kv { for i := range kv {
if slices.Contains(crossAttentionLayers, int32(i)) { if slices.Contains(crossAttentionLayers, uint32(i)) {
kv[i] = headsKV * (embeddingHeadsK + embeddingHeadsV) * kv[i] = headsKV * (embeddingHeadsK + embeddingHeadsV) *
4 * // sizeof(float32) 4 * // sizeof(float32)
visionTokens * visionTokens *
@@ -543,7 +485,7 @@ func (f GGML) GraphSize(context, batch uint64, numParallel int, kvCacheType stri
var ropeFreqsCount uint64 var ropeFreqsCount uint64
if ropeFreqs, ok := f.Tensors().GroupLayers()["rope_freqs"]; ok { if ropeFreqs, ok := f.Tensors().GroupLayers()["rope_freqs"]; ok {
if ropeFreqsWeights, ok := ropeFreqs["weights"]; ok { if ropeFreqsWeights, ok := ropeFreqs["weights"]; ok {
ropeFreqsCount = ropeFreqsWeights.Elements() ropeFreqsCount = ropeFreqsWeights.parameters()
} }
} }
@@ -555,7 +497,7 @@ func (f GGML) GraphSize(context, batch uint64, numParallel int, kvCacheType stri
// vocab graph // vocab graph
4*batch*(embedding+vocab)+embedding*vocab*105/128, 4*batch*(embedding+vocab)+embedding*vocab*105/128,
) )
case "gemma", "gemma2", "gemma3", "gemma3n": case "gemma", "gemma2", "gemma3":
fullOffload = max( fullOffload = max(
4*batch*(embedding+vocab), 4*batch*(embedding+vocab),
4*batch*(2+context+context*heads+2*embedding+2*embeddingHeadsK*heads), 4*batch*(2+context+context*heads+2*embedding+2*embeddingHeadsK*heads),
@@ -568,11 +510,6 @@ func (f GGML) GraphSize(context, batch uint64, numParallel int, kvCacheType stri
embedding*embeddingHeadsK*heads*9/16, embedding*embeddingHeadsK*heads*9/16,
) )
if f.KV().Architecture() == "gemma3n" {
fullOffload *= 4
partialOffload *= 4
}
// Gemma2 also has sliding window attention but we only have an optimized implementation in the Ollama // Gemma2 also has sliding window attention but we only have an optimized implementation in the Ollama
// engine. Gemma3 always uses the Ollama engine. // engine. Gemma3 always uses the Ollama engine.
if f.KV().Architecture() == "gemma3" { if f.KV().Architecture() == "gemma3" {
@@ -708,23 +645,6 @@ func (llm GGML) VisionGraphSize() (weights, graphSize uint64) {
graphSize = 4 * (imageSize*imageSize*numChannels + graphSize = 4 * (imageSize*imageSize*numChannels +
embeddingLength*patchSize + embeddingLength*patchSize +
numPatches*numPatches*headCount) numPatches*numPatches*headCount)
case "qwen25vl":
maxPixels := uint64(llm.KV().Uint("vision.max_pixels", 28*28*1280))
numPatches := maxPixels / (patchSize * patchSize)
graphSize = 4 * (maxPixels*numChannels + // Original image storage
// Normalized pixels
maxPixels*numChannels +
// Patches storage (numPatches * channels * patchSize^2)
numPatches*numChannels*patchSize*patchSize +
// Self-attention calculations
numPatches*numPatches*headCount +
// Additional buffer for processing
embeddingLength*numPatches)
case "llama4":
// vision graph is computed independently in the same schedule
// and is negligible compared to the worst case text graph
} }
return weights, graphSize return weights, graphSize

View File

@@ -2,7 +2,6 @@ package ggml
import ( import (
"maps" "maps"
"math"
"slices" "slices"
"strconv" "strconv"
"strings" "strings"
@@ -211,91 +210,3 @@ func TestTensorTypes(t *testing.T) {
}) })
} }
} }
func TestKeyValue(t *testing.T) {
kv := KV{
"general.architecture": "test",
"test.strings": &array[string]{size: 3, values: []string{"a", "b", "c"}},
"test.float32s": &array[float32]{size: 3, values: []float32{1.0, 2.0, 3.0}},
"test.int32s": &array[int32]{size: 3, values: []int32{1, 2, 3}},
"test.uint32s": &array[uint32]{size: 3, values: []uint32{1, 2, 3}},
}
if diff := cmp.Diff(kv.Strings("strings"), []string{"a", "b", "c"}); diff != "" {
t.Errorf("unexpected strings (-got +want):\n%s", diff)
}
if diff := cmp.Diff(kv.Strings("nonexistent.strings"), []string(nil)); diff != "" {
t.Errorf("unexpected strings (-got +want):\n%s", diff)
}
if diff := cmp.Diff(kv.Strings("default.strings", []string{"ollama"}), []string{"ollama"}); diff != "" {
t.Errorf("unexpected strings (-got +want):\n%s", diff)
}
if diff := cmp.Diff(kv.Floats("float32s"), []float32{1.0, 2.0, 3.0}); diff != "" {
t.Errorf("unexpected float32s (-got +want):\n%s", diff)
}
if diff := cmp.Diff(kv.Floats("nonexistent.float32s"), []float32(nil)); diff != "" {
t.Errorf("unexpected float32s (-got +want):\n%s", diff)
}
if diff := cmp.Diff(kv.Floats("default.float32s", []float32{math.MaxFloat32}), []float32{math.MaxFloat32}); diff != "" {
t.Errorf("unexpected float32s (-got +want):\n%s", diff)
}
if diff := cmp.Diff(kv.Ints("int32s"), []int32{1, 2, 3}); diff != "" {
t.Errorf("unexpected int8s (-got +want):\n%s", diff)
}
if diff := cmp.Diff(kv.Ints("nonexistent.int32s"), []int32(nil)); diff != "" {
t.Errorf("unexpected int8s (-got +want):\n%s", diff)
}
if diff := cmp.Diff(kv.Ints("default.int32s", []int32{math.MaxInt32}), []int32{math.MaxInt32}); diff != "" {
t.Errorf("unexpected int8s (-got +want):\n%s", diff)
}
if diff := cmp.Diff(kv.Uints("uint32s"), []uint32{1, 2, 3}); diff != "" {
t.Errorf("unexpected uint8s (-got +want):\n%s", diff)
}
if diff := cmp.Diff(kv.Uints("nonexistent.uint32s"), []uint32(nil)); diff != "" {
t.Errorf("unexpected uint8s (-got +want):\n%s", diff)
}
if diff := cmp.Diff(kv.Uints("default.uint32s", []uint32{math.MaxUint32}), []uint32{math.MaxUint32}); diff != "" {
t.Errorf("unexpected uint8s (-got +want):\n%s", diff)
}
}
func TestHeadCount(t *testing.T) {
valuesArray := []int32{1, 5, 3, 4}
cases := []struct {
kv KV
want uint64
}{
{
kv: KV{
"general.architecture": "abc",
"abc.attention.head_count": &array[int32]{values: valuesArray, size: len(valuesArray)},
},
want: uint64(5),
},
{
kv: KV{
"general.architecture": "abc",
"abc.attention.head_count": uint32(3),
},
want: uint64(3),
},
}
for _, tt := range cases {
got := tt.kv.HeadCountMax()
if got != tt.want {
t.Errorf("unexpected max value: got=%d want=%d", got, tt.want)
}
}
}

View File

@@ -9,12 +9,8 @@ import (
"io" "io"
"log/slog" "log/slog"
"maps" "maps"
"os"
"runtime"
"slices" "slices"
"strings" "strings"
"golang.org/x/sync/errgroup"
) )
type containerGGUF struct { type containerGGUF struct {
@@ -40,6 +36,10 @@ type containerGGUF struct {
maxArraySize int maxArraySize int
} }
func (c *containerGGUF) canCollectArray(size int) bool {
return c.maxArraySize < 0 || size <= c.maxArraySize
}
func (c *containerGGUF) Name() string { func (c *containerGGUF) Name() string {
return "gguf" return "gguf"
} }
@@ -229,13 +229,16 @@ func (llm *gguf) Decode(rs io.ReadSeeker) error {
} }
llm.tensors = append(llm.tensors, &tensor) llm.tensors = append(llm.tensors, &tensor)
llm.parameters += tensor.Elements() llm.parameters += tensor.parameters()
} }
// patch KV with parameter count // patch KV with parameter count
llm.kv["general.parameter_count"] = llm.parameters llm.kv["general.parameter_count"] = llm.parameters
alignment := llm.kv.Uint("general.alignment", 32) alignment, ok := llm.kv["general.alignment"].(uint32)
if !ok {
alignment = 32
}
offset, err := rs.Seek(0, io.SeekCurrent) offset, err := rs.Seek(0, io.SeekCurrent)
if err != nil { if err != nil {
@@ -295,23 +298,6 @@ func readGGUFV1String(llm *gguf, r io.Reader) (string, error) {
return b.String(), nil return b.String(), nil
} }
func readGGUFV1StringsData(llm *gguf, r io.Reader, a *array[string]) (any, error) {
for i := range a.size {
if a.values != nil {
e, err := readGGUFV1String(llm, r)
if err != nil {
return nil, err
}
a.values[i] = e
} else {
discardGGUFString(llm, r)
}
}
return a, nil
}
func discardGGUFString(llm *gguf, r io.Reader) error { func discardGGUFString(llm *gguf, r io.Reader) error {
buf := llm.scratch[:8] buf := llm.scratch[:8]
_, err := io.ReadFull(r, buf) _, err := io.ReadFull(r, buf)
@@ -369,44 +355,78 @@ func writeGGUFString(w io.Writer, s string) error {
return err return err
} }
func readGGUFStringsData(llm *gguf, r io.Reader, a *array[string]) (any, error) { type array struct {
for i := range a.size { size int
if a.values != nil { values []any
e, err := readGGUFString(llm, r) }
if err != nil {
return nil, err
}
func (a *array) MarshalJSON() ([]byte, error) {
return json.Marshal(a.values)
}
func readGGUFV1Array(llm *gguf, r io.Reader) (*array, error) {
t, err := readGGUF[uint32](llm, r)
if err != nil {
return nil, err
}
n, err := readGGUF[uint32](llm, r)
if err != nil {
return nil, err
}
a := &array{size: int(n)}
if llm.canCollectArray(int(n)) {
a.values = make([]any, 0, int(n))
}
for i := range n {
var e any
switch t {
case ggufTypeUint8:
e, err = readGGUF[uint8](llm, r)
case ggufTypeInt8:
e, err = readGGUF[int8](llm, r)
case ggufTypeUint16:
e, err = readGGUF[uint16](llm, r)
case ggufTypeInt16:
e, err = readGGUF[int16](llm, r)
case ggufTypeUint32:
e, err = readGGUF[uint32](llm, r)
case ggufTypeInt32:
e, err = readGGUF[int32](llm, r)
case ggufTypeUint64:
e, err = readGGUF[uint64](llm, r)
case ggufTypeInt64:
e, err = readGGUF[int64](llm, r)
case ggufTypeFloat32:
e, err = readGGUF[float32](llm, r)
case ggufTypeFloat64:
e, err = readGGUF[float64](llm, r)
case ggufTypeBool:
e, err = readGGUF[bool](llm, r)
case ggufTypeString:
e, err = readGGUFV1String(llm, r)
default:
return nil, fmt.Errorf("invalid array type: %d", t)
}
if err != nil {
return nil, err
}
if a.values != nil {
a.values[i] = e a.values[i] = e
} else {
discardGGUFString(llm, r)
} }
} }
return a, nil return a, nil
} }
type array[T any] struct { func readGGUFArray(llm *gguf, r io.Reader) (*array, error) {
// size is the actual size of the array if llm.Version == 1 {
size int return readGGUFV1Array(llm, r)
// values is the array of values. this is nil if the array is larger than configured maxSize
values []T
}
func (a *array[T]) MarshalJSON() ([]byte, error) {
return json.Marshal(a.values)
}
func newArray[T any](size, maxSize int) *array[T] {
a := array[T]{size: size}
if maxSize < 0 || size <= maxSize {
a.values = make([]T, size)
} }
return &a
}
func readGGUFArray(llm *gguf, r io.Reader) (any, error) {
t, err := readGGUF[uint32](llm, r) t, err := readGGUF[uint32](llm, r)
if err != nil { if err != nil {
return nil, err return nil, err
@@ -417,55 +437,45 @@ func readGGUFArray(llm *gguf, r io.Reader) (any, error) {
return nil, err return nil, err
} }
switch t { a := &array{size: int(n)}
case ggufTypeUint8: if llm.canCollectArray(int(n)) {
a := newArray[uint8](int(n), llm.maxArraySize) a.values = make([]any, int(n))
return readGGUFArrayData(llm, r, a)
case ggufTypeInt8:
a := newArray[int8](int(n), llm.maxArraySize)
return readGGUFArrayData(llm, r, a)
case ggufTypeUint16:
a := newArray[uint16](int(n), llm.maxArraySize)
return readGGUFArrayData(llm, r, a)
case ggufTypeInt16:
a := newArray[int16](int(n), llm.maxArraySize)
return readGGUFArrayData(llm, r, a)
case ggufTypeUint32:
a := newArray[uint32](int(n), llm.maxArraySize)
return readGGUFArrayData(llm, r, a)
case ggufTypeInt32:
a := newArray[int32](int(n), llm.maxArraySize)
return readGGUFArrayData(llm, r, a)
case ggufTypeUint64:
a := newArray[uint64](int(n), llm.maxArraySize)
return readGGUFArrayData(llm, r, a)
case ggufTypeInt64:
a := newArray[int64](int(n), llm.maxArraySize)
return readGGUFArrayData(llm, r, a)
case ggufTypeFloat32:
a := newArray[float32](int(n), llm.maxArraySize)
return readGGUFArrayData(llm, r, a)
case ggufTypeFloat64:
a := newArray[float64](int(n), llm.maxArraySize)
return readGGUFArrayData(llm, r, a)
case ggufTypeBool:
a := newArray[bool](int(n), llm.maxArraySize)
return readGGUFArrayData(llm, r, a)
case ggufTypeString:
a := newArray[string](int(n), llm.maxArraySize)
if llm.Version == 1 {
return readGGUFV1StringsData(llm, r, a)
}
return readGGUFStringsData(llm, r, a)
default:
return nil, fmt.Errorf("invalid array type: %d", t)
} }
}
func readGGUFArrayData[T any](llm *gguf, r io.Reader, a *array[T]) (any, error) { for i := range n {
for i := range a.size { var e any
e, err := readGGUF[T](llm, r) switch t {
case ggufTypeUint8:
e, err = readGGUF[uint8](llm, r)
case ggufTypeInt8:
e, err = readGGUF[int8](llm, r)
case ggufTypeUint16:
e, err = readGGUF[uint16](llm, r)
case ggufTypeInt16:
e, err = readGGUF[int16](llm, r)
case ggufTypeUint32:
e, err = readGGUF[uint32](llm, r)
case ggufTypeInt32:
e, err = readGGUF[int32](llm, r)
case ggufTypeUint64:
e, err = readGGUF[uint64](llm, r)
case ggufTypeInt64:
e, err = readGGUF[int64](llm, r)
case ggufTypeFloat32:
e, err = readGGUF[float32](llm, r)
case ggufTypeFloat64:
e, err = readGGUF[float64](llm, r)
case ggufTypeBool:
e, err = readGGUF[bool](llm, r)
case ggufTypeString:
if a.values != nil {
e, err = readGGUFString(llm, r)
} else {
err = discardGGUFString(llm, r)
}
default:
return nil, fmt.Errorf("invalid array type: %d", t)
}
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -492,83 +502,62 @@ func writeGGUFArray[S ~[]E, E any](w io.Writer, t uint32, s S) error {
return err return err
} }
if t == ggufTypeString {
for _, e := range any(s).([]string) {
if err := binary.Write(w, binary.LittleEndian, uint64(len(e))); err != nil {
return err
}
if err := binary.Write(w, binary.LittleEndian, []byte(e)); err != nil {
return err
}
}
return nil
}
return binary.Write(w, binary.LittleEndian, s) return binary.Write(w, binary.LittleEndian, s)
} }
func WriteGGUF(f *os.File, kv KV, ts []*Tensor) error { func WriteGGUF(ws io.WriteSeeker, kv KV, ts []Tensor) error {
alignment := kv.Uint("general.alignment", 32) if err := binary.Write(ws, binary.LittleEndian, []byte("GGUF")); err != nil {
if err := binary.Write(f, binary.LittleEndian, []byte("GGUF")); err != nil {
return err return err
} }
if err := binary.Write(f, binary.LittleEndian, uint32(3)); err != nil { if err := binary.Write(ws, binary.LittleEndian, uint32(3)); err != nil {
return err return err
} }
if err := binary.Write(f, binary.LittleEndian, uint64(len(ts))); err != nil { if err := binary.Write(ws, binary.LittleEndian, uint64(len(ts))); err != nil {
return err return err
} }
if err := binary.Write(f, binary.LittleEndian, uint64(len(kv))); err != nil { if err := binary.Write(ws, binary.LittleEndian, uint64(len(kv))); err != nil {
return err return err
} }
for _, key := range slices.Sorted(maps.Keys(kv)) { keys := slices.Collect(maps.Keys(kv))
if err := ggufWriteKV(f, key, kv[key]); err != nil { slices.Sort(keys)
for _, key := range keys {
if err := ggufWriteKV(ws, key, kv[key]); err != nil {
return err return err
} }
} }
slices.SortStableFunc(ts, func(a, b *Tensor) int { slices.SortStableFunc(ts, func(a, b Tensor) int {
if i, j := a.block(), b.block(); i > 0 && j > 0 { if i, j := a.block(), b.block(); i < 0 && j > 0 {
return 1
} else if i > 0 && j < 0 {
return -1
} else {
return cmp.Compare(i, j) return cmp.Compare(i, j)
} }
return cmp.Compare(a.Name, b.Name)
}) })
var s uint64 var s uint64
for i := range ts { for _, t := range ts {
ts[i].Offset = s t.Offset = s
if err := ggufWriteTensorInfo(f, ts[i]); err != nil { if err := ggufWriteTensorInfo(ws, t); err != nil {
return err return err
} }
s += ts[i].Size() s += t.Size()
s += uint64(ggufPadding(int64(s), int64(alignment)))
} }
offset, err := f.Seek(0, io.SeekCurrent) var alignment int64 = 32
if err != nil {
return err
}
offset += ggufPadding(offset, int64(alignment))
var g errgroup.Group
g.SetLimit(runtime.GOMAXPROCS(0))
// TODO consider reducing if tensors size * gomaxprocs is larger than free memory
for _, t := range ts { for _, t := range ts {
t := t if err := ggufWriteTensor(ws, t, alignment); err != nil {
w := io.NewOffsetWriter(f, offset+int64(t.Offset))
g.Go(func() error {
_, err := t.WriteTo(w)
return err return err
}) }
} }
return g.Wait() return nil
} }
func ggufWriteKV(ws io.WriteSeeker, k string, v any) error { func ggufWriteKV(ws io.WriteSeeker, k string, v any) error {
@@ -583,10 +572,8 @@ func ggufWriteKV(ws io.WriteSeeker, k string, v any) error {
var err error var err error
switch v := v.(type) { switch v := v.(type) {
case uint32, FileType: case uint32:
err = writeGGUF(ws, ggufTypeUint32, v) err = writeGGUF(ws, ggufTypeUint32, v)
case uint64:
err = writeGGUF(ws, ggufTypeUint64, v)
case float32: case float32:
err = writeGGUF(ws, ggufTypeFloat32, v) err = writeGGUF(ws, ggufTypeFloat32, v)
case bool: case bool:
@@ -595,24 +582,32 @@ func ggufWriteKV(ws io.WriteSeeker, k string, v any) error {
err = writeGGUFString(ws, v) err = writeGGUFString(ws, v)
case []int32: case []int32:
err = writeGGUFArray(ws, ggufTypeInt32, v) err = writeGGUFArray(ws, ggufTypeInt32, v)
case *array[int32]:
err = writeGGUFArray(ws, ggufTypeInt32, v.values)
case []uint32: case []uint32:
err = writeGGUFArray(ws, ggufTypeUint32, v) err = writeGGUFArray(ws, ggufTypeUint32, v)
case *array[uint32]:
err = writeGGUFArray(ws, ggufTypeUint32, v.values)
case []float32: case []float32:
err = writeGGUFArray(ws, ggufTypeFloat32, v) err = writeGGUFArray(ws, ggufTypeFloat32, v)
case *array[float32]:
err = writeGGUFArray(ws, ggufTypeFloat32, v.values)
case []string: case []string:
err = writeGGUFArray(ws, ggufTypeString, v) if err := binary.Write(ws, binary.LittleEndian, ggufTypeArray); err != nil {
case *array[string]: return err
err = writeGGUFArray(ws, ggufTypeString, v.values) }
case []bool:
err = writeGGUFArray(ws, ggufTypeBool, v) if err := binary.Write(ws, binary.LittleEndian, ggufTypeString); err != nil {
case *array[bool]: return err
err = writeGGUFArray(ws, ggufTypeBool, v.values) }
if err := binary.Write(ws, binary.LittleEndian, uint64(len(v))); err != nil {
return err
}
for _, e := range v {
if err := binary.Write(ws, binary.LittleEndian, uint64(len(e))); err != nil {
return err
}
if err := binary.Write(ws, binary.LittleEndian, []byte(e)); err != nil {
return err
}
}
default: default:
return fmt.Errorf("improper type for '%s'", k) return fmt.Errorf("improper type for '%s'", k)
} }
@@ -620,7 +615,7 @@ func ggufWriteKV(ws io.WriteSeeker, k string, v any) error {
return err return err
} }
func ggufWriteTensorInfo(ws io.WriteSeeker, t *Tensor) error { func ggufWriteTensorInfo(ws io.WriteSeeker, t Tensor) error {
slog.Debug(t.Name, "kind", t.Kind, "shape", t.Shape, "offset", t.Offset) slog.Debug(t.Name, "kind", t.Kind, "shape", t.Shape, "offset", t.Offset)
if err := binary.Write(ws, binary.LittleEndian, uint64(len(t.Name))); err != nil { if err := binary.Write(ws, binary.LittleEndian, uint64(len(t.Name))); err != nil {
return err return err
@@ -634,8 +629,8 @@ func ggufWriteTensorInfo(ws io.WriteSeeker, t *Tensor) error {
return err return err
} }
for _, n := range t.Shape { for i := range len(t.Shape) {
if err := binary.Write(ws, binary.LittleEndian, n); err != nil { if err := binary.Write(ws, binary.LittleEndian, t.Shape[len(t.Shape)-i-1]); err != nil {
return err return err
} }
} }
@@ -647,6 +642,20 @@ func ggufWriteTensorInfo(ws io.WriteSeeker, t *Tensor) error {
return binary.Write(ws, binary.LittleEndian, t.Offset) return binary.Write(ws, binary.LittleEndian, t.Offset)
} }
func ggufWriteTensor(ws io.WriteSeeker, t Tensor, alignment int64) error {
offset, err := ws.Seek(0, io.SeekCurrent)
if err != nil {
return err
}
if err := binary.Write(ws, binary.LittleEndian, bytes.Repeat([]byte{0}, int(ggufPadding(offset, alignment)))); err != nil {
return err
}
_, err = t.WriteTo(ws)
return err
}
func ggufPadding(offset, align int64) int64 { func ggufPadding(offset, align int64) int64 {
return (align - offset%align) % align return (align - offset%align) % align
} }

View File

@@ -1,83 +0,0 @@
package ggml
import (
"bytes"
"math/rand/v2"
"os"
"strings"
"testing"
"github.com/google/go-cmp/cmp"
)
func TestWriteGGUF(t *testing.T) {
r := rand.New(rand.NewPCG(0, 0))
for range 8 {
t.Run("shuffle", func(t *testing.T) {
t.Parallel()
ts := []*Tensor{
{Name: "token_embd.weight", Shape: []uint64{2, 3}, WriterTo: bytes.NewBuffer(make([]byte, 2*3))},
{Name: "blk.0.attn_norm.weight", Shape: []uint64{2, 3}, WriterTo: bytes.NewBuffer(make([]byte, 2*3))},
{Name: "blk.1.attn_norm.weight", Shape: []uint64{2, 3}, WriterTo: bytes.NewBuffer(make([]byte, 2*3))},
{Name: "blk.2.attn_norm.weight", Shape: []uint64{2, 3}, WriterTo: bytes.NewBuffer(make([]byte, 2*3))},
{Name: "blk.3.attn_norm.weight", Shape: []uint64{2, 3}, WriterTo: bytes.NewBuffer(make([]byte, 2*3))},
{Name: "blk.4.attn_norm.weight", Shape: []uint64{2, 3}, WriterTo: bytes.NewBuffer(make([]byte, 2*3))},
{Name: "blk.5.attn_norm.weight", Shape: []uint64{2, 3}, WriterTo: bytes.NewBuffer(make([]byte, 2*3))},
{Name: "output_norm.weight", Shape: []uint64{3, 2}, WriterTo: bytes.NewBuffer(make([]byte, 3*2))},
{Name: "output.weight", Shape: []uint64{3, 2}, WriterTo: bytes.NewBuffer(make([]byte, 3*2))},
}
r.Shuffle(len(ts), func(i, j int) {
ts[i], ts[j] = ts[j], ts[i]
})
w, err := os.CreateTemp(t.TempDir(), strings.ReplaceAll(t.Name(), "/", "_")+"*.bin")
if err != nil {
t.Fatal(err)
}
defer w.Close()
if err := WriteGGUF(w, KV{
"general.alignment": uint32(16),
}, ts); err != nil {
t.Fatal(err)
}
r, err := os.Open(w.Name())
if err != nil {
t.Fatal(err)
}
defer r.Close()
ff, err := Decode(r, 0)
if err != nil {
t.Fatal(err)
}
if diff := cmp.Diff(KV{
"general.alignment": uint32(16),
"general.parameter_count": uint64(54),
}, ff.KV()); diff != "" {
t.Errorf("Mismatch (-want +got):\n%s", diff)
}
if diff := cmp.Diff(Tensors{
Offset: 608,
items: []*Tensor{
{Name: "blk.0.attn_norm.weight", Offset: 0, Shape: []uint64{2, 3}},
{Name: "blk.1.attn_norm.weight", Offset: 32, Shape: []uint64{2, 3}},
{Name: "blk.2.attn_norm.weight", Offset: 64, Shape: []uint64{2, 3}},
{Name: "blk.3.attn_norm.weight", Offset: 96, Shape: []uint64{2, 3}},
{Name: "blk.4.attn_norm.weight", Offset: 128, Shape: []uint64{2, 3}},
{Name: "blk.5.attn_norm.weight", Offset: 160, Shape: []uint64{2, 3}},
{Name: "output.weight", Offset: 192, Shape: []uint64{3, 2}},
{Name: "output_norm.weight", Offset: 224, Shape: []uint64{3, 2}},
{Name: "token_embd.weight", Offset: 256, Shape: []uint64{2, 3}},
},
}, ff.Tensors(), cmp.AllowUnexported(Tensors{})); diff != "" {
t.Errorf("Mismatch (-want +got):\n%s", diff)
}
})
}
}

View File

@@ -1,31 +1,26 @@
package ggml package ggml
import ( import "fmt"
"fmt"
"log/slog"
"strings"
)
// FileType is the Go equivalent to llama_ftype used for gguf file typing type fileType uint32
type FileType uint32
const ( const (
FileTypeF32 FileType = iota fileTypeF32 fileType = iota
FileTypeF16 fileTypeF16
fileTypeQ4_0 fileTypeQ4_0
fileTypeQ4_1 fileTypeQ4_1
fileTypeQ4_1_F16 // unused by GGML fileTypeQ4_1_F16
fileTypeQ4_2 // unused by GGML fileTypeQ4_2 // unused
fileTypeQ4_3 // unused by GGML fileTypeQ4_3 // unused
FileTypeQ8_0 fileTypeQ8_0
fileTypeQ5_0 fileTypeQ5_0
fileTypeQ5_1 fileTypeQ5_1
fileTypeQ2_K fileTypeQ2_K
fileTypeQ3_K_S fileTypeQ3_K_S
fileTypeQ3_K_M fileTypeQ3_K_M
fileTypeQ3_K_L fileTypeQ3_K_L
FileTypeQ4_K_S fileTypeQ4_K_S
FileTypeQ4_K_M fileTypeQ4_K_M
fileTypeQ5_K_S fileTypeQ5_K_S
fileTypeQ5_K_M fileTypeQ5_K_M
fileTypeQ6_K fileTypeQ6_K
@@ -42,62 +37,93 @@ const (
fileTypeIQ2_M fileTypeIQ2_M
fileTypeIQ4_XS fileTypeIQ4_XS
fileTypeIQ1_M fileTypeIQ1_M
FileTypeBF16 fileTypeBF16
fileTypeQ4_0_4_4 // unused by GGML
fileTypeQ4_0_4_8 // unused by GGML
fileTypeQ4_0_8_8 // unused by GGML
fileTypeTQ1_0
fileTypeTQ2_0
FileTypeUnknown = 1024 fileTypeUnknown
) )
// ParseFileType parses the provided GGUF file type func ParseFileType(s string) (fileType, error) {
// Only Ollama supported types are considered valid
func ParseFileType(s string) (FileType, error) {
switch s { switch s {
case "F32": case "F32":
return FileTypeF32, nil return fileTypeF32, nil
case "F16": case "F16":
return FileTypeF16, nil return fileTypeF16, nil
case "Q4_0":
return fileTypeQ4_0, nil
case "Q4_1":
return fileTypeQ4_1, nil
case "Q4_1_F16":
return fileTypeQ4_1_F16, nil
case "Q8_0": case "Q8_0":
return FileTypeQ8_0, nil return fileTypeQ8_0, nil
case "Q5_0":
return fileTypeQ5_0, nil
case "Q5_1":
return fileTypeQ5_1, nil
case "Q2_K":
return fileTypeQ2_K, nil
case "Q3_K_S":
return fileTypeQ3_K_S, nil
case "Q3_K_M":
return fileTypeQ3_K_M, nil
case "Q3_K_L":
return fileTypeQ3_K_L, nil
case "Q4_K_S": case "Q4_K_S":
return FileTypeQ4_K_S, nil return fileTypeQ4_K_S, nil
case "Q4_K_M", "Q4_K": case "Q4_K_M":
return FileTypeQ4_K_M, nil return fileTypeQ4_K_M, nil
case "Q5_K_S":
return fileTypeQ5_K_S, nil
case "Q5_K_M":
return fileTypeQ5_K_M, nil
case "Q6_K":
return fileTypeQ6_K, nil
case "IQ2_XXS":
return fileTypeIQ2_XXS, nil
case "IQ2_XS":
return fileTypeIQ2_XS, nil
case "Q2_K_S":
return fileTypeQ2_K_S, nil
case "IQ3_XS":
return fileTypeIQ3_XS, nil
case "IQ3_XXS":
return fileTypeIQ3_XXS, nil
case "IQ1_S":
return fileTypeIQ1_S, nil
case "IQ4_NL":
return fileTypeIQ4_NL, nil
case "IQ3_S":
return fileTypeIQ3_S, nil
case "IQ3_M":
return fileTypeIQ3_M, nil
case "IQ2_S":
return fileTypeIQ2_S, nil
case "IQ2_M":
return fileTypeIQ2_M, nil
case "IQ4_XS":
return fileTypeIQ4_XS, nil
case "IQ1_M":
return fileTypeIQ1_M, nil
case "BF16": case "BF16":
return FileTypeBF16, nil return fileTypeBF16, nil
default: default:
supportedFileTypes := []FileType{ return fileTypeUnknown, fmt.Errorf("unknown fileType: %s", s)
FileTypeF32,
FileTypeF16,
FileTypeQ4_K_S,
FileTypeQ4_K_M,
FileTypeQ8_0,
// fsggml.FileTypeBF16, // TODO
}
strs := make([]string, len(supportedFileTypes))
for i := range supportedFileTypes {
strs[i] = supportedFileTypes[i].String()
}
return FileTypeUnknown, fmt.Errorf("unsupported quantization type %s - supported types are %s", s, strings.Join(strs, ", "))
} }
} }
func (t FileType) String() string { func (t fileType) String() string {
// Note: this routine will return a broader set of file types for existing models
switch t { switch t {
case FileTypeF32: case fileTypeF32:
return "F32" return "F32"
case FileTypeF16: case fileTypeF16:
return "F16" return "F16"
case fileTypeQ4_0: case fileTypeQ4_0:
return "Q4_0" return "Q4_0"
case fileTypeQ4_1: case fileTypeQ4_1:
return "Q4_1" return "Q4_1"
case FileTypeQ8_0: case fileTypeQ4_1_F16:
return "Q4_1_F16"
case fileTypeQ8_0:
return "Q8_0" return "Q8_0"
case fileTypeQ5_0: case fileTypeQ5_0:
return "Q5_0" return "Q5_0"
@@ -111,9 +137,9 @@ func (t FileType) String() string {
return "Q3_K_M" return "Q3_K_M"
case fileTypeQ3_K_L: case fileTypeQ3_K_L:
return "Q3_K_L" return "Q3_K_L"
case FileTypeQ4_K_S: case fileTypeQ4_K_S:
return "Q4_K_S" return "Q4_K_S"
case FileTypeQ4_K_M: case fileTypeQ4_K_M:
return "Q4_K_M" return "Q4_K_M"
case fileTypeQ5_K_S: case fileTypeQ5_K_S:
return "Q5_K_S" return "Q5_K_S"
@@ -121,198 +147,39 @@ func (t FileType) String() string {
return "Q5_K_M" return "Q5_K_M"
case fileTypeQ6_K: case fileTypeQ6_K:
return "Q6_K" return "Q6_K"
case fileTypeIQ2_XXS:
return "IQ2_XXS"
case fileTypeIQ2_XS:
return "IQ2_XS"
case fileTypeQ2_K_S: case fileTypeQ2_K_S:
return "Q2_K_S" return "Q2_K_S"
case FileTypeBF16: case fileTypeIQ3_XS:
return "IQ3_XS"
case fileTypeIQ3_XXS:
return "IQ3_XXS"
case fileTypeIQ1_S:
return "IQ1_S"
case fileTypeIQ4_NL:
return "IQ4_NL"
case fileTypeIQ3_S:
return "IQ3_S"
case fileTypeIQ3_M:
return "IQ3_M"
case fileTypeIQ2_S:
return "IQ2_S"
case fileTypeIQ4_XS:
return "IQ4_XS"
case fileTypeIQ2_M:
return "IQ2_M"
case fileTypeIQ1_M:
return "IQ1_M"
case fileTypeBF16:
return "BF16" return "BF16"
default: default:
return "unknown" return "unknown"
} }
} }
func (t FileType) Value() uint32 { func (t fileType) Value() uint32 {
return uint32(t) return uint32(t)
} }
func (ftype FileType) ToTensorType() TensorType {
switch ftype {
case FileTypeF32:
return TensorTypeF32
case FileTypeF16:
return TensorTypeF16
case fileTypeQ4_0:
return TensorTypeQ4_0
case fileTypeQ4_1:
return TensorTypeQ4_1
case FileTypeQ8_0:
return TensorTypeQ8_0
case fileTypeQ5_0:
return TensorTypeQ5_0
case fileTypeQ5_1:
return TensorTypeQ5_1
case fileTypeQ2_K:
return TensorTypeQ2_K
case fileTypeQ3_K_S:
return TensorTypeQ3_K
case fileTypeQ3_K_M:
return TensorTypeQ3_K
case fileTypeQ3_K_L:
return TensorTypeQ3_K
case FileTypeQ4_K_S:
return TensorTypeQ4_K
case FileTypeQ4_K_M:
return TensorTypeQ4_K
case fileTypeQ5_K_S:
return TensorTypeQ5_K
case fileTypeQ5_K_M:
return TensorTypeQ5_K
case fileTypeQ6_K:
return TensorTypeQ6_K
case fileTypeQ2_K_S:
return TensorTypeQ2_K
case FileTypeBF16:
return TensorTypeBF16
default:
slog.Warn("unsupported file type", "type", ftype)
return 0 // F32
}
}
// TensorType is equivalent to ggml_type for individual tensor types
// Note: these are not the same as FileType
type TensorType uint32
const (
TensorTypeF32 TensorType = iota
TensorTypeF16
TensorTypeQ4_0
TensorTypeQ4_1
tensorTypeQ4_2 // unused by GGML
tensorTypeQ4_3 // unused by GGML
TensorTypeQ5_0
TensorTypeQ5_1
TensorTypeQ8_0
TensorTypeQ8_1
TensorTypeQ2_K
TensorTypeQ3_K
TensorTypeQ4_K
TensorTypeQ5_K
TensorTypeQ6_K
TensorTypeQ8_K
tensorTypeIQ2_XXS // not supported by ollama
tensorTypeIQ2_XS // not supported by ollama
tensorTypeIQ3_XXS // not supported by ollama
tensorTypeIQ1_S // not supported by ollama
tensorTypeIQ4_NL // not supported by ollama
tensorTypeIQ3_S // not supported by ollama
tensorTypeIQ2_S // not supported by ollama
tensorTypeIQ4_XS // not supported by ollama
TensorTypeI8
TensorTypeI16
TensorTypeI32
TensorTypeI64
TensorTypeF64
tensorTypeIQ1_M // not supported by ollama
TensorTypeBF16
tensorTypeQ4_0_4_4 // unused by GGML
tensorTypeQ4_0_4_8 // unused by GGML
tensorTypeQ4_0_8_8 // unused by GGML
tensorTypeTQ1_0 // not supported by ollama
tensorTypeTQ2_0 // not supported by ollama
tensorTypeIQ4_NL_4_4 // unused by GGML
tensorTypeIQ4_NL_4_8 // unused by GGML
tensorTypeIQ4_NL_8_8 // unused by GGML
)
// ParseFileType parses the provided GGUF file type
// Only Ollama supported types are considered valid
func ParseTensorType(s string) (TensorType, error) {
switch s {
case "F32":
return TensorTypeF32, nil
case "F16":
return TensorTypeF16, nil
case "Q4_0":
return TensorTypeQ4_0, nil
case "Q4_1":
return TensorTypeQ4_1, nil
case "Q5_0":
return TensorTypeQ5_0, nil
case "Q5_1":
return TensorTypeQ5_1, nil
case "Q8_0":
return TensorTypeQ8_0, nil
case "Q8_1":
return TensorTypeQ8_1, nil
case "Q2_K":
return TensorTypeQ2_K, nil
case "Q3_K":
return TensorTypeQ3_K, nil
case "Q4_K":
return TensorTypeQ4_K, nil
case "Q5_K":
return TensorTypeQ5_K, nil
case "Q6_K":
return TensorTypeQ6_K, nil
case "Q8_K":
return TensorTypeQ8_K, nil
case "F64":
return TensorTypeF64, nil
case "BF16":
return TensorTypeBF16, nil
default:
return 0, fmt.Errorf("unsupported quantization type %s", s)
}
}
func (t TensorType) IsQuantized() bool {
switch t {
case TensorTypeF32, TensorTypeF16, TensorTypeBF16:
return false
default:
return true
}
}
func (t TensorType) RowSize(ne uint64) uint64 {
return t.TypeSize() * ne / t.BlockSize()
}
func (t TensorType) String() string {
switch t {
case TensorTypeF32:
return "F32"
case TensorTypeF16:
return "F16"
case TensorTypeQ4_0:
return "Q4_0"
case TensorTypeQ4_1:
return "Q4_1"
case TensorTypeQ5_0:
return "Q5_0"
case TensorTypeQ5_1:
return "Q5_1"
case TensorTypeQ8_0:
return "Q8_0"
case TensorTypeQ8_1:
return "Q8_1"
case TensorTypeQ2_K:
return "Q2_K"
case TensorTypeQ3_K:
return "Q3_K"
case TensorTypeQ4_K:
return "Q4_K"
case TensorTypeQ5_K:
return "Q5_K"
case TensorTypeQ6_K:
return "Q6_K"
case TensorTypeQ8_K:
return "Q8_K"
case TensorTypeF64:
return "F64"
case TensorTypeBF16:
return "BF16"
default:
return "unknown"
}
}

View File

@@ -1,347 +0,0 @@
package gguf
import (
"bytes"
"cmp"
"encoding/binary"
"errors"
"fmt"
"io"
"iter"
"os"
"slices"
"strings"
)
const (
typeUint8 uint32 = iota
typeInt8
typeUint16
typeInt16
typeUint32
typeInt32
typeFloat32
typeBool
typeString
typeArray
typeUint64
typeInt64
typeFloat64
)
var ErrUnsupported = errors.New("unsupported")
type File struct {
Magic [4]byte
Version uint32
keyValues *lazy[KeyValue]
tensors *lazy[TensorInfo]
offset int64
file *os.File
reader *bufferedReader
bts []byte
}
func Open(path string) (f *File, err error) {
f = &File{bts: make([]byte, 4096)}
f.file, err = os.Open(path)
if err != nil {
return nil, err
}
f.reader = newBufferedReader(f.file, 32<<10)
if err := binary.Read(f.reader, binary.LittleEndian, &f.Magic); err != nil {
return nil, err
}
if bytes.Equal(f.Magic[:], []byte("gguf")) {
return nil, fmt.Errorf("%w file type %v", ErrUnsupported, f.Magic)
}
if err := binary.Read(f.reader, binary.LittleEndian, &f.Version); err != nil {
return nil, err
}
if f.Version < 2 {
return nil, fmt.Errorf("%w version %v", ErrUnsupported, f.Version)
}
f.tensors, err = newLazy(f, f.readTensor)
if err != nil {
return nil, err
}
f.tensors.successFunc = func() error {
offset := f.reader.offset
alignment := cmp.Or(f.KeyValue("general.alignment").Int(), 32)
f.offset = offset + (alignment-offset%alignment)%alignment
return nil
}
f.keyValues, err = newLazy(f, f.readKeyValue)
if err != nil {
return nil, err
}
return f, nil
}
func (f *File) readTensor() (TensorInfo, error) {
name, err := readString(f)
if err != nil {
return TensorInfo{}, err
}
dims, err := read[uint32](f)
if err != nil {
return TensorInfo{}, err
}
shape := make([]uint64, dims)
for i := range dims {
shape[i], err = read[uint64](f)
if err != nil {
return TensorInfo{}, err
}
}
type_, err := read[uint32](f)
if err != nil {
return TensorInfo{}, err
}
offset, err := read[uint64](f)
if err != nil {
return TensorInfo{}, err
}
return TensorInfo{
Name: name,
Offset: offset,
Shape: shape,
Type: TensorType(type_),
}, nil
}
func (f *File) readKeyValue() (KeyValue, error) {
key, err := readString(f)
if err != nil {
return KeyValue{}, err
}
t, err := read[uint32](f)
if err != nil {
return KeyValue{}, err
}
value, err := func() (any, error) {
switch t {
case typeUint8:
return read[uint8](f)
case typeInt8:
return read[int8](f)
case typeUint16:
return read[uint16](f)
case typeInt16:
return read[int16](f)
case typeUint32:
return read[uint32](f)
case typeInt32:
return read[int32](f)
case typeUint64:
return read[uint64](f)
case typeInt64:
return read[int64](f)
case typeFloat32:
return read[float32](f)
case typeFloat64:
return read[float64](f)
case typeBool:
return read[bool](f)
case typeString:
return readString(f)
case typeArray:
return readArray(f)
default:
return nil, fmt.Errorf("%w type %d", ErrUnsupported, t)
}
}()
if err != nil {
return KeyValue{}, err
}
return KeyValue{
Key: key,
Value: Value{value},
}, nil
}
func read[T any](f *File) (t T, err error) {
err = binary.Read(f.reader, binary.LittleEndian, &t)
return t, err
}
func readString(f *File) (string, error) {
n, err := read[uint64](f)
if err != nil {
return "", err
}
if int(n) > len(f.bts) {
f.bts = make([]byte, n)
}
bts := f.bts[:n]
if _, err := io.ReadFull(f.reader, bts); err != nil {
return "", err
}
defer clear(bts)
return string(bts), nil
}
func readArray(f *File) (any, error) {
t, err := read[uint32](f)
if err != nil {
return nil, err
}
n, err := read[uint64](f)
if err != nil {
return nil, err
}
switch t {
case typeUint8:
return readArrayData[uint8](f, n)
case typeInt8:
return readArrayData[int8](f, n)
case typeUint16:
return readArrayData[uint16](f, n)
case typeInt16:
return readArrayData[int16](f, n)
case typeUint32:
return readArrayData[uint32](f, n)
case typeInt32:
return readArrayData[int32](f, n)
case typeUint64:
return readArrayData[uint64](f, n)
case typeInt64:
return readArrayData[int64](f, n)
case typeFloat32:
return readArrayData[float32](f, n)
case typeFloat64:
return readArrayData[float64](f, n)
case typeBool:
return readArrayData[bool](f, n)
case typeString:
return readArrayString(f, n)
default:
return nil, fmt.Errorf("%w type %d", ErrUnsupported, t)
}
}
func readArrayData[T any](f *File, n uint64) (s []T, err error) {
s = make([]T, n)
for i := range n {
e, err := read[T](f)
if err != nil {
return nil, err
}
s[i] = e
}
return s, nil
}
func readArrayString(f *File, n uint64) (s []string, err error) {
s = make([]string, n)
for i := range n {
e, err := readString(f)
if err != nil {
return nil, err
}
s[i] = e
}
return s, nil
}
func (f *File) Close() error {
f.keyValues.stop()
f.tensors.stop()
return f.file.Close()
}
func (f *File) KeyValue(key string) KeyValue {
if !strings.HasPrefix(key, "general.") && !strings.HasPrefix(key, "tokenizer.") {
key = f.KeyValue("general.architecture").String() + "." + key
}
if index := slices.IndexFunc(f.keyValues.values, func(kv KeyValue) bool {
return kv.Key == key
}); index >= 0 {
return f.keyValues.values[index]
}
for keyValue, ok := f.keyValues.next(); ok; keyValue, ok = f.keyValues.next() {
if keyValue.Key == key {
return keyValue
}
}
return KeyValue{}
}
func (f *File) NumKeyValues() int {
return int(f.keyValues.count)
}
func (f *File) KeyValues() iter.Seq2[int, KeyValue] {
return f.keyValues.All()
}
func (f *File) TensorInfo(name string) TensorInfo {
if index := slices.IndexFunc(f.tensors.values, func(t TensorInfo) bool {
return t.Name == name
}); index >= 0 {
return f.tensors.values[index]
}
// fast-forward through key values if we haven't already
_ = f.keyValues.rest()
for tensor, ok := f.tensors.next(); ok; tensor, ok = f.tensors.next() {
if tensor.Name == name {
return tensor
}
}
return TensorInfo{}
}
func (f *File) NumTensors() int {
return int(f.tensors.count)
}
func (f *File) TensorInfos() iter.Seq2[int, TensorInfo] {
// fast forward through key values if we haven't already
f.keyValues.rest()
return f.tensors.All()
}
func (f *File) TensorReader(name string) (TensorInfo, io.Reader, error) {
t := f.TensorInfo(name)
if t.NumBytes() == 0 {
return TensorInfo{}, nil, fmt.Errorf("tensor %s not found", name)
}
// fast forward through tensor info if we haven't already
_ = f.tensors.rest()
return t, io.NewSectionReader(f.file, f.offset+int64(t.Offset), t.NumBytes()), nil
}

View File

@@ -1,249 +0,0 @@
package gguf_test
import (
"bytes"
"os"
"strconv"
"strings"
"testing"
"github.com/google/go-cmp/cmp"
"github.com/google/go-cmp/cmp/cmpopts"
"github.com/ollama/ollama/fs/ggml"
"github.com/ollama/ollama/fs/gguf"
)
func createBinFile(tb testing.TB) string {
tb.Helper()
f, err := os.CreateTemp(tb.TempDir(), "")
if err != nil {
tb.Fatal(err)
}
defer f.Close()
kv := ggml.KV{
"general.architecture": "llama",
"llama.block_count": uint32(8),
"llama.embedding_length": uint32(3),
"llama.attention.head_count": uint32(2),
"llama.attention.head_count_kv": uint32(2),
"llama.attention.key_length": uint32(3),
"llama.rope.dimension_count": uint32(4),
"llama.rope.freq_base": float32(10000.0),
"llama.rope.freq_scale": float32(1.0),
"llama.attention.layer_norm_rms_epsilon": float32(1e-6),
"tokenizer.ggml.eos_token_id": uint32(0),
"tokenizer.ggml.eos_token_ids": []int32{1, 2, 3},
"tokenizer.ggml.tokens": []string{"hello", "world"},
"tokenizer.ggml.scores": []float32{0, 1},
}
tensors := []*ggml.Tensor{
{
Name: "token_embd.weight",
Kind: 0,
Shape: []uint64{2, 3},
WriterTo: bytes.NewBuffer(make([]byte, 4*2*3)),
},
{
Name: "output.weight",
Kind: 0,
Shape: []uint64{3, 2},
WriterTo: bytes.NewBuffer(make([]byte, 4*3*2)),
},
}
for i := range 8 {
tensors = append(tensors, &ggml.Tensor{
Name: "blk." + strconv.Itoa(i) + ".attn_q.weight",
Kind: 0,
Shape: []uint64{3, 3},
WriterTo: bytes.NewBuffer(make([]byte, 4*3*3)),
}, &ggml.Tensor{
Name: "blk." + strconv.Itoa(i) + ".attn_k.weight",
Kind: 0,
Shape: []uint64{3, 3},
WriterTo: bytes.NewBuffer(make([]byte, 4*3*3)),
}, &ggml.Tensor{
Name: "blk." + strconv.Itoa(i) + ".attn_v.weight",
Kind: 0,
Shape: []uint64{3, 3},
WriterTo: bytes.NewBuffer(make([]byte, 4*3*3)),
}, &ggml.Tensor{
Name: "blk." + strconv.Itoa(i) + ".attn_output.weight",
Kind: 0,
Shape: []uint64{3, 3},
WriterTo: bytes.NewBuffer(make([]byte, 4*3*3)),
})
}
if err := ggml.WriteGGUF(f, kv, tensors); err != nil {
tb.Fatal(err)
}
return f.Name()
}
func TestRead(t *testing.T) {
f, err := gguf.Open(createBinFile(t))
if err != nil {
t.Fatal(err)
}
defer f.Close()
if got := f.KeyValue("does.not.exist").Valid(); got {
t.Errorf(`KeyValue("does.not.exist").Exists() = %v, want false`, got)
}
if got := f.KeyValue("general.architecture").String(); got != "llama" {
t.Errorf(`KeyValue("general.architecture").String() = %q, want %q`, got, "llama")
}
if got := f.TensorInfo("token_embd.weight"); got.Name != "token_embd.weight" {
t.Errorf(`TensorInfo("token_embd.weight").Name = %q, want %q`, got.Name, "token_embd.weight")
} else if diff := cmp.Diff(got.Shape, []uint64{2, 3}); diff != "" {
t.Errorf(`TensorInfo("token_embd.weight").Shape mismatch (-got +want):\n%s`, diff)
} else if got.Type != gguf.TensorTypeF32 {
t.Errorf(`TensorInfo("token_embd.weight").Type = %d, want %d`, got.Type, gguf.TensorTypeF32)
}
if got := f.KeyValue("block_count").Uint(); got != 8 {
t.Errorf(`KeyValue("block_count").Uint() = %d, want %d`, got, 8)
}
if diff := cmp.Diff(f.KeyValue("tokenizer.ggml.tokens").Strings(), []string{"hello", "world"}); diff != "" {
t.Errorf("KeyValue(\"tokenizer.ggml.tokens\").Strings() mismatch (-got +want):\n%s", diff)
}
if diff := cmp.Diff(f.KeyValue("tokenizer.ggml.scores").Floats(), []float64{0, 1}); diff != "" {
t.Errorf("KeyValue(\"tokenizer.ggml.scores\").Ints() mismatch (-got +want):\n%s", diff)
}
var kvs []string
for _, kv := range f.KeyValues() {
if !kv.Valid() {
t.Error("found invalid key-value pair:", kv)
}
kvs = append(kvs, kv.Key)
}
if len(kvs) != f.NumKeyValues() {
t.Errorf("iterated key count = %d, want %d", len(kvs), f.NumKeyValues())
}
if diff := cmp.Diff(kvs, []string{
"general.architecture",
"llama.block_count",
"llama.embedding_length",
"llama.attention.head_count",
"llama.attention.head_count_kv",
"llama.attention.key_length",
"llama.rope.dimension_count",
"llama.rope.freq_base",
"llama.rope.freq_scale",
"llama.attention.layer_norm_rms_epsilon",
"tokenizer.ggml.eos_token_id",
"tokenizer.ggml.eos_token_ids",
"tokenizer.ggml.tokens",
"tokenizer.ggml.scores",
}, cmpopts.SortSlices(strings.Compare)); diff != "" {
t.Errorf("KeyValues() mismatch (-got +want):\n%s", diff)
}
var tis []string
for _, ti := range f.TensorInfos() {
if !ti.Valid() {
t.Error("found invalid tensor info:", ti)
}
tis = append(tis, ti.Name)
}
if len(tis) != f.NumTensors() {
t.Errorf("iterated tensor count = %d, want %d", len(tis), f.NumTensors())
}
if diff := cmp.Diff(tis, []string{
"token_embd.weight",
"output.weight",
"blk.0.attn_q.weight",
"blk.0.attn_k.weight",
"blk.0.attn_v.weight",
"blk.0.attn_output.weight",
"blk.1.attn_q.weight",
"blk.1.attn_k.weight",
"blk.1.attn_v.weight",
"blk.1.attn_output.weight",
"blk.2.attn_q.weight",
"blk.2.attn_k.weight",
"blk.2.attn_v.weight",
"blk.2.attn_output.weight",
"blk.3.attn_q.weight",
"blk.3.attn_k.weight",
"blk.3.attn_v.weight",
"blk.3.attn_output.weight",
"blk.4.attn_q.weight",
"blk.4.attn_k.weight",
"blk.4.attn_v.weight",
"blk.4.attn_output.weight",
"blk.5.attn_q.weight",
"blk.5.attn_k.weight",
"blk.5.attn_v.weight",
"blk.5.attn_output.weight",
"blk.6.attn_q.weight",
"blk.6.attn_k.weight",
"blk.6.attn_v.weight",
"blk.6.attn_output.weight",
"blk.7.attn_q.weight",
"blk.7.attn_k.weight",
"blk.7.attn_v.weight",
"blk.7.attn_output.weight",
}, cmpopts.SortSlices(strings.Compare)); diff != "" {
t.Errorf("TensorInfos() mismatch (-got +want):\n%s", diff)
}
ti, r, err := f.TensorReader("output.weight")
if err != nil {
t.Fatalf(`TensorReader("output.weight") error: %v`, err)
}
if ti.Name != "output.weight" {
t.Errorf(`TensorReader("output.weight").Name = %q, want %q`, ti.Name, "output.weight")
} else if diff := cmp.Diff(ti.Shape, []uint64{3, 2}); diff != "" {
t.Errorf(`TensorReader("output.weight").Shape mismatch (-got +want):\n%s`, diff)
} else if ti.Type != gguf.TensorTypeF32 {
t.Errorf(`TensorReader("output.weight").Type = %d, want %d`, ti.Type, gguf.TensorTypeF32)
}
var b bytes.Buffer
if _, err := b.ReadFrom(r); err != nil {
t.Fatalf(`ReadFrom TensorReader("output.weight") error: %v`, err)
}
if b.Len() != int(ti.NumBytes()) {
t.Errorf(`ReadFrom TensorReader("output.weight") length = %d, want %d`, b.Len(), ti.NumBytes())
}
}
func BenchmarkRead(b *testing.B) {
b.ReportAllocs()
p := createBinFile(b)
for b.Loop() {
f, err := gguf.Open(p)
if err != nil {
b.Fatal(err)
}
if got := f.KeyValue("general.architecture").String(); got != "llama" {
b.Errorf("got = %q, want %q", got, "llama")
}
// Iterate through some tensors
for range f.TensorInfos() {
}
f.Close()
}
}

View File

@@ -1,90 +0,0 @@
package gguf
import (
"reflect"
"slices"
)
type KeyValue struct {
Key string
Value
}
func (kv KeyValue) Valid() bool {
return kv.Key != "" && kv.Value.value != nil
}
type Value struct {
value any
}
func value[T any](v Value, kinds ...reflect.Kind) (t T) {
vv := reflect.ValueOf(v.value)
if slices.Contains(kinds, vv.Kind()) {
t = vv.Convert(reflect.TypeOf(t)).Interface().(T)
}
return
}
func values[T any](v Value, kinds ...reflect.Kind) (ts []T) {
switch vv := reflect.ValueOf(v.value); vv.Kind() {
case reflect.Slice:
if slices.Contains(kinds, vv.Type().Elem().Kind()) {
ts = make([]T, vv.Len())
for i := range vv.Len() {
ts[i] = vv.Index(i).Convert(reflect.TypeOf(ts[i])).Interface().(T)
}
}
}
return
}
// Int returns Value as a signed integer. If it is not a signed integer, it returns 0.
func (v Value) Int() int64 {
return value[int64](v, reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64)
}
// Ints returns Value as a signed integer slice. If it is not a signed integer slice, it returns nil.
func (v Value) Ints() (i64s []int64) {
return values[int64](v, reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64)
}
// Uint converts an unsigned integer value to uint64. If the value is not a unsigned integer, it returns 0.
func (v Value) Uint() uint64 {
return value[uint64](v, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64)
}
// Uints returns Value as a unsigned integer slice. If it is not a unsigned integer slice, it returns nil.
func (v Value) Uints() (u64s []uint64) {
return values[uint64](v, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64)
}
// Float returns Value as a float. If it is not a float, it returns 0.
func (v Value) Float() float64 {
return value[float64](v, reflect.Float32, reflect.Float64)
}
// Floats returns Value as a float slice. If it is not a float slice, it returns nil.
func (v Value) Floats() (f64s []float64) {
return values[float64](v, reflect.Float32, reflect.Float64)
}
// Bool returns Value as a boolean. If it is not a boolean, it returns false.
func (v Value) Bool() bool {
return value[bool](v, reflect.Bool)
}
// Bools returns Value as a boolean slice. If it is not a boolean slice, it returns nil.
func (v Value) Bools() (bools []bool) {
return values[bool](v, reflect.Bool)
}
// String returns Value as a string. If it is not a string, it returns an empty string.
func (v Value) String() string {
return value[string](v, reflect.String)
}
// Strings returns Value as a string slice. If it is not a string slice, it returns nil.
func (v Value) Strings() (strings []string) {
return values[string](v, reflect.String)
}

View File

@@ -1,208 +0,0 @@
package gguf
import (
"testing"
"github.com/google/go-cmp/cmp"
)
func split(name string, values map[string][]any) (matched []any, unmatched []any) {
for key, value := range values {
if key == name {
matched = value
} else {
unmatched = append(unmatched, value...)
}
}
return
}
func TestValue(t *testing.T) {
values := map[string][]any{
"int64": {int(42), int8(42), int16(42), int32(42), int64(42)},
"uint64": {uint(42), uint8(42), uint16(42), uint32(42), uint64(42)},
"float64": {float32(42), float64(42)},
"string": {"42", "hello"},
"bool": {true, false},
}
t.Run("int64", func(t *testing.T) {
matched, unmatched := split("int64", values)
for _, v := range matched {
kv := KeyValue{"key", Value{v}}
if i64 := kv.Int(); i64 != 42 {
t.Errorf("expected 42, got %d", i64)
}
}
for _, v := range unmatched {
kv := KeyValue{"key", Value{v}}
if i64 := kv.Int(); i64 != 0 {
t.Errorf("expected 42, got %d", i64)
}
}
})
t.Run("uint64", func(t *testing.T) {
matched, unmatched := split("uint64", values)
for _, v := range matched {
kv := KeyValue{"key", Value{v}}
if u64 := kv.Uint(); u64 != 42 {
t.Errorf("expected 42, got %d", u64)
}
}
for _, v := range unmatched {
kv := KeyValue{"key", Value{v}}
if u64 := kv.Uint(); u64 != 0 {
t.Errorf("expected 42, got %d", u64)
}
}
})
t.Run("float64", func(t *testing.T) {
matched, unmatched := split("float64", values)
for _, v := range matched {
kv := KeyValue{"key", Value{v}}
if f64 := kv.Float(); f64 != 42 {
t.Errorf("expected 42, got %f", f64)
}
}
for _, v := range unmatched {
kv := KeyValue{"key", Value{v}}
if f64 := kv.Float(); f64 != 0 {
t.Errorf("expected 42, got %f", f64)
}
}
})
t.Run("string", func(t *testing.T) {
matched, unmatched := split("string", values)
for _, v := range matched {
kv := KeyValue{"key", Value{v}}
if s := kv.String(); s != v {
t.Errorf("expected 42, got %s", s)
}
}
for _, v := range unmatched {
kv := KeyValue{"key", Value{v}}
if s := kv.String(); s != "" {
t.Errorf("expected 42, got %s", s)
}
}
})
t.Run("bool", func(t *testing.T) {
matched, unmatched := split("bool", values)
for _, v := range matched {
kv := KeyValue{"key", Value{v}}
if b := kv.Bool(); b != v {
t.Errorf("expected true, got %v", b)
}
}
for _, v := range unmatched {
kv := KeyValue{"key", Value{v}}
if b := kv.Bool(); b != false {
t.Errorf("expected false, got %v", b)
}
}
})
}
func TestValues(t *testing.T) {
values := map[string][]any{
"int64s": {[]int{42}, []int8{42}, []int16{42}, []int32{42}, []int64{42}},
"uint64s": {[]uint{42}, []uint8{42}, []uint16{42}, []uint32{42}, []uint64{42}},
"float64s": {[]float32{42}, []float64{42}},
"strings": {[]string{"42"}, []string{"hello"}},
"bools": {[]bool{true}, []bool{false}},
}
t.Run("int64s", func(t *testing.T) {
matched, unmatched := split("int64s", values)
for _, v := range matched {
kv := KeyValue{"key", Value{v}}
if diff := cmp.Diff(kv.Ints(), []int64{42}); diff != "" {
t.Errorf("diff: %s", diff)
}
}
for _, v := range unmatched {
kv := KeyValue{"key", Value{v}}
if i64s := kv.Ints(); i64s != nil {
t.Errorf("expected nil, got %v", i64s)
}
}
})
t.Run("uint64s", func(t *testing.T) {
matched, unmatched := split("uint64s", values)
for _, v := range matched {
kv := KeyValue{"key", Value{v}}
if diff := cmp.Diff(kv.Uints(), []uint64{42}); diff != "" {
t.Errorf("diff: %s", diff)
}
}
for _, v := range unmatched {
kv := KeyValue{"key", Value{v}}
if u64s := kv.Uints(); u64s != nil {
t.Errorf("expected nil, got %v", u64s)
}
}
})
t.Run("float64s", func(t *testing.T) {
matched, unmatched := split("float64s", values)
for _, v := range matched {
kv := KeyValue{"key", Value{v}}
if diff := cmp.Diff(kv.Floats(), []float64{42}); diff != "" {
t.Errorf("diff: %s", diff)
}
}
for _, v := range unmatched {
kv := KeyValue{"key", Value{v}}
if f64s := kv.Floats(); f64s != nil {
t.Errorf("expected nil, got %v", f64s)
}
}
})
t.Run("strings", func(t *testing.T) {
matched, unmatched := split("strings", values)
for _, v := range matched {
kv := KeyValue{"key", Value{v}}
if diff := cmp.Diff(kv.Strings(), v); diff != "" {
t.Errorf("diff: %s", diff)
}
}
for _, v := range unmatched {
kv := KeyValue{"key", Value{v}}
if s := kv.Strings(); s != nil {
t.Errorf("expected nil, got %v", s)
}
}
})
t.Run("bools", func(t *testing.T) {
matched, unmatched := split("bools", values)
for _, v := range matched {
kv := KeyValue{"key", Value{v}}
if diff := cmp.Diff(kv.Bools(), v); diff != "" {
t.Errorf("diff: %s", diff)
}
}
for _, v := range unmatched {
kv := KeyValue{"key", Value{v}}
if b := kv.Bools(); b != nil {
t.Errorf("expected nil, got %v", b)
}
}
})
}

View File

@@ -1,89 +0,0 @@
package gguf
import (
"encoding/binary"
"iter"
"log/slog"
)
type lazy[T any] struct {
count uint64
next func() (T, bool)
stop func()
values []T
// successFunc is called when all values have been successfully read.
successFunc func() error
}
func newLazy[T any](f *File, fn func() (T, error)) (*lazy[T], error) {
it := lazy[T]{}
if err := binary.Read(f.reader, binary.LittleEndian, &it.count); err != nil {
return nil, err
}
it.values = make([]T, 0)
it.next, it.stop = iter.Pull(func(yield func(T) bool) {
for i := range it.count {
t, err := fn()
if err != nil {
slog.Error("error reading tensor", "index", i, "error", err)
return
}
it.values = append(it.values, t)
if !yield(t) {
break
}
}
if it.successFunc != nil {
it.successFunc()
}
})
return &it, nil
}
func (g *lazy[T]) Values() iter.Seq[T] {
return func(yield func(T) bool) {
for _, v := range g.All() {
if !yield(v) {
break
}
}
}
}
func (g *lazy[T]) All() iter.Seq2[int, T] {
return func(yield func(int, T) bool) {
for i := range int(g.count) {
if i < len(g.values) {
if !yield(i, g.values[i]) {
break
}
} else {
t, ok := g.next()
if !ok {
break
}
if !yield(i, t) {
break
}
}
}
}
}
func (g *lazy[T]) rest() (collected bool) {
for {
_, ok := g.next()
collected = collected || ok
if !ok {
break
}
}
return collected
}

View File

@@ -1,23 +0,0 @@
package gguf
import (
"bufio"
"io"
)
type bufferedReader struct {
offset int64
*bufio.Reader
}
func newBufferedReader(rs io.ReadSeeker, size int) *bufferedReader {
return &bufferedReader{
Reader: bufio.NewReaderSize(rs, size),
}
}
func (rs *bufferedReader) Read(p []byte) (n int, err error) {
n, err = rs.Reader.Read(p)
rs.offset += int64(n)
return n, err
}

View File

@@ -1,288 +0,0 @@
package gguf
import (
"log/slog"
"strings"
)
type TensorInfo struct {
Name string
Offset uint64
Shape []uint64
Type TensorType
}
func (ti TensorInfo) Valid() bool {
return ti.Name != "" && ti.NumBytes() > 0
}
func (ti TensorInfo) NumValues() int64 {
var numItems int64 = 1
for _, dim := range ti.Shape {
numItems *= int64(dim)
}
return numItems
}
// NumBytes returns the number of bytes in the tensor.
func (ti TensorInfo) NumBytes() int64 {
return int64(float64(ti.NumValues()) * ti.Type.NumBytes())
}
func (ti TensorInfo) LogValue() slog.Value {
return slog.GroupValue(
slog.String("name", ti.Name),
slog.Int64("offset", int64(ti.Offset)),
slog.Any("shape", ti.Shape),
slog.Int64("num_values", ti.NumValues()),
slog.Int64("num_bytes", ti.NumBytes()),
slog.Any("type", ti.Type),
)
}
type TensorType uint32
const (
TensorTypeF32 TensorType = iota
TensorTypeF16
TensorTypeQ4_0
TensorTypeQ4_1
// unexported // unused in gguf
tensorTypeQ4_2
tensorTypeQ4_3
TensorTypeQ5_0
TensorTypeQ5_1
TensorTypeQ8_0
TensorTypeQ8_1
TensorTypeQ2_K
TensorTypeQ3_K
TensorTypeQ4_K
TensorTypeQ5_K
TensorTypeQ6_K
TensorTypeQ8_K
// unexported // unquantizable by ollama
tensorTypeIQ2_XXS
tensorTypeIQ2_XS
tensorTypeIQ3_XXS
tensorTypeIQ1_S
tensorTypeIQ4_NL
tensorTypeIQ3_S
tensorTypeIQ2_S
tensorTypeIQ4_XS
TensorTypeI8
TensorTypeI16
TensorTypeI32
TensorTypeI64
TensorTypeF64
// unexported // unquantizable by ollama
tensorTypeIQ1_M
TensorTypeBF16
// unexported // unused in gguf
tensorTypeQ4_0_4_4
tensorTypeQ4_0_4_8
tensorTypeQ4_0_8_8
// unexported // unquantizable by ollama
tensorTypeTQ1_0
tensorTypeTQ2_0
// unexported // unused in gguf
tensorTypeIQ4_NL_4_4
tensorTypeIQ4_NL_4_8
tensorTypeIQ4_NL_8_8
)
func (tt TensorType) NumBytes() float64 {
return float64(tt.typeSize()) / float64(tt.blockSize())
}
func (tt TensorType) typeSize() int64 {
switch tt {
case TensorTypeF32:
return 4
case TensorTypeF16:
return 2
case TensorTypeQ4_0:
return 2 + tt.blockSize()/2
case TensorTypeQ4_1:
return 2 + 2 + tt.blockSize()/2
case TensorTypeQ5_0:
return 2 + 4 + tt.blockSize()/2
case TensorTypeQ5_1:
return 2 + 2 + 4 + tt.blockSize()/2
case TensorTypeQ8_0:
return 2 + tt.blockSize()
case TensorTypeQ8_1:
return 2 + 2 + tt.blockSize()
case TensorTypeQ2_K:
return tt.blockSize()/16 + tt.blockSize()/4 + 2 + 2
case TensorTypeQ3_K:
return tt.blockSize()/8 + tt.blockSize()/4 + 12 + 2
case TensorTypeQ4_K:
return 2 + 2 + 12 + tt.blockSize()/2
case TensorTypeQ5_K:
return 2 + 2 + 12 + tt.blockSize()/8 + tt.blockSize()/2
case TensorTypeQ6_K:
return tt.blockSize()/2 + tt.blockSize()/4 + tt.blockSize()/16 + 2
case TensorTypeQ8_K:
return 4 + tt.blockSize() + 2*tt.blockSize()/16
case tensorTypeIQ2_XXS:
return 2 + 2*tt.blockSize()/8
case tensorTypeIQ2_XS:
return 2 + 2*tt.blockSize()/8 + tt.blockSize()/32
case tensorTypeIQ3_XXS:
return 2 + tt.blockSize()/4 + tt.blockSize()/8
case tensorTypeIQ1_S:
return 2 + tt.blockSize()/8 + tt.blockSize()/16
case tensorTypeIQ4_NL:
return 2 + tt.blockSize()/2
case tensorTypeIQ3_S:
return 2 + tt.blockSize()/4 + tt.blockSize()/8 + tt.blockSize()/32 + 4
case tensorTypeIQ2_S:
return 2 + tt.blockSize()/4 + tt.blockSize()/16
case tensorTypeIQ4_XS:
return 2 + 2 + tt.blockSize()/2 + tt.blockSize()/64
case TensorTypeI8:
return 1
case TensorTypeI16:
return 2
case TensorTypeI32:
return 4
case TensorTypeI64:
return 8
case TensorTypeF64:
return 8
case tensorTypeIQ1_M:
return tt.blockSize()/8 + tt.blockSize()/16 + tt.blockSize()/32
case TensorTypeBF16:
return 2
default:
return 0
}
}
func (tt TensorType) blockSize() int64 {
switch tt {
case TensorTypeF32,
TensorTypeF16,
TensorTypeI8,
TensorTypeI16,
TensorTypeI32,
TensorTypeI64,
TensorTypeF64,
TensorTypeBF16:
return 1
case TensorTypeQ4_0,
TensorTypeQ4_1,
TensorTypeQ5_0,
TensorTypeQ5_1,
TensorTypeQ8_0,
TensorTypeQ8_1,
tensorTypeIQ4_NL:
return 32
default:
return 256
}
}
func (tt TensorType) String() string {
switch tt {
case TensorTypeF32:
return "f32"
case TensorTypeF16:
return "f16"
case TensorTypeQ4_0:
return "q4_0"
case TensorTypeQ4_1:
return "q4_1"
case tensorTypeQ4_2:
return "q4_2"
case tensorTypeQ4_3:
return "q4_3"
case TensorTypeQ5_0:
return "q5_0"
case TensorTypeQ5_1:
return "q5_1"
case TensorTypeQ8_0:
return "q8_0"
case TensorTypeQ8_1:
return "q8_1"
case TensorTypeQ2_K:
return "q2_k"
case TensorTypeQ3_K:
return "q3_k"
case TensorTypeQ4_K:
return "q4_k"
case TensorTypeQ5_K:
return "q5_k"
case TensorTypeQ6_K:
return "q6_k"
case TensorTypeQ8_K:
return "q8_k"
case tensorTypeIQ2_XXS:
return "iq2_xxs"
case tensorTypeIQ2_XS:
return "iq2_xs"
case tensorTypeIQ3_XXS:
return "iq3_xxs"
case tensorTypeIQ1_S:
return "iq1_s"
case tensorTypeIQ4_NL:
return "iq4_nl"
case tensorTypeIQ3_S:
return "iq3_s"
case tensorTypeIQ2_S:
return "iq2_s"
case tensorTypeIQ4_XS:
return "iq4_xs"
case TensorTypeI8:
return "i8"
case TensorTypeI16:
return "i16"
case TensorTypeI32:
return "i32"
case TensorTypeI64:
return "i64"
case TensorTypeF64:
return "f64"
case tensorTypeIQ1_M:
return "iq1_m"
case TensorTypeBF16:
return "bf16"
case tensorTypeQ4_0_4_4:
return "q4_0_4_4"
case tensorTypeQ4_0_4_8:
return "q4_0_4_8"
case tensorTypeQ4_0_8_8:
return "q4_0_8_8"
case tensorTypeTQ1_0:
return "tq1_0"
case tensorTypeTQ2_0:
return "tq2_0"
case tensorTypeIQ4_NL_4_4:
return "iq4_nl_4_4"
case tensorTypeIQ4_NL_4_8:
return "iq4_nl_4_8"
case tensorTypeIQ4_NL_8_8:
return "iq4_nl_8_8"
default:
return "unknown"
}
}
func (tt TensorType) LogValue() slog.Value {
return slog.GroupValue(
slog.Uint64("value", uint64(tt)),
slog.String("name", strings.ToUpper(tt.String())),
slog.Int64("size", tt.typeSize()),
slog.Int64("block_size", tt.blockSize()),
slog.Float64("num_bytes", tt.NumBytes()),
)
}

16
go.mod
View File

@@ -11,7 +11,7 @@ require (
github.com/spf13/cobra v1.7.0 github.com/spf13/cobra v1.7.0
github.com/stretchr/testify v1.9.0 github.com/stretchr/testify v1.9.0
github.com/x448/float16 v0.8.4 github.com/x448/float16 v0.8.4
golang.org/x/sync v0.12.0 golang.org/x/sync v0.11.0
) )
require ( require (
@@ -19,13 +19,12 @@ require (
github.com/d4l3k/go-bfloat16 v0.0.0-20211005043715-690c3bdd05f1 github.com/d4l3k/go-bfloat16 v0.0.0-20211005043715-690c3bdd05f1
github.com/dlclark/regexp2 v1.11.4 github.com/dlclark/regexp2 v1.11.4
github.com/emirpasic/gods/v2 v2.0.0-alpha github.com/emirpasic/gods/v2 v2.0.0-alpha
github.com/google/go-cmp v0.7.0 github.com/google/go-cmp v0.6.0
github.com/mattn/go-runewidth v0.0.14 github.com/mattn/go-runewidth v0.0.14
github.com/nlpodyssey/gopickle v0.3.0 github.com/nlpodyssey/gopickle v0.3.0
github.com/pdevine/tensor v0.0.0-20240510204454-f88f4562727c github.com/pdevine/tensor v0.0.0-20240510204454-f88f4562727c
golang.org/x/image v0.22.0 golang.org/x/image v0.22.0
golang.org/x/tools v0.30.0 golang.org/x/tools v0.30.0
gonum.org/v1/gonum v0.15.0
) )
require ( require (
@@ -45,6 +44,7 @@ require (
github.com/xtgo/set v1.0.0 // indirect github.com/xtgo/set v1.0.0 // indirect
go4.org/unsafe/assume-no-moving-gc v0.0.0-20231121144256-b99613f794b6 // indirect go4.org/unsafe/assume-no-moving-gc v0.0.0-20231121144256-b99613f794b6 // indirect
golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1 // indirect golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1 // indirect
gonum.org/v1/gonum v0.15.0 // indirect
gorgonia.org/vecf32 v0.9.0 // indirect gorgonia.org/vecf32 v0.9.0 // indirect
gorgonia.org/vecf64 v0.9.0 // indirect gorgonia.org/vecf64 v0.9.0 // indirect
) )
@@ -70,12 +70,12 @@ require (
github.com/twitchyliquid64/golang-asm v0.15.1 // indirect github.com/twitchyliquid64/golang-asm v0.15.1 // indirect
github.com/ugorji/go/codec v1.2.12 // indirect github.com/ugorji/go/codec v1.2.12 // indirect
golang.org/x/arch v0.8.0 // indirect golang.org/x/arch v0.8.0 // indirect
golang.org/x/crypto v0.36.0 golang.org/x/crypto v0.33.0
golang.org/x/exp v0.0.0-20250218142911-aa4b98e5adaa golang.org/x/exp v0.0.0-20250218142911-aa4b98e5adaa
golang.org/x/net v0.38.0 // indirect golang.org/x/net v0.35.0 // indirect
golang.org/x/sys v0.31.0 golang.org/x/sys v0.30.0
golang.org/x/term v0.30.0 golang.org/x/term v0.29.0
golang.org/x/text v0.23.0 golang.org/x/text v0.22.0
google.golang.org/protobuf v1.34.1 google.golang.org/protobuf v1.34.1
gopkg.in/yaml.v3 v3.0.1 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect
) )

28
go.sum
View File

@@ -112,8 +112,8 @@ github.com/google/go-cmp v0.4.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/
github.com/google/go-cmp v0.5.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.5.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
github.com/google/go-cmp v0.5.6/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.5.6/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8= github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI=
github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU= github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
github.com/google/uuid v1.1.2/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/google/uuid v1.1.2/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
@@ -214,8 +214,8 @@ golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACk
golang.org/x/crypto v0.0.0-20190510104115-cbcb75029529/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= golang.org/x/crypto v0.0.0-20190510104115-cbcb75029529/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
golang.org/x/crypto v0.36.0 h1:AnAEvhDddvBdpY+uR+MyHmuZzzNqXSe/GvuDeob5L34= golang.org/x/crypto v0.33.0 h1:IOBPskki6Lysi0lo9qQvbxiQ+FvsCC/YWOecCHAixus=
golang.org/x/crypto v0.36.0/go.mod h1:Y4J0ReaxCR1IMaabaSMugxJES1EpwhBHhv2bDHklZvc= golang.org/x/crypto v0.33.0/go.mod h1:bVdXmD7IV/4GdElGPozy6U7lWdRXA4qyRVGJV57uQ5M=
golang.org/x/exp v0.0.0-20180321215751-8460e604b9de/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/exp v0.0.0-20180321215751-8460e604b9de/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
golang.org/x/exp v0.0.0-20180807140117-3d87b88a115f/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/exp v0.0.0-20180807140117-3d87b88a115f/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
@@ -257,8 +257,8 @@ golang.org/x/net v0.0.0-20200822124328-c89045814202/go.mod h1:/O7V0waA8r7cgGh81R
golang.org/x/net v0.0.0-20201021035429-f5854403a974/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU= golang.org/x/net v0.0.0-20201021035429-f5854403a974/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU=
golang.org/x/net v0.0.0-20210405180319-a5a99cb37ef4/go.mod h1:p54w0d4576C0XHj96bSt6lcn1PtDYWL6XObtHCRCNQM= golang.org/x/net v0.0.0-20210405180319-a5a99cb37ef4/go.mod h1:p54w0d4576C0XHj96bSt6lcn1PtDYWL6XObtHCRCNQM=
golang.org/x/net v0.0.0-20210614182718-04defd469f4e/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= golang.org/x/net v0.0.0-20210614182718-04defd469f4e/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y=
golang.org/x/net v0.38.0 h1:vRMAPTMaeGqVhG5QyLJHqNDwecKTomGeqbnfZyKlBI8= golang.org/x/net v0.35.0 h1:T5GQRQb2y08kTAByq9L4/bz8cipCdA8FbRTXewonqY8=
golang.org/x/net v0.38.0/go.mod h1:ivrbrMbzFq5J41QOQh0siUuly180yBYtLp+CKbEaFx8= golang.org/x/net v0.35.0/go.mod h1:EglIi67kWsHKlRzzVMUD93VMSWGFOMSZgxFjparz1Qk=
golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U=
golang.org/x/oauth2 v0.0.0-20200107190931-bf48bf16ab8d/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= golang.org/x/oauth2 v0.0.0-20200107190931-bf48bf16ab8d/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw=
golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
@@ -268,8 +268,8 @@ golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJ
golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.12.0 h1:MHc5BpPuC30uJk597Ri8TV3CNZcTLu6B6z4lJy+g6Jw= golang.org/x/sync v0.11.0 h1:GGz8+XQP4FvTTrjZPzNKTMFtSXH80RAzG+5ghFPgK9w=
golang.org/x/sync v0.12.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA= golang.org/x/sync v0.11.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20190312061237-fead79001313/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20190312061237-fead79001313/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
@@ -285,17 +285,17 @@ golang.org/x/sys v0.0.0-20210510120138-977fb7262007/go.mod h1:oPkhp1MJrh7nUepCBc
golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.31.0 h1:ioabZlmFYtWhL+TRYpcnNlLwhyxaM9kWTDEmfnprqik= golang.org/x/sys v0.30.0 h1:QjkSwP/36a20jFYWkSue1YwXzLmsV5Gfq7Eiy72C1uc=
golang.org/x/sys v0.31.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k= golang.org/x/sys v0.30.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
golang.org/x/term v0.30.0 h1:PQ39fJZ+mfadBm0y5WlL4vlM7Sx1Hgf13sMIY2+QS9Y= golang.org/x/term v0.29.0 h1:L6pJp37ocefwRRtYPKSWOWzOtWSxVajvz2ldH/xi3iU=
golang.org/x/term v0.30.0/go.mod h1:NYYFdzHoI5wRh/h5tDMdMqCqPJZEuNqVR5xJLd/n67g= golang.org/x/term v0.29.0/go.mod h1:6bl4lRlvVuDgSf3179VpIxBF0o10JUpXWOnI7nErv7s=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/text v0.3.5/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.5/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/text v0.23.0 h1:D71I7dUrlY+VX0gQShAThNGHFxZ13dGLBHQLVl1mJlY= golang.org/x/text v0.22.0 h1:bofq7m3/HAFvbF51jz3Q9wLg3jkvSPuiZu/pD1XwgtM=
golang.org/x/text v0.23.0/go.mod h1:/BLNzu4aZCJ1+kcD0DNRotWKage4q2rGVAg4o22unh4= golang.org/x/text v0.22.0/go.mod h1:YRoo4H8PVmsu+E3Ou7cqLVH8oXWIHVoX0jqUWALQhfY=
golang.org/x/tools v0.0.0-20180525024113-a5b4c53f6e8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20180525024113-a5b4c53f6e8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
golang.org/x/tools v0.0.0-20190114222345-bf090417da8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20190114222345-bf090417da8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=

View File

@@ -1,412 +0,0 @@
//go:build integration
package integration
import (
"bytes"
"context"
"fmt"
"math/rand"
"strings"
"testing"
"time"
"github.com/ollama/ollama/api"
)
func TestAPIGenerate(t *testing.T) {
initialTimeout := 60 * time.Second
streamTimeout := 30 * time.Second
ctx, cancel := context.WithTimeout(context.Background(), 1*time.Minute)
defer cancel()
// Set up the test data
req := api.GenerateRequest{
Model: smol,
Prompt: "why is the sky blue? be brief",
Options: map[string]interface{}{
"temperature": 0,
"seed": 123,
},
}
anyResp := []string{"rayleigh", "scattering"}
client, _, cleanup := InitServerConnection(ctx, t)
defer cleanup()
if err := PullIfMissing(ctx, client, req.Model); err != nil {
t.Fatalf("pull failed %s", err)
}
tests := []struct {
name string
stream bool
}{
{
name: "stream",
stream: true,
},
{
name: "no_stream",
stream: false,
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
stallTimer := time.NewTimer(initialTimeout)
var buf bytes.Buffer
fn := func(response api.GenerateResponse) error {
// Fields that must always be present
if response.Model == "" {
t.Errorf("response missing model: %#v", response)
}
if response.Done {
// Required fields for final updates:
if response.DoneReason == "" && *req.Stream {
// TODO - is the lack of done reason on non-stream a bug?
t.Errorf("final response missing done_reason: %#v", response)
}
if response.Metrics.TotalDuration == 0 {
t.Errorf("final response missing total_duration: %#v", response)
}
if response.Metrics.LoadDuration == 0 {
t.Errorf("final response missing load_duration: %#v", response)
}
if response.Metrics.PromptEvalDuration == 0 {
t.Errorf("final response missing prompt_eval_duration: %#v", response)
}
if response.Metrics.EvalCount == 0 {
t.Errorf("final response missing eval_count: %#v", response)
}
if response.Metrics.EvalDuration == 0 {
t.Errorf("final response missing eval_duration: %#v", response)
}
if len(response.Context) == 0 {
t.Errorf("final response missing context: %#v", response)
}
// Note: caching can result in no prompt eval count, so this can't be verified reliably
// if response.Metrics.PromptEvalCount == 0 {
// t.Errorf("final response missing prompt_eval_count: %#v", response)
// }
} // else incremental response, nothing to check right now...
buf.Write([]byte(response.Response))
if !stallTimer.Reset(streamTimeout) {
return fmt.Errorf("stall was detected while streaming response, aborting")
}
return nil
}
done := make(chan int)
var genErr error
go func() {
req.Stream = &test.stream
req.Options["seed"] = rand.Int() // bust cache for prompt eval results
genErr = client.Generate(ctx, &req, fn)
done <- 0
}()
select {
case <-stallTimer.C:
if buf.Len() == 0 {
t.Errorf("generate never started. Timed out after :%s", initialTimeout.String())
} else {
t.Errorf("generate stalled. Response so far:%s", buf.String())
}
case <-done:
if genErr != nil {
t.Fatalf("failed with %s request prompt %s ", req.Model, req.Prompt)
}
// Verify the response contains the expected data
response := buf.String()
atLeastOne := false
for _, resp := range anyResp {
if strings.Contains(strings.ToLower(response), resp) {
atLeastOne = true
break
}
}
if !atLeastOne {
t.Errorf("none of %v found in %s", anyResp, response)
}
case <-ctx.Done():
t.Error("outer test context done while waiting for generate")
}
})
}
// Validate PS while we're at it...
resp, err := client.ListRunning(ctx)
if err != nil {
t.Fatalf("list models API error: %s", err)
}
if resp == nil || len(resp.Models) == 0 {
t.Fatalf("list models API returned empty list while model should still be loaded")
}
// Find the model we just loaded and verify some attributes
found := false
for _, model := range resp.Models {
if strings.Contains(model.Name, req.Model) {
found = true
if model.Model == "" {
t.Errorf("model field omitted: %#v", model)
}
if model.Size == 0 {
t.Errorf("size omitted: %#v", model)
}
if model.Digest == "" {
t.Errorf("digest omitted: %#v", model)
}
verifyModelDetails(t, model.Details)
var nilTime time.Time
if model.ExpiresAt == nilTime {
t.Errorf("expires_at omitted: %#v", model)
}
// SizeVRAM could be zero.
}
}
if !found {
t.Errorf("unable to locate running model: %#v", resp)
}
}
func TestAPIChat(t *testing.T) {
initialTimeout := 60 * time.Second
streamTimeout := 30 * time.Second
ctx, cancel := context.WithTimeout(context.Background(), 1*time.Minute)
defer cancel()
// Set up the test data
req := api.ChatRequest{
Model: smol,
Messages: []api.Message{
{
Role: "user",
Content: "why is the sky blue? be brief",
},
},
Options: map[string]interface{}{
"temperature": 0,
"seed": 123,
},
}
anyResp := []string{"rayleigh", "scattering"}
client, _, cleanup := InitServerConnection(ctx, t)
defer cleanup()
if err := PullIfMissing(ctx, client, req.Model); err != nil {
t.Fatalf("pull failed %s", err)
}
tests := []struct {
name string
stream bool
}{
{
name: "stream",
stream: true,
},
{
name: "no_stream",
stream: false,
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
stallTimer := time.NewTimer(initialTimeout)
var buf bytes.Buffer
fn := func(response api.ChatResponse) error {
// Fields that must always be present
if response.Model == "" {
t.Errorf("response missing model: %#v", response)
}
if response.Done {
// Required fields for final updates:
var nilTime time.Time
if response.CreatedAt == nilTime {
t.Errorf("final response missing total_duration: %#v", response)
}
if response.DoneReason == "" {
t.Errorf("final response missing done_reason: %#v", response)
}
if response.Metrics.TotalDuration == 0 {
t.Errorf("final response missing total_duration: %#v", response)
}
if response.Metrics.LoadDuration == 0 {
t.Errorf("final response missing load_duration: %#v", response)
}
if response.Metrics.PromptEvalDuration == 0 {
t.Errorf("final response missing prompt_eval_duration: %#v", response)
}
if response.Metrics.EvalCount == 0 {
t.Errorf("final response missing eval_count: %#v", response)
}
if response.Metrics.EvalDuration == 0 {
t.Errorf("final response missing eval_duration: %#v", response)
}
if response.Metrics.PromptEvalCount == 0 {
t.Errorf("final response missing prompt_eval_count: %#v", response)
}
} // else incremental response, nothing to check right now...
buf.Write([]byte(response.Message.Content))
if !stallTimer.Reset(streamTimeout) {
return fmt.Errorf("stall was detected while streaming response, aborting")
}
return nil
}
done := make(chan int)
var genErr error
go func() {
req.Stream = &test.stream
req.Options["seed"] = rand.Int() // bust cache for prompt eval results
genErr = client.Chat(ctx, &req, fn)
done <- 0
}()
select {
case <-stallTimer.C:
if buf.Len() == 0 {
t.Errorf("chat never started. Timed out after :%s", initialTimeout.String())
} else {
t.Errorf("chat stalled. Response so far:%s", buf.String())
}
case <-done:
if genErr != nil {
t.Fatalf("failed with %s request prompt %v", req.Model, req.Messages)
}
// Verify the response contains the expected data
response := buf.String()
atLeastOne := false
for _, resp := range anyResp {
if strings.Contains(strings.ToLower(response), resp) {
atLeastOne = true
break
}
}
if !atLeastOne {
t.Errorf("none of %v found in %s", anyResp, response)
}
case <-ctx.Done():
t.Error("outer test context done while waiting for chat")
}
})
}
}
func TestAPIListModels(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
client, _, cleanup := InitServerConnection(ctx, t)
defer cleanup()
// Make sure we have at least one model so an empty list can be considered a failure
if err := PullIfMissing(ctx, client, smol); err != nil {
t.Fatalf("pull failed %s", err)
}
resp, err := client.List(ctx)
if err != nil {
t.Fatalf("unable to list models: %s", err)
}
if len(resp.Models) == 0 {
t.Fatalf("list should not be empty")
}
model := resp.Models[0]
if model.Name == "" {
t.Errorf("first model name empty: %#v", model)
}
var nilTime time.Time
if model.ModifiedAt == nilTime {
t.Errorf("first model modified_at empty: %#v", model)
}
if model.Size == 0 {
t.Errorf("first model size empty: %#v", model)
}
if model.Digest == "" {
t.Errorf("first model digest empty: %#v", model)
}
verifyModelDetails(t, model.Details)
}
func verifyModelDetails(t *testing.T, details api.ModelDetails) {
if details.Format == "" {
t.Errorf("first model details.format empty: %#v", details)
}
if details.Family == "" {
t.Errorf("first model details.family empty: %#v", details)
}
if details.ParameterSize == "" {
t.Errorf("first model details.parameter_size empty: %#v", details)
}
if details.QuantizationLevel == "" {
t.Errorf("first model details.quantization_level empty: %#v", details)
}
}
func TestAPIShowModel(t *testing.T) {
modelName := "llama3.2"
ctx, cancel := context.WithTimeout(context.Background(), 1*time.Minute)
defer cancel()
client, _, cleanup := InitServerConnection(ctx, t)
defer cleanup()
if err := PullIfMissing(ctx, client, modelName); err != nil {
t.Fatalf("pull failed %s", err)
}
resp, err := client.Show(ctx, &api.ShowRequest{Name: modelName})
if err != nil {
t.Fatalf("unable to show model: %s", err)
}
if resp.License == "" {
t.Errorf("%s missing license: %#v", modelName, resp)
}
if resp.Modelfile == "" {
t.Errorf("%s missing modelfile: %#v", modelName, resp)
}
if resp.Parameters == "" {
t.Errorf("%s missing parameters: %#v", modelName, resp)
}
if resp.Template == "" {
t.Errorf("%s missing template: %#v", modelName, resp)
}
// llama3 omits system
verifyModelDetails(t, resp.Details)
// llama3 ommits messages
if len(resp.ModelInfo) == 0 {
t.Errorf("%s missing model_info: %#v", modelName, resp)
}
// llama3 omits projectors
var nilTime time.Time
if resp.ModifiedAt == nilTime {
t.Errorf("%s missing modified_at: %#v", modelName, resp)
}
}
func TestAPIEmbeddings(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 1*time.Minute)
defer cancel()
client, _, cleanup := InitServerConnection(ctx, t)
defer cleanup()
req := api.EmbeddingRequest{
Model: "orca-mini",
Prompt: "why is the sky blue?",
Options: map[string]interface{}{
"temperature": 0,
"seed": 123,
},
}
if err := PullIfMissing(ctx, client, req.Model); err != nil {
t.Fatalf("pull failed %s", err)
}
resp, err := client.Embeddings(ctx, &req)
if err != nil {
t.Fatalf("embeddings call failed %s", err)
}
if len(resp.Embedding) == 0 {
t.Errorf("zero length embedding response")
}
}

View File

@@ -14,12 +14,12 @@ import (
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
func TestBlueSky(t *testing.T) { func TestOrcaMiniBlueSky(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute) ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute)
defer cancel() defer cancel()
// Set up the test data // Set up the test data
req := api.GenerateRequest{ req := api.GenerateRequest{
Model: smol, Model: "orca-mini",
Prompt: "why is the sky blue?", Prompt: "why is the sky blue?",
Stream: &stream, Stream: &stream,
Options: map[string]any{ Options: map[string]any{
@@ -31,7 +31,6 @@ func TestBlueSky(t *testing.T) {
} }
func TestUnicode(t *testing.T) { func TestUnicode(t *testing.T) {
skipUnderMinVRAM(t, 6)
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Minute) ctx, cancel := context.WithTimeout(context.Background(), 3*time.Minute)
defer cancel() defer cancel()
// Set up the test data // Set up the test data
@@ -94,7 +93,7 @@ func TestUnicodeModelDir(t *testing.T) {
defer cancel() defer cancel()
req := api.GenerateRequest{ req := api.GenerateRequest{
Model: smol, Model: "orca-mini",
Prompt: "why is the sky blue?", Prompt: "why is the sky blue?",
Stream: &stream, Stream: &stream,
Options: map[string]any{ Options: map[string]any{

View File

@@ -21,7 +21,7 @@ func TestMultiModelConcurrency(t *testing.T) {
var ( var (
req = [2]api.GenerateRequest{ req = [2]api.GenerateRequest{
{ {
Model: "llama3.2:1b", Model: "orca-mini",
Prompt: "why is the ocean blue?", Prompt: "why is the ocean blue?",
Stream: &stream, Stream: &stream,
KeepAlive: &api.Duration{Duration: 10 * time.Second}, KeepAlive: &api.Duration{Duration: 10 * time.Second},
@@ -67,7 +67,7 @@ func TestMultiModelConcurrency(t *testing.T) {
wg.Wait() wg.Wait()
} }
func TestIntegrationConcurrentPredict(t *testing.T) { func TestIntegrationConcurrentPredictOrcaMini(t *testing.T) {
req, resp := GenerateRequests() req, resp := GenerateRequests()
reqLimit := len(req) reqLimit := len(req)
iterLimit := 5 iterLimit := 5
@@ -117,9 +117,6 @@ func TestMultiModelStress(t *testing.T) {
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
if maxVram < 2*format.GibiByte {
t.Skip("VRAM less than 2G, skipping model stress tests")
}
type model struct { type model struct {
name string name string
@@ -128,8 +125,8 @@ func TestMultiModelStress(t *testing.T) {
smallModels := []model{ smallModels := []model{
{ {
name: "llama3.2:1b", name: "orca-mini",
size: 2876 * format.MebiByte, size: 2992 * format.MebiByte,
}, },
{ {
name: "phi", name: "phi",

View File

@@ -34,15 +34,13 @@ func cosineSimilarity[V float32 | float64](v1, v2 []V) V {
func TestAllMiniLMEmbeddings(t *testing.T) { func TestAllMiniLMEmbeddings(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute) ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute)
defer cancel() defer cancel()
client, _, cleanup := InitServerConnection(ctx, t)
defer cleanup()
req := api.EmbeddingRequest{ req := api.EmbeddingRequest{
Model: "all-minilm", Model: "all-minilm",
Prompt: "why is the sky blue?", Prompt: "why is the sky blue?",
} }
res, err := embeddingTestHelper(ctx, client, t, req) res, err := embeddingTestHelper(ctx, t, req)
if err != nil { if err != nil {
t.Fatalf("error: %v", err) t.Fatalf("error: %v", err)
@@ -64,15 +62,13 @@ func TestAllMiniLMEmbeddings(t *testing.T) {
func TestAllMiniLMEmbed(t *testing.T) { func TestAllMiniLMEmbed(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute) ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute)
defer cancel() defer cancel()
client, _, cleanup := InitServerConnection(ctx, t)
defer cleanup()
req := api.EmbedRequest{ req := api.EmbedRequest{
Model: "all-minilm", Model: "all-minilm",
Input: "why is the sky blue?", Input: "why is the sky blue?",
} }
res, err := embedTestHelper(ctx, client, t, req) res, err := embedTestHelper(ctx, t, req)
if err != nil { if err != nil {
t.Fatalf("error: %v", err) t.Fatalf("error: %v", err)
@@ -102,15 +98,13 @@ func TestAllMiniLMEmbed(t *testing.T) {
func TestAllMiniLMBatchEmbed(t *testing.T) { func TestAllMiniLMBatchEmbed(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute) ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute)
defer cancel() defer cancel()
client, _, cleanup := InitServerConnection(ctx, t)
defer cleanup()
req := api.EmbedRequest{ req := api.EmbedRequest{
Model: "all-minilm", Model: "all-minilm",
Input: []string{"why is the sky blue?", "why is the grass green?"}, Input: []string{"why is the sky blue?", "why is the grass green?"},
} }
res, err := embedTestHelper(ctx, client, t, req) res, err := embedTestHelper(ctx, t, req)
if err != nil { if err != nil {
t.Fatalf("error: %v", err) t.Fatalf("error: %v", err)
@@ -150,8 +144,6 @@ func TestAllMiniLMBatchEmbed(t *testing.T) {
func TestAllMiniLMEmbedTruncate(t *testing.T) { func TestAllMiniLMEmbedTruncate(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute) ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute)
defer cancel() defer cancel()
client, _, cleanup := InitServerConnection(ctx, t)
defer cleanup()
truncTrue, truncFalse := true, false truncTrue, truncFalse := true, false
@@ -190,7 +182,7 @@ func TestAllMiniLMEmbedTruncate(t *testing.T) {
res := make(map[string]*api.EmbedResponse) res := make(map[string]*api.EmbedResponse)
for _, req := range reqs { for _, req := range reqs {
response, err := embedTestHelper(ctx, client, t, req.Request) response, err := embedTestHelper(ctx, t, req.Request)
if err != nil { if err != nil {
t.Fatalf("error: %v", err) t.Fatalf("error: %v", err)
} }
@@ -206,7 +198,7 @@ func TestAllMiniLMEmbedTruncate(t *testing.T) {
} }
// check that truncate set to false returns an error if context length is exceeded // check that truncate set to false returns an error if context length is exceeded
_, err := embedTestHelper(ctx, client, t, api.EmbedRequest{ _, err := embedTestHelper(ctx, t, api.EmbedRequest{
Model: "all-minilm", Model: "all-minilm",
Input: "why is the sky blue?", Input: "why is the sky blue?",
Truncate: &truncFalse, Truncate: &truncFalse,
@@ -218,7 +210,9 @@ func TestAllMiniLMEmbedTruncate(t *testing.T) {
} }
} }
func embeddingTestHelper(ctx context.Context, client *api.Client, t *testing.T, req api.EmbeddingRequest) (*api.EmbeddingResponse, error) { func embeddingTestHelper(ctx context.Context, t *testing.T, req api.EmbeddingRequest) (*api.EmbeddingResponse, error) {
client, _, cleanup := InitServerConnection(ctx, t)
defer cleanup()
if err := PullIfMissing(ctx, client, req.Model); err != nil { if err := PullIfMissing(ctx, client, req.Model); err != nil {
t.Fatalf("failed to pull model %s: %v", req.Model, err) t.Fatalf("failed to pull model %s: %v", req.Model, err)
} }
@@ -232,7 +226,9 @@ func embeddingTestHelper(ctx context.Context, client *api.Client, t *testing.T,
return response, nil return response, nil
} }
func embedTestHelper(ctx context.Context, client *api.Client, t *testing.T, req api.EmbedRequest) (*api.EmbedResponse, error) { func embedTestHelper(ctx context.Context, t *testing.T, req api.EmbedRequest) (*api.EmbedResponse, error) {
client, _, cleanup := InitServerConnection(ctx, t)
defer cleanup()
if err := PullIfMissing(ctx, client, req.Model); err != nil { if err := PullIfMissing(ctx, client, req.Model); err != nil {
t.Fatalf("failed to pull model %s: %v", req.Model, err) t.Fatalf("failed to pull model %s: %v", req.Model, err)
} }

View File

@@ -12,55 +12,61 @@ import (
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
func TestVisionModels(t *testing.T) { func TestIntegrationLlava(t *testing.T) {
skipUnderMinVRAM(t, 6) image, err := base64.StdEncoding.DecodeString(imageEncoding)
type testCase struct { require.NoError(t, err)
model string req := api.GenerateRequest{
} Model: "llava:7b",
testCases := []testCase{ Prompt: "what does the text in this image say?",
{ Stream: &stream,
model: "qwen2.5vl", Options: map[string]any{
"seed": 42,
"temperature": 0.0,
}, },
{ Images: []api.ImageData{
model: "llama3.2-vision", image,
},
{
model: "gemma3",
}, },
} }
for _, v := range testCases { // Note: sometimes it returns "the ollamas" sometimes "the ollams"
t.Run(v.model, func(t *testing.T) { resp := "the ollam"
image, err := base64.StdEncoding.DecodeString(imageEncoding) ctx, cancel := context.WithTimeout(context.Background(), 3*time.Minute)
require.NoError(t, err) defer cancel()
req := api.GenerateRequest{ client, _, cleanup := InitServerConnection(ctx, t)
Model: v.model, defer cleanup()
Prompt: "what does the text in this image say?", require.NoError(t, PullIfMissing(ctx, client, req.Model))
Stream: &stream, // llava models on CPU can be quite slow to start,
Options: map[string]any{ DoGenerate(ctx, t, client, req, []string{resp}, 120*time.Second, 30*time.Second)
"seed": 42, }
"temperature": 0.0,
},
Images: []api.ImageData{
image,
},
}
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute)
defer cancel()
client, _, cleanup := InitServerConnection(ctx, t)
// Note: sometimes it returns "the ollamas" sometimes "the ollams" func TestIntegrationMllama(t *testing.T) {
resp := "the ollam" image, err := base64.StdEncoding.DecodeString(imageEncoding)
defer cleanup() require.NoError(t, err)
require.NoError(t, PullIfMissing(ctx, client, req.Model)) req := api.GenerateRequest{
// llava models on CPU can be quite slow to start // TODO fix up once we publish the final image
DoGenerate(ctx, t, client, req, []string{resp}, 240*time.Second, 30*time.Second) Model: "x/llama3.2-vision",
}) Prompt: "what does the text in this image say?",
Stream: &stream,
Options: map[string]any{
"seed": 42,
"temperature": 0.0,
},
Images: []api.ImageData{
image,
},
} }
resp := "the ollamas"
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute)
defer cancel()
client, _, cleanup := InitServerConnection(ctx, t)
defer cleanup()
require.NoError(t, PullIfMissing(ctx, client, req.Model))
// mllama models on CPU can be quite slow to start,
DoGenerate(ctx, t, client, req, []string{resp}, 240*time.Second, 30*time.Second)
} }
func TestIntegrationSplitBatch(t *testing.T) { func TestIntegrationSplitBatch(t *testing.T) {
skipUnderMinVRAM(t, 6)
image, err := base64.StdEncoding.DecodeString(imageEncoding) image, err := base64.StdEncoding.DecodeString(imageEncoding)
require.NoError(t, err) require.NoError(t, err)
req := api.GenerateRequest{ req := api.GenerateRequest{

View File

@@ -17,7 +17,7 @@ var (
stream = false stream = false
req = [2]api.GenerateRequest{ req = [2]api.GenerateRequest{
{ {
Model: smol, Model: "orca-mini",
Prompt: "why is the ocean blue?", Prompt: "why is the ocean blue?",
Stream: &stream, Stream: &stream,
Options: map[string]any{ Options: map[string]any{
@@ -25,7 +25,7 @@ var (
"temperature": 0.0, "temperature": 0.0,
}, },
}, { }, {
Model: smol, Model: "orca-mini",
Prompt: "what is the origin of the us thanksgiving holiday?", Prompt: "what is the origin of the us thanksgiving holiday?",
Stream: &stream, Stream: &stream,
Options: map[string]any{ Options: map[string]any{
@@ -35,12 +35,12 @@ var (
}, },
} }
resp = [2][]string{ resp = [2][]string{
{"sunlight", "scattering", "interact"}, {"sunlight"},
{"england", "english", "massachusetts", "pilgrims"}, {"england", "english", "massachusetts", "pilgrims"},
} }
) )
func TestIntegrationSimple(t *testing.T) { func TestIntegrationSimpleOrcaMini(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), time.Second*120) ctx, cancel := context.WithTimeout(context.Background(), time.Second*120)
defer cancel() defer cancel()
GenerateTestHelper(ctx, t, req[0], resp[0]) GenerateTestHelper(ctx, t, req[0], resp[0])

View File

@@ -30,7 +30,7 @@ func TestMaxQueue(t *testing.T) {
t.Setenv("OLLAMA_MAX_QUEUE", strconv.Itoa(threadCount)) t.Setenv("OLLAMA_MAX_QUEUE", strconv.Itoa(threadCount))
req := api.GenerateRequest{ req := api.GenerateRequest{
Model: smol, Model: "orca-mini",
Prompt: "write a long historical fiction story about christopher columbus. use at least 10 facts from his actual journey", Prompt: "write a long historical fiction story about christopher columbus. use at least 10 facts from his actual journey",
Options: map[string]any{ Options: map[string]any{
"seed": 42, "seed": 42,
@@ -61,7 +61,7 @@ func TestMaxQueue(t *testing.T) {
}() }()
// Give the generate a chance to get started before we start hammering on embed requests // Give the generate a chance to get started before we start hammering on embed requests
time.Sleep(10 * time.Millisecond) time.Sleep(5 * time.Millisecond)
threadCount += 10 // Add a few extra to ensure we push the queue past its limit threadCount += 10 // Add a few extra to ensure we push the queue past its limit
busyCount := 0 busyCount := 0

View File

@@ -1,186 +0,0 @@
//go:build integration && models
package integration
import (
"context"
"encoding/json"
"fmt"
"io/ioutil"
"log/slog"
"os"
"path/filepath"
"strconv"
"strings"
"testing"
"time"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/format"
)
var (
started = time.Now()
chatModels = []string{
"granite3-moe:latest",
"granite-code:latest",
"nemotron-mini:latest",
"command-r:latest",
"gemma2:latest",
"gemma:latest",
"internlm2:latest",
"phi3.5:latest",
"phi3:latest",
// "phi:latest", // flaky, sometimes generates no response on first query
"stablelm2:latest", // Predictions are off, crashes on small VRAM GPUs
"falcon:latest",
"falcon2:latest",
"minicpm-v:latest",
"mistral:latest",
"orca-mini:latest",
"llama2:latest",
"llama3.1:latest",
"llama3.2:latest",
"llama3.2-vision:latest",
"qwen2.5-coder:latest",
"qwen:latest",
"solar-pro:latest",
"codellama:latest",
"nous-hermes:latest",
}
)
func TestModelsGenerate(t *testing.T) {
softTimeout, hardTimeout := getTimeouts(t)
slog.Info("Setting timeouts", "soft", softTimeout, "hard", hardTimeout)
ctx, cancel := context.WithTimeout(context.Background(), hardTimeout)
defer cancel()
client, _, cleanup := InitServerConnection(ctx, t)
defer cleanup()
// TODO use info API eventually
var maxVram uint64
var err error
if s := os.Getenv("OLLAMA_MAX_VRAM"); s != "" {
maxVram, err = strconv.ParseUint(s, 10, 64)
if err != nil {
t.Fatalf("invalid OLLAMA_MAX_VRAM %v", err)
}
} else {
slog.Warn("No VRAM info available, testing all models, so larger ones might timeout...")
}
for _, model := range chatModels {
t.Run(model, func(t *testing.T) {
if time.Now().Sub(started) > softTimeout {
t.Skip("skipping remaining tests to avoid excessive runtime")
}
if err := PullIfMissing(ctx, client, model); err != nil {
t.Fatalf("pull failed %s", err)
}
if maxVram > 0 {
resp, err := client.List(ctx)
if err != nil {
t.Fatalf("list models failed %v", err)
}
for _, m := range resp.Models {
if m.Name == model && float32(m.Size)*1.2 > float32(maxVram) {
t.Skipf("model %s is too large for available VRAM: %s > %s", model, format.HumanBytes(m.Size), format.HumanBytes(int64(maxVram)))
}
}
}
// TODO - fiddle with context size
req := api.GenerateRequest{
Model: model,
Prompt: "why is the sky blue?",
Options: map[string]interface{}{
"temperature": 0,
"seed": 123,
},
}
anyResp := []string{"rayleigh", "scattering", "atmosphere", "nitrogen", "oxygen"}
DoGenerate(ctx, t, client, req, anyResp, 120*time.Second, 30*time.Second)
})
}
}
func TestModelsEmbed(t *testing.T) {
softTimeout, hardTimeout := getTimeouts(t)
ctx, cancel := context.WithTimeout(context.Background(), hardTimeout)
defer cancel()
client, _, cleanup := InitServerConnection(ctx, t)
defer cleanup()
// TODO use info API eventually
var maxVram uint64
var err error
if s := os.Getenv("OLLAMA_MAX_VRAM"); s != "" {
maxVram, err = strconv.ParseUint(s, 10, 64)
if err != nil {
t.Fatalf("invalid OLLAMA_MAX_VRAM %v", err)
}
} else {
slog.Warn("No VRAM info available, testing all models, so larger ones might timeout...")
}
data, err := ioutil.ReadFile(filepath.Join("testdata", "embed.json"))
if err != nil {
t.Fatalf("failed to open test data file: %s", err)
}
testCase := map[string][]float64{}
err = json.Unmarshal(data, &testCase)
if err != nil {
t.Fatalf("failed to load test data: %s", err)
}
for model, expected := range testCase {
t.Run(model, func(t *testing.T) {
if time.Now().Sub(started) > softTimeout {
t.Skip("skipping remaining tests to avoid excessive runtime")
}
if err := PullIfMissing(ctx, client, model); err != nil {
t.Fatalf("pull failed %s", err)
}
if maxVram > 0 {
resp, err := client.List(ctx)
if err != nil {
t.Fatalf("list models failed %v", err)
}
for _, m := range resp.Models {
if m.Name == model && float32(m.Size)*1.2 > float32(maxVram) {
t.Skipf("model %s is too large for available VRAM: %s > %s", model, format.HumanBytes(m.Size), format.HumanBytes(int64(maxVram)))
}
}
}
req := api.EmbeddingRequest{
Model: model,
Prompt: "why is the sky blue?",
Options: map[string]interface{}{
"temperature": 0,
"seed": 123,
},
}
resp, err := client.Embeddings(ctx, &req)
if err != nil {
t.Fatalf("embeddings call failed %s", err)
}
if len(resp.Embedding) == 0 {
t.Errorf("zero length embedding response")
}
if len(expected) != len(resp.Embedding) {
expStr := make([]string, len(resp.Embedding))
for i, v := range resp.Embedding {
expStr[i] = fmt.Sprintf("%0.6f", v)
}
// When adding new models, use this output to populate the testdata/embed.json
fmt.Printf("expected\n%s\n", strings.Join(expStr, ", "))
t.Fatalf("expected %d, got %d", len(expected), len(resp.Embedding))
}
sim := cosineSimilarity(resp.Embedding, expected)
if sim < 0.99 {
t.Fatalf("expected %v, got %v (similarity: %f)", expected[0:5], resp.Embedding[0:5], sim)
}
})
}
}

View File

@@ -1,130 +0,0 @@
//go:build integration && models
package integration
import (
"bytes"
"context"
"fmt"
"log/slog"
"strings"
"testing"
"time"
"github.com/ollama/ollama/api"
)
func TestQuantization(t *testing.T) {
sourceModels := []string{
"qwen2.5:0.5b-instruct-fp16",
}
quantizations := []string{
"Q8_0",
"Q4_K_S",
"Q4_K_M",
"Q4_K",
}
softTimeout, hardTimeout := getTimeouts(t)
started := time.Now()
slog.Info("Setting timeouts", "soft", softTimeout, "hard", hardTimeout)
ctx, cancel := context.WithTimeout(context.Background(), hardTimeout)
defer cancel()
client, _, cleanup := InitServerConnection(ctx, t)
defer cleanup()
for _, base := range sourceModels {
if err := PullIfMissing(ctx, client, base); err != nil {
t.Fatalf("pull failed %s", err)
}
for _, quant := range quantizations {
newName := fmt.Sprintf("%s__%s", base, quant)
t.Run(newName, func(t *testing.T) {
if time.Now().Sub(started) > softTimeout {
t.Skip("skipping remaining tests to avoid excessive runtime")
}
req := &api.CreateRequest{
Model: newName,
Quantization: quant,
From: base,
}
fn := func(resp api.ProgressResponse) error {
// fmt.Print(".")
return nil
}
t.Logf("quantizing: %s -> %s", base, quant)
if err := client.Create(ctx, req, fn); err != nil {
t.Fatalf("create failed %s", err)
}
defer func() {
req := &api.DeleteRequest{
Model: newName,
}
t.Logf("deleting: %s -> %s", base, quant)
if err := client.Delete(ctx, req); err != nil {
t.Logf("failed to clean up %s: %s", req.Model, err)
}
}()
// Check metadata on the model
resp, err := client.Show(ctx, &api.ShowRequest{Name: newName})
if err != nil {
t.Fatalf("unable to show model: %s", err)
}
if !strings.Contains(resp.Details.QuantizationLevel, quant) {
t.Fatalf("unexpected quantization for %s:\ngot: %s", newName, resp.Details.QuantizationLevel)
}
stream := true
genReq := api.GenerateRequest{
Model: newName,
Prompt: "why is the sky blue?",
KeepAlive: &api.Duration{Duration: 3 * time.Second},
Options: map[string]any{
"seed": 42,
"temperature": 0.0,
},
Stream: &stream,
}
t.Logf("verifying: %s -> %s", base, quant)
// Some smaller quantizations can cause models to have poor quality
// or get stuck in repetition loops, so we stop as soon as we have any matches
anyResp := []string{"rayleigh", "scattering", "day", "sun", "moon", "color", "nitrogen", "oxygen"}
reqCtx, reqCancel := context.WithCancel(ctx)
atLeastOne := false
var buf bytes.Buffer
genfn := func(response api.GenerateResponse) error {
buf.Write([]byte(response.Response))
fullResp := strings.ToLower(buf.String())
for _, resp := range anyResp {
if strings.Contains(fullResp, resp) {
atLeastOne = true
t.Log(fullResp)
reqCancel()
break
}
}
return nil
}
done := make(chan int)
var genErr error
go func() {
genErr = client.Generate(reqCtx, &genReq, genfn)
done <- 0
}()
select {
case <-done:
if genErr != nil && !atLeastOne {
t.Fatalf("failed with %s request prompt %s ", genReq.Model, genReq.Prompt)
}
case <-ctx.Done():
t.Error("outer test context done while waiting for generate")
}
t.Logf("passed")
})
}
}
}

File diff suppressed because one or more lines are too long

View File

@@ -24,14 +24,9 @@ import (
"github.com/ollama/ollama/api" "github.com/ollama/ollama/api"
"github.com/ollama/ollama/app/lifecycle" "github.com/ollama/ollama/app/lifecycle"
"github.com/ollama/ollama/format"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
const (
smol = "llama3.2:1b"
)
func Init() { func Init() {
lifecycle.InitLogging() lifecycle.InitLogging()
} }
@@ -145,7 +140,7 @@ func PullIfMissing(ctx context.Context, client *api.Client, modelName string) er
showCtx, cancel := context.WithDeadlineCause( showCtx, cancel := context.WithDeadlineCause(
ctx, ctx,
time.Now().Add(20*time.Second), time.Now().Add(10*time.Second),
fmt.Errorf("show for existing model %s took too long", modelName), fmt.Errorf("show for existing model %s took too long", modelName),
) )
defer cancel() defer cancel()
@@ -162,7 +157,7 @@ func PullIfMissing(ctx context.Context, client *api.Client, modelName string) er
} }
slog.Info("model missing", "model", modelName) slog.Info("model missing", "model", modelName)
stallDuration := 60 * time.Second // This includes checksum verification, which can take a while on larger models, and slower systems stallDuration := 30 * time.Second // This includes checksum verification, which can take a while on larger models
stallTimer := time.NewTimer(stallDuration) stallTimer := time.NewTimer(stallDuration)
fn := func(resp api.ProgressResponse) error { fn := func(resp api.ProgressResponse) error {
// fmt.Print(".") // fmt.Print(".")
@@ -217,7 +212,6 @@ func InitServerConnection(ctx context.Context, t *testing.T) (*api.Client, strin
slog.Error("failed to open server log", "logfile", lifecycle.ServerLogFile, "error", err) slog.Error("failed to open server log", "logfile", lifecycle.ServerLogFile, "error", err)
return return
} }
defer fp.Close()
data, err := io.ReadAll(fp) data, err := io.ReadAll(fp)
if err != nil { if err != nil {
slog.Error("failed to read server log", "logfile", lifecycle.ServerLogFile, "error", err) slog.Error("failed to read server log", "logfile", lifecycle.ServerLogFile, "error", err)
@@ -289,11 +283,11 @@ func DoGenerate(ctx context.Context, t *testing.T, client *api.Client, genReq ap
} }
// Generate a set of requests // Generate a set of requests
// By default each request uses llama3.2 as the model // By default each request uses orca-mini as the model
func GenerateRequests() ([]api.GenerateRequest, [][]string) { func GenerateRequests() ([]api.GenerateRequest, [][]string) {
return []api.GenerateRequest{ return []api.GenerateRequest{
{ {
Model: smol, Model: "orca-mini",
Prompt: "why is the ocean blue?", Prompt: "why is the ocean blue?",
Stream: &stream, Stream: &stream,
KeepAlive: &api.Duration{Duration: 10 * time.Second}, KeepAlive: &api.Duration{Duration: 10 * time.Second},
@@ -302,7 +296,7 @@ func GenerateRequests() ([]api.GenerateRequest, [][]string) {
"temperature": 0.0, "temperature": 0.0,
}, },
}, { }, {
Model: smol, Model: "orca-mini",
Prompt: "why is the color of dirt brown?", Prompt: "why is the color of dirt brown?",
Stream: &stream, Stream: &stream,
KeepAlive: &api.Duration{Duration: 10 * time.Second}, KeepAlive: &api.Duration{Duration: 10 * time.Second},
@@ -311,7 +305,7 @@ func GenerateRequests() ([]api.GenerateRequest, [][]string) {
"temperature": 0.0, "temperature": 0.0,
}, },
}, { }, {
Model: smol, Model: "orca-mini",
Prompt: "what is the origin of the us thanksgiving holiday?", Prompt: "what is the origin of the us thanksgiving holiday?",
Stream: &stream, Stream: &stream,
KeepAlive: &api.Duration{Duration: 10 * time.Second}, KeepAlive: &api.Duration{Duration: 10 * time.Second},
@@ -320,7 +314,7 @@ func GenerateRequests() ([]api.GenerateRequest, [][]string) {
"temperature": 0.0, "temperature": 0.0,
}, },
}, { }, {
Model: smol, Model: "orca-mini",
Prompt: "what is the origin of independence day?", Prompt: "what is the origin of independence day?",
Stream: &stream, Stream: &stream,
KeepAlive: &api.Duration{Duration: 10 * time.Second}, KeepAlive: &api.Duration{Duration: 10 * time.Second},
@@ -329,7 +323,7 @@ func GenerateRequests() ([]api.GenerateRequest, [][]string) {
"temperature": 0.0, "temperature": 0.0,
}, },
}, { }, {
Model: smol, Model: "orca-mini",
Prompt: "what is the composition of air?", Prompt: "what is the composition of air?",
Stream: &stream, Stream: &stream,
KeepAlive: &api.Duration{Duration: 10 * time.Second}, KeepAlive: &api.Duration{Duration: 10 * time.Second},
@@ -347,26 +341,3 @@ func GenerateRequests() ([]api.GenerateRequest, [][]string) {
{"nitrogen", "oxygen", "carbon", "dioxide"}, {"nitrogen", "oxygen", "carbon", "dioxide"},
} }
} }
func skipUnderMinVRAM(t *testing.T, gb uint64) {
// TODO use info API in the future
if s := os.Getenv("OLLAMA_MAX_VRAM"); s != "" {
maxVram, err := strconv.ParseUint(s, 10, 64)
require.NoError(t, err)
// Don't hammer on small VRAM cards...
if maxVram < gb*format.GibiByte {
t.Skip("skipping with small VRAM to avoid timeouts")
}
}
}
func getTimeouts(t *testing.T) (soft time.Duration, hard time.Duration) {
deadline, hasDeadline := t.Deadline()
if !hasDeadline {
return 8 * time.Minute, 10 * time.Minute
} else if deadline.Compare(time.Now().Add(2*time.Minute)) <= 0 {
t.Skip("too little time")
return time.Duration(0), time.Duration(0)
}
return -time.Since(deadline.Add(-2 * time.Minute)), -time.Since(deadline.Add(-20 * time.Second))
}

View File

@@ -21,7 +21,6 @@ type shiftFn func(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, e
type Causal struct { type Causal struct {
DType ml.DType DType ml.DType
windowSize int32 windowSize int32
chunkSize int32
opts CausalOptions opts CausalOptions
@@ -30,11 +29,6 @@ type Causal struct {
// ** current forward pass ** // ** current forward pass **
// curReserve indicates that this forward pass is only for
// memory reservation and we should not update our metadata
// based on it.
curReserve bool
// the active layer for Get and Put // the active layer for Get and Put
curLayer int curLayer int
@@ -103,17 +97,6 @@ func NewSWACache(windowSize int32, shift shiftFn) *Causal {
} }
} }
func NewChunkedAttentionCache(chunkSize int32, shift shiftFn) *Causal {
return &Causal{
windowSize: math.MaxInt32,
chunkSize: chunkSize,
shiftFn: shift,
ctxs: make(map[int]ml.Context),
keys: make(map[int]ml.Tensor),
values: make(map[int]ml.Tensor),
}
}
func (c *Causal) Init(backend ml.Backend, dtype ml.DType, maxSequences, capacity, maxBatch int) { func (c *Causal) Init(backend ml.Backend, dtype ml.DType, maxSequences, capacity, maxBatch int) {
if c.config == nil { if c.config == nil {
var config ml.CacheConfig var config ml.CacheConfig
@@ -164,13 +147,12 @@ func (c *Causal) Close() {
} }
func (c *Causal) StartForward(ctx ml.Context, batch input.Batch, reserve bool) error { func (c *Causal) StartForward(ctx ml.Context, batch input.Batch, reserve bool) error {
c.curReserve = reserve
c.curBatchSize = len(batch.Positions) c.curBatchSize = len(batch.Positions)
c.curSequences = batch.Sequences c.curSequences = batch.Sequences
c.curPositions = batch.Positions c.curPositions = batch.Positions
c.opts.Except = nil c.opts.Except = nil
if !c.curReserve { if !reserve {
c.updateSlidingWindow() c.updateSlidingWindow()
var err error var err error
@@ -217,9 +199,10 @@ func (c *Causal) StartForward(ctx ml.Context, batch input.Batch, reserve bool) e
c.curCellRange.max = len(c.cells) - 1 c.curCellRange.max = len(c.cells) - 1
} }
c.curMask = c.buildMask(ctx) var err error
c.curMask, err = c.buildMask(ctx)
return nil return err
} }
func newRange() cellRange { func newRange() cellRange {
@@ -244,7 +227,7 @@ func (c *Causal) findStartLoc() (int, error) {
} }
} }
return 0, fmt.Errorf("%w (cache: %v batch: %v)", ErrKvCacheFull, len(c.cells), c.curBatchSize) return 0, fmt.Errorf("%w (length: %v)", ErrKvCacheFull, len(c.cells))
} }
func (c *Causal) updateSlidingWindow() { func (c *Causal) updateSlidingWindow() {
@@ -302,7 +285,7 @@ func roundUp(length, pad int) int {
// Builds a mask of history x batch indicating whether for each token in the batch the // Builds a mask of history x batch indicating whether for each token in the batch the
// token in the history should apply. This is based on both the sequence and causality (the // token in the history should apply. This is based on both the sequence and causality (the
// position of the history is not ahead of the token in the batch). // position of the history is not ahead of the token in the batch).
func (c *Causal) buildMask(ctx ml.Context) ml.Tensor { func (c *Causal) buildMask(ctx ml.Context) (ml.Tensor, error) {
// Align and pad the two dimensions as required by the backend // Align and pad the two dimensions as required by the backend
batchSize := roundUp(c.curBatchSize, c.config.MaskBatchPadding) batchSize := roundUp(c.curBatchSize, c.config.MaskBatchPadding)
@@ -310,11 +293,6 @@ func (c *Causal) buildMask(ctx ml.Context) ml.Tensor {
c.curCellRange.max = roundUp(c.curCellRange.max+1, c.config.CachePadding) - 1 c.curCellRange.max = roundUp(c.curCellRange.max+1, c.config.CachePadding) - 1
length := c.curCellRange.max - c.curCellRange.min + 1 length := c.curCellRange.max - c.curCellRange.min + 1
if c.curReserve {
return ctx.Input().Empty(c.config.MaskDType, length, batchSize)
}
mask := make([]float32, batchSize*length) mask := make([]float32, batchSize*length)
for i := range c.curBatchSize { for i := range c.curBatchSize {
@@ -322,7 +300,6 @@ func (c *Causal) buildMask(ctx ml.Context) ml.Tensor {
for j := c.curCellRange.min; j <= c.curCellRange.max; j++ { for j := c.curCellRange.min; j <= c.curCellRange.max; j++ {
if !slices.Contains(c.cells[j].sequences, c.curSequences[i]) || if !slices.Contains(c.cells[j].sequences, c.curSequences[i]) ||
(enabled && c.cells[j].pos > c.curPositions[i]) || (enabled && c.cells[j].pos > c.curPositions[i]) ||
c.chunkSize > 0 && c.cells[j].pos < c.curPositions[i]-c.curPositions[i]%c.chunkSize ||
c.cells[j].pos < c.curPositions[i]-c.windowSize { c.cells[j].pos < c.curPositions[i]-c.windowSize {
mask[i*length+(j-c.curCellRange.min)] = float32(math.Inf(-1)) mask[i*length+(j-c.curCellRange.min)] = float32(math.Inf(-1))
} }
@@ -335,7 +312,10 @@ func (c *Causal) buildMask(ctx ml.Context) ml.Tensor {
mask[i] = float32(math.Inf(-1)) mask[i] = float32(math.Inf(-1))
} }
maskTensor := ctx.Input().FromFloatSlice(mask, length, batchSize) maskTensor, err := ctx.Input().FromFloatSlice(mask, length, batchSize)
if err != nil {
return nil, err
}
if c.config.MaskDType != ml.DTypeF32 { if c.config.MaskDType != ml.DTypeF32 {
out := ctx.Input().Empty(c.config.MaskDType, maskTensor.Shape()...) out := ctx.Input().Empty(c.config.MaskDType, maskTensor.Shape()...)
@@ -343,7 +323,7 @@ func (c *Causal) buildMask(ctx ml.Context) ml.Tensor {
maskTensor = out maskTensor = out
} }
return maskTensor return maskTensor, nil
} }
func (c *Causal) moveCells(ctx ml.Context, src, dst, length int) { func (c *Causal) moveCells(ctx ml.Context, src, dst, length int) {
@@ -498,7 +478,12 @@ func (c *Causal) SetCausal(ctx ml.Context, opts CausalOptions) {
if !slices.Equal(c.opts.Except, opts.Except) { if !slices.Equal(c.opts.Except, opts.Except) {
c.opts = opts c.opts = opts
if ctx != nil { if ctx != nil {
c.curMask = c.buildMask(ctx) var err error
c.curMask, err = c.buildMask(ctx)
if err != nil {
// This error should never occur because we have previously built a mask with the same shape
panic(fmt.Errorf("SetCausal: %w", err))
}
} }
} }
} }
@@ -654,7 +639,10 @@ func (c *Causal) shift(seq int, beginIndex, offset int32) error {
} }
} }
kShift := ctx.Input().FromIntSlice(offsets, len(offsets)) kShift, err := ctx.Input().FromIntSlice(offsets, len(offsets))
if err != nil {
return err
}
for i, key := range c.keys { for i, key := range c.keys {
if key == nil { if key == nil {

View File

@@ -86,64 +86,6 @@ func TestSWA(t *testing.T) {
testCache(t, backend, cache, tests) testCache(t, backend, cache, tests)
} }
func TestChunkedAttention(t *testing.T) {
cache := NewChunkedAttentionCache(2, nil)
defer cache.Close()
var b testBackend
cache.Init(&b, ml.DTypeF16, 1, 16, 16)
x := float32(math.Inf(-1))
testCache(
t, &b, cache,
[]testCase{
{
name: "FirstBatch",
in: []float32{1, 2, 3, 4},
inShape: []int{1, 1, 4},
seqs: []int{0, 0, 0, 0},
pos: []int32{0, 1, 2, 3},
expected: []float32{1, 2, 3, 4},
expectedShape: []int{1, 1, 4},
expectedMask: []float32{
0, x, x, x,
0, 0, x, x,
x, x, 0, x,
x, x, 0, 0,
},
},
{
name: "SecondBatch",
in: []float32{5, 6, 7},
inShape: []int{1, 1, 3},
seqs: []int{0, 0, 0},
pos: []int32{4, 5, 6},
expected: []float32{1, 2, 3, 4, 5, 6, 7},
expectedShape: []int{1, 1, 7},
expectedMask: []float32{
x, x, x, x, 0, x, x,
x, x, x, x, 0, 0, x,
x, x, x, x, x, x, 0,
},
},
{
name: "ThirdBatch",
in: []float32{8, 9},
inShape: []int{1, 1, 2},
seqs: []int{0, 0},
pos: []int32{7, 8},
expected: []float32{1, 2, 3, 4, 5, 6, 7, 8, 9},
expectedShape: []int{1, 1, 9},
expectedMask: []float32{
x, x, x, x, x, x, 0, 0, x,
x, x, x, x, x, x, x, x, 0,
},
},
},
)
}
func TestSequences(t *testing.T) { func TestSequences(t *testing.T) {
backend := &testBackend{} backend := &testBackend{}
cache := NewCausalCache(nil) cache := NewCausalCache(nil)
@@ -344,23 +286,15 @@ func testCache(t *testing.T, backend ml.Backend, cache Cache, tests []testCase)
} }
cache.SetLayer(0) cache.SetLayer(0)
tensor := context.FromFloatSlice(test.in, test.inShape...) tensor, _ := context.FromFloatSlice(test.in, test.inShape...)
cache.Put(context, tensor, tensor) cache.Put(context, tensor, tensor)
out, _, mask := cache.Get(context) out, _, mask := cache.Get(context)
context.Forward(out, mask).Compute(out, mask) context.Forward(out, mask).Compute(out, mask)
if !slices.Equal(out.Floats(), test.expected) { if !slices.Equal(out.Floats(), test.expected) || !slices.Equal(out.Shape(), test.expectedShape) || !slices.Equal(mask.Floats(), test.expectedMask) {
t.Errorf("TestCache: have %v; want %v", out.Floats(), test.expected) t.Errorf("TestCache: have %v (shape %v); want %v (shape %v); mask: have %v (shape %v) want %v", out.Floats(), out.Shape(), test.expected, test.expectedShape, mask.Floats(), mask.Shape(), test.expectedMask)
}
if !slices.Equal(out.Shape(), test.expectedShape) {
t.Errorf("TestCache: has shape %v; want %v", out.Shape(), test.expectedShape)
}
if !slices.Equal(mask.Floats(), test.expectedMask) {
t.Errorf("TestCache: have mask: have %v want %v", mask.Floats(), test.expectedMask)
} }
}) })
} }
@@ -386,7 +320,7 @@ func TestCanResume(t *testing.T) {
} }
cache.SetLayer(0) cache.SetLayer(0)
tensor := context.FromFloatSlice([]float32{1, 2, 3, 4}, 1, 1, 4) tensor, _ := context.FromFloatSlice([]float32{1, 2, 3, 4}, 1, 1, 4)
cache.Put(context, tensor, tensor) cache.Put(context, tensor, tensor)
// with window size 4, nothing has slid out of the window yet // with window size 4, nothing has slid out of the window yet
@@ -413,7 +347,7 @@ func TestCanResume(t *testing.T) {
} }
cache.SetLayer(0) cache.SetLayer(0)
tensor = context.FromFloatSlice([]float32{5, 6}, 1, 1, 2) tensor, _ = context.FromFloatSlice([]float32{5, 6}, 1, 1, 2)
cache.Put(context, tensor, tensor) cache.Put(context, tensor, tensor)
// only the latest position has overlapping windows // only the latest position has overlapping windows
@@ -470,35 +404,24 @@ func (c *testContext) Zeros(dtype ml.DType, shape ...int) ml.Tensor {
return c.Empty(dtype, shape...) return c.Empty(dtype, shape...)
} }
func (c *testContext) FromFloatSlice(s []float32, shape ...int) ml.Tensor { func (c *testContext) FromFloatSlice(s []float32, shape ...int) (ml.Tensor, error) {
t := c.Empty(ml.DTypeF32, shape...).(*testTensor) t := c.Empty(ml.DTypeF32, shape...).(*testTensor)
copy(t.data, s) copy(t.data, s)
return t return t, nil
} }
func (c *testContext) FromIntSlice(s []int32, shape ...int) ml.Tensor { func (c *testContext) FromIntSlice(s []int32, shape ...int) (ml.Tensor, error) {
f := make([]float32, len(s)) f := make([]float32, len(s))
for i := range f { for i := range f {
f[i] = float32(s[i]) f[i] = float32(s[i])
} }
out := c.FromFloatSlice(f, shape...) out, _ := c.FromFloatSlice(f, shape...)
out.(*testTensor).dtype = ml.DTypeI32 out.(*testTensor).dtype = ml.DTypeI32
return out return out, nil
}
func (c *testContext) Arange(start, stop, step float32, dtype ml.DType) ml.Tensor {
s := make([]float32, 0, int((stop-start)/step))
for i := start; i < stop; i += step {
s = append(s, i)
}
out := c.FromFloatSlice(s, len(s))
out.(*testTensor).dtype = dtype
return out
} }
func (c *testContext) Input() ml.Context { return c } func (c *testContext) Input() ml.Context { return c }
@@ -508,7 +431,7 @@ func (c *testContext) Forward(...ml.Tensor) ml.Context { return c }
func (c *testContext) Compute(...ml.Tensor) {} func (c *testContext) Compute(...ml.Tensor) {}
func (c *testContext) Reserve() {} func (c *testContext) Reserve() error { return nil }
func (c *testContext) MaxGraphNodes() int { func (c *testContext) MaxGraphNodes() int {
return 10 return 10

2
llama/build-info.cpp generated vendored
View File

@@ -1,4 +1,4 @@
int LLAMA_BUILD_NUMBER = 0; int LLAMA_BUILD_NUMBER = 0;
char const *LLAMA_COMMIT = "de4c07f93783a1a96456a44dc16b9db538ee1618"; char const *LLAMA_COMMIT = "d7cfe1ffe0f435d0048a6058d529daf76e072d9c";
char const *LLAMA_COMPILER = ""; char const *LLAMA_COMPILER = "";
char const *LLAMA_BUILD_TARGET = ""; char const *LLAMA_BUILD_TARGET = "";

View File

@@ -10,11 +10,10 @@ include common/stb_image.*
include include/ include include/
include include/llama.* include include/llama.*
include include/llama-*.* include include/llama-*.*
include tools/ include examples/
include tools/mtmd/ include examples/llava/
include tools/mtmd/clip.* include examples/llava/clip.*
include tools/mtmd/clip-impl.* include examples/llava/llava.*
include tools/mtmd/llava.*
include src/ include src/
include src/llama.* include src/llama.*
include src/llama-*.* include src/llama-*.*

View File

@@ -7,6 +7,10 @@
#include "common.h" #include "common.h"
#include "log.h" #include "log.h"
// Change JSON_ASSERT from assert() to GGML_ASSERT:
#define JSON_ASSERT GGML_ASSERT
#include "json.hpp"
#include "json-schema-to-grammar.h"
#include "llama.h" #include "llama.h"
#include <algorithm> #include <algorithm>
@@ -48,11 +52,47 @@
#include <sys/stat.h> #include <sys/stat.h>
#include <unistd.h> #include <unistd.h>
#endif #endif
#if defined(LLAMA_USE_CURL)
#include <curl/curl.h>
#include <curl/easy.h>
#include <future>
#endif
#if defined(_MSC_VER) #if defined(_MSC_VER)
#pragma warning(disable: 4244 4267) // possible loss of data #pragma warning(disable: 4244 4267) // possible loss of data
#endif #endif
#if defined(LLAMA_USE_CURL)
#ifdef __linux__
#include <linux/limits.h>
#elif defined(_WIN32)
# if !defined(PATH_MAX)
# define PATH_MAX MAX_PATH
# endif
#else
#include <sys/syslimits.h>
#endif
#define LLAMA_CURL_MAX_URL_LENGTH 2084 // Maximum URL Length in Chrome: 2083
//
// CURL utils
//
using curl_ptr = std::unique_ptr<CURL, decltype(&curl_easy_cleanup)>;
// cannot use unique_ptr for curl_slist, because we cannot update without destroying the old one
struct curl_slist_ptr {
struct curl_slist * ptr = nullptr;
~curl_slist_ptr() {
if (ptr) {
curl_slist_free_all(ptr);
}
}
};
#endif // LLAMA_USE_CURL
using json = nlohmann::ordered_json;
// //
// CPU utils // CPU utils
// //
@@ -443,11 +483,6 @@ void string_replace_all(std::string & s, const std::string & search, const std::
s = std::move(builder); s = std::move(builder);
} }
std::string regex_escape(const std::string & s) {
static const std::regex special_chars("[.^$|()*+?\\[\\]{}\\\\]");
return std::regex_replace(s, special_chars, "\\$0");
}
std::string string_join(const std::vector<std::string> & values, const std::string & separator) { std::string string_join(const std::vector<std::string> & values, const std::string & separator) {
std::ostringstream result; std::ostringstream result;
for (size_t i = 0; i < values.size(); ++i) { for (size_t i = 0; i < values.size(); ++i) {
@@ -830,7 +865,7 @@ std::string fs_get_cache_directory() {
if (getenv("LLAMA_CACHE")) { if (getenv("LLAMA_CACHE")) {
cache_directory = std::getenv("LLAMA_CACHE"); cache_directory = std::getenv("LLAMA_CACHE");
} else { } else {
#if defined(__linux__) || defined(__FreeBSD__) || defined(_AIX) #ifdef __linux__
if (std::getenv("XDG_CACHE_HOME")) { if (std::getenv("XDG_CACHE_HOME")) {
cache_directory = std::getenv("XDG_CACHE_HOME"); cache_directory = std::getenv("XDG_CACHE_HOME");
} else { } else {
@@ -840,9 +875,7 @@ std::string fs_get_cache_directory() {
cache_directory = std::getenv("HOME") + std::string("/Library/Caches/"); cache_directory = std::getenv("HOME") + std::string("/Library/Caches/");
#elif defined(_WIN32) #elif defined(_WIN32)
cache_directory = std::getenv("LOCALAPPDATA"); cache_directory = std::getenv("LOCALAPPDATA");
#else #endif // __linux__
# error Unknown architecture
#endif
cache_directory = ensure_trailing_slash(cache_directory); cache_directory = ensure_trailing_slash(cache_directory);
cache_directory += "llama.cpp"; cache_directory += "llama.cpp";
} }
@@ -863,14 +896,22 @@ std::string fs_get_cache_file(const std::string & filename) {
// //
// Model utils // Model utils
// //
struct common_init_result common_init_from_params(common_params & params) { struct common_init_result common_init_from_params(common_params & params) {
common_init_result iparams; common_init_result iparams;
auto mparams = common_model_params_to_llama(params); auto mparams = common_model_params_to_llama(params);
llama_model * model = llama_model_load_from_file(params.model.path.c_str(), mparams); llama_model * model = nullptr;
if (!params.hf_repo.empty() && !params.hf_file.empty()) {
model = common_load_model_from_hf(params.hf_repo, params.hf_file, params.model, params.hf_token, mparams);
} else if (!params.model_url.empty()) {
model = common_load_model_from_url(params.model_url, params.model, params.hf_token, mparams);
} else {
model = llama_model_load_from_file(params.model.c_str(), mparams);
}
if (model == NULL) { if (model == NULL) {
LOG_ERR("%s: failed to load model '%s'\n", __func__, params.model.path.c_str()); LOG_ERR("%s: failed to load model '%s'\n", __func__, params.model.c_str());
return iparams; return iparams;
} }
@@ -905,13 +946,13 @@ struct common_init_result common_init_from_params(common_params & params) {
llama_context * lctx = llama_init_from_model(model, cparams); llama_context * lctx = llama_init_from_model(model, cparams);
if (lctx == NULL) { if (lctx == NULL) {
LOG_ERR("%s: failed to create context with model '%s'\n", __func__, params.model.path.c_str()); LOG_ERR("%s: failed to create context with model '%s'\n", __func__, params.model.c_str());
llama_model_free(model); llama_model_free(model);
return iparams; return iparams;
} }
if (params.ctx_shift && !llama_kv_self_can_shift(lctx)) { if (params.ctx_shift && !llama_kv_cache_can_shift(lctx)) {
LOG_WRN("%s: KV cache shifting is not supported for this context, disabling KV cache shifting\n", __func__); LOG_WRN("%s: KV cache shifting is not supported for this model, disabling KV cache shifting\n", __func__);
params.ctx_shift = false; params.ctx_shift = false;
} }
@@ -988,8 +1029,6 @@ struct common_init_result common_init_from_params(common_params & params) {
if (params.warmup) { if (params.warmup) {
LOG_WRN("%s: warming up the model with an empty run - please wait ... (--no-warmup to disable)\n", __func__); LOG_WRN("%s: warming up the model with an empty run - please wait ... (--no-warmup to disable)\n", __func__);
llama_set_warmup(lctx, true);
std::vector<llama_token> tmp; std::vector<llama_token> tmp;
llama_token bos = llama_vocab_bos(vocab); llama_token bos = llama_vocab_bos(vocab);
llama_token eos = llama_vocab_eos(vocab); llama_token eos = llama_vocab_eos(vocab);
@@ -1017,10 +1056,9 @@ struct common_init_result common_init_from_params(common_params & params) {
if (llama_model_has_decoder(model)) { if (llama_model_has_decoder(model)) {
llama_decode(lctx, llama_batch_get_one(tmp.data(), std::min(tmp.size(), (size_t) params.n_batch))); llama_decode(lctx, llama_batch_get_one(tmp.data(), std::min(tmp.size(), (size_t) params.n_batch)));
} }
llama_kv_self_clear(lctx); llama_kv_cache_clear(lctx);
llama_synchronize(lctx); llama_synchronize(lctx);
llama_perf_context_reset(lctx); llama_perf_context_reset(lctx);
llama_set_warmup(lctx, false);
} }
iparams.model.reset(model); iparams.model.reset(model);
@@ -1029,19 +1067,6 @@ struct common_init_result common_init_from_params(common_params & params) {
return iparams; return iparams;
} }
std::string get_model_endpoint() {
const char * model_endpoint_env = getenv("MODEL_ENDPOINT");
// We still respect the use of environment-variable "HF_ENDPOINT" for backward-compatibility.
const char * hf_endpoint_env = getenv("HF_ENDPOINT");
const char * endpoint_env = model_endpoint_env ? model_endpoint_env : hf_endpoint_env;
std::string model_endpoint = "https://huggingface.co/";
if (endpoint_env) {
model_endpoint = endpoint_env;
if (model_endpoint.back() != '/') model_endpoint += '/';
}
return model_endpoint;
}
void common_set_adapter_lora(struct llama_context * ctx, std::vector<common_adapter_lora_info> & lora) { void common_set_adapter_lora(struct llama_context * ctx, std::vector<common_adapter_lora_info> & lora) {
llama_clear_adapter_lora(ctx); llama_clear_adapter_lora(ctx);
for (auto & la : lora) { for (auto & la : lora) {
@@ -1057,18 +1082,15 @@ struct llama_model_params common_model_params_to_llama(common_params & params) {
if (!params.devices.empty()) { if (!params.devices.empty()) {
mparams.devices = params.devices.data(); mparams.devices = params.devices.data();
} }
if (params.n_gpu_layers != -1) { if (params.n_gpu_layers != -1) {
mparams.n_gpu_layers = params.n_gpu_layers; mparams.n_gpu_layers = params.n_gpu_layers;
} }
mparams.main_gpu = params.main_gpu; mparams.main_gpu = params.main_gpu;
mparams.split_mode = params.split_mode; mparams.split_mode = params.split_mode;
mparams.tensor_split = params.tensor_split; mparams.tensor_split = params.tensor_split;
mparams.use_mmap = params.use_mmap; mparams.use_mmap = params.use_mmap;
mparams.use_mlock = params.use_mlock; mparams.use_mlock = params.use_mlock;
mparams.check_tensors = params.check_tensors; mparams.check_tensors = params.check_tensors;
if (params.kv_overrides.empty()) { if (params.kv_overrides.empty()) {
mparams.kv_overrides = NULL; mparams.kv_overrides = NULL;
} else { } else {
@@ -1076,13 +1098,6 @@ struct llama_model_params common_model_params_to_llama(common_params & params) {
mparams.kv_overrides = params.kv_overrides.data(); mparams.kv_overrides = params.kv_overrides.data();
} }
if (params.tensor_buft_overrides.empty()) {
mparams.tensor_buft_overrides = NULL;
} else {
GGML_ASSERT(params.tensor_buft_overrides.back().pattern == nullptr && "Tensor buffer overrides not terminated with empty pattern");
mparams.tensor_buft_overrides = params.tensor_buft_overrides.data();
}
return mparams; return mparams;
} }
@@ -1096,6 +1111,7 @@ struct llama_context_params common_context_params_to_llama(const common_params &
cparams.n_threads = params.cpuparams.n_threads; cparams.n_threads = params.cpuparams.n_threads;
cparams.n_threads_batch = params.cpuparams_batch.n_threads == -1 ? cparams.n_threads_batch = params.cpuparams_batch.n_threads == -1 ?
params.cpuparams.n_threads : params.cpuparams_batch.n_threads; params.cpuparams.n_threads : params.cpuparams_batch.n_threads;
cparams.logits_all = params.logits_all;
cparams.embeddings = params.embedding; cparams.embeddings = params.embedding;
cparams.rope_scaling_type = params.rope_scaling_type; cparams.rope_scaling_type = params.rope_scaling_type;
cparams.rope_freq_base = params.rope_freq_base; cparams.rope_freq_base = params.rope_freq_base;
@@ -1113,7 +1129,6 @@ struct llama_context_params common_context_params_to_llama(const common_params &
cparams.offload_kqv = !params.no_kv_offload; cparams.offload_kqv = !params.no_kv_offload;
cparams.flash_attn = params.flash_attn; cparams.flash_attn = params.flash_attn;
cparams.no_perf = params.no_perf; cparams.no_perf = params.no_perf;
cparams.op_offload = !params.no_op_offload;
if (params.reranking) { if (params.reranking) {
cparams.embeddings = true; cparams.embeddings = true;
@@ -1142,6 +1157,451 @@ struct ggml_threadpool_params ggml_threadpool_params_from_cpu_params(const cpu_p
return tpp; return tpp;
} }
#ifdef LLAMA_USE_CURL
#define CURL_MAX_RETRY 3
#define CURL_RETRY_DELAY_SECONDS 2
static bool curl_perform_with_retry(const std::string & url, CURL * curl, int max_attempts, int retry_delay_seconds) {
int remaining_attempts = max_attempts;
while (remaining_attempts > 0) {
LOG_INF("%s: Trying to download from %s (attempt %d of %d)...\n", __func__ , url.c_str(), max_attempts - remaining_attempts + 1, max_attempts);
CURLcode res = curl_easy_perform(curl);
if (res == CURLE_OK) {
return true;
}
int exponential_backoff_delay = std::pow(retry_delay_seconds, max_attempts - remaining_attempts) * 1000;
LOG_WRN("%s: curl_easy_perform() failed: %s, retrying after %d milliseconds...\n", __func__, curl_easy_strerror(res), exponential_backoff_delay);
remaining_attempts--;
std::this_thread::sleep_for(std::chrono::milliseconds(exponential_backoff_delay));
}
LOG_ERR("%s: curl_easy_perform() failed after %d attempts\n", __func__, max_attempts);
return false;
}
static bool common_download_file(const std::string & url, const std::string & path, const std::string & hf_token) {
// Initialize libcurl
curl_ptr curl(curl_easy_init(), &curl_easy_cleanup);
curl_slist_ptr http_headers;
if (!curl) {
LOG_ERR("%s: error initializing libcurl\n", __func__);
return false;
}
bool force_download = false;
// Set the URL, allow to follow http redirection
curl_easy_setopt(curl.get(), CURLOPT_URL, url.c_str());
curl_easy_setopt(curl.get(), CURLOPT_FOLLOWLOCATION, 1L);
// Check if hf-token or bearer-token was specified
if (!hf_token.empty()) {
std::string auth_header = "Authorization: Bearer " + hf_token;
http_headers.ptr = curl_slist_append(http_headers.ptr, auth_header.c_str());
curl_easy_setopt(curl.get(), CURLOPT_HTTPHEADER, http_headers.ptr);
}
#if defined(_WIN32)
// CURLSSLOPT_NATIVE_CA tells libcurl to use standard certificate store of
// operating system. Currently implemented under MS-Windows.
curl_easy_setopt(curl.get(), CURLOPT_SSL_OPTIONS, CURLSSLOPT_NATIVE_CA);
#endif
// Check if the file already exists locally
auto file_exists = std::filesystem::exists(path);
// If the file exists, check its JSON metadata companion file.
std::string metadata_path = path + ".json";
nlohmann::json metadata;
std::string etag;
std::string last_modified;
if (file_exists) {
// Try and read the JSON metadata file (note: stream autoclosed upon exiting this block).
std::ifstream metadata_in(metadata_path);
if (metadata_in.good()) {
try {
metadata_in >> metadata;
LOG_INF("%s: previous metadata file found %s: %s\n", __func__, metadata_path.c_str(), metadata.dump().c_str());
if (metadata.contains("url") && metadata.at("url").is_string()) {
auto previous_url = metadata.at("url").get<std::string>();
if (previous_url != url) {
LOG_ERR("%s: Model URL mismatch: %s != %s\n", __func__, url.c_str(), previous_url.c_str());
return false;
}
}
if (metadata.contains("etag") && metadata.at("etag").is_string()) {
etag = metadata.at("etag");
}
if (metadata.contains("lastModified") && metadata.at("lastModified").is_string()) {
last_modified = metadata.at("lastModified");
}
} catch (const nlohmann::json::exception & e) {
LOG_ERR("%s: error reading metadata file %s: %s\n", __func__, metadata_path.c_str(), e.what());
return false;
}
}
} else {
LOG_INF("%s: no previous model file found %s\n", __func__, path.c_str());
}
// Send a HEAD request to retrieve the etag and last-modified headers
struct common_load_model_from_url_headers {
std::string etag;
std::string last_modified;
};
common_load_model_from_url_headers headers;
{
typedef size_t(*CURLOPT_HEADERFUNCTION_PTR)(char *, size_t, size_t, void *);
auto header_callback = [](char * buffer, size_t /*size*/, size_t n_items, void * userdata) -> size_t {
common_load_model_from_url_headers * headers = (common_load_model_from_url_headers *) userdata;
static std::regex header_regex("([^:]+): (.*)\r\n");
static std::regex etag_regex("ETag", std::regex_constants::icase);
static std::regex last_modified_regex("Last-Modified", std::regex_constants::icase);
std::string header(buffer, n_items);
std::smatch match;
if (std::regex_match(header, match, header_regex)) {
const std::string & key = match[1];
const std::string & value = match[2];
if (std::regex_match(key, match, etag_regex)) {
headers->etag = value;
} else if (std::regex_match(key, match, last_modified_regex)) {
headers->last_modified = value;
}
}
return n_items;
};
curl_easy_setopt(curl.get(), CURLOPT_NOBODY, 1L); // will trigger the HEAD verb
curl_easy_setopt(curl.get(), CURLOPT_NOPROGRESS, 1L); // hide head request progress
curl_easy_setopt(curl.get(), CURLOPT_HEADERFUNCTION, static_cast<CURLOPT_HEADERFUNCTION_PTR>(header_callback));
curl_easy_setopt(curl.get(), CURLOPT_HEADERDATA, &headers);
bool was_perform_successful = curl_perform_with_retry(url, curl.get(), CURL_MAX_RETRY, CURL_RETRY_DELAY_SECONDS);
if (!was_perform_successful) {
return false;
}
long http_code = 0;
curl_easy_getinfo(curl.get(), CURLINFO_RESPONSE_CODE, &http_code);
if (http_code != 200) {
// HEAD not supported, we don't know if the file has changed
// force trigger downloading
force_download = true;
LOG_ERR("%s: HEAD invalid http status code received: %ld\n", __func__, http_code);
}
}
bool should_download = !file_exists || force_download;
if (!should_download) {
if (!etag.empty() && etag != headers.etag) {
LOG_WRN("%s: ETag header is different (%s != %s): triggering a new download\n", __func__, etag.c_str(), headers.etag.c_str());
should_download = true;
} else if (!last_modified.empty() && last_modified != headers.last_modified) {
LOG_WRN("%s: Last-Modified header is different (%s != %s): triggering a new download\n", __func__, last_modified.c_str(), headers.last_modified.c_str());
should_download = true;
}
}
if (should_download) {
std::string path_temporary = path + ".downloadInProgress";
if (file_exists) {
LOG_WRN("%s: deleting previous downloaded file: %s\n", __func__, path.c_str());
if (remove(path.c_str()) != 0) {
LOG_ERR("%s: unable to delete file: %s\n", __func__, path.c_str());
return false;
}
}
// Set the output file
struct FILE_deleter {
void operator()(FILE * f) const {
fclose(f);
}
};
std::unique_ptr<FILE, FILE_deleter> outfile(fopen(path_temporary.c_str(), "wb"));
if (!outfile) {
LOG_ERR("%s: error opening local file for writing: %s\n", __func__, path.c_str());
return false;
}
typedef size_t(*CURLOPT_WRITEFUNCTION_PTR)(void * data, size_t size, size_t nmemb, void * fd);
auto write_callback = [](void * data, size_t size, size_t nmemb, void * fd) -> size_t {
return fwrite(data, size, nmemb, (FILE *)fd);
};
curl_easy_setopt(curl.get(), CURLOPT_NOBODY, 0L);
curl_easy_setopt(curl.get(), CURLOPT_WRITEFUNCTION, static_cast<CURLOPT_WRITEFUNCTION_PTR>(write_callback));
curl_easy_setopt(curl.get(), CURLOPT_WRITEDATA, outfile.get());
// display download progress
curl_easy_setopt(curl.get(), CURLOPT_NOPROGRESS, 0L);
// helper function to hide password in URL
auto llama_download_hide_password_in_url = [](const std::string & url) -> std::string {
std::size_t protocol_pos = url.find("://");
if (protocol_pos == std::string::npos) {
return url; // Malformed URL
}
std::size_t at_pos = url.find('@', protocol_pos + 3);
if (at_pos == std::string::npos) {
return url; // No password in URL
}
return url.substr(0, protocol_pos + 3) + "********" + url.substr(at_pos);
};
// start the download
LOG_INF("%s: trying to download model from %s to %s (server_etag:%s, server_last_modified:%s)...\n", __func__,
llama_download_hide_password_in_url(url).c_str(), path.c_str(), headers.etag.c_str(), headers.last_modified.c_str());
bool was_perform_successful = curl_perform_with_retry(url, curl.get(), CURL_MAX_RETRY, CURL_RETRY_DELAY_SECONDS);
if (!was_perform_successful) {
return false;
}
long http_code = 0;
curl_easy_getinfo (curl.get(), CURLINFO_RESPONSE_CODE, &http_code);
if (http_code < 200 || http_code >= 400) {
LOG_ERR("%s: invalid http status code received: %ld\n", __func__, http_code);
return false;
}
// Causes file to be closed explicitly here before we rename it.
outfile.reset();
// Write the updated JSON metadata file.
metadata.update({
{"url", url},
{"etag", headers.etag},
{"lastModified", headers.last_modified}
});
std::ofstream(metadata_path) << metadata.dump(4);
LOG_INF("%s: file metadata saved: %s\n", __func__, metadata_path.c_str());
if (rename(path_temporary.c_str(), path.c_str()) != 0) {
LOG_ERR("%s: unable to rename file: %s to %s\n", __func__, path_temporary.c_str(), path.c_str());
return false;
}
}
return true;
}
struct llama_model * common_load_model_from_url(
const std::string & model_url,
const std::string & local_path,
const std::string & hf_token,
const struct llama_model_params & params) {
// Basic validation of the model_url
if (model_url.empty()) {
LOG_ERR("%s: invalid model_url\n", __func__);
return NULL;
}
if (!common_download_file(model_url, local_path, hf_token)) {
return NULL;
}
// check for additional GGUFs split to download
int n_split = 0;
{
struct gguf_init_params gguf_params = {
/*.no_alloc = */ true,
/*.ctx = */ NULL,
};
auto * ctx_gguf = gguf_init_from_file(local_path.c_str(), gguf_params);
if (!ctx_gguf) {
LOG_ERR("\n%s: failed to load input GGUF from %s\n", __func__, local_path.c_str());
return NULL;
}
auto key_n_split = gguf_find_key(ctx_gguf, LLM_KV_SPLIT_COUNT);
if (key_n_split >= 0) {
n_split = gguf_get_val_u16(ctx_gguf, key_n_split);
}
gguf_free(ctx_gguf);
}
if (n_split > 1) {
char split_prefix[PATH_MAX] = {0};
char split_url_prefix[LLAMA_CURL_MAX_URL_LENGTH] = {0};
// Verify the first split file format
// and extract split URL and PATH prefixes
{
if (!llama_split_prefix(split_prefix, sizeof(split_prefix), local_path.c_str(), 0, n_split)) {
LOG_ERR("\n%s: unexpected model file name: %s n_split=%d\n", __func__, local_path.c_str(), n_split);
return NULL;
}
if (!llama_split_prefix(split_url_prefix, sizeof(split_url_prefix), model_url.c_str(), 0, n_split)) {
LOG_ERR("\n%s: unexpected model url: %s n_split=%d\n", __func__, model_url.c_str(), n_split);
return NULL;
}
}
// Prepare download in parallel
std::vector<std::future<bool>> futures_download;
for (int idx = 1; idx < n_split; idx++) {
futures_download.push_back(std::async(std::launch::async, [&split_prefix, &split_url_prefix, &n_split, hf_token](int download_idx) -> bool {
char split_path[PATH_MAX] = {0};
llama_split_path(split_path, sizeof(split_path), split_prefix, download_idx, n_split);
char split_url[LLAMA_CURL_MAX_URL_LENGTH] = {0};
llama_split_path(split_url, sizeof(split_url), split_url_prefix, download_idx, n_split);
return common_download_file(split_url, split_path, hf_token);
}, idx));
}
// Wait for all downloads to complete
for (auto & f : futures_download) {
if (!f.get()) {
return NULL;
}
}
}
return llama_model_load_from_file(local_path.c_str(), params);
}
struct llama_model * common_load_model_from_hf(
const std::string & repo,
const std::string & remote_path,
const std::string & local_path,
const std::string & hf_token,
const struct llama_model_params & params) {
// construct hugging face model url:
//
// --repo ggml-org/models --file tinyllama-1.1b/ggml-model-f16.gguf
// https://huggingface.co/ggml-org/models/resolve/main/tinyllama-1.1b/ggml-model-f16.gguf
//
// --repo TheBloke/Mixtral-8x7B-v0.1-GGUF --file mixtral-8x7b-v0.1.Q4_K_M.gguf
// https://huggingface.co/TheBloke/Mixtral-8x7B-v0.1-GGUF/resolve/main/mixtral-8x7b-v0.1.Q4_K_M.gguf
//
std::string model_url = "https://huggingface.co/";
model_url += repo;
model_url += "/resolve/main/";
model_url += remote_path;
return common_load_model_from_url(model_url, local_path, hf_token, params);
}
/**
* Allow getting the HF file from the HF repo with tag (like ollama), for example:
* - bartowski/Llama-3.2-3B-Instruct-GGUF:q4
* - bartowski/Llama-3.2-3B-Instruct-GGUF:Q4_K_M
* - bartowski/Llama-3.2-3B-Instruct-GGUF:q5_k_s
* Tag is optional, default to "latest" (meaning it checks for Q4_K_M first, then Q4, then if not found, return the first GGUF file in repo)
*
* Return pair of <repo, file> (with "repo" already having tag removed)
*
* Note: we use the Ollama-compatible HF API, but not using the blobId. Instead, we use the special "ggufFile" field which returns the value for "hf_file". This is done to be backward-compatible with existing cache files.
*/
std::pair<std::string, std::string> common_get_hf_file(const std::string & hf_repo_with_tag, const std::string & hf_token) {
auto parts = string_split<std::string>(hf_repo_with_tag, ':');
std::string tag = parts.size() > 1 ? parts.back() : "latest";
std::string hf_repo = parts[0];
if (string_split<std::string>(hf_repo, '/').size() != 2) {
throw std::invalid_argument("error: invalid HF repo format, expected <user>/<model>[:quant]\n");
}
// fetch model info from Hugging Face Hub API
json model_info;
curl_ptr curl(curl_easy_init(), &curl_easy_cleanup);
curl_slist_ptr http_headers;
std::string res_str;
std::string url = "https://huggingface.co/v2/" + hf_repo + "/manifests/" + tag;
curl_easy_setopt(curl.get(), CURLOPT_URL, url.c_str());
curl_easy_setopt(curl.get(), CURLOPT_NOPROGRESS, 1L);
typedef size_t(*CURLOPT_WRITEFUNCTION_PTR)(void * ptr, size_t size, size_t nmemb, void * data);
auto write_callback = [](void * ptr, size_t size, size_t nmemb, void * data) -> size_t {
static_cast<std::string *>(data)->append((char * ) ptr, size * nmemb);
return size * nmemb;
};
curl_easy_setopt(curl.get(), CURLOPT_WRITEFUNCTION, static_cast<CURLOPT_WRITEFUNCTION_PTR>(write_callback));
curl_easy_setopt(curl.get(), CURLOPT_WRITEDATA, &res_str);
#if defined(_WIN32)
curl_easy_setopt(curl.get(), CURLOPT_SSL_OPTIONS, CURLSSLOPT_NATIVE_CA);
#endif
if (!hf_token.empty()) {
std::string auth_header = "Authorization: Bearer " + hf_token;
http_headers.ptr = curl_slist_append(http_headers.ptr, auth_header.c_str());
}
// Important: the User-Agent must be "llama-cpp" to get the "ggufFile" field in the response
http_headers.ptr = curl_slist_append(http_headers.ptr, "User-Agent: llama-cpp");
http_headers.ptr = curl_slist_append(http_headers.ptr, "Accept: application/json");
curl_easy_setopt(curl.get(), CURLOPT_HTTPHEADER, http_headers.ptr);
CURLcode res = curl_easy_perform(curl.get());
if (res != CURLE_OK) {
throw std::runtime_error("error: cannot make GET request to HF API");
}
long res_code;
curl_easy_getinfo(curl.get(), CURLINFO_RESPONSE_CODE, &res_code);
if (res_code == 200) {
model_info = json::parse(res_str);
} else if (res_code == 401) {
throw std::runtime_error("error: model is private or does not exist; if you are accessing a gated model, please provide a valid HF token");
} else {
throw std::runtime_error(string_format("error from HF API, response code: %ld, data: %s", res_code, res_str.c_str()));
}
// check response
if (!model_info.contains("ggufFile")) {
throw std::runtime_error("error: model does not have ggufFile");
}
json & gguf_file = model_info.at("ggufFile");
if (!gguf_file.contains("rfilename")) {
throw std::runtime_error("error: ggufFile does not have rfilename");
}
return std::make_pair(hf_repo, gguf_file.at("rfilename"));
}
#else
struct llama_model * common_load_model_from_url(
const std::string & /*model_url*/,
const std::string & /*local_path*/,
const std::string & /*hf_token*/,
const struct llama_model_params & /*params*/) {
LOG_WRN("%s: llama.cpp built without libcurl, downloading from an url not supported.\n", __func__);
return nullptr;
}
struct llama_model * common_load_model_from_hf(
const std::string & /*repo*/,
const std::string & /*remote_path*/,
const std::string & /*local_path*/,
const std::string & /*hf_token*/,
const struct llama_model_params & /*params*/) {
LOG_WRN("%s: llama.cpp built without libcurl, downloading from Hugging Face not supported.\n", __func__);
return nullptr;
}
std::pair<std::string, std::string> common_get_hf_file(const std::string &, const std::string &) {
LOG_WRN("%s: llama.cpp built without libcurl, downloading from Hugging Face not supported.\n", __func__);
return std::make_pair("", "");
}
#endif // LLAMA_USE_CURL
// //
// Batch utils // Batch utils
// //
@@ -1566,19 +2026,3 @@ common_control_vector_data common_control_vector_load(const std::vector<common_c
return result; return result;
} }
ggml_opt_dataset_t common_opt_dataset_init(struct llama_context * ctx, const std::vector<llama_token> & tokens, int64_t stride) {
const int64_t ne_datapoint = llama_n_ctx(ctx);
const int64_t ndata = (tokens.size() - ne_datapoint - 1) / stride;
ggml_opt_dataset_t result = ggml_opt_dataset_init(
GGML_TYPE_I32, GGML_TYPE_I32, ne_datapoint, ne_datapoint, ndata, /*ndata_shard =*/ 1);
llama_token * data = (llama_token *) ggml_opt_dataset_data(result)->data;
llama_token * labels = (llama_token *) ggml_opt_dataset_labels(result)->data;
for (int64_t idata = 0; idata < ndata; ++idata) {
memcpy(data + idata*ne_datapoint, tokens.data() + idata*stride + 0, ne_datapoint*sizeof(llama_token));
memcpy(labels + idata*ne_datapoint, tokens.data() + idata*stride + 1, ne_datapoint*sizeof(llama_token));
}
return result;
}

View File

@@ -66,6 +66,7 @@ enum llama_example {
LLAMA_EXAMPLE_COMMON, LLAMA_EXAMPLE_COMMON,
LLAMA_EXAMPLE_SPECULATIVE, LLAMA_EXAMPLE_SPECULATIVE,
LLAMA_EXAMPLE_MAIN, LLAMA_EXAMPLE_MAIN,
LLAMA_EXAMPLE_INFILL,
LLAMA_EXAMPLE_EMBEDDING, LLAMA_EXAMPLE_EMBEDDING,
LLAMA_EXAMPLE_PERPLEXITY, LLAMA_EXAMPLE_PERPLEXITY,
LLAMA_EXAMPLE_RETRIEVAL, LLAMA_EXAMPLE_RETRIEVAL,
@@ -95,7 +96,6 @@ enum common_sampler_type {
COMMON_SAMPLER_TYPE_XTC = 8, COMMON_SAMPLER_TYPE_XTC = 8,
COMMON_SAMPLER_TYPE_INFILL = 9, COMMON_SAMPLER_TYPE_INFILL = 9,
COMMON_SAMPLER_TYPE_PENALTIES = 10, COMMON_SAMPLER_TYPE_PENALTIES = 10,
COMMON_SAMPLER_TYPE_TOP_N_SIGMA = 11,
}; };
// dimensionality reduction methods, used by cvector-generator // dimensionality reduction methods, used by cvector-generator
@@ -110,17 +110,9 @@ enum common_conversation_mode {
COMMON_CONVERSATION_MODE_AUTO = 2, COMMON_CONVERSATION_MODE_AUTO = 2,
}; };
enum common_grammar_trigger_type {
COMMON_GRAMMAR_TRIGGER_TYPE_TOKEN,
COMMON_GRAMMAR_TRIGGER_TYPE_WORD,
COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN,
COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_START,
};
struct common_grammar_trigger { struct common_grammar_trigger {
common_grammar_trigger_type type; std::string word;
std::string value; bool at_start;
llama_token token = LLAMA_TOKEN_NULL;
}; };
// sampling parameters // sampling parameters
@@ -161,7 +153,6 @@ struct common_params_sampling {
std::vector<enum common_sampler_type> samplers = { std::vector<enum common_sampler_type> samplers = {
COMMON_SAMPLER_TYPE_PENALTIES, COMMON_SAMPLER_TYPE_PENALTIES,
COMMON_SAMPLER_TYPE_DRY, COMMON_SAMPLER_TYPE_DRY,
COMMON_SAMPLER_TYPE_TOP_N_SIGMA,
COMMON_SAMPLER_TYPE_TOP_K, COMMON_SAMPLER_TYPE_TOP_K,
COMMON_SAMPLER_TYPE_TYPICAL_P, COMMON_SAMPLER_TYPE_TYPICAL_P,
COMMON_SAMPLER_TYPE_TOP_P, COMMON_SAMPLER_TYPE_TOP_P,
@@ -172,7 +163,8 @@ struct common_params_sampling {
std::string grammar; // optional BNF-like grammar to constrain sampling std::string grammar; // optional BNF-like grammar to constrain sampling
bool grammar_lazy = false; bool grammar_lazy = false;
std::vector<common_grammar_trigger> grammar_triggers; // optional triggers (for lazy grammars) std::vector<common_grammar_trigger> grammar_trigger_words; // optional trigger words to trigger lazy grammar
std::vector<llama_token> grammar_trigger_tokens; // optional trigger tokens to trigger lazy grammar and print trigger special tokens.
std::set<llama_token> preserved_tokens; std::set<llama_token> preserved_tokens;
std::vector<llama_logit_bias> logit_bias; // logit biases to apply std::vector<llama_logit_bias> logit_bias; // logit biases to apply
@@ -181,13 +173,6 @@ struct common_params_sampling {
std::string print() const; std::string print() const;
}; };
struct common_params_model {
std::string path = ""; // model local path // NOLINT
std::string url = ""; // model url to download // NOLINT
std::string hf_repo = ""; // HF repo // NOLINT
std::string hf_file = ""; // HF file // NOLINT
};
struct common_params_speculative { struct common_params_speculative {
std::vector<ggml_backend_dev_t> devices; // devices to use for offloading std::vector<ggml_backend_dev_t> devices; // devices to use for offloading
@@ -201,13 +186,19 @@ struct common_params_speculative {
struct cpu_params cpuparams; struct cpu_params cpuparams;
struct cpu_params cpuparams_batch; struct cpu_params cpuparams_batch;
struct common_params_model model; std::string hf_repo = ""; // HF repo // NOLINT
std::string hf_file = ""; // HF file // NOLINT
std::string model = ""; // draft model for speculative decoding // NOLINT
std::string model_url = ""; // model url to download // NOLINT
}; };
struct common_params_vocoder { struct common_params_vocoder {
struct common_params_model model; std::string hf_repo = ""; // HF repo // NOLINT
std::string hf_file = ""; // HF file // NOLINT
std::string speaker_file = ""; // speaker file path // NOLINT std::string model = ""; // model path // NOLINT
std::string model_url = ""; // model url to download // NOLINT
bool use_guide_tokens = false; // enable guide tokens to improve TTS accuracy // NOLINT bool use_guide_tokens = false; // enable guide tokens to improve TTS accuracy // NOLINT
}; };
@@ -263,12 +254,13 @@ struct common_params {
struct common_params_speculative speculative; struct common_params_speculative speculative;
struct common_params_vocoder vocoder; struct common_params_vocoder vocoder;
struct common_params_model model; std::string model = ""; // model path // NOLINT
std::string model_alias = ""; // model alias // NOLINT std::string model_alias = ""; // model alias // NOLINT
std::string model_url = ""; // model url to download // NOLINT
std::string hf_token = ""; // HF token // NOLINT std::string hf_token = ""; // HF token // NOLINT
std::string hf_repo = ""; // HF repo // NOLINT
std::string hf_file = ""; // HF file // NOLINT
std::string prompt = ""; // NOLINT std::string prompt = ""; // NOLINT
std::string system_prompt = ""; // NOLINT
std::string prompt_file = ""; // store the external prompt file name // NOLINT std::string prompt_file = ""; // store the external prompt file name // NOLINT
std::string path_prompt_cache = ""; // path to file for saving/loading prompt eval state // NOLINT std::string path_prompt_cache = ""; // path to file for saving/loading prompt eval state // NOLINT
std::string input_prefix = ""; // string to prefix user inputs with // NOLINT std::string input_prefix = ""; // string to prefix user inputs with // NOLINT
@@ -280,7 +272,6 @@ struct common_params {
std::vector<std::string> in_files; // all input files std::vector<std::string> in_files; // all input files
std::vector<std::string> antiprompt; // strings upon which more user input is prompted (a.k.a. reverse prompts) std::vector<std::string> antiprompt; // strings upon which more user input is prompted (a.k.a. reverse prompts)
std::vector<llama_model_kv_override> kv_overrides; std::vector<llama_model_kv_override> kv_overrides;
std::vector<llama_model_tensor_buft_override> tensor_buft_overrides;
bool lora_init_without_apply = false; // only load lora to memory, but do not apply it to ctx (user can manually apply lora later using llama_adapter_lora_apply) bool lora_init_without_apply = false; // only load lora to memory, but do not apply it to ctx (user can manually apply lora later using llama_adapter_lora_apply)
std::vector<common_adapter_lora_info> lora_adapters; // lora adapter path with user defined scale std::vector<common_adapter_lora_info> lora_adapters; // lora adapter path with user defined scale
@@ -324,6 +315,7 @@ struct common_params {
bool ctx_shift = true; // context shift on inifinite text generation bool ctx_shift = true; // context shift on inifinite text generation
bool input_prefix_bos = false; // prefix BOS to user inputs, preceding input_prefix bool input_prefix_bos = false; // prefix BOS to user inputs, preceding input_prefix
bool logits_all = false; // return logits for all tokens in the batch
bool use_mmap = true; // use mmap for faster loads bool use_mmap = true; // use mmap for faster loads
bool use_mlock = false; // use mlock to keep model in memory bool use_mlock = false; // use mlock to keep model in memory
bool verbose_prompt = false; // print prompt tokens before generation bool verbose_prompt = false; // print prompt tokens before generation
@@ -332,19 +324,14 @@ struct common_params {
bool no_kv_offload = false; // disable KV offloading bool no_kv_offload = false; // disable KV offloading
bool warmup = true; // warmup run bool warmup = true; // warmup run
bool check_tensors = false; // validate tensor data bool check_tensors = false; // validate tensor data
bool no_op_offload = false; // globally disable offload host tensor operations to device
bool single_turn = false; // single turn chat conversation
ggml_type cache_type_k = GGML_TYPE_F16; // KV cache data type for the K ggml_type cache_type_k = GGML_TYPE_F16; // KV cache data type for the K
ggml_type cache_type_v = GGML_TYPE_F16; // KV cache data type for the V ggml_type cache_type_v = GGML_TYPE_F16; // KV cache data type for the V
common_conversation_mode conversation_mode = COMMON_CONVERSATION_MODE_AUTO; common_conversation_mode conversation_mode = COMMON_CONVERSATION_MODE_AUTO;
// multimodal models (see tools/mtmd) // multimodal models (see examples/llava)
struct common_params_model mmproj; std::string mmproj = ""; // path to multimodal projector // NOLINT
bool mmproj_use_gpu = true; // use GPU for multimodal model
bool no_mmproj = false; // explicitly disable multimodal model
std::vector<std::string> image; // path to image file(s) std::vector<std::string> image; // path to image file(s)
// embedding // embedding
@@ -404,28 +391,29 @@ struct common_params {
int32_t i_pos = -1; // position of the passkey in the junk text int32_t i_pos = -1; // position of the passkey in the junk text
// imatrix params // imatrix params
std::string out_file = "imatrix.dat"; // save the resulting imatrix to this file
int32_t n_out_freq = 10; // output the imatrix every n_out_freq iterations int32_t n_out_freq = 10; // output the imatrix every n_out_freq iterations
int32_t n_save_freq = 0; // save the imatrix every n_save_freq iterations int32_t n_save_freq = 0; // save the imatrix every n_save_freq iterations
int32_t i_chunk = 0; // start processing from this chunk int32_t i_chunk = 0; // start processing from this chunk
bool process_output = false; // collect data for the output tensor bool process_output = false; // collect data for the output tensor
bool compute_ppl = true; // whether to compute perplexity bool compute_ppl = true; // whether to compute perplexity
bool parse_special = false; // whether to parse special tokens during imatrix tokenization
// cvector-generator params // cvector-generator params
int n_pca_batch = 100; int n_pca_batch = 100;
int n_pca_iterations = 1000; int n_pca_iterations = 1000;
dimre_method cvector_dimre_method = DIMRE_METHOD_PCA; dimre_method cvector_dimre_method = DIMRE_METHOD_PCA;
std::string cvector_positive_file = "tools/cvector-generator/positive.txt"; std::string cvector_outfile = "control_vector.gguf";
std::string cvector_negative_file = "tools/cvector-generator/negative.txt"; std::string cvector_positive_file = "examples/cvector-generator/positive.txt";
std::string cvector_negative_file = "examples/cvector-generator/negative.txt";
bool spm_infill = false; // suffix/prefix/middle pattern for infill bool spm_infill = false; // suffix/prefix/middle pattern for infill
std::string lora_outfile = "ggml-lora-merged-f16.gguf";
// batched-bench params // batched-bench params
bool batched_bench_output_jsonl = false; bool batched_bench_output_jsonl = false;
// common params
std::string out_file; // output filename for all example programs
}; };
// call once at the start of a program if it uses libcommon // call once at the start of a program if it uses libcommon
@@ -465,8 +453,6 @@ std::string string_repeat(const std::string & str, size_t n);
void string_replace_all(std::string & s, const std::string & search, const std::string & replace); void string_replace_all(std::string & s, const std::string & search, const std::string & replace);
std::string regex_escape(const std::string & s);
template<class T> template<class T>
static std::vector<T> string_split(const std::string & str, char delim) { static std::vector<T> string_split(const std::string & str, char delim) {
static_assert(!std::is_same<T, std::string>::value, "Please use the specialized version for std::string"); static_assert(!std::is_same<T, std::string>::value, "Please use the specialized version for std::string");
@@ -544,11 +530,26 @@ struct llama_model_params common_model_params_to_llama ( common_params
struct llama_context_params common_context_params_to_llama(const common_params & params); struct llama_context_params common_context_params_to_llama(const common_params & params);
struct ggml_threadpool_params ggml_threadpool_params_from_cpu_params(const cpu_params & params); struct ggml_threadpool_params ggml_threadpool_params_from_cpu_params(const cpu_params & params);
struct llama_model * common_load_model_from_url(
const std::string & model_url,
const std::string & local_path,
const std::string & hf_token,
const struct llama_model_params & params);
struct llama_model * common_load_model_from_hf(
const std::string & repo,
const std::string & remote_path,
const std::string & local_path,
const std::string & hf_token,
const struct llama_model_params & params);
std::pair<std::string, std::string> common_get_hf_file(
const std::string & hf_repo_with_tag,
const std::string & hf_token);
// clear LoRA adapters from context, then apply new list of adapters // clear LoRA adapters from context, then apply new list of adapters
void common_set_adapter_lora(struct llama_context * ctx, std::vector<common_adapter_lora_info> & lora); void common_set_adapter_lora(struct llama_context * ctx, std::vector<common_adapter_lora_info> & lora);
std::string get_model_endpoint();
// //
// Batch utils // Batch utils
// //
@@ -666,9 +667,3 @@ const char * const LLM_KV_SPLIT_COUNT = "split.count";
const char * const LLM_KV_SPLIT_TENSORS_COUNT = "split.tensors.count"; const char * const LLM_KV_SPLIT_TENSORS_COUNT = "split.tensors.count";
} }
//
// training utils
//
ggml_opt_dataset_t common_opt_dataset_init(struct llama_context * ctx, const std::vector<llama_token> & tokens, int64_t stride);

View File

@@ -16,9 +16,6 @@ using json = nlohmann::ordered_json;
static std::string build_repetition(const std::string & item_rule, int min_items, int max_items, const std::string & separator_rule = "") { static std::string build_repetition(const std::string & item_rule, int min_items, int max_items, const std::string & separator_rule = "") {
auto has_max = max_items != std::numeric_limits<int>::max(); auto has_max = max_items != std::numeric_limits<int>::max();
if (max_items == 0) {
return "";
}
if (min_items == 0 && max_items == 1) { if (min_items == 0 && max_items == 1) {
return item_rule + "?"; return item_rule + "?";
} }
@@ -267,7 +264,7 @@ static void _build_min_max_int(int min_value, int max_value, std::stringstream &
throw std::runtime_error("At least one of min_value or max_value must be set"); throw std::runtime_error("At least one of min_value or max_value must be set");
} }
const std::string SPACE_RULE = "| \" \" | \"\\n\"{1,2} [ \\t]{0,20}"; const std::string SPACE_RULE = "| \" \" | \"\\n\" [ \\t]{0,20}";
struct BuiltinRule { struct BuiltinRule {
std::string content; std::string content;
@@ -767,10 +764,11 @@ private:
public: public:
SchemaConverter( SchemaConverter(
const std::function<json(const std::string &)> & fetch_json, const std::function<json(const std::string &)> & fetch_json,
bool dotall) bool dotall,
bool compact_spaces)
: _fetch_json(fetch_json), _dotall(dotall) : _fetch_json(fetch_json), _dotall(dotall)
{ {
_rules["space"] = SPACE_RULE; _rules["space"] = compact_spaces ? "\" \"?" : SPACE_RULE;
} }
void resolve_refs(json & schema, const std::string & url) { void resolve_refs(json & schema, const std::string & url) {
@@ -1009,7 +1007,7 @@ std::string json_schema_to_grammar(const json & schema, bool force_gbnf) {
} }
std::string build_grammar(const std::function<void(const common_grammar_builder &)> & cb, const common_grammar_options & options) { std::string build_grammar(const std::function<void(const common_grammar_builder &)> & cb, const common_grammar_options & options) {
SchemaConverter converter([&](const std::string &) { return json(); }, options.dotall); SchemaConverter converter([&](const std::string &) { return json(); }, options.dotall, options.compact_spaces);
common_grammar_builder builder { common_grammar_builder builder {
/* .add_rule = */ [&](const std::string & name, const std::string & rule) { /* .add_rule = */ [&](const std::string & name, const std::string & rule) {
return converter._add_rule(name, rule); return converter._add_rule(name, rule);

View File

@@ -16,6 +16,7 @@ struct common_grammar_builder {
struct common_grammar_options { struct common_grammar_options {
bool dotall = false; bool dotall = false;
bool compact_spaces = false;
}; };
std::string build_grammar(const std::function<void(const common_grammar_builder &)> & cb, const common_grammar_options & options = {}); std::string build_grammar(const std::function<void(const common_grammar_builder &)> & cb, const common_grammar_options & options = {});

View File

@@ -1,11 +1,9 @@
#include "sampling.h" #include "sampling.h"
#include "common.h" #include "common.h"
#include "log.h"
#include <cmath> #include <cmath>
#include <unordered_map> #include <unordered_map>
#include <algorithm>
// the ring buffer works similarly to std::deque, but with a fixed capacity // the ring buffer works similarly to std::deque, but with a fixed capacity
// TODO: deduplicate with llama-impl.h // TODO: deduplicate with llama-impl.h
@@ -161,57 +159,17 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co
GGML_ABORT("llguidance (cmake -DLLAMA_LLGUIDANCE=ON) is not enabled"); GGML_ABORT("llguidance (cmake -DLLAMA_LLGUIDANCE=ON) is not enabled");
#endif // LLAMA_USE_LLGUIDANCE #endif // LLAMA_USE_LLGUIDANCE
} else { } else {
std::vector<std::string> patterns_at_start; std::vector<const char *> trigger_words;
std::vector<std::string> patterns_anywhere; trigger_words.reserve(params.grammar_trigger_words.size());
std::vector<llama_token> trigger_tokens; for (const auto & str : params.grammar_trigger_words) {
for (const auto & trigger : params.grammar_triggers) { trigger_words.push_back(str.word.c_str());
switch (trigger.type) {
case COMMON_GRAMMAR_TRIGGER_TYPE_WORD:
{
const auto & word = trigger.value;
patterns_anywhere.push_back(regex_escape(word));
break;
}
case COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN:
case COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_START:
{
const auto & pattern = trigger.value;
(trigger.type == COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_START ? patterns_at_start : patterns_anywhere).push_back(pattern);
break;
}
case COMMON_GRAMMAR_TRIGGER_TYPE_TOKEN:
{
const auto token = trigger.token;
trigger_tokens.push_back(token);
break;
}
default:
GGML_ASSERT(false && "unknown trigger type");
}
}
std::vector<std::string> trigger_patterns;
if (!patterns_at_start.empty()) {
trigger_patterns.push_back("^(" + string_join(patterns_at_start, "|") + ")[\\s\\S]*");
}
if (!patterns_anywhere.empty()) {
trigger_patterns.push_back("^[\\s\\S]*?(" + string_join(patterns_anywhere, "|") + ")[\\s\\S]*");
}
std::vector<const char *> trigger_patterns_c;
trigger_patterns_c.reserve(trigger_patterns.size());
for (const auto & regex : trigger_patterns) {
trigger_patterns_c.push_back(regex.c_str());
} }
grmr = params.grammar_lazy grmr = params.grammar_lazy
? llama_sampler_init_grammar_lazy_patterns(vocab, params.grammar.c_str(), "root", ? llama_sampler_init_grammar_lazy(vocab, params.grammar.c_str(), "root",
trigger_patterns_c.data(), trigger_patterns_c.size(), trigger_words.data(), trigger_words.size(),
trigger_tokens.data(), trigger_tokens.size()) params.grammar_trigger_tokens.data(), params.grammar_trigger_tokens.size())
: llama_sampler_init_grammar(vocab, params.grammar.c_str(), "root"); : llama_sampler_init_grammar(vocab, params.grammar.c_str(), "root");
if (!grmr) {
return nullptr;
}
} }
auto * result = new common_sampler { auto * result = new common_sampler {
@@ -230,48 +188,51 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co
params.logit_bias.data())); params.logit_bias.data()));
if (params.mirostat == 0) { if (params.mirostat == 0) {
for (const auto & cnstr : params.samplers) { if (params.top_n_sigma >= 0) {
switch (cnstr) { llama_sampler_chain_add(result->chain, llama_sampler_init_top_k (params.top_k));
case COMMON_SAMPLER_TYPE_DRY: llama_sampler_chain_add(result->chain, llama_sampler_init_temp (params.temp));
{ llama_sampler_chain_add(result->chain, llama_sampler_init_top_n_sigma (params.top_n_sigma));
std::vector<const char *> c_breakers; } else {
c_breakers.reserve(params.dry_sequence_breakers.size()); for (const auto & cnstr : params.samplers) {
for (const auto & str : params.dry_sequence_breakers) { switch (cnstr) {
c_breakers.push_back(str.c_str()); case COMMON_SAMPLER_TYPE_DRY:
} {
std::vector<const char *> c_breakers;
c_breakers.reserve(params.dry_sequence_breakers.size());
for (const auto & str : params.dry_sequence_breakers) {
c_breakers.push_back(str.c_str());
}
llama_sampler_chain_add(result->chain, llama_sampler_init_dry (vocab, llama_model_n_ctx_train(model), params.dry_multiplier, params.dry_base, params.dry_allowed_length, params.dry_penalty_last_n, c_breakers.data(), c_breakers.size())); llama_sampler_chain_add(result->chain, llama_sampler_init_dry (vocab, llama_model_n_ctx_train(model), params.dry_multiplier, params.dry_base, params.dry_allowed_length, params.dry_penalty_last_n, c_breakers.data(), c_breakers.size()));
} }
break; break;
case COMMON_SAMPLER_TYPE_TOP_K: case COMMON_SAMPLER_TYPE_TOP_K:
llama_sampler_chain_add(result->chain, llama_sampler_init_top_k (params.top_k)); llama_sampler_chain_add(result->chain, llama_sampler_init_top_k (params.top_k));
break; break;
case COMMON_SAMPLER_TYPE_TOP_P: case COMMON_SAMPLER_TYPE_TOP_P:
llama_sampler_chain_add(result->chain, llama_sampler_init_top_p (params.top_p, params.min_keep)); llama_sampler_chain_add(result->chain, llama_sampler_init_top_p (params.top_p, params.min_keep));
break; break;
case COMMON_SAMPLER_TYPE_TOP_N_SIGMA: case COMMON_SAMPLER_TYPE_MIN_P:
llama_sampler_chain_add(result->chain, llama_sampler_init_top_n_sigma (params.top_n_sigma)); llama_sampler_chain_add(result->chain, llama_sampler_init_min_p (params.min_p, params.min_keep));
break; break;
case COMMON_SAMPLER_TYPE_MIN_P: case COMMON_SAMPLER_TYPE_XTC:
llama_sampler_chain_add(result->chain, llama_sampler_init_min_p (params.min_p, params.min_keep)); llama_sampler_chain_add(result->chain, llama_sampler_init_xtc (params.xtc_probability, params.xtc_threshold, params.min_keep, params.seed));
break; break;
case COMMON_SAMPLER_TYPE_XTC: case COMMON_SAMPLER_TYPE_TYPICAL_P:
llama_sampler_chain_add(result->chain, llama_sampler_init_xtc (params.xtc_probability, params.xtc_threshold, params.min_keep, params.seed)); llama_sampler_chain_add(result->chain, llama_sampler_init_typical (params.typ_p, params.min_keep));
break; break;
case COMMON_SAMPLER_TYPE_TYPICAL_P: case COMMON_SAMPLER_TYPE_TEMPERATURE:
llama_sampler_chain_add(result->chain, llama_sampler_init_typical (params.typ_p, params.min_keep)); llama_sampler_chain_add(result->chain, llama_sampler_init_temp_ext (params.temp, params.dynatemp_range, params.dynatemp_exponent));
break; break;
case COMMON_SAMPLER_TYPE_TEMPERATURE: case COMMON_SAMPLER_TYPE_INFILL:
llama_sampler_chain_add(result->chain, llama_sampler_init_temp_ext (params.temp, params.dynatemp_range, params.dynatemp_exponent)); llama_sampler_chain_add(result->chain, llama_sampler_init_infill (vocab));
break; break;
case COMMON_SAMPLER_TYPE_INFILL: case COMMON_SAMPLER_TYPE_PENALTIES:
llama_sampler_chain_add(result->chain, llama_sampler_init_infill (vocab)); llama_sampler_chain_add(result->chain, llama_sampler_init_penalties(params.penalty_last_n, params.penalty_repeat, params.penalty_freq, params.penalty_present));
break; break;
case COMMON_SAMPLER_TYPE_PENALTIES: default:
llama_sampler_chain_add(result->chain, llama_sampler_init_penalties (params.penalty_last_n, params.penalty_repeat, params.penalty_freq, params.penalty_present)); GGML_ASSERT(false && "unknown sampler type");
break; }
default:
GGML_ASSERT(false && "unknown sampler type");
} }
} }
llama_sampler_chain_add(result->chain, llama_sampler_init_dist(params.seed)); llama_sampler_chain_add(result->chain, llama_sampler_init_dist(params.seed));
@@ -473,7 +434,6 @@ char common_sampler_type_to_chr(enum common_sampler_type cnstr) {
case COMMON_SAMPLER_TYPE_TOP_K: return 'k'; case COMMON_SAMPLER_TYPE_TOP_K: return 'k';
case COMMON_SAMPLER_TYPE_TYPICAL_P: return 'y'; case COMMON_SAMPLER_TYPE_TYPICAL_P: return 'y';
case COMMON_SAMPLER_TYPE_TOP_P: return 'p'; case COMMON_SAMPLER_TYPE_TOP_P: return 'p';
case COMMON_SAMPLER_TYPE_TOP_N_SIGMA: return 's';
case COMMON_SAMPLER_TYPE_MIN_P: return 'm'; case COMMON_SAMPLER_TYPE_MIN_P: return 'm';
case COMMON_SAMPLER_TYPE_TEMPERATURE: return 't'; case COMMON_SAMPLER_TYPE_TEMPERATURE: return 't';
case COMMON_SAMPLER_TYPE_XTC: return 'x'; case COMMON_SAMPLER_TYPE_XTC: return 'x';
@@ -489,7 +449,6 @@ std::string common_sampler_type_to_str(enum common_sampler_type cnstr) {
case COMMON_SAMPLER_TYPE_TOP_K: return "top_k"; case COMMON_SAMPLER_TYPE_TOP_K: return "top_k";
case COMMON_SAMPLER_TYPE_TYPICAL_P: return "typ_p"; case COMMON_SAMPLER_TYPE_TYPICAL_P: return "typ_p";
case COMMON_SAMPLER_TYPE_TOP_P: return "top_p"; case COMMON_SAMPLER_TYPE_TOP_P: return "top_p";
case COMMON_SAMPLER_TYPE_TOP_N_SIGMA: return "top_n_sigma";
case COMMON_SAMPLER_TYPE_MIN_P: return "min_p"; case COMMON_SAMPLER_TYPE_MIN_P: return "min_p";
case COMMON_SAMPLER_TYPE_TEMPERATURE: return "temperature"; case COMMON_SAMPLER_TYPE_TEMPERATURE: return "temperature";
case COMMON_SAMPLER_TYPE_XTC: return "xtc"; case COMMON_SAMPLER_TYPE_XTC: return "xtc";
@@ -504,7 +463,6 @@ std::vector<common_sampler_type> common_sampler_types_from_names(const std::vect
{ "dry", COMMON_SAMPLER_TYPE_DRY }, { "dry", COMMON_SAMPLER_TYPE_DRY },
{ "top_k", COMMON_SAMPLER_TYPE_TOP_K }, { "top_k", COMMON_SAMPLER_TYPE_TOP_K },
{ "top_p", COMMON_SAMPLER_TYPE_TOP_P }, { "top_p", COMMON_SAMPLER_TYPE_TOP_P },
{ "top_n_sigma", COMMON_SAMPLER_TYPE_TOP_N_SIGMA },
{ "typ_p", COMMON_SAMPLER_TYPE_TYPICAL_P }, { "typ_p", COMMON_SAMPLER_TYPE_TYPICAL_P },
{ "min_p", COMMON_SAMPLER_TYPE_MIN_P }, { "min_p", COMMON_SAMPLER_TYPE_MIN_P },
{ "temperature", COMMON_SAMPLER_TYPE_TEMPERATURE }, { "temperature", COMMON_SAMPLER_TYPE_TEMPERATURE },
@@ -518,7 +476,6 @@ std::vector<common_sampler_type> common_sampler_types_from_names(const std::vect
std::unordered_map<std::string, common_sampler_type> sampler_alt_name_map { std::unordered_map<std::string, common_sampler_type> sampler_alt_name_map {
{ "top-k", COMMON_SAMPLER_TYPE_TOP_K }, { "top-k", COMMON_SAMPLER_TYPE_TOP_K },
{ "top-p", COMMON_SAMPLER_TYPE_TOP_P }, { "top-p", COMMON_SAMPLER_TYPE_TOP_P },
{ "top-n-sigma", COMMON_SAMPLER_TYPE_TOP_N_SIGMA },
{ "nucleus", COMMON_SAMPLER_TYPE_TOP_P }, { "nucleus", COMMON_SAMPLER_TYPE_TOP_P },
{ "typical-p", COMMON_SAMPLER_TYPE_TYPICAL_P }, { "typical-p", COMMON_SAMPLER_TYPE_TYPICAL_P },
{ "typical", COMMON_SAMPLER_TYPE_TYPICAL_P }, { "typical", COMMON_SAMPLER_TYPE_TYPICAL_P },
@@ -535,16 +492,14 @@ std::vector<common_sampler_type> common_sampler_types_from_names(const std::vect
auto sampler = sampler_canonical_name_map.find(name); auto sampler = sampler_canonical_name_map.find(name);
if (sampler != sampler_canonical_name_map.end()) { if (sampler != sampler_canonical_name_map.end()) {
samplers.push_back(sampler->second); samplers.push_back(sampler->second);
continue; } else {
} if (allow_alt_names) {
if (allow_alt_names) { sampler = sampler_alt_name_map.find(name);
sampler = sampler_alt_name_map.find(name); if (sampler != sampler_alt_name_map.end()) {
if (sampler != sampler_alt_name_map.end()) { samplers.push_back(sampler->second);
samplers.push_back(sampler->second); }
continue;
} }
} }
LOG_WRN("%s: unable to match sampler by name '%s'\n", __func__, name.c_str());
} }
return samplers; return samplers;
@@ -556,7 +511,6 @@ std::vector<common_sampler_type> common_sampler_types_from_chars(const std::stri
{ common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_TOP_K), COMMON_SAMPLER_TYPE_TOP_K }, { common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_TOP_K), COMMON_SAMPLER_TYPE_TOP_K },
{ common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_TYPICAL_P), COMMON_SAMPLER_TYPE_TYPICAL_P }, { common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_TYPICAL_P), COMMON_SAMPLER_TYPE_TYPICAL_P },
{ common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_TOP_P), COMMON_SAMPLER_TYPE_TOP_P }, { common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_TOP_P), COMMON_SAMPLER_TYPE_TOP_P },
{ common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_TOP_N_SIGMA), COMMON_SAMPLER_TYPE_TOP_N_SIGMA },
{ common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_MIN_P), COMMON_SAMPLER_TYPE_MIN_P }, { common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_MIN_P), COMMON_SAMPLER_TYPE_MIN_P },
{ common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_TEMPERATURE), COMMON_SAMPLER_TYPE_TEMPERATURE }, { common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_TEMPERATURE), COMMON_SAMPLER_TYPE_TEMPERATURE },
{ common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_XTC), COMMON_SAMPLER_TYPE_XTC }, { common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_XTC), COMMON_SAMPLER_TYPE_XTC },
@@ -571,8 +525,6 @@ std::vector<common_sampler_type> common_sampler_types_from_chars(const std::stri
const auto sampler = sampler_name_map.find(c); const auto sampler = sampler_name_map.find(c);
if (sampler != sampler_name_map.end()) { if (sampler != sampler_name_map.end()) {
samplers.push_back(sampler->second); samplers.push_back(sampler->second);
} else {
LOG_WRN("%s: unable to match sampler by char '%c'\n", __func__, c);
} }
} }

3032
llama/llama.cpp/examples/llava/clip.cpp vendored Normal file

File diff suppressed because it is too large Load Diff

View File

@@ -1,7 +1,6 @@
#ifndef CLIP_H #ifndef CLIP_H
#define CLIP_H #define CLIP_H
#include "ggml.h"
#include <stddef.h> #include <stddef.h>
#include <stdint.h> #include <stdint.h>
@@ -30,28 +29,27 @@ struct clip_image_size {
int height; int height;
}; };
struct clip_image_f32; struct clip_image_u8_batch {
struct clip_image_u8_batch; struct clip_image_u8 * data;
struct clip_image_f32_batch; size_t size;
struct clip_context_params {
bool use_gpu;
enum ggml_log_level verbosity;
}; };
// deprecated, use clip_init struct clip_image_f32_batch {
CLIP_API struct clip_ctx * clip_model_load(const char * fname, int verbosity); struct clip_image_f32 * data;
size_t size;
};
CLIP_API struct clip_ctx * clip_init(const char * fname, struct clip_context_params ctx_params); CLIP_API struct clip_ctx * clip_model_load (const char * fname, int verbosity);
CLIP_API struct clip_ctx * clip_model_load_cpu(const char * fname, int verbosity);
CLIP_API void clip_free(struct clip_ctx * ctx); CLIP_API void clip_free(struct clip_ctx * ctx);
CLIP_API size_t clip_embd_nbytes(const struct clip_ctx * ctx); CLIP_API size_t clip_embd_nbytes(const struct clip_ctx * ctx);
CLIP_API size_t clip_embd_nbytes_by_img(const struct clip_ctx * ctx, int img_w, int img_h); CLIP_API size_t clip_embd_nbytes_by_img(const struct clip_ctx * ctx, int img_h, int img_w);
CLIP_API int32_t clip_get_image_size (const struct clip_ctx * ctx); CLIP_API int32_t clip_image_size (const struct clip_ctx * ctx);
CLIP_API int32_t clip_get_patch_size (const struct clip_ctx * ctx); CLIP_API int32_t clip_patch_size (const struct clip_ctx * ctx);
CLIP_API int32_t clip_get_hidden_size(const struct clip_ctx * ctx); CLIP_API int32_t clip_hidden_size(const struct clip_ctx * ctx);
// TODO: should be enum, not string // TODO: should be enum, not string
CLIP_API const char * clip_patch_merge_type(const struct clip_ctx * ctx); CLIP_API const char * clip_patch_merge_type(const struct clip_ctx * ctx);
@@ -59,49 +57,24 @@ CLIP_API const char * clip_patch_merge_type(const struct clip_ctx * ctx);
CLIP_API const int32_t * clip_image_grid(const struct clip_ctx * ctx); CLIP_API const int32_t * clip_image_grid(const struct clip_ctx * ctx);
CLIP_API size_t get_clip_image_grid_size(const struct clip_ctx * ctx); CLIP_API size_t get_clip_image_grid_size(const struct clip_ctx * ctx);
GGML_DEPRECATED(CLIP_API int clip_n_patches(const struct clip_ctx * ctx), CLIP_API int clip_n_patches (const struct clip_ctx * ctx);
"use clip_n_output_tokens instead"); CLIP_API int clip_n_patches_by_img (const struct clip_ctx * ctx, struct clip_image_f32 * img);
GGML_DEPRECATED(CLIP_API int clip_n_patches_by_img(const struct clip_ctx * ctx, struct clip_image_f32 * img), CLIP_API int clip_n_mmproj_embd (const struct clip_ctx * ctx);
"use clip_n_output_tokens instead");
CLIP_API int clip_n_output_tokens(const struct clip_ctx * ctx, struct clip_image_f32 * img);
// for M-RoPE, this will be the number of token positions in X and Y directions
// for other models, X will be the total number of tokens and Y will be 1
CLIP_API int clip_n_output_tokens_x(const struct clip_ctx * ctx, struct clip_image_f32 * img);
CLIP_API int clip_n_output_tokens_y(const struct clip_ctx * ctx, struct clip_image_f32 * img);
// this should be equal to the embedding dimension of the text model
CLIP_API int clip_n_mmproj_embd(const struct clip_ctx * ctx);
CLIP_API int clip_uhd_num_image_embeds_col(struct clip_ctx * ctx_clip); CLIP_API int clip_uhd_num_image_embeds_col(struct clip_ctx * ctx_clip);
CLIP_API void clip_add_load_image_size(struct clip_ctx * ctx_clip, struct clip_image_size * load_image_size); CLIP_API void clip_add_load_image_size(struct clip_ctx * ctx_clip, struct clip_image_size * load_image_size);
CLIP_API struct clip_image_size * clip_get_load_image_size(struct clip_ctx * ctx_clip); CLIP_API struct clip_image_size * clip_get_load_image_size(struct clip_ctx * ctx_clip);
CLIP_API struct clip_image_size * clip_image_size_init(void); CLIP_API struct clip_image_size * clip_image_size_init();
CLIP_API struct clip_image_u8 * clip_image_u8_init (void); CLIP_API struct clip_image_u8 * clip_image_u8_init ();
CLIP_API struct clip_image_f32 * clip_image_f32_init(void); CLIP_API struct clip_image_f32 * clip_image_f32_init();
CLIP_API struct clip_image_f32_batch * clip_image_f32_batch_init(void); // only used by libllava
// nx, ny are the output image dimensions
CLIP_API unsigned char * clip_image_u8_get_data(struct clip_image_u8 * img, uint32_t * nx, uint32_t * ny);
CLIP_API void clip_image_size_free (struct clip_image_size * img_size);
CLIP_API void clip_image_u8_free (struct clip_image_u8 * img); CLIP_API void clip_image_u8_free (struct clip_image_u8 * img);
CLIP_API void clip_image_f32_free(struct clip_image_f32 * img); CLIP_API void clip_image_f32_free(struct clip_image_f32 * img);
CLIP_API void clip_image_u8_batch_free (struct clip_image_u8_batch * batch); CLIP_API void clip_image_u8_batch_free (struct clip_image_u8_batch * batch);
CLIP_API void clip_image_f32_batch_free(struct clip_image_f32_batch * batch); CLIP_API void clip_image_f32_batch_free(struct clip_image_f32_batch * batch);
// use for accessing underlay data of clip_image_f32_batch /** build image from pixels decoded by other libraries instead of stb_image.h for better performance. The memory layout is RGBRGBRGB..., input buffer length must be 3*nx*ny bytes */
CLIP_API size_t clip_image_f32_batch_n_images(const struct clip_image_f32_batch * batch); // equivalent to batch->size()
CLIP_API size_t clip_image_f32_batch_nx(const struct clip_image_f32_batch * batch, int idx); // equivalent to batch[idx]->nx
CLIP_API size_t clip_image_f32_batch_ny(const struct clip_image_f32_batch * batch, int idx); // equivalent to batch[idx]->ny
CLIP_API struct clip_image_f32 * clip_image_f32_get_img(const struct clip_image_f32_batch * batch, int idx); // equivalent to batch[idx]->data
/**
* Build image from pixels decoded by other libraries instead of stb_image.h for better performance.
* The memory layout is RGBRGBRGB..., input buffer length must be 3*nx*ny bytes
*/
CLIP_API void clip_build_img_from_pixels(const unsigned char * rgb_pixels, int nx, int ny, struct clip_image_u8 * img); CLIP_API void clip_build_img_from_pixels(const unsigned char * rgb_pixels, int nx, int ny, struct clip_image_u8 * img);
CLIP_API bool clip_image_load_from_file(const char * fname, struct clip_image_u8 * img); CLIP_API bool clip_image_load_from_file(const char * fname, struct clip_image_u8 * img);
@@ -122,8 +95,8 @@ CLIP_API bool clip_model_quantize(const char * fname_inp, const char * fname_out
CLIP_API int clip_is_minicpmv(const struct clip_ctx * ctx); CLIP_API int clip_is_minicpmv(const struct clip_ctx * ctx);
CLIP_API bool clip_is_glm(const struct clip_ctx * ctx); CLIP_API bool clip_is_glm(const struct clip_ctx * ctx);
CLIP_API bool clip_is_qwen2vl(const struct clip_ctx * ctx); CLIP_API bool clip_is_qwen2vl(const struct clip_ctx * ctx);
CLIP_API bool clip_is_llava(const struct clip_ctx * ctx);
CLIP_API bool clip_is_gemma3(const struct clip_ctx * ctx); CLIP_API int get_deepest_feature_layer(const struct clip_ctx * ctx);
CLIP_API bool clip_encode_float_image (struct clip_ctx * ctx, int n_threads, float * img, int h, int w, float * vec); CLIP_API bool clip_encode_float_image (struct clip_ctx * ctx, int n_threads, float * img, int h, int w, float * vec);

Some files were not shown because too many files have changed in this diff Show More