ollama/tools/deepseek_tools_test.go

87 lines
2.8 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package tools
import (
"fmt"
"path/filepath"
"testing"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/template"
"github.com/stretchr/testify/assert"
)
func TestDeepSeekToolParser(t *testing.T) {
p := filepath.Join("testdata")
t1 := api.ToolCall{
Function: api.ToolCallFunction{
Name: "get_current_weather",
Arguments: map[string]any{
"format": "fahrenheit",
"location": "San Francisco, CA",
},
Index: 0,
},
}
// t2 := api.ToolCall{
// Function: api.ToolCallFunction{
// Name: "get_current_weather",
// Arguments: map[string]any{
// "format": "celsius",
// "location": "Toronto, Canada",
// },
// Index: 1,
// },
// }
tests := []struct {
name string
template string
output string
expectedToolCall []api.ToolCall
expectedTokens string
}{
{
name: "single tool call",
output: `<tool▁calls▁begin><tool▁call▁begin>function<tool▁sep>get_current_weather
` + "```json\n" + `{"format":"fahrenheit","location":"San Francisco, CA"}` + "\n```" + `<tool▁call▁end>`,
expectedToolCall: []api.ToolCall{t1},
expectedTokens: "",
},
// {
// name: "multiple tool calls",
// template: `"<tool▁call▁begin>function<tool▁sep>get_current_weather\n` + "```json\n" + `{"format":"fahrenheit","location":"San Francisco, CA"}` + "\n```" + `<tool▁call▁end>"`,
// output: `<tool▁call▁begin>function<tool▁sep>get_current_weather
// ` + "```json\n" + `{"format":"fahrenheit","location":"San Francisco, CA"}` + "\n```" + `<tool▁call▁end>
// <tool▁call▁begin>function<tool▁sep>get_current_weather
// ` + "```json\n" + `{"format":"celsius","location":"Toronto, Canada"}` + "\n```" + `<tool▁call▁end>`,
// expectedToolCall: []api.ToolCall{t1, t2},
// expectedTokens: "",
// },
// {
// name: "invalid tool call format",
// template: `{{"<tool▁call▁begin>function<tool▁sep>get_current_weather\n` + "```json\n" + `{"format":"fahrenheit","location":"San Francisco, CA"}` + "\n```" + `<tool▁call▁end>"}}`,
// output: "This is just some text without a tool call",
// expectedToolCall: nil,
// expectedTokens: "This is just some text without a tool call",
// },
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
tmpl, err := template.Parse(readFile(t, p, "deepseek-r1.gotmpl").String())
if err != nil {
t.Fatal(err)
}
fmt.Println(tmpl.Template.Root.String())
parser, err := NewDeepSeekToolParser(tmpl.Template)
assert.NoError(t, err)
tools, content := parser.Add(tt.output)
assert.Equal(t, tt.expectedToolCall, tools)
assert.Equal(t, tt.expectedTokens, content)
})
}
}