From 717fa7a44ade794194a431bcff097e9c9d015f0d Mon Sep 17 00:00:00 2001 From: ParthSareen Date: Thu, 15 May 2025 14:23:14 -0700 Subject: [PATCH] Add sentinel errors, remove redundant calls --- server/routes.go | 9 ++- tools/tools.go | 159 ++++++++++++++++++++------------------------ tools/tools_test.go | 47 +++++++------ tools/utils.go | 14 ++-- tools/utils_test.go | 6 +- 5 files changed, 116 insertions(+), 119 deletions(-) diff --git a/server/routes.go b/server/routes.go index c1868ff89..76039f79d 100644 --- a/server/routes.go +++ b/server/routes.go @@ -1485,7 +1485,7 @@ func (s *Server) ChatHandler(c *gin.Context) { slog.Debug("chat request", "images", len(images), "prompt", prompt) - var toolParser *tools.Parser + var toolParser tools.Parser if len(req.Tools) > 0 { toolParser, err = tools.NewParser(m.Template.Template) if err != nil { @@ -1525,6 +1525,9 @@ func (s *Server) ChatHandler(c *gin.Context) { } if len(req.Tools) > 0 && !toolParser.Done { + if r.Content == "" { + return + } toolCalls, content, err := toolParser.Add(r.Content) if err == nil { if len(content) > 0 { @@ -1533,9 +1536,9 @@ func (s *Server) ChatHandler(c *gin.Context) { } else if len(toolCalls) > 0 { res.Message.ToolCalls = toolCalls res.Message.Content = "" - } else { - return } + } else if errors.Is(err, tools.ErrAccumulateMore) { + return } } ch <- res diff --git a/tools/tools.go b/tools/tools.go index 1e8a397cf..c36b2a03a 100644 --- a/tools/tools.go +++ b/tools/tools.go @@ -14,25 +14,35 @@ import ( "github.com/ollama/ollama/template" ) +// Sentinel errors for parsing states +var ( + ErrPartialPrefix = errors.New("partial prefix detected") + + ErrPrefixNotFound = errors.New("prefix not found") + + ErrInvalidToolCall = errors.New("invalid tool call format") + + ErrAccumulateMore = errors.New("need to accumulate more content") +) + type Parser struct { - greedyParse bool - prefixFound bool - prefixPartial bool - tmpl *gotmpl.Template - sb *strings.Builder - prefix string - index int - name string - arguments string - Done bool + greedyParse bool + prefixFound bool + tmpl gotmpl.Template + sb strings.Builder + prefix string + index int + name string + arguments string + Done bool } -// parseJSONToolCalls attempts to parse a JSON string into a slice of ToolCalls. +// parseJSONToolCalls attempts to parse a JSON string into a slice ToolCalls. // It first tries to incrementally decode the JSON to handle partial inputs. // Returns: // - []api.ToolCall: The parsed tool calls if successful -// - bool: True if JSON is incomplete and needs more input -func (p *Parser) parseJSONToolCalls(s string) ([]api.ToolCall, bool) { +// - error: ErrPartialJSON if JSON is incomplete, ErrInvalidToolCall if invalid, or nil if successful +func (p *Parser) parseJSONToolCalls(s string) ([]api.ToolCall, error) { // First try incremental decoding to handle partial JSON dec := jsontext.NewDecoder(strings.NewReader(s)) if got, err := dec.ReadValue(); err == nil { @@ -41,22 +51,18 @@ func (p *Parser) parseJSONToolCalls(s string) ([]api.ToolCall, bool) { // Attempt full unmarshal of the JSON var resp any - err := jsonv2.Unmarshal([]byte(s), &resp) - if err != nil { - // Handle incomplete JSON cases - if errors.Is(err, io.ErrUnexpectedEOF) || err.Error() == "unexpected end of JSON input" { - slog.Debug("incomplete JSON detected", "input", s) - return nil, true - } + if err := jsonv2.Unmarshal([]byte(s), &resp); errors.Is(err, io.ErrUnexpectedEOF) { + slog.Debug("incomplete JSON detected", "input", s) + return nil, ErrAccumulateMore + } else if err != nil { slog.Debug("failed to unmarshal response", "error", err) - return nil, false + return nil, ErrInvalidToolCall } // Collect all nested objects that could contain tool calls - var objs []map[string]any - objs = append(objs, collect(resp)...) + objs := collect(resp) if len(objs) == 0 { - return nil, false + return nil, ErrInvalidToolCall } var toolCalls []api.ToolCall @@ -75,59 +81,56 @@ func (p *Parser) parseJSONToolCalls(s string) ([]api.ToolCall, bool) { // Valid JSON, no tool calls found if len(toolCalls) == 0 { - return nil, false + return nil, ErrInvalidToolCall } - return toolCalls, false + return toolCalls, nil } // checkPrefix processes a string to find and handle a prefix pattern. // // Returns: // - The processed string with prefix removed if found -// - Whether the prefix was found at the start of the string -// - Whether to continue parsing -func (p *Parser) checkPrefix(s string) (string, bool, bool) { +// - error: ErrPartialPrefix if prefix is incomplete, ErrPrefixNotFound if not found, or nil if successful +func (p *Parser) checkPrefix(s string) (string, error) { // Keep original for overlap checks original := s s = strings.TrimSpace(s) if s == "" { - return "", false, true + return "", nil } // If no prefix defined, just return trimmed string if p.prefix == "" { - return s, false, true + return s, nil } // Check for prefix at start of string if processedStr, hasPrefix := strings.CutPrefix(s, p.prefix); hasPrefix { // Found prefix at start - accumulate for potential tool - return processedStr, true, true + p.prefixFound = true + return processedStr, nil } // Check if prefix overlaps end of string if overlap := suffixOverlap(original, p.prefix); overlap > 0 { - p.prefixPartial = true // Return everything except overlapping portion p.sb.Reset() p.sb.WriteString(original[len(original)-overlap:]) - return original[0 : len(original)-overlap], false, false + return original[0 : len(original)-overlap], ErrAccumulateMore } // Check if prefix appears in middle of string if idx := strings.Index(original, p.prefix); idx != -1 { - p.prefixPartial = true // Save remainder starting at prefix for next pass p.sb.Reset() p.sb.WriteString(strings.TrimSpace(original[idx:])) // Return everything before prefix - return original[:idx], false, false + return original[:idx], ErrAccumulateMore } - // No prefix found - p.prefixPartial = false - return s, false, true + // No partial prefix found + return s, nil } // Add processes a string input to parse tool calls and content. @@ -136,57 +139,45 @@ func (p *Parser) checkPrefix(s string) (string, bool, bool) { // Returns: // - tools: Any parsed tool calls // - content: Non-tool call content -// - err: Error if parsing failed +// - error: One of the sentinel errors or nil if successful func (p *Parser) Add(s string) (tools []api.ToolCall, content string, err error) { - if len(s) == 0 { - return nil, "", nil - } - p.sb.WriteString(s) s = p.sb.String() // Check for prefix pattern in input - s, prefixFound, shouldContinue := p.checkPrefix(s) - if !shouldContinue { + s, err = p.checkPrefix(s) + if err != nil { if s != "" { // Return content before prefix return nil, s, nil } // Need more input to complete prefix - return nil, "", nil - } - - // Update prefix found state - if prefixFound { - p.prefixFound = true + return nil, "", ErrAccumulateMore } // Exit if prefix exists in template, greedy parsing is off, and prefix not found if !p.greedyParse && !p.prefixFound { p.sb.Reset() - return nil, "", errors.New("prefix not found") + return nil, "", ErrPrefixNotFound } - toolCalls, isPartial := p.parseJSONToolCalls(s) - if isPartial { - // Need more input to complete JSON - return nil, "", nil - } - - // Do not try greedy parsing if partial JSON not found - p.greedyParse = false - - // Handle invalid tool call format - if len(toolCalls) == 0 { - p.sb.Reset() - if p.prefix == "" { - p.Done = true + toolCalls, err := p.parseJSONToolCalls(s) + if err != nil { + if errors.Is(err, ErrAccumulateMore) { + return nil, "", err + } else { + p.sb.Reset() + // Do not try greedy parsing if JSON not found + p.greedyParse = false + if p.prefix == "" { + p.Done = true + } + if p.prefixFound { + // Drop tokens since prefix was found + return nil, "", ErrAccumulateMore + } + return nil, s, nil } - if p.prefixFound { - // Drop tokens since prefix was found - return nil, "", nil - } - return nil, s, nil } for _, tc := range toolCalls { @@ -207,21 +198,15 @@ func (p *Parser) Add(s string) (tools []api.ToolCall, content string, err error) // prefix, and field names from the template to use for parsing tool calls from model output. // // Returns an error if the template does not contain valid tool call formatting. -func NewParser(templateToProcess *gotmpl.Template) (*Parser, error) { +func NewParser(templateToProcess *gotmpl.Template) (Parser, error) { parsed, err := template.Parse(templateToProcess.Root.String()) if err != nil { - return nil, err - } - if parsed == nil { - return nil, errors.New("failed to parse template") + return Parser{}, err } - tt, tc := toolTemplate(parsed) - if !tc { - return nil, errors.New("failed to find tool calls in template") - } - if tt == nil { - return nil, errors.New("failed to find tool template") + tt, err := toolTemplate(parsed) + if err != nil { + return Parser{}, err } tp := toolPrefix(templateToProcess) @@ -229,12 +214,12 @@ func NewParser(templateToProcess *gotmpl.Template) (*Parser, error) { name, arguments, err := extractToolArgs(tt) if err != nil { - return nil, err + return Parser{}, err } - return &Parser{ - tmpl: tt, - sb: &strings.Builder{}, + return Parser{ + tmpl: *tt, + sb: strings.Builder{}, prefix: tp, greedyParse: true, name: name, diff --git a/tools/tools_test.go b/tools/tools_test.go index bc436f838..50b3afc85 100644 --- a/tools/tools_test.go +++ b/tools/tools_test.go @@ -3,6 +3,7 @@ package tools import ( "bytes" "encoding/json" + "errors" "fmt" "os" "path/filepath" @@ -196,6 +197,13 @@ func TestParseToolCalls(t *testing.T) { expectedToolCall: []api.ToolCall{t1, t2}, expectedTokens: "some tokens after call", }, + { + name: "qwen2.5-coder tool calls with initial text", + model: "qwen2.5-coder", + output: `some tokens before call [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}}, {"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`, + expectedToolCall: []api.ToolCall{}, + expectedTokens: `some tokens before call [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}}, {"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`, + }, { name: "qwen2.5 tool calls with prefix and trailing text", model: "qwen2.5-coder", @@ -203,6 +211,13 @@ func TestParseToolCalls(t *testing.T) { expectedToolCall: []api.ToolCall{t1, t2}, expectedTokens: "", }, + { + name: "qwen2.5 tool calls with prefix and initial text", + model: "qwen2.5-coder", + output: `some tokens before call [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}}, {"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}] `, + expectedToolCall: []api.ToolCall{t1, t2}, + expectedTokens: "some tokens before call", + }, { name: "qwen2.5 tool calls without prefix and valid tool call", model: "qwen2.5-coder", @@ -356,9 +371,9 @@ func TestParseToolCalls(t *testing.T) { } else if len(toolCalls) > 0 { got = append(got, toolCalls...) add = false - } else { - add = false } + } else if errors.Is(err, ErrAccumulateMore) { + add = false } } if add { @@ -388,8 +403,7 @@ func TestParseJSONToolCalls(t *testing.T) { input string parser *Parser wantToolCalls []api.ToolCall - wantPartial bool - wantValid bool + wantErr error }{ { name: "valid single tool call", @@ -405,32 +419,28 @@ func TestParseJSONToolCalls(t *testing.T) { }, }, }, - wantPartial: false, - wantValid: true, + wantErr: nil, }, { name: "incomplete JSON", input: `{"name": "test_tool", "arguments": {"arg1": `, parser: &Parser{name: "name", arguments: "arguments"}, wantToolCalls: nil, - wantPartial: true, - wantValid: false, + wantErr: ErrAccumulateMore, }, { name: "invalid JSON", input: `not json at all`, parser: &Parser{name: "name", arguments: "arguments"}, wantToolCalls: nil, - wantPartial: false, - wantValid: false, + wantErr: ErrInvalidToolCall, }, { name: "missing required fields", input: `{"other": "field"}`, parser: &Parser{name: "name", arguments: "arguments"}, wantToolCalls: nil, - wantPartial: false, - wantValid: false, + wantErr: ErrInvalidToolCall, }, { name: "multiple tool calls in array", @@ -457,21 +467,20 @@ func TestParseJSONToolCalls(t *testing.T) { }, }, }, - wantPartial: false, - wantValid: true, + wantErr: nil, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - gotCalls, gotPartial := tt.parser.parseJSONToolCalls(tt.input) + gotCalls, err := tt.parser.parseJSONToolCalls(tt.input) - if gotPartial != tt.wantPartial { - t.Errorf("parseJSONToolCalls() partial = %v, want %v", gotPartial, tt.wantPartial) + if err != tt.wantErr { + t.Errorf("parseJSONToolCalls() error = %v, want %v", err, tt.wantErr) } - if len(gotCalls) != 0 != tt.wantValid { - t.Errorf("parseJSONToolCalls() valid = %v, want %v", len(gotCalls) == 0, tt.wantValid) + if len(gotCalls) != 0 && tt.wantErr != nil { + t.Errorf("parseJSONToolCalls() valid = %v, want %v", len(gotCalls) == 0, tt.wantErr == nil) } if diff := cmp.Diff(gotCalls, tt.wantToolCalls); diff != "" { diff --git a/tools/utils.go b/tools/utils.go index 64f88658c..5af7818e5 100644 --- a/tools/utils.go +++ b/tools/utils.go @@ -142,8 +142,8 @@ func toolPrefix(tmpl *gotmpl.Template) string { // // Returns: // - *gotmpl.Template: The subtree containing the .ToolCalls range -// - bool: Whether a .ToolCalls range was found in the template -func toolTemplate(t *template.Template) (*gotmpl.Template, bool) { +// - error: Error if parsing failed +func toolTemplate(t *template.Template) (*gotmpl.Template, error) { tmpl := t.Subtree(func(n parse.Node) bool { if t, ok := n.(*parse.RangeNode); ok { return slices.Contains(template.Identifiers(t.Pipe), "ToolCalls") @@ -153,20 +153,20 @@ func toolTemplate(t *template.Template) (*gotmpl.Template, bool) { }) if tmpl == nil { - return nil, false + return nil, errors.New("failed to find tool template") } - return tmpl, true + return tmpl, nil } // suffixOverlap returns the length of the longest suffix overlap between two strings // // Returns: // - int: The length of the longest suffix overlap -func suffixOverlap(s, delim string) int { - max := min(len(delim), len(s)) +func suffixOverlap(s, prefix string) int { + max := min(len(prefix), len(s)) for i := max; i > 0; i-- { - if strings.HasSuffix(s, delim[:i]) { + if strings.HasSuffix(s, prefix[:i]) { return i } } diff --git a/tools/utils_test.go b/tools/utils_test.go index c082fde02..51fcff0e2 100644 --- a/tools/utils_test.go +++ b/tools/utils_test.go @@ -192,9 +192,9 @@ func TestToolTemplate(t *testing.T) { t.Fatalf("failed to parse template: %v", err) } - _, got := toolTemplate(parsed) - if got != tt.want { - t.Errorf("toolTemplate() = %v; want %v", got, tt.want) + _, err = toolTemplate(parsed) + if err != nil && tt.want { + t.Errorf("toolTemplate() = %v; want %v", err, tt.want) } }) }