integration: improve tool calling tests for local and cloud models
This commit is contained in:
parent
12b174b10e
commit
bd834dcbe3
5
go.mod
5
go.mod
|
|
@ -27,6 +27,7 @@ require (
|
|||
github.com/emirpasic/gods/v2 v2.0.0-alpha
|
||||
github.com/mattn/go-runewidth v0.0.14
|
||||
github.com/nlpodyssey/gopickle v0.3.0
|
||||
github.com/openai/openai-go/v3 v3.8.1
|
||||
github.com/pdevine/tensor v0.0.0-20240510204454-f88f4562727c
|
||||
github.com/tkrajina/typescriptify-golang-structs v0.2.0
|
||||
golang.org/x/image v0.22.0
|
||||
|
|
@ -48,6 +49,10 @@ require (
|
|||
github.com/pkg/errors v0.9.1 // indirect
|
||||
github.com/pmezard/go-difflib v1.0.0 // indirect
|
||||
github.com/rivo/uniseg v0.2.0 // indirect
|
||||
github.com/tidwall/gjson v1.14.4 // indirect
|
||||
github.com/tidwall/match v1.1.1 // indirect
|
||||
github.com/tidwall/pretty v1.2.1 // indirect
|
||||
github.com/tidwall/sjson v1.2.5 // indirect
|
||||
github.com/tkrajina/go-reflector v0.5.5 // indirect
|
||||
github.com/xtgo/set v1.0.0 // indirect
|
||||
go4.org/unsafe/assume-no-moving-gc v0.0.0-20231121144256-b99613f794b6 // indirect
|
||||
|
|
|
|||
12
go.sum
12
go.sum
|
|
@ -159,6 +159,8 @@ github.com/nlpodyssey/gopickle v0.3.0 h1:BLUE5gxFLyyNOPzlXxt6GoHEMMxD0qhsE4p0CIQ
|
|||
github.com/nlpodyssey/gopickle v0.3.0/go.mod h1:f070HJ/yR+eLi5WmM1OXJEGaTpuJEUiib19olXgYha0=
|
||||
github.com/olekukonko/tablewriter v0.0.5 h1:P2Ga83D34wi1o9J6Wh1mRuqd4mF/x/lgBS7N7AbDhec=
|
||||
github.com/olekukonko/tablewriter v0.0.5/go.mod h1:hPp6KlRPjbx+hW8ykQs1w3UBbZlj6HuIJcUGPhkA7kY=
|
||||
github.com/openai/openai-go/v3 v3.8.1 h1:b+YWsmwqXnbpSHWQEntZAkKciBZ5CJXwL68j+l59UDg=
|
||||
github.com/openai/openai-go/v3 v3.8.1/go.mod h1:UOpNxkqC9OdNXNUfpNByKOtB4jAL0EssQXq5p8gO0Xs=
|
||||
github.com/pdevine/tensor v0.0.0-20240510204454-f88f4562727c h1:GwiUUjKefgvSNmv3NCvI/BL0kDebW6Xa+kcdpdc1mTY=
|
||||
github.com/pdevine/tensor v0.0.0-20240510204454-f88f4562727c/go.mod h1:PSojXDXF7TbgQiD6kkd98IHOS0QqTyUEaWRiS8+BLu8=
|
||||
github.com/pelletier/go-toml/v2 v2.2.2 h1:aYUidT7k73Pcl9nb2gScu7NSrKCSHIDE89b3+6Wq+LM=
|
||||
|
|
@ -199,6 +201,16 @@ github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o
|
|||
github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo=
|
||||
github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg=
|
||||
github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
|
||||
github.com/tidwall/gjson v1.14.2/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
|
||||
github.com/tidwall/gjson v1.14.4 h1:uo0p8EbA09J7RQaflQ1aBRffTR7xedD2bcIVSYxLnkM=
|
||||
github.com/tidwall/gjson v1.14.4/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
|
||||
github.com/tidwall/match v1.1.1 h1:+Ho715JplO36QYgwN9PGYNhgZvoUSc9X2c80KVTi+GA=
|
||||
github.com/tidwall/match v1.1.1/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JTxsfmM=
|
||||
github.com/tidwall/pretty v1.2.0/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU=
|
||||
github.com/tidwall/pretty v1.2.1 h1:qjsOFOWWQl+N3RsoF5/ssm1pHmJJwhjlSbZ51I6wMl4=
|
||||
github.com/tidwall/pretty v1.2.1/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU=
|
||||
github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY=
|
||||
github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28=
|
||||
github.com/tkrajina/go-reflector v0.5.5 h1:gwoQFNye30Kk7NrExj8zm3zFtrGPqOkzFMLuQZg1DtQ=
|
||||
github.com/tkrajina/go-reflector v0.5.5/go.mod h1:ECbqLgccecY5kPmPmXg1MrHW585yMcDkVl6IvJe64T4=
|
||||
github.com/tkrajina/typescriptify-golang-structs v0.2.0 h1:ZedWk82egydDspGTryAatbX0/1NZDQbdiZLoCbOk4f8=
|
||||
|
|
|
|||
|
|
@ -14,4 +14,59 @@ The integration tests have 2 modes of operating.
|
|||
> Before running the tests locally without the "test existing" setting, compile ollama from the top of the source tree `go build .` in addition to GPU support with cmake if applicable on your platform. The integration tests expect to find an ollama binary at the top of the tree.
|
||||
|
||||
|
||||
Many tests use a default small model suitable to run on many systems. You can override this default model by setting `OLLAMA_TEST_DEFAULT_MODEL`
|
||||
Many tests use a default small model suitable to run on many systems. You can override this default model by setting `OLLAMA_TEST_DEFAULT_MODEL`
|
||||
|
||||
## Tool Calling Tests
|
||||
|
||||
The tool calling tests are split into two files:
|
||||
|
||||
- **`tools_test.go`** - Tests using the native Ollama API (`api.Tool`)
|
||||
- **`tools_openai_test.go`** - Tests using the OpenAI-compatible API format
|
||||
|
||||
### Running Tool Calling Tests
|
||||
|
||||
Run all tool calling tests:
|
||||
```bash
|
||||
go test -tags=integration -v -run Test.*Tool.* ./integration
|
||||
```
|
||||
|
||||
Run only OpenAI-compatible tests:
|
||||
```bash
|
||||
go test -tags=integration -v -run TestOpenAI ./integration
|
||||
```
|
||||
|
||||
Run only native API tests:
|
||||
```bash
|
||||
go test -tags=integration -v -run TestAPIToolCalling ./integration
|
||||
```
|
||||
|
||||
### Parallel Execution
|
||||
|
||||
The OpenAI-compatible tests (`tools_openai_test.go`) support parallel execution for cloud models. Run with parallel execution:
|
||||
```bash
|
||||
go test -tags=integration -v -run TestOpenAI -parallel 3 ./integration
|
||||
```
|
||||
|
||||
Cloud models (models ending with `-cloud`) will run in parallel, while local models run sequentially. This significantly speeds up test execution when testing against external endpoints.
|
||||
|
||||
### Testing Specific Models
|
||||
|
||||
To test a specific model, set the `OPENAI_TEST_MODELS` environment variable:
|
||||
```bash
|
||||
OPENAI_TEST_MODELS="gpt-oss:120b-cloud" go test -tags=integration -v -run TestOpenAI ./integration
|
||||
```
|
||||
|
||||
### External Endpoints
|
||||
|
||||
To test against an external OpenAI-compatible endpoint (e.g., Ollama Cloud):
|
||||
```bash
|
||||
OPENAI_BASE_URL="https://ollama.com/v1" OLLAMA_API_KEY="your-key" go test -tags=integration -v -run TestOpenAI ./integration
|
||||
```
|
||||
|
||||
### Environment Variables
|
||||
|
||||
The tool calling tests support the following environment variables:
|
||||
|
||||
- **`OPENAI_BASE_URL`** - When set, tests will run against an external OpenAI-compatible endpoint instead of a local server. If set, `OLLAMA_API_KEY` must also be provided.
|
||||
- **`OLLAMA_API_KEY`** - API key for authenticating with external endpoints (required when `OPENAI_BASE_URL` is set).
|
||||
- **`OPENAI_TEST_MODELS`** - Override the default model list and test only the specified model(s). Can be a single model or comma-separated list.
|
||||
|
|
@ -0,0 +1,834 @@
|
|||
//go:build integration
|
||||
|
||||
package integration
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
"slices"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/openai/openai-go/v3"
|
||||
"github.com/openai/openai-go/v3/option"
|
||||
"github.com/openai/openai-go/v3/shared"
|
||||
)
|
||||
|
||||
var agenticModels = []string{
|
||||
"gpt-oss:20b",
|
||||
"gpt-oss:120b",
|
||||
"qwen3-coder:30b",
|
||||
"qwen3:4b",
|
||||
"qwen3:8b",
|
||||
}
|
||||
|
||||
var cloudModels = []string{
|
||||
"gpt-oss:120b-cloud",
|
||||
"gpt-oss:20b-cloud",
|
||||
"qwen3-vl:235b-cloud",
|
||||
"qwen3-coder:480b-cloud",
|
||||
"kimi-k2-thinking:cloud",
|
||||
"kimi-k2:1t-cloud",
|
||||
}
|
||||
|
||||
// validateBashCommand validates a bash command with flexible matching
|
||||
// It checks that the core command matches and required arguments are present
|
||||
func validateBashCommand(cmd string, expectedCmd string, requiredArgs []string) error {
|
||||
parts := strings.Fields(cmd)
|
||||
if len(parts) == 0 {
|
||||
return fmt.Errorf("empty command")
|
||||
}
|
||||
|
||||
actualCmd := parts[0]
|
||||
if actualCmd != expectedCmd {
|
||||
return fmt.Errorf("expected command '%s', got '%s'", expectedCmd, actualCmd)
|
||||
}
|
||||
|
||||
cmdStr := strings.Join(parts[1:], " ")
|
||||
for _, arg := range requiredArgs {
|
||||
if !strings.Contains(cmdStr, arg) {
|
||||
return fmt.Errorf("missing required argument: %s", arg)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// validateBashCommandFlexible validates a bash command with flexible matching
|
||||
// It accepts alternative command forms (e.g., find vs ls) and checks required patterns
|
||||
func validateBashCommandFlexible(cmd string, allowedCommands []string, requiredPatterns []string) error {
|
||||
parts := strings.Fields(cmd)
|
||||
if len(parts) == 0 {
|
||||
return fmt.Errorf("empty command")
|
||||
}
|
||||
|
||||
actualCmd := parts[0]
|
||||
commandMatched := false
|
||||
for _, allowedCmd := range allowedCommands {
|
||||
if actualCmd == allowedCmd {
|
||||
commandMatched = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !commandMatched {
|
||||
return fmt.Errorf("expected one of commands %v, got '%s'", allowedCommands, actualCmd)
|
||||
}
|
||||
|
||||
cmdStr := strings.ToLower(strings.Join(parts[1:], " "))
|
||||
for _, pattern := range requiredPatterns {
|
||||
if !strings.Contains(cmdStr, strings.ToLower(pattern)) {
|
||||
return fmt.Errorf("missing required pattern: %s", pattern)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func TestOpenAIToolCallingMultiStep(t *testing.T) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Minute)
|
||||
defer cancel()
|
||||
|
||||
var baseURL string
|
||||
var apiKey string
|
||||
var modelsToTest []string
|
||||
var cleanup func()
|
||||
|
||||
if openaiBaseURL := os.Getenv("OPENAI_BASE_URL"); openaiBaseURL != "" {
|
||||
baseURL = openaiBaseURL
|
||||
apiKey = os.Getenv("OLLAMA_API_KEY")
|
||||
if apiKey == "" {
|
||||
t.Fatal("OPENAI_API_KEY must be set when using OPENAI_BASE_URL")
|
||||
}
|
||||
|
||||
// only test cloud models unless OPENAI_TEST_MODELS is set
|
||||
modelsToTest = cloudModels
|
||||
if modelsEnv := os.Getenv("OPENAI_TEST_MODELS"); modelsEnv != "" {
|
||||
modelsToTest = []string{modelsEnv}
|
||||
}
|
||||
cleanup = func() {}
|
||||
} else {
|
||||
_, testEndpoint, cleanupFn := InitServerConnection(ctx, t)
|
||||
cleanup = cleanupFn
|
||||
baseURL = fmt.Sprintf("http://%s/v1", testEndpoint)
|
||||
apiKey = "ollama"
|
||||
modelsToTest = append(agenticModels, cloudModels...)
|
||||
}
|
||||
t.Cleanup(cleanup)
|
||||
|
||||
opts := []option.RequestOption{
|
||||
option.WithBaseURL(baseURL),
|
||||
option.WithAPIKey(apiKey),
|
||||
}
|
||||
openaiClient := openai.NewClient(opts...)
|
||||
|
||||
var ollamaClient *api.Client
|
||||
if baseURL == "" {
|
||||
ollamaClient, _, _ = InitServerConnection(ctx, t)
|
||||
}
|
||||
|
||||
for _, model := range modelsToTest {
|
||||
t.Run(model, func(t *testing.T) {
|
||||
testCtx := ctx
|
||||
if slices.Contains(cloudModels, model) {
|
||||
t.Parallel()
|
||||
// Create a new context for parallel tests to avoid cancellation
|
||||
var cancel context.CancelFunc
|
||||
testCtx, cancel = context.WithTimeout(context.Background(), 10*time.Minute)
|
||||
defer cancel()
|
||||
}
|
||||
if v, ok := minVRAM[model]; ok {
|
||||
skipUnderMinVRAM(t, v)
|
||||
}
|
||||
|
||||
if ollamaClient != nil {
|
||||
if err := PullIfMissing(testCtx, ollamaClient, model); err != nil {
|
||||
t.Fatalf("pull failed %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
tools := []openai.ChatCompletionToolUnionParam{
|
||||
openai.ChatCompletionFunctionTool(shared.FunctionDefinitionParam{
|
||||
Name: "list_files",
|
||||
Description: openai.Opt("List all files in a directory"),
|
||||
Parameters: shared.FunctionParameters{
|
||||
"type": "object",
|
||||
"properties": map[string]any{
|
||||
"path": map[string]any{
|
||||
"type": "string",
|
||||
"description": "The directory path to list files from",
|
||||
},
|
||||
},
|
||||
"required": []string{"path"},
|
||||
},
|
||||
}),
|
||||
openai.ChatCompletionFunctionTool(shared.FunctionDefinitionParam{
|
||||
Name: "read_file",
|
||||
Description: openai.Opt("Read the contents of a file"),
|
||||
Parameters: shared.FunctionParameters{
|
||||
"type": "object",
|
||||
"properties": map[string]any{
|
||||
"path": map[string]any{
|
||||
"type": "string",
|
||||
"description": "The file path to read",
|
||||
},
|
||||
},
|
||||
"required": []string{"path"},
|
||||
},
|
||||
}),
|
||||
}
|
||||
|
||||
mockFileContents := "line 1\nline 2\nline 3\nline 4\nline 5"
|
||||
userContent := "Find the file named 'config.json' in /tmp and read its contents"
|
||||
userMessage := openai.UserMessage(userContent)
|
||||
|
||||
messages := []openai.ChatCompletionMessageParamUnion{
|
||||
userMessage,
|
||||
}
|
||||
stepCount := 0
|
||||
maxSteps := 10
|
||||
|
||||
normalizePath := func(path string) string {
|
||||
if path != "" && path[0] != '/' {
|
||||
return "/" + path
|
||||
}
|
||||
return path
|
||||
}
|
||||
|
||||
expectedSteps := []struct {
|
||||
functionName string
|
||||
validateArgs func(map[string]any) error
|
||||
result string
|
||||
}{
|
||||
{
|
||||
functionName: "list_files",
|
||||
validateArgs: func(args map[string]any) error {
|
||||
path, ok := args["path"]
|
||||
if !ok {
|
||||
return fmt.Errorf("missing required argument 'path'")
|
||||
}
|
||||
pathStr, ok := path.(string)
|
||||
if !ok {
|
||||
return fmt.Errorf("expected 'path' to be string, got %T", path)
|
||||
}
|
||||
normalizedPath := normalizePath(pathStr)
|
||||
if normalizedPath != "/tmp" {
|
||||
return fmt.Errorf("expected list_files(\"/tmp\"), got list_files(%q)", pathStr)
|
||||
}
|
||||
return nil
|
||||
},
|
||||
result: `["config.json", "other.txt", "data.log"]`,
|
||||
},
|
||||
{
|
||||
functionName: "read_file",
|
||||
validateArgs: func(args map[string]any) error {
|
||||
path, ok := args["path"]
|
||||
if !ok {
|
||||
return fmt.Errorf("missing required argument 'path'")
|
||||
}
|
||||
pathStr, ok := path.(string)
|
||||
if !ok {
|
||||
return fmt.Errorf("expected 'path' to be string, got %T", path)
|
||||
}
|
||||
normalizedPath := normalizePath(pathStr)
|
||||
if normalizedPath != "/tmp/config.json" {
|
||||
return fmt.Errorf("expected read_file(\"/tmp/config.json\"), got read_file(%q)", pathStr)
|
||||
}
|
||||
return nil
|
||||
},
|
||||
result: mockFileContents,
|
||||
},
|
||||
}
|
||||
|
||||
for stepCount < maxSteps {
|
||||
req := openai.ChatCompletionNewParams{
|
||||
Model: shared.ChatModel(model),
|
||||
Messages: messages,
|
||||
Tools: tools,
|
||||
Temperature: openai.Opt(0.0),
|
||||
}
|
||||
|
||||
completion, err := openaiClient.Chat.Completions.New(testCtx, req)
|
||||
if err != nil {
|
||||
t.Fatalf("step %d chat failed: %v", stepCount+1, err)
|
||||
}
|
||||
|
||||
if len(completion.Choices) == 0 {
|
||||
t.Fatalf("step %d: no choices in response", stepCount+1)
|
||||
}
|
||||
|
||||
choice := completion.Choices[0]
|
||||
message := choice.Message
|
||||
|
||||
toolCalls := message.ToolCalls
|
||||
content := message.Content
|
||||
gotToolCall := len(toolCalls) > 0
|
||||
var toolCallID string
|
||||
if gotToolCall && toolCalls[0].ID != "" {
|
||||
toolCallID = toolCalls[0].ID
|
||||
}
|
||||
|
||||
var assistantMessage openai.ChatCompletionMessageParamUnion
|
||||
if gotToolCall {
|
||||
toolCallsJSON, err := json.Marshal(toolCalls)
|
||||
if err != nil {
|
||||
t.Fatalf("step %d: failed to marshal tool calls: %v", stepCount+1, err)
|
||||
}
|
||||
var toolCallParams []openai.ChatCompletionMessageToolCallUnionParam
|
||||
if err := json.Unmarshal(toolCallsJSON, &toolCallParams); err != nil {
|
||||
t.Fatalf("step %d: failed to unmarshal tool calls: %v", stepCount+1, err)
|
||||
}
|
||||
contentUnion := openai.ChatCompletionAssistantMessageParamContentUnion{
|
||||
OfString: openai.Opt(content),
|
||||
}
|
||||
assistantMsg := openai.ChatCompletionAssistantMessageParam{
|
||||
Content: contentUnion,
|
||||
ToolCalls: toolCallParams,
|
||||
}
|
||||
assistantMessage = openai.ChatCompletionMessageParamUnion{
|
||||
OfAssistant: &assistantMsg,
|
||||
}
|
||||
} else {
|
||||
assistantMessage = openai.AssistantMessage(content)
|
||||
}
|
||||
|
||||
if !gotToolCall && content != "" {
|
||||
if stepCount < len(expectedSteps) {
|
||||
t.Logf("EXPECTED: Step %d should call '%s'", stepCount+1, expectedSteps[stepCount].functionName)
|
||||
t.Logf("ACTUAL: Model stopped with content: %s", content)
|
||||
t.Fatalf("model stopped making tool calls after %d steps, expected %d steps. Final response: %s", stepCount, len(expectedSteps), content)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if !gotToolCall || len(toolCalls) == 0 {
|
||||
if stepCount < len(expectedSteps) {
|
||||
expectedStep := expectedSteps[stepCount]
|
||||
t.Logf("EXPECTED: Step %d should call '%s'", stepCount+1, expectedStep.functionName)
|
||||
t.Logf("ACTUAL: No tool call, got content: %s", content)
|
||||
t.Fatalf("step %d: expected tool call but got none. Response: %s", stepCount+1, content)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if stepCount >= len(expectedSteps) {
|
||||
actualCallJSON, _ := json.MarshalIndent(toolCalls[0], "", " ")
|
||||
t.Logf("EXPECTED: All %d steps completed", len(expectedSteps))
|
||||
t.Logf("ACTUAL: Extra step %d with tool call:\n%s", stepCount+1, string(actualCallJSON))
|
||||
funcName := "unknown"
|
||||
if toolCalls[0].Function.Name != "" {
|
||||
funcName = toolCalls[0].Function.Name
|
||||
}
|
||||
t.Fatalf("model made more tool calls than expected. Expected %d steps, got step %d with tool call: %s", len(expectedSteps), stepCount+1, funcName)
|
||||
}
|
||||
|
||||
expectedStep := expectedSteps[stepCount]
|
||||
firstToolCall := toolCalls[0]
|
||||
funcCall := firstToolCall.Function
|
||||
if funcCall.Name == "" {
|
||||
t.Fatalf("step %d: tool call missing function name", stepCount+1)
|
||||
}
|
||||
|
||||
funcName := funcCall.Name
|
||||
|
||||
var args map[string]any
|
||||
if funcCall.Arguments != "" {
|
||||
if err := json.Unmarshal([]byte(funcCall.Arguments), &args); err != nil {
|
||||
t.Fatalf("step %d: failed to parse tool call arguments: %v", stepCount+1, err)
|
||||
}
|
||||
}
|
||||
|
||||
if funcName != expectedStep.functionName {
|
||||
t.Logf("DIFF: Function name mismatch")
|
||||
t.Logf(" Expected: %s", expectedStep.functionName)
|
||||
t.Logf(" Got: %s", funcName)
|
||||
t.Logf(" Arguments: %v", args)
|
||||
t.Fatalf("step %d: expected tool call '%s', got '%s'. Arguments: %v", stepCount+1, expectedStep.functionName, funcName, args)
|
||||
}
|
||||
|
||||
if err := expectedStep.validateArgs(args); err != nil {
|
||||
expectedArgsForDisplay := map[string]any{}
|
||||
if expectedStep.functionName == "list_files" {
|
||||
expectedArgsForDisplay = map[string]any{"path": "/tmp"}
|
||||
} else if expectedStep.functionName == "read_file" {
|
||||
expectedArgsForDisplay = map[string]any{"path": "/tmp/config.json"}
|
||||
}
|
||||
if diff := cmp.Diff(expectedArgsForDisplay, args); diff != "" {
|
||||
t.Logf("DIFF: Arguments mismatch for function '%s' (-want +got):\n%s", expectedStep.functionName, diff)
|
||||
}
|
||||
t.Logf("Error: %v", err)
|
||||
t.Fatalf("step %d: tool call '%s' has invalid arguments: %v. Arguments: %v", stepCount+1, expectedStep.functionName, err, args)
|
||||
}
|
||||
|
||||
toolMessage := openai.ToolMessage(expectedStep.result, toolCallID)
|
||||
messages = append(messages, assistantMessage, toolMessage)
|
||||
stepCount++
|
||||
}
|
||||
|
||||
if stepCount < len(expectedSteps) {
|
||||
t.Fatalf("test exceeded max steps (%d) before completing all expected steps (%d)", maxSteps, len(expectedSteps))
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpenAIToolCallingBash(t *testing.T) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Minute)
|
||||
defer cancel()
|
||||
|
||||
var baseURL string
|
||||
var apiKey string
|
||||
var modelsToTest []string
|
||||
var cleanup func()
|
||||
|
||||
if openaiBaseURL := os.Getenv("OPENAI_BASE_URL"); openaiBaseURL != "" {
|
||||
baseURL = openaiBaseURL
|
||||
apiKey = os.Getenv("OLLAMA_API_KEY")
|
||||
if apiKey == "" {
|
||||
t.Fatal("OPENAI_API_KEY must be set when using OPENAI_BASE_URL")
|
||||
}
|
||||
modelsToTest = cloudModels
|
||||
if modelsEnv := os.Getenv("OPENAI_TEST_MODELS"); modelsEnv != "" {
|
||||
modelsToTest = []string{modelsEnv}
|
||||
}
|
||||
cleanup = func() {}
|
||||
} else {
|
||||
_, testEndpoint, cleanupFn := InitServerConnection(ctx, t)
|
||||
cleanup = cleanupFn
|
||||
baseURL = fmt.Sprintf("http://%s/v1", testEndpoint)
|
||||
apiKey = "ollama"
|
||||
modelsToTest = append(agenticModels, cloudModels...)
|
||||
}
|
||||
t.Cleanup(cleanup)
|
||||
|
||||
opts := []option.RequestOption{
|
||||
option.WithBaseURL(baseURL),
|
||||
option.WithAPIKey(apiKey),
|
||||
}
|
||||
openaiClient := openai.NewClient(opts...)
|
||||
|
||||
var ollamaClient *api.Client
|
||||
if baseURL == "" {
|
||||
ollamaClient, _, _ = InitServerConnection(ctx, t)
|
||||
}
|
||||
|
||||
for _, model := range modelsToTest {
|
||||
t.Run(model, func(t *testing.T) {
|
||||
testCtx := ctx
|
||||
if slices.Contains(cloudModels, model) {
|
||||
t.Parallel()
|
||||
// Create a new context for parallel tests to avoid cancellation
|
||||
var cancel context.CancelFunc
|
||||
testCtx, cancel = context.WithTimeout(context.Background(), 10*time.Minute)
|
||||
defer cancel()
|
||||
}
|
||||
if v, ok := minVRAM[model]; ok {
|
||||
skipUnderMinVRAM(t, v)
|
||||
}
|
||||
|
||||
if ollamaClient != nil {
|
||||
if err := PullIfMissing(testCtx, ollamaClient, model); err != nil {
|
||||
t.Fatalf("pull failed %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
tools := []openai.ChatCompletionToolUnionParam{
|
||||
openai.ChatCompletionFunctionTool(shared.FunctionDefinitionParam{
|
||||
Name: "execute_bash",
|
||||
Description: openai.Opt("Execute a bash/shell command and return stdout, stderr, and exit code"),
|
||||
Parameters: shared.FunctionParameters{
|
||||
"type": "object",
|
||||
"properties": map[string]any{
|
||||
"command": map[string]any{
|
||||
"type": "string",
|
||||
"description": "The bash command to execute",
|
||||
},
|
||||
"working_directory": map[string]any{
|
||||
"type": "string",
|
||||
"description": "Optional working directory for command execution",
|
||||
},
|
||||
},
|
||||
"required": []string{"command"},
|
||||
},
|
||||
}),
|
||||
}
|
||||
|
||||
userContent := "List all files in /tmp directory"
|
||||
userMessage := openai.UserMessage(userContent)
|
||||
|
||||
req := openai.ChatCompletionNewParams{
|
||||
Model: shared.ChatModel(model),
|
||||
Messages: []openai.ChatCompletionMessageParamUnion{userMessage},
|
||||
Tools: tools,
|
||||
Temperature: openai.Opt(0.0),
|
||||
}
|
||||
|
||||
completion, err := openaiClient.Chat.Completions.New(testCtx, req)
|
||||
if err != nil {
|
||||
t.Fatalf("chat failed: %v", err)
|
||||
}
|
||||
|
||||
if len(completion.Choices) == 0 {
|
||||
t.Fatalf("no choices in response")
|
||||
}
|
||||
|
||||
choice := completion.Choices[0]
|
||||
message := choice.Message
|
||||
|
||||
if len(message.ToolCalls) == 0 {
|
||||
finishReason := choice.FinishReason
|
||||
if finishReason == "" {
|
||||
finishReason = "unknown"
|
||||
}
|
||||
content := message.Content
|
||||
if content == "" {
|
||||
content = "(empty)"
|
||||
}
|
||||
t.Logf("User prompt: %q", userContent)
|
||||
t.Logf("Finish reason: %s", finishReason)
|
||||
t.Logf("Message content: %q", content)
|
||||
t.Logf("Tool calls count: %d", len(message.ToolCalls))
|
||||
if messageJSON, err := json.MarshalIndent(message, "", " "); err == nil {
|
||||
t.Logf("Full message: %s", string(messageJSON))
|
||||
}
|
||||
t.Fatalf("expected at least one tool call, got none. Finish reason: %s, Content: %q", finishReason, content)
|
||||
}
|
||||
|
||||
firstToolCall := message.ToolCalls[0]
|
||||
if firstToolCall.Function.Name != "execute_bash" {
|
||||
t.Fatalf("unexpected tool called: got %q want %q", firstToolCall.Function.Name, "execute_bash")
|
||||
}
|
||||
|
||||
var args map[string]any
|
||||
if firstToolCall.Function.Arguments != "" {
|
||||
if err := json.Unmarshal([]byte(firstToolCall.Function.Arguments), &args); err != nil {
|
||||
t.Fatalf("failed to parse tool call arguments: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
command, ok := args["command"]
|
||||
if !ok {
|
||||
t.Fatalf("expected tool arguments to include 'command', got: %v", args)
|
||||
}
|
||||
|
||||
cmdStr, ok := command.(string)
|
||||
if !ok {
|
||||
t.Fatalf("expected command to be string, got %T", command)
|
||||
}
|
||||
|
||||
if err := validateBashCommand(cmdStr, "ls", []string{"/tmp"}); err != nil {
|
||||
t.Errorf("bash command validation failed: %v. Command: %q", err, cmdStr)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpenAIToolCallingBashMultiStep(t *testing.T) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Minute)
|
||||
defer cancel()
|
||||
|
||||
var baseURL string
|
||||
var apiKey string
|
||||
var modelsToTest []string
|
||||
var cleanup func()
|
||||
|
||||
if openaiBaseURL := os.Getenv("OPENAI_BASE_URL"); openaiBaseURL != "" {
|
||||
baseURL = openaiBaseURL
|
||||
apiKey = os.Getenv("OLLAMA_API_KEY")
|
||||
if apiKey == "" {
|
||||
t.Fatal("OPENAI_API_KEY must be set when using OPENAI_BASE_URL")
|
||||
}
|
||||
modelsToTest = cloudModels
|
||||
if modelsEnv := os.Getenv("OPENAI_TEST_MODELS"); modelsEnv != "" {
|
||||
modelsToTest = []string{modelsEnv}
|
||||
}
|
||||
cleanup = func() {}
|
||||
} else {
|
||||
_, testEndpoint, cleanupFn := InitServerConnection(ctx, t)
|
||||
cleanup = cleanupFn
|
||||
baseURL = fmt.Sprintf("http://%s/v1", testEndpoint)
|
||||
apiKey = "ollama"
|
||||
modelsToTest = append(agenticModels, cloudModels...)
|
||||
}
|
||||
t.Cleanup(cleanup)
|
||||
|
||||
opts := []option.RequestOption{
|
||||
option.WithBaseURL(baseURL),
|
||||
option.WithAPIKey(apiKey),
|
||||
}
|
||||
openaiClient := openai.NewClient(opts...)
|
||||
|
||||
var ollamaClient *api.Client
|
||||
if baseURL == "" {
|
||||
ollamaClient, _, _ = InitServerConnection(ctx, t)
|
||||
}
|
||||
|
||||
for _, model := range modelsToTest {
|
||||
t.Run(model, func(t *testing.T) {
|
||||
testCtx := ctx
|
||||
if slices.Contains(cloudModels, model) {
|
||||
t.Parallel()
|
||||
// Create a new context for parallel tests to avoid cancellation
|
||||
var cancel context.CancelFunc
|
||||
testCtx, cancel = context.WithTimeout(context.Background(), 10*time.Minute)
|
||||
defer cancel()
|
||||
}
|
||||
if v, ok := minVRAM[model]; ok {
|
||||
skipUnderMinVRAM(t, v)
|
||||
}
|
||||
|
||||
if ollamaClient != nil {
|
||||
if err := PullIfMissing(testCtx, ollamaClient, model); err != nil {
|
||||
t.Fatalf("pull failed %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
tools := []openai.ChatCompletionToolUnionParam{
|
||||
openai.ChatCompletionFunctionTool(shared.FunctionDefinitionParam{
|
||||
Name: "execute_bash",
|
||||
Description: openai.Opt("Execute a bash/shell command and return stdout, stderr, and exit code"),
|
||||
Parameters: shared.FunctionParameters{
|
||||
"type": "object",
|
||||
"properties": map[string]any{
|
||||
"command": map[string]any{
|
||||
"type": "string",
|
||||
"description": "The bash command to execute",
|
||||
},
|
||||
"working_directory": map[string]any{
|
||||
"type": "string",
|
||||
"description": "Optional working directory for command execution",
|
||||
},
|
||||
},
|
||||
"required": []string{"command"},
|
||||
},
|
||||
}),
|
||||
}
|
||||
|
||||
userContent := "Find all log files in /tmp. use the bash tool"
|
||||
userMessage := openai.UserMessage(userContent)
|
||||
|
||||
req := openai.ChatCompletionNewParams{
|
||||
Model: shared.ChatModel(model),
|
||||
Messages: []openai.ChatCompletionMessageParamUnion{userMessage},
|
||||
Tools: tools,
|
||||
Temperature: openai.Opt(0.0),
|
||||
}
|
||||
|
||||
completion, err := openaiClient.Chat.Completions.New(testCtx, req)
|
||||
if err != nil {
|
||||
t.Fatalf("chat failed: %v", err)
|
||||
}
|
||||
|
||||
if len(completion.Choices) == 0 {
|
||||
t.Fatalf("no choices in response")
|
||||
}
|
||||
|
||||
choice := completion.Choices[0]
|
||||
message := choice.Message
|
||||
|
||||
if len(message.ToolCalls) == 0 {
|
||||
finishReason := choice.FinishReason
|
||||
if finishReason == "" {
|
||||
finishReason = "unknown"
|
||||
}
|
||||
content := message.Content
|
||||
if content == "" {
|
||||
content = "(empty)"
|
||||
}
|
||||
t.Logf("User prompt: %q", userContent)
|
||||
t.Logf("Finish reason: %s", finishReason)
|
||||
t.Logf("Message content: %q", content)
|
||||
t.Logf("Tool calls count: %d", len(message.ToolCalls))
|
||||
if messageJSON, err := json.MarshalIndent(message, "", " "); err == nil {
|
||||
t.Logf("Full message: %s", string(messageJSON))
|
||||
}
|
||||
t.Fatalf("expected at least one tool call, got none. Finish reason: %s, Content: %q", finishReason, content)
|
||||
}
|
||||
|
||||
firstToolCall := message.ToolCalls[0]
|
||||
if firstToolCall.Function.Name != "execute_bash" {
|
||||
t.Fatalf("unexpected tool called: got %q want %q", firstToolCall.Function.Name, "execute_bash")
|
||||
}
|
||||
|
||||
var args map[string]any
|
||||
if firstToolCall.Function.Arguments != "" {
|
||||
if err := json.Unmarshal([]byte(firstToolCall.Function.Arguments), &args); err != nil {
|
||||
t.Fatalf("failed to parse tool call arguments: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
command, ok := args["command"]
|
||||
if !ok {
|
||||
t.Fatalf("expected tool arguments to include 'command', got: %v", args)
|
||||
}
|
||||
|
||||
cmdStr, ok := command.(string)
|
||||
if !ok {
|
||||
t.Fatalf("expected command to be string, got %T", command)
|
||||
}
|
||||
|
||||
if err := validateBashCommandFlexible(cmdStr, []string{"find", "ls"}, []string{"/tmp"}); err != nil {
|
||||
t.Errorf("bash command validation failed: %v. Command: %q", err, cmdStr)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpenAIToolCallingBashAmpersand(t *testing.T) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Minute)
|
||||
defer cancel()
|
||||
|
||||
var baseURL string
|
||||
var apiKey string
|
||||
var modelsToTest []string
|
||||
var cleanup func()
|
||||
|
||||
if openaiBaseURL := os.Getenv("OPENAI_BASE_URL"); openaiBaseURL != "" {
|
||||
baseURL = openaiBaseURL
|
||||
apiKey = os.Getenv("OLLAMA_API_KEY")
|
||||
if apiKey == "" {
|
||||
t.Fatal("OPENAI_API_KEY must be set when using OPENAI_BASE_URL")
|
||||
}
|
||||
modelsToTest = cloudModels
|
||||
if modelsEnv := os.Getenv("OPENAI_TEST_MODELS"); modelsEnv != "" {
|
||||
modelsToTest = []string{modelsEnv}
|
||||
}
|
||||
cleanup = func() {}
|
||||
} else {
|
||||
_, testEndpoint, cleanupFn := InitServerConnection(ctx, t)
|
||||
cleanup = cleanupFn
|
||||
baseURL = fmt.Sprintf("http://%s/v1", testEndpoint)
|
||||
apiKey = "ollama"
|
||||
modelsToTest = append(agenticModels, cloudModels...)
|
||||
}
|
||||
t.Cleanup(cleanup)
|
||||
|
||||
opts := []option.RequestOption{
|
||||
option.WithBaseURL(baseURL),
|
||||
option.WithAPIKey(apiKey),
|
||||
}
|
||||
openaiClient := openai.NewClient(opts...)
|
||||
|
||||
var ollamaClient *api.Client
|
||||
if baseURL == "" {
|
||||
ollamaClient, _, _ = InitServerConnection(ctx, t)
|
||||
}
|
||||
|
||||
for _, model := range modelsToTest {
|
||||
t.Run(model, func(t *testing.T) {
|
||||
testCtx := ctx
|
||||
if slices.Contains(cloudModels, model) {
|
||||
t.Parallel()
|
||||
// Create a new context for parallel tests to avoid cancellation
|
||||
var cancel context.CancelFunc
|
||||
testCtx, cancel = context.WithTimeout(context.Background(), 10*time.Minute)
|
||||
defer cancel()
|
||||
}
|
||||
if v, ok := minVRAM[model]; ok {
|
||||
skipUnderMinVRAM(t, v)
|
||||
}
|
||||
|
||||
if ollamaClient != nil {
|
||||
if err := PullIfMissing(testCtx, ollamaClient, model); err != nil {
|
||||
t.Fatalf("pull failed %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
tools := []openai.ChatCompletionToolUnionParam{
|
||||
openai.ChatCompletionFunctionTool(shared.FunctionDefinitionParam{
|
||||
Name: "execute_bash",
|
||||
Description: openai.Opt("Execute a bash/shell command and return stdout, stderr, and exit code"),
|
||||
Parameters: shared.FunctionParameters{
|
||||
"type": "object",
|
||||
"properties": map[string]any{
|
||||
"command": map[string]any{
|
||||
"type": "string",
|
||||
"description": "The bash command to execute",
|
||||
},
|
||||
"working_directory": map[string]any{
|
||||
"type": "string",
|
||||
"description": "Optional working directory for command execution",
|
||||
},
|
||||
},
|
||||
"required": []string{"command"},
|
||||
},
|
||||
}),
|
||||
}
|
||||
|
||||
userContent := "Echo the text 'A & B' using bash with the bash tool"
|
||||
userMessage := openai.UserMessage(userContent)
|
||||
|
||||
req := openai.ChatCompletionNewParams{
|
||||
Model: shared.ChatModel(model),
|
||||
Messages: []openai.ChatCompletionMessageParamUnion{userMessage},
|
||||
Tools: tools,
|
||||
Temperature: openai.Opt(0.0),
|
||||
}
|
||||
|
||||
completion, err := openaiClient.Chat.Completions.New(testCtx, req)
|
||||
if err != nil {
|
||||
t.Fatalf("chat failed: %v", err)
|
||||
}
|
||||
|
||||
if len(completion.Choices) == 0 {
|
||||
t.Fatalf("no choices in response")
|
||||
}
|
||||
|
||||
choice := completion.Choices[0]
|
||||
message := choice.Message
|
||||
|
||||
if len(message.ToolCalls) == 0 {
|
||||
finishReason := choice.FinishReason
|
||||
if finishReason == "" {
|
||||
finishReason = "unknown"
|
||||
}
|
||||
content := message.Content
|
||||
if content == "" {
|
||||
content = "(empty)"
|
||||
}
|
||||
t.Logf("User prompt: %q", userContent)
|
||||
t.Logf("Finish reason: %s", finishReason)
|
||||
t.Logf("Message content: %q", content)
|
||||
t.Logf("Tool calls count: %d", len(message.ToolCalls))
|
||||
if messageJSON, err := json.MarshalIndent(message, "", " "); err == nil {
|
||||
t.Logf("Full message: %s", string(messageJSON))
|
||||
}
|
||||
t.Fatalf("expected at least one tool call, got none. Finish reason: %s, Content: %q", finishReason, content)
|
||||
}
|
||||
|
||||
firstToolCall := message.ToolCalls[0]
|
||||
if firstToolCall.Function.Name != "execute_bash" {
|
||||
t.Fatalf("unexpected tool called: got %q want %q", firstToolCall.Function.Name, "execute_bash")
|
||||
}
|
||||
|
||||
var args map[string]any
|
||||
if firstToolCall.Function.Arguments != "" {
|
||||
if err := json.Unmarshal([]byte(firstToolCall.Function.Arguments), &args); err != nil {
|
||||
t.Fatalf("failed to parse tool call arguments: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
command, ok := args["command"]
|
||||
if !ok {
|
||||
t.Fatalf("expected tool arguments to include 'command', got: %v", args)
|
||||
}
|
||||
|
||||
cmdStr, ok := command.(string)
|
||||
if !ok {
|
||||
t.Fatalf("expected command to be string, got %T", command)
|
||||
}
|
||||
|
||||
if !strings.Contains(cmdStr, "&") {
|
||||
t.Errorf("expected command to contain '&' character for parsing test, got: %q", cmdStr)
|
||||
}
|
||||
|
||||
if !strings.Contains(cmdStr, "echo") && !strings.Contains(cmdStr, "printf") {
|
||||
t.Errorf("expected command to use echo or printf, got: %q", cmdStr)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
@ -5,12 +5,30 @@ package integration
|
|||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
)
|
||||
|
||||
var libraryToolsModels = []string{
|
||||
"qwen3-vl",
|
||||
"gpt-oss:20b",
|
||||
"gpt-oss:120b",
|
||||
"qwen3",
|
||||
"llama3.1",
|
||||
"llama3.2",
|
||||
"mistral",
|
||||
"qwen2.5",
|
||||
"qwen2",
|
||||
"mistral-nemo",
|
||||
"mistral-small",
|
||||
"mixtral:8x22b",
|
||||
"qwq",
|
||||
"granite3.3",
|
||||
}
|
||||
|
||||
func TestAPIToolCalling(t *testing.T) {
|
||||
initialTimeout := 60 * time.Second
|
||||
streamTimeout := 60 * time.Second
|
||||
|
|
@ -20,23 +38,6 @@ func TestAPIToolCalling(t *testing.T) {
|
|||
client, _, cleanup := InitServerConnection(ctx, t)
|
||||
defer cleanup()
|
||||
|
||||
minVRAM := map[string]uint64{
|
||||
"qwen3-vl": 16,
|
||||
"gpt-oss:20b": 16,
|
||||
"gpt-oss:120b": 70,
|
||||
"qwen3": 6,
|
||||
"llama3.1": 8,
|
||||
"llama3.2": 4,
|
||||
"mistral": 6,
|
||||
"qwen2.5": 6,
|
||||
"qwen2": 6,
|
||||
"mistral-nemo": 9,
|
||||
"mistral-small": 16,
|
||||
"mixtral:8x22b": 80,
|
||||
"qwq": 20,
|
||||
"granite3.3": 7,
|
||||
}
|
||||
|
||||
for _, model := range libraryToolsModels {
|
||||
t.Run(model, func(t *testing.T) {
|
||||
if v, ok := minVRAM[model]; ok {
|
||||
|
|
@ -130,3 +131,210 @@ func TestAPIToolCalling(t *testing.T) {
|
|||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAPIToolCallingMultiTurn(t *testing.T) {
|
||||
initialTimeout := 60 * time.Second
|
||||
streamTimeout := 60 * time.Second
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Minute)
|
||||
defer cancel()
|
||||
|
||||
client, _, cleanup := InitServerConnection(ctx, t)
|
||||
defer cleanup()
|
||||
|
||||
for _, model := range libraryToolsModels {
|
||||
t.Run(model, func(t *testing.T) {
|
||||
if v, ok := minVRAM[model]; ok {
|
||||
skipUnderMinVRAM(t, v)
|
||||
}
|
||||
|
||||
if err := PullIfMissing(ctx, client, model); err != nil {
|
||||
t.Fatalf("pull failed %s", err)
|
||||
}
|
||||
|
||||
tools := []api.Tool{
|
||||
{
|
||||
Type: "function",
|
||||
Function: api.ToolFunction{
|
||||
Name: "get_weather",
|
||||
Description: "Get the current weather in a given location",
|
||||
Parameters: api.ToolFunctionParameters{
|
||||
Type: "object",
|
||||
Required: []string{"location"},
|
||||
Properties: map[string]api.ToolProperty{
|
||||
"location": {
|
||||
Type: api.PropertyType{"string"},
|
||||
Description: "The city and state, e.g. San Francisco, CA",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
// First turn: User asks for weather, model should make a tool call
|
||||
userMessage := api.Message{
|
||||
Role: "user",
|
||||
Content: "What's the weather like in San Francisco?",
|
||||
}
|
||||
|
||||
req := api.ChatRequest{
|
||||
Model: model,
|
||||
Messages: []api.Message{userMessage},
|
||||
Tools: tools,
|
||||
Options: map[string]any{
|
||||
"temperature": 0,
|
||||
},
|
||||
}
|
||||
|
||||
stallTimer := time.NewTimer(initialTimeout)
|
||||
var assistantMessage api.Message
|
||||
var gotToolCall bool
|
||||
var toolCallID string
|
||||
|
||||
fn := func(response api.ChatResponse) error {
|
||||
// Accumulate assistant message content
|
||||
if response.Message.Content != "" {
|
||||
assistantMessage.Content += response.Message.Content
|
||||
assistantMessage.Role = "assistant"
|
||||
}
|
||||
// Capture tool calls whenever they appear
|
||||
if len(response.Message.ToolCalls) > 0 {
|
||||
gotToolCall = true
|
||||
assistantMessage.ToolCalls = response.Message.ToolCalls
|
||||
assistantMessage.Role = "assistant"
|
||||
// Capture the tool call ID if available
|
||||
toolCallID = response.Message.ToolCalls[0].ID
|
||||
}
|
||||
if !stallTimer.Reset(streamTimeout) {
|
||||
return fmt.Errorf("stall was detected while streaming response, aborting")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
stream := true
|
||||
req.Stream = &stream
|
||||
done := make(chan int)
|
||||
var genErr error
|
||||
go func() {
|
||||
genErr = client.Chat(ctx, &req, fn)
|
||||
done <- 0
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-stallTimer.C:
|
||||
t.Fatalf("first turn chat never started. Timed out after: %s", initialTimeout.String())
|
||||
case <-done:
|
||||
if genErr != nil {
|
||||
t.Fatalf("first turn chat failed: %v", genErr)
|
||||
}
|
||||
|
||||
if !gotToolCall {
|
||||
t.Fatalf("expected at least one tool call in first turn, got none")
|
||||
}
|
||||
|
||||
if len(assistantMessage.ToolCalls) == 0 {
|
||||
t.Fatalf("expected tool calls in assistant message, got none")
|
||||
}
|
||||
|
||||
firstToolCall := assistantMessage.ToolCalls[0]
|
||||
if firstToolCall.Function.Name != "get_weather" {
|
||||
t.Errorf("unexpected tool called: got %q want %q", firstToolCall.Function.Name, "get_weather")
|
||||
}
|
||||
|
||||
location, ok := firstToolCall.Function.Arguments["location"]
|
||||
if !ok {
|
||||
t.Fatalf("expected tool arguments to include 'location', got: %s", firstToolCall.Function.Arguments.String())
|
||||
}
|
||||
|
||||
// Second turn: Feed back the tool result and expect a natural language response
|
||||
toolResult := `{"temperature": 72, "condition": "sunny", "humidity": 65}`
|
||||
toolMessage := api.Message{
|
||||
Role: "tool",
|
||||
Content: toolResult,
|
||||
ToolName: "get_weather",
|
||||
ToolCallID: toolCallID,
|
||||
}
|
||||
|
||||
// Build conversation history: user -> assistant (with tool call) -> tool (result) -> user (follow-up)
|
||||
messages := []api.Message{
|
||||
userMessage,
|
||||
assistantMessage,
|
||||
toolMessage,
|
||||
}
|
||||
|
||||
req2 := api.ChatRequest{
|
||||
Model: model,
|
||||
Messages: messages,
|
||||
Tools: tools,
|
||||
Options: map[string]any{
|
||||
"temperature": 0,
|
||||
},
|
||||
}
|
||||
|
||||
stallTimer2 := time.NewTimer(initialTimeout)
|
||||
var finalResponse string
|
||||
var gotSecondToolCall bool
|
||||
|
||||
fn2 := func(response api.ChatResponse) error {
|
||||
if len(response.Message.ToolCalls) > 0 {
|
||||
gotSecondToolCall = true
|
||||
}
|
||||
if response.Message.Content != "" {
|
||||
finalResponse += response.Message.Content
|
||||
}
|
||||
if !stallTimer2.Reset(streamTimeout) {
|
||||
return fmt.Errorf("stall was detected while streaming response, aborting")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
req2.Stream = &stream
|
||||
done2 := make(chan int)
|
||||
var genErr2 error
|
||||
go func() {
|
||||
genErr2 = client.Chat(ctx, &req2, fn2)
|
||||
done2 <- 0
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-stallTimer2.C:
|
||||
t.Fatalf("second turn chat never started. Timed out after: %s", initialTimeout.String())
|
||||
case <-done2:
|
||||
if genErr2 != nil {
|
||||
t.Fatalf("second turn chat failed: %v", genErr2)
|
||||
}
|
||||
|
||||
if gotSecondToolCall {
|
||||
t.Errorf("expected no tool calls in second turn, but got tool calls. Model should respond with natural language after receiving tool result.")
|
||||
}
|
||||
|
||||
if finalResponse == "" {
|
||||
t.Fatalf("expected natural language response in second turn, got empty response")
|
||||
}
|
||||
|
||||
// Verify the response mentions something about the weather (temperature, condition, etc.)
|
||||
responseLower := strings.ToLower(finalResponse)
|
||||
expectedKeywords := []string{"72", "sunny", "temperature", "weather", "san francisco", "fahrenheit"}
|
||||
foundKeyword := false
|
||||
for _, keyword := range expectedKeywords {
|
||||
if strings.Contains(responseLower, strings.ToLower(keyword)) {
|
||||
foundKeyword = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !foundKeyword {
|
||||
t.Logf("response: %s", finalResponse)
|
||||
t.Logf("location from tool call: %v", location)
|
||||
// Don't fail, just log - the model might phrase it differently
|
||||
}
|
||||
|
||||
t.Logf("Successfully completed multi-turn tool calling test. First turn made tool call, second turn responded with: %s", finalResponse)
|
||||
case <-ctx.Done():
|
||||
t.Error("outer test context done while waiting for second turn")
|
||||
}
|
||||
case <-ctx.Done():
|
||||
t.Error("outer test context done while waiting for first turn")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -260,22 +260,6 @@ var (
|
|||
"snowflake-arctic-embed",
|
||||
"snowflake-arctic-embed2",
|
||||
}
|
||||
libraryToolsModels = []string{
|
||||
"qwen3-vl",
|
||||
"gpt-oss:20b",
|
||||
"gpt-oss:120b",
|
||||
"qwen3",
|
||||
"llama3.1",
|
||||
"llama3.2",
|
||||
"mistral",
|
||||
"qwen2.5",
|
||||
"qwen2",
|
||||
"mistral-nemo",
|
||||
"mistral-small",
|
||||
"mixtral:8x22b",
|
||||
"qwq",
|
||||
"granite3.3",
|
||||
}
|
||||
|
||||
blueSkyPrompt = "why is the sky blue? Be brief but factual in your reply"
|
||||
blueSkyExpected = []string{"rayleigh", "scatter", "atmosphere", "nitrogen", "oxygen", "wavelength", "interact"}
|
||||
|
|
@ -747,6 +731,23 @@ func skipUnderMinVRAM(t *testing.T, gb uint64) {
|
|||
}
|
||||
}
|
||||
|
||||
var minVRAM = map[string]uint64{
|
||||
"qwen3-vl": 16,
|
||||
"gpt-oss:20b": 16,
|
||||
"gpt-oss:120b": 70,
|
||||
"qwen3": 6,
|
||||
"llama3.1": 8,
|
||||
"llama3.2": 4,
|
||||
"mistral": 6,
|
||||
"qwen2.5": 6,
|
||||
"qwen2": 6,
|
||||
"mistral-nemo": 9,
|
||||
"mistral-small": 16,
|
||||
"mixtral:8x22b": 80,
|
||||
"qwq": 20,
|
||||
"granite3.3": 7,
|
||||
}
|
||||
|
||||
// Skip if the target model isn't X% GPU loaded to avoid excessive runtime
|
||||
func skipIfNotGPULoaded(ctx context.Context, t *testing.T, client *api.Client, model string, minPercent int) {
|
||||
gpuPercent := getGPUPercent(ctx, t, client, model)
|
||||
|
|
|
|||
Loading…
Reference in New Issue