anthropic: fix ToolCallFunctionArguments type after rebase

Update tests and implementation to use the new ordered map-based
ToolCallFunctionArguments type which replaces the previous map[string]any.

- Add mapToArgs helper to convert map[string]any to ToolCallFunctionArguments
- Add testArgs and testProps helpers in tests
- Use cmpopts.IgnoreUnexported for cmp.Diff comparisons
This commit is contained in:
ParthSareen 2026-01-05 21:10:29 -08:00
parent bd4ab011ac
commit fceafefdce
3 changed files with 48 additions and 15 deletions

View File

@ -377,7 +377,7 @@ func convertMessage(msg MessageParam) ([]api.Message, error) {
}, },
} }
if input, ok := blockMap["input"].(map[string]any); ok { if input, ok := blockMap["input"].(map[string]any); ok {
tc.Function.Arguments = api.ToolCallFunctionArguments(input) tc.Function.Arguments = mapToArgs(input)
} }
toolCalls = append(toolCalls, tc) toolCalls = append(toolCalls, tc)
@ -767,3 +767,12 @@ func GenerateMessageID() string {
func ptr(s string) *string { func ptr(s string) *string {
return &s return &s
} }
// mapToArgs converts a map to ToolCallFunctionArguments
func mapToArgs(m map[string]any) api.ToolCallFunctionArguments {
args := api.NewToolCallFunctionArguments()
for k, v := range m {
args.Set(k, v)
}
return args
}

View File

@ -14,6 +14,15 @@ const (
testImage = `iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVR42mNk+A8AAQUBAScY42YAAAAASUVORK5CYII=` testImage = `iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVR42mNk+A8AAQUBAScY42YAAAAASUVORK5CYII=`
) )
// testArgs creates ToolCallFunctionArguments from a map (convenience function for tests)
func testArgs(m map[string]any) api.ToolCallFunctionArguments {
args := api.NewToolCallFunctionArguments()
for k, v := range m {
args.Set(k, v)
}
return args
}
func TestFromMessagesRequest_Basic(t *testing.T) { func TestFromMessagesRequest_Basic(t *testing.T) {
req := MessagesRequest{ req := MessagesRequest{
Model: "test-model", Model: "test-model",
@ -468,7 +477,7 @@ func TestToMessagesResponse_WithToolCalls(t *testing.T) {
ID: "call_123", ID: "call_123",
Function: api.ToolCallFunction{ Function: api.ToolCallFunction{
Name: "get_weather", Name: "get_weather",
Arguments: map[string]any{"location": "Paris"}, Arguments: testArgs(map[string]any{"location": "Paris"}),
}, },
}, },
}, },
@ -662,7 +671,7 @@ func TestStreamConverter_WithToolCalls(t *testing.T) {
ID: "call_123", ID: "call_123",
Function: api.ToolCallFunction{ Function: api.ToolCallFunction{
Name: "get_weather", Name: "get_weather",
Arguments: map[string]any{"location": "Paris"}, Arguments: testArgs(map[string]any{"location": "Paris"}),
}, },
}, },
}, },
@ -708,6 +717,8 @@ func TestStreamConverter_ToolCallWithUnmarshalableArgs(t *testing.T) {
// Create a channel which cannot be JSON marshaled // Create a channel which cannot be JSON marshaled
unmarshalable := make(chan int) unmarshalable := make(chan int)
badArgs := api.NewToolCallFunctionArguments()
badArgs.Set("channel", unmarshalable)
resp := api.ChatResponse{ resp := api.ChatResponse{
Model: "test-model", Model: "test-model",
@ -718,7 +729,7 @@ func TestStreamConverter_ToolCallWithUnmarshalableArgs(t *testing.T) {
ID: "call_bad", ID: "call_bad",
Function: api.ToolCallFunction{ Function: api.ToolCallFunction{
Name: "bad_function", Name: "bad_function",
Arguments: map[string]any{"channel": unmarshalable}, Arguments: badArgs,
}, },
}, },
}, },
@ -752,6 +763,8 @@ func TestStreamConverter_MultipleToolCallsWithMixedValidity(t *testing.T) {
conv := NewStreamConverter("msg_123", "test-model") conv := NewStreamConverter("msg_123", "test-model")
unmarshalable := make(chan int) unmarshalable := make(chan int)
badArgs := api.NewToolCallFunctionArguments()
badArgs.Set("channel", unmarshalable)
resp := api.ChatResponse{ resp := api.ChatResponse{
Model: "test-model", Model: "test-model",
@ -762,14 +775,14 @@ func TestStreamConverter_MultipleToolCallsWithMixedValidity(t *testing.T) {
ID: "call_good", ID: "call_good",
Function: api.ToolCallFunction{ Function: api.ToolCallFunction{
Name: "good_function", Name: "good_function",
Arguments: map[string]any{"location": "Paris"}, Arguments: testArgs(map[string]any{"location": "Paris"}),
}, },
}, },
{ {
ID: "call_bad", ID: "call_bad",
Function: api.ToolCallFunction{ Function: api.ToolCallFunction{
Name: "bad_function", Name: "bad_function",
Arguments: map[string]any{"channel": unmarshalable}, Arguments: badArgs,
}, },
}, },
}, },

View File

@ -11,6 +11,7 @@ import (
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp"
"github.com/google/go-cmp/cmp/cmpopts"
"github.com/ollama/ollama/anthropic" "github.com/ollama/ollama/anthropic"
"github.com/ollama/ollama/api" "github.com/ollama/ollama/api"
@ -25,6 +26,15 @@ func captureAnthropicRequest(capturedRequest any) gin.HandlerFunc {
} }
} }
// testProps creates ToolPropertiesMap from a map (convenience function for tests)
func testProps(m map[string]api.ToolProperty) *api.ToolPropertiesMap {
props := api.NewToolPropertiesMap()
for k, v := range m {
props.Set(k, v)
}
return props
}
func TestAnthropicMessagesMiddleware(t *testing.T) { func TestAnthropicMessagesMiddleware(t *testing.T) {
type testCase struct { type testCase struct {
name string name string
@ -156,9 +166,9 @@ func TestAnthropicMessagesMiddleware(t *testing.T) {
Parameters: api.ToolFunctionParameters{ Parameters: api.ToolFunctionParameters{
Type: "object", Type: "object",
Required: []string{"location"}, Required: []string{"location"},
Properties: map[string]api.ToolProperty{ Properties: testProps(map[string]api.ToolProperty{
"location": {Type: api.PropertyType{"string"}}, "location": {Type: api.PropertyType{"string"}},
}, }),
}, },
}, },
}, },
@ -193,7 +203,7 @@ func TestAnthropicMessagesMiddleware(t *testing.T) {
ID: "call_123", ID: "call_123",
Function: api.ToolCallFunction{ Function: api.ToolCallFunction{
Name: "get_weather", Name: "get_weather",
Arguments: api.ToolCallFunctionArguments{"location": "Paris"}, Arguments: testArgs(map[string]any{"location": "Paris"}),
}, },
}, },
}, },
@ -344,7 +354,8 @@ func TestAnthropicMessagesMiddleware(t *testing.T) {
t.Errorf("model mismatch: got %q, want %q", capturedRequest.Model, tc.req.Model) t.Errorf("model mismatch: got %q, want %q", capturedRequest.Model, tc.req.Model)
} }
if diff := cmp.Diff(tc.req.Messages, capturedRequest.Messages); diff != "" { if diff := cmp.Diff(tc.req.Messages, capturedRequest.Messages,
cmpopts.IgnoreUnexported(api.ToolCallFunctionArguments{}, api.ToolPropertiesMap{})); diff != "" {
t.Errorf("messages mismatch (-want +got):\n%s", diff) t.Errorf("messages mismatch (-want +got):\n%s", diff)
} }
@ -492,11 +503,11 @@ func TestAnthropicWriter_ErrorFromRoutes(t *testing.T) {
gin.SetMode(gin.TestMode) gin.SetMode(gin.TestMode)
tests := []struct { tests := []struct {
name string name string
statusCode int statusCode int
errorPayload any errorPayload any
wantErrorType string wantErrorType string
wantMessage string wantMessage string
}{ }{
// routes.go sends errors without StatusCode in JSON, so we must use HTTP status // routes.go sends errors without StatusCode in JSON, so we must use HTTP status
{ {