package anthropic import ( "encoding/base64" "encoding/json" "testing" "github.com/google/go-cmp/cmp" "github.com/ollama/ollama/api" ) const ( testImage = `iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVR42mNk+A8AAQUBAScY42YAAAAASUVORK5CYII=` ) func TestFromMessagesRequest_Basic(t *testing.T) { req := MessagesRequest{ Model: "test-model", MaxTokens: 1024, Messages: []MessageParam{ {Role: "user", Content: "Hello"}, }, } result, err := FromMessagesRequest(req) if err != nil { t.Fatalf("unexpected error: %v", err) } if result.Model != "test-model" { t.Errorf("expected model 'test-model', got %q", result.Model) } if len(result.Messages) != 1 { t.Fatalf("expected 1 message, got %d", len(result.Messages)) } if result.Messages[0].Role != "user" || result.Messages[0].Content != "Hello" { t.Errorf("unexpected message: %+v", result.Messages[0]) } if numPredict, ok := result.Options["num_predict"].(int); !ok || numPredict != 1024 { t.Errorf("expected num_predict 1024, got %v", result.Options["num_predict"]) } } func TestFromMessagesRequest_WithSystemPrompt(t *testing.T) { req := MessagesRequest{ Model: "test-model", MaxTokens: 1024, System: "You are a helpful assistant.", Messages: []MessageParam{ {Role: "user", Content: "Hello"}, }, } result, err := FromMessagesRequest(req) if err != nil { t.Fatalf("unexpected error: %v", err) } if len(result.Messages) != 2 { t.Fatalf("expected 2 messages, got %d", len(result.Messages)) } if result.Messages[0].Role != "system" || result.Messages[0].Content != "You are a helpful assistant." { t.Errorf("unexpected system message: %+v", result.Messages[0]) } } func TestFromMessagesRequest_WithSystemPromptArray(t *testing.T) { req := MessagesRequest{ Model: "test-model", MaxTokens: 1024, System: []any{ map[string]any{"type": "text", "text": "You are helpful."}, map[string]any{"type": "text", "text": " Be concise."}, }, Messages: []MessageParam{ {Role: "user", Content: "Hello"}, }, } result, err := FromMessagesRequest(req) if err != nil { t.Fatalf("unexpected error: %v", err) } if len(result.Messages) != 2 { t.Fatalf("expected 2 messages, got %d", len(result.Messages)) } if result.Messages[0].Content != "You are helpful. Be concise." { t.Errorf("unexpected system message content: %q", result.Messages[0].Content) } } func TestFromMessagesRequest_WithOptions(t *testing.T) { temp := 0.7 topP := 0.9 topK := 40 req := MessagesRequest{ Model: "test-model", MaxTokens: 2048, Messages: []MessageParam{{Role: "user", Content: "Hello"}}, Temperature: &temp, TopP: &topP, TopK: &topK, StopSequences: []string{"\n", "END"}, } result, err := FromMessagesRequest(req) if err != nil { t.Fatalf("unexpected error: %v", err) } if result.Options["temperature"] != 0.7 { t.Errorf("expected temperature 0.7, got %v", result.Options["temperature"]) } if result.Options["top_p"] != 0.9 { t.Errorf("expected top_p 0.9, got %v", result.Options["top_p"]) } if result.Options["top_k"] != 40 { t.Errorf("expected top_k 40, got %v", result.Options["top_k"]) } if diff := cmp.Diff([]string{"\n", "END"}, result.Options["stop"]); diff != "" { t.Errorf("stop sequences mismatch: %s", diff) } } func TestFromMessagesRequest_WithImage(t *testing.T) { imgData, _ := base64.StdEncoding.DecodeString(testImage) req := MessagesRequest{ Model: "test-model", MaxTokens: 1024, Messages: []MessageParam{ { Role: "user", Content: []any{ map[string]any{"type": "text", "text": "What's in this image?"}, map[string]any{ "type": "image", "source": map[string]any{ "type": "base64", "media_type": "image/png", "data": testImage, }, }, }, }, }, } result, err := FromMessagesRequest(req) if err != nil { t.Fatalf("unexpected error: %v", err) } if len(result.Messages) != 1 { t.Fatalf("expected 1 message, got %d", len(result.Messages)) } if result.Messages[0].Content != "What's in this image?" { t.Errorf("expected content 'What's in this image?', got %q", result.Messages[0].Content) } if len(result.Messages[0].Images) != 1 { t.Fatalf("expected 1 image, got %d", len(result.Messages[0].Images)) } if string(result.Messages[0].Images[0]) != string(imgData) { t.Error("image data mismatch") } } func TestFromMessagesRequest_WithToolUse(t *testing.T) { req := MessagesRequest{ Model: "test-model", MaxTokens: 1024, Messages: []MessageParam{ {Role: "user", Content: "What's the weather in Paris?"}, { Role: "assistant", Content: []any{ map[string]any{ "type": "tool_use", "id": "call_123", "name": "get_weather", "input": map[string]any{"location": "Paris"}, }, }, }, }, } result, err := FromMessagesRequest(req) if err != nil { t.Fatalf("unexpected error: %v", err) } if len(result.Messages) != 2 { t.Fatalf("expected 2 messages, got %d", len(result.Messages)) } if len(result.Messages[1].ToolCalls) != 1 { t.Fatalf("expected 1 tool call, got %d", len(result.Messages[1].ToolCalls)) } tc := result.Messages[1].ToolCalls[0] if tc.ID != "call_123" { t.Errorf("expected tool call ID 'call_123', got %q", tc.ID) } if tc.Function.Name != "get_weather" { t.Errorf("expected tool name 'get_weather', got %q", tc.Function.Name) } } func TestFromMessagesRequest_WithToolResult(t *testing.T) { req := MessagesRequest{ Model: "test-model", MaxTokens: 1024, Messages: []MessageParam{ { Role: "user", Content: []any{ map[string]any{ "type": "tool_result", "tool_use_id": "call_123", "content": "The weather in Paris is sunny, 22°C", }, }, }, }, } result, err := FromMessagesRequest(req) if err != nil { t.Fatalf("unexpected error: %v", err) } if len(result.Messages) != 1 { t.Fatalf("expected 1 message, got %d", len(result.Messages)) } msg := result.Messages[0] if msg.Role != "tool" { t.Errorf("expected role 'tool', got %q", msg.Role) } if msg.ToolCallID != "call_123" { t.Errorf("expected tool_call_id 'call_123', got %q", msg.ToolCallID) } if msg.Content != "The weather in Paris is sunny, 22°C" { t.Errorf("unexpected content: %q", msg.Content) } } func TestFromMessagesRequest_WithTools(t *testing.T) { req := MessagesRequest{ Model: "test-model", MaxTokens: 1024, Messages: []MessageParam{{Role: "user", Content: "Hello"}}, Tools: []Tool{ { Name: "get_weather", Description: "Get current weather", InputSchema: json.RawMessage(`{"type":"object","properties":{"location":{"type":"string"}},"required":["location"]}`), }, }, } result, err := FromMessagesRequest(req) if err != nil { t.Fatalf("unexpected error: %v", err) } if len(result.Tools) != 1 { t.Fatalf("expected 1 tool, got %d", len(result.Tools)) } tool := result.Tools[0] if tool.Type != "function" { t.Errorf("expected type 'function', got %q", tool.Type) } if tool.Function.Name != "get_weather" { t.Errorf("expected name 'get_weather', got %q", tool.Function.Name) } if tool.Function.Description != "Get current weather" { t.Errorf("expected description 'Get current weather', got %q", tool.Function.Description) } } func TestFromMessagesRequest_WithThinking(t *testing.T) { req := MessagesRequest{ Model: "test-model", MaxTokens: 1024, Messages: []MessageParam{{Role: "user", Content: "Hello"}}, Thinking: &ThinkingConfig{Type: "enabled", BudgetTokens: 1000}, } result, err := FromMessagesRequest(req) if err != nil { t.Fatalf("unexpected error: %v", err) } if result.Think == nil { t.Fatal("expected Think to be set") } if v, ok := result.Think.Value.(bool); !ok || !v { t.Errorf("expected Think.Value to be true, got %v", result.Think.Value) } } func TestFromMessagesRequest_ToolUseMissingID(t *testing.T) { req := MessagesRequest{ Model: "test-model", MaxTokens: 1024, Messages: []MessageParam{ { Role: "assistant", Content: []any{ map[string]any{ "type": "tool_use", "name": "get_weather", }, }, }, }, } _, err := FromMessagesRequest(req) if err == nil { t.Fatal("expected error for missing tool_use id") } if err.Error() != "tool_use block missing required 'id' field" { t.Errorf("unexpected error message: %v", err) } } func TestFromMessagesRequest_ToolUseMissingName(t *testing.T) { req := MessagesRequest{ Model: "test-model", MaxTokens: 1024, Messages: []MessageParam{ { Role: "assistant", Content: []any{ map[string]any{ "type": "tool_use", "id": "call_123", }, }, }, }, } _, err := FromMessagesRequest(req) if err == nil { t.Fatal("expected error for missing tool_use name") } if err.Error() != "tool_use block missing required 'name' field" { t.Errorf("unexpected error message: %v", err) } } func TestFromMessagesRequest_InvalidToolSchema(t *testing.T) { req := MessagesRequest{ Model: "test-model", MaxTokens: 1024, Messages: []MessageParam{{Role: "user", Content: "Hello"}}, Tools: []Tool{ { Name: "bad_tool", InputSchema: json.RawMessage(`{invalid json`), }, }, } _, err := FromMessagesRequest(req) if err == nil { t.Fatal("expected error for invalid tool schema") } } func TestToMessagesResponse_Basic(t *testing.T) { resp := api.ChatResponse{ Model: "test-model", Message: api.Message{ Role: "assistant", Content: "Hello there!", }, Done: true, DoneReason: "stop", Metrics: api.Metrics{ PromptEvalCount: 10, EvalCount: 5, }, } result := ToMessagesResponse("msg_123", resp) if result.ID != "msg_123" { t.Errorf("expected ID 'msg_123', got %q", result.ID) } if result.Type != "message" { t.Errorf("expected type 'message', got %q", result.Type) } if result.Role != "assistant" { t.Errorf("expected role 'assistant', got %q", result.Role) } if len(result.Content) != 1 { t.Fatalf("expected 1 content block, got %d", len(result.Content)) } if result.Content[0].Type != "text" || result.Content[0].Text != "Hello there!" { t.Errorf("unexpected content: %+v", result.Content[0]) } if result.StopReason != "end_turn" { t.Errorf("expected stop_reason 'end_turn', got %q", result.StopReason) } if result.Usage.InputTokens != 10 || result.Usage.OutputTokens != 5 { t.Errorf("unexpected usage: %+v", result.Usage) } } func TestToMessagesResponse_WithToolCalls(t *testing.T) { resp := api.ChatResponse{ Model: "test-model", Message: api.Message{ Role: "assistant", ToolCalls: []api.ToolCall{ { ID: "call_123", Function: api.ToolCallFunction{ Name: "get_weather", Arguments: map[string]any{"location": "Paris"}, }, }, }, }, Done: true, DoneReason: "stop", } result := ToMessagesResponse("msg_123", resp) if len(result.Content) != 1 { t.Fatalf("expected 1 content block, got %d", len(result.Content)) } if result.Content[0].Type != "tool_use" { t.Errorf("expected type 'tool_use', got %q", result.Content[0].Type) } if result.Content[0].ID != "call_123" { t.Errorf("expected ID 'call_123', got %q", result.Content[0].ID) } if result.Content[0].Name != "get_weather" { t.Errorf("expected name 'get_weather', got %q", result.Content[0].Name) } if result.StopReason != "tool_use" { t.Errorf("expected stop_reason 'tool_use', got %q", result.StopReason) } } func TestToMessagesResponse_WithThinking(t *testing.T) { resp := api.ChatResponse{ Model: "test-model", Message: api.Message{ Role: "assistant", Content: "The answer is 42.", Thinking: "Let me think about this...", }, Done: true, DoneReason: "stop", } result := ToMessagesResponse("msg_123", resp) if len(result.Content) != 2 { t.Fatalf("expected 2 content blocks, got %d", len(result.Content)) } if result.Content[0].Type != "thinking" { t.Errorf("expected first block type 'thinking', got %q", result.Content[0].Type) } if result.Content[0].Thinking != "Let me think about this..." { t.Errorf("unexpected thinking content: %q", result.Content[0].Thinking) } if result.Content[1].Type != "text" { t.Errorf("expected second block type 'text', got %q", result.Content[1].Type) } } func TestMapStopReason(t *testing.T) { tests := []struct { reason string hasToolCalls bool want string }{ {"stop", false, "end_turn"}, {"length", false, "max_tokens"}, {"stop", true, "tool_use"}, {"other", false, "stop_sequence"}, {"", false, ""}, } for _, tt := range tests { got := mapStopReason(tt.reason, tt.hasToolCalls) if got != tt.want { t.Errorf("mapStopReason(%q, %v) = %q, want %q", tt.reason, tt.hasToolCalls, got, tt.want) } } } func TestNewError(t *testing.T) { tests := []struct { code int want string }{ {400, "invalid_request_error"}, {401, "authentication_error"}, {403, "permission_error"}, {404, "not_found_error"}, {429, "rate_limit_error"}, {500, "api_error"}, {503, "overloaded_error"}, {529, "overloaded_error"}, } for _, tt := range tests { result := NewError(tt.code, "test message") if result.Type != "error" { t.Errorf("NewError(%d) type = %q, want 'error'", tt.code, result.Type) } if result.Error.Type != tt.want { t.Errorf("NewError(%d) error.type = %q, want %q", tt.code, result.Error.Type, tt.want) } if result.Error.Message != "test message" { t.Errorf("NewError(%d) message = %q, want 'test message'", tt.code, result.Error.Message) } if result.RequestID == "" { t.Errorf("NewError(%d) request_id should not be empty", tt.code) } } } func TestGenerateMessageID(t *testing.T) { id1 := GenerateMessageID() id2 := GenerateMessageID() if id1 == "" { t.Error("GenerateMessageID returned empty string") } if id1 == id2 { t.Error("GenerateMessageID returned duplicate IDs") } if len(id1) < 10 { t.Errorf("GenerateMessageID returned short ID: %q", id1) } if id1[:4] != "msg_" { t.Errorf("GenerateMessageID should start with 'msg_', got %q", id1[:4]) } } func TestStreamConverter_Basic(t *testing.T) { conv := NewStreamConverter("msg_123", "test-model") // First chunk resp1 := api.ChatResponse{ Model: "test-model", Message: api.Message{ Role: "assistant", Content: "Hello", }, Metrics: api.Metrics{PromptEvalCount: 10}, } events1 := conv.Process(resp1) if len(events1) < 3 { t.Fatalf("expected at least 3 events for first chunk, got %d", len(events1)) } // Should have message_start, content_block_start, content_block_delta if events1[0].Event != "message_start" { t.Errorf("expected first event 'message_start', got %q", events1[0].Event) } if events1[1].Event != "content_block_start" { t.Errorf("expected second event 'content_block_start', got %q", events1[1].Event) } if events1[2].Event != "content_block_delta" { t.Errorf("expected third event 'content_block_delta', got %q", events1[2].Event) } // Final chunk resp2 := api.ChatResponse{ Model: "test-model", Message: api.Message{ Role: "assistant", Content: " world!", }, Done: true, DoneReason: "stop", Metrics: api.Metrics{EvalCount: 5}, } events2 := conv.Process(resp2) // Should have content_block_delta, content_block_stop, message_delta, message_stop hasStop := false for _, e := range events2 { if e.Event == "message_stop" { hasStop = true } } if !hasStop { t.Error("expected message_stop event in final chunk") } } func TestStreamConverter_WithToolCalls(t *testing.T) { conv := NewStreamConverter("msg_123", "test-model") resp := api.ChatResponse{ Model: "test-model", Message: api.Message{ Role: "assistant", ToolCalls: []api.ToolCall{ { ID: "call_123", Function: api.ToolCallFunction{ Name: "get_weather", Arguments: map[string]any{"location": "Paris"}, }, }, }, }, Done: true, DoneReason: "stop", Metrics: api.Metrics{PromptEvalCount: 10, EvalCount: 5}, } events := conv.Process(resp) hasToolStart := false hasToolDelta := false for _, e := range events { if e.Event == "content_block_start" { if start, ok := e.Data.(ContentBlockStartEvent); ok { if start.ContentBlock.Type == "tool_use" { hasToolStart = true } } } if e.Event == "content_block_delta" { if delta, ok := e.Data.(ContentBlockDeltaEvent); ok { if delta.Delta.Type == "input_json_delta" { hasToolDelta = true } } } } if !hasToolStart { t.Error("expected tool_use content_block_start event") } if !hasToolDelta { t.Error("expected input_json_delta event") } } func TestStreamConverter_ToolCallWithUnmarshalableArgs(t *testing.T) { // Test that unmarshalable arguments (like channels) are handled gracefully // and don't cause a panic or corrupt stream conv := NewStreamConverter("msg_123", "test-model") // Create a channel which cannot be JSON marshaled unmarshalable := make(chan int) resp := api.ChatResponse{ Model: "test-model", Message: api.Message{ Role: "assistant", ToolCalls: []api.ToolCall{ { ID: "call_bad", Function: api.ToolCallFunction{ Name: "bad_function", Arguments: map[string]any{"channel": unmarshalable}, }, }, }, }, Done: true, DoneReason: "stop", } // Should not panic and should skip the unmarshalable tool call events := conv.Process(resp) // Verify no tool_use block was started (since marshal failed before block start) hasToolStart := false for _, e := range events { if e.Event == "content_block_start" { if start, ok := e.Data.(ContentBlockStartEvent); ok { if start.ContentBlock.Type == "tool_use" { hasToolStart = true } } } } if hasToolStart { t.Error("expected no tool_use block when arguments cannot be marshaled") } } func TestStreamConverter_MultipleToolCallsWithMixedValidity(t *testing.T) { // Test that valid tool calls still work when mixed with invalid ones conv := NewStreamConverter("msg_123", "test-model") unmarshalable := make(chan int) resp := api.ChatResponse{ Model: "test-model", Message: api.Message{ Role: "assistant", ToolCalls: []api.ToolCall{ { ID: "call_good", Function: api.ToolCallFunction{ Name: "good_function", Arguments: map[string]any{"location": "Paris"}, }, }, { ID: "call_bad", Function: api.ToolCallFunction{ Name: "bad_function", Arguments: map[string]any{"channel": unmarshalable}, }, }, }, }, Done: true, DoneReason: "stop", } events := conv.Process(resp) // Count tool_use blocks - should only have 1 (the valid one) toolStartCount := 0 toolDeltaCount := 0 for _, e := range events { if e.Event == "content_block_start" { if start, ok := e.Data.(ContentBlockStartEvent); ok { if start.ContentBlock.Type == "tool_use" { toolStartCount++ if start.ContentBlock.Name != "good_function" { t.Errorf("expected tool name 'good_function', got %q", start.ContentBlock.Name) } } } } if e.Event == "content_block_delta" { if delta, ok := e.Data.(ContentBlockDeltaEvent); ok { if delta.Delta.Type == "input_json_delta" { toolDeltaCount++ } } } } if toolStartCount != 1 { t.Errorf("expected 1 tool_use block, got %d", toolStartCount) } if toolDeltaCount != 1 { t.Errorf("expected 1 input_json_delta, got %d", toolDeltaCount) } }