Add sentinel errors, remove redundant calls

This commit is contained in:
ParthSareen 2025-05-15 14:23:14 -07:00
parent 53f7946fb6
commit 717fa7a44a
5 changed files with 116 additions and 119 deletions

View File

@ -1485,7 +1485,7 @@ func (s *Server) ChatHandler(c *gin.Context) {
slog.Debug("chat request", "images", len(images), "prompt", prompt)
var toolParser *tools.Parser
var toolParser tools.Parser
if len(req.Tools) > 0 {
toolParser, err = tools.NewParser(m.Template.Template)
if err != nil {
@ -1525,6 +1525,9 @@ func (s *Server) ChatHandler(c *gin.Context) {
}
if len(req.Tools) > 0 && !toolParser.Done {
if r.Content == "" {
return
}
toolCalls, content, err := toolParser.Add(r.Content)
if err == nil {
if len(content) > 0 {
@ -1533,9 +1536,9 @@ func (s *Server) ChatHandler(c *gin.Context) {
} else if len(toolCalls) > 0 {
res.Message.ToolCalls = toolCalls
res.Message.Content = ""
} else {
return
}
} else if errors.Is(err, tools.ErrAccumulateMore) {
return
}
}
ch <- res

View File

@ -14,25 +14,35 @@ import (
"github.com/ollama/ollama/template"
)
// Sentinel errors for parsing states
var (
ErrPartialPrefix = errors.New("partial prefix detected")
ErrPrefixNotFound = errors.New("prefix not found")
ErrInvalidToolCall = errors.New("invalid tool call format")
ErrAccumulateMore = errors.New("need to accumulate more content")
)
type Parser struct {
greedyParse bool
prefixFound bool
prefixPartial bool
tmpl *gotmpl.Template
sb *strings.Builder
prefix string
index int
name string
arguments string
Done bool
greedyParse bool
prefixFound bool
tmpl gotmpl.Template
sb strings.Builder
prefix string
index int
name string
arguments string
Done bool
}
// parseJSONToolCalls attempts to parse a JSON string into a slice of ToolCalls.
// parseJSONToolCalls attempts to parse a JSON string into a slice ToolCalls.
// It first tries to incrementally decode the JSON to handle partial inputs.
// Returns:
// - []api.ToolCall: The parsed tool calls if successful
// - bool: True if JSON is incomplete and needs more input
func (p *Parser) parseJSONToolCalls(s string) ([]api.ToolCall, bool) {
// - error: ErrPartialJSON if JSON is incomplete, ErrInvalidToolCall if invalid, or nil if successful
func (p *Parser) parseJSONToolCalls(s string) ([]api.ToolCall, error) {
// First try incremental decoding to handle partial JSON
dec := jsontext.NewDecoder(strings.NewReader(s))
if got, err := dec.ReadValue(); err == nil {
@ -41,22 +51,18 @@ func (p *Parser) parseJSONToolCalls(s string) ([]api.ToolCall, bool) {
// Attempt full unmarshal of the JSON
var resp any
err := jsonv2.Unmarshal([]byte(s), &resp)
if err != nil {
// Handle incomplete JSON cases
if errors.Is(err, io.ErrUnexpectedEOF) || err.Error() == "unexpected end of JSON input" {
slog.Debug("incomplete JSON detected", "input", s)
return nil, true
}
if err := jsonv2.Unmarshal([]byte(s), &resp); errors.Is(err, io.ErrUnexpectedEOF) {
slog.Debug("incomplete JSON detected", "input", s)
return nil, ErrAccumulateMore
} else if err != nil {
slog.Debug("failed to unmarshal response", "error", err)
return nil, false
return nil, ErrInvalidToolCall
}
// Collect all nested objects that could contain tool calls
var objs []map[string]any
objs = append(objs, collect(resp)...)
objs := collect(resp)
if len(objs) == 0 {
return nil, false
return nil, ErrInvalidToolCall
}
var toolCalls []api.ToolCall
@ -75,59 +81,56 @@ func (p *Parser) parseJSONToolCalls(s string) ([]api.ToolCall, bool) {
// Valid JSON, no tool calls found
if len(toolCalls) == 0 {
return nil, false
return nil, ErrInvalidToolCall
}
return toolCalls, false
return toolCalls, nil
}
// checkPrefix processes a string to find and handle a prefix pattern.
//
// Returns:
// - The processed string with prefix removed if found
// - Whether the prefix was found at the start of the string
// - Whether to continue parsing
func (p *Parser) checkPrefix(s string) (string, bool, bool) {
// - error: ErrPartialPrefix if prefix is incomplete, ErrPrefixNotFound if not found, or nil if successful
func (p *Parser) checkPrefix(s string) (string, error) {
// Keep original for overlap checks
original := s
s = strings.TrimSpace(s)
if s == "" {
return "", false, true
return "", nil
}
// If no prefix defined, just return trimmed string
if p.prefix == "" {
return s, false, true
return s, nil
}
// Check for prefix at start of string
if processedStr, hasPrefix := strings.CutPrefix(s, p.prefix); hasPrefix {
// Found prefix at start - accumulate for potential tool
return processedStr, true, true
p.prefixFound = true
return processedStr, nil
}
// Check if prefix overlaps end of string
if overlap := suffixOverlap(original, p.prefix); overlap > 0 {
p.prefixPartial = true
// Return everything except overlapping portion
p.sb.Reset()
p.sb.WriteString(original[len(original)-overlap:])
return original[0 : len(original)-overlap], false, false
return original[0 : len(original)-overlap], ErrAccumulateMore
}
// Check if prefix appears in middle of string
if idx := strings.Index(original, p.prefix); idx != -1 {
p.prefixPartial = true
// Save remainder starting at prefix for next pass
p.sb.Reset()
p.sb.WriteString(strings.TrimSpace(original[idx:]))
// Return everything before prefix
return original[:idx], false, false
return original[:idx], ErrAccumulateMore
}
// No prefix found
p.prefixPartial = false
return s, false, true
// No partial prefix found
return s, nil
}
// Add processes a string input to parse tool calls and content.
@ -136,57 +139,45 @@ func (p *Parser) checkPrefix(s string) (string, bool, bool) {
// Returns:
// - tools: Any parsed tool calls
// - content: Non-tool call content
// - err: Error if parsing failed
// - error: One of the sentinel errors or nil if successful
func (p *Parser) Add(s string) (tools []api.ToolCall, content string, err error) {
if len(s) == 0 {
return nil, "", nil
}
p.sb.WriteString(s)
s = p.sb.String()
// Check for prefix pattern in input
s, prefixFound, shouldContinue := p.checkPrefix(s)
if !shouldContinue {
s, err = p.checkPrefix(s)
if err != nil {
if s != "" {
// Return content before prefix
return nil, s, nil
}
// Need more input to complete prefix
return nil, "", nil
}
// Update prefix found state
if prefixFound {
p.prefixFound = true
return nil, "", ErrAccumulateMore
}
// Exit if prefix exists in template, greedy parsing is off, and prefix not found
if !p.greedyParse && !p.prefixFound {
p.sb.Reset()
return nil, "", errors.New("prefix not found")
return nil, "", ErrPrefixNotFound
}
toolCalls, isPartial := p.parseJSONToolCalls(s)
if isPartial {
// Need more input to complete JSON
return nil, "", nil
}
// Do not try greedy parsing if partial JSON not found
p.greedyParse = false
// Handle invalid tool call format
if len(toolCalls) == 0 {
p.sb.Reset()
if p.prefix == "" {
p.Done = true
toolCalls, err := p.parseJSONToolCalls(s)
if err != nil {
if errors.Is(err, ErrAccumulateMore) {
return nil, "", err
} else {
p.sb.Reset()
// Do not try greedy parsing if JSON not found
p.greedyParse = false
if p.prefix == "" {
p.Done = true
}
if p.prefixFound {
// Drop tokens since prefix was found
return nil, "", ErrAccumulateMore
}
return nil, s, nil
}
if p.prefixFound {
// Drop tokens since prefix was found
return nil, "", nil
}
return nil, s, nil
}
for _, tc := range toolCalls {
@ -207,21 +198,15 @@ func (p *Parser) Add(s string) (tools []api.ToolCall, content string, err error)
// prefix, and field names from the template to use for parsing tool calls from model output.
//
// Returns an error if the template does not contain valid tool call formatting.
func NewParser(templateToProcess *gotmpl.Template) (*Parser, error) {
func NewParser(templateToProcess *gotmpl.Template) (Parser, error) {
parsed, err := template.Parse(templateToProcess.Root.String())
if err != nil {
return nil, err
}
if parsed == nil {
return nil, errors.New("failed to parse template")
return Parser{}, err
}
tt, tc := toolTemplate(parsed)
if !tc {
return nil, errors.New("failed to find tool calls in template")
}
if tt == nil {
return nil, errors.New("failed to find tool template")
tt, err := toolTemplate(parsed)
if err != nil {
return Parser{}, err
}
tp := toolPrefix(templateToProcess)
@ -229,12 +214,12 @@ func NewParser(templateToProcess *gotmpl.Template) (*Parser, error) {
name, arguments, err := extractToolArgs(tt)
if err != nil {
return nil, err
return Parser{}, err
}
return &Parser{
tmpl: tt,
sb: &strings.Builder{},
return Parser{
tmpl: *tt,
sb: strings.Builder{},
prefix: tp,
greedyParse: true,
name: name,

View File

@ -3,6 +3,7 @@ package tools
import (
"bytes"
"encoding/json"
"errors"
"fmt"
"os"
"path/filepath"
@ -196,6 +197,13 @@ func TestParseToolCalls(t *testing.T) {
expectedToolCall: []api.ToolCall{t1, t2},
expectedTokens: "some tokens after call",
},
{
name: "qwen2.5-coder tool calls with initial text",
model: "qwen2.5-coder",
output: `some tokens before call [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}}, {"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`,
expectedToolCall: []api.ToolCall{},
expectedTokens: `some tokens before call [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}}, {"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`,
},
{
name: "qwen2.5 tool calls with prefix and trailing text",
model: "qwen2.5-coder",
@ -203,6 +211,13 @@ func TestParseToolCalls(t *testing.T) {
expectedToolCall: []api.ToolCall{t1, t2},
expectedTokens: "",
},
{
name: "qwen2.5 tool calls with prefix and initial text",
model: "qwen2.5-coder",
output: `some tokens before call <tool_call> [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}}, {"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}] </tool_call>`,
expectedToolCall: []api.ToolCall{t1, t2},
expectedTokens: "some tokens before call",
},
{
name: "qwen2.5 tool calls without prefix and valid tool call",
model: "qwen2.5-coder",
@ -356,9 +371,9 @@ func TestParseToolCalls(t *testing.T) {
} else if len(toolCalls) > 0 {
got = append(got, toolCalls...)
add = false
} else {
add = false
}
} else if errors.Is(err, ErrAccumulateMore) {
add = false
}
}
if add {
@ -388,8 +403,7 @@ func TestParseJSONToolCalls(t *testing.T) {
input string
parser *Parser
wantToolCalls []api.ToolCall
wantPartial bool
wantValid bool
wantErr error
}{
{
name: "valid single tool call",
@ -405,32 +419,28 @@ func TestParseJSONToolCalls(t *testing.T) {
},
},
},
wantPartial: false,
wantValid: true,
wantErr: nil,
},
{
name: "incomplete JSON",
input: `{"name": "test_tool", "arguments": {"arg1": `,
parser: &Parser{name: "name", arguments: "arguments"},
wantToolCalls: nil,
wantPartial: true,
wantValid: false,
wantErr: ErrAccumulateMore,
},
{
name: "invalid JSON",
input: `not json at all`,
parser: &Parser{name: "name", arguments: "arguments"},
wantToolCalls: nil,
wantPartial: false,
wantValid: false,
wantErr: ErrInvalidToolCall,
},
{
name: "missing required fields",
input: `{"other": "field"}`,
parser: &Parser{name: "name", arguments: "arguments"},
wantToolCalls: nil,
wantPartial: false,
wantValid: false,
wantErr: ErrInvalidToolCall,
},
{
name: "multiple tool calls in array",
@ -457,21 +467,20 @@ func TestParseJSONToolCalls(t *testing.T) {
},
},
},
wantPartial: false,
wantValid: true,
wantErr: nil,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
gotCalls, gotPartial := tt.parser.parseJSONToolCalls(tt.input)
gotCalls, err := tt.parser.parseJSONToolCalls(tt.input)
if gotPartial != tt.wantPartial {
t.Errorf("parseJSONToolCalls() partial = %v, want %v", gotPartial, tt.wantPartial)
if err != tt.wantErr {
t.Errorf("parseJSONToolCalls() error = %v, want %v", err, tt.wantErr)
}
if len(gotCalls) != 0 != tt.wantValid {
t.Errorf("parseJSONToolCalls() valid = %v, want %v", len(gotCalls) == 0, tt.wantValid)
if len(gotCalls) != 0 && tt.wantErr != nil {
t.Errorf("parseJSONToolCalls() valid = %v, want %v", len(gotCalls) == 0, tt.wantErr == nil)
}
if diff := cmp.Diff(gotCalls, tt.wantToolCalls); diff != "" {

View File

@ -142,8 +142,8 @@ func toolPrefix(tmpl *gotmpl.Template) string {
//
// Returns:
// - *gotmpl.Template: The subtree containing the .ToolCalls range
// - bool: Whether a .ToolCalls range was found in the template
func toolTemplate(t *template.Template) (*gotmpl.Template, bool) {
// - error: Error if parsing failed
func toolTemplate(t *template.Template) (*gotmpl.Template, error) {
tmpl := t.Subtree(func(n parse.Node) bool {
if t, ok := n.(*parse.RangeNode); ok {
return slices.Contains(template.Identifiers(t.Pipe), "ToolCalls")
@ -153,20 +153,20 @@ func toolTemplate(t *template.Template) (*gotmpl.Template, bool) {
})
if tmpl == nil {
return nil, false
return nil, errors.New("failed to find tool template")
}
return tmpl, true
return tmpl, nil
}
// suffixOverlap returns the length of the longest suffix overlap between two strings
//
// Returns:
// - int: The length of the longest suffix overlap
func suffixOverlap(s, delim string) int {
max := min(len(delim), len(s))
func suffixOverlap(s, prefix string) int {
max := min(len(prefix), len(s))
for i := max; i > 0; i-- {
if strings.HasSuffix(s, delim[:i]) {
if strings.HasSuffix(s, prefix[:i]) {
return i
}
}

View File

@ -192,9 +192,9 @@ func TestToolTemplate(t *testing.T) {
t.Fatalf("failed to parse template: %v", err)
}
_, got := toolTemplate(parsed)
if got != tt.want {
t.Errorf("toolTemplate() = %v; want %v", got, tt.want)
_, err = toolTemplate(parsed)
if err != nil && tt.want {
t.Errorf("toolTemplate() = %v; want %v", err, tt.want)
}
})
}