Compare commits

..

8 Commits

Author SHA1 Message Date
ParthSareen
f30d01801d routes: update generate handler to use runner with harmony 2025-08-22 16:06:41 -07:00
ParthSareen
b08c7dad0a harmony: add harmony parsing to runner 2025-08-22 15:47:10 -07:00
ParthSareen
bc5ab5784b routes: ChatHandler to get parsed harmony from runner 2025-08-22 15:46:42 -07:00
ParthSareen
92a99e67c7 harmony: simplify prefill, add marshalling for functions, and update harmony check 2025-08-22 15:45:11 -07:00
ParthSareen
05cebf1f21 server: update completion request signature and update token repeat 2025-08-22 15:40:32 -07:00
ParthSareen
51a400ff0f server: add thinking and tool calls to CompletionResponse 2025-08-21 14:50:34 -07:00
ParthSareen
a865b50d9a harmony: move harmony parsing into a package 2025-08-21 12:42:48 -07:00
Daniel Hiltgen
31f64183dc perf: build graph for next batch in parallel to keep GPU busy
This refactors the main run loop of the ollama runner to perform the main GPU
intensive tasks (Compute+Floats) in a go routine so we can prepare the next
batch in parallel to reduce the amount of time the GPU stalls waiting for the
next batch of work.
2025-08-19 12:33:03 -07:00
79 changed files with 1070 additions and 2425 deletions

View File

@@ -65,36 +65,14 @@ jobs:
arch: amd64
preset: 'CUDA 12'
install: https://developer.download.nvidia.com/compute/cuda/12.8.0/local_installers/cuda_12.8.0_571.96_windows.exe
cuda-components:
- '"cudart"'
- '"nvcc"'
- '"cublas"'
- '"cublas_dev"'
cuda-version: '12.8'
flags: ''
runner_dir: 'cuda_v12'
- os: windows
arch: amd64
preset: 'CUDA 13'
install: https://developer.download.nvidia.com/compute/cuda/13.0.0/local_installers/cuda_13.0.0_windows.exe
cuda-components:
- '"cudart"'
- '"nvcc"'
- '"cublas"'
- '"cublas_dev"'
- '"crt"'
- '"nvvm"'
- '"nvptxcompiler"'
cuda-version: '13.0'
flags: ''
runner_dir: 'cuda_v13'
- os: windows
arch: amd64
preset: 'ROCm 6'
install: https://download.amd.com/developer/eula/rocm-hub/AMD-Software-PRO-Edition-24.Q4-WinSvr2022-For-HIP.exe
rocm-version: '6.2'
flags: '-DCMAKE_C_COMPILER=clang -DCMAKE_CXX_COMPILER=clang++ -DCMAKE_C_FLAGS="-parallel-jobs=4 -Wno-ignored-attributes -Wno-deprecated-pragma" -DCMAKE_CXX_FLAGS="-parallel-jobs=4 -Wno-ignored-attributes -Wno-deprecated-pragma"'
runner_dir: ''
runs-on: ${{ matrix.arch == 'arm64' && format('{0}-{1}', matrix.os, matrix.arch) || matrix.os }}
environment: release
env:
@@ -118,7 +96,7 @@ jobs:
$ErrorActionPreference = "Stop"
if ("${{ steps.cache-install.outputs.cache-hit }}" -ne 'true') {
Invoke-WebRequest -Uri "${{ matrix.install }}" -OutFile "install.exe"
$subpackages = @(${{ join(matrix.cuda-components, ', ') }}) | Foreach-Object {"${_}_${{ matrix.cuda-version }}"}
$subpackages = @("cudart", "nvcc", "cublas", "cublas_dev") | Foreach-Object {"${_}_${{ matrix.cuda-version }}"}
Start-Process -FilePath .\install.exe -ArgumentList (@("-s") + $subpackages) -NoNewWindow -Wait
}
@@ -160,7 +138,7 @@ jobs:
run: |
Import-Module 'C:\Program Files\Microsoft Visual Studio\2022\Enterprise\Common7\Tools\Microsoft.VisualStudio.DevShell.dll'
Enter-VsDevShell -VsInstallPath 'C:\Program Files\Microsoft Visual Studio\2022\Enterprise' -SkipAutomaticLocation -DevCmdArguments '-arch=x64 -no_logo'
cmake --preset "${{ matrix.preset }}" ${{ matrix.flags }} -DOLLAMA_RUNNER_DIR="${{ matrix.runner_dir }}"
cmake --preset "${{ matrix.preset }}" ${{ matrix.flags }}
cmake --build --parallel --preset "${{ matrix.preset }}"
cmake --install build --component "${{ startsWith(matrix.preset, 'CUDA ') && 'CUDA' || startsWith(matrix.preset, 'ROCm ') && 'HIP' || 'CPU' }}" --strip --parallel 8
env:
@@ -254,7 +232,7 @@ jobs:
case "$COMPONENT" in
bin/ollama) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}.tar.in ;;
lib/ollama/*.so*) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}.tar.in ;;
lib/ollama/cuda_v*) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}.tar.in ;;
lib/ollama/cuda_sbsa) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}.tar.in ;;
lib/ollama/cuda_jetpack5) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}-jetpack5.tar.in ;;
lib/ollama/cuda_jetpack6) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}-jetpack6.tar.in ;;
lib/ollama/rocm) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}-rocm.tar.in ;;

View File

@@ -46,7 +46,7 @@ jobs:
include:
- preset: CPU
- preset: CUDA
container: nvidia/cuda:13.0.0-devel-ubuntu22.04
container: nvidia/cuda:12.8.1-devel-ubuntu22.04
flags: '-DCMAKE_CUDA_ARCHITECTURES=87'
- preset: ROCm
container: rocm/dev-ubuntu-22.04:6.1.2
@@ -78,17 +78,8 @@ jobs:
include:
- preset: CPU
- preset: CUDA
install: https://developer.download.nvidia.com/compute/cuda/13.0.0/local_installers/cuda_13.0.0_windows.exe
install: https://developer.download.nvidia.com/compute/cuda/12.8.0/local_installers/cuda_12.8.0_571.96_windows.exe
flags: '-DCMAKE_CUDA_ARCHITECTURES=80'
cuda-components:
- '"cudart"'
- '"nvcc"'
- '"cublas"'
- '"cublas_dev"'
- '"crt"'
- '"nvvm"'
- '"nvptxcompiler"'
cuda-version: '13.0'
- preset: ROCm
install: https://download.amd.com/developer/eula/rocm-hub/AMD-Software-PRO-Edition-24.Q4-WinSvr2022-For-HIP.exe
flags: '-DAMDGPU_TARGETS=gfx1010 -DCMAKE_C_COMPILER=clang -DCMAKE_CXX_COMPILER=clang++ -DCMAKE_C_FLAGS="-parallel-jobs=4 -Wno-ignored-attributes -Wno-deprecated-pragma" -DCMAKE_CXX_FLAGS="-parallel-jobs=4 -Wno-ignored-attributes -Wno-deprecated-pragma"'
@@ -111,8 +102,7 @@ jobs:
$ErrorActionPreference = "Stop"
if ("${{ steps.cache-install.outputs.cache-hit }}" -ne 'true') {
Invoke-WebRequest -Uri "${{ matrix.install }}" -OutFile "install.exe"
$subpackages = @(${{ join(matrix.cuda-components, ', ') }}) | Foreach-Object {"${_}_${{ matrix.cuda-version }}"}
Start-Process -FilePath .\install.exe -ArgumentList (@("-s") + $subpackages) -NoNewWindow -Wait
Start-Process -FilePath .\install.exe -ArgumentList (@("-s", "cudart_12.8", "nvcc_12.8", "cublas_12.8", "cublas_dev_12.8")) -NoNewWindow -Wait
}
$cudaPath = (Resolve-Path "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\*").path

View File

@@ -38,7 +38,7 @@ if (CMAKE_OSX_ARCHITECTURES MATCHES "x86_64")
endif()
set(OLLAMA_BUILD_DIR ${CMAKE_BINARY_DIR}/lib/ollama)
set(OLLAMA_INSTALL_DIR ${CMAKE_INSTALL_PREFIX}/lib/ollama/${OLLAMA_RUNNER_DIR})
set(OLLAMA_INSTALL_DIR ${CMAKE_INSTALL_PREFIX}/lib/ollama)
set(CMAKE_RUNTIME_OUTPUT_DIRECTORY ${OLLAMA_BUILD_DIR})
set(CMAKE_RUNTIME_OUTPUT_DIRECTORY_DEBUG ${OLLAMA_BUILD_DIR})
@@ -81,7 +81,7 @@ if(CMAKE_CUDA_COMPILER)
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/ggml/ggml/src/ggml-cuda)
install(TARGETS ggml-cuda
RUNTIME_DEPENDENCIES
DIRECTORIES ${CUDAToolkit_BIN_DIR} ${CUDAToolkit_BIN_DIR}/x64 ${CUDAToolkit_LIBRARY_DIR}
DIRECTORIES ${CUDAToolkit_BIN_DIR} ${CUDAToolkit_LIBRARY_DIR}
PRE_INCLUDE_REGEXES cublas cublasLt cudart
PRE_EXCLUDE_REGEXES ".*"
RUNTIME DESTINATION ${OLLAMA_INSTALL_DIR} COMPONENT CUDA

View File

@@ -18,14 +18,6 @@
"name": "CUDA",
"inherits": [ "Default" ]
},
{
"name": "CUDA 11",
"inherits": [ "CUDA" ],
"cacheVariables": {
"CMAKE_CUDA_ARCHITECTURES": "50-virtual;60-virtual;61-virtual;70-virtual;75-virtual;80-virtual;86-virtual;87-virtual;89-virtual;90-virtual",
"CMAKE_CUDA_FLAGS": "-Wno-deprecated-gpu-targets -t 2"
}
},
{
"name": "CUDA 12",
"inherits": [ "CUDA" ],
@@ -34,14 +26,6 @@
"CMAKE_CUDA_FLAGS": "-Wno-deprecated-gpu-targets -t 2"
}
},
{
"name": "CUDA 13",
"inherits": [ "CUDA" ],
"cacheVariables": {
"CMAKE_CUDA_ARCHITECTURES": "75-virtual;80-virtual;86-virtual;87-virtual;89-virtual;90-virtual;90a-virtual;100-virtual;110-virtual;120-virtual;121-virtual",
"CMAKE_CUDA_FLAGS": "-t 2"
}
},
{
"name": "JetPack 5",
"inherits": [ "CUDA" ],
@@ -88,21 +72,11 @@
"configurePreset": "CUDA",
"targets": [ "ggml-cuda" ]
},
{
"name": "CUDA 11",
"inherits": [ "CUDA" ],
"configurePreset": "CUDA 11"
},
{
"name": "CUDA 12",
"inherits": [ "CUDA" ],
"configurePreset": "CUDA 12"
},
{
"name": "CUDA 13",
"inherits": [ "CUDA" ],
"configurePreset": "CUDA 13"
},
{
"name": "JetPack 5",
"inherits": [ "CUDA" ],

View File

@@ -39,35 +39,15 @@ RUN --mount=type=cache,target=/root/.ccache \
&& cmake --build --parallel --preset 'CPU' \
&& cmake --install build --component CPU --strip --parallel 8
FROM base AS cuda-11
ARG CUDA11VERSION=11.8
RUN dnf install -y cuda-toolkit-${CUDA11VERSION//./-}
ENV PATH=/usr/local/cuda-11/bin:$PATH
RUN --mount=type=cache,target=/root/.ccache \
cmake --preset 'CUDA 11' -DOLLAMA_RUNNER_DIR="cuda_v11" \
&& cmake --build --parallel --preset 'CUDA 11' \
&& cmake --install build --component CUDA --strip --parallel 8
FROM base AS cuda-12
ARG CUDA12VERSION=12.8
RUN dnf install -y cuda-toolkit-${CUDA12VERSION//./-}
ENV PATH=/usr/local/cuda-12/bin:$PATH
RUN --mount=type=cache,target=/root/.ccache \
cmake --preset 'CUDA 12' -DOLLAMA_RUNNER_DIR="cuda_v12"\
cmake --preset 'CUDA 12' \
&& cmake --build --parallel --preset 'CUDA 12' \
&& cmake --install build --component CUDA --strip --parallel 8
FROM base AS cuda-13
ARG CUDA13VERSION=13.0
RUN dnf install -y cuda-toolkit-${CUDA13VERSION//./-}
ENV PATH=/usr/local/cuda-13/bin:$PATH
RUN --mount=type=cache,target=/root/.ccache \
cmake --preset 'CUDA 13' -DOLLAMA_RUNNER_DIR="cuda_v13" \
&& cmake --build --parallel --preset 'CUDA 13' \
&& cmake --install build --component CUDA --strip --parallel 8
FROM base AS rocm-6
ENV PATH=/opt/rocm/hcc/bin:/opt/rocm/hip/bin:/opt/rocm/bin:/opt/rocm/hcc/bin:$PATH
RUN --mount=type=cache,target=/root/.ccache \
@@ -112,14 +92,10 @@ RUN --mount=type=cache,target=/root/.cache/go-build \
go build -trimpath -buildmode=pie -o /bin/ollama .
FROM --platform=linux/amd64 scratch AS amd64
# COPY --from=cuda-11 dist/lib/ollama/ /lib/ollama/
COPY --from=cuda-12 dist/lib/ollama /lib/ollama/
COPY --from=cuda-13 dist/lib/ollama/ /lib/ollama/
COPY --from=cuda-12 dist/lib/ollama /lib/ollama
FROM --platform=linux/arm64 scratch AS arm64
# COPY --from=cuda-11 dist/lib/ollama/ /lib/ollama/
COPY --from=cuda-12 dist/lib/ollama /lib/ollama/
COPY --from=cuda-13 dist/lib/ollama/ /lib/ollama/
COPY --from=cuda-12 dist/lib/ollama /lib/ollama/cuda_sbsa
COPY --from=jetpack-5 dist/lib/ollama /lib/ollama/cuda_jetpack5
COPY --from=jetpack-6 dist/lib/ollama /lib/ollama/cuda_jetpack6

View File

@@ -413,8 +413,6 @@ See the [API documentation](./docs/api.md) for all endpoints.
- [Mayan EDMS](https://gitlab.com/mayan-edms/mayan-edms) (Open source document management system to organize, tag, search, and automate your files with powerful Ollama driven workflows.)
- [Serene Pub](https://github.com/doolijb/serene-pub) (Beginner friendly, open source AI Roleplaying App for Windows, Mac OS and Linux. Search, download and use models with Ollama all inside the app.)
- [Andes](https://github.com/aqerd/andes) (A Visual Studio Code extension that provides a local UI interface for Ollama models)
- [Clueless](https://github.com/KashyapTan/clueless) (Open Source & Local Cluely: A desktop application LLM assistant to help you talk to anything on your screen using locally served Ollama models. Also undetectable to screenshare)
- [ollama-co2](https://github.com/carbonatedWaterOrg/ollama-co2) (FastAPI web interface for monitoring and managing local and remote Ollama servers with real-time model monitoring and concurrent downloads)
### Cloud
@@ -543,7 +541,6 @@ See the [API documentation](./docs/api.md) for all endpoints.
- [OllamaPlusPlus](https://github.com/HardCodeDev777/OllamaPlusPlus) (Very simple C++ library for Ollama)
- [any-llm](https://github.com/mozilla-ai/any-llm) (A single interface to use different llm providers by [mozilla.ai](https://www.mozilla.ai/))
- [any-agent](https://github.com/mozilla-ai/any-agent) (A single interface to use and evaluate different agent frameworks by [mozilla.ai](https://www.mozilla.ai/))
- [Neuro SAN](https://github.com/cognizant-ai-lab/neuro-san-studio) (Data-driven multi-agent orchestration framework) with [example](https://github.com/cognizant-ai-lab/neuro-san-studio/blob/main/docs/user_guide.md#ollama)
### Mobile
@@ -604,7 +601,6 @@ See the [API documentation](./docs/api.md) for all endpoints.
- [UnityCodeLama](https://github.com/HardCodeDev777/UnityCodeLama) (Unity Edtior tool to analyze scripts via Ollama)
- [NativeMind](https://github.com/NativeMindBrowser/NativeMindExtension) (Private, on-device AI Assistant, no cloud dependencies)
- [GMAI - Gradle Managed AI](https://gmai.premex.se/) (Gradle plugin for automated Ollama lifecycle management during build phases)
- [NOMYO Router](https://github.com/nomyo-ai/nomyo-router) (A transparent Ollama proxy with model deployment aware routing which auto-manages multiple Ollama instances in a given network)
### Supported backends

View File

@@ -286,23 +286,16 @@ func mapToTypeScriptType(jsonType string) string {
}
}
type ToolFunctionParameters struct {
Type string `json:"type"`
Defs any `json:"$defs,omitempty"`
Items any `json:"items,omitempty"`
Required []string `json:"required"`
Properties map[string]ToolProperty `json:"properties"`
}
func (t *ToolFunctionParameters) String() string {
bts, _ := json.Marshal(t)
return string(bts)
}
type ToolFunction struct {
Name string `json:"name"`
Description string `json:"description"`
Parameters ToolFunctionParameters `json:"parameters"`
Name string `json:"name"`
Description string `json:"description"`
Parameters struct {
Type string `json:"type"`
Defs any `json:"$defs,omitempty"`
Items any `json:"items,omitempty"`
Required []string `json:"required"`
Properties map[string]ToolProperty `json:"properties"`
} `json:"parameters"`
}
func (t *ToolFunction) String() string {
@@ -388,12 +381,8 @@ type EmbedRequest struct {
// this request.
KeepAlive *Duration `json:"keep_alive,omitempty"`
// Truncate truncates the input to fit the model's max sequence length.
Truncate *bool `json:"truncate,omitempty"`
// Dimensions truncates the output embedding to the specified dimension.
Dimensions int `json:"dimensions,omitempty"`
// Options lists model-specific options.
Options map[string]any `json:"options"`
}
@@ -892,7 +881,7 @@ func (d *Duration) UnmarshalJSON(b []byte) (err error) {
if t < 0 {
d.Duration = time.Duration(math.MaxInt64)
} else {
d.Duration = time.Duration(t * float64(time.Second))
d.Duration = time.Duration(int(t) * int(time.Second))
}
case string:
d.Duration, err = time.ParseDuration(t)

View File

@@ -17,11 +17,6 @@ func TestKeepAliveParsingFromJSON(t *testing.T) {
req string
exp *Duration
}{
{
name: "Unset",
req: `{ }`,
exp: nil,
},
{
name: "Positive Integer",
req: `{ "keep_alive": 42 }`,
@@ -30,7 +25,7 @@ func TestKeepAliveParsingFromJSON(t *testing.T) {
{
name: "Positive Float",
req: `{ "keep_alive": 42.5 }`,
exp: &Duration{42500 * time.Millisecond},
exp: &Duration{42 * time.Second},
},
{
name: "Positive Integer String",
@@ -441,50 +436,3 @@ func TestThinking_UnmarshalJSON(t *testing.T) {
})
}
}
func TestToolFunctionParameters_String(t *testing.T) {
tests := []struct {
name string
params ToolFunctionParameters
expected string
}{
{
name: "simple object with string property",
params: ToolFunctionParameters{
Type: "object",
Required: []string{"name"},
Properties: map[string]ToolProperty{
"name": {
Type: PropertyType{"string"},
Description: "The name of the person",
},
},
},
expected: `{"type":"object","required":["name"],"properties":{"name":{"type":"string","description":"The name of the person"}}}`,
},
{
name: "marshal failure returns empty string",
params: ToolFunctionParameters{
Type: "object",
Defs: func() any {
// Create a cycle that will cause json.Marshal to fail
type selfRef struct {
Self *selfRef
}
s := &selfRef{}
s.Self = s
return s
}(),
Properties: map[string]ToolProperty{},
},
expected: "",
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
result := test.params.String()
assert.Equal(t, test.expected, result)
})
}
}

View File

@@ -56,8 +56,10 @@ func ensureThinkingSupport(ctx context.Context, client *api.Client, name string)
if err != nil {
return
}
if slices.Contains(resp.Capabilities, model.CapabilityThinking) {
return
for _, cap := range resp.Capabilities {
if cap == model.CapabilityThinking {
return
}
}
fmt.Fprintf(os.Stderr, "warning: model %q does not support thinking output\n", name)
}

View File

@@ -15,24 +15,19 @@ import (
type gptossModel struct {
ModelParameters
HiddenLayers uint32 `json:"num_hidden_layers"`
MaxPositionEmbeddings uint32 `json:"max_position_embeddings"`
HiddenSize uint32 `json:"hidden_size"`
IntermediateSize uint32 `json:"intermediate_size"`
AttentionHeads uint32 `json:"num_attention_heads"`
KeyValueHeads uint32 `json:"num_key_value_heads"`
HeadDim uint32 `json:"head_dim"`
Experts uint32 `json:"num_experts"`
LocalExperts uint32 `json:"num_local_experts"`
ExpertsPerToken uint32 `json:"experts_per_token"`
RMSNormEpsilon float32 `json:"rms_norm_eps"`
InitialContextLength uint32 `json:"initial_context_length"`
RopeTheta float32 `json:"rope_theta"`
RopeScalingFactor float32 `json:"rope_scaling_factor"`
RopeScaling struct {
Factor float32 `json:"factor"`
} `json:"rope_scaling"`
SlidingWindow uint32 `json:"sliding_window"`
HiddenLayers uint32 `json:"num_hidden_layers"`
HiddenSize uint32 `json:"hidden_size"`
IntermediateSize uint32 `json:"intermediate_size"`
AttentionHeads uint32 `json:"num_attention_heads"`
KeyValueHeads uint32 `json:"num_key_value_heads"`
HeadDim uint32 `json:"head_dim"`
Experts uint32 `json:"num_experts"`
ExpertsPerToken uint32 `json:"experts_per_token"`
RMSNormEpsilon float32 `json:"rms_norm_eps"`
InitialContextLength uint32 `json:"initial_context_length"`
RopeTheta float32 `json:"rope_theta"`
RopeScalingFactor float32 `json:"rope_scaling_factor"`
SlidingWindow uint32 `json:"sliding_window"`
}
var _ ModelConverter = (*gptossModel)(nil)
@@ -41,11 +36,11 @@ func (m *gptossModel) KV(t *Tokenizer) ggml.KV {
kv := m.ModelParameters.KV(t)
kv["general.architecture"] = "gptoss"
kv["general.file_type"] = uint32(4)
kv["gptoss.context_length"] = cmp.Or(m.MaxPositionEmbeddings, uint32(m.RopeScalingFactor*float32(m.InitialContextLength)))
kv["gptoss.context_length"] = uint32(m.RopeScalingFactor * float32(m.InitialContextLength))
kv["gptoss.block_count"] = m.HiddenLayers
kv["gptoss.embedding_length"] = m.HiddenSize
kv["gptoss.feed_forward_length"] = m.IntermediateSize
kv["gptoss.expert_count"] = cmp.Or(m.Experts, m.LocalExperts)
kv["gptoss.expert_count"] = m.Experts
kv["gptoss.expert_used_count"] = m.ExpertsPerToken
kv["gptoss.attention.head_count"] = m.AttentionHeads
kv["gptoss.attention.head_count_kv"] = m.KeyValueHeads
@@ -54,7 +49,7 @@ func (m *gptossModel) KV(t *Tokenizer) ggml.KV {
kv["gptoss.attention.layer_norm_rms_epsilon"] = cmp.Or(m.RMSNormEpsilon, 1e-5)
kv["gptoss.attention.sliding_window"] = m.SlidingWindow
kv["gptoss.rope.freq_base"] = m.RopeTheta
kv["gptoss.rope.scaling.factor"] = cmp.Or(m.RopeScalingFactor, m.RopeScaling.Factor)
kv["gptoss.rope.scaling.factor"] = m.RopeScalingFactor
kv["gptoss.rope.scaling.original_context_length"] = m.InitialContextLength
kv["tokenizer.ggml.bos_token_id"] = uint32(199998) // <|startoftext|>
kv["tokenizer.ggml.add_bos_token"] = false
@@ -97,11 +92,6 @@ func (m *gptossModel) Tensors(ts []Tensor) []*ggml.Tensor {
for name, mxfp4 := range mxfp4s {
dims := mxfp4.blocks.Shape()
if !strings.HasSuffix(name, ".weight") {
name += ".weight"
}
out = append(out, &ggml.Tensor{
Name: name,
Kind: uint32(ggml.TensorTypeMXFP4),
@@ -114,47 +104,25 @@ func (m *gptossModel) Tensors(ts []Tensor) []*ggml.Tensor {
}
func (m *gptossModel) Replacements() []string {
var replacements []string
if m.MaxPositionEmbeddings > 0 {
// hf flavored model
replacements = []string{
"lm_head", "output",
"model.embed_tokens", "token_embd",
"model.layers", "blk",
"input_layernorm", "attn_norm",
"self_attn.q_proj", "attn_q",
"self_attn.k_proj", "attn_k",
"self_attn.v_proj", "attn_v",
"self_attn.o_proj", "attn_out",
"self_attn.sinks", "attn_sinks",
"post_attention_layernorm", "ffn_norm",
"mlp.router", "ffn_gate_inp",
"mlp.experts.gate_up_proj_", "ffn_gate_up_exps.",
"mlp.experts.down_proj_", "ffn_down_exps.",
"model.norm", "output_norm",
}
} else {
replacements = []string{
// noop replacements so other replacements will not be applied
".blocks", ".blocks",
".scales", ".scales",
// real replacements
"block", "blk",
"attn.norm", "attn_norm",
"attn.qkv", "attn_qkv",
"attn.sinks", "attn_sinks",
"attn.out", "attn_out",
"mlp.norm", "ffn_norm",
"mlp.gate", "ffn_gate_inp",
"mlp.mlp1_", "ffn_gate_up_exps.",
"mlp.mlp2_", "ffn_down_exps.",
"embedding", "token_embd",
"norm", "output_norm",
"unembedding", "output",
"scale", "weight",
}
return []string{
// noop replacements so other replacements will not be applied
".blocks", ".blocks",
".scales", ".scales",
// real replacements
"block", "blk",
"attn.norm", "attn_norm",
"attn.qkv", "attn_qkv",
"attn.sinks", "attn_sinks",
"attn.out", "attn_out",
"mlp.norm", "ffn_norm",
"mlp.gate", "ffn_gate_inp",
"mlp.mlp1_", "ffn_gate_up_exps.",
"mlp.mlp2_", "ffn_down_exps.",
"embedding", "token_embd",
"norm", "output_norm",
"unembedding", "output",
"scale", "weight",
}
return replacements
}
type mxfp4 struct {
@@ -172,20 +140,7 @@ func (m *mxfp4) WriteTo(w io.Writer) (int64, error) {
blocksDims[i] = int(d)
}
bts := b.Bytes()
var tmp [16]byte
for i := 0; i < b.Len(); i += 16 {
for j := range 8 {
// transform a1b2c3 ... x7y8z9 -> 71xa82yb93zc
a, b := bts[i+j], bts[i+j+8]
tmp[2*j+0] = (a & 0x0F) | (b << 4)
tmp[2*j+1] = (a >> 4) | (b & 0xF0)
}
copy(bts[i:i+16], tmp[:])
}
var blocks tensor.Tensor = tensor.New(tensor.WithShape(blocksDims...), tensor.WithBacking(bts))
var blocks tensor.Tensor = tensor.New(tensor.WithShape(blocksDims...), tensor.WithBacking(b.Bytes()))
var s bytes.Buffer
if _, err := m.scales.WriteTo(&s); err != nil {
@@ -219,5 +174,5 @@ func (m *mxfp4) WriteTo(w io.Writer) (int64, error) {
return 0, err
}
return int64(len(u8s)), nil
return 0, nil
}

View File

@@ -33,8 +33,8 @@ func (t tensorBase) Shape() []uint64 {
const (
tensorKindFP32 uint32 = iota
tensorKindFP16
tensorKindMXFP4 = 4
tensorKindBF16 = 30
tensorKindMXFP4 = 39
)
func (t tensorBase) Kind() uint32 {

View File

@@ -188,17 +188,17 @@ func (st safetensor) WriteTo(w io.Writer) (int64, error) {
switch st.Kind() {
case tensorKindFP32:
return int64(len(f32s) * 4), binary.Write(w, binary.LittleEndian, f32s)
return 0, binary.Write(w, binary.LittleEndian, f32s)
case tensorKindFP16:
f16s := make([]uint16, len(f32s))
for i := range f32s {
f16s[i] = float16.Fromfloat32(f32s[i]).Bits()
}
return int64(len(f16s) * 2), binary.Write(w, binary.LittleEndian, f16s)
return 0, binary.Write(w, binary.LittleEndian, f16s)
case tensorKindBF16:
u8s := bfloat16.EncodeFloat32(f32s)
return int64(len(u8s)), binary.Write(w, binary.LittleEndian, u8s)
return 0, binary.Write(w, binary.LittleEndian, u8s)
default:
return 0, fmt.Errorf("unknown storage type: %d", st.Kind())
}

View File

@@ -277,7 +277,6 @@ func AMDGetGPUInfo() ([]RocmGPUInfo, error) {
FreeMemory: (totalMemory - usedMemory),
},
ID: ID,
filterID: gpuOrdinalID,
Name: name,
Compute: fmt.Sprintf("gfx%d%x%x", major, minor, patch),
MinimumMemory: rocmMinimumMemory,
@@ -395,7 +394,7 @@ func AMDGetGPUInfo() ([]RocmGPUInfo, error) {
// Check for env var workarounds
if name == "1002:687f" { // Vega RX 56
gpuInfo.EnvWorkarounds = append(gpuInfo.EnvWorkarounds, "HSA_ENABLE_SDMA=0")
gpuInfo.EnvWorkarounds = append(gpuInfo.EnvWorkarounds, [2]string{"HSA_ENABLE_SDMA", "0"})
}
// The GPU has passed all the verification steps and is supported
@@ -524,26 +523,19 @@ func verifyKFDDriverAccess() error {
return nil
}
func rocmGetVisibleDevicesEnv(gpuInfo []GpuInfo) string {
func rocmGetVisibleDevicesEnv(gpuInfo []GpuInfo) (string, string) {
ids := []string{}
for _, info := range gpuInfo {
if info.Library != "rocm" {
// TODO shouldn't happen if things are wired correctly...
slog.Debug("rocmGetVisibleDevicesEnv skipping over non-rocm device", "library", info.Library)
continue
}
// If the devices requires a numeric ID, for filtering purposes, we use the unfiltered ID number
if _, err := strconv.Atoi(info.ID); err == nil {
ids = append(ids, fmt.Sprintf("%d", info.filterID))
} else {
ids = append(ids, info.ID)
}
ids = append(ids, info.ID)
}
if len(ids) == 0 {
return ""
}
// There are 3 potential env vars to use to select GPUs.
// ROCR_VISIBLE_DEVICES supports UUID or numeric so is our preferred on linux
// GPU_DEVICE_ORDINAL supports numeric IDs only
// HIP_VISIBLE_DEVICES supports numeric IDs only
return "ROCR_VISIBLE_DEVICES=" + strings.Join(ids, ",")
return "ROCR_VISIBLE_DEVICES", strings.Join(ids, ",")
}

View File

@@ -111,7 +111,6 @@ func AMDGetGPUInfo() ([]RocmGPUInfo, error) {
UnreliableFreeMemory: true,
ID: strconv.Itoa(i), // TODO this is probably wrong if we specify visible devices
filterID: i,
DependencyPath: []string{libDir},
MinimumMemory: rocmMinimumMemory,
Name: name,
@@ -201,26 +200,19 @@ func (gpus RocmGPUInfoList) RefreshFreeMemory() error {
return nil
}
func rocmGetVisibleDevicesEnv(gpuInfo []GpuInfo) string {
func rocmGetVisibleDevicesEnv(gpuInfo []GpuInfo) (string, string) {
ids := []string{}
for _, info := range gpuInfo {
if info.Library != "rocm" {
// TODO shouldn't happen if things are wired correctly...
slog.Debug("rocmGetVisibleDevicesEnv skipping over non-rocm device", "library", info.Library)
continue
}
// If the devices requires a numeric ID, for filtering purposes, we use the unfiltered ID number
if _, err := strconv.Atoi(info.ID); err == nil {
ids = append(ids, fmt.Sprintf("%d", info.filterID))
} else {
ids = append(ids, info.ID)
}
ids = append(ids, info.ID)
}
if len(ids) == 0 {
return ""
}
// There are 3 potential env vars to use to select GPUs.
// ROCR_VISIBLE_DEVICES supports UUID or numeric but does not work on Windows
// HIP_VISIBLE_DEVICES supports numeric IDs only
// GPU_DEVICE_ORDINAL supports numeric IDs only
return "HIP_VISIBLE_DEVICES=" + strings.Join(ids, ",")
return "HIP_VISIBLE_DEVICES", strings.Join(ids, ",")
}

View File

@@ -16,6 +16,19 @@ import (
// Included to drive logic for reducing Ollama-allocated overhead on L4T/Jetson devices.
var CudaTegra string = os.Getenv("JETSON_JETPACK")
func cudaGetVisibleDevicesEnv(gpuInfo []GpuInfo) (string, string) {
ids := []string{}
for _, info := range gpuInfo {
if info.Library != "cuda" {
// TODO shouldn't happen if things are wired correctly...
slog.Debug("cudaGetVisibleDevicesEnv skipping over non-cuda device", "library", info.Library)
continue
}
ids = append(ids, info.ID)
}
return "CUDA_VISIBLE_DEVICES", strings.Join(ids, ",")
}
func cudaVariant(gpuInfo CudaGPUInfo) string {
if runtime.GOARCH == "arm64" && runtime.GOOS == "linux" {
if CudaTegra != "" {
@@ -43,15 +56,14 @@ func cudaVariant(gpuInfo CudaGPUInfo) string {
}
}
}
return "sbsa"
}
if gpuInfo.DriverMajor < 13 {
// The detected driver is older than 580 (Aug 2025)
// Warn if their CC is compatible with v13 and they should upgrade their driver to get better performance
if gpuInfo.computeMajor > 7 || (gpuInfo.computeMajor == 7 && gpuInfo.computeMinor >= 5) {
slog.Warn("old CUDA driver detected - please upgrade to a newer driver for best performance", "version", fmt.Sprintf("%d.%d", gpuInfo.DriverMajor, gpuInfo.DriverMinor))
}
return "v12"
// driver 12.0 has problems with the cuda v12 library, so run v11 on those older drivers
if gpuInfo.DriverMajor < 12 || (gpuInfo.DriverMajor == 12 && gpuInfo.DriverMinor == 0) {
// The detected driver is older than Feb 2023
slog.Warn("old CUDA driver detected - please upgrade to a newer driver", "version", fmt.Sprintf("%d.%d", gpuInfo.DriverMajor, gpuInfo.DriverMinor))
return "v11"
}
return "v13"
return "v12"
}

View File

@@ -371,15 +371,6 @@ func GetGPUInfo() GpuInfoList {
}
rocmGPUs, err = AMDGetGPUInfo()
// The ID field is used in context of the filtered set of GPUS
// so we have to replace any of these numeric IDs with their
// placement in this set of GPUs
for i := range rocmGPUs {
if _, err := strconv.Atoi(rocmGPUs[i].ID); err == nil {
rocmGPUs[i].ID = strconv.Itoa(i)
}
}
if err != nil {
bootstrapErrors = append(bootstrapErrors, err)
}
@@ -689,16 +680,23 @@ func getVerboseState() C.uint16_t {
// Given the list of GPUs this instantiation is targeted for,
// figure out the visible devices environment variable
func (l GpuInfoList) GetVisibleDevicesEnv() []string {
//
// If different libraries are detected, the first one is what we use
func (l GpuInfoList) GetVisibleDevicesEnv() (string, string) {
if len(l) == 0 {
return nil
return "", ""
}
vd := []string{}
// Only filter the AMD GPUs at this level, let all NVIDIA devices through
if tmp := rocmGetVisibleDevicesEnv(l); tmp != "" {
vd = append(vd, tmp)
switch l[0].Library {
case "cuda":
return cudaGetVisibleDevicesEnv(l)
case "rocm":
return rocmGetVisibleDevicesEnv(l)
case "oneapi":
return oneapiGetVisibleDevicesEnv(l)
default:
slog.Debug("no filter required for library " + l[0].Library)
return "", ""
}
return vd
}
func GetSystemInfo() SystemInfo {

View File

@@ -62,9 +62,9 @@ func GetCPUMem() (memInfo, error) {
}, nil
}
func (l GpuInfoList) GetVisibleDevicesEnv() []string {
func (l GpuInfoList) GetVisibleDevicesEnv() (string, string) {
// No-op on darwin
return nil
return "", ""
}
func GetSystemInfo() SystemInfo {

21
discover/gpu_oneapi.go Normal file
View File

@@ -0,0 +1,21 @@
//go:build linux || windows
package discover
import (
"log/slog"
"strings"
)
func oneapiGetVisibleDevicesEnv(gpuInfo []GpuInfo) (string, string) {
ids := []string{}
for _, info := range gpuInfo {
if info.Library != "oneapi" {
// TODO shouldn't happen if things are wired correctly...
slog.Debug("oneapiGetVisibleDevicesEnv skipping over non-sycl device", "library", info.Library)
continue
}
ids = append(ids, info.ID)
}
return "ONEAPI_DEVICE_SELECTOR", "level_zero:" + strings.Join(ids, ",")
}

View File

@@ -27,8 +27,8 @@ type GpuInfo struct { // TODO better name maybe "InferenceProcessor"?
// Any extra PATH/LD_LIBRARY_PATH dependencies required for the Library to operate properly
DependencyPath []string `json:"lib_path,omitempty"`
// Extra environment variables specific to the GPU as list of [key=value]
EnvWorkarounds []string `json:"envs,omitempty"`
// Extra environment variables specific to the GPU as list of [key,value]
EnvWorkarounds [][2]string `json:"envs,omitempty"`
// Set to true if we can NOT reliably discover FreeMemory. A value of true indicates
// the FreeMemory is best effort, and may over or under report actual memory usage
@@ -36,10 +36,9 @@ type GpuInfo struct { // TODO better name maybe "InferenceProcessor"?
UnreliableFreeMemory bool
// GPU information
ID string `json:"gpu_id"` // string to use for selection of this specific GPU
filterID int //nolint:unused,nolintlint // AMD Workaround: The numeric ID of the device used to filter out other devices
Name string `json:"name"` // user friendly name if available
Compute string `json:"compute"` // Compute Capability or gfx
ID string `json:"gpu_id"` // string to use for selection of this specific GPU
Name string `json:"name"` // user friendly name if available
Compute string `json:"compute"` // Compute Capability or gfx
// Driver Information - TODO no need to put this on each GPU
DriverMajor int `json:"driver_major,omitempty"`

View File

@@ -1708,7 +1708,6 @@ Advanced parameters:
- `truncate`: truncates the end of each input to fit within context length. Returns error if `false` and context length is exceeded. Defaults to `true`
- `options`: additional model parameters listed in the documentation for the [Modelfile](./modelfile.md#valid-parameters-and-values) such as `temperature`
- `keep_alive`: controls how long the model will stay loaded into memory following the request (default: `5m`)
- `dimensions`: number of dimensions for the embedding
### Examples

View File

@@ -11,13 +11,12 @@ curl -fsSL https://ollama.com/install.sh | sh
## Manual install
> [!NOTE]
> If you are upgrading from a prior version, you **MUST** remove the old libraries with `sudo rm -rf /usr/lib/ollama` first.
> If you are upgrading from a prior version, you should remove the old libraries with `sudo rm -rf /usr/lib/ollama` first.
Download and extract the package:
```shell
curl -LO https://ollama.com/download/ollama-linux-amd64.tgz
sudo rm -rf /usr/lib/ollama
sudo tar -C /usr -xzf ollama-linux-amd64.tgz
```

View File

@@ -92,9 +92,6 @@ If none of those resolve the problem, gather additional information and file an
- Set `CUDA_ERROR_LEVEL=50` and try again to get more diagnostic logs
- Check dmesg for any errors `sudo dmesg | grep -i nvrm` and `sudo dmesg | grep -i nvidia`
You may get more details for initialization failures by enabling debug prints in the uvm driver. You should only use this temporarily while troubleshooting
- `sudo rmmod nvidia_uvm` then `sudo modprobe nvidia_uvm uvm_debug_prints=1`
## AMD GPU Discovery

View File

@@ -185,6 +185,8 @@ var (
ContextLength = Uint("OLLAMA_CONTEXT_LENGTH", 4096)
// Auth enables authentication between the Ollama client and server
UseAuth = Bool("OLLAMA_AUTH")
// Enable the new memory estimation logic
NewMemoryEstimates = Bool("OLLAMA_NEW_ESTIMATES")
)
func String(s string) func() string {
@@ -270,6 +272,7 @@ func AsMap() map[string]EnvVar {
"OLLAMA_MULTIUSER_CACHE": {"OLLAMA_MULTIUSER_CACHE", MultiUserCache(), "Optimize prompt caching for multi-user scenarios"},
"OLLAMA_CONTEXT_LENGTH": {"OLLAMA_CONTEXT_LENGTH", ContextLength(), "Context length to use unless otherwise specified (default: 4096)"},
"OLLAMA_NEW_ENGINE": {"OLLAMA_NEW_ENGINE", NewEngine(), "Enable the new Ollama engine"},
"OLLAMA_NEW_ESTIMATES": {"OLLAMA_NEW_ESTIMATES", NewMemoryEstimates(), "Enable the new memory estimation logic"},
// Informational
"HTTP_PROXY": {"HTTP_PROXY", String("HTTP_PROXY")(), "HTTP proxy"},

View File

@@ -7,11 +7,9 @@ import (
"fmt"
"io"
"log/slog"
"math"
"slices"
"strings"
"github.com/ollama/ollama/format"
"github.com/ollama/ollama/fs/util/bufioutil"
)
@@ -57,28 +55,10 @@ func (kv KV) EmbeddingLength() uint64 {
return uint64(kv.Uint("embedding_length"))
}
func (kv KV) HeadCount() []uint64 {
headCountDefault := uint32(1)
headCount := kv.UintOrArrayValueAsArray("attention.head_count", headCountDefault)
if len(headCount) == 1 {
headCountDefault = headCount[0]
}
nLayers := int(kv.BlockCount())
if len(headCount) > nLayers {
slog.Warn("got more elements of attention.head_count than layers", "len(headCount)", len(headCount), "layers", nLayers)
}
out := make([]uint64, nLayers)
for i := range nLayers {
if i >= len(headCount) {
out[i] = uint64(headCountDefault)
} else {
out[i] = uint64(headCount[i])
}
}
return out
}
func (kv KV) HeadCountMax() uint64 {
// TODO(drifkin): using the max value can cause an overestimation. In the
// future if array values become more popular, we can adapt the more invasive
// <https://github.com/ollama/ollama/pull/10225>
return uint64(kv.UintOrMaxArrayValue("attention.head_count", 1))
}
@@ -86,27 +66,6 @@ func (kv KV) HeadCountMin() uint64 {
return uint64(kv.UintOrMinArrayValue("attention.head_count", 1))
}
func (kv KV) HeadCountKV() []uint64 {
headCountKVDefault := uint32(1)
headCountKV := kv.UintOrArrayValueAsArray("attention.head_count_kv", headCountKVDefault)
if len(headCountKV) == 1 {
headCountKVDefault = headCountKV[0]
}
nLayers := int(kv.BlockCount())
if len(headCountKV) > nLayers {
slog.Warn("got more elements of attention.head_count than layers", "len(headCountKV)", len(headCountKV), "layers", nLayers)
}
out := make([]uint64, nLayers)
for i := range nLayers {
if i >= len(headCountKV) {
out[i] = uint64(headCountKVDefault)
} else {
out[i] = uint64(headCountKV[i])
}
}
return out
}
func (kv KV) HeadCountKVMax() uint64 {
return uint64(kv.UintOrMaxArrayValue("attention.head_count_kv", 1))
}
@@ -139,26 +98,6 @@ func (kv KV) ChatTemplate() string {
return kv.String("tokenizer.chat_template")
}
// ssm architecture parameters
func (kv KV) SSMConvKernel() uint64 {
return uint64(kv.Uint("ssm.conv_kernel"))
}
func (kv KV) SSMInnerSize() uint64 {
return uint64(kv.Uint("ssm.inner_size"))
}
func (kv KV) SSMStateSize() uint64 {
return uint64(kv.Uint("ssm.state_size"))
}
func (kv KV) SSMGroupCount() uint64 {
return uint64(kv.Uint("ssm.group_count"))
}
// general types
func (kv KV) String(key string, defaultValue ...string) string {
val, _ := keyValue(kv, key, append(defaultValue, "")...)
return val
@@ -190,27 +129,22 @@ func (kv KV) UintOrMinArrayValue(key string, defaultValue uint32) uint32 {
}
func (kv KV) UintOrArrayValue(key string, defaultValue uint32) (uint32, uint32) {
arrVal := kv.UintOrArrayValueAsArray(key, defaultValue)
return slices.Min(arrVal), slices.Max(arrVal)
}
func (kv KV) UintOrArrayValueAsArray(key string, defaultValue uint32) []uint32 {
if u32, ok := keyValue(kv, key, uint32(0)); ok {
return []uint32{u32}
return u32, u32
} else if u32s, ok := keyValue(kv, key, &array[uint32]{}); ok {
return u32s.values
min := slices.Min(u32s.values)
max := slices.Max(u32s.values)
return min, max
} else if i32s, ok := keyValue(kv, key, &array[int32]{}); ok {
dst := make([]uint32, len(i32s.values))
for i, v := range i32s.values {
if v < 0 {
slog.Warn("array values are unexpectedly negative", "key", key, "i", i, "v", v)
}
dst[i] = uint32(v)
min := slices.Min(i32s.values)
max := slices.Max(i32s.values)
if min < 0 || max < 0 {
slog.Warn("array values are unexpectedly negative", "key", key, "min", min, "max", max)
}
return dst
return uint32(min), uint32(max)
}
return []uint32{defaultValue}
return defaultValue, defaultValue
}
func (kv KV) Strings(key string, defaultValue ...[]string) []string {
@@ -341,7 +275,7 @@ type Tensor struct {
func (t Tensor) block() (n int) {
if _, err := fmt.Sscanf(t.Name, "blk.%d.", &n); err != nil {
return math.MaxInt
return -1
}
return
@@ -354,24 +288,24 @@ func (t Tensor) blockSize() uint64 {
func (t TensorType) BlockSize() uint64 {
switch t {
case
TensorTypeF32,
TensorTypeF16,
TensorTypeI8,
TensorTypeI16,
TensorTypeI32,
TensorTypeI64,
TensorTypeF64,
TensorTypeBF16:
0, // F32
1, // F16
24, // I8
25, // I16
26, // I32
27, // I64
28, // F64
30: // BF16
return 1
case
TensorTypeQ4_0,
TensorTypeQ4_1,
TensorTypeQ5_0,
TensorTypeQ5_1,
TensorTypeQ8_0,
TensorTypeQ8_1,
tensorTypeIQ4_NL,
4, TensorTypeMXFP4:
2, // Q4_0
3, // Q4_1
4, // MXFP4
6, // Q5_0
7, // Q5_1
8, // Q8_0
9, // Q8_1
20: // IQ4_NL
return 32
default:
return 256
@@ -394,6 +328,8 @@ func (t TensorType) TypeSize() uint64 {
return 2 + blockSize/2
case TensorTypeQ4_1:
return 2 + 2 + blockSize/2
case TensorTypeMXFP4, 39:
return 1 + blockSize/2
case TensorTypeQ5_0:
return 2 + 4 + blockSize/2
case TensorTypeQ5_1:
@@ -444,8 +380,6 @@ func (t TensorType) TypeSize() uint64 {
return blockSize/8 + blockSize/16 + blockSize/32
case TensorTypeBF16:
return 2
case 4, TensorTypeMXFP4:
return 1 + blockSize/2
default:
return 0
}
@@ -545,14 +479,12 @@ func Decode(rs io.ReadSeeker, maxArraySize int) (*GGML, error) {
}, nil
}
func (f GGML) GraphSize(context, batch uint64, numParallel int, kvCacheType string, useFlashAttention bool) (kv []uint64, partialOffload, fullOffload uint64) {
func (f GGML) GraphSize(context, batch uint64, numParallel int, kvCacheType string) (kv []uint64, partialOffload, fullOffload uint64) {
context *= uint64(numParallel)
embedding := f.KV().EmbeddingLength()
heads := f.KV().HeadCountMax()
headsArr := f.KV().HeadCount()
headsKV := f.KV().HeadCountKVMax()
headsKVArr := f.KV().HeadCountKV()
vocab := uint64(f.KV()["tokenizer.ggml.tokens"].(*array[string]).size)
embeddingHeads := f.KV().EmbeddingHeadCountMax()
@@ -562,51 +494,12 @@ func (f GGML) GraphSize(context, batch uint64, numParallel int, kvCacheType stri
layers := f.Tensors().GroupLayers()
bytesPerElement := kvCacheBytesPerElement(kvCacheType)
// Default for models unless special-cased below. These defaults mirror the
// cache usage in llama.cpp under the assumption that models without special
// cases below will use the llamarunner and caching will be handled by the
// llama.cpp layer.
//
// This also assumes that a layer without heads or headsKV set is recurrent
// which is usually the case. Some models (eg nemotronh) use "blocks" in
// place of layers where some are MLP blocks that don't have any cache.
// Models like this will need a special case below to be accurately
// estimated.
var kvTotal uint64
kv = make([]uint64, f.KV().BlockCount())
kvSizeAttn := uint64(0)
kvSizeRecurrent := uint64(0)
for i := range kv {
headsL := headsArr[i]
headsKVL := headsKVArr[i]
if headsL > 0 && headsKVL > 0 {
// full attention layer
// NOTE: Assumes uniform values for all attn layers
kv[i] = uint64(float64(context*(embeddingHeadsK+embeddingHeadsV)*headsKVL) * bytesPerElement)
kvSizeAttn += kv[i]
} else {
// recurrent layer
ssmDConv := f.KV().SSMConvKernel()
ssmDState := f.KV().SSMStateSize()
ssmDInner := f.KV().SSMInnerSize()
ssmNGroups := f.KV().SSMGroupCount()
nEmbdR := uint64(0)
if ssmDConv > 0 {
nEmbdR = (ssmDConv - 1) * (ssmDInner + 2*ssmNGroups*ssmDState)
}
nEmbdS := ssmDState * ssmDInner
// recurrent always uses F32 in llama.cpp backend
// https://github.com/ggml-org/llama.cpp/blob/master/src/llama-model.cpp#L18644
bytesPerElementRecurrent := kvCacheBytesPerElement("f32")
kv[i] = (nEmbdR + nEmbdS) * uint64(bytesPerElementRecurrent)
kvSizeRecurrent += kv[i]
}
kv[i] = uint64(float64(context*(embeddingHeadsK+embeddingHeadsV)*headsKV) * bytesPerElement)
kvTotal += kv[i]
}
slog.Debug("default cache size estimate", "attention MiB", float32(kvSizeAttn)/(1024.*1024.), "attention bytes", kvSizeAttn, "recurrent MiB", float32(kvSizeRecurrent)/(1024.*1024.), "recurrent bytes", kvSizeRecurrent)
switch f.KV().Architecture() {
case "llama", "llama4":
@@ -784,12 +677,7 @@ func (f GGML) GraphSize(context, batch uint64, numParallel int, kvCacheType stri
kv[i] *= context
}
}
partialOffload = 2 * f.KV().HeadCountMax() / cmp.Or(f.KV().HeadCountKVMin(), 1) * kvTotal / 6
if useFlashAttention {
// rough estimate of graph size with flash attention on
partialOffload = (4*uint64(numParallel) + context>>10 + 110) * format.MebiByte
}
}
return
@@ -864,16 +752,12 @@ func (llm GGML) VisionGraphSize() (weights, graphSize uint64) {
// SupportsKVCacheType checks if the requested cache type is supported
func (f GGML) SupportsKVCacheType(cacheType string) bool {
if cacheType == "" || cacheType == "f16" {
return true
}
if arch := f.KV().Architecture(); slices.Contains([]string{"gptoss", "gpt-oss"}, arch) {
// gpt-oss uses attention with sinks which does not support quantized cache types
slog.Warn("model only supports non-quantized cache types", "model", arch)
return false
slog.Warn("model only supports non-quantized cache types ", "mode", arch)
return cacheType == "f16"
}
return slices.Contains([]string{"q8_0", "q4_0"}, cacheType)
return slices.Contains([]string{"f16", "q8_0", "q4_0"}, cacheType)
}
// SupportsFlashAttention checks if the model supports flash attention
@@ -883,23 +767,12 @@ func (f GGML) SupportsFlashAttention() bool {
return false
}
if arch := f.KV().Architecture(); slices.Contains([]string{"gemma2"}, arch) {
return false
}
// Check head counts match and are non-zero
headCountK := f.KV().EmbeddingHeadCountK()
headCountV := f.KV().EmbeddingHeadCountV()
return headCountK != 0 && headCountV != 0 && headCountK == headCountV
}
// FlashAttention checks if the model should enable flash attention
func (f GGML) FlashAttention() bool {
return slices.Contains([]string{
"gptoss", "gpt-oss",
}, f.KV().String("general.architecture"))
}
// kvCacheBytesPerElement returns the number of bytes per element for a given KV cache type
func kvCacheBytesPerElement(cacheType string) float64 {
switch cacheType {
@@ -907,8 +780,6 @@ func kvCacheBytesPerElement(cacheType string) float64 {
return 1 // 1/2 of fp16
case "q4_0":
return 0.5 // 1/4 of fp16
case "f32":
return 4 // f32 (default for recurrent)
default:
return 2 // f16 (default)
}

View File

@@ -533,15 +533,12 @@ func WriteGGUF(f *os.File, kv KV, ts []*Tensor) error {
}
}
slices.SortStableFunc(
ts,
func(a, b *Tensor) int {
return cmp.Or(
cmp.Compare(a.block(), b.block()),
cmp.Compare(a.Name, b.Name),
)
},
)
slices.SortStableFunc(ts, func(a, b *Tensor) int {
if i, j := a.block(), b.block(); i > 0 && j > 0 {
return cmp.Compare(i, j)
}
return cmp.Compare(a.Name, b.Name)
})
var s uint64
for i := range ts {

View File

@@ -11,24 +11,24 @@ import (
)
func TestWriteGGUF(t *testing.T) {
b := bytes.NewBuffer(make([]byte, 2*3))
r := rand.New(rand.NewPCG(0, 0))
for range 8 {
t.Run("shuffle", func(t *testing.T) {
t.Parallel()
ts := []*Tensor{
{Name: "token_embd.weight", Shape: []uint64{2, 3}, WriterTo: b},
{Name: "blk.0.ffn_norm.weight", Shape: []uint64{2, 3}, WriterTo: b},
{Name: "blk.0.attn_norm.weight", Shape: []uint64{2, 3}, WriterTo: b},
{Name: "blk.1.ffn_up.weight", Shape: []uint64{2, 3}, WriterTo: b},
{Name: "blk.2.ffn_norm.weight", Shape: []uint64{2, 3}, WriterTo: b},
{Name: "blk.1.ffn_down.weight", Shape: []uint64{2, 3}, WriterTo: b},
{Name: "blk.0.attn_k.weight", Shape: []uint64{2, 3}, WriterTo: b},
{Name: "output_norm.weight", Shape: []uint64{3, 2}, WriterTo: b},
{Name: "output.weight", Shape: []uint64{3, 2}, WriterTo: b},
{Name: "token_embd.weight", Shape: []uint64{2, 3}, WriterTo: bytes.NewBuffer(make([]byte, 2*3))},
{Name: "blk.0.attn_norm.weight", Shape: []uint64{2, 3}, WriterTo: bytes.NewBuffer(make([]byte, 2*3))},
{Name: "blk.1.attn_norm.weight", Shape: []uint64{2, 3}, WriterTo: bytes.NewBuffer(make([]byte, 2*3))},
{Name: "blk.2.attn_norm.weight", Shape: []uint64{2, 3}, WriterTo: bytes.NewBuffer(make([]byte, 2*3))},
{Name: "blk.3.attn_norm.weight", Shape: []uint64{2, 3}, WriterTo: bytes.NewBuffer(make([]byte, 2*3))},
{Name: "blk.4.attn_norm.weight", Shape: []uint64{2, 3}, WriterTo: bytes.NewBuffer(make([]byte, 2*3))},
{Name: "blk.5.attn_norm.weight", Shape: []uint64{2, 3}, WriterTo: bytes.NewBuffer(make([]byte, 2*3))},
{Name: "output_norm.weight", Shape: []uint64{3, 2}, WriterTo: bytes.NewBuffer(make([]byte, 3*2))},
{Name: "output.weight", Shape: []uint64{3, 2}, WriterTo: bytes.NewBuffer(make([]byte, 3*2))},
}
rand.Shuffle(len(ts), func(i, j int) {
r.Shuffle(len(ts), func(i, j int) {
ts[i], ts[j] = ts[j], ts[i]
})
@@ -63,14 +63,14 @@ func TestWriteGGUF(t *testing.T) {
}
if diff := cmp.Diff(Tensors{
Offset: 592,
Offset: 608,
items: []*Tensor{
{Name: "blk.0.attn_k.weight", Offset: 0, Shape: []uint64{2, 3}},
{Name: "blk.0.attn_norm.weight", Offset: 32, Shape: []uint64{2, 3}},
{Name: "blk.0.ffn_norm.weight", Offset: 64, Shape: []uint64{2, 3}},
{Name: "blk.1.ffn_down.weight", Offset: 96, Shape: []uint64{2, 3}},
{Name: "blk.1.ffn_up.weight", Offset: 128, Shape: []uint64{2, 3}},
{Name: "blk.2.ffn_norm.weight", Offset: 160, Shape: []uint64{2, 3}},
{Name: "blk.0.attn_norm.weight", Offset: 0, Shape: []uint64{2, 3}},
{Name: "blk.1.attn_norm.weight", Offset: 32, Shape: []uint64{2, 3}},
{Name: "blk.2.attn_norm.weight", Offset: 64, Shape: []uint64{2, 3}},
{Name: "blk.3.attn_norm.weight", Offset: 96, Shape: []uint64{2, 3}},
{Name: "blk.4.attn_norm.weight", Offset: 128, Shape: []uint64{2, 3}},
{Name: "blk.5.attn_norm.weight", Offset: 160, Shape: []uint64{2, 3}},
{Name: "output.weight", Offset: 192, Shape: []uint64{3, 2}},
{Name: "output_norm.weight", Offset: 224, Shape: []uint64{3, 2}},
{Name: "token_embd.weight", Offset: 256, Shape: []uint64{2, 3}},

View File

@@ -146,6 +146,8 @@ func (ftype FileType) ToTensorType() TensorType {
return TensorTypeQ4_0
case fileTypeQ4_1:
return TensorTypeQ4_1
case fileTypeMXFP4:
return TensorTypeMXFP4 // Formerly unused tensorTypeQ4_2
case FileTypeQ8_0:
return TensorTypeQ8_0
case fileTypeQ5_0:
@@ -174,8 +176,6 @@ func (ftype FileType) ToTensorType() TensorType {
return TensorTypeQ2_K
case FileTypeBF16:
return TensorTypeBF16
case fileTypeMXFP4:
return TensorTypeMXFP4
default:
slog.Warn("unsupported file type", "type", ftype)
return 0 // F32
@@ -191,8 +191,8 @@ const (
TensorTypeF16
TensorTypeQ4_0
TensorTypeQ4_1
tensorTypeQ4_2
tensorTypeQ4_3 // unused by GGML
TensorTypeMXFP4 // Formerly unused tensorTypeQ4_2
tensorTypeQ4_3 // unused by GGML
TensorTypeQ5_0
TensorTypeQ5_1
TensorTypeQ8_0
@@ -226,7 +226,6 @@ const (
tensorTypeIQ4_NL_4_4 // unused by GGML
tensorTypeIQ4_NL_4_8 // unused by GGML
tensorTypeIQ4_NL_8_8 // unused by GGML
TensorTypeMXFP4
)
// ParseFileType parses the provided GGUF file type
@@ -319,7 +318,7 @@ func (t TensorType) String() string {
return "F64"
case TensorTypeBF16:
return "BF16"
case 4, TensorTypeMXFP4:
case TensorTypeMXFP4:
return "MXFP4"
default:
return "unknown"

View File

@@ -1,13 +1,15 @@
package harmony
import (
"context"
"encoding/json"
"fmt"
"log/slog"
"maps"
"slices"
"strings"
"unicode"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/logutil"
"github.com/ollama/ollama/template"
)
@@ -89,28 +91,19 @@ func (s *HarmonyParser) AddImplicitStart() {
s.acc.WriteString("<|start|>assistant")
}
func Prefill(lastMessage api.Message) string {
if lastMessage.Role != "assistant" {
return ""
// AddImplicitStartOrPrefill adds content or thinking to the accumulator else adds start tag
func (s *HarmonyParser) AddImplicitStartOrPrefill(prefillContentOrThinking *bool) {
if prefillContentOrThinking != nil {
if *prefillContentOrThinking {
s.acc.WriteString("<|start|>assistant<|channel|>final<|message|>")
return
} else {
s.acc.WriteString("<|start|>assistant<|channel|>analysis<|message|>")
return
}
}
switch {
case strings.TrimSpace(lastMessage.Content) != "":
return "<|start|>assistant<|channel|>final<|message|>"
case strings.TrimSpace(lastMessage.Thinking) != "":
return "<|start|>assistant<|channel|>analysis<|message|>"
default:
return ""
}
}
// AddImplicitStartOrPrefill adds an implicit start tag or prefill string if provided
func (s *HarmonyParser) AddImplicitStartOrPrefill(prefillString string) {
if strings.TrimSpace(prefillString) != "" {
s.acc.WriteString(prefillString)
} else {
s.AddImplicitStart()
}
s.AddImplicitStart()
}
func (s *HarmonyParser) AddContent(content string) []HarmonyEvent {
@@ -289,7 +282,6 @@ type HarmonyMessageHandler struct {
state harmonyMessageState
HarmonyParser *HarmonyParser
FunctionNameMap *FunctionNameMap
ToolParser *HarmonyToolCallAccumulator
}
// NewHarmonyMessageHandler creates a new message handler
@@ -302,16 +294,12 @@ func NewHarmonyMessageHandler() *HarmonyMessageHandler {
HeaderEndTag: "<|message|>",
},
FunctionNameMap: NewFunctionNameMap(),
ToolParser: &HarmonyToolCallAccumulator{
state: harmonyToolCallState_Normal,
currentToolName: nil,
},
}
}
// AddContent processes the content and returns the content, thinking, and tool content.
// content and thinking are already fully parsed, but tool content still needs to be passed to the tool parser
func (h *HarmonyMessageHandler) AddContent(content string) (string, string, string) {
func (h *HarmonyMessageHandler) AddContent(content string, toolParser *HarmonyToolCallAccumulator) (string, string, string) {
contentSb := strings.Builder{}
thinkingSb := strings.Builder{}
toolContentSb := strings.Builder{}
@@ -320,7 +308,7 @@ func (h *HarmonyMessageHandler) AddContent(content string) (string, string, stri
for _, event := range events {
switch event := event.(type) {
case HarmonyEventHeaderComplete:
logutil.Trace("harmony event header complete", "header", event.Header)
slog.Log(context.TODO(), logutil.LevelTrace, "harmony event header complete", "header", event.Header)
switch event.Header.Channel {
case "analysis":
if event.Header.Recipient != "" {
@@ -328,14 +316,14 @@ func (h *HarmonyMessageHandler) AddContent(content string) (string, string, stri
// event.Header.Recipient is the tool name, something like
// "browser.search" for a built-in, or "functions.calc" for a
// custom one
h.ToolParser.SetToolName(event.Header.Recipient)
toolParser.SetToolName(event.Header.Recipient)
} else {
h.state = harmonyMessageState_Thinking
}
case "commentary":
if event.Header.Recipient != "" {
h.state = harmonyMessageState_ToolCalling
h.ToolParser.SetToolName(event.Header.Recipient)
toolParser.SetToolName(event.Header.Recipient)
} else {
h.state = harmonyMessageState_Normal
}
@@ -343,7 +331,7 @@ func (h *HarmonyMessageHandler) AddContent(content string) (string, string, stri
h.state = harmonyMessageState_Normal
}
case HarmonyEventContentEmitted:
logutil.Trace("harmony event content", "content", event.Content, "state", h.state)
slog.Log(context.TODO(), logutil.LevelTrace, "harmony event content", "content", event.Content, "state", h.state)
if h.state == harmonyMessageState_Normal {
contentSb.WriteString(event.Content)
} else if h.state == harmonyMessageState_Thinking {
@@ -358,6 +346,13 @@ func (h *HarmonyMessageHandler) AddContent(content string) (string, string, stri
return contentSb.String(), thinkingSb.String(), toolContentSb.String()
}
func (h *HarmonyMessageHandler) CreateToolParser() *HarmonyToolCallAccumulator {
return &HarmonyToolCallAccumulator{
state: harmonyToolCallState_Normal,
currentToolName: nil,
}
}
type harmonyToolCallState int
const (
@@ -399,6 +394,38 @@ type FunctionNameMap struct {
harmonyToUser map[string]string
}
func (m FunctionNameMap) MarshalJSON() ([]byte, error) {
// necessary to avoid exposing map internals
type alias struct {
UserToHarmony map[string]string `json:"userToHarmony"`
HarmonyToUser map[string]string `json:"harmonyToUser"`
}
return json.Marshal(alias{
UserToHarmony: m.userToHarmony,
HarmonyToUser: m.harmonyToUser,
})
}
func (m *FunctionNameMap) UnmarshalJSON(b []byte) error {
type alias struct {
UserToHarmony map[string]string `json:"userToHarmony"`
HarmonyToUser map[string]string `json:"harmonyToUser"`
}
var a alias
if err := json.Unmarshal(b, &a); err != nil {
return err
}
if m.userToHarmony == nil {
m.userToHarmony = make(map[string]string)
}
if m.harmonyToUser == nil {
m.harmonyToUser = make(map[string]string)
}
maps.Copy(m.userToHarmony, a.UserToHarmony)
maps.Copy(m.harmonyToUser, a.HarmonyToUser)
return nil
}
func NewFunctionNameMap() *FunctionNameMap {
return &FunctionNameMap{
userToHarmony: make(map[string]string),

View File

@@ -3,7 +3,6 @@ package harmony
import (
"fmt"
"reflect"
"strings"
"testing"
)
@@ -536,202 +535,3 @@ func TestFunctionConvertAndAdd(t *testing.T) {
})
}
}
func TestHarmonyMessageHandlerStreamingScenarios(t *testing.T) {
t.Run("thinking_then_content_streams", func(t *testing.T) {
handler := NewHarmonyMessageHandler()
handler.HarmonyParser.AddImplicitStart()
tp := handler.ToolParser
type step struct {
in string
wantContent string
wantThinking string
}
steps := []step{
{in: "<|channel|>analysis<|message|>Thinking...", wantThinking: "Thinking..."},
{in: "<|end|>", wantThinking: ""},
{in: "<|start|>assistant<|message|>Answer", wantContent: "Answer"},
{in: "<|end|>", wantContent: ""},
}
for i, s := range steps {
content, thinking, tool := handler.AddContent(s.in)
if tool != "" {
tp.Add(tool)
}
if content != s.wantContent || thinking != s.wantThinking {
t.Fatalf("step %d: got (content=%q thinking=%q), want (content=%q thinking=%q)", i, content, thinking, s.wantContent, s.wantThinking)
}
}
})
t.Run("content_streams_as_it_arrives", func(t *testing.T) {
handler := NewHarmonyMessageHandler()
handler.HarmonyParser.AddImplicitStart()
tp := handler.ToolParser
inputs := []string{
"<|start|>assistant<|message|>Hello",
", world",
"!<|end|>",
}
var got []string
for _, in := range inputs {
content, thinking, tool := handler.AddContent(in)
if tool != "" {
tp.Add(tool)
}
if thinking != "" {
t.Fatalf("unexpected thinking %q", thinking)
}
if content != "" {
got = append(got, content)
}
}
want := []string{"Hello", ", world", "!"}
if !reflect.DeepEqual(got, want) {
t.Fatalf("content pieces mismatch: got %v want %v", got, want)
}
})
t.Run("thinking_streams_separately_from_content", func(t *testing.T) {
handler := NewHarmonyMessageHandler()
handler.HarmonyParser.AddImplicitStart()
tp := handler.ToolParser
inputs := []string{
"<|channel|>analysis<|message|>Thinking...",
"<|end|>",
"<|start|>assistant<|message|>Answer",
"<|end|>",
}
var got []string
for _, in := range inputs {
content, thinking, tool := handler.AddContent(in)
if tool != "" {
tp.Add(tool)
}
if thinking != "" {
got = append(got, thinking)
}
if content != "" {
got = append(got, content)
}
}
want := []string{"Thinking...", "Answer"}
if !reflect.DeepEqual(got, want) {
t.Fatalf("content pieces mismatch: got %v want %v", got, want)
}
})
t.Run("partial_tags_buffer_until_complete", func(t *testing.T) {
handler := NewHarmonyMessageHandler()
handler.HarmonyParser.AddImplicitStart()
tp := handler.ToolParser
inputs := []string{
"<|chan",
"nel|>analysis<|mess",
"age|>Deep ",
"thought",
"<|end|>",
"<|start|>assistant<|message|>Done",
"<|end|>",
}
var thinkingPieces []string
var contentPieces []string
for _, in := range inputs {
content, thinking, tool := handler.AddContent(in)
if tool != "" {
tp.Add(tool)
}
if thinking != "" {
thinkingPieces = append(thinkingPieces, thinking)
}
if content != "" {
contentPieces = append(contentPieces, content)
}
}
if want := []string{"Deep ", "thought"}; !reflect.DeepEqual(thinkingPieces, want) {
t.Fatalf("thinking pieces mismatch: got %v want %v", thinkingPieces, want)
}
if want := []string{"Done"}; !reflect.DeepEqual(contentPieces, want) {
t.Fatalf("content pieces mismatch: got %v want %v", contentPieces, want)
}
})
t.Run("simple_assistant_after_analysis", func(t *testing.T) {
handler := NewHarmonyMessageHandler()
handler.HarmonyParser.AddImplicitStart()
tp := handler.ToolParser
inputs := []string{
"<|channel|>analysis<|message|>Think",
"<|end|>",
"<|start|>assistant<|message|>Answer",
"<|end|>",
}
var contentSb, thinkingSb strings.Builder
for _, in := range inputs {
content, thinking, tool := handler.AddContent(in)
if tool != "" {
tp.Add(tool)
}
contentSb.WriteString(content)
thinkingSb.WriteString(thinking)
}
if contentSb.String() != "Answer" {
t.Fatalf("content mismatch: got %q want %q", contentSb.String(), "Answer")
}
if thinkingSb.String() != "Think" {
t.Fatalf("thinking mismatch: got %q want %q", thinkingSb.String(), "Think")
}
})
t.Run("tool_call_parsed_and_returned_correctly", func(t *testing.T) {
handler := NewHarmonyMessageHandler()
handler.HarmonyParser.AddImplicitStart()
tp := handler.ToolParser
inputs := []string{
"<|channel|>commentary to=functions.calculate<|message|>{\"expression\":\"2+2\"}<|end|>",
}
for _, in := range inputs {
content, thinking, tool := handler.AddContent(in)
if content != "" || thinking != "" {
continue
}
if tool != "" {
tp.Add(tool)
}
}
name, args := tp.Drain()
if name == nil || *name != "functions.calculate" {
t.Fatalf("unexpected tool name: %v", name)
}
if got, want := args, "{\"expression\":\"2+2\"}"; got != want {
t.Fatalf("unexpected tool args: got %s want %s", got, want)
}
})
t.Run("tool_call_across_chunks", func(t *testing.T) {
handler := NewHarmonyMessageHandler()
handler.HarmonyParser.AddImplicitStart()
tp := handler.ToolParser
inputs := []string{
"<|channel|>commentary to=functions.calculate<|message|>{\"expression\":\"2+",
"2\"}",
"<|end|>",
}
for _, in := range inputs {
content, thinking, tool := handler.AddContent(in)
if content != "" || thinking != "" {
continue
}
if tool != "" {
tp.Add(tool)
}
}
name, args := tp.Drain()
if name == nil || *name != "functions.calculate" {
t.Fatalf("unexpected tool name: %v", name)
}
if got, want := args, "{\"expression\":\"2+2\"}"; got != want {
t.Fatalf("unexpected tool args: got %s want %s", got, want)
}
})
}

View File

@@ -390,7 +390,7 @@ func TestAPIEmbeddings(t *testing.T) {
client, _, cleanup := InitServerConnection(ctx, t)
defer cleanup()
req := api.EmbeddingRequest{
Model: libraryEmbedModels[0],
Model: "orca-mini",
Prompt: "why is the sky blue?",
Options: map[string]interface{}{
"temperature": 0,
@@ -410,99 +410,3 @@ func TestAPIEmbeddings(t *testing.T) {
t.Errorf("zero length embedding response")
}
}
func TestAPIToolCalling(t *testing.T) {
initialTimeout := 60 * time.Second
streamTimeout := 30 * time.Second
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute)
defer cancel()
client, _, cleanup := InitServerConnection(ctx, t)
defer cleanup()
modelName := "qwen3:0.6b"
if err := PullIfMissing(ctx, client, modelName); err != nil {
t.Fatalf("pull failed %s", err)
}
tools := []api.Tool{
{
Type: "function",
Function: api.ToolFunction{
Name: "get_weather",
Description: "Get the current weather in a given location",
Parameters: api.ToolFunctionParameters{
Type: "object",
Required: []string{"location"},
Properties: map[string]api.ToolProperty{
"location": {
Type: api.PropertyType{"string"},
Description: "The city and state, e.g. San Francisco, CA",
},
},
},
},
},
}
req := api.ChatRequest{
Model: modelName,
Messages: []api.Message{
{
Role: "user",
Content: "Call get_weather with location set to San Francisco.",
},
},
Tools: tools,
Options: map[string]any{
"temperature": 0,
},
}
stallTimer := time.NewTimer(initialTimeout)
var gotToolCall bool
var lastToolCall api.ToolCall
fn := func(response api.ChatResponse) error {
if len(response.Message.ToolCalls) > 0 {
gotToolCall = true
lastToolCall = response.Message.ToolCalls[len(response.Message.ToolCalls)-1]
}
if !stallTimer.Reset(streamTimeout) {
return fmt.Errorf("stall was detected while streaming response, aborting")
}
return nil
}
stream := true
req.Stream = &stream
done := make(chan int)
var genErr error
go func() {
genErr = client.Chat(ctx, &req, fn)
done <- 0
}()
select {
case <-stallTimer.C:
t.Errorf("tool-calling chat never started. Timed out after: %s", initialTimeout.String())
case <-done:
if genErr != nil {
t.Fatalf("chat failed: %v", genErr)
}
if !gotToolCall {
t.Fatalf("expected at least one tool call, got none")
}
if lastToolCall.Function.Name != "get_weather" {
t.Errorf("unexpected tool called: got %q want %q", lastToolCall.Function.Name, "get_weather")
}
if _, ok := lastToolCall.Function.Arguments["location"]; !ok {
t.Errorf("expected tool arguments to include 'location', got: %s", lastToolCall.Function.Arguments.String())
}
case <-ctx.Done():
t.Error("outer test context done while waiting for tool-calling chat")
}
}

View File

@@ -11,6 +11,7 @@ import (
"time"
"github.com/ollama/ollama/api"
"github.com/stretchr/testify/require"
)
func TestBlueSky(t *testing.T) {
@@ -36,8 +37,8 @@ func TestUnicode(t *testing.T) {
// Set up the test data
req := api.GenerateRequest{
// DeepSeek has a Unicode tokenizer regex, making it a unicode torture test
Model: "deepseek-coder-v2:16b-lite-instruct-q2_K", // TODO is there an ollama-engine model we can switch to and keep the coverage?
Prompt: "天空为什么是蓝色的?", // Why is the sky blue?
Model: "deepseek-coder-v2:16b-lite-instruct-q2_K",
Prompt: "天空为什么是蓝色的?",
Stream: &stream,
Options: map[string]any{
"temperature": 0,
@@ -49,20 +50,8 @@ func TestUnicode(t *testing.T) {
}
client, _, cleanup := InitServerConnection(ctx, t)
defer cleanup()
if err := PullIfMissing(ctx, client, req.Model); err != nil {
t.Fatal(err)
}
slog.Info("loading", "model", req.Model)
err := client.Generate(ctx, &api.GenerateRequest{Model: req.Model}, func(response api.GenerateResponse) error { return nil })
if err != nil {
t.Fatalf("failed to load model %s: %s", req.Model, err)
}
skipIfNotGPULoaded(ctx, t, client, req.Model, 100)
DoGenerate(ctx, t, client, req, []string{
"散射", // scattering
"频率", // frequency
}, 120*time.Second, 120*time.Second)
require.NoError(t, PullIfMissing(ctx, client, req.Model))
DoGenerate(ctx, t, client, req, []string{"散射", "频率"}, 120*time.Second, 120*time.Second)
}
func TestExtendedUnicodeOutput(t *testing.T) {
@@ -80,9 +69,7 @@ func TestExtendedUnicodeOutput(t *testing.T) {
}
client, _, cleanup := InitServerConnection(ctx, t)
defer cleanup()
if err := PullIfMissing(ctx, client, req.Model); err != nil {
t.Fatal(err)
}
require.NoError(t, PullIfMissing(ctx, client, req.Model))
DoGenerate(ctx, t, client, req, []string{"😀", "😊", "😁", "😂", "😄", "😃"}, 120*time.Second, 120*time.Second)
}
@@ -97,9 +84,7 @@ func TestUnicodeModelDir(t *testing.T) {
}
modelDir, err := os.MkdirTemp("", "ollama_埃")
if err != nil {
t.Fatal(err)
}
require.NoError(t, err)
defer os.RemoveAll(modelDir)
slog.Info("unicode", "OLLAMA_MODELS", modelDir)

View File

@@ -14,6 +14,8 @@ import (
"testing"
"time"
"github.com/stretchr/testify/require"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/envconfig"
"github.com/ollama/ollama/format"
@@ -77,21 +79,21 @@ func TestMultiModelStress(t *testing.T) {
t.Fatal(err)
}
// All models compatible with ollama-engine
smallModels := []string{
"llama3.2:1b",
"qwen3:0.6b",
"gemma2:2b",
"deepseek-r1:1.5b", // qwen2 arch
"gemma3:270m",
"gemma:2b",
"deepseek-r1:1.5b",
"starcoder2:3b",
}
mediumModels := []string{
"llama3.2:3b", // ~3.4G
"qwen3:8b", // ~6.6G
"gpt-oss:20b", // ~15G
"deepseek-r1:7b", // ~5.6G
"gemma3:4b", // ~5.8G
"gemma2:9b", // ~8.1G
"qwen3:8b",
"llama2",
"deepseek-r1:7b",
"mistral",
"dolphin-mistral",
"gemma:7b",
"codellama:7b",
}
var chosenModels []string
@@ -112,16 +114,13 @@ func TestMultiModelStress(t *testing.T) {
// Make sure all the models are pulled before we get started
for _, model := range chosenModels {
if err := PullIfMissing(ctx, client, model); err != nil {
t.Fatal(err)
}
require.NoError(t, PullIfMissing(ctx, client, model))
}
// Determine how many models we can load in parallel before we exceed VRAM
// The intent is to go 1 over what can fit so we force the scheduler to thrash
targetLoadCount := 0
slog.Info("Loading models to find how many can fit in VRAM before overflowing")
chooseModels:
for i, model := range chosenModels {
req := &api.GenerateRequest{Model: model}
slog.Info("loading", "model", model)
@@ -143,13 +142,6 @@ chooseModels:
slog.Info("found model load capacity", "target", targetLoadCount, "current", loaded, "chosen", chosenModels[:targetLoadCount])
break
}
// Effectively limit model count to 2 on CPU only systems to avoid thrashing and timeouts
for _, m := range models.Models {
if m.SizeVRAM == 0 {
slog.Info("model running on CPU", "name", m.Name, "target", targetLoadCount, "chosen", chosenModels[:targetLoadCount])
break chooseModels
}
}
}
}
if targetLoadCount == len(chosenModels) {

View File

@@ -22,7 +22,7 @@ func TestLongInputContext(t *testing.T) {
defer cancel()
// Set up the test data
req := api.GenerateRequest{
Model: smol,
Model: "llama2",
Prompt: "Oh, dont speak to me of Austria. Perhaps I dont understand things, but Austria never has wished, and does not wish, for war. She is betraying us! Russia alone must save Europe. Our gracious sovereign recognizes his high vocation and will be true to it. That is the one thing I have faith in! Our good and wonderful sovereign has to perform the noblest role on earth, and he is so virtuous and noble that God will not forsake him. He will fulfill his vocation and crush the hydra of revolution, which has become more terrible than ever in the person of this murderer and villain! We alone must avenge the blood of the just one.... Whom, I ask you, can we rely on?... England with her commercial spirit will not and cannot understand the Emperor Alexanders loftiness of soul. She has refused to evacuate Malta. She wanted to find, and still seeks, some secret motive in our actions. What answer did Novosíltsev get? None. The English have not understood and cannot understand the self-abnegation of our Emperor who wants nothing for himself, but only desires the good of mankind. And what have they promised? Nothing! And what little they have promised they will not perform! Prussia has always declared that Buonaparte is invincible, and that all Europe is powerless before him.... And I dont believe a word that Hardenburg says, or Haugwitz either. This famous Prussian neutrality is just a trap. I have faith only in God and the lofty destiny of our adored monarch. He will save Europe! What country is this referring to?",
Stream: &stream,
Options: map[string]any{
@@ -36,7 +36,7 @@ func TestLongInputContext(t *testing.T) {
if err := PullIfMissing(ctx, client, req.Model); err != nil {
t.Fatalf("PullIfMissing failed: %v", err)
}
DoGenerate(ctx, t, client, req, []string{"russia", "germany", "france", "england", "austria", "prussia", "europe", "individuals", "coalition", "conflict"}, 120*time.Second, 10*time.Second)
DoGenerate(ctx, t, client, req, []string{"russia", "germany", "france", "england", "austria", "prussia"}, 120*time.Second, 10*time.Second)
}
func TestContextExhaustion(t *testing.T) {
@@ -49,7 +49,7 @@ func TestContextExhaustion(t *testing.T) {
defer cancel()
// Set up the test data
req := api.GenerateRequest{
Model: smol,
Model: "llama2",
Prompt: "Write me a story with a ton of emojis?",
Stream: &stream,
Options: map[string]any{
@@ -63,7 +63,7 @@ func TestContextExhaustion(t *testing.T) {
if err := PullIfMissing(ctx, client, req.Model); err != nil {
t.Fatalf("PullIfMissing failed: %v", err)
}
DoGenerate(ctx, t, client, req, []string{"once", "upon", "lived", "sunny", "cloudy", "clear", "water"}, 120*time.Second, 10*time.Second)
DoGenerate(ctx, t, client, req, []string{"once", "upon", "lived"}, 120*time.Second, 10*time.Second)
}
// Send multiple generate requests with prior context and ensure the response is coherant and expected

View File

@@ -38,9 +38,8 @@ func TestAllMiniLMEmbeddings(t *testing.T) {
defer cleanup()
req := api.EmbeddingRequest{
Model: "all-minilm",
Prompt: "why is the sky blue?",
KeepAlive: &api.Duration{Duration: 10 * time.Second},
Model: "all-minilm",
Prompt: "why is the sky blue?",
}
res, err := embeddingTestHelper(ctx, client, t, req)

View File

@@ -9,6 +9,7 @@ import (
"time"
"github.com/ollama/ollama/api"
"github.com/stretchr/testify/require"
)
func TestVisionModels(t *testing.T) {
@@ -31,9 +32,7 @@ func TestVisionModels(t *testing.T) {
for _, v := range testCases {
t.Run(v.model, func(t *testing.T) {
image, err := base64.StdEncoding.DecodeString(imageEncoding)
if err != nil {
t.Fatal(err)
}
require.NoError(t, err)
req := api.GenerateRequest{
Model: v.model,
Prompt: "what does the text in this image say?",
@@ -53,9 +52,7 @@ func TestVisionModels(t *testing.T) {
// Note: sometimes it returns "the ollamas" sometimes "the ollams"
resp := "the ollam"
defer cleanup()
if err := PullIfMissing(ctx, client, req.Model); err != nil {
t.Fatal(err)
}
require.NoError(t, PullIfMissing(ctx, client, req.Model))
// llava models on CPU can be quite slow to start
DoGenerate(ctx, t, client, req, []string{resp}, 240*time.Second, 30*time.Second)
})
@@ -65,9 +62,7 @@ func TestVisionModels(t *testing.T) {
func TestIntegrationSplitBatch(t *testing.T) {
skipUnderMinVRAM(t, 6)
image, err := base64.StdEncoding.DecodeString(imageEncoding)
if err != nil {
t.Fatal(err)
}
require.NoError(t, err)
req := api.GenerateRequest{
Model: "gemma3:4b",
// Fill up a chunk of the batch so the image will partially spill over into the next one
@@ -89,9 +84,7 @@ func TestIntegrationSplitBatch(t *testing.T) {
defer cancel()
client, _, cleanup := InitServerConnection(ctx, t)
defer cleanup()
if err := PullIfMissing(ctx, client, req.Model); err != nil {
t.Fatal(err)
}
require.NoError(t, PullIfMissing(ctx, client, req.Model))
// llava models on CPU can be quite slow to start,
DoGenerate(ctx, t, client, req, []string{resp}, 120*time.Second, 30*time.Second)
}

47
integration/llm_test.go Normal file
View File

@@ -0,0 +1,47 @@
//go:build integration
package integration
import (
"context"
"testing"
"time"
"github.com/ollama/ollama/api"
)
// TODO - this would ideally be in the llm package, but that would require some refactoring of interfaces in the server
// package to avoid circular dependencies
var (
stream = false
req = [2]api.GenerateRequest{
{
Model: smol,
Prompt: "why is the ocean blue?",
Stream: &stream,
Options: map[string]any{
"seed": 42,
"temperature": 0.0,
},
}, {
Model: smol,
Prompt: "what is the origin of the us thanksgiving holiday?",
Stream: &stream,
Options: map[string]any{
"seed": 42,
"temperature": 0.0,
},
},
}
resp = [2][]string{
{"sunlight", "scattering", "interact"},
{"england", "english", "massachusetts", "pilgrims"},
}
)
func TestIntegrationSimple(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), time.Second*120)
defer cancel()
GenerateTestHelper(ctx, t, req[0], resp[0])
}

View File

@@ -13,6 +13,8 @@ import (
"testing"
"time"
"github.com/stretchr/testify/require"
"github.com/ollama/ollama/api"
)
@@ -45,9 +47,7 @@ func TestMaxQueue(t *testing.T) {
client, _, cleanup := InitServerConnection(ctx, t)
defer cleanup()
if err := PullIfMissing(ctx, client, req.Model); err != nil {
t.Fatal(err)
}
require.NoError(t, PullIfMissing(ctx, client, req.Model))
// Context for the worker threads so we can shut them down
// embedCtx, embedCancel := context.WithCancel(ctx)
@@ -91,9 +91,7 @@ func TestMaxQueue(t *testing.T) {
switch {
case genErr == nil:
successCount++
if len(resp.Embedding) < 5 { // somewhat arbitrary, but sufficient to be reasonable
t.Fatalf("embeddings shorter than expected: %d", len(resp.Embedding))
}
require.Greater(t, len(resp.Embedding), 5) // somewhat arbitrary, but sufficient to be reasonable
case errors.Is(genErr, context.Canceled):
canceledCount++
case strings.Contains(genErr.Error(), "busy"):
@@ -101,9 +99,7 @@ func TestMaxQueue(t *testing.T) {
case strings.Contains(genErr.Error(), "connection reset by peer"):
resetByPeerCount++
default:
if genErr != nil {
t.Fatalf("%d request failed", i)
}
require.NoError(t, genErr, "%d request failed", i)
}
slog.Info("embed finished", "id", i)
@@ -114,13 +110,8 @@ func TestMaxQueue(t *testing.T) {
embedwg.Wait()
slog.Info("embeds completed", "success", successCount, "busy", busyCount, "reset", resetByPeerCount, "canceled", canceledCount)
if resetByPeerCount != 0 {
t.Fatalf("Connections reset by peer, have you updated your fd and socket limits? %d", resetByPeerCount)
}
if busyCount == 0 {
t.Fatalf("no requests hit busy error but some should have")
}
if canceledCount > 0 {
t.Fatalf("no requests should have been canceled due to timeout %d", canceledCount)
}
require.Equal(t, resetByPeerCount, 0, "Connections reset by peer, have you updated your fd and socket limits?")
require.True(t, busyCount > 0, "no requests hit busy error but some should have")
require.True(t, canceledCount == 0, "no requests should have been canceled due to timeout")
}

View File

@@ -9,7 +9,6 @@ import (
"fmt"
"io"
"log/slog"
"math"
"math/rand"
"net"
"net/http"
@@ -26,11 +25,11 @@ import (
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/app/lifecycle"
"github.com/ollama/ollama/format"
"github.com/stretchr/testify/require"
)
var (
smol = "llama3.2:1b"
stream = false
smol = "llama3.2:1b"
)
var (
@@ -436,9 +435,7 @@ func InitServerConnection(ctx context.Context, t *testing.T) (*api.Client, strin
}
lifecycle.ServerLogFile = fp.Name()
fp.Close()
if err := startServer(t, ctx, testEndpoint); err != nil {
t.Fatal(err)
}
require.NoError(t, startServer(t, ctx, testEndpoint))
}
return client, testEndpoint, func() {
@@ -471,9 +468,7 @@ func InitServerConnection(ctx context.Context, t *testing.T) (*api.Client, strin
func GenerateTestHelper(ctx context.Context, t *testing.T, genReq api.GenerateRequest, anyResp []string) {
client, _, cleanup := InitServerConnection(ctx, t)
defer cleanup()
if err := PullIfMissing(ctx, client, genReq.Model); err != nil {
t.Fatal(err)
}
require.NoError(t, PullIfMissing(ctx, client, genReq.Model))
DoGenerate(ctx, t, client, genReq, anyResp, 30*time.Second, 10*time.Second)
}
@@ -502,22 +497,6 @@ func DoGenerate(ctx context.Context, t *testing.T, client *api.Client, genReq ap
done <- 0
}()
var response string
verify := func() {
// Verify the response contains the expected data
response = buf.String()
atLeastOne := false
for _, resp := range anyResp {
if strings.Contains(strings.ToLower(response), resp) {
atLeastOne = true
break
}
}
if !atLeastOne {
t.Fatalf("%s: none of %v found in %s", genReq.Model, anyResp, response)
}
}
select {
case <-stallTimer.C:
if buf.Len() == 0 {
@@ -530,17 +509,20 @@ func DoGenerate(ctx context.Context, t *testing.T, client *api.Client, genReq ap
slog.Warn("model is too large for the target test system", "model", genReq.Model, "error", genErr)
return context
}
if genErr != nil {
t.Fatalf("%s failed with %s request prompt %s", genErr, genReq.Model, genReq.Prompt)
require.NoError(t, genErr, "failed with %s request prompt %s ", genReq.Model, genReq.Prompt)
// Verify the response contains the expected data
response := buf.String()
atLeastOne := false
for _, resp := range anyResp {
if strings.Contains(strings.ToLower(response), resp) {
atLeastOne = true
break
}
}
verify()
require.True(t, atLeastOne, "%s: none of %v found in %s", genReq.Model, anyResp, response)
slog.Info("test pass", "model", genReq.Model, "prompt", genReq.Prompt, "contains", anyResp, "response", response)
case <-ctx.Done():
// On slow systems, we might timeout before some models finish rambling, so check what we have so far to see
// if it's considered a pass - the stallTimer will detect hangs, but we want to consider slow systems a pass
// if they are still generating valid responses
slog.Warn("outer test context done while waiting for generate")
verify()
t.Error("outer test context done while waiting for generate")
}
return context
}
@@ -579,12 +561,29 @@ func GenerateRequests() ([]api.GenerateRequest, [][]string) {
[][]string{
{"sunlight", "scattering", "interact", "color", "surface", "depth", "red", "orange", "yellow", "absorbs", "wavelength"},
{"soil", "organic", "earth", "black", "tan", "chemical", "processes", "pigments", "particles", "iron oxide", "rust", "air", "water", "mixture", "mixing"},
{"england", "english", "massachusetts", "pilgrims", "colonists", "independence", "british", "feast", "family", "gatherings", "traditions", "turkey", "colonial", "period", "harvest", "agricultural", "european settlers", "american revolution", "civil war", "16th century", "17th century", "native american", "united states", "cultural", "hardship", "autumn", "festival"},
{"england", "english", "massachusetts", "pilgrims", "colonists", "independence", "british", "feast", "family", "gatherings", "traditions", "turkey", "colonial", "period", "harvest", "agricultural", "european settlers", "american revolution", "civil war", "16th century", "17th century", "native american", "united states"},
{"fourth", "july", "declaration", "independence"},
{"nitrogen", "oxygen", "carbon", "dioxide"},
}
}
func ChatRequests() ([]api.ChatRequest, [][]string) {
genReqs, results := GenerateRequests()
reqs := make([]api.ChatRequest, len(genReqs))
for i := range reqs {
reqs[i].Model = genReqs[i].Model
reqs[i].Stream = genReqs[i].Stream
reqs[i].KeepAlive = genReqs[i].KeepAlive
reqs[i].Messages = []api.Message{
{
Role: "user",
Content: genReqs[i].Prompt,
},
}
}
return reqs, results
}
func DoChat(ctx context.Context, t *testing.T, client *api.Client, req api.ChatRequest, anyResp []string, initialTimeout, streamTimeout time.Duration) *api.Message {
stallTimer := time.NewTimer(initialTimeout)
var buf bytes.Buffer
@@ -608,22 +607,6 @@ func DoChat(ctx context.Context, t *testing.T, client *api.Client, req api.ChatR
done <- 0
}()
var response string
verify := func() {
// Verify the response contains the expected data
response = buf.String()
atLeastOne := false
for _, resp := range anyResp {
if strings.Contains(strings.ToLower(response), resp) {
atLeastOne = true
break
}
}
if !atLeastOne {
t.Fatalf("%s: none of %v found in \"%s\" -- request was:%v", req.Model, anyResp, response, req.Messages)
}
}
select {
case <-stallTimer.C:
if buf.Len() == 0 {
@@ -636,47 +619,29 @@ func DoChat(ctx context.Context, t *testing.T, client *api.Client, req api.ChatR
slog.Warn("model is too large for the target test system", "model", req.Model, "error", genErr)
return nil
}
if genErr != nil {
t.Fatalf("%s failed with %s request prompt %v", genErr, req.Model, req.Messages)
require.NoError(t, genErr, "failed with %s request Messages %s ", req.Model, req.Messages)
// Verify the response contains the expected data
response := buf.String()
atLeastOne := false
for _, resp := range anyResp {
if strings.Contains(strings.ToLower(response), resp) {
atLeastOne = true
break
}
}
verify()
require.True(t, atLeastOne, "%s: none of %v found in \"%s\" -- request was:%v", req.Model, anyResp, response, req.Messages)
slog.Info("test pass", "model", req.Model, "messages", req.Messages, "contains", anyResp, "response", response)
case <-ctx.Done():
// On slow systems, we might timeout before some models finish rambling, so check what we have so far to see
// if it's considered a pass - the stallTimer will detect hangs, but we want to consider slow systems a pass
// if they are still generating valid responses
slog.Warn("outer test context done while waiting for chat")
verify()
t.Error("outer test context done while waiting for generate")
}
return &api.Message{Role: role, Content: buf.String()}
}
func ChatRequests() ([]api.ChatRequest, [][]string) {
genReqs, results := GenerateRequests()
reqs := make([]api.ChatRequest, len(genReqs))
// think := api.ThinkValue{Value: "low"}
for i := range reqs {
reqs[i].Model = genReqs[i].Model
reqs[i].Stream = genReqs[i].Stream
reqs[i].KeepAlive = genReqs[i].KeepAlive
// reqs[i].Think = &think
reqs[i].Messages = []api.Message{
{
Role: "user",
Content: genReqs[i].Prompt,
},
}
}
return reqs, results
}
func skipUnderMinVRAM(t *testing.T, gb uint64) {
// TODO use info API in the future
if s := os.Getenv("OLLAMA_MAX_VRAM"); s != "" {
maxVram, err := strconv.ParseUint(s, 10, 64)
if err != nil {
t.Fatal(err)
}
require.NoError(t, err)
// Don't hammer on small VRAM cards...
if maxVram < gb*format.GibiByte {
t.Skip("skipping with small VRAM to avoid timeouts")
@@ -684,39 +649,6 @@ func skipUnderMinVRAM(t *testing.T, gb uint64) {
}
}
// Skip if the target model isn't X% GPU loaded to avoid excessive runtime
func skipIfNotGPULoaded(ctx context.Context, t *testing.T, client *api.Client, model string, minPercent int) {
models, err := client.ListRunning(ctx)
if err != nil {
t.Fatalf("failed to list running models: %s", err)
}
loaded := []string{}
for _, m := range models.Models {
loaded = append(loaded, m.Name)
if m.Name != model {
continue
}
gpuPercent := 0
switch {
case m.SizeVRAM == 0:
gpuPercent = 0
case m.SizeVRAM == m.Size:
gpuPercent = 100
case m.SizeVRAM > m.Size || m.Size == 0:
t.Logf("unexpected size detected: %d", m.SizeVRAM)
default:
sizeCPU := m.Size - m.SizeVRAM
cpuPercent := math.Round(float64(sizeCPU) / float64(m.Size) * 110)
gpuPercent = int(100 - cpuPercent)
}
if gpuPercent < minPercent {
t.Skip(fmt.Sprintf("test requires minimum %d%% GPU load, but model %s only has %d%%", minPercent, model, gpuPercent))
}
return
}
t.Skip(fmt.Sprintf("model %s not loaded - actually loaded: %v", model, loaded))
}
func getTimeouts(t *testing.T) (soft time.Duration, hard time.Duration) {
deadline, hasDeadline := t.Deadline()
if !hasDeadline {

View File

@@ -378,7 +378,9 @@ func (c *Causal) buildMask(ctx ml.Context) ml.Tensor {
maskTensor := ctx.Input().FromFloatSlice(mask, length, batchSize)
if c.config.MaskDType != ml.DTypeF32 {
maskTensor = maskTensor.Cast(ctx, c.config.MaskDType)
out := ctx.Input().Empty(c.config.MaskDType, maskTensor.Shape()...)
ctx.Forward(maskTensor.Copy(ctx, out))
maskTensor = out
}
return maskTensor

View File

@@ -1,130 +0,0 @@
From 0000000000000000000000000000000000000000 Mon Sep 17 00:00:00 2001
From: Jesse Gross <jesse@ollama.com>
Date: Wed, 27 Aug 2025 14:39:48 -0700
Subject: [PATCH] ggml: Enable resetting backend devices
Touching a CUDA device causes the allocation of a primary context
with CUDA data structures (~300 MB of VRAM). If a device is
unused then it can be reset to free these data structures.
---
ggml/include/ggml-backend.h | 1 +
ggml/src/ggml-backend-impl.h | 4 ++++
ggml/src/ggml-backend.cpp | 8 ++++++++
ggml/src/ggml-cuda/ggml-cuda.cu | 17 +++++++++++++++--
ggml/src/ggml-cuda/vendors/hip.h | 1 +
5 files changed, 29 insertions(+), 2 deletions(-)
diff --git a/ggml/include/ggml-backend.h b/ggml/include/ggml-backend.h
index b602a7c78..fda5ceb24 100644
--- a/ggml/include/ggml-backend.h
+++ b/ggml/include/ggml-backend.h
@@ -167,6 +167,7 @@ extern "C" {
GGML_API void ggml_backend_dev_get_props(ggml_backend_dev_t device, struct ggml_backend_dev_props * props);
GGML_API ggml_backend_reg_t ggml_backend_dev_backend_reg(ggml_backend_dev_t device);
GGML_API ggml_backend_t ggml_backend_dev_init(ggml_backend_dev_t device, const char * params);
+ GGML_API void ggml_backend_dev_reset(ggml_backend_dev_t device);
GGML_API ggml_backend_buffer_type_t ggml_backend_dev_buffer_type(ggml_backend_dev_t device);
GGML_API ggml_backend_buffer_type_t ggml_backend_dev_host_buffer_type(ggml_backend_dev_t device);
GGML_API ggml_backend_buffer_t ggml_backend_dev_buffer_from_host_ptr(ggml_backend_dev_t device, void * ptr, size_t size, size_t max_tensor_size);
diff --git a/ggml/src/ggml-backend-impl.h b/ggml/src/ggml-backend-impl.h
index 81749a5a3..6f10c353b 100644
--- a/ggml/src/ggml-backend-impl.h
+++ b/ggml/src/ggml-backend-impl.h
@@ -178,6 +178,10 @@ extern "C" {
ggml_backend_event_t (*event_new) (ggml_backend_dev_t dev);
void (*event_free) (ggml_backend_dev_t dev, ggml_backend_event_t event);
void (*event_synchronize) (ggml_backend_dev_t dev, ggml_backend_event_t event);
+
+ // (optional) reset device, clearing existing allocations and context
+ // the caller must ensure that there are no outstanding buffers, as these will become invalid
+ void (*reset)(ggml_backend_dev_t dev);
};
struct ggml_backend_device {
diff --git a/ggml/src/ggml-backend.cpp b/ggml/src/ggml-backend.cpp
index 05a842ed5..6556943b0 100644
--- a/ggml/src/ggml-backend.cpp
+++ b/ggml/src/ggml-backend.cpp
@@ -477,6 +477,14 @@ ggml_backend_t ggml_backend_dev_init(ggml_backend_dev_t device, const char * par
return device->iface.init_backend(device, params);
}
+void ggml_backend_dev_reset(ggml_backend_dev_t device) {
+ if (device->iface.reset == NULL) {
+ return;
+ }
+
+ device->iface.reset(device);
+}
+
ggml_backend_buffer_type_t ggml_backend_dev_buffer_type(ggml_backend_dev_t device) {
return device->iface.get_buffer_type(device);
}
diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu
index c7f9dc3a5..e43fde523 100644
--- a/ggml/src/ggml-cuda/ggml-cuda.cu
+++ b/ggml/src/ggml-cuda/ggml-cuda.cu
@@ -103,6 +103,11 @@ int ggml_cuda_get_device() {
return id;
}
+void ggml_cuda_reset_device(int device) {
+ ggml_cuda_set_device(device);
+ CUDA_CHECK(cudaDeviceReset());
+}
+
static cudaError_t ggml_cuda_device_malloc(void ** ptr, size_t size, int device) {
ggml_cuda_set_device(device);
cudaError_t err;
@@ -3243,7 +3248,10 @@ static void ggml_backend_cuda_device_get_props(ggml_backend_dev_t dev, ggml_back
props->description = ggml_backend_cuda_device_get_description(dev);
props->id = ggml_backend_cuda_device_get_id(dev);
props->type = ggml_backend_cuda_device_get_type(dev);
- ggml_backend_cuda_device_get_memory(dev, &props->memory_free, &props->memory_total);
+
+ // Memory reporting is disabled to avoid allocation of a CUDA primary context (~300 MB per device).
+ // If you need the memory data, call ggml_backend_dev_memory() explicitly.
+ props->memory_total = props->memory_free = 0;
bool host_buffer = getenv("GGML_CUDA_NO_PINNED") == nullptr;
#ifdef GGML_CUDA_NO_PEER_COPY
@@ -3700,6 +3708,11 @@ static void ggml_backend_cuda_device_event_synchronize(ggml_backend_dev_t dev, g
CUDA_CHECK(cudaEventSynchronize((cudaEvent_t)event->context));
}
+static void ggml_backend_cuda_device_reset(ggml_backend_dev_t dev) {
+ ggml_backend_cuda_device_context * ctx = (ggml_backend_cuda_device_context *)dev->context;
+ ggml_cuda_reset_device(ctx->device);
+}
+
static const ggml_backend_device_i ggml_backend_cuda_device_interface = {
/* .get_name = */ ggml_backend_cuda_device_get_name,
/* .get_description = */ ggml_backend_cuda_device_get_description,
@@ -3716,6 +3729,7 @@ static const ggml_backend_device_i ggml_backend_cuda_device_interface = {
/* .event_new = */ ggml_backend_cuda_device_event_new,
/* .event_free = */ ggml_backend_cuda_device_event_free,
/* .event_synchronize = */ ggml_backend_cuda_device_event_synchronize,
+ /* .reset = */ ggml_backend_cuda_device_reset,
};
// backend reg
@@ -3835,7 +3849,6 @@ ggml_backend_reg_t ggml_backend_cuda_reg() {
dev_ctx->device = i;
dev_ctx->name = GGML_CUDA_NAME + std::to_string(i);
- ggml_cuda_set_device(i);
cudaDeviceProp prop;
CUDA_CHECK(cudaGetDeviceProperties(&prop, i));
dev_ctx->description = prop.name;
diff --git a/ggml/src/ggml-cuda/vendors/hip.h b/ggml/src/ggml-cuda/vendors/hip.h
index c31f31923..cf22e60d2 100644
--- a/ggml/src/ggml-cuda/vendors/hip.h
+++ b/ggml/src/ggml-cuda/vendors/hip.h
@@ -40,6 +40,7 @@
#define cudaDeviceDisablePeerAccess hipDeviceDisablePeerAccess
#define cudaDeviceEnablePeerAccess hipDeviceEnablePeerAccess
#define cudaDeviceProp hipDeviceProp_t
+#define cudaDeviceReset hipDeviceReset
#define cudaDeviceSynchronize hipDeviceSynchronize
#define cudaError_t hipError_t
#define cudaErrorPeerAccessAlreadyEnabled hipErrorPeerAccessAlreadyEnabled

View File

@@ -1,28 +0,0 @@
From 0000000000000000000000000000000000000000 Mon Sep 17 00:00:00 2001
From: Daniel Hiltgen <daniel@ollama.com>
Date: Fri, 29 Aug 2025 16:53:08 -0700
Subject: [PATCH] harden uncaught exception registration
---
ggml/src/ggml.cpp | 8 ++++++--
1 file changed, 6 insertions(+), 2 deletions(-)
diff --git a/ggml/src/ggml.cpp b/ggml/src/ggml.cpp
index 0d388d45..f5bcb446 100644
--- a/ggml/src/ggml.cpp
+++ b/ggml/src/ggml.cpp
@@ -19,8 +19,12 @@ static bool ggml_uncaught_exception_init = []{
return false;
}
const auto prev{std::get_terminate()};
- GGML_ASSERT(prev != ggml_uncaught_exception);
- previous_terminate_handler = prev;
+ // GGML_ASSERT(prev != ggml_uncaught_exception);
+ if (prev != ggml_uncaught_exception) {
+ previous_terminate_handler = prev;
+ } else {
+ GGML_LOG_WARN("%s double registration of ggml_uncaught_exception\n", __func__);
+ }
std::set_terminate(ggml_uncaught_exception);
return true;
}();

View File

@@ -30,7 +30,7 @@ func pickBestFullFitByLibrary(f *ggml.GGML, modelPath string, projectors []strin
// Try to pack into as few GPUs as possible, starting from 1 GPU
for numGPUs := 1; numGPUs <= len(sgl); numGPUs++ {
gpuSubset := sgl[:numGPUs]
ok, estimatedVRAM := predictServerFit(gpuSubset, f, adapters, projectors, opts, numParallel)
ok, estimatedVRAM := PredictServerFit(gpuSubset, f, adapters, projectors, opts, numParallel)
if ok {
slog.Info("new model will fit in available VRAM across minimum required GPUs, loading",
@@ -48,7 +48,7 @@ func pickBestFullFitByLibrary(f *ggml.GGML, modelPath string, projectors []strin
// - try subsets of GPUs instead of just falling back to 1 or all in a family
// Now try all the GPUS (OLLAMA_SCHED_SPREAD is set)
if ok, estimatedVRAM := predictServerFit(sgl, f, adapters, projectors, opts, numParallel); ok {
if ok, estimatedVRAM := PredictServerFit(sgl, f, adapters, projectors, opts, numParallel); ok {
slog.Info("new model will fit in available VRAM, loading",
"model", modelPath,
"library", sgl[0].Library,
@@ -71,7 +71,7 @@ func pickBestPartialFitByLibrary(f *ggml.GGML, projectors []string, adapters []s
var bestEstimate uint64
var bestFit int
for i, gl := range byLibrary {
_, estimatedVRAM := predictServerFit(gl, f, adapters, projectors, opts, numParallel)
_, estimatedVRAM := PredictServerFit(gl, f, adapters, projectors, opts, numParallel)
if estimatedVRAM > bestEstimate {
bestEstimate = estimatedVRAM
bestFit = i
@@ -81,7 +81,7 @@ func pickBestPartialFitByLibrary(f *ggml.GGML, projectors []string, adapters []s
}
// This algorithm looks for a complete fit to determine if we need to unload other models
func predictServerFit(allGpus discover.GpuInfoList, f *ggml.GGML, adapters, projectors []string, opts api.Options, numParallel int) (bool, uint64) {
func PredictServerFit(allGpus discover.GpuInfoList, f *ggml.GGML, adapters, projectors []string, opts api.Options, numParallel int) (bool, uint64) {
// Split up the GPUs by type and try them
var estimatedVRAM uint64
for _, gpus := range allGpus.ByLibrary() {
@@ -97,10 +97,6 @@ func predictServerFit(allGpus discover.GpuInfoList, f *ggml.GGML, adapters, proj
return true, estimatedVRAM
}
}
if len(gpus) == 1 && gpus[0].Library == "cpu" && estimate.TotalSize <= gpus[0].FreeMemory {
return true, estimatedVRAM
}
}
return false, estimatedVRAM
}
@@ -195,19 +191,17 @@ func estimateGPULayers(gpus []discover.GpuInfo, f *ggml.GGML, projectors []strin
slog.Warn("model missing blk.0 layer size")
}
useFlashAttention := (envconfig.FlashAttention() || f.FlashAttention()) &&
discover.GetGPUInfo().FlashAttentionSupported() &&
f.SupportsFlashAttention()
var kvct string
if useFlashAttention {
if envconfig.FlashAttention() &&
discover.GetGPUInfo().FlashAttentionSupported() &&
f.SupportsFlashAttention() {
requested := strings.ToLower(envconfig.KvCacheType())
if f.SupportsKVCacheType(requested) {
if requested != "" && f.SupportsKVCacheType(requested) {
kvct = requested
}
}
kv, graphPartialOffload, graphFullOffload := f.GraphSize(uint64(opts.NumCtx), uint64(min(opts.NumCtx, opts.NumBatch)), numParallel, kvct, useFlashAttention)
kv, graphPartialOffload, graphFullOffload := f.GraphSize(uint64(opts.NumCtx), uint64(min(opts.NumCtx, opts.NumBatch)), numParallel, kvct)
if len(kv) > 0 {
layerSize += kv[0]

View File

@@ -31,11 +31,11 @@ import (
"github.com/ollama/ollama/envconfig"
"github.com/ollama/ollama/format"
"github.com/ollama/ollama/fs/ggml"
"github.com/ollama/ollama/harmony"
"github.com/ollama/ollama/llama"
"github.com/ollama/ollama/logutil"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/parser"
)
type filteredEnv []string
@@ -149,11 +149,7 @@ func NewLlamaServer(gpus discover.GpuInfoList, modelPath string, f *ggml.GGML, a
var textProcessor model.TextProcessor
var err error
if envconfig.NewEngine() || f.KV().OllamaEngineRequired() {
if len(projectors) == 0 {
textProcessor, err = model.NewTextProcessor(modelPath)
} else {
err = errors.New("split vision models aren't supported")
}
textProcessor, err = model.NewTextProcessor(modelPath)
if err != nil {
// To prepare for opt-out mode, instead of treating this as an error, we fallback to the old runner
slog.Debug("model not yet supported by Ollama engine, switching to compatibility mode", "model", modelPath, "error", err)
@@ -166,6 +162,11 @@ func NewLlamaServer(gpus discover.GpuInfoList, modelPath string, f *ggml.GGML, a
}
}
newEstimates := textProcessor != nil && envconfig.NewMemoryEstimates()
if newEstimates {
slog.Info("enabling new memory estimates")
}
// Verify the requested context size is <= the model training size
trainCtx := f.KV().ContextLength()
if opts.NumCtx > int(trainCtx) && trainCtx > 0 {
@@ -173,8 +174,6 @@ func NewLlamaServer(gpus discover.GpuInfoList, modelPath string, f *ggml.GGML, a
opts.NumCtx = int(trainCtx)
}
opts.NumBatch = min(opts.NumBatch, opts.NumCtx)
loadRequest := LoadRequest{LoraPath: adapters, KvSize: opts.NumCtx * numParallel, BatchSize: opts.NumBatch, Parallel: numParallel, MultiUserCache: envconfig.MultiUserCache()}
defaultThreads := discover.GetSystemInfo().GetOptimalThreadCount()
@@ -197,11 +196,6 @@ func NewLlamaServer(gpus discover.GpuInfoList, modelPath string, f *ggml.GGML, a
// This will disable flash attention unless all GPUs on the system support it, even if we end up selecting a subset
// that can handle it.
fa := envconfig.FlashAttention()
if f.FlashAttention() {
slog.Info("model wants flash attention")
fa = true
}
if fa && !gpus.FlashAttentionSupported() {
slog.Warn("flash attention enabled but not supported by gpu")
fa = false
@@ -220,7 +214,7 @@ func NewLlamaServer(gpus discover.GpuInfoList, modelPath string, f *ggml.GGML, a
// Flash Attention also supports kv cache quantization
// Enable if the requested and kv cache type is supported by the model
if f.SupportsKVCacheType(kvct) {
if kvct != "" && f.SupportsKVCacheType(kvct) {
loadRequest.KvCacheType = kvct
} else {
slog.Warn("kv cache type not supported by model", "type", kvct)
@@ -362,28 +356,23 @@ func NewLlamaServer(gpus discover.GpuInfoList, modelPath string, f *ggml.GGML, a
s.cmd.Env = append(s.cmd.Env, "OLLAMA_LIBRARY_PATH="+strings.Join(ggmlPaths, string(filepath.ListSeparator)))
envWorkarounds := []string{}
envWorkarounds := [][2]string{}
for _, gpu := range gpus {
envWorkarounds = append(envWorkarounds, gpu.EnvWorkarounds...)
}
// Always filter down the set of GPUs in case there are any unsupported devices that might crash
envWorkarounds = append(envWorkarounds, gpus.GetVisibleDevicesEnv()...)
pathEnvVal := strings.Join(libraryPaths, string(filepath.ListSeparator))
// Update or add the path variable with our adjusted version
pathNeeded := true
envWorkaroundDone := make([]bool, len(envWorkarounds))
for i := range s.cmd.Env {
cmp := strings.SplitN(s.cmd.Env[i], "=", 2)
if strings.EqualFold(cmp[0], pathEnv) {
s.cmd.Env[i] = pathEnv + "=" + pathEnvVal
pathNeeded = false
} else if len(envWorkarounds) != 0 {
for j, kv := range envWorkarounds {
tmp := strings.SplitN(kv, "=", 2)
if strings.EqualFold(cmp[0], tmp[0]) {
s.cmd.Env[i] = kv
envWorkaroundDone[j] = true
for _, kv := range envWorkarounds {
if strings.EqualFold(cmp[0], kv[0]) {
s.cmd.Env[i] = kv[0] + "=" + kv[1]
}
}
}
@@ -391,11 +380,6 @@ func NewLlamaServer(gpus discover.GpuInfoList, modelPath string, f *ggml.GGML, a
if pathNeeded {
s.cmd.Env = append(s.cmd.Env, pathEnv+"="+pathEnvVal)
}
for i, done := range envWorkaroundDone {
if !done {
s.cmd.Env = append(s.cmd.Env, envWorkarounds[i])
}
}
slog.Info("starting runner", "cmd", s.cmd)
slog.Debug("subprocess", "", filteredEnv(s.cmd.Env))
@@ -433,7 +417,7 @@ func NewLlamaServer(gpus discover.GpuInfoList, modelPath string, f *ggml.GGML, a
}
}()
if textProcessor != nil {
if newEstimates {
return &ollamaServer{llmServer: s}, nil
} else {
return &llamaServer{llmServer: s, ggml: f}, nil
@@ -509,7 +493,6 @@ func (s *llamaServer) Load(ctx context.Context, gpus discover.GpuInfoList, requi
if !requireFull {
g = pickBestPartialFitByLibrary(s.ggml, []string{s.loadRequest.ProjectorPath}, s.loadRequest.LoraPath, s.options, gpus, s.numParallel)
} else {
slog.Info("model requires more memory than is currently available, evicting a model to make space", "estimate", s.estimate)
return ErrLoadRequiredFull
}
}
@@ -542,6 +525,10 @@ func (s *llamaServer) Load(ctx context.Context, gpus discover.GpuInfoList, requi
}
}
if requireFull && len(gpus) == 1 && gpus[0].Library == "cpu" && s.estimate.TotalSize > gpus[0].FreeMemory {
return ErrLoadRequiredFull
}
slog.Info("offload", "", s.estimate)
s.gpus = gpus
@@ -680,12 +667,8 @@ func (s *ollamaServer) Load(ctx context.Context, gpus discover.GpuInfoList, requ
if !(len(gpus) == 1 && gpus[0].Library == "cpu") {
for _, gpu := range gpus {
available := gpu.FreeMemory - envconfig.GpuOverhead() - gpu.MinimumMemory
if gpu.FreeMemory < envconfig.GpuOverhead()+gpu.MinimumMemory {
available = 0
}
slog.Info("gpu memory", "id", gpu.ID,
"available", format.HumanBytes2(available),
"available", format.HumanBytes2(gpu.FreeMemory-envconfig.GpuOverhead()-gpu.MinimumMemory),
"free", format.HumanBytes2(gpu.FreeMemory),
"minimum", format.HumanBytes2(gpu.MinimumMemory),
"overhead", format.HumanBytes2(envconfig.GpuOverhead()))
@@ -867,7 +850,7 @@ func (s *ollamaServer) createLayout(systemInfo discover.SystemInfo, systemGPUs d
}
layers[i] += memory.CPU.Weights[i].Size
layers[i] += memory.CPU.Cache[i].Size
logutil.Trace("layer to assign", "layer", i, "size", format.HumanBytes2(layers[i]))
slog.Log(context.TODO(), logutil.LevelTrace, "layer to assign", "layer", i, "size", format.HumanBytes2(layers[i]))
}
gpuLayers := ml.GPULayersList{}
@@ -1349,9 +1332,9 @@ type CompletionRequest struct {
Images []ImageData
Options *api.Options
Grammar string // set before sending the request to the subprocess
ParserType parser.TokenParserType
PrefillString string
Grammar string // set before sending the request to the subprocess
FunctionNameMap *harmony.FunctionNameMap
PrefillContent *bool
}
// DoneReason represents the reason why a completion response is done
@@ -1504,8 +1487,7 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu
return fmt.Errorf("error unmarshalling llm prediction response: %v", err)
}
switch {
// TODO(parthsareen): token repeat limit is now handled in the runner, this currently support legacy model and can be removed in the future
case strings.TrimSpace(c.Content) == lastToken && c.Content != "":
case lastToken != "" && (strings.TrimSpace(c.Content) == lastToken || strings.TrimSpace(c.Thinking) == lastToken):
tokenRepeat++
default:
lastToken = strings.TrimSpace(c.Content)

View File

@@ -1,7 +1,6 @@
package logutil
import (
"context"
"io"
"log/slog"
"path/filepath"
@@ -28,11 +27,3 @@ func NewLogger(w io.Writer, level slog.Level) *slog.Logger {
},
}))
}
func Trace(msg string, args ...any) {
slog.Log(context.TODO(), LevelTrace, msg, args...)
}
func TraceContext(ctx context.Context, msg string, args ...any) {
slog.Log(ctx, LevelTrace, msg, args...)
}

View File

@@ -266,7 +266,7 @@ func (m DeviceMemory) LogValue() slog.Value {
// allocation is guaranteed to be provided so that if it failed, the caller can
// accommodate that to make forward progress.
type BackendMemory struct {
// InputWeights are always located on the CPU and cannot be moved
// InputsWeights are always located on the CPU and cannot be moved
InputWeights Memory
// CPU model components are located in system memory. This does not
@@ -372,7 +372,6 @@ type Context interface {
Forward(...Tensor) Context
Compute(...Tensor)
ComputeWithNotify(func(), ...Tensor) // notify callback once compute has begun
// Reserve is analogous to Compute but rather than executing a
// graph, simply preallocates memory. Typically called with a
@@ -397,12 +396,11 @@ type Tensor interface {
Shape() []int
DType() DType
Cast(ctx Context, dtype DType) Tensor
Bytes() []byte
Floats() []float32
SetValueFromIntSlice(s []int32)
BackendSetFromIntSlice(s []int32)
Neg(ctx Context) Tensor
Add(ctx Context, t2 Tensor) Tensor

View File

@@ -271,7 +271,7 @@ func New(modelPath string, params ml.BackendParams) (ml.Backend, error) {
tt := C.ggml_new_tensor(ctxs[bt], kind, C.int(len(t.source.Shape)), (*C.int64_t)(unsafe.Pointer(&t.source.Shape[0])))
C.ggml_set_name(tt, cname)
logutil.Trace("created tensor", "name", name, "shape", t.source.Shape, "dtype", t.source.Kind, "buffer_type", C.GoString(C.ggml_backend_buft_name(bt)))
slog.Log(context.TODO(), logutil.LevelTrace, "created tensor", "name", name, "shape", t.source.Shape, "dtype", t.source.Kind, "buffer_type", C.GoString(C.ggml_backend_buft_name(bt)))
size := pad(C.ggml_backend_buft_get_alloc_size(bt, tt), C.ggml_backend_buft_get_alignment(bt))
if layer == -1 {
@@ -378,7 +378,7 @@ func New(modelPath string, params ml.BackendParams) (ml.Backend, error) {
}
for bs := range maps.Values(bbs) {
logutil.Trace("model weights", "buffer", C.GoString(C.ggml_backend_buffer_name(bs)),
slog.Log(context.TODO(), logutil.LevelTrace, "model weights", "buffer", C.GoString(C.ggml_backend_buffer_name(bs)),
"size", format.HumanBytes2(uint64(C.ggml_backend_buffer_get_size(bs))))
}
@@ -536,7 +536,6 @@ func (b *Backend) Load(ctx context.Context, progress func(float32)) error {
const BS = 17 // MXFP4 block size
bts := make([]byte, 8*BS*format.KibiByte) // ~128k block aligned
var s uint64
var tmp [16]byte
for s < t.Size() {
// Stop if either the parent context has been canceled or if any of the other tensors returned an error
if err := ctx.Err(); err != nil {
@@ -548,13 +547,37 @@ func (b *Backend) Load(ctx context.Context, progress func(float32)) error {
return err
}
for j := range n / BS {
for i := 1; i < 9; i++ {
// transform a1b2c3 ... x7y8z9 -> 71xa82yb93zc
a, b := bts[j*BS+i], bts[j*BS+i+8]
tmp[2*(i-1)] = (a & 0x0F) | (b << 4)
tmp[2*(i-1)+1] = (a >> 4) | (b & 0xF0)
for i := 1; i < BS; i++ {
// swap nibbles
t_lo := bts[j*BS+i] & 0x0F
t_hi := bts[j*BS+i] & 0xF0
bts[j*BS+i] = (t_lo << 4) | (t_hi >> 4)
}
// transform aaaa...bbbb... to abababab...
oi := 0
tmp := [16]byte{}
for i := 1; i < 9; i++ {
blk_a0 := bts[j*BS+i] & 0xF0
blk_a1 := bts[j*BS+i] << 4
blk_b0 := bts[j*BS+i+8] >> 4
blk_b1 := bts[j*BS+i+8] & 0x0F
// swap once more
out0 := blk_a0 | blk_b0
out1 := blk_a1 | blk_b1
out_h0 := out0 & 0xF0
out_l0 := out0 & 0x0F
out_h1 := out1 & 0xF0
out_l1 := out1 & 0x0F
out0 = (out_h0 >> 4) | (out_l0 << 4)
out1 = (out_h1 >> 4) | (out_l1 << 4)
tmp[oi] = out0
oi++
tmp[oi] = out1
oi++
}
for i := range tmp {
bts[j*BS+i+1] = tmp[i]
}
copy(bts[j*BS+1:j*BS+17], tmp[:])
}
for _, tt := range tts {
@@ -630,18 +653,6 @@ func (b *Backend) Load(ctx context.Context, progress func(float32)) error {
})
}
// Cleanup any backend state from devices that we didn't end up using
nextDevice:
for _, d := range append(gpus, append(accels, cpus...)...) {
for _, backend := range b.schedBackends {
if d == C.ggml_backend_get_device(backend) {
continue nextDevice
}
}
C.ggml_backend_dev_reset(d)
}
if err := g.Wait(); err != nil {
return err
}
@@ -759,15 +770,8 @@ func (c *Context) Forward(tensors ...ml.Tensor) ml.Context {
}
func (c *Context) Compute(tensors ...ml.Tensor) {
c.ComputeWithNotify(nil, tensors...)
}
func (c *Context) ComputeWithNotify(cb func(), tensors ...ml.Tensor) {
c.b.schedMu.Lock()
defer c.b.schedMu.Unlock()
if cb != nil {
go cb()
}
if status := C.ggml_backend_sched_graph_compute_async(c.b.sched, c.graph); status != C.GGML_STATUS_SUCCESS {
panic(fmt.Errorf("error computing ggml graph: %v", status))
}
@@ -811,7 +815,7 @@ func (c *Context) Reserve() {
}
}
logutil.Trace("compute graph", "backend", C.GoString(C.ggml_backend_name(c.b.schedBackends[i])),
slog.Log(context.TODO(), logutil.LevelTrace, "compute graph", "backend", C.GoString(C.ggml_backend_name(c.b.schedBackends[i])),
"buffer_type", C.GoString(C.ggml_backend_buft_name(c.b.schedBufts[i])), "size", format.HumanBytes2(uint64(bufferStatus.size)))
}
@@ -842,7 +846,23 @@ func (c *Context) newTensor(dtype ml.DType, shape []int) ml.Tensor {
panic("set Input or Layer before creating tensors")
}
cdtype := ggmlDType(dtype)
var cdtype uint32
switch dtype {
case ml.DTypeF32:
cdtype = C.GGML_TYPE_F32
case ml.DTypeF16:
cdtype = C.GGML_TYPE_F16
case ml.DTypeQ80:
cdtype = C.GGML_TYPE_Q8_0
case ml.DTypeQ40:
cdtype = C.GGML_TYPE_Q4_0
case ml.DTypeI32:
cdtype = C.GGML_TYPE_I32
case ml.DTypeMXFP4:
cdtype = C.GGML_TYPE_MXFP4
default:
panic("unsupported dtype")
}
if len(shape) < 1 || shape[0] == 0 {
var shape C.int64_t = 0
@@ -1020,7 +1040,7 @@ func (t *Tensor) Floats() (data []float32) {
return
}
func (t *Tensor) SetValueFromIntSlice(s []int32) {
func (t *Tensor) BackendSetFromIntSlice(s []int32) {
if len(s) > 0 {
C.ggml_backend_tensor_set(t.t, unsafe.Pointer(&s[0]), 0, C.ggml_nbytes(t.t))
}
@@ -1045,32 +1065,6 @@ func (t *Tensor) DType() ml.DType {
}
}
func ggmlDType(dtype ml.DType) uint32 {
switch dtype {
case ml.DTypeF32:
return C.GGML_TYPE_F32
case ml.DTypeF16:
return C.GGML_TYPE_F16
case ml.DTypeQ80:
return C.GGML_TYPE_Q8_0
case ml.DTypeQ40:
return C.GGML_TYPE_Q4_0
case ml.DTypeI32:
return C.GGML_TYPE_I32
case ml.DTypeMXFP4:
return C.GGML_TYPE_MXFP4
default:
panic("unsupported dtype")
}
}
func (t *Tensor) Cast(ctx ml.Context, dtype ml.DType) ml.Tensor {
return &Tensor{
b: t.b,
t: C.ggml_cast(ctx.(*Context).ctx, t.t, ggmlDType(dtype)),
}
}
func (t *Tensor) Neg(ctx ml.Context) ml.Tensor {
return &Tensor{
b: t.b,

View File

@@ -167,7 +167,6 @@ extern "C" {
GGML_API void ggml_backend_dev_get_props(ggml_backend_dev_t device, struct ggml_backend_dev_props * props);
GGML_API ggml_backend_reg_t ggml_backend_dev_backend_reg(ggml_backend_dev_t device);
GGML_API ggml_backend_t ggml_backend_dev_init(ggml_backend_dev_t device, const char * params);
GGML_API void ggml_backend_dev_reset(ggml_backend_dev_t device);
GGML_API ggml_backend_buffer_type_t ggml_backend_dev_buffer_type(ggml_backend_dev_t device);
GGML_API ggml_backend_buffer_type_t ggml_backend_dev_host_buffer_type(ggml_backend_dev_t device);
GGML_API ggml_backend_buffer_t ggml_backend_dev_buffer_from_host_ptr(ggml_backend_dev_t device, void * ptr, size_t size, size_t max_tensor_size);

View File

@@ -178,10 +178,6 @@ extern "C" {
ggml_backend_event_t (*event_new) (ggml_backend_dev_t dev);
void (*event_free) (ggml_backend_dev_t dev, ggml_backend_event_t event);
void (*event_synchronize) (ggml_backend_dev_t dev, ggml_backend_event_t event);
// (optional) reset device, clearing existing allocations and context
// the caller must ensure that there are no outstanding buffers, as these will become invalid
void (*reset)(ggml_backend_dev_t dev);
};
struct ggml_backend_device {

View File

@@ -477,14 +477,6 @@ ggml_backend_t ggml_backend_dev_init(ggml_backend_dev_t device, const char * par
return device->iface.init_backend(device, params);
}
void ggml_backend_dev_reset(ggml_backend_dev_t device) {
if (device->iface.reset == NULL) {
return;
}
device->iface.reset(device);
}
ggml_backend_buffer_type_t ggml_backend_dev_buffer_type(ggml_backend_dev_t device) {
return device->iface.get_buffer_type(device);
}

View File

@@ -103,11 +103,6 @@ int ggml_cuda_get_device() {
return id;
}
void ggml_cuda_reset_device(int device) {
ggml_cuda_set_device(device);
CUDA_CHECK(cudaDeviceReset());
}
static cudaError_t ggml_cuda_device_malloc(void ** ptr, size_t size, int device) {
ggml_cuda_set_device(device);
cudaError_t err;
@@ -3248,10 +3243,7 @@ static void ggml_backend_cuda_device_get_props(ggml_backend_dev_t dev, ggml_back
props->description = ggml_backend_cuda_device_get_description(dev);
props->id = ggml_backend_cuda_device_get_id(dev);
props->type = ggml_backend_cuda_device_get_type(dev);
// Memory reporting is disabled to avoid allocation of a CUDA primary context (~300 MB per device).
// If you need the memory data, call ggml_backend_dev_memory() explicitly.
props->memory_total = props->memory_free = 0;
ggml_backend_cuda_device_get_memory(dev, &props->memory_free, &props->memory_total);
bool host_buffer = getenv("GGML_CUDA_NO_PINNED") == nullptr;
#ifdef GGML_CUDA_NO_PEER_COPY
@@ -3708,11 +3700,6 @@ static void ggml_backend_cuda_device_event_synchronize(ggml_backend_dev_t dev, g
CUDA_CHECK(cudaEventSynchronize((cudaEvent_t)event->context));
}
static void ggml_backend_cuda_device_reset(ggml_backend_dev_t dev) {
ggml_backend_cuda_device_context * ctx = (ggml_backend_cuda_device_context *)dev->context;
ggml_cuda_reset_device(ctx->device);
}
static const ggml_backend_device_i ggml_backend_cuda_device_interface = {
/* .get_name = */ ggml_backend_cuda_device_get_name,
/* .get_description = */ ggml_backend_cuda_device_get_description,
@@ -3729,7 +3716,6 @@ static const ggml_backend_device_i ggml_backend_cuda_device_interface = {
/* .event_new = */ ggml_backend_cuda_device_event_new,
/* .event_free = */ ggml_backend_cuda_device_event_free,
/* .event_synchronize = */ ggml_backend_cuda_device_event_synchronize,
/* .reset = */ ggml_backend_cuda_device_reset,
};
// backend reg
@@ -3849,6 +3835,7 @@ ggml_backend_reg_t ggml_backend_cuda_reg() {
dev_ctx->device = i;
dev_ctx->name = GGML_CUDA_NAME + std::to_string(i);
ggml_cuda_set_device(i);
cudaDeviceProp prop;
CUDA_CHECK(cudaGetDeviceProperties(&prop, i));
dev_ctx->description = prop.name;

View File

@@ -40,7 +40,6 @@
#define cudaDeviceDisablePeerAccess hipDeviceDisablePeerAccess
#define cudaDeviceEnablePeerAccess hipDeviceEnablePeerAccess
#define cudaDeviceProp hipDeviceProp_t
#define cudaDeviceReset hipDeviceReset
#define cudaDeviceSynchronize hipDeviceSynchronize
#define cudaError_t hipError_t
#define cudaErrorPeerAccessAlreadyEnabled hipErrorPeerAccessAlreadyEnabled

View File

@@ -19,12 +19,8 @@ static bool ggml_uncaught_exception_init = []{
return false;
}
const auto prev{std::get_terminate()};
// GGML_ASSERT(prev != ggml_uncaught_exception);
if (prev != ggml_uncaught_exception) {
previous_terminate_handler = prev;
} else {
GGML_LOG_WARN("%s double registration of ggml_uncaught_exception\n", __func__);
}
GGML_ASSERT(prev != ggml_uncaught_exception);
previous_terminate_handler = prev;
std::set_terminate(ggml_uncaught_exception);
return true;
}();

View File

@@ -2,6 +2,7 @@ package model
import (
"cmp"
"context"
"fmt"
"iter"
"log/slog"
@@ -108,7 +109,7 @@ func (bpe BytePairEncoding) Encode(s string, addSpecial bool) ([]int32, error) {
r = 0x0143
case r <= 0x0020:
r = r + 0x0100
case r >= 0x007f && r <= 0x00a0:
case r >= 0x007e && r <= 0x00a0:
r = r + 0x00a2
}
@@ -201,11 +202,12 @@ func (bpe BytePairEncoding) Encode(s string, addSpecial bool) ([]int32, error) {
}
}
slog.Log(context.TODO(), logutil.LevelTrace, "encoded", "string", s, "ids", ids)
if addSpecial && len(ids) > 0 {
ids = bpe.vocab.addSpecials(ids)
}
logutil.Trace("encoded", "string", s, "ids", ids)
return ids, nil
}
@@ -241,6 +243,6 @@ func (bpe BytePairEncoding) Decode(ids []int32) (string, error) {
}
}
logutil.Trace("decoded", "string", sb.String(), "from", lazyIdsString{ids: ids})
slog.Log(context.TODO(), logutil.LevelTrace, "decoded", "string", sb.String(), "from", lazyIdsString{ids: ids})
return sb.String(), nil
}

View File

@@ -207,36 +207,6 @@ func TestLlama(t *testing.T) {
}
}
})
t.Run("roundtriping 0x00-0xFF", func(t *testing.T) {
t.Parallel()
for b := 0x00; b <= 0xFF; b++ {
input := string(rune(b))
ids, err := tokenizer.Encode(input, false)
if err != nil {
t.Errorf("failed to encode rune 0x%02X: %v", b, err)
continue
}
decoded, err := tokenizer.Decode(ids)
if err != nil {
t.Errorf("failed to decode rune 0x%02X: %v", b, err)
continue
}
if b == 0x00 {
if len(decoded) != 0 {
t.Errorf("Decode(Encode(0x00)) should be empty, got %v", ids)
}
continue
}
if decoded != input {
t.Errorf("rune 0x%02X failed roundtrip: got %q, want %q", b, decoded, input)
}
}
})
}
func BenchmarkBytePairEncoding(b *testing.B) {

View File

@@ -1,11 +1,12 @@
package model
import (
"context"
"errors"
"fmt"
_ "image/jpeg"
_ "image/png"
"math"
"log/slog"
"os"
"reflect"
"strconv"
@@ -104,10 +105,6 @@ func New(modelPath string, params ml.BackendParams) (Model, error) {
}
arch := b.Config().Architecture()
if b.Config().Uint("pooling_type", math.MaxUint32) != math.MaxUint32 {
arch = arch + "_embed"
}
f, ok := models[arch]
if !ok {
return nil, fmt.Errorf("unsupported model architecture %q", arch)
@@ -201,7 +198,7 @@ func populateFields(base Base, v reflect.Value, tags ...Tag) reflect.Value {
names := fn(tagsCopy)
for _, name := range names {
if tensor := base.Backend().Get(strings.Join(name, ".")); tensor != nil {
logutil.Trace("found tensor", "", tensor)
slog.Log(context.TODO(), logutil.LevelTrace, "found tensor", "", tensor)
vv.Set(reflect.ValueOf(tensor))
break
}
@@ -281,29 +278,31 @@ func canNil(t reflect.Type) bool {
t.Kind() == reflect.Slice
}
func Forward(ctx ml.Context, m Model, batch input.Batch) (ml.Tensor, error) {
func Forward(ctx ml.Context, m Model, inputs []int32, batch input.Batch) (ml.Tensor, ml.Tensor, error) {
if len(batch.Positions) != len(batch.Sequences) {
return nil, fmt.Errorf("length of positions (%v) must match length of seqs (%v)", len(batch.Positions), len(batch.Sequences))
return nil, nil, fmt.Errorf("length of positions (%v) must match length of seqs (%v)", len(batch.Positions), len(batch.Sequences))
}
if len(batch.Positions) < 1 {
return nil, errors.New("batch size cannot be less than 1")
return nil, nil, errors.New("batch size cannot be less than 1")
}
batch.Inputs = ctx.Input().FromIntSlice(inputs, len(inputs))
cache := m.Config().Cache
if cache != nil {
err := cache.StartForward(ctx, batch, false)
if err != nil {
return nil, err
return nil, nil, err
}
}
t, err := m.Forward(ctx, batch)
if err != nil {
return nil, err
return nil, nil, err
}
ctx.Forward(t)
return t, nil
return batch.Inputs, t, nil
}

View File

@@ -1,73 +0,0 @@
package gemma3
import (
"errors"
"github.com/ollama/ollama/fs"
"github.com/ollama/ollama/kvcache"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/ml/nn"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/model/input"
)
type embedModel struct {
model.Base
model.SentencePieceModel
*TextModel
PoolingType uint32
Dense [2]*nn.Linear `gguf:"dense"`
}
func (m *embedModel) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
batch.Outputs = batch.Positions // return all positions
hiddenStates := m.TextModel.Forward(ctx, batch, m.Cache)
switch m.PoolingType {
case 0: // None
case 1: // Mean
hiddenStates = hiddenStates.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx).Mean(ctx)
hiddenStates = hiddenStates.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx)
default:
return nil, errors.New("unsupported pooling type")
}
for _, dense := range m.Dense {
hiddenStates = dense.Forward(ctx, hiddenStates)
}
return hiddenStates, nil
}
func newEmbedModel(c fs.Config) (model.Model, error) {
m := &embedModel{
SentencePieceModel: model.NewSentencePieceModel(
&model.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"),
Scores: c.Floats("tokenizer.ggml.scores"),
Types: c.Ints("tokenizer.ggml.token_type"),
AddBOS: c.Bool("tokenizer.ggml.add_bos_token", true),
BOS: []int32{int32(c.Uint("tokenizer.ggml.bos_token_id"))},
AddEOS: c.Bool("tokenizer.ggml.add_eos_token", false),
EOS: append(
[]int32{
int32(c.Uint("tokenizer.ggml.eos_token_id")),
int32(c.Uint("tokenizer.ggml.eot_token_id", 106)),
},
c.Ints("tokenizer.ggml.eos_token_ids")...,
),
},
),
TextModel: newTextModel(c),
PoolingType: c.Uint("pooling_type", 0),
}
m.Cache = kvcache.NewWrapperCache(
kvcache.NewSWACache(int32(c.Uint("attention.sliding_window")), m.Shift),
kvcache.NewCausalCache(m.Shift),
)
return m, nil
}

View File

@@ -18,7 +18,7 @@ type Model struct {
model.Base
model.SentencePieceModel
*VisionModel `gguf:"v"`
*VisionModel `gguf:"v,vision"`
*TextModel
*MultiModalProjector `gguf:"mm"`
@@ -141,11 +141,12 @@ func (m *Model) PostTokenize(inputs []*input.Input) ([]*input.Input, error) {
}
func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
hiddenStates := m.TextModel.Forward(ctx, batch, m.Cache)
return m.Output.Forward(ctx, hiddenStates), nil
positions := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions))
outputs := ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs))
return m.TextModel.Forward(ctx, batch.Inputs, positions, outputs, batch, m.Cache), nil
}
func init() {
model.Register("gemma3", New)
model.Register("gemma3_embed", newEmbedModel)
}

View File

@@ -159,11 +159,8 @@ func (l *TextLayer) Forward(ctx ml.Context, layer int, hiddenState, positionIDs,
return hiddenState.Add(ctx, residual)
}
func (m *TextModel) Forward(ctx ml.Context, batch input.Batch, cache kvcache.Cache) ml.Tensor {
positions := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions))
outputs := ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs))
hiddenState := m.TokenEmbedding.Forward(ctx, batch.Inputs)
func (m *TextModel) Forward(ctx ml.Context, inputs, positions, outputs ml.Tensor, batch input.Batch, cache kvcache.Cache) ml.Tensor {
hiddenState := m.TokenEmbedding.Forward(ctx, inputs)
hiddenState = hiddenState.Scale(ctx, math.Sqrt(float64(m.TextConfig.hiddenSize)))
// set image embeddings
@@ -201,5 +198,5 @@ func (m *TextModel) Forward(ctx ml.Context, batch input.Batch, cache kvcache.Cac
}
hiddenState = m.OutputNorm.Forward(ctx, hiddenState, m.eps)
return hiddenState
return m.Output.Forward(ctx, hiddenState)
}

View File

@@ -18,7 +18,7 @@ type Model struct {
model.BytePairEncoding
ImageProcessor
*VisionModel `gguf:"v"`
*VisionModel `gguf:"v,vision"`
*Projector `gguf:"mm"`
*TextModel
}

View File

@@ -18,7 +18,7 @@ type Model struct {
model.BytePairEncoding
*TextModel
*VisionModel `gguf:"v"`
*VisionModel `gguf:"v,vision"`
*MultiModalProjector `gguf:"mm"`
ImageProcessor

View File

@@ -17,7 +17,7 @@ type Model struct {
model.Base
model.BytePairEncoding
*VisionModel `gguf:"v"`
*VisionModel `gguf:"v,vision"`
*TextModel
Projector *nn.Linear `gguf:"mm.0"`

View File

@@ -18,7 +18,7 @@ type Model struct {
model.BytePairEncoding
*TextModel
*VisionModel `gguf:"v"`
*VisionModel `gguf:"v,vision"`
ImageProcessor
}

View File

@@ -2,6 +2,7 @@ package model
import (
"container/heap"
"context"
"fmt"
"log/slog"
"strconv"
@@ -24,7 +25,7 @@ func (spm SentencePieceModel) Vocabulary() *Vocabulary {
}
func NewSentencePieceModel(vocab *Vocabulary) SentencePieceModel {
logutil.Trace("Tokens", "num tokens", len(vocab.Values), "vals", vocab.Values[:5], "scores", vocab.Scores[:5], "types", vocab.Types[:5])
slog.Log(context.TODO(), logutil.LevelTrace, "Tokens", "num tokens", len(vocab.Values), "vals", vocab.Values[:5], "scores", vocab.Scores[:5], "types", vocab.Types[:5])
counter := map[int]int{}
var maxTokenLen int
@@ -38,7 +39,7 @@ func NewSentencePieceModel(vocab *Vocabulary) SentencePieceModel {
}
}
logutil.Trace("Token counts", "normal", counter[TOKEN_TYPE_NORMAL], "unknown", counter[TOKEN_TYPE_UNKNOWN], "control", counter[TOKEN_TYPE_CONTROL],
slog.Log(context.TODO(), logutil.LevelTrace, "Token counts", "normal", counter[TOKEN_TYPE_NORMAL], "unknown", counter[TOKEN_TYPE_UNKNOWN], "control", counter[TOKEN_TYPE_CONTROL],
"user defined", counter[TOKEN_TYPE_USER_DEFINED], "unused", counter[TOKEN_TYPE_UNUSED], "byte", counter[TOKEN_TYPE_BYTE],
"max token len", maxTokenLen)
@@ -181,11 +182,12 @@ func (spm SentencePieceModel) Encode(s string, addSpecial bool) ([]int32, error)
}
}
slog.Log(context.TODO(), logutil.LevelTrace, "encoded", "string", s, "ids", ids)
if addSpecial && len(ids) > 0 {
ids = spm.vocab.addSpecials(ids)
}
logutil.Trace("encoded", "string", s, "ids", ids)
return ids, nil
}
@@ -244,6 +246,6 @@ func (spm SentencePieceModel) Decode(ids []int32) (string, error) {
}
}
logutil.Trace("decoded", "ids", ids, "string", sb.String())
slog.Log(context.TODO(), logutil.LevelTrace, "decoded", "ids", ids, "string", sb.String())
return sb.String(), nil
}

View File

@@ -49,7 +49,7 @@ func (v *Vocabulary) addSpecials(ids []int32) []int32 {
slog.Warn("adding bos token to prompt which already has it", "id", v.BOS)
}
slog.Debug("adding bos token to prompt", "id", v.BOS[0])
slog.Debug("adding bos token to prompt", "id", v.BOS)
ids = append([]int32{v.BOS[0]}, ids...)
}
@@ -58,7 +58,7 @@ func (v *Vocabulary) addSpecials(ids []int32) []int32 {
slog.Warn("adding eos token to prompt which already has it", "id", v.EOS)
}
slog.Debug("adding eos token to prompt", "id", v.EOS[0])
slog.Debug("adding eos token to prompt", "id", v.EOS)
ids = append(ids, v.EOS[0])
}

View File

@@ -76,9 +76,8 @@ type JsonSchema struct {
}
type EmbedRequest struct {
Input any `json:"input"`
Model string `json:"model"`
Dimensions int `json:"dimensions,omitempty"`
Input any `json:"input"`
Model string `json:"model"`
}
type StreamOptions struct {
@@ -558,10 +557,12 @@ func fromChatRequest(r ChatCompletionRequest) (*api.ChatRequest, error) {
var think *api.ThinkValue
if r.Reasoning != nil {
options["reasoning"] = *r.Reasoning.Effort
think = &api.ThinkValue{
Value: *r.Reasoning.Effort,
}
} else if r.ReasoningEffort != nil {
options["reasoning"] = *r.ReasoningEffort
think = &api.ThinkValue{
Value: *r.ReasoningEffort,
}
@@ -1006,7 +1007,7 @@ func EmbeddingsMiddleware() gin.HandlerFunc {
}
var b bytes.Buffer
if err := json.NewEncoder(&b).Encode(api.EmbedRequest{Model: req.Model, Input: req.Input, Dimensions: req.Dimensions}); err != nil {
if err := json.NewEncoder(&b).Encode(api.EmbedRequest{Model: req.Model, Input: req.Input}); err != nil {
c.AbortWithStatusJSON(http.StatusInternalServerError, NewError(http.StatusInternalServerError, err.Error()))
return
}

View File

@@ -246,7 +246,7 @@ func filesForModel(path string) ([]string, error) {
for _, match := range matches {
if ct, err := detectContentType(match); err != nil {
return nil, err
} else if len(contentType) > 0 && ct != contentType {
} else if ct != contentType {
return nil, fmt.Errorf("invalid content type: expected %s for %s", ct, match)
}
}
@@ -255,8 +255,7 @@ func filesForModel(path string) ([]string, error) {
}
var files []string
// some safetensors files do not properly match "application/octet-stream", so skip checking their contentType
if st, _ := glob(filepath.Join(path, "*.safetensors"), ""); len(st) > 0 {
if st, _ := glob(filepath.Join(path, "*.safetensors"), "application/octet-stream"); len(st) > 0 {
// safetensors files might be unresolved git lfs references; skip if they are
// covers model-x-of-y.safetensors, model.fp32-x-of-y.safetensors, model.safetensors
files = append(files, st...)

View File

@@ -1,126 +0,0 @@
package parser
import (
"encoding/json"
"errors"
"strings"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/harmony"
)
type TokenParserType int
const (
TokenParserTypeDefault TokenParserType = iota
TokenParserTypeHarmony
)
type TokenParser struct {
messageHandler MessageHandler
parserEngine ParserInternals
toolParser ToolParser
lastToken string
tokenRepeat int
repeatLimit int
}
const defaultTokenRepeatLimit = 30
type MessageHandler interface {
AddContent(token string) (content, thinking string, toolContent string)
}
type ParserInternals interface {
AddImplicitStartOrPrefill(prefillString string)
}
type ToolParser interface {
Add(token string)
Drain() (toolName *string, toolContent string)
}
// Default implementation for the TokenParser interface as a no-op passthrough
type defaultMessageHandler struct{}
func (defaultMessageHandler) AddContent(token string) (string, string, string) {
return token, "", ""
}
type defaultEngine struct{}
func (defaultEngine) AddImplicitStartOrPrefill(prefillString string) {}
type defaultToolParser struct{}
func (defaultToolParser) Add(token string) {}
func (defaultToolParser) Drain() (*string, string) { return nil, "" }
func NewTokenParser(parserType TokenParserType, prefillString string) TokenParser {
switch parserType {
case TokenParserTypeHarmony:
harmonyMessageHandler := harmony.NewHarmonyMessageHandler()
harmonyMessageHandler.HarmonyParser.AddImplicitStartOrPrefill(prefillString)
return TokenParser{
messageHandler: harmonyMessageHandler,
parserEngine: harmonyMessageHandler.HarmonyParser,
toolParser: harmonyMessageHandler.ToolParser,
repeatLimit: defaultTokenRepeatLimit,
}
default:
return TokenParser{
messageHandler: defaultMessageHandler{},
parserEngine: defaultEngine{},
toolParser: defaultToolParser{},
repeatLimit: 30,
}
}
}
func (p *TokenParser) AddContent(token string) (string, string, error) {
if p.repeatLimitReached(token) {
return "", "", errors.New("token repeat limit reached")
}
content, thinking, toolContent := p.messageHandler.AddContent(token)
p.toolParser.Add(toolContent)
return content, thinking, nil
}
// repeatLimitReached updates repeat counters and returns true if the repeat limit is reached.
func (p *TokenParser) repeatLimitReached(token string) bool {
if p == nil {
return false
}
trimmed := strings.TrimSpace(token)
if trimmed == p.lastToken {
p.tokenRepeat++
} else {
p.tokenRepeat = 0
}
p.lastToken = trimmed
return p.tokenRepeat >= p.repeatLimit
}
// TODO: update to work with multiple toolcalls - unmarshalling should also happen on parser level
func (p *TokenParser) Drain() []api.ToolCall {
toolName, toolContent := p.toolParser.Drain()
if toolName != nil {
*toolName = strings.TrimPrefix(*toolName, "functions.")
var args api.ToolCallFunctionArguments
if err := json.Unmarshal([]byte(toolContent), &args); err != nil {
return nil
}
return []api.ToolCall{
{
Function: api.ToolCallFunction{
Name: *toolName,
Arguments: args,
},
},
}
}
return nil
}

View File

@@ -46,7 +46,7 @@ func NewInputCache(lc *llama.Context, kvSize int, numSlots int, multiUserCache b
}
// Locking: Operations on InputCacheSlot (including finding one
// through LoadCacheSlot) require a lock to be held that serializes
// through LoadCacheSlot) require a lock to be be held that serializes
// these operations with each other and llama.Decode
type InputCacheSlot struct {

View File

@@ -34,8 +34,8 @@ type InputCache struct {
func NewInputCache(model model.Model, kvCacheType string, kvSize int32, numSlots int, batchSize int, multiUserCache bool) (*InputCache, error) {
numCtx := kvSize / int32(numSlots)
if int(numCtx) < batchSize {
return nil, fmt.Errorf("kv size must be at least as large as batch size * parallel (kv: %v batch: %v parallel: %v)", kvSize, batchSize, numSlots)
if numCtx < 1 {
return nil, fmt.Errorf("must have at least one kv cache entry per parallel sequence (kv: %v parallel: %v)", kvSize, numSlots)
}
slots := make([]InputCacheSlot, numSlots)
@@ -70,13 +70,15 @@ func kvCacheTypeFromStr(s string) ml.DType {
}
func (c *InputCache) Close() {
if c != nil && c.cache != nil {
c.cache.Close()
if c == nil {
return
}
c.cache.Close()
}
// Locking: Operations on InputCacheSlot (including finding one
// through LoadCacheSlot) require a lock to be held that serializes
// through LoadCacheSlot) require a lock to be be held that serializes
// these operations with each other and processBatch
type InputCacheSlot struct {
@@ -93,7 +95,7 @@ type InputCacheSlot struct {
lastUsed time.Time
}
func (c *InputCache) LoadCacheSlot(prompt []*input.Input, cachePrompt bool) (*InputCacheSlot, []*input.Input, error) {
func (c *InputCache) LoadCacheSlot(prompt []*input.Input) (*InputCacheSlot, []*input.Input, error) {
var slot *InputCacheSlot
var numPast int32
var err error
@@ -111,10 +113,6 @@ func (c *InputCache) LoadCacheSlot(prompt []*input.Input, cachePrompt bool) (*In
return nil, nil, err
}
if !cachePrompt {
numPast = 0
}
slot.InUse = true
slot.lastUsed = time.Now()

View File

@@ -393,7 +393,7 @@ func TestLoadCacheSlot(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
slot, remainingPrompt, err := tt.cache.LoadCacheSlot(tt.prompt, true)
slot, remainingPrompt, err := tt.cache.LoadCacheSlot(tt.prompt)
// Check error state
if (err != nil) != tt.wantErr {

View File

@@ -11,13 +11,13 @@ import (
"image"
"log"
"log/slog"
"math"
"net"
"net/http"
"os"
"reflect"
"regexp"
"runtime"
"runtime/debug"
"strconv"
"strings"
"sync"
@@ -29,12 +29,12 @@ import (
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/envconfig"
"github.com/ollama/ollama/harmony"
"github.com/ollama/ollama/llm"
"github.com/ollama/ollama/logutil"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/model/input"
"github.com/ollama/ollama/parser"
"github.com/ollama/ollama/runner/common"
"github.com/ollama/ollama/sample"
@@ -88,6 +88,12 @@ type Sequence struct {
// true if an embedding are to be returned instead of text generation
embeddingOnly bool
// true if the sequence if finished and marked for removal on next pass
finished bool
// True if we have to skip this sequence to shift the cache
skipForShift bool
doneReason llm.DoneReason
// Metrics
@@ -262,24 +268,14 @@ func (s *Server) inputs(prompt string, images []llm.ImageData) ([]*input.Input,
}
type batchState struct {
// id provides a counter for trace logging batches
id int
// ctx holds the backend context used for this batch
ctx ml.Context
// modelOutput holds the outputs from this batch
id int
ctx ml.Context
modelInput ml.Tensor
modelOutput ml.Tensor
// batchInputs holds the input token pointers which may start as
// placeholders later filled in before calling ctx.Compute
batchInputs []*input.Input
// batch contains the inputs for a model forward pass
batch input.Batch
// full set of seqs at the time this batch was initiated
seqs []*Sequence
batch input.Batch
seqs []*Sequence // full set of seqs at the time this batch was initiated
initSeqIdx int // The initial value for the set of sequences evaluated (s.nextSeq - 1)
// Signaled when this batches inputs are ready and compute can proceed
inputsReadyCh chan struct{}
@@ -326,6 +322,10 @@ type Server struct {
// Used to signal a hard failure during async processing which will panic the runner
hardErrCh chan error
// A prior batch that's still being processed
// only read or written by forwardBatch
pendingBatch *batchState
// Simple counter used only for trace logging batches
batchID int
@@ -389,14 +389,25 @@ func flushPending(seq *Sequence) bool {
}
}
func (s *Server) removeSequence(seqIndex int, reason llm.DoneReason) {
func (s *Server) finishSequence(seqIndex int, reason llm.DoneReason) {
seq := s.seqs[seqIndex]
// finish could be called multiple times since we prepare 1 batch ahead
// and multiple scenarios can lead to finishing a sequence
// ensure only the first finish called is processed
if seq.finished {
return
}
flushPending(seq)
seq.doneReason = reason
seq.finished = true
close(seq.responses)
close(seq.embedding)
seq.cache.InUse = false
}
func (s *Server) removeFinishedSequence(seqIndex int) {
s.seqs[seqIndex] = nil
s.seqsSem.Release(1)
}
@@ -406,9 +417,7 @@ func (s *Server) removeSequence(seqIndex int, reason llm.DoneReason) {
func (s *Server) run(ctx context.Context) {
s.ready.Wait()
supportsAsync := s.model.Backend().Config().Uint("pooling_type", math.MaxUint32) == math.MaxUint32
var activeBatch batchState
var bs *batchState
for {
select {
case <-ctx.Done():
@@ -417,35 +426,35 @@ func (s *Server) run(ctx context.Context) {
panic(err)
default:
var err error
activeBatch, err = s.forwardBatch(activeBatch)
bs, err = s.forwardBatch()
if err != nil {
panic(err)
}
if supportsAsync {
go s.computeBatch(activeBatch)
} else {
s.computeBatch(activeBatch)
if bs == nil {
continue
}
go s.computeBatch(bs)
}
}
}
// forwardBatch will calculate a batch.
func (s *Server) forwardBatch(pendingBatch batchState) (nextBatch batchState, err error) {
func (s *Server) forwardBatch() (*batchState, error) {
inputsReady := false
var inputsReadyCh chan struct{}
// If we have a pending batch still processing, wait until Compute has started
// before setting up the next batch so the seqs inputs are ready to receive their
// token values and we get the correct input pointers for the batchInputs
if pendingBatch.ctx != nil {
logutil.Trace("forwardBatch waiting for compute to start", "pendingBatch.id", pendingBatch.id)
<-pendingBatch.computeStartedCh
logutil.Trace("forwardBatch compute started, setting up next batch", "pendingBatch.id", pendingBatch.id, "id", s.batchID)
nextBatch.inputsReadyCh = pendingBatch.outputsReadyCh // Chain the ouputs from the pending batch to the next inputs batch
if s.pendingBatch != nil {
slog.Log(context.TODO(), logutil.LevelTrace, "forwardBatch waiting for compute to start", "pendingBatch.id", s.pendingBatch.id)
<-s.pendingBatch.computeStartedCh
slog.Log(context.TODO(), logutil.LevelTrace, "forwardBatch compute started, setting up next batch", "pendingBatch.id", s.pendingBatch.id, "id", s.batchID)
inputsReadyCh = s.pendingBatch.outputsReadyCh // Chain the ouputs from the pending batch to the next inputs batch
} else {
logutil.Trace("forwardBatch no pending batch detected", "batchID", s.batchID)
// No pendingBatch, so the inputs will be ready in the seqs immediately
nextBatch.inputsReadyCh = make(chan struct{}, 1)
nextBatch.inputsReadyCh <- struct{}{}
slog.Log(context.TODO(), logutil.LevelTrace, "forwardBatch no pending batch detected", "batchID", s.batchID)
inputsReady = true // No pendingBatch, so the inputs will be ready in the seqs immediately
inputsReadyCh = make(chan struct{}, 1)
}
s.mu.Lock()
@@ -454,17 +463,55 @@ func (s *Server) forwardBatch(pendingBatch batchState) (nextBatch batchState, er
}
defer s.mu.Unlock()
nextBatch.ctx = s.model.Backend().NewContext()
defer func() {
if err != nil {
nextBatch.ctx.Close()
nextBatch.ctx = nil
// If new sequences have been added with an active batch we delay preparing the next batch
// until Compute has finished
if s.pendingBatch != nil {
for seqIdx := range s.seqs {
if s.seqs[seqIdx] != s.pendingBatch.seqs[seqIdx] {
slog.Log(context.TODO(), logutil.LevelTrace, "forwardBatch seqs changed, waiting for compute to finish to pick up new sequence(s)", "pendingBatch.id", s.pendingBatch.id)
s.mu.Unlock() // release the lock so computeBatch can finish up
<-s.pendingBatch.outputsReadyCh
slog.Log(context.TODO(), logutil.LevelTrace, "forwardBatch pending batch outputs ready", "pendingBatch.id", s.pendingBatch.id)
s.mu.Lock()
inputsReady = true // pendingBatch completed, so the inputs are ready in the seqs
break
}
}
}()
nextBatch.id = s.batchID
nextBatch.seqs = append([]*Sequence{}, s.seqs...)
nextBatch.computeStartedCh = make(chan struct{}, 1)
nextBatch.outputsReadyCh = make(chan struct{}, 1)
}
// Clear pending Batch - we'll set it if we have a batch with any inputs
s.pendingBatch = nil
// Remove any finished sequences before recording the active set of seqs in the batch
for seqIdx := range s.seqs {
seq := s.seqs[seqIdx]
if seq == nil {
continue
}
if seq.finished {
s.removeFinishedSequence(seqIdx)
continue
}
if seq.numPredict > 0 && seq.numPredicted >= seq.numPredict {
s.finishSequence(seqIdx, llm.DoneReasonLength)
s.removeFinishedSequence(seqIdx)
continue
}
}
// next batch
nb := &batchState{
id: s.batchID,
initSeqIdx: s.nextSeq - 1,
seqs: make([]*Sequence, len(s.seqs)),
inputsReadyCh: inputsReadyCh,
computeStartedCh: make(chan struct{}, 1),
outputsReadyCh: make(chan struct{}, 1),
}
ctx := s.model.Backend().NewContext()
nb.ctx = ctx
// Record the sequences at the time we create the batch so we can detect if new sequences are added on the next pass
copy(nb.seqs, s.seqs)
// Prepare the seqs and batch, but defer the input token values as we may not be ready yet
var batchInputs []*input.Input
@@ -479,13 +526,6 @@ func (s *Server) forwardBatch(pendingBatch batchState) (nextBatch batchState, er
continue
}
// if past the num predict limit
if seq.numPredict > 0 && seq.numPredicted >= seq.numPredict {
s.removeSequence(seqIdx, llm.DoneReasonLength)
nextBatch.seqs[seqIdx] = nil
continue
}
if !s.cache.enabled {
seq.inputs = append(seq.cache.Inputs, seq.inputs...)
seq.cache.Inputs = []*input.Input{}
@@ -521,28 +561,28 @@ func (s *Server) forwardBatch(pendingBatch batchState) (nextBatch batchState, er
break
}
err = s.cache.ShiftCacheSlot(seq.cache, seq.numKeep)
err := s.cache.ShiftCacheSlot(seq.cache, seq.numKeep)
if err != nil {
var reprocess *ErrReprocessInputs
if errors.As(err, &reprocess) {
// Prepend these inputs to the sequence's inputs queue for reprocessing
seq.inputs = append(reprocess.Inputs, seq.inputs...)
// Skip this sequence but continue processing the rest
nextBatch.seqs[seqIdx] = nil // clear this sequence for this batch
err = nil
seq.skipForShift = true // cleared in computeBatch below for the next batch
continue
} else {
return
ctx.Close()
return nil, err
}
}
}
batchInputs = append(batchInputs, seq.inputs[i])
if inp.Multimodal != nil {
var mm []input.Multimodal
mm, err = seq.mmStore.getMultimodal(s.model.Backend(), nextBatch.ctx, inp.Multimodal, false)
mm, err := seq.mmStore.getMultimodal(s.model.Backend(), ctx, inp.Multimodal, false)
if err != nil {
return
ctx.Close()
return nil, err
}
batch.Multimodal = append(batch.Multimodal, input.MultimodalIndex{Index: len(batchInputs) - 1, Multimodal: mm})
}
@@ -550,11 +590,13 @@ func (s *Server) forwardBatch(pendingBatch batchState) (nextBatch batchState, er
batch.Positions = append(batch.Positions, int32(len(seq.cache.Inputs)+len(seq.pendingInputs)))
batch.Sequences = append(batch.Sequences, seq.cache.Id)
// TODO BUG HERE!!!
// Somehow sometimes iBatch isn't set correctly
seq.iBatch = len(batch.Outputs)
if i+1 == len(seq.inputs) {
batch.Outputs = append(batch.Outputs, int32(len(batchInputs)-1))
}
logutil.Trace("forwardBatch iBatch", "batchID", s.batchID, "seqIdx", seqIdx, "seq.iBatch", seq.iBatch, "i+1", i+1, "len(seq.inputs)", len(seq.inputs))
slog.Log(context.TODO(), logutil.LevelTrace, "forwardBatch iBatch", "batchID", s.batchID, "seqIdx", seqIdx, "seq.iBatch", seq.iBatch, "i+1", i+1, "len(seq.inputs)", len(seq.inputs))
seq.pendingInputs = append(seq.pendingInputs, inp)
}
@@ -568,57 +610,66 @@ func (s *Server) forwardBatch(pendingBatch batchState) (nextBatch batchState, er
}
if len(batchInputs) == 0 {
logutil.Trace("forwardBatch no batchInputs, going idle", "batchID", s.batchID)
nextBatch.ctx.Close()
nextBatch.ctx = nil
return
slog.Log(context.TODO(), logutil.LevelTrace, "forwardBatch no batchInputs, going idle", "batchID", s.batchID)
ctx.Close()
return nil, nil
}
s.batchID++
// Actual batchInputs values will be injected into the batch.Inputs tensor before calling Compute
batch.Inputs = nextBatch.ctx.Input().Empty(ml.DTypeI32, len(batchInputs))
nextBatch.modelOutput, err = model.Forward(nextBatch.ctx, s.model, batch)
var err error
// Actual batchInputs values will be injected into the modelInput tensor before calling Compute
nb.modelInput, nb.modelOutput, err = model.Forward(ctx, s.model, make([]int32, len(batchInputs)), batch)
if err != nil {
err = fmt.Errorf("failed to build graph: %w", err)
return
ctx.Close()
return nil, fmt.Errorf("failed to build graph: %w", err)
}
nextBatch.batchInputs = batchInputs
nextBatch.batch = batch
nb.batchInputs = batchInputs
nb.batch = batch
return
// computeBatch will close the context in the batch upon completion
s.pendingBatch = nb
if inputsReady {
nb.inputsReadyCh <- struct{}{}
}
return nb, nil
}
// Async processing of the next batch
func (s *Server) computeBatch(activeBatch batchState) {
if activeBatch.ctx == nil {
func (s *Server) computeBatch(bs *batchState) {
if bs == nil || bs.ctx == nil {
// Nothing to compute
return
}
defer activeBatch.ctx.Close()
defer bs.ctx.Close()
// Wait until inputs are ready
logutil.Trace("computeBatch: waiting for inputs to be ready", "batchID", activeBatch.id)
<-activeBatch.inputsReadyCh
logutil.Trace("computeBatch: inputs are ready", "batchID", activeBatch.id)
slog.Log(context.TODO(), logutil.LevelTrace, "computeBatch: waiting for inputs to be ready", "batchID", bs.id)
<-bs.inputsReadyCh
slog.Log(context.TODO(), logutil.LevelTrace, "computeBatch: inputs are ready", "batchID", bs.id)
// Once we complete, signal the next batch of inputs are ready
// This will unblock the next computeBatch, or forwardBatch if new seqs come in
defer func() {
logutil.Trace("computeBatch: outputs are ready", "batchID", activeBatch.id)
activeBatch.outputsReadyCh <- struct{}{}
slog.Log(context.TODO(), logutil.LevelTrace, "computeBatch: outputs are ready", "batchID", bs.id)
bs.outputsReadyCh <- struct{}{}
}()
s.mu.Lock()
// Gather the actual input token values now that they're ready
batchInputs := make([]int32, len(activeBatch.batchInputs))
batchInputs := make([]int32, len(bs.batchInputs))
for i := range batchInputs {
batchInputs[i] = activeBatch.batchInputs[i].Token
batchInputs[i] = bs.batchInputs[i].Token
}
// TODO the following logic could be run in a go routine to possibly speed up getting to Compute
// Now we run part of the decoding algorithm to adjust the seq.inputs with placeholder tokens
// so that forwardBatch can build a batchInputs set which will eventually contain the actual
// decoded tokens.
promptProcessing := make([]bool, len(s.seqs)) // track seq's we skip
nextBatchTokens := make([]*input.Input, len(s.seqs))
iBatches := make([]int, len(s.seqs)) // Record the iBatch values before releasing the lock
for i, seq := range s.seqs {
@@ -626,26 +677,12 @@ func (s *Server) computeBatch(activeBatch batchState) {
if seq == nil {
continue
}
// Skip over any newly added or skipped sequences
if activeBatch.seqs[i] == nil {
// Skip over any newly added sequences
if bs.seqs[i] == nil {
continue
}
// Detect if the sequence we're processing has already been completed and replaced
// with a new sequence
if seq != activeBatch.seqs[i] {
logutil.Trace("computeBatch: sequence replaced, discarding its results", "batchID", activeBatch.id, "seqIdx", i)
continue
}
// Pending inputs will actually be in the cache after we call Compute.
// However, we have already resolved any placeholder tokens.
//
// It's possible for incoming sequences to look at the values that we've
// added to the cache here and start relying on them before we've done
// the computation. This is OK as long as we ensure that this batch's
// computation happens before any future batch's and we never fail
// (unless we take down the whole runner).
// After calling Forward, pending inputs are now in the cache
if len(seq.pendingInputs) > 0 {
seq.cache.Inputs = append(seq.cache.Inputs, seq.pendingInputs...)
seq.pendingInputs = []*input.Input{}
@@ -655,9 +692,10 @@ func (s *Server) computeBatch(activeBatch batchState) {
if len(seq.inputs) != 0 {
if !s.cache.enabled {
s.hardErrCh <- fmt.Errorf("caching disabled but unable to fit entire input in a batch")
s.mu.Unlock()
return
}
// Record so we can skip during Decode
promptProcessing[i] = true
continue
}
@@ -670,25 +708,37 @@ func (s *Server) computeBatch(activeBatch batchState) {
// At this point the seqs are ready for forwardBatch to move forward so unblock
s.mu.Unlock()
slog.Log(context.TODO(), logutil.LevelTrace, "computeBatch: signaling computeStartedCh", "batchID", bs.id)
bs.computeStartedCh <- struct{}{}
activeBatch.batch.Inputs.SetValueFromIntSlice(batchInputs)
activeBatch.ctx.ComputeWithNotify(
func() {
logutil.Trace("computeBatch: signaling computeStartedCh", "batchID", activeBatch.id)
activeBatch.computeStartedCh <- struct{}{}
},
activeBatch.modelOutput)
bs.modelInput.BackendSetFromIntSlice(batchInputs)
bs.ctx.Compute(bs.modelOutput)
logits := bs.modelOutput.Floats()
outputs := activeBatch.modelOutput.Floats()
logutil.Trace("computeBatch: logits ready", "batchID", activeBatch.id)
slog.Log(context.TODO(), logutil.LevelTrace, "computeBatch: logits ready", "batchID", bs.id)
s.mu.Lock()
defer s.mu.Unlock()
logutil.Trace("computeBatch: decoding", "batchID", activeBatch.id)
slog.Log(context.TODO(), logutil.LevelTrace, "computeBatch: decoding", "batchID", bs.id)
for i, seq := range s.seqs {
if seq == nil || nextBatchTokens[i] == nil {
if seq == nil {
continue
}
// Skip over any newly added sequences
if bs.seqs[i] == nil {
continue
}
// Detect if the sequence we're processing has already been completed and replaced
// with a new sequence
if seq != bs.seqs[i] {
slog.Log(context.TODO(), logutil.LevelTrace, "computeBatch: sequence replaced, discarding its results", "batchID", bs.id, "seqIdx", i)
continue
}
// don't sample prompt processing
if promptProcessing[i] {
continue
}
@@ -698,15 +748,16 @@ func (s *Server) computeBatch(activeBatch batchState) {
// if done processing the prompt, generate an embedding and return
if seq.embeddingOnly {
seq.embedding <- outputs
s.removeSequence(i, llm.DoneReasonStop)
// TODO(jessegross): Embedding support
slog.Warn("generation of embedding outputs not yet supported", "id", bs.id, "seqIdx", i)
s.finishSequence(i, llm.DoneReasonStop)
continue
}
// sample a token
vocabSize := len(outputs) / len(activeBatch.batch.Outputs)
logutil.Trace("computeBatch: vocab details", "batchID", activeBatch.id, "seqIdx", i, "len(logits)", len(outputs), "len(activeBatch.batch.Outputs)", len(activeBatch.batch.Outputs), "vocabSize", vocabSize, "iBatches", iBatches)
token, err := seq.sampler.Sample(outputs[iBatches[i]*vocabSize : (iBatches[i]+1)*vocabSize])
vocabSize := len(logits) / len(bs.batch.Outputs)
slog.Log(context.TODO(), logutil.LevelTrace, "computeBatch: vocab details", "batchID", bs.id, "seqIdx", i, "len(logits)", len(logits), "len(bs.batch.Outputs)", len(bs.batch.Outputs), "vocabSize", vocabSize, "seq.iBatch", seq.iBatch)
token, err := seq.sampler.Sample(logits[iBatches[i]*vocabSize : (iBatches[i]+1)*vocabSize])
if err != nil {
s.hardErrCh <- fmt.Errorf("failed to sample token: %w", err)
return
@@ -719,8 +770,8 @@ func (s *Server) computeBatch(activeBatch batchState) {
// TODO (jmorganca): we should send this back
// as it's important for the /api/generate context
// seq.responses <- piece
logutil.Trace("computeBatch: EOS", "batchID", activeBatch.id, "seqIdx", i)
s.removeSequence(i, llm.DoneReasonStop)
slog.Log(context.TODO(), logutil.LevelTrace, "computeBatch: EOS", "batchID", bs.id, "seqIdx", i)
s.finishSequence(i, llm.DoneReasonStop)
continue
}
@@ -730,6 +781,15 @@ func (s *Server) computeBatch(activeBatch batchState) {
return
}
if nextBatchTokens[i] == nil {
slog.Error("batch corrupted", "id", bs.id, "batch", bs.batch, "seqIdx", i, "seq", seq)
s.hardErrCh <- fmt.Errorf("expected a single token during decode")
return
}
// fill in the final selected token value to replace the placeholder in the next batch
// nextBatchTokensWritten++
seq.pendingResponses = append(seq.pendingResponses, piece)
sequence := strings.Join(seq.pendingResponses, "")
@@ -756,7 +816,7 @@ func (s *Server) computeBatch(activeBatch batchState) {
seq.cache.Inputs = seq.cache.Inputs[:tokenLen]
s.removeSequence(i, llm.DoneReasonStop)
s.finishSequence(i, llm.DoneReasonStop)
continue
}
@@ -769,7 +829,7 @@ func (s *Server) computeBatch(activeBatch batchState) {
}
if !flushPending(seq) {
s.removeSequence(i, llm.DoneReasonConnectionClosed)
s.finishSequence(i, llm.DoneReasonConnectionClosed)
}
}
}
@@ -781,7 +841,14 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
return
}
tokenParser := parser.NewTokenParser(req.ParserType, req.PrefillString)
var harmonyMessageHandler *harmony.HarmonyMessageHandler
var harmonyToolParser *harmony.HarmonyToolCallAccumulator
if req.FunctionNameMap != nil {
harmonyMessageHandler = harmony.NewHarmonyMessageHandler()
harmonyMessageHandler.FunctionNameMap = req.FunctionNameMap
harmonyMessageHandler.HarmonyParser.AddImplicitStartOrPrefill(req.PrefillContent)
harmonyToolParser = harmonyMessageHandler.CreateToolParser()
}
if req.Options == nil {
opts := api.DefaultOptions()
@@ -844,7 +911,7 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
found := false
for i, sq := range s.seqs {
if sq == nil {
seq.cache, seq.inputs, err = s.cache.LoadCacheSlot(seq.inputs, true)
seq.cache, seq.inputs, err = s.cache.LoadCacheSlot(seq.inputs)
if err != nil {
s.mu.Unlock()
s.seqsSem.Release(1)
@@ -874,12 +941,10 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
case content, ok := <-seq.responses:
if ok {
var thinking string
var err error
content, thinking, err = tokenParser.AddContent(content)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
close(seq.quit)
return
if harmonyMessageHandler != nil {
var toolContent string
content, thinking, toolContent = harmonyMessageHandler.AddContent(content, harmonyToolParser)
harmonyToolParser.Add(toolContent)
}
if err := json.NewEncoder(w).Encode(&llm.CompletionResponse{
@@ -893,7 +958,27 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
flusher.Flush()
} else {
toolCalls := tokenParser.Drain()
var toolCalls []api.ToolCall
if harmonyMessageHandler != nil {
toolName, toolContent := harmonyToolParser.Drain()
if toolName != nil {
*toolName = strings.TrimPrefix(*toolName, "functions.")
*toolName = harmonyMessageHandler.FunctionNameMap.OriginalFromConverted(*toolName)
var args api.ToolCallFunctionArguments
if err := json.Unmarshal([]byte(toolContent), &args); err != nil {
http.Error(w, fmt.Sprintf("failed to unmarshal tool call function arguments: %v", err), http.StatusInternalServerError)
close(seq.quit)
return
}
toolCalls = append(toolCalls, api.ToolCall{
Function: api.ToolCallFunction{
Name: *toolName,
Arguments: args,
},
})
}
}
if err := json.NewEncoder(w).Encode(&llm.CompletionResponse{
ToolCalls: toolCalls,
Done: true,
@@ -912,67 +997,6 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
}
}
func (s *Server) embeddings(w http.ResponseWriter, r *http.Request) {
if s.model.Backend().Config().Uint("pooling_type", math.MaxUint32) == math.MaxUint32 {
http.Error(w, "this model does not support embeddings", http.StatusNotImplemented)
return
}
var req llm.EmbeddingRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
http.Error(w, fmt.Sprintf("bad request: %s", err), http.StatusBadRequest)
return
}
w.Header().Set("Content-Type", "application/json")
seq, err := s.NewSequence(req.Content, nil, NewSequenceParams{embedding: true})
if err != nil {
http.Error(w, fmt.Sprintf("failed to create new sequence: %v", err), http.StatusInternalServerError)
return
}
if err := s.seqsSem.Acquire(r.Context(), 1); err != nil {
if errors.Is(err, context.Canceled) {
slog.Info("aborting embedding request due to client closing the connection")
} else {
http.Error(w, fmt.Sprintf("failed to acquire semaphore: %v", err), http.StatusInternalServerError)
}
return
}
s.mu.Lock()
found := false
for i, sq := range s.seqs {
if sq == nil {
seq.cache, seq.inputs, err = s.cache.LoadCacheSlot(seq.inputs, false)
if err != nil {
s.mu.Unlock()
s.seqsSem.Release(1)
http.Error(w, fmt.Sprintf("failed to load cache: %v", err), http.StatusInternalServerError)
return
}
s.seqs[i] = seq
s.cond.Signal()
found = true
break
}
}
s.mu.Unlock()
if !found {
s.seqsSem.Release(1)
http.Error(w, "could not find an available sequence", http.StatusInternalServerError)
return
}
if err := json.NewEncoder(w).Encode(&llm.EmbeddingResponse{
Embedding: <-seq.embedding,
}); err != nil {
http.Error(w, fmt.Sprintf("failed to encode response: %v", err), http.StatusInternalServerError)
}
}
func (s *Server) health(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
if err := json.NewEncoder(w).Encode(&llm.ServerStatusResponse{
@@ -1100,13 +1124,9 @@ func (s *Server) allocModel(
// Convert memory allocation panics to errors
defer func() {
if r := recover(); r != nil {
debug.PrintStack()
if err, ok := r.(error); ok {
var noMem ml.ErrNoMem
if errors.As(err, &noMem) {
panicErr = noMem
} else {
panic(r)
}
panicErr = err
} else {
panic(r)
}
@@ -1293,7 +1313,10 @@ func Execute(args []string) error {
mux := http.NewServeMux()
// TODO: support embeddings
mux.HandleFunc("POST /load", server.load)
mux.HandleFunc("POST /embedding", server.embeddings)
mux.HandleFunc("POST /embedding", func(w http.ResponseWriter, r *http.Request) {
http.Error(w, "this model does not support embeddings", http.StatusNotImplemented)
})
mux.HandleFunc("POST /completion", server.completion)
mux.HandleFunc("GET /health", server.health)

View File

@@ -78,7 +78,7 @@ function checkEnv() {
}
function buildCPU() {
function buildOllama() {
mkdir -Force -path "${script:DIST_DIR}\"
if ($script:ARCH -ne "arm64") {
Remove-Item -ea 0 -recurse -force -path "${script:SRC_DIR}\dist\windows-${script:ARCH}"
@@ -90,72 +90,20 @@ function buildCPU() {
if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)}
& cmake --install build --component CPU --strip
if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)}
}
}
function buildCUDA11() {
# CUDA v11 claims to be compatible with MSVC 2022, but the latest updates are no longer compatible
# 19.40 is the last compiler version that works, but recent udpates are 19.43
# So this pins to MSVC 2019 for best compatibility
mkdir -Force -path "${script:DIST_DIR}\"
if ($script:ARCH -ne "arm64") {
$hashEnv = @{}
Get-ChildItem env: | foreach { $hashEnv[$_.Name] = $_.Value }
if ("$script:CUDA_DIRS".Contains("v11")) {
$hashEnv.Keys | foreach { if ($_.Contains("CUDA_PATH_V11")) { $x=$hashEnv[$_]; if (test-path -literalpath "$x\bin\nvcc.exe" ) { $cuda=$x} }}
write-host "Building CUDA v11 backend libraries $cuda"
$env:CUDAToolkit_ROOT=$cuda
& cmake --fresh --preset "CUDA 11" -T cuda="$cuda" -DCMAKE_CUDA_COMPILER="$cuda\bin\nvcc.exe" -G "Visual Studio 16 2019" --install-prefix $script:DIST_DIR -DOLLAMA_RUNNER_DIR="cuda_v11"
if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)}
& cmake --build --preset "CUDA 11" --config Release --parallel $script:JOBS
if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)}
& cmake --install build --component "CUDA" --strip
if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)}
}
}
}
function buildCUDA12() {
mkdir -Force -path "${script:DIST_DIR}\"
if ($script:ARCH -ne "arm64") {
$hashEnv = @{}
Get-ChildItem env: | foreach { $hashEnv[$_.Name] = $_.Value }
if ("$script:CUDA_DIRS".Contains("v12.8")) {
$hashEnv.Keys | foreach { if ($_.Contains("CUDA_PATH_V12_8")) { $x=$hashEnv[$_]; if (test-path -literalpath "$x\bin\nvcc.exe" ) { $cuda=$x} }}
write-host "Building CUDA v12 backend libraries $cuda"
$env:CUDAToolkit_ROOT=$cuda
& cmake --fresh --preset "CUDA 12" -T cuda="$cuda" --install-prefix $script:DIST_DIR -DOLLAMA_RUNNER_DIR="cuda_v12"
if ("$script:CUDA_DIRS".Contains("v12")) {
$hashEnv.Keys | foreach { if ($_.Contains("CUDA_PATH_V12")) { $v12="$_" }}
$env:CUDAToolkit_ROOT=$hashEnv[$v12]
write-host "Building CUDA v12 backend libraries"
& cmake --fresh --preset "CUDA 12" --install-prefix $script:DIST_DIR
if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)}
& cmake --build --preset "CUDA 12" --config Release --parallel $script:JOBS
if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)}
& cmake --install build --component "CUDA" --strip
if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)}
}
}
}
function buildCUDA13() {
mkdir -Force -path "${script:DIST_DIR}\"
if ($script:ARCH -ne "arm64") {
$hashEnv = @{}
Get-ChildItem env: | foreach { $hashEnv[$_.Name] = $_.Value }
if ("$script:CUDA_DIRS".Contains("v13")) {
$hashEnv.Keys | foreach { if ($_.Contains("CUDA_PATH_V13")) { $x=$hashEnv[$_]; if (test-path -literalpath "$x\bin\nvcc.exe" ) { $cuda=$x} }}
$env:CUDAToolkit_ROOT=$cuda
write-host "Building CUDA v13 backend libraries $cuda"
& cmake --fresh --preset "CUDA 13" -T cuda="$cuda" --install-prefix $script:DIST_DIR -DOLLAMA_RUNNER_DIR="cuda_v13"
if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)}
& cmake --build --preset "CUDA 13" --config Release --parallel $script:JOBS
if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)}
& cmake --install build --component "CUDA" --strip
if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)}
}
}
}
function buildROCm() {
mkdir -Force -path "${script:DIST_DIR}\"
if ($script:ARCH -ne "arm64") {
if ($env:HIP_PATH) {
write-host "Building ROCm backend libraries"
if (-Not (get-command -ErrorAction silent ninja)) {
@@ -181,10 +129,6 @@ function buildROCm() {
if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)}
}
}
}
function buildOllama() {
mkdir -Force -path "${script:DIST_DIR}\"
write-host "Building ollama CLI"
& go build -trimpath -ldflags "-s -w -X=github.com/ollama/ollama/version.Version=$script:VERSION -X=github.com/ollama/ollama/server.mode=release" .
if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)}
@@ -292,10 +236,6 @@ function distZip() {
checkEnv
try {
if ($($args.count) -eq 0) {
buildCPU
buildCUDA12
buildCUDA13
buildROCm
buildOllama
buildApp
gatherDependencies

View File

@@ -36,7 +36,6 @@ import (
"github.com/ollama/ollama/llm"
"github.com/ollama/ollama/logutil"
"github.com/ollama/ollama/openai"
"github.com/ollama/ollama/parser"
"github.com/ollama/ollama/server/internal/client/ollama"
"github.com/ollama/ollama/server/internal/registry"
"github.com/ollama/ollama/template"
@@ -178,7 +177,7 @@ func (s *Server) GenerateHandler(c *gin.Context) {
}
// expire the runner
if req.Prompt == "" && req.KeepAlive != nil && req.KeepAlive.Duration == 0 {
if req.Prompt == "" && req.KeepAlive != nil && int(req.KeepAlive.Seconds()) == 0 {
s.sched.expireRunner(m)
c.JSON(http.StatusOK, api.GenerateResponse{
@@ -197,17 +196,6 @@ func (s *Server) GenerateHandler(c *gin.Context) {
}
useHarmony := harmony.ShouldUseHarmony(m.Config.ModelFamily, m.Template) && !req.Raw
var parserType parser.TokenParserType
if useHarmony {
parserType = parser.TokenParserTypeHarmony
} else {
parserType = parser.TokenParserTypeDefault
}
var functionNameMap *harmony.FunctionNameMap
if useHarmony {
functionNameMap = harmony.NewFunctionNameMap()
}
// Validate Think value: string values currently only allowed for gptoss models
if req.Think != nil && req.Think.IsString() && !useHarmony {
@@ -350,19 +338,16 @@ func (s *Server) GenerateHandler(c *gin.Context) {
var sb strings.Builder
defer close(ch)
if err := r.Completion(c.Request.Context(), llm.CompletionRequest{
Prompt: prompt,
Images: images,
Format: req.Format,
Options: opts,
ParserType: parserType,
Prompt: prompt,
Images: images,
Format: req.Format,
Options: opts,
}, func(cr llm.CompletionResponse) {
res := api.GenerateResponse{
Model: req.Model,
CreatedAt: time.Now().UTC(),
Response: cr.Content,
Done: cr.Done,
Thinking: cr.Thinking,
ToolCalls: cr.ToolCalls,
Metrics: api.Metrics{
PromptEvalCount: cr.PromptEvalCount,
PromptEvalDuration: cr.PromptEvalDuration,
@@ -371,22 +356,7 @@ func (s *Server) GenerateHandler(c *gin.Context) {
},
}
if res.Done {
res.DoneReason = cr.DoneReason.String()
res.TotalDuration = time.Since(checkpointStart)
res.LoadDuration = checkpointLoaded.Sub(checkpointStart)
}
if useHarmony {
for i, tool := range res.ToolCalls {
res.ToolCalls[i].Function.Name = functionNameMap.OriginalFromConverted(tool.Function.Name)
}
if res.Response != "" || res.Thinking != "" || len(res.ToolCalls) > 0 || res.Done {
ch <- res
}
return
}
if thinkingState != nil {
if !useHarmony && thinkingState != nil {
thinking, content := thinkingState.AddContent(cr.Content)
res.Thinking = thinking
res.Response = content
@@ -397,6 +367,10 @@ func (s *Server) GenerateHandler(c *gin.Context) {
}
if cr.Done {
res.DoneReason = cr.DoneReason.String()
res.TotalDuration = time.Since(checkpointStart)
res.LoadDuration = checkpointLoaded.Sub(checkpointStart)
if !req.Raw {
tokens, err := r.Tokenize(c.Request.Context(), prompt+sb.String())
if err != nil {
@@ -558,12 +532,7 @@ func (s *Server) EmbedHandler(c *gin.Context) {
if err != nil {
return err
}
// TODO: this first normalization should be done by the model
embedding = normalize(embedding)
if req.Dimensions > 0 && req.Dimensions < len(embedding) {
embedding = normalize(embedding[:req.Dimensions])
}
embeddings[i] = embedding
embeddings[i] = normalize(embedding)
return nil
})
}
@@ -589,7 +558,11 @@ func normalize(vec []float32) []float32 {
sum += v * v
}
norm := float32(1.0 / max(math.Sqrt(float64(sum)), 1e-12))
norm := float32(0.0)
if sum > 0 {
norm = float32(1.0 / math.Sqrt(float64(sum)))
}
for i := range vec {
vec[i] *= norm
}
@@ -1527,7 +1500,7 @@ func (s *Server) ChatHandler(c *gin.Context) {
}
// expire the runner
if len(req.Messages) == 0 && req.KeepAlive != nil && req.KeepAlive.Duration == 0 {
if len(req.Messages) == 0 && req.KeepAlive != nil && int(req.KeepAlive.Seconds()) == 0 {
model, err := GetModel(req.Model)
if err != nil {
switch {
@@ -1600,20 +1573,29 @@ func (s *Server) ChatHandler(c *gin.Context) {
msgs = filterThinkTags(msgs, m)
useHarmony := harmony.ShouldUseHarmony(m.Config.ModelFamily, m.Template)
var parserType parser.TokenParserType
if useHarmony {
parserType = parser.TokenParserTypeHarmony
} else {
parserType = parser.TokenParserTypeDefault
}
processedTools := req.Tools
var functionNameMap *harmony.FunctionNameMap
var prefillString string
// TODO(parthsareen): this can be abstracted to not be model specific and potentially moved to the runner
var prefillContentOrThinking *bool
if useHarmony {
prefillString = harmony.Prefill(msgs[len(msgs)-1])
functionNameMap = harmony.NewFunctionNameMap()
var lastMessage *api.Message
if len(msgs) > 0 {
lastMessage = &msgs[len(msgs)-1]
}
// prefill content or thinking flag if the last message is an assistant message
if lastMessage != nil && lastMessage.Role == "assistant" {
if lastMessage.Content != "" {
trueVal := true
// true sets content to be prefilled
prefillContentOrThinking = &trueVal
} else if lastMessage.Thinking != "" {
// false sets thinking to be prefilled
falseVal := false
prefillContentOrThinking = &falseVal
}
}
// make a copy of tools to pass to the chat prompt. Function names may be
// renamed to be valid Harmony function names.
processedTools = make([]api.Tool, len(req.Tools))
@@ -1656,10 +1638,6 @@ func (s *Server) ChatHandler(c *gin.Context) {
OpeningTag: openingTag,
ClosingTag: closingTag,
}
if strings.HasSuffix(strings.TrimSpace(prompt), openingTag) {
thinkingState.AddContent(openingTag)
}
}
var toolParser *tools.Parser
@@ -1672,12 +1650,12 @@ func (s *Server) ChatHandler(c *gin.Context) {
defer close(ch)
if err := r.Completion(c.Request.Context(), llm.CompletionRequest{
Prompt: prompt,
Images: images,
Format: req.Format,
Options: opts,
ParserType: parserType,
PrefillString: prefillString,
Prompt: prompt,
Images: images,
Format: req.Format,
Options: opts,
FunctionNameMap: functionNameMap,
PrefillContent: prefillContentOrThinking,
}, func(r llm.CompletionResponse) {
res := api.ChatResponse{
Model: req.Model,
@@ -1698,9 +1676,6 @@ func (s *Server) ChatHandler(c *gin.Context) {
}
if useHarmony {
for i, tool := range res.Message.ToolCalls {
res.Message.ToolCalls[i].Function.Name = functionNameMap.OriginalFromConverted(tool.Function.Name)
}
// only send messages with meaningful content (empty messages confuse clients)
if res.Message.Content != "" || res.Message.Thinking != "" || len(res.Message.ToolCalls) > 0 || res.Done {
ch <- res

View File

@@ -969,233 +969,3 @@ func TestGenerate(t *testing.T) {
}
})
}
func TestChatWithPromptEndingInThinkTag(t *testing.T) {
gin.SetMode(gin.TestMode)
// Helper to create a standard thinking test setup
setupThinkingTest := func(t *testing.T) (*mockRunner, *Server) {
mock := &mockRunner{
CompletionResponse: llm.CompletionResponse{
Done: true,
DoneReason: llm.DoneReasonStop,
PromptEvalCount: 1,
PromptEvalDuration: 1,
EvalCount: 1,
EvalDuration: 1,
},
}
s := &Server{
sched: &Scheduler{
pendingReqCh: make(chan *LlmRequest, 1),
finishedReqCh: make(chan *LlmRequest, 1),
expiredCh: make(chan *runnerRef, 1),
unloadedCh: make(chan any, 1),
loaded: make(map[string]*runnerRef),
newServerFn: newMockServer(mock),
getGpuFn: discover.GetGPUInfo,
getCpuFn: discover.GetCPUInfo,
reschedDelay: 250 * time.Millisecond,
loadFn: func(req *LlmRequest, _ *ggml.GGML, _ discover.GpuInfoList, _ bool) bool {
time.Sleep(time.Millisecond)
req.successCh <- &runnerRef{llama: mock}
return false
},
},
}
go s.sched.Run(t.Context())
// Create a model with thinking support
_, digest := createBinFile(t, ggml.KV{
"general.architecture": "llama",
"llama.block_count": uint32(1),
"llama.context_length": uint32(8192),
"llama.embedding_length": uint32(4096),
"llama.attention.head_count": uint32(32),
"llama.attention.head_count_kv": uint32(8),
"tokenizer.ggml.tokens": []string{""},
"tokenizer.ggml.scores": []float32{0},
"tokenizer.ggml.token_type": []int32{0},
}, []*ggml.Tensor{
{Name: "token_embd.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
{Name: "blk.0.attn_norm.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
{Name: "blk.0.ffn_down.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
{Name: "blk.0.ffn_gate.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
{Name: "blk.0.ffn_up.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
{Name: "blk.0.ffn_norm.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
{Name: "blk.0.attn_k.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
{Name: "blk.0.attn_output.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
{Name: "blk.0.attn_q.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
{Name: "blk.0.attn_v.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
{Name: "output.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
})
// Create model with thinking template that adds <think> at the end
w := createRequest(t, s.CreateHandler, api.CreateRequest{
Model: "test-thinking",
Files: map[string]string{"file.gguf": digest},
Template: `{{- range .Messages }}
{{- if eq .Role "user" }}user: {{ .Content }}
{{ else if eq .Role "assistant" }}assistant: {{ if .Thinking }}<think>{{ .Thinking }}</think>{{ end }}{{ .Content }}
{{ end }}{{ end }}<think>`,
Stream: &stream,
})
if w.Code != http.StatusOK {
t.Fatalf("expected status 200, got %d", w.Code)
}
return mock, s
}
mock, s := setupThinkingTest(t)
// Helper to test chat responses
testChatRequest := func(t *testing.T, name string, userContent string, modelResponse string, expectedThinking string, expectedContent string, think bool) {
t.Run(name, func(t *testing.T) {
mock.CompletionResponse = llm.CompletionResponse{
Content: modelResponse,
Done: true,
DoneReason: llm.DoneReasonStop,
PromptEvalCount: 1,
PromptEvalDuration: 1,
EvalCount: 1,
EvalDuration: 1,
}
mock.CompletionFn = nil
streamRequest := false
req := api.ChatRequest{
Model: "test-thinking",
Messages: []api.Message{
{Role: "user", Content: userContent},
},
Stream: &streamRequest,
}
if think {
req.Think = &api.ThinkValue{Value: think}
}
w := createRequest(t, s.ChatHandler, req)
if w.Code != http.StatusOK {
t.Fatalf("expected status 200, got %d", w.Code)
}
var resp api.ChatResponse
if err := json.NewDecoder(w.Body).Decode(&resp); err != nil {
t.Fatal(err)
}
if resp.Message.Thinking != expectedThinking {
t.Errorf("expected thinking %q, got %q", expectedThinking, resp.Message.Thinking)
}
if resp.Message.Content != expectedContent {
t.Errorf("expected content %q, got %q", expectedContent, resp.Message.Content)
}
})
}
// Test cases - Note: Template adds <think> at the end, and leading whitespace after <think> is eaten by the parser
testChatRequest(t, "basic thinking response",
"Help me solve this problem",
" Let me think about this step by step... </think> The answer is 42.",
"Let me think about this step by step... ",
"The answer is 42.",
true)
testChatRequest(t, "thinking with multiple sentences",
"Explain quantum computing",
" First, I need to understand the basics. Quantum bits can be in superposition. </think> Quantum computing uses quantum mechanics principles.",
"First, I need to understand the basics. Quantum bits can be in superposition. ",
"Quantum computing uses quantum mechanics principles.",
true)
testChatRequest(t, "no thinking content",
"What is 2+2?",
"</think> The answer is 4.",
"",
"The answer is 4.",
true)
testChatRequest(t, "thinking disabled but template still adds think tag",
"Simple question",
" My thoughts </think> The answer.",
"",
" My thoughts </think> The answer.",
false)
// Test streaming response with template-added <think>
t.Run("streaming with thinking", func(t *testing.T) {
var wg sync.WaitGroup
wg.Add(1)
mock.CompletionFn = func(ctx context.Context, r llm.CompletionRequest, fn func(r llm.CompletionResponse)) error {
defer wg.Done()
// Verify the prompt ends with <think> due to template
if !strings.HasSuffix(r.Prompt, "<think>") {
t.Errorf("expected prompt to end with <think>, got: %q", r.Prompt)
}
// Simulate streaming chunks
responses := []llm.CompletionResponse{
{Content: " I need to consider", Done: false, PromptEvalCount: 1, PromptEvalDuration: 1},
{Content: " multiple factors here...", Done: false, PromptEvalCount: 1, PromptEvalDuration: 1},
{Content: " </think> Based on my analysis,", Done: false, PromptEvalCount: 1, PromptEvalDuration: 1},
{Content: " the solution is straightforward.", Done: true, DoneReason: llm.DoneReasonStop, PromptEvalCount: 1, PromptEvalDuration: 1, EvalCount: 1, EvalDuration: 1},
}
for _, resp := range responses {
select {
case <-ctx.Done():
return ctx.Err()
default:
fn(resp)
time.Sleep(10 * time.Millisecond)
}
}
return nil
}
think := true
w := createRequest(t, s.ChatHandler, api.ChatRequest{
Model: "test-thinking",
Messages: []api.Message{{Role: "user", Content: "Analyze this complex problem"}},
Think: &api.ThinkValue{Value: think},
Stream: &stream,
})
wg.Wait()
if w.Code != http.StatusOK {
t.Fatalf("expected status 200, got %d", w.Code)
}
// Parse streaming responses
decoder := json.NewDecoder(w.Body)
var allThinking, allContent strings.Builder
for {
var resp api.ChatResponse
if err := decoder.Decode(&resp); err == io.EOF {
break
} else if err != nil {
t.Fatal(err)
}
allThinking.WriteString(resp.Message.Thinking)
allContent.WriteString(resp.Message.Content)
}
// Note: Leading whitespace after <think> is eaten by the parser
if got := allThinking.String(); got != "I need to consider multiple factors here... " {
t.Errorf("expected thinking %q, got %q", "I need to consider multiple factors here... ", got)
}
if got := allContent.String(); got != "Based on my analysis, the solution is straightforward." {
t.Errorf("expected content %q, got %q", "Based on my analysis, the solution is straightforward.", got)
}
})
}

View File

@@ -7,6 +7,7 @@ import (
"bytes"
"context"
"encoding/json"
"net/http"
"strings"
"testing"
"time"
@@ -117,7 +118,7 @@ func TestChatHarmonyParserStreamingRealtime(t *testing.T) {
name: "content streams as it arrives",
steps: []step{
{
input: llm.CompletionResponse{Content: "Hello", Done: false},
input: llm.CompletionResponse{Content: "<|message|>Hello", Done: false},
wantContent: "Hello",
},
{
@@ -125,7 +126,7 @@ func TestChatHarmonyParserStreamingRealtime(t *testing.T) {
wantContent: ", world",
},
{
input: llm.CompletionResponse{Content: "!", Done: true, DoneReason: llm.DoneReasonStop},
input: llm.CompletionResponse{Content: "!<|end|>", Done: true, DoneReason: llm.DoneReasonStop},
wantContent: "!",
},
},
@@ -134,15 +135,20 @@ func TestChatHarmonyParserStreamingRealtime(t *testing.T) {
name: "thinking streams separately from content",
steps: []step{
{
input: llm.CompletionResponse{Thinking: "Thinking...", Done: false},
input: llm.CompletionResponse{Content: "<|channel|>analysis<|message|>Thinking...", Done: false},
wantThinking: "Thinking...",
},
{
input: llm.CompletionResponse{Content: "Answer", Done: false},
wantContent: "Answer",
input: llm.CompletionResponse{Content: "<|end|>", Done: false},
// No output expected - just closes the analysis message and resets state to normal
},
{
input: llm.CompletionResponse{Done: true, DoneReason: llm.DoneReasonStop},
input: llm.CompletionResponse{Content: "<|start|>assistant<|message|>Answer", Done: false},
wantContent: "Answer", // After message end, state is reset to normal
},
{
input: llm.CompletionResponse{Content: "<|end|>", Done: true, DoneReason: llm.DoneReasonStop},
// No output expected - just closes the assistant message
},
},
},
@@ -150,16 +156,24 @@ func TestChatHarmonyParserStreamingRealtime(t *testing.T) {
name: "partial tags buffer until complete",
steps: []step{
{
input: llm.CompletionResponse{Thinking: "Deep ", Done: false},
input: llm.CompletionResponse{Content: "<|chan", Done: false},
// No output - partial tag
},
{
input: llm.CompletionResponse{Content: "nel|>analysis<|mess", Done: false},
// No output - still building tags
},
{
input: llm.CompletionResponse{Content: "age|>Deep ", Done: false},
wantThinking: "Deep ",
},
{
input: llm.CompletionResponse{Thinking: "thought", Done: false},
input: llm.CompletionResponse{Content: "thought<|end|>", Done: false},
wantThinking: "thought",
},
{
input: llm.CompletionResponse{Content: "Done", Done: true, DoneReason: llm.DoneReasonStop},
wantContent: "Done",
input: llm.CompletionResponse{Content: "<|start|>assistant<|message|>Done<|end|>", Done: true, DoneReason: llm.DoneReasonStop},
wantContent: "Done", // After message end, state is reset to normal
},
},
},
@@ -167,7 +181,7 @@ func TestChatHarmonyParserStreamingRealtime(t *testing.T) {
name: "simple assistant after analysis",
steps: []step{
{
input: llm.CompletionResponse{Thinking: "Think", Content: "Answer", Done: true, DoneReason: llm.DoneReasonStop},
input: llm.CompletionResponse{Content: "<|channel|>analysis<|message|>Think<|end|><|start|>assistant<|message|>Answer<|end|>", Done: true, DoneReason: llm.DoneReasonStop},
wantContent: "Answer",
wantThinking: "Think",
},
@@ -177,7 +191,7 @@ func TestChatHarmonyParserStreamingRealtime(t *testing.T) {
name: "tool call parsed and returned correctly",
steps: []step{
{
input: llm.CompletionResponse{Content: "The weather is sunny", ToolCalls: []api.ToolCall{{Function: api.ToolCallFunction{Name: "get_weather", Arguments: api.ToolCallFunctionArguments{"location": "San Francisco"}}}}, Done: true, DoneReason: llm.DoneReasonStop},
input: llm.CompletionResponse{Content: "<|channel|>commentary to=functions.get_weather<|message|>{\"location\":\"San Francisco\"}<|end|><|start|>assistant<|message|>The weather is sunny<|end|>", Done: true, DoneReason: llm.DoneReasonStop},
wantContent: "The weather is sunny",
wantToolCalls: []api.ToolCall{
{
@@ -196,10 +210,15 @@ func TestChatHarmonyParserStreamingRealtime(t *testing.T) {
name: "tool call with streaming JSON across chunks",
steps: []step{
{
input: llm.CompletionResponse{Done: false},
input: llm.CompletionResponse{Content: "<|channel|>commentary to=functions.calculate<|message|>{\"expr", Done: false},
// No output yet - incomplete JSON
},
{
input: llm.CompletionResponse{ToolCalls: []api.ToolCall{{Function: api.ToolCallFunction{Name: "calculate", Arguments: api.ToolCallFunctionArguments{"expression": "2+2"}}}}, Done: true},
input: llm.CompletionResponse{Content: "ession\":\"2+", Done: false},
// Still no output - incomplete JSON
},
{
input: llm.CompletionResponse{Content: "2\"}", Done: true},
wantToolCalls: []api.ToolCall{
{
Function: api.ToolCallFunction{
@@ -381,9 +400,9 @@ func TestChatHarmonyParserStreamingSimple(t *testing.T) {
gin.SetMode(gin.TestMode)
mockResponses := []llm.CompletionResponse{
{Content: "First ", Done: false},
{Content: "<|message|>First ", Done: false},
{Content: "chunk ", Done: false},
{Content: "here", Done: true, DoneReason: llm.DoneReasonStop},
{Content: "here<|end|>", Done: true, DoneReason: llm.DoneReasonStop},
}
mock := mockRunner{
@@ -488,3 +507,189 @@ func TestChatHarmonyParserStreamingSimple(t *testing.T) {
t.Errorf("expected at least 2 content chunks for streaming, got %d", contentChunks)
}
}
func TestChatHarmonyParserStreaming(t *testing.T) {
gin.SetMode(gin.TestMode)
type expectedChunk struct {
afterResponse int // Which mock response this chunk should appear after
content string // Expected content in this chunk
thinking string // Expected thinking in this chunk
}
testCases := []struct {
name string
mockResponses []llm.CompletionResponse
expectedChunks []expectedChunk
wantContent string
wantThinking string
}{
{
name: "simple message without thinking",
mockResponses: []llm.CompletionResponse{
{Content: "<|start|>assistant<|message|>Hello, ", Done: false},
{Content: "how can I help?", Done: false},
{Content: "<|end|>", Done: true, DoneReason: llm.DoneReasonStop},
},
expectedChunks: []expectedChunk{
{afterResponse: 1, content: "Hello, "},
{afterResponse: 2, content: "how can I help?"},
},
wantContent: "Hello, how can I help?",
},
{
name: "message with analysis channel for thinking",
mockResponses: []llm.CompletionResponse{
{Content: "<|channel|>analysis<|message|>", Done: false},
{Content: "Let me think ", Done: false},
{Content: "about this problem...", Done: false},
{Content: "<|end|>", Done: false},
{Content: "<|start|>assistant<|message|>", Done: false},
{Content: "The answer ", Done: false},
{Content: "is 42", Done: false},
{Content: "<|end|>", Done: true, DoneReason: llm.DoneReasonStop},
},
expectedChunks: []expectedChunk{
{afterResponse: 2, thinking: "Let me think "},
{afterResponse: 3, thinking: "about this problem..."},
{afterResponse: 6, content: "The answer "},
{afterResponse: 7, content: "is 42"},
},
wantContent: "The answer is 42",
wantThinking: "Let me think about this problem...",
},
{
name: "streaming with partial tags across boundaries",
mockResponses: []llm.CompletionResponse{
{Content: "<|chan", Done: false},
{Content: "nel|>analy", Done: false},
{Content: "sis<|mess", Done: false},
{Content: "age|>Think", Done: false},
{Content: "ing deeply...<|end|>", Done: false},
{Content: "<|start|>assi", Done: false},
{Content: "stant<|message|>Result ", Done: false},
{Content: "computed<|e", Done: false},
{Content: "nd|>", Done: true, DoneReason: llm.DoneReasonStop},
},
expectedChunks: []expectedChunk{
{afterResponse: 4, thinking: "Think"},
{afterResponse: 5, thinking: "ing deeply..."},
{afterResponse: 7, content: "Result "},
{afterResponse: 8, content: "computed"},
},
wantContent: "Result computed",
wantThinking: "Thinking deeply...",
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
// Channel to synchronize mock responses with chunk verification
responsesSent := make(chan int, len(tc.mockResponses))
mock := mockRunner{
CompletionFn: func(ctx context.Context, r llm.CompletionRequest, fn func(llm.CompletionResponse)) error {
// Send mock responses one at a time, notifying when each is sent
for i, resp := range tc.mockResponses {
fn(resp)
responsesSent <- i + 1
}
close(responsesSent)
return nil
},
}
s := Server{
sched: &Scheduler{
pendingReqCh: make(chan *LlmRequest, 1),
finishedReqCh: make(chan *LlmRequest, 1),
expiredCh: make(chan *runnerRef, 1),
unloadedCh: make(chan any, 1),
loaded: make(map[string]*runnerRef),
newServerFn: newMockServer(&mock),
getGpuFn: discover.GetGPUInfo,
getCpuFn: discover.GetCPUInfo,
reschedDelay: 250 * time.Millisecond,
loadFn: func(req *LlmRequest, _ *ggml.GGML, _ discover.GpuInfoList, _ bool) bool {
req.successCh <- &runnerRef{
llama: &mock,
}
return false
},
},
}
go s.sched.Run(t.Context())
// Create a minimal model
_, digest := createHarmonyTestModel(t)
// Create model with passthrough template
stream := false
w := createRequest(t, s.CreateHandler, api.CreateRequest{
Model: "harmony-test",
Files: map[string]string{"file.gguf": digest},
Template: `<|start|><|end|>{{ with .Tools }}{{ end }}{{ .Prompt }}`,
Stream: &stream,
})
if w.Code != http.StatusOK {
t.Fatalf("failed to create model: %d", w.Code)
}
// Test chat endpoint with streaming
streamTrue := true
w = createRequest(t, s.ChatHandler, api.ChatRequest{
Model: "harmony-test",
Messages: []api.Message{{Role: "user", Content: "Hello"}},
Stream: &streamTrue,
Tools: getTestTools(),
})
if w.Code != http.StatusOK {
t.Fatalf("chat request failed: %d - %s", w.Code, w.Body.String())
}
// Parse streaming response
var chunks []api.ChatResponse
var content, thinking strings.Builder
decoder := json.NewDecoder(w.Body)
for decoder.More() {
var chunk api.ChatResponse
if err := decoder.Decode(&chunk); err != nil {
t.Fatalf("failed to decode chunk: %v", err)
}
chunks = append(chunks, chunk)
// Accumulate content and thinking from each chunk
content.WriteString(chunk.Message.Content)
thinking.WriteString(chunk.Message.Thinking)
// Debug output
t.Logf("Chunk %d: content=%q thinking=%q done=%v", len(chunks), chunk.Message.Content, chunk.Message.Thinking, chunk.Done)
}
// Verify we got streaming chunks
if len(chunks) == 0 {
t.Fatal("expected streaming chunks, got none")
}
gotContent := content.String()
gotThinking := thinking.String()
if gotContent != tc.wantContent {
t.Errorf("content mismatch: got %q, want %q", gotContent, tc.wantContent)
}
if gotThinking != tc.wantThinking {
t.Errorf("thinking mismatch: got %q, want %q", gotThinking, tc.wantThinking)
}
// Verify last chunk has done=true
lastChunk := chunks[len(chunks)-1]
if !lastChunk.Done {
t.Error("expected last chunk to have done=true")
}
})
}
}

View File

@@ -103,9 +103,7 @@ func eat(s *Parser) (string, string, bool) {
// note that we use the original content, not the trimmed one because we
// don't want to eat any whitespace in the real content if there were no
// thinking tags
untrimmed := s.acc.String()
s.acc.Reset()
return "", untrimmed, false
return "", s.acc.String(), false
}
case thinkingState_ThinkingStartedEatingWhitespace:
trimmed := strings.TrimLeftFunc(s.acc.String(), unicode.IsSpace)

View File

@@ -58,15 +58,6 @@ func TestThinkingStreaming(t *testing.T) {
wantContent: " abc",
wantStateAfter: thinkingState_ThinkingDone,
},
// regression test for a bug where we were transitioning directly to
// ThinkingDone without clearing the buffer. This would cuase the first
// step to be outputted twice
{
input: "def",
wantThinking: "",
wantContent: "def",
wantStateAfter: thinkingState_ThinkingDone,
},
},
},
{

View File

@@ -224,45 +224,22 @@ func findArguments(buffer []byte) (map[string]any, int) {
return nil, 0
}
start := -1
var braces int
var inString, escaped bool
for i := range buffer {
c := buffer[i]
if escaped {
escaped = false
continue
}
if c == '\\' {
escaped = true
continue
}
if c == '"' {
inString = !inString
continue
}
if inString {
continue
}
var start int = -1
for i, c := range buffer {
if c == '{' {
if braces == 0 {
start = i
}
braces++
} else if c == '}' {
} else if c == '}' && braces > 0 {
braces--
if braces == 0 && start != -1 {
object := buffer[start : i+1]
var data map[string]any
if err := json.Unmarshal(object, &data); err != nil {
// not a valid object, keep looking
start = -1
continue
}
@@ -305,10 +282,6 @@ func findArguments(buffer []byte) (map[string]any, int) {
return data, i
}
if braces < 0 {
braces = 0
}
}
}

View File

@@ -1,7 +1,6 @@
package tools
import (
"strings"
"testing"
"text/template"
@@ -41,7 +40,13 @@ func TestParser(t *testing.T) {
Function: api.ToolFunction{
Name: "get_temperature",
Description: "Retrieve the temperature for a given location",
Parameters: api.ToolFunctionParameters{
Parameters: struct {
Type string `json:"type"`
Defs any `json:"$defs,omitempty"`
Items any `json:"items,omitempty"`
Required []string `json:"required"`
Properties map[string]api.ToolProperty `json:"properties"`
}{
Type: "object",
Required: []string{"city"},
Properties: map[string]api.ToolProperty{
@@ -63,7 +68,13 @@ func TestParser(t *testing.T) {
Function: api.ToolFunction{
Name: "get_conditions",
Description: "Retrieve the current weather conditions for a given location",
Parameters: api.ToolFunctionParameters{
Parameters: struct {
Type string `json:"type"`
Defs any `json:"$defs,omitempty"`
Items any `json:"items,omitempty"`
Required []string `json:"required"`
Properties map[string]api.ToolProperty `json:"properties"`
}{
Type: "object",
Properties: map[string]api.ToolProperty{
"location": {
@@ -93,7 +104,13 @@ func TestParser(t *testing.T) {
Function: api.ToolFunction{
Name: "get_address",
Description: "Get the address of a given location",
Parameters: api.ToolFunctionParameters{
Parameters: struct {
Type string `json:"type"`
Defs any `json:"$defs,omitempty"`
Items any `json:"items,omitempty"`
Required []string `json:"required"`
Properties map[string]api.ToolProperty `json:"properties"`
}{
Type: "object",
Properties: map[string]api.ToolProperty{
"location": {
@@ -109,7 +126,13 @@ func TestParser(t *testing.T) {
Function: api.ToolFunction{
Name: "add",
Description: "Add two numbers",
Parameters: api.ToolFunctionParameters{
Parameters: struct {
Type string `json:"type"`
Defs any `json:"$defs,omitempty"`
Items any `json:"items,omitempty"`
Required []string `json:"required"`
Properties map[string]api.ToolProperty `json:"properties"`
}{
Type: "object",
Properties: map[string]api.ToolProperty{
"a": {
@@ -1117,163 +1140,11 @@ func TestFindArguments(t *testing.T) {
},
{
name: "deepseek",
buffer: []byte(`"arguments": {"location": "Tokyo"}}</tool_call>`),
buffer: []byte(`", "arguments": {"location": "Tokyo"}}</tool_call>`),
want: map[string]any{
"location": "Tokyo",
},
},
{
name: "string with braces",
buffer: []byte(`{"name": "process_code", "arguments": {"code": "if (x > 0) { return true; }"}}`),
want: map[string]any{
"code": "if (x > 0) { return true; }",
},
},
{
name: "string with nested json",
buffer: []byte(`{"name": "send_data", "arguments": {"payload": "{\"nested\": {\"key\": \"value\"}}"}}`),
want: map[string]any{
"payload": `{"nested": {"key": "value"}}`,
},
},
{
name: "string with escaped quotes and braces",
buffer: []byte(`{"name": "analyze", "arguments": {"text": "The JSON is: {\"key\": \"val{ue}\"}"}}`),
want: map[string]any{
"text": `The JSON is: {"key": "val{ue}"}`,
},
},
{
name: "multiple objects with string containing braces",
buffer: []byte(`{"name": "test", "arguments": {"query": "find } in text"}} {"name": "other"}`),
want: map[string]any{
"query": "find } in text",
},
},
{
name: "unmatched closing brace in string",
buffer: []byte(`{"name": "search", "arguments": {"pattern": "regex: }"}}`),
want: map[string]any{
"pattern": "regex: }",
},
},
{
name: "complex nested with mixed braces",
buffer: []byte(`{"name": "analyze", "arguments": {"data": "{\"items\": [{\"value\": \"}\"}, {\"code\": \"if (x) { return y; }\"}]}"}}`),
want: map[string]any{
"data": `{"items": [{"value": "}"}, {"code": "if (x) { return y; }"}]}`,
},
},
{
name: "string with newline and braces",
buffer: []byte(`{"name": "format", "arguments": {"template": "{\n \"key\": \"value\"\n}"}}`),
want: map[string]any{
"template": "{\n \"key\": \"value\"\n}",
},
},
{
name: "string with unicode escape",
buffer: []byte(`{"name": "test", "arguments": {"text": "Unicode: \u007B and \u007D"}}`),
want: map[string]any{
"text": "Unicode: { and }",
},
},
{
name: "array arguments",
buffer: []byte(`{"name": "batch", "arguments": ["item1", "item2", "{\"nested\": true}"]}`),
want: nil, // This should return nil because arguments is not a map
},
{
name: "escaped backslash before quote",
buffer: []byte(`{"name": "path", "arguments": {"dir": "C:\\Program Files\\{App}\\"}}`),
want: map[string]any{
"dir": `C:\Program Files\{App}\`,
},
},
{
name: "single quotes not treated as string delimiters",
buffer: []byte(`{"name": "query", "arguments": {"sql": "SELECT * FROM users WHERE name = '{admin}'"}}`),
want: map[string]any{
"sql": "SELECT * FROM users WHERE name = '{admin}'",
},
},
{
name: "incomplete json at buffer end",
buffer: []byte(`{"name": "test", "arguments": {"data": "some {"`),
want: nil,
},
{
name: "multiple escaped quotes",
buffer: []byte(`{"name": "echo", "arguments": {"msg": "He said \"Hello {World}\" loudly"}}`),
want: map[string]any{
"msg": `He said "Hello {World}" loudly`,
},
},
{
name: "json with comments style string",
buffer: []byte(`{"name": "code", "arguments": {"snippet": "// This is a comment with { and }"}}`),
want: map[string]any{
"snippet": "// This is a comment with { and }",
},
},
{
name: "consecutive escaped backslashes",
buffer: []byte(`{"name": "test", "arguments": {"path": "C:\\\\{folder}\\\\"}}`),
want: map[string]any{
"path": `C:\\{folder}\\`,
},
},
{
name: "empty string with braces after",
buffer: []byte(`{"name": "test", "arguments": {"a": "", "b": "{value}"}}`),
want: map[string]any{
"a": "",
"b": "{value}",
},
},
{
name: "unicode in key names",
buffer: []byte(`{"name": "test", "arguments": {"key{": "value", "key}": "value2"}}`),
want: map[string]any{
"key{": "value",
"key}": "value2",
},
},
{
name: "very long string with braces",
buffer: []byte(`{"name": "test", "arguments": {"data": "` + strings.Repeat("a{b}c", 100) + `"}}`),
want: map[string]any{
"data": strings.Repeat("a{b}c", 100),
},
},
{
name: "tab characters and braces",
buffer: []byte(`{"name": "test", "arguments": {"code": "\tif (true) {\n\t\treturn;\n\t}"}}`),
want: map[string]any{
"code": "\tif (true) {\n\t\treturn;\n\t}",
},
},
{
name: "null byte in string",
buffer: []byte(`{"name": "test", "arguments": {"data": "before\u0000{after}"}}`),
want: map[string]any{
"data": "before\x00{after}",
},
},
{
name: "escaped quote at end of string",
buffer: []byte(`{"name": "test", "arguments": {"data": "text with quote at end\\\""}}`),
want: map[string]any{
"data": `text with quote at end\"`,
},
},
{
name: "mixed array and object in arguments",
buffer: []byte(`{"name": "test", "arguments": {"items": ["{", "}", {"key": "value"}]}}`),
want: map[string]any{
"items": []any{"{", "}", map[string]any{"key": "value"}},
},
},
}
for _, tt := range tests {