diff --git a/anthropic/anthropic.go b/anthropic/anthropic.go index f4d9a6e23..9cb2c75c4 100644 --- a/anthropic/anthropic.go +++ b/anthropic/anthropic.go @@ -377,7 +377,7 @@ func convertMessage(msg MessageParam) ([]api.Message, error) { }, } if input, ok := blockMap["input"].(map[string]any); ok { - tc.Function.Arguments = api.ToolCallFunctionArguments(input) + tc.Function.Arguments = mapToArgs(input) } toolCalls = append(toolCalls, tc) @@ -767,3 +767,12 @@ func GenerateMessageID() string { func ptr(s string) *string { return &s } + +// mapToArgs converts a map to ToolCallFunctionArguments +func mapToArgs(m map[string]any) api.ToolCallFunctionArguments { + args := api.NewToolCallFunctionArguments() + for k, v := range m { + args.Set(k, v) + } + return args +} diff --git a/anthropic/anthropic_test.go b/anthropic/anthropic_test.go index 8228bd37b..117d183c9 100644 --- a/anthropic/anthropic_test.go +++ b/anthropic/anthropic_test.go @@ -14,6 +14,15 @@ const ( testImage = `iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVR42mNk+A8AAQUBAScY42YAAAAASUVORK5CYII=` ) +// testArgs creates ToolCallFunctionArguments from a map (convenience function for tests) +func testArgs(m map[string]any) api.ToolCallFunctionArguments { + args := api.NewToolCallFunctionArguments() + for k, v := range m { + args.Set(k, v) + } + return args +} + func TestFromMessagesRequest_Basic(t *testing.T) { req := MessagesRequest{ Model: "test-model", @@ -468,7 +477,7 @@ func TestToMessagesResponse_WithToolCalls(t *testing.T) { ID: "call_123", Function: api.ToolCallFunction{ Name: "get_weather", - Arguments: map[string]any{"location": "Paris"}, + Arguments: testArgs(map[string]any{"location": "Paris"}), }, }, }, @@ -662,7 +671,7 @@ func TestStreamConverter_WithToolCalls(t *testing.T) { ID: "call_123", Function: api.ToolCallFunction{ Name: "get_weather", - Arguments: map[string]any{"location": "Paris"}, + Arguments: testArgs(map[string]any{"location": "Paris"}), }, }, }, @@ -708,6 +717,8 @@ func TestStreamConverter_ToolCallWithUnmarshalableArgs(t *testing.T) { // Create a channel which cannot be JSON marshaled unmarshalable := make(chan int) + badArgs := api.NewToolCallFunctionArguments() + badArgs.Set("channel", unmarshalable) resp := api.ChatResponse{ Model: "test-model", @@ -718,7 +729,7 @@ func TestStreamConverter_ToolCallWithUnmarshalableArgs(t *testing.T) { ID: "call_bad", Function: api.ToolCallFunction{ Name: "bad_function", - Arguments: map[string]any{"channel": unmarshalable}, + Arguments: badArgs, }, }, }, @@ -752,6 +763,8 @@ func TestStreamConverter_MultipleToolCallsWithMixedValidity(t *testing.T) { conv := NewStreamConverter("msg_123", "test-model") unmarshalable := make(chan int) + badArgs := api.NewToolCallFunctionArguments() + badArgs.Set("channel", unmarshalable) resp := api.ChatResponse{ Model: "test-model", @@ -762,14 +775,14 @@ func TestStreamConverter_MultipleToolCallsWithMixedValidity(t *testing.T) { ID: "call_good", Function: api.ToolCallFunction{ Name: "good_function", - Arguments: map[string]any{"location": "Paris"}, + Arguments: testArgs(map[string]any{"location": "Paris"}), }, }, { ID: "call_bad", Function: api.ToolCallFunction{ Name: "bad_function", - Arguments: map[string]any{"channel": unmarshalable}, + Arguments: badArgs, }, }, }, diff --git a/middleware/anthropic_test.go b/middleware/anthropic_test.go index b444e83ab..40df7fbb4 100644 --- a/middleware/anthropic_test.go +++ b/middleware/anthropic_test.go @@ -11,6 +11,7 @@ import ( "github.com/gin-gonic/gin" "github.com/google/go-cmp/cmp" + "github.com/google/go-cmp/cmp/cmpopts" "github.com/ollama/ollama/anthropic" "github.com/ollama/ollama/api" @@ -25,6 +26,15 @@ func captureAnthropicRequest(capturedRequest any) gin.HandlerFunc { } } +// testProps creates ToolPropertiesMap from a map (convenience function for tests) +func testProps(m map[string]api.ToolProperty) *api.ToolPropertiesMap { + props := api.NewToolPropertiesMap() + for k, v := range m { + props.Set(k, v) + } + return props +} + func TestAnthropicMessagesMiddleware(t *testing.T) { type testCase struct { name string @@ -156,9 +166,9 @@ func TestAnthropicMessagesMiddleware(t *testing.T) { Parameters: api.ToolFunctionParameters{ Type: "object", Required: []string{"location"}, - Properties: map[string]api.ToolProperty{ + Properties: testProps(map[string]api.ToolProperty{ "location": {Type: api.PropertyType{"string"}}, - }, + }), }, }, }, @@ -193,7 +203,7 @@ func TestAnthropicMessagesMiddleware(t *testing.T) { ID: "call_123", Function: api.ToolCallFunction{ Name: "get_weather", - Arguments: api.ToolCallFunctionArguments{"location": "Paris"}, + Arguments: testArgs(map[string]any{"location": "Paris"}), }, }, }, @@ -344,7 +354,8 @@ func TestAnthropicMessagesMiddleware(t *testing.T) { t.Errorf("model mismatch: got %q, want %q", capturedRequest.Model, tc.req.Model) } - if diff := cmp.Diff(tc.req.Messages, capturedRequest.Messages); diff != "" { + if diff := cmp.Diff(tc.req.Messages, capturedRequest.Messages, + cmpopts.IgnoreUnexported(api.ToolCallFunctionArguments{}, api.ToolPropertiesMap{})); diff != "" { t.Errorf("messages mismatch (-want +got):\n%s", diff) } @@ -492,11 +503,11 @@ func TestAnthropicWriter_ErrorFromRoutes(t *testing.T) { gin.SetMode(gin.TestMode) tests := []struct { - name string - statusCode int - errorPayload any - wantErrorType string - wantMessage string + name string + statusCode int + errorPayload any + wantErrorType string + wantMessage string }{ // routes.go sends errors without StatusCode in JSON, so we must use HTTP status {