From 478824045dcefa6bf13abd8663223cbd81549636 Mon Sep 17 00:00:00 2001 From: ParthSareen Date: Thu, 29 May 2025 15:26:37 -0700 Subject: [PATCH] temp --- server/routes.go | 16 +++- tools/deepseek_tools.go | 179 +++++++++++++++++++++++++++++++++++ tools/deepseek_tools_test.go | 86 +++++++++++++++++ tools/tools.go | 30 +++++- tools/tools_utils.go | 2 + 5 files changed, 310 insertions(+), 3 deletions(-) create mode 100644 tools/deepseek_tools.go create mode 100644 tools/deepseek_tools_test.go diff --git a/server/routes.go b/server/routes.go index 236f92e22..4a18aefcc 100644 --- a/server/routes.go +++ b/server/routes.go @@ -1523,8 +1523,20 @@ func (s *Server) ChatHandler(c *gin.Context) { } } - var toolParser *tools.Parser - if len(req.Tools) > 0 { + var toolParser tools.ToolParser + + fmt.Println("m.Config.ModelFamily", m.Config.ModelFamily) + if m.Config.ModelFamily == "qwen" { + slog.Info("using deepseek tool parser") + fmt.Println("m.Template.Template", m.Template.Template) + toolParser, err = tools.NewDeepSeekToolParser(m.Template.Template) + if err != nil { + slog.Error("failed to create tool parser", "error", err) + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + } else if len(req.Tools) > 0 { + slog.Info("using default tool parser") toolParser, err = tools.NewParser(m.Template.Template) if err != nil { slog.Error("failed to create tool parser", "error", err) diff --git a/tools/deepseek_tools.go b/tools/deepseek_tools.go new file mode 100644 index 000000000..015b4e479 --- /dev/null +++ b/tools/deepseek_tools.go @@ -0,0 +1,179 @@ +package tools + +import ( + "encoding/json" + "errors" + "fmt" + "strings" + gotmpl "text/template" + + "github.com/ollama/ollama/api" +) + +type DeepSeekToolParser struct { + parser *Parser // Embed the base parser as a field +} + +func (p *DeepSeekToolParser) Add(s string) (tools []api.ToolCall, content string) { + fmt.Println("prefix", p.parser.prefix) + fmt.Println("DeepSeekToolParser.Add: Starting with input:", s) + p.parser.sb.WriteString(s) + s = p.parser.sb.String() + fmt.Println("DeepSeekToolParser.Add: After StringBuilder:", s) + + // Check for prefix pattern in input + s, err := p.parser.checkPrefix(s) + fmt.Println("DeepSeekToolParser.Add: After checkPrefix:", s, "error:", err) + if err != nil { + // Need more input to complete prefix + return nil, s + } + + // Exit if prefix exists in template, greedy parsing is off, and prefix not found + if !p.parser.prefixFound { + fmt.Println("DeepSeekToolParser.Add: Prefix not found, resetting") + p.parser.sb.Reset() + return nil, s + } + + toolCalls, err := parseDeepSeekToolCalls(s) + fmt.Println("DeepSeekToolParser.Add: After parseDeepSeekToolCalls:", toolCalls, "error:", err) + if err != nil { + if errors.Is(err, errAccumulateMore) { + return nil, "" + } + p.parser.sb.Reset() + // Only do greedy JSON parsing if there is no prefix from template + if p.parser.prefix != "" { + fmt.Println("DeepSeekToolParser.Add: Disabling greedy parsing") + p.parser.greedyParseJSON = false + } + if p.parser.index != 0 && p.parser.prefix == "" { + return nil, "" + } + if p.parser.prefixFound { + fmt.Println("DeepSeekToolParser.Add: Prefix found but invalid tool call") + // Drop tokens since prefix was found + return nil, "" + } + return nil, s + } + + fmt.Println("DeepSeekToolParser.Add: Processing tool calls") + for _, tc := range toolCalls { + tc.Function.Index = p.parser.index + p.parser.index++ + } + + p.parser.sb.Reset() + fmt.Println("DeepSeekToolParser.Add: Returning tool calls:", toolCalls) + return toolCalls, "" +} + +func (p *DeepSeekToolParser) NewParser(templateToProcess *gotmpl.Template) (ToolParser, error) { + return NewDeepSeekToolParser(templateToProcess) +} + +func NewDeepSeekToolParser(templateToProcess *gotmpl.Template) (ToolParser, error) { + // Create base parser first + baseParser, err := NewParser(templateToProcess) + if err != nil { + return nil, fmt.Errorf("failed to create base parser: %w", err) + } + + return &DeepSeekToolParser{ + parser: baseParser, + }, nil +} + +func parseDeepSeekToolCalls(s string) ([]api.ToolCall, error) { + fmt.Println("parseDeepSeekToolCalls: Starting with input:", s) + fields := strings.Fields(s) + fmt.Println("parseDeepSeekToolCalls: Split fields:", fields) + + sep := "<|tool▁sep|>" + var functionNames []string + for _, field := range fields { + fmt.Println("parseDeepSeekToolCalls: Processing field:", field) + // TODO: check if brittle + if strings.Contains(field, "function") { + idx := strings.Index(field, "function") + if idx == -1 { + fmt.Println("parseDeepSeekToolCalls: No 'function' prefix found") + return nil, errAccumulateMore + } + functionName := field[idx+len("function"):] + // functionName, cut := strings.CutPrefix(field, "function") + // if !cut { + // fmt.Println("parseDeepSeekToolCalls: Failed to cut 'function' prefix") + // return nil, errAccumulateMore + // } + // pass through on this is fine as it doesn't always come down + functionName, _ = strings.CutPrefix(functionName, sep) + fmt.Println("parseDeepSeekToolCalls: Found function name:", functionName) + functionNames = append(functionNames, functionName) + } + } + + if len(functionNames) == 0 { + fmt.Println("parseDeepSeekToolCalls: No function names found") + return nil, errAccumulateMore + } + fmt.Println("parseDeepSeekToolCalls: Found function names:", functionNames) + + braceCount := 0 + startIndex := -1 + + var rawToolArgs []string + for i, c := range s { + switch c { + case '{': + braceCount++ + if startIndex == -1 { + startIndex = i + fmt.Printf("parseDeepSeekToolCalls: Found opening brace at index %d\n", i) + } + case '}': + braceCount-- + if braceCount == 0 { + rawToolArgs = append(rawToolArgs, s[startIndex:i+1]) + fmt.Printf("parseDeepSeekToolCalls: Found closing brace at index %d, captured: %s\n", i, s[startIndex:i+1]) + startIndex = -1 + } + } + } + fmt.Println("parseDeepSeekToolCalls: Raw tool arguments:", rawToolArgs) + + var toolCalls []api.ToolCall + // unmarshal args + var args map[string]any + for i, rawToolArg := range rawToolArgs { + fmt.Printf("parseDeepSeekToolCalls: Unmarshaling tool arg %d: %s\n", i, rawToolArg) + if err := json.Unmarshal([]byte(rawToolArg), &args); err != nil { + fmt.Println("parseDeepSeekToolCalls: Failed to unmarshal JSON:", err) + return nil, err + } + + toolCalls = append(toolCalls, api.ToolCall{ + Function: api.ToolCallFunction{ + Name: functionNames[i], + Arguments: args, + }, + }) + fmt.Printf("parseDeepSeekToolCalls: Created tool call %d with name %s and args %v\n", i, functionNames[i], args) + } + + if len(toolCalls) == 0 { + fmt.Println("parseDeepSeekToolCalls: No tool calls created") + // todo: check err here + return nil, errInvalidToolCall + } + + fmt.Println("parseDeepSeekToolCalls: Returning tool calls:", toolCalls) + return toolCalls, nil +} + +// ! use as prefix +// {{"<|tool▁call▁begin|> +// ! send to tc parser +// * function<|tool▁sep|>\n```json\n\n```<|tool▁call▁end|>"}} diff --git a/tools/deepseek_tools_test.go b/tools/deepseek_tools_test.go new file mode 100644 index 000000000..19aeed098 --- /dev/null +++ b/tools/deepseek_tools_test.go @@ -0,0 +1,86 @@ +package tools + +import ( + "fmt" + "path/filepath" + "testing" + + "github.com/ollama/ollama/api" + "github.com/ollama/ollama/template" + "github.com/stretchr/testify/assert" +) + +func TestDeepSeekToolParser(t *testing.T) { + p := filepath.Join("testdata") + t1 := api.ToolCall{ + Function: api.ToolCallFunction{ + Name: "get_current_weather", + Arguments: map[string]any{ + "format": "fahrenheit", + "location": "San Francisco, CA", + }, + Index: 0, + }, + } + + // t2 := api.ToolCall{ + // Function: api.ToolCallFunction{ + // Name: "get_current_weather", + // Arguments: map[string]any{ + // "format": "celsius", + // "location": "Toronto, Canada", + // }, + // Index: 1, + // }, + // } + + tests := []struct { + name string + template string + output string + expectedToolCall []api.ToolCall + expectedTokens string + }{ + { + name: "single tool call", + output: `<|tool▁calls▁begin|><|tool▁call▁begin|>function<|tool▁sep|>get_current_weather +` + "```json\n" + `{"format":"fahrenheit","location":"San Francisco, CA"}` + "\n```" + `<|tool▁call▁end|>`, + expectedToolCall: []api.ToolCall{t1}, + expectedTokens: "", + }, + // { + // name: "multiple tool calls", + // template: `"<|tool▁call▁begin|>function<|tool▁sep|>get_current_weather\n` + "```json\n" + `{"format":"fahrenheit","location":"San Francisco, CA"}` + "\n```" + `<|tool▁call▁end|>"`, + // output: `<|tool▁call▁begin|>function<|tool▁sep|>get_current_weather + // ` + "```json\n" + `{"format":"fahrenheit","location":"San Francisco, CA"}` + "\n```" + `<|tool▁call▁end|> + // <|tool▁call▁begin|>function<|tool▁sep|>get_current_weather + // ` + "```json\n" + `{"format":"celsius","location":"Toronto, Canada"}` + "\n```" + `<|tool▁call▁end|>`, + // expectedToolCall: []api.ToolCall{t1, t2}, + // expectedTokens: "", + // }, + // { + // name: "invalid tool call format", + // template: `{{"<|tool▁call▁begin|>function<|tool▁sep|>get_current_weather\n` + "```json\n" + `{"format":"fahrenheit","location":"San Francisco, CA"}` + "\n```" + `<|tool▁call▁end|>"}}`, + // output: "This is just some text without a tool call", + // expectedToolCall: nil, + // expectedTokens: "This is just some text without a tool call", + // }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tmpl, err := template.Parse(readFile(t, p, "deepseek-r1.gotmpl").String()) + if err != nil { + t.Fatal(err) + } + fmt.Println(tmpl.Template.Root.String()) + + parser, err := NewDeepSeekToolParser(tmpl.Template) + assert.NoError(t, err) + + tools, content := parser.Add(tt.output) + assert.Equal(t, tt.expectedToolCall, tools) + assert.Equal(t, tt.expectedTokens, content) + }) + } +} diff --git a/tools/tools.go b/tools/tools.go index 914a5eaf0..c58378024 100644 --- a/tools/tools.go +++ b/tools/tools.go @@ -3,6 +3,7 @@ package tools import ( "encoding/json" "errors" + "fmt" "log/slog" "strings" gotmpl "text/template" @@ -16,6 +17,11 @@ var ( errAccumulateMore = errors.New("need to accumulate more content") ) +type ToolParser interface { + Add(s string) (tools []api.ToolCall, content string) + NewParser(templateToProcess *gotmpl.Template) (ToolParser, error) +} + type Parser struct { greedyParseJSON bool prefix string @@ -104,6 +110,9 @@ func parseJSONToolCalls(s string, name, arguments string, prefix string) ([]api. continue } + fmt.Println("name", name) + fmt.Println("arguments", arguments) + fmt.Println("parseJSONToolCalls: Objects:", objs) // Extract tool calls from objects for _, kv := range objs { n, nok := kv[name].(string) @@ -123,7 +132,6 @@ func parseJSONToolCalls(s string, name, arguments string, prefix string) ([]api. // Valid JSON, no tool calls found if len(toolCalls) == 0 { - slog.Debug("No valid tool calls found in any raw tool calls.", "rawToolCalls", rawToolCalls) return nil, errInvalidToolCall } @@ -177,6 +185,7 @@ func (p *Parser) checkPrefix(s string) (string, error) { func (p *Parser) Add(s string) (tools []api.ToolCall, content string) { p.sb.WriteString(s) s = p.sb.String() + fmt.Println("Add: Starting with input:", s) // Check for prefix pattern in input s, err := p.checkPrefix(s) @@ -225,23 +234,37 @@ func (p *Parser) Add(s string) (tools []api.ToolCall, content string) { // // Returns an error if the template does not contain valid tool call formatting. func NewParser(templateToProcess *gotmpl.Template) (*Parser, error) { + fmt.Println("Checkpoint 1: Starting NewParser") parsed, err := template.Parse(templateToProcess.Root.String()) if err != nil { + fmt.Println("Checkpoint 2: Error parsing template:", err) return nil, err } + fmt.Println("Checkpoint 3: Getting tool template") tt, err := toolTemplate(parsed) + fmt.Println("Checkpoint 4: Tool template:", tt.Root.String()) if err != nil { + fmt.Println("Checkpoint 5: Error getting tool template:", err) return nil, err } + fmt.Println("Checkpoint 6: Getting tool prefix") tp := toolPrefix(templateToProcess) + fmt.Println("Checkpoint 7: Tool prefix:", tp) + fmt.Println("Checkpoint 8: Extracting tool args") name, arguments, err := extractToolArgs(tt) if err != nil { + fmt.Println("Checkpoint 9: Error extracting tool args:", err) return nil, err } + // name := "temp1" + // args := "temp2" + fmt.Println("Checkpoint 10: Tool name:", name, "arguments:", arguments) + + fmt.Println("Checkpoint 11: Creating parser") return &Parser{ tmpl: *tt, sb: strings.Builder{}, @@ -251,3 +274,8 @@ func NewParser(templateToProcess *gotmpl.Template) (*Parser, error) { arguments: arguments, }, nil } + +// NewParser implements the ToolParser interface +func (p *Parser) NewParser(templateToProcess *gotmpl.Template) (ToolParser, error) { + return NewParser(templateToProcess) +} diff --git a/tools/tools_utils.go b/tools/tools_utils.go index 48531b789..b08340010 100644 --- a/tools/tools_utils.go +++ b/tools/tools_utils.go @@ -4,6 +4,7 @@ import ( "bytes" "encoding/json" "errors" + "fmt" "log/slog" "slices" "strings" @@ -98,6 +99,7 @@ func isToolCallsNode(n *parse.IfNode) bool { func toolPrefix(tmpl *gotmpl.Template) string { tokenText, ok := extractToolCallsFormat(tmpl) + fmt.Println("tokenText", tokenText) if !ok { return "" }