From 08e9ebc7f633bae173257d57f04658ae0e59fe55 Mon Sep 17 00:00:00 2001 From: ParthSareen Date: Tue, 21 Oct 2025 19:01:12 -0700 Subject: [PATCH] routes/types: add tool call id --- api/types.go | 1 + openai/openai.go | 19 +++++---- openai/openai_test.go | 70 ++++++++++++++++++++++++++++++++++ server/routes.go | 17 +++++++++ server/routes_generate_test.go | 30 ++++++++++++++- 5 files changed, 126 insertions(+), 11 deletions(-) diff --git a/api/types.go b/api/types.go index 1483c844f..95b054f7e 100644 --- a/api/types.go +++ b/api/types.go @@ -200,6 +200,7 @@ func (m *Message) UnmarshalJSON(b []byte) error { } type ToolCall struct { + ID string `json:"id,omitempty"` Function ToolCallFunction `json:"function"` } diff --git a/openai/openai.go b/openai/openai.go index 23e9522f0..8411c1533 100644 --- a/openai/openai.go +++ b/openai/openai.go @@ -7,7 +7,6 @@ import ( "errors" "fmt" "log/slog" - "math/rand" "net/http" "slices" "strings" @@ -226,20 +225,20 @@ func ToUsage(r api.ChatResponse) Usage { } } -func toolCallId() string { - const letterBytes = "abcdefghijklmnopqrstuvwxyz0123456789" - b := make([]byte, 8) - for i := range b { - b[i] = letterBytes[rand.Intn(len(letterBytes))] - } - return "call_" + strings.ToLower(string(b)) -} +// func toolCallId() string { +// const letterBytes = "abcdefghijklmnopqrstuvwxyz0123456789" +// b := make([]byte, 8) +// for i := range b { +// b[i] = letterBytes[rand.Intn(len(letterBytes))] +// } +// return "call_" + strings.ToLower(string(b)) +// } // ToToolCalls converts api.ToolCall to OpenAI ToolCall format func ToToolCalls(tc []api.ToolCall) []ToolCall { toolCalls := make([]ToolCall, len(tc)) for i, tc := range tc { - toolCalls[i].ID = toolCallId() + toolCalls[i].ID = tc.ID toolCalls[i].Type = "function" toolCalls[i].Function.Name = tc.Function.Name toolCalls[i].Index = tc.Function.Index diff --git a/openai/openai_test.go b/openai/openai_test.go index 0f1a877f4..b054bdb74 100644 --- a/openai/openai_test.go +++ b/openai/openai_test.go @@ -4,7 +4,10 @@ import ( "encoding/base64" "testing" + "github.com/google/go-cmp/cmp" + "github.com/ollama/ollama/api" + "golang.org/x/exp/slices" ) const ( @@ -148,3 +151,70 @@ func TestNewError(t *testing.T) { } } } + +func TestToToolCallsPreservesIDs(t *testing.T) { + original := []api.ToolCall{ + { + ID: "call_abc123", + Function: api.ToolCallFunction{ + Index: 2, + Name: "get_weather", + Arguments: api.ToolCallFunctionArguments{ + "location": "Seattle", + }, + }, + }, + { + ID: "call_def456", + Function: api.ToolCallFunction{ + Index: 7, + Name: "get_time", + Arguments: api.ToolCallFunctionArguments{ + "timezone": "UTC", + }, + }, + }, + } + + toolCalls := slices.Clone(original) + got := ToToolCalls(toolCalls) + + if len(got) != len(original) { + t.Fatalf("expected %d tool calls, got %d", len(original), len(got)) + } + + expected := []ToolCall{ + { + ID: "call_abc123", + Type: "function", + Index: 2, + Function: struct { + Name string `json:"name"` + Arguments string `json:"arguments"` + }{ + Name: "get_weather", + Arguments: `{"location":"Seattle"}`, + }, + }, + { + ID: "call_def456", + Type: "function", + Index: 7, + Function: struct { + Name string `json:"name"` + Arguments string `json:"arguments"` + }{ + Name: "get_time", + Arguments: `{"timezone":"UTC"}`, + }, + }, + } + + if diff := cmp.Diff(expected, got); diff != "" { + t.Errorf("tool calls mismatch (-want +got):\n%s", diff) + } + + if diff := cmp.Diff(original, toolCalls); diff != "" { + t.Errorf("input tool calls mutated (-want +got):\n%s", diff) + } +} diff --git a/server/routes.go b/server/routes.go index 80c00cb69..8e1c294e3 100644 --- a/server/routes.go +++ b/server/routes.go @@ -13,6 +13,7 @@ import ( "io/fs" "log/slog" "math" + "math/rand" "net" "net/http" "net/netip" @@ -1803,6 +1804,15 @@ func (s *Server) PsHandler(c *gin.Context) { c.JSON(http.StatusOK, api.ProcessResponse{Models: models}) } +func toolCallId() string { + const letterBytes = "abcdefghijklmnopqrstuvwxyz0123456789" + b := make([]byte, 8) + for i := range b { + b[i] = letterBytes[rand.Intn(len(letterBytes))] + } + return "call_" + strings.ToLower(string(b)) +} + func (s *Server) ChatHandler(c *gin.Context) { checkpointStart := time.Now() @@ -2111,6 +2121,9 @@ func (s *Server) ChatHandler(c *gin.Context) { res.Message.Content = content res.Message.Thinking = thinking + for i := range toolCalls { + toolCalls[i].ID = toolCallId() + } res.Message.ToolCalls = toolCalls tb.WriteString(thinking) @@ -2155,8 +2168,12 @@ func (s *Server) ChatHandler(c *gin.Context) { if len(content) > 0 { res.Message.Content = content } else if len(toolCalls) > 0 { + for i := range toolCalls { + toolCalls[i].ID = toolCallId() + } res.Message.ToolCalls = toolCalls res.Message.Content = "" + } else if res.Message.Thinking != "" { // don't return } else { diff --git a/server/routes_generate_test.go b/server/routes_generate_test.go index 75d4f012e..54588bcd6 100644 --- a/server/routes_generate_test.go +++ b/server/routes_generate_test.go @@ -472,6 +472,14 @@ func TestGenerateChat(t *testing.T) { t.Error("expected tool calls, got nil") } + gotToolCall := resp.Message.ToolCalls[0] + if gotToolCall.ID == "" { + t.Error("expected tool call ID to be populated") + } + if !strings.HasPrefix(gotToolCall.ID, "call_") { + t.Errorf("expected tool call ID to have call_ prefix, got %q", gotToolCall.ID) + } + expectedToolCall := api.ToolCall{ Function: api.ToolCallFunction{ Name: "get_weather", @@ -482,7 +490,8 @@ func TestGenerateChat(t *testing.T) { }, } - if diff := cmp.Diff(resp.Message.ToolCalls[0], expectedToolCall); diff != "" { + expectedToolCall.ID = gotToolCall.ID + if diff := cmp.Diff(gotToolCall, expectedToolCall); diff != "" { t.Errorf("tool call mismatch (-got +want):\n%s", diff) } }) @@ -587,6 +596,17 @@ func TestGenerateChat(t *testing.T) { t.Fatal(err) } + if len(resp.Message.ToolCalls) > 0 { + for _, call := range resp.Message.ToolCalls { + if call.ID == "" { + t.Fatal("expected streaming tool call to have an ID") + } + if !strings.HasPrefix(call.ID, "call_") { + t.Fatalf("expected streaming tool call ID to have call_ prefix, got %q", call.ID) + } + } + } + if resp.Done { if len(resp.Message.ToolCalls) != 1 { t.Errorf("expected 1 tool call in final response, got %d", len(resp.Message.ToolCalls)) @@ -605,6 +625,14 @@ func TestGenerateChat(t *testing.T) { }, } + if finalToolCall.ID == "" { + t.Fatal("expected final tool call to have an ID") + } + if !strings.HasPrefix(finalToolCall.ID, "call_") { + t.Fatalf("expected final tool call ID to have call_ prefix, got %q", finalToolCall.ID) + } + + expectedToolCall.ID = finalToolCall.ID if diff := cmp.Diff(finalToolCall, expectedToolCall); diff != "" { t.Errorf("final tool call mismatch (-got +want):\n%s", diff) }