From 6229df5b90711da2a93f4246b558d7441370b305 Mon Sep 17 00:00:00 2001 From: ParthSareen Date: Fri, 2 Jan 2026 01:31:43 -0500 Subject: [PATCH] anthropic: add unit and integration tests - Unit tests for transformation functions (FromMessagesRequest, ToMessagesResponse) - Unit tests for error handling and edge cases - Middleware integration tests with httptest - Fix lint issues (gofmt) - Fix unused struct fields in StreamConverter - Add fallback for crypto/rand errors --- anthropic/anthropic.go | 12 +- anthropic/anthropic_test.go | 667 +++++++++++++++++++++++++++++++++++ docs/README.md | 1 + docs/docs.json | 3 +- middleware/anthropic_test.go | 487 +++++++++++++++++++++++++ 5 files changed, 1163 insertions(+), 7 deletions(-) create mode 100644 anthropic/anthropic_test.go create mode 100644 middleware/anthropic_test.go diff --git a/anthropic/anthropic.go b/anthropic/anthropic.go index 7bf9e98a0..ef0bdd953 100644 --- a/anthropic/anthropic.go +++ b/anthropic/anthropic.go @@ -60,7 +60,7 @@ type MessagesRequest struct { Model string `json:"model"` MaxTokens int `json:"max_tokens"` Messages []MessageParam `json:"messages"` - System any `json:"system,omitempty"` // string or []ContentBlock + System any `json:"system,omitempty"` // string or []ContentBlock Stream bool `json:"stream,omitempty"` Temperature *float64 `json:"temperature,omitempty"` TopP *float64 `json:"top_p,omitempty"` @@ -74,7 +74,7 @@ type MessagesRequest struct { // MessageParam represents a message in the request type MessageParam struct { - Role string `json:"role"` // "user" or "assistant" + Role string `json:"role"` // "user" or "assistant" Content any `json:"content"` // string or []ContentBlock } @@ -181,11 +181,11 @@ type ContentBlockDeltaEvent struct { // Delta represents an incremental update type Delta struct { - Type string `json:"type"` // "text_delta", "input_json_delta", "thinking_delta", "signature_delta" - Text string `json:"text,omitempty"` + Type string `json:"type"` // "text_delta", "input_json_delta", "thinking_delta", "signature_delta" + Text string `json:"text,omitempty"` PartialJSON string `json:"partial_json,omitempty"` - Thinking string `json:"thinking,omitempty"` - Signature string `json:"signature,omitempty"` + Thinking string `json:"thinking,omitempty"` + Signature string `json:"signature,omitempty"` } // ContentBlockStopEvent signals the end of a content block diff --git a/anthropic/anthropic_test.go b/anthropic/anthropic_test.go new file mode 100644 index 000000000..31a2ec67c --- /dev/null +++ b/anthropic/anthropic_test.go @@ -0,0 +1,667 @@ +package anthropic + +import ( + "encoding/base64" + "encoding/json" + "testing" + + "github.com/google/go-cmp/cmp" + + "github.com/ollama/ollama/api" +) + +const ( + testImage = `iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVR42mNk+A8AAQUBAScY42YAAAAASUVORK5CYII=` +) + +func TestFromMessagesRequest_Basic(t *testing.T) { + req := MessagesRequest{ + Model: "test-model", + MaxTokens: 1024, + Messages: []MessageParam{ + {Role: "user", Content: "Hello"}, + }, + } + + result, err := FromMessagesRequest(req) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if result.Model != "test-model" { + t.Errorf("expected model 'test-model', got %q", result.Model) + } + + if len(result.Messages) != 1 { + t.Fatalf("expected 1 message, got %d", len(result.Messages)) + } + + if result.Messages[0].Role != "user" || result.Messages[0].Content != "Hello" { + t.Errorf("unexpected message: %+v", result.Messages[0]) + } + + if numPredict, ok := result.Options["num_predict"].(int); !ok || numPredict != 1024 { + t.Errorf("expected num_predict 1024, got %v", result.Options["num_predict"]) + } +} + +func TestFromMessagesRequest_WithSystemPrompt(t *testing.T) { + req := MessagesRequest{ + Model: "test-model", + MaxTokens: 1024, + System: "You are a helpful assistant.", + Messages: []MessageParam{ + {Role: "user", Content: "Hello"}, + }, + } + + result, err := FromMessagesRequest(req) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if len(result.Messages) != 2 { + t.Fatalf("expected 2 messages, got %d", len(result.Messages)) + } + + if result.Messages[0].Role != "system" || result.Messages[0].Content != "You are a helpful assistant." { + t.Errorf("unexpected system message: %+v", result.Messages[0]) + } +} + +func TestFromMessagesRequest_WithSystemPromptArray(t *testing.T) { + req := MessagesRequest{ + Model: "test-model", + MaxTokens: 1024, + System: []any{ + map[string]any{"type": "text", "text": "You are helpful."}, + map[string]any{"type": "text", "text": " Be concise."}, + }, + Messages: []MessageParam{ + {Role: "user", Content: "Hello"}, + }, + } + + result, err := FromMessagesRequest(req) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if len(result.Messages) != 2 { + t.Fatalf("expected 2 messages, got %d", len(result.Messages)) + } + + if result.Messages[0].Content != "You are helpful. Be concise." { + t.Errorf("unexpected system message content: %q", result.Messages[0].Content) + } +} + +func TestFromMessagesRequest_WithOptions(t *testing.T) { + temp := 0.7 + topP := 0.9 + topK := 40 + req := MessagesRequest{ + Model: "test-model", + MaxTokens: 2048, + Messages: []MessageParam{{Role: "user", Content: "Hello"}}, + Temperature: &temp, + TopP: &topP, + TopK: &topK, + StopSequences: []string{"\n", "END"}, + } + + result, err := FromMessagesRequest(req) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if result.Options["temperature"] != 0.7 { + t.Errorf("expected temperature 0.7, got %v", result.Options["temperature"]) + } + if result.Options["top_p"] != 0.9 { + t.Errorf("expected top_p 0.9, got %v", result.Options["top_p"]) + } + if result.Options["top_k"] != 40 { + t.Errorf("expected top_k 40, got %v", result.Options["top_k"]) + } + if diff := cmp.Diff([]string{"\n", "END"}, result.Options["stop"]); diff != "" { + t.Errorf("stop sequences mismatch: %s", diff) + } +} + +func TestFromMessagesRequest_WithImage(t *testing.T) { + imgData, _ := base64.StdEncoding.DecodeString(testImage) + + req := MessagesRequest{ + Model: "test-model", + MaxTokens: 1024, + Messages: []MessageParam{ + { + Role: "user", + Content: []any{ + map[string]any{"type": "text", "text": "What's in this image?"}, + map[string]any{ + "type": "image", + "source": map[string]any{ + "type": "base64", + "media_type": "image/png", + "data": testImage, + }, + }, + }, + }, + }, + } + + result, err := FromMessagesRequest(req) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if len(result.Messages) != 1 { + t.Fatalf("expected 1 message, got %d", len(result.Messages)) + } + + if result.Messages[0].Content != "What's in this image?" { + t.Errorf("expected content 'What's in this image?', got %q", result.Messages[0].Content) + } + + if len(result.Messages[0].Images) != 1 { + t.Fatalf("expected 1 image, got %d", len(result.Messages[0].Images)) + } + + if string(result.Messages[0].Images[0]) != string(imgData) { + t.Error("image data mismatch") + } +} + +func TestFromMessagesRequest_WithToolUse(t *testing.T) { + req := MessagesRequest{ + Model: "test-model", + MaxTokens: 1024, + Messages: []MessageParam{ + {Role: "user", Content: "What's the weather in Paris?"}, + { + Role: "assistant", + Content: []any{ + map[string]any{ + "type": "tool_use", + "id": "call_123", + "name": "get_weather", + "input": map[string]any{"location": "Paris"}, + }, + }, + }, + }, + } + + result, err := FromMessagesRequest(req) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if len(result.Messages) != 2 { + t.Fatalf("expected 2 messages, got %d", len(result.Messages)) + } + + if len(result.Messages[1].ToolCalls) != 1 { + t.Fatalf("expected 1 tool call, got %d", len(result.Messages[1].ToolCalls)) + } + + tc := result.Messages[1].ToolCalls[0] + if tc.ID != "call_123" { + t.Errorf("expected tool call ID 'call_123', got %q", tc.ID) + } + if tc.Function.Name != "get_weather" { + t.Errorf("expected tool name 'get_weather', got %q", tc.Function.Name) + } +} + +func TestFromMessagesRequest_WithToolResult(t *testing.T) { + req := MessagesRequest{ + Model: "test-model", + MaxTokens: 1024, + Messages: []MessageParam{ + { + Role: "user", + Content: []any{ + map[string]any{ + "type": "tool_result", + "tool_use_id": "call_123", + "content": "The weather in Paris is sunny, 22°C", + }, + }, + }, + }, + } + + result, err := FromMessagesRequest(req) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if len(result.Messages) != 1 { + t.Fatalf("expected 1 message, got %d", len(result.Messages)) + } + + msg := result.Messages[0] + if msg.Role != "tool" { + t.Errorf("expected role 'tool', got %q", msg.Role) + } + if msg.ToolCallID != "call_123" { + t.Errorf("expected tool_call_id 'call_123', got %q", msg.ToolCallID) + } + if msg.Content != "The weather in Paris is sunny, 22°C" { + t.Errorf("unexpected content: %q", msg.Content) + } +} + +func TestFromMessagesRequest_WithTools(t *testing.T) { + req := MessagesRequest{ + Model: "test-model", + MaxTokens: 1024, + Messages: []MessageParam{{Role: "user", Content: "Hello"}}, + Tools: []Tool{ + { + Name: "get_weather", + Description: "Get current weather", + InputSchema: json.RawMessage(`{"type":"object","properties":{"location":{"type":"string"}},"required":["location"]}`), + }, + }, + } + + result, err := FromMessagesRequest(req) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if len(result.Tools) != 1 { + t.Fatalf("expected 1 tool, got %d", len(result.Tools)) + } + + tool := result.Tools[0] + if tool.Type != "function" { + t.Errorf("expected type 'function', got %q", tool.Type) + } + if tool.Function.Name != "get_weather" { + t.Errorf("expected name 'get_weather', got %q", tool.Function.Name) + } + if tool.Function.Description != "Get current weather" { + t.Errorf("expected description 'Get current weather', got %q", tool.Function.Description) + } +} + +func TestFromMessagesRequest_WithThinking(t *testing.T) { + req := MessagesRequest{ + Model: "test-model", + MaxTokens: 1024, + Messages: []MessageParam{{Role: "user", Content: "Hello"}}, + Thinking: &ThinkingConfig{Type: "enabled", BudgetTokens: 1000}, + } + + result, err := FromMessagesRequest(req) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if result.Think == nil { + t.Fatal("expected Think to be set") + } + if v, ok := result.Think.Value.(bool); !ok || !v { + t.Errorf("expected Think.Value to be true, got %v", result.Think.Value) + } +} + +func TestFromMessagesRequest_ToolUseMissingID(t *testing.T) { + req := MessagesRequest{ + Model: "test-model", + MaxTokens: 1024, + Messages: []MessageParam{ + { + Role: "assistant", + Content: []any{ + map[string]any{ + "type": "tool_use", + "name": "get_weather", + }, + }, + }, + }, + } + + _, err := FromMessagesRequest(req) + if err == nil { + t.Fatal("expected error for missing tool_use id") + } + if err.Error() != "tool_use block missing required 'id' field" { + t.Errorf("unexpected error message: %v", err) + } +} + +func TestFromMessagesRequest_ToolUseMissingName(t *testing.T) { + req := MessagesRequest{ + Model: "test-model", + MaxTokens: 1024, + Messages: []MessageParam{ + { + Role: "assistant", + Content: []any{ + map[string]any{ + "type": "tool_use", + "id": "call_123", + }, + }, + }, + }, + } + + _, err := FromMessagesRequest(req) + if err == nil { + t.Fatal("expected error for missing tool_use name") + } + if err.Error() != "tool_use block missing required 'name' field" { + t.Errorf("unexpected error message: %v", err) + } +} + +func TestFromMessagesRequest_InvalidToolSchema(t *testing.T) { + req := MessagesRequest{ + Model: "test-model", + MaxTokens: 1024, + Messages: []MessageParam{{Role: "user", Content: "Hello"}}, + Tools: []Tool{ + { + Name: "bad_tool", + InputSchema: json.RawMessage(`{invalid json`), + }, + }, + } + + _, err := FromMessagesRequest(req) + if err == nil { + t.Fatal("expected error for invalid tool schema") + } +} + +func TestToMessagesResponse_Basic(t *testing.T) { + resp := api.ChatResponse{ + Model: "test-model", + Message: api.Message{ + Role: "assistant", + Content: "Hello there!", + }, + Done: true, + DoneReason: "stop", + Metrics: api.Metrics{ + PromptEvalCount: 10, + EvalCount: 5, + }, + } + + result := ToMessagesResponse("msg_123", resp) + + if result.ID != "msg_123" { + t.Errorf("expected ID 'msg_123', got %q", result.ID) + } + if result.Type != "message" { + t.Errorf("expected type 'message', got %q", result.Type) + } + if result.Role != "assistant" { + t.Errorf("expected role 'assistant', got %q", result.Role) + } + if len(result.Content) != 1 { + t.Fatalf("expected 1 content block, got %d", len(result.Content)) + } + if result.Content[0].Type != "text" || result.Content[0].Text != "Hello there!" { + t.Errorf("unexpected content: %+v", result.Content[0]) + } + if result.StopReason != "end_turn" { + t.Errorf("expected stop_reason 'end_turn', got %q", result.StopReason) + } + if result.Usage.InputTokens != 10 || result.Usage.OutputTokens != 5 { + t.Errorf("unexpected usage: %+v", result.Usage) + } +} + +func TestToMessagesResponse_WithToolCalls(t *testing.T) { + resp := api.ChatResponse{ + Model: "test-model", + Message: api.Message{ + Role: "assistant", + ToolCalls: []api.ToolCall{ + { + ID: "call_123", + Function: api.ToolCallFunction{ + Name: "get_weather", + Arguments: map[string]any{"location": "Paris"}, + }, + }, + }, + }, + Done: true, + DoneReason: "stop", + } + + result := ToMessagesResponse("msg_123", resp) + + if len(result.Content) != 1 { + t.Fatalf("expected 1 content block, got %d", len(result.Content)) + } + if result.Content[0].Type != "tool_use" { + t.Errorf("expected type 'tool_use', got %q", result.Content[0].Type) + } + if result.Content[0].ID != "call_123" { + t.Errorf("expected ID 'call_123', got %q", result.Content[0].ID) + } + if result.Content[0].Name != "get_weather" { + t.Errorf("expected name 'get_weather', got %q", result.Content[0].Name) + } + if result.StopReason != "tool_use" { + t.Errorf("expected stop_reason 'tool_use', got %q", result.StopReason) + } +} + +func TestToMessagesResponse_WithThinking(t *testing.T) { + resp := api.ChatResponse{ + Model: "test-model", + Message: api.Message{ + Role: "assistant", + Content: "The answer is 42.", + Thinking: "Let me think about this...", + }, + Done: true, + DoneReason: "stop", + } + + result := ToMessagesResponse("msg_123", resp) + + if len(result.Content) != 2 { + t.Fatalf("expected 2 content blocks, got %d", len(result.Content)) + } + if result.Content[0].Type != "thinking" { + t.Errorf("expected first block type 'thinking', got %q", result.Content[0].Type) + } + if result.Content[0].Thinking != "Let me think about this..." { + t.Errorf("unexpected thinking content: %q", result.Content[0].Thinking) + } + if result.Content[1].Type != "text" { + t.Errorf("expected second block type 'text', got %q", result.Content[1].Type) + } +} + +func TestMapStopReason(t *testing.T) { + tests := []struct { + reason string + hasToolCalls bool + want string + }{ + {"stop", false, "end_turn"}, + {"length", false, "max_tokens"}, + {"stop", true, "tool_use"}, + {"other", false, "stop_sequence"}, + {"", false, ""}, + } + + for _, tt := range tests { + got := mapStopReason(tt.reason, tt.hasToolCalls) + if got != tt.want { + t.Errorf("mapStopReason(%q, %v) = %q, want %q", tt.reason, tt.hasToolCalls, got, tt.want) + } + } +} + +func TestNewError(t *testing.T) { + tests := []struct { + code int + want string + }{ + {400, "invalid_request_error"}, + {401, "authentication_error"}, + {403, "permission_error"}, + {404, "not_found_error"}, + {429, "rate_limit_error"}, + {500, "api_error"}, + {503, "overloaded_error"}, + {529, "overloaded_error"}, + } + + for _, tt := range tests { + result := NewError(tt.code, "test message") + if result.Type != "error" { + t.Errorf("NewError(%d) type = %q, want 'error'", tt.code, result.Type) + } + if result.Error.Type != tt.want { + t.Errorf("NewError(%d) error.type = %q, want %q", tt.code, result.Error.Type, tt.want) + } + if result.Error.Message != "test message" { + t.Errorf("NewError(%d) message = %q, want 'test message'", tt.code, result.Error.Message) + } + if result.RequestID == "" { + t.Errorf("NewError(%d) request_id should not be empty", tt.code) + } + } +} + +func TestGenerateMessageID(t *testing.T) { + id1 := GenerateMessageID() + id2 := GenerateMessageID() + + if id1 == "" { + t.Error("GenerateMessageID returned empty string") + } + if id1 == id2 { + t.Error("GenerateMessageID returned duplicate IDs") + } + if len(id1) < 10 { + t.Errorf("GenerateMessageID returned short ID: %q", id1) + } + if id1[:4] != "msg_" { + t.Errorf("GenerateMessageID should start with 'msg_', got %q", id1[:4]) + } +} + +func TestStreamConverter_Basic(t *testing.T) { + conv := NewStreamConverter("msg_123", "test-model") + + // First chunk + resp1 := api.ChatResponse{ + Model: "test-model", + Message: api.Message{ + Role: "assistant", + Content: "Hello", + }, + Metrics: api.Metrics{PromptEvalCount: 10}, + } + + events1 := conv.Process(resp1) + if len(events1) < 3 { + t.Fatalf("expected at least 3 events for first chunk, got %d", len(events1)) + } + + // Should have message_start, content_block_start, content_block_delta + if events1[0].Event != "message_start" { + t.Errorf("expected first event 'message_start', got %q", events1[0].Event) + } + if events1[1].Event != "content_block_start" { + t.Errorf("expected second event 'content_block_start', got %q", events1[1].Event) + } + if events1[2].Event != "content_block_delta" { + t.Errorf("expected third event 'content_block_delta', got %q", events1[2].Event) + } + + // Final chunk + resp2 := api.ChatResponse{ + Model: "test-model", + Message: api.Message{ + Role: "assistant", + Content: " world!", + }, + Done: true, + DoneReason: "stop", + Metrics: api.Metrics{EvalCount: 5}, + } + + events2 := conv.Process(resp2) + + // Should have content_block_delta, content_block_stop, message_delta, message_stop + hasStop := false + for _, e := range events2 { + if e.Event == "message_stop" { + hasStop = true + } + } + if !hasStop { + t.Error("expected message_stop event in final chunk") + } +} + +func TestStreamConverter_WithToolCalls(t *testing.T) { + conv := NewStreamConverter("msg_123", "test-model") + + resp := api.ChatResponse{ + Model: "test-model", + Message: api.Message{ + Role: "assistant", + ToolCalls: []api.ToolCall{ + { + ID: "call_123", + Function: api.ToolCallFunction{ + Name: "get_weather", + Arguments: map[string]any{"location": "Paris"}, + }, + }, + }, + }, + Done: true, + DoneReason: "stop", + Metrics: api.Metrics{PromptEvalCount: 10, EvalCount: 5}, + } + + events := conv.Process(resp) + + hasToolStart := false + hasToolDelta := false + for _, e := range events { + if e.Event == "content_block_start" { + if start, ok := e.Data.(ContentBlockStartEvent); ok { + if start.ContentBlock.Type == "tool_use" { + hasToolStart = true + } + } + } + if e.Event == "content_block_delta" { + if delta, ok := e.Data.(ContentBlockDeltaEvent); ok { + if delta.Delta.Type == "input_json_delta" { + hasToolDelta = true + } + } + } + } + + if !hasToolStart { + t.Error("expected tool_use content_block_start event") + } + if !hasToolDelta { + t.Error("expected input_json_delta event") + } +} diff --git a/docs/README.md b/docs/README.md index 74544a321..4483eb550 100644 --- a/docs/README.md +++ b/docs/README.md @@ -14,6 +14,7 @@ * [API Reference](https://docs.ollama.com/api) * [Modelfile Reference](https://docs.ollama.com/modelfile) * [OpenAI Compatibility](https://docs.ollama.com/api/openai-compatibility) +* [Anthropic Compatibility](./api/anthropic-compatibility.mdx) ### Resources diff --git a/docs/docs.json b/docs/docs.json index 71a6f17a0..47b865d20 100644 --- a/docs/docs.json +++ b/docs/docs.json @@ -139,7 +139,8 @@ "/api/streaming", "/api/usage", "/api/errors", - "/api/openai-compatibility" + "/api/openai-compatibility", + "/api/anthropic-compatibility" ] }, { diff --git a/middleware/anthropic_test.go b/middleware/anthropic_test.go new file mode 100644 index 000000000..81c68cce1 --- /dev/null +++ b/middleware/anthropic_test.go @@ -0,0 +1,487 @@ +package middleware + +import ( + "bytes" + "encoding/json" + "io" + "net/http" + "net/http/httptest" + "strings" + "testing" + + "github.com/gin-gonic/gin" + "github.com/google/go-cmp/cmp" + + "github.com/ollama/ollama/anthropic" + "github.com/ollama/ollama/api" +) + +func captureAnthropicRequest(capturedRequest any) gin.HandlerFunc { + return func(c *gin.Context) { + bodyBytes, _ := io.ReadAll(c.Request.Body) + c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes)) + _ = json.Unmarshal(bodyBytes, capturedRequest) + c.Next() + } +} + +func TestAnthropicMessagesMiddleware(t *testing.T) { + type testCase struct { + name string + body string + req api.ChatRequest + err anthropic.ErrorResponse + } + + var capturedRequest *api.ChatRequest + stream := true + + testCases := []testCase{ + { + name: "basic message", + body: `{ + "model": "test-model", + "max_tokens": 1024, + "messages": [ + {"role": "user", "content": "Hello"} + ] + }`, + req: api.ChatRequest{ + Model: "test-model", + Messages: []api.Message{ + {Role: "user", Content: "Hello"}, + }, + Options: map[string]any{"num_predict": 1024}, + Stream: &False, + }, + }, + { + name: "with system prompt", + body: `{ + "model": "test-model", + "max_tokens": 1024, + "system": "You are helpful.", + "messages": [ + {"role": "user", "content": "Hello"} + ] + }`, + req: api.ChatRequest{ + Model: "test-model", + Messages: []api.Message{ + {Role: "system", Content: "You are helpful."}, + {Role: "user", Content: "Hello"}, + }, + Options: map[string]any{"num_predict": 1024}, + Stream: &False, + }, + }, + { + name: "with options", + body: `{ + "model": "test-model", + "max_tokens": 2048, + "temperature": 0.7, + "top_p": 0.9, + "top_k": 40, + "stop_sequences": ["\n", "END"], + "messages": [ + {"role": "user", "content": "Hello"} + ] + }`, + req: api.ChatRequest{ + Model: "test-model", + Messages: []api.Message{ + {Role: "user", Content: "Hello"}, + }, + Options: map[string]any{ + "num_predict": 2048, + "temperature": 0.7, + "top_p": 0.9, + "top_k": 40, + "stop": []string{"\n", "END"}, + }, + Stream: &False, + }, + }, + { + name: "streaming", + body: `{ + "model": "test-model", + "max_tokens": 1024, + "stream": true, + "messages": [ + {"role": "user", "content": "Hello"} + ] + }`, + req: api.ChatRequest{ + Model: "test-model", + Messages: []api.Message{ + {Role: "user", Content: "Hello"}, + }, + Options: map[string]any{"num_predict": 1024}, + Stream: &stream, + }, + }, + { + name: "with tools", + body: `{ + "model": "test-model", + "max_tokens": 1024, + "messages": [ + {"role": "user", "content": "What's the weather?"} + ], + "tools": [{ + "name": "get_weather", + "description": "Get current weather", + "input_schema": { + "type": "object", + "properties": { + "location": {"type": "string"} + }, + "required": ["location"] + } + }] + }`, + req: api.ChatRequest{ + Model: "test-model", + Messages: []api.Message{ + {Role: "user", Content: "What's the weather?"}, + }, + Tools: []api.Tool{ + { + Type: "function", + Function: api.ToolFunction{ + Name: "get_weather", + Description: "Get current weather", + Parameters: api.ToolFunctionParameters{ + Type: "object", + Required: []string{"location"}, + Properties: map[string]api.ToolProperty{ + "location": {Type: api.PropertyType{"string"}}, + }, + }, + }, + }, + }, + Options: map[string]any{"num_predict": 1024}, + Stream: &False, + }, + }, + { + name: "with tool result", + body: `{ + "model": "test-model", + "max_tokens": 1024, + "messages": [ + {"role": "user", "content": "What's the weather?"}, + {"role": "assistant", "content": [ + {"type": "tool_use", "id": "call_123", "name": "get_weather", "input": {"location": "Paris"}} + ]}, + {"role": "user", "content": [ + {"type": "tool_result", "tool_use_id": "call_123", "content": "Sunny, 22°C"} + ]} + ] + }`, + req: api.ChatRequest{ + Model: "test-model", + Messages: []api.Message{ + {Role: "user", Content: "What's the weather?"}, + { + Role: "assistant", + ToolCalls: []api.ToolCall{ + { + ID: "call_123", + Function: api.ToolCallFunction{ + Name: "get_weather", + Arguments: api.ToolCallFunctionArguments{"location": "Paris"}, + }, + }, + }, + }, + {Role: "tool", Content: "Sunny, 22°C", ToolCallID: "call_123"}, + }, + Options: map[string]any{"num_predict": 1024}, + Stream: &False, + }, + }, + { + name: "with thinking enabled", + body: `{ + "model": "test-model", + "max_tokens": 1024, + "thinking": {"type": "enabled", "budget_tokens": 1000}, + "messages": [ + {"role": "user", "content": "Hello"} + ] + }`, + req: api.ChatRequest{ + Model: "test-model", + Messages: []api.Message{ + {Role: "user", Content: "Hello"}, + }, + Options: map[string]any{"num_predict": 1024}, + Stream: &False, + Think: &api.ThinkValue{Value: true}, + }, + }, + { + name: "missing model error", + body: `{ + "max_tokens": 1024, + "messages": [ + {"role": "user", "content": "Hello"} + ] + }`, + err: anthropic.ErrorResponse{ + Type: "error", + Error: anthropic.Error{ + Type: "invalid_request_error", + Message: "model is required", + }, + }, + }, + { + name: "missing max_tokens error", + body: `{ + "model": "test-model", + "messages": [ + {"role": "user", "content": "Hello"} + ] + }`, + err: anthropic.ErrorResponse{ + Type: "error", + Error: anthropic.Error{ + Type: "invalid_request_error", + Message: "max_tokens is required and must be positive", + }, + }, + }, + { + name: "missing messages error", + body: `{ + "model": "test-model", + "max_tokens": 1024 + }`, + err: anthropic.ErrorResponse{ + Type: "error", + Error: anthropic.Error{ + Type: "invalid_request_error", + Message: "messages is required", + }, + }, + }, + { + name: "tool_use missing id error", + body: `{ + "model": "test-model", + "max_tokens": 1024, + "messages": [ + {"role": "assistant", "content": [ + {"type": "tool_use", "name": "test"} + ]} + ] + }`, + err: anthropic.ErrorResponse{ + Type: "error", + Error: anthropic.Error{ + Type: "invalid_request_error", + Message: "tool_use block missing required 'id' field", + }, + }, + }, + } + + endpoint := func(c *gin.Context) { + c.Status(http.StatusOK) + } + + gin.SetMode(gin.TestMode) + router := gin.New() + router.Use(AnthropicMessagesMiddleware(), captureAnthropicRequest(&capturedRequest)) + router.Handle(http.MethodPost, "/v1/messages", endpoint) + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + req, _ := http.NewRequest(http.MethodPost, "/v1/messages", strings.NewReader(tc.body)) + req.Header.Set("Content-Type", "application/json") + + defer func() { capturedRequest = nil }() + + resp := httptest.NewRecorder() + router.ServeHTTP(resp, req) + + if tc.err.Type != "" { + // Expect error + if resp.Code == http.StatusOK { + t.Fatalf("expected error response, got 200 OK") + } + var errResp anthropic.ErrorResponse + if err := json.Unmarshal(resp.Body.Bytes(), &errResp); err != nil { + t.Fatalf("failed to unmarshal error: %v", err) + } + if errResp.Type != tc.err.Type { + t.Errorf("expected error type %q, got %q", tc.err.Type, errResp.Type) + } + if errResp.Error.Type != tc.err.Error.Type { + t.Errorf("expected error.type %q, got %q", tc.err.Error.Type, errResp.Error.Type) + } + if errResp.Error.Message != tc.err.Error.Message { + t.Errorf("expected error.message %q, got %q", tc.err.Error.Message, errResp.Error.Message) + } + return + } + + if resp.Code != http.StatusOK { + t.Fatalf("unexpected status code: %d, body: %s", resp.Code, resp.Body.String()) + } + + if capturedRequest == nil { + t.Fatal("request was not captured") + } + + // Compare relevant fields + if capturedRequest.Model != tc.req.Model { + t.Errorf("model mismatch: got %q, want %q", capturedRequest.Model, tc.req.Model) + } + + if diff := cmp.Diff(tc.req.Messages, capturedRequest.Messages); diff != "" { + t.Errorf("messages mismatch (-want +got):\n%s", diff) + } + + if tc.req.Stream != nil && capturedRequest.Stream != nil { + if *tc.req.Stream != *capturedRequest.Stream { + t.Errorf("stream mismatch: got %v, want %v", *capturedRequest.Stream, *tc.req.Stream) + } + } + + if tc.req.Think != nil { + if capturedRequest.Think == nil { + t.Error("expected Think to be set") + } else if capturedRequest.Think.Value != tc.req.Think.Value { + t.Errorf("Think mismatch: got %v, want %v", capturedRequest.Think.Value, tc.req.Think.Value) + } + } + }) + } +} + +func TestAnthropicMessagesMiddleware_Headers(t *testing.T) { + gin.SetMode(gin.TestMode) + + t.Run("streaming sets correct headers", func(t *testing.T) { + router := gin.New() + router.Use(AnthropicMessagesMiddleware()) + router.POST("/v1/messages", func(c *gin.Context) { + // Check headers were set + if c.Writer.Header().Get("Content-Type") != "text/event-stream" { + t.Errorf("expected Content-Type text/event-stream, got %q", c.Writer.Header().Get("Content-Type")) + } + if c.Writer.Header().Get("Cache-Control") != "no-cache" { + t.Errorf("expected Cache-Control no-cache, got %q", c.Writer.Header().Get("Cache-Control")) + } + c.Status(http.StatusOK) + }) + + body := `{"model": "test", "max_tokens": 100, "stream": true, "messages": [{"role": "user", "content": "Hi"}]}` + req, _ := http.NewRequest(http.MethodPost, "/v1/messages", strings.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + + resp := httptest.NewRecorder() + router.ServeHTTP(resp, req) + }) +} + +func TestAnthropicMessagesMiddleware_InvalidJSON(t *testing.T) { + gin.SetMode(gin.TestMode) + router := gin.New() + router.Use(AnthropicMessagesMiddleware()) + router.POST("/v1/messages", func(c *gin.Context) { + c.Status(http.StatusOK) + }) + + req, _ := http.NewRequest(http.MethodPost, "/v1/messages", strings.NewReader(`{invalid json`)) + req.Header.Set("Content-Type", "application/json") + + resp := httptest.NewRecorder() + router.ServeHTTP(resp, req) + + if resp.Code != http.StatusBadRequest { + t.Errorf("expected status 400, got %d", resp.Code) + } + + var errResp anthropic.ErrorResponse + if err := json.Unmarshal(resp.Body.Bytes(), &errResp); err != nil { + t.Fatalf("failed to unmarshal error: %v", err) + } + + if errResp.Type != "error" { + t.Errorf("expected type 'error', got %q", errResp.Type) + } + if errResp.Error.Type != "invalid_request_error" { + t.Errorf("expected error type 'invalid_request_error', got %q", errResp.Error.Type) + } +} + +func TestAnthropicWriter_NonStreaming(t *testing.T) { + gin.SetMode(gin.TestMode) + + router := gin.New() + router.Use(AnthropicMessagesMiddleware()) + router.POST("/v1/messages", func(c *gin.Context) { + // Simulate Ollama response + resp := api.ChatResponse{ + Model: "test-model", + Message: api.Message{ + Role: "assistant", + Content: "Hello there!", + }, + Done: true, + DoneReason: "stop", + Metrics: api.Metrics{ + PromptEvalCount: 10, + EvalCount: 5, + }, + } + data, _ := json.Marshal(resp) + c.Writer.WriteHeader(http.StatusOK) + _, _ = c.Writer.Write(data) + }) + + body := `{"model": "test-model", "max_tokens": 100, "messages": [{"role": "user", "content": "Hi"}]}` + req, _ := http.NewRequest(http.MethodPost, "/v1/messages", strings.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + + resp := httptest.NewRecorder() + router.ServeHTTP(resp, req) + + if resp.Code != http.StatusOK { + t.Fatalf("expected status 200, got %d", resp.Code) + } + + var result anthropic.MessagesResponse + if err := json.Unmarshal(resp.Body.Bytes(), &result); err != nil { + t.Fatalf("failed to unmarshal response: %v", err) + } + + if result.Type != "message" { + t.Errorf("expected type 'message', got %q", result.Type) + } + if result.Role != "assistant" { + t.Errorf("expected role 'assistant', got %q", result.Role) + } + if len(result.Content) != 1 { + t.Fatalf("expected 1 content block, got %d", len(result.Content)) + } + if result.Content[0].Text != "Hello there!" { + t.Errorf("expected text 'Hello there!', got %q", result.Content[0].Text) + } + if result.StopReason != "end_turn" { + t.Errorf("expected stop_reason 'end_turn', got %q", result.StopReason) + } + if result.Usage.InputTokens != 10 { + t.Errorf("expected input_tokens 10, got %d", result.Usage.InputTokens) + } + if result.Usage.OutputTokens != 5 { + t.Errorf("expected output_tokens 5, got %d", result.Usage.OutputTokens) + } +}