Compare commits
5 Commits
v0.9.3-rc0
...
brucemacd/
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
f2a4d058f9 | ||
|
|
63e7634014 | ||
|
|
8d51d92f3b | ||
|
|
2348fef568 | ||
|
|
883f655dd6 |
7
.github/workflows/release.yaml
vendored
7
.github/workflows/release.yaml
vendored
@@ -103,6 +103,11 @@ jobs:
|
|||||||
arch: [amd64]
|
arch: [amd64]
|
||||||
preset: ['CPU']
|
preset: ['CPU']
|
||||||
include:
|
include:
|
||||||
|
- os: windows
|
||||||
|
arch: amd64
|
||||||
|
preset: 'CUDA 11'
|
||||||
|
install: https://developer.download.nvidia.com/compute/cuda/11.3.1/local_installers/cuda_11.3.1_465.89_win10.exe
|
||||||
|
cuda-version: '11.3'
|
||||||
- os: windows
|
- os: windows
|
||||||
arch: amd64
|
arch: amd64
|
||||||
preset: 'CUDA 12'
|
preset: 'CUDA 12'
|
||||||
@@ -319,6 +324,8 @@ jobs:
|
|||||||
case "$COMPONENT" in
|
case "$COMPONENT" in
|
||||||
bin/ollama) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}.tar.in ;;
|
bin/ollama) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}.tar.in ;;
|
||||||
lib/ollama/*.so) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}.tar.in ;;
|
lib/ollama/*.so) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}.tar.in ;;
|
||||||
|
lib/ollama/cuda_v11) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}.tar.in ;;
|
||||||
|
lib/ollama/cuda_v12) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}.tar.in ;;
|
||||||
lib/ollama/cuda_jetpack5) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}-jetpack5.tar.in ;;
|
lib/ollama/cuda_jetpack5) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}-jetpack5.tar.in ;;
|
||||||
lib/ollama/cuda_jetpack6) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}-jetpack6.tar.in ;;
|
lib/ollama/cuda_jetpack6) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}-jetpack6.tar.in ;;
|
||||||
lib/ollama/rocm) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}-rocm.tar.in ;;
|
lib/ollama/rocm) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}-rocm.tar.in ;;
|
||||||
|
|||||||
6
.github/workflows/test.yaml
vendored
6
.github/workflows/test.yaml
vendored
@@ -46,7 +46,7 @@ jobs:
|
|||||||
include:
|
include:
|
||||||
- preset: CPU
|
- preset: CPU
|
||||||
- preset: CUDA
|
- preset: CUDA
|
||||||
container: nvidia/cuda:12.8.1-devel-ubuntu22.04
|
container: nvidia/cuda:11.8.0-devel-ubuntu22.04
|
||||||
flags: '-DCMAKE_CUDA_ARCHITECTURES=87'
|
flags: '-DCMAKE_CUDA_ARCHITECTURES=87'
|
||||||
- preset: ROCm
|
- preset: ROCm
|
||||||
container: rocm/dev-ubuntu-22.04:6.1.2
|
container: rocm/dev-ubuntu-22.04:6.1.2
|
||||||
@@ -78,7 +78,7 @@ jobs:
|
|||||||
include:
|
include:
|
||||||
- preset: CPU
|
- preset: CPU
|
||||||
- preset: CUDA
|
- preset: CUDA
|
||||||
install: https://developer.download.nvidia.com/compute/cuda/12.8.0/local_installers/cuda_12.8.0_571.96_windows.exe
|
install: https://developer.download.nvidia.com/compute/cuda/11.3.1/local_installers/cuda_11.3.1_465.89_win10.exe
|
||||||
flags: '-DCMAKE_CUDA_ARCHITECTURES=80'
|
flags: '-DCMAKE_CUDA_ARCHITECTURES=80'
|
||||||
- preset: ROCm
|
- preset: ROCm
|
||||||
install: https://download.amd.com/developer/eula/rocm-hub/AMD-Software-PRO-Edition-24.Q4-WinSvr2022-For-HIP.exe
|
install: https://download.amd.com/developer/eula/rocm-hub/AMD-Software-PRO-Edition-24.Q4-WinSvr2022-For-HIP.exe
|
||||||
@@ -102,7 +102,7 @@ jobs:
|
|||||||
$ErrorActionPreference = "Stop"
|
$ErrorActionPreference = "Stop"
|
||||||
if ("${{ steps.cache-install.outputs.cache-hit }}" -ne 'true') {
|
if ("${{ steps.cache-install.outputs.cache-hit }}" -ne 'true') {
|
||||||
Invoke-WebRequest -Uri "${{ matrix.install }}" -OutFile "install.exe"
|
Invoke-WebRequest -Uri "${{ matrix.install }}" -OutFile "install.exe"
|
||||||
Start-Process -FilePath .\install.exe -ArgumentList (@("-s", "cudart_12.8", "nvcc_12.8", "cublas_12.8", "cublas_dev_12.8")) -NoNewWindow -Wait
|
Start-Process -FilePath .\install.exe -ArgumentList (@("-s", "cudart_11.3", "nvcc_11.3", "cublas_11.3", "cublas_dev_11.3")) -NoNewWindow -Wait
|
||||||
}
|
}
|
||||||
|
|
||||||
$cudaPath = (Resolve-Path "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\*").path
|
$cudaPath = (Resolve-Path "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\*").path
|
||||||
|
|||||||
@@ -78,13 +78,14 @@ if(CMAKE_CUDA_COMPILER)
|
|||||||
|
|
||||||
find_package(CUDAToolkit)
|
find_package(CUDAToolkit)
|
||||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/ggml/ggml/src/ggml-cuda)
|
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/ggml/ggml/src/ggml-cuda)
|
||||||
|
set(OLLAMA_CUDA_INSTALL_DIR ${OLLAMA_INSTALL_DIR}/cuda_v${CUDAToolkit_VERSION_MAJOR})
|
||||||
install(TARGETS ggml-cuda
|
install(TARGETS ggml-cuda
|
||||||
RUNTIME_DEPENDENCIES
|
RUNTIME_DEPENDENCIES
|
||||||
DIRECTORIES ${CUDAToolkit_BIN_DIR} ${CUDAToolkit_LIBRARY_DIR}
|
DIRECTORIES ${CUDAToolkit_BIN_DIR} ${CUDAToolkit_LIBRARY_DIR}
|
||||||
PRE_INCLUDE_REGEXES cublas cublasLt cudart
|
PRE_INCLUDE_REGEXES cublas cublasLt cudart
|
||||||
PRE_EXCLUDE_REGEXES ".*"
|
PRE_EXCLUDE_REGEXES ".*"
|
||||||
RUNTIME DESTINATION ${OLLAMA_INSTALL_DIR} COMPONENT CUDA
|
RUNTIME DESTINATION ${OLLAMA_CUDA_INSTALL_DIR} COMPONENT CUDA
|
||||||
LIBRARY DESTINATION ${OLLAMA_INSTALL_DIR} COMPONENT CUDA
|
LIBRARY DESTINATION ${OLLAMA_CUDA_INSTALL_DIR} COMPONENT CUDA
|
||||||
)
|
)
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
@@ -115,11 +116,7 @@ if(CMAKE_HIP_COMPILER)
|
|||||||
|
|
||||||
set(OLLAMA_HIP_INSTALL_DIR ${OLLAMA_INSTALL_DIR}/rocm)
|
set(OLLAMA_HIP_INSTALL_DIR ${OLLAMA_INSTALL_DIR}/rocm)
|
||||||
install(TARGETS ggml-hip
|
install(TARGETS ggml-hip
|
||||||
RUNTIME_DEPENDENCY_SET rocm
|
RUNTIME_DEPENDENCIES
|
||||||
RUNTIME DESTINATION ${OLLAMA_INSTALL_DIR} COMPONENT HIP
|
|
||||||
LIBRARY DESTINATION ${OLLAMA_INSTALL_DIR} COMPONENT HIP
|
|
||||||
)
|
|
||||||
install(RUNTIME_DEPENDENCY_SET rocm
|
|
||||||
DIRECTORIES ${HIP_BIN_INSTALL_DIR} ${HIP_LIB_INSTALL_DIR}
|
DIRECTORIES ${HIP_BIN_INSTALL_DIR} ${HIP_LIB_INSTALL_DIR}
|
||||||
PRE_INCLUDE_REGEXES hipblas rocblas amdhip64 rocsolver amd_comgr hsa-runtime64 rocsparse tinfo rocprofiler-register drm drm_amdgpu numa elf
|
PRE_INCLUDE_REGEXES hipblas rocblas amdhip64 rocsolver amd_comgr hsa-runtime64 rocsparse tinfo rocprofiler-register drm drm_amdgpu numa elf
|
||||||
PRE_EXCLUDE_REGEXES ".*"
|
PRE_EXCLUDE_REGEXES ".*"
|
||||||
|
|||||||
@@ -17,12 +17,20 @@
|
|||||||
"name": "CUDA",
|
"name": "CUDA",
|
||||||
"inherits": [ "Default" ]
|
"inherits": [ "Default" ]
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
"name": "CUDA 11",
|
||||||
|
"inherits": [ "CUDA" ],
|
||||||
|
"cacheVariables": {
|
||||||
|
"CMAKE_CUDA_ARCHITECTURES": "50;52;53;60;61;70;75;80;86",
|
||||||
|
"CMAKE_CUDA_FLAGS": "-Wno-deprecated-gpu-targets"
|
||||||
|
}
|
||||||
|
},
|
||||||
{
|
{
|
||||||
"name": "CUDA 12",
|
"name": "CUDA 12",
|
||||||
"inherits": [ "CUDA" ],
|
"inherits": [ "CUDA" ],
|
||||||
"cacheVariables": {
|
"cacheVariables": {
|
||||||
"CMAKE_CUDA_ARCHITECTURES": "50;60;61;70;75;80;86;87;89;90;90a;120",
|
"CMAKE_CUDA_ARCHITECTURES": "50;60;61;70;75;80;86;87;89;90;90a;120",
|
||||||
"CMAKE_CUDA_FLAGS": "-Wno-deprecated-gpu-targets -t 2"
|
"CMAKE_CUDA_FLAGS": "-Wno-deprecated-gpu-targets"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@@ -50,7 +58,6 @@
|
|||||||
"name": "ROCm 6",
|
"name": "ROCm 6",
|
||||||
"inherits": [ "ROCm" ],
|
"inherits": [ "ROCm" ],
|
||||||
"cacheVariables": {
|
"cacheVariables": {
|
||||||
"CMAKE_HIP_FLAGS": "-parallel-jobs=4",
|
|
||||||
"AMDGPU_TARGETS": "gfx900;gfx940;gfx941;gfx942;gfx1010;gfx1012;gfx1030;gfx1100;gfx1101;gfx1102;gfx1151;gfx1200;gfx1201;gfx906:xnack-;gfx908:xnack-;gfx90a:xnack+;gfx90a:xnack-"
|
"AMDGPU_TARGETS": "gfx900;gfx940;gfx941;gfx942;gfx1010;gfx1012;gfx1030;gfx1100;gfx1101;gfx1102;gfx1151;gfx1200;gfx1201;gfx906:xnack-;gfx908:xnack-;gfx90a:xnack+;gfx90a:xnack-"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -71,6 +78,11 @@
|
|||||||
"configurePreset": "CUDA",
|
"configurePreset": "CUDA",
|
||||||
"targets": [ "ggml-cuda" ]
|
"targets": [ "ggml-cuda" ]
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
"name": "CUDA 11",
|
||||||
|
"inherits": [ "CUDA" ],
|
||||||
|
"configurePreset": "CUDA 11"
|
||||||
|
},
|
||||||
{
|
{
|
||||||
"name": "CUDA 12",
|
"name": "CUDA 12",
|
||||||
"inherits": [ "CUDA" ],
|
"inherits": [ "CUDA" ],
|
||||||
|
|||||||
24
Dockerfile
24
Dockerfile
@@ -7,13 +7,12 @@ ARG JETPACK5VERSION=r35.4.1
|
|||||||
ARG JETPACK6VERSION=r36.4.0
|
ARG JETPACK6VERSION=r36.4.0
|
||||||
ARG CMAKEVERSION=3.31.2
|
ARG CMAKEVERSION=3.31.2
|
||||||
|
|
||||||
# We require gcc v10 minimum. v10.3 has regressions, so the rockylinux 8.5 AppStream has the latest compatible version
|
# CUDA v11 requires gcc v10. v10.3 has regressions, so the rockylinux 8.5 AppStream has the latest compatible version
|
||||||
FROM --platform=linux/amd64 rocm/dev-almalinux-8:${ROCMVERSION}-complete AS base-amd64
|
FROM --platform=linux/amd64 rocm/dev-almalinux-8:${ROCMVERSION}-complete AS base-amd64
|
||||||
RUN yum install -y yum-utils \
|
RUN yum install -y yum-utils \
|
||||||
&& yum-config-manager --add-repo https://dl.rockylinux.org/vault/rocky/8.5/AppStream/\$basearch/os/ \
|
&& yum-config-manager --add-repo https://dl.rockylinux.org/vault/rocky/8.5/AppStream/\$basearch/os/ \
|
||||||
&& rpm --import https://dl.rockylinux.org/pub/rocky/RPM-GPG-KEY-Rocky-8 \
|
&& rpm --import https://dl.rockylinux.org/pub/rocky/RPM-GPG-KEY-Rocky-8 \
|
||||||
&& dnf install -y yum-utils ccache gcc-toolset-10-gcc-10.2.1-8.2.el8 gcc-toolset-10-gcc-c++-10.2.1-8.2.el8 gcc-toolset-10-binutils-2.35-11.el8 \
|
&& dnf install -y yum-utils ccache gcc-toolset-10-gcc-10.2.1-8.2.el8 gcc-toolset-10-gcc-c++-10.2.1-8.2.el8 gcc-toolset-10-binutils-2.35-11.el8 \
|
||||||
&& dnf install -y ccache \
|
|
||||||
&& yum-config-manager --add-repo https://developer.download.nvidia.com/compute/cuda/repos/rhel8/x86_64/cuda-rhel8.repo
|
&& yum-config-manager --add-repo https://developer.download.nvidia.com/compute/cuda/repos/rhel8/x86_64/cuda-rhel8.repo
|
||||||
ENV PATH=/opt/rh/gcc-toolset-10/root/usr/bin:$PATH
|
ENV PATH=/opt/rh/gcc-toolset-10/root/usr/bin:$PATH
|
||||||
|
|
||||||
@@ -39,6 +38,15 @@ RUN --mount=type=cache,target=/root/.ccache \
|
|||||||
&& cmake --build --parallel --preset 'CPU' \
|
&& cmake --build --parallel --preset 'CPU' \
|
||||||
&& cmake --install build --component CPU --strip --parallel 8
|
&& cmake --install build --component CPU --strip --parallel 8
|
||||||
|
|
||||||
|
FROM base AS cuda-11
|
||||||
|
ARG CUDA11VERSION=11.3
|
||||||
|
RUN dnf install -y cuda-toolkit-${CUDA11VERSION//./-}
|
||||||
|
ENV PATH=/usr/local/cuda-11/bin:$PATH
|
||||||
|
RUN --mount=type=cache,target=/root/.ccache \
|
||||||
|
cmake --preset 'CUDA 11' \
|
||||||
|
&& cmake --build --parallel --preset 'CUDA 11' \
|
||||||
|
&& cmake --install build --component CUDA --strip --parallel 8
|
||||||
|
|
||||||
FROM base AS cuda-12
|
FROM base AS cuda-12
|
||||||
ARG CUDA12VERSION=12.8
|
ARG CUDA12VERSION=12.8
|
||||||
RUN dnf install -y cuda-toolkit-${CUDA12VERSION//./-}
|
RUN dnf install -y cuda-toolkit-${CUDA12VERSION//./-}
|
||||||
@@ -90,15 +98,17 @@ RUN --mount=type=cache,target=/root/.cache/go-build \
|
|||||||
go build -trimpath -buildmode=pie -o /bin/ollama .
|
go build -trimpath -buildmode=pie -o /bin/ollama .
|
||||||
|
|
||||||
FROM --platform=linux/amd64 scratch AS amd64
|
FROM --platform=linux/amd64 scratch AS amd64
|
||||||
COPY --from=cuda-12 dist/lib/ollama /lib/ollama
|
COPY --from=cuda-11 dist/lib/ollama/cuda_v11 /lib/ollama/cuda_v11
|
||||||
|
COPY --from=cuda-12 dist/lib/ollama/cuda_v12 /lib/ollama/cuda_v12
|
||||||
|
|
||||||
FROM --platform=linux/arm64 scratch AS arm64
|
FROM --platform=linux/arm64 scratch AS arm64
|
||||||
COPY --from=cuda-12 dist/lib/ollama /lib/ollama/cuda_sbsa
|
COPY --from=cuda-11 dist/lib/ollama/cuda_v11 /lib/ollama/cuda_v11
|
||||||
COPY --from=jetpack-5 dist/lib/ollama /lib/ollama/cuda_jetpack5
|
COPY --from=cuda-12 dist/lib/ollama/cuda_v12 /lib/ollama/cuda_v12
|
||||||
COPY --from=jetpack-6 dist/lib/ollama /lib/ollama/cuda_jetpack6
|
COPY --from=jetpack-5 dist/lib/ollama/cuda_v11 /lib/ollama/cuda_jetpack5
|
||||||
|
COPY --from=jetpack-6 dist/lib/ollama/cuda_v12 /lib/ollama/cuda_jetpack6
|
||||||
|
|
||||||
FROM scratch AS rocm
|
FROM scratch AS rocm
|
||||||
COPY --from=rocm-6 dist/lib/ollama /lib/ollama
|
COPY --from=rocm-6 dist/lib/ollama/rocm /lib/ollama/rocm
|
||||||
|
|
||||||
FROM ${FLAVOR} AS archive
|
FROM ${FLAVOR} AS archive
|
||||||
COPY --from=cpu dist/lib/ollama /lib/ollama
|
COPY --from=cpu dist/lib/ollama /lib/ollama
|
||||||
|
|||||||
@@ -409,7 +409,6 @@ See the [API documentation](./docs/api.md) for all endpoints.
|
|||||||
- [macLlama (macOS native)](https://github.com/hellotunamayo/macLlama) (A native macOS GUI application for interacting with Ollama models, featuring a chat interface.)
|
- [macLlama (macOS native)](https://github.com/hellotunamayo/macLlama) (A native macOS GUI application for interacting with Ollama models, featuring a chat interface.)
|
||||||
- [GPTranslate](https://github.com/philberndt/GPTranslate) (A fast and lightweight, AI powered desktop translation application written with Rust and Tauri. Features real-time translation with OpenAI/Azure/Ollama.)
|
- [GPTranslate](https://github.com/philberndt/GPTranslate) (A fast and lightweight, AI powered desktop translation application written with Rust and Tauri. Features real-time translation with OpenAI/Azure/Ollama.)
|
||||||
- [ollama launcher](https://github.com/NGC13009/ollama-launcher) (A launcher for Ollama, aiming to provide users with convenient functions such as ollama server launching, management, or configuration.)
|
- [ollama launcher](https://github.com/NGC13009/ollama-launcher) (A launcher for Ollama, aiming to provide users with convenient functions such as ollama server launching, management, or configuration.)
|
||||||
- [ai-hub](https://github.com/Aj-Seven/ai-hub) (AI Hub supports multiple models via API keys and Chat support via Ollama API.)
|
|
||||||
|
|
||||||
### Cloud
|
### Cloud
|
||||||
|
|
||||||
|
|||||||
178
benchmark/server_benchmark_test.go
Normal file
178
benchmark/server_benchmark_test.go
Normal file
@@ -0,0 +1,178 @@
|
|||||||
|
package benchmark
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"flag"
|
||||||
|
"fmt"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/ollama/ollama/api"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Command line flags
|
||||||
|
var modelFlag string
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
flag.StringVar(&modelFlag, "m", "", "Name of the model to benchmark")
|
||||||
|
flag.Lookup("m").DefValue = "model"
|
||||||
|
}
|
||||||
|
|
||||||
|
// modelName returns the model name from flags, failing the test if not set
|
||||||
|
func modelName(b *testing.B) string {
|
||||||
|
if modelFlag == "" {
|
||||||
|
b.Fatal("Error: -m flag is required for benchmark tests")
|
||||||
|
}
|
||||||
|
return modelFlag
|
||||||
|
}
|
||||||
|
|
||||||
|
type TestCase struct {
|
||||||
|
name string
|
||||||
|
prompt string
|
||||||
|
maxTokens int
|
||||||
|
}
|
||||||
|
|
||||||
|
// runGenerateBenchmark contains the common generate and metrics logic
|
||||||
|
func runGenerateBenchmark(b *testing.B, ctx context.Context, client *api.Client, req *api.GenerateRequest) {
|
||||||
|
start := time.Now()
|
||||||
|
var ttft time.Duration
|
||||||
|
var metrics api.Metrics
|
||||||
|
|
||||||
|
err := client.Generate(ctx, req, func(resp api.GenerateResponse) error {
|
||||||
|
if ttft == 0 && resp.Response != "" {
|
||||||
|
ttft = time.Since(start)
|
||||||
|
}
|
||||||
|
if resp.Done {
|
||||||
|
metrics = resp.Metrics
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
|
||||||
|
// Report custom metrics as part of the benchmark results
|
||||||
|
b.ReportMetric(float64(ttft.Milliseconds()), "ttft_ms")
|
||||||
|
b.ReportMetric(float64(metrics.LoadDuration.Milliseconds()), "load_ms")
|
||||||
|
|
||||||
|
// Token throughput metrics
|
||||||
|
promptThroughput := float64(metrics.PromptEvalCount) / metrics.PromptEvalDuration.Seconds()
|
||||||
|
genThroughput := float64(metrics.EvalCount) / metrics.EvalDuration.Seconds()
|
||||||
|
b.ReportMetric(promptThroughput, "prompt_tok/s")
|
||||||
|
b.ReportMetric(genThroughput, "gen_tok/s")
|
||||||
|
|
||||||
|
// Token counts
|
||||||
|
b.ReportMetric(float64(metrics.PromptEvalCount), "prompt_tokens")
|
||||||
|
b.ReportMetric(float64(metrics.EvalCount), "gen_tokens")
|
||||||
|
if err != nil {
|
||||||
|
b.Fatal(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// BenchmarkColdStart runs benchmarks with model loading from cold state
|
||||||
|
func BenchmarkColdStart(b *testing.B) {
|
||||||
|
client := setup(b)
|
||||||
|
tests := []TestCase{
|
||||||
|
{"short_prompt", "Write a long story", 100},
|
||||||
|
{"medium_prompt", "Write a detailed economic analysis", 500},
|
||||||
|
{"long_prompt", "Write a comprehensive AI research paper", 1000},
|
||||||
|
}
|
||||||
|
m := modelName(b)
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
b.Run(fmt.Sprintf("%s/cold/%s", m, tt.name), func(b *testing.B) {
|
||||||
|
ctx := b.Context()
|
||||||
|
|
||||||
|
// Set number of tokens as our throughput metric
|
||||||
|
b.SetBytes(int64(tt.maxTokens))
|
||||||
|
|
||||||
|
for b.Loop() {
|
||||||
|
b.StopTimer()
|
||||||
|
// Ensure model is unloaded before each iteration
|
||||||
|
unload(client, m, b)
|
||||||
|
b.StartTimer()
|
||||||
|
|
||||||
|
req := &api.GenerateRequest{
|
||||||
|
Model: m,
|
||||||
|
Prompt: tt.prompt,
|
||||||
|
Options: map[string]any{"num_predict": tt.maxTokens, "temperature": 0.1},
|
||||||
|
}
|
||||||
|
|
||||||
|
runGenerateBenchmark(b, ctx, client, req)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// BenchmarkWarmStart runs benchmarks with pre-loaded model
|
||||||
|
func BenchmarkWarmStart(b *testing.B) {
|
||||||
|
client := setup(b)
|
||||||
|
tests := []TestCase{
|
||||||
|
{"short_prompt", "Write a long story", 100},
|
||||||
|
{"medium_prompt", "Write a detailed economic analysis", 500},
|
||||||
|
{"long_prompt", "Write a comprehensive AI research paper", 1000},
|
||||||
|
}
|
||||||
|
m := modelName(b)
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
b.Run(fmt.Sprintf("%s/warm/%s", m, tt.name), func(b *testing.B) {
|
||||||
|
ctx := b.Context()
|
||||||
|
|
||||||
|
// Pre-warm the model
|
||||||
|
warmup(client, m, tt.prompt, b)
|
||||||
|
|
||||||
|
// Set number of tokens as our throughput metric
|
||||||
|
b.SetBytes(int64(tt.maxTokens))
|
||||||
|
|
||||||
|
for b.Loop() {
|
||||||
|
req := &api.GenerateRequest{
|
||||||
|
Model: m,
|
||||||
|
Prompt: tt.prompt,
|
||||||
|
Options: map[string]any{"num_predict": tt.maxTokens, "temperature": 0.1},
|
||||||
|
}
|
||||||
|
|
||||||
|
runGenerateBenchmark(b, ctx, client, req)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// setup verifies server and model availability
|
||||||
|
func setup(b *testing.B) *api.Client {
|
||||||
|
client, err := api.ClientFromEnvironment()
|
||||||
|
if err != nil {
|
||||||
|
b.Fatal(err)
|
||||||
|
}
|
||||||
|
if _, err := client.Show(b.Context(), &api.ShowRequest{Model: modelName(b)}); err != nil {
|
||||||
|
b.Fatalf("Model unavailable: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return client
|
||||||
|
}
|
||||||
|
|
||||||
|
// warmup ensures the model is loaded and warmed up
|
||||||
|
func warmup(client *api.Client, model string, prompt string, b *testing.B) {
|
||||||
|
for range 3 {
|
||||||
|
err := client.Generate(
|
||||||
|
context.Background(),
|
||||||
|
&api.GenerateRequest{
|
||||||
|
Model: model,
|
||||||
|
Prompt: prompt,
|
||||||
|
Options: map[string]any{"num_predict": 50, "temperature": 0.1},
|
||||||
|
},
|
||||||
|
func(api.GenerateResponse) error { return nil },
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
b.Logf("Error during model warm-up: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// unload forces model unloading using KeepAlive: 0 parameter
|
||||||
|
func unload(client *api.Client, model string, b *testing.B) {
|
||||||
|
req := &api.GenerateRequest{
|
||||||
|
Model: model,
|
||||||
|
KeepAlive: &api.Duration{Duration: 0},
|
||||||
|
}
|
||||||
|
if err := client.Generate(context.Background(), req, func(api.GenerateResponse) error { return nil }); err != nil {
|
||||||
|
b.Logf("Unload error: %v", err)
|
||||||
|
}
|
||||||
|
time.Sleep(1 * time.Second)
|
||||||
|
}
|
||||||
@@ -2,6 +2,9 @@ package convert
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"slices"
|
||||||
|
"strings"
|
||||||
|
|
||||||
"github.com/ollama/ollama/fs/ggml"
|
"github.com/ollama/ollama/fs/ggml"
|
||||||
)
|
)
|
||||||
@@ -27,38 +30,65 @@ func (p *mixtralModel) KV(t *Tokenizer) ggml.KV {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (p *mixtralModel) Tensors(ts []Tensor) []*ggml.Tensor {
|
func (p *mixtralModel) Tensors(ts []Tensor) []*ggml.Tensor {
|
||||||
merges := make([]merge, 0, p.NumHiddenLayers*6)
|
oldnew := []string{
|
||||||
for i := range p.NumHiddenLayers {
|
"model.layers", "blk",
|
||||||
merges = append(merges, merge{
|
"w1", "ffn_gate_exps",
|
||||||
fmt.Sprintf("blk.%d.*.w1.weight", i),
|
"w2", "ffn_down_exps",
|
||||||
fmt.Sprintf("blk.%d.ffn_gate_exps.weight", i),
|
"w3", "ffn_up_exps",
|
||||||
}, merge{
|
}
|
||||||
fmt.Sprintf("blk.%d.*.w1.bias", i),
|
|
||||||
fmt.Sprintf("blk.%d.ffn_gate_exps.bias", i),
|
for i := range p.NumLocalExperts {
|
||||||
}, merge{
|
oldnew = append(oldnew, fmt.Sprintf(".block_sparse_moe.experts.%d.", i), ".")
|
||||||
fmt.Sprintf("blk.%d.*.w2.weight", i),
|
}
|
||||||
fmt.Sprintf("blk.%d.ffn_up_exps.weight", i),
|
|
||||||
}, merge{
|
// group experts of the same layer (model.layers.%d) and type (w[123]) into a single tensor
|
||||||
fmt.Sprintf("blk.%d.*.w2.bias", i),
|
namer := strings.NewReplacer(oldnew...)
|
||||||
fmt.Sprintf("blk.%d.ffn_up_exps.bias", i),
|
experts := make(map[string]experts)
|
||||||
}, merge{
|
|
||||||
fmt.Sprintf("blk.%d.*.w3.weight", i),
|
// merge experts into a single tensor while removing them from ts
|
||||||
fmt.Sprintf("blk.%d.ffn_down_exps.weight", i),
|
ts = slices.DeleteFunc(ts, func(t Tensor) bool {
|
||||||
}, merge{
|
if !strings.Contains(t.Name(), ".block_sparse_moe.experts.") {
|
||||||
fmt.Sprintf("blk.%d.*.w3.bias", i),
|
return false
|
||||||
fmt.Sprintf("blk.%d.ffn_down_exps.bias", i),
|
}
|
||||||
|
|
||||||
|
name := namer.Replace(t.Name())
|
||||||
|
experts[name] = append(experts[name], t)
|
||||||
|
return true
|
||||||
|
})
|
||||||
|
|
||||||
|
var out []*ggml.Tensor
|
||||||
|
for n, e := range experts {
|
||||||
|
// TODO(mxyng): sanity check experts
|
||||||
|
out = append(out, &ggml.Tensor{
|
||||||
|
Name: n,
|
||||||
|
Kind: e[0].Kind(),
|
||||||
|
Shape: append([]uint64{uint64(len(e))}, e[0].Shape()...),
|
||||||
|
WriterTo: e,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
out, ts := mergeTensors(ts, merges...)
|
|
||||||
return append(out, p.llamaModel.Tensors(ts)...)
|
return append(out, p.llamaModel.Tensors(ts)...)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *mixtralModel) Replacements() []string {
|
func (p *mixtralModel) Replacements() []string {
|
||||||
return append(
|
return append(
|
||||||
p.llamaModel.Replacements(),
|
p.llamaModel.Replacements(),
|
||||||
"model.layers", "blk",
|
|
||||||
"block_sparse_moe.gate", "ffn_gate_inp",
|
"block_sparse_moe.gate", "ffn_gate_inp",
|
||||||
"block_sparse_moe.experts.", ".",
|
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type experts []Tensor
|
||||||
|
|
||||||
|
func (e experts) WriteTo(w io.Writer) (int64, error) {
|
||||||
|
// TODO(mxyng): experts _should_ be numerically sorted by expert but this should check
|
||||||
|
for _, t := range e {
|
||||||
|
// the canonical merged experts tensor stacks all experts along a new, 0 axis,
|
||||||
|
// e.g. `tensor.Stack(0, e[0], e[1:]...)`, which requires allocating temporary buffers
|
||||||
|
// this accomplishes the same thing by writing each expert tensor in sequence
|
||||||
|
if _, err := t.WriteTo(w); err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return 0, nil
|
||||||
|
}
|
||||||
|
|||||||
@@ -2,9 +2,7 @@ package convert
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"cmp"
|
"cmp"
|
||||||
"io"
|
|
||||||
"iter"
|
"iter"
|
||||||
"path"
|
|
||||||
"slices"
|
"slices"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
@@ -76,54 +74,3 @@ func splitDim(t Tensor, dim int, splits ...split) iter.Seq[*ggml.Tensor] {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
type merge struct {
|
|
||||||
pattern, name string
|
|
||||||
}
|
|
||||||
|
|
||||||
// mergeTensors merges tensors that match a given pattern into a single tensor.
|
|
||||||
func mergeTensors(unmatched []Tensor, merges ...merge) (out []*ggml.Tensor, _ []Tensor) {
|
|
||||||
var matched []Tensor
|
|
||||||
for i := range merges {
|
|
||||||
matched, unmatched = slicesSplitFunc(unmatched, func(t Tensor) bool {
|
|
||||||
matched, _ := path.Match(merges[i].pattern, t.Name())
|
|
||||||
return matched
|
|
||||||
})
|
|
||||||
|
|
||||||
if len(matched) > 0 {
|
|
||||||
out = append(out, &ggml.Tensor{
|
|
||||||
Name: merges[i].name,
|
|
||||||
Kind: matched[0].Kind(),
|
|
||||||
Shape: append([]uint64{uint64(len(matched))}, matched[0].Shape()...),
|
|
||||||
WriterTo: mergeGroup(matched),
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return out, unmatched
|
|
||||||
}
|
|
||||||
|
|
||||||
// slicesSplitFunc splits a slice into two slices based on a predicate function.
|
|
||||||
func slicesSplitFunc[S ~[]E, E comparable](s S, fn func(e E) bool) (matched, unmatched S) {
|
|
||||||
for _, e := range s {
|
|
||||||
if fn(e) {
|
|
||||||
matched = append(matched, e)
|
|
||||||
} else {
|
|
||||||
unmatched = append(unmatched, e)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return matched, unmatched
|
|
||||||
}
|
|
||||||
|
|
||||||
type mergeGroup []Tensor
|
|
||||||
|
|
||||||
func (g mergeGroup) WriteTo(w io.Writer) (int64, error) {
|
|
||||||
for _, t := range g {
|
|
||||||
if _, err := t.WriteTo(w); err != nil {
|
|
||||||
return 0, err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return 0, nil
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -9,8 +9,6 @@ import (
|
|||||||
"strings"
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/google/go-cmp/cmp"
|
|
||||||
"github.com/ollama/ollama/fs/ggml"
|
|
||||||
"github.com/pdevine/tensor"
|
"github.com/pdevine/tensor"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -304,99 +302,3 @@ func TestSplitDim(t *testing.T) {
|
|||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestMerge(t *testing.T) {
|
|
||||||
unmatched := []Tensor{
|
|
||||||
&fakeTensor{
|
|
||||||
name: "a.0.b",
|
|
||||||
shape: []uint64{5, 2},
|
|
||||||
data: []float32{10, 11, 12, 13, 14, 15, 16, 17, 18, 19},
|
|
||||||
},
|
|
||||||
&fakeTensor{
|
|
||||||
name: "a.1.b",
|
|
||||||
shape: []uint64{5, 2},
|
|
||||||
data: []float32{20, 21, 22, 23, 24, 25, 26, 27, 28, 29},
|
|
||||||
},
|
|
||||||
&fakeTensor{
|
|
||||||
name: "c.0.d",
|
|
||||||
shape: []uint64{5, 2},
|
|
||||||
data: []float32{30, 31, 32, 33, 34, 35, 36, 37, 38, 39},
|
|
||||||
},
|
|
||||||
&fakeTensor{
|
|
||||||
name: "c.1.d",
|
|
||||||
shape: []uint64{5, 2},
|
|
||||||
data: []float32{40, 41, 42, 43, 44, 45, 46, 47, 48, 49},
|
|
||||||
},
|
|
||||||
&fakeTensor{
|
|
||||||
name: "e.0.f",
|
|
||||||
shape: []uint64{5, 2},
|
|
||||||
data: []float32{50, 51, 52, 53, 54, 55, 56, 57, 58, 59},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
checkMatched := func(t *testing.T, n int, matched []*ggml.Tensor) {
|
|
||||||
for i := range n {
|
|
||||||
got := matched[i]
|
|
||||||
if diff := cmp.Diff([]uint64{2, 5, 2}, got.Shape); diff != "" {
|
|
||||||
t.Errorf("unexpected (-want +got):\n%s", diff)
|
|
||||||
}
|
|
||||||
|
|
||||||
var b bytes.Buffer
|
|
||||||
if _, err := got.WriteTo(&b); err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
f32s := make([]float32, 20)
|
|
||||||
if err := binary.Read(&b, binary.LittleEndian, &f32s); err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
offset := 10 + (i * 20)
|
|
||||||
want := make([]float32, 20)
|
|
||||||
for j := range 20 {
|
|
||||||
want[j] = float32(offset + j)
|
|
||||||
}
|
|
||||||
|
|
||||||
if diff := cmp.Diff(want, f32s); diff != "" {
|
|
||||||
t.Errorf("unexpected data (-want +got):\n%s", diff)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
t.Run("single merge", func(t *testing.T) {
|
|
||||||
matched, unmatched := mergeTensors(unmatched, merge{"a.*.b", "a.b"})
|
|
||||||
if len(unmatched) != 3 {
|
|
||||||
t.Error("expected 3 remaining tensors, got", len(unmatched))
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(matched) != 1 {
|
|
||||||
t.Error("expected 1 merged tensor, got", len(matched))
|
|
||||||
}
|
|
||||||
|
|
||||||
checkMatched(t, 1, matched)
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("multiple merges", func(t *testing.T) {
|
|
||||||
matched, unmatched := mergeTensors(unmatched, merge{"a.*.b", "a.b"}, merge{"c.*.d", "c.d"})
|
|
||||||
if len(unmatched) != 1 {
|
|
||||||
t.Error("expected 1 remaining tensors, got", len(unmatched))
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(matched) != 2 {
|
|
||||||
t.Error("expected 2 merged tensor, got", len(matched))
|
|
||||||
}
|
|
||||||
|
|
||||||
checkMatched(t, 2, matched)
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("no match", func(t *testing.T) {
|
|
||||||
matched, unmatched := mergeTensors(unmatched, merge{"x.*.y", "x.y"})
|
|
||||||
if len(unmatched) != 5 {
|
|
||||||
t.Error("expected 5 remaining tensors, got", len(unmatched))
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(matched) != 0 {
|
|
||||||
t.Error("expected no merged tensors, got", len(matched))
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -3,7 +3,6 @@
|
|||||||
package discover
|
package discover
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
|
||||||
"log/slog"
|
"log/slog"
|
||||||
"os"
|
"os"
|
||||||
"regexp"
|
"regexp"
|
||||||
@@ -56,13 +55,10 @@ func cudaVariant(gpuInfo CudaGPUInfo) string {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return "sbsa"
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// driver 12.0 has problems with the cuda v12 library, so run v11 on those older drivers
|
// driver 12.0 has problems with the cuda v12 library, so run v11 on those older drivers
|
||||||
if gpuInfo.DriverMajor < 12 || (gpuInfo.DriverMajor == 12 && gpuInfo.DriverMinor == 0) {
|
if gpuInfo.DriverMajor < 12 || (gpuInfo.DriverMajor == 12 && gpuInfo.DriverMinor == 0) {
|
||||||
// The detected driver is older than Feb 2023
|
|
||||||
slog.Warn("old CUDA driver detected - please upgrade to a newer driver", "version", fmt.Sprintf("%d.%d", gpuInfo.DriverMajor, gpuInfo.DriverMinor))
|
|
||||||
return "v11"
|
return "v11"
|
||||||
}
|
}
|
||||||
return "v12"
|
return "v12"
|
||||||
|
|||||||
@@ -12,7 +12,7 @@ import (
|
|||||||
// '../lib/ollama' on Linux and the executable's directory on macOS
|
// '../lib/ollama' on Linux and the executable's directory on macOS
|
||||||
// note: distribution builds, additional GPU-specific libraries are
|
// note: distribution builds, additional GPU-specific libraries are
|
||||||
// found in subdirectories of the returned path, such as
|
// found in subdirectories of the returned path, such as
|
||||||
// 'cuda_v12', 'rocm', etc.
|
// 'cuda_v11', 'cuda_v12', 'rocm', etc.
|
||||||
var LibOllamaPath string = func() string {
|
var LibOllamaPath string = func() string {
|
||||||
exe, err := os.Executable()
|
exe, err := os.Executable()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
59
docs/benchmark.md
Normal file
59
docs/benchmark.md
Normal file
@@ -0,0 +1,59 @@
|
|||||||
|
# Benchmark
|
||||||
|
|
||||||
|
Go benchmark tests that measure end-to-end performance of a running Ollama server. Run these tests to evaluate model inference performance on your hardware and measure the impact of code changes.
|
||||||
|
|
||||||
|
## When to use
|
||||||
|
|
||||||
|
Run these benchmarks when:
|
||||||
|
- Making changes to the model inference engine
|
||||||
|
- Modifying model loading/unloading logic
|
||||||
|
- Changing prompt processing or token generation code
|
||||||
|
- Implementing a new model architecture
|
||||||
|
- Testing performance across different hardware setups
|
||||||
|
|
||||||
|
## Prerequisites
|
||||||
|
- Ollama server running locally with `ollama serve` on `127.0.0.1:11434`
|
||||||
|
## Usage and Examples
|
||||||
|
|
||||||
|
>[!NOTE]
|
||||||
|
>All commands must be run from the root directory of the Ollama project.
|
||||||
|
|
||||||
|
Basic syntax:
|
||||||
|
```bash
|
||||||
|
go test -bench=. ./benchmark/... -m $MODEL_NAME
|
||||||
|
```
|
||||||
|
|
||||||
|
Required flags:
|
||||||
|
- `-bench=.`: Run all benchmarks
|
||||||
|
- `-m`: Model name to benchmark
|
||||||
|
|
||||||
|
Optional flags:
|
||||||
|
- `-count N`: Number of times to run the benchmark (useful for statistical analysis)
|
||||||
|
- `-timeout T`: Maximum time for the benchmark to run (e.g. "10m" for 10 minutes)
|
||||||
|
|
||||||
|
Common usage patterns:
|
||||||
|
|
||||||
|
Single benchmark run with a model specified:
|
||||||
|
```bash
|
||||||
|
go test -bench=. ./benchmark/... -m llama3.3
|
||||||
|
```
|
||||||
|
|
||||||
|
## Output metrics
|
||||||
|
|
||||||
|
The benchmark reports several key metrics:
|
||||||
|
|
||||||
|
- `gen_tok/s`: Generated tokens per second
|
||||||
|
- `prompt_tok/s`: Prompt processing tokens per second
|
||||||
|
- `ttft_ms`: Time to first token in milliseconds
|
||||||
|
- `load_ms`: Model load time in milliseconds
|
||||||
|
- `gen_tokens`: Total tokens generated
|
||||||
|
- `prompt_tokens`: Total prompt tokens processed
|
||||||
|
|
||||||
|
Each benchmark runs two scenarios:
|
||||||
|
- Cold start: Model is loaded from disk for each test
|
||||||
|
- Warm start: Model is pre-loaded in memory
|
||||||
|
|
||||||
|
Three prompt lengths are tested for each scenario:
|
||||||
|
- Short prompt (100 tokens)
|
||||||
|
- Medium prompt (500 tokens)
|
||||||
|
- Long prompt (1000 tokens)
|
||||||
@@ -1,6 +1,6 @@
|
|||||||
# GPU
|
# GPU
|
||||||
## Nvidia
|
## Nvidia
|
||||||
Ollama supports Nvidia GPUs with compute capability 5.0+ and driver version 531 and newer.
|
Ollama supports Nvidia GPUs with compute capability 5.0+.
|
||||||
|
|
||||||
Check your compute compatibility to see if your card is supported:
|
Check your compute compatibility to see if your card is supported:
|
||||||
[https://developer.nvidia.com/cuda-gpus](https://developer.nvidia.com/cuda-gpus)
|
[https://developer.nvidia.com/cuda-gpus](https://developer.nvidia.com/cuda-gpus)
|
||||||
|
|||||||
@@ -43,7 +43,7 @@ Ollama includes multiple LLM libraries compiled for different GPUs and CPU vecto
|
|||||||
In the server log, you will see a message that looks something like this (varies from release to release):
|
In the server log, you will see a message that looks something like this (varies from release to release):
|
||||||
|
|
||||||
```
|
```
|
||||||
Dynamic LLM libraries [rocm_v6 cpu cpu_avx cpu_avx2 cuda_v12 rocm_v5]
|
Dynamic LLM libraries [rocm_v6 cpu cpu_avx cpu_avx2 cuda_v11 rocm_v5]
|
||||||
```
|
```
|
||||||
|
|
||||||
**Experimental LLM Library Override**
|
**Experimental LLM Library Override**
|
||||||
|
|||||||
@@ -65,7 +65,7 @@ func Open(path string) (f *File, err error) {
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
if f.Version < 2 {
|
if f.Version != 3 {
|
||||||
return nil, fmt.Errorf("%w version %v", ErrUnsupported, f.Version)
|
return nil, fmt.Errorf("%w version %v", ErrUnsupported, f.Version)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -45,8 +45,6 @@ var (
|
|||||||
"qwen2.5-coder:latest",
|
"qwen2.5-coder:latest",
|
||||||
"qwen:latest",
|
"qwen:latest",
|
||||||
"solar-pro:latest",
|
"solar-pro:latest",
|
||||||
"codellama:latest",
|
|
||||||
"nous-hermes:latest",
|
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -1,32 +0,0 @@
|
|||||||
From 0000000000000000000000000000000000000000 Mon Sep 17 00:00:00 2001
|
|
||||||
From: Daniel Hiltgen <daniel@ollama.com>
|
|
||||||
Date: Sun, 22 Jun 2025 09:22:05 -0700
|
|
||||||
Subject: [PATCH] temporary prevent rocm+cuda mixed loading
|
|
||||||
|
|
||||||
---
|
|
||||||
ggml/src/ggml-backend-reg.cpp | 12 ++++++++++--
|
|
||||||
1 file changed, 10 insertions(+), 2 deletions(-)
|
|
||||||
|
|
||||||
diff --git a/ggml/src/ggml-backend-reg.cpp b/ggml/src/ggml-backend-reg.cpp
|
|
||||||
index 4e67d243..8f49f084 100644
|
|
||||||
--- a/ggml/src/ggml-backend-reg.cpp
|
|
||||||
+++ b/ggml/src/ggml-backend-reg.cpp
|
|
||||||
@@ -573,8 +573,16 @@ void ggml_backend_load_all_from_path(const char * dir_path) {
|
|
||||||
|
|
||||||
ggml_backend_load_best("blas", silent, dir_path);
|
|
||||||
ggml_backend_load_best("cann", silent, dir_path);
|
|
||||||
- ggml_backend_load_best("cuda", silent, dir_path);
|
|
||||||
- ggml_backend_load_best("hip", silent, dir_path);
|
|
||||||
+
|
|
||||||
+ // Avoid mixed hip+cuda configurations
|
|
||||||
+ const char * hip_devices = std::getenv("HIP_VISIBLE_DEVICES");
|
|
||||||
+ const char * rocr_devices = std::getenv("ROCR_VISIBLE_DEVICES");
|
|
||||||
+ if (!hip_devices && !rocr_devices) {
|
|
||||||
+ ggml_backend_load_best("cuda", silent, dir_path);
|
|
||||||
+ } else {
|
|
||||||
+ ggml_backend_load_best("hip", silent, dir_path);
|
|
||||||
+ }
|
|
||||||
+
|
|
||||||
ggml_backend_load_best("kompute", silent, dir_path);
|
|
||||||
ggml_backend_load_best("metal", silent, dir_path);
|
|
||||||
ggml_backend_load_best("rpc", silent, dir_path);
|
|
||||||
@@ -139,13 +139,6 @@ func NewLlamaServer(gpus discover.GpuInfoList, modelPath string, f *ggml.GGML, a
|
|||||||
gpus = discover.GetCPUInfo()
|
gpus = discover.GetCPUInfo()
|
||||||
}
|
}
|
||||||
|
|
||||||
// Verify the requested context size is <= the model training size
|
|
||||||
trainCtx := f.KV().ContextLength()
|
|
||||||
if opts.NumCtx/numParallel > int(trainCtx) && trainCtx > 0 {
|
|
||||||
slog.Warn("requested context size too large for model", "num_ctx", opts.NumCtx, "num_parallel", numParallel, "n_ctx_train", trainCtx)
|
|
||||||
opts.NumCtx = int(trainCtx) * numParallel
|
|
||||||
}
|
|
||||||
|
|
||||||
estimate := EstimateGPULayers(gpus, f, projectors, opts, numParallel)
|
estimate := EstimateGPULayers(gpus, f, projectors, opts, numParallel)
|
||||||
if len(gpus) > 1 || gpus[0].Library != "cpu" {
|
if len(gpus) > 1 || gpus[0].Library != "cpu" {
|
||||||
switch {
|
switch {
|
||||||
@@ -318,7 +311,7 @@ func NewLlamaServer(gpus discover.GpuInfoList, modelPath string, f *ggml.GGML, a
|
|||||||
params = append(params, "--mmproj", projectors[0])
|
params = append(params, "--mmproj", projectors[0])
|
||||||
}
|
}
|
||||||
|
|
||||||
// iterate through compatible GPU libraries such as 'cuda_v12', 'rocm', etc.
|
// iterate through compatible GPU libraries such as 'cuda_v12', 'cuda_v11', 'rocm', etc.
|
||||||
// adding each library's respective path to the LD_LIBRARY_PATH, until finally running
|
// adding each library's respective path to the LD_LIBRARY_PATH, until finally running
|
||||||
// without any LD_LIBRARY_PATH flags
|
// without any LD_LIBRARY_PATH flags
|
||||||
for {
|
for {
|
||||||
|
|||||||
@@ -602,9 +602,7 @@ func (c *Context) Forward(tensors ...ml.Tensor) ml.Context {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (c *Context) Compute(tensors ...ml.Tensor) {
|
func (c *Context) Compute(tensors ...ml.Tensor) {
|
||||||
if status := C.ggml_backend_sched_graph_compute_async(c.b.sched, c.graph); status != C.GGML_STATUS_SUCCESS {
|
C.ggml_backend_sched_graph_compute_async(c.b.sched, c.graph)
|
||||||
panic(fmt.Errorf("error computing ggml graph: %v", status))
|
|
||||||
}
|
|
||||||
C.ggml_backend_sched_reset(c.b.sched)
|
C.ggml_backend_sched_reset(c.b.sched)
|
||||||
|
|
||||||
needSync := true
|
needSync := true
|
||||||
|
|||||||
12
ml/backend/ggml/ggml/src/ggml-backend-reg.cpp
vendored
12
ml/backend/ggml/ggml/src/ggml-backend-reg.cpp
vendored
@@ -573,16 +573,8 @@ void ggml_backend_load_all_from_path(const char * dir_path) {
|
|||||||
|
|
||||||
ggml_backend_load_best("blas", silent, dir_path);
|
ggml_backend_load_best("blas", silent, dir_path);
|
||||||
ggml_backend_load_best("cann", silent, dir_path);
|
ggml_backend_load_best("cann", silent, dir_path);
|
||||||
|
ggml_backend_load_best("cuda", silent, dir_path);
|
||||||
// Avoid mixed hip+cuda configurations
|
ggml_backend_load_best("hip", silent, dir_path);
|
||||||
const char * hip_devices = std::getenv("HIP_VISIBLE_DEVICES");
|
|
||||||
const char * rocr_devices = std::getenv("ROCR_VISIBLE_DEVICES");
|
|
||||||
if (!hip_devices && !rocr_devices) {
|
|
||||||
ggml_backend_load_best("cuda", silent, dir_path);
|
|
||||||
} else {
|
|
||||||
ggml_backend_load_best("hip", silent, dir_path);
|
|
||||||
}
|
|
||||||
|
|
||||||
ggml_backend_load_best("kompute", silent, dir_path);
|
ggml_backend_load_best("kompute", silent, dir_path);
|
||||||
ggml_backend_load_best("metal", silent, dir_path);
|
ggml_backend_load_best("metal", silent, dir_path);
|
||||||
ggml_backend_load_best("rpc", silent, dir_path);
|
ggml_backend_load_best("rpc", silent, dir_path);
|
||||||
|
|||||||
@@ -87,7 +87,7 @@ func (v *Vocabulary) Decode(id int32) string {
|
|||||||
func (v *Vocabulary) SpecialVocabulary() []string {
|
func (v *Vocabulary) SpecialVocabulary() []string {
|
||||||
v.specialOnce.Do(func() {
|
v.specialOnce.Do(func() {
|
||||||
for i := range v.Values {
|
for i := range v.Values {
|
||||||
if v.Types[i] == TOKEN_TYPE_CONTROL || v.Types[i] == TOKEN_TYPE_USER_DEFINED {
|
if v.Types[i] == TOKEN_TYPE_CONTROL {
|
||||||
v.special = append(v.special, v.Values[i])
|
v.special = append(v.special, v.Values[i])
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,16 +0,0 @@
|
|||||||
package model
|
|
||||||
|
|
||||||
import "testing"
|
|
||||||
|
|
||||||
func TestVocabulary_SpecialVocabulary(t *testing.T) {
|
|
||||||
vocab := &Vocabulary{
|
|
||||||
Values: []string{"<|startoftext|>", "<|endoftext|>", "<|tool_call_start|>", "<|tool_call_end|>", "hi"},
|
|
||||||
Types: []int32{TOKEN_TYPE_CONTROL, TOKEN_TYPE_CONTROL, TOKEN_TYPE_USER_DEFINED, TOKEN_TYPE_USER_DEFINED, TOKEN_TYPE_NORMAL},
|
|
||||||
}
|
|
||||||
|
|
||||||
specialVocab := vocab.SpecialVocabulary()
|
|
||||||
|
|
||||||
if len(specialVocab) != 4 {
|
|
||||||
t.Errorf("expected 4 special tokens, got %d", len(specialVocab))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -27,6 +27,7 @@ function checkEnv() {
|
|||||||
$env:VCToolsRedistDir=(get-item "${MSVC_INSTALL}\VC\Redist\MSVC\*")[0]
|
$env:VCToolsRedistDir=(get-item "${MSVC_INSTALL}\VC\Redist\MSVC\*")[0]
|
||||||
}
|
}
|
||||||
# Locate CUDA versions
|
# Locate CUDA versions
|
||||||
|
# Note: this assumes every version found will be built
|
||||||
$cudaList=(get-item "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v*\bin\" -ea 'silentlycontinue')
|
$cudaList=(get-item "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v*\bin\" -ea 'silentlycontinue')
|
||||||
if ($cudaList.length -eq 0) {
|
if ($cudaList.length -eq 0) {
|
||||||
$d=(get-command -ea 'silentlycontinue' nvcc).path
|
$d=(get-command -ea 'silentlycontinue' nvcc).path
|
||||||
@@ -93,6 +94,19 @@ function buildOllama() {
|
|||||||
|
|
||||||
$hashEnv = @{}
|
$hashEnv = @{}
|
||||||
Get-ChildItem env: | foreach { $hashEnv[$_.Name] = $_.Value }
|
Get-ChildItem env: | foreach { $hashEnv[$_.Name] = $_.Value }
|
||||||
|
if ("$script:CUDA_DIRS".Contains("v11")) {
|
||||||
|
$hashEnv.Keys | foreach { if ($_.Contains("CUDA_PATH_V11")) { $v11="$_" }}
|
||||||
|
$env:CUDAToolkit_ROOT=$hashEnv[$v11]
|
||||||
|
write-host "Building CUDA v11 backend libraries"
|
||||||
|
# Note: cuda v11 requires msvc 2019 so force the older generator
|
||||||
|
# to avoid 2022 (or newer) from being used as the default
|
||||||
|
& cmake --fresh --preset "CUDA 11" -G "Visual Studio 16 2019" --install-prefix $script:DIST_DIR
|
||||||
|
if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)}
|
||||||
|
& cmake --build --preset "CUDA 11" --config Release --parallel $script:JOBS
|
||||||
|
if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)}
|
||||||
|
& cmake --install build --component "CUDA" --strip
|
||||||
|
if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)}
|
||||||
|
}
|
||||||
if ("$script:CUDA_DIRS".Contains("v12")) {
|
if ("$script:CUDA_DIRS".Contains("v12")) {
|
||||||
$hashEnv.Keys | foreach { if ($_.Contains("CUDA_PATH_V12")) { $v12="$_" }}
|
$hashEnv.Keys | foreach { if ($_.Contains("CUDA_PATH_V12")) { $v12="$_" }}
|
||||||
$env:CUDAToolkit_ROOT=$hashEnv[$v12]
|
$env:CUDAToolkit_ROOT=$hashEnv[$v12]
|
||||||
|
|||||||
@@ -10,7 +10,9 @@ OLLAMA_COMMON_BUILD_ARGS="--build-arg=VERSION \
|
|||||||
--build-arg=GOFLAGS \
|
--build-arg=GOFLAGS \
|
||||||
--build-arg=OLLAMA_CUSTOM_CPU_DEFS \
|
--build-arg=OLLAMA_CUSTOM_CPU_DEFS \
|
||||||
--build-arg=OLLAMA_SKIP_CUDA_GENERATE \
|
--build-arg=OLLAMA_SKIP_CUDA_GENERATE \
|
||||||
|
--build-arg=OLLAMA_SKIP_CUDA_11_GENERATE \
|
||||||
--build-arg=OLLAMA_SKIP_CUDA_12_GENERATE \
|
--build-arg=OLLAMA_SKIP_CUDA_12_GENERATE \
|
||||||
|
--build-arg=CUDA_V11_ARCHITECTURES \
|
||||||
--build-arg=CUDA_V12_ARCHITECTURES \
|
--build-arg=CUDA_V12_ARCHITECTURES \
|
||||||
--build-arg=OLLAMA_SKIP_ROCM_GENERATE \
|
--build-arg=OLLAMA_SKIP_ROCM_GENERATE \
|
||||||
--build-arg=OLLAMA_FAST_BUILD \
|
--build-arg=OLLAMA_FAST_BUILD \
|
||||||
|
|||||||
115
server/cache/capabilities.go
vendored
Normal file
115
server/cache/capabilities.go
vendored
Normal file
@@ -0,0 +1,115 @@
|
|||||||
|
package cache
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"log/slog"
|
||||||
|
"os"
|
||||||
|
"slices"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/ollama/ollama/fs/ggml"
|
||||||
|
"github.com/ollama/ollama/template"
|
||||||
|
"github.com/ollama/ollama/thinking"
|
||||||
|
"github.com/ollama/ollama/types/model"
|
||||||
|
)
|
||||||
|
|
||||||
|
// cacheEntry stores capabilities and the modification time of the model file
|
||||||
|
type cacheEntry struct {
|
||||||
|
capabilities []model.Capability
|
||||||
|
modTime time.Time
|
||||||
|
}
|
||||||
|
|
||||||
|
// ggufCapabilities is a cache for gguf model capabilities
|
||||||
|
var ggufCapabilities = &sync.Map{}
|
||||||
|
|
||||||
|
// ModelInfo contains the minimal information needed to determine capabilities
|
||||||
|
type ModelInfo struct {
|
||||||
|
ModelPath string
|
||||||
|
ProjectorPaths []string
|
||||||
|
Template *template.Template
|
||||||
|
}
|
||||||
|
|
||||||
|
// Capabilities returns the capabilities that the model supports
|
||||||
|
func Capabilities(info ModelInfo) []model.Capability {
|
||||||
|
capabilities, err := ggufCapabilties(info.ModelPath)
|
||||||
|
if err != nil {
|
||||||
|
slog.Error("could not determine gguf capabilities", "error", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if info.Template == nil {
|
||||||
|
return capabilities
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check for tools capability
|
||||||
|
if slices.Contains(info.Template.Vars(), "tools") {
|
||||||
|
capabilities = append(capabilities, model.CapabilityTools)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check for insert capability
|
||||||
|
if slices.Contains(info.Template.Vars(), "suffix") {
|
||||||
|
capabilities = append(capabilities, model.CapabilityInsert)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check for vision capability in projector-based models
|
||||||
|
if len(info.ProjectorPaths) > 0 {
|
||||||
|
capabilities = append(capabilities, model.CapabilityVision)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check for thinking capability
|
||||||
|
openingTag, closingTag := thinking.InferTags(info.Template.Template)
|
||||||
|
if openingTag != "" && closingTag != "" {
|
||||||
|
capabilities = append(capabilities, model.CapabilityThinking)
|
||||||
|
}
|
||||||
|
|
||||||
|
return capabilities
|
||||||
|
}
|
||||||
|
|
||||||
|
func ggufCapabilties(modelPath string) ([]model.Capability, error) {
|
||||||
|
// Get file info to check modification time
|
||||||
|
fileInfo, err := os.Stat(modelPath)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
currentModTime := fileInfo.ModTime()
|
||||||
|
|
||||||
|
// Check if we have a cached entry
|
||||||
|
if cached, ok := ggufCapabilities.Load(modelPath); ok {
|
||||||
|
entry := cached.(cacheEntry)
|
||||||
|
// If the file hasn't been modified since we cached it, return the cached capabilities
|
||||||
|
if entry.modTime.Equal(currentModTime) {
|
||||||
|
return entry.capabilities, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// If not cached or file was modified, read the model file to determine capabilities
|
||||||
|
capabilities := []model.Capability{}
|
||||||
|
|
||||||
|
r, err := os.Open(modelPath)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer r.Close()
|
||||||
|
|
||||||
|
f, err := ggml.Decode(r, 1024)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, ok := f.KV()[fmt.Sprintf("%s.pooling_type", f.KV().Architecture())]; ok {
|
||||||
|
capabilities = append(capabilities, model.CapabilityEmbedding)
|
||||||
|
} else {
|
||||||
|
capabilities = append(capabilities, model.CapabilityCompletion)
|
||||||
|
}
|
||||||
|
if _, ok := f.KV()[fmt.Sprintf("%s.vision.block_count", f.KV().Architecture())]; ok {
|
||||||
|
capabilities = append(capabilities, model.CapabilityVision)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Cache the capabilities with the modification time
|
||||||
|
ggufCapabilities.Store(modelPath, cacheEntry{
|
||||||
|
capabilities: capabilities,
|
||||||
|
modTime: currentModTime,
|
||||||
|
})
|
||||||
|
|
||||||
|
return capabilities, nil
|
||||||
|
}
|
||||||
211
server/cache/capabilities_test.go
vendored
Normal file
211
server/cache/capabilities_test.go
vendored
Normal file
@@ -0,0 +1,211 @@
|
|||||||
|
package cache
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"maps"
|
||||||
|
"os"
|
||||||
|
"slices"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/ollama/ollama/fs/ggml"
|
||||||
|
"github.com/ollama/ollama/template"
|
||||||
|
"github.com/ollama/ollama/types/model"
|
||||||
|
)
|
||||||
|
|
||||||
|
// testGGUF creates a temporary GGUF model file for testing with custom key-value pairs
|
||||||
|
func testGGUF(tb testing.TB, customKV ggml.KV) string {
|
||||||
|
tb.Helper()
|
||||||
|
f, err := os.CreateTemp(tb.TempDir(), "test*.gguf")
|
||||||
|
if err != nil {
|
||||||
|
tb.Fatal(err)
|
||||||
|
}
|
||||||
|
defer f.Close()
|
||||||
|
|
||||||
|
kv := ggml.KV{}
|
||||||
|
maps.Copy(kv, customKV)
|
||||||
|
|
||||||
|
tensors := []*ggml.Tensor{
|
||||||
|
{
|
||||||
|
Name: "token_embd.weight",
|
||||||
|
Kind: 0,
|
||||||
|
Shape: []uint64{1, 1},
|
||||||
|
WriterTo: bytes.NewBuffer(make([]byte, 4)),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := ggml.WriteGGUF(f, kv, tensors); err != nil {
|
||||||
|
tb.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return f.Name()
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCapabilities(t *testing.T) {
|
||||||
|
ggufCapabilities.Range(func(key, value any) bool {
|
||||||
|
ggufCapabilities.Delete(key)
|
||||||
|
return true
|
||||||
|
})
|
||||||
|
|
||||||
|
// Create test model paths
|
||||||
|
completionModelPath := testGGUF(t, ggml.KV{
|
||||||
|
"general.architecture": "llama",
|
||||||
|
})
|
||||||
|
|
||||||
|
visionModelPath := testGGUF(t, ggml.KV{
|
||||||
|
"general.architecture": "llama",
|
||||||
|
"llama.vision.block_count": uint32(1),
|
||||||
|
})
|
||||||
|
|
||||||
|
embeddingModelPath := testGGUF(t, ggml.KV{
|
||||||
|
"general.architecture": "bert",
|
||||||
|
"bert.pooling_type": uint32(1),
|
||||||
|
})
|
||||||
|
|
||||||
|
// Create templates
|
||||||
|
toolsInsertTemplate, err := template.Parse("{{ .prompt }}{{ if .tools }}{{ .tools }}{{ end }}{{ if .suffix }}{{ .suffix }}{{ end }}")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to parse template: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
chatTemplate, err := template.Parse("{{ .prompt }}")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to parse template: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
toolsTemplate, err := template.Parse("{{ .prompt }}{{ if .tools }}{{ .tools }}{{ end }}")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to parse template: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
testCases := []struct {
|
||||||
|
name string
|
||||||
|
model ModelInfo
|
||||||
|
expectedCaps []model.Capability
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "model with completion capability",
|
||||||
|
model: ModelInfo{
|
||||||
|
ModelPath: completionModelPath,
|
||||||
|
Template: chatTemplate,
|
||||||
|
},
|
||||||
|
expectedCaps: []model.Capability{model.CapabilityCompletion},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "model with completion, tools, and insert capability",
|
||||||
|
model: ModelInfo{
|
||||||
|
ModelPath: completionModelPath,
|
||||||
|
Template: toolsInsertTemplate,
|
||||||
|
},
|
||||||
|
expectedCaps: []model.Capability{model.CapabilityCompletion, model.CapabilityTools, model.CapabilityInsert},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "model with tools capability",
|
||||||
|
model: ModelInfo{
|
||||||
|
ModelPath: completionModelPath,
|
||||||
|
Template: toolsTemplate,
|
||||||
|
},
|
||||||
|
expectedCaps: []model.Capability{model.CapabilityCompletion, model.CapabilityTools},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "model with vision capability from gguf",
|
||||||
|
model: ModelInfo{
|
||||||
|
ModelPath: visionModelPath,
|
||||||
|
Template: chatTemplate,
|
||||||
|
},
|
||||||
|
expectedCaps: []model.Capability{model.CapabilityCompletion, model.CapabilityVision},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "model with vision capability from projector",
|
||||||
|
model: ModelInfo{
|
||||||
|
ModelPath: completionModelPath,
|
||||||
|
ProjectorPaths: []string{"/path/to/projector"},
|
||||||
|
Template: chatTemplate,
|
||||||
|
},
|
||||||
|
expectedCaps: []model.Capability{model.CapabilityCompletion, model.CapabilityVision},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "model with vision, tools, and insert capability",
|
||||||
|
model: ModelInfo{
|
||||||
|
ModelPath: visionModelPath,
|
||||||
|
Template: toolsInsertTemplate,
|
||||||
|
},
|
||||||
|
expectedCaps: []model.Capability{model.CapabilityCompletion, model.CapabilityVision, model.CapabilityTools, model.CapabilityInsert},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "model with embedding capability",
|
||||||
|
model: ModelInfo{
|
||||||
|
ModelPath: embeddingModelPath,
|
||||||
|
Template: chatTemplate,
|
||||||
|
},
|
||||||
|
expectedCaps: []model.Capability{model.CapabilityEmbedding},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range testCases {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
// First call - should read from file
|
||||||
|
caps := Capabilities(tc.model)
|
||||||
|
slices.Sort(caps)
|
||||||
|
slices.Sort(tc.expectedCaps)
|
||||||
|
if !slices.Equal(caps, tc.expectedCaps) {
|
||||||
|
t.Errorf("Expected capabilities %v, got %v", tc.expectedCaps, caps)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify caching for models that read from GGUF
|
||||||
|
if tc.model.ModelPath != "" {
|
||||||
|
// Check that entry is cached
|
||||||
|
_, ok := ggufCapabilities.Load(tc.model.ModelPath)
|
||||||
|
if !ok {
|
||||||
|
t.Error("Expected capabilities to be cached")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Second call - should use cache
|
||||||
|
caps2 := Capabilities(tc.model)
|
||||||
|
slices.Sort(caps2)
|
||||||
|
if !slices.Equal(caps, caps2) {
|
||||||
|
t.Errorf("Cached capabilities don't match original: expected %v, got %v", caps, caps2)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test cache invalidation on file modification
|
||||||
|
t.Run("cache invalidation", func(t *testing.T) {
|
||||||
|
// Use completion model for this test
|
||||||
|
info := ModelInfo{
|
||||||
|
ModelPath: completionModelPath,
|
||||||
|
Template: chatTemplate,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get initial cached entry
|
||||||
|
cached, ok := ggufCapabilities.Load(completionModelPath)
|
||||||
|
if !ok {
|
||||||
|
t.Fatal("Expected model to be cached from previous tests")
|
||||||
|
}
|
||||||
|
entry := cached.(cacheEntry)
|
||||||
|
|
||||||
|
// Modify the file's timestamp to the future
|
||||||
|
future := time.Now().Add(time.Hour)
|
||||||
|
err := os.Chtimes(completionModelPath, future, future)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to update file timestamp: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Call should re-read from file due to changed modtime
|
||||||
|
caps := Capabilities(info)
|
||||||
|
if len(caps) != 1 || caps[0] != model.CapabilityCompletion {
|
||||||
|
t.Errorf("Expected [CapabilityCompletion], got %v", caps)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check that cache was updated with new modtime
|
||||||
|
cached2, ok := ggufCapabilities.Load(completionModelPath)
|
||||||
|
if !ok {
|
||||||
|
t.Error("Expected capabilities to be cached after re-read")
|
||||||
|
}
|
||||||
|
entry2 := cached2.(cacheEntry)
|
||||||
|
if entry2.modTime.Equal(entry.modTime) {
|
||||||
|
t.Error("Expected cache entry to have updated modTime")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
@@ -23,10 +23,9 @@ import (
|
|||||||
|
|
||||||
"github.com/ollama/ollama/api"
|
"github.com/ollama/ollama/api"
|
||||||
"github.com/ollama/ollama/envconfig"
|
"github.com/ollama/ollama/envconfig"
|
||||||
"github.com/ollama/ollama/fs/gguf"
|
|
||||||
"github.com/ollama/ollama/parser"
|
"github.com/ollama/ollama/parser"
|
||||||
|
"github.com/ollama/ollama/server/cache"
|
||||||
"github.com/ollama/ollama/template"
|
"github.com/ollama/ollama/template"
|
||||||
"github.com/ollama/ollama/thinking"
|
|
||||||
"github.com/ollama/ollama/types/model"
|
"github.com/ollama/ollama/types/model"
|
||||||
"github.com/ollama/ollama/version"
|
"github.com/ollama/ollama/version"
|
||||||
)
|
)
|
||||||
@@ -68,60 +67,14 @@ type Model struct {
|
|||||||
Template *template.Template
|
Template *template.Template
|
||||||
}
|
}
|
||||||
|
|
||||||
// Capabilities returns the capabilities that the model supports
|
|
||||||
func (m *Model) Capabilities() []model.Capability {
|
|
||||||
capabilities := []model.Capability{}
|
|
||||||
|
|
||||||
// Check for completion capability
|
|
||||||
f, err := gguf.Open(m.ModelPath)
|
|
||||||
if err == nil {
|
|
||||||
defer f.Close()
|
|
||||||
|
|
||||||
if f.KeyValue("pooling_type").Valid() {
|
|
||||||
capabilities = append(capabilities, model.CapabilityEmbedding)
|
|
||||||
} else {
|
|
||||||
// If no embedding is specified, we assume the model supports completion
|
|
||||||
capabilities = append(capabilities, model.CapabilityCompletion)
|
|
||||||
}
|
|
||||||
if f.KeyValue("vision.block_count").Valid() {
|
|
||||||
capabilities = append(capabilities, model.CapabilityVision)
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
slog.Error("couldn't open model file", "error", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if m.Template == nil {
|
|
||||||
return capabilities
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check for tools capability
|
|
||||||
if slices.Contains(m.Template.Vars(), "tools") {
|
|
||||||
capabilities = append(capabilities, model.CapabilityTools)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check for insert capability
|
|
||||||
if slices.Contains(m.Template.Vars(), "suffix") {
|
|
||||||
capabilities = append(capabilities, model.CapabilityInsert)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check for vision capability in projector-based models
|
|
||||||
if len(m.ProjectorPaths) > 0 {
|
|
||||||
capabilities = append(capabilities, model.CapabilityVision)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check for thinking capability
|
|
||||||
openingTag, closingTag := thinking.InferTags(m.Template.Template)
|
|
||||||
if openingTag != "" && closingTag != "" {
|
|
||||||
capabilities = append(capabilities, model.CapabilityThinking)
|
|
||||||
}
|
|
||||||
|
|
||||||
return capabilities
|
|
||||||
}
|
|
||||||
|
|
||||||
// CheckCapabilities checks if the model has the specified capabilities returning an error describing
|
// CheckCapabilities checks if the model has the specified capabilities returning an error describing
|
||||||
// any missing or unknown capabilities
|
// any missing or unknown capabilities
|
||||||
func (m *Model) CheckCapabilities(want ...model.Capability) error {
|
func (m *Model) CheckCapabilities(want ...model.Capability) error {
|
||||||
available := m.Capabilities()
|
available := cache.Capabilities(cache.ModelInfo{
|
||||||
|
ModelPath: m.ModelPath,
|
||||||
|
ProjectorPaths: m.ProjectorPaths,
|
||||||
|
Template: m.Template,
|
||||||
|
})
|
||||||
var errs []error
|
var errs []error
|
||||||
|
|
||||||
// Map capabilities to their corresponding error
|
// Map capabilities to their corresponding error
|
||||||
|
|||||||
@@ -9,131 +9,6 @@ import (
|
|||||||
"github.com/ollama/ollama/types/model"
|
"github.com/ollama/ollama/types/model"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestModelCapabilities(t *testing.T) {
|
|
||||||
// Create completion model (llama architecture without vision)
|
|
||||||
completionModelPath, _ := createBinFile(t, ggml.KV{
|
|
||||||
"general.architecture": "llama",
|
|
||||||
}, []*ggml.Tensor{})
|
|
||||||
|
|
||||||
// Create vision model (llama architecture with vision block count)
|
|
||||||
visionModelPath, _ := createBinFile(t, ggml.KV{
|
|
||||||
"general.architecture": "llama",
|
|
||||||
"llama.vision.block_count": uint32(1),
|
|
||||||
}, []*ggml.Tensor{})
|
|
||||||
|
|
||||||
// Create embedding model (bert architecture with pooling type)
|
|
||||||
embeddingModelPath, _ := createBinFile(t, ggml.KV{
|
|
||||||
"general.architecture": "bert",
|
|
||||||
"bert.pooling_type": uint32(1),
|
|
||||||
}, []*ggml.Tensor{})
|
|
||||||
|
|
||||||
toolsInsertTemplate, err := template.Parse("{{ .prompt }}{{ if .tools }}{{ .tools }}{{ end }}{{ if .suffix }}{{ .suffix }}{{ end }}")
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Failed to parse template: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
chatTemplate, err := template.Parse("{{ .prompt }}")
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Failed to parse template: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
toolsTemplate, err := template.Parse("{{ .prompt }}{{ if .tools }}{{ .tools }}{{ end }}")
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Failed to parse template: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
testModels := []struct {
|
|
||||||
name string
|
|
||||||
model Model
|
|
||||||
expectedCaps []model.Capability
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
name: "model with completion capability",
|
|
||||||
model: Model{
|
|
||||||
ModelPath: completionModelPath,
|
|
||||||
Template: chatTemplate,
|
|
||||||
},
|
|
||||||
expectedCaps: []model.Capability{model.CapabilityCompletion},
|
|
||||||
},
|
|
||||||
|
|
||||||
{
|
|
||||||
name: "model with completion, tools, and insert capability",
|
|
||||||
model: Model{
|
|
||||||
ModelPath: completionModelPath,
|
|
||||||
Template: toolsInsertTemplate,
|
|
||||||
},
|
|
||||||
expectedCaps: []model.Capability{model.CapabilityCompletion, model.CapabilityTools, model.CapabilityInsert},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "model with tools capability",
|
|
||||||
model: Model{
|
|
||||||
ModelPath: completionModelPath,
|
|
||||||
Template: toolsTemplate,
|
|
||||||
},
|
|
||||||
expectedCaps: []model.Capability{model.CapabilityCompletion, model.CapabilityTools},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "model with vision capability",
|
|
||||||
model: Model{
|
|
||||||
ModelPath: visionModelPath,
|
|
||||||
Template: chatTemplate,
|
|
||||||
},
|
|
||||||
expectedCaps: []model.Capability{model.CapabilityCompletion, model.CapabilityVision},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "model with vision, tools, and insert capability",
|
|
||||||
model: Model{
|
|
||||||
ModelPath: visionModelPath,
|
|
||||||
Template: toolsInsertTemplate,
|
|
||||||
},
|
|
||||||
expectedCaps: []model.Capability{model.CapabilityCompletion, model.CapabilityVision, model.CapabilityTools, model.CapabilityInsert},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "model with embedding capability",
|
|
||||||
model: Model{
|
|
||||||
ModelPath: embeddingModelPath,
|
|
||||||
Template: chatTemplate,
|
|
||||||
},
|
|
||||||
expectedCaps: []model.Capability{model.CapabilityEmbedding},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
// compare two slices of model.Capability regardless of order
|
|
||||||
compareCapabilities := func(a, b []model.Capability) bool {
|
|
||||||
if len(a) != len(b) {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
aCount := make(map[model.Capability]int)
|
|
||||||
for _, cap := range a {
|
|
||||||
aCount[cap]++
|
|
||||||
}
|
|
||||||
|
|
||||||
bCount := make(map[model.Capability]int)
|
|
||||||
for _, cap := range b {
|
|
||||||
bCount[cap]++
|
|
||||||
}
|
|
||||||
|
|
||||||
for cap, count := range aCount {
|
|
||||||
if bCount[cap] != count {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tt := range testModels {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
// Test Capabilities method
|
|
||||||
caps := tt.model.Capabilities()
|
|
||||||
if !compareCapabilities(caps, tt.expectedCaps) {
|
|
||||||
t.Errorf("Expected capabilities %v, got %v", tt.expectedCaps, caps)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestModelCheckCapabilities(t *testing.T) {
|
func TestModelCheckCapabilities(t *testing.T) {
|
||||||
// Create simple model file for tests that don't depend on GGUF content
|
// Create simple model file for tests that don't depend on GGUF content
|
||||||
completionModelPath, _ := createBinFile(t, ggml.KV{
|
completionModelPath, _ := createBinFile(t, ggml.KV{
|
||||||
|
|||||||
2
server/internal/cache/blob/cache.go
vendored
2
server/internal/cache/blob/cache.go
vendored
@@ -59,7 +59,7 @@ type DiskCache struct {
|
|||||||
testHookBeforeFinalWrite func(f *os.File)
|
testHookBeforeFinalWrite func(f *os.File)
|
||||||
}
|
}
|
||||||
|
|
||||||
// PutBytes is a convenience function for c.Put(d, strings.NewReader(s), int64(len(s))).
|
// PutString is a convenience function for c.Put(d, strings.NewReader(s), int64(len(s))).
|
||||||
func PutBytes[S string | []byte](c *DiskCache, d Digest, data S) error {
|
func PutBytes[S string | []byte](c *DiskCache, d Digest, data S) error {
|
||||||
return c.Put(d, bytes.NewReader([]byte(data)), int64(len(data)))
|
return c.Put(d, bytes.NewReader([]byte(data)), int64(len(data)))
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -34,6 +34,7 @@ import (
|
|||||||
"github.com/ollama/ollama/llm"
|
"github.com/ollama/ollama/llm"
|
||||||
"github.com/ollama/ollama/logutil"
|
"github.com/ollama/ollama/logutil"
|
||||||
"github.com/ollama/ollama/openai"
|
"github.com/ollama/ollama/openai"
|
||||||
|
"github.com/ollama/ollama/server/cache"
|
||||||
"github.com/ollama/ollama/server/internal/client/ollama"
|
"github.com/ollama/ollama/server/internal/client/ollama"
|
||||||
"github.com/ollama/ollama/server/internal/registry"
|
"github.com/ollama/ollama/server/internal/registry"
|
||||||
"github.com/ollama/ollama/template"
|
"github.com/ollama/ollama/template"
|
||||||
@@ -819,13 +820,17 @@ func GetModelInfo(req api.ShowRequest) (*api.ShowResponse, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
resp := &api.ShowResponse{
|
resp := &api.ShowResponse{
|
||||||
License: strings.Join(m.License, "\n"),
|
License: strings.Join(m.License, "\n"),
|
||||||
System: m.System,
|
System: m.System,
|
||||||
Template: m.Template.String(),
|
Template: m.Template.String(),
|
||||||
Details: modelDetails,
|
Details: modelDetails,
|
||||||
Messages: msgs,
|
Messages: msgs,
|
||||||
Capabilities: m.Capabilities(),
|
Capabilities: cache.Capabilities(cache.ModelInfo{
|
||||||
ModifiedAt: manifest.fi.ModTime(),
|
ModelPath: m.ModelPath,
|
||||||
|
Template: m.Template,
|
||||||
|
ProjectorPaths: m.ProjectorPaths,
|
||||||
|
}),
|
||||||
|
ModifiedAt: manifest.fi.ModTime(),
|
||||||
}
|
}
|
||||||
|
|
||||||
var params []string
|
var params []string
|
||||||
|
|||||||
@@ -18,8 +18,9 @@ const (
|
|||||||
)
|
)
|
||||||
|
|
||||||
type Parser struct {
|
type Parser struct {
|
||||||
tag string
|
tag string
|
||||||
tools []api.Tool
|
names []string
|
||||||
|
properties []string
|
||||||
|
|
||||||
state toolsState
|
state toolsState
|
||||||
buffer []byte
|
buffer []byte
|
||||||
@@ -33,10 +34,15 @@ func NewParser(tmpl *template.Template, tools []api.Tool) *Parser {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func NewParserWithTag(tools []api.Tool, tag string) *Parser {
|
func NewParserWithTag(tools []api.Tool, tag string) *Parser {
|
||||||
return &Parser{
|
var p Parser
|
||||||
tag: tag,
|
for _, t := range tools {
|
||||||
tools: tools,
|
p.names = append(p.names, t.Function.Name)
|
||||||
|
for r := range t.Function.Parameters.Properties {
|
||||||
|
p.properties = append(p.properties, r)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
p.tag = tag
|
||||||
|
return &p
|
||||||
}
|
}
|
||||||
|
|
||||||
// Add processes a string input to parse tool calls and content that
|
// Add processes a string input to parse tool calls and content that
|
||||||
@@ -115,40 +121,36 @@ func (p *Parser) findTag() (int, bool) {
|
|||||||
// parseToolCall finds the next complete tool call in the buffer
|
// parseToolCall finds the next complete tool call in the buffer
|
||||||
// incrementing n and advancing the buffer.
|
// incrementing n and advancing the buffer.
|
||||||
func (p *Parser) parseToolCall() *api.ToolCall {
|
func (p *Parser) parseToolCall() *api.ToolCall {
|
||||||
var tool *api.Tool
|
var name string
|
||||||
|
var args map[string]any
|
||||||
var end int = len(p.buffer)
|
var end int = len(p.buffer)
|
||||||
var i int
|
|
||||||
|
|
||||||
// find tool name
|
// find tool name
|
||||||
for _, t := range p.tools {
|
var i int
|
||||||
n := t.Function.Name
|
for _, n := range p.names {
|
||||||
if i = bytes.Index(p.buffer, []byte(n)); i != -1 {
|
if i = bytes.Index(p.buffer, []byte(n)); i != -1 {
|
||||||
if i+len(n) < end {
|
if i+len(n) < end {
|
||||||
tool = &t
|
name = n
|
||||||
end = i + len(n)
|
end = i + len(n)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if tool == nil {
|
if name == "" {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// only look for arguments if the tool has parameters
|
if args, i = p.findArguments(); args == nil {
|
||||||
args := map[string]any{}
|
return nil
|
||||||
if len(tool.Function.Parameters.Properties) > 0 {
|
}
|
||||||
if args, i = p.findArguments(*tool); args == nil {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
if i > end {
|
if i > end {
|
||||||
end = i
|
end = i
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
tc := &api.ToolCall{
|
tc := &api.ToolCall{
|
||||||
Function: api.ToolCallFunction{
|
Function: api.ToolCallFunction{
|
||||||
Name: tool.Function.Name,
|
Name: name,
|
||||||
Arguments: args,
|
Arguments: args,
|
||||||
Index: p.n,
|
Index: p.n,
|
||||||
},
|
},
|
||||||
@@ -160,17 +162,13 @@ func (p *Parser) parseToolCall() *api.ToolCall {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// findArguments returns the first object that appears to be
|
// findArguments returns the first object that appears to be
|
||||||
// arguments for the provided tool, returning nil
|
// arguments and the position where the arguments end, returning nil and 0 if
|
||||||
func (p *Parser) findArguments(tool api.Tool) (map[string]any, int) {
|
// an invalid JSON object or non-arguments object is found first
|
||||||
|
func (p *Parser) findArguments() (map[string]any, int) {
|
||||||
if len(p.buffer) == 0 {
|
if len(p.buffer) == 0 {
|
||||||
return nil, 0
|
return nil, 0
|
||||||
}
|
}
|
||||||
|
|
||||||
// no arguments to parse
|
|
||||||
if len(tool.Function.Parameters.Properties) == 0 {
|
|
||||||
return nil, 0
|
|
||||||
}
|
|
||||||
|
|
||||||
var braces int
|
var braces int
|
||||||
var start int = -1
|
var start int = -1
|
||||||
var end int
|
var end int
|
||||||
@@ -186,13 +184,11 @@ func (p *Parser) findArguments(tool api.Tool) (map[string]any, int) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if c == '}' {
|
if c == '}' {
|
||||||
if start != -1 {
|
braces--
|
||||||
braces--
|
if braces == 0 && start != -1 {
|
||||||
if braces == 0 {
|
end = i + 1
|
||||||
end = i + 1
|
object = p.buffer[start:end]
|
||||||
object = p.buffer[start:end]
|
break
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -210,27 +206,24 @@ func (p *Parser) findArguments(tool api.Tool) (map[string]any, int) {
|
|||||||
|
|
||||||
var find func(obj any) map[string]any
|
var find func(obj any) map[string]any
|
||||||
find = func(obj any) map[string]any {
|
find = func(obj any) map[string]any {
|
||||||
switch obj := obj.(type) {
|
switch v := obj.(type) {
|
||||||
case map[string]any:
|
case map[string]any:
|
||||||
found := true
|
// check if the object keys are valid tool properties
|
||||||
for key := range obj {
|
// TODO (jmorganca): check only sets of properties that
|
||||||
if _, exists := tool.Function.Parameters.Properties[key]; !exists {
|
// go together instead of the entire set
|
||||||
found = false
|
for _, prop := range p.properties {
|
||||||
break
|
if _, exists := v[prop]; exists {
|
||||||
|
return v
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if found {
|
for _, value := range v {
|
||||||
return obj
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, value := range obj {
|
|
||||||
if result := find(value); result != nil {
|
if result := find(value); result != nil {
|
||||||
return result
|
return result
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
case []any:
|
case []any:
|
||||||
for _, item := range obj {
|
for _, item := range v {
|
||||||
if result := find(item); result != nil {
|
if result := find(item); result != nil {
|
||||||
return result
|
return result
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -104,13 +104,6 @@ func TestParser(t *testing.T) {
|
|||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
{
|
|
||||||
Type: "function",
|
|
||||||
Function: api.ToolFunction{
|
|
||||||
Name: "say_hello",
|
|
||||||
Description: "Say hello",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
}
|
||||||
|
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
@@ -151,20 +144,6 @@ func TestParser(t *testing.T) {
|
|||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
{
|
|
||||||
name: "invalid arguments",
|
|
||||||
inputs: []string{`<tool_call>{"name": "get_conditions", "arguments": {"city": "San Francisco"}}</tool_call>`},
|
|
||||||
content: "",
|
|
||||||
tmpl: qwen,
|
|
||||||
calls: nil,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "missing args",
|
|
||||||
inputs: []string{`<tool_call>{"name": "get_conditions"}</tool_call>`},
|
|
||||||
content: "",
|
|
||||||
tmpl: qwen,
|
|
||||||
calls: nil,
|
|
||||||
},
|
|
||||||
{
|
{
|
||||||
name: "text before tool call",
|
name: "text before tool call",
|
||||||
inputs: []string{`Let me check the weather. <tool_call>{"name": "get_temperature", "arguments": {"city": "New York"}}</tool_call>`},
|
inputs: []string{`Let me check the weather. <tool_call>{"name": "get_temperature", "arguments": {"city": "New York"}}</tool_call>`},
|
||||||
@@ -182,28 +161,6 @@ func TestParser(t *testing.T) {
|
|||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
{
|
|
||||||
name: "qwen no args tool call",
|
|
||||||
inputs: []string{`Let me say hello to the user. I'll use the say_hello tool <tool_call>{"name": "say_hello"}</tool_call>`},
|
|
||||||
content: "Let me say hello to the user. I'll use the say_hello tool ",
|
|
||||||
tmpl: qwen,
|
|
||||||
calls: []api.ToolCall{
|
|
||||||
{
|
|
||||||
Function: api.ToolCallFunction{
|
|
||||||
Index: 0,
|
|
||||||
Name: "say_hello",
|
|
||||||
Arguments: api.ToolCallFunctionArguments{},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "qwen no args with text",
|
|
||||||
inputs: []string{"Let me say hello to the user. I'll use the say_hello tool. "},
|
|
||||||
content: "Let me say hello to the user. I'll use the say_hello tool. ",
|
|
||||||
tmpl: qwen,
|
|
||||||
calls: nil,
|
|
||||||
},
|
|
||||||
{
|
{
|
||||||
name: "two tool calls in a list",
|
name: "two tool calls in a list",
|
||||||
inputs: []string{`[TOOL_CALLS] [{"name": "get_temperature", "arguments": {"city": "London", "format": "fahrenheit"}}, {"name": "get_conditions", "arguments": {"location": "Tokyo"}}][/TOOL_CALLS]`},
|
inputs: []string{`[TOOL_CALLS] [{"name": "get_temperature", "arguments": {"city": "London", "format": "fahrenheit"}}, {"name": "get_conditions", "arguments": {"location": "Tokyo"}}][/TOOL_CALLS]`},
|
||||||
@@ -232,7 +189,7 @@ func TestParser(t *testing.T) {
|
|||||||
},
|
},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "qwen two tool calls",
|
name: "two tool calls",
|
||||||
inputs: []string{`Okay, let's call both tools! <tool_call>{"name": "get_temperature", "arguments": {"city": "London", "format": "fahrenheit"}}</tool_call><tool_call>{"name": "get_conditions", "arguments": {"location": "Tokyo"}}</tool_call>`},
|
inputs: []string{`Okay, let's call both tools! <tool_call>{"name": "get_temperature", "arguments": {"city": "London", "format": "fahrenheit"}}</tool_call><tool_call>{"name": "get_conditions", "arguments": {"location": "Tokyo"}}</tool_call>`},
|
||||||
content: "Okay, let's call both tools! ",
|
content: "Okay, let's call both tools! ",
|
||||||
tmpl: qwen,
|
tmpl: qwen,
|
||||||
@@ -258,30 +215,6 @@ func TestParser(t *testing.T) {
|
|||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
{
|
|
||||||
name: "qwen two tool calls one with no args",
|
|
||||||
inputs: []string{`Let me check the weather. <tool_call>{"name": "say_hello"}</tool_call><tool_call>{"name": "get_conditions", "arguments": {"location": "Tokyo"}}`},
|
|
||||||
content: "Let me check the weather. ",
|
|
||||||
tmpl: qwen,
|
|
||||||
calls: []api.ToolCall{
|
|
||||||
{
|
|
||||||
Function: api.ToolCallFunction{
|
|
||||||
Index: 0,
|
|
||||||
Name: "say_hello",
|
|
||||||
Arguments: api.ToolCallFunctionArguments{},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
Function: api.ToolCallFunction{
|
|
||||||
Index: 1,
|
|
||||||
Name: "get_conditions",
|
|
||||||
Arguments: api.ToolCallFunctionArguments{
|
|
||||||
"location": "Tokyo",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
{
|
||||||
name: "deepseek",
|
name: "deepseek",
|
||||||
inputs: []string{"<think>Wait, I need to call a tool</think><|tool▁calls▁begin|><|tool▁call▁begin|>function<|tool▁sep|>get_temperature\n```json\n{\"city\": \"Tokyo\"}\n```<|tool▁call▁end|><|tool▁calls▁end|><|end▁of▁sentence|>"},
|
inputs: []string{"<think>Wait, I need to call a tool</think><|tool▁calls▁begin|><|tool▁call▁begin|>function<|tool▁sep|>get_temperature\n```json\n{\"city\": \"Tokyo\"}\n```<|tool▁call▁end|><|tool▁calls▁end|><|end▁of▁sentence|>"},
|
||||||
@@ -405,52 +338,6 @@ func TestParser(t *testing.T) {
|
|||||||
content: "for { fmt.Println(\"hello\") }",
|
content: "for { fmt.Println(\"hello\") }",
|
||||||
tmpl: json,
|
tmpl: json,
|
||||||
},
|
},
|
||||||
{
|
|
||||||
name: "json no args tool call",
|
|
||||||
inputs: []string{
|
|
||||||
"{\"name\": \"say_hello\"}",
|
|
||||||
},
|
|
||||||
content: "",
|
|
||||||
tmpl: json,
|
|
||||||
calls: []api.ToolCall{
|
|
||||||
{
|
|
||||||
Function: api.ToolCallFunction{
|
|
||||||
Index: 0,
|
|
||||||
Name: "say_hello",
|
|
||||||
Arguments: api.ToolCallFunctionArguments{},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "json no args no tool call",
|
|
||||||
inputs: []string{
|
|
||||||
"I'll use the say_hello tool to say hello to the user.",
|
|
||||||
},
|
|
||||||
content: "I'll use the say_hello tool to say hello to the user.",
|
|
||||||
tmpl: json,
|
|
||||||
calls: nil,
|
|
||||||
},
|
|
||||||
|
|
||||||
// TODO (jmorganca): this is a false positive, we should
|
|
||||||
// not be parsing this as a tool call
|
|
||||||
{
|
|
||||||
name: "json no args false positive",
|
|
||||||
inputs: []string{
|
|
||||||
`{say_hello!!!}`,
|
|
||||||
},
|
|
||||||
content: "",
|
|
||||||
tmpl: json,
|
|
||||||
calls: []api.ToolCall{
|
|
||||||
{
|
|
||||||
Function: api.ToolCallFunction{
|
|
||||||
Index: 0,
|
|
||||||
Name: "say_hello",
|
|
||||||
Arguments: api.ToolCallFunctionArguments{},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
{
|
||||||
name: "list multiple",
|
name: "list multiple",
|
||||||
inputs: []string{
|
inputs: []string{
|
||||||
@@ -493,30 +380,6 @@ func TestParser(t *testing.T) {
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "list partial",
|
name: "list partial",
|
||||||
inputs: []string{
|
|
||||||
"[{",
|
|
||||||
"\"name\": \"get_conditions\", ",
|
|
||||||
"\"arguments\": {",
|
|
||||||
"\"location\": \"Tokyo\"",
|
|
||||||
"}",
|
|
||||||
"}",
|
|
||||||
},
|
|
||||||
content: "",
|
|
||||||
tmpl: list,
|
|
||||||
calls: []api.ToolCall{
|
|
||||||
{
|
|
||||||
Function: api.ToolCallFunction{
|
|
||||||
Index: 0,
|
|
||||||
Name: "get_conditions",
|
|
||||||
Arguments: api.ToolCallFunctionArguments{
|
|
||||||
"location": "Tokyo",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "list invalid",
|
|
||||||
inputs: []string{
|
inputs: []string{
|
||||||
"[",
|
"[",
|
||||||
"{",
|
"{",
|
||||||
@@ -530,33 +393,6 @@ func TestParser(t *testing.T) {
|
|||||||
tmpl: list,
|
tmpl: list,
|
||||||
calls: nil,
|
calls: nil,
|
||||||
},
|
},
|
||||||
{
|
|
||||||
name: "list trailing ]",
|
|
||||||
inputs: []string{
|
|
||||||
"[",
|
|
||||||
"{",
|
|
||||||
"\"name\": \"get_conditions\", ",
|
|
||||||
"\"arguments\": {",
|
|
||||||
"\"location\": \"Tokyo\"",
|
|
||||||
"}",
|
|
||||||
"}",
|
|
||||||
"]",
|
|
||||||
"]",
|
|
||||||
},
|
|
||||||
content: "",
|
|
||||||
tmpl: list,
|
|
||||||
calls: []api.ToolCall{
|
|
||||||
{
|
|
||||||
Function: api.ToolCallFunction{
|
|
||||||
Index: 0,
|
|
||||||
Name: "get_conditions",
|
|
||||||
Arguments: api.ToolCallFunctionArguments{
|
|
||||||
"location": "Tokyo",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
{
|
||||||
name: "list not a tool call",
|
name: "list not a tool call",
|
||||||
inputs: []string{
|
inputs: []string{
|
||||||
@@ -568,26 +404,6 @@ func TestParser(t *testing.T) {
|
|||||||
tmpl: list,
|
tmpl: list,
|
||||||
calls: nil,
|
calls: nil,
|
||||||
},
|
},
|
||||||
{
|
|
||||||
name: "list with no arguments",
|
|
||||||
inputs: []string{
|
|
||||||
"[",
|
|
||||||
"{",
|
|
||||||
"\"name\": \"say_hello\"",
|
|
||||||
"}",
|
|
||||||
},
|
|
||||||
content: "",
|
|
||||||
tmpl: list,
|
|
||||||
calls: []api.ToolCall{
|
|
||||||
{
|
|
||||||
Function: api.ToolCallFunction{
|
|
||||||
Index: 0,
|
|
||||||
Name: "say_hello",
|
|
||||||
Arguments: api.ToolCallFunctionArguments{},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
@@ -884,75 +700,25 @@ func TestFindTag(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestFindArguments(t *testing.T) {
|
func TestFindArguments(t *testing.T) {
|
||||||
tool := api.Tool{
|
|
||||||
Type: "function",
|
|
||||||
Function: api.ToolFunction{
|
|
||||||
Name: "get_temperature",
|
|
||||||
Description: "Retrieve the temperature for a given location",
|
|
||||||
Parameters: struct {
|
|
||||||
Type string `json:"type"`
|
|
||||||
Defs any `json:"$defs,omitempty"`
|
|
||||||
Items any `json:"items,omitempty"`
|
|
||||||
Required []string `json:"required"`
|
|
||||||
Properties map[string]struct {
|
|
||||||
Type api.PropertyType `json:"type"`
|
|
||||||
Items any `json:"items,omitempty"`
|
|
||||||
Description string `json:"description"`
|
|
||||||
Enum []any `json:"enum,omitempty"`
|
|
||||||
} `json:"properties"`
|
|
||||||
}{
|
|
||||||
Type: "object",
|
|
||||||
Properties: map[string]struct {
|
|
||||||
Type api.PropertyType `json:"type"`
|
|
||||||
Items any `json:"items,omitempty"`
|
|
||||||
Description string `json:"description"`
|
|
||||||
Enum []any `json:"enum,omitempty"`
|
|
||||||
}{
|
|
||||||
"format": {
|
|
||||||
Type: api.PropertyType{"string"},
|
|
||||||
Description: "The format to return the temperature in",
|
|
||||||
Enum: []any{"fahrenheit", "celsius"},
|
|
||||||
},
|
|
||||||
"location": {
|
|
||||||
Type: api.PropertyType{"string"},
|
|
||||||
Description: "The location to get the temperature for",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
tool2 := api.Tool{
|
|
||||||
Type: "function",
|
|
||||||
Function: api.ToolFunction{
|
|
||||||
Name: "say_hello",
|
|
||||||
Description: "Say hello to the user",
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
buffer []byte
|
buffer []byte
|
||||||
want map[string]any
|
want map[string]any
|
||||||
tool api.Tool
|
|
||||||
}{
|
}{
|
||||||
{
|
{
|
||||||
name: "empty string",
|
name: "empty string",
|
||||||
buffer: []byte{},
|
buffer: []byte{},
|
||||||
want: nil,
|
want: nil,
|
||||||
tool: tool,
|
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "whitespace only",
|
name: "whitespace only",
|
||||||
buffer: []byte(" \n\t "),
|
buffer: []byte(" \n\t "),
|
||||||
want: nil,
|
want: nil,
|
||||||
tool: tool,
|
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "unbalanced braces - missing closing",
|
name: "unbalanced braces - missing closing",
|
||||||
buffer: []byte(`{"format": "fahrenheit", "location": "San Francisco"`),
|
buffer: []byte(`{"format": "fahrenheit", "location": "San Francisco"`),
|
||||||
want: nil,
|
want: nil,
|
||||||
tool: tool,
|
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "unbalanced braces - extra closing",
|
name: "unbalanced braces - extra closing",
|
||||||
@@ -960,13 +726,11 @@ func TestFindArguments(t *testing.T) {
|
|||||||
want: map[string]any{
|
want: map[string]any{
|
||||||
"format": "fahrenheit",
|
"format": "fahrenheit",
|
||||||
},
|
},
|
||||||
tool: tool,
|
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "invalid JSON",
|
name: "invalid JSON",
|
||||||
buffer: []byte(`{format: fahrenheit, location: "San Francisco"}`),
|
buffer: []byte(`{format: fahrenheit, location: "San Francisco"}`),
|
||||||
want: nil,
|
want: nil,
|
||||||
tool: tool,
|
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "valid json",
|
name: "valid json",
|
||||||
@@ -975,7 +739,6 @@ func TestFindArguments(t *testing.T) {
|
|||||||
"format": "fahrenheit",
|
"format": "fahrenheit",
|
||||||
"location": "San Francisco, CA",
|
"location": "San Francisco, CA",
|
||||||
},
|
},
|
||||||
tool: tool,
|
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "valid arguments with special tokens",
|
name: "valid arguments with special tokens",
|
||||||
@@ -984,7 +747,6 @@ func TestFindArguments(t *testing.T) {
|
|||||||
"format": "fahrenheit",
|
"format": "fahrenheit",
|
||||||
"location": "San Francisco, CA",
|
"location": "San Francisco, CA",
|
||||||
},
|
},
|
||||||
tool: tool,
|
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "valid arguments in array",
|
name: "valid arguments in array",
|
||||||
@@ -993,7 +755,6 @@ func TestFindArguments(t *testing.T) {
|
|||||||
"format": "fahrenheit",
|
"format": "fahrenheit",
|
||||||
"location": "San Francisco, CA",
|
"location": "San Francisco, CA",
|
||||||
},
|
},
|
||||||
tool: tool,
|
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "nested deep",
|
name: "nested deep",
|
||||||
@@ -1002,49 +763,39 @@ func TestFindArguments(t *testing.T) {
|
|||||||
"format": "fahrenheit",
|
"format": "fahrenheit",
|
||||||
"location": "San Francisco, CA",
|
"location": "San Francisco, CA",
|
||||||
},
|
},
|
||||||
tool: tool,
|
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "one arg",
|
name: "one arg",
|
||||||
buffer: []byte(`get_temperature({"location": "San Francisco, CA"})`),
|
buffer: []byte(`get_weather({"location": "San Francisco, CA"})`),
|
||||||
want: map[string]any{
|
want: map[string]any{
|
||||||
"location": "San Francisco, CA",
|
"location": "San Francisco, CA",
|
||||||
},
|
},
|
||||||
tool: tool,
|
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "two args",
|
name: "two args",
|
||||||
buffer: []byte(`[{"name": "get_temperature", "arguments": {"location": "San Francisco, CA", "format": "fahrenheit"}}, {"name": "get_weather", "arguments": {"location": "San Francisco, CA", "format": "fahrenheit"}}]`),
|
buffer: []byte(`[{"name": "get_weather", "arguments": {"location": "San Francisco, CA", "format": "fahrenheit"}}, {"name": "get_weather", "arguments": {"location": "San Francisco, CA", "format": "fahrenheit"}}]`),
|
||||||
want: map[string]any{
|
want: map[string]any{
|
||||||
"location": "San Francisco, CA",
|
"location": "San Francisco, CA",
|
||||||
"format": "fahrenheit",
|
"format": "fahrenheit",
|
||||||
},
|
},
|
||||||
tool: tool,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "no args",
|
|
||||||
buffer: []byte(`{"name": "say_hello"}`),
|
|
||||||
want: nil,
|
|
||||||
tool: tool2,
|
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "deepseek",
|
name: "deepseek",
|
||||||
buffer: []byte("<|tool▁calls▁begin|><|tool▁call▁begin|>function<|tool▁sep|>get_temperature\n```json\n{\"location\": \"Tokyo\"}\n```<|tool▁call▁end|><|tool▁calls▁end|><|end▁of▁sentence|>"),
|
buffer: []byte("<|tool▁calls▁begin|><|tool▁call▁begin|>function<|tool▁sep|>get_current_weather\n```json\n{\"location\": \"Tokyo\"}\n```<|tool▁call▁end|><|tool▁calls▁end|><|end▁of▁sentence|>"),
|
||||||
want: map[string]any{
|
want: map[string]any{
|
||||||
"location": "Tokyo",
|
"location": "Tokyo",
|
||||||
},
|
},
|
||||||
tool: tool,
|
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
parser := &Parser{
|
parser := &Parser{
|
||||||
buffer: tt.buffer,
|
buffer: tt.buffer,
|
||||||
tools: []api.Tool{tool, tool2},
|
properties: []string{"format", "location"},
|
||||||
}
|
}
|
||||||
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
got, _ := parser.findArguments(tool)
|
got, _ := parser.findArguments()
|
||||||
|
|
||||||
if diff := cmp.Diff(got, tt.want); diff != "" {
|
if diff := cmp.Diff(got, tt.want); diff != "" {
|
||||||
t.Errorf("scanArguments() args mismatch (-got +want):\n%s", diff)
|
t.Errorf("scanArguments() args mismatch (-got +want):\n%s", diff)
|
||||||
|
|||||||
Reference in New Issue
Block a user