1040 lines
29 KiB
Go
1040 lines
29 KiB
Go
// openai package provides core transformation logic for partial compatibility with the OpenAI REST API
|
|
package openai
|
|
|
|
import (
|
|
"bytes"
|
|
"encoding/base64"
|
|
"encoding/binary"
|
|
"encoding/json"
|
|
"errors"
|
|
"fmt"
|
|
"log/slog"
|
|
"net/http"
|
|
"slices"
|
|
"strings"
|
|
"time"
|
|
|
|
"github.com/ollama/ollama/api"
|
|
"github.com/ollama/ollama/types/model"
|
|
)
|
|
|
|
var finishReasonToolCalls = "tool_calls"
|
|
|
|
type Error struct {
|
|
Message string `json:"message"`
|
|
Type string `json:"type"`
|
|
Param any `json:"param"`
|
|
Code *string `json:"code"`
|
|
}
|
|
|
|
type ErrorResponse struct {
|
|
Error Error `json:"error"`
|
|
}
|
|
|
|
type Message struct {
|
|
Role string `json:"role"`
|
|
Content any `json:"content"`
|
|
Reasoning string `json:"reasoning,omitempty"`
|
|
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
|
|
Name string `json:"name,omitempty"`
|
|
ToolCallID string `json:"tool_call_id,omitempty"`
|
|
}
|
|
|
|
type ChoiceLogprobs struct {
|
|
Content []api.Logprob `json:"content"`
|
|
}
|
|
|
|
type Choice struct {
|
|
Index int `json:"index"`
|
|
Message Message `json:"message"`
|
|
FinishReason *string `json:"finish_reason"`
|
|
Logprobs *ChoiceLogprobs `json:"logprobs,omitempty"`
|
|
}
|
|
|
|
type ChunkChoice struct {
|
|
Index int `json:"index"`
|
|
Delta Message `json:"delta"`
|
|
FinishReason *string `json:"finish_reason"`
|
|
Logprobs *ChoiceLogprobs `json:"logprobs,omitempty"`
|
|
}
|
|
|
|
type CompleteChunkChoice struct {
|
|
Text string `json:"text"`
|
|
Index int `json:"index"`
|
|
FinishReason *string `json:"finish_reason"`
|
|
Logprobs *ChoiceLogprobs `json:"logprobs,omitempty"`
|
|
}
|
|
|
|
type Usage struct {
|
|
PromptTokens int `json:"prompt_tokens"`
|
|
CompletionTokens int `json:"completion_tokens"`
|
|
TotalTokens int `json:"total_tokens"`
|
|
}
|
|
|
|
type ResponseFormat struct {
|
|
Type string `json:"type"`
|
|
JsonSchema *JsonSchema `json:"json_schema,omitempty"`
|
|
}
|
|
|
|
type JsonSchema struct {
|
|
Schema json.RawMessage `json:"schema"`
|
|
}
|
|
|
|
type EmbedRequest struct {
|
|
Input any `json:"input"`
|
|
Model string `json:"model"`
|
|
Dimensions int `json:"dimensions,omitempty"`
|
|
EncodingFormat string `json:"encoding_format,omitempty"` // "float" or "base64"
|
|
}
|
|
|
|
type StreamOptions struct {
|
|
IncludeUsage bool `json:"include_usage"`
|
|
}
|
|
|
|
type Reasoning struct {
|
|
Effort string `json:"effort,omitempty"`
|
|
}
|
|
|
|
type ToolChoiceFunctionRef struct {
|
|
Name string `json:"name"`
|
|
}
|
|
|
|
type ToolChoiceAllowedTool struct {
|
|
Type string `json:"type"` // "function"
|
|
Name string `json:"name"`
|
|
}
|
|
|
|
type ToolChoiceObject struct {
|
|
Type string `json:"type,omitempty"`
|
|
Name string `json:"name,omitempty"`
|
|
Function *ToolChoiceFunctionRef `json:"function,omitempty"`
|
|
Mode string `json:"mode,omitempty"`
|
|
Tools []ToolChoiceAllowedTool `json:"tools,omitempty"`
|
|
}
|
|
|
|
type ToolChoice struct {
|
|
Mode string
|
|
Object *ToolChoiceObject
|
|
}
|
|
|
|
// UnmarshalJSON handles both string and object forms of tool_choice
|
|
func (tc *ToolChoice) UnmarshalJSON(data []byte) error {
|
|
// Try string first
|
|
var s string
|
|
if err := json.Unmarshal(data, &s); err == nil {
|
|
tc.Mode = s
|
|
tc.Object = nil
|
|
return nil
|
|
}
|
|
|
|
// Try object
|
|
var obj ToolChoiceObject
|
|
if err := json.Unmarshal(data, &obj); err != nil {
|
|
return fmt.Errorf("invalid tool_choice: must be string or object")
|
|
}
|
|
tc.Object = &obj
|
|
return nil
|
|
}
|
|
|
|
// MarshalJSON serializes ToolChoice back to JSON
|
|
func (tc ToolChoice) MarshalJSON() ([]byte, error) {
|
|
if tc.Object != nil {
|
|
return json.Marshal(tc.Object)
|
|
}
|
|
return json.Marshal(tc.Mode)
|
|
}
|
|
|
|
// IsNone returns true if tool_choice is "none"
|
|
func (tc *ToolChoice) IsNone() bool {
|
|
return tc != nil && tc.Mode == "none"
|
|
}
|
|
|
|
// IsRequired returns true if tool_choice is "required"
|
|
func (tc *ToolChoice) IsRequired() bool {
|
|
return tc != nil && tc.Mode == "required"
|
|
}
|
|
|
|
// IsAuto returns true if tool_choice is "auto" or not specified
|
|
func (tc *ToolChoice) IsAuto() bool {
|
|
if tc == nil {
|
|
return true
|
|
}
|
|
// If there's an object, check if it's allowed_tools with auto mode
|
|
if tc.Object != nil {
|
|
// allowed_tools with "auto" mode is still considered auto
|
|
if tc.Object.Type == "allowed_tools" && (tc.Object.Mode == "" || tc.Object.Mode == "auto") {
|
|
return true
|
|
}
|
|
// Any other object (forced function, allowed_tools with required) is not auto
|
|
return false
|
|
}
|
|
return tc.Mode == "" || tc.Mode == "auto"
|
|
}
|
|
|
|
// IsForcedFunction returns true if a specific function is forced
|
|
func (tc *ToolChoice) IsForcedFunction() bool {
|
|
if tc == nil || tc.Object == nil {
|
|
return false
|
|
}
|
|
return tc.Object.Type == "function" || tc.Object.Name != "" || tc.Object.Function != nil
|
|
}
|
|
|
|
// GetForcedFunctionName returns the name of the forced function, if any
|
|
func (tc *ToolChoice) GetForcedFunctionName() string {
|
|
if tc == nil || tc.Object == nil {
|
|
return ""
|
|
}
|
|
if tc.Object.Name != "" {
|
|
return tc.Object.Name
|
|
}
|
|
if tc.Object.Function != nil {
|
|
return tc.Object.Function.Name
|
|
}
|
|
return ""
|
|
}
|
|
|
|
// IsAllowedTools returns true if tool_choice uses allowed_tools mode
|
|
func (tc *ToolChoice) IsAllowedTools() bool {
|
|
return tc != nil && tc.Object != nil && tc.Object.Type == "allowed_tools"
|
|
}
|
|
|
|
// GetAllowedToolNames returns the list of allowed tool names
|
|
func (tc *ToolChoice) GetAllowedToolNames() []string {
|
|
if !tc.IsAllowedTools() || tc.Object.Tools == nil {
|
|
return nil
|
|
}
|
|
names := make([]string, len(tc.Object.Tools))
|
|
for i, t := range tc.Object.Tools {
|
|
names[i] = t.Name
|
|
}
|
|
return names
|
|
}
|
|
|
|
// GetAllowedToolsMode returns the mode for allowed_tools ("auto" or "required")
|
|
func (tc *ToolChoice) GetAllowedToolsMode() string {
|
|
if !tc.IsAllowedTools() {
|
|
return ""
|
|
}
|
|
if tc.Object.Mode == "" {
|
|
return "auto"
|
|
}
|
|
return tc.Object.Mode
|
|
}
|
|
|
|
type ChatCompletionRequest struct {
|
|
Model string `json:"model"`
|
|
Messages []Message `json:"messages"`
|
|
Stream bool `json:"stream"`
|
|
StreamOptions *StreamOptions `json:"stream_options"`
|
|
MaxTokens *int `json:"max_tokens"`
|
|
Seed *int `json:"seed"`
|
|
Stop any `json:"stop"`
|
|
Temperature *float64 `json:"temperature"`
|
|
FrequencyPenalty *float64 `json:"frequency_penalty"`
|
|
PresencePenalty *float64 `json:"presence_penalty"`
|
|
TopP *float64 `json:"top_p"`
|
|
ResponseFormat *ResponseFormat `json:"response_format"`
|
|
Tools []api.Tool `json:"tools"`
|
|
ToolChoice *ToolChoice `json:"tool_choice,omitempty"`
|
|
Reasoning *Reasoning `json:"reasoning,omitempty"`
|
|
ReasoningEffort *string `json:"reasoning_effort,omitempty"`
|
|
Logprobs *bool `json:"logprobs"`
|
|
TopLogprobs int `json:"top_logprobs"`
|
|
DebugRenderOnly bool `json:"_debug_render_only"`
|
|
}
|
|
|
|
type ChatCompletion struct {
|
|
Id string `json:"id"`
|
|
Object string `json:"object"`
|
|
Created int64 `json:"created"`
|
|
Model string `json:"model"`
|
|
SystemFingerprint string `json:"system_fingerprint"`
|
|
Choices []Choice `json:"choices"`
|
|
Usage Usage `json:"usage,omitempty"`
|
|
DebugInfo *api.DebugInfo `json:"_debug_info,omitempty"`
|
|
}
|
|
|
|
type ChatCompletionChunk struct {
|
|
Id string `json:"id"`
|
|
Object string `json:"object"`
|
|
Created int64 `json:"created"`
|
|
Model string `json:"model"`
|
|
SystemFingerprint string `json:"system_fingerprint"`
|
|
Choices []ChunkChoice `json:"choices"`
|
|
Usage *Usage `json:"usage,omitempty"`
|
|
}
|
|
|
|
// TODO (https://github.com/ollama/ollama/issues/5259): support []string, []int and [][]int
|
|
type CompletionRequest struct {
|
|
Model string `json:"model"`
|
|
Prompt string `json:"prompt"`
|
|
FrequencyPenalty float32 `json:"frequency_penalty"`
|
|
MaxTokens *int `json:"max_tokens"`
|
|
PresencePenalty float32 `json:"presence_penalty"`
|
|
Seed *int `json:"seed"`
|
|
Stop any `json:"stop"`
|
|
Stream bool `json:"stream"`
|
|
StreamOptions *StreamOptions `json:"stream_options"`
|
|
Temperature *float32 `json:"temperature"`
|
|
TopP float32 `json:"top_p"`
|
|
Suffix string `json:"suffix"`
|
|
Logprobs *int `json:"logprobs"`
|
|
DebugRenderOnly bool `json:"_debug_render_only"`
|
|
}
|
|
|
|
type Completion struct {
|
|
Id string `json:"id"`
|
|
Object string `json:"object"`
|
|
Created int64 `json:"created"`
|
|
Model string `json:"model"`
|
|
SystemFingerprint string `json:"system_fingerprint"`
|
|
Choices []CompleteChunkChoice `json:"choices"`
|
|
Usage Usage `json:"usage,omitempty"`
|
|
}
|
|
|
|
type CompletionChunk struct {
|
|
Id string `json:"id"`
|
|
Object string `json:"object"`
|
|
Created int64 `json:"created"`
|
|
Choices []CompleteChunkChoice `json:"choices"`
|
|
Model string `json:"model"`
|
|
SystemFingerprint string `json:"system_fingerprint"`
|
|
Usage *Usage `json:"usage,omitempty"`
|
|
}
|
|
|
|
type ToolCall struct {
|
|
ID string `json:"id"`
|
|
Index int `json:"index"`
|
|
Type string `json:"type"`
|
|
Function struct {
|
|
Name string `json:"name"`
|
|
Arguments string `json:"arguments"`
|
|
} `json:"function"`
|
|
}
|
|
|
|
type Model struct {
|
|
Id string `json:"id"`
|
|
Object string `json:"object"`
|
|
Created int64 `json:"created"`
|
|
OwnedBy string `json:"owned_by"`
|
|
}
|
|
|
|
type Embedding struct {
|
|
Object string `json:"object"`
|
|
Embedding any `json:"embedding"` // Can be []float32 (float format) or string (base64 format)
|
|
Index int `json:"index"`
|
|
}
|
|
|
|
type ListCompletion struct {
|
|
Object string `json:"object"`
|
|
Data []Model `json:"data"`
|
|
}
|
|
|
|
type EmbeddingList struct {
|
|
Object string `json:"object"`
|
|
Data []Embedding `json:"data"`
|
|
Model string `json:"model"`
|
|
Usage EmbeddingUsage `json:"usage,omitempty"`
|
|
}
|
|
|
|
type EmbeddingUsage struct {
|
|
PromptTokens int `json:"prompt_tokens"`
|
|
TotalTokens int `json:"total_tokens"`
|
|
}
|
|
|
|
func NewError(code int, message string) ErrorResponse {
|
|
var etype string
|
|
switch code {
|
|
case http.StatusBadRequest:
|
|
etype = "invalid_request_error"
|
|
case http.StatusNotFound:
|
|
etype = "not_found_error"
|
|
default:
|
|
etype = "api_error"
|
|
}
|
|
|
|
return ErrorResponse{Error{Type: etype, Message: message}}
|
|
}
|
|
|
|
// ToUsage converts an api.ChatResponse to Usage
|
|
func ToUsage(r api.ChatResponse) Usage {
|
|
return Usage{
|
|
PromptTokens: r.Metrics.PromptEvalCount,
|
|
CompletionTokens: r.Metrics.EvalCount,
|
|
TotalTokens: r.Metrics.PromptEvalCount + r.Metrics.EvalCount,
|
|
}
|
|
}
|
|
|
|
// ToToolCalls converts api.ToolCall to OpenAI ToolCall format
|
|
func ToToolCalls(tc []api.ToolCall) []ToolCall {
|
|
toolCalls := make([]ToolCall, len(tc))
|
|
for i, tc := range tc {
|
|
toolCalls[i].ID = tc.ID
|
|
toolCalls[i].Type = "function"
|
|
toolCalls[i].Function.Name = tc.Function.Name
|
|
toolCalls[i].Index = tc.Function.Index
|
|
|
|
args, err := json.Marshal(tc.Function.Arguments)
|
|
if err != nil {
|
|
slog.Error("could not marshall function arguments to json", "error", err)
|
|
continue
|
|
}
|
|
|
|
toolCalls[i].Function.Arguments = string(args)
|
|
}
|
|
return toolCalls
|
|
}
|
|
|
|
// ToChatCompletion converts an api.ChatResponse to ChatCompletion
|
|
func ToChatCompletion(id string, r api.ChatResponse) ChatCompletion {
|
|
toolCalls := ToToolCalls(r.Message.ToolCalls)
|
|
|
|
var logprobs *ChoiceLogprobs
|
|
if len(r.Logprobs) > 0 {
|
|
logprobs = &ChoiceLogprobs{Content: r.Logprobs}
|
|
}
|
|
|
|
return ChatCompletion{
|
|
Id: id,
|
|
Object: "chat.completion",
|
|
Created: r.CreatedAt.Unix(),
|
|
Model: r.Model,
|
|
SystemFingerprint: "fp_ollama",
|
|
Choices: []Choice{{
|
|
Index: 0,
|
|
Message: Message{Role: r.Message.Role, Content: r.Message.Content, ToolCalls: toolCalls, Reasoning: r.Message.Thinking},
|
|
FinishReason: func(reason string) *string {
|
|
if len(toolCalls) > 0 {
|
|
reason = "tool_calls"
|
|
}
|
|
if len(reason) > 0 {
|
|
return &reason
|
|
}
|
|
return nil
|
|
}(r.DoneReason),
|
|
Logprobs: logprobs,
|
|
}}, Usage: ToUsage(r),
|
|
DebugInfo: r.DebugInfo,
|
|
}
|
|
}
|
|
|
|
// ToChunk converts an api.ChatResponse to ChatCompletionChunk
|
|
func ToChunk(id string, r api.ChatResponse, toolCallSent bool) ChatCompletionChunk {
|
|
toolCalls := ToToolCalls(r.Message.ToolCalls)
|
|
|
|
var logprobs *ChoiceLogprobs
|
|
if len(r.Logprobs) > 0 {
|
|
logprobs = &ChoiceLogprobs{Content: r.Logprobs}
|
|
}
|
|
|
|
return ChatCompletionChunk{
|
|
Id: id,
|
|
Object: "chat.completion.chunk",
|
|
Created: time.Now().Unix(),
|
|
Model: r.Model,
|
|
SystemFingerprint: "fp_ollama",
|
|
Choices: []ChunkChoice{{
|
|
Index: 0,
|
|
Delta: Message{Role: "assistant", Content: r.Message.Content, ToolCalls: toolCalls, Reasoning: r.Message.Thinking},
|
|
FinishReason: func(reason string) *string {
|
|
if len(reason) > 0 {
|
|
if toolCallSent || len(toolCalls) > 0 {
|
|
return &finishReasonToolCalls
|
|
}
|
|
return &reason
|
|
}
|
|
return nil
|
|
}(r.DoneReason),
|
|
Logprobs: logprobs,
|
|
}},
|
|
}
|
|
}
|
|
|
|
// ToUsageGenerate converts an api.GenerateResponse to Usage
|
|
func ToUsageGenerate(r api.GenerateResponse) Usage {
|
|
return Usage{
|
|
PromptTokens: r.Metrics.PromptEvalCount,
|
|
CompletionTokens: r.Metrics.EvalCount,
|
|
TotalTokens: r.Metrics.PromptEvalCount + r.Metrics.EvalCount,
|
|
}
|
|
}
|
|
|
|
// ToCompletion converts an api.GenerateResponse to Completion
|
|
func ToCompletion(id string, r api.GenerateResponse) Completion {
|
|
return Completion{
|
|
Id: id,
|
|
Object: "text_completion",
|
|
Created: r.CreatedAt.Unix(),
|
|
Model: r.Model,
|
|
SystemFingerprint: "fp_ollama",
|
|
Choices: []CompleteChunkChoice{{
|
|
Text: r.Response,
|
|
Index: 0,
|
|
FinishReason: func(reason string) *string {
|
|
if len(reason) > 0 {
|
|
return &reason
|
|
}
|
|
return nil
|
|
}(r.DoneReason),
|
|
}},
|
|
Usage: ToUsageGenerate(r),
|
|
}
|
|
}
|
|
|
|
// ToCompleteChunk converts an api.GenerateResponse to CompletionChunk
|
|
func ToCompleteChunk(id string, r api.GenerateResponse) CompletionChunk {
|
|
return CompletionChunk{
|
|
Id: id,
|
|
Object: "text_completion",
|
|
Created: time.Now().Unix(),
|
|
Model: r.Model,
|
|
SystemFingerprint: "fp_ollama",
|
|
Choices: []CompleteChunkChoice{{
|
|
Text: r.Response,
|
|
Index: 0,
|
|
FinishReason: func(reason string) *string {
|
|
if len(reason) > 0 {
|
|
return &reason
|
|
}
|
|
return nil
|
|
}(r.DoneReason),
|
|
}},
|
|
}
|
|
}
|
|
|
|
// ToListCompletion converts an api.ListResponse to ListCompletion
|
|
func ToListCompletion(r api.ListResponse) ListCompletion {
|
|
var data []Model
|
|
for _, m := range r.Models {
|
|
data = append(data, Model{
|
|
Id: m.Name,
|
|
Object: "model",
|
|
Created: m.ModifiedAt.Unix(),
|
|
OwnedBy: model.ParseName(m.Name).Namespace,
|
|
})
|
|
}
|
|
|
|
return ListCompletion{
|
|
Object: "list",
|
|
Data: data,
|
|
}
|
|
}
|
|
|
|
// ToEmbeddingList converts an api.EmbedResponse to EmbeddingList
|
|
// encodingFormat can be "float", "base64", or empty (defaults to "float")
|
|
func ToEmbeddingList(model string, r api.EmbedResponse, encodingFormat string) EmbeddingList {
|
|
if r.Embeddings != nil {
|
|
var data []Embedding
|
|
for i, e := range r.Embeddings {
|
|
var embedding any
|
|
if strings.EqualFold(encodingFormat, "base64") {
|
|
embedding = floatsToBase64(e)
|
|
} else {
|
|
embedding = e
|
|
}
|
|
|
|
data = append(data, Embedding{
|
|
Object: "embedding",
|
|
Embedding: embedding,
|
|
Index: i,
|
|
})
|
|
}
|
|
|
|
return EmbeddingList{
|
|
Object: "list",
|
|
Data: data,
|
|
Model: model,
|
|
Usage: EmbeddingUsage{
|
|
PromptTokens: r.PromptEvalCount,
|
|
TotalTokens: r.PromptEvalCount,
|
|
},
|
|
}
|
|
}
|
|
|
|
return EmbeddingList{}
|
|
}
|
|
|
|
// floatsToBase64 encodes a []float32 to a base64 string
|
|
func floatsToBase64(floats []float32) string {
|
|
var buf bytes.Buffer
|
|
binary.Write(&buf, binary.LittleEndian, floats)
|
|
return base64.StdEncoding.EncodeToString(buf.Bytes())
|
|
}
|
|
|
|
// ToModel converts an api.ShowResponse to Model
|
|
func ToModel(r api.ShowResponse, m string) Model {
|
|
return Model{
|
|
Id: m,
|
|
Object: "model",
|
|
Created: r.ModifiedAt.Unix(),
|
|
OwnedBy: model.ParseName(m).Namespace,
|
|
}
|
|
}
|
|
|
|
// filterToolsByNames returns only the tools that match the given names
|
|
func filterToolsByNames(tools []api.Tool, names []string) []api.Tool {
|
|
if len(names) == 0 {
|
|
return tools
|
|
}
|
|
nameSet := make(map[string]bool)
|
|
for _, name := range names {
|
|
nameSet[name] = true
|
|
}
|
|
var filtered []api.Tool
|
|
for _, tool := range tools {
|
|
if nameSet[tool.Function.Name] {
|
|
filtered = append(filtered, tool)
|
|
}
|
|
}
|
|
return filtered
|
|
}
|
|
|
|
// findToolByName returns the tool with the given name, or nil if not found
|
|
func findToolByName(tools []api.Tool, name string) *api.Tool {
|
|
for _, tool := range tools {
|
|
if tool.Function.Name == name {
|
|
return &tool
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// generateToolCallSchema creates a JSON schema that constrains output to valid tool calls
|
|
func generateToolCallSchema(tools []api.Tool) json.RawMessage {
|
|
if len(tools) == 0 {
|
|
return nil
|
|
}
|
|
|
|
// Collect all tool names for the enum
|
|
toolNames := make([]string, len(tools))
|
|
for i, tool := range tools {
|
|
toolNames[i] = tool.Function.Name
|
|
}
|
|
|
|
// Build a schema that allows any of the tools
|
|
// Using oneOf for each tool with its specific parameter schema
|
|
var oneOfSchemas []map[string]any
|
|
for _, tool := range tools {
|
|
toolSchema := map[string]any{
|
|
"type": "object",
|
|
"properties": map[string]any{
|
|
"name": map[string]any{
|
|
"type": "string",
|
|
"const": tool.Function.Name,
|
|
},
|
|
"arguments": tool.Function.Parameters,
|
|
},
|
|
"required": []string{"name", "arguments"},
|
|
"additionalProperties": false,
|
|
}
|
|
oneOfSchemas = append(oneOfSchemas, toolSchema)
|
|
}
|
|
|
|
var schema map[string]any
|
|
if len(oneOfSchemas) == 1 {
|
|
// Single tool - use its schema directly
|
|
schema = oneOfSchemas[0]
|
|
} else {
|
|
// Multiple tools - use oneOf
|
|
schema = map[string]any{
|
|
"oneOf": oneOfSchemas,
|
|
}
|
|
}
|
|
|
|
bytes, err := json.Marshal(schema)
|
|
if err != nil {
|
|
slog.Error("failed to marshal tool call schema", "error", err)
|
|
return nil
|
|
}
|
|
return bytes
|
|
}
|
|
|
|
// generateForcedFunctionSchema creates a JSON schema for a specific forced function
|
|
func generateForcedFunctionSchema(tool api.Tool) json.RawMessage {
|
|
schema := map[string]any{
|
|
"type": "object",
|
|
"properties": map[string]any{
|
|
"name": map[string]any{
|
|
"type": "string",
|
|
"const": tool.Function.Name,
|
|
},
|
|
"arguments": tool.Function.Parameters,
|
|
},
|
|
"required": []string{"name", "arguments"},
|
|
"additionalProperties": false,
|
|
}
|
|
|
|
bytes, err := json.Marshal(schema)
|
|
if err != nil {
|
|
slog.Error("failed to marshal forced function schema", "error", err)
|
|
return nil
|
|
}
|
|
return bytes
|
|
}
|
|
|
|
// ApplyToolChoice processes tool_choice and returns filtered tools and optional format schema
|
|
// Returns:
|
|
// - filteredTools: the tools to pass to the model (may be empty for "none")
|
|
// - format: JSON schema to constrain output (for "required" or forced function)
|
|
// - forcedToolCall: true if the response should be parsed as a tool call
|
|
// - error: if tool_choice references a non-existent tool
|
|
func ApplyToolChoice(tools []api.Tool, toolChoice *ToolChoice) (filteredTools []api.Tool, format json.RawMessage, forcedToolCall bool, err error) {
|
|
// Default: auto mode, return all tools without format constraint
|
|
if toolChoice == nil || toolChoice.IsAuto() {
|
|
// Check for allowed_tools with auto mode
|
|
if toolChoice != nil && toolChoice.IsAllowedTools() {
|
|
allowedNames := toolChoice.GetAllowedToolNames()
|
|
filteredTools = filterToolsByNames(tools, allowedNames)
|
|
if toolChoice.GetAllowedToolsMode() == "required" {
|
|
format = generateToolCallSchema(filteredTools)
|
|
forcedToolCall = true
|
|
}
|
|
return filteredTools, format, forcedToolCall, nil
|
|
}
|
|
return tools, nil, false, nil
|
|
}
|
|
|
|
// "none" mode: don't pass any tools
|
|
if toolChoice.IsNone() {
|
|
return nil, nil, false, nil
|
|
}
|
|
|
|
// "required" mode: must call at least one tool
|
|
if toolChoice.IsRequired() {
|
|
format = generateToolCallSchema(tools)
|
|
return tools, format, true, nil
|
|
}
|
|
|
|
// Forced function mode
|
|
if toolChoice.IsForcedFunction() {
|
|
funcName := toolChoice.GetForcedFunctionName()
|
|
if funcName == "" {
|
|
return nil, nil, false, errors.New("tool_choice function name is required")
|
|
}
|
|
|
|
tool := findToolByName(tools, funcName)
|
|
if tool == nil {
|
|
return nil, nil, false, fmt.Errorf("tool_choice references unknown function: %s", funcName)
|
|
}
|
|
|
|
format = generateForcedFunctionSchema(*tool)
|
|
return []api.Tool{*tool}, format, true, nil
|
|
}
|
|
|
|
// allowed_tools mode (already handled in IsAuto check, but handle explicit type here)
|
|
if toolChoice.IsAllowedTools() {
|
|
allowedNames := toolChoice.GetAllowedToolNames()
|
|
filteredTools = filterToolsByNames(tools, allowedNames)
|
|
if toolChoice.GetAllowedToolsMode() == "required" {
|
|
format = generateToolCallSchema(filteredTools)
|
|
forcedToolCall = true
|
|
}
|
|
return filteredTools, format, forcedToolCall, nil
|
|
}
|
|
|
|
// Unknown mode, default to auto
|
|
return tools, nil, false, nil
|
|
}
|
|
|
|
// FromChatRequest converts a ChatCompletionRequest to api.ChatRequest
|
|
func FromChatRequest(r ChatCompletionRequest) (*api.ChatRequest, error) {
|
|
var messages []api.Message
|
|
for _, msg := range r.Messages {
|
|
toolName := ""
|
|
if strings.ToLower(msg.Role) == "tool" {
|
|
toolName = msg.Name
|
|
if toolName == "" && msg.ToolCallID != "" {
|
|
toolName = nameFromToolCallID(r.Messages, msg.ToolCallID)
|
|
}
|
|
}
|
|
switch content := msg.Content.(type) {
|
|
case string:
|
|
toolCalls, err := FromCompletionToolCall(msg.ToolCalls)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
messages = append(messages, api.Message{Role: msg.Role, Content: content, Thinking: msg.Reasoning, ToolCalls: toolCalls, ToolName: toolName, ToolCallID: msg.ToolCallID})
|
|
case []any:
|
|
for _, c := range content {
|
|
data, ok := c.(map[string]any)
|
|
if !ok {
|
|
return nil, errors.New("invalid message format")
|
|
}
|
|
switch data["type"] {
|
|
case "text":
|
|
text, ok := data["text"].(string)
|
|
if !ok {
|
|
return nil, errors.New("invalid message format")
|
|
}
|
|
messages = append(messages, api.Message{Role: msg.Role, Content: text})
|
|
case "image_url":
|
|
var url string
|
|
if urlMap, ok := data["image_url"].(map[string]any); ok {
|
|
if url, ok = urlMap["url"].(string); !ok {
|
|
return nil, errors.New("invalid message format")
|
|
}
|
|
} else {
|
|
if url, ok = data["image_url"].(string); !ok {
|
|
return nil, errors.New("invalid message format")
|
|
}
|
|
}
|
|
|
|
img, err := decodeImageURL(url)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
messages = append(messages, api.Message{Role: msg.Role, Images: []api.ImageData{img}})
|
|
default:
|
|
return nil, errors.New("invalid message format")
|
|
}
|
|
}
|
|
// since we might have added multiple messages above, if we have tools
|
|
// calls we'll add them to the last message
|
|
if len(messages) > 0 && len(msg.ToolCalls) > 0 {
|
|
toolCalls, err := FromCompletionToolCall(msg.ToolCalls)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
messages[len(messages)-1].ToolCalls = toolCalls
|
|
messages[len(messages)-1].ToolName = toolName
|
|
messages[len(messages)-1].ToolCallID = msg.ToolCallID
|
|
messages[len(messages)-1].Thinking = msg.Reasoning
|
|
}
|
|
default:
|
|
// content is only optional if tool calls are present
|
|
if msg.ToolCalls == nil {
|
|
return nil, fmt.Errorf("invalid message content type: %T", content)
|
|
}
|
|
|
|
toolCalls, err := FromCompletionToolCall(msg.ToolCalls)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
messages = append(messages, api.Message{Role: msg.Role, Thinking: msg.Reasoning, ToolCalls: toolCalls, ToolCallID: msg.ToolCallID})
|
|
}
|
|
}
|
|
|
|
options := make(map[string]any)
|
|
|
|
switch stop := r.Stop.(type) {
|
|
case string:
|
|
options["stop"] = []string{stop}
|
|
case []any:
|
|
var stops []string
|
|
for _, s := range stop {
|
|
if str, ok := s.(string); ok {
|
|
stops = append(stops, str)
|
|
}
|
|
}
|
|
options["stop"] = stops
|
|
}
|
|
|
|
if r.MaxTokens != nil {
|
|
options["num_predict"] = *r.MaxTokens
|
|
}
|
|
|
|
if r.Temperature != nil {
|
|
options["temperature"] = *r.Temperature
|
|
} else {
|
|
options["temperature"] = 1.0
|
|
}
|
|
|
|
if r.Seed != nil {
|
|
options["seed"] = *r.Seed
|
|
}
|
|
|
|
if r.FrequencyPenalty != nil {
|
|
options["frequency_penalty"] = *r.FrequencyPenalty
|
|
}
|
|
|
|
if r.PresencePenalty != nil {
|
|
options["presence_penalty"] = *r.PresencePenalty
|
|
}
|
|
|
|
if r.TopP != nil {
|
|
options["top_p"] = *r.TopP
|
|
} else {
|
|
options["top_p"] = 1.0
|
|
}
|
|
|
|
var format json.RawMessage
|
|
if r.ResponseFormat != nil {
|
|
switch strings.ToLower(strings.TrimSpace(r.ResponseFormat.Type)) {
|
|
// Support the old "json_object" type for OpenAI compatibility
|
|
case "json_object":
|
|
format = json.RawMessage(`"json"`)
|
|
case "json_schema":
|
|
if r.ResponseFormat.JsonSchema != nil {
|
|
format = r.ResponseFormat.JsonSchema.Schema
|
|
}
|
|
}
|
|
}
|
|
|
|
// Apply tool_choice to filter tools and potentially set format constraint
|
|
filteredTools, toolChoiceFormat, _, err := ApplyToolChoice(r.Tools, r.ToolChoice)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
// If tool_choice requires a format constraint and no explicit response_format was set,
|
|
// apply the tool call schema
|
|
if toolChoiceFormat != nil && format == nil {
|
|
format = toolChoiceFormat
|
|
}
|
|
|
|
var think *api.ThinkValue
|
|
var effort string
|
|
|
|
if r.Reasoning != nil {
|
|
effort = r.Reasoning.Effort
|
|
} else if r.ReasoningEffort != nil {
|
|
effort = *r.ReasoningEffort
|
|
}
|
|
|
|
if effort != "" {
|
|
if !slices.Contains([]string{"high", "medium", "low", "none"}, effort) {
|
|
return nil, fmt.Errorf("invalid reasoning value: '%s' (must be \"high\", \"medium\", \"low\", or \"none\")", effort)
|
|
}
|
|
|
|
if effort == "none" {
|
|
think = &api.ThinkValue{Value: false}
|
|
} else {
|
|
think = &api.ThinkValue{Value: effort}
|
|
}
|
|
}
|
|
|
|
return &api.ChatRequest{
|
|
Model: r.Model,
|
|
Messages: messages,
|
|
Format: format,
|
|
Options: options,
|
|
Stream: &r.Stream,
|
|
Tools: filteredTools,
|
|
Think: think,
|
|
Logprobs: r.Logprobs != nil && *r.Logprobs,
|
|
TopLogprobs: r.TopLogprobs,
|
|
DebugRenderOnly: r.DebugRenderOnly,
|
|
}, nil
|
|
}
|
|
|
|
func nameFromToolCallID(messages []Message, toolCallID string) string {
|
|
// iterate backwards to be more resilient to duplicate tool call IDs (this
|
|
// follows "last one wins")
|
|
for i := len(messages) - 1; i >= 0; i-- {
|
|
msg := messages[i]
|
|
for _, tc := range msg.ToolCalls {
|
|
if tc.ID == toolCallID {
|
|
return tc.Function.Name
|
|
}
|
|
}
|
|
}
|
|
return ""
|
|
}
|
|
|
|
// decodeImageURL decodes a base64 data URI into raw image bytes.
|
|
func decodeImageURL(url string) (api.ImageData, error) {
|
|
types := []string{"jpeg", "jpg", "png", "webp"}
|
|
|
|
// Support blank mime type to match /api/chat's behavior of taking just unadorned base64
|
|
if strings.HasPrefix(url, "data:;base64,") {
|
|
url = strings.TrimPrefix(url, "data:;base64,")
|
|
} else {
|
|
valid := false
|
|
for _, t := range types {
|
|
prefix := "data:image/" + t + ";base64,"
|
|
if strings.HasPrefix(url, prefix) {
|
|
url = strings.TrimPrefix(url, prefix)
|
|
valid = true
|
|
break
|
|
}
|
|
}
|
|
if !valid {
|
|
return nil, errors.New("invalid image input")
|
|
}
|
|
}
|
|
|
|
img, err := base64.StdEncoding.DecodeString(url)
|
|
if err != nil {
|
|
return nil, errors.New("invalid image input")
|
|
}
|
|
return img, nil
|
|
}
|
|
|
|
// FromCompletionToolCall converts OpenAI ToolCall format to api.ToolCall
|
|
func FromCompletionToolCall(toolCalls []ToolCall) ([]api.ToolCall, error) {
|
|
apiToolCalls := make([]api.ToolCall, len(toolCalls))
|
|
for i, tc := range toolCalls {
|
|
apiToolCalls[i].ID = tc.ID
|
|
apiToolCalls[i].Function.Name = tc.Function.Name
|
|
err := json.Unmarshal([]byte(tc.Function.Arguments), &apiToolCalls[i].Function.Arguments)
|
|
if err != nil {
|
|
return nil, errors.New("invalid tool call arguments")
|
|
}
|
|
}
|
|
|
|
return apiToolCalls, nil
|
|
}
|
|
|
|
// FromCompleteRequest converts a CompletionRequest to api.GenerateRequest
|
|
func FromCompleteRequest(r CompletionRequest) (api.GenerateRequest, error) {
|
|
options := make(map[string]any)
|
|
|
|
switch stop := r.Stop.(type) {
|
|
case string:
|
|
options["stop"] = []string{stop}
|
|
case []any:
|
|
var stops []string
|
|
for _, s := range stop {
|
|
if str, ok := s.(string); ok {
|
|
stops = append(stops, str)
|
|
} else {
|
|
return api.GenerateRequest{}, fmt.Errorf("invalid type for 'stop' field: %T", s)
|
|
}
|
|
}
|
|
options["stop"] = stops
|
|
}
|
|
|
|
if r.MaxTokens != nil {
|
|
options["num_predict"] = *r.MaxTokens
|
|
}
|
|
|
|
if r.Temperature != nil {
|
|
options["temperature"] = *r.Temperature
|
|
} else {
|
|
options["temperature"] = 1.0
|
|
}
|
|
|
|
if r.Seed != nil {
|
|
options["seed"] = *r.Seed
|
|
}
|
|
|
|
options["frequency_penalty"] = r.FrequencyPenalty
|
|
|
|
options["presence_penalty"] = r.PresencePenalty
|
|
|
|
if r.TopP != 0.0 {
|
|
options["top_p"] = r.TopP
|
|
} else {
|
|
options["top_p"] = 1.0
|
|
}
|
|
|
|
var logprobs bool
|
|
var topLogprobs int
|
|
if r.Logprobs != nil && *r.Logprobs > 0 {
|
|
logprobs = true
|
|
topLogprobs = *r.Logprobs
|
|
}
|
|
|
|
return api.GenerateRequest{
|
|
Model: r.Model,
|
|
Prompt: r.Prompt,
|
|
Options: options,
|
|
Stream: &r.Stream,
|
|
Suffix: r.Suffix,
|
|
Logprobs: logprobs,
|
|
TopLogprobs: topLogprobs,
|
|
DebugRenderOnly: r.DebugRenderOnly,
|
|
}, nil
|
|
}
|