diff --git a/.github/workflows/release.yaml b/.github/workflows/release.yaml index 40871e644..902fa9ccc 100644 --- a/.github/workflows/release.yaml +++ b/.github/workflows/release.yaml @@ -67,12 +67,21 @@ jobs: install: https://developer.download.nvidia.com/compute/cuda/12.8.0/local_installers/cuda_12.8.0_571.96_windows.exe cuda-version: '12.8' flags: '' + runner_dir: 'cuda_v12' + - os: windows + arch: amd64 + preset: 'CUDA 13' + install: https://developer.download.nvidia.com/compute/cuda/13.0.0/local_installers/cuda_13.0.0_windows.exe + cuda-version: '13.0' + flags: '' + runner_dir: 'cuda_v13' - os: windows arch: amd64 preset: 'ROCm 6' install: https://download.amd.com/developer/eula/rocm-hub/AMD-Software-PRO-Edition-24.Q4-WinSvr2022-For-HIP.exe rocm-version: '6.2' flags: '-DCMAKE_C_COMPILER=clang -DCMAKE_CXX_COMPILER=clang++ -DCMAKE_C_FLAGS="-parallel-jobs=4 -Wno-ignored-attributes -Wno-deprecated-pragma" -DCMAKE_CXX_FLAGS="-parallel-jobs=4 -Wno-ignored-attributes -Wno-deprecated-pragma"' + runner_dir: '' runs-on: ${{ matrix.arch == 'arm64' && format('{0}-{1}', matrix.os, matrix.arch) || matrix.os }} environment: release env: @@ -138,7 +147,7 @@ jobs: run: | Import-Module 'C:\Program Files\Microsoft Visual Studio\2022\Enterprise\Common7\Tools\Microsoft.VisualStudio.DevShell.dll' Enter-VsDevShell -VsInstallPath 'C:\Program Files\Microsoft Visual Studio\2022\Enterprise' -SkipAutomaticLocation -DevCmdArguments '-arch=x64 -no_logo' - cmake --preset "${{ matrix.preset }}" ${{ matrix.flags }} + cmake --preset "${{ matrix.preset }}" ${{ matrix.flags }} -DOLLAMA_RUNNER_DIR="${{ matrix.runner_dir }}" cmake --build --parallel --preset "${{ matrix.preset }}" cmake --install build --component "${{ startsWith(matrix.preset, 'CUDA ') && 'CUDA' || startsWith(matrix.preset, 'ROCm ') && 'HIP' || 'CPU' }}" --strip --parallel 8 env: @@ -232,7 +241,7 @@ jobs: case "$COMPONENT" in bin/ollama) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}.tar.in ;; lib/ollama/*.so*) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}.tar.in ;; - lib/ollama/cuda_sbsa) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}.tar.in ;; + lib/ollama/cuda_v*) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}.tar.in ;; lib/ollama/cuda_jetpack5) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}-jetpack5.tar.in ;; lib/ollama/cuda_jetpack6) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}-jetpack6.tar.in ;; lib/ollama/rocm) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}-rocm.tar.in ;; diff --git a/.github/workflows/test.yaml b/.github/workflows/test.yaml index 4d8cf773c..a10ad37a9 100644 --- a/.github/workflows/test.yaml +++ b/.github/workflows/test.yaml @@ -46,7 +46,7 @@ jobs: include: - preset: CPU - preset: CUDA - container: nvidia/cuda:12.8.1-devel-ubuntu22.04 + container: nvidia/cuda:13.0.0-devel-ubuntu22.04 flags: '-DCMAKE_CUDA_ARCHITECTURES=87' - preset: ROCm container: rocm/dev-ubuntu-22.04:6.1.2 @@ -78,7 +78,7 @@ jobs: include: - preset: CPU - preset: CUDA - install: https://developer.download.nvidia.com/compute/cuda/12.8.0/local_installers/cuda_12.8.0_571.96_windows.exe + install: https://developer.download.nvidia.com/compute/cuda/13.0.0/local_installers/cuda_13.0.0_windows.exe flags: '-DCMAKE_CUDA_ARCHITECTURES=80' - preset: ROCm install: https://download.amd.com/developer/eula/rocm-hub/AMD-Software-PRO-Edition-24.Q4-WinSvr2022-For-HIP.exe @@ -102,7 +102,7 @@ jobs: $ErrorActionPreference = "Stop" if ("${{ steps.cache-install.outputs.cache-hit }}" -ne 'true') { Invoke-WebRequest -Uri "${{ matrix.install }}" -OutFile "install.exe" - Start-Process -FilePath .\install.exe -ArgumentList (@("-s", "cudart_12.8", "nvcc_12.8", "cublas_12.8", "cublas_dev_12.8")) -NoNewWindow -Wait + Start-Process -FilePath .\install.exe -ArgumentList (@("-s", "cudart_13.0", "nvcc_13.0", "cublas_13.0", "cublas_dev_13.0")) -NoNewWindow -Wait } $cudaPath = (Resolve-Path "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\*").path diff --git a/CMakeLists.txt b/CMakeLists.txt index af93af460..8503aa80e 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -25,7 +25,7 @@ set(GGML_LLAMAFILE ON) set(GGML_CUDA_PEER_MAX_BATCH_SIZE 128) set(GGML_CUDA_GRAPHS ON) set(GGML_CUDA_FA ON) -set(GGML_CUDA_COMPRESSION_MODE default) +set(GGML_CUDA_COMPRESSION_MODE size) if((CMAKE_OSX_ARCHITECTURES AND NOT CMAKE_OSX_ARCHITECTURES MATCHES "arm64") OR (NOT CMAKE_OSX_ARCHITECTURES AND NOT CMAKE_SYSTEM_PROCESSOR MATCHES "arm|aarch64|ARM64|ARMv[0-9]+")) @@ -38,7 +38,7 @@ if (CMAKE_OSX_ARCHITECTURES MATCHES "x86_64") endif() set(OLLAMA_BUILD_DIR ${CMAKE_BINARY_DIR}/lib/ollama) -set(OLLAMA_INSTALL_DIR ${CMAKE_INSTALL_PREFIX}/lib/ollama) +set(OLLAMA_INSTALL_DIR ${CMAKE_INSTALL_PREFIX}/lib/ollama/${OLLAMA_RUNNER_DIR}) set(CMAKE_RUNTIME_OUTPUT_DIRECTORY ${OLLAMA_BUILD_DIR}) set(CMAKE_RUNTIME_OUTPUT_DIRECTORY_DEBUG ${OLLAMA_BUILD_DIR}) @@ -81,7 +81,7 @@ if(CMAKE_CUDA_COMPILER) add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/ggml/ggml/src/ggml-cuda) install(TARGETS ggml-cuda RUNTIME_DEPENDENCIES - DIRECTORIES ${CUDAToolkit_BIN_DIR} ${CUDAToolkit_LIBRARY_DIR} + DIRECTORIES ${CUDAToolkit_BIN_DIR} ${CUDAToolkit_BIN_DIR}/x64 ${CUDAToolkit_LIBRARY_DIR} PRE_INCLUDE_REGEXES cublas cublasLt cudart PRE_EXCLUDE_REGEXES ".*" RUNTIME DESTINATION ${OLLAMA_INSTALL_DIR} COMPONENT CUDA diff --git a/CMakePresets.json b/CMakePresets.json index 82da950bc..79fa2e7da 100644 --- a/CMakePresets.json +++ b/CMakePresets.json @@ -18,6 +18,14 @@ "name": "CUDA", "inherits": [ "Default" ] }, + { + "name": "CUDA 11", + "inherits": [ "CUDA" ], + "cacheVariables": { + "CMAKE_CUDA_ARCHITECTURES": "50-virtual;60-virtual;61-virtual;70-virtual;75-virtual;80-virtual;86-virtual;87-virtual;89-virtual;90-virtual", + "CMAKE_CUDA_FLAGS": "-Wno-deprecated-gpu-targets -t 2" + } + }, { "name": "CUDA 12", "inherits": [ "CUDA" ], @@ -26,6 +34,14 @@ "CMAKE_CUDA_FLAGS": "-Wno-deprecated-gpu-targets -t 2" } }, + { + "name": "CUDA 13", + "inherits": [ "CUDA" ], + "cacheVariables": { + "CMAKE_CUDA_ARCHITECTURES": "75-virtual;80-virtual;86-virtual;87-virtual;89-virtual;90-virtual;90a-virtual;100-virtual;110-virtual;120-virtual;121-virtual", + "CMAKE_CUDA_FLAGS": "-t 2" + } + }, { "name": "JetPack 5", "inherits": [ "CUDA" ], @@ -76,11 +92,21 @@ "configurePreset": "CUDA", "targets": [ "ggml-cuda" ] }, + { + "name": "CUDA 11", + "inherits": [ "CUDA" ], + "configurePreset": "CUDA 11" + }, { "name": "CUDA 12", "inherits": [ "CUDA" ], "configurePreset": "CUDA 12" }, + { + "name": "CUDA 13", + "inherits": [ "CUDA" ], + "configurePreset": "CUDA 13" + }, { "name": "JetPack 5", "inherits": [ "CUDA" ], diff --git a/Dockerfile b/Dockerfile index 83e2f89e4..58b946592 100644 --- a/Dockerfile +++ b/Dockerfile @@ -50,15 +50,35 @@ RUN --mount=type=cache,target=/root/.ccache \ && cmake --build --parallel --preset 'CPU' \ && cmake --install build --component CPU --strip --parallel 8 +FROM base AS cuda-11 +ARG CUDA11VERSION=11.8 +RUN dnf install -y cuda-toolkit-${CUDA11VERSION//./-} +ENV PATH=/usr/local/cuda-11/bin:$PATH +RUN --mount=type=cache,target=/root/.ccache \ + cmake --preset 'CUDA 11' -DOLLAMA_RUNNER_DIR="cuda_v11" \ + && cmake --build --parallel --preset 'CUDA 11' \ + && cmake --install build --component CUDA --strip --parallel 8 + FROM base AS cuda-12 ARG CUDA12VERSION=12.8 RUN dnf install -y cuda-toolkit-${CUDA12VERSION//./-} ENV PATH=/usr/local/cuda-12/bin:$PATH RUN --mount=type=cache,target=/root/.ccache \ - cmake --preset 'CUDA 12' \ + cmake --preset 'CUDA 12' -DOLLAMA_RUNNER_DIR="cuda_v12"\ && cmake --build --parallel --preset 'CUDA 12' \ && cmake --install build --component CUDA --strip --parallel 8 + +FROM base AS cuda-13 +ARG CUDA13VERSION=13.0 +RUN dnf install -y cuda-toolkit-${CUDA13VERSION//./-} +ENV PATH=/usr/local/cuda-13/bin:$PATH +RUN --mount=type=cache,target=/root/.ccache \ + cmake --preset 'CUDA 13' -DOLLAMA_RUNNER_DIR="cuda_v13" \ + && cmake --build --parallel --preset 'CUDA 13' \ + && cmake --install build --component CUDA --strip --parallel 8 + + FROM base AS rocm-6 ENV PATH=/opt/rocm/hcc/bin:/opt/rocm/hip/bin:/opt/rocm/bin:/opt/rocm/hcc/bin:$PATH RUN --mount=type=cache,target=/root/.ccache \ @@ -110,11 +130,15 @@ RUN --mount=type=cache,target=/root/.cache/go-build \ go build -trimpath -buildmode=pie -o /bin/ollama . FROM --platform=linux/amd64 scratch AS amd64 -COPY --from=cuda-12 dist/lib/ollama /lib/ollama +# COPY --from=cuda-11 dist/lib/ollama/ /lib/ollama/ +COPY --from=cuda-12 dist/lib/ollama /lib/ollama/ +COPY --from=cuda-13 dist/lib/ollama/ /lib/ollama/ COPY --from=vulkan dist/lib/ollama/vulkan /lib/ollama/vulkan FROM --platform=linux/arm64 scratch AS arm64 -COPY --from=cuda-12 dist/lib/ollama /lib/ollama/cuda_sbsa +# COPY --from=cuda-11 dist/lib/ollama/ /lib/ollama/ +COPY --from=cuda-12 dist/lib/ollama /lib/ollama/ +COPY --from=cuda-13 dist/lib/ollama/ /lib/ollama/ COPY --from=jetpack-5 dist/lib/ollama /lib/ollama/cuda_jetpack5 COPY --from=jetpack-6 dist/lib/ollama /lib/ollama/cuda_jetpack6 diff --git a/README.md b/README.md index 26742e8a0..5962f5b28 100644 --- a/README.md +++ b/README.md @@ -413,6 +413,8 @@ See the [API documentation](./docs/api.md) for all endpoints. - [Mayan EDMS](https://gitlab.com/mayan-edms/mayan-edms) (Open source document management system to organize, tag, search, and automate your files with powerful Ollama driven workflows.) - [Serene Pub](https://github.com/doolijb/serene-pub) (Beginner friendly, open source AI Roleplaying App for Windows, Mac OS and Linux. Search, download and use models with Ollama all inside the app.) - [Andes](https://github.com/aqerd/andes) (A Visual Studio Code extension that provides a local UI interface for Ollama models) +- [Clueless](https://github.com/KashyapTan/clueless) (Open Source & Local Cluely: A desktop application LLM assistant to help you talk to anything on your screen using locally served Ollama models. Also undetectable to screenshare) +- [ollama-co2](https://github.com/carbonatedWaterOrg/ollama-co2) (FastAPI web interface for monitoring and managing local and remote Ollama servers with real-time model monitoring and concurrent downloads) ### Cloud diff --git a/discover/cuda_common.go b/discover/cuda_common.go index b539f6b32..ca008af63 100644 --- a/discover/cuda_common.go +++ b/discover/cuda_common.go @@ -43,14 +43,15 @@ func cudaVariant(gpuInfo CudaGPUInfo) string { } } } - return "sbsa" } - // driver 12.0 has problems with the cuda v12 library, so run v11 on those older drivers - if gpuInfo.DriverMajor < 12 || (gpuInfo.DriverMajor == 12 && gpuInfo.DriverMinor == 0) { - // The detected driver is older than Feb 2023 - slog.Warn("old CUDA driver detected - please upgrade to a newer driver", "version", fmt.Sprintf("%d.%d", gpuInfo.DriverMajor, gpuInfo.DriverMinor)) - return "v11" + if gpuInfo.DriverMajor < 13 { + // The detected driver is older than 580 (Aug 2025) + // Warn if their CC is compatible with v13 and they should upgrade their driver to get better performance + if gpuInfo.computeMajor > 7 || (gpuInfo.computeMajor == 7 && gpuInfo.computeMinor >= 5) { + slog.Warn("old CUDA driver detected - please upgrade to a newer driver for best performance", "version", fmt.Sprintf("%d.%d", gpuInfo.DriverMajor, gpuInfo.DriverMinor)) + } + return "v12" } - return "v12" + return "v13" } diff --git a/docs/linux.md b/docs/linux.md index 9a156d1dc..ce5ed860b 100644 --- a/docs/linux.md +++ b/docs/linux.md @@ -11,12 +11,13 @@ curl -fsSL https://ollama.com/install.sh | sh ## Manual install > [!NOTE] -> If you are upgrading from a prior version, you should remove the old libraries with `sudo rm -rf /usr/lib/ollama` first. +> If you are upgrading from a prior version, you **MUST** remove the old libraries with `sudo rm -rf /usr/lib/ollama` first. Download and extract the package: ```shell curl -LO https://ollama.com/download/ollama-linux-amd64.tgz +sudo rm -rf /usr/lib/ollama sudo tar -C /usr -xzf ollama-linux-amd64.tgz ``` diff --git a/docs/troubleshooting.md b/docs/troubleshooting.md index 6fdd3e85b..7647b12f9 100644 --- a/docs/troubleshooting.md +++ b/docs/troubleshooting.md @@ -92,6 +92,9 @@ If none of those resolve the problem, gather additional information and file an - Set `CUDA_ERROR_LEVEL=50` and try again to get more diagnostic logs - Check dmesg for any errors `sudo dmesg | grep -i nvrm` and `sudo dmesg | grep -i nvidia` +You may get more details for initialization failures by enabling debug prints in the uvm driver. You should only use this temporarily while troubleshooting +- `sudo rmmod nvidia_uvm` then `sudo modprobe nvidia_uvm uvm_debug_prints=1` + ## AMD GPU Discovery diff --git a/fs/ggml/ggml.go b/fs/ggml/ggml.go index 3f4374cd0..6b582b499 100644 --- a/fs/ggml/ggml.go +++ b/fs/ggml/ggml.go @@ -57,10 +57,28 @@ func (kv KV) EmbeddingLength() uint64 { return uint64(kv.Uint("embedding_length")) } +func (kv KV) HeadCount() []uint64 { + headCountDefault := uint32(1) + headCount := kv.UintOrArrayValueAsArray("attention.head_count", headCountDefault) + if len(headCount) == 1 { + headCountDefault = headCount[0] + } + nLayers := int(kv.BlockCount()) + if len(headCount) > nLayers { + slog.Warn("got more elements of attention.head_count than layers", "len(headCount)", len(headCount), "layers", nLayers) + } + out := make([]uint64, nLayers) + for i := range nLayers { + if i >= len(headCount) { + out[i] = uint64(headCountDefault) + } else { + out[i] = uint64(headCount[i]) + } + } + return out +} + func (kv KV) HeadCountMax() uint64 { - // TODO(drifkin): using the max value can cause an overestimation. In the - // future if array values become more popular, we can adapt the more invasive - // return uint64(kv.UintOrMaxArrayValue("attention.head_count", 1)) } @@ -68,6 +86,27 @@ func (kv KV) HeadCountMin() uint64 { return uint64(kv.UintOrMinArrayValue("attention.head_count", 1)) } +func (kv KV) HeadCountKV() []uint64 { + headCountKVDefault := uint32(1) + headCountKV := kv.UintOrArrayValueAsArray("attention.head_count_kv", headCountKVDefault) + if len(headCountKV) == 1 { + headCountKVDefault = headCountKV[0] + } + nLayers := int(kv.BlockCount()) + if len(headCountKV) > nLayers { + slog.Warn("got more elements of attention.head_count than layers", "len(headCountKV)", len(headCountKV), "layers", nLayers) + } + out := make([]uint64, nLayers) + for i := range nLayers { + if i >= len(headCountKV) { + out[i] = uint64(headCountKVDefault) + } else { + out[i] = uint64(headCountKV[i]) + } + } + return out +} + func (kv KV) HeadCountKVMax() uint64 { return uint64(kv.UintOrMaxArrayValue("attention.head_count_kv", 1)) } @@ -100,6 +139,26 @@ func (kv KV) ChatTemplate() string { return kv.String("tokenizer.chat_template") } +// ssm architecture parameters + +func (kv KV) SSMConvKernel() uint64 { + return uint64(kv.Uint("ssm.conv_kernel")) +} + +func (kv KV) SSMInnerSize() uint64 { + return uint64(kv.Uint("ssm.inner_size")) +} + +func (kv KV) SSMStateSize() uint64 { + return uint64(kv.Uint("ssm.state_size")) +} + +func (kv KV) SSMGroupCount() uint64 { + return uint64(kv.Uint("ssm.group_count")) +} + +// general types + func (kv KV) String(key string, defaultValue ...string) string { val, _ := keyValue(kv, key, append(defaultValue, "")...) return val @@ -131,22 +190,27 @@ func (kv KV) UintOrMinArrayValue(key string, defaultValue uint32) uint32 { } func (kv KV) UintOrArrayValue(key string, defaultValue uint32) (uint32, uint32) { + arrVal := kv.UintOrArrayValueAsArray(key, defaultValue) + return slices.Min(arrVal), slices.Max(arrVal) +} + +func (kv KV) UintOrArrayValueAsArray(key string, defaultValue uint32) []uint32 { if u32, ok := keyValue(kv, key, uint32(0)); ok { - return u32, u32 + return []uint32{u32} } else if u32s, ok := keyValue(kv, key, &array[uint32]{}); ok { - min := slices.Min(u32s.values) - max := slices.Max(u32s.values) - return min, max + return u32s.values } else if i32s, ok := keyValue(kv, key, &array[int32]{}); ok { - min := slices.Min(i32s.values) - max := slices.Max(i32s.values) - if min < 0 || max < 0 { - slog.Warn("array values are unexpectedly negative", "key", key, "min", min, "max", max) + dst := make([]uint32, len(i32s.values)) + for i, v := range i32s.values { + if v < 0 { + slog.Warn("array values are unexpectedly negative", "key", key, "i", i, "v", v) + } + dst[i] = uint32(v) } - return uint32(min), uint32(max) + return dst } - return defaultValue, defaultValue + return []uint32{defaultValue} } func (kv KV) Strings(key string, defaultValue ...[]string) []string { @@ -486,7 +550,9 @@ func (f GGML) GraphSize(context, batch uint64, numParallel int, kvCacheType stri embedding := f.KV().EmbeddingLength() heads := f.KV().HeadCountMax() + headsArr := f.KV().HeadCount() headsKV := f.KV().HeadCountKVMax() + headsKVArr := f.KV().HeadCountKV() vocab := uint64(f.KV()["tokenizer.ggml.tokens"].(*array[string]).size) embeddingHeads := f.KV().EmbeddingHeadCountMax() @@ -496,12 +562,51 @@ func (f GGML) GraphSize(context, batch uint64, numParallel int, kvCacheType stri layers := f.Tensors().GroupLayers() bytesPerElement := kvCacheBytesPerElement(kvCacheType) + + // Default for models unless special-cased below. These defaults mirror the + // cache usage in llama.cpp under the assumption that models without special + // cases below will use the llamarunner and caching will be handled by the + // llama.cpp layer. + // + // This also assumes that a layer without heads or headsKV set is recurrent + // which is usually the case. Some models (eg nemotronh) use "blocks" in + // place of layers where some are MLP blocks that don't have any cache. + // Models like this will need a special case below to be accurately + // estimated. var kvTotal uint64 kv = make([]uint64, f.KV().BlockCount()) + kvSizeAttn := uint64(0) + kvSizeRecurrent := uint64(0) for i := range kv { - kv[i] = uint64(float64(context*(embeddingHeadsK+embeddingHeadsV)*headsKV) * bytesPerElement) + headsL := headsArr[i] + headsKVL := headsKVArr[i] + if headsL > 0 && headsKVL > 0 { + // full attention layer + // NOTE: Assumes uniform values for all attn layers + kv[i] = uint64(float64(context*(embeddingHeadsK+embeddingHeadsV)*headsKVL) * bytesPerElement) + kvSizeAttn += kv[i] + } else { + // recurrent layer + ssmDConv := f.KV().SSMConvKernel() + ssmDState := f.KV().SSMStateSize() + ssmDInner := f.KV().SSMInnerSize() + ssmNGroups := f.KV().SSMGroupCount() + nEmbdR := uint64(0) + if ssmDConv > 0 { + nEmbdR = (ssmDConv - 1) * (ssmDInner + 2*ssmNGroups*ssmDState) + } + nEmbdS := ssmDState * ssmDInner + + // recurrent always uses F32 in llama.cpp backend + // https://github.com/ggml-org/llama.cpp/blob/master/src/llama-model.cpp#L18644 + bytesPerElementRecurrent := kvCacheBytesPerElement("f32") + + kv[i] = (nEmbdR + nEmbdS) * uint64(bytesPerElementRecurrent) + kvSizeRecurrent += kv[i] + } kvTotal += kv[i] } + slog.Debug("default cache size estimate", "attention MiB", float32(kvSizeAttn)/(1024.*1024.), "attention bytes", kvSizeAttn, "recurrent MiB", float32(kvSizeRecurrent)/(1024.*1024.), "recurrent bytes", kvSizeRecurrent) switch f.KV().Architecture() { case "llama", "llama4": @@ -759,12 +864,16 @@ func (llm GGML) VisionGraphSize() (weights, graphSize uint64) { // SupportsKVCacheType checks if the requested cache type is supported func (f GGML) SupportsKVCacheType(cacheType string) bool { + if cacheType == "" || cacheType == "f16" { + return true + } + if arch := f.KV().Architecture(); slices.Contains([]string{"gptoss", "gpt-oss"}, arch) { // gpt-oss uses attention with sinks which does not support quantized cache types - slog.Warn("model only supports non-quantized cache types ", "mode", arch) - return cacheType == "f16" + slog.Warn("model only supports non-quantized cache types", "model", arch) + return false } - return slices.Contains([]string{"f16", "q8_0", "q4_0"}, cacheType) + return slices.Contains([]string{"q8_0", "q4_0"}, cacheType) } // SupportsFlashAttention checks if the model supports flash attention @@ -774,6 +883,10 @@ func (f GGML) SupportsFlashAttention() bool { return false } + if arch := f.KV().Architecture(); slices.Contains([]string{"gemma2"}, arch) { + return false + } + // Check head counts match and are non-zero headCountK := f.KV().EmbeddingHeadCountK() headCountV := f.KV().EmbeddingHeadCountV() @@ -794,6 +907,8 @@ func kvCacheBytesPerElement(cacheType string) float64 { return 1 // 1/2 of fp16 case "q4_0": return 0.5 // 1/4 of fp16 + case "f32": + return 4 // f32 (default for recurrent) default: return 2 // f16 (default) } diff --git a/harmony/harmonyparser.go b/harmony/harmonyparser.go index a51819dda..3ec2c21f1 100644 --- a/harmony/harmonyparser.go +++ b/harmony/harmonyparser.go @@ -3,15 +3,29 @@ package harmony import ( "fmt" "log/slog" + "slices" "strings" "unicode" "github.com/ollama/ollama/api" "github.com/ollama/ollama/logutil" + "github.com/ollama/ollama/template" ) type harmonyParserState int +func ShouldUseHarmony(modelFamily string, template *template.Template) bool { + if slices.Contains([]string{"gptoss", "gpt-oss"}, modelFamily) { + // heuristic to check whether the template expects to be parsed via harmony: + // search for harmony tags that are nearly always used + if template.Contains("<|start|>") && template.Contains("<|end|>") { + return true + } + } + + return false +} + const ( harmonyParserState_LookingForMessageStart harmonyParserState = iota harmonyParserState_ParsingHeader @@ -75,18 +89,28 @@ func (s *HarmonyParser) AddImplicitStart() { s.acc.WriteString("<|start|>assistant") } -func (s *HarmonyParser) AddImplicitStartOrPrefill(lastMessage *api.Message) { - if lastMessage != nil && lastMessage.Role == "assistant" { - // handle prefilling conditions - if lastMessage.Content != "" { - s.acc.WriteString("<|start|>assistant<|channel|>final<|message|>") - return - } else if lastMessage.Thinking != "" { - s.acc.WriteString("<|start|>assistant<|channel|>analysis<|message|>") - return - } +func Prefill(lastMessage api.Message) string { + if lastMessage.Role != "assistant" { + return "" + } + + switch { + case strings.TrimSpace(lastMessage.Content) != "": + return "<|start|>assistant<|channel|>final<|message|>" + case strings.TrimSpace(lastMessage.Thinking) != "": + return "<|start|>assistant<|channel|>analysis<|message|>" + default: + return "" + } +} + +// AddImplicitStartOrPrefill adds an implicit start tag or prefill string if provided +func (s *HarmonyParser) AddImplicitStartOrPrefill(prefillString string) { + if strings.TrimSpace(prefillString) != "" { + s.acc.WriteString(prefillString) + } else { + s.AddImplicitStart() } - s.AddImplicitStart() } func (s *HarmonyParser) AddContent(content string) []HarmonyEvent { @@ -265,6 +289,7 @@ type HarmonyMessageHandler struct { state harmonyMessageState HarmonyParser *HarmonyParser FunctionNameMap *FunctionNameMap + ToolParser *HarmonyToolCallAccumulator } // NewHarmonyMessageHandler creates a new message handler @@ -277,12 +302,16 @@ func NewHarmonyMessageHandler() *HarmonyMessageHandler { HeaderEndTag: "<|message|>", }, FunctionNameMap: NewFunctionNameMap(), + ToolParser: &HarmonyToolCallAccumulator{ + state: harmonyToolCallState_Normal, + currentToolName: nil, + }, } } // AddContent processes the content and returns the content, thinking, and tool content. // content and thinking are already fully parsed, but tool content still needs to be passed to the tool parser -func (h *HarmonyMessageHandler) AddContent(content string, toolParser *HarmonyToolCallAccumulator) (string, string, string) { +func (h *HarmonyMessageHandler) AddContent(content string) (string, string, string) { contentSb := strings.Builder{} thinkingSb := strings.Builder{} toolContentSb := strings.Builder{} @@ -299,14 +328,14 @@ func (h *HarmonyMessageHandler) AddContent(content string, toolParser *HarmonyTo // event.Header.Recipient is the tool name, something like // "browser.search" for a built-in, or "functions.calc" for a // custom one - toolParser.SetToolName(event.Header.Recipient) + h.ToolParser.SetToolName(event.Header.Recipient) } else { h.state = harmonyMessageState_Thinking } case "commentary": if event.Header.Recipient != "" { h.state = harmonyMessageState_ToolCalling - toolParser.SetToolName(event.Header.Recipient) + h.ToolParser.SetToolName(event.Header.Recipient) } else { h.state = harmonyMessageState_Normal } @@ -329,13 +358,6 @@ func (h *HarmonyMessageHandler) AddContent(content string, toolParser *HarmonyTo return contentSb.String(), thinkingSb.String(), toolContentSb.String() } -func (h *HarmonyMessageHandler) CreateToolParser() *HarmonyToolCallAccumulator { - return &HarmonyToolCallAccumulator{ - state: harmonyToolCallState_Normal, - currentToolName: nil, - } -} - type harmonyToolCallState int const ( diff --git a/harmony/harmonyparser_test.go b/harmony/harmonyparser_test.go index b988a018f..82bf5b2de 100644 --- a/harmony/harmonyparser_test.go +++ b/harmony/harmonyparser_test.go @@ -3,6 +3,7 @@ package harmony import ( "fmt" "reflect" + "strings" "testing" ) @@ -535,3 +536,202 @@ func TestFunctionConvertAndAdd(t *testing.T) { }) } } + +func TestHarmonyMessageHandlerStreamingScenarios(t *testing.T) { + t.Run("thinking_then_content_streams", func(t *testing.T) { + handler := NewHarmonyMessageHandler() + handler.HarmonyParser.AddImplicitStart() + tp := handler.ToolParser + type step struct { + in string + wantContent string + wantThinking string + } + steps := []step{ + {in: "<|channel|>analysis<|message|>Thinking...", wantThinking: "Thinking..."}, + {in: "<|end|>", wantThinking: ""}, + {in: "<|start|>assistant<|message|>Answer", wantContent: "Answer"}, + {in: "<|end|>", wantContent: ""}, + } + for i, s := range steps { + content, thinking, tool := handler.AddContent(s.in) + if tool != "" { + tp.Add(tool) + } + if content != s.wantContent || thinking != s.wantThinking { + t.Fatalf("step %d: got (content=%q thinking=%q), want (content=%q thinking=%q)", i, content, thinking, s.wantContent, s.wantThinking) + } + } + }) + + t.Run("content_streams_as_it_arrives", func(t *testing.T) { + handler := NewHarmonyMessageHandler() + handler.HarmonyParser.AddImplicitStart() + tp := handler.ToolParser + inputs := []string{ + "<|start|>assistant<|message|>Hello", + ", world", + "!<|end|>", + } + var got []string + for _, in := range inputs { + content, thinking, tool := handler.AddContent(in) + if tool != "" { + tp.Add(tool) + } + if thinking != "" { + t.Fatalf("unexpected thinking %q", thinking) + } + if content != "" { + got = append(got, content) + } + } + want := []string{"Hello", ", world", "!"} + if !reflect.DeepEqual(got, want) { + t.Fatalf("content pieces mismatch: got %v want %v", got, want) + } + }) + + t.Run("thinking_streams_separately_from_content", func(t *testing.T) { + handler := NewHarmonyMessageHandler() + handler.HarmonyParser.AddImplicitStart() + tp := handler.ToolParser + inputs := []string{ + "<|channel|>analysis<|message|>Thinking...", + "<|end|>", + "<|start|>assistant<|message|>Answer", + "<|end|>", + } + var got []string + for _, in := range inputs { + content, thinking, tool := handler.AddContent(in) + if tool != "" { + tp.Add(tool) + } + if thinking != "" { + got = append(got, thinking) + } + if content != "" { + got = append(got, content) + } + } + want := []string{"Thinking...", "Answer"} + if !reflect.DeepEqual(got, want) { + t.Fatalf("content pieces mismatch: got %v want %v", got, want) + } + }) + + t.Run("partial_tags_buffer_until_complete", func(t *testing.T) { + handler := NewHarmonyMessageHandler() + handler.HarmonyParser.AddImplicitStart() + tp := handler.ToolParser + inputs := []string{ + "<|chan", + "nel|>analysis<|mess", + "age|>Deep ", + "thought", + "<|end|>", + "<|start|>assistant<|message|>Done", + "<|end|>", + } + var thinkingPieces []string + var contentPieces []string + for _, in := range inputs { + content, thinking, tool := handler.AddContent(in) + if tool != "" { + tp.Add(tool) + } + if thinking != "" { + thinkingPieces = append(thinkingPieces, thinking) + } + if content != "" { + contentPieces = append(contentPieces, content) + } + } + if want := []string{"Deep ", "thought"}; !reflect.DeepEqual(thinkingPieces, want) { + t.Fatalf("thinking pieces mismatch: got %v want %v", thinkingPieces, want) + } + if want := []string{"Done"}; !reflect.DeepEqual(contentPieces, want) { + t.Fatalf("content pieces mismatch: got %v want %v", contentPieces, want) + } + }) + + t.Run("simple_assistant_after_analysis", func(t *testing.T) { + handler := NewHarmonyMessageHandler() + handler.HarmonyParser.AddImplicitStart() + tp := handler.ToolParser + inputs := []string{ + "<|channel|>analysis<|message|>Think", + "<|end|>", + "<|start|>assistant<|message|>Answer", + "<|end|>", + } + var contentSb, thinkingSb strings.Builder + for _, in := range inputs { + content, thinking, tool := handler.AddContent(in) + if tool != "" { + tp.Add(tool) + } + contentSb.WriteString(content) + thinkingSb.WriteString(thinking) + } + if contentSb.String() != "Answer" { + t.Fatalf("content mismatch: got %q want %q", contentSb.String(), "Answer") + } + if thinkingSb.String() != "Think" { + t.Fatalf("thinking mismatch: got %q want %q", thinkingSb.String(), "Think") + } + }) + + t.Run("tool_call_parsed_and_returned_correctly", func(t *testing.T) { + handler := NewHarmonyMessageHandler() + handler.HarmonyParser.AddImplicitStart() + tp := handler.ToolParser + inputs := []string{ + "<|channel|>commentary to=functions.calculate<|message|>{\"expression\":\"2+2\"}<|end|>", + } + for _, in := range inputs { + content, thinking, tool := handler.AddContent(in) + if content != "" || thinking != "" { + continue + } + if tool != "" { + tp.Add(tool) + } + } + name, args := tp.Drain() + if name == nil || *name != "functions.calculate" { + t.Fatalf("unexpected tool name: %v", name) + } + if got, want := args, "{\"expression\":\"2+2\"}"; got != want { + t.Fatalf("unexpected tool args: got %s want %s", got, want) + } + }) + + t.Run("tool_call_across_chunks", func(t *testing.T) { + handler := NewHarmonyMessageHandler() + handler.HarmonyParser.AddImplicitStart() + tp := handler.ToolParser + inputs := []string{ + "<|channel|>commentary to=functions.calculate<|message|>{\"expression\":\"2+", + "2\"}", + "<|end|>", + } + for _, in := range inputs { + content, thinking, tool := handler.AddContent(in) + if content != "" || thinking != "" { + continue + } + if tool != "" { + tp.Add(tool) + } + } + name, args := tp.Drain() + if name == nil || *name != "functions.calculate" { + t.Fatalf("unexpected tool name: %v", name) + } + if got, want := args, "{\"expression\":\"2+2\"}"; got != want { + t.Fatalf("unexpected tool args: got %s want %s", got, want) + } + }) +} diff --git a/integration/api_test.go b/integration/api_test.go index 0baba8827..c39192c99 100644 --- a/integration/api_test.go +++ b/integration/api_test.go @@ -410,3 +410,99 @@ func TestAPIEmbeddings(t *testing.T) { t.Errorf("zero length embedding response") } } + +func TestAPIToolCalling(t *testing.T) { + initialTimeout := 60 * time.Second + streamTimeout := 30 * time.Second + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute) + defer cancel() + + client, _, cleanup := InitServerConnection(ctx, t) + defer cleanup() + + modelName := "qwen3:0.6b" + if err := PullIfMissing(ctx, client, modelName); err != nil { + t.Fatalf("pull failed %s", err) + } + + tools := []api.Tool{ + { + Type: "function", + Function: api.ToolFunction{ + Name: "get_weather", + Description: "Get the current weather in a given location", + Parameters: api.ToolFunctionParameters{ + Type: "object", + Required: []string{"location"}, + Properties: map[string]api.ToolProperty{ + "location": { + Type: api.PropertyType{"string"}, + Description: "The city and state, e.g. San Francisco, CA", + }, + }, + }, + }, + }, + } + + req := api.ChatRequest{ + Model: modelName, + Messages: []api.Message{ + { + Role: "user", + Content: "Call get_weather with location set to San Francisco.", + }, + }, + Tools: tools, + Options: map[string]any{ + "temperature": 0, + }, + } + + stallTimer := time.NewTimer(initialTimeout) + var gotToolCall bool + var lastToolCall api.ToolCall + + fn := func(response api.ChatResponse) error { + if len(response.Message.ToolCalls) > 0 { + gotToolCall = true + lastToolCall = response.Message.ToolCalls[len(response.Message.ToolCalls)-1] + } + if !stallTimer.Reset(streamTimeout) { + return fmt.Errorf("stall was detected while streaming response, aborting") + } + return nil + } + + stream := true + req.Stream = &stream + done := make(chan int) + var genErr error + go func() { + genErr = client.Chat(ctx, &req, fn) + done <- 0 + }() + + select { + case <-stallTimer.C: + t.Errorf("tool-calling chat never started. Timed out after: %s", initialTimeout.String()) + case <-done: + if genErr != nil { + t.Fatalf("chat failed: %v", genErr) + } + + if !gotToolCall { + t.Fatalf("expected at least one tool call, got none") + } + + if lastToolCall.Function.Name != "get_weather" { + t.Errorf("unexpected tool called: got %q want %q", lastToolCall.Function.Name, "get_weather") + } + + if _, ok := lastToolCall.Function.Arguments["location"]; !ok { + t.Errorf("expected tool arguments to include 'location', got: %s", lastToolCall.Function.Arguments.String()) + } + case <-ctx.Done(): + t.Error("outer test context done while waiting for tool-calling chat") + } +} diff --git a/integration/concurrency_test.go b/integration/concurrency_test.go index 331bb6e75..3104eacca 100644 --- a/integration/concurrency_test.go +++ b/integration/concurrency_test.go @@ -121,6 +121,7 @@ func TestMultiModelStress(t *testing.T) { // The intent is to go 1 over what can fit so we force the scheduler to thrash targetLoadCount := 0 slog.Info("Loading models to find how many can fit in VRAM before overflowing") +chooseModels: for i, model := range chosenModels { req := &api.GenerateRequest{Model: model} slog.Info("loading", "model", model) @@ -142,6 +143,13 @@ func TestMultiModelStress(t *testing.T) { slog.Info("found model load capacity", "target", targetLoadCount, "current", loaded, "chosen", chosenModels[:targetLoadCount]) break } + // Effectively limit model count to 2 on CPU only systems to avoid thrashing and timeouts + for _, m := range models.Models { + if m.SizeVRAM == 0 { + slog.Info("model running on CPU", "name", m.Name, "target", targetLoadCount, "chosen", chosenModels[:targetLoadCount]) + break chooseModels + } + } } } if targetLoadCount == len(chosenModels) { diff --git a/integration/context_test.go b/integration/context_test.go index 24c57dcf2..ca6f16087 100644 --- a/integration/context_test.go +++ b/integration/context_test.go @@ -36,7 +36,7 @@ func TestLongInputContext(t *testing.T) { if err := PullIfMissing(ctx, client, req.Model); err != nil { t.Fatalf("PullIfMissing failed: %v", err) } - DoGenerate(ctx, t, client, req, []string{"russia", "germany", "france", "england", "austria", "prussia", "individuals", "coalition", "conflict"}, 120*time.Second, 10*time.Second) + DoGenerate(ctx, t, client, req, []string{"russia", "germany", "france", "england", "austria", "prussia", "europe", "individuals", "coalition", "conflict"}, 120*time.Second, 10*time.Second) } func TestContextExhaustion(t *testing.T) { diff --git a/integration/embed_test.go b/integration/embed_test.go index 09369dbb4..eb00f4ba6 100644 --- a/integration/embed_test.go +++ b/integration/embed_test.go @@ -38,8 +38,9 @@ func TestAllMiniLMEmbeddings(t *testing.T) { defer cleanup() req := api.EmbeddingRequest{ - Model: "all-minilm", - Prompt: "why is the sky blue?", + Model: "all-minilm", + Prompt: "why is the sky blue?", + KeepAlive: &api.Duration{Duration: 10 * time.Second}, } res, err := embeddingTestHelper(ctx, client, t, req) diff --git a/integration/utils_test.go b/integration/utils_test.go index 2bb6a1570..ec74b2e3d 100644 --- a/integration/utils_test.go +++ b/integration/utils_test.go @@ -502,6 +502,22 @@ func DoGenerate(ctx context.Context, t *testing.T, client *api.Client, genReq ap done <- 0 }() + var response string + verify := func() { + // Verify the response contains the expected data + response = buf.String() + atLeastOne := false + for _, resp := range anyResp { + if strings.Contains(strings.ToLower(response), resp) { + atLeastOne = true + break + } + } + if !atLeastOne { + t.Fatalf("%s: none of %v found in %s", genReq.Model, anyResp, response) + } + } + select { case <-stallTimer.C: if buf.Len() == 0 { @@ -517,21 +533,14 @@ func DoGenerate(ctx context.Context, t *testing.T, client *api.Client, genReq ap if genErr != nil { t.Fatalf("%s failed with %s request prompt %s", genErr, genReq.Model, genReq.Prompt) } - // Verify the response contains the expected data - response := buf.String() - atLeastOne := false - for _, resp := range anyResp { - if strings.Contains(strings.ToLower(response), resp) { - atLeastOne = true - break - } - } - if !atLeastOne { - t.Fatalf("%s: none of %v found in %s", genReq.Model, anyResp, response) - } + verify() slog.Info("test pass", "model", genReq.Model, "prompt", genReq.Prompt, "contains", anyResp, "response", response) case <-ctx.Done(): - t.Error("outer test context done while waiting for generate") + // On slow systems, we might timeout before some models finish rambling, so check what we have so far to see + // if it's considered a pass - the stallTimer will detect hangs, but we want to consider slow systems a pass + // if they are still generating valid responses + slog.Warn("outer test context done while waiting for generate") + verify() } return context } @@ -599,6 +608,22 @@ func DoChat(ctx context.Context, t *testing.T, client *api.Client, req api.ChatR done <- 0 }() + var response string + verify := func() { + // Verify the response contains the expected data + response = buf.String() + atLeastOne := false + for _, resp := range anyResp { + if strings.Contains(strings.ToLower(response), resp) { + atLeastOne = true + break + } + } + if !atLeastOne { + t.Fatalf("%s: none of %v found in \"%s\" -- request was:%v", req.Model, anyResp, response, req.Messages) + } + } + select { case <-stallTimer.C: if buf.Len() == 0 { @@ -614,23 +639,14 @@ func DoChat(ctx context.Context, t *testing.T, client *api.Client, req api.ChatR if genErr != nil { t.Fatalf("%s failed with %s request prompt %v", genErr, req.Model, req.Messages) } - - // Verify the response contains the expected data - response := buf.String() - atLeastOne := false - for _, resp := range anyResp { - if strings.Contains(strings.ToLower(response), resp) { - atLeastOne = true - break - } - } - if !atLeastOne { - t.Fatalf("%s: none of %v found in \"%s\" -- request was:%v", req.Model, anyResp, response, req.Messages) - } - + verify() slog.Info("test pass", "model", req.Model, "messages", req.Messages, "contains", anyResp, "response", response) case <-ctx.Done(): - t.Error("outer test context done while waiting for generate") + // On slow systems, we might timeout before some models finish rambling, so check what we have so far to see + // if it's considered a pass - the stallTimer will detect hangs, but we want to consider slow systems a pass + // if they are still generating valid responses + slog.Warn("outer test context done while waiting for chat") + verify() } return &api.Message{Role: role, Content: buf.String()} } diff --git a/llm/memory.go b/llm/memory.go index ce128eb58..7a87b28fe 100644 --- a/llm/memory.go +++ b/llm/memory.go @@ -202,7 +202,7 @@ func estimateGPULayers(gpus []discover.GpuInfo, f *ggml.GGML, projectors []strin var kvct string if useFlashAttention { requested := strings.ToLower(envconfig.KvCacheType()) - if requested != "" && f.SupportsKVCacheType(requested) { + if f.SupportsKVCacheType(requested) { kvct = requested } } diff --git a/llm/server.go b/llm/server.go index e7e8b4da8..09987f6f6 100644 --- a/llm/server.go +++ b/llm/server.go @@ -35,6 +35,7 @@ import ( "github.com/ollama/ollama/logutil" "github.com/ollama/ollama/ml" "github.com/ollama/ollama/model" + "github.com/ollama/ollama/parser" ) type filteredEnv []string @@ -173,6 +174,8 @@ func NewLlamaServer(gpus discover.GpuInfoList, modelPath string, f *ggml.GGML, a opts.NumCtx = int(trainCtx) } + opts.NumBatch = min(opts.NumBatch, opts.NumCtx) + loadRequest := LoadRequest{LoraPath: adapters, KvSize: opts.NumCtx * numParallel, BatchSize: opts.NumBatch, Parallel: numParallel, MultiUserCache: envconfig.MultiUserCache()} defaultThreads := discover.GetSystemInfo().GetOptimalThreadCount() @@ -218,7 +221,7 @@ func NewLlamaServer(gpus discover.GpuInfoList, modelPath string, f *ggml.GGML, a // Flash Attention also supports kv cache quantization // Enable if the requested and kv cache type is supported by the model - if kvct != "" && f.SupportsKVCacheType(kvct) { + if f.SupportsKVCacheType(kvct) { loadRequest.KvCacheType = kvct } else { slog.Warn("kv cache type not supported by model", "type", kvct) @@ -1348,7 +1351,9 @@ type CompletionRequest struct { Images []ImageData Options *api.Options - Grammar string // set before sending the request to the subprocess + Grammar string // set before sending the request to the subprocess + ParserType parser.TokenParserType + PrefillString string } // DoneReason represents the reason why a completion response is done @@ -1375,13 +1380,15 @@ func (d DoneReason) String() string { } type CompletionResponse struct { - Content string `json:"content"` - DoneReason DoneReason `json:"done_reason"` - Done bool `json:"done"` - PromptEvalCount int `json:"prompt_eval_count"` - PromptEvalDuration time.Duration `json:"prompt_eval_duration"` - EvalCount int `json:"eval_count"` - EvalDuration time.Duration `json:"eval_duration"` + Content string `json:"content"` + Thinking string `json:"thinking"` + ToolCalls []api.ToolCall `json:"tool_calls"` + DoneReason DoneReason `json:"done_reason"` + Done bool `json:"done"` + PromptEvalCount int `json:"prompt_eval_count"` + PromptEvalDuration time.Duration `json:"prompt_eval_duration"` + EvalCount int `json:"eval_count"` + EvalDuration time.Duration `json:"eval_duration"` } func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn func(CompletionResponse)) error { @@ -1499,7 +1506,8 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu return fmt.Errorf("error unmarshalling llm prediction response: %v", err) } switch { - case strings.TrimSpace(c.Content) == lastToken: + // TODO(parthsareen): token repeat limit is now handled in the runner, this currently support legacy model and can be removed in the future + case strings.TrimSpace(c.Content) == lastToken && c.Content != "": tokenRepeat++ default: lastToken = strings.TrimSpace(c.Content) @@ -1512,16 +1520,14 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu return ctx.Err() } - if c.Content != "" { - fn(CompletionResponse{ - Content: c.Content, - }) - } - if c.Done { fn(c) return nil } + + if c.Content != "" || c.Thinking != "" || len(c.ToolCalls) > 0 { + fn(c) + } } } diff --git a/parser/token_parser.go b/parser/token_parser.go new file mode 100644 index 000000000..812458299 --- /dev/null +++ b/parser/token_parser.go @@ -0,0 +1,126 @@ +package parser + +import ( + "encoding/json" + "errors" + "strings" + + "github.com/ollama/ollama/api" + "github.com/ollama/ollama/harmony" +) + +type TokenParserType int + +const ( + TokenParserTypeDefault TokenParserType = iota + TokenParserTypeHarmony +) + +type TokenParser struct { + messageHandler MessageHandler + parserEngine ParserInternals + toolParser ToolParser + lastToken string + tokenRepeat int + repeatLimit int +} + +const defaultTokenRepeatLimit = 30 + +type MessageHandler interface { + AddContent(token string) (content, thinking string, toolContent string) +} + +type ParserInternals interface { + AddImplicitStartOrPrefill(prefillString string) +} + +type ToolParser interface { + Add(token string) + Drain() (toolName *string, toolContent string) +} + +// Default implementation for the TokenParser interface as a no-op passthrough +type defaultMessageHandler struct{} + +func (defaultMessageHandler) AddContent(token string) (string, string, string) { + return token, "", "" +} + +type defaultEngine struct{} + +func (defaultEngine) AddImplicitStartOrPrefill(prefillString string) {} + +type defaultToolParser struct{} + +func (defaultToolParser) Add(token string) {} + +func (defaultToolParser) Drain() (*string, string) { return nil, "" } + +func NewTokenParser(parserType TokenParserType, prefillString string) TokenParser { + switch parserType { + case TokenParserTypeHarmony: + harmonyMessageHandler := harmony.NewHarmonyMessageHandler() + harmonyMessageHandler.HarmonyParser.AddImplicitStartOrPrefill(prefillString) + return TokenParser{ + messageHandler: harmonyMessageHandler, + parserEngine: harmonyMessageHandler.HarmonyParser, + toolParser: harmonyMessageHandler.ToolParser, + repeatLimit: defaultTokenRepeatLimit, + } + + default: + return TokenParser{ + messageHandler: defaultMessageHandler{}, + parserEngine: defaultEngine{}, + toolParser: defaultToolParser{}, + repeatLimit: 30, + } + } +} + +func (p *TokenParser) AddContent(token string) (string, string, error) { + if p.repeatLimitReached(token) { + return "", "", errors.New("token repeat limit reached") + } + content, thinking, toolContent := p.messageHandler.AddContent(token) + p.toolParser.Add(toolContent) + return content, thinking, nil +} + +// repeatLimitReached updates repeat counters and returns true if the repeat limit is reached. +func (p *TokenParser) repeatLimitReached(token string) bool { + if p == nil { + return false + } + trimmed := strings.TrimSpace(token) + if trimmed == p.lastToken { + p.tokenRepeat++ + } else { + p.tokenRepeat = 0 + } + p.lastToken = trimmed + + return p.tokenRepeat >= p.repeatLimit +} + +// TODO: update to work with multiple toolcalls - unmarshalling should also happen on parser level +func (p *TokenParser) Drain() []api.ToolCall { + toolName, toolContent := p.toolParser.Drain() + if toolName != nil { + *toolName = strings.TrimPrefix(*toolName, "functions.") + var args api.ToolCallFunctionArguments + if err := json.Unmarshal([]byte(toolContent), &args); err != nil { + return nil + } + return []api.ToolCall{ + { + Function: api.ToolCallFunction{ + Name: *toolName, + Arguments: args, + }, + }, + } + } + return nil +} diff --git a/runner/ollamarunner/cache.go b/runner/ollamarunner/cache.go index 955ef9b3d..f558f7b87 100644 --- a/runner/ollamarunner/cache.go +++ b/runner/ollamarunner/cache.go @@ -34,8 +34,8 @@ type InputCache struct { func NewInputCache(model model.Model, kvCacheType string, kvSize int32, numSlots int, batchSize int, multiUserCache bool) (*InputCache, error) { numCtx := kvSize / int32(numSlots) - if numCtx < 1 { - return nil, fmt.Errorf("must have at least one kv cache entry per parallel sequence (kv: %v parallel: %v)", kvSize, numSlots) + if int(numCtx) < batchSize { + return nil, fmt.Errorf("kv size must be at least as large as batch size * parallel (kv: %v batch: %v parallel: %v)", kvSize, batchSize, numSlots) } slots := make([]InputCacheSlot, numSlots) @@ -70,11 +70,9 @@ func kvCacheTypeFromStr(s string) ml.DType { } func (c *InputCache) Close() { - if c == nil { - return + if c != nil && c.cache != nil { + c.cache.Close() } - - c.cache.Close() } // Locking: Operations on InputCacheSlot (including finding one diff --git a/runner/ollamarunner/runner.go b/runner/ollamarunner/runner.go index df3ce1d9f..201d55a16 100644 --- a/runner/ollamarunner/runner.go +++ b/runner/ollamarunner/runner.go @@ -35,6 +35,7 @@ import ( "github.com/ollama/ollama/ml" "github.com/ollama/ollama/model" "github.com/ollama/ollama/model/input" + "github.com/ollama/ollama/parser" "github.com/ollama/ollama/runner/common" "github.com/ollama/ollama/sample" @@ -781,6 +782,8 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) { return } + tokenParser := parser.NewTokenParser(req.ParserType, req.PrefillString) + if req.Options == nil { opts := api.DefaultOptions() req.Options = &opts @@ -871,8 +874,18 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) { return case content, ok := <-seq.responses: if ok { + var thinking string + var err error + content, thinking, err = tokenParser.AddContent(content) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + close(seq.quit) + return + } + if err := json.NewEncoder(w).Encode(&llm.CompletionResponse{ - Content: content, + Content: content, + Thinking: thinking, }); err != nil { http.Error(w, fmt.Sprintf("failed to encode response: %v", err), http.StatusInternalServerError) close(seq.quit) @@ -881,7 +894,9 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) { flusher.Flush() } else { + toolCalls := tokenParser.Drain() if err := json.NewEncoder(w).Encode(&llm.CompletionResponse{ + ToolCalls: toolCalls, Done: true, DoneReason: seq.doneReason, PromptEvalCount: seq.numPromptInputs, diff --git a/scripts/build_windows.ps1 b/scripts/build_windows.ps1 index 057b64c90..ff853b5a3 100644 --- a/scripts/build_windows.ps1 +++ b/scripts/build_windows.ps1 @@ -78,7 +78,7 @@ function checkEnv() { } -function buildOllama() { +function buildCPU() { mkdir -Force -path "${script:DIST_DIR}\" if ($script:ARCH -ne "arm64") { Remove-Item -ea 0 -recurse -force -path "${script:SRC_DIR}\dist\windows-${script:ARCH}" @@ -90,20 +90,72 @@ function buildOllama() { if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)} & cmake --install build --component CPU --strip if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)} + } +} +function buildCUDA11() { + # CUDA v11 claims to be compatible with MSVC 2022, but the latest updates are no longer compatible + # 19.40 is the last compiler version that works, but recent udpates are 19.43 + # So this pins to MSVC 2019 for best compatibility + mkdir -Force -path "${script:DIST_DIR}\" + if ($script:ARCH -ne "arm64") { $hashEnv = @{} Get-ChildItem env: | foreach { $hashEnv[$_.Name] = $_.Value } - if ("$script:CUDA_DIRS".Contains("v12")) { - $hashEnv.Keys | foreach { if ($_.Contains("CUDA_PATH_V12")) { $v12="$_" }} - $env:CUDAToolkit_ROOT=$hashEnv[$v12] - write-host "Building CUDA v12 backend libraries" - & cmake --fresh --preset "CUDA 12" --install-prefix $script:DIST_DIR + if ("$script:CUDA_DIRS".Contains("v11")) { + $hashEnv.Keys | foreach { if ($_.Contains("CUDA_PATH_V11")) { $x=$hashEnv[$_]; if (test-path -literalpath "$x\bin\nvcc.exe" ) { $cuda=$x} }} + write-host "Building CUDA v11 backend libraries $cuda" + $env:CUDAToolkit_ROOT=$cuda + & cmake --fresh --preset "CUDA 11" -T cuda="$cuda" -DCMAKE_CUDA_COMPILER="$cuda\bin\nvcc.exe" -G "Visual Studio 16 2019" --install-prefix $script:DIST_DIR -DOLLAMA_RUNNER_DIR="cuda_v11" + if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)} + & cmake --build --preset "CUDA 11" --config Release --parallel $script:JOBS + if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)} + & cmake --install build --component "CUDA" --strip + if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)} + } + } +} + +function buildCUDA12() { + mkdir -Force -path "${script:DIST_DIR}\" + if ($script:ARCH -ne "arm64") { + $hashEnv = @{} + Get-ChildItem env: | foreach { $hashEnv[$_.Name] = $_.Value } + if ("$script:CUDA_DIRS".Contains("v12.8")) { + $hashEnv.Keys | foreach { if ($_.Contains("CUDA_PATH_V12_8")) { $x=$hashEnv[$_]; if (test-path -literalpath "$x\bin\nvcc.exe" ) { $cuda=$x} }} + write-host "Building CUDA v12 backend libraries $cuda" + $env:CUDAToolkit_ROOT=$cuda + & cmake --fresh --preset "CUDA 12" -T cuda="$cuda" --install-prefix $script:DIST_DIR -DOLLAMA_RUNNER_DIR="cuda_v12" if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)} & cmake --build --preset "CUDA 12" --config Release --parallel $script:JOBS if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)} & cmake --install build --component "CUDA" --strip if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)} } + } +} + +function buildCUDA13() { + mkdir -Force -path "${script:DIST_DIR}\" + if ($script:ARCH -ne "arm64") { + $hashEnv = @{} + Get-ChildItem env: | foreach { $hashEnv[$_.Name] = $_.Value } + if ("$script:CUDA_DIRS".Contains("v13")) { + $hashEnv.Keys | foreach { if ($_.Contains("CUDA_PATH_V13")) { $x=$hashEnv[$_]; if (test-path -literalpath "$x\bin\nvcc.exe" ) { $cuda=$x} }} + $env:CUDAToolkit_ROOT=$cuda + write-host "Building CUDA v13 backend libraries $cuda" + & cmake --fresh --preset "CUDA 13" -T cuda="$cuda" --install-prefix $script:DIST_DIR -DOLLAMA_RUNNER_DIR="cuda_v13" + if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)} + & cmake --build --preset "CUDA 13" --config Release --parallel $script:JOBS + if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)} + & cmake --install build --component "CUDA" --strip + if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)} + } + } +} + +function buildROCm() { + mkdir -Force -path "${script:DIST_DIR}\" + if ($script:ARCH -ne "arm64") { if ($env:HIP_PATH) { write-host "Building ROCm backend libraries" if (-Not (get-command -ErrorAction silent ninja)) { @@ -138,6 +190,10 @@ function buildOllama() { if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)} } } +} + +function buildOllama() { + mkdir -Force -path "${script:DIST_DIR}\" write-host "Building ollama CLI" & go build -trimpath -ldflags "-s -w -X=github.com/ollama/ollama/version.Version=$script:VERSION -X=github.com/ollama/ollama/server.mode=release" . if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)} @@ -245,6 +301,10 @@ function distZip() { checkEnv try { if ($($args.count) -eq 0) { + buildCPU + buildCUDA12 + buildCUDA13 + buildROCm buildOllama buildApp gatherDependencies diff --git a/server/routes.go b/server/routes.go index e6e4e2c47..ac4df4a46 100644 --- a/server/routes.go +++ b/server/routes.go @@ -36,6 +36,7 @@ import ( "github.com/ollama/ollama/llm" "github.com/ollama/ollama/logutil" "github.com/ollama/ollama/openai" + "github.com/ollama/ollama/parser" "github.com/ollama/ollama/server/internal/client/ollama" "github.com/ollama/ollama/server/internal/registry" "github.com/ollama/ollama/template" @@ -46,18 +47,6 @@ import ( "github.com/ollama/ollama/version" ) -func shouldUseHarmony(model *Model) bool { - if slices.Contains([]string{"gptoss", "gpt-oss"}, model.Config.ModelFamily) { - // heuristic to check whether the template expects to be parsed via harmony: - // search for harmony tags that are nearly always used - if model.Template.Contains("<|start|>") && model.Template.Contains("<|end|>") { - return true - } - } - - return false -} - func experimentEnabled(name string) bool { return slices.Contains(strings.Split(os.Getenv("OLLAMA_EXPERIMENT"), ","), name) } @@ -207,13 +196,17 @@ func (s *Server) GenerateHandler(c *gin.Context) { return } - useHarmony := shouldUseHarmony(m) && !req.Raw - var harmonyMessageHandler *harmony.HarmonyMessageHandler - var harmonyToolParser *harmony.HarmonyToolCallAccumulator + useHarmony := harmony.ShouldUseHarmony(m.Config.ModelFamily, m.Template) && !req.Raw + var parserType parser.TokenParserType if useHarmony { - harmonyMessageHandler = harmony.NewHarmonyMessageHandler() - harmonyMessageHandler.HarmonyParser.AddImplicitStart() - harmonyToolParser = harmonyMessageHandler.CreateToolParser() + parserType = parser.TokenParserTypeHarmony + } else { + parserType = parser.TokenParserTypeDefault + } + var functionNameMap *harmony.FunctionNameMap + + if useHarmony { + functionNameMap = harmony.NewFunctionNameMap() } // Validate Think value: string values currently only allowed for gptoss models @@ -357,16 +350,19 @@ func (s *Server) GenerateHandler(c *gin.Context) { var sb strings.Builder defer close(ch) if err := r.Completion(c.Request.Context(), llm.CompletionRequest{ - Prompt: prompt, - Images: images, - Format: req.Format, - Options: opts, + Prompt: prompt, + Images: images, + Format: req.Format, + Options: opts, + ParserType: parserType, }, func(cr llm.CompletionResponse) { res := api.GenerateResponse{ Model: req.Model, CreatedAt: time.Now().UTC(), Response: cr.Content, Done: cr.Done, + Thinking: cr.Thinking, + ToolCalls: cr.ToolCalls, Metrics: api.Metrics{ PromptEvalCount: cr.PromptEvalCount, PromptEvalDuration: cr.PromptEvalDuration, @@ -375,12 +371,22 @@ func (s *Server) GenerateHandler(c *gin.Context) { }, } + if res.Done { + res.DoneReason = cr.DoneReason.String() + res.TotalDuration = time.Since(checkpointStart) + res.LoadDuration = checkpointLoaded.Sub(checkpointStart) + } + if useHarmony { - content, thinking, toolContent := harmonyMessageHandler.AddContent(cr.Content, harmonyToolParser) - res.Response = content - res.Thinking = thinking - harmonyToolParser.Add(toolContent) - } else if thinkingState != nil { + for i, tool := range res.ToolCalls { + res.ToolCalls[i].Function.Name = functionNameMap.OriginalFromConverted(tool.Function.Name) + } + if res.Response != "" || res.Thinking != "" || len(res.ToolCalls) > 0 || res.Done { + ch <- res + } + return + } + if thinkingState != nil { thinking, content := thinkingState.AddContent(cr.Content) res.Thinking = thinking res.Response = content @@ -391,30 +397,6 @@ func (s *Server) GenerateHandler(c *gin.Context) { } if cr.Done { - if useHarmony { - toolName, toolContent := harmonyToolParser.Drain() - if toolName != nil { - *toolName = strings.TrimPrefix(*toolName, "functions.") - var args api.ToolCallFunctionArguments - if err := json.Unmarshal([]byte(toolContent), &args); err != nil { - errStr := fmt.Sprintf("error parsing tool call: raw='%s', err=%s", toolContent, err.Error()) - ch <- gin.H{"error": errStr} - return - } - - res.ToolCalls = append(res.ToolCalls, api.ToolCall{ - Function: api.ToolCallFunction{ - Name: *toolName, - Arguments: args, - }, - }) - } - } - - res.DoneReason = cr.DoneReason.String() - res.TotalDuration = time.Since(checkpointStart) - res.LoadDuration = checkpointLoaded.Sub(checkpointStart) - if !req.Raw { tokens, err := r.Tokenize(c.Request.Context(), prompt+sb.String()) if err != nil { @@ -1616,27 +1598,27 @@ func (s *Server) ChatHandler(c *gin.Context) { } msgs = filterThinkTags(msgs, m) - var harmonyMessageHandler *harmony.HarmonyMessageHandler - var harmonyToolParser *harmony.HarmonyToolCallAccumulator - - useHarmony := shouldUseHarmony(m) + useHarmony := harmony.ShouldUseHarmony(m.Config.ModelFamily, m.Template) + var parserType parser.TokenParserType + if useHarmony { + parserType = parser.TokenParserTypeHarmony + } else { + parserType = parser.TokenParserTypeDefault + } processedTools := req.Tools + var functionNameMap *harmony.FunctionNameMap + var prefillString string + // TODO(parthsareen): this can be abstracted to not be model specific and potentially moved to the runner if useHarmony { - harmonyMessageHandler = harmony.NewHarmonyMessageHandler() - var lastMessage *api.Message - if len(msgs) > 0 { - lastMessage = &msgs[len(msgs)-1] - } - harmonyMessageHandler.HarmonyParser.AddImplicitStartOrPrefill(lastMessage) - harmonyToolParser = harmonyMessageHandler.CreateToolParser() - + prefillString = harmony.Prefill(msgs[len(msgs)-1]) + functionNameMap = harmony.NewFunctionNameMap() // make a copy of tools to pass to the chat prompt. Function names may be // renamed to be valid Harmony function names. processedTools = make([]api.Tool, len(req.Tools)) copy(processedTools, req.Tools) for i, tool := range processedTools { - processedTools[i].Function.Name = harmonyMessageHandler.FunctionNameMap.ConvertAndAdd(tool.Function.Name) + processedTools[i].Function.Name = functionNameMap.ConvertAndAdd(tool.Function.Name) } } @@ -1689,15 +1671,17 @@ func (s *Server) ChatHandler(c *gin.Context) { defer close(ch) if err := r.Completion(c.Request.Context(), llm.CompletionRequest{ - Prompt: prompt, - Images: images, - Format: req.Format, - Options: opts, + Prompt: prompt, + Images: images, + Format: req.Format, + Options: opts, + ParserType: parserType, + PrefillString: prefillString, }, func(r llm.CompletionResponse) { res := api.ChatResponse{ Model: req.Model, CreatedAt: time.Now().UTC(), - Message: api.Message{Role: "assistant", Content: r.Content}, + Message: api.Message{Role: "assistant", Content: r.Content, Thinking: r.Thinking, ToolCalls: r.ToolCalls}, Done: r.Done, Metrics: api.Metrics{ PromptEvalCount: r.PromptEvalCount, @@ -1713,31 +1697,13 @@ func (s *Server) ChatHandler(c *gin.Context) { } if useHarmony { - content, thinking, toolContent := harmonyMessageHandler.AddContent(r.Content, harmonyToolParser) - res.Message.Content = content - res.Message.Thinking = thinking - harmonyToolParser.Add(toolContent) - - if r.Done { - toolName, toolContent := harmonyToolParser.Drain() - if toolName != nil { - *toolName = strings.TrimPrefix(*toolName, "functions.") - *toolName = harmonyMessageHandler.FunctionNameMap.OriginalFromConverted(*toolName) - var args api.ToolCallFunctionArguments - if err := json.Unmarshal([]byte(toolContent), &args); err != nil { - errStr := fmt.Sprintf("error parsing tool call: raw='%s', err=%s", toolContent, err.Error()) - ch <- gin.H{"error": errStr} - return - } - res.Message.ToolCalls = []api.ToolCall{{Function: api.ToolCallFunction{Name: *toolName, Arguments: args}}} - } + for i, tool := range res.Message.ToolCalls { + res.Message.ToolCalls[i].Function.Name = functionNameMap.OriginalFromConverted(tool.Function.Name) } - // only send messages with meaningful content (empty messages confuse clients) if res.Message.Content != "" || res.Message.Thinking != "" || len(res.Message.ToolCalls) > 0 || res.Done { ch <- res } - return } diff --git a/server/routes_harmony_streaming_test.go b/server/routes_harmony_streaming_test.go index b1ede4e39..bcb020886 100644 --- a/server/routes_harmony_streaming_test.go +++ b/server/routes_harmony_streaming_test.go @@ -7,7 +7,6 @@ import ( "bytes" "context" "encoding/json" - "net/http" "strings" "testing" "time" @@ -118,7 +117,7 @@ func TestChatHarmonyParserStreamingRealtime(t *testing.T) { name: "content streams as it arrives", steps: []step{ { - input: llm.CompletionResponse{Content: "<|message|>Hello", Done: false}, + input: llm.CompletionResponse{Content: "Hello", Done: false}, wantContent: "Hello", }, { @@ -126,7 +125,7 @@ func TestChatHarmonyParserStreamingRealtime(t *testing.T) { wantContent: ", world", }, { - input: llm.CompletionResponse{Content: "!<|end|>", Done: true, DoneReason: llm.DoneReasonStop}, + input: llm.CompletionResponse{Content: "!", Done: true, DoneReason: llm.DoneReasonStop}, wantContent: "!", }, }, @@ -135,20 +134,15 @@ func TestChatHarmonyParserStreamingRealtime(t *testing.T) { name: "thinking streams separately from content", steps: []step{ { - input: llm.CompletionResponse{Content: "<|channel|>analysis<|message|>Thinking...", Done: false}, + input: llm.CompletionResponse{Thinking: "Thinking...", Done: false}, wantThinking: "Thinking...", }, { - input: llm.CompletionResponse{Content: "<|end|>", Done: false}, - // No output expected - just closes the analysis message and resets state to normal + input: llm.CompletionResponse{Content: "Answer", Done: false}, + wantContent: "Answer", }, { - input: llm.CompletionResponse{Content: "<|start|>assistant<|message|>Answer", Done: false}, - wantContent: "Answer", // After message end, state is reset to normal - }, - { - input: llm.CompletionResponse{Content: "<|end|>", Done: true, DoneReason: llm.DoneReasonStop}, - // No output expected - just closes the assistant message + input: llm.CompletionResponse{Done: true, DoneReason: llm.DoneReasonStop}, }, }, }, @@ -156,24 +150,16 @@ func TestChatHarmonyParserStreamingRealtime(t *testing.T) { name: "partial tags buffer until complete", steps: []step{ { - input: llm.CompletionResponse{Content: "<|chan", Done: false}, - // No output - partial tag - }, - { - input: llm.CompletionResponse{Content: "nel|>analysis<|mess", Done: false}, - // No output - still building tags - }, - { - input: llm.CompletionResponse{Content: "age|>Deep ", Done: false}, + input: llm.CompletionResponse{Thinking: "Deep ", Done: false}, wantThinking: "Deep ", }, { - input: llm.CompletionResponse{Content: "thought<|end|>", Done: false}, + input: llm.CompletionResponse{Thinking: "thought", Done: false}, wantThinking: "thought", }, { - input: llm.CompletionResponse{Content: "<|start|>assistant<|message|>Done<|end|>", Done: true, DoneReason: llm.DoneReasonStop}, - wantContent: "Done", // After message end, state is reset to normal + input: llm.CompletionResponse{Content: "Done", Done: true, DoneReason: llm.DoneReasonStop}, + wantContent: "Done", }, }, }, @@ -181,7 +167,7 @@ func TestChatHarmonyParserStreamingRealtime(t *testing.T) { name: "simple assistant after analysis", steps: []step{ { - input: llm.CompletionResponse{Content: "<|channel|>analysis<|message|>Think<|end|><|start|>assistant<|message|>Answer<|end|>", Done: true, DoneReason: llm.DoneReasonStop}, + input: llm.CompletionResponse{Thinking: "Think", Content: "Answer", Done: true, DoneReason: llm.DoneReasonStop}, wantContent: "Answer", wantThinking: "Think", }, @@ -191,7 +177,7 @@ func TestChatHarmonyParserStreamingRealtime(t *testing.T) { name: "tool call parsed and returned correctly", steps: []step{ { - input: llm.CompletionResponse{Content: "<|channel|>commentary to=functions.get_weather<|message|>{\"location\":\"San Francisco\"}<|end|><|start|>assistant<|message|>The weather is sunny<|end|>", Done: true, DoneReason: llm.DoneReasonStop}, + input: llm.CompletionResponse{Content: "The weather is sunny", ToolCalls: []api.ToolCall{{Function: api.ToolCallFunction{Name: "get_weather", Arguments: api.ToolCallFunctionArguments{"location": "San Francisco"}}}}, Done: true, DoneReason: llm.DoneReasonStop}, wantContent: "The weather is sunny", wantToolCalls: []api.ToolCall{ { @@ -210,15 +196,10 @@ func TestChatHarmonyParserStreamingRealtime(t *testing.T) { name: "tool call with streaming JSON across chunks", steps: []step{ { - input: llm.CompletionResponse{Content: "<|channel|>commentary to=functions.calculate<|message|>{\"expr", Done: false}, - // No output yet - incomplete JSON + input: llm.CompletionResponse{Done: false}, }, { - input: llm.CompletionResponse{Content: "ession\":\"2+", Done: false}, - // Still no output - incomplete JSON - }, - { - input: llm.CompletionResponse{Content: "2\"}", Done: true}, + input: llm.CompletionResponse{ToolCalls: []api.ToolCall{{Function: api.ToolCallFunction{Name: "calculate", Arguments: api.ToolCallFunctionArguments{"expression": "2+2"}}}}, Done: true}, wantToolCalls: []api.ToolCall{ { Function: api.ToolCallFunction{ @@ -400,9 +381,9 @@ func TestChatHarmonyParserStreamingSimple(t *testing.T) { gin.SetMode(gin.TestMode) mockResponses := []llm.CompletionResponse{ - {Content: "<|message|>First ", Done: false}, + {Content: "First ", Done: false}, {Content: "chunk ", Done: false}, - {Content: "here<|end|>", Done: true, DoneReason: llm.DoneReasonStop}, + {Content: "here", Done: true, DoneReason: llm.DoneReasonStop}, } mock := mockRunner{ @@ -507,189 +488,3 @@ func TestChatHarmonyParserStreamingSimple(t *testing.T) { t.Errorf("expected at least 2 content chunks for streaming, got %d", contentChunks) } } - -func TestChatHarmonyParserStreaming(t *testing.T) { - gin.SetMode(gin.TestMode) - - type expectedChunk struct { - afterResponse int // Which mock response this chunk should appear after - content string // Expected content in this chunk - thinking string // Expected thinking in this chunk - } - - testCases := []struct { - name string - mockResponses []llm.CompletionResponse - expectedChunks []expectedChunk - wantContent string - wantThinking string - }{ - { - name: "simple message without thinking", - mockResponses: []llm.CompletionResponse{ - {Content: "<|start|>assistant<|message|>Hello, ", Done: false}, - {Content: "how can I help?", Done: false}, - {Content: "<|end|>", Done: true, DoneReason: llm.DoneReasonStop}, - }, - expectedChunks: []expectedChunk{ - {afterResponse: 1, content: "Hello, "}, - {afterResponse: 2, content: "how can I help?"}, - }, - wantContent: "Hello, how can I help?", - }, - { - name: "message with analysis channel for thinking", - mockResponses: []llm.CompletionResponse{ - {Content: "<|channel|>analysis<|message|>", Done: false}, - {Content: "Let me think ", Done: false}, - {Content: "about this problem...", Done: false}, - {Content: "<|end|>", Done: false}, - {Content: "<|start|>assistant<|message|>", Done: false}, - {Content: "The answer ", Done: false}, - {Content: "is 42", Done: false}, - {Content: "<|end|>", Done: true, DoneReason: llm.DoneReasonStop}, - }, - expectedChunks: []expectedChunk{ - {afterResponse: 2, thinking: "Let me think "}, - {afterResponse: 3, thinking: "about this problem..."}, - {afterResponse: 6, content: "The answer "}, - {afterResponse: 7, content: "is 42"}, - }, - wantContent: "The answer is 42", - wantThinking: "Let me think about this problem...", - }, - { - name: "streaming with partial tags across boundaries", - mockResponses: []llm.CompletionResponse{ - {Content: "<|chan", Done: false}, - {Content: "nel|>analy", Done: false}, - {Content: "sis<|mess", Done: false}, - {Content: "age|>Think", Done: false}, - {Content: "ing deeply...<|end|>", Done: false}, - {Content: "<|start|>assi", Done: false}, - {Content: "stant<|message|>Result ", Done: false}, - {Content: "computed<|e", Done: false}, - {Content: "nd|>", Done: true, DoneReason: llm.DoneReasonStop}, - }, - expectedChunks: []expectedChunk{ - {afterResponse: 4, thinking: "Think"}, - {afterResponse: 5, thinking: "ing deeply..."}, - {afterResponse: 7, content: "Result "}, - {afterResponse: 8, content: "computed"}, - }, - wantContent: "Result computed", - wantThinking: "Thinking deeply...", - }, - } - - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - // Channel to synchronize mock responses with chunk verification - responsesSent := make(chan int, len(tc.mockResponses)) - - mock := mockRunner{ - CompletionFn: func(ctx context.Context, r llm.CompletionRequest, fn func(llm.CompletionResponse)) error { - // Send mock responses one at a time, notifying when each is sent - for i, resp := range tc.mockResponses { - fn(resp) - responsesSent <- i + 1 - } - close(responsesSent) - return nil - }, - } - - s := Server{ - sched: &Scheduler{ - pendingReqCh: make(chan *LlmRequest, 1), - finishedReqCh: make(chan *LlmRequest, 1), - expiredCh: make(chan *runnerRef, 1), - unloadedCh: make(chan any, 1), - loaded: make(map[string]*runnerRef), - newServerFn: newMockServer(&mock), - getGpuFn: discover.GetGPUInfo, - getCpuFn: discover.GetCPUInfo, - reschedDelay: 250 * time.Millisecond, - loadFn: func(req *LlmRequest, _ *ggml.GGML, _ discover.GpuInfoList, _ bool) bool { - req.successCh <- &runnerRef{ - llama: &mock, - } - return false - }, - }, - } - - go s.sched.Run(t.Context()) - - // Create a minimal model - _, digest := createHarmonyTestModel(t) - - // Create model with passthrough template - stream := false - w := createRequest(t, s.CreateHandler, api.CreateRequest{ - Model: "harmony-test", - Files: map[string]string{"file.gguf": digest}, - Template: `<|start|><|end|>{{ with .Tools }}{{ end }}{{ .Prompt }}`, - Stream: &stream, - }) - - if w.Code != http.StatusOK { - t.Fatalf("failed to create model: %d", w.Code) - } - - // Test chat endpoint with streaming - streamTrue := true - w = createRequest(t, s.ChatHandler, api.ChatRequest{ - Model: "harmony-test", - Messages: []api.Message{{Role: "user", Content: "Hello"}}, - Stream: &streamTrue, - Tools: getTestTools(), - }) - - if w.Code != http.StatusOK { - t.Fatalf("chat request failed: %d - %s", w.Code, w.Body.String()) - } - - // Parse streaming response - var chunks []api.ChatResponse - var content, thinking strings.Builder - - decoder := json.NewDecoder(w.Body) - for decoder.More() { - var chunk api.ChatResponse - if err := decoder.Decode(&chunk); err != nil { - t.Fatalf("failed to decode chunk: %v", err) - } - chunks = append(chunks, chunk) - - // Accumulate content and thinking from each chunk - content.WriteString(chunk.Message.Content) - thinking.WriteString(chunk.Message.Thinking) - - // Debug output - t.Logf("Chunk %d: content=%q thinking=%q done=%v", len(chunks), chunk.Message.Content, chunk.Message.Thinking, chunk.Done) - } - - // Verify we got streaming chunks - if len(chunks) == 0 { - t.Fatal("expected streaming chunks, got none") - } - - gotContent := content.String() - gotThinking := thinking.String() - - if gotContent != tc.wantContent { - t.Errorf("content mismatch: got %q, want %q", gotContent, tc.wantContent) - } - if gotThinking != tc.wantThinking { - t.Errorf("thinking mismatch: got %q, want %q", gotThinking, tc.wantThinking) - } - - // Verify last chunk has done=true - lastChunk := chunks[len(chunks)-1] - if !lastChunk.Done { - t.Error("expected last chunk to have done=true") - } - }) - } -}