Add sentinel errors, remove redundant calls
This commit is contained in:
parent
53f7946fb6
commit
717fa7a44a
|
|
@ -1485,7 +1485,7 @@ func (s *Server) ChatHandler(c *gin.Context) {
|
|||
|
||||
slog.Debug("chat request", "images", len(images), "prompt", prompt)
|
||||
|
||||
var toolParser *tools.Parser
|
||||
var toolParser tools.Parser
|
||||
if len(req.Tools) > 0 {
|
||||
toolParser, err = tools.NewParser(m.Template.Template)
|
||||
if err != nil {
|
||||
|
|
@ -1525,6 +1525,9 @@ func (s *Server) ChatHandler(c *gin.Context) {
|
|||
}
|
||||
|
||||
if len(req.Tools) > 0 && !toolParser.Done {
|
||||
if r.Content == "" {
|
||||
return
|
||||
}
|
||||
toolCalls, content, err := toolParser.Add(r.Content)
|
||||
if err == nil {
|
||||
if len(content) > 0 {
|
||||
|
|
@ -1533,9 +1536,9 @@ func (s *Server) ChatHandler(c *gin.Context) {
|
|||
} else if len(toolCalls) > 0 {
|
||||
res.Message.ToolCalls = toolCalls
|
||||
res.Message.Content = ""
|
||||
} else {
|
||||
return
|
||||
}
|
||||
} else if errors.Is(err, tools.ErrAccumulateMore) {
|
||||
return
|
||||
}
|
||||
}
|
||||
ch <- res
|
||||
|
|
|
|||
159
tools/tools.go
159
tools/tools.go
|
|
@ -14,25 +14,35 @@ import (
|
|||
"github.com/ollama/ollama/template"
|
||||
)
|
||||
|
||||
// Sentinel errors for parsing states
|
||||
var (
|
||||
ErrPartialPrefix = errors.New("partial prefix detected")
|
||||
|
||||
ErrPrefixNotFound = errors.New("prefix not found")
|
||||
|
||||
ErrInvalidToolCall = errors.New("invalid tool call format")
|
||||
|
||||
ErrAccumulateMore = errors.New("need to accumulate more content")
|
||||
)
|
||||
|
||||
type Parser struct {
|
||||
greedyParse bool
|
||||
prefixFound bool
|
||||
prefixPartial bool
|
||||
tmpl *gotmpl.Template
|
||||
sb *strings.Builder
|
||||
prefix string
|
||||
index int
|
||||
name string
|
||||
arguments string
|
||||
Done bool
|
||||
greedyParse bool
|
||||
prefixFound bool
|
||||
tmpl gotmpl.Template
|
||||
sb strings.Builder
|
||||
prefix string
|
||||
index int
|
||||
name string
|
||||
arguments string
|
||||
Done bool
|
||||
}
|
||||
|
||||
// parseJSONToolCalls attempts to parse a JSON string into a slice of ToolCalls.
|
||||
// parseJSONToolCalls attempts to parse a JSON string into a slice ToolCalls.
|
||||
// It first tries to incrementally decode the JSON to handle partial inputs.
|
||||
// Returns:
|
||||
// - []api.ToolCall: The parsed tool calls if successful
|
||||
// - bool: True if JSON is incomplete and needs more input
|
||||
func (p *Parser) parseJSONToolCalls(s string) ([]api.ToolCall, bool) {
|
||||
// - error: ErrPartialJSON if JSON is incomplete, ErrInvalidToolCall if invalid, or nil if successful
|
||||
func (p *Parser) parseJSONToolCalls(s string) ([]api.ToolCall, error) {
|
||||
// First try incremental decoding to handle partial JSON
|
||||
dec := jsontext.NewDecoder(strings.NewReader(s))
|
||||
if got, err := dec.ReadValue(); err == nil {
|
||||
|
|
@ -41,22 +51,18 @@ func (p *Parser) parseJSONToolCalls(s string) ([]api.ToolCall, bool) {
|
|||
|
||||
// Attempt full unmarshal of the JSON
|
||||
var resp any
|
||||
err := jsonv2.Unmarshal([]byte(s), &resp)
|
||||
if err != nil {
|
||||
// Handle incomplete JSON cases
|
||||
if errors.Is(err, io.ErrUnexpectedEOF) || err.Error() == "unexpected end of JSON input" {
|
||||
slog.Debug("incomplete JSON detected", "input", s)
|
||||
return nil, true
|
||||
}
|
||||
if err := jsonv2.Unmarshal([]byte(s), &resp); errors.Is(err, io.ErrUnexpectedEOF) {
|
||||
slog.Debug("incomplete JSON detected", "input", s)
|
||||
return nil, ErrAccumulateMore
|
||||
} else if err != nil {
|
||||
slog.Debug("failed to unmarshal response", "error", err)
|
||||
return nil, false
|
||||
return nil, ErrInvalidToolCall
|
||||
}
|
||||
|
||||
// Collect all nested objects that could contain tool calls
|
||||
var objs []map[string]any
|
||||
objs = append(objs, collect(resp)...)
|
||||
objs := collect(resp)
|
||||
if len(objs) == 0 {
|
||||
return nil, false
|
||||
return nil, ErrInvalidToolCall
|
||||
}
|
||||
|
||||
var toolCalls []api.ToolCall
|
||||
|
|
@ -75,59 +81,56 @@ func (p *Parser) parseJSONToolCalls(s string) ([]api.ToolCall, bool) {
|
|||
|
||||
// Valid JSON, no tool calls found
|
||||
if len(toolCalls) == 0 {
|
||||
return nil, false
|
||||
return nil, ErrInvalidToolCall
|
||||
}
|
||||
|
||||
return toolCalls, false
|
||||
return toolCalls, nil
|
||||
}
|
||||
|
||||
// checkPrefix processes a string to find and handle a prefix pattern.
|
||||
//
|
||||
// Returns:
|
||||
// - The processed string with prefix removed if found
|
||||
// - Whether the prefix was found at the start of the string
|
||||
// - Whether to continue parsing
|
||||
func (p *Parser) checkPrefix(s string) (string, bool, bool) {
|
||||
// - error: ErrPartialPrefix if prefix is incomplete, ErrPrefixNotFound if not found, or nil if successful
|
||||
func (p *Parser) checkPrefix(s string) (string, error) {
|
||||
// Keep original for overlap checks
|
||||
original := s
|
||||
s = strings.TrimSpace(s)
|
||||
if s == "" {
|
||||
return "", false, true
|
||||
return "", nil
|
||||
}
|
||||
|
||||
// If no prefix defined, just return trimmed string
|
||||
if p.prefix == "" {
|
||||
return s, false, true
|
||||
return s, nil
|
||||
}
|
||||
|
||||
// Check for prefix at start of string
|
||||
if processedStr, hasPrefix := strings.CutPrefix(s, p.prefix); hasPrefix {
|
||||
// Found prefix at start - accumulate for potential tool
|
||||
return processedStr, true, true
|
||||
p.prefixFound = true
|
||||
return processedStr, nil
|
||||
}
|
||||
|
||||
// Check if prefix overlaps end of string
|
||||
if overlap := suffixOverlap(original, p.prefix); overlap > 0 {
|
||||
p.prefixPartial = true
|
||||
// Return everything except overlapping portion
|
||||
p.sb.Reset()
|
||||
p.sb.WriteString(original[len(original)-overlap:])
|
||||
return original[0 : len(original)-overlap], false, false
|
||||
return original[0 : len(original)-overlap], ErrAccumulateMore
|
||||
}
|
||||
|
||||
// Check if prefix appears in middle of string
|
||||
if idx := strings.Index(original, p.prefix); idx != -1 {
|
||||
p.prefixPartial = true
|
||||
// Save remainder starting at prefix for next pass
|
||||
p.sb.Reset()
|
||||
p.sb.WriteString(strings.TrimSpace(original[idx:]))
|
||||
// Return everything before prefix
|
||||
return original[:idx], false, false
|
||||
return original[:idx], ErrAccumulateMore
|
||||
}
|
||||
|
||||
// No prefix found
|
||||
p.prefixPartial = false
|
||||
return s, false, true
|
||||
// No partial prefix found
|
||||
return s, nil
|
||||
}
|
||||
|
||||
// Add processes a string input to parse tool calls and content.
|
||||
|
|
@ -136,57 +139,45 @@ func (p *Parser) checkPrefix(s string) (string, bool, bool) {
|
|||
// Returns:
|
||||
// - tools: Any parsed tool calls
|
||||
// - content: Non-tool call content
|
||||
// - err: Error if parsing failed
|
||||
// - error: One of the sentinel errors or nil if successful
|
||||
func (p *Parser) Add(s string) (tools []api.ToolCall, content string, err error) {
|
||||
if len(s) == 0 {
|
||||
return nil, "", nil
|
||||
}
|
||||
|
||||
p.sb.WriteString(s)
|
||||
s = p.sb.String()
|
||||
|
||||
// Check for prefix pattern in input
|
||||
s, prefixFound, shouldContinue := p.checkPrefix(s)
|
||||
if !shouldContinue {
|
||||
s, err = p.checkPrefix(s)
|
||||
if err != nil {
|
||||
if s != "" {
|
||||
// Return content before prefix
|
||||
return nil, s, nil
|
||||
}
|
||||
// Need more input to complete prefix
|
||||
return nil, "", nil
|
||||
}
|
||||
|
||||
// Update prefix found state
|
||||
if prefixFound {
|
||||
p.prefixFound = true
|
||||
return nil, "", ErrAccumulateMore
|
||||
}
|
||||
|
||||
// Exit if prefix exists in template, greedy parsing is off, and prefix not found
|
||||
if !p.greedyParse && !p.prefixFound {
|
||||
p.sb.Reset()
|
||||
return nil, "", errors.New("prefix not found")
|
||||
return nil, "", ErrPrefixNotFound
|
||||
}
|
||||
|
||||
toolCalls, isPartial := p.parseJSONToolCalls(s)
|
||||
if isPartial {
|
||||
// Need more input to complete JSON
|
||||
return nil, "", nil
|
||||
}
|
||||
|
||||
// Do not try greedy parsing if partial JSON not found
|
||||
p.greedyParse = false
|
||||
|
||||
// Handle invalid tool call format
|
||||
if len(toolCalls) == 0 {
|
||||
p.sb.Reset()
|
||||
if p.prefix == "" {
|
||||
p.Done = true
|
||||
toolCalls, err := p.parseJSONToolCalls(s)
|
||||
if err != nil {
|
||||
if errors.Is(err, ErrAccumulateMore) {
|
||||
return nil, "", err
|
||||
} else {
|
||||
p.sb.Reset()
|
||||
// Do not try greedy parsing if JSON not found
|
||||
p.greedyParse = false
|
||||
if p.prefix == "" {
|
||||
p.Done = true
|
||||
}
|
||||
if p.prefixFound {
|
||||
// Drop tokens since prefix was found
|
||||
return nil, "", ErrAccumulateMore
|
||||
}
|
||||
return nil, s, nil
|
||||
}
|
||||
if p.prefixFound {
|
||||
// Drop tokens since prefix was found
|
||||
return nil, "", nil
|
||||
}
|
||||
return nil, s, nil
|
||||
}
|
||||
|
||||
for _, tc := range toolCalls {
|
||||
|
|
@ -207,21 +198,15 @@ func (p *Parser) Add(s string) (tools []api.ToolCall, content string, err error)
|
|||
// prefix, and field names from the template to use for parsing tool calls from model output.
|
||||
//
|
||||
// Returns an error if the template does not contain valid tool call formatting.
|
||||
func NewParser(templateToProcess *gotmpl.Template) (*Parser, error) {
|
||||
func NewParser(templateToProcess *gotmpl.Template) (Parser, error) {
|
||||
parsed, err := template.Parse(templateToProcess.Root.String())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if parsed == nil {
|
||||
return nil, errors.New("failed to parse template")
|
||||
return Parser{}, err
|
||||
}
|
||||
|
||||
tt, tc := toolTemplate(parsed)
|
||||
if !tc {
|
||||
return nil, errors.New("failed to find tool calls in template")
|
||||
}
|
||||
if tt == nil {
|
||||
return nil, errors.New("failed to find tool template")
|
||||
tt, err := toolTemplate(parsed)
|
||||
if err != nil {
|
||||
return Parser{}, err
|
||||
}
|
||||
|
||||
tp := toolPrefix(templateToProcess)
|
||||
|
|
@ -229,12 +214,12 @@ func NewParser(templateToProcess *gotmpl.Template) (*Parser, error) {
|
|||
|
||||
name, arguments, err := extractToolArgs(tt)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return Parser{}, err
|
||||
}
|
||||
|
||||
return &Parser{
|
||||
tmpl: tt,
|
||||
sb: &strings.Builder{},
|
||||
return Parser{
|
||||
tmpl: *tt,
|
||||
sb: strings.Builder{},
|
||||
prefix: tp,
|
||||
greedyParse: true,
|
||||
name: name,
|
||||
|
|
|
|||
|
|
@ -3,6 +3,7 @@ package tools
|
|||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
|
|
@ -196,6 +197,13 @@ func TestParseToolCalls(t *testing.T) {
|
|||
expectedToolCall: []api.ToolCall{t1, t2},
|
||||
expectedTokens: "some tokens after call",
|
||||
},
|
||||
{
|
||||
name: "qwen2.5-coder tool calls with initial text",
|
||||
model: "qwen2.5-coder",
|
||||
output: `some tokens before call [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}}, {"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`,
|
||||
expectedToolCall: []api.ToolCall{},
|
||||
expectedTokens: `some tokens before call [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}}, {"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`,
|
||||
},
|
||||
{
|
||||
name: "qwen2.5 tool calls with prefix and trailing text",
|
||||
model: "qwen2.5-coder",
|
||||
|
|
@ -203,6 +211,13 @@ func TestParseToolCalls(t *testing.T) {
|
|||
expectedToolCall: []api.ToolCall{t1, t2},
|
||||
expectedTokens: "",
|
||||
},
|
||||
{
|
||||
name: "qwen2.5 tool calls with prefix and initial text",
|
||||
model: "qwen2.5-coder",
|
||||
output: `some tokens before call <tool_call> [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}}, {"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}] </tool_call>`,
|
||||
expectedToolCall: []api.ToolCall{t1, t2},
|
||||
expectedTokens: "some tokens before call",
|
||||
},
|
||||
{
|
||||
name: "qwen2.5 tool calls without prefix and valid tool call",
|
||||
model: "qwen2.5-coder",
|
||||
|
|
@ -356,9 +371,9 @@ func TestParseToolCalls(t *testing.T) {
|
|||
} else if len(toolCalls) > 0 {
|
||||
got = append(got, toolCalls...)
|
||||
add = false
|
||||
} else {
|
||||
add = false
|
||||
}
|
||||
} else if errors.Is(err, ErrAccumulateMore) {
|
||||
add = false
|
||||
}
|
||||
}
|
||||
if add {
|
||||
|
|
@ -388,8 +403,7 @@ func TestParseJSONToolCalls(t *testing.T) {
|
|||
input string
|
||||
parser *Parser
|
||||
wantToolCalls []api.ToolCall
|
||||
wantPartial bool
|
||||
wantValid bool
|
||||
wantErr error
|
||||
}{
|
||||
{
|
||||
name: "valid single tool call",
|
||||
|
|
@ -405,32 +419,28 @@ func TestParseJSONToolCalls(t *testing.T) {
|
|||
},
|
||||
},
|
||||
},
|
||||
wantPartial: false,
|
||||
wantValid: true,
|
||||
wantErr: nil,
|
||||
},
|
||||
{
|
||||
name: "incomplete JSON",
|
||||
input: `{"name": "test_tool", "arguments": {"arg1": `,
|
||||
parser: &Parser{name: "name", arguments: "arguments"},
|
||||
wantToolCalls: nil,
|
||||
wantPartial: true,
|
||||
wantValid: false,
|
||||
wantErr: ErrAccumulateMore,
|
||||
},
|
||||
{
|
||||
name: "invalid JSON",
|
||||
input: `not json at all`,
|
||||
parser: &Parser{name: "name", arguments: "arguments"},
|
||||
wantToolCalls: nil,
|
||||
wantPartial: false,
|
||||
wantValid: false,
|
||||
wantErr: ErrInvalidToolCall,
|
||||
},
|
||||
{
|
||||
name: "missing required fields",
|
||||
input: `{"other": "field"}`,
|
||||
parser: &Parser{name: "name", arguments: "arguments"},
|
||||
wantToolCalls: nil,
|
||||
wantPartial: false,
|
||||
wantValid: false,
|
||||
wantErr: ErrInvalidToolCall,
|
||||
},
|
||||
{
|
||||
name: "multiple tool calls in array",
|
||||
|
|
@ -457,21 +467,20 @@ func TestParseJSONToolCalls(t *testing.T) {
|
|||
},
|
||||
},
|
||||
},
|
||||
wantPartial: false,
|
||||
wantValid: true,
|
||||
wantErr: nil,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
gotCalls, gotPartial := tt.parser.parseJSONToolCalls(tt.input)
|
||||
gotCalls, err := tt.parser.parseJSONToolCalls(tt.input)
|
||||
|
||||
if gotPartial != tt.wantPartial {
|
||||
t.Errorf("parseJSONToolCalls() partial = %v, want %v", gotPartial, tt.wantPartial)
|
||||
if err != tt.wantErr {
|
||||
t.Errorf("parseJSONToolCalls() error = %v, want %v", err, tt.wantErr)
|
||||
}
|
||||
|
||||
if len(gotCalls) != 0 != tt.wantValid {
|
||||
t.Errorf("parseJSONToolCalls() valid = %v, want %v", len(gotCalls) == 0, tt.wantValid)
|
||||
if len(gotCalls) != 0 && tt.wantErr != nil {
|
||||
t.Errorf("parseJSONToolCalls() valid = %v, want %v", len(gotCalls) == 0, tt.wantErr == nil)
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(gotCalls, tt.wantToolCalls); diff != "" {
|
||||
|
|
|
|||
|
|
@ -142,8 +142,8 @@ func toolPrefix(tmpl *gotmpl.Template) string {
|
|||
//
|
||||
// Returns:
|
||||
// - *gotmpl.Template: The subtree containing the .ToolCalls range
|
||||
// - bool: Whether a .ToolCalls range was found in the template
|
||||
func toolTemplate(t *template.Template) (*gotmpl.Template, bool) {
|
||||
// - error: Error if parsing failed
|
||||
func toolTemplate(t *template.Template) (*gotmpl.Template, error) {
|
||||
tmpl := t.Subtree(func(n parse.Node) bool {
|
||||
if t, ok := n.(*parse.RangeNode); ok {
|
||||
return slices.Contains(template.Identifiers(t.Pipe), "ToolCalls")
|
||||
|
|
@ -153,20 +153,20 @@ func toolTemplate(t *template.Template) (*gotmpl.Template, bool) {
|
|||
})
|
||||
|
||||
if tmpl == nil {
|
||||
return nil, false
|
||||
return nil, errors.New("failed to find tool template")
|
||||
}
|
||||
|
||||
return tmpl, true
|
||||
return tmpl, nil
|
||||
}
|
||||
|
||||
// suffixOverlap returns the length of the longest suffix overlap between two strings
|
||||
//
|
||||
// Returns:
|
||||
// - int: The length of the longest suffix overlap
|
||||
func suffixOverlap(s, delim string) int {
|
||||
max := min(len(delim), len(s))
|
||||
func suffixOverlap(s, prefix string) int {
|
||||
max := min(len(prefix), len(s))
|
||||
for i := max; i > 0; i-- {
|
||||
if strings.HasSuffix(s, delim[:i]) {
|
||||
if strings.HasSuffix(s, prefix[:i]) {
|
||||
return i
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -192,9 +192,9 @@ func TestToolTemplate(t *testing.T) {
|
|||
t.Fatalf("failed to parse template: %v", err)
|
||||
}
|
||||
|
||||
_, got := toolTemplate(parsed)
|
||||
if got != tt.want {
|
||||
t.Errorf("toolTemplate() = %v; want %v", got, tt.want)
|
||||
_, err = toolTemplate(parsed)
|
||||
if err != nil && tt.want {
|
||||
t.Errorf("toolTemplate() = %v; want %v", err, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in New Issue