Compare commits

..

15 Commits

Author SHA1 Message Date
Patrick Devine
d3e0a0dee4 model: ministral w/ llama4 scaling (#13292)
This change:

* fixes rope scaling in the mistral converter
* updates ministral to include llama4 scaling
* includes a new ministral parser for parsing reasoning and tool calling

---------

Co-authored-by: jmorganca <jmorganca@gmail.com>
2025-12-01 23:20:14 -08:00
Daniel Hiltgen
554172759c win: warn if ggml-base detected in PATH (#13289)
If the user has somehow installed another GGML based app which places a
ggml-base lib somewhere in their PATH, we can experience runtime problems
due to incompatibilities.  This change adds a warning message if we detect
a ggml-base outside of our install location to aid in troubleshooting.
2025-12-01 15:36:47 -08:00
Bruce MacDonald
5b6a8e6001 api/client: handle non-json streaming errors (#13007)
While processing the response stream during a chat or generation if an error is occurred it is parsed and returned to the user. The issue with the existing code is that this assumed the response would be valid JSON, which is not a safe assumption and caused cryptic error messages to be displayed due to parsing failures:
`invalid character 'i' looking for beginning of value`

This change updates the stream function to return the raw error string if it cant be parsed as JSON. This should help with debugging issues by making sure the actual error reaches the user.
2025-12-01 15:10:16 -08:00
Daniel Hiltgen
467bbc0dd5 jetpack: require exact match or skip cuda_jetpack* (#13288)
The cuda_jetpack libs will enumerate discrete GPUs on SBSA systems
which leads to runtime failures of missing kernels.  This fix
requires an exact match to enable jetpacks instead of relying on
enumeration to filter out supported libraries.
2025-12-01 12:48:16 -08:00
Jeffrey Morgan
6d9f9323c5 .gitattributes: add app/webview to linguist-vendored (#13274) 2025-11-29 23:46:10 -05:00
Ondrej Kokes
0c2489605d docs: fix output formatting in faq.mdx (#13231)
There were a few Markdown typos in one FAQ answer. It now renders as a proper ascii table.
2025-11-28 19:19:21 -05:00
EntropyYue
8b1b89a984 docs: remove deprecated parameters (#13237) 2025-11-26 11:03:09 +09:00
Eva H
47e272c35a app/cmd: update ollama help to navigate to ollama doc instead of github page (#13174) 2025-11-20 16:30:35 -05:00
Jeffrey Morgan
417a81fda3 app: open app instead of always navigating to / on connect (#13164) 2025-11-20 12:59:17 -08:00
Daniel Hiltgen
dba62ff3a5 discovery: fix cuda overlap case (#13176)
Recent refactoring introduced a regression for filtering cuda overlap to favor newest supported version.
2025-11-20 12:15:37 -08:00
Grace
d70e935526 Parser for Cogito v2 (#13145) 2025-11-19 17:21:07 -08:00
Michael Yang
5c1063df7f deepseek2: upgrade to run v3+ models (#13166)
the check for mla omits v3 and r1 which should not return unsupported.
instead check the tokenizer for compatibility
2025-11-19 17:05:39 -08:00
Jesse Gross
cb485b2019 kvcache: Run tests both with and without PermutedV
The causal cache can store data differently depending on what is
best for the backend. We should run tests both ways.
2025-11-19 16:45:30 -08:00
nicole pardal
b2af50960f nomic-embed: nomic-embed-text defaulted to ollama runner (#13144) 2025-11-19 13:03:44 -08:00
Michael Yang
eac5b8bfbd chore: mark vulkan shaders as vendored files 2025-11-19 12:01:23 -08:00
31 changed files with 2043 additions and 600 deletions

4
.gitattributes vendored
View File

@@ -15,8 +15,12 @@ ml/backend/**/*.cu linguist-vendored
ml/backend/**/*.cuh linguist-vendored
ml/backend/**/*.m linguist-vendored
ml/backend/**/*.metal linguist-vendored
ml/backend/**/*.comp linguist-vendored
ml/backend/**/*.glsl linguist-vendored
ml/backend/**/CMakeLists.txt linguist-vendored
app/webview linguist-vendored
llama/build-info.cpp linguist-generated
ml/backend/ggml/ggml/src/ggml-metal/ggml-metal-embed.s linguist-generated

View File

@@ -11,7 +11,6 @@ linters:
- errorlint
- exptostd
- gocheckcompilerdirectives
- gocritic
- govet
- ineffassign
- intrange

View File

@@ -226,7 +226,14 @@ func (c *Client) stream(ctx context.Context, method, path string, data any, fn f
bts := scanner.Bytes()
if err := json.Unmarshal(bts, &errorResponse); err != nil {
return fmt.Errorf("unmarshal: %w", err)
if response.StatusCode >= http.StatusBadRequest {
return StatusError{
StatusCode: response.StatusCode,
Status: response.Status,
ErrorMessage: string(bts),
}
}
return errors.New(string(bts))
}
if response.StatusCode == http.StatusUnauthorized {

View File

@@ -55,6 +55,7 @@ func TestClientFromEnvironment(t *testing.T) {
type testError struct {
message string
statusCode int
raw bool // if true, write message as-is instead of JSON encoding
}
func (e testError) Error() string {
@@ -111,6 +112,20 @@ func TestClientStream(t *testing.T) {
},
},
},
{
name: "plain text error response",
responses: []any{
"internal server error",
},
wantErr: "internal server error",
},
{
name: "HTML error page",
responses: []any{
"<html><body>404 Not Found</body></html>",
},
wantErr: "404 Not Found",
},
}
for _, tc := range testCases {
@@ -135,6 +150,12 @@ func TestClientStream(t *testing.T) {
return
}
if str, ok := resp.(string); ok {
fmt.Fprintln(w, str)
flusher.Flush()
continue
}
if err := json.NewEncoder(w).Encode(resp); err != nil {
t.Fatalf("failed to encode response: %v", err)
}
@@ -173,9 +194,10 @@ func TestClientStream(t *testing.T) {
func TestClientDo(t *testing.T) {
testCases := []struct {
name string
response any
wantErr string
name string
response any
wantErr string
wantStatusCode int
}{
{
name: "immediate error response",
@@ -183,7 +205,8 @@ func TestClientDo(t *testing.T) {
message: "test error message",
statusCode: http.StatusBadRequest,
},
wantErr: "test error message",
wantErr: "test error message",
wantStatusCode: http.StatusBadRequest,
},
{
name: "server error response",
@@ -191,7 +214,8 @@ func TestClientDo(t *testing.T) {
message: "internal error",
statusCode: http.StatusInternalServerError,
},
wantErr: "internal error",
wantErr: "internal error",
wantStatusCode: http.StatusInternalServerError,
},
{
name: "successful response",
@@ -203,6 +227,26 @@ func TestClientDo(t *testing.T) {
Success: true,
},
},
{
name: "plain text error response",
response: testError{
message: "internal server error",
statusCode: http.StatusInternalServerError,
raw: true,
},
wantErr: "internal server error",
wantStatusCode: http.StatusInternalServerError,
},
{
name: "HTML error page",
response: testError{
message: "<html><body>404 Not Found</body></html>",
statusCode: http.StatusNotFound,
raw: true,
},
wantErr: "<html><body>404 Not Found</body></html>",
wantStatusCode: http.StatusNotFound,
},
}
for _, tc := range testCases {
@@ -210,11 +254,16 @@ func TestClientDo(t *testing.T) {
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if errResp, ok := tc.response.(testError); ok {
w.WriteHeader(errResp.statusCode)
err := json.NewEncoder(w).Encode(map[string]string{
"error": errResp.message,
})
if err != nil {
t.Fatal("failed to encode error response:", err)
if !errResp.raw {
err := json.NewEncoder(w).Encode(map[string]string{
"error": errResp.message,
})
if err != nil {
t.Fatal("failed to encode error response:", err)
}
} else {
// Write raw message (simulates non-JSON error responses)
fmt.Fprint(w, errResp.message)
}
return
}
@@ -241,6 +290,15 @@ func TestClientDo(t *testing.T) {
if err.Error() != tc.wantErr {
t.Errorf("error message mismatch: got %q, want %q", err.Error(), tc.wantErr)
}
if tc.wantStatusCode != 0 {
if statusErr, ok := err.(StatusError); ok {
if statusErr.StatusCode != tc.wantStatusCode {
t.Errorf("status code mismatch: got %d, want %d", statusErr.StatusCode, tc.wantStatusCode)
}
} else {
t.Errorf("expected StatusError, got %T", err)
}
}
return
}

View File

@@ -397,8 +397,8 @@ func checkUserLoggedIn(uiServerPort int) bool {
// handleConnectURLScheme fetches the connect URL and opens it in the browser
func handleConnectURLScheme() {
if checkUserLoggedIn(uiServerPort) {
slog.Info("user is already logged in, opening settings instead")
sendUIRequestMessage("/")
slog.Info("user is already logged in, opening app instead")
showWindow(wv.webview.Window())
return
}
@@ -466,6 +466,8 @@ func handleURLSchemeInCurrentInstance(urlSchemeRequest string) {
if isConnect {
handleConnectURLScheme()
} else {
sendUIRequestMessage("/")
if wv.webview != nil {
showWindow(wv.webview.Window())
}
}
}

View File

@@ -24,27 +24,14 @@ bool firstTimeRun,startHidden; // Set in run before initialization
for (NSURL *url in urls) {
if ([url.scheme isEqualToString:@"ollama"]) {
NSString *path = url.path;
if (!path || [path isEqualToString:@""]) {
// For URLs like ollama://settings (without triple slash),
// the "settings" part is parsed as the host, not the path.
// We need to convert it to a path by prepending "/"
if (url.host && ![url.host isEqualToString:@""]) {
path = [@"/" stringByAppendingString:url.host];
} else {
path = @"/";
}
}
if ([path isEqualToString:@"/connect"] || [url.host isEqualToString:@"connect"]) {
if (path && ([path isEqualToString:@"/connect"] || [url.host isEqualToString:@"connect"])) {
// Special case: handle connect by opening browser instead of app
handleConnectURL();
} else {
// Set app to be active and visible
[NSApp setActivationPolicy:NSApplicationActivationPolicyRegular];
[NSApp activateIgnoringOtherApps:YES];
// Open the path with the UI
[self uiRequest:path];
}
break;
@@ -260,7 +247,7 @@ bool firstTimeRun,startHidden; // Set in run before initialization
}
- (void)openHelp:(id)sender {
NSURL *url = [NSURL URLWithString:@"https://github.com/ollama/ollama/tree/main/docs"];
NSURL *url = [NSURL URLWithString:@"https://docs.ollama.com/"];
[[NSWorkspace sharedWorkspace] openURL:url];
}

View File

@@ -147,7 +147,9 @@ func handleURLSchemeRequest(urlScheme string) {
if isConnect {
handleConnectURLScheme()
} else {
sendUIRequestMessage("/")
if wv.webview != nil {
showWindow(wv.webview.Window())
}
}
}

View File

@@ -25,7 +25,7 @@ declare module "@/gotypes" {
}
Model.prototype.isCloud = function (): boolean {
return this.model.endsWith("cloud") || this.model === "gemini-3-pro-preview";
return this.model.endsWith("cloud");
};
// Helper function to convert Uint8Array to base64

View File

@@ -14,8 +14,8 @@ describe("Model merging logic", () => {
const merged = mergeModels(localModels);
// First verify cloud models are first and in FEATURED_MODELS order
const cloudModels = FEATURED_MODELS.filter(
(m: string) => m.endsWith("cloud") || m === "gemini-3-pro-preview",
const cloudModels = FEATURED_MODELS.filter((m: string) =>
m.endsWith("cloud"),
);
for (let i = 0; i < cloudModels.length; i++) {
expect(merged[i].model).toBe(cloudModels[i]);
@@ -24,7 +24,7 @@ describe("Model merging logic", () => {
// Then verify non-cloud featured models are next and in FEATURED_MODELS order
const nonCloudFeatured = FEATURED_MODELS.filter(
(m: string) => !m.endsWith("cloud") && m !== "gemini-3-pro-preview",
(m: string) => !m.endsWith("cloud"),
);
for (let i = 0; i < nonCloudFeatured.length; i++) {
const model = merged[i + cloudModels.length];
@@ -54,9 +54,9 @@ describe("Model merging logic", () => {
const cloudModels = merged.filter((m) => m.isCloud());
expect(cloudModels.length).toBe(0);
// Should have non-cloud featured models (excluding gemini-3-pro-preview which is treated as cloud)
// Should have non-cloud featured models
const nonCloudFeatured = FEATURED_MODELS.filter(
(m) => !m.endsWith("cloud") && m !== "gemini-3-pro-preview",
(m) => !m.endsWith("cloud"),
);
for (let i = 0; i < nonCloudFeatured.length; i++) {
const model = merged[i];
@@ -74,9 +74,7 @@ describe("Model merging logic", () => {
const merged = mergeModels([]);
// First verify cloud models are first and in FEATURED_MODELS order
const cloudModels = FEATURED_MODELS.filter(
(m) => m.endsWith("cloud") || m === "gemini-3-pro-preview",
);
const cloudModels = FEATURED_MODELS.filter((m) => m.endsWith("cloud"));
for (let i = 0; i < cloudModels.length; i++) {
expect(merged[i].model).toBe(cloudModels[i]);
expect(merged[i].isCloud()).toBe(true);
@@ -84,7 +82,7 @@ describe("Model merging logic", () => {
// Then verify non-cloud featured models are next and in FEATURED_MODELS order
const nonCloudFeatured = FEATURED_MODELS.filter(
(m) => !m.endsWith("cloud") && m !== "gemini-3-pro-preview",
(m) => !m.endsWith("cloud"),
);
for (let i = 0; i < nonCloudFeatured.length; i++) {
const model = merged[i + cloudModels.length];
@@ -106,9 +104,7 @@ describe("Model merging logic", () => {
const merged = mergeModels(localModels);
// First verify cloud models are first and in FEATURED_MODELS order
const cloudModels = FEATURED_MODELS.filter(
(m) => m.endsWith("cloud") || m === "gemini-3-pro-preview",
);
const cloudModels = FEATURED_MODELS.filter((m) => m.endsWith("cloud"));
for (let i = 0; i < cloudModels.length; i++) {
expect(merged[i].model).toBe(cloudModels[i]);
expect(merged[i].isCloud()).toBe(true);
@@ -116,7 +112,7 @@ describe("Model merging logic", () => {
// Then verify non-cloud featured models are next and in FEATURED_MODELS order
const nonCloudFeatured = FEATURED_MODELS.filter(
(m) => !m.endsWith("cloud") && m !== "gemini-3-pro-preview",
(m) => !m.endsWith("cloud"),
);
for (let i = 0; i < nonCloudFeatured.length; i++) {
const model = merged[i + cloudModels.length];

View File

@@ -4,7 +4,6 @@ import { Model } from "@/gotypes";
export const FEATURED_MODELS = [
"gpt-oss:120b-cloud",
"gpt-oss:20b-cloud",
"gemini-3-pro-preview",
"deepseek-v3.1:671b-cloud",
"qwen3-coder:480b-cloud",
"qwen3-vl:235b-cloud",
@@ -41,9 +40,7 @@ export function mergeModels(
const cloudModels = [...allModels.filter((m) => m.isCloud())];
// Add any cloud models from FEATURED_MODELS that aren't in local models
FEATURED_MODELS.filter(
(f) => f.endsWith("cloud") || f === "gemini-3-pro-preview",
).forEach((cloudModel) => {
FEATURED_MODELS.filter((f) => f.endsWith("cloud")).forEach((cloudModel) => {
if (!cloudModels.some((m) => m.model === cloudModel)) {
cloudModels.push(new Model({ model: cloudModel }));
}
@@ -51,7 +48,7 @@ export function mergeModels(
// 2. Get other featured models (non-cloud)
const featuredModels = FEATURED_MODELS.filter(
(f) => !f.endsWith("cloud") && f !== "gemini-3-pro-preview",
(f) => !f.endsWith("cloud"),
).map((model) => {
// Check if this model exists in local models
const localMatch = allModels.find(

View File

@@ -1430,7 +1430,7 @@ func chat(cmd *cobra.Command, opts runOptions) (*api.Message, error) {
latest.Summary()
}
return &api.Message{Role: role, Content: fullResponse.String()}, nil
return &api.Message{Role: role, Thinking: thinkingContent.String(), Content: fullResponse.String()}, nil
}
func generate(cmd *cobra.Command, opts runOptions) error {

View File

@@ -29,6 +29,15 @@ type mistral3Model struct {
SlidingWindow *uint32 `json:"sliding_window"`
HiddenAct string `json:"hidden_act"`
VocabSize uint32 `json:"vocab_size"`
RopeParameters struct {
BetaFast float32 `json:"beta_fast"`
BetaSlow float32 `json:"beta_slow"`
Factor float32 `json:"factor"`
ScalingBeta float32 `json:"llama_4_scaling_beta"`
OrigMaxPositionEmbeddings uint32 `json:"original_max_position_embeddings"`
RopeType string `json:"rope_type"`
RopeTheta float32 `json:"rope_theta"`
} `json:"rope_parameters"`
} `json:"text_config"`
VisionModel struct {
NumAttentionHeads uint32 `json:"num_attention_heads"`
@@ -61,8 +70,13 @@ func (p *mistral3Model) KV(t *Tokenizer) ggml.KV {
kv["mistral3.attention.layer_norm_rms_epsilon"] = p.TextModel.RMSNormEPS
kv["mistral3.attention.key_length"] = p.TextModel.HeadDim
kv["mistral3.attention.value_length"] = p.TextModel.HeadDim
kv["mistral3.rope.dimension_count"] = p.TextModel.HiddenSize / p.TextModel.NumHiddenLayers
kv["mistral3.rope.freq_base"] = p.TextModel.RopeTheta
kv["mistral3.rope.dimension_count"] = cmp.Or(p.TextModel.HeadDim, p.TextModel.HiddenSize/p.TextModel.NumAttentionHeads)
kv["mistral3.rope.freq_base"] = cmp.Or(p.TextModel.RopeTheta, p.TextModel.RopeParameters.RopeTheta)
if p.TextModel.RopeParameters.OrigMaxPositionEmbeddings > 0 {
kv["mistral3.rope.scaling.original_context_length"] = p.TextModel.RopeParameters.OrigMaxPositionEmbeddings
kv["mistral3.rope.scaling_beta"] = p.TextModel.RopeParameters.ScalingBeta
}
// Vision configuration
kv["mistral3.vision.block_count"] = p.VisionModel.NumHiddenLayers

View File

@@ -65,6 +65,7 @@ func GPUDevices(ctx context.Context, runners []ml.FilteredRunnerDiscovery) []ml.
}
slog.Info("discovering available GPUs...")
detectIncompatibleLibraries()
// Warn if any user-overrides are set which could lead to incorrect GPU discovery
overrideWarnings()
@@ -98,6 +99,9 @@ func GPUDevices(ctx context.Context, runners []ml.FilteredRunnerDiscovery) []ml.
continue
} else if jetpack != "" && filepath.Base(dir) != "cuda_"+jetpack {
continue
} else if jetpack == "" && strings.Contains(filepath.Base(dir), "cuda_jetpack") {
slog.Debug("jetpack not detected (set JETSON_JETPACK or OLLAMA_LLM_LIBRARY to override), skipping", "libDir", dir)
continue
} else if !envconfig.EnableVulkan() && strings.Contains(filepath.Base(dir), "vulkan") {
slog.Info("experimental Vulkan support disabled. To enable, set OLLAMA_VULKAN=1")
continue
@@ -125,10 +129,20 @@ func GPUDevices(ctx context.Context, runners []ml.FilteredRunnerDiscovery) []ml.
supportedMu := sync.Mutex{}
supported := make(map[string]map[string]map[string]int) // [Library][libDir][ID] = pre-deletion devices index
for i := range devices {
libDir := devices[i].LibraryPath[len(devices[i].LibraryPath)-1]
if !devices[i].NeedsInitValidation() {
// No need to validate, add to the supported map
supportedMu.Lock()
if _, ok := supported[devices[i].Library]; !ok {
supported[devices[i].Library] = make(map[string]map[string]int)
}
if _, ok := supported[devices[i].Library][libDir]; !ok {
supported[devices[i].Library][libDir] = make(map[string]int)
}
supported[devices[i].Library][libDir][devices[i].ID] = i
supportedMu.Unlock()
continue
}
libDir := devices[i].LibraryPath[len(devices[i].LibraryPath)-1]
slog.Debug("verifying if device is supported", "library", libDir, "description", devices[i].Description, "compute", devices[i].Compute(), "id", devices[i].ID, "pci_id", devices[i].PCIID)
wg.Add(1)
go func(i int) {
@@ -474,3 +488,16 @@ func overrideWarnings() {
slog.Warn("if GPUs are not correctly discovered, unset and try again")
}
}
func detectIncompatibleLibraries() {
if runtime.GOOS != "windows" {
return
}
basePath, err := exec.LookPath("ggml-base.dll")
if err != nil || basePath == "" {
return
}
if !strings.HasPrefix(basePath, ml.LibOllamaPath) {
slog.Warn("potentially incompatible library detected in PATH", "location", basePath)
}
}

View File

@@ -57,8 +57,13 @@ ollama ps
```
<Info>
**Output**: ``` NAME ID SIZE PROCESSOR UNTIL llama3:70b bcfb190ca3a7 42 GB
100% GPU 4 minutes from now ```
**Output**:
```
NAME ID SIZE PROCESSOR UNTIL
llama3:70b bcfb190ca3a7 42 GB 100% GPU 4 minutes from now
```
</Info>
The `Processor` column will show which memory the model was loaded in to:
@@ -385,4 +390,4 @@ Ollama for Windows and macOS register as a login item during installation. You
- In `Task Manager` go to the `Startup apps` tab, search for `ollama` then click `Disable`
**MacOS**
- Open `Settings` and search for "Login Items", find the `Ollama` entry under "Allow in the Background`, then click the slider to disable.
- Open `Settings` and search for "Login Items", find the `Ollama` entry under "Allow in the Background`, then click the slider to disable.

View File

@@ -149,9 +149,6 @@ PARAMETER <parameter> <parametervalue>
| Parameter | Description | Value Type | Example Usage |
| -------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | ---------- | -------------------- |
| mirostat | Enable Mirostat sampling for controlling perplexity. (default: 0, 0 = disabled, 1 = Mirostat, 2 = Mirostat 2.0) | int | mirostat 0 |
| mirostat_eta | Influences how quickly the algorithm responds to feedback from the generated text. A lower learning rate will result in slower adjustments, while a higher learning rate will make the algorithm more responsive. (Default: 0.1) | float | mirostat_eta 0.1 |
| mirostat_tau | Controls the balance between coherence and diversity of the output. A lower value will result in more focused and coherent text. (Default: 5.0) | float | mirostat_tau 5.0 |
| num_ctx | Sets the size of the context window used to generate the next token. (Default: 2048) | int | num_ctx 4096 |
| repeat_last_n | Sets how far back for the model to look back to prevent repetition. (Default: 64, 0 = disabled, -1 = num_ctx) | int | repeat_last_n 64 |
| repeat_penalty | Sets how strongly to penalize repetitions. A higher value (e.g., 1.5) will penalize repetitions more strongly, while a lower value (e.g., 0.9) will be more lenient. (Default: 1.1) | float | repeat_penalty 1.1 |

View File

@@ -251,6 +251,7 @@ func (kv KV) OllamaEngineRequired() bool {
"qwen3vl", "qwen3vlmoe",
"deepseekocr",
"deepseek2",
"nomic-bert",
}, kv.Architecture())
}

View File

@@ -388,9 +388,9 @@ func NewFunctionNameMap() *FunctionNameMap {
}
}
// Init initializes the handler with tools and optional last message
// Init initializes the handler with tools, optional last message, and think value
// Implements the Parser interface
func (h *HarmonyMessageHandler) Init(tools []api.Tool, lastMessage *api.Message) []api.Tool {
func (h *HarmonyMessageHandler) Init(tools []api.Tool, lastMessage *api.Message, thinkValue *api.ThinkValue) []api.Tool {
// Initialize the harmony parser
if h.HarmonyParser == nil {
h.HarmonyParser = &HarmonyParser{

File diff suppressed because it is too large Load Diff

View File

@@ -236,11 +236,6 @@ type Model struct {
}
func New(c fs.Config) (model.Model, error) {
if c.Uint("attention.key_length_mla") == 0 {
// non-MLA models aren't yet supported
return nil, model.ErrUnsupportedModel
}
layers := make([]Layer, c.Uint("block_count"))
firstDenseLayerIndex := int(c.Uint("leading_dense_block_count"))
@@ -259,6 +254,30 @@ func New(c fs.Config) (model.Model, error) {
keyLength := int(cmp.Or(c.Uint("attention.key_length_mla"), c.Uint("attention.key_length")))
valueLength := int(cmp.Or(c.Uint("attention.value_length_mla"), c.Uint("attention.value_length")))
var pre []string
switch c.String("tokenizer.ggml.pre") {
case "deepseek-v3":
pre = []string{
// Split regex into multiple parts (according to DeepSeek3's regex)
"\\p{N}{1,3}",
`[一-龥぀-ゟ゠-ヿ]+`,
"[!\"#$%&'()*+,\\-./:;<=>?@\\[\\\\\\]^_`{|}~][A-Za-z]+|[^\r\n\\p{L}\\p{P}\\p{S}]?[\\p{L}\\p{M}]+| ?[\\p{P}\\p{S}]+[\r\n]*|\\s*[\r\n]+|\\s+(?!\\S)|\\s+",
}
case "deepseek-llm":
// TODO: these models haven't been vetted so skip for now
// pre = []string{
// "[\r\n]",
// "\\s?[A-Za-zµÀ-ÖØ-öø-ƺƼ-ƿDŽ-ʓʕ-ʯͰ-ͳͶͷͻ-ͽͿΆΈ-ΊΌΎ-ΡΣ-ϵϷ-ҁҊ-ԯԱ-ՖႠ-ჅᎠ-Ᏽᏸ-ᏽᲐ-ᲺᲽ-Ჿᴀ-ᴫᵫ-ᵷᵹ-ᶚḀ-ἕἘ-Ἕἠ-ὅὈ-Ὅὐ-ὗὙὛὝὟ-ώᾀ-ᾴᾶ-ᾼιῂ-ῄῆ-ῌῐ-ΐῖ-Ίῠ-Ῥῲ-ῴῶ-ῼℂℇℊ--ℝℤΩℨK--ℴℹℼ-ℿⅅ-ⅉⅎↃↄⰀ-ⱻⱾ-ⳤⳫ-ⳮⳲⳳꙀ-ꙭꚀ-ꚛꜢ-ꝯꝱ-ꞇꞋ-ꞎꭰ-ꮿff-stﬓ-ﬗA--z𐐀-𐑏𐒰-𐓓𐓘-𐓻𐲀-𐲲𐳀-𐳲𑢠-𑣟𞤀-𞥃]+",
// "\\s?[!-/:-~---‟ -。]+",
// "\\s+$",
// "[一-龥ࠀ-一가-퟿]+",
// "[0-9]",
// }
fallthrough
default:
return nil, model.ErrUnsupportedTokenizer
}
m := Model{
BytePairEncoding: model.NewBytePairEncoding(
&model.Vocabulary{
@@ -273,10 +292,7 @@ func New(c fs.Config) (model.Model, error) {
c.Ints("tokenizer.ggml.eos_token_ids")...,
),
},
// Split regex into multiple parts (according to DeepSeek3's regex)
"\\p{N}{1,3}",
`[一-龥぀-ゟ゠-ヿ]+`,
"[!\"#$%&'()*+,\\-./:;<=>?@\\[\\\\\\]^_`{|}~][A-Za-z]+|[^\r\n\\p{L}\\p{P}\\p{S}]?[\\p{L}\\p{M}]+| ?[\\p{P}\\p{S}]+[\r\n]*|\\s*[\r\n]+|\\s+(?!\\S)|\\s+",
pre...,
),
Layers: layers,
Options: &Options{

View File

@@ -159,8 +159,9 @@ func (m *Model) PostTokenize(inputs []*input.Input) ([]*input.Input, error) {
func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
positions := ctx.Input().FromInts(batch.Positions, len(batch.Positions))
positionsScale := m.getScale(ctx, batch.Positions)
return m.TextModel.Forward(ctx, batch.Inputs, positions, batch.Outputs, batch, m.Cache), nil
return m.TextModel.Forward(ctx, batch.Inputs, positions, positionsScale, batch.Outputs, batch, m.Cache), nil
}
func init() {

View File

@@ -16,6 +16,8 @@ type TextOptions struct {
hiddenSize, numHeads, numKVHeads int
headDim, ropeDim int
eps, ropeBase, ropeScale float32
ropeOrigPosEmbeddings int
ropeScalingBeta float32
}
type TextModel struct {
@@ -34,7 +36,7 @@ type SelfAttention struct {
Output *nn.Linear `gguf:"attn_output"`
}
func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Tensor, cache kvcache.Cache, opts *TextOptions) ml.Tensor {
func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs, positionsScale ml.Tensor, cache kvcache.Cache, opts *TextOptions) ml.Tensor {
batchSize := hiddenState.Dim(1)
headDim := cmp.Or(opts.headDim, opts.hiddenSize/opts.numHeads)
@@ -49,6 +51,10 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Ten
v := sa.Value.Forward(ctx, hiddenState)
v = v.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
if opts.ropeOrigPosEmbeddings > 0 {
q = q.Mul(ctx, positionsScale)
}
kqv := nn.Attention(ctx, q, k, v, 1.0/math.Sqrt(float64(headDim)), cache)
kqv = kqv.Reshape(ctx, headDim*opts.numHeads, batchSize)
return sa.Output.Forward(ctx, kqv)
@@ -76,11 +82,11 @@ type Layer struct {
MLP *MLP
}
func (l *Layer) Forward(ctx ml.Context, hiddenState, positionIDs, outputs ml.Tensor, cache kvcache.Cache, opts *TextOptions) ml.Tensor {
func (l *Layer) Forward(ctx ml.Context, hiddenState, positionIDs, positionsScale, outputs ml.Tensor, cache kvcache.Cache, opts *TextOptions) ml.Tensor {
residual := hiddenState
hiddenState = l.AttentionNorm.Forward(ctx, hiddenState, opts.eps)
hiddenState = l.SelfAttention.Forward(ctx, hiddenState, positionIDs, cache, opts)
hiddenState = l.SelfAttention.Forward(ctx, hiddenState, positionIDs, positionsScale, cache, opts)
// In the final layer (outputs != nil), optimize by pruning to just the token positions
// we need logits for.
@@ -97,7 +103,7 @@ func (l *Layer) Forward(ctx ml.Context, hiddenState, positionIDs, outputs ml.Ten
return hiddenState.Add(ctx, residual)
}
func (m *TextModel) Forward(ctx ml.Context, inputs, positions, outputs ml.Tensor, batch input.Batch, cache kvcache.Cache) ml.Tensor {
func (m *TextModel) Forward(ctx ml.Context, inputs, positions, positionsScale, outputs ml.Tensor, batch input.Batch, cache kvcache.Cache) ml.Tensor {
hiddenState := m.TokenEmbedding.Forward(ctx, inputs).Duplicate(ctx)
// image embeddings
@@ -114,25 +120,36 @@ func (m *TextModel) Forward(ctx ml.Context, inputs, positions, outputs ml.Tensor
lastLayerOutputs = outputs
}
hiddenState = layer.Forward(ctx, hiddenState, positions, lastLayerOutputs, cache, m.TextOptions)
hiddenState = layer.Forward(ctx, hiddenState, positions, positionsScale, lastLayerOutputs, cache, m.TextOptions)
}
hiddenState = m.OutputNorm.Forward(ctx, hiddenState, m.eps)
return m.Output.Forward(ctx, hiddenState)
}
func (m *TextModel) getScale(ctx ml.Context, positions []int32) ml.Tensor {
posScale := make([]float32, len(positions))
for n, pos := range positions {
interval := math.Floor(float64(pos) / float64(m.ropeOrigPosEmbeddings))
posScale[n] = float32(1.0 + float64(m.ropeScalingBeta)*math.Log(1.0+interval))
}
return ctx.Input().FromFloats(posScale, 1, 1, len(posScale))
}
func newTextModel(c fs.Config) *TextModel {
return &TextModel{
Layers: make([]Layer, c.Uint("block_count")),
TextOptions: &TextOptions{
hiddenSize: int(c.Uint("embedding_length")),
numHeads: int(c.Uint("attention.head_count")),
numKVHeads: int(c.Uint("attention.head_count_kv")),
headDim: int(c.Uint("attention.key_length")),
ropeDim: int(c.Uint("rope.dimension_count")),
eps: c.Float("attention.layer_norm_rms_epsilon"),
ropeBase: c.Float("rope.freq_base"),
ropeScale: c.Float("rope.scaling.factor", 1),
hiddenSize: int(c.Uint("embedding_length")),
numHeads: int(c.Uint("attention.head_count")),
numKVHeads: int(c.Uint("attention.head_count_kv")),
headDim: int(c.Uint("attention.key_length")),
ropeDim: int(c.Uint("rope.dimension_count")),
eps: c.Float("attention.layer_norm_rms_epsilon"),
ropeBase: c.Float("rope.freq_base"),
ropeScale: c.Float("rope.scaling.factor", 1),
ropeOrigPosEmbeddings: int(c.Uint("rope.scaling.original_context_length")),
ropeScalingBeta: c.Float("rope.scaling_beta"),
},
}
}

319
model/parsers/cogito.go Normal file
View File

@@ -0,0 +1,319 @@
package parsers
import (
"encoding/json"
"errors"
"log/slog"
"strings"
"unicode"
"github.com/ollama/ollama/api"
)
type CogitoParserState int
const (
CogitoCollectingThinking CogitoParserState = iota
CogitoCollectingContent
CogitoCollectingToolCalls
CogitoCollectingToolOutput
)
const (
cogitoThinkingCloseTag = "</think>"
cogitoToolCallsBeginTag = "<tool▁calls▁begin>"
cogitoToolCallsEndTag = "<tool▁calls▁end>"
cogitoToolCallBeginTag = "<tool▁call▁begin>"
cogitoToolCallEndTag = "<tool▁call▁end>"
cogitoToolSepTag = "<tool▁sep>"
cogitoToolOutputBeginTag = "<tool▁output▁begin>"
cogitoToolOutputEndTag = "<tool▁output▁end>"
cogitoToolOutputsBeginTag = "<tool▁outputs▁begin>"
cogitoToolOutputsEndTag = "<tool▁outputs▁end>"
)
type CogitoParser struct {
state CogitoParserState
buffer strings.Builder
}
func (p *CogitoParser) HasToolSupport() bool {
return true
}
func (p *CogitoParser) HasThinkingSupport() bool {
return true
}
func (p *CogitoParser) setInitialState(lastMessage *api.Message, tools []api.Tool, thinkValue *api.ThinkValue) {
prefill := lastMessage != nil && lastMessage.Role == "assistant"
// Check both model capability AND request preference
thinkingEnabled := thinkValue != nil && thinkValue.Bool()
// thinkingEnabled should be set to false for tools
if !thinkingEnabled {
p.state = CogitoCollectingContent
return
}
if prefill && lastMessage.Content != "" {
p.state = CogitoCollectingContent
return
}
// Note: for cogito, if there are tools, then we don't want to be thinking
if len(tools) > 0 {
p.state = CogitoCollectingContent
return
}
p.state = CogitoCollectingThinking
}
func (p *CogitoParser) Init(tools []api.Tool, lastMessage *api.Message, thinkValue *api.ThinkValue) []api.Tool {
p.setInitialState(lastMessage, tools, thinkValue)
return tools
}
type cogitoEvent interface {
isCogitoEvent()
}
type cogitoEventThinkingContent struct {
content string
}
type cogitoEventContent struct {
content string
}
type cogitoEventToolCall struct {
toolCall api.ToolCall
}
func (cogitoEventThinkingContent) isCogitoEvent() {}
func (cogitoEventContent) isCogitoEvent() {}
func (cogitoEventToolCall) isCogitoEvent() {}
func (p *CogitoParser) Add(s string, done bool) (content string, thinking string, calls []api.ToolCall, err error) {
p.buffer.WriteString(s)
events := p.parseEvents()
var toolCalls []api.ToolCall
var contentSb strings.Builder
var thinkingSb strings.Builder
for _, event := range events {
switch event := event.(type) {
case cogitoEventToolCall:
toolCalls = append(toolCalls, event.toolCall)
case cogitoEventThinkingContent:
thinkingSb.WriteString(event.content)
case cogitoEventContent:
contentSb.WriteString(event.content)
}
}
return contentSb.String(), thinkingSb.String(), toolCalls, nil
}
func (p *CogitoParser) parseEvents() []cogitoEvent {
var all []cogitoEvent
keepLooping := true
for keepLooping {
var events []cogitoEvent
events, keepLooping = p.eat()
if len(events) > 0 {
all = append(all, events...)
}
}
return all
}
func (p *CogitoParser) eat() ([]cogitoEvent, bool) {
var events []cogitoEvent
bufStr := p.buffer.String()
if bufStr == "" {
return events, false
}
switch p.state {
case CogitoCollectingThinking:
if strings.Contains(bufStr, cogitoThinkingCloseTag) { // thinking[</think>] -> content
split := strings.SplitN(bufStr, cogitoThinkingCloseTag, 2)
thinking := split[0]
thinking = strings.TrimRightFunc(thinking, unicode.IsSpace)
remaining := split[1]
remaining = strings.TrimLeftFunc(remaining, unicode.IsSpace)
p.buffer.Reset()
p.buffer.WriteString(remaining)
p.state = CogitoCollectingContent
if len(thinking) > 0 {
events = append(events, cogitoEventThinkingContent{content: thinking})
}
return events, true
} else if overlapLen := overlap(bufStr, cogitoThinkingCloseTag); overlapLen > 0 { // partial </think>
beforePartialTag := bufStr[:len(bufStr)-overlapLen]
trailingLen := trailingWhitespaceLen(beforePartialTag)
ambiguousStart := len(beforePartialTag) - trailingLen
unambiguous := bufStr[:ambiguousStart]
ambiguous := bufStr[ambiguousStart:]
p.buffer.Reset()
p.buffer.WriteString(ambiguous)
if len(unambiguous) > 0 {
events = append(events, cogitoEventThinkingContent{content: unambiguous})
}
return events, false
} else { // otherwise its thinking content
whitespaceLen := trailingWhitespaceLen(bufStr)
ambiguousStart := len(bufStr) - whitespaceLen
unambiguous := bufStr[:ambiguousStart]
ambiguous := bufStr[ambiguousStart:]
p.buffer.Reset()
p.buffer.WriteString(ambiguous)
if len(unambiguous) > 0 {
events = append(events, cogitoEventThinkingContent{content: unambiguous})
}
return events, false
}
case CogitoCollectingContent:
switch {
case strings.Contains(bufStr, cogitoToolCallsBeginTag): // content[<tool▁calls▁begin>] -> tool calls
split := strings.SplitN(bufStr, cogitoToolCallsBeginTag, 2)
contentBefore := strings.TrimRightFunc(split[0], unicode.IsSpace)
remaining := split[1]
p.buffer.Reset()
p.buffer.WriteString(remaining)
p.state = CogitoCollectingToolCalls
if len(contentBefore) > 0 {
events = append(events, cogitoEventContent{content: contentBefore})
}
return events, true
case strings.Contains(bufStr, cogitoToolOutputsBeginTag): // content[<tool▁outputs▁begin>] -> tool outputs
split := strings.SplitN(bufStr, cogitoToolOutputsBeginTag, 2)
contentBefore := strings.TrimRightFunc(split[0], unicode.IsSpace)
remaining := split[1]
p.buffer.Reset()
p.buffer.WriteString(remaining)
p.state = CogitoCollectingToolOutput
if len(contentBefore) > 0 {
events = append(events, cogitoEventContent{content: contentBefore})
}
return events, true
default: // otherwise its content
p.buffer.Reset()
if len(bufStr) > 0 {
events = append(events, cogitoEventContent{content: bufStr})
}
return events, false
}
case CogitoCollectingToolCalls:
if idx := strings.Index(bufStr, cogitoToolCallBeginTag); idx != -1 {
startIdx := idx + len(cogitoToolCallBeginTag)
if endIdx := strings.Index(bufStr[startIdx:], cogitoToolCallEndTag); endIdx != -1 {
toolCallContent := bufStr[startIdx : startIdx+endIdx]
if toolCall, err := p.parseToolCallContent(toolCallContent); err == nil {
remaining := bufStr[startIdx+endIdx+len(cogitoToolCallEndTag):]
remaining = strings.TrimLeftFunc(remaining, unicode.IsSpace)
p.buffer.Reset()
p.buffer.WriteString(remaining)
events = append(events, cogitoEventToolCall{toolCall: toolCall})
return events, true
} else {
slog.Warn("cogito tool call parsing failed", "error", err)
}
}
}
if idx := strings.Index(bufStr, cogitoToolCallsEndTag); idx != -1 {
remaining := bufStr[idx+len(cogitoToolCallsEndTag):]
remaining = strings.TrimLeftFunc(remaining, unicode.IsSpace)
p.buffer.Reset()
p.buffer.WriteString(remaining)
p.state = CogitoCollectingContent
return events, true
}
return events, false
case CogitoCollectingToolOutput:
if idx := strings.Index(bufStr, cogitoToolOutputBeginTag); idx != -1 {
startIdx := idx + len(cogitoToolOutputBeginTag)
if endIdx := strings.Index(bufStr[startIdx:], cogitoToolOutputEndTag); endIdx != -1 {
remaining := bufStr[startIdx+endIdx+len(cogitoToolOutputEndTag):]
remaining = strings.TrimLeftFunc(remaining, unicode.IsSpace)
p.buffer.Reset()
p.buffer.WriteString(remaining)
return events, true
}
}
if idx := strings.Index(bufStr, cogitoToolOutputsEndTag); idx != -1 {
remaining := bufStr[idx+len(cogitoToolOutputsEndTag):]
remaining = strings.TrimLeftFunc(remaining, unicode.IsSpace)
p.buffer.Reset()
p.buffer.WriteString(remaining)
p.state = CogitoCollectingContent
return events, true
}
return events, false
}
return events, false
}
func (p *CogitoParser) parseToolCallContent(content string) (api.ToolCall, error) {
// Expected format: function<tool▁sep>tool_name\n```json\n{args}\n```
parts := strings.SplitN(content, cogitoToolSepTag, 2)
if len(parts) < 2 {
return api.ToolCall{}, errors.New("invalid format")
}
nameAndArgs := parts[1]
jsonStart := strings.Index(nameAndArgs, "\n```json\n")
if jsonStart == -1 {
return api.ToolCall{}, errors.New("invalid format")
}
toolName := strings.TrimSpace(nameAndArgs[:jsonStart])
jsonContent := nameAndArgs[jsonStart+len("\n```json\n"):]
jsonEnd := strings.Index(jsonContent, "\n```")
if jsonEnd == -1 {
return api.ToolCall{}, errors.New("invalid format")
}
argsJSON := jsonContent[:jsonEnd]
var args api.ToolCallFunctionArguments
if err := json.Unmarshal([]byte(argsJSON), &args); err != nil {
return api.ToolCall{}, err
}
return api.ToolCall{
Function: api.ToolCallFunction{
Name: toolName,
Arguments: args,
},
}, nil
}

View File

@@ -0,0 +1,565 @@
package parsers
import (
"strings"
"testing"
"github.com/google/go-cmp/cmp"
"github.com/ollama/ollama/api"
)
func TestCogitoParser(t *testing.T) {
tests := []struct {
name string
input string
expectedContent string
expectedThinking string
expectedToolCalls []api.ToolCall
tools []api.Tool
lastMessage *api.Message
}{
{
name: "simple_content",
input: "This is a simple response.",
expectedContent: "This is a simple response.",
expectedThinking: "",
},
{
name: "thinking_only",
input: "This is thinking content.</think>This is response content.",
expectedContent: "This is response content.",
expectedThinking: "This is thinking content.",
},
{
name: "tool_call_simple",
input: `<tool▁calls▁begin><tool▁call▁begin>function<tool▁sep>get_weather
` + "```json\n" + `{"location":"Paris"}
` + "```" + `<tool▁call▁end><tool▁calls▁end>`,
expectedToolCalls: []api.ToolCall{
{
Function: api.ToolCallFunction{
Name: "get_weather",
Arguments: api.ToolCallFunctionArguments{
"location": "Paris",
},
},
},
},
tools: []api.Tool{
{
Type: "function",
Function: api.ToolFunction{
Name: "get_weather",
Parameters: api.ToolFunctionParameters{
Properties: map[string]api.ToolProperty{
"location": {Type: api.PropertyType{"string"}},
},
},
},
},
},
},
{
name: "thinking_with_tool_call",
input: `I need to check the weather.</think><tool▁calls▁begin><tool▁call▁begin>function<tool▁sep>get_weather
` + "```json\n" + `{"location":"Paris"}
` + "```" + `<tool▁call▁end><tool▁calls▁end>`,
expectedContent: "I need to check the weather.</think>",
expectedThinking: "", // No thinking when tools are present (Cogito-specific behavior)
expectedToolCalls: []api.ToolCall{
{
Function: api.ToolCallFunction{
Name: "get_weather",
Arguments: api.ToolCallFunctionArguments{
"location": "Paris",
},
},
},
},
tools: []api.Tool{
{
Type: "function",
Function: api.ToolFunction{
Name: "get_weather",
Parameters: api.ToolFunctionParameters{
Properties: map[string]api.ToolProperty{
"location": {Type: api.PropertyType{"string"}},
},
},
},
},
},
},
{
name: "multiple_tool_calls",
input: `<tool▁calls▁begin><tool▁call▁begin>function<tool▁sep>get_weather
` + "```json\n" + `{"location":"Paris"}
` + "```" + `<tool▁call▁end>
<tool▁call▁begin>function<tool▁sep>get_weather
` + "```json\n" + `{"location":"London"}
` + "```" + `<tool▁call▁end><tool▁calls▁end>`,
expectedToolCalls: []api.ToolCall{
{
Function: api.ToolCallFunction{
Name: "get_weather",
Arguments: api.ToolCallFunctionArguments{
"location": "Paris",
},
},
},
{
Function: api.ToolCallFunction{
Name: "get_weather",
Arguments: api.ToolCallFunctionArguments{
"location": "London",
},
},
},
},
tools: []api.Tool{
{
Type: "function",
Function: api.ToolFunction{
Name: "get_weather",
Parameters: api.ToolFunctionParameters{
Properties: map[string]api.ToolProperty{
"location": {Type: api.PropertyType{"string"}},
},
},
},
},
},
},
{
name: "complex_tool_arguments",
input: `<tool▁calls▁begin><tool▁call▁begin>function<tool▁sep>process_data
` + "```json\n" + `{"items":["item1","item2"],"config":{"enabled":true,"threshold":0.95},"count":42}
` + "```" + `<tool▁call▁end><tool▁calls▁end>`,
expectedToolCalls: []api.ToolCall{
{
Function: api.ToolCallFunction{
Name: "process_data",
Arguments: api.ToolCallFunctionArguments{
"items": []any{"item1", "item2"},
"config": map[string]any{"enabled": true, "threshold": 0.95},
"count": 42.0,
},
},
},
},
},
{
name: "tool_output_parsing",
input: `<tool▁outputs▁begin><tool▁output▁begin>{"temperature": 22, "condition": "sunny"}<tool▁output▁end><tool▁outputs▁end>`,
expectedContent: "",
expectedThinking: "",
},
{
name: "thinking_with_multiline_content",
input: `This is line 1
This is line 2
This is line 3</think>Final response here.`,
expectedContent: "Final response here.",
expectedThinking: "This is line 1\nThis is line 2\nThis is line 3",
},
{
name: "no_thinking_simple",
input: "This is content.",
expectedContent: "This is content.",
expectedThinking: "",
},
{
name: "prefill_content_only",
input: "Continuing from previous content.",
expectedContent: "Continuing from previous content.",
lastMessage: &api.Message{
Role: "assistant",
Content: "Previous content",
},
},
{
name: "prefill_with_thinking",
input: "Continuing thinking</think>Continuing content.",
expectedContent: "Continuing content.",
expectedThinking: "Continuing thinking",
lastMessage: &api.Message{
Role: "assistant",
},
},
// Edge cases
{
name: "nested_think_tags_in_thinking",
input: "I'm thinking <think>nested</think> more thinking</think>Final content.",
expectedContent: "more thinking</think>Final content.",
expectedThinking: "I'm thinking <think>nested",
},
{
name: "multiple_think_close_tags",
input: "First thinking</think>Content</think>More content.",
expectedContent: "Content</think>More content.",
expectedThinking: "First thinking",
},
{
name: "empty_thinking_content",
input: "</think>Just content here.",
expectedContent: "</think>Just content here.",
expectedThinking: "",
},
{
name: "thinking_disabled_with_think_tags",
input: "Content with </think> tags should be treated as content.",
expectedContent: "Content with </think> tags should be treated as content.",
expectedThinking: "",
lastMessage: &api.Message{
Role: "assistant",
Content: "existing", // Forces non-thinking mode
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Use thinking-enabled parser for tests that expect thinking
hasThinking := tt.expectedThinking != ""
parser := &CogitoParser{} // it has thinking support
parser.Init(tt.tools, tt.lastMessage, &api.ThinkValue{Value: hasThinking}) // but we should set it with the request that the user wants
content, thinking, toolCalls, err := parser.Add(tt.input, true)
if err != nil {
t.Fatalf("Add() error = %v", err)
}
if diff := cmp.Diff(tt.expectedContent, content); diff != "" {
t.Errorf("content mismatch (-want +got):\n%s", diff)
}
if diff := cmp.Diff(tt.expectedThinking, thinking); diff != "" {
t.Errorf("thinking mismatch (-want +got):\n%s", diff)
}
if diff := cmp.Diff(tt.expectedToolCalls, toolCalls); diff != "" {
t.Errorf("tool calls mismatch (-want +got):\n%s", diff)
}
})
}
}
func TestCogitoParser_Streaming(t *testing.T) {
parser := &CogitoParser{}
parser.Init(nil, nil, &api.ThinkValue{Value: true})
chunks := []string{
"This is ",
"thinking content",
".</think>This is ",
"content.<tool▁calls▁begin><tool▁call▁begin>function<tool▁sep>test_tool\n```json\n{\"arg\":\"value\"}\n```<tool▁call▁end><tool▁calls▁end>",
}
var finalContent, finalThinking strings.Builder
var finalToolCalls []api.ToolCall
for i, chunk := range chunks {
done := i == len(chunks)-1
content, thinking, toolCalls, err := parser.Add(chunk, done)
if err != nil {
t.Fatalf("Add() error on chunk %d: %v", i, err)
}
finalContent.WriteString(content)
finalThinking.WriteString(thinking)
finalToolCalls = append(finalToolCalls, toolCalls...)
}
expectedContent := "This is content."
expectedThinking := "This is thinking content."
expectedToolCalls := []api.ToolCall{
{
Function: api.ToolCallFunction{
Name: "test_tool",
Arguments: api.ToolCallFunctionArguments{
"arg": "value",
},
},
},
}
if finalContent.String() != expectedContent {
t.Errorf("expected content %q, got %q", expectedContent, finalContent.String())
}
if finalThinking.String() != expectedThinking {
t.Errorf("expected thinking %q, got %q", expectedThinking, finalThinking.String())
}
if diff := cmp.Diff(expectedToolCalls, finalToolCalls); diff != "" {
t.Errorf("tool calls mismatch (-want +got):\n%s", diff)
}
}
func TestCogitoParser_StreamingEdgeCases(t *testing.T) {
tests := []struct {
name string
chunks []string
expectedContent string
expectedThinking string
expectedToolCalls []api.ToolCall
hasThinkingSupport bool
}{
{
name: "split_thinking_tag",
chunks: []string{
"This is thinking content</thi",
"nk>This is content.",
},
expectedContent: "This is content.",
expectedThinking: "This is thinking content",
hasThinkingSupport: true,
},
{
name: "split_tool_calls_begin_tag_conservative_parsing",
chunks: []string{
"Content before<tool▁calls▁beg",
"in><tool▁call▁begin>function<tool▁sep>test\n```json\n{}\n```<tool▁call▁end><tool▁calls▁end>",
},
// Parser is conservative - treats incomplete tags as content
expectedContent: "Content before<tool▁calls▁begin><tool▁call▁begin>function<tool▁sep>test\n```json\n{}\n```<tool▁call▁end><tool▁calls▁end>",
expectedToolCalls: nil,
hasThinkingSupport: false,
},
{
name: "thinking_disabled_with_split_tags",
chunks: []string{
"Content with </thi",
"nk> should be treated as content.",
},
expectedContent: "Content with </think> should be treated as content.",
expectedThinking: "",
hasThinkingSupport: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
parser := &CogitoParser{}
parser.Init(nil, nil, &api.ThinkValue{Value: tt.hasThinkingSupport})
var finalContent, finalThinking strings.Builder
var finalToolCalls []api.ToolCall
for i, chunk := range tt.chunks {
done := i == len(tt.chunks)-1
content, thinking, toolCalls, err := parser.Add(chunk, done)
if err != nil {
t.Fatalf("Add() error on chunk %d: %v", i, err)
}
finalContent.WriteString(content)
finalThinking.WriteString(thinking)
finalToolCalls = append(finalToolCalls, toolCalls...)
}
if finalContent.String() != tt.expectedContent {
t.Errorf("expected content %q, got %q", tt.expectedContent, finalContent.String())
}
if finalThinking.String() != tt.expectedThinking {
t.Errorf("expected thinking %q, got %q", tt.expectedThinking, finalThinking.String())
}
if diff := cmp.Diff(tt.expectedToolCalls, finalToolCalls); diff != "" {
t.Errorf("tool calls mismatch (-want +got):\n%s", diff)
}
})
}
}
func TestCogitoParser_HasToolSupport(t *testing.T) {
parser := &CogitoParser{}
if !parser.HasToolSupport() {
t.Error("CogitoParser should support tools")
}
}
func TestCogitoParser_Init(t *testing.T) {
parser := &CogitoParser{}
tools := []api.Tool{
{Function: api.ToolFunction{Name: "test_tool"}},
}
lastMessage := &api.Message{Role: "assistant", Content: "previous"}
returnedTools := parser.Init(tools, lastMessage, nil)
if len(returnedTools) != len(tools) {
t.Errorf("expected %d tools returned, got %d", len(tools), len(returnedTools))
}
}
func TestCogitoParser_parseToolCallContent(t *testing.T) {
tests := []struct {
name string
content string
expected api.ToolCall
expectError bool
}{
{
name: "valid_tool_call_standard_format",
content: `function<tool▁sep>get_weather
` + "```json\n" + `{"location":"Paris"}
` + "```",
expected: api.ToolCall{
Function: api.ToolCallFunction{
Name: "get_weather",
Arguments: api.ToolCallFunctionArguments{
"location": "Paris",
},
},
},
expectError: false,
},
{
name: "valid_tool_call_complex_args",
content: `function<tool▁sep>process_data
` + "```json\n" + `{"items":["item1","item2"],"config":{"enabled":true},"count":42}
` + "```",
expected: api.ToolCall{
Function: api.ToolCallFunction{
Name: "process_data",
Arguments: api.ToolCallFunctionArguments{
"items": []any{"item1", "item2"},
"config": map[string]any{"enabled": true},
"count": 42.0,
},
},
},
expectError: false,
},
{
name: "valid_tool_call_empty_args",
content: `function<tool▁sep>no_args_tool
` + "```json\n" + `{}
` + "```",
expected: api.ToolCall{
Function: api.ToolCallFunction{
Name: "no_args_tool",
Arguments: api.ToolCallFunctionArguments{},
},
},
expectError: false,
},
{
name: "missing_separator",
content: `functionget_weather` + "```json\n" + `{"location":"Paris"}` + "\n```",
expected: api.ToolCall{},
expectError: true,
},
{
name: "invalid_function_type",
content: `not_function<tool▁sep>get_weather` + "```json\n" + `{"location":"Paris"}` + "\n```",
expected: api.ToolCall{},
expectError: true,
},
{
name: "missing_json_block_start",
content: `function<tool▁sep>get_weather{"location":"Paris"}` + "```",
expected: api.ToolCall{},
expectError: true,
},
{
name: "missing_json_block_end",
content: `function<tool▁sep>get_weather` + "```json\n" + `{"location":"Paris"}`,
expected: api.ToolCall{},
expectError: true,
},
{
name: "invalid_json",
content: `function<tool▁sep>get_weather` + "```json\n" + `{location:Paris}` + "\n```",
expected: api.ToolCall{},
expectError: true,
},
{
name: "empty_function_type",
content: `<tool▁sep>get_weather` + "```json\n" + `{"location":"Paris"}` + "\n```",
expected: api.ToolCall{},
expectError: true,
},
{
name: "tool_with_spaces_in_name",
content: `function<tool▁sep> get_weather
` + "```json\n" + `{"location":"Paris"}
` + "```",
expected: api.ToolCall{
Function: api.ToolCallFunction{
Name: "get_weather",
Arguments: api.ToolCallFunctionArguments{
"location": "Paris",
},
},
},
expectError: false,
},
{
name: "tool_with_multiline_json",
content: `function<tool▁sep>get_weather
` + "```json\n" + `{
"location": "Paris",
"units": "metric"
}
` + "```",
expected: api.ToolCall{
Function: api.ToolCallFunction{
Name: "get_weather",
Arguments: api.ToolCallFunctionArguments{
"location": "Paris",
"units": "metric",
},
},
},
expectError: false,
},
{
name: "tool_with_nested_objects",
content: `function<tool▁sep>complex_tool
` + "```json\n" + `{"nested":{"deep":{"value":123}}}
` + "```",
expected: api.ToolCall{
Function: api.ToolCallFunction{
Name: "complex_tool",
Arguments: api.ToolCallFunctionArguments{
"nested": map[string]any{
"deep": map[string]any{
"value": 123.0,
},
},
},
},
},
expectError: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
parser := &CogitoParser{}
result, err := parser.parseToolCallContent(tt.content)
if tt.expectError {
if err == nil {
t.Errorf("expected error but got none")
}
return
}
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if diff := cmp.Diff(tt.expected, result); diff != "" {
t.Errorf("tool call mismatch (-want +got):\n%s", diff)
}
})
}
}

136
model/parsers/ministral.go Normal file
View File

@@ -0,0 +1,136 @@
package parsers
import (
"encoding/json"
"fmt"
"strings"
"github.com/ollama/ollama/api"
)
type ministralParserState int
const (
ministralCollectingContent = iota
ministralCollectingThinkingContent
ministralCollectingToolName
ministralCollectingToolArgs
)
type MinistralParser struct {
state ministralParserState
buffer strings.Builder
tools []api.Tool
hasThinkingSupport bool
currentTool *api.Tool
}
func (p *MinistralParser) HasToolSupport() bool {
return true
}
func (p *MinistralParser) HasThinkingSupport() bool {
return p.hasThinkingSupport
}
func (p *MinistralParser) setInitialState(lastMessage *api.Message) {
prefill := lastMessage != nil && lastMessage.Role == "assistant"
if !p.HasThinkingSupport() {
p.state = ministralCollectingContent
return
}
if prefill && lastMessage.Content != "" {
p.state = ministralCollectingContent
return
}
p.state = ministralCollectingThinkingContent
}
func (p *MinistralParser) Init(tools []api.Tool, lastMessage *api.Message, thinkValue *api.ThinkValue) []api.Tool {
p.tools = tools
p.setInitialState(lastMessage)
return tools
}
func toolByName(tools []api.Tool, n string) (*api.Tool, error) {
for i := range tools {
if tools[i].Function.Name == n {
return &tools[i], nil
}
}
return nil, fmt.Errorf("tool '%s' not found", n)
}
func (p *MinistralParser) Add(s string, done bool) (content string, thinking string, calls []api.ToolCall, err error) {
p.buffer.WriteString(s)
switch p.state {
case ministralCollectingContent:
if strings.Contains(p.buffer.String(), "[TOOL_CALLS]") {
before, _ := splitAtTag(&p.buffer, "[TOOL_CALLS]", false)
if before != "" {
return before, "", calls, nil
}
p.state = ministralCollectingToolName
} else if strings.Contains(p.buffer.String(), "[THINK]") {
p.state = ministralCollectingThinkingContent
return "", "", calls, nil
} else {
p.buffer.Reset()
return s, "", calls, nil
}
case ministralCollectingThinkingContent:
if strings.Contains(p.buffer.String(), "[/THINK]") {
thinkingContent, after := splitAtTag(&p.buffer, "[/THINK]", true)
p.state = ministralCollectingContent
if after != "" {
p.buffer.Reset()
return after, thinkingContent, calls, nil
}
return "", thinkingContent, calls, nil
} else {
p.buffer.Reset()
return "", s, calls, nil
}
case ministralCollectingToolName:
if strings.Contains(p.buffer.String(), "[ARGS]") {
name, _ := splitAtTag(&p.buffer, "[ARGS]", false)
t, err := toolByName(p.tools, name)
if err != nil {
return "", "", calls, err
}
p.currentTool = t
p.state = ministralCollectingToolArgs
return "", "", calls, nil
}
return "", "", calls, nil
case ministralCollectingToolArgs:
if strings.Contains(p.buffer.String(), "}") {
before, _ := splitAtTag(&p.buffer, "}", false)
before += "}"
var data map[string]any
if err := json.Unmarshal([]byte(before), &data); err != nil {
// todo - throw a better error
return "", "", calls, err
}
p.state = ministralCollectingContent
call := api.ToolCall{
Function: api.ToolCallFunction{
Name: p.currentTool.Function.Name,
Arguments: api.ToolCallFunctionArguments(data),
},
}
calls = append(calls, call)
return "", "", calls, nil
}
return "", "", calls, nil
}
return p.buffer.String(), thinking, calls, nil
}

View File

@@ -1,14 +1,17 @@
package parsers
import (
"strings"
"unicode"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/harmony"
)
type Parser interface {
// Init initializes the parser with tools and optional last message for chat prefill
// Init initializes the parser with tools, optional last message for chat prefill, and think value
// Returns processed tools if the parser needs to modify them (e.g., harmony renames them)
Init(tools []api.Tool, lastMessage *api.Message) []api.Tool
Init(tools []api.Tool, lastMessage *api.Message, thinkValue *api.ThinkValue) []api.Tool
// Add processes streamed content and returns parsed content, thinking, and tool calls
// The done flag indicates if this is the last chunk (used for draining accumulators)
Add(s string, done bool) (content string, thinking string, calls []api.ToolCall, err error)
@@ -38,28 +41,32 @@ func ParserForName(name string) Parser {
if parser, ok := registry.constructors[name]; ok {
return parser()
}
var p Parser
switch name {
case "qwen3-coder":
parser := &Qwen3CoderParser{}
return parser
p = &Qwen3CoderParser{}
case "qwen3-vl-instruct":
parser := &Qwen3VLParser{hasThinkingSupport: false}
return parser
p = &Qwen3VLParser{hasThinkingSupport: false}
case "qwen3-vl-thinking":
parser := &Qwen3VLParser{hasThinkingSupport: true}
return parser
p = &Qwen3VLParser{hasThinkingSupport: true}
case "ministral":
p = &MinistralParser{hasThinkingSupport: false}
case "passthrough":
return &PassthroughParser{}
case "harmony":
return harmony.NewHarmonyMessageHandler()
case "cogito":
return &CogitoParser{}
default:
return nil
}
return p
}
type PassthroughParser struct{}
func (p *PassthroughParser) Init(tools []api.Tool, lastMessage *api.Message) []api.Tool {
func (p *PassthroughParser) Init(tools []api.Tool, lastMessage *api.Message, thinkValue *api.ThinkValue) []api.Tool {
return tools // passthrough doesn't modify tools
}
@@ -74,3 +81,20 @@ func (p *PassthroughParser) HasToolSupport() bool {
func (p *PassthroughParser) HasThinkingSupport() bool {
return false
}
func splitAtTag(sb *strings.Builder, tag string, trimAfter bool) (string, string) {
split := strings.SplitN(sb.String(), tag, 2)
if len(split) == 1 {
sb.Reset()
return split[0], ""
}
before := split[0]
before = strings.TrimRightFunc(before, unicode.IsSpace)
after := split[1]
if trimAfter {
after = strings.TrimLeftFunc(after, unicode.IsSpace)
}
sb.Reset()
sb.WriteString(after)
return before, after // return events
}

View File

@@ -1,6 +1,7 @@
package parsers
import (
"strings"
"testing"
"github.com/ollama/ollama/api"
@@ -10,7 +11,7 @@ type mockParser struct {
name string
}
func (m *mockParser) Init(tools []api.Tool, lastMessage *api.Message) []api.Tool {
func (m *mockParser) Init(tools []api.Tool, lastMessage *api.Message, thinkValue *api.ThinkValue) []api.Tool {
return tools
}
@@ -95,3 +96,164 @@ func TestUnknownParserReturnsNil(t *testing.T) {
t.Error("expected nil for unknown parser")
}
}
func TestSplitAtTag(t *testing.T) {
tests := []struct {
name string
input string
tag string
trimAfter bool
wantBefore string
wantAfter string
wantSB string // expected content of strings.Builder after operation
}{
{
name: "basic split with trimAfter true",
input: "hello <!-- split --> world",
tag: "<!-- split -->",
trimAfter: true,
wantBefore: "hello",
wantAfter: "world",
wantSB: "world",
},
{
name: "basic split with trimAfter false",
input: "hello <!-- split --> world",
tag: "<!-- split -->",
trimAfter: false,
wantBefore: "hello",
wantAfter: " world",
wantSB: " world",
},
{
name: "tag at beginning with trimAfter true",
input: "<!-- split -->world",
tag: "<!-- split -->",
trimAfter: true,
wantBefore: "",
wantAfter: "world",
wantSB: "world",
},
{
name: "tag at beginning with trimAfter false",
input: "<!-- split --> world",
tag: "<!-- split -->",
trimAfter: false,
wantBefore: "",
wantAfter: " world",
wantSB: " world",
},
{
name: "tag at end with trimAfter true",
input: "hello <!-- split -->",
tag: "<!-- split -->",
trimAfter: true,
wantBefore: "hello",
wantAfter: "",
wantSB: "",
},
{
name: "tag at end with trimAfter false",
input: "hello <!-- split -->",
tag: "<!-- split -->",
trimAfter: false,
wantBefore: "hello",
wantAfter: "",
wantSB: "",
},
{
name: "multiple tags splits at first occurrence",
input: "hello <!-- split --> world <!-- split --> end",
tag: "<!-- split -->",
trimAfter: true,
wantBefore: "hello",
wantAfter: "world <!-- split --> end",
wantSB: "world <!-- split --> end",
},
{
name: "tag not present",
input: "hello world",
tag: "<!-- split -->",
trimAfter: true,
wantBefore: "hello world",
wantAfter: "",
wantSB: "",
},
{
name: "empty input",
input: "",
tag: "<!-- split -->",
trimAfter: true,
wantBefore: "",
wantAfter: "",
wantSB: "",
},
{
name: "only whitespace before tag",
input: " \t\n<!-- split -->world",
tag: "<!-- split -->",
trimAfter: true,
wantBefore: "",
wantAfter: "world",
wantSB: "world",
},
{
name: "only whitespace after tag with trimAfter true",
input: "hello<!-- split --> \t\n",
tag: "<!-- split -->",
trimAfter: true,
wantBefore: "hello",
wantAfter: "",
wantSB: "",
},
{
name: "only whitespace after tag with trimAfter false",
input: "hello<!-- split --> \t\n",
tag: "<!-- split -->",
trimAfter: false,
wantBefore: "hello",
wantAfter: " \t\n",
wantSB: " \t\n",
},
{
name: "complex whitespace trimming",
input: " hello \t\n <!-- split --> \n\t world ",
tag: "<!-- split -->",
trimAfter: true,
wantBefore: " hello",
wantAfter: "world ",
wantSB: "world ",
},
{
name: "tag with special characters",
input: "text <tag attr=\"value\"> more text",
tag: "<tag attr=\"value\">",
trimAfter: true,
wantBefore: "text",
wantAfter: "more text",
wantSB: "more text",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
sb := &strings.Builder{}
sb.WriteString(tt.input)
before, after := splitAtTag(sb, tt.tag, tt.trimAfter)
// Check return values
if before != tt.wantBefore {
t.Errorf("splitAtTag() before = %q, want %q", before, tt.wantBefore)
}
if after != tt.wantAfter {
t.Errorf("splitAtTag() after = %q, want %q", after, tt.wantAfter)
}
// Check strings.Builder state
if sb.String() != tt.wantSB {
t.Errorf("strings.Builder after split = %q, want %q", sb.String(), tt.wantSB)
}
})
}
}

View File

@@ -43,7 +43,7 @@ func (p *Qwen3CoderParser) HasThinkingSupport() bool {
return false
}
func (p *Qwen3CoderParser) Init(tools []api.Tool, lastMessage *api.Message) []api.Tool {
func (p *Qwen3CoderParser) Init(tools []api.Tool, lastMessage *api.Message, thinkValue *api.ThinkValue) []api.Tool {
p.tools = tools
return tools // Qwen doesn't modify tools
}
@@ -432,7 +432,7 @@ func transformToXML(raw string) string {
groups := qwenTagRegex.FindStringSubmatch(match)
tag := groups[1]
var escapedValue strings.Builder
xml.EscapeText(&escapedValue, []byte(groups[2]))
_ = xml.EscapeText(&escapedValue, []byte(groups[2])) // error is always nil for strings.Builder
return fmt.Sprintf(`<%s name="%s">`, tag, escapedValue.String())
})

View File

@@ -54,7 +54,7 @@ func (p *Qwen3VLParser) setInitialState(lastMessage *api.Message) {
p.state = CollectingThinkingContent
}
func (p *Qwen3VLParser) Init(tools []api.Tool, lastMessage *api.Message) []api.Tool {
func (p *Qwen3VLParser) Init(tools []api.Tool, lastMessage *api.Message, thinkValue *api.ThinkValue) []api.Tool {
p.tools = tools
p.setInitialState(lastMessage)
return tools
@@ -70,7 +70,6 @@ func (p *Qwen3VLParser) Add(s string, done bool) (content string, thinking strin
p.buffer.WriteString(s)
events := p.parseEvents()
var toolCalls []api.ToolCall
var contentSb strings.Builder
var thinkingSb strings.Builder
for _, event := range events {
@@ -81,7 +80,7 @@ func (p *Qwen3VLParser) Add(s string, done bool) (content string, thinking strin
slog.Warn("qwen tool call parsing failed", "error", err)
return "", "", nil, err
}
toolCalls = append(toolCalls, toolCall)
calls = append(calls, toolCall)
case qwenEventThinkingContent:
thinkingSb.WriteString(event.content)
case qwenEventContent:
@@ -91,7 +90,7 @@ func (p *Qwen3VLParser) Add(s string, done bool) (content string, thinking strin
}
}
return contentSb.String(), thinkingSb.String(), toolCalls, nil
return contentSb.String(), thinkingSb.String(), calls, nil
}
func (p *Qwen3VLParser) parseEvents() []qwenEvent {
@@ -113,19 +112,6 @@ func (p *Qwen3VLParser) parseEvents() []qwenEvent {
return all
}
func splitAtTag(p *Qwen3VLParser, tag string, trimAfter bool) (string, string) {
split := strings.SplitN(p.buffer.String(), tag, 2)
before := split[0]
before = strings.TrimRightFunc(before, unicode.IsSpace)
after := split[1]
if trimAfter {
after = strings.TrimLeftFunc(after, unicode.IsSpace)
}
p.buffer.Reset()
p.buffer.WriteString(after)
return before, after // return events
}
func (p *Qwen3VLParser) eatLeadingWhitespaceAndTransitionTo(nextState qwenParserState) ([]qwenEvent, bool) {
trimmed := strings.TrimLeftFunc(p.buffer.String(), unicode.IsSpace)
p.buffer.Reset()
@@ -144,7 +130,7 @@ func (p *Qwen3VLParser) eat() ([]qwenEvent, bool) {
case CollectingContent:
if strings.Contains(p.buffer.String(), toolOpenTag) {
// events = emitContentBeforeTag(p, events, toolOpenTag)
before, _ := splitAtTag(p, toolOpenTag, false)
before, _ := splitAtTag(&p.buffer, toolOpenTag, false)
if len(before) > 0 {
events = append(events, qwenEventContent{content: before})
}
@@ -195,7 +181,7 @@ func (p *Qwen3VLParser) eat() ([]qwenEvent, bool) {
}
case CollectingThinkingContent:
if strings.Contains(p.buffer.String(), thinkingCloseTag) {
thinking, remaining := splitAtTag(p, thinkingCloseTag, true)
thinking, remaining := splitAtTag(&p.buffer, thinkingCloseTag, true)
if len(thinking) > 0 {
events = append(events, qwenEventThinkingContent{content: thinking})
}

View File

@@ -198,7 +198,7 @@ func TestQwen3VLNonThinkingParserStreaming(t *testing.T) {
t.Run(tc.desc, func(t *testing.T) {
parser := Qwen3VLParser{hasThinkingSupport: false}
parser.Init([]api.Tool{}, nil)
parser.Init([]api.Tool{}, nil, nil)
for i, step := range tc.steps {
parser.buffer.WriteString(step.input)
@@ -515,7 +515,7 @@ func TestQwenOldParserStreaming(t *testing.T) {
t.Run(tc.desc, func(t *testing.T) {
parser := Qwen3VLParser{hasThinkingSupport: false}
parser.Init([]api.Tool{}, nil)
parser.Init([]api.Tool{}, nil, nil)
for i, step := range tc.steps {
parser.buffer.WriteString(step.input)
@@ -822,7 +822,7 @@ func TestQwen3VLNonThinkingToolCallWhitespaceHandling(t *testing.T) {
t.Run(tc.desc, func(t *testing.T) {
parser := Qwen3VLParser{hasThinkingSupport: false}
parser.Init([]api.Tool{}, nil)
parser.Init([]api.Tool{}, nil, nil)
for i, step := range tc.steps {
parser.buffer.WriteString(step.input)

View File

@@ -205,7 +205,7 @@ func TestQwen3VLThinkingParserStreaming(t *testing.T) {
t.Run(tc.desc, func(t *testing.T) {
parser := Qwen3VLParser{hasThinkingSupport: true}
parser.Init([]api.Tool{}, nil)
parser.Init([]api.Tool{}, nil, nil)
// parser.state = CollectingThinkingContent
for i, step := range tc.steps {
@@ -386,7 +386,7 @@ func TestQwen3VLParserState(t *testing.T) {
for _, tc := range cases {
parser := Qwen3VLParser{hasThinkingSupport: tc.hasThinking}
parser.Init(nil, tc.last)
parser.Init(nil, tc.last, nil)
if parser.state != tc.wantState {
t.Errorf("%s: got state %v, want %v", tc.desc, parser.state, tc.wantState)
}
@@ -437,7 +437,7 @@ func TestQwen3VLThinkingParserWithThinkingPrefill(t *testing.T) {
for _, tc := range cases {
t.Run(tc.desc, func(t *testing.T) {
parser := Qwen3VLParser{hasThinkingSupport: true}
parser.Init([]api.Tool{}, last)
parser.Init([]api.Tool{}, last, nil)
for i, step := range tc.steps {
parser.buffer.WriteString(step.input)
@@ -500,7 +500,7 @@ func TestQwen3VLThinkingParserWithNonThinkingPrefill(t *testing.T) {
for _, tc := range cases {
t.Run(tc.desc, func(t *testing.T) {
parser := Qwen3VLParser{hasThinkingSupport: true}
parser.Init([]api.Tool{}, last)
parser.Init([]api.Tool{}, last, nil)
for i, step := range tc.steps {
parser.buffer.WriteString(step.input)
@@ -523,7 +523,7 @@ func TestQwen3VLThinkingParserStreamingAssistantPrefillContent(t *testing.T) {
// last message is assistant with content ⇒ start in CollectingContent
last := &api.Message{Role: "assistant", Content: "has content"}
parser := Qwen3VLParser{hasThinkingSupport: true}
parser.Init([]api.Tool{}, last)
parser.Init([]api.Tool{}, last, nil)
type step struct {
input string
@@ -750,7 +750,7 @@ func TestQwen3VLThinkingWhitespaceHandling(t *testing.T) {
t.Run(tc.desc, func(t *testing.T) {
parser := Qwen3VLParser{hasThinkingSupport: true}
parser.Init([]api.Tool{}, nil)
parser.Init([]api.Tool{}, nil, nil)
for i, step := range tc.steps {
parser.buffer.WriteString(step.input)
@@ -859,7 +859,7 @@ func TestQwen3VLToolCallWhitespaceHandling(t *testing.T) {
t.Run(tc.desc, func(t *testing.T) {
parser := Qwen3VLParser{hasThinkingSupport: true}
parser.Init([]api.Tool{}, tc.prefillMsg)
parser.Init([]api.Tool{}, tc.prefillMsg, nil)
for i, step := range tc.steps {
parser.buffer.WriteString(step.input)

View File

@@ -340,7 +340,7 @@ func (s *Server) GenerateHandler(c *gin.Context) {
builtinParser = parsers.ParserForName(m.Config.Parser)
if builtinParser != nil {
// no tools or last message for generate endpoint
builtinParser.Init(nil, nil)
builtinParser.Init(nil, nil, req.Think)
}
}
@@ -2051,7 +2051,7 @@ func (s *Server) ChatHandler(c *gin.Context) {
lastMessage = &msgs[len(msgs)-1]
}
// Initialize parser and get processed tools
processedTools = builtinParser.Init(req.Tools, lastMessage)
processedTools = builtinParser.Init(req.Tools, lastMessage, req.Think)
}
}