Compare commits

..

1 Commits

Author SHA1 Message Date
Devon Rifkin
c87b910232 WIP: stable ordering for tool args
Right now we deserialize tool call definitions' arguments into golang
maps. These purposefully don't have a predictable iteration order,
whereas we want to maintain the order the user originally provided.

Unstable rendering of arguments means that we break the kv cache, which
this change fixes.

There's no way to build this in a fully backwards compatible way when
executing existing templates exactly as they are. We get around this by
rewriting templates dynamically just before they're rendered. This is
fragile, but perhaps the least bad option?
2025-10-07 15:38:58 -07:00
37 changed files with 1172 additions and 1169 deletions

View File

@@ -3,6 +3,7 @@ package api
import (
"encoding/json"
"fmt"
"iter"
"log/slog"
"math"
"os"
@@ -12,6 +13,7 @@ import (
"time"
"github.com/google/uuid"
orderedmap "github.com/wk8/go-ordered-map/v2"
"github.com/ollama/ollama/envconfig"
"github.com/ollama/ollama/types/model"
@@ -193,13 +195,70 @@ type ToolCallFunction struct {
Arguments ToolCallFunctionArguments `json:"arguments"`
}
type ToolCallFunctionArguments map[string]any
type ToolCallFunctionArguments struct {
om *orderedmap.OrderedMap[string, any]
}
func NewToolCallFunctionArguments() ToolCallFunctionArguments {
return ToolCallFunctionArguments{
om: orderedmap.New[string, any](),
}
}
func (t *ToolCallFunctionArguments) Get(key string) (any, bool) {
if t == nil || t.om == nil {
return nil, false
}
return t.om.Get(key)
}
func (t *ToolCallFunctionArguments) Set(key string, value any) {
if t == nil {
return
}
if t.om == nil {
t.om = orderedmap.New[string, any]()
}
t.om.Set(key, value)
}
func (t *ToolCallFunctionArguments) Len() int {
if t == nil || t.om == nil {
return 0
}
return t.om.Len()
}
func (t *ToolCallFunctionArguments) All() iter.Seq2[string, any] {
return func(yield func(string, any) bool) {
if t == nil || t.om == nil {
return
}
for pair := t.om.Oldest(); pair != nil; pair = pair.Next() {
if !yield(pair.Key, pair.Value) {
return
}
}
}
}
func (t *ToolCallFunctionArguments) String() string {
bts, _ := json.Marshal(t)
if t == nil || t.om == nil {
return "{}"
}
bts, _ := json.Marshal(t.om)
return string(bts)
}
func (t *ToolCallFunctionArguments) UnmarshalJSON(data []byte) error {
t.om = orderedmap.New[string, any]()
return json.Unmarshal(data, &t.om)
}
func (t ToolCallFunctionArguments) MarshalJSON() ([]byte, error) {
return json.Marshal(t.om)
}
type Tool struct {
Type string `json:"type"`
Items any `json:"items,omitempty"`
@@ -301,12 +360,114 @@ func mapToTypeScriptType(jsonType string) string {
}
}
type ToolProperties struct {
om *orderedmap.OrderedMap[string, ToolProperty]
}
func NewToolProperties() *ToolProperties {
return &ToolProperties{
om: orderedmap.New[string, ToolProperty](),
}
}
func (t *ToolProperties) Get(key string) (ToolProperty, bool) {
if t == nil || t.om == nil {
return ToolProperty{}, false
}
return t.om.Get(key)
}
func (t *ToolProperties) Set(key string, value ToolProperty) {
if t == nil {
return
}
if t.om == nil {
t.om = orderedmap.New[string, ToolProperty]()
}
t.om.Set(key, value)
}
func (t *ToolProperties) Len() int {
if t == nil || t.om == nil {
return 0
}
return t.om.Len()
}
func (t *ToolProperties) All() iter.Seq2[string, ToolProperty] {
return func(yield func(string, ToolProperty) bool) {
if t == nil || t.om == nil {
return
}
for pair := t.om.Oldest(); pair != nil; pair = pair.Next() {
if !yield(pair.Key, pair.Value) {
return
}
}
}
}
func (t *ToolProperties) MarshalJSON() ([]byte, error) {
if t == nil || t.om == nil {
return []byte("null"), nil
}
return json.Marshal(t.om)
}
func (t *ToolProperties) UnmarshalJSON(data []byte) error {
t.om = orderedmap.New[string, ToolProperty]()
return json.Unmarshal(data, &t.om)
}
type ToolFunctionParameters struct {
Type string `json:"type"`
Defs any `json:"$defs,omitempty"`
Items any `json:"items,omitempty"`
Required []string `json:"required"`
Properties map[string]ToolProperty `json:"properties"`
Type string `json:"type"`
Defs any `json:"$defs,omitempty"`
Items any `json:"items,omitempty"`
Required []string `json:"required"`
properties *ToolProperties // unexported - accessed via Properties() method
}
// Properties returns an iterator for template compatibility.
// Templates can range over this directly: {{range $k, $v := .Properties}}
func (t ToolFunctionParameters) Properties() iter.Seq2[string, ToolProperty] {
if t.properties == nil {
return func(yield func(string, ToolProperty) bool) {}
}
return t.properties.All()
}
// HasProperties returns true if properties exist and are non-empty.
// This is used by templates for conditional checks: {{if .HasProperties}}
func (t ToolFunctionParameters) HasProperties() bool {
return t.properties != nil && t.properties.Len() > 0
}
// Len returns the number of properties.
// This is used by templates: {{.Function.Parameters.Len}}
func (t ToolFunctionParameters) Len() int {
if t.properties == nil {
return 0
}
return t.properties.Len()
}
// SetProperties sets the properties (used by tests and internal code)
func (t *ToolFunctionParameters) SetProperties(props *ToolProperties) {
t.properties = props
}
// NewToolFunctionParametersWithProps creates a ToolFunctionParameters with properties (helper for tests)
func NewToolFunctionParametersWithProps(typ string, required []string, props *ToolProperties) ToolFunctionParameters {
return ToolFunctionParameters{
Type: typ,
Required: required,
properties: props,
}
}
// GetProperties returns the properties wrapper (used by renderers)
func (t *ToolFunctionParameters) GetProperties() *ToolProperties {
return t.properties
}
func (t *ToolFunctionParameters) String() string {
@@ -314,6 +475,38 @@ func (t *ToolFunctionParameters) String() string {
return string(bts)
}
func (t *ToolFunctionParameters) MarshalJSON() ([]byte, error) {
type Alias ToolFunctionParameters
return json.Marshal(&struct {
Type string `json:"type"`
Defs any `json:"$defs,omitempty"`
Items any `json:"items,omitempty"`
Required []string `json:"required"`
Properties *ToolProperties `json:"properties"`
}{
Type: t.Type,
Defs: t.Defs,
Items: t.Items,
Required: t.Required,
Properties: t.properties,
})
}
func (t *ToolFunctionParameters) UnmarshalJSON(data []byte) error {
type Alias ToolFunctionParameters
aux := &struct {
Properties *ToolProperties `json:"properties"`
*Alias
}{
Alias: (*Alias)(t),
}
if err := json.Unmarshal(data, aux); err != nil {
return err
}
t.properties = aux.Properties
return nil
}
type ToolFunction struct {
Name string `json:"name"`
Description string `json:"description"`
@@ -936,7 +1129,7 @@ func (t *ThinkValue) UnmarshalJSON(data []byte) error {
return nil
}
return fmt.Errorf("think must be a boolean or string (\"high\", \"medium\", \"low\", true, or false)")
return fmt.Errorf("think must be a boolean or string (\"high\", \"medium\", \"low\")")
}
// MarshalJSON implements json.Marshaler

View File

@@ -4,6 +4,7 @@ import (
"encoding/json"
"errors"
"math"
"strings"
"testing"
"time"
@@ -450,23 +451,25 @@ func TestToolFunctionParameters_String(t *testing.T) {
}{
{
name: "simple object with string property",
params: ToolFunctionParameters{
Type: "object",
Required: []string{"name"},
Properties: map[string]ToolProperty{
"name": {
params: NewToolFunctionParametersWithProps(
"object",
[]string{"name"},
func() *ToolProperties {
om := NewToolProperties()
om.Set("name", ToolProperty{
Type: PropertyType{"string"},
Description: "The name of the person",
},
},
},
})
return om
}(),
),
expected: `{"type":"object","required":["name"],"properties":{"name":{"type":"string","description":"The name of the person"}}}`,
},
{
name: "marshal failure returns empty string",
params: ToolFunctionParameters{
Type: "object",
Defs: func() any {
params: func() ToolFunctionParameters {
p := NewToolFunctionParametersWithProps("object", nil, NewToolProperties())
p.Defs = func() any {
// Create a cycle that will cause json.Marshal to fail
type selfRef struct {
Self *selfRef
@@ -474,9 +477,9 @@ func TestToolFunctionParameters_String(t *testing.T) {
s := &selfRef{}
s.Self = s
return s
}(),
Properties: map[string]ToolProperty{},
},
}()
return p
}(),
expected: "",
},
}
@@ -488,3 +491,31 @@ func TestToolFunctionParameters_String(t *testing.T) {
})
}
}
func TestTemplateRenderingWithArguments(t *testing.T) {
// Test that ToolCallFunctionArguments renders correctly in templates
// This verifies the String() method works for template interpolation
args := NewToolCallFunctionArguments()
args.Set("location", "San Francisco")
args.Set("unit", "fahrenheit")
// Simulate what a template would do: convert to string
rendered := args.String()
// Should produce valid JSON
var parsed map[string]any
err := json.Unmarshal([]byte(rendered), &parsed)
require.NoError(t, err, "Arguments should render as valid JSON")
// Verify the values are present and in order
assert.Equal(t, "San Francisco", parsed["location"])
assert.Equal(t, "fahrenheit", parsed["unit"])
// Verify it maintains insertion order by checking the JSON string directly
// The first Set was "location", so it should appear before "unit"
assert.Contains(t, rendered, `"location":"San Francisco"`)
assert.Contains(t, rendered, `"unit":"fahrenheit"`)
locIndex := strings.Index(rendered, "location")
unitIndex := strings.Index(rendered, "unit")
assert.Less(t, locIndex, unitIndex, "insertion order should be preserved")
}

View File

@@ -2,12 +2,11 @@ package discover
import (
"context"
"fmt"
"log/slog"
"os"
"path/filepath"
"regexp"
"runtime"
"strconv"
"strings"
"github.com/ollama/ollama/format"
@@ -61,14 +60,17 @@ func devInfoToInfoList(devs []ml.DeviceInfo) GpuInfoList {
DependencyPath: dev.LibraryPath,
DriverMajor: dev.DriverMajor,
DriverMinor: dev.DriverMinor,
ComputeMajor: dev.ComputeMajor,
ComputeMinor: dev.ComputeMinor,
}
if dev.Library == "CUDA" || dev.Library == "ROCm" {
info.MinimumMemory = 457 * format.MebiByte
}
if dev.Library == "ROCm" && rocmDir != "" {
info.DependencyPath = append(info.DependencyPath, rocmDir)
if dev.Library == "ROCm" {
info.Compute = fmt.Sprintf("gfx%x%02x", dev.ComputeMajor, dev.ComputeMinor)
if rocmDir != "" {
info.DependencyPath = append(info.DependencyPath, rocmDir)
}
} else {
info.Compute = fmt.Sprintf("%d.%d", dev.ComputeMajor, dev.ComputeMinor)
}
resp = append(resp, info)
}
@@ -144,35 +146,3 @@ func GetSystemInfo() SystemInfo {
GPUs: gpus,
}
}
func cudaJetpack() string {
if runtime.GOARCH == "arm64" && runtime.GOOS == "linux" {
if CudaTegra != "" {
ver := strings.Split(CudaTegra, ".")
if len(ver) > 0 {
return "jetpack" + ver[0]
}
} else if data, err := os.ReadFile("/etc/nv_tegra_release"); err == nil {
r := regexp.MustCompile(` R(\d+) `)
m := r.FindSubmatch(data)
if len(m) != 2 {
slog.Info("Unexpected format for /etc/nv_tegra_release. Set JETSON_JETPACK to select version")
} else {
if l4t, err := strconv.Atoi(string(m[1])); err == nil {
// Note: mapping from L4t -> JP is inconsistent (can't just subtract 30)
// https://developer.nvidia.com/embedded/jetpack-archive
switch l4t {
case 35:
return "jetpack5"
case 36:
return "jetpack6"
default:
// Newer Jetson systems use the SBSU runtime
slog.Debug("unrecognized L4T version", "nv_tegra_release", string(data))
}
}
}
}
}
return ""
}

View File

@@ -78,8 +78,6 @@ func GPUDevices(ctx context.Context, runners []FilteredRunnerDiscovery) []ml.Dev
}
slog.Info("discovering available GPUs...")
requested := envconfig.LLMLibrary()
jetpack := cudaJetpack()
// For our initial discovery pass, we gather all the known GPUs through
// all the libraries that were detected. This pass may include GPUs that
@@ -88,14 +86,6 @@ func GPUDevices(ctx context.Context, runners []FilteredRunnerDiscovery) []ml.Dev
// times concurrently leading to memory contention
for dir := range libDirs {
var dirs []string
if dir != "" {
if requested != "" && filepath.Base(dir) != requested {
slog.Debug("skipping available library at users request", "requested", requested, "libDir", dir)
continue
} else if jetpack != "" && filepath.Base(dir) != "cuda_"+jetpack {
continue
}
}
if dir == "" {
dirs = []string{LibOllamaPath}
} else {

View File

@@ -37,10 +37,9 @@ type GpuInfo struct { // TODO better name maybe "InferenceProcessor"?
UnreliableFreeMemory bool
// GPU information
filterID string // AMD Workaround: The numeric ID of the device used to filter out other devices
Name string `json:"name"` // user friendly name if available
ComputeMajor int `json:"compute_major"` // Compute Capability or gfx
ComputeMinor int `json:"compute_minor"`
filterID string // AMD Workaround: The numeric ID of the device used to filter out other devices
Name string `json:"name"` // user friendly name if available
Compute string `json:"compute"` // Compute Capability or gfx
// Driver Information - TODO no need to put this on each GPU
DriverMajor int `json:"driver_major,omitempty"`
@@ -174,7 +173,7 @@ func (l GpuInfoList) FlashAttentionSupported() bool {
for _, gpu := range l {
supportsFA := gpu.Library == "cpu" ||
gpu.Name == "Metal" || gpu.Library == "Metal" ||
(gpu.Library == "CUDA" && gpu.DriverMajor >= 7 && !(gpu.ComputeMajor == 7 && gpu.ComputeMinor == 2)) || // We don't have kernels for Jetson Xavier
(gpu.Library == "CUDA" && gpu.DriverMajor >= 7) ||
gpu.Library == "ROCm"
if !supportsFA {

View File

@@ -38,14 +38,26 @@ Join the [Discord](https://discord.gg/ollama) for help interpreting the logs.
## LLM libraries
Ollama includes multiple LLM libraries compiled for different GPU libraries and versions. Ollama tries to pick the best one based on the capabilities of your system. If this autodetection has problems, or you run into other problems (e.g. crashes in your GPU) you can workaround this by forcing a specific LLM library.
Ollama includes multiple LLM libraries compiled for different GPUs and CPU vector features. Ollama tries to pick the best one based on the capabilities of your system. If this autodetection has problems, or you run into other problems (e.g. crashes in your GPU) you can workaround this by forcing a specific LLM library. `cpu_avx2` will perform the best, followed by `cpu_avx` and the slowest but most compatible is `cpu`. Rosetta emulation under MacOS will work with the `cpu` library.
In the server log, you will see a message that looks something like this (varies from release to release):
```
Dynamic LLM libraries [rocm_v6 cpu cpu_avx cpu_avx2 cuda_v12 rocm_v5]
```
**Experimental LLM Library Override**
You can set OLLAMA_LLM_LIBRARY to any of the available LLM libraries to limit autodetection, so for example, if you have both CUDA and AMD GPUs, but want to force the CUDA v13 only, use:
You can set OLLAMA_LLM_LIBRARY to any of the available LLM libraries to bypass autodetection, so for example, if you have a CUDA card, but want to force the CPU LLM library with AVX2 vector support, use:
```shell
OLLAMA_LLM_LIBRARY="cuda_v13" ollama serve
OLLAMA_LLM_LIBRARY="cpu_avx2" ollama serve
```
You can see what features your CPU has with the following.
```shell
cat /proc/cpuinfo| grep flags | head -1
```
## Installing older or pre-release versions on Linux

4
go.mod
View File

@@ -23,6 +23,7 @@ require (
github.com/mattn/go-runewidth v0.0.14
github.com/nlpodyssey/gopickle v0.3.0
github.com/pdevine/tensor v0.0.0-20240510204454-f88f4562727c
github.com/wk8/go-ordered-map/v2 v2.1.8
golang.org/x/image v0.22.0
golang.org/x/tools v0.30.0
gonum.org/v1/gonum v0.15.0
@@ -30,6 +31,8 @@ require (
require (
github.com/apache/arrow/go/arrow v0.0.0-20211112161151-bc219186db40 // indirect
github.com/bahlo/generic-list-go v0.2.0 // indirect
github.com/buger/jsonparser v1.1.1 // indirect
github.com/bytedance/sonic/loader v0.1.1 // indirect
github.com/chewxy/hm v1.0.0 // indirect
github.com/chewxy/math32 v1.11.0 // indirect
@@ -39,6 +42,7 @@ require (
github.com/gogo/protobuf v1.3.2 // indirect
github.com/google/flatbuffers v24.3.25+incompatible // indirect
github.com/kr/text v0.2.0 // indirect
github.com/mailru/easyjson v0.7.7 // indirect
github.com/pkg/errors v0.9.1 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/rivo/uniseg v0.2.0 // indirect

9
go.sum
View File

@@ -12,7 +12,11 @@ github.com/apache/arrow/go/arrow v0.0.0-20211112161151-bc219186db40 h1:q4dksr6IC
github.com/apache/arrow/go/arrow v0.0.0-20211112161151-bc219186db40/go.mod h1:Q7yQnSMnLvcXlZ8RV+jwz/6y1rQTqbX6C82SndT52Zs=
github.com/arbovm/levenshtein v0.0.0-20160628152529-48b4e1c0c4d0 h1:jfIu9sQUG6Ig+0+Ap1h4unLjW6YQJpKZVmUzxsD4E/Q=
github.com/arbovm/levenshtein v0.0.0-20160628152529-48b4e1c0c4d0/go.mod h1:t2tdKJDJF9BV14lnkjHmOQgcvEKgtqs5a1N3LNdJhGE=
github.com/bahlo/generic-list-go v0.2.0 h1:5sz/EEAK+ls5wF+NeqDpk5+iNdMDXrh3z3nPnH1Wvgk=
github.com/bahlo/generic-list-go v0.2.0/go.mod h1:2KvAjgMlE5NNynlg/5iLrrCCZ2+5xWbdbCW3pNTGyYg=
github.com/boombuler/barcode v1.0.0/go.mod h1:paBWMcWSl3LHKBqUq+rly7CNSldXjb2rDl3JlRe0mD8=
github.com/buger/jsonparser v1.1.1 h1:2PnMjfWD7wBILjqQbt530v576A/cAbQvEW9gGIpYMUs=
github.com/buger/jsonparser v1.1.1/go.mod h1:6RYKKt7H4d4+iWqouImQ9R2FZql3VbhNgx27UK13J/0=
github.com/bytedance/sonic v1.11.6 h1:oUp34TzMlL+OY1OUWxHqsdkgC/Zfc85zGqw9siXjrc0=
github.com/bytedance/sonic v1.11.6/go.mod h1:LysEHSvpvDySVdC2f87zGWf6CIKJcAvqab1ZaiQtds4=
github.com/bytedance/sonic/loader v0.1.1 h1:c+e5Pt1k/cy5wMveRDyk2X4B9hF4g7an8N3zCYjJFNM=
@@ -121,6 +125,7 @@ github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+
github.com/grpc-ecosystem/grpc-gateway v1.16.0/go.mod h1:BDjrQk3hbvj6Nolgz8mAMFbcEtjT1g+wF4CSlocrBnw=
github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8=
github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw=
github.com/josharian/intern v1.0.0/go.mod h1:5DoeVV0s6jJacbCEi61lwdGj/aVlrQvzHFFd8Hwg//Y=
github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM=
github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo=
github.com/jung-kurt/gofpdf v1.0.0/go.mod h1:7Id9E/uU8ce6rXgefFLlgrJj/GYY22cpxn+r32jIOes=
@@ -139,6 +144,8 @@ github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
github.com/leodido/go-urn v1.4.0 h1:WT9HwE9SGECu3lg4d/dIA+jxlljEa1/ffXKmRjqdmIQ=
github.com/leodido/go-urn v1.4.0/go.mod h1:bvxc+MVxLKB4z00jd1z+Dvzr47oO32F/QSNjSBOlFxI=
github.com/mailru/easyjson v0.7.7 h1:UGYAvKxe3sBsEDzO8ZeWOSlIQfWFlxbzLZe7hwFURr0=
github.com/mailru/easyjson v0.7.7/go.mod h1:xzfreul335JAWq5oZzymOObrkdz5UnU4kGfJJLY9Nlc=
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
github.com/mattn/go-runewidth v0.0.9/go.mod h1:H031xJmbD/WCDINGzjvQ9THkh0rPKHF+m2gUSrubnMI=
@@ -197,6 +204,8 @@ github.com/twitchyliquid64/golang-asm v0.15.1 h1:SU5vSMR7hnwNxj24w34ZyCi/FmDZTkS
github.com/twitchyliquid64/golang-asm v0.15.1/go.mod h1:a1lVb/DtPvCB8fslRZhAngC2+aY1QWCk3Cedj/Gdt08=
github.com/ugorji/go/codec v1.2.12 h1:9LC83zGrHhuUA9l16C9AHXAqEV/2wBQ4nkvumAE65EE=
github.com/ugorji/go/codec v1.2.12/go.mod h1:UNopzCgEMSXjBc6AOMqYvWC1ktqTAfzJZUZgYf6w6lg=
github.com/wk8/go-ordered-map/v2 v2.1.8 h1:5h/BUHu93oj4gIdvHHHGsScSTMijfx5PeYkE/fJgbpc=
github.com/wk8/go-ordered-map/v2 v2.1.8/go.mod h1:5nJHM5DyteebpVlHnWMV0rPz6Zp7+xBAnxjb1X5vnTw=
github.com/x448/float16 v0.8.4 h1:qLwI1I70+NjRFUR3zs1JPUCgaCXSh3SW62uAKT1mSBM=
github.com/x448/float16 v0.8.4/go.mod h1:14CWIYCyZA/cWjXOioeEpHeN/83MdbZDRQHoFcYsOfg=
github.com/xtgo/set v1.0.0 h1:6BCNBRv3ORNDQ7fyoJXRv+tstJz3m1JVFQErfeZz2pY=

View File

@@ -11,6 +11,8 @@ import (
"testing"
"time"
orderedmap "github.com/wk8/go-ordered-map/v2"
"github.com/ollama/ollama/api"
)
@@ -432,12 +434,14 @@ func TestAPIToolCalling(t *testing.T) {
Parameters: api.ToolFunctionParameters{
Type: "object",
Required: []string{"location"},
Properties: map[string]api.ToolProperty{
"location": {
Properties: func() *api.ToolProperties {
props := api.NewToolProperties()
props.Set("location", api.ToolProperty{
Type: api.PropertyType{"string"},
Description: "The city and state, e.g. San Francisco, CA",
},
},
})
return props
}(),
},
},
},
@@ -497,7 +501,7 @@ func TestAPIToolCalling(t *testing.T) {
t.Errorf("unexpected tool called: got %q want %q", lastToolCall.Function.Name, "get_weather")
}
if _, ok := lastToolCall.Function.Arguments["location"]; !ok {
if _, ok := lastToolCall.Function.Arguments.Get("location"); !ok {
t.Errorf("expected tool arguments to include 'location', got: %s", lastToolCall.Function.Arguments.String())
}
case <-ctx.Done():

View File

@@ -17,21 +17,16 @@ func TestBlueSky(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute)
defer cancel()
// Set up the test data
req := api.ChatRequest{
Model: smol,
Messages: []api.Message{
{
Role: "user",
Content: blueSkyPrompt,
},
},
req := api.GenerateRequest{
Model: smol,
Prompt: blueSkyPrompt,
Stream: &stream,
Options: map[string]any{
"temperature": 0,
"seed": 123,
},
}
ChatTestHelper(ctx, t, req, blueSkyExpected)
GenerateTestHelper(ctx, t, req, blueSkyExpected)
}
func TestUnicode(t *testing.T) {
@@ -39,15 +34,10 @@ func TestUnicode(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Minute)
defer cancel()
// Set up the test data
req := api.ChatRequest{
req := api.GenerateRequest{
// DeepSeek has a Unicode tokenizer regex, making it a unicode torture test
Model: "deepseek-coder-v2:16b-lite-instruct-q2_K", // TODO is there an ollama-engine model we can switch to and keep the coverage?
Messages: []api.Message{
{
Role: "user",
Content: "天空为什么是蓝色的?", // Why is the sky blue?
},
},
Model: "deepseek-coder-v2:16b-lite-instruct-q2_K", // TODO is there an ollama-engine model we can switch to and keep the coverage?
Prompt: "天空为什么是蓝色的?", // Why is the sky blue?
Stream: &stream,
Options: map[string]any{
"temperature": 0,
@@ -67,14 +57,9 @@ func TestUnicode(t *testing.T) {
if err != nil {
t.Fatalf("failed to load model %s: %s", req.Model, err)
}
defer func() {
// best effort unload once we're done with the model
client.Generate(ctx, &api.GenerateRequest{Model: req.Model, KeepAlive: &api.Duration{Duration: 0}}, func(rsp api.GenerateResponse) error { return nil })
}()
skipIfNotGPULoaded(ctx, t, client, req.Model, 100)
DoChat(ctx, t, client, req, []string{
DoGenerate(ctx, t, client, req, []string{
"散射", // scattering
"频率", // frequency
}, 120*time.Second, 120*time.Second)
@@ -84,14 +69,9 @@ func TestExtendedUnicodeOutput(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute)
defer cancel()
// Set up the test data
req := api.ChatRequest{
Model: "gemma2:2b",
Messages: []api.Message{
{
Role: "user",
Content: "Output some smily face emoji",
},
},
req := api.GenerateRequest{
Model: "gemma2:2b",
Prompt: "Output some smily face emoji",
Stream: &stream,
Options: map[string]any{
"temperature": 0,
@@ -103,7 +83,7 @@ func TestExtendedUnicodeOutput(t *testing.T) {
if err := PullIfMissing(ctx, client, req.Model); err != nil {
t.Fatal(err)
}
DoChat(ctx, t, client, req, []string{"😀", "😊", "😁", "😂", "😄", "😃"}, 120*time.Second, 120*time.Second)
DoGenerate(ctx, t, client, req, []string{"😀", "😊", "😁", "😂", "😄", "😃"}, 120*time.Second, 120*time.Second)
}
func TestUnicodeModelDir(t *testing.T) {
@@ -128,19 +108,14 @@ func TestUnicodeModelDir(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute)
defer cancel()
req := api.ChatRequest{
Model: smol,
Messages: []api.Message{
{
Role: "user",
Content: blueSkyPrompt,
},
},
req := api.GenerateRequest{
Model: smol,
Prompt: blueSkyPrompt,
Stream: &stream,
Options: map[string]any{
"temperature": 0,
"seed": 123,
},
}
ChatTestHelper(ctx, t, req, blueSkyExpected)
GenerateTestHelper(ctx, t, req, blueSkyExpected)
}

View File

@@ -20,9 +20,9 @@ import (
)
// Send multiple requests in parallel (concurrently) to a single model and ensure responses are expected
func TestConcurrentChat(t *testing.T) {
func TestConcurrentGenerate(t *testing.T) {
// Assumes all requests have the same model
req, resp := ChatRequests()
req, resp := GenerateRequests()
numParallel := int(envconfig.NumParallel() + 1)
iterLimit := 3
@@ -57,7 +57,7 @@ func TestConcurrentChat(t *testing.T) {
slog.Info("Starting", "thread", i, "iter", j)
// On slower GPUs it can take a while to process the concurrent requests
// so we allow a much longer initial timeout
DoChat(ctx, t, client, req[k], resp[k], 120*time.Second, 20*time.Second)
DoGenerate(ctx, t, client, req[k], resp[k], 120*time.Second, 20*time.Second)
}
}(i)
}
@@ -163,7 +163,7 @@ chooseModels:
wg.Add(1)
go func(i int) {
defer wg.Done()
reqs, resps := ChatRequests()
reqs, resps := GenerateRequests()
for j := 0; j < 3; j++ {
if time.Now().Sub(started) > softTimeout {
slog.Info("exceeded soft timeout, winding down test")
@@ -171,8 +171,8 @@ chooseModels:
}
k := r.Int() % len(reqs)
reqs[k].Model = chosenModels[i]
slog.Info("Starting", "model", reqs[k].Model, "iteration", j, "request", reqs[k].Messages[0].Content)
DoChat(ctx, t, client, reqs[k], resps[k],
slog.Info("Starting", "model", reqs[k].Model, "iteration", j, "request", reqs[k].Prompt)
DoGenerate(ctx, t, client, reqs[k], resps[k],
120*time.Second, // Be extra patient for the model to load initially
10*time.Second, // Once results start streaming, fail if they stall
)

View File

@@ -21,14 +21,9 @@ func TestLongInputContext(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute)
defer cancel()
// Set up the test data
req := api.ChatRequest{
Model: smol,
Messages: []api.Message{
{
Role: "user",
Content: "Oh, dont speak to me of Austria. Perhaps I dont understand things, but Austria never has wished, and does not wish, for war. She is betraying us! Russia alone must save Europe. Our gracious sovereign recognizes his high vocation and will be true to it. That is the one thing I have faith in! Our good and wonderful sovereign has to perform the noblest role on earth, and he is so virtuous and noble that God will not forsake him. He will fulfill his vocation and crush the hydra of revolution, which has become more terrible than ever in the person of this murderer and villain! We alone must avenge the blood of the just one.... Whom, I ask you, can we rely on?... England with her commercial spirit will not and cannot understand the Emperor Alexanders loftiness of soul. She has refused to evacuate Malta. She wanted to find, and still seeks, some secret motive in our actions. What answer did Novosíltsev get? None. The English have not understood and cannot understand the self-abnegation of our Emperor who wants nothing for himself, but only desires the good of mankind. And what have they promised? Nothing! And what little they have promised they will not perform! Prussia has always declared that Buonaparte is invincible, and that all Europe is powerless before him.... And I dont believe a word that Hardenburg says, or Haugwitz either. This famous Prussian neutrality is just a trap. I have faith only in God and the lofty destiny of our adored monarch. He will save Europe! What country is this referring to?",
},
},
req := api.GenerateRequest{
Model: smol,
Prompt: "Oh, dont speak to me of Austria. Perhaps I dont understand things, but Austria never has wished, and does not wish, for war. She is betraying us! Russia alone must save Europe. Our gracious sovereign recognizes his high vocation and will be true to it. That is the one thing I have faith in! Our good and wonderful sovereign has to perform the noblest role on earth, and he is so virtuous and noble that God will not forsake him. He will fulfill his vocation and crush the hydra of revolution, which has become more terrible than ever in the person of this murderer and villain! We alone must avenge the blood of the just one.... Whom, I ask you, can we rely on?... England with her commercial spirit will not and cannot understand the Emperor Alexanders loftiness of soul. She has refused to evacuate Malta. She wanted to find, and still seeks, some secret motive in our actions. What answer did Novosíltsev get? None. The English have not understood and cannot understand the self-abnegation of our Emperor who wants nothing for himself, but only desires the good of mankind. And what have they promised? Nothing! And what little they have promised they will not perform! Prussia has always declared that Buonaparte is invincible, and that all Europe is powerless before him.... And I dont believe a word that Hardenburg says, or Haugwitz either. This famous Prussian neutrality is just a trap. I have faith only in God and the lofty destiny of our adored monarch. He will save Europe! What country is this referring to?",
Stream: &stream,
Options: map[string]any{
"temperature": 0,
@@ -41,7 +36,7 @@ func TestLongInputContext(t *testing.T) {
if err := PullIfMissing(ctx, client, req.Model); err != nil {
t.Fatalf("PullIfMissing failed: %v", err)
}
DoChat(ctx, t, client, req, []string{"russia", "german", "france", "england", "austria", "prussia", "europe", "individuals", "coalition", "conflict"}, 120*time.Second, 10*time.Second)
DoGenerate(ctx, t, client, req, []string{"russia", "german", "france", "england", "austria", "prussia", "europe", "individuals", "coalition", "conflict"}, 120*time.Second, 10*time.Second)
}
func TestContextExhaustion(t *testing.T) {
@@ -53,14 +48,9 @@ func TestContextExhaustion(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute)
defer cancel()
// Set up the test data
req := api.ChatRequest{
Model: smol,
Messages: []api.Message{
{
Role: "user",
Content: "Write me a story in english with a lot of emojis",
},
},
req := api.GenerateRequest{
Model: smol,
Prompt: "Write me a story in english with a lot of emojis",
Stream: &stream,
Options: map[string]any{
"temperature": 0,
@@ -73,12 +63,12 @@ func TestContextExhaustion(t *testing.T) {
if err := PullIfMissing(ctx, client, req.Model); err != nil {
t.Fatalf("PullIfMissing failed: %v", err)
}
DoChat(ctx, t, client, req, []string{"once", "upon", "lived", "sunny", "cloudy", "clear", "water", "time", "travel", "world"}, 120*time.Second, 10*time.Second)
DoGenerate(ctx, t, client, req, []string{"once", "upon", "lived", "sunny", "cloudy", "clear", "water", "time", "travel", "world"}, 120*time.Second, 10*time.Second)
}
// Send multiple generate requests with prior context and ensure the response is coherant and expected
func TestParallelGenerateWithHistory(t *testing.T) {
modelOverride := "gpt-oss:20b"
modelOverride := ollamaEngineChatModels[0] // Most recent ollama engine model
req, resp := GenerateRequests()
numParallel := 2
iterLimit := 2
@@ -165,7 +155,7 @@ func TestGenerateWithHistory(t *testing.T) {
// Send multiple chat requests with prior context and ensure the response is coherant and expected
func TestParallelChatWithHistory(t *testing.T) {
modelOverride := "gpt-oss:20b"
modelOverride := ollamaEngineChatModels[0] // Most recent ollama engine model
req, resp := ChatRequests()
numParallel := 2
iterLimit := 2

View File

@@ -15,7 +15,7 @@ import (
// First run of this scenario on a target system will take a long time to download
// ~1.5TB of models. Set a sufficiently large -timeout for your network speed
func TestLibraryModelsChat(t *testing.T) {
func TestLibraryModelsGenerate(t *testing.T) {
softTimeout, hardTimeout := getTimeouts(t)
slog.Info("Setting timeouts", "soft", softTimeout, "hard", hardTimeout)
ctx, cancel := context.WithTimeout(context.Background(), hardTimeout)
@@ -43,14 +43,9 @@ func TestLibraryModelsChat(t *testing.T) {
t.Skip(fmt.Sprintf("Skipping %s architecture %s != %s", model, arch, targetArch))
}
}
req := api.ChatRequest{
Model: model,
Messages: []api.Message{
{
Role: "user",
Content: blueSkyPrompt,
},
},
req := api.GenerateRequest{
Model: model,
Prompt: blueSkyPrompt,
KeepAlive: &api.Duration{Duration: 10 * time.Second},
Options: map[string]interface{}{
"temperature": 0.1,
@@ -63,13 +58,13 @@ func TestLibraryModelsChat(t *testing.T) {
anyResp = []string{"select", "from"}
} else if model == "granite3-guardian" || model == "shieldgemma" || model == "llama-guard3" || model == "bespoke-minicheck" {
anyResp = []string{"yes", "no", "safe", "unsafe"}
} else if model == "openthinker" {
} else if model == "openthinker" || model == "nexusraven" {
anyResp = []string{"plugin", "im_sep", "components", "function call"}
} else if model == "starcoder" || model == "starcoder2" || model == "magicoder" || model == "deepseek-coder" {
req.Messages[0].Content = "def fibonacci():"
req.Prompt = "def fibonacci():"
anyResp = []string{"f(n)", "sequence", "n-1", "main()", "__main__", "while"}
}
DoChat(ctx, t, client, req, anyResp, 120*time.Second, 30*time.Second)
DoGenerate(ctx, t, client, req, anyResp, 120*time.Second, 30*time.Second)
})
}
}

View File

@@ -34,22 +34,17 @@ func TestVisionModels(t *testing.T) {
if err != nil {
t.Fatal(err)
}
req := api.ChatRequest{
Model: v.model,
Messages: []api.Message{
{
Role: "user",
Content: "what does the text in this image say?",
Images: []api.ImageData{
image,
},
},
},
req := api.GenerateRequest{
Model: v.model,
Prompt: "what does the text in this image say?",
Stream: &stream,
Options: map[string]any{
"seed": 42,
"temperature": 0.0,
},
Images: []api.ImageData{
image,
},
}
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute)
defer cancel()
@@ -61,15 +56,8 @@ func TestVisionModels(t *testing.T) {
if err := PullIfMissing(ctx, client, req.Model); err != nil {
t.Fatal(err)
}
// Preload to skip if we're less than 80% on GPU to avoid extremely slow tests
err = client.Generate(ctx, &api.GenerateRequest{Model: req.Model}, func(response api.GenerateResponse) error { return nil })
if err != nil {
t.Fatalf("failed to load model %s: %s", req.Model, err)
}
skipIfNotGPULoaded(ctx, t, client, req.Model, 80)
// llava models on CPU can be quite slow to start
DoChat(ctx, t, client, req, []string{resp}, 240*time.Second, 30*time.Second)
DoGenerate(ctx, t, client, req, []string{resp}, 240*time.Second, 30*time.Second)
})
}
}

View File

@@ -19,7 +19,7 @@ import (
"github.com/ollama/ollama/format"
)
func TestModelsChat(t *testing.T) {
func TestModelsGenerate(t *testing.T) {
softTimeout, hardTimeout := getTimeouts(t)
slog.Info("Setting timeouts", "soft", softTimeout, "hard", hardTimeout)
ctx, cancel := context.WithTimeout(context.Background(), hardTimeout)
@@ -66,23 +66,15 @@ func TestModelsChat(t *testing.T) {
}
}
// TODO - fiddle with context size
req := api.ChatRequest{
Model: model,
Messages: []api.Message{
{
Role: "user",
Content: blueSkyPrompt,
},
},
KeepAlive: &api.Duration{Duration: 10 * time.Second},
req := api.GenerateRequest{
Model: model,
Prompt: blueSkyPrompt,
Options: map[string]interface{}{
"temperature": 0,
"seed": 123,
},
}
DoChat(ctx, t, client, req, blueSkyExpected, 120*time.Second, 30*time.Second)
// best effort unload once we're done with the model
client.Generate(ctx, &api.GenerateRequest{Model: req.Model, KeepAlive: &api.Duration{Duration: 0}}, func(rsp api.GenerateResponse) error { return nil })
DoGenerate(ctx, t, client, req, blueSkyExpected, 120*time.Second, 30*time.Second)
})
}
}
@@ -136,9 +128,8 @@ func TestModelsEmbed(t *testing.T) {
}
}
req := api.EmbeddingRequest{
Model: model,
Prompt: "why is the sky blue?",
KeepAlive: &api.Duration{Duration: 10 * time.Second},
Model: model,
Prompt: "why is the sky blue?",
Options: map[string]interface{}{
"temperature": 0,
"seed": 123,
@@ -148,10 +139,6 @@ func TestModelsEmbed(t *testing.T) {
if err != nil {
t.Fatalf("embeddings call failed %s", err)
}
defer func() {
// best effort unload once we're done with the model
client.Generate(ctx, &api.GenerateRequest{Model: req.Model, KeepAlive: &api.Duration{Duration: 0}}, func(rsp api.GenerateResponse) error { return nil })
}()
if len(resp.Embedding) == 0 {
t.Errorf("zero length embedding response")
}

View File

@@ -173,14 +173,9 @@ func doModelPerfTest(t *testing.T, chatModels []string) {
slog.Info("skipping long prompt", "model", model, "num_ctx", numCtx, "gpu_percent", gpuPercent)
continue
}
req := api.ChatRequest{
Model: model,
Messages: []api.Message{
{
Role: "user",
Content: tc.prompt,
},
},
req := api.GenerateRequest{
Model: model,
Prompt: tc.prompt,
KeepAlive: &api.Duration{Duration: 20 * time.Second}, // long enough to ensure a ps returns
Options: map[string]interface{}{
"temperature": 0,
@@ -189,7 +184,7 @@ func doModelPerfTest(t *testing.T, chatModels []string) {
},
}
atLeastOne := false
var resp api.ChatResponse
var resp api.GenerateResponse
stream := false
req.Stream = &stream
@@ -203,7 +198,7 @@ func doModelPerfTest(t *testing.T, chatModels []string) {
)
defer cancel()
err = client.Chat(genCtx, &req, func(rsp api.ChatResponse) error {
err = client.Generate(genCtx, &req, func(rsp api.GenerateResponse) error {
resp = rsp
return nil
})
@@ -219,13 +214,13 @@ func doModelPerfTest(t *testing.T, chatModels []string) {
}
loaded = true
for _, expResp := range tc.anyResp {
if strings.Contains(strings.ToLower(resp.Message.Content), expResp) {
if strings.Contains(strings.ToLower(resp.Response), expResp) {
atLeastOne = true
break
}
}
if !atLeastOne {
t.Fatalf("response didn't contain expected values: ctx:%d expected:%v response:%s ", numCtx, tc.anyResp, resp.Message.Content)
t.Fatalf("response didn't contain expected values: ctx:%d expected:%v response:%s ", numCtx, tc.anyResp, resp.Response)
}
models, err := client.ListRunning(ctx)
if err != nil {

View File

@@ -74,14 +74,9 @@ func TestQuantization(t *testing.T) {
}
stream := true
chatReq := api.ChatRequest{
Model: newName,
Messages: []api.Message{
{
Role: "user",
Content: blueSkyPrompt,
},
},
genReq := api.GenerateRequest{
Model: newName,
Prompt: blueSkyPrompt,
KeepAlive: &api.Duration{Duration: 3 * time.Second},
Options: map[string]any{
"seed": 42,
@@ -96,8 +91,8 @@ func TestQuantization(t *testing.T) {
reqCtx, reqCancel := context.WithCancel(ctx)
atLeastOne := false
var buf bytes.Buffer
chatfn := func(response api.ChatResponse) error {
buf.Write([]byte(response.Message.Content))
genfn := func(response api.GenerateResponse) error {
buf.Write([]byte(response.Response))
fullResp := strings.ToLower(buf.String())
for _, resp := range blueSkyExpected {
if strings.Contains(fullResp, resp) {
@@ -113,14 +108,14 @@ func TestQuantization(t *testing.T) {
done := make(chan int)
var genErr error
go func() {
genErr = client.Chat(reqCtx, &chatReq, chatfn)
genErr = client.Generate(reqCtx, &genReq, genfn)
done <- 0
}()
select {
case <-done:
if genErr != nil && !atLeastOne {
t.Fatalf("failed with %s request prompt %s ", chatReq.Model, chatReq.Messages[0].Content)
t.Fatalf("failed with %s request prompt %s ", genReq.Model, genReq.Prompt)
}
case <-ctx.Done():
t.Error("outer test context done while waiting for generate")

View File

@@ -15,7 +15,6 @@ import (
"net/http"
"net/url"
"os"
"os/exec"
"path/filepath"
"runtime"
"strconv"
@@ -25,6 +24,7 @@ import (
"time"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/app/lifecycle"
"github.com/ollama/ollama/format"
)
@@ -38,7 +38,6 @@ var (
// Note: add newer models at the top of the list to test them first
ollamaEngineChatModels = []string{
"qwen3-coder:30b",
"gpt-oss:20b",
"gemma3n:e2b",
"mistral-small3.2:latest",
@@ -47,7 +46,6 @@ var (
"qwen2.5-coder:latest",
"qwen2.5vl:3b",
"qwen3:0.6b", // dense
"qwen3:1.7b", // dense
"qwen3:30b", // MOE
"gemma3:1b",
"llama3.1:latest",
@@ -267,16 +265,16 @@ var (
"Explain the physics involved in them. Be breif in your reply",
"Explain the chemistry involved in them. Be breif in your reply",
"What are common myths related to them? Be brief in your reply",
"What are common fairytales related to them? Be brief in your reply",
"Can they form if there is no rain? Be breif in your reply",
"Can they form if there are no clouds? Be breif in your reply",
"Do they happen on other planets? Be brief in your reply",
}
rainbowExpected = []string{"water", "droplet", "mist", "glow", "refract", "reflect", "scatter", "particles", "wave", "color", "spectrum", "raindrop", "atmosphere", "frequency", "shower", "sky", "shimmer", "light", "storm", "sunny", "sunburst", "phenomenon", "mars", "venus", "jupiter"}
rainbowExpected = []string{"water", "droplet", "mist", "glow", "refract", "reflect", "scatter", "wave", "color", "spectrum", "raindrop", "atmosphere", "frequency", "end", "gold", "fortune", "blessing", "prosperity", "magic", "shower", "sky", "shimmer", "light", "storm", "sunny"}
)
func init() {
logger := slog.New(slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{Level: slog.LevelDebug}))
slog.SetDefault(logger)
lifecycle.InitLogging()
custom := os.Getenv("OLLAMA_TEST_DEFAULT_MODEL")
if custom != "" {
slog.Info("setting default test model to " + custom)
@@ -337,7 +335,6 @@ func GetTestEndpoint() (*api.Client, string) {
var serverMutex sync.Mutex
var serverReady bool
var serverLogFile string
func startServer(t *testing.T, ctx context.Context, ollamaHost string) error {
// Make sure the server has been built
@@ -364,9 +361,8 @@ func startServer(t *testing.T, ctx context.Context, ollamaHost string) error {
t.Setenv("OLLAMA_HOST", ollamaHost)
}
logDir := t.TempDir()
slog.Info("starting server", "url", ollamaHost)
done, err := SpawnServer(ctx, "../ollama", logDir)
done, err := lifecycle.SpawnServer(ctx, "../ollama")
if err != nil {
return fmt.Errorf("failed to start server: %w", err)
}
@@ -389,36 +385,6 @@ func startServer(t *testing.T, ctx context.Context, ollamaHost string) error {
return nil
}
func SpawnServer(ctx context.Context, command, logDir string) (chan int, error) {
done := make(chan int)
fp, err := os.CreateTemp(logDir, "ollama-server-*.log")
if err != nil {
return nil, fmt.Errorf("failed to create log file: %w", err)
}
serverLogFile = fp.Name()
cmd := exec.CommandContext(ctx, command, "serve")
cmd.Stderr = fp
cmd.Stdout = fp
go func() {
slog.Info("starting server...")
if err := cmd.Run(); err != nil {
// "signal: killed" expected
if !strings.Contains(err.Error(), "signal") {
slog.Info("failed to run server", "error", err)
}
}
var code int
if cmd.ProcessState != nil {
code = cmd.ProcessState.ExitCode()
}
slog.Info("server exited")
done <- code
}()
return done, nil
}
func PullIfMissing(ctx context.Context, client *api.Client, modelName string) error {
slog.Info("checking status of model", "model", modelName)
showReq := &api.ShowRequest{Name: modelName}
@@ -479,6 +445,12 @@ func InitServerConnection(ctx context.Context, t *testing.T) (*api.Client, strin
client, testEndpoint := GetTestEndpoint()
if os.Getenv("OLLAMA_TEST_EXISTING") == "" {
serverProcMutex.Lock()
fp, err := os.CreateTemp("", "ollama-server-*.log")
if err != nil {
t.Fatalf("failed to generate log file: %s", err)
}
lifecycle.ServerLogFile = fp.Name()
fp.Close()
if err := startServer(t, ctx, testEndpoint); err != nil {
t.Fatal(err)
}
@@ -506,32 +478,36 @@ func InitServerConnection(ctx context.Context, t *testing.T) (*api.Client, strin
if os.Getenv("OLLAMA_TEST_EXISTING") == "" {
defer serverProcMutex.Unlock()
if t.Failed() {
fp, err := os.Open(serverLogFile)
fp, err := os.Open(lifecycle.ServerLogFile)
if err != nil {
slog.Error("failed to open server log", "logfile", serverLogFile, "error", err)
slog.Error("failed to open server log", "logfile", lifecycle.ServerLogFile, "error", err)
return
}
defer fp.Close()
data, err := io.ReadAll(fp)
if err != nil {
slog.Error("failed to read server log", "logfile", serverLogFile, "error", err)
slog.Error("failed to read server log", "logfile", lifecycle.ServerLogFile, "error", err)
return
}
slog.Warn("SERVER LOG FOLLOWS")
os.Stderr.Write(data)
slog.Warn("END OF SERVER")
}
err := os.Remove(lifecycle.ServerLogFile)
if err != nil && !os.IsNotExist(err) {
slog.Warn("failed to cleanup", "logfile", lifecycle.ServerLogFile, "error", err)
}
}
}
}
func ChatTestHelper(ctx context.Context, t *testing.T, req api.ChatRequest, anyResp []string) {
func GenerateTestHelper(ctx context.Context, t *testing.T, genReq api.GenerateRequest, anyResp []string) {
client, _, cleanup := InitServerConnection(ctx, t)
defer cleanup()
if err := PullIfMissing(ctx, client, req.Model); err != nil {
if err := PullIfMissing(ctx, client, genReq.Model); err != nil {
t.Fatal(err)
}
DoChat(ctx, t, client, req, anyResp, 30*time.Second, 10*time.Second)
DoGenerate(ctx, t, client, genReq, anyResp, 30*time.Second, 10*time.Second)
}
func DoGenerate(ctx context.Context, t *testing.T, client *api.Client, genReq api.GenerateRequest, anyResp []string, initialTimeout, streamTimeout time.Duration) []int {
@@ -750,14 +726,8 @@ func skipIfNotGPULoaded(ctx context.Context, t *testing.T, client *api.Client, m
loaded := []string{}
for _, m := range models.Models {
loaded = append(loaded, m.Name)
if strings.Contains(model, ":") {
if m.Name != model {
continue
}
} else if strings.Contains(m.Name, ":") {
if !strings.HasPrefix(m.Name, model+":") {
continue
}
if m.Name != model {
continue
}
gpuPercent := 0
switch {

View File

@@ -160,15 +160,7 @@ func (c *Causal) Init(backend ml.Backend, dtype ml.DType, maxSequences, capacity
if c.swaMemorySize == 0 {
c.swaMemorySize = c.swaWindowSize
}
// We will allocate space in the cache for the stop token, which won't be part of a follow on
// sequence, so allocate an extra token of storage to ensure that we can jump back without
// causing a cache break. As an optimization, only do this when we have parallel sequences
// because the extra token will live in the batch buffer and won't get overwritten if we
// only have a single sequence.
if c.swaMemorySize != math.MaxInt32 && maxSequences > 1 {
c.swaMemorySize = max(c.swaMemorySize, c.swaWindowSize+1)
}
if int(c.swaMemorySize) >= capacity {
if int(c.swaMemorySize) > capacity {
c.swaMemorySize = math.MaxInt32
}
@@ -222,6 +214,7 @@ func (c *Causal) StartForward(ctx ml.Context, batch input.Batch, reserve bool) e
c.curLoc, err = c.findStartLoc()
}
if err != nil {
slog.Warn("unable to find a kv cache slot", "cache", c)
return err
}
@@ -295,44 +288,23 @@ func (c *Causal) updateSlidingWindow() {
return
}
type lowestPosition struct {
pos int32
curBatch bool
}
// create a map of unique sequences to the lowest position in that sequence
lowestPos := make(map[int]lowestPosition)
lowestPos := make(map[int]int32)
for i := range c.curPositions {
seq := c.curSequences[i]
lowest, ok := lowestPos[seq]
pos, ok := lowestPos[seq]
if !ok {
lowest = lowestPosition{pos: c.curPositions[i], curBatch: true}
} else if c.curPositions[i] < lowest.pos {
lowest.pos = c.curPositions[i]
pos = c.curPositions[i]
} else if c.curPositions[i] < pos {
pos = c.curPositions[i]
}
lowestPos[seq] = lowest
}
// for any sequences are not part of this batch, clean up any tokens
// that are no longer needed after the processing of the previous
// batch
for seq, seqRange := range c.cellRanges {
if _, ok := lowestPos[seq]; !ok {
var last int32
for i := seqRange.min; i <= seqRange.max; i++ {
if slices.Contains(c.cells[i].sequences, seq) {
last = max(last, c.cells[i].pos)
}
}
lowestPos[seq] = lowestPosition{pos: last + 1, curBatch: false}
}
lowestPos[seq] = pos
}
// delete any entries that are beyond the window of the oldest position in the sequence
for seq, lowest := range lowestPos {
for seq, pos := range lowestPos {
oldRange, ok := c.cellRanges[seq]
if !ok {
continue
@@ -342,13 +314,13 @@ func (c *Causal) updateSlidingWindow() {
for i := oldRange.min; i <= oldRange.max; i++ {
if slices.Contains(c.cells[i].sequences, seq) {
if c.cells[i].pos < lowest.pos-c.swaMemorySize {
if c.cells[i].pos < pos-c.swaMemorySize {
c.cells[i].sequences = slices.DeleteFunc(c.cells[i].sequences, func(s int) bool { return s == seq })
} else {
newRange.min = min(newRange.min, i)
newRange.max = max(newRange.max, i)
}
if lowest.curBatch && c.cells[i].pos >= lowest.pos-c.swaWindowSize {
if c.cells[i].pos >= pos-c.swaWindowSize {
c.curCellRange.min = min(c.curCellRange.min, i)
c.curCellRange.max = max(c.curCellRange.max, i)
}
@@ -685,11 +657,9 @@ func (c *Causal) CanResume(seq int, pos int32) bool {
// for sliding window, check that the window of the new sequence is contained in
// the window of what we are storing
var first int32 = math.MaxInt32
var last int32 = -1
for i := seqRange.min; i <= seqRange.max; i++ {
if slices.Contains(c.cells[i].sequences, seq) {
first = min(first, c.cells[i].pos)
last = max(last, c.cells[i].pos)
}
}
@@ -698,8 +668,10 @@ func (c *Causal) CanResume(seq int, pos int32) bool {
return false
}
lastWindowStart := max(0, last-c.swaMemorySize)
posWindowStart := max(0, pos-c.swaWindowSize)
return posWindowStart >= first && pos <= last+1
return posWindowStart >= lastWindowStart
}
func (c *Causal) shift(seq int, beginIndex, offset int32) error {

View File

@@ -96,86 +96,6 @@ func TestSWA(t *testing.T) {
testCache(t, backend, cache, tests)
}
func TestSWASeparateBatches(t *testing.T) {
backend := &testBackend{}
cache := NewSWACache(1, nil)
defer cache.Close()
cache.Init(backend, ml.DTypeF16, 2, 16, 2)
x := float32(math.Inf(-1))
tests := []testCase{
{
name: "First seq 0",
in: []float32{1, 2},
inShape: []int{1, 1, 2},
seqs: []int{0, 0},
pos: []int32{0, 1},
expected: []float32{1, 2},
expectedShape: []int{1, 1, 2},
expectedMask: []float32{
0, x,
0, 0,
},
},
{
name: "Second seq 0",
in: []float32{3, 4},
inShape: []int{1, 1, 2},
seqs: []int{0, 0},
pos: []int32{2, 3},
expected: []float32{2, 3, 4},
expectedShape: []int{1, 1, 3},
expectedMask: []float32{
0, 0, x,
x, 0, 0,
},
},
{
name: "First seq 1",
in: []float32{5, 6},
inShape: []int{1, 1, 2},
seqs: []int{1, 1},
pos: []int32{0, 1},
expected: []float32{5, 6},
expectedShape: []int{1, 1, 2},
expectedMask: []float32{
0, x,
0, 0,
},
},
{
name: "Second seq 1",
in: []float32{7, 8},
inShape: []int{1, 1, 2},
seqs: []int{1, 1},
pos: []int32{2, 3},
expected: []float32{6, 3, 4, 7, 8},
expectedShape: []int{1, 1, 5},
expectedMask: []float32{
0, x, x, 0, x,
x, x, x, 0, 0,
},
},
{
name: "Third seq 0",
in: []float32{9, 10},
inShape: []int{1, 1, 2},
seqs: []int{0, 0},
pos: []int32{4, 5},
expected: []float32{9, 10, 3, 4},
expectedShape: []int{1, 1, 4},
expectedMask: []float32{
0, x, x, 0,
0, 0, x, x,
},
},
}
testCache(t, backend, cache, tests)
}
func TestSWAMem(t *testing.T) {
backend := &testBackend{}
cache := NewSWAMemCache(1, 3, nil)
@@ -511,15 +431,15 @@ func TestCanResume(t *testing.T) {
defer context.Close()
err := cache.StartForward(context, input.Batch{
Positions: []int32{0, 1, 2, 3, 4},
Sequences: []int{0, 0, 0, 0, 0},
Positions: []int32{0, 1, 2, 3},
Sequences: []int{0, 0, 0, 0},
}, false)
if err != nil {
t.Fatalf("StartForward failed: %v", err)
}
cache.SetLayer(0)
tensor := context.FromFloatSlice([]float32{1, 2, 3, 4, 5}, 1, 1, 5)
tensor := context.FromFloatSlice([]float32{1, 2, 3, 4}, 1, 1, 4)
cache.Put(context, tensor, tensor)
// with window size 4, nothing has slid out of the window yet
@@ -535,21 +455,18 @@ func TestCanResume(t *testing.T) {
if !cache.CanResume(0, 3) {
t.Errorf("CanResume(0, 3) = false, want true (latest position)")
}
if !cache.CanResume(0, 4) {
t.Errorf("CanResume(0, 4) = false, want true (latest position)")
}
// shift window by adding position 5
// shift window by adding position 4
err = cache.StartForward(context, input.Batch{
Positions: []int32{5},
Sequences: []int{0},
Positions: []int32{4, 5},
Sequences: []int{0, 0},
}, false)
if err != nil {
t.Fatalf("StartForward failed: %v", err)
}
cache.SetLayer(0)
tensor = context.FromFloatSlice([]float32{6}, 1, 1, 1)
tensor = context.FromFloatSlice([]float32{5, 6}, 1, 1, 2)
cache.Put(context, tensor, tensor)
// only the latest position has overlapping windows
@@ -586,28 +503,28 @@ func TestCanResumeSWAMem(t *testing.T) {
defer context.Close()
err := cache.StartForward(context, input.Batch{
Positions: []int32{0, 1, 2, 3, 4, 5, 6},
Sequences: []int{0, 0, 0, 0, 0, 0, 0},
Positions: []int32{0, 1, 2, 3, 4, 5},
Sequences: []int{0, 0, 0, 0, 0, 0},
}, false)
if err != nil {
t.Fatalf("StartForward failed: %v", err)
}
cache.SetLayer(0)
tensor := context.FromFloatSlice([]float32{1, 2, 3, 4, 5, 6, 7}, 1, 1, 7)
tensor := context.FromFloatSlice([]float32{1, 2, 3, 4, 5, 6}, 1, 1, 6)
cache.Put(context, tensor, tensor)
// shift window by adding position 7
// shift window by adding position 6
err = cache.StartForward(context, input.Batch{
Positions: []int32{7},
Sequences: []int{0},
Positions: []int32{6, 7},
Sequences: []int{0, 0},
}, false)
if err != nil {
t.Fatalf("StartForward failed: %v", err)
}
cache.SetLayer(0)
tensor = context.FromFloatSlice([]float32{8}, 1, 1, 1)
tensor = context.FromFloatSlice([]float32{7, 8}, 1, 1, 2)
cache.Put(context, tensor, tensor)
// only the latest position has overlapping windows

View File

@@ -266,18 +266,11 @@ func estimateGPULayers(gpus []discover.GpuInfo, f *ggml.GGML, projectors []strin
}
// Only include GPUs that can fit the graph, gpu minimum, the layer buffer and at least more layer
if gpus[i].FreeMemory < overhead+gzo+max(graphPartialOffload, graphFullOffload)+gpus[i].MinimumMemory+2*layerSize {
var compute string
if gpus[i].Library == "ROCm" {
compute = fmt.Sprintf("gfx%x%02x", gpus[i].ComputeMajor, gpus[i].ComputeMinor)
} else {
compute = fmt.Sprintf("%d.%d", gpus[i].ComputeMajor, gpus[i].ComputeMinor)
}
slog.Debug("gpu has too little memory to allocate any layers",
"id", gpus[i].ID,
"library", gpus[i].Library,
"variant", gpus[i].Variant,
"compute", compute,
"compute", gpus[i].Compute,
"driver", fmt.Sprintf("%d.%d", gpus[i].DriverMajor, gpus[i].DriverMinor),
"name", gpus[i].Name,
"total", format.HumanBytes2(gpus[i].TotalMemory),

View File

@@ -1486,10 +1486,7 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu
serverReq.Header.Set("Content-Type", "application/json")
res, err := http.DefaultClient.Do(serverReq)
if err != nil && errors.Is(err, context.Canceled) {
// client closed connection
return err
} else if err != nil {
if err != nil {
slog.Error("post predict", "error", err)
return errors.New("model runner has unexpectedly stopped, this may be due to resource limitations or an internal error, check ollama server logs for details")
}

View File

@@ -14,6 +14,7 @@ import (
"github.com/gin-gonic/gin"
"github.com/google/go-cmp/cmp"
"github.com/google/go-cmp/cmp/cmpopts"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/openai"
@@ -29,6 +30,16 @@ var (
True = true
)
func makeArgs(pairs ...any) api.ToolCallFunctionArguments {
args := api.NewToolCallFunctionArguments()
for i := 0; i < len(pairs); i += 2 {
key := pairs[i].(string)
value := pairs[i+1]
args.Set(key, value)
}
return args
}
func captureRequestMiddleware(capturedRequest any) gin.HandlerFunc {
return func(c *gin.Context) {
bodyBytes, _ := io.ReadAll(c.Request.Body)
@@ -220,10 +231,7 @@ func TestChatMiddleware(t *testing.T) {
{
Function: api.ToolCallFunction{
Name: "get_current_weather",
Arguments: map[string]any{
"location": "Paris, France",
"format": "celsius",
},
Arguments: makeArgs("location", "Paris, France", "format", "celsius"),
},
},
},
@@ -259,10 +267,7 @@ func TestChatMiddleware(t *testing.T) {
{
Function: api.ToolCallFunction{
Name: "get_current_weather",
Arguments: map[string]any{
"location": "Paris, France",
"format": "celsius",
},
Arguments: makeArgs("location", "Paris, France", "format", "celsius"),
},
},
},
@@ -297,10 +302,7 @@ func TestChatMiddleware(t *testing.T) {
{
Function: api.ToolCallFunction{
Name: "get_current_weather",
Arguments: map[string]any{
"location": "Paris, France",
"format": "celsius",
},
Arguments: makeArgs("location", "Paris, France", "format", "celsius"),
},
},
},
@@ -336,10 +338,7 @@ func TestChatMiddleware(t *testing.T) {
{
Function: api.ToolCallFunction{
Name: "get_current_weather",
Arguments: map[string]any{
"location": "Paris, France",
"format": "celsius",
},
Arguments: makeArgs("location", "Paris, France", "format", "celsius"),
},
},
},
@@ -375,10 +374,7 @@ func TestChatMiddleware(t *testing.T) {
{
Function: api.ToolCallFunction{
Name: "get_current_weather",
Arguments: map[string]any{
"location": "Paris, France",
"format": "celsius",
},
Arguments: makeArgs("location", "Paris, France", "format", "celsius"),
},
},
},
@@ -419,10 +415,7 @@ func TestChatMiddleware(t *testing.T) {
{
Function: api.ToolCallFunction{
Name: "get_current_weather",
Arguments: map[string]any{
"location": "Paris, France",
"format": "celsius",
},
Arguments: makeArgs("location", "Paris, France", "format", "celsius"),
},
},
},
@@ -484,26 +477,22 @@ func TestChatMiddleware(t *testing.T) {
Function: api.ToolFunction{
Name: "get_weather",
Description: "Get the current weather",
Parameters: struct {
Type string `json:"type"`
Defs any `json:"$defs,omitempty"`
Items any `json:"items,omitempty"`
Required []string `json:"required"`
Properties map[string]api.ToolProperty `json:"properties"`
}{
Type: "object",
Required: []string{"location"},
Properties: map[string]api.ToolProperty{
"location": {
Parameters: api.NewToolFunctionParametersWithProps(
"object",
[]string{"location"},
func() *api.ToolProperties {
props := api.NewToolProperties()
props.Set("location", api.ToolProperty{
Type: api.PropertyType{"string"},
Description: "The city and state",
},
"unit": {
})
props.Set("unit", api.ToolProperty{
Type: api.PropertyType{"string"},
Enum: []any{"celsius", "fahrenheit"},
},
},
},
})
return props
}(),
),
},
},
},
@@ -557,7 +546,7 @@ func TestChatMiddleware(t *testing.T) {
}
return
}
if diff := cmp.Diff(&tc.req, capturedRequest); diff != "" {
if diff := cmp.Diff(&tc.req, capturedRequest, cmpopts.IgnoreUnexported(api.ToolCallFunctionArguments{}, api.ToolProperties{}, api.ToolFunctionParameters{})); diff != "" {
t.Fatalf("requests did not match: %+v", diff)
}
if diff := cmp.Diff(tc.err, errResp); diff != "" {

View File

@@ -268,17 +268,17 @@ func parseToolCall(raw qwenEventRawToolCall, tools []api.Tool) (api.ToolCall, er
}
}
toolCall.Function.Arguments = make(api.ToolCallFunctionArguments)
toolCall.Function.Arguments = api.NewToolCallFunctionArguments()
for _, parameter := range functionCall.Parameters {
// Look up the parameter type if we found the tool
var paramType api.PropertyType
if matchedTool != nil && matchedTool.Function.Parameters.Properties != nil {
if prop, ok := matchedTool.Function.Parameters.Properties[parameter.Name]; ok {
if matchedTool != nil && matchedTool.Function.Parameters.GetProperties() != nil {
if prop, ok := matchedTool.Function.Parameters.GetProperties().Get(parameter.Name); ok {
paramType = prop.Type
}
}
toolCall.Function.Arguments[parameter.Name] = parseValue(parameter.Value, paramType)
toolCall.Function.Arguments.Set(parameter.Name, parseValue(parameter.Value, paramType))
}
return toolCall, nil

View File

@@ -11,10 +11,25 @@ import (
func tool(name string, props map[string]api.ToolProperty) api.Tool {
t := api.Tool{Type: "function", Function: api.ToolFunction{Name: name}}
t.Function.Parameters.Type = "object"
t.Function.Parameters.Properties = props
p := api.NewToolProperties()
for k, v := range props {
p.Set(k, v)
}
t.Function.Parameters.SetProperties(p)
return t
}
// Helper function to create ordered arguments for tests
func makeArgs(pairs ...any) api.ToolCallFunctionArguments {
args := api.NewToolCallFunctionArguments()
for i := 0; i < len(pairs); i += 2 {
key := pairs[i].(string)
value := pairs[i+1]
args.Set(key, value)
}
return args
}
func TestQwenParserStreaming(t *testing.T) {
type step struct {
input string
@@ -354,10 +369,7 @@ celsius
wantToolCall: api.ToolCall{
Function: api.ToolCallFunction{
Name: "get_current_temperature",
Arguments: map[string]any{
"location": "San Francisco",
"unit": "celsius",
},
Arguments: makeArgs("location", "San Francisco", "unit", "celsius"),
},
},
},
@@ -375,10 +387,10 @@ celsius
wantToolCall: api.ToolCall{
Function: api.ToolCallFunction{
Name: "get current temperature",
Arguments: map[string]any{
"location with spaces": "San Francisco",
"unit with spaces": "celsius",
},
Arguments: makeArgs(
"location with spaces", "San Francisco",
"unit with spaces", "celsius",
),
},
},
},
@@ -400,10 +412,10 @@ San Francisco
wantToolCall: api.ToolCall{
Function: api.ToolCallFunction{
Name: "\"get current temperature\"",
Arguments: map[string]any{
"\"location with spaces\"": "San Francisco",
"\"unit with spaces\"": "\"celsius\"",
},
Arguments: makeArgs(
"\"location with spaces\"", "San Francisco",
"\"unit with spaces\"", "\"celsius\"",
),
},
},
},
@@ -434,12 +446,12 @@ true
wantToolCall: api.ToolCall{
Function: api.ToolCallFunction{
Name: "calculate",
Arguments: map[string]any{
"x": 3.14,
"y": 42,
"enabled": true,
"items": []any{"a", "b", "c"},
},
Arguments: makeArgs(
"x", 3.14,
"y", 42,
"enabled", true,
"items", []any{"a", "b", "c"},
),
},
},
},
@@ -455,9 +467,7 @@ ls && echo "done"
wantToolCall: api.ToolCall{
Function: api.ToolCallFunction{
Name: "exec",
Arguments: map[string]any{
"command": "ls && echo \"done\"",
},
Arguments: makeArgs("command", "ls && echo \"done\""),
},
},
},
@@ -472,9 +482,7 @@ ls && echo "a > b and a < b"
wantToolCall: api.ToolCall{
Function: api.ToolCallFunction{
Name: "exec",
Arguments: map[string]any{
"command": "ls && echo \"a > b and a < b\"",
},
Arguments: makeArgs("command", "ls && echo \"a > b and a < b\""),
},
},
},
@@ -492,10 +500,7 @@ Hello! 你好! 🌟 مرحبا
wantToolCall: api.ToolCall{
Function: api.ToolCallFunction{
Name: "获取天气",
Arguments: map[string]any{
"城市": "北京",
"message": "Hello! 你好! 🌟 مرحبا",
},
Arguments: makeArgs("城市", "北京", "message", "Hello! 你好! 🌟 مرحبا"),
},
},
},

View File

@@ -94,26 +94,28 @@ func Qwen3CoderRenderer(messages []api.Message, tools []api.Tool, _ *api.ThinkVa
}
sb.WriteString("\n<parameters>")
for name, prop := range tool.Function.Parameters.Properties {
sb.WriteString("\n<parameter>")
sb.WriteString("\n<name>" + name + "</name>")
if tool.Function.Parameters.GetProperties() != nil {
for name, prop := range tool.Function.Parameters.Properties() {
sb.WriteString("\n<parameter>")
sb.WriteString("\n<name>" + name + "</name>")
if len(prop.Type) > 0 {
sb.WriteString("\n<type>" + formatToolDefinitionType(prop.Type) + "</type>")
if len(prop.Type) > 0 {
sb.WriteString("\n<type>" + formatToolDefinitionType(prop.Type) + "</type>")
}
if prop.Description != "" {
sb.WriteString("\n<description>" + prop.Description + "</description>")
}
// Render any additional keys not already handled
handledKeys := map[string]bool{
"type": true,
"description": true,
}
sb.WriteString(renderAdditionalKeys(prop, handledKeys))
sb.WriteString("\n</parameter>")
}
if prop.Description != "" {
sb.WriteString("\n<description>" + prop.Description + "</description>")
}
// Render any additional keys not already handled
handledKeys := map[string]bool{
"type": true,
"description": true,
}
sb.WriteString(renderAdditionalKeys(prop, handledKeys))
sb.WriteString("\n</parameter>")
}
// Render extra keys for parameters (everything except 'type' and 'properties')
@@ -145,7 +147,7 @@ func Qwen3CoderRenderer(messages []api.Message, tools []api.Tool, _ *api.ThinkVa
}
for _, toolCall := range message.ToolCalls {
sb.WriteString("\n<tool_call>\n<function=" + toolCall.Function.Name + ">")
for name, value := range toolCall.Function.Arguments {
for name, value := range toolCall.Function.Arguments.All() {
valueStr := formatToolCallArgument(value)
sb.WriteString("\n<parameter=" + name + ">\n" + valueStr + "\n</parameter>")
}

View File

@@ -4,9 +4,32 @@ import (
"testing"
"github.com/google/go-cmp/cmp"
"github.com/ollama/ollama/api"
)
// Helper function to create ordered arguments for tests
func makeArgs(pairs ...any) api.ToolCallFunctionArguments {
args := api.NewToolCallFunctionArguments()
for i := 0; i < len(pairs); i += 2 {
key := pairs[i].(string)
value := pairs[i+1]
args.Set(key, value)
}
return args
}
// Helper function to create ordered properties for tests
func makeProps(pairs ...any) *api.ToolProperties {
props := api.NewToolProperties()
for i := 0; i < len(pairs); i += 2 {
key := pairs[i].(string)
value := pairs[i+1].(api.ToolProperty)
props.Set(key, value)
}
return props
}
func TestQwen3CoderRenderer(t *testing.T) {
tests := []struct {
name string
@@ -38,10 +61,8 @@ Hello, how are you?<|im_end|>
ToolCalls: []api.ToolCall{
{
Function: api.ToolCallFunction{
Name: "get_weather",
Arguments: map[string]any{
"unit": "fahrenheit",
},
Name: "get_weather",
Arguments: makeArgs("unit", "fahrenheit"),
},
},
},
@@ -53,18 +74,13 @@ Hello, how are you?<|im_end|>
{Function: api.ToolFunction{
Name: "get_weather",
Description: "Get the current weather in a given location",
Parameters: api.ToolFunctionParameters{
Required: []string{"unit"},
Properties: map[string]api.ToolProperty{
"unit": {Type: api.PropertyType{"string"}, Enum: []any{"celsius", "fahrenheit"}, Description: "The unit of temperature"},
// TODO(drifkin): add multiple params back once we have predictable
// order via some sort of ordered map type (see
// <https://github.com/ollama/ollama/issues/12244>)
/*
"location": {Type: api.PropertyType{"string"}, Description: "The city and state, e.g. San Francisco, CA"},
*/
},
},
Parameters: api.NewToolFunctionParametersWithProps(
"object",
[]string{"unit"},
makeProps(
"unit", api.ToolProperty{Type: api.PropertyType{"string"}, Enum: []any{"celsius", "fahrenheit"}, Description: "The unit of temperature"},
),
),
}},
},
expected: `<|im_start|>system
@@ -140,19 +156,19 @@ That sounds nice! What about New York?<|im_end|>
{Role: "system", Content: "You are a helpful assistant with access to tools."},
{Role: "user", Content: "call double(1) and triple(2)"},
{Role: "assistant", Content: "I'll call double(1) and triple(2) for you.", ToolCalls: []api.ToolCall{
{Function: api.ToolCallFunction{Name: "double", Arguments: map[string]any{"number": "1"}}},
{Function: api.ToolCallFunction{Name: "triple", Arguments: map[string]any{"number": "2"}}},
{Function: api.ToolCallFunction{Name: "double", Arguments: makeArgs("number", "1")}},
{Function: api.ToolCallFunction{Name: "triple", Arguments: makeArgs("number", "2")}},
}},
{Role: "tool", Content: "{\"number\": 2}", ToolName: "double"},
{Role: "tool", Content: "{\"number\": 6}", ToolName: "triple"},
},
tools: []api.Tool{
{Function: api.ToolFunction{Name: "double", Description: "Double a number", Parameters: api.ToolFunctionParameters{Properties: map[string]api.ToolProperty{
"number": {Type: api.PropertyType{"string"}, Description: "The number to double"},
}}}},
{Function: api.ToolFunction{Name: "triple", Description: "Triple a number", Parameters: api.ToolFunctionParameters{Properties: map[string]api.ToolProperty{
"number": {Type: api.PropertyType{"string"}, Description: "The number to triple"},
}}}},
{Function: api.ToolFunction{Name: "double", Description: "Double a number", Parameters: api.NewToolFunctionParametersWithProps("object", nil, makeProps(
"number", api.ToolProperty{Type: api.PropertyType{"string"}, Description: "The number to double"},
))}},
{Function: api.ToolFunction{Name: "triple", Description: "Triple a number", Parameters: api.NewToolFunctionParametersWithProps("object", nil, makeProps(
"number", api.ToolProperty{Type: api.PropertyType{"string"}, Description: "The number to triple"},
))}},
},
expected: `<|im_start|>system
You are a helpful assistant with access to tools.
@@ -258,10 +274,8 @@ I'll tell you something interesting about cats`,
{Role: "user", Content: "call tool"},
{Role: "assistant", ToolCalls: []api.ToolCall{
{Function: api.ToolCallFunction{
Name: "echo",
Arguments: map[string]any{
"payload": map[string]any{"foo": "bar"},
},
Name: "echo",
Arguments: makeArgs("payload", map[string]any{"foo": "bar"}),
}},
}},
{Role: "tool", Content: "{\"payload\": {\"foo\": \"bar\"}}", ToolName: "echo"},
@@ -368,3 +382,62 @@ func TestQwen3ToolDefinitionTypes(t *testing.T) {
})
}
}
func TestMultipleParametersNonDeterministic(t *testing.T) {
// This test demonstrates that tools with multiple parameters are rendered
// non-deterministically due to Go's map iteration order.
// See https://github.com/ollama/ollama/issues/12244
tools := []api.Tool{
{Function: api.ToolFunction{
Name: "get_weather",
Description: "Get the current weather",
Parameters: api.NewToolFunctionParametersWithProps(
"object",
[]string{"location", "unit"},
makeProps(
"location", api.ToolProperty{Type: api.PropertyType{"string"}, Description: "The city and state"},
"unit", api.ToolProperty{Type: api.PropertyType{"string"}, Description: "The temperature unit"},
"format", api.ToolProperty{Type: api.PropertyType{"string"}, Description: "The output format"},
),
),
}},
}
msgs := []api.Message{
{Role: "user", Content: "What's the weather?"},
{Role: "assistant", ToolCalls: []api.ToolCall{
{Function: api.ToolCallFunction{
Name: "get_weather",
Arguments: makeArgs(
"location", "San Francisco, CA",
"unit", "fahrenheit",
"format", "detailed",
),
}},
}},
}
// Run the renderer multiple times and collect unique outputs
outputs := make(map[string]bool)
for i := 0; i < 15; i++ {
rendered, err := Qwen3CoderRenderer(msgs, tools, nil)
if err != nil {
t.Fatal(err)
}
outputs[rendered] = true
}
// The renderer should be deterministic - we should only get one unique output
if len(outputs) > 1 {
// Show the first two different outputs for comparison
count := 0
for output := range outputs {
if count < 2 {
t.Logf("\nOutput variant %d:\n%s", count+1, output)
count++
}
}
t.Fatalf("Renderer produced %d different outputs across 15 runs (expected deterministic output)", len(outputs))
}
}

View File

@@ -9,7 +9,6 @@ import (
"log/slog"
"math/rand"
"net/http"
"slices"
"strings"
"time"
@@ -83,7 +82,7 @@ type StreamOptions struct {
}
type Reasoning struct {
Effort string `json:"effort,omitempty"`
Effort *string `json:"effort,omitempty"`
}
type ChatCompletionRequest struct {
@@ -567,23 +566,13 @@ func FromChatRequest(r ChatCompletionRequest) (*api.ChatRequest, error) {
}
var think *api.ThinkValue
var effort string
if r.Reasoning != nil {
effort = r.Reasoning.Effort
} else if r.ReasoningEffort != nil {
effort = *r.ReasoningEffort
}
if effort != "" {
if !slices.Contains([]string{"high", "medium", "low", "none"}, effort) {
return nil, fmt.Errorf("invalid reasoning value: '%s' (must be \"high\", \"medium\", \"low\", or \"none\")", effort)
think = &api.ThinkValue{
Value: *r.Reasoning.Effort,
}
if effort == "none" {
think = &api.ThinkValue{Value: false}
} else {
think = &api.ThinkValue{Value: effort}
} else if r.ReasoningEffort != nil {
think = &api.ThinkValue{
Value: *r.ReasoningEffort,
}
}

View File

@@ -330,16 +330,12 @@ func (s *Server) GenerateHandler(c *gin.Context) {
if req.Suffix != "" {
caps = append(caps, model.CapabilityInsert)
}
modelCaps := m.Capabilities()
if req.Think != nil {
if req.Think != nil && req.Think.Bool() {
caps = append(caps, model.CapabilityThinking)
} else {
// add thinking if the model supports it
if slices.Contains(modelCaps, model.CapabilityThinking) {
caps = append(caps, model.CapabilityThinking)
req.Think = &api.ThinkValue{Value: true}
}
// TODO(drifkin): consider adding a warning if it's false and the model
// doesn't support thinking. It's not strictly required, but it can be a
// hint that the user is on an older qwen3/r1 model that doesn't have an
// updated template supporting thinking
}
r, m, opts, err := s.scheduleRunner(c.Request.Context(), name.String(), caps, req.Options, req.KeepAlive)
@@ -1875,16 +1871,8 @@ func (s *Server) ChatHandler(c *gin.Context) {
if len(req.Tools) > 0 {
caps = append(caps, model.CapabilityTools)
}
modelCaps := m.Capabilities()
if req.Think != nil {
if req.Think != nil && req.Think.Bool() {
caps = append(caps, model.CapabilityThinking)
} else {
// add thinking if the model supports it
if slices.Contains(modelCaps, model.CapabilityThinking) {
caps = append(caps, model.CapabilityThinking)
req.Think = &api.ThinkValue{Value: true}
}
}
r, m, opts, err := s.scheduleRunner(c.Request.Context(), name.String(), caps, req.Options, req.KeepAlive)
@@ -1979,167 +1967,88 @@ func (s *Server) ChatHandler(c *gin.Context) {
toolParser = tools.NewParser(m.Template.Template, req.Tools)
}
type structuredOutputsState int
const (
structuredOutputsState_None structuredOutputsState = iota
structuredOutputsState_ReadyToApply
structuredOutputsState_Applying
)
ch := make(chan any)
go func() {
defer close(ch)
structuredOutputsState := structuredOutputsState_None
for {
var tb strings.Builder
currentFormat := req.Format
// structured outputs via double request is enabled when:
// 1. the model supports the thinking capability and
// 2. it uses a built-in parser or our generic thinking parser
// Note that the current approach does not work for (potential future)
// non-thinking models that emit anything before actual content. This
// current approach uses the transition from parsed thinking content to
// parsed non-thinking content as the signal to turn constraining on
if req.Format != nil && structuredOutputsState == structuredOutputsState_None && ((builtinParser != nil || thinkingState != nil) && slices.Contains(m.Capabilities(), model.CapabilityThinking)) {
currentFormat = nil
if err := r.Completion(c.Request.Context(), llm.CompletionRequest{
Prompt: prompt,
Images: images,
Format: req.Format,
Options: opts,
}, func(r llm.CompletionResponse) {
res := api.ChatResponse{
Model: req.Model,
CreatedAt: time.Now().UTC(),
Message: api.Message{Role: "assistant", Content: r.Content},
Done: r.Done,
Metrics: api.Metrics{
PromptEvalCount: r.PromptEvalCount,
PromptEvalDuration: r.PromptEvalDuration,
EvalCount: r.EvalCount,
EvalDuration: r.EvalDuration,
},
}
if r.Done {
res.DoneReason = r.DoneReason.String()
res.TotalDuration = time.Since(checkpointStart)
res.LoadDuration = checkpointLoaded.Sub(checkpointStart)
}
// sets up new context given parent context per request
ctx, cancel := context.WithCancel(c.Request.Context())
err := r.Completion(ctx, llm.CompletionRequest{
Prompt: prompt,
Images: images,
Format: currentFormat,
Options: opts,
}, func(r llm.CompletionResponse) {
res := api.ChatResponse{
Model: req.Model,
CreatedAt: time.Now().UTC(),
Message: api.Message{Role: "assistant", Content: r.Content},
Done: r.Done,
Metrics: api.Metrics{
PromptEvalCount: r.PromptEvalCount,
PromptEvalDuration: r.PromptEvalDuration,
EvalCount: r.EvalCount,
EvalDuration: r.EvalDuration,
},
}
if r.Done {
res.DoneReason = r.DoneReason.String()
res.TotalDuration = time.Since(checkpointStart)
res.LoadDuration = checkpointLoaded.Sub(checkpointStart)
}
if builtinParser != nil {
slog.Log(context.TODO(), logutil.LevelTrace, "builtin parser input", "parser", m.Config.Parser, "content", r.Content)
if builtinParser != nil {
slog.Log(context.TODO(), logutil.LevelTrace, "builtin parser input", "parser", m.Config.Parser, "content", r.Content)
content, thinking, toolCalls, err := builtinParser.Add(r.Content, r.Done)
if err != nil {
ch <- gin.H{"error": err.Error()}
return
}
res.Message.Content = content
res.Message.Thinking = thinking
res.Message.ToolCalls = toolCalls
tb.WriteString(thinking)
// we are now receiving content from the model - we should start applying structured outputs
if structuredOutputsState == structuredOutputsState_None && req.Format != nil && tb.String() != "" && res.Message.Content != "" {
structuredOutputsState = structuredOutputsState_ReadyToApply
cancel()
return
}
if res.Message.Content != "" || res.Message.Thinking != "" || len(res.Message.ToolCalls) > 0 || r.Done {
slog.Log(context.TODO(), logutil.LevelTrace, "builtin parser output", "parser", m.Config.Parser, "content", content, "thinking", thinking, "toolCalls", toolCalls, "done", r.Done)
ch <- res
} else {
slog.Log(context.TODO(), logutil.LevelTrace, "builtin parser empty output", "parser", m.Config.Parser)
}
return
}
if thinkingState != nil {
thinkingContent, remainingContent := thinkingState.AddContent(res.Message.Content)
if thinkingContent == "" && remainingContent == "" && !r.Done {
// need to accumulate more to decide what to send
return
}
res.Message.Thinking = thinkingContent
tb.WriteString(thinkingContent)
// emit the collected thinking text before restarting with structured outputs and clear unstructured content
// to avoid leaking mixed tokens like "</think>Hello"
if structuredOutputsState == structuredOutputsState_None && req.Format != nil && tb.String() != "" && remainingContent != "" {
structuredOutputsState = structuredOutputsState_ReadyToApply
res.Message.Content = ""
ch <- res
cancel()
return
}
res.Message.Content = remainingContent
}
if len(req.Tools) > 0 {
toolCalls, content := toolParser.Add(res.Message.Content)
if len(content) > 0 {
res.Message.Content = content
} else if len(toolCalls) > 0 {
res.Message.ToolCalls = toolCalls
res.Message.Content = ""
} else if res.Message.Thinking != "" {
// don't return
} else {
if r.Done {
res.Message.Content = toolParser.Content()
ch <- res
}
return
}
}
ch <- res
})
if err != nil {
if structuredOutputsState == structuredOutputsState_ReadyToApply && strings.Contains(err.Error(), "context canceled") && c.Request.Context().Err() == nil {
// only ignores error if it's a context cancellation due to setting structured outputs
} else {
ch <- gin.H{"error": err.Error()}
return
}
}
// ignored structured outputs cancellation falls through to here, start a new request with the structured outputs and updated prompt. use the
if structuredOutputsState == structuredOutputsState_ReadyToApply {
structuredOutputsState = structuredOutputsState_Applying
msg := api.Message{
Role: "assistant",
Thinking: tb.String(),
}
msgs = append(msgs, msg)
prompt, _, err = chatPrompt(c.Request.Context(), m, r.Tokenize, opts, msgs, processedTools, req.Think)
content, thinking, toolCalls, err := builtinParser.Add(r.Content, r.Done)
if err != nil {
slog.Error("chat prompt error applying structured outputs", "error", err)
ch <- gin.H{"error": err.Error()}
return
}
// force constraining by terminating thinking header, the parser is already at this state
// when the last message is thinking, the rendered for gpt-oss cannot disambiguate between having the
// model continue thinking or ending thinking and outputting the final message.
// TODO(parthsareen): consider adding prefill disambiguation logic to the renderer for structured outputs.
if shouldUseHarmony(m) || (builtinParser != nil && m.Config.Parser == "harmony") {
prompt += "<|end|><|start|>assistant<|channel|>final<|message|>"
res.Message.Content = content
res.Message.Thinking = thinking
res.Message.ToolCalls = toolCalls
if res.Message.Content != "" || res.Message.Thinking != "" || len(res.Message.ToolCalls) > 0 || r.Done {
slog.Log(context.TODO(), logutil.LevelTrace, "builtin parser output", "parser", m.Config.Parser, "content", content, "thinking", thinking, "toolCalls", toolCalls, "done", r.Done)
ch <- res
} else {
slog.Log(context.TODO(), logutil.LevelTrace, "builtin parser empty output", "parser", m.Config.Parser)
}
continue
return
}
break
if thinkingState != nil {
thinkingContent, remainingContent := thinkingState.AddContent(res.Message.Content)
if thinkingContent == "" && remainingContent == "" && !r.Done {
// need to accumulate more to decide what to send
return
}
res.Message.Content = remainingContent
res.Message.Thinking = thinkingContent
}
if len(req.Tools) > 0 {
toolCalls, content := toolParser.Add(res.Message.Content)
if len(content) > 0 {
res.Message.Content = content
} else if len(toolCalls) > 0 {
res.Message.ToolCalls = toolCalls
res.Message.Content = ""
} else if res.Message.Thinking != "" {
// don't return
} else {
if r.Done {
res.Message.Content = toolParser.Content()
ch <- res
}
return
}
}
ch <- res
}); err != nil {
ch <- gin.H{"error": err.Error()}
}
}()

View File

@@ -13,6 +13,7 @@ import (
"github.com/gin-gonic/gin"
"github.com/google/go-cmp/cmp"
"github.com/google/go-cmp/cmp/cmpopts"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/discover"
@@ -46,6 +47,16 @@ func (mockRunner) Tokenize(_ context.Context, s string) (tokens []int, err error
return
}
func makeArgs(pairs ...any) api.ToolCallFunctionArguments {
args := api.NewToolCallFunctionArguments()
for i := 0; i < len(pairs); i += 2 {
key := pairs[i].(string)
value := pairs[i+1]
args.Set(key, value)
}
return args
}
func newMockServer(mock *mockRunner) func(discover.GpuInfoList, string, *ggml.GGML, []string, []string, api.Options, int) (llm.LlamaServer, error) {
return func(_ discover.GpuInfoList, _ string, _ *ggml.GGML, _, _ []string, _ api.Options, _ int) (llm.LlamaServer, error) {
return mock, nil
@@ -389,24 +400,26 @@ func TestGenerateChat(t *testing.T) {
Name: "get_weather",
Description: "Get the current weather",
Parameters: struct {
Type string `json:"type"`
Defs any `json:"$defs,omitempty"`
Items any `json:"items,omitempty"`
Required []string `json:"required"`
Properties map[string]api.ToolProperty `json:"properties"`
Type string `json:"type"`
Defs any `json:"$defs,omitempty"`
Items any `json:"items,omitempty"`
Required []string `json:"required"`
Properties *api.ToolProperties `json:"properties"`
}{
Type: "object",
Required: []string{"location"},
Properties: map[string]api.ToolProperty{
"location": {
Properties: func() *api.ToolProperties {
props := api.NewToolProperties()
props.Set("location", api.ToolProperty{
Type: api.PropertyType{"string"},
Description: "The city and state",
},
"unit": {
})
props.Set("unit", api.ToolProperty{
Type: api.PropertyType{"string"},
Enum: []any{"celsius", "fahrenheit"},
},
},
})
return props
}(),
},
},
},
@@ -459,15 +472,12 @@ func TestGenerateChat(t *testing.T) {
expectedToolCall := api.ToolCall{
Function: api.ToolCallFunction{
Name: "get_weather",
Arguments: api.ToolCallFunctionArguments{
"location": "Seattle, WA",
"unit": "celsius",
},
Name: "get_weather",
Arguments: makeArgs("location", "Seattle, WA", "unit", "celsius"),
},
}
if diff := cmp.Diff(resp.Message.ToolCalls[0], expectedToolCall); diff != "" {
if diff := cmp.Diff(resp.Message.ToolCalls[0], expectedToolCall, cmpopts.IgnoreUnexported(api.ToolCallFunctionArguments{}, api.ToolProperties{})); diff != "" {
t.Errorf("tool call mismatch (-got +want):\n%s", diff)
}
})
@@ -480,24 +490,26 @@ func TestGenerateChat(t *testing.T) {
Name: "get_weather",
Description: "Get the current weather",
Parameters: struct {
Type string `json:"type"`
Defs any `json:"$defs,omitempty"`
Items any `json:"items,omitempty"`
Required []string `json:"required"`
Properties map[string]api.ToolProperty `json:"properties"`
Type string `json:"type"`
Defs any `json:"$defs,omitempty"`
Items any `json:"items,omitempty"`
Required []string `json:"required"`
Properties *api.ToolProperties `json:"properties"`
}{
Type: "object",
Required: []string{"location"},
Properties: map[string]api.ToolProperty{
"location": {
Properties: func() *api.ToolProperties {
props := api.NewToolProperties()
props.Set("location", api.ToolProperty{
Type: api.PropertyType{"string"},
Description: "The city and state",
},
"unit": {
})
props.Set("unit", api.ToolProperty{
Type: api.PropertyType{"string"},
Enum: []any{"celsius", "fahrenheit"},
},
},
})
return props
}(),
},
},
},
@@ -582,15 +594,12 @@ func TestGenerateChat(t *testing.T) {
expectedToolCall := api.ToolCall{
Function: api.ToolCallFunction{
Name: "get_weather",
Arguments: api.ToolCallFunctionArguments{
"location": "Seattle, WA",
"unit": "celsius",
},
Name: "get_weather",
Arguments: makeArgs("location", "Seattle, WA", "unit", "celsius"),
},
}
if diff := cmp.Diff(finalToolCall, expectedToolCall); diff != "" {
if diff := cmp.Diff(finalToolCall, expectedToolCall, cmpopts.IgnoreUnexported(api.ToolCallFunctionArguments{}, api.ToolProperties{})); diff != "" {
t.Errorf("final tool call mismatch (-got +want):\n%s", diff)
}
})
@@ -1120,6 +1129,13 @@ func TestChatWithPromptEndingInThinkTag(t *testing.T) {
"The answer is 4.",
true)
testChatRequest(t, "thinking disabled but template still adds think tag",
"Simple question",
" My thoughts </think> The answer.",
"",
" My thoughts </think> The answer.",
false)
// Test streaming response with template-added <think>
t.Run("streaming with thinking", func(t *testing.T) {
var wg sync.WaitGroup
@@ -1191,238 +1207,4 @@ func TestChatWithPromptEndingInThinkTag(t *testing.T) {
t.Errorf("expected content %q, got %q", "Based on my analysis, the solution is straightforward.", got)
}
})
t.Run("structured outputs restart non-stream", func(t *testing.T) {
var (
requestsMu sync.Mutex
requests []llm.CompletionRequest
wg sync.WaitGroup
)
wg.Add(2)
format := json.RawMessage(`{"type":"object","properties":{"answer":{"type":"string"}}}`)
mock.CompletionFn = func(ctx context.Context, r llm.CompletionRequest, fn func(r llm.CompletionResponse)) error {
defer wg.Done()
requestsMu.Lock()
requests = append(requests, r)
callNum := len(requests)
requestsMu.Unlock()
switch callNum {
case 1:
fn(llm.CompletionResponse{
Content: " I am thinking through this problem. </think> {\"answer\":\"42\"}",
Done: false,
PromptEvalCount: 1,
PromptEvalDuration: 1,
})
select {
case <-ctx.Done():
return ctx.Err()
case <-time.After(time.Second):
t.Fatalf("timeout waiting for structured outputs cancellation")
return nil
}
case 2:
fn(llm.CompletionResponse{
Content: `{"answer":"42"}`,
Done: true,
DoneReason: llm.DoneReasonStop,
PromptEvalCount: 1,
PromptEvalDuration: 1,
EvalCount: 1,
EvalDuration: 1,
})
return nil
default:
t.Fatalf("unexpected number of completion calls: %d", callNum)
return nil
}
}
think := true
streamRequest := false
w := createRequest(t, s.ChatHandler, api.ChatRequest{
Model: "test-thinking",
Messages: []api.Message{{Role: "user", Content: "Please respond in JSON."}},
Think: &api.ThinkValue{Value: think},
Stream: &streamRequest,
Format: format,
})
wg.Wait()
mock.CompletionFn = nil
if w.Code != http.StatusOK {
t.Fatalf("expected status 200, got %d", w.Code)
}
if len(requests) != 2 {
t.Fatalf("expected two completion calls, got %d", len(requests))
}
if requests[0].Format != nil {
t.Errorf("expected first completion format to be nil, got %q", requests[0].Format)
}
if !bytes.Equal([]byte(format), []byte(requests[1].Format)) {
t.Errorf("expected second completion format to match original format")
}
var resp api.ChatResponse
if err := json.NewDecoder(w.Body).Decode(&resp); err != nil {
t.Fatal(err)
}
if resp.Message.Thinking != "I am thinking through this problem. " {
t.Errorf("expected thinking %q, got %q", "I am thinking through this problem. ", resp.Message.Thinking)
}
if resp.Message.Content != `{"answer":"42"}` {
t.Errorf("expected content %q, got %q", `{"answer":"42"}`, resp.Message.Content)
}
if !resp.Done {
t.Errorf("expected response to be done")
}
if resp.DoneReason != "stop" {
t.Errorf("expected done reason stop, got %s", resp.DoneReason)
}
})
t.Run("structured outputs restart streaming", func(t *testing.T) {
var (
requestsMu sync.Mutex
requests []llm.CompletionRequest
wg sync.WaitGroup
)
wg.Add(2)
format := json.RawMessage(`{"type":"object","properties":{"answer":{"type":"string"}}}`)
mock.CompletionFn = func(ctx context.Context, r llm.CompletionRequest, fn func(r llm.CompletionResponse)) error {
defer wg.Done()
requestsMu.Lock()
requests = append(requests, r)
callNum := len(requests)
requestsMu.Unlock()
switch callNum {
case 1:
fn(llm.CompletionResponse{
Content: " I am thinking through this problem. </think> {\"answer\":\"42\"}",
Done: false,
PromptEvalCount: 1,
PromptEvalDuration: 1,
})
select {
case <-ctx.Done():
return ctx.Err()
case <-time.After(time.Second):
t.Fatalf("timeout waiting for structured outputs cancellation")
return nil
}
case 2:
fn(llm.CompletionResponse{
Content: `{"answer":"42"}`,
Done: true,
DoneReason: llm.DoneReasonStop,
PromptEvalCount: 1,
PromptEvalDuration: 1,
EvalCount: 1,
EvalDuration: 1,
})
return nil
default:
t.Fatalf("unexpected number of completion calls: %d", callNum)
return nil
}
}
think := true
streamRequest := true
w := createRequest(t, s.ChatHandler, api.ChatRequest{
Model: "test-thinking",
Messages: []api.Message{{Role: "user", Content: "Please respond in JSON."}},
Think: &api.ThinkValue{Value: think},
Stream: &streamRequest,
Format: format,
})
wg.Wait()
mock.CompletionFn = nil
if w.Code != http.StatusOK {
t.Fatalf("expected status 200, got %d", w.Code)
}
if len(requests) != 2 {
t.Fatalf("expected two completion calls, got %d", len(requests))
}
if requests[0].Format != nil {
t.Errorf("expected first completion format to be nil, got %q", requests[0].Format)
}
if !bytes.Equal([]byte(format), []byte(requests[1].Format)) {
t.Errorf("expected second completion format to match original format")
}
decoder := json.NewDecoder(w.Body)
var events []api.ChatResponse
for {
var event api.ChatResponse
if err := decoder.Decode(&event); err == io.EOF {
break
} else if err != nil {
t.Fatal(err)
}
events = append(events, event)
if event.Done {
break
}
}
if len(events) < 2 {
t.Fatalf("expected at least two streaming events, got %d", len(events))
}
first := events[0]
if first.Message.Thinking != "I am thinking through this problem. " {
t.Errorf("expected first event thinking %q, got %q", "I am thinking through this problem. ", first.Message.Thinking)
}
if first.Message.Content != "" {
t.Errorf("expected first event content to be empty, got %q", first.Message.Content)
}
if first.Done {
t.Error("expected first event to be non-terminal")
}
last := events[len(events)-1]
if last.Message.Thinking != "" {
t.Errorf("expected final event thinking to be empty, got %q", last.Message.Thinking)
}
if last.Message.Content != `{"answer":"42"}` {
t.Errorf("expected final event content %q, got %q", `{"answer":"42"}`, last.Message.Content)
}
if !last.Done {
t.Error("expected final event to be done")
}
if last.DoneReason != "stop" {
t.Errorf("expected final done reason stop, got %s", last.DoneReason)
}
})
}

View File

@@ -13,6 +13,7 @@ import (
"time"
"github.com/gin-gonic/gin"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/discover"
"github.com/ollama/ollama/fs/ggml"
@@ -27,20 +28,22 @@ func getTestTools() []api.Tool {
Name: "get_weather",
Description: "Get the current weather in a given location",
Parameters: struct {
Type string `json:"type"`
Defs any `json:"$defs,omitempty"`
Items any `json:"items,omitempty"`
Required []string `json:"required"`
Properties map[string]api.ToolProperty `json:"properties"`
Type string `json:"type"`
Defs any `json:"$defs,omitempty"`
Items any `json:"items,omitempty"`
Required []string `json:"required"`
Properties *api.ToolProperties `json:"properties"`
}{
Type: "object",
Required: []string{"location"},
Properties: map[string]api.ToolProperty{
"location": {
Properties: func() *api.ToolProperties {
props := api.NewToolProperties()
props.Set("location", api.ToolProperty{
Type: api.PropertyType{"string"},
Description: "The city and state, e.g. San Francisco, CA",
},
},
})
return props
}(),
},
},
},
@@ -50,20 +53,22 @@ func getTestTools() []api.Tool {
Name: "calculate",
Description: "Calculate a mathematical expression",
Parameters: struct {
Type string `json:"type"`
Defs any `json:"$defs,omitempty"`
Items any `json:"items,omitempty"`
Required []string `json:"required"`
Properties map[string]api.ToolProperty `json:"properties"`
Type string `json:"type"`
Defs any `json:"$defs,omitempty"`
Items any `json:"items,omitempty"`
Required []string `json:"required"`
Properties *api.ToolProperties `json:"properties"`
}{
Type: "object",
Required: []string{"expression"},
Properties: map[string]api.ToolProperty{
"expression": {
Properties: func() *api.ToolProperties {
props := api.NewToolProperties()
props.Set("expression", api.ToolProperty{
Type: api.PropertyType{"string"},
Description: "The mathematical expression to calculate",
},
},
})
return props
}(),
},
},
},
@@ -196,10 +201,8 @@ func TestChatHarmonyParserStreamingRealtime(t *testing.T) {
wantToolCalls: []api.ToolCall{
{
Function: api.ToolCallFunction{
Name: "get_weather",
Arguments: api.ToolCallFunctionArguments{
"location": "San Francisco",
},
Name: "get_weather",
Arguments: makeArgs("location", "San Francisco"),
},
},
},
@@ -222,10 +225,8 @@ func TestChatHarmonyParserStreamingRealtime(t *testing.T) {
wantToolCalls: []api.ToolCall{
{
Function: api.ToolCallFunction{
Name: "calculate",
Arguments: api.ToolCallFunctionArguments{
"expression": "2+2",
},
Name: "calculate",
Arguments: makeArgs("expression", "2+2"),
},
},
},

View File

@@ -229,9 +229,8 @@ func (s *Scheduler) processPending(ctx context.Context) {
}
if runnerToExpire == nil {
// While we were performing load calculations, the loaded runner(s) unloaded in parallel
// so findRunnerToUnload returned no runners. We'll try again and the loadedCount should be zero
slog.Debug("runner to expire was nil, retrying")
// Shouildn't happen
slog.Error("runner to expire was nil!")
continue
}
// Trigger an expiration to unload once it's done

149
template/rewrite.go Normal file
View File

@@ -0,0 +1,149 @@
package template
import (
"text/template"
"text/template/parse"
)
// rewritePropertiesCheck walks the template AST and rewrites .Function.Parameters.Properties
// to .Function.Parameters.HasProperties in if/with conditions to fix truthiness checking.
// This maintains backward compatibility with templates that check if Properties exist.
func rewritePropertiesCheck(tmpl *template.Template) {
walk(tmpl.Tree.Root)
}
func walk(n parse.Node) {
if n == nil {
return
}
switch node := n.(type) {
case *parse.ListNode:
for _, child := range node.Nodes {
walk(child)
}
case *parse.ActionNode:
// Rewrite len calls in action nodes
rewritePipeProperties(node.Pipe)
case *parse.IfNode:
rewritePipeProperties(node.Pipe)
walk(&node.BranchNode)
case *parse.WithNode:
rewritePipeProperties(node.Pipe)
walk(&node.BranchNode)
case *parse.RangeNode:
// Don't rewrite the pipe for range nodes - they need .Properties for iteration
walk(&node.BranchNode)
case *parse.BranchNode:
if node.List != nil {
walk(node.List)
}
if node.ElseList != nil {
walk(node.ElseList)
}
}
}
func rewritePipeProperties(pipe *parse.PipeNode) {
if pipe == nil {
return
}
for _, cmd := range pipe.Cmds {
rewriteCommand(cmd)
}
}
// rewriteCommand recursively rewrites a command and all its nested command arguments
func rewriteCommand(cmd *parse.CommandNode) {
// Check if this is a "len .Function.Parameters.Properties" call
if isLenPropertiesCall(cmd) {
// Replace entire command with .Function.Parameters.Len field access
replaceLenWithLenMethod(cmd)
return
}
// Recursively process all arguments
for i, arg := range cmd.Args {
switch argNode := arg.(type) {
case *parse.FieldNode:
// Check for direct .Properties field access
if isPropertiesField(argNode.Ident) {
cmd.Args[i] = replaceWithHasProperties(argNode)
}
case *parse.CommandNode:
// Recursively process nested commands (e.g., inside "and", "gt", etc.)
rewriteCommand(argNode)
case *parse.PipeNode:
// Template function arguments can be wrapped in PipeNodes
rewritePipeProperties(argNode)
}
}
}
// isLenPropertiesCall checks if a command is "len .Function.Parameters.Properties"
func isLenPropertiesCall(cmd *parse.CommandNode) bool {
if len(cmd.Args) != 2 {
return false
}
// First arg should be the "len" identifier
if ident, ok := cmd.Args[0].(*parse.IdentifierNode); !ok || ident.Ident != "len" {
return false
}
// Second arg should be .Function.Parameters.Properties field
if field, ok := cmd.Args[1].(*parse.FieldNode); ok {
return isPropertiesField(field.Ident)
}
return false
}
// replaceLenWithLenMethod replaces "len .Function.Parameters.Properties" with ".Function.Parameters.Len"
func replaceLenWithLenMethod(cmd *parse.CommandNode) {
if len(cmd.Args) < 2 {
return
}
field, ok := cmd.Args[1].(*parse.FieldNode)
if !ok {
return
}
// Create new field node with .Len instead of .Properties
newIdent := make([]string, len(field.Ident))
copy(newIdent, field.Ident)
newIdent[len(newIdent)-1] = "Len"
newField := &parse.FieldNode{
NodeType: parse.NodeField,
Ident: newIdent,
Pos: field.Pos,
}
// Replace the command with just the field access (remove "len" function call)
cmd.Args = []parse.Node{newField}
}
func isPropertiesField(ident []string) bool {
// Match: .Function.Parameters.Properties
// We only rewrite if it ends with Parameters.Properties to avoid false positives
if len(ident) < 3 {
return false
}
return ident[len(ident)-1] == "Properties" && ident[len(ident)-2] == "Parameters"
}
func replaceWithHasProperties(field *parse.FieldNode) *parse.FieldNode {
// Clone the identifier slice and replace the last element
newIdent := make([]string, len(field.Ident))
copy(newIdent, field.Ident)
newIdent[len(newIdent)-1] = "HasProperties"
return &parse.FieldNode{
NodeType: parse.NodeField,
Ident: newIdent,
Pos: field.Pos,
}
}

131
template/rewrite_test.go Normal file
View File

@@ -0,0 +1,131 @@
package template
import (
"bytes"
"testing"
"text/template"
"github.com/ollama/ollama/api"
)
func TestRewritePropertiesCheck(t *testing.T) {
makeToolWithProps := func(props *api.ToolProperties) api.Tool {
return api.Tool{
Type: "function",
Function: api.ToolFunction{
Name: "test",
Description: "test function",
Parameters: api.NewToolFunctionParametersWithProps("object", nil, props),
},
}
}
tests := []struct {
name string
template string
data interface{}
expected string
}{
{
name: "if statement with Properties gets rewritten to HasProperties",
template: `{{if .Function.Parameters.Properties}}Has props{{else}}No props{{end}}`,
data: makeToolWithProps(nil),
expected: "No props", // Should use HasProperties which returns false for empty
},
{
name: "if statement with Properties and non-empty properties",
template: `{{if .Function.Parameters.Properties}}Has props{{else}}No props{{end}}`,
data: makeToolWithProps(func() *api.ToolProperties {
p := api.NewToolProperties()
p.Set("test", api.ToolProperty{Type: api.PropertyType{"string"}})
return p
}()),
expected: "Has props", // Should use HasProperties which returns true
},
{
name: "range over Properties should not be rewritten",
template: `{{range $k, $v := .Function.Parameters.Properties}}{{$k}} {{end}}`,
data: makeToolWithProps(func() *api.ToolProperties {
p := api.NewToolProperties()
p.Set("foo", api.ToolProperty{Type: api.PropertyType{"string"}})
p.Set("bar", api.ToolProperty{Type: api.PropertyType{"number"}})
return p
}()),
expected: "foo bar ", // Should still use Properties() for ranging
},
{
name: "complex template with both if and range",
template: `{{if .Function.Parameters.Properties}}Args:
{{range $k, $v := .Function.Parameters.Properties}} {{$k}}
{{end}}{{else}}No args{{end}}`,
data: makeToolWithProps(func() *api.ToolProperties {
p := api.NewToolProperties()
p.Set("location", api.ToolProperty{Type: api.PropertyType{"string"}})
return p
}()),
expected: "Args:\n location\n",
},
{
name: "if with and condition",
template: `{{if and .Function.Parameters.Properties (gt (len .Function.Parameters.Properties) 0)}}yes{{else}}no{{end}}`,
data: makeToolWithProps(nil),
expected: "no", // Empty, so HasProperties returns false
},
{
name: "len function on Properties gets rewritten to Len method",
template: `{{len .Function.Parameters.Properties}}`,
data: makeToolWithProps(nil),
expected: "0", // Empty properties should have length 0
},
{
name: "len function on non-empty Properties",
template: `{{len .Function.Parameters.Properties}}`,
data: makeToolWithProps(func() *api.ToolProperties {
p := api.NewToolProperties()
p.Set("foo", api.ToolProperty{Type: api.PropertyType{"string"}})
p.Set("bar", api.ToolProperty{Type: api.PropertyType{"number"}})
return p
}()),
expected: "2", // Two properties
},
{
name: "nested len in and/gt (gpt-oss pattern)",
template: `{{if and .Function.Parameters.Properties (gt (len .Function.Parameters.Properties) 0)}}has props{{else}}no props{{end}}`,
data: makeToolWithProps(nil),
expected: "no props", // Empty, so both checks should be false
},
{
name: "nested len in and/gt with properties",
template: `{{if and .Function.Parameters.Properties (gt (len .Function.Parameters.Properties) 0)}}has props{{else}}no props{{end}}`,
data: makeToolWithProps(func() *api.ToolProperties {
p := api.NewToolProperties()
p.Set("test", api.ToolProperty{Type: api.PropertyType{"string"}})
return p
}()),
expected: "has props", // Has properties, both checks should be true
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Use text/template directly and call rewritePropertiesCheck
tmpl, err := template.New("test").Parse(tt.template)
if err != nil {
t.Fatalf("Failed to parse template: %v", err)
}
// Apply the rewrite
rewritePropertiesCheck(tmpl)
var buf bytes.Buffer
err = tmpl.Execute(&buf, tt.data)
if err != nil {
t.Fatalf("Failed to execute template: %v", err)
}
if buf.String() != tt.expected {
t.Errorf("Expected %q, got %q", tt.expected, buf.String())
}
})
}
}

View File

@@ -147,6 +147,10 @@ func Parse(s string) (*Template, error) {
return nil, err
}
// Rewrite .Function.Parameters.Properties to .Function.Parameters.HasProperties
// in if/with conditions for backward compatibility with templates
rewritePropertiesCheck(tmpl)
t := Template{Template: tmpl, raw: s}
vars, err := t.Vars()
if err != nil {

View File

@@ -124,16 +124,21 @@ func (p *Parser) parseToolCall() *api.ToolCall {
return nil
}
var args map[string]any
var argsMap map[string]any
if found, i := findArguments(p.buffer); found == nil {
return nil
} else {
args = found
argsMap = found
if i > end {
end = i
}
}
args := api.NewToolCallFunctionArguments()
for k, v := range argsMap {
args.Set(k, v)
}
tc := &api.ToolCall{
Function: api.ToolCallFunction{
Name: tool.Function.Name,

View File

@@ -6,9 +6,33 @@ import (
"text/template"
"github.com/google/go-cmp/cmp"
"github.com/google/go-cmp/cmp/cmpopts"
"github.com/ollama/ollama/api"
)
func makeArgs(pairs ...any) api.ToolCallFunctionArguments {
args := api.NewToolCallFunctionArguments()
for i := 0; i < len(pairs); i += 2 {
key := pairs[i].(string)
value := pairs[i+1]
args.Set(key, value)
}
return args
}
// helper to build ToolFunctionParameters with properties
func makeParams(typ string, required []string, propsFn func() *api.ToolProperties) api.ToolFunctionParameters {
params := api.ToolFunctionParameters{
Type: typ,
Required: required,
}
if propsFn != nil {
params.SetProperties(propsFn())
}
return params
}
func TestParser(t *testing.T) {
qwen, err := template.New("qwen").Parse(`{{if .ToolCalls}}<tool_call>{{range .ToolCalls}}{"name": "{{.Function.Name}}", "arguments": {{.Function.Arguments}}}{{end}}</tool_call>{{end}}`)
if err != nil {
@@ -41,21 +65,19 @@ func TestParser(t *testing.T) {
Function: api.ToolFunction{
Name: "get_temperature",
Description: "Retrieve the temperature for a given location",
Parameters: api.ToolFunctionParameters{
Type: "object",
Required: []string{"city"},
Properties: map[string]api.ToolProperty{
"format": {
Parameters: makeParams("object", []string{"city"}, func() *api.ToolProperties {
props := api.NewToolProperties()
props.Set("format", api.ToolProperty{
Type: api.PropertyType{"string"},
Description: "The format to return the temperature in",
Enum: []any{"fahrenheit", "celsius"},
},
"city": {
})
props.Set("city", api.ToolProperty{
Type: api.PropertyType{"string"},
Description: "The city to get the temperature for",
},
},
},
})
return props
}),
},
},
{
@@ -63,15 +85,14 @@ func TestParser(t *testing.T) {
Function: api.ToolFunction{
Name: "get_conditions",
Description: "Retrieve the current weather conditions for a given location",
Parameters: api.ToolFunctionParameters{
Type: "object",
Properties: map[string]api.ToolProperty{
"location": {
Parameters: makeParams("object", nil, func() *api.ToolProperties {
props := api.NewToolProperties()
props.Set("location", api.ToolProperty{
Type: api.PropertyType{"string"},
Description: "The location to get the weather conditions for",
},
},
},
})
return props
}),
},
},
{
@@ -93,15 +114,14 @@ func TestParser(t *testing.T) {
Function: api.ToolFunction{
Name: "get_address",
Description: "Get the address of a given location",
Parameters: api.ToolFunctionParameters{
Type: "object",
Properties: map[string]api.ToolProperty{
"location": {
Parameters: makeParams("object", nil, func() *api.ToolProperties {
props := api.NewToolProperties()
props.Set("location", api.ToolProperty{
Type: api.PropertyType{"string"},
Description: "The location to get the address for",
},
},
},
})
return props
}),
},
},
{
@@ -109,19 +129,18 @@ func TestParser(t *testing.T) {
Function: api.ToolFunction{
Name: "add",
Description: "Add two numbers",
Parameters: api.ToolFunctionParameters{
Type: "object",
Properties: map[string]api.ToolProperty{
"a": {
Parameters: makeParams("object", nil, func() *api.ToolProperties {
props := api.NewToolProperties()
props.Set("a", api.ToolProperty{
Type: api.PropertyType{"string"},
Description: "The first number to add",
},
"b": {
})
props.Set("b", api.ToolProperty{
Type: api.PropertyType{"string"},
Description: "The second number to add",
},
},
},
})
return props
}),
},
},
}
@@ -155,11 +174,9 @@ func TestParser(t *testing.T) {
calls: []api.ToolCall{
{
Function: api.ToolCallFunction{
Index: 0,
Name: "get_conditions",
Arguments: api.ToolCallFunctionArguments{
"location": "San Francisco",
},
Index: 0,
Name: "get_conditions",
Arguments: makeArgs("location", "San Francisco"),
},
},
},
@@ -174,7 +191,7 @@ func TestParser(t *testing.T) {
Function: api.ToolCallFunction{
Index: 0,
Name: "get_conditions",
Arguments: api.ToolCallFunctionArguments{},
Arguments: makeArgs(),
},
},
},
@@ -187,11 +204,9 @@ func TestParser(t *testing.T) {
calls: []api.ToolCall{
{
Function: api.ToolCallFunction{
Index: 0,
Name: "get_temperature",
Arguments: api.ToolCallFunctionArguments{
"city": "New York",
},
Index: 0,
Name: "get_temperature",
Arguments: makeArgs("city", "New York"),
},
},
},
@@ -211,21 +226,16 @@ func TestParser(t *testing.T) {
calls: []api.ToolCall{
{
Function: api.ToolCallFunction{
Index: 0,
Name: "get_temperature",
Arguments: api.ToolCallFunctionArguments{
"city": "London",
"format": "fahrenheit",
},
Index: 0,
Name: "get_temperature",
Arguments: makeArgs("city", "London", "format", "fahrenheit"),
},
},
{
Function: api.ToolCallFunction{
Index: 1,
Name: "get_conditions",
Arguments: api.ToolCallFunctionArguments{
"location": "Tokyo",
},
Index: 1,
Name: "get_conditions",
Arguments: makeArgs("location", "Tokyo"),
},
},
},
@@ -238,21 +248,16 @@ func TestParser(t *testing.T) {
calls: []api.ToolCall{
{
Function: api.ToolCallFunction{
Index: 0,
Name: "get_temperature",
Arguments: api.ToolCallFunctionArguments{
"city": "London",
"format": "fahrenheit",
},
Index: 0,
Name: "get_temperature",
Arguments: makeArgs("city", "London", "format", "fahrenheit"),
},
},
{
Function: api.ToolCallFunction{
Index: 1,
Name: "get_conditions",
Arguments: api.ToolCallFunctionArguments{
"location": "Tokyo",
},
Index: 1,
Name: "get_conditions",
Arguments: makeArgs("location", "Tokyo"),
},
},
},
@@ -267,17 +272,14 @@ func TestParser(t *testing.T) {
Function: api.ToolCallFunction{
Index: 0,
Name: "say_hello",
Arguments: api.ToolCallFunctionArguments{},
Arguments: makeArgs(),
},
},
{
Function: api.ToolCallFunction{
Index: 1,
Name: "get_temperature",
Arguments: api.ToolCallFunctionArguments{
"city": "London",
"format": "fahrenheit",
},
Index: 1,
Name: "get_temperature",
Arguments: makeArgs("city", "London", "format", "fahrenheit"),
},
},
},
@@ -292,16 +294,14 @@ func TestParser(t *testing.T) {
Function: api.ToolCallFunction{
Index: 0,
Name: "get_conditions",
Arguments: api.ToolCallFunctionArguments{},
Arguments: makeArgs(),
},
},
{
Function: api.ToolCallFunction{
Index: 1,
Name: "get_conditions",
Arguments: api.ToolCallFunctionArguments{
"location": "Tokyo",
},
Index: 1,
Name: "get_conditions",
Arguments: makeArgs("location", "Tokyo"),
},
},
},
@@ -314,11 +314,9 @@ func TestParser(t *testing.T) {
calls: []api.ToolCall{
{
Function: api.ToolCallFunction{
Index: 0,
Name: "get_temperature",
Arguments: api.ToolCallFunctionArguments{
"city": "Tokyo",
},
Index: 0,
Name: "get_temperature",
Arguments: makeArgs("city", "Tokyo"),
},
},
},
@@ -345,11 +343,9 @@ func TestParser(t *testing.T) {
calls: []api.ToolCall{
{
Function: api.ToolCallFunction{
Index: 0,
Name: "get_temperature",
Arguments: api.ToolCallFunctionArguments{
"city": "Tokyo",
},
Index: 0,
Name: "get_temperature",
Arguments: makeArgs("city", "Tokyo"),
},
},
},
@@ -369,11 +365,9 @@ func TestParser(t *testing.T) {
calls: []api.ToolCall{
{
Function: api.ToolCallFunction{
Index: 0,
Name: "get_temperature",
Arguments: api.ToolCallFunctionArguments{
"city": "Tokyo",
},
Index: 0,
Name: "get_temperature",
Arguments: makeArgs("city", "Tokyo"),
},
},
},
@@ -451,20 +445,16 @@ func TestParser(t *testing.T) {
calls: []api.ToolCall{
{
Function: api.ToolCallFunction{
Index: 0,
Name: "get_temperature",
Arguments: api.ToolCallFunctionArguments{
"city": "London",
},
Index: 0,
Name: "get_temperature",
Arguments: makeArgs("city", "London"),
},
},
{
Function: api.ToolCallFunction{
Index: 1,
Name: "get_conditions",
Arguments: api.ToolCallFunctionArguments{
"location": "Tokyo",
},
Index: 1,
Name: "get_conditions",
Arguments: makeArgs("location", "Tokyo"),
},
},
},
@@ -484,11 +474,9 @@ func TestParser(t *testing.T) {
calls: []api.ToolCall{
{
Function: api.ToolCallFunction{
Index: 0,
Name: "get_conditions",
Arguments: api.ToolCallFunctionArguments{
"location": "Tokyo",
},
Index: 0,
Name: "get_conditions",
Arguments: makeArgs("location", "Tokyo"),
},
},
},
@@ -526,11 +514,9 @@ func TestParser(t *testing.T) {
calls: []api.ToolCall{
{
Function: api.ToolCallFunction{
Index: 0,
Name: "get_conditions",
Arguments: api.ToolCallFunctionArguments{
"location": "Tokyo",
},
Index: 0,
Name: "get_conditions",
Arguments: makeArgs("location", "Tokyo"),
},
},
},
@@ -563,7 +549,7 @@ func TestParser(t *testing.T) {
Function: api.ToolCallFunction{
Index: 0,
Name: "say_hello_world",
Arguments: api.ToolCallFunctionArguments{},
Arguments: makeArgs(),
},
},
},
@@ -591,14 +577,14 @@ func TestParser(t *testing.T) {
Function: api.ToolCallFunction{
Index: 0,
Name: "say_hello_world",
Arguments: api.ToolCallFunctionArguments{},
Arguments: makeArgs(),
},
},
{
Function: api.ToolCallFunction{
Index: 1,
Name: "say_hello",
Arguments: api.ToolCallFunctionArguments{},
Arguments: makeArgs(),
},
},
},
@@ -624,14 +610,14 @@ func TestParser(t *testing.T) {
Function: api.ToolCallFunction{
Index: 0,
Name: "say_hello",
Arguments: api.ToolCallFunctionArguments{},
Arguments: makeArgs(),
},
},
{
Function: api.ToolCallFunction{
Index: 1,
Name: "say_hello_world",
Arguments: api.ToolCallFunctionArguments{},
Arguments: makeArgs(),
},
},
},
@@ -648,7 +634,7 @@ func TestParser(t *testing.T) {
Function: api.ToolCallFunction{
Index: 0,
Name: "say_hello",
Arguments: api.ToolCallFunctionArguments{},
Arguments: makeArgs(),
},
},
},
@@ -665,7 +651,7 @@ func TestParser(t *testing.T) {
Function: api.ToolCallFunction{
Index: 0,
Name: "say_hello_world",
Arguments: api.ToolCallFunctionArguments{},
Arguments: makeArgs(),
},
},
},
@@ -685,11 +671,9 @@ func TestParser(t *testing.T) {
calls: []api.ToolCall{
{
Function: api.ToolCallFunction{
Index: 0,
Name: "get_address",
Arguments: api.ToolCallFunctionArguments{
"location": "London",
},
Index: 0,
Name: "get_address",
Arguments: makeArgs("location", "London"),
},
},
},
@@ -704,11 +688,9 @@ func TestParser(t *testing.T) {
calls: []api.ToolCall{
{
Function: api.ToolCallFunction{
Index: 0,
Name: "get_address",
Arguments: api.ToolCallFunctionArguments{
"location": "London",
},
Index: 0,
Name: "get_address",
Arguments: makeArgs("location", "London"),
},
},
},
@@ -723,12 +705,9 @@ func TestParser(t *testing.T) {
calls: []api.ToolCall{
{
Function: api.ToolCallFunction{
Index: 0,
Name: "add",
Arguments: api.ToolCallFunctionArguments{
"a": "5",
"b": "10",
},
Index: 0,
Name: "add",
Arguments: makeArgs("a", "5", "b", "10"),
},
},
},
@@ -756,7 +735,7 @@ func TestParser(t *testing.T) {
}
for i, want := range tt.calls {
if diff := cmp.Diff(calls[i], want); diff != "" {
if diff := cmp.Diff(calls[i], want, cmpopts.IgnoreUnexported(api.ToolCallFunctionArguments{})); diff != "" {
t.Errorf("Tool call %d mismatch (-got +want):\n%s", i, diff)
}
}