Compare commits
22 Commits
v0.7.1-rc0
...
v0.9.0
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
5f57b0ef42 | ||
|
|
aa25aff10d | ||
|
|
ea79003180 | ||
|
|
9239a254e0 | ||
|
|
066d0f4746 | ||
|
|
aea6fb9b58 | ||
|
|
012cf65340 | ||
|
|
a45231af47 | ||
|
|
2307fc2bcd | ||
|
|
6623898198 | ||
|
|
eda472df1b | ||
|
|
f18e0cb550 | ||
|
|
e8b981fa5d | ||
|
|
884d26093c | ||
|
|
1f371ea92f | ||
|
|
73d6a82cce | ||
|
|
6db8a3771c | ||
|
|
d950ff12c0 | ||
|
|
adff143bcd | ||
|
|
fbe6ae285a | ||
|
|
fdd4d479a3 | ||
|
|
61aeaf7e81 |
@@ -406,6 +406,7 @@ See the [API documentation](./docs/api.md) for all endpoints.
|
||||
- [AppFlowy](https://github.com/AppFlowy-IO/AppFlowy) (AI collaborative workspace with Ollama, cross-platform and self-hostable)
|
||||
- [Lumina](https://github.com/cushydigit/lumina.git) (A lightweight, minimal React.js frontend for interacting with Ollama servers)
|
||||
- [Tiny Notepad](https://pypi.org/project/tiny-notepad) (A lightweight, notepad-like interface to chat with ollama available on PyPI)
|
||||
- [macLlama (macOS native)](https://github.com/hellotunamayo/macLlama) (A native macOS GUI application for interacting with Ollama models, featuring a chat interface.)
|
||||
|
||||
### Cloud
|
||||
|
||||
@@ -449,6 +450,7 @@ See the [API documentation](./docs/api.md) for all endpoints.
|
||||
- [orbiton](https://github.com/xyproto/orbiton) Configuration-free text editor and IDE with support for tab completion with Ollama.
|
||||
- [orca-cli](https://github.com/molbal/orca-cli) Ollama Registry CLI Application - Browse, pull, and download models from Ollama Registry in your terminal.
|
||||
- [GGUF-to-Ollama](https://github.com/jonathanhecl/gguf-to-ollama) - Importing GGUF to Ollama made easy (multiplatform)
|
||||
- [AWS-Strands-With-Ollama](https://github.com/rapidarchitect/ollama_strands) - AWS Strands Agents with Ollama Examples
|
||||
|
||||
### Apple Vision Pro
|
||||
|
||||
|
||||
@@ -24,7 +24,10 @@ import (
|
||||
"net/http"
|
||||
"net/url"
|
||||
"runtime"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"github.com/ollama/ollama/auth"
|
||||
"github.com/ollama/ollama/envconfig"
|
||||
"github.com/ollama/ollama/format"
|
||||
"github.com/ollama/ollama/version"
|
||||
@@ -76,6 +79,14 @@ func NewClient(base *url.URL, http *http.Client) *Client {
|
||||
}
|
||||
}
|
||||
|
||||
func getAuthorizationToken(ctx context.Context, challenge string) (string, error) {
|
||||
token, err := auth.Sign(ctx, []byte(challenge))
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return token, nil
|
||||
}
|
||||
|
||||
func (c *Client) do(ctx context.Context, method, path string, reqData, respData any) error {
|
||||
var reqBody io.Reader
|
||||
var data []byte
|
||||
@@ -97,6 +108,21 @@ func (c *Client) do(ctx context.Context, method, path string, reqData, respData
|
||||
}
|
||||
|
||||
requestURL := c.base.JoinPath(path)
|
||||
|
||||
var token string
|
||||
if envconfig.UseAuth() || c.base.Hostname() == "ollama.com" {
|
||||
now := strconv.FormatInt(time.Now().Unix(), 10)
|
||||
chal := fmt.Sprintf("%s,%s?ts=%s", method, path, now)
|
||||
token, err = getAuthorizationToken(ctx, chal)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
q := requestURL.Query()
|
||||
q.Set("ts", now)
|
||||
requestURL.RawQuery = q.Encode()
|
||||
}
|
||||
|
||||
request, err := http.NewRequestWithContext(ctx, method, requestURL.String(), reqBody)
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -106,6 +132,10 @@ func (c *Client) do(ctx context.Context, method, path string, reqData, respData
|
||||
request.Header.Set("Accept", "application/json")
|
||||
request.Header.Set("User-Agent", fmt.Sprintf("ollama/%s (%s %s) Go/%s", version.Version, runtime.GOARCH, runtime.GOOS, runtime.Version()))
|
||||
|
||||
if token != "" {
|
||||
request.Header.Set("Authorization", token)
|
||||
}
|
||||
|
||||
respObj, err := c.http.Do(request)
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -143,6 +173,22 @@ func (c *Client) stream(ctx context.Context, method, path string, data any, fn f
|
||||
}
|
||||
|
||||
requestURL := c.base.JoinPath(path)
|
||||
|
||||
var token string
|
||||
if envconfig.UseAuth() || c.base.Hostname() == "ollama.com" {
|
||||
var err error
|
||||
now := strconv.FormatInt(time.Now().Unix(), 10)
|
||||
chal := fmt.Sprintf("%s,%s?ts=%s", method, path, now)
|
||||
token, err = getAuthorizationToken(ctx, chal)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
q := requestURL.Query()
|
||||
q.Set("ts", now)
|
||||
requestURL.RawQuery = q.Encode()
|
||||
}
|
||||
|
||||
request, err := http.NewRequestWithContext(ctx, method, requestURL.String(), buf)
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -152,6 +198,10 @@ func (c *Client) stream(ctx context.Context, method, path string, data any, fn f
|
||||
request.Header.Set("Accept", "application/x-ndjson")
|
||||
request.Header.Set("User-Agent", fmt.Sprintf("ollama/%s (%s %s) Go/%s", version.Version, runtime.GOARCH, runtime.GOOS, runtime.Version()))
|
||||
|
||||
if token != "" {
|
||||
request.Header.Set("Authorization", token)
|
||||
}
|
||||
|
||||
response, err := c.http.Do(request)
|
||||
if err != nil {
|
||||
return err
|
||||
|
||||
21
api/types.go
21
api/types.go
@@ -83,6 +83,12 @@ type GenerateRequest struct {
|
||||
// Options lists model-specific options. For example, temperature can be
|
||||
// set through this field, if the model supports it.
|
||||
Options map[string]any `json:"options"`
|
||||
|
||||
// Think controls whether thinking/reasoning models will think before
|
||||
// responding. Needs to be a pointer so we can distinguish between false
|
||||
// (request that thinking _not_ be used) and unset (use the old behavior
|
||||
// before this option was introduced)
|
||||
Think *bool `json:"think,omitempty"`
|
||||
}
|
||||
|
||||
// ChatRequest describes a request sent by [Client.Chat].
|
||||
@@ -108,6 +114,10 @@ type ChatRequest struct {
|
||||
|
||||
// Options lists model-specific options.
|
||||
Options map[string]any `json:"options"`
|
||||
|
||||
// Think controls whether thinking/reasoning models will think before
|
||||
// responding
|
||||
Think *bool `json:"think,omitempty"`
|
||||
}
|
||||
|
||||
type Tools []Tool
|
||||
@@ -126,8 +136,11 @@ func (t Tool) String() string {
|
||||
// role ("system", "user", or "assistant"), the content and an optional list
|
||||
// of images.
|
||||
type Message struct {
|
||||
Role string `json:"role"`
|
||||
Content string `json:"content"`
|
||||
Role string `json:"role"`
|
||||
Content string `json:"content"`
|
||||
// Thinking contains the text that was inside thinking tags in the
|
||||
// original model output when ChatRequest.Think is enabled.
|
||||
Thinking string `json:"thinking,omitempty"`
|
||||
Images []ImageData `json:"images,omitempty"`
|
||||
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
|
||||
}
|
||||
@@ -478,6 +491,10 @@ type GenerateResponse struct {
|
||||
// Response is the textual response itself.
|
||||
Response string `json:"response"`
|
||||
|
||||
// Thinking contains the text that was inside thinking tags in the
|
||||
// original model output when ChatRequest.Think is enabled.
|
||||
Thinking string `json:"thinking,omitempty"`
|
||||
|
||||
// Done specifies if the response is complete.
|
||||
Done bool `json:"done"`
|
||||
|
||||
|
||||
@@ -372,3 +372,50 @@ func TestPropertyType_MarshalJSON(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestThinking_UnmarshalJSON(t *testing.T) {
|
||||
trueVal := true
|
||||
falseVal := false
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
expectedThinking *bool
|
||||
expectedError bool
|
||||
}{
|
||||
{
|
||||
name: "true",
|
||||
input: `{ "think": true }`,
|
||||
expectedThinking: &trueVal,
|
||||
},
|
||||
{
|
||||
name: "false",
|
||||
input: `{ "think": false }`,
|
||||
expectedThinking: &falseVal,
|
||||
},
|
||||
{
|
||||
name: "unset",
|
||||
input: `{ }`,
|
||||
expectedThinking: nil,
|
||||
},
|
||||
{
|
||||
name: "invalid",
|
||||
input: `{ "think": "true" }`,
|
||||
expectedThinking: nil,
|
||||
expectedError: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
var req GenerateRequest
|
||||
err := json.Unmarshal([]byte(test.input), &req)
|
||||
if test.expectedError {
|
||||
require.Error(t, err)
|
||||
} else {
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, test.expectedThinking, req.Think)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
178
cmd/cmd.go
178
cmd/cmd.go
@@ -39,6 +39,7 @@ import (
|
||||
"github.com/ollama/ollama/format"
|
||||
"github.com/ollama/ollama/parser"
|
||||
"github.com/ollama/ollama/progress"
|
||||
"github.com/ollama/ollama/readline"
|
||||
"github.com/ollama/ollama/runner"
|
||||
"github.com/ollama/ollama/server"
|
||||
"github.com/ollama/ollama/types/model"
|
||||
@@ -46,6 +47,23 @@ import (
|
||||
"github.com/ollama/ollama/version"
|
||||
)
|
||||
|
||||
// ensureThinkingSupport emits a warning if the model does not advertise thinking support
|
||||
func ensureThinkingSupport(ctx context.Context, client *api.Client, name string) {
|
||||
if name == "" {
|
||||
return
|
||||
}
|
||||
resp, err := client.Show(ctx, &api.ShowRequest{Model: name})
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
for _, cap := range resp.Capabilities {
|
||||
if cap == model.CapabilityThinking {
|
||||
return
|
||||
}
|
||||
}
|
||||
fmt.Fprintf(os.Stderr, "warning: model %q does not support thinking output\n", name)
|
||||
}
|
||||
|
||||
var errModelfileNotFound = errors.New("specified Modelfile wasn't found")
|
||||
|
||||
func getModelfileName(cmd *cobra.Command) (string, error) {
|
||||
@@ -265,6 +283,9 @@ func loadOrUnloadModel(cmd *cobra.Command, opts *runOptions) error {
|
||||
req := &api.GenerateRequest{
|
||||
Model: opts.Model,
|
||||
KeepAlive: opts.KeepAlive,
|
||||
|
||||
// pass Think here so we fail before getting to the chat prompt if the model doesn't support it
|
||||
Think: opts.Think,
|
||||
}
|
||||
|
||||
return client.Generate(cmd.Context(), req, func(api.GenerateResponse) error { return nil })
|
||||
@@ -299,6 +320,22 @@ func RunHandler(cmd *cobra.Command, args []string) error {
|
||||
}
|
||||
opts.Format = format
|
||||
|
||||
thinkFlag := cmd.Flags().Lookup("think")
|
||||
if thinkFlag.Changed {
|
||||
think, err := cmd.Flags().GetBool("think")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
opts.Think = &think
|
||||
} else {
|
||||
opts.Think = nil
|
||||
}
|
||||
hidethinking, err := cmd.Flags().GetBool("hidethinking")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
opts.HideThinking = hidethinking
|
||||
|
||||
keepAlive, err := cmd.Flags().GetString("keepalive")
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -362,6 +399,11 @@ func RunHandler(cmd *cobra.Command, args []string) error {
|
||||
return err
|
||||
}
|
||||
|
||||
opts.Think, err = inferThinkingOption(&info.Capabilities, &opts, thinkFlag.Changed)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
opts.MultiModal = slices.Contains(info.Capabilities, model.CapabilityVision)
|
||||
|
||||
// TODO: remove the projector info and vision info checks below,
|
||||
@@ -923,17 +965,19 @@ func PullHandler(cmd *cobra.Command, args []string) error {
|
||||
type generateContextKey string
|
||||
|
||||
type runOptions struct {
|
||||
Model string
|
||||
ParentModel string
|
||||
Prompt string
|
||||
Messages []api.Message
|
||||
WordWrap bool
|
||||
Format string
|
||||
System string
|
||||
Images []api.ImageData
|
||||
Options map[string]any
|
||||
MultiModal bool
|
||||
KeepAlive *api.Duration
|
||||
Model string
|
||||
ParentModel string
|
||||
Prompt string
|
||||
Messages []api.Message
|
||||
WordWrap bool
|
||||
Format string
|
||||
System string
|
||||
Images []api.ImageData
|
||||
Options map[string]any
|
||||
MultiModal bool
|
||||
KeepAlive *api.Duration
|
||||
Think *bool
|
||||
HideThinking bool
|
||||
}
|
||||
|
||||
type displayResponseState struct {
|
||||
@@ -989,6 +1033,26 @@ func displayResponse(content string, wordWrap bool, state *displayResponseState)
|
||||
}
|
||||
}
|
||||
|
||||
func thinkingOutputOpeningText(plainText bool) string {
|
||||
text := "Thinking...\n"
|
||||
|
||||
if plainText {
|
||||
return text
|
||||
}
|
||||
|
||||
return readline.ColorGrey + readline.ColorBold + text + readline.ColorDefault + readline.ColorGrey
|
||||
}
|
||||
|
||||
func thinkingOutputClosingText(plainText bool) string {
|
||||
text := "...done thinking.\n\n"
|
||||
|
||||
if plainText {
|
||||
return text
|
||||
}
|
||||
|
||||
return readline.ColorGrey + readline.ColorBold + text + readline.ColorDefault
|
||||
}
|
||||
|
||||
func chat(cmd *cobra.Command, opts runOptions) (*api.Message, error) {
|
||||
client, err := api.ClientFromEnvironment()
|
||||
if err != nil {
|
||||
@@ -1016,14 +1080,34 @@ func chat(cmd *cobra.Command, opts runOptions) (*api.Message, error) {
|
||||
var latest api.ChatResponse
|
||||
var fullResponse strings.Builder
|
||||
var role string
|
||||
var thinkTagOpened bool = false
|
||||
var thinkTagClosed bool = false
|
||||
|
||||
fn := func(response api.ChatResponse) error {
|
||||
p.StopAndClear()
|
||||
if response.Message.Content != "" || !opts.HideThinking {
|
||||
p.StopAndClear()
|
||||
}
|
||||
|
||||
latest = response
|
||||
|
||||
role = response.Message.Role
|
||||
if response.Message.Thinking != "" && !opts.HideThinking {
|
||||
if !thinkTagOpened {
|
||||
fmt.Print(thinkingOutputOpeningText(false))
|
||||
thinkTagOpened = true
|
||||
}
|
||||
displayResponse(response.Message.Thinking, opts.WordWrap, state)
|
||||
}
|
||||
|
||||
content := response.Message.Content
|
||||
if thinkTagOpened && !thinkTagClosed && content != "" {
|
||||
fmt.Print(thinkingOutputClosingText(false))
|
||||
thinkTagClosed = true
|
||||
}
|
||||
// purposefully not putting thinking blocks in the response, which would
|
||||
// only be needed if we later added tool calling to the cli (they get
|
||||
// filtered out anyway since current models don't expect them unless you're
|
||||
// about to finish some tool calls)
|
||||
fullResponse.WriteString(content)
|
||||
|
||||
displayResponse(content, opts.WordWrap, state)
|
||||
@@ -1040,6 +1124,7 @@ func chat(cmd *cobra.Command, opts runOptions) (*api.Message, error) {
|
||||
Messages: opts.Messages,
|
||||
Format: json.RawMessage(opts.Format),
|
||||
Options: opts.Options,
|
||||
Think: opts.Think,
|
||||
}
|
||||
|
||||
if opts.KeepAlive != nil {
|
||||
@@ -1101,13 +1186,32 @@ func generate(cmd *cobra.Command, opts runOptions) error {
|
||||
}()
|
||||
|
||||
var state *displayResponseState = &displayResponseState{}
|
||||
var thinkTagOpened bool = false
|
||||
var thinkTagClosed bool = false
|
||||
|
||||
plainText := !term.IsTerminal(int(os.Stdout.Fd()))
|
||||
|
||||
fn := func(response api.GenerateResponse) error {
|
||||
p.StopAndClear()
|
||||
|
||||
latest = response
|
||||
content := response.Response
|
||||
|
||||
if response.Response != "" || !opts.HideThinking {
|
||||
p.StopAndClear()
|
||||
}
|
||||
|
||||
if response.Thinking != "" && !opts.HideThinking {
|
||||
if !thinkTagOpened {
|
||||
fmt.Print(thinkingOutputOpeningText(plainText))
|
||||
thinkTagOpened = true
|
||||
}
|
||||
displayResponse(response.Thinking, opts.WordWrap, state)
|
||||
}
|
||||
|
||||
if thinkTagOpened && !thinkTagClosed && content != "" {
|
||||
fmt.Print(thinkingOutputClosingText(plainText))
|
||||
thinkTagClosed = true
|
||||
}
|
||||
|
||||
displayResponse(content, opts.WordWrap, state)
|
||||
|
||||
return nil
|
||||
@@ -1133,6 +1237,7 @@ func generate(cmd *cobra.Command, opts runOptions) error {
|
||||
System: opts.System,
|
||||
Options: opts.Options,
|
||||
KeepAlive: opts.KeepAlive,
|
||||
Think: opts.Think,
|
||||
}
|
||||
|
||||
if err := client.Generate(ctx, &request, fn); err != nil {
|
||||
@@ -1348,6 +1453,8 @@ func NewCLI() *cobra.Command {
|
||||
runCmd.Flags().Bool("insecure", false, "Use an insecure registry")
|
||||
runCmd.Flags().Bool("nowordwrap", false, "Don't wrap words to the next line automatically")
|
||||
runCmd.Flags().String("format", "", "Response format (e.g. json)")
|
||||
runCmd.Flags().Bool("think", false, "Whether to use thinking mode for supported models")
|
||||
runCmd.Flags().Bool("hidethinking", false, "Hide thinking output (if provided)")
|
||||
|
||||
stopCmd := &cobra.Command{
|
||||
Use: "stop MODEL",
|
||||
@@ -1399,7 +1506,6 @@ func NewCLI() *cobra.Command {
|
||||
PreRunE: checkServerHeartbeat,
|
||||
RunE: ListRunningHandler,
|
||||
}
|
||||
|
||||
copyCmd := &cobra.Command{
|
||||
Use: "cp SOURCE DESTINATION",
|
||||
Short: "Copy a model",
|
||||
@@ -1488,3 +1594,45 @@ func NewCLI() *cobra.Command {
|
||||
|
||||
return rootCmd
|
||||
}
|
||||
|
||||
// If the user has explicitly set thinking options, either through the CLI or
|
||||
// through the `/set think` or `set nothink` interactive options, then we
|
||||
// respect them. Otherwise, we check model capabilities to see if the model
|
||||
// supports thinking. If the model does support thinking, we enable it.
|
||||
// Otherwise, we unset the thinking option (which is different than setting it
|
||||
// to false).
|
||||
//
|
||||
// If capabilities are not provided, we fetch them from the server.
|
||||
func inferThinkingOption(caps *[]model.Capability, runOpts *runOptions, explicitlySetByUser bool) (*bool, error) {
|
||||
if explicitlySetByUser {
|
||||
return runOpts.Think, nil
|
||||
}
|
||||
|
||||
if caps == nil {
|
||||
client, err := api.ClientFromEnvironment()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
ret, err := client.Show(context.Background(), &api.ShowRequest{
|
||||
Model: runOpts.Model,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
caps = &ret.Capabilities
|
||||
}
|
||||
|
||||
thinkingSupported := false
|
||||
for _, cap := range *caps {
|
||||
if cap == model.CapabilityThinking {
|
||||
thinkingSupported = true
|
||||
}
|
||||
}
|
||||
|
||||
if thinkingSupported {
|
||||
thinking := true
|
||||
return &thinking, nil
|
||||
}
|
||||
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
@@ -62,6 +62,8 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error {
|
||||
fmt.Fprintln(os.Stderr, " /set noformat Disable formatting")
|
||||
fmt.Fprintln(os.Stderr, " /set verbose Show LLM stats")
|
||||
fmt.Fprintln(os.Stderr, " /set quiet Disable LLM stats")
|
||||
fmt.Fprintln(os.Stderr, " /set think Enable thinking")
|
||||
fmt.Fprintln(os.Stderr, " /set nothink Disable thinking")
|
||||
fmt.Fprintln(os.Stderr, "")
|
||||
}
|
||||
|
||||
@@ -128,6 +130,7 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error {
|
||||
|
||||
var sb strings.Builder
|
||||
var multiline MultilineState
|
||||
var thinkExplicitlySet bool = opts.Think != nil
|
||||
|
||||
for {
|
||||
line, err := scanner.Readline()
|
||||
@@ -195,11 +198,19 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error {
|
||||
opts.Model = args[1]
|
||||
opts.Messages = []api.Message{}
|
||||
fmt.Printf("Loading model '%s'\n", opts.Model)
|
||||
opts.Think, err = inferThinkingOption(nil, &opts, thinkExplicitlySet)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := loadOrUnloadModel(cmd, &opts); err != nil {
|
||||
if strings.Contains(err.Error(), "not found") {
|
||||
fmt.Printf("error: %v\n", err)
|
||||
continue
|
||||
}
|
||||
if strings.Contains(err.Error(), "does not support thinking") {
|
||||
fmt.Printf("error: %v\n", err)
|
||||
continue
|
||||
}
|
||||
return err
|
||||
}
|
||||
continue
|
||||
@@ -260,6 +271,22 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error {
|
||||
return err
|
||||
}
|
||||
fmt.Println("Set 'quiet' mode.")
|
||||
case "think":
|
||||
think := true
|
||||
opts.Think = &think
|
||||
thinkExplicitlySet = true
|
||||
if client, err := api.ClientFromEnvironment(); err == nil {
|
||||
ensureThinkingSupport(cmd.Context(), client, opts.Model)
|
||||
}
|
||||
fmt.Println("Set 'think' mode.")
|
||||
case "nothink":
|
||||
think := false
|
||||
opts.Think = &think
|
||||
thinkExplicitlySet = true
|
||||
if client, err := api.ClientFromEnvironment(); err == nil {
|
||||
ensureThinkingSupport(cmd.Context(), client, opts.Model)
|
||||
}
|
||||
fmt.Println("Set 'nothink' mode.")
|
||||
case "format":
|
||||
if len(args) < 3 || args[2] != "json" {
|
||||
fmt.Println("Invalid or missing format. For 'json' mode use '/set format json'")
|
||||
@@ -448,6 +475,11 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error {
|
||||
|
||||
assistant, err := chat(cmd, opts)
|
||||
if err != nil {
|
||||
if strings.Contains(err.Error(), "does not support thinking") {
|
||||
fmt.Printf("error: %v\n", err)
|
||||
sb.Reset()
|
||||
continue
|
||||
}
|
||||
return err
|
||||
}
|
||||
if assistant != nil {
|
||||
|
||||
63
cmd/warn_thinking_test.go
Normal file
63
cmd/warn_thinking_test.go
Normal file
@@ -0,0 +1,63 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/types/model"
|
||||
)
|
||||
|
||||
// Test that a warning is printed when thinking is requested but not supported.
|
||||
func TestWarnMissingThinking(t *testing.T) {
|
||||
cases := []struct {
|
||||
capabilities []model.Capability
|
||||
expectWarn bool
|
||||
}{
|
||||
{capabilities: []model.Capability{model.CapabilityThinking}, expectWarn: false},
|
||||
{capabilities: []model.Capability{}, expectWarn: true},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path != "/api/show" || r.Method != http.MethodPost {
|
||||
t.Fatalf("unexpected request to %s %s", r.URL.Path, r.Method)
|
||||
}
|
||||
var req api.ShowRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
t.Fatalf("decode request: %v", err)
|
||||
}
|
||||
resp := api.ShowResponse{Capabilities: tc.capabilities}
|
||||
if err := json.NewEncoder(w).Encode(resp); err != nil {
|
||||
t.Fatalf("encode response: %v", err)
|
||||
}
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
t.Setenv("OLLAMA_HOST", srv.URL)
|
||||
client, err := api.ClientFromEnvironment()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
oldStderr := os.Stderr
|
||||
r, w, _ := os.Pipe()
|
||||
os.Stderr = w
|
||||
ensureThinkingSupport(t.Context(), client, "m")
|
||||
w.Close()
|
||||
os.Stderr = oldStderr
|
||||
out, _ := io.ReadAll(r)
|
||||
|
||||
warned := strings.Contains(string(out), "warning:")
|
||||
if tc.expectWarn && !warned {
|
||||
t.Errorf("expected warning, got none")
|
||||
}
|
||||
if !tc.expectWarn && warned {
|
||||
t.Errorf("did not expect warning, got: %s", string(out))
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -94,7 +94,9 @@ func (m *mllamaModel) Tensors(ts []Tensor) []*ggml.Tensor {
|
||||
var out []*ggml.Tensor
|
||||
var text []Tensor
|
||||
for _, t := range ts {
|
||||
if t.Name() == "v.position_embd.gate" {
|
||||
if !strings.HasPrefix(t.Name(), "v.") && !strings.HasPrefix(t.Name(), "mm.") {
|
||||
text = append(text, t)
|
||||
} else if t.Name() == "v.position_embd.gate" {
|
||||
for _, name := range []string{"v.position_embd.gate", "v.tile_position_embd.gate"} {
|
||||
tt := t.Clone()
|
||||
tt.SetRepacker(m.repack(name))
|
||||
@@ -105,23 +107,21 @@ func (m *mllamaModel) Tensors(ts []Tensor) []*ggml.Tensor {
|
||||
WriterTo: tt,
|
||||
})
|
||||
}
|
||||
} else if t.Name() == "v.pre_tile_position_embd.gate" || t.Name() == "v.post_tile_position_embd.gate" {
|
||||
t.SetRepacker(m.repack(t.Name()))
|
||||
out = append(out, &ggml.Tensor{
|
||||
Name: t.Name(),
|
||||
Kind: t.Kind(),
|
||||
Shape: t.Shape(),
|
||||
WriterTo: t,
|
||||
})
|
||||
} else if strings.HasPrefix(t.Name(), "v.") || strings.HasPrefix(t.Name(), "mm.") {
|
||||
out = append(out, &ggml.Tensor{
|
||||
Name: t.Name(),
|
||||
Kind: t.Kind(),
|
||||
Shape: t.Shape(),
|
||||
WriterTo: t,
|
||||
})
|
||||
} else {
|
||||
text = append(text, t)
|
||||
if t.Name() == "v.pre_tile_position_embd.gate" || t.Name() == "v.post_tile_position_embd.gate" {
|
||||
t.SetRepacker(m.repack(t.Name()))
|
||||
} else if strings.HasSuffix(t.Name(), "attn_q.weight") || strings.HasSuffix(t.Name(), "attn_k.weight") {
|
||||
t.SetRepacker(m.repack(t.Name()))
|
||||
} else if strings.HasSuffix(t.Name(), "attn_gate") || strings.HasSuffix(t.Name(), "ffn_gate") {
|
||||
t.SetRepacker(m.repack(t.Name()))
|
||||
}
|
||||
|
||||
out = append(out, &ggml.Tensor{
|
||||
Name: t.Name(),
|
||||
Kind: t.Kind(),
|
||||
Shape: t.Shape(),
|
||||
WriterTo: t,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -137,16 +137,35 @@ func (m *mllamaModel) repack(name string) Repacker {
|
||||
|
||||
var t tensor.Tensor = tensor.New(tensor.WithShape(dims...), tensor.WithBacking(data))
|
||||
|
||||
t, err = tensor.Tanh(t)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if strings.HasSuffix(name, "attn_q.weight") || strings.HasSuffix(name, "attn_k.weight") {
|
||||
heads := m.VisionModel.AttentionHeads
|
||||
if err := t.Reshape(append([]int{int(heads), 2, dims[0] / int(heads) / 2}, dims[1:]...)...); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if name == "v.position_embd.gate" {
|
||||
t, err = tensor.Sub(float32(1), t)
|
||||
if err := t.T(0, 2, 1, 3); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := t.Reshape(dims...); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := t.Transpose(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
} else {
|
||||
t, err = tensor.Tanh(t)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if name == "v.position_embd.gate" {
|
||||
t, err = tensor.Sub(float32(1), t)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
t = tensor.Materialize(t)
|
||||
|
||||
@@ -43,6 +43,7 @@ Generate a response for a given prompt with a provided model. This is a streamin
|
||||
- `prompt`: the prompt to generate a response for
|
||||
- `suffix`: the text after the model response
|
||||
- `images`: (optional) a list of base64-encoded images (for multimodal models such as `llava`)
|
||||
- `think`: (for thinking models) should the model think before responding?
|
||||
|
||||
Advanced parameters (optional):
|
||||
|
||||
@@ -490,11 +491,13 @@ Generate the next message in a chat with a provided model. This is a streaming e
|
||||
- `model`: (required) the [model name](#model-names)
|
||||
- `messages`: the messages of the chat, this can be used to keep a chat memory
|
||||
- `tools`: list of tools in JSON for the model to use if supported
|
||||
- `think`: (for thinking models) should the model think before responding?
|
||||
|
||||
The `message` object has the following fields:
|
||||
|
||||
- `role`: the role of the message, either `system`, `user`, `assistant`, or `tool`
|
||||
- `content`: the content of the message
|
||||
- `thinking`: (for thinking models) the model's thinking process
|
||||
- `images` (optional): a list of images to include in the message (for multimodal models such as `llava`)
|
||||
- `tool_calls` (optional): a list of tools in JSON that the model wants to use
|
||||
|
||||
|
||||
@@ -132,22 +132,12 @@ success
|
||||
|
||||
### Supported Quantizations
|
||||
|
||||
- `q4_0`
|
||||
- `q4_1`
|
||||
- `q5_0`
|
||||
- `q5_1`
|
||||
- `q8_0`
|
||||
|
||||
#### K-means Quantizations
|
||||
|
||||
- `q3_K_S`
|
||||
- `q3_K_M`
|
||||
- `q3_K_L`
|
||||
- `q4_K_S`
|
||||
- `q4_K_M`
|
||||
- `q5_K_S`
|
||||
- `q5_K_M`
|
||||
- `q6_K`
|
||||
|
||||
|
||||
## Sharing your model on ollama.com
|
||||
|
||||
@@ -183,6 +183,8 @@ var (
|
||||
NewEngine = Bool("OLLAMA_NEW_ENGINE")
|
||||
// ContextLength sets the default context length
|
||||
ContextLength = Uint("OLLAMA_CONTEXT_LENGTH", 4096)
|
||||
// Auth enables authentication between the Ollama client and server
|
||||
UseAuth = Bool("OLLAMA_AUTH")
|
||||
)
|
||||
|
||||
func String(s string) func() string {
|
||||
|
||||
@@ -19,7 +19,7 @@ func TestVisionModels(t *testing.T) {
|
||||
}
|
||||
testCases := []testCase{
|
||||
{
|
||||
model: "llava:7b",
|
||||
model: "qwen2.5vl",
|
||||
},
|
||||
{
|
||||
model: "llama3.2-vision",
|
||||
@@ -60,6 +60,7 @@ func TestVisionModels(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestIntegrationSplitBatch(t *testing.T) {
|
||||
skipUnderMinVRAM(t, 6)
|
||||
image, err := base64.StdEncoding.DecodeString(imageEncoding)
|
||||
require.NoError(t, err)
|
||||
req := api.GenerateRequest{
|
||||
|
||||
1
integration/testdata/embed.json
vendored
1
integration/testdata/embed.json
vendored
File diff suppressed because one or more lines are too long
@@ -30,6 +30,11 @@ type Causal struct {
|
||||
|
||||
// ** current forward pass **
|
||||
|
||||
// curReserve indicates that this forward pass is only for
|
||||
// memory reservation and we should not update our metadata
|
||||
// based on it.
|
||||
curReserve bool
|
||||
|
||||
// the active layer for Get and Put
|
||||
curLayer int
|
||||
|
||||
@@ -159,12 +164,13 @@ func (c *Causal) Close() {
|
||||
}
|
||||
|
||||
func (c *Causal) StartForward(ctx ml.Context, batch input.Batch, reserve bool) error {
|
||||
c.curReserve = reserve
|
||||
c.curBatchSize = len(batch.Positions)
|
||||
c.curSequences = batch.Sequences
|
||||
c.curPositions = batch.Positions
|
||||
c.opts.Except = nil
|
||||
|
||||
if !reserve {
|
||||
if !c.curReserve {
|
||||
c.updateSlidingWindow()
|
||||
|
||||
var err error
|
||||
@@ -211,10 +217,9 @@ func (c *Causal) StartForward(ctx ml.Context, batch input.Batch, reserve bool) e
|
||||
c.curCellRange.max = len(c.cells) - 1
|
||||
}
|
||||
|
||||
var err error
|
||||
c.curMask, err = c.buildMask(ctx)
|
||||
c.curMask = c.buildMask(ctx)
|
||||
|
||||
return err
|
||||
return nil
|
||||
}
|
||||
|
||||
func newRange() cellRange {
|
||||
@@ -297,7 +302,7 @@ func roundUp(length, pad int) int {
|
||||
// Builds a mask of history x batch indicating whether for each token in the batch the
|
||||
// token in the history should apply. This is based on both the sequence and causality (the
|
||||
// position of the history is not ahead of the token in the batch).
|
||||
func (c *Causal) buildMask(ctx ml.Context) (ml.Tensor, error) {
|
||||
func (c *Causal) buildMask(ctx ml.Context) ml.Tensor {
|
||||
// Align and pad the two dimensions as required by the backend
|
||||
batchSize := roundUp(c.curBatchSize, c.config.MaskBatchPadding)
|
||||
|
||||
@@ -305,6 +310,11 @@ func (c *Causal) buildMask(ctx ml.Context) (ml.Tensor, error) {
|
||||
c.curCellRange.max = roundUp(c.curCellRange.max+1, c.config.CachePadding) - 1
|
||||
|
||||
length := c.curCellRange.max - c.curCellRange.min + 1
|
||||
|
||||
if c.curReserve {
|
||||
return ctx.Input().Empty(c.config.MaskDType, length, batchSize)
|
||||
}
|
||||
|
||||
mask := make([]float32, batchSize*length)
|
||||
|
||||
for i := range c.curBatchSize {
|
||||
@@ -325,10 +335,7 @@ func (c *Causal) buildMask(ctx ml.Context) (ml.Tensor, error) {
|
||||
mask[i] = float32(math.Inf(-1))
|
||||
}
|
||||
|
||||
maskTensor, err := ctx.Input().FromFloatSlice(mask, length, batchSize)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
maskTensor := ctx.Input().FromFloatSlice(mask, length, batchSize)
|
||||
|
||||
if c.config.MaskDType != ml.DTypeF32 {
|
||||
out := ctx.Input().Empty(c.config.MaskDType, maskTensor.Shape()...)
|
||||
@@ -336,7 +343,7 @@ func (c *Causal) buildMask(ctx ml.Context) (ml.Tensor, error) {
|
||||
maskTensor = out
|
||||
}
|
||||
|
||||
return maskTensor, nil
|
||||
return maskTensor
|
||||
}
|
||||
|
||||
func (c *Causal) moveCells(ctx ml.Context, src, dst, length int) {
|
||||
@@ -491,12 +498,7 @@ func (c *Causal) SetCausal(ctx ml.Context, opts CausalOptions) {
|
||||
if !slices.Equal(c.opts.Except, opts.Except) {
|
||||
c.opts = opts
|
||||
if ctx != nil {
|
||||
var err error
|
||||
c.curMask, err = c.buildMask(ctx)
|
||||
if err != nil {
|
||||
// This error should never occur because we have previously built a mask with the same shape
|
||||
panic(fmt.Errorf("SetCausal: %w", err))
|
||||
}
|
||||
c.curMask = c.buildMask(ctx)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -652,10 +654,7 @@ func (c *Causal) shift(seq int, beginIndex, offset int32) error {
|
||||
}
|
||||
}
|
||||
|
||||
kShift, err := ctx.Input().FromIntSlice(offsets, len(offsets))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
kShift := ctx.Input().FromIntSlice(offsets, len(offsets))
|
||||
|
||||
for i, key := range c.keys {
|
||||
if key == nil {
|
||||
|
||||
@@ -344,7 +344,7 @@ func testCache(t *testing.T, backend ml.Backend, cache Cache, tests []testCase)
|
||||
}
|
||||
|
||||
cache.SetLayer(0)
|
||||
tensor, _ := context.FromFloatSlice(test.in, test.inShape...)
|
||||
tensor := context.FromFloatSlice(test.in, test.inShape...)
|
||||
cache.Put(context, tensor, tensor)
|
||||
|
||||
out, _, mask := cache.Get(context)
|
||||
@@ -386,7 +386,7 @@ func TestCanResume(t *testing.T) {
|
||||
}
|
||||
|
||||
cache.SetLayer(0)
|
||||
tensor, _ := context.FromFloatSlice([]float32{1, 2, 3, 4}, 1, 1, 4)
|
||||
tensor := context.FromFloatSlice([]float32{1, 2, 3, 4}, 1, 1, 4)
|
||||
cache.Put(context, tensor, tensor)
|
||||
|
||||
// with window size 4, nothing has slid out of the window yet
|
||||
@@ -413,7 +413,7 @@ func TestCanResume(t *testing.T) {
|
||||
}
|
||||
|
||||
cache.SetLayer(0)
|
||||
tensor, _ = context.FromFloatSlice([]float32{5, 6}, 1, 1, 2)
|
||||
tensor = context.FromFloatSlice([]float32{5, 6}, 1, 1, 2)
|
||||
cache.Put(context, tensor, tensor)
|
||||
|
||||
// only the latest position has overlapping windows
|
||||
@@ -470,24 +470,24 @@ func (c *testContext) Zeros(dtype ml.DType, shape ...int) ml.Tensor {
|
||||
return c.Empty(dtype, shape...)
|
||||
}
|
||||
|
||||
func (c *testContext) FromFloatSlice(s []float32, shape ...int) (ml.Tensor, error) {
|
||||
func (c *testContext) FromFloatSlice(s []float32, shape ...int) ml.Tensor {
|
||||
t := c.Empty(ml.DTypeF32, shape...).(*testTensor)
|
||||
|
||||
copy(t.data, s)
|
||||
|
||||
return t, nil
|
||||
return t
|
||||
}
|
||||
|
||||
func (c *testContext) FromIntSlice(s []int32, shape ...int) (ml.Tensor, error) {
|
||||
func (c *testContext) FromIntSlice(s []int32, shape ...int) ml.Tensor {
|
||||
f := make([]float32, len(s))
|
||||
for i := range f {
|
||||
f[i] = float32(s[i])
|
||||
}
|
||||
|
||||
out, _ := c.FromFloatSlice(f, shape...)
|
||||
out := c.FromFloatSlice(f, shape...)
|
||||
out.(*testTensor).dtype = ml.DTypeI32
|
||||
|
||||
return out, nil
|
||||
return out
|
||||
}
|
||||
|
||||
func (c *testContext) Arange(start, stop, step float32, dtype ml.DType) ml.Tensor {
|
||||
@@ -496,7 +496,7 @@ func (c *testContext) Arange(start, stop, step float32, dtype ml.DType) ml.Tenso
|
||||
s = append(s, i)
|
||||
}
|
||||
|
||||
out, _ := c.FromFloatSlice(s, len(s))
|
||||
out := c.FromFloatSlice(s, len(s))
|
||||
out.(*testTensor).dtype = dtype
|
||||
return out
|
||||
}
|
||||
@@ -508,7 +508,7 @@ func (c *testContext) Forward(...ml.Tensor) ml.Context { return c }
|
||||
|
||||
func (c *testContext) Compute(...ml.Tensor) {}
|
||||
|
||||
func (c *testContext) Reserve() error { return nil }
|
||||
func (c *testContext) Reserve() {}
|
||||
|
||||
func (c *testContext) MaxGraphNodes() int {
|
||||
return 10
|
||||
|
||||
@@ -580,7 +580,7 @@ func SchemaToGrammar(schema []byte) []byte {
|
||||
defer C.free(unsafe.Pointer(cStr))
|
||||
|
||||
// Allocate buffer for grammar based on schema length but with upper bound
|
||||
maxLen := min(1024*1024, len(schema)*4)
|
||||
maxLen := max(32768, min(1024*1024, len(schema)*4))
|
||||
buf := make([]byte, maxLen)
|
||||
|
||||
// Call C function to convert schema to grammar
|
||||
|
||||
156
llama/patches/0016-graph-memory-reporting-on-failure.patch
Normal file
156
llama/patches/0016-graph-memory-reporting-on-failure.patch
Normal file
@@ -0,0 +1,156 @@
|
||||
From 0000000000000000000000000000000000000000 Mon Sep 17 00:00:00 2001
|
||||
From: Jesse Gross <jesse@ollama.com>
|
||||
Date: Fri, 18 Apr 2025 15:58:19 -0700
|
||||
Subject: [PATCH] graph memory reporting on failure
|
||||
|
||||
---
|
||||
ggml/include/ggml-alloc.h | 6 ++++++
|
||||
ggml/include/ggml-backend.h | 6 ++++++
|
||||
ggml/src/ggml-alloc.c | 38 +++++++++++++++++++++++++++++++++----
|
||||
ggml/src/ggml-backend.cpp | 10 ++++++++++
|
||||
4 files changed, 56 insertions(+), 4 deletions(-)
|
||||
|
||||
diff --git a/ggml/include/ggml-alloc.h b/ggml/include/ggml-alloc.h
|
||||
index 2cb150fd..781b1e10 100644
|
||||
--- a/ggml/include/ggml-alloc.h
|
||||
+++ b/ggml/include/ggml-alloc.h
|
||||
@@ -66,6 +66,12 @@ GGML_API bool ggml_gallocr_alloc_graph(ggml_gallocr_t galloc, struct ggml_cgraph
|
||||
|
||||
GGML_API size_t ggml_gallocr_get_buffer_size(ggml_gallocr_t galloc, int buffer_id);
|
||||
|
||||
+struct ggml_allocr_buffer_status {
|
||||
+ size_t size;
|
||||
+ bool allocated;
|
||||
+};
|
||||
+GGML_API struct ggml_allocr_buffer_status ggml_gallocr_get_attempted_buffer_size(ggml_gallocr_t galloc, int buffer_id);
|
||||
+
|
||||
// Utils
|
||||
// Create a buffer and allocate all the tensors in a ggml_context
|
||||
GGML_API struct ggml_backend_buffer * ggml_backend_alloc_ctx_tensors_from_buft(struct ggml_context * ctx, ggml_backend_buffer_type_t buft);
|
||||
diff --git a/ggml/include/ggml-backend.h b/ggml/include/ggml-backend.h
|
||||
index 778927f6..74e46716 100644
|
||||
--- a/ggml/include/ggml-backend.h
|
||||
+++ b/ggml/include/ggml-backend.h
|
||||
@@ -304,6 +304,12 @@ extern "C" {
|
||||
|
||||
GGML_API size_t ggml_backend_sched_get_buffer_size(ggml_backend_sched_t sched, ggml_backend_t backend);
|
||||
|
||||
+ struct ggml_backend_buffer_status {
|
||||
+ size_t size;
|
||||
+ bool allocated;
|
||||
+ };
|
||||
+ GGML_API struct ggml_backend_buffer_status ggml_backend_sched_get_attempted_buffer_size(ggml_backend_sched_t sched, ggml_backend_t backend);
|
||||
+
|
||||
GGML_API void ggml_backend_sched_set_tensor_backend(ggml_backend_sched_t sched, struct ggml_tensor * node, ggml_backend_t backend);
|
||||
GGML_API ggml_backend_t ggml_backend_sched_get_tensor_backend(ggml_backend_sched_t sched, struct ggml_tensor * node);
|
||||
|
||||
diff --git a/ggml/src/ggml-alloc.c b/ggml/src/ggml-alloc.c
|
||||
index 5fd379f6..04812990 100644
|
||||
--- a/ggml/src/ggml-alloc.c
|
||||
+++ b/ggml/src/ggml-alloc.c
|
||||
@@ -364,6 +364,7 @@ struct node_alloc {
|
||||
struct ggml_gallocr {
|
||||
ggml_backend_buffer_type_t * bufts; // [n_buffers]
|
||||
ggml_backend_buffer_t * buffers; // [n_buffers]
|
||||
+ size_t *buffer_sizes; // [n_buffers]
|
||||
struct ggml_dyn_tallocr ** buf_tallocs; // [n_buffers]
|
||||
int n_buffers;
|
||||
|
||||
@@ -387,6 +388,9 @@ ggml_gallocr_t ggml_gallocr_new_n(ggml_backend_buffer_type_t * bufts, int n_bufs
|
||||
galloc->buffers = calloc(n_bufs, sizeof(ggml_backend_buffer_t));
|
||||
GGML_ASSERT(galloc->buffers != NULL);
|
||||
|
||||
+ galloc->buffer_sizes = calloc(n_bufs, sizeof(size_t));
|
||||
+ GGML_ASSERT(galloc->buffer_sizes != NULL);
|
||||
+
|
||||
galloc->buf_tallocs = calloc(n_bufs, sizeof(struct ggml_dyn_tallocr *));
|
||||
GGML_ASSERT(galloc->buf_tallocs != NULL);
|
||||
|
||||
@@ -453,6 +457,7 @@ void ggml_gallocr_free(ggml_gallocr_t galloc) {
|
||||
ggml_hash_set_free(&galloc->hash_set);
|
||||
free(galloc->hash_values);
|
||||
free(galloc->bufts);
|
||||
+ free(galloc->buffer_sizes);
|
||||
free(galloc->buffers);
|
||||
free(galloc->buf_tallocs);
|
||||
free(galloc->node_allocs);
|
||||
@@ -748,6 +753,8 @@ bool ggml_gallocr_reserve_n(ggml_gallocr_t galloc, struct ggml_cgraph * graph, c
|
||||
}
|
||||
}
|
||||
|
||||
+ bool success = true;
|
||||
+
|
||||
// reallocate buffers if needed
|
||||
for (int i = 0; i < galloc->n_buffers; i++) {
|
||||
// if the buffer type is used multiple times, we reuse the same buffer
|
||||
@@ -769,15 +776,20 @@ bool ggml_gallocr_reserve_n(ggml_gallocr_t galloc, struct ggml_cgraph * graph, c
|
||||
|
||||
ggml_backend_buffer_free(galloc->buffers[i]);
|
||||
galloc->buffers[i] = ggml_backend_buft_alloc_buffer(galloc->bufts[i], new_size);
|
||||
- if (galloc->buffers[i] == NULL) {
|
||||
+ if (galloc->buffers[i]) {
|
||||
+ galloc->buffer_sizes[i] = ggml_backend_buffer_get_size(galloc->buffers[i]);
|
||||
+ ggml_backend_buffer_set_usage(galloc->buffers[i], GGML_BACKEND_BUFFER_USAGE_COMPUTE);
|
||||
+ } else {
|
||||
GGML_LOG_ERROR("%s: failed to allocate %s buffer of size %zu\n", __func__, ggml_backend_buft_name(galloc->bufts[i]), new_size);
|
||||
- return false;
|
||||
+ galloc->buffer_sizes[i] = new_size;
|
||||
+ success = false;
|
||||
}
|
||||
- ggml_backend_buffer_set_usage(galloc->buffers[i], GGML_BACKEND_BUFFER_USAGE_COMPUTE);
|
||||
+ } else {
|
||||
+ galloc->buffer_sizes[i] = ggml_backend_buffer_get_size(galloc->buffers[i]);
|
||||
}
|
||||
}
|
||||
|
||||
- return true;
|
||||
+ return success;
|
||||
}
|
||||
|
||||
bool ggml_gallocr_reserve(ggml_gallocr_t galloc, struct ggml_cgraph *graph) {
|
||||
@@ -934,6 +946,24 @@ size_t ggml_gallocr_get_buffer_size(ggml_gallocr_t galloc, int buffer_id) {
|
||||
return ggml_backend_buffer_get_size(galloc->buffers[buffer_id]);
|
||||
}
|
||||
|
||||
+struct ggml_allocr_buffer_status ggml_gallocr_get_attempted_buffer_size(ggml_gallocr_t galloc, int buffer_id) {
|
||||
+ GGML_ASSERT(buffer_id >= 0 && buffer_id < galloc->n_buffers);
|
||||
+
|
||||
+ for (int i = 0; i < buffer_id; i++) {
|
||||
+ if (galloc->buf_tallocs[i] == galloc->buf_tallocs[buffer_id]) {
|
||||
+ // This buffer is the same as a previous one due to the same buffer type being used multiple times
|
||||
+ // (See above.) However, we need a different check because multiple buffers might be NULL in our
|
||||
+ // case and we still want to know the attempted size.
|
||||
+
|
||||
+ struct ggml_allocr_buffer_status status = {0, true};
|
||||
+ return status;
|
||||
+ }
|
||||
+ }
|
||||
+
|
||||
+ struct ggml_allocr_buffer_status status = {galloc->buffer_sizes[buffer_id], galloc->buffers[buffer_id] != NULL};
|
||||
+ return status;
|
||||
+}
|
||||
+
|
||||
// utils
|
||||
|
||||
static void free_buffers(ggml_backend_buffer_t ** buffers, const size_t * n_buffers) {
|
||||
diff --git a/ggml/src/ggml-backend.cpp b/ggml/src/ggml-backend.cpp
|
||||
index 0ce73a99..be335e8c 100644
|
||||
--- a/ggml/src/ggml-backend.cpp
|
||||
+++ b/ggml/src/ggml-backend.cpp
|
||||
@@ -1629,6 +1629,16 @@ size_t ggml_backend_sched_get_buffer_size(ggml_backend_sched_t sched, ggml_backe
|
||||
return ggml_gallocr_get_buffer_size(sched->galloc, backend_index);
|
||||
}
|
||||
|
||||
+struct ggml_backend_buffer_status ggml_backend_sched_get_attempted_buffer_size(ggml_backend_sched_t sched, ggml_backend_t backend) {
|
||||
+ int backend_index = ggml_backend_sched_backend_id(sched, backend);
|
||||
+ GGML_ASSERT(backend_index >= 0 && backend_index < sched->n_backends);
|
||||
+
|
||||
+ struct ggml_allocr_buffer_status allocr_status = ggml_gallocr_get_attempted_buffer_size(sched->galloc, backend_index);
|
||||
+ struct ggml_backend_buffer_status status = {allocr_status.size, allocr_status.allocated};
|
||||
+
|
||||
+ return status;
|
||||
+}
|
||||
+
|
||||
void ggml_backend_sched_set_tensor_backend(ggml_backend_sched_t sched, struct ggml_tensor * node, ggml_backend_t backend) {
|
||||
int backend_index = ggml_backend_sched_backend_id(sched, backend);
|
||||
GGML_ASSERT(backend_index >= 0 && backend_index < sched->n_backends);
|
||||
124
ml/backend.go
124
ml/backend.go
@@ -5,6 +5,7 @@ import (
|
||||
"context"
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"math"
|
||||
"slices"
|
||||
"strconv"
|
||||
@@ -15,6 +16,10 @@ import (
|
||||
|
||||
type Backend interface {
|
||||
Load(ctx context.Context, progress func(float32)) error
|
||||
|
||||
// BackendMemory returns the memory allocations that were made for this model
|
||||
BackendMemory() BackendMemory
|
||||
|
||||
Config() fs.Config
|
||||
Get(name string) Tensor
|
||||
NewContext() Context
|
||||
@@ -68,6 +73,119 @@ type BackendParams struct {
|
||||
FlashAttention bool
|
||||
}
|
||||
|
||||
// ErrNoMem is returned when panicing due to insufficient memory. It includes
|
||||
// the attempted memory allocation.
|
||||
type ErrNoMem struct {
|
||||
BackendMemory
|
||||
}
|
||||
|
||||
func (e ErrNoMem) Error() string {
|
||||
return fmt.Sprintf("insufficient memory - required allocations: %+v", e.BackendMemory)
|
||||
}
|
||||
|
||||
type AllocationStatus int
|
||||
|
||||
const (
|
||||
// Unallocated memory - have not yet attempted to allocate
|
||||
Unallocated AllocationStatus = iota
|
||||
|
||||
// Failed memory - tried to allocate the memory and did not succeed
|
||||
Failed
|
||||
|
||||
// Allocated memory = tried and succeeded to allocate memory
|
||||
Allocated
|
||||
)
|
||||
|
||||
// Memory is the size of an allocation and whether it was successful.
|
||||
type Memory struct {
|
||||
Size uint64
|
||||
Status AllocationStatus
|
||||
}
|
||||
|
||||
func (m Memory) String() string {
|
||||
s := fmt.Sprint(m.Size)
|
||||
|
||||
switch m.Status {
|
||||
case Unallocated:
|
||||
s += "U"
|
||||
case Failed:
|
||||
s += "F"
|
||||
case Allocated:
|
||||
s += "A"
|
||||
}
|
||||
|
||||
return s
|
||||
}
|
||||
|
||||
// DeviceMemory provides a breakdown of the memory needed
|
||||
// per device, such as a CPU or GPU.
|
||||
type DeviceMemory struct {
|
||||
// Name is the name of the device as labeled by the backend. It
|
||||
// may not be persistent across instances of the runner.
|
||||
Name string
|
||||
|
||||
// Weights is the per-layer memory needed for the model weights.
|
||||
Weights []Memory
|
||||
|
||||
// Cache is the per-layer memory needed for the KV cache.
|
||||
Cache []Memory
|
||||
|
||||
// Graph is the size of the compute graph. It is not per-layer.
|
||||
Graph Memory
|
||||
}
|
||||
|
||||
func memoryPresent(mem []Memory) bool {
|
||||
return slices.ContainsFunc(mem, func(m Memory) bool { return m.Size != 0 })
|
||||
}
|
||||
|
||||
func (m DeviceMemory) LogValue() slog.Value {
|
||||
var attrs []slog.Attr
|
||||
if memoryPresent(m.Weights) {
|
||||
attrs = append(attrs, slog.Any("Weights", m.Weights))
|
||||
}
|
||||
|
||||
if memoryPresent(m.Cache) {
|
||||
attrs = append(attrs, slog.Any("Cache", m.Cache))
|
||||
}
|
||||
|
||||
if m.Graph.Size != 0 {
|
||||
attrs = append(attrs, slog.Any("Graph", m.Graph))
|
||||
}
|
||||
|
||||
return slog.GroupValue(attrs...)
|
||||
}
|
||||
|
||||
// BackendMemory provides the amount of memory required to load the model
|
||||
// per device based on the BackendParams. In some cases, not all required
|
||||
// allocations will be known at this point. However, the size of the most recent
|
||||
// allocation is guaranteed to be provided so that if it failed, the caller can
|
||||
// accommodate that to make forward progress.
|
||||
type BackendMemory struct {
|
||||
// InputsWeights are always located on the CPU and cannot be moved
|
||||
InputWeights Memory
|
||||
|
||||
// CPU model components are located in system memory. This does not
|
||||
// include unified memory allocated through the GPU.
|
||||
CPU DeviceMemory
|
||||
|
||||
// GPU model components are located on one or more GPUs.
|
||||
GPUs []DeviceMemory
|
||||
}
|
||||
|
||||
func (m BackendMemory) LogValue() slog.Value {
|
||||
var attrs []slog.Attr
|
||||
if m.InputWeights.Size != 0 {
|
||||
attrs = append(attrs, slog.Any("InputWeights", m.InputWeights))
|
||||
}
|
||||
|
||||
attrs = append(attrs, slog.Any(m.CPU.Name, m.CPU))
|
||||
for _, g := range m.GPUs {
|
||||
attrs = append(attrs, slog.Any(g.Name, g))
|
||||
}
|
||||
|
||||
return slog.GroupValue(attrs...)
|
||||
}
|
||||
|
||||
var backends = make(map[string]func(string, BackendParams) (Backend, error))
|
||||
|
||||
func RegisterBackend(name string, f func(string, BackendParams) (Backend, error)) {
|
||||
@@ -89,8 +207,8 @@ func NewBackend(modelPath string, params BackendParams) (Backend, error) {
|
||||
type Context interface {
|
||||
Empty(dtype DType, shape ...int) Tensor
|
||||
Zeros(dtype DType, shape ...int) Tensor
|
||||
FromFloatSlice(s []float32, shape ...int) (Tensor, error)
|
||||
FromIntSlice(s []int32, shape ...int) (Tensor, error)
|
||||
FromFloatSlice(s []float32, shape ...int) Tensor
|
||||
FromIntSlice(s []int32, shape ...int) Tensor
|
||||
|
||||
// Arange creates a 1D tensor with values within an interval (start, stop] increased by step.
|
||||
Arange(start, stop, step float32, dtype DType) Tensor
|
||||
@@ -102,7 +220,7 @@ type Context interface {
|
||||
// graph, simply preallocates memory. Typically called with a
|
||||
// worst case graph to ensure all resources are available for
|
||||
// for future inference.
|
||||
Reserve() error
|
||||
Reserve()
|
||||
|
||||
MaxGraphNodes() int
|
||||
Close()
|
||||
|
||||
@@ -10,7 +10,6 @@ import "C"
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
@@ -66,6 +65,12 @@ type Backend struct {
|
||||
// layers is the backend used for repeating layers
|
||||
layers map[int]*C.struct_ggml_backend_buffer_type
|
||||
|
||||
// requiredMemory is the cumulative memory allocations needed by the backend
|
||||
requiredMemory *ml.BackendMemory
|
||||
|
||||
// btDeviceMemory maps from a buffer type to the memory allocations associated with that device
|
||||
btDeviceMemory map[*C.struct_ggml_backend_buffer_type]*ml.DeviceMemory
|
||||
|
||||
flashAttention bool
|
||||
|
||||
// maxGraphNodes is the maximum allowed number of graph nodes in this scheduler
|
||||
@@ -94,6 +99,9 @@ func New(modelPath string, params ml.BackendParams) (ml.Backend, error) {
|
||||
"num_key_values", len(meta.KV()),
|
||||
)
|
||||
|
||||
var requiredMemory ml.BackendMemory
|
||||
btDeviceMemory := make(map[*C.struct_ggml_backend_buffer_type]*ml.DeviceMemory)
|
||||
|
||||
type deviceBufferType struct {
|
||||
d *C.struct_ggml_backend_device
|
||||
bts []*C.struct_ggml_backend_buffer_type
|
||||
@@ -114,6 +122,8 @@ func New(modelPath string, params ml.BackendParams) (ml.Backend, error) {
|
||||
}
|
||||
}
|
||||
|
||||
blocks := int(meta.KV().BlockCount())
|
||||
|
||||
// create list of buffer types for the cpu
|
||||
cpuDeviceBufferType := deviceBufferType{d: C.ggml_backend_dev_by_type(C.GGML_BACKEND_DEVICE_TYPE_CPU)}
|
||||
for _, d := range append(accels, append(gpus, cpus...)...) {
|
||||
@@ -121,17 +131,27 @@ func New(modelPath string, params ml.BackendParams) (ml.Backend, error) {
|
||||
case C.GGML_BACKEND_DEVICE_TYPE_CPU,
|
||||
C.GGML_BACKEND_DEVICE_TYPE_ACCEL:
|
||||
cpuDeviceBufferType.bts = append(cpuDeviceBufferType.bts, C.ggml_backend_dev_buffer_type(d))
|
||||
btDeviceMemory[C.ggml_backend_dev_buffer_type(d)] = &requiredMemory.CPU
|
||||
}
|
||||
}
|
||||
|
||||
requiredMemory.CPU.Name = C.GoString(C.ggml_backend_dev_name(cpuDeviceBufferType.d))
|
||||
requiredMemory.CPU.Weights = make([]ml.Memory, blocks+1)
|
||||
requiredMemory.CPU.Cache = make([]ml.Memory, blocks+1)
|
||||
|
||||
// create list of buffer types for each gpu
|
||||
var gpuDeviceBufferTypes []deviceBufferType
|
||||
for _, d := range gpus {
|
||||
requiredMemory.GPUs = make([]ml.DeviceMemory, len(gpus))
|
||||
for i, d := range gpus {
|
||||
bt := C.ggml_backend_dev_buffer_type(d)
|
||||
gpuDeviceBufferTypes = append(gpuDeviceBufferTypes, deviceBufferType{
|
||||
d: d,
|
||||
bts: append([]*C.struct_ggml_backend_buffer_type{bt}, cpuDeviceBufferType.bts...),
|
||||
})
|
||||
btDeviceMemory[bt] = &requiredMemory.GPUs[i]
|
||||
requiredMemory.GPUs[i].Name = C.GoString(C.ggml_backend_dev_name(d))
|
||||
requiredMemory.GPUs[i].Weights = make([]ml.Memory, blocks+1)
|
||||
requiredMemory.GPUs[i].Cache = make([]ml.Memory, blocks+1)
|
||||
}
|
||||
|
||||
useDefaultSplit := true
|
||||
@@ -170,8 +190,6 @@ func New(modelPath string, params ml.BackendParams) (ml.Backend, error) {
|
||||
// inputs always use cpu
|
||||
input := cpuDeviceBufferType
|
||||
|
||||
blocks := int(meta.KV().BlockCount())
|
||||
|
||||
// define a range of gpu layers. anything outside of this range is assigned to the cpu
|
||||
gpuRangeStart := max(0, blocks-params.NumGPULayers)
|
||||
gpuRangeStop := min(gpuRangeStart+params.NumGPULayers, blocks+1)
|
||||
@@ -212,7 +230,7 @@ func New(modelPath string, params ml.BackendParams) (ml.Backend, error) {
|
||||
|
||||
// contexts are shared by tensors of the same buffer type
|
||||
ctxs := make(map[*C.struct_ggml_backend_buffer_type]*C.struct_ggml_context)
|
||||
createTensor := func(t tensor, bts []*C.struct_ggml_backend_buffer_type) *C.struct_ggml_tensor {
|
||||
createTensor := func(t tensor, bts []*C.struct_ggml_backend_buffer_type, layer int) *C.struct_ggml_tensor {
|
||||
for _, bt := range bts {
|
||||
if _, ok := ctxs[bt]; !ok {
|
||||
ctxs[bt] = C.ggml_init(C.struct_ggml_init_params{
|
||||
@@ -238,6 +256,16 @@ func New(modelPath string, params ml.BackendParams) (ml.Backend, error) {
|
||||
C.ggml_set_name(tt, cname)
|
||||
|
||||
slog.Log(context.TODO(), logutil.LevelTrace, "created tensor", "name", name, "shape", t.source.Shape, "dtype", t.source.Kind, "buffer_type", C.GoString(C.ggml_backend_buft_name(bt)))
|
||||
|
||||
size := pad(C.ggml_backend_buft_get_alloc_size(bt, tt), C.ggml_backend_buft_get_alignment(bt))
|
||||
if layer == -1 {
|
||||
// Assume that InputWeights can be allocated - they're always in system memory and can't be moved in any case
|
||||
requiredMemory.InputWeights.Status = ml.Allocated
|
||||
requiredMemory.InputWeights.Size += uint64(size)
|
||||
} else {
|
||||
btDeviceMemory[bt].Weights[layer].Size += uint64(size)
|
||||
}
|
||||
|
||||
//nolint:staticcheck // TODO: check if buffer type supports this tensor
|
||||
return tt
|
||||
}
|
||||
@@ -259,22 +287,22 @@ func New(modelPath string, params ml.BackendParams) (ml.Backend, error) {
|
||||
for _, t := range meta.Tensors().Items() {
|
||||
switch {
|
||||
case contains(t.Name, "position_embd", "token_embd", "token_norm_embd", "token_types"):
|
||||
createTensor(tensor{source: t}, input.bts)
|
||||
createTensor(tensor{source: t}, input.bts, -1)
|
||||
if _, ok := meta.Tensors().GroupLayers()["output"]; !ok && t.Name == "token_embd.weight" {
|
||||
createTensor(tensor{source: t, target: "output.weight"}, output.bts)
|
||||
createTensor(tensor{source: t, target: "output.weight"}, output.bts, blocks)
|
||||
}
|
||||
case contains(t.Name, "cls", "output", "output_norm"):
|
||||
createTensor(tensor{source: t}, output.bts)
|
||||
createTensor(tensor{source: t}, output.bts, blocks)
|
||||
case strings.HasPrefix(t.Name, "v.") || strings.HasPrefix(t.Name, "mm."):
|
||||
// TODO: assign vision tensors to the gpu if possible
|
||||
createTensor(tensor{source: t}, output.bts)
|
||||
createTensor(tensor{source: t}, output.bts, blocks)
|
||||
case contains(t.Name, "rope_freqs", "rope_factors_long", "rope_factors_short"):
|
||||
// these tensors should be repeated per layer
|
||||
for i, layer := range layers {
|
||||
createTensor(tensor{
|
||||
source: t,
|
||||
target: "blk." + strconv.Itoa(i) + "." + t.Name,
|
||||
}, layer.bts)
|
||||
}, layer.bts, i)
|
||||
}
|
||||
default:
|
||||
layerIndex := -1
|
||||
@@ -285,10 +313,10 @@ func New(modelPath string, params ml.BackendParams) (ml.Backend, error) {
|
||||
}
|
||||
|
||||
if layerIndex >= 0 {
|
||||
createTensor(tensor{source: t}, layers[layerIndex].bts)
|
||||
createTensor(tensor{source: t}, layers[layerIndex].bts, layerIndex)
|
||||
} else {
|
||||
// load all other tensors on the cpu
|
||||
createTensor(tensor{source: t}, input.bts)
|
||||
createTensor(tensor{source: t}, input.bts, -1)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -301,8 +329,18 @@ func New(modelPath string, params ml.BackendParams) (ml.Backend, error) {
|
||||
}
|
||||
|
||||
b := C.ggml_backend_alloc_ctx_tensors_from_buft(c, bt)
|
||||
for i := range btDeviceMemory[bt].Weights {
|
||||
if btDeviceMemory[bt].Weights[i].Size != 0 {
|
||||
if b != nil {
|
||||
btDeviceMemory[bt].Weights[i].Status = ml.Allocated
|
||||
} else {
|
||||
btDeviceMemory[bt].Weights[i].Status = ml.Failed
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if b == nil {
|
||||
return nil, fmt.Errorf("unable to allocate memory from device %v for model weights", C.GoString(C.ggml_backend_buft_name(bt)))
|
||||
panic(ml.ErrNoMem{BackendMemory: requiredMemory})
|
||||
}
|
||||
|
||||
C.ggml_backend_buffer_set_usage(b, C.GGML_BACKEND_BUFFER_USAGE_WEIGHTS)
|
||||
@@ -367,7 +405,9 @@ func New(modelPath string, params ml.BackendParams) (ml.Backend, error) {
|
||||
}
|
||||
return m
|
||||
}(),
|
||||
maxGraphNodes: maxGraphNodes,
|
||||
requiredMemory: &requiredMemory,
|
||||
btDeviceMemory: btDeviceMemory,
|
||||
maxGraphNodes: maxGraphNodes,
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -446,6 +486,10 @@ func (b *Backend) Load(ctx context.Context, progress func(float32)) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (b *Backend) BackendMemory() ml.BackendMemory {
|
||||
return *b.requiredMemory
|
||||
}
|
||||
|
||||
func (b *Backend) Config() fs.Config {
|
||||
return b.meta.KV()
|
||||
}
|
||||
@@ -477,6 +521,7 @@ func (b *Backend) NewContextSize(n int) ml.Context {
|
||||
no_alloc: true,
|
||||
}),
|
||||
allocatedBuffers: &allocatedBuffers,
|
||||
layer: -1,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -503,6 +548,9 @@ type Context struct {
|
||||
|
||||
// maxGraphNodes is the maximum allowed number of graph nodes in this context
|
||||
maxGraphNodes int
|
||||
|
||||
// layer is the graph layer that this context is allocating for - assumed to be cache
|
||||
layer int
|
||||
}
|
||||
|
||||
func (c *Context) Input() ml.Context {
|
||||
@@ -513,6 +561,7 @@ func (c *Context) Input() ml.Context {
|
||||
buft: c.b.input,
|
||||
allocatedBuffers: c.allocatedBuffers,
|
||||
maxGraphNodes: c.maxGraphNodes,
|
||||
layer: -1,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -527,6 +576,7 @@ func (c *Context) Layer(i int) ml.Context {
|
||||
buft: buft,
|
||||
allocatedBuffers: c.allocatedBuffers,
|
||||
maxGraphNodes: c.maxGraphNodes,
|
||||
layer: i,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -564,22 +614,34 @@ func (c *Context) Compute(tensors ...ml.Tensor) {
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Context) Reserve() error {
|
||||
if !C.ggml_backend_sched_reserve(c.b.sched, c.graph) {
|
||||
C.ggml_backend_sched_reset(c.b.sched)
|
||||
return errors.New("failed to reserve graph")
|
||||
}
|
||||
func (c *Context) Reserve() {
|
||||
reserved := C.ggml_backend_sched_reserve(c.b.sched, c.graph)
|
||||
|
||||
slog.Debug("compute graph", "nodes", C.ggml_graph_n_nodes(c.graph), "splits", C.ggml_backend_sched_get_n_splits(c.b.sched))
|
||||
for i := range c.b.schedBackends {
|
||||
size := C.ggml_backend_sched_get_buffer_size(c.b.sched, c.b.schedBackends[i])
|
||||
slog.Info("compute graph", "backend", C.GoString(C.ggml_backend_name(c.b.schedBackends[i])), "buffer_type", C.GoString(C.ggml_backend_buft_name(c.b.schedBufts[i])),
|
||||
"size", format.HumanBytes2(uint64(size)))
|
||||
|
||||
// Reserve may get called multiple times for different graphs - we just want the last run, which will contain the max allocations
|
||||
for _, bt := range c.b.schedBufts {
|
||||
c.b.btDeviceMemory[bt].Graph = ml.Memory{}
|
||||
}
|
||||
|
||||
C.ggml_backend_sched_reset(c.b.sched)
|
||||
for i := range c.b.schedBackends {
|
||||
bufferStatus := C.ggml_backend_sched_get_attempted_buffer_size(c.b.sched, c.b.schedBackends[i])
|
||||
|
||||
return nil
|
||||
graph := &c.b.btDeviceMemory[c.b.schedBufts[i]].Graph
|
||||
graph.Size += uint64(bufferStatus.size)
|
||||
if bufferStatus.allocated && graph.Status != ml.Failed {
|
||||
graph.Status = ml.Allocated
|
||||
} else {
|
||||
graph.Status = ml.Failed
|
||||
}
|
||||
|
||||
slog.Info("compute graph", "backend", C.GoString(C.ggml_backend_name(c.b.schedBackends[i])), "buffer_type", C.GoString(C.ggml_backend_buft_name(c.b.schedBufts[i])),
|
||||
"size", format.HumanBytes2(uint64(bufferStatus.size)))
|
||||
}
|
||||
|
||||
if !reserved {
|
||||
panic(ml.ErrNoMem{BackendMemory: *c.b.requiredMemory})
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Context) MaxGraphNodes() int {
|
||||
@@ -599,7 +661,7 @@ func pad(length, pad C.size_t) C.size_t {
|
||||
return ((length + pad - 1) / pad) * pad
|
||||
}
|
||||
|
||||
func (c *Context) newTensor(dtype ml.DType, shape []int) (ml.Tensor, error) {
|
||||
func (c *Context) newTensor(dtype ml.DType, shape []int) ml.Tensor {
|
||||
if c.buft == nil {
|
||||
panic("set Input or Layer before creating tensors")
|
||||
}
|
||||
@@ -622,7 +684,7 @@ func (c *Context) newTensor(dtype ml.DType, shape []int) (ml.Tensor, error) {
|
||||
|
||||
if len(shape) < 1 || shape[0] == 0 {
|
||||
var shape C.int64_t = 0
|
||||
return &Tensor{b: c.b, t: C.ggml_new_tensor(c.ctx, cdtype, 1, &shape)}, nil
|
||||
return &Tensor{b: c.b, t: C.ggml_new_tensor(c.ctx, cdtype, 1, &shape)}
|
||||
} else if len(shape) > 4 {
|
||||
panic("unsupported number of dimensions")
|
||||
}
|
||||
@@ -635,40 +697,43 @@ func (c *Context) newTensor(dtype ml.DType, shape []int) (ml.Tensor, error) {
|
||||
|
||||
t := C.ggml_new_tensor(c.ctx, cdtype, C.int(len(shape)), shapeToGGML(shape))
|
||||
size := pad(C.ggml_backend_buft_get_alloc_size(c.buft, t), C.ggml_backend_buft_get_alignment(c.buft))
|
||||
b := C.ggml_backend_buft_alloc_buffer(c.buft, size)
|
||||
if b == nil {
|
||||
return nil, fmt.Errorf("unable to allocate %v from device %v for new tensor", format.HumanBytes2(uint64(size)), C.GoString(C.ggml_backend_buft_name(c.buft)))
|
||||
}
|
||||
*c.allocatedBuffers = append(*c.allocatedBuffers, b)
|
||||
|
||||
b := C.ggml_backend_buft_alloc_buffer(c.buft, size)
|
||||
if c.layer >= 0 {
|
||||
cache := &c.b.btDeviceMemory[c.buft].Cache[c.layer]
|
||||
|
||||
cache.Size += uint64(size)
|
||||
if b != nil {
|
||||
cache.Status = ml.Allocated
|
||||
} else {
|
||||
cache.Status = ml.Failed
|
||||
}
|
||||
}
|
||||
|
||||
if b == nil {
|
||||
panic(ml.ErrNoMem{BackendMemory: *c.b.requiredMemory})
|
||||
}
|
||||
|
||||
*c.allocatedBuffers = append(*c.allocatedBuffers, b)
|
||||
C.ggml_backend_tensor_alloc(b, t, C.ggml_backend_buffer_get_base(b))
|
||||
return &Tensor{b: c.b, t: t}, nil
|
||||
return &Tensor{b: c.b, t: t}
|
||||
}
|
||||
|
||||
func (c *Context) Empty(dtype ml.DType, shape ...int) ml.Tensor {
|
||||
t, err := c.newTensor(dtype, shape)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
return t
|
||||
return c.newTensor(dtype, shape)
|
||||
}
|
||||
|
||||
func (c *Context) Zeros(dtype ml.DType, shape ...int) ml.Tensor {
|
||||
t, err := c.newTensor(dtype, shape)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
t := c.newTensor(dtype, shape)
|
||||
C.ggml_set_zero(t.(*Tensor).t)
|
||||
return t
|
||||
}
|
||||
|
||||
func checkShape[S ~[]E, E any](s S, shape ...int) error {
|
||||
func checkShape[S ~[]E, E any](s S, shape ...int) {
|
||||
n := len(s)
|
||||
|
||||
if n == 0 {
|
||||
return nil
|
||||
return
|
||||
}
|
||||
|
||||
for _, v := range shape {
|
||||
@@ -676,44 +741,32 @@ func checkShape[S ~[]E, E any](s S, shape ...int) error {
|
||||
}
|
||||
|
||||
if n != 1 {
|
||||
return fmt.Errorf("invalid shape: %v", shape)
|
||||
panic(fmt.Errorf("invalid shape: %v", shape))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Context) FromFloatSlice(s []float32, shape ...int) (ml.Tensor, error) {
|
||||
if err := checkShape(s, shape...); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
func (c *Context) FromFloatSlice(s []float32, shape ...int) ml.Tensor {
|
||||
checkShape(s, shape...)
|
||||
|
||||
t, err := c.newTensor(ml.DTypeF32, shape)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
t := c.newTensor(ml.DTypeF32, shape)
|
||||
|
||||
if len(s) > 0 {
|
||||
C.ggml_backend_tensor_set(t.(*Tensor).t, unsafe.Pointer(&s[0]), 0, C.ggml_nbytes(t.(*Tensor).t))
|
||||
}
|
||||
|
||||
return t, nil
|
||||
return t
|
||||
}
|
||||
|
||||
func (c *Context) FromIntSlice(s []int32, shape ...int) (ml.Tensor, error) {
|
||||
if err := checkShape(s, shape...); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
func (c *Context) FromIntSlice(s []int32, shape ...int) ml.Tensor {
|
||||
checkShape(s, shape...)
|
||||
|
||||
t, err := c.newTensor(ml.DTypeI32, shape)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
t := c.newTensor(ml.DTypeI32, shape)
|
||||
|
||||
if len(s) > 0 {
|
||||
C.ggml_backend_tensor_set(t.(*Tensor).t, unsafe.Pointer(&s[0]), 0, C.ggml_nbytes(t.(*Tensor).t))
|
||||
}
|
||||
|
||||
return t, nil
|
||||
return t
|
||||
}
|
||||
|
||||
func (c Context) Arange(start, stop, step float32, dtype ml.DType) ml.Tensor {
|
||||
@@ -731,12 +784,7 @@ func (c Context) Arange(start, stop, step float32, dtype ml.DType) ml.Tensor {
|
||||
arange = append(arange, int32(i))
|
||||
}
|
||||
|
||||
t, err := c.Input().FromIntSlice(arange, len(arange))
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
return t
|
||||
return c.Input().FromIntSlice(arange, len(arange))
|
||||
default:
|
||||
panic("unsupported dtype for arange")
|
||||
}
|
||||
|
||||
6
ml/backend/ggml/ggml/include/ggml-alloc.h
vendored
6
ml/backend/ggml/ggml/include/ggml-alloc.h
vendored
@@ -66,6 +66,12 @@ GGML_API bool ggml_gallocr_alloc_graph(ggml_gallocr_t galloc, struct ggml_cgraph
|
||||
|
||||
GGML_API size_t ggml_gallocr_get_buffer_size(ggml_gallocr_t galloc, int buffer_id);
|
||||
|
||||
struct ggml_allocr_buffer_status {
|
||||
size_t size;
|
||||
bool allocated;
|
||||
};
|
||||
GGML_API struct ggml_allocr_buffer_status ggml_gallocr_get_attempted_buffer_size(ggml_gallocr_t galloc, int buffer_id);
|
||||
|
||||
// Utils
|
||||
// Create a buffer and allocate all the tensors in a ggml_context
|
||||
GGML_API struct ggml_backend_buffer * ggml_backend_alloc_ctx_tensors_from_buft(struct ggml_context * ctx, ggml_backend_buffer_type_t buft);
|
||||
|
||||
6
ml/backend/ggml/ggml/include/ggml-backend.h
vendored
6
ml/backend/ggml/ggml/include/ggml-backend.h
vendored
@@ -304,6 +304,12 @@ extern "C" {
|
||||
|
||||
GGML_API size_t ggml_backend_sched_get_buffer_size(ggml_backend_sched_t sched, ggml_backend_t backend);
|
||||
|
||||
struct ggml_backend_buffer_status {
|
||||
size_t size;
|
||||
bool allocated;
|
||||
};
|
||||
GGML_API struct ggml_backend_buffer_status ggml_backend_sched_get_attempted_buffer_size(ggml_backend_sched_t sched, ggml_backend_t backend);
|
||||
|
||||
GGML_API void ggml_backend_sched_set_tensor_backend(ggml_backend_sched_t sched, struct ggml_tensor * node, ggml_backend_t backend);
|
||||
GGML_API ggml_backend_t ggml_backend_sched_get_tensor_backend(ggml_backend_sched_t sched, struct ggml_tensor * node);
|
||||
|
||||
|
||||
38
ml/backend/ggml/ggml/src/ggml-alloc.c
vendored
38
ml/backend/ggml/ggml/src/ggml-alloc.c
vendored
@@ -364,6 +364,7 @@ struct node_alloc {
|
||||
struct ggml_gallocr {
|
||||
ggml_backend_buffer_type_t * bufts; // [n_buffers]
|
||||
ggml_backend_buffer_t * buffers; // [n_buffers]
|
||||
size_t *buffer_sizes; // [n_buffers]
|
||||
struct ggml_dyn_tallocr ** buf_tallocs; // [n_buffers]
|
||||
int n_buffers;
|
||||
|
||||
@@ -387,6 +388,9 @@ ggml_gallocr_t ggml_gallocr_new_n(ggml_backend_buffer_type_t * bufts, int n_bufs
|
||||
galloc->buffers = calloc(n_bufs, sizeof(ggml_backend_buffer_t));
|
||||
GGML_ASSERT(galloc->buffers != NULL);
|
||||
|
||||
galloc->buffer_sizes = calloc(n_bufs, sizeof(size_t));
|
||||
GGML_ASSERT(galloc->buffer_sizes != NULL);
|
||||
|
||||
galloc->buf_tallocs = calloc(n_bufs, sizeof(struct ggml_dyn_tallocr *));
|
||||
GGML_ASSERT(galloc->buf_tallocs != NULL);
|
||||
|
||||
@@ -453,6 +457,7 @@ void ggml_gallocr_free(ggml_gallocr_t galloc) {
|
||||
ggml_hash_set_free(&galloc->hash_set);
|
||||
free(galloc->hash_values);
|
||||
free(galloc->bufts);
|
||||
free(galloc->buffer_sizes);
|
||||
free(galloc->buffers);
|
||||
free(galloc->buf_tallocs);
|
||||
free(galloc->node_allocs);
|
||||
@@ -748,6 +753,8 @@ bool ggml_gallocr_reserve_n(ggml_gallocr_t galloc, struct ggml_cgraph * graph, c
|
||||
}
|
||||
}
|
||||
|
||||
bool success = true;
|
||||
|
||||
// reallocate buffers if needed
|
||||
for (int i = 0; i < galloc->n_buffers; i++) {
|
||||
// if the buffer type is used multiple times, we reuse the same buffer
|
||||
@@ -769,15 +776,20 @@ bool ggml_gallocr_reserve_n(ggml_gallocr_t galloc, struct ggml_cgraph * graph, c
|
||||
|
||||
ggml_backend_buffer_free(galloc->buffers[i]);
|
||||
galloc->buffers[i] = ggml_backend_buft_alloc_buffer(galloc->bufts[i], new_size);
|
||||
if (galloc->buffers[i] == NULL) {
|
||||
if (galloc->buffers[i]) {
|
||||
galloc->buffer_sizes[i] = ggml_backend_buffer_get_size(galloc->buffers[i]);
|
||||
ggml_backend_buffer_set_usage(galloc->buffers[i], GGML_BACKEND_BUFFER_USAGE_COMPUTE);
|
||||
} else {
|
||||
GGML_LOG_ERROR("%s: failed to allocate %s buffer of size %zu\n", __func__, ggml_backend_buft_name(galloc->bufts[i]), new_size);
|
||||
return false;
|
||||
galloc->buffer_sizes[i] = new_size;
|
||||
success = false;
|
||||
}
|
||||
ggml_backend_buffer_set_usage(galloc->buffers[i], GGML_BACKEND_BUFFER_USAGE_COMPUTE);
|
||||
} else {
|
||||
galloc->buffer_sizes[i] = ggml_backend_buffer_get_size(galloc->buffers[i]);
|
||||
}
|
||||
}
|
||||
|
||||
return true;
|
||||
return success;
|
||||
}
|
||||
|
||||
bool ggml_gallocr_reserve(ggml_gallocr_t galloc, struct ggml_cgraph *graph) {
|
||||
@@ -934,6 +946,24 @@ size_t ggml_gallocr_get_buffer_size(ggml_gallocr_t galloc, int buffer_id) {
|
||||
return ggml_backend_buffer_get_size(galloc->buffers[buffer_id]);
|
||||
}
|
||||
|
||||
struct ggml_allocr_buffer_status ggml_gallocr_get_attempted_buffer_size(ggml_gallocr_t galloc, int buffer_id) {
|
||||
GGML_ASSERT(buffer_id >= 0 && buffer_id < galloc->n_buffers);
|
||||
|
||||
for (int i = 0; i < buffer_id; i++) {
|
||||
if (galloc->buf_tallocs[i] == galloc->buf_tallocs[buffer_id]) {
|
||||
// This buffer is the same as a previous one due to the same buffer type being used multiple times
|
||||
// (See above.) However, we need a different check because multiple buffers might be NULL in our
|
||||
// case and we still want to know the attempted size.
|
||||
|
||||
struct ggml_allocr_buffer_status status = {0, true};
|
||||
return status;
|
||||
}
|
||||
}
|
||||
|
||||
struct ggml_allocr_buffer_status status = {galloc->buffer_sizes[buffer_id], galloc->buffers[buffer_id] != NULL};
|
||||
return status;
|
||||
}
|
||||
|
||||
// utils
|
||||
|
||||
static void free_buffers(ggml_backend_buffer_t ** buffers, const size_t * n_buffers) {
|
||||
|
||||
10
ml/backend/ggml/ggml/src/ggml-backend.cpp
vendored
10
ml/backend/ggml/ggml/src/ggml-backend.cpp
vendored
@@ -1629,6 +1629,16 @@ size_t ggml_backend_sched_get_buffer_size(ggml_backend_sched_t sched, ggml_backe
|
||||
return ggml_gallocr_get_buffer_size(sched->galloc, backend_index);
|
||||
}
|
||||
|
||||
struct ggml_backend_buffer_status ggml_backend_sched_get_attempted_buffer_size(ggml_backend_sched_t sched, ggml_backend_t backend) {
|
||||
int backend_index = ggml_backend_sched_backend_id(sched, backend);
|
||||
GGML_ASSERT(backend_index >= 0 && backend_index < sched->n_backends);
|
||||
|
||||
struct ggml_allocr_buffer_status allocr_status = ggml_gallocr_get_attempted_buffer_size(sched->galloc, backend_index);
|
||||
struct ggml_backend_buffer_status status = {allocr_status.size, allocr_status.allocated};
|
||||
|
||||
return status;
|
||||
}
|
||||
|
||||
void ggml_backend_sched_set_tensor_backend(ggml_backend_sched_t sched, struct ggml_tensor * node, ggml_backend_t backend) {
|
||||
int backend_index = ggml_backend_sched_backend_id(sched, backend);
|
||||
GGML_ASSERT(backend_index >= 0 && backend_index < sched->n_backends);
|
||||
|
||||
@@ -3,6 +3,7 @@ package model
|
||||
import (
|
||||
"cmp"
|
||||
"context"
|
||||
"fmt"
|
||||
"iter"
|
||||
"log/slog"
|
||||
"strings"
|
||||
@@ -210,6 +211,14 @@ func (bpe BytePairEncoding) Encode(s string, addSpecial bool) ([]int32, error) {
|
||||
return ids, nil
|
||||
}
|
||||
|
||||
type lazyIdsString struct {
|
||||
ids []int32
|
||||
}
|
||||
|
||||
func (l lazyIdsString) LogValue() slog.Value {
|
||||
return slog.AnyValue(fmt.Sprint(l.ids))
|
||||
}
|
||||
|
||||
func (bpe BytePairEncoding) Decode(ids []int32) (string, error) {
|
||||
var sb strings.Builder
|
||||
for _, id := range ids {
|
||||
@@ -234,6 +243,6 @@ func (bpe BytePairEncoding) Decode(ids []int32) (string, error) {
|
||||
}
|
||||
}
|
||||
|
||||
slog.Log(context.TODO(), logutil.LevelTrace, "decoded", "ids", ids, "string", sb.String())
|
||||
slog.Log(context.TODO(), logutil.LevelTrace, "decoded", "string", sb.String(), "from", lazyIdsString{ids: ids})
|
||||
return sb.String(), nil
|
||||
}
|
||||
|
||||
@@ -287,11 +287,7 @@ func Forward(ctx ml.Context, m Model, inputs []int32, batch input.Batch) (ml.Ten
|
||||
return nil, errors.New("batch size cannot be less than 1")
|
||||
}
|
||||
|
||||
var err error
|
||||
batch.Inputs, err = ctx.Input().FromIntSlice(inputs, len(inputs))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
batch.Inputs = ctx.Input().FromIntSlice(inputs, len(inputs))
|
||||
|
||||
cache := m.Config().Cache
|
||||
if cache != nil {
|
||||
|
||||
@@ -175,15 +175,8 @@ func (l *Layer) Forward(ctx ml.Context, hiddenState, positionIDs, outputs ml.Ten
|
||||
}
|
||||
|
||||
func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
|
||||
positions, err := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
outputs, err := ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
positions := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions))
|
||||
outputs := ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs))
|
||||
|
||||
hiddenState := m.TokenEmbedding.Forward(ctx, batch.Inputs)
|
||||
hiddenState = hiddenState.Scale(ctx, math.Sqrt(float64(m.Options.hiddenSize)))
|
||||
|
||||
@@ -101,14 +101,11 @@ func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) ([]input
|
||||
return nil, err
|
||||
}
|
||||
|
||||
pixelValues, err := ctx.Input().FromFloatSlice(f32s,
|
||||
pixelValues := ctx.Input().FromFloatSlice(f32s,
|
||||
m.ImageProcessor.imageSize,
|
||||
m.ImageProcessor.imageSize,
|
||||
m.ImageProcessor.numChannels,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
visionOutputs := m.VisionModel.Forward(ctx, pixelValues)
|
||||
visionOutputs = m.MultiModalProjector.Forward(ctx, visionOutputs, m.imageSize, m.patchSize, m.VisionModel.eps)
|
||||
@@ -144,15 +141,8 @@ func (m *Model) PostTokenize(inputs []input.Input) ([]input.Input, error) {
|
||||
}
|
||||
|
||||
func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
|
||||
positions, err := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
outputs, err := ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
positions := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions))
|
||||
outputs := ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs))
|
||||
|
||||
return m.TextModel.Forward(ctx, batch.Inputs, positions, outputs, batch, m.Cache), nil
|
||||
}
|
||||
|
||||
@@ -142,10 +142,7 @@ func (l *Layer) Forward(ctx ml.Context, hiddenState, positions, outputs ml.Tenso
|
||||
}
|
||||
|
||||
func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
|
||||
positions, err := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
positions := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions))
|
||||
|
||||
hiddenState := m.TokenEmbedding.Forward(ctx, batch.Inputs)
|
||||
|
||||
@@ -154,10 +151,7 @@ func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
|
||||
|
||||
var outputs ml.Tensor
|
||||
if i == len(m.Layers)-1 {
|
||||
outputs, err = ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
outputs = ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs))
|
||||
}
|
||||
|
||||
hiddenState = layer.Forward(ctx, hiddenState, positions, outputs, m.Cache, m.Options)
|
||||
|
||||
@@ -77,10 +77,7 @@ func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) ([]input
|
||||
return nil, err
|
||||
}
|
||||
|
||||
tilesLocal, err := ctx.Input().FromFloatSlice(pixelsLocal, size.X, size.Y, m.numChannels)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
tilesLocal := ctx.Input().FromFloatSlice(pixelsLocal, size.X, size.Y, m.numChannels)
|
||||
|
||||
ratioW, ratioH := size.X/m.imageSize, size.Y/m.imageSize
|
||||
|
||||
@@ -91,11 +88,7 @@ func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) ([]input
|
||||
pixelValues := tilesLocal
|
||||
|
||||
if len(pixelsGlobal) > 0 {
|
||||
tilesGlobal, err := ctx.Input().FromFloatSlice(pixelsGlobal, m.imageSize, m.imageSize, m.numChannels)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
tilesGlobal := ctx.Input().FromFloatSlice(pixelsGlobal, m.imageSize, m.imageSize, m.numChannels)
|
||||
pixelValues = pixelValues.Concat(ctx, tilesGlobal, 3)
|
||||
}
|
||||
|
||||
@@ -182,15 +175,8 @@ func (m *Model) PostTokenize(inputs []input.Input) ([]input.Input, error) {
|
||||
}
|
||||
|
||||
func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
|
||||
positions, err := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
outputs, err := ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
positions := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions))
|
||||
outputs := ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs))
|
||||
|
||||
return m.TextModel.Forward(ctx, batch.Inputs, positions, outputs, batch, m.Cache), nil
|
||||
}
|
||||
|
||||
@@ -223,11 +223,7 @@ func (m *TextModel) Forward(ctx ml.Context, inputs, positions, outputs ml.Tensor
|
||||
scales[i] = float32(math.Log(math.Floor(((float64(p)+1.0)/float64(m.attentionFloorScale))+1.0))*m.attentionScale + 1.0)
|
||||
}
|
||||
|
||||
var err error
|
||||
attentionScales, err = ctx.Input().FromFloatSlice(scales, 1, 1, len(scales))
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
attentionScales = ctx.Input().FromFloatSlice(scales, 1, 1, len(scales))
|
||||
}
|
||||
|
||||
for i, layer := range m.Layers {
|
||||
|
||||
@@ -245,10 +245,7 @@ func (m *VisionModel) rotaryEmbedding(ctx ml.Context) (ml.Tensor, ml.Tensor) {
|
||||
}
|
||||
}
|
||||
|
||||
ropeFreqs, err := ctx.Input().FromFloatSlice(freqs, freqDim/2, numPatches, 2)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
ropeFreqs := ctx.Input().FromFloatSlice(freqs, freqDim/2, numPatches, 2)
|
||||
|
||||
ropeFreqs = ropeFreqs.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
|
||||
ropeFreqs = ropeFreqs.Reshape(ctx, freqDim, 1, numPatches)
|
||||
|
||||
@@ -114,10 +114,7 @@ func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) ([]input
|
||||
return nil, err
|
||||
}
|
||||
|
||||
pixelValues, err := ctx.Input().FromFloatSlice(f32s, size.X, size.Y, m.ImageProcessor.numChannels)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
pixelValues := ctx.Input().FromFloatSlice(f32s, size.X, size.Y, m.ImageProcessor.numChannels)
|
||||
|
||||
visionOutputs := m.VisionModel.Forward(ctx, pixelValues)
|
||||
features, size := m.MultiModalProjector.Forward(ctx, visionOutputs, size)
|
||||
@@ -161,15 +158,8 @@ func (m *Model) PostTokenize(inputs []input.Input) ([]input.Input, error) {
|
||||
}
|
||||
|
||||
func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
|
||||
positions, err := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
outputs, err := ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
positions := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions))
|
||||
outputs := ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs))
|
||||
|
||||
return m.TextModel.Forward(ctx, batch.Inputs, positions, outputs, batch, m.Cache), nil
|
||||
}
|
||||
|
||||
@@ -110,15 +110,8 @@ func (m *VisionModel) positionalEmbedding(ctx ml.Context, positionIDs ml.Tensor)
|
||||
}
|
||||
}
|
||||
|
||||
h, err := ctx.Input().FromFloatSlice(frequenciesHeight, maxPatchesPerSide, frequencies/2)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
w, err := ctx.Input().FromFloatSlice(frequenciesWidth, maxPatchesPerSide, frequencies/2)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
h := ctx.Input().FromFloatSlice(frequenciesHeight, maxPatchesPerSide, frequencies/2)
|
||||
w := ctx.Input().FromFloatSlice(frequenciesWidth, maxPatchesPerSide, frequencies/2)
|
||||
|
||||
h = h.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx)
|
||||
w = w.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx)
|
||||
@@ -151,10 +144,7 @@ func (m *VisionModel) Forward(ctx ml.Context, pixelValues ml.Tensor) ml.Tensor {
|
||||
}
|
||||
}
|
||||
|
||||
positionIDs, err := ctx.Input().FromIntSlice(positions, len(positions))
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
positionIDs := ctx.Input().FromIntSlice(positions, len(positions))
|
||||
|
||||
positionEmbedding := m.positionalEmbedding(ctx, positionIDs)
|
||||
cos, sin := positionEmbedding.Cos(ctx), positionEmbedding.Sin(ctx)
|
||||
|
||||
@@ -80,15 +80,8 @@ func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) ([]input
|
||||
f32s = f32s[:m.imageSize*m.imageSize*m.numChannels*m.maxNumTiles]
|
||||
}
|
||||
|
||||
pixelValues, err := ctx.Input().FromFloatSlice(f32s, m.imageSize, m.imageSize, m.numChannels, m.maxNumTiles)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
aspectRatio, err := ctx.Input().FromIntSlice([]int32{int32(ratio.rank)}, 1)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
pixelValues := ctx.Input().FromFloatSlice(f32s, m.imageSize, m.imageSize, m.numChannels, m.maxNumTiles)
|
||||
aspectRatio := ctx.Input().FromIntSlice([]int32{int32(ratio.rank)}, 1)
|
||||
|
||||
positionIDs := ctx.Arange(0, 1601, 1, ml.DTypeI32)
|
||||
crossAttentionStates := m.VisionModel.Forward(ctx, pixelValues, positionIDs, aspectRatio)
|
||||
@@ -113,15 +106,8 @@ func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
|
||||
crossAttentionStates = batch.Multimodal[len(batch.Multimodal)-1].Multimodal[0].Tensor
|
||||
}
|
||||
|
||||
positions, err := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
outputs, err := ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
positions := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions))
|
||||
outputs := ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs))
|
||||
|
||||
// TODO: attention mask, cross attention mask
|
||||
return m.TextModel.Forward(ctx, batch.Inputs, positions, outputs, crossAttentionStates, nil, m.Cache.(*kvcache.WrapperCache)), nil
|
||||
|
||||
@@ -16,8 +16,6 @@ type VisionSelfAttention struct {
|
||||
Key *nn.Linear `gguf:"attn_k"`
|
||||
Value *nn.Linear `gguf:"attn_v"`
|
||||
Output *nn.Linear `gguf:"attn_output"`
|
||||
|
||||
Gate ml.Tensor `gguf:"attn_gate"`
|
||||
}
|
||||
|
||||
func (sa *VisionSelfAttention) Forward(ctx ml.Context, hiddenState ml.Tensor, opts *VisionModelOptions) ml.Tensor {
|
||||
@@ -25,27 +23,16 @@ func (sa *VisionSelfAttention) Forward(ctx ml.Context, hiddenState ml.Tensor, op
|
||||
|
||||
query := sa.Query.Forward(ctx, hiddenState)
|
||||
query = query.Reshape(ctx, headDim, opts.numHeads, query.Dim(1), batchSize)
|
||||
query = query.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
|
||||
|
||||
key := sa.Key.Forward(ctx, hiddenState)
|
||||
key = key.Reshape(ctx, headDim, opts.numHeads, key.Dim(1), batchSize)
|
||||
key = key.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
|
||||
|
||||
value := sa.Value.Forward(ctx, hiddenState)
|
||||
value = value.Reshape(ctx, headDim, opts.numHeads, value.Dim(1), batchSize)
|
||||
value = value.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx)
|
||||
|
||||
scores := key.Mulmat(ctx, query)
|
||||
scores = scores.Scale(ctx, 1.0/math.Sqrt(float64(headDim)))
|
||||
scores = scores.Softmax(ctx)
|
||||
|
||||
attention := value.Mulmat(ctx, scores)
|
||||
attention = attention.Reshape(ctx, headDim, attention.Dim(1), opts.numHeads, batchSize)
|
||||
attention = attention.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
|
||||
attention := nn.Attention(ctx, query, key, value, 1./math.Sqrt(float64(headDim)), nil)
|
||||
attention = attention.Reshape(ctx, opts.hiddenSize, attention.Dim(2), batchSize)
|
||||
|
||||
hiddenState = sa.Output.Forward(ctx, attention)
|
||||
return hiddenState
|
||||
return sa.Output.Forward(ctx, attention)
|
||||
}
|
||||
|
||||
type VisionMLP struct {
|
||||
@@ -76,21 +63,18 @@ func (e *VisionEncoderLayer) Forward(ctx ml.Context, hiddenState ml.Tensor, opts
|
||||
// self attention
|
||||
hiddenState = e.AttentionNorm.Forward(ctx, hiddenState, opts.eps)
|
||||
hiddenState = e.SelfAttention.Forward(ctx, hiddenState, opts)
|
||||
|
||||
if e.AttentionGate != nil {
|
||||
hiddenState = hiddenState.Mul(ctx, e.AttentionGate)
|
||||
}
|
||||
hiddenState = hiddenState.Add(ctx, residual)
|
||||
residual = hiddenState
|
||||
|
||||
// feed forward
|
||||
hiddenState = e.MLPNorm.Forward(ctx, hiddenState, opts.eps)
|
||||
hiddenState = e.MLP.Forward(ctx, hiddenState, opts)
|
||||
hiddenState = hiddenState.Add(ctx, residual)
|
||||
if e.MLPGate != nil {
|
||||
hiddenState = hiddenState.Mul(ctx, e.MLPGate)
|
||||
}
|
||||
|
||||
hiddenState = hiddenState.Add(ctx, residual)
|
||||
return hiddenState
|
||||
}
|
||||
|
||||
|
||||
@@ -100,10 +100,7 @@ type Model struct {
|
||||
|
||||
// Forward implements model.Model.
|
||||
func (m Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
|
||||
positions, err := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
positions := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions))
|
||||
|
||||
hiddenStates := m.TokenEmbedding.Forward(ctx, batch.Inputs)
|
||||
|
||||
@@ -112,10 +109,7 @@ func (m Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
|
||||
|
||||
var outputs ml.Tensor
|
||||
if i == len(m.Layers)-1 {
|
||||
outputs, err = ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
outputs = ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs))
|
||||
}
|
||||
|
||||
hiddenStates = layer.Forward(ctx, hiddenStates, positions, outputs, m.Cache, &m.Options)
|
||||
|
||||
@@ -69,10 +69,7 @@ func (m *Model) PixelValues(ctx ml.Context, multimodalData []byte) (ml.Tensor, *
|
||||
m.ImageProcessor.patchSize * m.ImageProcessor.patchSize
|
||||
numPatches := grid.Temporal * grid.Height * grid.Width
|
||||
|
||||
pixelValues, err := ctx.Input().FromFloatSlice(f32s, patchDim, numPatches)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("failed to create tensor from image: %w", err)
|
||||
}
|
||||
pixelValues := ctx.Input().FromFloatSlice(f32s, patchDim, numPatches)
|
||||
|
||||
return pixelValues, grid, nil
|
||||
}
|
||||
@@ -142,15 +139,8 @@ func (m *Model) PostTokenize(inputs []input.Input) ([]input.Input, error) {
|
||||
}
|
||||
|
||||
func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
|
||||
positions, err := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
outputs, err := ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
positions := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions))
|
||||
outputs := ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs))
|
||||
|
||||
return m.TextModel.Forward(ctx, batch.Inputs, positions, outputs, batch, m.Cache)
|
||||
}
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
package qwen25vl
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"math"
|
||||
"slices"
|
||||
|
||||
@@ -44,10 +43,8 @@ func blockDiagonalMask(ctx ml.Context, seqLength int, bounds []int, numHeads int
|
||||
}
|
||||
}
|
||||
|
||||
mask, err := ctx.Input().FromFloatSlice(flat, seqLength, seqLength)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
mask := ctx.Input().FromFloatSlice(flat, seqLength, seqLength)
|
||||
|
||||
// Reshape to match [seqLength, seqLength, 1] for broadcasting
|
||||
mask = mask.Reshape(ctx, seqLength, seqLength, 1)
|
||||
|
||||
@@ -303,10 +300,7 @@ func (m *VisionModel) WindowIndex(ctx ml.Context, grid *Grid) (ml.Tensor, []int)
|
||||
}
|
||||
}
|
||||
|
||||
t, err := ctx.Input().FromIntSlice(index, len(index))
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
t := ctx.Input().FromIntSlice(index, len(index))
|
||||
|
||||
return t, bounds
|
||||
}
|
||||
@@ -326,10 +320,7 @@ func (m *VisionModel) PositionalEmbedding(ctx ml.Context, grid *Grid) ml.Tensor
|
||||
freqVals[i*freq+j] = float32(i) / float32(math.Pow(theta, float64(j*2)/float64(dim)))
|
||||
}
|
||||
}
|
||||
freqs, err := ctx.Input().FromFloatSlice(freqVals, freq, maxGridSize)
|
||||
if err != nil {
|
||||
panic(fmt.Errorf("failed to create tensor from frequencies: %w", err))
|
||||
}
|
||||
freqs := ctx.Input().FromFloatSlice(freqVals, freq, maxGridSize)
|
||||
|
||||
// Create position coordinates (y,x pairs) for the grid
|
||||
// In PyTorch: Equivalent to generating position ids with torch.arange()
|
||||
@@ -339,10 +330,7 @@ func (m *VisionModel) PositionalEmbedding(ctx ml.Context, grid *Grid) ml.Tensor
|
||||
coords = append(coords, int32(y), int32(x))
|
||||
}
|
||||
}
|
||||
pos, err := ctx.Input().FromIntSlice(coords, 2, grid.Width, grid.Height)
|
||||
if err != nil {
|
||||
panic(fmt.Errorf("failed to create tensor from positions: %w", err))
|
||||
}
|
||||
pos := ctx.Input().FromIntSlice(coords, 2, grid.Width, grid.Height)
|
||||
|
||||
// Reshape and permute positions to match spatial merging pattern
|
||||
pos = pos.Reshape(ctx, 2, grid.Width, merge, grid.Height/merge)
|
||||
|
||||
@@ -156,10 +156,7 @@ type Model struct {
|
||||
|
||||
// Forward implements model.Model.
|
||||
func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
|
||||
positions, err := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
positions := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions))
|
||||
|
||||
hiddenStates := m.TokenEmbedding.Forward(ctx, batch.Inputs)
|
||||
|
||||
@@ -168,10 +165,7 @@ func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
|
||||
|
||||
var outputs ml.Tensor
|
||||
if i == len(m.Layers)-1 {
|
||||
outputs, err = ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
outputs = ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs))
|
||||
}
|
||||
|
||||
hiddenStates = layer.Forward(ctx, hiddenStates, positions, outputs, m.Cache, m.Options)
|
||||
|
||||
@@ -61,6 +61,8 @@ const (
|
||||
ColorGrey = Esc + "[38;5;245m"
|
||||
ColorDefault = Esc + "[0m"
|
||||
|
||||
ColorBold = Esc + "[1m"
|
||||
|
||||
StartBracketedPaste = Esc + "[?2004h"
|
||||
EndBracketedPaste = Esc + "[?2004l"
|
||||
)
|
||||
|
||||
@@ -95,17 +95,14 @@ func (m multimodalStore) getTensor(backend ml.Backend, ctx ml.Context, in ml.Ten
|
||||
}
|
||||
}
|
||||
} else {
|
||||
err := computeCtx.Reserve()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
computeCtx.Reserve()
|
||||
}
|
||||
}
|
||||
|
||||
for i, t := range entry.mm {
|
||||
if in == t.Tensor {
|
||||
if !reserve {
|
||||
return ctx.Input().FromFloatSlice(entry.data[i], t.Tensor.Shape()...)
|
||||
return ctx.Input().FromFloatSlice(entry.data[i], t.Tensor.Shape()...), nil
|
||||
} else {
|
||||
return ctx.Input().Empty(t.Tensor.DType(), t.Tensor.Shape()...), nil
|
||||
}
|
||||
|
||||
@@ -808,10 +808,7 @@ func (s *Server) reserveWorstCaseGraph() error {
|
||||
batch.Outputs[i] = int32(i)
|
||||
}
|
||||
|
||||
batch.Inputs, err = ctx.Input().FromIntSlice(batchInputs, len(batchInputs))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
batch.Inputs = ctx.Input().FromIntSlice(batchInputs, len(batchInputs))
|
||||
|
||||
cache := s.model.Config().Cache
|
||||
if cache != nil {
|
||||
@@ -826,16 +823,12 @@ func (s *Server) reserveWorstCaseGraph() error {
|
||||
return err
|
||||
}
|
||||
|
||||
err = ctx.Forward(t).Reserve()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
ctx.Forward(t).Reserve()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Server) loadModel(
|
||||
ctx context.Context,
|
||||
func (s *Server) initModel(
|
||||
mpath string,
|
||||
params ml.BackendParams,
|
||||
lpath multiLPath,
|
||||
@@ -843,21 +836,21 @@ func (s *Server) loadModel(
|
||||
kvCacheType string,
|
||||
kvSize int,
|
||||
multiUserCache bool,
|
||||
) {
|
||||
) error {
|
||||
var err error
|
||||
s.model, err = model.New(mpath, params)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
return err
|
||||
}
|
||||
|
||||
// TODO(jessegross): LoRA loading
|
||||
if lpath.String() != "" {
|
||||
panic("loras are not yet implemented")
|
||||
return errors.New("loras are not yet implemented")
|
||||
}
|
||||
|
||||
s.cache, err = NewInputCache(s.model, kvCacheType, int32(kvSize), parallel, s.batchSize, multiUserCache)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
return err
|
||||
}
|
||||
|
||||
if !s.cache.enabled && parallel > 1 {
|
||||
@@ -869,11 +862,26 @@ func (s *Server) loadModel(
|
||||
s.seqs = make([]*Sequence, s.parallel)
|
||||
s.seqsSem = semaphore.NewWeighted(int64(s.parallel))
|
||||
|
||||
err = s.reserveWorstCaseGraph()
|
||||
return s.reserveWorstCaseGraph()
|
||||
}
|
||||
|
||||
func (s *Server) load(
|
||||
ctx context.Context,
|
||||
mpath string,
|
||||
params ml.BackendParams,
|
||||
lpath multiLPath,
|
||||
parallel int,
|
||||
kvCacheType string,
|
||||
kvSize int,
|
||||
multiUserCache bool,
|
||||
) {
|
||||
err := s.initModel(mpath, params, lpath, parallel, kvCacheType, kvSize, multiUserCache)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
slog.Debug("memory", "allocated", s.model.Backend().BackendMemory())
|
||||
|
||||
err = s.model.Backend().Load(ctx,
|
||||
func(progress float32) {
|
||||
s.progress = progress
|
||||
@@ -921,9 +929,14 @@ func Execute(args []string) error {
|
||||
status: llm.ServerStatusLoadingModel,
|
||||
}
|
||||
|
||||
server.cond = sync.NewCond(&server.mu)
|
||||
server.ready.Add(1)
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
// TODO(jessegross): Parameters that need to be implemented:
|
||||
// no-mmap
|
||||
// mlock
|
||||
|
||||
var tensorSplitFloats []float32
|
||||
if *tensorSplit != "" {
|
||||
@@ -943,14 +956,7 @@ func Execute(args []string) error {
|
||||
FlashAttention: *flashAttention,
|
||||
}
|
||||
|
||||
server.ready.Add(1)
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
go server.loadModel(ctx, *mpath, params, lpaths, *parallel, *kvCacheType, *kvSize, *multiUserCache)
|
||||
|
||||
server.cond = sync.NewCond(&server.mu)
|
||||
|
||||
go server.load(ctx, *mpath, params, lpaths, *parallel, *kvCacheType, *kvSize, *multiUserCache)
|
||||
go server.run(ctx)
|
||||
|
||||
addr := "127.0.0.1:" + strconv.Itoa(*port)
|
||||
|
||||
@@ -501,48 +501,27 @@ func ggufLayers(digest string, fn func(resp api.ProgressResponse)) ([]*layerGGML
|
||||
return nil, errOnlyGGUFSupported
|
||||
}
|
||||
|
||||
stat, err := blob.Stat()
|
||||
f, err := ggml.Decode(blob, -1)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var offset int64
|
||||
for offset < stat.Size() {
|
||||
f, err := ggml.Decode(blob, -1)
|
||||
if errors.Is(err, io.EOF) {
|
||||
break
|
||||
} else if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
mediatype := "application/vnd.ollama.image.model"
|
||||
if f.KV().Kind() == "adapter" {
|
||||
mediatype = "application/vnd.ollama.image.adapter"
|
||||
} else if _, ok := f.KV()[fmt.Sprintf("%s.vision.block_count", f.KV().Architecture())]; ok || f.KV().Kind() == "projector" {
|
||||
mediatype = "application/vnd.ollama.image.projector"
|
||||
}
|
||||
|
||||
var layer Layer
|
||||
if digest != "" && f.Length == stat.Size() && offset == 0 {
|
||||
layer, err = NewLayerFromLayer(digest, mediatype, blob.Name())
|
||||
if err != nil {
|
||||
slog.Debug("could not create new layer from layer", "error", err)
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
// Fallback to creating layer from file copy (either NewLayerFromLayer failed, or digest empty/n != stat.Size())
|
||||
if layer.Digest == "" {
|
||||
layer, err = NewLayer(io.NewSectionReader(blob, offset, f.Length), mediatype)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
layers = append(layers, &layerGGML{layer, f})
|
||||
offset = f.Length
|
||||
mediatype := "application/vnd.ollama.image.model"
|
||||
if f.KV().Kind() == "adapter" {
|
||||
mediatype = "application/vnd.ollama.image.adapter"
|
||||
} else if (f.KV().Uint("block_count") == 0 && f.KV().Uint("vision.block_count") > 0) || f.KV().Kind() == "projector" {
|
||||
// if a model has vision.block_count but not block_count, it is a standalone vision model
|
||||
mediatype = "application/vnd.ollama.image.projector"
|
||||
}
|
||||
|
||||
layer, err := NewLayerFromLayer(digest, mediatype, blob.Name())
|
||||
if err != nil {
|
||||
slog.Debug("could not create new layer from layer", "error", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
layers = append(layers, &layerGGML{layer, f})
|
||||
|
||||
return detectChatTemplate(layers)
|
||||
}
|
||||
|
||||
|
||||
@@ -464,6 +464,10 @@ type downloadOpts struct {
|
||||
|
||||
// downloadBlob downloads a blob from the registry and stores it in the blobs directory
|
||||
func downloadBlob(ctx context.Context, opts downloadOpts) (cacheHit bool, _ error) {
|
||||
if opts.digest == "" {
|
||||
return false, fmt.Errorf(("%s: %s"), opts.mp.GetNamespaceRepository(), "digest is is empty")
|
||||
}
|
||||
|
||||
fp, err := GetBlobsPath(opts.digest)
|
||||
if err != nil {
|
||||
return false, err
|
||||
|
||||
@@ -37,6 +37,7 @@ var (
|
||||
errCapabilityInsert = errors.New("insert")
|
||||
errCapabilityVision = errors.New("vision")
|
||||
errCapabilityEmbedding = errors.New("embedding")
|
||||
errCapabilityThinking = errors.New("thinking")
|
||||
errInsecureProtocol = errors.New("insecure protocol http")
|
||||
)
|
||||
|
||||
@@ -111,6 +112,12 @@ func (m *Model) Capabilities() []model.Capability {
|
||||
capabilities = append(capabilities, model.CapabilityVision)
|
||||
}
|
||||
|
||||
// Check for thinking capability
|
||||
openingTag, closingTag := inferThinkingTags(m.Template.Template)
|
||||
if openingTag != "" && closingTag != "" {
|
||||
capabilities = append(capabilities, model.CapabilityThinking)
|
||||
}
|
||||
|
||||
return capabilities
|
||||
}
|
||||
|
||||
@@ -127,6 +134,7 @@ func (m *Model) CheckCapabilities(want ...model.Capability) error {
|
||||
model.CapabilityInsert: errCapabilityInsert,
|
||||
model.CapabilityVision: errCapabilityVision,
|
||||
model.CapabilityEmbedding: errCapabilityEmbedding,
|
||||
model.CapabilityThinking: errCapabilityThinking,
|
||||
}
|
||||
|
||||
for _, cap := range want {
|
||||
@@ -141,11 +149,19 @@ func (m *Model) CheckCapabilities(want ...model.Capability) error {
|
||||
}
|
||||
}
|
||||
|
||||
var err error
|
||||
if len(errs) > 0 {
|
||||
return fmt.Errorf("%w %w", errCapabilities, errors.Join(errs...))
|
||||
err = fmt.Errorf("%w %w", errCapabilities, errors.Join(errs...))
|
||||
}
|
||||
|
||||
return nil
|
||||
if slices.Contains(errs, errCapabilityThinking) {
|
||||
if m.Config.ModelFamily == "qwen3" || model.ParseName(m.Name).Model == "deepseek-r1" {
|
||||
// append a message to the existing error
|
||||
return fmt.Errorf("%w. Pull the model again to get the latest version with full thinking support", err)
|
||||
}
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
func (m *Model) String() string {
|
||||
|
||||
124
server/model.go
124
server/model.go
@@ -10,9 +10,6 @@ import (
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"os"
|
||||
"slices"
|
||||
"strings"
|
||||
"text/template/parse"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/fs/ggml"
|
||||
@@ -128,124 +125,3 @@ func detectContentType(r io.Reader) (string, error) {
|
||||
|
||||
return "unknown", nil
|
||||
}
|
||||
|
||||
func parseObjects(s string) []map[string]any {
|
||||
var objs []map[string]any
|
||||
for offset := 0; offset < len(s); {
|
||||
var obj map[string]any
|
||||
decoder := json.NewDecoder(strings.NewReader(s[offset:]))
|
||||
if err := decoder.Decode(&obj); errors.Is(err, io.EOF) || errors.Is(err, io.ErrUnexpectedEOF) {
|
||||
break
|
||||
} else if syntax := &(json.SyntaxError{}); errors.As(err, &syntax) {
|
||||
// skip over any syntax errors
|
||||
offset += int(syntax.Offset)
|
||||
} else if unmarshalType := &(json.UnmarshalTypeError{}); errors.As(err, &unmarshalType) {
|
||||
// skip over any unmarshalable types
|
||||
offset += int(unmarshalType.Offset)
|
||||
} else if err != nil {
|
||||
return nil
|
||||
} else {
|
||||
offset += int(decoder.InputOffset())
|
||||
objs = append(objs, obj)
|
||||
}
|
||||
}
|
||||
|
||||
return objs
|
||||
}
|
||||
|
||||
// parseToolCalls attempts to parse a JSON string into a slice of ToolCalls.
|
||||
// mxyng: this only really works if the input contains tool calls in some JSON format
|
||||
func (m *Model) parseToolCalls(s string) ([]api.ToolCall, bool) {
|
||||
// create a subtree from the node that ranges over .ToolCalls
|
||||
tmpl := m.Template.Subtree(func(n parse.Node) bool {
|
||||
if t, ok := n.(*parse.RangeNode); ok {
|
||||
return slices.Contains(template.Identifiers(t.Pipe), "ToolCalls")
|
||||
}
|
||||
|
||||
return false
|
||||
})
|
||||
|
||||
if tmpl == nil {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
var b bytes.Buffer
|
||||
if err := tmpl.Execute(&b, map[string][]api.ToolCall{
|
||||
"ToolCalls": {
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "@@name@@",
|
||||
Arguments: api.ToolCallFunctionArguments{
|
||||
"@@argument@@": 1,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}); err != nil {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
templateObjects := parseObjects(b.String())
|
||||
if len(templateObjects) == 0 {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
// find the keys that correspond to the name and arguments fields
|
||||
var name, arguments string
|
||||
for k, v := range templateObjects[0] {
|
||||
switch v.(type) {
|
||||
case string:
|
||||
name = k
|
||||
case map[string]any:
|
||||
arguments = k
|
||||
}
|
||||
}
|
||||
|
||||
if name == "" || arguments == "" {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
responseObjects := parseObjects(s)
|
||||
if len(responseObjects) == 0 {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
// collect all nested objects
|
||||
var collect func(any) []map[string]any
|
||||
collect = func(obj any) (all []map[string]any) {
|
||||
switch o := obj.(type) {
|
||||
case map[string]any:
|
||||
all = append(all, o)
|
||||
for _, v := range o {
|
||||
all = append(all, collect(v)...)
|
||||
}
|
||||
case []any:
|
||||
for _, v := range o {
|
||||
all = append(all, collect(v)...)
|
||||
}
|
||||
}
|
||||
|
||||
return all
|
||||
}
|
||||
|
||||
var objs []map[string]any
|
||||
for _, p := range responseObjects {
|
||||
objs = append(objs, collect(p)...)
|
||||
}
|
||||
|
||||
var toolCalls []api.ToolCall
|
||||
for _, kv := range objs {
|
||||
n, nok := kv[name].(string)
|
||||
a, aok := kv[arguments].(map[string]any)
|
||||
if nok && aok {
|
||||
toolCalls = append(toolCalls, api.ToolCall{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: n,
|
||||
Arguments: a,
|
||||
},
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
return toolCalls, len(toolCalls) > 0
|
||||
}
|
||||
|
||||
@@ -1,179 +0,0 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/template"
|
||||
)
|
||||
|
||||
func readFile(t *testing.T, base, name string) *bytes.Buffer {
|
||||
t.Helper()
|
||||
|
||||
bts, err := os.ReadFile(filepath.Join(base, name))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
return bytes.NewBuffer(bts)
|
||||
}
|
||||
|
||||
func TestExecuteWithTools(t *testing.T) {
|
||||
p := filepath.Join("testdata", "tools")
|
||||
cases := []struct {
|
||||
model string
|
||||
output string
|
||||
ok bool
|
||||
}{
|
||||
{"mistral", `[TOOL_CALLS] [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`, true},
|
||||
{"mistral", `[TOOL_CALLS] [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]
|
||||
|
||||
The temperature in San Francisco, CA is 70°F and in Toronto, Canada is 20°C.`, true},
|
||||
{"mistral", `[TOOL_CALLS] [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"To }]`, false},
|
||||
{"mistral", `I'm not aware of that information. However, I can suggest searching for the weather using the "get_current_weather" function:
|
||||
|
||||
[{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`, true},
|
||||
{"mistral", " The weather in San Francisco, CA is 70°F and in Toronto, Canada is 20°C.", false},
|
||||
{"command-r-plus", "Action: ```json" + `
|
||||
[
|
||||
{
|
||||
"tool_name": "get_current_weather",
|
||||
"parameters": {
|
||||
"format": "fahrenheit",
|
||||
"location": "San Francisco, CA"
|
||||
}
|
||||
},
|
||||
{
|
||||
"tool_name": "get_current_weather",
|
||||
"parameters": {
|
||||
"format": "celsius",
|
||||
"location": "Toronto, Canada"
|
||||
}
|
||||
}
|
||||
]
|
||||
` + "```", true},
|
||||
{"command-r-plus", " The weather in San Francisco, CA is 70°F and in Toronto, Canada is 20°C.", false},
|
||||
{"firefunction", ` functools[{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`, true},
|
||||
{"firefunction", " The weather in San Francisco, CA is 70°F and in Toronto, Canada is 20°C.", false},
|
||||
{"llama3-groq-tool-use", `<tool_call>
|
||||
{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}}
|
||||
{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}
|
||||
</tool_call>`, true},
|
||||
{"xlam", `{"tool_calls": [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]}`, true},
|
||||
{"nemotron", `<toolcall>{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]} </toolcall>`, true},
|
||||
}
|
||||
|
||||
var tools []api.Tool
|
||||
if err := json.Unmarshal(readFile(t, p, "tools.json").Bytes(), &tools); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
var messages []api.Message
|
||||
if err := json.Unmarshal(readFile(t, p, "messages.json").Bytes(), &messages); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
calls := []api.ToolCall{
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_current_weather",
|
||||
Arguments: api.ToolCallFunctionArguments{
|
||||
"format": "fahrenheit",
|
||||
"location": "San Francisco, CA",
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_current_weather",
|
||||
Arguments: api.ToolCallFunctionArguments{
|
||||
"format": "celsius",
|
||||
"location": "Toronto, Canada",
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range cases {
|
||||
t.Run(tt.model, func(t *testing.T) {
|
||||
tmpl, err := template.Parse(readFile(t, p, fmt.Sprintf("%s.gotmpl", tt.model)).String())
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
t.Run("template", func(t *testing.T) {
|
||||
var actual bytes.Buffer
|
||||
if err := tmpl.Execute(&actual, template.Values{Tools: tools, Messages: messages}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(actual.String(), readFile(t, p, fmt.Sprintf("%s.out", tt.model)).String()); diff != "" {
|
||||
t.Errorf("mismatch (-got +want):\n%s", diff)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("parse", func(t *testing.T) {
|
||||
m := &Model{Template: tmpl}
|
||||
actual, ok := m.parseToolCalls(tt.output)
|
||||
if ok != tt.ok {
|
||||
t.Fatalf("expected %t, got %t", tt.ok, ok)
|
||||
}
|
||||
|
||||
if tt.ok {
|
||||
if diff := cmp.Diff(actual, calls); diff != "" {
|
||||
t.Errorf("mismatch (-got +want):\n%s", diff)
|
||||
}
|
||||
}
|
||||
})
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseObjects(t *testing.T) {
|
||||
tests := []struct {
|
||||
input string
|
||||
want []map[string]any
|
||||
}{
|
||||
{
|
||||
input: `[{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`,
|
||||
want: []map[string]any{
|
||||
{"name": "get_current_weather", "arguments": map[string]any{"format": "fahrenheit", "location": "San Francisco, CA"}},
|
||||
{"name": "get_current_weather", "arguments": map[string]any{"format": "celsius", "location": "Toronto, Canada"}},
|
||||
},
|
||||
},
|
||||
{
|
||||
input: `<toolcall>{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}} </toolcall>`,
|
||||
want: []map[string]any{
|
||||
{"name": "get_current_weather", "arguments": map[string]any{"format": "fahrenheit", "location": "San Francisco, CA"}},
|
||||
},
|
||||
},
|
||||
{
|
||||
input: `<toolcall>{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}} </toolcall> <toolcall>{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, ON"}} </toolcall>`,
|
||||
want: []map[string]any{
|
||||
{"name": "get_current_weather", "arguments": map[string]any{"format": "fahrenheit", "location": "San Francisco, CA"}},
|
||||
{"name": "get_current_weather", "arguments": map[string]any{"format": "celsius", "location": "Toronto, ON"}},
|
||||
},
|
||||
},
|
||||
{
|
||||
input: `{"name": "get_current_weather", "arguments": `,
|
||||
want: nil,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.input, func(t *testing.T) {
|
||||
got := parseObjects(tc.input)
|
||||
|
||||
if diff := cmp.Diff(got, tc.want); diff != "" {
|
||||
t.Errorf("mismatch (-got +want):\n%s", diff)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -116,7 +116,7 @@ func (mp ModelPath) BaseURL() *url.URL {
|
||||
func GetManifestPath() (string, error) {
|
||||
path := filepath.Join(envconfig.Models(), "manifests")
|
||||
if err := os.MkdirAll(path, 0o755); err != nil {
|
||||
return "", err
|
||||
return "", fmt.Errorf("%w: ensure path elements are traversable", err)
|
||||
}
|
||||
|
||||
return path, nil
|
||||
@@ -139,7 +139,7 @@ func GetBlobsPath(digest string) (string, error) {
|
||||
}
|
||||
|
||||
if err := os.MkdirAll(dirPath, 0o755); err != nil {
|
||||
return "", err
|
||||
return "", fmt.Errorf("%w: ensure path elements are traversable", err)
|
||||
}
|
||||
|
||||
return path, nil
|
||||
|
||||
@@ -19,7 +19,7 @@ type tokenizeFunc func(context.Context, string) ([]int, error)
|
||||
// chatPrompt accepts a list of messages and returns the prompt and images that should be used for the next chat turn.
|
||||
// chatPrompt truncates any messages that exceed the context window of the model, making sure to always include 1) the
|
||||
// latest message and 2) system messages
|
||||
func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api.Options, msgs []api.Message, tools []api.Tool) (prompt string, images []llm.ImageData, _ error) {
|
||||
func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api.Options, msgs []api.Message, tools []api.Tool, think *bool) (prompt string, images []llm.ImageData, _ error) {
|
||||
var system []api.Message
|
||||
|
||||
// TODO: Ideally we would compute this from the projector metadata but some pieces are implementation dependent
|
||||
@@ -41,8 +41,12 @@ func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api.
|
||||
}
|
||||
}
|
||||
|
||||
thinkVal := false
|
||||
if think != nil {
|
||||
thinkVal = *think
|
||||
}
|
||||
var b bytes.Buffer
|
||||
if err := m.Template.Execute(&b, template.Values{Messages: append(system, msgs[i:]...), Tools: tools}); err != nil {
|
||||
if err := m.Template.Execute(&b, template.Values{Messages: append(system, msgs[i:]...), Tools: tools, Think: thinkVal, IsThinkSet: think != nil}); err != nil {
|
||||
return "", nil, err
|
||||
}
|
||||
|
||||
@@ -96,7 +100,11 @@ func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api.
|
||||
|
||||
// truncate any messages that do not fit into the context window
|
||||
var b bytes.Buffer
|
||||
if err := m.Template.Execute(&b, template.Values{Messages: append(system, msgs[currMsgIdx:]...), Tools: tools}); err != nil {
|
||||
thinkVal := false
|
||||
if think != nil {
|
||||
thinkVal = *think
|
||||
}
|
||||
if err := m.Template.Execute(&b, template.Values{Messages: append(system, msgs[currMsgIdx:]...), Tools: tools, Think: thinkVal, IsThinkSet: think != nil}); err != nil {
|
||||
return "", nil, err
|
||||
}
|
||||
|
||||
|
||||
@@ -208,7 +208,8 @@ func TestChatPrompt(t *testing.T) {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
model := tt.model
|
||||
opts := api.Options{Runner: api.Runner{NumCtx: tt.limit}}
|
||||
prompt, images, err := chatPrompt(t.Context(), &model, mockRunner{}.Tokenize, &opts, tt.msgs, nil)
|
||||
think := false
|
||||
prompt, images, err := chatPrompt(t.Context(), &model, mockRunner{}.Tokenize, &opts, tt.msgs, nil, &think)
|
||||
if tt.error == nil && err != nil {
|
||||
t.Fatal(err)
|
||||
} else if tt.error != nil && err != tt.error {
|
||||
|
||||
@@ -120,14 +120,30 @@ func getTensorNewType(kv fsggml.KV, qs *quantizeState, newType fsggml.TensorType
|
||||
|
||||
if newType.IsQuantized() {
|
||||
nx := shape[0]
|
||||
ny := uint64(1)
|
||||
if len(shape) > 1 {
|
||||
ny = shape[1]
|
||||
}
|
||||
qk_k := newType.BlockSize()
|
||||
|
||||
// Check if first dimension is divisible by block size
|
||||
if nx%qk_k != 0 {
|
||||
slog.Warn(fmt.Sprintf("tensor cols %d x %d are not divisible by %d, required for %s. Falling back to quantization %s", nx, ny, qk_k, newType.String(), fsggml.TensorTypeF16.String()))
|
||||
newType = fsggml.TensorTypeF16
|
||||
// Store the original type for logging
|
||||
originalType := newType
|
||||
|
||||
// Select appropriate fallback based on original type
|
||||
switch newType {
|
||||
case fsggml.TensorTypeQ4_K:
|
||||
newType = fsggml.TensorTypeQ5_0
|
||||
case fsggml.TensorTypeQ5_K:
|
||||
newType = fsggml.TensorTypeQ5_1
|
||||
case fsggml.TensorTypeQ6_K:
|
||||
newType = fsggml.TensorTypeQ8_0
|
||||
}
|
||||
|
||||
// Final check - if still incompatible, fall back to F16
|
||||
if nx%newType.BlockSize() != 0 {
|
||||
newType = fsggml.TensorTypeF16
|
||||
}
|
||||
|
||||
slog.Warn(fmt.Sprintf("tensor cols %d are not divisible by %d, required for %s - using fallback quantization %s",
|
||||
nx, qk_k, originalType.String(), newType.String()))
|
||||
}
|
||||
}
|
||||
return newType
|
||||
|
||||
153
server/routes.go
153
server/routes.go
@@ -17,7 +17,6 @@ import (
|
||||
"net/netip"
|
||||
"os"
|
||||
"os/signal"
|
||||
"regexp"
|
||||
"slices"
|
||||
"strings"
|
||||
"syscall"
|
||||
@@ -38,6 +37,7 @@ import (
|
||||
"github.com/ollama/ollama/server/internal/client/ollama"
|
||||
"github.com/ollama/ollama/server/internal/registry"
|
||||
"github.com/ollama/ollama/template"
|
||||
"github.com/ollama/ollama/tools"
|
||||
"github.com/ollama/ollama/types/errtypes"
|
||||
"github.com/ollama/ollama/types/model"
|
||||
"github.com/ollama/ollama/version"
|
||||
@@ -185,6 +185,13 @@ func (s *Server) GenerateHandler(c *gin.Context) {
|
||||
if req.Suffix != "" {
|
||||
caps = append(caps, model.CapabilityInsert)
|
||||
}
|
||||
if req.Think != nil && *req.Think {
|
||||
caps = append(caps, model.CapabilityThinking)
|
||||
// TODO(drifkin): consider adding a warning if it's false and the model
|
||||
// doesn't support thinking. It's not strictly required, but it can be a
|
||||
// hint that the user is on an older qwen3/r1 model that doesn't have an
|
||||
// updated template supporting thinking
|
||||
}
|
||||
|
||||
r, m, opts, err := s.scheduleRunner(c.Request.Context(), name.String(), caps, req.Options, req.KeepAlive)
|
||||
if errors.Is(err, errCapabilityCompletion) {
|
||||
@@ -253,6 +260,9 @@ func (s *Server) GenerateHandler(c *gin.Context) {
|
||||
values.Messages = append(msgs, api.Message{Role: "user", Content: req.Prompt})
|
||||
}
|
||||
|
||||
values.Think = req.Think != nil && *req.Think
|
||||
values.IsThinkSet = req.Think != nil
|
||||
|
||||
var b bytes.Buffer
|
||||
if req.Context != nil {
|
||||
slog.Warn("the context field is deprecated and will be removed in a future version of Ollama")
|
||||
@@ -272,6 +282,15 @@ func (s *Server) GenerateHandler(c *gin.Context) {
|
||||
prompt = b.String()
|
||||
}
|
||||
|
||||
var thinkingState *thinkingParser
|
||||
openingTag, closingTag := inferThinkingTags(m.Template.Template)
|
||||
if req.Think != nil && *req.Think && openingTag != "" && closingTag != "" {
|
||||
thinkingState = &thinkingParser{
|
||||
openingTag: openingTag,
|
||||
closingTag: closingTag,
|
||||
}
|
||||
}
|
||||
|
||||
ch := make(chan any)
|
||||
go func() {
|
||||
// TODO (jmorganca): avoid building the response twice both here and below
|
||||
@@ -296,6 +315,12 @@ func (s *Server) GenerateHandler(c *gin.Context) {
|
||||
},
|
||||
}
|
||||
|
||||
if thinkingState != nil {
|
||||
thinking, content := thinkingState.addContent(cr.Content)
|
||||
res.Thinking = thinking
|
||||
res.Response = content
|
||||
}
|
||||
|
||||
if _, err := sb.WriteString(cr.Content); err != nil {
|
||||
ch <- gin.H{"error": err.Error()}
|
||||
}
|
||||
@@ -323,11 +348,13 @@ func (s *Server) GenerateHandler(c *gin.Context) {
|
||||
|
||||
if req.Stream != nil && !*req.Stream {
|
||||
var r api.GenerateResponse
|
||||
var sb strings.Builder
|
||||
var sbThinking strings.Builder
|
||||
var sbContent strings.Builder
|
||||
for rr := range ch {
|
||||
switch t := rr.(type) {
|
||||
case api.GenerateResponse:
|
||||
sb.WriteString(t.Response)
|
||||
sbThinking.WriteString(t.Thinking)
|
||||
sbContent.WriteString(t.Response)
|
||||
r = t
|
||||
case gin.H:
|
||||
msg, ok := t["error"].(string)
|
||||
@@ -343,7 +370,9 @@ func (s *Server) GenerateHandler(c *gin.Context) {
|
||||
}
|
||||
}
|
||||
|
||||
r.Response = sb.String()
|
||||
r.Thinking = sbThinking.String()
|
||||
r.Response = sbContent.String()
|
||||
|
||||
c.JSON(http.StatusOK, r)
|
||||
return
|
||||
}
|
||||
@@ -1435,6 +1464,9 @@ func (s *Server) ChatHandler(c *gin.Context) {
|
||||
if len(req.Tools) > 0 {
|
||||
caps = append(caps, model.CapabilityTools)
|
||||
}
|
||||
if req.Think != nil && *req.Think {
|
||||
caps = append(caps, model.CapabilityThinking)
|
||||
}
|
||||
|
||||
name := model.ParseName(req.Model)
|
||||
if !name.IsValid() {
|
||||
@@ -1475,18 +1507,36 @@ func (s *Server) ChatHandler(c *gin.Context) {
|
||||
}
|
||||
msgs = filterThinkTags(msgs, m)
|
||||
|
||||
prompt, images, err := chatPrompt(c.Request.Context(), m, r.Tokenize, opts, msgs, req.Tools)
|
||||
prompt, images, err := chatPrompt(c.Request.Context(), m, r.Tokenize, opts, msgs, req.Tools, req.Think)
|
||||
if err != nil {
|
||||
slog.Error("chat prompt error", "error", err)
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
var thinkingState *thinkingParser
|
||||
openingTag, closingTag := inferThinkingTags(m.Template.Template)
|
||||
if req.Think != nil && *req.Think && openingTag != "" && closingTag != "" {
|
||||
thinkingState = &thinkingParser{
|
||||
openingTag: openingTag,
|
||||
closingTag: closingTag,
|
||||
}
|
||||
}
|
||||
|
||||
var toolParser *tools.Parser
|
||||
if len(req.Tools) > 0 {
|
||||
toolParser, err = tools.NewParser(m.Template.Template)
|
||||
if err != nil {
|
||||
slog.Error("failed to create tool parser", "error", err)
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
ch := make(chan any)
|
||||
go func() {
|
||||
defer close(ch)
|
||||
var sb strings.Builder
|
||||
var toolCallIndex int = 0
|
||||
|
||||
if err := r.Completion(c.Request.Context(), llm.CompletionRequest{
|
||||
Prompt: prompt,
|
||||
Images: images,
|
||||
@@ -1506,43 +1556,40 @@ func (s *Server) ChatHandler(c *gin.Context) {
|
||||
},
|
||||
}
|
||||
|
||||
if thinkingState != nil {
|
||||
thinkingContent, remainingContent := thinkingState.addContent(res.Message.Content)
|
||||
if thinkingContent == "" && remainingContent == "" && !r.Done {
|
||||
// need to accumulate more to decide what to send
|
||||
return
|
||||
}
|
||||
res.Message.Content = remainingContent
|
||||
res.Message.Thinking = thinkingContent
|
||||
}
|
||||
|
||||
if r.Done {
|
||||
res.DoneReason = r.DoneReason.String()
|
||||
res.TotalDuration = time.Since(checkpointStart)
|
||||
res.LoadDuration = checkpointLoaded.Sub(checkpointStart)
|
||||
}
|
||||
|
||||
// TODO: tool call checking and filtering should be moved outside of this callback once streaming
|
||||
// however this was a simple change for now without reworking streaming logic of this (and other)
|
||||
// handlers
|
||||
if req.Stream != nil && !*req.Stream || len(req.Tools) == 0 {
|
||||
ch <- res
|
||||
return
|
||||
if len(req.Tools) > 0 {
|
||||
toolCalls, content := toolParser.Add(res.Message.Content)
|
||||
if len(content) > 0 {
|
||||
res.Message.Content = content
|
||||
} else if len(toolCalls) > 0 {
|
||||
res.Message.ToolCalls = toolCalls
|
||||
res.Message.Content = ""
|
||||
} else if res.Message.Thinking != "" {
|
||||
// don't return
|
||||
} else {
|
||||
if r.Done {
|
||||
ch <- res
|
||||
}
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// Streaming tool calls:
|
||||
// If tools are recognized, use a flag to track the sending of a tool downstream
|
||||
// This ensures that content is cleared from the message on the last chunk sent
|
||||
sb.WriteString(r.Content)
|
||||
if toolCalls, ok := m.parseToolCalls(sb.String()); ok {
|
||||
res.Message.ToolCalls = toolCalls
|
||||
for i := range toolCalls {
|
||||
toolCalls[i].Function.Index = toolCallIndex
|
||||
toolCallIndex++
|
||||
}
|
||||
res.Message.Content = ""
|
||||
sb.Reset()
|
||||
ch <- res
|
||||
return
|
||||
}
|
||||
|
||||
if r.Done {
|
||||
// Send any remaining content if no tool calls were detected
|
||||
if toolCallIndex == 0 {
|
||||
res.Message.Content = sb.String()
|
||||
}
|
||||
ch <- res
|
||||
}
|
||||
ch <- res
|
||||
}); err != nil {
|
||||
ch <- gin.H{"error": err.Error()}
|
||||
}
|
||||
@@ -1550,12 +1597,18 @@ func (s *Server) ChatHandler(c *gin.Context) {
|
||||
|
||||
if req.Stream != nil && !*req.Stream {
|
||||
var resp api.ChatResponse
|
||||
var sb strings.Builder
|
||||
var toolCalls []api.ToolCall
|
||||
var sbThinking strings.Builder
|
||||
var sbContent strings.Builder
|
||||
for rr := range ch {
|
||||
switch t := rr.(type) {
|
||||
case api.ChatResponse:
|
||||
sb.WriteString(t.Message.Content)
|
||||
sbThinking.WriteString(t.Message.Thinking)
|
||||
sbContent.WriteString(t.Message.Content)
|
||||
resp = t
|
||||
if len(req.Tools) > 0 {
|
||||
toolCalls = append(toolCalls, t.Message.ToolCalls...)
|
||||
}
|
||||
case gin.H:
|
||||
msg, ok := t["error"].(string)
|
||||
if !ok {
|
||||
@@ -1570,13 +1623,11 @@ func (s *Server) ChatHandler(c *gin.Context) {
|
||||
}
|
||||
}
|
||||
|
||||
resp.Message.Content = sb.String()
|
||||
resp.Message.Content = sbContent.String()
|
||||
resp.Message.Thinking = sbThinking.String()
|
||||
|
||||
if len(req.Tools) > 0 {
|
||||
if toolCalls, ok := m.parseToolCalls(sb.String()); ok {
|
||||
resp.Message.ToolCalls = toolCalls
|
||||
resp.Message.Content = ""
|
||||
}
|
||||
if len(toolCalls) > 0 {
|
||||
resp.Message.ToolCalls = toolCalls
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, resp)
|
||||
@@ -1601,8 +1652,6 @@ func handleScheduleError(c *gin.Context, name string, err error) {
|
||||
}
|
||||
}
|
||||
|
||||
var thinkTagRegexp = regexp.MustCompile(`<think>(?s).*?</think>(\n)*`)
|
||||
|
||||
func filterThinkTags(msgs []api.Message, m *Model) []api.Message {
|
||||
if m.Config.ModelFamily == "qwen3" || model.ParseName(m.Name).Model == "deepseek-r1" {
|
||||
finalUserIndex := -1
|
||||
@@ -1614,7 +1663,17 @@ func filterThinkTags(msgs []api.Message, m *Model) []api.Message {
|
||||
|
||||
for i, msg := range msgs {
|
||||
if msg.Role == "assistant" && i < finalUserIndex {
|
||||
msgs[i].Content = thinkTagRegexp.ReplaceAllString(msg.Content, "")
|
||||
// TODO(drifkin): this is from before we added proper thinking support.
|
||||
// However, even if thinking is not enabled (and therefore we shouldn't
|
||||
// change the user output), we should probably perform this filtering
|
||||
// for all thinking models (not just qwen3 & deepseek-r1) since it tends
|
||||
// to save tokens and improve quality.
|
||||
thinkingState := &thinkingParser{
|
||||
openingTag: "<think>",
|
||||
closingTag: "</think>",
|
||||
}
|
||||
_, content := thinkingState.addContent(msg.Content)
|
||||
msgs[i].Content = content
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -143,6 +143,25 @@ func TestGenerateChat(t *testing.T) {
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("missing thinking capability", func(t *testing.T) {
|
||||
think := true
|
||||
w := createRequest(t, s.ChatHandler, api.ChatRequest{
|
||||
Model: "test",
|
||||
Messages: []api.Message{
|
||||
{Role: "user", Content: "Hello!"},
|
||||
},
|
||||
Think: &think,
|
||||
})
|
||||
|
||||
if w.Code != http.StatusBadRequest {
|
||||
t.Errorf("expected status 400, got %d", w.Code)
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(w.Body.String(), `{"error":"registry.ollama.ai/library/test:latest does not support thinking"}`); diff != "" {
|
||||
t.Errorf("mismatch (-got +want):\n%s", diff)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("missing model", func(t *testing.T) {
|
||||
w := createRequest(t, s.ChatHandler, api.ChatRequest{})
|
||||
if w.Code != http.StatusBadRequest {
|
||||
|
||||
@@ -387,6 +387,17 @@ func (s *Scheduler) processCompleted(ctx context.Context) {
|
||||
s.loadedMu.Unlock()
|
||||
runner.refMu.Unlock()
|
||||
slog.Debug("duplicate expired event, ignoring", "runner", runner)
|
||||
} else if runner.pid != runnerToUnload.pid {
|
||||
// If the pids do not match, we likely had multiple load
|
||||
// failures for the same model in quick succession due to
|
||||
// request context canceled and are draining the queue of
|
||||
// events. Ensure the orphaned runner is properly shut down, but
|
||||
// do not delete the mismatched loaded runner, or wait for VRAM
|
||||
// convergence.
|
||||
slog.Debug("orphaned runner shutting down", "orphan", runner, "loaded", runnerToUnload)
|
||||
runner.unload()
|
||||
s.loadedMu.Unlock()
|
||||
runner.refMu.Unlock()
|
||||
} else {
|
||||
slog.Debug("starting background wait for VRAM recovery", "runner", runner)
|
||||
finished := runner.waitForVRAMRecovery()
|
||||
|
||||
300
server/thinking.go
Normal file
300
server/thinking.go
Normal file
@@ -0,0 +1,300 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"text/template"
|
||||
"text/template/parse"
|
||||
"unicode"
|
||||
)
|
||||
|
||||
type thinkingState int
|
||||
|
||||
const (
|
||||
// We're looking for the opening tag, but we haven't seen any non-whitespace
|
||||
// characters yet
|
||||
thinkingState_LookingForOpening thinkingState = iota
|
||||
// We've seen the opening tag, but we haven't seen any non-whitespace
|
||||
// characters yet (we want to eat any whitespace between the opening tag and
|
||||
// the thinking content)
|
||||
thinkingState_ThinkingStartedEatingWhitespace
|
||||
// We've seen non-whitespace characters after the opening tag, but we haven't
|
||||
// seen the closing tag yet
|
||||
thinkingState_Thinking
|
||||
// We've seen the closing tag, but we haven't seen any non-whitespace
|
||||
// characters after the closing tag yet (we want to eat any whitespace between
|
||||
// the closing tag and the content)
|
||||
thinkingState_ThinkingDoneEatingWhitespace
|
||||
// We've seen the closing tag and seen at least one non-whitespace character
|
||||
// after it
|
||||
thinkingState_ThinkingDone
|
||||
)
|
||||
|
||||
func (s thinkingState) String() string {
|
||||
switch s {
|
||||
case thinkingState_LookingForOpening:
|
||||
return "LookingForOpening"
|
||||
case thinkingState_ThinkingStartedEatingWhitespace:
|
||||
return "ThinkingStartedEatingWhitespace"
|
||||
case thinkingState_Thinking:
|
||||
return "Thinking"
|
||||
case thinkingState_ThinkingDoneEatingWhitespace:
|
||||
return "ThinkingDoneEatingWhitespace"
|
||||
case thinkingState_ThinkingDone:
|
||||
return "ThinkingDone"
|
||||
default:
|
||||
return "Unknown"
|
||||
}
|
||||
}
|
||||
|
||||
type thinkingParser struct {
|
||||
state thinkingState
|
||||
openingTag string
|
||||
closingTag string
|
||||
acc strings.Builder
|
||||
}
|
||||
|
||||
// addContent returns the thinking content and the non-thinking content that
|
||||
// should be immediately sent to the user. It will internally buffer if it needs
|
||||
// to see more raw content to disambiguate
|
||||
func (s *thinkingParser) addContent(content string) (string, string) {
|
||||
s.acc.WriteString(content)
|
||||
|
||||
var thinkingSb, remainingSb strings.Builder
|
||||
|
||||
var thinking, remaining string
|
||||
keepLooping := true
|
||||
// we loop because we might pass through multiple parsing states in a single
|
||||
// call to addContent, and we want to make sure callers don't have to wait for
|
||||
// data that's already unambiguous
|
||||
for keepLooping {
|
||||
thinking, remaining, keepLooping = eat(s)
|
||||
thinkingSb.WriteString(thinking)
|
||||
remainingSb.WriteString(remaining)
|
||||
}
|
||||
|
||||
return thinkingSb.String(), remainingSb.String()
|
||||
}
|
||||
|
||||
// the additional bool return is true iff we should continue eating
|
||||
func eat(s *thinkingParser) (string, string, bool) {
|
||||
switch s.state {
|
||||
case thinkingState_LookingForOpening:
|
||||
trimmed := strings.TrimLeftFunc(s.acc.String(), unicode.IsSpace)
|
||||
if strings.HasPrefix(trimmed, s.openingTag) {
|
||||
after := strings.Join(strings.Split(trimmed, s.openingTag)[1:], s.openingTag)
|
||||
after = strings.TrimLeftFunc(after, unicode.IsSpace)
|
||||
// after might contain more than just thinking tokens, so we continue
|
||||
// parsing instead of returning it as thinking tokens here
|
||||
s.acc.Reset()
|
||||
s.acc.WriteString(after)
|
||||
if after == "" {
|
||||
s.state = thinkingState_ThinkingStartedEatingWhitespace
|
||||
} else {
|
||||
s.state = thinkingState_Thinking
|
||||
}
|
||||
return "", "", true
|
||||
} else if strings.HasPrefix(s.openingTag, trimmed) {
|
||||
// partial opening seen, so let's keep accumulating
|
||||
return "", "", false
|
||||
} else if trimmed == "" {
|
||||
// saw whitespace only, so let's keep accumulating
|
||||
return "", "", false
|
||||
} else {
|
||||
// didn't see an opening tag, but we have content, so thinking was skipped
|
||||
s.state = thinkingState_ThinkingDone
|
||||
// note that we use the original content, not the trimmed one because we
|
||||
// don't want to eat any whitespace in the real content if there were no
|
||||
// thinking tags
|
||||
return "", s.acc.String(), false
|
||||
}
|
||||
case thinkingState_ThinkingStartedEatingWhitespace:
|
||||
trimmed := strings.TrimLeftFunc(s.acc.String(), unicode.IsSpace)
|
||||
s.acc.Reset()
|
||||
if trimmed == "" {
|
||||
return "", "", false
|
||||
} else {
|
||||
s.state = thinkingState_Thinking
|
||||
s.acc.WriteString(trimmed)
|
||||
return "", "", true
|
||||
}
|
||||
case thinkingState_Thinking:
|
||||
acc := s.acc.String()
|
||||
if strings.Contains(acc, s.closingTag) {
|
||||
split := strings.Split(acc, s.closingTag)
|
||||
thinking := split[0]
|
||||
remaining := strings.Join(split[1:], s.closingTag)
|
||||
remaining = strings.TrimLeftFunc(remaining, unicode.IsSpace)
|
||||
s.acc.Reset()
|
||||
if remaining == "" {
|
||||
s.state = thinkingState_ThinkingDoneEatingWhitespace
|
||||
} else {
|
||||
s.state = thinkingState_ThinkingDone
|
||||
}
|
||||
return thinking, remaining, false
|
||||
} else if overlapLen := overlap(acc, s.closingTag); overlapLen > 0 {
|
||||
thinking := acc[:len(acc)-overlapLen]
|
||||
remaining := acc[len(acc)-overlapLen:]
|
||||
s.acc.Reset()
|
||||
// keep track of the candidate closing tag. We have to buffer it until it
|
||||
// becomes disambiguated
|
||||
s.acc.WriteString(remaining)
|
||||
return thinking, "", false
|
||||
} else {
|
||||
// purely just thinking tokens, so we can return them
|
||||
s.acc.Reset()
|
||||
return acc, "", false
|
||||
}
|
||||
case thinkingState_ThinkingDoneEatingWhitespace:
|
||||
trimmed := strings.TrimLeftFunc(s.acc.String(), unicode.IsSpace)
|
||||
s.acc.Reset()
|
||||
// if we see non-whitespace, we're done eating the leading whitespace of the content
|
||||
if trimmed != "" {
|
||||
s.state = thinkingState_ThinkingDone
|
||||
}
|
||||
return "", trimmed, false
|
||||
case thinkingState_ThinkingDone:
|
||||
acc := s.acc.String()
|
||||
s.acc.Reset()
|
||||
return "", acc, false
|
||||
default:
|
||||
panic("unknown state")
|
||||
}
|
||||
}
|
||||
|
||||
// longest overlap between suffix of s and prefix of delim
|
||||
func overlap(s, delim string) int {
|
||||
max := min(len(delim), len(s))
|
||||
for i := max; i > 0; i-- {
|
||||
if strings.HasSuffix(s, delim[:i]) {
|
||||
return i
|
||||
}
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
func templateVisit(n parse.Node, enterFn func(parse.Node) bool, exitFn func(parse.Node)) {
|
||||
if n == nil {
|
||||
return
|
||||
}
|
||||
shouldContinue := enterFn(n)
|
||||
if !shouldContinue {
|
||||
return
|
||||
}
|
||||
switch x := n.(type) {
|
||||
case *parse.ListNode:
|
||||
for _, c := range x.Nodes {
|
||||
templateVisit(c, enterFn, exitFn)
|
||||
}
|
||||
case *parse.BranchNode:
|
||||
if x.Pipe != nil {
|
||||
templateVisit(x.Pipe, enterFn, exitFn)
|
||||
}
|
||||
if x.List != nil {
|
||||
templateVisit(x.List, enterFn, exitFn)
|
||||
}
|
||||
if x.ElseList != nil {
|
||||
templateVisit(x.ElseList, enterFn, exitFn)
|
||||
}
|
||||
case *parse.ActionNode:
|
||||
templateVisit(x.Pipe, enterFn, exitFn)
|
||||
case *parse.WithNode:
|
||||
templateVisit(&x.BranchNode, enterFn, exitFn)
|
||||
case *parse.RangeNode:
|
||||
templateVisit(&x.BranchNode, enterFn, exitFn)
|
||||
case *parse.IfNode:
|
||||
templateVisit(&x.BranchNode, enterFn, exitFn)
|
||||
case *parse.TemplateNode:
|
||||
templateVisit(x.Pipe, enterFn, exitFn)
|
||||
case *parse.PipeNode:
|
||||
for _, c := range x.Cmds {
|
||||
templateVisit(c, enterFn, exitFn)
|
||||
}
|
||||
case *parse.CommandNode:
|
||||
for _, a := range x.Args {
|
||||
templateVisit(a, enterFn, exitFn)
|
||||
}
|
||||
// text, field, number, etc. are leaves – nothing to recurse into
|
||||
}
|
||||
if exitFn != nil {
|
||||
exitFn(n)
|
||||
}
|
||||
}
|
||||
|
||||
// We use a heuristic to infer the tags that surround thinking traces:
|
||||
// We look for a range node that iterates over "Messages" and then look for a
|
||||
// reference to "Thinking" like `{{.Thinking}}`. We then go up to the nearest
|
||||
// ListNode and take the first and last TextNodes as the opening and closing
|
||||
// tags.
|
||||
func inferThinkingTags(t *template.Template) (string, string) {
|
||||
ancestors := []parse.Node{}
|
||||
|
||||
openingTag := ""
|
||||
closingTag := ""
|
||||
|
||||
enterFn := func(n parse.Node) bool {
|
||||
ancestors = append(ancestors, n)
|
||||
|
||||
switch x := n.(type) {
|
||||
case *parse.FieldNode:
|
||||
if len(x.Ident) > 0 && x.Ident[0] == "Thinking" {
|
||||
var mostRecentRange *parse.RangeNode
|
||||
for i := len(ancestors) - 1; i >= 0; i-- {
|
||||
if r, ok := ancestors[i].(*parse.RangeNode); ok {
|
||||
mostRecentRange = r
|
||||
break
|
||||
}
|
||||
}
|
||||
if mostRecentRange == nil || !rangeUsesField(mostRecentRange, "Messages") {
|
||||
return true
|
||||
}
|
||||
|
||||
// TODO(drifkin): to be more robust, check that it's in the action
|
||||
// part, not the `if`'s pipeline part. We do match on the nearest list
|
||||
// that starts and ends with text nodes, which makes this not strictly
|
||||
// necessary for our heuristic
|
||||
|
||||
// go up to the nearest ancestor that is a *parse.ListNode
|
||||
for i := len(ancestors) - 1; i >= 0; i-- {
|
||||
if l, ok := ancestors[i].(*parse.ListNode); ok {
|
||||
firstNode := l.Nodes[0]
|
||||
if t, ok := firstNode.(*parse.TextNode); ok {
|
||||
openingTag = strings.TrimSpace(t.String())
|
||||
}
|
||||
lastNode := l.Nodes[len(l.Nodes)-1]
|
||||
if t, ok := lastNode.(*parse.TextNode); ok {
|
||||
closingTag = strings.TrimSpace(t.String())
|
||||
}
|
||||
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
exitFn := func(n parse.Node) {
|
||||
ancestors = ancestors[:len(ancestors)-1]
|
||||
}
|
||||
|
||||
templateVisit(t.Root, enterFn, exitFn)
|
||||
|
||||
return openingTag, closingTag
|
||||
}
|
||||
|
||||
// checks to see if the given field name is present in the pipeline of the given range node
|
||||
func rangeUsesField(rangeNode *parse.RangeNode, field string) bool {
|
||||
found := false
|
||||
enterFn := func(n parse.Node) bool {
|
||||
switch x := n.(type) {
|
||||
case *parse.FieldNode:
|
||||
if x.Ident[0] == field {
|
||||
found = true
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
templateVisit(rangeNode.BranchNode.Pipe, enterFn, nil)
|
||||
return found
|
||||
}
|
||||
403
server/thinking_test.go
Normal file
403
server/thinking_test.go
Normal file
@@ -0,0 +1,403 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"text/template"
|
||||
)
|
||||
|
||||
func TestExtractThinking(t *testing.T) {
|
||||
tests := []struct {
|
||||
in, wantContent, wantThink string
|
||||
}{
|
||||
{
|
||||
in: "<think> internal </think> world",
|
||||
wantThink: "internal ",
|
||||
wantContent: "world",
|
||||
},
|
||||
{
|
||||
in: "<think>a</think><think>b</think>c",
|
||||
wantThink: "a",
|
||||
wantContent: "<think>b</think>c",
|
||||
},
|
||||
{
|
||||
in: "no think",
|
||||
wantThink: "",
|
||||
wantContent: "no think",
|
||||
},
|
||||
}
|
||||
for i, tt := range tests {
|
||||
parser := thinkingParser{
|
||||
openingTag: "<think>",
|
||||
closingTag: "</think>",
|
||||
}
|
||||
gotThinking, gotContent := parser.addContent(tt.in)
|
||||
if gotContent != tt.wantContent || gotThinking != tt.wantThink {
|
||||
t.Errorf("case %d: got (%q,%q), want (%q,%q)", i, gotThinking, gotContent, tt.wantThink, tt.wantContent)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestThinkingStreaming(t *testing.T) {
|
||||
type step struct {
|
||||
input string
|
||||
wantThinking string
|
||||
wantContent string
|
||||
wantStateAfter thinkingState
|
||||
}
|
||||
|
||||
cases := []struct {
|
||||
desc string
|
||||
skip bool
|
||||
steps []step
|
||||
}{
|
||||
{
|
||||
desc: "content without a thinking tag",
|
||||
steps: []step{
|
||||
{
|
||||
input: " abc",
|
||||
wantThinking: "",
|
||||
wantContent: " abc",
|
||||
wantStateAfter: thinkingState_ThinkingDone,
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "content before a thinking tag nerfs the thinking tag",
|
||||
steps: []step{
|
||||
{
|
||||
input: " abc <think>def</think> ghi",
|
||||
wantThinking: "",
|
||||
wantContent: " abc <think>def</think> ghi",
|
||||
wantStateAfter: thinkingState_ThinkingDone,
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "building up a thinking tag partially",
|
||||
steps: []step{
|
||||
{
|
||||
input: " <th",
|
||||
wantThinking: "",
|
||||
wantContent: "",
|
||||
wantStateAfter: thinkingState_LookingForOpening,
|
||||
},
|
||||
{
|
||||
input: "in",
|
||||
wantThinking: "",
|
||||
wantContent: "",
|
||||
wantStateAfter: thinkingState_LookingForOpening,
|
||||
},
|
||||
{
|
||||
input: "k>a",
|
||||
wantThinking: "a",
|
||||
wantContent: "",
|
||||
wantStateAfter: thinkingState_Thinking,
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "partial closing tag",
|
||||
steps: []step{
|
||||
{
|
||||
input: "<think>abc</th",
|
||||
wantThinking: "abc",
|
||||
wantContent: "",
|
||||
wantStateAfter: thinkingState_Thinking,
|
||||
},
|
||||
{
|
||||
input: "ink>def",
|
||||
wantThinking: "",
|
||||
wantContent: "def",
|
||||
wantStateAfter: thinkingState_ThinkingDone,
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "partial closing tag fakeout",
|
||||
steps: []step{
|
||||
{
|
||||
input: "<think>abc</th",
|
||||
wantThinking: "abc",
|
||||
wantContent: "",
|
||||
wantStateAfter: thinkingState_Thinking,
|
||||
},
|
||||
{
|
||||
input: "ing>def",
|
||||
wantThinking: "</thing>def",
|
||||
wantContent: "",
|
||||
wantStateAfter: thinkingState_Thinking,
|
||||
},
|
||||
{
|
||||
input: "ghi</thi",
|
||||
wantThinking: "ghi",
|
||||
wantContent: "",
|
||||
wantStateAfter: thinkingState_Thinking,
|
||||
},
|
||||
{
|
||||
input: "nk>jkl",
|
||||
wantThinking: "",
|
||||
wantContent: "jkl",
|
||||
wantStateAfter: thinkingState_ThinkingDone,
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "whitespace after thinking tag",
|
||||
steps: []step{
|
||||
{
|
||||
input: " <think>abc</think>\n\ndef",
|
||||
wantThinking: "abc",
|
||||
wantContent: "def",
|
||||
wantStateAfter: thinkingState_ThinkingDone,
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "whitespace after thinking tag (incremental)",
|
||||
steps: []step{
|
||||
{
|
||||
input: " <think>abc</think>",
|
||||
wantThinking: "abc",
|
||||
wantContent: "",
|
||||
wantStateAfter: thinkingState_ThinkingDoneEatingWhitespace,
|
||||
},
|
||||
{
|
||||
input: "\n\ndef",
|
||||
wantThinking: "",
|
||||
wantContent: "def",
|
||||
wantStateAfter: thinkingState_ThinkingDone,
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "whitespace after thinking tag with content and more whitespace",
|
||||
steps: []step{
|
||||
{
|
||||
input: " <think>abc</think>\n\ndef ",
|
||||
wantThinking: "abc",
|
||||
wantContent: "def ",
|
||||
wantStateAfter: thinkingState_ThinkingDone,
|
||||
},
|
||||
{
|
||||
input: " ghi",
|
||||
wantThinking: "",
|
||||
wantContent: " ghi",
|
||||
wantStateAfter: thinkingState_ThinkingDone,
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "token by token",
|
||||
steps: []step{
|
||||
{
|
||||
input: "<think>",
|
||||
wantThinking: "",
|
||||
wantContent: "",
|
||||
wantStateAfter: thinkingState_ThinkingStartedEatingWhitespace,
|
||||
},
|
||||
{
|
||||
input: "\n",
|
||||
wantThinking: "",
|
||||
wantContent: "",
|
||||
wantStateAfter: thinkingState_ThinkingStartedEatingWhitespace,
|
||||
},
|
||||
{
|
||||
input: "</think>",
|
||||
wantThinking: "",
|
||||
wantContent: "",
|
||||
wantStateAfter: thinkingState_ThinkingDoneEatingWhitespace,
|
||||
},
|
||||
{
|
||||
input: "\n\n",
|
||||
wantThinking: "",
|
||||
wantContent: "",
|
||||
wantStateAfter: thinkingState_ThinkingDoneEatingWhitespace,
|
||||
},
|
||||
{
|
||||
input: "Hi",
|
||||
wantThinking: "",
|
||||
wantContent: "Hi",
|
||||
wantStateAfter: thinkingState_ThinkingDone,
|
||||
},
|
||||
{
|
||||
input: " there",
|
||||
wantThinking: "",
|
||||
wantContent: " there",
|
||||
wantStateAfter: thinkingState_ThinkingDone,
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "leading thinking whitespace",
|
||||
steps: []step{
|
||||
{
|
||||
input: " <think> \t ",
|
||||
wantThinking: "",
|
||||
wantContent: "",
|
||||
wantStateAfter: thinkingState_ThinkingStartedEatingWhitespace,
|
||||
},
|
||||
{
|
||||
input: " these are some ",
|
||||
wantThinking: "these are some ",
|
||||
wantContent: "",
|
||||
wantStateAfter: thinkingState_Thinking,
|
||||
},
|
||||
{
|
||||
input: "thoughts </think> ",
|
||||
wantThinking: "thoughts ",
|
||||
wantContent: "",
|
||||
wantStateAfter: thinkingState_ThinkingDoneEatingWhitespace,
|
||||
},
|
||||
{
|
||||
input: " more content",
|
||||
wantThinking: "",
|
||||
wantContent: "more content",
|
||||
wantStateAfter: thinkingState_ThinkingDone,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, c := range cases {
|
||||
parser := thinkingParser{
|
||||
openingTag: "<think>",
|
||||
closingTag: "</think>",
|
||||
}
|
||||
if c.skip {
|
||||
continue
|
||||
}
|
||||
for i, step := range c.steps {
|
||||
thinking, content := parser.addContent(step.input)
|
||||
if content != step.wantContent || thinking != step.wantThinking {
|
||||
t.Errorf("case %q (step %d): got (%q,%q), want (%q,%q)", c.desc, i, content, thinking, step.wantContent, step.wantThinking)
|
||||
}
|
||||
if parser.state != step.wantStateAfter {
|
||||
t.Errorf("case %q (step %d): got state %s, want %s", c.desc, i, parser.state, step.wantStateAfter)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestInferThinkingTags(t *testing.T) {
|
||||
cases := []struct {
|
||||
desc string
|
||||
tmplString string
|
||||
wantOpeningTag string
|
||||
wantClosingTag string
|
||||
}{
|
||||
{
|
||||
desc: "basic",
|
||||
tmplString: `
|
||||
{{ if .Thinking}}
|
||||
/think
|
||||
{{ end }}
|
||||
{{- range $i, $_ := .Messages }}
|
||||
{{- $last := eq (len (slice $.Messages $i)) 1 -}}
|
||||
{{ if and $last .Thinking }}
|
||||
<think>{{ .Thinking }}</think>
|
||||
{{ end }}
|
||||
{{ end }}
|
||||
`,
|
||||
wantOpeningTag: "<think>",
|
||||
wantClosingTag: "</think>",
|
||||
},
|
||||
{
|
||||
desc: "doubly nested range",
|
||||
tmplString: `
|
||||
{{ if .Thinking}}
|
||||
/think
|
||||
{{ end }}
|
||||
{{- range $i, $_ := .Messages }}
|
||||
{{- range $j, $_ := .NotMessages }}
|
||||
{{- $last := eq (len (slice $.Messages $i)) 1 -}}
|
||||
{{ if and $last .Thinking }}
|
||||
<think>{{ .Thinking }}</think>
|
||||
{{ end }}
|
||||
{{ end }}
|
||||
{{ end }}
|
||||
`,
|
||||
wantOpeningTag: "",
|
||||
wantClosingTag: "",
|
||||
},
|
||||
{
|
||||
desc: "whitespace is trimmed",
|
||||
tmplString: `
|
||||
{{ if .Thinking}}
|
||||
/think
|
||||
{{ end }}
|
||||
{{- range $i, $_ := .Messages }}
|
||||
{{- $last := eq (len (slice $.Messages $i)) 1 -}}
|
||||
{{ if and $last .Thinking }}
|
||||
Some text before {{ .Thinking }} Some text after
|
||||
{{ end }}
|
||||
{{ end }}
|
||||
`,
|
||||
wantOpeningTag: "Some text before",
|
||||
wantClosingTag: "Some text after",
|
||||
},
|
||||
{
|
||||
desc: "qwen3",
|
||||
tmplString: `
|
||||
{{- if or .System .Tools .Thinking }}<|im_start|>system
|
||||
{{- if .System }}
|
||||
{{ .System }}
|
||||
{{- end }}
|
||||
{{- if .Tools }}
|
||||
|
||||
# Tools
|
||||
|
||||
You may call one or more functions to assist with the user query.
|
||||
|
||||
You are provided with function signatures within <tools></tools> XML tags:
|
||||
<tools>
|
||||
{{- range .Tools }}
|
||||
{"type": "function", "function": {{ .Function }}}
|
||||
{{- end }}
|
||||
</tools>
|
||||
|
||||
For each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:
|
||||
<tool_call>
|
||||
{"name": <function-name>, "arguments": <args-json-object>}
|
||||
</tool_call>
|
||||
{{- end }}
|
||||
{{- if .Thinking }}
|
||||
/think
|
||||
{{- else }}
|
||||
/no_think
|
||||
{{- end }}<|im_end|>
|
||||
{{ end }}
|
||||
{{- range $i, $_ := .Messages }}
|
||||
{{- $last := eq (len (slice $.Messages $i)) 1 -}}
|
||||
{{- if eq .Role "user" }}<|im_start|>user
|
||||
{{ .Content }}<|im_end|>
|
||||
{{ else if eq .Role "assistant" }}<|im_start|>assistant
|
||||
{{ if and $last .Thinking }}
|
||||
<think>{{ .Thinking }}</think>
|
||||
{{ end }}
|
||||
{{ if .Content }}{{ .Content }}
|
||||
{{- else if .ToolCalls }}<tool_call>
|
||||
{{ range .ToolCalls }}{"name": "{{ .Function.Name }}", "arguments": {{ .Function.Arguments }}}
|
||||
{{ end }}</tool_call>
|
||||
{{- end }}{{ if not $last }}<|im_end|>
|
||||
{{ end }}
|
||||
{{- else if eq .Role "tool" }}<|im_start|>user
|
||||
<tool_response>
|
||||
{{ .Content }}
|
||||
</tool_response><|im_end|>
|
||||
{{ end }}
|
||||
{{- if and (ne .Role "assistant") $last }}<|im_start|>assistant
|
||||
{{ end }}
|
||||
{{- end }}
|
||||
`,
|
||||
wantOpeningTag: "<think>",
|
||||
wantClosingTag: "</think>",
|
||||
},
|
||||
}
|
||||
for _, c := range cases {
|
||||
tmpl := template.Must(template.New("test").Parse(c.tmplString))
|
||||
openingTag, closingTag := inferThinkingTags(tmpl)
|
||||
if openingTag != c.wantOpeningTag || closingTag != c.wantClosingTag {
|
||||
t.Errorf("case %q: got (%q,%q), want (%q,%q)", c.desc, openingTag, closingTag, c.wantOpeningTag, c.wantClosingTag)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -167,6 +167,10 @@ type Values struct {
|
||||
api.Tools
|
||||
Prompt string
|
||||
Suffix string
|
||||
Think bool
|
||||
// whether or not the user explicitly set the thinking flag (vs. it being
|
||||
// implicitly false). Templates can't see whether `Think` is nil
|
||||
IsThinkSet bool
|
||||
|
||||
// forceLegacy is a flag used to test compatibility with legacy templates
|
||||
forceLegacy bool
|
||||
@@ -222,16 +226,20 @@ func (t *Template) Execute(w io.Writer, v Values) error {
|
||||
system, messages := collate(v.Messages)
|
||||
if v.Prompt != "" && v.Suffix != "" {
|
||||
return t.Template.Execute(w, map[string]any{
|
||||
"Prompt": v.Prompt,
|
||||
"Suffix": v.Suffix,
|
||||
"Response": "",
|
||||
"Prompt": v.Prompt,
|
||||
"Suffix": v.Suffix,
|
||||
"Response": "",
|
||||
"Think": v.Think,
|
||||
"IsThinkSet": v.IsThinkSet,
|
||||
})
|
||||
} else if !v.forceLegacy && slices.Contains(t.Vars(), "messages") {
|
||||
return t.Template.Execute(w, map[string]any{
|
||||
"System": system,
|
||||
"Messages": messages,
|
||||
"Tools": v.Tools,
|
||||
"Response": "",
|
||||
"System": system,
|
||||
"Messages": messages,
|
||||
"Tools": v.Tools,
|
||||
"Response": "",
|
||||
"Think": v.Think,
|
||||
"IsThinkSet": v.IsThinkSet,
|
||||
})
|
||||
}
|
||||
|
||||
@@ -241,9 +249,11 @@ func (t *Template) Execute(w io.Writer, v Values) error {
|
||||
for _, m := range messages {
|
||||
execute := func() error {
|
||||
if err := t.Template.Execute(&b, map[string]any{
|
||||
"System": system,
|
||||
"Prompt": prompt,
|
||||
"Response": response,
|
||||
"System": system,
|
||||
"Prompt": prompt,
|
||||
"Response": response,
|
||||
"Think": v.Think,
|
||||
"IsThinkSet": v.IsThinkSet,
|
||||
}); err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -286,9 +296,11 @@ func (t *Template) Execute(w io.Writer, v Values) error {
|
||||
|
||||
tree := parse.Tree{Root: nodes.(*parse.ListNode)}
|
||||
if err := template.Must(template.New("").AddParseTree("", &tree)).Execute(&b, map[string]any{
|
||||
"System": system,
|
||||
"Prompt": prompt,
|
||||
"Response": response,
|
||||
"System": system,
|
||||
"Prompt": prompt,
|
||||
"Response": response,
|
||||
"Think": v.Think,
|
||||
"IsThinkSet": v.IsThinkSet,
|
||||
}); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
44
tools/testdata/llama3.2.gotmpl
vendored
Normal file
44
tools/testdata/llama3.2.gotmpl
vendored
Normal file
@@ -0,0 +1,44 @@
|
||||
<|start_header_id|>system<|end_header_id|>
|
||||
|
||||
Cutting Knowledge Date: December 2023
|
||||
|
||||
{{ if .System }}{{ .System }}
|
||||
{{- end }}
|
||||
{{- if .Tools }}When you receive a tool call response, use the output to format an answer to the orginal user question.
|
||||
|
||||
You are a helpful assistant with tool calling capabilities.
|
||||
{{- end }}<|eot_id|>
|
||||
{{- range $i, $_ := .Messages }}
|
||||
{{- $last := eq (len (slice $.Messages $i)) 1 }}
|
||||
{{- if eq .Role "user" }}<|start_header_id|>user<|end_header_id|>
|
||||
{{- if and $.Tools $last }}
|
||||
|
||||
Given the following functions, please respond with a JSON for a function call with its proper arguments that best answers the given prompt.
|
||||
|
||||
Respond in the format {"name": function name, "parameters": dictionary of argument name and its value}. Do not use variables.
|
||||
|
||||
{{ range $.Tools }}
|
||||
{{- . }}
|
||||
{{ end }}
|
||||
{{ .Content }}<|eot_id|>
|
||||
{{- else }}
|
||||
|
||||
{{ .Content }}<|eot_id|>
|
||||
{{- end }}{{ if $last }}<|start_header_id|>assistant<|end_header_id|>
|
||||
|
||||
{{ end }}
|
||||
{{- else if eq .Role "assistant" }}<|start_header_id|>assistant<|end_header_id|>
|
||||
{{- if .ToolCalls }}
|
||||
{{ range .ToolCalls }}
|
||||
{"name": "{{ .Function.Name }}", "parameters": {{ .Function.Arguments }}}{{ end }}
|
||||
{{- else }}
|
||||
|
||||
{{ .Content }}
|
||||
{{- end }}{{ if not $last }}<|eot_id|>{{ end }}
|
||||
{{- else if eq .Role "tool" }}<|start_header_id|>ipython<|end_header_id|>
|
||||
|
||||
{{ .Content }}<|eot_id|>{{ if $last }}<|start_header_id|>assistant<|end_header_id|>
|
||||
|
||||
{{ end }}
|
||||
{{- end }}
|
||||
{{- end }}
|
||||
24
tools/testdata/llama3.2.out
vendored
Normal file
24
tools/testdata/llama3.2.out
vendored
Normal file
@@ -0,0 +1,24 @@
|
||||
<|start_header_id|>system<|end_header_id|>
|
||||
|
||||
Cutting Knowledge Date: December 2023
|
||||
|
||||
You are a knowledgeable assistant. You can answer questions and perform tasks.When you receive a tool call response, use the output to format an answer to the orginal user question.
|
||||
|
||||
You are a helpful assistant with tool calling capabilities.<|eot_id|><|start_header_id|>user<|end_header_id|>
|
||||
|
||||
What's the weather like today in Paris?<|eot_id|><|start_header_id|>assistant<|end_header_id|>
|
||||
|
||||
{"name": "get_current_weather", "parameters": {"format":"celsius","location":"Paris, France"}}<|eot_id|><|start_header_id|>ipython<|end_header_id|>
|
||||
|
||||
22<|eot_id|><|start_header_id|>assistant<|end_header_id|>
|
||||
|
||||
The current temperature in Paris, France is 22 degrees Celsius.<|eot_id|><|start_header_id|>user<|end_header_id|>
|
||||
|
||||
Given the following functions, please respond with a JSON for a function call with its proper arguments that best answers the given prompt.
|
||||
|
||||
Respond in the format {"name": function name, "parameters": dictionary of argument name and its value}. Do not use variables.
|
||||
|
||||
{"type":"function","function":{"name":"get_current_weather","description":"Get the current weather","parameters":{"type":"object","required":["location","format"],"properties":{"format":{"type":"string","description":"The temperature unit to use. Infer this from the user's location.","enum":["celsius","fahrenheit"]},"location":{"type":"string","description":"The city and state, e.g. San Francisco, CA"}}}}}
|
||||
|
||||
What's the weather like today in San Francisco and Toronto?<|eot_id|><|start_header_id|>assistant<|end_header_id|>
|
||||
|
||||
51
tools/testdata/qwen2.5.gotmpl
vendored
Normal file
51
tools/testdata/qwen2.5.gotmpl
vendored
Normal file
@@ -0,0 +1,51 @@
|
||||
{{- if .Suffix }}<|fim_prefix|>{{ .Prompt }}<|fim_suffix|>{{ .Suffix }}<|fim_middle|>
|
||||
{{- else if .Messages }}
|
||||
{{- if or .System .Tools }}<|im_start|>system
|
||||
{{- if .System }}
|
||||
{{ .System }}
|
||||
{{- end }}
|
||||
{{- if .Tools }}
|
||||
|
||||
# Tools
|
||||
|
||||
You may call one or more functions to assist with the user query.
|
||||
|
||||
You are provided with function signatures within <tools></tools> XML tags:
|
||||
<tools>
|
||||
{{- range .Tools }}
|
||||
{"type": "function", "function": {{ .Function }}}
|
||||
{{- end }}
|
||||
</tools>
|
||||
|
||||
For each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:
|
||||
<tool_call>
|
||||
{"name": <function-name>, "arguments": <args-json-object>}
|
||||
</tool_call>
|
||||
{{- end }}<|im_end|>
|
||||
{{ end }}
|
||||
{{- range $i, $_ := .Messages }}
|
||||
{{- $last := eq (len (slice $.Messages $i)) 1 -}}
|
||||
{{- if eq .Role "user" }}<|im_start|>user
|
||||
{{ .Content }}<|im_end|>
|
||||
{{ else if eq .Role "assistant" }}<|im_start|>assistant
|
||||
{{ if .Content }}{{ .Content }}
|
||||
{{- else if .ToolCalls }}<tool_call>
|
||||
{{ range .ToolCalls }}{"name": "{{ .Function.Name }}", "arguments": {{ .Function.Arguments }}}
|
||||
{{ end }}</tool_call>
|
||||
{{- end }}{{ if not $last }}<|im_end|>
|
||||
{{ end }}
|
||||
{{- else if eq .Role "tool" }}<|im_start|>user
|
||||
<tool_response>
|
||||
{{ .Content }}
|
||||
</tool_response><|im_end|>
|
||||
{{ end }}
|
||||
{{- if and (ne .Role "assistant") $last }}<|im_start|>assistant
|
||||
{{ end }}
|
||||
{{- end }}
|
||||
{{- else }}
|
||||
{{- if .System }}<|im_start|>system
|
||||
{{ .System }}<|im_end|>
|
||||
{{ end }}{{ if .Prompt }}<|im_start|>user
|
||||
{{ .Prompt }}<|im_end|>
|
||||
{{ end }}<|im_start|>assistant
|
||||
{{ end }}{{ .Response }}{{ if .Response }}<|im_end|>{{ end }}
|
||||
31
tools/testdata/qwen2.5.out
vendored
Normal file
31
tools/testdata/qwen2.5.out
vendored
Normal file
@@ -0,0 +1,31 @@
|
||||
<|im_start|>system
|
||||
You are a knowledgeable assistant. You can answer questions and perform tasks.
|
||||
|
||||
# Tools
|
||||
|
||||
You may call one or more functions to assist with the user query.
|
||||
|
||||
You are provided with function signatures within <tools></tools> XML tags:
|
||||
<tools>
|
||||
{"type": "function", "function": {"name":"get_current_weather","description":"Get the current weather","parameters":{"type":"object","required":["location","format"],"properties":{"format":{"type":"string","description":"The temperature unit to use. Infer this from the user's location.","enum":["celsius","fahrenheit"]},"location":{"type":"string","description":"The city and state, e.g. San Francisco, CA"}}}}}
|
||||
</tools>
|
||||
|
||||
For each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:
|
||||
<tool_call>
|
||||
{"name": <function-name>, "arguments": <args-json-object>}
|
||||
</tool_call><|im_end|>
|
||||
<|im_start|>user
|
||||
What's the weather like today in Paris?<|im_end|>
|
||||
<|im_start|>assistant
|
||||
<tool_call>
|
||||
{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Paris, France"}}
|
||||
</tool_call><|im_end|>
|
||||
<|im_start|>user
|
||||
<tool_response>
|
||||
22
|
||||
</tool_response><|im_end|>
|
||||
<|im_start|>assistant
|
||||
The current temperature in Paris, France is 22 degrees Celsius.<|im_end|>
|
||||
<|im_start|>user
|
||||
What's the weather like today in San Francisco and Toronto?<|im_end|>
|
||||
<|im_start|>assistant
|
||||
50
tools/testdata/qwen3.gotmpl
vendored
Normal file
50
tools/testdata/qwen3.gotmpl
vendored
Normal file
@@ -0,0 +1,50 @@
|
||||
{{- if .Messages }}
|
||||
{{- if or .System .Tools }}<|im_start|>system
|
||||
{{- if .System }}
|
||||
{{ .System }}
|
||||
{{- end }}
|
||||
{{- if .Tools }}
|
||||
|
||||
# Tools
|
||||
|
||||
You may call one or more functions to assist with the user query.
|
||||
|
||||
You are provided with function signatures within <tools></tools> XML tags:
|
||||
<tools>
|
||||
{{- range .Tools }}
|
||||
{"type": "function", "function": {{ .Function }}}
|
||||
{{- end }}
|
||||
</tools>
|
||||
|
||||
For each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:
|
||||
<tool_call>
|
||||
{"name": <function-name>, "arguments": <args-json-object>}
|
||||
</tool_call>
|
||||
{{- end }}<|im_end|>
|
||||
{{ end }}
|
||||
{{- range $i, $_ := .Messages }}
|
||||
{{- $last := eq (len (slice $.Messages $i)) 1 -}}
|
||||
{{- if eq .Role "user" }}<|im_start|>user
|
||||
{{ .Content }}<|im_end|>
|
||||
{{ else if eq .Role "assistant" }}<|im_start|>assistant
|
||||
{{ if .Content }}{{ .Content }}
|
||||
{{- else if .ToolCalls }}<tool_call>
|
||||
{{ range .ToolCalls }}{"name": "{{ .Function.Name }}", "arguments": {{ .Function.Arguments }}}
|
||||
{{ end }}</tool_call>
|
||||
{{- end }}{{ if not $last }}<|im_end|>
|
||||
{{ end }}
|
||||
{{- else if eq .Role "tool" }}<|im_start|>user
|
||||
<tool_response>
|
||||
{{ .Content }}
|
||||
</tool_response><|im_end|>
|
||||
{{ end }}
|
||||
{{- if and (ne .Role "assistant") $last }}<|im_start|>assistant
|
||||
{{ end }}
|
||||
{{- end }}
|
||||
{{- else }}
|
||||
{{- if .System }}<|im_start|>system
|
||||
{{ .System }}<|im_end|>
|
||||
{{ end }}{{ if .Prompt }}<|im_start|>user
|
||||
{{ .Prompt }}<|im_end|>
|
||||
{{ end }}<|im_start|>assistant
|
||||
{{ end }}{{ .Response }}{{ if .Response }}<|im_end|>{{ end }}
|
||||
31
tools/testdata/qwen3.out
vendored
Normal file
31
tools/testdata/qwen3.out
vendored
Normal file
@@ -0,0 +1,31 @@
|
||||
<|im_start|>system
|
||||
You are a knowledgeable assistant. You can answer questions and perform tasks.
|
||||
|
||||
# Tools
|
||||
|
||||
You may call one or more functions to assist with the user query.
|
||||
|
||||
You are provided with function signatures within <tools></tools> XML tags:
|
||||
<tools>
|
||||
{"type": "function", "function": {"name":"get_current_weather","description":"Get the current weather","parameters":{"type":"object","required":["location","format"],"properties":{"format":{"type":"string","description":"The temperature unit to use. Infer this from the user's location.","enum":["celsius","fahrenheit"]},"location":{"type":"string","description":"The city and state, e.g. San Francisco, CA"}}}}}
|
||||
</tools>
|
||||
|
||||
For each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:
|
||||
<tool_call>
|
||||
{"name": <function-name>, "arguments": <args-json-object>}
|
||||
</tool_call><|im_end|>
|
||||
<|im_start|>user
|
||||
What's the weather like today in Paris?<|im_end|>
|
||||
<|im_start|>assistant
|
||||
<tool_call>
|
||||
{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Paris, France"}}
|
||||
</tool_call><|im_end|>
|
||||
<|im_start|>user
|
||||
<tool_response>
|
||||
22
|
||||
</tool_response><|im_end|>
|
||||
<|im_start|>assistant
|
||||
The current temperature in Paris, France is 22 degrees Celsius.<|im_end|>
|
||||
<|im_start|>user
|
||||
What's the weather like today in San Francisco and Toronto?<|im_end|>
|
||||
<|im_start|>assistant
|
||||
253
tools/tools.go
Normal file
253
tools/tools.go
Normal file
@@ -0,0 +1,253 @@
|
||||
package tools
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"log/slog"
|
||||
"strings"
|
||||
gotmpl "text/template"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/template"
|
||||
)
|
||||
|
||||
var (
|
||||
errInvalidToolCall = errors.New("invalid tool call format")
|
||||
errAccumulateMore = errors.New("need to accumulate more content")
|
||||
)
|
||||
|
||||
type Parser struct {
|
||||
greedyParseJSON bool
|
||||
prefix string
|
||||
prefixFound bool
|
||||
tmpl gotmpl.Template
|
||||
sb strings.Builder
|
||||
index int
|
||||
name string
|
||||
arguments string
|
||||
}
|
||||
|
||||
// parseJSONToolCalls attempts to parse a JSON string into a slice of ToolCalls.
|
||||
//
|
||||
// Parameters:
|
||||
// - s: The string to parse
|
||||
// - name: The field name from template that identifies the tool call name
|
||||
// - arguments: The field name from template that identifies the tool call arguments
|
||||
//
|
||||
// Returns:
|
||||
// - []api.ToolCall: The parsed tool calls if successful
|
||||
// - error: ErrAccumulateMore if braces unbalanced, ErrInvalidToolCall if invalid, or nil if successful
|
||||
func parseJSONToolCalls(s string, name, arguments string, prefix string) ([]api.ToolCall, error) {
|
||||
// Check for balanced braces before attempting to parse
|
||||
braceCount := 0
|
||||
squareCount := 0
|
||||
startIndex := -1
|
||||
var rawToolCalls []string
|
||||
s = strings.TrimSpace(s)
|
||||
|
||||
// Only track these if we don't have a prefix as it will be cut off from the prefix. Also track in the parseLeadingJSON case.
|
||||
trackSquareBrackets := prefix == "" || !strings.HasSuffix(prefix, "[") || strings.HasPrefix(s, "[")
|
||||
for i, c := range s {
|
||||
switch c {
|
||||
case '{':
|
||||
braceCount++
|
||||
if startIndex == -1 {
|
||||
startIndex = i
|
||||
}
|
||||
case '}':
|
||||
braceCount--
|
||||
if braceCount == 0 {
|
||||
rawToolCalls = append(rawToolCalls, s[startIndex:i+1])
|
||||
startIndex = -1
|
||||
}
|
||||
case '[':
|
||||
if trackSquareBrackets {
|
||||
squareCount++
|
||||
}
|
||||
case ']':
|
||||
if trackSquareBrackets {
|
||||
squareCount--
|
||||
}
|
||||
}
|
||||
|
||||
// Negative means we have an extra closing brace/bracket
|
||||
if braceCount < 0 || squareCount < 0 {
|
||||
return nil, errInvalidToolCall
|
||||
}
|
||||
}
|
||||
|
||||
// If braces/brackets aren't balanced, need more input
|
||||
if braceCount > 0 || squareCount > 0 {
|
||||
return nil, errAccumulateMore
|
||||
}
|
||||
|
||||
t := strings.TrimSpace(s)
|
||||
if len(t) == 0 {
|
||||
return nil, errAccumulateMore
|
||||
}
|
||||
// If the input is a single square bracket, it's not a valid tool call
|
||||
if t[0] == '[' && len(t) == 1 {
|
||||
return nil, errAccumulateMore
|
||||
}
|
||||
|
||||
// Attempt full unmarshal of the JSON
|
||||
var toolCalls []api.ToolCall
|
||||
for _, rawToolCall := range rawToolCalls {
|
||||
var resp map[string]any
|
||||
if err := json.Unmarshal([]byte(rawToolCall), &resp); err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
// Collect nested objects that could contain tool calls
|
||||
objs := collect(resp)
|
||||
if len(objs) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
// Extract tool calls from objects
|
||||
for _, kv := range objs {
|
||||
n, nok := kv[name].(string)
|
||||
a, aok := kv[arguments].(map[string]any)
|
||||
if nok && aok {
|
||||
toolCalls = append(toolCalls, api.ToolCall{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: n,
|
||||
Arguments: a,
|
||||
},
|
||||
})
|
||||
} else {
|
||||
slog.Debug("No valid tool call found in object.", "object", kv)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Valid JSON, no tool calls found
|
||||
if len(toolCalls) == 0 {
|
||||
slog.Debug("No valid tool calls found in any raw tool calls.", "rawToolCalls", rawToolCalls)
|
||||
return nil, errInvalidToolCall
|
||||
}
|
||||
|
||||
return toolCalls, nil
|
||||
}
|
||||
|
||||
// checkPrefix processes a string to find and handle a prefix pattern.
|
||||
//
|
||||
// Returns:
|
||||
// - The processed string with prefix removed if found
|
||||
// - error: ErrAccumulateMore if prefix is incomplete, or nil if successful
|
||||
func (p *Parser) checkPrefix(s string) (string, error) {
|
||||
if s == "" || p.prefix == "" {
|
||||
return s, nil
|
||||
}
|
||||
|
||||
// Check for prefix at start of string
|
||||
if cut, hasPrefix := strings.CutPrefix(s, p.prefix); hasPrefix {
|
||||
// Found prefix at start - accumulate for potential tool
|
||||
p.prefixFound = true
|
||||
return cut, nil
|
||||
}
|
||||
|
||||
// Check if prefix overlaps end of string
|
||||
if idx := suffixOverlap(s, p.prefix); idx != -1 {
|
||||
// Return everything except overlapping portion
|
||||
p.sb.Reset()
|
||||
p.sb.WriteString(s[idx:])
|
||||
return s[:idx], errAccumulateMore
|
||||
}
|
||||
|
||||
// Check if prefix appears in middle of string
|
||||
if idx := strings.Index(s, p.prefix); idx != -1 {
|
||||
// Save remainder starting at prefix for next pass
|
||||
p.sb.Reset()
|
||||
p.sb.WriteString(strings.TrimSpace(s[idx:]))
|
||||
// Return everything before prefix
|
||||
return s[:idx], errAccumulateMore
|
||||
}
|
||||
|
||||
// No partial prefix found
|
||||
return s, nil
|
||||
}
|
||||
|
||||
// Add processes a string input to parse tool calls and content.
|
||||
// It handles prefix detection and JSON parsing to extract tool calls.
|
||||
//
|
||||
// Returns:
|
||||
// - tools: Any parsed tool calls
|
||||
// - content: Non-tool call content
|
||||
func (p *Parser) Add(s string) (tools []api.ToolCall, content string) {
|
||||
p.sb.WriteString(s)
|
||||
s = p.sb.String()
|
||||
|
||||
// Check for prefix pattern in input
|
||||
s, err := p.checkPrefix(s)
|
||||
if err != nil {
|
||||
// Need more input to complete prefix
|
||||
return nil, s
|
||||
}
|
||||
|
||||
// Exit if prefix exists in template, greedy parsing is off, and prefix not found
|
||||
if !p.greedyParseJSON && !p.prefixFound {
|
||||
p.sb.Reset()
|
||||
return nil, s
|
||||
}
|
||||
|
||||
toolCalls, err := parseJSONToolCalls(s, p.name, p.arguments, p.prefix)
|
||||
if err != nil {
|
||||
if errors.Is(err, errAccumulateMore) {
|
||||
return nil, ""
|
||||
}
|
||||
p.sb.Reset()
|
||||
// Only do greedy JSON parsing if there is no prefix from template
|
||||
if p.prefix != "" {
|
||||
p.greedyParseJSON = false
|
||||
}
|
||||
if p.index != 0 && p.prefix == "" {
|
||||
return nil, ""
|
||||
}
|
||||
if p.prefixFound {
|
||||
// Drop tokens since prefix was found
|
||||
return nil, ""
|
||||
}
|
||||
return nil, s
|
||||
}
|
||||
|
||||
for _, tc := range toolCalls {
|
||||
tc.Function.Index = p.index
|
||||
p.index++
|
||||
}
|
||||
|
||||
p.sb.Reset()
|
||||
return toolCalls, ""
|
||||
}
|
||||
|
||||
// NewParser creates a new tool call parser from a template. It extracts the tool call format,
|
||||
// prefix, and field names from the template to use for parsing tool calls from model output.
|
||||
//
|
||||
// Returns an error if the template does not contain valid tool call formatting.
|
||||
func NewParser(templateToProcess *gotmpl.Template) (*Parser, error) {
|
||||
parsed, err := template.Parse(templateToProcess.Root.String())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
tt, err := toolTemplate(parsed)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
tp := toolPrefix(templateToProcess)
|
||||
|
||||
name, arguments, err := extractToolArgs(tt)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &Parser{
|
||||
tmpl: *tt,
|
||||
sb: strings.Builder{},
|
||||
prefix: tp,
|
||||
greedyParseJSON: true,
|
||||
name: name,
|
||||
arguments: arguments,
|
||||
}, nil
|
||||
}
|
||||
673
tools/tools_test.go
Normal file
673
tools/tools_test.go
Normal file
@@ -0,0 +1,673 @@
|
||||
package tools
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/template"
|
||||
)
|
||||
|
||||
func readFile(t *testing.T, base, name string) *bytes.Buffer {
|
||||
t.Helper()
|
||||
|
||||
bts, err := os.ReadFile(filepath.Join(base, name))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
return bytes.NewBuffer(bts)
|
||||
}
|
||||
|
||||
func TestParseJSONToolCalls(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
nameField string
|
||||
argsField string
|
||||
wantToolCalls []api.ToolCall
|
||||
wantErr error
|
||||
prefix string
|
||||
}{
|
||||
{
|
||||
name: "valid single tool call",
|
||||
input: `{"name": "test_tool", "arguments": {"arg1": "value1"}}`,
|
||||
nameField: "name",
|
||||
argsField: "arguments",
|
||||
wantToolCalls: []api.ToolCall{
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "test_tool",
|
||||
Arguments: map[string]any{
|
||||
"arg1": "value1",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
wantErr: nil,
|
||||
prefix: "",
|
||||
},
|
||||
{
|
||||
name: "incomplete JSON",
|
||||
input: `{"name": "test_tool", "arguments": {"arg1": `,
|
||||
nameField: "name",
|
||||
argsField: "arguments",
|
||||
wantToolCalls: nil,
|
||||
wantErr: errAccumulateMore,
|
||||
prefix: "",
|
||||
},
|
||||
{
|
||||
name: "invalid JSON",
|
||||
input: `not json at all`,
|
||||
nameField: "name",
|
||||
argsField: "arguments",
|
||||
wantToolCalls: nil,
|
||||
wantErr: errInvalidToolCall,
|
||||
prefix: "",
|
||||
},
|
||||
{
|
||||
name: "missing required fields",
|
||||
input: `{"other": "field"}`,
|
||||
nameField: "name",
|
||||
argsField: "arguments",
|
||||
wantToolCalls: nil,
|
||||
wantErr: errInvalidToolCall,
|
||||
prefix: "",
|
||||
},
|
||||
{
|
||||
name: "multiple tool calls in array",
|
||||
input: `[
|
||||
{"name": "tool1", "arguments": {"arg1": 1}},
|
||||
{"name": "tool2", "arguments": {"arg2": "value"}}
|
||||
]`,
|
||||
nameField: "name",
|
||||
argsField: "arguments",
|
||||
wantToolCalls: []api.ToolCall{
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "tool1",
|
||||
Arguments: map[string]any{
|
||||
"arg1": float64(1),
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "tool2",
|
||||
Arguments: map[string]any{
|
||||
"arg2": "value",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
wantErr: nil,
|
||||
prefix: "",
|
||||
},
|
||||
{
|
||||
name: "multiple tool calls without array",
|
||||
input: `
|
||||
{"name": "tool1", "arguments": {"arg1": 1}},
|
||||
{"name": "tool2", "arguments": {"arg2": "value"}}
|
||||
`,
|
||||
nameField: "name",
|
||||
argsField: "arguments",
|
||||
wantToolCalls: []api.ToolCall{
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "tool1",
|
||||
Arguments: map[string]any{
|
||||
"arg1": float64(1),
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "tool2",
|
||||
Arguments: map[string]any{
|
||||
"arg2": "value",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
wantErr: nil,
|
||||
prefix: "",
|
||||
},
|
||||
{
|
||||
name: "multiple tool calls with text after",
|
||||
input: `
|
||||
{"name": "tool1", "arguments": {"arg1": 1}} text
|
||||
{"name": "tool2", "arguments": {"arg2": "value"}} text
|
||||
`,
|
||||
nameField: "name",
|
||||
argsField: "arguments",
|
||||
wantToolCalls: []api.ToolCall{
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "tool1",
|
||||
Arguments: map[string]any{
|
||||
"arg1": float64(1),
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "tool2",
|
||||
Arguments: map[string]any{
|
||||
"arg2": "value",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
wantErr: nil,
|
||||
prefix: "",
|
||||
},
|
||||
{
|
||||
name: "second tool call in array",
|
||||
input: `
|
||||
, {"name": "tool2", "arguments": {"arg2": "value"}}
|
||||
`,
|
||||
nameField: "name",
|
||||
argsField: "arguments",
|
||||
wantToolCalls: []api.ToolCall{
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "tool2",
|
||||
Arguments: map[string]any{
|
||||
"arg2": "value",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
wantErr: nil,
|
||||
prefix: "",
|
||||
},
|
||||
// a bad JSON would not return any tool calls or content as it would always accumulate more
|
||||
{
|
||||
name: "unbalanced square brackets",
|
||||
input: `[{"name": "tool1", "arguments": {"arg1": [1, 2}]`,
|
||||
nameField: "name",
|
||||
argsField: "arguments",
|
||||
wantToolCalls: nil,
|
||||
wantErr: errAccumulateMore,
|
||||
prefix: "",
|
||||
},
|
||||
{
|
||||
name: "incomplete square brackets",
|
||||
input: `[{"name": "tool1", "arguments": {"arg1": [1, 2, 3`,
|
||||
nameField: "name",
|
||||
argsField: "arguments",
|
||||
wantToolCalls: nil,
|
||||
wantErr: errAccumulateMore,
|
||||
prefix: "",
|
||||
},
|
||||
{
|
||||
name: "nested arrays in arguments",
|
||||
input: `{"name": "tool1", "arguments": {"arg1": [1, 2, ["nested", "array"]]}}`,
|
||||
nameField: "name",
|
||||
argsField: "arguments",
|
||||
wantToolCalls: []api.ToolCall{
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "tool1",
|
||||
Arguments: map[string]any{
|
||||
"arg1": []any{float64(1), float64(2), []any{"nested", "array"}},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
wantErr: nil,
|
||||
prefix: "",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
gotCalls, err := parseJSONToolCalls(tt.input, tt.nameField, tt.argsField, tt.prefix)
|
||||
|
||||
if err != tt.wantErr {
|
||||
t.Errorf("parseJSONToolCalls() error = %v, want %v", err, tt.wantErr)
|
||||
}
|
||||
|
||||
if len(gotCalls) != 0 && tt.wantErr != nil {
|
||||
t.Errorf("parseJSONToolCalls() valid = %v, want %v", len(gotCalls) == 0, tt.wantErr == nil)
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(gotCalls, tt.wantToolCalls); diff != "" {
|
||||
t.Errorf("parseJSONToolCalls() tool calls mismatch (-got +want):\n%s", diff)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseToolCalls(t *testing.T) {
|
||||
p := filepath.Join("testdata")
|
||||
t1 := api.ToolCall{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_current_weather",
|
||||
Arguments: api.ToolCallFunctionArguments{
|
||||
"format": "fahrenheit",
|
||||
"location": "San Francisco, CA",
|
||||
},
|
||||
},
|
||||
}
|
||||
t2 := api.ToolCall{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_current_weather",
|
||||
Arguments: api.ToolCallFunctionArguments{
|
||||
"format": "celsius",
|
||||
"location": "Toronto, Canada",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
cases := []struct {
|
||||
name string
|
||||
model string
|
||||
output string
|
||||
expectedToolCall []api.ToolCall
|
||||
expectedTokens string
|
||||
}{
|
||||
{
|
||||
name: "mistral malformed json with tool calls prefix",
|
||||
model: "mistral",
|
||||
output: `[TOOL_CALLS] [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_curren}]`,
|
||||
expectedToolCall: []api.ToolCall{t1},
|
||||
expectedTokens: "",
|
||||
},
|
||||
{
|
||||
name: "mistral multiple tool calls without prefix",
|
||||
model: "mistral",
|
||||
output: `[{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}} ]`,
|
||||
expectedToolCall: []api.ToolCall{t1, t2},
|
||||
expectedTokens: "",
|
||||
},
|
||||
{
|
||||
name: "mistral tool calls with text between no prefix",
|
||||
model: "mistral",
|
||||
output: `[{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]
|
||||
model outputs more tokens here and then [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`,
|
||||
expectedToolCall: []api.ToolCall{t1, t2},
|
||||
expectedTokens: `model outputs more tokens here and then [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`,
|
||||
},
|
||||
{
|
||||
name: "mistral valid json with tool calls prefix",
|
||||
model: "mistral",
|
||||
output: `[TOOL_CALLS] [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`,
|
||||
expectedToolCall: []api.ToolCall{t1, t2},
|
||||
expectedTokens: "",
|
||||
},
|
||||
{
|
||||
name: "mistral multiple tool calls with text between and prefix",
|
||||
model: "mistral",
|
||||
output: `[TOOL_CALLS] [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]
|
||||
model outputs more tokens here and then [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`,
|
||||
expectedToolCall: []api.ToolCall{t1, t2, t1, t2},
|
||||
expectedTokens: "",
|
||||
},
|
||||
{
|
||||
name: "mistral incomplete json with tool calls prefix",
|
||||
model: "mistral",
|
||||
output: `[TOOL_CALLS] [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, `,
|
||||
expectedToolCall: []api.ToolCall{},
|
||||
expectedTokens: "",
|
||||
},
|
||||
{
|
||||
name: "mistral invalid tool call with explanatory text no prefix",
|
||||
model: "mistral",
|
||||
output: `I'm not aware of that information. However, I can suggest searching for the weather using the "get_current_weather" function:
|
||||
|
||||
[{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`,
|
||||
expectedToolCall: []api.ToolCall{},
|
||||
expectedTokens: `I'm not aware of that information. However, I can suggest searching for the weather using the "get_current_weather" function: [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`,
|
||||
},
|
||||
{
|
||||
name: "mistral tool calls without prefix",
|
||||
model: "mistral",
|
||||
output: `[{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`,
|
||||
expectedToolCall: []api.ToolCall{t1, t2},
|
||||
expectedTokens: "",
|
||||
},
|
||||
{
|
||||
name: "command r plus tool calls with json block format",
|
||||
model: "command-r-plus",
|
||||
output: "Action: ```json" + `
|
||||
[
|
||||
{
|
||||
"tool_name": "get_current_weather",
|
||||
"parameters": {
|
||||
"format": "fahrenheit",
|
||||
"location": "San Francisco, CA"
|
||||
}
|
||||
},
|
||||
{
|
||||
"tool_name": "get_current_weather",
|
||||
"parameters": {
|
||||
"format": "celsius",
|
||||
"location": "Toronto, Canada"
|
||||
}
|
||||
}
|
||||
]
|
||||
` + "```",
|
||||
expectedToolCall: []api.ToolCall{t1, t2},
|
||||
expectedTokens: "",
|
||||
},
|
||||
{
|
||||
name: "firefunction tool calls with functools prefix",
|
||||
model: "firefunction",
|
||||
output: ` functools[{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`,
|
||||
expectedToolCall: []api.ToolCall{t1, t2},
|
||||
expectedTokens: "",
|
||||
},
|
||||
{
|
||||
name: "llama3 groq single tool call with xml tags",
|
||||
model: "llama3-groq-tool-use",
|
||||
output: `<tool_call>
|
||||
{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}}
|
||||
</tool_call>`,
|
||||
expectedToolCall: []api.ToolCall{t1},
|
||||
expectedTokens: "",
|
||||
},
|
||||
{
|
||||
name: "xlam tool calls with wrapper object",
|
||||
model: "xlam",
|
||||
output: `{"tool_calls": [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]}`,
|
||||
expectedToolCall: []api.ToolCall{t1, t2},
|
||||
expectedTokens: "",
|
||||
},
|
||||
{
|
||||
name: "qwen2.5 single tool call with prefix",
|
||||
model: "qwen2.5",
|
||||
output: `<tool_call>{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}}</tool_call>`,
|
||||
expectedToolCall: []api.ToolCall{t1},
|
||||
expectedTokens: "",
|
||||
},
|
||||
{
|
||||
name: "qwen2.5 multiple tool calls with and without prefix",
|
||||
model: "qwen2.5",
|
||||
output: `{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}} <tool_call>{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}}</tool_call> <tool_call>{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}</tool_call>`,
|
||||
expectedToolCall: []api.ToolCall{t1, t1, t2},
|
||||
expectedTokens: "",
|
||||
},
|
||||
{
|
||||
name: "qwen2.5 plain text response no tool calls",
|
||||
model: "qwen2.5",
|
||||
output: "The weather in San Francisco, CA is 70°F and in Toronto, Canada is 20°C.",
|
||||
expectedToolCall: []api.ToolCall{},
|
||||
expectedTokens: "The weather in San Francisco, CA is 70°F and in Toronto, Canada is 20°C.",
|
||||
},
|
||||
{
|
||||
name: "qwen2.5 tool calls with trailing text",
|
||||
model: "qwen2.5",
|
||||
output: `[{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}}, {"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}] some tokens after call`,
|
||||
expectedToolCall: []api.ToolCall{t1, t2},
|
||||
expectedTokens: "some tokens after call",
|
||||
},
|
||||
{
|
||||
name: "qwen2.5 tool calls with initial text",
|
||||
model: "qwen2.5",
|
||||
output: `some tokens before call [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}}, {"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`,
|
||||
expectedToolCall: []api.ToolCall{},
|
||||
expectedTokens: `some tokens before call [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}}, {"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`,
|
||||
},
|
||||
{
|
||||
name: "qwen2.5 tool calls with prefix and trailing text",
|
||||
model: "qwen2.5",
|
||||
output: `<tool_call> [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}}, {"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}] </tool_call> some tokens after call`,
|
||||
expectedToolCall: []api.ToolCall{t1, t2},
|
||||
expectedTokens: "",
|
||||
},
|
||||
{
|
||||
name: "qwen2.5 tool calls with prefix and initial text",
|
||||
model: "qwen2.5",
|
||||
output: `some tokens before call <tool_call> [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}}, {"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}] </tool_call>`,
|
||||
expectedToolCall: []api.ToolCall{t1, t2},
|
||||
expectedTokens: "some tokens before call",
|
||||
},
|
||||
{
|
||||
name: "qwen2.5 tool calls without and with prefix",
|
||||
model: "qwen2.5",
|
||||
output: `{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}} <tool_call>{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}</tool_call>`,
|
||||
expectedToolCall: []api.ToolCall{t1, t2},
|
||||
expectedTokens: "",
|
||||
},
|
||||
{
|
||||
name: "qwen2.5 tool calls without and with prefix and text between",
|
||||
model: "qwen2.5",
|
||||
output: `{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}} some tokens between <tool_call>{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}</tool_call> some tokens after call`,
|
||||
expectedToolCall: []api.ToolCall{t1, t2},
|
||||
expectedTokens: "some tokens between",
|
||||
},
|
||||
{
|
||||
name: "qwen2.5 tool calls without prefix and invalid tool call with other tokens",
|
||||
model: "qwen2.5",
|
||||
output: `hi [{"options": "foo"}]`,
|
||||
expectedToolCall: []api.ToolCall{},
|
||||
expectedTokens: `hi [{"options": "foo"}]`,
|
||||
},
|
||||
{
|
||||
name: "qwen2.5 tool calls with prefix and invalid tool call",
|
||||
model: "qwen2.5",
|
||||
output: `<tool_call> [{"options": "foo"}] </tool_call> `,
|
||||
expectedToolCall: []api.ToolCall{},
|
||||
expectedTokens: ``,
|
||||
},
|
||||
{
|
||||
name: "qwen3 tool call with think prefix and tool prefix (sent as a single token)",
|
||||
model: "qwen3",
|
||||
output: `<think>Okay, let me think what tool we should use...</think><tool_call>{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}}</tool_call>`,
|
||||
expectedToolCall: []api.ToolCall{t1},
|
||||
expectedTokens: "<think>Okay, let me think what tool we should use...</think>",
|
||||
},
|
||||
{
|
||||
name: "qwen3 tool call with think prefix, tool prefix, and whitespace (sent as separate tokens)",
|
||||
model: "qwen3",
|
||||
output: `<think>Okay, let me think what tool we should use...</think> <tool_call>{ "name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}} </tool_call>`,
|
||||
expectedToolCall: []api.ToolCall{t1},
|
||||
expectedTokens: "<think>Okay, let me think what tool we should use...</think>",
|
||||
},
|
||||
{
|
||||
name: "qwen3 empty think prefix without tool prefix and invalid tool call",
|
||||
model: "qwen3",
|
||||
output: `<think></think> {"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}} </tool_call>`,
|
||||
expectedToolCall: []api.ToolCall{},
|
||||
expectedTokens: `<think></think> {"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}} </tool_call>`,
|
||||
},
|
||||
{
|
||||
name: "qwen3 empty think prefix with tool prefix and valid tool call",
|
||||
model: "qwen3",
|
||||
output: `<think></think><tool_call>{ "name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}} </tool_call>`,
|
||||
expectedToolCall: []api.ToolCall{t1},
|
||||
expectedTokens: `<think></think>`,
|
||||
},
|
||||
{
|
||||
name: "qwen3 invalid tool call with fake tool prefix (single rune suffix match)",
|
||||
model: "qwen3",
|
||||
output: `<think></think>< fakeout {"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}} </tool_call>`,
|
||||
expectedToolCall: []api.ToolCall{},
|
||||
expectedTokens: `<think></think>< fakeout {"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}} </tool_call>`,
|
||||
},
|
||||
{
|
||||
name: "qwen3 invalid tool call with partial tool prefix (multiple rune suffix match)",
|
||||
model: "qwen3",
|
||||
output: `<think></think><tool_c fakeout {"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}} </tool_call>`,
|
||||
expectedToolCall: []api.ToolCall{},
|
||||
expectedTokens: `<think></think><tool_c fakeout {"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}} </tool_call>`,
|
||||
},
|
||||
{
|
||||
name: "qwen3 invalid tool call with malformed tool prefix",
|
||||
model: "qwen3",
|
||||
output: `<think></think><tool_cfakeout {"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}} </tool_call>`,
|
||||
expectedToolCall: []api.ToolCall{},
|
||||
expectedTokens: `<think></think><tool_cfakeout {"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}} </tool_call>`,
|
||||
},
|
||||
{
|
||||
name: "model with prefix in template, no prefix in output",
|
||||
model: "qwen2.5",
|
||||
output: `[{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}} {"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`,
|
||||
expectedToolCall: []api.ToolCall{t1, t2},
|
||||
expectedTokens: "",
|
||||
},
|
||||
{
|
||||
name: "model with prefix in template, prefix in output",
|
||||
model: "qwen2.5",
|
||||
output: `<tool_call>[{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}} {"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]</tool_call>`,
|
||||
expectedToolCall: []api.ToolCall{t1, t2},
|
||||
expectedTokens: "",
|
||||
},
|
||||
{
|
||||
name: "model without prefix in template, no prefix in output",
|
||||
model: "llama3.2",
|
||||
output: `[{"name": "get_current_weather", "parameters": {"format":"fahrenheit","location":"San Francisco, CA"}} {"name": "get_current_weather", "parameters": {"format":"celsius","location":"Toronto, Canada"}}]`,
|
||||
expectedToolCall: []api.ToolCall{t1, t2},
|
||||
expectedTokens: "",
|
||||
},
|
||||
{
|
||||
name: "model without prefix in template, no prefix in output, single tool call",
|
||||
model: "llama3.2",
|
||||
output: `{"name": "get_current_weather", "parameters": {"format":"fahrenheit","location":"San Francisco, CA"}}`,
|
||||
expectedToolCall: []api.ToolCall{t1},
|
||||
expectedTokens: "",
|
||||
},
|
||||
{
|
||||
name: "model without prefix in template, prefix in output, multiple tool calls in list",
|
||||
model: "llama3.2",
|
||||
output: `<tool_call> [{"name": "get_current_weather", "parameters": {"format":"fahrenheit","location":"San Francisco, CA"}} {"name": "get_current_weather", "parameters": {"format":"celsius","location":"Toronto, Canada"}}]</tool_call>`,
|
||||
expectedToolCall: []api.ToolCall{t1, t2},
|
||||
expectedTokens: `<tool_call>`,
|
||||
},
|
||||
{
|
||||
name: "model without prefix in template, prefix in output, individual tool calls",
|
||||
model: "llama3.2",
|
||||
output: `<tool_call> {"name": "get_current_weather", "parameters": {"format":"fahrenheit","location":"San Francisco, CA"}} {"name": "get_current_weather", "parameters": {"format":"celsius","location":"Toronto, Canada"}}`,
|
||||
expectedToolCall: []api.ToolCall{t1, t2},
|
||||
expectedTokens: `<tool_call>`,
|
||||
},
|
||||
{
|
||||
name: "model with prefix in template, no prefix in output, tokens before",
|
||||
model: "qwen2.5",
|
||||
output: `some tokens before [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}} {"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`,
|
||||
expectedToolCall: []api.ToolCall{},
|
||||
expectedTokens: `some tokens before [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}} {"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`,
|
||||
},
|
||||
{
|
||||
name: "model with prefix in template, prefix in output, tokens after",
|
||||
model: "qwen2.5",
|
||||
output: `<tool_call>[{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}} {"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]</tool_call> some tokens after`,
|
||||
expectedToolCall: []api.ToolCall{t1, t2},
|
||||
expectedTokens: "",
|
||||
},
|
||||
{
|
||||
name: "model without prefix in template, no prefix in output, tokens after",
|
||||
model: "llama3.2",
|
||||
output: `[{"name": "get_current_weather", "parameters": {"format":"fahrenheit","location":"San Francisco, CA"}} {"name": "get_current_weather", "parameters": {"format":"celsius","location":"Toronto, Canada"}}]</tool_call> some tokens after`,
|
||||
expectedToolCall: []api.ToolCall{t1, t2},
|
||||
expectedTokens: "",
|
||||
},
|
||||
{
|
||||
name: "model without prefix in template, no prefix in output, tokens before",
|
||||
model: "llama3.2",
|
||||
output: `some tokens before [{"name": "get_current_weather", "parameters": {"format":"fahrenheit","location":"San Francisco, CA"}} {"name": "get_current_weather", "parameters": {"format":"celsius","location":"Toronto, Canada"}}]`,
|
||||
expectedToolCall: []api.ToolCall{t1, t2},
|
||||
expectedTokens: `some tokens before`,
|
||||
},
|
||||
{
|
||||
name: "model without prefix in template, prefix in output, tokens after",
|
||||
model: "llama3.2",
|
||||
output: `<tool_call>
|
||||
[{"name": "get_current_weather", "parameters": {"format":"fahrenheit","location":"San Francisco, CA"}} {"name": "get_current_weather", "parameters": {"format":"celsius","location":"Toronto, Canada"}}]</tool_call> some tokens after`,
|
||||
expectedToolCall: []api.ToolCall{t1, t2},
|
||||
expectedTokens: `<tool_call>`,
|
||||
},
|
||||
{
|
||||
name: "model without without prefix, match all jsons",
|
||||
model: "llama3.2",
|
||||
output: `model outputs some text [{"name": "get_current_weather", "parameters": {"format":"fahrenheit","location":"San Francisco, CA"}} {"name": "get_current_weather", "parameters": {"format":"celsius","location":"Toronto, Canada"}}]</tool_call> some tokens after`,
|
||||
expectedToolCall: []api.ToolCall{t1, t2},
|
||||
expectedTokens: "model outputs some text",
|
||||
},
|
||||
{
|
||||
name: "model flushes tokens if tool call doesn't match",
|
||||
model: "llama3.2",
|
||||
output: `{ "user": {"id": 12345, "name": "Alice", "preferences": {"theme": "dark", "notifications": true}, "stats": {"points": 987, "level": 42}}}`,
|
||||
expectedToolCall: []api.ToolCall{},
|
||||
expectedTokens: `{ "user": {"id": 12345, "name": "Alice", "preferences": {"theme": "dark", "notifications": true}, "stats": {"points": 987, "level": 42}}}`,
|
||||
},
|
||||
{
|
||||
name: "model flushes tokens if tool call doesn't match array",
|
||||
model: "llama3.2",
|
||||
output: `[ { "user": {"id": 12345, "name": "Alice", "preferences": {"theme": "dark", "notifications": true}, "stats": {"points": 987, "level": 42}}}]`,
|
||||
expectedToolCall: []api.ToolCall{},
|
||||
expectedTokens: `[ { "user": {"id": 12345, "name": "Alice", "preferences": {"theme": "dark", "notifications": true}, "stats": {"points": 987, "level": 42}}}]`,
|
||||
},
|
||||
}
|
||||
|
||||
var tools []api.Tool
|
||||
if err := json.Unmarshal(readFile(t, p, "tools.json").Bytes(), &tools); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
var messages []api.Message
|
||||
if err := json.Unmarshal(readFile(t, p, "messages.json").Bytes(), &messages); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
for _, tt := range cases {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
tmpl, err := template.Parse(readFile(t, p, fmt.Sprintf("%s.gotmpl", tt.model)).String())
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
t.Run("template", func(t *testing.T) {
|
||||
actual := &bytes.Buffer{} // Create new buffer for each test
|
||||
if err := tmpl.Execute(actual, template.Values{Tools: tools, Messages: messages}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(actual.String(), readFile(t, p, fmt.Sprintf("%s.out", tt.model)).String()); diff != "" {
|
||||
t.Errorf("mismatch (-got +want):\n%s", diff)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("parse", func(t *testing.T) {
|
||||
tp, err := NewParser(tmpl.Template)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
got := []api.ToolCall{}
|
||||
var gotTokens strings.Builder
|
||||
|
||||
tokens := strings.Fields(tt.output)
|
||||
for _, tok := range tokens {
|
||||
s := " " + tok
|
||||
|
||||
toolCalls, content := tp.Add(s)
|
||||
if len(content) > 0 {
|
||||
gotTokens.WriteString(content)
|
||||
} else if len(toolCalls) > 0 {
|
||||
got = append(got, toolCalls...)
|
||||
}
|
||||
}
|
||||
|
||||
// Compare tool calls if we expect any
|
||||
if diff := cmp.Diff(got, tt.expectedToolCall); diff != "" {
|
||||
t.Errorf("tool calls mismatch (-got +want):\n%s", diff)
|
||||
}
|
||||
|
||||
// Compare tokens if we expect any
|
||||
stripped := strings.TrimSpace(gotTokens.String())
|
||||
if diff := cmp.Diff(stripped, tt.expectedTokens); diff != "" {
|
||||
t.Log("actualTokens", stripped, "expectedTokens", tt.expectedTokens)
|
||||
t.Errorf("tokens mismatch (-got +want):\n%s", diff)
|
||||
}
|
||||
})
|
||||
})
|
||||
}
|
||||
}
|
||||
227
tools/tools_utils.go
Normal file
227
tools/tools_utils.go
Normal file
@@ -0,0 +1,227 @@
|
||||
package tools
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"log/slog"
|
||||
"slices"
|
||||
"strings"
|
||||
gotmpl "text/template"
|
||||
"text/template/parse"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/template"
|
||||
)
|
||||
|
||||
// extractToolCallsFormat traverses a template AST to find text that follows a ".ToolCalls" condition.
|
||||
// It walks the template nodes looking for if-statements containing ".ToolCalls" and extracts any
|
||||
// immediate text nodes that follow. This is used to identify tool call prefixes and formatting.
|
||||
//
|
||||
// Returns:
|
||||
// - string: The extracted text following the first ".ToolCalls" condition found
|
||||
// - bool: Whether a ".ToolCalls" condition was found in the template
|
||||
func extractToolCallsFormat(tmpl *gotmpl.Template) (string, bool) {
|
||||
if tmpl == nil || tmpl.Tree == nil {
|
||||
slog.Debug("template or tree is nil")
|
||||
return "", false
|
||||
}
|
||||
|
||||
var result string
|
||||
var found bool
|
||||
|
||||
var walk func(nodes []parse.Node)
|
||||
walk = func(nodes []parse.Node) {
|
||||
for _, node := range nodes {
|
||||
if found {
|
||||
return
|
||||
}
|
||||
|
||||
switch n := node.(type) {
|
||||
case *parse.IfNode:
|
||||
if isToolCallsNode(n) {
|
||||
// Collect immediate TextNode(s) at start of IfNode's list
|
||||
var sb strings.Builder
|
||||
for _, innerNode := range n.List.Nodes {
|
||||
if tn, ok := innerNode.(*parse.TextNode); ok {
|
||||
sb.Write(tn.Text)
|
||||
} else {
|
||||
// Stop at first non-text node
|
||||
break
|
||||
}
|
||||
}
|
||||
result = sb.String()
|
||||
found = true
|
||||
return
|
||||
}
|
||||
// Recurse into child nodes
|
||||
walk(n.List.Nodes)
|
||||
if n.ElseList != nil {
|
||||
walk(n.ElseList.Nodes)
|
||||
}
|
||||
case *parse.ListNode:
|
||||
walk(n.Nodes)
|
||||
case *parse.RangeNode:
|
||||
walk(n.List.Nodes)
|
||||
if n.ElseList != nil {
|
||||
walk(n.ElseList.Nodes)
|
||||
}
|
||||
case *parse.WithNode:
|
||||
walk(n.List.Nodes)
|
||||
if n.ElseList != nil {
|
||||
walk(n.ElseList.Nodes)
|
||||
}
|
||||
default:
|
||||
// Continue to next node
|
||||
continue
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
walk(tmpl.Tree.Root.Nodes)
|
||||
return result, found
|
||||
}
|
||||
|
||||
// isToolCallsNode detects if a node's condition includes ".ToolCalls"
|
||||
func isToolCallsNode(n *parse.IfNode) bool {
|
||||
for _, cmd := range n.Pipe.Cmds {
|
||||
for _, arg := range cmd.Args {
|
||||
if field, ok := arg.(*parse.FieldNode); ok {
|
||||
if slices.Contains(field.Ident, "ToolCalls") {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func toolPrefix(tmpl *gotmpl.Template) string {
|
||||
tokenText, ok := extractToolCallsFormat(tmpl)
|
||||
if !ok {
|
||||
return ""
|
||||
}
|
||||
tokenText = strings.TrimSpace(tokenText)
|
||||
tokenText = strings.ReplaceAll(tokenText, "\r", "")
|
||||
tokenText = strings.ReplaceAll(tokenText, "\n", " ")
|
||||
|
||||
return tokenText
|
||||
}
|
||||
|
||||
// toolTemplate creates a subtree from the node that ranges over .ToolCalls
|
||||
//
|
||||
// Returns:
|
||||
// - *gotmpl.Template: The subtree containing the .ToolCalls range
|
||||
// - error: Error if parsing failed
|
||||
func toolTemplate(t *template.Template) (*gotmpl.Template, error) {
|
||||
tmpl := t.Subtree(func(n parse.Node) bool {
|
||||
if t, ok := n.(*parse.RangeNode); ok {
|
||||
return slices.Contains(template.Identifiers(t.Pipe), "ToolCalls")
|
||||
}
|
||||
|
||||
return false
|
||||
})
|
||||
|
||||
if tmpl == nil {
|
||||
return nil, errors.New("failed to find tool template")
|
||||
}
|
||||
|
||||
return tmpl, nil
|
||||
}
|
||||
|
||||
// suffixOverlap returns the index in s where the longest suffix overlap with prefix begins
|
||||
//
|
||||
// Returns:
|
||||
// - int: The starting index in s where the suffix overlap begins
|
||||
func suffixOverlap(s, prefix string) int {
|
||||
max := min(len(prefix), len(s))
|
||||
for i := max; i > 0; i-- {
|
||||
if strings.HasSuffix(s, prefix[:i]) {
|
||||
return len(s) - i
|
||||
}
|
||||
}
|
||||
return -1
|
||||
}
|
||||
|
||||
// extractToolArgs executes a template with a known tool call format to extract the name and arguments
|
||||
//
|
||||
// Returns:
|
||||
// - string: The name of the tool call
|
||||
// - string: The arguments of the tool call
|
||||
// - error: Error if parsing failed
|
||||
func extractToolArgs(tmpl *gotmpl.Template) (name, arguments string, err error) {
|
||||
var b bytes.Buffer
|
||||
if err := tmpl.Execute(&b, map[string][]api.ToolCall{
|
||||
"ToolCalls": {
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "@@name@@",
|
||||
Arguments: api.ToolCallFunctionArguments{
|
||||
"@@argument@@": 1,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}); err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
|
||||
var obj any
|
||||
err = json.Unmarshal(b.Bytes(), &obj)
|
||||
if err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
|
||||
var objs []map[string]any
|
||||
switch v := obj.(type) {
|
||||
case map[string]any:
|
||||
objs = []map[string]any{v}
|
||||
case []map[string]any:
|
||||
objs = v
|
||||
case []any:
|
||||
objs = collect(v)
|
||||
}
|
||||
if len(objs) == 0 {
|
||||
return "", "", errors.New("no template objects found")
|
||||
}
|
||||
|
||||
// find the keys that correspond to the name and arguments fields
|
||||
for k, v := range objs[0] {
|
||||
switch v.(type) {
|
||||
case string:
|
||||
name = k
|
||||
case map[string]any:
|
||||
arguments = k
|
||||
}
|
||||
}
|
||||
|
||||
if name == "" || arguments == "" {
|
||||
slog.Debug("missing required fields in tool call template", "name", name, "arguments", arguments)
|
||||
return "", "", errors.New("missing required fields in tool call template")
|
||||
}
|
||||
|
||||
return name, arguments, nil
|
||||
}
|
||||
|
||||
// collect recursively traverses an object to collect all nested maps
|
||||
//
|
||||
// Returns:
|
||||
// - []map[string]any: A slice of all nested maps found in the object
|
||||
func collect(obj any) []map[string]any {
|
||||
var all []map[string]any
|
||||
switch o := obj.(type) {
|
||||
case map[string]any:
|
||||
all = append(all, o)
|
||||
for _, v := range o {
|
||||
all = append(all, collect(v)...)
|
||||
}
|
||||
case []any:
|
||||
for _, v := range o {
|
||||
all = append(all, collect(v)...)
|
||||
}
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
|
||||
return all
|
||||
}
|
||||
464
tools/tools_utils_test.go
Normal file
464
tools/tools_utils_test.go
Normal file
@@ -0,0 +1,464 @@
|
||||
package tools
|
||||
|
||||
import (
|
||||
"testing"
|
||||
gotmpl "text/template"
|
||||
|
||||
"github.com/ollama/ollama/template"
|
||||
)
|
||||
|
||||
func TestExtractToolCallsFormat(t *testing.T) {
|
||||
cases := []struct {
|
||||
name string
|
||||
template string
|
||||
want string
|
||||
found bool
|
||||
}{
|
||||
{
|
||||
name: "nil template",
|
||||
template: "",
|
||||
want: "",
|
||||
found: false,
|
||||
},
|
||||
{
|
||||
name: "basic tool call with text",
|
||||
template: "{{if .ToolCalls}}Hello world{{end}}",
|
||||
want: "Hello world",
|
||||
found: true,
|
||||
},
|
||||
{
|
||||
name: "tool call with json format",
|
||||
template: "{{if .ToolCalls}}```json\n{{end}}",
|
||||
want: "```json\n",
|
||||
found: true,
|
||||
},
|
||||
{
|
||||
name: "tool call in range",
|
||||
template: "{{range .ToolCalls}}tool: {{.}}{{end}}",
|
||||
want: "",
|
||||
found: false,
|
||||
},
|
||||
{
|
||||
name: "tool call with multiple text nodes",
|
||||
template: "{{if .ToolCalls}}First text{{if .Something}}inner{{end}}Second text{{end}}",
|
||||
want: "First text",
|
||||
found: true,
|
||||
},
|
||||
{
|
||||
name: "nested if without tool calls",
|
||||
template: "{{if .Something}}{{if .OtherThing}}text{{end}}{{end}}",
|
||||
want: "",
|
||||
found: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
tmpl, err := gotmpl.New("test").Parse(tc.template)
|
||||
if err != nil && tc.template != "" {
|
||||
t.Fatalf("failed to parse template: %v", err)
|
||||
}
|
||||
|
||||
got, found := extractToolCallsFormat(tmpl)
|
||||
if got != tc.want {
|
||||
t.Errorf("got text %q, want %q", got, tc.want)
|
||||
}
|
||||
if found != tc.found {
|
||||
t.Errorf("got found %v, want %v", found, tc.found)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestToolPrefix(t *testing.T) {
|
||||
cases := []struct {
|
||||
name string
|
||||
template string
|
||||
want string
|
||||
}{
|
||||
{
|
||||
name: "basic tool call with action prefix",
|
||||
template: "{{if .ToolCalls}}Action: ```json{{end}}",
|
||||
want: "Action: ```json",
|
||||
},
|
||||
{
|
||||
name: "incomplete functools bracket",
|
||||
template: "{{if .ToolCalls}}functools[{{end}}",
|
||||
want: "functools[",
|
||||
},
|
||||
{
|
||||
name: "tool call with angle brackets",
|
||||
template: "{{if .ToolCalls}}Hello, world! <tool_call>{{end}}",
|
||||
want: "Hello, world! <tool_call>",
|
||||
},
|
||||
{
|
||||
name: "multiple tool call formats",
|
||||
template: "{{if .ToolCalls}}[tool_call] <tool_call>{{end}}",
|
||||
want: "[tool_call] <tool_call>",
|
||||
},
|
||||
{
|
||||
name: "single angle bracket tool call",
|
||||
template: "{{if .ToolCalls}}<tool_call>{{end}}",
|
||||
want: "<tool_call>",
|
||||
},
|
||||
{
|
||||
name: "incomplete angle bracket after tool call",
|
||||
template: "{{if .ToolCalls}}[tool_call] <{{end}}",
|
||||
want: "[tool_call] <",
|
||||
},
|
||||
{
|
||||
name: "angle bracket prefix with tool call",
|
||||
template: "{{if .ToolCalls}}> <tool_call>{{end}}",
|
||||
want: "> <tool_call>",
|
||||
},
|
||||
{
|
||||
name: "uppercase tool call with incomplete bracket",
|
||||
template: "{{if .ToolCalls}}[TOOL_CALL] [{{end}}",
|
||||
want: "[TOOL_CALL] [",
|
||||
},
|
||||
{
|
||||
name: "uppercase tool call with adjacent bracket",
|
||||
template: "{{if .ToolCalls}}[TOOL_CALL][{{end}}",
|
||||
want: "[TOOL_CALL][",
|
||||
},
|
||||
{
|
||||
name: "tool call with pipe delimiters",
|
||||
template: "{{if .ToolCalls}}<|tool_call|>{{end}}",
|
||||
want: "<|tool_call|>",
|
||||
},
|
||||
{
|
||||
name: "tool with no prefix",
|
||||
template: "{{if .ToolCalls}}{{end}}",
|
||||
want: "",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range cases {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
tmpl, err := gotmpl.New("test").Parse(tt.template)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to parse template: %v", err)
|
||||
}
|
||||
got := toolPrefix(tmpl)
|
||||
if got != tt.want {
|
||||
t.Errorf("ToolToken(%q) = %q; want %q", tt.template, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestToolTemplate(t *testing.T) {
|
||||
cases := []struct {
|
||||
name string
|
||||
template string
|
||||
want bool
|
||||
}{
|
||||
{
|
||||
name: "basic tool call range",
|
||||
template: "{{range .ToolCalls}}test{{end}}",
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "no tool calls",
|
||||
template: "{{range .Other}}test{{end}}",
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "nested tool calls",
|
||||
template: "{{range .Outer}}{{range .ToolCalls}}test{{end}}{{end}}",
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "empty template",
|
||||
template: "",
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "tool calls in if statement",
|
||||
template: "{{if .ToolCalls}}test{{end}}",
|
||||
want: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range cases {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
tmpl, err := gotmpl.New("test").Parse(tt.template)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to parse template: %v", err)
|
||||
}
|
||||
|
||||
parsed, err := template.Parse(tmpl.Root.String())
|
||||
if err != nil {
|
||||
t.Fatalf("failed to parse template: %v", err)
|
||||
}
|
||||
|
||||
_, err = toolTemplate(parsed)
|
||||
if err != nil && tt.want {
|
||||
t.Errorf("toolTemplate() = %v; want %v", err, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSuffixOverlap(t *testing.T) {
|
||||
cases := []struct {
|
||||
name string
|
||||
s string
|
||||
d string
|
||||
want int
|
||||
}{
|
||||
{
|
||||
name: "no overlap",
|
||||
s: "hello world",
|
||||
d: "<tool_call>",
|
||||
want: -1,
|
||||
},
|
||||
{
|
||||
name: "full overlap",
|
||||
s: "<tool_call>",
|
||||
d: "<tool_call>",
|
||||
want: 0,
|
||||
},
|
||||
{
|
||||
name: "partial overlap",
|
||||
s: "text <tool_call>",
|
||||
d: "<tool_call>",
|
||||
want: 5,
|
||||
},
|
||||
{
|
||||
name: "delimiter longer than string",
|
||||
s: "<tool>",
|
||||
d: "<tool_call>",
|
||||
want: -1,
|
||||
},
|
||||
{
|
||||
name: "empty string",
|
||||
s: "",
|
||||
d: "<tool_call>",
|
||||
want: -1,
|
||||
},
|
||||
{
|
||||
name: "empty delimiter",
|
||||
s: "<tool_call>",
|
||||
d: "",
|
||||
want: -1,
|
||||
},
|
||||
{
|
||||
name: "single char overlap",
|
||||
s: "test<",
|
||||
d: "<tool_call>",
|
||||
want: 4,
|
||||
},
|
||||
{
|
||||
name: "partial tool call",
|
||||
s: "hello <tool_",
|
||||
d: "<tool_call>",
|
||||
want: 6,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range cases {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := suffixOverlap(tt.s, tt.d)
|
||||
if got != tt.want {
|
||||
t.Errorf("suffixOverlap(%q, %q) = %d; want %d", tt.s, tt.d, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractToolArgs(t *testing.T) {
|
||||
cases := []struct {
|
||||
name string
|
||||
template string
|
||||
want string
|
||||
ok bool
|
||||
}{
|
||||
{
|
||||
name: "basic tool call with text after",
|
||||
template: `{{if .ToolCalls}}tool response{{end}}`,
|
||||
want: "tool response",
|
||||
ok: true,
|
||||
},
|
||||
{
|
||||
name: "tool call with mixed content after",
|
||||
template: `{{if .ToolCalls}}<tool_call>{{.Something}}{{end}}`,
|
||||
want: "<tool_call>",
|
||||
ok: true,
|
||||
},
|
||||
{
|
||||
name: "tool call with no text after",
|
||||
template: `{{if .ToolCalls}}{{.Something}}{{end}}`,
|
||||
want: "",
|
||||
ok: true,
|
||||
},
|
||||
{
|
||||
name: "nested tool call",
|
||||
template: `{{if .Something}}{{if .ToolCalls}}[TOOL_CALL]{{end}}{{end}}`,
|
||||
want: "[TOOL_CALL]",
|
||||
ok: true,
|
||||
},
|
||||
{
|
||||
name: "no tool calls",
|
||||
template: `{{if .Something}}no tools here{{end}}`,
|
||||
want: "",
|
||||
ok: false,
|
||||
},
|
||||
{
|
||||
name: "empty template",
|
||||
template: ``,
|
||||
want: "",
|
||||
ok: false,
|
||||
},
|
||||
{
|
||||
name: "multiple tool calls sections",
|
||||
template: `{{if .ToolCalls}}first{{end}}{{if .ToolCalls}}second{{end}}`,
|
||||
want: "first",
|
||||
ok: true,
|
||||
},
|
||||
{
|
||||
name: "range over tool calls",
|
||||
template: `{{if .ToolCalls}}{{range .ToolCalls}}tool{{end}}{{end}}`,
|
||||
want: "",
|
||||
ok: true,
|
||||
},
|
||||
{
|
||||
name: "tool calls with pipe delimiters",
|
||||
template: `{{if .ToolCalls}}<|tool|>{{end}}`,
|
||||
want: "<|tool|>",
|
||||
ok: true,
|
||||
},
|
||||
{
|
||||
name: "tool calls with nested template",
|
||||
template: `{{if .ToolCalls}}{{template "tool" .}}{{end}}`,
|
||||
want: "",
|
||||
ok: true,
|
||||
},
|
||||
{
|
||||
name: "tool calls with whitespace variations",
|
||||
template: `{{if .ToolCalls}} tool {{end}}`,
|
||||
want: " tool ",
|
||||
ok: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range cases {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
tmpl, err := gotmpl.New("test").Parse(tt.template)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to parse template: %v", err)
|
||||
}
|
||||
|
||||
got, ok := extractToolCallsFormat(tmpl)
|
||||
if got != tt.want {
|
||||
t.Errorf("TextAfterToolCalls() got = %q, want %q", got, tt.want)
|
||||
}
|
||||
if ok != tt.ok {
|
||||
t.Errorf("TextAfterToolCalls() ok = %v, want %v", ok, tt.ok)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCollect(t *testing.T) {
|
||||
cases := []struct {
|
||||
name string
|
||||
obj any
|
||||
want []map[string]any
|
||||
}{
|
||||
{
|
||||
name: "simple map",
|
||||
obj: map[string]any{
|
||||
"key": "value",
|
||||
},
|
||||
want: []map[string]any{
|
||||
{"key": "value"},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "nested map",
|
||||
obj: map[string]any{
|
||||
"outer": map[string]any{
|
||||
"inner": "value",
|
||||
},
|
||||
},
|
||||
want: []map[string]any{
|
||||
{"outer": map[string]any{"inner": "value"}},
|
||||
{"inner": "value"},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "array of maps",
|
||||
obj: []any{
|
||||
map[string]any{"key1": "val1"},
|
||||
map[string]any{"key2": "val2"},
|
||||
},
|
||||
want: []map[string]any{
|
||||
{"key1": "val1"},
|
||||
{"key2": "val2"},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "deeply nested",
|
||||
obj: map[string]any{
|
||||
"l1": map[string]any{
|
||||
"l2": map[string]any{
|
||||
"l3": "value",
|
||||
},
|
||||
},
|
||||
},
|
||||
want: []map[string]any{
|
||||
{"l1": map[string]any{"l2": map[string]any{"l3": "value"}}},
|
||||
{"l2": map[string]any{"l3": "value"}},
|
||||
{"l3": "value"},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "non-map value",
|
||||
obj: "string",
|
||||
want: nil,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range cases {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := collect(tt.obj)
|
||||
if len(got) != len(tt.want) {
|
||||
t.Errorf("collect() got %d maps, want %d", len(got), len(tt.want))
|
||||
return
|
||||
}
|
||||
|
||||
// Compare each map in the result
|
||||
for i := range tt.want {
|
||||
if !mapsEqual(got[i], tt.want[i]) {
|
||||
t.Errorf("collect() map[%d] = %v, want %v", i, got[i], tt.want[i])
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// mapsEqual compares two maps for deep equality
|
||||
func mapsEqual(m1, m2 map[string]any) bool {
|
||||
if len(m1) != len(m2) {
|
||||
return false
|
||||
}
|
||||
for k, v1 := range m1 {
|
||||
v2, ok := m2[k]
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
switch val1 := v1.(type) {
|
||||
case map[string]any:
|
||||
val2, ok := v2.(map[string]any)
|
||||
if !ok || !mapsEqual(val1, val2) {
|
||||
return false
|
||||
}
|
||||
default:
|
||||
if v1 != v2 {
|
||||
return false
|
||||
}
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
@@ -8,6 +8,7 @@ const (
|
||||
CapabilityInsert = Capability("insert")
|
||||
CapabilityVision = Capability("vision")
|
||||
CapabilityEmbedding = Capability("embedding")
|
||||
CapabilityThinking = Capability("thinking")
|
||||
)
|
||||
|
||||
func (c Capability) String() string {
|
||||
|
||||
Reference in New Issue
Block a user