Compare commits
1 Commits
pdevine/au
...
parth/opt-
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
beaa0e82f3 |
186
.github/workflows/release.yaml
vendored
186
.github/workflows/release.yaml
vendored
@@ -54,6 +54,48 @@ jobs:
|
||||
name: build-${{ matrix.os }}-${{ matrix.arch }}
|
||||
path: dist/*
|
||||
|
||||
darwin-sign:
|
||||
runs-on: macos-13
|
||||
environment: release
|
||||
needs: darwin-build
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- run: |
|
||||
echo $MACOS_SIGNING_KEY | base64 --decode > certificate.p12
|
||||
security create-keychain -p password build.keychain
|
||||
security default-keychain -s build.keychain
|
||||
security unlock-keychain -p password build.keychain
|
||||
security import certificate.p12 -k build.keychain -P $MACOS_SIGNING_KEY_PASSWORD -T /usr/bin/codesign
|
||||
security set-key-partition-list -S apple-tool:,apple:,codesign: -s -k password build.keychain
|
||||
security set-keychain-settings -lut 3600 build.keychain
|
||||
env:
|
||||
MACOS_SIGNING_KEY: ${{ secrets.MACOS_SIGNING_KEY }}
|
||||
MACOS_SIGNING_KEY_PASSWORD: ${{ secrets.MACOS_SIGNING_KEY_PASSWORD }}
|
||||
- uses: actions/download-artifact@v4
|
||||
with:
|
||||
name: build-darwin-amd64
|
||||
path: dist/darwin-amd64
|
||||
- uses: actions/download-artifact@v4
|
||||
with:
|
||||
name: build-darwin-arm64
|
||||
path: dist/darwin-arm64
|
||||
- run: |
|
||||
export VERSION=${GITHUB_REF_NAME#v}
|
||||
./scripts/build_darwin.sh sign macapp
|
||||
env:
|
||||
APPLE_IDENTITY: ${{ secrets.APPLE_IDENTITY }}
|
||||
APPLE_PASSWORD: ${{ secrets.APPLE_PASSWORD }}
|
||||
APPLE_TEAM_ID: ${{ vars.APPLE_TEAM_ID }}
|
||||
APPLE_ID: ${{ vars.APPLE_ID }}
|
||||
SDKROOT: /Applications/Xcode_14.1.0.app/Contents/Developer/Platforms/MacOSX.platform/Developer/SDKs/MacOSX.sdk
|
||||
DEVELOPER_DIR: /Applications/Xcode_14.1.0.app/Contents/Developer
|
||||
- uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: dist-darwin
|
||||
path: |
|
||||
dist/Ollama-darwin.zip
|
||||
dist/ollama-darwin.tgz
|
||||
|
||||
windows-depends:
|
||||
strategy:
|
||||
matrix:
|
||||
@@ -61,18 +103,21 @@ jobs:
|
||||
arch: [amd64]
|
||||
preset: ['CPU']
|
||||
include:
|
||||
- os: windows
|
||||
arch: amd64
|
||||
preset: 'CUDA 11'
|
||||
install: https://developer.download.nvidia.com/compute/cuda/11.3.1/local_installers/cuda_11.3.1_465.89_win10.exe
|
||||
cuda-version: '11.3'
|
||||
- os: windows
|
||||
arch: amd64
|
||||
preset: 'CUDA 12'
|
||||
install: https://developer.download.nvidia.com/compute/cuda/12.8.0/local_installers/cuda_12.8.0_571.96_windows.exe
|
||||
cuda-version: '12.8'
|
||||
flags: ''
|
||||
- os: windows
|
||||
arch: amd64
|
||||
preset: 'ROCm 6'
|
||||
install: https://download.amd.com/developer/eula/rocm-hub/AMD-Software-PRO-Edition-24.Q4-WinSvr2022-For-HIP.exe
|
||||
rocm-version: '6.2'
|
||||
flags: '-DCMAKE_C_COMPILER=clang -DCMAKE_CXX_COMPILER=clang++ -DCMAKE_C_FLAGS="-parallel-jobs=4 -Wno-ignored-attributes -Wno-deprecated-pragma" -DCMAKE_CXX_FLAGS="-parallel-jobs=4 -Wno-ignored-attributes -Wno-deprecated-pragma"'
|
||||
runs-on: ${{ matrix.arch == 'arm64' && format('{0}-{1}', matrix.os, matrix.arch) || matrix.os }}
|
||||
environment: release
|
||||
env:
|
||||
@@ -115,9 +160,6 @@ jobs:
|
||||
echo "$hipPath\bin" | Out-File -FilePath $env:GITHUB_PATH -Encoding utf8 -Append
|
||||
echo "CC=$hipPath\bin\clang.exe" | Out-File -FilePath $env:GITHUB_ENV -Append
|
||||
echo "CXX=$hipPath\bin\clang++.exe" | Out-File -FilePath $env:GITHUB_ENV -Append
|
||||
echo "HIPCXX=$hipPath\bin\clang++.exe" | Out-File -FilePath $env:GITHUB_ENV -Append
|
||||
echo "HIP_PLATFORM=amd" | Out-File -FilePath $env:GITHUB_ENV -Append
|
||||
echo "CMAKE_PREFIX_PATH=$hipPath" | Out-File -FilePath $env:GITHUB_ENV -Append
|
||||
- if: matrix.preset == 'CPU'
|
||||
run: |
|
||||
echo "CC=clang.exe" | Out-File -FilePath $env:GITHUB_ENV -Append
|
||||
@@ -136,9 +178,9 @@ jobs:
|
||||
key: ccache-${{ matrix.os }}-${{ matrix.arch }}-${{ matrix.preset }}
|
||||
- name: Build target "${{ matrix.preset }}"
|
||||
run: |
|
||||
Import-Module 'C:\Program Files\Microsoft Visual Studio\2022\Enterprise\Common7\Tools\Microsoft.VisualStudio.DevShell.dll'
|
||||
Enter-VsDevShell -VsInstallPath 'C:\Program Files\Microsoft Visual Studio\2022\Enterprise' -SkipAutomaticLocation -DevCmdArguments '-arch=x64 -no_logo'
|
||||
cmake --preset "${{ matrix.preset }}" ${{ matrix.flags }}
|
||||
Import-Module 'C:\Program Files (x86)\Microsoft Visual Studio\2019\Enterprise\Common7\Tools\Microsoft.VisualStudio.DevShell.dll'
|
||||
Enter-VsDevShell -VsInstallPath 'C:\Program Files (x86)\Microsoft Visual Studio\2019\Enterprise' -SkipAutomaticLocation -DevCmdArguments '-arch=x64 -no_logo'
|
||||
cmake --preset "${{ matrix.preset }}"
|
||||
cmake --build --parallel --preset "${{ matrix.preset }}"
|
||||
cmake --install build --component "${{ startsWith(matrix.preset, 'CUDA ') && 'CUDA' || startsWith(matrix.preset, 'ROCm ') && 'HIP' || 'CPU' }}" --strip --parallel 8
|
||||
env:
|
||||
@@ -188,11 +230,61 @@ jobs:
|
||||
go-version-file: go.mod
|
||||
- run: |
|
||||
go build -o dist/${{ matrix.os }}-${{ matrix.arch }}/ .
|
||||
- if: matrix.arch == 'arm64'
|
||||
run: |
|
||||
Invoke-WebRequest -Uri "https://aka.ms/vs/17/release/vc_redist.arm64.exe" -OutFile "dist\windows-arm64\vc_redist.arm64.exe"
|
||||
- run: |
|
||||
$env:VERSION='${{ github.ref_name }}' -Replace "v(.*)", '$1'
|
||||
& .\scripts\build_windows.ps1 buildApp
|
||||
env:
|
||||
VCToolsRedistDir: stub
|
||||
- uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: build-${{ matrix.os }}-${{ matrix.arch }}
|
||||
path: |
|
||||
dist\${{ matrix.os }}-${{ matrix.arch }}\*.exe
|
||||
dist\${{ matrix.os }}-${{ matrix.arch }}-app.exe
|
||||
|
||||
windows-sign:
|
||||
runs-on: windows-2022
|
||||
environment: release
|
||||
needs: [windows-depends, windows-build]
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: google-github-actions/auth@v2
|
||||
with:
|
||||
project_id: ollama
|
||||
credentials_json: ${{ secrets.GOOGLE_SIGNING_CREDENTIALS }}
|
||||
- run: |
|
||||
$ErrorActionPreference = "Stop"
|
||||
Invoke-WebRequest -Uri "https://go.microsoft.com/fwlink/p/?LinkId=323507" -OutFile "${{ runner.temp }}\sdksetup.exe"
|
||||
Start-Process "${{ runner.temp }}\sdksetup.exe" -ArgumentList @("/q") -NoNewWindow -Wait
|
||||
|
||||
Invoke-WebRequest -Uri "https://github.com/GoogleCloudPlatform/kms-integrations/releases/download/cng-v1.0/kmscng-1.0-windows-amd64.zip" -OutFile "${{ runner.temp }}\plugin.zip"
|
||||
Expand-Archive -Path "${{ runner.temp }}\plugin.zip" -DestinationPath "${{ runner.temp }}\plugin\"
|
||||
& "${{ runner.temp }}\plugin\*\kmscng.msi" /quiet
|
||||
|
||||
echo "${{ vars.OLLAMA_CERT }}" >ollama_inc.crt
|
||||
- uses: actions/download-artifact@v4
|
||||
with:
|
||||
pattern: build-windows-*
|
||||
path: dist\
|
||||
merge-multiple: true
|
||||
- uses: actions/download-artifact@v4
|
||||
with:
|
||||
pattern: depends-windows-amd64-*
|
||||
path: dist\windows-amd64\
|
||||
merge-multiple: true
|
||||
- run: |
|
||||
& .\scripts\build_windows.ps1 gatherDependencies sign buildInstaller distZip
|
||||
env:
|
||||
KEY_CONTAINER: ${{ vars.KEY_CONTAINER }}
|
||||
- uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: dist-windows
|
||||
path: |
|
||||
dist\OllamaSetup.exe
|
||||
dist\ollama-windows-*.zip
|
||||
|
||||
linux-build:
|
||||
strategy:
|
||||
@@ -225,26 +317,21 @@ jobs:
|
||||
CGO_CFLAGS=${{ env.CGO_CFLAGS }}
|
||||
CGO_CXXFLAGS=${{ env.CGO_CXXFLAGS }}
|
||||
outputs: type=local,dest=dist/${{ matrix.os }}-${{ matrix.arch }}
|
||||
cache-from: type=registry,ref=${{ vars.DOCKER_REPO }}:latest
|
||||
cache-from: type=registry,ref=ollama/ollama:latest
|
||||
cache-to: type=inline
|
||||
- run: |
|
||||
for COMPONENT in bin/* lib/ollama/*; do
|
||||
case "$COMPONENT" in
|
||||
bin/ollama) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}.tar.in ;;
|
||||
lib/ollama/*.so*) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}.tar.in ;;
|
||||
lib/ollama/cuda_sbsa) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}.tar.in ;;
|
||||
lib/ollama/cuda_jetpack5) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}-jetpack5.tar.in ;;
|
||||
lib/ollama/cuda_jetpack6) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}-jetpack6.tar.in ;;
|
||||
lib/ollama/rocm) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}-rocm.tar.in ;;
|
||||
bin/ollama) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}.tar.in ;;
|
||||
lib/ollama/*.so) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}.tar.in ;;
|
||||
lib/ollama/cuda_v11) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}.tar.in ;;
|
||||
lib/ollama/cuda_v12) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}.tar.in ;;
|
||||
lib/ollama/cuda_jetpack5) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}-jetpack5.tar.in ;;
|
||||
lib/ollama/cuda_jetpack6) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}-jetpack6.tar.in ;;
|
||||
lib/ollama/rocm) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}-rocm.tar.in ;;
|
||||
esac
|
||||
done
|
||||
working-directory: dist/${{ matrix.os }}-${{ matrix.arch }}
|
||||
- run: |
|
||||
echo "Manifests"
|
||||
for ARCHIVE in dist/${{ matrix.os }}-${{ matrix.arch }}/*.tar.in ; do
|
||||
echo $ARCHIVE
|
||||
cat $ARCHIVE
|
||||
done
|
||||
- run: |
|
||||
for ARCHIVE in dist/${{ matrix.os }}-${{ matrix.arch }}/*.tar.in; do
|
||||
tar c -C dist/${{ matrix.os }}-${{ matrix.arch }} -T $ARCHIVE --owner 0 --group 0 | pigz -9vc >$(basename ${ARCHIVE//.*/}.tgz);
|
||||
@@ -298,8 +385,8 @@ jobs:
|
||||
context: .
|
||||
platforms: ${{ matrix.os }}/${{ matrix.arch }}
|
||||
build-args: ${{ matrix.build-args }}
|
||||
outputs: type=image,name=${{ vars.DOCKER_REPO }},push-by-digest=true,name-canonical=true,push=true
|
||||
cache-from: type=registry,ref=${{ vars.DOCKER_REPO }}:latest
|
||||
outputs: type=image,name=ollama/ollama,push-by-digest=true,name-canonical=true,push=true
|
||||
cache-from: type=registry,ref=ollama/ollama:latest
|
||||
cache-to: type=inline
|
||||
- run: |
|
||||
mkdir -p ${{ matrix.os }}-${{ matrix.arch }}
|
||||
@@ -331,7 +418,7 @@ jobs:
|
||||
latest=false
|
||||
suffix=${{ matrix.suffix }}
|
||||
images: |
|
||||
${{ vars.DOCKER_REPO }}
|
||||
ollama/ollama
|
||||
tags: |
|
||||
type=ref,enable=true,priority=600,prefix=pr-,event=pr
|
||||
type=semver,pattern={{version}}
|
||||
@@ -341,24 +428,56 @@ jobs:
|
||||
path: ${{ runner.temp }}
|
||||
merge-multiple: true
|
||||
- run: |
|
||||
docker buildx imagetools create $(echo '${{ steps.metadata.outputs.json }}' | jq -cr '.tags | map("-t", .) | join(" ")') $(cat *-${{ matrix.suffix }}.txt | xargs printf '${{ vars.DOCKER_REPO }}@%s ')
|
||||
docker buildx imagetools inspect ${{ vars.DOCKER_REPO }}:${{ steps.metadata.outputs.version }}
|
||||
docker buildx imagetools create $(echo '${{ steps.metadata.outputs.json }}' | jq -cr '.tags | map("-t", .) | join(" ")') $(cat *-${{ matrix.suffix }}.txt | xargs printf 'ollama/ollama@%s ')
|
||||
docker buildx imagetools inspect ollama/ollama:${{ steps.metadata.outputs.version }}
|
||||
working-directory: ${{ runner.temp }}
|
||||
|
||||
# Trigger downstream release process
|
||||
trigger:
|
||||
runs-on: ubuntu-latest
|
||||
environment: release
|
||||
needs: [darwin-build, windows-build, windows-depends, linux-build]
|
||||
needs: [darwin-build, windows-build, windows-depends]
|
||||
steps:
|
||||
- name: Trigger downstream release process
|
||||
run: |
|
||||
curl -L \
|
||||
-X POST \
|
||||
-H "Accept: application/vnd.github+json" \
|
||||
-H "Authorization: Bearer ${{ secrets.RELEASE_TOKEN }}" \
|
||||
-H "X-GitHub-Api-Version: 2022-11-28" \
|
||||
https://api.github.com/repos/ollama/${{ vars.RELEASE_REPO }}/dispatches \
|
||||
-d "{\"event_type\": \"trigger-workflow\", \"client_payload\": {\"run_id\": \"${GITHUB_RUN_ID}\", \"version\": \"${GITHUB_REF_NAME#v}\"}}"
|
||||
|
||||
# Aggregate all the assets and ship a release
|
||||
release:
|
||||
needs: [darwin-sign, windows-sign, linux-build]
|
||||
runs-on: linux
|
||||
environment: release
|
||||
permissions:
|
||||
contents: write
|
||||
env:
|
||||
GH_TOKEN: ${{ github.token }}
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- name: Create or update Release for tag
|
||||
- uses: actions/download-artifact@v4
|
||||
with:
|
||||
name: dist-darwin
|
||||
path: dist
|
||||
- uses: actions/download-artifact@v4
|
||||
with:
|
||||
name: dist-windows
|
||||
path: dist
|
||||
- uses: actions/download-artifact@v4
|
||||
with:
|
||||
pattern: dist-linux-*
|
||||
path: dist
|
||||
merge-multiple: true
|
||||
- run: find . -type f -not -name 'sha256sum.txt' | xargs sha256sum | tee sha256sum.txt
|
||||
working-directory: dist
|
||||
- name: Create or update Release
|
||||
run: |
|
||||
RELEASE_VERSION="$(echo ${GITHUB_REF_NAME} | cut -f1 -d-)"
|
||||
|
||||
echo "Looking for existing release for ${RELEASE_VERSION}"
|
||||
OLD_TAG=$(gh release ls --json name,tagName | jq -r ".[] | select(.name == \"${RELEASE_VERSION}\") | .tagName")
|
||||
if [ -n "$OLD_TAG" ]; then
|
||||
@@ -372,12 +491,5 @@ jobs:
|
||||
--generate-notes \
|
||||
--prerelease
|
||||
fi
|
||||
- name: Trigger downstream release process
|
||||
run: |
|
||||
curl -L \
|
||||
-X POST \
|
||||
-H "Accept: application/vnd.github+json" \
|
||||
-H "Authorization: Bearer ${{ secrets.RELEASE_TOKEN }}" \
|
||||
-H "X-GitHub-Api-Version: 2022-11-28" \
|
||||
https://api.github.com/repos/ollama/${{ vars.RELEASE_REPO }}/dispatches \
|
||||
-d "{\"event_type\": \"trigger-workflow\", \"client_payload\": {\"run_id\": \"${GITHUB_RUN_ID}\", \"version\": \"${GITHUB_REF_NAME#v}\", \"origin\": \"${GITHUB_REPOSITORY}\", \"publish\": \"1\"}}"
|
||||
echo "Uploading artifacts for tag ${GITHUB_REF_NAME}"
|
||||
gh release upload ${GITHUB_REF_NAME} dist/* --clobber
|
||||
|
||||
17
.github/workflows/test.yaml
vendored
17
.github/workflows/test.yaml
vendored
@@ -36,7 +36,7 @@ jobs:
|
||||
| xargs python3 -c "import sys; from pathlib import Path; print(any(Path(x).match(glob) for x in sys.argv[1:] for glob in '$*'.split(' ')))"
|
||||
}
|
||||
|
||||
echo changed=$(changed 'llama/llama.cpp/**/*' 'ml/backend/ggml/ggml/**/*') | tee -a $GITHUB_OUTPUT
|
||||
echo changed=$(changed 'llama/llama.cpp/**' 'ml/backend/ggml/ggml/**') | tee -a $GITHUB_OUTPUT
|
||||
|
||||
linux:
|
||||
needs: [changes]
|
||||
@@ -46,7 +46,7 @@ jobs:
|
||||
include:
|
||||
- preset: CPU
|
||||
- preset: CUDA
|
||||
container: nvidia/cuda:12.8.1-devel-ubuntu22.04
|
||||
container: nvidia/cuda:11.8.0-devel-ubuntu22.04
|
||||
flags: '-DCMAKE_CUDA_ARCHITECTURES=87'
|
||||
- preset: ROCm
|
||||
container: rocm/dev-ubuntu-22.04:6.1.2
|
||||
@@ -78,11 +78,11 @@ jobs:
|
||||
include:
|
||||
- preset: CPU
|
||||
- preset: CUDA
|
||||
install: https://developer.download.nvidia.com/compute/cuda/12.8.0/local_installers/cuda_12.8.0_571.96_windows.exe
|
||||
install: https://developer.download.nvidia.com/compute/cuda/11.3.1/local_installers/cuda_11.3.1_465.89_win10.exe
|
||||
flags: '-DCMAKE_CUDA_ARCHITECTURES=80'
|
||||
- preset: ROCm
|
||||
install: https://download.amd.com/developer/eula/rocm-hub/AMD-Software-PRO-Edition-24.Q4-WinSvr2022-For-HIP.exe
|
||||
flags: '-DAMDGPU_TARGETS=gfx1010 -DCMAKE_C_COMPILER=clang -DCMAKE_CXX_COMPILER=clang++ -DCMAKE_C_FLAGS="-parallel-jobs=4 -Wno-ignored-attributes -Wno-deprecated-pragma" -DCMAKE_CXX_FLAGS="-parallel-jobs=4 -Wno-ignored-attributes -Wno-deprecated-pragma"'
|
||||
flags: '-DAMDGPU_TARGETS=gfx1010'
|
||||
runs-on: windows
|
||||
steps:
|
||||
- run: |
|
||||
@@ -102,7 +102,7 @@ jobs:
|
||||
$ErrorActionPreference = "Stop"
|
||||
if ("${{ steps.cache-install.outputs.cache-hit }}" -ne 'true') {
|
||||
Invoke-WebRequest -Uri "${{ matrix.install }}" -OutFile "install.exe"
|
||||
Start-Process -FilePath .\install.exe -ArgumentList (@("-s", "cudart_12.8", "nvcc_12.8", "cublas_12.8", "cublas_dev_12.8")) -NoNewWindow -Wait
|
||||
Start-Process -FilePath .\install.exe -ArgumentList (@("-s", "cudart_11.3", "nvcc_11.3", "cublas_11.3", "cublas_dev_11.3")) -NoNewWindow -Wait
|
||||
}
|
||||
|
||||
$cudaPath = (Resolve-Path "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\*").path
|
||||
@@ -120,9 +120,6 @@ jobs:
|
||||
echo "$hipPath\bin" | Out-File -FilePath $env:GITHUB_PATH -Encoding utf8 -Append
|
||||
echo "CC=$hipPath\bin\clang.exe" | Out-File -FilePath $env:GITHUB_ENV -Append
|
||||
echo "CXX=$hipPath\bin\clang++.exe" | Out-File -FilePath $env:GITHUB_ENV -Append
|
||||
echo "HIPCXX=$hipPath\bin\clang++.exe" | Out-File -FilePath $env:GITHUB_ENV -Append
|
||||
echo "HIP_PLATFORM=amd" | Out-File -FilePath $env:GITHUB_ENV -Append
|
||||
echo "CMAKE_PREFIX_PATH=$hipPath" | Out-File -FilePath $env:GITHUB_ENV -Append
|
||||
- if: ${{ !cancelled() && steps.cache-install.outputs.cache-hit != 'true' }}
|
||||
uses: actions/cache/save@v4
|
||||
with:
|
||||
@@ -136,8 +133,8 @@ jobs:
|
||||
path: ${{ github.workspace }}\.ccache
|
||||
key: ccache-${{ runner.os }}-${{ runner.arch }}-${{ matrix.preset }}
|
||||
- run: |
|
||||
Import-Module 'C:\Program Files\Microsoft Visual Studio\2022\Enterprise\Common7\Tools\Microsoft.VisualStudio.DevShell.dll'
|
||||
Enter-VsDevShell -VsInstallPath 'C:\Program Files\Microsoft Visual Studio\2022\Enterprise' -SkipAutomaticLocation -DevCmdArguments '-arch=x64 -no_logo'
|
||||
Import-Module 'C:\Program Files (x86)\Microsoft Visual Studio\2019\Enterprise\Common7\Tools\Microsoft.VisualStudio.DevShell.dll'
|
||||
Enter-VsDevShell -VsInstallPath 'C:\Program Files (x86)\Microsoft Visual Studio\2019\Enterprise' -SkipAutomaticLocation -DevCmdArguments '-arch=x64 -no_logo'
|
||||
cmake --preset "${{ matrix.preset }}" ${{ matrix.flags }}
|
||||
cmake --build --parallel --preset "${{ matrix.preset }}"
|
||||
env:
|
||||
|
||||
@@ -78,13 +78,14 @@ if(CMAKE_CUDA_COMPILER)
|
||||
|
||||
find_package(CUDAToolkit)
|
||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/ggml/ggml/src/ggml-cuda)
|
||||
set(OLLAMA_CUDA_INSTALL_DIR ${OLLAMA_INSTALL_DIR}/cuda_v${CUDAToolkit_VERSION_MAJOR})
|
||||
install(TARGETS ggml-cuda
|
||||
RUNTIME_DEPENDENCIES
|
||||
DIRECTORIES ${CUDAToolkit_BIN_DIR} ${CUDAToolkit_LIBRARY_DIR}
|
||||
PRE_INCLUDE_REGEXES cublas cublasLt cudart
|
||||
PRE_EXCLUDE_REGEXES ".*"
|
||||
RUNTIME DESTINATION ${OLLAMA_INSTALL_DIR} COMPONENT CUDA
|
||||
LIBRARY DESTINATION ${OLLAMA_INSTALL_DIR} COMPONENT CUDA
|
||||
RUNTIME DESTINATION ${OLLAMA_CUDA_INSTALL_DIR} COMPONENT CUDA
|
||||
LIBRARY DESTINATION ${OLLAMA_CUDA_INSTALL_DIR} COMPONENT CUDA
|
||||
)
|
||||
endif()
|
||||
|
||||
@@ -115,11 +116,7 @@ if(CMAKE_HIP_COMPILER)
|
||||
|
||||
set(OLLAMA_HIP_INSTALL_DIR ${OLLAMA_INSTALL_DIR}/rocm)
|
||||
install(TARGETS ggml-hip
|
||||
RUNTIME_DEPENDENCY_SET rocm
|
||||
RUNTIME DESTINATION ${OLLAMA_INSTALL_DIR} COMPONENT HIP
|
||||
LIBRARY DESTINATION ${OLLAMA_INSTALL_DIR} COMPONENT HIP
|
||||
)
|
||||
install(RUNTIME_DEPENDENCY_SET rocm
|
||||
RUNTIME_DEPENDENCIES
|
||||
DIRECTORIES ${HIP_BIN_INSTALL_DIR} ${HIP_LIB_INSTALL_DIR}
|
||||
PRE_INCLUDE_REGEXES hipblas rocblas amdhip64 rocsolver amd_comgr hsa-runtime64 rocsparse tinfo rocprofiler-register drm drm_amdgpu numa elf
|
||||
PRE_EXCLUDE_REGEXES ".*"
|
||||
|
||||
@@ -17,12 +17,20 @@
|
||||
"name": "CUDA",
|
||||
"inherits": [ "Default" ]
|
||||
},
|
||||
{
|
||||
"name": "CUDA 11",
|
||||
"inherits": [ "CUDA" ],
|
||||
"cacheVariables": {
|
||||
"CMAKE_CUDA_ARCHITECTURES": "50;52;53;60;61;70;75;80;86",
|
||||
"CMAKE_CUDA_FLAGS": "-Wno-deprecated-gpu-targets"
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "CUDA 12",
|
||||
"inherits": [ "CUDA" ],
|
||||
"cacheVariables": {
|
||||
"CMAKE_CUDA_ARCHITECTURES": "50;60;61;70;75;80;86;87;89;90;90a;120",
|
||||
"CMAKE_CUDA_FLAGS": "-Wno-deprecated-gpu-targets -t 2"
|
||||
"CMAKE_CUDA_FLAGS": "-Wno-deprecated-gpu-targets"
|
||||
}
|
||||
},
|
||||
{
|
||||
@@ -50,7 +58,6 @@
|
||||
"name": "ROCm 6",
|
||||
"inherits": [ "ROCm" ],
|
||||
"cacheVariables": {
|
||||
"CMAKE_HIP_FLAGS": "-parallel-jobs=4",
|
||||
"AMDGPU_TARGETS": "gfx900;gfx940;gfx941;gfx942;gfx1010;gfx1012;gfx1030;gfx1100;gfx1101;gfx1102;gfx1151;gfx1200;gfx1201;gfx906:xnack-;gfx908:xnack-;gfx90a:xnack+;gfx90a:xnack-"
|
||||
}
|
||||
}
|
||||
@@ -71,6 +78,11 @@
|
||||
"configurePreset": "CUDA",
|
||||
"targets": [ "ggml-cuda" ]
|
||||
},
|
||||
{
|
||||
"name": "CUDA 11",
|
||||
"inherits": [ "CUDA" ],
|
||||
"configurePreset": "CUDA 11"
|
||||
},
|
||||
{
|
||||
"name": "CUDA 12",
|
||||
"inherits": [ "CUDA" ],
|
||||
|
||||
@@ -65,7 +65,7 @@ continuation of the sentence:
|
||||
Examples:
|
||||
|
||||
llm/backend/mlx: support the llama architecture
|
||||
CONTRIBUTING: provide clarity on good commit messages, and bad
|
||||
CONTRIBUTING: provide clairity on good commit messages, and bad
|
||||
|
||||
Bad Examples:
|
||||
|
||||
|
||||
26
Dockerfile
26
Dockerfile
@@ -7,13 +7,12 @@ ARG JETPACK5VERSION=r35.4.1
|
||||
ARG JETPACK6VERSION=r36.4.0
|
||||
ARG CMAKEVERSION=3.31.2
|
||||
|
||||
# We require gcc v10 minimum. v10.3 has regressions, so the rockylinux 8.5 AppStream has the latest compatible version
|
||||
# CUDA v11 requires gcc v10. v10.3 has regressions, so the rockylinux 8.5 AppStream has the latest compatible version
|
||||
FROM --platform=linux/amd64 rocm/dev-almalinux-8:${ROCMVERSION}-complete AS base-amd64
|
||||
RUN yum install -y yum-utils \
|
||||
&& yum-config-manager --add-repo https://dl.rockylinux.org/vault/rocky/8.5/AppStream/\$basearch/os/ \
|
||||
&& rpm --import https://dl.rockylinux.org/pub/rocky/RPM-GPG-KEY-Rocky-8 \
|
||||
&& dnf install -y yum-utils ccache gcc-toolset-10-gcc-10.2.1-8.2.el8 gcc-toolset-10-gcc-c++-10.2.1-8.2.el8 gcc-toolset-10-binutils-2.35-11.el8 \
|
||||
&& dnf install -y ccache \
|
||||
&& yum-config-manager --add-repo https://developer.download.nvidia.com/compute/cuda/repos/rhel8/x86_64/cuda-rhel8.repo
|
||||
ENV PATH=/opt/rh/gcc-toolset-10/root/usr/bin:$PATH
|
||||
|
||||
@@ -39,6 +38,15 @@ RUN --mount=type=cache,target=/root/.ccache \
|
||||
&& cmake --build --parallel --preset 'CPU' \
|
||||
&& cmake --install build --component CPU --strip --parallel 8
|
||||
|
||||
FROM base AS cuda-11
|
||||
ARG CUDA11VERSION=11.3
|
||||
RUN dnf install -y cuda-toolkit-${CUDA11VERSION//./-}
|
||||
ENV PATH=/usr/local/cuda-11/bin:$PATH
|
||||
RUN --mount=type=cache,target=/root/.ccache \
|
||||
cmake --preset 'CUDA 11' \
|
||||
&& cmake --build --parallel --preset 'CUDA 11' \
|
||||
&& cmake --install build --component CUDA --strip --parallel 8
|
||||
|
||||
FROM base AS cuda-12
|
||||
ARG CUDA12VERSION=12.8
|
||||
RUN dnf install -y cuda-toolkit-${CUDA12VERSION//./-}
|
||||
@@ -90,21 +98,23 @@ RUN --mount=type=cache,target=/root/.cache/go-build \
|
||||
go build -trimpath -buildmode=pie -o /bin/ollama .
|
||||
|
||||
FROM --platform=linux/amd64 scratch AS amd64
|
||||
COPY --from=cuda-12 dist/lib/ollama /lib/ollama
|
||||
COPY --from=cuda-11 dist/lib/ollama/cuda_v11 /lib/ollama/cuda_v11
|
||||
COPY --from=cuda-12 dist/lib/ollama/cuda_v12 /lib/ollama/cuda_v12
|
||||
|
||||
FROM --platform=linux/arm64 scratch AS arm64
|
||||
COPY --from=cuda-12 dist/lib/ollama /lib/ollama/cuda_sbsa
|
||||
COPY --from=jetpack-5 dist/lib/ollama /lib/ollama/cuda_jetpack5
|
||||
COPY --from=jetpack-6 dist/lib/ollama /lib/ollama/cuda_jetpack6
|
||||
COPY --from=cuda-11 dist/lib/ollama/cuda_v11 /lib/ollama/cuda_v11
|
||||
COPY --from=cuda-12 dist/lib/ollama/cuda_v12 /lib/ollama/cuda_v12
|
||||
COPY --from=jetpack-5 dist/lib/ollama/cuda_v11 /lib/ollama/cuda_jetpack5
|
||||
COPY --from=jetpack-6 dist/lib/ollama/cuda_v12 /lib/ollama/cuda_jetpack6
|
||||
|
||||
FROM scratch AS rocm
|
||||
COPY --from=rocm-6 dist/lib/ollama /lib/ollama
|
||||
COPY --from=rocm-6 dist/lib/ollama/rocm /lib/ollama/rocm
|
||||
|
||||
FROM ${FLAVOR} AS archive
|
||||
COPY --from=cpu dist/lib/ollama /lib/ollama
|
||||
COPY --from=build /bin/ollama /bin/ollama
|
||||
|
||||
FROM ubuntu:24.04
|
||||
FROM ubuntu:20.04
|
||||
RUN apt-get update \
|
||||
&& apt-get install -y ca-certificates \
|
||||
&& apt-get clean \
|
||||
|
||||
13
README.md
13
README.md
@@ -1,6 +1,6 @@
|
||||
<div align="center">
|
||||
<a href="https://ollama.com">
|
||||
<img alt="ollama" width="240" src="https://github.com/ollama/ollama/assets/3325447/0d0b44e2-8f4a-4e99-9b52-a5c1c741c8f7">
|
||||
<img alt="ollama" height="200px" src="https://github.com/ollama/ollama/assets/3325447/0d0b44e2-8f4a-4e99-9b52-a5c1c741c8f7">
|
||||
</a>
|
||||
</div>
|
||||
|
||||
@@ -10,7 +10,7 @@ Get up and running with large language models.
|
||||
|
||||
### macOS
|
||||
|
||||
[Download](https://ollama.com/download/Ollama.dmg)
|
||||
[Download](https://ollama.com/download/Ollama-darwin.zip)
|
||||
|
||||
### Windows
|
||||
|
||||
@@ -360,7 +360,7 @@ See the [API documentation](./docs/api.md) for all endpoints.
|
||||
- [Tkinter-based client](https://github.com/chyok/ollama-gui) (Python tkinter-based Client for Ollama)
|
||||
- [LLMChat](https://github.com/trendy-design/llmchat) (Privacy focused, 100% local, intuitive all-in-one chat interface)
|
||||
- [Local Multimodal AI Chat](https://github.com/Leon-Sander/Local-Multimodal-AI-Chat) (Ollama-based LLM Chat with support for multiple features, including PDF RAG, voice chat, image-based interactions, and integration with OpenAI.)
|
||||
- [ARGO](https://github.com/xark-argo/argo) (Locally download and run Ollama and Huggingface models with RAG and deep research on Mac/Windows/Linux)
|
||||
- [ARGO](https://github.com/xark-argo/argo) (Locally download and run Ollama and Huggingface models with RAG on Mac/Windows/Linux)
|
||||
- [OrionChat](https://github.com/EliasPereirah/OrionChat) - OrionChat is a web interface for chatting with different AI providers
|
||||
- [G1](https://github.com/bklieger-groq/g1) (Prototype of using prompting strategies to improve the LLM's reasoning through o1-like reasoning chains.)
|
||||
- [Web management](https://github.com/lemonit-eric-mao/ollama-web-management) (Web management page)
|
||||
@@ -409,8 +409,6 @@ See the [API documentation](./docs/api.md) for all endpoints.
|
||||
- [macLlama (macOS native)](https://github.com/hellotunamayo/macLlama) (A native macOS GUI application for interacting with Ollama models, featuring a chat interface.)
|
||||
- [GPTranslate](https://github.com/philberndt/GPTranslate) (A fast and lightweight, AI powered desktop translation application written with Rust and Tauri. Features real-time translation with OpenAI/Azure/Ollama.)
|
||||
- [ollama launcher](https://github.com/NGC13009/ollama-launcher) (A launcher for Ollama, aiming to provide users with convenient functions such as ollama server launching, management, or configuration.)
|
||||
- [ai-hub](https://github.com/Aj-Seven/ai-hub) (AI Hub supports multiple models via API keys and Chat support via Ollama API.)
|
||||
- [Mayan EDMS](https://gitlab.com/mayan-edms/mayan-edms) (Open source document management system to organize, tag, search, and automate your files with powerful Ollama driven workflows.)
|
||||
|
||||
### Cloud
|
||||
|
||||
@@ -456,7 +454,6 @@ See the [API documentation](./docs/api.md) for all endpoints.
|
||||
- [GGUF-to-Ollama](https://github.com/jonathanhecl/gguf-to-ollama) - Importing GGUF to Ollama made easy (multiplatform)
|
||||
- [AWS-Strands-With-Ollama](https://github.com/rapidarchitect/ollama_strands) - AWS Strands Agents with Ollama Examples
|
||||
- [ollama-multirun](https://github.com/attogram/ollama-multirun) - A bash shell script to run a single prompt against any or all of your locally installed ollama models, saving the output and performance statistics as easily navigable web pages. ([Demo](https://attogram.github.io/ai_test_zone/))
|
||||
- [ollama-bash-toolshed](https://github.com/attogram/ollama-bash-toolshed) - Bash scripts to chat with tool using models. Add new tools to your shed with ease. Runs on Ollama.
|
||||
|
||||
### Apple Vision Pro
|
||||
|
||||
@@ -595,12 +592,10 @@ See the [API documentation](./docs/api.md) for all endpoints.
|
||||
- [mcp-llm](https://github.com/sammcj/mcp-llm) (MCP Server to allow LLMs to call other LLMs)
|
||||
- [SimpleOllamaUnity](https://github.com/HardCodeDev777/SimpleOllamaUnity) (Unity Engine extension for communicating with Ollama in a few lines of code. Also works at runtime)
|
||||
- [UnityCodeLama](https://github.com/HardCodeDev777/UnityCodeLama) (Unity Edtior tool to analyze scripts via Ollama)
|
||||
- [NativeMind](https://github.com/NativeMindBrowser/NativeMindExtension) (Private, on-device AI Assistant, no cloud dependencies)
|
||||
- [GMAI - Gradle Managed AI](https://gmai.premex.se/) (Gradle plugin for automated Ollama lifecycle management during build phases)
|
||||
|
||||
### Supported backends
|
||||
|
||||
- [llama.cpp](https://github.com/ggml-org/llama.cpp) project founded by Georgi Gerganov.
|
||||
- [llama.cpp](https://github.com/ggerganov/llama.cpp) project founded by Georgi Gerganov.
|
||||
|
||||
### Observability
|
||||
- [Opik](https://www.comet.com/docs/opik/cookbook/ollama) is an open-source platform to debug, evaluate, and monitor your LLM applications, RAG systems, and agentic workflows with comprehensive tracing, automated evaluations, and production-ready dashboards. Opik supports native intergration to Ollama.
|
||||
|
||||
@@ -42,23 +42,6 @@ type Client struct {
|
||||
|
||||
func checkError(resp *http.Response, body []byte) error {
|
||||
if resp.StatusCode < http.StatusBadRequest {
|
||||
if len(body) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
// streams can contain error message even with StatusOK
|
||||
var errorResponse struct {
|
||||
Error string `json:"error,omitempty"`
|
||||
}
|
||||
|
||||
if err := json.Unmarshal(body, &errorResponse); err != nil {
|
||||
return fmt.Errorf("unmarshal: %w", err)
|
||||
}
|
||||
|
||||
if errorResponse.Error != "" {
|
||||
return errors.New(errorResponse.Error)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -230,9 +213,25 @@ func (c *Client) stream(ctx context.Context, method, path string, data any, fn f
|
||||
scanBuf := make([]byte, 0, maxBufferSize)
|
||||
scanner.Buffer(scanBuf, maxBufferSize)
|
||||
for scanner.Scan() {
|
||||
var errorResponse struct {
|
||||
Error string `json:"error,omitempty"`
|
||||
}
|
||||
|
||||
bts := scanner.Bytes()
|
||||
if err := checkError(response, bts); err != nil {
|
||||
return err
|
||||
if err := json.Unmarshal(bts, &errorResponse); err != nil {
|
||||
return fmt.Errorf("unmarshal: %w", err)
|
||||
}
|
||||
|
||||
if errorResponse.Error != "" {
|
||||
return errors.New(errorResponse.Error)
|
||||
}
|
||||
|
||||
if response.StatusCode >= http.StatusBadRequest {
|
||||
return StatusError{
|
||||
StatusCode: response.StatusCode,
|
||||
Status: response.Status,
|
||||
ErrorMessage: errorResponse.Error,
|
||||
}
|
||||
}
|
||||
|
||||
if err := fn(bts); err != nil {
|
||||
|
||||
18
api/types.go
18
api/types.go
@@ -143,7 +143,6 @@ type Message struct {
|
||||
Thinking string `json:"thinking,omitempty"`
|
||||
Images []ImageData `json:"images,omitempty"`
|
||||
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
|
||||
ToolName string `json:"tool_name,omitempty"`
|
||||
}
|
||||
|
||||
func (m *Message) UnmarshalJSON(b []byte) error {
|
||||
@@ -286,6 +285,7 @@ type Options struct {
|
||||
PresencePenalty float32 `json:"presence_penalty,omitempty"`
|
||||
FrequencyPenalty float32 `json:"frequency_penalty,omitempty"`
|
||||
Stop []string `json:"stop,omitempty"`
|
||||
ShiftContext bool `json:"shift_context,omitempty"`
|
||||
}
|
||||
|
||||
// Runner options which must be set when the model is loaded into memory
|
||||
@@ -468,14 +468,13 @@ type ListModelResponse struct {
|
||||
|
||||
// ProcessModelResponse is a single model description in [ProcessResponse].
|
||||
type ProcessModelResponse struct {
|
||||
Name string `json:"name"`
|
||||
Model string `json:"model"`
|
||||
Size int64 `json:"size"`
|
||||
Digest string `json:"digest"`
|
||||
Details ModelDetails `json:"details,omitempty"`
|
||||
ExpiresAt time.Time `json:"expires_at"`
|
||||
SizeVRAM int64 `json:"size_vram"`
|
||||
ContextLength int `json:"context_length"`
|
||||
Name string `json:"name"`
|
||||
Model string `json:"model"`
|
||||
Size int64 `json:"size"`
|
||||
Digest string `json:"digest"`
|
||||
Details ModelDetails `json:"details,omitempty"`
|
||||
ExpiresAt time.Time `json:"expires_at"`
|
||||
SizeVRAM int64 `json:"size_vram"`
|
||||
}
|
||||
|
||||
type TokenResponse struct {
|
||||
@@ -665,6 +664,7 @@ func DefaultOptions() Options {
|
||||
PresencePenalty: 0.0,
|
||||
FrequencyPenalty: 0.0,
|
||||
Seed: -1,
|
||||
ShiftContext: true,
|
||||
|
||||
Runner: Runner{
|
||||
// options set when the model is loaded
|
||||
|
||||
35
auth/auth.go
35
auth/auth.go
@@ -18,8 +18,6 @@ import (
|
||||
|
||||
const defaultPrivateKey = "id_ed25519"
|
||||
|
||||
var ErrInvalidToken = errors.New("invalid token")
|
||||
|
||||
func keyPath() (string, error) {
|
||||
home, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
@@ -29,39 +27,6 @@ func keyPath() (string, error) {
|
||||
return filepath.Join(home, ".ollama", defaultPrivateKey), nil
|
||||
}
|
||||
|
||||
func parseToken(token string) (key, sig []byte, _ error) {
|
||||
keyData, sigData, ok := strings.Cut(token, ":")
|
||||
if !ok {
|
||||
return nil, nil, fmt.Errorf("identity: parseToken: %w", ErrInvalidToken)
|
||||
}
|
||||
sig, err := base64.StdEncoding.DecodeString(sigData)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("identity: parseToken: base64 decoding signature: %w", err)
|
||||
}
|
||||
return []byte(keyData), sig, nil
|
||||
}
|
||||
|
||||
func Authenticate(token, checkData string) (ssh.PublicKey, error) {
|
||||
keyShort, sigBytes, err := parseToken(token)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
keyLong := append([]byte("ssh-ed25519 "), keyShort...)
|
||||
pub, _, _, _, err := ssh.ParseAuthorizedKey(keyLong)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := pub.Verify([]byte(checkData), &ssh.Signature{
|
||||
Format: pub.Type(),
|
||||
Blob: sigBytes,
|
||||
}); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return pub, nil
|
||||
}
|
||||
|
||||
func GetPublicKey() (string, error) {
|
||||
keyPath, err := keyPath()
|
||||
if err != nil {
|
||||
|
||||
@@ -1,254 +0,0 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"regexp"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"golang.org/x/crypto/ssh"
|
||||
)
|
||||
|
||||
type KeyEntry struct {
|
||||
Name string
|
||||
PublicKey string
|
||||
Endpoints []string
|
||||
}
|
||||
|
||||
type KeyPermission struct {
|
||||
Name string
|
||||
Endpoints []string
|
||||
}
|
||||
|
||||
type APIPermissions struct {
|
||||
permissions map[string]*KeyPermission
|
||||
lastModified time.Time
|
||||
mutex sync.RWMutex
|
||||
}
|
||||
|
||||
var ws = regexp.MustCompile(`\s+`)
|
||||
|
||||
func authkeyPath() (string, error) {
|
||||
home, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return filepath.Join(home, ".ollama", "authorized_keys"), nil
|
||||
}
|
||||
|
||||
func NewAPIPermissions() *APIPermissions {
|
||||
return &APIPermissions{
|
||||
permissions: make(map[string]*KeyPermission),
|
||||
mutex: sync.RWMutex{},
|
||||
}
|
||||
}
|
||||
|
||||
func (ap *APIPermissions) ReloadIfNeeded() error {
|
||||
ap.mutex.Lock()
|
||||
defer ap.mutex.Unlock()
|
||||
|
||||
filename, err := authkeyPath()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
fileInfo, err := os.Stat(filename)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to stat file: %v", err)
|
||||
}
|
||||
|
||||
if !fileInfo.ModTime().After(ap.lastModified) {
|
||||
return nil
|
||||
}
|
||||
|
||||
file, err := os.Open(filename)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to open file: %v", err)
|
||||
}
|
||||
defer file.Close()
|
||||
|
||||
ap.lastModified = fileInfo.ModTime()
|
||||
return ap.parse(file)
|
||||
}
|
||||
|
||||
func (ap *APIPermissions) parse(r io.Reader) error {
|
||||
ap.permissions = make(map[string]*KeyPermission)
|
||||
|
||||
scanner := bufio.NewScanner(r)
|
||||
var cnt int
|
||||
for scanner.Scan() {
|
||||
cnt += 1
|
||||
line := strings.TrimSpace(scanner.Text())
|
||||
|
||||
if line == "" || strings.HasPrefix(line, "#") {
|
||||
continue
|
||||
}
|
||||
|
||||
line = ws.ReplaceAllString(line, " ")
|
||||
|
||||
entry, err := ap.parseLine(line)
|
||||
if err != nil {
|
||||
slog.Warn(fmt.Sprintf("authorized_keys line %d: skipping invalid line: %v\n", cnt, err))
|
||||
continue
|
||||
}
|
||||
|
||||
var pubKeyStr string
|
||||
|
||||
if entry.PublicKey == "*" {
|
||||
pubKeyStr = "*"
|
||||
} else {
|
||||
pubKey, err := ap.validateAndDecodeKey(entry)
|
||||
if err != nil {
|
||||
slog.Warn(fmt.Sprintf("authorized_keys line %d: invalid key for %s: %v\n", cnt, entry.Name, err))
|
||||
continue
|
||||
}
|
||||
pubKeyStr = pubKey
|
||||
}
|
||||
|
||||
if perm, exists := ap.permissions[pubKeyStr]; exists {
|
||||
if perm.Name == "default" {
|
||||
perm.Name = entry.Name
|
||||
}
|
||||
if len(perm.Endpoints) == 1 && perm.Endpoints[0] == "*" {
|
||||
// skip redundant entries
|
||||
continue
|
||||
} else if len(entry.Endpoints) == 1 && entry.Endpoints[0] == "*" {
|
||||
// overwrite redundant entries
|
||||
perm.Endpoints = entry.Endpoints
|
||||
} else {
|
||||
perm.Endpoints = append(perm.Endpoints, entry.Endpoints...)
|
||||
}
|
||||
} else {
|
||||
ap.permissions[pubKeyStr] = &KeyPermission{
|
||||
Name: entry.Name,
|
||||
Endpoints: entry.Endpoints,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return scanner.Err()
|
||||
}
|
||||
|
||||
func (ap *APIPermissions) parseLine(line string) (*KeyEntry, error) {
|
||||
parts := strings.SplitN(line, " ", 4)
|
||||
if len(parts) < 2 {
|
||||
return nil, fmt.Errorf("key type and public key not found")
|
||||
}
|
||||
|
||||
kind, b64Key := parts[0], parts[1]
|
||||
name := "default"
|
||||
eps := "*"
|
||||
|
||||
if len(parts) >= 3 && parts[2] != "" {
|
||||
if parts[2] != "*" {
|
||||
name = parts[2]
|
||||
}
|
||||
}
|
||||
|
||||
if len(parts) == 4 && parts[3] != "" {
|
||||
eps = parts[3]
|
||||
}
|
||||
|
||||
if kind != "ssh-ed25519" && kind != "*" {
|
||||
return nil, fmt.Errorf("unsupported key type %s", kind)
|
||||
}
|
||||
|
||||
if kind == "*" && b64Key != "*" {
|
||||
return nil, fmt.Errorf("unsupported key type")
|
||||
}
|
||||
|
||||
var endpoints []string
|
||||
if eps == "*" {
|
||||
endpoints = []string{"*"}
|
||||
} else {
|
||||
for _, e := range strings.Split(eps, ",") {
|
||||
e = strings.TrimSpace(e)
|
||||
if e == "" {
|
||||
return nil, fmt.Errorf("empty endpoint in list")
|
||||
} else if e == "*" {
|
||||
endpoints = []string{"*"}
|
||||
break
|
||||
}
|
||||
endpoints = append(endpoints, e)
|
||||
}
|
||||
}
|
||||
|
||||
return &KeyEntry{
|
||||
PublicKey: b64Key,
|
||||
Name: name,
|
||||
Endpoints: endpoints,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (ap *APIPermissions) validateAndDecodeKey(entry *KeyEntry) (string, error) {
|
||||
keyBlob, err := base64.StdEncoding.DecodeString(entry.PublicKey)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("base64 decode: %w", err)
|
||||
}
|
||||
pub, err := ssh.ParsePublicKey(keyBlob)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("parse key: %w", err)
|
||||
}
|
||||
if pub.Type() != ssh.KeyAlgoED25519 {
|
||||
return "", fmt.Errorf("key is not Ed25519")
|
||||
}
|
||||
|
||||
return entry.PublicKey, nil
|
||||
}
|
||||
|
||||
func (ap *APIPermissions) Authorize(pubKey ssh.PublicKey, endpoint string) (bool, string, error) {
|
||||
if err := ap.ReloadIfNeeded(); err != nil {
|
||||
return false, "unknown", err
|
||||
}
|
||||
|
||||
ap.mutex.RLock()
|
||||
defer ap.mutex.RUnlock()
|
||||
|
||||
if wildcardPerm, exists := ap.permissions["*"]; exists {
|
||||
if len(wildcardPerm.Endpoints) == 1 && wildcardPerm.Endpoints[0] == "*" {
|
||||
return true, wildcardPerm.Name, nil
|
||||
}
|
||||
|
||||
for _, allowedEndpoint := range wildcardPerm.Endpoints {
|
||||
if allowedEndpoint == endpoint {
|
||||
return true, wildcardPerm.Name, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
keyString := string(ssh.MarshalAuthorizedKey(pubKey))
|
||||
parts := strings.SplitN(keyString, " ", 2)
|
||||
var base64Key string
|
||||
if len(parts) > 1 {
|
||||
base64Key = parts[1]
|
||||
} else {
|
||||
base64Key = parts[0]
|
||||
}
|
||||
|
||||
base64Key = strings.TrimSpace(base64Key)
|
||||
|
||||
perm, exists := ap.permissions[base64Key]
|
||||
if !exists {
|
||||
return false, "unknown", nil
|
||||
}
|
||||
|
||||
if len(perm.Endpoints) == 1 && perm.Endpoints[0] == "*" {
|
||||
return true, perm.Name, nil
|
||||
}
|
||||
|
||||
for _, allowedEndpoint := range perm.Endpoints {
|
||||
if allowedEndpoint == endpoint {
|
||||
return true, perm.Name, nil
|
||||
}
|
||||
}
|
||||
|
||||
return false, "unknown", nil
|
||||
}
|
||||
@@ -1,133 +0,0 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"reflect"
|
||||
"testing"
|
||||
)
|
||||
|
||||
const validB64 = "AAAAC3NzaC1lZDI1NTE5AAAAICy1v/Sn0kGhu1LXzCsnx3wlk5ESdncS66JWo13yeJod"
|
||||
|
||||
func TestParse(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
file string
|
||||
want map[string]*KeyPermission
|
||||
}{
|
||||
{
|
||||
name: "two fields only defaults",
|
||||
file: "ssh-ed25519 " + validB64 + "\n",
|
||||
want: map[string]*KeyPermission{
|
||||
validB64: {
|
||||
Name: "default",
|
||||
Endpoints: []string{"*"},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "extra whitespace collapsed and default endpoints",
|
||||
file: "ssh-ed25519 " + validB64 + " alice\n",
|
||||
want: map[string]*KeyPermission{
|
||||
validB64: {
|
||||
Name: "alice",
|
||||
Endpoints: []string{"*"},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "four fields full",
|
||||
file: "ssh-ed25519 " + validB64 + " bob /api/foo,/api/bar\n",
|
||||
want: map[string]*KeyPermission{
|
||||
validB64: {
|
||||
Name: "bob",
|
||||
Endpoints: []string{"/api/foo", "/api/bar"},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "comment lines ignored and multiple entries",
|
||||
file: "# header\n\nssh-ed25519 " + validB64 + " user1\nssh-ed25519 " + validB64 + " user2 /api/x\n",
|
||||
want: map[string]*KeyPermission{
|
||||
validB64: {
|
||||
Name: "user1",
|
||||
Endpoints: []string{"*"},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "three entries variety",
|
||||
file: "ssh-ed25519 " + validB64 + "\nssh-ed25519 " + validB64 + " alice /api/a,/api/b\nssh-ed25519 " + validB64 + " bob /api/c\n",
|
||||
want: map[string]*KeyPermission{
|
||||
validB64: {
|
||||
Name: "alice",
|
||||
Endpoints: []string{"*"},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "two entries w/ wildcard",
|
||||
file: "ssh-ed25519 " + validB64 + " alice /api/a\n* * * /api/b\n",
|
||||
want: map[string]*KeyPermission{
|
||||
validB64: {
|
||||
Name: "alice",
|
||||
Endpoints: []string{"/api/a"},
|
||||
},
|
||||
"*": {
|
||||
Name: "default",
|
||||
Endpoints: []string{"/api/b"},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "tags for everyone",
|
||||
file: "* * * /api/tags",
|
||||
want: map[string]*KeyPermission{
|
||||
"*": {
|
||||
Name: "default",
|
||||
Endpoints: []string{"/api/tags"},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "default name",
|
||||
file: "* * somename",
|
||||
want: map[string]*KeyPermission{
|
||||
"*": {
|
||||
Name: "somename",
|
||||
Endpoints: []string{"*"},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "unsupported key type",
|
||||
file: "ssh-rsa AAAAB3Nza...\n",
|
||||
want: map[string]*KeyPermission{},
|
||||
},
|
||||
{
|
||||
name: "bad base64",
|
||||
file: "ssh-ed25519 invalid@@@\n",
|
||||
want: map[string]*KeyPermission{},
|
||||
},
|
||||
{
|
||||
name: "just an asterix",
|
||||
file: "*\n",
|
||||
want: map[string]*KeyPermission{},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
perms := NewAPIPermissions()
|
||||
err := perms.parse(bytes.NewBufferString(tc.file))
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if len(perms.permissions) != len(tc.want) {
|
||||
t.Fatalf("got %d entries, want %d", len(perms.permissions), len(tc.want))
|
||||
}
|
||||
if !reflect.DeepEqual(perms.permissions, tc.want) {
|
||||
t.Errorf("got %+v, want %+v", perms.permissions, tc.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
20
cmd/cmd.go
20
cmd/cmd.go
@@ -583,13 +583,12 @@ func ListRunningHandler(cmd *cobra.Command, args []string) error {
|
||||
} else {
|
||||
until = format.HumanTime(m.ExpiresAt, "Never")
|
||||
}
|
||||
ctxStr := strconv.Itoa(m.ContextLength)
|
||||
data = append(data, []string{m.Name, m.Digest[:12], format.HumanBytes(m.Size), procStr, ctxStr, until})
|
||||
data = append(data, []string{m.Name, m.Digest[:12], format.HumanBytes(m.Size), procStr, until})
|
||||
}
|
||||
}
|
||||
|
||||
table := tablewriter.NewWriter(os.Stdout)
|
||||
table.SetHeader([]string{"NAME", "ID", "SIZE", "PROCESSOR", "CONTEXT", "UNTIL"})
|
||||
table.SetHeader([]string{"NAME", "ID", "SIZE", "PROCESSOR", "UNTIL"})
|
||||
table.SetHeaderAlignment(tablewriter.ALIGN_LEFT)
|
||||
table.SetAlignment(tablewriter.ALIGN_LEFT)
|
||||
table.SetHeaderLine(false)
|
||||
@@ -1080,11 +1079,10 @@ func chat(cmd *cobra.Command, opts runOptions) (*api.Message, error) {
|
||||
var state *displayResponseState = &displayResponseState{}
|
||||
var latest api.ChatResponse
|
||||
var fullResponse strings.Builder
|
||||
var role string
|
||||
var thinkTagOpened bool = false
|
||||
var thinkTagClosed bool = false
|
||||
|
||||
role := "assistant"
|
||||
|
||||
fn := func(response api.ChatResponse) error {
|
||||
if response.Message.Content != "" || !opts.HideThinking {
|
||||
p.StopAndClear()
|
||||
@@ -1137,14 +1135,6 @@ func chat(cmd *cobra.Command, opts runOptions) (*api.Message, error) {
|
||||
if errors.Is(err, context.Canceled) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// this error should ideally be wrapped properly by the client
|
||||
if strings.Contains(err.Error(), "upstream error") {
|
||||
p.StopAndClear()
|
||||
fmt.Println("An error occurred while processing your message. Please try again.")
|
||||
fmt.Println()
|
||||
return nil, nil
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -1426,13 +1416,13 @@ func NewCLI() *cobra.Command {
|
||||
|
||||
createCmd := &cobra.Command{
|
||||
Use: "create MODEL",
|
||||
Short: "Create a model",
|
||||
Short: "Create a model from a Modelfile",
|
||||
Args: cobra.ExactArgs(1),
|
||||
PreRunE: checkServerHeartbeat,
|
||||
RunE: CreateHandler,
|
||||
}
|
||||
|
||||
createCmd.Flags().StringP("file", "f", "", "Name of the Modelfile (default \"Modelfile\")")
|
||||
createCmd.Flags().StringP("file", "f", "", "Name of the Modelfile (default \"Modelfile\"")
|
||||
createCmd.Flags().StringP("quantize", "q", "", "Quantize model to this level (e.g. q4_K_M)")
|
||||
|
||||
showCmd := &cobra.Command{
|
||||
|
||||
@@ -385,21 +385,18 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error {
|
||||
case "modelfile":
|
||||
fmt.Println(resp.Modelfile)
|
||||
case "parameters":
|
||||
fmt.Println("Model defined parameters:")
|
||||
if resp.Parameters == "" {
|
||||
fmt.Println(" No additional parameters were specified for this model.")
|
||||
fmt.Println("No parameters were specified for this model.")
|
||||
} else {
|
||||
for _, l := range strings.Split(resp.Parameters, "\n") {
|
||||
fmt.Printf(" %s\n", l)
|
||||
if len(opts.Options) > 0 {
|
||||
fmt.Println("User defined parameters:")
|
||||
for k, v := range opts.Options {
|
||||
fmt.Printf("%-*s %v\n", 30, k, v)
|
||||
}
|
||||
fmt.Println()
|
||||
}
|
||||
}
|
||||
fmt.Println()
|
||||
if len(opts.Options) > 0 {
|
||||
fmt.Println("User defined parameters:")
|
||||
for k, v := range opts.Options {
|
||||
fmt.Printf(" %-*s %v\n", 30, k, v)
|
||||
}
|
||||
fmt.Println()
|
||||
fmt.Println("Model defined parameters:")
|
||||
fmt.Println(resp.Parameters)
|
||||
}
|
||||
case "system":
|
||||
switch {
|
||||
|
||||
@@ -190,8 +190,6 @@ func ConvertModel(fsys fs.FS, f *os.File) error {
|
||||
conv = &gemma2Model{}
|
||||
case "Gemma3ForCausalLM", "Gemma3ForConditionalGeneration":
|
||||
conv = &gemma3Model{Architecture: p.Architectures[0]}
|
||||
case "Gemma3nForConditionalGeneration":
|
||||
conv = &gemma3nModel{}
|
||||
case "Phi3ForCausalLM":
|
||||
conv = &phi3Model{}
|
||||
case "Qwen2ForCausalLM":
|
||||
|
||||
@@ -1,165 +0,0 @@
|
||||
package convert
|
||||
|
||||
import (
|
||||
"slices"
|
||||
"strings"
|
||||
|
||||
"github.com/ollama/ollama/fs/ggml"
|
||||
"github.com/pdevine/tensor"
|
||||
"github.com/pdevine/tensor/native"
|
||||
"gonum.org/v1/gonum/stat/distuv"
|
||||
)
|
||||
|
||||
type gemma3nModel struct {
|
||||
ModelParameters
|
||||
|
||||
TextModel struct {
|
||||
ActivationSparsityPattern []float32 `json:"activation_sparsity_pattern"`
|
||||
AltupActiveIdx uint32 `json:"altup_active_idx"`
|
||||
AltupCoefClip float32 `json:"altup_coef_clip"`
|
||||
AltupCorrectScale bool `json:"altup_correct_scale"`
|
||||
AltupLRMultiplier float32 `json:"altup_lr_multiplier"`
|
||||
AltupNumInputs uint32 `json:"altup_num_inputs"`
|
||||
HeadDim uint32 `json:"head_dim"`
|
||||
HiddenSize uint32 `json:"hidden_size"`
|
||||
HiddenSizePerLayerInput uint32 `json:"hidden_size_per_layer_input"`
|
||||
IntermediateSize uint32 `json:"intermediate_size"`
|
||||
MaxPositionEmbeddings uint32 `json:"max_position_embeddings"`
|
||||
NumAttentionHeads uint32 `json:"num_attention_heads"`
|
||||
NumHiddenLayers uint32 `json:"num_hidden_layers"`
|
||||
NumKeyValueHeads uint32 `json:"num_key_value_heads"`
|
||||
NumKVSharedLayers uint32 `json:"num_kv_shared_layers"`
|
||||
RMSNormEPS float32 `json:"rms_norm_eps"`
|
||||
RopeLocalBaseFreq float32 `json:"rope_local_base_freq"`
|
||||
RopeTheta float32 `json:"rope_theta"`
|
||||
SlidingWindow uint32 `json:"sliding_window"`
|
||||
LayerTypes []string `json:"layer_types"`
|
||||
} `json:"text_config"`
|
||||
VisionModel struct{} `json:"vision_config"`
|
||||
}
|
||||
|
||||
func (m *gemma3nModel) KV(t *Tokenizer) ggml.KV {
|
||||
kv := m.ModelParameters.KV(t)
|
||||
kv["general.architecture"] = "gemma3n"
|
||||
kv["gemma3n.activation_sparsity_scale"] = slices.Collect(func(yield func(float32) bool) {
|
||||
norm := distuv.Normal{Mu: 0, Sigma: 1}
|
||||
for _, v := range m.TextModel.ActivationSparsityPattern {
|
||||
if !yield(float32(norm.Quantile(float64(v)))) {
|
||||
break
|
||||
}
|
||||
}
|
||||
})
|
||||
kv["gemma3n.altup.active_idx"] = m.TextModel.AltupActiveIdx
|
||||
kv["gemma3n.altup.correct_scale"] = m.TextModel.AltupCorrectScale
|
||||
kv["gemma3n.altup.lr_multiplier"] = m.TextModel.AltupLRMultiplier
|
||||
kv["gemma3n.altup.num_inputs"] = m.TextModel.AltupNumInputs
|
||||
kv["gemma3n.attention.head_count_kv"] = m.TextModel.NumKeyValueHeads
|
||||
kv["gemma3n.attention.head_count"] = m.TextModel.NumAttentionHeads
|
||||
kv["gemma3n.attention.layer_norm_rms_epsilon"] = m.TextModel.RMSNormEPS
|
||||
kv["gemma3n.attention.sliding_window"] = m.TextModel.SlidingWindow
|
||||
kv["gemma3n.attention.sliding_window_pattern"] = slices.Collect(func(yield func(bool) bool) {
|
||||
for _, t := range m.TextModel.LayerTypes {
|
||||
if !yield(t == "sliding_attention") {
|
||||
break
|
||||
}
|
||||
}
|
||||
})
|
||||
kv["gemma3n.attention.shared_kv_layers"] = m.TextModel.NumKVSharedLayers
|
||||
kv["gemma3n.block_count"] = m.TextModel.NumHiddenLayers
|
||||
kv["gemma3n.context_length"] = m.TextModel.MaxPositionEmbeddings
|
||||
kv["gemma3n.embedding_length_per_layer_input"] = m.TextModel.HiddenSizePerLayerInput
|
||||
kv["gemma3n.embedding_length"] = m.TextModel.HiddenSize
|
||||
kv["gemma3n.feed_forward_length"] = m.TextModel.IntermediateSize
|
||||
kv["gemma3n.head_dim"] = m.TextModel.HeadDim
|
||||
kv["gemma3n.rope.freq_base_local"] = m.TextModel.RopeLocalBaseFreq
|
||||
kv["gemma3n.rope.freq_base"] = m.TextModel.RopeTheta
|
||||
return kv
|
||||
}
|
||||
|
||||
func (m *gemma3nModel) Tensors(ts []Tensor) []*ggml.Tensor {
|
||||
out, ts := mergeTensors(ts,
|
||||
merge{"altup_proj.*.weight", "altup_proj.weight"},
|
||||
merge{"altup_unembd_proj.*.weight", "altup_unembd_proj.weight"},
|
||||
)
|
||||
|
||||
for _, t := range ts {
|
||||
switch {
|
||||
case strings.Contains(t.Name(), "audio_tower"),
|
||||
strings.Contains(t.Name(), "embed_audio"),
|
||||
strings.Contains(t.Name(), "vision_tower"),
|
||||
strings.Contains(t.Name(), "embed_vision"):
|
||||
// TODO: handle audio and vision towers
|
||||
continue
|
||||
case strings.Contains(t.Name(), "altup_predict_coef"),
|
||||
strings.Contains(t.Name(), "altup_correct_coef"):
|
||||
if m.TextModel.AltupCoefClip > 0 {
|
||||
t.SetRepacker(func(name string, data []float32, shape []uint64) (_ []float32, err error) {
|
||||
dims := make([]int, len(shape))
|
||||
for i := range shape {
|
||||
dims[i] = int(shape[i])
|
||||
}
|
||||
|
||||
var t tensor.Tensor = tensor.New(tensor.WithShape(dims...), tensor.WithBacking(data))
|
||||
|
||||
t, err = tensor.Clamp(t, -m.TextModel.AltupCoefClip, m.TextModel.AltupCoefClip)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := t.Reshape(t.Shape().TotalSize()); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return native.VectorF32(t.(*tensor.Dense))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
out = append(out, &ggml.Tensor{
|
||||
Name: t.Name(),
|
||||
Kind: t.Kind(),
|
||||
Shape: t.Shape(),
|
||||
WriterTo: t,
|
||||
})
|
||||
}
|
||||
|
||||
return out
|
||||
}
|
||||
|
||||
func (m *gemma3nModel) Replacements() []string {
|
||||
return []string{
|
||||
"model.language_model.embed_tokens_per_layer", "per_layer_token_embd",
|
||||
"model.language_model.embed_tokens", "token_embd",
|
||||
"model.language_model.per_layer_model_projection", "per_layer_model_proj",
|
||||
"model.language_model.per_layer_projection_norm", "per_layer_proj_norm", "model.language_model.altup_projections", "altup_proj",
|
||||
"model.language_model.altup_unembed_projections", "altup_unembd_proj",
|
||||
"model.language_model.norm", "output_norm",
|
||||
"model.language_model.layers", "blk",
|
||||
|
||||
"input_layernorm", "attn_norm",
|
||||
"self_attn.q_proj", "attn_q",
|
||||
"self_attn.q_norm", "attn_q_norm",
|
||||
"self_attn.k_proj", "attn_k",
|
||||
"self_attn.k_norm", "attn_k_norm",
|
||||
"self_attn.v_proj", "attn_v",
|
||||
"self_attn.o_proj", "attn_output",
|
||||
"post_attention_layernorm", "post_attention_norm",
|
||||
"pre_feedforward_layernorm", "ffn_norm",
|
||||
"mlp.gate_proj", "ffn_gate",
|
||||
"mlp.up_proj", "ffn_up",
|
||||
"mlp.down_proj", "ffn_down",
|
||||
"post_feedforward_layernorm", "post_ffw_norm",
|
||||
"per_layer_input_gate", "inp_gate",
|
||||
"per_layer_projection", "proj",
|
||||
"post_per_layer_input_norm", "post_norm",
|
||||
"altup.", "altup_",
|
||||
"modality_router", "router",
|
||||
"prediction_coefs", "predict_coef",
|
||||
"correction_coefs", "correct_coef",
|
||||
"correct_output_scale", "correct_scale.weight",
|
||||
"laurel.", "laurel_",
|
||||
"linear_left", "l",
|
||||
"linear_right", "r",
|
||||
"post_laurel_norm", "post_norm",
|
||||
}
|
||||
}
|
||||
@@ -2,6 +2,9 @@ package convert
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"slices"
|
||||
"strings"
|
||||
|
||||
"github.com/ollama/ollama/fs/ggml"
|
||||
)
|
||||
@@ -27,38 +30,65 @@ func (p *mixtralModel) KV(t *Tokenizer) ggml.KV {
|
||||
}
|
||||
|
||||
func (p *mixtralModel) Tensors(ts []Tensor) []*ggml.Tensor {
|
||||
merges := make([]merge, 0, p.NumHiddenLayers*6)
|
||||
for i := range p.NumHiddenLayers {
|
||||
merges = append(merges, merge{
|
||||
fmt.Sprintf("blk.%d.*.w1.weight", i),
|
||||
fmt.Sprintf("blk.%d.ffn_gate_exps.weight", i),
|
||||
}, merge{
|
||||
fmt.Sprintf("blk.%d.*.w1.bias", i),
|
||||
fmt.Sprintf("blk.%d.ffn_gate_exps.bias", i),
|
||||
}, merge{
|
||||
fmt.Sprintf("blk.%d.*.w2.weight", i),
|
||||
fmt.Sprintf("blk.%d.ffn_up_exps.weight", i),
|
||||
}, merge{
|
||||
fmt.Sprintf("blk.%d.*.w2.bias", i),
|
||||
fmt.Sprintf("blk.%d.ffn_up_exps.bias", i),
|
||||
}, merge{
|
||||
fmt.Sprintf("blk.%d.*.w3.weight", i),
|
||||
fmt.Sprintf("blk.%d.ffn_down_exps.weight", i),
|
||||
}, merge{
|
||||
fmt.Sprintf("blk.%d.*.w3.bias", i),
|
||||
fmt.Sprintf("blk.%d.ffn_down_exps.bias", i),
|
||||
oldnew := []string{
|
||||
"model.layers", "blk",
|
||||
"w1", "ffn_gate_exps",
|
||||
"w2", "ffn_down_exps",
|
||||
"w3", "ffn_up_exps",
|
||||
}
|
||||
|
||||
for i := range p.NumLocalExperts {
|
||||
oldnew = append(oldnew, fmt.Sprintf(".block_sparse_moe.experts.%d.", i), ".")
|
||||
}
|
||||
|
||||
// group experts of the same layer (model.layers.%d) and type (w[123]) into a single tensor
|
||||
namer := strings.NewReplacer(oldnew...)
|
||||
experts := make(map[string]experts)
|
||||
|
||||
// merge experts into a single tensor while removing them from ts
|
||||
ts = slices.DeleteFunc(ts, func(t Tensor) bool {
|
||||
if !strings.Contains(t.Name(), ".block_sparse_moe.experts.") {
|
||||
return false
|
||||
}
|
||||
|
||||
name := namer.Replace(t.Name())
|
||||
experts[name] = append(experts[name], t)
|
||||
return true
|
||||
})
|
||||
|
||||
var out []*ggml.Tensor
|
||||
for n, e := range experts {
|
||||
// TODO(mxyng): sanity check experts
|
||||
out = append(out, &ggml.Tensor{
|
||||
Name: n,
|
||||
Kind: e[0].Kind(),
|
||||
Shape: append([]uint64{uint64(len(e))}, e[0].Shape()...),
|
||||
WriterTo: e,
|
||||
})
|
||||
}
|
||||
|
||||
out, ts := mergeTensors(ts, merges...)
|
||||
return append(out, p.llamaModel.Tensors(ts)...)
|
||||
}
|
||||
|
||||
func (p *mixtralModel) Replacements() []string {
|
||||
return append(
|
||||
p.llamaModel.Replacements(),
|
||||
"model.layers", "blk",
|
||||
"block_sparse_moe.gate", "ffn_gate_inp",
|
||||
"block_sparse_moe.experts.", ".",
|
||||
)
|
||||
}
|
||||
|
||||
type experts []Tensor
|
||||
|
||||
func (e experts) WriteTo(w io.Writer) (int64, error) {
|
||||
// TODO(mxyng): experts _should_ be numerically sorted by expert but this should check
|
||||
for _, t := range e {
|
||||
// the canonical merged experts tensor stacks all experts along a new, 0 axis,
|
||||
// e.g. `tensor.Stack(0, e[0], e[1:]...)`, which requires allocating temporary buffers
|
||||
// this accomplishes the same thing by writing each expert tensor in sequence
|
||||
if _, err := t.WriteTo(w); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
}
|
||||
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
@@ -11,13 +11,14 @@ import (
|
||||
"io"
|
||||
"io/fs"
|
||||
"log/slog"
|
||||
"maps"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"slices"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"golang.org/x/exp/maps"
|
||||
|
||||
"github.com/ollama/ollama/fs/ggml"
|
||||
)
|
||||
|
||||
@@ -136,7 +137,9 @@ func TestConvertModel(t *testing.T) {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
for _, k := range slices.Sorted(maps.Keys(expect)) {
|
||||
keys := maps.Keys(expect)
|
||||
slices.Sort(keys)
|
||||
for _, k := range keys {
|
||||
if v, ok := actual[k]; !ok {
|
||||
t.Errorf("missing %s", k)
|
||||
} else if v != expect[k] {
|
||||
@@ -340,7 +343,9 @@ func TestConvertAdapter(t *testing.T) {
|
||||
|
||||
actual := generateResultsJSON(t, r, m.KV(), m.Tensors())
|
||||
|
||||
for _, k := range slices.Sorted(maps.Keys(c.Expected)) {
|
||||
keys := maps.Keys(c.Expected)
|
||||
slices.Sort(keys)
|
||||
for _, k := range keys {
|
||||
if v, ok := actual[k]; !ok {
|
||||
t.Errorf("missing %s", k)
|
||||
} else if v != c.Expected[k] {
|
||||
|
||||
@@ -8,12 +8,12 @@ import (
|
||||
"fmt"
|
||||
"io"
|
||||
"io/fs"
|
||||
"maps"
|
||||
"slices"
|
||||
"strings"
|
||||
|
||||
"github.com/d4l3k/go-bfloat16"
|
||||
"github.com/x448/float16"
|
||||
"golang.org/x/exp/maps"
|
||||
)
|
||||
|
||||
type safetensorMetadata struct {
|
||||
@@ -46,7 +46,8 @@ func parseSafetensors(fsys fs.FS, replacer *strings.Replacer, ps ...string) ([]T
|
||||
return nil, err
|
||||
}
|
||||
|
||||
keys := slices.Sorted(maps.Keys(headers))
|
||||
keys := maps.Keys(headers)
|
||||
slices.Sort(keys)
|
||||
|
||||
names := make(map[string]struct{}, len(keys))
|
||||
|
||||
|
||||
@@ -2,9 +2,7 @@ package convert
|
||||
|
||||
import (
|
||||
"cmp"
|
||||
"io"
|
||||
"iter"
|
||||
"path"
|
||||
"slices"
|
||||
"strings"
|
||||
|
||||
@@ -76,54 +74,3 @@ func splitDim(t Tensor, dim int, splits ...split) iter.Seq[*ggml.Tensor] {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
type merge struct {
|
||||
pattern, name string
|
||||
}
|
||||
|
||||
// mergeTensors merges tensors that match a given pattern into a single tensor.
|
||||
func mergeTensors(unmatched []Tensor, merges ...merge) (out []*ggml.Tensor, _ []Tensor) {
|
||||
var matched []Tensor
|
||||
for i := range merges {
|
||||
matched, unmatched = slicesSplitFunc(unmatched, func(t Tensor) bool {
|
||||
matched, _ := path.Match(merges[i].pattern, t.Name())
|
||||
return matched
|
||||
})
|
||||
|
||||
if len(matched) > 0 {
|
||||
out = append(out, &ggml.Tensor{
|
||||
Name: merges[i].name,
|
||||
Kind: matched[0].Kind(),
|
||||
Shape: append([]uint64{uint64(len(matched))}, matched[0].Shape()...),
|
||||
WriterTo: mergeGroup(matched),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
return out, unmatched
|
||||
}
|
||||
|
||||
// slicesSplitFunc splits a slice into two slices based on a predicate function.
|
||||
func slicesSplitFunc[S ~[]E, E comparable](s S, fn func(e E) bool) (matched, unmatched S) {
|
||||
for _, e := range s {
|
||||
if fn(e) {
|
||||
matched = append(matched, e)
|
||||
} else {
|
||||
unmatched = append(unmatched, e)
|
||||
}
|
||||
}
|
||||
|
||||
return matched, unmatched
|
||||
}
|
||||
|
||||
type mergeGroup []Tensor
|
||||
|
||||
func (g mergeGroup) WriteTo(w io.Writer) (int64, error) {
|
||||
for _, t := range g {
|
||||
if _, err := t.WriteTo(w); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
}
|
||||
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
@@ -9,8 +9,6 @@ import (
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/ollama/ollama/fs/ggml"
|
||||
"github.com/pdevine/tensor"
|
||||
)
|
||||
|
||||
@@ -304,99 +302,3 @@ func TestSplitDim(t *testing.T) {
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestMerge(t *testing.T) {
|
||||
unmatched := []Tensor{
|
||||
&fakeTensor{
|
||||
name: "a.0.b",
|
||||
shape: []uint64{5, 2},
|
||||
data: []float32{10, 11, 12, 13, 14, 15, 16, 17, 18, 19},
|
||||
},
|
||||
&fakeTensor{
|
||||
name: "a.1.b",
|
||||
shape: []uint64{5, 2},
|
||||
data: []float32{20, 21, 22, 23, 24, 25, 26, 27, 28, 29},
|
||||
},
|
||||
&fakeTensor{
|
||||
name: "c.0.d",
|
||||
shape: []uint64{5, 2},
|
||||
data: []float32{30, 31, 32, 33, 34, 35, 36, 37, 38, 39},
|
||||
},
|
||||
&fakeTensor{
|
||||
name: "c.1.d",
|
||||
shape: []uint64{5, 2},
|
||||
data: []float32{40, 41, 42, 43, 44, 45, 46, 47, 48, 49},
|
||||
},
|
||||
&fakeTensor{
|
||||
name: "e.0.f",
|
||||
shape: []uint64{5, 2},
|
||||
data: []float32{50, 51, 52, 53, 54, 55, 56, 57, 58, 59},
|
||||
},
|
||||
}
|
||||
|
||||
checkMatched := func(t *testing.T, n int, matched []*ggml.Tensor) {
|
||||
for i := range n {
|
||||
got := matched[i]
|
||||
if diff := cmp.Diff([]uint64{2, 5, 2}, got.Shape); diff != "" {
|
||||
t.Errorf("unexpected (-want +got):\n%s", diff)
|
||||
}
|
||||
|
||||
var b bytes.Buffer
|
||||
if _, err := got.WriteTo(&b); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
f32s := make([]float32, 20)
|
||||
if err := binary.Read(&b, binary.LittleEndian, &f32s); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
offset := 10 + (i * 20)
|
||||
want := make([]float32, 20)
|
||||
for j := range 20 {
|
||||
want[j] = float32(offset + j)
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(want, f32s); diff != "" {
|
||||
t.Errorf("unexpected data (-want +got):\n%s", diff)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
t.Run("single merge", func(t *testing.T) {
|
||||
matched, unmatched := mergeTensors(unmatched, merge{"a.*.b", "a.b"})
|
||||
if len(unmatched) != 3 {
|
||||
t.Error("expected 3 remaining tensors, got", len(unmatched))
|
||||
}
|
||||
|
||||
if len(matched) != 1 {
|
||||
t.Error("expected 1 merged tensor, got", len(matched))
|
||||
}
|
||||
|
||||
checkMatched(t, 1, matched)
|
||||
})
|
||||
|
||||
t.Run("multiple merges", func(t *testing.T) {
|
||||
matched, unmatched := mergeTensors(unmatched, merge{"a.*.b", "a.b"}, merge{"c.*.d", "c.d"})
|
||||
if len(unmatched) != 1 {
|
||||
t.Error("expected 1 remaining tensors, got", len(unmatched))
|
||||
}
|
||||
|
||||
if len(matched) != 2 {
|
||||
t.Error("expected 2 merged tensor, got", len(matched))
|
||||
}
|
||||
|
||||
checkMatched(t, 2, matched)
|
||||
})
|
||||
|
||||
t.Run("no match", func(t *testing.T) {
|
||||
matched, unmatched := mergeTensors(unmatched, merge{"x.*.y", "x.y"})
|
||||
if len(unmatched) != 5 {
|
||||
t.Error("expected 5 remaining tensors, got", len(unmatched))
|
||||
}
|
||||
|
||||
if len(matched) != 0 {
|
||||
t.Error("expected no merged tensors, got", len(matched))
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
@@ -8,10 +8,11 @@ import (
|
||||
"fmt"
|
||||
"io/fs"
|
||||
"log/slog"
|
||||
"maps"
|
||||
"os"
|
||||
"slices"
|
||||
"strings"
|
||||
|
||||
"golang.org/x/exp/maps"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -259,8 +260,11 @@ func parseVocabularyFromTokenizer(fsys fs.FS) (*Vocabulary, error) {
|
||||
tokens[token.ID] = token
|
||||
}
|
||||
|
||||
keys := maps.Keys(tokens)
|
||||
slices.Sort(keys)
|
||||
|
||||
v := Vocabulary{Model: "gpt2"}
|
||||
for _, k := range slices.Sorted(maps.Keys(tokens)) {
|
||||
for _, k := range keys {
|
||||
token := tokens[k]
|
||||
v.Tokens = append(v.Tokens, token.Content)
|
||||
v.Scores = append(v.Scores, float32(token.ID))
|
||||
|
||||
@@ -3,7 +3,6 @@
|
||||
package discover
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"os"
|
||||
"regexp"
|
||||
@@ -56,13 +55,10 @@ func cudaVariant(gpuInfo CudaGPUInfo) string {
|
||||
}
|
||||
}
|
||||
}
|
||||
return "sbsa"
|
||||
}
|
||||
|
||||
// driver 12.0 has problems with the cuda v12 library, so run v11 on those older drivers
|
||||
if gpuInfo.DriverMajor < 12 || (gpuInfo.DriverMajor == 12 && gpuInfo.DriverMinor == 0) {
|
||||
// The detected driver is older than Feb 2023
|
||||
slog.Warn("old CUDA driver detected - please upgrade to a newer driver", "version", fmt.Sprintf("%d.%d", gpuInfo.DriverMajor, gpuInfo.DriverMinor))
|
||||
return "v11"
|
||||
}
|
||||
return "v12"
|
||||
|
||||
@@ -12,7 +12,7 @@ import (
|
||||
// '../lib/ollama' on Linux and the executable's directory on macOS
|
||||
// note: distribution builds, additional GPU-specific libraries are
|
||||
// found in subdirectories of the returned path, such as
|
||||
// 'cuda_v12', 'rocm', etc.
|
||||
// 'cuda_v11', 'cuda_v12', 'rocm', etc.
|
||||
var LibOllamaPath string = func() string {
|
||||
exe, err := os.Executable()
|
||||
if err != nil {
|
||||
|
||||
@@ -4,7 +4,6 @@
|
||||
* [Quickstart](../README.md#quickstart)
|
||||
* [Examples](./examples.md)
|
||||
* [Importing models](./import.md)
|
||||
* [MacOS Documentation](./macos.md)
|
||||
* [Linux Documentation](./linux.md)
|
||||
* [Windows Documentation](./windows.md)
|
||||
* [Docker Documentation](./docker.md)
|
||||
|
||||
244
docs/api.md
244
docs/api.md
@@ -500,30 +500,21 @@ The `message` object has the following fields:
|
||||
- `thinking`: (for thinking models) the model's thinking process
|
||||
- `images` (optional): a list of images to include in the message (for multimodal models such as `llava`)
|
||||
- `tool_calls` (optional): a list of tools in JSON that the model wants to use
|
||||
- `tool_name` (optional): add the name of the tool that was executed to inform the model of the result
|
||||
|
||||
Advanced parameters (optional):
|
||||
|
||||
- `format`: the format to return a response in. Format can be `json` or a JSON schema.
|
||||
- `format`: the format to return a response in. Format can be `json` or a JSON schema.
|
||||
- `options`: additional model parameters listed in the documentation for the [Modelfile](./modelfile.md#valid-parameters-and-values) such as `temperature`
|
||||
- `stream`: if `false` the response will be returned as a single response object, rather than a stream of objects
|
||||
- `keep_alive`: controls how long the model will stay loaded into memory following the request (default: `5m`)
|
||||
|
||||
### Tool calling
|
||||
|
||||
Tool calling is supported by providing a list of tools in the `tools` parameter. The model will generate a response that includes a list of tool calls. See the [Chat request (Streaming with tools)](#chat-request-streaming-with-tools) example below.
|
||||
|
||||
Models can also explain the result of the tool call in the response. See the [Chat request (With history, with tools)](#chat-request-with-history-with-tools) example below.
|
||||
|
||||
[See models with tool calling capabilities](https://ollama.com/search?c=tool).
|
||||
|
||||
### Structured outputs
|
||||
|
||||
Structured outputs are supported by providing a JSON schema in the `format` parameter. The model will generate a response that matches the schema. See the [Chat request (Structured outputs)](#chat-request-structured-outputs) example below.
|
||||
|
||||
### Examples
|
||||
|
||||
#### Chat request (Streaming)
|
||||
#### Chat Request (Streaming)
|
||||
|
||||
##### Request
|
||||
|
||||
@@ -578,88 +569,6 @@ Final response:
|
||||
}
|
||||
```
|
||||
|
||||
#### Chat request (Streaming with tools)
|
||||
|
||||
##### Request
|
||||
|
||||
```shell
|
||||
curl http://localhost:11434/api/chat -d '{
|
||||
"model": "llama3.2",
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "what is the weather in tokyo?"
|
||||
}
|
||||
],
|
||||
"tools": [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_weather",
|
||||
"description": "Get the weather in a given city",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"city": {
|
||||
"type": "string",
|
||||
"description": "The city to get the weather for"
|
||||
}
|
||||
},
|
||||
"required": ["city"]
|
||||
}
|
||||
}
|
||||
}
|
||||
],
|
||||
"stream": true
|
||||
}'
|
||||
```
|
||||
|
||||
##### Response
|
||||
|
||||
A stream of JSON objects is returned:
|
||||
```json
|
||||
{
|
||||
"model": "llama3.2",
|
||||
"created_at": "2025-07-07T20:22:19.184789Z",
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": "",
|
||||
"tool_calls": [
|
||||
{
|
||||
"function": {
|
||||
"name": "get_weather",
|
||||
"arguments": {
|
||||
"city": "Tokyo"
|
||||
}
|
||||
},
|
||||
}
|
||||
]
|
||||
},
|
||||
"done": false
|
||||
}
|
||||
```
|
||||
|
||||
Final response:
|
||||
|
||||
```json
|
||||
{
|
||||
"model":"llama3.2",
|
||||
"created_at":"2025-07-07T20:22:19.19314Z",
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": ""
|
||||
},
|
||||
"done_reason": "stop",
|
||||
"done": true,
|
||||
"total_duration": 182242375,
|
||||
"load_duration": 41295167,
|
||||
"prompt_eval_count": 169,
|
||||
"prompt_eval_duration": 24573166,
|
||||
"eval_count": 15,
|
||||
"eval_duration": 115959084
|
||||
}
|
||||
```
|
||||
|
||||
#### Chat request (No streaming)
|
||||
|
||||
##### Request
|
||||
@@ -697,74 +606,6 @@ curl http://localhost:11434/api/chat -d '{
|
||||
}
|
||||
```
|
||||
|
||||
#### Chat request (No streaming, with tools)
|
||||
|
||||
##### Request
|
||||
|
||||
|
||||
```shell
|
||||
curl http://localhost:11434/api/chat -d '{
|
||||
"model": "llama3.2",
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "what is the weather in tokyo?"
|
||||
}
|
||||
],
|
||||
"tools": [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_weather",
|
||||
"description": "Get the weather in a given city",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"city": {
|
||||
"type": "string",
|
||||
"description": "The city to get the weather for"
|
||||
}
|
||||
},
|
||||
"required": ["city"]
|
||||
}
|
||||
}
|
||||
}
|
||||
],
|
||||
"stream": false
|
||||
}'
|
||||
```
|
||||
|
||||
##### Response
|
||||
|
||||
```json
|
||||
{
|
||||
"model": "llama3.2",
|
||||
"created_at": "2025-07-07T20:32:53.844124Z",
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": "",
|
||||
"tool_calls": [
|
||||
{
|
||||
"function": {
|
||||
"name": "get_weather",
|
||||
"arguments": {
|
||||
"city": "Tokyo"
|
||||
}
|
||||
},
|
||||
}
|
||||
]
|
||||
},
|
||||
"done_reason": "stop",
|
||||
"done": true,
|
||||
"total_duration": 3244883583,
|
||||
"load_duration": 2969184542,
|
||||
"prompt_eval_count": 169,
|
||||
"prompt_eval_duration": 141656333,
|
||||
"eval_count": 18,
|
||||
"eval_duration": 133293625
|
||||
}
|
||||
```
|
||||
|
||||
#### Chat request (Structured outputs)
|
||||
|
||||
##### Request
|
||||
@@ -871,87 +712,6 @@ Final response:
|
||||
}
|
||||
```
|
||||
|
||||
|
||||
#### Chat request (With history, with tools)
|
||||
|
||||
##### Request
|
||||
|
||||
```shell
|
||||
curl http://localhost:11434/api/chat -d '{
|
||||
"model": "llama3.2",
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "what is the weather in Toronto?"
|
||||
},
|
||||
// the message from the model appended to history
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "",
|
||||
"tool_calls": [
|
||||
{
|
||||
"function": {
|
||||
"name": "get_temperature",
|
||||
"arguments": {
|
||||
"city": "Toronto"
|
||||
}
|
||||
},
|
||||
}
|
||||
]
|
||||
},
|
||||
// the tool call result appended to history
|
||||
{
|
||||
"role": "tool",
|
||||
"content": "11 degrees celsius",
|
||||
"tool_name": "get_temperature",
|
||||
}
|
||||
],
|
||||
"stream": false,
|
||||
"tools": [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_weather",
|
||||
"description": "Get the weather in a given city",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"city": {
|
||||
"type": "string",
|
||||
"description": "The city to get the weather for"
|
||||
}
|
||||
},
|
||||
"required": ["city"]
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
}'
|
||||
```
|
||||
|
||||
##### Response
|
||||
|
||||
```json
|
||||
{
|
||||
"model": "llama3.2",
|
||||
"created_at": "2025-07-07T20:43:37.688511Z",
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": "The current temperature in Toronto is 11°C."
|
||||
},
|
||||
"done_reason": "stop",
|
||||
"done": true,
|
||||
"total_duration": 890771750,
|
||||
"load_duration": 707634750,
|
||||
"prompt_eval_count": 94,
|
||||
"prompt_eval_duration": 91703208,
|
||||
"eval_count": 11,
|
||||
"eval_duration": 90282125
|
||||
}
|
||||
|
||||
```
|
||||
|
||||
|
||||
#### Chat request (with images)
|
||||
|
||||
##### Request
|
||||
|
||||
@@ -118,7 +118,7 @@ To run tests, use `go test`:
|
||||
go test ./...
|
||||
```
|
||||
|
||||
> NOTE: In rare circumstances, you may need to change a package using the new
|
||||
> NOTE: In rare cirumstances, you may need to change a package using the new
|
||||
> "synctest" package in go1.24.
|
||||
>
|
||||
> If you do not have the "synctest" package enabled, you will not see build or
|
||||
|
||||
17
docs/faq.md
17
docs/faq.md
@@ -292,7 +292,7 @@ If too many requests are sent to the server, it will respond with a 503 error in
|
||||
|
||||
## How does Ollama handle concurrent requests?
|
||||
|
||||
Ollama supports two levels of concurrent processing. If your system has sufficient available memory (system memory when using CPU inference, or VRAM for GPU inference) then multiple models can be loaded at the same time. For a given model, if there is sufficient available memory when the model is loaded, it can be configured to allow parallel request processing.
|
||||
Ollama supports two levels of concurrent processing. If your system has sufficient available memory (system memory when using CPU inference, or VRAM for GPU inference) then multiple models can be loaded at the same time. For a given model, if there is sufficient available memory when the model is loaded, it is configured to allow parallel request processing.
|
||||
|
||||
If there is insufficient available memory to load a new model request while one or more models are already loaded, all new requests will be queued until the new model can be loaded. As prior models become idle, one or more will be unloaded to make room for the new model. Queued requests will be processed in order. When using GPU inference new models must be able to completely fit in VRAM to allow concurrent model loads.
|
||||
|
||||
@@ -301,7 +301,7 @@ Parallel request processing for a given model results in increasing the context
|
||||
The following server settings may be used to adjust how Ollama handles concurrent requests on most platforms:
|
||||
|
||||
- `OLLAMA_MAX_LOADED_MODELS` - The maximum number of models that can be loaded concurrently provided they fit in available memory. The default is 3 * the number of GPUs or 3 for CPU inference.
|
||||
- `OLLAMA_NUM_PARALLEL` - The maximum number of parallel requests each model will process at the same time. The default is 1, and will handle 1 request per model at a time.
|
||||
- `OLLAMA_NUM_PARALLEL` - The maximum number of parallel requests each model will process at the same time. The default will auto-select either 4 or 1 based on available memory.
|
||||
- `OLLAMA_MAX_QUEUE` - The maximum number of requests Ollama will queue when busy before rejecting additional requests. The default is 512
|
||||
|
||||
Note: Windows with Radeon GPUs currently default to 1 model maximum due to limitations in ROCm v5.7 for available VRAM reporting. Once ROCm v6.2 is available, Windows Radeon will follow the defaults above. You may enable concurrent model loads on Radeon on Windows, but ensure you don't load more models than will fit into your GPUs VRAM.
|
||||
@@ -333,16 +333,3 @@ The currently available K/V cache quantization types are:
|
||||
How much the cache quantization impacts the model's response quality will depend on the model and the task. Models that have a high GQA count (e.g. Qwen2) may see a larger impact on precision from quantization than models with a low GQA count.
|
||||
|
||||
You may need to experiment with different quantization types to find the best balance between memory usage and quality.
|
||||
|
||||
## How can I stop Ollama from starting when I login to my computer
|
||||
|
||||
Ollama for Windows and macOS register as a login item during installation. You can disable this if you prefer not to have Ollama automatically start. Ollama will respect this setting across upgrades, unless you uninstall the application.
|
||||
|
||||
**Windows**
|
||||
- Remove `%APPDATA%\Microsoft\Windows\Start Menu\Programs\Startup\Ollama.lnk`
|
||||
|
||||
**MacOS Monterey (v12)**
|
||||
- Open `Settings` -> `Users & Groups` -> `Login Items` and find the `Ollama` entry, then click the `-` (minus) to remove
|
||||
|
||||
**MacOS Ventura (v13) and later**
|
||||
- Open `Settings` and search for "Login Items", find the `Ollama` entry under "Allow in the Background`, then click the slider to disable.
|
||||
@@ -1,14 +1,12 @@
|
||||
# GPU
|
||||
## Nvidia
|
||||
Ollama supports Nvidia GPUs with compute capability 5.0+ and driver version 531 and newer.
|
||||
Ollama supports Nvidia GPUs with compute capability 5.0+.
|
||||
|
||||
Check your compute compatibility to see if your card is supported:
|
||||
[https://developer.nvidia.com/cuda-gpus](https://developer.nvidia.com/cuda-gpus)
|
||||
|
||||
| Compute Capability | Family | Cards |
|
||||
| ------------------ | ------------------- | ----------------------------------------------------------------------------------------------------------- |
|
||||
| 12.0 | GeForce RTX 50xx | `RTX 5060` `RTX 5060 Ti` `RTX 5070` `RTX 5070 Ti` `RTX 5080` `RTX 5090` |
|
||||
| | NVIDIA Professioal | `RTX PRO 4000 Blackwell` `RTX PRO 4500 Blackwell` `RTX PRO 5000 Blackwell` `RTX PRO 6000 Blackwell` |
|
||||
| 9.0 | NVIDIA | `H200` `H100` |
|
||||
| 8.9 | GeForce RTX 40xx | `RTX 4090` `RTX 4080 SUPER` `RTX 4080` `RTX 4070 Ti SUPER` `RTX 4070 Ti` `RTX 4070 SUPER` `RTX 4070` `RTX 4060 Ti` `RTX 4060` |
|
||||
| | NVIDIA Professional | `L4` `L40` `RTX 6000` |
|
||||
|
||||
@@ -53,8 +53,6 @@ FROM /path/to/safetensors/directory
|
||||
|
||||
If you create the Modelfile in the same directory as the weights, you can use the command `FROM .`.
|
||||
|
||||
If you do not create the Modelfile, ollama will act as if there was a Modelfile with the command `FROM .`.
|
||||
|
||||
Now run the `ollama create` command from the directory where you created the `Modelfile`:
|
||||
|
||||
```shell
|
||||
|
||||
@@ -16,7 +16,7 @@ curl -fsSL https://ollama.com/install.sh | sh
|
||||
Download and extract the package:
|
||||
|
||||
```shell
|
||||
curl -LO https://ollama.com/download/ollama-linux-amd64.tgz
|
||||
curl -L https://ollama.com/download/ollama-linux-amd64.tgz -o ollama-linux-amd64.tgz
|
||||
sudo tar -C /usr -xzf ollama-linux-amd64.tgz
|
||||
```
|
||||
|
||||
|
||||
@@ -1,42 +0,0 @@
|
||||
# Ollama for macOS
|
||||
|
||||
## System Requirements
|
||||
|
||||
* MacOS Monterey (v12) or newer
|
||||
* Apple M series (CPU and GPU support) or x86 (CPU only)
|
||||
|
||||
|
||||
## Filesystem Requirements
|
||||
|
||||
The preferred method of installation is to mount the `ollama.dmg` and drag-and-drop the Ollama application to the system-wide `Applications` folder. Upon startup, the Ollama app will verify the `ollama` CLI is present in your PATH, and if not detected, will prompt for permission to create a link in `/usr/local/bin`
|
||||
|
||||
Once you've installed Ollama, you'll need additional space for storing the Large Language models, which can be tens to hundreds of GB in size. If your home directory doesn't have enough space, you can change where the binaries are installed, and where the models are stored.
|
||||
|
||||
### Changing Install Location
|
||||
|
||||
To install the Ollama application somewhere other than `Applications`, place the Ollama application in the desired location, and ensure the CLI `Ollama.app/Contents/Resources/ollama` or a sym-link to the CLI can be found in your path. Upon first start decline the "Move to Applications?" request.
|
||||
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
Ollama on MacOS stores files in a few different locations.
|
||||
- `~/.ollama` contains models and configuration
|
||||
- `~/.ollama/logs` contains logs
|
||||
- *app.log* contains most recent logs from the GUI application
|
||||
- *server.log* contains the most recent server logs
|
||||
- `<install location>/Ollama.app/Contents/Resources/ollama` the CLI binary
|
||||
|
||||
## Uninstall
|
||||
|
||||
To fully remove Ollama from your system, remove the following files and folders:
|
||||
|
||||
```
|
||||
sudo rm -rf /Applications/Ollama.app
|
||||
sudo rm /usr/local/bin/ollama
|
||||
rm -rf "~/Library/Application Support/Ollama"
|
||||
rm -rf "~/Library/Saved Application State/com.electron.ollama.savedState"
|
||||
rm -rf ~/Library/Caches/com.electron.ollama/
|
||||
rm -rf ~/Library/Caches/ollama
|
||||
rm -rf ~/Library/WebKit/com.electron.ollama
|
||||
rm -rf ~/.ollama
|
||||
```
|
||||
@@ -150,7 +150,7 @@ PARAMETER <parameter> <parametervalue>
|
||||
|
||||
| Parameter | Description | Value Type | Example Usage |
|
||||
| -------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | ---------- | -------------------- |
|
||||
| num_ctx | Sets the size of the context window used to generate the next token. (Default: 4096) | int | num_ctx 4096 |
|
||||
| num_ctx | Sets the size of the context window used to generate the next token. (Default: 2048) | int | num_ctx 4096 |
|
||||
| repeat_last_n | Sets how far back for the model to look back to prevent repetition. (Default: 64, 0 = disabled, -1 = num_ctx) | int | repeat_last_n 64 |
|
||||
| repeat_penalty | Sets how strongly to penalize repetitions. A higher value (e.g., 1.5) will penalize repetitions more strongly, while a lower value (e.g., 0.9) will be more lenient. (Default: 1.1) | float | repeat_penalty 1.1 |
|
||||
| temperature | The temperature of the model. Increasing the temperature will make the model answer more creatively. (Default: 0.8) | float | temperature 0.7 |
|
||||
|
||||
@@ -72,7 +72,7 @@ client = OpenAI(base_url="http://localhost:11434/v1", api_key="ollama")
|
||||
# Define the schema for the response
|
||||
class FriendInfo(BaseModel):
|
||||
name: str
|
||||
age: int
|
||||
age: int
|
||||
is_available: bool
|
||||
|
||||
class FriendList(BaseModel):
|
||||
|
||||
@@ -9,7 +9,7 @@ cat ~/.ollama/logs/server.log
|
||||
On **Linux** systems with systemd, the logs can be found with this command:
|
||||
|
||||
```shell
|
||||
journalctl -u ollama --no-pager --follow --pager-end
|
||||
journalctl -u ollama --no-pager --follow --pager-end
|
||||
```
|
||||
|
||||
When you run Ollama in a **container**, the logs go to stdout/stderr in the container:
|
||||
@@ -23,7 +23,7 @@ docker logs <container-name>
|
||||
If manually running `ollama serve` in a terminal, the logs will be on that terminal.
|
||||
|
||||
When you run Ollama on **Windows**, there are a few different locations. You can view them in the explorer window by hitting `<cmd>+R` and type in:
|
||||
- `explorer %LOCALAPPDATA%\Ollama` to view logs. The most recent server logs will be in `server.log` and older logs will be in `server-#.log`
|
||||
- `explorer %LOCALAPPDATA%\Ollama` to view logs. The most recent server logs will be in `server.log` and older logs will be in `server-#.log`
|
||||
- `explorer %LOCALAPPDATA%\Programs\Ollama` to browse the binaries (The installer adds this to your user PATH)
|
||||
- `explorer %HOMEPATH%\.ollama` to browse where models and configuration is stored
|
||||
|
||||
@@ -38,12 +38,12 @@ Join the [Discord](https://discord.gg/ollama) for help interpreting the logs.
|
||||
|
||||
## LLM libraries
|
||||
|
||||
Ollama includes multiple LLM libraries compiled for different GPUs and CPU vector features. Ollama tries to pick the best one based on the capabilities of your system. If this autodetection has problems, or you run into other problems (e.g. crashes in your GPU) you can workaround this by forcing a specific LLM library. `cpu_avx2` will perform the best, followed by `cpu_avx` and the slowest but most compatible is `cpu`. Rosetta emulation under MacOS will work with the `cpu` library.
|
||||
Ollama includes multiple LLM libraries compiled for different GPUs and CPU vector features. Ollama tries to pick the best one based on the capabilities of your system. If this autodetection has problems, or you run into other problems (e.g. crashes in your GPU) you can workaround this by forcing a specific LLM library. `cpu_avx2` will perform the best, followed by `cpu_avx` an the slowest but most compatible is `cpu`. Rosetta emulation under MacOS will work with the `cpu` library.
|
||||
|
||||
In the server log, you will see a message that looks something like this (varies from release to release):
|
||||
|
||||
```
|
||||
Dynamic LLM libraries [rocm_v6 cpu cpu_avx cpu_avx2 cuda_v12 rocm_v5]
|
||||
Dynamic LLM libraries [rocm_v6 cpu cpu_avx cpu_avx2 cuda_v11 rocm_v5]
|
||||
```
|
||||
|
||||
**Experimental LLM Library Override**
|
||||
@@ -97,7 +97,7 @@ If none of those resolve the problem, gather additional information and file an
|
||||
|
||||
On linux, AMD GPU access typically requires `video` and/or `render` group membership to access the `/dev/kfd` device. If permissions are not set up correctly, Ollama will detect this and report an error in the server log.
|
||||
|
||||
When running in a container, in some Linux distributions and container runtimes, the ollama process may be unable to access the GPU. Use `ls -lnd /dev/kfd /dev/dri /dev/dri/*` on the host system to determine the **numeric** group IDs on your system, and pass additional `--group-add ...` arguments to the container so it can access the required devices. For example, in the following output `crw-rw---- 1 0 44 226, 0 Sep 16 16:55 /dev/dri/card0` the group ID column is `44`
|
||||
When running in a container, in some Linux distributions and container runtimes, the ollama process may be unable to access the GPU. Use `ls -lnd /dev/kfd /dev/dri /dev/dri/*` on the host system to determine the **numeric** group IDs on your system, and pass additional `--group-add ...` arguments to the container so it can access the required devices. For example, in the following output `crw-rw---- 1 0 44 226, 0 Sep 16 16:55 /dev/dri/card0` the group ID column is `44`
|
||||
|
||||
If you are experiencing problems getting Ollama to correctly discover or use your GPU for inference, the following may help isolate the failure.
|
||||
- `AMD_LOG_LEVEL=3` Enable info log levels in the AMD HIP/ROCm libraries. This can help show more detailed error codes that can help troubleshoot problems
|
||||
|
||||
@@ -30,6 +30,20 @@ To install the Ollama application in a location different than your home directo
|
||||
OllamaSetup.exe /DIR="d:\some\location"
|
||||
```
|
||||
|
||||
### Changing Model Location
|
||||
|
||||
To change where Ollama stores the downloaded models instead of using your home directory, set the environment variable `OLLAMA_MODELS` in your user account.
|
||||
|
||||
1. Start the Settings (Windows 11) or Control Panel (Windows 10) application and search for _environment variables_.
|
||||
|
||||
2. Click on _Edit environment variables for your account_.
|
||||
|
||||
3. Edit or create a new variable for your user account for `OLLAMA_MODELS` where you want the models stored
|
||||
|
||||
4. Click OK/Apply to save.
|
||||
|
||||
If Ollama is already running, Quit the tray application and relaunch it from the Start menu, or a new terminal started after you saved the environment variables.
|
||||
|
||||
## API Access
|
||||
|
||||
Here's a quick example showing API access from `powershell`
|
||||
|
||||
@@ -219,7 +219,7 @@ func Uint(key string, defaultValue uint) func() uint {
|
||||
|
||||
var (
|
||||
// NumParallel sets the number of parallel model requests. NumParallel can be configured via the OLLAMA_NUM_PARALLEL environment variable.
|
||||
NumParallel = Uint("OLLAMA_NUM_PARALLEL", 1)
|
||||
NumParallel = Uint("OLLAMA_NUM_PARALLEL", 0)
|
||||
// MaxRunners sets the maximum number of loaded models. MaxRunners can be configured via the OLLAMA_MAX_LOADED_MODELS environment variable.
|
||||
MaxRunners = Uint("OLLAMA_MAX_LOADED_MODELS", 0)
|
||||
// MaxQueue sets the maximum number of queued requests. MaxQueue can be configured via the OLLAMA_MAX_QUEUE environment variable.
|
||||
|
||||
@@ -10,5 +10,4 @@ type Config interface {
|
||||
Strings(string, ...[]string) []string
|
||||
Ints(string, ...[]int32) []int32
|
||||
Floats(string, ...[]float32) []float32
|
||||
Bools(string, ...[]bool) []bool
|
||||
}
|
||||
|
||||
116
fs/ggml/ggml.go
116
fs/ggml/ggml.go
@@ -34,8 +34,7 @@ func (kv KV) Kind() string {
|
||||
}
|
||||
|
||||
func (kv KV) ParameterCount() uint64 {
|
||||
val, _ := keyValue(kv, "general.parameter_count", uint64(0))
|
||||
return val
|
||||
return keyValue(kv, "general.parameter_count", uint64(0))
|
||||
}
|
||||
|
||||
func (kv KV) FileType() FileType {
|
||||
@@ -54,27 +53,16 @@ func (kv KV) EmbeddingLength() uint64 {
|
||||
return uint64(kv.Uint("embedding_length"))
|
||||
}
|
||||
|
||||
func (kv KV) HeadCountMax() uint64 {
|
||||
// TODO(drifkin): using the max value can cause an overestimation. In the
|
||||
// future if array values become more popular, we can adapt the more invasive
|
||||
// <https://github.com/ollama/ollama/pull/10225>
|
||||
return uint64(kv.UintOrMaxArrayValue("attention.head_count", 1))
|
||||
func (kv KV) HeadCount() uint64 {
|
||||
return uint64(kv.Uint("attention.head_count"))
|
||||
}
|
||||
|
||||
func (kv KV) HeadCountMin() uint64 {
|
||||
return uint64(kv.UintOrMinArrayValue("attention.head_count", 1))
|
||||
func (kv KV) HeadCountKV() uint64 {
|
||||
return uint64(kv.Uint("attention.head_count_kv", 1))
|
||||
}
|
||||
|
||||
func (kv KV) HeadCountKVMax() uint64 {
|
||||
return uint64(kv.UintOrMaxArrayValue("attention.head_count_kv", 1))
|
||||
}
|
||||
|
||||
func (kv KV) HeadCountKVMin() uint64 {
|
||||
return uint64(kv.UintOrMinArrayValue("attention.head_count_kv", 1))
|
||||
}
|
||||
|
||||
func (kv KV) EmbeddingHeadCountMax() uint64 {
|
||||
if heads := kv.HeadCountMin(); heads > 0 {
|
||||
func (kv KV) EmbeddingHeadCount() uint64 {
|
||||
if heads := kv.HeadCount(); heads > 0 {
|
||||
return kv.EmbeddingLength() / heads
|
||||
}
|
||||
|
||||
@@ -82,11 +70,15 @@ func (kv KV) EmbeddingHeadCountMax() uint64 {
|
||||
}
|
||||
|
||||
func (kv KV) EmbeddingHeadCountK() uint64 {
|
||||
return uint64(kv.Uint("attention.key_length", uint32(kv.EmbeddingHeadCountMax())))
|
||||
return uint64(kv.Uint("attention.key_length", uint32(kv.EmbeddingHeadCount())))
|
||||
}
|
||||
|
||||
func (kv KV) EmbeddingHeadCountV() uint64 {
|
||||
return uint64(kv.Uint("attention.value_length", uint32(kv.EmbeddingHeadCountMax())))
|
||||
return uint64(kv.Uint("attention.value_length", uint32(kv.EmbeddingHeadCount())))
|
||||
}
|
||||
|
||||
func (kv KV) GQA() uint64 {
|
||||
return kv.HeadCount() / kv.HeadCountKV()
|
||||
}
|
||||
|
||||
func (kv KV) ContextLength() uint64 {
|
||||
@@ -98,83 +90,40 @@ func (kv KV) ChatTemplate() string {
|
||||
}
|
||||
|
||||
func (kv KV) String(key string, defaultValue ...string) string {
|
||||
val, _ := keyValue(kv, key, append(defaultValue, "")...)
|
||||
return val
|
||||
return keyValue(kv, key, append(defaultValue, "")...)
|
||||
}
|
||||
|
||||
func (kv KV) Uint(key string, defaultValue ...uint32) uint32 {
|
||||
val, _ := keyValue(kv, key, append(defaultValue, 0)...)
|
||||
return val
|
||||
return keyValue(kv, key, append(defaultValue, 0)...)
|
||||
}
|
||||
|
||||
func (kv KV) Float(key string, defaultValue ...float32) float32 {
|
||||
val, _ := keyValue(kv, key, append(defaultValue, 0)...)
|
||||
return val
|
||||
return keyValue(kv, key, append(defaultValue, 0)...)
|
||||
}
|
||||
|
||||
func (kv KV) Bool(key string, defaultValue ...bool) bool {
|
||||
val, _ := keyValue(kv, key, append(defaultValue, false)...)
|
||||
return val
|
||||
}
|
||||
|
||||
func (kv KV) UintOrMaxArrayValue(key string, defaultValue uint32) uint32 {
|
||||
_, max := kv.UintOrArrayValue(key, defaultValue)
|
||||
return max
|
||||
}
|
||||
|
||||
func (kv KV) UintOrMinArrayValue(key string, defaultValue uint32) uint32 {
|
||||
min, _ := kv.UintOrArrayValue(key, defaultValue)
|
||||
return min
|
||||
}
|
||||
|
||||
func (kv KV) UintOrArrayValue(key string, defaultValue uint32) (uint32, uint32) {
|
||||
if u32, ok := keyValue(kv, key, uint32(0)); ok {
|
||||
return u32, u32
|
||||
} else if u32s, ok := keyValue(kv, key, &array[uint32]{}); ok {
|
||||
min := slices.Min(u32s.values)
|
||||
max := slices.Max(u32s.values)
|
||||
return min, max
|
||||
} else if i32s, ok := keyValue(kv, key, &array[int32]{}); ok {
|
||||
min := slices.Min(i32s.values)
|
||||
max := slices.Max(i32s.values)
|
||||
if min < 0 || max < 0 {
|
||||
slog.Warn("array values are unexpectedly negative", "key", key, "min", min, "max", max)
|
||||
}
|
||||
return uint32(min), uint32(max)
|
||||
}
|
||||
|
||||
return defaultValue, defaultValue
|
||||
return keyValue(kv, key, append(defaultValue, false)...)
|
||||
}
|
||||
|
||||
func (kv KV) Strings(key string, defaultValue ...[]string) []string {
|
||||
val, _ := keyValue(kv, key, &array[string]{values: append(defaultValue, []string(nil))[0]})
|
||||
return val.values
|
||||
return keyValue(kv, key, &array[string]{values: append(defaultValue, []string(nil))[0]}).values
|
||||
}
|
||||
|
||||
func (kv KV) Ints(key string, defaultValue ...[]int32) []int32 {
|
||||
val, _ := keyValue(kv, key, &array[int32]{values: append(defaultValue, []int32(nil))[0]})
|
||||
return val.values
|
||||
return keyValue(kv, key, &array[int32]{values: append(defaultValue, []int32(nil))[0]}).values
|
||||
}
|
||||
|
||||
func (kv KV) Uints(key string, defaultValue ...[]uint32) []uint32 {
|
||||
val, _ := keyValue(kv, key, &array[uint32]{values: append(defaultValue, []uint32(nil))[0]})
|
||||
return val.values
|
||||
return keyValue(kv, key, &array[uint32]{values: append(defaultValue, []uint32(nil))[0]}).values
|
||||
}
|
||||
|
||||
func (kv KV) Floats(key string, defaultValue ...[]float32) []float32 {
|
||||
val, _ := keyValue(kv, key, &array[float32]{values: append(defaultValue, []float32(nil))[0]})
|
||||
return val.values
|
||||
}
|
||||
|
||||
func (kv KV) Bools(key string, defaultValue ...[]bool) []bool {
|
||||
val, _ := keyValue(kv, key, &array[bool]{values: append(defaultValue, []bool(nil))[0]})
|
||||
return val.values
|
||||
return keyValue(kv, key, &array[float32]{values: append(defaultValue, []float32(nil))[0]}).values
|
||||
}
|
||||
|
||||
func (kv KV) OllamaEngineRequired() bool {
|
||||
return slices.Contains([]string{
|
||||
"gemma3",
|
||||
"gemma3n",
|
||||
"mistral3",
|
||||
"llama4",
|
||||
"mllama",
|
||||
@@ -194,17 +143,17 @@ type arrayValueTypes interface {
|
||||
*array[string] | *array[float32] | *array[float64] | *array[bool]
|
||||
}
|
||||
|
||||
func keyValue[T valueTypes | arrayValueTypes](kv KV, key string, defaultValue ...T) (T, bool) {
|
||||
func keyValue[T valueTypes | arrayValueTypes](kv KV, key string, defaultValue ...T) T {
|
||||
if !strings.HasPrefix(key, "tokenizer.") && !strings.HasPrefix(key, "general.") {
|
||||
key = kv.Architecture() + "." + key
|
||||
}
|
||||
|
||||
if val, ok := kv[key].(T); ok {
|
||||
return val, true
|
||||
if val, ok := kv[key]; ok {
|
||||
return val.(T)
|
||||
}
|
||||
|
||||
slog.Debug("key with type not found", "key", key, "default", defaultValue[0])
|
||||
return defaultValue[0], false
|
||||
slog.Debug("key not found", "key", key, "default", defaultValue[0])
|
||||
return defaultValue[0]
|
||||
}
|
||||
|
||||
type Tensors struct {
|
||||
@@ -476,11 +425,11 @@ func Decode(rs io.ReadSeeker, maxArraySize int) (*GGML, error) {
|
||||
|
||||
func (f GGML) GraphSize(context, batch uint64, numParallel int, kvCacheType string) (kv []uint64, partialOffload, fullOffload uint64) {
|
||||
embedding := f.KV().EmbeddingLength()
|
||||
heads := f.KV().HeadCountMax()
|
||||
headsKV := f.KV().HeadCountKVMax()
|
||||
heads := f.KV().HeadCount()
|
||||
headsKV := f.KV().HeadCountKV()
|
||||
vocab := uint64(f.KV()["tokenizer.ggml.tokens"].(*array[string]).size)
|
||||
|
||||
embeddingHeads := f.KV().EmbeddingHeadCountMax()
|
||||
embeddingHeads := f.KV().EmbeddingHeadCount()
|
||||
embeddingHeadsK := f.KV().EmbeddingHeadCountK()
|
||||
embeddingHeadsV := f.KV().EmbeddingHeadCountV()
|
||||
|
||||
@@ -555,7 +504,7 @@ func (f GGML) GraphSize(context, batch uint64, numParallel int, kvCacheType stri
|
||||
// vocab graph
|
||||
4*batch*(embedding+vocab)+embedding*vocab*105/128,
|
||||
)
|
||||
case "gemma", "gemma2", "gemma3", "gemma3n":
|
||||
case "gemma", "gemma2", "gemma3":
|
||||
fullOffload = max(
|
||||
4*batch*(embedding+vocab),
|
||||
4*batch*(2+context+context*heads+2*embedding+2*embeddingHeadsK*heads),
|
||||
@@ -568,11 +517,6 @@ func (f GGML) GraphSize(context, batch uint64, numParallel int, kvCacheType stri
|
||||
embedding*embeddingHeadsK*heads*9/16,
|
||||
)
|
||||
|
||||
if f.KV().Architecture() == "gemma3n" {
|
||||
fullOffload *= 4
|
||||
partialOffload *= 4
|
||||
}
|
||||
|
||||
// Gemma2 also has sliding window attention but we only have an optimized implementation in the Ollama
|
||||
// engine. Gemma3 always uses the Ollama engine.
|
||||
if f.KV().Architecture() == "gemma3" {
|
||||
|
||||
@@ -269,33 +269,3 @@ func TestKeyValue(t *testing.T) {
|
||||
t.Errorf("unexpected uint8s (-got +want):\n%s", diff)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHeadCount(t *testing.T) {
|
||||
valuesArray := []int32{1, 5, 3, 4}
|
||||
cases := []struct {
|
||||
kv KV
|
||||
want uint64
|
||||
}{
|
||||
{
|
||||
kv: KV{
|
||||
"general.architecture": "abc",
|
||||
"abc.attention.head_count": &array[int32]{values: valuesArray, size: len(valuesArray)},
|
||||
},
|
||||
want: uint64(5),
|
||||
},
|
||||
{
|
||||
kv: KV{
|
||||
"general.architecture": "abc",
|
||||
"abc.attention.head_count": uint32(3),
|
||||
},
|
||||
want: uint64(3),
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range cases {
|
||||
got := tt.kv.HeadCountMax()
|
||||
if got != tt.want {
|
||||
t.Errorf("unexpected max value: got=%d want=%d", got, tt.want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -609,10 +609,6 @@ func ggufWriteKV(ws io.WriteSeeker, k string, v any) error {
|
||||
err = writeGGUFArray(ws, ggufTypeString, v)
|
||||
case *array[string]:
|
||||
err = writeGGUFArray(ws, ggufTypeString, v.values)
|
||||
case []bool:
|
||||
err = writeGGUFArray(ws, ggufTypeBool, v)
|
||||
case *array[bool]:
|
||||
err = writeGGUFArray(ws, ggufTypeBool, v.values)
|
||||
default:
|
||||
return fmt.Errorf("improper type for '%s'", k)
|
||||
}
|
||||
|
||||
347
fs/gguf/gguf.go
347
fs/gguf/gguf.go
@@ -1,347 +0,0 @@
|
||||
package gguf
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"cmp"
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"iter"
|
||||
"os"
|
||||
"slices"
|
||||
"strings"
|
||||
)
|
||||
|
||||
const (
|
||||
typeUint8 uint32 = iota
|
||||
typeInt8
|
||||
typeUint16
|
||||
typeInt16
|
||||
typeUint32
|
||||
typeInt32
|
||||
typeFloat32
|
||||
typeBool
|
||||
typeString
|
||||
typeArray
|
||||
typeUint64
|
||||
typeInt64
|
||||
typeFloat64
|
||||
)
|
||||
|
||||
var ErrUnsupported = errors.New("unsupported")
|
||||
|
||||
type File struct {
|
||||
Magic [4]byte
|
||||
Version uint32
|
||||
|
||||
keyValues *lazy[KeyValue]
|
||||
tensors *lazy[TensorInfo]
|
||||
offset int64
|
||||
|
||||
file *os.File
|
||||
reader *bufferedReader
|
||||
bts []byte
|
||||
}
|
||||
|
||||
func Open(path string) (f *File, err error) {
|
||||
f = &File{bts: make([]byte, 4096)}
|
||||
f.file, err = os.Open(path)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
f.reader = newBufferedReader(f.file, 32<<10)
|
||||
|
||||
if err := binary.Read(f.reader, binary.LittleEndian, &f.Magic); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if bytes.Equal(f.Magic[:], []byte("gguf")) {
|
||||
return nil, fmt.Errorf("%w file type %v", ErrUnsupported, f.Magic)
|
||||
}
|
||||
|
||||
if err := binary.Read(f.reader, binary.LittleEndian, &f.Version); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if f.Version < 2 {
|
||||
return nil, fmt.Errorf("%w version %v", ErrUnsupported, f.Version)
|
||||
}
|
||||
|
||||
f.tensors, err = newLazy(f, f.readTensor)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
f.tensors.successFunc = func() error {
|
||||
offset := f.reader.offset
|
||||
|
||||
alignment := cmp.Or(f.KeyValue("general.alignment").Int(), 32)
|
||||
f.offset = offset + (alignment-offset%alignment)%alignment
|
||||
return nil
|
||||
}
|
||||
|
||||
f.keyValues, err = newLazy(f, f.readKeyValue)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return f, nil
|
||||
}
|
||||
|
||||
func (f *File) readTensor() (TensorInfo, error) {
|
||||
name, err := readString(f)
|
||||
if err != nil {
|
||||
return TensorInfo{}, err
|
||||
}
|
||||
|
||||
dims, err := read[uint32](f)
|
||||
if err != nil {
|
||||
return TensorInfo{}, err
|
||||
}
|
||||
|
||||
shape := make([]uint64, dims)
|
||||
for i := range dims {
|
||||
shape[i], err = read[uint64](f)
|
||||
if err != nil {
|
||||
return TensorInfo{}, err
|
||||
}
|
||||
}
|
||||
|
||||
type_, err := read[uint32](f)
|
||||
if err != nil {
|
||||
return TensorInfo{}, err
|
||||
}
|
||||
|
||||
offset, err := read[uint64](f)
|
||||
if err != nil {
|
||||
return TensorInfo{}, err
|
||||
}
|
||||
|
||||
return TensorInfo{
|
||||
Name: name,
|
||||
Offset: offset,
|
||||
Shape: shape,
|
||||
Type: TensorType(type_),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (f *File) readKeyValue() (KeyValue, error) {
|
||||
key, err := readString(f)
|
||||
if err != nil {
|
||||
return KeyValue{}, err
|
||||
}
|
||||
|
||||
t, err := read[uint32](f)
|
||||
if err != nil {
|
||||
return KeyValue{}, err
|
||||
}
|
||||
|
||||
value, err := func() (any, error) {
|
||||
switch t {
|
||||
case typeUint8:
|
||||
return read[uint8](f)
|
||||
case typeInt8:
|
||||
return read[int8](f)
|
||||
case typeUint16:
|
||||
return read[uint16](f)
|
||||
case typeInt16:
|
||||
return read[int16](f)
|
||||
case typeUint32:
|
||||
return read[uint32](f)
|
||||
case typeInt32:
|
||||
return read[int32](f)
|
||||
case typeUint64:
|
||||
return read[uint64](f)
|
||||
case typeInt64:
|
||||
return read[int64](f)
|
||||
case typeFloat32:
|
||||
return read[float32](f)
|
||||
case typeFloat64:
|
||||
return read[float64](f)
|
||||
case typeBool:
|
||||
return read[bool](f)
|
||||
case typeString:
|
||||
return readString(f)
|
||||
case typeArray:
|
||||
return readArray(f)
|
||||
default:
|
||||
return nil, fmt.Errorf("%w type %d", ErrUnsupported, t)
|
||||
}
|
||||
}()
|
||||
if err != nil {
|
||||
return KeyValue{}, err
|
||||
}
|
||||
|
||||
return KeyValue{
|
||||
Key: key,
|
||||
Value: Value{value},
|
||||
}, nil
|
||||
}
|
||||
|
||||
func read[T any](f *File) (t T, err error) {
|
||||
err = binary.Read(f.reader, binary.LittleEndian, &t)
|
||||
return t, err
|
||||
}
|
||||
|
||||
func readString(f *File) (string, error) {
|
||||
n, err := read[uint64](f)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
if int(n) > len(f.bts) {
|
||||
f.bts = make([]byte, n)
|
||||
}
|
||||
|
||||
bts := f.bts[:n]
|
||||
if _, err := io.ReadFull(f.reader, bts); err != nil {
|
||||
return "", err
|
||||
}
|
||||
defer clear(bts)
|
||||
|
||||
return string(bts), nil
|
||||
}
|
||||
|
||||
func readArray(f *File) (any, error) {
|
||||
t, err := read[uint32](f)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
n, err := read[uint64](f)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
switch t {
|
||||
case typeUint8:
|
||||
return readArrayData[uint8](f, n)
|
||||
case typeInt8:
|
||||
return readArrayData[int8](f, n)
|
||||
case typeUint16:
|
||||
return readArrayData[uint16](f, n)
|
||||
case typeInt16:
|
||||
return readArrayData[int16](f, n)
|
||||
case typeUint32:
|
||||
return readArrayData[uint32](f, n)
|
||||
case typeInt32:
|
||||
return readArrayData[int32](f, n)
|
||||
case typeUint64:
|
||||
return readArrayData[uint64](f, n)
|
||||
case typeInt64:
|
||||
return readArrayData[int64](f, n)
|
||||
case typeFloat32:
|
||||
return readArrayData[float32](f, n)
|
||||
case typeFloat64:
|
||||
return readArrayData[float64](f, n)
|
||||
case typeBool:
|
||||
return readArrayData[bool](f, n)
|
||||
case typeString:
|
||||
return readArrayString(f, n)
|
||||
default:
|
||||
return nil, fmt.Errorf("%w type %d", ErrUnsupported, t)
|
||||
}
|
||||
}
|
||||
|
||||
func readArrayData[T any](f *File, n uint64) (s []T, err error) {
|
||||
s = make([]T, n)
|
||||
for i := range n {
|
||||
e, err := read[T](f)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
s[i] = e
|
||||
}
|
||||
|
||||
return s, nil
|
||||
}
|
||||
|
||||
func readArrayString(f *File, n uint64) (s []string, err error) {
|
||||
s = make([]string, n)
|
||||
for i := range n {
|
||||
e, err := readString(f)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
s[i] = e
|
||||
}
|
||||
|
||||
return s, nil
|
||||
}
|
||||
|
||||
func (f *File) Close() error {
|
||||
f.keyValues.stop()
|
||||
f.tensors.stop()
|
||||
return f.file.Close()
|
||||
}
|
||||
|
||||
func (f *File) KeyValue(key string) KeyValue {
|
||||
if !strings.HasPrefix(key, "general.") && !strings.HasPrefix(key, "tokenizer.") {
|
||||
key = f.KeyValue("general.architecture").String() + "." + key
|
||||
}
|
||||
|
||||
if index := slices.IndexFunc(f.keyValues.values, func(kv KeyValue) bool {
|
||||
return kv.Key == key
|
||||
}); index >= 0 {
|
||||
return f.keyValues.values[index]
|
||||
}
|
||||
|
||||
for keyValue, ok := f.keyValues.next(); ok; keyValue, ok = f.keyValues.next() {
|
||||
if keyValue.Key == key {
|
||||
return keyValue
|
||||
}
|
||||
}
|
||||
|
||||
return KeyValue{}
|
||||
}
|
||||
|
||||
func (f *File) NumKeyValues() int {
|
||||
return int(f.keyValues.count)
|
||||
}
|
||||
|
||||
func (f *File) KeyValues() iter.Seq2[int, KeyValue] {
|
||||
return f.keyValues.All()
|
||||
}
|
||||
|
||||
func (f *File) TensorInfo(name string) TensorInfo {
|
||||
if index := slices.IndexFunc(f.tensors.values, func(t TensorInfo) bool {
|
||||
return t.Name == name
|
||||
}); index >= 0 {
|
||||
return f.tensors.values[index]
|
||||
}
|
||||
|
||||
// fast-forward through key values if we haven't already
|
||||
_ = f.keyValues.rest()
|
||||
for tensor, ok := f.tensors.next(); ok; tensor, ok = f.tensors.next() {
|
||||
if tensor.Name == name {
|
||||
return tensor
|
||||
}
|
||||
}
|
||||
|
||||
return TensorInfo{}
|
||||
}
|
||||
|
||||
func (f *File) NumTensors() int {
|
||||
return int(f.tensors.count)
|
||||
}
|
||||
|
||||
func (f *File) TensorInfos() iter.Seq2[int, TensorInfo] {
|
||||
// fast forward through key values if we haven't already
|
||||
f.keyValues.rest()
|
||||
return f.tensors.All()
|
||||
}
|
||||
|
||||
func (f *File) TensorReader(name string) (TensorInfo, io.Reader, error) {
|
||||
t := f.TensorInfo(name)
|
||||
if t.NumBytes() == 0 {
|
||||
return TensorInfo{}, nil, fmt.Errorf("tensor %s not found", name)
|
||||
}
|
||||
|
||||
// fast forward through tensor info if we haven't already
|
||||
_ = f.tensors.rest()
|
||||
return t, io.NewSectionReader(f.file, f.offset+int64(t.Offset), t.NumBytes()), nil
|
||||
}
|
||||
@@ -1,249 +0,0 @@
|
||||
package gguf_test
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"os"
|
||||
"strconv"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/google/go-cmp/cmp/cmpopts"
|
||||
"github.com/ollama/ollama/fs/ggml"
|
||||
"github.com/ollama/ollama/fs/gguf"
|
||||
)
|
||||
|
||||
func createBinFile(tb testing.TB) string {
|
||||
tb.Helper()
|
||||
f, err := os.CreateTemp(tb.TempDir(), "")
|
||||
if err != nil {
|
||||
tb.Fatal(err)
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
kv := ggml.KV{
|
||||
"general.architecture": "llama",
|
||||
"llama.block_count": uint32(8),
|
||||
"llama.embedding_length": uint32(3),
|
||||
"llama.attention.head_count": uint32(2),
|
||||
"llama.attention.head_count_kv": uint32(2),
|
||||
"llama.attention.key_length": uint32(3),
|
||||
"llama.rope.dimension_count": uint32(4),
|
||||
"llama.rope.freq_base": float32(10000.0),
|
||||
"llama.rope.freq_scale": float32(1.0),
|
||||
"llama.attention.layer_norm_rms_epsilon": float32(1e-6),
|
||||
"tokenizer.ggml.eos_token_id": uint32(0),
|
||||
"tokenizer.ggml.eos_token_ids": []int32{1, 2, 3},
|
||||
"tokenizer.ggml.tokens": []string{"hello", "world"},
|
||||
"tokenizer.ggml.scores": []float32{0, 1},
|
||||
}
|
||||
|
||||
tensors := []*ggml.Tensor{
|
||||
{
|
||||
Name: "token_embd.weight",
|
||||
Kind: 0,
|
||||
Shape: []uint64{2, 3},
|
||||
WriterTo: bytes.NewBuffer(make([]byte, 4*2*3)),
|
||||
},
|
||||
{
|
||||
Name: "output.weight",
|
||||
Kind: 0,
|
||||
Shape: []uint64{3, 2},
|
||||
WriterTo: bytes.NewBuffer(make([]byte, 4*3*2)),
|
||||
},
|
||||
}
|
||||
|
||||
for i := range 8 {
|
||||
tensors = append(tensors, &ggml.Tensor{
|
||||
Name: "blk." + strconv.Itoa(i) + ".attn_q.weight",
|
||||
Kind: 0,
|
||||
Shape: []uint64{3, 3},
|
||||
WriterTo: bytes.NewBuffer(make([]byte, 4*3*3)),
|
||||
}, &ggml.Tensor{
|
||||
Name: "blk." + strconv.Itoa(i) + ".attn_k.weight",
|
||||
Kind: 0,
|
||||
Shape: []uint64{3, 3},
|
||||
WriterTo: bytes.NewBuffer(make([]byte, 4*3*3)),
|
||||
}, &ggml.Tensor{
|
||||
Name: "blk." + strconv.Itoa(i) + ".attn_v.weight",
|
||||
Kind: 0,
|
||||
Shape: []uint64{3, 3},
|
||||
WriterTo: bytes.NewBuffer(make([]byte, 4*3*3)),
|
||||
}, &ggml.Tensor{
|
||||
Name: "blk." + strconv.Itoa(i) + ".attn_output.weight",
|
||||
Kind: 0,
|
||||
Shape: []uint64{3, 3},
|
||||
WriterTo: bytes.NewBuffer(make([]byte, 4*3*3)),
|
||||
})
|
||||
}
|
||||
|
||||
if err := ggml.WriteGGUF(f, kv, tensors); err != nil {
|
||||
tb.Fatal(err)
|
||||
}
|
||||
|
||||
return f.Name()
|
||||
}
|
||||
|
||||
func TestRead(t *testing.T) {
|
||||
f, err := gguf.Open(createBinFile(t))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
if got := f.KeyValue("does.not.exist").Valid(); got {
|
||||
t.Errorf(`KeyValue("does.not.exist").Exists() = %v, want false`, got)
|
||||
}
|
||||
|
||||
if got := f.KeyValue("general.architecture").String(); got != "llama" {
|
||||
t.Errorf(`KeyValue("general.architecture").String() = %q, want %q`, got, "llama")
|
||||
}
|
||||
|
||||
if got := f.TensorInfo("token_embd.weight"); got.Name != "token_embd.weight" {
|
||||
t.Errorf(`TensorInfo("token_embd.weight").Name = %q, want %q`, got.Name, "token_embd.weight")
|
||||
} else if diff := cmp.Diff(got.Shape, []uint64{2, 3}); diff != "" {
|
||||
t.Errorf(`TensorInfo("token_embd.weight").Shape mismatch (-got +want):\n%s`, diff)
|
||||
} else if got.Type != gguf.TensorTypeF32 {
|
||||
t.Errorf(`TensorInfo("token_embd.weight").Type = %d, want %d`, got.Type, gguf.TensorTypeF32)
|
||||
}
|
||||
|
||||
if got := f.KeyValue("block_count").Uint(); got != 8 {
|
||||
t.Errorf(`KeyValue("block_count").Uint() = %d, want %d`, got, 8)
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(f.KeyValue("tokenizer.ggml.tokens").Strings(), []string{"hello", "world"}); diff != "" {
|
||||
t.Errorf("KeyValue(\"tokenizer.ggml.tokens\").Strings() mismatch (-got +want):\n%s", diff)
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(f.KeyValue("tokenizer.ggml.scores").Floats(), []float64{0, 1}); diff != "" {
|
||||
t.Errorf("KeyValue(\"tokenizer.ggml.scores\").Ints() mismatch (-got +want):\n%s", diff)
|
||||
}
|
||||
|
||||
var kvs []string
|
||||
for _, kv := range f.KeyValues() {
|
||||
if !kv.Valid() {
|
||||
t.Error("found invalid key-value pair:", kv)
|
||||
}
|
||||
|
||||
kvs = append(kvs, kv.Key)
|
||||
}
|
||||
|
||||
if len(kvs) != f.NumKeyValues() {
|
||||
t.Errorf("iterated key count = %d, want %d", len(kvs), f.NumKeyValues())
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(kvs, []string{
|
||||
"general.architecture",
|
||||
"llama.block_count",
|
||||
"llama.embedding_length",
|
||||
"llama.attention.head_count",
|
||||
"llama.attention.head_count_kv",
|
||||
"llama.attention.key_length",
|
||||
"llama.rope.dimension_count",
|
||||
"llama.rope.freq_base",
|
||||
"llama.rope.freq_scale",
|
||||
"llama.attention.layer_norm_rms_epsilon",
|
||||
"tokenizer.ggml.eos_token_id",
|
||||
"tokenizer.ggml.eos_token_ids",
|
||||
"tokenizer.ggml.tokens",
|
||||
"tokenizer.ggml.scores",
|
||||
}, cmpopts.SortSlices(strings.Compare)); diff != "" {
|
||||
t.Errorf("KeyValues() mismatch (-got +want):\n%s", diff)
|
||||
}
|
||||
|
||||
var tis []string
|
||||
for _, ti := range f.TensorInfos() {
|
||||
if !ti.Valid() {
|
||||
t.Error("found invalid tensor info:", ti)
|
||||
}
|
||||
|
||||
tis = append(tis, ti.Name)
|
||||
}
|
||||
|
||||
if len(tis) != f.NumTensors() {
|
||||
t.Errorf("iterated tensor count = %d, want %d", len(tis), f.NumTensors())
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(tis, []string{
|
||||
"token_embd.weight",
|
||||
"output.weight",
|
||||
"blk.0.attn_q.weight",
|
||||
"blk.0.attn_k.weight",
|
||||
"blk.0.attn_v.weight",
|
||||
"blk.0.attn_output.weight",
|
||||
"blk.1.attn_q.weight",
|
||||
"blk.1.attn_k.weight",
|
||||
"blk.1.attn_v.weight",
|
||||
"blk.1.attn_output.weight",
|
||||
"blk.2.attn_q.weight",
|
||||
"blk.2.attn_k.weight",
|
||||
"blk.2.attn_v.weight",
|
||||
"blk.2.attn_output.weight",
|
||||
"blk.3.attn_q.weight",
|
||||
"blk.3.attn_k.weight",
|
||||
"blk.3.attn_v.weight",
|
||||
"blk.3.attn_output.weight",
|
||||
"blk.4.attn_q.weight",
|
||||
"blk.4.attn_k.weight",
|
||||
"blk.4.attn_v.weight",
|
||||
"blk.4.attn_output.weight",
|
||||
"blk.5.attn_q.weight",
|
||||
"blk.5.attn_k.weight",
|
||||
"blk.5.attn_v.weight",
|
||||
"blk.5.attn_output.weight",
|
||||
"blk.6.attn_q.weight",
|
||||
"blk.6.attn_k.weight",
|
||||
"blk.6.attn_v.weight",
|
||||
"blk.6.attn_output.weight",
|
||||
"blk.7.attn_q.weight",
|
||||
"blk.7.attn_k.weight",
|
||||
"blk.7.attn_v.weight",
|
||||
"blk.7.attn_output.weight",
|
||||
}, cmpopts.SortSlices(strings.Compare)); diff != "" {
|
||||
t.Errorf("TensorInfos() mismatch (-got +want):\n%s", diff)
|
||||
}
|
||||
|
||||
ti, r, err := f.TensorReader("output.weight")
|
||||
if err != nil {
|
||||
t.Fatalf(`TensorReader("output.weight") error: %v`, err)
|
||||
}
|
||||
|
||||
if ti.Name != "output.weight" {
|
||||
t.Errorf(`TensorReader("output.weight").Name = %q, want %q`, ti.Name, "output.weight")
|
||||
} else if diff := cmp.Diff(ti.Shape, []uint64{3, 2}); diff != "" {
|
||||
t.Errorf(`TensorReader("output.weight").Shape mismatch (-got +want):\n%s`, diff)
|
||||
} else if ti.Type != gguf.TensorTypeF32 {
|
||||
t.Errorf(`TensorReader("output.weight").Type = %d, want %d`, ti.Type, gguf.TensorTypeF32)
|
||||
}
|
||||
|
||||
var b bytes.Buffer
|
||||
if _, err := b.ReadFrom(r); err != nil {
|
||||
t.Fatalf(`ReadFrom TensorReader("output.weight") error: %v`, err)
|
||||
}
|
||||
|
||||
if b.Len() != int(ti.NumBytes()) {
|
||||
t.Errorf(`ReadFrom TensorReader("output.weight") length = %d, want %d`, b.Len(), ti.NumBytes())
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkRead(b *testing.B) {
|
||||
b.ReportAllocs()
|
||||
|
||||
p := createBinFile(b)
|
||||
for b.Loop() {
|
||||
f, err := gguf.Open(p)
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
|
||||
if got := f.KeyValue("general.architecture").String(); got != "llama" {
|
||||
b.Errorf("got = %q, want %q", got, "llama")
|
||||
}
|
||||
|
||||
// Iterate through some tensors
|
||||
for range f.TensorInfos() {
|
||||
}
|
||||
|
||||
f.Close()
|
||||
}
|
||||
}
|
||||
@@ -1,90 +0,0 @@
|
||||
package gguf
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"slices"
|
||||
)
|
||||
|
||||
type KeyValue struct {
|
||||
Key string
|
||||
Value
|
||||
}
|
||||
|
||||
func (kv KeyValue) Valid() bool {
|
||||
return kv.Key != "" && kv.Value.value != nil
|
||||
}
|
||||
|
||||
type Value struct {
|
||||
value any
|
||||
}
|
||||
|
||||
func value[T any](v Value, kinds ...reflect.Kind) (t T) {
|
||||
vv := reflect.ValueOf(v.value)
|
||||
if slices.Contains(kinds, vv.Kind()) {
|
||||
t = vv.Convert(reflect.TypeOf(t)).Interface().(T)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func values[T any](v Value, kinds ...reflect.Kind) (ts []T) {
|
||||
switch vv := reflect.ValueOf(v.value); vv.Kind() {
|
||||
case reflect.Slice:
|
||||
if slices.Contains(kinds, vv.Type().Elem().Kind()) {
|
||||
ts = make([]T, vv.Len())
|
||||
for i := range vv.Len() {
|
||||
ts[i] = vv.Index(i).Convert(reflect.TypeOf(ts[i])).Interface().(T)
|
||||
}
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// Int returns Value as a signed integer. If it is not a signed integer, it returns 0.
|
||||
func (v Value) Int() int64 {
|
||||
return value[int64](v, reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64)
|
||||
}
|
||||
|
||||
// Ints returns Value as a signed integer slice. If it is not a signed integer slice, it returns nil.
|
||||
func (v Value) Ints() (i64s []int64) {
|
||||
return values[int64](v, reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64)
|
||||
}
|
||||
|
||||
// Uint converts an unsigned integer value to uint64. If the value is not a unsigned integer, it returns 0.
|
||||
func (v Value) Uint() uint64 {
|
||||
return value[uint64](v, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64)
|
||||
}
|
||||
|
||||
// Uints returns Value as a unsigned integer slice. If it is not a unsigned integer slice, it returns nil.
|
||||
func (v Value) Uints() (u64s []uint64) {
|
||||
return values[uint64](v, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64)
|
||||
}
|
||||
|
||||
// Float returns Value as a float. If it is not a float, it returns 0.
|
||||
func (v Value) Float() float64 {
|
||||
return value[float64](v, reflect.Float32, reflect.Float64)
|
||||
}
|
||||
|
||||
// Floats returns Value as a float slice. If it is not a float slice, it returns nil.
|
||||
func (v Value) Floats() (f64s []float64) {
|
||||
return values[float64](v, reflect.Float32, reflect.Float64)
|
||||
}
|
||||
|
||||
// Bool returns Value as a boolean. If it is not a boolean, it returns false.
|
||||
func (v Value) Bool() bool {
|
||||
return value[bool](v, reflect.Bool)
|
||||
}
|
||||
|
||||
// Bools returns Value as a boolean slice. If it is not a boolean slice, it returns nil.
|
||||
func (v Value) Bools() (bools []bool) {
|
||||
return values[bool](v, reflect.Bool)
|
||||
}
|
||||
|
||||
// String returns Value as a string. If it is not a string, it returns an empty string.
|
||||
func (v Value) String() string {
|
||||
return value[string](v, reflect.String)
|
||||
}
|
||||
|
||||
// Strings returns Value as a string slice. If it is not a string slice, it returns nil.
|
||||
func (v Value) Strings() (strings []string) {
|
||||
return values[string](v, reflect.String)
|
||||
}
|
||||
@@ -1,208 +0,0 @@
|
||||
package gguf
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
)
|
||||
|
||||
func split(name string, values map[string][]any) (matched []any, unmatched []any) {
|
||||
for key, value := range values {
|
||||
if key == name {
|
||||
matched = value
|
||||
} else {
|
||||
unmatched = append(unmatched, value...)
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func TestValue(t *testing.T) {
|
||||
values := map[string][]any{
|
||||
"int64": {int(42), int8(42), int16(42), int32(42), int64(42)},
|
||||
"uint64": {uint(42), uint8(42), uint16(42), uint32(42), uint64(42)},
|
||||
"float64": {float32(42), float64(42)},
|
||||
"string": {"42", "hello"},
|
||||
"bool": {true, false},
|
||||
}
|
||||
|
||||
t.Run("int64", func(t *testing.T) {
|
||||
matched, unmatched := split("int64", values)
|
||||
for _, v := range matched {
|
||||
kv := KeyValue{"key", Value{v}}
|
||||
if i64 := kv.Int(); i64 != 42 {
|
||||
t.Errorf("expected 42, got %d", i64)
|
||||
}
|
||||
}
|
||||
|
||||
for _, v := range unmatched {
|
||||
kv := KeyValue{"key", Value{v}}
|
||||
if i64 := kv.Int(); i64 != 0 {
|
||||
t.Errorf("expected 42, got %d", i64)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("uint64", func(t *testing.T) {
|
||||
matched, unmatched := split("uint64", values)
|
||||
for _, v := range matched {
|
||||
kv := KeyValue{"key", Value{v}}
|
||||
if u64 := kv.Uint(); u64 != 42 {
|
||||
t.Errorf("expected 42, got %d", u64)
|
||||
}
|
||||
}
|
||||
|
||||
for _, v := range unmatched {
|
||||
kv := KeyValue{"key", Value{v}}
|
||||
if u64 := kv.Uint(); u64 != 0 {
|
||||
t.Errorf("expected 42, got %d", u64)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("float64", func(t *testing.T) {
|
||||
matched, unmatched := split("float64", values)
|
||||
for _, v := range matched {
|
||||
kv := KeyValue{"key", Value{v}}
|
||||
if f64 := kv.Float(); f64 != 42 {
|
||||
t.Errorf("expected 42, got %f", f64)
|
||||
}
|
||||
}
|
||||
|
||||
for _, v := range unmatched {
|
||||
kv := KeyValue{"key", Value{v}}
|
||||
if f64 := kv.Float(); f64 != 0 {
|
||||
t.Errorf("expected 42, got %f", f64)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("string", func(t *testing.T) {
|
||||
matched, unmatched := split("string", values)
|
||||
for _, v := range matched {
|
||||
kv := KeyValue{"key", Value{v}}
|
||||
if s := kv.String(); s != v {
|
||||
t.Errorf("expected 42, got %s", s)
|
||||
}
|
||||
}
|
||||
|
||||
for _, v := range unmatched {
|
||||
kv := KeyValue{"key", Value{v}}
|
||||
if s := kv.String(); s != "" {
|
||||
t.Errorf("expected 42, got %s", s)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("bool", func(t *testing.T) {
|
||||
matched, unmatched := split("bool", values)
|
||||
for _, v := range matched {
|
||||
kv := KeyValue{"key", Value{v}}
|
||||
if b := kv.Bool(); b != v {
|
||||
t.Errorf("expected true, got %v", b)
|
||||
}
|
||||
}
|
||||
|
||||
for _, v := range unmatched {
|
||||
kv := KeyValue{"key", Value{v}}
|
||||
if b := kv.Bool(); b != false {
|
||||
t.Errorf("expected false, got %v", b)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestValues(t *testing.T) {
|
||||
values := map[string][]any{
|
||||
"int64s": {[]int{42}, []int8{42}, []int16{42}, []int32{42}, []int64{42}},
|
||||
"uint64s": {[]uint{42}, []uint8{42}, []uint16{42}, []uint32{42}, []uint64{42}},
|
||||
"float64s": {[]float32{42}, []float64{42}},
|
||||
"strings": {[]string{"42"}, []string{"hello"}},
|
||||
"bools": {[]bool{true}, []bool{false}},
|
||||
}
|
||||
|
||||
t.Run("int64s", func(t *testing.T) {
|
||||
matched, unmatched := split("int64s", values)
|
||||
for _, v := range matched {
|
||||
kv := KeyValue{"key", Value{v}}
|
||||
if diff := cmp.Diff(kv.Ints(), []int64{42}); diff != "" {
|
||||
t.Errorf("diff: %s", diff)
|
||||
}
|
||||
}
|
||||
|
||||
for _, v := range unmatched {
|
||||
kv := KeyValue{"key", Value{v}}
|
||||
if i64s := kv.Ints(); i64s != nil {
|
||||
t.Errorf("expected nil, got %v", i64s)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("uint64s", func(t *testing.T) {
|
||||
matched, unmatched := split("uint64s", values)
|
||||
for _, v := range matched {
|
||||
kv := KeyValue{"key", Value{v}}
|
||||
if diff := cmp.Diff(kv.Uints(), []uint64{42}); diff != "" {
|
||||
t.Errorf("diff: %s", diff)
|
||||
}
|
||||
}
|
||||
|
||||
for _, v := range unmatched {
|
||||
kv := KeyValue{"key", Value{v}}
|
||||
if u64s := kv.Uints(); u64s != nil {
|
||||
t.Errorf("expected nil, got %v", u64s)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("float64s", func(t *testing.T) {
|
||||
matched, unmatched := split("float64s", values)
|
||||
for _, v := range matched {
|
||||
kv := KeyValue{"key", Value{v}}
|
||||
if diff := cmp.Diff(kv.Floats(), []float64{42}); diff != "" {
|
||||
t.Errorf("diff: %s", diff)
|
||||
}
|
||||
}
|
||||
|
||||
for _, v := range unmatched {
|
||||
kv := KeyValue{"key", Value{v}}
|
||||
if f64s := kv.Floats(); f64s != nil {
|
||||
t.Errorf("expected nil, got %v", f64s)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("strings", func(t *testing.T) {
|
||||
matched, unmatched := split("strings", values)
|
||||
for _, v := range matched {
|
||||
kv := KeyValue{"key", Value{v}}
|
||||
if diff := cmp.Diff(kv.Strings(), v); diff != "" {
|
||||
t.Errorf("diff: %s", diff)
|
||||
}
|
||||
}
|
||||
|
||||
for _, v := range unmatched {
|
||||
kv := KeyValue{"key", Value{v}}
|
||||
if s := kv.Strings(); s != nil {
|
||||
t.Errorf("expected nil, got %v", s)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("bools", func(t *testing.T) {
|
||||
matched, unmatched := split("bools", values)
|
||||
for _, v := range matched {
|
||||
kv := KeyValue{"key", Value{v}}
|
||||
if diff := cmp.Diff(kv.Bools(), v); diff != "" {
|
||||
t.Errorf("diff: %s", diff)
|
||||
}
|
||||
}
|
||||
|
||||
for _, v := range unmatched {
|
||||
kv := KeyValue{"key", Value{v}}
|
||||
if b := kv.Bools(); b != nil {
|
||||
t.Errorf("expected nil, got %v", b)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -1,89 +0,0 @@
|
||||
package gguf
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"iter"
|
||||
"log/slog"
|
||||
)
|
||||
|
||||
type lazy[T any] struct {
|
||||
count uint64
|
||||
next func() (T, bool)
|
||||
stop func()
|
||||
values []T
|
||||
|
||||
// successFunc is called when all values have been successfully read.
|
||||
successFunc func() error
|
||||
}
|
||||
|
||||
func newLazy[T any](f *File, fn func() (T, error)) (*lazy[T], error) {
|
||||
it := lazy[T]{}
|
||||
if err := binary.Read(f.reader, binary.LittleEndian, &it.count); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
it.values = make([]T, 0)
|
||||
it.next, it.stop = iter.Pull(func(yield func(T) bool) {
|
||||
for i := range it.count {
|
||||
t, err := fn()
|
||||
if err != nil {
|
||||
slog.Error("error reading tensor", "index", i, "error", err)
|
||||
return
|
||||
}
|
||||
|
||||
it.values = append(it.values, t)
|
||||
if !yield(t) {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if it.successFunc != nil {
|
||||
it.successFunc()
|
||||
}
|
||||
})
|
||||
|
||||
return &it, nil
|
||||
}
|
||||
|
||||
func (g *lazy[T]) Values() iter.Seq[T] {
|
||||
return func(yield func(T) bool) {
|
||||
for _, v := range g.All() {
|
||||
if !yield(v) {
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (g *lazy[T]) All() iter.Seq2[int, T] {
|
||||
return func(yield func(int, T) bool) {
|
||||
for i := range int(g.count) {
|
||||
if i < len(g.values) {
|
||||
if !yield(i, g.values[i]) {
|
||||
break
|
||||
}
|
||||
} else {
|
||||
t, ok := g.next()
|
||||
if !ok {
|
||||
break
|
||||
}
|
||||
|
||||
if !yield(i, t) {
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (g *lazy[T]) rest() (collected bool) {
|
||||
for {
|
||||
_, ok := g.next()
|
||||
collected = collected || ok
|
||||
if !ok {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
return collected
|
||||
}
|
||||
@@ -1,23 +0,0 @@
|
||||
package gguf
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"io"
|
||||
)
|
||||
|
||||
type bufferedReader struct {
|
||||
offset int64
|
||||
*bufio.Reader
|
||||
}
|
||||
|
||||
func newBufferedReader(rs io.ReadSeeker, size int) *bufferedReader {
|
||||
return &bufferedReader{
|
||||
Reader: bufio.NewReaderSize(rs, size),
|
||||
}
|
||||
}
|
||||
|
||||
func (rs *bufferedReader) Read(p []byte) (n int, err error) {
|
||||
n, err = rs.Reader.Read(p)
|
||||
rs.offset += int64(n)
|
||||
return n, err
|
||||
}
|
||||
@@ -1,288 +0,0 @@
|
||||
package gguf
|
||||
|
||||
import (
|
||||
"log/slog"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type TensorInfo struct {
|
||||
Name string
|
||||
Offset uint64
|
||||
Shape []uint64
|
||||
Type TensorType
|
||||
}
|
||||
|
||||
func (ti TensorInfo) Valid() bool {
|
||||
return ti.Name != "" && ti.NumBytes() > 0
|
||||
}
|
||||
|
||||
func (ti TensorInfo) NumValues() int64 {
|
||||
var numItems int64 = 1
|
||||
for _, dim := range ti.Shape {
|
||||
numItems *= int64(dim)
|
||||
}
|
||||
return numItems
|
||||
}
|
||||
|
||||
// NumBytes returns the number of bytes in the tensor.
|
||||
func (ti TensorInfo) NumBytes() int64 {
|
||||
return int64(float64(ti.NumValues()) * ti.Type.NumBytes())
|
||||
}
|
||||
|
||||
func (ti TensorInfo) LogValue() slog.Value {
|
||||
return slog.GroupValue(
|
||||
slog.String("name", ti.Name),
|
||||
slog.Int64("offset", int64(ti.Offset)),
|
||||
slog.Any("shape", ti.Shape),
|
||||
slog.Int64("num_values", ti.NumValues()),
|
||||
slog.Int64("num_bytes", ti.NumBytes()),
|
||||
slog.Any("type", ti.Type),
|
||||
)
|
||||
}
|
||||
|
||||
type TensorType uint32
|
||||
|
||||
const (
|
||||
TensorTypeF32 TensorType = iota
|
||||
TensorTypeF16
|
||||
TensorTypeQ4_0
|
||||
TensorTypeQ4_1
|
||||
|
||||
// unexported // unused in gguf
|
||||
tensorTypeQ4_2
|
||||
tensorTypeQ4_3
|
||||
|
||||
TensorTypeQ5_0
|
||||
TensorTypeQ5_1
|
||||
TensorTypeQ8_0
|
||||
TensorTypeQ8_1
|
||||
TensorTypeQ2_K
|
||||
TensorTypeQ3_K
|
||||
TensorTypeQ4_K
|
||||
TensorTypeQ5_K
|
||||
TensorTypeQ6_K
|
||||
TensorTypeQ8_K
|
||||
|
||||
// unexported // unquantizable by ollama
|
||||
tensorTypeIQ2_XXS
|
||||
tensorTypeIQ2_XS
|
||||
tensorTypeIQ3_XXS
|
||||
tensorTypeIQ1_S
|
||||
tensorTypeIQ4_NL
|
||||
tensorTypeIQ3_S
|
||||
tensorTypeIQ2_S
|
||||
tensorTypeIQ4_XS
|
||||
|
||||
TensorTypeI8
|
||||
TensorTypeI16
|
||||
TensorTypeI32
|
||||
TensorTypeI64
|
||||
TensorTypeF64
|
||||
|
||||
// unexported // unquantizable by ollama
|
||||
tensorTypeIQ1_M
|
||||
|
||||
TensorTypeBF16
|
||||
|
||||
// unexported // unused in gguf
|
||||
tensorTypeQ4_0_4_4
|
||||
tensorTypeQ4_0_4_8
|
||||
tensorTypeQ4_0_8_8
|
||||
|
||||
// unexported // unquantizable by ollama
|
||||
tensorTypeTQ1_0
|
||||
tensorTypeTQ2_0
|
||||
|
||||
// unexported // unused in gguf
|
||||
tensorTypeIQ4_NL_4_4
|
||||
tensorTypeIQ4_NL_4_8
|
||||
tensorTypeIQ4_NL_8_8
|
||||
)
|
||||
|
||||
func (tt TensorType) NumBytes() float64 {
|
||||
return float64(tt.typeSize()) / float64(tt.blockSize())
|
||||
}
|
||||
|
||||
func (tt TensorType) typeSize() int64 {
|
||||
switch tt {
|
||||
case TensorTypeF32:
|
||||
return 4
|
||||
case TensorTypeF16:
|
||||
return 2
|
||||
case TensorTypeQ4_0:
|
||||
return 2 + tt.blockSize()/2
|
||||
case TensorTypeQ4_1:
|
||||
return 2 + 2 + tt.blockSize()/2
|
||||
case TensorTypeQ5_0:
|
||||
return 2 + 4 + tt.blockSize()/2
|
||||
case TensorTypeQ5_1:
|
||||
return 2 + 2 + 4 + tt.blockSize()/2
|
||||
case TensorTypeQ8_0:
|
||||
return 2 + tt.blockSize()
|
||||
case TensorTypeQ8_1:
|
||||
return 2 + 2 + tt.blockSize()
|
||||
case TensorTypeQ2_K:
|
||||
return tt.blockSize()/16 + tt.blockSize()/4 + 2 + 2
|
||||
case TensorTypeQ3_K:
|
||||
return tt.blockSize()/8 + tt.blockSize()/4 + 12 + 2
|
||||
case TensorTypeQ4_K:
|
||||
return 2 + 2 + 12 + tt.blockSize()/2
|
||||
case TensorTypeQ5_K:
|
||||
return 2 + 2 + 12 + tt.blockSize()/8 + tt.blockSize()/2
|
||||
case TensorTypeQ6_K:
|
||||
return tt.blockSize()/2 + tt.blockSize()/4 + tt.blockSize()/16 + 2
|
||||
case TensorTypeQ8_K:
|
||||
return 4 + tt.blockSize() + 2*tt.blockSize()/16
|
||||
case tensorTypeIQ2_XXS:
|
||||
return 2 + 2*tt.blockSize()/8
|
||||
case tensorTypeIQ2_XS:
|
||||
return 2 + 2*tt.blockSize()/8 + tt.blockSize()/32
|
||||
case tensorTypeIQ3_XXS:
|
||||
return 2 + tt.blockSize()/4 + tt.blockSize()/8
|
||||
case tensorTypeIQ1_S:
|
||||
return 2 + tt.blockSize()/8 + tt.blockSize()/16
|
||||
case tensorTypeIQ4_NL:
|
||||
return 2 + tt.blockSize()/2
|
||||
case tensorTypeIQ3_S:
|
||||
return 2 + tt.blockSize()/4 + tt.blockSize()/8 + tt.blockSize()/32 + 4
|
||||
case tensorTypeIQ2_S:
|
||||
return 2 + tt.blockSize()/4 + tt.blockSize()/16
|
||||
case tensorTypeIQ4_XS:
|
||||
return 2 + 2 + tt.blockSize()/2 + tt.blockSize()/64
|
||||
case TensorTypeI8:
|
||||
return 1
|
||||
case TensorTypeI16:
|
||||
return 2
|
||||
case TensorTypeI32:
|
||||
return 4
|
||||
case TensorTypeI64:
|
||||
return 8
|
||||
case TensorTypeF64:
|
||||
return 8
|
||||
case tensorTypeIQ1_M:
|
||||
return tt.blockSize()/8 + tt.blockSize()/16 + tt.blockSize()/32
|
||||
case TensorTypeBF16:
|
||||
return 2
|
||||
default:
|
||||
return 0
|
||||
}
|
||||
}
|
||||
|
||||
func (tt TensorType) blockSize() int64 {
|
||||
switch tt {
|
||||
case TensorTypeF32,
|
||||
TensorTypeF16,
|
||||
TensorTypeI8,
|
||||
TensorTypeI16,
|
||||
TensorTypeI32,
|
||||
TensorTypeI64,
|
||||
TensorTypeF64,
|
||||
TensorTypeBF16:
|
||||
return 1
|
||||
case TensorTypeQ4_0,
|
||||
TensorTypeQ4_1,
|
||||
TensorTypeQ5_0,
|
||||
TensorTypeQ5_1,
|
||||
TensorTypeQ8_0,
|
||||
TensorTypeQ8_1,
|
||||
tensorTypeIQ4_NL:
|
||||
return 32
|
||||
default:
|
||||
return 256
|
||||
}
|
||||
}
|
||||
|
||||
func (tt TensorType) String() string {
|
||||
switch tt {
|
||||
case TensorTypeF32:
|
||||
return "f32"
|
||||
case TensorTypeF16:
|
||||
return "f16"
|
||||
case TensorTypeQ4_0:
|
||||
return "q4_0"
|
||||
case TensorTypeQ4_1:
|
||||
return "q4_1"
|
||||
case tensorTypeQ4_2:
|
||||
return "q4_2"
|
||||
case tensorTypeQ4_3:
|
||||
return "q4_3"
|
||||
case TensorTypeQ5_0:
|
||||
return "q5_0"
|
||||
case TensorTypeQ5_1:
|
||||
return "q5_1"
|
||||
case TensorTypeQ8_0:
|
||||
return "q8_0"
|
||||
case TensorTypeQ8_1:
|
||||
return "q8_1"
|
||||
case TensorTypeQ2_K:
|
||||
return "q2_k"
|
||||
case TensorTypeQ3_K:
|
||||
return "q3_k"
|
||||
case TensorTypeQ4_K:
|
||||
return "q4_k"
|
||||
case TensorTypeQ5_K:
|
||||
return "q5_k"
|
||||
case TensorTypeQ6_K:
|
||||
return "q6_k"
|
||||
case TensorTypeQ8_K:
|
||||
return "q8_k"
|
||||
case tensorTypeIQ2_XXS:
|
||||
return "iq2_xxs"
|
||||
case tensorTypeIQ2_XS:
|
||||
return "iq2_xs"
|
||||
case tensorTypeIQ3_XXS:
|
||||
return "iq3_xxs"
|
||||
case tensorTypeIQ1_S:
|
||||
return "iq1_s"
|
||||
case tensorTypeIQ4_NL:
|
||||
return "iq4_nl"
|
||||
case tensorTypeIQ3_S:
|
||||
return "iq3_s"
|
||||
case tensorTypeIQ2_S:
|
||||
return "iq2_s"
|
||||
case tensorTypeIQ4_XS:
|
||||
return "iq4_xs"
|
||||
case TensorTypeI8:
|
||||
return "i8"
|
||||
case TensorTypeI16:
|
||||
return "i16"
|
||||
case TensorTypeI32:
|
||||
return "i32"
|
||||
case TensorTypeI64:
|
||||
return "i64"
|
||||
case TensorTypeF64:
|
||||
return "f64"
|
||||
case tensorTypeIQ1_M:
|
||||
return "iq1_m"
|
||||
case TensorTypeBF16:
|
||||
return "bf16"
|
||||
case tensorTypeQ4_0_4_4:
|
||||
return "q4_0_4_4"
|
||||
case tensorTypeQ4_0_4_8:
|
||||
return "q4_0_4_8"
|
||||
case tensorTypeQ4_0_8_8:
|
||||
return "q4_0_8_8"
|
||||
case tensorTypeTQ1_0:
|
||||
return "tq1_0"
|
||||
case tensorTypeTQ2_0:
|
||||
return "tq2_0"
|
||||
case tensorTypeIQ4_NL_4_4:
|
||||
return "iq4_nl_4_4"
|
||||
case tensorTypeIQ4_NL_4_8:
|
||||
return "iq4_nl_4_8"
|
||||
case tensorTypeIQ4_NL_8_8:
|
||||
return "iq4_nl_8_8"
|
||||
default:
|
||||
return "unknown"
|
||||
}
|
||||
}
|
||||
|
||||
func (tt TensorType) LogValue() slog.Value {
|
||||
return slog.GroupValue(
|
||||
slog.Uint64("value", uint64(tt)),
|
||||
slog.String("name", strings.ToUpper(tt.String())),
|
||||
slog.Int64("size", tt.typeSize()),
|
||||
slog.Int64("block_size", tt.blockSize()),
|
||||
slog.Float64("num_bytes", tt.NumBytes()),
|
||||
)
|
||||
}
|
||||
6
go.mod
6
go.mod
@@ -19,13 +19,12 @@ require (
|
||||
github.com/d4l3k/go-bfloat16 v0.0.0-20211005043715-690c3bdd05f1
|
||||
github.com/dlclark/regexp2 v1.11.4
|
||||
github.com/emirpasic/gods/v2 v2.0.0-alpha
|
||||
github.com/google/go-cmp v0.7.0
|
||||
github.com/google/go-cmp v0.6.0
|
||||
github.com/mattn/go-runewidth v0.0.14
|
||||
github.com/nlpodyssey/gopickle v0.3.0
|
||||
github.com/pdevine/tensor v0.0.0-20240510204454-f88f4562727c
|
||||
golang.org/x/image v0.22.0
|
||||
golang.org/x/tools v0.30.0
|
||||
gonum.org/v1/gonum v0.15.0
|
||||
)
|
||||
|
||||
require (
|
||||
@@ -45,6 +44,7 @@ require (
|
||||
github.com/xtgo/set v1.0.0 // indirect
|
||||
go4.org/unsafe/assume-no-moving-gc v0.0.0-20231121144256-b99613f794b6 // indirect
|
||||
golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1 // indirect
|
||||
gonum.org/v1/gonum v0.15.0 // indirect
|
||||
gorgonia.org/vecf32 v0.9.0 // indirect
|
||||
gorgonia.org/vecf64 v0.9.0 // indirect
|
||||
)
|
||||
@@ -71,7 +71,7 @@ require (
|
||||
github.com/ugorji/go/codec v1.2.12 // indirect
|
||||
golang.org/x/arch v0.8.0 // indirect
|
||||
golang.org/x/crypto v0.36.0
|
||||
golang.org/x/exp v0.0.0-20250218142911-aa4b98e5adaa // indirect
|
||||
golang.org/x/exp v0.0.0-20250218142911-aa4b98e5adaa
|
||||
golang.org/x/net v0.38.0 // indirect
|
||||
golang.org/x/sys v0.31.0
|
||||
golang.org/x/term v0.30.0
|
||||
|
||||
4
go.sum
4
go.sum
@@ -112,8 +112,8 @@ github.com/google/go-cmp v0.4.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/
|
||||
github.com/google/go-cmp v0.5.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
|
||||
github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
|
||||
github.com/google/go-cmp v0.5.6/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
|
||||
github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8=
|
||||
github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU=
|
||||
github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI=
|
||||
github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
|
||||
github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
|
||||
github.com/google/uuid v1.1.2/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
||||
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
|
||||
|
||||
@@ -1,57 +0,0 @@
|
||||
//go:build integration && library
|
||||
|
||||
package integration
|
||||
|
||||
import (
|
||||
"context"
|
||||
"log/slog"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
)
|
||||
|
||||
// First run of this scenario on a target system will take a long time to download
|
||||
// ~1.5TB of models. Set a sufficiently large -timeout for your network speed
|
||||
func TestLibraryModelsGenerate(t *testing.T) {
|
||||
softTimeout, hardTimeout := getTimeouts(t)
|
||||
slog.Info("Setting timeouts", "soft", softTimeout, "hard", hardTimeout)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), hardTimeout)
|
||||
defer cancel()
|
||||
client, _, cleanup := InitServerConnection(ctx, t)
|
||||
defer cleanup()
|
||||
|
||||
chatModels := libraryChatModels
|
||||
for _, model := range chatModels {
|
||||
t.Run(model, func(t *testing.T) {
|
||||
if time.Now().Sub(started) > softTimeout {
|
||||
t.Skip("skipping remaining tests to avoid excessive runtime")
|
||||
}
|
||||
if err := PullIfMissing(ctx, client, model); err != nil {
|
||||
t.Fatalf("pull failed %s", err)
|
||||
}
|
||||
req := api.GenerateRequest{
|
||||
Model: model,
|
||||
Prompt: "why is the sky blue?",
|
||||
KeepAlive: &api.Duration{Duration: 10 * time.Second},
|
||||
Options: map[string]interface{}{
|
||||
"temperature": 0.1,
|
||||
"seed": 123,
|
||||
},
|
||||
}
|
||||
anyResp := []string{"rayleigh", "scatter", "atmosphere", "nitrogen", "oxygen", "wavelength"}
|
||||
// Special cases
|
||||
if model == "duckdb-nsql" {
|
||||
anyResp = []string{"select", "from"}
|
||||
} else if model == "granite3-guardian" || model == "shieldgemma" || model == "llama-guard3" || model == "bespoke-minicheck" {
|
||||
anyResp = []string{"yes", "no", "safe", "unsafe"}
|
||||
} else if model == "openthinker" || model == "nexusraven" {
|
||||
anyResp = []string{"plugin", "im_sep", "components", "function call"}
|
||||
} else if model == "starcoder" || model == "starcoder2" || model == "magicoder" || model == "deepseek-coder" {
|
||||
req.Prompt = "def fibonacci():"
|
||||
anyResp = []string{"f(n)", "sequence", "n-1", "main()", "__main__", "while"}
|
||||
}
|
||||
DoGenerate(ctx, t, client, req, anyResp, 120*time.Second, 30*time.Second)
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -19,6 +19,35 @@ import (
|
||||
"github.com/ollama/ollama/format"
|
||||
)
|
||||
|
||||
var (
|
||||
started = time.Now()
|
||||
chatModels = []string{
|
||||
"granite3-moe:latest",
|
||||
"granite-code:latest",
|
||||
"nemotron-mini:latest",
|
||||
"command-r:latest",
|
||||
"gemma2:latest",
|
||||
"gemma:latest",
|
||||
"internlm2:latest",
|
||||
"phi3.5:latest",
|
||||
"phi3:latest",
|
||||
// "phi:latest", // flaky, sometimes generates no response on first query
|
||||
"stablelm2:latest", // Predictions are off, crashes on small VRAM GPUs
|
||||
"falcon:latest",
|
||||
"falcon2:latest",
|
||||
"minicpm-v:latest",
|
||||
"mistral:latest",
|
||||
"orca-mini:latest",
|
||||
"llama2:latest",
|
||||
"llama3.1:latest",
|
||||
"llama3.2:latest",
|
||||
"llama3.2-vision:latest",
|
||||
"qwen2.5-coder:latest",
|
||||
"qwen:latest",
|
||||
"solar-pro:latest",
|
||||
}
|
||||
)
|
||||
|
||||
func TestModelsGenerate(t *testing.T) {
|
||||
softTimeout, hardTimeout := getTimeouts(t)
|
||||
slog.Info("Setting timeouts", "soft", softTimeout, "hard", hardTimeout)
|
||||
@@ -39,13 +68,6 @@ func TestModelsGenerate(t *testing.T) {
|
||||
slog.Warn("No VRAM info available, testing all models, so larger ones might timeout...")
|
||||
}
|
||||
|
||||
var chatModels []string
|
||||
if s := os.Getenv("OLLAMA_NEW_ENGINE"); s != "" {
|
||||
chatModels = ollamaEngineChatModels
|
||||
} else {
|
||||
chatModels = append(ollamaEngineChatModels, llamaRunnerChatModels...)
|
||||
}
|
||||
|
||||
for _, model := range chatModels {
|
||||
t.Run(model, func(t *testing.T) {
|
||||
if time.Now().Sub(started) > softTimeout {
|
||||
|
||||
@@ -1,266 +0,0 @@
|
||||
//go:build integration && perf
|
||||
|
||||
package integration
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"log/slog"
|
||||
"math"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strconv"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/format"
|
||||
)
|
||||
|
||||
var (
|
||||
// Models that don't work reliably with the large context prompt in this test case
|
||||
longContextFlakes = []string{
|
||||
"granite-code:latest",
|
||||
"nemotron-mini:latest",
|
||||
"falcon:latest", // 2k model
|
||||
"falcon2:latest", // 2k model
|
||||
"minicpm-v:latest",
|
||||
"qwen:latest",
|
||||
"solar-pro:latest",
|
||||
}
|
||||
)
|
||||
|
||||
// Note: this test case can take a long time to run, particularly on models with
|
||||
// large contexts. Run with -timeout set to a large value to get reasonable coverage
|
||||
// Example usage:
|
||||
//
|
||||
// go test --tags=integration,perf -count 1 ./integration -v -timeout 90m -run TestModelsPerf 2>&1 | tee int.log
|
||||
// cat int.log | grep MODEL_PERF_HEADER | head -1| cut -f2- -d: > perf.csv
|
||||
// cat int.log | grep MODEL_PERF_DATA | cut -f2- -d: >> perf.csv
|
||||
func TestModelsPerf(t *testing.T) {
|
||||
softTimeout, hardTimeout := getTimeouts(t)
|
||||
slog.Info("Setting timeouts", "soft", softTimeout, "hard", hardTimeout)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), hardTimeout)
|
||||
defer cancel()
|
||||
client, _, cleanup := InitServerConnection(ctx, t)
|
||||
defer cleanup()
|
||||
|
||||
// TODO use info API eventually
|
||||
var maxVram uint64
|
||||
var err error
|
||||
if s := os.Getenv("OLLAMA_MAX_VRAM"); s != "" {
|
||||
maxVram, err = strconv.ParseUint(s, 10, 64)
|
||||
if err != nil {
|
||||
t.Fatalf("invalid OLLAMA_MAX_VRAM %v", err)
|
||||
}
|
||||
} else {
|
||||
slog.Warn("No VRAM info available, testing all models, so larger ones might timeout...")
|
||||
}
|
||||
|
||||
data, err := ioutil.ReadFile(filepath.Join("testdata", "shakespeare.txt"))
|
||||
if err != nil {
|
||||
t.Fatalf("failed to open test data file: %s", err)
|
||||
}
|
||||
longPrompt := "summarize the following: " + string(data)
|
||||
|
||||
var chatModels []string
|
||||
if s := os.Getenv("OLLAMA_NEW_ENGINE"); s != "" {
|
||||
chatModels = ollamaEngineChatModels
|
||||
} else {
|
||||
chatModels = append(ollamaEngineChatModels, llamaRunnerChatModels...)
|
||||
}
|
||||
|
||||
for _, model := range chatModels {
|
||||
t.Run(model, func(t *testing.T) {
|
||||
if time.Now().Sub(started) > softTimeout {
|
||||
t.Skip("skipping remaining tests to avoid excessive runtime")
|
||||
}
|
||||
if err := PullIfMissing(ctx, client, model); err != nil {
|
||||
t.Fatalf("pull failed %s", err)
|
||||
}
|
||||
var maxContext int
|
||||
|
||||
resp, err := client.Show(ctx, &api.ShowRequest{Model: model})
|
||||
if err != nil {
|
||||
t.Fatalf("show failed: %s", err)
|
||||
}
|
||||
arch := resp.ModelInfo["general.architecture"].(string)
|
||||
maxContext = int(resp.ModelInfo[fmt.Sprintf("%s.context_length", arch)].(float64))
|
||||
|
||||
if maxVram > 0 {
|
||||
resp, err := client.List(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("list models failed %v", err)
|
||||
}
|
||||
for _, m := range resp.Models {
|
||||
// For these tests we want to exercise a some amount of overflow on the CPU
|
||||
if m.Name == model && float32(m.Size)*0.75 > float32(maxVram) {
|
||||
t.Skipf("model %s is too large %s for available VRAM %s", model, format.HumanBytes(m.Size), format.HumanBytes(int64(maxVram)))
|
||||
}
|
||||
}
|
||||
}
|
||||
slog.Info("scneario", "model", model, "max_context", maxContext)
|
||||
loaded := false
|
||||
defer func() {
|
||||
// best effort unload once we're done with the model
|
||||
if loaded {
|
||||
client.Generate(ctx, &api.GenerateRequest{Model: model, KeepAlive: &api.Duration{Duration: 0}}, func(rsp api.GenerateResponse) error { return nil })
|
||||
}
|
||||
}()
|
||||
|
||||
// Some models don't handle the long context data well so skip them to avoid flaky test results
|
||||
longContextFlake := false
|
||||
for _, flake := range longContextFlakes {
|
||||
if model == flake {
|
||||
longContextFlake = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
// iterate through a few context sizes for coverage without excessive runtime
|
||||
var contexts []int
|
||||
keepGoing := true
|
||||
if maxContext > 16384 {
|
||||
contexts = []int{4096, 8192, 16384, maxContext}
|
||||
} else if maxContext > 8192 {
|
||||
contexts = []int{4096, 8192, maxContext}
|
||||
} else if maxContext > 4096 {
|
||||
contexts = []int{4096, maxContext}
|
||||
} else if maxContext > 0 {
|
||||
contexts = []int{maxContext}
|
||||
} else {
|
||||
t.Fatal("unknown max context size")
|
||||
}
|
||||
for _, numCtx := range contexts {
|
||||
if !keepGoing && numCtx > 8192 { // Always try up to 8k before bailing out
|
||||
break
|
||||
}
|
||||
skipLongPrompt := false
|
||||
|
||||
// Workaround bug 11172 temporarily...
|
||||
maxPrompt := longPrompt
|
||||
// If we fill the context too full with the prompt, many models
|
||||
// quickly hit context shifting and go bad.
|
||||
if len(maxPrompt) > numCtx*2 { // typically yields ~1/2 full context
|
||||
maxPrompt = maxPrompt[:numCtx*2]
|
||||
}
|
||||
|
||||
testCases := []struct {
|
||||
prompt string
|
||||
anyResp []string
|
||||
}{
|
||||
{"why is the sky blue?", []string{"rayleigh", "scattering", "atmosphere", "nitrogen", "oxygen"}},
|
||||
{maxPrompt, []string{"shakespeare", "oppression", "sorrows", "gutenberg", "child", "license", "sonnet", "melancholy"}},
|
||||
}
|
||||
var gpuPercent int
|
||||
for _, tc := range testCases {
|
||||
if len(tc.prompt) > 100 && (longContextFlake || skipLongPrompt) {
|
||||
slog.Info("skipping long prompt", "model", model, "num_ctx", numCtx, "gpu_percent", gpuPercent)
|
||||
continue
|
||||
}
|
||||
req := api.GenerateRequest{
|
||||
Model: model,
|
||||
Prompt: tc.prompt,
|
||||
KeepAlive: &api.Duration{Duration: 20 * time.Second}, // long enough to ensure a ps returns
|
||||
Options: map[string]interface{}{
|
||||
"temperature": 0,
|
||||
"seed": 123,
|
||||
"num_ctx": numCtx,
|
||||
},
|
||||
}
|
||||
atLeastOne := false
|
||||
var resp api.GenerateResponse
|
||||
|
||||
stream := false
|
||||
req.Stream = &stream
|
||||
|
||||
// Avoid potentially getting stuck indefinitely
|
||||
limit := 5 * time.Minute
|
||||
genCtx, cancel := context.WithDeadlineCause(
|
||||
ctx,
|
||||
time.Now().Add(limit),
|
||||
fmt.Errorf("generate on model %s with ctx %d took longer than %v", model, numCtx, limit),
|
||||
)
|
||||
defer cancel()
|
||||
|
||||
err = client.Generate(genCtx, &req, func(rsp api.GenerateResponse) error {
|
||||
resp = rsp
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
// Avoid excessive test runs, but don't consider a failure with massive context
|
||||
if numCtx > 16384 && strings.Contains(err.Error(), "took longer") {
|
||||
slog.Warn("max context was taking too long, skipping", "error", err)
|
||||
keepGoing = false
|
||||
skipLongPrompt = true
|
||||
continue
|
||||
}
|
||||
t.Fatalf("generate error: ctx:%d err:%s", numCtx, err)
|
||||
}
|
||||
loaded = true
|
||||
for _, expResp := range tc.anyResp {
|
||||
if strings.Contains(strings.ToLower(resp.Response), expResp) {
|
||||
atLeastOne = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !atLeastOne {
|
||||
t.Fatalf("response didn't contain expected values: ctx:%d expected:%v response:%s ", numCtx, tc.anyResp, resp.Response)
|
||||
}
|
||||
models, err := client.ListRunning(ctx)
|
||||
if err != nil {
|
||||
slog.Warn("failed to list running models", "error", err)
|
||||
continue
|
||||
}
|
||||
if len(models.Models) > 1 {
|
||||
slog.Warn("multiple models loaded, may impact performance results", "loaded", models.Models)
|
||||
}
|
||||
for _, m := range models.Models {
|
||||
if m.Name == model {
|
||||
if m.SizeVRAM == 0 {
|
||||
slog.Info("Model fully loaded into CPU")
|
||||
gpuPercent = 0
|
||||
keepGoing = false
|
||||
skipLongPrompt = true
|
||||
} else if m.SizeVRAM == m.Size {
|
||||
slog.Info("Model fully loaded into GPU")
|
||||
gpuPercent = 100
|
||||
} else {
|
||||
sizeCPU := m.Size - m.SizeVRAM
|
||||
cpuPercent := math.Round(float64(sizeCPU) / float64(m.Size) * 100)
|
||||
gpuPercent = int(100 - cpuPercent)
|
||||
slog.Info("Model split between CPU/GPU", "CPU", cpuPercent, "GPU", gpuPercent)
|
||||
keepGoing = false
|
||||
|
||||
// Heuristic to avoid excessive test run time
|
||||
if gpuPercent < 90 {
|
||||
skipLongPrompt = true
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
fmt.Fprintf(os.Stderr, "MODEL_PERF_HEADER:%s,%s,%s,%s,%s,%s,%s\n",
|
||||
"MODEL",
|
||||
"CONTEXT",
|
||||
"GPU PERCENT",
|
||||
"PROMPT COUNT",
|
||||
"LOAD TIME",
|
||||
"PROMPT EVAL TPS",
|
||||
"EVAL TPS",
|
||||
)
|
||||
fmt.Fprintf(os.Stderr, "MODEL_PERF_DATA:%s,%d,%d,%d,%0.2f,%0.2f,%0.2f\n",
|
||||
model,
|
||||
numCtx,
|
||||
gpuPercent,
|
||||
resp.PromptEvalCount,
|
||||
float64(resp.LoadDuration)/1000000000.0,
|
||||
float64(resp.PromptEvalCount)/(float64(resp.PromptEvalDuration)/1000000000.0),
|
||||
float64(resp.EvalCount)/(float64(resp.EvalDuration)/1000000000.0),
|
||||
)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
124456
integration/testdata/shakespeare.txt
vendored
124456
integration/testdata/shakespeare.txt
vendored
File diff suppressed because it is too large
Load Diff
@@ -32,229 +32,6 @@ const (
|
||||
smol = "llama3.2:1b"
|
||||
)
|
||||
|
||||
var (
|
||||
started = time.Now()
|
||||
|
||||
// Note: add newer models at the top of the list to test them first
|
||||
ollamaEngineChatModels = []string{
|
||||
"gemma3n:e2b",
|
||||
"mistral-small3.2:latest",
|
||||
"deepseek-r1:1.5b",
|
||||
"llama3.2-vision:latest",
|
||||
"qwen2.5-coder:latest",
|
||||
"qwen2.5vl:3b",
|
||||
"qwen3:0.6b", // dense
|
||||
"qwen3:30b", // MOE
|
||||
"gemma3:1b",
|
||||
"llama3.1:latest",
|
||||
"llama3.2:latest",
|
||||
"gemma2:latest",
|
||||
"minicpm-v:latest", // arch=qwen2
|
||||
"granite-code:latest", // arch=llama
|
||||
}
|
||||
llamaRunnerChatModels = []string{
|
||||
"mistral:latest",
|
||||
"falcon3:latest",
|
||||
"granite3-moe:latest",
|
||||
"command-r:latest",
|
||||
"nemotron-mini:latest",
|
||||
"phi3.5:latest",
|
||||
"solar-pro:latest",
|
||||
"internlm2:latest",
|
||||
"codellama:latest", // arch=llama
|
||||
"phi3:latest",
|
||||
"falcon2:latest",
|
||||
"gemma:latest",
|
||||
"llama2:latest",
|
||||
"nous-hermes:latest",
|
||||
"orca-mini:latest",
|
||||
"qwen:latest",
|
||||
"stablelm2:latest", // Predictions are off, crashes on small VRAM GPUs
|
||||
"falcon:latest",
|
||||
}
|
||||
|
||||
// Some library models are quite large - ensure large VRAM and sufficient disk space
|
||||
// before running scenarios based on this set
|
||||
libraryChatModels = []string{
|
||||
"alfred",
|
||||
"athene-v2",
|
||||
"aya-expanse",
|
||||
"aya",
|
||||
"bakllava",
|
||||
"bespoke-minicheck",
|
||||
"codebooga",
|
||||
"codegeex4",
|
||||
"codegemma",
|
||||
"codellama",
|
||||
"codeqwen",
|
||||
"codestral",
|
||||
"codeup",
|
||||
"cogito",
|
||||
"command-a",
|
||||
"command-r-plus",
|
||||
"command-r",
|
||||
"command-r7b-arabic",
|
||||
"command-r7b",
|
||||
"dbrx",
|
||||
"deepcoder",
|
||||
"deepscaler",
|
||||
"deepseek-coder-v2",
|
||||
"deepseek-coder",
|
||||
"deepseek-llm",
|
||||
"deepseek-r1",
|
||||
// "deepseek-v2.5", // requires 155 GB VRAM
|
||||
"deepseek-v2",
|
||||
// "deepseek-v3", // requires 482 GB VRAM
|
||||
"devstral",
|
||||
"dolphin-llama3",
|
||||
"dolphin-mistral",
|
||||
"dolphin-mixtral",
|
||||
"dolphin-phi",
|
||||
"dolphin3",
|
||||
"dolphincoder",
|
||||
"duckdb-nsql",
|
||||
"everythinglm",
|
||||
"exaone-deep",
|
||||
"exaone3.5",
|
||||
"falcon",
|
||||
"falcon2",
|
||||
"falcon3",
|
||||
"firefunction-v2",
|
||||
"gemma",
|
||||
"gemma2",
|
||||
"gemma3",
|
||||
"gemma3n",
|
||||
"glm4",
|
||||
"goliath",
|
||||
"granite-code",
|
||||
"granite3-dense",
|
||||
"granite3-guardian",
|
||||
"granite3-moe",
|
||||
"granite3.1-dense",
|
||||
"granite3.1-moe",
|
||||
"granite3.2-vision",
|
||||
"granite3.2",
|
||||
"granite3.3",
|
||||
"hermes3",
|
||||
"internlm2",
|
||||
"llama-guard3",
|
||||
"llama-pro",
|
||||
"llama2-chinese",
|
||||
"llama2-uncensored",
|
||||
"llama2",
|
||||
"llama3-chatqa",
|
||||
"llama3-gradient",
|
||||
"llama3-groq-tool-use",
|
||||
"llama3.1",
|
||||
"llama3.2-vision",
|
||||
"llama3.2",
|
||||
"llama3.3",
|
||||
"llama3",
|
||||
"llama4",
|
||||
"llava-llama3",
|
||||
"llava-phi3",
|
||||
"llava",
|
||||
"magicoder",
|
||||
"magistral",
|
||||
"marco-o1",
|
||||
"mathstral",
|
||||
"meditron",
|
||||
"medllama2",
|
||||
"megadolphin",
|
||||
"minicpm-v",
|
||||
"mistral-large",
|
||||
"mistral-nemo",
|
||||
"mistral-openorca",
|
||||
"mistral-small",
|
||||
"mistral-small3.1",
|
||||
"mistral-small3.2",
|
||||
"mistral",
|
||||
"mistrallite",
|
||||
"mixtral",
|
||||
"moondream",
|
||||
"nemotron-mini",
|
||||
"nemotron",
|
||||
"neural-chat",
|
||||
"nexusraven",
|
||||
"notus",
|
||||
"nous-hermes",
|
||||
"nous-hermes2-mixtral",
|
||||
"nous-hermes2",
|
||||
"nuextract",
|
||||
"olmo2",
|
||||
"open-orca-platypus2",
|
||||
"openchat",
|
||||
"opencoder",
|
||||
"openhermes",
|
||||
"openthinker",
|
||||
"orca-mini",
|
||||
"orca2",
|
||||
// "phi", // unreliable
|
||||
"phi3.5",
|
||||
"phi3",
|
||||
"phi4-mini-reasoning",
|
||||
"phi4-mini",
|
||||
"phi4-reasoning",
|
||||
"phi4",
|
||||
"phind-codellama",
|
||||
"qwen",
|
||||
"qwen2-math",
|
||||
"qwen2.5-coder",
|
||||
"qwen2.5",
|
||||
"qwen2.5vl",
|
||||
"qwen2",
|
||||
"qwen3:0.6b", // dense
|
||||
"qwen3:30b", // MOE
|
||||
"qwq",
|
||||
"r1-1776",
|
||||
"reader-lm",
|
||||
"reflection",
|
||||
"sailor2",
|
||||
"samantha-mistral",
|
||||
"shieldgemma",
|
||||
"smallthinker",
|
||||
"smollm",
|
||||
"smollm2",
|
||||
"solar-pro",
|
||||
"solar",
|
||||
"sqlcoder",
|
||||
"stable-beluga",
|
||||
"stable-code",
|
||||
"stablelm-zephyr",
|
||||
"stablelm2",
|
||||
"starcoder",
|
||||
"starcoder2",
|
||||
"starling-lm",
|
||||
"tinydolphin",
|
||||
"tinyllama",
|
||||
"tulu3",
|
||||
"vicuna",
|
||||
"wizard-math",
|
||||
"wizard-vicuna-uncensored",
|
||||
"wizard-vicuna",
|
||||
"wizardcoder",
|
||||
"wizardlm-uncensored",
|
||||
"wizardlm2",
|
||||
"xwinlm",
|
||||
"yarn-llama2",
|
||||
"yarn-mistral",
|
||||
"yi-coder",
|
||||
"yi",
|
||||
"zephyr",
|
||||
}
|
||||
libraryEmbedModels = []string{
|
||||
"all-minilm",
|
||||
"bge-large",
|
||||
"bge-m3",
|
||||
"granite-embedding",
|
||||
"mxbai-embed-large",
|
||||
"nomic-embed-text",
|
||||
"paraphrase-multilingual",
|
||||
"snowflake-arctic-embed",
|
||||
"snowflake-arctic-embed2",
|
||||
}
|
||||
)
|
||||
|
||||
func Init() {
|
||||
lifecycle.InitLogging()
|
||||
}
|
||||
@@ -494,10 +271,6 @@ func DoGenerate(ctx context.Context, t *testing.T, client *api.Client, genReq ap
|
||||
t.Errorf("generate stalled. Response so far:%s", buf.String())
|
||||
}
|
||||
case <-done:
|
||||
if genErr != nil && strings.Contains(genErr.Error(), "model requires more system memory") {
|
||||
slog.Warn("model is too large for the target test system", "model", genReq.Model, "error", genErr)
|
||||
return
|
||||
}
|
||||
require.NoError(t, genErr, "failed with %s request prompt %s ", genReq.Model, genReq.Prompt)
|
||||
// Verify the response contains the expected data
|
||||
response := buf.String()
|
||||
|
||||
@@ -25,9 +25,6 @@ type Causal struct {
|
||||
|
||||
opts CausalOptions
|
||||
|
||||
// maxBatch is the largest batch that we might receive
|
||||
maxBatch int
|
||||
|
||||
// config controls mostly backend-specific optimizations
|
||||
config *ml.CacheConfig
|
||||
|
||||
@@ -150,7 +147,6 @@ func (c *Causal) Init(backend ml.Backend, dtype ml.DType, maxSequences, capacity
|
||||
c.DType = dtype
|
||||
c.cellRanges = make(map[int]cellRange)
|
||||
c.backend = backend
|
||||
c.maxBatch = maxBatch
|
||||
}
|
||||
|
||||
func (c *Causal) SetConfig(config ml.CacheConfig) {
|
||||
@@ -643,64 +639,48 @@ func (c *Causal) shift(seq int, beginIndex, offset int32) error {
|
||||
return ErrNotSupported
|
||||
}
|
||||
|
||||
ctx := c.backend.NewContext()
|
||||
defer ctx.Close()
|
||||
|
||||
seqRange := c.cellRanges[seq]
|
||||
size := seqRange.max - seqRange.min + 1
|
||||
|
||||
for start := seqRange.min; start <= seqRange.max; start += c.maxBatch {
|
||||
size := min(seqRange.max-start+1, c.maxBatch)
|
||||
offsets := make([]int32, size)
|
||||
offsets := make([]int32, size)
|
||||
for i := range offsets {
|
||||
cell := c.cells[seqRange.min+i]
|
||||
|
||||
var batchFirst, batchLast int
|
||||
|
||||
batchFirst = -1
|
||||
for i := range offsets {
|
||||
cell := c.cells[start+i]
|
||||
|
||||
if slices.Contains(cell.sequences, seq) && cell.pos >= beginIndex {
|
||||
offsets[i] = offset
|
||||
if batchFirst < 0 {
|
||||
batchFirst = i
|
||||
}
|
||||
batchLast = i
|
||||
}
|
||||
if slices.Contains(cell.sequences, seq) && cell.pos >= beginIndex {
|
||||
offsets[i] = offset
|
||||
}
|
||||
}
|
||||
|
||||
if batchFirst < 0 {
|
||||
kShift := ctx.Input().FromIntSlice(offsets, len(offsets))
|
||||
|
||||
for i, key := range c.keys {
|
||||
if key == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
offsets = offsets[batchFirst : batchLast+1]
|
||||
kHeadDim := key.Dim(0)
|
||||
numKVHeads := key.Dim(1)
|
||||
rowSize := key.Stride(2)
|
||||
|
||||
ctx := c.backend.NewContext()
|
||||
kShift := ctx.Input().FromIntSlice(offsets, len(offsets))
|
||||
key = key.View(ctx, rowSize*seqRange.min,
|
||||
kHeadDim, key.Stride(1),
|
||||
numKVHeads, key.Stride(2),
|
||||
size,
|
||||
)
|
||||
|
||||
for i, key := range c.keys {
|
||||
if key == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
kHeadDim := key.Dim(0)
|
||||
numKVHeads := key.Dim(1)
|
||||
rowSize := key.Stride(2)
|
||||
|
||||
key = key.View(ctx, rowSize*(start+batchFirst),
|
||||
kHeadDim, key.Stride(1),
|
||||
numKVHeads, key.Stride(2),
|
||||
len(offsets),
|
||||
)
|
||||
|
||||
roped, err := c.shiftFn(ctx, i, key, kShift)
|
||||
if err != nil {
|
||||
ctx.Close()
|
||||
return err
|
||||
}
|
||||
|
||||
ctx.Forward(roped.Copy(ctx, key))
|
||||
roped, err := c.shiftFn(ctx, i, key, kShift)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
ctx.Compute()
|
||||
ctx.Close()
|
||||
ctx.Forward(roped.Copy(ctx, key))
|
||||
}
|
||||
|
||||
ctx.Compute()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
@@ -150,7 +150,7 @@ index 4cce5166..7f6617fa 100644
|
||||
llama_model_loader::llama_model_loader(
|
||||
const std::string & fname,
|
||||
diff --git a/src/llama-model.cpp b/src/llama-model.cpp
|
||||
index 3a4e72a3..db62973f 100644
|
||||
index 3a4e72a3..831b68c0 100644
|
||||
--- a/src/llama-model.cpp
|
||||
+++ b/src/llama-model.cpp
|
||||
@@ -1402,6 +1402,21 @@ void llama_model::load_hparams(llama_model_loader & ml) {
|
||||
|
||||
@@ -22,10 +22,10 @@ multiple batches of processing until everything is complete.
|
||||
4 files changed, 59 insertions(+), 79 deletions(-)
|
||||
|
||||
diff --git a/src/llama-context.cpp b/src/llama-context.cpp
|
||||
index dca22d8b..1f3a3956 100644
|
||||
index c22687e4..c5948e8f 100644
|
||||
--- a/src/llama-context.cpp
|
||||
+++ b/src/llama-context.cpp
|
||||
@@ -947,9 +947,12 @@ int llama_context::decode(llama_batch & inp_batch) {
|
||||
@@ -950,9 +950,12 @@ int llama_context::decode(llama_batch & inp_batch) {
|
||||
|
||||
// find KV slot
|
||||
if (!kv_self->find_slot(ubatch)) {
|
||||
@@ -41,7 +41,7 @@ index dca22d8b..1f3a3956 100644
|
||||
}
|
||||
|
||||
ggml_backend_sched_reset(sched.get());
|
||||
@@ -1965,9 +1968,12 @@ void llama_context::opt_epoch_iter(
|
||||
@@ -1967,9 +1970,12 @@ void llama_context::opt_epoch_iter(
|
||||
|
||||
// TODO: not sure if this is needed
|
||||
if (!kv_self->find_slot(ubatch)) {
|
||||
|
||||
@@ -10,10 +10,10 @@ Subject: [PATCH] add argsort and cuda copy for i32
|
||||
3 files changed, 192 insertions(+), 2 deletions(-)
|
||||
|
||||
diff --git a/ggml/src/ggml-cpu/ops.cpp b/ggml/src/ggml-cpu/ops.cpp
|
||||
index 955fec59..654e2f28 100644
|
||||
index becdae07..7a44b6cf 100644
|
||||
--- a/ggml/src/ggml-cpu/ops.cpp
|
||||
+++ b/ggml/src/ggml-cpu/ops.cpp
|
||||
@@ -6822,6 +6822,45 @@ static void ggml_compute_forward_argsort_f32(
|
||||
@@ -6890,6 +6890,45 @@ static void ggml_compute_forward_argsort_f32(
|
||||
}
|
||||
}
|
||||
|
||||
@@ -59,7 +59,7 @@ index 955fec59..654e2f28 100644
|
||||
void ggml_compute_forward_argsort(
|
||||
const ggml_compute_params * params,
|
||||
ggml_tensor * dst) {
|
||||
@@ -6833,6 +6872,10 @@ void ggml_compute_forward_argsort(
|
||||
@@ -6901,6 +6940,10 @@ void ggml_compute_forward_argsort(
|
||||
{
|
||||
ggml_compute_forward_argsort_f32(params, dst);
|
||||
} break;
|
||||
@@ -195,7 +195,7 @@ index 607ded85..53b02634 100644
|
||||
+ }
|
||||
}
|
||||
diff --git a/ggml/src/ggml-cuda/cpy.cu b/ggml/src/ggml-cuda/cpy.cu
|
||||
index d027271f..4abd01d7 100644
|
||||
index 2d46176e..47383486 100644
|
||||
--- a/ggml/src/ggml-cuda/cpy.cu
|
||||
+++ b/ggml/src/ggml-cuda/cpy.cu
|
||||
@@ -38,6 +38,13 @@ static __device__ void cpy_1_f16_f32(const char * cxi, char * cdsti) {
|
||||
@@ -257,7 +257,7 @@ index d027271f..4abd01d7 100644
|
||||
static __device__ void cpy_blck_f32_q8_0(const char * cxi, char * cdsti) {
|
||||
const float * xi = (const float *) cxi;
|
||||
block_q8_0 * dsti = (block_q8_0 *) cdsti;
|
||||
@@ -633,6 +678,8 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg
|
||||
@@ -631,6 +676,8 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg
|
||||
ggml_cpy_f16_f16_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
|
||||
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32) {
|
||||
ggml_cpy_f16_f32_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
|
||||
@@ -266,7 +266,7 @@ index d027271f..4abd01d7 100644
|
||||
} else {
|
||||
GGML_ABORT("%s: unsupported type combination (%s to %s)\n", __func__,
|
||||
ggml_type_name(src0->type), ggml_type_name(src1->type));
|
||||
@@ -688,6 +735,8 @@ void* ggml_cuda_cpy_fn(const ggml_tensor * src0, ggml_tensor * src1) {
|
||||
@@ -686,6 +733,8 @@ void* ggml_cuda_cpy_fn(const ggml_tensor * src0, ggml_tensor * src1) {
|
||||
return (void*) cpy_f32_f16<cpy_1_f32_f16>;
|
||||
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32) {
|
||||
return (void*) cpy_f32_f16<cpy_1_f16_f32>;
|
||||
|
||||
@@ -7,31 +7,31 @@ This enables matching up devices and information reported by the backend
|
||||
with tools (e.g. nvidia-smi) and system management libraries (e.g. nvml).
|
||||
---
|
||||
ggml/include/ggml-backend.h | 1 +
|
||||
ggml/src/ggml-cuda/ggml-cuda.cu | 39 ++++++++++++++++++++++++++++++++
|
||||
ggml/src/ggml-cuda/ggml-cuda.cu | 33 ++++++++++++++++++++++++++++++++
|
||||
ggml/src/ggml-metal/ggml-metal.m | 1 +
|
||||
3 files changed, 41 insertions(+)
|
||||
3 files changed, 35 insertions(+)
|
||||
|
||||
diff --git a/ggml/include/ggml-backend.h b/ggml/include/ggml-backend.h
|
||||
index 74e46716..48839339 100644
|
||||
index 74e46716..a880df33 100644
|
||||
--- a/ggml/include/ggml-backend.h
|
||||
+++ b/ggml/include/ggml-backend.h
|
||||
@@ -152,6 +152,7 @@ extern "C" {
|
||||
struct ggml_backend_dev_props {
|
||||
const char * name;
|
||||
const char * description;
|
||||
+ const char * id;
|
||||
+ const char * uuid;
|
||||
size_t memory_free;
|
||||
size_t memory_total;
|
||||
enum ggml_backend_dev_type type;
|
||||
diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu
|
||||
index cb0d8528..d6960174 100644
|
||||
index cb0d8528..4c829153 100644
|
||||
--- a/ggml/src/ggml-cuda/ggml-cuda.cu
|
||||
+++ b/ggml/src/ggml-cuda/ggml-cuda.cu
|
||||
@@ -2884,6 +2884,7 @@ struct ggml_backend_cuda_device_context {
|
||||
int device;
|
||||
std::string name;
|
||||
std::string description;
|
||||
+ std::string id;
|
||||
+ std::string uuid;
|
||||
};
|
||||
|
||||
static const char * ggml_backend_cuda_device_get_name(ggml_backend_dev_t dev) {
|
||||
@@ -39,9 +39,9 @@ index cb0d8528..d6960174 100644
|
||||
return ctx->description.c_str();
|
||||
}
|
||||
|
||||
+static const char * ggml_backend_cuda_device_get_id(ggml_backend_dev_t dev) {
|
||||
+static const char * ggml_backend_cuda_device_get_uuid(ggml_backend_dev_t dev) {
|
||||
+ ggml_backend_cuda_device_context * ctx = (ggml_backend_cuda_device_context *)dev->context;
|
||||
+ return ctx->id.c_str();
|
||||
+ return ctx->uuid.c_str();
|
||||
+}
|
||||
+
|
||||
static void ggml_backend_cuda_device_get_memory(ggml_backend_dev_t dev, size_t * free, size_t * total) {
|
||||
@@ -51,17 +51,17 @@ index cb0d8528..d6960174 100644
|
||||
static void ggml_backend_cuda_device_get_props(ggml_backend_dev_t dev, ggml_backend_dev_props * props) {
|
||||
props->name = ggml_backend_cuda_device_get_name(dev);
|
||||
props->description = ggml_backend_cuda_device_get_description(dev);
|
||||
+ props->id = ggml_backend_cuda_device_get_id(dev);
|
||||
+ props->uuid = ggml_backend_cuda_device_get_uuid(dev);
|
||||
props->type = ggml_backend_cuda_device_get_type(dev);
|
||||
ggml_backend_cuda_device_get_memory(dev, &props->memory_free, &props->memory_total);
|
||||
|
||||
@@ -3458,6 +3465,38 @@ ggml_backend_reg_t ggml_backend_cuda_reg() {
|
||||
@@ -3458,6 +3465,32 @@ ggml_backend_reg_t ggml_backend_cuda_reg() {
|
||||
CUDA_CHECK(cudaGetDeviceProperties(&prop, i));
|
||||
dev_ctx->description = prop.name;
|
||||
|
||||
+ #if !defined(GGML_USE_HIP)
|
||||
+ char id[64];
|
||||
+ snprintf(id, sizeof(id),
|
||||
+ char uuid[64];
|
||||
+ snprintf(uuid, sizeof(uuid),
|
||||
+ "GPU-%02x%02x%02x%02x-%02x%02x-%02x%02x-%02x%02x-%02x%02x%02x%02x%02x%02x",
|
||||
+ (unsigned char)prop.uuid.bytes[0],
|
||||
+ (unsigned char)prop.uuid.bytes[1],
|
||||
@@ -80,29 +80,23 @@ index cb0d8528..d6960174 100644
|
||||
+ (unsigned char)prop.uuid.bytes[14],
|
||||
+ (unsigned char)prop.uuid.bytes[15]
|
||||
+ );
|
||||
+ dev_ctx->id = id;
|
||||
+ dev_ctx->uuid = uuid;
|
||||
+ #else
|
||||
+ #ifdef _WIN32
|
||||
+ char id[16];
|
||||
+ snprintf(id, sizeof(id), "%d", i);
|
||||
+ dev_ctx->id = id;
|
||||
+ #else
|
||||
+ dev_ctx->id = "GPU-" + std::string(prop.uuid.bytes, 16);
|
||||
+ #endif
|
||||
+ dev_ctx->uuid = "GPU-" + std::string(prop.uuid.bytes, 16);
|
||||
+ #endif
|
||||
+
|
||||
ggml_backend_dev_t dev = new ggml_backend_device {
|
||||
/* .iface = */ ggml_backend_cuda_device_interface,
|
||||
/* .reg = */ ®,
|
||||
diff --git a/ggml/src/ggml-metal/ggml-metal.m b/ggml/src/ggml-metal/ggml-metal.m
|
||||
index 1b56f858..a9eeebc6 100644
|
||||
index 1b56f858..ee4f2dcb 100644
|
||||
--- a/ggml/src/ggml-metal/ggml-metal.m
|
||||
+++ b/ggml/src/ggml-metal/ggml-metal.m
|
||||
@@ -5703,6 +5703,7 @@ static enum ggml_backend_dev_type ggml_backend_metal_device_get_type(ggml_backen
|
||||
static void ggml_backend_metal_device_get_props(ggml_backend_dev_t dev, struct ggml_backend_dev_props * props) {
|
||||
props->name = ggml_backend_metal_device_get_name(dev);
|
||||
props->description = ggml_backend_metal_device_get_description(dev);
|
||||
+ props->id = "0";
|
||||
+ props->uuid = "0";
|
||||
props->type = ggml_backend_metal_device_get_type(dev);
|
||||
ggml_backend_metal_device_get_memory(dev, &props->memory_free, &props->memory_total);
|
||||
props->caps = (struct ggml_backend_dev_caps) {
|
||||
|
||||
@@ -1,32 +0,0 @@
|
||||
From 0000000000000000000000000000000000000000 Mon Sep 17 00:00:00 2001
|
||||
From: Daniel Hiltgen <daniel@ollama.com>
|
||||
Date: Sun, 22 Jun 2025 09:22:05 -0700
|
||||
Subject: [PATCH] temporary prevent rocm+cuda mixed loading
|
||||
|
||||
---
|
||||
ggml/src/ggml-backend-reg.cpp | 12 ++++++++++--
|
||||
1 file changed, 10 insertions(+), 2 deletions(-)
|
||||
|
||||
diff --git a/ggml/src/ggml-backend-reg.cpp b/ggml/src/ggml-backend-reg.cpp
|
||||
index 4e67d243..8f49f084 100644
|
||||
--- a/ggml/src/ggml-backend-reg.cpp
|
||||
+++ b/ggml/src/ggml-backend-reg.cpp
|
||||
@@ -573,8 +573,16 @@ void ggml_backend_load_all_from_path(const char * dir_path) {
|
||||
|
||||
ggml_backend_load_best("blas", silent, dir_path);
|
||||
ggml_backend_load_best("cann", silent, dir_path);
|
||||
- ggml_backend_load_best("cuda", silent, dir_path);
|
||||
- ggml_backend_load_best("hip", silent, dir_path);
|
||||
+
|
||||
+ // Avoid mixed hip+cuda configurations
|
||||
+ const char * hip_devices = std::getenv("HIP_VISIBLE_DEVICES");
|
||||
+ const char * rocr_devices = std::getenv("ROCR_VISIBLE_DEVICES");
|
||||
+ if (!hip_devices && !rocr_devices) {
|
||||
+ ggml_backend_load_best("cuda", silent, dir_path);
|
||||
+ } else {
|
||||
+ ggml_backend_load_best("hip", silent, dir_path);
|
||||
+ }
|
||||
+
|
||||
ggml_backend_load_best("kompute", silent, dir_path);
|
||||
ggml_backend_load_best("metal", silent, dir_path);
|
||||
ggml_backend_load_best("rpc", silent, dir_path);
|
||||
@@ -1,169 +0,0 @@
|
||||
From 0000000000000000000000000000000000000000 Mon Sep 17 00:00:00 2001
|
||||
From: Georgi Gerganov <ggerganov@gmail.com>
|
||||
Date: Thu, 19 Jun 2025 08:05:21 +0300
|
||||
Subject: [PATCH] metal : add mean kernel (#14267)
|
||||
|
||||
* metal : add mean kernel
|
||||
|
||||
ggml-ci
|
||||
|
||||
* cont : dedup implementation
|
||||
|
||||
ggml-ci
|
||||
---
|
||||
ggml/src/ggml-metal/ggml-metal.m | 33 ++++++++++++++++---
|
||||
ggml/src/ggml-metal/ggml-metal.metal | 48 ++++++++++++++++++++++------
|
||||
2 files changed, 67 insertions(+), 14 deletions(-)
|
||||
|
||||
diff --git a/ggml/src/ggml-metal/ggml-metal.m b/ggml/src/ggml-metal/ggml-metal.m
|
||||
index a9eeebc6..110c9ece 100644
|
||||
--- a/ggml/src/ggml-metal/ggml-metal.m
|
||||
+++ b/ggml/src/ggml-metal/ggml-metal.m
|
||||
@@ -489,6 +489,7 @@ enum ggml_metal_kernel_type {
|
||||
GGML_METAL_KERNEL_TYPE_COS,
|
||||
GGML_METAL_KERNEL_TYPE_NEG,
|
||||
GGML_METAL_KERNEL_TYPE_SUM_ROWS,
|
||||
+ GGML_METAL_KERNEL_TYPE_MEAN,
|
||||
GGML_METAL_KERNEL_TYPE_POOL_2D_AVG_F32,
|
||||
GGML_METAL_KERNEL_TYPE_POOL_2D_MAX_F32,
|
||||
GGML_METAL_KERNEL_TYPE_ARGMAX,
|
||||
@@ -1436,6 +1437,7 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_COS, cos, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_NEG, neg, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SUM_ROWS, sum_rows, true);
|
||||
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MEAN, mean, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGMAX, argmax, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_POOL_2D_AVG_F32, pool_2d_avg_f32, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_POOL_2D_MAX_F32, pool_2d_max_f32, true);
|
||||
@@ -1634,6 +1636,7 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
|
||||
case GGML_OP_LOG:
|
||||
return false; // TODO: implement
|
||||
case GGML_OP_SUM_ROWS:
|
||||
+ case GGML_OP_MEAN:
|
||||
case GGML_OP_SOFT_MAX:
|
||||
case GGML_OP_GROUP_NORM:
|
||||
return has_simdgroup_reduction && ggml_is_contiguous(op->src[0]);
|
||||
@@ -2362,11 +2365,30 @@ static bool ggml_metal_encode_node(
|
||||
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
||||
} break;
|
||||
case GGML_OP_SUM_ROWS:
|
||||
+ case GGML_OP_MEAN:
|
||||
{
|
||||
GGML_ASSERT(src0->nb[0] == ggml_type_size(src0->type));
|
||||
|
||||
- id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SUM_ROWS].pipeline;
|
||||
+ id<MTLComputePipelineState> pipeline = nil;
|
||||
+
|
||||
+ switch (dst->op) {
|
||||
+ case GGML_OP_SUM_ROWS:
|
||||
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SUM_ROWS].pipeline;
|
||||
+ break;
|
||||
+ case GGML_OP_MEAN:
|
||||
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MEAN].pipeline;
|
||||
+ break;
|
||||
+ default:
|
||||
+ GGML_ABORT("fatal error");
|
||||
+ }
|
||||
+
|
||||
+ int nth = 32; // SIMD width
|
||||
+
|
||||
+ while (nth < ne00 && nth < (int) pipeline.maxTotalThreadsPerThreadgroup) {
|
||||
+ nth *= 2;
|
||||
+ }
|
||||
|
||||
+ nth = MIN(nth, ne00);
|
||||
|
||||
ggml_metal_kargs_sum_rows args = {
|
||||
/*.ne00 =*/ ne00,
|
||||
@@ -2396,11 +2418,12 @@ static bool ggml_metal_encode_node(
|
||||
};
|
||||
|
||||
[encoder setComputePipelineState:pipeline];
|
||||
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
||||
- [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
||||
- [encoder setBytes:&args length:sizeof(args) atIndex:2];
|
||||
+ [encoder setBytes:&args length:sizeof(args) atIndex:0];
|
||||
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
|
||||
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
||||
+ [encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0];
|
||||
|
||||
- [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
||||
+ [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
||||
} break;
|
||||
case GGML_OP_SOFT_MAX:
|
||||
{
|
||||
diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal
|
||||
index 9cfddf45..08e8d807 100644
|
||||
--- a/ggml/src/ggml-metal/ggml-metal.metal
|
||||
+++ b/ggml/src/ggml-metal/ggml-metal.metal
|
||||
@@ -956,31 +956,61 @@ kernel void kernel_neg(
|
||||
dst[tpig] = -src0[tpig];
|
||||
}
|
||||
|
||||
+template <bool norm>
|
||||
kernel void kernel_sum_rows(
|
||||
+ constant ggml_metal_kargs_sum_rows & args,
|
||||
device const float * src0,
|
||||
device float * dst,
|
||||
- constant ggml_metal_kargs_sum_rows & args,
|
||||
- uint3 tpig[[thread_position_in_grid]]) {
|
||||
- int64_t i3 = tpig.z;
|
||||
- int64_t i2 = tpig.y;
|
||||
- int64_t i1 = tpig.x;
|
||||
+ threadgroup float * shmem_f32 [[threadgroup(0)]],
|
||||
+ uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
+ ushort3 tpitg[[thread_position_in_threadgroup]],
|
||||
+ ushort sgitg[[simdgroup_index_in_threadgroup]],
|
||||
+ ushort tiisg[[thread_index_in_simdgroup]],
|
||||
+ ushort3 ntg[[threads_per_threadgroup]]) {
|
||||
+ int64_t i3 = tgpig.z;
|
||||
+ int64_t i2 = tgpig.y;
|
||||
+ int64_t i1 = tgpig.x;
|
||||
|
||||
if (i3 >= args.ne03 || i2 >= args.ne02 || i1 >= args.ne01) {
|
||||
return;
|
||||
}
|
||||
|
||||
+ if (sgitg == 0) {
|
||||
+ shmem_f32[tiisg] = 0.0f;
|
||||
+ }
|
||||
+
|
||||
device const float * src_row = (device const float *) ((device const char *) src0 + i1*args.nb01 + i2*args.nb02 + i3*args.nb03);
|
||||
device float * dst_row = (device float *) ((device char *) dst + i1*args.nb1 + i2*args.nb2 + i3*args.nb3);
|
||||
|
||||
- float row_sum = 0;
|
||||
+ float sumf = 0;
|
||||
|
||||
- for (int64_t i0 = 0; i0 < args.ne00; i0++) {
|
||||
- row_sum += src_row[i0];
|
||||
+ for (int64_t i0 = tpitg.x; i0 < args.ne00; i0 += ntg.x) {
|
||||
+ sumf += src_row[i0];
|
||||
}
|
||||
|
||||
- dst_row[0] = row_sum;
|
||||
+ sumf = simd_sum(sumf);
|
||||
+
|
||||
+ threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
+
|
||||
+ if (tiisg == 0) {
|
||||
+ shmem_f32[sgitg] = sumf;
|
||||
+ }
|
||||
+
|
||||
+ threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
+
|
||||
+ sumf = shmem_f32[tiisg];
|
||||
+ sumf = simd_sum(sumf);
|
||||
+
|
||||
+ if (tpitg.x == 0) {
|
||||
+ dst_row[0] = norm ? sumf / args.ne00 : sumf;
|
||||
+ }
|
||||
}
|
||||
|
||||
+typedef decltype(kernel_sum_rows<false>) kernel_sum_rows_t;
|
||||
+
|
||||
+template [[host_name("kernel_sum_rows")]] kernel kernel_sum_rows_t kernel_sum_rows<false>;
|
||||
+template [[host_name("kernel_mean")]] kernel kernel_sum_rows_t kernel_sum_rows<true>;
|
||||
+
|
||||
template<typename T>
|
||||
kernel void kernel_soft_max(
|
||||
device const char * src0,
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,50 +0,0 @@
|
||||
From 0000000000000000000000000000000000000000 Mon Sep 17 00:00:00 2001
|
||||
From: Oliver Simons <osimons@nvidia.com>
|
||||
Date: Tue, 22 Jul 2025 11:02:28 +0200
|
||||
Subject: [PATCH] Enable CUDA Graphs for gemma3n.
|
||||
|
||||
Similar to
|
||||
https://github.com/ggml-org/llama.cpp/pull/14741,
|
||||
though ollama has a slightly different model graph
|
||||
than llama.cpp which requires different workaround
|
||||
checks.
|
||||
---
|
||||
ggml/src/ggml-cuda/ggml-cuda.cu | 16 ++++++++++++----
|
||||
1 file changed, 12 insertions(+), 4 deletions(-)
|
||||
|
||||
diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu
|
||||
index 2b9fabf4..28ccf4be 100644
|
||||
--- a/ggml/src/ggml-cuda/ggml-cuda.cu
|
||||
+++ b/ggml/src/ggml-cuda/ggml-cuda.cu
|
||||
@@ -2474,6 +2474,9 @@ static bool check_node_graph_compatibility_and_refresh_copy_ops(ggml_backend_cud
|
||||
// Loop over nodes in GGML graph to obtain info needed for CUDA graph
|
||||
cuda_ctx->cuda_graph->cpy_dest_ptrs.clear();
|
||||
|
||||
+ const std::string gemma3n_per_layer_proj_src1_name = " (reshaped)";
|
||||
+ const std::string gemma3n_node_name = "node_";
|
||||
+
|
||||
for (int i = 0; i < cgraph->n_nodes; i++) {
|
||||
ggml_tensor * node = cgraph->nodes[i];
|
||||
|
||||
@@ -2495,12 +2498,17 @@ static bool check_node_graph_compatibility_and_refresh_copy_ops(ggml_backend_cud
|
||||
#endif
|
||||
}
|
||||
|
||||
- if (node->op == GGML_OP_ADD && node->src[1] && node->src[1]->ne[1] > 1) {
|
||||
- // disable CUDA graphs for batch size > 1 for now.
|
||||
- // Changes in batch size or context size can cause changes to the grid size of some kernels.
|
||||
+ // workarounds to exclude Gemma3n's `project_per_layer_input` operation from the batch-size heuristic, specific to ollama's implementation of gemma3n
|
||||
+ // number of layers is different for per_layer_proj between gemma3n:2b and gemma3n:4b, which is why we don't check that value here
|
||||
+ if (node->op == GGML_OP_ADD && node->src[1] && node->src[1]->ne[1] > 1 && !(node->ne[0] == 256
|
||||
+ && node->ne[2] == 1
|
||||
+ && node->ne[3] == 1
|
||||
+ && node->src[0] ? std::string(node->src[0]->name).find(gemma3n_node_name) != std::string::npos : false
|
||||
+ && node->src[1] ? node->src[1]->name == gemma3n_per_layer_proj_src1_name : false)) {
|
||||
+ // Generally, changes in batch size or context size can cause changes to the grid size of some kernels.
|
||||
use_cuda_graph = false;
|
||||
#ifndef NDEBUG
|
||||
- GGML_LOG_DEBUG("%s: disabling CUDA graphs due to batch size > 1 [%s] [%ld %ld %ld %ld]\n", __func__, node->name, node->ne[0], node->ne[1], node->ne[2], node->ne[3]);
|
||||
+ GGML_LOG_INFO("%s: disabling CUDA graphs due to batch size > 1 [%s] [%ld %ld %ld %ld]\n", __func__, node->name, node->ne[0], node->ne[1], node->ne[2], node->ne[3]);
|
||||
#endif
|
||||
}
|
||||
|
||||
@@ -151,12 +151,7 @@ func EstimateGPULayers(gpus []discover.GpuInfo, f *ggml.GGML, projectors []strin
|
||||
}
|
||||
|
||||
if graphPartialOffload == 0 {
|
||||
headsKV := f.KV().HeadCountKVMin()
|
||||
if headsKV == 0 {
|
||||
headsKV = 1
|
||||
}
|
||||
gqa := f.KV().HeadCountMax() / headsKV
|
||||
graphPartialOffload = gqa * kvTotal / 6
|
||||
graphPartialOffload = f.KV().GQA() * kvTotal / 6
|
||||
}
|
||||
if graphFullOffload == 0 {
|
||||
graphFullOffload = graphPartialOffload
|
||||
|
||||
@@ -139,13 +139,6 @@ func NewLlamaServer(gpus discover.GpuInfoList, modelPath string, f *ggml.GGML, a
|
||||
gpus = discover.GetCPUInfo()
|
||||
}
|
||||
|
||||
// Verify the requested context size is <= the model training size
|
||||
trainCtx := f.KV().ContextLength()
|
||||
if opts.NumCtx/numParallel > int(trainCtx) && trainCtx > 0 {
|
||||
slog.Warn("requested context size too large for model", "num_ctx", opts.NumCtx, "num_parallel", numParallel, "n_ctx_train", trainCtx)
|
||||
opts.NumCtx = int(trainCtx) * numParallel
|
||||
}
|
||||
|
||||
estimate := EstimateGPULayers(gpus, f, projectors, opts, numParallel)
|
||||
if len(gpus) > 1 || gpus[0].Library != "cpu" {
|
||||
switch {
|
||||
@@ -318,7 +311,7 @@ func NewLlamaServer(gpus discover.GpuInfoList, modelPath string, f *ggml.GGML, a
|
||||
params = append(params, "--mmproj", projectors[0])
|
||||
}
|
||||
|
||||
// iterate through compatible GPU libraries such as 'cuda_v12', 'rocm', etc.
|
||||
// iterate through compatible GPU libraries such as 'cuda_v12', 'cuda_v11', 'rocm', etc.
|
||||
// adding each library's respective path to the LD_LIBRARY_PATH, until finally running
|
||||
// without any LD_LIBRARY_PATH flags
|
||||
for {
|
||||
@@ -707,6 +700,8 @@ const (
|
||||
DoneReasonStop DoneReason = iota
|
||||
// DoneReasonLength indicates the completion stopped due to length limits
|
||||
DoneReasonLength
|
||||
// DoneReasonContextShift indicates the completion stopped due to context shift
|
||||
DoneReasonContextShift
|
||||
// DoneReasonConnectionClosed indicates the completion stopped due to the connection being closed
|
||||
DoneReasonConnectionClosed
|
||||
)
|
||||
@@ -717,6 +712,8 @@ func (d DoneReason) String() string {
|
||||
return "length"
|
||||
case DoneReasonStop:
|
||||
return "stop"
|
||||
case DoneReasonContextShift:
|
||||
return "context_limit_reached"
|
||||
default:
|
||||
return "" // closed
|
||||
}
|
||||
|
||||
@@ -124,9 +124,9 @@ type DeviceMemory struct {
|
||||
// may not be persistent across instances of the runner.
|
||||
Name string
|
||||
|
||||
// ID is an identifier for the device for matching with system
|
||||
// management libraries.
|
||||
ID string
|
||||
// UUID is a unique persistent identifier for the device for matching
|
||||
// with system management libraries
|
||||
UUID string
|
||||
|
||||
// Weights is the per-layer memory needed for the model weights.
|
||||
Weights []Memory
|
||||
@@ -156,8 +156,8 @@ func (m DeviceMemory) LogValue() slog.Value {
|
||||
attrs = append(attrs, slog.Any("Graph", m.Graph))
|
||||
}
|
||||
|
||||
if len(attrs) > 0 && m.ID != "" {
|
||||
attrs = append([]slog.Attr{slog.String("ID", m.ID)}, attrs...)
|
||||
if len(attrs) > 0 && m.UUID != "" {
|
||||
attrs = append([]slog.Attr{slog.String("UUID", m.UUID)}, attrs...)
|
||||
}
|
||||
|
||||
return slog.GroupValue(attrs...)
|
||||
@@ -253,7 +253,6 @@ type Tensor interface {
|
||||
|
||||
Neg(ctx Context) Tensor
|
||||
Add(ctx Context, t2 Tensor) Tensor
|
||||
Sub(ctx Context, t2 Tensor) Tensor
|
||||
Mul(ctx Context, t2 Tensor) Tensor
|
||||
Div(ctx Context, t2 Tensor) Tensor
|
||||
|
||||
@@ -277,7 +276,6 @@ type Tensor interface {
|
||||
Tanh(ctx Context) Tensor
|
||||
GELU(ctx Context) Tensor
|
||||
SILU(ctx Context) Tensor
|
||||
RELU(ctx Context) Tensor
|
||||
Sigmoid(ctx Context) Tensor
|
||||
|
||||
Reshape(ctx Context, shape ...int) Tensor
|
||||
@@ -299,12 +297,6 @@ type Tensor interface {
|
||||
|
||||
TopK(ctx Context, k int) Tensor
|
||||
Argsort(ctx Context) Tensor
|
||||
Mean(ctx Context) Tensor
|
||||
Variance(ctx Context) Tensor
|
||||
Stddev(ctx Context) Tensor
|
||||
Sqr(ctx Context) Tensor
|
||||
Sqrt(ctx Context) Tensor
|
||||
Clamp(ctx Context, min, max float32) Tensor
|
||||
}
|
||||
|
||||
// ScaledDotProductAttention implements a fused attention
|
||||
|
||||
@@ -138,7 +138,7 @@ func New(modelPath string, params ml.BackendParams) (ml.Backend, error) {
|
||||
requiredMemory.CPU.Name = C.GoString(C.ggml_backend_dev_name(cpuDeviceBufferType.d))
|
||||
var props C.struct_ggml_backend_dev_props
|
||||
C.ggml_backend_dev_get_props(cpuDeviceBufferType.d, &props)
|
||||
requiredMemory.CPU.ID = C.GoString(props.id)
|
||||
requiredMemory.CPU.UUID = C.GoString(props.uuid)
|
||||
requiredMemory.CPU.Weights = make([]ml.Memory, blocks+1)
|
||||
requiredMemory.CPU.Cache = make([]ml.Memory, blocks+1)
|
||||
|
||||
@@ -155,7 +155,7 @@ func New(modelPath string, params ml.BackendParams) (ml.Backend, error) {
|
||||
requiredMemory.GPUs[i].Name = C.GoString(C.ggml_backend_dev_name(d))
|
||||
var props C.struct_ggml_backend_dev_props
|
||||
C.ggml_backend_dev_get_props(d, &props)
|
||||
requiredMemory.GPUs[i].ID = C.GoString(props.id)
|
||||
requiredMemory.GPUs[i].UUID = C.GoString(props.uuid)
|
||||
requiredMemory.GPUs[i].Weights = make([]ml.Memory, blocks+1)
|
||||
requiredMemory.GPUs[i].Cache = make([]ml.Memory, blocks+1)
|
||||
}
|
||||
@@ -297,9 +297,7 @@ func New(modelPath string, params ml.BackendParams) (ml.Backend, error) {
|
||||
if _, ok := meta.Tensors().GroupLayers()["output"]; !ok && t.Name == "token_embd.weight" {
|
||||
createTensor(tensor{source: t, target: "output.weight"}, output.bts, blocks)
|
||||
}
|
||||
case contains(t.Name, "cls", "output", "output_norm",
|
||||
"altup_proj", "altup_unembd_proj",
|
||||
"per_layer_token_embd", "per_layer_model_proj", "per_layer_proj_norm"):
|
||||
case contains(t.Name, "cls", "output", "output_norm"):
|
||||
createTensor(tensor{source: t}, output.bts, blocks)
|
||||
case strings.HasPrefix(t.Name, "v.") || strings.HasPrefix(t.Name, "mm."):
|
||||
// TODO: assign vision tensors to the gpu if possible
|
||||
@@ -355,26 +353,6 @@ func New(modelPath string, params ml.BackendParams) (ml.Backend, error) {
|
||||
bbs[c] = b
|
||||
}
|
||||
|
||||
// Mimic llama runner logs summarizing layers and memory
|
||||
gpuLayers := 0
|
||||
for _, layer := range layers {
|
||||
if C.ggml_backend_dev_type(layer.d) == C.GGML_BACKEND_DEVICE_TYPE_GPU {
|
||||
gpuLayers++
|
||||
}
|
||||
}
|
||||
slog.Info(fmt.Sprintf("offloading %d repeating layers to GPU", gpuLayers))
|
||||
|
||||
switch C.ggml_backend_dev_type(output.d) {
|
||||
case C.GGML_BACKEND_DEVICE_TYPE_CPU:
|
||||
slog.Info("offloading output layer to CPU")
|
||||
case C.GGML_BACKEND_DEVICE_TYPE_GPU:
|
||||
slog.Info("offloading output layer to GPU")
|
||||
gpuLayers++
|
||||
case C.GGML_BACKEND_DEVICE_TYPE_ACCEL:
|
||||
slog.Info("offloading output layer to ACCEL")
|
||||
}
|
||||
slog.Info(fmt.Sprintf("offloaded %d/%d layers to GPU", gpuLayers, len(layers)+1))
|
||||
|
||||
for bs := range maps.Values(bbs) {
|
||||
slog.Info("model weights", "buffer", C.GoString(C.ggml_backend_buffer_name(bs)), "size", format.HumanBytes2(uint64(C.ggml_backend_buffer_get_size(bs))))
|
||||
}
|
||||
@@ -420,7 +398,7 @@ func New(modelPath string, params ml.BackendParams) (ml.Backend, error) {
|
||||
(*C.ggml_backend_buffer_type_t)(unsafe.Pointer(&schedBufts[0])),
|
||||
C.int(len(schedBackends)),
|
||||
C.size_t(maxGraphNodes),
|
||||
C._Bool(false),
|
||||
C._Bool(len(gpus) > 1 && slices.Contains(gpus, output.d)),
|
||||
C._Bool(false),
|
||||
),
|
||||
schedBackends: schedBackends,
|
||||
@@ -624,9 +602,7 @@ func (c *Context) Forward(tensors ...ml.Tensor) ml.Context {
|
||||
}
|
||||
|
||||
func (c *Context) Compute(tensors ...ml.Tensor) {
|
||||
if status := C.ggml_backend_sched_graph_compute_async(c.b.sched, c.graph); status != C.GGML_STATUS_SUCCESS {
|
||||
panic(fmt.Errorf("error computing ggml graph: %v", status))
|
||||
}
|
||||
C.ggml_backend_sched_graph_compute_async(c.b.sched, c.graph)
|
||||
C.ggml_backend_sched_reset(c.b.sched)
|
||||
|
||||
needSync := true
|
||||
@@ -915,13 +891,6 @@ func (t *Tensor) Add(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
|
||||
}
|
||||
}
|
||||
|
||||
func (t *Tensor) Sub(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
|
||||
return &Tensor{
|
||||
b: t.b,
|
||||
t: C.ggml_sub(ctx.(*Context).ctx, t.t, t2.(*Tensor).t),
|
||||
}
|
||||
}
|
||||
|
||||
func (t *Tensor) Repeat(ctx ml.Context, dim, n int) ml.Tensor {
|
||||
if dim < 0 || dim >= C.GGML_MAX_DIMS {
|
||||
panic("invalid dimension")
|
||||
@@ -1229,13 +1198,6 @@ func (t *Tensor) SILU(ctx ml.Context) ml.Tensor {
|
||||
}
|
||||
}
|
||||
|
||||
func (t *Tensor) RELU(ctx ml.Context) ml.Tensor {
|
||||
return &Tensor{
|
||||
b: t.b,
|
||||
t: C.ggml_relu_inplace(ctx.(*Context).ctx, t.t),
|
||||
}
|
||||
}
|
||||
|
||||
func (t *Tensor) Conv2D(ctx ml.Context, t2 ml.Tensor, s0, s1, p0, p1, d0, d1 int) ml.Tensor {
|
||||
return &Tensor{
|
||||
b: t.b,
|
||||
@@ -1311,42 +1273,3 @@ func (t *Tensor) Argsort(ctx ml.Context) ml.Tensor {
|
||||
t: C.ggml_argsort(ctx.(*Context).ctx, t.t, C.GGML_SORT_ORDER_ASC),
|
||||
}
|
||||
}
|
||||
|
||||
func (t *Tensor) Mean(ctx ml.Context) ml.Tensor {
|
||||
return &Tensor{
|
||||
b: t.b,
|
||||
t: C.ggml_mean(ctx.(*Context).ctx, t.t),
|
||||
}
|
||||
}
|
||||
|
||||
func (t *Tensor) Variance(ctx ml.Context) ml.Tensor {
|
||||
return t.Add(ctx, t.Mean(ctx).Scale(ctx, -1)).
|
||||
Sqr(ctx).
|
||||
SumRows(ctx).
|
||||
Scale(ctx, 1/float64(t.Dim(0)))
|
||||
}
|
||||
|
||||
func (t *Tensor) Stddev(ctx ml.Context) ml.Tensor {
|
||||
return t.Variance(ctx).Sqrt(ctx)
|
||||
}
|
||||
|
||||
func (t *Tensor) Sqr(ctx ml.Context) ml.Tensor {
|
||||
return &Tensor{
|
||||
b: t.b,
|
||||
t: C.ggml_sqr(ctx.(*Context).ctx, t.t),
|
||||
}
|
||||
}
|
||||
|
||||
func (t *Tensor) Sqrt(ctx ml.Context) ml.Tensor {
|
||||
return &Tensor{
|
||||
b: t.b,
|
||||
t: C.ggml_sqrt(ctx.(*Context).ctx, t.t),
|
||||
}
|
||||
}
|
||||
|
||||
func (t *Tensor) Clamp(ctx ml.Context, min, max float32) ml.Tensor {
|
||||
return &Tensor{
|
||||
b: t.b,
|
||||
t: C.ggml_clamp(ctx.(*Context).ctx, t.t, C.float(min), C.float(max)),
|
||||
}
|
||||
}
|
||||
|
||||
2
ml/backend/ggml/ggml/include/ggml-backend.h
vendored
2
ml/backend/ggml/ggml/include/ggml-backend.h
vendored
@@ -152,7 +152,7 @@ extern "C" {
|
||||
struct ggml_backend_dev_props {
|
||||
const char * name;
|
||||
const char * description;
|
||||
const char * id;
|
||||
const char * uuid;
|
||||
size_t memory_free;
|
||||
size_t memory_total;
|
||||
enum ggml_backend_dev_type type;
|
||||
|
||||
12
ml/backend/ggml/ggml/src/ggml-backend-reg.cpp
vendored
12
ml/backend/ggml/ggml/src/ggml-backend-reg.cpp
vendored
@@ -573,16 +573,8 @@ void ggml_backend_load_all_from_path(const char * dir_path) {
|
||||
|
||||
ggml_backend_load_best("blas", silent, dir_path);
|
||||
ggml_backend_load_best("cann", silent, dir_path);
|
||||
|
||||
// Avoid mixed hip+cuda configurations
|
||||
const char * hip_devices = std::getenv("HIP_VISIBLE_DEVICES");
|
||||
const char * rocr_devices = std::getenv("ROCR_VISIBLE_DEVICES");
|
||||
if (!hip_devices && !rocr_devices) {
|
||||
ggml_backend_load_best("cuda", silent, dir_path);
|
||||
} else {
|
||||
ggml_backend_load_best("hip", silent, dir_path);
|
||||
}
|
||||
|
||||
ggml_backend_load_best("cuda", silent, dir_path);
|
||||
ggml_backend_load_best("hip", silent, dir_path);
|
||||
ggml_backend_load_best("kompute", silent, dir_path);
|
||||
ggml_backend_load_best("metal", silent, dir_path);
|
||||
ggml_backend_load_best("rpc", silent, dir_path);
|
||||
|
||||
20
ml/backend/ggml/ggml/src/ggml-cuda/common.cuh
vendored
20
ml/backend/ggml/ggml/src/ggml-cuda/common.cuh
vendored
@@ -362,26 +362,6 @@ static __device__ __forceinline__ half2 warp_reduce_sum(half2 a) {
|
||||
#endif // FP16_AVAILABLE
|
||||
}
|
||||
|
||||
// Row reduction kernel template - compute sum (norm=false) or mean (norm=true)
|
||||
template<bool norm>
|
||||
static __global__ void reduce_rows_f32(const float * x, float * dst, const int ncols) {
|
||||
const int row = blockIdx.x;
|
||||
const int col = threadIdx.x;
|
||||
|
||||
float sum = 0.0f;
|
||||
for (int i = col; i < ncols; i += blockDim.x) {
|
||||
sum += x[row * ncols + i];
|
||||
}
|
||||
|
||||
sum = warp_reduce_sum(sum);
|
||||
|
||||
if (col != 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
dst[row] = norm ? sum / ncols : sum;
|
||||
}
|
||||
|
||||
template<int width = WARP_SIZE>
|
||||
static __device__ __forceinline__ float warp_reduce_max(float x) {
|
||||
#pragma unroll
|
||||
|
||||
43
ml/backend/ggml/ggml/src/ggml-cuda/ggml-cuda.cu
vendored
43
ml/backend/ggml/ggml/src/ggml-cuda/ggml-cuda.cu
vendored
@@ -35,7 +35,6 @@
|
||||
#include "ggml-cuda/ssm-scan.cuh"
|
||||
#include "ggml-cuda/sum.cuh"
|
||||
#include "ggml-cuda/sumrows.cuh"
|
||||
#include "ggml-cuda/mean.cuh"
|
||||
#include "ggml-cuda/tsembd.cuh"
|
||||
#include "ggml-cuda/unary.cuh"
|
||||
#include "ggml-cuda/upscale.cuh"
|
||||
@@ -2323,9 +2322,6 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
|
||||
case GGML_OP_SUM_ROWS:
|
||||
ggml_cuda_op_sum_rows(ctx, dst);
|
||||
break;
|
||||
case GGML_OP_MEAN:
|
||||
ggml_cuda_op_mean(ctx, dst);
|
||||
break;
|
||||
case GGML_OP_SSM_CONV:
|
||||
ggml_cuda_op_ssm_conv(ctx, dst);
|
||||
break;
|
||||
@@ -2474,9 +2470,6 @@ static bool check_node_graph_compatibility_and_refresh_copy_ops(ggml_backend_cud
|
||||
// Loop over nodes in GGML graph to obtain info needed for CUDA graph
|
||||
cuda_ctx->cuda_graph->cpy_dest_ptrs.clear();
|
||||
|
||||
const std::string gemma3n_per_layer_proj_src1_name = " (reshaped)";
|
||||
const std::string gemma3n_node_name = "node_";
|
||||
|
||||
for (int i = 0; i < cgraph->n_nodes; i++) {
|
||||
ggml_tensor * node = cgraph->nodes[i];
|
||||
|
||||
@@ -2498,17 +2491,12 @@ static bool check_node_graph_compatibility_and_refresh_copy_ops(ggml_backend_cud
|
||||
#endif
|
||||
}
|
||||
|
||||
// workarounds to exclude Gemma3n's `project_per_layer_input` operation from the batch-size heuristic, specific to ollama's implementation of gemma3n
|
||||
// number of layers is different for per_layer_proj between gemma3n:2b and gemma3n:4b, which is why we don't check that value here
|
||||
if (node->op == GGML_OP_ADD && node->src[1] && node->src[1]->ne[1] > 1 && !(node->ne[0] == 256
|
||||
&& node->ne[2] == 1
|
||||
&& node->ne[3] == 1
|
||||
&& node->src[0] ? std::string(node->src[0]->name).find(gemma3n_node_name) != std::string::npos : false
|
||||
&& node->src[1] ? node->src[1]->name == gemma3n_per_layer_proj_src1_name : false)) {
|
||||
// Generally, changes in batch size or context size can cause changes to the grid size of some kernels.
|
||||
if (node->op == GGML_OP_ADD && node->src[1] && node->src[1]->ne[1] > 1) {
|
||||
// disable CUDA graphs for batch size > 1 for now.
|
||||
// Changes in batch size or context size can cause changes to the grid size of some kernels.
|
||||
use_cuda_graph = false;
|
||||
#ifndef NDEBUG
|
||||
GGML_LOG_INFO("%s: disabling CUDA graphs due to batch size > 1 [%s] [%ld %ld %ld %ld]\n", __func__, node->name, node->ne[0], node->ne[1], node->ne[2], node->ne[3]);
|
||||
GGML_LOG_DEBUG("%s: disabling CUDA graphs due to batch size > 1 [%s] [%ld %ld %ld %ld]\n", __func__, node->name, node->ne[0], node->ne[1], node->ne[2], node->ne[3]);
|
||||
#endif
|
||||
}
|
||||
|
||||
@@ -2896,7 +2884,7 @@ struct ggml_backend_cuda_device_context {
|
||||
int device;
|
||||
std::string name;
|
||||
std::string description;
|
||||
std::string id;
|
||||
std::string uuid;
|
||||
};
|
||||
|
||||
static const char * ggml_backend_cuda_device_get_name(ggml_backend_dev_t dev) {
|
||||
@@ -2909,9 +2897,9 @@ static const char * ggml_backend_cuda_device_get_description(ggml_backend_dev_t
|
||||
return ctx->description.c_str();
|
||||
}
|
||||
|
||||
static const char * ggml_backend_cuda_device_get_id(ggml_backend_dev_t dev) {
|
||||
static const char * ggml_backend_cuda_device_get_uuid(ggml_backend_dev_t dev) {
|
||||
ggml_backend_cuda_device_context * ctx = (ggml_backend_cuda_device_context *)dev->context;
|
||||
return ctx->id.c_str();
|
||||
return ctx->uuid.c_str();
|
||||
}
|
||||
|
||||
static void ggml_backend_cuda_device_get_memory(ggml_backend_dev_t dev, size_t * free, size_t * total) {
|
||||
@@ -2928,7 +2916,7 @@ static enum ggml_backend_dev_type ggml_backend_cuda_device_get_type(ggml_backend
|
||||
static void ggml_backend_cuda_device_get_props(ggml_backend_dev_t dev, ggml_backend_dev_props * props) {
|
||||
props->name = ggml_backend_cuda_device_get_name(dev);
|
||||
props->description = ggml_backend_cuda_device_get_description(dev);
|
||||
props->id = ggml_backend_cuda_device_get_id(dev);
|
||||
props->uuid = ggml_backend_cuda_device_get_uuid(dev);
|
||||
props->type = ggml_backend_cuda_device_get_type(dev);
|
||||
ggml_backend_cuda_device_get_memory(dev, &props->memory_free, &props->memory_total);
|
||||
|
||||
@@ -3223,7 +3211,6 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
|
||||
case GGML_OP_POOL_2D:
|
||||
case GGML_OP_SUM:
|
||||
case GGML_OP_SUM_ROWS:
|
||||
case GGML_OP_MEAN:
|
||||
case GGML_OP_ARGSORT:
|
||||
case GGML_OP_ACC:
|
||||
return true;
|
||||
@@ -3479,8 +3466,8 @@ ggml_backend_reg_t ggml_backend_cuda_reg() {
|
||||
dev_ctx->description = prop.name;
|
||||
|
||||
#if !defined(GGML_USE_HIP)
|
||||
char id[64];
|
||||
snprintf(id, sizeof(id),
|
||||
char uuid[64];
|
||||
snprintf(uuid, sizeof(uuid),
|
||||
"GPU-%02x%02x%02x%02x-%02x%02x-%02x%02x-%02x%02x-%02x%02x%02x%02x%02x%02x",
|
||||
(unsigned char)prop.uuid.bytes[0],
|
||||
(unsigned char)prop.uuid.bytes[1],
|
||||
@@ -3499,15 +3486,9 @@ ggml_backend_reg_t ggml_backend_cuda_reg() {
|
||||
(unsigned char)prop.uuid.bytes[14],
|
||||
(unsigned char)prop.uuid.bytes[15]
|
||||
);
|
||||
dev_ctx->id = id;
|
||||
dev_ctx->uuid = uuid;
|
||||
#else
|
||||
#ifdef _WIN32
|
||||
char id[16];
|
||||
snprintf(id, sizeof(id), "%d", i);
|
||||
dev_ctx->id = id;
|
||||
#else
|
||||
dev_ctx->id = "GPU-" + std::string(prop.uuid.bytes, 16);
|
||||
#endif
|
||||
dev_ctx->uuid = "GPU-" + std::string(prop.uuid.bytes, 16);
|
||||
#endif
|
||||
|
||||
ggml_backend_dev_t dev = new ggml_backend_device {
|
||||
|
||||
19
ml/backend/ggml/ggml/src/ggml-cuda/mean.cu
vendored
19
ml/backend/ggml/ggml/src/ggml-cuda/mean.cu
vendored
@@ -1,19 +0,0 @@
|
||||
#include "mean.cuh"
|
||||
|
||||
void ggml_cuda_op_mean(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
const ggml_tensor * src0 = dst->src[0];
|
||||
const float * src0_d = (const float *) src0->data;
|
||||
float * dst_d = (float *) dst->data;
|
||||
cudaStream_t stream = ctx.stream();
|
||||
|
||||
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(dst->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(ggml_is_contiguous(src0));
|
||||
|
||||
const int64_t ncols = src0->ne[0];
|
||||
const int64_t nrows = ggml_nrows(src0);
|
||||
|
||||
const dim3 block_dims(WARP_SIZE, 1, 1);
|
||||
const dim3 block_nums(nrows, 1, 1);
|
||||
reduce_rows_f32</*norm*/ true><<<block_nums, block_dims, 0, stream>>>(src0_d, dst_d, ncols);
|
||||
}
|
||||
3
ml/backend/ggml/ggml/src/ggml-cuda/mean.cuh
vendored
3
ml/backend/ggml/ggml/src/ggml-cuda/mean.cuh
vendored
@@ -1,3 +0,0 @@
|
||||
#include "common.cuh"
|
||||
|
||||
void ggml_cuda_op_mean(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
||||
23
ml/backend/ggml/ggml/src/ggml-cuda/sumrows.cu
vendored
23
ml/backend/ggml/ggml/src/ggml-cuda/sumrows.cu
vendored
@@ -1,9 +1,25 @@
|
||||
#include "sumrows.cuh"
|
||||
|
||||
static __global__ void k_sum_rows_f32(const float * x, float * dst, const int ncols) {
|
||||
const int row = blockIdx.x;
|
||||
const int col = threadIdx.x;
|
||||
|
||||
float sum = 0.0f;
|
||||
for (int i = col; i < ncols; i += blockDim.x) {
|
||||
sum += x[row * ncols + i];
|
||||
}
|
||||
|
||||
sum = warp_reduce_sum(sum);
|
||||
|
||||
if (col == 0) {
|
||||
dst[row] = sum;
|
||||
}
|
||||
}
|
||||
|
||||
void sum_rows_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
|
||||
const dim3 block_dims(WARP_SIZE, 1, 1);
|
||||
const dim3 block_nums(nrows, 1, 1);
|
||||
reduce_rows_f32</*norm*/false><<<block_nums, block_dims, 0, stream>>>(x, dst, ncols);
|
||||
k_sum_rows_f32<<<block_nums, block_dims, 0, stream>>>(x, dst, ncols);
|
||||
}
|
||||
|
||||
void ggml_cuda_op_sum_rows(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
@@ -19,8 +35,5 @@ void ggml_cuda_op_sum_rows(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
const int64_t ncols = src0->ne[0];
|
||||
const int64_t nrows = ggml_nrows(src0);
|
||||
|
||||
const dim3 block_dims(WARP_SIZE, 1, 1);
|
||||
const dim3 block_nums(nrows, 1, 1);
|
||||
|
||||
reduce_rows_f32</*norm=*/false><<<block_nums, block_dims, 0, stream>>>(src0_d, dst_d, ncols);
|
||||
sum_rows_f32_cuda(src0_d, dst_d, ncols, nrows, stream);
|
||||
}
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
#include "common.cuh"
|
||||
|
||||
void sum_rows_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, cudaStream_t stream);
|
||||
|
||||
void ggml_cuda_op_sum_rows(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
||||
|
||||
@@ -3434,61 +3434,31 @@ kernel void kernel_neg(
|
||||
dst[tpig] = -src0[tpig];
|
||||
}
|
||||
|
||||
template <bool norm>
|
||||
kernel void kernel_sum_rows(
|
||||
constant ggml_metal_kargs_sum_rows & args,
|
||||
device const float * src0,
|
||||
device float * dst,
|
||||
threadgroup float * shmem_f32 [[threadgroup(0)]],
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
ushort3 tpitg[[thread_position_in_threadgroup]],
|
||||
ushort sgitg[[simdgroup_index_in_threadgroup]],
|
||||
ushort tiisg[[thread_index_in_simdgroup]],
|
||||
ushort3 ntg[[threads_per_threadgroup]]) {
|
||||
int64_t i3 = tgpig.z;
|
||||
int64_t i2 = tgpig.y;
|
||||
int64_t i1 = tgpig.x;
|
||||
constant ggml_metal_kargs_sum_rows & args,
|
||||
uint3 tpig[[thread_position_in_grid]]) {
|
||||
int64_t i3 = tpig.z;
|
||||
int64_t i2 = tpig.y;
|
||||
int64_t i1 = tpig.x;
|
||||
|
||||
if (i3 >= args.ne03 || i2 >= args.ne02 || i1 >= args.ne01) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (sgitg == 0) {
|
||||
shmem_f32[tiisg] = 0.0f;
|
||||
}
|
||||
|
||||
device const float * src_row = (device const float *) ((device const char *) src0 + i1*args.nb01 + i2*args.nb02 + i3*args.nb03);
|
||||
device float * dst_row = (device float *) ((device char *) dst + i1*args.nb1 + i2*args.nb2 + i3*args.nb3);
|
||||
|
||||
float sumf = 0;
|
||||
float row_sum = 0;
|
||||
|
||||
for (int64_t i0 = tpitg.x; i0 < args.ne00; i0 += ntg.x) {
|
||||
sumf += src_row[i0];
|
||||
for (int64_t i0 = 0; i0 < args.ne00; i0++) {
|
||||
row_sum += src_row[i0];
|
||||
}
|
||||
|
||||
sumf = simd_sum(sumf);
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
if (tiisg == 0) {
|
||||
shmem_f32[sgitg] = sumf;
|
||||
}
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
sumf = shmem_f32[tiisg];
|
||||
sumf = simd_sum(sumf);
|
||||
|
||||
if (tpitg.x == 0) {
|
||||
dst_row[0] = norm ? sumf / args.ne00 : sumf;
|
||||
}
|
||||
dst_row[0] = row_sum;
|
||||
}
|
||||
|
||||
typedef decltype(kernel_sum_rows<false>) kernel_sum_rows_t;
|
||||
|
||||
template [[host_name("kernel_sum_rows")]] kernel kernel_sum_rows_t kernel_sum_rows<false>;
|
||||
template [[host_name("kernel_mean")]] kernel kernel_sum_rows_t kernel_sum_rows<true>;
|
||||
|
||||
template<typename T>
|
||||
kernel void kernel_soft_max(
|
||||
device const char * src0,
|
||||
|
||||
35
ml/backend/ggml/ggml/src/ggml-metal/ggml-metal.m
vendored
35
ml/backend/ggml/ggml/src/ggml-metal/ggml-metal.m
vendored
@@ -489,7 +489,6 @@ enum ggml_metal_kernel_type {
|
||||
GGML_METAL_KERNEL_TYPE_COS,
|
||||
GGML_METAL_KERNEL_TYPE_NEG,
|
||||
GGML_METAL_KERNEL_TYPE_SUM_ROWS,
|
||||
GGML_METAL_KERNEL_TYPE_MEAN,
|
||||
GGML_METAL_KERNEL_TYPE_POOL_2D_AVG_F32,
|
||||
GGML_METAL_KERNEL_TYPE_POOL_2D_MAX_F32,
|
||||
GGML_METAL_KERNEL_TYPE_ARGMAX,
|
||||
@@ -1437,7 +1436,6 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_COS, cos, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_NEG, neg, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SUM_ROWS, sum_rows, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MEAN, mean, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGMAX, argmax, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_POOL_2D_AVG_F32, pool_2d_avg_f32, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_POOL_2D_MAX_F32, pool_2d_max_f32, true);
|
||||
@@ -1636,7 +1634,6 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
|
||||
case GGML_OP_LOG:
|
||||
return false; // TODO: implement
|
||||
case GGML_OP_SUM_ROWS:
|
||||
case GGML_OP_MEAN:
|
||||
case GGML_OP_SOFT_MAX:
|
||||
case GGML_OP_GROUP_NORM:
|
||||
return has_simdgroup_reduction && ggml_is_contiguous(op->src[0]);
|
||||
@@ -2365,30 +2362,11 @@ static bool ggml_metal_encode_node(
|
||||
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
||||
} break;
|
||||
case GGML_OP_SUM_ROWS:
|
||||
case GGML_OP_MEAN:
|
||||
{
|
||||
GGML_ASSERT(src0->nb[0] == ggml_type_size(src0->type));
|
||||
|
||||
id<MTLComputePipelineState> pipeline = nil;
|
||||
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SUM_ROWS].pipeline;
|
||||
|
||||
switch (dst->op) {
|
||||
case GGML_OP_SUM_ROWS:
|
||||
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SUM_ROWS].pipeline;
|
||||
break;
|
||||
case GGML_OP_MEAN:
|
||||
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MEAN].pipeline;
|
||||
break;
|
||||
default:
|
||||
GGML_ABORT("fatal error");
|
||||
}
|
||||
|
||||
int nth = 32; // SIMD width
|
||||
|
||||
while (nth < ne00 && nth < (int) pipeline.maxTotalThreadsPerThreadgroup) {
|
||||
nth *= 2;
|
||||
}
|
||||
|
||||
nth = MIN(nth, ne00);
|
||||
|
||||
ggml_metal_kargs_sum_rows args = {
|
||||
/*.ne00 =*/ ne00,
|
||||
@@ -2418,12 +2396,11 @@ static bool ggml_metal_encode_node(
|
||||
};
|
||||
|
||||
[encoder setComputePipelineState:pipeline];
|
||||
[encoder setBytes:&args length:sizeof(args) atIndex:0];
|
||||
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
|
||||
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
||||
[encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0];
|
||||
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
||||
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
||||
[encoder setBytes:&args length:sizeof(args) atIndex:2];
|
||||
|
||||
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
||||
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
||||
} break;
|
||||
case GGML_OP_SOFT_MAX:
|
||||
{
|
||||
@@ -5726,7 +5703,7 @@ static enum ggml_backend_dev_type ggml_backend_metal_device_get_type(ggml_backen
|
||||
static void ggml_backend_metal_device_get_props(ggml_backend_dev_t dev, struct ggml_backend_dev_props * props) {
|
||||
props->name = ggml_backend_metal_device_get_name(dev);
|
||||
props->description = ggml_backend_metal_device_get_description(dev);
|
||||
props->id = "0";
|
||||
props->uuid = "0";
|
||||
props->type = ggml_backend_metal_device_get_type(dev);
|
||||
ggml_backend_metal_device_get_memory(dev, &props->memory_free, &props->memory_total);
|
||||
props->caps = (struct ggml_backend_dev_caps) {
|
||||
|
||||
@@ -956,61 +956,31 @@ kernel void kernel_neg(
|
||||
dst[tpig] = -src0[tpig];
|
||||
}
|
||||
|
||||
template <bool norm>
|
||||
kernel void kernel_sum_rows(
|
||||
constant ggml_metal_kargs_sum_rows & args,
|
||||
device const float * src0,
|
||||
device float * dst,
|
||||
threadgroup float * shmem_f32 [[threadgroup(0)]],
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
ushort3 tpitg[[thread_position_in_threadgroup]],
|
||||
ushort sgitg[[simdgroup_index_in_threadgroup]],
|
||||
ushort tiisg[[thread_index_in_simdgroup]],
|
||||
ushort3 ntg[[threads_per_threadgroup]]) {
|
||||
int64_t i3 = tgpig.z;
|
||||
int64_t i2 = tgpig.y;
|
||||
int64_t i1 = tgpig.x;
|
||||
constant ggml_metal_kargs_sum_rows & args,
|
||||
uint3 tpig[[thread_position_in_grid]]) {
|
||||
int64_t i3 = tpig.z;
|
||||
int64_t i2 = tpig.y;
|
||||
int64_t i1 = tpig.x;
|
||||
|
||||
if (i3 >= args.ne03 || i2 >= args.ne02 || i1 >= args.ne01) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (sgitg == 0) {
|
||||
shmem_f32[tiisg] = 0.0f;
|
||||
}
|
||||
|
||||
device const float * src_row = (device const float *) ((device const char *) src0 + i1*args.nb01 + i2*args.nb02 + i3*args.nb03);
|
||||
device float * dst_row = (device float *) ((device char *) dst + i1*args.nb1 + i2*args.nb2 + i3*args.nb3);
|
||||
|
||||
float sumf = 0;
|
||||
float row_sum = 0;
|
||||
|
||||
for (int64_t i0 = tpitg.x; i0 < args.ne00; i0 += ntg.x) {
|
||||
sumf += src_row[i0];
|
||||
for (int64_t i0 = 0; i0 < args.ne00; i0++) {
|
||||
row_sum += src_row[i0];
|
||||
}
|
||||
|
||||
sumf = simd_sum(sumf);
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
if (tiisg == 0) {
|
||||
shmem_f32[sgitg] = sumf;
|
||||
}
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
sumf = shmem_f32[tiisg];
|
||||
sumf = simd_sum(sumf);
|
||||
|
||||
if (tpitg.x == 0) {
|
||||
dst_row[0] = norm ? sumf / args.ne00 : sumf;
|
||||
}
|
||||
dst_row[0] = row_sum;
|
||||
}
|
||||
|
||||
typedef decltype(kernel_sum_rows<false>) kernel_sum_rows_t;
|
||||
|
||||
template [[host_name("kernel_sum_rows")]] kernel kernel_sum_rows_t kernel_sum_rows<false>;
|
||||
template [[host_name("kernel_mean")]] kernel kernel_sum_rows_t kernel_sum_rows<true>;
|
||||
|
||||
template<typename T>
|
||||
kernel void kernel_soft_max(
|
||||
device const char * src0,
|
||||
|
||||
@@ -4,6 +4,6 @@ package metal
|
||||
|
||||
//go:generate sh -c "{ echo // Code generated by 'go generate'. DO NOT EDIT.; sed -e '/__embed_ggml-common.h__/r ../ggml-common.h' -e '/__embed_ggml-common.h__/d' -e '/#include \"ggml-metal-impl.h\"/r ggml-metal-impl.h' -e '/#include \"ggml-metal-impl.h\"/d' ggml-metal.metal; } >ggml-metal-embed.metal"
|
||||
|
||||
// #cgo CPPFLAGS: -DGGML_METAL_NDEBUG -DGGML_METAL_EMBED_LIBRARY -DGGML_METAL_USE_BF16 -I.. -I../../include
|
||||
// #cgo CPPFLAGS: -DGGML_METAL_NDEBUG -DGGML_METAL_EMBED_LIBRARY -I.. -I../../include
|
||||
// #cgo LDFLAGS: -framework Metal -framework MetalKit
|
||||
import "C"
|
||||
|
||||
@@ -1,51 +0,0 @@
|
||||
package gemma3n
|
||||
|
||||
import (
|
||||
"github.com/ollama/ollama/fs"
|
||||
"github.com/ollama/ollama/kvcache"
|
||||
"github.com/ollama/ollama/ml"
|
||||
"github.com/ollama/ollama/model"
|
||||
"github.com/ollama/ollama/model/input"
|
||||
)
|
||||
|
||||
type Model struct {
|
||||
model.Base
|
||||
model.SentencePieceModel
|
||||
|
||||
*TextModel
|
||||
}
|
||||
|
||||
// Forward implements model.Model.
|
||||
func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
|
||||
return m.TextModel.Forward(ctx, batch, m.Cache)
|
||||
}
|
||||
|
||||
func New(c fs.Config) (model.Model, error) {
|
||||
m := Model{
|
||||
TextModel: newTextModel(c),
|
||||
SentencePieceModel: model.NewSentencePieceModel(
|
||||
&model.Vocabulary{
|
||||
Values: c.Strings("tokenizer.ggml.tokens"),
|
||||
Scores: c.Floats("tokenizer.ggml.scores"),
|
||||
Types: c.Ints("tokenizer.ggml.token_type"),
|
||||
AddBOS: c.Bool("tokenizer.ggml.add_bos_token", true),
|
||||
BOS: []int32{int32(c.Uint("tokenizer.ggml.bos_token_id"))},
|
||||
AddEOS: c.Bool("tokenizer.ggml.add_eos_token", false),
|
||||
EOS: append(
|
||||
[]int32{int32(c.Uint("tokenizer.ggml.eos_token_id"))},
|
||||
c.Ints("tokenizer.ggml.eos_token_ids")...,
|
||||
),
|
||||
},
|
||||
),
|
||||
}
|
||||
|
||||
m.Cache = kvcache.NewWrapperCache(
|
||||
kvcache.NewCausalCache(m.Shift),
|
||||
kvcache.NewSWACache(int32(c.Uint("attention.sliding_window")), m.Shift),
|
||||
)
|
||||
return &m, nil
|
||||
}
|
||||
|
||||
func init() {
|
||||
model.Register("gemma3n", New)
|
||||
}
|
||||
@@ -1,359 +0,0 @@
|
||||
package gemma3n
|
||||
|
||||
import (
|
||||
"cmp"
|
||||
"math"
|
||||
|
||||
"github.com/ollama/ollama/fs"
|
||||
"github.com/ollama/ollama/kvcache"
|
||||
"github.com/ollama/ollama/ml"
|
||||
"github.com/ollama/ollama/ml/nn"
|
||||
"github.com/ollama/ollama/ml/nn/fast"
|
||||
"github.com/ollama/ollama/ml/nn/rope"
|
||||
"github.com/ollama/ollama/model/input"
|
||||
)
|
||||
|
||||
type TextModel struct {
|
||||
TokenEmbedding *TextScaledWordEmbedding `gguf:"token_embd"`
|
||||
|
||||
*PerLayerProjector
|
||||
|
||||
AltupEmbd *nn.Linear `gguf:"altup_proj"`
|
||||
AltupUnembd *nn.Linear `gguf:"altup_unembd_proj"`
|
||||
|
||||
TextLayers []TextLayer `gguf:"blk"`
|
||||
OutputNorm *nn.RMSNorm `gguf:"output_norm"`
|
||||
Output *nn.Linear `gguf:"output,alt:token_embd"`
|
||||
|
||||
TextOptions
|
||||
}
|
||||
|
||||
func (m *TextModel) Forward(ctx ml.Context, batch input.Batch, cache kvcache.Cache) (ml.Tensor, error) {
|
||||
positions := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions))
|
||||
// Create a tensor of a single float32 value of 1.0 to use for altup correction
|
||||
one := ctx.Input().FromFloatSlice([]float32{1.0}, 1)
|
||||
|
||||
inputs := m.TokenEmbedding.Forward(ctx, batch.Inputs, math.Sqrt(float64(m.hiddenSize)))
|
||||
inputsPerLayer := m.PerLayerProjector.Forward(ctx, batch, inputs, &m.TextOptions)
|
||||
|
||||
targetMagnitude := inputs.Sqr(ctx).Mean(ctx).Sqrt(ctx)
|
||||
targetMagnitude = targetMagnitude.Repeat(ctx, 2, m.altupInputs-1)
|
||||
|
||||
hiddenState := inputs.Repeat(ctx, 2, m.altupInputs-1)
|
||||
altupProj := m.AltupEmbd.Forward(ctx, hiddenState)
|
||||
altupProj = altupProj.Mul(ctx, targetMagnitude.Div(ctx, altupProj.Sqr(ctx).Mean(ctx).Sqrt(ctx)))
|
||||
|
||||
hiddenStates := inputs.Concat(ctx, altupProj, 2)
|
||||
|
||||
firstSharedKeyValue := m.hiddenLayers - m.sharedKeyValueLayers
|
||||
for i, layer := range m.TextLayers {
|
||||
if i < firstSharedKeyValue {
|
||||
cache.SetLayer(i)
|
||||
} else if m.isLocal(i) {
|
||||
cache.SetLayer(firstSharedKeyValue - 2)
|
||||
} else {
|
||||
cache.SetLayer(firstSharedKeyValue - 1)
|
||||
}
|
||||
|
||||
var layerType int
|
||||
ropeBase := m.ropeBase
|
||||
if m.isLocal(i) {
|
||||
layerType = 1
|
||||
ropeBase = m.ropeBaseLocal
|
||||
}
|
||||
|
||||
cache.(*kvcache.WrapperCache).SetLayerType(layerType)
|
||||
|
||||
// inputPerLayer = inputsPerLayer[:, i, :]
|
||||
inputPerLayer := inputsPerLayer.View(ctx, i*inputsPerLayer.Stride(1), inputsPerLayer.Dim(0), inputsPerLayer.Stride(2), inputsPerLayer.Dim(2))
|
||||
hiddenStates = layer.Forward(ctx, hiddenStates, inputPerLayer, positions, one, cache, i >= firstSharedKeyValue, ropeBase, float64(m.activationSparsityScale[i]), &m.TextOptions)
|
||||
}
|
||||
|
||||
// hiddenStates = hiddenStates[:, :, 0]
|
||||
hiddenStates0 := hiddenStates.View(ctx, 0, hiddenStates.Dim(0), hiddenStates.Stride(1), hiddenStates.Dim(1))
|
||||
targetMagnitude = hiddenStates0.Sqr(ctx).Mean(ctx).Sqrt(ctx)
|
||||
targetMagnitude = targetMagnitude.Repeat(ctx, 2, m.altupInputs-1)
|
||||
|
||||
// hiddenState = hiddenStates[:, :, 1:]
|
||||
hiddenState = hiddenStates.View(ctx, hiddenStates.Stride(2), hiddenStates.Dim(0), hiddenStates.Stride(1), hiddenStates.Dim(1), hiddenStates.Stride(2), m.altupInputs-1)
|
||||
altupUnembdProj := m.AltupUnembd.Forward(ctx, hiddenState)
|
||||
altupUnembdProj = altupUnembdProj.Mul(ctx, targetMagnitude.Div(ctx, altupUnembdProj.Sqr(ctx).Mean(ctx).Sqrt(ctx)))
|
||||
|
||||
hiddenStates = hiddenStates0.Concat(ctx, altupUnembdProj, 2)
|
||||
|
||||
hiddenStates = hiddenStates.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx).Mean(ctx)
|
||||
hiddenStates = hiddenStates.Permute(ctx, 2, 0, 1, 3).Contiguous(ctx)
|
||||
hiddenStates = hiddenStates.Rows(ctx, ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs)))
|
||||
|
||||
hiddenStates = m.OutputNorm.Forward(ctx, hiddenStates, m.eps)
|
||||
return m.Output.Forward(ctx, hiddenStates), nil
|
||||
}
|
||||
|
||||
func (m *TextModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
|
||||
ropeBase := m.ropeBase
|
||||
if m.isLocal(layer) {
|
||||
ropeBase = m.ropeBaseLocal
|
||||
}
|
||||
|
||||
return fast.RoPE(ctx, key, shift, m.headDim(), ropeBase, m.ropeScale, rope.WithTypeNeoX()), nil
|
||||
}
|
||||
|
||||
type TextScaledWordEmbedding struct {
|
||||
*nn.Embedding
|
||||
}
|
||||
|
||||
func (e TextScaledWordEmbedding) Forward(ctx ml.Context, inputIDs ml.Tensor, scale float64) ml.Tensor {
|
||||
return e.Embedding.Forward(ctx, inputIDs).Scale(ctx, scale)
|
||||
}
|
||||
|
||||
type PerLayerProjector struct {
|
||||
TokenEmbedding *TextScaledWordEmbedding `gguf:"per_layer_token_embd"`
|
||||
Projector *nn.Linear `gguf:"per_layer_model_proj"`
|
||||
Norm *nn.RMSNorm `gguf:"per_layer_proj_norm"`
|
||||
}
|
||||
|
||||
func (p PerLayerProjector) Forward(ctx ml.Context, batch input.Batch, inputs ml.Tensor, opts *TextOptions) ml.Tensor {
|
||||
inputsPerLayer := p.TokenEmbedding.Forward(ctx, batch.Inputs, math.Sqrt(float64(opts.hiddenSizePerLayerInput)))
|
||||
inputsPerLayer = inputsPerLayer.Reshape(ctx, opts.hiddenSizePerLayerInput, opts.hiddenLayers, batch.Inputs.Dim(0), batch.Inputs.Dim(1))
|
||||
|
||||
perLayerProjection := p.Projector.Forward(ctx, inputs)
|
||||
perLayerProjection = perLayerProjection.Scale(ctx, math.Sqrt(float64(opts.hiddenSize)))
|
||||
perLayerProjection = perLayerProjection.Reshape(ctx, opts.hiddenSizePerLayerInput, opts.hiddenLayers, inputs.Dim(1))
|
||||
perLayerProjection = p.Norm.Forward(ctx, perLayerProjection, opts.eps)
|
||||
|
||||
if inputsPerLayer != nil {
|
||||
perLayerProjection = perLayerProjection.Add(ctx, inputsPerLayer)
|
||||
perLayerProjection = perLayerProjection.Scale(ctx, 1/math.Sqrt(2))
|
||||
}
|
||||
|
||||
return perLayerProjection
|
||||
}
|
||||
|
||||
type TextLayer struct {
|
||||
*AltUp
|
||||
*Laurel
|
||||
|
||||
AttentionNorm *nn.RMSNorm `gguf:"attn_norm"`
|
||||
Attention *TextAttention
|
||||
PostAttentionNorm *nn.RMSNorm `gguf:"post_attention_norm"`
|
||||
|
||||
MLPNorm *nn.RMSNorm `gguf:"ffn_norm"`
|
||||
MLP *TextMLP
|
||||
PostMLPNorm *nn.RMSNorm `gguf:"post_ffw_norm"`
|
||||
|
||||
PerLayerInputGate *nn.Linear `gguf:"inp_gate"`
|
||||
PerLayerProjection *nn.Linear `gguf:"proj"`
|
||||
PostPerLayerNorm *nn.RMSNorm `gguf:"post_norm"`
|
||||
}
|
||||
|
||||
func (d TextLayer) Forward(ctx ml.Context, hiddenStates, perLayerInput, positions, one ml.Tensor, cache kvcache.Cache, sharedKV bool, ropeBase float32, activationSparsityScale float64, opts *TextOptions) ml.Tensor {
|
||||
predictions := d.Predict(ctx, hiddenStates, opts)
|
||||
active := opts.altupActive(ctx, predictions)
|
||||
|
||||
attn := d.AttentionNorm.Forward(ctx, active, opts.eps)
|
||||
laurel := d.Laurel.Forward(ctx, attn, opts)
|
||||
|
||||
attn = d.Attention.Forward(ctx, attn, positions, cache, sharedKV, ropeBase, opts)
|
||||
attn = d.PostAttentionNorm.Forward(ctx, attn, opts.eps)
|
||||
attn = active.Add(ctx, attn)
|
||||
attn = attn.Add(ctx, laurel).Scale(ctx, 1/math.Sqrt(2))
|
||||
|
||||
mlp := d.MLPNorm.Forward(ctx, attn, opts.eps)
|
||||
mlp = d.MLP.Forward(ctx, mlp, activationSparsityScale)
|
||||
mlp = d.PostMLPNorm.Forward(ctx, mlp, opts.eps)
|
||||
mlp = attn.Add(ctx, mlp)
|
||||
|
||||
predictions = d.Correct(ctx, predictions, mlp, one, opts)
|
||||
active = opts.altupActive(ctx, predictions)
|
||||
if opts.altupCorrectScale {
|
||||
active = d.ScaleCorrectedOutput(ctx, active)
|
||||
}
|
||||
|
||||
active = d.PerLayerInputGate.Forward(ctx, active)
|
||||
active = active.GELU(ctx)
|
||||
active = active.Mul(ctx, perLayerInput)
|
||||
|
||||
active = d.PerLayerProjection.Forward(ctx, active)
|
||||
active = d.PostPerLayerNorm.Forward(ctx, active, opts.eps)
|
||||
|
||||
// inactive := predictions[:, :, 1:]
|
||||
inactive := predictions.View(ctx, predictions.Stride(2), predictions.Dim(0), predictions.Stride(1), predictions.Dim(1), predictions.Stride(2), predictions.Dim(2)-1)
|
||||
active = inactive.Add(ctx, active)
|
||||
|
||||
predictions0 := predictions.View(ctx, 0, predictions.Dim(0), predictions.Stride(1), predictions.Dim(1))
|
||||
return predictions0.Concat(ctx, active, 2)
|
||||
}
|
||||
|
||||
type AltUp struct {
|
||||
CorrectionScale ml.Tensor `gguf:"altup_correct_scale.weight"`
|
||||
PredictionCoefficient *nn.Linear `gguf:"altup_predict_coef"`
|
||||
CorrectionCoefficient *nn.Linear `gguf:"altup_correct_coef"`
|
||||
Router *nn.Linear `gguf:"altup_router"`
|
||||
RouterNorm *nn.RMSNorm `gguf:"altup_router_norm"`
|
||||
}
|
||||
|
||||
func (a AltUp) computeRouterModalities(ctx ml.Context, hiddenStates ml.Tensor, opts *TextOptions) ml.Tensor {
|
||||
routerInputs := a.RouterNorm.Forward(ctx, hiddenStates, opts.eps).Scale(ctx, 1.0/float64(opts.hiddenSize))
|
||||
return a.Router.Forward(ctx, routerInputs).Tanh(ctx)
|
||||
}
|
||||
|
||||
func (a AltUp) Predict(ctx ml.Context, hiddenStates ml.Tensor, opts *TextOptions) ml.Tensor {
|
||||
modalities := a.computeRouterModalities(ctx, opts.altupActive(ctx, hiddenStates), opts)
|
||||
|
||||
coefficients := a.PredictionCoefficient.Forward(ctx, modalities)
|
||||
coefficients = coefficients.Reshape(ctx, opts.altupInputs, opts.altupInputs, coefficients.Dim(1), coefficients.Dim(2))
|
||||
|
||||
predictions := coefficients.Mulmat(ctx, hiddenStates.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx))
|
||||
predictions = predictions.Permute(ctx, 2, 0, 1, 3).Contiguous(ctx)
|
||||
return predictions.Add(ctx, hiddenStates)
|
||||
}
|
||||
|
||||
func (a AltUp) Correct(ctx ml.Context, predictions, activated, one ml.Tensor, opts *TextOptions) ml.Tensor {
|
||||
innovation := activated.Sub(ctx, opts.altupActive(ctx, predictions))
|
||||
innovation = innovation.Repeat(ctx, 2, opts.altupInputs)
|
||||
|
||||
modalities := a.computeRouterModalities(ctx, activated, opts)
|
||||
coefficients := a.CorrectionCoefficient.Forward(ctx, modalities)
|
||||
coefficients = coefficients.Add(ctx, one)
|
||||
|
||||
coefficients = coefficients.Reshape(ctx, 1, coefficients.Dim(0), coefficients.Dim(1))
|
||||
coefficients = coefficients.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
|
||||
|
||||
corrected := innovation.Mul(ctx, coefficients)
|
||||
corrected = corrected.Add(ctx, predictions)
|
||||
return corrected
|
||||
}
|
||||
|
||||
func (a AltUp) ScaleCorrectedOutput(ctx ml.Context, predictions ml.Tensor) ml.Tensor {
|
||||
return predictions.Mul(ctx, a.CorrectionScale)
|
||||
}
|
||||
|
||||
type Laurel struct {
|
||||
LinearLeft *nn.Linear `gguf:"laurel_l"`
|
||||
LinearRight *nn.Linear `gguf:"laurel_r"`
|
||||
PostLaurelNorm *nn.RMSNorm `gguf:"laurel_post_norm"`
|
||||
}
|
||||
|
||||
func (l Laurel) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *TextOptions) ml.Tensor {
|
||||
residual := hiddenStates
|
||||
hiddenStates = l.LinearLeft.Forward(ctx, hiddenStates)
|
||||
hiddenStates = l.LinearRight.Forward(ctx, hiddenStates)
|
||||
hiddenStates = l.PostLaurelNorm.Forward(ctx, hiddenStates, opts.eps)
|
||||
return hiddenStates.Add(ctx, residual)
|
||||
}
|
||||
|
||||
type TextAttention struct {
|
||||
Query *nn.Linear `gguf:"attn_q"`
|
||||
QueryNorm *nn.RMSNorm `gguf:"attn_q_norm"`
|
||||
Key *nn.Linear `gguf:"attn_k"`
|
||||
KeyNorm *nn.RMSNorm `gguf:"attn_k_norm"`
|
||||
Value *nn.Linear `gguf:"attn_v"`
|
||||
Output *nn.Linear `gguf:"attn_output"`
|
||||
}
|
||||
|
||||
func (attn TextAttention) Forward(ctx ml.Context, hiddenStates, positions ml.Tensor, cache kvcache.Cache, sharedKV bool, ropeBase float32, opts *TextOptions) ml.Tensor {
|
||||
batchSize := hiddenStates.Dim(1)
|
||||
|
||||
query := attn.Query.Forward(ctx, hiddenStates)
|
||||
query = query.Reshape(ctx, opts.headDim(), opts.numHeads, batchSize)
|
||||
query = attn.QueryNorm.Forward(ctx, query, opts.eps)
|
||||
query = fast.RoPE(ctx, query, positions, opts.headDim(), ropeBase, opts.ropeScale, rope.WithTypeNeoX())
|
||||
|
||||
var key, value ml.Tensor
|
||||
if !sharedKV {
|
||||
key = attn.Key.Forward(ctx, hiddenStates)
|
||||
key = key.Reshape(ctx, opts.headDim(), opts.numKVHeads, batchSize)
|
||||
key = attn.KeyNorm.Forward(ctx, key, opts.eps)
|
||||
key = fast.RoPE(ctx, key, positions, opts.headDim(), ropeBase, opts.ropeScale, rope.WithTypeNeoX())
|
||||
|
||||
value = attn.Value.Forward(ctx, hiddenStates)
|
||||
value = value.Reshape(ctx, opts.headDim(), opts.numKVHeads, batchSize)
|
||||
value = value.RMSNorm(ctx, nil, opts.eps)
|
||||
}
|
||||
|
||||
attention := nn.Attention(ctx, query, key, value, 1., cache)
|
||||
attention = attention.Reshape(ctx, attention.Dim(0)*attention.Dim(1), batchSize)
|
||||
return attn.Output.Forward(ctx, attention)
|
||||
}
|
||||
|
||||
type TextMLP struct {
|
||||
Gate *nn.Linear `gguf:"ffn_gate"`
|
||||
Up *nn.Linear `gguf:"ffn_up"`
|
||||
Down *nn.Linear `gguf:"ffn_down"`
|
||||
}
|
||||
|
||||
func (mlp TextMLP) Forward(ctx ml.Context, hiddenStates ml.Tensor, activationSparsityScale float64) ml.Tensor {
|
||||
upStates := mlp.Up.Forward(ctx, hiddenStates)
|
||||
hiddenStates = mlp.Gate.Forward(ctx, hiddenStates)
|
||||
if activationSparsityScale > 0 {
|
||||
mean := hiddenStates.Mean(ctx)
|
||||
std := hiddenStates.Stddev(ctx).Scale(ctx, activationSparsityScale)
|
||||
cutoff := mean.Add(ctx, std)
|
||||
hiddenStates = hiddenStates.Sub(ctx, cutoff).RELU(ctx)
|
||||
}
|
||||
|
||||
hiddenStates = hiddenStates.GELU(ctx).Mul(ctx, upStates)
|
||||
hiddenStates = mlp.Down.Forward(ctx, hiddenStates)
|
||||
return hiddenStates
|
||||
}
|
||||
|
||||
type TextOptions struct {
|
||||
hiddenLayers int
|
||||
hiddenSize int
|
||||
hiddenSizePerLayerInput int
|
||||
numHeads, numKVHeads int
|
||||
keyLength, valueLength int
|
||||
sharedKeyValueLayers int
|
||||
|
||||
altupActiveIndex int
|
||||
altupInputs int
|
||||
altupCorrectScale bool
|
||||
|
||||
eps float32
|
||||
ropeBase float32
|
||||
ropeBaseLocal float32
|
||||
ropeScale float32
|
||||
|
||||
slidingWindowPattern []bool
|
||||
activationSparsityScale []float32
|
||||
}
|
||||
|
||||
func (o *TextOptions) altupActive(ctx ml.Context, t ml.Tensor) ml.Tensor {
|
||||
// t[:, :, o.altupActiveIndex]
|
||||
return t.View(ctx, o.altupActiveIndex*t.Stride(2), t.Dim(0), t.Stride(1), t.Dim(1))
|
||||
}
|
||||
|
||||
func (o *TextOptions) headDim() int {
|
||||
return cmp.Or(o.keyLength, o.valueLength, o.hiddenSize/o.numHeads)
|
||||
}
|
||||
|
||||
func (o *TextOptions) isLocal(i int) bool {
|
||||
return o.slidingWindowPattern[i]
|
||||
}
|
||||
|
||||
func newTextModel(c fs.Config) *TextModel {
|
||||
return &TextModel{
|
||||
TextLayers: make([]TextLayer, c.Uint("block_count")),
|
||||
TextOptions: TextOptions{
|
||||
hiddenLayers: int(c.Uint("block_count")),
|
||||
hiddenSize: int(c.Uint("embedding_length")),
|
||||
hiddenSizePerLayerInput: int(c.Uint("embedding_length_per_layer_input")),
|
||||
numHeads: int(c.Uint("attention.head_count")),
|
||||
numKVHeads: int(c.Uint("attention.head_count_kv")),
|
||||
keyLength: int(c.Uint("attention.key_length")),
|
||||
valueLength: int(c.Uint("attention.value_length")),
|
||||
sharedKeyValueLayers: int(c.Uint("attention.shared_kv_layers")),
|
||||
|
||||
altupActiveIndex: int(c.Uint("altup.active_idx")),
|
||||
altupInputs: int(c.Uint("altup.num_inputs")),
|
||||
|
||||
eps: c.Float("attention.layer_norm_rms_epsilon", 1e-06),
|
||||
ropeBase: c.Float("rope.freq_base", 1_000_000),
|
||||
ropeBaseLocal: c.Float("rope.freq_base_local", 10_000),
|
||||
ropeScale: c.Float("rope.freq_scale", 1.0),
|
||||
|
||||
slidingWindowPattern: c.Bools("attention.sliding_window_pattern"),
|
||||
activationSparsityScale: c.Floats("activation_sparsity_scale"),
|
||||
},
|
||||
}
|
||||
}
|
||||
@@ -2,7 +2,6 @@ package llama
|
||||
|
||||
import (
|
||||
"cmp"
|
||||
"fmt"
|
||||
"math"
|
||||
|
||||
"github.com/ollama/ollama/fs"
|
||||
@@ -34,14 +33,6 @@ type Model struct {
|
||||
}
|
||||
|
||||
func New(c fs.Config) (model.Model, error) {
|
||||
// This model currently only supports the gpt2 tokenizer
|
||||
if c.String("tokenizer.ggml.model") == "llama" {
|
||||
return nil, fmt.Errorf("unsupported tokenizer: llama")
|
||||
}
|
||||
// Best effort detection of library/deepseek-coder model(s) which are incompatible
|
||||
if c.String("general.name") == "deepseek-ai" {
|
||||
return nil, fmt.Errorf("unsupported model: %s", c.String("general.name"))
|
||||
}
|
||||
m := Model{
|
||||
BytePairEncoding: model.NewBytePairEncoding(
|
||||
c.String("tokenizer.ggml.pretokenizer", `(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`),
|
||||
|
||||
@@ -3,7 +3,6 @@ package models
|
||||
import (
|
||||
_ "github.com/ollama/ollama/model/models/gemma2"
|
||||
_ "github.com/ollama/ollama/model/models/gemma3"
|
||||
_ "github.com/ollama/ollama/model/models/gemma3n"
|
||||
_ "github.com/ollama/ollama/model/models/llama"
|
||||
_ "github.com/ollama/ollama/model/models/llama4"
|
||||
_ "github.com/ollama/ollama/model/models/mistral3"
|
||||
|
||||
@@ -2,9 +2,7 @@ package qwen2
|
||||
|
||||
import (
|
||||
"cmp"
|
||||
"fmt"
|
||||
"math"
|
||||
"strings"
|
||||
|
||||
"github.com/ollama/ollama/fs"
|
||||
"github.com/ollama/ollama/kvcache"
|
||||
@@ -128,14 +126,6 @@ func (m Model) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor
|
||||
}
|
||||
|
||||
func New(c fs.Config) (model.Model, error) {
|
||||
// This model currently only supports the gpt2 tokenizer
|
||||
if c.String("tokenizer.ggml.model") == "llama" {
|
||||
return nil, fmt.Errorf("unsupported tokenizer: llama")
|
||||
}
|
||||
// detect library/qwen model(s) which are incompatible
|
||||
if strings.HasPrefix(c.String("general.name"), "Qwen2-beta") {
|
||||
return nil, fmt.Errorf("unsupported model: %s", c.String("general.name"))
|
||||
}
|
||||
m := Model{
|
||||
Layers: make([]DecoderLayer, c.Uint("block_count")),
|
||||
BytePairEncoding: model.NewBytePairEncoding(
|
||||
|
||||
@@ -423,7 +423,7 @@ func fromChatRequest(r ChatCompletionRequest) (*api.ChatRequest, error) {
|
||||
}
|
||||
}
|
||||
|
||||
types := []string{"jpeg", "jpg", "png", "webp"}
|
||||
types := []string{"jpeg", "jpg", "png"}
|
||||
valid := false
|
||||
for _, t := range types {
|
||||
prefix := "data:image/" + t + ";base64,"
|
||||
|
||||
@@ -80,6 +80,9 @@ type Sequence struct {
|
||||
// true if an embedding are to be returned instead of text generation
|
||||
embeddingOnly bool
|
||||
|
||||
// true if context shifting should be enabled
|
||||
shiftContext bool
|
||||
|
||||
doneReason llm.DoneReason
|
||||
|
||||
// Metrics
|
||||
@@ -90,11 +93,12 @@ type Sequence struct {
|
||||
}
|
||||
|
||||
type NewSequenceParams struct {
|
||||
numPredict int
|
||||
stop []string
|
||||
numKeep int
|
||||
samplingParams *llama.SamplingParams
|
||||
embedding bool
|
||||
numPredict int
|
||||
stop []string
|
||||
numKeep int
|
||||
samplingParams *llama.SamplingParams
|
||||
embedding bool
|
||||
enableContextShift bool
|
||||
}
|
||||
|
||||
func (s *Server) NewSequence(prompt string, images []llm.ImageData, params NewSequenceParams) (*Sequence, error) {
|
||||
@@ -120,7 +124,7 @@ func (s *Server) NewSequence(prompt string, images []llm.ImageData, params NewSe
|
||||
// Ensure that at least 1 input can be discarded during shift
|
||||
params.numKeep = min(params.numKeep, s.cache.numCtx-1)
|
||||
|
||||
if len(inputs) > s.cache.numCtx {
|
||||
if len(inputs) > s.cache.numCtx && params.enableContextShift {
|
||||
discard := len(inputs) - s.cache.numCtx
|
||||
newInputs := inputs[:params.numKeep]
|
||||
newInputs = append(newInputs, inputs[params.numKeep+discard:]...)
|
||||
@@ -155,6 +159,7 @@ func (s *Server) NewSequence(prompt string, images []llm.ImageData, params NewSe
|
||||
embeddingOnly: params.embedding,
|
||||
stop: params.stop,
|
||||
numKeep: params.numKeep,
|
||||
shiftContext: params.enableContextShift,
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -300,13 +305,26 @@ func flushPending(seq *Sequence) bool {
|
||||
|
||||
func (s *Server) removeSequence(seqIndex int, reason llm.DoneReason) {
|
||||
seq := s.seqs[seqIndex]
|
||||
if seq == nil {
|
||||
return
|
||||
}
|
||||
|
||||
// Mark the sequence as being removed to prevent further processing
|
||||
s.seqs[seqIndex] = nil
|
||||
|
||||
if seq.cache != nil {
|
||||
seq.cache.InUse = false
|
||||
}
|
||||
|
||||
if len(seq.pendingResponses) > 0 {
|
||||
flushPending(seq)
|
||||
}
|
||||
|
||||
flushPending(seq)
|
||||
seq.doneReason = reason
|
||||
|
||||
close(seq.responses)
|
||||
close(seq.embedding)
|
||||
seq.cache.InUse = false
|
||||
s.seqs[seqIndex] = nil
|
||||
|
||||
s.seqsSem.Release(1)
|
||||
}
|
||||
|
||||
@@ -340,7 +358,7 @@ func (s *Server) run(ctx context.Context) {
|
||||
default:
|
||||
err := s.processBatch(tokenBatch, embedBatch)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
slog.Error("error processing batch", "error", err)
|
||||
}
|
||||
|
||||
tokenBatch.Clear()
|
||||
@@ -382,6 +400,10 @@ func (s *Server) processBatch(tokenBatch *llama.Batch, embedBatch *llama.Batch)
|
||||
|
||||
for i, input := range seq.inputs {
|
||||
if len(seq.cache.Inputs)+len(seq.pendingInputs)+1 > s.cache.numCtx {
|
||||
if !seq.shiftContext {
|
||||
s.removeSequence(seqIdx, llm.DoneReasonContextShift)
|
||||
continue
|
||||
}
|
||||
if len(seq.pendingInputs) == 0 {
|
||||
err := s.cache.ShiftCacheSlot(seq.cache, seq.numKeep)
|
||||
if err != nil {
|
||||
@@ -573,11 +595,12 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
seq, err := s.NewSequence(req.Prompt, req.Images, NewSequenceParams{
|
||||
numPredict: req.Options.NumPredict,
|
||||
stop: req.Options.Stop,
|
||||
numKeep: req.Options.NumKeep,
|
||||
samplingParams: &samplingParams,
|
||||
embedding: false,
|
||||
numPredict: req.Options.NumPredict,
|
||||
stop: req.Options.Stop,
|
||||
numKeep: req.Options.NumKeep,
|
||||
samplingParams: &samplingParams,
|
||||
embedding: false,
|
||||
enableContextShift: req.Options.ShiftContext,
|
||||
})
|
||||
if err != nil {
|
||||
http.Error(w, fmt.Sprintf("Failed to create new sequence: %v", err), http.StatusInternalServerError)
|
||||
|
||||
152
runner/llamarunner/runner_test.go
Normal file
152
runner/llamarunner/runner_test.go
Normal file
@@ -0,0 +1,152 @@
|
||||
package llamarunner
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/ollama/ollama/llm"
|
||||
)
|
||||
|
||||
func TestContextShiftLogic(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
enableContextShift bool
|
||||
contextLength int32
|
||||
cacheInputs int
|
||||
pendingInputs int
|
||||
minBatch int
|
||||
expectedDoneReason llm.DoneReason
|
||||
shouldRemove bool
|
||||
}{
|
||||
{
|
||||
name: "context shifting enabled - should shift",
|
||||
enableContextShift: true,
|
||||
contextLength: 100,
|
||||
cacheInputs: 80,
|
||||
pendingInputs: 0,
|
||||
minBatch: 30,
|
||||
expectedDoneReason: llm.DoneReasonStop,
|
||||
shouldRemove: false,
|
||||
},
|
||||
{
|
||||
name: "context shifting disabled - should remove",
|
||||
enableContextShift: false,
|
||||
contextLength: 100,
|
||||
cacheInputs: 80,
|
||||
pendingInputs: 0,
|
||||
minBatch: 30,
|
||||
expectedDoneReason: llm.DoneReasonContextShift,
|
||||
shouldRemove: true,
|
||||
},
|
||||
{
|
||||
name: "context shifting disabled - within limits",
|
||||
enableContextShift: false,
|
||||
contextLength: 100,
|
||||
cacheInputs: 50,
|
||||
pendingInputs: 0,
|
||||
minBatch: 30,
|
||||
expectedDoneReason: llm.DoneReasonStop,
|
||||
shouldRemove: false,
|
||||
},
|
||||
{
|
||||
name: "pending inputs - should break batch",
|
||||
enableContextShift: true,
|
||||
contextLength: 100,
|
||||
cacheInputs: 50,
|
||||
pendingInputs: 20,
|
||||
minBatch: 30,
|
||||
expectedDoneReason: llm.DoneReasonStop,
|
||||
shouldRemove: false,
|
||||
},
|
||||
{
|
||||
name: "no pending inputs - should shift",
|
||||
enableContextShift: true,
|
||||
contextLength: 100,
|
||||
cacheInputs: 80,
|
||||
pendingInputs: 0,
|
||||
minBatch: 30,
|
||||
expectedDoneReason: llm.DoneReasonStop,
|
||||
shouldRemove: false,
|
||||
},
|
||||
{
|
||||
name: "long prompt with context shifting disabled - will be handled at runtime",
|
||||
enableContextShift: false,
|
||||
contextLength: 100,
|
||||
cacheInputs: 0,
|
||||
pendingInputs: 0,
|
||||
minBatch: 150, // Simulates a long prompt
|
||||
expectedDoneReason: llm.DoneReasonContextShift,
|
||||
shouldRemove: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// Test the core logic from processBatch
|
||||
if int32(tt.cacheInputs+tt.pendingInputs+tt.minBatch) > tt.contextLength {
|
||||
if tt.pendingInputs != 0 {
|
||||
// Should break batch
|
||||
if tt.shouldRemove {
|
||||
t.Error("should not remove sequence when pending inputs exist")
|
||||
}
|
||||
} else if !tt.enableContextShift {
|
||||
// Should remove with DoneReasonContextShift
|
||||
if !tt.shouldRemove {
|
||||
t.Error("should remove sequence when context shifting disabled")
|
||||
}
|
||||
if tt.expectedDoneReason != llm.DoneReasonContextShift {
|
||||
t.Errorf("expected DoneReason %v, got %v", llm.DoneReasonContextShift, tt.expectedDoneReason)
|
||||
}
|
||||
} else {
|
||||
// Should shift context
|
||||
if tt.shouldRemove {
|
||||
t.Error("should not remove sequence when context shifting enabled")
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestPredictLimitLogic(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
numPredict int
|
||||
numPredicted int
|
||||
expectRemove bool
|
||||
}{
|
||||
{
|
||||
name: "predict limit not reached",
|
||||
numPredict: 5,
|
||||
numPredicted: 3,
|
||||
expectRemove: false,
|
||||
},
|
||||
{
|
||||
name: "predict limit reached",
|
||||
numPredict: 5,
|
||||
numPredicted: 5,
|
||||
expectRemove: true,
|
||||
},
|
||||
{
|
||||
name: "predict limit exceeded",
|
||||
numPredict: 5,
|
||||
numPredicted: 6,
|
||||
expectRemove: true,
|
||||
},
|
||||
{
|
||||
name: "no predict limit",
|
||||
numPredict: 0,
|
||||
numPredicted: 100,
|
||||
expectRemove: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// Test the core logic from processBatch
|
||||
shouldRemove := tt.numPredict > 0 && tt.numPredicted >= tt.numPredict
|
||||
if shouldRemove != tt.expectRemove {
|
||||
t.Errorf("expected remove=%v, got %v", tt.expectRemove, shouldRemove)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -85,6 +85,9 @@ type Sequence struct {
|
||||
// true if an embedding are to be returned instead of text generation
|
||||
embeddingOnly bool
|
||||
|
||||
// true if context shifting should be enabled
|
||||
shiftContext bool
|
||||
|
||||
doneReason llm.DoneReason
|
||||
|
||||
// Metrics
|
||||
@@ -95,11 +98,12 @@ type Sequence struct {
|
||||
}
|
||||
|
||||
type NewSequenceParams struct {
|
||||
numPredict int
|
||||
stop []string
|
||||
numKeep int32
|
||||
sampler sample.Sampler
|
||||
embedding bool
|
||||
numPredict int
|
||||
stop []string
|
||||
numKeep int32
|
||||
sampler sample.Sampler
|
||||
embedding bool
|
||||
enableContextShift bool
|
||||
}
|
||||
|
||||
func (s *Server) NewSequence(prompt string, images []llm.ImageData, params NewSequenceParams) (*Sequence, error) {
|
||||
@@ -121,7 +125,7 @@ func (s *Server) NewSequence(prompt string, images []llm.ImageData, params NewSe
|
||||
// Ensure that at least 1 input can be discarded during shift
|
||||
params.numKeep = min(params.numKeep, s.cache.numCtx-1)
|
||||
|
||||
if int32(len(inputs)) > s.cache.numCtx {
|
||||
if int32(len(inputs)) > s.cache.numCtx && params.enableContextShift {
|
||||
discard := int32(len(inputs)) - s.cache.numCtx
|
||||
promptStart := params.numKeep + discard
|
||||
|
||||
@@ -175,6 +179,7 @@ func (s *Server) NewSequence(prompt string, images []llm.ImageData, params NewSe
|
||||
embeddingOnly: params.embedding,
|
||||
stop: params.stop,
|
||||
numKeep: params.numKeep,
|
||||
shiftContext: params.enableContextShift,
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -341,13 +346,25 @@ func flushPending(seq *Sequence) bool {
|
||||
|
||||
func (s *Server) removeSequence(seqIndex int, reason llm.DoneReason) {
|
||||
seq := s.seqs[seqIndex]
|
||||
if seq == nil {
|
||||
return
|
||||
}
|
||||
|
||||
// Mark the sequence as being removed to prevent further processing
|
||||
s.seqs[seqIndex] = nil
|
||||
|
||||
if seq.cache != nil {
|
||||
seq.cache.InUse = false
|
||||
}
|
||||
|
||||
if len(seq.pendingResponses) > 0 {
|
||||
flushPending(seq)
|
||||
}
|
||||
|
||||
flushPending(seq)
|
||||
seq.doneReason = reason
|
||||
|
||||
close(seq.responses)
|
||||
close(seq.embedding)
|
||||
seq.cache.InUse = false
|
||||
s.seqs[seqIndex] = nil
|
||||
s.seqsSem.Release(1)
|
||||
}
|
||||
|
||||
@@ -431,6 +448,11 @@ func (s *Server) processBatch() error {
|
||||
break
|
||||
}
|
||||
|
||||
if !seq.shiftContext {
|
||||
s.removeSequence(seqIdx, llm.DoneReasonContextShift)
|
||||
continue
|
||||
}
|
||||
|
||||
err := s.cache.ShiftCacheSlot(seq.cache, seq.numKeep)
|
||||
if err != nil {
|
||||
var reprocess *ErrReprocessInputs
|
||||
@@ -629,11 +651,12 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
|
||||
)
|
||||
|
||||
seq, err := s.NewSequence(req.Prompt, req.Images, NewSequenceParams{
|
||||
numPredict: req.Options.NumPredict,
|
||||
stop: req.Options.Stop,
|
||||
numKeep: int32(req.Options.NumKeep),
|
||||
sampler: sampler,
|
||||
embedding: false,
|
||||
numPredict: req.Options.NumPredict,
|
||||
stop: req.Options.Stop,
|
||||
numKeep: int32(req.Options.NumKeep),
|
||||
sampler: sampler,
|
||||
embedding: false,
|
||||
enableContextShift: req.Options.ShiftContext,
|
||||
})
|
||||
if err != nil {
|
||||
http.Error(w, fmt.Sprintf("Failed to create new sequence: %v", err), http.StatusInternalServerError)
|
||||
|
||||
167
runner/ollamarunner/runner_test.go
Normal file
167
runner/ollamarunner/runner_test.go
Normal file
@@ -0,0 +1,167 @@
|
||||
package ollamarunner
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/ollama/ollama/llm"
|
||||
)
|
||||
|
||||
func TestEnableContextShiftLogic(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
enableContextShift bool
|
||||
contextLength int32
|
||||
cacheInputs int
|
||||
pendingInputs int
|
||||
minBatch int
|
||||
expectedDoneReason llm.DoneReason
|
||||
shouldRemove bool
|
||||
}{
|
||||
{
|
||||
name: "context shifting enabled - should shift",
|
||||
enableContextShift: true,
|
||||
contextLength: 100,
|
||||
cacheInputs: 80,
|
||||
pendingInputs: 0,
|
||||
minBatch: 30,
|
||||
expectedDoneReason: llm.DoneReasonStop,
|
||||
shouldRemove: false,
|
||||
},
|
||||
{
|
||||
name: "context shifting disabled - should remove with DoneReasonContextShift",
|
||||
enableContextShift: false,
|
||||
contextLength: 100,
|
||||
cacheInputs: 80,
|
||||
pendingInputs: 0,
|
||||
minBatch: 30,
|
||||
expectedDoneReason: llm.DoneReasonContextShift,
|
||||
shouldRemove: true,
|
||||
},
|
||||
{
|
||||
name: "context shifting disabled - within limits",
|
||||
enableContextShift: false,
|
||||
contextLength: 100,
|
||||
cacheInputs: 50,
|
||||
pendingInputs: 0,
|
||||
minBatch: 30,
|
||||
expectedDoneReason: llm.DoneReasonStop,
|
||||
shouldRemove: false,
|
||||
},
|
||||
{
|
||||
name: "context shifting disabled - exact limit",
|
||||
enableContextShift: false,
|
||||
contextLength: 100,
|
||||
cacheInputs: 100,
|
||||
pendingInputs: 0,
|
||||
minBatch: 1,
|
||||
expectedDoneReason: llm.DoneReasonContextShift,
|
||||
shouldRemove: true,
|
||||
},
|
||||
{
|
||||
name: "pending inputs - should break batch",
|
||||
enableContextShift: true,
|
||||
contextLength: 100,
|
||||
cacheInputs: 50,
|
||||
pendingInputs: 20,
|
||||
minBatch: 30,
|
||||
expectedDoneReason: llm.DoneReasonStop,
|
||||
shouldRemove: false,
|
||||
},
|
||||
{
|
||||
name: "no pending inputs - should shift",
|
||||
enableContextShift: true,
|
||||
contextLength: 100,
|
||||
cacheInputs: 80,
|
||||
pendingInputs: 0,
|
||||
minBatch: 30,
|
||||
expectedDoneReason: llm.DoneReasonStop,
|
||||
shouldRemove: false,
|
||||
},
|
||||
{
|
||||
name: "long prompt with context shifting disabled - will be handled at runtime",
|
||||
enableContextShift: false,
|
||||
contextLength: 100,
|
||||
cacheInputs: 0,
|
||||
pendingInputs: 0,
|
||||
minBatch: 150, // Simulates a long prompt
|
||||
expectedDoneReason: llm.DoneReasonContextShift,
|
||||
shouldRemove: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// Test the core logic from processBatch - matches actual implementation
|
||||
if int32(tt.cacheInputs+tt.pendingInputs+tt.minBatch) > tt.contextLength {
|
||||
if tt.pendingInputs != 0 {
|
||||
// Should break batch - don't remove sequence
|
||||
if tt.shouldRemove {
|
||||
t.Error("should not remove sequence when pending inputs exist")
|
||||
}
|
||||
} else if !tt.enableContextShift {
|
||||
// Should remove with DoneReasonContextShift
|
||||
if !tt.shouldRemove {
|
||||
t.Error("should remove sequence when context shifting disabled")
|
||||
}
|
||||
if tt.expectedDoneReason != llm.DoneReasonContextShift {
|
||||
t.Errorf("expected DoneReason %v, got %v", llm.DoneReasonContextShift, tt.expectedDoneReason)
|
||||
}
|
||||
} else {
|
||||
// Should shift context - don't remove sequence
|
||||
if tt.shouldRemove {
|
||||
t.Error("should not remove sequence when context shifting enabled")
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// Within limits - should not remove
|
||||
if tt.shouldRemove {
|
||||
t.Errorf("should not remove sequence when within context limits")
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestPredictLimitLogic(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
numPredict int
|
||||
numPredicted int
|
||||
expectRemove bool
|
||||
}{
|
||||
{
|
||||
name: "predict limit not reached",
|
||||
numPredict: 5,
|
||||
numPredicted: 3,
|
||||
expectRemove: false,
|
||||
},
|
||||
{
|
||||
name: "predict limit reached",
|
||||
numPredict: 5,
|
||||
numPredicted: 5,
|
||||
expectRemove: true,
|
||||
},
|
||||
{
|
||||
name: "predict limit exceeded",
|
||||
numPredict: 5,
|
||||
numPredicted: 6,
|
||||
expectRemove: true,
|
||||
},
|
||||
{
|
||||
name: "no predict limit",
|
||||
numPredict: 0,
|
||||
numPredicted: 100,
|
||||
expectRemove: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// Test the core logic from processBatch
|
||||
shouldRemove := tt.numPredict > 0 && tt.numPredicted >= tt.numPredict
|
||||
if shouldRemove != tt.expectRemove {
|
||||
t.Errorf("expected remove=%v, got %v", tt.expectRemove, shouldRemove)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -27,6 +27,7 @@ function checkEnv() {
|
||||
$env:VCToolsRedistDir=(get-item "${MSVC_INSTALL}\VC\Redist\MSVC\*")[0]
|
||||
}
|
||||
# Locate CUDA versions
|
||||
# Note: this assumes every version found will be built
|
||||
$cudaList=(get-item "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v*\bin\" -ea 'silentlycontinue')
|
||||
if ($cudaList.length -eq 0) {
|
||||
$d=(get-command -ea 'silentlycontinue' nvcc).path
|
||||
@@ -93,6 +94,19 @@ function buildOllama() {
|
||||
|
||||
$hashEnv = @{}
|
||||
Get-ChildItem env: | foreach { $hashEnv[$_.Name] = $_.Value }
|
||||
if ("$script:CUDA_DIRS".Contains("v11")) {
|
||||
$hashEnv.Keys | foreach { if ($_.Contains("CUDA_PATH_V11")) { $v11="$_" }}
|
||||
$env:CUDAToolkit_ROOT=$hashEnv[$v11]
|
||||
write-host "Building CUDA v11 backend libraries"
|
||||
# Note: cuda v11 requires msvc 2019 so force the older generator
|
||||
# to avoid 2022 (or newer) from being used as the default
|
||||
& cmake --fresh --preset "CUDA 11" -G "Visual Studio 16 2019" --install-prefix $script:DIST_DIR
|
||||
if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)}
|
||||
& cmake --build --preset "CUDA 11" --config Release --parallel $script:JOBS
|
||||
if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)}
|
||||
& cmake --install build --component "CUDA" --strip
|
||||
if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)}
|
||||
}
|
||||
if ("$script:CUDA_DIRS".Contains("v12")) {
|
||||
$hashEnv.Keys | foreach { if ($_.Contains("CUDA_PATH_V12")) { $v12="$_" }}
|
||||
$env:CUDAToolkit_ROOT=$hashEnv[$v12]
|
||||
@@ -113,17 +127,12 @@ function buildOllama() {
|
||||
$env:HIPCXX="${env:HIP_PATH}\bin\clang++.exe"
|
||||
$env:HIP_PLATFORM="amd"
|
||||
$env:CMAKE_PREFIX_PATH="${env:HIP_PATH}"
|
||||
& cmake --fresh --preset "ROCm 6" -G Ninja `
|
||||
-DCMAKE_C_COMPILER=clang `
|
||||
-DCMAKE_CXX_COMPILER=clang++ `
|
||||
-DCMAKE_C_FLAGS="-parallel-jobs=4 -Wno-ignored-attributes -Wno-deprecated-pragma" `
|
||||
-DCMAKE_CXX_FLAGS="-parallel-jobs=4 -Wno-ignored-attributes -Wno-deprecated-pragma" `
|
||||
--install-prefix $script:DIST_DIR
|
||||
& cmake --fresh --preset "ROCm 6" -G Ninja -DCMAKE_C_COMPILER=clang -DCMAKE_CXX_COMPILER=clang++ --install-prefix $script:DIST_DIR
|
||||
if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)}
|
||||
$env:HIPCXX=""
|
||||
$env:HIP_PLATFORM=""
|
||||
$env:CMAKE_PREFIX_PATH=""
|
||||
& cmake --build --preset "ROCm 6" --config Release --parallel $script:JOBS
|
||||
& cmake --build --preset "ROCm" --config Release --parallel $script:JOBS
|
||||
if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)}
|
||||
& cmake --install build --component "HIP" --strip
|
||||
if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)}
|
||||
|
||||
@@ -10,7 +10,9 @@ OLLAMA_COMMON_BUILD_ARGS="--build-arg=VERSION \
|
||||
--build-arg=GOFLAGS \
|
||||
--build-arg=OLLAMA_CUSTOM_CPU_DEFS \
|
||||
--build-arg=OLLAMA_SKIP_CUDA_GENERATE \
|
||||
--build-arg=OLLAMA_SKIP_CUDA_11_GENERATE \
|
||||
--build-arg=OLLAMA_SKIP_CUDA_12_GENERATE \
|
||||
--build-arg=CUDA_V11_ARCHITECTURES \
|
||||
--build-arg=CUDA_V12_ARCHITECTURES \
|
||||
--build-arg=OLLAMA_SKIP_ROCM_GENERATE \
|
||||
--build-arg=OLLAMA_FAST_BUILD \
|
||||
|
||||
@@ -23,7 +23,7 @@ import (
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/envconfig"
|
||||
"github.com/ollama/ollama/fs/gguf"
|
||||
"github.com/ollama/ollama/fs/ggml"
|
||||
"github.com/ollama/ollama/parser"
|
||||
"github.com/ollama/ollama/template"
|
||||
"github.com/ollama/ollama/thinking"
|
||||
@@ -73,18 +73,22 @@ func (m *Model) Capabilities() []model.Capability {
|
||||
capabilities := []model.Capability{}
|
||||
|
||||
// Check for completion capability
|
||||
f, err := gguf.Open(m.ModelPath)
|
||||
r, err := os.Open(m.ModelPath)
|
||||
if err == nil {
|
||||
defer f.Close()
|
||||
defer r.Close()
|
||||
|
||||
if f.KeyValue("pooling_type").Valid() {
|
||||
capabilities = append(capabilities, model.CapabilityEmbedding)
|
||||
f, err := ggml.Decode(r, 1024)
|
||||
if err == nil {
|
||||
if _, ok := f.KV()[fmt.Sprintf("%s.pooling_type", f.KV().Architecture())]; ok {
|
||||
capabilities = append(capabilities, model.CapabilityEmbedding)
|
||||
} else {
|
||||
capabilities = append(capabilities, model.CapabilityCompletion)
|
||||
}
|
||||
if _, ok := f.KV()[fmt.Sprintf("%s.vision.block_count", f.KV().Architecture())]; ok {
|
||||
capabilities = append(capabilities, model.CapabilityVision)
|
||||
}
|
||||
} else {
|
||||
// If no embedding is specified, we assume the model supports completion
|
||||
capabilities = append(capabilities, model.CapabilityCompletion)
|
||||
}
|
||||
if f.KeyValue("vision.block_count").Valid() {
|
||||
capabilities = append(capabilities, model.CapabilityVision)
|
||||
slog.Error("couldn't decode ggml", "error", err)
|
||||
}
|
||||
} else {
|
||||
slog.Error("couldn't open model file", "error", err)
|
||||
|
||||
@@ -1,42 +1,123 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/ollama/ollama/fs/ggml"
|
||||
"github.com/ollama/ollama/template"
|
||||
"github.com/ollama/ollama/types/model"
|
||||
)
|
||||
|
||||
// Constants for GGUF magic bytes and version
|
||||
var (
|
||||
ggufMagic = []byte{0x47, 0x47, 0x55, 0x46} // "GGUF"
|
||||
ggufVer = uint32(3) // Version 3
|
||||
)
|
||||
|
||||
// Helper function to create mock GGUF data
|
||||
func createMockGGUFData(architecture string, vision bool) []byte {
|
||||
var buf bytes.Buffer
|
||||
|
||||
// Write GGUF header
|
||||
buf.Write(ggufMagic)
|
||||
binary.Write(&buf, binary.LittleEndian, ggufVer)
|
||||
|
||||
// Write tensor count (0 for our test)
|
||||
var numTensors uint64 = 0
|
||||
binary.Write(&buf, binary.LittleEndian, numTensors)
|
||||
|
||||
// Calculate number of metadata entries
|
||||
numMetaEntries := uint64(1) // architecture entry
|
||||
if vision {
|
||||
numMetaEntries++
|
||||
}
|
||||
// Add embedding entry if architecture is "bert"
|
||||
if architecture == "bert" {
|
||||
numMetaEntries++
|
||||
}
|
||||
binary.Write(&buf, binary.LittleEndian, numMetaEntries)
|
||||
|
||||
// Write architecture metadata
|
||||
archKey := "general.architecture"
|
||||
keyLen := uint64(len(archKey))
|
||||
binary.Write(&buf, binary.LittleEndian, keyLen)
|
||||
buf.WriteString(archKey)
|
||||
|
||||
// String type (8)
|
||||
var strType uint32 = 8
|
||||
binary.Write(&buf, binary.LittleEndian, strType)
|
||||
|
||||
// String length
|
||||
strLen := uint64(len(architecture))
|
||||
binary.Write(&buf, binary.LittleEndian, strLen)
|
||||
buf.WriteString(architecture)
|
||||
|
||||
if vision {
|
||||
visionKey := architecture + ".vision.block_count"
|
||||
keyLen = uint64(len(visionKey))
|
||||
binary.Write(&buf, binary.LittleEndian, keyLen)
|
||||
buf.WriteString(visionKey)
|
||||
|
||||
// uint32 type (4)
|
||||
var uint32Type uint32 = 4
|
||||
binary.Write(&buf, binary.LittleEndian, uint32Type)
|
||||
|
||||
// uint32 value (1)
|
||||
var countVal uint32 = 1
|
||||
binary.Write(&buf, binary.LittleEndian, countVal)
|
||||
}
|
||||
// Write embedding metadata if architecture is "bert"
|
||||
if architecture == "bert" {
|
||||
poolKey := architecture + ".pooling_type"
|
||||
keyLen = uint64(len(poolKey))
|
||||
binary.Write(&buf, binary.LittleEndian, keyLen)
|
||||
buf.WriteString(poolKey)
|
||||
|
||||
// uint32 type (4)
|
||||
var uint32Type uint32 = 4
|
||||
binary.Write(&buf, binary.LittleEndian, uint32Type)
|
||||
|
||||
// uint32 value (1)
|
||||
var poolingVal uint32 = 1
|
||||
binary.Write(&buf, binary.LittleEndian, poolingVal)
|
||||
}
|
||||
|
||||
return buf.Bytes()
|
||||
}
|
||||
|
||||
func TestModelCapabilities(t *testing.T) {
|
||||
// Create completion model (llama architecture without vision)
|
||||
completionModelPath, _ := createBinFile(t, ggml.KV{
|
||||
"general.architecture": "llama",
|
||||
}, []*ggml.Tensor{})
|
||||
// Create a temporary directory for test files
|
||||
tempDir := t.TempDir()
|
||||
|
||||
// Create vision model (llama architecture with vision block count)
|
||||
visionModelPath, _ := createBinFile(t, ggml.KV{
|
||||
"general.architecture": "llama",
|
||||
"llama.vision.block_count": uint32(1),
|
||||
}, []*ggml.Tensor{})
|
||||
// Create different types of mock model files
|
||||
completionModelPath := filepath.Join(tempDir, "model.bin")
|
||||
visionModelPath := filepath.Join(tempDir, "vision_model.bin")
|
||||
embeddingModelPath := filepath.Join(tempDir, "embedding_model.bin")
|
||||
// Create a simple model file for tests that don't depend on GGUF content
|
||||
simpleModelPath := filepath.Join(tempDir, "simple_model.bin")
|
||||
|
||||
// Create embedding model (bert architecture with pooling type)
|
||||
embeddingModelPath, _ := createBinFile(t, ggml.KV{
|
||||
"general.architecture": "bert",
|
||||
"bert.pooling_type": uint32(1),
|
||||
}, []*ggml.Tensor{})
|
||||
if err := errors.Join(
|
||||
os.WriteFile(completionModelPath, createMockGGUFData("llama", false), 0o644),
|
||||
os.WriteFile(visionModelPath, createMockGGUFData("llama", true), 0o644),
|
||||
os.WriteFile(embeddingModelPath, createMockGGUFData("bert", false), 0o644),
|
||||
os.WriteFile(simpleModelPath, []byte("dummy model data"), 0o644),
|
||||
); err != nil {
|
||||
t.Fatalf("Failed to create model files: %v", err)
|
||||
}
|
||||
|
||||
toolsInsertTemplate, err := template.Parse("{{ .prompt }}{{ if .tools }}{{ .tools }}{{ end }}{{ if .suffix }}{{ .suffix }}{{ end }}")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to parse template: %v", err)
|
||||
}
|
||||
|
||||
chatTemplate, err := template.Parse("{{ .prompt }}")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to parse template: %v", err)
|
||||
}
|
||||
|
||||
toolsTemplate, err := template.Parse("{{ .prompt }}{{ if .tools }}{{ .tools }}{{ end }}")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to parse template: %v", err)
|
||||
@@ -64,13 +145,21 @@ func TestModelCapabilities(t *testing.T) {
|
||||
},
|
||||
expectedCaps: []model.Capability{model.CapabilityCompletion, model.CapabilityTools, model.CapabilityInsert},
|
||||
},
|
||||
{
|
||||
name: "model with tools and insert capability",
|
||||
model: Model{
|
||||
ModelPath: simpleModelPath,
|
||||
Template: toolsInsertTemplate,
|
||||
},
|
||||
expectedCaps: []model.Capability{model.CapabilityTools, model.CapabilityInsert},
|
||||
},
|
||||
{
|
||||
name: "model with tools capability",
|
||||
model: Model{
|
||||
ModelPath: completionModelPath,
|
||||
ModelPath: simpleModelPath,
|
||||
Template: toolsTemplate,
|
||||
},
|
||||
expectedCaps: []model.Capability{model.CapabilityCompletion, model.CapabilityTools},
|
||||
expectedCaps: []model.Capability{model.CapabilityTools},
|
||||
},
|
||||
{
|
||||
name: "model with vision capability",
|
||||
@@ -135,33 +224,29 @@ func TestModelCapabilities(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestModelCheckCapabilities(t *testing.T) {
|
||||
// Create simple model file for tests that don't depend on GGUF content
|
||||
completionModelPath, _ := createBinFile(t, ggml.KV{
|
||||
"general.architecture": "llama",
|
||||
}, []*ggml.Tensor{})
|
||||
// Create a temporary directory for test files
|
||||
tempDir := t.TempDir()
|
||||
|
||||
// Create vision model (llama architecture with vision block count)
|
||||
visionModelPath, _ := createBinFile(t, ggml.KV{
|
||||
"general.architecture": "llama",
|
||||
"llama.vision.block_count": uint32(1),
|
||||
}, []*ggml.Tensor{})
|
||||
visionModelPath := filepath.Join(tempDir, "vision_model.bin")
|
||||
simpleModelPath := filepath.Join(tempDir, "model.bin")
|
||||
embeddingModelPath := filepath.Join(tempDir, "embedding_model.bin")
|
||||
|
||||
// Create embedding model (bert architecture with pooling type)
|
||||
embeddingModelPath, _ := createBinFile(t, ggml.KV{
|
||||
"general.architecture": "bert",
|
||||
"bert.pooling_type": uint32(1),
|
||||
}, []*ggml.Tensor{})
|
||||
if err := errors.Join(
|
||||
os.WriteFile(simpleModelPath, []byte("dummy model data"), 0o644),
|
||||
os.WriteFile(visionModelPath, createMockGGUFData("llama", true), 0o644),
|
||||
os.WriteFile(embeddingModelPath, createMockGGUFData("bert", false), 0o644),
|
||||
); err != nil {
|
||||
t.Fatalf("Failed to create model files: %v", err)
|
||||
}
|
||||
|
||||
toolsInsertTemplate, err := template.Parse("{{ .prompt }}{{ if .tools }}{{ .tools }}{{ end }}{{ if .suffix }}{{ .suffix }}{{ end }}")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to parse template: %v", err)
|
||||
}
|
||||
|
||||
chatTemplate, err := template.Parse("{{ .prompt }}")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to parse template: %v", err)
|
||||
}
|
||||
|
||||
toolsTemplate, err := template.Parse("{{ .prompt }}{{ if .tools }}{{ .tools }}{{ end }}")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to parse template: %v", err)
|
||||
@@ -176,7 +261,7 @@ func TestModelCheckCapabilities(t *testing.T) {
|
||||
{
|
||||
name: "completion model without tools capability",
|
||||
model: Model{
|
||||
ModelPath: completionModelPath,
|
||||
ModelPath: simpleModelPath,
|
||||
Template: chatTemplate,
|
||||
},
|
||||
checkCaps: []model.Capability{model.CapabilityTools},
|
||||
@@ -185,7 +270,7 @@ func TestModelCheckCapabilities(t *testing.T) {
|
||||
{
|
||||
name: "model with all needed capabilities",
|
||||
model: Model{
|
||||
ModelPath: completionModelPath,
|
||||
ModelPath: simpleModelPath,
|
||||
Template: toolsInsertTemplate,
|
||||
},
|
||||
checkCaps: []model.Capability{model.CapabilityTools, model.CapabilityInsert},
|
||||
@@ -193,7 +278,7 @@ func TestModelCheckCapabilities(t *testing.T) {
|
||||
{
|
||||
name: "model missing insert capability",
|
||||
model: Model{
|
||||
ModelPath: completionModelPath,
|
||||
ModelPath: simpleModelPath,
|
||||
Template: toolsTemplate,
|
||||
},
|
||||
checkCaps: []model.Capability{model.CapabilityInsert},
|
||||
@@ -202,7 +287,7 @@ func TestModelCheckCapabilities(t *testing.T) {
|
||||
{
|
||||
name: "model missing vision capability",
|
||||
model: Model{
|
||||
ModelPath: completionModelPath,
|
||||
ModelPath: simpleModelPath,
|
||||
Template: toolsTemplate,
|
||||
},
|
||||
checkCaps: []model.Capability{model.CapabilityVision},
|
||||
@@ -227,7 +312,7 @@ func TestModelCheckCapabilities(t *testing.T) {
|
||||
{
|
||||
name: "unknown capability",
|
||||
model: Model{
|
||||
ModelPath: completionModelPath,
|
||||
ModelPath: simpleModelPath,
|
||||
Template: chatTemplate,
|
||||
},
|
||||
checkCaps: []model.Capability{"unknown"},
|
||||
|
||||
@@ -63,6 +63,9 @@ func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api.
|
||||
}
|
||||
|
||||
if ctxLen > opts.NumCtx {
|
||||
if !opts.ShiftContext {
|
||||
return "", nil, fmt.Errorf("context length of %d tokens exceeded, context shifting is disabled", opts.NumCtx)
|
||||
}
|
||||
slog.Debug("truncating input messages which exceed context length", "truncated", len(msgs[i:]))
|
||||
break
|
||||
} else {
|
||||
|
||||
@@ -2,6 +2,7 @@ package server
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
@@ -56,7 +57,7 @@ func TestChatPrompt(t *testing.T) {
|
||||
{Role: "user", Content: "A test. And a thumping good one at that, I'd wager."},
|
||||
},
|
||||
expect: expect{
|
||||
prompt: "A test. And a thumping good one at that, I'd wager. ",
|
||||
error: fmt.Errorf("context length of 1 tokens exceeded, context shifting is disabled"),
|
||||
},
|
||||
},
|
||||
{
|
||||
@@ -69,10 +70,7 @@ func TestChatPrompt(t *testing.T) {
|
||||
{Role: "user", Content: "A test. And a thumping good one at that, I'd wager.", Images: []api.ImageData{[]byte("something")}},
|
||||
},
|
||||
expect: expect{
|
||||
prompt: "[img-0]A test. And a thumping good one at that, I'd wager. ",
|
||||
images: [][]byte{
|
||||
[]byte("something"),
|
||||
},
|
||||
error: fmt.Errorf("context length of 64 tokens exceeded, context shifting is disabled"),
|
||||
},
|
||||
},
|
||||
{
|
||||
@@ -85,10 +83,7 @@ func TestChatPrompt(t *testing.T) {
|
||||
{Role: "user", Content: "A test. And a thumping good one at that, I'd wager.", Images: []api.ImageData{[]byte("somethingelse")}},
|
||||
},
|
||||
expect: expect{
|
||||
prompt: "[img-0]A test. And a thumping good one at that, I'd wager. ",
|
||||
images: [][]byte{
|
||||
[]byte("somethingelse"),
|
||||
},
|
||||
error: fmt.Errorf("context length of 64 tokens exceeded, context shifting is disabled"),
|
||||
},
|
||||
},
|
||||
{
|
||||
@@ -156,10 +151,7 @@ func TestChatPrompt(t *testing.T) {
|
||||
{Role: "user", Content: "A test. And a thumping good one at that, I'd wager."},
|
||||
},
|
||||
expect: expect{
|
||||
prompt: "[img-0] I-I'm a what? A test. And a thumping good one at that, I'd wager. ",
|
||||
images: [][]byte{
|
||||
[]byte("somethingelse"),
|
||||
},
|
||||
error: fmt.Errorf("context length of 1024 tokens exceeded, context shifting is disabled"),
|
||||
},
|
||||
},
|
||||
{
|
||||
@@ -208,12 +200,25 @@ func TestChatPrompt(t *testing.T) {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
model := tt.model
|
||||
opts := api.Options{Runner: api.Runner{NumCtx: tt.limit}}
|
||||
|
||||
// For truncation tests, disable context shifting to test the truncation behavior
|
||||
if tt.name == "truncate messages" ||
|
||||
tt.name == "truncate messages with image" ||
|
||||
tt.name == "truncate messages with images" ||
|
||||
tt.name == "truncate message with interleaved images" {
|
||||
opts.ShiftContext = false
|
||||
}
|
||||
|
||||
think := false
|
||||
prompt, images, err := chatPrompt(t.Context(), &model, mockRunner{}.Tokenize, &opts, tt.msgs, nil, &think)
|
||||
if tt.error == nil && err != nil {
|
||||
t.Fatal(err)
|
||||
} else if tt.error != nil && err != tt.error {
|
||||
t.Fatalf("expected err '%q', got '%q'", tt.error, err)
|
||||
} else if tt.error != nil && err != nil {
|
||||
if err.Error() != tt.error.Error() {
|
||||
t.Fatalf("expected err '%q', got '%q'", tt.error, err)
|
||||
}
|
||||
} else if tt.error != nil && err == nil {
|
||||
t.Fatalf("expected err '%q', got nil", tt.error)
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(prompt, tt.prompt); diff != "" {
|
||||
|
||||
@@ -231,8 +231,6 @@ func newType(t *fsggml.Tensor, kv fsggml.KV, qs *quantizeState, ftype fsggml.Fil
|
||||
// do not quantize relative position bias (T5)
|
||||
quantize = quantize && !strings.Contains(name, "attn_rel_b.weight")
|
||||
|
||||
quantize = quantize && !strings.Contains(name, "per_layer_token_embd.weight")
|
||||
|
||||
newType := fsggml.TensorType(t.Kind)
|
||||
if quantize {
|
||||
// get more optimal quantization type based on the tensor shape, layer, etc.
|
||||
|
||||
@@ -257,8 +257,16 @@ func TestQuantizeModel(t *testing.T) {
|
||||
|
||||
for _, tt := range cases {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
p, _ := createBinFile(t, tt.kv, tt.tensors)
|
||||
fp, err := os.Open(p)
|
||||
f, err := os.CreateTemp(t.TempDir(), tt.name)
|
||||
if err != nil {
|
||||
t.Fatal(err.Error())
|
||||
}
|
||||
defer f.Close()
|
||||
err = fsggml.WriteGGUF(f, tt.kv, tt.tensors)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create initial model: %s", err)
|
||||
}
|
||||
fp, err := os.Open(f.Name())
|
||||
if err != nil {
|
||||
t.Fatal(err.Error())
|
||||
}
|
||||
|
||||
@@ -28,7 +28,6 @@ import (
|
||||
"golang.org/x/sync/errgroup"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/auth"
|
||||
"github.com/ollama/ollama/discover"
|
||||
"github.com/ollama/ollama/envconfig"
|
||||
"github.com/ollama/ollama/fs/ggml"
|
||||
@@ -56,8 +55,6 @@ var mode string = gin.DebugMode
|
||||
type Server struct {
|
||||
addr net.Addr
|
||||
sched *Scheduler
|
||||
|
||||
perms *auth.APIPermissions
|
||||
}
|
||||
|
||||
func init() {
|
||||
@@ -72,38 +69,6 @@ func init() {
|
||||
gin.SetMode(mode)
|
||||
}
|
||||
|
||||
func loggedFormatter(param gin.LogFormatterParams) string {
|
||||
var statusColor, methodColor, resetColor string
|
||||
if param.IsOutputColor() {
|
||||
statusColor = param.StatusCodeColor()
|
||||
methodColor = param.MethodColor()
|
||||
resetColor = param.ResetColor()
|
||||
}
|
||||
|
||||
if param.Latency > time.Minute {
|
||||
param.Latency = param.Latency.Truncate(time.Second)
|
||||
}
|
||||
|
||||
username := "default"
|
||||
if userVal, exists := param.Keys["username"]; exists {
|
||||
if name, ok := userVal.(string); ok {
|
||||
username = name
|
||||
}
|
||||
}
|
||||
|
||||
return fmt.Sprintf(
|
||||
"[Ollama] %s |%s %3d %s| %13v | %15s | %-20s |%s %-7s %s %#v\n%s",
|
||||
param.TimeStamp.Format("2006/01/02 - 15:04:05"),
|
||||
statusColor, param.StatusCode, resetColor,
|
||||
param.Latency,
|
||||
param.ClientIP,
|
||||
username,
|
||||
methodColor, param.Method, resetColor,
|
||||
param.Path,
|
||||
param.ErrorMessage,
|
||||
)
|
||||
}
|
||||
|
||||
var (
|
||||
errRequired = errors.New("is required")
|
||||
errBadTemplate = errors.New("template error")
|
||||
@@ -877,11 +842,8 @@ func GetModelInfo(req api.ShowRequest) (*api.ShowResponse, error) {
|
||||
}
|
||||
resp.Parameters = strings.Join(params, "\n")
|
||||
|
||||
if len(req.Options) > 0 {
|
||||
if m.Options == nil {
|
||||
m.Options = make(map[string]any)
|
||||
}
|
||||
for k, v := range req.Options {
|
||||
for k, v := range req.Options {
|
||||
if _, ok := req.Options[k]; ok {
|
||||
m.Options[k] = v
|
||||
}
|
||||
}
|
||||
@@ -1146,43 +1108,6 @@ func allowedHost(host string) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func allowedEndpointsMiddleware(perms *auth.APIPermissions) gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
if !envconfig.UseAuth() || (c.Request.Method == "HEAD" && c.Request.URL.Path == "/") {
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
|
||||
token := strings.TrimPrefix(c.Request.Header.Get("Authorization"), "Bearer ")
|
||||
if token == "" {
|
||||
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "unauthorized"})
|
||||
return
|
||||
}
|
||||
|
||||
pubKey, err := auth.Authenticate(token, fmt.Sprintf("%s,%s", c.Request.Method, c.Request.RequestURI))
|
||||
if err != nil {
|
||||
slog.Error("authentication error", "error", err)
|
||||
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "unauthorized"})
|
||||
return
|
||||
}
|
||||
|
||||
authorized, name, err := perms.Authorize(pubKey, c.Request.URL.Path)
|
||||
c.Set("username", name)
|
||||
if err != nil {
|
||||
slog.Error("authorization error", "error", err)
|
||||
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "unauthorized"})
|
||||
return
|
||||
}
|
||||
|
||||
if !authorized {
|
||||
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "unauthorized"})
|
||||
return
|
||||
}
|
||||
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
func allowedHostsMiddleware(addr net.Addr) gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
if addr == nil {
|
||||
@@ -1249,13 +1174,10 @@ func (s *Server) GenerateRoutes(rc *ollama.Registry) (http.Handler, error) {
|
||||
}
|
||||
corsConfig.AllowOrigins = envconfig.AllowedOrigins()
|
||||
|
||||
r := gin.New()
|
||||
r := gin.Default()
|
||||
r.HandleMethodNotAllowed = true
|
||||
r.Use(
|
||||
gin.LoggerWithFormatter(loggedFormatter),
|
||||
gin.Recovery(),
|
||||
cors.New(corsConfig),
|
||||
allowedEndpointsMiddleware(s.perms),
|
||||
allowedHostsMiddleware(s.addr),
|
||||
)
|
||||
|
||||
@@ -1265,7 +1187,7 @@ func (s *Server) GenerateRoutes(rc *ollama.Registry) (http.Handler, error) {
|
||||
r.HEAD("/api/version", func(c *gin.Context) { c.JSON(http.StatusOK, gin.H{"version": version.Version}) })
|
||||
r.GET("/api/version", func(c *gin.Context) { c.JSON(http.StatusOK, gin.H{"version": version.Version}) })
|
||||
|
||||
// Local model cache management
|
||||
// Local model cache management (new implementation is at end of function)
|
||||
r.POST("/api/pull", s.PullHandler)
|
||||
r.POST("/api/push", s.PushHandler)
|
||||
r.HEAD("/api/tags", s.ListHandler)
|
||||
@@ -1297,7 +1219,7 @@ func (s *Server) GenerateRoutes(rc *ollama.Registry) (http.Handler, error) {
|
||||
// wrap old with new
|
||||
rs := ®istry.Local{
|
||||
Client: rc,
|
||||
Logger: slog.Default(),
|
||||
Logger: slog.Default(), // TODO(bmizerany): Take a logger, do not use slog.Default()
|
||||
Fallback: r,
|
||||
|
||||
Prune: PruneLayers,
|
||||
@@ -1342,12 +1264,6 @@ func Serve(ln net.Listener) error {
|
||||
|
||||
s := &Server{addr: ln.Addr()}
|
||||
|
||||
if envconfig.UseAuth() {
|
||||
perms := auth.NewAPIPermissions()
|
||||
perms.ReloadIfNeeded()
|
||||
s.perms = perms
|
||||
}
|
||||
|
||||
var rc *ollama.Registry
|
||||
if useClient2 {
|
||||
var err error
|
||||
@@ -1488,9 +1404,6 @@ func (s *Server) PsHandler(c *gin.Context) {
|
||||
Details: modelDetails,
|
||||
ExpiresAt: v.expiresAt,
|
||||
}
|
||||
if v.Options != nil {
|
||||
mr.ContextLength = v.Options.NumCtx / v.numParallel
|
||||
}
|
||||
// The scheduler waits to set expiresAt, so if a model is loading it's
|
||||
// possible that it will be set to the unix epoch. For those cases, just
|
||||
// calculate the time w/ the sessionDuration instead.
|
||||
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
"context"
|
||||
"encoding/binary"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"io/fs"
|
||||
@@ -16,7 +17,6 @@ import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"reflect"
|
||||
"slices"
|
||||
"sort"
|
||||
"strings"
|
||||
"testing"
|
||||
@@ -26,6 +26,7 @@ import (
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/fs/ggml"
|
||||
"github.com/ollama/ollama/llm"
|
||||
"github.com/ollama/ollama/openai"
|
||||
"github.com/ollama/ollama/server/internal/client/ollama"
|
||||
"github.com/ollama/ollama/types/model"
|
||||
@@ -83,6 +84,19 @@ func createTestFile(t *testing.T, name string) (string, string) {
|
||||
return f.Name(), digest
|
||||
}
|
||||
|
||||
// equalStringSlices checks if two slices of strings are equal.
|
||||
func equalStringSlices(a, b []string) bool {
|
||||
if len(a) != len(b) {
|
||||
return false
|
||||
}
|
||||
for i := range a {
|
||||
if a[i] != b[i] {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
type panicTransport struct{}
|
||||
|
||||
func (t *panicTransport) RoundTrip(r *http.Request) (*http.Response, error) {
|
||||
@@ -435,7 +449,7 @@ func TestRoutes(t *testing.T) {
|
||||
"stop \"foo\"",
|
||||
"top_p 0.9",
|
||||
}
|
||||
if !slices.Equal(params, expectedParams) {
|
||||
if !equalStringSlices(params, expectedParams) {
|
||||
t.Errorf("expected parameters %v, got %v", expectedParams, params)
|
||||
}
|
||||
paramCount, ok := showResp.ModelInfo["general.parameter_count"].(float64)
|
||||
@@ -956,3 +970,154 @@ func TestWaitForStream(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestEnableContextShiftNonStreamingResponse(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
enableContextShift bool
|
||||
responses []llm.CompletionResponse
|
||||
expectedDone bool
|
||||
expectedDoneReason string
|
||||
}{
|
||||
{
|
||||
name: "context shifting disabled - should have DoneReasonLength",
|
||||
enableContextShift: false,
|
||||
responses: []llm.CompletionResponse{
|
||||
{Content: "Hello", Done: false},
|
||||
{Content: " world", Done: false},
|
||||
{Content: "", Done: true, DoneReason: llm.DoneReasonLength},
|
||||
},
|
||||
expectedDone: true,
|
||||
expectedDoneReason: "length",
|
||||
},
|
||||
{
|
||||
name: "context shifting enabled - should have DoneReasonStop",
|
||||
enableContextShift: true,
|
||||
responses: []llm.CompletionResponse{
|
||||
{Content: "Hello", Done: false},
|
||||
{Content: " world", Done: false},
|
||||
{Content: "", Done: true, DoneReason: llm.DoneReasonStop},
|
||||
},
|
||||
expectedDone: true,
|
||||
expectedDoneReason: "stop",
|
||||
},
|
||||
{
|
||||
name: "no final response with Done=true",
|
||||
enableContextShift: false,
|
||||
responses: []llm.CompletionResponse{
|
||||
{Content: "Hello", Done: false},
|
||||
{Content: " world", Done: false},
|
||||
},
|
||||
expectedDone: false,
|
||||
expectedDoneReason: "",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// The last response in the channel will naturally be the final state
|
||||
lastResponse := tt.responses[len(tt.responses)-1]
|
||||
|
||||
if lastResponse.Done != tt.expectedDone {
|
||||
t.Errorf("expected Done=%v, got %v", tt.expectedDone, lastResponse.Done)
|
||||
}
|
||||
|
||||
if tt.expectedDoneReason != "" {
|
||||
if lastResponse.DoneReason.String() != tt.expectedDoneReason {
|
||||
t.Errorf("expected DoneReason=%s, got %s", tt.expectedDoneReason, lastResponse.DoneReason.String())
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleScheduleError(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
errorMessage string
|
||||
expectedStatus int
|
||||
}{
|
||||
{
|
||||
name: "context length exceeded error",
|
||||
errorMessage: "context length of 100 tokens exceeded, context shifting is disabled",
|
||||
expectedStatus: http.StatusInternalServerError,
|
||||
},
|
||||
{
|
||||
name: "other error",
|
||||
errorMessage: "some other error",
|
||||
expectedStatus: http.StatusInternalServerError,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
|
||||
err := errors.New(tt.errorMessage)
|
||||
|
||||
handleScheduleError(c, "test-model", err)
|
||||
|
||||
if w.Code != tt.expectedStatus {
|
||||
t.Errorf("expected status %d, got %d", tt.expectedStatus, w.Code)
|
||||
}
|
||||
|
||||
var response map[string]any
|
||||
if err := json.Unmarshal(w.Body.Bytes(), &response); err != nil {
|
||||
t.Fatalf("failed to unmarshal response: %v", err)
|
||||
}
|
||||
|
||||
if errorMsg, ok := response["error"].(string); !ok || errorMsg != tt.errorMessage {
|
||||
t.Errorf("expected error message '%s', got '%s'", tt.errorMessage, errorMsg)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestEnableContextShiftOptions(t *testing.T) {
|
||||
t.Run("default options have enableContextShift=true", func(t *testing.T) {
|
||||
opts := api.DefaultOptions()
|
||||
if !opts.ShiftContext {
|
||||
t.Errorf("expected EnableContextShift=true by default, got %v", opts.ShiftContext)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("can set enableContextShift to false", func(t *testing.T) {
|
||||
opts := api.DefaultOptions()
|
||||
opts.ShiftContext = false
|
||||
if opts.ShiftContext {
|
||||
t.Errorf("expected EnableContextShift=false after setting, got %v", opts.ShiftContext)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("JSON serialization omits false values", func(t *testing.T) {
|
||||
opts := api.DefaultOptions()
|
||||
opts.ShiftContext = false
|
||||
|
||||
data, err := json.Marshal(opts)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to marshal options: %v", err)
|
||||
}
|
||||
|
||||
// Check that enable_context_shift is not in the JSON when false
|
||||
if bytes.Contains(data, []byte("enable_context_shift")) {
|
||||
t.Errorf("expected enable_context_shift to be omitted from JSON when false, but found it in: %s", string(data))
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("JSON serialization includes true values", func(t *testing.T) {
|
||||
opts := api.DefaultOptions()
|
||||
opts.ShiftContext = true
|
||||
|
||||
data, err := json.Marshal(opts)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to marshal options: %v", err)
|
||||
}
|
||||
|
||||
// Check that enable_context_shift is in the JSON when true
|
||||
if !bytes.Contains(data, []byte("enable_context_shift")) {
|
||||
t.Errorf("expected enable_context_shift to be in JSON when true, but not found in: %s", string(data))
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user