Compare commits

..

1 Commits

Author SHA1 Message Date
Bruce MacDonald
715952705e model: framework for testing forward pass 2025-05-08 09:25:12 -07:00
324 changed files with 12621 additions and 150172 deletions

View File

@@ -23,7 +23,7 @@ jobs:
echo GOFLAGS="'-ldflags=-w -s \"-X=github.com/ollama/ollama/version.Version=${GITHUB_REF_NAME#v}\" \"-X=github.com/ollama/ollama/server.mode=release\"'" >>$GITHUB_OUTPUT echo GOFLAGS="'-ldflags=-w -s \"-X=github.com/ollama/ollama/version.Version=${GITHUB_REF_NAME#v}\" \"-X=github.com/ollama/ollama/server.mode=release\"'" >>$GITHUB_OUTPUT
darwin-build: darwin-build:
runs-on: macos-13-xlarge runs-on: macos-13
environment: release environment: release
needs: setup-environment needs: setup-environment
strategy: strategy:
@@ -54,6 +54,48 @@ jobs:
name: build-${{ matrix.os }}-${{ matrix.arch }} name: build-${{ matrix.os }}-${{ matrix.arch }}
path: dist/* path: dist/*
darwin-sign:
runs-on: macos-13
environment: release
needs: darwin-build
steps:
- uses: actions/checkout@v4
- run: |
echo $MACOS_SIGNING_KEY | base64 --decode > certificate.p12
security create-keychain -p password build.keychain
security default-keychain -s build.keychain
security unlock-keychain -p password build.keychain
security import certificate.p12 -k build.keychain -P $MACOS_SIGNING_KEY_PASSWORD -T /usr/bin/codesign
security set-key-partition-list -S apple-tool:,apple:,codesign: -s -k password build.keychain
security set-keychain-settings -lut 3600 build.keychain
env:
MACOS_SIGNING_KEY: ${{ secrets.MACOS_SIGNING_KEY }}
MACOS_SIGNING_KEY_PASSWORD: ${{ secrets.MACOS_SIGNING_KEY_PASSWORD }}
- uses: actions/download-artifact@v4
with:
name: build-darwin-amd64
path: dist/darwin-amd64
- uses: actions/download-artifact@v4
with:
name: build-darwin-arm64
path: dist/darwin-arm64
- run: |
export VERSION=${GITHUB_REF_NAME#v}
./scripts/build_darwin.sh sign macapp
env:
APPLE_IDENTITY: ${{ secrets.APPLE_IDENTITY }}
APPLE_PASSWORD: ${{ secrets.APPLE_PASSWORD }}
APPLE_TEAM_ID: ${{ vars.APPLE_TEAM_ID }}
APPLE_ID: ${{ vars.APPLE_ID }}
SDKROOT: /Applications/Xcode_14.1.0.app/Contents/Developer/Platforms/MacOSX.platform/Developer/SDKs/MacOSX.sdk
DEVELOPER_DIR: /Applications/Xcode_14.1.0.app/Contents/Developer
- uses: actions/upload-artifact@v4
with:
name: dist-darwin
path: |
dist/Ollama-darwin.zip
dist/ollama-darwin.tgz
windows-depends: windows-depends:
strategy: strategy:
matrix: matrix:
@@ -66,13 +108,11 @@ jobs:
preset: 'CUDA 12' preset: 'CUDA 12'
install: https://developer.download.nvidia.com/compute/cuda/12.8.0/local_installers/cuda_12.8.0_571.96_windows.exe install: https://developer.download.nvidia.com/compute/cuda/12.8.0/local_installers/cuda_12.8.0_571.96_windows.exe
cuda-version: '12.8' cuda-version: '12.8'
flags: ''
- os: windows - os: windows
arch: amd64 arch: amd64
preset: 'ROCm 6' preset: 'ROCm 6'
install: https://download.amd.com/developer/eula/rocm-hub/AMD-Software-PRO-Edition-24.Q4-WinSvr2022-For-HIP.exe install: https://download.amd.com/developer/eula/rocm-hub/AMD-Software-PRO-Edition-24.Q4-WinSvr2022-For-HIP.exe
rocm-version: '6.2' rocm-version: '6.2'
flags: '-DCMAKE_C_COMPILER=clang -DCMAKE_CXX_COMPILER=clang++ -DCMAKE_C_FLAGS="-parallel-jobs=4 -Wno-ignored-attributes -Wno-deprecated-pragma" -DCMAKE_CXX_FLAGS="-parallel-jobs=4 -Wno-ignored-attributes -Wno-deprecated-pragma"'
runs-on: ${{ matrix.arch == 'arm64' && format('{0}-{1}', matrix.os, matrix.arch) || matrix.os }} runs-on: ${{ matrix.arch == 'arm64' && format('{0}-{1}', matrix.os, matrix.arch) || matrix.os }}
environment: release environment: release
env: env:
@@ -115,9 +155,6 @@ jobs:
echo "$hipPath\bin" | Out-File -FilePath $env:GITHUB_PATH -Encoding utf8 -Append echo "$hipPath\bin" | Out-File -FilePath $env:GITHUB_PATH -Encoding utf8 -Append
echo "CC=$hipPath\bin\clang.exe" | Out-File -FilePath $env:GITHUB_ENV -Append echo "CC=$hipPath\bin\clang.exe" | Out-File -FilePath $env:GITHUB_ENV -Append
echo "CXX=$hipPath\bin\clang++.exe" | Out-File -FilePath $env:GITHUB_ENV -Append echo "CXX=$hipPath\bin\clang++.exe" | Out-File -FilePath $env:GITHUB_ENV -Append
echo "HIPCXX=$hipPath\bin\clang++.exe" | Out-File -FilePath $env:GITHUB_ENV -Append
echo "HIP_PLATFORM=amd" | Out-File -FilePath $env:GITHUB_ENV -Append
echo "CMAKE_PREFIX_PATH=$hipPath" | Out-File -FilePath $env:GITHUB_ENV -Append
- if: matrix.preset == 'CPU' - if: matrix.preset == 'CPU'
run: | run: |
echo "CC=clang.exe" | Out-File -FilePath $env:GITHUB_ENV -Append echo "CC=clang.exe" | Out-File -FilePath $env:GITHUB_ENV -Append
@@ -136,9 +173,9 @@ jobs:
key: ccache-${{ matrix.os }}-${{ matrix.arch }}-${{ matrix.preset }} key: ccache-${{ matrix.os }}-${{ matrix.arch }}-${{ matrix.preset }}
- name: Build target "${{ matrix.preset }}" - name: Build target "${{ matrix.preset }}"
run: | run: |
Import-Module 'C:\Program Files\Microsoft Visual Studio\2022\Enterprise\Common7\Tools\Microsoft.VisualStudio.DevShell.dll' Import-Module 'C:\Program Files (x86)\Microsoft Visual Studio\2019\Enterprise\Common7\Tools\Microsoft.VisualStudio.DevShell.dll'
Enter-VsDevShell -VsInstallPath 'C:\Program Files\Microsoft Visual Studio\2022\Enterprise' -SkipAutomaticLocation -DevCmdArguments '-arch=x64 -no_logo' Enter-VsDevShell -VsInstallPath 'C:\Program Files (x86)\Microsoft Visual Studio\2019\Enterprise' -SkipAutomaticLocation -DevCmdArguments '-arch=x64 -no_logo'
cmake --preset "${{ matrix.preset }}" ${{ matrix.flags }} cmake --preset "${{ matrix.preset }}"
cmake --build --parallel --preset "${{ matrix.preset }}" cmake --build --parallel --preset "${{ matrix.preset }}"
cmake --install build --component "${{ startsWith(matrix.preset, 'CUDA ') && 'CUDA' || startsWith(matrix.preset, 'ROCm ') && 'HIP' || 'CPU' }}" --strip --parallel 8 cmake --install build --component "${{ startsWith(matrix.preset, 'CUDA ') && 'CUDA' || startsWith(matrix.preset, 'ROCm ') && 'HIP' || 'CPU' }}" --strip --parallel 8
env: env:
@@ -188,11 +225,61 @@ jobs:
go-version-file: go.mod go-version-file: go.mod
- run: | - run: |
go build -o dist/${{ matrix.os }}-${{ matrix.arch }}/ . go build -o dist/${{ matrix.os }}-${{ matrix.arch }}/ .
- if: matrix.arch == 'arm64'
run: |
Invoke-WebRequest -Uri "https://aka.ms/vs/17/release/vc_redist.arm64.exe" -OutFile "dist\windows-arm64\vc_redist.arm64.exe"
- run: |
$env:VERSION='${{ github.ref_name }}' -Replace "v(.*)", '$1'
& .\scripts\build_windows.ps1 buildApp
env:
VCToolsRedistDir: stub
- uses: actions/upload-artifact@v4 - uses: actions/upload-artifact@v4
with: with:
name: build-${{ matrix.os }}-${{ matrix.arch }} name: build-${{ matrix.os }}-${{ matrix.arch }}
path: | path: |
dist\${{ matrix.os }}-${{ matrix.arch }}\*.exe dist\${{ matrix.os }}-${{ matrix.arch }}\*.exe
dist\${{ matrix.os }}-${{ matrix.arch }}-app.exe
windows-sign:
runs-on: windows-2022
environment: release
needs: [windows-depends, windows-build]
steps:
- uses: actions/checkout@v4
- uses: google-github-actions/auth@v2
with:
project_id: ollama
credentials_json: ${{ secrets.GOOGLE_SIGNING_CREDENTIALS }}
- run: |
$ErrorActionPreference = "Stop"
Invoke-WebRequest -Uri "https://go.microsoft.com/fwlink/p/?LinkId=323507" -OutFile "${{ runner.temp }}\sdksetup.exe"
Start-Process "${{ runner.temp }}\sdksetup.exe" -ArgumentList @("/q") -NoNewWindow -Wait
Invoke-WebRequest -Uri "https://github.com/GoogleCloudPlatform/kms-integrations/releases/download/cng-v1.0/kmscng-1.0-windows-amd64.zip" -OutFile "${{ runner.temp }}\plugin.zip"
Expand-Archive -Path "${{ runner.temp }}\plugin.zip" -DestinationPath "${{ runner.temp }}\plugin\"
& "${{ runner.temp }}\plugin\*\kmscng.msi" /quiet
echo "${{ vars.OLLAMA_CERT }}" >ollama_inc.crt
- uses: actions/download-artifact@v4
with:
pattern: build-windows-*
path: dist\
merge-multiple: true
- uses: actions/download-artifact@v4
with:
pattern: depends-windows-amd64-*
path: dist\windows-amd64\
merge-multiple: true
- run: |
& .\scripts\build_windows.ps1 gatherDependencies sign buildInstaller distZip
env:
KEY_CONTAINER: ${{ vars.KEY_CONTAINER }}
- uses: actions/upload-artifact@v4
with:
name: dist-windows
path: |
dist\OllamaSetup.exe
dist\ollama-windows-*.zip
linux-build: linux-build:
strategy: strategy:
@@ -225,26 +312,20 @@ jobs:
CGO_CFLAGS=${{ env.CGO_CFLAGS }} CGO_CFLAGS=${{ env.CGO_CFLAGS }}
CGO_CXXFLAGS=${{ env.CGO_CXXFLAGS }} CGO_CXXFLAGS=${{ env.CGO_CXXFLAGS }}
outputs: type=local,dest=dist/${{ matrix.os }}-${{ matrix.arch }} outputs: type=local,dest=dist/${{ matrix.os }}-${{ matrix.arch }}
cache-from: type=registry,ref=${{ vars.DOCKER_REPO }}:latest cache-from: type=registry,ref=ollama/ollama:latest
cache-to: type=inline cache-to: type=inline
- run: | - run: |
for COMPONENT in bin/* lib/ollama/*; do for COMPONENT in bin/* lib/ollama/*; do
case "$COMPONENT" in case "$COMPONENT" in
bin/ollama) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}.tar.in ;; bin/ollama) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}.tar.in ;;
lib/ollama/*.so*) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}.tar.in ;; lib/ollama/*.so) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}.tar.in ;;
lib/ollama/cuda_sbsa) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}.tar.in ;; lib/ollama/cuda_v12) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}.tar.in ;;
lib/ollama/cuda_jetpack5) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}-jetpack5.tar.in ;; lib/ollama/cuda_jetpack5) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}-jetpack5.tar.in ;;
lib/ollama/cuda_jetpack6) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}-jetpack6.tar.in ;; lib/ollama/cuda_jetpack6) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}-jetpack6.tar.in ;;
lib/ollama/rocm) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}-rocm.tar.in ;; lib/ollama/rocm) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}-rocm.tar.in ;;
esac esac
done done
working-directory: dist/${{ matrix.os }}-${{ matrix.arch }} working-directory: dist/${{ matrix.os }}-${{ matrix.arch }}
- run: |
echo "Manifests"
for ARCHIVE in dist/${{ matrix.os }}-${{ matrix.arch }}/*.tar.in ; do
echo $ARCHIVE
cat $ARCHIVE
done
- run: | - run: |
for ARCHIVE in dist/${{ matrix.os }}-${{ matrix.arch }}/*.tar.in; do for ARCHIVE in dist/${{ matrix.os }}-${{ matrix.arch }}/*.tar.in; do
tar c -C dist/${{ matrix.os }}-${{ matrix.arch }} -T $ARCHIVE --owner 0 --group 0 | pigz -9vc >$(basename ${ARCHIVE//.*/}.tgz); tar c -C dist/${{ matrix.os }}-${{ matrix.arch }} -T $ARCHIVE --owner 0 --group 0 | pigz -9vc >$(basename ${ARCHIVE//.*/}.tgz);
@@ -298,8 +379,8 @@ jobs:
context: . context: .
platforms: ${{ matrix.os }}/${{ matrix.arch }} platforms: ${{ matrix.os }}/${{ matrix.arch }}
build-args: ${{ matrix.build-args }} build-args: ${{ matrix.build-args }}
outputs: type=image,name=${{ vars.DOCKER_REPO }},push-by-digest=true,name-canonical=true,push=true outputs: type=image,name=ollama/ollama,push-by-digest=true,name-canonical=true,push=true
cache-from: type=registry,ref=${{ vars.DOCKER_REPO }}:latest cache-from: type=registry,ref=ollama/ollama:latest
cache-to: type=inline cache-to: type=inline
- run: | - run: |
mkdir -p ${{ matrix.os }}-${{ matrix.arch }} mkdir -p ${{ matrix.os }}-${{ matrix.arch }}
@@ -331,7 +412,7 @@ jobs:
latest=false latest=false
suffix=${{ matrix.suffix }} suffix=${{ matrix.suffix }}
images: | images: |
${{ vars.DOCKER_REPO }} ollama/ollama
tags: | tags: |
type=ref,enable=true,priority=600,prefix=pr-,event=pr type=ref,enable=true,priority=600,prefix=pr-,event=pr
type=semver,pattern={{version}} type=semver,pattern={{version}}
@@ -341,24 +422,56 @@ jobs:
path: ${{ runner.temp }} path: ${{ runner.temp }}
merge-multiple: true merge-multiple: true
- run: | - run: |
docker buildx imagetools create $(echo '${{ steps.metadata.outputs.json }}' | jq -cr '.tags | map("-t", .) | join(" ")') $(cat *-${{ matrix.suffix }}.txt | xargs printf '${{ vars.DOCKER_REPO }}@%s ') docker buildx imagetools create $(echo '${{ steps.metadata.outputs.json }}' | jq -cr '.tags | map("-t", .) | join(" ")') $(cat *-${{ matrix.suffix }}.txt | xargs printf 'ollama/ollama@%s ')
docker buildx imagetools inspect ${{ vars.DOCKER_REPO }}:${{ steps.metadata.outputs.version }} docker buildx imagetools inspect ollama/ollama:${{ steps.metadata.outputs.version }}
working-directory: ${{ runner.temp }} working-directory: ${{ runner.temp }}
# Trigger downstream release process # Trigger downstream release process
trigger: trigger:
runs-on: ubuntu-latest runs-on: ubuntu-latest
environment: release environment: release
needs: [darwin-build, windows-build, windows-depends, linux-build] needs: [darwin-build, windows-build, windows-depends]
steps:
- name: Trigger downstream release process
run: |
curl -L \
-X POST \
-H "Accept: application/vnd.github+json" \
-H "Authorization: Bearer ${{ secrets.RELEASE_TOKEN }}" \
-H "X-GitHub-Api-Version: 2022-11-28" \
https://api.github.com/repos/ollama/${{ vars.RELEASE_REPO }}/dispatches \
-d "{\"event_type\": \"trigger-workflow\", \"client_payload\": {\"run_id\": \"${GITHUB_RUN_ID}\", \"version\": \"${GITHUB_REF_NAME#v}\"}}"
# Aggregate all the assets and ship a release
release:
needs: [darwin-sign, windows-sign, linux-build]
runs-on: linux
environment: release
permissions: permissions:
contents: write contents: write
env: env:
GH_TOKEN: ${{ github.token }} GH_TOKEN: ${{ github.token }}
steps: steps:
- uses: actions/checkout@v4 - uses: actions/checkout@v4
- name: Create or update Release for tag - uses: actions/download-artifact@v4
with:
name: dist-darwin
path: dist
- uses: actions/download-artifact@v4
with:
name: dist-windows
path: dist
- uses: actions/download-artifact@v4
with:
pattern: dist-linux-*
path: dist
merge-multiple: true
- run: find . -type f -not -name 'sha256sum.txt' | xargs sha256sum | tee sha256sum.txt
working-directory: dist
- name: Create or update Release
run: | run: |
RELEASE_VERSION="$(echo ${GITHUB_REF_NAME} | cut -f1 -d-)" RELEASE_VERSION="$(echo ${GITHUB_REF_NAME} | cut -f1 -d-)"
echo "Looking for existing release for ${RELEASE_VERSION}" echo "Looking for existing release for ${RELEASE_VERSION}"
OLD_TAG=$(gh release ls --json name,tagName | jq -r ".[] | select(.name == \"${RELEASE_VERSION}\") | .tagName") OLD_TAG=$(gh release ls --json name,tagName | jq -r ".[] | select(.name == \"${RELEASE_VERSION}\") | .tagName")
if [ -n "$OLD_TAG" ]; then if [ -n "$OLD_TAG" ]; then
@@ -372,12 +485,5 @@ jobs:
--generate-notes \ --generate-notes \
--prerelease --prerelease
fi fi
- name: Trigger downstream release process echo "Uploading artifacts for tag ${GITHUB_REF_NAME}"
run: | gh release upload ${GITHUB_REF_NAME} dist/* --clobber
curl -L \
-X POST \
-H "Accept: application/vnd.github+json" \
-H "Authorization: Bearer ${{ secrets.RELEASE_TOKEN }}" \
-H "X-GitHub-Api-Version: 2022-11-28" \
https://api.github.com/repos/ollama/${{ vars.RELEASE_REPO }}/dispatches \
-d "{\"event_type\": \"trigger-workflow\", \"client_payload\": {\"run_id\": \"${GITHUB_RUN_ID}\", \"version\": \"${GITHUB_REF_NAME#v}\", \"origin\": \"${GITHUB_REPOSITORY}\", \"publish\": \"1\"}}"

View File

@@ -36,7 +36,7 @@ jobs:
| xargs python3 -c "import sys; from pathlib import Path; print(any(Path(x).match(glob) for x in sys.argv[1:] for glob in '$*'.split(' ')))" | xargs python3 -c "import sys; from pathlib import Path; print(any(Path(x).match(glob) for x in sys.argv[1:] for glob in '$*'.split(' ')))"
} }
echo changed=$(changed 'llama/llama.cpp/**/*' 'ml/backend/ggml/ggml/**/*') | tee -a $GITHUB_OUTPUT echo changed=$(changed 'llama/llama.cpp/**' 'ml/backend/ggml/ggml/**') | tee -a $GITHUB_OUTPUT
linux: linux:
needs: [changes] needs: [changes]
@@ -82,7 +82,7 @@ jobs:
flags: '-DCMAKE_CUDA_ARCHITECTURES=80' flags: '-DCMAKE_CUDA_ARCHITECTURES=80'
- preset: ROCm - preset: ROCm
install: https://download.amd.com/developer/eula/rocm-hub/AMD-Software-PRO-Edition-24.Q4-WinSvr2022-For-HIP.exe install: https://download.amd.com/developer/eula/rocm-hub/AMD-Software-PRO-Edition-24.Q4-WinSvr2022-For-HIP.exe
flags: '-DAMDGPU_TARGETS=gfx1010 -DCMAKE_C_COMPILER=clang -DCMAKE_CXX_COMPILER=clang++ -DCMAKE_C_FLAGS="-parallel-jobs=4 -Wno-ignored-attributes -Wno-deprecated-pragma" -DCMAKE_CXX_FLAGS="-parallel-jobs=4 -Wno-ignored-attributes -Wno-deprecated-pragma"' flags: '-DAMDGPU_TARGETS=gfx1010'
runs-on: windows runs-on: windows
steps: steps:
- run: | - run: |
@@ -120,9 +120,6 @@ jobs:
echo "$hipPath\bin" | Out-File -FilePath $env:GITHUB_PATH -Encoding utf8 -Append echo "$hipPath\bin" | Out-File -FilePath $env:GITHUB_PATH -Encoding utf8 -Append
echo "CC=$hipPath\bin\clang.exe" | Out-File -FilePath $env:GITHUB_ENV -Append echo "CC=$hipPath\bin\clang.exe" | Out-File -FilePath $env:GITHUB_ENV -Append
echo "CXX=$hipPath\bin\clang++.exe" | Out-File -FilePath $env:GITHUB_ENV -Append echo "CXX=$hipPath\bin\clang++.exe" | Out-File -FilePath $env:GITHUB_ENV -Append
echo "HIPCXX=$hipPath\bin\clang++.exe" | Out-File -FilePath $env:GITHUB_ENV -Append
echo "HIP_PLATFORM=amd" | Out-File -FilePath $env:GITHUB_ENV -Append
echo "CMAKE_PREFIX_PATH=$hipPath" | Out-File -FilePath $env:GITHUB_ENV -Append
- if: ${{ !cancelled() && steps.cache-install.outputs.cache-hit != 'true' }} - if: ${{ !cancelled() && steps.cache-install.outputs.cache-hit != 'true' }}
uses: actions/cache/save@v4 uses: actions/cache/save@v4
with: with:
@@ -136,8 +133,8 @@ jobs:
path: ${{ github.workspace }}\.ccache path: ${{ github.workspace }}\.ccache
key: ccache-${{ runner.os }}-${{ runner.arch }}-${{ matrix.preset }} key: ccache-${{ runner.os }}-${{ runner.arch }}-${{ matrix.preset }}
- run: | - run: |
Import-Module 'C:\Program Files\Microsoft Visual Studio\2022\Enterprise\Common7\Tools\Microsoft.VisualStudio.DevShell.dll' Import-Module 'C:\Program Files (x86)\Microsoft Visual Studio\2019\Enterprise\Common7\Tools\Microsoft.VisualStudio.DevShell.dll'
Enter-VsDevShell -VsInstallPath 'C:\Program Files\Microsoft Visual Studio\2022\Enterprise' -SkipAutomaticLocation -DevCmdArguments '-arch=x64 -no_logo' Enter-VsDevShell -VsInstallPath 'C:\Program Files (x86)\Microsoft Visual Studio\2019\Enterprise' -SkipAutomaticLocation -DevCmdArguments '-arch=x64 -no_logo'
cmake --preset "${{ matrix.preset }}" ${{ matrix.flags }} cmake --preset "${{ matrix.preset }}" ${{ matrix.flags }}
cmake --build --parallel --preset "${{ matrix.preset }}" cmake --build --parallel --preset "${{ matrix.preset }}"
env: env:

1
.gitignore vendored
View File

@@ -14,3 +14,4 @@ __debug_bin*
llama/build llama/build
llama/vendor llama/vendor
/ollama /ollama
/model/testdata/forward

View File

@@ -19,8 +19,8 @@ linters:
- nolintlint - nolintlint
- nosprintfhostport - nosprintfhostport
- staticcheck - staticcheck
- tenv
- unconvert - unconvert
- usetesting
- wastedassign - wastedassign
- whitespace - whitespace
disable: disable:

View File

@@ -51,8 +51,6 @@ include_directories(${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/ggml/ggml/src/include
include_directories(${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/ggml/ggml/src/ggml-cpu) include_directories(${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/ggml/ggml/src/ggml-cpu)
include_directories(${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/ggml/ggml/src/ggml-cpu/amx) include_directories(${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/ggml/ggml/src/ggml-cpu/amx)
add_compile_definitions(NDEBUG)
set(GGML_CPU ON) set(GGML_CPU ON)
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/ggml/ggml/src) add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/ggml/ggml/src)
set_property(TARGET ggml PROPERTY EXCLUDE_FROM_ALL TRUE) set_property(TARGET ggml PROPERTY EXCLUDE_FROM_ALL TRUE)
@@ -78,13 +76,14 @@ if(CMAKE_CUDA_COMPILER)
find_package(CUDAToolkit) find_package(CUDAToolkit)
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/ggml/ggml/src/ggml-cuda) add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/ggml/ggml/src/ggml-cuda)
set(OLLAMA_CUDA_INSTALL_DIR ${OLLAMA_INSTALL_DIR}/cuda_v${CUDAToolkit_VERSION_MAJOR})
install(TARGETS ggml-cuda install(TARGETS ggml-cuda
RUNTIME_DEPENDENCIES RUNTIME_DEPENDENCIES
DIRECTORIES ${CUDAToolkit_BIN_DIR} ${CUDAToolkit_LIBRARY_DIR} DIRECTORIES ${CUDAToolkit_BIN_DIR} ${CUDAToolkit_LIBRARY_DIR}
PRE_INCLUDE_REGEXES cublas cublasLt cudart PRE_INCLUDE_REGEXES cublas cublasLt cudart
PRE_EXCLUDE_REGEXES ".*" PRE_EXCLUDE_REGEXES ".*"
RUNTIME DESTINATION ${OLLAMA_INSTALL_DIR} COMPONENT CUDA RUNTIME DESTINATION ${OLLAMA_CUDA_INSTALL_DIR} COMPONENT CUDA
LIBRARY DESTINATION ${OLLAMA_INSTALL_DIR} COMPONENT CUDA LIBRARY DESTINATION ${OLLAMA_CUDA_INSTALL_DIR} COMPONENT CUDA
) )
endif() endif()
@@ -115,11 +114,7 @@ if(CMAKE_HIP_COMPILER)
set(OLLAMA_HIP_INSTALL_DIR ${OLLAMA_INSTALL_DIR}/rocm) set(OLLAMA_HIP_INSTALL_DIR ${OLLAMA_INSTALL_DIR}/rocm)
install(TARGETS ggml-hip install(TARGETS ggml-hip
RUNTIME_DEPENDENCY_SET rocm RUNTIME_DEPENDENCIES
RUNTIME DESTINATION ${OLLAMA_INSTALL_DIR} COMPONENT HIP
LIBRARY DESTINATION ${OLLAMA_INSTALL_DIR} COMPONENT HIP
)
install(RUNTIME_DEPENDENCY_SET rocm
DIRECTORIES ${HIP_BIN_INSTALL_DIR} ${HIP_LIB_INSTALL_DIR} DIRECTORIES ${HIP_BIN_INSTALL_DIR} ${HIP_LIB_INSTALL_DIR}
PRE_INCLUDE_REGEXES hipblas rocblas amdhip64 rocsolver amd_comgr hsa-runtime64 rocsparse tinfo rocprofiler-register drm drm_amdgpu numa elf PRE_INCLUDE_REGEXES hipblas rocblas amdhip64 rocsolver amd_comgr hsa-runtime64 rocsparse tinfo rocprofiler-register drm drm_amdgpu numa elf
PRE_EXCLUDE_REGEXES ".*" PRE_EXCLUDE_REGEXES ".*"

View File

@@ -22,7 +22,7 @@
"inherits": [ "CUDA" ], "inherits": [ "CUDA" ],
"cacheVariables": { "cacheVariables": {
"CMAKE_CUDA_ARCHITECTURES": "50;60;61;70;75;80;86;87;89;90;90a;120", "CMAKE_CUDA_ARCHITECTURES": "50;60;61;70;75;80;86;87;89;90;90a;120",
"CMAKE_CUDA_FLAGS": "-Wno-deprecated-gpu-targets -t 2" "CMAKE_CUDA_FLAGS": "-Wno-deprecated-gpu-targets"
} }
}, },
{ {
@@ -50,7 +50,6 @@
"name": "ROCm 6", "name": "ROCm 6",
"inherits": [ "ROCm" ], "inherits": [ "ROCm" ],
"cacheVariables": { "cacheVariables": {
"CMAKE_HIP_FLAGS": "-parallel-jobs=4",
"AMDGPU_TARGETS": "gfx900;gfx940;gfx941;gfx942;gfx1010;gfx1012;gfx1030;gfx1100;gfx1101;gfx1102;gfx1151;gfx1200;gfx1201;gfx906:xnack-;gfx908:xnack-;gfx90a:xnack+;gfx90a:xnack-" "AMDGPU_TARGETS": "gfx900;gfx940;gfx941;gfx942;gfx1010;gfx1012;gfx1030;gfx1100;gfx1101;gfx1102;gfx1151;gfx1200;gfx1201;gfx906:xnack-;gfx908:xnack-;gfx90a:xnack+;gfx90a:xnack-"
} }
} }

View File

@@ -65,7 +65,7 @@ continuation of the sentence:
Examples: Examples:
llm/backend/mlx: support the llama architecture llm/backend/mlx: support the llama architecture
CONTRIBUTING: provide clarity on good commit messages, and bad CONTRIBUTING: provide clairity on good commit messages, and bad
Bad Examples: Bad Examples:

View File

@@ -7,15 +7,10 @@ ARG JETPACK5VERSION=r35.4.1
ARG JETPACK6VERSION=r36.4.0 ARG JETPACK6VERSION=r36.4.0
ARG CMAKEVERSION=3.31.2 ARG CMAKEVERSION=3.31.2
# We require gcc v10 minimum. v10.3 has regressions, so the rockylinux 8.5 AppStream has the latest compatible version
FROM --platform=linux/amd64 rocm/dev-almalinux-8:${ROCMVERSION}-complete AS base-amd64 FROM --platform=linux/amd64 rocm/dev-almalinux-8:${ROCMVERSION}-complete AS base-amd64
RUN yum install -y yum-utils \ RUN yum install -y yum-utils \
&& yum-config-manager --add-repo https://dl.rockylinux.org/vault/rocky/8.5/AppStream/\$basearch/os/ \
&& rpm --import https://dl.rockylinux.org/pub/rocky/RPM-GPG-KEY-Rocky-8 \
&& dnf install -y yum-utils ccache gcc-toolset-10-gcc-10.2.1-8.2.el8 gcc-toolset-10-gcc-c++-10.2.1-8.2.el8 gcc-toolset-10-binutils-2.35-11.el8 \
&& dnf install -y ccache \ && dnf install -y ccache \
&& yum-config-manager --add-repo https://developer.download.nvidia.com/compute/cuda/repos/rhel8/x86_64/cuda-rhel8.repo && yum-config-manager --add-repo https://developer.download.nvidia.com/compute/cuda/repos/rhel8/x86_64/cuda-rhel8.repo
ENV PATH=/opt/rh/gcc-toolset-10/root/usr/bin:$PATH
FROM --platform=linux/arm64 almalinux:8 AS base-arm64 FROM --platform=linux/arm64 almalinux:8 AS base-arm64
# install epel-release for ccache # install epel-release for ccache
@@ -90,21 +85,21 @@ RUN --mount=type=cache,target=/root/.cache/go-build \
go build -trimpath -buildmode=pie -o /bin/ollama . go build -trimpath -buildmode=pie -o /bin/ollama .
FROM --platform=linux/amd64 scratch AS amd64 FROM --platform=linux/amd64 scratch AS amd64
COPY --from=cuda-12 dist/lib/ollama /lib/ollama COPY --from=cuda-12 dist/lib/ollama/cuda_v12 /lib/ollama/cuda_v12
FROM --platform=linux/arm64 scratch AS arm64 FROM --platform=linux/arm64 scratch AS arm64
COPY --from=cuda-12 dist/lib/ollama /lib/ollama/cuda_sbsa COPY --from=cuda-12 dist/lib/ollama/cuda_v12 /lib/ollama/cuda_v12
COPY --from=jetpack-5 dist/lib/ollama /lib/ollama/cuda_jetpack5 COPY --from=jetpack-5 dist/lib/ollama/cuda_v11 /lib/ollama/cuda_jetpack5
COPY --from=jetpack-6 dist/lib/ollama /lib/ollama/cuda_jetpack6 COPY --from=jetpack-6 dist/lib/ollama/cuda_v12 /lib/ollama/cuda_jetpack6
FROM scratch AS rocm FROM scratch AS rocm
COPY --from=rocm-6 dist/lib/ollama /lib/ollama COPY --from=rocm-6 dist/lib/ollama/rocm /lib/ollama/rocm
FROM ${FLAVOR} AS archive FROM ${FLAVOR} AS archive
COPY --from=cpu dist/lib/ollama /lib/ollama COPY --from=cpu dist/lib/ollama /lib/ollama
COPY --from=build /bin/ollama /bin/ollama COPY --from=build /bin/ollama /bin/ollama
FROM ubuntu:24.04 FROM ubuntu:20.04
RUN apt-get update \ RUN apt-get update \
&& apt-get install -y ca-certificates \ && apt-get install -y ca-certificates \
&& apt-get clean \ && apt-get clean \

View File

@@ -1,6 +1,6 @@
UPSTREAM=https://github.com/ggerganov/llama.cpp.git UPSTREAM=https://github.com/ggerganov/llama.cpp.git
WORKDIR=llama/vendor WORKDIR=llama/vendor
FETCH_HEAD=de4c07f93783a1a96456a44dc16b9db538ee1618 FETCH_HEAD=e1e8e0991ffd9e99a445c6812bb519d5bac9f4b5
.PHONY: help .PHONY: help
help: help:
@@ -15,13 +15,11 @@ help:
@echo " make -f $(lastword $(MAKEFILE_LIST)) clean sync" @echo " make -f $(lastword $(MAKEFILE_LIST)) clean sync"
.PHONY: sync .PHONY: sync
sync: llama/build-info.cpp ml/backend/ggml/ggml/src/ggml-metal/ggml-metal-embed.metal sync: llama/build-info.cpp llama/llama.cpp ml/backend/ggml/ggml
llama/build-info.cpp: llama/build-info.cpp.in llama/llama.cpp .PHONY: llama/build-info.cpp
sed -e 's|@FETCH_HEAD@|$(FETCH_HEAD)|' <$< >$@ llama/build-info.cpp: llama/build-info.cpp.in
sed -e 's|@FETCH_HEAD@|$(FETCH_HEAD)|' $< > $@
ml/backend/ggml/ggml/src/ggml-metal/ggml-metal-embed.metal: ml/backend/ggml/ggml
go generate ./$(@D)
.PHONY: llama/llama.cpp .PHONY: llama/llama.cpp
llama/llama.cpp: llama/vendor/ llama/llama.cpp: llama/vendor/
@@ -32,13 +30,12 @@ ml/backend/ggml/ggml: llama/vendor/ggml/
rsync -arvzc -f "merge $@/.rsync-filter" $< $@ rsync -arvzc -f "merge $@/.rsync-filter" $< $@
PATCHES=$(wildcard llama/patches/*.patch) PATCHES=$(wildcard llama/patches/*.patch)
PATCHED=$(join $(dir $(PATCHES)), $(addsuffix ed, $(addprefix ., $(notdir $(PATCHES)))))
.PHONY: apply-patches .PHONY: apply-patches
.NOTPARALLEL: .NOTPARALLEL:
apply-patches: $(PATCHED) apply-patches: $(addsuffix ed, $(PATCHES))
llama/patches/.%.patched: llama/patches/%.patch %.patched: %.patch
@if git -c user.name=nobody -c 'user.email=<>' -C $(WORKDIR) am -3 $(realpath $<); then touch $@; else git -C $(WORKDIR) am --abort; exit 1; fi @if git -c user.name=nobody -c 'user.email=<>' -C $(WORKDIR) am -3 $(realpath $<); then touch $@; else git -C $(WORKDIR) am --abort; exit 1; fi
.PHONY: checkout .PHONY: checkout
@@ -60,4 +57,4 @@ format-patches: llama/patches
.PHONE: clean .PHONE: clean
clean: checkout clean: checkout
$(RM) llama/patches/.*.patched $(RM) $(addsuffix ed, $(PATCHES))

View File

@@ -1,6 +1,6 @@
<div align="center"> <div align="center">
  <a href="https://ollama.com">   <a href="https://ollama.com">
<img alt="ollama" width="240" src="https://github.com/ollama/ollama/assets/3325447/0d0b44e2-8f4a-4e99-9b52-a5c1c741c8f7"> <img alt="ollama" height="200px" src="https://github.com/ollama/ollama/assets/3325447/0d0b44e2-8f4a-4e99-9b52-a5c1c741c8f7">
</a> </a>
</div> </div>
@@ -10,7 +10,7 @@ Get up and running with large language models.
### macOS ### macOS
[Download](https://ollama.com/download/Ollama.dmg) [Download](https://ollama.com/download/Ollama-darwin.zip)
### Windows ### Windows
@@ -40,10 +40,10 @@ The official [Ollama Docker image](https://hub.docker.com/r/ollama/ollama) `olla
## Quickstart ## Quickstart
To run and chat with [Gemma 3](https://ollama.com/library/gemma3): To run and chat with [Llama 3.2](https://ollama.com/library/llama3.2):
```shell ```shell
ollama run gemma3 ollama run llama3.2
``` ```
## Model library ## Model library
@@ -315,7 +315,6 @@ See the [API documentation](./docs/api.md) for all endpoints.
- [Ollama-chats RPG](https://github.com/drazdra/ollama-chats) - [Ollama-chats RPG](https://github.com/drazdra/ollama-chats)
- [IntelliBar](https://intellibar.app/) (AI-powered assistant for macOS) - [IntelliBar](https://intellibar.app/) (AI-powered assistant for macOS)
- [Jirapt](https://github.com/AliAhmedNada/jirapt) (Jira Integration to generate issues, tasks, epics) - [Jirapt](https://github.com/AliAhmedNada/jirapt) (Jira Integration to generate issues, tasks, epics)
- [ojira](https://github.com/AliAhmedNada/ojira) (Jira chrome plugin to easily generate descriptions for tasks)
- [QA-Pilot](https://github.com/reid41/QA-Pilot) (Interactive chat tool that can leverage Ollama models for rapid understanding and navigation of GitHub code repositories) - [QA-Pilot](https://github.com/reid41/QA-Pilot) (Interactive chat tool that can leverage Ollama models for rapid understanding and navigation of GitHub code repositories)
- [ChatOllama](https://github.com/sugarforever/chat-ollama) (Open Source Chatbot based on Ollama with Knowledge Bases) - [ChatOllama](https://github.com/sugarforever/chat-ollama) (Open Source Chatbot based on Ollama with Knowledge Bases)
- [CRAG Ollama Chat](https://github.com/Nagi-ovo/CRAG-Ollama-Chat) (Simple Web Search with Corrective RAG) - [CRAG Ollama Chat](https://github.com/Nagi-ovo/CRAG-Ollama-Chat) (Simple Web Search with Corrective RAG)
@@ -360,7 +359,7 @@ See the [API documentation](./docs/api.md) for all endpoints.
- [Tkinter-based client](https://github.com/chyok/ollama-gui) (Python tkinter-based Client for Ollama) - [Tkinter-based client](https://github.com/chyok/ollama-gui) (Python tkinter-based Client for Ollama)
- [LLMChat](https://github.com/trendy-design/llmchat) (Privacy focused, 100% local, intuitive all-in-one chat interface) - [LLMChat](https://github.com/trendy-design/llmchat) (Privacy focused, 100% local, intuitive all-in-one chat interface)
- [Local Multimodal AI Chat](https://github.com/Leon-Sander/Local-Multimodal-AI-Chat) (Ollama-based LLM Chat with support for multiple features, including PDF RAG, voice chat, image-based interactions, and integration with OpenAI.) - [Local Multimodal AI Chat](https://github.com/Leon-Sander/Local-Multimodal-AI-Chat) (Ollama-based LLM Chat with support for multiple features, including PDF RAG, voice chat, image-based interactions, and integration with OpenAI.)
- [ARGO](https://github.com/xark-argo/argo) (Locally download and run Ollama and Huggingface models with RAG and deep research on Mac/Windows/Linux) - [ARGO](https://github.com/xark-argo/argo) (Locally download and run Ollama and Huggingface models with RAG on Mac/Windows/Linux)
- [OrionChat](https://github.com/EliasPereirah/OrionChat) - OrionChat is a web interface for chatting with different AI providers - [OrionChat](https://github.com/EliasPereirah/OrionChat) - OrionChat is a web interface for chatting with different AI providers
- [G1](https://github.com/bklieger-groq/g1) (Prototype of using prompting strategies to improve the LLM's reasoning through o1-like reasoning chains.) - [G1](https://github.com/bklieger-groq/g1) (Prototype of using prompting strategies to improve the LLM's reasoning through o1-like reasoning chains.)
- [Web management](https://github.com/lemonit-eric-mao/ollama-web-management) (Web management page) - [Web management](https://github.com/lemonit-eric-mao/ollama-web-management) (Web management page)
@@ -405,12 +404,6 @@ See the [API documentation](./docs/api.md) for all endpoints.
- [Writeopia](https://github.com/Writeopia/Writeopia) (Text editor with integration with Ollama) - [Writeopia](https://github.com/Writeopia/Writeopia) (Text editor with integration with Ollama)
- [AppFlowy](https://github.com/AppFlowy-IO/AppFlowy) (AI collaborative workspace with Ollama, cross-platform and self-hostable) - [AppFlowy](https://github.com/AppFlowy-IO/AppFlowy) (AI collaborative workspace with Ollama, cross-platform and self-hostable)
- [Lumina](https://github.com/cushydigit/lumina.git) (A lightweight, minimal React.js frontend for interacting with Ollama servers) - [Lumina](https://github.com/cushydigit/lumina.git) (A lightweight, minimal React.js frontend for interacting with Ollama servers)
- [Tiny Notepad](https://pypi.org/project/tiny-notepad) (A lightweight, notepad-like interface to chat with ollama available on PyPI)
- [macLlama (macOS native)](https://github.com/hellotunamayo/macLlama) (A native macOS GUI application for interacting with Ollama models, featuring a chat interface.)
- [GPTranslate](https://github.com/philberndt/GPTranslate) (A fast and lightweight, AI powered desktop translation application written with Rust and Tauri. Features real-time translation with OpenAI/Azure/Ollama.)
- [ollama launcher](https://github.com/NGC13009/ollama-launcher) (A launcher for Ollama, aiming to provide users with convenient functions such as ollama server launching, management, or configuration.)
- [ai-hub](https://github.com/Aj-Seven/ai-hub) (AI Hub supports multiple models via API keys and Chat support via Ollama API.)
- [Mayan EDMS](https://gitlab.com/mayan-edms/mayan-edms) (Open source document management system to organize, tag, search, and automate your files with powerful Ollama driven workflows.)
### Cloud ### Cloud
@@ -454,9 +447,6 @@ See the [API documentation](./docs/api.md) for all endpoints.
- [orbiton](https://github.com/xyproto/orbiton) Configuration-free text editor and IDE with support for tab completion with Ollama. - [orbiton](https://github.com/xyproto/orbiton) Configuration-free text editor and IDE with support for tab completion with Ollama.
- [orca-cli](https://github.com/molbal/orca-cli) Ollama Registry CLI Application - Browse, pull, and download models from Ollama Registry in your terminal. - [orca-cli](https://github.com/molbal/orca-cli) Ollama Registry CLI Application - Browse, pull, and download models from Ollama Registry in your terminal.
- [GGUF-to-Ollama](https://github.com/jonathanhecl/gguf-to-ollama) - Importing GGUF to Ollama made easy (multiplatform) - [GGUF-to-Ollama](https://github.com/jonathanhecl/gguf-to-ollama) - Importing GGUF to Ollama made easy (multiplatform)
- [AWS-Strands-With-Ollama](https://github.com/rapidarchitect/ollama_strands) - AWS Strands Agents with Ollama Examples
- [ollama-multirun](https://github.com/attogram/ollama-multirun) - A bash shell script to run a single prompt against any or all of your locally installed ollama models, saving the output and performance statistics as easily navigable web pages. ([Demo](https://attogram.github.io/ai_test_zone/))
- [ollama-bash-toolshed](https://github.com/attogram/ollama-bash-toolshed) - Bash scripts to chat with tool using models. Add new tools to your shed with ease. Runs on Ollama.
### Apple Vision Pro ### Apple Vision Pro
@@ -536,7 +526,6 @@ See the [API documentation](./docs/api.md) for all endpoints.
- [Abso](https://github.com/lunary-ai/abso) (OpenAI-compatible TypeScript SDK for any LLM provider) - [Abso](https://github.com/lunary-ai/abso) (OpenAI-compatible TypeScript SDK for any LLM provider)
- [Nichey](https://github.com/goodreasonai/nichey) is a Python package for generating custom wikis for your research topic - [Nichey](https://github.com/goodreasonai/nichey) is a Python package for generating custom wikis for your research topic
- [Ollama for D](https://github.com/kassane/ollama-d) - [Ollama for D](https://github.com/kassane/ollama-d)
- [OllamaPlusPlus](https://github.com/HardCodeDev777/OllamaPlusPlus) (Very simple C++ library for Ollama)
### Mobile ### Mobile
@@ -593,14 +582,10 @@ See the [API documentation](./docs/api.md) for all endpoints.
- [Simple-Discord-AI](https://github.com/zyphixor/simple-discord-ai) - [Simple-Discord-AI](https://github.com/zyphixor/simple-discord-ai)
- [LLM Telegram Bot](https://github.com/innightwolfsleep/llm_telegram_bot) (telegram bot, primary for RP. Oobabooga-like buttons, [A1111](https://github.com/AUTOMATIC1111/stable-diffusion-webui) API integration e.t.c) - [LLM Telegram Bot](https://github.com/innightwolfsleep/llm_telegram_bot) (telegram bot, primary for RP. Oobabooga-like buttons, [A1111](https://github.com/AUTOMATIC1111/stable-diffusion-webui) API integration e.t.c)
- [mcp-llm](https://github.com/sammcj/mcp-llm) (MCP Server to allow LLMs to call other LLMs) - [mcp-llm](https://github.com/sammcj/mcp-llm) (MCP Server to allow LLMs to call other LLMs)
- [SimpleOllamaUnity](https://github.com/HardCodeDev777/SimpleOllamaUnity) (Unity Engine extension for communicating with Ollama in a few lines of code. Also works at runtime)
- [UnityCodeLama](https://github.com/HardCodeDev777/UnityCodeLama) (Unity Edtior tool to analyze scripts via Ollama)
- [NativeMind](https://github.com/NativeMindBrowser/NativeMindExtension) (Private, on-device AI Assistant, no cloud dependencies)
- [GMAI - Gradle Managed AI](https://gmai.premex.se/) (Gradle plugin for automated Ollama lifecycle management during build phases)
### Supported backends ### Supported backends
- [llama.cpp](https://github.com/ggml-org/llama.cpp) project founded by Georgi Gerganov. - [llama.cpp](https://github.com/ggerganov/llama.cpp) project founded by Georgi Gerganov.
### Observability ### Observability
- [Opik](https://www.comet.com/docs/opik/cookbook/ollama) is an open-source platform to debug, evaluate, and monitor your LLM applications, RAG systems, and agentic workflows with comprehensive tracing, automated evaluations, and production-ready dashboards. Opik supports native intergration to Ollama. - [Opik](https://www.comet.com/docs/opik/cookbook/ollama) is an open-source platform to debug, evaluate, and monitor your LLM applications, RAG systems, and agentic workflows with comprehensive tracing, automated evaluations, and production-ready dashboards. Opik supports native intergration to Ollama.

View File

@@ -24,10 +24,7 @@ import (
"net/http" "net/http"
"net/url" "net/url"
"runtime" "runtime"
"strconv"
"time"
"github.com/ollama/ollama/auth"
"github.com/ollama/ollama/envconfig" "github.com/ollama/ollama/envconfig"
"github.com/ollama/ollama/format" "github.com/ollama/ollama/format"
"github.com/ollama/ollama/version" "github.com/ollama/ollama/version"
@@ -79,14 +76,6 @@ func NewClient(base *url.URL, http *http.Client) *Client {
} }
} }
func getAuthorizationToken(ctx context.Context, challenge string) (string, error) {
token, err := auth.Sign(ctx, []byte(challenge))
if err != nil {
return "", err
}
return token, nil
}
func (c *Client) do(ctx context.Context, method, path string, reqData, respData any) error { func (c *Client) do(ctx context.Context, method, path string, reqData, respData any) error {
var reqBody io.Reader var reqBody io.Reader
var data []byte var data []byte
@@ -108,21 +97,6 @@ func (c *Client) do(ctx context.Context, method, path string, reqData, respData
} }
requestURL := c.base.JoinPath(path) requestURL := c.base.JoinPath(path)
var token string
if envconfig.UseAuth() || c.base.Hostname() == "ollama.com" {
now := strconv.FormatInt(time.Now().Unix(), 10)
chal := fmt.Sprintf("%s,%s?ts=%s", method, path, now)
token, err = getAuthorizationToken(ctx, chal)
if err != nil {
return err
}
q := requestURL.Query()
q.Set("ts", now)
requestURL.RawQuery = q.Encode()
}
request, err := http.NewRequestWithContext(ctx, method, requestURL.String(), reqBody) request, err := http.NewRequestWithContext(ctx, method, requestURL.String(), reqBody)
if err != nil { if err != nil {
return err return err
@@ -132,10 +106,6 @@ func (c *Client) do(ctx context.Context, method, path string, reqData, respData
request.Header.Set("Accept", "application/json") request.Header.Set("Accept", "application/json")
request.Header.Set("User-Agent", fmt.Sprintf("ollama/%s (%s %s) Go/%s", version.Version, runtime.GOARCH, runtime.GOOS, runtime.Version())) request.Header.Set("User-Agent", fmt.Sprintf("ollama/%s (%s %s) Go/%s", version.Version, runtime.GOARCH, runtime.GOOS, runtime.Version()))
if token != "" {
request.Header.Set("Authorization", token)
}
respObj, err := c.http.Do(request) respObj, err := c.http.Do(request)
if err != nil { if err != nil {
return err return err
@@ -173,22 +143,6 @@ func (c *Client) stream(ctx context.Context, method, path string, data any, fn f
} }
requestURL := c.base.JoinPath(path) requestURL := c.base.JoinPath(path)
var token string
if envconfig.UseAuth() || c.base.Hostname() == "ollama.com" {
var err error
now := strconv.FormatInt(time.Now().Unix(), 10)
chal := fmt.Sprintf("%s,%s?ts=%s", method, path, now)
token, err = getAuthorizationToken(ctx, chal)
if err != nil {
return err
}
q := requestURL.Query()
q.Set("ts", now)
requestURL.RawQuery = q.Encode()
}
request, err := http.NewRequestWithContext(ctx, method, requestURL.String(), buf) request, err := http.NewRequestWithContext(ctx, method, requestURL.String(), buf)
if err != nil { if err != nil {
return err return err
@@ -198,10 +152,6 @@ func (c *Client) stream(ctx context.Context, method, path string, data any, fn f
request.Header.Set("Accept", "application/x-ndjson") request.Header.Set("Accept", "application/x-ndjson")
request.Header.Set("User-Agent", fmt.Sprintf("ollama/%s (%s %s) Go/%s", version.Version, runtime.GOARCH, runtime.GOOS, runtime.Version())) request.Header.Set("User-Agent", fmt.Sprintf("ollama/%s (%s %s) Go/%s", version.Version, runtime.GOARCH, runtime.GOOS, runtime.Version()))
if token != "" {
request.Header.Set("Authorization", token)
}
response, err := c.http.Do(request) response, err := c.http.Do(request)
if err != nil { if err != nil {
return err return err
@@ -222,6 +172,10 @@ func (c *Client) stream(ctx context.Context, method, path string, data any, fn f
return fmt.Errorf("unmarshal: %w", err) return fmt.Errorf("unmarshal: %w", err)
} }
if errorResponse.Error != "" {
return errors.New(errorResponse.Error)
}
if response.StatusCode >= http.StatusBadRequest { if response.StatusCode >= http.StatusBadRequest {
return StatusError{ return StatusError{
StatusCode: response.StatusCode, StatusCode: response.StatusCode,
@@ -230,10 +184,6 @@ func (c *Client) stream(ctx context.Context, method, path string, data any, fn f
} }
} }
if errorResponse.Error != "" {
return errors.New(errorResponse.Error)
}
if err := fn(bts); err != nil { if err := fn(bts); err != nil {
return err return err
} }

View File

@@ -1,6 +1,7 @@
package api package api
import ( import (
"context"
"encoding/json" "encoding/json"
"fmt" "fmt"
"net/http" "net/http"
@@ -89,16 +90,6 @@ func TestClientStream(t *testing.T) {
}, },
wantErr: "mid-stream error", wantErr: "mid-stream error",
}, },
{
name: "http status error takes precedence over general error",
responses: []any{
testError{
message: "custom error message",
statusCode: http.StatusInternalServerError,
},
},
wantErr: "500",
},
{ {
name: "successful stream completion", name: "successful stream completion",
responses: []any{ responses: []any{
@@ -146,7 +137,7 @@ func TestClientStream(t *testing.T) {
client := NewClient(&url.URL{Scheme: "http", Host: ts.Listener.Addr().String()}, http.DefaultClient) client := NewClient(&url.URL{Scheme: "http", Host: ts.Listener.Addr().String()}, http.DefaultClient)
var receivedChunks []ChatResponse var receivedChunks []ChatResponse
err := client.stream(t.Context(), http.MethodPost, "/v1/chat", nil, func(chunk []byte) error { err := client.stream(context.Background(), http.MethodPost, "/v1/chat", nil, func(chunk []byte) error {
var resp ChatResponse var resp ChatResponse
if err := json.Unmarshal(chunk, &resp); err != nil { if err := json.Unmarshal(chunk, &resp); err != nil {
return fmt.Errorf("failed to unmarshal chunk: %w", err) return fmt.Errorf("failed to unmarshal chunk: %w", err)
@@ -232,7 +223,7 @@ func TestClientDo(t *testing.T) {
ID string `json:"id"` ID string `json:"id"`
Success bool `json:"success"` Success bool `json:"success"`
} }
err := client.do(t.Context(), http.MethodPost, "/v1/messages", nil, &resp) err := client.do(context.Background(), http.MethodPost, "/v1/messages", nil, &resp)
if tc.wantErr != "" { if tc.wantErr != "" {
if err == nil { if err == nil {

View File

@@ -83,12 +83,6 @@ type GenerateRequest struct {
// Options lists model-specific options. For example, temperature can be // Options lists model-specific options. For example, temperature can be
// set through this field, if the model supports it. // set through this field, if the model supports it.
Options map[string]any `json:"options"` Options map[string]any `json:"options"`
// Think controls whether thinking/reasoning models will think before
// responding. Needs to be a pointer so we can distinguish between false
// (request that thinking _not_ be used) and unset (use the old behavior
// before this option was introduced)
Think *bool `json:"think,omitempty"`
} }
// ChatRequest describes a request sent by [Client.Chat]. // ChatRequest describes a request sent by [Client.Chat].
@@ -114,10 +108,6 @@ type ChatRequest struct {
// Options lists model-specific options. // Options lists model-specific options.
Options map[string]any `json:"options"` Options map[string]any `json:"options"`
// Think controls whether thinking/reasoning models will think before
// responding
Think *bool `json:"think,omitempty"`
} }
type Tools []Tool type Tools []Tool
@@ -136,14 +126,10 @@ func (t Tool) String() string {
// role ("system", "user", or "assistant"), the content and an optional list // role ("system", "user", or "assistant"), the content and an optional list
// of images. // of images.
type Message struct { type Message struct {
Role string `json:"role"` Role string `json:"role"`
Content string `json:"content"` Content string `json:"content"`
// Thinking contains the text that was inside thinking tags in the
// original model output when ChatRequest.Think is enabled.
Thinking string `json:"thinking,omitempty"`
Images []ImageData `json:"images,omitempty"` Images []ImageData `json:"images,omitempty"`
ToolCalls []ToolCall `json:"tool_calls,omitempty"` ToolCalls []ToolCall `json:"tool_calls,omitempty"`
ToolName string `json:"tool_name,omitempty"`
} }
func (m *Message) UnmarshalJSON(b []byte) error { func (m *Message) UnmarshalJSON(b []byte) error {
@@ -468,14 +454,13 @@ type ListModelResponse struct {
// ProcessModelResponse is a single model description in [ProcessResponse]. // ProcessModelResponse is a single model description in [ProcessResponse].
type ProcessModelResponse struct { type ProcessModelResponse struct {
Name string `json:"name"` Name string `json:"name"`
Model string `json:"model"` Model string `json:"model"`
Size int64 `json:"size"` Size int64 `json:"size"`
Digest string `json:"digest"` Digest string `json:"digest"`
Details ModelDetails `json:"details,omitempty"` Details ModelDetails `json:"details,omitempty"`
ExpiresAt time.Time `json:"expires_at"` ExpiresAt time.Time `json:"expires_at"`
SizeVRAM int64 `json:"size_vram"` SizeVRAM int64 `json:"size_vram"`
ContextLength int `json:"context_length"`
} }
type TokenResponse struct { type TokenResponse struct {
@@ -493,10 +478,6 @@ type GenerateResponse struct {
// Response is the textual response itself. // Response is the textual response itself.
Response string `json:"response"` Response string `json:"response"`
// Thinking contains the text that was inside thinking tags in the
// original model output when ChatRequest.Think is enabled.
Thinking string `json:"thinking,omitempty"`
// Done specifies if the response is complete. // Done specifies if the response is complete.
Done bool `json:"done"` Done bool `json:"done"`

View File

@@ -372,50 +372,3 @@ func TestPropertyType_MarshalJSON(t *testing.T) {
}) })
} }
} }
func TestThinking_UnmarshalJSON(t *testing.T) {
trueVal := true
falseVal := false
tests := []struct {
name string
input string
expectedThinking *bool
expectedError bool
}{
{
name: "true",
input: `{ "think": true }`,
expectedThinking: &trueVal,
},
{
name: "false",
input: `{ "think": false }`,
expectedThinking: &falseVal,
},
{
name: "unset",
input: `{ }`,
expectedThinking: nil,
},
{
name: "invalid",
input: `{ "think": "true" }`,
expectedThinking: nil,
expectedError: true,
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
var req GenerateRequest
err := json.Unmarshal([]byte(test.input), &req)
if test.expectedError {
require.Error(t, err)
} else {
require.NoError(t, err)
assert.Equal(t, test.expectedThinking, req.Think)
}
})
}
}

View File

@@ -4,14 +4,20 @@ import (
"fmt" "fmt"
"log/slog" "log/slog"
"os" "os"
"path/filepath"
"strconv" "strconv"
"strings" "strings"
"github.com/ollama/ollama/envconfig" "github.com/ollama/ollama/envconfig"
"github.com/ollama/ollama/logutil"
) )
func InitLogging() { func InitLogging() {
level := slog.LevelInfo
if envconfig.Debug() {
level = slog.LevelDebug
}
var logFile *os.File var logFile *os.File
var err error var err error
// Detect if we're a GUI app on windows, and if not, send logs to console // Detect if we're a GUI app on windows, and if not, send logs to console
@@ -27,8 +33,20 @@ func InitLogging() {
return return
} }
} }
handler := slog.NewTextHandler(logFile, &slog.HandlerOptions{
Level: level,
AddSource: true,
ReplaceAttr: func(_ []string, attr slog.Attr) slog.Attr {
if attr.Key == slog.SourceKey {
source := attr.Value.Any().(*slog.Source)
source.File = filepath.Base(source.File)
}
return attr
},
})
slog.SetDefault(slog.New(handler))
slog.SetDefault(logutil.NewLogger(logFile, envconfig.LogLevel()))
slog.Info("ollama app started") slog.Info("ollama app started")
} }

View File

@@ -0,0 +1,178 @@
package benchmark
import (
"context"
"flag"
"fmt"
"testing"
"time"
"github.com/ollama/ollama/api"
)
// Command line flags
var modelFlag string
func init() {
flag.StringVar(&modelFlag, "m", "", "Name of the model to benchmark")
flag.Lookup("m").DefValue = "model"
}
// modelName returns the model name from flags, failing the test if not set
func modelName(b *testing.B) string {
if modelFlag == "" {
b.Fatal("Error: -m flag is required for benchmark tests")
}
return modelFlag
}
type TestCase struct {
name string
prompt string
maxTokens int
}
// runGenerateBenchmark contains the common generate and metrics logic
func runGenerateBenchmark(b *testing.B, ctx context.Context, client *api.Client, req *api.GenerateRequest) {
start := time.Now()
var ttft time.Duration
var metrics api.Metrics
err := client.Generate(ctx, req, func(resp api.GenerateResponse) error {
if ttft == 0 && resp.Response != "" {
ttft = time.Since(start)
}
if resp.Done {
metrics = resp.Metrics
}
return nil
})
// Report custom metrics as part of the benchmark results
b.ReportMetric(float64(ttft.Milliseconds()), "ttft_ms")
b.ReportMetric(float64(metrics.LoadDuration.Milliseconds()), "load_ms")
// Token throughput metrics
promptThroughput := float64(metrics.PromptEvalCount) / metrics.PromptEvalDuration.Seconds()
genThroughput := float64(metrics.EvalCount) / metrics.EvalDuration.Seconds()
b.ReportMetric(promptThroughput, "prompt_tok/s")
b.ReportMetric(genThroughput, "gen_tok/s")
// Token counts
b.ReportMetric(float64(metrics.PromptEvalCount), "prompt_tokens")
b.ReportMetric(float64(metrics.EvalCount), "gen_tokens")
if err != nil {
b.Fatal(err)
}
}
// BenchmarkColdStart runs benchmarks with model loading from cold state
func BenchmarkColdStart(b *testing.B) {
client := setup(b)
tests := []TestCase{
{"short_prompt", "Write a long story", 100},
{"medium_prompt", "Write a detailed economic analysis", 500},
{"long_prompt", "Write a comprehensive AI research paper", 1000},
}
m := modelName(b)
for _, tt := range tests {
b.Run(fmt.Sprintf("%s/cold/%s", m, tt.name), func(b *testing.B) {
ctx := context.Background()
// Set number of tokens as our throughput metric
b.SetBytes(int64(tt.maxTokens))
for b.Loop() {
b.StopTimer()
// Ensure model is unloaded before each iteration
unload(client, m, b)
b.StartTimer()
req := &api.GenerateRequest{
Model: m,
Prompt: tt.prompt,
Options: map[string]any{"num_predict": tt.maxTokens, "temperature": 0.1},
}
runGenerateBenchmark(b, ctx, client, req)
}
})
}
}
// BenchmarkWarmStart runs benchmarks with pre-loaded model
func BenchmarkWarmStart(b *testing.B) {
client := setup(b)
tests := []TestCase{
{"short_prompt", "Write a long story", 100},
{"medium_prompt", "Write a detailed economic analysis", 500},
{"long_prompt", "Write a comprehensive AI research paper", 1000},
}
m := modelName(b)
for _, tt := range tests {
b.Run(fmt.Sprintf("%s/warm/%s", m, tt.name), func(b *testing.B) {
ctx := context.Background()
// Pre-warm the model
warmup(client, m, tt.prompt, b)
// Set number of tokens as our throughput metric
b.SetBytes(int64(tt.maxTokens))
for b.Loop() {
req := &api.GenerateRequest{
Model: m,
Prompt: tt.prompt,
Options: map[string]any{"num_predict": tt.maxTokens, "temperature": 0.1},
}
runGenerateBenchmark(b, ctx, client, req)
}
})
}
}
// setup verifies server and model availability
func setup(b *testing.B) *api.Client {
client, err := api.ClientFromEnvironment()
if err != nil {
b.Fatal(err)
}
if _, err := client.Show(context.Background(), &api.ShowRequest{Model: modelName(b)}); err != nil {
b.Fatalf("Model unavailable: %v", err)
}
return client
}
// warmup ensures the model is loaded and warmed up
func warmup(client *api.Client, model string, prompt string, b *testing.B) {
for range 3 {
err := client.Generate(
context.Background(),
&api.GenerateRequest{
Model: model,
Prompt: prompt,
Options: map[string]any{"num_predict": 50, "temperature": 0.1},
},
func(api.GenerateResponse) error { return nil },
)
if err != nil {
b.Logf("Error during model warm-up: %v", err)
}
}
}
// unload forces model unloading using KeepAlive: 0 parameter
func unload(client *api.Client, model string, b *testing.B) {
req := &api.GenerateRequest{
Model: model,
KeepAlive: &api.Duration{Duration: 0},
}
if err := client.Generate(context.Background(), req, func(api.GenerateResponse) error { return nil }); err != nil {
b.Logf("Unload error: %v", err)
}
time.Sleep(1 * time.Second)
}

View File

@@ -39,7 +39,6 @@ import (
"github.com/ollama/ollama/format" "github.com/ollama/ollama/format"
"github.com/ollama/ollama/parser" "github.com/ollama/ollama/parser"
"github.com/ollama/ollama/progress" "github.com/ollama/ollama/progress"
"github.com/ollama/ollama/readline"
"github.com/ollama/ollama/runner" "github.com/ollama/ollama/runner"
"github.com/ollama/ollama/server" "github.com/ollama/ollama/server"
"github.com/ollama/ollama/types/model" "github.com/ollama/ollama/types/model"
@@ -47,23 +46,6 @@ import (
"github.com/ollama/ollama/version" "github.com/ollama/ollama/version"
) )
// ensureThinkingSupport emits a warning if the model does not advertise thinking support
func ensureThinkingSupport(ctx context.Context, client *api.Client, name string) {
if name == "" {
return
}
resp, err := client.Show(ctx, &api.ShowRequest{Model: name})
if err != nil {
return
}
for _, cap := range resp.Capabilities {
if cap == model.CapabilityThinking {
return
}
}
fmt.Fprintf(os.Stderr, "warning: model %q does not support thinking output\n", name)
}
var errModelfileNotFound = errors.New("specified Modelfile wasn't found") var errModelfileNotFound = errors.New("specified Modelfile wasn't found")
func getModelfileName(cmd *cobra.Command) (string, error) { func getModelfileName(cmd *cobra.Command) (string, error) {
@@ -283,9 +265,6 @@ func loadOrUnloadModel(cmd *cobra.Command, opts *runOptions) error {
req := &api.GenerateRequest{ req := &api.GenerateRequest{
Model: opts.Model, Model: opts.Model,
KeepAlive: opts.KeepAlive, KeepAlive: opts.KeepAlive,
// pass Think here so we fail before getting to the chat prompt if the model doesn't support it
Think: opts.Think,
} }
return client.Generate(cmd.Context(), req, func(api.GenerateResponse) error { return nil }) return client.Generate(cmd.Context(), req, func(api.GenerateResponse) error { return nil })
@@ -320,22 +299,6 @@ func RunHandler(cmd *cobra.Command, args []string) error {
} }
opts.Format = format opts.Format = format
thinkFlag := cmd.Flags().Lookup("think")
if thinkFlag.Changed {
think, err := cmd.Flags().GetBool("think")
if err != nil {
return err
}
opts.Think = &think
} else {
opts.Think = nil
}
hidethinking, err := cmd.Flags().GetBool("hidethinking")
if err != nil {
return err
}
opts.HideThinking = hidethinking
keepAlive, err := cmd.Flags().GetString("keepalive") keepAlive, err := cmd.Flags().GetString("keepalive")
if err != nil { if err != nil {
return err return err
@@ -399,11 +362,6 @@ func RunHandler(cmd *cobra.Command, args []string) error {
return err return err
} }
opts.Think, err = inferThinkingOption(&info.Capabilities, &opts, thinkFlag.Changed)
if err != nil {
return err
}
opts.MultiModal = slices.Contains(info.Capabilities, model.CapabilityVision) opts.MultiModal = slices.Contains(info.Capabilities, model.CapabilityVision)
// TODO: remove the projector info and vision info checks below, // TODO: remove the projector info and vision info checks below,
@@ -583,13 +541,12 @@ func ListRunningHandler(cmd *cobra.Command, args []string) error {
} else { } else {
until = format.HumanTime(m.ExpiresAt, "Never") until = format.HumanTime(m.ExpiresAt, "Never")
} }
ctxStr := strconv.Itoa(m.ContextLength) data = append(data, []string{m.Name, m.Digest[:12], format.HumanBytes(m.Size), procStr, until})
data = append(data, []string{m.Name, m.Digest[:12], format.HumanBytes(m.Size), procStr, ctxStr, until})
} }
} }
table := tablewriter.NewWriter(os.Stdout) table := tablewriter.NewWriter(os.Stdout)
table.SetHeader([]string{"NAME", "ID", "SIZE", "PROCESSOR", "CONTEXT", "UNTIL"}) table.SetHeader([]string{"NAME", "ID", "SIZE", "PROCESSOR", "UNTIL"})
table.SetHeaderAlignment(tablewriter.ALIGN_LEFT) table.SetHeaderAlignment(tablewriter.ALIGN_LEFT)
table.SetAlignment(tablewriter.ALIGN_LEFT) table.SetAlignment(tablewriter.ALIGN_LEFT)
table.SetHeaderLine(false) table.SetHeaderLine(false)
@@ -790,38 +747,11 @@ func showInfo(resp *api.ShowResponse, verbose bool, w io.Writer) error {
case float64: case float64:
v = fmt.Sprintf("%g", vData) v = fmt.Sprintf("%g", vData)
case []any: case []any:
targetWidth := 10 // Small width where we are displaying the data in a column n := 3
if len(vData) < n {
var itemsToShow int n = len(vData)
totalWidth := 1 // Start with 1 for opening bracket
// Find how many we can fit
for i := range vData {
itemStr := fmt.Sprintf("%v", vData[i])
width := runewidth.StringWidth(itemStr)
// Add separator width (", ") for all items except the first
if i > 0 {
width += 2
}
// Check if adding this item would exceed our width limit
if totalWidth+width > targetWidth && i > 0 {
break
}
totalWidth += width
itemsToShow++
}
// Format the output
if itemsToShow < len(vData) {
v = fmt.Sprintf("%v", vData[:itemsToShow])
v = strings.TrimSuffix(v, "]")
v += fmt.Sprintf(" ...+%d more]", len(vData)-itemsToShow)
} else {
v = fmt.Sprintf("%v", vData)
} }
v = fmt.Sprintf("%v", vData[:n])
default: default:
v = fmt.Sprintf("%T", vData) v = fmt.Sprintf("%T", vData)
} }
@@ -842,19 +772,10 @@ func showInfo(resp *api.ShowResponse, verbose bool, w io.Writer) error {
head := func(s string, n int) (rows [][]string) { head := func(s string, n int) (rows [][]string) {
scanner := bufio.NewScanner(strings.NewReader(s)) scanner := bufio.NewScanner(strings.NewReader(s))
count := 0 for scanner.Scan() && (len(rows) < n || n < 0) {
for scanner.Scan() { if text := scanner.Text(); text != "" {
text := strings.TrimSpace(scanner.Text()) rows = append(rows, []string{"", strings.TrimSpace(text)})
if text == "" {
continue
} }
count++
if n < 0 || count <= n {
rows = append(rows, []string{"", text})
}
}
if n >= 0 && count > n {
rows = append(rows, []string{"", "..."})
} }
return return
} }
@@ -966,19 +887,17 @@ func PullHandler(cmd *cobra.Command, args []string) error {
type generateContextKey string type generateContextKey string
type runOptions struct { type runOptions struct {
Model string Model string
ParentModel string ParentModel string
Prompt string Prompt string
Messages []api.Message Messages []api.Message
WordWrap bool WordWrap bool
Format string Format string
System string System string
Images []api.ImageData Images []api.ImageData
Options map[string]any Options map[string]any
MultiModal bool MultiModal bool
KeepAlive *api.Duration KeepAlive *api.Duration
Think *bool
HideThinking bool
} }
type displayResponseState struct { type displayResponseState struct {
@@ -1034,26 +953,6 @@ func displayResponse(content string, wordWrap bool, state *displayResponseState)
} }
} }
func thinkingOutputOpeningText(plainText bool) string {
text := "Thinking...\n"
if plainText {
return text
}
return readline.ColorGrey + readline.ColorBold + text + readline.ColorDefault + readline.ColorGrey
}
func thinkingOutputClosingText(plainText bool) string {
text := "...done thinking.\n\n"
if plainText {
return text
}
return readline.ColorGrey + readline.ColorBold + text + readline.ColorDefault
}
func chat(cmd *cobra.Command, opts runOptions) (*api.Message, error) { func chat(cmd *cobra.Command, opts runOptions) (*api.Message, error) {
client, err := api.ClientFromEnvironment() client, err := api.ClientFromEnvironment()
if err != nil { if err != nil {
@@ -1080,36 +979,15 @@ func chat(cmd *cobra.Command, opts runOptions) (*api.Message, error) {
var state *displayResponseState = &displayResponseState{} var state *displayResponseState = &displayResponseState{}
var latest api.ChatResponse var latest api.ChatResponse
var fullResponse strings.Builder var fullResponse strings.Builder
var thinkTagOpened bool = false var role string
var thinkTagClosed bool = false
role := "assistant"
fn := func(response api.ChatResponse) error { fn := func(response api.ChatResponse) error {
if response.Message.Content != "" || !opts.HideThinking { p.StopAndClear()
p.StopAndClear()
}
latest = response latest = response
role = response.Message.Role role = response.Message.Role
if response.Message.Thinking != "" && !opts.HideThinking {
if !thinkTagOpened {
fmt.Print(thinkingOutputOpeningText(false))
thinkTagOpened = true
}
displayResponse(response.Message.Thinking, opts.WordWrap, state)
}
content := response.Message.Content content := response.Message.Content
if thinkTagOpened && !thinkTagClosed && content != "" {
fmt.Print(thinkingOutputClosingText(false))
thinkTagClosed = true
}
// purposefully not putting thinking blocks in the response, which would
// only be needed if we later added tool calling to the cli (they get
// filtered out anyway since current models don't expect them unless you're
// about to finish some tool calls)
fullResponse.WriteString(content) fullResponse.WriteString(content)
displayResponse(content, opts.WordWrap, state) displayResponse(content, opts.WordWrap, state)
@@ -1126,7 +1004,6 @@ func chat(cmd *cobra.Command, opts runOptions) (*api.Message, error) {
Messages: opts.Messages, Messages: opts.Messages,
Format: json.RawMessage(opts.Format), Format: json.RawMessage(opts.Format),
Options: opts.Options, Options: opts.Options,
Think: opts.Think,
} }
if opts.KeepAlive != nil { if opts.KeepAlive != nil {
@@ -1137,14 +1014,6 @@ func chat(cmd *cobra.Command, opts runOptions) (*api.Message, error) {
if errors.Is(err, context.Canceled) { if errors.Is(err, context.Canceled) {
return nil, nil return nil, nil
} }
// this error should ideally be wrapped properly by the client
if strings.Contains(err.Error(), "upstream error") {
p.StopAndClear()
fmt.Println("An error occurred while processing your message. Please try again.")
fmt.Println()
return nil, nil
}
return nil, err return nil, err
} }
@@ -1196,32 +1065,13 @@ func generate(cmd *cobra.Command, opts runOptions) error {
}() }()
var state *displayResponseState = &displayResponseState{} var state *displayResponseState = &displayResponseState{}
var thinkTagOpened bool = false
var thinkTagClosed bool = false
plainText := !term.IsTerminal(int(os.Stdout.Fd()))
fn := func(response api.GenerateResponse) error { fn := func(response api.GenerateResponse) error {
p.StopAndClear()
latest = response latest = response
content := response.Response content := response.Response
if response.Response != "" || !opts.HideThinking {
p.StopAndClear()
}
if response.Thinking != "" && !opts.HideThinking {
if !thinkTagOpened {
fmt.Print(thinkingOutputOpeningText(plainText))
thinkTagOpened = true
}
displayResponse(response.Thinking, opts.WordWrap, state)
}
if thinkTagOpened && !thinkTagClosed && content != "" {
fmt.Print(thinkingOutputClosingText(plainText))
thinkTagClosed = true
}
displayResponse(content, opts.WordWrap, state) displayResponse(content, opts.WordWrap, state)
return nil return nil
@@ -1247,7 +1097,6 @@ func generate(cmd *cobra.Command, opts runOptions) error {
System: opts.System, System: opts.System,
Options: opts.Options, Options: opts.Options,
KeepAlive: opts.KeepAlive, KeepAlive: opts.KeepAlive,
Think: opts.Think,
} }
if err := client.Generate(ctx, &request, fn); err != nil { if err := client.Generate(ctx, &request, fn); err != nil {
@@ -1351,11 +1200,11 @@ func checkServerHeartbeat(cmd *cobra.Command, _ []string) error {
return err return err
} }
if err := client.Heartbeat(cmd.Context()); err != nil { if err := client.Heartbeat(cmd.Context()); err != nil {
if !(strings.Contains(err.Error(), " refused") || strings.Contains(err.Error(), "could not connect")) { if !strings.Contains(err.Error(), " refused") {
return err return err
} }
if err := startApp(cmd.Context(), client); err != nil { if err := startApp(cmd.Context(), client); err != nil {
return fmt.Errorf("ollama server not responding - %w", err) return errors.New("could not connect to ollama app, is it running?")
} }
} }
return nil return nil
@@ -1426,14 +1275,14 @@ func NewCLI() *cobra.Command {
createCmd := &cobra.Command{ createCmd := &cobra.Command{
Use: "create MODEL", Use: "create MODEL",
Short: "Create a model", Short: "Create a model from a Modelfile",
Args: cobra.ExactArgs(1), Args: cobra.ExactArgs(1),
PreRunE: checkServerHeartbeat, PreRunE: checkServerHeartbeat,
RunE: CreateHandler, RunE: CreateHandler,
} }
createCmd.Flags().StringP("file", "f", "", "Name of the Modelfile (default \"Modelfile\")") createCmd.Flags().StringP("file", "f", "", "Name of the Modelfile (default \"Modelfile\"")
createCmd.Flags().StringP("quantize", "q", "", "Quantize model to this level (e.g. q4_K_M)") createCmd.Flags().StringP("quantize", "q", "", "Quantize model to this level (e.g. q4_0)")
showCmd := &cobra.Command{ showCmd := &cobra.Command{
Use: "show MODEL", Use: "show MODEL",
@@ -1463,8 +1312,6 @@ func NewCLI() *cobra.Command {
runCmd.Flags().Bool("insecure", false, "Use an insecure registry") runCmd.Flags().Bool("insecure", false, "Use an insecure registry")
runCmd.Flags().Bool("nowordwrap", false, "Don't wrap words to the next line automatically") runCmd.Flags().Bool("nowordwrap", false, "Don't wrap words to the next line automatically")
runCmd.Flags().String("format", "", "Response format (e.g. json)") runCmd.Flags().String("format", "", "Response format (e.g. json)")
runCmd.Flags().Bool("think", false, "Whether to use thinking mode for supported models")
runCmd.Flags().Bool("hidethinking", false, "Hide thinking output (if provided)")
stopCmd := &cobra.Command{ stopCmd := &cobra.Command{
Use: "stop MODEL", Use: "stop MODEL",
@@ -1516,6 +1363,7 @@ func NewCLI() *cobra.Command {
PreRunE: checkServerHeartbeat, PreRunE: checkServerHeartbeat,
RunE: ListRunningHandler, RunE: ListRunningHandler,
} }
copyCmd := &cobra.Command{ copyCmd := &cobra.Command{
Use: "cp SOURCE DESTINATION", Use: "cp SOURCE DESTINATION",
Short: "Copy a model", Short: "Copy a model",
@@ -1604,45 +1452,3 @@ func NewCLI() *cobra.Command {
return rootCmd return rootCmd
} }
// If the user has explicitly set thinking options, either through the CLI or
// through the `/set think` or `set nothink` interactive options, then we
// respect them. Otherwise, we check model capabilities to see if the model
// supports thinking. If the model does support thinking, we enable it.
// Otherwise, we unset the thinking option (which is different than setting it
// to false).
//
// If capabilities are not provided, we fetch them from the server.
func inferThinkingOption(caps *[]model.Capability, runOpts *runOptions, explicitlySetByUser bool) (*bool, error) {
if explicitlySetByUser {
return runOpts.Think, nil
}
if caps == nil {
client, err := api.ClientFromEnvironment()
if err != nil {
return nil, err
}
ret, err := client.Show(context.Background(), &api.ShowRequest{
Model: runOpts.Model,
})
if err != nil {
return nil, err
}
caps = &ret.Capabilities
}
thinkingSupported := false
for _, cap := range *caps {
if cap == model.CapabilityThinking {
thinkingSupported = true
}
}
if thinkingSupported {
thinking := true
return &thinking, nil
}
return nil, nil
}

View File

@@ -2,6 +2,7 @@ package cmd
import ( import (
"bytes" "bytes"
"context"
"encoding/json" "encoding/json"
"io" "io"
"net/http" "net/http"
@@ -225,7 +226,6 @@ Weigh anchor!
System System
You are a pirate! You are a pirate!
Ahoy, matey! Ahoy, matey!
...
` `
if diff := cmp.Diff(expect, b.String()); diff != "" { if diff := cmp.Diff(expect, b.String()); diff != "" {
@@ -337,7 +337,7 @@ func TestDeleteHandler(t *testing.T) {
t.Cleanup(mockServer.Close) t.Cleanup(mockServer.Close)
cmd := &cobra.Command{} cmd := &cobra.Command{}
cmd.SetContext(t.Context()) cmd.SetContext(context.TODO())
if err := DeleteHandler(cmd, []string{"test-model"}); err != nil { if err := DeleteHandler(cmd, []string{"test-model"}); err != nil {
t.Fatalf("DeleteHandler failed: %v", err) t.Fatalf("DeleteHandler failed: %v", err)
} }
@@ -399,6 +399,11 @@ func TestGetModelfileName(t *testing.T) {
var expectedFilename string var expectedFilename string
if tt.fileExists { if tt.fileExists {
tempDir, err := os.MkdirTemp("", "modelfiledir")
defer os.RemoveAll(tempDir)
if err != nil {
t.Fatalf("temp modelfile dir creation failed: %v", err)
}
var fn string var fn string
if tt.modelfileName != "" { if tt.modelfileName != "" {
fn = tt.modelfileName fn = tt.modelfileName
@@ -406,7 +411,7 @@ func TestGetModelfileName(t *testing.T) {
fn = "Modelfile" fn = "Modelfile"
} }
tempFile, err := os.CreateTemp(t.TempDir(), fn) tempFile, err := os.CreateTemp(tempDir, fn)
if err != nil { if err != nil {
t.Fatalf("temp modelfile creation failed: %v", err) t.Fatalf("temp modelfile creation failed: %v", err)
} }
@@ -525,7 +530,7 @@ func TestPushHandler(t *testing.T) {
cmd := &cobra.Command{} cmd := &cobra.Command{}
cmd.Flags().Bool("insecure", false, "") cmd.Flags().Bool("insecure", false, "")
cmd.SetContext(t.Context()) cmd.SetContext(context.TODO())
// Redirect stderr to capture progress output // Redirect stderr to capture progress output
oldStderr := os.Stderr oldStderr := os.Stderr
@@ -630,7 +635,7 @@ func TestListHandler(t *testing.T) {
t.Setenv("OLLAMA_HOST", mockServer.URL) t.Setenv("OLLAMA_HOST", mockServer.URL)
cmd := &cobra.Command{} cmd := &cobra.Command{}
cmd.SetContext(t.Context()) cmd.SetContext(context.TODO())
// Capture stdout // Capture stdout
oldStdout := os.Stdout oldStdout := os.Stdout
@@ -725,7 +730,7 @@ func TestCreateHandler(t *testing.T) {
})) }))
t.Setenv("OLLAMA_HOST", mockServer.URL) t.Setenv("OLLAMA_HOST", mockServer.URL)
t.Cleanup(mockServer.Close) t.Cleanup(mockServer.Close)
tempFile, err := os.CreateTemp(t.TempDir(), "modelfile") tempFile, err := os.CreateTemp("", "modelfile")
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@@ -745,7 +750,7 @@ func TestCreateHandler(t *testing.T) {
} }
cmd.Flags().Bool("insecure", false, "") cmd.Flags().Bool("insecure", false, "")
cmd.SetContext(t.Context()) cmd.SetContext(context.TODO())
// Redirect stderr to capture progress output // Redirect stderr to capture progress output
oldStderr := os.Stderr oldStderr := os.Stderr

View File

@@ -44,7 +44,7 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error {
fmt.Fprintln(os.Stderr, "Use \"\"\" to begin a multi-line message.") fmt.Fprintln(os.Stderr, "Use \"\"\" to begin a multi-line message.")
if opts.MultiModal { if opts.MultiModal {
fmt.Fprintf(os.Stderr, "Use %s to include .jpg, .png, or .webp images.\n", filepath.FromSlash("/path/to/file")) fmt.Fprintf(os.Stderr, "Use %s to include .jpg or .png images.\n", filepath.FromSlash("/path/to/file"))
} }
fmt.Fprintln(os.Stderr, "") fmt.Fprintln(os.Stderr, "")
@@ -62,8 +62,6 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error {
fmt.Fprintln(os.Stderr, " /set noformat Disable formatting") fmt.Fprintln(os.Stderr, " /set noformat Disable formatting")
fmt.Fprintln(os.Stderr, " /set verbose Show LLM stats") fmt.Fprintln(os.Stderr, " /set verbose Show LLM stats")
fmt.Fprintln(os.Stderr, " /set quiet Disable LLM stats") fmt.Fprintln(os.Stderr, " /set quiet Disable LLM stats")
fmt.Fprintln(os.Stderr, " /set think Enable thinking")
fmt.Fprintln(os.Stderr, " /set nothink Disable thinking")
fmt.Fprintln(os.Stderr, "") fmt.Fprintln(os.Stderr, "")
} }
@@ -130,7 +128,6 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error {
var sb strings.Builder var sb strings.Builder
var multiline MultilineState var multiline MultilineState
var thinkExplicitlySet bool = opts.Think != nil
for { for {
line, err := scanner.Readline() line, err := scanner.Readline()
@@ -198,19 +195,11 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error {
opts.Model = args[1] opts.Model = args[1]
opts.Messages = []api.Message{} opts.Messages = []api.Message{}
fmt.Printf("Loading model '%s'\n", opts.Model) fmt.Printf("Loading model '%s'\n", opts.Model)
opts.Think, err = inferThinkingOption(nil, &opts, thinkExplicitlySet)
if err != nil {
return err
}
if err := loadOrUnloadModel(cmd, &opts); err != nil { if err := loadOrUnloadModel(cmd, &opts); err != nil {
if strings.Contains(err.Error(), "not found") { if strings.Contains(err.Error(), "not found") {
fmt.Printf("error: %v\n", err) fmt.Printf("error: %v\n", err)
continue continue
} }
if strings.Contains(err.Error(), "does not support thinking") {
fmt.Printf("error: %v\n", err)
continue
}
return err return err
} }
continue continue
@@ -271,22 +260,6 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error {
return err return err
} }
fmt.Println("Set 'quiet' mode.") fmt.Println("Set 'quiet' mode.")
case "think":
think := true
opts.Think = &think
thinkExplicitlySet = true
if client, err := api.ClientFromEnvironment(); err == nil {
ensureThinkingSupport(cmd.Context(), client, opts.Model)
}
fmt.Println("Set 'think' mode.")
case "nothink":
think := false
opts.Think = &think
thinkExplicitlySet = true
if client, err := api.ClientFromEnvironment(); err == nil {
ensureThinkingSupport(cmd.Context(), client, opts.Model)
}
fmt.Println("Set 'nothink' mode.")
case "format": case "format":
if len(args) < 3 || args[2] != "json" { if len(args) < 3 || args[2] != "json" {
fmt.Println("Invalid or missing format. For 'json' mode use '/set format json'") fmt.Println("Invalid or missing format. For 'json' mode use '/set format json'")
@@ -385,21 +358,18 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error {
case "modelfile": case "modelfile":
fmt.Println(resp.Modelfile) fmt.Println(resp.Modelfile)
case "parameters": case "parameters":
fmt.Println("Model defined parameters:")
if resp.Parameters == "" { if resp.Parameters == "" {
fmt.Println(" No additional parameters were specified for this model.") fmt.Println("No parameters were specified for this model.")
} else { } else {
for _, l := range strings.Split(resp.Parameters, "\n") { if len(opts.Options) > 0 {
fmt.Printf(" %s\n", l) fmt.Println("User defined parameters:")
for k, v := range opts.Options {
fmt.Printf("%-*s %v\n", 30, k, v)
}
fmt.Println()
} }
} fmt.Println("Model defined parameters:")
fmt.Println() fmt.Println(resp.Parameters)
if len(opts.Options) > 0 {
fmt.Println("User defined parameters:")
for k, v := range opts.Options {
fmt.Printf(" %-*s %v\n", 30, k, v)
}
fmt.Println()
} }
case "system": case "system":
switch { switch {
@@ -478,11 +448,6 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error {
assistant, err := chat(cmd, opts) assistant, err := chat(cmd, opts)
if err != nil { if err != nil {
if strings.Contains(err.Error(), "does not support thinking") {
fmt.Printf("error: %v\n", err)
sb.Reset()
continue
}
return err return err
} }
if assistant != nil { if assistant != nil {
@@ -546,7 +511,7 @@ func extractFileNames(input string) []string {
// Regex to match file paths starting with optional drive letter, / ./ \ or .\ and include escaped or unescaped spaces (\ or %20) // Regex to match file paths starting with optional drive letter, / ./ \ or .\ and include escaped or unescaped spaces (\ or %20)
// and followed by more characters and a file extension // and followed by more characters and a file extension
// This will capture non filename strings, but we'll check for file existence to remove mismatches // This will capture non filename strings, but we'll check for file existence to remove mismatches
regexPattern := `(?:[a-zA-Z]:)?(?:\./|/|\\)[\S\\ ]+?\.(?i:jpg|jpeg|png|webp)\b` regexPattern := `(?:[a-zA-Z]:)?(?:\./|/|\\)[\S\\ ]+?\.(?i:jpg|jpeg|png)\b`
re := regexp.MustCompile(regexPattern) re := regexp.MustCompile(regexPattern)
return re.FindAllString(input, -1) return re.FindAllString(input, -1)
@@ -566,8 +531,6 @@ func extractFileData(input string) (string, []api.ImageData, error) {
return "", imgs, err return "", imgs, err
} }
fmt.Fprintf(os.Stderr, "Added image '%s'\n", nfp) fmt.Fprintf(os.Stderr, "Added image '%s'\n", nfp)
input = strings.ReplaceAll(input, "'"+nfp+"'", "")
input = strings.ReplaceAll(input, "'"+fp+"'", "")
input = strings.ReplaceAll(input, fp, "") input = strings.ReplaceAll(input, fp, "")
imgs = append(imgs, data) imgs = append(imgs, data)
} }
@@ -588,7 +551,7 @@ func getImageData(filePath string) ([]byte, error) {
} }
contentType := http.DetectContentType(buf) contentType := http.DetectContentType(buf)
allowedTypes := []string{"image/jpeg", "image/jpg", "image/png", "image/webp"} allowedTypes := []string{"image/jpeg", "image/jpg", "image/png"}
if !slices.Contains(allowedTypes, contentType) { if !slices.Contains(allowedTypes, contentType) {
return nil, fmt.Errorf("invalid image type: %s", contentType) return nil, fmt.Errorf("invalid image type: %s", contentType)
} }

View File

@@ -1,8 +1,6 @@
package cmd package cmd
import ( import (
"os"
"path/filepath"
"testing" "testing"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
@@ -12,17 +10,14 @@ func TestExtractFilenames(t *testing.T) {
// Unix style paths // Unix style paths
input := ` some preamble input := ` some preamble
./relative\ path/one.png inbetween1 ./not a valid two.jpg inbetween2 ./1.svg ./relative\ path/one.png inbetween1 ./not a valid two.jpg inbetween2 ./1.svg
/unescaped space /three.jpeg inbetween3 /valid\ path/dir/four.png "./quoted with spaces/five.JPG /unescaped space /three.jpeg inbetween3 /valid\ path/dir/four.png "./quoted with spaces/five.JPG`
/unescaped space /six.webp inbetween6 /valid\ path/dir/seven.WEBP`
res := extractFileNames(input) res := extractFileNames(input)
assert.Len(t, res, 7) assert.Len(t, res, 5)
assert.Contains(t, res[0], "one.png") assert.Contains(t, res[0], "one.png")
assert.Contains(t, res[1], "two.jpg") assert.Contains(t, res[1], "two.jpg")
assert.Contains(t, res[2], "three.jpeg") assert.Contains(t, res[2], "three.jpeg")
assert.Contains(t, res[3], "four.png") assert.Contains(t, res[3], "four.png")
assert.Contains(t, res[4], "five.JPG") assert.Contains(t, res[4], "five.JPG")
assert.Contains(t, res[5], "six.webp")
assert.Contains(t, res[6], "seven.WEBP")
assert.NotContains(t, res[4], '"') assert.NotContains(t, res[4], '"')
assert.NotContains(t, res, "inbetween1") assert.NotContains(t, res, "inbetween1")
assert.NotContains(t, res, "./1.svg") assert.NotContains(t, res, "./1.svg")
@@ -33,12 +28,10 @@ func TestExtractFilenames(t *testing.T) {
/absolute/nospace/three.jpeg inbetween3 /absolute/with space/four.png inbetween4 /absolute/nospace/three.jpeg inbetween3 /absolute/with space/four.png inbetween4
./relative\ path/five.JPG inbetween5 "./relative with/spaces/six.png inbetween6 ./relative\ path/five.JPG inbetween5 "./relative with/spaces/six.png inbetween6
d:\path with\spaces\seven.JPEG inbetween7 c:\users\jdoe\eight.png inbetween8 d:\path with\spaces\seven.JPEG inbetween7 c:\users\jdoe\eight.png inbetween8
d:\program files\someplace\nine.png inbetween9 "E:\program files\someplace\ten.PNG d:\program files\someplace\nine.png inbetween9 "E:\program files\someplace\ten.PNG some ending
c:/users/jdoe/eleven.webp inbetween11 c:/program files/someplace/twelve.WebP inbetween12
d:\path with\spaces\thirteen.WEBP some ending
` `
res = extractFileNames(input) res = extractFileNames(input)
assert.Len(t, res, 13) assert.Len(t, res, 10)
assert.NotContains(t, res, "inbetween2") assert.NotContains(t, res, "inbetween2")
assert.Contains(t, res[0], "one.png") assert.Contains(t, res[0], "one.png")
assert.Contains(t, res[0], "c:") assert.Contains(t, res[0], "c:")
@@ -56,31 +49,4 @@ d:\path with\spaces\thirteen.WEBP some ending
assert.Contains(t, res[8], "d:") assert.Contains(t, res[8], "d:")
assert.Contains(t, res[9], "ten.PNG") assert.Contains(t, res[9], "ten.PNG")
assert.Contains(t, res[9], "E:") assert.Contains(t, res[9], "E:")
assert.Contains(t, res[10], "eleven.webp")
assert.Contains(t, res[10], "c:")
assert.Contains(t, res[11], "twelve.WebP")
assert.Contains(t, res[11], "c:")
assert.Contains(t, res[12], "thirteen.WEBP")
assert.Contains(t, res[12], "d:")
}
// Ensure that file paths wrapped in single quotes are removed with the quotes.
func TestExtractFileDataRemovesQuotedFilepath(t *testing.T) {
dir := t.TempDir()
fp := filepath.Join(dir, "img.jpg")
data := make([]byte, 600)
copy(data, []byte{
0xff, 0xd8, 0xff, 0xe0, 0x00, 0x10, 'J', 'F', 'I', 'F',
0x00, 0x01, 0x01, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
0xff, 0xd9,
})
if err := os.WriteFile(fp, data, 0o600); err != nil {
t.Fatalf("failed to write test image: %v", err)
}
input := "before '" + fp + "' after"
cleaned, imgs, err := extractFileData(input)
assert.NoError(t, err)
assert.Len(t, imgs, 1)
assert.Equal(t, cleaned, "before after")
} }

View File

@@ -5,7 +5,7 @@ import (
"errors" "errors"
"os" "os"
"os/exec" "os/exec"
"regexp" "strings"
"github.com/ollama/ollama/api" "github.com/ollama/ollama/api"
) )
@@ -19,12 +19,11 @@ func startApp(ctx context.Context, client *api.Client) error {
if err != nil { if err != nil {
return err return err
} }
r := regexp.MustCompile(`^.*/Ollama\s?\d*.app`) if !strings.Contains(link, "Ollama.app") {
m := r.FindStringSubmatch(link)
if len(m) != 1 {
return errors.New("could not find ollama app") return errors.New("could not find ollama app")
} }
if err := exec.Command("/usr/bin/open", "-j", "-a", m[0], "--args", "--fast-startup").Run(); err != nil { path := strings.Split(link, "Ollama.app")
if err := exec.Command("/usr/bin/open", "-a", path[0]+"Ollama.app").Run(); err != nil {
return err return err
} }
return waitForServer(ctx, client) return waitForServer(ctx, client)

View File

@@ -4,27 +4,17 @@ import (
"context" "context"
"errors" "errors"
"fmt" "fmt"
"log/slog"
"os" "os"
"os/exec" "os/exec"
"path"
"path/filepath" "path/filepath"
"strings" "strings"
"syscall" "syscall"
"unsafe"
"github.com/ollama/ollama/api" "github.com/ollama/ollama/api"
"golang.org/x/sys/windows"
)
const (
Installer = "OllamaSetup.exe"
) )
func startApp(ctx context.Context, client *api.Client) error { func startApp(ctx context.Context, client *api.Client) error {
if len(isProcRunning(Installer)) > 0 { // log.Printf("XXX Attempting to find and start ollama app")
return fmt.Errorf("upgrade in progress...")
}
AppName := "ollama app.exe" AppName := "ollama app.exe"
exe, err := os.Executable() exe, err := os.Executable()
if err != nil { if err != nil {
@@ -45,11 +35,14 @@ func startApp(ctx context.Context, client *api.Client) error {
} }
} }
} }
// log.Printf("XXX attempting to start app %s", appExe)
cmd_path := "c:\\Windows\\system32\\cmd.exe" cmd_path := "c:\\Windows\\system32\\cmd.exe"
cmd := exec.Command(cmd_path, "/c", appExe, "--hide", "--fast-startup") cmd := exec.Command(cmd_path, "/c", appExe)
// TODO - these hide flags aren't working - still pops up a command window for some reason
cmd.SysProcAttr = &syscall.SysProcAttr{CreationFlags: 0x08000000, HideWindow: true} cmd.SysProcAttr = &syscall.SysProcAttr{CreationFlags: 0x08000000, HideWindow: true}
// TODO this didn't help either...
cmd.Stdin = strings.NewReader("") cmd.Stdin = strings.NewReader("")
cmd.Stdout = os.Stdout cmd.Stdout = os.Stdout
cmd.Stderr = os.Stderr cmd.Stderr = os.Stderr
@@ -63,50 +56,3 @@ func startApp(ctx context.Context, client *api.Client) error {
} }
return waitForServer(ctx, client) return waitForServer(ctx, client)
} }
func isProcRunning(procName string) []uint32 {
pids := make([]uint32, 2048)
var ret uint32
if err := windows.EnumProcesses(pids, &ret); err != nil || ret == 0 {
slog.Debug("failed to check for running installers", "error", err)
return nil
}
if ret > uint32(len(pids)) {
pids = make([]uint32, ret+10)
if err := windows.EnumProcesses(pids, &ret); err != nil || ret == 0 {
slog.Debug("failed to check for running installers", "error", err)
return nil
}
}
if ret < uint32(len(pids)) {
pids = pids[:ret]
}
var matches []uint32
for _, pid := range pids {
if pid == 0 {
continue
}
hProcess, err := windows.OpenProcess(windows.PROCESS_QUERY_INFORMATION|windows.PROCESS_VM_READ, false, pid)
if err != nil {
continue
}
defer windows.CloseHandle(hProcess)
var module windows.Handle
var cbNeeded uint32
cb := (uint32)(unsafe.Sizeof(module))
if err := windows.EnumProcessModules(hProcess, &module, cb, &cbNeeded); err != nil {
continue
}
var sz uint32 = 1024 * 8
moduleName := make([]uint16, sz)
cb = uint32(len(moduleName)) * (uint32)(unsafe.Sizeof(uint16(0)))
if err := windows.GetModuleBaseName(hProcess, module, &moduleName[0], cb); err != nil && err != syscall.ERROR_INSUFFICIENT_BUFFER {
continue
}
exeFile := path.Base(strings.ToLower(syscall.UTF16ToString(moduleName)))
if strings.EqualFold(exeFile, procName) {
matches = append(matches, pid)
}
}
return matches
}

View File

@@ -1,63 +0,0 @@
package cmd
import (
"encoding/json"
"io"
"net/http"
"net/http/httptest"
"os"
"strings"
"testing"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/types/model"
)
// Test that a warning is printed when thinking is requested but not supported.
func TestWarnMissingThinking(t *testing.T) {
cases := []struct {
capabilities []model.Capability
expectWarn bool
}{
{capabilities: []model.Capability{model.CapabilityThinking}, expectWarn: false},
{capabilities: []model.Capability{}, expectWarn: true},
}
for _, tc := range cases {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/api/show" || r.Method != http.MethodPost {
t.Fatalf("unexpected request to %s %s", r.URL.Path, r.Method)
}
var req api.ShowRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
t.Fatalf("decode request: %v", err)
}
resp := api.ShowResponse{Capabilities: tc.capabilities}
if err := json.NewEncoder(w).Encode(resp); err != nil {
t.Fatalf("encode response: %v", err)
}
}))
defer srv.Close()
t.Setenv("OLLAMA_HOST", srv.URL)
client, err := api.ClientFromEnvironment()
if err != nil {
t.Fatal(err)
}
oldStderr := os.Stderr
r, w, _ := os.Pipe()
os.Stderr = w
ensureThinkingSupport(t.Context(), client, "m")
w.Close()
os.Stderr = oldStderr
out, _ := io.ReadAll(r)
warned := strings.Contains(string(out), "warning:")
if tc.expectWarn && !warned {
t.Errorf("expected warning, got none")
}
if !tc.expectWarn && warned {
t.Errorf("did not expect warning, got: %s", string(out))
}
}
}

View File

@@ -1,7 +1,6 @@
package convert package convert
import ( import (
"cmp"
"encoding/json" "encoding/json"
"errors" "errors"
"fmt" "fmt"
@@ -15,12 +14,13 @@ import (
) )
type ModelParameters struct { type ModelParameters struct {
Architectures []string `json:"architectures"` Architectures []string `json:"architectures"`
VocabSize uint32 `json:"vocab_size"` VocabSize uint32 `json:"vocab_size"`
TextModel TextParameters `json:"text_config"`
}
TextModel struct { type TextParameters struct {
VocabSize uint32 `json:"vocab_size"` VocabSize uint32 `json:"vocab_size"`
} `json:"text_config"`
} }
type AdapterParameters struct { type AdapterParameters struct {
@@ -53,11 +53,8 @@ func (ModelParameters) KV(t *Tokenizer) ggml.KV {
} }
for _, sv := range t.SpecialVocabulary { for _, sv := range t.SpecialVocabulary {
kv[fmt.Sprintf("tokenizer.ggml.add_%s_token", sv.Key())] = sv.AddToken
kv[fmt.Sprintf("tokenizer.ggml.%s_token_id", sv.Key())] = uint32(sv.ID) kv[fmt.Sprintf("tokenizer.ggml.%s_token_id", sv.Key())] = uint32(sv.ID)
if len(sv.IDs) > 0 { kv[fmt.Sprintf("tokenizer.ggml.add_%s_token", sv.Key())] = sv.AddToken
kv[fmt.Sprintf("tokenizer.ggml.%s_token_ids", sv.Key())] = sv.IDs
}
} }
return kv return kv
@@ -176,8 +173,6 @@ func ConvertModel(fsys fs.FS, f *os.File) error {
switch p.Architectures[0] { switch p.Architectures[0] {
case "LlamaForCausalLM": case "LlamaForCausalLM":
conv = &llamaModel{} conv = &llamaModel{}
case "MllamaForConditionalGeneration":
conv = &mllamaModel{}
case "Llama4ForConditionalGeneration": case "Llama4ForConditionalGeneration":
conv = &llama4Model{} conv = &llama4Model{}
case "Mistral3ForConditionalGeneration": case "Mistral3ForConditionalGeneration":
@@ -190,14 +185,10 @@ func ConvertModel(fsys fs.FS, f *os.File) error {
conv = &gemma2Model{} conv = &gemma2Model{}
case "Gemma3ForCausalLM", "Gemma3ForConditionalGeneration": case "Gemma3ForCausalLM", "Gemma3ForConditionalGeneration":
conv = &gemma3Model{Architecture: p.Architectures[0]} conv = &gemma3Model{Architecture: p.Architectures[0]}
case "Gemma3nForConditionalGeneration":
conv = &gemma3nModel{}
case "Phi3ForCausalLM": case "Phi3ForCausalLM":
conv = &phi3Model{} conv = &phi3Model{}
case "Qwen2ForCausalLM": case "Qwen2ForCausalLM":
conv = &qwen2Model{} conv = &qwen2Model{}
case "Qwen2_5_VLForConditionalGeneration":
conv = &qwen25VLModel{}
case "BertModel": case "BertModel":
conv = &bertModel{} conv = &bertModel{}
case "CohereForCausalLM": case "CohereForCausalLM":
@@ -221,22 +212,24 @@ func ConvertModel(fsys fs.FS, f *os.File) error {
return err return err
} }
vocabSize := int(cmp.Or(p.VocabSize, p.TextModel.VocabSize)) vocabSize := int(p.VocabSize)
if vocabSize == 0 {
tVocabSize := int(p.TextModel.VocabSize)
vocabSize = tVocabSize
}
switch { switch {
case vocabSize == 0: case vocabSize == 0:
slog.Debug("vocabulary size was not explicitly set by the model", "default size", len(t.Vocabulary.Tokens)) slog.Warn("vocabulary size was not explicitly set by the model", "default size", len(t.Vocabulary.Tokens))
case vocabSize > len(t.Vocabulary.Tokens): case vocabSize > len(t.Vocabulary.Tokens):
slog.Debug("vocabulary is smaller than expected, padding with dummy tokens", "expect", vocabSize, "actual", len(t.Vocabulary.Tokens)) slog.Warn("vocabulary is smaller than expected, padding with dummy tokens", "expect", vocabSize, "actual", len(t.Vocabulary.Tokens))
for i := range vocabSize - len(t.Vocabulary.Tokens) { for i := range vocabSize - len(t.Vocabulary.Tokens) {
t.Vocabulary.Tokens = append(t.Vocabulary.Tokens, fmt.Sprintf("[PAD%d]", i)) t.Vocabulary.Tokens = append(t.Vocabulary.Tokens, fmt.Sprintf("[PAD%d]", i))
t.Vocabulary.Scores = append(t.Vocabulary.Scores, -1) t.Vocabulary.Scores = append(t.Vocabulary.Scores, -1)
t.Vocabulary.Types = append(t.Vocabulary.Types, tokenTypeUserDefined) t.Vocabulary.Types = append(t.Vocabulary.Types, tokenTypeUserDefined)
} }
case vocabSize < len(t.Vocabulary.Tokens): case vocabSize < len(t.Vocabulary.Tokens):
slog.Debug("vocabulary is larger than expected", "want", vocabSize, "got", len(t.Vocabulary.Tokens)) return fmt.Errorf("vocabulary is larger than expected '%d' instead of '%d'", len(t.Vocabulary.Tokens), vocabSize)
p.VocabSize = uint32(len(t.Vocabulary.Tokens))
p.TextModel.VocabSize = uint32(len(t.Vocabulary.Tokens))
default: default:
slog.Debug("vocabulary", "size", len(t.Vocabulary.Tokens)) slog.Debug("vocabulary", "size", len(t.Vocabulary.Tokens))
} }

View File

@@ -1,165 +0,0 @@
package convert
import (
"slices"
"strings"
"github.com/ollama/ollama/fs/ggml"
"github.com/pdevine/tensor"
"github.com/pdevine/tensor/native"
"gonum.org/v1/gonum/stat/distuv"
)
type gemma3nModel struct {
ModelParameters
TextModel struct {
ActivationSparsityPattern []float32 `json:"activation_sparsity_pattern"`
AltupActiveIdx uint32 `json:"altup_active_idx"`
AltupCoefClip float32 `json:"altup_coef_clip"`
AltupCorrectScale bool `json:"altup_correct_scale"`
AltupLRMultiplier float32 `json:"altup_lr_multiplier"`
AltupNumInputs uint32 `json:"altup_num_inputs"`
HeadDim uint32 `json:"head_dim"`
HiddenSize uint32 `json:"hidden_size"`
HiddenSizePerLayerInput uint32 `json:"hidden_size_per_layer_input"`
IntermediateSize uint32 `json:"intermediate_size"`
MaxPositionEmbeddings uint32 `json:"max_position_embeddings"`
NumAttentionHeads uint32 `json:"num_attention_heads"`
NumHiddenLayers uint32 `json:"num_hidden_layers"`
NumKeyValueHeads uint32 `json:"num_key_value_heads"`
NumKVSharedLayers uint32 `json:"num_kv_shared_layers"`
RMSNormEPS float32 `json:"rms_norm_eps"`
RopeLocalBaseFreq float32 `json:"rope_local_base_freq"`
RopeTheta float32 `json:"rope_theta"`
SlidingWindow uint32 `json:"sliding_window"`
LayerTypes []string `json:"layer_types"`
} `json:"text_config"`
VisionModel struct{} `json:"vision_config"`
}
func (m *gemma3nModel) KV(t *Tokenizer) ggml.KV {
kv := m.ModelParameters.KV(t)
kv["general.architecture"] = "gemma3n"
kv["gemma3n.activation_sparsity_scale"] = slices.Collect(func(yield func(float32) bool) {
norm := distuv.Normal{Mu: 0, Sigma: 1}
for _, v := range m.TextModel.ActivationSparsityPattern {
if !yield(float32(norm.Quantile(float64(v)))) {
break
}
}
})
kv["gemma3n.altup.active_idx"] = m.TextModel.AltupActiveIdx
kv["gemma3n.altup.correct_scale"] = m.TextModel.AltupCorrectScale
kv["gemma3n.altup.lr_multiplier"] = m.TextModel.AltupLRMultiplier
kv["gemma3n.altup.num_inputs"] = m.TextModel.AltupNumInputs
kv["gemma3n.attention.head_count_kv"] = m.TextModel.NumKeyValueHeads
kv["gemma3n.attention.head_count"] = m.TextModel.NumAttentionHeads
kv["gemma3n.attention.layer_norm_rms_epsilon"] = m.TextModel.RMSNormEPS
kv["gemma3n.attention.sliding_window"] = m.TextModel.SlidingWindow
kv["gemma3n.attention.sliding_window_pattern"] = slices.Collect(func(yield func(bool) bool) {
for _, t := range m.TextModel.LayerTypes {
if !yield(t == "sliding_attention") {
break
}
}
})
kv["gemma3n.attention.shared_kv_layers"] = m.TextModel.NumKVSharedLayers
kv["gemma3n.block_count"] = m.TextModel.NumHiddenLayers
kv["gemma3n.context_length"] = m.TextModel.MaxPositionEmbeddings
kv["gemma3n.embedding_length_per_layer_input"] = m.TextModel.HiddenSizePerLayerInput
kv["gemma3n.embedding_length"] = m.TextModel.HiddenSize
kv["gemma3n.feed_forward_length"] = m.TextModel.IntermediateSize
kv["gemma3n.head_dim"] = m.TextModel.HeadDim
kv["gemma3n.rope.freq_base_local"] = m.TextModel.RopeLocalBaseFreq
kv["gemma3n.rope.freq_base"] = m.TextModel.RopeTheta
return kv
}
func (m *gemma3nModel) Tensors(ts []Tensor) []*ggml.Tensor {
out, ts := mergeTensors(ts,
merge{"altup_proj.*.weight", "altup_proj.weight"},
merge{"altup_unembd_proj.*.weight", "altup_unembd_proj.weight"},
)
for _, t := range ts {
switch {
case strings.Contains(t.Name(), "audio_tower"),
strings.Contains(t.Name(), "embed_audio"),
strings.Contains(t.Name(), "vision_tower"),
strings.Contains(t.Name(), "embed_vision"):
// TODO: handle audio and vision towers
continue
case strings.Contains(t.Name(), "altup_predict_coef"),
strings.Contains(t.Name(), "altup_correct_coef"):
if m.TextModel.AltupCoefClip > 0 {
t.SetRepacker(func(name string, data []float32, shape []uint64) (_ []float32, err error) {
dims := make([]int, len(shape))
for i := range shape {
dims[i] = int(shape[i])
}
var t tensor.Tensor = tensor.New(tensor.WithShape(dims...), tensor.WithBacking(data))
t, err = tensor.Clamp(t, -m.TextModel.AltupCoefClip, m.TextModel.AltupCoefClip)
if err != nil {
return nil, err
}
if err := t.Reshape(t.Shape().TotalSize()); err != nil {
return nil, err
}
return native.VectorF32(t.(*tensor.Dense))
})
}
}
out = append(out, &ggml.Tensor{
Name: t.Name(),
Kind: t.Kind(),
Shape: t.Shape(),
WriterTo: t,
})
}
return out
}
func (m *gemma3nModel) Replacements() []string {
return []string{
"model.language_model.embed_tokens_per_layer", "per_layer_token_embd",
"model.language_model.embed_tokens", "token_embd",
"model.language_model.per_layer_model_projection", "per_layer_model_proj",
"model.language_model.per_layer_projection_norm", "per_layer_proj_norm", "model.language_model.altup_projections", "altup_proj",
"model.language_model.altup_unembed_projections", "altup_unembd_proj",
"model.language_model.norm", "output_norm",
"model.language_model.layers", "blk",
"input_layernorm", "attn_norm",
"self_attn.q_proj", "attn_q",
"self_attn.q_norm", "attn_q_norm",
"self_attn.k_proj", "attn_k",
"self_attn.k_norm", "attn_k_norm",
"self_attn.v_proj", "attn_v",
"self_attn.o_proj", "attn_output",
"post_attention_layernorm", "post_attention_norm",
"pre_feedforward_layernorm", "ffn_norm",
"mlp.gate_proj", "ffn_gate",
"mlp.up_proj", "ffn_up",
"mlp.down_proj", "ffn_down",
"post_feedforward_layernorm", "post_ffw_norm",
"per_layer_input_gate", "inp_gate",
"per_layer_projection", "proj",
"post_per_layer_input_norm", "post_norm",
"altup.", "altup_",
"modality_router", "router",
"prediction_coefs", "predict_coef",
"correction_coefs", "correct_coef",
"correct_output_scale", "correct_scale.weight",
"laurel.", "laurel_",
"linear_left", "l",
"linear_right", "r",
"post_laurel_norm", "post_norm",
}
}

View File

@@ -139,8 +139,7 @@ func (p *llamaModel) Tensors(ts []Tensor) []*ggml.Tensor {
} }
for _, t := range ts { for _, t := range ts {
if strings.HasSuffix(t.Name(), "attn_q.weight") || strings.HasSuffix(t.Name(), "attn_k.weight") || if strings.HasSuffix(t.Name(), "attn_q.weight") || strings.HasSuffix(t.Name(), "attn_k.weight") {
strings.HasSuffix(t.Name(), "attn_q_proj.weight") || strings.HasSuffix(t.Name(), "attn_k_proj.weight") {
if !p.skipRepack { if !p.skipRepack {
t.SetRepacker(p.repack) t.SetRepacker(p.repack)
} }
@@ -182,9 +181,9 @@ func (p *llamaModel) repack(name string, data []float32, shape []uint64) ([]floa
} }
var heads uint32 var heads uint32
if strings.HasSuffix(name, "attn_q.weight") || strings.HasSuffix(name, "attn_q_proj.weight") { if strings.HasSuffix(name, "attn_q.weight") {
heads = p.NumAttentionHeads heads = p.NumAttentionHeads
} else if strings.HasSuffix(name, "attn_k.weight") || strings.HasSuffix(name, "attn_k_proj.weight") { } else if strings.HasSuffix(name, "attn_k.weight") {
heads = cmp.Or(p.NumKeyValueHeads, p.NumAttentionHeads) heads = cmp.Or(p.NumKeyValueHeads, p.NumAttentionHeads)
} else { } else {
return nil, fmt.Errorf("unknown tensor for repack: %s", name) return nil, fmt.Errorf("unknown tensor for repack: %s", name)

View File

@@ -2,6 +2,9 @@ package convert
import ( import (
"fmt" "fmt"
"io"
"slices"
"strings"
"github.com/ollama/ollama/fs/ggml" "github.com/ollama/ollama/fs/ggml"
) )
@@ -27,38 +30,65 @@ func (p *mixtralModel) KV(t *Tokenizer) ggml.KV {
} }
func (p *mixtralModel) Tensors(ts []Tensor) []*ggml.Tensor { func (p *mixtralModel) Tensors(ts []Tensor) []*ggml.Tensor {
merges := make([]merge, 0, p.NumHiddenLayers*6) oldnew := []string{
for i := range p.NumHiddenLayers { "model.layers", "blk",
merges = append(merges, merge{ "w1", "ffn_gate_exps",
fmt.Sprintf("blk.%d.*.w1.weight", i), "w2", "ffn_down_exps",
fmt.Sprintf("blk.%d.ffn_gate_exps.weight", i), "w3", "ffn_up_exps",
}, merge{ }
fmt.Sprintf("blk.%d.*.w1.bias", i),
fmt.Sprintf("blk.%d.ffn_gate_exps.bias", i), for i := range p.NumLocalExperts {
}, merge{ oldnew = append(oldnew, fmt.Sprintf(".block_sparse_moe.experts.%d.", i), ".")
fmt.Sprintf("blk.%d.*.w2.weight", i), }
fmt.Sprintf("blk.%d.ffn_up_exps.weight", i),
}, merge{ // group experts of the same layer (model.layers.%d) and type (w[123]) into a single tensor
fmt.Sprintf("blk.%d.*.w2.bias", i), namer := strings.NewReplacer(oldnew...)
fmt.Sprintf("blk.%d.ffn_up_exps.bias", i), experts := make(map[string]experts)
}, merge{
fmt.Sprintf("blk.%d.*.w3.weight", i), // merge experts into a single tensor while removing them from ts
fmt.Sprintf("blk.%d.ffn_down_exps.weight", i), ts = slices.DeleteFunc(ts, func(t Tensor) bool {
}, merge{ if !strings.Contains(t.Name(), ".block_sparse_moe.experts.") {
fmt.Sprintf("blk.%d.*.w3.bias", i), return false
fmt.Sprintf("blk.%d.ffn_down_exps.bias", i), }
name := namer.Replace(t.Name())
experts[name] = append(experts[name], t)
return true
})
var out []*ggml.Tensor
for n, e := range experts {
// TODO(mxyng): sanity check experts
out = append(out, &ggml.Tensor{
Name: n,
Kind: e[0].Kind(),
Shape: append([]uint64{uint64(len(e))}, e[0].Shape()...),
WriterTo: e,
}) })
} }
out, ts := mergeTensors(ts, merges...)
return append(out, p.llamaModel.Tensors(ts)...) return append(out, p.llamaModel.Tensors(ts)...)
} }
func (p *mixtralModel) Replacements() []string { func (p *mixtralModel) Replacements() []string {
return append( return append(
p.llamaModel.Replacements(), p.llamaModel.Replacements(),
"model.layers", "blk",
"block_sparse_moe.gate", "ffn_gate_inp", "block_sparse_moe.gate", "ffn_gate_inp",
"block_sparse_moe.experts.", ".",
) )
} }
type experts []Tensor
func (e experts) WriteTo(w io.Writer) (int64, error) {
// TODO(mxyng): experts _should_ be numerically sorted by expert but this should check
for _, t := range e {
// the canonical merged experts tensor stacks all experts along a new, 0 axis,
// e.g. `tensor.Stack(0, e[0], e[1:]...)`, which requires allocating temporary buffers
// this accomplishes the same thing by writing each expert tensor in sequence
if _, err := t.WriteTo(w); err != nil {
return 0, err
}
}
return 0, nil
}

View File

@@ -1,179 +0,0 @@
package convert
import (
"strings"
"github.com/ollama/ollama/fs/ggml"
"github.com/pdevine/tensor"
"github.com/pdevine/tensor/native"
)
type mllamaModel struct {
ModelParameters
TextModel struct {
llamaModel
CrossAttentionLayers []int32 `json:"cross_attention_layers"`
} `json:"text_config"`
VisionModel struct {
NumHiddenLayers uint32 `json:"num_hidden_layers"`
NumGlobalLayers uint32 `json:"num_global_layers"`
IntermediateLayersIndices []int32 `json:"intermediate_layers_indices"`
HiddenSize uint32 `json:"hidden_size"`
IntermediateSize uint32 `json:"intermediate_size"`
AttentionHeads uint32 `json:"attention_heads"`
ImageSize uint32 `json:"image_size"`
PatchSize uint32 `json:"patch_size"`
NumChannels uint32 `json:"num_channels"`
MaxNumTiles uint32 `json:"max_num_tiles"`
NormEpsilon float32 `json:"norm_eps"`
RopeTheta float32 `json:"rope.freq_base"`
} `json:"vision_config"`
}
func (m *mllamaModel) KV(t *Tokenizer) ggml.KV {
kv := m.ModelParameters.KV(t)
kv["general.architecture"] = "mllama"
for k, v := range m.TextModel.KV(t) {
if strings.HasPrefix(k, "llama.") {
kv[strings.ReplaceAll(k, "llama.", "mllama.")] = v
}
}
kv["mllama.attention.cross_attention_layers"] = m.TextModel.CrossAttentionLayers
kv["mllama.vision.block_count"] = m.VisionModel.NumHiddenLayers
kv["mllama.vision.global.block_count"] = m.VisionModel.NumGlobalLayers
kv["mllama.vision.intermediate_layers_indices"] = m.VisionModel.IntermediateLayersIndices
kv["mllama.vision.embedding_length"] = m.VisionModel.HiddenSize
kv["mllama.vision.feed_forward_length"] = m.VisionModel.IntermediateSize
kv["mllama.vision.attention.head_count"] = m.VisionModel.AttentionHeads
kv["mllama.vision.attention.layer_norm_epsilon"] = m.VisionModel.NormEpsilon
kv["mllama.vision.image_size"] = m.VisionModel.ImageSize
kv["mllama.vision.patch_size"] = m.VisionModel.PatchSize
kv["mllama.vision.max_num_tiles"] = m.VisionModel.MaxNumTiles
kv["mllama.vision.num_channels"] = m.VisionModel.NumChannels
return kv
}
func (m *mllamaModel) Replacements() []string {
return append(
m.TextModel.Replacements(),
"language_model.", "",
"gate_attn", "attn_gate",
"gate_ffn", "ffn_gate",
"cross_attn.", "cross_attn_",
"vision_model", "v",
"class_embedding", "class_embd",
"patch_embedding", "patch_embd",
"gated_positional_embedding.tile_embedding", "tile_position_embd",
"gated_positional_embedding.embedding", "position_embd.weight",
"gated_positional_embedding", "position_embd",
"embedding.weight", "weight",
"pre_tile_positional_embedding", "pre_tile_position_embd",
"post_tile_positional_embedding", "post_tile_position_embd",
"layernorm_pre", "pre_ln",
"layernorm_post", "post_ln",
"global_transformer.layers", "global.blk",
"transformer.layers", "blk",
"mlp.fc1", "ffn_up",
"mlp.fc2", "ffn_down",
"multi_modal_projector", "mm.0",
)
}
func (m *mllamaModel) Tensors(ts []Tensor) []*ggml.Tensor {
var out []*ggml.Tensor
var text []Tensor
for _, t := range ts {
if !strings.HasPrefix(t.Name(), "v.") && !strings.HasPrefix(t.Name(), "mm.") {
text = append(text, t)
} else if t.Name() == "v.position_embd.gate" {
for _, name := range []string{"v.position_embd.gate", "v.tile_position_embd.gate"} {
tt := t.Clone()
tt.SetRepacker(m.repack(name))
out = append(out, &ggml.Tensor{
Name: name,
Kind: t.Kind(),
Shape: t.Shape(),
WriterTo: tt,
})
}
} else {
if t.Name() == "v.pre_tile_position_embd.gate" || t.Name() == "v.post_tile_position_embd.gate" {
t.SetRepacker(m.repack(t.Name()))
} else if strings.HasSuffix(t.Name(), "attn_q.weight") || strings.HasSuffix(t.Name(), "attn_k.weight") {
t.SetRepacker(m.repack(t.Name()))
} else if strings.HasSuffix(t.Name(), "attn_gate") || strings.HasSuffix(t.Name(), "ffn_gate") {
t.SetRepacker(m.repack(t.Name()))
}
out = append(out, &ggml.Tensor{
Name: t.Name(),
Kind: t.Kind(),
Shape: t.Shape(),
WriterTo: t,
})
}
}
return append(out, m.TextModel.Tensors(text)...)
}
func (m *mllamaModel) repack(name string) Repacker {
return func(_ string, data []float32, shape []uint64) (_ []float32, err error) {
dims := make([]int, len(shape))
for i, dim := range shape {
dims[i] = int(dim)
}
var t tensor.Tensor = tensor.New(tensor.WithShape(dims...), tensor.WithBacking(data))
if strings.HasSuffix(name, "attn_q.weight") || strings.HasSuffix(name, "attn_k.weight") {
heads := m.VisionModel.AttentionHeads
if err := t.Reshape(append([]int{int(heads), 2, dims[0] / int(heads) / 2}, dims[1:]...)...); err != nil {
return nil, err
}
if err := t.T(0, 2, 1, 3); err != nil {
return nil, err
}
if err := t.Reshape(dims...); err != nil {
return nil, err
}
if err := t.Transpose(); err != nil {
return nil, err
}
} else {
t, err = tensor.Tanh(t)
if err != nil {
return nil, err
}
if name == "v.position_embd.gate" {
t, err = tensor.Sub(float32(1), t)
if err != nil {
return nil, err
}
}
}
t = tensor.Materialize(t)
// flatten tensor so it can be return as a vector
if err := t.Reshape(t.Shape().TotalSize()); err != nil {
return nil, err
}
return native.VectorF32(t.(*tensor.Dense))
}
}

View File

@@ -15,7 +15,6 @@ type qwen2Model struct {
Type string `json:"type"` Type string `json:"type"`
Factor ropeFactor `json:"factor"` Factor ropeFactor `json:"factor"`
OriginalMaxPositionEmbeddings uint32 `json:"original_max_position_embeddings"` OriginalMaxPositionEmbeddings uint32 `json:"original_max_position_embeddings"`
MropeSection []int32 `json:"mrope_section"`
} `json:"rope_scaling"` } `json:"rope_scaling"`
RMSNormEPS float32 `json:"rms_norm_eps"` RMSNormEPS float32 `json:"rms_norm_eps"`
} }
@@ -40,8 +39,6 @@ func (q *qwen2Model) KV(t *Tokenizer) ggml.KV {
case "yarn": case "yarn":
kv["qwen2.rope.scaling.type"] = q.RopeScaling.Type kv["qwen2.rope.scaling.type"] = q.RopeScaling.Type
kv["qwen2.rope.scaling.factor"] = q.RopeScaling.Factor kv["qwen2.rope.scaling.factor"] = q.RopeScaling.Factor
case "mrope", "default":
kv["qwen2.rope.mrope_section"] = q.RopeScaling.MropeSection
default: default:
panic("unknown rope scaling type") panic("unknown rope scaling type")
} }

View File

@@ -1,102 +0,0 @@
package convert
import (
"cmp"
"slices"
"strings"
"github.com/ollama/ollama/fs/ggml"
)
type qwen25VLModel struct {
qwen2Model
VisionModel struct {
Depth uint32 `json:"depth"`
HiddenSize uint32 `json:"hidden_size"`
NumHeads uint32 `json:"num_heads"`
InChannels uint32 `json:"in_chans"`
PatchSize uint32 `json:"patch_size"`
SpatialMergeSize uint32 `json:"spatial_merge_size"`
SpatialPatchSize uint32 `json:"spatial_patch_size"`
WindowSize uint32 `json:"window_size"`
RMSNormEps float32 `json:"layer_norm_epsilon"`
RopeTheta float32 `json:"rope_theta"`
FullAttentionBlocks []int32 `json:"fullatt_block_indexes"`
TemporalPatchSize uint32 `json:"temporal_patch_size"`
} `json:"vision_config"`
}
var _ ModelConverter = (*qwen25VLModel)(nil)
func (q *qwen25VLModel) KV(t *Tokenizer) ggml.KV {
kv := q.ModelParameters.KV(t)
kv["general.architecture"] = "qwen25vl"
for k, v := range q.qwen2Model.KV(t) {
if strings.HasPrefix(k, "qwen2.") {
kv[strings.Replace(k, "qwen2.", "qwen25vl.", 1)] = v
}
}
if q.VisionModel.FullAttentionBlocks == nil {
kv["qwen25vl.vision.fullatt_block_indexes"] = []int32{7, 15, 23, 31}
}
kv["qwen25vl.vision.block_count"] = cmp.Or(q.VisionModel.Depth, 32)
kv["qwen25vl.vision.embedding_length"] = q.VisionModel.HiddenSize
kv["qwen25vl.vision.attention.head_count"] = cmp.Or(q.VisionModel.NumHeads, 16)
kv["qwen25vl.vision.num_channels"] = q.VisionModel.InChannels
kv["qwen25vl.vision.patch_size"] = cmp.Or(q.VisionModel.PatchSize, 14)
kv["qwen25vl.vision.spatial_merge_size"] = cmp.Or(q.VisionModel.SpatialMergeSize, 2)
kv["qwen25vl.vision.spatial_patch_size"] = q.VisionModel.SpatialPatchSize
kv["qwen25vl.vision.window_size"] = cmp.Or(q.VisionModel.WindowSize, 112)
kv["qwen25vl.vision.attention.layer_norm_epsilon"] = cmp.Or(q.VisionModel.RMSNormEps, 1e-6)
kv["qwen25vl.vision.rope.freq_base"] = cmp.Or(q.VisionModel.RopeTheta, 1e4)
kv["qwen25vl.vision.fullatt_block_indexes"] = q.VisionModel.FullAttentionBlocks
kv["qwen25vl.vision.temporal_patch_size"] = cmp.Or(q.VisionModel.TemporalPatchSize, 2)
return kv
}
func (q *qwen25VLModel) Tensors(ts []Tensor) []*ggml.Tensor {
var out []*ggml.Tensor
for _, t := range ts {
if strings.Contains(t.Name(), "patch_embed.proj") {
for t := range splitDim(t, 2,
split{Replacer: strings.NewReplacer("patch_embed.proj", "patch_embd_0")},
split{Replacer: strings.NewReplacer("patch_embed.proj", "patch_embd_1")},
) {
t.Shape = slices.DeleteFunc(t.Shape, func(i uint64) bool { return i == 1 })
out = append(out, t)
}
} else if strings.Contains(t.Name(), "attn.qkv") {
out = append(out, slices.Collect(splitDim(t, 0,
split{Replacer: strings.NewReplacer("attn.qkv", "attn_q")},
split{Replacer: strings.NewReplacer("attn.qkv", "attn_k")},
split{Replacer: strings.NewReplacer("attn.qkv", "attn_v")},
))...)
} else {
out = append(out, &ggml.Tensor{
Name: t.Name(),
Kind: t.Kind(),
Shape: t.Shape(),
WriterTo: t,
})
}
}
return out
}
func (p *qwen25VLModel) Replacements() []string {
return append(
p.qwen2Model.Replacements(),
"visual", "v",
"blocks", "blk",
"attn.proj", "attn_out",
"norm1", "ln1",
"norm2", "ln2",
)
}

View File

@@ -11,13 +11,14 @@ import (
"io" "io"
"io/fs" "io/fs"
"log/slog" "log/slog"
"maps"
"os" "os"
"path/filepath" "path/filepath"
"slices" "slices"
"strings" "strings"
"testing" "testing"
"golang.org/x/exp/maps"
"github.com/ollama/ollama/fs/ggml" "github.com/ollama/ollama/fs/ggml"
) )
@@ -46,7 +47,7 @@ func convertFull(t *testing.T, fsys fs.FS) (*os.File, ggml.KV, ggml.Tensors) {
} }
t.Cleanup(func() { r.Close() }) t.Cleanup(func() { r.Close() })
m, err := ggml.Decode(r, -1) m, _, err := ggml.Decode(r, -1)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@@ -136,7 +137,9 @@ func TestConvertModel(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
for _, k := range slices.Sorted(maps.Keys(expect)) { keys := maps.Keys(expect)
slices.Sort(keys)
for _, k := range keys {
if v, ok := actual[k]; !ok { if v, ok := actual[k]; !ok {
t.Errorf("missing %s", k) t.Errorf("missing %s", k)
} else if v != expect[k] { } else if v != expect[k] {
@@ -329,7 +332,7 @@ func TestConvertAdapter(t *testing.T) {
} }
defer r.Close() defer r.Close()
m, err := ggml.Decode(r, -1) m, _, err := ggml.Decode(r, -1)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@@ -340,7 +343,9 @@ func TestConvertAdapter(t *testing.T) {
actual := generateResultsJSON(t, r, m.KV(), m.Tensors()) actual := generateResultsJSON(t, r, m.KV(), m.Tensors())
for _, k := range slices.Sorted(maps.Keys(c.Expected)) { keys := maps.Keys(c.Expected)
slices.Sort(keys)
for _, k := range keys {
if v, ok := actual[k]; !ok { if v, ok := actual[k]; !ok {
t.Errorf("missing %s", k) t.Errorf("missing %s", k)
} else if v != c.Expected[k] { } else if v != c.Expected[k] {

58
convert/fs.go Normal file
View File

@@ -0,0 +1,58 @@
package convert
import (
"archive/zip"
"errors"
"io"
"io/fs"
"os"
"path/filepath"
)
type ZipReader struct {
r *zip.Reader
p string
// limit is the maximum size of a file that can be read directly
// from the zip archive. Files larger than this size will be extracted
limit int64
}
func NewZipReader(r *zip.Reader, p string, limit int64) fs.FS {
return &ZipReader{r, p, limit}
}
func (z *ZipReader) Open(name string) (fs.File, error) {
r, err := z.r.Open(name)
if err != nil {
return nil, err
}
defer r.Close()
if fi, err := r.Stat(); err != nil {
return nil, err
} else if fi.Size() < z.limit {
return r, nil
}
if !filepath.IsLocal(name) {
return nil, zip.ErrInsecurePath
}
n := filepath.Join(z.p, name)
if _, err := os.Stat(n); errors.Is(err, os.ErrNotExist) {
w, err := os.Create(n)
if err != nil {
return nil, err
}
defer w.Close()
if _, err := io.Copy(w, r); err != nil {
return nil, err
}
} else if err != nil {
return nil, err
}
return os.Open(n)
}

View File

@@ -38,10 +38,7 @@ const (
func (t tensorBase) Kind() uint32 { func (t tensorBase) Kind() uint32 {
if strings.HasSuffix(t.name, ".ffn_gate_inp.weight") || if strings.HasSuffix(t.name, ".ffn_gate_inp.weight") ||
t.name == "token_types.weight" || t.name == "token_types.weight" ||
t.name == "v.positional_embedding_vlm" || t.name == "v.positional_embedding_vlm" {
t.name == "v.tile_position_embd.weight" ||
t.name == "v.pre_tile_position_embd.weight" ||
t.name == "v.post_tile_position_embd.weight" {
// these tensors are always F32 // these tensors are always F32
return 0 return 0
} }

View File

@@ -8,12 +8,12 @@ import (
"fmt" "fmt"
"io" "io"
"io/fs" "io/fs"
"maps"
"slices" "slices"
"strings" "strings"
"github.com/d4l3k/go-bfloat16" "github.com/d4l3k/go-bfloat16"
"github.com/x448/float16" "github.com/x448/float16"
"golang.org/x/exp/maps"
) )
type safetensorMetadata struct { type safetensorMetadata struct {
@@ -46,7 +46,8 @@ func parseSafetensors(fsys fs.FS, replacer *strings.Replacer, ps ...string) ([]T
return nil, err return nil, err
} }
keys := slices.Sorted(maps.Keys(headers)) keys := maps.Keys(headers)
slices.Sort(keys)
names := make(map[string]struct{}, len(keys)) names := make(map[string]struct{}, len(keys))

View File

@@ -1,129 +0,0 @@
package convert
import (
"cmp"
"io"
"iter"
"path"
"slices"
"strings"
"github.com/pdevine/tensor"
"github.com/pdevine/tensor/native"
"github.com/ollama/ollama/fs/ggml"
)
type split struct {
*strings.Replacer
dim int
// fn is an optional function to apply to the tensor after slicing
fn func(tensor.Tensor) (tensor.Tensor, error)
}
// splitDim splits a tensor along a specified dimension into multiple tensors. The dimension
// is split evenly based on the number of replacers provided unless a specific count is given.
func splitDim(t Tensor, dim int, splits ...split) iter.Seq[*ggml.Tensor] {
return func(yield func(*ggml.Tensor) bool) {
var offset int
for _, split := range splits {
t := t.Clone()
shape := slices.Clone(t.Shape())
shape[dim] = cmp.Or(uint64(split.dim), shape[dim]/uint64(len(splits)))
slice := slices.Repeat([]tensor.Slice{nil}, len(shape))
slice[dim] = tensor.S(offset, offset+int(shape[dim]))
offset += int(shape[dim])
t.SetRepacker(func(_ string, data []float32, shape []uint64) ([]float32, error) {
dims := make([]int, len(shape))
for i := range shape {
dims[i] = int(shape[i])
}
var tt tensor.Tensor = tensor.New(tensor.WithShape(dims...), tensor.WithBacking(data))
tt, err := tt.Slice(slice...)
if err != nil {
return nil, err
}
tt = tensor.Materialize(tt)
if split.fn != nil {
tt, err = split.fn(tt)
if err != nil {
return nil, err
}
}
// flatten tensor so it can be written as a vector
if err := tt.Reshape(tt.Shape().TotalSize()); err != nil {
return nil, err
}
return native.VectorF32(tt.(*tensor.Dense))
})
if !yield(&ggml.Tensor{
Name: split.Replace(t.Name()),
Kind: t.Kind(),
Shape: shape,
WriterTo: t,
}) {
break
}
}
}
}
type merge struct {
pattern, name string
}
// mergeTensors merges tensors that match a given pattern into a single tensor.
func mergeTensors(unmatched []Tensor, merges ...merge) (out []*ggml.Tensor, _ []Tensor) {
var matched []Tensor
for i := range merges {
matched, unmatched = slicesSplitFunc(unmatched, func(t Tensor) bool {
matched, _ := path.Match(merges[i].pattern, t.Name())
return matched
})
if len(matched) > 0 {
out = append(out, &ggml.Tensor{
Name: merges[i].name,
Kind: matched[0].Kind(),
Shape: append([]uint64{uint64(len(matched))}, matched[0].Shape()...),
WriterTo: mergeGroup(matched),
})
}
}
return out, unmatched
}
// slicesSplitFunc splits a slice into two slices based on a predicate function.
func slicesSplitFunc[S ~[]E, E comparable](s S, fn func(e E) bool) (matched, unmatched S) {
for _, e := range s {
if fn(e) {
matched = append(matched, e)
} else {
unmatched = append(unmatched, e)
}
}
return matched, unmatched
}
type mergeGroup []Tensor
func (g mergeGroup) WriteTo(w io.Writer) (int64, error) {
for _, t := range g {
if _, err := t.WriteTo(w); err != nil {
return 0, err
}
}
return 0, nil
}

View File

@@ -1,402 +0,0 @@
package convert
import (
"bytes"
"encoding/binary"
"io"
"iter"
"slices"
"strings"
"testing"
"github.com/google/go-cmp/cmp"
"github.com/ollama/ollama/fs/ggml"
"github.com/pdevine/tensor"
)
type fakeTensor struct {
name string
shape []uint64
data []float32
repacker Repacker
}
func (f fakeTensor) Name() string {
return f.name
}
func (f fakeTensor) Shape() []uint64 {
return f.shape
}
func (f fakeTensor) Kind() uint32 {
return 0
}
func (f *fakeTensor) SetRepacker(fn Repacker) {
f.repacker = fn
}
func (f fakeTensor) Clone() Tensor {
return &fakeTensor{
name: f.name,
shape: slices.Clone(f.shape),
data: slices.Clone(f.data),
repacker: f.repacker,
}
}
func (f fakeTensor) WriteTo(w io.Writer) (n int64, err error) {
data := f.data
if f.repacker != nil {
data, err = f.repacker(f.name, data, f.shape)
if err != nil {
return 0, err
}
}
if err := binary.Write(w, binary.LittleEndian, data); err != nil {
return 0, err
}
return int64(len(data) * 4), nil
}
func mul(shape []uint64) int {
n := 1
for _, dim := range shape {
n *= int(dim)
}
return n
}
func TestSplitDim(t *testing.T) {
r := fakeTensor{
name: "a.b",
shape: []uint64{3, 4},
data: []float32{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11},
}
t.Run("no split", func(t *testing.T) {
for tt := range splitDim(&r, 0, split{Replacer: strings.NewReplacer("a", "x")}) {
if tt.Name != "x.b" {
t.Fatalf("expected name 'x', got '%s'", tt.Name)
}
if !slices.Equal(tt.Shape, []uint64{3, 4}) {
t.Fatalf("expected shape [3, 4], got %v", tt.Shape)
}
var b bytes.Buffer
if _, err := tt.WriteTo(&b); err != nil {
t.Fatal(err)
}
f32s := make([]float32, mul(tt.Shape))
if err := binary.Read(&b, binary.LittleEndian, &f32s); err != nil {
t.Fatal(err)
}
if !slices.Equal(f32s, []float32{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11}) {
t.Fatalf("expected data [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11], got %v", f32s)
}
}
})
t.Run("even split", func(t *testing.T) {
next, stop := iter.Pull(splitDim(&r, 1,
split{Replacer: strings.NewReplacer("a", "x")},
split{Replacer: strings.NewReplacer("b", "y")},
))
defer stop()
{
tt, ok := next()
if !ok {
t.Fatal("expected at least one split")
}
if tt.Name != "x.b" {
t.Fatal("expected name 'x.b', got", tt.Name)
}
if !slices.Equal(tt.Shape, []uint64{3, 2}) {
t.Fatal("expected shape [3, 2], got", tt.Shape)
}
var b bytes.Buffer
if _, err := tt.WriteTo(&b); err != nil {
t.Fatal(err)
}
f32s := make([]float32, mul(tt.Shape))
if err := binary.Read(&b, binary.LittleEndian, &f32s); err != nil {
t.Fatal(err)
}
if !slices.Equal(f32s, []float32{0, 1, 4, 5, 8, 9}) {
t.Fatal("expected data [0, 1, 4, 5, 8, 9], got", f32s)
}
}
{
tt, ok := next()
if !ok {
t.Fatal("expected at least one split")
}
if tt.Name != "a.y" {
t.Fatal("expected name 'a.y', got", tt.Name)
}
if !slices.Equal(tt.Shape, []uint64{3, 2}) {
t.Fatal("expected shape [3, 2], got", tt.Shape)
}
var b bytes.Buffer
if _, err := tt.WriteTo(&b); err != nil {
t.Fatal(err)
}
f32s := make([]float32, mul(tt.Shape))
if err := binary.Read(&b, binary.LittleEndian, &f32s); err != nil {
t.Fatal(err)
}
if !slices.Equal(f32s, []float32{2, 3, 6, 7, 10, 11}) {
t.Fatal("expected data [2, 3, 6, 7, 10, 11], got", f32s)
}
}
})
t.Run("uneven split", func(t *testing.T) {
next, stop := iter.Pull(splitDim(&r, 0,
split{Replacer: strings.NewReplacer("a", "x"), dim: 2},
split{Replacer: strings.NewReplacer("b", "y"), dim: 1},
))
defer stop()
{
tt, ok := next()
if !ok {
t.Fatal("expected at least one split")
}
if tt.Name != "x.b" {
t.Fatal("expected name 'x.b', got", tt.Name)
}
if !slices.Equal(tt.Shape, []uint64{2, 4}) {
t.Fatal("expected shape [2, 4], got", tt.Shape)
}
var b bytes.Buffer
if _, err := tt.WriteTo(&b); err != nil {
t.Fatal(err)
}
f32s := make([]float32, mul(tt.Shape))
if err := binary.Read(&b, binary.LittleEndian, &f32s); err != nil {
t.Fatal(err)
}
if !slices.Equal(f32s, []float32{0, 1, 2, 3, 4, 5, 6, 7}) {
t.Fatal("expected data [0, 1, 2, 3, 4, 5, 6, 7], got", f32s)
}
}
{
tt, ok := next()
if !ok {
t.Fatal("expected at least one split")
}
if tt.Name != "a.y" {
t.Fatal("expected name 'a.y', got", tt.Name)
}
if !slices.Equal(tt.Shape, []uint64{1, 4}) {
t.Fatal("expected shape [1, 4], got", tt.Shape)
}
var b bytes.Buffer
if _, err := tt.WriteTo(&b); err != nil {
t.Fatal(err)
}
f32s := make([]float32, mul(tt.Shape))
if err := binary.Read(&b, binary.LittleEndian, &f32s); err != nil {
t.Fatal(err)
}
if !slices.Equal(f32s, []float32{8, 9, 10, 11}) {
t.Fatal("expected data [8, 9, 10, 11], got", f32s)
}
}
})
t.Run("split with transpose", func(t *testing.T) {
next, stop := iter.Pull(splitDim(&r, 1,
split{Replacer: strings.NewReplacer("a", "x")},
split{Replacer: strings.NewReplacer("b", "y"), fn: func(tt tensor.Tensor) (tensor.Tensor, error) {
return tensor.Transpose(tt, 1, 0)
}},
))
defer stop()
{
tt, ok := next()
if !ok {
t.Fatal("expected at least one split")
}
if tt.Name != "x.b" {
t.Fatal("expected name 'x.b', got", tt.Name)
}
if !slices.Equal(tt.Shape, []uint64{3, 2}) {
t.Fatal("expected shape [3, 2], got", tt.Shape)
}
var b bytes.Buffer
if _, err := tt.WriteTo(&b); err != nil {
t.Fatal(err)
}
f32s := make([]float32, mul(tt.Shape))
if err := binary.Read(&b, binary.LittleEndian, &f32s); err != nil {
t.Fatal(err)
}
if !slices.Equal(f32s, []float32{0, 1, 4, 5, 8, 9}) {
t.Fatal("expected data [0, 1, 4, 5, 8, 9], got", f32s)
}
}
{
tt, ok := next()
if !ok {
t.Fatal("expected at least one split")
}
if tt.Name != "a.y" {
t.Fatal("expected name 'a.y', got", tt.Name)
}
if !slices.Equal(tt.Shape, []uint64{3, 2}) {
t.Fatal("expected shape [3, 2], got", tt.Shape)
}
var b bytes.Buffer
if _, err := tt.WriteTo(&b); err != nil {
t.Fatal(err)
}
f32s := make([]float32, mul(tt.Shape))
if err := binary.Read(&b, binary.LittleEndian, &f32s); err != nil {
t.Fatal(err)
}
if !slices.Equal(f32s, []float32{2, 6, 10, 3, 7, 11}) {
t.Fatal("expected data [2, 6, 10, 3, 7, 11], got", f32s)
}
}
})
}
func TestMerge(t *testing.T) {
unmatched := []Tensor{
&fakeTensor{
name: "a.0.b",
shape: []uint64{5, 2},
data: []float32{10, 11, 12, 13, 14, 15, 16, 17, 18, 19},
},
&fakeTensor{
name: "a.1.b",
shape: []uint64{5, 2},
data: []float32{20, 21, 22, 23, 24, 25, 26, 27, 28, 29},
},
&fakeTensor{
name: "c.0.d",
shape: []uint64{5, 2},
data: []float32{30, 31, 32, 33, 34, 35, 36, 37, 38, 39},
},
&fakeTensor{
name: "c.1.d",
shape: []uint64{5, 2},
data: []float32{40, 41, 42, 43, 44, 45, 46, 47, 48, 49},
},
&fakeTensor{
name: "e.0.f",
shape: []uint64{5, 2},
data: []float32{50, 51, 52, 53, 54, 55, 56, 57, 58, 59},
},
}
checkMatched := func(t *testing.T, n int, matched []*ggml.Tensor) {
for i := range n {
got := matched[i]
if diff := cmp.Diff([]uint64{2, 5, 2}, got.Shape); diff != "" {
t.Errorf("unexpected (-want +got):\n%s", diff)
}
var b bytes.Buffer
if _, err := got.WriteTo(&b); err != nil {
t.Fatal(err)
}
f32s := make([]float32, 20)
if err := binary.Read(&b, binary.LittleEndian, &f32s); err != nil {
t.Fatal(err)
}
offset := 10 + (i * 20)
want := make([]float32, 20)
for j := range 20 {
want[j] = float32(offset + j)
}
if diff := cmp.Diff(want, f32s); diff != "" {
t.Errorf("unexpected data (-want +got):\n%s", diff)
}
}
}
t.Run("single merge", func(t *testing.T) {
matched, unmatched := mergeTensors(unmatched, merge{"a.*.b", "a.b"})
if len(unmatched) != 3 {
t.Error("expected 3 remaining tensors, got", len(unmatched))
}
if len(matched) != 1 {
t.Error("expected 1 merged tensor, got", len(matched))
}
checkMatched(t, 1, matched)
})
t.Run("multiple merges", func(t *testing.T) {
matched, unmatched := mergeTensors(unmatched, merge{"a.*.b", "a.b"}, merge{"c.*.d", "c.d"})
if len(unmatched) != 1 {
t.Error("expected 1 remaining tensors, got", len(unmatched))
}
if len(matched) != 2 {
t.Error("expected 2 merged tensor, got", len(matched))
}
checkMatched(t, 2, matched)
})
t.Run("no match", func(t *testing.T) {
matched, unmatched := mergeTensors(unmatched, merge{"x.*.y", "x.y"})
if len(unmatched) != 5 {
t.Error("expected 5 remaining tensors, got", len(unmatched))
}
if len(matched) != 0 {
t.Error("expected no merged tensors, got", len(matched))
}
})
}

View File

@@ -8,10 +8,11 @@ import (
"fmt" "fmt"
"io/fs" "io/fs"
"log/slog" "log/slog"
"maps"
"os" "os"
"slices" "slices"
"strings" "strings"
"golang.org/x/exp/maps"
) )
const ( const (
@@ -109,7 +110,6 @@ func parseTokenizer(fsys fs.FS, specialTokenTypes []string) (*Tokenizer, error)
} }
if f, err := fsys.Open("tokenizer_config.json"); errors.Is(err, os.ErrNotExist) { if f, err := fsys.Open("tokenizer_config.json"); errors.Is(err, os.ErrNotExist) {
// noop
} else if err != nil { } else if err != nil {
return nil, err return nil, err
} else { } else {
@@ -171,34 +171,6 @@ func parseTokenizer(fsys fs.FS, specialTokenTypes []string) (*Tokenizer, error)
} }
} }
if f, err := fsys.Open("generation_config.json"); errors.Is(err, os.ErrNotExist) {
} else if err != nil {
return nil, err
} else {
defer f.Close()
var p map[string]json.RawMessage
if err := json.NewDecoder(f).Decode(&p); err != nil {
return nil, err
}
for _, st := range specialTokenTypes {
if bts, ok := p[fmt.Sprintf("%s_token_id", st)]; ok {
var ids []int32
if err := json.Unmarshal(bts, &ids); err != nil {
// value is not a list so the existing ID is used
continue
}
if i := slices.IndexFunc(t.SpecialVocabulary, func(sv *SpecialVocabulary) bool {
return sv.Type == st
}); i >= 0 {
t.SpecialVocabulary[i].IDs = ids
}
}
}
}
return t, nil return t, nil
} }
@@ -259,8 +231,11 @@ func parseVocabularyFromTokenizer(fsys fs.FS) (*Vocabulary, error) {
tokens[token.ID] = token tokens[token.ID] = token
} }
keys := maps.Keys(tokens)
slices.Sort(keys)
v := Vocabulary{Model: "gpt2"} v := Vocabulary{Model: "gpt2"}
for _, k := range slices.Sorted(maps.Keys(tokens)) { for _, k := range keys {
token := tokens[k] token := tokens[k]
v.Tokens = append(v.Tokens, token.Content) v.Tokens = append(v.Tokens, token.Content)
v.Scores = append(v.Scores, float32(token.ID)) v.Scores = append(v.Scores, float32(token.ID))
@@ -305,9 +280,6 @@ type SpecialVocabulary struct {
ID int ID int
Content string Content string
AddToken bool AddToken bool
// IDs is populated by generation_config.json
IDs []int32
} }
func (sv SpecialVocabulary) Key() string { func (sv SpecialVocabulary) Key() string {

View File

@@ -247,67 +247,6 @@ func TestParseTokenizer(t *testing.T) {
Pre: "default", Pre: "default",
}, },
}, },
{
name: "generation config eos token ids",
fsys: createTokenizerFS(t, t.TempDir(), map[string]io.Reader{
"tokenizer.json": strings.NewReader(`{
"added_tokens": [
{
"id": 0,
"content": "<bos>",
"special": true
},
{
"id": 1,
"content": "<eos>",
"special": true
},
{
"id": 2,
"content": "<eot>",
"special": true
},
{
"id": 3,
"content": "<eom>",
"special": true
}
],
"model": {
"vocab": {
"<bos>": 0,
"<eos>": 1,
"<eot>": 2,
"<eom>": 3
}
}
}`),
"tokenizer_config.json": strings.NewReader(`{
"add_bos_token": true,
"add_eos_token": false,
"bos_token": "<bos>",
"eos_token": "<eos>"
}`),
"generation_config.json": strings.NewReader(`{
"bos_token_id": 0,
"eos_token_id": [1, 2, 3]
}`),
}),
specialTokenTypes: []string{"pad", "eos", "bos", "unk"},
want: &Tokenizer{
Vocabulary: &Vocabulary{
Model: "gpt2",
Tokens: []string{"<bos>", "<eos>", "<eot>", "<eom>"},
Scores: []float32{0, 1, 2, 3},
Types: []int32{3, 3, 3, 3},
},
SpecialVocabulary: []*SpecialVocabulary{
{Type: "eos", Content: "<eos>", ID: 1, IDs: []int32{1, 2, 3}, AddToken: false},
{Type: "bos", Content: "<bos>", ID: 0, AddToken: true},
},
Pre: "default",
},
},
} }
for _, tt := range cases { for _, tt := range cases {

View File

@@ -58,7 +58,7 @@ func AMDGetGPUInfo() ([]RocmGPUInfo, error) {
driverMajor, driverMinor, err := AMDDriverVersion() driverMajor, driverMinor, err := AMDDriverVersion()
if err != nil { if err != nil {
// TODO - if we see users crash and burn with the upstreamed kernel this can be adjusted to hard-fail rocm support and fallback to CPU // TODO - if we see users crash and burn with the upstreamed kernel this can be adjusted to hard-fail rocm support and fallback to CPU
slog.Warn("ollama recommends running the https://www.amd.com/en/support/download/linux-drivers.html", "error", err) slog.Warn("ollama recommends running the https://www.amd.com/en/support/linux-drivers", "error", err)
} }
// Determine if the user has already pre-selected which GPUs to look at, then ignore the others // Determine if the user has already pre-selected which GPUs to look at, then ignore the others

View File

@@ -56,7 +56,6 @@ func cudaVariant(gpuInfo CudaGPUInfo) string {
} }
} }
} }
return "sbsa"
} }
// driver 12.0 has problems with the cuda v12 library, so run v11 on those older drivers // driver 12.0 has problems with the cuda v12 library, so run v11 on those older drivers

View File

@@ -670,7 +670,7 @@ func loadOneapiMgmt(oneapiLibPaths []string) (int, *C.oneapi_handle_t, string, e
} }
func getVerboseState() C.uint16_t { func getVerboseState() C.uint16_t {
if envconfig.LogLevel() < slog.LevelInfo { if envconfig.Debug() {
return C.uint16_t(1) return C.uint16_t(1)
} }
return C.uint16_t(0) return C.uint16_t(0)

View File

@@ -4,7 +4,6 @@
* [Quickstart](../README.md#quickstart) * [Quickstart](../README.md#quickstart)
* [Examples](./examples.md) * [Examples](./examples.md)
* [Importing models](./import.md) * [Importing models](./import.md)
* [MacOS Documentation](./macos.md)
* [Linux Documentation](./linux.md) * [Linux Documentation](./linux.md)
* [Windows Documentation](./windows.md) * [Windows Documentation](./windows.md)
* [Docker Documentation](./docker.md) * [Docker Documentation](./docker.md)

View File

@@ -19,7 +19,7 @@
### Model names ### Model names
Model names follow a `model:tag` format, where `model` can have an optional namespace such as `example/model`. Some examples are `orca-mini:3b-q8_0` and `llama3:70b`. The tag is optional and, if not provided, will default to `latest`. The tag is used to identify a specific version. Model names follow a `model:tag` format, where `model` can have an optional namespace such as `example/model`. Some examples are `orca-mini:3b-q4_1` and `llama3:70b`. The tag is optional and, if not provided, will default to `latest`. The tag is used to identify a specific version.
### Durations ### Durations
@@ -43,7 +43,6 @@ Generate a response for a given prompt with a provided model. This is a streamin
- `prompt`: the prompt to generate a response for - `prompt`: the prompt to generate a response for
- `suffix`: the text after the model response - `suffix`: the text after the model response
- `images`: (optional) a list of base64-encoded images (for multimodal models such as `llava`) - `images`: (optional) a list of base64-encoded images (for multimodal models such as `llava`)
- `think`: (for thinking models) should the model think before responding?
Advanced parameters (optional): Advanced parameters (optional):
@@ -491,39 +490,28 @@ Generate the next message in a chat with a provided model. This is a streaming e
- `model`: (required) the [model name](#model-names) - `model`: (required) the [model name](#model-names)
- `messages`: the messages of the chat, this can be used to keep a chat memory - `messages`: the messages of the chat, this can be used to keep a chat memory
- `tools`: list of tools in JSON for the model to use if supported - `tools`: list of tools in JSON for the model to use if supported
- `think`: (for thinking models) should the model think before responding?
The `message` object has the following fields: The `message` object has the following fields:
- `role`: the role of the message, either `system`, `user`, `assistant`, or `tool` - `role`: the role of the message, either `system`, `user`, `assistant`, or `tool`
- `content`: the content of the message - `content`: the content of the message
- `thinking`: (for thinking models) the model's thinking process
- `images` (optional): a list of images to include in the message (for multimodal models such as `llava`) - `images` (optional): a list of images to include in the message (for multimodal models such as `llava`)
- `tool_calls` (optional): a list of tools in JSON that the model wants to use - `tool_calls` (optional): a list of tools in JSON that the model wants to use
- `tool_name` (optional): add the name of the tool that was executed to inform the model of the result
Advanced parameters (optional): Advanced parameters (optional):
- `format`: the format to return a response in. Format can be `json` or a JSON schema. - `format`: the format to return a response in. Format can be `json` or a JSON schema.
- `options`: additional model parameters listed in the documentation for the [Modelfile](./modelfile.md#valid-parameters-and-values) such as `temperature` - `options`: additional model parameters listed in the documentation for the [Modelfile](./modelfile.md#valid-parameters-and-values) such as `temperature`
- `stream`: if `false` the response will be returned as a single response object, rather than a stream of objects - `stream`: if `false` the response will be returned as a single response object, rather than a stream of objects
- `keep_alive`: controls how long the model will stay loaded into memory following the request (default: `5m`) - `keep_alive`: controls how long the model will stay loaded into memory following the request (default: `5m`)
### Tool calling
Tool calling is supported by providing a list of tools in the `tools` parameter. The model will generate a response that includes a list of tool calls. See the [Chat request (Streaming with tools)](#chat-request-streaming-with-tools) example below.
Models can also explain the result of the tool call in the response. See the [Chat request (With history, with tools)](#chat-request-with-history-with-tools) example below.
[See models with tool calling capabilities](https://ollama.com/search?c=tool).
### Structured outputs ### Structured outputs
Structured outputs are supported by providing a JSON schema in the `format` parameter. The model will generate a response that matches the schema. See the [Chat request (Structured outputs)](#chat-request-structured-outputs) example below. Structured outputs are supported by providing a JSON schema in the `format` parameter. The model will generate a response that matches the schema. See the [Chat request (Structured outputs)](#chat-request-structured-outputs) example below.
### Examples ### Examples
#### Chat request (Streaming) #### Chat Request (Streaming)
##### Request ##### Request
@@ -578,88 +566,6 @@ Final response:
} }
``` ```
#### Chat request (Streaming with tools)
##### Request
```shell
curl http://localhost:11434/api/chat -d '{
"model": "llama3.2",
"messages": [
{
"role": "user",
"content": "what is the weather in tokyo?"
}
],
"tools": [
{
"type": "function",
"function": {
"name": "get_weather",
"description": "Get the weather in a given city",
"parameters": {
"type": "object",
"properties": {
"city": {
"type": "string",
"description": "The city to get the weather for"
}
},
"required": ["city"]
}
}
}
],
"stream": true
}'
```
##### Response
A stream of JSON objects is returned:
```json
{
"model": "llama3.2",
"created_at": "2025-07-07T20:22:19.184789Z",
"message": {
"role": "assistant",
"content": "",
"tool_calls": [
{
"function": {
"name": "get_weather",
"arguments": {
"city": "Tokyo"
}
},
}
]
},
"done": false
}
```
Final response:
```json
{
"model":"llama3.2",
"created_at":"2025-07-07T20:22:19.19314Z",
"message": {
"role": "assistant",
"content": ""
},
"done_reason": "stop",
"done": true,
"total_duration": 182242375,
"load_duration": 41295167,
"prompt_eval_count": 169,
"prompt_eval_duration": 24573166,
"eval_count": 15,
"eval_duration": 115959084
}
```
#### Chat request (No streaming) #### Chat request (No streaming)
##### Request ##### Request
@@ -697,74 +603,6 @@ curl http://localhost:11434/api/chat -d '{
} }
``` ```
#### Chat request (No streaming, with tools)
##### Request
```shell
curl http://localhost:11434/api/chat -d '{
"model": "llama3.2",
"messages": [
{
"role": "user",
"content": "what is the weather in tokyo?"
}
],
"tools": [
{
"type": "function",
"function": {
"name": "get_weather",
"description": "Get the weather in a given city",
"parameters": {
"type": "object",
"properties": {
"city": {
"type": "string",
"description": "The city to get the weather for"
}
},
"required": ["city"]
}
}
}
],
"stream": false
}'
```
##### Response
```json
{
"model": "llama3.2",
"created_at": "2025-07-07T20:32:53.844124Z",
"message": {
"role": "assistant",
"content": "",
"tool_calls": [
{
"function": {
"name": "get_weather",
"arguments": {
"city": "Tokyo"
}
},
}
]
},
"done_reason": "stop",
"done": true,
"total_duration": 3244883583,
"load_duration": 2969184542,
"prompt_eval_count": 169,
"prompt_eval_duration": 141656333,
"eval_count": 18,
"eval_duration": 133293625
}
```
#### Chat request (Structured outputs) #### Chat request (Structured outputs)
##### Request ##### Request
@@ -871,87 +709,6 @@ Final response:
} }
``` ```
#### Chat request (With history, with tools)
##### Request
```shell
curl http://localhost:11434/api/chat -d '{
"model": "llama3.2",
"messages": [
{
"role": "user",
"content": "what is the weather in Toronto?"
},
// the message from the model appended to history
{
"role": "assistant",
"content": "",
"tool_calls": [
{
"function": {
"name": "get_temperature",
"arguments": {
"city": "Toronto"
}
},
}
]
},
// the tool call result appended to history
{
"role": "tool",
"content": "11 degrees celsius",
"tool_name": "get_temperature",
}
],
"stream": false,
"tools": [
{
"type": "function",
"function": {
"name": "get_weather",
"description": "Get the weather in a given city",
"parameters": {
"type": "object",
"properties": {
"city": {
"type": "string",
"description": "The city to get the weather for"
}
},
"required": ["city"]
}
}
}
]
}'
```
##### Response
```json
{
"model": "llama3.2",
"created_at": "2025-07-07T20:43:37.688511Z",
"message": {
"role": "assistant",
"content": "The current temperature in Toronto is 11°C."
},
"done_reason": "stop",
"done": true,
"total_duration": 890771750,
"load_duration": 707634750,
"prompt_eval_count": 94,
"prompt_eval_duration": 91703208,
"eval_count": 11,
"eval_duration": 90282125
}
```
#### Chat request (with images) #### Chat request (with images)
##### Request ##### Request
@@ -1195,8 +952,19 @@ If you are creating a model from a safetensors directory or from a GGUF file, yo
| Type | Recommended | | Type | Recommended |
| --- | :-: | | --- | :-: |
| q2_K | |
| q3_K_L | |
| q3_K_M | |
| q3_K_S | |
| q4_0 | |
| q4_1 | |
| q4_K_M | * | | q4_K_M | * |
| q4_K_S | | | q4_K_S | |
| q5_0 | |
| q5_1 | |
| q5_K_M | |
| q5_K_S | |
| q6_K | |
| q8_0 | * | | q8_0 | * |
### Examples ### Examples
@@ -1241,8 +1009,8 @@ Quantize a non-quantized model.
```shell ```shell
curl http://localhost:11434/api/create -d '{ curl http://localhost:11434/api/create -d '{
"model": "llama3.2:quantized", "model": "llama3.1:quantized",
"from": "llama3.2:3b-instruct-fp16", "from": "llama3.1:8b-instruct-fp16",
"quantize": "q4_K_M" "quantize": "q4_K_M"
}' }'
``` ```
@@ -1252,14 +1020,12 @@ curl http://localhost:11434/api/create -d '{
A stream of JSON objects is returned: A stream of JSON objects is returned:
```json ```json
{"status":"quantizing F16 model to Q4_K_M","digest":"0","total":6433687776,"completed":12302} {"status":"quantizing F16 model to Q4_K_M"}
{"status":"quantizing F16 model to Q4_K_M","digest":"0","total":6433687776,"completed":6433687552} {"status":"creating new layer sha256:667b0c1932bc6ffc593ed1d03f895bf2dc8dc6df21db3042284a6f4416b06a29"}
{"status":"verifying conversion"} {"status":"using existing layer sha256:11ce4ee3e170f6adebac9a991c22e22ab3f8530e154ee669954c4bc73061c258"}
{"status":"creating new layer sha256:fb7f4f211b89c6c4928ff4ddb73db9f9c0cfca3e000c3e40d6cf27ddc6ca72eb"} {"status":"using existing layer sha256:0ba8f0e314b4264dfd19df045cde9d4c394a52474bf92ed6a3de22a4ca31a177"}
{"status":"using existing layer sha256:966de95ca8a62200913e3f8bfbf84c8494536f1b94b49166851e76644e966396"}
{"status":"using existing layer sha256:fcc5a6bec9daf9b561a68827b67ab6088e1dba9d1fa2a50d7bbcc8384e0a265d"}
{"status":"using existing layer sha256:a70ff7e570d97baaf4e62ac6e6ad9975e04caa6d900d3742d37698494479e0cd"}
{"status":"using existing layer sha256:56bb8bd477a519ffa694fc449c2413c6f0e1d3b1c88fa7e3c9d88d3ae49d4dcb"} {"status":"using existing layer sha256:56bb8bd477a519ffa694fc449c2413c6f0e1d3b1c88fa7e3c9d88d3ae49d4dcb"}
{"status":"creating new layer sha256:455f34728c9b5dd3376378bfb809ee166c145b0b4c1f1a6feca069055066ef9a"}
{"status":"writing manifest"} {"status":"writing manifest"}
{"status":"success"} {"status":"success"}
``` ```
@@ -1397,37 +1163,29 @@ A single JSON object will be returned.
{ {
"models": [ "models": [
{ {
"name": "deepseek-r1:latest", "name": "codellama:13b",
"model": "deepseek-r1:latest", "modified_at": "2023-11-04T14:56:49.277302595-07:00",
"modified_at": "2025-05-10T08:06:48.639712648-07:00", "size": 7365960935,
"size": 4683075271, "digest": "9f438cb9cd581fc025612d27f7c1a6669ff83a8bb0ed86c94fcf4c5440555697",
"digest": "0a8c266910232fd3291e71e5ba1e058cc5af9d411192cf88b6d30e92b6e73163",
"details": { "details": {
"parent_model": "",
"format": "gguf", "format": "gguf",
"family": "qwen2", "family": "llama",
"families": [ "families": null,
"qwen2" "parameter_size": "13B",
], "quantization_level": "Q4_0"
"parameter_size": "7.6B",
"quantization_level": "Q4_K_M"
} }
}, },
{ {
"name": "llama3.2:latest", "name": "llama3:latest",
"model": "llama3.2:latest", "modified_at": "2023-12-07T09:32:18.757212583-08:00",
"modified_at": "2025-05-04T17:37:44.706015396-07:00", "size": 3825819519,
"size": 2019393189, "digest": "fe938a131f40e6f6d40083c9f0f430a515233eb2edaa6d72eb85c50d64f2300e",
"digest": "a80c4f17acd55265feec403c7aef86be0c25983ab279d83f3bcd3abbcb5b8b72",
"details": { "details": {
"parent_model": "",
"format": "gguf", "format": "gguf",
"family": "llama", "family": "llama",
"families": [ "families": null,
"llama" "parameter_size": "7B",
], "quantization_level": "Q4_0"
"parameter_size": "3.2B",
"quantization_level": "Q4_K_M"
} }
} }
] ]

59
docs/benchmark.md Normal file
View File

@@ -0,0 +1,59 @@
# Benchmark
Go benchmark tests that measure end-to-end performance of a running Ollama server. Run these tests to evaluate model inference performance on your hardware and measure the impact of code changes.
## When to use
Run these benchmarks when:
- Making changes to the model inference engine
- Modifying model loading/unloading logic
- Changing prompt processing or token generation code
- Implementing a new model architecture
- Testing performance across different hardware setups
## Prerequisites
- Ollama server running locally with `ollama serve` on `127.0.0.1:11434`
## Usage and Examples
>[!NOTE]
>All commands must be run from the root directory of the Ollama project.
Basic syntax:
```bash
go test -bench=. ./benchmark/... -m $MODEL_NAME
```
Required flags:
- `-bench=.`: Run all benchmarks
- `-m`: Model name to benchmark
Optional flags:
- `-count N`: Number of times to run the benchmark (useful for statistical analysis)
- `-timeout T`: Maximum time for the benchmark to run (e.g. "10m" for 10 minutes)
Common usage patterns:
Single benchmark run with a model specified:
```bash
go test -bench=. ./benchmark/... -m llama3.3
```
## Output metrics
The benchmark reports several key metrics:
- `gen_tok/s`: Generated tokens per second
- `prompt_tok/s`: Prompt processing tokens per second
- `ttft_ms`: Time to first token in milliseconds
- `load_ms`: Model load time in milliseconds
- `gen_tokens`: Total tokens generated
- `prompt_tokens`: Total prompt tokens processed
Each benchmark runs two scenarios:
- Cold start: Model is loaded from disk for each test
- Warm start: Model is pre-loaded in memory
Three prompt lengths are tested for each scenario:
- Short prompt (100 tokens)
- Medium prompt (500 tokens)
- Long prompt (1000 tokens)

View File

@@ -118,7 +118,7 @@ To run tests, use `go test`:
go test ./... go test ./...
``` ```
> NOTE: In rare circumstances, you may need to change a package using the new > NOTE: In rare cirumstances, you may nedd to change a package using the new
> "synctest" package in go1.24. > "synctest" package in go1.24.
> >
> If you do not have the "synctest" package enabled, you will not see build or > If you do not have the "synctest" package enabled, you will not see build or

View File

@@ -292,7 +292,7 @@ If too many requests are sent to the server, it will respond with a 503 error in
## How does Ollama handle concurrent requests? ## How does Ollama handle concurrent requests?
Ollama supports two levels of concurrent processing. If your system has sufficient available memory (system memory when using CPU inference, or VRAM for GPU inference) then multiple models can be loaded at the same time. For a given model, if there is sufficient available memory when the model is loaded, it can be configured to allow parallel request processing. Ollama supports two levels of concurrent processing. If your system has sufficient available memory (system memory when using CPU inference, or VRAM for GPU inference) then multiple models can be loaded at the same time. For a given model, if there is sufficient available memory when the model is loaded, it is configured to allow parallel request processing.
If there is insufficient available memory to load a new model request while one or more models are already loaded, all new requests will be queued until the new model can be loaded. As prior models become idle, one or more will be unloaded to make room for the new model. Queued requests will be processed in order. When using GPU inference new models must be able to completely fit in VRAM to allow concurrent model loads. If there is insufficient available memory to load a new model request while one or more models are already loaded, all new requests will be queued until the new model can be loaded. As prior models become idle, one or more will be unloaded to make room for the new model. Queued requests will be processed in order. When using GPU inference new models must be able to completely fit in VRAM to allow concurrent model loads.
@@ -301,7 +301,7 @@ Parallel request processing for a given model results in increasing the context
The following server settings may be used to adjust how Ollama handles concurrent requests on most platforms: The following server settings may be used to adjust how Ollama handles concurrent requests on most platforms:
- `OLLAMA_MAX_LOADED_MODELS` - The maximum number of models that can be loaded concurrently provided they fit in available memory. The default is 3 * the number of GPUs or 3 for CPU inference. - `OLLAMA_MAX_LOADED_MODELS` - The maximum number of models that can be loaded concurrently provided they fit in available memory. The default is 3 * the number of GPUs or 3 for CPU inference.
- `OLLAMA_NUM_PARALLEL` - The maximum number of parallel requests each model will process at the same time. The default is 1, and will handle 1 request per model at a time. - `OLLAMA_NUM_PARALLEL` - The maximum number of parallel requests each model will process at the same time. The default will auto-select either 4 or 1 based on available memory.
- `OLLAMA_MAX_QUEUE` - The maximum number of requests Ollama will queue when busy before rejecting additional requests. The default is 512 - `OLLAMA_MAX_QUEUE` - The maximum number of requests Ollama will queue when busy before rejecting additional requests. The default is 512
Note: Windows with Radeon GPUs currently default to 1 model maximum due to limitations in ROCm v5.7 for available VRAM reporting. Once ROCm v6.2 is available, Windows Radeon will follow the defaults above. You may enable concurrent model loads on Radeon on Windows, but ensure you don't load more models than will fit into your GPUs VRAM. Note: Windows with Radeon GPUs currently default to 1 model maximum due to limitations in ROCm v5.7 for available VRAM reporting. Once ROCm v6.2 is available, Windows Radeon will follow the defaults above. You may enable concurrent model loads on Radeon on Windows, but ensure you don't load more models than will fit into your GPUs VRAM.
@@ -333,16 +333,3 @@ The currently available K/V cache quantization types are:
How much the cache quantization impacts the model's response quality will depend on the model and the task. Models that have a high GQA count (e.g. Qwen2) may see a larger impact on precision from quantization than models with a low GQA count. How much the cache quantization impacts the model's response quality will depend on the model and the task. Models that have a high GQA count (e.g. Qwen2) may see a larger impact on precision from quantization than models with a low GQA count.
You may need to experiment with different quantization types to find the best balance between memory usage and quality. You may need to experiment with different quantization types to find the best balance between memory usage and quality.
## How can I stop Ollama from starting when I login to my computer
Ollama for Windows and macOS register as a login item during installation. You can disable this if you prefer not to have Ollama automatically start. Ollama will respect this setting across upgrades, unless you uninstall the application.
**Windows**
- Remove `%APPDATA%\Microsoft\Windows\Start Menu\Programs\Startup\Ollama.lnk`
**MacOS Monterey (v12)**
- Open `Settings` -> `Users & Groups` -> `Login Items` and find the `Ollama` entry, then click the `-` (minus) to remove
**MacOS Ventura (v13) and later**
- Open `Settings` and search for "Login Items", find the `Ollama` entry under "Allow in the Background`, then click the slider to disable.

View File

@@ -7,8 +7,6 @@ Check your compute compatibility to see if your card is supported:
| Compute Capability | Family | Cards | | Compute Capability | Family | Cards |
| ------------------ | ------------------- | ----------------------------------------------------------------------------------------------------------- | | ------------------ | ------------------- | ----------------------------------------------------------------------------------------------------------- |
| 12.0 | GeForce RTX 50xx | `RTX 5060` `RTX 5060 Ti` `RTX 5070` `RTX 5070 Ti` `RTX 5080` `RTX 5090` |
| | NVIDIA Professioal | `RTX PRO 4000 Blackwell` `RTX PRO 4500 Blackwell` `RTX PRO 5000 Blackwell` `RTX PRO 6000 Blackwell` |
| 9.0 | NVIDIA | `H200` `H100` | | 9.0 | NVIDIA | `H200` `H100` |
| 8.9 | GeForce RTX 40xx | `RTX 4090` `RTX 4080 SUPER` `RTX 4080` `RTX 4070 Ti SUPER` `RTX 4070 Ti` `RTX 4070 SUPER` `RTX 4070` `RTX 4060 Ti` `RTX 4060` | | 8.9 | GeForce RTX 40xx | `RTX 4090` `RTX 4080 SUPER` `RTX 4080` `RTX 4070 Ti SUPER` `RTX 4070 Ti` `RTX 4070 SUPER` `RTX 4070` `RTX 4060 Ti` `RTX 4060` |
| | NVIDIA Professional | `L4` `L40` `RTX 6000` | | | NVIDIA Professional | `L4` `L40` `RTX 6000` |

View File

@@ -53,8 +53,6 @@ FROM /path/to/safetensors/directory
If you create the Modelfile in the same directory as the weights, you can use the command `FROM .`. If you create the Modelfile in the same directory as the weights, you can use the command `FROM .`.
If you do not create the Modelfile, ollama will act as if there was a Modelfile with the command `FROM .`.
Now run the `ollama create` command from the directory where you created the `Modelfile`: Now run the `ollama create` command from the directory where you created the `Modelfile`:
```shell ```shell
@@ -134,12 +132,22 @@ success
### Supported Quantizations ### Supported Quantizations
- `q4_0`
- `q4_1`
- `q5_0`
- `q5_1`
- `q8_0` - `q8_0`
#### K-means Quantizations #### K-means Quantizations
- `q3_K_S`
- `q3_K_M`
- `q3_K_L`
- `q4_K_S` - `q4_K_S`
- `q4_K_M` - `q4_K_M`
- `q5_K_S`
- `q5_K_M`
- `q6_K`
## Sharing your model on ollama.com ## Sharing your model on ollama.com

View File

@@ -16,7 +16,7 @@ curl -fsSL https://ollama.com/install.sh | sh
Download and extract the package: Download and extract the package:
```shell ```shell
curl -LO https://ollama.com/download/ollama-linux-amd64.tgz curl -L https://ollama.com/download/ollama-linux-amd64.tgz -o ollama-linux-amd64.tgz
sudo tar -C /usr -xzf ollama-linux-amd64.tgz sudo tar -C /usr -xzf ollama-linux-amd64.tgz
``` ```
@@ -112,8 +112,8 @@ sudo systemctl status ollama
> While AMD has contributed the `amdgpu` driver upstream to the official linux > While AMD has contributed the `amdgpu` driver upstream to the official linux
> kernel source, the version is older and may not support all ROCm features. We > kernel source, the version is older and may not support all ROCm features. We
> recommend you install the latest driver from > recommend you install the latest driver from
> [AMD](https://www.amd.com/en/support/download/linux-drivers.html) for best support > https://www.amd.com/en/support/linux-drivers for best support of your Radeon
> of your Radeon GPU. > GPU.
## Customizing ## Customizing

View File

@@ -1,42 +0,0 @@
# Ollama for macOS
## System Requirements
* MacOS Monterey (v12) or newer
* Apple M series (CPU and GPU support) or x86 (CPU only)
## Filesystem Requirements
The preferred method of installation is to mount the `ollama.dmg` and drag-and-drop the Ollama application to the system-wide `Applications` folder. Upon startup, the Ollama app will verify the `ollama` CLI is present in your PATH, and if not detected, will prompt for permission to create a link in `/usr/local/bin`
Once you've installed Ollama, you'll need additional space for storing the Large Language models, which can be tens to hundreds of GB in size. If your home directory doesn't have enough space, you can change where the binaries are installed, and where the models are stored.
### Changing Install Location
To install the Ollama application somewhere other than `Applications`, place the Ollama application in the desired location, and ensure the CLI `Ollama.app/Contents/Resources/ollama` or a sym-link to the CLI can be found in your path. Upon first start decline the "Move to Applications?" request.
## Troubleshooting
Ollama on MacOS stores files in a few different locations.
- `~/.ollama` contains models and configuration
- `~/.ollama/logs` contains logs
- *app.log* contains most recent logs from the GUI application
- *server.log* contains the most recent server logs
- `<install location>/Ollama.app/Contents/Resources/ollama` the CLI binary
## Uninstall
To fully remove Ollama from your system, remove the following files and folders:
```
sudo rm -rf /Applications/Ollama.app
sudo rm /usr/local/bin/ollama
rm -rf "~/Library/Application Support/Ollama"
rm -rf "~/Library/Saved Application State/com.electron.ollama.savedState"
rm -rf ~/Library/Caches/com.electron.ollama/
rm -rf ~/Library/Caches/ollama
rm -rf ~/Library/WebKit/com.electron.ollama
rm -rf ~/.ollama
```

View File

@@ -150,7 +150,7 @@ PARAMETER <parameter> <parametervalue>
| Parameter | Description | Value Type | Example Usage | | Parameter | Description | Value Type | Example Usage |
| -------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | ---------- | -------------------- | | -------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | ---------- | -------------------- |
| num_ctx | Sets the size of the context window used to generate the next token. (Default: 4096) | int | num_ctx 4096 | | num_ctx | Sets the size of the context window used to generate the next token. (Default: 2048) | int | num_ctx 4096 |
| repeat_last_n | Sets how far back for the model to look back to prevent repetition. (Default: 64, 0 = disabled, -1 = num_ctx) | int | repeat_last_n 64 | | repeat_last_n | Sets how far back for the model to look back to prevent repetition. (Default: 64, 0 = disabled, -1 = num_ctx) | int | repeat_last_n 64 |
| repeat_penalty | Sets how strongly to penalize repetitions. A higher value (e.g., 1.5) will penalize repetitions more strongly, while a lower value (e.g., 0.9) will be more lenient. (Default: 1.1) | float | repeat_penalty 1.1 | | repeat_penalty | Sets how strongly to penalize repetitions. A higher value (e.g., 1.5) will penalize repetitions more strongly, while a lower value (e.g., 0.9) will be more lenient. (Default: 1.1) | float | repeat_penalty 1.1 |
| temperature | The temperature of the model. Increasing the temperature will make the model answer more creatively. (Default: 0.8) | float | temperature 0.7 | | temperature | The temperature of the model. Increasing the temperature will make the model answer more creatively. (Default: 0.8) | float | temperature 0.7 |

View File

@@ -72,7 +72,7 @@ client = OpenAI(base_url="http://localhost:11434/v1", api_key="ollama")
# Define the schema for the response # Define the schema for the response
class FriendInfo(BaseModel): class FriendInfo(BaseModel):
name: str name: str
age: int age: int
is_available: bool is_available: bool
class FriendList(BaseModel): class FriendList(BaseModel):

View File

@@ -9,7 +9,7 @@ cat ~/.ollama/logs/server.log
On **Linux** systems with systemd, the logs can be found with this command: On **Linux** systems with systemd, the logs can be found with this command:
```shell ```shell
journalctl -u ollama --no-pager --follow --pager-end journalctl -u ollama --no-pager --follow --pager-end
``` ```
When you run Ollama in a **container**, the logs go to stdout/stderr in the container: When you run Ollama in a **container**, the logs go to stdout/stderr in the container:
@@ -23,7 +23,7 @@ docker logs <container-name>
If manually running `ollama serve` in a terminal, the logs will be on that terminal. If manually running `ollama serve` in a terminal, the logs will be on that terminal.
When you run Ollama on **Windows**, there are a few different locations. You can view them in the explorer window by hitting `<cmd>+R` and type in: When you run Ollama on **Windows**, there are a few different locations. You can view them in the explorer window by hitting `<cmd>+R` and type in:
- `explorer %LOCALAPPDATA%\Ollama` to view logs. The most recent server logs will be in `server.log` and older logs will be in `server-#.log` - `explorer %LOCALAPPDATA%\Ollama` to view logs. The most recent server logs will be in `server.log` and older logs will be in `server-#.log`
- `explorer %LOCALAPPDATA%\Programs\Ollama` to browse the binaries (The installer adds this to your user PATH) - `explorer %LOCALAPPDATA%\Programs\Ollama` to browse the binaries (The installer adds this to your user PATH)
- `explorer %HOMEPATH%\.ollama` to browse where models and configuration is stored - `explorer %HOMEPATH%\.ollama` to browse where models and configuration is stored
@@ -38,7 +38,7 @@ Join the [Discord](https://discord.gg/ollama) for help interpreting the logs.
## LLM libraries ## LLM libraries
Ollama includes multiple LLM libraries compiled for different GPUs and CPU vector features. Ollama tries to pick the best one based on the capabilities of your system. If this autodetection has problems, or you run into other problems (e.g. crashes in your GPU) you can workaround this by forcing a specific LLM library. `cpu_avx2` will perform the best, followed by `cpu_avx` and the slowest but most compatible is `cpu`. Rosetta emulation under MacOS will work with the `cpu` library. Ollama includes multiple LLM libraries compiled for different GPUs and CPU vector features. Ollama tries to pick the best one based on the capabilities of your system. If this autodetection has problems, or you run into other problems (e.g. crashes in your GPU) you can workaround this by forcing a specific LLM library. `cpu_avx2` will perform the best, followed by `cpu_avx` an the slowest but most compatible is `cpu`. Rosetta emulation under MacOS will work with the `cpu` library.
In the server log, you will see a message that looks something like this (varies from release to release): In the server log, you will see a message that looks something like this (varies from release to release):
@@ -97,7 +97,7 @@ If none of those resolve the problem, gather additional information and file an
On linux, AMD GPU access typically requires `video` and/or `render` group membership to access the `/dev/kfd` device. If permissions are not set up correctly, Ollama will detect this and report an error in the server log. On linux, AMD GPU access typically requires `video` and/or `render` group membership to access the `/dev/kfd` device. If permissions are not set up correctly, Ollama will detect this and report an error in the server log.
When running in a container, in some Linux distributions and container runtimes, the ollama process may be unable to access the GPU. Use `ls -lnd /dev/kfd /dev/dri /dev/dri/*` on the host system to determine the **numeric** group IDs on your system, and pass additional `--group-add ...` arguments to the container so it can access the required devices. For example, in the following output `crw-rw---- 1 0 44 226, 0 Sep 16 16:55 /dev/dri/card0` the group ID column is `44` When running in a container, in some Linux distributions and container runtimes, the ollama process may be unable to access the GPU. Use `ls -lnd /dev/kfd /dev/dri /dev/dri/*` on the host system to determine the **numeric** group IDs on your system, and pass additional `--group-add ...` arguments to the container so it can access the required devices. For example, in the following output `crw-rw---- 1 0 44 226, 0 Sep 16 16:55 /dev/dri/card0` the group ID column is `44`
If you are experiencing problems getting Ollama to correctly discover or use your GPU for inference, the following may help isolate the failure. If you are experiencing problems getting Ollama to correctly discover or use your GPU for inference, the following may help isolate the failure.
- `AMD_LOG_LEVEL=3` Enable info log levels in the AMD HIP/ROCm libraries. This can help show more detailed error codes that can help troubleshoot problems - `AMD_LOG_LEVEL=3` Enable info log levels in the AMD HIP/ROCm libraries. This can help show more detailed error codes that can help troubleshoot problems

View File

@@ -30,6 +30,20 @@ To install the Ollama application in a location different than your home directo
OllamaSetup.exe /DIR="d:\some\location" OllamaSetup.exe /DIR="d:\some\location"
``` ```
### Changing Model Location
To change where Ollama stores the downloaded models instead of using your home directory, set the environment variable `OLLAMA_MODELS` in your user account.
1. Start the Settings (Windows 11) or Control Panel (Windows 10) application and search for _environment variables_.
2. Click on _Edit environment variables for your account_.
3. Edit or create a new variable for your user account for `OLLAMA_MODELS` where you want the models stored
4. Click OK/Apply to save.
If Ollama is already running, Quit the tray application and relaunch it from the Start menu, or a new terminal started after you saved the environment variables.
## API Access ## API Access
Here's a quick example showing API access from `powershell` Here's a quick example showing API access from `powershell`

View File

@@ -149,22 +149,9 @@ func Bool(k string) func() bool {
} }
} }
// LogLevel returns the log level for the application.
// Values are 0 or false INFO (Default), 1 or true DEBUG, 2 TRACE
func LogLevel() slog.Level {
level := slog.LevelInfo
if s := Var("OLLAMA_DEBUG"); s != "" {
if b, _ := strconv.ParseBool(s); b {
level = slog.LevelDebug
} else if i, _ := strconv.ParseInt(s, 10, 64); i != 0 {
level = slog.Level(i * -4)
}
}
return level
}
var ( var (
// Debug enabled additional debug information.
Debug = Bool("OLLAMA_DEBUG")
// FlashAttention enables the experimental flash attention feature. // FlashAttention enables the experimental flash attention feature.
FlashAttention = Bool("OLLAMA_FLASH_ATTENTION") FlashAttention = Bool("OLLAMA_FLASH_ATTENTION")
// KvCacheType is the quantization type for the K/V cache. // KvCacheType is the quantization type for the K/V cache.
@@ -183,8 +170,6 @@ var (
NewEngine = Bool("OLLAMA_NEW_ENGINE") NewEngine = Bool("OLLAMA_NEW_ENGINE")
// ContextLength sets the default context length // ContextLength sets the default context length
ContextLength = Uint("OLLAMA_CONTEXT_LENGTH", 4096) ContextLength = Uint("OLLAMA_CONTEXT_LENGTH", 4096)
// Auth enables authentication between the Ollama client and server
UseAuth = Bool("OLLAMA_AUTH")
) )
func String(s string) func() string { func String(s string) func() string {
@@ -219,11 +204,13 @@ func Uint(key string, defaultValue uint) func() uint {
var ( var (
// NumParallel sets the number of parallel model requests. NumParallel can be configured via the OLLAMA_NUM_PARALLEL environment variable. // NumParallel sets the number of parallel model requests. NumParallel can be configured via the OLLAMA_NUM_PARALLEL environment variable.
NumParallel = Uint("OLLAMA_NUM_PARALLEL", 1) NumParallel = Uint("OLLAMA_NUM_PARALLEL", 0)
// MaxRunners sets the maximum number of loaded models. MaxRunners can be configured via the OLLAMA_MAX_LOADED_MODELS environment variable. // MaxRunners sets the maximum number of loaded models. MaxRunners can be configured via the OLLAMA_MAX_LOADED_MODELS environment variable.
MaxRunners = Uint("OLLAMA_MAX_LOADED_MODELS", 0) MaxRunners = Uint("OLLAMA_MAX_LOADED_MODELS", 0)
// MaxQueue sets the maximum number of queued requests. MaxQueue can be configured via the OLLAMA_MAX_QUEUE environment variable. // MaxQueue sets the maximum number of queued requests. MaxQueue can be configured via the OLLAMA_MAX_QUEUE environment variable.
MaxQueue = Uint("OLLAMA_MAX_QUEUE", 512) MaxQueue = Uint("OLLAMA_MAX_QUEUE", 512)
// MaxVRAM sets a maximum VRAM override in bytes. MaxVRAM can be configured via the OLLAMA_MAX_VRAM environment variable.
MaxVRAM = Uint("OLLAMA_MAX_VRAM", 0)
) )
func Uint64(key string, defaultValue uint64) func() uint64 { func Uint64(key string, defaultValue uint64) func() uint64 {
@@ -251,7 +238,7 @@ type EnvVar struct {
func AsMap() map[string]EnvVar { func AsMap() map[string]EnvVar {
ret := map[string]EnvVar{ ret := map[string]EnvVar{
"OLLAMA_DEBUG": {"OLLAMA_DEBUG", LogLevel(), "Show additional debug information (e.g. OLLAMA_DEBUG=1)"}, "OLLAMA_DEBUG": {"OLLAMA_DEBUG", Debug(), "Show additional debug information (e.g. OLLAMA_DEBUG=1)"},
"OLLAMA_FLASH_ATTENTION": {"OLLAMA_FLASH_ATTENTION", FlashAttention(), "Enabled flash attention"}, "OLLAMA_FLASH_ATTENTION": {"OLLAMA_FLASH_ATTENTION", FlashAttention(), "Enabled flash attention"},
"OLLAMA_KV_CACHE_TYPE": {"OLLAMA_KV_CACHE_TYPE", KvCacheType(), "Quantization type for the K/V cache (default: f16)"}, "OLLAMA_KV_CACHE_TYPE": {"OLLAMA_KV_CACHE_TYPE", KvCacheType(), "Quantization type for the K/V cache (default: f16)"},
"OLLAMA_GPU_OVERHEAD": {"OLLAMA_GPU_OVERHEAD", GpuOverhead(), "Reserve a portion of VRAM per GPU (bytes)"}, "OLLAMA_GPU_OVERHEAD": {"OLLAMA_GPU_OVERHEAD", GpuOverhead(), "Reserve a portion of VRAM per GPU (bytes)"},

View File

@@ -1,13 +1,11 @@
package envconfig package envconfig
import ( import (
"log/slog"
"math" "math"
"testing" "testing"
"time" "time"
"github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp"
"github.com/ollama/ollama/logutil"
) )
func TestHost(t *testing.T) { func TestHost(t *testing.T) {
@@ -294,34 +292,3 @@ func TestContextLength(t *testing.T) {
}) })
} }
} }
func TestLogLevel(t *testing.T) {
cases := map[string]slog.Level{
// Default to INFO
"": slog.LevelInfo,
"false": slog.LevelInfo,
"f": slog.LevelInfo,
"0": slog.LevelInfo,
// True values enable Debug
"true": slog.LevelDebug,
"t": slog.LevelDebug,
// Positive values increase verbosity
"1": slog.LevelDebug,
"2": logutil.LevelTrace,
// Negative values decrease verbosity
"-1": slog.LevelWarn,
"-2": slog.LevelError,
}
for k, v := range cases {
t.Run(k, func(t *testing.T) {
t.Setenv("OLLAMA_DEBUG", k)
if i := LogLevel(); i != v {
t.Errorf("%s: expected %d, got %d", k, v, i)
}
})
}
}

View File

@@ -10,5 +10,4 @@ type Config interface {
Strings(string, ...[]string) []string Strings(string, ...[]string) []string
Ints(string, ...[]int32) []int32 Ints(string, ...[]int32) []int32
Floats(string, ...[]float32) []float32 Floats(string, ...[]float32) []float32
Bools(string, ...[]bool) []bool
} }

View File

@@ -15,7 +15,6 @@ import (
type GGML struct { type GGML struct {
container container
model model
Length int64
} }
type model interface { type model interface {
@@ -34,8 +33,7 @@ func (kv KV) Kind() string {
} }
func (kv KV) ParameterCount() uint64 { func (kv KV) ParameterCount() uint64 {
val, _ := keyValue(kv, "general.parameter_count", uint64(0)) return keyValue(kv, "general.parameter_count", uint64(0))
return val
} }
func (kv KV) FileType() FileType { func (kv KV) FileType() FileType {
@@ -54,27 +52,16 @@ func (kv KV) EmbeddingLength() uint64 {
return uint64(kv.Uint("embedding_length")) return uint64(kv.Uint("embedding_length"))
} }
func (kv KV) HeadCountMax() uint64 { func (kv KV) HeadCount() uint64 {
// TODO(drifkin): using the max value can cause an overestimation. In the return uint64(kv.Uint("attention.head_count"))
// future if array values become more popular, we can adapt the more invasive
// <https://github.com/ollama/ollama/pull/10225>
return uint64(kv.UintOrMaxArrayValue("attention.head_count", 1))
} }
func (kv KV) HeadCountMin() uint64 { func (kv KV) HeadCountKV() uint64 {
return uint64(kv.UintOrMinArrayValue("attention.head_count", 1)) return uint64(kv.Uint("attention.head_count_kv", 1))
} }
func (kv KV) HeadCountKVMax() uint64 { func (kv KV) EmbeddingHeadCount() uint64 {
return uint64(kv.UintOrMaxArrayValue("attention.head_count_kv", 1)) if heads := kv.HeadCount(); heads > 0 {
}
func (kv KV) HeadCountKVMin() uint64 {
return uint64(kv.UintOrMinArrayValue("attention.head_count_kv", 1))
}
func (kv KV) EmbeddingHeadCountMax() uint64 {
if heads := kv.HeadCountMin(); heads > 0 {
return kv.EmbeddingLength() / heads return kv.EmbeddingLength() / heads
} }
@@ -82,11 +69,15 @@ func (kv KV) EmbeddingHeadCountMax() uint64 {
} }
func (kv KV) EmbeddingHeadCountK() uint64 { func (kv KV) EmbeddingHeadCountK() uint64 {
return uint64(kv.Uint("attention.key_length", uint32(kv.EmbeddingHeadCountMax()))) return uint64(kv.Uint("attention.key_length", uint32(kv.EmbeddingHeadCount())))
} }
func (kv KV) EmbeddingHeadCountV() uint64 { func (kv KV) EmbeddingHeadCountV() uint64 {
return uint64(kv.Uint("attention.value_length", uint32(kv.EmbeddingHeadCountMax()))) return uint64(kv.Uint("attention.value_length", uint32(kv.EmbeddingHeadCount())))
}
func (kv KV) GQA() uint64 {
return kv.HeadCount() / kv.HeadCountKV()
} }
func (kv KV) ContextLength() uint64 { func (kv KV) ContextLength() uint64 {
@@ -98,87 +89,42 @@ func (kv KV) ChatTemplate() string {
} }
func (kv KV) String(key string, defaultValue ...string) string { func (kv KV) String(key string, defaultValue ...string) string {
val, _ := keyValue(kv, key, append(defaultValue, "")...) return keyValue(kv, key, append(defaultValue, "")...)
return val
} }
func (kv KV) Uint(key string, defaultValue ...uint32) uint32 { func (kv KV) Uint(key string, defaultValue ...uint32) uint32 {
val, _ := keyValue(kv, key, append(defaultValue, 0)...) return keyValue(kv, key, append(defaultValue, 0)...)
return val
} }
func (kv KV) Float(key string, defaultValue ...float32) float32 { func (kv KV) Float(key string, defaultValue ...float32) float32 {
val, _ := keyValue(kv, key, append(defaultValue, 0)...) return keyValue(kv, key, append(defaultValue, 0)...)
return val
} }
func (kv KV) Bool(key string, defaultValue ...bool) bool { func (kv KV) Bool(key string, defaultValue ...bool) bool {
val, _ := keyValue(kv, key, append(defaultValue, false)...) return keyValue(kv, key, append(defaultValue, false)...)
return val
}
func (kv KV) UintOrMaxArrayValue(key string, defaultValue uint32) uint32 {
_, max := kv.UintOrArrayValue(key, defaultValue)
return max
}
func (kv KV) UintOrMinArrayValue(key string, defaultValue uint32) uint32 {
min, _ := kv.UintOrArrayValue(key, defaultValue)
return min
}
func (kv KV) UintOrArrayValue(key string, defaultValue uint32) (uint32, uint32) {
if u32, ok := keyValue(kv, key, uint32(0)); ok {
return u32, u32
} else if u32s, ok := keyValue(kv, key, &array[uint32]{}); ok {
min := slices.Min(u32s.values)
max := slices.Max(u32s.values)
return min, max
} else if i32s, ok := keyValue(kv, key, &array[int32]{}); ok {
min := slices.Min(i32s.values)
max := slices.Max(i32s.values)
if min < 0 || max < 0 {
slog.Warn("array values are unexpectedly negative", "key", key, "min", min, "max", max)
}
return uint32(min), uint32(max)
}
return defaultValue, defaultValue
} }
func (kv KV) Strings(key string, defaultValue ...[]string) []string { func (kv KV) Strings(key string, defaultValue ...[]string) []string {
val, _ := keyValue(kv, key, &array[string]{values: append(defaultValue, []string(nil))[0]}) return keyValue(kv, key, &array[string]{values: append(defaultValue, []string(nil))[0]}).values
return val.values
} }
func (kv KV) Ints(key string, defaultValue ...[]int32) []int32 { func (kv KV) Ints(key string, defaultValue ...[]int32) []int32 {
val, _ := keyValue(kv, key, &array[int32]{values: append(defaultValue, []int32(nil))[0]}) return keyValue(kv, key, &array[int32]{values: append(defaultValue, []int32(nil))[0]}).values
return val.values
} }
func (kv KV) Uints(key string, defaultValue ...[]uint32) []uint32 { func (kv KV) Uints(key string, defaultValue ...[]uint32) []uint32 {
val, _ := keyValue(kv, key, &array[uint32]{values: append(defaultValue, []uint32(nil))[0]}) return keyValue(kv, key, &array[uint32]{values: append(defaultValue, []uint32(nil))[0]}).values
return val.values
} }
func (kv KV) Floats(key string, defaultValue ...[]float32) []float32 { func (kv KV) Floats(key string, defaultValue ...[]float32) []float32 {
val, _ := keyValue(kv, key, &array[float32]{values: append(defaultValue, []float32(nil))[0]}) return keyValue(kv, key, &array[float32]{values: append(defaultValue, []float32(nil))[0]}).values
return val.values
}
func (kv KV) Bools(key string, defaultValue ...[]bool) []bool {
val, _ := keyValue(kv, key, &array[bool]{values: append(defaultValue, []bool(nil))[0]})
return val.values
} }
func (kv KV) OllamaEngineRequired() bool { func (kv KV) OllamaEngineRequired() bool {
return slices.Contains([]string{ return slices.Contains([]string{
"gemma3", "gemma3",
"gemma3n",
"mistral3", "mistral3",
"llama4", "llama4",
"mllama",
"qwen25vl",
}, kv.Architecture()) }, kv.Architecture())
} }
@@ -194,17 +140,17 @@ type arrayValueTypes interface {
*array[string] | *array[float32] | *array[float64] | *array[bool] *array[string] | *array[float32] | *array[float64] | *array[bool]
} }
func keyValue[T valueTypes | arrayValueTypes](kv KV, key string, defaultValue ...T) (T, bool) { func keyValue[T valueTypes | arrayValueTypes](kv KV, key string, defaultValue ...T) T {
if !strings.HasPrefix(key, "tokenizer.") && !strings.HasPrefix(key, "general.") { if !strings.HasPrefix(key, "tokenizer.") && !strings.HasPrefix(key, "general.") {
key = kv.Architecture() + "." + key key = kv.Architecture() + "." + key
} }
if val, ok := kv[key].(T); ok { if val, ok := kv[key]; ok {
return val, true return val.(T)
} }
slog.Debug("key with type not found", "key", key, "default", defaultValue[0]) slog.Debug("key not found", "key", key, "default", defaultValue[0])
return defaultValue[0], false return defaultValue[0]
} }
type Tensors struct { type Tensors struct {
@@ -438,12 +384,12 @@ func DetectContentType(b []byte) string {
// //
// It collects array values for arrays with a size less than or equal to // It collects array values for arrays with a size less than or equal to
// maxArraySize. If the maxArraySize is negative, all arrays are collected. // maxArraySize. If the maxArraySize is negative, all arrays are collected.
func Decode(rs io.ReadSeeker, maxArraySize int) (*GGML, error) { func Decode(rs io.ReadSeeker, maxArraySize int) (*GGML, int64, error) {
rs = bufioutil.NewBufferedSeeker(rs, 32<<10) rs = bufioutil.NewBufferedSeeker(rs, 32<<10)
var magic uint32 var magic uint32
if err := binary.Read(rs, binary.LittleEndian, &magic); err != nil { if err := binary.Read(rs, binary.LittleEndian, &magic); err != nil {
return nil, err return nil, 0, err
} }
var c container var c container
@@ -453,34 +399,33 @@ func Decode(rs io.ReadSeeker, maxArraySize int) (*GGML, error) {
case FILE_MAGIC_GGUF_BE: case FILE_MAGIC_GGUF_BE:
c = &containerGGUF{ByteOrder: binary.BigEndian, maxArraySize: maxArraySize} c = &containerGGUF{ByteOrder: binary.BigEndian, maxArraySize: maxArraySize}
default: default:
return nil, errors.New("invalid file magic") return nil, 0, errors.New("invalid file magic")
} }
model, err := c.Decode(rs) model, err := c.Decode(rs)
if err != nil { if err != nil {
return nil, err return nil, 0, err
} }
offset, err := rs.Seek(0, io.SeekCurrent) offset, err := rs.Seek(0, io.SeekCurrent)
if err != nil { if err != nil {
return nil, err return nil, 0, err
} }
// final model type // final model type
return &GGML{ return &GGML{
container: c, container: c,
model: model, model: model,
Length: offset, }, offset, nil
}, nil
} }
func (f GGML) GraphSize(context, batch uint64, numParallel int, kvCacheType string) (kv []uint64, partialOffload, fullOffload uint64) { func (f GGML) GraphSize(context, batch uint64, numParallel int, kvCacheType string) (kv []uint64, partialOffload, fullOffload uint64) {
embedding := f.KV().EmbeddingLength() embedding := f.KV().EmbeddingLength()
heads := f.KV().HeadCountMax() heads := f.KV().HeadCount()
headsKV := f.KV().HeadCountKVMax() headsKV := f.KV().HeadCountKV()
vocab := uint64(f.KV()["tokenizer.ggml.tokens"].(*array[string]).size) vocab := uint64(f.KV()["tokenizer.ggml.tokens"].(*array[string]).size)
embeddingHeads := f.KV().EmbeddingHeadCountMax() embeddingHeads := f.KV().EmbeddingHeadCount()
embeddingHeadsK := f.KV().EmbeddingHeadCountK() embeddingHeadsK := f.KV().EmbeddingHeadCountK()
embeddingHeadsV := f.KV().EmbeddingHeadCountV() embeddingHeadsV := f.KV().EmbeddingHeadCountV()
@@ -555,7 +500,7 @@ func (f GGML) GraphSize(context, batch uint64, numParallel int, kvCacheType stri
// vocab graph // vocab graph
4*batch*(embedding+vocab)+embedding*vocab*105/128, 4*batch*(embedding+vocab)+embedding*vocab*105/128,
) )
case "gemma", "gemma2", "gemma3", "gemma3n": case "gemma", "gemma2", "gemma3":
fullOffload = max( fullOffload = max(
4*batch*(embedding+vocab), 4*batch*(embedding+vocab),
4*batch*(2+context+context*heads+2*embedding+2*embeddingHeadsK*heads), 4*batch*(2+context+context*heads+2*embedding+2*embeddingHeadsK*heads),
@@ -568,11 +513,6 @@ func (f GGML) GraphSize(context, batch uint64, numParallel int, kvCacheType stri
embedding*embeddingHeadsK*heads*9/16, embedding*embeddingHeadsK*heads*9/16,
) )
if f.KV().Architecture() == "gemma3n" {
fullOffload *= 4
partialOffload *= 4
}
// Gemma2 also has sliding window attention but we only have an optimized implementation in the Ollama // Gemma2 also has sliding window attention but we only have an optimized implementation in the Ollama
// engine. Gemma3 always uses the Ollama engine. // engine. Gemma3 always uses the Ollama engine.
if f.KV().Architecture() == "gemma3" { if f.KV().Architecture() == "gemma3" {
@@ -708,20 +648,6 @@ func (llm GGML) VisionGraphSize() (weights, graphSize uint64) {
graphSize = 4 * (imageSize*imageSize*numChannels + graphSize = 4 * (imageSize*imageSize*numChannels +
embeddingLength*patchSize + embeddingLength*patchSize +
numPatches*numPatches*headCount) numPatches*numPatches*headCount)
case "qwen25vl":
maxPixels := uint64(llm.KV().Uint("vision.max_pixels", 28*28*1280))
numPatches := maxPixels / (patchSize * patchSize)
graphSize = 4 * (maxPixels*numChannels + // Original image storage
// Normalized pixels
maxPixels*numChannels +
// Patches storage (numPatches * channels * patchSize^2)
numPatches*numChannels*patchSize*patchSize +
// Self-attention calculations
numPatches*numPatches*headCount +
// Additional buffer for processing
embeddingLength*numPatches)
case "llama4": case "llama4":
// vision graph is computed independently in the same schedule // vision graph is computed independently in the same schedule
// and is negligible compared to the worst case text graph // and is negligible compared to the worst case text graph

View File

@@ -269,33 +269,3 @@ func TestKeyValue(t *testing.T) {
t.Errorf("unexpected uint8s (-got +want):\n%s", diff) t.Errorf("unexpected uint8s (-got +want):\n%s", diff)
} }
} }
func TestHeadCount(t *testing.T) {
valuesArray := []int32{1, 5, 3, 4}
cases := []struct {
kv KV
want uint64
}{
{
kv: KV{
"general.architecture": "abc",
"abc.attention.head_count": &array[int32]{values: valuesArray, size: len(valuesArray)},
},
want: uint64(5),
},
{
kv: KV{
"general.architecture": "abc",
"abc.attention.head_count": uint32(3),
},
want: uint64(3),
},
}
for _, tt := range cases {
got := tt.kv.HeadCountMax()
if got != tt.want {
t.Errorf("unexpected max value: got=%d want=%d", got, tt.want)
}
}
}

View File

@@ -527,17 +527,23 @@ func WriteGGUF(f *os.File, kv KV, ts []*Tensor) error {
return err return err
} }
for _, key := range slices.Sorted(maps.Keys(kv)) { keys := slices.Collect(maps.Keys(kv))
slices.Sort(keys)
for _, key := range keys {
if err := ggufWriteKV(f, key, kv[key]); err != nil { if err := ggufWriteKV(f, key, kv[key]); err != nil {
return err return err
} }
} }
slices.SortStableFunc(ts, func(a, b *Tensor) int { slices.SortStableFunc(ts, func(a, b *Tensor) int {
if i, j := a.block(), b.block(); i > 0 && j > 0 { if i, j := a.block(), b.block(); i < 0 && j > 0 {
return 1
} else if i > 0 && j < 0 {
return -1
} else {
return cmp.Compare(i, j) return cmp.Compare(i, j)
} }
return cmp.Compare(a.Name, b.Name)
}) })
var s uint64 var s uint64
@@ -609,10 +615,6 @@ func ggufWriteKV(ws io.WriteSeeker, k string, v any) error {
err = writeGGUFArray(ws, ggufTypeString, v) err = writeGGUFArray(ws, ggufTypeString, v)
case *array[string]: case *array[string]:
err = writeGGUFArray(ws, ggufTypeString, v.values) err = writeGGUFArray(ws, ggufTypeString, v.values)
case []bool:
err = writeGGUFArray(ws, ggufTypeBool, v)
case *array[bool]:
err = writeGGUFArray(ws, ggufTypeBool, v.values)
default: default:
return fmt.Errorf("improper type for '%s'", k) return fmt.Errorf("improper type for '%s'", k)
} }

View File

@@ -2,82 +2,62 @@ package ggml
import ( import (
"bytes" "bytes"
"math/rand/v2"
"os" "os"
"strings" "slices"
"testing" "testing"
"github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp"
) )
func TestWriteGGUF(t *testing.T) { func TestWriteGGUF(t *testing.T) {
r := rand.New(rand.NewPCG(0, 0)) w, err := os.CreateTemp(t.TempDir(), "*.bin")
for range 8 { if err != nil {
t.Run("shuffle", func(t *testing.T) { t.Fatal(err)
t.Parallel() }
defer w.Close()
ts := []*Tensor{ if err := WriteGGUF(w, KV{
{Name: "token_embd.weight", Shape: []uint64{2, 3}, WriterTo: bytes.NewBuffer(make([]byte, 2*3))}, "general.alignment": uint32(16),
{Name: "blk.0.attn_norm.weight", Shape: []uint64{2, 3}, WriterTo: bytes.NewBuffer(make([]byte, 2*3))}, }, []*Tensor{
{Name: "blk.1.attn_norm.weight", Shape: []uint64{2, 3}, WriterTo: bytes.NewBuffer(make([]byte, 2*3))}, {Name: "test.0", Shape: []uint64{2, 3}, WriterTo: bytes.NewBuffer(slices.Repeat([]byte{0}, 2*3*4))},
{Name: "blk.2.attn_norm.weight", Shape: []uint64{2, 3}, WriterTo: bytes.NewBuffer(make([]byte, 2*3))}, {Name: "test.1", Shape: []uint64{2, 3}, WriterTo: bytes.NewBuffer(slices.Repeat([]byte{0}, 2*3*4))},
{Name: "blk.3.attn_norm.weight", Shape: []uint64{2, 3}, WriterTo: bytes.NewBuffer(make([]byte, 2*3))}, {Name: "test.2", Shape: []uint64{2, 3}, WriterTo: bytes.NewBuffer(slices.Repeat([]byte{0}, 2*3*4))},
{Name: "blk.4.attn_norm.weight", Shape: []uint64{2, 3}, WriterTo: bytes.NewBuffer(make([]byte, 2*3))}, {Name: "test.3", Shape: []uint64{2, 3}, WriterTo: bytes.NewBuffer(slices.Repeat([]byte{0}, 2*3*4))},
{Name: "blk.5.attn_norm.weight", Shape: []uint64{2, 3}, WriterTo: bytes.NewBuffer(make([]byte, 2*3))}, {Name: "test.4", Shape: []uint64{2, 3}, WriterTo: bytes.NewBuffer(slices.Repeat([]byte{0}, 2*3*4))},
{Name: "output_norm.weight", Shape: []uint64{3, 2}, WriterTo: bytes.NewBuffer(make([]byte, 3*2))}, {Name: "test.5", Shape: []uint64{2, 3}, WriterTo: bytes.NewBuffer(slices.Repeat([]byte{0}, 2*3*4))},
{Name: "output.weight", Shape: []uint64{3, 2}, WriterTo: bytes.NewBuffer(make([]byte, 3*2))}, }); err != nil {
} t.Fatal(err)
}
r.Shuffle(len(ts), func(i, j int) { r, err := os.Open(w.Name())
ts[i], ts[j] = ts[j], ts[i] if err != nil {
}) t.Fatal(err)
}
defer r.Close()
w, err := os.CreateTemp(t.TempDir(), strings.ReplaceAll(t.Name(), "/", "_")+"*.bin") ff, _, err := Decode(r, 0)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
defer w.Close()
if err := WriteGGUF(w, KV{ if diff := cmp.Diff(ff.KV(), KV{
"general.alignment": uint32(16), "general.alignment": uint32(16),
}, ts); err != nil { "general.parameter_count": uint64(36),
t.Fatal(err) }); diff != "" {
} t.Errorf("Mismatch (-want +got):\n%s", diff)
}
r, err := os.Open(w.Name()) if diff := cmp.Diff(ff.Tensors(), Tensors{
if err != nil { Offset: 336,
t.Fatal(err) items: []*Tensor{
} {Name: "test.0", Offset: 0, Shape: []uint64{2, 3}},
defer r.Close() {Name: "test.1", Offset: 32, Shape: []uint64{2, 3}},
{Name: "test.2", Offset: 64, Shape: []uint64{2, 3}},
ff, err := Decode(r, 0) {Name: "test.3", Offset: 96, Shape: []uint64{2, 3}},
if err != nil { {Name: "test.4", Offset: 128, Shape: []uint64{2, 3}},
t.Fatal(err) {Name: "test.5", Offset: 160, Shape: []uint64{2, 3}},
} },
}, cmp.AllowUnexported(Tensors{})); diff != "" {
if diff := cmp.Diff(KV{ t.Errorf("Mismatch (-want +got):\n%s", diff)
"general.alignment": uint32(16),
"general.parameter_count": uint64(54),
}, ff.KV()); diff != "" {
t.Errorf("Mismatch (-want +got):\n%s", diff)
}
if diff := cmp.Diff(Tensors{
Offset: 608,
items: []*Tensor{
{Name: "blk.0.attn_norm.weight", Offset: 0, Shape: []uint64{2, 3}},
{Name: "blk.1.attn_norm.weight", Offset: 32, Shape: []uint64{2, 3}},
{Name: "blk.2.attn_norm.weight", Offset: 64, Shape: []uint64{2, 3}},
{Name: "blk.3.attn_norm.weight", Offset: 96, Shape: []uint64{2, 3}},
{Name: "blk.4.attn_norm.weight", Offset: 128, Shape: []uint64{2, 3}},
{Name: "blk.5.attn_norm.weight", Offset: 160, Shape: []uint64{2, 3}},
{Name: "output.weight", Offset: 192, Shape: []uint64{3, 2}},
{Name: "output_norm.weight", Offset: 224, Shape: []uint64{3, 2}},
{Name: "token_embd.weight", Offset: 256, Shape: []uint64{2, 3}},
},
}, ff.Tensors(), cmp.AllowUnexported(Tensors{})); diff != "" {
t.Errorf("Mismatch (-want +got):\n%s", diff)
}
})
} }
} }

View File

@@ -12,42 +12,42 @@ type FileType uint32
const ( const (
FileTypeF32 FileType = iota FileTypeF32 FileType = iota
FileTypeF16 FileTypeF16
fileTypeQ4_0 FileTypeQ4_0
fileTypeQ4_1 FileTypeQ4_1
fileTypeQ4_1_F16 // unused by GGML fileTypeQ4_1_F16 // unused by GGML
fileTypeQ4_2 // unused by GGML fileTypeQ4_2 // unused by GGML
fileTypeQ4_3 // unused by GGML fileTypeQ4_3 // unused by GGML
FileTypeQ8_0 FileTypeQ8_0
fileTypeQ5_0 FileTypeQ5_0
fileTypeQ5_1 FileTypeQ5_1
fileTypeQ2_K FileTypeQ2_K
fileTypeQ3_K_S FileTypeQ3_K_S
fileTypeQ3_K_M FileTypeQ3_K_M
fileTypeQ3_K_L FileTypeQ3_K_L
FileTypeQ4_K_S FileTypeQ4_K_S
FileTypeQ4_K_M FileTypeQ4_K_M
fileTypeQ5_K_S FileTypeQ5_K_S
fileTypeQ5_K_M FileTypeQ5_K_M
fileTypeQ6_K FileTypeQ6_K
fileTypeIQ2_XXS fileTypeIQ2_XXS // not supported by ollama
fileTypeIQ2_XS fileTypeIQ2_XS // not supported by ollama
fileTypeQ2_K_S FileTypeQ2_K_S
fileTypeIQ3_XS fileTypeIQ3_XS // not supported by ollama
fileTypeIQ3_XXS fileTypeIQ3_XXS // not supported by ollama
fileTypeIQ1_S fileTypeIQ1_S // not supported by ollama
fileTypeIQ4_NL fileTypeIQ4_NL // not supported by ollama
fileTypeIQ3_S fileTypeIQ3_S // not supported by ollama
fileTypeIQ3_M fileTypeIQ3_M // not supported by ollama
fileTypeIQ2_S fileTypeIQ2_S // not supported by ollama
fileTypeIQ2_M fileTypeIQ2_M // not supported by ollama
fileTypeIQ4_XS fileTypeIQ4_XS // not supported by ollama
fileTypeIQ1_M fileTypeIQ1_M // not supported by ollama
FileTypeBF16 FileTypeBF16
fileTypeQ4_0_4_4 // unused by GGML fileTypeQ4_0_4_4 // unused by GGML
fileTypeQ4_0_4_8 // unused by GGML fileTypeQ4_0_4_8 // unused by GGML
fileTypeQ4_0_8_8 // unused by GGML fileTypeQ4_0_8_8 // unused by GGML
fileTypeTQ1_0 fileTypeTQ1_0 // not supported by ollama
fileTypeTQ2_0 fileTypeTQ2_0 // not supported by ollama
FileTypeUnknown = 1024 FileTypeUnknown = 1024
) )
@@ -60,12 +60,36 @@ func ParseFileType(s string) (FileType, error) {
return FileTypeF32, nil return FileTypeF32, nil
case "F16": case "F16":
return FileTypeF16, nil return FileTypeF16, nil
case "Q4_0":
return FileTypeQ4_0, nil
case "Q4_1":
return FileTypeQ4_1, nil
case "Q8_0": case "Q8_0":
return FileTypeQ8_0, nil return FileTypeQ8_0, nil
case "Q5_0":
return FileTypeQ5_0, nil
case "Q5_1":
return FileTypeQ5_1, nil
case "Q2_K":
return FileTypeQ2_K, nil
case "Q3_K_S":
return FileTypeQ3_K_S, nil
case "Q3_K_M":
return FileTypeQ3_K_M, nil
case "Q3_K_L":
return FileTypeQ3_K_L, nil
case "Q4_K_S": case "Q4_K_S":
return FileTypeQ4_K_S, nil return FileTypeQ4_K_S, nil
case "Q4_K_M", "Q4_K": case "Q4_K_M", "Q4_K":
return FileTypeQ4_K_M, nil return FileTypeQ4_K_M, nil
case "Q5_K_S":
return FileTypeQ5_K_S, nil
case "Q5_K_M", "Q5_K":
return FileTypeQ5_K_M, nil
case "Q6_K":
return FileTypeQ6_K, nil
case "Q2_K_S":
return FileTypeQ2_K_S, nil
case "BF16": case "BF16":
return FileTypeBF16, nil return FileTypeBF16, nil
default: default:
@@ -87,41 +111,40 @@ func ParseFileType(s string) (FileType, error) {
} }
func (t FileType) String() string { func (t FileType) String() string {
// Note: this routine will return a broader set of file types for existing models
switch t { switch t {
case FileTypeF32: case FileTypeF32:
return "F32" return "F32"
case FileTypeF16: case FileTypeF16:
return "F16" return "F16"
case fileTypeQ4_0: case FileTypeQ4_0:
return "Q4_0" return "Q4_0"
case fileTypeQ4_1: case FileTypeQ4_1:
return "Q4_1" return "Q4_1"
case FileTypeQ8_0: case FileTypeQ8_0:
return "Q8_0" return "Q8_0"
case fileTypeQ5_0: case FileTypeQ5_0:
return "Q5_0" return "Q5_0"
case fileTypeQ5_1: case FileTypeQ5_1:
return "Q5_1" return "Q5_1"
case fileTypeQ2_K: case FileTypeQ2_K:
return "Q2_K" return "Q2_K"
case fileTypeQ3_K_S: case FileTypeQ3_K_S:
return "Q3_K_S" return "Q3_K_S"
case fileTypeQ3_K_M: case FileTypeQ3_K_M:
return "Q3_K_M" return "Q3_K_M"
case fileTypeQ3_K_L: case FileTypeQ3_K_L:
return "Q3_K_L" return "Q3_K_L"
case FileTypeQ4_K_S: case FileTypeQ4_K_S:
return "Q4_K_S" return "Q4_K_S"
case FileTypeQ4_K_M: case FileTypeQ4_K_M:
return "Q4_K_M" return "Q4_K_M"
case fileTypeQ5_K_S: case FileTypeQ5_K_S:
return "Q5_K_S" return "Q5_K_S"
case fileTypeQ5_K_M: case FileTypeQ5_K_M:
return "Q5_K_M" return "Q5_K_M"
case fileTypeQ6_K: case FileTypeQ6_K:
return "Q6_K" return "Q6_K"
case fileTypeQ2_K_S: case FileTypeQ2_K_S:
return "Q2_K_S" return "Q2_K_S"
case FileTypeBF16: case FileTypeBF16:
return "BF16" return "BF16"
@@ -140,35 +163,35 @@ func (ftype FileType) ToTensorType() TensorType {
return TensorTypeF32 return TensorTypeF32
case FileTypeF16: case FileTypeF16:
return TensorTypeF16 return TensorTypeF16
case fileTypeQ4_0: case FileTypeQ4_0:
return TensorTypeQ4_0 return TensorTypeQ4_0
case fileTypeQ4_1: case FileTypeQ4_1:
return TensorTypeQ4_1 return TensorTypeQ4_1
case FileTypeQ8_0: case FileTypeQ8_0:
return TensorTypeQ8_0 return TensorTypeQ8_0
case fileTypeQ5_0: case FileTypeQ5_0:
return TensorTypeQ5_0 return TensorTypeQ5_0
case fileTypeQ5_1: case FileTypeQ5_1:
return TensorTypeQ5_1 return TensorTypeQ5_1
case fileTypeQ2_K: case FileTypeQ2_K:
return TensorTypeQ2_K return TensorTypeQ2_K
case fileTypeQ3_K_S: case FileTypeQ3_K_S:
return TensorTypeQ3_K return TensorTypeQ3_K
case fileTypeQ3_K_M: case FileTypeQ3_K_M:
return TensorTypeQ3_K return TensorTypeQ3_K
case fileTypeQ3_K_L: case FileTypeQ3_K_L:
return TensorTypeQ3_K return TensorTypeQ3_K
case FileTypeQ4_K_S: case FileTypeQ4_K_S:
return TensorTypeQ4_K return TensorTypeQ4_K
case FileTypeQ4_K_M: case FileTypeQ4_K_M:
return TensorTypeQ4_K return TensorTypeQ4_K
case fileTypeQ5_K_S: case FileTypeQ5_K_S:
return TensorTypeQ5_K return TensorTypeQ5_K
case fileTypeQ5_K_M: case FileTypeQ5_K_M:
return TensorTypeQ5_K return TensorTypeQ5_K
case fileTypeQ6_K: case FileTypeQ6_K:
return TensorTypeQ6_K return TensorTypeQ6_K
case fileTypeQ2_K_S: case FileTypeQ2_K_S:
return TensorTypeQ2_K return TensorTypeQ2_K
case FileTypeBF16: case FileTypeBF16:
return TensorTypeBF16 return TensorTypeBF16

View File

@@ -1,347 +0,0 @@
package gguf
import (
"bytes"
"cmp"
"encoding/binary"
"errors"
"fmt"
"io"
"iter"
"os"
"slices"
"strings"
)
const (
typeUint8 uint32 = iota
typeInt8
typeUint16
typeInt16
typeUint32
typeInt32
typeFloat32
typeBool
typeString
typeArray
typeUint64
typeInt64
typeFloat64
)
var ErrUnsupported = errors.New("unsupported")
type File struct {
Magic [4]byte
Version uint32
keyValues *lazy[KeyValue]
tensors *lazy[TensorInfo]
offset int64
file *os.File
reader *bufferedReader
bts []byte
}
func Open(path string) (f *File, err error) {
f = &File{bts: make([]byte, 4096)}
f.file, err = os.Open(path)
if err != nil {
return nil, err
}
f.reader = newBufferedReader(f.file, 32<<10)
if err := binary.Read(f.reader, binary.LittleEndian, &f.Magic); err != nil {
return nil, err
}
if bytes.Equal(f.Magic[:], []byte("gguf")) {
return nil, fmt.Errorf("%w file type %v", ErrUnsupported, f.Magic)
}
if err := binary.Read(f.reader, binary.LittleEndian, &f.Version); err != nil {
return nil, err
}
if f.Version < 2 {
return nil, fmt.Errorf("%w version %v", ErrUnsupported, f.Version)
}
f.tensors, err = newLazy(f, f.readTensor)
if err != nil {
return nil, err
}
f.tensors.successFunc = func() error {
offset := f.reader.offset
alignment := cmp.Or(f.KeyValue("general.alignment").Int(), 32)
f.offset = offset + (alignment-offset%alignment)%alignment
return nil
}
f.keyValues, err = newLazy(f, f.readKeyValue)
if err != nil {
return nil, err
}
return f, nil
}
func (f *File) readTensor() (TensorInfo, error) {
name, err := readString(f)
if err != nil {
return TensorInfo{}, err
}
dims, err := read[uint32](f)
if err != nil {
return TensorInfo{}, err
}
shape := make([]uint64, dims)
for i := range dims {
shape[i], err = read[uint64](f)
if err != nil {
return TensorInfo{}, err
}
}
type_, err := read[uint32](f)
if err != nil {
return TensorInfo{}, err
}
offset, err := read[uint64](f)
if err != nil {
return TensorInfo{}, err
}
return TensorInfo{
Name: name,
Offset: offset,
Shape: shape,
Type: TensorType(type_),
}, nil
}
func (f *File) readKeyValue() (KeyValue, error) {
key, err := readString(f)
if err != nil {
return KeyValue{}, err
}
t, err := read[uint32](f)
if err != nil {
return KeyValue{}, err
}
value, err := func() (any, error) {
switch t {
case typeUint8:
return read[uint8](f)
case typeInt8:
return read[int8](f)
case typeUint16:
return read[uint16](f)
case typeInt16:
return read[int16](f)
case typeUint32:
return read[uint32](f)
case typeInt32:
return read[int32](f)
case typeUint64:
return read[uint64](f)
case typeInt64:
return read[int64](f)
case typeFloat32:
return read[float32](f)
case typeFloat64:
return read[float64](f)
case typeBool:
return read[bool](f)
case typeString:
return readString(f)
case typeArray:
return readArray(f)
default:
return nil, fmt.Errorf("%w type %d", ErrUnsupported, t)
}
}()
if err != nil {
return KeyValue{}, err
}
return KeyValue{
Key: key,
Value: Value{value},
}, nil
}
func read[T any](f *File) (t T, err error) {
err = binary.Read(f.reader, binary.LittleEndian, &t)
return t, err
}
func readString(f *File) (string, error) {
n, err := read[uint64](f)
if err != nil {
return "", err
}
if int(n) > len(f.bts) {
f.bts = make([]byte, n)
}
bts := f.bts[:n]
if _, err := io.ReadFull(f.reader, bts); err != nil {
return "", err
}
defer clear(bts)
return string(bts), nil
}
func readArray(f *File) (any, error) {
t, err := read[uint32](f)
if err != nil {
return nil, err
}
n, err := read[uint64](f)
if err != nil {
return nil, err
}
switch t {
case typeUint8:
return readArrayData[uint8](f, n)
case typeInt8:
return readArrayData[int8](f, n)
case typeUint16:
return readArrayData[uint16](f, n)
case typeInt16:
return readArrayData[int16](f, n)
case typeUint32:
return readArrayData[uint32](f, n)
case typeInt32:
return readArrayData[int32](f, n)
case typeUint64:
return readArrayData[uint64](f, n)
case typeInt64:
return readArrayData[int64](f, n)
case typeFloat32:
return readArrayData[float32](f, n)
case typeFloat64:
return readArrayData[float64](f, n)
case typeBool:
return readArrayData[bool](f, n)
case typeString:
return readArrayString(f, n)
default:
return nil, fmt.Errorf("%w type %d", ErrUnsupported, t)
}
}
func readArrayData[T any](f *File, n uint64) (s []T, err error) {
s = make([]T, n)
for i := range n {
e, err := read[T](f)
if err != nil {
return nil, err
}
s[i] = e
}
return s, nil
}
func readArrayString(f *File, n uint64) (s []string, err error) {
s = make([]string, n)
for i := range n {
e, err := readString(f)
if err != nil {
return nil, err
}
s[i] = e
}
return s, nil
}
func (f *File) Close() error {
f.keyValues.stop()
f.tensors.stop()
return f.file.Close()
}
func (f *File) KeyValue(key string) KeyValue {
if !strings.HasPrefix(key, "general.") && !strings.HasPrefix(key, "tokenizer.") {
key = f.KeyValue("general.architecture").String() + "." + key
}
if index := slices.IndexFunc(f.keyValues.values, func(kv KeyValue) bool {
return kv.Key == key
}); index >= 0 {
return f.keyValues.values[index]
}
for keyValue, ok := f.keyValues.next(); ok; keyValue, ok = f.keyValues.next() {
if keyValue.Key == key {
return keyValue
}
}
return KeyValue{}
}
func (f *File) NumKeyValues() int {
return int(f.keyValues.count)
}
func (f *File) KeyValues() iter.Seq2[int, KeyValue] {
return f.keyValues.All()
}
func (f *File) TensorInfo(name string) TensorInfo {
if index := slices.IndexFunc(f.tensors.values, func(t TensorInfo) bool {
return t.Name == name
}); index >= 0 {
return f.tensors.values[index]
}
// fast-forward through key values if we haven't already
_ = f.keyValues.rest()
for tensor, ok := f.tensors.next(); ok; tensor, ok = f.tensors.next() {
if tensor.Name == name {
return tensor
}
}
return TensorInfo{}
}
func (f *File) NumTensors() int {
return int(f.tensors.count)
}
func (f *File) TensorInfos() iter.Seq2[int, TensorInfo] {
// fast forward through key values if we haven't already
f.keyValues.rest()
return f.tensors.All()
}
func (f *File) TensorReader(name string) (TensorInfo, io.Reader, error) {
t := f.TensorInfo(name)
if t.NumBytes() == 0 {
return TensorInfo{}, nil, fmt.Errorf("tensor %s not found", name)
}
// fast forward through tensor info if we haven't already
_ = f.tensors.rest()
return t, io.NewSectionReader(f.file, f.offset+int64(t.Offset), t.NumBytes()), nil
}

View File

@@ -1,249 +0,0 @@
package gguf_test
import (
"bytes"
"os"
"strconv"
"strings"
"testing"
"github.com/google/go-cmp/cmp"
"github.com/google/go-cmp/cmp/cmpopts"
"github.com/ollama/ollama/fs/ggml"
"github.com/ollama/ollama/fs/gguf"
)
func createBinFile(tb testing.TB) string {
tb.Helper()
f, err := os.CreateTemp(tb.TempDir(), "")
if err != nil {
tb.Fatal(err)
}
defer f.Close()
kv := ggml.KV{
"general.architecture": "llama",
"llama.block_count": uint32(8),
"llama.embedding_length": uint32(3),
"llama.attention.head_count": uint32(2),
"llama.attention.head_count_kv": uint32(2),
"llama.attention.key_length": uint32(3),
"llama.rope.dimension_count": uint32(4),
"llama.rope.freq_base": float32(10000.0),
"llama.rope.freq_scale": float32(1.0),
"llama.attention.layer_norm_rms_epsilon": float32(1e-6),
"tokenizer.ggml.eos_token_id": uint32(0),
"tokenizer.ggml.eos_token_ids": []int32{1, 2, 3},
"tokenizer.ggml.tokens": []string{"hello", "world"},
"tokenizer.ggml.scores": []float32{0, 1},
}
tensors := []*ggml.Tensor{
{
Name: "token_embd.weight",
Kind: 0,
Shape: []uint64{2, 3},
WriterTo: bytes.NewBuffer(make([]byte, 4*2*3)),
},
{
Name: "output.weight",
Kind: 0,
Shape: []uint64{3, 2},
WriterTo: bytes.NewBuffer(make([]byte, 4*3*2)),
},
}
for i := range 8 {
tensors = append(tensors, &ggml.Tensor{
Name: "blk." + strconv.Itoa(i) + ".attn_q.weight",
Kind: 0,
Shape: []uint64{3, 3},
WriterTo: bytes.NewBuffer(make([]byte, 4*3*3)),
}, &ggml.Tensor{
Name: "blk." + strconv.Itoa(i) + ".attn_k.weight",
Kind: 0,
Shape: []uint64{3, 3},
WriterTo: bytes.NewBuffer(make([]byte, 4*3*3)),
}, &ggml.Tensor{
Name: "blk." + strconv.Itoa(i) + ".attn_v.weight",
Kind: 0,
Shape: []uint64{3, 3},
WriterTo: bytes.NewBuffer(make([]byte, 4*3*3)),
}, &ggml.Tensor{
Name: "blk." + strconv.Itoa(i) + ".attn_output.weight",
Kind: 0,
Shape: []uint64{3, 3},
WriterTo: bytes.NewBuffer(make([]byte, 4*3*3)),
})
}
if err := ggml.WriteGGUF(f, kv, tensors); err != nil {
tb.Fatal(err)
}
return f.Name()
}
func TestRead(t *testing.T) {
f, err := gguf.Open(createBinFile(t))
if err != nil {
t.Fatal(err)
}
defer f.Close()
if got := f.KeyValue("does.not.exist").Valid(); got {
t.Errorf(`KeyValue("does.not.exist").Exists() = %v, want false`, got)
}
if got := f.KeyValue("general.architecture").String(); got != "llama" {
t.Errorf(`KeyValue("general.architecture").String() = %q, want %q`, got, "llama")
}
if got := f.TensorInfo("token_embd.weight"); got.Name != "token_embd.weight" {
t.Errorf(`TensorInfo("token_embd.weight").Name = %q, want %q`, got.Name, "token_embd.weight")
} else if diff := cmp.Diff(got.Shape, []uint64{2, 3}); diff != "" {
t.Errorf(`TensorInfo("token_embd.weight").Shape mismatch (-got +want):\n%s`, diff)
} else if got.Type != gguf.TensorTypeF32 {
t.Errorf(`TensorInfo("token_embd.weight").Type = %d, want %d`, got.Type, gguf.TensorTypeF32)
}
if got := f.KeyValue("block_count").Uint(); got != 8 {
t.Errorf(`KeyValue("block_count").Uint() = %d, want %d`, got, 8)
}
if diff := cmp.Diff(f.KeyValue("tokenizer.ggml.tokens").Strings(), []string{"hello", "world"}); diff != "" {
t.Errorf("KeyValue(\"tokenizer.ggml.tokens\").Strings() mismatch (-got +want):\n%s", diff)
}
if diff := cmp.Diff(f.KeyValue("tokenizer.ggml.scores").Floats(), []float64{0, 1}); diff != "" {
t.Errorf("KeyValue(\"tokenizer.ggml.scores\").Ints() mismatch (-got +want):\n%s", diff)
}
var kvs []string
for _, kv := range f.KeyValues() {
if !kv.Valid() {
t.Error("found invalid key-value pair:", kv)
}
kvs = append(kvs, kv.Key)
}
if len(kvs) != f.NumKeyValues() {
t.Errorf("iterated key count = %d, want %d", len(kvs), f.NumKeyValues())
}
if diff := cmp.Diff(kvs, []string{
"general.architecture",
"llama.block_count",
"llama.embedding_length",
"llama.attention.head_count",
"llama.attention.head_count_kv",
"llama.attention.key_length",
"llama.rope.dimension_count",
"llama.rope.freq_base",
"llama.rope.freq_scale",
"llama.attention.layer_norm_rms_epsilon",
"tokenizer.ggml.eos_token_id",
"tokenizer.ggml.eos_token_ids",
"tokenizer.ggml.tokens",
"tokenizer.ggml.scores",
}, cmpopts.SortSlices(strings.Compare)); diff != "" {
t.Errorf("KeyValues() mismatch (-got +want):\n%s", diff)
}
var tis []string
for _, ti := range f.TensorInfos() {
if !ti.Valid() {
t.Error("found invalid tensor info:", ti)
}
tis = append(tis, ti.Name)
}
if len(tis) != f.NumTensors() {
t.Errorf("iterated tensor count = %d, want %d", len(tis), f.NumTensors())
}
if diff := cmp.Diff(tis, []string{
"token_embd.weight",
"output.weight",
"blk.0.attn_q.weight",
"blk.0.attn_k.weight",
"blk.0.attn_v.weight",
"blk.0.attn_output.weight",
"blk.1.attn_q.weight",
"blk.1.attn_k.weight",
"blk.1.attn_v.weight",
"blk.1.attn_output.weight",
"blk.2.attn_q.weight",
"blk.2.attn_k.weight",
"blk.2.attn_v.weight",
"blk.2.attn_output.weight",
"blk.3.attn_q.weight",
"blk.3.attn_k.weight",
"blk.3.attn_v.weight",
"blk.3.attn_output.weight",
"blk.4.attn_q.weight",
"blk.4.attn_k.weight",
"blk.4.attn_v.weight",
"blk.4.attn_output.weight",
"blk.5.attn_q.weight",
"blk.5.attn_k.weight",
"blk.5.attn_v.weight",
"blk.5.attn_output.weight",
"blk.6.attn_q.weight",
"blk.6.attn_k.weight",
"blk.6.attn_v.weight",
"blk.6.attn_output.weight",
"blk.7.attn_q.weight",
"blk.7.attn_k.weight",
"blk.7.attn_v.weight",
"blk.7.attn_output.weight",
}, cmpopts.SortSlices(strings.Compare)); diff != "" {
t.Errorf("TensorInfos() mismatch (-got +want):\n%s", diff)
}
ti, r, err := f.TensorReader("output.weight")
if err != nil {
t.Fatalf(`TensorReader("output.weight") error: %v`, err)
}
if ti.Name != "output.weight" {
t.Errorf(`TensorReader("output.weight").Name = %q, want %q`, ti.Name, "output.weight")
} else if diff := cmp.Diff(ti.Shape, []uint64{3, 2}); diff != "" {
t.Errorf(`TensorReader("output.weight").Shape mismatch (-got +want):\n%s`, diff)
} else if ti.Type != gguf.TensorTypeF32 {
t.Errorf(`TensorReader("output.weight").Type = %d, want %d`, ti.Type, gguf.TensorTypeF32)
}
var b bytes.Buffer
if _, err := b.ReadFrom(r); err != nil {
t.Fatalf(`ReadFrom TensorReader("output.weight") error: %v`, err)
}
if b.Len() != int(ti.NumBytes()) {
t.Errorf(`ReadFrom TensorReader("output.weight") length = %d, want %d`, b.Len(), ti.NumBytes())
}
}
func BenchmarkRead(b *testing.B) {
b.ReportAllocs()
p := createBinFile(b)
for b.Loop() {
f, err := gguf.Open(p)
if err != nil {
b.Fatal(err)
}
if got := f.KeyValue("general.architecture").String(); got != "llama" {
b.Errorf("got = %q, want %q", got, "llama")
}
// Iterate through some tensors
for range f.TensorInfos() {
}
f.Close()
}
}

View File

@@ -1,90 +0,0 @@
package gguf
import (
"reflect"
"slices"
)
type KeyValue struct {
Key string
Value
}
func (kv KeyValue) Valid() bool {
return kv.Key != "" && kv.Value.value != nil
}
type Value struct {
value any
}
func value[T any](v Value, kinds ...reflect.Kind) (t T) {
vv := reflect.ValueOf(v.value)
if slices.Contains(kinds, vv.Kind()) {
t = vv.Convert(reflect.TypeOf(t)).Interface().(T)
}
return
}
func values[T any](v Value, kinds ...reflect.Kind) (ts []T) {
switch vv := reflect.ValueOf(v.value); vv.Kind() {
case reflect.Slice:
if slices.Contains(kinds, vv.Type().Elem().Kind()) {
ts = make([]T, vv.Len())
for i := range vv.Len() {
ts[i] = vv.Index(i).Convert(reflect.TypeOf(ts[i])).Interface().(T)
}
}
}
return
}
// Int returns Value as a signed integer. If it is not a signed integer, it returns 0.
func (v Value) Int() int64 {
return value[int64](v, reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64)
}
// Ints returns Value as a signed integer slice. If it is not a signed integer slice, it returns nil.
func (v Value) Ints() (i64s []int64) {
return values[int64](v, reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64)
}
// Uint converts an unsigned integer value to uint64. If the value is not a unsigned integer, it returns 0.
func (v Value) Uint() uint64 {
return value[uint64](v, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64)
}
// Uints returns Value as a unsigned integer slice. If it is not a unsigned integer slice, it returns nil.
func (v Value) Uints() (u64s []uint64) {
return values[uint64](v, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64)
}
// Float returns Value as a float. If it is not a float, it returns 0.
func (v Value) Float() float64 {
return value[float64](v, reflect.Float32, reflect.Float64)
}
// Floats returns Value as a float slice. If it is not a float slice, it returns nil.
func (v Value) Floats() (f64s []float64) {
return values[float64](v, reflect.Float32, reflect.Float64)
}
// Bool returns Value as a boolean. If it is not a boolean, it returns false.
func (v Value) Bool() bool {
return value[bool](v, reflect.Bool)
}
// Bools returns Value as a boolean slice. If it is not a boolean slice, it returns nil.
func (v Value) Bools() (bools []bool) {
return values[bool](v, reflect.Bool)
}
// String returns Value as a string. If it is not a string, it returns an empty string.
func (v Value) String() string {
return value[string](v, reflect.String)
}
// Strings returns Value as a string slice. If it is not a string slice, it returns nil.
func (v Value) Strings() (strings []string) {
return values[string](v, reflect.String)
}

View File

@@ -1,208 +0,0 @@
package gguf
import (
"testing"
"github.com/google/go-cmp/cmp"
)
func split(name string, values map[string][]any) (matched []any, unmatched []any) {
for key, value := range values {
if key == name {
matched = value
} else {
unmatched = append(unmatched, value...)
}
}
return
}
func TestValue(t *testing.T) {
values := map[string][]any{
"int64": {int(42), int8(42), int16(42), int32(42), int64(42)},
"uint64": {uint(42), uint8(42), uint16(42), uint32(42), uint64(42)},
"float64": {float32(42), float64(42)},
"string": {"42", "hello"},
"bool": {true, false},
}
t.Run("int64", func(t *testing.T) {
matched, unmatched := split("int64", values)
for _, v := range matched {
kv := KeyValue{"key", Value{v}}
if i64 := kv.Int(); i64 != 42 {
t.Errorf("expected 42, got %d", i64)
}
}
for _, v := range unmatched {
kv := KeyValue{"key", Value{v}}
if i64 := kv.Int(); i64 != 0 {
t.Errorf("expected 42, got %d", i64)
}
}
})
t.Run("uint64", func(t *testing.T) {
matched, unmatched := split("uint64", values)
for _, v := range matched {
kv := KeyValue{"key", Value{v}}
if u64 := kv.Uint(); u64 != 42 {
t.Errorf("expected 42, got %d", u64)
}
}
for _, v := range unmatched {
kv := KeyValue{"key", Value{v}}
if u64 := kv.Uint(); u64 != 0 {
t.Errorf("expected 42, got %d", u64)
}
}
})
t.Run("float64", func(t *testing.T) {
matched, unmatched := split("float64", values)
for _, v := range matched {
kv := KeyValue{"key", Value{v}}
if f64 := kv.Float(); f64 != 42 {
t.Errorf("expected 42, got %f", f64)
}
}
for _, v := range unmatched {
kv := KeyValue{"key", Value{v}}
if f64 := kv.Float(); f64 != 0 {
t.Errorf("expected 42, got %f", f64)
}
}
})
t.Run("string", func(t *testing.T) {
matched, unmatched := split("string", values)
for _, v := range matched {
kv := KeyValue{"key", Value{v}}
if s := kv.String(); s != v {
t.Errorf("expected 42, got %s", s)
}
}
for _, v := range unmatched {
kv := KeyValue{"key", Value{v}}
if s := kv.String(); s != "" {
t.Errorf("expected 42, got %s", s)
}
}
})
t.Run("bool", func(t *testing.T) {
matched, unmatched := split("bool", values)
for _, v := range matched {
kv := KeyValue{"key", Value{v}}
if b := kv.Bool(); b != v {
t.Errorf("expected true, got %v", b)
}
}
for _, v := range unmatched {
kv := KeyValue{"key", Value{v}}
if b := kv.Bool(); b != false {
t.Errorf("expected false, got %v", b)
}
}
})
}
func TestValues(t *testing.T) {
values := map[string][]any{
"int64s": {[]int{42}, []int8{42}, []int16{42}, []int32{42}, []int64{42}},
"uint64s": {[]uint{42}, []uint8{42}, []uint16{42}, []uint32{42}, []uint64{42}},
"float64s": {[]float32{42}, []float64{42}},
"strings": {[]string{"42"}, []string{"hello"}},
"bools": {[]bool{true}, []bool{false}},
}
t.Run("int64s", func(t *testing.T) {
matched, unmatched := split("int64s", values)
for _, v := range matched {
kv := KeyValue{"key", Value{v}}
if diff := cmp.Diff(kv.Ints(), []int64{42}); diff != "" {
t.Errorf("diff: %s", diff)
}
}
for _, v := range unmatched {
kv := KeyValue{"key", Value{v}}
if i64s := kv.Ints(); i64s != nil {
t.Errorf("expected nil, got %v", i64s)
}
}
})
t.Run("uint64s", func(t *testing.T) {
matched, unmatched := split("uint64s", values)
for _, v := range matched {
kv := KeyValue{"key", Value{v}}
if diff := cmp.Diff(kv.Uints(), []uint64{42}); diff != "" {
t.Errorf("diff: %s", diff)
}
}
for _, v := range unmatched {
kv := KeyValue{"key", Value{v}}
if u64s := kv.Uints(); u64s != nil {
t.Errorf("expected nil, got %v", u64s)
}
}
})
t.Run("float64s", func(t *testing.T) {
matched, unmatched := split("float64s", values)
for _, v := range matched {
kv := KeyValue{"key", Value{v}}
if diff := cmp.Diff(kv.Floats(), []float64{42}); diff != "" {
t.Errorf("diff: %s", diff)
}
}
for _, v := range unmatched {
kv := KeyValue{"key", Value{v}}
if f64s := kv.Floats(); f64s != nil {
t.Errorf("expected nil, got %v", f64s)
}
}
})
t.Run("strings", func(t *testing.T) {
matched, unmatched := split("strings", values)
for _, v := range matched {
kv := KeyValue{"key", Value{v}}
if diff := cmp.Diff(kv.Strings(), v); diff != "" {
t.Errorf("diff: %s", diff)
}
}
for _, v := range unmatched {
kv := KeyValue{"key", Value{v}}
if s := kv.Strings(); s != nil {
t.Errorf("expected nil, got %v", s)
}
}
})
t.Run("bools", func(t *testing.T) {
matched, unmatched := split("bools", values)
for _, v := range matched {
kv := KeyValue{"key", Value{v}}
if diff := cmp.Diff(kv.Bools(), v); diff != "" {
t.Errorf("diff: %s", diff)
}
}
for _, v := range unmatched {
kv := KeyValue{"key", Value{v}}
if b := kv.Bools(); b != nil {
t.Errorf("expected nil, got %v", b)
}
}
})
}

View File

@@ -1,89 +0,0 @@
package gguf
import (
"encoding/binary"
"iter"
"log/slog"
)
type lazy[T any] struct {
count uint64
next func() (T, bool)
stop func()
values []T
// successFunc is called when all values have been successfully read.
successFunc func() error
}
func newLazy[T any](f *File, fn func() (T, error)) (*lazy[T], error) {
it := lazy[T]{}
if err := binary.Read(f.reader, binary.LittleEndian, &it.count); err != nil {
return nil, err
}
it.values = make([]T, 0)
it.next, it.stop = iter.Pull(func(yield func(T) bool) {
for i := range it.count {
t, err := fn()
if err != nil {
slog.Error("error reading tensor", "index", i, "error", err)
return
}
it.values = append(it.values, t)
if !yield(t) {
break
}
}
if it.successFunc != nil {
it.successFunc()
}
})
return &it, nil
}
func (g *lazy[T]) Values() iter.Seq[T] {
return func(yield func(T) bool) {
for _, v := range g.All() {
if !yield(v) {
break
}
}
}
}
func (g *lazy[T]) All() iter.Seq2[int, T] {
return func(yield func(int, T) bool) {
for i := range int(g.count) {
if i < len(g.values) {
if !yield(i, g.values[i]) {
break
}
} else {
t, ok := g.next()
if !ok {
break
}
if !yield(i, t) {
break
}
}
}
}
}
func (g *lazy[T]) rest() (collected bool) {
for {
_, ok := g.next()
collected = collected || ok
if !ok {
break
}
}
return collected
}

View File

@@ -1,23 +0,0 @@
package gguf
import (
"bufio"
"io"
)
type bufferedReader struct {
offset int64
*bufio.Reader
}
func newBufferedReader(rs io.ReadSeeker, size int) *bufferedReader {
return &bufferedReader{
Reader: bufio.NewReaderSize(rs, size),
}
}
func (rs *bufferedReader) Read(p []byte) (n int, err error) {
n, err = rs.Reader.Read(p)
rs.offset += int64(n)
return n, err
}

View File

@@ -1,288 +0,0 @@
package gguf
import (
"log/slog"
"strings"
)
type TensorInfo struct {
Name string
Offset uint64
Shape []uint64
Type TensorType
}
func (ti TensorInfo) Valid() bool {
return ti.Name != "" && ti.NumBytes() > 0
}
func (ti TensorInfo) NumValues() int64 {
var numItems int64 = 1
for _, dim := range ti.Shape {
numItems *= int64(dim)
}
return numItems
}
// NumBytes returns the number of bytes in the tensor.
func (ti TensorInfo) NumBytes() int64 {
return int64(float64(ti.NumValues()) * ti.Type.NumBytes())
}
func (ti TensorInfo) LogValue() slog.Value {
return slog.GroupValue(
slog.String("name", ti.Name),
slog.Int64("offset", int64(ti.Offset)),
slog.Any("shape", ti.Shape),
slog.Int64("num_values", ti.NumValues()),
slog.Int64("num_bytes", ti.NumBytes()),
slog.Any("type", ti.Type),
)
}
type TensorType uint32
const (
TensorTypeF32 TensorType = iota
TensorTypeF16
TensorTypeQ4_0
TensorTypeQ4_1
// unexported // unused in gguf
tensorTypeQ4_2
tensorTypeQ4_3
TensorTypeQ5_0
TensorTypeQ5_1
TensorTypeQ8_0
TensorTypeQ8_1
TensorTypeQ2_K
TensorTypeQ3_K
TensorTypeQ4_K
TensorTypeQ5_K
TensorTypeQ6_K
TensorTypeQ8_K
// unexported // unquantizable by ollama
tensorTypeIQ2_XXS
tensorTypeIQ2_XS
tensorTypeIQ3_XXS
tensorTypeIQ1_S
tensorTypeIQ4_NL
tensorTypeIQ3_S
tensorTypeIQ2_S
tensorTypeIQ4_XS
TensorTypeI8
TensorTypeI16
TensorTypeI32
TensorTypeI64
TensorTypeF64
// unexported // unquantizable by ollama
tensorTypeIQ1_M
TensorTypeBF16
// unexported // unused in gguf
tensorTypeQ4_0_4_4
tensorTypeQ4_0_4_8
tensorTypeQ4_0_8_8
// unexported // unquantizable by ollama
tensorTypeTQ1_0
tensorTypeTQ2_0
// unexported // unused in gguf
tensorTypeIQ4_NL_4_4
tensorTypeIQ4_NL_4_8
tensorTypeIQ4_NL_8_8
)
func (tt TensorType) NumBytes() float64 {
return float64(tt.typeSize()) / float64(tt.blockSize())
}
func (tt TensorType) typeSize() int64 {
switch tt {
case TensorTypeF32:
return 4
case TensorTypeF16:
return 2
case TensorTypeQ4_0:
return 2 + tt.blockSize()/2
case TensorTypeQ4_1:
return 2 + 2 + tt.blockSize()/2
case TensorTypeQ5_0:
return 2 + 4 + tt.blockSize()/2
case TensorTypeQ5_1:
return 2 + 2 + 4 + tt.blockSize()/2
case TensorTypeQ8_0:
return 2 + tt.blockSize()
case TensorTypeQ8_1:
return 2 + 2 + tt.blockSize()
case TensorTypeQ2_K:
return tt.blockSize()/16 + tt.blockSize()/4 + 2 + 2
case TensorTypeQ3_K:
return tt.blockSize()/8 + tt.blockSize()/4 + 12 + 2
case TensorTypeQ4_K:
return 2 + 2 + 12 + tt.blockSize()/2
case TensorTypeQ5_K:
return 2 + 2 + 12 + tt.blockSize()/8 + tt.blockSize()/2
case TensorTypeQ6_K:
return tt.blockSize()/2 + tt.blockSize()/4 + tt.blockSize()/16 + 2
case TensorTypeQ8_K:
return 4 + tt.blockSize() + 2*tt.blockSize()/16
case tensorTypeIQ2_XXS:
return 2 + 2*tt.blockSize()/8
case tensorTypeIQ2_XS:
return 2 + 2*tt.blockSize()/8 + tt.blockSize()/32
case tensorTypeIQ3_XXS:
return 2 + tt.blockSize()/4 + tt.blockSize()/8
case tensorTypeIQ1_S:
return 2 + tt.blockSize()/8 + tt.blockSize()/16
case tensorTypeIQ4_NL:
return 2 + tt.blockSize()/2
case tensorTypeIQ3_S:
return 2 + tt.blockSize()/4 + tt.blockSize()/8 + tt.blockSize()/32 + 4
case tensorTypeIQ2_S:
return 2 + tt.blockSize()/4 + tt.blockSize()/16
case tensorTypeIQ4_XS:
return 2 + 2 + tt.blockSize()/2 + tt.blockSize()/64
case TensorTypeI8:
return 1
case TensorTypeI16:
return 2
case TensorTypeI32:
return 4
case TensorTypeI64:
return 8
case TensorTypeF64:
return 8
case tensorTypeIQ1_M:
return tt.blockSize()/8 + tt.blockSize()/16 + tt.blockSize()/32
case TensorTypeBF16:
return 2
default:
return 0
}
}
func (tt TensorType) blockSize() int64 {
switch tt {
case TensorTypeF32,
TensorTypeF16,
TensorTypeI8,
TensorTypeI16,
TensorTypeI32,
TensorTypeI64,
TensorTypeF64,
TensorTypeBF16:
return 1
case TensorTypeQ4_0,
TensorTypeQ4_1,
TensorTypeQ5_0,
TensorTypeQ5_1,
TensorTypeQ8_0,
TensorTypeQ8_1,
tensorTypeIQ4_NL:
return 32
default:
return 256
}
}
func (tt TensorType) String() string {
switch tt {
case TensorTypeF32:
return "f32"
case TensorTypeF16:
return "f16"
case TensorTypeQ4_0:
return "q4_0"
case TensorTypeQ4_1:
return "q4_1"
case tensorTypeQ4_2:
return "q4_2"
case tensorTypeQ4_3:
return "q4_3"
case TensorTypeQ5_0:
return "q5_0"
case TensorTypeQ5_1:
return "q5_1"
case TensorTypeQ8_0:
return "q8_0"
case TensorTypeQ8_1:
return "q8_1"
case TensorTypeQ2_K:
return "q2_k"
case TensorTypeQ3_K:
return "q3_k"
case TensorTypeQ4_K:
return "q4_k"
case TensorTypeQ5_K:
return "q5_k"
case TensorTypeQ6_K:
return "q6_k"
case TensorTypeQ8_K:
return "q8_k"
case tensorTypeIQ2_XXS:
return "iq2_xxs"
case tensorTypeIQ2_XS:
return "iq2_xs"
case tensorTypeIQ3_XXS:
return "iq3_xxs"
case tensorTypeIQ1_S:
return "iq1_s"
case tensorTypeIQ4_NL:
return "iq4_nl"
case tensorTypeIQ3_S:
return "iq3_s"
case tensorTypeIQ2_S:
return "iq2_s"
case tensorTypeIQ4_XS:
return "iq4_xs"
case TensorTypeI8:
return "i8"
case TensorTypeI16:
return "i16"
case TensorTypeI32:
return "i32"
case TensorTypeI64:
return "i64"
case TensorTypeF64:
return "f64"
case tensorTypeIQ1_M:
return "iq1_m"
case TensorTypeBF16:
return "bf16"
case tensorTypeQ4_0_4_4:
return "q4_0_4_4"
case tensorTypeQ4_0_4_8:
return "q4_0_4_8"
case tensorTypeQ4_0_8_8:
return "q4_0_8_8"
case tensorTypeTQ1_0:
return "tq1_0"
case tensorTypeTQ2_0:
return "tq2_0"
case tensorTypeIQ4_NL_4_4:
return "iq4_nl_4_4"
case tensorTypeIQ4_NL_4_8:
return "iq4_nl_4_8"
case tensorTypeIQ4_NL_8_8:
return "iq4_nl_8_8"
default:
return "unknown"
}
}
func (tt TensorType) LogValue() slog.Value {
return slog.GroupValue(
slog.Uint64("value", uint64(tt)),
slog.String("name", strings.ToUpper(tt.String())),
slog.Int64("size", tt.typeSize()),
slog.Int64("block_size", tt.blockSize()),
slog.Float64("num_bytes", tt.NumBytes()),
)
}

6
go.mod
View File

@@ -19,13 +19,12 @@ require (
github.com/d4l3k/go-bfloat16 v0.0.0-20211005043715-690c3bdd05f1 github.com/d4l3k/go-bfloat16 v0.0.0-20211005043715-690c3bdd05f1
github.com/dlclark/regexp2 v1.11.4 github.com/dlclark/regexp2 v1.11.4
github.com/emirpasic/gods/v2 v2.0.0-alpha github.com/emirpasic/gods/v2 v2.0.0-alpha
github.com/google/go-cmp v0.7.0 github.com/google/go-cmp v0.6.0
github.com/mattn/go-runewidth v0.0.14 github.com/mattn/go-runewidth v0.0.14
github.com/nlpodyssey/gopickle v0.3.0 github.com/nlpodyssey/gopickle v0.3.0
github.com/pdevine/tensor v0.0.0-20240510204454-f88f4562727c github.com/pdevine/tensor v0.0.0-20240510204454-f88f4562727c
golang.org/x/image v0.22.0 golang.org/x/image v0.22.0
golang.org/x/tools v0.30.0 golang.org/x/tools v0.30.0
gonum.org/v1/gonum v0.15.0
) )
require ( require (
@@ -45,6 +44,7 @@ require (
github.com/xtgo/set v1.0.0 // indirect github.com/xtgo/set v1.0.0 // indirect
go4.org/unsafe/assume-no-moving-gc v0.0.0-20231121144256-b99613f794b6 // indirect go4.org/unsafe/assume-no-moving-gc v0.0.0-20231121144256-b99613f794b6 // indirect
golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1 // indirect golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1 // indirect
gonum.org/v1/gonum v0.15.0 // indirect
gorgonia.org/vecf32 v0.9.0 // indirect gorgonia.org/vecf32 v0.9.0 // indirect
gorgonia.org/vecf64 v0.9.0 // indirect gorgonia.org/vecf64 v0.9.0 // indirect
) )
@@ -71,7 +71,7 @@ require (
github.com/ugorji/go/codec v1.2.12 // indirect github.com/ugorji/go/codec v1.2.12 // indirect
golang.org/x/arch v0.8.0 // indirect golang.org/x/arch v0.8.0 // indirect
golang.org/x/crypto v0.36.0 golang.org/x/crypto v0.36.0
golang.org/x/exp v0.0.0-20250218142911-aa4b98e5adaa // indirect golang.org/x/exp v0.0.0-20250218142911-aa4b98e5adaa
golang.org/x/net v0.38.0 // indirect golang.org/x/net v0.38.0 // indirect
golang.org/x/sys v0.31.0 golang.org/x/sys v0.31.0
golang.org/x/term v0.30.0 golang.org/x/term v0.30.0

4
go.sum
View File

@@ -112,8 +112,8 @@ github.com/google/go-cmp v0.4.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/
github.com/google/go-cmp v0.5.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.5.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
github.com/google/go-cmp v0.5.6/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.5.6/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8= github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI=
github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU= github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
github.com/google/uuid v1.1.2/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/google/uuid v1.1.2/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=

View File

@@ -1,57 +0,0 @@
//go:build integration && library
package integration
import (
"context"
"log/slog"
"testing"
"time"
"github.com/ollama/ollama/api"
)
// First run of this scenario on a target system will take a long time to download
// ~1.5TB of models. Set a sufficiently large -timeout for your network speed
func TestLibraryModelsGenerate(t *testing.T) {
softTimeout, hardTimeout := getTimeouts(t)
slog.Info("Setting timeouts", "soft", softTimeout, "hard", hardTimeout)
ctx, cancel := context.WithTimeout(context.Background(), hardTimeout)
defer cancel()
client, _, cleanup := InitServerConnection(ctx, t)
defer cleanup()
chatModels := libraryChatModels
for _, model := range chatModels {
t.Run(model, func(t *testing.T) {
if time.Now().Sub(started) > softTimeout {
t.Skip("skipping remaining tests to avoid excessive runtime")
}
if err := PullIfMissing(ctx, client, model); err != nil {
t.Fatalf("pull failed %s", err)
}
req := api.GenerateRequest{
Model: model,
Prompt: "why is the sky blue?",
KeepAlive: &api.Duration{Duration: 10 * time.Second},
Options: map[string]interface{}{
"temperature": 0.1,
"seed": 123,
},
}
anyResp := []string{"rayleigh", "scatter", "atmosphere", "nitrogen", "oxygen", "wavelength"}
// Special cases
if model == "duckdb-nsql" {
anyResp = []string{"select", "from"}
} else if model == "granite3-guardian" || model == "shieldgemma" || model == "llama-guard3" || model == "bespoke-minicheck" {
anyResp = []string{"yes", "no", "safe", "unsafe"}
} else if model == "openthinker" || model == "nexusraven" {
anyResp = []string{"plugin", "im_sep", "components", "function call"}
} else if model == "starcoder" || model == "starcoder2" || model == "magicoder" || model == "deepseek-coder" {
req.Prompt = "def fibonacci():"
anyResp = []string{"f(n)", "sequence", "n-1", "main()", "__main__", "while"}
}
DoGenerate(ctx, t, client, req, anyResp, 120*time.Second, 30*time.Second)
})
}
}

View File

@@ -19,7 +19,7 @@ func TestVisionModels(t *testing.T) {
} }
testCases := []testCase{ testCases := []testCase{
{ {
model: "qwen2.5vl", model: "llava:7b",
}, },
{ {
model: "llama3.2-vision", model: "llama3.2-vision",
@@ -60,7 +60,6 @@ func TestVisionModels(t *testing.T) {
} }
func TestIntegrationSplitBatch(t *testing.T) { func TestIntegrationSplitBatch(t *testing.T) {
skipUnderMinVRAM(t, 6)
image, err := base64.StdEncoding.DecodeString(imageEncoding) image, err := base64.StdEncoding.DecodeString(imageEncoding)
require.NoError(t, err) require.NoError(t, err)
req := api.GenerateRequest{ req := api.GenerateRequest{

View File

@@ -19,6 +19,35 @@ import (
"github.com/ollama/ollama/format" "github.com/ollama/ollama/format"
) )
var (
started = time.Now()
chatModels = []string{
"granite3-moe:latest",
"granite-code:latest",
"nemotron-mini:latest",
"command-r:latest",
"gemma2:latest",
"gemma:latest",
"internlm2:latest",
"phi3.5:latest",
"phi3:latest",
// "phi:latest", // flaky, sometimes generates no response on first query
"stablelm2:latest", // Predictions are off, crashes on small VRAM GPUs
"falcon:latest",
"falcon2:latest",
"minicpm-v:latest",
"mistral:latest",
"orca-mini:latest",
"llama2:latest",
"llama3.1:latest",
"llama3.2:latest",
"llama3.2-vision:latest",
"qwen2.5-coder:latest",
"qwen:latest",
"solar-pro:latest",
}
)
func TestModelsGenerate(t *testing.T) { func TestModelsGenerate(t *testing.T) {
softTimeout, hardTimeout := getTimeouts(t) softTimeout, hardTimeout := getTimeouts(t)
slog.Info("Setting timeouts", "soft", softTimeout, "hard", hardTimeout) slog.Info("Setting timeouts", "soft", softTimeout, "hard", hardTimeout)
@@ -39,13 +68,6 @@ func TestModelsGenerate(t *testing.T) {
slog.Warn("No VRAM info available, testing all models, so larger ones might timeout...") slog.Warn("No VRAM info available, testing all models, so larger ones might timeout...")
} }
var chatModels []string
if s := os.Getenv("OLLAMA_NEW_ENGINE"); s != "" {
chatModels = ollamaEngineChatModels
} else {
chatModels = append(ollamaEngineChatModels, llamaRunnerChatModels...)
}
for _, model := range chatModels { for _, model := range chatModels {
t.Run(model, func(t *testing.T) { t.Run(model, func(t *testing.T) {
if time.Now().Sub(started) > softTimeout { if time.Now().Sub(started) > softTimeout {

View File

@@ -1,266 +0,0 @@
//go:build integration && perf
package integration
import (
"context"
"fmt"
"io/ioutil"
"log/slog"
"math"
"os"
"path/filepath"
"strconv"
"strings"
"testing"
"time"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/format"
)
var (
// Models that don't work reliably with the large context prompt in this test case
longContextFlakes = []string{
"granite-code:latest",
"nemotron-mini:latest",
"falcon:latest", // 2k model
"falcon2:latest", // 2k model
"minicpm-v:latest",
"qwen:latest",
"solar-pro:latest",
}
)
// Note: this test case can take a long time to run, particularly on models with
// large contexts. Run with -timeout set to a large value to get reasonable coverage
// Example usage:
//
// go test --tags=integration,perf -count 1 ./integration -v -timeout 90m -run TestModelsPerf 2>&1 | tee int.log
// cat int.log | grep MODEL_PERF_HEADER | head -1| cut -f2- -d: > perf.csv
// cat int.log | grep MODEL_PERF_DATA | cut -f2- -d: >> perf.csv
func TestModelsPerf(t *testing.T) {
softTimeout, hardTimeout := getTimeouts(t)
slog.Info("Setting timeouts", "soft", softTimeout, "hard", hardTimeout)
ctx, cancel := context.WithTimeout(context.Background(), hardTimeout)
defer cancel()
client, _, cleanup := InitServerConnection(ctx, t)
defer cleanup()
// TODO use info API eventually
var maxVram uint64
var err error
if s := os.Getenv("OLLAMA_MAX_VRAM"); s != "" {
maxVram, err = strconv.ParseUint(s, 10, 64)
if err != nil {
t.Fatalf("invalid OLLAMA_MAX_VRAM %v", err)
}
} else {
slog.Warn("No VRAM info available, testing all models, so larger ones might timeout...")
}
data, err := ioutil.ReadFile(filepath.Join("testdata", "shakespeare.txt"))
if err != nil {
t.Fatalf("failed to open test data file: %s", err)
}
longPrompt := "summarize the following: " + string(data)
var chatModels []string
if s := os.Getenv("OLLAMA_NEW_ENGINE"); s != "" {
chatModels = ollamaEngineChatModels
} else {
chatModels = append(ollamaEngineChatModels, llamaRunnerChatModels...)
}
for _, model := range chatModels {
t.Run(model, func(t *testing.T) {
if time.Now().Sub(started) > softTimeout {
t.Skip("skipping remaining tests to avoid excessive runtime")
}
if err := PullIfMissing(ctx, client, model); err != nil {
t.Fatalf("pull failed %s", err)
}
var maxContext int
resp, err := client.Show(ctx, &api.ShowRequest{Model: model})
if err != nil {
t.Fatalf("show failed: %s", err)
}
arch := resp.ModelInfo["general.architecture"].(string)
maxContext = int(resp.ModelInfo[fmt.Sprintf("%s.context_length", arch)].(float64))
if maxVram > 0 {
resp, err := client.List(ctx)
if err != nil {
t.Fatalf("list models failed %v", err)
}
for _, m := range resp.Models {
// For these tests we want to exercise a some amount of overflow on the CPU
if m.Name == model && float32(m.Size)*0.75 > float32(maxVram) {
t.Skipf("model %s is too large %s for available VRAM %s", model, format.HumanBytes(m.Size), format.HumanBytes(int64(maxVram)))
}
}
}
slog.Info("scneario", "model", model, "max_context", maxContext)
loaded := false
defer func() {
// best effort unload once we're done with the model
if loaded {
client.Generate(ctx, &api.GenerateRequest{Model: model, KeepAlive: &api.Duration{Duration: 0}}, func(rsp api.GenerateResponse) error { return nil })
}
}()
// Some models don't handle the long context data well so skip them to avoid flaky test results
longContextFlake := false
for _, flake := range longContextFlakes {
if model == flake {
longContextFlake = true
break
}
}
// iterate through a few context sizes for coverage without excessive runtime
var contexts []int
keepGoing := true
if maxContext > 16384 {
contexts = []int{4096, 8192, 16384, maxContext}
} else if maxContext > 8192 {
contexts = []int{4096, 8192, maxContext}
} else if maxContext > 4096 {
contexts = []int{4096, maxContext}
} else if maxContext > 0 {
contexts = []int{maxContext}
} else {
t.Fatal("unknown max context size")
}
for _, numCtx := range contexts {
if !keepGoing && numCtx > 8192 { // Always try up to 8k before bailing out
break
}
skipLongPrompt := false
// Workaround bug 11172 temporarily...
maxPrompt := longPrompt
// If we fill the context too full with the prompt, many models
// quickly hit context shifting and go bad.
if len(maxPrompt) > numCtx*2 { // typically yields ~1/2 full context
maxPrompt = maxPrompt[:numCtx*2]
}
testCases := []struct {
prompt string
anyResp []string
}{
{"why is the sky blue?", []string{"rayleigh", "scattering", "atmosphere", "nitrogen", "oxygen"}},
{maxPrompt, []string{"shakespeare", "oppression", "sorrows", "gutenberg", "child", "license", "sonnet", "melancholy"}},
}
var gpuPercent int
for _, tc := range testCases {
if len(tc.prompt) > 100 && (longContextFlake || skipLongPrompt) {
slog.Info("skipping long prompt", "model", model, "num_ctx", numCtx, "gpu_percent", gpuPercent)
continue
}
req := api.GenerateRequest{
Model: model,
Prompt: tc.prompt,
KeepAlive: &api.Duration{Duration: 20 * time.Second}, // long enough to ensure a ps returns
Options: map[string]interface{}{
"temperature": 0,
"seed": 123,
"num_ctx": numCtx,
},
}
atLeastOne := false
var resp api.GenerateResponse
stream := false
req.Stream = &stream
// Avoid potentially getting stuck indefinitely
limit := 5 * time.Minute
genCtx, cancel := context.WithDeadlineCause(
ctx,
time.Now().Add(limit),
fmt.Errorf("generate on model %s with ctx %d took longer than %v", model, numCtx, limit),
)
defer cancel()
err = client.Generate(genCtx, &req, func(rsp api.GenerateResponse) error {
resp = rsp
return nil
})
if err != nil {
// Avoid excessive test runs, but don't consider a failure with massive context
if numCtx > 16384 && strings.Contains(err.Error(), "took longer") {
slog.Warn("max context was taking too long, skipping", "error", err)
keepGoing = false
skipLongPrompt = true
continue
}
t.Fatalf("generate error: ctx:%d err:%s", numCtx, err)
}
loaded = true
for _, expResp := range tc.anyResp {
if strings.Contains(strings.ToLower(resp.Response), expResp) {
atLeastOne = true
break
}
}
if !atLeastOne {
t.Fatalf("response didn't contain expected values: ctx:%d expected:%v response:%s ", numCtx, tc.anyResp, resp.Response)
}
models, err := client.ListRunning(ctx)
if err != nil {
slog.Warn("failed to list running models", "error", err)
continue
}
if len(models.Models) > 1 {
slog.Warn("multiple models loaded, may impact performance results", "loaded", models.Models)
}
for _, m := range models.Models {
if m.Name == model {
if m.SizeVRAM == 0 {
slog.Info("Model fully loaded into CPU")
gpuPercent = 0
keepGoing = false
skipLongPrompt = true
} else if m.SizeVRAM == m.Size {
slog.Info("Model fully loaded into GPU")
gpuPercent = 100
} else {
sizeCPU := m.Size - m.SizeVRAM
cpuPercent := math.Round(float64(sizeCPU) / float64(m.Size) * 100)
gpuPercent = int(100 - cpuPercent)
slog.Info("Model split between CPU/GPU", "CPU", cpuPercent, "GPU", gpuPercent)
keepGoing = false
// Heuristic to avoid excessive test run time
if gpuPercent < 90 {
skipLongPrompt = true
}
}
}
}
fmt.Fprintf(os.Stderr, "MODEL_PERF_HEADER:%s,%s,%s,%s,%s,%s,%s\n",
"MODEL",
"CONTEXT",
"GPU PERCENT",
"PROMPT COUNT",
"LOAD TIME",
"PROMPT EVAL TPS",
"EVAL TPS",
)
fmt.Fprintf(os.Stderr, "MODEL_PERF_DATA:%s,%d,%d,%d,%0.2f,%0.2f,%0.2f\n",
model,
numCtx,
gpuPercent,
resp.PromptEvalCount,
float64(resp.LoadDuration)/1000000000.0,
float64(resp.PromptEvalCount)/(float64(resp.PromptEvalDuration)/1000000000.0),
float64(resp.EvalCount)/(float64(resp.EvalDuration)/1000000000.0),
)
}
}
})
}
}

File diff suppressed because one or more lines are too long

File diff suppressed because it is too large Load Diff

View File

@@ -32,229 +32,6 @@ const (
smol = "llama3.2:1b" smol = "llama3.2:1b"
) )
var (
started = time.Now()
// Note: add newer models at the top of the list to test them first
ollamaEngineChatModels = []string{
"gemma3n:e2b",
"mistral-small3.2:latest",
"deepseek-r1:1.5b",
"llama3.2-vision:latest",
"qwen2.5-coder:latest",
"qwen2.5vl:3b",
"qwen3:0.6b", // dense
"qwen3:30b", // MOE
"gemma3:1b",
"llama3.1:latest",
"llama3.2:latest",
"gemma2:latest",
"minicpm-v:latest", // arch=qwen2
"granite-code:latest", // arch=llama
}
llamaRunnerChatModels = []string{
"mistral:latest",
"falcon3:latest",
"granite3-moe:latest",
"command-r:latest",
"nemotron-mini:latest",
"phi3.5:latest",
"solar-pro:latest",
"internlm2:latest",
"codellama:latest", // arch=llama
"phi3:latest",
"falcon2:latest",
"gemma:latest",
"llama2:latest",
"nous-hermes:latest",
"orca-mini:latest",
"qwen:latest",
"stablelm2:latest", // Predictions are off, crashes on small VRAM GPUs
"falcon:latest",
}
// Some library models are quite large - ensure large VRAM and sufficient disk space
// before running scenarios based on this set
libraryChatModels = []string{
"alfred",
"athene-v2",
"aya-expanse",
"aya",
"bakllava",
"bespoke-minicheck",
"codebooga",
"codegeex4",
"codegemma",
"codellama",
"codeqwen",
"codestral",
"codeup",
"cogito",
"command-a",
"command-r-plus",
"command-r",
"command-r7b-arabic",
"command-r7b",
"dbrx",
"deepcoder",
"deepscaler",
"deepseek-coder-v2",
"deepseek-coder",
"deepseek-llm",
"deepseek-r1",
// "deepseek-v2.5", // requires 155 GB VRAM
"deepseek-v2",
// "deepseek-v3", // requires 482 GB VRAM
"devstral",
"dolphin-llama3",
"dolphin-mistral",
"dolphin-mixtral",
"dolphin-phi",
"dolphin3",
"dolphincoder",
"duckdb-nsql",
"everythinglm",
"exaone-deep",
"exaone3.5",
"falcon",
"falcon2",
"falcon3",
"firefunction-v2",
"gemma",
"gemma2",
"gemma3",
"gemma3n",
"glm4",
"goliath",
"granite-code",
"granite3-dense",
"granite3-guardian",
"granite3-moe",
"granite3.1-dense",
"granite3.1-moe",
"granite3.2-vision",
"granite3.2",
"granite3.3",
"hermes3",
"internlm2",
"llama-guard3",
"llama-pro",
"llama2-chinese",
"llama2-uncensored",
"llama2",
"llama3-chatqa",
"llama3-gradient",
"llama3-groq-tool-use",
"llama3.1",
"llama3.2-vision",
"llama3.2",
"llama3.3",
"llama3",
"llama4",
"llava-llama3",
"llava-phi3",
"llava",
"magicoder",
"magistral",
"marco-o1",
"mathstral",
"meditron",
"medllama2",
"megadolphin",
"minicpm-v",
"mistral-large",
"mistral-nemo",
"mistral-openorca",
"mistral-small",
"mistral-small3.1",
"mistral-small3.2",
"mistral",
"mistrallite",
"mixtral",
"moondream",
"nemotron-mini",
"nemotron",
"neural-chat",
"nexusraven",
"notus",
"nous-hermes",
"nous-hermes2-mixtral",
"nous-hermes2",
"nuextract",
"olmo2",
"open-orca-platypus2",
"openchat",
"opencoder",
"openhermes",
"openthinker",
"orca-mini",
"orca2",
// "phi", // unreliable
"phi3.5",
"phi3",
"phi4-mini-reasoning",
"phi4-mini",
"phi4-reasoning",
"phi4",
"phind-codellama",
"qwen",
"qwen2-math",
"qwen2.5-coder",
"qwen2.5",
"qwen2.5vl",
"qwen2",
"qwen3:0.6b", // dense
"qwen3:30b", // MOE
"qwq",
"r1-1776",
"reader-lm",
"reflection",
"sailor2",
"samantha-mistral",
"shieldgemma",
"smallthinker",
"smollm",
"smollm2",
"solar-pro",
"solar",
"sqlcoder",
"stable-beluga",
"stable-code",
"stablelm-zephyr",
"stablelm2",
"starcoder",
"starcoder2",
"starling-lm",
"tinydolphin",
"tinyllama",
"tulu3",
"vicuna",
"wizard-math",
"wizard-vicuna-uncensored",
"wizard-vicuna",
"wizardcoder",
"wizardlm-uncensored",
"wizardlm2",
"xwinlm",
"yarn-llama2",
"yarn-mistral",
"yi-coder",
"yi",
"zephyr",
}
libraryEmbedModels = []string{
"all-minilm",
"bge-large",
"bge-m3",
"granite-embedding",
"mxbai-embed-large",
"nomic-embed-text",
"paraphrase-multilingual",
"snowflake-arctic-embed",
"snowflake-arctic-embed2",
}
)
func Init() { func Init() {
lifecycle.InitLogging() lifecycle.InitLogging()
} }
@@ -494,10 +271,6 @@ func DoGenerate(ctx context.Context, t *testing.T, client *api.Client, genReq ap
t.Errorf("generate stalled. Response so far:%s", buf.String()) t.Errorf("generate stalled. Response so far:%s", buf.String())
} }
case <-done: case <-done:
if genErr != nil && strings.Contains(genErr.Error(), "model requires more system memory") {
slog.Warn("model is too large for the target test system", "model", genReq.Model, "error", genErr)
return
}
require.NoError(t, genErr, "failed with %s request prompt %s ", genReq.Model, genReq.Prompt) require.NoError(t, genErr, "failed with %s request prompt %s ", genReq.Model, genReq.Prompt)
// Verify the response contains the expected data // Verify the response contains the expected data
response := buf.String() response := buf.String()

View File

@@ -25,19 +25,11 @@ type Causal struct {
opts CausalOptions opts CausalOptions
// maxBatch is the largest batch that we might receive
maxBatch int
// config controls mostly backend-specific optimizations // config controls mostly backend-specific optimizations
config *ml.CacheConfig config *ml.CacheConfig
// ** current forward pass ** // ** current forward pass **
// curReserve indicates that this forward pass is only for
// memory reservation and we should not update our metadata
// based on it.
curReserve bool
// the active layer for Get and Put // the active layer for Get and Put
curLayer int curLayer int
@@ -150,7 +142,6 @@ func (c *Causal) Init(backend ml.Backend, dtype ml.DType, maxSequences, capacity
c.DType = dtype c.DType = dtype
c.cellRanges = make(map[int]cellRange) c.cellRanges = make(map[int]cellRange)
c.backend = backend c.backend = backend
c.maxBatch = maxBatch
} }
func (c *Causal) SetConfig(config ml.CacheConfig) { func (c *Causal) SetConfig(config ml.CacheConfig) {
@@ -168,13 +159,12 @@ func (c *Causal) Close() {
} }
func (c *Causal) StartForward(ctx ml.Context, batch input.Batch, reserve bool) error { func (c *Causal) StartForward(ctx ml.Context, batch input.Batch, reserve bool) error {
c.curReserve = reserve
c.curBatchSize = len(batch.Positions) c.curBatchSize = len(batch.Positions)
c.curSequences = batch.Sequences c.curSequences = batch.Sequences
c.curPositions = batch.Positions c.curPositions = batch.Positions
c.opts.Except = nil c.opts.Except = nil
if !c.curReserve { if !reserve {
c.updateSlidingWindow() c.updateSlidingWindow()
var err error var err error
@@ -221,9 +211,10 @@ func (c *Causal) StartForward(ctx ml.Context, batch input.Batch, reserve bool) e
c.curCellRange.max = len(c.cells) - 1 c.curCellRange.max = len(c.cells) - 1
} }
c.curMask = c.buildMask(ctx) var err error
c.curMask, err = c.buildMask(ctx)
return nil return err
} }
func newRange() cellRange { func newRange() cellRange {
@@ -306,7 +297,7 @@ func roundUp(length, pad int) int {
// Builds a mask of history x batch indicating whether for each token in the batch the // Builds a mask of history x batch indicating whether for each token in the batch the
// token in the history should apply. This is based on both the sequence and causality (the // token in the history should apply. This is based on both the sequence and causality (the
// position of the history is not ahead of the token in the batch). // position of the history is not ahead of the token in the batch).
func (c *Causal) buildMask(ctx ml.Context) ml.Tensor { func (c *Causal) buildMask(ctx ml.Context) (ml.Tensor, error) {
// Align and pad the two dimensions as required by the backend // Align and pad the two dimensions as required by the backend
batchSize := roundUp(c.curBatchSize, c.config.MaskBatchPadding) batchSize := roundUp(c.curBatchSize, c.config.MaskBatchPadding)
@@ -314,11 +305,6 @@ func (c *Causal) buildMask(ctx ml.Context) ml.Tensor {
c.curCellRange.max = roundUp(c.curCellRange.max+1, c.config.CachePadding) - 1 c.curCellRange.max = roundUp(c.curCellRange.max+1, c.config.CachePadding) - 1
length := c.curCellRange.max - c.curCellRange.min + 1 length := c.curCellRange.max - c.curCellRange.min + 1
if c.curReserve {
return ctx.Input().Empty(c.config.MaskDType, length, batchSize)
}
mask := make([]float32, batchSize*length) mask := make([]float32, batchSize*length)
for i := range c.curBatchSize { for i := range c.curBatchSize {
@@ -339,7 +325,10 @@ func (c *Causal) buildMask(ctx ml.Context) ml.Tensor {
mask[i] = float32(math.Inf(-1)) mask[i] = float32(math.Inf(-1))
} }
maskTensor := ctx.Input().FromFloatSlice(mask, length, batchSize) maskTensor, err := ctx.Input().FromFloatSlice(mask, length, batchSize)
if err != nil {
return nil, err
}
if c.config.MaskDType != ml.DTypeF32 { if c.config.MaskDType != ml.DTypeF32 {
out := ctx.Input().Empty(c.config.MaskDType, maskTensor.Shape()...) out := ctx.Input().Empty(c.config.MaskDType, maskTensor.Shape()...)
@@ -347,7 +336,7 @@ func (c *Causal) buildMask(ctx ml.Context) ml.Tensor {
maskTensor = out maskTensor = out
} }
return maskTensor return maskTensor, nil
} }
func (c *Causal) moveCells(ctx ml.Context, src, dst, length int) { func (c *Causal) moveCells(ctx ml.Context, src, dst, length int) {
@@ -502,7 +491,12 @@ func (c *Causal) SetCausal(ctx ml.Context, opts CausalOptions) {
if !slices.Equal(c.opts.Except, opts.Except) { if !slices.Equal(c.opts.Except, opts.Except) {
c.opts = opts c.opts = opts
if ctx != nil { if ctx != nil {
c.curMask = c.buildMask(ctx) var err error
c.curMask, err = c.buildMask(ctx)
if err != nil {
// This error should never occur because we have previously built a mask with the same shape
panic(fmt.Errorf("SetCausal: %w", err))
}
} }
} }
} }
@@ -643,64 +637,51 @@ func (c *Causal) shift(seq int, beginIndex, offset int32) error {
return ErrNotSupported return ErrNotSupported
} }
ctx := c.backend.NewContext()
defer ctx.Close()
seqRange := c.cellRanges[seq] seqRange := c.cellRanges[seq]
size := seqRange.max - seqRange.min + 1
for start := seqRange.min; start <= seqRange.max; start += c.maxBatch { offsets := make([]int32, size)
size := min(seqRange.max-start+1, c.maxBatch) for i := range offsets {
offsets := make([]int32, size) cell := c.cells[seqRange.min+i]
var batchFirst, batchLast int if slices.Contains(cell.sequences, seq) && cell.pos >= beginIndex {
offsets[i] = offset
batchFirst = -1
for i := range offsets {
cell := c.cells[start+i]
if slices.Contains(cell.sequences, seq) && cell.pos >= beginIndex {
offsets[i] = offset
if batchFirst < 0 {
batchFirst = i
}
batchLast = i
}
} }
}
if batchFirst < 0 { kShift, err := ctx.Input().FromIntSlice(offsets, len(offsets))
if err != nil {
return err
}
for i, key := range c.keys {
if key == nil {
continue continue
} }
offsets = offsets[batchFirst : batchLast+1] kHeadDim := key.Dim(0)
numKVHeads := key.Dim(1)
rowSize := key.Stride(2)
ctx := c.backend.NewContext() key = key.View(ctx, rowSize*seqRange.min,
kShift := ctx.Input().FromIntSlice(offsets, len(offsets)) kHeadDim, key.Stride(1),
numKVHeads, key.Stride(2),
size,
)
for i, key := range c.keys { roped, err := c.shiftFn(ctx, i, key, kShift)
if key == nil { if err != nil {
continue return err
}
kHeadDim := key.Dim(0)
numKVHeads := key.Dim(1)
rowSize := key.Stride(2)
key = key.View(ctx, rowSize*(start+batchFirst),
kHeadDim, key.Stride(1),
numKVHeads, key.Stride(2),
len(offsets),
)
roped, err := c.shiftFn(ctx, i, key, kShift)
if err != nil {
ctx.Close()
return err
}
ctx.Forward(roped.Copy(ctx, key))
} }
ctx.Compute() ctx.Forward(roped.Copy(ctx, key))
ctx.Close()
} }
ctx.Compute()
return nil return nil
} }

View File

@@ -344,7 +344,7 @@ func testCache(t *testing.T, backend ml.Backend, cache Cache, tests []testCase)
} }
cache.SetLayer(0) cache.SetLayer(0)
tensor := context.FromFloatSlice(test.in, test.inShape...) tensor, _ := context.FromFloatSlice(test.in, test.inShape...)
cache.Put(context, tensor, tensor) cache.Put(context, tensor, tensor)
out, _, mask := cache.Get(context) out, _, mask := cache.Get(context)
@@ -386,7 +386,7 @@ func TestCanResume(t *testing.T) {
} }
cache.SetLayer(0) cache.SetLayer(0)
tensor := context.FromFloatSlice([]float32{1, 2, 3, 4}, 1, 1, 4) tensor, _ := context.FromFloatSlice([]float32{1, 2, 3, 4}, 1, 1, 4)
cache.Put(context, tensor, tensor) cache.Put(context, tensor, tensor)
// with window size 4, nothing has slid out of the window yet // with window size 4, nothing has slid out of the window yet
@@ -413,7 +413,7 @@ func TestCanResume(t *testing.T) {
} }
cache.SetLayer(0) cache.SetLayer(0)
tensor = context.FromFloatSlice([]float32{5, 6}, 1, 1, 2) tensor, _ = context.FromFloatSlice([]float32{5, 6}, 1, 1, 2)
cache.Put(context, tensor, tensor) cache.Put(context, tensor, tensor)
// only the latest position has overlapping windows // only the latest position has overlapping windows
@@ -470,24 +470,24 @@ func (c *testContext) Zeros(dtype ml.DType, shape ...int) ml.Tensor {
return c.Empty(dtype, shape...) return c.Empty(dtype, shape...)
} }
func (c *testContext) FromFloatSlice(s []float32, shape ...int) ml.Tensor { func (c *testContext) FromFloatSlice(s []float32, shape ...int) (ml.Tensor, error) {
t := c.Empty(ml.DTypeF32, shape...).(*testTensor) t := c.Empty(ml.DTypeF32, shape...).(*testTensor)
copy(t.data, s) copy(t.data, s)
return t return t, nil
} }
func (c *testContext) FromIntSlice(s []int32, shape ...int) ml.Tensor { func (c *testContext) FromIntSlice(s []int32, shape ...int) (ml.Tensor, error) {
f := make([]float32, len(s)) f := make([]float32, len(s))
for i := range f { for i := range f {
f[i] = float32(s[i]) f[i] = float32(s[i])
} }
out := c.FromFloatSlice(f, shape...) out, _ := c.FromFloatSlice(f, shape...)
out.(*testTensor).dtype = ml.DTypeI32 out.(*testTensor).dtype = ml.DTypeI32
return out return out, nil
} }
func (c *testContext) Arange(start, stop, step float32, dtype ml.DType) ml.Tensor { func (c *testContext) Arange(start, stop, step float32, dtype ml.DType) ml.Tensor {
@@ -496,7 +496,7 @@ func (c *testContext) Arange(start, stop, step float32, dtype ml.DType) ml.Tenso
s = append(s, i) s = append(s, i)
} }
out := c.FromFloatSlice(s, len(s)) out, _ := c.FromFloatSlice(s, len(s))
out.(*testTensor).dtype = dtype out.(*testTensor).dtype = dtype
return out return out
} }
@@ -508,7 +508,7 @@ func (c *testContext) Forward(...ml.Tensor) ml.Context { return c }
func (c *testContext) Compute(...ml.Tensor) {} func (c *testContext) Compute(...ml.Tensor) {}
func (c *testContext) Reserve() {} func (c *testContext) Reserve() error { return nil }
func (c *testContext) MaxGraphNodes() int { func (c *testContext) MaxGraphNodes() int {
return 10 return 10

2
llama/build-info.cpp generated vendored
View File

@@ -1,4 +1,4 @@
int LLAMA_BUILD_NUMBER = 0; int LLAMA_BUILD_NUMBER = 0;
char const *LLAMA_COMMIT = "de4c07f93783a1a96456a44dc16b9db538ee1618"; char const *LLAMA_COMMIT = "e1e8e0991ffd9e99a445c6812bb519d5bac9f4b5";
char const *LLAMA_COMPILER = ""; char const *LLAMA_COMPILER = "";
char const *LLAMA_BUILD_TARGET = ""; char const *LLAMA_BUILD_TARGET = "";

View File

@@ -10,11 +10,11 @@ include common/stb_image.*
include include/ include include/
include include/llama.* include include/llama.*
include include/llama-*.* include include/llama-*.*
include tools/ include examples/
include tools/mtmd/ include examples/llava/
include tools/mtmd/clip.* include examples/llava/clip.*
include tools/mtmd/clip-impl.* include examples/llava/clip-impl.*
include tools/mtmd/llava.* include examples/llava/llava.*
include src/ include src/
include src/llama.* include src/llama.*
include src/llama-*.* include src/llama-*.*

View File

@@ -1096,6 +1096,7 @@ struct llama_context_params common_context_params_to_llama(const common_params &
cparams.n_threads = params.cpuparams.n_threads; cparams.n_threads = params.cpuparams.n_threads;
cparams.n_threads_batch = params.cpuparams_batch.n_threads == -1 ? cparams.n_threads_batch = params.cpuparams_batch.n_threads == -1 ?
params.cpuparams.n_threads : params.cpuparams_batch.n_threads; params.cpuparams.n_threads : params.cpuparams_batch.n_threads;
cparams.logits_all = params.logits_all;
cparams.embeddings = params.embedding; cparams.embeddings = params.embedding;
cparams.rope_scaling_type = params.rope_scaling_type; cparams.rope_scaling_type = params.rope_scaling_type;
cparams.rope_freq_base = params.rope_freq_base; cparams.rope_freq_base = params.rope_freq_base;
@@ -1113,7 +1114,6 @@ struct llama_context_params common_context_params_to_llama(const common_params &
cparams.offload_kqv = !params.no_kv_offload; cparams.offload_kqv = !params.no_kv_offload;
cparams.flash_attn = params.flash_attn; cparams.flash_attn = params.flash_attn;
cparams.no_perf = params.no_perf; cparams.no_perf = params.no_perf;
cparams.op_offload = !params.no_op_offload;
if (params.reranking) { if (params.reranking) {
cparams.embeddings = true; cparams.embeddings = true;
@@ -1565,20 +1565,3 @@ common_control_vector_data common_control_vector_load(const std::vector<common_c
return result; return result;
} }
ggml_opt_dataset_t common_opt_dataset_init(struct llama_context * ctx, const std::vector<llama_token> & tokens, int64_t stride) {
const int64_t ne_datapoint = llama_n_ctx(ctx);
const int64_t ndata = (tokens.size() - ne_datapoint - 1) / stride;
ggml_opt_dataset_t result = ggml_opt_dataset_init(
GGML_TYPE_I32, GGML_TYPE_I32, ne_datapoint, ne_datapoint, ndata, /*ndata_shard =*/ 1);
llama_token * data = (llama_token *) ggml_opt_dataset_data(result)->data;
llama_token * labels = (llama_token *) ggml_opt_dataset_labels(result)->data;
for (int64_t idata = 0; idata < ndata; ++idata) {
memcpy(data + idata*ne_datapoint, tokens.data() + idata*stride + 0, ne_datapoint*sizeof(llama_token));
memcpy(labels + idata*ne_datapoint, tokens.data() + idata*stride + 1, ne_datapoint*sizeof(llama_token));
}
return result;
}

View File

@@ -66,6 +66,7 @@ enum llama_example {
LLAMA_EXAMPLE_COMMON, LLAMA_EXAMPLE_COMMON,
LLAMA_EXAMPLE_SPECULATIVE, LLAMA_EXAMPLE_SPECULATIVE,
LLAMA_EXAMPLE_MAIN, LLAMA_EXAMPLE_MAIN,
LLAMA_EXAMPLE_INFILL,
LLAMA_EXAMPLE_EMBEDDING, LLAMA_EXAMPLE_EMBEDDING,
LLAMA_EXAMPLE_PERPLEXITY, LLAMA_EXAMPLE_PERPLEXITY,
LLAMA_EXAMPLE_RETRIEVAL, LLAMA_EXAMPLE_RETRIEVAL,
@@ -95,7 +96,6 @@ enum common_sampler_type {
COMMON_SAMPLER_TYPE_XTC = 8, COMMON_SAMPLER_TYPE_XTC = 8,
COMMON_SAMPLER_TYPE_INFILL = 9, COMMON_SAMPLER_TYPE_INFILL = 9,
COMMON_SAMPLER_TYPE_PENALTIES = 10, COMMON_SAMPLER_TYPE_PENALTIES = 10,
COMMON_SAMPLER_TYPE_TOP_N_SIGMA = 11,
}; };
// dimensionality reduction methods, used by cvector-generator // dimensionality reduction methods, used by cvector-generator
@@ -161,7 +161,6 @@ struct common_params_sampling {
std::vector<enum common_sampler_type> samplers = { std::vector<enum common_sampler_type> samplers = {
COMMON_SAMPLER_TYPE_PENALTIES, COMMON_SAMPLER_TYPE_PENALTIES,
COMMON_SAMPLER_TYPE_DRY, COMMON_SAMPLER_TYPE_DRY,
COMMON_SAMPLER_TYPE_TOP_N_SIGMA,
COMMON_SAMPLER_TYPE_TOP_K, COMMON_SAMPLER_TYPE_TOP_K,
COMMON_SAMPLER_TYPE_TYPICAL_P, COMMON_SAMPLER_TYPE_TYPICAL_P,
COMMON_SAMPLER_TYPE_TOP_P, COMMON_SAMPLER_TYPE_TOP_P,
@@ -324,6 +323,7 @@ struct common_params {
bool ctx_shift = true; // context shift on inifinite text generation bool ctx_shift = true; // context shift on inifinite text generation
bool input_prefix_bos = false; // prefix BOS to user inputs, preceding input_prefix bool input_prefix_bos = false; // prefix BOS to user inputs, preceding input_prefix
bool logits_all = false; // return logits for all tokens in the batch
bool use_mmap = true; // use mmap for faster loads bool use_mmap = true; // use mmap for faster loads
bool use_mlock = false; // use mlock to keep model in memory bool use_mlock = false; // use mlock to keep model in memory
bool verbose_prompt = false; // print prompt tokens before generation bool verbose_prompt = false; // print prompt tokens before generation
@@ -332,7 +332,6 @@ struct common_params {
bool no_kv_offload = false; // disable KV offloading bool no_kv_offload = false; // disable KV offloading
bool warmup = true; // warmup run bool warmup = true; // warmup run
bool check_tensors = false; // validate tensor data bool check_tensors = false; // validate tensor data
bool no_op_offload = false; // globally disable offload host tensor operations to device
bool single_turn = false; // single turn chat conversation bool single_turn = false; // single turn chat conversation
@@ -341,7 +340,7 @@ struct common_params {
common_conversation_mode conversation_mode = COMMON_CONVERSATION_MODE_AUTO; common_conversation_mode conversation_mode = COMMON_CONVERSATION_MODE_AUTO;
// multimodal models (see tools/mtmd) // multimodal models (see examples/llava)
struct common_params_model mmproj; struct common_params_model mmproj;
bool mmproj_use_gpu = true; // use GPU for multimodal model bool mmproj_use_gpu = true; // use GPU for multimodal model
bool no_mmproj = false; // explicitly disable multimodal model bool no_mmproj = false; // explicitly disable multimodal model
@@ -410,14 +409,13 @@ struct common_params {
bool process_output = false; // collect data for the output tensor bool process_output = false; // collect data for the output tensor
bool compute_ppl = true; // whether to compute perplexity bool compute_ppl = true; // whether to compute perplexity
bool parse_special = false; // whether to parse special tokens during imatrix tokenization
// cvector-generator params // cvector-generator params
int n_pca_batch = 100; int n_pca_batch = 100;
int n_pca_iterations = 1000; int n_pca_iterations = 1000;
dimre_method cvector_dimre_method = DIMRE_METHOD_PCA; dimre_method cvector_dimre_method = DIMRE_METHOD_PCA;
std::string cvector_positive_file = "tools/cvector-generator/positive.txt"; std::string cvector_positive_file = "examples/cvector-generator/positive.txt";
std::string cvector_negative_file = "tools/cvector-generator/negative.txt"; std::string cvector_negative_file = "examples/cvector-generator/negative.txt";
bool spm_infill = false; // suffix/prefix/middle pattern for infill bool spm_infill = false; // suffix/prefix/middle pattern for infill
@@ -666,9 +664,3 @@ const char * const LLM_KV_SPLIT_COUNT = "split.count";
const char * const LLM_KV_SPLIT_TENSORS_COUNT = "split.tensors.count"; const char * const LLM_KV_SPLIT_TENSORS_COUNT = "split.tensors.count";
} }
//
// training utils
//
ggml_opt_dataset_t common_opt_dataset_init(struct llama_context * ctx, const std::vector<llama_token> & tokens, int64_t stride);

View File

@@ -1,7 +1,6 @@
#include "sampling.h" #include "sampling.h"
#include "common.h" #include "common.h"
#include "log.h"
#include <cmath> #include <cmath>
#include <unordered_map> #include <unordered_map>
@@ -230,48 +229,51 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co
params.logit_bias.data())); params.logit_bias.data()));
if (params.mirostat == 0) { if (params.mirostat == 0) {
for (const auto & cnstr : params.samplers) { if (params.top_n_sigma >= 0) {
switch (cnstr) { llama_sampler_chain_add(result->chain, llama_sampler_init_top_k (params.top_k));
case COMMON_SAMPLER_TYPE_DRY: llama_sampler_chain_add(result->chain, llama_sampler_init_temp (params.temp));
{ llama_sampler_chain_add(result->chain, llama_sampler_init_top_n_sigma (params.top_n_sigma));
std::vector<const char *> c_breakers; } else {
c_breakers.reserve(params.dry_sequence_breakers.size()); for (const auto & cnstr : params.samplers) {
for (const auto & str : params.dry_sequence_breakers) { switch (cnstr) {
c_breakers.push_back(str.c_str()); case COMMON_SAMPLER_TYPE_DRY:
} {
std::vector<const char *> c_breakers;
c_breakers.reserve(params.dry_sequence_breakers.size());
for (const auto & str : params.dry_sequence_breakers) {
c_breakers.push_back(str.c_str());
}
llama_sampler_chain_add(result->chain, llama_sampler_init_dry (vocab, llama_model_n_ctx_train(model), params.dry_multiplier, params.dry_base, params.dry_allowed_length, params.dry_penalty_last_n, c_breakers.data(), c_breakers.size())); llama_sampler_chain_add(result->chain, llama_sampler_init_dry (vocab, llama_model_n_ctx_train(model), params.dry_multiplier, params.dry_base, params.dry_allowed_length, params.dry_penalty_last_n, c_breakers.data(), c_breakers.size()));
} }
break; break;
case COMMON_SAMPLER_TYPE_TOP_K: case COMMON_SAMPLER_TYPE_TOP_K:
llama_sampler_chain_add(result->chain, llama_sampler_init_top_k (params.top_k)); llama_sampler_chain_add(result->chain, llama_sampler_init_top_k (params.top_k));
break; break;
case COMMON_SAMPLER_TYPE_TOP_P: case COMMON_SAMPLER_TYPE_TOP_P:
llama_sampler_chain_add(result->chain, llama_sampler_init_top_p (params.top_p, params.min_keep)); llama_sampler_chain_add(result->chain, llama_sampler_init_top_p (params.top_p, params.min_keep));
break; break;
case COMMON_SAMPLER_TYPE_TOP_N_SIGMA: case COMMON_SAMPLER_TYPE_MIN_P:
llama_sampler_chain_add(result->chain, llama_sampler_init_top_n_sigma (params.top_n_sigma)); llama_sampler_chain_add(result->chain, llama_sampler_init_min_p (params.min_p, params.min_keep));
break; break;
case COMMON_SAMPLER_TYPE_MIN_P: case COMMON_SAMPLER_TYPE_XTC:
llama_sampler_chain_add(result->chain, llama_sampler_init_min_p (params.min_p, params.min_keep)); llama_sampler_chain_add(result->chain, llama_sampler_init_xtc (params.xtc_probability, params.xtc_threshold, params.min_keep, params.seed));
break; break;
case COMMON_SAMPLER_TYPE_XTC: case COMMON_SAMPLER_TYPE_TYPICAL_P:
llama_sampler_chain_add(result->chain, llama_sampler_init_xtc (params.xtc_probability, params.xtc_threshold, params.min_keep, params.seed)); llama_sampler_chain_add(result->chain, llama_sampler_init_typical (params.typ_p, params.min_keep));
break; break;
case COMMON_SAMPLER_TYPE_TYPICAL_P: case COMMON_SAMPLER_TYPE_TEMPERATURE:
llama_sampler_chain_add(result->chain, llama_sampler_init_typical (params.typ_p, params.min_keep)); llama_sampler_chain_add(result->chain, llama_sampler_init_temp_ext (params.temp, params.dynatemp_range, params.dynatemp_exponent));
break; break;
case COMMON_SAMPLER_TYPE_TEMPERATURE: case COMMON_SAMPLER_TYPE_INFILL:
llama_sampler_chain_add(result->chain, llama_sampler_init_temp_ext (params.temp, params.dynatemp_range, params.dynatemp_exponent)); llama_sampler_chain_add(result->chain, llama_sampler_init_infill (vocab));
break; break;
case COMMON_SAMPLER_TYPE_INFILL: case COMMON_SAMPLER_TYPE_PENALTIES:
llama_sampler_chain_add(result->chain, llama_sampler_init_infill (vocab)); llama_sampler_chain_add(result->chain, llama_sampler_init_penalties(params.penalty_last_n, params.penalty_repeat, params.penalty_freq, params.penalty_present));
break; break;
case COMMON_SAMPLER_TYPE_PENALTIES: default:
llama_sampler_chain_add(result->chain, llama_sampler_init_penalties (params.penalty_last_n, params.penalty_repeat, params.penalty_freq, params.penalty_present)); GGML_ASSERT(false && "unknown sampler type");
break; }
default:
GGML_ASSERT(false && "unknown sampler type");
} }
} }
llama_sampler_chain_add(result->chain, llama_sampler_init_dist(params.seed)); llama_sampler_chain_add(result->chain, llama_sampler_init_dist(params.seed));
@@ -473,7 +475,6 @@ char common_sampler_type_to_chr(enum common_sampler_type cnstr) {
case COMMON_SAMPLER_TYPE_TOP_K: return 'k'; case COMMON_SAMPLER_TYPE_TOP_K: return 'k';
case COMMON_SAMPLER_TYPE_TYPICAL_P: return 'y'; case COMMON_SAMPLER_TYPE_TYPICAL_P: return 'y';
case COMMON_SAMPLER_TYPE_TOP_P: return 'p'; case COMMON_SAMPLER_TYPE_TOP_P: return 'p';
case COMMON_SAMPLER_TYPE_TOP_N_SIGMA: return 's';
case COMMON_SAMPLER_TYPE_MIN_P: return 'm'; case COMMON_SAMPLER_TYPE_MIN_P: return 'm';
case COMMON_SAMPLER_TYPE_TEMPERATURE: return 't'; case COMMON_SAMPLER_TYPE_TEMPERATURE: return 't';
case COMMON_SAMPLER_TYPE_XTC: return 'x'; case COMMON_SAMPLER_TYPE_XTC: return 'x';
@@ -489,7 +490,6 @@ std::string common_sampler_type_to_str(enum common_sampler_type cnstr) {
case COMMON_SAMPLER_TYPE_TOP_K: return "top_k"; case COMMON_SAMPLER_TYPE_TOP_K: return "top_k";
case COMMON_SAMPLER_TYPE_TYPICAL_P: return "typ_p"; case COMMON_SAMPLER_TYPE_TYPICAL_P: return "typ_p";
case COMMON_SAMPLER_TYPE_TOP_P: return "top_p"; case COMMON_SAMPLER_TYPE_TOP_P: return "top_p";
case COMMON_SAMPLER_TYPE_TOP_N_SIGMA: return "top_n_sigma";
case COMMON_SAMPLER_TYPE_MIN_P: return "min_p"; case COMMON_SAMPLER_TYPE_MIN_P: return "min_p";
case COMMON_SAMPLER_TYPE_TEMPERATURE: return "temperature"; case COMMON_SAMPLER_TYPE_TEMPERATURE: return "temperature";
case COMMON_SAMPLER_TYPE_XTC: return "xtc"; case COMMON_SAMPLER_TYPE_XTC: return "xtc";
@@ -504,7 +504,6 @@ std::vector<common_sampler_type> common_sampler_types_from_names(const std::vect
{ "dry", COMMON_SAMPLER_TYPE_DRY }, { "dry", COMMON_SAMPLER_TYPE_DRY },
{ "top_k", COMMON_SAMPLER_TYPE_TOP_K }, { "top_k", COMMON_SAMPLER_TYPE_TOP_K },
{ "top_p", COMMON_SAMPLER_TYPE_TOP_P }, { "top_p", COMMON_SAMPLER_TYPE_TOP_P },
{ "top_n_sigma", COMMON_SAMPLER_TYPE_TOP_N_SIGMA },
{ "typ_p", COMMON_SAMPLER_TYPE_TYPICAL_P }, { "typ_p", COMMON_SAMPLER_TYPE_TYPICAL_P },
{ "min_p", COMMON_SAMPLER_TYPE_MIN_P }, { "min_p", COMMON_SAMPLER_TYPE_MIN_P },
{ "temperature", COMMON_SAMPLER_TYPE_TEMPERATURE }, { "temperature", COMMON_SAMPLER_TYPE_TEMPERATURE },
@@ -518,7 +517,6 @@ std::vector<common_sampler_type> common_sampler_types_from_names(const std::vect
std::unordered_map<std::string, common_sampler_type> sampler_alt_name_map { std::unordered_map<std::string, common_sampler_type> sampler_alt_name_map {
{ "top-k", COMMON_SAMPLER_TYPE_TOP_K }, { "top-k", COMMON_SAMPLER_TYPE_TOP_K },
{ "top-p", COMMON_SAMPLER_TYPE_TOP_P }, { "top-p", COMMON_SAMPLER_TYPE_TOP_P },
{ "top-n-sigma", COMMON_SAMPLER_TYPE_TOP_N_SIGMA },
{ "nucleus", COMMON_SAMPLER_TYPE_TOP_P }, { "nucleus", COMMON_SAMPLER_TYPE_TOP_P },
{ "typical-p", COMMON_SAMPLER_TYPE_TYPICAL_P }, { "typical-p", COMMON_SAMPLER_TYPE_TYPICAL_P },
{ "typical", COMMON_SAMPLER_TYPE_TYPICAL_P }, { "typical", COMMON_SAMPLER_TYPE_TYPICAL_P },
@@ -535,16 +533,14 @@ std::vector<common_sampler_type> common_sampler_types_from_names(const std::vect
auto sampler = sampler_canonical_name_map.find(name); auto sampler = sampler_canonical_name_map.find(name);
if (sampler != sampler_canonical_name_map.end()) { if (sampler != sampler_canonical_name_map.end()) {
samplers.push_back(sampler->second); samplers.push_back(sampler->second);
continue; } else {
} if (allow_alt_names) {
if (allow_alt_names) { sampler = sampler_alt_name_map.find(name);
sampler = sampler_alt_name_map.find(name); if (sampler != sampler_alt_name_map.end()) {
if (sampler != sampler_alt_name_map.end()) { samplers.push_back(sampler->second);
samplers.push_back(sampler->second); }
continue;
} }
} }
LOG_WRN("%s: unable to match sampler by name '%s'\n", __func__, name.c_str());
} }
return samplers; return samplers;
@@ -556,7 +552,6 @@ std::vector<common_sampler_type> common_sampler_types_from_chars(const std::stri
{ common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_TOP_K), COMMON_SAMPLER_TYPE_TOP_K }, { common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_TOP_K), COMMON_SAMPLER_TYPE_TOP_K },
{ common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_TYPICAL_P), COMMON_SAMPLER_TYPE_TYPICAL_P }, { common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_TYPICAL_P), COMMON_SAMPLER_TYPE_TYPICAL_P },
{ common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_TOP_P), COMMON_SAMPLER_TYPE_TOP_P }, { common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_TOP_P), COMMON_SAMPLER_TYPE_TOP_P },
{ common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_TOP_N_SIGMA), COMMON_SAMPLER_TYPE_TOP_N_SIGMA },
{ common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_MIN_P), COMMON_SAMPLER_TYPE_MIN_P }, { common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_MIN_P), COMMON_SAMPLER_TYPE_MIN_P },
{ common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_TEMPERATURE), COMMON_SAMPLER_TYPE_TEMPERATURE }, { common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_TEMPERATURE), COMMON_SAMPLER_TYPE_TEMPERATURE },
{ common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_XTC), COMMON_SAMPLER_TYPE_XTC }, { common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_XTC), COMMON_SAMPLER_TYPE_XTC },
@@ -571,8 +566,6 @@ std::vector<common_sampler_type> common_sampler_types_from_chars(const std::stri
const auto sampler = sampler_name_map.find(c); const auto sampler = sampler_name_map.find(c);
if (sampler != sampler_name_map.end()) { if (sampler != sampler_name_map.end()) {
samplers.push_back(sampler->second); samplers.push_back(sampler->second);
} else {
LOG_WRN("%s: unable to match sampler by char '%c'\n", __func__, c);
} }
} }

View File

@@ -31,7 +31,9 @@
#define KEY_FEATURE_LAYER "clip.vision.feature_layer" #define KEY_FEATURE_LAYER "clip.vision.feature_layer"
#define KEY_PROJ_SCALE_FACTOR "clip.vision.projector.scale_factor" #define KEY_PROJ_SCALE_FACTOR "clip.vision.projector.scale_factor"
#define KEY_PROJ_TYPE "clip.projector_type" #define KEY_PROJ_TYPE "clip.projector_type"
#define KEY_SPATIAL_MERGE_SIZE "clip.vision.spatial_merge_size"
#define KEY_USE_GLU_MLP "clip.use_glu_mlp" // for qwen2.5vl
#define KEY_USE_RMS_NORM "clip.use_rms_norm" // for qwen2.5vl
#define KEY_MM_PATCH_MERGE_TYPE "clip.vision.mm_patch_merge_type" #define KEY_MM_PATCH_MERGE_TYPE "clip.vision.mm_patch_merge_type"
#define KEY_IMAGE_GRID_PINPOINTS "clip.vision.image_grid_pinpoints" #define KEY_IMAGE_GRID_PINPOINTS "clip.vision.image_grid_pinpoints"
@@ -53,16 +55,12 @@
#define TN_ATTN_Q "%s.blk.%d.attn_q.%s" #define TN_ATTN_Q "%s.blk.%d.attn_q.%s"
#define TN_ATTN_V "%s.blk.%d.attn_v.%s" #define TN_ATTN_V "%s.blk.%d.attn_v.%s"
#define TN_ATTN_OUTPUT "%s.blk.%d.attn_out.%s" #define TN_ATTN_OUTPUT "%s.blk.%d.attn_out.%s"
#define TN_ATTN_K_NORM "%s.blk.%d.attn_k_norm.%s"
#define TN_ATTN_Q_NORM "%s.blk.%d.attn_q_norm.%s"
#define TN_FFN_DOWN "%s.blk.%d.ffn_down.%s" #define TN_FFN_DOWN "%s.blk.%d.ffn_down.%s"
#define TN_FFN_GATE "%s.blk.%d.ffn_gate.%s" #define TN_FFN_GATE "%s.blk.%d.ffn_gate.%s"
#define TN_FFN_UP "%s.blk.%d.ffn_up.%s" #define TN_FFN_UP "%s.blk.%d.ffn_up.%s"
#define TN_FFN_GATE "%s.blk.%d.ffn_gate.%s" #define TN_FFN_GATE "%s.blk.%d.ffn_gate.%s"
#define TN_LN_1 "%s.blk.%d.ln1.%s" // layer norm #define TN_LN_1 "%s.blk.%d.ln1.%s"
#define TN_LN_2 "%s.blk.%d.ln2.%s" // layer norm #define TN_LN_2 "%s.blk.%d.ln2.%s"
#define TN_LS_1 "%s.blk.%d.ls1.%s" // layer scale
#define TN_LS_2 "%s.blk.%d.ls2.%s" // layer scale
#define TN_LN_PRE "%s.pre_ln.%s" #define TN_LN_PRE "%s.pre_ln.%s"
#define TN_LN_POST "%s.post_ln.%s" #define TN_LN_POST "%s.post_ln.%s"
#define TN_LLAVA_PROJ "mm.%d.%s" #define TN_LLAVA_PROJ "mm.%d.%s"
@@ -70,14 +68,10 @@
#define TN_MVLM_PROJ_BLOCK "mm.model.mb_block.%d.block.%d.%s" #define TN_MVLM_PROJ_BLOCK "mm.model.mb_block.%d.block.%d.%s"
#define TN_MVLM_PROJ_PEG "mm.model.peg.%d.%s" #define TN_MVLM_PROJ_PEG "mm.model.peg.%d.%s"
#define TN_IMAGE_NEWLINE "model.image_newline" #define TN_IMAGE_NEWLINE "model.image_newline"
#define TN_MM_INP_NORM "mm.input_norm.weight"
#define TN_MM_INP_PROJ "mm.input_projection.weight" // gemma3 #define TN_MM_INP_PROJ "mm.input_projection.weight" // gemma3
#define TN_MM_SOFT_EMB_N "mm.soft_emb_norm.weight" // gemma3 #define TN_MM_SOFT_EMB_N "mm.soft_emb_norm.weight" // gemma3
#define TN_MM_PROJECTOR "mm.model.fc.weight" // idefics3 #define TN_MM_PROJECTOR "mm.model.fc.weight" // idefics3
#define TN_MM_PATCH_MERGER "mm.patch_merger.weight" // mistral small 3.1
#define TN_TOK_IMG_BREAK "v.token_embd.img_break" // pixtral #define TN_TOK_IMG_BREAK "v.token_embd.img_break" // pixtral
#define TN_TOK_GLM_BOI "adapter.boi" // glm-edge (these embeddings are not in text model)
#define TN_TOK_GLM_EOI "adapter.eoi" // glm-edge (these embeddings are not in text model)
// mimicpmv // mimicpmv
#define TN_MINICPMV_POS_EMBD_K "resampler.pos_embed_k" #define TN_MINICPMV_POS_EMBD_K "resampler.pos_embed_k"
@@ -94,9 +88,6 @@
#define TN_GLM_ADAPTER_GATE "adapter.linear.gate.%s" #define TN_GLM_ADAPTER_GATE "adapter.linear.gate.%s"
#define TN_GLM_ADAPTER_D_4H_2_H "adapter.linear.dense_4h_to_h.%s" #define TN_GLM_ADAPTER_D_4H_2_H "adapter.linear.dense_4h_to_h.%s"
// align x to upper multiple of n
#define CLIP_ALIGN(x, n) ((((x) + (n) - 1) / (n)) * (n))
enum projector_type { enum projector_type {
PROJECTOR_TYPE_MLP, PROJECTOR_TYPE_MLP,
PROJECTOR_TYPE_MLP_NORM, PROJECTOR_TYPE_MLP_NORM,
@@ -109,7 +100,6 @@ enum projector_type {
PROJECTOR_TYPE_IDEFICS3, PROJECTOR_TYPE_IDEFICS3,
PROJECTOR_TYPE_PIXTRAL, PROJECTOR_TYPE_PIXTRAL,
PROJECTOR_TYPE_QWEN25VL, PROJECTOR_TYPE_QWEN25VL,
PROJECTOR_TYPE_INTERNVL,
PROJECTOR_TYPE_UNKNOWN, PROJECTOR_TYPE_UNKNOWN,
}; };
@@ -124,7 +114,6 @@ static std::map<projector_type, std::string> PROJECTOR_TYPE_NAMES = {
{ PROJECTOR_TYPE_GEMMA3, "gemma3"}, { PROJECTOR_TYPE_GEMMA3, "gemma3"},
{ PROJECTOR_TYPE_IDEFICS3, "idefics3"}, { PROJECTOR_TYPE_IDEFICS3, "idefics3"},
{ PROJECTOR_TYPE_PIXTRAL, "pixtral"}, { PROJECTOR_TYPE_PIXTRAL, "pixtral"},
{ PROJECTOR_TYPE_INTERNVL, "internvl"},
}; };
static projector_type clip_projector_type_from_string(const std::string & str) { static projector_type clip_projector_type_from_string(const std::string & str) {
@@ -239,15 +228,6 @@ struct clip_image_u8_batch {
struct clip_image_f32_batch { struct clip_image_f32_batch {
std::vector<clip_image_f32_ptr> entries; std::vector<clip_image_f32_ptr> entries;
clip_image_f32_batch clone() const {
clip_image_f32_batch new_batch;
new_batch.entries.reserve(entries.size());
for (const auto & entry : entries) {
new_batch.entries.emplace_back(new clip_image_f32(*entry));
}
return new_batch;
}
}; };
// //

View File

@@ -78,10 +78,10 @@ CLIP_API int clip_uhd_num_image_embeds_col(struct clip_ctx * ctx_clip);
CLIP_API void clip_add_load_image_size(struct clip_ctx * ctx_clip, struct clip_image_size * load_image_size); CLIP_API void clip_add_load_image_size(struct clip_ctx * ctx_clip, struct clip_image_size * load_image_size);
CLIP_API struct clip_image_size * clip_get_load_image_size(struct clip_ctx * ctx_clip); CLIP_API struct clip_image_size * clip_get_load_image_size(struct clip_ctx * ctx_clip);
CLIP_API struct clip_image_size * clip_image_size_init(void); CLIP_API struct clip_image_size * clip_image_size_init();
CLIP_API struct clip_image_u8 * clip_image_u8_init (void); CLIP_API struct clip_image_u8 * clip_image_u8_init ();
CLIP_API struct clip_image_f32 * clip_image_f32_init(void); CLIP_API struct clip_image_f32 * clip_image_f32_init();
CLIP_API struct clip_image_f32_batch * clip_image_f32_batch_init(void); // only used by libllava CLIP_API struct clip_image_f32_batch * clip_image_f32_batch_init(); // only used by libllava
// nx, ny are the output image dimensions // nx, ny are the output image dimensions
CLIP_API unsigned char * clip_image_u8_get_data(struct clip_image_u8 * img, uint32_t * nx, uint32_t * ny); CLIP_API unsigned char * clip_image_u8_get_data(struct clip_image_u8 * img, uint32_t * nx, uint32_t * ny);

View File

@@ -2,7 +2,6 @@
#include "llava.h" #include "llava.h"
#include "llama.h" #include "llama.h"
#include "ggml-cpp.h"
#include <algorithm> #include <algorithm>
#include <cerrno> #include <cerrno>
@@ -210,11 +209,7 @@ static bool clip_llava_handle_patches(clip_ctx * ctx_clip, std::vector<float *>
struct ggml_tensor *flatten = ggml_view_2d(model.ctx, permuted_cont, clip_n_mmproj_embd(ctx_clip), num_patches_height * num_patches_width * num_patches_per_side * num_patches_per_side, size_ele * clip_n_mmproj_embd(ctx_clip), 0); struct ggml_tensor *flatten = ggml_view_2d(model.ctx, permuted_cont, clip_n_mmproj_embd(ctx_clip), num_patches_height * num_patches_width * num_patches_per_side * num_patches_per_side, size_ele * clip_n_mmproj_embd(ctx_clip), 0);
// ggml_tensor_printf(flatten,"flatten",__LINE__,false,false); // ggml_tensor_printf(flatten,"flatten",__LINE__,false,false);
ggml_build_forward_expand(gf, flatten); ggml_build_forward_expand(gf, flatten);
ggml_graph_compute_with_ctx(model.ctx, gf, 1);
ggml_backend_ptr backend { ggml_backend_init_by_type(GGML_BACKEND_DEVICE_TYPE_CPU, nullptr) };
GGML_ASSERT(backend != nullptr && "failed to initialize CPU backend");
ggml_backend_graph_compute(backend.get(), gf);
struct ggml_tensor* result = ggml_graph_node(gf, -1); struct ggml_tensor* result = ggml_graph_node(gf, -1);
memcpy(image_embd_out, image_embd_v[0], clip_embd_nbytes(ctx_clip)); // main image as global context memcpy(image_embd_out, image_embd_v[0], clip_embd_nbytes(ctx_clip)); // main image as global context
@@ -462,7 +457,7 @@ struct llava_embd_batch {
std::vector<llama_seq_id *> seq_ids; std::vector<llama_seq_id *> seq_ids;
std::vector<int8_t> logits; std::vector<int8_t> logits;
llama_batch batch; llama_batch batch;
llava_embd_batch(float * embd, int32_t n_tokens, llama_pos pos_0, llama_seq_id seq_id) { llava_embd_batch(float * embd, int32_t n_embd, int32_t n_tokens, llama_pos pos_0, llama_seq_id seq_id) {
pos .resize(n_tokens); pos .resize(n_tokens);
n_seq_id.resize(n_tokens); n_seq_id.resize(n_tokens);
seq_ids .resize(n_tokens + 1); seq_ids .resize(n_tokens + 1);
@@ -474,6 +469,7 @@ struct llava_embd_batch {
/*n_tokens =*/ n_tokens, /*n_tokens =*/ n_tokens,
/*tokens =*/ nullptr, /*tokens =*/ nullptr,
/*embd =*/ embd, /*embd =*/ embd,
/*n_embd =*/ n_embd,
/*pos =*/ pos.data(), /*pos =*/ pos.data(),
/*n_seq_id =*/ n_seq_id.data(), /*n_seq_id =*/ n_seq_id.data(),
/*seq_id =*/ seq_ids.data(), /*seq_id =*/ seq_ids.data(),
@@ -497,7 +493,7 @@ bool llava_eval_image_embed(llama_context * ctx_llama, const struct llava_image_
n_eval = n_batch; n_eval = n_batch;
} }
float * embd = image_embed->embed+i*n_embd; float * embd = image_embed->embed+i*n_embd;
llava_embd_batch llava_batch = llava_embd_batch(embd, n_eval, *n_past, 0); llava_embd_batch llava_batch = llava_embd_batch(embd, n_embd, n_eval, *n_past, 0);
if (llama_decode(ctx_llama, llava_batch.batch)) { if (llama_decode(ctx_llama, llava_batch.batch)) {
LOG_ERR("%s : failed to eval\n", __func__); LOG_ERR("%s : failed to eval\n", __func__);
return false; return false;

View File

@@ -1,4 +1,4 @@
package mtmd package llava
// #cgo CXXFLAGS: -std=c++11 // #cgo CXXFLAGS: -std=c++11
// #cgo CPPFLAGS: -I${SRCDIR}/../../include -I${SRCDIR}/../../common // #cgo CPPFLAGS: -I${SRCDIR}/../../include -I${SRCDIR}/../../common

View File

@@ -4,7 +4,6 @@
#include "ggml.h" #include "ggml.h"
#include "ggml-cpu.h" #include "ggml-cpu.h"
#include "ggml-backend.h" #include "ggml-backend.h"
#include "ggml-opt.h"
#include <stddef.h> #include <stddef.h>
#include <stdint.h> #include <stdint.h>
@@ -113,7 +112,6 @@ extern "C" {
LLAMA_VOCAB_PRE_TYPE_BAILINGMOE = 32, LLAMA_VOCAB_PRE_TYPE_BAILINGMOE = 32,
LLAMA_VOCAB_PRE_TYPE_LLAMA4 = 33, LLAMA_VOCAB_PRE_TYPE_LLAMA4 = 33,
LLAMA_VOCAB_PRE_TYPE_PIXTRAL = 34, LLAMA_VOCAB_PRE_TYPE_PIXTRAL = 34,
LLAMA_VOCAB_PRE_TYPE_SEED_CODER = 35,
}; };
enum llama_rope_type { enum llama_rope_type {
@@ -258,6 +256,7 @@ extern "C" {
llama_token * token; llama_token * token;
float * embd; float * embd;
int32_t n_embd;
llama_pos * pos; llama_pos * pos;
int32_t * n_seq_id; int32_t * n_seq_id;
llama_seq_id ** seq_id; llama_seq_id ** seq_id;
@@ -353,18 +352,20 @@ extern "C" {
enum ggml_type type_k; // data type for K cache [EXPERIMENTAL] enum ggml_type type_k; // data type for K cache [EXPERIMENTAL]
enum ggml_type type_v; // data type for V cache [EXPERIMENTAL] enum ggml_type type_v; // data type for V cache [EXPERIMENTAL]
// Keep the booleans together and at the end of the struct to avoid misalignment during copy-by-value.
// TODO: move at the end of the struct
bool logits_all; // the llama_decode() call computes all logits, not just the last one (DEPRECATED - set llama_batch.logits instead)
bool embeddings; // if true, extract embeddings (together with logits)
bool offload_kqv; // whether to offload the KQV ops (including the KV cache) to GPU
bool flash_attn; // whether to use flash attention [EXPERIMENTAL]
bool no_perf; // whether to measure performance timings
bool cross_attn; // whether to use cross attention
// Abort callback // Abort callback
// if it returns true, execution of llama_decode() will be aborted // if it returns true, execution of llama_decode() will be aborted
// currently works only with CPU execution // currently works only with CPU execution
ggml_abort_callback abort_callback; ggml_abort_callback abort_callback;
void * abort_callback_data; void * abort_callback_data;
// Keep the booleans together and at the end of the struct to avoid misalignment during copy-by-value.
bool embeddings; // if true, extract embeddings (together with logits)
bool offload_kqv; // whether to offload the KQV ops (including the KV cache) to GPU
bool flash_attn; // whether to use flash attention [EXPERIMENTAL]
bool no_perf; // whether to measure performance timings
bool op_offload; // whether to offload host tensor operations to device
}; };
// model quantization parameters // model quantization parameters
@@ -446,10 +447,6 @@ extern "C" {
size_t n_paths, size_t n_paths,
struct llama_model_params params); struct llama_model_params params);
LLAMA_API void llama_model_save_to_file(
const struct llama_model * model,
const char * path_model);
DEPRECATED(LLAMA_API void llama_free_model(struct llama_model * model), DEPRECATED(LLAMA_API void llama_free_model(struct llama_model * model),
"use llama_model_free instead"); "use llama_model_free instead");
@@ -464,6 +461,10 @@ extern "C" {
struct llama_context_params params), struct llama_context_params params),
"use llama_init_from_model instead"); "use llama_init_from_model instead");
// TODO (jmorganca): this should most likely be passed in as part of a batch
// and not set on the context for all batches.
LLAMA_API void llama_set_cross_attention(struct llama_context * ctx, bool cross_attn_state);
// Frees all allocated memory // Frees all allocated memory
LLAMA_API void llama_free(struct llama_context * ctx); LLAMA_API void llama_free(struct llama_context * ctx);
@@ -929,19 +930,14 @@ extern "C" {
// Frees a batch of tokens allocated with llama_batch_init() // Frees a batch of tokens allocated with llama_batch_init()
LLAMA_API void llama_batch_free(struct llama_batch batch); LLAMA_API void llama_batch_free(struct llama_batch batch);
// Process a batch of tokens. // Processes a batch of tokens with the ecoder part of the encoder-decoder model.
// In contrast to llama_decode() - this call does not use KV cache. // Stores the encoder output internally for later use by the decoder cross-attention layers.
// For encode-decoder contexts, processes the batch using the encoder.
// Can store the encoder output internally for later use by the decoder's cross-attention layers.
// 0 - success // 0 - success
// < 0 - error. the KV cache state is restored to the state before this call // < 0 - error. the KV cache state is restored to the state before this call
LLAMA_API int32_t llama_encode( LLAMA_API int32_t llama_encode(
struct llama_context * ctx, struct llama_context * ctx,
struct llama_batch batch); struct llama_batch batch);
// Process a batch of tokens.
// Requires KV cache.
// For encode-decoder contexts, processes the batch using the decoder.
// Positive return values does not mean a fatal error, but rather a warning. // Positive return values does not mean a fatal error, but rather a warning.
// 0 - success // 0 - success
// 1 - could not find a KV slot for the batch (try reducing the size of the batch or increase the context) // 1 - could not find a KV slot for the batch (try reducing the size of the batch or increase the context)
@@ -1438,37 +1434,6 @@ extern "C" {
LLAMA_API void llama_perf_sampler_print(const struct llama_sampler * chain); LLAMA_API void llama_perf_sampler_print(const struct llama_sampler * chain);
LLAMA_API void llama_perf_sampler_reset( struct llama_sampler * chain); LLAMA_API void llama_perf_sampler_reset( struct llama_sampler * chain);
//
// training
//
// function that returns whether or not a given tensor contains trainable parameters
typedef bool (*llama_opt_param_filter)(const struct ggml_tensor * tensor, void * userdata);
// always returns true
LLAMA_API bool llama_opt_param_filter_all(const struct ggml_tensor * tensor, void * userdata);
struct llama_opt_params {
uint32_t n_ctx_train; // assumed context size post training, use context size specified in llama_context if 0
llama_opt_param_filter param_filter; // callback for determining which tensors contain trainable parameters
void * param_filter_ud; // userdata for determining which tensors contain trainable parameters
ggml_opt_get_optimizer_params get_opt_pars; // callback for calculating optimizer parameters
void * get_opt_pars_ud; // userdata for calculating optimizer parameters
};
LLAMA_API void llama_opt_init(struct llama_context * lctx, struct llama_model * model, struct llama_opt_params lopt_params);
LLAMA_API void llama_opt_epoch(
struct llama_context * lctx,
ggml_opt_dataset_t dataset,
ggml_opt_result_t result_train,
ggml_opt_result_t result_eval,
int64_t idata_split,
ggml_opt_epoch_callback callback_train,
ggml_opt_epoch_callback callback_eval);
#ifdef __cplusplus #ifdef __cplusplus
} }
#endif #endif

View File

@@ -253,9 +253,6 @@ static void llama_adapter_lora_init_impl(llama_model & model, const char * path_
std::vector<ggml_backend_buffer_type_t> buft_extra; std::vector<ggml_backend_buffer_type_t> buft_extra;
{ {
auto * cpu_dev = ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_CPU); auto * cpu_dev = ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_CPU);
if (!cpu_dev) {
throw std::runtime_error(format("%s: no CPU backend found", __func__));
}
auto * cpu_reg = ggml_backend_dev_backend_reg(cpu_dev); auto * cpu_reg = ggml_backend_dev_backend_reg(cpu_dev);
auto ggml_backend_dev_get_extra_bufts_fn = (ggml_backend_dev_get_extra_bufts_t) auto ggml_backend_dev_get_extra_bufts_fn = (ggml_backend_dev_get_extra_bufts_t)
@@ -294,9 +291,6 @@ static void llama_adapter_lora_init_impl(llama_model & model, const char * path_
LLAMA_LOG_WARN("%s: lora for '%s' cannot use buft '%s', fallback to CPU\n", __func__, model_tensor->name, ggml_backend_buft_name(buft)); LLAMA_LOG_WARN("%s: lora for '%s' cannot use buft '%s', fallback to CPU\n", __func__, model_tensor->name, ggml_backend_buft_name(buft));
auto * cpu_dev = ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_CPU); auto * cpu_dev = ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_CPU);
if (!cpu_dev) {
throw std::runtime_error(format("%s: no CPU backend found", __func__));
}
buft = ggml_backend_dev_buffer_type(cpu_dev); buft = ggml_backend_dev_buffer_type(cpu_dev);
break; break;

View File

@@ -6,6 +6,7 @@
static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = { static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
{ LLM_ARCH_LLAMA, "llama" }, { LLM_ARCH_LLAMA, "llama" },
{ LLM_ARCH_MLLAMA, "mllama" },
{ LLM_ARCH_LLAMA4, "llama4" }, { LLM_ARCH_LLAMA4, "llama4" },
{ LLM_ARCH_DECI, "deci" }, { LLM_ARCH_DECI, "deci" },
{ LLM_ARCH_FALCON, "falcon" }, { LLM_ARCH_FALCON, "falcon" },
@@ -144,6 +145,7 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
{ LLM_KV_ATTENTION_SLIDING_WINDOW, "%s.attention.sliding_window" }, { LLM_KV_ATTENTION_SLIDING_WINDOW, "%s.attention.sliding_window" },
{ LLM_KV_ATTENTION_SCALE, "%s.attention.scale" }, { LLM_KV_ATTENTION_SCALE, "%s.attention.scale" },
{ LLM_KV_ATTENTION_BLOCK_SKIP_CONNECTION, "%s.attention.block_skip_connection" }, { LLM_KV_ATTENTION_BLOCK_SKIP_CONNECTION, "%s.attention.block_skip_connection" },
{ LLM_KV_ATTENTION_CROSS_ATTENTION_LAYERS, "%s.attention.cross_attention_layers" },
{ LLM_KV_ATTENTION_KEY_LENGTH_MLA, "%s.attention.key_length_mla" }, { LLM_KV_ATTENTION_KEY_LENGTH_MLA, "%s.attention.key_length_mla" },
{ LLM_KV_ATTENTION_VALUE_LENGTH_MLA, "%s.attention.value_length_mla" }, { LLM_KV_ATTENTION_VALUE_LENGTH_MLA, "%s.attention.value_length_mla" },
@@ -273,6 +275,40 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
{ LLM_TENSOR_FFN_UP_SHEXP, "blk.%d.ffn_up_shexp" }, { LLM_TENSOR_FFN_UP_SHEXP, "blk.%d.ffn_up_shexp" },
}, },
}, },
{
LLM_ARCH_MLLAMA,
{
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
{ LLM_TENSOR_OUTPUT, "output" },
{ LLM_TENSOR_ROPE_FREQS, "rope_freqs" },
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
{ LLM_TENSOR_ATTN_ROT_EMBD, "blk.%d.attn_rot_embd" },
{ LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" },
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
{ LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
{ LLM_TENSOR_FFN_GATE_EXP, "blk.%d.ffn_gate.%d" },
{ LLM_TENSOR_FFN_DOWN_EXP, "blk.%d.ffn_down.%d" },
{ LLM_TENSOR_FFN_UP_EXP, "blk.%d.ffn_up.%d" },
{ LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" },
{ LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" },
{ LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" },
{ LLM_TENSOR_CROSS_ATTN_K_NORM, "blk.%d.cross_attn_k_norm" },
{ LLM_TENSOR_CROSS_ATTN_K_PROJ, "blk.%d.cross_attn_k_proj" },
{ LLM_TENSOR_CROSS_ATTN_O_PROJ, "blk.%d.cross_attn_o_proj" },
{ LLM_TENSOR_CROSS_ATTN_Q_NORM, "blk.%d.cross_attn_q_norm" },
{ LLM_TENSOR_CROSS_ATTN_Q_PROJ, "blk.%d.cross_attn_q_proj" },
{ LLM_TENSOR_CROSS_ATTN_V_PROJ, "blk.%d.cross_attn_v_proj" },
{ LLM_TENSOR_CROSS_ATTN_ATTN_GATE, "blk.%d.cross_attn_attn_gate" },
{ LLM_TENSOR_CROSS_ATTN_MLP_GATE, "blk.%d.cross_attn_mlp_gate" },
},
},
{ {
LLM_ARCH_DECI, LLM_ARCH_DECI,
{ {
@@ -1701,6 +1737,14 @@ static const std::map<llm_tensor, llm_tensor_info> LLM_TENSOR_INFOS = {
// this tensor is loaded for T5, but never used // this tensor is loaded for T5, but never used
{LLM_TENSOR_DEC_CROSS_ATTN_REL_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_NONE}}, {LLM_TENSOR_DEC_CROSS_ATTN_REL_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_NONE}},
{LLM_TENSOR_BSKCN_TV, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, {LLM_TENSOR_BSKCN_TV, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
{LLM_TENSOR_CROSS_ATTN_K_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
{LLM_TENSOR_CROSS_ATTN_K_PROJ, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_CROSS_ATTN_O_PROJ, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_CROSS_ATTN_Q_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
{LLM_TENSOR_CROSS_ATTN_Q_PROJ, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_CROSS_ATTN_V_PROJ, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_CROSS_ATTN_ATTN_GATE, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
{LLM_TENSOR_CROSS_ATTN_MLP_GATE, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
{LLM_TENSOR_CONV1D, {LLM_TENSOR_LAYER_INPUT, GGML_OP_IM2COL}}, {LLM_TENSOR_CONV1D, {LLM_TENSOR_LAYER_INPUT, GGML_OP_IM2COL}},
{LLM_TENSOR_POS_NET_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, {LLM_TENSOR_POS_NET_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
{LLM_TENSOR_POS_NET_NORM1, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, {LLM_TENSOR_POS_NET_NORM1, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},

View File

@@ -11,6 +11,7 @@
enum llm_arch { enum llm_arch {
LLM_ARCH_LLAMA, LLM_ARCH_LLAMA,
LLM_ARCH_LLAMA4, LLM_ARCH_LLAMA4,
LLM_ARCH_MLLAMA,
LLM_ARCH_DECI, LLM_ARCH_DECI,
LLM_ARCH_FALCON, LLM_ARCH_FALCON,
LLM_ARCH_BAICHUAN, LLM_ARCH_BAICHUAN,
@@ -148,6 +149,7 @@ enum llm_kv {
LLM_KV_ATTENTION_SLIDING_WINDOW, LLM_KV_ATTENTION_SLIDING_WINDOW,
LLM_KV_ATTENTION_SCALE, LLM_KV_ATTENTION_SCALE,
LLM_KV_ATTENTION_BLOCK_SKIP_CONNECTION, LLM_KV_ATTENTION_BLOCK_SKIP_CONNECTION,
LLM_KV_ATTENTION_CROSS_ATTENTION_LAYERS,
LLM_KV_ATTENTION_KEY_LENGTH_MLA, LLM_KV_ATTENTION_KEY_LENGTH_MLA,
LLM_KV_ATTENTION_VALUE_LENGTH_MLA, LLM_KV_ATTENTION_VALUE_LENGTH_MLA,
@@ -349,6 +351,14 @@ enum llm_tensor {
LLM_TENSOR_CLS, LLM_TENSOR_CLS,
LLM_TENSOR_CLS_OUT, LLM_TENSOR_CLS_OUT,
LLM_TENSOR_BSKCN_TV, LLM_TENSOR_BSKCN_TV,
LLM_TENSOR_CROSS_ATTN_K_NORM,
LLM_TENSOR_CROSS_ATTN_K_PROJ,
LLM_TENSOR_CROSS_ATTN_O_PROJ,
LLM_TENSOR_CROSS_ATTN_Q_NORM,
LLM_TENSOR_CROSS_ATTN_Q_PROJ,
LLM_TENSOR_CROSS_ATTN_V_PROJ,
LLM_TENSOR_CROSS_ATTN_ATTN_GATE,
LLM_TENSOR_CROSS_ATTN_MLP_GATE,
LLM_TENSOR_CONV1D, LLM_TENSOR_CONV1D,
LLM_TENSOR_CONVNEXT_DW, LLM_TENSOR_CONVNEXT_DW,
LLM_TENSOR_CONVNEXT_NORM, LLM_TENSOR_CONVNEXT_NORM,

View File

@@ -189,7 +189,7 @@ llama_ubatch llama_sbatch::split_seq(size_t n_ubatch) {
return ubatch; return ubatch;
} }
llama_sbatch::llama_sbatch(const llama_batch & batch, size_t n_embd, bool simple_split, bool logits_all) { void llama_sbatch::from_batch(const llama_batch & batch, size_t n_embd, bool simple_split, bool logits_all) {
GGML_ASSERT(batch.n_tokens >= 0); GGML_ASSERT(batch.n_tokens >= 0);
this->batch = &batch; this->batch = &batch;
this->n_embd = n_embd; this->n_embd = n_embd;
@@ -203,7 +203,6 @@ llama_sbatch::llama_sbatch(const llama_batch & batch, size_t n_embd, bool simple
for (size_t i = 0; i < n_tokens; ++i) { for (size_t i = 0; i < n_tokens; ++i) {
ids[i] = i; ids[i] = i;
} }
if (simple_split) { if (simple_split) {
seq.resize(1); seq.resize(1);
llama_sbatch_seq & s = seq[0]; llama_sbatch_seq & s = seq[0];
@@ -213,7 +212,6 @@ llama_sbatch::llama_sbatch(const llama_batch & batch, size_t n_embd, bool simple
s.length = n_tokens; s.length = n_tokens;
return; return;
} }
std::sort(ids.begin(), ids.end(), std::sort(ids.begin(), ids.end(),
[&batch](size_t a, size_t b) { [&batch](size_t a, size_t b) {
int32_t n_seq_a = batch.n_seq_id ? batch.n_seq_id[a] : 1; int32_t n_seq_a = batch.n_seq_id ? batch.n_seq_id[a] : 1;
@@ -241,7 +239,6 @@ llama_sbatch::llama_sbatch(const llama_batch & batch, size_t n_embd, bool simple
return n_seq_a > n_seq_b; return n_seq_a > n_seq_b;
} }
); );
// init seq // init seq
llama_sbatch_seq * last_seq = nullptr; llama_sbatch_seq * last_seq = nullptr;
@@ -265,7 +262,6 @@ llama_sbatch::llama_sbatch(const llama_batch & batch, size_t n_embd, bool simple
seq.push_back(new_seq); seq.push_back(new_seq);
last_seq = &seq.back(); last_seq = &seq.back();
} }
// keep shared prompts first at the end, then sort by length descending. // keep shared prompts first at the end, then sort by length descending.
std::sort(seq.begin(), seq.end(), std::sort(seq.begin(), seq.end(),
[](llama_sbatch_seq & a, llama_sbatch_seq & b) { [](llama_sbatch_seq & a, llama_sbatch_seq & b) {
@@ -320,6 +316,7 @@ struct llama_batch llama_batch_get_one(
/*n_tokens =*/ n_tokens, /*n_tokens =*/ n_tokens,
/*tokens =*/ tokens, /*tokens =*/ tokens,
/*embd =*/ nullptr, /*embd =*/ nullptr,
/*n_embd =*/ 0,
/*pos =*/ nullptr, /*pos =*/ nullptr,
/*n_seq_id =*/ nullptr, /*n_seq_id =*/ nullptr,
/*seq_id =*/ nullptr, /*seq_id =*/ nullptr,
@@ -332,6 +329,7 @@ struct llama_batch llama_batch_init(int32_t n_tokens_alloc, int32_t embd, int32_
/*n_tokens =*/ 0, /*n_tokens =*/ 0,
/*tokens =*/ nullptr, /*tokens =*/ nullptr,
/*embd =*/ nullptr, /*embd =*/ nullptr,
/*n_embd =*/ 0,
/*pos =*/ nullptr, /*pos =*/ nullptr,
/*n_seq_id =*/ nullptr, /*n_seq_id =*/ nullptr,
/*seq_id =*/ nullptr, /*seq_id =*/ nullptr,
@@ -340,6 +338,7 @@ struct llama_batch llama_batch_init(int32_t n_tokens_alloc, int32_t embd, int32_
if (embd) { if (embd) {
batch.embd = (float *) malloc(sizeof(float) * n_tokens_alloc * embd); batch.embd = (float *) malloc(sizeof(float) * n_tokens_alloc * embd);
batch.n_embd = embd;
} else { } else {
batch.token = (llama_token *) malloc(sizeof(llama_token) * n_tokens_alloc); batch.token = (llama_token *) malloc(sizeof(llama_token) * n_tokens_alloc);
} }

View File

@@ -70,8 +70,7 @@ struct llama_sbatch {
// sequence-wise split // sequence-wise split
llama_ubatch split_seq(size_t n_ubatch); llama_ubatch split_seq(size_t n_ubatch);
llama_sbatch() = default; void from_batch(const llama_batch & batch, size_t n_embd, bool simple_split = false, bool logits_all = false);
llama_sbatch(const llama_batch & batch, size_t n_embd, bool simple_split = false, bool logits_all = false);
}; };
// temporary allocate memory for the input batch if needed // temporary allocate memory for the input batch if needed

View File

@@ -35,7 +35,6 @@ static const std::map<std::string, llm_chat_template> LLM_CHAT_TEMPLATES = {
{ "mistral-v3", LLM_CHAT_TEMPLATE_MISTRAL_V3 }, { "mistral-v3", LLM_CHAT_TEMPLATE_MISTRAL_V3 },
{ "mistral-v3-tekken", LLM_CHAT_TEMPLATE_MISTRAL_V3_TEKKEN }, { "mistral-v3-tekken", LLM_CHAT_TEMPLATE_MISTRAL_V3_TEKKEN },
{ "mistral-v7", LLM_CHAT_TEMPLATE_MISTRAL_V7 }, { "mistral-v7", LLM_CHAT_TEMPLATE_MISTRAL_V7 },
{ "mistral-v7-tekken", LLM_CHAT_TEMPLATE_MISTRAL_V7_TEKKEN },
{ "phi3", LLM_CHAT_TEMPLATE_PHI_3 }, { "phi3", LLM_CHAT_TEMPLATE_PHI_3 },
{ "phi4", LLM_CHAT_TEMPLATE_PHI_4 }, { "phi4", LLM_CHAT_TEMPLATE_PHI_4 },
{ "falcon3", LLM_CHAT_TEMPLATE_FALCON_3 }, { "falcon3", LLM_CHAT_TEMPLATE_FALCON_3 },
@@ -203,20 +202,19 @@ int32_t llm_chat_apply_template(
if (add_ass) { if (add_ass) {
ss << "<|im_start|>assistant\n"; ss << "<|im_start|>assistant\n";
} }
} else if (tmpl == LLM_CHAT_TEMPLATE_MISTRAL_V7 || tmpl == LLM_CHAT_TEMPLATE_MISTRAL_V7_TEKKEN) { } else if (tmpl == LLM_CHAT_TEMPLATE_MISTRAL_V7) {
// Official mistral 'v7' template // Official mistral 'v7' template
// See: https://huggingface.co/mistralai/Mistral-Large-Instruct-2411#basic-instruct-template-v7 // See: https://huggingface.co/mistralai/Mistral-Large-Instruct-2411#basic-instruct-template-v7
// https://huggingface.co/mistralai/Mistral-Small-3.1-24B-Instruct-2503#basic-instruct-template-v7-tekken
const char * trailing_space = tmpl == LLM_CHAT_TEMPLATE_MISTRAL_V7 ? " " : "";
for (auto message : chat) { for (auto message : chat) {
std::string role(message->role); std::string role(message->role);
std::string content(message->content); std::string content(message->content);
if (role == "system") { if (role == "system") {
ss << "[SYSTEM_PROMPT]" << trailing_space << content << "[/SYSTEM_PROMPT]"; ss << "[SYSTEM_PROMPT] " << content << "[/SYSTEM_PROMPT]";
} else if (role == "user") { } else if (role == "user") {
ss << "[INST]" << trailing_space << content << "[/INST]"; ss << "[INST] " << content << "[/INST]";
} else { }
ss << trailing_space << content << "</s>"; else {
ss << " " << content << "</s>";
} }
} }
} else if (tmpl == LLM_CHAT_TEMPLATE_MISTRAL_V1 } else if (tmpl == LLM_CHAT_TEMPLATE_MISTRAL_V1
@@ -449,16 +447,8 @@ int32_t llm_chat_apply_template(
if (add_ass) { if (add_ass) {
ss << "<|assistant|>"; ss << "<|assistant|>";
} }
} else if (tmpl == LLM_CHAT_TEMPLATE_CHATGLM_4) { } else if (tmpl == LLM_CHAT_TEMPLATE_CHATGLM_4 || tmpl == LLM_CHAT_TEMPLATE_GLMEDGE) {
ss << "[gMASK]" << "<sop>"; ss << "[gMASK]" << "<sop>";
for (auto message : chat) {
std::string role(message->role);
ss << "<|" << role << "|>" << "\n" << message->content;
}
if (add_ass) {
ss << "<|assistant|>\n";
}
} else if (tmpl == LLM_CHAT_TEMPLATE_GLMEDGE) {
for (auto message : chat) { for (auto message : chat) {
std::string role(message->role); std::string role(message->role);
ss << "<|" << role << "|>" << "\n" << message->content; ss << "<|" << role << "|>" << "\n" << message->content;

View File

@@ -14,7 +14,6 @@ enum llm_chat_template {
LLM_CHAT_TEMPLATE_MISTRAL_V3, LLM_CHAT_TEMPLATE_MISTRAL_V3,
LLM_CHAT_TEMPLATE_MISTRAL_V3_TEKKEN, LLM_CHAT_TEMPLATE_MISTRAL_V3_TEKKEN,
LLM_CHAT_TEMPLATE_MISTRAL_V7, LLM_CHAT_TEMPLATE_MISTRAL_V7,
LLM_CHAT_TEMPLATE_MISTRAL_V7_TEKKEN,
LLM_CHAT_TEMPLATE_PHI_3, LLM_CHAT_TEMPLATE_PHI_3,
LLM_CHAT_TEMPLATE_PHI_4, LLM_CHAT_TEMPLATE_PHI_4,
LLM_CHAT_TEMPLATE_FALCON_3, LLM_CHAT_TEMPLATE_FALCON_3,

File diff suppressed because it is too large Load Diff

Some files were not shown because too many files have changed in this diff Show More