diff --git a/api/types.go b/api/types.go index a7ddbc373..df3504c3b 100644 --- a/api/types.go +++ b/api/types.go @@ -313,10 +313,11 @@ func (t *ToolFunction) String() string { // ChatResponse is the response returned by [Client.Chat]. Its fields are // similar to [GenerateResponse]. type ChatResponse struct { - Model string `json:"model"` - CreatedAt time.Time `json:"created_at"` - Message Message `json:"message"` - DoneReason string `json:"done_reason,omitempty"` + Model string `json:"model"` + CreatedAt time.Time `json:"created_at"` + Message Message `json:"message"` + DoneReason string `json:"done_reason,omitempty"` + DebugInfo *DebugInfo `json:"_debug_info,omitempty"` Done bool `json:"done"` @@ -329,13 +330,6 @@ type DebugInfo struct { ImageCount int `json:"image_count,omitempty"` } -// DebugTemplateResponse is returned when _debug_render_only is set to true -type DebugTemplateResponse struct { - Model string `json:"model"` - CreatedAt time.Time `json:"created_at"` - DebugInfo DebugInfo `json:"_debug_info"` -} - type Metrics struct { TotalDuration time.Duration `json:"total_duration,omitempty"` LoadDuration time.Duration `json:"load_duration,omitempty"` @@ -443,6 +437,8 @@ type CreateRequest struct { System string `json:"system,omitempty"` Parameters map[string]any `json:"parameters,omitempty"` Messages []Message `json:"messages,omitempty"` + Renderer string `json:"renderer,omitempty"` + Parser string `json:"parser,omitempty"` // Deprecated: set the model name with Model instead Name string `json:"name"` @@ -480,6 +476,8 @@ type ShowResponse struct { Parameters string `json:"parameters,omitempty"` Template string `json:"template,omitempty"` System string `json:"system,omitempty"` + Renderer string `json:"renderer,omitempty"` + Parser string `json:"parser,omitempty"` Details ModelDetails `json:"details,omitempty"` Messages []Message `json:"messages,omitempty"` ModelInfo map[string]any `json:"model_info,omitempty"` @@ -592,6 +590,8 @@ type GenerateResponse struct { Metrics ToolCalls []ToolCall `json:"tool_calls,omitempty"` + + DebugInfo *DebugInfo `json:"_debug_info,omitempty"` } // ModelDetails provides details about a model. diff --git a/convert/convert_bert.go b/convert/convert_bert.go index a9f4b8a77..6b0d0030a 100644 --- a/convert/convert_bert.go +++ b/convert/convert_bert.go @@ -28,6 +28,7 @@ type bertModel struct { LayerNormEPS float32 `json:"layer_norm_eps"` LayerNormEpsilon float32 `json:"layer_norm_epsilon"` NormEpsilon float32 `json:"norm_epsilon"` + normalizeEmbeddings bool PoolingType uint32 } @@ -54,9 +55,11 @@ func (p *bertModel) parseMore(fsys fs.FS) error { var pooling string for _, m := range modules { - if m.Type == "sentence_transformers.models.Pooling" { + switch m.Type { + case "sentence_transformers.models.Pooling": pooling = m.Path - break + case "sentence_transformers.models.Normalize": + p.normalizeEmbeddings = true } } @@ -90,6 +93,7 @@ func (p *bertModel) KV(t *Tokenizer) ggml.KV { kv["general.architecture"] = "bert" kv["bert.attention.causal"] = false kv["bert.pooling_type"] = p.PoolingType + kv["bert.normalize_embeddings"] = p.normalizeEmbeddings kv["bert.block_count"] = cmp.Or(p.NLayers, p.NumHiddenLayers, p.NLayer) diff --git a/discover/cuda_common.go b/discover/cuda_common.go index ca008af63..3c7a92114 100644 --- a/discover/cuda_common.go +++ b/discover/cuda_common.go @@ -45,10 +45,18 @@ func cudaVariant(gpuInfo CudaGPUInfo) string { } } + // Check GPU compute capability FIRST + isOldGPU := gpuInfo.computeMajor < 7 || (gpuInfo.computeMajor == 7 && gpuInfo.computeMinor < 5) + if isOldGPU { + // GPU is Pascal or older (CC <= 7.4) - use CUDA v12 (supports CC 6.1) + return "v12" + } + + // GPU is Turing or newer (CC >= 7.5) - can use newer CUDA if gpuInfo.DriverMajor < 13 { // The detected driver is older than 580 (Aug 2025) // Warn if their CC is compatible with v13 and they should upgrade their driver to get better performance - if gpuInfo.computeMajor > 7 || (gpuInfo.computeMajor == 7 && gpuInfo.computeMinor >= 5) { + if !isOldGPU { slog.Warn("old CUDA driver detected - please upgrade to a newer driver for best performance", "version", fmt.Sprintf("%d.%d", gpuInfo.DriverMajor, gpuInfo.DriverMinor)) } return "v12" diff --git a/docs/development.md b/docs/development.md index 9726b5d91..ff07b5fb6 100644 --- a/docs/development.md +++ b/docs/development.md @@ -11,6 +11,10 @@ Then build and run Ollama from the root directory of the repository: go run . serve ``` +> [!NOTE] +> Ollama includes native code compiled with CGO. From time to time these data structures can change and CGO can get out of sync resulting in unexpected crashes. You can force a full build of the native code by running `go clean -cache` first. + + ## macOS (Apple Silicon) macOS Apple Silicon supports Metal which is built-in to the Ollama binary. No additional steps are required. diff --git a/harmony/harmonyparser.go b/harmony/harmonyparser.go index 3ec2c21f1..a51819dda 100644 --- a/harmony/harmonyparser.go +++ b/harmony/harmonyparser.go @@ -3,29 +3,15 @@ package harmony import ( "fmt" "log/slog" - "slices" "strings" "unicode" "github.com/ollama/ollama/api" "github.com/ollama/ollama/logutil" - "github.com/ollama/ollama/template" ) type harmonyParserState int -func ShouldUseHarmony(modelFamily string, template *template.Template) bool { - if slices.Contains([]string{"gptoss", "gpt-oss"}, modelFamily) { - // heuristic to check whether the template expects to be parsed via harmony: - // search for harmony tags that are nearly always used - if template.Contains("<|start|>") && template.Contains("<|end|>") { - return true - } - } - - return false -} - const ( harmonyParserState_LookingForMessageStart harmonyParserState = iota harmonyParserState_ParsingHeader @@ -89,28 +75,18 @@ func (s *HarmonyParser) AddImplicitStart() { s.acc.WriteString("<|start|>assistant") } -func Prefill(lastMessage api.Message) string { - if lastMessage.Role != "assistant" { - return "" - } - - switch { - case strings.TrimSpace(lastMessage.Content) != "": - return "<|start|>assistant<|channel|>final<|message|>" - case strings.TrimSpace(lastMessage.Thinking) != "": - return "<|start|>assistant<|channel|>analysis<|message|>" - default: - return "" - } -} - -// AddImplicitStartOrPrefill adds an implicit start tag or prefill string if provided -func (s *HarmonyParser) AddImplicitStartOrPrefill(prefillString string) { - if strings.TrimSpace(prefillString) != "" { - s.acc.WriteString(prefillString) - } else { - s.AddImplicitStart() +func (s *HarmonyParser) AddImplicitStartOrPrefill(lastMessage *api.Message) { + if lastMessage != nil && lastMessage.Role == "assistant" { + // handle prefilling conditions + if lastMessage.Content != "" { + s.acc.WriteString("<|start|>assistant<|channel|>final<|message|>") + return + } else if lastMessage.Thinking != "" { + s.acc.WriteString("<|start|>assistant<|channel|>analysis<|message|>") + return + } } + s.AddImplicitStart() } func (s *HarmonyParser) AddContent(content string) []HarmonyEvent { @@ -289,7 +265,6 @@ type HarmonyMessageHandler struct { state harmonyMessageState HarmonyParser *HarmonyParser FunctionNameMap *FunctionNameMap - ToolParser *HarmonyToolCallAccumulator } // NewHarmonyMessageHandler creates a new message handler @@ -302,16 +277,12 @@ func NewHarmonyMessageHandler() *HarmonyMessageHandler { HeaderEndTag: "<|message|>", }, FunctionNameMap: NewFunctionNameMap(), - ToolParser: &HarmonyToolCallAccumulator{ - state: harmonyToolCallState_Normal, - currentToolName: nil, - }, } } // AddContent processes the content and returns the content, thinking, and tool content. // content and thinking are already fully parsed, but tool content still needs to be passed to the tool parser -func (h *HarmonyMessageHandler) AddContent(content string) (string, string, string) { +func (h *HarmonyMessageHandler) AddContent(content string, toolParser *HarmonyToolCallAccumulator) (string, string, string) { contentSb := strings.Builder{} thinkingSb := strings.Builder{} toolContentSb := strings.Builder{} @@ -328,14 +299,14 @@ func (h *HarmonyMessageHandler) AddContent(content string) (string, string, stri // event.Header.Recipient is the tool name, something like // "browser.search" for a built-in, or "functions.calc" for a // custom one - h.ToolParser.SetToolName(event.Header.Recipient) + toolParser.SetToolName(event.Header.Recipient) } else { h.state = harmonyMessageState_Thinking } case "commentary": if event.Header.Recipient != "" { h.state = harmonyMessageState_ToolCalling - h.ToolParser.SetToolName(event.Header.Recipient) + toolParser.SetToolName(event.Header.Recipient) } else { h.state = harmonyMessageState_Normal } @@ -358,6 +329,13 @@ func (h *HarmonyMessageHandler) AddContent(content string) (string, string, stri return contentSb.String(), thinkingSb.String(), toolContentSb.String() } +func (h *HarmonyMessageHandler) CreateToolParser() *HarmonyToolCallAccumulator { + return &HarmonyToolCallAccumulator{ + state: harmonyToolCallState_Normal, + currentToolName: nil, + } +} + type harmonyToolCallState int const ( diff --git a/harmony/harmonyparser_test.go b/harmony/harmonyparser_test.go index 82bf5b2de..b988a018f 100644 --- a/harmony/harmonyparser_test.go +++ b/harmony/harmonyparser_test.go @@ -3,7 +3,6 @@ package harmony import ( "fmt" "reflect" - "strings" "testing" ) @@ -536,202 +535,3 @@ func TestFunctionConvertAndAdd(t *testing.T) { }) } } - -func TestHarmonyMessageHandlerStreamingScenarios(t *testing.T) { - t.Run("thinking_then_content_streams", func(t *testing.T) { - handler := NewHarmonyMessageHandler() - handler.HarmonyParser.AddImplicitStart() - tp := handler.ToolParser - type step struct { - in string - wantContent string - wantThinking string - } - steps := []step{ - {in: "<|channel|>analysis<|message|>Thinking...", wantThinking: "Thinking..."}, - {in: "<|end|>", wantThinking: ""}, - {in: "<|start|>assistant<|message|>Answer", wantContent: "Answer"}, - {in: "<|end|>", wantContent: ""}, - } - for i, s := range steps { - content, thinking, tool := handler.AddContent(s.in) - if tool != "" { - tp.Add(tool) - } - if content != s.wantContent || thinking != s.wantThinking { - t.Fatalf("step %d: got (content=%q thinking=%q), want (content=%q thinking=%q)", i, content, thinking, s.wantContent, s.wantThinking) - } - } - }) - - t.Run("content_streams_as_it_arrives", func(t *testing.T) { - handler := NewHarmonyMessageHandler() - handler.HarmonyParser.AddImplicitStart() - tp := handler.ToolParser - inputs := []string{ - "<|start|>assistant<|message|>Hello", - ", world", - "!<|end|>", - } - var got []string - for _, in := range inputs { - content, thinking, tool := handler.AddContent(in) - if tool != "" { - tp.Add(tool) - } - if thinking != "" { - t.Fatalf("unexpected thinking %q", thinking) - } - if content != "" { - got = append(got, content) - } - } - want := []string{"Hello", ", world", "!"} - if !reflect.DeepEqual(got, want) { - t.Fatalf("content pieces mismatch: got %v want %v", got, want) - } - }) - - t.Run("thinking_streams_separately_from_content", func(t *testing.T) { - handler := NewHarmonyMessageHandler() - handler.HarmonyParser.AddImplicitStart() - tp := handler.ToolParser - inputs := []string{ - "<|channel|>analysis<|message|>Thinking...", - "<|end|>", - "<|start|>assistant<|message|>Answer", - "<|end|>", - } - var got []string - for _, in := range inputs { - content, thinking, tool := handler.AddContent(in) - if tool != "" { - tp.Add(tool) - } - if thinking != "" { - got = append(got, thinking) - } - if content != "" { - got = append(got, content) - } - } - want := []string{"Thinking...", "Answer"} - if !reflect.DeepEqual(got, want) { - t.Fatalf("content pieces mismatch: got %v want %v", got, want) - } - }) - - t.Run("partial_tags_buffer_until_complete", func(t *testing.T) { - handler := NewHarmonyMessageHandler() - handler.HarmonyParser.AddImplicitStart() - tp := handler.ToolParser - inputs := []string{ - "<|chan", - "nel|>analysis<|mess", - "age|>Deep ", - "thought", - "<|end|>", - "<|start|>assistant<|message|>Done", - "<|end|>", - } - var thinkingPieces []string - var contentPieces []string - for _, in := range inputs { - content, thinking, tool := handler.AddContent(in) - if tool != "" { - tp.Add(tool) - } - if thinking != "" { - thinkingPieces = append(thinkingPieces, thinking) - } - if content != "" { - contentPieces = append(contentPieces, content) - } - } - if want := []string{"Deep ", "thought"}; !reflect.DeepEqual(thinkingPieces, want) { - t.Fatalf("thinking pieces mismatch: got %v want %v", thinkingPieces, want) - } - if want := []string{"Done"}; !reflect.DeepEqual(contentPieces, want) { - t.Fatalf("content pieces mismatch: got %v want %v", contentPieces, want) - } - }) - - t.Run("simple_assistant_after_analysis", func(t *testing.T) { - handler := NewHarmonyMessageHandler() - handler.HarmonyParser.AddImplicitStart() - tp := handler.ToolParser - inputs := []string{ - "<|channel|>analysis<|message|>Think", - "<|end|>", - "<|start|>assistant<|message|>Answer", - "<|end|>", - } - var contentSb, thinkingSb strings.Builder - for _, in := range inputs { - content, thinking, tool := handler.AddContent(in) - if tool != "" { - tp.Add(tool) - } - contentSb.WriteString(content) - thinkingSb.WriteString(thinking) - } - if contentSb.String() != "Answer" { - t.Fatalf("content mismatch: got %q want %q", contentSb.String(), "Answer") - } - if thinkingSb.String() != "Think" { - t.Fatalf("thinking mismatch: got %q want %q", thinkingSb.String(), "Think") - } - }) - - t.Run("tool_call_parsed_and_returned_correctly", func(t *testing.T) { - handler := NewHarmonyMessageHandler() - handler.HarmonyParser.AddImplicitStart() - tp := handler.ToolParser - inputs := []string{ - "<|channel|>commentary to=functions.calculate<|message|>{\"expression\":\"2+2\"}<|end|>", - } - for _, in := range inputs { - content, thinking, tool := handler.AddContent(in) - if content != "" || thinking != "" { - continue - } - if tool != "" { - tp.Add(tool) - } - } - name, args := tp.Drain() - if name == nil || *name != "functions.calculate" { - t.Fatalf("unexpected tool name: %v", name) - } - if got, want := args, "{\"expression\":\"2+2\"}"; got != want { - t.Fatalf("unexpected tool args: got %s want %s", got, want) - } - }) - - t.Run("tool_call_across_chunks", func(t *testing.T) { - handler := NewHarmonyMessageHandler() - handler.HarmonyParser.AddImplicitStart() - tp := handler.ToolParser - inputs := []string{ - "<|channel|>commentary to=functions.calculate<|message|>{\"expression\":\"2+", - "2\"}", - "<|end|>", - } - for _, in := range inputs { - content, thinking, tool := handler.AddContent(in) - if content != "" || thinking != "" { - continue - } - if tool != "" { - tp.Add(tool) - } - } - name, args := tp.Drain() - if name == nil || *name != "functions.calculate" { - t.Fatalf("unexpected tool name: %v", name) - } - if got, want := args, "{\"expression\":\"2+2\"}"; got != want { - t.Fatalf("unexpected tool args: got %s want %s", got, want) - } - }) -} diff --git a/integration/context_test.go b/integration/context_test.go index ca6f16087..15c157858 100644 --- a/integration/context_test.go +++ b/integration/context_test.go @@ -50,7 +50,7 @@ func TestContextExhaustion(t *testing.T) { // Set up the test data req := api.GenerateRequest{ Model: smol, - Prompt: "Write me a story with a ton of emojis?", + Prompt: "Write me a story in english with a lot of emojis", Stream: &stream, Options: map[string]any{ "temperature": 0, diff --git a/integration/utils_test.go b/integration/utils_test.go index ec74b2e3d..7901fed3f 100644 --- a/integration/utils_test.go +++ b/integration/utils_test.go @@ -561,7 +561,7 @@ func GenerateRequests() ([]api.GenerateRequest, [][]string) { KeepAlive: &api.Duration{Duration: 10 * time.Second}, }, { Model: smol, - Prompt: "what is the origin of the US thanksgiving holiday? Be brief but factual in your reply", + Prompt: "how do rainbows form? Be brief but factual in your reply", Stream: &stream, KeepAlive: &api.Duration{Duration: 10 * time.Second}, }, { @@ -579,9 +579,9 @@ func GenerateRequests() ([]api.GenerateRequest, [][]string) { [][]string{ {"sunlight", "scattering", "interact", "color", "surface", "depth", "red", "orange", "yellow", "absorbs", "wavelength"}, {"soil", "organic", "earth", "black", "tan", "chemical", "processes", "pigments", "particles", "iron oxide", "rust", "air", "water", "mixture", "mixing"}, - {"england", "english", "massachusetts", "pilgrims", "colonists", "independence", "british", "feast", "family", "gatherings", "traditions", "turkey", "colonial", "period", "harvest", "agricultural", "european settlers", "american revolution", "civil war", "16th century", "17th century", "native american", "united states", "cultural", "hardship", "autumn", "festival"}, + {"water", "droplet", "refracted", "reflect", "color", "spectrum"}, {"fourth", "july", "declaration", "independence"}, - {"nitrogen", "oxygen", "carbon", "dioxide"}, + {"nitrogen", "oxygen", "carbon", "dioxide", "water", "vapor"}, } } diff --git a/llama/llama.go b/llama/llama.go index ac2c112c2..88672a033 100644 --- a/llama/llama.go +++ b/llama/llama.go @@ -515,33 +515,34 @@ func (c *MtmdContext) NewEmbed(llamaContext *Context, data []byte) ([][]float32, } nChunks := C.mtmd_input_chunks_size(ic) numEmbed := llamaContext.Model().NEmbd() - lastChunkSize := 0 + embed := make([][]float32, 0) for i := range int(nChunks) { chunk := C.mtmd_input_chunks_get(ic, C.size_t(i)) numTokens := int(C.mtmd_input_chunk_get_n_tokens(chunk)) - lastChunkSize = numTokens + slog.Debug("chunk tokens", "index", i, "numTokens", numTokens) // Encode the chunk if C.int32_t(0) != C.mtmd_encode_chunk(c.c, chunk) { return nil, errors.New("unable to encode mtmd image chunk") } - } - // Get the embeddings - embed := make([][]float32, lastChunkSize) - embd := C.mtmd_get_output_embd(c.c) - if nil == embd { - return nil, errors.New("failed to get image embedding") - } + // Get the embeddings for this chunk + chunkEmbed := make([][]float32, numTokens) + chunkEmbd := C.mtmd_get_output_embd(c.c) + if nil == chunkEmbd { + continue + } - // Extend the embedding array for each token - s := unsafe.Slice((*float32)(embd), numEmbed*lastChunkSize) - rows := make([]float32, len(s)) - copy(rows, s) - for i := range lastChunkSize { - embed[i] = rows[i*numEmbed : (i+1)*numEmbed] + // Extend the embedding array for each token + s := unsafe.Slice((*float32)(chunkEmbd), numTokens*numEmbed) + rows := make([]float32, len(s)) + copy(rows, s) + for i := range numTokens { + chunkEmbed[i] = rows[i*numEmbed : (i+1)*numEmbed] + } + embed = append(embed, chunkEmbed...) } - + slog.Debug("image embeddings", "totalEmbeddings", len(embed)) return embed, nil } diff --git a/llm/server.go b/llm/server.go index 2af82fa04..7f7d68cd0 100644 --- a/llm/server.go +++ b/llm/server.go @@ -35,7 +35,6 @@ import ( "github.com/ollama/ollama/logutil" "github.com/ollama/ollama/ml" "github.com/ollama/ollama/model" - "github.com/ollama/ollama/parser" ) type filteredEnv []string @@ -1350,9 +1349,7 @@ type CompletionRequest struct { Images []ImageData Options *api.Options - Grammar string // set before sending the request to the subprocess - ParserType parser.TokenParserType - PrefillString string + Grammar string // set before sending the request to the subprocess } // DoneReason represents the reason why a completion response is done @@ -1379,15 +1376,13 @@ func (d DoneReason) String() string { } type CompletionResponse struct { - Content string `json:"content"` - Thinking string `json:"thinking"` - ToolCalls []api.ToolCall `json:"tool_calls"` - DoneReason DoneReason `json:"done_reason"` - Done bool `json:"done"` - PromptEvalCount int `json:"prompt_eval_count"` - PromptEvalDuration time.Duration `json:"prompt_eval_duration"` - EvalCount int `json:"eval_count"` - EvalDuration time.Duration `json:"eval_duration"` + Content string `json:"content"` + DoneReason DoneReason `json:"done_reason"` + Done bool `json:"done"` + PromptEvalCount int `json:"prompt_eval_count"` + PromptEvalDuration time.Duration `json:"prompt_eval_duration"` + EvalCount int `json:"eval_count"` + EvalDuration time.Duration `json:"eval_duration"` } func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn func(CompletionResponse)) error { @@ -1505,8 +1500,7 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu return fmt.Errorf("error unmarshalling llm prediction response: %v", err) } switch { - // TODO(parthsareen): token repeat limit is now handled in the runner, this currently support legacy model and can be removed in the future - case strings.TrimSpace(c.Content) == lastToken && c.Content != "": + case strings.TrimSpace(c.Content) == lastToken: tokenRepeat++ default: lastToken = strings.TrimSpace(c.Content) @@ -1519,14 +1513,16 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu return ctx.Err() } + if c.Content != "" { + fn(CompletionResponse{ + Content: c.Content, + }) + } + if c.Done { fn(c) return nil } - - if c.Content != "" || c.Thinking != "" || len(c.ToolCalls) > 0 { - fn(c) - } } } diff --git a/ml/backend.go b/ml/backend.go index 154a0f1b5..455715b0d 100644 --- a/ml/backend.go +++ b/ml/backend.go @@ -416,6 +416,7 @@ type Tensor interface { AddID(ctx Context, t2, ids Tensor) Tensor Softmax(ctx Context) Tensor + L2Norm(ctx Context, eps float32) Tensor LayerNorm(ctx Context, weight, bias Tensor, eps float32) Tensor RMSNorm(ctx Context, weight Tensor, eps float32) Tensor Scale(ctx Context, s float64) Tensor @@ -429,12 +430,13 @@ type Tensor interface { Sin(ctx Context) Tensor Cos(ctx Context) Tensor Tanh(ctx Context) Tensor - GELU(ctx Context) Tensor - QuickGELU(ctx Context) Tensor - SILU(ctx Context) Tensor - RELU(ctx Context) Tensor + GELU(ctx Context, up ...Tensor) Tensor + SILU(ctx Context, up ...Tensor) Tensor + RELU(ctx Context, up ...Tensor) Tensor Sigmoid(ctx Context) Tensor - SwiGLU(ctx Context, up Tensor, alpha, limit float32) Tensor + + // AlphaLimitSILU is a variant of SILU that clamps the input to the range [-limit, limit] + SILUAlphaLimit(ctx Context, up Tensor, alpha, limit float32) Tensor Reshape(ctx Context, shape ...int) Tensor View(ctx Context, offset int, shape ...int) Tensor diff --git a/ml/backend/ggml/ggml.go b/ml/backend/ggml/ggml.go index 931386d56..49dc3e1ab 100644 --- a/ml/backend/ggml/ggml.go +++ b/ml/backend/ggml/ggml.go @@ -1205,6 +1205,13 @@ func (t *Tensor) AddID(ctx ml.Context, t2, ids ml.Tensor) ml.Tensor { } } +func (t *Tensor) L2Norm(ctx ml.Context, eps float32) ml.Tensor { + return &Tensor{ + b: t.b, + t: C.ggml_l2_norm(ctx.(*Context).ctx, t.t, C.float(eps)), + } +} + func (t *Tensor) LayerNorm(ctx ml.Context, w, b ml.Tensor, eps float32) ml.Tensor { tt := C.ggml_norm(ctx.(*Context).ctx, t.t, C.float(eps)) if w != nil { @@ -1424,35 +1431,46 @@ func (t *Tensor) IM2Col(ctx ml.Context, t2 ml.Tensor, s0, s1, p0, p1, d0, d1 int } } -func (t *Tensor) GELU(ctx ml.Context) ml.Tensor { +func (t *Tensor) GELU(ctx ml.Context, t2 ...ml.Tensor) ml.Tensor { + if len(t2) > 0 { + return &Tensor{ + b: t.b, + t: C.ggml_geglu_split(ctx.(*Context).ctx, t.t, t2[0].(*Tensor).t), + } + } return &Tensor{ b: t.b, t: C.ggml_gelu_inplace(ctx.(*Context).ctx, t.t), } } -func (t *Tensor) QuickGELU(ctx ml.Context) ml.Tensor { - return &Tensor{ - b: t.b, - t: C.ggml_gelu_quick_inplace(ctx.(*Context).ctx, t.t), +func (t *Tensor) SILU(ctx ml.Context, t2 ...ml.Tensor) ml.Tensor { + if len(t2) > 0 { + return &Tensor{ + b: t.b, + t: C.ggml_swiglu_split(ctx.(*Context).ctx, t.t, t2[0].(*Tensor).t), + } } -} - -func (t *Tensor) SILU(ctx ml.Context) ml.Tensor { return &Tensor{ b: t.b, t: C.ggml_silu_inplace(ctx.(*Context).ctx, t.t), } } -func (t *Tensor) RELU(ctx ml.Context) ml.Tensor { +func (t *Tensor) RELU(ctx ml.Context, t2 ...ml.Tensor) ml.Tensor { + if len(t2) > 0 { + return &Tensor{ + b: t.b, + t: C.ggml_reglu_split(ctx.(*Context).ctx, t.t, t2[0].(*Tensor).t), + } + } return &Tensor{ b: t.b, t: C.ggml_relu_inplace(ctx.(*Context).ctx, t.t), } } -func (t *Tensor) SwiGLU(ctx ml.Context, up ml.Tensor, alpha, limit float32) ml.Tensor { +func (t *Tensor) SILUAlphaLimit(ctx ml.Context, up ml.Tensor, alpha, limit float32) ml.Tensor { return &Tensor{ b: t.b, t: C.ggml_swiglu_oai(ctx.(*Context).ctx, t.t, up.(*Tensor).t, C.float(alpha), C.float(limit)), diff --git a/ml/nn/attention.go b/ml/nn/attention.go index 21b4a28ae..94dbde0b0 100644 --- a/ml/nn/attention.go +++ b/ml/nn/attention.go @@ -26,6 +26,7 @@ func Attention(ctx ml.Context, query, key, value ml.Tensor, scale float64, cache } func AttentionWithSinks(ctx ml.Context, query, key, value, sinks ml.Tensor, scale float64, cache kvcache.Cache) ml.Tensor { + ctx.Forward(query) if key != nil && value != nil { if query.Dim(0) != key.Dim(0) { panic(fmt.Errorf("d_k in attention operation does not match between query(%v) and key(%v)", query.Dim(0), key.Dim(0))) @@ -39,6 +40,7 @@ func AttentionWithSinks(ctx ml.Context, query, key, value, sinks ml.Tensor, scal panic(fmt.Errorf("seq_len_k in attention operation does not match between key(%v) and value(%v)", key.Dim(2), value.Dim(2))) } + ctx.Forward(key, value) if cache != nil { cache.Put(ctx, key, value) } diff --git a/ml/nn/pooling/pooling.go b/ml/nn/pooling/pooling.go new file mode 100644 index 000000000..63b63b3af --- /dev/null +++ b/ml/nn/pooling/pooling.go @@ -0,0 +1,42 @@ +package pooling + +import ( + "github.com/ollama/ollama/ml" +) + +type Type uint32 + +const ( + TypeNone Type = iota + TypeMean + TypeCLS + TypeLast +) + +func (t Type) String() string { + switch t { + case TypeMean: + return "Mean" + case TypeCLS: + return "CLS" + case TypeLast: + return "Last" + default: + return "Unknown" + } +} + +func (t Type) Forward(ctx ml.Context, hiddenStates ml.Tensor) ml.Tensor { + switch t { + case TypeMean: + hiddenStates = hiddenStates.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx).Mean(ctx) + return hiddenStates.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx) + case TypeCLS: + return hiddenStates.View(ctx, 0, hiddenStates.Dim(0)) + case TypeLast: + hiddenStates = hiddenStates.View(ctx, (hiddenStates.Dim(1)-1)*hiddenStates.Stride(1), hiddenStates.Dim(0)) + return hiddenStates + default: + panic("unknown pooling type") + } +} diff --git a/ml/nn/pooling/pooling_test.go b/ml/nn/pooling/pooling_test.go new file mode 100644 index 000000000..c80019459 --- /dev/null +++ b/ml/nn/pooling/pooling_test.go @@ -0,0 +1,79 @@ +package pooling_test + +import ( + "bytes" + "os" + "slices" + "testing" + + "github.com/google/go-cmp/cmp" + "github.com/ollama/ollama/discover" + fsggml "github.com/ollama/ollama/fs/ggml" + "github.com/ollama/ollama/ml" + "github.com/ollama/ollama/ml/backend/ggml" + "github.com/ollama/ollama/ml/nn/pooling" +) + +func setup(tb testing.TB, n int) ml.Backend { + tb.Helper() + + f, err := os.CreateTemp(tb.TempDir(), "*.bin") + if err != nil { + tb.Fatal(err) + } + defer f.Close() + + if err := fsggml.WriteGGUF(f, fsggml.KV{ + "general.architecture": "test", + "test.block_count": uint32(1), + }, []*fsggml.Tensor{ + {Name: "blk.0.weight", Shape: []uint64{1}, WriterTo: bytes.NewBuffer(make([]byte, 4))}, + }); err != nil { + tb.Fatal(err) + } + + var gpuLayers ml.GPULayersList + if gpus := discover.GetGPUInfo(); len(gpus) > 0 { + gpuLayers = append(gpuLayers, ml.GPULayers{ + ID: gpus[0].ID, + Layers: slices.Collect(func(yield func(int) bool) { + for i := range n { + if !yield(i) { + return + } + } + }), + }) + } + b, err := ggml.New(f.Name(), ml.BackendParams{AllocMemory: true, GPULayers: gpuLayers}) + if err != nil { + tb.Fatal(err) + } + + return b +} + +func TestForward(t *testing.T) { + cases := map[pooling.Type][]float32{ + pooling.TypeMean: {4, 5, 6, 7, 8, 9, 10, 11}, + pooling.TypeCLS: {0, 1, 2, 3, 4, 5, 6, 7}, + pooling.TypeLast: {8, 9, 10, 11, 12, 13, 14, 15}, + } + for typ, want := range cases { + t.Run(typ.String(), func(t *testing.T) { + b := setup(t, 99) + defer b.Close() + + ctx := b.NewContext() + defer ctx.Close() + + tt := ctx.Input().Arange(0, 16, 1, ml.DTypeF32).Reshape(ctx, 8, 2) + tt = typ.Forward(ctx, tt) + + ctx.Forward(tt).Compute(tt) + if diff := cmp.Diff(want, tt.Floats()); diff != "" { + t.Error(diff) + } + }) + } +} diff --git a/model/input/input.go b/model/input/input.go index bd9b53ec6..35dc41b35 100644 --- a/model/input/input.go +++ b/model/input/input.go @@ -54,10 +54,9 @@ type Batch struct { // Inputs is the input tokens, including placeholders for multimodal inputs. Inputs ml.Tensor - // Multimodal is a set of multimodal embeddings previously created by - // EncodeMultimodal, along with an index into Inputs. Unused for text-only - // models or for batches without multimodal elements. - Multimodal []MultimodalIndex + // Outputs are the set of indicies into Inputs for which output data should + // be returned. + Outputs ml.Tensor // Positions is the position for each Input, relative to its sequence. Equal // in length to Inputs. @@ -66,7 +65,8 @@ type Batch struct { // Sequences is the sequence for each Input. Equal in length to Inputs. Sequences []int - // Outputs are the set of indicies into Inputs for which output data should - // be returned. - Outputs []int32 + // Multimodal is a set of multimodal embeddings previously created by + // EncodeMultimodal, along with an index into Inputs. Unused for text-only + // models or for batches without multimodal elements. + Multimodal []MultimodalIndex } diff --git a/model/model.go b/model/model.go index 3a72f09aa..5493a4e63 100644 --- a/model/model.go +++ b/model/model.go @@ -5,7 +5,6 @@ import ( "fmt" _ "image/jpeg" _ "image/png" - "math" "os" "reflect" "strconv" @@ -21,10 +20,15 @@ import ( "github.com/ollama/ollama/logutil" "github.com/ollama/ollama/ml" _ "github.com/ollama/ollama/ml/backend" + "github.com/ollama/ollama/ml/nn/pooling" "github.com/ollama/ollama/model/input" ) -var ErrNoVisionModel = errors.New("this model is missing data required for image input") +var ( + ErrNoVisionModel = errors.New("this model is missing data required for image input") + ErrUnsupportedModel = errors.New("model not supported") + ErrUnsupportedTokenizer = errors.New("tokenizer not supported") +) // Model implements a specific model architecture, defining the forward pass and any model-specific configuration type Model interface { @@ -104,7 +108,7 @@ func New(modelPath string, params ml.BackendParams) (Model, error) { } arch := b.Config().Architecture() - if b.Config().Uint("pooling_type", math.MaxUint32) != math.MaxUint32 { + if pooling.Type(b.Config().Uint("pooling_type")) != pooling.TypeNone { arch = arch + "_embed" } @@ -242,7 +246,7 @@ func setPointer(base Base, v reflect.Value, tags []Tag) { vv = vv.Elem() } - vv = vv.Elem() + vv = reflect.Indirect(vv) if v.IsNil() { vv = reflect.New(v.Type().Elem()).Elem() } diff --git a/model/models/bert/embed.go b/model/models/bert/embed.go new file mode 100644 index 000000000..166c11e13 --- /dev/null +++ b/model/models/bert/embed.go @@ -0,0 +1,181 @@ +package bert + +import ( + "cmp" + "math" + + "github.com/ollama/ollama/fs" + "github.com/ollama/ollama/ml" + "github.com/ollama/ollama/ml/nn" + "github.com/ollama/ollama/ml/nn/pooling" + "github.com/ollama/ollama/model" + "github.com/ollama/ollama/model/input" +) + +type Model struct { + model.Base + model.TextProcessor + + TokenEmbedding *nn.Embedding `gguf:"token_embd"` + TypeEmbedding *nn.Embedding `gguf:"token_types"` + PositionEmbedding *nn.Embedding `gguf:"position_embd"` + TokenEmbeddingNorm *nn.LayerNorm `gguf:"token_embd_norm"` + + Layers []EncoderLayer `gguf:"blk"` + + Options +} + +// Forward implements model.Model. +func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) { + hiddenStates := m.TokenEmbedding.Forward(ctx, batch.Inputs) + hiddenStates = hiddenStates.Add(ctx, m.TypeEmbedding.Weight.View(ctx, 0, m.hiddenSize)) + hiddenStates = hiddenStates.Add(ctx, m.PositionEmbedding.Forward(ctx, ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions)))) + hiddenStates = m.TokenEmbeddingNorm.Forward(ctx, hiddenStates, m.eps) + + for _, layer := range m.Layers { + hiddenStates = layer.Forward(ctx, hiddenStates, &m.Options) + } + + hiddenStates = m.poolingType.Forward(ctx, hiddenStates) + if m.normalize { + hiddenStates = hiddenStates.L2Norm(ctx, 1e-12) + } + + return hiddenStates, nil +} + +type EncoderLayer struct { + *Attention + AttentionNorm *nn.LayerNorm `gguf:"attn_output_norm"` + + *MLP + MLPNorm *nn.LayerNorm `gguf:"layer_output_norm"` +} + +func (e *EncoderLayer) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *Options) ml.Tensor { + // Attention + residual := hiddenStates + hiddenStates = e.Attention.Forward(ctx, hiddenStates, opts) + hiddenStates = hiddenStates.Add(ctx, residual) + hiddenStates = e.AttentionNorm.Forward(ctx, hiddenStates, opts.eps) + + // MLP + residual = hiddenStates + hiddenStates = e.MLP.Forward(ctx, hiddenStates, opts) + hiddenStates = hiddenStates.Add(ctx, residual) + hiddenStates = e.MLPNorm.Forward(ctx, hiddenStates, opts.eps) + + return hiddenStates +} + +type Attention struct { + Query *nn.Linear `gguf:"attn_q"` + QueryNorm *nn.LayerNorm `gguf:"attn_q_norm"` + + Key *nn.Linear `gguf:"attn_k"` + KeyNorm *nn.LayerNorm `gguf:"attn_k_norm"` + + Value *nn.Linear `gguf:"attn_v"` + + Output *nn.Linear `gguf:"attn_output"` +} + +func (a *Attention) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *Options) ml.Tensor { + batchSize := hiddenStates.Dim(1) + + query := a.Query.Forward(ctx, hiddenStates) + if a.QueryNorm != nil { + query = a.QueryNorm.Forward(ctx, query, opts.eps) + } + query = query.Reshape(ctx, opts.headDim(), opts.numHeads, batchSize) + + key := a.Key.Forward(ctx, hiddenStates) + if a.KeyNorm != nil { + key = a.KeyNorm.Forward(ctx, key, opts.eps) + } + key = key.Reshape(ctx, opts.headDim(), cmp.Or(opts.numKVHeads, opts.numHeads), batchSize) + + value := a.Value.Forward(ctx, hiddenStates) + value = value.Reshape(ctx, opts.headDim(), cmp.Or(opts.numKVHeads, opts.numHeads), batchSize) + + attention := nn.Attention(ctx, query, key, value, 1/math.Sqrt(float64(opts.headDim())), nil) + attention = attention.Reshape(ctx, opts.hiddenSize, batchSize) + return a.Output.Forward(ctx, attention) +} + +type MLP struct { + Up *nn.Linear `gguf:"ffn_up"` + Down *nn.Linear `gguf:"ffn_down"` +} + +func (m *MLP) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *Options) ml.Tensor { + return m.Down.Forward(ctx, m.Up.Forward(ctx, hiddenStates).GELU(ctx)) +} + +type Options struct { + hiddenSize, + numHeads, + numKVHeads, + keyLength, + valueLength int + poolingType pooling.Type + eps float32 + normalize bool +} + +func (o Options) headDim() int { + return cmp.Or(o.keyLength, o.valueLength, o.hiddenSize/o.numHeads) +} + +func New(c fs.Config) (model.Model, error) { + var processor model.TextProcessor + switch c.String("tokenizer.ggml.model", "bert") { + case "bert": + processor = model.NewWordPiece( + &model.Vocabulary{ + Values: c.Strings("tokenizer.ggml.tokens"), + Scores: c.Floats("tokenizer.ggml.scores"), + Types: c.Ints("tokenizer.ggml.token_type"), + AddBOS: c.Bool("tokenizer.ggml.add_bos_token", true), + BOS: []int32{ + int32(cmp.Or( + c.Uint("tokenizer.ggml.cls_token_id"), + c.Uint("tokenizer.ggml.bos_token_id"), + )), + }, + AddEOS: c.Bool("tokenizer.ggml.add_eos_token", true), + EOS: []int32{ + int32(cmp.Or( + c.Uint("tokenizer.ggml.separator_token_id"), + //nolint:misspell + // NOTE: "seperator_token_id" is a typo in model metadata but we need to + // support it for compatibility. + c.Uint("tokenizer.ggml.seperator_token_id"), + c.Uint("tokenizer.ggml.eos_token_id"), + )), + }, + }, + ) + default: + return nil, model.ErrUnsupportedTokenizer + } + + return &Model{ + TextProcessor: processor, + Layers: make([]EncoderLayer, c.Uint("block_count")), + Options: Options{ + hiddenSize: int(c.Uint("embedding_length")), + numHeads: int(c.Uint("attention.head_count")), + numKVHeads: int(c.Uint("attention.head_count_kv")), + eps: c.Float("attention.layer_norm_epsilon"), + poolingType: pooling.Type(c.Uint("pooling_type")), + normalize: c.Bool("normalize_embeddings", true), + }, + }, nil +} + +func init() { + model.Register("bert", New) + model.Register("bert_embed", New) +} diff --git a/model/models/gemma2/model.go b/model/models/gemma2/model.go index e621d03ae..96ace7c74 100644 --- a/model/models/gemma2/model.go +++ b/model/models/gemma2/model.go @@ -24,7 +24,7 @@ type Options struct { type Model struct { model.Base - model.SentencePieceModel + model.SentencePiece TokenEmbedding *nn.Embedding `gguf:"token_embd"` Layers []Layer `gguf:"blk"` @@ -40,7 +40,7 @@ const ( func New(c fs.Config) (model.Model, error) { m := Model{ - SentencePieceModel: model.NewSentencePieceModel( + SentencePiece: model.NewSentencePiece( &model.Vocabulary{ Values: c.Strings("tokenizer.ggml.tokens"), Scores: c.Floats("tokenizer.ggml.scores"), @@ -138,7 +138,7 @@ type MLP struct { } func (mlp *MLP) Forward(ctx ml.Context, hiddenState ml.Tensor, opts *Options) ml.Tensor { - hiddenState = mlp.Gate.Forward(ctx, hiddenState).GELU(ctx).Mul(ctx, mlp.Up.Forward(ctx, hiddenState)) + hiddenState = mlp.Gate.Forward(ctx, hiddenState).GELU(ctx, mlp.Up.Forward(ctx, hiddenState)) return mlp.Down.Forward(ctx, hiddenState) } @@ -176,7 +176,6 @@ func (l *Layer) Forward(ctx ml.Context, hiddenState, positionIDs, outputs ml.Ten func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) { positions := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions)) - outputs := ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs)) hiddenState := m.TokenEmbedding.Forward(ctx, batch.Inputs) hiddenState = hiddenState.Scale(ctx, math.Sqrt(float64(m.Options.hiddenSize))) @@ -193,7 +192,7 @@ func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) { var lastLayerOutputs ml.Tensor if i == len(m.Layers)-1 { - lastLayerOutputs = outputs + lastLayerOutputs = batch.Outputs } hiddenState = layer.Forward(ctx, hiddenState, positions, lastLayerOutputs, m.Cache, m.Options) diff --git a/model/models/gemma3/embed.go b/model/models/gemma3/embed.go index 16c299e22..525547767 100644 --- a/model/models/gemma3/embed.go +++ b/model/models/gemma3/embed.go @@ -1,49 +1,38 @@ package gemma3 import ( - "errors" - "github.com/ollama/ollama/fs" "github.com/ollama/ollama/kvcache" "github.com/ollama/ollama/ml" "github.com/ollama/ollama/ml/nn" + "github.com/ollama/ollama/ml/nn/pooling" "github.com/ollama/ollama/model" "github.com/ollama/ollama/model/input" ) type embedModel struct { model.Base - model.SentencePieceModel + model.SentencePiece *TextModel - PoolingType uint32 + poolingType pooling.Type Dense [2]*nn.Linear `gguf:"dense"` } func (m *embedModel) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) { - batch.Outputs = batch.Positions // return all positions hiddenStates := m.TextModel.Forward(ctx, batch, m.Cache) - - switch m.PoolingType { - case 0: // None - case 1: // Mean - hiddenStates = hiddenStates.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx).Mean(ctx) - hiddenStates = hiddenStates.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx) - default: - return nil, errors.New("unsupported pooling type") - } - + hiddenStates = m.poolingType.Forward(ctx, hiddenStates) for _, dense := range m.Dense { hiddenStates = dense.Forward(ctx, hiddenStates) } - + hiddenStates = hiddenStates.L2Norm(ctx, 1e-12) return hiddenStates, nil } func newEmbedModel(c fs.Config) (model.Model, error) { m := &embedModel{ - SentencePieceModel: model.NewSentencePieceModel( + SentencePiece: model.NewSentencePiece( &model.Vocabulary{ Values: c.Strings("tokenizer.ggml.tokens"), Scores: c.Floats("tokenizer.ggml.scores"), @@ -61,7 +50,7 @@ func newEmbedModel(c fs.Config) (model.Model, error) { }, ), TextModel: newTextModel(c), - PoolingType: c.Uint("pooling_type", 0), + poolingType: pooling.Type(c.Uint("pooling_type", 0)), } m.Cache = kvcache.NewWrapperCache( diff --git a/model/models/gemma3/model.go b/model/models/gemma3/model.go index 5c92b6bf9..27da889e4 100644 --- a/model/models/gemma3/model.go +++ b/model/models/gemma3/model.go @@ -16,7 +16,7 @@ import ( type Model struct { model.Base - model.SentencePieceModel + model.SentencePiece *VisionModel `gguf:"v"` *TextModel @@ -55,7 +55,7 @@ func (p *MultiModalProjector) Forward(ctx ml.Context, visionOutputs ml.Tensor, i func New(c fs.Config) (model.Model, error) { m := Model{ - SentencePieceModel: model.NewSentencePieceModel( + SentencePiece: model.NewSentencePiece( &model.Vocabulary{ Values: c.Strings("tokenizer.ggml.tokens"), Scores: c.Floats("tokenizer.ggml.scores"), diff --git a/model/models/gemma3/model_text.go b/model/models/gemma3/model_text.go index 2a3b23939..d38746dc8 100644 --- a/model/models/gemma3/model_text.go +++ b/model/models/gemma3/model_text.go @@ -123,7 +123,7 @@ type TextMLP struct { } func (mlp *TextMLP) Forward(ctx ml.Context, hiddenState ml.Tensor, opts *TextConfig) ml.Tensor { - hiddenState = mlp.Gate.Forward(ctx, hiddenState).GELU(ctx).Mul(ctx, mlp.Up.Forward(ctx, hiddenState)) + hiddenState = mlp.Gate.Forward(ctx, hiddenState).GELU(ctx, mlp.Up.Forward(ctx, hiddenState)) return mlp.Down.Forward(ctx, hiddenState) } @@ -161,7 +161,6 @@ func (l *TextLayer) Forward(ctx ml.Context, layer int, hiddenState, positionIDs, func (m *TextModel) Forward(ctx ml.Context, batch input.Batch, cache kvcache.Cache) ml.Tensor { positions := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions)) - outputs := ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs)) hiddenState := m.TokenEmbedding.Forward(ctx, batch.Inputs) hiddenState = hiddenState.Scale(ctx, math.Sqrt(float64(m.TextConfig.hiddenSize))) @@ -194,7 +193,7 @@ func (m *TextModel) Forward(ctx ml.Context, batch input.Batch, cache kvcache.Cac var lastLayerOutputs ml.Tensor if i == len(m.Layers)-1 { - lastLayerOutputs = outputs + lastLayerOutputs = batch.Outputs } hiddenState = layer.Forward(ctx, i, hiddenState, positions, lastLayerOutputs, cache, m.TextConfig) diff --git a/model/models/gemma3n/model.go b/model/models/gemma3n/model.go index 6e83a9724..e59e3193f 100644 --- a/model/models/gemma3n/model.go +++ b/model/models/gemma3n/model.go @@ -10,7 +10,7 @@ import ( type Model struct { model.Base - model.SentencePieceModel + model.SentencePiece *TextModel } @@ -23,7 +23,7 @@ func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) { func New(c fs.Config) (model.Model, error) { m := Model{ TextModel: newTextModel(c), - SentencePieceModel: model.NewSentencePieceModel( + SentencePiece: model.NewSentencePiece( &model.Vocabulary{ Values: c.Strings("tokenizer.ggml.tokens"), Scores: c.Floats("tokenizer.ggml.scores"), diff --git a/model/models/gemma3n/model_text.go b/model/models/gemma3n/model_text.go index b75a2abb3..2682a45f7 100644 --- a/model/models/gemma3n/model_text.go +++ b/model/models/gemma3n/model_text.go @@ -83,7 +83,7 @@ func (m *TextModel) Forward(ctx ml.Context, batch input.Batch, cache kvcache.Cac hiddenStates = hiddenStates.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx).Mean(ctx) hiddenStates = hiddenStates.Permute(ctx, 2, 0, 1, 3).Contiguous(ctx) - hiddenStates = hiddenStates.Rows(ctx, ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs))) + hiddenStates = hiddenStates.Rows(ctx, batch.Outputs) hiddenStates = m.OutputNorm.Forward(ctx, hiddenStates, m.eps) return m.Output.Forward(ctx, hiddenStates), nil @@ -170,8 +170,7 @@ func (d TextLayer) Forward(ctx ml.Context, hiddenStates, perLayerInput, position } active = d.PerLayerInputGate.Forward(ctx, active) - active = active.GELU(ctx) - active = active.Mul(ctx, perLayerInput) + active = active.GELU(ctx, perLayerInput) active = d.PerLayerProjection.Forward(ctx, active) active = d.PostPerLayerNorm.Forward(ctx, active, opts.eps) @@ -292,7 +291,7 @@ func (mlp TextMLP) Forward(ctx ml.Context, hiddenStates ml.Tensor, activationSpa hiddenStates = hiddenStates.Sub(ctx, cutoff).RELU(ctx) } - hiddenStates = hiddenStates.GELU(ctx).Mul(ctx, upStates) + hiddenStates = hiddenStates.GELU(ctx, upStates) hiddenStates = mlp.Down.Forward(ctx, hiddenStates) return hiddenStates } diff --git a/model/models/gptoss/model.go b/model/models/gptoss/model.go index 3ef078095..8456ea5f7 100644 --- a/model/models/gptoss/model.go +++ b/model/models/gptoss/model.go @@ -41,8 +41,8 @@ func (m *Transformer) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, err } var outputs ml.Tensor - if len(batch.Outputs) > 0 && i == len(m.TransformerBlocks)-1 { - outputs = ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs)) + if i == len(m.TransformerBlocks)-1 { + outputs = batch.Outputs } hiddenStates = block.Forward(ctx, hiddenStates, positions, outputs, one, m.Cache, &m.Options) @@ -210,7 +210,7 @@ func (mlp *MLPBlock) Forward(ctx ml.Context, hiddenStates, one ml.Tensor, opts * up = mlp.Up.Forward(ctx, hiddenStates, selectedExperts) } - hiddenStates = gate.SwiGLU(ctx, up, 1.702, 7) + hiddenStates = gate.SILUAlphaLimit(ctx, up, 1.702, 7) experts := mlp.Down.Forward(ctx, hiddenStates, selectedExperts) experts = experts.Mul(ctx, routingWeights) diff --git a/model/models/llama/model.go b/model/models/llama/model.go index 77d8f36d3..572c687a9 100644 --- a/model/models/llama/model.go +++ b/model/models/llama/model.go @@ -118,7 +118,7 @@ type MLP struct { } func (mlp *MLP) Forward(ctx ml.Context, hiddenState ml.Tensor, opts *Options) ml.Tensor { - hiddenState = mlp.Gate.Forward(ctx, hiddenState).SILU(ctx).Mul(ctx, mlp.Up.Forward(ctx, hiddenState)) + hiddenState = mlp.Gate.Forward(ctx, hiddenState).SILU(ctx, mlp.Up.Forward(ctx, hiddenState)) return mlp.Down.Forward(ctx, hiddenState) } @@ -160,7 +160,7 @@ func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) { var outputs ml.Tensor if i == len(m.Layers)-1 { - outputs = ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs)) + outputs = batch.Outputs } hiddenState = layer.Forward(ctx, hiddenState, positions, outputs, m.Cache, m.Options) diff --git a/model/models/llama4/model.go b/model/models/llama4/model.go index 99a898d2d..9cb2efc87 100644 --- a/model/models/llama4/model.go +++ b/model/models/llama4/model.go @@ -176,9 +176,7 @@ func (m *Model) PostTokenize(inputs []*input.Input) ([]*input.Input, error) { func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) { positions := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions)) - outputs := ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs)) - - return m.TextModel.Forward(ctx, batch.Inputs, positions, outputs, batch, m.Cache), nil + return m.TextModel.Forward(ctx, batch.Inputs, positions, batch.Outputs, batch, m.Cache), nil } func init() { diff --git a/model/models/llama4/model_text.go b/model/models/llama4/model_text.go index 045ab403f..dbe6bba7a 100644 --- a/model/models/llama4/model_text.go +++ b/model/models/llama4/model_text.go @@ -58,14 +58,14 @@ type TextMLP struct { } func (mlp *TextMLP) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *TextOptions) ml.Tensor { - hiddenStates = mlp.Gate.Forward(ctx, hiddenStates).SILU(ctx).Mul(ctx, mlp.Up.Forward(ctx, hiddenStates)) + hiddenStates = mlp.Gate.Forward(ctx, hiddenStates).SILU(ctx, mlp.Up.Forward(ctx, hiddenStates)) return mlp.Down.Forward(ctx, hiddenStates) } type TextExperts struct { - Gate *nn.Linear `gguf:"ffn_gate_exps"` - Up *nn.Linear `gguf:"ffn_up_exps"` - Down *nn.Linear `gguf:"ffn_down_exps"` + Gate *nn.LinearBatch `gguf:"ffn_gate_exps"` + Up *nn.LinearBatch `gguf:"ffn_up_exps"` + Down *nn.LinearBatch `gguf:"ffn_down_exps"` } func (e *TextExperts) Forward(ctx ml.Context, hiddenStates, routerLogits ml.Tensor, opts *TextOptions) ml.Tensor { @@ -76,9 +76,9 @@ func (e *TextExperts) Forward(ctx ml.Context, hiddenStates, routerLogits ml.Tens hiddenStates = hiddenStates.Repeat(ctx, 1, opts.numExpertsUsed) hiddenStates = hiddenStates.Mul(ctx, scores) - upStates := e.Up.Weight.MulmatID(ctx, hiddenStates, experts) - gateStates := e.Gate.Weight.MulmatID(ctx, hiddenStates, experts) - downStates := e.Down.Weight.MulmatID(ctx, upStates.Mul(ctx, gateStates.SILU(ctx)), experts) + upStates := e.Up.Forward(ctx, hiddenStates, experts) + gateStates := e.Gate.Forward(ctx, hiddenStates, experts) + downStates := e.Down.Forward(ctx, upStates.Mul(ctx, gateStates.SILU(ctx)), experts) nextStates := downStates.View(ctx, 0, hiddenStates.Dim(0), downStates.Stride(2), hiddenStates.Dim(2)) for i := 1; i < opts.numExpertsUsed; i++ { @@ -96,7 +96,7 @@ type TextSharedExpert struct { } func (mlp *TextSharedExpert) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *TextOptions) ml.Tensor { - hiddenStates = mlp.Gate.Forward(ctx, hiddenStates).SILU(ctx).Mul(ctx, mlp.Up.Forward(ctx, hiddenStates)) + hiddenStates = mlp.Gate.Forward(ctx, hiddenStates).SILU(ctx, mlp.Up.Forward(ctx, hiddenStates)) return mlp.Down.Forward(ctx, hiddenStates) } diff --git a/model/models/mistral3/model.go b/model/models/mistral3/model.go index 408e54d3d..435b1a304 100644 --- a/model/models/mistral3/model.go +++ b/model/models/mistral3/model.go @@ -159,9 +159,8 @@ func (m *Model) PostTokenize(inputs []*input.Input) ([]*input.Input, error) { func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) { positions := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions)) - outputs := ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs)) - return m.TextModel.Forward(ctx, batch.Inputs, positions, outputs, batch, m.Cache), nil + return m.TextModel.Forward(ctx, batch.Inputs, positions, batch.Outputs, batch, m.Cache), nil } func init() { diff --git a/model/models/mistral3/model_text.go b/model/models/mistral3/model_text.go index 19c36f9fe..132d1756c 100644 --- a/model/models/mistral3/model_text.go +++ b/model/models/mistral3/model_text.go @@ -65,7 +65,7 @@ type MLP struct { } func (mlp *MLP) Forward(ctx ml.Context, hiddenState ml.Tensor, opts *TextOptions) ml.Tensor { - hiddenState = mlp.Gate.Forward(ctx, hiddenState).SILU(ctx).Mul(ctx, mlp.Up.Forward(ctx, hiddenState)) + hiddenState = mlp.Gate.Forward(ctx, hiddenState).SILU(ctx, mlp.Up.Forward(ctx, hiddenState)) return mlp.Down.Forward(ctx, hiddenState) } diff --git a/model/models/mistral3/model_vision.go b/model/models/mistral3/model_vision.go index 65bdcff2a..3bfb8c90a 100644 --- a/model/models/mistral3/model_vision.go +++ b/model/models/mistral3/model_vision.go @@ -51,7 +51,7 @@ type VisionMLP struct { } func (mlp *VisionMLP) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *VisionModelOptions) ml.Tensor { - hiddenStates = mlp.Gate.Forward(ctx, hiddenStates).SILU(ctx).Mul(ctx, mlp.Up.Forward(ctx, hiddenStates)) + hiddenStates = mlp.Gate.Forward(ctx, hiddenStates).SILU(ctx, mlp.Up.Forward(ctx, hiddenStates)) return mlp.Down.Forward(ctx, hiddenStates) } diff --git a/model/models/mllama/model.go b/model/models/mllama/model.go index d0ad4670e..239d999d5 100644 --- a/model/models/mllama/model.go +++ b/model/models/mllama/model.go @@ -107,10 +107,9 @@ func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) { } positions := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions)) - outputs := ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs)) // TODO: attention mask, cross attention mask - return m.TextModel.Forward(ctx, batch.Inputs, positions, outputs, crossAttentionStates, nil, m.Cache.(*kvcache.WrapperCache)), nil + return m.TextModel.Forward(ctx, batch.Inputs, positions, batch.Outputs, crossAttentionStates, nil, m.Cache.(*kvcache.WrapperCache)), nil } func init() { diff --git a/model/models/mllama/model_text.go b/model/models/mllama/model_text.go index 47a518ced..cb18f0878 100644 --- a/model/models/mllama/model_text.go +++ b/model/models/mllama/model_text.go @@ -58,7 +58,7 @@ type TextMLP struct { } func (mlp *TextMLP) Forward(ctx ml.Context, hiddenState ml.Tensor, opts *TextModelOptions) ml.Tensor { - hiddenState = mlp.Gate.Forward(ctx, hiddenState).SILU(ctx).Mul(ctx, mlp.Up.Forward(ctx, hiddenState)) + hiddenState = mlp.Gate.Forward(ctx, hiddenState).SILU(ctx, mlp.Up.Forward(ctx, hiddenState)) return mlp.Down.Forward(ctx, hiddenState) } diff --git a/model/models/models.go b/model/models/models.go index c880a4720..cc9980789 100644 --- a/model/models/models.go +++ b/model/models/models.go @@ -1,6 +1,7 @@ package models import ( + _ "github.com/ollama/ollama/model/models/bert" _ "github.com/ollama/ollama/model/models/gemma2" _ "github.com/ollama/ollama/model/models/gemma3" _ "github.com/ollama/ollama/model/models/gemma3n" diff --git a/model/models/qwen2/model.go b/model/models/qwen2/model.go index 3c662f068..5a8bea29e 100644 --- a/model/models/qwen2/model.go +++ b/model/models/qwen2/model.go @@ -59,7 +59,7 @@ type MLP struct { } func (mlp MLP) Forward(ctx ml.Context, hiddenStates ml.Tensor) ml.Tensor { - hiddenStates = mlp.Gate.Forward(ctx, hiddenStates).SILU(ctx).Mul(ctx, mlp.Up.Forward(ctx, hiddenStates)) + hiddenStates = mlp.Gate.Forward(ctx, hiddenStates).SILU(ctx, mlp.Up.Forward(ctx, hiddenStates)) return mlp.Down.Forward(ctx, hiddenStates) } @@ -111,7 +111,7 @@ func (m Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) { var outputs ml.Tensor if i == len(m.Layers)-1 { - outputs = ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs)) + outputs = batch.Outputs } hiddenStates = layer.Forward(ctx, hiddenStates, positions, outputs, m.Cache, &m.Options) diff --git a/model/models/qwen25vl/model.go b/model/models/qwen25vl/model.go index d73f499d2..6c76305db 100644 --- a/model/models/qwen25vl/model.go +++ b/model/models/qwen25vl/model.go @@ -140,9 +140,8 @@ func (m *Model) PostTokenize(inputs []*input.Input) ([]*input.Input, error) { func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) { positions := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions)) - outputs := ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs)) - return m.TextModel.Forward(ctx, batch.Inputs, positions, outputs, batch, m.Cache) + return m.TextModel.Forward(ctx, batch.Inputs, positions, batch.Outputs, batch, m.Cache) } func init() { diff --git a/model/models/qwen25vl/model_text.go b/model/models/qwen25vl/model_text.go index 4b6bc1666..4f4e1effd 100644 --- a/model/models/qwen25vl/model_text.go +++ b/model/models/qwen25vl/model_text.go @@ -90,7 +90,7 @@ type MLP struct { func (mlp *MLP) Forward(ctx ml.Context, hiddenState ml.Tensor, opts *TextOptions) ml.Tensor { // Apply SwiGLU activation gating - hiddenState = mlp.Gate.Forward(ctx, hiddenState).SILU(ctx).Mul(ctx, mlp.Up.Forward(ctx, hiddenState)) + hiddenState = mlp.Gate.Forward(ctx, hiddenState).SILU(ctx, mlp.Up.Forward(ctx, hiddenState)) // Project back to hidden dimension return mlp.Down.Forward(ctx, hiddenState) } diff --git a/model/models/qwen25vl/model_vision.go b/model/models/qwen25vl/model_vision.go index 4d7afaa14..3dd60e3ba 100644 --- a/model/models/qwen25vl/model_vision.go +++ b/model/models/qwen25vl/model_vision.go @@ -100,8 +100,7 @@ type VisionMLP struct { func (mlp *VisionMLP) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *VisionModelOptions) ml.Tensor { // Using activation as specified in config (likely GELU or SiLU/Swish) gateOutput := mlp.Gate.Forward(ctx, hiddenStates) - upOutput := mlp.Up.Forward(ctx, hiddenStates) - hiddenStates = gateOutput.SILU(ctx).Mul(ctx, upOutput) + hiddenStates = gateOutput.SILU(ctx, mlp.Up.Forward(ctx, hiddenStates)) return mlp.Down.Forward(ctx, hiddenStates) } diff --git a/model/models/qwen3/model.go b/model/models/qwen3/model.go index 7a83e0d04..3f86d0236 100644 --- a/model/models/qwen3/model.go +++ b/model/models/qwen3/model.go @@ -30,10 +30,10 @@ func (o Options) headDim() int { } type Attention struct { - QueryNorm *nn.RMSNorm `gguf:"attn_q_norm"` Query *nn.Linear `gguf:"attn_q"` - KeyNorm *nn.RMSNorm `gguf:"attn_k_norm"` + QueryNorm *nn.RMSNorm `gguf:"attn_q_norm"` Key *nn.Linear `gguf:"attn_k"` + KeyNorm *nn.RMSNorm `gguf:"attn_k_norm"` Value *nn.Linear `gguf:"attn_v"` Output *nn.Linear `gguf:"attn_output"` } @@ -65,10 +65,10 @@ type MLP interface { } type sparse struct { - Router *nn.Linear `gguf:"ffn_gate_inp"` - Gate *nn.Linear `gguf:"ffn_gate_exps"` - Up *nn.Linear `gguf:"ffn_up_exps"` - Down *nn.Linear `gguf:"ffn_down_exps"` + Router *nn.Linear `gguf:"ffn_gate_inp"` + Gate *nn.LinearBatch `gguf:"ffn_gate_exps"` + Up *nn.LinearBatch `gguf:"ffn_up_exps"` + Down *nn.LinearBatch `gguf:"ffn_down_exps"` } func (mlp *sparse) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *Options) ml.Tensor { @@ -87,13 +87,9 @@ func (mlp *sparse) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *Options hiddenStates = hiddenStates.Reshape(ctx, hiddenStates.Dim(0), 1, hiddenStates.Dim(1)) - upStates := mlp.Up.Weight.MulmatID(ctx, hiddenStates, selectedExperts) + hiddenStates = mlp.Gate.Forward(ctx, hiddenStates, selectedExperts).SILU(ctx, mlp.Up.Forward(ctx, hiddenStates, selectedExperts)) - hiddenStates = mlp.Gate.Weight.MulmatID(ctx, hiddenStates, selectedExperts) - hiddenStates = hiddenStates.SILU(ctx) - hiddenStates = hiddenStates.Mul(ctx, upStates) - - experts := mlp.Down.Weight.MulmatID(ctx, hiddenStates, selectedExperts) + experts := mlp.Down.Forward(ctx, hiddenStates, selectedExperts) experts = experts.Mul(ctx, routingWeights) nextStates := experts.View(ctx, 0, experts.Dim(0), experts.Stride(2), experts.Dim(2)) @@ -111,7 +107,8 @@ type dense struct { } func (mlp *dense) Forward(ctx ml.Context, hiddenStates ml.Tensor, _ *Options) ml.Tensor { - hiddenStates = mlp.Gate.Forward(ctx, hiddenStates).SILU(ctx).Mul(ctx, mlp.Up.Forward(ctx, hiddenStates)) + hiddenStates = mlp.Gate.Forward(ctx, hiddenStates). + SILU(ctx, mlp.Up.Forward(ctx, hiddenStates)) return mlp.Down.Forward(ctx, hiddenStates) } @@ -165,7 +162,7 @@ func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) { var outputs ml.Tensor if i == len(m.Layers)-1 { - outputs = ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs)) + outputs = batch.Outputs } hiddenStates = layer.Forward(ctx, hiddenStates, positions, outputs, m.Cache, m.Options) diff --git a/model/parsers/parsers.go b/model/parsers/parsers.go new file mode 100644 index 000000000..e6dbd1f4f --- /dev/null +++ b/model/parsers/parsers.go @@ -0,0 +1,37 @@ +package parsers + +import ( + "github.com/ollama/ollama/api" +) + +type Parser interface { + Add(s string, tools []api.Tool) (content string, thinking string, calls []api.ToolCall, err error) + HasToolSupport() bool + HasThinkingSupport() bool +} + +func ParserForName(name string) Parser { + switch name { + case "qwen3-coder": + parser := &Qwen3CoderParser{} + return parser + case "passthrough": + return &PassthroughParser{} + default: + return nil + } +} + +type PassthroughParser struct{} + +func (p *PassthroughParser) Add(s string, tools []api.Tool) (content string, thinking string, calls []api.ToolCall, err error) { + return s, "", nil, nil +} + +func (p *PassthroughParser) HasToolSupport() bool { + return false +} + +func (p *PassthroughParser) HasThinkingSupport() bool { + return false +} diff --git a/model/parsers/qwen3coder.go b/model/parsers/qwen3coder.go new file mode 100644 index 000000000..b0e8ec48c --- /dev/null +++ b/model/parsers/qwen3coder.go @@ -0,0 +1,410 @@ +package parsers + +import ( + "context" + "encoding/json" + "encoding/xml" + "fmt" + "log/slog" + "math" + "regexp" + "strconv" + "strings" + "unicode" + + "github.com/ollama/ollama/api" + "github.com/ollama/ollama/logutil" +) + +type qwenParserState int + +const ( + toolOpenTag = "" + toolCloseTag = "" +) + +const ( + qwenParserState_LookingForToolStart qwenParserState = iota + qwenParserState_CollectingToolContent +) + +type Qwen3CoderParser struct { + state qwenParserState + acc strings.Builder +} + +func (p *Qwen3CoderParser) HasToolSupport() bool { + return true +} + +func (p *Qwen3CoderParser) HasThinkingSupport() bool { + return false +} + +func (p *Qwen3CoderParser) Add(s string, tools []api.Tool) (content string, thinking string, calls []api.ToolCall, err error) { + p.acc.WriteString(s) + + events := p.parseEvents() + + var toolCalls []api.ToolCall + var sb strings.Builder + for _, event := range events { + switch event := event.(type) { + case qwenEventRawToolCall: + toolCall, err := parseToolCall(event, tools) + if err != nil { + slog.Warn("qwen tool call parsing failed", "error", err) + return "", "", nil, err + } + toolCalls = append(toolCalls, toolCall) + case qwenEventContent: + // TODO(drifkin): if the same turn contains multiple interleaved content + // events, we naively append them together here. See the note below about + // `qwenEvent`s for more details + sb.WriteString(event.content) + } + } + + return sb.String(), "", toolCalls, nil +} + +func (p *Qwen3CoderParser) parseEvents() []qwenEvent { + var all []qwenEvent + + keepLooping := true + for keepLooping { + var events []qwenEvent + events, keepLooping = eat(p) + if len(events) > 0 { + all = append(all, events...) + } + } + + if len(all) > 0 { + slog.Log(context.TODO(), logutil.LevelTrace, "qwen events parsed", "events", all, "state", p.state, "acc", p.acc.String()) + } + + return all +} + +// we use some internal event types in order to communicate between `Add` and +// `eat`. We do this to support interleaving content and parallel tool calls in +// the parser, even though qwen3-coder isn't supposed to do this. Our API +// doesn't currently support models outputting multiple messages in a turn, so +// we wouldn't be able to represent it yet, but there's no reason to prevent the +// parser from supporting it, especially for future models if they end up using +// a similar format. +type qwenEvent interface { + isQwenEvent() +} + +type qwenEventRawToolCall struct { + raw string +} + +type qwenEventContent struct { + content string +} + +func (qwenEventContent) isQwenEvent() {} +func (qwenEventRawToolCall) isQwenEvent() {} + +// eat consumes the parser's buffer, and returns a list of any unambiguous +// events from the current parser state. If the parser transitions to another +// state, it may have additional events to emit on the next call, which is what +// the second return value indicates +func eat(p *Qwen3CoderParser) ([]qwenEvent, bool) { + var events []qwenEvent + + switch p.state { + case qwenParserState_LookingForToolStart: + if strings.Contains(p.acc.String(), toolOpenTag) { + // we found a full tool open tag, so we can emit the content before the + // tag, being sure to trim any trailing whitespace + split := strings.SplitN(p.acc.String(), toolOpenTag, 2) + before := split[0] + before = strings.TrimRightFunc(before, unicode.IsSpace) + if len(before) > 0 { + events = append(events, qwenEventContent{content: before}) + } + after := split[1] + p.acc.Reset() + p.acc.WriteString(after) + p.state = qwenParserState_CollectingToolContent + return events, true + } else if overlap := overlap(p.acc.String(), toolOpenTag); overlap > 0 { + // we found a partial tool open tag, so we can emit the unambiguous part, + // which is the (trailing-whitespace trimmed) content before the partial + // tool open tag + beforePartialTag := p.acc.String()[:len(p.acc.String())-overlap] + trailingWhitespaceLen := trailingWhitespaceLen(beforePartialTag) + ambiguousStart := len(beforePartialTag) - trailingWhitespaceLen + unambiguous := p.acc.String()[:ambiguousStart] + ambiguous := p.acc.String()[ambiguousStart:] + p.acc.Reset() + p.acc.WriteString(ambiguous) + events = append(events, qwenEventContent{content: unambiguous}) + return events, false + } else { + // we found content that is entirely not a tool call. We should withhold + // any trailing whitespace in case this is the end of the content + whitespaceLen := trailingWhitespaceLen(p.acc.String()) + ambiguousStart := len(p.acc.String()) - whitespaceLen + unambiguous := p.acc.String()[:ambiguousStart] + ambiguous := p.acc.String()[ambiguousStart:] + p.acc.Reset() + p.acc.WriteString(ambiguous) + if len(unambiguous) > 0 { + events = append(events, qwenEventContent{content: unambiguous}) + } + return events, false + } + case qwenParserState_CollectingToolContent: + if strings.Contains(p.acc.String(), toolCloseTag) { + split := strings.SplitN(p.acc.String(), toolCloseTag, 2) + before := split[0] + if len(before) == 0 { + slog.Warn("qwen tool call closing tag found but no content before it") + } + // remove any whitespace between the tool call and any content after it + after := strings.TrimLeftFunc(split[1], unicode.IsSpace) + p.acc.Reset() + p.acc.WriteString(after) + events = append(events, qwenEventRawToolCall{raw: before}) + p.state = qwenParserState_LookingForToolStart + return events, true + } else { + // note that we don't need to check the overlap here because we only plan + // on parsing the tool call once we see the full closing tag. We don't + // stream back the unparsed tool content, so there's no need to be eager + // here + return events, false + } + default: + panic("unreachable") + } +} + +// TODO(drifkin): move this to a shared location +// longest overlap between suffix of s and prefix of delim +func overlap(s, delim string) int { + max := min(len(delim), len(s)) + for i := max; i > 0; i-- { + if strings.HasSuffix(s, delim[:i]) { + return i + } + } + return 0 +} + +func trailingWhitespaceLen(s string) int { + for i := len(s) - 1; i >= 0; i-- { + if !unicode.IsSpace(rune(s[i])) { + return len(s) - i - 1 + } + } + return len(s) +} + +type XMLFunctionCall struct { + XMLName xml.Name `xml:"function"` + Name string `xml:"name,attr"` + Parameters []XMLParameter `xml:"parameter"` +} + +type XMLParameter struct { + Name string `xml:"name,attr"` + Value string `xml:",chardata"` +} + +// parseToolCall parses a raw tool call string into an api.ToolCall. +// The raw string follows an xml-like format, here's an example: +// +// +// +// San Francisco +// +// +// celsius +// +// +func parseToolCall(raw qwenEventRawToolCall, tools []api.Tool) (api.ToolCall, error) { + toolCall := api.ToolCall{} + + xmlString := transformToXML(raw.raw) + + var functionCall XMLFunctionCall + err := xml.Unmarshal([]byte(xmlString), &functionCall) + if err != nil { + return api.ToolCall{}, err + } + + toolCall.Function = api.ToolCallFunction{ + Name: functionCall.Name, + } + + // Find the matching tool to get parameter types + var matchedTool *api.Tool + for i := range tools { + if tools[i].Function.Name == functionCall.Name { + matchedTool = &tools[i] + break + } + } + + toolCall.Function.Arguments = make(api.ToolCallFunctionArguments) + for _, parameter := range functionCall.Parameters { + // Look up the parameter type if we found the tool + var paramType api.PropertyType + if matchedTool != nil && matchedTool.Function.Parameters.Properties != nil { + if prop, ok := matchedTool.Function.Parameters.Properties[parameter.Name]; ok { + paramType = prop.Type + } + } + + toolCall.Function.Arguments[parameter.Name] = parseValue(parameter.Value, paramType) + } + + return toolCall, nil +} + +// parseValue converts a raw string value to the appropriate type based on the parameter type specification. +// +// For union types (multiple types in PropertyType, which we support but doesn't +// seem as though the reference parser does type coercion with those types in +// mind) we use a type precedence approach: +// 1. null - checked first regardless of declared types (matches reference implementation) +// 2. boolean - only "true"/"false" are valid booleans +// 3. integer - must parse as a whole number +// 4. number - must parse as numeric (returns int if no decimal part) +// 5. array - must parse as valid JSON array +// 6. object - must parse as valid JSON object +// 7. string - always succeeds (least specific type) +// +// This precedence ensures we return the most specific type that successfully parses, +// following the principle of least surprise. For example, with PropertyType{"string", "number"}, +// "123" becomes 123 (number), while "hello" becomes "hello" (string). +func parseValue(raw string, paramType api.PropertyType) any { + // first remove a single leading newlines, and a single trailing newline (if + // they exist). This follows the reference implementation + raw = strings.TrimPrefix(raw, "\n") + raw = strings.TrimSuffix(raw, "\n") + + // Check for null first (case-insensitive) - this takes precedence over any type + if strings.ToLower(raw) == "null" { + return nil + } + + // If no type is specified, default to string + if len(paramType) == 0 { + return raw + } + + // Check if any of the specified types match, using type precedence + // Order: boolean -> integer -> number -> array -> object -> string + typeSet := make(map[string]bool) + for _, t := range paramType { + typeSet[t] = true + } + + // Try boolean first (most restrictive) + if typeSet["boolean"] { + lower := strings.ToLower(raw) + switch lower { + case "true": + return true + case "false": + return false + } + // If not a valid boolean but boolean is the only type, return false (matching reference) + if len(paramType) == 1 { + return false + } + // Otherwise try other types + } + + // Try integer + if typeSet["integer"] { + if i, err := strconv.ParseInt(raw, 10, 64); err == nil { + // Return as int if it fits in int32, otherwise int64 + if i >= math.MinInt32 && i <= math.MaxInt32 { + return int(i) + } + return i + } + // If integer is the only type and parsing failed, fall back to string + if len(paramType) == 1 { + return raw + } + } + + // Try number (float) + if typeSet["number"] { + if f, err := strconv.ParseFloat(raw, 64); err == nil { + // If the number has no decimal part, return as int (matching reference) + if f == math.Trunc(f) { + i := int64(f) + if i >= math.MinInt32 && i <= math.MaxInt32 { + return int(i) + } + return i + } + return f + } + // If number is the only type and parsing failed, fall back to string + if len(paramType) == 1 { + return raw + } + } + + // Try array + if typeSet["array"] { + var arr []interface{} + if err := json.Unmarshal([]byte(raw), &arr); err == nil { + return arr + } + // If array is the only type and parsing failed, fall back to string + if len(paramType) == 1 { + return raw + } + } + + // Try object + if typeSet["object"] { + var obj map[string]interface{} + if err := json.Unmarshal([]byte(raw), &obj); err == nil { + return obj + } + // If object is the only type and parsing failed, fall back to string + if len(paramType) == 1 { + return raw + } + } + + // String always succeeds (or if "string" is in the type set) + if typeSet["string"] { + return raw + } + + // If we get here, none of the types matched and string wasn't an option + // We return string as a fallback. The reference implementation will attempt + // to parse the value as a python literal, but we purposefully don't support + // that + return raw +} + +var qwenTagRegex = regexp.MustCompile(`<(\w+)=([^>]+)>`) + +// transformToXML transforms a raw qwen tool call with xml-like tags into valid +// xml so that it can be parsed by any xml parser +func transformToXML(raw string) string { + // take the form `` and transform it to ``, taking + // care to properly escape the string that becomes the attribute value + return qwenTagRegex.ReplaceAllStringFunc(raw, func(match string) string { + groups := qwenTagRegex.FindStringSubmatch(match) + tag := groups[1] + var escapedValue strings.Builder + xml.EscapeText(&escapedValue, []byte(groups[2])) + return fmt.Sprintf(`<%s name="%s">`, tag, escapedValue.String()) + }) +} diff --git a/model/parsers/qwen3coder_test.go b/model/parsers/qwen3coder_test.go new file mode 100644 index 000000000..2389c77b5 --- /dev/null +++ b/model/parsers/qwen3coder_test.go @@ -0,0 +1,830 @@ +package parsers + +import ( + "reflect" + "testing" + + "github.com/ollama/ollama/api" +) + +// tool creates a test tool with the given name and properties +func tool(name string, props map[string]api.ToolProperty) api.Tool { + t := api.Tool{Type: "function", Function: api.ToolFunction{Name: name}} + t.Function.Parameters.Type = "object" + t.Function.Parameters.Properties = props + return t +} + +func TestQwenParserStreaming(t *testing.T) { + type step struct { + input string + wantEvents []qwenEvent + } + + cases := []struct { + desc string + steps []step + only bool + }{ + { + desc: "simple message streamed word by word", + steps: []step{ + { + input: "hi", + wantEvents: []qwenEvent{qwenEventContent{content: "hi"}}, + }, + { + input: " there", + wantEvents: []qwenEvent{qwenEventContent{content: " there"}}, + }, + }, + }, + { + desc: "content before tool call", + steps: []step{ + { + input: "hi there", + wantEvents: []qwenEvent{qwenEventContent{content: "hi there"}}, + }, + }, + }, + { + desc: "multiple tool calls in one message", + steps: []step{ + { + input: "before1in tool callafter1in tool call 2after2", + wantEvents: []qwenEvent{ + qwenEventContent{content: "before1"}, + qwenEventRawToolCall{raw: "in tool call"}, + qwenEventContent{content: "after1"}, + qwenEventRawToolCall{raw: "in tool call 2"}, + qwenEventContent{content: "after2"}, + }, + }, + }, + }, + { + desc: "tool calls with split tags", + steps: []step{ + { + input: "beforein tool callaf", + wantEvents: []qwenEvent{ + qwenEventRawToolCall{raw: "in tool call"}, + qwenEventContent{content: "af"}, + }, + }, + { + input: "ter", + wantEvents: []qwenEvent{ + qwenEventContent{content: "ter"}, + }, + }, + }, + }, + { + desc: "trailing whitespace between content and tool call", + steps: []step{ + { + input: "abc\ndef", + wantEvents: []qwenEvent{ + qwenEventContent{content: "abc"}, + qwenEventRawToolCall{raw: "def"}, + }, + }, + }, + }, + { + desc: "trailing whitespace between tool call and content", + steps: []step{ + { + input: "abc\ndef", + wantEvents: []qwenEvent{ + qwenEventRawToolCall{raw: "abc"}, + qwenEventContent{content: "def"}, + }, + }, + }, + }, + { + desc: "empty content before tool call", + steps: []step{ + { + input: "\nabc", + wantEvents: []qwenEvent{ + qwenEventRawToolCall{raw: "abc"}, + }, + }, + }, + }, + { + desc: "partial tool open tag fakeout", + steps: []step{ + { + input: "abc\n + +San Francisco + + +celsius + +`, + wantToolCall: api.ToolCall{ + Function: api.ToolCallFunction{ + Name: "get_current_temperature", + Arguments: map[string]any{ + "location": "San Francisco", + "unit": "celsius", + }, + }, + }, + }, + { + name: "names with spaces", + tools: []api.Tool{}, + rawToolCall: ` + +San Francisco + + +celsius + +`, + wantToolCall: api.ToolCall{ + Function: api.ToolCallFunction{ + Name: "get current temperature", + Arguments: map[string]any{ + "location with spaces": "San Francisco", + "unit with spaces": "celsius", + }, + }, + }, + }, + // this mirrors the reference implementation's behavior, but unclear if it + // ever happens. If so, then we should probably remove them instead, this + // test is to just document the current behavior and test that we don't get + // xml errors + { + name: "names with quotes", + tools: []api.Tool{}, + rawToolCall: ` + +San Francisco + + +"celsius" + +`, + wantToolCall: api.ToolCall{ + Function: api.ToolCallFunction{ + Name: "\"get current temperature\"", + Arguments: map[string]any{ + "\"location with spaces\"": "San Francisco", + "\"unit with spaces\"": "\"celsius\"", + }, + }, + }, + }, + { + name: "tool call with typed parameters", + tools: []api.Tool{ + tool("calculate", map[string]api.ToolProperty{ + "x": {Type: api.PropertyType{"number"}}, + "y": {Type: api.PropertyType{"integer"}}, + "enabled": {Type: api.PropertyType{"boolean"}}, + "items": {Type: api.PropertyType{"array"}}, + }), + }, + rawToolCall: ` + +3.14 + + +42 + + +true + + +["a", "b", "c"] + +`, + wantToolCall: api.ToolCall{ + Function: api.ToolCallFunction{ + Name: "calculate", + Arguments: map[string]any{ + "x": 3.14, + "y": 42, + "enabled": true, + "items": []any{"a", "b", "c"}, + }, + }, + }, + }, + } + + for i, step := range steps { + gotToolCall, err := parseToolCall(qwenEventRawToolCall{raw: step.rawToolCall}, step.tools) + if err != nil { + t.Errorf("step %d (%s): %v", i, step.name, err) + } + if !reflect.DeepEqual(gotToolCall, step.wantToolCall) { + t.Errorf("step %d (%s): got tool call %#v, want %#v", i, step.name, gotToolCall, step.wantToolCall) + } + } +} + +func TestQwenToolCallValueParsing(t *testing.T) { + cases := []struct { + desc string + raw string + paramType api.PropertyType + want any + }{ + { + desc: "default string value (no type specified)", + paramType: api.PropertyType{}, + raw: "some-string", + want: "some-string", + }, + { + desc: "trim a single leading and trailing newline", + paramType: api.PropertyType{}, + raw: "\nsome-string\n", + want: "some-string", + }, + { + desc: "trim at most one leading and trailing newline", + paramType: api.PropertyType{}, + raw: "\n\nsome-string\n\n", + want: "\nsome-string\n", + }, + { + desc: "newline really has to be the first character to be trimmed", + paramType: api.PropertyType{}, + raw: " \nsome-string\n ", + want: " \nsome-string\n ", + }, + { + desc: "numeric type", + paramType: api.PropertyType{"number"}, + raw: "123", + want: 123, + }, + // Integer parsing tests + { + desc: "integer type", + paramType: api.PropertyType{"integer"}, + raw: "42", + want: 42, + }, + { + desc: "negative integer", + paramType: api.PropertyType{"integer"}, + raw: "-100", + want: -100, + }, + { + desc: "zero integer", + paramType: api.PropertyType{"integer"}, + raw: "0", + want: 0, + }, + { + desc: "integer with leading zeros", + paramType: api.PropertyType{"integer"}, + raw: "007", + want: 7, + }, + { + desc: "large integer", + paramType: api.PropertyType{"integer"}, + raw: "2147483648", // Just beyond int32 max + want: int64(2147483648), + }, + // Float/number parsing tests + { + desc: "float type", + paramType: api.PropertyType{"number"}, + raw: "3.14", + want: 3.14, + }, + { + desc: "negative float", + paramType: api.PropertyType{"number"}, + raw: "-273.15", + want: -273.15, + }, + { + desc: "float without decimal part", + paramType: api.PropertyType{"number"}, + raw: "100.0", + want: 100, + }, + { + desc: "scientific notation positive", + paramType: api.PropertyType{"number"}, + raw: "1.23e5", + want: 123000, // Will be int since it has no decimal part + }, + { + desc: "scientific notation negative", + paramType: api.PropertyType{"number"}, + raw: "1.5e-3", + want: 0.0015, + }, + { + desc: "very small float", + paramType: api.PropertyType{"number"}, + raw: "0.00000001", + want: 0.00000001, + }, + // String parsing tests + { + desc: "explicit string type", + paramType: api.PropertyType{"string"}, + raw: "hello world", + want: "hello world", + }, + { + desc: "string with special characters", + paramType: api.PropertyType{"string"}, + raw: "/usr/local/bin/test-file_v2.0.sh", + want: "/usr/local/bin/test-file_v2.0.sh", + }, + { + desc: "string with quotes", + paramType: api.PropertyType{"string"}, + raw: `He said "hello" to me`, + want: `He said "hello" to me`, + }, + { + desc: "multiline string", + paramType: api.PropertyType{"string"}, + raw: "line one\nline two\nline three", + want: "line one\nline two\nline three", + }, + { + desc: "empty string", + paramType: api.PropertyType{"string"}, + raw: "", + want: "", + }, + { + desc: "string that looks like a number", + paramType: api.PropertyType{"string"}, + raw: "12345", + want: "12345", + }, + // Boolean parsing tests + { + desc: "boolean true", + paramType: api.PropertyType{"boolean"}, + raw: "true", + want: true, + }, + { + desc: "boolean false", + paramType: api.PropertyType{"boolean"}, + raw: "false", + want: false, + }, + { + desc: "boolean case insensitive true", + paramType: api.PropertyType{"boolean"}, + raw: "True", + want: true, + }, + { + desc: "boolean case insensitive false", + paramType: api.PropertyType{"boolean"}, + raw: "FALSE", + want: false, + }, + // Null parsing tests + { + desc: "null value lowercase", + paramType: api.PropertyType{"string"}, + raw: "null", + want: nil, + }, + { + desc: "null value case insensitive", + paramType: api.PropertyType{"integer"}, + raw: "NULL", + want: nil, + }, + // Array parsing tests + { + desc: "array of strings", + paramType: api.PropertyType{"array"}, + raw: `["foo", "bar", "baz"]`, + want: []any{"foo", "bar", "baz"}, + }, + { + desc: "array of numbers", + paramType: api.PropertyType{"array"}, + raw: `[1, 2.5, 3]`, + want: []any{float64(1), 2.5, float64(3)}, + }, + { + desc: "array of mixed types", + paramType: api.PropertyType{"array"}, + raw: `["string", 123, true, null]`, + want: []any{"string", float64(123), true, nil}, + }, + { + desc: "empty array", + paramType: api.PropertyType{"array"}, + raw: `[]`, + want: []any{}, + }, + // Object parsing tests + { + desc: "simple object", + paramType: api.PropertyType{"object"}, + raw: `{"key": "value", "number": 42}`, + want: map[string]any{"key": "value", "number": float64(42)}, + }, + { + desc: "nested object", + paramType: api.PropertyType{"object"}, + raw: `{"outer": {"inner": "value"}}`, + want: map[string]any{"outer": map[string]any{"inner": "value"}}, + }, + { + desc: "empty object", + paramType: api.PropertyType{"object"}, + raw: `{}`, + want: map[string]any{}, + }, + // Error cases and fallback behavior + { + desc: "invalid integer falls back to string", + paramType: api.PropertyType{"integer"}, + raw: "not-a-number", + want: "not-a-number", + }, + { + desc: "invalid float falls back to string", + paramType: api.PropertyType{"number"}, + raw: "3.14.159", + want: "3.14.159", + }, + { + desc: "invalid boolean falls back to false", + paramType: api.PropertyType{"boolean"}, + raw: "yes", + want: false, + }, + { + desc: "invalid JSON array falls back to string", + paramType: api.PropertyType{"array"}, + raw: "[1, 2, unclosed", + want: "[1, 2, unclosed", + }, + { + desc: "invalid JSON object falls back to string", + paramType: api.PropertyType{"object"}, + raw: `{"key": unclosed`, + want: `{"key": unclosed`, + }, + // Edge cases + { + desc: "integer overflow should use int64", + paramType: api.PropertyType{"integer"}, + raw: "2147483648", // Beyond int32 max + want: int64(2147483648), + }, + { + desc: "float with many decimal places", + paramType: api.PropertyType{"number"}, + raw: "3.141592653589793", + want: 3.141592653589793, + }, + { + desc: "string with JSON-like content", + paramType: api.PropertyType{"string"}, + raw: `{"this": "is", "just": "a string"}`, + want: `{"this": "is", "just": "a string"}`, + }, + { + desc: "whitespace-only string", + paramType: api.PropertyType{"string"}, + raw: " ", + want: " ", + }, + // Unknown parameter (no type specified in tools) + { + desc: "parameter not in tool definition defaults to string", + paramType: api.PropertyType{}, + raw: "some value", + want: "some value", + }, + // Union type tests + { + desc: "string or number union - valid number", + paramType: api.PropertyType{"string", "number"}, + raw: "42.5", + want: 42.5, + }, + { + desc: "string or number union - non-numeric string", + paramType: api.PropertyType{"string", "number"}, + raw: "hello", + want: "hello", + }, + { + desc: "number or string union - valid number (order shouldn't matter)", + paramType: api.PropertyType{"number", "string"}, + raw: "42.5", + want: 42.5, + }, + { + desc: "integer or null union - valid integer", + paramType: api.PropertyType{"integer", "null"}, + raw: "123", + want: 123, + }, + { + desc: "integer or null union - null value", + paramType: api.PropertyType{"integer", "null"}, + raw: "null", + want: nil, + }, + { + desc: "null or integer union - null value (order shouldn't matter)", + paramType: api.PropertyType{"null", "integer"}, + raw: "null", + want: nil, + }, + { + desc: "boolean or string union - valid boolean", + paramType: api.PropertyType{"boolean", "string"}, + raw: "true", + want: true, + }, + { + desc: "boolean or string union - non-boolean becomes string", + paramType: api.PropertyType{"boolean", "string"}, + raw: "yes", + want: "yes", + }, + { + desc: "string or boolean union - valid boolean (precedence test)", + paramType: api.PropertyType{"string", "boolean"}, + raw: "false", + want: false, // Should be boolean, not string "false" + }, + { + desc: "integer or number union - integer value", + paramType: api.PropertyType{"integer", "number"}, + raw: "42", + want: 42, + }, + { + desc: "integer or number union - float value", + paramType: api.PropertyType{"integer", "number"}, + raw: "42.5", + want: 42.5, + }, + { + desc: "number or integer union - integer value (precedence test)", + paramType: api.PropertyType{"number", "integer"}, + raw: "42", + want: 42, // Should try integer first due to precedence + }, + { + desc: "array or object union - valid array", + paramType: api.PropertyType{"array", "object"}, + raw: `[1, 2, 3]`, + want: []any{float64(1), float64(2), float64(3)}, + }, + { + desc: "array or object union - valid object", + paramType: api.PropertyType{"array", "object"}, + raw: `{"key": "value"}`, + want: map[string]any{"key": "value"}, + }, + { + desc: "object or array union - valid array (precedence test)", + paramType: api.PropertyType{"object", "array"}, + raw: `[1, 2, 3]`, + want: []any{float64(1), float64(2), float64(3)}, + }, + { + desc: "complex multi-type union - null", + paramType: api.PropertyType{"string", "number", "boolean", "null"}, + raw: "null", + want: nil, + }, + { + desc: "complex multi-type union - boolean", + paramType: api.PropertyType{"string", "number", "boolean", "null"}, + raw: "true", + want: true, + }, + { + desc: "complex multi-type union - number", + paramType: api.PropertyType{"string", "number", "boolean", "null"}, + raw: "3.14", + want: 3.14, + }, + { + desc: "complex multi-type union - string", + paramType: api.PropertyType{"string", "number", "boolean", "null"}, + raw: "hello", + want: "hello", + }, + { + desc: "integer string union - integer string becomes integer", + paramType: api.PropertyType{"integer", "string"}, + raw: "123", + want: 123, + }, + { + desc: "string integer union - integer string becomes integer (precedence)", + paramType: api.PropertyType{"string", "integer"}, + raw: "123", + want: 123, // Integer has higher precedence than string + }, + } + + for _, tc := range cases { + t.Run(tc.desc, func(t *testing.T) { + got := parseValue(tc.raw, tc.paramType) + if !reflect.DeepEqual(got, tc.want) { + t.Errorf("got %v (type %T), want %v (type %T)", got, got, tc.want, tc.want) + } + }) + } +} + +func TestQwenXMLTransform(t *testing.T) { + cases := []struct { + desc string + raw string + want string + }{ + { + desc: "simple example", + raw: ` + +San Francisco + + +celsius + +`, + want: ` + +San Francisco + + +celsius + +`, + }, + // even though quotes aren't expected in these tags, we have these tests to + // make sure they're escaped so they don't blow up the xml parser in case + // they happen + { + desc: "names with quotes", + raw: ` + +San Francisco + + +celsius + +`, + want: ` + +San Francisco + + +celsius + +`, + }, + } + + for _, tc := range cases { + got := transformToXML(tc.raw) + if got != tc.want { + t.Errorf("got %q, want %q", got, tc.want) + } + } +} + +func TestTrailingWhitespaceLen(t *testing.T) { + cases := []struct { + desc string + s string + want int + }{ + {desc: "no whitespace", s: "abc", want: 0}, + {desc: "trailing whitespace", s: "abc ", want: 1}, + {desc: "trailing whitespace with newlines", s: "abc \n", want: 2}, + {desc: "only whitespace", s: " \n ", want: 4}, + {desc: "leading whitespace doesn't count", s: " \n abc", want: 0}, + } + + for _, tc := range cases { + got := trailingWhitespaceLen(tc.s) + if got != tc.want { + t.Errorf("got %d, want %d", got, tc.want) + } + } +} diff --git a/model/renderers/qwen3coder.go b/model/renderers/qwen3coder.go new file mode 100644 index 000000000..df3b3a45b --- /dev/null +++ b/model/renderers/qwen3coder.go @@ -0,0 +1,217 @@ +package renderers + +import ( + "encoding/json" + "fmt" + "reflect" + "strings" + + "github.com/ollama/ollama/api" +) + +var ( + imStartTag = "<|im_start|>" + imEndTag = "<|im_end|>" +) + +// renderAdditionalKeys renders all JSON fields except the ones in handledKeys +// This follows the same approach from the reference implementation, which gives +// a particular key ordering +func renderAdditionalKeys(obj any, handledKeys map[string]bool) string { + data, err := json.Marshal(obj) + if err != nil { + return "" + } + + var m map[string]any + if err := json.Unmarshal(data, &m); err != nil { + return "" + } + + var sb strings.Builder + for key, value := range m { + if handledKeys[key] { + continue + } + + // Check if value is a map or array (needs JSON serialization) + switch v := value.(type) { + case map[string]any, []any: + jsonBytes, _ := json.Marshal(v) + // TODO(drifkin): it would be nice to format the JSON here similarly to + // python's default json.dumps behavior (spaces after commas and colons). + // This would let us be byte-for-byte compatible with the reference + // implementation for most common inputs + jsonStr := string(jsonBytes) + sb.WriteString("\n<" + key + ">" + jsonStr + "") + case nil: + continue + default: + // Simple types, convert to string + sb.WriteString("\n<" + key + ">" + fmt.Sprintf("%v", value) + "") + } + } + + return sb.String() +} + +func Qwen3CoderRenderer(messages []api.Message, tools []api.Tool, _ *api.ThinkValue) (string, error) { + var sb strings.Builder + + // filter out system messages and choose the first (if any) to win + var systemMessage string + var filteredMessages []api.Message + for _, message := range messages { + if message.Role != "system" { + filteredMessages = append(filteredMessages, message) + continue + } + + if systemMessage == "" { + systemMessage = message.Content + } + } + + if systemMessage != "" || len(tools) > 0 { + sb.WriteString(imStartTag + "system\n") + + // if we have tools but no system message, match the reference implementation by providing a default system message + if systemMessage == "" { + systemMessage = "You are Qwen, a helpful AI assistant that can interact with a computer to solve tasks." + } + + sb.WriteString(systemMessage) + + if len(tools) > 0 { + sb.WriteString("\n\n# Tools\n\nYou have access to the following functions:\n\n") + sb.WriteString("") + for _, tool := range tools { + sb.WriteString("\n") + sb.WriteString("\n") + sb.WriteString("" + tool.Function.Name + "") + if tool.Function.Description != "" { + sb.WriteString("\n" + tool.Function.Description + "") + } + sb.WriteString("\n") + + for name, prop := range tool.Function.Parameters.Properties { + sb.WriteString("\n") + sb.WriteString("\n" + name + "") + + if len(prop.Type) > 0 { + // TODO(!!!)(drifkin): we should match the reference implementation for + // more complex types here instead of using this format + sb.WriteString("\n" + prop.ToTypeScriptType() + "") + } + + if prop.Description != "" { + sb.WriteString("\n" + prop.Description + "") + } + + // Render any additional keys not already handled + handledKeys := map[string]bool{ + "type": true, + "description": true, + } + sb.WriteString(renderAdditionalKeys(prop, handledKeys)) + + sb.WriteString("\n") + } + + // Render extra keys for parameters (everything except 'type' and 'properties') + paramHandledKeys := map[string]bool{ + "type": true, + "properties": true, + } + sb.WriteString(renderAdditionalKeys(tool.Function.Parameters, paramHandledKeys)) + + sb.WriteString("\n") + sb.WriteString("\n") + } + sb.WriteString("\n") + sb.WriteString("\n\nIf you choose to call a function ONLY reply in the following format with NO suffix:\n\n\n\n\nvalue_1\n\n\nThis is the value for the second parameter\nthat can span\nmultiple lines\n\n\n\n\n\nReminder:\n- Function calls MUST follow the specified format: an inner block must be nested within XML tags\n- Required parameters MUST be specified\n- You may provide optional reasoning for your function call in natural language BEFORE the function call, but NOT after\n- If there is no function call available, answer the question like normal with your current knowledge and do not tell the user about function calls\n") + } + + sb.WriteString(imEndTag + "\n") + } + + for i, message := range filteredMessages { + lastMessage := i == len(filteredMessages)-1 + prefill := lastMessage && message.Role == "assistant" + switch message.Role { + case "assistant": + if len(message.ToolCalls) > 0 { + sb.WriteString(imStartTag + "assistant\n") + if message.Content != "" { + sb.WriteString(message.Content + "\n") + } + for _, toolCall := range message.ToolCalls { + sb.WriteString("\n\n") + for name, value := range toolCall.Function.Arguments { + valueStr := formatToolCallArgument(value) + sb.WriteString("\n\n" + valueStr + "\n") + } + sb.WriteString("\n\n") + } + sb.WriteString("<|im_end|>\n") + } else { + sb.WriteString(imStartTag + "assistant\n") + sb.WriteString(message.Content) + if !prefill { + sb.WriteString(imEndTag + "\n") + } + } + case "tool": + // consecutive tool responses should share a single `user`, but + // have their own tags + + // only start a new user block if this is the first tool response + if i == 0 || filteredMessages[i-1].Role != "tool" { + sb.WriteString(imStartTag + "user\n") + } + + sb.WriteString("\n") + sb.WriteString(message.Content) + sb.WriteString("\n\n") + + // close the user block only if this is the last tool response + if i == len(filteredMessages)-1 || filteredMessages[i+1].Role != "tool" { + sb.WriteString(imEndTag + "\n") + } + default: + sb.WriteString(imStartTag + message.Role + "\n") + sb.WriteString(message.Content) + sb.WriteString(imEndTag + "\n") + } + + if lastMessage && !prefill { + sb.WriteString(imStartTag + "assistant\n") + } + } + + return sb.String(), nil +} + +func formatToolCallArgument(value any) string { + if value == nil { + return "null" + } + + switch v := value.(type) { + case string: + return v + case []byte: + return string(v) + } + + if reflect.TypeOf(value) != nil { + kind := reflect.TypeOf(value).Kind() + if kind == reflect.Map || kind == reflect.Slice || kind == reflect.Array { + if marshalled, err := json.Marshal(value); err == nil { + return string(marshalled) + } + } + } + + return fmt.Sprintf("%v", value) +} diff --git a/model/renderers/qwen3coder_test.go b/model/renderers/qwen3coder_test.go new file mode 100644 index 000000000..4aaa066d6 --- /dev/null +++ b/model/renderers/qwen3coder_test.go @@ -0,0 +1,338 @@ +package renderers + +import ( + "testing" + + "github.com/google/go-cmp/cmp" + "github.com/ollama/ollama/api" +) + +func TestQwen3CoderRenderer(t *testing.T) { + tests := []struct { + name string + msgs []api.Message + tools []api.Tool + expected string + }{ + { + name: "basic", + msgs: []api.Message{ + {Role: "system", Content: "You are a helpful assistant."}, + {Role: "user", Content: "Hello, how are you?"}, + }, + expected: `<|im_start|>system +You are a helpful assistant.<|im_end|> +<|im_start|>user +Hello, how are you?<|im_end|> +<|im_start|>assistant +`, + }, + { + name: "with tools and response", + msgs: []api.Message{ + {Role: "system", Content: "You are a helpful assistant with access to tools."}, + {Role: "user", Content: "What is the weather like in San Francisco?"}, + { + Role: "assistant", + Content: "I'll check the weather in San Francisco for you.", + ToolCalls: []api.ToolCall{ + { + Function: api.ToolCallFunction{ + Name: "get_weather", + Arguments: map[string]any{ + "unit": "fahrenheit", + }, + }, + }, + }, + }, + {Role: "tool", Content: "{\"location\": \"San Francisco, CA\", \"temperature\": 68, \"condition\": \"partly cloudy\", \"humidity\": 65, \"wind_speed\": 12}", ToolName: "get_weather"}, + {Role: "user", Content: "That sounds nice! What about New York?"}, + }, + tools: []api.Tool{ + {Function: api.ToolFunction{ + Name: "get_weather", + Description: "Get the current weather in a given location", + Parameters: api.ToolFunctionParameters{ + Required: []string{"unit"}, + Properties: map[string]api.ToolProperty{ + "unit": {Type: api.PropertyType{"string"}, Enum: []any{"celsius", "fahrenheit"}, Description: "The unit of temperature"}, + // TODO(drifkin): add multiple params back once we have predictable + // order via some sort of ordered map type (see + // ) + /* + "location": {Type: api.PropertyType{"string"}, Description: "The city and state, e.g. San Francisco, CA"}, + */ + }, + }, + }}, + }, + expected: `<|im_start|>system +You are a helpful assistant with access to tools. + +# Tools + +You have access to the following functions: + + + +get_weather +Get the current weather in a given location + + +unit +string +The unit of temperature +["celsius","fahrenheit"] + +["unit"] + + + + +If you choose to call a function ONLY reply in the following format with NO suffix: + + + + +value_1 + + +This is the value for the second parameter +that can span +multiple lines + + + + + +Reminder: +- Function calls MUST follow the specified format: an inner block must be nested within XML tags +- Required parameters MUST be specified +- You may provide optional reasoning for your function call in natural language BEFORE the function call, but NOT after +- If there is no function call available, answer the question like normal with your current knowledge and do not tell the user about function calls +<|im_end|> +<|im_start|>user +What is the weather like in San Francisco?<|im_end|> +<|im_start|>assistant +I'll check the weather in San Francisco for you. + + + + +fahrenheit + + +<|im_end|> +<|im_start|>user + +{"location": "San Francisco, CA", "temperature": 68, "condition": "partly cloudy", "humidity": 65, "wind_speed": 12} + +<|im_end|> +<|im_start|>user +That sounds nice! What about New York?<|im_end|> +<|im_start|>assistant +`, + }, + { + name: "parallel tool calls", + msgs: []api.Message{ + {Role: "system", Content: "You are a helpful assistant with access to tools."}, + {Role: "user", Content: "call double(1) and triple(2)"}, + {Role: "assistant", Content: "I'll call double(1) and triple(2) for you.", ToolCalls: []api.ToolCall{ + {Function: api.ToolCallFunction{Name: "double", Arguments: map[string]any{"number": "1"}}}, + {Function: api.ToolCallFunction{Name: "triple", Arguments: map[string]any{"number": "2"}}}, + }}, + {Role: "tool", Content: "{\"number\": 2}", ToolName: "double"}, + {Role: "tool", Content: "{\"number\": 6}", ToolName: "triple"}, + }, + tools: []api.Tool{ + {Function: api.ToolFunction{Name: "double", Description: "Double a number", Parameters: api.ToolFunctionParameters{Properties: map[string]api.ToolProperty{ + "number": {Type: api.PropertyType{"string"}, Description: "The number to double"}, + }}}}, + {Function: api.ToolFunction{Name: "triple", Description: "Triple a number", Parameters: api.ToolFunctionParameters{Properties: map[string]api.ToolProperty{ + "number": {Type: api.PropertyType{"string"}, Description: "The number to triple"}, + }}}}, + }, + expected: `<|im_start|>system +You are a helpful assistant with access to tools. + +# Tools + +You have access to the following functions: + + + +double +Double a number + + +number +string +The number to double + + + + +triple +Triple a number + + +number +string +The number to triple + + + + + +If you choose to call a function ONLY reply in the following format with NO suffix: + + + + +value_1 + + +This is the value for the second parameter +that can span +multiple lines + + + + + +Reminder: +- Function calls MUST follow the specified format: an inner block must be nested within XML tags +- Required parameters MUST be specified +- You may provide optional reasoning for your function call in natural language BEFORE the function call, but NOT after +- If there is no function call available, answer the question like normal with your current knowledge and do not tell the user about function calls +<|im_end|> +<|im_start|>user +call double(1) and triple(2)<|im_end|> +<|im_start|>assistant +I'll call double(1) and triple(2) for you. + + + + +1 + + + + + + +2 + + +<|im_end|> +<|im_start|>user + +{"number": 2} + + +{"number": 6} + +<|im_end|> +<|im_start|>assistant +`, + }, + { + name: "prefill", + msgs: []api.Message{ + {Role: "system", Content: "You are a helpful assistant."}, + {Role: "user", Content: "Tell me something interesting."}, + {Role: "assistant", Content: "I'll tell you something interesting about cats"}, + }, + expected: `<|im_start|>system +You are a helpful assistant.<|im_end|> +<|im_start|>user +Tell me something interesting.<|im_end|> +<|im_start|>assistant +I'll tell you something interesting about cats`, + }, + { + name: "complex tool call arguments should remain json encoded", + msgs: []api.Message{ + {Role: "user", Content: "call tool"}, + {Role: "assistant", ToolCalls: []api.ToolCall{ + {Function: api.ToolCallFunction{ + Name: "echo", + Arguments: map[string]any{ + "payload": map[string]any{"foo": "bar"}, + }, + }}, + }}, + {Role: "tool", Content: "{\"payload\": {\"foo\": \"bar\"}}", ToolName: "echo"}, + }, + expected: `<|im_start|>user +call tool<|im_end|> +<|im_start|>assistant + + + + +{"foo":"bar"} + + +<|im_end|> +<|im_start|>user + +{"payload": {"foo": "bar"}} + +<|im_end|> +<|im_start|>assistant +`, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + rendered, err := Qwen3CoderRenderer(tt.msgs, tt.tools, nil) + if err != nil { + t.Fatal(err) + } + if diff := cmp.Diff(rendered, tt.expected); diff != "" { + t.Errorf("mismatch (-got +want):\n%s", diff) + } + }) + } +} + +func TestFormatToolCallArgument(t *testing.T) { + tests := []struct { + name string + arg any + expected string + }{ + { + name: "string", + arg: "foo", + // notice no quotes around the string + expected: "foo", + }, + { + name: "map", + arg: map[string]any{"foo": "bar"}, + expected: "{\"foo\":\"bar\"}", + }, + { + name: "number", + arg: 1, + expected: "1", + }, + { + name: "boolean", + arg: true, + expected: "true", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := formatToolCallArgument(tt.arg) + if got != tt.expected { + t.Errorf("formatToolCallArgument(%v) = %v, want %v", tt.arg, got, tt.expected) + } + }) + } +} diff --git a/model/renderers/renderer.go b/model/renderers/renderer.go new file mode 100644 index 000000000..2dfb51e49 --- /dev/null +++ b/model/renderers/renderer.go @@ -0,0 +1,26 @@ +package renderers + +import ( + "fmt" + + "github.com/ollama/ollama/api" +) + +type rendererFunc func([]api.Message, []api.Tool, *api.ThinkValue) (string, error) + +func RenderWithRenderer(name string, msgs []api.Message, tools []api.Tool, think *api.ThinkValue) (string, error) { + renderer := rendererForName(name) + if renderer == nil { + return "", fmt.Errorf("unknown renderer %q", name) + } + return renderer(msgs, tools, think) +} + +func rendererForName(name string) rendererFunc { + switch name { + case "qwen3-coder": + return Qwen3CoderRenderer + default: + return nil + } +} diff --git a/model/sentencepiece.go b/model/sentencepiece.go index 827ce00d9..db07beee9 100644 --- a/model/sentencepiece.go +++ b/model/sentencepiece.go @@ -12,18 +12,18 @@ import ( const spmWhitespaceSep = "▁" -type SentencePieceModel struct { +type SentencePiece struct { maxTokenLen int vocab *Vocabulary } -var _ TextProcessor = (*SentencePieceModel)(nil) +var _ TextProcessor = (*SentencePiece)(nil) -func (spm SentencePieceModel) Vocabulary() *Vocabulary { +func (spm SentencePiece) Vocabulary() *Vocabulary { return spm.vocab } -func NewSentencePieceModel(vocab *Vocabulary) SentencePieceModel { +func NewSentencePiece(vocab *Vocabulary) SentencePiece { logutil.Trace("Tokens", "num tokens", len(vocab.Values), "vals", vocab.Values[:5], "scores", vocab.Scores[:5], "types", vocab.Types[:5]) counter := map[int]int{} @@ -42,17 +42,17 @@ func NewSentencePieceModel(vocab *Vocabulary) SentencePieceModel { "user defined", counter[TOKEN_TYPE_USER_DEFINED], "unused", counter[TOKEN_TYPE_UNUSED], "byte", counter[TOKEN_TYPE_BYTE], "max token len", maxTokenLen) - return SentencePieceModel{ + return SentencePiece{ maxTokenLen: maxTokenLen, vocab: vocab, } } -func (spm SentencePieceModel) Is(id int32, special Special) bool { +func (spm SentencePiece) Is(id int32, special Special) bool { return spm.vocab.Is(id, special) } -func (spm SentencePieceModel) Encode(s string, addSpecial bool) ([]int32, error) { +func (spm SentencePiece) Encode(s string, addSpecial bool) ([]int32, error) { fragments := []fragment{{value: s}} for _, special := range spm.vocab.SpecialVocabulary() { id := spm.vocab.Encode(special) @@ -218,7 +218,7 @@ func (q *queue) Pop() interface{} { return item } -func (spm SentencePieceModel) Decode(ids []int32) (string, error) { +func (spm SentencePiece) Decode(ids []int32) (string, error) { var sb strings.Builder for _, id := range ids { data := spm.vocab.Decode(id) diff --git a/model/sentencepiece_test.go b/model/sentencepiece_test.go index 50ac26787..8f4570c17 100644 --- a/model/sentencepiece_test.go +++ b/model/sentencepiece_test.go @@ -12,7 +12,7 @@ import ( "github.com/ollama/ollama/convert/sentencepiece" ) -func loadSentencePieceVocab(t *testing.T) SentencePieceModel { +func loadSentencePieceVocab(t *testing.T) SentencePiece { t.Helper() bts, err := os.ReadFile(filepath.Join("testdata", "gemma2", "tokenizer.model")) @@ -45,7 +45,7 @@ func loadSentencePieceVocab(t *testing.T) SentencePieceModel { } } - return NewSentencePieceModel(&v) + return NewSentencePiece(&v) } func TestSentencePieceEncode(t *testing.T) { @@ -115,7 +115,7 @@ func TestSentencePieceEncode(t *testing.T) { }) } -func TestSentencePieceModelDecodeByteTokens(t *testing.T) { +func TestSentencePieceDecodeByteTokens(t *testing.T) { vocab := &Vocabulary{ Values: []string{ "normal", @@ -134,7 +134,7 @@ func TestSentencePieceModelDecodeByteTokens(t *testing.T) { Scores: []float32{0, 0, 0, 0, 0}, } - spm := NewSentencePieceModel(vocab) + spm := NewSentencePiece(vocab) tests := []struct { name string diff --git a/model/wordpiece.go b/model/wordpiece.go new file mode 100644 index 000000000..e8d5e848a --- /dev/null +++ b/model/wordpiece.go @@ -0,0 +1,167 @@ +package model + +import ( + "fmt" + "iter" + "strings" + "unicode" + + "github.com/ollama/ollama/logutil" +) + +type WordPiece struct { + vocab *Vocabulary +} + +// ggmlPrefix is the prefix used by GGML vocabularies to indicate word boundaries. +// this differs from original word piece which uses "##" to indicate subwords. +const ggmlPrefix = "▁" + +var wordPieceReplacer = strings.NewReplacer( + " .", ".", + " ?", "?", + " !", "!", + " ,", ",", + " ' ", "'", + " n't", "n't", + " 'm", "'m", + " do not", " don't", + " 's", "'s", + " 've", "'ve", + " 're", "'re", +) + +// Decode implements TextProcessor. +func (wpm WordPiece) Decode(ids []int32) (string, error) { + var sb strings.Builder + for i, id := range ids { + if id < 0 || int(id) >= len(wpm.vocab.Values) { + return "", fmt.Errorf("invalid token id: %d", id) + } + + var separator string + piece := wpm.vocab.Values[id] + if i > 0 && + (strings.HasPrefix(piece, ggmlPrefix) || + (strings.HasPrefix(piece, "[") && strings.HasSuffix(piece, "]"))) { + separator = " " + } + + sb.WriteString(wordPieceReplacer.Replace(separator + strings.TrimPrefix(piece, ggmlPrefix))) + } + + return sb.String(), nil +} + +// words splits a string into words, treating CJK characters as separate words. +// TODO: this is specifically for BERT and may need to be adjusted or refactored for other models. +func (wpm WordPiece) words(s string) iter.Seq[string] { + return func(yield func(string) bool) { + runes := make([]rune, 0, len(s)*3) + for _, r := range s { + switch { + case r >= 0x4E00 && r <= 0x9FFF, + r >= 0x3400 && r <= 0x4DBF, + r >= 0x20000 && r <= 0x2A6DF, + r >= 0x2A700 && r <= 0x2B73F, + r >= 0x2B740 && r <= 0x2B81F, + r >= 0x2B820 && r <= 0x2CEAF, + r >= 0xF900 && r <= 0xFAFF, + r >= 0x2F800 && r <= 0x2FA1F: + runes = append(runes, ' ', r, ' ') + default: + runes = append(runes, r) + } + } + + for w := range strings.FieldsFuncSeq(string(runes), unicode.IsSpace) { + // split on but keep punctuation + var start int + for start < len(w) { + end := strings.IndexFunc(w[start:], unicode.IsPunct) + if end < 0 { + end = len(w) - start + } else if end == 0 { + end = 1 + } + + if !yield(w[start : start+end]) { + return + } + + start += end + } + } + } +} + +// Encode implements TextProcessor. +func (wpm WordPiece) Encode(s string, addSpecial bool) ([]int32, error) { + var ids []int32 + + // TODO: use [UNK] from config + unk := wpm.vocab.Encode("[UNK]") + for word := range wpm.words(s) { + var start int + var pieces []int32 + for start < len(word) { + end := len(word) + + var piece int32 + for start < end { + subword := word[start:end] + if start == 0 { + subword = ggmlPrefix + subword + } + + // TODO: some models might not want [ToLower] + piece = wpm.vocab.Encode(strings.ToLower(subword)) + if piece >= 0 { + break + } + + end-- + } + + if piece < 0 { + // Unknown token + pieces = pieces[:0] + break + } + + pieces = append(pieces, piece) + start = end + } + + if len(pieces) > 0 { + ids = append(ids, pieces...) + } else { + ids = append(ids, unk) + } + } + + if addSpecial && len(ids) > 0 { + ids = wpm.vocab.addSpecials(ids) + } + + logutil.Trace("encoded", "string", s, "ids", ids) + return ids, nil +} + +// Is implements TextProcessor. +func (wpm WordPiece) Is(id int32, special Special) bool { + return wpm.vocab.Is(id, special) +} + +// Vocabulary implements TextProcessor. +func (wpm WordPiece) Vocabulary() *Vocabulary { + return wpm.vocab +} + +var _ TextProcessor = (*WordPiece)(nil) + +func NewWordPiece(vocab *Vocabulary) WordPiece { + return WordPiece{ + vocab: vocab, + } +} diff --git a/model/wordpiece_test.go b/model/wordpiece_test.go new file mode 100644 index 000000000..258fbffcb --- /dev/null +++ b/model/wordpiece_test.go @@ -0,0 +1,51 @@ +package model + +import ( + "slices" + "testing" + + "github.com/google/go-cmp/cmp" +) + +func TestWordPiece(t *testing.T) { + wpm := NewWordPiece( + &Vocabulary{ + Values: []string{"[UNK]", "[CLS]", "[SEP]", "▁hello", "▁world", "s", "▁!", "▁@", "▁#"}, + AddBOS: true, + AddEOS: true, + BOS: []int32{1}, + EOS: []int32{2}, + }) + + ids, err := wpm.Encode("Hello world!", true) + if err != nil { + t.Fatal(err) + } + + if diff := cmp.Diff([]int32{1, 3, 4, 6, 2}, ids); diff != "" { + t.Errorf("unexpected ids (-want +got):\n%s", diff) + } + + words, err := wpm.Decode(ids) + if err != nil { + t.Fatal(err) + } + + if diff := cmp.Diff("[CLS] hello world! [SEP]", words); diff != "" { + t.Errorf("unexpected words (-want +got):\n%s", diff) + } +} + +func TestWordPieceWords(t *testing.T) { + var wpm WordPiece + + basic := slices.Collect(wpm.words("Hey friend! How are you?!?")) + if diff := cmp.Diff([]string{"Hey", "friend", "!", "How", "are", "you", "?", "!", "?"}, basic); diff != "" { + t.Errorf("unexpected words (-want +got):\n%s", diff) + } + + chinese := slices.Collect(wpm.words("野口里佳 Noguchi Rika")) + if diff := cmp.Diff([]string{"野", "口", "里", "佳", "Noguchi", "Rika"}, chinese); diff != "" { + t.Errorf("unexpected words (-want +got):\n%s", diff) + } +} diff --git a/openai/openai.go b/openai/openai.go index b6a8a95e2..7ef5ac6de 100644 --- a/openai/openai.go +++ b/openai/openai.go @@ -105,16 +105,18 @@ type ChatCompletionRequest struct { Tools []api.Tool `json:"tools"` Reasoning *Reasoning `json:"reasoning,omitempty"` ReasoningEffort *string `json:"reasoning_effort,omitempty"` + DebugRenderOnly bool `json:"_debug_render_only"` } type ChatCompletion struct { - Id string `json:"id"` - Object string `json:"object"` - Created int64 `json:"created"` - Model string `json:"model"` - SystemFingerprint string `json:"system_fingerprint"` - Choices []Choice `json:"choices"` - Usage Usage `json:"usage,omitempty"` + Id string `json:"id"` + Object string `json:"object"` + Created int64 `json:"created"` + Model string `json:"model"` + SystemFingerprint string `json:"system_fingerprint"` + Choices []Choice `json:"choices"` + Usage Usage `json:"usage,omitempty"` + DebugInfo *api.DebugInfo `json:"_debug_info,omitempty"` } type ChatCompletionChunk struct { @@ -141,6 +143,7 @@ type CompletionRequest struct { Temperature *float32 `json:"temperature"` TopP float32 `json:"top_p"` Suffix string `json:"suffix"` + DebugRenderOnly bool `json:"_debug_render_only"` } type Completion struct { @@ -273,8 +276,8 @@ func toChatCompletion(id string, r api.ChatResponse) ChatCompletion { } return nil }(r.DoneReason), - }}, - Usage: toUsage(r), + }}, Usage: toUsage(r), + DebugInfo: r.DebugInfo, } } @@ -568,13 +571,14 @@ func fromChatRequest(r ChatCompletionRequest) (*api.ChatRequest, error) { } return &api.ChatRequest{ - Model: r.Model, - Messages: messages, - Format: format, - Options: options, - Stream: &r.Stream, - Tools: r.Tools, - Think: think, + Model: r.Model, + Messages: messages, + Format: format, + Options: options, + Stream: &r.Stream, + Tools: r.Tools, + Think: think, + DebugRenderOnly: r.DebugRenderOnly, }, nil } @@ -648,11 +652,12 @@ func fromCompleteRequest(r CompletionRequest) (api.GenerateRequest, error) { } return api.GenerateRequest{ - Model: r.Model, - Prompt: r.Prompt, - Options: options, - Stream: &r.Stream, - Suffix: r.Suffix, + Model: r.Model, + Prompt: r.Prompt, + Options: options, + Stream: &r.Stream, + Suffix: r.Suffix, + DebugRenderOnly: r.DebugRenderOnly, }, nil } diff --git a/parser/parser.go b/parser/parser.go index e080f1bb7..c2e8f981f 100644 --- a/parser/parser.go +++ b/parser/parser.go @@ -100,6 +100,10 @@ func (f Modelfile) CreateRequest(relativeDir string) (*api.CreateRequest, error) req.System = c.Args case "license": licenses = append(licenses, c.Args) + case "renderer": + req.Renderer = c.Args + case "parser": + req.Parser = c.Args case "message": role, msg, _ := strings.Cut(c.Args, ": ") messages = append(messages, api.Message{Role: role, Content: msg}) @@ -320,7 +324,7 @@ func (c Command) String() string { switch c.Name { case "model": fmt.Fprintf(&sb, "FROM %s", c.Args) - case "license", "template", "system", "adapter": + case "license", "template", "system", "adapter", "renderer", "parser": fmt.Fprintf(&sb, "%s %s", strings.ToUpper(c.Name), quote(c.Args)) case "message": role, message, _ := strings.Cut(c.Args, ": ") @@ -346,7 +350,7 @@ const ( var ( errMissingFrom = errors.New("no FROM line") errInvalidMessageRole = errors.New("message role must be one of \"system\", \"user\", or \"assistant\"") - errInvalidCommand = errors.New("command must be one of \"from\", \"license\", \"template\", \"system\", \"adapter\", \"parameter\", or \"message\"") + errInvalidCommand = errors.New("command must be one of \"from\", \"license\", \"template\", \"system\", \"adapter\", \"renderer\", \"parser\", \"parameter\", or \"message\"") ) type ParserError struct { @@ -606,7 +610,7 @@ func isValidMessageRole(role string) bool { func isValidCommand(cmd string) bool { switch strings.ToLower(cmd) { - case "from", "license", "template", "system", "adapter", "parameter", "message": + case "from", "license", "template", "system", "adapter", "renderer", "parser", "parameter", "message": return true default: return false diff --git a/parser/parser_test.go b/parser/parser_test.go index 7d5a808ba..1524e890a 100644 --- a/parser/parser_test.go +++ b/parser/parser_test.go @@ -198,6 +198,34 @@ BADCOMMAND param1 value1 } } +func TestParseFileRenderer(t *testing.T) { + input := ` +FROM foo +RENDERER renderer1 +` + + reader := strings.NewReader(input) + + modelfile, err := ParseFile(reader) + require.NoError(t, err) + + assert.Equal(t, []Command{{Name: "model", Args: "foo"}, {Name: "renderer", Args: "renderer1"}}, modelfile.Commands) +} + +func TestParseFileParser(t *testing.T) { + input := ` +FROM foo +PARSER parser1 +` + + reader := strings.NewReader(input) + + modelfile, err := ParseFile(reader) + require.NoError(t, err) + + assert.Equal(t, []Command{{Name: "model", Args: "foo"}, {Name: "parser", Args: "parser1"}}, modelfile.Commands) +} + func TestParseFileMessages(t *testing.T) { cases := []struct { input string diff --git a/parser/token_parser.go b/parser/token_parser.go deleted file mode 100644 index 812458299..000000000 --- a/parser/token_parser.go +++ /dev/null @@ -1,126 +0,0 @@ -package parser - -import ( - "encoding/json" - "errors" - "strings" - - "github.com/ollama/ollama/api" - "github.com/ollama/ollama/harmony" -) - -type TokenParserType int - -const ( - TokenParserTypeDefault TokenParserType = iota - TokenParserTypeHarmony -) - -type TokenParser struct { - messageHandler MessageHandler - parserEngine ParserInternals - toolParser ToolParser - lastToken string - tokenRepeat int - repeatLimit int -} - -const defaultTokenRepeatLimit = 30 - -type MessageHandler interface { - AddContent(token string) (content, thinking string, toolContent string) -} - -type ParserInternals interface { - AddImplicitStartOrPrefill(prefillString string) -} - -type ToolParser interface { - Add(token string) - Drain() (toolName *string, toolContent string) -} - -// Default implementation for the TokenParser interface as a no-op passthrough -type defaultMessageHandler struct{} - -func (defaultMessageHandler) AddContent(token string) (string, string, string) { - return token, "", "" -} - -type defaultEngine struct{} - -func (defaultEngine) AddImplicitStartOrPrefill(prefillString string) {} - -type defaultToolParser struct{} - -func (defaultToolParser) Add(token string) {} - -func (defaultToolParser) Drain() (*string, string) { return nil, "" } - -func NewTokenParser(parserType TokenParserType, prefillString string) TokenParser { - switch parserType { - case TokenParserTypeHarmony: - harmonyMessageHandler := harmony.NewHarmonyMessageHandler() - harmonyMessageHandler.HarmonyParser.AddImplicitStartOrPrefill(prefillString) - return TokenParser{ - messageHandler: harmonyMessageHandler, - parserEngine: harmonyMessageHandler.HarmonyParser, - toolParser: harmonyMessageHandler.ToolParser, - repeatLimit: defaultTokenRepeatLimit, - } - - default: - return TokenParser{ - messageHandler: defaultMessageHandler{}, - parserEngine: defaultEngine{}, - toolParser: defaultToolParser{}, - repeatLimit: 30, - } - } -} - -func (p *TokenParser) AddContent(token string) (string, string, error) { - if p.repeatLimitReached(token) { - return "", "", errors.New("token repeat limit reached") - } - content, thinking, toolContent := p.messageHandler.AddContent(token) - p.toolParser.Add(toolContent) - return content, thinking, nil -} - -// repeatLimitReached updates repeat counters and returns true if the repeat limit is reached. -func (p *TokenParser) repeatLimitReached(token string) bool { - if p == nil { - return false - } - trimmed := strings.TrimSpace(token) - if trimmed == p.lastToken { - p.tokenRepeat++ - } else { - p.tokenRepeat = 0 - } - p.lastToken = trimmed - - return p.tokenRepeat >= p.repeatLimit -} - -// TODO: update to work with multiple toolcalls - unmarshalling should also happen on parser level -func (p *TokenParser) Drain() []api.ToolCall { - toolName, toolContent := p.toolParser.Drain() - if toolName != nil { - *toolName = strings.TrimPrefix(*toolName, "functions.") - var args api.ToolCallFunctionArguments - if err := json.Unmarshal([]byte(toolContent), &args); err != nil { - return nil - } - return []api.ToolCall{ - { - Function: api.ToolCallFunction{ - Name: *toolName, - Arguments: args, - }, - }, - } - } - return nil -} diff --git a/runner/ollamarunner/runner.go b/runner/ollamarunner/runner.go index 676e5186f..480cfc19b 100644 --- a/runner/ollamarunner/runner.go +++ b/runner/ollamarunner/runner.go @@ -11,7 +11,6 @@ import ( "image" "log" "log/slog" - "math" "net" "net/http" "os" @@ -32,9 +31,9 @@ import ( "github.com/ollama/ollama/llm" "github.com/ollama/ollama/logutil" "github.com/ollama/ollama/ml" + "github.com/ollama/ollama/ml/nn/pooling" "github.com/ollama/ollama/model" "github.com/ollama/ollama/model/input" - "github.com/ollama/ollama/parser" "github.com/ollama/ollama/runner/common" "github.com/ollama/ollama/sample" @@ -406,7 +405,7 @@ func (s *Server) removeSequence(seqIndex int, reason llm.DoneReason) { func (s *Server) run(ctx context.Context) { s.ready.Wait() - supportsAsync := s.model.Backend().Config().Uint("pooling_type", math.MaxUint32) == math.MaxUint32 + supportsAsync := pooling.Type(s.model.Backend().Config().Uint("pooling_type")) == pooling.TypeNone var activeBatch batchState for { @@ -468,6 +467,7 @@ func (s *Server) forwardBatch(pendingBatch batchState) (nextBatch batchState, er // Prepare the seqs and batch, but defer the input token values as we may not be ready yet var batchInputs []*input.Input + var batchOutputs []int32 var batch input.Batch resumeSeq := -1 @@ -550,9 +550,9 @@ func (s *Server) forwardBatch(pendingBatch batchState) (nextBatch batchState, er batch.Positions = append(batch.Positions, int32(len(seq.cache.Inputs)+len(seq.pendingInputs))) batch.Sequences = append(batch.Sequences, seq.cache.Id) - seq.iBatch = len(batch.Outputs) - if i+1 == len(seq.inputs) { - batch.Outputs = append(batch.Outputs, int32(len(batchInputs)-1)) + seq.iBatch = len(batchOutputs) + if i+1 == len(seq.inputs) || seq.embeddingOnly { + batchOutputs = append(batchOutputs, int32(len(batchInputs)-1)) } logutil.Trace("forwardBatch iBatch", "batchID", s.batchID, "seqIdx", seqIdx, "seq.iBatch", seq.iBatch, "i+1", i+1, "len(seq.inputs)", len(seq.inputs)) seq.pendingInputs = append(seq.pendingInputs, inp) @@ -577,6 +577,7 @@ func (s *Server) forwardBatch(pendingBatch batchState) (nextBatch batchState, er // Actual batchInputs values will be injected into the batch.Inputs tensor before calling Compute batch.Inputs = nextBatch.ctx.Input().Empty(ml.DTypeI32, len(batchInputs)) + batch.Outputs = nextBatch.ctx.Input().FromIntSlice(batchOutputs, len(batchOutputs)) nextBatch.modelOutput, err = model.Forward(nextBatch.ctx, s.model, batch) if err != nil { err = fmt.Errorf("failed to build graph: %w", err) @@ -704,8 +705,8 @@ func (s *Server) computeBatch(activeBatch batchState) { } // sample a token - vocabSize := len(outputs) / len(activeBatch.batch.Outputs) - logutil.Trace("computeBatch: vocab details", "batchID", activeBatch.id, "seqIdx", i, "len(logits)", len(outputs), "len(activeBatch.batch.Outputs)", len(activeBatch.batch.Outputs), "vocabSize", vocabSize, "iBatches", iBatches) + vocabSize := len(outputs) / activeBatch.batch.Outputs.Dim(0) + logutil.Trace("computeBatch: vocab details", "batchID", activeBatch.id, "seqIdx", i, "len(logits)", len(outputs), "len(activeBatch.batch.Outputs)", activeBatch.batch.Outputs.Dim(0), "vocabSize", vocabSize, "iBatches", iBatches) token, err := seq.sampler.Sample(outputs[iBatches[i]*vocabSize : (iBatches[i]+1)*vocabSize]) if err != nil { s.hardErrCh <- fmt.Errorf("failed to sample token: %w", err) @@ -781,8 +782,6 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) { return } - tokenParser := parser.NewTokenParser(req.ParserType, req.PrefillString) - if req.Options == nil { opts := api.DefaultOptions() req.Options = &opts @@ -873,18 +872,8 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) { return case content, ok := <-seq.responses: if ok { - var thinking string - var err error - content, thinking, err = tokenParser.AddContent(content) - if err != nil { - http.Error(w, err.Error(), http.StatusInternalServerError) - close(seq.quit) - return - } - if err := json.NewEncoder(w).Encode(&llm.CompletionResponse{ - Content: content, - Thinking: thinking, + Content: content, }); err != nil { http.Error(w, fmt.Sprintf("failed to encode response: %v", err), http.StatusInternalServerError) close(seq.quit) @@ -893,9 +882,7 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) { flusher.Flush() } else { - toolCalls := tokenParser.Drain() if err := json.NewEncoder(w).Encode(&llm.CompletionResponse{ - ToolCalls: toolCalls, Done: true, DoneReason: seq.doneReason, PromptEvalCount: seq.numPromptInputs, @@ -913,7 +900,7 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) { } func (s *Server) embeddings(w http.ResponseWriter, r *http.Request) { - if s.model.Backend().Config().Uint("pooling_type", math.MaxUint32) == math.MaxUint32 { + if pooling.Type(s.model.Backend().Config().Uint("pooling_type")) == pooling.TypeNone { http.Error(w, "this model does not support embeddings", http.StatusNotImplemented) return } @@ -1061,12 +1048,8 @@ func (s *Server) reserveWorstCaseGraph() error { batch.Positions[i] = int32(i) } - batch.Outputs = make([]int32, s.parallel) - for i := range batch.Outputs { - batch.Outputs[i] = int32(i) - } - batch.Inputs = ctx.Input().FromIntSlice(batchInputs, len(batchInputs)) + batch.Outputs = ctx.Input().Empty(ml.DTypeI32, s.parallel) cache := s.model.Config().Cache if cache != nil { diff --git a/server/create.go b/server/create.go index bd970876f..f08f18b34 100644 --- a/server/create.go +++ b/server/create.go @@ -323,6 +323,8 @@ func createModel(r api.CreateRequest, name model.Name, baseLayers []*layerGGML, RootFS: RootFS{ Type: "layers", }, + Renderer: r.Renderer, + Parser: r.Parser, } var layers []Layer diff --git a/server/images.go b/server/images.go index 504eb95cf..6432860f8 100644 --- a/server/images.go +++ b/server/images.go @@ -24,6 +24,7 @@ import ( "github.com/ollama/ollama/api" "github.com/ollama/ollama/envconfig" "github.com/ollama/ollama/fs/gguf" + "github.com/ollama/ollama/model/parsers" "github.com/ollama/ollama/parser" "github.com/ollama/ollama/template" "github.com/ollama/ollama/thinking" @@ -94,8 +95,9 @@ func (m *Model) Capabilities() []model.Capability { return capabilities } + builtinParser := parsers.ParserForName(m.Config.Parser) // Check for tools capability - if slices.Contains(m.Template.Vars(), "tools") { + if slices.Contains(m.Template.Vars(), "tools") || (builtinParser != nil && builtinParser.HasToolSupport()) { capabilities = append(capabilities, model.CapabilityTools) } @@ -112,7 +114,8 @@ func (m *Model) Capabilities() []model.Capability { // Check for thinking capability openingTag, closingTag := thinking.InferTags(m.Template.Template) hasTags := openingTag != "" && closingTag != "" - if hasTags || slices.Contains([]string{"gptoss", "gpt-oss"}, m.Config.ModelFamily) { + isGptoss := slices.Contains([]string{"gptoss", "gpt-oss"}, m.Config.ModelFamily) + if hasTags || isGptoss || (builtinParser != nil && builtinParser.HasThinkingSupport()) { capabilities = append(capabilities, model.CapabilityThinking) } @@ -198,6 +201,20 @@ func (m *Model) String() string { }) } + if m.Config.Renderer != "" { + modelfile.Commands = append(modelfile.Commands, parser.Command{ + Name: "renderer", + Args: m.Config.Renderer, + }) + } + + if m.Config.Parser != "" { + modelfile.Commands = append(modelfile.Commands, parser.Command{ + Name: "parser", + Args: m.Config.Parser, + }) + } + for k, v := range m.Options { switch v := v.(type) { case []any: @@ -238,6 +255,8 @@ type ConfigV2 struct { ModelFamilies []string `json:"model_families"` ModelType string `json:"model_type"` FileType string `json:"file_type"` + Renderer string `json:"renderer,omitempty"` + Parser string `json:"parser,omitempty"` // required by spec Architecture string `json:"architecture"` diff --git a/server/prompt.go b/server/prompt.go index f1d8020ea..56bc63030 100644 --- a/server/prompt.go +++ b/server/prompt.go @@ -11,6 +11,7 @@ import ( "github.com/ollama/ollama/api" "github.com/ollama/ollama/llm" + "github.com/ollama/ollama/model/renderers" "github.com/ollama/ollama/template" ) @@ -41,18 +42,12 @@ func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api. } } - thinkVal := false - thinkLevel := "" - if think != nil { - thinkVal = think.Bool() - thinkLevel = think.String() - } - var b bytes.Buffer - if err := m.Template.Execute(&b, template.Values{Messages: append(system, msgs[i:]...), Tools: tools, Think: thinkVal, ThinkLevel: thinkLevel, IsThinkSet: think != nil}); err != nil { + p, err := renderPrompt(m, append(system, msgs[i:]...), tools, think) + if err != nil { return "", nil, err } - s, err := tokenize(ctx, b.String()) + s, err := tokenize(ctx, p) if err != nil { return "", nil, err } @@ -101,6 +96,23 @@ func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api. } // truncate any messages that do not fit into the context window + p, err := renderPrompt(m, append(system, msgs[currMsgIdx:]...), tools, think) + if err != nil { + return "", nil, err + } + + return p, images, nil +} + +func renderPrompt(m *Model, msgs []api.Message, tools []api.Tool, think *api.ThinkValue) (string, error) { + if m.Config.Renderer != "" { + rendered, err := renderers.RenderWithRenderer(m.Config.Renderer, msgs, tools, think) + if err != nil { + return "", err + } + return rendered, nil + } + var b bytes.Buffer thinkVal := false thinkLevel := "" @@ -108,9 +120,8 @@ func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api. thinkVal = think.Bool() thinkLevel = think.String() } - if err := m.Template.Execute(&b, template.Values{Messages: append(system, msgs[currMsgIdx:]...), Tools: tools, Think: thinkVal, ThinkLevel: thinkLevel, IsThinkSet: think != nil}); err != nil { - return "", nil, err + if err := m.Template.Execute(&b, template.Values{Messages: msgs, Tools: tools, Think: thinkVal, ThinkLevel: thinkLevel, IsThinkSet: think != nil}); err != nil { + return "", err } - - return b.String(), images, nil + return b.String(), nil } diff --git a/server/routes.go b/server/routes.go index 8dd1b217a..e999c6c01 100644 --- a/server/routes.go +++ b/server/routes.go @@ -35,8 +35,8 @@ import ( "github.com/ollama/ollama/harmony" "github.com/ollama/ollama/llm" "github.com/ollama/ollama/logutil" + "github.com/ollama/ollama/model/parsers" "github.com/ollama/ollama/openai" - "github.com/ollama/ollama/parser" "github.com/ollama/ollama/server/internal/client/ollama" "github.com/ollama/ollama/server/internal/registry" "github.com/ollama/ollama/template" @@ -47,6 +47,18 @@ import ( "github.com/ollama/ollama/version" ) +func shouldUseHarmony(model *Model) bool { + if slices.Contains([]string{"gptoss", "gpt-oss"}, model.Config.ModelFamily) { + // heuristic to check whether the template expects to be parsed via harmony: + // search for harmony tags that are nearly always used + if model.Template.Contains("<|start|>") && model.Template.Contains("<|end|>") { + return true + } + } + + return false +} + func experimentEnabled(name string) bool { return slices.Contains(strings.Split(os.Getenv("OLLAMA_EXPERIMENT"), ","), name) } @@ -196,17 +208,13 @@ func (s *Server) GenerateHandler(c *gin.Context) { return } - useHarmony := harmony.ShouldUseHarmony(m.Config.ModelFamily, m.Template) && !req.Raw - var parserType parser.TokenParserType + useHarmony := shouldUseHarmony(m) && !req.Raw + var harmonyMessageHandler *harmony.HarmonyMessageHandler + var harmonyToolParser *harmony.HarmonyToolCallAccumulator if useHarmony { - parserType = parser.TokenParserTypeHarmony - } else { - parserType = parser.TokenParserTypeDefault - } - var functionNameMap *harmony.FunctionNameMap - - if useHarmony { - functionNameMap = harmony.NewFunctionNameMap() + harmonyMessageHandler = harmony.NewHarmonyMessageHandler() + harmonyMessageHandler.HarmonyParser.AddImplicitStart() + harmonyToolParser = harmonyMessageHandler.CreateToolParser() } // Validate Think value: string values currently only allowed for gptoss models @@ -322,10 +330,10 @@ func (s *Server) GenerateHandler(c *gin.Context) { // If debug mode is enabled, return the rendered template instead of calling the model if req.DebugRenderOnly { - c.JSON(http.StatusOK, api.DebugTemplateResponse{ + c.JSON(http.StatusOK, api.GenerateResponse{ Model: req.Model, CreatedAt: time.Now().UTC(), - DebugInfo: api.DebugInfo{ + DebugInfo: &api.DebugInfo{ RenderedTemplate: prompt, ImageCount: len(images), }, @@ -350,19 +358,16 @@ func (s *Server) GenerateHandler(c *gin.Context) { var sb strings.Builder defer close(ch) if err := r.Completion(c.Request.Context(), llm.CompletionRequest{ - Prompt: prompt, - Images: images, - Format: req.Format, - Options: opts, - ParserType: parserType, + Prompt: prompt, + Images: images, + Format: req.Format, + Options: opts, }, func(cr llm.CompletionResponse) { res := api.GenerateResponse{ Model: req.Model, CreatedAt: time.Now().UTC(), Response: cr.Content, Done: cr.Done, - Thinking: cr.Thinking, - ToolCalls: cr.ToolCalls, Metrics: api.Metrics{ PromptEvalCount: cr.PromptEvalCount, PromptEvalDuration: cr.PromptEvalDuration, @@ -371,22 +376,12 @@ func (s *Server) GenerateHandler(c *gin.Context) { }, } - if res.Done { - res.DoneReason = cr.DoneReason.String() - res.TotalDuration = time.Since(checkpointStart) - res.LoadDuration = checkpointLoaded.Sub(checkpointStart) - } - if useHarmony { - for i, tool := range res.ToolCalls { - res.ToolCalls[i].Function.Name = functionNameMap.OriginalFromConverted(tool.Function.Name) - } - if res.Response != "" || res.Thinking != "" || len(res.ToolCalls) > 0 || res.Done { - ch <- res - } - return - } - if thinkingState != nil { + content, thinking, toolContent := harmonyMessageHandler.AddContent(cr.Content, harmonyToolParser) + res.Response = content + res.Thinking = thinking + harmonyToolParser.Add(toolContent) + } else if thinkingState != nil { thinking, content := thinkingState.AddContent(cr.Content) res.Thinking = thinking res.Response = content @@ -397,6 +392,30 @@ func (s *Server) GenerateHandler(c *gin.Context) { } if cr.Done { + if useHarmony { + toolName, toolContent := harmonyToolParser.Drain() + if toolName != nil { + *toolName = strings.TrimPrefix(*toolName, "functions.") + var args api.ToolCallFunctionArguments + if err := json.Unmarshal([]byte(toolContent), &args); err != nil { + errStr := fmt.Sprintf("error parsing tool call: raw='%s', err=%s", toolContent, err.Error()) + ch <- gin.H{"error": errStr} + return + } + + res.ToolCalls = append(res.ToolCalls, api.ToolCall{ + Function: api.ToolCallFunction{ + Name: *toolName, + Arguments: args, + }, + }) + } + } + + res.DoneReason = cr.DoneReason.String() + res.TotalDuration = time.Since(checkpointStart) + res.LoadDuration = checkpointLoaded.Sub(checkpointStart) + if !req.Raw { tokens, err := r.Tokenize(c.Request.Context(), prompt+sb.String()) if err != nil { @@ -470,7 +489,6 @@ func (s *Server) EmbedHandler(c *gin.Context) { } truncate := true - if req.Truncate != nil && !*req.Truncate { truncate = false } @@ -537,7 +555,16 @@ func (s *Server) EmbedHandler(c *gin.Context) { return } + if bos := kvData.Uint("tokenizer.ggml.bos_token_id"); tokens[0] != int(bos) && kvData.Bool("add_bos_token", true) { + ctxLen-- + } + + if eos := kvData.Uint("tokenizer.ggml.eos_token_id"); tokens[len(tokens)-1] != int(eos) && kvData.Bool("add_eos_token", true) { + ctxLen-- + } + tokens = tokens[:ctxLen] + s, err = r.Detokenize(c.Request.Context(), tokens) if err != nil { c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) @@ -1599,27 +1626,32 @@ func (s *Server) ChatHandler(c *gin.Context) { } msgs = filterThinkTags(msgs, m) - useHarmony := harmony.ShouldUseHarmony(m.Config.ModelFamily, m.Template) - var parserType parser.TokenParserType - if useHarmony { - parserType = parser.TokenParserTypeHarmony - } else { - parserType = parser.TokenParserTypeDefault + var builtinParser parsers.Parser + if m.Config.Parser != "" { + builtinParser = parsers.ParserForName(m.Config.Parser) } + var harmonyMessageHandler *harmony.HarmonyMessageHandler + var harmonyToolParser *harmony.HarmonyToolCallAccumulator + + useHarmony := shouldUseHarmony(m) || m.Config.Parser == "harmony" + processedTools := req.Tools - var functionNameMap *harmony.FunctionNameMap - var prefillString string - // TODO(parthsareen): this can be abstracted to not be model specific and potentially moved to the runner if useHarmony { - prefillString = harmony.Prefill(msgs[len(msgs)-1]) - functionNameMap = harmony.NewFunctionNameMap() + harmonyMessageHandler = harmony.NewHarmonyMessageHandler() + var lastMessage *api.Message + if len(msgs) > 0 { + lastMessage = &msgs[len(msgs)-1] + } + harmonyMessageHandler.HarmonyParser.AddImplicitStartOrPrefill(lastMessage) + harmonyToolParser = harmonyMessageHandler.CreateToolParser() + // make a copy of tools to pass to the chat prompt. Function names may be // renamed to be valid Harmony function names. processedTools = make([]api.Tool, len(req.Tools)) copy(processedTools, req.Tools) for i, tool := range processedTools { - processedTools[i].Function.Name = functionNameMap.ConvertAndAdd(tool.Function.Name) + processedTools[i].Function.Name = harmonyMessageHandler.FunctionNameMap.ConvertAndAdd(tool.Function.Name) } } @@ -1632,10 +1664,10 @@ func (s *Server) ChatHandler(c *gin.Context) { // If debug mode is enabled, return the rendered template instead of calling the model if req.DebugRenderOnly { - c.JSON(http.StatusOK, api.DebugTemplateResponse{ + c.JSON(http.StatusOK, api.ChatResponse{ Model: req.Model, CreatedAt: time.Now().UTC(), - DebugInfo: api.DebugInfo{ + DebugInfo: &api.DebugInfo{ RenderedTemplate: prompt, ImageCount: len(images), }, @@ -1672,17 +1704,15 @@ func (s *Server) ChatHandler(c *gin.Context) { defer close(ch) if err := r.Completion(c.Request.Context(), llm.CompletionRequest{ - Prompt: prompt, - Images: images, - Format: req.Format, - Options: opts, - ParserType: parserType, - PrefillString: prefillString, + Prompt: prompt, + Images: images, + Format: req.Format, + Options: opts, }, func(r llm.CompletionResponse) { res := api.ChatResponse{ Model: req.Model, CreatedAt: time.Now().UTC(), - Message: api.Message{Role: "assistant", Content: r.Content, Thinking: r.Thinking, ToolCalls: r.ToolCalls}, + Message: api.Message{Role: "assistant", Content: r.Content}, Done: r.Done, Metrics: api.Metrics{ PromptEvalCount: r.PromptEvalCount, @@ -1697,14 +1727,54 @@ func (s *Server) ChatHandler(c *gin.Context) { res.LoadDuration = checkpointLoaded.Sub(checkpointStart) } + // TODO(drifkin): fold this as much as possibleinto the generic m.Config.Parser logic if useHarmony { - for i, tool := range res.Message.ToolCalls { - res.Message.ToolCalls[i].Function.Name = functionNameMap.OriginalFromConverted(tool.Function.Name) + content, thinking, toolContent := harmonyMessageHandler.AddContent(r.Content, harmonyToolParser) + res.Message.Content = content + res.Message.Thinking = thinking + harmonyToolParser.Add(toolContent) + + if r.Done { + toolName, toolContent := harmonyToolParser.Drain() + if toolName != nil { + *toolName = strings.TrimPrefix(*toolName, "functions.") + *toolName = harmonyMessageHandler.FunctionNameMap.OriginalFromConverted(*toolName) + var args api.ToolCallFunctionArguments + if err := json.Unmarshal([]byte(toolContent), &args); err != nil { + errStr := fmt.Sprintf("error parsing tool call: raw='%s', err=%s", toolContent, err.Error()) + ch <- gin.H{"error": errStr} + return + } + res.Message.ToolCalls = []api.ToolCall{{Function: api.ToolCallFunction{Name: *toolName, Arguments: args}}} + } } + // only send messages with meaningful content (empty messages confuse clients) if res.Message.Content != "" || res.Message.Thinking != "" || len(res.Message.ToolCalls) > 0 || res.Done { ch <- res } + + return + } else if builtinParser != nil { + slog.Log(context.TODO(), logutil.LevelTrace, "builtin parser input", "parser", m.Config.Parser, "content", r.Content) + + content, thinking, toolCalls, err := builtinParser.Add(r.Content, req.Tools) + if err != nil { + ch <- gin.H{"error": err.Error()} + return + } + + res.Message.Content = content + res.Message.Thinking = thinking + res.Message.ToolCalls = toolCalls + + if res.Message.Content != "" || res.Message.Thinking != "" || len(res.Message.ToolCalls) > 0 || r.Done { + slog.Log(context.TODO(), logutil.LevelTrace, "builtin parser output", "parser", m.Config.Parser, "content", content, "thinking", thinking, "toolCalls", toolCalls, "done", r.Done) + ch <- res + } else { + slog.Log(context.TODO(), logutil.LevelTrace, "builtin parser empty output", "parser", m.Config.Parser) + } + return } diff --git a/server/routes_debug_test.go b/server/routes_debug_test.go index f04a1da99..6507284ef 100644 --- a/server/routes_debug_test.go +++ b/server/routes_debug_test.go @@ -180,7 +180,7 @@ func TestGenerateDebugRenderOnly(t *testing.T) { t.Errorf("expected status %d, got %d, body: %s", http.StatusOK, w.Code, w.Body.String()) } - var response api.DebugTemplateResponse + var response api.GenerateResponse if err := json.Unmarshal(w.Body.Bytes(), &response); err != nil { t.Fatalf("failed to unmarshal response: %v", err) } @@ -385,7 +385,7 @@ func TestChatDebugRenderOnly(t *testing.T) { t.Errorf("expected status %d, got %d, body: %s", http.StatusOK, w.Code, w.Body.String()) } - var response api.DebugTemplateResponse + var response api.ChatResponse if err := json.Unmarshal(w.Body.Bytes(), &response); err != nil { t.Fatalf("failed to unmarshal response: %v", err) } diff --git a/server/routes_harmony_streaming_test.go b/server/routes_harmony_streaming_test.go index bcb020886..b1ede4e39 100644 --- a/server/routes_harmony_streaming_test.go +++ b/server/routes_harmony_streaming_test.go @@ -7,6 +7,7 @@ import ( "bytes" "context" "encoding/json" + "net/http" "strings" "testing" "time" @@ -117,7 +118,7 @@ func TestChatHarmonyParserStreamingRealtime(t *testing.T) { name: "content streams as it arrives", steps: []step{ { - input: llm.CompletionResponse{Content: "Hello", Done: false}, + input: llm.CompletionResponse{Content: "<|message|>Hello", Done: false}, wantContent: "Hello", }, { @@ -125,7 +126,7 @@ func TestChatHarmonyParserStreamingRealtime(t *testing.T) { wantContent: ", world", }, { - input: llm.CompletionResponse{Content: "!", Done: true, DoneReason: llm.DoneReasonStop}, + input: llm.CompletionResponse{Content: "!<|end|>", Done: true, DoneReason: llm.DoneReasonStop}, wantContent: "!", }, }, @@ -134,15 +135,20 @@ func TestChatHarmonyParserStreamingRealtime(t *testing.T) { name: "thinking streams separately from content", steps: []step{ { - input: llm.CompletionResponse{Thinking: "Thinking...", Done: false}, + input: llm.CompletionResponse{Content: "<|channel|>analysis<|message|>Thinking...", Done: false}, wantThinking: "Thinking...", }, { - input: llm.CompletionResponse{Content: "Answer", Done: false}, - wantContent: "Answer", + input: llm.CompletionResponse{Content: "<|end|>", Done: false}, + // No output expected - just closes the analysis message and resets state to normal }, { - input: llm.CompletionResponse{Done: true, DoneReason: llm.DoneReasonStop}, + input: llm.CompletionResponse{Content: "<|start|>assistant<|message|>Answer", Done: false}, + wantContent: "Answer", // After message end, state is reset to normal + }, + { + input: llm.CompletionResponse{Content: "<|end|>", Done: true, DoneReason: llm.DoneReasonStop}, + // No output expected - just closes the assistant message }, }, }, @@ -150,16 +156,24 @@ func TestChatHarmonyParserStreamingRealtime(t *testing.T) { name: "partial tags buffer until complete", steps: []step{ { - input: llm.CompletionResponse{Thinking: "Deep ", Done: false}, + input: llm.CompletionResponse{Content: "<|chan", Done: false}, + // No output - partial tag + }, + { + input: llm.CompletionResponse{Content: "nel|>analysis<|mess", Done: false}, + // No output - still building tags + }, + { + input: llm.CompletionResponse{Content: "age|>Deep ", Done: false}, wantThinking: "Deep ", }, { - input: llm.CompletionResponse{Thinking: "thought", Done: false}, + input: llm.CompletionResponse{Content: "thought<|end|>", Done: false}, wantThinking: "thought", }, { - input: llm.CompletionResponse{Content: "Done", Done: true, DoneReason: llm.DoneReasonStop}, - wantContent: "Done", + input: llm.CompletionResponse{Content: "<|start|>assistant<|message|>Done<|end|>", Done: true, DoneReason: llm.DoneReasonStop}, + wantContent: "Done", // After message end, state is reset to normal }, }, }, @@ -167,7 +181,7 @@ func TestChatHarmonyParserStreamingRealtime(t *testing.T) { name: "simple assistant after analysis", steps: []step{ { - input: llm.CompletionResponse{Thinking: "Think", Content: "Answer", Done: true, DoneReason: llm.DoneReasonStop}, + input: llm.CompletionResponse{Content: "<|channel|>analysis<|message|>Think<|end|><|start|>assistant<|message|>Answer<|end|>", Done: true, DoneReason: llm.DoneReasonStop}, wantContent: "Answer", wantThinking: "Think", }, @@ -177,7 +191,7 @@ func TestChatHarmonyParserStreamingRealtime(t *testing.T) { name: "tool call parsed and returned correctly", steps: []step{ { - input: llm.CompletionResponse{Content: "The weather is sunny", ToolCalls: []api.ToolCall{{Function: api.ToolCallFunction{Name: "get_weather", Arguments: api.ToolCallFunctionArguments{"location": "San Francisco"}}}}, Done: true, DoneReason: llm.DoneReasonStop}, + input: llm.CompletionResponse{Content: "<|channel|>commentary to=functions.get_weather<|message|>{\"location\":\"San Francisco\"}<|end|><|start|>assistant<|message|>The weather is sunny<|end|>", Done: true, DoneReason: llm.DoneReasonStop}, wantContent: "The weather is sunny", wantToolCalls: []api.ToolCall{ { @@ -196,10 +210,15 @@ func TestChatHarmonyParserStreamingRealtime(t *testing.T) { name: "tool call with streaming JSON across chunks", steps: []step{ { - input: llm.CompletionResponse{Done: false}, + input: llm.CompletionResponse{Content: "<|channel|>commentary to=functions.calculate<|message|>{\"expr", Done: false}, + // No output yet - incomplete JSON }, { - input: llm.CompletionResponse{ToolCalls: []api.ToolCall{{Function: api.ToolCallFunction{Name: "calculate", Arguments: api.ToolCallFunctionArguments{"expression": "2+2"}}}}, Done: true}, + input: llm.CompletionResponse{Content: "ession\":\"2+", Done: false}, + // Still no output - incomplete JSON + }, + { + input: llm.CompletionResponse{Content: "2\"}", Done: true}, wantToolCalls: []api.ToolCall{ { Function: api.ToolCallFunction{ @@ -381,9 +400,9 @@ func TestChatHarmonyParserStreamingSimple(t *testing.T) { gin.SetMode(gin.TestMode) mockResponses := []llm.CompletionResponse{ - {Content: "First ", Done: false}, + {Content: "<|message|>First ", Done: false}, {Content: "chunk ", Done: false}, - {Content: "here", Done: true, DoneReason: llm.DoneReasonStop}, + {Content: "here<|end|>", Done: true, DoneReason: llm.DoneReasonStop}, } mock := mockRunner{ @@ -488,3 +507,189 @@ func TestChatHarmonyParserStreamingSimple(t *testing.T) { t.Errorf("expected at least 2 content chunks for streaming, got %d", contentChunks) } } + +func TestChatHarmonyParserStreaming(t *testing.T) { + gin.SetMode(gin.TestMode) + + type expectedChunk struct { + afterResponse int // Which mock response this chunk should appear after + content string // Expected content in this chunk + thinking string // Expected thinking in this chunk + } + + testCases := []struct { + name string + mockResponses []llm.CompletionResponse + expectedChunks []expectedChunk + wantContent string + wantThinking string + }{ + { + name: "simple message without thinking", + mockResponses: []llm.CompletionResponse{ + {Content: "<|start|>assistant<|message|>Hello, ", Done: false}, + {Content: "how can I help?", Done: false}, + {Content: "<|end|>", Done: true, DoneReason: llm.DoneReasonStop}, + }, + expectedChunks: []expectedChunk{ + {afterResponse: 1, content: "Hello, "}, + {afterResponse: 2, content: "how can I help?"}, + }, + wantContent: "Hello, how can I help?", + }, + { + name: "message with analysis channel for thinking", + mockResponses: []llm.CompletionResponse{ + {Content: "<|channel|>analysis<|message|>", Done: false}, + {Content: "Let me think ", Done: false}, + {Content: "about this problem...", Done: false}, + {Content: "<|end|>", Done: false}, + {Content: "<|start|>assistant<|message|>", Done: false}, + {Content: "The answer ", Done: false}, + {Content: "is 42", Done: false}, + {Content: "<|end|>", Done: true, DoneReason: llm.DoneReasonStop}, + }, + expectedChunks: []expectedChunk{ + {afterResponse: 2, thinking: "Let me think "}, + {afterResponse: 3, thinking: "about this problem..."}, + {afterResponse: 6, content: "The answer "}, + {afterResponse: 7, content: "is 42"}, + }, + wantContent: "The answer is 42", + wantThinking: "Let me think about this problem...", + }, + { + name: "streaming with partial tags across boundaries", + mockResponses: []llm.CompletionResponse{ + {Content: "<|chan", Done: false}, + {Content: "nel|>analy", Done: false}, + {Content: "sis<|mess", Done: false}, + {Content: "age|>Think", Done: false}, + {Content: "ing deeply...<|end|>", Done: false}, + {Content: "<|start|>assi", Done: false}, + {Content: "stant<|message|>Result ", Done: false}, + {Content: "computed<|e", Done: false}, + {Content: "nd|>", Done: true, DoneReason: llm.DoneReasonStop}, + }, + expectedChunks: []expectedChunk{ + {afterResponse: 4, thinking: "Think"}, + {afterResponse: 5, thinking: "ing deeply..."}, + {afterResponse: 7, content: "Result "}, + {afterResponse: 8, content: "computed"}, + }, + wantContent: "Result computed", + wantThinking: "Thinking deeply...", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + // Channel to synchronize mock responses with chunk verification + responsesSent := make(chan int, len(tc.mockResponses)) + + mock := mockRunner{ + CompletionFn: func(ctx context.Context, r llm.CompletionRequest, fn func(llm.CompletionResponse)) error { + // Send mock responses one at a time, notifying when each is sent + for i, resp := range tc.mockResponses { + fn(resp) + responsesSent <- i + 1 + } + close(responsesSent) + return nil + }, + } + + s := Server{ + sched: &Scheduler{ + pendingReqCh: make(chan *LlmRequest, 1), + finishedReqCh: make(chan *LlmRequest, 1), + expiredCh: make(chan *runnerRef, 1), + unloadedCh: make(chan any, 1), + loaded: make(map[string]*runnerRef), + newServerFn: newMockServer(&mock), + getGpuFn: discover.GetGPUInfo, + getCpuFn: discover.GetCPUInfo, + reschedDelay: 250 * time.Millisecond, + loadFn: func(req *LlmRequest, _ *ggml.GGML, _ discover.GpuInfoList, _ bool) bool { + req.successCh <- &runnerRef{ + llama: &mock, + } + return false + }, + }, + } + + go s.sched.Run(t.Context()) + + // Create a minimal model + _, digest := createHarmonyTestModel(t) + + // Create model with passthrough template + stream := false + w := createRequest(t, s.CreateHandler, api.CreateRequest{ + Model: "harmony-test", + Files: map[string]string{"file.gguf": digest}, + Template: `<|start|><|end|>{{ with .Tools }}{{ end }}{{ .Prompt }}`, + Stream: &stream, + }) + + if w.Code != http.StatusOK { + t.Fatalf("failed to create model: %d", w.Code) + } + + // Test chat endpoint with streaming + streamTrue := true + w = createRequest(t, s.ChatHandler, api.ChatRequest{ + Model: "harmony-test", + Messages: []api.Message{{Role: "user", Content: "Hello"}}, + Stream: &streamTrue, + Tools: getTestTools(), + }) + + if w.Code != http.StatusOK { + t.Fatalf("chat request failed: %d - %s", w.Code, w.Body.String()) + } + + // Parse streaming response + var chunks []api.ChatResponse + var content, thinking strings.Builder + + decoder := json.NewDecoder(w.Body) + for decoder.More() { + var chunk api.ChatResponse + if err := decoder.Decode(&chunk); err != nil { + t.Fatalf("failed to decode chunk: %v", err) + } + chunks = append(chunks, chunk) + + // Accumulate content and thinking from each chunk + content.WriteString(chunk.Message.Content) + thinking.WriteString(chunk.Message.Thinking) + + // Debug output + t.Logf("Chunk %d: content=%q thinking=%q done=%v", len(chunks), chunk.Message.Content, chunk.Message.Thinking, chunk.Done) + } + + // Verify we got streaming chunks + if len(chunks) == 0 { + t.Fatal("expected streaming chunks, got none") + } + + gotContent := content.String() + gotThinking := thinking.String() + + if gotContent != tc.wantContent { + t.Errorf("content mismatch: got %q, want %q", gotContent, tc.wantContent) + } + if gotThinking != tc.wantThinking { + t.Errorf("thinking mismatch: got %q, want %q", gotThinking, tc.wantThinking) + } + + // Verify last chunk has done=true + lastChunk := chunks[len(chunks)-1] + if !lastChunk.Done { + t.Error("expected last chunk to have done=true") + } + }) + } +}