From a56501316535beb01b4a5fad66c43f3cd219288f Mon Sep 17 00:00:00 2001 From: Baptiste Jamin Date: Mon, 5 Jan 2026 12:01:44 +0100 Subject: [PATCH 1/2] feat: allow enforced tool calling --- api/types.go | 88 ++++++ api/types_test.go | 167 +++++++++++ docs/api.md | 34 +++ docs/api/openai-compatibility.mdx | 90 +++++- middleware/openai.go | 73 ++++- openai/openai.go | 294 +++++++++++++++++- openai/openai_test.go | 477 ++++++++++++++++++++++++++++++ server/routes.go | 175 ++++++++++- server/routes_test.go | 224 ++++++++++++++ 9 files changed, 1611 insertions(+), 11 deletions(-) diff --git a/api/types.go b/api/types.go index 63b898975..79ab7a874 100644 --- a/api/types.go +++ b/api/types.go @@ -148,6 +148,13 @@ type ChatRequest struct { // Tools is an optional list of tools the model has access to. Tools `json:"tools,omitempty"` + // ToolChoice controls how the model uses tools. Can be: + // - "auto" (default): model decides whether to call tools + // - "none": model won't call any tools + // - "required": model must call at least one tool + // - ToolChoiceFunction{Name: "func_name"}: model must call this specific function + ToolChoice *ToolChoice `json:"tool_choice,omitempty"` + // Options lists model-specific options. Options map[string]any `json:"options"` @@ -184,6 +191,87 @@ func (t Tools) String() string { return string(bts) } +// ToolChoice controls how the model uses tools. +// It can be a string ("auto", "none", "required") or a ToolChoiceFunction. +type ToolChoice struct { + // Mode is the tool choice mode: "auto", "none", or "required" + Mode string `json:"-"` + // Function specifies a specific function to call (when forcing a specific tool) + Function *ToolChoiceFunction `json:"-"` +} + +// ToolChoiceFunction specifies a specific function that the model must call. +type ToolChoiceFunction struct { + Name string `json:"name"` +} + +// UnmarshalJSON handles both string and object forms of tool_choice. +func (tc *ToolChoice) UnmarshalJSON(data []byte) error { + // Try string first: "auto", "none", "required" + var s string + if err := json.Unmarshal(data, &s); err == nil { + tc.Mode = s + tc.Function = nil + return nil + } + + // Try object with function name: {"function": {"name": "func_name"}} + var obj struct { + Function *ToolChoiceFunction `json:"function"` + } + if err := json.Unmarshal(data, &obj); err == nil && obj.Function != nil { + tc.Function = obj.Function + tc.Mode = "" + return nil + } + + // Try simple object with just name: {"name": "func_name"} + var simple ToolChoiceFunction + if err := json.Unmarshal(data, &simple); err == nil && simple.Name != "" { + tc.Function = &simple + tc.Mode = "" + return nil + } + + return fmt.Errorf("invalid tool_choice: must be string or object with function name") +} + +// MarshalJSON serializes ToolChoice back to JSON. +func (tc ToolChoice) MarshalJSON() ([]byte, error) { + if tc.Function != nil { + return json.Marshal(map[string]any{"function": tc.Function}) + } + return json.Marshal(tc.Mode) +} + +// IsNone returns true if tool_choice is "none". +func (tc *ToolChoice) IsNone() bool { + return tc != nil && tc.Mode == "none" +} + +// IsRequired returns true if tool_choice is "required". +func (tc *ToolChoice) IsRequired() bool { + return tc != nil && tc.Mode == "required" +} + +// IsAuto returns true if tool_choice is "auto" or not specified. +func (tc *ToolChoice) IsAuto() bool { + return tc == nil || tc.Mode == "" || tc.Mode == "auto" +} + +// IsForcedFunction returns true if a specific function is forced. +func (tc *ToolChoice) IsForcedFunction() bool { + return tc != nil && tc.Function != nil && tc.Function.Name != "" +} + +// GetForcedFunctionName returns the name of the forced function, if any. +func (tc *ToolChoice) GetForcedFunctionName() string { + if tc == nil || tc.Function == nil { + return "" + } + return tc.Function.Name +} + func (t Tool) String() string { bts, _ := json.Marshal(t) return string(bts) diff --git a/api/types_test.go b/api/types_test.go index da1581f48..360e11f50 100644 --- a/api/types_test.go +++ b/api/types_test.go @@ -651,3 +651,170 @@ func TestToolFunctionParameters_String(t *testing.T) { }) } } + +func TestToolChoice_UnmarshalJSON(t *testing.T) { + tests := []struct { + name string + input string + expectedMode string + expectedFunc string + expectedError bool + }{ + { + name: "auto string", + input: `{"tool_choice": "auto"}`, + expectedMode: "auto", + expectedFunc: "", + }, + { + name: "none string", + input: `{"tool_choice": "none"}`, + expectedMode: "none", + expectedFunc: "", + }, + { + name: "required string", + input: `{"tool_choice": "required"}`, + expectedMode: "required", + expectedFunc: "", + }, + { + name: "function object with nested function", + input: `{"tool_choice": {"function": {"name": "get_weather"}}}`, + expectedMode: "", + expectedFunc: "get_weather", + }, + { + name: "function object with direct name", + input: `{"tool_choice": {"name": "get_weather"}}`, + expectedMode: "", + expectedFunc: "get_weather", + }, + { + name: "unset", + input: `{}`, + expectedMode: "", + expectedFunc: "", + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + var req ChatRequest + err := json.Unmarshal([]byte(test.input), &req) + + if test.expectedError { + require.Error(t, err) + return + } + + require.NoError(t, err) + + if test.expectedMode == "" && test.expectedFunc == "" && test.name == "unset" { + assert.Nil(t, req.ToolChoice) + return + } + + require.NotNil(t, req.ToolChoice) + + if test.expectedMode != "" { + assert.Equal(t, test.expectedMode, req.ToolChoice.Mode) + } + + if test.expectedFunc != "" { + require.NotNil(t, req.ToolChoice.Function) + assert.Equal(t, test.expectedFunc, req.ToolChoice.Function.Name) + } + }) + } +} + +func TestToolChoice_Methods(t *testing.T) { + t.Run("IsNone", func(t *testing.T) { + tc := &ToolChoice{Mode: "none"} + assert.True(t, tc.IsNone()) + + tc = &ToolChoice{Mode: "auto"} + assert.False(t, tc.IsNone()) + + var nilTc *ToolChoice + assert.False(t, nilTc.IsNone()) + }) + + t.Run("IsRequired", func(t *testing.T) { + tc := &ToolChoice{Mode: "required"} + assert.True(t, tc.IsRequired()) + + tc = &ToolChoice{Mode: "auto"} + assert.False(t, tc.IsRequired()) + + var nilTc *ToolChoice + assert.False(t, nilTc.IsRequired()) + }) + + t.Run("IsAuto", func(t *testing.T) { + tc := &ToolChoice{Mode: "auto"} + assert.True(t, tc.IsAuto()) + + tc = &ToolChoice{Mode: ""} + assert.True(t, tc.IsAuto()) + + var nilTc *ToolChoice + assert.True(t, nilTc.IsAuto()) + + tc = &ToolChoice{Mode: "required"} + assert.False(t, tc.IsAuto()) + }) + + t.Run("IsForcedFunction", func(t *testing.T) { + tc := &ToolChoice{Function: &ToolChoiceFunction{Name: "get_weather"}} + assert.True(t, tc.IsForcedFunction()) + + tc = &ToolChoice{Mode: "required"} + assert.False(t, tc.IsForcedFunction()) + + tc = &ToolChoice{Function: &ToolChoiceFunction{Name: ""}} + assert.False(t, tc.IsForcedFunction()) + + var nilTc *ToolChoice + assert.False(t, nilTc.IsForcedFunction()) + }) + + t.Run("GetForcedFunctionName", func(t *testing.T) { + tc := &ToolChoice{Function: &ToolChoiceFunction{Name: "get_weather"}} + assert.Equal(t, "get_weather", tc.GetForcedFunctionName()) + + tc = &ToolChoice{Mode: "required"} + assert.Equal(t, "", tc.GetForcedFunctionName()) + + var nilTc *ToolChoice + assert.Equal(t, "", nilTc.GetForcedFunctionName()) + }) +} + +func TestToolChoice_MarshalJSON(t *testing.T) { + tests := []struct { + name string + tc ToolChoice + expected string + }{ + { + name: "mode string", + tc: ToolChoice{Mode: "required"}, + expected: `"required"`, + }, + { + name: "function object", + tc: ToolChoice{Function: &ToolChoiceFunction{Name: "get_weather"}}, + expected: `{"function":{"name":"get_weather"}}`, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + data, err := json.Marshal(test.tc) + require.NoError(t, err) + assert.Equal(t, test.expected, string(data)) + }) + } +} diff --git a/docs/api.md b/docs/api.md index 7c32c9597..65cdd5c5f 100644 --- a/docs/api.md +++ b/docs/api.md @@ -493,6 +493,7 @@ Generate the next message in a chat with a provided model. This is a streaming e - `model`: (required) the [model name](#model-names) - `messages`: the messages of the chat, this can be used to keep a chat memory - `tools`: list of tools in JSON for the model to use if supported +- `tool_choice`: controls how the model uses tools (see [Tool choice](#tool-choice) below) - `think`: (for thinking models) should the model think before responding? The `message` object has the following fields: @@ -519,6 +520,39 @@ Models can also explain the result of the tool call in the response. See the [Ch [See models with tool calling capabilities](https://ollama.com/search?c=tool). +### Tool choice + +By default, the model will determine when and how many tools to use. You can control this behavior with the `tool_choice` parameter: + +- `"auto"` (default): The model decides whether to call zero, one, or multiple tools. +- `"none"`: The model won't call any tools, even if they are provided. +- `"required"`: The model must call at least one tool. The output is constrained to produce a valid tool call. +- `{"function": {"name": "function_name"}}`: The model must call the specified function. + +Example with `tool_choice: "required"`: + +```json +{ + "model": "qwen3", + "messages": [{"role": "user", "content": "What is the weather in Paris?"}], + "tools": [...], + "tool_choice": "required", + "stream": false +} +``` + +Example with a forced function: + +```json +{ + "model": "qwen3", + "messages": [{"role": "user", "content": "What is the weather in Paris?"}], + "tools": [...], + "tool_choice": {"function": {"name": "get_weather"}}, + "stream": false +} +``` + ### Structured outputs Structured outputs are supported by providing a JSON schema in the `format` parameter. The model will generate a response that matches the schema. See the [Chat request (Structured outputs)](#chat-request-structured-outputs) example below. diff --git a/docs/api/openai-compatibility.mdx b/docs/api/openai-compatibility.mdx index a0882053e..a49c8688d 100644 --- a/docs/api/openai-compatibility.mdx +++ b/docs/api/openai-compatibility.mdx @@ -103,6 +103,89 @@ curl -X POST http://localhost:11434/v1/responses \ +### Tool calling with `tool_choice` + +The `tool_choice` parameter controls how the model uses tools: + + + +```python tool_choice.py +from openai import OpenAI + +client = OpenAI( + base_url='http://localhost:11434/v1/', + api_key='ollama', +) + +tools = [ + { + "type": "function", + "function": { + "name": "get_weather", + "description": "Get the weather for a location", + "parameters": { + "type": "object", + "properties": { + "location": {"type": "string", "description": "City name"} + }, + "required": ["location"] + } + } + } +] + +# Force the model to call a tool (any tool) +response = client.chat.completions.create( + model="llama3.2", + messages=[{"role": "user", "content": "What's the weather in Paris?"}], + tools=tools, + tool_choice="required" # Must call at least one tool +) + +# Force a specific function +response = client.chat.completions.create( + model="llama3.2", + messages=[{"role": "user", "content": "Tell me about Paris"}], + tools=tools, + tool_choice={"type": "function", "name": "get_weather"} # Must call get_weather +) + +# Disable tool calling +response = client.chat.completions.create( + model="llama3.2", + messages=[{"role": "user", "content": "What's the weather?"}], + tools=tools, + tool_choice="none" # Don't call any tools +) +``` + +```shell tool_choice.sh +# Force the model to call a tool +curl http://localhost:11434/v1/chat/completions \ +-H "Content-Type: application/json" \ +-d '{ + "model": "llama3.2", + "messages": [{"role": "user", "content": "What is the weather in Paris?"}], + "tools": [{ + "type": "function", + "function": { + "name": "get_weather", + "description": "Get weather for a location", + "parameters": { + "type": "object", + "properties": { + "location": {"type": "string"} + }, + "required": ["location"] + } + } + }], + "tool_choice": "required" +}' +``` + + + ### v1/chat/completions with vision example @@ -207,7 +290,12 @@ curl -X POST http://localhost:11434/v1/chat/completions \ - [x] `top_p` - [x] `max_tokens` - [x] `tools` -- [ ] `tool_choice` +- [x] `tool_choice` + - [x] `"auto"` (default) + - [x] `"none"` + - [x] `"required"` + - [x] `{"type": "function", "name": "function_name"}` + - [x] `{"type": "allowed_tools", "mode": "auto"|"required", "tools": [...]}` - [ ] `logit_bias` - [ ] `user` - [ ] `n` diff --git a/middleware/openai.go b/middleware/openai.go index 5e526416e..84a8a6bd7 100644 --- a/middleware/openai.go +++ b/middleware/openai.go @@ -20,10 +20,12 @@ type BaseWriter struct { } type ChatWriter struct { - stream bool - streamOptions *openai.StreamOptions - id string - toolCallSent bool + stream bool + streamOptions *openai.StreamOptions + id string + toolCallSent bool + forcedToolCall bool + forcedToolName string BaseWriter } @@ -65,6 +67,40 @@ func (w *BaseWriter) writeError(data []byte) (int, error) { return len(data), nil } +func parseForcedToolCall(content string, forcedToolName string) *api.ToolCall { + if content == "" { + return nil + } + + // Try to parse as tool call structure: {"name": "...", "arguments": {...}} + var toolCallJSON struct { + Name string `json:"name"` + Arguments map[string]any `json:"arguments"` + } + + if err := json.Unmarshal([]byte(content), &toolCallJSON); err != nil { + return nil + } + + // If a specific tool was forced, use that name + name := toolCallJSON.Name + if forcedToolName != "" { + name = forcedToolName + } + + if name == "" { + return nil + } + + return &api.ToolCall{ + ID: fmt.Sprintf("call_%d", rand.Intn(999999)), + Function: api.ToolCallFunction{ + Name: name, + Arguments: toolCallJSON.Arguments, + }, + } +} + func (w *ChatWriter) writeResponse(data []byte) (int, error) { var chatResponse api.ChatResponse err := json.Unmarshal(data, &chatResponse) @@ -72,6 +108,15 @@ func (w *ChatWriter) writeResponse(data []byte) (int, error) { return 0, err } + // If tool_choice forced a tool call and we have content but no tool calls, + // try to parse the content as a tool call + if w.forcedToolCall && len(chatResponse.Message.ToolCalls) == 0 && chatResponse.Message.Content != "" { + if toolCall := parseForcedToolCall(chatResponse.Message.Content, w.forcedToolName); toolCall != nil { + chatResponse.Message.ToolCalls = []api.ToolCall{*toolCall} + chatResponse.Message.Content = "" + } + } + // chat chunk if w.stream { c := openai.ToChunk(w.id, chatResponse, w.toolCallSent) @@ -406,6 +451,16 @@ func ChatMiddleware() gin.HandlerFunc { return } + // Determine if tool_choice forces a tool call + var forcedToolCall bool + var forcedToolName string + if req.ToolChoice != nil && len(req.Tools) > 0 { + _, _, forcedToolCall, _ = openai.ApplyToolChoice(req.Tools, req.ToolChoice) + if req.ToolChoice.IsForcedFunction() { + forcedToolName = req.ToolChoice.GetForcedFunctionName() + } + } + var b bytes.Buffer chatReq, err := openai.FromChatRequest(req) @@ -422,10 +477,12 @@ func ChatMiddleware() gin.HandlerFunc { c.Request.Body = io.NopCloser(&b) w := &ChatWriter{ - BaseWriter: BaseWriter{ResponseWriter: c.Writer}, - stream: req.Stream, - id: fmt.Sprintf("chatcmpl-%d", rand.Intn(999)), - streamOptions: req.StreamOptions, + BaseWriter: BaseWriter{ResponseWriter: c.Writer}, + stream: req.Stream, + id: fmt.Sprintf("chatcmpl-%d", rand.Intn(999)), + streamOptions: req.StreamOptions, + forcedToolCall: forcedToolCall, + forcedToolName: forcedToolName, } c.Writer = w diff --git a/openai/openai.go b/openai/openai.go index 9dcba3000..6a1b54928 100644 --- a/openai/openai.go +++ b/openai/openai.go @@ -95,6 +95,120 @@ type Reasoning struct { Effort string `json:"effort,omitempty"` } +type ToolChoiceFunctionRef struct { + Name string `json:"name"` +} + +type ToolChoiceAllowedTool struct { + Type string `json:"type"` // "function" + Name string `json:"name"` +} + +type ToolChoiceObject struct { + Type string `json:"type,omitempty"` + Name string `json:"name,omitempty"` + Function *ToolChoiceFunctionRef `json:"function,omitempty"` + Mode string `json:"mode,omitempty"` + Tools []ToolChoiceAllowedTool `json:"tools,omitempty"` +} + +type ToolChoice struct { + Mode string + Object *ToolChoiceObject +} + +// UnmarshalJSON handles both string and object forms of tool_choice +func (tc *ToolChoice) UnmarshalJSON(data []byte) error { + // Try string first + var s string + if err := json.Unmarshal(data, &s); err == nil { + tc.Mode = s + tc.Object = nil + return nil + } + + // Try object + var obj ToolChoiceObject + if err := json.Unmarshal(data, &obj); err != nil { + return fmt.Errorf("invalid tool_choice: must be string or object") + } + tc.Object = &obj + return nil +} + +// MarshalJSON serializes ToolChoice back to JSON +func (tc ToolChoice) MarshalJSON() ([]byte, error) { + if tc.Object != nil { + return json.Marshal(tc.Object) + } + return json.Marshal(tc.Mode) +} + +// IsNone returns true if tool_choice is "none" +func (tc *ToolChoice) IsNone() bool { + return tc != nil && tc.Mode == "none" +} + +// IsRequired returns true if tool_choice is "required" +func (tc *ToolChoice) IsRequired() bool { + return tc != nil && tc.Mode == "required" +} + +// IsAuto returns true if tool_choice is "auto" or not specified +func (tc *ToolChoice) IsAuto() bool { + return tc == nil || tc.Mode == "" || tc.Mode == "auto" +} + +// IsForcedFunction returns true if a specific function is forced +func (tc *ToolChoice) IsForcedFunction() bool { + if tc == nil || tc.Object == nil { + return false + } + return tc.Object.Type == "function" || tc.Object.Name != "" || tc.Object.Function != nil +} + +// GetForcedFunctionName returns the name of the forced function, if any +func (tc *ToolChoice) GetForcedFunctionName() string { + if tc == nil || tc.Object == nil { + return "" + } + if tc.Object.Name != "" { + return tc.Object.Name + } + if tc.Object.Function != nil { + return tc.Object.Function.Name + } + return "" +} + +// IsAllowedTools returns true if tool_choice uses allowed_tools mode +func (tc *ToolChoice) IsAllowedTools() bool { + return tc != nil && tc.Object != nil && tc.Object.Type == "allowed_tools" +} + +// GetAllowedToolNames returns the list of allowed tool names +func (tc *ToolChoice) GetAllowedToolNames() []string { + if !tc.IsAllowedTools() || tc.Object.Tools == nil { + return nil + } + names := make([]string, len(tc.Object.Tools)) + for i, t := range tc.Object.Tools { + names[i] = t.Name + } + return names +} + +// GetAllowedToolsMode returns the mode for allowed_tools ("auto" or "required") +func (tc *ToolChoice) GetAllowedToolsMode() string { + if !tc.IsAllowedTools() { + return "" + } + if tc.Object.Mode == "" { + return "auto" + } + return tc.Object.Mode +} + type ChatCompletionRequest struct { Model string `json:"model"` Messages []Message `json:"messages"` @@ -109,6 +223,7 @@ type ChatCompletionRequest struct { TopP *float64 `json:"top_p"` ResponseFormat *ResponseFormat `json:"response_format"` Tools []api.Tool `json:"tools"` + ToolChoice *ToolChoice `json:"tool_choice,omitempty"` Reasoning *Reasoning `json:"reasoning,omitempty"` ReasoningEffort *string `json:"reasoning_effort,omitempty"` Logprobs *bool `json:"logprobs"` @@ -444,6 +559,171 @@ func ToModel(r api.ShowResponse, m string) Model { } } +// filterToolsByNames returns only the tools that match the given names +func filterToolsByNames(tools []api.Tool, names []string) []api.Tool { + if len(names) == 0 { + return tools + } + nameSet := make(map[string]bool) + for _, name := range names { + nameSet[name] = true + } + var filtered []api.Tool + for _, tool := range tools { + if nameSet[tool.Function.Name] { + filtered = append(filtered, tool) + } + } + return filtered +} + +// findToolByName returns the tool with the given name, or nil if not found +func findToolByName(tools []api.Tool, name string) *api.Tool { + for _, tool := range tools { + if tool.Function.Name == name { + return &tool + } + } + return nil +} + +// generateToolCallSchema creates a JSON schema that constrains output to valid tool calls +func generateToolCallSchema(tools []api.Tool) json.RawMessage { + if len(tools) == 0 { + return nil + } + + // Collect all tool names for the enum + toolNames := make([]string, len(tools)) + for i, tool := range tools { + toolNames[i] = tool.Function.Name + } + + // Build a schema that allows any of the tools + // Using oneOf for each tool with its specific parameter schema + var oneOfSchemas []map[string]any + for _, tool := range tools { + toolSchema := map[string]any{ + "type": "object", + "properties": map[string]any{ + "name": map[string]any{ + "type": "string", + "const": tool.Function.Name, + }, + "arguments": tool.Function.Parameters, + }, + "required": []string{"name", "arguments"}, + "additionalProperties": false, + } + oneOfSchemas = append(oneOfSchemas, toolSchema) + } + + var schema map[string]any + if len(oneOfSchemas) == 1 { + // Single tool - use its schema directly + schema = oneOfSchemas[0] + } else { + // Multiple tools - use oneOf + schema = map[string]any{ + "oneOf": oneOfSchemas, + } + } + + bytes, err := json.Marshal(schema) + if err != nil { + slog.Error("failed to marshal tool call schema", "error", err) + return nil + } + return bytes +} + +// generateForcedFunctionSchema creates a JSON schema for a specific forced function +func generateForcedFunctionSchema(tool api.Tool) json.RawMessage { + schema := map[string]any{ + "type": "object", + "properties": map[string]any{ + "name": map[string]any{ + "type": "string", + "const": tool.Function.Name, + }, + "arguments": tool.Function.Parameters, + }, + "required": []string{"name", "arguments"}, + "additionalProperties": false, + } + + bytes, err := json.Marshal(schema) + if err != nil { + slog.Error("failed to marshal forced function schema", "error", err) + return nil + } + return bytes +} + +// ApplyToolChoice processes tool_choice and returns filtered tools and optional format schema +// Returns: +// - filteredTools: the tools to pass to the model (may be empty for "none") +// - format: JSON schema to constrain output (for "required" or forced function) +// - forcedToolCall: true if the response should be parsed as a tool call +// - error: if tool_choice references a non-existent tool +func ApplyToolChoice(tools []api.Tool, toolChoice *ToolChoice) (filteredTools []api.Tool, format json.RawMessage, forcedToolCall bool, err error) { + // Default: auto mode, return all tools without format constraint + if toolChoice == nil || toolChoice.IsAuto() { + // Check for allowed_tools with auto mode + if toolChoice != nil && toolChoice.IsAllowedTools() { + allowedNames := toolChoice.GetAllowedToolNames() + filteredTools = filterToolsByNames(tools, allowedNames) + if toolChoice.GetAllowedToolsMode() == "required" { + format = generateToolCallSchema(filteredTools) + forcedToolCall = true + } + return filteredTools, format, forcedToolCall, nil + } + return tools, nil, false, nil + } + + // "none" mode: don't pass any tools + if toolChoice.IsNone() { + return nil, nil, false, nil + } + + // "required" mode: must call at least one tool + if toolChoice.IsRequired() { + format = generateToolCallSchema(tools) + return tools, format, true, nil + } + + // Forced function mode + if toolChoice.IsForcedFunction() { + funcName := toolChoice.GetForcedFunctionName() + if funcName == "" { + return nil, nil, false, errors.New("tool_choice function name is required") + } + + tool := findToolByName(tools, funcName) + if tool == nil { + return nil, nil, false, fmt.Errorf("tool_choice references unknown function: %s", funcName) + } + + format = generateForcedFunctionSchema(*tool) + return []api.Tool{*tool}, format, true, nil + } + + // allowed_tools mode (already handled in IsAuto check, but handle explicit type here) + if toolChoice.IsAllowedTools() { + allowedNames := toolChoice.GetAllowedToolNames() + filteredTools = filterToolsByNames(tools, allowedNames) + if toolChoice.GetAllowedToolsMode() == "required" { + format = generateToolCallSchema(filteredTools) + forcedToolCall = true + } + return filteredTools, format, forcedToolCall, nil + } + + // Unknown mode, default to auto + return tools, nil, false, nil +} + // FromChatRequest converts a ChatCompletionRequest to api.ChatRequest func FromChatRequest(r ChatCompletionRequest) (*api.ChatRequest, error) { var messages []api.Message @@ -579,6 +859,18 @@ func FromChatRequest(r ChatCompletionRequest) (*api.ChatRequest, error) { } } + // Apply tool_choice to filter tools and potentially set format constraint + filteredTools, toolChoiceFormat, _, err := ApplyToolChoice(r.Tools, r.ToolChoice) + if err != nil { + return nil, err + } + + // If tool_choice requires a format constraint and no explicit response_format was set, + // apply the tool call schema + if toolChoiceFormat != nil && format == nil { + format = toolChoiceFormat + } + var think *api.ThinkValue var effort string @@ -606,7 +898,7 @@ func FromChatRequest(r ChatCompletionRequest) (*api.ChatRequest, error) { Format: format, Options: options, Stream: &r.Stream, - Tools: r.Tools, + Tools: filteredTools, Think: think, Logprobs: r.Logprobs != nil && *r.Logprobs, TopLogprobs: r.TopLogprobs, diff --git a/openai/openai_test.go b/openai/openai_test.go index 51e243dec..b551058e6 100644 --- a/openai/openai_test.go +++ b/openai/openai_test.go @@ -434,3 +434,480 @@ func TestFromChatRequest_TopLogprobsRange(t *testing.T) { }) } } + +func TestToolChoice_UnmarshalJSON(t *testing.T) { + tests := []struct { + name string + json string + wantMode string + wantObj bool + }{ + { + name: "string auto", + json: `"auto"`, + wantMode: "auto", + wantObj: false, + }, + { + name: "string none", + json: `"none"`, + wantMode: "none", + wantObj: false, + }, + { + name: "string required", + json: `"required"`, + wantMode: "required", + wantObj: false, + }, + { + name: "object function with name", + json: `{"type": "function", "name": "get_weather"}`, + wantMode: "", + wantObj: true, + }, + { + name: "object function with function.name", + json: `{"type": "function", "function": {"name": "get_weather"}}`, + wantMode: "", + wantObj: true, + }, + { + name: "object allowed_tools", + json: `{"type": "allowed_tools", "mode": "required", "tools": [{"type": "function", "name": "get_weather"}]}`, + wantMode: "", + wantObj: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var tc ToolChoice + if err := tc.UnmarshalJSON([]byte(tt.json)); err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if tc.Mode != tt.wantMode { + t.Errorf("Mode = %q, want %q", tc.Mode, tt.wantMode) + } + + if (tc.Object != nil) != tt.wantObj { + t.Errorf("Object = %v, wantObj = %v", tc.Object, tt.wantObj) + } + }) + } +} + +func TestToolChoice_Methods(t *testing.T) { + tests := []struct { + name string + toolChoice *ToolChoice + isNone bool + isRequired bool + isAuto bool + isForcedFunction bool + forcedFuncName string + isAllowedTools bool + allowedToolNames []string + allowedToolsMode string + }{ + { + name: "nil", + toolChoice: nil, + isNone: false, + isRequired: false, + isAuto: true, + }, + { + name: "auto string", + toolChoice: &ToolChoice{Mode: "auto"}, + isNone: false, + isRequired: false, + isAuto: true, + }, + { + name: "none string", + toolChoice: &ToolChoice{Mode: "none"}, + isNone: true, + isRequired: false, + isAuto: false, + }, + { + name: "required string", + toolChoice: &ToolChoice{Mode: "required"}, + isNone: false, + isRequired: true, + isAuto: false, + }, + { + name: "forced function with name", + toolChoice: &ToolChoice{ + Object: &ToolChoiceObject{Type: "function", Name: "get_weather"}, + }, + isNone: false, + isRequired: false, + isAuto: false, + isForcedFunction: true, + forcedFuncName: "get_weather", + }, + { + name: "forced function with function.name", + toolChoice: &ToolChoice{ + Object: &ToolChoiceObject{ + Type: "function", + Function: &ToolChoiceFunctionRef{Name: "search"}, + }, + }, + isNone: false, + isRequired: false, + isAuto: false, + isForcedFunction: true, + forcedFuncName: "search", + }, + { + name: "allowed_tools auto", + toolChoice: &ToolChoice{ + Object: &ToolChoiceObject{ + Type: "allowed_tools", + Mode: "auto", + Tools: []ToolChoiceAllowedTool{ + {Type: "function", Name: "get_weather"}, + {Type: "function", Name: "search"}, + }, + }, + }, + isNone: false, + isRequired: false, + isAuto: true, + isForcedFunction: false, + isAllowedTools: true, + allowedToolNames: []string{"get_weather", "search"}, + allowedToolsMode: "auto", + }, + { + name: "allowed_tools required", + toolChoice: &ToolChoice{ + Object: &ToolChoiceObject{ + Type: "allowed_tools", + Mode: "required", + Tools: []ToolChoiceAllowedTool{ + {Type: "function", Name: "get_weather"}, + }, + }, + }, + isNone: false, + isRequired: false, + isAuto: false, + isForcedFunction: false, + isAllowedTools: true, + allowedToolNames: []string{"get_weather"}, + allowedToolsMode: "required", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := tt.toolChoice.IsNone(); got != tt.isNone { + t.Errorf("IsNone() = %v, want %v", got, tt.isNone) + } + if got := tt.toolChoice.IsRequired(); got != tt.isRequired { + t.Errorf("IsRequired() = %v, want %v", got, tt.isRequired) + } + if got := tt.toolChoice.IsAuto(); got != tt.isAuto { + t.Errorf("IsAuto() = %v, want %v", got, tt.isAuto) + } + if got := tt.toolChoice.IsForcedFunction(); got != tt.isForcedFunction { + t.Errorf("IsForcedFunction() = %v, want %v", got, tt.isForcedFunction) + } + if got := tt.toolChoice.GetForcedFunctionName(); got != tt.forcedFuncName { + t.Errorf("GetForcedFunctionName() = %q, want %q", got, tt.forcedFuncName) + } + if got := tt.toolChoice.IsAllowedTools(); got != tt.isAllowedTools { + t.Errorf("IsAllowedTools() = %v, want %v", got, tt.isAllowedTools) + } + if tt.isAllowedTools { + if got := tt.toolChoice.GetAllowedToolNames(); !cmp.Equal(got, tt.allowedToolNames) { + t.Errorf("GetAllowedToolNames() = %v, want %v", got, tt.allowedToolNames) + } + if got := tt.toolChoice.GetAllowedToolsMode(); got != tt.allowedToolsMode { + t.Errorf("GetAllowedToolsMode() = %q, want %q", got, tt.allowedToolsMode) + } + } + }) + } +} + +func TestApplyToolChoice(t *testing.T) { + tools := []api.Tool{ + { + Type: "function", + Function: api.ToolFunction{ + Name: "get_weather", + Description: "Get weather", + Parameters: api.ToolFunctionParameters{ + Type: "object", + Properties: map[string]api.Property{ + "location": {Type: []string{"string"}}, + }, + Required: []string{"location"}, + }, + }, + }, + { + Type: "function", + Function: api.ToolFunction{ + Name: "search", + Description: "Search the web", + Parameters: api.ToolFunctionParameters{ + Type: "object", + Properties: map[string]api.Property{ + "query": {Type: []string{"string"}}, + }, + Required: []string{"query"}, + }, + }, + }, + } + + tests := []struct { + name string + toolChoice *ToolChoice + wantToolCount int + wantFormat bool + wantForced bool + wantError bool + wantToolNames []string + }{ + { + name: "nil (auto)", + toolChoice: nil, + wantToolCount: 2, + wantFormat: false, + wantForced: false, + }, + { + name: "auto string", + toolChoice: &ToolChoice{Mode: "auto"}, + wantToolCount: 2, + wantFormat: false, + wantForced: false, + }, + { + name: "none string", + toolChoice: &ToolChoice{Mode: "none"}, + wantToolCount: 0, + wantFormat: false, + wantForced: false, + }, + { + name: "required string", + toolChoice: &ToolChoice{Mode: "required"}, + wantToolCount: 2, + wantFormat: true, + wantForced: true, + }, + { + name: "forced function", + toolChoice: &ToolChoice{ + Object: &ToolChoiceObject{Type: "function", Name: "get_weather"}, + }, + wantToolCount: 1, + wantFormat: true, + wantForced: true, + wantToolNames: []string{"get_weather"}, + }, + { + name: "forced unknown function", + toolChoice: &ToolChoice{ + Object: &ToolChoiceObject{Type: "function", Name: "unknown_func"}, + }, + wantError: true, + }, + { + name: "allowed_tools auto", + toolChoice: &ToolChoice{ + Object: &ToolChoiceObject{ + Type: "allowed_tools", + Mode: "auto", + Tools: []ToolChoiceAllowedTool{ + {Type: "function", Name: "get_weather"}, + }, + }, + }, + wantToolCount: 1, + wantFormat: false, + wantForced: false, + wantToolNames: []string{"get_weather"}, + }, + { + name: "allowed_tools required", + toolChoice: &ToolChoice{ + Object: &ToolChoiceObject{ + Type: "allowed_tools", + Mode: "required", + Tools: []ToolChoiceAllowedTool{ + {Type: "function", Name: "search"}, + }, + }, + }, + wantToolCount: 1, + wantFormat: true, + wantForced: true, + wantToolNames: []string{"search"}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + filteredTools, format, forced, err := ApplyToolChoice(tools, tt.toolChoice) + + if tt.wantError { + if err == nil { + t.Errorf("expected error, got nil") + } + return + } + + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if len(filteredTools) != tt.wantToolCount { + t.Errorf("got %d tools, want %d", len(filteredTools), tt.wantToolCount) + } + + if (format != nil) != tt.wantFormat { + t.Errorf("format = %v, wantFormat = %v", format != nil, tt.wantFormat) + } + + if forced != tt.wantForced { + t.Errorf("forced = %v, want %v", forced, tt.wantForced) + } + + if tt.wantToolNames != nil { + gotNames := make([]string, len(filteredTools)) + for i, tool := range filteredTools { + gotNames[i] = tool.Function.Name + } + if !cmp.Equal(gotNames, tt.wantToolNames) { + t.Errorf("tool names = %v, want %v", gotNames, tt.wantToolNames) + } + } + }) + } +} + +func TestFromChatRequest_WithToolChoice(t *testing.T) { + tools := []api.Tool{ + { + Type: "function", + Function: api.ToolFunction{ + Name: "get_weather", + Description: "Get weather", + Parameters: api.ToolFunctionParameters{ + Type: "object", + Properties: map[string]api.Property{ + "location": {Type: []string{"string"}}, + }, + }, + }, + }, + { + Type: "function", + Function: api.ToolFunction{ + Name: "search", + Description: "Search", + Parameters: api.ToolFunctionParameters{ + Type: "object", + Properties: map[string]api.Property{ + "query": {Type: []string{"string"}}, + }, + }, + }, + }, + } + + t.Run("tool_choice none removes tools", func(t *testing.T) { + req := ChatCompletionRequest{ + Model: "test-model", + Messages: []Message{{Role: "user", Content: "Hello"}}, + Tools: tools, + ToolChoice: &ToolChoice{Mode: "none"}, + } + + result, err := FromChatRequest(req) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if len(result.Tools) != 0 { + t.Errorf("expected 0 tools, got %d", len(result.Tools)) + } + }) + + t.Run("tool_choice required adds format", func(t *testing.T) { + req := ChatCompletionRequest{ + Model: "test-model", + Messages: []Message{{Role: "user", Content: "Hello"}}, + Tools: tools, + ToolChoice: &ToolChoice{Mode: "required"}, + } + + result, err := FromChatRequest(req) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if len(result.Tools) != 2 { + t.Errorf("expected 2 tools, got %d", len(result.Tools)) + } + + if result.Format == nil { + t.Error("expected format to be set for required tool_choice") + } + }) + + t.Run("tool_choice forced function filters and adds format", func(t *testing.T) { + req := ChatCompletionRequest{ + Model: "test-model", + Messages: []Message{{Role: "user", Content: "Hello"}}, + Tools: tools, + ToolChoice: &ToolChoice{ + Object: &ToolChoiceObject{Type: "function", Name: "get_weather"}, + }, + } + + result, err := FromChatRequest(req) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if len(result.Tools) != 1 { + t.Errorf("expected 1 tool, got %d", len(result.Tools)) + } + + if result.Tools[0].Function.Name != "get_weather" { + t.Errorf("expected tool 'get_weather', got %q", result.Tools[0].Function.Name) + } + + if result.Format == nil { + t.Error("expected format to be set for forced function") + } + }) + + t.Run("tool_choice unknown function returns error", func(t *testing.T) { + req := ChatCompletionRequest{ + Model: "test-model", + Messages: []Message{{Role: "user", Content: "Hello"}}, + Tools: tools, + ToolChoice: &ToolChoice{ + Object: &ToolChoiceObject{Type: "function", Name: "unknown"}, + }, + } + + _, err := FromChatRequest(req) + if err == nil { + t.Error("expected error for unknown function") + } + }) +} diff --git a/server/routes.go b/server/routes.go index 977a13ff2..5de1882cf 100644 --- a/server/routes.go +++ b/server/routes.go @@ -1861,6 +1861,112 @@ func toolCallId() string { return "call_" + strings.ToLower(string(b)) } +// findToolByName returns the tool with the given name, or nil if not found. +func findToolByName(tools []api.Tool, name string) *api.Tool { + for i := range tools { + if tools[i].Function.Name == name { + return &tools[i] + } + } + return nil +} + +// generateToolCallSchema creates a JSON schema that constrains output to valid tool calls. +func generateToolCallSchema(tools []api.Tool) json.RawMessage { + if len(tools) == 0 { + return nil + } + + var oneOfSchemas []map[string]any + for _, tool := range tools { + toolSchema := map[string]any{ + "type": "object", + "properties": map[string]any{ + "name": map[string]any{ + "type": "string", + "const": tool.Function.Name, + }, + "arguments": tool.Function.Parameters, + }, + "required": []string{"name", "arguments"}, + "additionalProperties": false, + } + oneOfSchemas = append(oneOfSchemas, toolSchema) + } + + var schema map[string]any + if len(oneOfSchemas) == 1 { + schema = oneOfSchemas[0] + } else { + schema = map[string]any{ + "oneOf": oneOfSchemas, + } + } + + bytes, err := json.Marshal(schema) + if err != nil { + slog.Error("failed to marshal tool call schema", "error", err) + return nil + } + return bytes +} + +// generateForcedFunctionSchema creates a JSON schema for a specific forced function. +func generateForcedFunctionSchema(tool api.Tool) json.RawMessage { + schema := map[string]any{ + "type": "object", + "properties": map[string]any{ + "name": map[string]any{ + "type": "string", + "const": tool.Function.Name, + }, + "arguments": tool.Function.Parameters, + }, + "required": []string{"name", "arguments"}, + "additionalProperties": false, + } + + bytes, err := json.Marshal(schema) + if err != nil { + slog.Error("failed to marshal forced function schema", "error", err) + return nil + } + return bytes +} + +// parseForcedToolCallContent parses content as a forced tool call JSON. +func parseForcedToolCallContent(content string, forcedToolName string) *api.ToolCall { + if content == "" { + return nil + } + + var toolCallJSON struct { + Name string `json:"name"` + Arguments map[string]any `json:"arguments"` + } + + if err := json.Unmarshal([]byte(content), &toolCallJSON); err != nil { + return nil + } + + name := toolCallJSON.Name + if forcedToolName != "" { + name = forcedToolName + } + + if name == "" { + return nil + } + + return &api.ToolCall{ + ID: toolCallId(), + Function: api.ToolCallFunction{ + Name: name, + Arguments: toolCallJSON.Arguments, + }, + } +} + func (s *Server) ChatHandler(c *gin.Context) { checkpointStart := time.Now() @@ -2075,6 +2181,36 @@ func (s *Server) ChatHandler(c *gin.Context) { } } + // Handle tool_choice + var forcedToolCall bool + var forcedToolName string + if req.ToolChoice != nil && len(req.Tools) > 0 { + if req.ToolChoice.IsNone() { + // "none" mode: don't pass any tools + processedTools = nil + } else if req.ToolChoice.IsRequired() { + // "required" mode: must call at least one tool, generate JSON schema + if req.Format == nil { + req.Format = generateToolCallSchema(req.Tools) + } + forcedToolCall = true + } else if req.ToolChoice.IsForcedFunction() { + // Forced function mode: must call this specific function + funcName := req.ToolChoice.GetForcedFunctionName() + tool := findToolByName(req.Tools, funcName) + if tool == nil { + c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("tool_choice references unknown function: %s", funcName)}) + return + } + if req.Format == nil { + req.Format = generateForcedFunctionSchema(*tool) + } + processedTools = []api.Tool{*tool} + forcedToolCall = true + forcedToolName = funcName + } + } + truncate := req.Truncate == nil || *req.Truncate prompt, images, err := chatPrompt(c.Request.Context(), m, r.Tokenize, opts, msgs, processedTools, req.Think, truncate) if err != nil { @@ -2206,6 +2342,15 @@ func (s *Server) ChatHandler(c *gin.Context) { return } + // If tool_choice forced a tool call and builtinParser didn't find any, + // try to parse the content as a forced tool call JSON + if r.Done && forcedToolCall && len(res.Message.ToolCalls) == 0 && res.Message.Content != "" { + if toolCall := parseForcedToolCallContent(res.Message.Content, forcedToolName); toolCall != nil { + res.Message.ToolCalls = []api.ToolCall{*toolCall} + res.Message.Content = "" + } + } + if res.Message.Content != "" || res.Message.Thinking != "" || len(res.Message.ToolCalls) > 0 || r.Done || len(res.Logprobs) > 0 { slog.Log(context.TODO(), logutil.LevelTrace, "builtin parser output", "parser", m.Config.Parser, "content", content, "thinking", thinking, "toolCalls", toolCalls, "done", r.Done) ch <- res @@ -2235,7 +2380,7 @@ func (s *Server) ChatHandler(c *gin.Context) { res.Message.Content = remainingContent } - if len(req.Tools) > 0 { + if len(req.Tools) > 0 && toolParser != nil { toolCalls, content := toolParser.Add(res.Message.Content) if len(content) > 0 { res.Message.Content = content @@ -2258,12 +2403,28 @@ func (s *Server) ChatHandler(c *gin.Context) { if r.Done { res.Message.Content = toolParser.Content() + // If tool_choice forced a tool call, try to parse the buffered content + if forcedToolCall && len(res.Message.ToolCalls) == 0 && res.Message.Content != "" { + if toolCall := parseForcedToolCallContent(res.Message.Content, forcedToolName); toolCall != nil { + res.Message.ToolCalls = []api.ToolCall{*toolCall} + res.Message.Content = "" + } + } ch <- res } return } } + // If tool_choice forced a tool call and we have content but no tool calls, + // try to parse the content as a tool call (used when format constraint was applied) + if r.Done && forcedToolCall && len(res.Message.ToolCalls) == 0 && res.Message.Content != "" { + if toolCall := parseForcedToolCallContent(res.Message.Content, forcedToolName); toolCall != nil { + res.Message.ToolCalls = []api.ToolCall{*toolCall} + res.Message.Content = "" + } + } + ch <- res }) if err != nil { @@ -2353,6 +2514,18 @@ func (s *Server) ChatHandler(c *gin.Context) { if len(toolCalls) > 0 { resp.Message.ToolCalls = toolCalls + // If we have tool calls from forced tool_choice, the "content" was actually + // the JSON that got parsed into tool calls, so clear it + if forcedToolCall { + resp.Message.Content = "" + } + } else if forcedToolCall && resp.Message.Content != "" { + // No tool calls were parsed in callbacks, but we have forced tool_choice. + // Try to parse the accumulated content as a tool call JSON. + if toolCall := parseForcedToolCallContent(resp.Message.Content, forcedToolName); toolCall != nil { + resp.Message.ToolCalls = []api.ToolCall{*toolCall} + resp.Message.Content = "" + } } c.JSON(http.StatusOK, resp) diff --git a/server/routes_test.go b/server/routes_test.go index e470b9384..10c03b255 100644 --- a/server/routes_test.go +++ b/server/routes_test.go @@ -978,3 +978,227 @@ func TestWaitForStream(t *testing.T) { }) } } + +func TestFindToolByName(t *testing.T) { + tools := []api.Tool{ + { + Type: "function", + Function: api.ToolFunction{ + Name: "get_weather", + Description: "Get weather for a location", + }, + }, + { + Type: "function", + Function: api.ToolFunction{ + Name: "search_web", + Description: "Search the web", + }, + }, + } + + t.Run("found", func(t *testing.T) { + tool := findToolByName(tools, "get_weather") + if tool == nil { + t.Fatal("expected to find tool") + } + if tool.Function.Name != "get_weather" { + t.Errorf("expected get_weather, got %s", tool.Function.Name) + } + }) + + t.Run("not found", func(t *testing.T) { + tool := findToolByName(tools, "nonexistent") + if tool != nil { + t.Error("expected nil for nonexistent tool") + } + }) + + t.Run("empty tools", func(t *testing.T) { + tool := findToolByName([]api.Tool{}, "get_weather") + if tool != nil { + t.Error("expected nil for empty tools") + } + }) +} + +func TestGenerateToolCallSchema(t *testing.T) { + tools := []api.Tool{ + { + Type: "function", + Function: api.ToolFunction{ + Name: "get_weather", + Description: "Get weather", + Parameters: api.ToolFunctionParameters{ + Type: "object", + Required: []string{"location"}, + Properties: map[string]api.ToolProperty{ + "location": {Type: api.PropertyType{"string"}}, + }, + }, + }, + }, + } + + t.Run("single tool", func(t *testing.T) { + schema := generateToolCallSchema(tools) + if schema == nil { + t.Fatal("expected schema, got nil") + } + + var parsed map[string]any + if err := json.Unmarshal(schema, &parsed); err != nil { + t.Fatalf("failed to parse schema: %v", err) + } + + // Should have properties with name and arguments + props, ok := parsed["properties"].(map[string]any) + if !ok { + t.Fatal("expected properties in schema") + } + + if _, ok := props["name"]; !ok { + t.Error("expected name property") + } + if _, ok := props["arguments"]; !ok { + t.Error("expected arguments property") + } + }) + + t.Run("multiple tools", func(t *testing.T) { + multiTools := append(tools, api.Tool{ + Type: "function", + Function: api.ToolFunction{ + Name: "search_web", + }, + }) + + schema := generateToolCallSchema(multiTools) + if schema == nil { + t.Fatal("expected schema, got nil") + } + + var parsed map[string]any + if err := json.Unmarshal(schema, &parsed); err != nil { + t.Fatalf("failed to parse schema: %v", err) + } + + // Should have oneOf for multiple tools + if _, ok := parsed["oneOf"]; !ok { + t.Error("expected oneOf for multiple tools") + } + }) + + t.Run("empty tools", func(t *testing.T) { + schema := generateToolCallSchema([]api.Tool{}) + if schema != nil { + t.Error("expected nil for empty tools") + } + }) +} + +func TestGenerateForcedFunctionSchema(t *testing.T) { + tool := api.Tool{ + Type: "function", + Function: api.ToolFunction{ + Name: "get_weather", + Description: "Get weather", + Parameters: api.ToolFunctionParameters{ + Type: "object", + Required: []string{"location"}, + Properties: map[string]api.ToolProperty{ + "location": {Type: api.PropertyType{"string"}}, + }, + }, + }, + } + + schema := generateForcedFunctionSchema(tool) + if schema == nil { + t.Fatal("expected schema, got nil") + } + + var parsed map[string]any + if err := json.Unmarshal(schema, &parsed); err != nil { + t.Fatalf("failed to parse schema: %v", err) + } + + // Check that name has const constraint + props := parsed["properties"].(map[string]any) + nameSchema := props["name"].(map[string]any) + if nameSchema["const"] != "get_weather" { + t.Errorf("expected const get_weather, got %v", nameSchema["const"]) + } +} + +func TestParseForcedToolCallContent(t *testing.T) { + t.Run("valid tool call JSON", func(t *testing.T) { + content := `{"name": "get_weather", "arguments": {"location": "Paris"}}` + toolCall := parseForcedToolCallContent(content, "") + + if toolCall == nil { + t.Fatal("expected tool call, got nil") + } + + if toolCall.Function.Name != "get_weather" { + t.Errorf("expected get_weather, got %s", toolCall.Function.Name) + } + + if toolCall.Function.Arguments["location"] != "Paris" { + t.Errorf("expected Paris, got %v", toolCall.Function.Arguments["location"]) + } + + if toolCall.ID == "" { + t.Error("expected non-empty tool call ID") + } + }) + + t.Run("forced tool name override", func(t *testing.T) { + content := `{"name": "other_tool", "arguments": {"location": "Paris"}}` + toolCall := parseForcedToolCallContent(content, "get_weather") + + if toolCall == nil { + t.Fatal("expected tool call, got nil") + } + + // Should use the forced name, not the one in JSON + if toolCall.Function.Name != "get_weather" { + t.Errorf("expected get_weather (forced), got %s", toolCall.Function.Name) + } + }) + + t.Run("empty content", func(t *testing.T) { + toolCall := parseForcedToolCallContent("", "") + if toolCall != nil { + t.Error("expected nil for empty content") + } + }) + + t.Run("invalid JSON", func(t *testing.T) { + toolCall := parseForcedToolCallContent("not json", "") + if toolCall != nil { + t.Error("expected nil for invalid JSON") + } + }) + + t.Run("missing name", func(t *testing.T) { + content := `{"arguments": {"location": "Paris"}}` + toolCall := parseForcedToolCallContent(content, "") + if toolCall != nil { + t.Error("expected nil when name is missing and not forced") + } + }) + + t.Run("missing name but forced", func(t *testing.T) { + content := `{"arguments": {"location": "Paris"}}` + toolCall := parseForcedToolCallContent(content, "get_weather") + + if toolCall == nil { + t.Fatal("expected tool call with forced name") + } + + if toolCall.Function.Name != "get_weather" { + t.Errorf("expected get_weather, got %s", toolCall.Function.Name) + } + }) +} From a9a62a07d6419904291bc8f399bb858d45494efa Mon Sep 17 00:00:00 2001 From: Baptiste Jamin Date: Mon, 5 Jan 2026 12:23:49 +0100 Subject: [PATCH 2/2] Fix tests --- openai/openai.go | 14 +++++++++++++- openai/openai_test.go | 8 ++++---- 2 files changed, 17 insertions(+), 5 deletions(-) diff --git a/openai/openai.go b/openai/openai.go index 6a1b54928..9d0f95e0c 100644 --- a/openai/openai.go +++ b/openai/openai.go @@ -156,7 +156,19 @@ func (tc *ToolChoice) IsRequired() bool { // IsAuto returns true if tool_choice is "auto" or not specified func (tc *ToolChoice) IsAuto() bool { - return tc == nil || tc.Mode == "" || tc.Mode == "auto" + if tc == nil { + return true + } + // If there's an object, check if it's allowed_tools with auto mode + if tc.Object != nil { + // allowed_tools with "auto" mode is still considered auto + if tc.Object.Type == "allowed_tools" && (tc.Object.Mode == "" || tc.Object.Mode == "auto") { + return true + } + // Any other object (forced function, allowed_tools with required) is not auto + return false + } + return tc.Mode == "" || tc.Mode == "auto" } // IsForcedFunction returns true if a specific function is forced diff --git a/openai/openai_test.go b/openai/openai_test.go index b551058e6..850a017d8 100644 --- a/openai/openai_test.go +++ b/openai/openai_test.go @@ -646,7 +646,7 @@ func TestApplyToolChoice(t *testing.T) { Description: "Get weather", Parameters: api.ToolFunctionParameters{ Type: "object", - Properties: map[string]api.Property{ + Properties: map[string]api.ToolProperty{ "location": {Type: []string{"string"}}, }, Required: []string{"location"}, @@ -660,7 +660,7 @@ func TestApplyToolChoice(t *testing.T) { Description: "Search the web", Parameters: api.ToolFunctionParameters{ Type: "object", - Properties: map[string]api.Property{ + Properties: map[string]api.ToolProperty{ "query": {Type: []string{"string"}}, }, Required: []string{"query"}, @@ -806,7 +806,7 @@ func TestFromChatRequest_WithToolChoice(t *testing.T) { Description: "Get weather", Parameters: api.ToolFunctionParameters{ Type: "object", - Properties: map[string]api.Property{ + Properties: map[string]api.ToolProperty{ "location": {Type: []string{"string"}}, }, }, @@ -819,7 +819,7 @@ func TestFromChatRequest_WithToolChoice(t *testing.T) { Description: "Search", Parameters: api.ToolFunctionParameters{ Type: "object", - Properties: map[string]api.Property{ + Properties: map[string]api.ToolProperty{ "query": {Type: []string{"string"}}, }, },