Merge remote-tracking branch 'upstream/main' into vulkanV3

This commit is contained in:
Inforithmics 2025-09-11 18:30:21 +02:00
commit 69ed26c93b
25 changed files with 918 additions and 428 deletions

View File

@ -67,12 +67,21 @@ jobs:
install: https://developer.download.nvidia.com/compute/cuda/12.8.0/local_installers/cuda_12.8.0_571.96_windows.exe
cuda-version: '12.8'
flags: ''
runner_dir: 'cuda_v12'
- os: windows
arch: amd64
preset: 'CUDA 13'
install: https://developer.download.nvidia.com/compute/cuda/13.0.0/local_installers/cuda_13.0.0_windows.exe
cuda-version: '13.0'
flags: ''
runner_dir: 'cuda_v13'
- os: windows
arch: amd64
preset: 'ROCm 6'
install: https://download.amd.com/developer/eula/rocm-hub/AMD-Software-PRO-Edition-24.Q4-WinSvr2022-For-HIP.exe
rocm-version: '6.2'
flags: '-DCMAKE_C_COMPILER=clang -DCMAKE_CXX_COMPILER=clang++ -DCMAKE_C_FLAGS="-parallel-jobs=4 -Wno-ignored-attributes -Wno-deprecated-pragma" -DCMAKE_CXX_FLAGS="-parallel-jobs=4 -Wno-ignored-attributes -Wno-deprecated-pragma"'
runner_dir: ''
runs-on: ${{ matrix.arch == 'arm64' && format('{0}-{1}', matrix.os, matrix.arch) || matrix.os }}
environment: release
env:
@ -138,7 +147,7 @@ jobs:
run: |
Import-Module 'C:\Program Files\Microsoft Visual Studio\2022\Enterprise\Common7\Tools\Microsoft.VisualStudio.DevShell.dll'
Enter-VsDevShell -VsInstallPath 'C:\Program Files\Microsoft Visual Studio\2022\Enterprise' -SkipAutomaticLocation -DevCmdArguments '-arch=x64 -no_logo'
cmake --preset "${{ matrix.preset }}" ${{ matrix.flags }}
cmake --preset "${{ matrix.preset }}" ${{ matrix.flags }} -DOLLAMA_RUNNER_DIR="${{ matrix.runner_dir }}"
cmake --build --parallel --preset "${{ matrix.preset }}"
cmake --install build --component "${{ startsWith(matrix.preset, 'CUDA ') && 'CUDA' || startsWith(matrix.preset, 'ROCm ') && 'HIP' || 'CPU' }}" --strip --parallel 8
env:
@ -232,7 +241,7 @@ jobs:
case "$COMPONENT" in
bin/ollama) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}.tar.in ;;
lib/ollama/*.so*) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}.tar.in ;;
lib/ollama/cuda_sbsa) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}.tar.in ;;
lib/ollama/cuda_v*) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}.tar.in ;;
lib/ollama/cuda_jetpack5) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}-jetpack5.tar.in ;;
lib/ollama/cuda_jetpack6) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}-jetpack6.tar.in ;;
lib/ollama/rocm) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}-rocm.tar.in ;;

View File

@ -46,7 +46,7 @@ jobs:
include:
- preset: CPU
- preset: CUDA
container: nvidia/cuda:12.8.1-devel-ubuntu22.04
container: nvidia/cuda:13.0.0-devel-ubuntu22.04
flags: '-DCMAKE_CUDA_ARCHITECTURES=87'
- preset: ROCm
container: rocm/dev-ubuntu-22.04:6.1.2
@ -78,7 +78,7 @@ jobs:
include:
- preset: CPU
- preset: CUDA
install: https://developer.download.nvidia.com/compute/cuda/12.8.0/local_installers/cuda_12.8.0_571.96_windows.exe
install: https://developer.download.nvidia.com/compute/cuda/13.0.0/local_installers/cuda_13.0.0_windows.exe
flags: '-DCMAKE_CUDA_ARCHITECTURES=80'
- preset: ROCm
install: https://download.amd.com/developer/eula/rocm-hub/AMD-Software-PRO-Edition-24.Q4-WinSvr2022-For-HIP.exe
@ -102,7 +102,7 @@ jobs:
$ErrorActionPreference = "Stop"
if ("${{ steps.cache-install.outputs.cache-hit }}" -ne 'true') {
Invoke-WebRequest -Uri "${{ matrix.install }}" -OutFile "install.exe"
Start-Process -FilePath .\install.exe -ArgumentList (@("-s", "cudart_12.8", "nvcc_12.8", "cublas_12.8", "cublas_dev_12.8")) -NoNewWindow -Wait
Start-Process -FilePath .\install.exe -ArgumentList (@("-s", "cudart_13.0", "nvcc_13.0", "cublas_13.0", "cublas_dev_13.0")) -NoNewWindow -Wait
}
$cudaPath = (Resolve-Path "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\*").path

View File

@ -25,7 +25,7 @@ set(GGML_LLAMAFILE ON)
set(GGML_CUDA_PEER_MAX_BATCH_SIZE 128)
set(GGML_CUDA_GRAPHS ON)
set(GGML_CUDA_FA ON)
set(GGML_CUDA_COMPRESSION_MODE default)
set(GGML_CUDA_COMPRESSION_MODE size)
if((CMAKE_OSX_ARCHITECTURES AND NOT CMAKE_OSX_ARCHITECTURES MATCHES "arm64")
OR (NOT CMAKE_OSX_ARCHITECTURES AND NOT CMAKE_SYSTEM_PROCESSOR MATCHES "arm|aarch64|ARM64|ARMv[0-9]+"))
@ -38,7 +38,7 @@ if (CMAKE_OSX_ARCHITECTURES MATCHES "x86_64")
endif()
set(OLLAMA_BUILD_DIR ${CMAKE_BINARY_DIR}/lib/ollama)
set(OLLAMA_INSTALL_DIR ${CMAKE_INSTALL_PREFIX}/lib/ollama)
set(OLLAMA_INSTALL_DIR ${CMAKE_INSTALL_PREFIX}/lib/ollama/${OLLAMA_RUNNER_DIR})
set(CMAKE_RUNTIME_OUTPUT_DIRECTORY ${OLLAMA_BUILD_DIR})
set(CMAKE_RUNTIME_OUTPUT_DIRECTORY_DEBUG ${OLLAMA_BUILD_DIR})
@ -81,7 +81,7 @@ if(CMAKE_CUDA_COMPILER)
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/ggml/ggml/src/ggml-cuda)
install(TARGETS ggml-cuda
RUNTIME_DEPENDENCIES
DIRECTORIES ${CUDAToolkit_BIN_DIR} ${CUDAToolkit_LIBRARY_DIR}
DIRECTORIES ${CUDAToolkit_BIN_DIR} ${CUDAToolkit_BIN_DIR}/x64 ${CUDAToolkit_LIBRARY_DIR}
PRE_INCLUDE_REGEXES cublas cublasLt cudart
PRE_EXCLUDE_REGEXES ".*"
RUNTIME DESTINATION ${OLLAMA_INSTALL_DIR} COMPONENT CUDA

View File

@ -18,6 +18,14 @@
"name": "CUDA",
"inherits": [ "Default" ]
},
{
"name": "CUDA 11",
"inherits": [ "CUDA" ],
"cacheVariables": {
"CMAKE_CUDA_ARCHITECTURES": "50-virtual;60-virtual;61-virtual;70-virtual;75-virtual;80-virtual;86-virtual;87-virtual;89-virtual;90-virtual",
"CMAKE_CUDA_FLAGS": "-Wno-deprecated-gpu-targets -t 2"
}
},
{
"name": "CUDA 12",
"inherits": [ "CUDA" ],
@ -26,6 +34,14 @@
"CMAKE_CUDA_FLAGS": "-Wno-deprecated-gpu-targets -t 2"
}
},
{
"name": "CUDA 13",
"inherits": [ "CUDA" ],
"cacheVariables": {
"CMAKE_CUDA_ARCHITECTURES": "75-virtual;80-virtual;86-virtual;87-virtual;89-virtual;90-virtual;90a-virtual;100-virtual;110-virtual;120-virtual;121-virtual",
"CMAKE_CUDA_FLAGS": "-t 2"
}
},
{
"name": "JetPack 5",
"inherits": [ "CUDA" ],
@ -76,11 +92,21 @@
"configurePreset": "CUDA",
"targets": [ "ggml-cuda" ]
},
{
"name": "CUDA 11",
"inherits": [ "CUDA" ],
"configurePreset": "CUDA 11"
},
{
"name": "CUDA 12",
"inherits": [ "CUDA" ],
"configurePreset": "CUDA 12"
},
{
"name": "CUDA 13",
"inherits": [ "CUDA" ],
"configurePreset": "CUDA 13"
},
{
"name": "JetPack 5",
"inherits": [ "CUDA" ],

View File

@ -50,15 +50,35 @@ RUN --mount=type=cache,target=/root/.ccache \
&& cmake --build --parallel --preset 'CPU' \
&& cmake --install build --component CPU --strip --parallel 8
FROM base AS cuda-11
ARG CUDA11VERSION=11.8
RUN dnf install -y cuda-toolkit-${CUDA11VERSION//./-}
ENV PATH=/usr/local/cuda-11/bin:$PATH
RUN --mount=type=cache,target=/root/.ccache \
cmake --preset 'CUDA 11' -DOLLAMA_RUNNER_DIR="cuda_v11" \
&& cmake --build --parallel --preset 'CUDA 11' \
&& cmake --install build --component CUDA --strip --parallel 8
FROM base AS cuda-12
ARG CUDA12VERSION=12.8
RUN dnf install -y cuda-toolkit-${CUDA12VERSION//./-}
ENV PATH=/usr/local/cuda-12/bin:$PATH
RUN --mount=type=cache,target=/root/.ccache \
cmake --preset 'CUDA 12' \
cmake --preset 'CUDA 12' -DOLLAMA_RUNNER_DIR="cuda_v12"\
&& cmake --build --parallel --preset 'CUDA 12' \
&& cmake --install build --component CUDA --strip --parallel 8
FROM base AS cuda-13
ARG CUDA13VERSION=13.0
RUN dnf install -y cuda-toolkit-${CUDA13VERSION//./-}
ENV PATH=/usr/local/cuda-13/bin:$PATH
RUN --mount=type=cache,target=/root/.ccache \
cmake --preset 'CUDA 13' -DOLLAMA_RUNNER_DIR="cuda_v13" \
&& cmake --build --parallel --preset 'CUDA 13' \
&& cmake --install build --component CUDA --strip --parallel 8
FROM base AS rocm-6
ENV PATH=/opt/rocm/hcc/bin:/opt/rocm/hip/bin:/opt/rocm/bin:/opt/rocm/hcc/bin:$PATH
RUN --mount=type=cache,target=/root/.ccache \
@ -110,11 +130,15 @@ RUN --mount=type=cache,target=/root/.cache/go-build \
go build -trimpath -buildmode=pie -o /bin/ollama .
FROM --platform=linux/amd64 scratch AS amd64
COPY --from=cuda-12 dist/lib/ollama /lib/ollama
# COPY --from=cuda-11 dist/lib/ollama/ /lib/ollama/
COPY --from=cuda-12 dist/lib/ollama /lib/ollama/
COPY --from=cuda-13 dist/lib/ollama/ /lib/ollama/
COPY --from=vulkan dist/lib/ollama/vulkan /lib/ollama/vulkan
FROM --platform=linux/arm64 scratch AS arm64
COPY --from=cuda-12 dist/lib/ollama /lib/ollama/cuda_sbsa
# COPY --from=cuda-11 dist/lib/ollama/ /lib/ollama/
COPY --from=cuda-12 dist/lib/ollama /lib/ollama/
COPY --from=cuda-13 dist/lib/ollama/ /lib/ollama/
COPY --from=jetpack-5 dist/lib/ollama /lib/ollama/cuda_jetpack5
COPY --from=jetpack-6 dist/lib/ollama /lib/ollama/cuda_jetpack6

View File

@ -413,6 +413,8 @@ See the [API documentation](./docs/api.md) for all endpoints.
- [Mayan EDMS](https://gitlab.com/mayan-edms/mayan-edms) (Open source document management system to organize, tag, search, and automate your files with powerful Ollama driven workflows.)
- [Serene Pub](https://github.com/doolijb/serene-pub) (Beginner friendly, open source AI Roleplaying App for Windows, Mac OS and Linux. Search, download and use models with Ollama all inside the app.)
- [Andes](https://github.com/aqerd/andes) (A Visual Studio Code extension that provides a local UI interface for Ollama models)
- [Clueless](https://github.com/KashyapTan/clueless) (Open Source & Local Cluely: A desktop application LLM assistant to help you talk to anything on your screen using locally served Ollama models. Also undetectable to screenshare)
- [ollama-co2](https://github.com/carbonatedWaterOrg/ollama-co2) (FastAPI web interface for monitoring and managing local and remote Ollama servers with real-time model monitoring and concurrent downloads)
### Cloud

View File

@ -43,14 +43,15 @@ func cudaVariant(gpuInfo CudaGPUInfo) string {
}
}
}
return "sbsa"
}
// driver 12.0 has problems with the cuda v12 library, so run v11 on those older drivers
if gpuInfo.DriverMajor < 12 || (gpuInfo.DriverMajor == 12 && gpuInfo.DriverMinor == 0) {
// The detected driver is older than Feb 2023
slog.Warn("old CUDA driver detected - please upgrade to a newer driver", "version", fmt.Sprintf("%d.%d", gpuInfo.DriverMajor, gpuInfo.DriverMinor))
return "v11"
if gpuInfo.DriverMajor < 13 {
// The detected driver is older than 580 (Aug 2025)
// Warn if their CC is compatible with v13 and they should upgrade their driver to get better performance
if gpuInfo.computeMajor > 7 || (gpuInfo.computeMajor == 7 && gpuInfo.computeMinor >= 5) {
slog.Warn("old CUDA driver detected - please upgrade to a newer driver for best performance", "version", fmt.Sprintf("%d.%d", gpuInfo.DriverMajor, gpuInfo.DriverMinor))
}
return "v12"
}
return "v12"
return "v13"
}

View File

@ -11,12 +11,13 @@ curl -fsSL https://ollama.com/install.sh | sh
## Manual install
> [!NOTE]
> If you are upgrading from a prior version, you should remove the old libraries with `sudo rm -rf /usr/lib/ollama` first.
> If you are upgrading from a prior version, you **MUST** remove the old libraries with `sudo rm -rf /usr/lib/ollama` first.
Download and extract the package:
```shell
curl -LO https://ollama.com/download/ollama-linux-amd64.tgz
sudo rm -rf /usr/lib/ollama
sudo tar -C /usr -xzf ollama-linux-amd64.tgz
```

View File

@ -92,6 +92,9 @@ If none of those resolve the problem, gather additional information and file an
- Set `CUDA_ERROR_LEVEL=50` and try again to get more diagnostic logs
- Check dmesg for any errors `sudo dmesg | grep -i nvrm` and `sudo dmesg | grep -i nvidia`
You may get more details for initialization failures by enabling debug prints in the uvm driver. You should only use this temporarily while troubleshooting
- `sudo rmmod nvidia_uvm` then `sudo modprobe nvidia_uvm uvm_debug_prints=1`
## AMD GPU Discovery

View File

@ -57,10 +57,28 @@ func (kv KV) EmbeddingLength() uint64 {
return uint64(kv.Uint("embedding_length"))
}
func (kv KV) HeadCount() []uint64 {
headCountDefault := uint32(1)
headCount := kv.UintOrArrayValueAsArray("attention.head_count", headCountDefault)
if len(headCount) == 1 {
headCountDefault = headCount[0]
}
nLayers := int(kv.BlockCount())
if len(headCount) > nLayers {
slog.Warn("got more elements of attention.head_count than layers", "len(headCount)", len(headCount), "layers", nLayers)
}
out := make([]uint64, nLayers)
for i := range nLayers {
if i >= len(headCount) {
out[i] = uint64(headCountDefault)
} else {
out[i] = uint64(headCount[i])
}
}
return out
}
func (kv KV) HeadCountMax() uint64 {
// TODO(drifkin): using the max value can cause an overestimation. In the
// future if array values become more popular, we can adapt the more invasive
// <https://github.com/ollama/ollama/pull/10225>
return uint64(kv.UintOrMaxArrayValue("attention.head_count", 1))
}
@ -68,6 +86,27 @@ func (kv KV) HeadCountMin() uint64 {
return uint64(kv.UintOrMinArrayValue("attention.head_count", 1))
}
func (kv KV) HeadCountKV() []uint64 {
headCountKVDefault := uint32(1)
headCountKV := kv.UintOrArrayValueAsArray("attention.head_count_kv", headCountKVDefault)
if len(headCountKV) == 1 {
headCountKVDefault = headCountKV[0]
}
nLayers := int(kv.BlockCount())
if len(headCountKV) > nLayers {
slog.Warn("got more elements of attention.head_count than layers", "len(headCountKV)", len(headCountKV), "layers", nLayers)
}
out := make([]uint64, nLayers)
for i := range nLayers {
if i >= len(headCountKV) {
out[i] = uint64(headCountKVDefault)
} else {
out[i] = uint64(headCountKV[i])
}
}
return out
}
func (kv KV) HeadCountKVMax() uint64 {
return uint64(kv.UintOrMaxArrayValue("attention.head_count_kv", 1))
}
@ -100,6 +139,26 @@ func (kv KV) ChatTemplate() string {
return kv.String("tokenizer.chat_template")
}
// ssm architecture parameters
func (kv KV) SSMConvKernel() uint64 {
return uint64(kv.Uint("ssm.conv_kernel"))
}
func (kv KV) SSMInnerSize() uint64 {
return uint64(kv.Uint("ssm.inner_size"))
}
func (kv KV) SSMStateSize() uint64 {
return uint64(kv.Uint("ssm.state_size"))
}
func (kv KV) SSMGroupCount() uint64 {
return uint64(kv.Uint("ssm.group_count"))
}
// general types
func (kv KV) String(key string, defaultValue ...string) string {
val, _ := keyValue(kv, key, append(defaultValue, "")...)
return val
@ -131,22 +190,27 @@ func (kv KV) UintOrMinArrayValue(key string, defaultValue uint32) uint32 {
}
func (kv KV) UintOrArrayValue(key string, defaultValue uint32) (uint32, uint32) {
arrVal := kv.UintOrArrayValueAsArray(key, defaultValue)
return slices.Min(arrVal), slices.Max(arrVal)
}
func (kv KV) UintOrArrayValueAsArray(key string, defaultValue uint32) []uint32 {
if u32, ok := keyValue(kv, key, uint32(0)); ok {
return u32, u32
return []uint32{u32}
} else if u32s, ok := keyValue(kv, key, &array[uint32]{}); ok {
min := slices.Min(u32s.values)
max := slices.Max(u32s.values)
return min, max
return u32s.values
} else if i32s, ok := keyValue(kv, key, &array[int32]{}); ok {
min := slices.Min(i32s.values)
max := slices.Max(i32s.values)
if min < 0 || max < 0 {
slog.Warn("array values are unexpectedly negative", "key", key, "min", min, "max", max)
dst := make([]uint32, len(i32s.values))
for i, v := range i32s.values {
if v < 0 {
slog.Warn("array values are unexpectedly negative", "key", key, "i", i, "v", v)
}
dst[i] = uint32(v)
}
return uint32(min), uint32(max)
return dst
}
return defaultValue, defaultValue
return []uint32{defaultValue}
}
func (kv KV) Strings(key string, defaultValue ...[]string) []string {
@ -486,7 +550,9 @@ func (f GGML) GraphSize(context, batch uint64, numParallel int, kvCacheType stri
embedding := f.KV().EmbeddingLength()
heads := f.KV().HeadCountMax()
headsArr := f.KV().HeadCount()
headsKV := f.KV().HeadCountKVMax()
headsKVArr := f.KV().HeadCountKV()
vocab := uint64(f.KV()["tokenizer.ggml.tokens"].(*array[string]).size)
embeddingHeads := f.KV().EmbeddingHeadCountMax()
@ -496,12 +562,51 @@ func (f GGML) GraphSize(context, batch uint64, numParallel int, kvCacheType stri
layers := f.Tensors().GroupLayers()
bytesPerElement := kvCacheBytesPerElement(kvCacheType)
// Default for models unless special-cased below. These defaults mirror the
// cache usage in llama.cpp under the assumption that models without special
// cases below will use the llamarunner and caching will be handled by the
// llama.cpp layer.
//
// This also assumes that a layer without heads or headsKV set is recurrent
// which is usually the case. Some models (eg nemotronh) use "blocks" in
// place of layers where some are MLP blocks that don't have any cache.
// Models like this will need a special case below to be accurately
// estimated.
var kvTotal uint64
kv = make([]uint64, f.KV().BlockCount())
kvSizeAttn := uint64(0)
kvSizeRecurrent := uint64(0)
for i := range kv {
kv[i] = uint64(float64(context*(embeddingHeadsK+embeddingHeadsV)*headsKV) * bytesPerElement)
headsL := headsArr[i]
headsKVL := headsKVArr[i]
if headsL > 0 && headsKVL > 0 {
// full attention layer
// NOTE: Assumes uniform values for all attn layers
kv[i] = uint64(float64(context*(embeddingHeadsK+embeddingHeadsV)*headsKVL) * bytesPerElement)
kvSizeAttn += kv[i]
} else {
// recurrent layer
ssmDConv := f.KV().SSMConvKernel()
ssmDState := f.KV().SSMStateSize()
ssmDInner := f.KV().SSMInnerSize()
ssmNGroups := f.KV().SSMGroupCount()
nEmbdR := uint64(0)
if ssmDConv > 0 {
nEmbdR = (ssmDConv - 1) * (ssmDInner + 2*ssmNGroups*ssmDState)
}
nEmbdS := ssmDState * ssmDInner
// recurrent always uses F32 in llama.cpp backend
// https://github.com/ggml-org/llama.cpp/blob/master/src/llama-model.cpp#L18644
bytesPerElementRecurrent := kvCacheBytesPerElement("f32")
kv[i] = (nEmbdR + nEmbdS) * uint64(bytesPerElementRecurrent)
kvSizeRecurrent += kv[i]
}
kvTotal += kv[i]
}
slog.Debug("default cache size estimate", "attention MiB", float32(kvSizeAttn)/(1024.*1024.), "attention bytes", kvSizeAttn, "recurrent MiB", float32(kvSizeRecurrent)/(1024.*1024.), "recurrent bytes", kvSizeRecurrent)
switch f.KV().Architecture() {
case "llama", "llama4":
@ -759,12 +864,16 @@ func (llm GGML) VisionGraphSize() (weights, graphSize uint64) {
// SupportsKVCacheType checks if the requested cache type is supported
func (f GGML) SupportsKVCacheType(cacheType string) bool {
if cacheType == "" || cacheType == "f16" {
return true
}
if arch := f.KV().Architecture(); slices.Contains([]string{"gptoss", "gpt-oss"}, arch) {
// gpt-oss uses attention with sinks which does not support quantized cache types
slog.Warn("model only supports non-quantized cache types ", "mode", arch)
return cacheType == "f16"
slog.Warn("model only supports non-quantized cache types", "model", arch)
return false
}
return slices.Contains([]string{"f16", "q8_0", "q4_0"}, cacheType)
return slices.Contains([]string{"q8_0", "q4_0"}, cacheType)
}
// SupportsFlashAttention checks if the model supports flash attention
@ -774,6 +883,10 @@ func (f GGML) SupportsFlashAttention() bool {
return false
}
if arch := f.KV().Architecture(); slices.Contains([]string{"gemma2"}, arch) {
return false
}
// Check head counts match and are non-zero
headCountK := f.KV().EmbeddingHeadCountK()
headCountV := f.KV().EmbeddingHeadCountV()
@ -794,6 +907,8 @@ func kvCacheBytesPerElement(cacheType string) float64 {
return 1 // 1/2 of fp16
case "q4_0":
return 0.5 // 1/4 of fp16
case "f32":
return 4 // f32 (default for recurrent)
default:
return 2 // f16 (default)
}

View File

@ -3,15 +3,29 @@ package harmony
import (
"fmt"
"log/slog"
"slices"
"strings"
"unicode"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/logutil"
"github.com/ollama/ollama/template"
)
type harmonyParserState int
func ShouldUseHarmony(modelFamily string, template *template.Template) bool {
if slices.Contains([]string{"gptoss", "gpt-oss"}, modelFamily) {
// heuristic to check whether the template expects to be parsed via harmony:
// search for harmony tags that are nearly always used
if template.Contains("<|start|>") && template.Contains("<|end|>") {
return true
}
}
return false
}
const (
harmonyParserState_LookingForMessageStart harmonyParserState = iota
harmonyParserState_ParsingHeader
@ -75,18 +89,28 @@ func (s *HarmonyParser) AddImplicitStart() {
s.acc.WriteString("<|start|>assistant")
}
func (s *HarmonyParser) AddImplicitStartOrPrefill(lastMessage *api.Message) {
if lastMessage != nil && lastMessage.Role == "assistant" {
// handle prefilling conditions
if lastMessage.Content != "" {
s.acc.WriteString("<|start|>assistant<|channel|>final<|message|>")
return
} else if lastMessage.Thinking != "" {
s.acc.WriteString("<|start|>assistant<|channel|>analysis<|message|>")
return
}
func Prefill(lastMessage api.Message) string {
if lastMessage.Role != "assistant" {
return ""
}
switch {
case strings.TrimSpace(lastMessage.Content) != "":
return "<|start|>assistant<|channel|>final<|message|>"
case strings.TrimSpace(lastMessage.Thinking) != "":
return "<|start|>assistant<|channel|>analysis<|message|>"
default:
return ""
}
}
// AddImplicitStartOrPrefill adds an implicit start tag or prefill string if provided
func (s *HarmonyParser) AddImplicitStartOrPrefill(prefillString string) {
if strings.TrimSpace(prefillString) != "" {
s.acc.WriteString(prefillString)
} else {
s.AddImplicitStart()
}
s.AddImplicitStart()
}
func (s *HarmonyParser) AddContent(content string) []HarmonyEvent {
@ -265,6 +289,7 @@ type HarmonyMessageHandler struct {
state harmonyMessageState
HarmonyParser *HarmonyParser
FunctionNameMap *FunctionNameMap
ToolParser *HarmonyToolCallAccumulator
}
// NewHarmonyMessageHandler creates a new message handler
@ -277,12 +302,16 @@ func NewHarmonyMessageHandler() *HarmonyMessageHandler {
HeaderEndTag: "<|message|>",
},
FunctionNameMap: NewFunctionNameMap(),
ToolParser: &HarmonyToolCallAccumulator{
state: harmonyToolCallState_Normal,
currentToolName: nil,
},
}
}
// AddContent processes the content and returns the content, thinking, and tool content.
// content and thinking are already fully parsed, but tool content still needs to be passed to the tool parser
func (h *HarmonyMessageHandler) AddContent(content string, toolParser *HarmonyToolCallAccumulator) (string, string, string) {
func (h *HarmonyMessageHandler) AddContent(content string) (string, string, string) {
contentSb := strings.Builder{}
thinkingSb := strings.Builder{}
toolContentSb := strings.Builder{}
@ -299,14 +328,14 @@ func (h *HarmonyMessageHandler) AddContent(content string, toolParser *HarmonyTo
// event.Header.Recipient is the tool name, something like
// "browser.search" for a built-in, or "functions.calc" for a
// custom one
toolParser.SetToolName(event.Header.Recipient)
h.ToolParser.SetToolName(event.Header.Recipient)
} else {
h.state = harmonyMessageState_Thinking
}
case "commentary":
if event.Header.Recipient != "" {
h.state = harmonyMessageState_ToolCalling
toolParser.SetToolName(event.Header.Recipient)
h.ToolParser.SetToolName(event.Header.Recipient)
} else {
h.state = harmonyMessageState_Normal
}
@ -329,13 +358,6 @@ func (h *HarmonyMessageHandler) AddContent(content string, toolParser *HarmonyTo
return contentSb.String(), thinkingSb.String(), toolContentSb.String()
}
func (h *HarmonyMessageHandler) CreateToolParser() *HarmonyToolCallAccumulator {
return &HarmonyToolCallAccumulator{
state: harmonyToolCallState_Normal,
currentToolName: nil,
}
}
type harmonyToolCallState int
const (

View File

@ -3,6 +3,7 @@ package harmony
import (
"fmt"
"reflect"
"strings"
"testing"
)
@ -535,3 +536,202 @@ func TestFunctionConvertAndAdd(t *testing.T) {
})
}
}
func TestHarmonyMessageHandlerStreamingScenarios(t *testing.T) {
t.Run("thinking_then_content_streams", func(t *testing.T) {
handler := NewHarmonyMessageHandler()
handler.HarmonyParser.AddImplicitStart()
tp := handler.ToolParser
type step struct {
in string
wantContent string
wantThinking string
}
steps := []step{
{in: "<|channel|>analysis<|message|>Thinking...", wantThinking: "Thinking..."},
{in: "<|end|>", wantThinking: ""},
{in: "<|start|>assistant<|message|>Answer", wantContent: "Answer"},
{in: "<|end|>", wantContent: ""},
}
for i, s := range steps {
content, thinking, tool := handler.AddContent(s.in)
if tool != "" {
tp.Add(tool)
}
if content != s.wantContent || thinking != s.wantThinking {
t.Fatalf("step %d: got (content=%q thinking=%q), want (content=%q thinking=%q)", i, content, thinking, s.wantContent, s.wantThinking)
}
}
})
t.Run("content_streams_as_it_arrives", func(t *testing.T) {
handler := NewHarmonyMessageHandler()
handler.HarmonyParser.AddImplicitStart()
tp := handler.ToolParser
inputs := []string{
"<|start|>assistant<|message|>Hello",
", world",
"!<|end|>",
}
var got []string
for _, in := range inputs {
content, thinking, tool := handler.AddContent(in)
if tool != "" {
tp.Add(tool)
}
if thinking != "" {
t.Fatalf("unexpected thinking %q", thinking)
}
if content != "" {
got = append(got, content)
}
}
want := []string{"Hello", ", world", "!"}
if !reflect.DeepEqual(got, want) {
t.Fatalf("content pieces mismatch: got %v want %v", got, want)
}
})
t.Run("thinking_streams_separately_from_content", func(t *testing.T) {
handler := NewHarmonyMessageHandler()
handler.HarmonyParser.AddImplicitStart()
tp := handler.ToolParser
inputs := []string{
"<|channel|>analysis<|message|>Thinking...",
"<|end|>",
"<|start|>assistant<|message|>Answer",
"<|end|>",
}
var got []string
for _, in := range inputs {
content, thinking, tool := handler.AddContent(in)
if tool != "" {
tp.Add(tool)
}
if thinking != "" {
got = append(got, thinking)
}
if content != "" {
got = append(got, content)
}
}
want := []string{"Thinking...", "Answer"}
if !reflect.DeepEqual(got, want) {
t.Fatalf("content pieces mismatch: got %v want %v", got, want)
}
})
t.Run("partial_tags_buffer_until_complete", func(t *testing.T) {
handler := NewHarmonyMessageHandler()
handler.HarmonyParser.AddImplicitStart()
tp := handler.ToolParser
inputs := []string{
"<|chan",
"nel|>analysis<|mess",
"age|>Deep ",
"thought",
"<|end|>",
"<|start|>assistant<|message|>Done",
"<|end|>",
}
var thinkingPieces []string
var contentPieces []string
for _, in := range inputs {
content, thinking, tool := handler.AddContent(in)
if tool != "" {
tp.Add(tool)
}
if thinking != "" {
thinkingPieces = append(thinkingPieces, thinking)
}
if content != "" {
contentPieces = append(contentPieces, content)
}
}
if want := []string{"Deep ", "thought"}; !reflect.DeepEqual(thinkingPieces, want) {
t.Fatalf("thinking pieces mismatch: got %v want %v", thinkingPieces, want)
}
if want := []string{"Done"}; !reflect.DeepEqual(contentPieces, want) {
t.Fatalf("content pieces mismatch: got %v want %v", contentPieces, want)
}
})
t.Run("simple_assistant_after_analysis", func(t *testing.T) {
handler := NewHarmonyMessageHandler()
handler.HarmonyParser.AddImplicitStart()
tp := handler.ToolParser
inputs := []string{
"<|channel|>analysis<|message|>Think",
"<|end|>",
"<|start|>assistant<|message|>Answer",
"<|end|>",
}
var contentSb, thinkingSb strings.Builder
for _, in := range inputs {
content, thinking, tool := handler.AddContent(in)
if tool != "" {
tp.Add(tool)
}
contentSb.WriteString(content)
thinkingSb.WriteString(thinking)
}
if contentSb.String() != "Answer" {
t.Fatalf("content mismatch: got %q want %q", contentSb.String(), "Answer")
}
if thinkingSb.String() != "Think" {
t.Fatalf("thinking mismatch: got %q want %q", thinkingSb.String(), "Think")
}
})
t.Run("tool_call_parsed_and_returned_correctly", func(t *testing.T) {
handler := NewHarmonyMessageHandler()
handler.HarmonyParser.AddImplicitStart()
tp := handler.ToolParser
inputs := []string{
"<|channel|>commentary to=functions.calculate<|message|>{\"expression\":\"2+2\"}<|end|>",
}
for _, in := range inputs {
content, thinking, tool := handler.AddContent(in)
if content != "" || thinking != "" {
continue
}
if tool != "" {
tp.Add(tool)
}
}
name, args := tp.Drain()
if name == nil || *name != "functions.calculate" {
t.Fatalf("unexpected tool name: %v", name)
}
if got, want := args, "{\"expression\":\"2+2\"}"; got != want {
t.Fatalf("unexpected tool args: got %s want %s", got, want)
}
})
t.Run("tool_call_across_chunks", func(t *testing.T) {
handler := NewHarmonyMessageHandler()
handler.HarmonyParser.AddImplicitStart()
tp := handler.ToolParser
inputs := []string{
"<|channel|>commentary to=functions.calculate<|message|>{\"expression\":\"2+",
"2\"}",
"<|end|>",
}
for _, in := range inputs {
content, thinking, tool := handler.AddContent(in)
if content != "" || thinking != "" {
continue
}
if tool != "" {
tp.Add(tool)
}
}
name, args := tp.Drain()
if name == nil || *name != "functions.calculate" {
t.Fatalf("unexpected tool name: %v", name)
}
if got, want := args, "{\"expression\":\"2+2\"}"; got != want {
t.Fatalf("unexpected tool args: got %s want %s", got, want)
}
})
}

View File

@ -410,3 +410,99 @@ func TestAPIEmbeddings(t *testing.T) {
t.Errorf("zero length embedding response")
}
}
func TestAPIToolCalling(t *testing.T) {
initialTimeout := 60 * time.Second
streamTimeout := 30 * time.Second
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute)
defer cancel()
client, _, cleanup := InitServerConnection(ctx, t)
defer cleanup()
modelName := "qwen3:0.6b"
if err := PullIfMissing(ctx, client, modelName); err != nil {
t.Fatalf("pull failed %s", err)
}
tools := []api.Tool{
{
Type: "function",
Function: api.ToolFunction{
Name: "get_weather",
Description: "Get the current weather in a given location",
Parameters: api.ToolFunctionParameters{
Type: "object",
Required: []string{"location"},
Properties: map[string]api.ToolProperty{
"location": {
Type: api.PropertyType{"string"},
Description: "The city and state, e.g. San Francisco, CA",
},
},
},
},
},
}
req := api.ChatRequest{
Model: modelName,
Messages: []api.Message{
{
Role: "user",
Content: "Call get_weather with location set to San Francisco.",
},
},
Tools: tools,
Options: map[string]any{
"temperature": 0,
},
}
stallTimer := time.NewTimer(initialTimeout)
var gotToolCall bool
var lastToolCall api.ToolCall
fn := func(response api.ChatResponse) error {
if len(response.Message.ToolCalls) > 0 {
gotToolCall = true
lastToolCall = response.Message.ToolCalls[len(response.Message.ToolCalls)-1]
}
if !stallTimer.Reset(streamTimeout) {
return fmt.Errorf("stall was detected while streaming response, aborting")
}
return nil
}
stream := true
req.Stream = &stream
done := make(chan int)
var genErr error
go func() {
genErr = client.Chat(ctx, &req, fn)
done <- 0
}()
select {
case <-stallTimer.C:
t.Errorf("tool-calling chat never started. Timed out after: %s", initialTimeout.String())
case <-done:
if genErr != nil {
t.Fatalf("chat failed: %v", genErr)
}
if !gotToolCall {
t.Fatalf("expected at least one tool call, got none")
}
if lastToolCall.Function.Name != "get_weather" {
t.Errorf("unexpected tool called: got %q want %q", lastToolCall.Function.Name, "get_weather")
}
if _, ok := lastToolCall.Function.Arguments["location"]; !ok {
t.Errorf("expected tool arguments to include 'location', got: %s", lastToolCall.Function.Arguments.String())
}
case <-ctx.Done():
t.Error("outer test context done while waiting for tool-calling chat")
}
}

View File

@ -121,6 +121,7 @@ func TestMultiModelStress(t *testing.T) {
// The intent is to go 1 over what can fit so we force the scheduler to thrash
targetLoadCount := 0
slog.Info("Loading models to find how many can fit in VRAM before overflowing")
chooseModels:
for i, model := range chosenModels {
req := &api.GenerateRequest{Model: model}
slog.Info("loading", "model", model)
@ -142,6 +143,13 @@ func TestMultiModelStress(t *testing.T) {
slog.Info("found model load capacity", "target", targetLoadCount, "current", loaded, "chosen", chosenModels[:targetLoadCount])
break
}
// Effectively limit model count to 2 on CPU only systems to avoid thrashing and timeouts
for _, m := range models.Models {
if m.SizeVRAM == 0 {
slog.Info("model running on CPU", "name", m.Name, "target", targetLoadCount, "chosen", chosenModels[:targetLoadCount])
break chooseModels
}
}
}
}
if targetLoadCount == len(chosenModels) {

View File

@ -36,7 +36,7 @@ func TestLongInputContext(t *testing.T) {
if err := PullIfMissing(ctx, client, req.Model); err != nil {
t.Fatalf("PullIfMissing failed: %v", err)
}
DoGenerate(ctx, t, client, req, []string{"russia", "germany", "france", "england", "austria", "prussia", "individuals", "coalition", "conflict"}, 120*time.Second, 10*time.Second)
DoGenerate(ctx, t, client, req, []string{"russia", "germany", "france", "england", "austria", "prussia", "europe", "individuals", "coalition", "conflict"}, 120*time.Second, 10*time.Second)
}
func TestContextExhaustion(t *testing.T) {

View File

@ -38,8 +38,9 @@ func TestAllMiniLMEmbeddings(t *testing.T) {
defer cleanup()
req := api.EmbeddingRequest{
Model: "all-minilm",
Prompt: "why is the sky blue?",
Model: "all-minilm",
Prompt: "why is the sky blue?",
KeepAlive: &api.Duration{Duration: 10 * time.Second},
}
res, err := embeddingTestHelper(ctx, client, t, req)

View File

@ -502,6 +502,22 @@ func DoGenerate(ctx context.Context, t *testing.T, client *api.Client, genReq ap
done <- 0
}()
var response string
verify := func() {
// Verify the response contains the expected data
response = buf.String()
atLeastOne := false
for _, resp := range anyResp {
if strings.Contains(strings.ToLower(response), resp) {
atLeastOne = true
break
}
}
if !atLeastOne {
t.Fatalf("%s: none of %v found in %s", genReq.Model, anyResp, response)
}
}
select {
case <-stallTimer.C:
if buf.Len() == 0 {
@ -517,21 +533,14 @@ func DoGenerate(ctx context.Context, t *testing.T, client *api.Client, genReq ap
if genErr != nil {
t.Fatalf("%s failed with %s request prompt %s", genErr, genReq.Model, genReq.Prompt)
}
// Verify the response contains the expected data
response := buf.String()
atLeastOne := false
for _, resp := range anyResp {
if strings.Contains(strings.ToLower(response), resp) {
atLeastOne = true
break
}
}
if !atLeastOne {
t.Fatalf("%s: none of %v found in %s", genReq.Model, anyResp, response)
}
verify()
slog.Info("test pass", "model", genReq.Model, "prompt", genReq.Prompt, "contains", anyResp, "response", response)
case <-ctx.Done():
t.Error("outer test context done while waiting for generate")
// On slow systems, we might timeout before some models finish rambling, so check what we have so far to see
// if it's considered a pass - the stallTimer will detect hangs, but we want to consider slow systems a pass
// if they are still generating valid responses
slog.Warn("outer test context done while waiting for generate")
verify()
}
return context
}
@ -599,6 +608,22 @@ func DoChat(ctx context.Context, t *testing.T, client *api.Client, req api.ChatR
done <- 0
}()
var response string
verify := func() {
// Verify the response contains the expected data
response = buf.String()
atLeastOne := false
for _, resp := range anyResp {
if strings.Contains(strings.ToLower(response), resp) {
atLeastOne = true
break
}
}
if !atLeastOne {
t.Fatalf("%s: none of %v found in \"%s\" -- request was:%v", req.Model, anyResp, response, req.Messages)
}
}
select {
case <-stallTimer.C:
if buf.Len() == 0 {
@ -614,23 +639,14 @@ func DoChat(ctx context.Context, t *testing.T, client *api.Client, req api.ChatR
if genErr != nil {
t.Fatalf("%s failed with %s request prompt %v", genErr, req.Model, req.Messages)
}
// Verify the response contains the expected data
response := buf.String()
atLeastOne := false
for _, resp := range anyResp {
if strings.Contains(strings.ToLower(response), resp) {
atLeastOne = true
break
}
}
if !atLeastOne {
t.Fatalf("%s: none of %v found in \"%s\" -- request was:%v", req.Model, anyResp, response, req.Messages)
}
verify()
slog.Info("test pass", "model", req.Model, "messages", req.Messages, "contains", anyResp, "response", response)
case <-ctx.Done():
t.Error("outer test context done while waiting for generate")
// On slow systems, we might timeout before some models finish rambling, so check what we have so far to see
// if it's considered a pass - the stallTimer will detect hangs, but we want to consider slow systems a pass
// if they are still generating valid responses
slog.Warn("outer test context done while waiting for chat")
verify()
}
return &api.Message{Role: role, Content: buf.String()}
}

View File

@ -202,7 +202,7 @@ func estimateGPULayers(gpus []discover.GpuInfo, f *ggml.GGML, projectors []strin
var kvct string
if useFlashAttention {
requested := strings.ToLower(envconfig.KvCacheType())
if requested != "" && f.SupportsKVCacheType(requested) {
if f.SupportsKVCacheType(requested) {
kvct = requested
}
}

View File

@ -35,6 +35,7 @@ import (
"github.com/ollama/ollama/logutil"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/parser"
)
type filteredEnv []string
@ -173,6 +174,8 @@ func NewLlamaServer(gpus discover.GpuInfoList, modelPath string, f *ggml.GGML, a
opts.NumCtx = int(trainCtx)
}
opts.NumBatch = min(opts.NumBatch, opts.NumCtx)
loadRequest := LoadRequest{LoraPath: adapters, KvSize: opts.NumCtx * numParallel, BatchSize: opts.NumBatch, Parallel: numParallel, MultiUserCache: envconfig.MultiUserCache()}
defaultThreads := discover.GetSystemInfo().GetOptimalThreadCount()
@ -218,7 +221,7 @@ func NewLlamaServer(gpus discover.GpuInfoList, modelPath string, f *ggml.GGML, a
// Flash Attention also supports kv cache quantization
// Enable if the requested and kv cache type is supported by the model
if kvct != "" && f.SupportsKVCacheType(kvct) {
if f.SupportsKVCacheType(kvct) {
loadRequest.KvCacheType = kvct
} else {
slog.Warn("kv cache type not supported by model", "type", kvct)
@ -1348,7 +1351,9 @@ type CompletionRequest struct {
Images []ImageData
Options *api.Options
Grammar string // set before sending the request to the subprocess
Grammar string // set before sending the request to the subprocess
ParserType parser.TokenParserType
PrefillString string
}
// DoneReason represents the reason why a completion response is done
@ -1375,13 +1380,15 @@ func (d DoneReason) String() string {
}
type CompletionResponse struct {
Content string `json:"content"`
DoneReason DoneReason `json:"done_reason"`
Done bool `json:"done"`
PromptEvalCount int `json:"prompt_eval_count"`
PromptEvalDuration time.Duration `json:"prompt_eval_duration"`
EvalCount int `json:"eval_count"`
EvalDuration time.Duration `json:"eval_duration"`
Content string `json:"content"`
Thinking string `json:"thinking"`
ToolCalls []api.ToolCall `json:"tool_calls"`
DoneReason DoneReason `json:"done_reason"`
Done bool `json:"done"`
PromptEvalCount int `json:"prompt_eval_count"`
PromptEvalDuration time.Duration `json:"prompt_eval_duration"`
EvalCount int `json:"eval_count"`
EvalDuration time.Duration `json:"eval_duration"`
}
func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn func(CompletionResponse)) error {
@ -1499,7 +1506,8 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu
return fmt.Errorf("error unmarshalling llm prediction response: %v", err)
}
switch {
case strings.TrimSpace(c.Content) == lastToken:
// TODO(parthsareen): token repeat limit is now handled in the runner, this currently support legacy model and can be removed in the future
case strings.TrimSpace(c.Content) == lastToken && c.Content != "":
tokenRepeat++
default:
lastToken = strings.TrimSpace(c.Content)
@ -1512,16 +1520,14 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu
return ctx.Err()
}
if c.Content != "" {
fn(CompletionResponse{
Content: c.Content,
})
}
if c.Done {
fn(c)
return nil
}
if c.Content != "" || c.Thinking != "" || len(c.ToolCalls) > 0 {
fn(c)
}
}
}

126
parser/token_parser.go Normal file
View File

@ -0,0 +1,126 @@
package parser
import (
"encoding/json"
"errors"
"strings"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/harmony"
)
type TokenParserType int
const (
TokenParserTypeDefault TokenParserType = iota
TokenParserTypeHarmony
)
type TokenParser struct {
messageHandler MessageHandler
parserEngine ParserInternals
toolParser ToolParser
lastToken string
tokenRepeat int
repeatLimit int
}
const defaultTokenRepeatLimit = 30
type MessageHandler interface {
AddContent(token string) (content, thinking string, toolContent string)
}
type ParserInternals interface {
AddImplicitStartOrPrefill(prefillString string)
}
type ToolParser interface {
Add(token string)
Drain() (toolName *string, toolContent string)
}
// Default implementation for the TokenParser interface as a no-op passthrough
type defaultMessageHandler struct{}
func (defaultMessageHandler) AddContent(token string) (string, string, string) {
return token, "", ""
}
type defaultEngine struct{}
func (defaultEngine) AddImplicitStartOrPrefill(prefillString string) {}
type defaultToolParser struct{}
func (defaultToolParser) Add(token string) {}
func (defaultToolParser) Drain() (*string, string) { return nil, "" }
func NewTokenParser(parserType TokenParserType, prefillString string) TokenParser {
switch parserType {
case TokenParserTypeHarmony:
harmonyMessageHandler := harmony.NewHarmonyMessageHandler()
harmonyMessageHandler.HarmonyParser.AddImplicitStartOrPrefill(prefillString)
return TokenParser{
messageHandler: harmonyMessageHandler,
parserEngine: harmonyMessageHandler.HarmonyParser,
toolParser: harmonyMessageHandler.ToolParser,
repeatLimit: defaultTokenRepeatLimit,
}
default:
return TokenParser{
messageHandler: defaultMessageHandler{},
parserEngine: defaultEngine{},
toolParser: defaultToolParser{},
repeatLimit: 30,
}
}
}
func (p *TokenParser) AddContent(token string) (string, string, error) {
if p.repeatLimitReached(token) {
return "", "", errors.New("token repeat limit reached")
}
content, thinking, toolContent := p.messageHandler.AddContent(token)
p.toolParser.Add(toolContent)
return content, thinking, nil
}
// repeatLimitReached updates repeat counters and returns true if the repeat limit is reached.
func (p *TokenParser) repeatLimitReached(token string) bool {
if p == nil {
return false
}
trimmed := strings.TrimSpace(token)
if trimmed == p.lastToken {
p.tokenRepeat++
} else {
p.tokenRepeat = 0
}
p.lastToken = trimmed
return p.tokenRepeat >= p.repeatLimit
}
// TODO: update to work with multiple toolcalls - unmarshalling should also happen on parser level
func (p *TokenParser) Drain() []api.ToolCall {
toolName, toolContent := p.toolParser.Drain()
if toolName != nil {
*toolName = strings.TrimPrefix(*toolName, "functions.")
var args api.ToolCallFunctionArguments
if err := json.Unmarshal([]byte(toolContent), &args); err != nil {
return nil
}
return []api.ToolCall{
{
Function: api.ToolCallFunction{
Name: *toolName,
Arguments: args,
},
},
}
}
return nil
}

View File

@ -34,8 +34,8 @@ type InputCache struct {
func NewInputCache(model model.Model, kvCacheType string, kvSize int32, numSlots int, batchSize int, multiUserCache bool) (*InputCache, error) {
numCtx := kvSize / int32(numSlots)
if numCtx < 1 {
return nil, fmt.Errorf("must have at least one kv cache entry per parallel sequence (kv: %v parallel: %v)", kvSize, numSlots)
if int(numCtx) < batchSize {
return nil, fmt.Errorf("kv size must be at least as large as batch size * parallel (kv: %v batch: %v parallel: %v)", kvSize, batchSize, numSlots)
}
slots := make([]InputCacheSlot, numSlots)
@ -70,11 +70,9 @@ func kvCacheTypeFromStr(s string) ml.DType {
}
func (c *InputCache) Close() {
if c == nil {
return
if c != nil && c.cache != nil {
c.cache.Close()
}
c.cache.Close()
}
// Locking: Operations on InputCacheSlot (including finding one

View File

@ -35,6 +35,7 @@ import (
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/model/input"
"github.com/ollama/ollama/parser"
"github.com/ollama/ollama/runner/common"
"github.com/ollama/ollama/sample"
@ -781,6 +782,8 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
return
}
tokenParser := parser.NewTokenParser(req.ParserType, req.PrefillString)
if req.Options == nil {
opts := api.DefaultOptions()
req.Options = &opts
@ -871,8 +874,18 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
return
case content, ok := <-seq.responses:
if ok {
var thinking string
var err error
content, thinking, err = tokenParser.AddContent(content)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
close(seq.quit)
return
}
if err := json.NewEncoder(w).Encode(&llm.CompletionResponse{
Content: content,
Content: content,
Thinking: thinking,
}); err != nil {
http.Error(w, fmt.Sprintf("failed to encode response: %v", err), http.StatusInternalServerError)
close(seq.quit)
@ -881,7 +894,9 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
flusher.Flush()
} else {
toolCalls := tokenParser.Drain()
if err := json.NewEncoder(w).Encode(&llm.CompletionResponse{
ToolCalls: toolCalls,
Done: true,
DoneReason: seq.doneReason,
PromptEvalCount: seq.numPromptInputs,

View File

@ -78,7 +78,7 @@ function checkEnv() {
}
function buildOllama() {
function buildCPU() {
mkdir -Force -path "${script:DIST_DIR}\"
if ($script:ARCH -ne "arm64") {
Remove-Item -ea 0 -recurse -force -path "${script:SRC_DIR}\dist\windows-${script:ARCH}"
@ -90,20 +90,72 @@ function buildOllama() {
if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)}
& cmake --install build --component CPU --strip
if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)}
}
}
function buildCUDA11() {
# CUDA v11 claims to be compatible with MSVC 2022, but the latest updates are no longer compatible
# 19.40 is the last compiler version that works, but recent udpates are 19.43
# So this pins to MSVC 2019 for best compatibility
mkdir -Force -path "${script:DIST_DIR}\"
if ($script:ARCH -ne "arm64") {
$hashEnv = @{}
Get-ChildItem env: | foreach { $hashEnv[$_.Name] = $_.Value }
if ("$script:CUDA_DIRS".Contains("v12")) {
$hashEnv.Keys | foreach { if ($_.Contains("CUDA_PATH_V12")) { $v12="$_" }}
$env:CUDAToolkit_ROOT=$hashEnv[$v12]
write-host "Building CUDA v12 backend libraries"
& cmake --fresh --preset "CUDA 12" --install-prefix $script:DIST_DIR
if ("$script:CUDA_DIRS".Contains("v11")) {
$hashEnv.Keys | foreach { if ($_.Contains("CUDA_PATH_V11")) { $x=$hashEnv[$_]; if (test-path -literalpath "$x\bin\nvcc.exe" ) { $cuda=$x} }}
write-host "Building CUDA v11 backend libraries $cuda"
$env:CUDAToolkit_ROOT=$cuda
& cmake --fresh --preset "CUDA 11" -T cuda="$cuda" -DCMAKE_CUDA_COMPILER="$cuda\bin\nvcc.exe" -G "Visual Studio 16 2019" --install-prefix $script:DIST_DIR -DOLLAMA_RUNNER_DIR="cuda_v11"
if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)}
& cmake --build --preset "CUDA 11" --config Release --parallel $script:JOBS
if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)}
& cmake --install build --component "CUDA" --strip
if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)}
}
}
}
function buildCUDA12() {
mkdir -Force -path "${script:DIST_DIR}\"
if ($script:ARCH -ne "arm64") {
$hashEnv = @{}
Get-ChildItem env: | foreach { $hashEnv[$_.Name] = $_.Value }
if ("$script:CUDA_DIRS".Contains("v12.8")) {
$hashEnv.Keys | foreach { if ($_.Contains("CUDA_PATH_V12_8")) { $x=$hashEnv[$_]; if (test-path -literalpath "$x\bin\nvcc.exe" ) { $cuda=$x} }}
write-host "Building CUDA v12 backend libraries $cuda"
$env:CUDAToolkit_ROOT=$cuda
& cmake --fresh --preset "CUDA 12" -T cuda="$cuda" --install-prefix $script:DIST_DIR -DOLLAMA_RUNNER_DIR="cuda_v12"
if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)}
& cmake --build --preset "CUDA 12" --config Release --parallel $script:JOBS
if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)}
& cmake --install build --component "CUDA" --strip
if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)}
}
}
}
function buildCUDA13() {
mkdir -Force -path "${script:DIST_DIR}\"
if ($script:ARCH -ne "arm64") {
$hashEnv = @{}
Get-ChildItem env: | foreach { $hashEnv[$_.Name] = $_.Value }
if ("$script:CUDA_DIRS".Contains("v13")) {
$hashEnv.Keys | foreach { if ($_.Contains("CUDA_PATH_V13")) { $x=$hashEnv[$_]; if (test-path -literalpath "$x\bin\nvcc.exe" ) { $cuda=$x} }}
$env:CUDAToolkit_ROOT=$cuda
write-host "Building CUDA v13 backend libraries $cuda"
& cmake --fresh --preset "CUDA 13" -T cuda="$cuda" --install-prefix $script:DIST_DIR -DOLLAMA_RUNNER_DIR="cuda_v13"
if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)}
& cmake --build --preset "CUDA 13" --config Release --parallel $script:JOBS
if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)}
& cmake --install build --component "CUDA" --strip
if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)}
}
}
}
function buildROCm() {
mkdir -Force -path "${script:DIST_DIR}\"
if ($script:ARCH -ne "arm64") {
if ($env:HIP_PATH) {
write-host "Building ROCm backend libraries"
if (-Not (get-command -ErrorAction silent ninja)) {
@ -138,6 +190,10 @@ function buildOllama() {
if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)}
}
}
}
function buildOllama() {
mkdir -Force -path "${script:DIST_DIR}\"
write-host "Building ollama CLI"
& go build -trimpath -ldflags "-s -w -X=github.com/ollama/ollama/version.Version=$script:VERSION -X=github.com/ollama/ollama/server.mode=release" .
if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)}
@ -245,6 +301,10 @@ function distZip() {
checkEnv
try {
if ($($args.count) -eq 0) {
buildCPU
buildCUDA12
buildCUDA13
buildROCm
buildOllama
buildApp
gatherDependencies

View File

@ -36,6 +36,7 @@ import (
"github.com/ollama/ollama/llm"
"github.com/ollama/ollama/logutil"
"github.com/ollama/ollama/openai"
"github.com/ollama/ollama/parser"
"github.com/ollama/ollama/server/internal/client/ollama"
"github.com/ollama/ollama/server/internal/registry"
"github.com/ollama/ollama/template"
@ -46,18 +47,6 @@ import (
"github.com/ollama/ollama/version"
)
func shouldUseHarmony(model *Model) bool {
if slices.Contains([]string{"gptoss", "gpt-oss"}, model.Config.ModelFamily) {
// heuristic to check whether the template expects to be parsed via harmony:
// search for harmony tags that are nearly always used
if model.Template.Contains("<|start|>") && model.Template.Contains("<|end|>") {
return true
}
}
return false
}
func experimentEnabled(name string) bool {
return slices.Contains(strings.Split(os.Getenv("OLLAMA_EXPERIMENT"), ","), name)
}
@ -207,13 +196,17 @@ func (s *Server) GenerateHandler(c *gin.Context) {
return
}
useHarmony := shouldUseHarmony(m) && !req.Raw
var harmonyMessageHandler *harmony.HarmonyMessageHandler
var harmonyToolParser *harmony.HarmonyToolCallAccumulator
useHarmony := harmony.ShouldUseHarmony(m.Config.ModelFamily, m.Template) && !req.Raw
var parserType parser.TokenParserType
if useHarmony {
harmonyMessageHandler = harmony.NewHarmonyMessageHandler()
harmonyMessageHandler.HarmonyParser.AddImplicitStart()
harmonyToolParser = harmonyMessageHandler.CreateToolParser()
parserType = parser.TokenParserTypeHarmony
} else {
parserType = parser.TokenParserTypeDefault
}
var functionNameMap *harmony.FunctionNameMap
if useHarmony {
functionNameMap = harmony.NewFunctionNameMap()
}
// Validate Think value: string values currently only allowed for gptoss models
@ -357,16 +350,19 @@ func (s *Server) GenerateHandler(c *gin.Context) {
var sb strings.Builder
defer close(ch)
if err := r.Completion(c.Request.Context(), llm.CompletionRequest{
Prompt: prompt,
Images: images,
Format: req.Format,
Options: opts,
Prompt: prompt,
Images: images,
Format: req.Format,
Options: opts,
ParserType: parserType,
}, func(cr llm.CompletionResponse) {
res := api.GenerateResponse{
Model: req.Model,
CreatedAt: time.Now().UTC(),
Response: cr.Content,
Done: cr.Done,
Thinking: cr.Thinking,
ToolCalls: cr.ToolCalls,
Metrics: api.Metrics{
PromptEvalCount: cr.PromptEvalCount,
PromptEvalDuration: cr.PromptEvalDuration,
@ -375,12 +371,22 @@ func (s *Server) GenerateHandler(c *gin.Context) {
},
}
if res.Done {
res.DoneReason = cr.DoneReason.String()
res.TotalDuration = time.Since(checkpointStart)
res.LoadDuration = checkpointLoaded.Sub(checkpointStart)
}
if useHarmony {
content, thinking, toolContent := harmonyMessageHandler.AddContent(cr.Content, harmonyToolParser)
res.Response = content
res.Thinking = thinking
harmonyToolParser.Add(toolContent)
} else if thinkingState != nil {
for i, tool := range res.ToolCalls {
res.ToolCalls[i].Function.Name = functionNameMap.OriginalFromConverted(tool.Function.Name)
}
if res.Response != "" || res.Thinking != "" || len(res.ToolCalls) > 0 || res.Done {
ch <- res
}
return
}
if thinkingState != nil {
thinking, content := thinkingState.AddContent(cr.Content)
res.Thinking = thinking
res.Response = content
@ -391,30 +397,6 @@ func (s *Server) GenerateHandler(c *gin.Context) {
}
if cr.Done {
if useHarmony {
toolName, toolContent := harmonyToolParser.Drain()
if toolName != nil {
*toolName = strings.TrimPrefix(*toolName, "functions.")
var args api.ToolCallFunctionArguments
if err := json.Unmarshal([]byte(toolContent), &args); err != nil {
errStr := fmt.Sprintf("error parsing tool call: raw='%s', err=%s", toolContent, err.Error())
ch <- gin.H{"error": errStr}
return
}
res.ToolCalls = append(res.ToolCalls, api.ToolCall{
Function: api.ToolCallFunction{
Name: *toolName,
Arguments: args,
},
})
}
}
res.DoneReason = cr.DoneReason.String()
res.TotalDuration = time.Since(checkpointStart)
res.LoadDuration = checkpointLoaded.Sub(checkpointStart)
if !req.Raw {
tokens, err := r.Tokenize(c.Request.Context(), prompt+sb.String())
if err != nil {
@ -1616,27 +1598,27 @@ func (s *Server) ChatHandler(c *gin.Context) {
}
msgs = filterThinkTags(msgs, m)
var harmonyMessageHandler *harmony.HarmonyMessageHandler
var harmonyToolParser *harmony.HarmonyToolCallAccumulator
useHarmony := shouldUseHarmony(m)
useHarmony := harmony.ShouldUseHarmony(m.Config.ModelFamily, m.Template)
var parserType parser.TokenParserType
if useHarmony {
parserType = parser.TokenParserTypeHarmony
} else {
parserType = parser.TokenParserTypeDefault
}
processedTools := req.Tools
var functionNameMap *harmony.FunctionNameMap
var prefillString string
// TODO(parthsareen): this can be abstracted to not be model specific and potentially moved to the runner
if useHarmony {
harmonyMessageHandler = harmony.NewHarmonyMessageHandler()
var lastMessage *api.Message
if len(msgs) > 0 {
lastMessage = &msgs[len(msgs)-1]
}
harmonyMessageHandler.HarmonyParser.AddImplicitStartOrPrefill(lastMessage)
harmonyToolParser = harmonyMessageHandler.CreateToolParser()
prefillString = harmony.Prefill(msgs[len(msgs)-1])
functionNameMap = harmony.NewFunctionNameMap()
// make a copy of tools to pass to the chat prompt. Function names may be
// renamed to be valid Harmony function names.
processedTools = make([]api.Tool, len(req.Tools))
copy(processedTools, req.Tools)
for i, tool := range processedTools {
processedTools[i].Function.Name = harmonyMessageHandler.FunctionNameMap.ConvertAndAdd(tool.Function.Name)
processedTools[i].Function.Name = functionNameMap.ConvertAndAdd(tool.Function.Name)
}
}
@ -1689,15 +1671,17 @@ func (s *Server) ChatHandler(c *gin.Context) {
defer close(ch)
if err := r.Completion(c.Request.Context(), llm.CompletionRequest{
Prompt: prompt,
Images: images,
Format: req.Format,
Options: opts,
Prompt: prompt,
Images: images,
Format: req.Format,
Options: opts,
ParserType: parserType,
PrefillString: prefillString,
}, func(r llm.CompletionResponse) {
res := api.ChatResponse{
Model: req.Model,
CreatedAt: time.Now().UTC(),
Message: api.Message{Role: "assistant", Content: r.Content},
Message: api.Message{Role: "assistant", Content: r.Content, Thinking: r.Thinking, ToolCalls: r.ToolCalls},
Done: r.Done,
Metrics: api.Metrics{
PromptEvalCount: r.PromptEvalCount,
@ -1713,31 +1697,13 @@ func (s *Server) ChatHandler(c *gin.Context) {
}
if useHarmony {
content, thinking, toolContent := harmonyMessageHandler.AddContent(r.Content, harmonyToolParser)
res.Message.Content = content
res.Message.Thinking = thinking
harmonyToolParser.Add(toolContent)
if r.Done {
toolName, toolContent := harmonyToolParser.Drain()
if toolName != nil {
*toolName = strings.TrimPrefix(*toolName, "functions.")
*toolName = harmonyMessageHandler.FunctionNameMap.OriginalFromConverted(*toolName)
var args api.ToolCallFunctionArguments
if err := json.Unmarshal([]byte(toolContent), &args); err != nil {
errStr := fmt.Sprintf("error parsing tool call: raw='%s', err=%s", toolContent, err.Error())
ch <- gin.H{"error": errStr}
return
}
res.Message.ToolCalls = []api.ToolCall{{Function: api.ToolCallFunction{Name: *toolName, Arguments: args}}}
}
for i, tool := range res.Message.ToolCalls {
res.Message.ToolCalls[i].Function.Name = functionNameMap.OriginalFromConverted(tool.Function.Name)
}
// only send messages with meaningful content (empty messages confuse clients)
if res.Message.Content != "" || res.Message.Thinking != "" || len(res.Message.ToolCalls) > 0 || res.Done {
ch <- res
}
return
}

View File

@ -7,7 +7,6 @@ import (
"bytes"
"context"
"encoding/json"
"net/http"
"strings"
"testing"
"time"
@ -118,7 +117,7 @@ func TestChatHarmonyParserStreamingRealtime(t *testing.T) {
name: "content streams as it arrives",
steps: []step{
{
input: llm.CompletionResponse{Content: "<|message|>Hello", Done: false},
input: llm.CompletionResponse{Content: "Hello", Done: false},
wantContent: "Hello",
},
{
@ -126,7 +125,7 @@ func TestChatHarmonyParserStreamingRealtime(t *testing.T) {
wantContent: ", world",
},
{
input: llm.CompletionResponse{Content: "!<|end|>", Done: true, DoneReason: llm.DoneReasonStop},
input: llm.CompletionResponse{Content: "!", Done: true, DoneReason: llm.DoneReasonStop},
wantContent: "!",
},
},
@ -135,20 +134,15 @@ func TestChatHarmonyParserStreamingRealtime(t *testing.T) {
name: "thinking streams separately from content",
steps: []step{
{
input: llm.CompletionResponse{Content: "<|channel|>analysis<|message|>Thinking...", Done: false},
input: llm.CompletionResponse{Thinking: "Thinking...", Done: false},
wantThinking: "Thinking...",
},
{
input: llm.CompletionResponse{Content: "<|end|>", Done: false},
// No output expected - just closes the analysis message and resets state to normal
input: llm.CompletionResponse{Content: "Answer", Done: false},
wantContent: "Answer",
},
{
input: llm.CompletionResponse{Content: "<|start|>assistant<|message|>Answer", Done: false},
wantContent: "Answer", // After message end, state is reset to normal
},
{
input: llm.CompletionResponse{Content: "<|end|>", Done: true, DoneReason: llm.DoneReasonStop},
// No output expected - just closes the assistant message
input: llm.CompletionResponse{Done: true, DoneReason: llm.DoneReasonStop},
},
},
},
@ -156,24 +150,16 @@ func TestChatHarmonyParserStreamingRealtime(t *testing.T) {
name: "partial tags buffer until complete",
steps: []step{
{
input: llm.CompletionResponse{Content: "<|chan", Done: false},
// No output - partial tag
},
{
input: llm.CompletionResponse{Content: "nel|>analysis<|mess", Done: false},
// No output - still building tags
},
{
input: llm.CompletionResponse{Content: "age|>Deep ", Done: false},
input: llm.CompletionResponse{Thinking: "Deep ", Done: false},
wantThinking: "Deep ",
},
{
input: llm.CompletionResponse{Content: "thought<|end|>", Done: false},
input: llm.CompletionResponse{Thinking: "thought", Done: false},
wantThinking: "thought",
},
{
input: llm.CompletionResponse{Content: "<|start|>assistant<|message|>Done<|end|>", Done: true, DoneReason: llm.DoneReasonStop},
wantContent: "Done", // After message end, state is reset to normal
input: llm.CompletionResponse{Content: "Done", Done: true, DoneReason: llm.DoneReasonStop},
wantContent: "Done",
},
},
},
@ -181,7 +167,7 @@ func TestChatHarmonyParserStreamingRealtime(t *testing.T) {
name: "simple assistant after analysis",
steps: []step{
{
input: llm.CompletionResponse{Content: "<|channel|>analysis<|message|>Think<|end|><|start|>assistant<|message|>Answer<|end|>", Done: true, DoneReason: llm.DoneReasonStop},
input: llm.CompletionResponse{Thinking: "Think", Content: "Answer", Done: true, DoneReason: llm.DoneReasonStop},
wantContent: "Answer",
wantThinking: "Think",
},
@ -191,7 +177,7 @@ func TestChatHarmonyParserStreamingRealtime(t *testing.T) {
name: "tool call parsed and returned correctly",
steps: []step{
{
input: llm.CompletionResponse{Content: "<|channel|>commentary to=functions.get_weather<|message|>{\"location\":\"San Francisco\"}<|end|><|start|>assistant<|message|>The weather is sunny<|end|>", Done: true, DoneReason: llm.DoneReasonStop},
input: llm.CompletionResponse{Content: "The weather is sunny", ToolCalls: []api.ToolCall{{Function: api.ToolCallFunction{Name: "get_weather", Arguments: api.ToolCallFunctionArguments{"location": "San Francisco"}}}}, Done: true, DoneReason: llm.DoneReasonStop},
wantContent: "The weather is sunny",
wantToolCalls: []api.ToolCall{
{
@ -210,15 +196,10 @@ func TestChatHarmonyParserStreamingRealtime(t *testing.T) {
name: "tool call with streaming JSON across chunks",
steps: []step{
{
input: llm.CompletionResponse{Content: "<|channel|>commentary to=functions.calculate<|message|>{\"expr", Done: false},
// No output yet - incomplete JSON
input: llm.CompletionResponse{Done: false},
},
{
input: llm.CompletionResponse{Content: "ession\":\"2+", Done: false},
// Still no output - incomplete JSON
},
{
input: llm.CompletionResponse{Content: "2\"}", Done: true},
input: llm.CompletionResponse{ToolCalls: []api.ToolCall{{Function: api.ToolCallFunction{Name: "calculate", Arguments: api.ToolCallFunctionArguments{"expression": "2+2"}}}}, Done: true},
wantToolCalls: []api.ToolCall{
{
Function: api.ToolCallFunction{
@ -400,9 +381,9 @@ func TestChatHarmonyParserStreamingSimple(t *testing.T) {
gin.SetMode(gin.TestMode)
mockResponses := []llm.CompletionResponse{
{Content: "<|message|>First ", Done: false},
{Content: "First ", Done: false},
{Content: "chunk ", Done: false},
{Content: "here<|end|>", Done: true, DoneReason: llm.DoneReasonStop},
{Content: "here", Done: true, DoneReason: llm.DoneReasonStop},
}
mock := mockRunner{
@ -507,189 +488,3 @@ func TestChatHarmonyParserStreamingSimple(t *testing.T) {
t.Errorf("expected at least 2 content chunks for streaming, got %d", contentChunks)
}
}
func TestChatHarmonyParserStreaming(t *testing.T) {
gin.SetMode(gin.TestMode)
type expectedChunk struct {
afterResponse int // Which mock response this chunk should appear after
content string // Expected content in this chunk
thinking string // Expected thinking in this chunk
}
testCases := []struct {
name string
mockResponses []llm.CompletionResponse
expectedChunks []expectedChunk
wantContent string
wantThinking string
}{
{
name: "simple message without thinking",
mockResponses: []llm.CompletionResponse{
{Content: "<|start|>assistant<|message|>Hello, ", Done: false},
{Content: "how can I help?", Done: false},
{Content: "<|end|>", Done: true, DoneReason: llm.DoneReasonStop},
},
expectedChunks: []expectedChunk{
{afterResponse: 1, content: "Hello, "},
{afterResponse: 2, content: "how can I help?"},
},
wantContent: "Hello, how can I help?",
},
{
name: "message with analysis channel for thinking",
mockResponses: []llm.CompletionResponse{
{Content: "<|channel|>analysis<|message|>", Done: false},
{Content: "Let me think ", Done: false},
{Content: "about this problem...", Done: false},
{Content: "<|end|>", Done: false},
{Content: "<|start|>assistant<|message|>", Done: false},
{Content: "The answer ", Done: false},
{Content: "is 42", Done: false},
{Content: "<|end|>", Done: true, DoneReason: llm.DoneReasonStop},
},
expectedChunks: []expectedChunk{
{afterResponse: 2, thinking: "Let me think "},
{afterResponse: 3, thinking: "about this problem..."},
{afterResponse: 6, content: "The answer "},
{afterResponse: 7, content: "is 42"},
},
wantContent: "The answer is 42",
wantThinking: "Let me think about this problem...",
},
{
name: "streaming with partial tags across boundaries",
mockResponses: []llm.CompletionResponse{
{Content: "<|chan", Done: false},
{Content: "nel|>analy", Done: false},
{Content: "sis<|mess", Done: false},
{Content: "age|>Think", Done: false},
{Content: "ing deeply...<|end|>", Done: false},
{Content: "<|start|>assi", Done: false},
{Content: "stant<|message|>Result ", Done: false},
{Content: "computed<|e", Done: false},
{Content: "nd|>", Done: true, DoneReason: llm.DoneReasonStop},
},
expectedChunks: []expectedChunk{
{afterResponse: 4, thinking: "Think"},
{afterResponse: 5, thinking: "ing deeply..."},
{afterResponse: 7, content: "Result "},
{afterResponse: 8, content: "computed"},
},
wantContent: "Result computed",
wantThinking: "Thinking deeply...",
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
// Channel to synchronize mock responses with chunk verification
responsesSent := make(chan int, len(tc.mockResponses))
mock := mockRunner{
CompletionFn: func(ctx context.Context, r llm.CompletionRequest, fn func(llm.CompletionResponse)) error {
// Send mock responses one at a time, notifying when each is sent
for i, resp := range tc.mockResponses {
fn(resp)
responsesSent <- i + 1
}
close(responsesSent)
return nil
},
}
s := Server{
sched: &Scheduler{
pendingReqCh: make(chan *LlmRequest, 1),
finishedReqCh: make(chan *LlmRequest, 1),
expiredCh: make(chan *runnerRef, 1),
unloadedCh: make(chan any, 1),
loaded: make(map[string]*runnerRef),
newServerFn: newMockServer(&mock),
getGpuFn: discover.GetGPUInfo,
getCpuFn: discover.GetCPUInfo,
reschedDelay: 250 * time.Millisecond,
loadFn: func(req *LlmRequest, _ *ggml.GGML, _ discover.GpuInfoList, _ bool) bool {
req.successCh <- &runnerRef{
llama: &mock,
}
return false
},
},
}
go s.sched.Run(t.Context())
// Create a minimal model
_, digest := createHarmonyTestModel(t)
// Create model with passthrough template
stream := false
w := createRequest(t, s.CreateHandler, api.CreateRequest{
Model: "harmony-test",
Files: map[string]string{"file.gguf": digest},
Template: `<|start|><|end|>{{ with .Tools }}{{ end }}{{ .Prompt }}`,
Stream: &stream,
})
if w.Code != http.StatusOK {
t.Fatalf("failed to create model: %d", w.Code)
}
// Test chat endpoint with streaming
streamTrue := true
w = createRequest(t, s.ChatHandler, api.ChatRequest{
Model: "harmony-test",
Messages: []api.Message{{Role: "user", Content: "Hello"}},
Stream: &streamTrue,
Tools: getTestTools(),
})
if w.Code != http.StatusOK {
t.Fatalf("chat request failed: %d - %s", w.Code, w.Body.String())
}
// Parse streaming response
var chunks []api.ChatResponse
var content, thinking strings.Builder
decoder := json.NewDecoder(w.Body)
for decoder.More() {
var chunk api.ChatResponse
if err := decoder.Decode(&chunk); err != nil {
t.Fatalf("failed to decode chunk: %v", err)
}
chunks = append(chunks, chunk)
// Accumulate content and thinking from each chunk
content.WriteString(chunk.Message.Content)
thinking.WriteString(chunk.Message.Thinking)
// Debug output
t.Logf("Chunk %d: content=%q thinking=%q done=%v", len(chunks), chunk.Message.Content, chunk.Message.Thinking, chunk.Done)
}
// Verify we got streaming chunks
if len(chunks) == 0 {
t.Fatal("expected streaming chunks, got none")
}
gotContent := content.String()
gotThinking := thinking.String()
if gotContent != tc.wantContent {
t.Errorf("content mismatch: got %q, want %q", gotContent, tc.wantContent)
}
if gotThinking != tc.wantThinking {
t.Errorf("thinking mismatch: got %q, want %q", gotThinking, tc.wantThinking)
}
// Verify last chunk has done=true
lastChunk := chunks[len(chunks)-1]
if !lastChunk.Done {
t.Error("expected last chunk to have done=true")
}
})
}
}