update tests

This commit is contained in:
ParthSareen 2025-11-14 16:00:34 -08:00
parent 471cbbe95a
commit 0103a3a89b
1 changed files with 270 additions and 245 deletions

View File

@ -3,13 +3,18 @@
package integration package integration
import ( import (
"bytes"
"context" "context"
"encoding/json"
"fmt" "fmt"
"io"
"net/http"
"strings" "strings"
"testing" "testing"
"time" "time"
"github.com/ollama/ollama/api" "github.com/ollama/ollama/api"
"github.com/ollama/ollama/openai"
) )
var libraryToolsModels = []string{ var libraryToolsModels = []string{
@ -29,302 +34,322 @@ var libraryToolsModels = []string{
"granite3.3", "granite3.3",
} }
func TestAPIToolCalling(t *testing.T) { func float64Ptr(v float64) *float64 {
initialTimeout := 60 * time.Second return &v
streamTimeout := 60 * time.Second
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Minute)
defer cancel()
client, _, cleanup := InitServerConnection(ctx, t)
defer cleanup()
for _, model := range libraryToolsModels {
t.Run(model, func(t *testing.T) {
if v, ok := minVRAM[model]; ok {
skipUnderMinVRAM(t, v)
}
if err := PullIfMissing(ctx, client, model); err != nil {
t.Fatalf("pull failed %s", err)
}
tools := []api.Tool{
{
Type: "function",
Function: api.ToolFunction{
Name: "get_weather",
Description: "Get the current weather in a given location",
Parameters: api.ToolFunctionParameters{
Type: "object",
Required: []string{"location"},
Properties: map[string]api.ToolProperty{
"location": {
Type: api.PropertyType{"string"},
Description: "The city and state, e.g. San Francisco, CA",
},
},
},
},
},
}
req := api.ChatRequest{
Model: model,
Messages: []api.Message{
{
Role: "user",
Content: "Call get_weather with location set to San Francisco.",
},
},
Tools: tools,
Options: map[string]any{
"temperature": 0,
},
}
stallTimer := time.NewTimer(initialTimeout)
var gotToolCall bool
var lastToolCall api.ToolCall
fn := func(response api.ChatResponse) error {
if len(response.Message.ToolCalls) > 0 {
gotToolCall = true
lastToolCall = response.Message.ToolCalls[len(response.Message.ToolCalls)-1]
}
if !stallTimer.Reset(streamTimeout) {
return fmt.Errorf("stall was detected while streaming response, aborting")
}
return nil
}
stream := true
req.Stream = &stream
done := make(chan int)
var genErr error
go func() {
genErr = client.Chat(ctx, &req, fn)
done <- 0
}()
select {
case <-stallTimer.C:
t.Errorf("tool-calling chat never started. Timed out after: %s", initialTimeout.String())
case <-done:
if genErr != nil {
t.Fatalf("chat failed: %v", genErr)
}
if !gotToolCall {
t.Fatalf("expected at least one tool call, got none")
}
if lastToolCall.Function.Name != "get_weather" {
t.Errorf("unexpected tool called: got %q want %q", lastToolCall.Function.Name, "get_weather")
}
if _, ok := lastToolCall.Function.Arguments["location"]; !ok {
t.Errorf("expected tool arguments to include 'location', got: %s", lastToolCall.Function.Arguments.String())
}
case <-ctx.Done():
t.Error("outer test context done while waiting for tool-calling chat")
}
})
}
} }
func TestAPIToolCallingMultiTurn(t *testing.T) { func sendOpenAIChatRequest(ctx context.Context, endpoint string, req openai.ChatCompletionRequest) (*openai.ChatCompletion, error) {
jsonData, err := json.Marshal(req)
if err != nil {
return nil, fmt.Errorf("failed to marshal request: %w", err)
}
httpReq, err := http.NewRequestWithContext(ctx, "POST", endpoint+"/v1/chat/completions", bytes.NewBuffer(jsonData))
if err != nil {
return nil, fmt.Errorf("failed to create request: %w", err)
}
httpReq.Header.Set("Content-Type", "application/json")
client := &http.Client{
Timeout: 10 * time.Minute,
}
resp, err := client.Do(httpReq)
if err != nil {
return nil, fmt.Errorf("failed to send request: %w", err)
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("failed to read response: %w", err)
}
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("API error: status=%d, body=%s", resp.StatusCode, string(body))
}
var chatResp openai.ChatCompletion
if err := json.Unmarshal(body, &chatResp); err != nil {
return nil, fmt.Errorf("failed to unmarshal response: %w, body: %s", err, string(body))
}
return &chatResp, nil
}
func sendOpenAIChatStreamRequest(ctx context.Context, endpoint string, req openai.ChatCompletionRequest, fn func(openai.ChatCompletionChunk) error) error {
jsonData, err := json.Marshal(req)
if err != nil {
return fmt.Errorf("failed to marshal request: %w", err)
}
httpReq, err := http.NewRequestWithContext(ctx, "POST", endpoint+"/v1/chat/completions", bytes.NewBuffer(jsonData))
if err != nil {
return fmt.Errorf("failed to create request: %w", err)
}
httpReq.Header.Set("Content-Type", "application/json")
httpReq.Header.Set("Accept", "text/event-stream")
client := &http.Client{
Timeout: 0, // No timeout for streaming
}
resp, err := client.Do(httpReq)
if err != nil {
return fmt.Errorf("failed to send request: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
body, _ := io.ReadAll(resp.Body)
return fmt.Errorf("API error: status=%d, body=%s", resp.StatusCode, string(body))
}
decoder := resp.Body
reader := bytes.NewBuffer([]byte{})
buf := make([]byte, 4096)
for {
select {
case <-ctx.Done():
return ctx.Err()
default:
n, err := decoder.Read(buf)
if n > 0 {
reader.Write(buf[:n])
// Process complete lines
for {
line, err := reader.ReadString('\n')
if err != nil {
// Not a complete line yet
reader.WriteString(line)
break
}
line = strings.TrimSpace(line)
if strings.HasPrefix(line, "data: ") {
data := strings.TrimPrefix(line, "data: ")
if data == "[DONE]" {
return nil
}
var streamResp openai.ChatCompletionChunk
if err := json.Unmarshal([]byte(data), &streamResp); err != nil {
return fmt.Errorf("failed to unmarshal stream response: %w", err)
}
if err := fn(streamResp); err != nil {
return err
}
}
}
}
if err != nil {
if err != io.EOF {
return fmt.Errorf("error reading stream: %w", err)
}
break
}
}
}
return nil
}
// TestToolCallingAllAPIs tests both Ollama and OpenAI APIs with shared model loading
func TestToolCallingAllAPIs(t *testing.T) {
initialTimeout := 60 * time.Second initialTimeout := 60 * time.Second
streamTimeout := 60 * time.Second streamTimeout := 60 * time.Second
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Minute) ctx, cancel := context.WithTimeout(context.Background(), 10*time.Minute)
defer cancel() defer cancel()
client, _, cleanup := InitServerConnection(ctx, t) client, endpoint, cleanup := InitServerConnection(ctx, t)
defer cleanup() defer cleanup()
for _, model := range libraryToolsModels { for _, model := range libraryToolsModels {
t.Run(model, func(t *testing.T) { t.Run(model, func(t *testing.T) {
// Skip if insufficient VRAM
if v, ok := minVRAM[model]; ok { if v, ok := minVRAM[model]; ok {
skipUnderMinVRAM(t, v) skipUnderMinVRAM(t, v)
} }
// Pull model if missing - only do this once per model
if err := PullIfMissing(ctx, client, model); err != nil { if err := PullIfMissing(ctx, client, model); err != nil {
t.Fatalf("pull failed %s", err) t.Fatalf("pull failed %s", err)
} }
tools := []api.Tool{ t.Run("OllamaAPI", func(t *testing.T) {
{ tools := []api.Tool{
Type: "function", {
Function: api.ToolFunction{ Type: "function",
Name: "get_weather", Function: api.ToolFunction{
Description: "Get the current weather in a given location", Name: "get_weather",
Parameters: api.ToolFunctionParameters{ Description: "Get the current weather in a given location",
Type: "object", Parameters: api.ToolFunctionParameters{
Required: []string{"location"}, Type: "object",
Properties: map[string]api.ToolProperty{ Required: []string{"location"},
"location": { Properties: map[string]api.ToolProperty{
Type: api.PropertyType{"string"}, "location": {
Description: "The city and state, e.g. San Francisco, CA", Type: api.PropertyType{"string"},
Description: "The city and state, e.g. San Francisco, CA",
},
}, },
}, },
}, },
}, },
},
}
userMessage := api.Message{
Role: "user",
Content: "What's the weather like in San Francisco?",
}
req := api.ChatRequest{
Model: model,
Messages: []api.Message{userMessage},
Tools: tools,
Options: map[string]any{
"temperature": 0,
},
}
stallTimer := time.NewTimer(initialTimeout)
var assistantMessage api.Message
var gotToolCall bool
var toolCallID string
fn := func(response api.ChatResponse) error {
if response.Message.Content != "" {
assistantMessage.Content += response.Message.Content
assistantMessage.Role = "assistant"
}
if len(response.Message.ToolCalls) > 0 {
gotToolCall = true
assistantMessage.ToolCalls = response.Message.ToolCalls
assistantMessage.Role = "assistant"
toolCallID = response.Message.ToolCalls[0].ID
}
if !stallTimer.Reset(streamTimeout) {
return fmt.Errorf("stall was detected while streaming response, aborting")
}
return nil
}
stream := true
req.Stream = &stream
done := make(chan int)
var genErr error
go func() {
genErr = client.Chat(ctx, &req, fn)
done <- 0
}()
select {
case <-stallTimer.C:
t.Fatalf("first turn chat never started. Timed out after: %s", initialTimeout.String())
case <-done:
if genErr != nil {
t.Fatalf("first turn chat failed: %v", genErr)
} }
if !gotToolCall { req := api.ChatRequest{
t.Fatalf("expected at least one tool call in first turn, got none") Model: model,
} Messages: []api.Message{
{
if len(assistantMessage.ToolCalls) == 0 { Role: "user",
t.Fatalf("expected tool calls in assistant message, got none") Content: "Call get_weather with location set to San Francisco.",
} },
},
firstToolCall := assistantMessage.ToolCalls[0] Tools: tools,
if firstToolCall.Function.Name != "get_weather" {
t.Errorf("unexpected tool called: got %q want %q", firstToolCall.Function.Name, "get_weather")
}
location, ok := firstToolCall.Function.Arguments["location"]
if !ok {
t.Fatalf("expected tool arguments to include 'location', got: %s", firstToolCall.Function.Arguments.String())
}
toolResult := `{"temperature": 72, "condition": "sunny", "humidity": 65}`
toolMessage := api.Message{
Role: "tool",
Content: toolResult,
ToolName: "get_weather",
ToolCallID: toolCallID,
}
messages := []api.Message{
userMessage,
assistantMessage,
toolMessage,
}
req2 := api.ChatRequest{
Model: model,
Messages: messages,
Tools: tools,
Options: map[string]any{ Options: map[string]any{
"temperature": 0, "temperature": 0,
}, },
} }
stallTimer2 := time.NewTimer(initialTimeout) stallTimer := time.NewTimer(initialTimeout)
var finalResponse string var gotToolCall bool
var gotSecondToolCall bool var lastToolCall api.ToolCall
fn2 := func(response api.ChatResponse) error { fn := func(response api.ChatResponse) error {
if len(response.Message.ToolCalls) > 0 { if len(response.Message.ToolCalls) > 0 {
gotSecondToolCall = true gotToolCall = true
lastToolCall = response.Message.ToolCalls[len(response.Message.ToolCalls)-1]
} }
if response.Message.Content != "" { if !stallTimer.Reset(streamTimeout) {
finalResponse += response.Message.Content
}
if !stallTimer2.Reset(streamTimeout) {
return fmt.Errorf("stall was detected while streaming response, aborting") return fmt.Errorf("stall was detected while streaming response, aborting")
} }
return nil return nil
} }
req2.Stream = &stream stream := true
done2 := make(chan int) req.Stream = &stream
var genErr2 error done := make(chan int)
var genErr error
go func() { go func() {
genErr2 = client.Chat(ctx, &req2, fn2) genErr = client.Chat(ctx, &req, fn)
done2 <- 0 done <- 0
}() }()
select { select {
case <-stallTimer2.C: case <-stallTimer.C:
t.Fatalf("second turn chat never started. Timed out after: %s", initialTimeout.String()) t.Errorf("tool-calling chat never started. Timed out after: %s", initialTimeout.String())
case <-done2: case <-done:
if genErr2 != nil { if genErr != nil {
t.Fatalf("second turn chat failed: %v", genErr2) t.Fatalf("chat failed: %v", genErr)
} }
if gotSecondToolCall { if !gotToolCall {
t.Errorf("expected no tool calls in second turn, but got tool calls. Model should respond with natural language after receiving tool result.") t.Fatalf("expected at least one tool call, got none")
} }
if finalResponse == "" { if lastToolCall.Function.Name != "get_weather" {
t.Fatalf("expected natural language response in second turn, got empty response") t.Errorf("unexpected tool called: got %q want %q", lastToolCall.Function.Name, "get_weather")
} }
responseLower := strings.ToLower(finalResponse) if _, ok := lastToolCall.Function.Arguments["location"]; !ok {
expectedKeywords := []string{"72", "sunny", "temperature", "weather", "san francisco", "fahrenheit"} t.Errorf("expected tool arguments to include 'location', got: %s", lastToolCall.Function.Arguments.String())
foundKeyword := false
for _, keyword := range expectedKeywords {
if strings.Contains(responseLower, strings.ToLower(keyword)) {
foundKeyword = true
break
}
}
if !foundKeyword {
t.Logf("response: %s", finalResponse)
t.Logf("location from tool call: %v", location)
} }
case <-ctx.Done(): case <-ctx.Done():
t.Error("outer test context done while waiting for second turn") t.Error("outer test context done while waiting for tool-calling chat")
} }
case <-ctx.Done(): })
t.Error("outer test context done while waiting for first turn")
} t.Run("OpenAIAPI", func(t *testing.T) {
tools := []api.Tool{
{
Type: "function",
Function: api.ToolFunction{
Name: "get_weather",
Description: "Get the current weather in a given location",
Parameters: api.ToolFunctionParameters{
Type: "object",
Required: []string{"location"},
Properties: map[string]api.ToolProperty{
"location": {
Type: api.PropertyType{"string"},
Description: "The city and state, e.g. San Francisco, CA",
},
},
},
},
},
}
req := openai.ChatCompletionRequest{
Model: model,
Messages: []openai.Message{
{
Role: "user",
Content: "Call get_weather with location set to San Francisco.",
},
},
Tools: tools,
Stream: true,
Temperature: float64Ptr(0),
}
stallTimer := time.NewTimer(initialTimeout)
var gotToolCall bool
var lastToolCall openai.ToolCall
fn := func(response openai.ChatCompletionChunk) error {
if len(response.Choices) > 0 && len(response.Choices[0].Delta.ToolCalls) > 0 {
gotToolCall = true
toolCalls := response.Choices[0].Delta.ToolCalls
lastToolCall = toolCalls[len(toolCalls)-1]
}
if !stallTimer.Reset(streamTimeout) {
return fmt.Errorf("stall was detected while streaming response, aborting")
}
return nil
}
done := make(chan int)
var genErr error
go func() {
genErr = sendOpenAIChatStreamRequest(ctx, "http://"+endpoint, req, fn)
done <- 0
}()
select {
case <-stallTimer.C:
t.Errorf("tool-calling chat never started. Timed out after: %s", initialTimeout.String())
case <-done:
if genErr != nil {
t.Fatalf("chat failed: %v", genErr)
}
if !gotToolCall {
t.Fatalf("expected at least one tool call, got none")
}
if lastToolCall.Function.Name != "get_weather" {
t.Errorf("unexpected tool called: got %q want %q", lastToolCall.Function.Name, "get_weather")
}
if !strings.Contains(lastToolCall.Function.Arguments, "location") {
t.Errorf("expected tool arguments to include 'location', got: %s", lastToolCall.Function.Arguments)
}
if !strings.Contains(lastToolCall.Function.Arguments, "San Francisco") {
t.Errorf("expected tool arguments to include 'San Francisco', got: %s", lastToolCall.Function.Arguments)
}
case <-ctx.Done():
t.Error("outer test context done while waiting for tool-calling chat")
}
})
}) })
} }
} }