diff --git a/go.mod b/go.mod index 54e7942cd..59573eeee 100644 --- a/go.mod +++ b/go.mod @@ -27,6 +27,7 @@ require ( github.com/emirpasic/gods/v2 v2.0.0-alpha github.com/mattn/go-runewidth v0.0.14 github.com/nlpodyssey/gopickle v0.3.0 + github.com/openai/openai-go/v3 v3.8.1 github.com/pdevine/tensor v0.0.0-20240510204454-f88f4562727c github.com/tkrajina/typescriptify-golang-structs v0.2.0 golang.org/x/image v0.22.0 @@ -48,6 +49,10 @@ require ( github.com/pkg/errors v0.9.1 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect github.com/rivo/uniseg v0.2.0 // indirect + github.com/tidwall/gjson v1.14.4 // indirect + github.com/tidwall/match v1.1.1 // indirect + github.com/tidwall/pretty v1.2.1 // indirect + github.com/tidwall/sjson v1.2.5 // indirect github.com/tkrajina/go-reflector v0.5.5 // indirect github.com/xtgo/set v1.0.0 // indirect go4.org/unsafe/assume-no-moving-gc v0.0.0-20231121144256-b99613f794b6 // indirect diff --git a/go.sum b/go.sum index 464cd6fcc..82c8c028b 100644 --- a/go.sum +++ b/go.sum @@ -159,6 +159,8 @@ github.com/nlpodyssey/gopickle v0.3.0 h1:BLUE5gxFLyyNOPzlXxt6GoHEMMxD0qhsE4p0CIQ github.com/nlpodyssey/gopickle v0.3.0/go.mod h1:f070HJ/yR+eLi5WmM1OXJEGaTpuJEUiib19olXgYha0= github.com/olekukonko/tablewriter v0.0.5 h1:P2Ga83D34wi1o9J6Wh1mRuqd4mF/x/lgBS7N7AbDhec= github.com/olekukonko/tablewriter v0.0.5/go.mod h1:hPp6KlRPjbx+hW8ykQs1w3UBbZlj6HuIJcUGPhkA7kY= +github.com/openai/openai-go/v3 v3.8.1 h1:b+YWsmwqXnbpSHWQEntZAkKciBZ5CJXwL68j+l59UDg= +github.com/openai/openai-go/v3 v3.8.1/go.mod h1:UOpNxkqC9OdNXNUfpNByKOtB4jAL0EssQXq5p8gO0Xs= github.com/pdevine/tensor v0.0.0-20240510204454-f88f4562727c h1:GwiUUjKefgvSNmv3NCvI/BL0kDebW6Xa+kcdpdc1mTY= github.com/pdevine/tensor v0.0.0-20240510204454-f88f4562727c/go.mod h1:PSojXDXF7TbgQiD6kkd98IHOS0QqTyUEaWRiS8+BLu8= github.com/pelletier/go-toml/v2 v2.2.2 h1:aYUidT7k73Pcl9nb2gScu7NSrKCSHIDE89b3+6Wq+LM= @@ -199,6 +201,16 @@ github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg= github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= +github.com/tidwall/gjson v1.14.2/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= +github.com/tidwall/gjson v1.14.4 h1:uo0p8EbA09J7RQaflQ1aBRffTR7xedD2bcIVSYxLnkM= +github.com/tidwall/gjson v1.14.4/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= +github.com/tidwall/match v1.1.1 h1:+Ho715JplO36QYgwN9PGYNhgZvoUSc9X2c80KVTi+GA= +github.com/tidwall/match v1.1.1/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JTxsfmM= +github.com/tidwall/pretty v1.2.0/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU= +github.com/tidwall/pretty v1.2.1 h1:qjsOFOWWQl+N3RsoF5/ssm1pHmJJwhjlSbZ51I6wMl4= +github.com/tidwall/pretty v1.2.1/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU= +github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY= +github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28= github.com/tkrajina/go-reflector v0.5.5 h1:gwoQFNye30Kk7NrExj8zm3zFtrGPqOkzFMLuQZg1DtQ= github.com/tkrajina/go-reflector v0.5.5/go.mod h1:ECbqLgccecY5kPmPmXg1MrHW585yMcDkVl6IvJe64T4= github.com/tkrajina/typescriptify-golang-structs v0.2.0 h1:ZedWk82egydDspGTryAatbX0/1NZDQbdiZLoCbOk4f8= diff --git a/integration/README.md b/integration/README.md index 5d2acc456..5fb171167 100644 --- a/integration/README.md +++ b/integration/README.md @@ -14,4 +14,59 @@ The integration tests have 2 modes of operating. > Before running the tests locally without the "test existing" setting, compile ollama from the top of the source tree `go build .` in addition to GPU support with cmake if applicable on your platform. The integration tests expect to find an ollama binary at the top of the tree. -Many tests use a default small model suitable to run on many systems. You can override this default model by setting `OLLAMA_TEST_DEFAULT_MODEL` \ No newline at end of file +Many tests use a default small model suitable to run on many systems. You can override this default model by setting `OLLAMA_TEST_DEFAULT_MODEL` + +## Tool Calling Tests + +The tool calling tests are split into two files: + +- **`tools_test.go`** - Tests using the native Ollama API (`api.Tool`) +- **`tools_openai_test.go`** - Tests using the OpenAI-compatible API format + +### Running Tool Calling Tests + +Run all tool calling tests: +```bash +go test -tags=integration -v -run Test.*Tool.* ./integration +``` + +Run only OpenAI-compatible tests: +```bash +go test -tags=integration -v -run TestOpenAI ./integration +``` + +Run only native API tests: +```bash +go test -tags=integration -v -run TestAPIToolCalling ./integration +``` + +### Parallel Execution + +The OpenAI-compatible tests (`tools_openai_test.go`) support parallel execution for cloud models. Run with parallel execution: +```bash +go test -tags=integration -v -run TestOpenAI -parallel 3 ./integration +``` + +Cloud models (models ending with `-cloud`) will run in parallel, while local models run sequentially. This significantly speeds up test execution when testing against external endpoints. + +### Testing Specific Models + +To test a specific model, set the `OPENAI_TEST_MODELS` environment variable: +```bash +OPENAI_TEST_MODELS="gpt-oss:120b-cloud" go test -tags=integration -v -run TestOpenAI ./integration +``` + +### External Endpoints + +To test against an external OpenAI-compatible endpoint (e.g., Ollama Cloud): +```bash +OPENAI_BASE_URL="https://ollama.com/v1" OLLAMA_API_KEY="your-key" go test -tags=integration -v -run TestOpenAI ./integration +``` + +### Environment Variables + +The tool calling tests support the following environment variables: + +- **`OPENAI_BASE_URL`** - When set, tests will run against an external OpenAI-compatible endpoint instead of a local server. If set, `OLLAMA_API_KEY` must also be provided. +- **`OLLAMA_API_KEY`** - API key for authenticating with external endpoints (required when `OPENAI_BASE_URL` is set). +- **`OPENAI_TEST_MODELS`** - Override the default model list and test only the specified model(s). Can be a single model or comma-separated list. \ No newline at end of file diff --git a/integration/tools_openai_test.go b/integration/tools_openai_test.go new file mode 100644 index 000000000..ebabc2bc3 --- /dev/null +++ b/integration/tools_openai_test.go @@ -0,0 +1,834 @@ +//go:build integration + +package integration + +import ( + "context" + "encoding/json" + "fmt" + "os" + "slices" + "strings" + "testing" + "time" + + "github.com/google/go-cmp/cmp" + "github.com/ollama/ollama/api" + "github.com/openai/openai-go/v3" + "github.com/openai/openai-go/v3/option" + "github.com/openai/openai-go/v3/shared" +) + +var agenticModels = []string{ + "gpt-oss:20b", + "gpt-oss:120b", + "qwen3-coder:30b", + "qwen3:4b", + "qwen3:8b", +} + +var cloudModels = []string{ + "gpt-oss:120b-cloud", + "gpt-oss:20b-cloud", + "qwen3-vl:235b-cloud", + "qwen3-coder:480b-cloud", + "kimi-k2-thinking:cloud", + "kimi-k2:1t-cloud", +} + +// validateBashCommand validates a bash command with flexible matching +// It checks that the core command matches and required arguments are present +func validateBashCommand(cmd string, expectedCmd string, requiredArgs []string) error { + parts := strings.Fields(cmd) + if len(parts) == 0 { + return fmt.Errorf("empty command") + } + + actualCmd := parts[0] + if actualCmd != expectedCmd { + return fmt.Errorf("expected command '%s', got '%s'", expectedCmd, actualCmd) + } + + cmdStr := strings.Join(parts[1:], " ") + for _, arg := range requiredArgs { + if !strings.Contains(cmdStr, arg) { + return fmt.Errorf("missing required argument: %s", arg) + } + } + + return nil +} + +// validateBashCommandFlexible validates a bash command with flexible matching +// It accepts alternative command forms (e.g., find vs ls) and checks required patterns +func validateBashCommandFlexible(cmd string, allowedCommands []string, requiredPatterns []string) error { + parts := strings.Fields(cmd) + if len(parts) == 0 { + return fmt.Errorf("empty command") + } + + actualCmd := parts[0] + commandMatched := false + for _, allowedCmd := range allowedCommands { + if actualCmd == allowedCmd { + commandMatched = true + break + } + } + if !commandMatched { + return fmt.Errorf("expected one of commands %v, got '%s'", allowedCommands, actualCmd) + } + + cmdStr := strings.ToLower(strings.Join(parts[1:], " ")) + for _, pattern := range requiredPatterns { + if !strings.Contains(cmdStr, strings.ToLower(pattern)) { + return fmt.Errorf("missing required pattern: %s", pattern) + } + } + + return nil +} + +func TestOpenAIToolCallingMultiStep(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Minute) + defer cancel() + + var baseURL string + var apiKey string + var modelsToTest []string + var cleanup func() + + if openaiBaseURL := os.Getenv("OPENAI_BASE_URL"); openaiBaseURL != "" { + baseURL = openaiBaseURL + apiKey = os.Getenv("OLLAMA_API_KEY") + if apiKey == "" { + t.Fatal("OPENAI_API_KEY must be set when using OPENAI_BASE_URL") + } + + // only test cloud models unless OPENAI_TEST_MODELS is set + modelsToTest = cloudModels + if modelsEnv := os.Getenv("OPENAI_TEST_MODELS"); modelsEnv != "" { + modelsToTest = []string{modelsEnv} + } + cleanup = func() {} + } else { + _, testEndpoint, cleanupFn := InitServerConnection(ctx, t) + cleanup = cleanupFn + baseURL = fmt.Sprintf("http://%s/v1", testEndpoint) + apiKey = "ollama" + modelsToTest = append(agenticModels, cloudModels...) + } + t.Cleanup(cleanup) + + opts := []option.RequestOption{ + option.WithBaseURL(baseURL), + option.WithAPIKey(apiKey), + } + openaiClient := openai.NewClient(opts...) + + var ollamaClient *api.Client + if baseURL == "" { + ollamaClient, _, _ = InitServerConnection(ctx, t) + } + + for _, model := range modelsToTest { + t.Run(model, func(t *testing.T) { + testCtx := ctx + if slices.Contains(cloudModels, model) { + t.Parallel() + // Create a new context for parallel tests to avoid cancellation + var cancel context.CancelFunc + testCtx, cancel = context.WithTimeout(context.Background(), 10*time.Minute) + defer cancel() + } + if v, ok := minVRAM[model]; ok { + skipUnderMinVRAM(t, v) + } + + if ollamaClient != nil { + if err := PullIfMissing(testCtx, ollamaClient, model); err != nil { + t.Fatalf("pull failed %s", err) + } + } + + tools := []openai.ChatCompletionToolUnionParam{ + openai.ChatCompletionFunctionTool(shared.FunctionDefinitionParam{ + Name: "list_files", + Description: openai.Opt("List all files in a directory"), + Parameters: shared.FunctionParameters{ + "type": "object", + "properties": map[string]any{ + "path": map[string]any{ + "type": "string", + "description": "The directory path to list files from", + }, + }, + "required": []string{"path"}, + }, + }), + openai.ChatCompletionFunctionTool(shared.FunctionDefinitionParam{ + Name: "read_file", + Description: openai.Opt("Read the contents of a file"), + Parameters: shared.FunctionParameters{ + "type": "object", + "properties": map[string]any{ + "path": map[string]any{ + "type": "string", + "description": "The file path to read", + }, + }, + "required": []string{"path"}, + }, + }), + } + + mockFileContents := "line 1\nline 2\nline 3\nline 4\nline 5" + userContent := "Find the file named 'config.json' in /tmp and read its contents" + userMessage := openai.UserMessage(userContent) + + messages := []openai.ChatCompletionMessageParamUnion{ + userMessage, + } + stepCount := 0 + maxSteps := 10 + + normalizePath := func(path string) string { + if path != "" && path[0] != '/' { + return "/" + path + } + return path + } + + expectedSteps := []struct { + functionName string + validateArgs func(map[string]any) error + result string + }{ + { + functionName: "list_files", + validateArgs: func(args map[string]any) error { + path, ok := args["path"] + if !ok { + return fmt.Errorf("missing required argument 'path'") + } + pathStr, ok := path.(string) + if !ok { + return fmt.Errorf("expected 'path' to be string, got %T", path) + } + normalizedPath := normalizePath(pathStr) + if normalizedPath != "/tmp" { + return fmt.Errorf("expected list_files(\"/tmp\"), got list_files(%q)", pathStr) + } + return nil + }, + result: `["config.json", "other.txt", "data.log"]`, + }, + { + functionName: "read_file", + validateArgs: func(args map[string]any) error { + path, ok := args["path"] + if !ok { + return fmt.Errorf("missing required argument 'path'") + } + pathStr, ok := path.(string) + if !ok { + return fmt.Errorf("expected 'path' to be string, got %T", path) + } + normalizedPath := normalizePath(pathStr) + if normalizedPath != "/tmp/config.json" { + return fmt.Errorf("expected read_file(\"/tmp/config.json\"), got read_file(%q)", pathStr) + } + return nil + }, + result: mockFileContents, + }, + } + + for stepCount < maxSteps { + req := openai.ChatCompletionNewParams{ + Model: shared.ChatModel(model), + Messages: messages, + Tools: tools, + Temperature: openai.Opt(0.0), + } + + completion, err := openaiClient.Chat.Completions.New(testCtx, req) + if err != nil { + t.Fatalf("step %d chat failed: %v", stepCount+1, err) + } + + if len(completion.Choices) == 0 { + t.Fatalf("step %d: no choices in response", stepCount+1) + } + + choice := completion.Choices[0] + message := choice.Message + + toolCalls := message.ToolCalls + content := message.Content + gotToolCall := len(toolCalls) > 0 + var toolCallID string + if gotToolCall && toolCalls[0].ID != "" { + toolCallID = toolCalls[0].ID + } + + var assistantMessage openai.ChatCompletionMessageParamUnion + if gotToolCall { + toolCallsJSON, err := json.Marshal(toolCalls) + if err != nil { + t.Fatalf("step %d: failed to marshal tool calls: %v", stepCount+1, err) + } + var toolCallParams []openai.ChatCompletionMessageToolCallUnionParam + if err := json.Unmarshal(toolCallsJSON, &toolCallParams); err != nil { + t.Fatalf("step %d: failed to unmarshal tool calls: %v", stepCount+1, err) + } + contentUnion := openai.ChatCompletionAssistantMessageParamContentUnion{ + OfString: openai.Opt(content), + } + assistantMsg := openai.ChatCompletionAssistantMessageParam{ + Content: contentUnion, + ToolCalls: toolCallParams, + } + assistantMessage = openai.ChatCompletionMessageParamUnion{ + OfAssistant: &assistantMsg, + } + } else { + assistantMessage = openai.AssistantMessage(content) + } + + if !gotToolCall && content != "" { + if stepCount < len(expectedSteps) { + t.Logf("EXPECTED: Step %d should call '%s'", stepCount+1, expectedSteps[stepCount].functionName) + t.Logf("ACTUAL: Model stopped with content: %s", content) + t.Fatalf("model stopped making tool calls after %d steps, expected %d steps. Final response: %s", stepCount, len(expectedSteps), content) + } + return + } + + if !gotToolCall || len(toolCalls) == 0 { + if stepCount < len(expectedSteps) { + expectedStep := expectedSteps[stepCount] + t.Logf("EXPECTED: Step %d should call '%s'", stepCount+1, expectedStep.functionName) + t.Logf("ACTUAL: No tool call, got content: %s", content) + t.Fatalf("step %d: expected tool call but got none. Response: %s", stepCount+1, content) + } + return + } + + if stepCount >= len(expectedSteps) { + actualCallJSON, _ := json.MarshalIndent(toolCalls[0], "", " ") + t.Logf("EXPECTED: All %d steps completed", len(expectedSteps)) + t.Logf("ACTUAL: Extra step %d with tool call:\n%s", stepCount+1, string(actualCallJSON)) + funcName := "unknown" + if toolCalls[0].Function.Name != "" { + funcName = toolCalls[0].Function.Name + } + t.Fatalf("model made more tool calls than expected. Expected %d steps, got step %d with tool call: %s", len(expectedSteps), stepCount+1, funcName) + } + + expectedStep := expectedSteps[stepCount] + firstToolCall := toolCalls[0] + funcCall := firstToolCall.Function + if funcCall.Name == "" { + t.Fatalf("step %d: tool call missing function name", stepCount+1) + } + + funcName := funcCall.Name + + var args map[string]any + if funcCall.Arguments != "" { + if err := json.Unmarshal([]byte(funcCall.Arguments), &args); err != nil { + t.Fatalf("step %d: failed to parse tool call arguments: %v", stepCount+1, err) + } + } + + if funcName != expectedStep.functionName { + t.Logf("DIFF: Function name mismatch") + t.Logf(" Expected: %s", expectedStep.functionName) + t.Logf(" Got: %s", funcName) + t.Logf(" Arguments: %v", args) + t.Fatalf("step %d: expected tool call '%s', got '%s'. Arguments: %v", stepCount+1, expectedStep.functionName, funcName, args) + } + + if err := expectedStep.validateArgs(args); err != nil { + expectedArgsForDisplay := map[string]any{} + if expectedStep.functionName == "list_files" { + expectedArgsForDisplay = map[string]any{"path": "/tmp"} + } else if expectedStep.functionName == "read_file" { + expectedArgsForDisplay = map[string]any{"path": "/tmp/config.json"} + } + if diff := cmp.Diff(expectedArgsForDisplay, args); diff != "" { + t.Logf("DIFF: Arguments mismatch for function '%s' (-want +got):\n%s", expectedStep.functionName, diff) + } + t.Logf("Error: %v", err) + t.Fatalf("step %d: tool call '%s' has invalid arguments: %v. Arguments: %v", stepCount+1, expectedStep.functionName, err, args) + } + + toolMessage := openai.ToolMessage(expectedStep.result, toolCallID) + messages = append(messages, assistantMessage, toolMessage) + stepCount++ + } + + if stepCount < len(expectedSteps) { + t.Fatalf("test exceeded max steps (%d) before completing all expected steps (%d)", maxSteps, len(expectedSteps)) + } + }) + } +} + +func TestOpenAIToolCallingBash(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Minute) + defer cancel() + + var baseURL string + var apiKey string + var modelsToTest []string + var cleanup func() + + if openaiBaseURL := os.Getenv("OPENAI_BASE_URL"); openaiBaseURL != "" { + baseURL = openaiBaseURL + apiKey = os.Getenv("OLLAMA_API_KEY") + if apiKey == "" { + t.Fatal("OPENAI_API_KEY must be set when using OPENAI_BASE_URL") + } + modelsToTest = cloudModels + if modelsEnv := os.Getenv("OPENAI_TEST_MODELS"); modelsEnv != "" { + modelsToTest = []string{modelsEnv} + } + cleanup = func() {} + } else { + _, testEndpoint, cleanupFn := InitServerConnection(ctx, t) + cleanup = cleanupFn + baseURL = fmt.Sprintf("http://%s/v1", testEndpoint) + apiKey = "ollama" + modelsToTest = append(agenticModels, cloudModels...) + } + t.Cleanup(cleanup) + + opts := []option.RequestOption{ + option.WithBaseURL(baseURL), + option.WithAPIKey(apiKey), + } + openaiClient := openai.NewClient(opts...) + + var ollamaClient *api.Client + if baseURL == "" { + ollamaClient, _, _ = InitServerConnection(ctx, t) + } + + for _, model := range modelsToTest { + t.Run(model, func(t *testing.T) { + testCtx := ctx + if slices.Contains(cloudModels, model) { + t.Parallel() + // Create a new context for parallel tests to avoid cancellation + var cancel context.CancelFunc + testCtx, cancel = context.WithTimeout(context.Background(), 10*time.Minute) + defer cancel() + } + if v, ok := minVRAM[model]; ok { + skipUnderMinVRAM(t, v) + } + + if ollamaClient != nil { + if err := PullIfMissing(testCtx, ollamaClient, model); err != nil { + t.Fatalf("pull failed %s", err) + } + } + + tools := []openai.ChatCompletionToolUnionParam{ + openai.ChatCompletionFunctionTool(shared.FunctionDefinitionParam{ + Name: "execute_bash", + Description: openai.Opt("Execute a bash/shell command and return stdout, stderr, and exit code"), + Parameters: shared.FunctionParameters{ + "type": "object", + "properties": map[string]any{ + "command": map[string]any{ + "type": "string", + "description": "The bash command to execute", + }, + "working_directory": map[string]any{ + "type": "string", + "description": "Optional working directory for command execution", + }, + }, + "required": []string{"command"}, + }, + }), + } + + userContent := "List all files in /tmp directory" + userMessage := openai.UserMessage(userContent) + + req := openai.ChatCompletionNewParams{ + Model: shared.ChatModel(model), + Messages: []openai.ChatCompletionMessageParamUnion{userMessage}, + Tools: tools, + Temperature: openai.Opt(0.0), + } + + completion, err := openaiClient.Chat.Completions.New(testCtx, req) + if err != nil { + t.Fatalf("chat failed: %v", err) + } + + if len(completion.Choices) == 0 { + t.Fatalf("no choices in response") + } + + choice := completion.Choices[0] + message := choice.Message + + if len(message.ToolCalls) == 0 { + finishReason := choice.FinishReason + if finishReason == "" { + finishReason = "unknown" + } + content := message.Content + if content == "" { + content = "(empty)" + } + t.Logf("User prompt: %q", userContent) + t.Logf("Finish reason: %s", finishReason) + t.Logf("Message content: %q", content) + t.Logf("Tool calls count: %d", len(message.ToolCalls)) + if messageJSON, err := json.MarshalIndent(message, "", " "); err == nil { + t.Logf("Full message: %s", string(messageJSON)) + } + t.Fatalf("expected at least one tool call, got none. Finish reason: %s, Content: %q", finishReason, content) + } + + firstToolCall := message.ToolCalls[0] + if firstToolCall.Function.Name != "execute_bash" { + t.Fatalf("unexpected tool called: got %q want %q", firstToolCall.Function.Name, "execute_bash") + } + + var args map[string]any + if firstToolCall.Function.Arguments != "" { + if err := json.Unmarshal([]byte(firstToolCall.Function.Arguments), &args); err != nil { + t.Fatalf("failed to parse tool call arguments: %v", err) + } + } + + command, ok := args["command"] + if !ok { + t.Fatalf("expected tool arguments to include 'command', got: %v", args) + } + + cmdStr, ok := command.(string) + if !ok { + t.Fatalf("expected command to be string, got %T", command) + } + + if err := validateBashCommand(cmdStr, "ls", []string{"/tmp"}); err != nil { + t.Errorf("bash command validation failed: %v. Command: %q", err, cmdStr) + } + }) + } +} + +func TestOpenAIToolCallingBashMultiStep(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Minute) + defer cancel() + + var baseURL string + var apiKey string + var modelsToTest []string + var cleanup func() + + if openaiBaseURL := os.Getenv("OPENAI_BASE_URL"); openaiBaseURL != "" { + baseURL = openaiBaseURL + apiKey = os.Getenv("OLLAMA_API_KEY") + if apiKey == "" { + t.Fatal("OPENAI_API_KEY must be set when using OPENAI_BASE_URL") + } + modelsToTest = cloudModels + if modelsEnv := os.Getenv("OPENAI_TEST_MODELS"); modelsEnv != "" { + modelsToTest = []string{modelsEnv} + } + cleanup = func() {} + } else { + _, testEndpoint, cleanupFn := InitServerConnection(ctx, t) + cleanup = cleanupFn + baseURL = fmt.Sprintf("http://%s/v1", testEndpoint) + apiKey = "ollama" + modelsToTest = append(agenticModels, cloudModels...) + } + t.Cleanup(cleanup) + + opts := []option.RequestOption{ + option.WithBaseURL(baseURL), + option.WithAPIKey(apiKey), + } + openaiClient := openai.NewClient(opts...) + + var ollamaClient *api.Client + if baseURL == "" { + ollamaClient, _, _ = InitServerConnection(ctx, t) + } + + for _, model := range modelsToTest { + t.Run(model, func(t *testing.T) { + testCtx := ctx + if slices.Contains(cloudModels, model) { + t.Parallel() + // Create a new context for parallel tests to avoid cancellation + var cancel context.CancelFunc + testCtx, cancel = context.WithTimeout(context.Background(), 10*time.Minute) + defer cancel() + } + if v, ok := minVRAM[model]; ok { + skipUnderMinVRAM(t, v) + } + + if ollamaClient != nil { + if err := PullIfMissing(testCtx, ollamaClient, model); err != nil { + t.Fatalf("pull failed %s", err) + } + } + + tools := []openai.ChatCompletionToolUnionParam{ + openai.ChatCompletionFunctionTool(shared.FunctionDefinitionParam{ + Name: "execute_bash", + Description: openai.Opt("Execute a bash/shell command and return stdout, stderr, and exit code"), + Parameters: shared.FunctionParameters{ + "type": "object", + "properties": map[string]any{ + "command": map[string]any{ + "type": "string", + "description": "The bash command to execute", + }, + "working_directory": map[string]any{ + "type": "string", + "description": "Optional working directory for command execution", + }, + }, + "required": []string{"command"}, + }, + }), + } + + userContent := "Find all log files in /tmp. use the bash tool" + userMessage := openai.UserMessage(userContent) + + req := openai.ChatCompletionNewParams{ + Model: shared.ChatModel(model), + Messages: []openai.ChatCompletionMessageParamUnion{userMessage}, + Tools: tools, + Temperature: openai.Opt(0.0), + } + + completion, err := openaiClient.Chat.Completions.New(testCtx, req) + if err != nil { + t.Fatalf("chat failed: %v", err) + } + + if len(completion.Choices) == 0 { + t.Fatalf("no choices in response") + } + + choice := completion.Choices[0] + message := choice.Message + + if len(message.ToolCalls) == 0 { + finishReason := choice.FinishReason + if finishReason == "" { + finishReason = "unknown" + } + content := message.Content + if content == "" { + content = "(empty)" + } + t.Logf("User prompt: %q", userContent) + t.Logf("Finish reason: %s", finishReason) + t.Logf("Message content: %q", content) + t.Logf("Tool calls count: %d", len(message.ToolCalls)) + if messageJSON, err := json.MarshalIndent(message, "", " "); err == nil { + t.Logf("Full message: %s", string(messageJSON)) + } + t.Fatalf("expected at least one tool call, got none. Finish reason: %s, Content: %q", finishReason, content) + } + + firstToolCall := message.ToolCalls[0] + if firstToolCall.Function.Name != "execute_bash" { + t.Fatalf("unexpected tool called: got %q want %q", firstToolCall.Function.Name, "execute_bash") + } + + var args map[string]any + if firstToolCall.Function.Arguments != "" { + if err := json.Unmarshal([]byte(firstToolCall.Function.Arguments), &args); err != nil { + t.Fatalf("failed to parse tool call arguments: %v", err) + } + } + + command, ok := args["command"] + if !ok { + t.Fatalf("expected tool arguments to include 'command', got: %v", args) + } + + cmdStr, ok := command.(string) + if !ok { + t.Fatalf("expected command to be string, got %T", command) + } + + if err := validateBashCommandFlexible(cmdStr, []string{"find", "ls"}, []string{"/tmp"}); err != nil { + t.Errorf("bash command validation failed: %v. Command: %q", err, cmdStr) + } + }) + } +} + +func TestOpenAIToolCallingBashAmpersand(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Minute) + defer cancel() + + var baseURL string + var apiKey string + var modelsToTest []string + var cleanup func() + + if openaiBaseURL := os.Getenv("OPENAI_BASE_URL"); openaiBaseURL != "" { + baseURL = openaiBaseURL + apiKey = os.Getenv("OLLAMA_API_KEY") + if apiKey == "" { + t.Fatal("OPENAI_API_KEY must be set when using OPENAI_BASE_URL") + } + modelsToTest = cloudModels + if modelsEnv := os.Getenv("OPENAI_TEST_MODELS"); modelsEnv != "" { + modelsToTest = []string{modelsEnv} + } + cleanup = func() {} + } else { + _, testEndpoint, cleanupFn := InitServerConnection(ctx, t) + cleanup = cleanupFn + baseURL = fmt.Sprintf("http://%s/v1", testEndpoint) + apiKey = "ollama" + modelsToTest = append(agenticModels, cloudModels...) + } + t.Cleanup(cleanup) + + opts := []option.RequestOption{ + option.WithBaseURL(baseURL), + option.WithAPIKey(apiKey), + } + openaiClient := openai.NewClient(opts...) + + var ollamaClient *api.Client + if baseURL == "" { + ollamaClient, _, _ = InitServerConnection(ctx, t) + } + + for _, model := range modelsToTest { + t.Run(model, func(t *testing.T) { + testCtx := ctx + if slices.Contains(cloudModels, model) { + t.Parallel() + // Create a new context for parallel tests to avoid cancellation + var cancel context.CancelFunc + testCtx, cancel = context.WithTimeout(context.Background(), 10*time.Minute) + defer cancel() + } + if v, ok := minVRAM[model]; ok { + skipUnderMinVRAM(t, v) + } + + if ollamaClient != nil { + if err := PullIfMissing(testCtx, ollamaClient, model); err != nil { + t.Fatalf("pull failed %s", err) + } + } + + tools := []openai.ChatCompletionToolUnionParam{ + openai.ChatCompletionFunctionTool(shared.FunctionDefinitionParam{ + Name: "execute_bash", + Description: openai.Opt("Execute a bash/shell command and return stdout, stderr, and exit code"), + Parameters: shared.FunctionParameters{ + "type": "object", + "properties": map[string]any{ + "command": map[string]any{ + "type": "string", + "description": "The bash command to execute", + }, + "working_directory": map[string]any{ + "type": "string", + "description": "Optional working directory for command execution", + }, + }, + "required": []string{"command"}, + }, + }), + } + + userContent := "Echo the text 'A & B' using bash with the bash tool" + userMessage := openai.UserMessage(userContent) + + req := openai.ChatCompletionNewParams{ + Model: shared.ChatModel(model), + Messages: []openai.ChatCompletionMessageParamUnion{userMessage}, + Tools: tools, + Temperature: openai.Opt(0.0), + } + + completion, err := openaiClient.Chat.Completions.New(testCtx, req) + if err != nil { + t.Fatalf("chat failed: %v", err) + } + + if len(completion.Choices) == 0 { + t.Fatalf("no choices in response") + } + + choice := completion.Choices[0] + message := choice.Message + + if len(message.ToolCalls) == 0 { + finishReason := choice.FinishReason + if finishReason == "" { + finishReason = "unknown" + } + content := message.Content + if content == "" { + content = "(empty)" + } + t.Logf("User prompt: %q", userContent) + t.Logf("Finish reason: %s", finishReason) + t.Logf("Message content: %q", content) + t.Logf("Tool calls count: %d", len(message.ToolCalls)) + if messageJSON, err := json.MarshalIndent(message, "", " "); err == nil { + t.Logf("Full message: %s", string(messageJSON)) + } + t.Fatalf("expected at least one tool call, got none. Finish reason: %s, Content: %q", finishReason, content) + } + + firstToolCall := message.ToolCalls[0] + if firstToolCall.Function.Name != "execute_bash" { + t.Fatalf("unexpected tool called: got %q want %q", firstToolCall.Function.Name, "execute_bash") + } + + var args map[string]any + if firstToolCall.Function.Arguments != "" { + if err := json.Unmarshal([]byte(firstToolCall.Function.Arguments), &args); err != nil { + t.Fatalf("failed to parse tool call arguments: %v", err) + } + } + + command, ok := args["command"] + if !ok { + t.Fatalf("expected tool arguments to include 'command', got: %v", args) + } + + cmdStr, ok := command.(string) + if !ok { + t.Fatalf("expected command to be string, got %T", command) + } + + if !strings.Contains(cmdStr, "&") { + t.Errorf("expected command to contain '&' character for parsing test, got: %q", cmdStr) + } + + if !strings.Contains(cmdStr, "echo") && !strings.Contains(cmdStr, "printf") { + t.Errorf("expected command to use echo or printf, got: %q", cmdStr) + } + }) + } +} diff --git a/integration/tools_test.go b/integration/tools_test.go index d6b8dfa54..5d04a6bf1 100644 --- a/integration/tools_test.go +++ b/integration/tools_test.go @@ -5,12 +5,30 @@ package integration import ( "context" "fmt" + "strings" "testing" "time" "github.com/ollama/ollama/api" ) +var libraryToolsModels = []string{ + "qwen3-vl", + "gpt-oss:20b", + "gpt-oss:120b", + "qwen3", + "llama3.1", + "llama3.2", + "mistral", + "qwen2.5", + "qwen2", + "mistral-nemo", + "mistral-small", + "mixtral:8x22b", + "qwq", + "granite3.3", +} + func TestAPIToolCalling(t *testing.T) { initialTimeout := 60 * time.Second streamTimeout := 60 * time.Second @@ -20,23 +38,6 @@ func TestAPIToolCalling(t *testing.T) { client, _, cleanup := InitServerConnection(ctx, t) defer cleanup() - minVRAM := map[string]uint64{ - "qwen3-vl": 16, - "gpt-oss:20b": 16, - "gpt-oss:120b": 70, - "qwen3": 6, - "llama3.1": 8, - "llama3.2": 4, - "mistral": 6, - "qwen2.5": 6, - "qwen2": 6, - "mistral-nemo": 9, - "mistral-small": 16, - "mixtral:8x22b": 80, - "qwq": 20, - "granite3.3": 7, - } - for _, model := range libraryToolsModels { t.Run(model, func(t *testing.T) { if v, ok := minVRAM[model]; ok { @@ -130,3 +131,210 @@ func TestAPIToolCalling(t *testing.T) { }) } } + +func TestAPIToolCallingMultiTurn(t *testing.T) { + initialTimeout := 60 * time.Second + streamTimeout := 60 * time.Second + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Minute) + defer cancel() + + client, _, cleanup := InitServerConnection(ctx, t) + defer cleanup() + + for _, model := range libraryToolsModels { + t.Run(model, func(t *testing.T) { + if v, ok := minVRAM[model]; ok { + skipUnderMinVRAM(t, v) + } + + if err := PullIfMissing(ctx, client, model); err != nil { + t.Fatalf("pull failed %s", err) + } + + tools := []api.Tool{ + { + Type: "function", + Function: api.ToolFunction{ + Name: "get_weather", + Description: "Get the current weather in a given location", + Parameters: api.ToolFunctionParameters{ + Type: "object", + Required: []string{"location"}, + Properties: map[string]api.ToolProperty{ + "location": { + Type: api.PropertyType{"string"}, + Description: "The city and state, e.g. San Francisco, CA", + }, + }, + }, + }, + }, + } + + // First turn: User asks for weather, model should make a tool call + userMessage := api.Message{ + Role: "user", + Content: "What's the weather like in San Francisco?", + } + + req := api.ChatRequest{ + Model: model, + Messages: []api.Message{userMessage}, + Tools: tools, + Options: map[string]any{ + "temperature": 0, + }, + } + + stallTimer := time.NewTimer(initialTimeout) + var assistantMessage api.Message + var gotToolCall bool + var toolCallID string + + fn := func(response api.ChatResponse) error { + // Accumulate assistant message content + if response.Message.Content != "" { + assistantMessage.Content += response.Message.Content + assistantMessage.Role = "assistant" + } + // Capture tool calls whenever they appear + if len(response.Message.ToolCalls) > 0 { + gotToolCall = true + assistantMessage.ToolCalls = response.Message.ToolCalls + assistantMessage.Role = "assistant" + // Capture the tool call ID if available + toolCallID = response.Message.ToolCalls[0].ID + } + if !stallTimer.Reset(streamTimeout) { + return fmt.Errorf("stall was detected while streaming response, aborting") + } + return nil + } + + stream := true + req.Stream = &stream + done := make(chan int) + var genErr error + go func() { + genErr = client.Chat(ctx, &req, fn) + done <- 0 + }() + + select { + case <-stallTimer.C: + t.Fatalf("first turn chat never started. Timed out after: %s", initialTimeout.String()) + case <-done: + if genErr != nil { + t.Fatalf("first turn chat failed: %v", genErr) + } + + if !gotToolCall { + t.Fatalf("expected at least one tool call in first turn, got none") + } + + if len(assistantMessage.ToolCalls) == 0 { + t.Fatalf("expected tool calls in assistant message, got none") + } + + firstToolCall := assistantMessage.ToolCalls[0] + if firstToolCall.Function.Name != "get_weather" { + t.Errorf("unexpected tool called: got %q want %q", firstToolCall.Function.Name, "get_weather") + } + + location, ok := firstToolCall.Function.Arguments["location"] + if !ok { + t.Fatalf("expected tool arguments to include 'location', got: %s", firstToolCall.Function.Arguments.String()) + } + + // Second turn: Feed back the tool result and expect a natural language response + toolResult := `{"temperature": 72, "condition": "sunny", "humidity": 65}` + toolMessage := api.Message{ + Role: "tool", + Content: toolResult, + ToolName: "get_weather", + ToolCallID: toolCallID, + } + + // Build conversation history: user -> assistant (with tool call) -> tool (result) -> user (follow-up) + messages := []api.Message{ + userMessage, + assistantMessage, + toolMessage, + } + + req2 := api.ChatRequest{ + Model: model, + Messages: messages, + Tools: tools, + Options: map[string]any{ + "temperature": 0, + }, + } + + stallTimer2 := time.NewTimer(initialTimeout) + var finalResponse string + var gotSecondToolCall bool + + fn2 := func(response api.ChatResponse) error { + if len(response.Message.ToolCalls) > 0 { + gotSecondToolCall = true + } + if response.Message.Content != "" { + finalResponse += response.Message.Content + } + if !stallTimer2.Reset(streamTimeout) { + return fmt.Errorf("stall was detected while streaming response, aborting") + } + return nil + } + + req2.Stream = &stream + done2 := make(chan int) + var genErr2 error + go func() { + genErr2 = client.Chat(ctx, &req2, fn2) + done2 <- 0 + }() + + select { + case <-stallTimer2.C: + t.Fatalf("second turn chat never started. Timed out after: %s", initialTimeout.String()) + case <-done2: + if genErr2 != nil { + t.Fatalf("second turn chat failed: %v", genErr2) + } + + if gotSecondToolCall { + t.Errorf("expected no tool calls in second turn, but got tool calls. Model should respond with natural language after receiving tool result.") + } + + if finalResponse == "" { + t.Fatalf("expected natural language response in second turn, got empty response") + } + + // Verify the response mentions something about the weather (temperature, condition, etc.) + responseLower := strings.ToLower(finalResponse) + expectedKeywords := []string{"72", "sunny", "temperature", "weather", "san francisco", "fahrenheit"} + foundKeyword := false + for _, keyword := range expectedKeywords { + if strings.Contains(responseLower, strings.ToLower(keyword)) { + foundKeyword = true + break + } + } + if !foundKeyword { + t.Logf("response: %s", finalResponse) + t.Logf("location from tool call: %v", location) + // Don't fail, just log - the model might phrase it differently + } + + t.Logf("Successfully completed multi-turn tool calling test. First turn made tool call, second turn responded with: %s", finalResponse) + case <-ctx.Done(): + t.Error("outer test context done while waiting for second turn") + } + case <-ctx.Done(): + t.Error("outer test context done while waiting for first turn") + } + }) + } +} diff --git a/integration/utils_test.go b/integration/utils_test.go index 8a362408e..97e72e469 100644 --- a/integration/utils_test.go +++ b/integration/utils_test.go @@ -260,22 +260,6 @@ var ( "snowflake-arctic-embed", "snowflake-arctic-embed2", } - libraryToolsModels = []string{ - "qwen3-vl", - "gpt-oss:20b", - "gpt-oss:120b", - "qwen3", - "llama3.1", - "llama3.2", - "mistral", - "qwen2.5", - "qwen2", - "mistral-nemo", - "mistral-small", - "mixtral:8x22b", - "qwq", - "granite3.3", - } blueSkyPrompt = "why is the sky blue? Be brief but factual in your reply" blueSkyExpected = []string{"rayleigh", "scatter", "atmosphere", "nitrogen", "oxygen", "wavelength", "interact"} @@ -747,6 +731,23 @@ func skipUnderMinVRAM(t *testing.T, gb uint64) { } } +var minVRAM = map[string]uint64{ + "qwen3-vl": 16, + "gpt-oss:20b": 16, + "gpt-oss:120b": 70, + "qwen3": 6, + "llama3.1": 8, + "llama3.2": 4, + "mistral": 6, + "qwen2.5": 6, + "qwen2": 6, + "mistral-nemo": 9, + "mistral-small": 16, + "mixtral:8x22b": 80, + "qwq": 20, + "granite3.3": 7, +} + // Skip if the target model isn't X% GPU loaded to avoid excessive runtime func skipIfNotGPULoaded(ctx context.Context, t *testing.T, client *api.Client, model string, minPercent int) { gpuPercent := getGPUPercent(ctx, t, client, model)