This commit is contained in:
Baptiste Jamin 2026-01-05 11:24:12 +00:00 committed by GitHub
commit 1af0c84bbd
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 1623 additions and 11 deletions

View File

@ -148,6 +148,13 @@ type ChatRequest struct {
// Tools is an optional list of tools the model has access to. // Tools is an optional list of tools the model has access to.
Tools `json:"tools,omitempty"` Tools `json:"tools,omitempty"`
// ToolChoice controls how the model uses tools. Can be:
// - "auto" (default): model decides whether to call tools
// - "none": model won't call any tools
// - "required": model must call at least one tool
// - ToolChoiceFunction{Name: "func_name"}: model must call this specific function
ToolChoice *ToolChoice `json:"tool_choice,omitempty"`
// Options lists model-specific options. // Options lists model-specific options.
Options map[string]any `json:"options"` Options map[string]any `json:"options"`
@ -184,6 +191,87 @@ func (t Tools) String() string {
return string(bts) return string(bts)
} }
// ToolChoice controls how the model uses tools.
// It can be a string ("auto", "none", "required") or a ToolChoiceFunction.
type ToolChoice struct {
// Mode is the tool choice mode: "auto", "none", or "required"
Mode string `json:"-"`
// Function specifies a specific function to call (when forcing a specific tool)
Function *ToolChoiceFunction `json:"-"`
}
// ToolChoiceFunction specifies a specific function that the model must call.
type ToolChoiceFunction struct {
Name string `json:"name"`
}
// UnmarshalJSON handles both string and object forms of tool_choice.
func (tc *ToolChoice) UnmarshalJSON(data []byte) error {
// Try string first: "auto", "none", "required"
var s string
if err := json.Unmarshal(data, &s); err == nil {
tc.Mode = s
tc.Function = nil
return nil
}
// Try object with function name: {"function": {"name": "func_name"}}
var obj struct {
Function *ToolChoiceFunction `json:"function"`
}
if err := json.Unmarshal(data, &obj); err == nil && obj.Function != nil {
tc.Function = obj.Function
tc.Mode = ""
return nil
}
// Try simple object with just name: {"name": "func_name"}
var simple ToolChoiceFunction
if err := json.Unmarshal(data, &simple); err == nil && simple.Name != "" {
tc.Function = &simple
tc.Mode = ""
return nil
}
return fmt.Errorf("invalid tool_choice: must be string or object with function name")
}
// MarshalJSON serializes ToolChoice back to JSON.
func (tc ToolChoice) MarshalJSON() ([]byte, error) {
if tc.Function != nil {
return json.Marshal(map[string]any{"function": tc.Function})
}
return json.Marshal(tc.Mode)
}
// IsNone returns true if tool_choice is "none".
func (tc *ToolChoice) IsNone() bool {
return tc != nil && tc.Mode == "none"
}
// IsRequired returns true if tool_choice is "required".
func (tc *ToolChoice) IsRequired() bool {
return tc != nil && tc.Mode == "required"
}
// IsAuto returns true if tool_choice is "auto" or not specified.
func (tc *ToolChoice) IsAuto() bool {
return tc == nil || tc.Mode == "" || tc.Mode == "auto"
}
// IsForcedFunction returns true if a specific function is forced.
func (tc *ToolChoice) IsForcedFunction() bool {
return tc != nil && tc.Function != nil && tc.Function.Name != ""
}
// GetForcedFunctionName returns the name of the forced function, if any.
func (tc *ToolChoice) GetForcedFunctionName() string {
if tc == nil || tc.Function == nil {
return ""
}
return tc.Function.Name
}
func (t Tool) String() string { func (t Tool) String() string {
bts, _ := json.Marshal(t) bts, _ := json.Marshal(t)
return string(bts) return string(bts)

View File

@ -651,3 +651,170 @@ func TestToolFunctionParameters_String(t *testing.T) {
}) })
} }
} }
func TestToolChoice_UnmarshalJSON(t *testing.T) {
tests := []struct {
name string
input string
expectedMode string
expectedFunc string
expectedError bool
}{
{
name: "auto string",
input: `{"tool_choice": "auto"}`,
expectedMode: "auto",
expectedFunc: "",
},
{
name: "none string",
input: `{"tool_choice": "none"}`,
expectedMode: "none",
expectedFunc: "",
},
{
name: "required string",
input: `{"tool_choice": "required"}`,
expectedMode: "required",
expectedFunc: "",
},
{
name: "function object with nested function",
input: `{"tool_choice": {"function": {"name": "get_weather"}}}`,
expectedMode: "",
expectedFunc: "get_weather",
},
{
name: "function object with direct name",
input: `{"tool_choice": {"name": "get_weather"}}`,
expectedMode: "",
expectedFunc: "get_weather",
},
{
name: "unset",
input: `{}`,
expectedMode: "",
expectedFunc: "",
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
var req ChatRequest
err := json.Unmarshal([]byte(test.input), &req)
if test.expectedError {
require.Error(t, err)
return
}
require.NoError(t, err)
if test.expectedMode == "" && test.expectedFunc == "" && test.name == "unset" {
assert.Nil(t, req.ToolChoice)
return
}
require.NotNil(t, req.ToolChoice)
if test.expectedMode != "" {
assert.Equal(t, test.expectedMode, req.ToolChoice.Mode)
}
if test.expectedFunc != "" {
require.NotNil(t, req.ToolChoice.Function)
assert.Equal(t, test.expectedFunc, req.ToolChoice.Function.Name)
}
})
}
}
func TestToolChoice_Methods(t *testing.T) {
t.Run("IsNone", func(t *testing.T) {
tc := &ToolChoice{Mode: "none"}
assert.True(t, tc.IsNone())
tc = &ToolChoice{Mode: "auto"}
assert.False(t, tc.IsNone())
var nilTc *ToolChoice
assert.False(t, nilTc.IsNone())
})
t.Run("IsRequired", func(t *testing.T) {
tc := &ToolChoice{Mode: "required"}
assert.True(t, tc.IsRequired())
tc = &ToolChoice{Mode: "auto"}
assert.False(t, tc.IsRequired())
var nilTc *ToolChoice
assert.False(t, nilTc.IsRequired())
})
t.Run("IsAuto", func(t *testing.T) {
tc := &ToolChoice{Mode: "auto"}
assert.True(t, tc.IsAuto())
tc = &ToolChoice{Mode: ""}
assert.True(t, tc.IsAuto())
var nilTc *ToolChoice
assert.True(t, nilTc.IsAuto())
tc = &ToolChoice{Mode: "required"}
assert.False(t, tc.IsAuto())
})
t.Run("IsForcedFunction", func(t *testing.T) {
tc := &ToolChoice{Function: &ToolChoiceFunction{Name: "get_weather"}}
assert.True(t, tc.IsForcedFunction())
tc = &ToolChoice{Mode: "required"}
assert.False(t, tc.IsForcedFunction())
tc = &ToolChoice{Function: &ToolChoiceFunction{Name: ""}}
assert.False(t, tc.IsForcedFunction())
var nilTc *ToolChoice
assert.False(t, nilTc.IsForcedFunction())
})
t.Run("GetForcedFunctionName", func(t *testing.T) {
tc := &ToolChoice{Function: &ToolChoiceFunction{Name: "get_weather"}}
assert.Equal(t, "get_weather", tc.GetForcedFunctionName())
tc = &ToolChoice{Mode: "required"}
assert.Equal(t, "", tc.GetForcedFunctionName())
var nilTc *ToolChoice
assert.Equal(t, "", nilTc.GetForcedFunctionName())
})
}
func TestToolChoice_MarshalJSON(t *testing.T) {
tests := []struct {
name string
tc ToolChoice
expected string
}{
{
name: "mode string",
tc: ToolChoice{Mode: "required"},
expected: `"required"`,
},
{
name: "function object",
tc: ToolChoice{Function: &ToolChoiceFunction{Name: "get_weather"}},
expected: `{"function":{"name":"get_weather"}}`,
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
data, err := json.Marshal(test.tc)
require.NoError(t, err)
assert.Equal(t, test.expected, string(data))
})
}
}

View File

@ -493,6 +493,7 @@ Generate the next message in a chat with a provided model. This is a streaming e
- `model`: (required) the [model name](#model-names) - `model`: (required) the [model name](#model-names)
- `messages`: the messages of the chat, this can be used to keep a chat memory - `messages`: the messages of the chat, this can be used to keep a chat memory
- `tools`: list of tools in JSON for the model to use if supported - `tools`: list of tools in JSON for the model to use if supported
- `tool_choice`: controls how the model uses tools (see [Tool choice](#tool-choice) below)
- `think`: (for thinking models) should the model think before responding? - `think`: (for thinking models) should the model think before responding?
The `message` object has the following fields: The `message` object has the following fields:
@ -519,6 +520,39 @@ Models can also explain the result of the tool call in the response. See the [Ch
[See models with tool calling capabilities](https://ollama.com/search?c=tool). [See models with tool calling capabilities](https://ollama.com/search?c=tool).
### Tool choice
By default, the model will determine when and how many tools to use. You can control this behavior with the `tool_choice` parameter:
- `"auto"` (default): The model decides whether to call zero, one, or multiple tools.
- `"none"`: The model won't call any tools, even if they are provided.
- `"required"`: The model must call at least one tool. The output is constrained to produce a valid tool call.
- `{"function": {"name": "function_name"}}`: The model must call the specified function.
Example with `tool_choice: "required"`:
```json
{
"model": "qwen3",
"messages": [{"role": "user", "content": "What is the weather in Paris?"}],
"tools": [...],
"tool_choice": "required",
"stream": false
}
```
Example with a forced function:
```json
{
"model": "qwen3",
"messages": [{"role": "user", "content": "What is the weather in Paris?"}],
"tools": [...],
"tool_choice": {"function": {"name": "get_weather"}},
"stream": false
}
```
### Structured outputs ### Structured outputs
Structured outputs are supported by providing a JSON schema in the `format` parameter. The model will generate a response that matches the schema. See the [Chat request (Structured outputs)](#chat-request-structured-outputs) example below. Structured outputs are supported by providing a JSON schema in the `format` parameter. The model will generate a response that matches the schema. See the [Chat request (Structured outputs)](#chat-request-structured-outputs) example below.

View File

@ -103,6 +103,89 @@ curl -X POST http://localhost:11434/v1/responses \
</CodeGroup> </CodeGroup>
### Tool calling with `tool_choice`
The `tool_choice` parameter controls how the model uses tools:
<CodeGroup dropdown>
```python tool_choice.py
from openai import OpenAI
client = OpenAI(
base_url='http://localhost:11434/v1/',
api_key='ollama',
)
tools = [
{
"type": "function",
"function": {
"name": "get_weather",
"description": "Get the weather for a location",
"parameters": {
"type": "object",
"properties": {
"location": {"type": "string", "description": "City name"}
},
"required": ["location"]
}
}
}
]
# Force the model to call a tool (any tool)
response = client.chat.completions.create(
model="llama3.2",
messages=[{"role": "user", "content": "What's the weather in Paris?"}],
tools=tools,
tool_choice="required" # Must call at least one tool
)
# Force a specific function
response = client.chat.completions.create(
model="llama3.2",
messages=[{"role": "user", "content": "Tell me about Paris"}],
tools=tools,
tool_choice={"type": "function", "name": "get_weather"} # Must call get_weather
)
# Disable tool calling
response = client.chat.completions.create(
model="llama3.2",
messages=[{"role": "user", "content": "What's the weather?"}],
tools=tools,
tool_choice="none" # Don't call any tools
)
```
```shell tool_choice.sh
# Force the model to call a tool
curl http://localhost:11434/v1/chat/completions \
-H "Content-Type: application/json" \
-d '{
"model": "llama3.2",
"messages": [{"role": "user", "content": "What is the weather in Paris?"}],
"tools": [{
"type": "function",
"function": {
"name": "get_weather",
"description": "Get weather for a location",
"parameters": {
"type": "object",
"properties": {
"location": {"type": "string"}
},
"required": ["location"]
}
}
}],
"tool_choice": "required"
}'
```
</CodeGroup>
### v1/chat/completions with vision example ### v1/chat/completions with vision example
<CodeGroup dropdown> <CodeGroup dropdown>
@ -207,7 +290,12 @@ curl -X POST http://localhost:11434/v1/chat/completions \
- [x] `top_p` - [x] `top_p`
- [x] `max_tokens` - [x] `max_tokens`
- [x] `tools` - [x] `tools`
- [ ] `tool_choice` - [x] `tool_choice`
- [x] `"auto"` (default)
- [x] `"none"`
- [x] `"required"`
- [x] `{"type": "function", "name": "function_name"}`
- [x] `{"type": "allowed_tools", "mode": "auto"|"required", "tools": [...]}`
- [ ] `logit_bias` - [ ] `logit_bias`
- [ ] `user` - [ ] `user`
- [ ] `n` - [ ] `n`

View File

@ -20,10 +20,12 @@ type BaseWriter struct {
} }
type ChatWriter struct { type ChatWriter struct {
stream bool stream bool
streamOptions *openai.StreamOptions streamOptions *openai.StreamOptions
id string id string
toolCallSent bool toolCallSent bool
forcedToolCall bool
forcedToolName string
BaseWriter BaseWriter
} }
@ -65,6 +67,40 @@ func (w *BaseWriter) writeError(data []byte) (int, error) {
return len(data), nil return len(data), nil
} }
func parseForcedToolCall(content string, forcedToolName string) *api.ToolCall {
if content == "" {
return nil
}
// Try to parse as tool call structure: {"name": "...", "arguments": {...}}
var toolCallJSON struct {
Name string `json:"name"`
Arguments map[string]any `json:"arguments"`
}
if err := json.Unmarshal([]byte(content), &toolCallJSON); err != nil {
return nil
}
// If a specific tool was forced, use that name
name := toolCallJSON.Name
if forcedToolName != "" {
name = forcedToolName
}
if name == "" {
return nil
}
return &api.ToolCall{
ID: fmt.Sprintf("call_%d", rand.Intn(999999)),
Function: api.ToolCallFunction{
Name: name,
Arguments: toolCallJSON.Arguments,
},
}
}
func (w *ChatWriter) writeResponse(data []byte) (int, error) { func (w *ChatWriter) writeResponse(data []byte) (int, error) {
var chatResponse api.ChatResponse var chatResponse api.ChatResponse
err := json.Unmarshal(data, &chatResponse) err := json.Unmarshal(data, &chatResponse)
@ -72,6 +108,15 @@ func (w *ChatWriter) writeResponse(data []byte) (int, error) {
return 0, err return 0, err
} }
// If tool_choice forced a tool call and we have content but no tool calls,
// try to parse the content as a tool call
if w.forcedToolCall && len(chatResponse.Message.ToolCalls) == 0 && chatResponse.Message.Content != "" {
if toolCall := parseForcedToolCall(chatResponse.Message.Content, w.forcedToolName); toolCall != nil {
chatResponse.Message.ToolCalls = []api.ToolCall{*toolCall}
chatResponse.Message.Content = ""
}
}
// chat chunk // chat chunk
if w.stream { if w.stream {
c := openai.ToChunk(w.id, chatResponse, w.toolCallSent) c := openai.ToChunk(w.id, chatResponse, w.toolCallSent)
@ -406,6 +451,16 @@ func ChatMiddleware() gin.HandlerFunc {
return return
} }
// Determine if tool_choice forces a tool call
var forcedToolCall bool
var forcedToolName string
if req.ToolChoice != nil && len(req.Tools) > 0 {
_, _, forcedToolCall, _ = openai.ApplyToolChoice(req.Tools, req.ToolChoice)
if req.ToolChoice.IsForcedFunction() {
forcedToolName = req.ToolChoice.GetForcedFunctionName()
}
}
var b bytes.Buffer var b bytes.Buffer
chatReq, err := openai.FromChatRequest(req) chatReq, err := openai.FromChatRequest(req)
@ -422,10 +477,12 @@ func ChatMiddleware() gin.HandlerFunc {
c.Request.Body = io.NopCloser(&b) c.Request.Body = io.NopCloser(&b)
w := &ChatWriter{ w := &ChatWriter{
BaseWriter: BaseWriter{ResponseWriter: c.Writer}, BaseWriter: BaseWriter{ResponseWriter: c.Writer},
stream: req.Stream, stream: req.Stream,
id: fmt.Sprintf("chatcmpl-%d", rand.Intn(999)), id: fmt.Sprintf("chatcmpl-%d", rand.Intn(999)),
streamOptions: req.StreamOptions, streamOptions: req.StreamOptions,
forcedToolCall: forcedToolCall,
forcedToolName: forcedToolName,
} }
c.Writer = w c.Writer = w

View File

@ -95,6 +95,132 @@ type Reasoning struct {
Effort string `json:"effort,omitempty"` Effort string `json:"effort,omitempty"`
} }
type ToolChoiceFunctionRef struct {
Name string `json:"name"`
}
type ToolChoiceAllowedTool struct {
Type string `json:"type"` // "function"
Name string `json:"name"`
}
type ToolChoiceObject struct {
Type string `json:"type,omitempty"`
Name string `json:"name,omitempty"`
Function *ToolChoiceFunctionRef `json:"function,omitempty"`
Mode string `json:"mode,omitempty"`
Tools []ToolChoiceAllowedTool `json:"tools,omitempty"`
}
type ToolChoice struct {
Mode string
Object *ToolChoiceObject
}
// UnmarshalJSON handles both string and object forms of tool_choice
func (tc *ToolChoice) UnmarshalJSON(data []byte) error {
// Try string first
var s string
if err := json.Unmarshal(data, &s); err == nil {
tc.Mode = s
tc.Object = nil
return nil
}
// Try object
var obj ToolChoiceObject
if err := json.Unmarshal(data, &obj); err != nil {
return fmt.Errorf("invalid tool_choice: must be string or object")
}
tc.Object = &obj
return nil
}
// MarshalJSON serializes ToolChoice back to JSON
func (tc ToolChoice) MarshalJSON() ([]byte, error) {
if tc.Object != nil {
return json.Marshal(tc.Object)
}
return json.Marshal(tc.Mode)
}
// IsNone returns true if tool_choice is "none"
func (tc *ToolChoice) IsNone() bool {
return tc != nil && tc.Mode == "none"
}
// IsRequired returns true if tool_choice is "required"
func (tc *ToolChoice) IsRequired() bool {
return tc != nil && tc.Mode == "required"
}
// IsAuto returns true if tool_choice is "auto" or not specified
func (tc *ToolChoice) IsAuto() bool {
if tc == nil {
return true
}
// If there's an object, check if it's allowed_tools with auto mode
if tc.Object != nil {
// allowed_tools with "auto" mode is still considered auto
if tc.Object.Type == "allowed_tools" && (tc.Object.Mode == "" || tc.Object.Mode == "auto") {
return true
}
// Any other object (forced function, allowed_tools with required) is not auto
return false
}
return tc.Mode == "" || tc.Mode == "auto"
}
// IsForcedFunction returns true if a specific function is forced
func (tc *ToolChoice) IsForcedFunction() bool {
if tc == nil || tc.Object == nil {
return false
}
return tc.Object.Type == "function" || tc.Object.Name != "" || tc.Object.Function != nil
}
// GetForcedFunctionName returns the name of the forced function, if any
func (tc *ToolChoice) GetForcedFunctionName() string {
if tc == nil || tc.Object == nil {
return ""
}
if tc.Object.Name != "" {
return tc.Object.Name
}
if tc.Object.Function != nil {
return tc.Object.Function.Name
}
return ""
}
// IsAllowedTools returns true if tool_choice uses allowed_tools mode
func (tc *ToolChoice) IsAllowedTools() bool {
return tc != nil && tc.Object != nil && tc.Object.Type == "allowed_tools"
}
// GetAllowedToolNames returns the list of allowed tool names
func (tc *ToolChoice) GetAllowedToolNames() []string {
if !tc.IsAllowedTools() || tc.Object.Tools == nil {
return nil
}
names := make([]string, len(tc.Object.Tools))
for i, t := range tc.Object.Tools {
names[i] = t.Name
}
return names
}
// GetAllowedToolsMode returns the mode for allowed_tools ("auto" or "required")
func (tc *ToolChoice) GetAllowedToolsMode() string {
if !tc.IsAllowedTools() {
return ""
}
if tc.Object.Mode == "" {
return "auto"
}
return tc.Object.Mode
}
type ChatCompletionRequest struct { type ChatCompletionRequest struct {
Model string `json:"model"` Model string `json:"model"`
Messages []Message `json:"messages"` Messages []Message `json:"messages"`
@ -109,6 +235,7 @@ type ChatCompletionRequest struct {
TopP *float64 `json:"top_p"` TopP *float64 `json:"top_p"`
ResponseFormat *ResponseFormat `json:"response_format"` ResponseFormat *ResponseFormat `json:"response_format"`
Tools []api.Tool `json:"tools"` Tools []api.Tool `json:"tools"`
ToolChoice *ToolChoice `json:"tool_choice,omitempty"`
Reasoning *Reasoning `json:"reasoning,omitempty"` Reasoning *Reasoning `json:"reasoning,omitempty"`
ReasoningEffort *string `json:"reasoning_effort,omitempty"` ReasoningEffort *string `json:"reasoning_effort,omitempty"`
Logprobs *bool `json:"logprobs"` Logprobs *bool `json:"logprobs"`
@ -444,6 +571,171 @@ func ToModel(r api.ShowResponse, m string) Model {
} }
} }
// filterToolsByNames returns only the tools that match the given names
func filterToolsByNames(tools []api.Tool, names []string) []api.Tool {
if len(names) == 0 {
return tools
}
nameSet := make(map[string]bool)
for _, name := range names {
nameSet[name] = true
}
var filtered []api.Tool
for _, tool := range tools {
if nameSet[tool.Function.Name] {
filtered = append(filtered, tool)
}
}
return filtered
}
// findToolByName returns the tool with the given name, or nil if not found
func findToolByName(tools []api.Tool, name string) *api.Tool {
for _, tool := range tools {
if tool.Function.Name == name {
return &tool
}
}
return nil
}
// generateToolCallSchema creates a JSON schema that constrains output to valid tool calls
func generateToolCallSchema(tools []api.Tool) json.RawMessage {
if len(tools) == 0 {
return nil
}
// Collect all tool names for the enum
toolNames := make([]string, len(tools))
for i, tool := range tools {
toolNames[i] = tool.Function.Name
}
// Build a schema that allows any of the tools
// Using oneOf for each tool with its specific parameter schema
var oneOfSchemas []map[string]any
for _, tool := range tools {
toolSchema := map[string]any{
"type": "object",
"properties": map[string]any{
"name": map[string]any{
"type": "string",
"const": tool.Function.Name,
},
"arguments": tool.Function.Parameters,
},
"required": []string{"name", "arguments"},
"additionalProperties": false,
}
oneOfSchemas = append(oneOfSchemas, toolSchema)
}
var schema map[string]any
if len(oneOfSchemas) == 1 {
// Single tool - use its schema directly
schema = oneOfSchemas[0]
} else {
// Multiple tools - use oneOf
schema = map[string]any{
"oneOf": oneOfSchemas,
}
}
bytes, err := json.Marshal(schema)
if err != nil {
slog.Error("failed to marshal tool call schema", "error", err)
return nil
}
return bytes
}
// generateForcedFunctionSchema creates a JSON schema for a specific forced function
func generateForcedFunctionSchema(tool api.Tool) json.RawMessage {
schema := map[string]any{
"type": "object",
"properties": map[string]any{
"name": map[string]any{
"type": "string",
"const": tool.Function.Name,
},
"arguments": tool.Function.Parameters,
},
"required": []string{"name", "arguments"},
"additionalProperties": false,
}
bytes, err := json.Marshal(schema)
if err != nil {
slog.Error("failed to marshal forced function schema", "error", err)
return nil
}
return bytes
}
// ApplyToolChoice processes tool_choice and returns filtered tools and optional format schema
// Returns:
// - filteredTools: the tools to pass to the model (may be empty for "none")
// - format: JSON schema to constrain output (for "required" or forced function)
// - forcedToolCall: true if the response should be parsed as a tool call
// - error: if tool_choice references a non-existent tool
func ApplyToolChoice(tools []api.Tool, toolChoice *ToolChoice) (filteredTools []api.Tool, format json.RawMessage, forcedToolCall bool, err error) {
// Default: auto mode, return all tools without format constraint
if toolChoice == nil || toolChoice.IsAuto() {
// Check for allowed_tools with auto mode
if toolChoice != nil && toolChoice.IsAllowedTools() {
allowedNames := toolChoice.GetAllowedToolNames()
filteredTools = filterToolsByNames(tools, allowedNames)
if toolChoice.GetAllowedToolsMode() == "required" {
format = generateToolCallSchema(filteredTools)
forcedToolCall = true
}
return filteredTools, format, forcedToolCall, nil
}
return tools, nil, false, nil
}
// "none" mode: don't pass any tools
if toolChoice.IsNone() {
return nil, nil, false, nil
}
// "required" mode: must call at least one tool
if toolChoice.IsRequired() {
format = generateToolCallSchema(tools)
return tools, format, true, nil
}
// Forced function mode
if toolChoice.IsForcedFunction() {
funcName := toolChoice.GetForcedFunctionName()
if funcName == "" {
return nil, nil, false, errors.New("tool_choice function name is required")
}
tool := findToolByName(tools, funcName)
if tool == nil {
return nil, nil, false, fmt.Errorf("tool_choice references unknown function: %s", funcName)
}
format = generateForcedFunctionSchema(*tool)
return []api.Tool{*tool}, format, true, nil
}
// allowed_tools mode (already handled in IsAuto check, but handle explicit type here)
if toolChoice.IsAllowedTools() {
allowedNames := toolChoice.GetAllowedToolNames()
filteredTools = filterToolsByNames(tools, allowedNames)
if toolChoice.GetAllowedToolsMode() == "required" {
format = generateToolCallSchema(filteredTools)
forcedToolCall = true
}
return filteredTools, format, forcedToolCall, nil
}
// Unknown mode, default to auto
return tools, nil, false, nil
}
// FromChatRequest converts a ChatCompletionRequest to api.ChatRequest // FromChatRequest converts a ChatCompletionRequest to api.ChatRequest
func FromChatRequest(r ChatCompletionRequest) (*api.ChatRequest, error) { func FromChatRequest(r ChatCompletionRequest) (*api.ChatRequest, error) {
var messages []api.Message var messages []api.Message
@ -579,6 +871,18 @@ func FromChatRequest(r ChatCompletionRequest) (*api.ChatRequest, error) {
} }
} }
// Apply tool_choice to filter tools and potentially set format constraint
filteredTools, toolChoiceFormat, _, err := ApplyToolChoice(r.Tools, r.ToolChoice)
if err != nil {
return nil, err
}
// If tool_choice requires a format constraint and no explicit response_format was set,
// apply the tool call schema
if toolChoiceFormat != nil && format == nil {
format = toolChoiceFormat
}
var think *api.ThinkValue var think *api.ThinkValue
var effort string var effort string
@ -606,7 +910,7 @@ func FromChatRequest(r ChatCompletionRequest) (*api.ChatRequest, error) {
Format: format, Format: format,
Options: options, Options: options,
Stream: &r.Stream, Stream: &r.Stream,
Tools: r.Tools, Tools: filteredTools,
Think: think, Think: think,
Logprobs: r.Logprobs != nil && *r.Logprobs, Logprobs: r.Logprobs != nil && *r.Logprobs,
TopLogprobs: r.TopLogprobs, TopLogprobs: r.TopLogprobs,

View File

@ -434,3 +434,480 @@ func TestFromChatRequest_TopLogprobsRange(t *testing.T) {
}) })
} }
} }
func TestToolChoice_UnmarshalJSON(t *testing.T) {
tests := []struct {
name string
json string
wantMode string
wantObj bool
}{
{
name: "string auto",
json: `"auto"`,
wantMode: "auto",
wantObj: false,
},
{
name: "string none",
json: `"none"`,
wantMode: "none",
wantObj: false,
},
{
name: "string required",
json: `"required"`,
wantMode: "required",
wantObj: false,
},
{
name: "object function with name",
json: `{"type": "function", "name": "get_weather"}`,
wantMode: "",
wantObj: true,
},
{
name: "object function with function.name",
json: `{"type": "function", "function": {"name": "get_weather"}}`,
wantMode: "",
wantObj: true,
},
{
name: "object allowed_tools",
json: `{"type": "allowed_tools", "mode": "required", "tools": [{"type": "function", "name": "get_weather"}]}`,
wantMode: "",
wantObj: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
var tc ToolChoice
if err := tc.UnmarshalJSON([]byte(tt.json)); err != nil {
t.Fatalf("unexpected error: %v", err)
}
if tc.Mode != tt.wantMode {
t.Errorf("Mode = %q, want %q", tc.Mode, tt.wantMode)
}
if (tc.Object != nil) != tt.wantObj {
t.Errorf("Object = %v, wantObj = %v", tc.Object, tt.wantObj)
}
})
}
}
func TestToolChoice_Methods(t *testing.T) {
tests := []struct {
name string
toolChoice *ToolChoice
isNone bool
isRequired bool
isAuto bool
isForcedFunction bool
forcedFuncName string
isAllowedTools bool
allowedToolNames []string
allowedToolsMode string
}{
{
name: "nil",
toolChoice: nil,
isNone: false,
isRequired: false,
isAuto: true,
},
{
name: "auto string",
toolChoice: &ToolChoice{Mode: "auto"},
isNone: false,
isRequired: false,
isAuto: true,
},
{
name: "none string",
toolChoice: &ToolChoice{Mode: "none"},
isNone: true,
isRequired: false,
isAuto: false,
},
{
name: "required string",
toolChoice: &ToolChoice{Mode: "required"},
isNone: false,
isRequired: true,
isAuto: false,
},
{
name: "forced function with name",
toolChoice: &ToolChoice{
Object: &ToolChoiceObject{Type: "function", Name: "get_weather"},
},
isNone: false,
isRequired: false,
isAuto: false,
isForcedFunction: true,
forcedFuncName: "get_weather",
},
{
name: "forced function with function.name",
toolChoice: &ToolChoice{
Object: &ToolChoiceObject{
Type: "function",
Function: &ToolChoiceFunctionRef{Name: "search"},
},
},
isNone: false,
isRequired: false,
isAuto: false,
isForcedFunction: true,
forcedFuncName: "search",
},
{
name: "allowed_tools auto",
toolChoice: &ToolChoice{
Object: &ToolChoiceObject{
Type: "allowed_tools",
Mode: "auto",
Tools: []ToolChoiceAllowedTool{
{Type: "function", Name: "get_weather"},
{Type: "function", Name: "search"},
},
},
},
isNone: false,
isRequired: false,
isAuto: true,
isForcedFunction: false,
isAllowedTools: true,
allowedToolNames: []string{"get_weather", "search"},
allowedToolsMode: "auto",
},
{
name: "allowed_tools required",
toolChoice: &ToolChoice{
Object: &ToolChoiceObject{
Type: "allowed_tools",
Mode: "required",
Tools: []ToolChoiceAllowedTool{
{Type: "function", Name: "get_weather"},
},
},
},
isNone: false,
isRequired: false,
isAuto: false,
isForcedFunction: false,
isAllowedTools: true,
allowedToolNames: []string{"get_weather"},
allowedToolsMode: "required",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := tt.toolChoice.IsNone(); got != tt.isNone {
t.Errorf("IsNone() = %v, want %v", got, tt.isNone)
}
if got := tt.toolChoice.IsRequired(); got != tt.isRequired {
t.Errorf("IsRequired() = %v, want %v", got, tt.isRequired)
}
if got := tt.toolChoice.IsAuto(); got != tt.isAuto {
t.Errorf("IsAuto() = %v, want %v", got, tt.isAuto)
}
if got := tt.toolChoice.IsForcedFunction(); got != tt.isForcedFunction {
t.Errorf("IsForcedFunction() = %v, want %v", got, tt.isForcedFunction)
}
if got := tt.toolChoice.GetForcedFunctionName(); got != tt.forcedFuncName {
t.Errorf("GetForcedFunctionName() = %q, want %q", got, tt.forcedFuncName)
}
if got := tt.toolChoice.IsAllowedTools(); got != tt.isAllowedTools {
t.Errorf("IsAllowedTools() = %v, want %v", got, tt.isAllowedTools)
}
if tt.isAllowedTools {
if got := tt.toolChoice.GetAllowedToolNames(); !cmp.Equal(got, tt.allowedToolNames) {
t.Errorf("GetAllowedToolNames() = %v, want %v", got, tt.allowedToolNames)
}
if got := tt.toolChoice.GetAllowedToolsMode(); got != tt.allowedToolsMode {
t.Errorf("GetAllowedToolsMode() = %q, want %q", got, tt.allowedToolsMode)
}
}
})
}
}
func TestApplyToolChoice(t *testing.T) {
tools := []api.Tool{
{
Type: "function",
Function: api.ToolFunction{
Name: "get_weather",
Description: "Get weather",
Parameters: api.ToolFunctionParameters{
Type: "object",
Properties: map[string]api.ToolProperty{
"location": {Type: []string{"string"}},
},
Required: []string{"location"},
},
},
},
{
Type: "function",
Function: api.ToolFunction{
Name: "search",
Description: "Search the web",
Parameters: api.ToolFunctionParameters{
Type: "object",
Properties: map[string]api.ToolProperty{
"query": {Type: []string{"string"}},
},
Required: []string{"query"},
},
},
},
}
tests := []struct {
name string
toolChoice *ToolChoice
wantToolCount int
wantFormat bool
wantForced bool
wantError bool
wantToolNames []string
}{
{
name: "nil (auto)",
toolChoice: nil,
wantToolCount: 2,
wantFormat: false,
wantForced: false,
},
{
name: "auto string",
toolChoice: &ToolChoice{Mode: "auto"},
wantToolCount: 2,
wantFormat: false,
wantForced: false,
},
{
name: "none string",
toolChoice: &ToolChoice{Mode: "none"},
wantToolCount: 0,
wantFormat: false,
wantForced: false,
},
{
name: "required string",
toolChoice: &ToolChoice{Mode: "required"},
wantToolCount: 2,
wantFormat: true,
wantForced: true,
},
{
name: "forced function",
toolChoice: &ToolChoice{
Object: &ToolChoiceObject{Type: "function", Name: "get_weather"},
},
wantToolCount: 1,
wantFormat: true,
wantForced: true,
wantToolNames: []string{"get_weather"},
},
{
name: "forced unknown function",
toolChoice: &ToolChoice{
Object: &ToolChoiceObject{Type: "function", Name: "unknown_func"},
},
wantError: true,
},
{
name: "allowed_tools auto",
toolChoice: &ToolChoice{
Object: &ToolChoiceObject{
Type: "allowed_tools",
Mode: "auto",
Tools: []ToolChoiceAllowedTool{
{Type: "function", Name: "get_weather"},
},
},
},
wantToolCount: 1,
wantFormat: false,
wantForced: false,
wantToolNames: []string{"get_weather"},
},
{
name: "allowed_tools required",
toolChoice: &ToolChoice{
Object: &ToolChoiceObject{
Type: "allowed_tools",
Mode: "required",
Tools: []ToolChoiceAllowedTool{
{Type: "function", Name: "search"},
},
},
},
wantToolCount: 1,
wantFormat: true,
wantForced: true,
wantToolNames: []string{"search"},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
filteredTools, format, forced, err := ApplyToolChoice(tools, tt.toolChoice)
if tt.wantError {
if err == nil {
t.Errorf("expected error, got nil")
}
return
}
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if len(filteredTools) != tt.wantToolCount {
t.Errorf("got %d tools, want %d", len(filteredTools), tt.wantToolCount)
}
if (format != nil) != tt.wantFormat {
t.Errorf("format = %v, wantFormat = %v", format != nil, tt.wantFormat)
}
if forced != tt.wantForced {
t.Errorf("forced = %v, want %v", forced, tt.wantForced)
}
if tt.wantToolNames != nil {
gotNames := make([]string, len(filteredTools))
for i, tool := range filteredTools {
gotNames[i] = tool.Function.Name
}
if !cmp.Equal(gotNames, tt.wantToolNames) {
t.Errorf("tool names = %v, want %v", gotNames, tt.wantToolNames)
}
}
})
}
}
func TestFromChatRequest_WithToolChoice(t *testing.T) {
tools := []api.Tool{
{
Type: "function",
Function: api.ToolFunction{
Name: "get_weather",
Description: "Get weather",
Parameters: api.ToolFunctionParameters{
Type: "object",
Properties: map[string]api.ToolProperty{
"location": {Type: []string{"string"}},
},
},
},
},
{
Type: "function",
Function: api.ToolFunction{
Name: "search",
Description: "Search",
Parameters: api.ToolFunctionParameters{
Type: "object",
Properties: map[string]api.ToolProperty{
"query": {Type: []string{"string"}},
},
},
},
},
}
t.Run("tool_choice none removes tools", func(t *testing.T) {
req := ChatCompletionRequest{
Model: "test-model",
Messages: []Message{{Role: "user", Content: "Hello"}},
Tools: tools,
ToolChoice: &ToolChoice{Mode: "none"},
}
result, err := FromChatRequest(req)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if len(result.Tools) != 0 {
t.Errorf("expected 0 tools, got %d", len(result.Tools))
}
})
t.Run("tool_choice required adds format", func(t *testing.T) {
req := ChatCompletionRequest{
Model: "test-model",
Messages: []Message{{Role: "user", Content: "Hello"}},
Tools: tools,
ToolChoice: &ToolChoice{Mode: "required"},
}
result, err := FromChatRequest(req)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if len(result.Tools) != 2 {
t.Errorf("expected 2 tools, got %d", len(result.Tools))
}
if result.Format == nil {
t.Error("expected format to be set for required tool_choice")
}
})
t.Run("tool_choice forced function filters and adds format", func(t *testing.T) {
req := ChatCompletionRequest{
Model: "test-model",
Messages: []Message{{Role: "user", Content: "Hello"}},
Tools: tools,
ToolChoice: &ToolChoice{
Object: &ToolChoiceObject{Type: "function", Name: "get_weather"},
},
}
result, err := FromChatRequest(req)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if len(result.Tools) != 1 {
t.Errorf("expected 1 tool, got %d", len(result.Tools))
}
if result.Tools[0].Function.Name != "get_weather" {
t.Errorf("expected tool 'get_weather', got %q", result.Tools[0].Function.Name)
}
if result.Format == nil {
t.Error("expected format to be set for forced function")
}
})
t.Run("tool_choice unknown function returns error", func(t *testing.T) {
req := ChatCompletionRequest{
Model: "test-model",
Messages: []Message{{Role: "user", Content: "Hello"}},
Tools: tools,
ToolChoice: &ToolChoice{
Object: &ToolChoiceObject{Type: "function", Name: "unknown"},
},
}
_, err := FromChatRequest(req)
if err == nil {
t.Error("expected error for unknown function")
}
})
}

View File

@ -1861,6 +1861,112 @@ func toolCallId() string {
return "call_" + strings.ToLower(string(b)) return "call_" + strings.ToLower(string(b))
} }
// findToolByName returns the tool with the given name, or nil if not found.
func findToolByName(tools []api.Tool, name string) *api.Tool {
for i := range tools {
if tools[i].Function.Name == name {
return &tools[i]
}
}
return nil
}
// generateToolCallSchema creates a JSON schema that constrains output to valid tool calls.
func generateToolCallSchema(tools []api.Tool) json.RawMessage {
if len(tools) == 0 {
return nil
}
var oneOfSchemas []map[string]any
for _, tool := range tools {
toolSchema := map[string]any{
"type": "object",
"properties": map[string]any{
"name": map[string]any{
"type": "string",
"const": tool.Function.Name,
},
"arguments": tool.Function.Parameters,
},
"required": []string{"name", "arguments"},
"additionalProperties": false,
}
oneOfSchemas = append(oneOfSchemas, toolSchema)
}
var schema map[string]any
if len(oneOfSchemas) == 1 {
schema = oneOfSchemas[0]
} else {
schema = map[string]any{
"oneOf": oneOfSchemas,
}
}
bytes, err := json.Marshal(schema)
if err != nil {
slog.Error("failed to marshal tool call schema", "error", err)
return nil
}
return bytes
}
// generateForcedFunctionSchema creates a JSON schema for a specific forced function.
func generateForcedFunctionSchema(tool api.Tool) json.RawMessage {
schema := map[string]any{
"type": "object",
"properties": map[string]any{
"name": map[string]any{
"type": "string",
"const": tool.Function.Name,
},
"arguments": tool.Function.Parameters,
},
"required": []string{"name", "arguments"},
"additionalProperties": false,
}
bytes, err := json.Marshal(schema)
if err != nil {
slog.Error("failed to marshal forced function schema", "error", err)
return nil
}
return bytes
}
// parseForcedToolCallContent parses content as a forced tool call JSON.
func parseForcedToolCallContent(content string, forcedToolName string) *api.ToolCall {
if content == "" {
return nil
}
var toolCallJSON struct {
Name string `json:"name"`
Arguments map[string]any `json:"arguments"`
}
if err := json.Unmarshal([]byte(content), &toolCallJSON); err != nil {
return nil
}
name := toolCallJSON.Name
if forcedToolName != "" {
name = forcedToolName
}
if name == "" {
return nil
}
return &api.ToolCall{
ID: toolCallId(),
Function: api.ToolCallFunction{
Name: name,
Arguments: toolCallJSON.Arguments,
},
}
}
func (s *Server) ChatHandler(c *gin.Context) { func (s *Server) ChatHandler(c *gin.Context) {
checkpointStart := time.Now() checkpointStart := time.Now()
@ -2075,6 +2181,36 @@ func (s *Server) ChatHandler(c *gin.Context) {
} }
} }
// Handle tool_choice
var forcedToolCall bool
var forcedToolName string
if req.ToolChoice != nil && len(req.Tools) > 0 {
if req.ToolChoice.IsNone() {
// "none" mode: don't pass any tools
processedTools = nil
} else if req.ToolChoice.IsRequired() {
// "required" mode: must call at least one tool, generate JSON schema
if req.Format == nil {
req.Format = generateToolCallSchema(req.Tools)
}
forcedToolCall = true
} else if req.ToolChoice.IsForcedFunction() {
// Forced function mode: must call this specific function
funcName := req.ToolChoice.GetForcedFunctionName()
tool := findToolByName(req.Tools, funcName)
if tool == nil {
c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("tool_choice references unknown function: %s", funcName)})
return
}
if req.Format == nil {
req.Format = generateForcedFunctionSchema(*tool)
}
processedTools = []api.Tool{*tool}
forcedToolCall = true
forcedToolName = funcName
}
}
truncate := req.Truncate == nil || *req.Truncate truncate := req.Truncate == nil || *req.Truncate
prompt, images, err := chatPrompt(c.Request.Context(), m, r.Tokenize, opts, msgs, processedTools, req.Think, truncate) prompt, images, err := chatPrompt(c.Request.Context(), m, r.Tokenize, opts, msgs, processedTools, req.Think, truncate)
if err != nil { if err != nil {
@ -2206,6 +2342,15 @@ func (s *Server) ChatHandler(c *gin.Context) {
return return
} }
// If tool_choice forced a tool call and builtinParser didn't find any,
// try to parse the content as a forced tool call JSON
if r.Done && forcedToolCall && len(res.Message.ToolCalls) == 0 && res.Message.Content != "" {
if toolCall := parseForcedToolCallContent(res.Message.Content, forcedToolName); toolCall != nil {
res.Message.ToolCalls = []api.ToolCall{*toolCall}
res.Message.Content = ""
}
}
if res.Message.Content != "" || res.Message.Thinking != "" || len(res.Message.ToolCalls) > 0 || r.Done || len(res.Logprobs) > 0 { if res.Message.Content != "" || res.Message.Thinking != "" || len(res.Message.ToolCalls) > 0 || r.Done || len(res.Logprobs) > 0 {
slog.Log(context.TODO(), logutil.LevelTrace, "builtin parser output", "parser", m.Config.Parser, "content", content, "thinking", thinking, "toolCalls", toolCalls, "done", r.Done) slog.Log(context.TODO(), logutil.LevelTrace, "builtin parser output", "parser", m.Config.Parser, "content", content, "thinking", thinking, "toolCalls", toolCalls, "done", r.Done)
ch <- res ch <- res
@ -2235,7 +2380,7 @@ func (s *Server) ChatHandler(c *gin.Context) {
res.Message.Content = remainingContent res.Message.Content = remainingContent
} }
if len(req.Tools) > 0 { if len(req.Tools) > 0 && toolParser != nil {
toolCalls, content := toolParser.Add(res.Message.Content) toolCalls, content := toolParser.Add(res.Message.Content)
if len(content) > 0 { if len(content) > 0 {
res.Message.Content = content res.Message.Content = content
@ -2258,12 +2403,28 @@ func (s *Server) ChatHandler(c *gin.Context) {
if r.Done { if r.Done {
res.Message.Content = toolParser.Content() res.Message.Content = toolParser.Content()
// If tool_choice forced a tool call, try to parse the buffered content
if forcedToolCall && len(res.Message.ToolCalls) == 0 && res.Message.Content != "" {
if toolCall := parseForcedToolCallContent(res.Message.Content, forcedToolName); toolCall != nil {
res.Message.ToolCalls = []api.ToolCall{*toolCall}
res.Message.Content = ""
}
}
ch <- res ch <- res
} }
return return
} }
} }
// If tool_choice forced a tool call and we have content but no tool calls,
// try to parse the content as a tool call (used when format constraint was applied)
if r.Done && forcedToolCall && len(res.Message.ToolCalls) == 0 && res.Message.Content != "" {
if toolCall := parseForcedToolCallContent(res.Message.Content, forcedToolName); toolCall != nil {
res.Message.ToolCalls = []api.ToolCall{*toolCall}
res.Message.Content = ""
}
}
ch <- res ch <- res
}) })
if err != nil { if err != nil {
@ -2353,6 +2514,18 @@ func (s *Server) ChatHandler(c *gin.Context) {
if len(toolCalls) > 0 { if len(toolCalls) > 0 {
resp.Message.ToolCalls = toolCalls resp.Message.ToolCalls = toolCalls
// If we have tool calls from forced tool_choice, the "content" was actually
// the JSON that got parsed into tool calls, so clear it
if forcedToolCall {
resp.Message.Content = ""
}
} else if forcedToolCall && resp.Message.Content != "" {
// No tool calls were parsed in callbacks, but we have forced tool_choice.
// Try to parse the accumulated content as a tool call JSON.
if toolCall := parseForcedToolCallContent(resp.Message.Content, forcedToolName); toolCall != nil {
resp.Message.ToolCalls = []api.ToolCall{*toolCall}
resp.Message.Content = ""
}
} }
c.JSON(http.StatusOK, resp) c.JSON(http.StatusOK, resp)

View File

@ -978,3 +978,227 @@ func TestWaitForStream(t *testing.T) {
}) })
} }
} }
func TestFindToolByName(t *testing.T) {
tools := []api.Tool{
{
Type: "function",
Function: api.ToolFunction{
Name: "get_weather",
Description: "Get weather for a location",
},
},
{
Type: "function",
Function: api.ToolFunction{
Name: "search_web",
Description: "Search the web",
},
},
}
t.Run("found", func(t *testing.T) {
tool := findToolByName(tools, "get_weather")
if tool == nil {
t.Fatal("expected to find tool")
}
if tool.Function.Name != "get_weather" {
t.Errorf("expected get_weather, got %s", tool.Function.Name)
}
})
t.Run("not found", func(t *testing.T) {
tool := findToolByName(tools, "nonexistent")
if tool != nil {
t.Error("expected nil for nonexistent tool")
}
})
t.Run("empty tools", func(t *testing.T) {
tool := findToolByName([]api.Tool{}, "get_weather")
if tool != nil {
t.Error("expected nil for empty tools")
}
})
}
func TestGenerateToolCallSchema(t *testing.T) {
tools := []api.Tool{
{
Type: "function",
Function: api.ToolFunction{
Name: "get_weather",
Description: "Get weather",
Parameters: api.ToolFunctionParameters{
Type: "object",
Required: []string{"location"},
Properties: map[string]api.ToolProperty{
"location": {Type: api.PropertyType{"string"}},
},
},
},
},
}
t.Run("single tool", func(t *testing.T) {
schema := generateToolCallSchema(tools)
if schema == nil {
t.Fatal("expected schema, got nil")
}
var parsed map[string]any
if err := json.Unmarshal(schema, &parsed); err != nil {
t.Fatalf("failed to parse schema: %v", err)
}
// Should have properties with name and arguments
props, ok := parsed["properties"].(map[string]any)
if !ok {
t.Fatal("expected properties in schema")
}
if _, ok := props["name"]; !ok {
t.Error("expected name property")
}
if _, ok := props["arguments"]; !ok {
t.Error("expected arguments property")
}
})
t.Run("multiple tools", func(t *testing.T) {
multiTools := append(tools, api.Tool{
Type: "function",
Function: api.ToolFunction{
Name: "search_web",
},
})
schema := generateToolCallSchema(multiTools)
if schema == nil {
t.Fatal("expected schema, got nil")
}
var parsed map[string]any
if err := json.Unmarshal(schema, &parsed); err != nil {
t.Fatalf("failed to parse schema: %v", err)
}
// Should have oneOf for multiple tools
if _, ok := parsed["oneOf"]; !ok {
t.Error("expected oneOf for multiple tools")
}
})
t.Run("empty tools", func(t *testing.T) {
schema := generateToolCallSchema([]api.Tool{})
if schema != nil {
t.Error("expected nil for empty tools")
}
})
}
func TestGenerateForcedFunctionSchema(t *testing.T) {
tool := api.Tool{
Type: "function",
Function: api.ToolFunction{
Name: "get_weather",
Description: "Get weather",
Parameters: api.ToolFunctionParameters{
Type: "object",
Required: []string{"location"},
Properties: map[string]api.ToolProperty{
"location": {Type: api.PropertyType{"string"}},
},
},
},
}
schema := generateForcedFunctionSchema(tool)
if schema == nil {
t.Fatal("expected schema, got nil")
}
var parsed map[string]any
if err := json.Unmarshal(schema, &parsed); err != nil {
t.Fatalf("failed to parse schema: %v", err)
}
// Check that name has const constraint
props := parsed["properties"].(map[string]any)
nameSchema := props["name"].(map[string]any)
if nameSchema["const"] != "get_weather" {
t.Errorf("expected const get_weather, got %v", nameSchema["const"])
}
}
func TestParseForcedToolCallContent(t *testing.T) {
t.Run("valid tool call JSON", func(t *testing.T) {
content := `{"name": "get_weather", "arguments": {"location": "Paris"}}`
toolCall := parseForcedToolCallContent(content, "")
if toolCall == nil {
t.Fatal("expected tool call, got nil")
}
if toolCall.Function.Name != "get_weather" {
t.Errorf("expected get_weather, got %s", toolCall.Function.Name)
}
if toolCall.Function.Arguments["location"] != "Paris" {
t.Errorf("expected Paris, got %v", toolCall.Function.Arguments["location"])
}
if toolCall.ID == "" {
t.Error("expected non-empty tool call ID")
}
})
t.Run("forced tool name override", func(t *testing.T) {
content := `{"name": "other_tool", "arguments": {"location": "Paris"}}`
toolCall := parseForcedToolCallContent(content, "get_weather")
if toolCall == nil {
t.Fatal("expected tool call, got nil")
}
// Should use the forced name, not the one in JSON
if toolCall.Function.Name != "get_weather" {
t.Errorf("expected get_weather (forced), got %s", toolCall.Function.Name)
}
})
t.Run("empty content", func(t *testing.T) {
toolCall := parseForcedToolCallContent("", "")
if toolCall != nil {
t.Error("expected nil for empty content")
}
})
t.Run("invalid JSON", func(t *testing.T) {
toolCall := parseForcedToolCallContent("not json", "")
if toolCall != nil {
t.Error("expected nil for invalid JSON")
}
})
t.Run("missing name", func(t *testing.T) {
content := `{"arguments": {"location": "Paris"}}`
toolCall := parseForcedToolCallContent(content, "")
if toolCall != nil {
t.Error("expected nil when name is missing and not forced")
}
})
t.Run("missing name but forced", func(t *testing.T) {
content := `{"arguments": {"location": "Paris"}}`
toolCall := parseForcedToolCallContent(content, "get_weather")
if toolCall == nil {
t.Fatal("expected tool call with forced name")
}
if toolCall.Function.Name != "get_weather" {
t.Errorf("expected get_weather, got %s", toolCall.Function.Name)
}
})
}