From e7f56ef3d8ac70280b05ec66989dfe0845f8f114 Mon Sep 17 00:00:00 2001 From: Devon Rifkin Date: Thu, 18 Sep 2025 14:55:59 -0700 Subject: [PATCH 01/11] harmony: remove special casing in routes.go Now that we have a built-in parser abstraction, which was introduced in , we can modify our harmony parser to match this and then get rid of nearly all of the harmony-specific logic in routes.go. We do have a small amount of code that turns the parser on by default if the architecture matches and no other built-in parser was provided. The built-in parser interface was modified in order to handle harmony's prefill and tool name translation requirements. --- .gitignore | 1 + harmony/harmonyparser.go | 77 +++++++++++++++++++++ model/parsers/parsers.go | 16 ++++- model/parsers/qwen3coder.go | 14 ++-- server/routes.go | 132 ++++++++++++------------------------ 5 files changed, 144 insertions(+), 96 deletions(-) diff --git a/.gitignore b/.gitignore index 3a2af0bd1..eabf94c28 100644 --- a/.gitignore +++ b/.gitignore @@ -6,6 +6,7 @@ dist build .cache +.gocache *.exe .idea test_data diff --git a/harmony/harmonyparser.go b/harmony/harmonyparser.go index a51819dda..b365b763d 100644 --- a/harmony/harmonyparser.go +++ b/harmony/harmonyparser.go @@ -1,6 +1,7 @@ package harmony import ( + "encoding/json" "fmt" "log/slog" "strings" @@ -265,6 +266,8 @@ type HarmonyMessageHandler struct { state harmonyMessageState HarmonyParser *HarmonyParser FunctionNameMap *FunctionNameMap + toolAccumulator *HarmonyToolCallAccumulator + convertedTools map[string]struct{} } // NewHarmonyMessageHandler creates a new message handler @@ -277,6 +280,7 @@ func NewHarmonyMessageHandler() *HarmonyMessageHandler { HeaderEndTag: "<|message|>", }, FunctionNameMap: NewFunctionNameMap(), + convertedTools: make(map[string]struct{}), } } @@ -384,6 +388,79 @@ func NewFunctionNameMap() *FunctionNameMap { } } +// Init initializes the handler with tools and optional last message +// Implements the Parser interface +func (h *HarmonyMessageHandler) Init(tools []api.Tool, lastMessage *api.Message) []api.Tool { + // Initialize the harmony parser + if h.HarmonyParser == nil { + h.HarmonyParser = &HarmonyParser{ + MessageStartTag: "<|start|>", + MessageEndTag: "<|end|>", + HeaderEndTag: "<|message|>", + } + } + + // Handle prefill for chat mode + if lastMessage != nil { + h.HarmonyParser.AddImplicitStartOrPrefill(lastMessage) + } else { + h.HarmonyParser.AddImplicitStart() + } + + // Initialize tool accumulator + h.toolAccumulator = h.CreateToolParser() + + // Process tools and return renamed versions + if len(tools) == 0 { + return tools + } + + processedTools := make([]api.Tool, len(tools)) + copy(processedTools, tools) + for i, tool := range processedTools { + if tool.Function.Name != "" { + processedTools[i].Function.Name = h.FunctionNameMap.ConvertAndAdd(tool.Function.Name) + h.convertedTools[tool.Function.Name] = struct{}{} + } + } + return processedTools +} + +// Add implements the Parser interface - processes streamed content and extracts content, thinking, and tool calls +func (h *HarmonyMessageHandler) Add(s string, done bool) (content string, thinking string, calls []api.ToolCall, err error) { + content, thinking, toolContent := h.AddContent(s, h.toolAccumulator) + if toolContent != "" { + h.toolAccumulator.Add(toolContent) + } + + // tool calls always happen one at a time, and always at the end of a message, + // so for simplicity we defer parsing them until we know we're done + if done { + toolName, raw := h.toolAccumulator.Drain() + if toolName != nil { + name := strings.TrimPrefix(*toolName, "functions.") + name = h.FunctionNameMap.OriginalFromConverted(name) + var args api.ToolCallFunctionArguments + if err := json.Unmarshal([]byte(raw), &args); err != nil { + return "", "", nil, fmt.Errorf("error parsing tool call: raw='%s', err=%w", raw, err) + } + calls = append(calls, api.ToolCall{Function: api.ToolCallFunction{Name: name, Arguments: args}}) + } + } + + return content, thinking, calls, nil +} + +// HasToolSupport implements the Parser interface +func (h *HarmonyMessageHandler) HasToolSupport() bool { + return true +} + +// HasThinkingSupport implements the Parser interface +func (h *HarmonyMessageHandler) HasThinkingSupport() bool { + return true +} + func (m *FunctionNameMap) ConvertAndAdd(userFunctionName string) string { harmonyFunctionName := m.deriveName(userFunctionName) m.userToHarmony[userFunctionName] = harmonyFunctionName diff --git a/model/parsers/parsers.go b/model/parsers/parsers.go index e6dbd1f4f..a1d4e8127 100644 --- a/model/parsers/parsers.go +++ b/model/parsers/parsers.go @@ -2,10 +2,16 @@ package parsers import ( "github.com/ollama/ollama/api" + "github.com/ollama/ollama/harmony" ) type Parser interface { - Add(s string, tools []api.Tool) (content string, thinking string, calls []api.ToolCall, err error) + // Init initializes the parser with tools and optional last message for chat prefill + // Returns processed tools if the parser needs to modify them (e.g., harmony renames them) + Init(tools []api.Tool, lastMessage *api.Message) []api.Tool + // Add processes streamed content and returns parsed content, thinking, and tool calls + // The done flag indicates if this is the last chunk (used for draining accumulators) + Add(s string, done bool) (content string, thinking string, calls []api.ToolCall, err error) HasToolSupport() bool HasThinkingSupport() bool } @@ -17,6 +23,8 @@ func ParserForName(name string) Parser { return parser case "passthrough": return &PassthroughParser{} + case "harmony": + return harmony.NewHarmonyMessageHandler() default: return nil } @@ -24,7 +32,11 @@ func ParserForName(name string) Parser { type PassthroughParser struct{} -func (p *PassthroughParser) Add(s string, tools []api.Tool) (content string, thinking string, calls []api.ToolCall, err error) { +func (p *PassthroughParser) Init(tools []api.Tool, lastMessage *api.Message) []api.Tool { + return tools // passthrough doesn't modify tools +} + +func (p *PassthroughParser) Add(s string, done bool) (content string, thinking string, calls []api.ToolCall, err error) { return s, "", nil, nil } diff --git a/model/parsers/qwen3coder.go b/model/parsers/qwen3coder.go index b0e8ec48c..b3629a5cc 100644 --- a/model/parsers/qwen3coder.go +++ b/model/parsers/qwen3coder.go @@ -31,6 +31,7 @@ const ( type Qwen3CoderParser struct { state qwenParserState acc strings.Builder + tools []api.Tool } func (p *Qwen3CoderParser) HasToolSupport() bool { @@ -41,7 +42,12 @@ func (p *Qwen3CoderParser) HasThinkingSupport() bool { return false } -func (p *Qwen3CoderParser) Add(s string, tools []api.Tool) (content string, thinking string, calls []api.ToolCall, err error) { +func (p *Qwen3CoderParser) Init(tools []api.Tool, lastMessage *api.Message) []api.Tool { + p.tools = tools + return tools // Qwen doesn't modify tools +} + +func (p *Qwen3CoderParser) Add(s string, done bool) (content string, thinking string, calls []api.ToolCall, err error) { p.acc.WriteString(s) events := p.parseEvents() @@ -51,7 +57,7 @@ func (p *Qwen3CoderParser) Add(s string, tools []api.Tool) (content string, thin for _, event := range events { switch event := event.(type) { case qwenEventRawToolCall: - toolCall, err := parseToolCall(event, tools) + toolCall, err := parseToolCall(event, p.tools) if err != nil { slog.Warn("qwen tool call parsing failed", "error", err) return "", "", nil, err @@ -359,7 +365,7 @@ func parseValue(raw string, paramType api.PropertyType) any { // Try array if typeSet["array"] { - var arr []interface{} + var arr []any if err := json.Unmarshal([]byte(raw), &arr); err == nil { return arr } @@ -371,7 +377,7 @@ func parseValue(raw string, paramType api.PropertyType) any { // Try object if typeSet["object"] { - var obj map[string]interface{} + var obj map[string]any if err := json.Unmarshal([]byte(raw), &obj); err == nil { return obj } diff --git a/server/routes.go b/server/routes.go index c02045318..3e9407025 100644 --- a/server/routes.go +++ b/server/routes.go @@ -34,7 +34,6 @@ import ( "github.com/ollama/ollama/envconfig" "github.com/ollama/ollama/format" "github.com/ollama/ollama/fs/ggml" - "github.com/ollama/ollama/harmony" "github.com/ollama/ollama/llm" "github.com/ollama/ollama/logutil" "github.com/ollama/ollama/model/parsers" @@ -288,17 +287,21 @@ func (s *Server) GenerateHandler(c *gin.Context) { return } - useHarmony := shouldUseHarmony(m) && !req.Raw - var harmonyMessageHandler *harmony.HarmonyMessageHandler - var harmonyToolParser *harmony.HarmonyToolCallAccumulator - if useHarmony { - harmonyMessageHandler = harmony.NewHarmonyMessageHandler() - harmonyMessageHandler.HarmonyParser.AddImplicitStart() - harmonyToolParser = harmonyMessageHandler.CreateToolParser() + var builtinParser parsers.Parser + if shouldUseHarmony(m) && m.Config.Parser == "" { + m.Config.Parser = "harmony" } - // Validate Think value: string values currently only allowed for gptoss models - if req.Think != nil && req.Think.IsString() && !useHarmony { + if !req.Raw && m.Config.Parser != "" { + builtinParser = parsers.ParserForName(m.Config.Parser) + if builtinParser != nil { + // no tools or last message for generate endpoint + builtinParser.Init(nil, nil) + } + } + + // Validate Think value: string values currently only allowed for harmony/gptoss models + if req.Think != nil && req.Think.IsString() && m.Config.Parser != "harmony" { c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("think value %q is not supported for this model", req.Think.String())}) return } @@ -422,7 +425,7 @@ func (s *Server) GenerateHandler(c *gin.Context) { } var thinkingState *thinking.Parser - if !useHarmony { + if builtinParser == nil { openingTag, closingTag := thinking.InferTags(m.Template.Template) if req.Think != nil && req.Think.Bool() && openingTag != "" && closingTag != "" { thinkingState = &thinking.Parser{ @@ -459,11 +462,17 @@ func (s *Server) GenerateHandler(c *gin.Context) { }, } - if useHarmony { - content, thinking, toolContent := harmonyMessageHandler.AddContent(cr.Content, harmonyToolParser) + if builtinParser != nil { + content, thinking, toolCalls, err := builtinParser.Add(cr.Content, cr.Done) + if err != nil { + ch <- gin.H{"error": err.Error()} + return + } res.Response = content res.Thinking = thinking - harmonyToolParser.Add(toolContent) + if cr.Done && len(toolCalls) > 0 { + res.ToolCalls = toolCalls + } } else if thinkingState != nil { thinking, content := thinkingState.AddContent(cr.Content) res.Thinking = thinking @@ -475,26 +484,6 @@ func (s *Server) GenerateHandler(c *gin.Context) { } if cr.Done { - if useHarmony { - toolName, toolContent := harmonyToolParser.Drain() - if toolName != nil { - *toolName = strings.TrimPrefix(*toolName, "functions.") - var args api.ToolCallFunctionArguments - if err := json.Unmarshal([]byte(toolContent), &args); err != nil { - errStr := fmt.Sprintf("error parsing tool call: raw='%s', err=%s", toolContent, err.Error()) - ch <- gin.H{"error": errStr} - return - } - - res.ToolCalls = append(res.ToolCalls, api.ToolCall{ - Function: api.ToolCallFunction{ - Name: *toolName, - Arguments: args, - }, - }) - } - } - res.DoneReason = cr.DoneReason.String() res.TotalDuration = time.Since(checkpointStart) res.LoadDuration = checkpointLoaded.Sub(checkpointStart) @@ -509,7 +498,7 @@ func (s *Server) GenerateHandler(c *gin.Context) { } } - if useHarmony { + if builtinParser != nil { // only send messages with meaningful content (empty messages confuse clients) if res.Response != "" || res.Thinking != "" || res.Done || len(res.ToolCalls) > 0 { ch <- res @@ -1853,32 +1842,23 @@ func (s *Server) ChatHandler(c *gin.Context) { } msgs = filterThinkTags(msgs, m) - var builtinParser parsers.Parser - if m.Config.Parser != "" { - builtinParser = parsers.ParserForName(m.Config.Parser) + if shouldUseHarmony(m) && m.Config.Parser == "" { + m.Config.Parser = "harmony" } - var harmonyMessageHandler *harmony.HarmonyMessageHandler - var harmonyToolParser *harmony.HarmonyToolCallAccumulator - - useHarmony := shouldUseHarmony(m) || m.Config.Parser == "harmony" - + var builtinParser parsers.Parser processedTools := req.Tools - if useHarmony { - harmonyMessageHandler = harmony.NewHarmonyMessageHandler() - var lastMessage *api.Message - if len(msgs) > 0 { - lastMessage = &msgs[len(msgs)-1] - } - harmonyMessageHandler.HarmonyParser.AddImplicitStartOrPrefill(lastMessage) - harmonyToolParser = harmonyMessageHandler.CreateToolParser() - // make a copy of tools to pass to the chat prompt. Function names may be - // renamed to be valid Harmony function names. - processedTools = make([]api.Tool, len(req.Tools)) - copy(processedTools, req.Tools) - for i, tool := range processedTools { - processedTools[i].Function.Name = harmonyMessageHandler.FunctionNameMap.ConvertAndAdd(tool.Function.Name) + if m.Config.Parser != "" { + builtinParser = parsers.ParserForName(m.Config.Parser) + if builtinParser != nil { + // Determine last message for chat prefill + var lastMessage *api.Message + if len(msgs) > 0 { + lastMessage = &msgs[len(msgs)-1] + } + // Initialize parser and get processed tools + processedTools = builtinParser.Init(req.Tools, lastMessage) } } @@ -1902,8 +1882,8 @@ func (s *Server) ChatHandler(c *gin.Context) { return } - // Validate Think value: string values currently only allowed for gptoss models - if req.Think != nil && req.Think.IsString() && !useHarmony { + // Validate Think value: string values currently only allowed for harmony/gptoss models + if req.Think != nil && req.Think.IsString() && m.Config.Parser != "harmony" { c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("think value %q is not supported for this model", req.Think.String())}) return } @@ -1922,7 +1902,7 @@ func (s *Server) ChatHandler(c *gin.Context) { } var toolParser *tools.Parser - if len(req.Tools) > 0 && !useHarmony { + if len(req.Tools) > 0 && (builtinParser == nil || !builtinParser.HasToolSupport()) { toolParser = tools.NewParser(m.Template.Template, req.Tools) } @@ -1954,38 +1934,10 @@ func (s *Server) ChatHandler(c *gin.Context) { res.LoadDuration = checkpointLoaded.Sub(checkpointStart) } - // TODO(drifkin): fold this as much as possibleinto the generic m.Config.Parser logic - if useHarmony { - content, thinking, toolContent := harmonyMessageHandler.AddContent(r.Content, harmonyToolParser) - res.Message.Content = content - res.Message.Thinking = thinking - harmonyToolParser.Add(toolContent) - - if r.Done { - toolName, toolContent := harmonyToolParser.Drain() - if toolName != nil { - *toolName = strings.TrimPrefix(*toolName, "functions.") - *toolName = harmonyMessageHandler.FunctionNameMap.OriginalFromConverted(*toolName) - var args api.ToolCallFunctionArguments - if err := json.Unmarshal([]byte(toolContent), &args); err != nil { - errStr := fmt.Sprintf("error parsing tool call: raw='%s', err=%s", toolContent, err.Error()) - ch <- gin.H{"error": errStr} - return - } - res.Message.ToolCalls = []api.ToolCall{{Function: api.ToolCallFunction{Name: *toolName, Arguments: args}}} - } - } - - // only send messages with meaningful content (empty messages confuse clients) - if res.Message.Content != "" || res.Message.Thinking != "" || len(res.Message.ToolCalls) > 0 || res.Done { - ch <- res - } - - return - } else if builtinParser != nil { + if builtinParser != nil { slog.Log(context.TODO(), logutil.LevelTrace, "builtin parser input", "parser", m.Config.Parser, "content", r.Content) - content, thinking, toolCalls, err := builtinParser.Add(r.Content, req.Tools) + content, thinking, toolCalls, err := builtinParser.Add(r.Content, r.Done) if err != nil { ch <- gin.H{"error": err.Error()} return From ae5c33008e53c1db1465b722b60835916e375d15 Mon Sep 17 00:00:00 2001 From: jmorganca Date: Fri, 19 Sep 2025 15:49:56 -0700 Subject: [PATCH 02/11] docs: move turbo.md to cloud.md --- docs/{turbo.md => cloud.md} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename docs/{turbo.md => cloud.md} (100%) diff --git a/docs/turbo.md b/docs/cloud.md similarity index 100% rename from docs/turbo.md rename to docs/cloud.md From af060eb2508e8bed25241163243bdd7471cb7fd6 Mon Sep 17 00:00:00 2001 From: jmorganca Date: Fri, 19 Sep 2025 15:50:41 -0700 Subject: [PATCH 03/11] docs: update cloud.md for cloud models --- docs/cloud.md | 113 ++++++++++---------------------------------------- 1 file changed, 23 insertions(+), 90 deletions(-) diff --git a/docs/cloud.md b/docs/cloud.md index d75d95570..300e6f5e0 100644 --- a/docs/cloud.md +++ b/docs/cloud.md @@ -1,107 +1,40 @@ -# Turbo +# Cloud -> ⚠️ Turbo is preview +| Ollama's cloud is currently in preview. For full documentation, see [Ollama's documentation](https://docs.ollama.com/cloud). -Ollama’s [Turbo](https://ollama.com/turbo) is a new way to run open-source models with acceleration from datacenter-grade hardware. +## Cloud Models -Currently, the following models are available in Turbo: +[Cloud models](https://ollama.com/cloud) are a new kind of model in Ollama that can run without a powerful GPU. Instead, cloud models are automatically offloaded to Ollama's cloud while offering the same capabilities as local models, making it possible to keep using your local tools while running larger models that wouldn’t fit on a personal computer. -- `gpt-oss:20b` -- `gpt-oss:120b` +Ollama currently supports the following cloud models, with more coming soon: -## Get started +- `gpt-oss:20b-cloud` +- `gpt-oss:120b-cloud` +- `deepseek-v3.1:671b-cloud` +- `qwen3-coder:480b-cloud` -### Ollama for macOS & Windows +### Get started -Download Ollama +To run a cloud model, open the terminal and run: -- Select a model such as `gpt-oss:20b` or `gpt-oss:120b` -- Click on **Turbo**. You’ll be prompted to create an account or sign in - -### Ollama’s CLI - -- [Sign up](https://ollama.com/signup) for an Ollama account -- Add your Ollama key [to ollama.com](https://ollama.com/settings/keys). - - On macOS and Linux: - - ```shell - cat ~/.ollama/id_ed25519.pub - ``` - - On Windows: - - ``` - type "%USERPROFILE%\.ollama\id_ed25519.pub" - ``` - -- Then run a model setting `OLLAMA_HOST` to `ollama.com`: - ```shell - OLLAMA_HOST=ollama.com ollama run gpt-oss:120b - ``` - -### Ollama’s Python library - -- Download Ollama's [Python library](https://github.com/ollama/ollama-python) -- [Sign up](https://ollama.com/signup) for an Ollama account -- Create an API key by visiting https://ollama.com/settings/keys - -```python -from ollama import Client - -client = Client( - host="https://ollama.com", - headers={'Authorization': ''} -) - -messages = [ - { - 'role': 'user', - 'content': 'Why is the sky blue?', - }, -] - -for part in client.chat('gpt-oss:120b', messages=messages, stream=True): - print(part['message']['content'], end='', flush=True) +``` +ollama run gpt-oss:120b-cloud ``` -### Ollama’s JavaScript library +To run cloud models with integrations that work with Ollama, first download the cloud model: -- Download Ollama's [JavaScript library](https://github.com/ollama/ollama-js) -- [Sign up](https://ollama.com/signup) for an Ollama account -- Create an API key by visiting https://ollama.com/settings/keys - -```typescript -import { Ollama } from 'ollama'; - -const ollama = new Ollama({ - host: 'https://ollama.com', - headers: { - Authorization: "Bearer " - } -}); - -const response = await ollama.chat({ - model: 'gpt-oss:120b', - messages: [{ role: 'user', content: 'Explain quantum computing' }], - stream: true -}); - -for await (const part of response) { - process.stdout.write(part.message.content) -} +``` +ollama pull qwen3-coder:480b-cloud ``` -### Community integrations +Then sign in to Ollama: -Turbo mode is also compatible with several community integrations. +``` +ollama signin +``` -#### Open WebUI +Finally, access the model using the model name `qwen3-coder:480b-cloud` via Ollama's local API or tooling. -- Go to **settings** → **Admin settings** → **Connections** -- Under **Ollama API,** click **+** -- For the **URL** put `https://ollama.com` -- For the **API key,** create an API key on https://ollama.com/settings/keys and add it. -- Click **Save** +## Cloud API access -Now, if you navigate to the model selector, Turbo models should be available under **External**. +Cloud models can also be accessed directly on ollama.com's API. For more information, see the [docs](https://docs.ollama.com/cloud). From c23e6f4cae3cbf62db68c2c9bf993925626fbe7c Mon Sep 17 00:00:00 2001 From: Daniel Hiltgen Date: Mon, 22 Sep 2025 11:23:14 -0700 Subject: [PATCH 04/11] tests: add single threaded history test (#12295) * tests: add single threaded history test Also tidies up some existing tests to handle more model output variation * test: add support for testing specific architectures --- integration/README.md | 3 + integration/api_test.go | 14 ++--- integration/basic_test.go | 8 +-- integration/context_test.go | 98 +++++++++++++++++++++++++++++- integration/library_models_test.go | 17 +++++- integration/model_arch_test.go | 5 +- integration/model_perf_test.go | 34 ++++++++--- integration/quantization_test.go | 5 +- integration/utils_test.go | 28 +++++++-- 9 files changed, 173 insertions(+), 39 deletions(-) diff --git a/integration/README.md b/integration/README.md index e52ba71ee..1dfd0e359 100644 --- a/integration/README.md +++ b/integration/README.md @@ -12,3 +12,6 @@ The integration tests have 2 modes of operating. > [!IMPORTANT] > Before running the tests locally without the "test existing" setting, compile ollama from the top of the source tree `go build .` in addition to GPU support with cmake if applicable on your platform. The integration tests expect to find an ollama binary at the top of the tree. + + +Many tests use a default small model suitable to run on many systems. You can override this default model by setting `OLLAMA_TEST_DEFAULT_MODEL` \ No newline at end of file diff --git a/integration/api_test.go b/integration/api_test.go index c39192c99..48572085d 100644 --- a/integration/api_test.go +++ b/integration/api_test.go @@ -22,13 +22,12 @@ func TestAPIGenerate(t *testing.T) { // Set up the test data req := api.GenerateRequest{ Model: smol, - Prompt: "why is the sky blue? be brief", + Prompt: blueSkyPrompt, Options: map[string]interface{}{ "temperature": 0, "seed": 123, }, } - anyResp := []string{"rayleigh", "scattering"} client, _, cleanup := InitServerConnection(ctx, t) defer cleanup() @@ -120,14 +119,14 @@ func TestAPIGenerate(t *testing.T) { // Verify the response contains the expected data response := buf.String() atLeastOne := false - for _, resp := range anyResp { + for _, resp := range blueSkyExpected { if strings.Contains(strings.ToLower(response), resp) { atLeastOne = true break } } if !atLeastOne { - t.Errorf("none of %v found in %s", anyResp, response) + t.Errorf("none of %v found in %s", blueSkyExpected, response) } case <-ctx.Done(): t.Error("outer test context done while waiting for generate") @@ -181,7 +180,7 @@ func TestAPIChat(t *testing.T) { Messages: []api.Message{ { Role: "user", - Content: "why is the sky blue? be brief", + Content: blueSkyPrompt, }, }, Options: map[string]interface{}{ @@ -189,7 +188,6 @@ func TestAPIChat(t *testing.T) { "seed": 123, }, } - anyResp := []string{"rayleigh", "scattering"} client, _, cleanup := InitServerConnection(ctx, t) defer cleanup() @@ -279,14 +277,14 @@ func TestAPIChat(t *testing.T) { // Verify the response contains the expected data response := buf.String() atLeastOne := false - for _, resp := range anyResp { + for _, resp := range blueSkyExpected { if strings.Contains(strings.ToLower(response), resp) { atLeastOne = true break } } if !atLeastOne { - t.Errorf("none of %v found in %s", anyResp, response) + t.Errorf("none of %v found in %s", blueSkyExpected, response) } case <-ctx.Done(): t.Error("outer test context done while waiting for chat") diff --git a/integration/basic_test.go b/integration/basic_test.go index 60cff172b..0a6b9253d 100644 --- a/integration/basic_test.go +++ b/integration/basic_test.go @@ -19,14 +19,14 @@ func TestBlueSky(t *testing.T) { // Set up the test data req := api.GenerateRequest{ Model: smol, - Prompt: "why is the sky blue?", + Prompt: blueSkyPrompt, Stream: &stream, Options: map[string]any{ "temperature": 0, "seed": 123, }, } - GenerateTestHelper(ctx, t, req, []string{"rayleigh", "scattering"}) + GenerateTestHelper(ctx, t, req, blueSkyExpected) } func TestUnicode(t *testing.T) { @@ -110,12 +110,12 @@ func TestUnicodeModelDir(t *testing.T) { req := api.GenerateRequest{ Model: smol, - Prompt: "why is the sky blue?", + Prompt: blueSkyPrompt, Stream: &stream, Options: map[string]any{ "temperature": 0, "seed": 123, }, } - GenerateTestHelper(ctx, t, req, []string{"rayleigh", "scattering"}) + GenerateTestHelper(ctx, t, req, blueSkyExpected) } diff --git a/integration/context_test.go b/integration/context_test.go index 15c157858..9d13f7acb 100644 --- a/integration/context_test.go +++ b/integration/context_test.go @@ -63,11 +63,11 @@ func TestContextExhaustion(t *testing.T) { if err := PullIfMissing(ctx, client, req.Model); err != nil { t.Fatalf("PullIfMissing failed: %v", err) } - DoGenerate(ctx, t, client, req, []string{"once", "upon", "lived", "sunny", "cloudy", "clear", "water"}, 120*time.Second, 10*time.Second) + DoGenerate(ctx, t, client, req, []string{"once", "upon", "lived", "sunny", "cloudy", "clear", "water", "time", "travel", "world"}, 120*time.Second, 10*time.Second) } // Send multiple generate requests with prior context and ensure the response is coherant and expected -func TestGenerateWithHistory(t *testing.T) { +func TestParallelGenerateWithHistory(t *testing.T) { modelOverride := ollamaEngineChatModels[0] // Most recent ollama engine model req, resp := GenerateRequests() numParallel := 2 @@ -113,8 +113,48 @@ func TestGenerateWithHistory(t *testing.T) { wg.Wait() } +// Send generate requests with prior context and ensure the response is coherant and expected +func TestGenerateWithHistory(t *testing.T) { + req := api.GenerateRequest{ + Model: smol, + Prompt: rainbowPrompt, + Stream: &stream, + KeepAlive: &api.Duration{Duration: 10 * time.Second}, + Options: map[string]any{ + "num_ctx": 16384, + }, + } + + softTimeout, hardTimeout := getTimeouts(t) + ctx, cancel := context.WithTimeout(context.Background(), hardTimeout) + defer cancel() + client, _, cleanup := InitServerConnection(ctx, t) + defer cleanup() + + // Get the server running (if applicable) warm the model up with a single initial request + slog.Info("loading", "model", req.Model) + err := client.Generate(ctx, + &api.GenerateRequest{Model: req.Model, KeepAlive: &api.Duration{Duration: 10 * time.Second}, Options: req.Options}, + func(response api.GenerateResponse) error { return nil }, + ) + if err != nil { + t.Fatalf("failed to load model %s: %s", req.Model, err) + } + + req.Context = DoGenerate(ctx, t, client, req, rainbowExpected, 30*time.Second, 20*time.Second) + + for i := 0; i < len(rainbowFollowups); i++ { + req.Prompt = rainbowFollowups[i] + if time.Now().Sub(started) > softTimeout { + slog.Info("exceeded soft timeout, winding down test") + return + } + req.Context = DoGenerate(ctx, t, client, req, rainbowExpected, 30*time.Second, 20*time.Second) + } +} + // Send multiple chat requests with prior context and ensure the response is coherant and expected -func TestChatWithHistory(t *testing.T) { +func TestParallelChatWithHistory(t *testing.T) { modelOverride := ollamaEngineChatModels[0] // Most recent ollama engine model req, resp := ChatRequests() numParallel := 2 @@ -164,3 +204,55 @@ func TestChatWithHistory(t *testing.T) { } wg.Wait() } + +// Send generate requests with prior context and ensure the response is coherant and expected +func TestChatWithHistory(t *testing.T) { + req := api.ChatRequest{ + Model: smol, + Stream: &stream, + KeepAlive: &api.Duration{Duration: 10 * time.Second}, + Options: map[string]any{ + "num_ctx": 16384, + }, + Messages: []api.Message{ + { + Role: "user", + Content: rainbowPrompt, + }, + }, + } + + softTimeout, hardTimeout := getTimeouts(t) + ctx, cancel := context.WithTimeout(context.Background(), hardTimeout) + defer cancel() + client, _, cleanup := InitServerConnection(ctx, t) + defer cleanup() + + // Get the server running (if applicable) warm the model up with a single initial request + slog.Info("loading", "model", req.Model) + err := client.Generate(ctx, + &api.GenerateRequest{Model: req.Model, KeepAlive: &api.Duration{Duration: 10 * time.Second}, Options: req.Options}, + func(response api.GenerateResponse) error { return nil }, + ) + if err != nil { + t.Fatalf("failed to load model %s: %s", req.Model, err) + } + + assistant := DoChat(ctx, t, client, req, rainbowExpected, 30*time.Second, 20*time.Second) + + for i := 0; i < len(rainbowFollowups); i++ { + if time.Now().Sub(started) > softTimeout { + slog.Info("exceeded soft timeout, winding down test") + return + } + req.Messages = append(req.Messages, + *assistant, + api.Message{Role: "user", Content: rainbowFollowups[i]}, + ) + + assistant = DoChat(ctx, t, client, req, rainbowExpected, 30*time.Second, 20*time.Second) + if assistant == nil { + t.Fatalf("didn't get an assistant response for context") + } + } +} diff --git a/integration/library_models_test.go b/integration/library_models_test.go index cdf65efc8..49e1097b8 100644 --- a/integration/library_models_test.go +++ b/integration/library_models_test.go @@ -4,7 +4,9 @@ package integration import ( "context" + "fmt" "log/slog" + "os" "testing" "time" @@ -20,6 +22,7 @@ func TestLibraryModelsGenerate(t *testing.T) { defer cancel() client, _, cleanup := InitServerConnection(ctx, t) defer cleanup() + targetArch := os.Getenv("OLLAMA_TEST_ARCHITECTURE") chatModels := libraryChatModels for _, model := range chatModels { @@ -30,16 +33,26 @@ func TestLibraryModelsGenerate(t *testing.T) { if err := PullIfMissing(ctx, client, model); err != nil { t.Fatalf("pull failed %s", err) } + if targetArch != "" { + resp, err := client.Show(ctx, &api.ShowRequest{Name: model}) + if err != nil { + t.Fatalf("unable to show model: %s", err) + } + arch := resp.ModelInfo["general.architecture"].(string) + if arch != targetArch { + t.Skip(fmt.Sprintf("Skipping %s architecture %s != %s", model, arch, targetArch)) + } + } req := api.GenerateRequest{ Model: model, - Prompt: "why is the sky blue?", + Prompt: blueSkyPrompt, KeepAlive: &api.Duration{Duration: 10 * time.Second}, Options: map[string]interface{}{ "temperature": 0.1, "seed": 123, }, } - anyResp := []string{"rayleigh", "scatter", "atmosphere", "nitrogen", "oxygen", "wavelength"} + anyResp := blueSkyExpected // Special cases if model == "duckdb-nsql" { anyResp = []string{"select", "from"} diff --git a/integration/model_arch_test.go b/integration/model_arch_test.go index 9fc2e01dd..721d95c54 100644 --- a/integration/model_arch_test.go +++ b/integration/model_arch_test.go @@ -68,14 +68,13 @@ func TestModelsGenerate(t *testing.T) { // TODO - fiddle with context size req := api.GenerateRequest{ Model: model, - Prompt: "why is the sky blue?", + Prompt: blueSkyPrompt, Options: map[string]interface{}{ "temperature": 0, "seed": 123, }, } - anyResp := []string{"rayleigh", "scattering", "atmosphere", "nitrogen", "oxygen"} - DoGenerate(ctx, t, client, req, anyResp, 120*time.Second, 30*time.Second) + DoGenerate(ctx, t, client, req, blueSkyExpected, 120*time.Second, 30*time.Second) }) } } diff --git a/integration/model_perf_test.go b/integration/model_perf_test.go index 759e8b9a2..3d6ba9239 100644 --- a/integration/model_perf_test.go +++ b/integration/model_perf_test.go @@ -40,6 +40,18 @@ var ( // cat int.log | grep MODEL_PERF_HEADER | head -1| cut -f2- -d: > perf.csv // cat int.log | grep MODEL_PERF_DATA | cut -f2- -d: >> perf.csv func TestModelsPerf(t *testing.T) { + if s := os.Getenv("OLLAMA_NEW_ENGINE"); s != "" { + doModelPerfTest(t, ollamaEngineChatModels) + } else { + doModelPerfTest(t, append(ollamaEngineChatModels, llamaRunnerChatModels...)) + } +} + +func TestLibraryModelsPerf(t *testing.T) { + doModelPerfTest(t, libraryChatModels) +} + +func doModelPerfTest(t *testing.T, chatModels []string) { softTimeout, hardTimeout := getTimeouts(t) slog.Info("Setting timeouts", "soft", softTimeout, "hard", hardTimeout) ctx, cancel := context.WithTimeout(context.Background(), hardTimeout) @@ -65,14 +77,12 @@ func TestModelsPerf(t *testing.T) { } longPrompt := "summarize the following: " + string(data) - var chatModels []string - if s := os.Getenv("OLLAMA_NEW_ENGINE"); s != "" { - chatModels = ollamaEngineChatModels - } else { - chatModels = append(ollamaEngineChatModels, llamaRunnerChatModels...) - } + targetArch := os.Getenv("OLLAMA_TEST_ARCHITECTURE") for _, model := range chatModels { + if !strings.Contains(model, ":") { + model = model + ":latest" + } t.Run(model, func(t *testing.T) { if time.Now().Sub(started) > softTimeout { t.Skip("skipping remaining tests to avoid excessive runtime") @@ -88,6 +98,9 @@ func TestModelsPerf(t *testing.T) { } arch := resp.ModelInfo["general.architecture"].(string) maxContext = int(resp.ModelInfo[fmt.Sprintf("%s.context_length", arch)].(float64)) + if targetArch != "" && arch != targetArch { + t.Skip(fmt.Sprintf("Skipping %s architecture %s != %s", model, arch, targetArch)) + } if maxVram > 0 { resp, err := client.List(ctx) @@ -151,8 +164,8 @@ func TestModelsPerf(t *testing.T) { prompt string anyResp []string }{ - {"why is the sky blue?", []string{"rayleigh", "scattering", "atmosphere", "nitrogen", "oxygen"}}, - {maxPrompt, []string{"shakespeare", "oppression", "sorrows", "gutenberg", "child", "license", "sonnet", "melancholy"}}, + {blueSkyPrompt, blueSkyExpected}, + {maxPrompt, []string{"shakespeare", "oppression", "sorrows", "gutenberg", "child", "license", "sonnet", "melancholy", "love", "sorrow", "beauty"}}, } var gpuPercent int for _, tc := range testCases { @@ -241,11 +254,12 @@ func TestModelsPerf(t *testing.T) { } } } + // Round the logged prompt count for comparisons across versions/configurations which can vary slightly fmt.Fprintf(os.Stderr, "MODEL_PERF_HEADER:%s,%s,%s,%s,%s,%s,%s\n", "MODEL", "CONTEXT", "GPU PERCENT", - "PROMPT COUNT", + "APPROX PROMPT COUNT", "LOAD TIME", "PROMPT EVAL TPS", "EVAL TPS", @@ -254,7 +268,7 @@ func TestModelsPerf(t *testing.T) { model, numCtx, gpuPercent, - resp.PromptEvalCount, + (resp.PromptEvalCount/10)*10, float64(resp.LoadDuration)/1000000000.0, float64(resp.PromptEvalCount)/(float64(resp.PromptEvalDuration)/1000000000.0), float64(resp.EvalCount)/(float64(resp.EvalDuration)/1000000000.0), diff --git a/integration/quantization_test.go b/integration/quantization_test.go index af9da0b62..305647496 100644 --- a/integration/quantization_test.go +++ b/integration/quantization_test.go @@ -76,7 +76,7 @@ func TestQuantization(t *testing.T) { stream := true genReq := api.GenerateRequest{ Model: newName, - Prompt: "why is the sky blue?", + Prompt: blueSkyPrompt, KeepAlive: &api.Duration{Duration: 3 * time.Second}, Options: map[string]any{ "seed": 42, @@ -88,14 +88,13 @@ func TestQuantization(t *testing.T) { // Some smaller quantizations can cause models to have poor quality // or get stuck in repetition loops, so we stop as soon as we have any matches - anyResp := []string{"rayleigh", "scattering", "day", "sun", "moon", "color", "nitrogen", "oxygen"} reqCtx, reqCancel := context.WithCancel(ctx) atLeastOne := false var buf bytes.Buffer genfn := func(response api.GenerateResponse) error { buf.Write([]byte(response.Response)) fullResp := strings.ToLower(buf.String()) - for _, resp := range anyResp { + for _, resp := range blueSkyExpected { if strings.Contains(fullResp, resp) { atLeastOne = true t.Log(fullResp) diff --git a/integration/utils_test.go b/integration/utils_test.go index 7901fed3f..f8ec13f39 100644 --- a/integration/utils_test.go +++ b/integration/utils_test.go @@ -256,13 +256,29 @@ var ( "snowflake-arctic-embed", "snowflake-arctic-embed2", } + + blueSkyPrompt = "why is the sky blue? Be brief but factual in your reply" + blueSkyExpected = []string{"rayleigh", "scatter", "atmosphere", "nitrogen", "oxygen", "wavelength", "interact"} + + rainbowPrompt = "how do rainbows form? Be brief but factual in your reply" + rainbowFollowups = []string{ + "Explain the physics involved in them. Be breif in your reply", + "Explain the chemistry involved in them. Be breif in your reply", + "Explain the quantum mechanics involved in them. Be breif in your reply", + "What are common myths related to them? Be brief in your reply", + "What are common fairytales related to them? Be brief in your reply", + "Can they form if there is no rain? Be breif in your reply", + "Can they form if there are no clouds? Be breif in your reply", + "Do they happen on other planets? Be brief in your reply", + } + rainbowExpected = []string{"water", "droplet", "mist", "glow", "refracted", "reflect", "color", "spectrum", "frequency", "end", "gold", "fortune", "blessing", "prosperity"} ) func init() { lifecycle.InitLogging() - custom := os.Getenv("OLLAMA_TEST_SMOL_MODEL") + custom := os.Getenv("OLLAMA_TEST_DEFAULT_MODEL") if custom != "" { - slog.Info("setting smol test model to " + custom) + slog.Info("setting default test model to " + custom) smol = custom } } @@ -577,11 +593,11 @@ func GenerateRequests() ([]api.GenerateRequest, [][]string) { }, }, [][]string{ - {"sunlight", "scattering", "interact", "color", "surface", "depth", "red", "orange", "yellow", "absorbs", "wavelength"}, - {"soil", "organic", "earth", "black", "tan", "chemical", "processes", "pigments", "particles", "iron oxide", "rust", "air", "water", "mixture", "mixing"}, - {"water", "droplet", "refracted", "reflect", "color", "spectrum"}, + {"sunlight", "scatter", "interact", "color", "surface", "depth", "red", "orange", "yellow", "absorb", "wavelength", "water", "molecule"}, + {"soil", "organic", "earth", "black", "tan", "chemical", "processes", "pigment", "particle", "iron oxide", "rust", "air", "water", "wet", "mixture", "mixing", "mineral", "element", "decomposed", "matter", "wavelength"}, + {"water", "droplet", "refract", "reflect", "color", "spectrum", "raindrop"}, {"fourth", "july", "declaration", "independence"}, - {"nitrogen", "oxygen", "carbon", "dioxide", "water", "vapor"}, + {"nitrogen", "oxygen", "carbon", "dioxide", "water", "vapor", "fluid", "particles", "gas"}, } } From 64883e3c4c0238dc70fddcc456af569d1489415d Mon Sep 17 00:00:00 2001 From: Patrick Devine Date: Mon, 22 Sep 2025 23:20:20 -0700 Subject: [PATCH 05/11] auth: fix problems with the ollama keypairs (#12373) * auth: fix problems with the ollama keypairs This change adds several fixes including: - reading in the pubkey files correctly - fixing the push unit test to create a keypair file in a temp directory - not return 500 errors for normal status error --- api/client.go | 24 ++++++++---- api/types.go | 2 +- auth/auth.go | 40 ++------------------ cmd/cmd.go | 56 +++++++++++++--------------- cmd/cmd_test.go | 3 ++ server/routes.go | 96 +++++++++++++++++++++++++++++++++++------------- 6 files changed, 119 insertions(+), 102 deletions(-) diff --git a/api/client.go b/api/client.go index 20e6d7957..0d4c97ba9 100644 --- a/api/client.go +++ b/api/client.go @@ -45,6 +45,12 @@ func checkError(resp *http.Response, body []byte) error { return nil } + if resp.StatusCode == http.StatusUnauthorized { + authError := AuthorizationError{StatusCode: resp.StatusCode} + json.Unmarshal(body, &authError) + return authError + } + apiError := StatusError{StatusCode: resp.StatusCode} err := json.Unmarshal(body, &apiError) @@ -214,7 +220,8 @@ func (c *Client) stream(ctx context.Context, method, path string, data any, fn f scanner.Buffer(scanBuf, maxBufferSize) for scanner.Scan() { var errorResponse struct { - Error string `json:"error,omitempty"` + Error string `json:"error,omitempty"` + SigninURL string `json:"signin_url,omitempty"` } bts := scanner.Bytes() @@ -223,14 +230,10 @@ func (c *Client) stream(ctx context.Context, method, path string, data any, fn f } if response.StatusCode == http.StatusUnauthorized { - pubKey, pkErr := auth.GetPublicKey() - if pkErr != nil { - return pkErr - } return AuthorizationError{ StatusCode: response.StatusCode, Status: response.Status, - PublicKey: pubKey, + SigninURL: errorResponse.SigninURL, } } else if response.StatusCode >= http.StatusBadRequest { return StatusError{ @@ -439,8 +442,13 @@ func (c *Client) Version(ctx context.Context) (string, error) { return version.Version, nil } -// Signout will disconnect an ollama instance from ollama.com -func (c *Client) Signout(ctx context.Context, encodedKey string) error { +// Signout will signout a client for a local ollama server. +func (c *Client) Signout(ctx context.Context) error { + return c.do(ctx, http.MethodPost, "/api/signout", nil, nil) +} + +// Disconnect will disconnect an ollama instance from ollama.com. +func (c *Client) Disconnect(ctx context.Context, encodedKey string) error { return c.do(ctx, http.MethodDelete, fmt.Sprintf("/api/user/keys/%s", encodedKey), nil, nil) } diff --git a/api/types.go b/api/types.go index 5b8e034c2..8cc7752ca 100644 --- a/api/types.go +++ b/api/types.go @@ -41,7 +41,7 @@ func (e StatusError) Error() string { type AuthorizationError struct { StatusCode int Status string - PublicKey string `json:"public_key"` + SigninURL string `json:"signin_url"` } func (e AuthorizationError) Error() string { diff --git a/auth/auth.go b/auth/auth.go index b26e2315b..f820964e7 100644 --- a/auth/auth.go +++ b/auth/auth.go @@ -18,46 +18,13 @@ import ( const defaultPrivateKey = "id_ed25519" -func keyPath() (string, error) { - fileIsReadable := func(fp string) bool { - info, err := os.Stat(fp) - if err != nil { - return false - } - - // Check that it's a regular file, not a directory or other file type - if !info.Mode().IsRegular() { - return false - } - - // Try to open it to check readability - file, err := os.Open(fp) - if err != nil { - return false - } - file.Close() - return true - } - - systemPath := filepath.Join("/usr/share/ollama/.ollama", defaultPrivateKey) - if fileIsReadable(systemPath) { - return systemPath, nil - } - +func GetPublicKey() (string, error) { home, err := os.UserHomeDir() if err != nil { return "", err } - return filepath.Join(home, ".ollama", defaultPrivateKey), nil -} - -func GetPublicKey() (string, error) { - keyPath, err := keyPath() - if err != nil { - return "", err - } - + keyPath := filepath.Join(home, ".ollama", defaultPrivateKey) privateKeyFile, err := os.ReadFile(keyPath) if err != nil { slog.Info(fmt.Sprintf("Failed to load private key: %v", err)) @@ -84,11 +51,12 @@ func NewNonce(r io.Reader, length int) (string, error) { } func Sign(ctx context.Context, bts []byte) (string, error) { - keyPath, err := keyPath() + home, err := os.UserHomeDir() if err != nil { return "", err } + keyPath := filepath.Join(home, ".ollama", defaultPrivateKey) privateKeyFile, err := os.ReadFile(keyPath) if err != nil { slog.Info(fmt.Sprintf("Failed to load private key: %v", err)) diff --git a/cmd/cmd.go b/cmd/cmd.go index 294e1662f..e8cfa1347 100644 --- a/cmd/cmd.go +++ b/cmd/cmd.go @@ -5,7 +5,6 @@ import ( "context" "crypto/ed25519" "crypto/rand" - "encoding/base64" "encoding/json" "encoding/pem" "errors" @@ -15,7 +14,6 @@ import ( "math" "net" "net/http" - "net/url" "os" "os/signal" "path/filepath" @@ -37,7 +35,6 @@ import ( "golang.org/x/term" "github.com/ollama/ollama/api" - "github.com/ollama/ollama/auth" "github.com/ollama/ollama/envconfig" "github.com/ollama/ollama/format" "github.com/ollama/ollama/parser" @@ -50,7 +47,7 @@ import ( "github.com/ollama/ollama/version" ) -const ConnectInstructions = "To sign in, navigate to:\n https://ollama.com/connect?name=%s&key=%s\n\n" +const ConnectInstructions = "To sign in, navigate to:\n %s\n\n" // ensureThinkingSupport emits a warning if the model does not advertise thinking support func ensureThinkingSupport(ctx context.Context, client *api.Client, name string) { @@ -452,16 +449,10 @@ func RunHandler(cmd *cobra.Command, args []string) error { if err := loadOrUnloadModel(cmd, &opts); err != nil { var sErr api.AuthorizationError if errors.As(err, &sErr) && sErr.StatusCode == http.StatusUnauthorized { - pubKey, pkErr := auth.GetPublicKey() - if pkErr != nil { - return pkErr - } - // the server and the client both have the same public key - if pubKey == sErr.PublicKey { - h, _ := os.Hostname() - encKey := base64.RawURLEncoding.EncodeToString([]byte(pubKey)) - fmt.Printf("You need to be signed in to Ollama to run Cloud models.\n\n") - fmt.Printf(ConnectInstructions, url.PathEscape(h), encKey) + fmt.Printf("You need to be signed in to Ollama to run Cloud models.\n\n") + + if sErr.SigninURL != "" { + fmt.Printf(ConnectInstructions, sErr.SigninURL) } return nil } @@ -493,6 +484,16 @@ func SigninHandler(cmd *cobra.Command, args []string) error { user, err := client.Whoami(cmd.Context()) if err != nil { + var aErr api.AuthorizationError + if errors.As(err, &aErr) && aErr.StatusCode == http.StatusUnauthorized { + fmt.Println("You need to be signed in to Ollama to run Cloud models.") + fmt.Println() + + if aErr.SigninURL != "" { + fmt.Printf(ConnectInstructions, aErr.SigninURL) + } + return nil + } return err } @@ -502,34 +503,27 @@ func SigninHandler(cmd *cobra.Command, args []string) error { return nil } - pubKey, pkErr := auth.GetPublicKey() - if pkErr != nil { - return pkErr - } - encKey := base64.RawURLEncoding.EncodeToString([]byte(pubKey)) - - h, _ := os.Hostname() - fmt.Printf(ConnectInstructions, url.PathEscape(h), encKey) - return nil } func SignoutHandler(cmd *cobra.Command, args []string) error { - pubKey, pkErr := auth.GetPublicKey() - if pkErr != nil { - return pkErr - } - encKey := base64.RawURLEncoding.EncodeToString([]byte(pubKey)) - client, err := api.ClientFromEnvironment() if err != nil { return err } - err = client.Signout(cmd.Context(), encKey) + err = client.Signout(cmd.Context()) if err != nil { - return err + var aErr api.AuthorizationError + if errors.As(err, &aErr) && aErr.StatusCode == http.StatusUnauthorized { + fmt.Println("You are not signed in to ollama.com") + fmt.Println() + return nil + } else { + return err + } } + fmt.Println("You have signed out of ollama.com") fmt.Println() return nil diff --git a/cmd/cmd_test.go b/cmd/cmd_test.go index bb793572f..24d287055 100644 --- a/cmd/cmd_test.go +++ b/cmd/cmd_test.go @@ -525,6 +525,9 @@ func TestPushHandler(t *testing.T) { defer mockServer.Close() t.Setenv("OLLAMA_HOST", mockServer.URL) + tmpDir := t.TempDir() + t.Setenv("HOME", tmpDir) + t.Setenv("USERPROFILE", tmpDir) initializeKeypair() cmd := &cobra.Command{} diff --git a/server/routes.go b/server/routes.go index a2078ec10..21a1b2b3d 100644 --- a/server/routes.go +++ b/server/routes.go @@ -4,6 +4,7 @@ import ( "bytes" "cmp" "context" + "encoding/base64" "encoding/json" "errors" "fmt" @@ -48,6 +49,8 @@ import ( "github.com/ollama/ollama/version" ) +const signinURLStr = "https://ollama.com/connect?name=%s&key=%s" + func shouldUseHarmony(model *Model) bool { if slices.Contains([]string{"gptoss", "gpt-oss"}, model.Config.ModelFamily) { // heuristic to check whether the template expects to be parsed via harmony: @@ -150,6 +153,17 @@ func (s *Server) scheduleRunner(ctx context.Context, name string, caps []model.C return runner.llama, model, &opts, nil } +func signinURL() (string, error) { + pubKey, err := auth.GetPublicKey() + if err != nil { + return "", err + } + + encKey := base64.RawURLEncoding.EncodeToString([]byte(pubKey)) + h, _ := os.Hostname() + return fmt.Sprintf(signinURLStr, url.PathEscape(h), encKey), nil +} + func (s *Server) GenerateHandler(c *gin.Context) { checkpointStart := time.Now() var req api.GenerateRequest @@ -250,18 +264,21 @@ func (s *Server) GenerateHandler(c *gin.Context) { client := api.NewClient(remoteURL, http.DefaultClient) err = client.Generate(c, &req, fn) if err != nil { - var sErr api.AuthorizationError - if errors.As(err, &sErr) && sErr.StatusCode == http.StatusUnauthorized { - pk, pkErr := auth.GetPublicKey() - if pkErr != nil { - slog.Error("couldn't get public key", "error", pkErr) - c.JSON(http.StatusUnauthorized, gin.H{"error": "error getting public key"}) + var authError api.AuthorizationError + if errors.As(err, &authError) { + sURL, sErr := signinURL() + if sErr != nil { + slog.Error(sErr.Error()) + c.JSON(http.StatusInternalServerError, gin.H{"error": "error getting authorization details"}) return } - c.JSON(http.StatusUnauthorized, gin.H{ - "error": "unauthorized", - "public_key": pk, - }) + + c.JSON(authError.StatusCode, gin.H{"error": "unauthorized", "signin_url": sURL}) + return + } + var apiError api.StatusError + if errors.As(err, &apiError) { + c.JSON(apiError.StatusCode, apiError) return } c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) @@ -1412,9 +1429,12 @@ func (s *Server) GenerateRoutes(rc *ollama.Registry) (http.Handler, error) { r.POST("/api/show", s.ShowHandler) r.DELETE("/api/delete", s.DeleteHandler) - r.DELETE("/api/user/keys/:encodedKey", s.SignoutHandler) r.POST("/api/me", s.WhoamiHandler) + r.POST("/api/signout", s.SignoutHandler) + // deprecated + r.DELETE("/api/user/keys/:encodedKey", s.SignoutHandler) + // Create r.POST("/api/create", s.CreateHandler) r.POST("/api/blobs/:digest", s.CreateBlobHandler) @@ -1625,11 +1645,32 @@ func (s *Server) WhoamiHandler(c *gin.Context) { if err != nil { slog.Error(err.Error()) } + + // user isn't signed in + if user != nil && user.Name == "" { + sURL, sErr := signinURL() + if sErr != nil { + slog.Error(sErr.Error()) + c.JSON(http.StatusInternalServerError, gin.H{"error": "error getting authorization details"}) + return + } + + c.JSON(http.StatusUnauthorized, gin.H{"error": "unauthorized", "signin_url": sURL}) + return + } + c.JSON(http.StatusOK, user) } func (s *Server) SignoutHandler(c *gin.Context) { - encodedKey := c.Param("encodedKey") + pubKey, err := auth.GetPublicKey() + if err != nil { + slog.Error("couldn't get public key", "error", err) + c.JSON(http.StatusInternalServerError, gin.H{"error": "there was an error signing out"}) + return + } + + encKey := base64.RawURLEncoding.EncodeToString([]byte(pubKey)) // todo allow other hosts u, err := url.Parse("https://ollama.com") @@ -1640,11 +1681,11 @@ func (s *Server) SignoutHandler(c *gin.Context) { } client := api.NewClient(u, http.DefaultClient) - err = client.Signout(c, encodedKey) + err = client.Disconnect(c, encKey) if err != nil { - slog.Error(err.Error()) - if strings.Contains(err.Error(), "page not found") || strings.Contains(err.Error(), "invalid credentials") { - c.JSON(http.StatusNotFound, gin.H{"error": "you are not currently signed in"}) + var authError api.AuthorizationError + if errors.As(err, &authError) { + c.JSON(http.StatusUnauthorized, gin.H{"error": "you are not currently signed in"}) return } c.JSON(http.StatusInternalServerError, gin.H{"error": "there was an error signing out"}) @@ -1802,18 +1843,21 @@ func (s *Server) ChatHandler(c *gin.Context) { client := api.NewClient(remoteURL, http.DefaultClient) err = client.Chat(c, &req, fn) if err != nil { - var sErr api.AuthorizationError - if errors.As(err, &sErr) && sErr.StatusCode == http.StatusUnauthorized { - pk, pkErr := auth.GetPublicKey() - if pkErr != nil { - slog.Error("couldn't get public key", "error", pkErr) - c.JSON(http.StatusUnauthorized, gin.H{"error": "error getting public key"}) + var authError api.AuthorizationError + if errors.As(err, &authError) { + sURL, sErr := signinURL() + if sErr != nil { + slog.Error(sErr.Error()) + c.JSON(http.StatusInternalServerError, gin.H{"error": "error getting authorization details"}) return } - c.JSON(http.StatusUnauthorized, gin.H{ - "error": "unauthorized", - "public_key": pk, - }) + + c.JSON(authError.StatusCode, gin.H{"error": "unauthorized", "signin_url": sURL}) + return + } + var apiError api.StatusError + if errors.As(err, &apiError) { + c.JSON(apiError.StatusCode, apiError) return } c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) From a40d427bcea52ad5c7e93780564fc15e5ef80473 Mon Sep 17 00:00:00 2001 From: Michael Yang Date: Tue, 23 Sep 2025 13:21:47 -0700 Subject: [PATCH 06/11] multi-regexp pretokenizer (#12325) --- model/bytepairencoding.go | 54 ++++++++++++++++++++++++++++------ model/bytepairencoding_test.go | 40 ++++++++++++++++++++++++- model/models/gptoss/model.go | 20 ++++++------- model/models/llama/model.go | 28 +++++++++++++++--- model/models/llama4/model.go | 3 +- model/models/mistral3/model.go | 2 +- model/models/mllama/model.go | 2 +- model/models/qwen2/model.go | 2 +- model/models/qwen25vl/model.go | 2 +- model/models/qwen3/embed.go | 2 +- model/models/qwen3/model.go | 2 +- sample/samplers_test.go | 1 - 12 files changed, 124 insertions(+), 34 deletions(-) diff --git a/model/bytepairencoding.go b/model/bytepairencoding.go index e21564aa5..3d51f70e8 100644 --- a/model/bytepairencoding.go +++ b/model/bytepairencoding.go @@ -5,6 +5,7 @@ import ( "fmt" "iter" "log/slog" + "slices" "strings" "github.com/dlclark/regexp2" @@ -13,16 +14,28 @@ import ( ) type BytePairEncoding struct { - pre *regexp2.Regexp - vocab *Vocabulary + vocab *Vocabulary + regexps []*regexp2.Regexp } var _ TextProcessor = (*BytePairEncoding)(nil) -func NewBytePairEncoding(pre string, vocab *Vocabulary) BytePairEncoding { +func NewBytePairEncoding(vocab *Vocabulary, pretokenizers ...string) BytePairEncoding { + if len(pretokenizers) == 0 { + // set default byte-level pretokenizer if none provided, e.g. + // https://github.com/huggingface/tokenizers/blob/main/tokenizers/src/pre_tokenizers/byte_level.rs#L44 + pretokenizers = []string{`'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+`} + } + return BytePairEncoding{ - pre: regexp2.MustCompile(pre, regexp2.None), vocab: vocab, + regexps: slices.Collect(func(yield func(*regexp2.Regexp) bool) { + for _, p := range pretokenizers { + if !yield(regexp2.MustCompile(p, regexp2.RE2)) { + return + } + } + }), } } @@ -35,13 +48,36 @@ func (bpe BytePairEncoding) Is(id int32, special Special) bool { } func (bpe *BytePairEncoding) split(s string) iter.Seq[string] { - return func(yield func(string) bool) { - for m, _ := bpe.pre.FindStringMatch(s); m != nil; m, _ = bpe.pre.FindNextMatch(m) { - if !yield(m.String()) { - break + parts := []string{s} + for _, re := range bpe.regexps { + parts = slices.Collect(func(yield func(string) bool) { + for _, part := range parts { + r := []rune(part) + var offset int + for m, _ := re.FindRunesMatch(r); m != nil; m, _ = re.FindNextMatch(m) { + if offset-m.Index != 0 { + if !yield(string(r[:m.Index])) { + return + } + } + + if !yield(m.String()) { + return + } + + offset = m.Index + m.Length + } + + if offset < len(r) { + if !yield(string(r[offset:])) { + return + } + } } - } + }) } + + return slices.Values(parts) } // fragment is a string fragment and their corresponding token IDs diff --git a/model/bytepairencoding_test.go b/model/bytepairencoding_test.go index 71947be99..39e5ab452 100644 --- a/model/bytepairencoding_test.go +++ b/model/bytepairencoding_test.go @@ -59,12 +59,12 @@ func llama(t testing.TB) BytePairEncoding { } return NewBytePairEncoding( - `(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`, &Vocabulary{ Values: tokens, Types: types, Merges: merges, }, + "(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+", ) } @@ -282,3 +282,41 @@ func BenchmarkBytePairEncoding(b *testing.B) { }) } } + +func TestSplit(t *testing.T) { + cases := []struct { + name string + patterns, + want []string + }{ + { + name: "default", + want: []string{"Hello", ",", " WORLD", "!!", " How", "'s", " it", " going", "?", " 123", " 一二三"}, + }, + { + name: "unicode", + patterns: []string{ + "\\p{N}{1,3}", + `[一-龥぀-ゟ゠-ヿ]+`, + "[!\"#$%&'()*+,\\-./:;<=>?@\\[\\\\\\]^_`{|}~][A-Za-z]+|[^\r\n\\p{L}\\p{P}\\p{S}]?[\\p{L}\\p{M}]+| ?[\\p{P}\\p{S}]+[\r\n]*|\\s*[\r\n]+|\\s+(?!\\S)|\\s+", + }, + want: []string{"Hello", ",", " WORLD", "!!", " How", "'s", " it", " going", "?", " ", "123", " ", "一二三"}, + }, + { + name: "individual digits", + patterns: []string{ + "(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+", + }, + want: []string{"Hello", ",", " WORLD", "!!", " How", "'s", " it", " going", "?", " ", "1", "2", "3", " 一二三"}, + }, + } + + for _, tt := range cases { + t.Run(tt.name, func(t *testing.T) { + tokenizer := NewBytePairEncoding(nil, tt.patterns...) + if diff := cmp.Diff(tt.want, slices.Collect(tokenizer.split("Hello, WORLD!! How's it going? 123 一二三"))); diff != "" { + t.Errorf("no match (-theirs +ours):\n%s", diff) + } + }) + } +} diff --git a/model/models/gptoss/model.go b/model/models/gptoss/model.go index 8456ea5f7..6a3270651 100644 --- a/model/models/gptoss/model.go +++ b/model/models/gptoss/model.go @@ -227,17 +227,6 @@ func New(c fs.Config) (model.Model, error) { m := Transformer{ TransformerBlocks: make([]TransformerBlock, c.Uint("block_count")), BytePairEncoding: model.NewBytePairEncoding( - c.String("tokenizer.ggml.pretokenizer", - strings.Join([]string{ - `[^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}]*[\p{Ll}\p{Lm}\p{Lo}\p{M}]+(?i:'s|'t|'re|'ve|'m|'ll|'d)?`, - `[^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}]+[\p{Ll}\p{Lm}\p{Lo}\p{M}]*(?i:'s|'t|'re|'ve|'m|'ll|'d)?`, - `\p{N}{1,3}`, - ` ?[^\s\p{L}\p{N}]+[\r\n/]*`, - `\s*[\r\n]+`, - `\s+(?!\S)`, - `\s+`, - }, "|"), - ), &model.Vocabulary{ Values: c.Strings("tokenizer.ggml.tokens"), Types: c.Ints("tokenizer.ggml.token_type"), @@ -250,6 +239,15 @@ func New(c fs.Config) (model.Model, error) { c.Ints("tokenizer.ggml.eos_token_ids")..., ), }, + strings.Join([]string{ + `[^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}]*[\p{Ll}\p{Lm}\p{Lo}\p{M}]+(?i:'s|'t|'re|'ve|'m|'ll|'d)?`, + `[^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}]+[\p{Ll}\p{Lm}\p{Lo}\p{M}]*(?i:'s|'t|'re|'ve|'m|'ll|'d)?`, + `\p{N}{1,3}`, + ` ?[^\s\p{L}\p{N}]+[\r\n/]*`, + `\s*[\r\n]+`, + `\s+(?!\S)`, + `\s+`, + }, "|"), ), Options: Options{ hiddenSize: int(c.Uint("embedding_length")), diff --git a/model/models/llama/model.go b/model/models/llama/model.go index f6ec02273..c03f04a0d 100644 --- a/model/models/llama/model.go +++ b/model/models/llama/model.go @@ -54,10 +54,30 @@ func New(c fs.Config) (model.Model, error) { } switch c.String("tokenizer.ggml.model") { case "gpt2": - processor = model.NewBytePairEncoding( - `(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`, - &vocabulary, - ) + var pretokenizers []string + switch c.String("tokenizer.ggml.pre") { + case "default": + // no-op use the default bpe pretokenizer + case "qwen2": + pretokenizers = []string{ + "(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+", + } + case "refact": + pretokenizers = []string{ + `\p{N}`, + `'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+`, + } + case "tekken": + pretokenizers = []string{ + "[^\\r\\n\\p{L}\\p{N}]?[\\p{Lu}\\p{Lt}\\p{Lm}\\p{Lo}\\p{M}]*[\\p{Ll}\\p{Lm}\\p{Lo}\\p{M}]+|[^\\r\\n\\p{L}\\p{N}]?[\\p{Lu}\\p{Lt}\\p{Lm}\\p{Lo}\\p{M}]+[\\p{Ll}\\p{Lm}\\p{Lo}\\p{M}]*|\\p{N}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n/]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+", + } + default: + // use a llama-style pretokenizer + pretokenizers = []string{ + "(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+", + } + } + processor = model.NewBytePairEncoding(&vocabulary, pretokenizers...) case "llama": processor = model.NewSentencePiece(&vocabulary) default: diff --git a/model/models/llama4/model.go b/model/models/llama4/model.go index 9cb2efc87..e80fbaed6 100644 --- a/model/models/llama4/model.go +++ b/model/models/llama4/model.go @@ -34,8 +34,6 @@ func (p *Projector) Forward(ctx ml.Context, visionOutputs ml.Tensor) ml.Tensor { func New(c fs.Config) (model.Model, error) { m := Model{ BytePairEncoding: model.NewBytePairEncoding( - c.String("tokenizer.ggml.pretokenizer", - `[^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}]*[\p{Ll}\p{Lm}\p{Lo}\p{M}]+(?i:'s|'t|'re|'ve|'m|'ll|'d)?|[^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}]+[\p{Ll}\p{Lm}\p{Lo}\p{M}]*(?i:'s|'t|'re|'ve|'m|'ll|'d)?|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n/]*|\s*[\r\n]+|\s+(?!\S)|\s+`), &model.Vocabulary{ Values: c.Strings("tokenizer.ggml.tokens"), Types: c.Ints("tokenizer.ggml.token_type"), @@ -48,6 +46,7 @@ func New(c fs.Config) (model.Model, error) { c.Ints("tokenizer.ggml.eos_token_ids")..., ), }, + `[^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}]*[\p{Ll}\p{Lm}\p{Lo}\p{M}]+(?i:'s|'t|'re|'ve|'m|'ll|'d)?|[^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}]+[\p{Ll}\p{Lm}\p{Lo}\p{M}]*(?i:'s|'t|'re|'ve|'m|'ll|'d)?|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n/]*|\s*[\r\n]+|\s+(?!\S)|\s+`, ), ImageProcessor: newImageProcessor(c), VisionModel: newVisionModel(c), diff --git a/model/models/mistral3/model.go b/model/models/mistral3/model.go index 435b1a304..5c46615e9 100644 --- a/model/models/mistral3/model.go +++ b/model/models/mistral3/model.go @@ -33,7 +33,6 @@ var _ model.TextProcessor = (*Model)(nil) func New(c fs.Config) (model.Model, error) { m := &Model{ BytePairEncoding: model.NewBytePairEncoding( - c.String("tokenizer.ggml.pretokenizer", `[^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}]*[\p{Ll}\p{Lm}\p{Lo}\p{M}]+|[^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}]+[\p{Ll}\p{Lm}\p{Lo}\p{M}]*|\p{N}| ?[^\s\p{L}\p{N}]+[\r\n/]*|\s*[\r\n]+|\s+(?!\S)|\s+`), &model.Vocabulary{ Values: c.Strings("tokenizer.ggml.tokens"), Types: c.Ints("tokenizer.ggml.token_type"), @@ -46,6 +45,7 @@ func New(c fs.Config) (model.Model, error) { c.Ints("tokenizer.ggml.eos_token_ids")..., ), }, + `[^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}]*[\p{Ll}\p{Lm}\p{Lo}\p{M}]+|[^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}]+[\p{Ll}\p{Lm}\p{Lo}\p{M}]*|\p{N}| ?[^\s\p{L}\p{N}]+[\r\n/]*|\s*[\r\n]+|\s+(?!\S)|\s+`, ), TextModel: newTextModel(c), VisionModel: newVisionModel(c), diff --git a/model/models/mllama/model.go b/model/models/mllama/model.go index 239d999d5..769743694 100644 --- a/model/models/mllama/model.go +++ b/model/models/mllama/model.go @@ -33,7 +33,6 @@ const ( func New(c fs.Config) (model.Model, error) { m := Model{ BytePairEncoding: model.NewBytePairEncoding( - c.String("tokenizer.ggml.pretokenizer", `(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`), &model.Vocabulary{ Values: c.Strings("tokenizer.ggml.tokens"), Types: c.Ints("tokenizer.ggml.token_type"), @@ -46,6 +45,7 @@ func New(c fs.Config) (model.Model, error) { c.Ints("tokenizer.ggml.eos_token_ids")..., ), }, + `(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`, ), ImageProcessor: newImageProcessor(c), VisionModel: newVisionModel(c), diff --git a/model/models/qwen2/model.go b/model/models/qwen2/model.go index 5a3458378..2e2347102 100644 --- a/model/models/qwen2/model.go +++ b/model/models/qwen2/model.go @@ -139,7 +139,6 @@ func New(c fs.Config) (model.Model, error) { m := Model{ Layers: make([]DecoderLayer, c.Uint("block_count")), BytePairEncoding: model.NewBytePairEncoding( - c.String("tokenizer.ggml.pretokenizer", `(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`), &model.Vocabulary{ Values: c.Strings("tokenizer.ggml.tokens"), Types: c.Ints("tokenizer.ggml.token_type"), @@ -152,6 +151,7 @@ func New(c fs.Config) (model.Model, error) { c.Ints("tokenizer.ggml.eos_token_ids")..., ), }, + `(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`, ), Options: Options{ hiddenSize: int(c.Uint("embedding_length")), diff --git a/model/models/qwen25vl/model.go b/model/models/qwen25vl/model.go index 6c76305db..6898e38ca 100644 --- a/model/models/qwen25vl/model.go +++ b/model/models/qwen25vl/model.go @@ -29,7 +29,6 @@ var _ model.MultimodalProcessor = (*Model)(nil) func New(c fs.Config) (model.Model, error) { m := &Model{ BytePairEncoding: model.NewBytePairEncoding( - c.String("tokenizer.ggml.pretokenizer", `(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`), &model.Vocabulary{ Values: c.Strings("tokenizer.ggml.tokens"), Types: c.Ints("tokenizer.ggml.token_type"), @@ -42,6 +41,7 @@ func New(c fs.Config) (model.Model, error) { c.Ints("tokenizer.ggml.eos_token_ids")..., ), }, + `(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`, ), TextModel: NewTextModel(c), VisionModel: newVisionModel(c), diff --git a/model/models/qwen3/embed.go b/model/models/qwen3/embed.go index 9a77efea9..c03888d45 100644 --- a/model/models/qwen3/embed.go +++ b/model/models/qwen3/embed.go @@ -35,7 +35,6 @@ func newEmbed(c fs.Config) (model.Model, error) { } m := embedModel{ BytePairEncoding: model.NewBytePairEncoding( - `(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`, &model.Vocabulary{ Values: c.Strings("tokenizer.ggml.tokens"), Types: c.Ints("tokenizer.ggml.token_type"), @@ -48,6 +47,7 @@ func newEmbed(c fs.Config) (model.Model, error) { c.Ints("tokenizer.ggml.eos_token_ids")..., ), }, + `(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`, ), Model: &Model{ Layers: layers, diff --git a/model/models/qwen3/model.go b/model/models/qwen3/model.go index 352268347..cc58e4a28 100644 --- a/model/models/qwen3/model.go +++ b/model/models/qwen3/model.go @@ -200,7 +200,6 @@ func New(c fs.Config) (model.Model, error) { m := Model{ BytePairEncoding: model.NewBytePairEncoding( - `(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`, &model.Vocabulary{ Values: c.Strings("tokenizer.ggml.tokens"), Types: c.Ints("tokenizer.ggml.token_type"), @@ -213,6 +212,7 @@ func New(c fs.Config) (model.Model, error) { c.Ints("tokenizer.ggml.eos_token_ids")..., ), }, + `(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`, ), Layers: layers, Options: &Options{ diff --git a/sample/samplers_test.go b/sample/samplers_test.go index b720f027c..eb10295d4 100644 --- a/sample/samplers_test.go +++ b/sample/samplers_test.go @@ -82,7 +82,6 @@ func modelHelper(t testing.TB) model.BytePairEncoding { merges := make([]string, 0, 1) // Only need vocab for Grammar Test return model.NewBytePairEncoding( - ``, &model.Vocabulary{ Values: tokens, Types: make([]int32, len(vocab)), From bf78ed6ee94e593a7edae2e277a736379cbc2413 Mon Sep 17 00:00:00 2001 From: Michael Yang Date: Tue, 23 Sep 2025 16:08:57 -0700 Subject: [PATCH 07/11] add pre:, suf: to tags (#12274) --- model/model.go | 67 ++++++++++++++++++++----------- model/model_test.go | 61 +++++++++++++++++++++++++--- model/models/llama4/model_text.go | 14 +------ 3 files changed, 101 insertions(+), 41 deletions(-) diff --git a/model/model.go b/model/model.go index f3d6bb3db..2b6ad7317 100644 --- a/model/model.go +++ b/model/model.go @@ -5,6 +5,7 @@ import ( "fmt" _ "image/jpeg" _ "image/png" + "log/slog" "os" "reflect" "strconv" @@ -171,35 +172,42 @@ func populateFields(base Base, v reflect.Value, tags ...Tag) reflect.Value { // make a copy tagsCopy := tags if tag := t.Field(i).Tag.Get("gguf"); tag != "" { - tagsCopy = append(tagsCopy, ParseTags(tag)) + tagsCopy = append(tagsCopy, parseTag(tag)) } if tt == reflect.TypeOf((*Base)(nil)).Elem() { vv.Set(reflect.ValueOf(base)) } else if tt == reflect.TypeOf((*ml.Tensor)(nil)).Elem() { - var fn func([]Tag) [][]string - fn = func(tags []Tag) (names [][]string) { + var fn func([]Tag, string, string) [][]string + fn = func(tags []Tag, prefix, suffix string) (fullNames [][]string) { if len(tags) > 0 { - localNames := []string{tags[0].Name} - localNames = append(localNames, tags[0].Alternate...) + var names []string + if tags[0].name != "" { + for _, n := range append([]string{tags[0].name}, tags[0].alternatives...) { + names = append(names, prefix+n+suffix) + } + } - for _, localName := range localNames { - fullName := []string{localName} - nested := fn(tags[1:]) - if len(nested) > 0 { - for _, rest := range nested { - names = append(names, append(fullName, rest...)) + if childNames := fn(tags[1:], tags[0].prefix, tags[0].suffix); len(childNames) == 0 { + // no child names, append current names + fullNames = append(fullNames, names) + } else if len(names) == 0 { + // no current names, append child names + fullNames = append(fullNames, childNames...) + } else { + // combine current and child names + for _, name := range names { + for _, childName := range childNames { + fullNames = append(fullNames, append([]string{name}, childName...)) } - } else { - names = append(names, fullName) } } } - return names + return fullNames } - names := fn(tagsCopy) + names := fn(tagsCopy, "", "") for _, name := range names { if tensor := base.Backend().Get(strings.Join(name, ".")); tensor != nil { logutil.Trace("found tensor", "", tensor) @@ -213,9 +221,9 @@ func populateFields(base Base, v reflect.Value, tags ...Tag) reflect.Value { for i := range vv.Len() { vvv := vv.Index(i) if vvv.Kind() == reflect.Pointer || vvv.Kind() == reflect.Interface { - setPointer(base, vvv, append(tagsCopy, Tag{Name: strconv.Itoa(i)})) + setPointer(base, vvv, append(tagsCopy, Tag{name: strconv.Itoa(i)})) } else { - vvv.Set(populateFields(base, vvv, append(tagsCopy, Tag{Name: strconv.Itoa(i)})...)) + vvv.Set(populateFields(base, vvv, append(tagsCopy, Tag{name: strconv.Itoa(i)})...)) } } } @@ -254,18 +262,31 @@ func setPointer(base Base, v reflect.Value, tags []Tag) { } type Tag struct { - Name string - Alternate []string + name, + // prefix and suffix are applied to child tags + prefix, + suffix string + alternatives []string } -func ParseTags(s string) (tag Tag) { +func parseTag(s string) (tag Tag) { parts := strings.Split(s, ",") if len(parts) > 0 { - tag.Name = parts[0] + tag.name = parts[0] for _, part := range parts[1:] { - if value, ok := strings.CutPrefix(part, "alt:"); ok { - tag.Alternate = append(tag.Alternate, value) + if value, ok := strings.CutPrefix(part, "alt:"); ok && tag.name == "" { + // elevate alternative to primary if no primary given + tag.name = value + slog.Warn("gguf tag has alt: but no primary name", "tag", s) + } else if ok { + tag.alternatives = append(tag.alternatives, value) + } + if value, ok := strings.CutPrefix(part, "pre:"); ok { + tag.prefix = value + } + if value, ok := strings.CutPrefix(part, "suf:"); ok { + tag.suffix = value } } } diff --git a/model/model_test.go b/model/model_test.go index 01080ffdf..e47278540 100644 --- a/model/model_test.go +++ b/model/model_test.go @@ -22,14 +22,14 @@ func TestParseTags(t *testing.T) { { value: "output", want: Tag{ - Name: "output", + name: "output", }, }, { value: "output,alt:token_embd", want: Tag{ - Name: "output", - Alternate: []string{ + name: "output", + alternatives: []string{ "token_embd", }, }, @@ -38,8 +38,8 @@ func TestParseTags(t *testing.T) { for _, tt := range cases { t.Run(tt.value, func(t *testing.T) { - got := ParseTags(tt.value) - if diff := cmp.Diff(tt.want, got); diff != "" { + got := parseTag(tt.value) + if diff := cmp.Diff(tt.want, got, cmp.AllowUnexported((Tag{}))); diff != "" { t.Errorf("ParseTags() returned unexpected values (-want +got):\n%s", diff) } }) @@ -147,6 +147,57 @@ func TestPopulateFieldsAlternateName(t *testing.T) { } } +func TestPopulateFieldsPrefixSuffixName(t *testing.T) { + type fakeBlock struct { + A *nn.Linear `gguf:"a"` + B *nn.Linear `gguf:",pre:b_"` + C *nn.Linear `gguf:",suf:_c"` + XY *nn.Linear `gguf:",pre:x_,suf:_y"` + } + + type fakeModel struct { + Blocks []fakeBlock `gguf:"blk"` + } + + m := fakeModel{ + Blocks: make([]fakeBlock, 2), + } + v := reflect.ValueOf(&m) + v.Elem().Set(populateFields(Base{b: &fakeBackend{ + names: []string{ + "blk.0.a.weight", + "blk.0.b_weight", + "blk.0.b_bias", + "blk.0.weight_c", + "blk.0.x_weight_y", + "blk.1.a.weight", + "blk.1.b_weight", + "blk.1.b_bias", + "blk.1.weight_c", + "blk.1.x_weight_y", + }, + }}, v.Elem())) + + if diff := cmp.Diff(fakeModel{ + Blocks: []fakeBlock{ + { + A: &nn.Linear{Weight: &fakeTensor{Name: "blk.0.a.weight"}}, + B: &nn.Linear{Weight: &fakeTensor{Name: "blk.0.b_weight"}, Bias: &fakeTensor{Name: "blk.0.b_bias"}}, + C: &nn.Linear{Weight: &fakeTensor{Name: "blk.0.weight_c"}}, + XY: &nn.Linear{Weight: &fakeTensor{Name: "blk.0.x_weight_y"}}, + }, + { + A: &nn.Linear{Weight: &fakeTensor{Name: "blk.1.a.weight"}}, + B: &nn.Linear{Weight: &fakeTensor{Name: "blk.1.b_weight"}, Bias: &fakeTensor{Name: "blk.1.b_bias"}}, + C: &nn.Linear{Weight: &fakeTensor{Name: "blk.1.weight_c"}}, + XY: &nn.Linear{Weight: &fakeTensor{Name: "blk.1.x_weight_y"}}, + }, + }, + }, m); diff != "" { + t.Errorf("populateFields() set incorrect values (-want +got):\n%s", diff) + } +} + func TestModelForArch(t *testing.T) { type fakeModel struct { Model diff --git a/model/models/llama4/model_text.go b/model/models/llama4/model_text.go index e0f932600..e056391f5 100644 --- a/model/models/llama4/model_text.go +++ b/model/models/llama4/model_text.go @@ -88,22 +88,10 @@ func (e *TextExperts) Forward(ctx ml.Context, hiddenStates, routerLogits ml.Tens return nextStates } -// TextSharedExpert is TextMLP with different tensor names -type TextSharedExpert struct { - Gate *nn.Linear `gguf:"ffn_gate_shexp"` - Up *nn.Linear `gguf:"ffn_up_shexp"` - Down *nn.Linear `gguf:"ffn_down_shexp"` -} - -func (mlp *TextSharedExpert) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *TextOptions) ml.Tensor { - hiddenStates = mlp.Gate.Forward(ctx, hiddenStates).SILU(ctx, mlp.Up.Forward(ctx, hiddenStates)) - return mlp.Down.Forward(ctx, hiddenStates) -} - type TextMOE struct { Router *nn.Linear `gguf:"ffn_gate_inp"` Experts *TextExperts - SharedExpert *TextSharedExpert + SharedExpert *TextMLP `gguf:",suf:_shexp"` } func (moe *TextMOE) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *TextOptions) ml.Tensor { From e1979c571aff857568c9c35f5994da40568ef15c Mon Sep 17 00:00:00 2001 From: Michael Yang Date: Tue, 23 Sep 2025 17:50:53 -0700 Subject: [PATCH 08/11] fix: leaf alt name (#12390) a leaf node with an alternative name gets all its alternatives names added into the same branch rather than creating branches themselves --- model/model.go | 16 +++++++++------- model/model_test.go | 3 +++ 2 files changed, 12 insertions(+), 7 deletions(-) diff --git a/model/model.go b/model/model.go index 2b6ad7317..0af16da80 100644 --- a/model/model.go +++ b/model/model.go @@ -187,15 +187,17 @@ func populateFields(base Base, v reflect.Value, tags ...Tag) reflect.Value { names = append(names, prefix+n+suffix) } } - - if childNames := fn(tags[1:], tags[0].prefix, tags[0].suffix); len(childNames) == 0 { - // no child names, append current names - fullNames = append(fullNames, names) - } else if len(names) == 0 { - // no current names, append child names + childNames := fn(tags[1:], tags[0].prefix, tags[0].suffix) + if len(names) == 0 { + // current tag has no name, use child names only fullNames = append(fullNames, childNames...) + } else if len(childNames) == 0 { + // current tag has names but no children, create branches for each name + for _, name := range names { + fullNames = append(fullNames, []string{name}) + } } else { - // combine current and child names + // merge each name with each child for _, name := range names { for _, childName := range childNames { fullNames = append(fullNames, append([]string{name}, childName...)) diff --git a/model/model_test.go b/model/model_test.go index e47278540..f6d75b230 100644 --- a/model/model_test.go +++ b/model/model_test.go @@ -125,6 +125,7 @@ func TestPopulateFieldsAlternateName(t *testing.T) { Input *nn.Embedding `gguf:"input"` Output *nn.Linear `gguf:"output,alt:input"` Nested *nested `gguf:"nested"` + Tensor ml.Tensor `gguf:"leaf,alt:tensor"` } var m fakeModel @@ -133,6 +134,7 @@ func TestPopulateFieldsAlternateName(t *testing.T) { names: []string{ "input.weight", "nested.b.weight", + "leaf", }, }}, v.Elem())) @@ -142,6 +144,7 @@ func TestPopulateFieldsAlternateName(t *testing.T) { Nested: &nested{ Weight: &nn.Linear{Weight: &fakeTensor{Name: "nested.b.weight"}}, }, + Tensor: &fakeTensor{Name: "leaf"}, }, m); diff != "" { t.Errorf("populateFields() set incorrect values (-want +got):\n%s", diff) } From fd88cd7cb0966a26f41ec41bc012f2c4d725ab98 Mon Sep 17 00:00:00 2001 From: Devon Rifkin Date: Tue, 23 Sep 2025 23:34:55 -0700 Subject: [PATCH 09/11] harmony: don't sanitize built-ins In #11910 we started sanitizing function names, but we accidentally were modifying built-ins like `browser.open` to `browser_open`. This was removing the special prompt rendering for built-ins, but this wasn't immediately apparent since the models seem to be reasonably good at remembering the built-ins even when presented with these slightly renamed version. This fix prevents built-ins from ever being renamed. --- harmony/harmonyparser.go | 4 ++++ harmony/harmonyparser_test.go | 1 + 2 files changed, 5 insertions(+) diff --git a/harmony/harmonyparser.go b/harmony/harmonyparser.go index b365b763d..da9fe3e93 100644 --- a/harmony/harmonyparser.go +++ b/harmony/harmonyparser.go @@ -463,6 +463,10 @@ func (h *HarmonyMessageHandler) HasThinkingSupport() bool { func (m *FunctionNameMap) ConvertAndAdd(userFunctionName string) string { harmonyFunctionName := m.deriveName(userFunctionName) + // built-in functions should not be renamed + if userFunctionName == "browser.open" || userFunctionName == "browser.search" || userFunctionName == "browser.find" || userFunctionName == "python" { + harmonyFunctionName = userFunctionName + } m.userToHarmony[userFunctionName] = harmonyFunctionName m.harmonyToUser[harmonyFunctionName] = userFunctionName return harmonyFunctionName diff --git a/harmony/harmonyparser_test.go b/harmony/harmonyparser_test.go index b988a018f..e56178c61 100644 --- a/harmony/harmonyparser_test.go +++ b/harmony/harmonyparser_test.go @@ -513,6 +513,7 @@ func TestFunctionConvertAndAdd(t *testing.T) { {name: "dupes from different user-specified names", in: []string{"get weather", "get_weather", "get-weather"}, want: []string{"get_weather", "get_weather_2", "get_weather_3"}}, {name: "non dupes after dupes", in: []string{"get weather", "get_weather", "get-weather", "something-different"}, want: []string{"get_weather", "get_weather_2", "get_weather_3", "something_different"}}, {name: "multiple sets of dupes", in: []string{"a", "a", "b", "a", "a", "b", "a"}, want: []string{"a", "a_2", "b", "a_3", "a_4", "b_2", "a_5"}}, + {name: "built-in functions should not be renamed", in: []string{"browser.open", "python", "not.a.built-in.function", "browser.not_a_real_built_in"}, want: []string{"browser.open", "python", "not_a_built_in_function", "browser_not_a_real_built_in"}}, } for i, tt := range tests { From 2e742544bfc5242be4d76c6fee5082c7e41b3df2 Mon Sep 17 00:00:00 2001 From: Michael Yang Date: Wed, 24 Sep 2025 11:21:32 -0700 Subject: [PATCH 10/11] prefer ollama engine for qwen3moe (#12374) --- fs/ggml/ggml.go | 1 + 1 file changed, 1 insertion(+) diff --git a/fs/ggml/ggml.go b/fs/ggml/ggml.go index 5da902bcb..58803f58f 100644 --- a/fs/ggml/ggml.go +++ b/fs/ggml/ggml.go @@ -244,6 +244,7 @@ func (kv KV) OllamaEngineRequired() bool { "gemma3n", "mistral3", "qwen3", + "qwen3moe", "llama4", "mllama", "qwen25vl", From fbd82ba5bb35c42a6b09f5bd50ff1aa0690b9626 Mon Sep 17 00:00:00 2001 From: Grace <88872231+gr4ceG@users.noreply.github.com> Date: Wed, 24 Sep 2025 15:19:47 -0700 Subject: [PATCH 11/11] Grace/deepseek v3 migration (#12385) * init deepseek model file * temp removal of flash attention implementation * shapes and proper, can make a pass * query, key, value have good cosine similarity, but the max diff is a bit high * Attention block is working! ** with eager for now, have not added the mask line * Attention block is working! ** with eager for now, have not added the mask line * working MoE at around 0.95 cosine sim * added cosine similarity function * Starting end to end structure * Trying (and failing) to get rope to work, going to test full thing on tater * running on tater36... just not the right outputs * we have the right values for rope... but its still not working? * chnage Extrapolation Factor to 1 * removed adding residuals twice, removed normalization from shared expert, refactored Norms (Attention, MLP) to be outside the (Attention, MLP) blocks and in the Transformer block instead, add cache setLayer * Temporary modelfiles for cpu * change kpass intermediate step to kv, two layer outputs [0,1] look fine * this calls for 16 chicken nuggets * whoops * cleaning up code * delete stuff we dont need * getting rid of debug statements for llama cpp * working with long contexts * fix long context view error * reverting some changes I made for files that are not apart of pr * Added proper tokenizer for deeepseek3 * clean up model and go test * remove Modelfile * not passing the tests * whoops * how to pass the ci tests * resolving some of the comments * rename * linted and renamed deepseek3 -> deepseek2 * remove name go * addressed changes - main change was adopting qwen3 naming scheme * I cannot with linters * clean up logs * clean up logs --------- Co-authored-by: Grace Guo Co-authored-by: Grace Guo Co-authored-by: graceguo --- model/models/deepseek2/model.go | 324 ++++++++++++++++++++++++++++++++ model/models/models.go | 1 + 2 files changed, 325 insertions(+) create mode 100644 model/models/deepseek2/model.go diff --git a/model/models/deepseek2/model.go b/model/models/deepseek2/model.go new file mode 100644 index 000000000..7b88711ba --- /dev/null +++ b/model/models/deepseek2/model.go @@ -0,0 +1,324 @@ +package deepseek2 + +// uses deepseek 2 architecture but written based on deepseek 3 model + +import ( + "math" + + "github.com/ollama/ollama/fs" + "github.com/ollama/ollama/kvcache" + "github.com/ollama/ollama/ml" + "github.com/ollama/ollama/ml/nn" + "github.com/ollama/ollama/ml/nn/fast" + "github.com/ollama/ollama/ml/nn/rope" + "github.com/ollama/ollama/model" + "github.com/ollama/ollama/model/input" +) + +type Options struct { + numExpertsUsed int + numExperts int + normTopKProb bool + routedScalingFactor float32 + + kvLoraRank, + qkNopeHeadDim, + qkRopeHeadDim, + kqNopeHeadDim, + qkHeadDim int + qLoraRank int + vHeadDim int + + hiddenSize, + numHeads, + numKVHeads, + keyLength, + valueLength, + originalContextLength int + + eps, + ropeBase, + ropeScale float32 + kqScale float64 +} + +func (o Options) RoPEOptions() []func(*rope.Options) { + attnFactor := float32(1.0 / (1.0 + 0.1*math.Log(float64(o.ropeScale)))) + return []func(*rope.Options){ + rope.WithOriginalContextLength(o.originalContextLength), + rope.WithExtrapolationFactor(1.), + rope.WithAttentionFactor(attnFactor), + } +} + +type Attention struct { + Q *nn.Linear `gguf:"attn_q"` + + QA *nn.Linear `gguf:"attn_q_a"` + QANorm *nn.RMSNorm `gguf:"attn_q_a_norm"` + QB *nn.Linear `gguf:"attn_q_b"` + + KVA *nn.Linear `gguf:"attn_kv_a_mqa"` + KVANorm *nn.RMSNorm `gguf:"attn_kv_a_norm"` + KVB *nn.Linear `gguf:"attn_kv_b"` + + Output *nn.Linear `gguf:"attn_out,alt:attn_output"` +} + +func (attn *Attention) Forward(ctx ml.Context, hiddenStates, positions ml.Tensor, cache kvcache.Cache, opts *Options) ml.Tensor { + seqLength := hiddenStates.Dim(1) + + var query ml.Tensor + if opts.qLoraRank == 0 { // nil { + query = attn.Q.Forward(ctx, hiddenStates) + } else { + query = attn.QA.Forward(ctx, hiddenStates) + query = attn.QANorm.Forward(ctx, query, opts.eps) + query = attn.QB.Forward(ctx, query) + } + + query = query.Reshape(ctx, query.Dim(0)/opts.numHeads, opts.numHeads, seqLength) + + qPass := query.View(ctx, 0, + opts.qkNopeHeadDim, query.Stride(1), + query.Dim(1), query.Stride(2), + query.Dim(2)) + + qRot := query.View(ctx, opts.qkNopeHeadDim*query.Stride(0), + opts.qkRopeHeadDim, query.Stride(1), + query.Dim(1), query.Stride(2), + query.Dim(2)) + + compressedKV := attn.KVA.Forward(ctx, hiddenStates) + + kPass := compressedKV.View(ctx, 0, opts.kvLoraRank, compressedKV.Stride(1), compressedKV.Dim(1)) + kRot := compressedKV.View(ctx, opts.kvLoraRank*compressedKV.Stride(0), + opts.qkRopeHeadDim, compressedKV.Stride(1), + 1, compressedKV.Stride(1), + compressedKV.Dim(1)) + + kPass = attn.KVANorm.Forward(ctx, kPass, opts.eps) + kPass = attn.KVB.Forward(ctx, kPass) + + kv := kPass.Reshape(ctx, kPass.Dim(0)/opts.numKVHeads, opts.numKVHeads, seqLength) + kPass = kv.View(ctx, 0, opts.kqNopeHeadDim, kv.Stride(1), kv.Dim(1), kv.Stride(2), kv.Dim(2)) + value := kv.View(ctx, opts.kqNopeHeadDim*kv.Stride(0), + opts.vHeadDim, kv.Stride(1), + kv.Dim(1), kv.Stride(2), + kv.Dim(2)).Contiguous(ctx) + + qRot = fast.RoPE(ctx, qRot, positions, opts.qkRopeHeadDim, opts.ropeBase, 1./opts.ropeScale, opts.RoPEOptions()...) + kRot = fast.RoPE(ctx, kRot, positions, opts.qkRopeHeadDim, opts.ropeBase, 1./opts.ropeScale, opts.RoPEOptions()...) + + kRot = kRot.Repeat(ctx, 1, qPass.Dim(1)) + + query = qRot.Concat(ctx, qPass, 0) + key := kRot.Concat(ctx, kPass, 0) + + attention := nn.Attention(ctx, query, key, value, opts.kqScale, cache) + attention = attention.Reshape(ctx, attention.Dim(0)*attention.Dim(1), seqLength) + return attn.Output.Forward(ctx, attention) +} + +type MLP interface { + Forward(ml.Context, ml.Tensor, *Options) ml.Tensor +} + +type sparse struct { + Router *nn.Linear `gguf:"ffn_gate_inp"` + Gate *nn.Linear `gguf:"ffn_gate_exps"` + Up *nn.Linear `gguf:"ffn_up_exps"` + Down *nn.Linear `gguf:"ffn_down_exps"` + SharedExpert *dense `gguf:",suf:_shexp"` + ExpProbsBias ml.Tensor `gguf:"exp_probs_b.bias,alt:exp_probs_b"` +} + +func (moe *sparse) Moe(ctx ml.Context, hiddenStates, topKIndices, topKWeights ml.Tensor, opts *Options) ml.Tensor { + hiddenStates = hiddenStates.Reshape(ctx, hiddenStates.Dim(0), 1, hiddenStates.Dim(1)) + + upStates := moe.Up.Weight.MulmatID(ctx, hiddenStates, topKIndices) + hiddenStates = moe.Gate.Weight.MulmatID(ctx, hiddenStates, topKIndices) + hiddenStates = hiddenStates.SILU(ctx, upStates) + + experts := moe.Down.Weight.MulmatID(ctx, hiddenStates, topKIndices) + experts = experts.Mul(ctx, topKWeights) + nextStates := experts.View(ctx, 0, experts.Dim(0), experts.Stride(2), experts.Dim(2)) + for i := 1; i < opts.numExpertsUsed; i++ { + nextStates = nextStates.Add(ctx, experts.View(ctx, i*experts.Stride(1), experts.Dim(0), experts.Stride(2), experts.Dim(2))) + } + return nextStates +} + +func (moe *sparse) topKIndices(ctx ml.Context, scores ml.Tensor, opts *Options) ml.Tensor { + scores = scores.Add(ctx, moe.ExpProbsBias) + topKIndices := scores.TopK(ctx, opts.numExpertsUsed) + return topKIndices +} + +func (moe *sparse) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *Options) ml.Tensor { + residuals := hiddenStates + + routerLogits := moe.Router.Forward(ctx, hiddenStates) + scores := routerLogits.Sigmoid(ctx) + topKIndices := moe.topKIndices(ctx, scores, opts) + topKWeights := scores.Reshape(ctx, 1, opts.numExperts, hiddenStates.Dim(1)).Rows(ctx, topKIndices) + + if opts.normTopKProb { + topKWeights = topKWeights.Reshape(ctx, opts.numExpertsUsed, hiddenStates.Dim(1)) + topKWeights = topKWeights.Div(ctx, topKWeights.SumRows(ctx)) + topKWeights = topKWeights.Reshape(ctx, 1, opts.numExpertsUsed, hiddenStates.Dim(1)) + } + + topKWeights = topKWeights.Scale(ctx, float64(opts.routedScalingFactor)) + hiddenStates = moe.Moe(ctx, hiddenStates, topKIndices, topKWeights, opts) + sharedExpertResult := moe.SharedExpert.Forward(ctx, residuals, opts) + + hiddenStates = hiddenStates.Add(ctx, sharedExpertResult) + return hiddenStates +} + +type dense struct { + Gate *nn.Linear `gguf:"ffn_gate"` + Up *nn.Linear `gguf:"ffn_up"` + Down *nn.Linear `gguf:"ffn_down"` +} + +func (mlp *dense) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *Options) ml.Tensor { + hiddenStates = mlp.Gate.Forward(ctx, hiddenStates).SILU(ctx, mlp.Up.Forward(ctx, hiddenStates)) + return mlp.Down.Forward(ctx, hiddenStates) +} + +type Layer struct { + AttentionNorm *nn.RMSNorm `gguf:"attn_norm"` + Attention *Attention + + MLPNorm *nn.RMSNorm `gguf:"ffn_norm"` + MLP MLP +} + +func (t *Layer) Forward(ctx ml.Context, hiddenStates, positions, outputs ml.Tensor, cache kvcache.Cache, opts *Options) ml.Tensor { + residual := hiddenStates + hiddenStates = t.AttentionNorm.Forward(ctx, hiddenStates, opts.eps) + hiddenStates = t.Attention.Forward(ctx, hiddenStates, positions, cache, opts) + + if outputs != nil { + hiddenStates = hiddenStates.Rows(ctx, outputs) + residual = residual.Rows(ctx, outputs) + } + + hiddenStates = hiddenStates.Add(ctx, residual) + residual = hiddenStates + + hiddenStates = t.MLPNorm.Forward(ctx, hiddenStates, opts.eps) + hiddenStates = t.MLP.Forward(ctx, hiddenStates, opts) + hiddenStates = hiddenStates.Add(ctx, residual) + return hiddenStates +} + +type Model struct { + model.Base + model.BytePairEncoding + + TokenEmbedding *nn.Embedding `gguf:"token_embd"` + Layers []Layer `gguf:"blk"` + + OutputNorm *nn.RMSNorm `gguf:"output_norm"` + Output *nn.Linear `gguf:"output,alt:token_embd"` + + *Options +} + +func New(c fs.Config) (model.Model, error) { + layers := make([]Layer, c.Uint("block_count")) + + firstDenseLayerIndex := int(c.Uint("leading_dense_block_count")) + for i := range layers { + if i < firstDenseLayerIndex { + layers[i].MLP = &dense{} + } else { + layers[i].MLP = &sparse{} + } + } + + mScale := float32(1.0 + float64(c.Float("rope.scaling.yarn_log_multiplier"))*math.Log(float64(c.Float("rope.scaling.factor")))) + kqScale := float64(mScale) * float64(mScale) / math.Sqrt(float64(c.Uint("attention.key_length"))) + + m := Model{ + BytePairEncoding: model.NewBytePairEncoding( + &model.Vocabulary{ + Values: c.Strings("tokenizer.ggml.tokens"), + Types: c.Ints("tokenizer.ggml.token_type"), + Merges: c.Strings("tokenizer.ggml.merges"), + AddBOS: c.Bool("tokenizer.ggml.add_bos_token", true), + BOS: []int32{int32(c.Uint("tokenizer.ggml.bos_token_id"))}, + AddEOS: c.Bool("tokenizer.ggml.add_eos_token", false), + EOS: append( + []int32{int32(c.Uint("tokenizer.ggml.eos_token_id"))}, + c.Ints("tokenizer.ggml.eos_token_ids")..., + ), + }, + // Split regex into multiple parts (according to DeepSeek3's regex) + "\\p{N}{1,3}", + `[一-龥぀-ゟ゠-ヿ]+`, + "[!\"#$%&'()*+,\\-./:;<=>?@\\[\\\\\\]^_`{|}~][A-Za-z]+|[^\r\n\\p{L}\\p{P}\\p{S}]?[\\p{L}\\p{M}]+| ?[\\p{P}\\p{S}]+[\r\n]*|\\s*[\r\n]+|\\s+(?!\\S)|\\s+", + ), + Layers: layers, + Options: &Options{ + hiddenSize: int(c.Uint("embedding_length")), + numHeads: int(c.Uint("attention.head_count")), + numKVHeads: int(c.Uint("attention.head_count_kv")), + keyLength: int(c.Uint("attention.key_length")), + valueLength: int(c.Uint("attention.value_length")), + eps: c.Float("attention.layer_norm_rms_epsilon"), + ropeBase: c.Float("rope.freq_base"), + ropeScale: c.Float("rope.scaling.factor", 1), + numExperts: int(c.Uint("expert_count")), + numExpertsUsed: int(c.Uint("expert_used_count")), + normTopKProb: c.Bool("expert_weights_norm", true), + + qLoraRank: int(c.Uint("attention.q_lora_rank")), //&qLoraRankVal, + kvLoraRank: int(c.Uint("attention.kv_lora_rank")), + qkHeadDim: int(c.Uint("attention.key_length")), + vHeadDim: int(c.Uint("attention.value_length")), + qkRopeHeadDim: int(c.Uint("rope.dimension_count")), + qkNopeHeadDim: int(c.Uint("attention.key_length")) - int(c.Uint("rope.dimension_count")), + kqNopeHeadDim: int(c.Uint("attention.key_length")) - int(c.Uint("rope.dimension_count")), + + routedScalingFactor: c.Float("expert_weights_scale"), + originalContextLength: int(c.Uint("rope.scaling.original_context_length")), + + kqScale: kqScale, + }, + } + + m.Cache = kvcache.NewCausalCache(m.Shift) + return &m, nil +} + +func (m Model) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) { + return fast.RoPE(ctx, key, shift, m.qkRopeHeadDim, m.ropeBase, 1./m.ropeScale, m.RoPEOptions()...), nil +} + +func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) { + positions := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions)) + + hiddenStates := m.TokenEmbedding.Forward(ctx, batch.Inputs) + + for i, layer := range m.Layers { + m.Cache.SetLayer(i) + + var outputs ml.Tensor + if i == len(m.Layers)-1 { + outputs = batch.Outputs + } + + hiddenStates = layer.Forward(ctx, hiddenStates, positions, outputs, m.Cache, m.Options) + } + + hiddenStates = m.OutputNorm.Forward(ctx, hiddenStates, m.eps) + return m.Output.Forward(ctx, hiddenStates), nil +} + +func init() { + model.Register("deepseek2", New) +} diff --git a/model/models/models.go b/model/models/models.go index cc9980789..0cda615af 100644 --- a/model/models/models.go +++ b/model/models/models.go @@ -2,6 +2,7 @@ package models import ( _ "github.com/ollama/ollama/model/models/bert" + _ "github.com/ollama/ollama/model/models/deepseek2" _ "github.com/ollama/ollama/model/models/gemma2" _ "github.com/ollama/ollama/model/models/gemma3" _ "github.com/ollama/ollama/model/models/gemma3n"