update tests
This commit is contained in:
parent
471cbbe95a
commit
0103a3a89b
|
|
@ -3,13 +3,18 @@
|
|||
package integration
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/openai"
|
||||
)
|
||||
|
||||
var libraryToolsModels = []string{
|
||||
|
|
@ -29,302 +34,322 @@ var libraryToolsModels = []string{
|
|||
"granite3.3",
|
||||
}
|
||||
|
||||
func TestAPIToolCalling(t *testing.T) {
|
||||
initialTimeout := 60 * time.Second
|
||||
streamTimeout := 60 * time.Second
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Minute)
|
||||
defer cancel()
|
||||
|
||||
client, _, cleanup := InitServerConnection(ctx, t)
|
||||
defer cleanup()
|
||||
|
||||
for _, model := range libraryToolsModels {
|
||||
t.Run(model, func(t *testing.T) {
|
||||
if v, ok := minVRAM[model]; ok {
|
||||
skipUnderMinVRAM(t, v)
|
||||
}
|
||||
|
||||
if err := PullIfMissing(ctx, client, model); err != nil {
|
||||
t.Fatalf("pull failed %s", err)
|
||||
}
|
||||
|
||||
tools := []api.Tool{
|
||||
{
|
||||
Type: "function",
|
||||
Function: api.ToolFunction{
|
||||
Name: "get_weather",
|
||||
Description: "Get the current weather in a given location",
|
||||
Parameters: api.ToolFunctionParameters{
|
||||
Type: "object",
|
||||
Required: []string{"location"},
|
||||
Properties: map[string]api.ToolProperty{
|
||||
"location": {
|
||||
Type: api.PropertyType{"string"},
|
||||
Description: "The city and state, e.g. San Francisco, CA",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
req := api.ChatRequest{
|
||||
Model: model,
|
||||
Messages: []api.Message{
|
||||
{
|
||||
Role: "user",
|
||||
Content: "Call get_weather with location set to San Francisco.",
|
||||
},
|
||||
},
|
||||
Tools: tools,
|
||||
Options: map[string]any{
|
||||
"temperature": 0,
|
||||
},
|
||||
}
|
||||
|
||||
stallTimer := time.NewTimer(initialTimeout)
|
||||
var gotToolCall bool
|
||||
var lastToolCall api.ToolCall
|
||||
|
||||
fn := func(response api.ChatResponse) error {
|
||||
if len(response.Message.ToolCalls) > 0 {
|
||||
gotToolCall = true
|
||||
lastToolCall = response.Message.ToolCalls[len(response.Message.ToolCalls)-1]
|
||||
}
|
||||
if !stallTimer.Reset(streamTimeout) {
|
||||
return fmt.Errorf("stall was detected while streaming response, aborting")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
stream := true
|
||||
req.Stream = &stream
|
||||
done := make(chan int)
|
||||
var genErr error
|
||||
go func() {
|
||||
genErr = client.Chat(ctx, &req, fn)
|
||||
done <- 0
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-stallTimer.C:
|
||||
t.Errorf("tool-calling chat never started. Timed out after: %s", initialTimeout.String())
|
||||
case <-done:
|
||||
if genErr != nil {
|
||||
t.Fatalf("chat failed: %v", genErr)
|
||||
}
|
||||
|
||||
if !gotToolCall {
|
||||
t.Fatalf("expected at least one tool call, got none")
|
||||
}
|
||||
|
||||
if lastToolCall.Function.Name != "get_weather" {
|
||||
t.Errorf("unexpected tool called: got %q want %q", lastToolCall.Function.Name, "get_weather")
|
||||
}
|
||||
|
||||
if _, ok := lastToolCall.Function.Arguments["location"]; !ok {
|
||||
t.Errorf("expected tool arguments to include 'location', got: %s", lastToolCall.Function.Arguments.String())
|
||||
}
|
||||
case <-ctx.Done():
|
||||
t.Error("outer test context done while waiting for tool-calling chat")
|
||||
}
|
||||
})
|
||||
}
|
||||
func float64Ptr(v float64) *float64 {
|
||||
return &v
|
||||
}
|
||||
|
||||
func TestAPIToolCallingMultiTurn(t *testing.T) {
|
||||
func sendOpenAIChatRequest(ctx context.Context, endpoint string, req openai.ChatCompletionRequest) (*openai.ChatCompletion, error) {
|
||||
jsonData, err := json.Marshal(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to marshal request: %w", err)
|
||||
}
|
||||
|
||||
httpReq, err := http.NewRequestWithContext(ctx, "POST", endpoint+"/v1/chat/completions", bytes.NewBuffer(jsonData))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create request: %w", err)
|
||||
}
|
||||
|
||||
httpReq.Header.Set("Content-Type", "application/json")
|
||||
|
||||
client := &http.Client{
|
||||
Timeout: 10 * time.Minute,
|
||||
}
|
||||
|
||||
resp, err := client.Do(httpReq)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to send request: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read response: %w", err)
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf("API error: status=%d, body=%s", resp.StatusCode, string(body))
|
||||
}
|
||||
|
||||
var chatResp openai.ChatCompletion
|
||||
if err := json.Unmarshal(body, &chatResp); err != nil {
|
||||
return nil, fmt.Errorf("failed to unmarshal response: %w, body: %s", err, string(body))
|
||||
}
|
||||
|
||||
return &chatResp, nil
|
||||
}
|
||||
|
||||
func sendOpenAIChatStreamRequest(ctx context.Context, endpoint string, req openai.ChatCompletionRequest, fn func(openai.ChatCompletionChunk) error) error {
|
||||
jsonData, err := json.Marshal(req)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal request: %w", err)
|
||||
}
|
||||
|
||||
httpReq, err := http.NewRequestWithContext(ctx, "POST", endpoint+"/v1/chat/completions", bytes.NewBuffer(jsonData))
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create request: %w", err)
|
||||
}
|
||||
|
||||
httpReq.Header.Set("Content-Type", "application/json")
|
||||
httpReq.Header.Set("Accept", "text/event-stream")
|
||||
|
||||
client := &http.Client{
|
||||
Timeout: 0, // No timeout for streaming
|
||||
}
|
||||
|
||||
resp, err := client.Do(httpReq)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to send request: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
return fmt.Errorf("API error: status=%d, body=%s", resp.StatusCode, string(body))
|
||||
}
|
||||
|
||||
decoder := resp.Body
|
||||
reader := bytes.NewBuffer([]byte{})
|
||||
buf := make([]byte, 4096)
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
default:
|
||||
n, err := decoder.Read(buf)
|
||||
if n > 0 {
|
||||
reader.Write(buf[:n])
|
||||
|
||||
// Process complete lines
|
||||
for {
|
||||
line, err := reader.ReadString('\n')
|
||||
if err != nil {
|
||||
// Not a complete line yet
|
||||
reader.WriteString(line)
|
||||
break
|
||||
}
|
||||
|
||||
line = strings.TrimSpace(line)
|
||||
if strings.HasPrefix(line, "data: ") {
|
||||
data := strings.TrimPrefix(line, "data: ")
|
||||
|
||||
if data == "[DONE]" {
|
||||
return nil
|
||||
}
|
||||
|
||||
var streamResp openai.ChatCompletionChunk
|
||||
if err := json.Unmarshal([]byte(data), &streamResp); err != nil {
|
||||
return fmt.Errorf("failed to unmarshal stream response: %w", err)
|
||||
}
|
||||
|
||||
if err := fn(streamResp); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
if err != io.EOF {
|
||||
return fmt.Errorf("error reading stream: %w", err)
|
||||
}
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// TestToolCallingAllAPIs tests both Ollama and OpenAI APIs with shared model loading
|
||||
func TestToolCallingAllAPIs(t *testing.T) {
|
||||
initialTimeout := 60 * time.Second
|
||||
streamTimeout := 60 * time.Second
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Minute)
|
||||
defer cancel()
|
||||
|
||||
client, _, cleanup := InitServerConnection(ctx, t)
|
||||
client, endpoint, cleanup := InitServerConnection(ctx, t)
|
||||
defer cleanup()
|
||||
|
||||
for _, model := range libraryToolsModels {
|
||||
t.Run(model, func(t *testing.T) {
|
||||
// Skip if insufficient VRAM
|
||||
if v, ok := minVRAM[model]; ok {
|
||||
skipUnderMinVRAM(t, v)
|
||||
}
|
||||
|
||||
// Pull model if missing - only do this once per model
|
||||
if err := PullIfMissing(ctx, client, model); err != nil {
|
||||
t.Fatalf("pull failed %s", err)
|
||||
}
|
||||
|
||||
tools := []api.Tool{
|
||||
{
|
||||
Type: "function",
|
||||
Function: api.ToolFunction{
|
||||
Name: "get_weather",
|
||||
Description: "Get the current weather in a given location",
|
||||
Parameters: api.ToolFunctionParameters{
|
||||
Type: "object",
|
||||
Required: []string{"location"},
|
||||
Properties: map[string]api.ToolProperty{
|
||||
"location": {
|
||||
Type: api.PropertyType{"string"},
|
||||
Description: "The city and state, e.g. San Francisco, CA",
|
||||
t.Run("OllamaAPI", func(t *testing.T) {
|
||||
tools := []api.Tool{
|
||||
{
|
||||
Type: "function",
|
||||
Function: api.ToolFunction{
|
||||
Name: "get_weather",
|
||||
Description: "Get the current weather in a given location",
|
||||
Parameters: api.ToolFunctionParameters{
|
||||
Type: "object",
|
||||
Required: []string{"location"},
|
||||
Properties: map[string]api.ToolProperty{
|
||||
"location": {
|
||||
Type: api.PropertyType{"string"},
|
||||
Description: "The city and state, e.g. San Francisco, CA",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
userMessage := api.Message{
|
||||
Role: "user",
|
||||
Content: "What's the weather like in San Francisco?",
|
||||
}
|
||||
|
||||
req := api.ChatRequest{
|
||||
Model: model,
|
||||
Messages: []api.Message{userMessage},
|
||||
Tools: tools,
|
||||
Options: map[string]any{
|
||||
"temperature": 0,
|
||||
},
|
||||
}
|
||||
|
||||
stallTimer := time.NewTimer(initialTimeout)
|
||||
var assistantMessage api.Message
|
||||
var gotToolCall bool
|
||||
var toolCallID string
|
||||
|
||||
fn := func(response api.ChatResponse) error {
|
||||
if response.Message.Content != "" {
|
||||
assistantMessage.Content += response.Message.Content
|
||||
assistantMessage.Role = "assistant"
|
||||
}
|
||||
if len(response.Message.ToolCalls) > 0 {
|
||||
gotToolCall = true
|
||||
assistantMessage.ToolCalls = response.Message.ToolCalls
|
||||
assistantMessage.Role = "assistant"
|
||||
toolCallID = response.Message.ToolCalls[0].ID
|
||||
}
|
||||
if !stallTimer.Reset(streamTimeout) {
|
||||
return fmt.Errorf("stall was detected while streaming response, aborting")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
stream := true
|
||||
req.Stream = &stream
|
||||
done := make(chan int)
|
||||
var genErr error
|
||||
go func() {
|
||||
genErr = client.Chat(ctx, &req, fn)
|
||||
done <- 0
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-stallTimer.C:
|
||||
t.Fatalf("first turn chat never started. Timed out after: %s", initialTimeout.String())
|
||||
case <-done:
|
||||
if genErr != nil {
|
||||
t.Fatalf("first turn chat failed: %v", genErr)
|
||||
}
|
||||
|
||||
if !gotToolCall {
|
||||
t.Fatalf("expected at least one tool call in first turn, got none")
|
||||
}
|
||||
|
||||
if len(assistantMessage.ToolCalls) == 0 {
|
||||
t.Fatalf("expected tool calls in assistant message, got none")
|
||||
}
|
||||
|
||||
firstToolCall := assistantMessage.ToolCalls[0]
|
||||
if firstToolCall.Function.Name != "get_weather" {
|
||||
t.Errorf("unexpected tool called: got %q want %q", firstToolCall.Function.Name, "get_weather")
|
||||
}
|
||||
|
||||
location, ok := firstToolCall.Function.Arguments["location"]
|
||||
if !ok {
|
||||
t.Fatalf("expected tool arguments to include 'location', got: %s", firstToolCall.Function.Arguments.String())
|
||||
}
|
||||
|
||||
toolResult := `{"temperature": 72, "condition": "sunny", "humidity": 65}`
|
||||
toolMessage := api.Message{
|
||||
Role: "tool",
|
||||
Content: toolResult,
|
||||
ToolName: "get_weather",
|
||||
ToolCallID: toolCallID,
|
||||
}
|
||||
|
||||
messages := []api.Message{
|
||||
userMessage,
|
||||
assistantMessage,
|
||||
toolMessage,
|
||||
}
|
||||
|
||||
req2 := api.ChatRequest{
|
||||
Model: model,
|
||||
Messages: messages,
|
||||
Tools: tools,
|
||||
req := api.ChatRequest{
|
||||
Model: model,
|
||||
Messages: []api.Message{
|
||||
{
|
||||
Role: "user",
|
||||
Content: "Call get_weather with location set to San Francisco.",
|
||||
},
|
||||
},
|
||||
Tools: tools,
|
||||
Options: map[string]any{
|
||||
"temperature": 0,
|
||||
},
|
||||
}
|
||||
|
||||
stallTimer2 := time.NewTimer(initialTimeout)
|
||||
var finalResponse string
|
||||
var gotSecondToolCall bool
|
||||
stallTimer := time.NewTimer(initialTimeout)
|
||||
var gotToolCall bool
|
||||
var lastToolCall api.ToolCall
|
||||
|
||||
fn2 := func(response api.ChatResponse) error {
|
||||
fn := func(response api.ChatResponse) error {
|
||||
if len(response.Message.ToolCalls) > 0 {
|
||||
gotSecondToolCall = true
|
||||
gotToolCall = true
|
||||
lastToolCall = response.Message.ToolCalls[len(response.Message.ToolCalls)-1]
|
||||
}
|
||||
if response.Message.Content != "" {
|
||||
finalResponse += response.Message.Content
|
||||
}
|
||||
if !stallTimer2.Reset(streamTimeout) {
|
||||
if !stallTimer.Reset(streamTimeout) {
|
||||
return fmt.Errorf("stall was detected while streaming response, aborting")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
req2.Stream = &stream
|
||||
done2 := make(chan int)
|
||||
var genErr2 error
|
||||
stream := true
|
||||
req.Stream = &stream
|
||||
done := make(chan int)
|
||||
var genErr error
|
||||
go func() {
|
||||
genErr2 = client.Chat(ctx, &req2, fn2)
|
||||
done2 <- 0
|
||||
genErr = client.Chat(ctx, &req, fn)
|
||||
done <- 0
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-stallTimer2.C:
|
||||
t.Fatalf("second turn chat never started. Timed out after: %s", initialTimeout.String())
|
||||
case <-done2:
|
||||
if genErr2 != nil {
|
||||
t.Fatalf("second turn chat failed: %v", genErr2)
|
||||
case <-stallTimer.C:
|
||||
t.Errorf("tool-calling chat never started. Timed out after: %s", initialTimeout.String())
|
||||
case <-done:
|
||||
if genErr != nil {
|
||||
t.Fatalf("chat failed: %v", genErr)
|
||||
}
|
||||
|
||||
if gotSecondToolCall {
|
||||
t.Errorf("expected no tool calls in second turn, but got tool calls. Model should respond with natural language after receiving tool result.")
|
||||
if !gotToolCall {
|
||||
t.Fatalf("expected at least one tool call, got none")
|
||||
}
|
||||
|
||||
if finalResponse == "" {
|
||||
t.Fatalf("expected natural language response in second turn, got empty response")
|
||||
if lastToolCall.Function.Name != "get_weather" {
|
||||
t.Errorf("unexpected tool called: got %q want %q", lastToolCall.Function.Name, "get_weather")
|
||||
}
|
||||
|
||||
responseLower := strings.ToLower(finalResponse)
|
||||
expectedKeywords := []string{"72", "sunny", "temperature", "weather", "san francisco", "fahrenheit"}
|
||||
foundKeyword := false
|
||||
for _, keyword := range expectedKeywords {
|
||||
if strings.Contains(responseLower, strings.ToLower(keyword)) {
|
||||
foundKeyword = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !foundKeyword {
|
||||
t.Logf("response: %s", finalResponse)
|
||||
t.Logf("location from tool call: %v", location)
|
||||
if _, ok := lastToolCall.Function.Arguments["location"]; !ok {
|
||||
t.Errorf("expected tool arguments to include 'location', got: %s", lastToolCall.Function.Arguments.String())
|
||||
}
|
||||
case <-ctx.Done():
|
||||
t.Error("outer test context done while waiting for second turn")
|
||||
t.Error("outer test context done while waiting for tool-calling chat")
|
||||
}
|
||||
case <-ctx.Done():
|
||||
t.Error("outer test context done while waiting for first turn")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("OpenAIAPI", func(t *testing.T) {
|
||||
tools := []api.Tool{
|
||||
{
|
||||
Type: "function",
|
||||
Function: api.ToolFunction{
|
||||
Name: "get_weather",
|
||||
Description: "Get the current weather in a given location",
|
||||
Parameters: api.ToolFunctionParameters{
|
||||
Type: "object",
|
||||
Required: []string{"location"},
|
||||
Properties: map[string]api.ToolProperty{
|
||||
"location": {
|
||||
Type: api.PropertyType{"string"},
|
||||
Description: "The city and state, e.g. San Francisco, CA",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
req := openai.ChatCompletionRequest{
|
||||
Model: model,
|
||||
Messages: []openai.Message{
|
||||
{
|
||||
Role: "user",
|
||||
Content: "Call get_weather with location set to San Francisco.",
|
||||
},
|
||||
},
|
||||
Tools: tools,
|
||||
Stream: true,
|
||||
Temperature: float64Ptr(0),
|
||||
}
|
||||
|
||||
stallTimer := time.NewTimer(initialTimeout)
|
||||
var gotToolCall bool
|
||||
var lastToolCall openai.ToolCall
|
||||
|
||||
fn := func(response openai.ChatCompletionChunk) error {
|
||||
if len(response.Choices) > 0 && len(response.Choices[0].Delta.ToolCalls) > 0 {
|
||||
gotToolCall = true
|
||||
toolCalls := response.Choices[0].Delta.ToolCalls
|
||||
lastToolCall = toolCalls[len(toolCalls)-1]
|
||||
}
|
||||
if !stallTimer.Reset(streamTimeout) {
|
||||
return fmt.Errorf("stall was detected while streaming response, aborting")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
done := make(chan int)
|
||||
var genErr error
|
||||
go func() {
|
||||
genErr = sendOpenAIChatStreamRequest(ctx, "http://"+endpoint, req, fn)
|
||||
done <- 0
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-stallTimer.C:
|
||||
t.Errorf("tool-calling chat never started. Timed out after: %s", initialTimeout.String())
|
||||
case <-done:
|
||||
if genErr != nil {
|
||||
t.Fatalf("chat failed: %v", genErr)
|
||||
}
|
||||
|
||||
if !gotToolCall {
|
||||
t.Fatalf("expected at least one tool call, got none")
|
||||
}
|
||||
|
||||
if lastToolCall.Function.Name != "get_weather" {
|
||||
t.Errorf("unexpected tool called: got %q want %q", lastToolCall.Function.Name, "get_weather")
|
||||
}
|
||||
|
||||
if !strings.Contains(lastToolCall.Function.Arguments, "location") {
|
||||
t.Errorf("expected tool arguments to include 'location', got: %s", lastToolCall.Function.Arguments)
|
||||
}
|
||||
|
||||
if !strings.Contains(lastToolCall.Function.Arguments, "San Francisco") {
|
||||
t.Errorf("expected tool arguments to include 'San Francisco', got: %s", lastToolCall.Function.Arguments)
|
||||
}
|
||||
case <-ctx.Done():
|
||||
t.Error("outer test context done while waiting for tool-calling chat")
|
||||
}
|
||||
})
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in New Issue