diff --git a/api/types.go b/api/types.go
index a7ddbc373..df3504c3b 100644
--- a/api/types.go
+++ b/api/types.go
@@ -313,10 +313,11 @@ func (t *ToolFunction) String() string {
// ChatResponse is the response returned by [Client.Chat]. Its fields are
// similar to [GenerateResponse].
type ChatResponse struct {
- Model string `json:"model"`
- CreatedAt time.Time `json:"created_at"`
- Message Message `json:"message"`
- DoneReason string `json:"done_reason,omitempty"`
+ Model string `json:"model"`
+ CreatedAt time.Time `json:"created_at"`
+ Message Message `json:"message"`
+ DoneReason string `json:"done_reason,omitempty"`
+ DebugInfo *DebugInfo `json:"_debug_info,omitempty"`
Done bool `json:"done"`
@@ -329,13 +330,6 @@ type DebugInfo struct {
ImageCount int `json:"image_count,omitempty"`
}
-// DebugTemplateResponse is returned when _debug_render_only is set to true
-type DebugTemplateResponse struct {
- Model string `json:"model"`
- CreatedAt time.Time `json:"created_at"`
- DebugInfo DebugInfo `json:"_debug_info"`
-}
-
type Metrics struct {
TotalDuration time.Duration `json:"total_duration,omitempty"`
LoadDuration time.Duration `json:"load_duration,omitempty"`
@@ -443,6 +437,8 @@ type CreateRequest struct {
System string `json:"system,omitempty"`
Parameters map[string]any `json:"parameters,omitempty"`
Messages []Message `json:"messages,omitempty"`
+ Renderer string `json:"renderer,omitempty"`
+ Parser string `json:"parser,omitempty"`
// Deprecated: set the model name with Model instead
Name string `json:"name"`
@@ -480,6 +476,8 @@ type ShowResponse struct {
Parameters string `json:"parameters,omitempty"`
Template string `json:"template,omitempty"`
System string `json:"system,omitempty"`
+ Renderer string `json:"renderer,omitempty"`
+ Parser string `json:"parser,omitempty"`
Details ModelDetails `json:"details,omitempty"`
Messages []Message `json:"messages,omitempty"`
ModelInfo map[string]any `json:"model_info,omitempty"`
@@ -592,6 +590,8 @@ type GenerateResponse struct {
Metrics
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
+
+ DebugInfo *DebugInfo `json:"_debug_info,omitempty"`
}
// ModelDetails provides details about a model.
diff --git a/convert/convert_bert.go b/convert/convert_bert.go
index a9f4b8a77..6b0d0030a 100644
--- a/convert/convert_bert.go
+++ b/convert/convert_bert.go
@@ -28,6 +28,7 @@ type bertModel struct {
LayerNormEPS float32 `json:"layer_norm_eps"`
LayerNormEpsilon float32 `json:"layer_norm_epsilon"`
NormEpsilon float32 `json:"norm_epsilon"`
+ normalizeEmbeddings bool
PoolingType uint32
}
@@ -54,9 +55,11 @@ func (p *bertModel) parseMore(fsys fs.FS) error {
var pooling string
for _, m := range modules {
- if m.Type == "sentence_transformers.models.Pooling" {
+ switch m.Type {
+ case "sentence_transformers.models.Pooling":
pooling = m.Path
- break
+ case "sentence_transformers.models.Normalize":
+ p.normalizeEmbeddings = true
}
}
@@ -90,6 +93,7 @@ func (p *bertModel) KV(t *Tokenizer) ggml.KV {
kv["general.architecture"] = "bert"
kv["bert.attention.causal"] = false
kv["bert.pooling_type"] = p.PoolingType
+ kv["bert.normalize_embeddings"] = p.normalizeEmbeddings
kv["bert.block_count"] = cmp.Or(p.NLayers, p.NumHiddenLayers, p.NLayer)
diff --git a/discover/cuda_common.go b/discover/cuda_common.go
index ca008af63..3c7a92114 100644
--- a/discover/cuda_common.go
+++ b/discover/cuda_common.go
@@ -45,10 +45,18 @@ func cudaVariant(gpuInfo CudaGPUInfo) string {
}
}
+ // Check GPU compute capability FIRST
+ isOldGPU := gpuInfo.computeMajor < 7 || (gpuInfo.computeMajor == 7 && gpuInfo.computeMinor < 5)
+ if isOldGPU {
+ // GPU is Pascal or older (CC <= 7.4) - use CUDA v12 (supports CC 6.1)
+ return "v12"
+ }
+
+ // GPU is Turing or newer (CC >= 7.5) - can use newer CUDA
if gpuInfo.DriverMajor < 13 {
// The detected driver is older than 580 (Aug 2025)
// Warn if their CC is compatible with v13 and they should upgrade their driver to get better performance
- if gpuInfo.computeMajor > 7 || (gpuInfo.computeMajor == 7 && gpuInfo.computeMinor >= 5) {
+ if !isOldGPU {
slog.Warn("old CUDA driver detected - please upgrade to a newer driver for best performance", "version", fmt.Sprintf("%d.%d", gpuInfo.DriverMajor, gpuInfo.DriverMinor))
}
return "v12"
diff --git a/docs/development.md b/docs/development.md
index 9726b5d91..ff07b5fb6 100644
--- a/docs/development.md
+++ b/docs/development.md
@@ -11,6 +11,10 @@ Then build and run Ollama from the root directory of the repository:
go run . serve
```
+> [!NOTE]
+> Ollama includes native code compiled with CGO. From time to time these data structures can change and CGO can get out of sync resulting in unexpected crashes. You can force a full build of the native code by running `go clean -cache` first.
+
+
## macOS (Apple Silicon)
macOS Apple Silicon supports Metal which is built-in to the Ollama binary. No additional steps are required.
diff --git a/harmony/harmonyparser.go b/harmony/harmonyparser.go
index 3ec2c21f1..a51819dda 100644
--- a/harmony/harmonyparser.go
+++ b/harmony/harmonyparser.go
@@ -3,29 +3,15 @@ package harmony
import (
"fmt"
"log/slog"
- "slices"
"strings"
"unicode"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/logutil"
- "github.com/ollama/ollama/template"
)
type harmonyParserState int
-func ShouldUseHarmony(modelFamily string, template *template.Template) bool {
- if slices.Contains([]string{"gptoss", "gpt-oss"}, modelFamily) {
- // heuristic to check whether the template expects to be parsed via harmony:
- // search for harmony tags that are nearly always used
- if template.Contains("<|start|>") && template.Contains("<|end|>") {
- return true
- }
- }
-
- return false
-}
-
const (
harmonyParserState_LookingForMessageStart harmonyParserState = iota
harmonyParserState_ParsingHeader
@@ -89,28 +75,18 @@ func (s *HarmonyParser) AddImplicitStart() {
s.acc.WriteString("<|start|>assistant")
}
-func Prefill(lastMessage api.Message) string {
- if lastMessage.Role != "assistant" {
- return ""
- }
-
- switch {
- case strings.TrimSpace(lastMessage.Content) != "":
- return "<|start|>assistant<|channel|>final<|message|>"
- case strings.TrimSpace(lastMessage.Thinking) != "":
- return "<|start|>assistant<|channel|>analysis<|message|>"
- default:
- return ""
- }
-}
-
-// AddImplicitStartOrPrefill adds an implicit start tag or prefill string if provided
-func (s *HarmonyParser) AddImplicitStartOrPrefill(prefillString string) {
- if strings.TrimSpace(prefillString) != "" {
- s.acc.WriteString(prefillString)
- } else {
- s.AddImplicitStart()
+func (s *HarmonyParser) AddImplicitStartOrPrefill(lastMessage *api.Message) {
+ if lastMessage != nil && lastMessage.Role == "assistant" {
+ // handle prefilling conditions
+ if lastMessage.Content != "" {
+ s.acc.WriteString("<|start|>assistant<|channel|>final<|message|>")
+ return
+ } else if lastMessage.Thinking != "" {
+ s.acc.WriteString("<|start|>assistant<|channel|>analysis<|message|>")
+ return
+ }
}
+ s.AddImplicitStart()
}
func (s *HarmonyParser) AddContent(content string) []HarmonyEvent {
@@ -289,7 +265,6 @@ type HarmonyMessageHandler struct {
state harmonyMessageState
HarmonyParser *HarmonyParser
FunctionNameMap *FunctionNameMap
- ToolParser *HarmonyToolCallAccumulator
}
// NewHarmonyMessageHandler creates a new message handler
@@ -302,16 +277,12 @@ func NewHarmonyMessageHandler() *HarmonyMessageHandler {
HeaderEndTag: "<|message|>",
},
FunctionNameMap: NewFunctionNameMap(),
- ToolParser: &HarmonyToolCallAccumulator{
- state: harmonyToolCallState_Normal,
- currentToolName: nil,
- },
}
}
// AddContent processes the content and returns the content, thinking, and tool content.
// content and thinking are already fully parsed, but tool content still needs to be passed to the tool parser
-func (h *HarmonyMessageHandler) AddContent(content string) (string, string, string) {
+func (h *HarmonyMessageHandler) AddContent(content string, toolParser *HarmonyToolCallAccumulator) (string, string, string) {
contentSb := strings.Builder{}
thinkingSb := strings.Builder{}
toolContentSb := strings.Builder{}
@@ -328,14 +299,14 @@ func (h *HarmonyMessageHandler) AddContent(content string) (string, string, stri
// event.Header.Recipient is the tool name, something like
// "browser.search" for a built-in, or "functions.calc" for a
// custom one
- h.ToolParser.SetToolName(event.Header.Recipient)
+ toolParser.SetToolName(event.Header.Recipient)
} else {
h.state = harmonyMessageState_Thinking
}
case "commentary":
if event.Header.Recipient != "" {
h.state = harmonyMessageState_ToolCalling
- h.ToolParser.SetToolName(event.Header.Recipient)
+ toolParser.SetToolName(event.Header.Recipient)
} else {
h.state = harmonyMessageState_Normal
}
@@ -358,6 +329,13 @@ func (h *HarmonyMessageHandler) AddContent(content string) (string, string, stri
return contentSb.String(), thinkingSb.String(), toolContentSb.String()
}
+func (h *HarmonyMessageHandler) CreateToolParser() *HarmonyToolCallAccumulator {
+ return &HarmonyToolCallAccumulator{
+ state: harmonyToolCallState_Normal,
+ currentToolName: nil,
+ }
+}
+
type harmonyToolCallState int
const (
diff --git a/harmony/harmonyparser_test.go b/harmony/harmonyparser_test.go
index 82bf5b2de..b988a018f 100644
--- a/harmony/harmonyparser_test.go
+++ b/harmony/harmonyparser_test.go
@@ -3,7 +3,6 @@ package harmony
import (
"fmt"
"reflect"
- "strings"
"testing"
)
@@ -536,202 +535,3 @@ func TestFunctionConvertAndAdd(t *testing.T) {
})
}
}
-
-func TestHarmonyMessageHandlerStreamingScenarios(t *testing.T) {
- t.Run("thinking_then_content_streams", func(t *testing.T) {
- handler := NewHarmonyMessageHandler()
- handler.HarmonyParser.AddImplicitStart()
- tp := handler.ToolParser
- type step struct {
- in string
- wantContent string
- wantThinking string
- }
- steps := []step{
- {in: "<|channel|>analysis<|message|>Thinking...", wantThinking: "Thinking..."},
- {in: "<|end|>", wantThinking: ""},
- {in: "<|start|>assistant<|message|>Answer", wantContent: "Answer"},
- {in: "<|end|>", wantContent: ""},
- }
- for i, s := range steps {
- content, thinking, tool := handler.AddContent(s.in)
- if tool != "" {
- tp.Add(tool)
- }
- if content != s.wantContent || thinking != s.wantThinking {
- t.Fatalf("step %d: got (content=%q thinking=%q), want (content=%q thinking=%q)", i, content, thinking, s.wantContent, s.wantThinking)
- }
- }
- })
-
- t.Run("content_streams_as_it_arrives", func(t *testing.T) {
- handler := NewHarmonyMessageHandler()
- handler.HarmonyParser.AddImplicitStart()
- tp := handler.ToolParser
- inputs := []string{
- "<|start|>assistant<|message|>Hello",
- ", world",
- "!<|end|>",
- }
- var got []string
- for _, in := range inputs {
- content, thinking, tool := handler.AddContent(in)
- if tool != "" {
- tp.Add(tool)
- }
- if thinking != "" {
- t.Fatalf("unexpected thinking %q", thinking)
- }
- if content != "" {
- got = append(got, content)
- }
- }
- want := []string{"Hello", ", world", "!"}
- if !reflect.DeepEqual(got, want) {
- t.Fatalf("content pieces mismatch: got %v want %v", got, want)
- }
- })
-
- t.Run("thinking_streams_separately_from_content", func(t *testing.T) {
- handler := NewHarmonyMessageHandler()
- handler.HarmonyParser.AddImplicitStart()
- tp := handler.ToolParser
- inputs := []string{
- "<|channel|>analysis<|message|>Thinking...",
- "<|end|>",
- "<|start|>assistant<|message|>Answer",
- "<|end|>",
- }
- var got []string
- for _, in := range inputs {
- content, thinking, tool := handler.AddContent(in)
- if tool != "" {
- tp.Add(tool)
- }
- if thinking != "" {
- got = append(got, thinking)
- }
- if content != "" {
- got = append(got, content)
- }
- }
- want := []string{"Thinking...", "Answer"}
- if !reflect.DeepEqual(got, want) {
- t.Fatalf("content pieces mismatch: got %v want %v", got, want)
- }
- })
-
- t.Run("partial_tags_buffer_until_complete", func(t *testing.T) {
- handler := NewHarmonyMessageHandler()
- handler.HarmonyParser.AddImplicitStart()
- tp := handler.ToolParser
- inputs := []string{
- "<|chan",
- "nel|>analysis<|mess",
- "age|>Deep ",
- "thought",
- "<|end|>",
- "<|start|>assistant<|message|>Done",
- "<|end|>",
- }
- var thinkingPieces []string
- var contentPieces []string
- for _, in := range inputs {
- content, thinking, tool := handler.AddContent(in)
- if tool != "" {
- tp.Add(tool)
- }
- if thinking != "" {
- thinkingPieces = append(thinkingPieces, thinking)
- }
- if content != "" {
- contentPieces = append(contentPieces, content)
- }
- }
- if want := []string{"Deep ", "thought"}; !reflect.DeepEqual(thinkingPieces, want) {
- t.Fatalf("thinking pieces mismatch: got %v want %v", thinkingPieces, want)
- }
- if want := []string{"Done"}; !reflect.DeepEqual(contentPieces, want) {
- t.Fatalf("content pieces mismatch: got %v want %v", contentPieces, want)
- }
- })
-
- t.Run("simple_assistant_after_analysis", func(t *testing.T) {
- handler := NewHarmonyMessageHandler()
- handler.HarmonyParser.AddImplicitStart()
- tp := handler.ToolParser
- inputs := []string{
- "<|channel|>analysis<|message|>Think",
- "<|end|>",
- "<|start|>assistant<|message|>Answer",
- "<|end|>",
- }
- var contentSb, thinkingSb strings.Builder
- for _, in := range inputs {
- content, thinking, tool := handler.AddContent(in)
- if tool != "" {
- tp.Add(tool)
- }
- contentSb.WriteString(content)
- thinkingSb.WriteString(thinking)
- }
- if contentSb.String() != "Answer" {
- t.Fatalf("content mismatch: got %q want %q", contentSb.String(), "Answer")
- }
- if thinkingSb.String() != "Think" {
- t.Fatalf("thinking mismatch: got %q want %q", thinkingSb.String(), "Think")
- }
- })
-
- t.Run("tool_call_parsed_and_returned_correctly", func(t *testing.T) {
- handler := NewHarmonyMessageHandler()
- handler.HarmonyParser.AddImplicitStart()
- tp := handler.ToolParser
- inputs := []string{
- "<|channel|>commentary to=functions.calculate<|message|>{\"expression\":\"2+2\"}<|end|>",
- }
- for _, in := range inputs {
- content, thinking, tool := handler.AddContent(in)
- if content != "" || thinking != "" {
- continue
- }
- if tool != "" {
- tp.Add(tool)
- }
- }
- name, args := tp.Drain()
- if name == nil || *name != "functions.calculate" {
- t.Fatalf("unexpected tool name: %v", name)
- }
- if got, want := args, "{\"expression\":\"2+2\"}"; got != want {
- t.Fatalf("unexpected tool args: got %s want %s", got, want)
- }
- })
-
- t.Run("tool_call_across_chunks", func(t *testing.T) {
- handler := NewHarmonyMessageHandler()
- handler.HarmonyParser.AddImplicitStart()
- tp := handler.ToolParser
- inputs := []string{
- "<|channel|>commentary to=functions.calculate<|message|>{\"expression\":\"2+",
- "2\"}",
- "<|end|>",
- }
- for _, in := range inputs {
- content, thinking, tool := handler.AddContent(in)
- if content != "" || thinking != "" {
- continue
- }
- if tool != "" {
- tp.Add(tool)
- }
- }
- name, args := tp.Drain()
- if name == nil || *name != "functions.calculate" {
- t.Fatalf("unexpected tool name: %v", name)
- }
- if got, want := args, "{\"expression\":\"2+2\"}"; got != want {
- t.Fatalf("unexpected tool args: got %s want %s", got, want)
- }
- })
-}
diff --git a/integration/context_test.go b/integration/context_test.go
index ca6f16087..15c157858 100644
--- a/integration/context_test.go
+++ b/integration/context_test.go
@@ -50,7 +50,7 @@ func TestContextExhaustion(t *testing.T) {
// Set up the test data
req := api.GenerateRequest{
Model: smol,
- Prompt: "Write me a story with a ton of emojis?",
+ Prompt: "Write me a story in english with a lot of emojis",
Stream: &stream,
Options: map[string]any{
"temperature": 0,
diff --git a/integration/utils_test.go b/integration/utils_test.go
index ec74b2e3d..7901fed3f 100644
--- a/integration/utils_test.go
+++ b/integration/utils_test.go
@@ -561,7 +561,7 @@ func GenerateRequests() ([]api.GenerateRequest, [][]string) {
KeepAlive: &api.Duration{Duration: 10 * time.Second},
}, {
Model: smol,
- Prompt: "what is the origin of the US thanksgiving holiday? Be brief but factual in your reply",
+ Prompt: "how do rainbows form? Be brief but factual in your reply",
Stream: &stream,
KeepAlive: &api.Duration{Duration: 10 * time.Second},
}, {
@@ -579,9 +579,9 @@ func GenerateRequests() ([]api.GenerateRequest, [][]string) {
[][]string{
{"sunlight", "scattering", "interact", "color", "surface", "depth", "red", "orange", "yellow", "absorbs", "wavelength"},
{"soil", "organic", "earth", "black", "tan", "chemical", "processes", "pigments", "particles", "iron oxide", "rust", "air", "water", "mixture", "mixing"},
- {"england", "english", "massachusetts", "pilgrims", "colonists", "independence", "british", "feast", "family", "gatherings", "traditions", "turkey", "colonial", "period", "harvest", "agricultural", "european settlers", "american revolution", "civil war", "16th century", "17th century", "native american", "united states", "cultural", "hardship", "autumn", "festival"},
+ {"water", "droplet", "refracted", "reflect", "color", "spectrum"},
{"fourth", "july", "declaration", "independence"},
- {"nitrogen", "oxygen", "carbon", "dioxide"},
+ {"nitrogen", "oxygen", "carbon", "dioxide", "water", "vapor"},
}
}
diff --git a/llama/llama.go b/llama/llama.go
index ac2c112c2..88672a033 100644
--- a/llama/llama.go
+++ b/llama/llama.go
@@ -515,33 +515,34 @@ func (c *MtmdContext) NewEmbed(llamaContext *Context, data []byte) ([][]float32,
}
nChunks := C.mtmd_input_chunks_size(ic)
numEmbed := llamaContext.Model().NEmbd()
- lastChunkSize := 0
+ embed := make([][]float32, 0)
for i := range int(nChunks) {
chunk := C.mtmd_input_chunks_get(ic, C.size_t(i))
numTokens := int(C.mtmd_input_chunk_get_n_tokens(chunk))
- lastChunkSize = numTokens
+ slog.Debug("chunk tokens", "index", i, "numTokens", numTokens)
// Encode the chunk
if C.int32_t(0) != C.mtmd_encode_chunk(c.c, chunk) {
return nil, errors.New("unable to encode mtmd image chunk")
}
- }
- // Get the embeddings
- embed := make([][]float32, lastChunkSize)
- embd := C.mtmd_get_output_embd(c.c)
- if nil == embd {
- return nil, errors.New("failed to get image embedding")
- }
+ // Get the embeddings for this chunk
+ chunkEmbed := make([][]float32, numTokens)
+ chunkEmbd := C.mtmd_get_output_embd(c.c)
+ if nil == chunkEmbd {
+ continue
+ }
- // Extend the embedding array for each token
- s := unsafe.Slice((*float32)(embd), numEmbed*lastChunkSize)
- rows := make([]float32, len(s))
- copy(rows, s)
- for i := range lastChunkSize {
- embed[i] = rows[i*numEmbed : (i+1)*numEmbed]
+ // Extend the embedding array for each token
+ s := unsafe.Slice((*float32)(chunkEmbd), numTokens*numEmbed)
+ rows := make([]float32, len(s))
+ copy(rows, s)
+ for i := range numTokens {
+ chunkEmbed[i] = rows[i*numEmbed : (i+1)*numEmbed]
+ }
+ embed = append(embed, chunkEmbed...)
}
-
+ slog.Debug("image embeddings", "totalEmbeddings", len(embed))
return embed, nil
}
diff --git a/llm/server.go b/llm/server.go
index 2af82fa04..7f7d68cd0 100644
--- a/llm/server.go
+++ b/llm/server.go
@@ -35,7 +35,6 @@ import (
"github.com/ollama/ollama/logutil"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/model"
- "github.com/ollama/ollama/parser"
)
type filteredEnv []string
@@ -1350,9 +1349,7 @@ type CompletionRequest struct {
Images []ImageData
Options *api.Options
- Grammar string // set before sending the request to the subprocess
- ParserType parser.TokenParserType
- PrefillString string
+ Grammar string // set before sending the request to the subprocess
}
// DoneReason represents the reason why a completion response is done
@@ -1379,15 +1376,13 @@ func (d DoneReason) String() string {
}
type CompletionResponse struct {
- Content string `json:"content"`
- Thinking string `json:"thinking"`
- ToolCalls []api.ToolCall `json:"tool_calls"`
- DoneReason DoneReason `json:"done_reason"`
- Done bool `json:"done"`
- PromptEvalCount int `json:"prompt_eval_count"`
- PromptEvalDuration time.Duration `json:"prompt_eval_duration"`
- EvalCount int `json:"eval_count"`
- EvalDuration time.Duration `json:"eval_duration"`
+ Content string `json:"content"`
+ DoneReason DoneReason `json:"done_reason"`
+ Done bool `json:"done"`
+ PromptEvalCount int `json:"prompt_eval_count"`
+ PromptEvalDuration time.Duration `json:"prompt_eval_duration"`
+ EvalCount int `json:"eval_count"`
+ EvalDuration time.Duration `json:"eval_duration"`
}
func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn func(CompletionResponse)) error {
@@ -1505,8 +1500,7 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu
return fmt.Errorf("error unmarshalling llm prediction response: %v", err)
}
switch {
- // TODO(parthsareen): token repeat limit is now handled in the runner, this currently support legacy model and can be removed in the future
- case strings.TrimSpace(c.Content) == lastToken && c.Content != "":
+ case strings.TrimSpace(c.Content) == lastToken:
tokenRepeat++
default:
lastToken = strings.TrimSpace(c.Content)
@@ -1519,14 +1513,16 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu
return ctx.Err()
}
+ if c.Content != "" {
+ fn(CompletionResponse{
+ Content: c.Content,
+ })
+ }
+
if c.Done {
fn(c)
return nil
}
-
- if c.Content != "" || c.Thinking != "" || len(c.ToolCalls) > 0 {
- fn(c)
- }
}
}
diff --git a/ml/backend.go b/ml/backend.go
index 154a0f1b5..455715b0d 100644
--- a/ml/backend.go
+++ b/ml/backend.go
@@ -416,6 +416,7 @@ type Tensor interface {
AddID(ctx Context, t2, ids Tensor) Tensor
Softmax(ctx Context) Tensor
+ L2Norm(ctx Context, eps float32) Tensor
LayerNorm(ctx Context, weight, bias Tensor, eps float32) Tensor
RMSNorm(ctx Context, weight Tensor, eps float32) Tensor
Scale(ctx Context, s float64) Tensor
@@ -429,12 +430,13 @@ type Tensor interface {
Sin(ctx Context) Tensor
Cos(ctx Context) Tensor
Tanh(ctx Context) Tensor
- GELU(ctx Context) Tensor
- QuickGELU(ctx Context) Tensor
- SILU(ctx Context) Tensor
- RELU(ctx Context) Tensor
+ GELU(ctx Context, up ...Tensor) Tensor
+ SILU(ctx Context, up ...Tensor) Tensor
+ RELU(ctx Context, up ...Tensor) Tensor
Sigmoid(ctx Context) Tensor
- SwiGLU(ctx Context, up Tensor, alpha, limit float32) Tensor
+
+ // AlphaLimitSILU is a variant of SILU that clamps the input to the range [-limit, limit]
+ SILUAlphaLimit(ctx Context, up Tensor, alpha, limit float32) Tensor
Reshape(ctx Context, shape ...int) Tensor
View(ctx Context, offset int, shape ...int) Tensor
diff --git a/ml/backend/ggml/ggml.go b/ml/backend/ggml/ggml.go
index 931386d56..49dc3e1ab 100644
--- a/ml/backend/ggml/ggml.go
+++ b/ml/backend/ggml/ggml.go
@@ -1205,6 +1205,13 @@ func (t *Tensor) AddID(ctx ml.Context, t2, ids ml.Tensor) ml.Tensor {
}
}
+func (t *Tensor) L2Norm(ctx ml.Context, eps float32) ml.Tensor {
+ return &Tensor{
+ b: t.b,
+ t: C.ggml_l2_norm(ctx.(*Context).ctx, t.t, C.float(eps)),
+ }
+}
+
func (t *Tensor) LayerNorm(ctx ml.Context, w, b ml.Tensor, eps float32) ml.Tensor {
tt := C.ggml_norm(ctx.(*Context).ctx, t.t, C.float(eps))
if w != nil {
@@ -1424,35 +1431,46 @@ func (t *Tensor) IM2Col(ctx ml.Context, t2 ml.Tensor, s0, s1, p0, p1, d0, d1 int
}
}
-func (t *Tensor) GELU(ctx ml.Context) ml.Tensor {
+func (t *Tensor) GELU(ctx ml.Context, t2 ...ml.Tensor) ml.Tensor {
+ if len(t2) > 0 {
+ return &Tensor{
+ b: t.b,
+ t: C.ggml_geglu_split(ctx.(*Context).ctx, t.t, t2[0].(*Tensor).t),
+ }
+ }
return &Tensor{
b: t.b,
t: C.ggml_gelu_inplace(ctx.(*Context).ctx, t.t),
}
}
-func (t *Tensor) QuickGELU(ctx ml.Context) ml.Tensor {
- return &Tensor{
- b: t.b,
- t: C.ggml_gelu_quick_inplace(ctx.(*Context).ctx, t.t),
+func (t *Tensor) SILU(ctx ml.Context, t2 ...ml.Tensor) ml.Tensor {
+ if len(t2) > 0 {
+ return &Tensor{
+ b: t.b,
+ t: C.ggml_swiglu_split(ctx.(*Context).ctx, t.t, t2[0].(*Tensor).t),
+ }
}
-}
-
-func (t *Tensor) SILU(ctx ml.Context) ml.Tensor {
return &Tensor{
b: t.b,
t: C.ggml_silu_inplace(ctx.(*Context).ctx, t.t),
}
}
-func (t *Tensor) RELU(ctx ml.Context) ml.Tensor {
+func (t *Tensor) RELU(ctx ml.Context, t2 ...ml.Tensor) ml.Tensor {
+ if len(t2) > 0 {
+ return &Tensor{
+ b: t.b,
+ t: C.ggml_reglu_split(ctx.(*Context).ctx, t.t, t2[0].(*Tensor).t),
+ }
+ }
return &Tensor{
b: t.b,
t: C.ggml_relu_inplace(ctx.(*Context).ctx, t.t),
}
}
-func (t *Tensor) SwiGLU(ctx ml.Context, up ml.Tensor, alpha, limit float32) ml.Tensor {
+func (t *Tensor) SILUAlphaLimit(ctx ml.Context, up ml.Tensor, alpha, limit float32) ml.Tensor {
return &Tensor{
b: t.b,
t: C.ggml_swiglu_oai(ctx.(*Context).ctx, t.t, up.(*Tensor).t, C.float(alpha), C.float(limit)),
diff --git a/ml/nn/attention.go b/ml/nn/attention.go
index 21b4a28ae..94dbde0b0 100644
--- a/ml/nn/attention.go
+++ b/ml/nn/attention.go
@@ -26,6 +26,7 @@ func Attention(ctx ml.Context, query, key, value ml.Tensor, scale float64, cache
}
func AttentionWithSinks(ctx ml.Context, query, key, value, sinks ml.Tensor, scale float64, cache kvcache.Cache) ml.Tensor {
+ ctx.Forward(query)
if key != nil && value != nil {
if query.Dim(0) != key.Dim(0) {
panic(fmt.Errorf("d_k in attention operation does not match between query(%v) and key(%v)", query.Dim(0), key.Dim(0)))
@@ -39,6 +40,7 @@ func AttentionWithSinks(ctx ml.Context, query, key, value, sinks ml.Tensor, scal
panic(fmt.Errorf("seq_len_k in attention operation does not match between key(%v) and value(%v)", key.Dim(2), value.Dim(2)))
}
+ ctx.Forward(key, value)
if cache != nil {
cache.Put(ctx, key, value)
}
diff --git a/ml/nn/pooling/pooling.go b/ml/nn/pooling/pooling.go
new file mode 100644
index 000000000..63b63b3af
--- /dev/null
+++ b/ml/nn/pooling/pooling.go
@@ -0,0 +1,42 @@
+package pooling
+
+import (
+ "github.com/ollama/ollama/ml"
+)
+
+type Type uint32
+
+const (
+ TypeNone Type = iota
+ TypeMean
+ TypeCLS
+ TypeLast
+)
+
+func (t Type) String() string {
+ switch t {
+ case TypeMean:
+ return "Mean"
+ case TypeCLS:
+ return "CLS"
+ case TypeLast:
+ return "Last"
+ default:
+ return "Unknown"
+ }
+}
+
+func (t Type) Forward(ctx ml.Context, hiddenStates ml.Tensor) ml.Tensor {
+ switch t {
+ case TypeMean:
+ hiddenStates = hiddenStates.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx).Mean(ctx)
+ return hiddenStates.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx)
+ case TypeCLS:
+ return hiddenStates.View(ctx, 0, hiddenStates.Dim(0))
+ case TypeLast:
+ hiddenStates = hiddenStates.View(ctx, (hiddenStates.Dim(1)-1)*hiddenStates.Stride(1), hiddenStates.Dim(0))
+ return hiddenStates
+ default:
+ panic("unknown pooling type")
+ }
+}
diff --git a/ml/nn/pooling/pooling_test.go b/ml/nn/pooling/pooling_test.go
new file mode 100644
index 000000000..c80019459
--- /dev/null
+++ b/ml/nn/pooling/pooling_test.go
@@ -0,0 +1,79 @@
+package pooling_test
+
+import (
+ "bytes"
+ "os"
+ "slices"
+ "testing"
+
+ "github.com/google/go-cmp/cmp"
+ "github.com/ollama/ollama/discover"
+ fsggml "github.com/ollama/ollama/fs/ggml"
+ "github.com/ollama/ollama/ml"
+ "github.com/ollama/ollama/ml/backend/ggml"
+ "github.com/ollama/ollama/ml/nn/pooling"
+)
+
+func setup(tb testing.TB, n int) ml.Backend {
+ tb.Helper()
+
+ f, err := os.CreateTemp(tb.TempDir(), "*.bin")
+ if err != nil {
+ tb.Fatal(err)
+ }
+ defer f.Close()
+
+ if err := fsggml.WriteGGUF(f, fsggml.KV{
+ "general.architecture": "test",
+ "test.block_count": uint32(1),
+ }, []*fsggml.Tensor{
+ {Name: "blk.0.weight", Shape: []uint64{1}, WriterTo: bytes.NewBuffer(make([]byte, 4))},
+ }); err != nil {
+ tb.Fatal(err)
+ }
+
+ var gpuLayers ml.GPULayersList
+ if gpus := discover.GetGPUInfo(); len(gpus) > 0 {
+ gpuLayers = append(gpuLayers, ml.GPULayers{
+ ID: gpus[0].ID,
+ Layers: slices.Collect(func(yield func(int) bool) {
+ for i := range n {
+ if !yield(i) {
+ return
+ }
+ }
+ }),
+ })
+ }
+ b, err := ggml.New(f.Name(), ml.BackendParams{AllocMemory: true, GPULayers: gpuLayers})
+ if err != nil {
+ tb.Fatal(err)
+ }
+
+ return b
+}
+
+func TestForward(t *testing.T) {
+ cases := map[pooling.Type][]float32{
+ pooling.TypeMean: {4, 5, 6, 7, 8, 9, 10, 11},
+ pooling.TypeCLS: {0, 1, 2, 3, 4, 5, 6, 7},
+ pooling.TypeLast: {8, 9, 10, 11, 12, 13, 14, 15},
+ }
+ for typ, want := range cases {
+ t.Run(typ.String(), func(t *testing.T) {
+ b := setup(t, 99)
+ defer b.Close()
+
+ ctx := b.NewContext()
+ defer ctx.Close()
+
+ tt := ctx.Input().Arange(0, 16, 1, ml.DTypeF32).Reshape(ctx, 8, 2)
+ tt = typ.Forward(ctx, tt)
+
+ ctx.Forward(tt).Compute(tt)
+ if diff := cmp.Diff(want, tt.Floats()); diff != "" {
+ t.Error(diff)
+ }
+ })
+ }
+}
diff --git a/model/input/input.go b/model/input/input.go
index bd9b53ec6..35dc41b35 100644
--- a/model/input/input.go
+++ b/model/input/input.go
@@ -54,10 +54,9 @@ type Batch struct {
// Inputs is the input tokens, including placeholders for multimodal inputs.
Inputs ml.Tensor
- // Multimodal is a set of multimodal embeddings previously created by
- // EncodeMultimodal, along with an index into Inputs. Unused for text-only
- // models or for batches without multimodal elements.
- Multimodal []MultimodalIndex
+ // Outputs are the set of indicies into Inputs for which output data should
+ // be returned.
+ Outputs ml.Tensor
// Positions is the position for each Input, relative to its sequence. Equal
// in length to Inputs.
@@ -66,7 +65,8 @@ type Batch struct {
// Sequences is the sequence for each Input. Equal in length to Inputs.
Sequences []int
- // Outputs are the set of indicies into Inputs for which output data should
- // be returned.
- Outputs []int32
+ // Multimodal is a set of multimodal embeddings previously created by
+ // EncodeMultimodal, along with an index into Inputs. Unused for text-only
+ // models or for batches without multimodal elements.
+ Multimodal []MultimodalIndex
}
diff --git a/model/model.go b/model/model.go
index 3a72f09aa..5493a4e63 100644
--- a/model/model.go
+++ b/model/model.go
@@ -5,7 +5,6 @@ import (
"fmt"
_ "image/jpeg"
_ "image/png"
- "math"
"os"
"reflect"
"strconv"
@@ -21,10 +20,15 @@ import (
"github.com/ollama/ollama/logutil"
"github.com/ollama/ollama/ml"
_ "github.com/ollama/ollama/ml/backend"
+ "github.com/ollama/ollama/ml/nn/pooling"
"github.com/ollama/ollama/model/input"
)
-var ErrNoVisionModel = errors.New("this model is missing data required for image input")
+var (
+ ErrNoVisionModel = errors.New("this model is missing data required for image input")
+ ErrUnsupportedModel = errors.New("model not supported")
+ ErrUnsupportedTokenizer = errors.New("tokenizer not supported")
+)
// Model implements a specific model architecture, defining the forward pass and any model-specific configuration
type Model interface {
@@ -104,7 +108,7 @@ func New(modelPath string, params ml.BackendParams) (Model, error) {
}
arch := b.Config().Architecture()
- if b.Config().Uint("pooling_type", math.MaxUint32) != math.MaxUint32 {
+ if pooling.Type(b.Config().Uint("pooling_type")) != pooling.TypeNone {
arch = arch + "_embed"
}
@@ -242,7 +246,7 @@ func setPointer(base Base, v reflect.Value, tags []Tag) {
vv = vv.Elem()
}
- vv = vv.Elem()
+ vv = reflect.Indirect(vv)
if v.IsNil() {
vv = reflect.New(v.Type().Elem()).Elem()
}
diff --git a/model/models/bert/embed.go b/model/models/bert/embed.go
new file mode 100644
index 000000000..166c11e13
--- /dev/null
+++ b/model/models/bert/embed.go
@@ -0,0 +1,181 @@
+package bert
+
+import (
+ "cmp"
+ "math"
+
+ "github.com/ollama/ollama/fs"
+ "github.com/ollama/ollama/ml"
+ "github.com/ollama/ollama/ml/nn"
+ "github.com/ollama/ollama/ml/nn/pooling"
+ "github.com/ollama/ollama/model"
+ "github.com/ollama/ollama/model/input"
+)
+
+type Model struct {
+ model.Base
+ model.TextProcessor
+
+ TokenEmbedding *nn.Embedding `gguf:"token_embd"`
+ TypeEmbedding *nn.Embedding `gguf:"token_types"`
+ PositionEmbedding *nn.Embedding `gguf:"position_embd"`
+ TokenEmbeddingNorm *nn.LayerNorm `gguf:"token_embd_norm"`
+
+ Layers []EncoderLayer `gguf:"blk"`
+
+ Options
+}
+
+// Forward implements model.Model.
+func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
+ hiddenStates := m.TokenEmbedding.Forward(ctx, batch.Inputs)
+ hiddenStates = hiddenStates.Add(ctx, m.TypeEmbedding.Weight.View(ctx, 0, m.hiddenSize))
+ hiddenStates = hiddenStates.Add(ctx, m.PositionEmbedding.Forward(ctx, ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions))))
+ hiddenStates = m.TokenEmbeddingNorm.Forward(ctx, hiddenStates, m.eps)
+
+ for _, layer := range m.Layers {
+ hiddenStates = layer.Forward(ctx, hiddenStates, &m.Options)
+ }
+
+ hiddenStates = m.poolingType.Forward(ctx, hiddenStates)
+ if m.normalize {
+ hiddenStates = hiddenStates.L2Norm(ctx, 1e-12)
+ }
+
+ return hiddenStates, nil
+}
+
+type EncoderLayer struct {
+ *Attention
+ AttentionNorm *nn.LayerNorm `gguf:"attn_output_norm"`
+
+ *MLP
+ MLPNorm *nn.LayerNorm `gguf:"layer_output_norm"`
+}
+
+func (e *EncoderLayer) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *Options) ml.Tensor {
+ // Attention
+ residual := hiddenStates
+ hiddenStates = e.Attention.Forward(ctx, hiddenStates, opts)
+ hiddenStates = hiddenStates.Add(ctx, residual)
+ hiddenStates = e.AttentionNorm.Forward(ctx, hiddenStates, opts.eps)
+
+ // MLP
+ residual = hiddenStates
+ hiddenStates = e.MLP.Forward(ctx, hiddenStates, opts)
+ hiddenStates = hiddenStates.Add(ctx, residual)
+ hiddenStates = e.MLPNorm.Forward(ctx, hiddenStates, opts.eps)
+
+ return hiddenStates
+}
+
+type Attention struct {
+ Query *nn.Linear `gguf:"attn_q"`
+ QueryNorm *nn.LayerNorm `gguf:"attn_q_norm"`
+
+ Key *nn.Linear `gguf:"attn_k"`
+ KeyNorm *nn.LayerNorm `gguf:"attn_k_norm"`
+
+ Value *nn.Linear `gguf:"attn_v"`
+
+ Output *nn.Linear `gguf:"attn_output"`
+}
+
+func (a *Attention) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *Options) ml.Tensor {
+ batchSize := hiddenStates.Dim(1)
+
+ query := a.Query.Forward(ctx, hiddenStates)
+ if a.QueryNorm != nil {
+ query = a.QueryNorm.Forward(ctx, query, opts.eps)
+ }
+ query = query.Reshape(ctx, opts.headDim(), opts.numHeads, batchSize)
+
+ key := a.Key.Forward(ctx, hiddenStates)
+ if a.KeyNorm != nil {
+ key = a.KeyNorm.Forward(ctx, key, opts.eps)
+ }
+ key = key.Reshape(ctx, opts.headDim(), cmp.Or(opts.numKVHeads, opts.numHeads), batchSize)
+
+ value := a.Value.Forward(ctx, hiddenStates)
+ value = value.Reshape(ctx, opts.headDim(), cmp.Or(opts.numKVHeads, opts.numHeads), batchSize)
+
+ attention := nn.Attention(ctx, query, key, value, 1/math.Sqrt(float64(opts.headDim())), nil)
+ attention = attention.Reshape(ctx, opts.hiddenSize, batchSize)
+ return a.Output.Forward(ctx, attention)
+}
+
+type MLP struct {
+ Up *nn.Linear `gguf:"ffn_up"`
+ Down *nn.Linear `gguf:"ffn_down"`
+}
+
+func (m *MLP) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *Options) ml.Tensor {
+ return m.Down.Forward(ctx, m.Up.Forward(ctx, hiddenStates).GELU(ctx))
+}
+
+type Options struct {
+ hiddenSize,
+ numHeads,
+ numKVHeads,
+ keyLength,
+ valueLength int
+ poolingType pooling.Type
+ eps float32
+ normalize bool
+}
+
+func (o Options) headDim() int {
+ return cmp.Or(o.keyLength, o.valueLength, o.hiddenSize/o.numHeads)
+}
+
+func New(c fs.Config) (model.Model, error) {
+ var processor model.TextProcessor
+ switch c.String("tokenizer.ggml.model", "bert") {
+ case "bert":
+ processor = model.NewWordPiece(
+ &model.Vocabulary{
+ Values: c.Strings("tokenizer.ggml.tokens"),
+ Scores: c.Floats("tokenizer.ggml.scores"),
+ Types: c.Ints("tokenizer.ggml.token_type"),
+ AddBOS: c.Bool("tokenizer.ggml.add_bos_token", true),
+ BOS: []int32{
+ int32(cmp.Or(
+ c.Uint("tokenizer.ggml.cls_token_id"),
+ c.Uint("tokenizer.ggml.bos_token_id"),
+ )),
+ },
+ AddEOS: c.Bool("tokenizer.ggml.add_eos_token", true),
+ EOS: []int32{
+ int32(cmp.Or(
+ c.Uint("tokenizer.ggml.separator_token_id"),
+ //nolint:misspell
+ // NOTE: "seperator_token_id" is a typo in model metadata but we need to
+ // support it for compatibility.
+ c.Uint("tokenizer.ggml.seperator_token_id"),
+ c.Uint("tokenizer.ggml.eos_token_id"),
+ )),
+ },
+ },
+ )
+ default:
+ return nil, model.ErrUnsupportedTokenizer
+ }
+
+ return &Model{
+ TextProcessor: processor,
+ Layers: make([]EncoderLayer, c.Uint("block_count")),
+ Options: Options{
+ hiddenSize: int(c.Uint("embedding_length")),
+ numHeads: int(c.Uint("attention.head_count")),
+ numKVHeads: int(c.Uint("attention.head_count_kv")),
+ eps: c.Float("attention.layer_norm_epsilon"),
+ poolingType: pooling.Type(c.Uint("pooling_type")),
+ normalize: c.Bool("normalize_embeddings", true),
+ },
+ }, nil
+}
+
+func init() {
+ model.Register("bert", New)
+ model.Register("bert_embed", New)
+}
diff --git a/model/models/gemma2/model.go b/model/models/gemma2/model.go
index e621d03ae..96ace7c74 100644
--- a/model/models/gemma2/model.go
+++ b/model/models/gemma2/model.go
@@ -24,7 +24,7 @@ type Options struct {
type Model struct {
model.Base
- model.SentencePieceModel
+ model.SentencePiece
TokenEmbedding *nn.Embedding `gguf:"token_embd"`
Layers []Layer `gguf:"blk"`
@@ -40,7 +40,7 @@ const (
func New(c fs.Config) (model.Model, error) {
m := Model{
- SentencePieceModel: model.NewSentencePieceModel(
+ SentencePiece: model.NewSentencePiece(
&model.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"),
Scores: c.Floats("tokenizer.ggml.scores"),
@@ -138,7 +138,7 @@ type MLP struct {
}
func (mlp *MLP) Forward(ctx ml.Context, hiddenState ml.Tensor, opts *Options) ml.Tensor {
- hiddenState = mlp.Gate.Forward(ctx, hiddenState).GELU(ctx).Mul(ctx, mlp.Up.Forward(ctx, hiddenState))
+ hiddenState = mlp.Gate.Forward(ctx, hiddenState).GELU(ctx, mlp.Up.Forward(ctx, hiddenState))
return mlp.Down.Forward(ctx, hiddenState)
}
@@ -176,7 +176,6 @@ func (l *Layer) Forward(ctx ml.Context, hiddenState, positionIDs, outputs ml.Ten
func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
positions := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions))
- outputs := ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs))
hiddenState := m.TokenEmbedding.Forward(ctx, batch.Inputs)
hiddenState = hiddenState.Scale(ctx, math.Sqrt(float64(m.Options.hiddenSize)))
@@ -193,7 +192,7 @@ func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
var lastLayerOutputs ml.Tensor
if i == len(m.Layers)-1 {
- lastLayerOutputs = outputs
+ lastLayerOutputs = batch.Outputs
}
hiddenState = layer.Forward(ctx, hiddenState, positions, lastLayerOutputs, m.Cache, m.Options)
diff --git a/model/models/gemma3/embed.go b/model/models/gemma3/embed.go
index 16c299e22..525547767 100644
--- a/model/models/gemma3/embed.go
+++ b/model/models/gemma3/embed.go
@@ -1,49 +1,38 @@
package gemma3
import (
- "errors"
-
"github.com/ollama/ollama/fs"
"github.com/ollama/ollama/kvcache"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/ml/nn"
+ "github.com/ollama/ollama/ml/nn/pooling"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/model/input"
)
type embedModel struct {
model.Base
- model.SentencePieceModel
+ model.SentencePiece
*TextModel
- PoolingType uint32
+ poolingType pooling.Type
Dense [2]*nn.Linear `gguf:"dense"`
}
func (m *embedModel) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
- batch.Outputs = batch.Positions // return all positions
hiddenStates := m.TextModel.Forward(ctx, batch, m.Cache)
-
- switch m.PoolingType {
- case 0: // None
- case 1: // Mean
- hiddenStates = hiddenStates.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx).Mean(ctx)
- hiddenStates = hiddenStates.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx)
- default:
- return nil, errors.New("unsupported pooling type")
- }
-
+ hiddenStates = m.poolingType.Forward(ctx, hiddenStates)
for _, dense := range m.Dense {
hiddenStates = dense.Forward(ctx, hiddenStates)
}
-
+ hiddenStates = hiddenStates.L2Norm(ctx, 1e-12)
return hiddenStates, nil
}
func newEmbedModel(c fs.Config) (model.Model, error) {
m := &embedModel{
- SentencePieceModel: model.NewSentencePieceModel(
+ SentencePiece: model.NewSentencePiece(
&model.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"),
Scores: c.Floats("tokenizer.ggml.scores"),
@@ -61,7 +50,7 @@ func newEmbedModel(c fs.Config) (model.Model, error) {
},
),
TextModel: newTextModel(c),
- PoolingType: c.Uint("pooling_type", 0),
+ poolingType: pooling.Type(c.Uint("pooling_type", 0)),
}
m.Cache = kvcache.NewWrapperCache(
diff --git a/model/models/gemma3/model.go b/model/models/gemma3/model.go
index 5c92b6bf9..27da889e4 100644
--- a/model/models/gemma3/model.go
+++ b/model/models/gemma3/model.go
@@ -16,7 +16,7 @@ import (
type Model struct {
model.Base
- model.SentencePieceModel
+ model.SentencePiece
*VisionModel `gguf:"v"`
*TextModel
@@ -55,7 +55,7 @@ func (p *MultiModalProjector) Forward(ctx ml.Context, visionOutputs ml.Tensor, i
func New(c fs.Config) (model.Model, error) {
m := Model{
- SentencePieceModel: model.NewSentencePieceModel(
+ SentencePiece: model.NewSentencePiece(
&model.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"),
Scores: c.Floats("tokenizer.ggml.scores"),
diff --git a/model/models/gemma3/model_text.go b/model/models/gemma3/model_text.go
index 2a3b23939..d38746dc8 100644
--- a/model/models/gemma3/model_text.go
+++ b/model/models/gemma3/model_text.go
@@ -123,7 +123,7 @@ type TextMLP struct {
}
func (mlp *TextMLP) Forward(ctx ml.Context, hiddenState ml.Tensor, opts *TextConfig) ml.Tensor {
- hiddenState = mlp.Gate.Forward(ctx, hiddenState).GELU(ctx).Mul(ctx, mlp.Up.Forward(ctx, hiddenState))
+ hiddenState = mlp.Gate.Forward(ctx, hiddenState).GELU(ctx, mlp.Up.Forward(ctx, hiddenState))
return mlp.Down.Forward(ctx, hiddenState)
}
@@ -161,7 +161,6 @@ func (l *TextLayer) Forward(ctx ml.Context, layer int, hiddenState, positionIDs,
func (m *TextModel) Forward(ctx ml.Context, batch input.Batch, cache kvcache.Cache) ml.Tensor {
positions := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions))
- outputs := ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs))
hiddenState := m.TokenEmbedding.Forward(ctx, batch.Inputs)
hiddenState = hiddenState.Scale(ctx, math.Sqrt(float64(m.TextConfig.hiddenSize)))
@@ -194,7 +193,7 @@ func (m *TextModel) Forward(ctx ml.Context, batch input.Batch, cache kvcache.Cac
var lastLayerOutputs ml.Tensor
if i == len(m.Layers)-1 {
- lastLayerOutputs = outputs
+ lastLayerOutputs = batch.Outputs
}
hiddenState = layer.Forward(ctx, i, hiddenState, positions, lastLayerOutputs, cache, m.TextConfig)
diff --git a/model/models/gemma3n/model.go b/model/models/gemma3n/model.go
index 6e83a9724..e59e3193f 100644
--- a/model/models/gemma3n/model.go
+++ b/model/models/gemma3n/model.go
@@ -10,7 +10,7 @@ import (
type Model struct {
model.Base
- model.SentencePieceModel
+ model.SentencePiece
*TextModel
}
@@ -23,7 +23,7 @@ func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
func New(c fs.Config) (model.Model, error) {
m := Model{
TextModel: newTextModel(c),
- SentencePieceModel: model.NewSentencePieceModel(
+ SentencePiece: model.NewSentencePiece(
&model.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"),
Scores: c.Floats("tokenizer.ggml.scores"),
diff --git a/model/models/gemma3n/model_text.go b/model/models/gemma3n/model_text.go
index b75a2abb3..2682a45f7 100644
--- a/model/models/gemma3n/model_text.go
+++ b/model/models/gemma3n/model_text.go
@@ -83,7 +83,7 @@ func (m *TextModel) Forward(ctx ml.Context, batch input.Batch, cache kvcache.Cac
hiddenStates = hiddenStates.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx).Mean(ctx)
hiddenStates = hiddenStates.Permute(ctx, 2, 0, 1, 3).Contiguous(ctx)
- hiddenStates = hiddenStates.Rows(ctx, ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs)))
+ hiddenStates = hiddenStates.Rows(ctx, batch.Outputs)
hiddenStates = m.OutputNorm.Forward(ctx, hiddenStates, m.eps)
return m.Output.Forward(ctx, hiddenStates), nil
@@ -170,8 +170,7 @@ func (d TextLayer) Forward(ctx ml.Context, hiddenStates, perLayerInput, position
}
active = d.PerLayerInputGate.Forward(ctx, active)
- active = active.GELU(ctx)
- active = active.Mul(ctx, perLayerInput)
+ active = active.GELU(ctx, perLayerInput)
active = d.PerLayerProjection.Forward(ctx, active)
active = d.PostPerLayerNorm.Forward(ctx, active, opts.eps)
@@ -292,7 +291,7 @@ func (mlp TextMLP) Forward(ctx ml.Context, hiddenStates ml.Tensor, activationSpa
hiddenStates = hiddenStates.Sub(ctx, cutoff).RELU(ctx)
}
- hiddenStates = hiddenStates.GELU(ctx).Mul(ctx, upStates)
+ hiddenStates = hiddenStates.GELU(ctx, upStates)
hiddenStates = mlp.Down.Forward(ctx, hiddenStates)
return hiddenStates
}
diff --git a/model/models/gptoss/model.go b/model/models/gptoss/model.go
index 3ef078095..8456ea5f7 100644
--- a/model/models/gptoss/model.go
+++ b/model/models/gptoss/model.go
@@ -41,8 +41,8 @@ func (m *Transformer) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, err
}
var outputs ml.Tensor
- if len(batch.Outputs) > 0 && i == len(m.TransformerBlocks)-1 {
- outputs = ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs))
+ if i == len(m.TransformerBlocks)-1 {
+ outputs = batch.Outputs
}
hiddenStates = block.Forward(ctx, hiddenStates, positions, outputs, one, m.Cache, &m.Options)
@@ -210,7 +210,7 @@ func (mlp *MLPBlock) Forward(ctx ml.Context, hiddenStates, one ml.Tensor, opts *
up = mlp.Up.Forward(ctx, hiddenStates, selectedExperts)
}
- hiddenStates = gate.SwiGLU(ctx, up, 1.702, 7)
+ hiddenStates = gate.SILUAlphaLimit(ctx, up, 1.702, 7)
experts := mlp.Down.Forward(ctx, hiddenStates, selectedExperts)
experts = experts.Mul(ctx, routingWeights)
diff --git a/model/models/llama/model.go b/model/models/llama/model.go
index 77d8f36d3..572c687a9 100644
--- a/model/models/llama/model.go
+++ b/model/models/llama/model.go
@@ -118,7 +118,7 @@ type MLP struct {
}
func (mlp *MLP) Forward(ctx ml.Context, hiddenState ml.Tensor, opts *Options) ml.Tensor {
- hiddenState = mlp.Gate.Forward(ctx, hiddenState).SILU(ctx).Mul(ctx, mlp.Up.Forward(ctx, hiddenState))
+ hiddenState = mlp.Gate.Forward(ctx, hiddenState).SILU(ctx, mlp.Up.Forward(ctx, hiddenState))
return mlp.Down.Forward(ctx, hiddenState)
}
@@ -160,7 +160,7 @@ func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
var outputs ml.Tensor
if i == len(m.Layers)-1 {
- outputs = ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs))
+ outputs = batch.Outputs
}
hiddenState = layer.Forward(ctx, hiddenState, positions, outputs, m.Cache, m.Options)
diff --git a/model/models/llama4/model.go b/model/models/llama4/model.go
index 99a898d2d..9cb2efc87 100644
--- a/model/models/llama4/model.go
+++ b/model/models/llama4/model.go
@@ -176,9 +176,7 @@ func (m *Model) PostTokenize(inputs []*input.Input) ([]*input.Input, error) {
func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
positions := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions))
- outputs := ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs))
-
- return m.TextModel.Forward(ctx, batch.Inputs, positions, outputs, batch, m.Cache), nil
+ return m.TextModel.Forward(ctx, batch.Inputs, positions, batch.Outputs, batch, m.Cache), nil
}
func init() {
diff --git a/model/models/llama4/model_text.go b/model/models/llama4/model_text.go
index 045ab403f..dbe6bba7a 100644
--- a/model/models/llama4/model_text.go
+++ b/model/models/llama4/model_text.go
@@ -58,14 +58,14 @@ type TextMLP struct {
}
func (mlp *TextMLP) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *TextOptions) ml.Tensor {
- hiddenStates = mlp.Gate.Forward(ctx, hiddenStates).SILU(ctx).Mul(ctx, mlp.Up.Forward(ctx, hiddenStates))
+ hiddenStates = mlp.Gate.Forward(ctx, hiddenStates).SILU(ctx, mlp.Up.Forward(ctx, hiddenStates))
return mlp.Down.Forward(ctx, hiddenStates)
}
type TextExperts struct {
- Gate *nn.Linear `gguf:"ffn_gate_exps"`
- Up *nn.Linear `gguf:"ffn_up_exps"`
- Down *nn.Linear `gguf:"ffn_down_exps"`
+ Gate *nn.LinearBatch `gguf:"ffn_gate_exps"`
+ Up *nn.LinearBatch `gguf:"ffn_up_exps"`
+ Down *nn.LinearBatch `gguf:"ffn_down_exps"`
}
func (e *TextExperts) Forward(ctx ml.Context, hiddenStates, routerLogits ml.Tensor, opts *TextOptions) ml.Tensor {
@@ -76,9 +76,9 @@ func (e *TextExperts) Forward(ctx ml.Context, hiddenStates, routerLogits ml.Tens
hiddenStates = hiddenStates.Repeat(ctx, 1, opts.numExpertsUsed)
hiddenStates = hiddenStates.Mul(ctx, scores)
- upStates := e.Up.Weight.MulmatID(ctx, hiddenStates, experts)
- gateStates := e.Gate.Weight.MulmatID(ctx, hiddenStates, experts)
- downStates := e.Down.Weight.MulmatID(ctx, upStates.Mul(ctx, gateStates.SILU(ctx)), experts)
+ upStates := e.Up.Forward(ctx, hiddenStates, experts)
+ gateStates := e.Gate.Forward(ctx, hiddenStates, experts)
+ downStates := e.Down.Forward(ctx, upStates.Mul(ctx, gateStates.SILU(ctx)), experts)
nextStates := downStates.View(ctx, 0, hiddenStates.Dim(0), downStates.Stride(2), hiddenStates.Dim(2))
for i := 1; i < opts.numExpertsUsed; i++ {
@@ -96,7 +96,7 @@ type TextSharedExpert struct {
}
func (mlp *TextSharedExpert) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *TextOptions) ml.Tensor {
- hiddenStates = mlp.Gate.Forward(ctx, hiddenStates).SILU(ctx).Mul(ctx, mlp.Up.Forward(ctx, hiddenStates))
+ hiddenStates = mlp.Gate.Forward(ctx, hiddenStates).SILU(ctx, mlp.Up.Forward(ctx, hiddenStates))
return mlp.Down.Forward(ctx, hiddenStates)
}
diff --git a/model/models/mistral3/model.go b/model/models/mistral3/model.go
index 408e54d3d..435b1a304 100644
--- a/model/models/mistral3/model.go
+++ b/model/models/mistral3/model.go
@@ -159,9 +159,8 @@ func (m *Model) PostTokenize(inputs []*input.Input) ([]*input.Input, error) {
func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
positions := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions))
- outputs := ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs))
- return m.TextModel.Forward(ctx, batch.Inputs, positions, outputs, batch, m.Cache), nil
+ return m.TextModel.Forward(ctx, batch.Inputs, positions, batch.Outputs, batch, m.Cache), nil
}
func init() {
diff --git a/model/models/mistral3/model_text.go b/model/models/mistral3/model_text.go
index 19c36f9fe..132d1756c 100644
--- a/model/models/mistral3/model_text.go
+++ b/model/models/mistral3/model_text.go
@@ -65,7 +65,7 @@ type MLP struct {
}
func (mlp *MLP) Forward(ctx ml.Context, hiddenState ml.Tensor, opts *TextOptions) ml.Tensor {
- hiddenState = mlp.Gate.Forward(ctx, hiddenState).SILU(ctx).Mul(ctx, mlp.Up.Forward(ctx, hiddenState))
+ hiddenState = mlp.Gate.Forward(ctx, hiddenState).SILU(ctx, mlp.Up.Forward(ctx, hiddenState))
return mlp.Down.Forward(ctx, hiddenState)
}
diff --git a/model/models/mistral3/model_vision.go b/model/models/mistral3/model_vision.go
index 65bdcff2a..3bfb8c90a 100644
--- a/model/models/mistral3/model_vision.go
+++ b/model/models/mistral3/model_vision.go
@@ -51,7 +51,7 @@ type VisionMLP struct {
}
func (mlp *VisionMLP) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *VisionModelOptions) ml.Tensor {
- hiddenStates = mlp.Gate.Forward(ctx, hiddenStates).SILU(ctx).Mul(ctx, mlp.Up.Forward(ctx, hiddenStates))
+ hiddenStates = mlp.Gate.Forward(ctx, hiddenStates).SILU(ctx, mlp.Up.Forward(ctx, hiddenStates))
return mlp.Down.Forward(ctx, hiddenStates)
}
diff --git a/model/models/mllama/model.go b/model/models/mllama/model.go
index d0ad4670e..239d999d5 100644
--- a/model/models/mllama/model.go
+++ b/model/models/mllama/model.go
@@ -107,10 +107,9 @@ func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
}
positions := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions))
- outputs := ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs))
// TODO: attention mask, cross attention mask
- return m.TextModel.Forward(ctx, batch.Inputs, positions, outputs, crossAttentionStates, nil, m.Cache.(*kvcache.WrapperCache)), nil
+ return m.TextModel.Forward(ctx, batch.Inputs, positions, batch.Outputs, crossAttentionStates, nil, m.Cache.(*kvcache.WrapperCache)), nil
}
func init() {
diff --git a/model/models/mllama/model_text.go b/model/models/mllama/model_text.go
index 47a518ced..cb18f0878 100644
--- a/model/models/mllama/model_text.go
+++ b/model/models/mllama/model_text.go
@@ -58,7 +58,7 @@ type TextMLP struct {
}
func (mlp *TextMLP) Forward(ctx ml.Context, hiddenState ml.Tensor, opts *TextModelOptions) ml.Tensor {
- hiddenState = mlp.Gate.Forward(ctx, hiddenState).SILU(ctx).Mul(ctx, mlp.Up.Forward(ctx, hiddenState))
+ hiddenState = mlp.Gate.Forward(ctx, hiddenState).SILU(ctx, mlp.Up.Forward(ctx, hiddenState))
return mlp.Down.Forward(ctx, hiddenState)
}
diff --git a/model/models/models.go b/model/models/models.go
index c880a4720..cc9980789 100644
--- a/model/models/models.go
+++ b/model/models/models.go
@@ -1,6 +1,7 @@
package models
import (
+ _ "github.com/ollama/ollama/model/models/bert"
_ "github.com/ollama/ollama/model/models/gemma2"
_ "github.com/ollama/ollama/model/models/gemma3"
_ "github.com/ollama/ollama/model/models/gemma3n"
diff --git a/model/models/qwen2/model.go b/model/models/qwen2/model.go
index 3c662f068..5a8bea29e 100644
--- a/model/models/qwen2/model.go
+++ b/model/models/qwen2/model.go
@@ -59,7 +59,7 @@ type MLP struct {
}
func (mlp MLP) Forward(ctx ml.Context, hiddenStates ml.Tensor) ml.Tensor {
- hiddenStates = mlp.Gate.Forward(ctx, hiddenStates).SILU(ctx).Mul(ctx, mlp.Up.Forward(ctx, hiddenStates))
+ hiddenStates = mlp.Gate.Forward(ctx, hiddenStates).SILU(ctx, mlp.Up.Forward(ctx, hiddenStates))
return mlp.Down.Forward(ctx, hiddenStates)
}
@@ -111,7 +111,7 @@ func (m Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
var outputs ml.Tensor
if i == len(m.Layers)-1 {
- outputs = ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs))
+ outputs = batch.Outputs
}
hiddenStates = layer.Forward(ctx, hiddenStates, positions, outputs, m.Cache, &m.Options)
diff --git a/model/models/qwen25vl/model.go b/model/models/qwen25vl/model.go
index d73f499d2..6c76305db 100644
--- a/model/models/qwen25vl/model.go
+++ b/model/models/qwen25vl/model.go
@@ -140,9 +140,8 @@ func (m *Model) PostTokenize(inputs []*input.Input) ([]*input.Input, error) {
func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
positions := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions))
- outputs := ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs))
- return m.TextModel.Forward(ctx, batch.Inputs, positions, outputs, batch, m.Cache)
+ return m.TextModel.Forward(ctx, batch.Inputs, positions, batch.Outputs, batch, m.Cache)
}
func init() {
diff --git a/model/models/qwen25vl/model_text.go b/model/models/qwen25vl/model_text.go
index 4b6bc1666..4f4e1effd 100644
--- a/model/models/qwen25vl/model_text.go
+++ b/model/models/qwen25vl/model_text.go
@@ -90,7 +90,7 @@ type MLP struct {
func (mlp *MLP) Forward(ctx ml.Context, hiddenState ml.Tensor, opts *TextOptions) ml.Tensor {
// Apply SwiGLU activation gating
- hiddenState = mlp.Gate.Forward(ctx, hiddenState).SILU(ctx).Mul(ctx, mlp.Up.Forward(ctx, hiddenState))
+ hiddenState = mlp.Gate.Forward(ctx, hiddenState).SILU(ctx, mlp.Up.Forward(ctx, hiddenState))
// Project back to hidden dimension
return mlp.Down.Forward(ctx, hiddenState)
}
diff --git a/model/models/qwen25vl/model_vision.go b/model/models/qwen25vl/model_vision.go
index 4d7afaa14..3dd60e3ba 100644
--- a/model/models/qwen25vl/model_vision.go
+++ b/model/models/qwen25vl/model_vision.go
@@ -100,8 +100,7 @@ type VisionMLP struct {
func (mlp *VisionMLP) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *VisionModelOptions) ml.Tensor {
// Using activation as specified in config (likely GELU or SiLU/Swish)
gateOutput := mlp.Gate.Forward(ctx, hiddenStates)
- upOutput := mlp.Up.Forward(ctx, hiddenStates)
- hiddenStates = gateOutput.SILU(ctx).Mul(ctx, upOutput)
+ hiddenStates = gateOutput.SILU(ctx, mlp.Up.Forward(ctx, hiddenStates))
return mlp.Down.Forward(ctx, hiddenStates)
}
diff --git a/model/models/qwen3/model.go b/model/models/qwen3/model.go
index 7a83e0d04..3f86d0236 100644
--- a/model/models/qwen3/model.go
+++ b/model/models/qwen3/model.go
@@ -30,10 +30,10 @@ func (o Options) headDim() int {
}
type Attention struct {
- QueryNorm *nn.RMSNorm `gguf:"attn_q_norm"`
Query *nn.Linear `gguf:"attn_q"`
- KeyNorm *nn.RMSNorm `gguf:"attn_k_norm"`
+ QueryNorm *nn.RMSNorm `gguf:"attn_q_norm"`
Key *nn.Linear `gguf:"attn_k"`
+ KeyNorm *nn.RMSNorm `gguf:"attn_k_norm"`
Value *nn.Linear `gguf:"attn_v"`
Output *nn.Linear `gguf:"attn_output"`
}
@@ -65,10 +65,10 @@ type MLP interface {
}
type sparse struct {
- Router *nn.Linear `gguf:"ffn_gate_inp"`
- Gate *nn.Linear `gguf:"ffn_gate_exps"`
- Up *nn.Linear `gguf:"ffn_up_exps"`
- Down *nn.Linear `gguf:"ffn_down_exps"`
+ Router *nn.Linear `gguf:"ffn_gate_inp"`
+ Gate *nn.LinearBatch `gguf:"ffn_gate_exps"`
+ Up *nn.LinearBatch `gguf:"ffn_up_exps"`
+ Down *nn.LinearBatch `gguf:"ffn_down_exps"`
}
func (mlp *sparse) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *Options) ml.Tensor {
@@ -87,13 +87,9 @@ func (mlp *sparse) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *Options
hiddenStates = hiddenStates.Reshape(ctx, hiddenStates.Dim(0), 1, hiddenStates.Dim(1))
- upStates := mlp.Up.Weight.MulmatID(ctx, hiddenStates, selectedExperts)
+ hiddenStates = mlp.Gate.Forward(ctx, hiddenStates, selectedExperts).SILU(ctx, mlp.Up.Forward(ctx, hiddenStates, selectedExperts))
- hiddenStates = mlp.Gate.Weight.MulmatID(ctx, hiddenStates, selectedExperts)
- hiddenStates = hiddenStates.SILU(ctx)
- hiddenStates = hiddenStates.Mul(ctx, upStates)
-
- experts := mlp.Down.Weight.MulmatID(ctx, hiddenStates, selectedExperts)
+ experts := mlp.Down.Forward(ctx, hiddenStates, selectedExperts)
experts = experts.Mul(ctx, routingWeights)
nextStates := experts.View(ctx, 0, experts.Dim(0), experts.Stride(2), experts.Dim(2))
@@ -111,7 +107,8 @@ type dense struct {
}
func (mlp *dense) Forward(ctx ml.Context, hiddenStates ml.Tensor, _ *Options) ml.Tensor {
- hiddenStates = mlp.Gate.Forward(ctx, hiddenStates).SILU(ctx).Mul(ctx, mlp.Up.Forward(ctx, hiddenStates))
+ hiddenStates = mlp.Gate.Forward(ctx, hiddenStates).
+ SILU(ctx, mlp.Up.Forward(ctx, hiddenStates))
return mlp.Down.Forward(ctx, hiddenStates)
}
@@ -165,7 +162,7 @@ func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
var outputs ml.Tensor
if i == len(m.Layers)-1 {
- outputs = ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs))
+ outputs = batch.Outputs
}
hiddenStates = layer.Forward(ctx, hiddenStates, positions, outputs, m.Cache, m.Options)
diff --git a/model/parsers/parsers.go b/model/parsers/parsers.go
new file mode 100644
index 000000000..e6dbd1f4f
--- /dev/null
+++ b/model/parsers/parsers.go
@@ -0,0 +1,37 @@
+package parsers
+
+import (
+ "github.com/ollama/ollama/api"
+)
+
+type Parser interface {
+ Add(s string, tools []api.Tool) (content string, thinking string, calls []api.ToolCall, err error)
+ HasToolSupport() bool
+ HasThinkingSupport() bool
+}
+
+func ParserForName(name string) Parser {
+ switch name {
+ case "qwen3-coder":
+ parser := &Qwen3CoderParser{}
+ return parser
+ case "passthrough":
+ return &PassthroughParser{}
+ default:
+ return nil
+ }
+}
+
+type PassthroughParser struct{}
+
+func (p *PassthroughParser) Add(s string, tools []api.Tool) (content string, thinking string, calls []api.ToolCall, err error) {
+ return s, "", nil, nil
+}
+
+func (p *PassthroughParser) HasToolSupport() bool {
+ return false
+}
+
+func (p *PassthroughParser) HasThinkingSupport() bool {
+ return false
+}
diff --git a/model/parsers/qwen3coder.go b/model/parsers/qwen3coder.go
new file mode 100644
index 000000000..b0e8ec48c
--- /dev/null
+++ b/model/parsers/qwen3coder.go
@@ -0,0 +1,410 @@
+package parsers
+
+import (
+ "context"
+ "encoding/json"
+ "encoding/xml"
+ "fmt"
+ "log/slog"
+ "math"
+ "regexp"
+ "strconv"
+ "strings"
+ "unicode"
+
+ "github.com/ollama/ollama/api"
+ "github.com/ollama/ollama/logutil"
+)
+
+type qwenParserState int
+
+const (
+ toolOpenTag = ""
+ toolCloseTag = ""
+)
+
+const (
+ qwenParserState_LookingForToolStart qwenParserState = iota
+ qwenParserState_CollectingToolContent
+)
+
+type Qwen3CoderParser struct {
+ state qwenParserState
+ acc strings.Builder
+}
+
+func (p *Qwen3CoderParser) HasToolSupport() bool {
+ return true
+}
+
+func (p *Qwen3CoderParser) HasThinkingSupport() bool {
+ return false
+}
+
+func (p *Qwen3CoderParser) Add(s string, tools []api.Tool) (content string, thinking string, calls []api.ToolCall, err error) {
+ p.acc.WriteString(s)
+
+ events := p.parseEvents()
+
+ var toolCalls []api.ToolCall
+ var sb strings.Builder
+ for _, event := range events {
+ switch event := event.(type) {
+ case qwenEventRawToolCall:
+ toolCall, err := parseToolCall(event, tools)
+ if err != nil {
+ slog.Warn("qwen tool call parsing failed", "error", err)
+ return "", "", nil, err
+ }
+ toolCalls = append(toolCalls, toolCall)
+ case qwenEventContent:
+ // TODO(drifkin): if the same turn contains multiple interleaved content
+ // events, we naively append them together here. See the note below about
+ // `qwenEvent`s for more details
+ sb.WriteString(event.content)
+ }
+ }
+
+ return sb.String(), "", toolCalls, nil
+}
+
+func (p *Qwen3CoderParser) parseEvents() []qwenEvent {
+ var all []qwenEvent
+
+ keepLooping := true
+ for keepLooping {
+ var events []qwenEvent
+ events, keepLooping = eat(p)
+ if len(events) > 0 {
+ all = append(all, events...)
+ }
+ }
+
+ if len(all) > 0 {
+ slog.Log(context.TODO(), logutil.LevelTrace, "qwen events parsed", "events", all, "state", p.state, "acc", p.acc.String())
+ }
+
+ return all
+}
+
+// we use some internal event types in order to communicate between `Add` and
+// `eat`. We do this to support interleaving content and parallel tool calls in
+// the parser, even though qwen3-coder isn't supposed to do this. Our API
+// doesn't currently support models outputting multiple messages in a turn, so
+// we wouldn't be able to represent it yet, but there's no reason to prevent the
+// parser from supporting it, especially for future models if they end up using
+// a similar format.
+type qwenEvent interface {
+ isQwenEvent()
+}
+
+type qwenEventRawToolCall struct {
+ raw string
+}
+
+type qwenEventContent struct {
+ content string
+}
+
+func (qwenEventContent) isQwenEvent() {}
+func (qwenEventRawToolCall) isQwenEvent() {}
+
+// eat consumes the parser's buffer, and returns a list of any unambiguous
+// events from the current parser state. If the parser transitions to another
+// state, it may have additional events to emit on the next call, which is what
+// the second return value indicates
+func eat(p *Qwen3CoderParser) ([]qwenEvent, bool) {
+ var events []qwenEvent
+
+ switch p.state {
+ case qwenParserState_LookingForToolStart:
+ if strings.Contains(p.acc.String(), toolOpenTag) {
+ // we found a full tool open tag, so we can emit the content before the
+ // tag, being sure to trim any trailing whitespace
+ split := strings.SplitN(p.acc.String(), toolOpenTag, 2)
+ before := split[0]
+ before = strings.TrimRightFunc(before, unicode.IsSpace)
+ if len(before) > 0 {
+ events = append(events, qwenEventContent{content: before})
+ }
+ after := split[1]
+ p.acc.Reset()
+ p.acc.WriteString(after)
+ p.state = qwenParserState_CollectingToolContent
+ return events, true
+ } else if overlap := overlap(p.acc.String(), toolOpenTag); overlap > 0 {
+ // we found a partial tool open tag, so we can emit the unambiguous part,
+ // which is the (trailing-whitespace trimmed) content before the partial
+ // tool open tag
+ beforePartialTag := p.acc.String()[:len(p.acc.String())-overlap]
+ trailingWhitespaceLen := trailingWhitespaceLen(beforePartialTag)
+ ambiguousStart := len(beforePartialTag) - trailingWhitespaceLen
+ unambiguous := p.acc.String()[:ambiguousStart]
+ ambiguous := p.acc.String()[ambiguousStart:]
+ p.acc.Reset()
+ p.acc.WriteString(ambiguous)
+ events = append(events, qwenEventContent{content: unambiguous})
+ return events, false
+ } else {
+ // we found content that is entirely not a tool call. We should withhold
+ // any trailing whitespace in case this is the end of the content
+ whitespaceLen := trailingWhitespaceLen(p.acc.String())
+ ambiguousStart := len(p.acc.String()) - whitespaceLen
+ unambiguous := p.acc.String()[:ambiguousStart]
+ ambiguous := p.acc.String()[ambiguousStart:]
+ p.acc.Reset()
+ p.acc.WriteString(ambiguous)
+ if len(unambiguous) > 0 {
+ events = append(events, qwenEventContent{content: unambiguous})
+ }
+ return events, false
+ }
+ case qwenParserState_CollectingToolContent:
+ if strings.Contains(p.acc.String(), toolCloseTag) {
+ split := strings.SplitN(p.acc.String(), toolCloseTag, 2)
+ before := split[0]
+ if len(before) == 0 {
+ slog.Warn("qwen tool call closing tag found but no content before it")
+ }
+ // remove any whitespace between the tool call and any content after it
+ after := strings.TrimLeftFunc(split[1], unicode.IsSpace)
+ p.acc.Reset()
+ p.acc.WriteString(after)
+ events = append(events, qwenEventRawToolCall{raw: before})
+ p.state = qwenParserState_LookingForToolStart
+ return events, true
+ } else {
+ // note that we don't need to check the overlap here because we only plan
+ // on parsing the tool call once we see the full closing tag. We don't
+ // stream back the unparsed tool content, so there's no need to be eager
+ // here
+ return events, false
+ }
+ default:
+ panic("unreachable")
+ }
+}
+
+// TODO(drifkin): move this to a shared location
+// longest overlap between suffix of s and prefix of delim
+func overlap(s, delim string) int {
+ max := min(len(delim), len(s))
+ for i := max; i > 0; i-- {
+ if strings.HasSuffix(s, delim[:i]) {
+ return i
+ }
+ }
+ return 0
+}
+
+func trailingWhitespaceLen(s string) int {
+ for i := len(s) - 1; i >= 0; i-- {
+ if !unicode.IsSpace(rune(s[i])) {
+ return len(s) - i - 1
+ }
+ }
+ return len(s)
+}
+
+type XMLFunctionCall struct {
+ XMLName xml.Name `xml:"function"`
+ Name string `xml:"name,attr"`
+ Parameters []XMLParameter `xml:"parameter"`
+}
+
+type XMLParameter struct {
+ Name string `xml:"name,attr"`
+ Value string `xml:",chardata"`
+}
+
+// parseToolCall parses a raw tool call string into an api.ToolCall.
+// The raw string follows an xml-like format, here's an example:
+//
+//
+//
+// San Francisco
+//
+//
+// celsius
+//
+//
+func parseToolCall(raw qwenEventRawToolCall, tools []api.Tool) (api.ToolCall, error) {
+ toolCall := api.ToolCall{}
+
+ xmlString := transformToXML(raw.raw)
+
+ var functionCall XMLFunctionCall
+ err := xml.Unmarshal([]byte(xmlString), &functionCall)
+ if err != nil {
+ return api.ToolCall{}, err
+ }
+
+ toolCall.Function = api.ToolCallFunction{
+ Name: functionCall.Name,
+ }
+
+ // Find the matching tool to get parameter types
+ var matchedTool *api.Tool
+ for i := range tools {
+ if tools[i].Function.Name == functionCall.Name {
+ matchedTool = &tools[i]
+ break
+ }
+ }
+
+ toolCall.Function.Arguments = make(api.ToolCallFunctionArguments)
+ for _, parameter := range functionCall.Parameters {
+ // Look up the parameter type if we found the tool
+ var paramType api.PropertyType
+ if matchedTool != nil && matchedTool.Function.Parameters.Properties != nil {
+ if prop, ok := matchedTool.Function.Parameters.Properties[parameter.Name]; ok {
+ paramType = prop.Type
+ }
+ }
+
+ toolCall.Function.Arguments[parameter.Name] = parseValue(parameter.Value, paramType)
+ }
+
+ return toolCall, nil
+}
+
+// parseValue converts a raw string value to the appropriate type based on the parameter type specification.
+//
+// For union types (multiple types in PropertyType, which we support but doesn't
+// seem as though the reference parser does type coercion with those types in
+// mind) we use a type precedence approach:
+// 1. null - checked first regardless of declared types (matches reference implementation)
+// 2. boolean - only "true"/"false" are valid booleans
+// 3. integer - must parse as a whole number
+// 4. number - must parse as numeric (returns int if no decimal part)
+// 5. array - must parse as valid JSON array
+// 6. object - must parse as valid JSON object
+// 7. string - always succeeds (least specific type)
+//
+// This precedence ensures we return the most specific type that successfully parses,
+// following the principle of least surprise. For example, with PropertyType{"string", "number"},
+// "123" becomes 123 (number), while "hello" becomes "hello" (string).
+func parseValue(raw string, paramType api.PropertyType) any {
+ // first remove a single leading newlines, and a single trailing newline (if
+ // they exist). This follows the reference implementation
+ raw = strings.TrimPrefix(raw, "\n")
+ raw = strings.TrimSuffix(raw, "\n")
+
+ // Check for null first (case-insensitive) - this takes precedence over any type
+ if strings.ToLower(raw) == "null" {
+ return nil
+ }
+
+ // If no type is specified, default to string
+ if len(paramType) == 0 {
+ return raw
+ }
+
+ // Check if any of the specified types match, using type precedence
+ // Order: boolean -> integer -> number -> array -> object -> string
+ typeSet := make(map[string]bool)
+ for _, t := range paramType {
+ typeSet[t] = true
+ }
+
+ // Try boolean first (most restrictive)
+ if typeSet["boolean"] {
+ lower := strings.ToLower(raw)
+ switch lower {
+ case "true":
+ return true
+ case "false":
+ return false
+ }
+ // If not a valid boolean but boolean is the only type, return false (matching reference)
+ if len(paramType) == 1 {
+ return false
+ }
+ // Otherwise try other types
+ }
+
+ // Try integer
+ if typeSet["integer"] {
+ if i, err := strconv.ParseInt(raw, 10, 64); err == nil {
+ // Return as int if it fits in int32, otherwise int64
+ if i >= math.MinInt32 && i <= math.MaxInt32 {
+ return int(i)
+ }
+ return i
+ }
+ // If integer is the only type and parsing failed, fall back to string
+ if len(paramType) == 1 {
+ return raw
+ }
+ }
+
+ // Try number (float)
+ if typeSet["number"] {
+ if f, err := strconv.ParseFloat(raw, 64); err == nil {
+ // If the number has no decimal part, return as int (matching reference)
+ if f == math.Trunc(f) {
+ i := int64(f)
+ if i >= math.MinInt32 && i <= math.MaxInt32 {
+ return int(i)
+ }
+ return i
+ }
+ return f
+ }
+ // If number is the only type and parsing failed, fall back to string
+ if len(paramType) == 1 {
+ return raw
+ }
+ }
+
+ // Try array
+ if typeSet["array"] {
+ var arr []interface{}
+ if err := json.Unmarshal([]byte(raw), &arr); err == nil {
+ return arr
+ }
+ // If array is the only type and parsing failed, fall back to string
+ if len(paramType) == 1 {
+ return raw
+ }
+ }
+
+ // Try object
+ if typeSet["object"] {
+ var obj map[string]interface{}
+ if err := json.Unmarshal([]byte(raw), &obj); err == nil {
+ return obj
+ }
+ // If object is the only type and parsing failed, fall back to string
+ if len(paramType) == 1 {
+ return raw
+ }
+ }
+
+ // String always succeeds (or if "string" is in the type set)
+ if typeSet["string"] {
+ return raw
+ }
+
+ // If we get here, none of the types matched and string wasn't an option
+ // We return string as a fallback. The reference implementation will attempt
+ // to parse the value as a python literal, but we purposefully don't support
+ // that
+ return raw
+}
+
+var qwenTagRegex = regexp.MustCompile(`<(\w+)=([^>]+)>`)
+
+// transformToXML transforms a raw qwen tool call with xml-like tags into valid
+// xml so that it can be parsed by any xml parser
+func transformToXML(raw string) string {
+ // take the form `` and transform it to ``, taking
+ // care to properly escape the string that becomes the attribute value
+ return qwenTagRegex.ReplaceAllStringFunc(raw, func(match string) string {
+ groups := qwenTagRegex.FindStringSubmatch(match)
+ tag := groups[1]
+ var escapedValue strings.Builder
+ xml.EscapeText(&escapedValue, []byte(groups[2]))
+ return fmt.Sprintf(`<%s name="%s">`, tag, escapedValue.String())
+ })
+}
diff --git a/model/parsers/qwen3coder_test.go b/model/parsers/qwen3coder_test.go
new file mode 100644
index 000000000..2389c77b5
--- /dev/null
+++ b/model/parsers/qwen3coder_test.go
@@ -0,0 +1,830 @@
+package parsers
+
+import (
+ "reflect"
+ "testing"
+
+ "github.com/ollama/ollama/api"
+)
+
+// tool creates a test tool with the given name and properties
+func tool(name string, props map[string]api.ToolProperty) api.Tool {
+ t := api.Tool{Type: "function", Function: api.ToolFunction{Name: name}}
+ t.Function.Parameters.Type = "object"
+ t.Function.Parameters.Properties = props
+ return t
+}
+
+func TestQwenParserStreaming(t *testing.T) {
+ type step struct {
+ input string
+ wantEvents []qwenEvent
+ }
+
+ cases := []struct {
+ desc string
+ steps []step
+ only bool
+ }{
+ {
+ desc: "simple message streamed word by word",
+ steps: []step{
+ {
+ input: "hi",
+ wantEvents: []qwenEvent{qwenEventContent{content: "hi"}},
+ },
+ {
+ input: " there",
+ wantEvents: []qwenEvent{qwenEventContent{content: " there"}},
+ },
+ },
+ },
+ {
+ desc: "content before tool call",
+ steps: []step{
+ {
+ input: "hi there",
+ wantEvents: []qwenEvent{qwenEventContent{content: "hi there"}},
+ },
+ },
+ },
+ {
+ desc: "multiple tool calls in one message",
+ steps: []step{
+ {
+ input: "before1in tool callafter1in tool call 2after2",
+ wantEvents: []qwenEvent{
+ qwenEventContent{content: "before1"},
+ qwenEventRawToolCall{raw: "in tool call"},
+ qwenEventContent{content: "after1"},
+ qwenEventRawToolCall{raw: "in tool call 2"},
+ qwenEventContent{content: "after2"},
+ },
+ },
+ },
+ },
+ {
+ desc: "tool calls with split tags",
+ steps: []step{
+ {
+ input: "beforein tool callaf",
+ wantEvents: []qwenEvent{
+ qwenEventRawToolCall{raw: "in tool call"},
+ qwenEventContent{content: "af"},
+ },
+ },
+ {
+ input: "ter",
+ wantEvents: []qwenEvent{
+ qwenEventContent{content: "ter"},
+ },
+ },
+ },
+ },
+ {
+ desc: "trailing whitespace between content and tool call",
+ steps: []step{
+ {
+ input: "abc\ndef",
+ wantEvents: []qwenEvent{
+ qwenEventContent{content: "abc"},
+ qwenEventRawToolCall{raw: "def"},
+ },
+ },
+ },
+ },
+ {
+ desc: "trailing whitespace between tool call and content",
+ steps: []step{
+ {
+ input: "abc\ndef",
+ wantEvents: []qwenEvent{
+ qwenEventRawToolCall{raw: "abc"},
+ qwenEventContent{content: "def"},
+ },
+ },
+ },
+ },
+ {
+ desc: "empty content before tool call",
+ steps: []step{
+ {
+ input: "\nabc",
+ wantEvents: []qwenEvent{
+ qwenEventRawToolCall{raw: "abc"},
+ },
+ },
+ },
+ },
+ {
+ desc: "partial tool open tag fakeout",
+ steps: []step{
+ {
+ input: "abc\n
+
+San Francisco
+
+
+celsius
+
+`,
+ wantToolCall: api.ToolCall{
+ Function: api.ToolCallFunction{
+ Name: "get_current_temperature",
+ Arguments: map[string]any{
+ "location": "San Francisco",
+ "unit": "celsius",
+ },
+ },
+ },
+ },
+ {
+ name: "names with spaces",
+ tools: []api.Tool{},
+ rawToolCall: `
+
+San Francisco
+
+
+celsius
+
+`,
+ wantToolCall: api.ToolCall{
+ Function: api.ToolCallFunction{
+ Name: "get current temperature",
+ Arguments: map[string]any{
+ "location with spaces": "San Francisco",
+ "unit with spaces": "celsius",
+ },
+ },
+ },
+ },
+ // this mirrors the reference implementation's behavior, but unclear if it
+ // ever happens. If so, then we should probably remove them instead, this
+ // test is to just document the current behavior and test that we don't get
+ // xml errors
+ {
+ name: "names with quotes",
+ tools: []api.Tool{},
+ rawToolCall: `
+
+San Francisco
+
+
+"celsius"
+
+`,
+ wantToolCall: api.ToolCall{
+ Function: api.ToolCallFunction{
+ Name: "\"get current temperature\"",
+ Arguments: map[string]any{
+ "\"location with spaces\"": "San Francisco",
+ "\"unit with spaces\"": "\"celsius\"",
+ },
+ },
+ },
+ },
+ {
+ name: "tool call with typed parameters",
+ tools: []api.Tool{
+ tool("calculate", map[string]api.ToolProperty{
+ "x": {Type: api.PropertyType{"number"}},
+ "y": {Type: api.PropertyType{"integer"}},
+ "enabled": {Type: api.PropertyType{"boolean"}},
+ "items": {Type: api.PropertyType{"array"}},
+ }),
+ },
+ rawToolCall: `
+
+3.14
+
+
+42
+
+
+true
+
+
+["a", "b", "c"]
+
+`,
+ wantToolCall: api.ToolCall{
+ Function: api.ToolCallFunction{
+ Name: "calculate",
+ Arguments: map[string]any{
+ "x": 3.14,
+ "y": 42,
+ "enabled": true,
+ "items": []any{"a", "b", "c"},
+ },
+ },
+ },
+ },
+ }
+
+ for i, step := range steps {
+ gotToolCall, err := parseToolCall(qwenEventRawToolCall{raw: step.rawToolCall}, step.tools)
+ if err != nil {
+ t.Errorf("step %d (%s): %v", i, step.name, err)
+ }
+ if !reflect.DeepEqual(gotToolCall, step.wantToolCall) {
+ t.Errorf("step %d (%s): got tool call %#v, want %#v", i, step.name, gotToolCall, step.wantToolCall)
+ }
+ }
+}
+
+func TestQwenToolCallValueParsing(t *testing.T) {
+ cases := []struct {
+ desc string
+ raw string
+ paramType api.PropertyType
+ want any
+ }{
+ {
+ desc: "default string value (no type specified)",
+ paramType: api.PropertyType{},
+ raw: "some-string",
+ want: "some-string",
+ },
+ {
+ desc: "trim a single leading and trailing newline",
+ paramType: api.PropertyType{},
+ raw: "\nsome-string\n",
+ want: "some-string",
+ },
+ {
+ desc: "trim at most one leading and trailing newline",
+ paramType: api.PropertyType{},
+ raw: "\n\nsome-string\n\n",
+ want: "\nsome-string\n",
+ },
+ {
+ desc: "newline really has to be the first character to be trimmed",
+ paramType: api.PropertyType{},
+ raw: " \nsome-string\n ",
+ want: " \nsome-string\n ",
+ },
+ {
+ desc: "numeric type",
+ paramType: api.PropertyType{"number"},
+ raw: "123",
+ want: 123,
+ },
+ // Integer parsing tests
+ {
+ desc: "integer type",
+ paramType: api.PropertyType{"integer"},
+ raw: "42",
+ want: 42,
+ },
+ {
+ desc: "negative integer",
+ paramType: api.PropertyType{"integer"},
+ raw: "-100",
+ want: -100,
+ },
+ {
+ desc: "zero integer",
+ paramType: api.PropertyType{"integer"},
+ raw: "0",
+ want: 0,
+ },
+ {
+ desc: "integer with leading zeros",
+ paramType: api.PropertyType{"integer"},
+ raw: "007",
+ want: 7,
+ },
+ {
+ desc: "large integer",
+ paramType: api.PropertyType{"integer"},
+ raw: "2147483648", // Just beyond int32 max
+ want: int64(2147483648),
+ },
+ // Float/number parsing tests
+ {
+ desc: "float type",
+ paramType: api.PropertyType{"number"},
+ raw: "3.14",
+ want: 3.14,
+ },
+ {
+ desc: "negative float",
+ paramType: api.PropertyType{"number"},
+ raw: "-273.15",
+ want: -273.15,
+ },
+ {
+ desc: "float without decimal part",
+ paramType: api.PropertyType{"number"},
+ raw: "100.0",
+ want: 100,
+ },
+ {
+ desc: "scientific notation positive",
+ paramType: api.PropertyType{"number"},
+ raw: "1.23e5",
+ want: 123000, // Will be int since it has no decimal part
+ },
+ {
+ desc: "scientific notation negative",
+ paramType: api.PropertyType{"number"},
+ raw: "1.5e-3",
+ want: 0.0015,
+ },
+ {
+ desc: "very small float",
+ paramType: api.PropertyType{"number"},
+ raw: "0.00000001",
+ want: 0.00000001,
+ },
+ // String parsing tests
+ {
+ desc: "explicit string type",
+ paramType: api.PropertyType{"string"},
+ raw: "hello world",
+ want: "hello world",
+ },
+ {
+ desc: "string with special characters",
+ paramType: api.PropertyType{"string"},
+ raw: "/usr/local/bin/test-file_v2.0.sh",
+ want: "/usr/local/bin/test-file_v2.0.sh",
+ },
+ {
+ desc: "string with quotes",
+ paramType: api.PropertyType{"string"},
+ raw: `He said "hello" to me`,
+ want: `He said "hello" to me`,
+ },
+ {
+ desc: "multiline string",
+ paramType: api.PropertyType{"string"},
+ raw: "line one\nline two\nline three",
+ want: "line one\nline two\nline three",
+ },
+ {
+ desc: "empty string",
+ paramType: api.PropertyType{"string"},
+ raw: "",
+ want: "",
+ },
+ {
+ desc: "string that looks like a number",
+ paramType: api.PropertyType{"string"},
+ raw: "12345",
+ want: "12345",
+ },
+ // Boolean parsing tests
+ {
+ desc: "boolean true",
+ paramType: api.PropertyType{"boolean"},
+ raw: "true",
+ want: true,
+ },
+ {
+ desc: "boolean false",
+ paramType: api.PropertyType{"boolean"},
+ raw: "false",
+ want: false,
+ },
+ {
+ desc: "boolean case insensitive true",
+ paramType: api.PropertyType{"boolean"},
+ raw: "True",
+ want: true,
+ },
+ {
+ desc: "boolean case insensitive false",
+ paramType: api.PropertyType{"boolean"},
+ raw: "FALSE",
+ want: false,
+ },
+ // Null parsing tests
+ {
+ desc: "null value lowercase",
+ paramType: api.PropertyType{"string"},
+ raw: "null",
+ want: nil,
+ },
+ {
+ desc: "null value case insensitive",
+ paramType: api.PropertyType{"integer"},
+ raw: "NULL",
+ want: nil,
+ },
+ // Array parsing tests
+ {
+ desc: "array of strings",
+ paramType: api.PropertyType{"array"},
+ raw: `["foo", "bar", "baz"]`,
+ want: []any{"foo", "bar", "baz"},
+ },
+ {
+ desc: "array of numbers",
+ paramType: api.PropertyType{"array"},
+ raw: `[1, 2.5, 3]`,
+ want: []any{float64(1), 2.5, float64(3)},
+ },
+ {
+ desc: "array of mixed types",
+ paramType: api.PropertyType{"array"},
+ raw: `["string", 123, true, null]`,
+ want: []any{"string", float64(123), true, nil},
+ },
+ {
+ desc: "empty array",
+ paramType: api.PropertyType{"array"},
+ raw: `[]`,
+ want: []any{},
+ },
+ // Object parsing tests
+ {
+ desc: "simple object",
+ paramType: api.PropertyType{"object"},
+ raw: `{"key": "value", "number": 42}`,
+ want: map[string]any{"key": "value", "number": float64(42)},
+ },
+ {
+ desc: "nested object",
+ paramType: api.PropertyType{"object"},
+ raw: `{"outer": {"inner": "value"}}`,
+ want: map[string]any{"outer": map[string]any{"inner": "value"}},
+ },
+ {
+ desc: "empty object",
+ paramType: api.PropertyType{"object"},
+ raw: `{}`,
+ want: map[string]any{},
+ },
+ // Error cases and fallback behavior
+ {
+ desc: "invalid integer falls back to string",
+ paramType: api.PropertyType{"integer"},
+ raw: "not-a-number",
+ want: "not-a-number",
+ },
+ {
+ desc: "invalid float falls back to string",
+ paramType: api.PropertyType{"number"},
+ raw: "3.14.159",
+ want: "3.14.159",
+ },
+ {
+ desc: "invalid boolean falls back to false",
+ paramType: api.PropertyType{"boolean"},
+ raw: "yes",
+ want: false,
+ },
+ {
+ desc: "invalid JSON array falls back to string",
+ paramType: api.PropertyType{"array"},
+ raw: "[1, 2, unclosed",
+ want: "[1, 2, unclosed",
+ },
+ {
+ desc: "invalid JSON object falls back to string",
+ paramType: api.PropertyType{"object"},
+ raw: `{"key": unclosed`,
+ want: `{"key": unclosed`,
+ },
+ // Edge cases
+ {
+ desc: "integer overflow should use int64",
+ paramType: api.PropertyType{"integer"},
+ raw: "2147483648", // Beyond int32 max
+ want: int64(2147483648),
+ },
+ {
+ desc: "float with many decimal places",
+ paramType: api.PropertyType{"number"},
+ raw: "3.141592653589793",
+ want: 3.141592653589793,
+ },
+ {
+ desc: "string with JSON-like content",
+ paramType: api.PropertyType{"string"},
+ raw: `{"this": "is", "just": "a string"}`,
+ want: `{"this": "is", "just": "a string"}`,
+ },
+ {
+ desc: "whitespace-only string",
+ paramType: api.PropertyType{"string"},
+ raw: " ",
+ want: " ",
+ },
+ // Unknown parameter (no type specified in tools)
+ {
+ desc: "parameter not in tool definition defaults to string",
+ paramType: api.PropertyType{},
+ raw: "some value",
+ want: "some value",
+ },
+ // Union type tests
+ {
+ desc: "string or number union - valid number",
+ paramType: api.PropertyType{"string", "number"},
+ raw: "42.5",
+ want: 42.5,
+ },
+ {
+ desc: "string or number union - non-numeric string",
+ paramType: api.PropertyType{"string", "number"},
+ raw: "hello",
+ want: "hello",
+ },
+ {
+ desc: "number or string union - valid number (order shouldn't matter)",
+ paramType: api.PropertyType{"number", "string"},
+ raw: "42.5",
+ want: 42.5,
+ },
+ {
+ desc: "integer or null union - valid integer",
+ paramType: api.PropertyType{"integer", "null"},
+ raw: "123",
+ want: 123,
+ },
+ {
+ desc: "integer or null union - null value",
+ paramType: api.PropertyType{"integer", "null"},
+ raw: "null",
+ want: nil,
+ },
+ {
+ desc: "null or integer union - null value (order shouldn't matter)",
+ paramType: api.PropertyType{"null", "integer"},
+ raw: "null",
+ want: nil,
+ },
+ {
+ desc: "boolean or string union - valid boolean",
+ paramType: api.PropertyType{"boolean", "string"},
+ raw: "true",
+ want: true,
+ },
+ {
+ desc: "boolean or string union - non-boolean becomes string",
+ paramType: api.PropertyType{"boolean", "string"},
+ raw: "yes",
+ want: "yes",
+ },
+ {
+ desc: "string or boolean union - valid boolean (precedence test)",
+ paramType: api.PropertyType{"string", "boolean"},
+ raw: "false",
+ want: false, // Should be boolean, not string "false"
+ },
+ {
+ desc: "integer or number union - integer value",
+ paramType: api.PropertyType{"integer", "number"},
+ raw: "42",
+ want: 42,
+ },
+ {
+ desc: "integer or number union - float value",
+ paramType: api.PropertyType{"integer", "number"},
+ raw: "42.5",
+ want: 42.5,
+ },
+ {
+ desc: "number or integer union - integer value (precedence test)",
+ paramType: api.PropertyType{"number", "integer"},
+ raw: "42",
+ want: 42, // Should try integer first due to precedence
+ },
+ {
+ desc: "array or object union - valid array",
+ paramType: api.PropertyType{"array", "object"},
+ raw: `[1, 2, 3]`,
+ want: []any{float64(1), float64(2), float64(3)},
+ },
+ {
+ desc: "array or object union - valid object",
+ paramType: api.PropertyType{"array", "object"},
+ raw: `{"key": "value"}`,
+ want: map[string]any{"key": "value"},
+ },
+ {
+ desc: "object or array union - valid array (precedence test)",
+ paramType: api.PropertyType{"object", "array"},
+ raw: `[1, 2, 3]`,
+ want: []any{float64(1), float64(2), float64(3)},
+ },
+ {
+ desc: "complex multi-type union - null",
+ paramType: api.PropertyType{"string", "number", "boolean", "null"},
+ raw: "null",
+ want: nil,
+ },
+ {
+ desc: "complex multi-type union - boolean",
+ paramType: api.PropertyType{"string", "number", "boolean", "null"},
+ raw: "true",
+ want: true,
+ },
+ {
+ desc: "complex multi-type union - number",
+ paramType: api.PropertyType{"string", "number", "boolean", "null"},
+ raw: "3.14",
+ want: 3.14,
+ },
+ {
+ desc: "complex multi-type union - string",
+ paramType: api.PropertyType{"string", "number", "boolean", "null"},
+ raw: "hello",
+ want: "hello",
+ },
+ {
+ desc: "integer string union - integer string becomes integer",
+ paramType: api.PropertyType{"integer", "string"},
+ raw: "123",
+ want: 123,
+ },
+ {
+ desc: "string integer union - integer string becomes integer (precedence)",
+ paramType: api.PropertyType{"string", "integer"},
+ raw: "123",
+ want: 123, // Integer has higher precedence than string
+ },
+ }
+
+ for _, tc := range cases {
+ t.Run(tc.desc, func(t *testing.T) {
+ got := parseValue(tc.raw, tc.paramType)
+ if !reflect.DeepEqual(got, tc.want) {
+ t.Errorf("got %v (type %T), want %v (type %T)", got, got, tc.want, tc.want)
+ }
+ })
+ }
+}
+
+func TestQwenXMLTransform(t *testing.T) {
+ cases := []struct {
+ desc string
+ raw string
+ want string
+ }{
+ {
+ desc: "simple example",
+ raw: `
+
+San Francisco
+
+
+celsius
+
+`,
+ want: `
+
+San Francisco
+
+
+celsius
+
+`,
+ },
+ // even though quotes aren't expected in these tags, we have these tests to
+ // make sure they're escaped so they don't blow up the xml parser in case
+ // they happen
+ {
+ desc: "names with quotes",
+ raw: `
+
+San Francisco
+
+
+celsius
+
+`,
+ want: `
+
+San Francisco
+
+
+celsius
+
+`,
+ },
+ }
+
+ for _, tc := range cases {
+ got := transformToXML(tc.raw)
+ if got != tc.want {
+ t.Errorf("got %q, want %q", got, tc.want)
+ }
+ }
+}
+
+func TestTrailingWhitespaceLen(t *testing.T) {
+ cases := []struct {
+ desc string
+ s string
+ want int
+ }{
+ {desc: "no whitespace", s: "abc", want: 0},
+ {desc: "trailing whitespace", s: "abc ", want: 1},
+ {desc: "trailing whitespace with newlines", s: "abc \n", want: 2},
+ {desc: "only whitespace", s: " \n ", want: 4},
+ {desc: "leading whitespace doesn't count", s: " \n abc", want: 0},
+ }
+
+ for _, tc := range cases {
+ got := trailingWhitespaceLen(tc.s)
+ if got != tc.want {
+ t.Errorf("got %d, want %d", got, tc.want)
+ }
+ }
+}
diff --git a/model/renderers/qwen3coder.go b/model/renderers/qwen3coder.go
new file mode 100644
index 000000000..df3b3a45b
--- /dev/null
+++ b/model/renderers/qwen3coder.go
@@ -0,0 +1,217 @@
+package renderers
+
+import (
+ "encoding/json"
+ "fmt"
+ "reflect"
+ "strings"
+
+ "github.com/ollama/ollama/api"
+)
+
+var (
+ imStartTag = "<|im_start|>"
+ imEndTag = "<|im_end|>"
+)
+
+// renderAdditionalKeys renders all JSON fields except the ones in handledKeys
+// This follows the same approach from the reference implementation, which gives
+// a particular key ordering
+func renderAdditionalKeys(obj any, handledKeys map[string]bool) string {
+ data, err := json.Marshal(obj)
+ if err != nil {
+ return ""
+ }
+
+ var m map[string]any
+ if err := json.Unmarshal(data, &m); err != nil {
+ return ""
+ }
+
+ var sb strings.Builder
+ for key, value := range m {
+ if handledKeys[key] {
+ continue
+ }
+
+ // Check if value is a map or array (needs JSON serialization)
+ switch v := value.(type) {
+ case map[string]any, []any:
+ jsonBytes, _ := json.Marshal(v)
+ // TODO(drifkin): it would be nice to format the JSON here similarly to
+ // python's default json.dumps behavior (spaces after commas and colons).
+ // This would let us be byte-for-byte compatible with the reference
+ // implementation for most common inputs
+ jsonStr := string(jsonBytes)
+ sb.WriteString("\n<" + key + ">" + jsonStr + "" + key + ">")
+ case nil:
+ continue
+ default:
+ // Simple types, convert to string
+ sb.WriteString("\n<" + key + ">" + fmt.Sprintf("%v", value) + "" + key + ">")
+ }
+ }
+
+ return sb.String()
+}
+
+func Qwen3CoderRenderer(messages []api.Message, tools []api.Tool, _ *api.ThinkValue) (string, error) {
+ var sb strings.Builder
+
+ // filter out system messages and choose the first (if any) to win
+ var systemMessage string
+ var filteredMessages []api.Message
+ for _, message := range messages {
+ if message.Role != "system" {
+ filteredMessages = append(filteredMessages, message)
+ continue
+ }
+
+ if systemMessage == "" {
+ systemMessage = message.Content
+ }
+ }
+
+ if systemMessage != "" || len(tools) > 0 {
+ sb.WriteString(imStartTag + "system\n")
+
+ // if we have tools but no system message, match the reference implementation by providing a default system message
+ if systemMessage == "" {
+ systemMessage = "You are Qwen, a helpful AI assistant that can interact with a computer to solve tasks."
+ }
+
+ sb.WriteString(systemMessage)
+
+ if len(tools) > 0 {
+ sb.WriteString("\n\n# Tools\n\nYou have access to the following functions:\n\n")
+ sb.WriteString("")
+ for _, tool := range tools {
+ sb.WriteString("\n")
+ sb.WriteString("\n")
+ sb.WriteString("" + tool.Function.Name + "")
+ if tool.Function.Description != "" {
+ sb.WriteString("\n" + tool.Function.Description + "")
+ }
+ sb.WriteString("\n")
+
+ for name, prop := range tool.Function.Parameters.Properties {
+ sb.WriteString("\n")
+ sb.WriteString("\n" + name + "")
+
+ if len(prop.Type) > 0 {
+ // TODO(!!!)(drifkin): we should match the reference implementation for
+ // more complex types here instead of using this format
+ sb.WriteString("\n" + prop.ToTypeScriptType() + "")
+ }
+
+ if prop.Description != "" {
+ sb.WriteString("\n" + prop.Description + "")
+ }
+
+ // Render any additional keys not already handled
+ handledKeys := map[string]bool{
+ "type": true,
+ "description": true,
+ }
+ sb.WriteString(renderAdditionalKeys(prop, handledKeys))
+
+ sb.WriteString("\n")
+ }
+
+ // Render extra keys for parameters (everything except 'type' and 'properties')
+ paramHandledKeys := map[string]bool{
+ "type": true,
+ "properties": true,
+ }
+ sb.WriteString(renderAdditionalKeys(tool.Function.Parameters, paramHandledKeys))
+
+ sb.WriteString("\n")
+ sb.WriteString("\n")
+ }
+ sb.WriteString("\n")
+ sb.WriteString("\n\nIf you choose to call a function ONLY reply in the following format with NO suffix:\n\n\n\n\nvalue_1\n\n\nThis is the value for the second parameter\nthat can span\nmultiple lines\n\n\n\n\n\nReminder:\n- Function calls MUST follow the specified format: an inner block must be nested within XML tags\n- Required parameters MUST be specified\n- You may provide optional reasoning for your function call in natural language BEFORE the function call, but NOT after\n- If there is no function call available, answer the question like normal with your current knowledge and do not tell the user about function calls\n")
+ }
+
+ sb.WriteString(imEndTag + "\n")
+ }
+
+ for i, message := range filteredMessages {
+ lastMessage := i == len(filteredMessages)-1
+ prefill := lastMessage && message.Role == "assistant"
+ switch message.Role {
+ case "assistant":
+ if len(message.ToolCalls) > 0 {
+ sb.WriteString(imStartTag + "assistant\n")
+ if message.Content != "" {
+ sb.WriteString(message.Content + "\n")
+ }
+ for _, toolCall := range message.ToolCalls {
+ sb.WriteString("\n\n")
+ for name, value := range toolCall.Function.Arguments {
+ valueStr := formatToolCallArgument(value)
+ sb.WriteString("\n\n" + valueStr + "\n")
+ }
+ sb.WriteString("\n\n")
+ }
+ sb.WriteString("<|im_end|>\n")
+ } else {
+ sb.WriteString(imStartTag + "assistant\n")
+ sb.WriteString(message.Content)
+ if !prefill {
+ sb.WriteString(imEndTag + "\n")
+ }
+ }
+ case "tool":
+ // consecutive tool responses should share a single `user`, but
+ // have their own tags
+
+ // only start a new user block if this is the first tool response
+ if i == 0 || filteredMessages[i-1].Role != "tool" {
+ sb.WriteString(imStartTag + "user\n")
+ }
+
+ sb.WriteString("\n")
+ sb.WriteString(message.Content)
+ sb.WriteString("\n\n")
+
+ // close the user block only if this is the last tool response
+ if i == len(filteredMessages)-1 || filteredMessages[i+1].Role != "tool" {
+ sb.WriteString(imEndTag + "\n")
+ }
+ default:
+ sb.WriteString(imStartTag + message.Role + "\n")
+ sb.WriteString(message.Content)
+ sb.WriteString(imEndTag + "\n")
+ }
+
+ if lastMessage && !prefill {
+ sb.WriteString(imStartTag + "assistant\n")
+ }
+ }
+
+ return sb.String(), nil
+}
+
+func formatToolCallArgument(value any) string {
+ if value == nil {
+ return "null"
+ }
+
+ switch v := value.(type) {
+ case string:
+ return v
+ case []byte:
+ return string(v)
+ }
+
+ if reflect.TypeOf(value) != nil {
+ kind := reflect.TypeOf(value).Kind()
+ if kind == reflect.Map || kind == reflect.Slice || kind == reflect.Array {
+ if marshalled, err := json.Marshal(value); err == nil {
+ return string(marshalled)
+ }
+ }
+ }
+
+ return fmt.Sprintf("%v", value)
+}
diff --git a/model/renderers/qwen3coder_test.go b/model/renderers/qwen3coder_test.go
new file mode 100644
index 000000000..4aaa066d6
--- /dev/null
+++ b/model/renderers/qwen3coder_test.go
@@ -0,0 +1,338 @@
+package renderers
+
+import (
+ "testing"
+
+ "github.com/google/go-cmp/cmp"
+ "github.com/ollama/ollama/api"
+)
+
+func TestQwen3CoderRenderer(t *testing.T) {
+ tests := []struct {
+ name string
+ msgs []api.Message
+ tools []api.Tool
+ expected string
+ }{
+ {
+ name: "basic",
+ msgs: []api.Message{
+ {Role: "system", Content: "You are a helpful assistant."},
+ {Role: "user", Content: "Hello, how are you?"},
+ },
+ expected: `<|im_start|>system
+You are a helpful assistant.<|im_end|>
+<|im_start|>user
+Hello, how are you?<|im_end|>
+<|im_start|>assistant
+`,
+ },
+ {
+ name: "with tools and response",
+ msgs: []api.Message{
+ {Role: "system", Content: "You are a helpful assistant with access to tools."},
+ {Role: "user", Content: "What is the weather like in San Francisco?"},
+ {
+ Role: "assistant",
+ Content: "I'll check the weather in San Francisco for you.",
+ ToolCalls: []api.ToolCall{
+ {
+ Function: api.ToolCallFunction{
+ Name: "get_weather",
+ Arguments: map[string]any{
+ "unit": "fahrenheit",
+ },
+ },
+ },
+ },
+ },
+ {Role: "tool", Content: "{\"location\": \"San Francisco, CA\", \"temperature\": 68, \"condition\": \"partly cloudy\", \"humidity\": 65, \"wind_speed\": 12}", ToolName: "get_weather"},
+ {Role: "user", Content: "That sounds nice! What about New York?"},
+ },
+ tools: []api.Tool{
+ {Function: api.ToolFunction{
+ Name: "get_weather",
+ Description: "Get the current weather in a given location",
+ Parameters: api.ToolFunctionParameters{
+ Required: []string{"unit"},
+ Properties: map[string]api.ToolProperty{
+ "unit": {Type: api.PropertyType{"string"}, Enum: []any{"celsius", "fahrenheit"}, Description: "The unit of temperature"},
+ // TODO(drifkin): add multiple params back once we have predictable
+ // order via some sort of ordered map type (see
+ // )
+ /*
+ "location": {Type: api.PropertyType{"string"}, Description: "The city and state, e.g. San Francisco, CA"},
+ */
+ },
+ },
+ }},
+ },
+ expected: `<|im_start|>system
+You are a helpful assistant with access to tools.
+
+# Tools
+
+You have access to the following functions:
+
+
+
+get_weather
+Get the current weather in a given location
+
+
+unit
+string
+The unit of temperature
+["celsius","fahrenheit"]
+
+["unit"]
+
+
+
+
+If you choose to call a function ONLY reply in the following format with NO suffix:
+
+
+
+
+value_1
+
+
+This is the value for the second parameter
+that can span
+multiple lines
+
+
+
+
+
+Reminder:
+- Function calls MUST follow the specified format: an inner block must be nested within XML tags
+- Required parameters MUST be specified
+- You may provide optional reasoning for your function call in natural language BEFORE the function call, but NOT after
+- If there is no function call available, answer the question like normal with your current knowledge and do not tell the user about function calls
+<|im_end|>
+<|im_start|>user
+What is the weather like in San Francisco?<|im_end|>
+<|im_start|>assistant
+I'll check the weather in San Francisco for you.
+
+
+
+
+fahrenheit
+
+
+<|im_end|>
+<|im_start|>user
+
+{"location": "San Francisco, CA", "temperature": 68, "condition": "partly cloudy", "humidity": 65, "wind_speed": 12}
+
+<|im_end|>
+<|im_start|>user
+That sounds nice! What about New York?<|im_end|>
+<|im_start|>assistant
+`,
+ },
+ {
+ name: "parallel tool calls",
+ msgs: []api.Message{
+ {Role: "system", Content: "You are a helpful assistant with access to tools."},
+ {Role: "user", Content: "call double(1) and triple(2)"},
+ {Role: "assistant", Content: "I'll call double(1) and triple(2) for you.", ToolCalls: []api.ToolCall{
+ {Function: api.ToolCallFunction{Name: "double", Arguments: map[string]any{"number": "1"}}},
+ {Function: api.ToolCallFunction{Name: "triple", Arguments: map[string]any{"number": "2"}}},
+ }},
+ {Role: "tool", Content: "{\"number\": 2}", ToolName: "double"},
+ {Role: "tool", Content: "{\"number\": 6}", ToolName: "triple"},
+ },
+ tools: []api.Tool{
+ {Function: api.ToolFunction{Name: "double", Description: "Double a number", Parameters: api.ToolFunctionParameters{Properties: map[string]api.ToolProperty{
+ "number": {Type: api.PropertyType{"string"}, Description: "The number to double"},
+ }}}},
+ {Function: api.ToolFunction{Name: "triple", Description: "Triple a number", Parameters: api.ToolFunctionParameters{Properties: map[string]api.ToolProperty{
+ "number": {Type: api.PropertyType{"string"}, Description: "The number to triple"},
+ }}}},
+ },
+ expected: `<|im_start|>system
+You are a helpful assistant with access to tools.
+
+# Tools
+
+You have access to the following functions:
+
+
+
+double
+Double a number
+
+
+number
+string
+The number to double
+
+
+
+
+triple
+Triple a number
+
+
+number
+string
+The number to triple
+
+
+
+
+
+If you choose to call a function ONLY reply in the following format with NO suffix:
+
+
+
+
+value_1
+
+
+This is the value for the second parameter
+that can span
+multiple lines
+
+
+
+
+
+Reminder:
+- Function calls MUST follow the specified format: an inner block must be nested within XML tags
+- Required parameters MUST be specified
+- You may provide optional reasoning for your function call in natural language BEFORE the function call, but NOT after
+- If there is no function call available, answer the question like normal with your current knowledge and do not tell the user about function calls
+<|im_end|>
+<|im_start|>user
+call double(1) and triple(2)<|im_end|>
+<|im_start|>assistant
+I'll call double(1) and triple(2) for you.
+
+
+
+
+1
+
+
+
+
+
+
+2
+
+
+<|im_end|>
+<|im_start|>user
+
+{"number": 2}
+
+
+{"number": 6}
+
+<|im_end|>
+<|im_start|>assistant
+`,
+ },
+ {
+ name: "prefill",
+ msgs: []api.Message{
+ {Role: "system", Content: "You are a helpful assistant."},
+ {Role: "user", Content: "Tell me something interesting."},
+ {Role: "assistant", Content: "I'll tell you something interesting about cats"},
+ },
+ expected: `<|im_start|>system
+You are a helpful assistant.<|im_end|>
+<|im_start|>user
+Tell me something interesting.<|im_end|>
+<|im_start|>assistant
+I'll tell you something interesting about cats`,
+ },
+ {
+ name: "complex tool call arguments should remain json encoded",
+ msgs: []api.Message{
+ {Role: "user", Content: "call tool"},
+ {Role: "assistant", ToolCalls: []api.ToolCall{
+ {Function: api.ToolCallFunction{
+ Name: "echo",
+ Arguments: map[string]any{
+ "payload": map[string]any{"foo": "bar"},
+ },
+ }},
+ }},
+ {Role: "tool", Content: "{\"payload\": {\"foo\": \"bar\"}}", ToolName: "echo"},
+ },
+ expected: `<|im_start|>user
+call tool<|im_end|>
+<|im_start|>assistant
+
+
+
+
+{"foo":"bar"}
+
+
+<|im_end|>
+<|im_start|>user
+
+{"payload": {"foo": "bar"}}
+
+<|im_end|>
+<|im_start|>assistant
+`,
+ },
+ }
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ rendered, err := Qwen3CoderRenderer(tt.msgs, tt.tools, nil)
+ if err != nil {
+ t.Fatal(err)
+ }
+ if diff := cmp.Diff(rendered, tt.expected); diff != "" {
+ t.Errorf("mismatch (-got +want):\n%s", diff)
+ }
+ })
+ }
+}
+
+func TestFormatToolCallArgument(t *testing.T) {
+ tests := []struct {
+ name string
+ arg any
+ expected string
+ }{
+ {
+ name: "string",
+ arg: "foo",
+ // notice no quotes around the string
+ expected: "foo",
+ },
+ {
+ name: "map",
+ arg: map[string]any{"foo": "bar"},
+ expected: "{\"foo\":\"bar\"}",
+ },
+ {
+ name: "number",
+ arg: 1,
+ expected: "1",
+ },
+ {
+ name: "boolean",
+ arg: true,
+ expected: "true",
+ },
+ }
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ got := formatToolCallArgument(tt.arg)
+ if got != tt.expected {
+ t.Errorf("formatToolCallArgument(%v) = %v, want %v", tt.arg, got, tt.expected)
+ }
+ })
+ }
+}
diff --git a/model/renderers/renderer.go b/model/renderers/renderer.go
new file mode 100644
index 000000000..2dfb51e49
--- /dev/null
+++ b/model/renderers/renderer.go
@@ -0,0 +1,26 @@
+package renderers
+
+import (
+ "fmt"
+
+ "github.com/ollama/ollama/api"
+)
+
+type rendererFunc func([]api.Message, []api.Tool, *api.ThinkValue) (string, error)
+
+func RenderWithRenderer(name string, msgs []api.Message, tools []api.Tool, think *api.ThinkValue) (string, error) {
+ renderer := rendererForName(name)
+ if renderer == nil {
+ return "", fmt.Errorf("unknown renderer %q", name)
+ }
+ return renderer(msgs, tools, think)
+}
+
+func rendererForName(name string) rendererFunc {
+ switch name {
+ case "qwen3-coder":
+ return Qwen3CoderRenderer
+ default:
+ return nil
+ }
+}
diff --git a/model/sentencepiece.go b/model/sentencepiece.go
index 827ce00d9..db07beee9 100644
--- a/model/sentencepiece.go
+++ b/model/sentencepiece.go
@@ -12,18 +12,18 @@ import (
const spmWhitespaceSep = "▁"
-type SentencePieceModel struct {
+type SentencePiece struct {
maxTokenLen int
vocab *Vocabulary
}
-var _ TextProcessor = (*SentencePieceModel)(nil)
+var _ TextProcessor = (*SentencePiece)(nil)
-func (spm SentencePieceModel) Vocabulary() *Vocabulary {
+func (spm SentencePiece) Vocabulary() *Vocabulary {
return spm.vocab
}
-func NewSentencePieceModel(vocab *Vocabulary) SentencePieceModel {
+func NewSentencePiece(vocab *Vocabulary) SentencePiece {
logutil.Trace("Tokens", "num tokens", len(vocab.Values), "vals", vocab.Values[:5], "scores", vocab.Scores[:5], "types", vocab.Types[:5])
counter := map[int]int{}
@@ -42,17 +42,17 @@ func NewSentencePieceModel(vocab *Vocabulary) SentencePieceModel {
"user defined", counter[TOKEN_TYPE_USER_DEFINED], "unused", counter[TOKEN_TYPE_UNUSED], "byte", counter[TOKEN_TYPE_BYTE],
"max token len", maxTokenLen)
- return SentencePieceModel{
+ return SentencePiece{
maxTokenLen: maxTokenLen,
vocab: vocab,
}
}
-func (spm SentencePieceModel) Is(id int32, special Special) bool {
+func (spm SentencePiece) Is(id int32, special Special) bool {
return spm.vocab.Is(id, special)
}
-func (spm SentencePieceModel) Encode(s string, addSpecial bool) ([]int32, error) {
+func (spm SentencePiece) Encode(s string, addSpecial bool) ([]int32, error) {
fragments := []fragment{{value: s}}
for _, special := range spm.vocab.SpecialVocabulary() {
id := spm.vocab.Encode(special)
@@ -218,7 +218,7 @@ func (q *queue) Pop() interface{} {
return item
}
-func (spm SentencePieceModel) Decode(ids []int32) (string, error) {
+func (spm SentencePiece) Decode(ids []int32) (string, error) {
var sb strings.Builder
for _, id := range ids {
data := spm.vocab.Decode(id)
diff --git a/model/sentencepiece_test.go b/model/sentencepiece_test.go
index 50ac26787..8f4570c17 100644
--- a/model/sentencepiece_test.go
+++ b/model/sentencepiece_test.go
@@ -12,7 +12,7 @@ import (
"github.com/ollama/ollama/convert/sentencepiece"
)
-func loadSentencePieceVocab(t *testing.T) SentencePieceModel {
+func loadSentencePieceVocab(t *testing.T) SentencePiece {
t.Helper()
bts, err := os.ReadFile(filepath.Join("testdata", "gemma2", "tokenizer.model"))
@@ -45,7 +45,7 @@ func loadSentencePieceVocab(t *testing.T) SentencePieceModel {
}
}
- return NewSentencePieceModel(&v)
+ return NewSentencePiece(&v)
}
func TestSentencePieceEncode(t *testing.T) {
@@ -115,7 +115,7 @@ func TestSentencePieceEncode(t *testing.T) {
})
}
-func TestSentencePieceModelDecodeByteTokens(t *testing.T) {
+func TestSentencePieceDecodeByteTokens(t *testing.T) {
vocab := &Vocabulary{
Values: []string{
"normal",
@@ -134,7 +134,7 @@ func TestSentencePieceModelDecodeByteTokens(t *testing.T) {
Scores: []float32{0, 0, 0, 0, 0},
}
- spm := NewSentencePieceModel(vocab)
+ spm := NewSentencePiece(vocab)
tests := []struct {
name string
diff --git a/model/wordpiece.go b/model/wordpiece.go
new file mode 100644
index 000000000..e8d5e848a
--- /dev/null
+++ b/model/wordpiece.go
@@ -0,0 +1,167 @@
+package model
+
+import (
+ "fmt"
+ "iter"
+ "strings"
+ "unicode"
+
+ "github.com/ollama/ollama/logutil"
+)
+
+type WordPiece struct {
+ vocab *Vocabulary
+}
+
+// ggmlPrefix is the prefix used by GGML vocabularies to indicate word boundaries.
+// this differs from original word piece which uses "##" to indicate subwords.
+const ggmlPrefix = "▁"
+
+var wordPieceReplacer = strings.NewReplacer(
+ " .", ".",
+ " ?", "?",
+ " !", "!",
+ " ,", ",",
+ " ' ", "'",
+ " n't", "n't",
+ " 'm", "'m",
+ " do not", " don't",
+ " 's", "'s",
+ " 've", "'ve",
+ " 're", "'re",
+)
+
+// Decode implements TextProcessor.
+func (wpm WordPiece) Decode(ids []int32) (string, error) {
+ var sb strings.Builder
+ for i, id := range ids {
+ if id < 0 || int(id) >= len(wpm.vocab.Values) {
+ return "", fmt.Errorf("invalid token id: %d", id)
+ }
+
+ var separator string
+ piece := wpm.vocab.Values[id]
+ if i > 0 &&
+ (strings.HasPrefix(piece, ggmlPrefix) ||
+ (strings.HasPrefix(piece, "[") && strings.HasSuffix(piece, "]"))) {
+ separator = " "
+ }
+
+ sb.WriteString(wordPieceReplacer.Replace(separator + strings.TrimPrefix(piece, ggmlPrefix)))
+ }
+
+ return sb.String(), nil
+}
+
+// words splits a string into words, treating CJK characters as separate words.
+// TODO: this is specifically for BERT and may need to be adjusted or refactored for other models.
+func (wpm WordPiece) words(s string) iter.Seq[string] {
+ return func(yield func(string) bool) {
+ runes := make([]rune, 0, len(s)*3)
+ for _, r := range s {
+ switch {
+ case r >= 0x4E00 && r <= 0x9FFF,
+ r >= 0x3400 && r <= 0x4DBF,
+ r >= 0x20000 && r <= 0x2A6DF,
+ r >= 0x2A700 && r <= 0x2B73F,
+ r >= 0x2B740 && r <= 0x2B81F,
+ r >= 0x2B820 && r <= 0x2CEAF,
+ r >= 0xF900 && r <= 0xFAFF,
+ r >= 0x2F800 && r <= 0x2FA1F:
+ runes = append(runes, ' ', r, ' ')
+ default:
+ runes = append(runes, r)
+ }
+ }
+
+ for w := range strings.FieldsFuncSeq(string(runes), unicode.IsSpace) {
+ // split on but keep punctuation
+ var start int
+ for start < len(w) {
+ end := strings.IndexFunc(w[start:], unicode.IsPunct)
+ if end < 0 {
+ end = len(w) - start
+ } else if end == 0 {
+ end = 1
+ }
+
+ if !yield(w[start : start+end]) {
+ return
+ }
+
+ start += end
+ }
+ }
+ }
+}
+
+// Encode implements TextProcessor.
+func (wpm WordPiece) Encode(s string, addSpecial bool) ([]int32, error) {
+ var ids []int32
+
+ // TODO: use [UNK] from config
+ unk := wpm.vocab.Encode("[UNK]")
+ for word := range wpm.words(s) {
+ var start int
+ var pieces []int32
+ for start < len(word) {
+ end := len(word)
+
+ var piece int32
+ for start < end {
+ subword := word[start:end]
+ if start == 0 {
+ subword = ggmlPrefix + subword
+ }
+
+ // TODO: some models might not want [ToLower]
+ piece = wpm.vocab.Encode(strings.ToLower(subword))
+ if piece >= 0 {
+ break
+ }
+
+ end--
+ }
+
+ if piece < 0 {
+ // Unknown token
+ pieces = pieces[:0]
+ break
+ }
+
+ pieces = append(pieces, piece)
+ start = end
+ }
+
+ if len(pieces) > 0 {
+ ids = append(ids, pieces...)
+ } else {
+ ids = append(ids, unk)
+ }
+ }
+
+ if addSpecial && len(ids) > 0 {
+ ids = wpm.vocab.addSpecials(ids)
+ }
+
+ logutil.Trace("encoded", "string", s, "ids", ids)
+ return ids, nil
+}
+
+// Is implements TextProcessor.
+func (wpm WordPiece) Is(id int32, special Special) bool {
+ return wpm.vocab.Is(id, special)
+}
+
+// Vocabulary implements TextProcessor.
+func (wpm WordPiece) Vocabulary() *Vocabulary {
+ return wpm.vocab
+}
+
+var _ TextProcessor = (*WordPiece)(nil)
+
+func NewWordPiece(vocab *Vocabulary) WordPiece {
+ return WordPiece{
+ vocab: vocab,
+ }
+}
diff --git a/model/wordpiece_test.go b/model/wordpiece_test.go
new file mode 100644
index 000000000..258fbffcb
--- /dev/null
+++ b/model/wordpiece_test.go
@@ -0,0 +1,51 @@
+package model
+
+import (
+ "slices"
+ "testing"
+
+ "github.com/google/go-cmp/cmp"
+)
+
+func TestWordPiece(t *testing.T) {
+ wpm := NewWordPiece(
+ &Vocabulary{
+ Values: []string{"[UNK]", "[CLS]", "[SEP]", "▁hello", "▁world", "s", "▁!", "▁@", "▁#"},
+ AddBOS: true,
+ AddEOS: true,
+ BOS: []int32{1},
+ EOS: []int32{2},
+ })
+
+ ids, err := wpm.Encode("Hello world!", true)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ if diff := cmp.Diff([]int32{1, 3, 4, 6, 2}, ids); diff != "" {
+ t.Errorf("unexpected ids (-want +got):\n%s", diff)
+ }
+
+ words, err := wpm.Decode(ids)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ if diff := cmp.Diff("[CLS] hello world! [SEP]", words); diff != "" {
+ t.Errorf("unexpected words (-want +got):\n%s", diff)
+ }
+}
+
+func TestWordPieceWords(t *testing.T) {
+ var wpm WordPiece
+
+ basic := slices.Collect(wpm.words("Hey friend! How are you?!?"))
+ if diff := cmp.Diff([]string{"Hey", "friend", "!", "How", "are", "you", "?", "!", "?"}, basic); diff != "" {
+ t.Errorf("unexpected words (-want +got):\n%s", diff)
+ }
+
+ chinese := slices.Collect(wpm.words("野口里佳 Noguchi Rika"))
+ if diff := cmp.Diff([]string{"野", "口", "里", "佳", "Noguchi", "Rika"}, chinese); diff != "" {
+ t.Errorf("unexpected words (-want +got):\n%s", diff)
+ }
+}
diff --git a/openai/openai.go b/openai/openai.go
index b6a8a95e2..7ef5ac6de 100644
--- a/openai/openai.go
+++ b/openai/openai.go
@@ -105,16 +105,18 @@ type ChatCompletionRequest struct {
Tools []api.Tool `json:"tools"`
Reasoning *Reasoning `json:"reasoning,omitempty"`
ReasoningEffort *string `json:"reasoning_effort,omitempty"`
+ DebugRenderOnly bool `json:"_debug_render_only"`
}
type ChatCompletion struct {
- Id string `json:"id"`
- Object string `json:"object"`
- Created int64 `json:"created"`
- Model string `json:"model"`
- SystemFingerprint string `json:"system_fingerprint"`
- Choices []Choice `json:"choices"`
- Usage Usage `json:"usage,omitempty"`
+ Id string `json:"id"`
+ Object string `json:"object"`
+ Created int64 `json:"created"`
+ Model string `json:"model"`
+ SystemFingerprint string `json:"system_fingerprint"`
+ Choices []Choice `json:"choices"`
+ Usage Usage `json:"usage,omitempty"`
+ DebugInfo *api.DebugInfo `json:"_debug_info,omitempty"`
}
type ChatCompletionChunk struct {
@@ -141,6 +143,7 @@ type CompletionRequest struct {
Temperature *float32 `json:"temperature"`
TopP float32 `json:"top_p"`
Suffix string `json:"suffix"`
+ DebugRenderOnly bool `json:"_debug_render_only"`
}
type Completion struct {
@@ -273,8 +276,8 @@ func toChatCompletion(id string, r api.ChatResponse) ChatCompletion {
}
return nil
}(r.DoneReason),
- }},
- Usage: toUsage(r),
+ }}, Usage: toUsage(r),
+ DebugInfo: r.DebugInfo,
}
}
@@ -568,13 +571,14 @@ func fromChatRequest(r ChatCompletionRequest) (*api.ChatRequest, error) {
}
return &api.ChatRequest{
- Model: r.Model,
- Messages: messages,
- Format: format,
- Options: options,
- Stream: &r.Stream,
- Tools: r.Tools,
- Think: think,
+ Model: r.Model,
+ Messages: messages,
+ Format: format,
+ Options: options,
+ Stream: &r.Stream,
+ Tools: r.Tools,
+ Think: think,
+ DebugRenderOnly: r.DebugRenderOnly,
}, nil
}
@@ -648,11 +652,12 @@ func fromCompleteRequest(r CompletionRequest) (api.GenerateRequest, error) {
}
return api.GenerateRequest{
- Model: r.Model,
- Prompt: r.Prompt,
- Options: options,
- Stream: &r.Stream,
- Suffix: r.Suffix,
+ Model: r.Model,
+ Prompt: r.Prompt,
+ Options: options,
+ Stream: &r.Stream,
+ Suffix: r.Suffix,
+ DebugRenderOnly: r.DebugRenderOnly,
}, nil
}
diff --git a/parser/parser.go b/parser/parser.go
index e080f1bb7..c2e8f981f 100644
--- a/parser/parser.go
+++ b/parser/parser.go
@@ -100,6 +100,10 @@ func (f Modelfile) CreateRequest(relativeDir string) (*api.CreateRequest, error)
req.System = c.Args
case "license":
licenses = append(licenses, c.Args)
+ case "renderer":
+ req.Renderer = c.Args
+ case "parser":
+ req.Parser = c.Args
case "message":
role, msg, _ := strings.Cut(c.Args, ": ")
messages = append(messages, api.Message{Role: role, Content: msg})
@@ -320,7 +324,7 @@ func (c Command) String() string {
switch c.Name {
case "model":
fmt.Fprintf(&sb, "FROM %s", c.Args)
- case "license", "template", "system", "adapter":
+ case "license", "template", "system", "adapter", "renderer", "parser":
fmt.Fprintf(&sb, "%s %s", strings.ToUpper(c.Name), quote(c.Args))
case "message":
role, message, _ := strings.Cut(c.Args, ": ")
@@ -346,7 +350,7 @@ const (
var (
errMissingFrom = errors.New("no FROM line")
errInvalidMessageRole = errors.New("message role must be one of \"system\", \"user\", or \"assistant\"")
- errInvalidCommand = errors.New("command must be one of \"from\", \"license\", \"template\", \"system\", \"adapter\", \"parameter\", or \"message\"")
+ errInvalidCommand = errors.New("command must be one of \"from\", \"license\", \"template\", \"system\", \"adapter\", \"renderer\", \"parser\", \"parameter\", or \"message\"")
)
type ParserError struct {
@@ -606,7 +610,7 @@ func isValidMessageRole(role string) bool {
func isValidCommand(cmd string) bool {
switch strings.ToLower(cmd) {
- case "from", "license", "template", "system", "adapter", "parameter", "message":
+ case "from", "license", "template", "system", "adapter", "renderer", "parser", "parameter", "message":
return true
default:
return false
diff --git a/parser/parser_test.go b/parser/parser_test.go
index 7d5a808ba..1524e890a 100644
--- a/parser/parser_test.go
+++ b/parser/parser_test.go
@@ -198,6 +198,34 @@ BADCOMMAND param1 value1
}
}
+func TestParseFileRenderer(t *testing.T) {
+ input := `
+FROM foo
+RENDERER renderer1
+`
+
+ reader := strings.NewReader(input)
+
+ modelfile, err := ParseFile(reader)
+ require.NoError(t, err)
+
+ assert.Equal(t, []Command{{Name: "model", Args: "foo"}, {Name: "renderer", Args: "renderer1"}}, modelfile.Commands)
+}
+
+func TestParseFileParser(t *testing.T) {
+ input := `
+FROM foo
+PARSER parser1
+`
+
+ reader := strings.NewReader(input)
+
+ modelfile, err := ParseFile(reader)
+ require.NoError(t, err)
+
+ assert.Equal(t, []Command{{Name: "model", Args: "foo"}, {Name: "parser", Args: "parser1"}}, modelfile.Commands)
+}
+
func TestParseFileMessages(t *testing.T) {
cases := []struct {
input string
diff --git a/parser/token_parser.go b/parser/token_parser.go
deleted file mode 100644
index 812458299..000000000
--- a/parser/token_parser.go
+++ /dev/null
@@ -1,126 +0,0 @@
-package parser
-
-import (
- "encoding/json"
- "errors"
- "strings"
-
- "github.com/ollama/ollama/api"
- "github.com/ollama/ollama/harmony"
-)
-
-type TokenParserType int
-
-const (
- TokenParserTypeDefault TokenParserType = iota
- TokenParserTypeHarmony
-)
-
-type TokenParser struct {
- messageHandler MessageHandler
- parserEngine ParserInternals
- toolParser ToolParser
- lastToken string
- tokenRepeat int
- repeatLimit int
-}
-
-const defaultTokenRepeatLimit = 30
-
-type MessageHandler interface {
- AddContent(token string) (content, thinking string, toolContent string)
-}
-
-type ParserInternals interface {
- AddImplicitStartOrPrefill(prefillString string)
-}
-
-type ToolParser interface {
- Add(token string)
- Drain() (toolName *string, toolContent string)
-}
-
-// Default implementation for the TokenParser interface as a no-op passthrough
-type defaultMessageHandler struct{}
-
-func (defaultMessageHandler) AddContent(token string) (string, string, string) {
- return token, "", ""
-}
-
-type defaultEngine struct{}
-
-func (defaultEngine) AddImplicitStartOrPrefill(prefillString string) {}
-
-type defaultToolParser struct{}
-
-func (defaultToolParser) Add(token string) {}
-
-func (defaultToolParser) Drain() (*string, string) { return nil, "" }
-
-func NewTokenParser(parserType TokenParserType, prefillString string) TokenParser {
- switch parserType {
- case TokenParserTypeHarmony:
- harmonyMessageHandler := harmony.NewHarmonyMessageHandler()
- harmonyMessageHandler.HarmonyParser.AddImplicitStartOrPrefill(prefillString)
- return TokenParser{
- messageHandler: harmonyMessageHandler,
- parserEngine: harmonyMessageHandler.HarmonyParser,
- toolParser: harmonyMessageHandler.ToolParser,
- repeatLimit: defaultTokenRepeatLimit,
- }
-
- default:
- return TokenParser{
- messageHandler: defaultMessageHandler{},
- parserEngine: defaultEngine{},
- toolParser: defaultToolParser{},
- repeatLimit: 30,
- }
- }
-}
-
-func (p *TokenParser) AddContent(token string) (string, string, error) {
- if p.repeatLimitReached(token) {
- return "", "", errors.New("token repeat limit reached")
- }
- content, thinking, toolContent := p.messageHandler.AddContent(token)
- p.toolParser.Add(toolContent)
- return content, thinking, nil
-}
-
-// repeatLimitReached updates repeat counters and returns true if the repeat limit is reached.
-func (p *TokenParser) repeatLimitReached(token string) bool {
- if p == nil {
- return false
- }
- trimmed := strings.TrimSpace(token)
- if trimmed == p.lastToken {
- p.tokenRepeat++
- } else {
- p.tokenRepeat = 0
- }
- p.lastToken = trimmed
-
- return p.tokenRepeat >= p.repeatLimit
-}
-
-// TODO: update to work with multiple toolcalls - unmarshalling should also happen on parser level
-func (p *TokenParser) Drain() []api.ToolCall {
- toolName, toolContent := p.toolParser.Drain()
- if toolName != nil {
- *toolName = strings.TrimPrefix(*toolName, "functions.")
- var args api.ToolCallFunctionArguments
- if err := json.Unmarshal([]byte(toolContent), &args); err != nil {
- return nil
- }
- return []api.ToolCall{
- {
- Function: api.ToolCallFunction{
- Name: *toolName,
- Arguments: args,
- },
- },
- }
- }
- return nil
-}
diff --git a/runner/ollamarunner/runner.go b/runner/ollamarunner/runner.go
index 676e5186f..480cfc19b 100644
--- a/runner/ollamarunner/runner.go
+++ b/runner/ollamarunner/runner.go
@@ -11,7 +11,6 @@ import (
"image"
"log"
"log/slog"
- "math"
"net"
"net/http"
"os"
@@ -32,9 +31,9 @@ import (
"github.com/ollama/ollama/llm"
"github.com/ollama/ollama/logutil"
"github.com/ollama/ollama/ml"
+ "github.com/ollama/ollama/ml/nn/pooling"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/model/input"
- "github.com/ollama/ollama/parser"
"github.com/ollama/ollama/runner/common"
"github.com/ollama/ollama/sample"
@@ -406,7 +405,7 @@ func (s *Server) removeSequence(seqIndex int, reason llm.DoneReason) {
func (s *Server) run(ctx context.Context) {
s.ready.Wait()
- supportsAsync := s.model.Backend().Config().Uint("pooling_type", math.MaxUint32) == math.MaxUint32
+ supportsAsync := pooling.Type(s.model.Backend().Config().Uint("pooling_type")) == pooling.TypeNone
var activeBatch batchState
for {
@@ -468,6 +467,7 @@ func (s *Server) forwardBatch(pendingBatch batchState) (nextBatch batchState, er
// Prepare the seqs and batch, but defer the input token values as we may not be ready yet
var batchInputs []*input.Input
+ var batchOutputs []int32
var batch input.Batch
resumeSeq := -1
@@ -550,9 +550,9 @@ func (s *Server) forwardBatch(pendingBatch batchState) (nextBatch batchState, er
batch.Positions = append(batch.Positions, int32(len(seq.cache.Inputs)+len(seq.pendingInputs)))
batch.Sequences = append(batch.Sequences, seq.cache.Id)
- seq.iBatch = len(batch.Outputs)
- if i+1 == len(seq.inputs) {
- batch.Outputs = append(batch.Outputs, int32(len(batchInputs)-1))
+ seq.iBatch = len(batchOutputs)
+ if i+1 == len(seq.inputs) || seq.embeddingOnly {
+ batchOutputs = append(batchOutputs, int32(len(batchInputs)-1))
}
logutil.Trace("forwardBatch iBatch", "batchID", s.batchID, "seqIdx", seqIdx, "seq.iBatch", seq.iBatch, "i+1", i+1, "len(seq.inputs)", len(seq.inputs))
seq.pendingInputs = append(seq.pendingInputs, inp)
@@ -577,6 +577,7 @@ func (s *Server) forwardBatch(pendingBatch batchState) (nextBatch batchState, er
// Actual batchInputs values will be injected into the batch.Inputs tensor before calling Compute
batch.Inputs = nextBatch.ctx.Input().Empty(ml.DTypeI32, len(batchInputs))
+ batch.Outputs = nextBatch.ctx.Input().FromIntSlice(batchOutputs, len(batchOutputs))
nextBatch.modelOutput, err = model.Forward(nextBatch.ctx, s.model, batch)
if err != nil {
err = fmt.Errorf("failed to build graph: %w", err)
@@ -704,8 +705,8 @@ func (s *Server) computeBatch(activeBatch batchState) {
}
// sample a token
- vocabSize := len(outputs) / len(activeBatch.batch.Outputs)
- logutil.Trace("computeBatch: vocab details", "batchID", activeBatch.id, "seqIdx", i, "len(logits)", len(outputs), "len(activeBatch.batch.Outputs)", len(activeBatch.batch.Outputs), "vocabSize", vocabSize, "iBatches", iBatches)
+ vocabSize := len(outputs) / activeBatch.batch.Outputs.Dim(0)
+ logutil.Trace("computeBatch: vocab details", "batchID", activeBatch.id, "seqIdx", i, "len(logits)", len(outputs), "len(activeBatch.batch.Outputs)", activeBatch.batch.Outputs.Dim(0), "vocabSize", vocabSize, "iBatches", iBatches)
token, err := seq.sampler.Sample(outputs[iBatches[i]*vocabSize : (iBatches[i]+1)*vocabSize])
if err != nil {
s.hardErrCh <- fmt.Errorf("failed to sample token: %w", err)
@@ -781,8 +782,6 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
return
}
- tokenParser := parser.NewTokenParser(req.ParserType, req.PrefillString)
-
if req.Options == nil {
opts := api.DefaultOptions()
req.Options = &opts
@@ -873,18 +872,8 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
return
case content, ok := <-seq.responses:
if ok {
- var thinking string
- var err error
- content, thinking, err = tokenParser.AddContent(content)
- if err != nil {
- http.Error(w, err.Error(), http.StatusInternalServerError)
- close(seq.quit)
- return
- }
-
if err := json.NewEncoder(w).Encode(&llm.CompletionResponse{
- Content: content,
- Thinking: thinking,
+ Content: content,
}); err != nil {
http.Error(w, fmt.Sprintf("failed to encode response: %v", err), http.StatusInternalServerError)
close(seq.quit)
@@ -893,9 +882,7 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
flusher.Flush()
} else {
- toolCalls := tokenParser.Drain()
if err := json.NewEncoder(w).Encode(&llm.CompletionResponse{
- ToolCalls: toolCalls,
Done: true,
DoneReason: seq.doneReason,
PromptEvalCount: seq.numPromptInputs,
@@ -913,7 +900,7 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
}
func (s *Server) embeddings(w http.ResponseWriter, r *http.Request) {
- if s.model.Backend().Config().Uint("pooling_type", math.MaxUint32) == math.MaxUint32 {
+ if pooling.Type(s.model.Backend().Config().Uint("pooling_type")) == pooling.TypeNone {
http.Error(w, "this model does not support embeddings", http.StatusNotImplemented)
return
}
@@ -1061,12 +1048,8 @@ func (s *Server) reserveWorstCaseGraph() error {
batch.Positions[i] = int32(i)
}
- batch.Outputs = make([]int32, s.parallel)
- for i := range batch.Outputs {
- batch.Outputs[i] = int32(i)
- }
-
batch.Inputs = ctx.Input().FromIntSlice(batchInputs, len(batchInputs))
+ batch.Outputs = ctx.Input().Empty(ml.DTypeI32, s.parallel)
cache := s.model.Config().Cache
if cache != nil {
diff --git a/server/create.go b/server/create.go
index bd970876f..f08f18b34 100644
--- a/server/create.go
+++ b/server/create.go
@@ -323,6 +323,8 @@ func createModel(r api.CreateRequest, name model.Name, baseLayers []*layerGGML,
RootFS: RootFS{
Type: "layers",
},
+ Renderer: r.Renderer,
+ Parser: r.Parser,
}
var layers []Layer
diff --git a/server/images.go b/server/images.go
index 504eb95cf..6432860f8 100644
--- a/server/images.go
+++ b/server/images.go
@@ -24,6 +24,7 @@ import (
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/envconfig"
"github.com/ollama/ollama/fs/gguf"
+ "github.com/ollama/ollama/model/parsers"
"github.com/ollama/ollama/parser"
"github.com/ollama/ollama/template"
"github.com/ollama/ollama/thinking"
@@ -94,8 +95,9 @@ func (m *Model) Capabilities() []model.Capability {
return capabilities
}
+ builtinParser := parsers.ParserForName(m.Config.Parser)
// Check for tools capability
- if slices.Contains(m.Template.Vars(), "tools") {
+ if slices.Contains(m.Template.Vars(), "tools") || (builtinParser != nil && builtinParser.HasToolSupport()) {
capabilities = append(capabilities, model.CapabilityTools)
}
@@ -112,7 +114,8 @@ func (m *Model) Capabilities() []model.Capability {
// Check for thinking capability
openingTag, closingTag := thinking.InferTags(m.Template.Template)
hasTags := openingTag != "" && closingTag != ""
- if hasTags || slices.Contains([]string{"gptoss", "gpt-oss"}, m.Config.ModelFamily) {
+ isGptoss := slices.Contains([]string{"gptoss", "gpt-oss"}, m.Config.ModelFamily)
+ if hasTags || isGptoss || (builtinParser != nil && builtinParser.HasThinkingSupport()) {
capabilities = append(capabilities, model.CapabilityThinking)
}
@@ -198,6 +201,20 @@ func (m *Model) String() string {
})
}
+ if m.Config.Renderer != "" {
+ modelfile.Commands = append(modelfile.Commands, parser.Command{
+ Name: "renderer",
+ Args: m.Config.Renderer,
+ })
+ }
+
+ if m.Config.Parser != "" {
+ modelfile.Commands = append(modelfile.Commands, parser.Command{
+ Name: "parser",
+ Args: m.Config.Parser,
+ })
+ }
+
for k, v := range m.Options {
switch v := v.(type) {
case []any:
@@ -238,6 +255,8 @@ type ConfigV2 struct {
ModelFamilies []string `json:"model_families"`
ModelType string `json:"model_type"`
FileType string `json:"file_type"`
+ Renderer string `json:"renderer,omitempty"`
+ Parser string `json:"parser,omitempty"`
// required by spec
Architecture string `json:"architecture"`
diff --git a/server/prompt.go b/server/prompt.go
index f1d8020ea..56bc63030 100644
--- a/server/prompt.go
+++ b/server/prompt.go
@@ -11,6 +11,7 @@ import (
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/llm"
+ "github.com/ollama/ollama/model/renderers"
"github.com/ollama/ollama/template"
)
@@ -41,18 +42,12 @@ func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api.
}
}
- thinkVal := false
- thinkLevel := ""
- if think != nil {
- thinkVal = think.Bool()
- thinkLevel = think.String()
- }
- var b bytes.Buffer
- if err := m.Template.Execute(&b, template.Values{Messages: append(system, msgs[i:]...), Tools: tools, Think: thinkVal, ThinkLevel: thinkLevel, IsThinkSet: think != nil}); err != nil {
+ p, err := renderPrompt(m, append(system, msgs[i:]...), tools, think)
+ if err != nil {
return "", nil, err
}
- s, err := tokenize(ctx, b.String())
+ s, err := tokenize(ctx, p)
if err != nil {
return "", nil, err
}
@@ -101,6 +96,23 @@ func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api.
}
// truncate any messages that do not fit into the context window
+ p, err := renderPrompt(m, append(system, msgs[currMsgIdx:]...), tools, think)
+ if err != nil {
+ return "", nil, err
+ }
+
+ return p, images, nil
+}
+
+func renderPrompt(m *Model, msgs []api.Message, tools []api.Tool, think *api.ThinkValue) (string, error) {
+ if m.Config.Renderer != "" {
+ rendered, err := renderers.RenderWithRenderer(m.Config.Renderer, msgs, tools, think)
+ if err != nil {
+ return "", err
+ }
+ return rendered, nil
+ }
+
var b bytes.Buffer
thinkVal := false
thinkLevel := ""
@@ -108,9 +120,8 @@ func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api.
thinkVal = think.Bool()
thinkLevel = think.String()
}
- if err := m.Template.Execute(&b, template.Values{Messages: append(system, msgs[currMsgIdx:]...), Tools: tools, Think: thinkVal, ThinkLevel: thinkLevel, IsThinkSet: think != nil}); err != nil {
- return "", nil, err
+ if err := m.Template.Execute(&b, template.Values{Messages: msgs, Tools: tools, Think: thinkVal, ThinkLevel: thinkLevel, IsThinkSet: think != nil}); err != nil {
+ return "", err
}
-
- return b.String(), images, nil
+ return b.String(), nil
}
diff --git a/server/routes.go b/server/routes.go
index 8dd1b217a..e999c6c01 100644
--- a/server/routes.go
+++ b/server/routes.go
@@ -35,8 +35,8 @@ import (
"github.com/ollama/ollama/harmony"
"github.com/ollama/ollama/llm"
"github.com/ollama/ollama/logutil"
+ "github.com/ollama/ollama/model/parsers"
"github.com/ollama/ollama/openai"
- "github.com/ollama/ollama/parser"
"github.com/ollama/ollama/server/internal/client/ollama"
"github.com/ollama/ollama/server/internal/registry"
"github.com/ollama/ollama/template"
@@ -47,6 +47,18 @@ import (
"github.com/ollama/ollama/version"
)
+func shouldUseHarmony(model *Model) bool {
+ if slices.Contains([]string{"gptoss", "gpt-oss"}, model.Config.ModelFamily) {
+ // heuristic to check whether the template expects to be parsed via harmony:
+ // search for harmony tags that are nearly always used
+ if model.Template.Contains("<|start|>") && model.Template.Contains("<|end|>") {
+ return true
+ }
+ }
+
+ return false
+}
+
func experimentEnabled(name string) bool {
return slices.Contains(strings.Split(os.Getenv("OLLAMA_EXPERIMENT"), ","), name)
}
@@ -196,17 +208,13 @@ func (s *Server) GenerateHandler(c *gin.Context) {
return
}
- useHarmony := harmony.ShouldUseHarmony(m.Config.ModelFamily, m.Template) && !req.Raw
- var parserType parser.TokenParserType
+ useHarmony := shouldUseHarmony(m) && !req.Raw
+ var harmonyMessageHandler *harmony.HarmonyMessageHandler
+ var harmonyToolParser *harmony.HarmonyToolCallAccumulator
if useHarmony {
- parserType = parser.TokenParserTypeHarmony
- } else {
- parserType = parser.TokenParserTypeDefault
- }
- var functionNameMap *harmony.FunctionNameMap
-
- if useHarmony {
- functionNameMap = harmony.NewFunctionNameMap()
+ harmonyMessageHandler = harmony.NewHarmonyMessageHandler()
+ harmonyMessageHandler.HarmonyParser.AddImplicitStart()
+ harmonyToolParser = harmonyMessageHandler.CreateToolParser()
}
// Validate Think value: string values currently only allowed for gptoss models
@@ -322,10 +330,10 @@ func (s *Server) GenerateHandler(c *gin.Context) {
// If debug mode is enabled, return the rendered template instead of calling the model
if req.DebugRenderOnly {
- c.JSON(http.StatusOK, api.DebugTemplateResponse{
+ c.JSON(http.StatusOK, api.GenerateResponse{
Model: req.Model,
CreatedAt: time.Now().UTC(),
- DebugInfo: api.DebugInfo{
+ DebugInfo: &api.DebugInfo{
RenderedTemplate: prompt,
ImageCount: len(images),
},
@@ -350,19 +358,16 @@ func (s *Server) GenerateHandler(c *gin.Context) {
var sb strings.Builder
defer close(ch)
if err := r.Completion(c.Request.Context(), llm.CompletionRequest{
- Prompt: prompt,
- Images: images,
- Format: req.Format,
- Options: opts,
- ParserType: parserType,
+ Prompt: prompt,
+ Images: images,
+ Format: req.Format,
+ Options: opts,
}, func(cr llm.CompletionResponse) {
res := api.GenerateResponse{
Model: req.Model,
CreatedAt: time.Now().UTC(),
Response: cr.Content,
Done: cr.Done,
- Thinking: cr.Thinking,
- ToolCalls: cr.ToolCalls,
Metrics: api.Metrics{
PromptEvalCount: cr.PromptEvalCount,
PromptEvalDuration: cr.PromptEvalDuration,
@@ -371,22 +376,12 @@ func (s *Server) GenerateHandler(c *gin.Context) {
},
}
- if res.Done {
- res.DoneReason = cr.DoneReason.String()
- res.TotalDuration = time.Since(checkpointStart)
- res.LoadDuration = checkpointLoaded.Sub(checkpointStart)
- }
-
if useHarmony {
- for i, tool := range res.ToolCalls {
- res.ToolCalls[i].Function.Name = functionNameMap.OriginalFromConverted(tool.Function.Name)
- }
- if res.Response != "" || res.Thinking != "" || len(res.ToolCalls) > 0 || res.Done {
- ch <- res
- }
- return
- }
- if thinkingState != nil {
+ content, thinking, toolContent := harmonyMessageHandler.AddContent(cr.Content, harmonyToolParser)
+ res.Response = content
+ res.Thinking = thinking
+ harmonyToolParser.Add(toolContent)
+ } else if thinkingState != nil {
thinking, content := thinkingState.AddContent(cr.Content)
res.Thinking = thinking
res.Response = content
@@ -397,6 +392,30 @@ func (s *Server) GenerateHandler(c *gin.Context) {
}
if cr.Done {
+ if useHarmony {
+ toolName, toolContent := harmonyToolParser.Drain()
+ if toolName != nil {
+ *toolName = strings.TrimPrefix(*toolName, "functions.")
+ var args api.ToolCallFunctionArguments
+ if err := json.Unmarshal([]byte(toolContent), &args); err != nil {
+ errStr := fmt.Sprintf("error parsing tool call: raw='%s', err=%s", toolContent, err.Error())
+ ch <- gin.H{"error": errStr}
+ return
+ }
+
+ res.ToolCalls = append(res.ToolCalls, api.ToolCall{
+ Function: api.ToolCallFunction{
+ Name: *toolName,
+ Arguments: args,
+ },
+ })
+ }
+ }
+
+ res.DoneReason = cr.DoneReason.String()
+ res.TotalDuration = time.Since(checkpointStart)
+ res.LoadDuration = checkpointLoaded.Sub(checkpointStart)
+
if !req.Raw {
tokens, err := r.Tokenize(c.Request.Context(), prompt+sb.String())
if err != nil {
@@ -470,7 +489,6 @@ func (s *Server) EmbedHandler(c *gin.Context) {
}
truncate := true
-
if req.Truncate != nil && !*req.Truncate {
truncate = false
}
@@ -537,7 +555,16 @@ func (s *Server) EmbedHandler(c *gin.Context) {
return
}
+ if bos := kvData.Uint("tokenizer.ggml.bos_token_id"); tokens[0] != int(bos) && kvData.Bool("add_bos_token", true) {
+ ctxLen--
+ }
+
+ if eos := kvData.Uint("tokenizer.ggml.eos_token_id"); tokens[len(tokens)-1] != int(eos) && kvData.Bool("add_eos_token", true) {
+ ctxLen--
+ }
+
tokens = tokens[:ctxLen]
+
s, err = r.Detokenize(c.Request.Context(), tokens)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
@@ -1599,27 +1626,32 @@ func (s *Server) ChatHandler(c *gin.Context) {
}
msgs = filterThinkTags(msgs, m)
- useHarmony := harmony.ShouldUseHarmony(m.Config.ModelFamily, m.Template)
- var parserType parser.TokenParserType
- if useHarmony {
- parserType = parser.TokenParserTypeHarmony
- } else {
- parserType = parser.TokenParserTypeDefault
+ var builtinParser parsers.Parser
+ if m.Config.Parser != "" {
+ builtinParser = parsers.ParserForName(m.Config.Parser)
}
+ var harmonyMessageHandler *harmony.HarmonyMessageHandler
+ var harmonyToolParser *harmony.HarmonyToolCallAccumulator
+
+ useHarmony := shouldUseHarmony(m) || m.Config.Parser == "harmony"
+
processedTools := req.Tools
- var functionNameMap *harmony.FunctionNameMap
- var prefillString string
- // TODO(parthsareen): this can be abstracted to not be model specific and potentially moved to the runner
if useHarmony {
- prefillString = harmony.Prefill(msgs[len(msgs)-1])
- functionNameMap = harmony.NewFunctionNameMap()
+ harmonyMessageHandler = harmony.NewHarmonyMessageHandler()
+ var lastMessage *api.Message
+ if len(msgs) > 0 {
+ lastMessage = &msgs[len(msgs)-1]
+ }
+ harmonyMessageHandler.HarmonyParser.AddImplicitStartOrPrefill(lastMessage)
+ harmonyToolParser = harmonyMessageHandler.CreateToolParser()
+
// make a copy of tools to pass to the chat prompt. Function names may be
// renamed to be valid Harmony function names.
processedTools = make([]api.Tool, len(req.Tools))
copy(processedTools, req.Tools)
for i, tool := range processedTools {
- processedTools[i].Function.Name = functionNameMap.ConvertAndAdd(tool.Function.Name)
+ processedTools[i].Function.Name = harmonyMessageHandler.FunctionNameMap.ConvertAndAdd(tool.Function.Name)
}
}
@@ -1632,10 +1664,10 @@ func (s *Server) ChatHandler(c *gin.Context) {
// If debug mode is enabled, return the rendered template instead of calling the model
if req.DebugRenderOnly {
- c.JSON(http.StatusOK, api.DebugTemplateResponse{
+ c.JSON(http.StatusOK, api.ChatResponse{
Model: req.Model,
CreatedAt: time.Now().UTC(),
- DebugInfo: api.DebugInfo{
+ DebugInfo: &api.DebugInfo{
RenderedTemplate: prompt,
ImageCount: len(images),
},
@@ -1672,17 +1704,15 @@ func (s *Server) ChatHandler(c *gin.Context) {
defer close(ch)
if err := r.Completion(c.Request.Context(), llm.CompletionRequest{
- Prompt: prompt,
- Images: images,
- Format: req.Format,
- Options: opts,
- ParserType: parserType,
- PrefillString: prefillString,
+ Prompt: prompt,
+ Images: images,
+ Format: req.Format,
+ Options: opts,
}, func(r llm.CompletionResponse) {
res := api.ChatResponse{
Model: req.Model,
CreatedAt: time.Now().UTC(),
- Message: api.Message{Role: "assistant", Content: r.Content, Thinking: r.Thinking, ToolCalls: r.ToolCalls},
+ Message: api.Message{Role: "assistant", Content: r.Content},
Done: r.Done,
Metrics: api.Metrics{
PromptEvalCount: r.PromptEvalCount,
@@ -1697,14 +1727,54 @@ func (s *Server) ChatHandler(c *gin.Context) {
res.LoadDuration = checkpointLoaded.Sub(checkpointStart)
}
+ // TODO(drifkin): fold this as much as possibleinto the generic m.Config.Parser logic
if useHarmony {
- for i, tool := range res.Message.ToolCalls {
- res.Message.ToolCalls[i].Function.Name = functionNameMap.OriginalFromConverted(tool.Function.Name)
+ content, thinking, toolContent := harmonyMessageHandler.AddContent(r.Content, harmonyToolParser)
+ res.Message.Content = content
+ res.Message.Thinking = thinking
+ harmonyToolParser.Add(toolContent)
+
+ if r.Done {
+ toolName, toolContent := harmonyToolParser.Drain()
+ if toolName != nil {
+ *toolName = strings.TrimPrefix(*toolName, "functions.")
+ *toolName = harmonyMessageHandler.FunctionNameMap.OriginalFromConverted(*toolName)
+ var args api.ToolCallFunctionArguments
+ if err := json.Unmarshal([]byte(toolContent), &args); err != nil {
+ errStr := fmt.Sprintf("error parsing tool call: raw='%s', err=%s", toolContent, err.Error())
+ ch <- gin.H{"error": errStr}
+ return
+ }
+ res.Message.ToolCalls = []api.ToolCall{{Function: api.ToolCallFunction{Name: *toolName, Arguments: args}}}
+ }
}
+
// only send messages with meaningful content (empty messages confuse clients)
if res.Message.Content != "" || res.Message.Thinking != "" || len(res.Message.ToolCalls) > 0 || res.Done {
ch <- res
}
+
+ return
+ } else if builtinParser != nil {
+ slog.Log(context.TODO(), logutil.LevelTrace, "builtin parser input", "parser", m.Config.Parser, "content", r.Content)
+
+ content, thinking, toolCalls, err := builtinParser.Add(r.Content, req.Tools)
+ if err != nil {
+ ch <- gin.H{"error": err.Error()}
+ return
+ }
+
+ res.Message.Content = content
+ res.Message.Thinking = thinking
+ res.Message.ToolCalls = toolCalls
+
+ if res.Message.Content != "" || res.Message.Thinking != "" || len(res.Message.ToolCalls) > 0 || r.Done {
+ slog.Log(context.TODO(), logutil.LevelTrace, "builtin parser output", "parser", m.Config.Parser, "content", content, "thinking", thinking, "toolCalls", toolCalls, "done", r.Done)
+ ch <- res
+ } else {
+ slog.Log(context.TODO(), logutil.LevelTrace, "builtin parser empty output", "parser", m.Config.Parser)
+ }
+
return
}
diff --git a/server/routes_debug_test.go b/server/routes_debug_test.go
index f04a1da99..6507284ef 100644
--- a/server/routes_debug_test.go
+++ b/server/routes_debug_test.go
@@ -180,7 +180,7 @@ func TestGenerateDebugRenderOnly(t *testing.T) {
t.Errorf("expected status %d, got %d, body: %s", http.StatusOK, w.Code, w.Body.String())
}
- var response api.DebugTemplateResponse
+ var response api.GenerateResponse
if err := json.Unmarshal(w.Body.Bytes(), &response); err != nil {
t.Fatalf("failed to unmarshal response: %v", err)
}
@@ -385,7 +385,7 @@ func TestChatDebugRenderOnly(t *testing.T) {
t.Errorf("expected status %d, got %d, body: %s", http.StatusOK, w.Code, w.Body.String())
}
- var response api.DebugTemplateResponse
+ var response api.ChatResponse
if err := json.Unmarshal(w.Body.Bytes(), &response); err != nil {
t.Fatalf("failed to unmarshal response: %v", err)
}
diff --git a/server/routes_harmony_streaming_test.go b/server/routes_harmony_streaming_test.go
index bcb020886..b1ede4e39 100644
--- a/server/routes_harmony_streaming_test.go
+++ b/server/routes_harmony_streaming_test.go
@@ -7,6 +7,7 @@ import (
"bytes"
"context"
"encoding/json"
+ "net/http"
"strings"
"testing"
"time"
@@ -117,7 +118,7 @@ func TestChatHarmonyParserStreamingRealtime(t *testing.T) {
name: "content streams as it arrives",
steps: []step{
{
- input: llm.CompletionResponse{Content: "Hello", Done: false},
+ input: llm.CompletionResponse{Content: "<|message|>Hello", Done: false},
wantContent: "Hello",
},
{
@@ -125,7 +126,7 @@ func TestChatHarmonyParserStreamingRealtime(t *testing.T) {
wantContent: ", world",
},
{
- input: llm.CompletionResponse{Content: "!", Done: true, DoneReason: llm.DoneReasonStop},
+ input: llm.CompletionResponse{Content: "!<|end|>", Done: true, DoneReason: llm.DoneReasonStop},
wantContent: "!",
},
},
@@ -134,15 +135,20 @@ func TestChatHarmonyParserStreamingRealtime(t *testing.T) {
name: "thinking streams separately from content",
steps: []step{
{
- input: llm.CompletionResponse{Thinking: "Thinking...", Done: false},
+ input: llm.CompletionResponse{Content: "<|channel|>analysis<|message|>Thinking...", Done: false},
wantThinking: "Thinking...",
},
{
- input: llm.CompletionResponse{Content: "Answer", Done: false},
- wantContent: "Answer",
+ input: llm.CompletionResponse{Content: "<|end|>", Done: false},
+ // No output expected - just closes the analysis message and resets state to normal
},
{
- input: llm.CompletionResponse{Done: true, DoneReason: llm.DoneReasonStop},
+ input: llm.CompletionResponse{Content: "<|start|>assistant<|message|>Answer", Done: false},
+ wantContent: "Answer", // After message end, state is reset to normal
+ },
+ {
+ input: llm.CompletionResponse{Content: "<|end|>", Done: true, DoneReason: llm.DoneReasonStop},
+ // No output expected - just closes the assistant message
},
},
},
@@ -150,16 +156,24 @@ func TestChatHarmonyParserStreamingRealtime(t *testing.T) {
name: "partial tags buffer until complete",
steps: []step{
{
- input: llm.CompletionResponse{Thinking: "Deep ", Done: false},
+ input: llm.CompletionResponse{Content: "<|chan", Done: false},
+ // No output - partial tag
+ },
+ {
+ input: llm.CompletionResponse{Content: "nel|>analysis<|mess", Done: false},
+ // No output - still building tags
+ },
+ {
+ input: llm.CompletionResponse{Content: "age|>Deep ", Done: false},
wantThinking: "Deep ",
},
{
- input: llm.CompletionResponse{Thinking: "thought", Done: false},
+ input: llm.CompletionResponse{Content: "thought<|end|>", Done: false},
wantThinking: "thought",
},
{
- input: llm.CompletionResponse{Content: "Done", Done: true, DoneReason: llm.DoneReasonStop},
- wantContent: "Done",
+ input: llm.CompletionResponse{Content: "<|start|>assistant<|message|>Done<|end|>", Done: true, DoneReason: llm.DoneReasonStop},
+ wantContent: "Done", // After message end, state is reset to normal
},
},
},
@@ -167,7 +181,7 @@ func TestChatHarmonyParserStreamingRealtime(t *testing.T) {
name: "simple assistant after analysis",
steps: []step{
{
- input: llm.CompletionResponse{Thinking: "Think", Content: "Answer", Done: true, DoneReason: llm.DoneReasonStop},
+ input: llm.CompletionResponse{Content: "<|channel|>analysis<|message|>Think<|end|><|start|>assistant<|message|>Answer<|end|>", Done: true, DoneReason: llm.DoneReasonStop},
wantContent: "Answer",
wantThinking: "Think",
},
@@ -177,7 +191,7 @@ func TestChatHarmonyParserStreamingRealtime(t *testing.T) {
name: "tool call parsed and returned correctly",
steps: []step{
{
- input: llm.CompletionResponse{Content: "The weather is sunny", ToolCalls: []api.ToolCall{{Function: api.ToolCallFunction{Name: "get_weather", Arguments: api.ToolCallFunctionArguments{"location": "San Francisco"}}}}, Done: true, DoneReason: llm.DoneReasonStop},
+ input: llm.CompletionResponse{Content: "<|channel|>commentary to=functions.get_weather<|message|>{\"location\":\"San Francisco\"}<|end|><|start|>assistant<|message|>The weather is sunny<|end|>", Done: true, DoneReason: llm.DoneReasonStop},
wantContent: "The weather is sunny",
wantToolCalls: []api.ToolCall{
{
@@ -196,10 +210,15 @@ func TestChatHarmonyParserStreamingRealtime(t *testing.T) {
name: "tool call with streaming JSON across chunks",
steps: []step{
{
- input: llm.CompletionResponse{Done: false},
+ input: llm.CompletionResponse{Content: "<|channel|>commentary to=functions.calculate<|message|>{\"expr", Done: false},
+ // No output yet - incomplete JSON
},
{
- input: llm.CompletionResponse{ToolCalls: []api.ToolCall{{Function: api.ToolCallFunction{Name: "calculate", Arguments: api.ToolCallFunctionArguments{"expression": "2+2"}}}}, Done: true},
+ input: llm.CompletionResponse{Content: "ession\":\"2+", Done: false},
+ // Still no output - incomplete JSON
+ },
+ {
+ input: llm.CompletionResponse{Content: "2\"}", Done: true},
wantToolCalls: []api.ToolCall{
{
Function: api.ToolCallFunction{
@@ -381,9 +400,9 @@ func TestChatHarmonyParserStreamingSimple(t *testing.T) {
gin.SetMode(gin.TestMode)
mockResponses := []llm.CompletionResponse{
- {Content: "First ", Done: false},
+ {Content: "<|message|>First ", Done: false},
{Content: "chunk ", Done: false},
- {Content: "here", Done: true, DoneReason: llm.DoneReasonStop},
+ {Content: "here<|end|>", Done: true, DoneReason: llm.DoneReasonStop},
}
mock := mockRunner{
@@ -488,3 +507,189 @@ func TestChatHarmonyParserStreamingSimple(t *testing.T) {
t.Errorf("expected at least 2 content chunks for streaming, got %d", contentChunks)
}
}
+
+func TestChatHarmonyParserStreaming(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+
+ type expectedChunk struct {
+ afterResponse int // Which mock response this chunk should appear after
+ content string // Expected content in this chunk
+ thinking string // Expected thinking in this chunk
+ }
+
+ testCases := []struct {
+ name string
+ mockResponses []llm.CompletionResponse
+ expectedChunks []expectedChunk
+ wantContent string
+ wantThinking string
+ }{
+ {
+ name: "simple message without thinking",
+ mockResponses: []llm.CompletionResponse{
+ {Content: "<|start|>assistant<|message|>Hello, ", Done: false},
+ {Content: "how can I help?", Done: false},
+ {Content: "<|end|>", Done: true, DoneReason: llm.DoneReasonStop},
+ },
+ expectedChunks: []expectedChunk{
+ {afterResponse: 1, content: "Hello, "},
+ {afterResponse: 2, content: "how can I help?"},
+ },
+ wantContent: "Hello, how can I help?",
+ },
+ {
+ name: "message with analysis channel for thinking",
+ mockResponses: []llm.CompletionResponse{
+ {Content: "<|channel|>analysis<|message|>", Done: false},
+ {Content: "Let me think ", Done: false},
+ {Content: "about this problem...", Done: false},
+ {Content: "<|end|>", Done: false},
+ {Content: "<|start|>assistant<|message|>", Done: false},
+ {Content: "The answer ", Done: false},
+ {Content: "is 42", Done: false},
+ {Content: "<|end|>", Done: true, DoneReason: llm.DoneReasonStop},
+ },
+ expectedChunks: []expectedChunk{
+ {afterResponse: 2, thinking: "Let me think "},
+ {afterResponse: 3, thinking: "about this problem..."},
+ {afterResponse: 6, content: "The answer "},
+ {afterResponse: 7, content: "is 42"},
+ },
+ wantContent: "The answer is 42",
+ wantThinking: "Let me think about this problem...",
+ },
+ {
+ name: "streaming with partial tags across boundaries",
+ mockResponses: []llm.CompletionResponse{
+ {Content: "<|chan", Done: false},
+ {Content: "nel|>analy", Done: false},
+ {Content: "sis<|mess", Done: false},
+ {Content: "age|>Think", Done: false},
+ {Content: "ing deeply...<|end|>", Done: false},
+ {Content: "<|start|>assi", Done: false},
+ {Content: "stant<|message|>Result ", Done: false},
+ {Content: "computed<|e", Done: false},
+ {Content: "nd|>", Done: true, DoneReason: llm.DoneReasonStop},
+ },
+ expectedChunks: []expectedChunk{
+ {afterResponse: 4, thinking: "Think"},
+ {afterResponse: 5, thinking: "ing deeply..."},
+ {afterResponse: 7, content: "Result "},
+ {afterResponse: 8, content: "computed"},
+ },
+ wantContent: "Result computed",
+ wantThinking: "Thinking deeply...",
+ },
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ // Channel to synchronize mock responses with chunk verification
+ responsesSent := make(chan int, len(tc.mockResponses))
+
+ mock := mockRunner{
+ CompletionFn: func(ctx context.Context, r llm.CompletionRequest, fn func(llm.CompletionResponse)) error {
+ // Send mock responses one at a time, notifying when each is sent
+ for i, resp := range tc.mockResponses {
+ fn(resp)
+ responsesSent <- i + 1
+ }
+ close(responsesSent)
+ return nil
+ },
+ }
+
+ s := Server{
+ sched: &Scheduler{
+ pendingReqCh: make(chan *LlmRequest, 1),
+ finishedReqCh: make(chan *LlmRequest, 1),
+ expiredCh: make(chan *runnerRef, 1),
+ unloadedCh: make(chan any, 1),
+ loaded: make(map[string]*runnerRef),
+ newServerFn: newMockServer(&mock),
+ getGpuFn: discover.GetGPUInfo,
+ getCpuFn: discover.GetCPUInfo,
+ reschedDelay: 250 * time.Millisecond,
+ loadFn: func(req *LlmRequest, _ *ggml.GGML, _ discover.GpuInfoList, _ bool) bool {
+ req.successCh <- &runnerRef{
+ llama: &mock,
+ }
+ return false
+ },
+ },
+ }
+
+ go s.sched.Run(t.Context())
+
+ // Create a minimal model
+ _, digest := createHarmonyTestModel(t)
+
+ // Create model with passthrough template
+ stream := false
+ w := createRequest(t, s.CreateHandler, api.CreateRequest{
+ Model: "harmony-test",
+ Files: map[string]string{"file.gguf": digest},
+ Template: `<|start|><|end|>{{ with .Tools }}{{ end }}{{ .Prompt }}`,
+ Stream: &stream,
+ })
+
+ if w.Code != http.StatusOK {
+ t.Fatalf("failed to create model: %d", w.Code)
+ }
+
+ // Test chat endpoint with streaming
+ streamTrue := true
+ w = createRequest(t, s.ChatHandler, api.ChatRequest{
+ Model: "harmony-test",
+ Messages: []api.Message{{Role: "user", Content: "Hello"}},
+ Stream: &streamTrue,
+ Tools: getTestTools(),
+ })
+
+ if w.Code != http.StatusOK {
+ t.Fatalf("chat request failed: %d - %s", w.Code, w.Body.String())
+ }
+
+ // Parse streaming response
+ var chunks []api.ChatResponse
+ var content, thinking strings.Builder
+
+ decoder := json.NewDecoder(w.Body)
+ for decoder.More() {
+ var chunk api.ChatResponse
+ if err := decoder.Decode(&chunk); err != nil {
+ t.Fatalf("failed to decode chunk: %v", err)
+ }
+ chunks = append(chunks, chunk)
+
+ // Accumulate content and thinking from each chunk
+ content.WriteString(chunk.Message.Content)
+ thinking.WriteString(chunk.Message.Thinking)
+
+ // Debug output
+ t.Logf("Chunk %d: content=%q thinking=%q done=%v", len(chunks), chunk.Message.Content, chunk.Message.Thinking, chunk.Done)
+ }
+
+ // Verify we got streaming chunks
+ if len(chunks) == 0 {
+ t.Fatal("expected streaming chunks, got none")
+ }
+
+ gotContent := content.String()
+ gotThinking := thinking.String()
+
+ if gotContent != tc.wantContent {
+ t.Errorf("content mismatch: got %q, want %q", gotContent, tc.wantContent)
+ }
+ if gotThinking != tc.wantThinking {
+ t.Errorf("thinking mismatch: got %q, want %q", gotThinking, tc.wantThinking)
+ }
+
+ // Verify last chunk has done=true
+ lastChunk := chunks[len(chunks)-1]
+ if !lastChunk.Done {
+ t.Error("expected last chunk to have done=true")
+ }
+ })
+ }
+}