diff --git a/api/types.go b/api/types.go
index 63b898975..79ab7a874 100644
--- a/api/types.go
+++ b/api/types.go
@@ -148,6 +148,13 @@ type ChatRequest struct {
// Tools is an optional list of tools the model has access to.
Tools `json:"tools,omitempty"`
+ // ToolChoice controls how the model uses tools. Can be:
+ // - "auto" (default): model decides whether to call tools
+ // - "none": model won't call any tools
+ // - "required": model must call at least one tool
+ // - ToolChoiceFunction{Name: "func_name"}: model must call this specific function
+ ToolChoice *ToolChoice `json:"tool_choice,omitempty"`
+
// Options lists model-specific options.
Options map[string]any `json:"options"`
@@ -184,6 +191,87 @@ func (t Tools) String() string {
return string(bts)
}
+// ToolChoice controls how the model uses tools.
+// It can be a string ("auto", "none", "required") or a ToolChoiceFunction.
+type ToolChoice struct {
+ // Mode is the tool choice mode: "auto", "none", or "required"
+ Mode string `json:"-"`
+ // Function specifies a specific function to call (when forcing a specific tool)
+ Function *ToolChoiceFunction `json:"-"`
+}
+
+// ToolChoiceFunction specifies a specific function that the model must call.
+type ToolChoiceFunction struct {
+ Name string `json:"name"`
+}
+
+// UnmarshalJSON handles both string and object forms of tool_choice.
+func (tc *ToolChoice) UnmarshalJSON(data []byte) error {
+ // Try string first: "auto", "none", "required"
+ var s string
+ if err := json.Unmarshal(data, &s); err == nil {
+ tc.Mode = s
+ tc.Function = nil
+ return nil
+ }
+
+ // Try object with function name: {"function": {"name": "func_name"}}
+ var obj struct {
+ Function *ToolChoiceFunction `json:"function"`
+ }
+ if err := json.Unmarshal(data, &obj); err == nil && obj.Function != nil {
+ tc.Function = obj.Function
+ tc.Mode = ""
+ return nil
+ }
+
+ // Try simple object with just name: {"name": "func_name"}
+ var simple ToolChoiceFunction
+ if err := json.Unmarshal(data, &simple); err == nil && simple.Name != "" {
+ tc.Function = &simple
+ tc.Mode = ""
+ return nil
+ }
+
+ return fmt.Errorf("invalid tool_choice: must be string or object with function name")
+}
+
+// MarshalJSON serializes ToolChoice back to JSON.
+func (tc ToolChoice) MarshalJSON() ([]byte, error) {
+ if tc.Function != nil {
+ return json.Marshal(map[string]any{"function": tc.Function})
+ }
+ return json.Marshal(tc.Mode)
+}
+
+// IsNone returns true if tool_choice is "none".
+func (tc *ToolChoice) IsNone() bool {
+ return tc != nil && tc.Mode == "none"
+}
+
+// IsRequired returns true if tool_choice is "required".
+func (tc *ToolChoice) IsRequired() bool {
+ return tc != nil && tc.Mode == "required"
+}
+
+// IsAuto returns true if tool_choice is "auto" or not specified.
+func (tc *ToolChoice) IsAuto() bool {
+ return tc == nil || tc.Mode == "" || tc.Mode == "auto"
+}
+
+// IsForcedFunction returns true if a specific function is forced.
+func (tc *ToolChoice) IsForcedFunction() bool {
+ return tc != nil && tc.Function != nil && tc.Function.Name != ""
+}
+
+// GetForcedFunctionName returns the name of the forced function, if any.
+func (tc *ToolChoice) GetForcedFunctionName() string {
+ if tc == nil || tc.Function == nil {
+ return ""
+ }
+ return tc.Function.Name
+}
+
func (t Tool) String() string {
bts, _ := json.Marshal(t)
return string(bts)
diff --git a/api/types_test.go b/api/types_test.go
index da1581f48..360e11f50 100644
--- a/api/types_test.go
+++ b/api/types_test.go
@@ -651,3 +651,170 @@ func TestToolFunctionParameters_String(t *testing.T) {
})
}
}
+
+func TestToolChoice_UnmarshalJSON(t *testing.T) {
+ tests := []struct {
+ name string
+ input string
+ expectedMode string
+ expectedFunc string
+ expectedError bool
+ }{
+ {
+ name: "auto string",
+ input: `{"tool_choice": "auto"}`,
+ expectedMode: "auto",
+ expectedFunc: "",
+ },
+ {
+ name: "none string",
+ input: `{"tool_choice": "none"}`,
+ expectedMode: "none",
+ expectedFunc: "",
+ },
+ {
+ name: "required string",
+ input: `{"tool_choice": "required"}`,
+ expectedMode: "required",
+ expectedFunc: "",
+ },
+ {
+ name: "function object with nested function",
+ input: `{"tool_choice": {"function": {"name": "get_weather"}}}`,
+ expectedMode: "",
+ expectedFunc: "get_weather",
+ },
+ {
+ name: "function object with direct name",
+ input: `{"tool_choice": {"name": "get_weather"}}`,
+ expectedMode: "",
+ expectedFunc: "get_weather",
+ },
+ {
+ name: "unset",
+ input: `{}`,
+ expectedMode: "",
+ expectedFunc: "",
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ var req ChatRequest
+ err := json.Unmarshal([]byte(test.input), &req)
+
+ if test.expectedError {
+ require.Error(t, err)
+ return
+ }
+
+ require.NoError(t, err)
+
+ if test.expectedMode == "" && test.expectedFunc == "" && test.name == "unset" {
+ assert.Nil(t, req.ToolChoice)
+ return
+ }
+
+ require.NotNil(t, req.ToolChoice)
+
+ if test.expectedMode != "" {
+ assert.Equal(t, test.expectedMode, req.ToolChoice.Mode)
+ }
+
+ if test.expectedFunc != "" {
+ require.NotNil(t, req.ToolChoice.Function)
+ assert.Equal(t, test.expectedFunc, req.ToolChoice.Function.Name)
+ }
+ })
+ }
+}
+
+func TestToolChoice_Methods(t *testing.T) {
+ t.Run("IsNone", func(t *testing.T) {
+ tc := &ToolChoice{Mode: "none"}
+ assert.True(t, tc.IsNone())
+
+ tc = &ToolChoice{Mode: "auto"}
+ assert.False(t, tc.IsNone())
+
+ var nilTc *ToolChoice
+ assert.False(t, nilTc.IsNone())
+ })
+
+ t.Run("IsRequired", func(t *testing.T) {
+ tc := &ToolChoice{Mode: "required"}
+ assert.True(t, tc.IsRequired())
+
+ tc = &ToolChoice{Mode: "auto"}
+ assert.False(t, tc.IsRequired())
+
+ var nilTc *ToolChoice
+ assert.False(t, nilTc.IsRequired())
+ })
+
+ t.Run("IsAuto", func(t *testing.T) {
+ tc := &ToolChoice{Mode: "auto"}
+ assert.True(t, tc.IsAuto())
+
+ tc = &ToolChoice{Mode: ""}
+ assert.True(t, tc.IsAuto())
+
+ var nilTc *ToolChoice
+ assert.True(t, nilTc.IsAuto())
+
+ tc = &ToolChoice{Mode: "required"}
+ assert.False(t, tc.IsAuto())
+ })
+
+ t.Run("IsForcedFunction", func(t *testing.T) {
+ tc := &ToolChoice{Function: &ToolChoiceFunction{Name: "get_weather"}}
+ assert.True(t, tc.IsForcedFunction())
+
+ tc = &ToolChoice{Mode: "required"}
+ assert.False(t, tc.IsForcedFunction())
+
+ tc = &ToolChoice{Function: &ToolChoiceFunction{Name: ""}}
+ assert.False(t, tc.IsForcedFunction())
+
+ var nilTc *ToolChoice
+ assert.False(t, nilTc.IsForcedFunction())
+ })
+
+ t.Run("GetForcedFunctionName", func(t *testing.T) {
+ tc := &ToolChoice{Function: &ToolChoiceFunction{Name: "get_weather"}}
+ assert.Equal(t, "get_weather", tc.GetForcedFunctionName())
+
+ tc = &ToolChoice{Mode: "required"}
+ assert.Equal(t, "", tc.GetForcedFunctionName())
+
+ var nilTc *ToolChoice
+ assert.Equal(t, "", nilTc.GetForcedFunctionName())
+ })
+}
+
+func TestToolChoice_MarshalJSON(t *testing.T) {
+ tests := []struct {
+ name string
+ tc ToolChoice
+ expected string
+ }{
+ {
+ name: "mode string",
+ tc: ToolChoice{Mode: "required"},
+ expected: `"required"`,
+ },
+ {
+ name: "function object",
+ tc: ToolChoice{Function: &ToolChoiceFunction{Name: "get_weather"}},
+ expected: `{"function":{"name":"get_weather"}}`,
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ data, err := json.Marshal(test.tc)
+ require.NoError(t, err)
+ assert.Equal(t, test.expected, string(data))
+ })
+ }
+}
diff --git a/docs/api.md b/docs/api.md
index 7c32c9597..65cdd5c5f 100644
--- a/docs/api.md
+++ b/docs/api.md
@@ -493,6 +493,7 @@ Generate the next message in a chat with a provided model. This is a streaming e
- `model`: (required) the [model name](#model-names)
- `messages`: the messages of the chat, this can be used to keep a chat memory
- `tools`: list of tools in JSON for the model to use if supported
+- `tool_choice`: controls how the model uses tools (see [Tool choice](#tool-choice) below)
- `think`: (for thinking models) should the model think before responding?
The `message` object has the following fields:
@@ -519,6 +520,39 @@ Models can also explain the result of the tool call in the response. See the [Ch
[See models with tool calling capabilities](https://ollama.com/search?c=tool).
+### Tool choice
+
+By default, the model will determine when and how many tools to use. You can control this behavior with the `tool_choice` parameter:
+
+- `"auto"` (default): The model decides whether to call zero, one, or multiple tools.
+- `"none"`: The model won't call any tools, even if they are provided.
+- `"required"`: The model must call at least one tool. The output is constrained to produce a valid tool call.
+- `{"function": {"name": "function_name"}}`: The model must call the specified function.
+
+Example with `tool_choice: "required"`:
+
+```json
+{
+ "model": "qwen3",
+ "messages": [{"role": "user", "content": "What is the weather in Paris?"}],
+ "tools": [...],
+ "tool_choice": "required",
+ "stream": false
+}
+```
+
+Example with a forced function:
+
+```json
+{
+ "model": "qwen3",
+ "messages": [{"role": "user", "content": "What is the weather in Paris?"}],
+ "tools": [...],
+ "tool_choice": {"function": {"name": "get_weather"}},
+ "stream": false
+}
+```
+
### Structured outputs
Structured outputs are supported by providing a JSON schema in the `format` parameter. The model will generate a response that matches the schema. See the [Chat request (Structured outputs)](#chat-request-structured-outputs) example below.
diff --git a/docs/api/openai-compatibility.mdx b/docs/api/openai-compatibility.mdx
index a0882053e..a49c8688d 100644
--- a/docs/api/openai-compatibility.mdx
+++ b/docs/api/openai-compatibility.mdx
@@ -103,6 +103,89 @@ curl -X POST http://localhost:11434/v1/responses \
+### Tool calling with `tool_choice`
+
+The `tool_choice` parameter controls how the model uses tools:
+
+
+
+```python tool_choice.py
+from openai import OpenAI
+
+client = OpenAI(
+ base_url='http://localhost:11434/v1/',
+ api_key='ollama',
+)
+
+tools = [
+ {
+ "type": "function",
+ "function": {
+ "name": "get_weather",
+ "description": "Get the weather for a location",
+ "parameters": {
+ "type": "object",
+ "properties": {
+ "location": {"type": "string", "description": "City name"}
+ },
+ "required": ["location"]
+ }
+ }
+ }
+]
+
+# Force the model to call a tool (any tool)
+response = client.chat.completions.create(
+ model="llama3.2",
+ messages=[{"role": "user", "content": "What's the weather in Paris?"}],
+ tools=tools,
+ tool_choice="required" # Must call at least one tool
+)
+
+# Force a specific function
+response = client.chat.completions.create(
+ model="llama3.2",
+ messages=[{"role": "user", "content": "Tell me about Paris"}],
+ tools=tools,
+ tool_choice={"type": "function", "name": "get_weather"} # Must call get_weather
+)
+
+# Disable tool calling
+response = client.chat.completions.create(
+ model="llama3.2",
+ messages=[{"role": "user", "content": "What's the weather?"}],
+ tools=tools,
+ tool_choice="none" # Don't call any tools
+)
+```
+
+```shell tool_choice.sh
+# Force the model to call a tool
+curl http://localhost:11434/v1/chat/completions \
+-H "Content-Type: application/json" \
+-d '{
+ "model": "llama3.2",
+ "messages": [{"role": "user", "content": "What is the weather in Paris?"}],
+ "tools": [{
+ "type": "function",
+ "function": {
+ "name": "get_weather",
+ "description": "Get weather for a location",
+ "parameters": {
+ "type": "object",
+ "properties": {
+ "location": {"type": "string"}
+ },
+ "required": ["location"]
+ }
+ }
+ }],
+ "tool_choice": "required"
+}'
+```
+
+
+
### v1/chat/completions with vision example
@@ -207,7 +290,12 @@ curl -X POST http://localhost:11434/v1/chat/completions \
- [x] `top_p`
- [x] `max_tokens`
- [x] `tools`
-- [ ] `tool_choice`
+- [x] `tool_choice`
+ - [x] `"auto"` (default)
+ - [x] `"none"`
+ - [x] `"required"`
+ - [x] `{"type": "function", "name": "function_name"}`
+ - [x] `{"type": "allowed_tools", "mode": "auto"|"required", "tools": [...]}`
- [ ] `logit_bias`
- [ ] `user`
- [ ] `n`
diff --git a/middleware/openai.go b/middleware/openai.go
index 5e526416e..84a8a6bd7 100644
--- a/middleware/openai.go
+++ b/middleware/openai.go
@@ -20,10 +20,12 @@ type BaseWriter struct {
}
type ChatWriter struct {
- stream bool
- streamOptions *openai.StreamOptions
- id string
- toolCallSent bool
+ stream bool
+ streamOptions *openai.StreamOptions
+ id string
+ toolCallSent bool
+ forcedToolCall bool
+ forcedToolName string
BaseWriter
}
@@ -65,6 +67,40 @@ func (w *BaseWriter) writeError(data []byte) (int, error) {
return len(data), nil
}
+func parseForcedToolCall(content string, forcedToolName string) *api.ToolCall {
+ if content == "" {
+ return nil
+ }
+
+ // Try to parse as tool call structure: {"name": "...", "arguments": {...}}
+ var toolCallJSON struct {
+ Name string `json:"name"`
+ Arguments map[string]any `json:"arguments"`
+ }
+
+ if err := json.Unmarshal([]byte(content), &toolCallJSON); err != nil {
+ return nil
+ }
+
+ // If a specific tool was forced, use that name
+ name := toolCallJSON.Name
+ if forcedToolName != "" {
+ name = forcedToolName
+ }
+
+ if name == "" {
+ return nil
+ }
+
+ return &api.ToolCall{
+ ID: fmt.Sprintf("call_%d", rand.Intn(999999)),
+ Function: api.ToolCallFunction{
+ Name: name,
+ Arguments: toolCallJSON.Arguments,
+ },
+ }
+}
+
func (w *ChatWriter) writeResponse(data []byte) (int, error) {
var chatResponse api.ChatResponse
err := json.Unmarshal(data, &chatResponse)
@@ -72,6 +108,15 @@ func (w *ChatWriter) writeResponse(data []byte) (int, error) {
return 0, err
}
+ // If tool_choice forced a tool call and we have content but no tool calls,
+ // try to parse the content as a tool call
+ if w.forcedToolCall && len(chatResponse.Message.ToolCalls) == 0 && chatResponse.Message.Content != "" {
+ if toolCall := parseForcedToolCall(chatResponse.Message.Content, w.forcedToolName); toolCall != nil {
+ chatResponse.Message.ToolCalls = []api.ToolCall{*toolCall}
+ chatResponse.Message.Content = ""
+ }
+ }
+
// chat chunk
if w.stream {
c := openai.ToChunk(w.id, chatResponse, w.toolCallSent)
@@ -406,6 +451,16 @@ func ChatMiddleware() gin.HandlerFunc {
return
}
+ // Determine if tool_choice forces a tool call
+ var forcedToolCall bool
+ var forcedToolName string
+ if req.ToolChoice != nil && len(req.Tools) > 0 {
+ _, _, forcedToolCall, _ = openai.ApplyToolChoice(req.Tools, req.ToolChoice)
+ if req.ToolChoice.IsForcedFunction() {
+ forcedToolName = req.ToolChoice.GetForcedFunctionName()
+ }
+ }
+
var b bytes.Buffer
chatReq, err := openai.FromChatRequest(req)
@@ -422,10 +477,12 @@ func ChatMiddleware() gin.HandlerFunc {
c.Request.Body = io.NopCloser(&b)
w := &ChatWriter{
- BaseWriter: BaseWriter{ResponseWriter: c.Writer},
- stream: req.Stream,
- id: fmt.Sprintf("chatcmpl-%d", rand.Intn(999)),
- streamOptions: req.StreamOptions,
+ BaseWriter: BaseWriter{ResponseWriter: c.Writer},
+ stream: req.Stream,
+ id: fmt.Sprintf("chatcmpl-%d", rand.Intn(999)),
+ streamOptions: req.StreamOptions,
+ forcedToolCall: forcedToolCall,
+ forcedToolName: forcedToolName,
}
c.Writer = w
diff --git a/openai/openai.go b/openai/openai.go
index 9dcba3000..9d0f95e0c 100644
--- a/openai/openai.go
+++ b/openai/openai.go
@@ -95,6 +95,132 @@ type Reasoning struct {
Effort string `json:"effort,omitempty"`
}
+type ToolChoiceFunctionRef struct {
+ Name string `json:"name"`
+}
+
+type ToolChoiceAllowedTool struct {
+ Type string `json:"type"` // "function"
+ Name string `json:"name"`
+}
+
+type ToolChoiceObject struct {
+ Type string `json:"type,omitempty"`
+ Name string `json:"name,omitempty"`
+ Function *ToolChoiceFunctionRef `json:"function,omitempty"`
+ Mode string `json:"mode,omitempty"`
+ Tools []ToolChoiceAllowedTool `json:"tools,omitempty"`
+}
+
+type ToolChoice struct {
+ Mode string
+ Object *ToolChoiceObject
+}
+
+// UnmarshalJSON handles both string and object forms of tool_choice
+func (tc *ToolChoice) UnmarshalJSON(data []byte) error {
+ // Try string first
+ var s string
+ if err := json.Unmarshal(data, &s); err == nil {
+ tc.Mode = s
+ tc.Object = nil
+ return nil
+ }
+
+ // Try object
+ var obj ToolChoiceObject
+ if err := json.Unmarshal(data, &obj); err != nil {
+ return fmt.Errorf("invalid tool_choice: must be string or object")
+ }
+ tc.Object = &obj
+ return nil
+}
+
+// MarshalJSON serializes ToolChoice back to JSON
+func (tc ToolChoice) MarshalJSON() ([]byte, error) {
+ if tc.Object != nil {
+ return json.Marshal(tc.Object)
+ }
+ return json.Marshal(tc.Mode)
+}
+
+// IsNone returns true if tool_choice is "none"
+func (tc *ToolChoice) IsNone() bool {
+ return tc != nil && tc.Mode == "none"
+}
+
+// IsRequired returns true if tool_choice is "required"
+func (tc *ToolChoice) IsRequired() bool {
+ return tc != nil && tc.Mode == "required"
+}
+
+// IsAuto returns true if tool_choice is "auto" or not specified
+func (tc *ToolChoice) IsAuto() bool {
+ if tc == nil {
+ return true
+ }
+ // If there's an object, check if it's allowed_tools with auto mode
+ if tc.Object != nil {
+ // allowed_tools with "auto" mode is still considered auto
+ if tc.Object.Type == "allowed_tools" && (tc.Object.Mode == "" || tc.Object.Mode == "auto") {
+ return true
+ }
+ // Any other object (forced function, allowed_tools with required) is not auto
+ return false
+ }
+ return tc.Mode == "" || tc.Mode == "auto"
+}
+
+// IsForcedFunction returns true if a specific function is forced
+func (tc *ToolChoice) IsForcedFunction() bool {
+ if tc == nil || tc.Object == nil {
+ return false
+ }
+ return tc.Object.Type == "function" || tc.Object.Name != "" || tc.Object.Function != nil
+}
+
+// GetForcedFunctionName returns the name of the forced function, if any
+func (tc *ToolChoice) GetForcedFunctionName() string {
+ if tc == nil || tc.Object == nil {
+ return ""
+ }
+ if tc.Object.Name != "" {
+ return tc.Object.Name
+ }
+ if tc.Object.Function != nil {
+ return tc.Object.Function.Name
+ }
+ return ""
+}
+
+// IsAllowedTools returns true if tool_choice uses allowed_tools mode
+func (tc *ToolChoice) IsAllowedTools() bool {
+ return tc != nil && tc.Object != nil && tc.Object.Type == "allowed_tools"
+}
+
+// GetAllowedToolNames returns the list of allowed tool names
+func (tc *ToolChoice) GetAllowedToolNames() []string {
+ if !tc.IsAllowedTools() || tc.Object.Tools == nil {
+ return nil
+ }
+ names := make([]string, len(tc.Object.Tools))
+ for i, t := range tc.Object.Tools {
+ names[i] = t.Name
+ }
+ return names
+}
+
+// GetAllowedToolsMode returns the mode for allowed_tools ("auto" or "required")
+func (tc *ToolChoice) GetAllowedToolsMode() string {
+ if !tc.IsAllowedTools() {
+ return ""
+ }
+ if tc.Object.Mode == "" {
+ return "auto"
+ }
+ return tc.Object.Mode
+}
+
type ChatCompletionRequest struct {
Model string `json:"model"`
Messages []Message `json:"messages"`
@@ -109,6 +235,7 @@ type ChatCompletionRequest struct {
TopP *float64 `json:"top_p"`
ResponseFormat *ResponseFormat `json:"response_format"`
Tools []api.Tool `json:"tools"`
+ ToolChoice *ToolChoice `json:"tool_choice,omitempty"`
Reasoning *Reasoning `json:"reasoning,omitempty"`
ReasoningEffort *string `json:"reasoning_effort,omitempty"`
Logprobs *bool `json:"logprobs"`
@@ -444,6 +571,171 @@ func ToModel(r api.ShowResponse, m string) Model {
}
}
+// filterToolsByNames returns only the tools that match the given names
+func filterToolsByNames(tools []api.Tool, names []string) []api.Tool {
+ if len(names) == 0 {
+ return tools
+ }
+ nameSet := make(map[string]bool)
+ for _, name := range names {
+ nameSet[name] = true
+ }
+ var filtered []api.Tool
+ for _, tool := range tools {
+ if nameSet[tool.Function.Name] {
+ filtered = append(filtered, tool)
+ }
+ }
+ return filtered
+}
+
+// findToolByName returns the tool with the given name, or nil if not found
+func findToolByName(tools []api.Tool, name string) *api.Tool {
+ for _, tool := range tools {
+ if tool.Function.Name == name {
+ return &tool
+ }
+ }
+ return nil
+}
+
+// generateToolCallSchema creates a JSON schema that constrains output to valid tool calls
+func generateToolCallSchema(tools []api.Tool) json.RawMessage {
+ if len(tools) == 0 {
+ return nil
+ }
+
+ // Collect all tool names for the enum
+ toolNames := make([]string, len(tools))
+ for i, tool := range tools {
+ toolNames[i] = tool.Function.Name
+ }
+
+ // Build a schema that allows any of the tools
+ // Using oneOf for each tool with its specific parameter schema
+ var oneOfSchemas []map[string]any
+ for _, tool := range tools {
+ toolSchema := map[string]any{
+ "type": "object",
+ "properties": map[string]any{
+ "name": map[string]any{
+ "type": "string",
+ "const": tool.Function.Name,
+ },
+ "arguments": tool.Function.Parameters,
+ },
+ "required": []string{"name", "arguments"},
+ "additionalProperties": false,
+ }
+ oneOfSchemas = append(oneOfSchemas, toolSchema)
+ }
+
+ var schema map[string]any
+ if len(oneOfSchemas) == 1 {
+ // Single tool - use its schema directly
+ schema = oneOfSchemas[0]
+ } else {
+ // Multiple tools - use oneOf
+ schema = map[string]any{
+ "oneOf": oneOfSchemas,
+ }
+ }
+
+ bytes, err := json.Marshal(schema)
+ if err != nil {
+ slog.Error("failed to marshal tool call schema", "error", err)
+ return nil
+ }
+ return bytes
+}
+
+// generateForcedFunctionSchema creates a JSON schema for a specific forced function
+func generateForcedFunctionSchema(tool api.Tool) json.RawMessage {
+ schema := map[string]any{
+ "type": "object",
+ "properties": map[string]any{
+ "name": map[string]any{
+ "type": "string",
+ "const": tool.Function.Name,
+ },
+ "arguments": tool.Function.Parameters,
+ },
+ "required": []string{"name", "arguments"},
+ "additionalProperties": false,
+ }
+
+ bytes, err := json.Marshal(schema)
+ if err != nil {
+ slog.Error("failed to marshal forced function schema", "error", err)
+ return nil
+ }
+ return bytes
+}
+
+// ApplyToolChoice processes tool_choice and returns filtered tools and optional format schema
+// Returns:
+// - filteredTools: the tools to pass to the model (may be empty for "none")
+// - format: JSON schema to constrain output (for "required" or forced function)
+// - forcedToolCall: true if the response should be parsed as a tool call
+// - error: if tool_choice references a non-existent tool
+func ApplyToolChoice(tools []api.Tool, toolChoice *ToolChoice) (filteredTools []api.Tool, format json.RawMessage, forcedToolCall bool, err error) {
+ // Default: auto mode, return all tools without format constraint
+ if toolChoice == nil || toolChoice.IsAuto() {
+ // Check for allowed_tools with auto mode
+ if toolChoice != nil && toolChoice.IsAllowedTools() {
+ allowedNames := toolChoice.GetAllowedToolNames()
+ filteredTools = filterToolsByNames(tools, allowedNames)
+ if toolChoice.GetAllowedToolsMode() == "required" {
+ format = generateToolCallSchema(filteredTools)
+ forcedToolCall = true
+ }
+ return filteredTools, format, forcedToolCall, nil
+ }
+ return tools, nil, false, nil
+ }
+
+ // "none" mode: don't pass any tools
+ if toolChoice.IsNone() {
+ return nil, nil, false, nil
+ }
+
+ // "required" mode: must call at least one tool
+ if toolChoice.IsRequired() {
+ format = generateToolCallSchema(tools)
+ return tools, format, true, nil
+ }
+
+ // Forced function mode
+ if toolChoice.IsForcedFunction() {
+ funcName := toolChoice.GetForcedFunctionName()
+ if funcName == "" {
+ return nil, nil, false, errors.New("tool_choice function name is required")
+ }
+
+ tool := findToolByName(tools, funcName)
+ if tool == nil {
+ return nil, nil, false, fmt.Errorf("tool_choice references unknown function: %s", funcName)
+ }
+
+ format = generateForcedFunctionSchema(*tool)
+ return []api.Tool{*tool}, format, true, nil
+ }
+
+ // allowed_tools mode (already handled in IsAuto check, but handle explicit type here)
+ if toolChoice.IsAllowedTools() {
+ allowedNames := toolChoice.GetAllowedToolNames()
+ filteredTools = filterToolsByNames(tools, allowedNames)
+ if toolChoice.GetAllowedToolsMode() == "required" {
+ format = generateToolCallSchema(filteredTools)
+ forcedToolCall = true
+ }
+ return filteredTools, format, forcedToolCall, nil
+ }
+
+ // Unknown mode, default to auto
+ return tools, nil, false, nil
+}
+
// FromChatRequest converts a ChatCompletionRequest to api.ChatRequest
func FromChatRequest(r ChatCompletionRequest) (*api.ChatRequest, error) {
var messages []api.Message
@@ -579,6 +871,18 @@ func FromChatRequest(r ChatCompletionRequest) (*api.ChatRequest, error) {
}
}
+ // Apply tool_choice to filter tools and potentially set format constraint
+ filteredTools, toolChoiceFormat, _, err := ApplyToolChoice(r.Tools, r.ToolChoice)
+ if err != nil {
+ return nil, err
+ }
+
+ // If tool_choice requires a format constraint and no explicit response_format was set,
+ // apply the tool call schema
+ if toolChoiceFormat != nil && format == nil {
+ format = toolChoiceFormat
+ }
+
var think *api.ThinkValue
var effort string
@@ -606,7 +910,7 @@ func FromChatRequest(r ChatCompletionRequest) (*api.ChatRequest, error) {
Format: format,
Options: options,
Stream: &r.Stream,
- Tools: r.Tools,
+ Tools: filteredTools,
Think: think,
Logprobs: r.Logprobs != nil && *r.Logprobs,
TopLogprobs: r.TopLogprobs,
diff --git a/openai/openai_test.go b/openai/openai_test.go
index 51e243dec..850a017d8 100644
--- a/openai/openai_test.go
+++ b/openai/openai_test.go
@@ -434,3 +434,480 @@ func TestFromChatRequest_TopLogprobsRange(t *testing.T) {
})
}
}
+
+func TestToolChoice_UnmarshalJSON(t *testing.T) {
+ tests := []struct {
+ name string
+ json string
+ wantMode string
+ wantObj bool
+ }{
+ {
+ name: "string auto",
+ json: `"auto"`,
+ wantMode: "auto",
+ wantObj: false,
+ },
+ {
+ name: "string none",
+ json: `"none"`,
+ wantMode: "none",
+ wantObj: false,
+ },
+ {
+ name: "string required",
+ json: `"required"`,
+ wantMode: "required",
+ wantObj: false,
+ },
+ {
+ name: "object function with name",
+ json: `{"type": "function", "name": "get_weather"}`,
+ wantMode: "",
+ wantObj: true,
+ },
+ {
+ name: "object function with function.name",
+ json: `{"type": "function", "function": {"name": "get_weather"}}`,
+ wantMode: "",
+ wantObj: true,
+ },
+ {
+ name: "object allowed_tools",
+ json: `{"type": "allowed_tools", "mode": "required", "tools": [{"type": "function", "name": "get_weather"}]}`,
+ wantMode: "",
+ wantObj: true,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ var tc ToolChoice
+ if err := tc.UnmarshalJSON([]byte(tt.json)); err != nil {
+ t.Fatalf("unexpected error: %v", err)
+ }
+
+ if tc.Mode != tt.wantMode {
+ t.Errorf("Mode = %q, want %q", tc.Mode, tt.wantMode)
+ }
+
+ if (tc.Object != nil) != tt.wantObj {
+ t.Errorf("Object = %v, wantObj = %v", tc.Object, tt.wantObj)
+ }
+ })
+ }
+}
+
+func TestToolChoice_Methods(t *testing.T) {
+ tests := []struct {
+ name string
+ toolChoice *ToolChoice
+ isNone bool
+ isRequired bool
+ isAuto bool
+ isForcedFunction bool
+ forcedFuncName string
+ isAllowedTools bool
+ allowedToolNames []string
+ allowedToolsMode string
+ }{
+ {
+ name: "nil",
+ toolChoice: nil,
+ isNone: false,
+ isRequired: false,
+ isAuto: true,
+ },
+ {
+ name: "auto string",
+ toolChoice: &ToolChoice{Mode: "auto"},
+ isNone: false,
+ isRequired: false,
+ isAuto: true,
+ },
+ {
+ name: "none string",
+ toolChoice: &ToolChoice{Mode: "none"},
+ isNone: true,
+ isRequired: false,
+ isAuto: false,
+ },
+ {
+ name: "required string",
+ toolChoice: &ToolChoice{Mode: "required"},
+ isNone: false,
+ isRequired: true,
+ isAuto: false,
+ },
+ {
+ name: "forced function with name",
+ toolChoice: &ToolChoice{
+ Object: &ToolChoiceObject{Type: "function", Name: "get_weather"},
+ },
+ isNone: false,
+ isRequired: false,
+ isAuto: false,
+ isForcedFunction: true,
+ forcedFuncName: "get_weather",
+ },
+ {
+ name: "forced function with function.name",
+ toolChoice: &ToolChoice{
+ Object: &ToolChoiceObject{
+ Type: "function",
+ Function: &ToolChoiceFunctionRef{Name: "search"},
+ },
+ },
+ isNone: false,
+ isRequired: false,
+ isAuto: false,
+ isForcedFunction: true,
+ forcedFuncName: "search",
+ },
+ {
+ name: "allowed_tools auto",
+ toolChoice: &ToolChoice{
+ Object: &ToolChoiceObject{
+ Type: "allowed_tools",
+ Mode: "auto",
+ Tools: []ToolChoiceAllowedTool{
+ {Type: "function", Name: "get_weather"},
+ {Type: "function", Name: "search"},
+ },
+ },
+ },
+ isNone: false,
+ isRequired: false,
+ isAuto: true,
+ isForcedFunction: false,
+ isAllowedTools: true,
+ allowedToolNames: []string{"get_weather", "search"},
+ allowedToolsMode: "auto",
+ },
+ {
+ name: "allowed_tools required",
+ toolChoice: &ToolChoice{
+ Object: &ToolChoiceObject{
+ Type: "allowed_tools",
+ Mode: "required",
+ Tools: []ToolChoiceAllowedTool{
+ {Type: "function", Name: "get_weather"},
+ },
+ },
+ },
+ isNone: false,
+ isRequired: false,
+ isAuto: false,
+ isForcedFunction: false,
+ isAllowedTools: true,
+ allowedToolNames: []string{"get_weather"},
+ allowedToolsMode: "required",
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ if got := tt.toolChoice.IsNone(); got != tt.isNone {
+ t.Errorf("IsNone() = %v, want %v", got, tt.isNone)
+ }
+ if got := tt.toolChoice.IsRequired(); got != tt.isRequired {
+ t.Errorf("IsRequired() = %v, want %v", got, tt.isRequired)
+ }
+ if got := tt.toolChoice.IsAuto(); got != tt.isAuto {
+ t.Errorf("IsAuto() = %v, want %v", got, tt.isAuto)
+ }
+ if got := tt.toolChoice.IsForcedFunction(); got != tt.isForcedFunction {
+ t.Errorf("IsForcedFunction() = %v, want %v", got, tt.isForcedFunction)
+ }
+ if got := tt.toolChoice.GetForcedFunctionName(); got != tt.forcedFuncName {
+ t.Errorf("GetForcedFunctionName() = %q, want %q", got, tt.forcedFuncName)
+ }
+ if got := tt.toolChoice.IsAllowedTools(); got != tt.isAllowedTools {
+ t.Errorf("IsAllowedTools() = %v, want %v", got, tt.isAllowedTools)
+ }
+ if tt.isAllowedTools {
+ if got := tt.toolChoice.GetAllowedToolNames(); !cmp.Equal(got, tt.allowedToolNames) {
+ t.Errorf("GetAllowedToolNames() = %v, want %v", got, tt.allowedToolNames)
+ }
+ if got := tt.toolChoice.GetAllowedToolsMode(); got != tt.allowedToolsMode {
+ t.Errorf("GetAllowedToolsMode() = %q, want %q", got, tt.allowedToolsMode)
+ }
+ }
+ })
+ }
+}
+
+func TestApplyToolChoice(t *testing.T) {
+ tools := []api.Tool{
+ {
+ Type: "function",
+ Function: api.ToolFunction{
+ Name: "get_weather",
+ Description: "Get weather",
+ Parameters: api.ToolFunctionParameters{
+ Type: "object",
+ Properties: map[string]api.ToolProperty{
+ "location": {Type: []string{"string"}},
+ },
+ Required: []string{"location"},
+ },
+ },
+ },
+ {
+ Type: "function",
+ Function: api.ToolFunction{
+ Name: "search",
+ Description: "Search the web",
+ Parameters: api.ToolFunctionParameters{
+ Type: "object",
+ Properties: map[string]api.ToolProperty{
+ "query": {Type: []string{"string"}},
+ },
+ Required: []string{"query"},
+ },
+ },
+ },
+ }
+
+ tests := []struct {
+ name string
+ toolChoice *ToolChoice
+ wantToolCount int
+ wantFormat bool
+ wantForced bool
+ wantError bool
+ wantToolNames []string
+ }{
+ {
+ name: "nil (auto)",
+ toolChoice: nil,
+ wantToolCount: 2,
+ wantFormat: false,
+ wantForced: false,
+ },
+ {
+ name: "auto string",
+ toolChoice: &ToolChoice{Mode: "auto"},
+ wantToolCount: 2,
+ wantFormat: false,
+ wantForced: false,
+ },
+ {
+ name: "none string",
+ toolChoice: &ToolChoice{Mode: "none"},
+ wantToolCount: 0,
+ wantFormat: false,
+ wantForced: false,
+ },
+ {
+ name: "required string",
+ toolChoice: &ToolChoice{Mode: "required"},
+ wantToolCount: 2,
+ wantFormat: true,
+ wantForced: true,
+ },
+ {
+ name: "forced function",
+ toolChoice: &ToolChoice{
+ Object: &ToolChoiceObject{Type: "function", Name: "get_weather"},
+ },
+ wantToolCount: 1,
+ wantFormat: true,
+ wantForced: true,
+ wantToolNames: []string{"get_weather"},
+ },
+ {
+ name: "forced unknown function",
+ toolChoice: &ToolChoice{
+ Object: &ToolChoiceObject{Type: "function", Name: "unknown_func"},
+ },
+ wantError: true,
+ },
+ {
+ name: "allowed_tools auto",
+ toolChoice: &ToolChoice{
+ Object: &ToolChoiceObject{
+ Type: "allowed_tools",
+ Mode: "auto",
+ Tools: []ToolChoiceAllowedTool{
+ {Type: "function", Name: "get_weather"},
+ },
+ },
+ },
+ wantToolCount: 1,
+ wantFormat: false,
+ wantForced: false,
+ wantToolNames: []string{"get_weather"},
+ },
+ {
+ name: "allowed_tools required",
+ toolChoice: &ToolChoice{
+ Object: &ToolChoiceObject{
+ Type: "allowed_tools",
+ Mode: "required",
+ Tools: []ToolChoiceAllowedTool{
+ {Type: "function", Name: "search"},
+ },
+ },
+ },
+ wantToolCount: 1,
+ wantFormat: true,
+ wantForced: true,
+ wantToolNames: []string{"search"},
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ filteredTools, format, forced, err := ApplyToolChoice(tools, tt.toolChoice)
+
+ if tt.wantError {
+ if err == nil {
+ t.Errorf("expected error, got nil")
+ }
+ return
+ }
+
+ if err != nil {
+ t.Fatalf("unexpected error: %v", err)
+ }
+
+ if len(filteredTools) != tt.wantToolCount {
+ t.Errorf("got %d tools, want %d", len(filteredTools), tt.wantToolCount)
+ }
+
+ if (format != nil) != tt.wantFormat {
+ t.Errorf("format = %v, wantFormat = %v", format != nil, tt.wantFormat)
+ }
+
+ if forced != tt.wantForced {
+ t.Errorf("forced = %v, want %v", forced, tt.wantForced)
+ }
+
+ if tt.wantToolNames != nil {
+ gotNames := make([]string, len(filteredTools))
+ for i, tool := range filteredTools {
+ gotNames[i] = tool.Function.Name
+ }
+ if !cmp.Equal(gotNames, tt.wantToolNames) {
+ t.Errorf("tool names = %v, want %v", gotNames, tt.wantToolNames)
+ }
+ }
+ })
+ }
+}
+
+func TestFromChatRequest_WithToolChoice(t *testing.T) {
+ tools := []api.Tool{
+ {
+ Type: "function",
+ Function: api.ToolFunction{
+ Name: "get_weather",
+ Description: "Get weather",
+ Parameters: api.ToolFunctionParameters{
+ Type: "object",
+ Properties: map[string]api.ToolProperty{
+ "location": {Type: []string{"string"}},
+ },
+ },
+ },
+ },
+ {
+ Type: "function",
+ Function: api.ToolFunction{
+ Name: "search",
+ Description: "Search",
+ Parameters: api.ToolFunctionParameters{
+ Type: "object",
+ Properties: map[string]api.ToolProperty{
+ "query": {Type: []string{"string"}},
+ },
+ },
+ },
+ },
+ }
+
+ t.Run("tool_choice none removes tools", func(t *testing.T) {
+ req := ChatCompletionRequest{
+ Model: "test-model",
+ Messages: []Message{{Role: "user", Content: "Hello"}},
+ Tools: tools,
+ ToolChoice: &ToolChoice{Mode: "none"},
+ }
+
+ result, err := FromChatRequest(req)
+ if err != nil {
+ t.Fatalf("unexpected error: %v", err)
+ }
+
+ if len(result.Tools) != 0 {
+ t.Errorf("expected 0 tools, got %d", len(result.Tools))
+ }
+ })
+
+ t.Run("tool_choice required adds format", func(t *testing.T) {
+ req := ChatCompletionRequest{
+ Model: "test-model",
+ Messages: []Message{{Role: "user", Content: "Hello"}},
+ Tools: tools,
+ ToolChoice: &ToolChoice{Mode: "required"},
+ }
+
+ result, err := FromChatRequest(req)
+ if err != nil {
+ t.Fatalf("unexpected error: %v", err)
+ }
+
+ if len(result.Tools) != 2 {
+ t.Errorf("expected 2 tools, got %d", len(result.Tools))
+ }
+
+ if result.Format == nil {
+ t.Error("expected format to be set for required tool_choice")
+ }
+ })
+
+ t.Run("tool_choice forced function filters and adds format", func(t *testing.T) {
+ req := ChatCompletionRequest{
+ Model: "test-model",
+ Messages: []Message{{Role: "user", Content: "Hello"}},
+ Tools: tools,
+ ToolChoice: &ToolChoice{
+ Object: &ToolChoiceObject{Type: "function", Name: "get_weather"},
+ },
+ }
+
+ result, err := FromChatRequest(req)
+ if err != nil {
+ t.Fatalf("unexpected error: %v", err)
+ }
+
+ if len(result.Tools) != 1 {
+ t.Errorf("expected 1 tool, got %d", len(result.Tools))
+ }
+
+ if result.Tools[0].Function.Name != "get_weather" {
+ t.Errorf("expected tool 'get_weather', got %q", result.Tools[0].Function.Name)
+ }
+
+ if result.Format == nil {
+ t.Error("expected format to be set for forced function")
+ }
+ })
+
+ t.Run("tool_choice unknown function returns error", func(t *testing.T) {
+ req := ChatCompletionRequest{
+ Model: "test-model",
+ Messages: []Message{{Role: "user", Content: "Hello"}},
+ Tools: tools,
+ ToolChoice: &ToolChoice{
+ Object: &ToolChoiceObject{Type: "function", Name: "unknown"},
+ },
+ }
+
+ _, err := FromChatRequest(req)
+ if err == nil {
+ t.Error("expected error for unknown function")
+ }
+ })
+}
diff --git a/server/routes.go b/server/routes.go
index 977a13ff2..5de1882cf 100644
--- a/server/routes.go
+++ b/server/routes.go
@@ -1861,6 +1861,112 @@ func toolCallId() string {
return "call_" + strings.ToLower(string(b))
}
+// findToolByName returns the tool with the given name, or nil if not found.
+func findToolByName(tools []api.Tool, name string) *api.Tool {
+ for i := range tools {
+ if tools[i].Function.Name == name {
+ return &tools[i]
+ }
+ }
+ return nil
+}
+
+// generateToolCallSchema creates a JSON schema that constrains output to valid tool calls.
+func generateToolCallSchema(tools []api.Tool) json.RawMessage {
+ if len(tools) == 0 {
+ return nil
+ }
+
+ var oneOfSchemas []map[string]any
+ for _, tool := range tools {
+ toolSchema := map[string]any{
+ "type": "object",
+ "properties": map[string]any{
+ "name": map[string]any{
+ "type": "string",
+ "const": tool.Function.Name,
+ },
+ "arguments": tool.Function.Parameters,
+ },
+ "required": []string{"name", "arguments"},
+ "additionalProperties": false,
+ }
+ oneOfSchemas = append(oneOfSchemas, toolSchema)
+ }
+
+ var schema map[string]any
+ if len(oneOfSchemas) == 1 {
+ schema = oneOfSchemas[0]
+ } else {
+ schema = map[string]any{
+ "oneOf": oneOfSchemas,
+ }
+ }
+
+ bytes, err := json.Marshal(schema)
+ if err != nil {
+ slog.Error("failed to marshal tool call schema", "error", err)
+ return nil
+ }
+ return bytes
+}
+
+// generateForcedFunctionSchema creates a JSON schema for a specific forced function.
+func generateForcedFunctionSchema(tool api.Tool) json.RawMessage {
+ schema := map[string]any{
+ "type": "object",
+ "properties": map[string]any{
+ "name": map[string]any{
+ "type": "string",
+ "const": tool.Function.Name,
+ },
+ "arguments": tool.Function.Parameters,
+ },
+ "required": []string{"name", "arguments"},
+ "additionalProperties": false,
+ }
+
+ bytes, err := json.Marshal(schema)
+ if err != nil {
+ slog.Error("failed to marshal forced function schema", "error", err)
+ return nil
+ }
+ return bytes
+}
+
+// parseForcedToolCallContent parses content as a forced tool call JSON.
+func parseForcedToolCallContent(content string, forcedToolName string) *api.ToolCall {
+ if content == "" {
+ return nil
+ }
+
+ var toolCallJSON struct {
+ Name string `json:"name"`
+ Arguments map[string]any `json:"arguments"`
+ }
+
+ if err := json.Unmarshal([]byte(content), &toolCallJSON); err != nil {
+ return nil
+ }
+
+ name := toolCallJSON.Name
+ if forcedToolName != "" {
+ name = forcedToolName
+ }
+
+ if name == "" {
+ return nil
+ }
+
+ return &api.ToolCall{
+ ID: toolCallId(),
+ Function: api.ToolCallFunction{
+ Name: name,
+ Arguments: toolCallJSON.Arguments,
+ },
+ }
+}
+
func (s *Server) ChatHandler(c *gin.Context) {
checkpointStart := time.Now()
@@ -2075,6 +2181,36 @@ func (s *Server) ChatHandler(c *gin.Context) {
}
}
+ // Handle tool_choice
+ var forcedToolCall bool
+ var forcedToolName string
+ if req.ToolChoice != nil && len(req.Tools) > 0 {
+ if req.ToolChoice.IsNone() {
+ // "none" mode: don't pass any tools
+ processedTools = nil
+ } else if req.ToolChoice.IsRequired() {
+ // "required" mode: must call at least one tool, generate JSON schema
+ if req.Format == nil {
+ req.Format = generateToolCallSchema(req.Tools)
+ }
+ forcedToolCall = true
+ } else if req.ToolChoice.IsForcedFunction() {
+ // Forced function mode: must call this specific function
+ funcName := req.ToolChoice.GetForcedFunctionName()
+ tool := findToolByName(req.Tools, funcName)
+ if tool == nil {
+ c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("tool_choice references unknown function: %s", funcName)})
+ return
+ }
+ if req.Format == nil {
+ req.Format = generateForcedFunctionSchema(*tool)
+ }
+ processedTools = []api.Tool{*tool}
+ forcedToolCall = true
+ forcedToolName = funcName
+ }
+ }
+
truncate := req.Truncate == nil || *req.Truncate
prompt, images, err := chatPrompt(c.Request.Context(), m, r.Tokenize, opts, msgs, processedTools, req.Think, truncate)
if err != nil {
@@ -2206,6 +2342,15 @@ func (s *Server) ChatHandler(c *gin.Context) {
return
}
+ // If tool_choice forced a tool call and builtinParser didn't find any,
+ // try to parse the content as a forced tool call JSON
+ if r.Done && forcedToolCall && len(res.Message.ToolCalls) == 0 && res.Message.Content != "" {
+ if toolCall := parseForcedToolCallContent(res.Message.Content, forcedToolName); toolCall != nil {
+ res.Message.ToolCalls = []api.ToolCall{*toolCall}
+ res.Message.Content = ""
+ }
+ }
+
if res.Message.Content != "" || res.Message.Thinking != "" || len(res.Message.ToolCalls) > 0 || r.Done || len(res.Logprobs) > 0 {
slog.Log(context.TODO(), logutil.LevelTrace, "builtin parser output", "parser", m.Config.Parser, "content", content, "thinking", thinking, "toolCalls", toolCalls, "done", r.Done)
ch <- res
@@ -2235,7 +2380,7 @@ func (s *Server) ChatHandler(c *gin.Context) {
res.Message.Content = remainingContent
}
- if len(req.Tools) > 0 {
+ if len(req.Tools) > 0 && toolParser != nil {
toolCalls, content := toolParser.Add(res.Message.Content)
if len(content) > 0 {
res.Message.Content = content
@@ -2258,12 +2403,28 @@ func (s *Server) ChatHandler(c *gin.Context) {
if r.Done {
res.Message.Content = toolParser.Content()
+ // If tool_choice forced a tool call, try to parse the buffered content
+ if forcedToolCall && len(res.Message.ToolCalls) == 0 && res.Message.Content != "" {
+ if toolCall := parseForcedToolCallContent(res.Message.Content, forcedToolName); toolCall != nil {
+ res.Message.ToolCalls = []api.ToolCall{*toolCall}
+ res.Message.Content = ""
+ }
+ }
ch <- res
}
return
}
}
+ // If tool_choice forced a tool call and we have content but no tool calls,
+ // try to parse the content as a tool call (used when format constraint was applied)
+ if r.Done && forcedToolCall && len(res.Message.ToolCalls) == 0 && res.Message.Content != "" {
+ if toolCall := parseForcedToolCallContent(res.Message.Content, forcedToolName); toolCall != nil {
+ res.Message.ToolCalls = []api.ToolCall{*toolCall}
+ res.Message.Content = ""
+ }
+ }
+
ch <- res
})
if err != nil {
@@ -2353,6 +2514,18 @@ func (s *Server) ChatHandler(c *gin.Context) {
if len(toolCalls) > 0 {
resp.Message.ToolCalls = toolCalls
+ // If we have tool calls from forced tool_choice, the "content" was actually
+ // the JSON that got parsed into tool calls, so clear it
+ if forcedToolCall {
+ resp.Message.Content = ""
+ }
+ } else if forcedToolCall && resp.Message.Content != "" {
+ // No tool calls were parsed in callbacks, but we have forced tool_choice.
+ // Try to parse the accumulated content as a tool call JSON.
+ if toolCall := parseForcedToolCallContent(resp.Message.Content, forcedToolName); toolCall != nil {
+ resp.Message.ToolCalls = []api.ToolCall{*toolCall}
+ resp.Message.Content = ""
+ }
}
c.JSON(http.StatusOK, resp)
diff --git a/server/routes_test.go b/server/routes_test.go
index e470b9384..10c03b255 100644
--- a/server/routes_test.go
+++ b/server/routes_test.go
@@ -978,3 +978,227 @@ func TestWaitForStream(t *testing.T) {
})
}
}
+
+func TestFindToolByName(t *testing.T) {
+ tools := []api.Tool{
+ {
+ Type: "function",
+ Function: api.ToolFunction{
+ Name: "get_weather",
+ Description: "Get weather for a location",
+ },
+ },
+ {
+ Type: "function",
+ Function: api.ToolFunction{
+ Name: "search_web",
+ Description: "Search the web",
+ },
+ },
+ }
+
+ t.Run("found", func(t *testing.T) {
+ tool := findToolByName(tools, "get_weather")
+ if tool == nil {
+ t.Fatal("expected to find tool")
+ }
+ if tool.Function.Name != "get_weather" {
+ t.Errorf("expected get_weather, got %s", tool.Function.Name)
+ }
+ })
+
+ t.Run("not found", func(t *testing.T) {
+ tool := findToolByName(tools, "nonexistent")
+ if tool != nil {
+ t.Error("expected nil for nonexistent tool")
+ }
+ })
+
+ t.Run("empty tools", func(t *testing.T) {
+ tool := findToolByName([]api.Tool{}, "get_weather")
+ if tool != nil {
+ t.Error("expected nil for empty tools")
+ }
+ })
+}
+
+func TestGenerateToolCallSchema(t *testing.T) {
+ tools := []api.Tool{
+ {
+ Type: "function",
+ Function: api.ToolFunction{
+ Name: "get_weather",
+ Description: "Get weather",
+ Parameters: api.ToolFunctionParameters{
+ Type: "object",
+ Required: []string{"location"},
+ Properties: map[string]api.ToolProperty{
+ "location": {Type: api.PropertyType{"string"}},
+ },
+ },
+ },
+ },
+ }
+
+ t.Run("single tool", func(t *testing.T) {
+ schema := generateToolCallSchema(tools)
+ if schema == nil {
+ t.Fatal("expected schema, got nil")
+ }
+
+ var parsed map[string]any
+ if err := json.Unmarshal(schema, &parsed); err != nil {
+ t.Fatalf("failed to parse schema: %v", err)
+ }
+
+ // Should have properties with name and arguments
+ props, ok := parsed["properties"].(map[string]any)
+ if !ok {
+ t.Fatal("expected properties in schema")
+ }
+
+ if _, ok := props["name"]; !ok {
+ t.Error("expected name property")
+ }
+ if _, ok := props["arguments"]; !ok {
+ t.Error("expected arguments property")
+ }
+ })
+
+ t.Run("multiple tools", func(t *testing.T) {
+ multiTools := append(tools, api.Tool{
+ Type: "function",
+ Function: api.ToolFunction{
+ Name: "search_web",
+ },
+ })
+
+ schema := generateToolCallSchema(multiTools)
+ if schema == nil {
+ t.Fatal("expected schema, got nil")
+ }
+
+ var parsed map[string]any
+ if err := json.Unmarshal(schema, &parsed); err != nil {
+ t.Fatalf("failed to parse schema: %v", err)
+ }
+
+ // Should have oneOf for multiple tools
+ if _, ok := parsed["oneOf"]; !ok {
+ t.Error("expected oneOf for multiple tools")
+ }
+ })
+
+ t.Run("empty tools", func(t *testing.T) {
+ schema := generateToolCallSchema([]api.Tool{})
+ if schema != nil {
+ t.Error("expected nil for empty tools")
+ }
+ })
+}
+
+func TestGenerateForcedFunctionSchema(t *testing.T) {
+ tool := api.Tool{
+ Type: "function",
+ Function: api.ToolFunction{
+ Name: "get_weather",
+ Description: "Get weather",
+ Parameters: api.ToolFunctionParameters{
+ Type: "object",
+ Required: []string{"location"},
+ Properties: map[string]api.ToolProperty{
+ "location": {Type: api.PropertyType{"string"}},
+ },
+ },
+ },
+ }
+
+ schema := generateForcedFunctionSchema(tool)
+ if schema == nil {
+ t.Fatal("expected schema, got nil")
+ }
+
+ var parsed map[string]any
+ if err := json.Unmarshal(schema, &parsed); err != nil {
+ t.Fatalf("failed to parse schema: %v", err)
+ }
+
+ // Check that name has const constraint
+ props := parsed["properties"].(map[string]any)
+ nameSchema := props["name"].(map[string]any)
+ if nameSchema["const"] != "get_weather" {
+ t.Errorf("expected const get_weather, got %v", nameSchema["const"])
+ }
+}
+
+func TestParseForcedToolCallContent(t *testing.T) {
+ t.Run("valid tool call JSON", func(t *testing.T) {
+ content := `{"name": "get_weather", "arguments": {"location": "Paris"}}`
+ toolCall := parseForcedToolCallContent(content, "")
+
+ if toolCall == nil {
+ t.Fatal("expected tool call, got nil")
+ }
+
+ if toolCall.Function.Name != "get_weather" {
+ t.Errorf("expected get_weather, got %s", toolCall.Function.Name)
+ }
+
+ if toolCall.Function.Arguments["location"] != "Paris" {
+ t.Errorf("expected Paris, got %v", toolCall.Function.Arguments["location"])
+ }
+
+ if toolCall.ID == "" {
+ t.Error("expected non-empty tool call ID")
+ }
+ })
+
+ t.Run("forced tool name override", func(t *testing.T) {
+ content := `{"name": "other_tool", "arguments": {"location": "Paris"}}`
+ toolCall := parseForcedToolCallContent(content, "get_weather")
+
+ if toolCall == nil {
+ t.Fatal("expected tool call, got nil")
+ }
+
+ // Should use the forced name, not the one in JSON
+ if toolCall.Function.Name != "get_weather" {
+ t.Errorf("expected get_weather (forced), got %s", toolCall.Function.Name)
+ }
+ })
+
+ t.Run("empty content", func(t *testing.T) {
+ toolCall := parseForcedToolCallContent("", "")
+ if toolCall != nil {
+ t.Error("expected nil for empty content")
+ }
+ })
+
+ t.Run("invalid JSON", func(t *testing.T) {
+ toolCall := parseForcedToolCallContent("not json", "")
+ if toolCall != nil {
+ t.Error("expected nil for invalid JSON")
+ }
+ })
+
+ t.Run("missing name", func(t *testing.T) {
+ content := `{"arguments": {"location": "Paris"}}`
+ toolCall := parseForcedToolCallContent(content, "")
+ if toolCall != nil {
+ t.Error("expected nil when name is missing and not forced")
+ }
+ })
+
+ t.Run("missing name but forced", func(t *testing.T) {
+ content := `{"arguments": {"location": "Paris"}}`
+ toolCall := parseForcedToolCallContent(content, "get_weather")
+
+ if toolCall == nil {
+ t.Fatal("expected tool call with forced name")
+ }
+
+ if toolCall.Function.Name != "get_weather" {
+ t.Errorf("expected get_weather, got %s", toolCall.Function.Name)
+ }
+ })
+}