Merge a9a62a07d6 into d087e46bd1
This commit is contained in:
commit
1af0c84bbd
88
api/types.go
88
api/types.go
|
|
@ -148,6 +148,13 @@ type ChatRequest struct {
|
|||
// Tools is an optional list of tools the model has access to.
|
||||
Tools `json:"tools,omitempty"`
|
||||
|
||||
// ToolChoice controls how the model uses tools. Can be:
|
||||
// - "auto" (default): model decides whether to call tools
|
||||
// - "none": model won't call any tools
|
||||
// - "required": model must call at least one tool
|
||||
// - ToolChoiceFunction{Name: "func_name"}: model must call this specific function
|
||||
ToolChoice *ToolChoice `json:"tool_choice,omitempty"`
|
||||
|
||||
// Options lists model-specific options.
|
||||
Options map[string]any `json:"options"`
|
||||
|
||||
|
|
@ -184,6 +191,87 @@ func (t Tools) String() string {
|
|||
return string(bts)
|
||||
}
|
||||
|
||||
// ToolChoice controls how the model uses tools.
|
||||
// It can be a string ("auto", "none", "required") or a ToolChoiceFunction.
|
||||
type ToolChoice struct {
|
||||
// Mode is the tool choice mode: "auto", "none", or "required"
|
||||
Mode string `json:"-"`
|
||||
// Function specifies a specific function to call (when forcing a specific tool)
|
||||
Function *ToolChoiceFunction `json:"-"`
|
||||
}
|
||||
|
||||
// ToolChoiceFunction specifies a specific function that the model must call.
|
||||
type ToolChoiceFunction struct {
|
||||
Name string `json:"name"`
|
||||
}
|
||||
|
||||
// UnmarshalJSON handles both string and object forms of tool_choice.
|
||||
func (tc *ToolChoice) UnmarshalJSON(data []byte) error {
|
||||
// Try string first: "auto", "none", "required"
|
||||
var s string
|
||||
if err := json.Unmarshal(data, &s); err == nil {
|
||||
tc.Mode = s
|
||||
tc.Function = nil
|
||||
return nil
|
||||
}
|
||||
|
||||
// Try object with function name: {"function": {"name": "func_name"}}
|
||||
var obj struct {
|
||||
Function *ToolChoiceFunction `json:"function"`
|
||||
}
|
||||
if err := json.Unmarshal(data, &obj); err == nil && obj.Function != nil {
|
||||
tc.Function = obj.Function
|
||||
tc.Mode = ""
|
||||
return nil
|
||||
}
|
||||
|
||||
// Try simple object with just name: {"name": "func_name"}
|
||||
var simple ToolChoiceFunction
|
||||
if err := json.Unmarshal(data, &simple); err == nil && simple.Name != "" {
|
||||
tc.Function = &simple
|
||||
tc.Mode = ""
|
||||
return nil
|
||||
}
|
||||
|
||||
return fmt.Errorf("invalid tool_choice: must be string or object with function name")
|
||||
}
|
||||
|
||||
// MarshalJSON serializes ToolChoice back to JSON.
|
||||
func (tc ToolChoice) MarshalJSON() ([]byte, error) {
|
||||
if tc.Function != nil {
|
||||
return json.Marshal(map[string]any{"function": tc.Function})
|
||||
}
|
||||
return json.Marshal(tc.Mode)
|
||||
}
|
||||
|
||||
// IsNone returns true if tool_choice is "none".
|
||||
func (tc *ToolChoice) IsNone() bool {
|
||||
return tc != nil && tc.Mode == "none"
|
||||
}
|
||||
|
||||
// IsRequired returns true if tool_choice is "required".
|
||||
func (tc *ToolChoice) IsRequired() bool {
|
||||
return tc != nil && tc.Mode == "required"
|
||||
}
|
||||
|
||||
// IsAuto returns true if tool_choice is "auto" or not specified.
|
||||
func (tc *ToolChoice) IsAuto() bool {
|
||||
return tc == nil || tc.Mode == "" || tc.Mode == "auto"
|
||||
}
|
||||
|
||||
// IsForcedFunction returns true if a specific function is forced.
|
||||
func (tc *ToolChoice) IsForcedFunction() bool {
|
||||
return tc != nil && tc.Function != nil && tc.Function.Name != ""
|
||||
}
|
||||
|
||||
// GetForcedFunctionName returns the name of the forced function, if any.
|
||||
func (tc *ToolChoice) GetForcedFunctionName() string {
|
||||
if tc == nil || tc.Function == nil {
|
||||
return ""
|
||||
}
|
||||
return tc.Function.Name
|
||||
}
|
||||
|
||||
func (t Tool) String() string {
|
||||
bts, _ := json.Marshal(t)
|
||||
return string(bts)
|
||||
|
|
|
|||
|
|
@ -651,3 +651,170 @@ func TestToolFunctionParameters_String(t *testing.T) {
|
|||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestToolChoice_UnmarshalJSON(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
expectedMode string
|
||||
expectedFunc string
|
||||
expectedError bool
|
||||
}{
|
||||
{
|
||||
name: "auto string",
|
||||
input: `{"tool_choice": "auto"}`,
|
||||
expectedMode: "auto",
|
||||
expectedFunc: "",
|
||||
},
|
||||
{
|
||||
name: "none string",
|
||||
input: `{"tool_choice": "none"}`,
|
||||
expectedMode: "none",
|
||||
expectedFunc: "",
|
||||
},
|
||||
{
|
||||
name: "required string",
|
||||
input: `{"tool_choice": "required"}`,
|
||||
expectedMode: "required",
|
||||
expectedFunc: "",
|
||||
},
|
||||
{
|
||||
name: "function object with nested function",
|
||||
input: `{"tool_choice": {"function": {"name": "get_weather"}}}`,
|
||||
expectedMode: "",
|
||||
expectedFunc: "get_weather",
|
||||
},
|
||||
{
|
||||
name: "function object with direct name",
|
||||
input: `{"tool_choice": {"name": "get_weather"}}`,
|
||||
expectedMode: "",
|
||||
expectedFunc: "get_weather",
|
||||
},
|
||||
{
|
||||
name: "unset",
|
||||
input: `{}`,
|
||||
expectedMode: "",
|
||||
expectedFunc: "",
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
var req ChatRequest
|
||||
err := json.Unmarshal([]byte(test.input), &req)
|
||||
|
||||
if test.expectedError {
|
||||
require.Error(t, err)
|
||||
return
|
||||
}
|
||||
|
||||
require.NoError(t, err)
|
||||
|
||||
if test.expectedMode == "" && test.expectedFunc == "" && test.name == "unset" {
|
||||
assert.Nil(t, req.ToolChoice)
|
||||
return
|
||||
}
|
||||
|
||||
require.NotNil(t, req.ToolChoice)
|
||||
|
||||
if test.expectedMode != "" {
|
||||
assert.Equal(t, test.expectedMode, req.ToolChoice.Mode)
|
||||
}
|
||||
|
||||
if test.expectedFunc != "" {
|
||||
require.NotNil(t, req.ToolChoice.Function)
|
||||
assert.Equal(t, test.expectedFunc, req.ToolChoice.Function.Name)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestToolChoice_Methods(t *testing.T) {
|
||||
t.Run("IsNone", func(t *testing.T) {
|
||||
tc := &ToolChoice{Mode: "none"}
|
||||
assert.True(t, tc.IsNone())
|
||||
|
||||
tc = &ToolChoice{Mode: "auto"}
|
||||
assert.False(t, tc.IsNone())
|
||||
|
||||
var nilTc *ToolChoice
|
||||
assert.False(t, nilTc.IsNone())
|
||||
})
|
||||
|
||||
t.Run("IsRequired", func(t *testing.T) {
|
||||
tc := &ToolChoice{Mode: "required"}
|
||||
assert.True(t, tc.IsRequired())
|
||||
|
||||
tc = &ToolChoice{Mode: "auto"}
|
||||
assert.False(t, tc.IsRequired())
|
||||
|
||||
var nilTc *ToolChoice
|
||||
assert.False(t, nilTc.IsRequired())
|
||||
})
|
||||
|
||||
t.Run("IsAuto", func(t *testing.T) {
|
||||
tc := &ToolChoice{Mode: "auto"}
|
||||
assert.True(t, tc.IsAuto())
|
||||
|
||||
tc = &ToolChoice{Mode: ""}
|
||||
assert.True(t, tc.IsAuto())
|
||||
|
||||
var nilTc *ToolChoice
|
||||
assert.True(t, nilTc.IsAuto())
|
||||
|
||||
tc = &ToolChoice{Mode: "required"}
|
||||
assert.False(t, tc.IsAuto())
|
||||
})
|
||||
|
||||
t.Run("IsForcedFunction", func(t *testing.T) {
|
||||
tc := &ToolChoice{Function: &ToolChoiceFunction{Name: "get_weather"}}
|
||||
assert.True(t, tc.IsForcedFunction())
|
||||
|
||||
tc = &ToolChoice{Mode: "required"}
|
||||
assert.False(t, tc.IsForcedFunction())
|
||||
|
||||
tc = &ToolChoice{Function: &ToolChoiceFunction{Name: ""}}
|
||||
assert.False(t, tc.IsForcedFunction())
|
||||
|
||||
var nilTc *ToolChoice
|
||||
assert.False(t, nilTc.IsForcedFunction())
|
||||
})
|
||||
|
||||
t.Run("GetForcedFunctionName", func(t *testing.T) {
|
||||
tc := &ToolChoice{Function: &ToolChoiceFunction{Name: "get_weather"}}
|
||||
assert.Equal(t, "get_weather", tc.GetForcedFunctionName())
|
||||
|
||||
tc = &ToolChoice{Mode: "required"}
|
||||
assert.Equal(t, "", tc.GetForcedFunctionName())
|
||||
|
||||
var nilTc *ToolChoice
|
||||
assert.Equal(t, "", nilTc.GetForcedFunctionName())
|
||||
})
|
||||
}
|
||||
|
||||
func TestToolChoice_MarshalJSON(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
tc ToolChoice
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "mode string",
|
||||
tc: ToolChoice{Mode: "required"},
|
||||
expected: `"required"`,
|
||||
},
|
||||
{
|
||||
name: "function object",
|
||||
tc: ToolChoice{Function: &ToolChoiceFunction{Name: "get_weather"}},
|
||||
expected: `{"function":{"name":"get_weather"}}`,
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
data, err := json.Marshal(test.tc)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, test.expected, string(data))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
|
|||
34
docs/api.md
34
docs/api.md
|
|
@ -493,6 +493,7 @@ Generate the next message in a chat with a provided model. This is a streaming e
|
|||
- `model`: (required) the [model name](#model-names)
|
||||
- `messages`: the messages of the chat, this can be used to keep a chat memory
|
||||
- `tools`: list of tools in JSON for the model to use if supported
|
||||
- `tool_choice`: controls how the model uses tools (see [Tool choice](#tool-choice) below)
|
||||
- `think`: (for thinking models) should the model think before responding?
|
||||
|
||||
The `message` object has the following fields:
|
||||
|
|
@ -519,6 +520,39 @@ Models can also explain the result of the tool call in the response. See the [Ch
|
|||
|
||||
[See models with tool calling capabilities](https://ollama.com/search?c=tool).
|
||||
|
||||
### Tool choice
|
||||
|
||||
By default, the model will determine when and how many tools to use. You can control this behavior with the `tool_choice` parameter:
|
||||
|
||||
- `"auto"` (default): The model decides whether to call zero, one, or multiple tools.
|
||||
- `"none"`: The model won't call any tools, even if they are provided.
|
||||
- `"required"`: The model must call at least one tool. The output is constrained to produce a valid tool call.
|
||||
- `{"function": {"name": "function_name"}}`: The model must call the specified function.
|
||||
|
||||
Example with `tool_choice: "required"`:
|
||||
|
||||
```json
|
||||
{
|
||||
"model": "qwen3",
|
||||
"messages": [{"role": "user", "content": "What is the weather in Paris?"}],
|
||||
"tools": [...],
|
||||
"tool_choice": "required",
|
||||
"stream": false
|
||||
}
|
||||
```
|
||||
|
||||
Example with a forced function:
|
||||
|
||||
```json
|
||||
{
|
||||
"model": "qwen3",
|
||||
"messages": [{"role": "user", "content": "What is the weather in Paris?"}],
|
||||
"tools": [...],
|
||||
"tool_choice": {"function": {"name": "get_weather"}},
|
||||
"stream": false
|
||||
}
|
||||
```
|
||||
|
||||
### Structured outputs
|
||||
|
||||
Structured outputs are supported by providing a JSON schema in the `format` parameter. The model will generate a response that matches the schema. See the [Chat request (Structured outputs)](#chat-request-structured-outputs) example below.
|
||||
|
|
|
|||
|
|
@ -103,6 +103,89 @@ curl -X POST http://localhost:11434/v1/responses \
|
|||
|
||||
</CodeGroup>
|
||||
|
||||
### Tool calling with `tool_choice`
|
||||
|
||||
The `tool_choice` parameter controls how the model uses tools:
|
||||
|
||||
<CodeGroup dropdown>
|
||||
|
||||
```python tool_choice.py
|
||||
from openai import OpenAI
|
||||
|
||||
client = OpenAI(
|
||||
base_url='http://localhost:11434/v1/',
|
||||
api_key='ollama',
|
||||
)
|
||||
|
||||
tools = [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_weather",
|
||||
"description": "Get the weather for a location",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"location": {"type": "string", "description": "City name"}
|
||||
},
|
||||
"required": ["location"]
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
|
||||
# Force the model to call a tool (any tool)
|
||||
response = client.chat.completions.create(
|
||||
model="llama3.2",
|
||||
messages=[{"role": "user", "content": "What's the weather in Paris?"}],
|
||||
tools=tools,
|
||||
tool_choice="required" # Must call at least one tool
|
||||
)
|
||||
|
||||
# Force a specific function
|
||||
response = client.chat.completions.create(
|
||||
model="llama3.2",
|
||||
messages=[{"role": "user", "content": "Tell me about Paris"}],
|
||||
tools=tools,
|
||||
tool_choice={"type": "function", "name": "get_weather"} # Must call get_weather
|
||||
)
|
||||
|
||||
# Disable tool calling
|
||||
response = client.chat.completions.create(
|
||||
model="llama3.2",
|
||||
messages=[{"role": "user", "content": "What's the weather?"}],
|
||||
tools=tools,
|
||||
tool_choice="none" # Don't call any tools
|
||||
)
|
||||
```
|
||||
|
||||
```shell tool_choice.sh
|
||||
# Force the model to call a tool
|
||||
curl http://localhost:11434/v1/chat/completions \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{
|
||||
"model": "llama3.2",
|
||||
"messages": [{"role": "user", "content": "What is the weather in Paris?"}],
|
||||
"tools": [{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_weather",
|
||||
"description": "Get weather for a location",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"location": {"type": "string"}
|
||||
},
|
||||
"required": ["location"]
|
||||
}
|
||||
}
|
||||
}],
|
||||
"tool_choice": "required"
|
||||
}'
|
||||
```
|
||||
|
||||
</CodeGroup>
|
||||
|
||||
### v1/chat/completions with vision example
|
||||
|
||||
<CodeGroup dropdown>
|
||||
|
|
@ -207,7 +290,12 @@ curl -X POST http://localhost:11434/v1/chat/completions \
|
|||
- [x] `top_p`
|
||||
- [x] `max_tokens`
|
||||
- [x] `tools`
|
||||
- [ ] `tool_choice`
|
||||
- [x] `tool_choice`
|
||||
- [x] `"auto"` (default)
|
||||
- [x] `"none"`
|
||||
- [x] `"required"`
|
||||
- [x] `{"type": "function", "name": "function_name"}`
|
||||
- [x] `{"type": "allowed_tools", "mode": "auto"|"required", "tools": [...]}`
|
||||
- [ ] `logit_bias`
|
||||
- [ ] `user`
|
||||
- [ ] `n`
|
||||
|
|
|
|||
|
|
@ -20,10 +20,12 @@ type BaseWriter struct {
|
|||
}
|
||||
|
||||
type ChatWriter struct {
|
||||
stream bool
|
||||
streamOptions *openai.StreamOptions
|
||||
id string
|
||||
toolCallSent bool
|
||||
stream bool
|
||||
streamOptions *openai.StreamOptions
|
||||
id string
|
||||
toolCallSent bool
|
||||
forcedToolCall bool
|
||||
forcedToolName string
|
||||
BaseWriter
|
||||
}
|
||||
|
||||
|
|
@ -65,6 +67,40 @@ func (w *BaseWriter) writeError(data []byte) (int, error) {
|
|||
return len(data), nil
|
||||
}
|
||||
|
||||
func parseForcedToolCall(content string, forcedToolName string) *api.ToolCall {
|
||||
if content == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Try to parse as tool call structure: {"name": "...", "arguments": {...}}
|
||||
var toolCallJSON struct {
|
||||
Name string `json:"name"`
|
||||
Arguments map[string]any `json:"arguments"`
|
||||
}
|
||||
|
||||
if err := json.Unmarshal([]byte(content), &toolCallJSON); err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
// If a specific tool was forced, use that name
|
||||
name := toolCallJSON.Name
|
||||
if forcedToolName != "" {
|
||||
name = forcedToolName
|
||||
}
|
||||
|
||||
if name == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
return &api.ToolCall{
|
||||
ID: fmt.Sprintf("call_%d", rand.Intn(999999)),
|
||||
Function: api.ToolCallFunction{
|
||||
Name: name,
|
||||
Arguments: toolCallJSON.Arguments,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (w *ChatWriter) writeResponse(data []byte) (int, error) {
|
||||
var chatResponse api.ChatResponse
|
||||
err := json.Unmarshal(data, &chatResponse)
|
||||
|
|
@ -72,6 +108,15 @@ func (w *ChatWriter) writeResponse(data []byte) (int, error) {
|
|||
return 0, err
|
||||
}
|
||||
|
||||
// If tool_choice forced a tool call and we have content but no tool calls,
|
||||
// try to parse the content as a tool call
|
||||
if w.forcedToolCall && len(chatResponse.Message.ToolCalls) == 0 && chatResponse.Message.Content != "" {
|
||||
if toolCall := parseForcedToolCall(chatResponse.Message.Content, w.forcedToolName); toolCall != nil {
|
||||
chatResponse.Message.ToolCalls = []api.ToolCall{*toolCall}
|
||||
chatResponse.Message.Content = ""
|
||||
}
|
||||
}
|
||||
|
||||
// chat chunk
|
||||
if w.stream {
|
||||
c := openai.ToChunk(w.id, chatResponse, w.toolCallSent)
|
||||
|
|
@ -406,6 +451,16 @@ func ChatMiddleware() gin.HandlerFunc {
|
|||
return
|
||||
}
|
||||
|
||||
// Determine if tool_choice forces a tool call
|
||||
var forcedToolCall bool
|
||||
var forcedToolName string
|
||||
if req.ToolChoice != nil && len(req.Tools) > 0 {
|
||||
_, _, forcedToolCall, _ = openai.ApplyToolChoice(req.Tools, req.ToolChoice)
|
||||
if req.ToolChoice.IsForcedFunction() {
|
||||
forcedToolName = req.ToolChoice.GetForcedFunctionName()
|
||||
}
|
||||
}
|
||||
|
||||
var b bytes.Buffer
|
||||
|
||||
chatReq, err := openai.FromChatRequest(req)
|
||||
|
|
@ -422,10 +477,12 @@ func ChatMiddleware() gin.HandlerFunc {
|
|||
c.Request.Body = io.NopCloser(&b)
|
||||
|
||||
w := &ChatWriter{
|
||||
BaseWriter: BaseWriter{ResponseWriter: c.Writer},
|
||||
stream: req.Stream,
|
||||
id: fmt.Sprintf("chatcmpl-%d", rand.Intn(999)),
|
||||
streamOptions: req.StreamOptions,
|
||||
BaseWriter: BaseWriter{ResponseWriter: c.Writer},
|
||||
stream: req.Stream,
|
||||
id: fmt.Sprintf("chatcmpl-%d", rand.Intn(999)),
|
||||
streamOptions: req.StreamOptions,
|
||||
forcedToolCall: forcedToolCall,
|
||||
forcedToolName: forcedToolName,
|
||||
}
|
||||
|
||||
c.Writer = w
|
||||
|
|
|
|||
306
openai/openai.go
306
openai/openai.go
|
|
@ -95,6 +95,132 @@ type Reasoning struct {
|
|||
Effort string `json:"effort,omitempty"`
|
||||
}
|
||||
|
||||
type ToolChoiceFunctionRef struct {
|
||||
Name string `json:"name"`
|
||||
}
|
||||
|
||||
type ToolChoiceAllowedTool struct {
|
||||
Type string `json:"type"` // "function"
|
||||
Name string `json:"name"`
|
||||
}
|
||||
|
||||
type ToolChoiceObject struct {
|
||||
Type string `json:"type,omitempty"`
|
||||
Name string `json:"name,omitempty"`
|
||||
Function *ToolChoiceFunctionRef `json:"function,omitempty"`
|
||||
Mode string `json:"mode,omitempty"`
|
||||
Tools []ToolChoiceAllowedTool `json:"tools,omitempty"`
|
||||
}
|
||||
|
||||
type ToolChoice struct {
|
||||
Mode string
|
||||
Object *ToolChoiceObject
|
||||
}
|
||||
|
||||
// UnmarshalJSON handles both string and object forms of tool_choice
|
||||
func (tc *ToolChoice) UnmarshalJSON(data []byte) error {
|
||||
// Try string first
|
||||
var s string
|
||||
if err := json.Unmarshal(data, &s); err == nil {
|
||||
tc.Mode = s
|
||||
tc.Object = nil
|
||||
return nil
|
||||
}
|
||||
|
||||
// Try object
|
||||
var obj ToolChoiceObject
|
||||
if err := json.Unmarshal(data, &obj); err != nil {
|
||||
return fmt.Errorf("invalid tool_choice: must be string or object")
|
||||
}
|
||||
tc.Object = &obj
|
||||
return nil
|
||||
}
|
||||
|
||||
// MarshalJSON serializes ToolChoice back to JSON
|
||||
func (tc ToolChoice) MarshalJSON() ([]byte, error) {
|
||||
if tc.Object != nil {
|
||||
return json.Marshal(tc.Object)
|
||||
}
|
||||
return json.Marshal(tc.Mode)
|
||||
}
|
||||
|
||||
// IsNone returns true if tool_choice is "none"
|
||||
func (tc *ToolChoice) IsNone() bool {
|
||||
return tc != nil && tc.Mode == "none"
|
||||
}
|
||||
|
||||
// IsRequired returns true if tool_choice is "required"
|
||||
func (tc *ToolChoice) IsRequired() bool {
|
||||
return tc != nil && tc.Mode == "required"
|
||||
}
|
||||
|
||||
// IsAuto returns true if tool_choice is "auto" or not specified
|
||||
func (tc *ToolChoice) IsAuto() bool {
|
||||
if tc == nil {
|
||||
return true
|
||||
}
|
||||
// If there's an object, check if it's allowed_tools with auto mode
|
||||
if tc.Object != nil {
|
||||
// allowed_tools with "auto" mode is still considered auto
|
||||
if tc.Object.Type == "allowed_tools" && (tc.Object.Mode == "" || tc.Object.Mode == "auto") {
|
||||
return true
|
||||
}
|
||||
// Any other object (forced function, allowed_tools with required) is not auto
|
||||
return false
|
||||
}
|
||||
return tc.Mode == "" || tc.Mode == "auto"
|
||||
}
|
||||
|
||||
// IsForcedFunction returns true if a specific function is forced
|
||||
func (tc *ToolChoice) IsForcedFunction() bool {
|
||||
if tc == nil || tc.Object == nil {
|
||||
return false
|
||||
}
|
||||
return tc.Object.Type == "function" || tc.Object.Name != "" || tc.Object.Function != nil
|
||||
}
|
||||
|
||||
// GetForcedFunctionName returns the name of the forced function, if any
|
||||
func (tc *ToolChoice) GetForcedFunctionName() string {
|
||||
if tc == nil || tc.Object == nil {
|
||||
return ""
|
||||
}
|
||||
if tc.Object.Name != "" {
|
||||
return tc.Object.Name
|
||||
}
|
||||
if tc.Object.Function != nil {
|
||||
return tc.Object.Function.Name
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// IsAllowedTools returns true if tool_choice uses allowed_tools mode
|
||||
func (tc *ToolChoice) IsAllowedTools() bool {
|
||||
return tc != nil && tc.Object != nil && tc.Object.Type == "allowed_tools"
|
||||
}
|
||||
|
||||
// GetAllowedToolNames returns the list of allowed tool names
|
||||
func (tc *ToolChoice) GetAllowedToolNames() []string {
|
||||
if !tc.IsAllowedTools() || tc.Object.Tools == nil {
|
||||
return nil
|
||||
}
|
||||
names := make([]string, len(tc.Object.Tools))
|
||||
for i, t := range tc.Object.Tools {
|
||||
names[i] = t.Name
|
||||
}
|
||||
return names
|
||||
}
|
||||
|
||||
// GetAllowedToolsMode returns the mode for allowed_tools ("auto" or "required")
|
||||
func (tc *ToolChoice) GetAllowedToolsMode() string {
|
||||
if !tc.IsAllowedTools() {
|
||||
return ""
|
||||
}
|
||||
if tc.Object.Mode == "" {
|
||||
return "auto"
|
||||
}
|
||||
return tc.Object.Mode
|
||||
}
|
||||
|
||||
type ChatCompletionRequest struct {
|
||||
Model string `json:"model"`
|
||||
Messages []Message `json:"messages"`
|
||||
|
|
@ -109,6 +235,7 @@ type ChatCompletionRequest struct {
|
|||
TopP *float64 `json:"top_p"`
|
||||
ResponseFormat *ResponseFormat `json:"response_format"`
|
||||
Tools []api.Tool `json:"tools"`
|
||||
ToolChoice *ToolChoice `json:"tool_choice,omitempty"`
|
||||
Reasoning *Reasoning `json:"reasoning,omitempty"`
|
||||
ReasoningEffort *string `json:"reasoning_effort,omitempty"`
|
||||
Logprobs *bool `json:"logprobs"`
|
||||
|
|
@ -444,6 +571,171 @@ func ToModel(r api.ShowResponse, m string) Model {
|
|||
}
|
||||
}
|
||||
|
||||
// filterToolsByNames returns only the tools that match the given names
|
||||
func filterToolsByNames(tools []api.Tool, names []string) []api.Tool {
|
||||
if len(names) == 0 {
|
||||
return tools
|
||||
}
|
||||
nameSet := make(map[string]bool)
|
||||
for _, name := range names {
|
||||
nameSet[name] = true
|
||||
}
|
||||
var filtered []api.Tool
|
||||
for _, tool := range tools {
|
||||
if nameSet[tool.Function.Name] {
|
||||
filtered = append(filtered, tool)
|
||||
}
|
||||
}
|
||||
return filtered
|
||||
}
|
||||
|
||||
// findToolByName returns the tool with the given name, or nil if not found
|
||||
func findToolByName(tools []api.Tool, name string) *api.Tool {
|
||||
for _, tool := range tools {
|
||||
if tool.Function.Name == name {
|
||||
return &tool
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// generateToolCallSchema creates a JSON schema that constrains output to valid tool calls
|
||||
func generateToolCallSchema(tools []api.Tool) json.RawMessage {
|
||||
if len(tools) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Collect all tool names for the enum
|
||||
toolNames := make([]string, len(tools))
|
||||
for i, tool := range tools {
|
||||
toolNames[i] = tool.Function.Name
|
||||
}
|
||||
|
||||
// Build a schema that allows any of the tools
|
||||
// Using oneOf for each tool with its specific parameter schema
|
||||
var oneOfSchemas []map[string]any
|
||||
for _, tool := range tools {
|
||||
toolSchema := map[string]any{
|
||||
"type": "object",
|
||||
"properties": map[string]any{
|
||||
"name": map[string]any{
|
||||
"type": "string",
|
||||
"const": tool.Function.Name,
|
||||
},
|
||||
"arguments": tool.Function.Parameters,
|
||||
},
|
||||
"required": []string{"name", "arguments"},
|
||||
"additionalProperties": false,
|
||||
}
|
||||
oneOfSchemas = append(oneOfSchemas, toolSchema)
|
||||
}
|
||||
|
||||
var schema map[string]any
|
||||
if len(oneOfSchemas) == 1 {
|
||||
// Single tool - use its schema directly
|
||||
schema = oneOfSchemas[0]
|
||||
} else {
|
||||
// Multiple tools - use oneOf
|
||||
schema = map[string]any{
|
||||
"oneOf": oneOfSchemas,
|
||||
}
|
||||
}
|
||||
|
||||
bytes, err := json.Marshal(schema)
|
||||
if err != nil {
|
||||
slog.Error("failed to marshal tool call schema", "error", err)
|
||||
return nil
|
||||
}
|
||||
return bytes
|
||||
}
|
||||
|
||||
// generateForcedFunctionSchema creates a JSON schema for a specific forced function
|
||||
func generateForcedFunctionSchema(tool api.Tool) json.RawMessage {
|
||||
schema := map[string]any{
|
||||
"type": "object",
|
||||
"properties": map[string]any{
|
||||
"name": map[string]any{
|
||||
"type": "string",
|
||||
"const": tool.Function.Name,
|
||||
},
|
||||
"arguments": tool.Function.Parameters,
|
||||
},
|
||||
"required": []string{"name", "arguments"},
|
||||
"additionalProperties": false,
|
||||
}
|
||||
|
||||
bytes, err := json.Marshal(schema)
|
||||
if err != nil {
|
||||
slog.Error("failed to marshal forced function schema", "error", err)
|
||||
return nil
|
||||
}
|
||||
return bytes
|
||||
}
|
||||
|
||||
// ApplyToolChoice processes tool_choice and returns filtered tools and optional format schema
|
||||
// Returns:
|
||||
// - filteredTools: the tools to pass to the model (may be empty for "none")
|
||||
// - format: JSON schema to constrain output (for "required" or forced function)
|
||||
// - forcedToolCall: true if the response should be parsed as a tool call
|
||||
// - error: if tool_choice references a non-existent tool
|
||||
func ApplyToolChoice(tools []api.Tool, toolChoice *ToolChoice) (filteredTools []api.Tool, format json.RawMessage, forcedToolCall bool, err error) {
|
||||
// Default: auto mode, return all tools without format constraint
|
||||
if toolChoice == nil || toolChoice.IsAuto() {
|
||||
// Check for allowed_tools with auto mode
|
||||
if toolChoice != nil && toolChoice.IsAllowedTools() {
|
||||
allowedNames := toolChoice.GetAllowedToolNames()
|
||||
filteredTools = filterToolsByNames(tools, allowedNames)
|
||||
if toolChoice.GetAllowedToolsMode() == "required" {
|
||||
format = generateToolCallSchema(filteredTools)
|
||||
forcedToolCall = true
|
||||
}
|
||||
return filteredTools, format, forcedToolCall, nil
|
||||
}
|
||||
return tools, nil, false, nil
|
||||
}
|
||||
|
||||
// "none" mode: don't pass any tools
|
||||
if toolChoice.IsNone() {
|
||||
return nil, nil, false, nil
|
||||
}
|
||||
|
||||
// "required" mode: must call at least one tool
|
||||
if toolChoice.IsRequired() {
|
||||
format = generateToolCallSchema(tools)
|
||||
return tools, format, true, nil
|
||||
}
|
||||
|
||||
// Forced function mode
|
||||
if toolChoice.IsForcedFunction() {
|
||||
funcName := toolChoice.GetForcedFunctionName()
|
||||
if funcName == "" {
|
||||
return nil, nil, false, errors.New("tool_choice function name is required")
|
||||
}
|
||||
|
||||
tool := findToolByName(tools, funcName)
|
||||
if tool == nil {
|
||||
return nil, nil, false, fmt.Errorf("tool_choice references unknown function: %s", funcName)
|
||||
}
|
||||
|
||||
format = generateForcedFunctionSchema(*tool)
|
||||
return []api.Tool{*tool}, format, true, nil
|
||||
}
|
||||
|
||||
// allowed_tools mode (already handled in IsAuto check, but handle explicit type here)
|
||||
if toolChoice.IsAllowedTools() {
|
||||
allowedNames := toolChoice.GetAllowedToolNames()
|
||||
filteredTools = filterToolsByNames(tools, allowedNames)
|
||||
if toolChoice.GetAllowedToolsMode() == "required" {
|
||||
format = generateToolCallSchema(filteredTools)
|
||||
forcedToolCall = true
|
||||
}
|
||||
return filteredTools, format, forcedToolCall, nil
|
||||
}
|
||||
|
||||
// Unknown mode, default to auto
|
||||
return tools, nil, false, nil
|
||||
}
|
||||
|
||||
// FromChatRequest converts a ChatCompletionRequest to api.ChatRequest
|
||||
func FromChatRequest(r ChatCompletionRequest) (*api.ChatRequest, error) {
|
||||
var messages []api.Message
|
||||
|
|
@ -579,6 +871,18 @@ func FromChatRequest(r ChatCompletionRequest) (*api.ChatRequest, error) {
|
|||
}
|
||||
}
|
||||
|
||||
// Apply tool_choice to filter tools and potentially set format constraint
|
||||
filteredTools, toolChoiceFormat, _, err := ApplyToolChoice(r.Tools, r.ToolChoice)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// If tool_choice requires a format constraint and no explicit response_format was set,
|
||||
// apply the tool call schema
|
||||
if toolChoiceFormat != nil && format == nil {
|
||||
format = toolChoiceFormat
|
||||
}
|
||||
|
||||
var think *api.ThinkValue
|
||||
var effort string
|
||||
|
||||
|
|
@ -606,7 +910,7 @@ func FromChatRequest(r ChatCompletionRequest) (*api.ChatRequest, error) {
|
|||
Format: format,
|
||||
Options: options,
|
||||
Stream: &r.Stream,
|
||||
Tools: r.Tools,
|
||||
Tools: filteredTools,
|
||||
Think: think,
|
||||
Logprobs: r.Logprobs != nil && *r.Logprobs,
|
||||
TopLogprobs: r.TopLogprobs,
|
||||
|
|
|
|||
|
|
@ -434,3 +434,480 @@ func TestFromChatRequest_TopLogprobsRange(t *testing.T) {
|
|||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestToolChoice_UnmarshalJSON(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
json string
|
||||
wantMode string
|
||||
wantObj bool
|
||||
}{
|
||||
{
|
||||
name: "string auto",
|
||||
json: `"auto"`,
|
||||
wantMode: "auto",
|
||||
wantObj: false,
|
||||
},
|
||||
{
|
||||
name: "string none",
|
||||
json: `"none"`,
|
||||
wantMode: "none",
|
||||
wantObj: false,
|
||||
},
|
||||
{
|
||||
name: "string required",
|
||||
json: `"required"`,
|
||||
wantMode: "required",
|
||||
wantObj: false,
|
||||
},
|
||||
{
|
||||
name: "object function with name",
|
||||
json: `{"type": "function", "name": "get_weather"}`,
|
||||
wantMode: "",
|
||||
wantObj: true,
|
||||
},
|
||||
{
|
||||
name: "object function with function.name",
|
||||
json: `{"type": "function", "function": {"name": "get_weather"}}`,
|
||||
wantMode: "",
|
||||
wantObj: true,
|
||||
},
|
||||
{
|
||||
name: "object allowed_tools",
|
||||
json: `{"type": "allowed_tools", "mode": "required", "tools": [{"type": "function", "name": "get_weather"}]}`,
|
||||
wantMode: "",
|
||||
wantObj: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
var tc ToolChoice
|
||||
if err := tc.UnmarshalJSON([]byte(tt.json)); err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if tc.Mode != tt.wantMode {
|
||||
t.Errorf("Mode = %q, want %q", tc.Mode, tt.wantMode)
|
||||
}
|
||||
|
||||
if (tc.Object != nil) != tt.wantObj {
|
||||
t.Errorf("Object = %v, wantObj = %v", tc.Object, tt.wantObj)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestToolChoice_Methods(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
toolChoice *ToolChoice
|
||||
isNone bool
|
||||
isRequired bool
|
||||
isAuto bool
|
||||
isForcedFunction bool
|
||||
forcedFuncName string
|
||||
isAllowedTools bool
|
||||
allowedToolNames []string
|
||||
allowedToolsMode string
|
||||
}{
|
||||
{
|
||||
name: "nil",
|
||||
toolChoice: nil,
|
||||
isNone: false,
|
||||
isRequired: false,
|
||||
isAuto: true,
|
||||
},
|
||||
{
|
||||
name: "auto string",
|
||||
toolChoice: &ToolChoice{Mode: "auto"},
|
||||
isNone: false,
|
||||
isRequired: false,
|
||||
isAuto: true,
|
||||
},
|
||||
{
|
||||
name: "none string",
|
||||
toolChoice: &ToolChoice{Mode: "none"},
|
||||
isNone: true,
|
||||
isRequired: false,
|
||||
isAuto: false,
|
||||
},
|
||||
{
|
||||
name: "required string",
|
||||
toolChoice: &ToolChoice{Mode: "required"},
|
||||
isNone: false,
|
||||
isRequired: true,
|
||||
isAuto: false,
|
||||
},
|
||||
{
|
||||
name: "forced function with name",
|
||||
toolChoice: &ToolChoice{
|
||||
Object: &ToolChoiceObject{Type: "function", Name: "get_weather"},
|
||||
},
|
||||
isNone: false,
|
||||
isRequired: false,
|
||||
isAuto: false,
|
||||
isForcedFunction: true,
|
||||
forcedFuncName: "get_weather",
|
||||
},
|
||||
{
|
||||
name: "forced function with function.name",
|
||||
toolChoice: &ToolChoice{
|
||||
Object: &ToolChoiceObject{
|
||||
Type: "function",
|
||||
Function: &ToolChoiceFunctionRef{Name: "search"},
|
||||
},
|
||||
},
|
||||
isNone: false,
|
||||
isRequired: false,
|
||||
isAuto: false,
|
||||
isForcedFunction: true,
|
||||
forcedFuncName: "search",
|
||||
},
|
||||
{
|
||||
name: "allowed_tools auto",
|
||||
toolChoice: &ToolChoice{
|
||||
Object: &ToolChoiceObject{
|
||||
Type: "allowed_tools",
|
||||
Mode: "auto",
|
||||
Tools: []ToolChoiceAllowedTool{
|
||||
{Type: "function", Name: "get_weather"},
|
||||
{Type: "function", Name: "search"},
|
||||
},
|
||||
},
|
||||
},
|
||||
isNone: false,
|
||||
isRequired: false,
|
||||
isAuto: true,
|
||||
isForcedFunction: false,
|
||||
isAllowedTools: true,
|
||||
allowedToolNames: []string{"get_weather", "search"},
|
||||
allowedToolsMode: "auto",
|
||||
},
|
||||
{
|
||||
name: "allowed_tools required",
|
||||
toolChoice: &ToolChoice{
|
||||
Object: &ToolChoiceObject{
|
||||
Type: "allowed_tools",
|
||||
Mode: "required",
|
||||
Tools: []ToolChoiceAllowedTool{
|
||||
{Type: "function", Name: "get_weather"},
|
||||
},
|
||||
},
|
||||
},
|
||||
isNone: false,
|
||||
isRequired: false,
|
||||
isAuto: false,
|
||||
isForcedFunction: false,
|
||||
isAllowedTools: true,
|
||||
allowedToolNames: []string{"get_weather"},
|
||||
allowedToolsMode: "required",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if got := tt.toolChoice.IsNone(); got != tt.isNone {
|
||||
t.Errorf("IsNone() = %v, want %v", got, tt.isNone)
|
||||
}
|
||||
if got := tt.toolChoice.IsRequired(); got != tt.isRequired {
|
||||
t.Errorf("IsRequired() = %v, want %v", got, tt.isRequired)
|
||||
}
|
||||
if got := tt.toolChoice.IsAuto(); got != tt.isAuto {
|
||||
t.Errorf("IsAuto() = %v, want %v", got, tt.isAuto)
|
||||
}
|
||||
if got := tt.toolChoice.IsForcedFunction(); got != tt.isForcedFunction {
|
||||
t.Errorf("IsForcedFunction() = %v, want %v", got, tt.isForcedFunction)
|
||||
}
|
||||
if got := tt.toolChoice.GetForcedFunctionName(); got != tt.forcedFuncName {
|
||||
t.Errorf("GetForcedFunctionName() = %q, want %q", got, tt.forcedFuncName)
|
||||
}
|
||||
if got := tt.toolChoice.IsAllowedTools(); got != tt.isAllowedTools {
|
||||
t.Errorf("IsAllowedTools() = %v, want %v", got, tt.isAllowedTools)
|
||||
}
|
||||
if tt.isAllowedTools {
|
||||
if got := tt.toolChoice.GetAllowedToolNames(); !cmp.Equal(got, tt.allowedToolNames) {
|
||||
t.Errorf("GetAllowedToolNames() = %v, want %v", got, tt.allowedToolNames)
|
||||
}
|
||||
if got := tt.toolChoice.GetAllowedToolsMode(); got != tt.allowedToolsMode {
|
||||
t.Errorf("GetAllowedToolsMode() = %q, want %q", got, tt.allowedToolsMode)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestApplyToolChoice(t *testing.T) {
|
||||
tools := []api.Tool{
|
||||
{
|
||||
Type: "function",
|
||||
Function: api.ToolFunction{
|
||||
Name: "get_weather",
|
||||
Description: "Get weather",
|
||||
Parameters: api.ToolFunctionParameters{
|
||||
Type: "object",
|
||||
Properties: map[string]api.ToolProperty{
|
||||
"location": {Type: []string{"string"}},
|
||||
},
|
||||
Required: []string{"location"},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
Type: "function",
|
||||
Function: api.ToolFunction{
|
||||
Name: "search",
|
||||
Description: "Search the web",
|
||||
Parameters: api.ToolFunctionParameters{
|
||||
Type: "object",
|
||||
Properties: map[string]api.ToolProperty{
|
||||
"query": {Type: []string{"string"}},
|
||||
},
|
||||
Required: []string{"query"},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
toolChoice *ToolChoice
|
||||
wantToolCount int
|
||||
wantFormat bool
|
||||
wantForced bool
|
||||
wantError bool
|
||||
wantToolNames []string
|
||||
}{
|
||||
{
|
||||
name: "nil (auto)",
|
||||
toolChoice: nil,
|
||||
wantToolCount: 2,
|
||||
wantFormat: false,
|
||||
wantForced: false,
|
||||
},
|
||||
{
|
||||
name: "auto string",
|
||||
toolChoice: &ToolChoice{Mode: "auto"},
|
||||
wantToolCount: 2,
|
||||
wantFormat: false,
|
||||
wantForced: false,
|
||||
},
|
||||
{
|
||||
name: "none string",
|
||||
toolChoice: &ToolChoice{Mode: "none"},
|
||||
wantToolCount: 0,
|
||||
wantFormat: false,
|
||||
wantForced: false,
|
||||
},
|
||||
{
|
||||
name: "required string",
|
||||
toolChoice: &ToolChoice{Mode: "required"},
|
||||
wantToolCount: 2,
|
||||
wantFormat: true,
|
||||
wantForced: true,
|
||||
},
|
||||
{
|
||||
name: "forced function",
|
||||
toolChoice: &ToolChoice{
|
||||
Object: &ToolChoiceObject{Type: "function", Name: "get_weather"},
|
||||
},
|
||||
wantToolCount: 1,
|
||||
wantFormat: true,
|
||||
wantForced: true,
|
||||
wantToolNames: []string{"get_weather"},
|
||||
},
|
||||
{
|
||||
name: "forced unknown function",
|
||||
toolChoice: &ToolChoice{
|
||||
Object: &ToolChoiceObject{Type: "function", Name: "unknown_func"},
|
||||
},
|
||||
wantError: true,
|
||||
},
|
||||
{
|
||||
name: "allowed_tools auto",
|
||||
toolChoice: &ToolChoice{
|
||||
Object: &ToolChoiceObject{
|
||||
Type: "allowed_tools",
|
||||
Mode: "auto",
|
||||
Tools: []ToolChoiceAllowedTool{
|
||||
{Type: "function", Name: "get_weather"},
|
||||
},
|
||||
},
|
||||
},
|
||||
wantToolCount: 1,
|
||||
wantFormat: false,
|
||||
wantForced: false,
|
||||
wantToolNames: []string{"get_weather"},
|
||||
},
|
||||
{
|
||||
name: "allowed_tools required",
|
||||
toolChoice: &ToolChoice{
|
||||
Object: &ToolChoiceObject{
|
||||
Type: "allowed_tools",
|
||||
Mode: "required",
|
||||
Tools: []ToolChoiceAllowedTool{
|
||||
{Type: "function", Name: "search"},
|
||||
},
|
||||
},
|
||||
},
|
||||
wantToolCount: 1,
|
||||
wantFormat: true,
|
||||
wantForced: true,
|
||||
wantToolNames: []string{"search"},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
filteredTools, format, forced, err := ApplyToolChoice(tools, tt.toolChoice)
|
||||
|
||||
if tt.wantError {
|
||||
if err == nil {
|
||||
t.Errorf("expected error, got nil")
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if len(filteredTools) != tt.wantToolCount {
|
||||
t.Errorf("got %d tools, want %d", len(filteredTools), tt.wantToolCount)
|
||||
}
|
||||
|
||||
if (format != nil) != tt.wantFormat {
|
||||
t.Errorf("format = %v, wantFormat = %v", format != nil, tt.wantFormat)
|
||||
}
|
||||
|
||||
if forced != tt.wantForced {
|
||||
t.Errorf("forced = %v, want %v", forced, tt.wantForced)
|
||||
}
|
||||
|
||||
if tt.wantToolNames != nil {
|
||||
gotNames := make([]string, len(filteredTools))
|
||||
for i, tool := range filteredTools {
|
||||
gotNames[i] = tool.Function.Name
|
||||
}
|
||||
if !cmp.Equal(gotNames, tt.wantToolNames) {
|
||||
t.Errorf("tool names = %v, want %v", gotNames, tt.wantToolNames)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestFromChatRequest_WithToolChoice(t *testing.T) {
|
||||
tools := []api.Tool{
|
||||
{
|
||||
Type: "function",
|
||||
Function: api.ToolFunction{
|
||||
Name: "get_weather",
|
||||
Description: "Get weather",
|
||||
Parameters: api.ToolFunctionParameters{
|
||||
Type: "object",
|
||||
Properties: map[string]api.ToolProperty{
|
||||
"location": {Type: []string{"string"}},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
Type: "function",
|
||||
Function: api.ToolFunction{
|
||||
Name: "search",
|
||||
Description: "Search",
|
||||
Parameters: api.ToolFunctionParameters{
|
||||
Type: "object",
|
||||
Properties: map[string]api.ToolProperty{
|
||||
"query": {Type: []string{"string"}},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
t.Run("tool_choice none removes tools", func(t *testing.T) {
|
||||
req := ChatCompletionRequest{
|
||||
Model: "test-model",
|
||||
Messages: []Message{{Role: "user", Content: "Hello"}},
|
||||
Tools: tools,
|
||||
ToolChoice: &ToolChoice{Mode: "none"},
|
||||
}
|
||||
|
||||
result, err := FromChatRequest(req)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if len(result.Tools) != 0 {
|
||||
t.Errorf("expected 0 tools, got %d", len(result.Tools))
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("tool_choice required adds format", func(t *testing.T) {
|
||||
req := ChatCompletionRequest{
|
||||
Model: "test-model",
|
||||
Messages: []Message{{Role: "user", Content: "Hello"}},
|
||||
Tools: tools,
|
||||
ToolChoice: &ToolChoice{Mode: "required"},
|
||||
}
|
||||
|
||||
result, err := FromChatRequest(req)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if len(result.Tools) != 2 {
|
||||
t.Errorf("expected 2 tools, got %d", len(result.Tools))
|
||||
}
|
||||
|
||||
if result.Format == nil {
|
||||
t.Error("expected format to be set for required tool_choice")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("tool_choice forced function filters and adds format", func(t *testing.T) {
|
||||
req := ChatCompletionRequest{
|
||||
Model: "test-model",
|
||||
Messages: []Message{{Role: "user", Content: "Hello"}},
|
||||
Tools: tools,
|
||||
ToolChoice: &ToolChoice{
|
||||
Object: &ToolChoiceObject{Type: "function", Name: "get_weather"},
|
||||
},
|
||||
}
|
||||
|
||||
result, err := FromChatRequest(req)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if len(result.Tools) != 1 {
|
||||
t.Errorf("expected 1 tool, got %d", len(result.Tools))
|
||||
}
|
||||
|
||||
if result.Tools[0].Function.Name != "get_weather" {
|
||||
t.Errorf("expected tool 'get_weather', got %q", result.Tools[0].Function.Name)
|
||||
}
|
||||
|
||||
if result.Format == nil {
|
||||
t.Error("expected format to be set for forced function")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("tool_choice unknown function returns error", func(t *testing.T) {
|
||||
req := ChatCompletionRequest{
|
||||
Model: "test-model",
|
||||
Messages: []Message{{Role: "user", Content: "Hello"}},
|
||||
Tools: tools,
|
||||
ToolChoice: &ToolChoice{
|
||||
Object: &ToolChoiceObject{Type: "function", Name: "unknown"},
|
||||
},
|
||||
}
|
||||
|
||||
_, err := FromChatRequest(req)
|
||||
if err == nil {
|
||||
t.Error("expected error for unknown function")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
|
|
|||
175
server/routes.go
175
server/routes.go
|
|
@ -1861,6 +1861,112 @@ func toolCallId() string {
|
|||
return "call_" + strings.ToLower(string(b))
|
||||
}
|
||||
|
||||
// findToolByName returns the tool with the given name, or nil if not found.
|
||||
func findToolByName(tools []api.Tool, name string) *api.Tool {
|
||||
for i := range tools {
|
||||
if tools[i].Function.Name == name {
|
||||
return &tools[i]
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// generateToolCallSchema creates a JSON schema that constrains output to valid tool calls.
|
||||
func generateToolCallSchema(tools []api.Tool) json.RawMessage {
|
||||
if len(tools) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
var oneOfSchemas []map[string]any
|
||||
for _, tool := range tools {
|
||||
toolSchema := map[string]any{
|
||||
"type": "object",
|
||||
"properties": map[string]any{
|
||||
"name": map[string]any{
|
||||
"type": "string",
|
||||
"const": tool.Function.Name,
|
||||
},
|
||||
"arguments": tool.Function.Parameters,
|
||||
},
|
||||
"required": []string{"name", "arguments"},
|
||||
"additionalProperties": false,
|
||||
}
|
||||
oneOfSchemas = append(oneOfSchemas, toolSchema)
|
||||
}
|
||||
|
||||
var schema map[string]any
|
||||
if len(oneOfSchemas) == 1 {
|
||||
schema = oneOfSchemas[0]
|
||||
} else {
|
||||
schema = map[string]any{
|
||||
"oneOf": oneOfSchemas,
|
||||
}
|
||||
}
|
||||
|
||||
bytes, err := json.Marshal(schema)
|
||||
if err != nil {
|
||||
slog.Error("failed to marshal tool call schema", "error", err)
|
||||
return nil
|
||||
}
|
||||
return bytes
|
||||
}
|
||||
|
||||
// generateForcedFunctionSchema creates a JSON schema for a specific forced function.
|
||||
func generateForcedFunctionSchema(tool api.Tool) json.RawMessage {
|
||||
schema := map[string]any{
|
||||
"type": "object",
|
||||
"properties": map[string]any{
|
||||
"name": map[string]any{
|
||||
"type": "string",
|
||||
"const": tool.Function.Name,
|
||||
},
|
||||
"arguments": tool.Function.Parameters,
|
||||
},
|
||||
"required": []string{"name", "arguments"},
|
||||
"additionalProperties": false,
|
||||
}
|
||||
|
||||
bytes, err := json.Marshal(schema)
|
||||
if err != nil {
|
||||
slog.Error("failed to marshal forced function schema", "error", err)
|
||||
return nil
|
||||
}
|
||||
return bytes
|
||||
}
|
||||
|
||||
// parseForcedToolCallContent parses content as a forced tool call JSON.
|
||||
func parseForcedToolCallContent(content string, forcedToolName string) *api.ToolCall {
|
||||
if content == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
var toolCallJSON struct {
|
||||
Name string `json:"name"`
|
||||
Arguments map[string]any `json:"arguments"`
|
||||
}
|
||||
|
||||
if err := json.Unmarshal([]byte(content), &toolCallJSON); err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
name := toolCallJSON.Name
|
||||
if forcedToolName != "" {
|
||||
name = forcedToolName
|
||||
}
|
||||
|
||||
if name == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
return &api.ToolCall{
|
||||
ID: toolCallId(),
|
||||
Function: api.ToolCallFunction{
|
||||
Name: name,
|
||||
Arguments: toolCallJSON.Arguments,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) ChatHandler(c *gin.Context) {
|
||||
checkpointStart := time.Now()
|
||||
|
||||
|
|
@ -2075,6 +2181,36 @@ func (s *Server) ChatHandler(c *gin.Context) {
|
|||
}
|
||||
}
|
||||
|
||||
// Handle tool_choice
|
||||
var forcedToolCall bool
|
||||
var forcedToolName string
|
||||
if req.ToolChoice != nil && len(req.Tools) > 0 {
|
||||
if req.ToolChoice.IsNone() {
|
||||
// "none" mode: don't pass any tools
|
||||
processedTools = nil
|
||||
} else if req.ToolChoice.IsRequired() {
|
||||
// "required" mode: must call at least one tool, generate JSON schema
|
||||
if req.Format == nil {
|
||||
req.Format = generateToolCallSchema(req.Tools)
|
||||
}
|
||||
forcedToolCall = true
|
||||
} else if req.ToolChoice.IsForcedFunction() {
|
||||
// Forced function mode: must call this specific function
|
||||
funcName := req.ToolChoice.GetForcedFunctionName()
|
||||
tool := findToolByName(req.Tools, funcName)
|
||||
if tool == nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("tool_choice references unknown function: %s", funcName)})
|
||||
return
|
||||
}
|
||||
if req.Format == nil {
|
||||
req.Format = generateForcedFunctionSchema(*tool)
|
||||
}
|
||||
processedTools = []api.Tool{*tool}
|
||||
forcedToolCall = true
|
||||
forcedToolName = funcName
|
||||
}
|
||||
}
|
||||
|
||||
truncate := req.Truncate == nil || *req.Truncate
|
||||
prompt, images, err := chatPrompt(c.Request.Context(), m, r.Tokenize, opts, msgs, processedTools, req.Think, truncate)
|
||||
if err != nil {
|
||||
|
|
@ -2206,6 +2342,15 @@ func (s *Server) ChatHandler(c *gin.Context) {
|
|||
return
|
||||
}
|
||||
|
||||
// If tool_choice forced a tool call and builtinParser didn't find any,
|
||||
// try to parse the content as a forced tool call JSON
|
||||
if r.Done && forcedToolCall && len(res.Message.ToolCalls) == 0 && res.Message.Content != "" {
|
||||
if toolCall := parseForcedToolCallContent(res.Message.Content, forcedToolName); toolCall != nil {
|
||||
res.Message.ToolCalls = []api.ToolCall{*toolCall}
|
||||
res.Message.Content = ""
|
||||
}
|
||||
}
|
||||
|
||||
if res.Message.Content != "" || res.Message.Thinking != "" || len(res.Message.ToolCalls) > 0 || r.Done || len(res.Logprobs) > 0 {
|
||||
slog.Log(context.TODO(), logutil.LevelTrace, "builtin parser output", "parser", m.Config.Parser, "content", content, "thinking", thinking, "toolCalls", toolCalls, "done", r.Done)
|
||||
ch <- res
|
||||
|
|
@ -2235,7 +2380,7 @@ func (s *Server) ChatHandler(c *gin.Context) {
|
|||
res.Message.Content = remainingContent
|
||||
}
|
||||
|
||||
if len(req.Tools) > 0 {
|
||||
if len(req.Tools) > 0 && toolParser != nil {
|
||||
toolCalls, content := toolParser.Add(res.Message.Content)
|
||||
if len(content) > 0 {
|
||||
res.Message.Content = content
|
||||
|
|
@ -2258,12 +2403,28 @@ func (s *Server) ChatHandler(c *gin.Context) {
|
|||
|
||||
if r.Done {
|
||||
res.Message.Content = toolParser.Content()
|
||||
// If tool_choice forced a tool call, try to parse the buffered content
|
||||
if forcedToolCall && len(res.Message.ToolCalls) == 0 && res.Message.Content != "" {
|
||||
if toolCall := parseForcedToolCallContent(res.Message.Content, forcedToolName); toolCall != nil {
|
||||
res.Message.ToolCalls = []api.ToolCall{*toolCall}
|
||||
res.Message.Content = ""
|
||||
}
|
||||
}
|
||||
ch <- res
|
||||
}
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// If tool_choice forced a tool call and we have content but no tool calls,
|
||||
// try to parse the content as a tool call (used when format constraint was applied)
|
||||
if r.Done && forcedToolCall && len(res.Message.ToolCalls) == 0 && res.Message.Content != "" {
|
||||
if toolCall := parseForcedToolCallContent(res.Message.Content, forcedToolName); toolCall != nil {
|
||||
res.Message.ToolCalls = []api.ToolCall{*toolCall}
|
||||
res.Message.Content = ""
|
||||
}
|
||||
}
|
||||
|
||||
ch <- res
|
||||
})
|
||||
if err != nil {
|
||||
|
|
@ -2353,6 +2514,18 @@ func (s *Server) ChatHandler(c *gin.Context) {
|
|||
|
||||
if len(toolCalls) > 0 {
|
||||
resp.Message.ToolCalls = toolCalls
|
||||
// If we have tool calls from forced tool_choice, the "content" was actually
|
||||
// the JSON that got parsed into tool calls, so clear it
|
||||
if forcedToolCall {
|
||||
resp.Message.Content = ""
|
||||
}
|
||||
} else if forcedToolCall && resp.Message.Content != "" {
|
||||
// No tool calls were parsed in callbacks, but we have forced tool_choice.
|
||||
// Try to parse the accumulated content as a tool call JSON.
|
||||
if toolCall := parseForcedToolCallContent(resp.Message.Content, forcedToolName); toolCall != nil {
|
||||
resp.Message.ToolCalls = []api.ToolCall{*toolCall}
|
||||
resp.Message.Content = ""
|
||||
}
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, resp)
|
||||
|
|
|
|||
|
|
@ -978,3 +978,227 @@ func TestWaitForStream(t *testing.T) {
|
|||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestFindToolByName(t *testing.T) {
|
||||
tools := []api.Tool{
|
||||
{
|
||||
Type: "function",
|
||||
Function: api.ToolFunction{
|
||||
Name: "get_weather",
|
||||
Description: "Get weather for a location",
|
||||
},
|
||||
},
|
||||
{
|
||||
Type: "function",
|
||||
Function: api.ToolFunction{
|
||||
Name: "search_web",
|
||||
Description: "Search the web",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
t.Run("found", func(t *testing.T) {
|
||||
tool := findToolByName(tools, "get_weather")
|
||||
if tool == nil {
|
||||
t.Fatal("expected to find tool")
|
||||
}
|
||||
if tool.Function.Name != "get_weather" {
|
||||
t.Errorf("expected get_weather, got %s", tool.Function.Name)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("not found", func(t *testing.T) {
|
||||
tool := findToolByName(tools, "nonexistent")
|
||||
if tool != nil {
|
||||
t.Error("expected nil for nonexistent tool")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("empty tools", func(t *testing.T) {
|
||||
tool := findToolByName([]api.Tool{}, "get_weather")
|
||||
if tool != nil {
|
||||
t.Error("expected nil for empty tools")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestGenerateToolCallSchema(t *testing.T) {
|
||||
tools := []api.Tool{
|
||||
{
|
||||
Type: "function",
|
||||
Function: api.ToolFunction{
|
||||
Name: "get_weather",
|
||||
Description: "Get weather",
|
||||
Parameters: api.ToolFunctionParameters{
|
||||
Type: "object",
|
||||
Required: []string{"location"},
|
||||
Properties: map[string]api.ToolProperty{
|
||||
"location": {Type: api.PropertyType{"string"}},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
t.Run("single tool", func(t *testing.T) {
|
||||
schema := generateToolCallSchema(tools)
|
||||
if schema == nil {
|
||||
t.Fatal("expected schema, got nil")
|
||||
}
|
||||
|
||||
var parsed map[string]any
|
||||
if err := json.Unmarshal(schema, &parsed); err != nil {
|
||||
t.Fatalf("failed to parse schema: %v", err)
|
||||
}
|
||||
|
||||
// Should have properties with name and arguments
|
||||
props, ok := parsed["properties"].(map[string]any)
|
||||
if !ok {
|
||||
t.Fatal("expected properties in schema")
|
||||
}
|
||||
|
||||
if _, ok := props["name"]; !ok {
|
||||
t.Error("expected name property")
|
||||
}
|
||||
if _, ok := props["arguments"]; !ok {
|
||||
t.Error("expected arguments property")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("multiple tools", func(t *testing.T) {
|
||||
multiTools := append(tools, api.Tool{
|
||||
Type: "function",
|
||||
Function: api.ToolFunction{
|
||||
Name: "search_web",
|
||||
},
|
||||
})
|
||||
|
||||
schema := generateToolCallSchema(multiTools)
|
||||
if schema == nil {
|
||||
t.Fatal("expected schema, got nil")
|
||||
}
|
||||
|
||||
var parsed map[string]any
|
||||
if err := json.Unmarshal(schema, &parsed); err != nil {
|
||||
t.Fatalf("failed to parse schema: %v", err)
|
||||
}
|
||||
|
||||
// Should have oneOf for multiple tools
|
||||
if _, ok := parsed["oneOf"]; !ok {
|
||||
t.Error("expected oneOf for multiple tools")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("empty tools", func(t *testing.T) {
|
||||
schema := generateToolCallSchema([]api.Tool{})
|
||||
if schema != nil {
|
||||
t.Error("expected nil for empty tools")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestGenerateForcedFunctionSchema(t *testing.T) {
|
||||
tool := api.Tool{
|
||||
Type: "function",
|
||||
Function: api.ToolFunction{
|
||||
Name: "get_weather",
|
||||
Description: "Get weather",
|
||||
Parameters: api.ToolFunctionParameters{
|
||||
Type: "object",
|
||||
Required: []string{"location"},
|
||||
Properties: map[string]api.ToolProperty{
|
||||
"location": {Type: api.PropertyType{"string"}},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
schema := generateForcedFunctionSchema(tool)
|
||||
if schema == nil {
|
||||
t.Fatal("expected schema, got nil")
|
||||
}
|
||||
|
||||
var parsed map[string]any
|
||||
if err := json.Unmarshal(schema, &parsed); err != nil {
|
||||
t.Fatalf("failed to parse schema: %v", err)
|
||||
}
|
||||
|
||||
// Check that name has const constraint
|
||||
props := parsed["properties"].(map[string]any)
|
||||
nameSchema := props["name"].(map[string]any)
|
||||
if nameSchema["const"] != "get_weather" {
|
||||
t.Errorf("expected const get_weather, got %v", nameSchema["const"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseForcedToolCallContent(t *testing.T) {
|
||||
t.Run("valid tool call JSON", func(t *testing.T) {
|
||||
content := `{"name": "get_weather", "arguments": {"location": "Paris"}}`
|
||||
toolCall := parseForcedToolCallContent(content, "")
|
||||
|
||||
if toolCall == nil {
|
||||
t.Fatal("expected tool call, got nil")
|
||||
}
|
||||
|
||||
if toolCall.Function.Name != "get_weather" {
|
||||
t.Errorf("expected get_weather, got %s", toolCall.Function.Name)
|
||||
}
|
||||
|
||||
if toolCall.Function.Arguments["location"] != "Paris" {
|
||||
t.Errorf("expected Paris, got %v", toolCall.Function.Arguments["location"])
|
||||
}
|
||||
|
||||
if toolCall.ID == "" {
|
||||
t.Error("expected non-empty tool call ID")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("forced tool name override", func(t *testing.T) {
|
||||
content := `{"name": "other_tool", "arguments": {"location": "Paris"}}`
|
||||
toolCall := parseForcedToolCallContent(content, "get_weather")
|
||||
|
||||
if toolCall == nil {
|
||||
t.Fatal("expected tool call, got nil")
|
||||
}
|
||||
|
||||
// Should use the forced name, not the one in JSON
|
||||
if toolCall.Function.Name != "get_weather" {
|
||||
t.Errorf("expected get_weather (forced), got %s", toolCall.Function.Name)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("empty content", func(t *testing.T) {
|
||||
toolCall := parseForcedToolCallContent("", "")
|
||||
if toolCall != nil {
|
||||
t.Error("expected nil for empty content")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("invalid JSON", func(t *testing.T) {
|
||||
toolCall := parseForcedToolCallContent("not json", "")
|
||||
if toolCall != nil {
|
||||
t.Error("expected nil for invalid JSON")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("missing name", func(t *testing.T) {
|
||||
content := `{"arguments": {"location": "Paris"}}`
|
||||
toolCall := parseForcedToolCallContent(content, "")
|
||||
if toolCall != nil {
|
||||
t.Error("expected nil when name is missing and not forced")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("missing name but forced", func(t *testing.T) {
|
||||
content := `{"arguments": {"location": "Paris"}}`
|
||||
toolCall := parseForcedToolCallContent(content, "get_weather")
|
||||
|
||||
if toolCall == nil {
|
||||
t.Fatal("expected tool call with forced name")
|
||||
}
|
||||
|
||||
if toolCall.Function.Name != "get_weather" {
|
||||
t.Errorf("expected get_weather, got %s", toolCall.Function.Name)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in New Issue