Compare commits

..

1 Commits

Author SHA1 Message Date
Michael Yang
c11b70da72 simplify expand path 2025-11-18 13:59:24 -08:00
58 changed files with 848 additions and 4417 deletions

4
.gitattributes vendored
View File

@@ -15,12 +15,8 @@ ml/backend/**/*.cu linguist-vendored
ml/backend/**/*.cuh linguist-vendored
ml/backend/**/*.m linguist-vendored
ml/backend/**/*.metal linguist-vendored
ml/backend/**/*.comp linguist-vendored
ml/backend/**/*.glsl linguist-vendored
ml/backend/**/CMakeLists.txt linguist-vendored
app/webview linguist-vendored
llama/build-info.cpp linguist-generated
ml/backend/ggml/ggml/src/ggml-metal/ggml-metal-embed.s linguist-generated

View File

@@ -226,14 +226,7 @@ func (c *Client) stream(ctx context.Context, method, path string, data any, fn f
bts := scanner.Bytes()
if err := json.Unmarshal(bts, &errorResponse); err != nil {
if response.StatusCode >= http.StatusBadRequest {
return StatusError{
StatusCode: response.StatusCode,
Status: response.Status,
ErrorMessage: string(bts),
}
}
return errors.New(string(bts))
return fmt.Errorf("unmarshal: %w", err)
}
if response.StatusCode == http.StatusUnauthorized {

View File

@@ -55,7 +55,6 @@ func TestClientFromEnvironment(t *testing.T) {
type testError struct {
message string
statusCode int
raw bool // if true, write message as-is instead of JSON encoding
}
func (e testError) Error() string {
@@ -112,20 +111,6 @@ func TestClientStream(t *testing.T) {
},
},
},
{
name: "plain text error response",
responses: []any{
"internal server error",
},
wantErr: "internal server error",
},
{
name: "HTML error page",
responses: []any{
"<html><body>404 Not Found</body></html>",
},
wantErr: "404 Not Found",
},
}
for _, tc := range testCases {
@@ -150,12 +135,6 @@ func TestClientStream(t *testing.T) {
return
}
if str, ok := resp.(string); ok {
fmt.Fprintln(w, str)
flusher.Flush()
continue
}
if err := json.NewEncoder(w).Encode(resp); err != nil {
t.Fatalf("failed to encode response: %v", err)
}
@@ -194,10 +173,9 @@ func TestClientStream(t *testing.T) {
func TestClientDo(t *testing.T) {
testCases := []struct {
name string
response any
wantErr string
wantStatusCode int
name string
response any
wantErr string
}{
{
name: "immediate error response",
@@ -205,8 +183,7 @@ func TestClientDo(t *testing.T) {
message: "test error message",
statusCode: http.StatusBadRequest,
},
wantErr: "test error message",
wantStatusCode: http.StatusBadRequest,
wantErr: "test error message",
},
{
name: "server error response",
@@ -214,8 +191,7 @@ func TestClientDo(t *testing.T) {
message: "internal error",
statusCode: http.StatusInternalServerError,
},
wantErr: "internal error",
wantStatusCode: http.StatusInternalServerError,
wantErr: "internal error",
},
{
name: "successful response",
@@ -227,26 +203,6 @@ func TestClientDo(t *testing.T) {
Success: true,
},
},
{
name: "plain text error response",
response: testError{
message: "internal server error",
statusCode: http.StatusInternalServerError,
raw: true,
},
wantErr: "internal server error",
wantStatusCode: http.StatusInternalServerError,
},
{
name: "HTML error page",
response: testError{
message: "<html><body>404 Not Found</body></html>",
statusCode: http.StatusNotFound,
raw: true,
},
wantErr: "<html><body>404 Not Found</body></html>",
wantStatusCode: http.StatusNotFound,
},
}
for _, tc := range testCases {
@@ -254,16 +210,11 @@ func TestClientDo(t *testing.T) {
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if errResp, ok := tc.response.(testError); ok {
w.WriteHeader(errResp.statusCode)
if !errResp.raw {
err := json.NewEncoder(w).Encode(map[string]string{
"error": errResp.message,
})
if err != nil {
t.Fatal("failed to encode error response:", err)
}
} else {
// Write raw message (simulates non-JSON error responses)
fmt.Fprint(w, errResp.message)
err := json.NewEncoder(w).Encode(map[string]string{
"error": errResp.message,
})
if err != nil {
t.Fatal("failed to encode error response:", err)
}
return
}
@@ -290,15 +241,6 @@ func TestClientDo(t *testing.T) {
if err.Error() != tc.wantErr {
t.Errorf("error message mismatch: got %q, want %q", err.Error(), tc.wantErr)
}
if tc.wantStatusCode != 0 {
if statusErr, ok := err.(StatusError); ok {
if statusErr.StatusCode != tc.wantStatusCode {
t.Errorf("status code mismatch: got %d, want %d", statusErr.StatusCode, tc.wantStatusCode)
}
} else {
t.Errorf("expected StatusError, got %T", err)
}
}
return
}

View File

@@ -397,8 +397,8 @@ func checkUserLoggedIn(uiServerPort int) bool {
// handleConnectURLScheme fetches the connect URL and opens it in the browser
func handleConnectURLScheme() {
if checkUserLoggedIn(uiServerPort) {
slog.Info("user is already logged in, opening app instead")
showWindow(wv.webview.Window())
slog.Info("user is already logged in, opening settings instead")
sendUIRequestMessage("/")
return
}
@@ -466,8 +466,6 @@ func handleURLSchemeInCurrentInstance(urlSchemeRequest string) {
if isConnect {
handleConnectURLScheme()
} else {
if wv.webview != nil {
showWindow(wv.webview.Window())
}
sendUIRequestMessage("/")
}
}

View File

@@ -24,14 +24,27 @@ bool firstTimeRun,startHidden; // Set in run before initialization
for (NSURL *url in urls) {
if ([url.scheme isEqualToString:@"ollama"]) {
NSString *path = url.path;
if (path && ([path isEqualToString:@"/connect"] || [url.host isEqualToString:@"connect"])) {
if (!path || [path isEqualToString:@""]) {
// For URLs like ollama://settings (without triple slash),
// the "settings" part is parsed as the host, not the path.
// We need to convert it to a path by prepending "/"
if (url.host && ![url.host isEqualToString:@""]) {
path = [@"/" stringByAppendingString:url.host];
} else {
path = @"/";
}
}
if ([path isEqualToString:@"/connect"] || [url.host isEqualToString:@"connect"]) {
// Special case: handle connect by opening browser instead of app
handleConnectURL();
} else {
// Set app to be active and visible
[NSApp setActivationPolicy:NSApplicationActivationPolicyRegular];
[NSApp activateIgnoringOtherApps:YES];
// Open the path with the UI
[self uiRequest:path];
}
break;
@@ -247,7 +260,7 @@ bool firstTimeRun,startHidden; // Set in run before initialization
}
- (void)openHelp:(id)sender {
NSURL *url = [NSURL URLWithString:@"https://docs.ollama.com/"];
NSURL *url = [NSURL URLWithString:@"https://github.com/ollama/ollama/tree/main/docs"];
[[NSWorkspace sharedWorkspace] openURL:url];
}

View File

@@ -147,9 +147,7 @@ func handleURLSchemeRequest(urlScheme string) {
if isConnect {
handleConnectURLScheme()
} else {
if wv.webview != nil {
showWindow(wv.webview.Window())
}
sendUIRequestMessage("/")
}
}

View File

@@ -1,625 +0,0 @@
#!/usr/bin/env python3
# /// script
# requires-python = ">=3.11"
# dependencies = [
# "transformers>=4.57.0",
# "jinja2",
# "fastapi",
# "uvicorn",
# "pydantic",
# "requests",
# ]
# ///
"""
Chat Template Testing Tool
Test HuggingFace chat templates against Ollama renderers.
Usage:
# Run predefined test cases against a HuggingFace model
uv run cmd/chat_template/chat_template.py --model PrimeIntellect/INTELLECT-3
# Compare HuggingFace output with Ollama renderer
uv run cmd/chat_template/chat_template.py --model PrimeIntellect/INTELLECT-3 --ollama-model intellect3
# Start server for manual curl testing
uv run cmd/chat_template/chat_template.py --serve
# Show chat template for a model
uv run cmd/chat_template/chat_template.py --model PrimeIntellect/INTELLECT-3 --show-template
"""
import argparse
import json
import sys
from typing import Any
from transformers import AutoTokenizer
TEST_CASES = [
{
"name": "basic_user_message",
"messages": [{"role": "user", "content": "Hello!"}],
"tools": None,
},
{
"name": "with_system_message",
"messages": [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "Hello!"},
],
"tools": None,
},
{
"name": "multi_turn_conversation",
"messages": [
{"role": "user", "content": "Hello"},
{"role": "assistant", "content": "Hi there!"},
{"role": "user", "content": "How are you?"},
],
"tools": None,
},
{
"name": "with_tools",
"messages": [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "What is the weather?"},
],
"tools": [
{
"type": "function",
"function": {
"name": "get_weather",
"description": "Get the current weather",
"parameters": {
"type": "object",
"required": ["location"],
"properties": {
"location": {"type": "string", "description": "The city"}
},
},
},
}
],
},
{
"name": "tool_call_and_response",
"messages": [
{"role": "user", "content": "What is the weather in SF?"},
{
"role": "assistant",
"content": "Let me check the weather.",
"tool_calls": [
{
"id": "call_1",
"type": "function",
"function": {
"name": "get_weather",
"arguments": {"location": "San Francisco"},
},
}
],
},
{"role": "tool", "content": '{"temperature": 68}', "tool_call_id": "call_1"},
],
"tools": [
{
"type": "function",
"function": {
"name": "get_weather",
"description": "Get the current weather",
"parameters": {
"type": "object",
"required": ["location"],
"properties": {
"location": {"type": "string", "description": "The city"}
},
},
},
}
],
},
{
"name": "parallel_tool_calls",
"messages": [
{"role": "user", "content": "Get weather in SF and NYC"},
{
"role": "assistant",
"tool_calls": [
{
"id": "call_1",
"type": "function",
"function": {
"name": "get_weather",
"arguments": {"location": "San Francisco"},
},
},
{
"id": "call_2",
"type": "function",
"function": {
"name": "get_weather",
"arguments": {"location": "New York"},
},
},
],
},
{"role": "tool", "content": '{"temperature": 68}', "tool_call_id": "call_1"},
{"role": "tool", "content": '{"temperature": 55}', "tool_call_id": "call_2"},
],
"tools": [
{
"type": "function",
"function": {
"name": "get_weather",
"parameters": {
"type": "object",
"properties": {"location": {"type": "string"}},
},
},
}
],
},
# Thinking tests
{
"name": "assistant_with_thinking",
"messages": [
{"role": "user", "content": "What is 2+2?"},
{
"role": "assistant",
"content": "The answer is 4.",
"thinking": "Let me calculate: 2 + 2 = 4. This is basic arithmetic.",
},
{"role": "user", "content": "And 3+3?"},
],
"tools": None,
},
{
"name": "thinking_with_tool_call",
"messages": [
{"role": "user", "content": "What's the weather in Paris?"},
{
"role": "assistant",
"content": "I'll check the weather for you.",
"thinking": "The user wants to know the weather in Paris. I should call the get_weather function.",
"tool_calls": [
{
"id": "call_1",
"type": "function",
"function": {
"name": "get_weather",
"arguments": {"location": "Paris"},
},
}
],
},
{"role": "tool", "content": '{"temperature": 18, "condition": "cloudy"}', "tool_call_id": "call_1"},
],
"tools": [
{
"type": "function",
"function": {
"name": "get_weather",
"description": "Get current weather",
"parameters": {
"type": "object",
"properties": {"location": {"type": "string"}},
},
},
}
],
},
{
"name": "thinking_only_no_content",
"messages": [
{"role": "user", "content": "Think about this silently."},
{
"role": "assistant",
"content": "", # HuggingFace requires content field
"thinking": "I'm thinking about this but won't respond with visible content.",
},
{"role": "user", "content": "What did you think?"},
],
"tools": None,
},
]
# Cache for tokenizers
_tokenizer_cache: dict[str, Any] = {}
def get_tokenizer(model_name: str):
"""Get or create tokenizer for the given model."""
if model_name not in _tokenizer_cache:
print(f"Loading tokenizer for {model_name}...", file=sys.stderr)
_tokenizer_cache[model_name] = AutoTokenizer.from_pretrained(model_name)
return _tokenizer_cache[model_name]
def apply_template(
model: str,
messages: list[dict],
tools: list[dict] | None = None,
) -> str:
"""Apply HuggingFace chat template to messages."""
tokenizer = get_tokenizer(model)
if tools:
return tokenizer.apply_chat_template(
messages,
tools=tools,
tokenize=False,
add_generation_prompt=True,
)
else:
return tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True,
)
def get_ollama_prompt(
ollama_model: str,
messages: list[dict],
tools: list[dict] | None = None,
ollama_host: str = "http://localhost:11434",
) -> str | None:
"""Get rendered prompt from Ollama using debug_render_only."""
import requests
# Convert messages to Ollama format
ollama_messages = []
for msg in messages:
ollama_msg = {"role": msg["role"]}
if "content" in msg:
ollama_msg["content"] = msg["content"]
if "thinking" in msg:
ollama_msg["thinking"] = msg["thinking"]
if "tool_calls" in msg:
# Convert tool_calls to Ollama format
tool_calls = []
for tc in msg["tool_calls"]:
tool_call = {
"function": {
"name": tc["function"]["name"],
"arguments": tc["function"]["arguments"],
}
}
if "id" in tc:
tool_call["id"] = tc["id"]
tool_calls.append(tool_call)
ollama_msg["tool_calls"] = tool_calls
if "tool_call_id" in msg:
ollama_msg["tool_call_id"] = msg["tool_call_id"]
ollama_messages.append(ollama_msg)
payload = {
"model": ollama_model,
"messages": ollama_messages,
"stream": False,
"_debug_render_only": True,
}
if tools:
payload["tools"] = tools
try:
resp = requests.post(f"{ollama_host}/api/chat", json=payload, timeout=30)
resp.raise_for_status()
data = resp.json()
# Field name is _debug_info with underscore prefix
if "_debug_info" in data and "rendered_template" in data["_debug_info"]:
return data["_debug_info"]["rendered_template"]
return None
except requests.exceptions.ConnectionError:
print(f" [ERROR] Cannot connect to Ollama at {ollama_host}", file=sys.stderr)
return None
except Exception as e:
print(f" [ERROR] Ollama request failed: {e}", file=sys.stderr)
return None
def compute_diff(hf_prompt: str, ollama_prompt: str) -> str:
"""Compute a unified diff between HuggingFace and Ollama prompts."""
import difflib
hf_lines = hf_prompt.splitlines(keepends=True)
ollama_lines = ollama_prompt.splitlines(keepends=True)
diff = difflib.unified_diff(
ollama_lines,
hf_lines,
fromfile="Ollama",
tofile="HuggingFace",
lineterm="",
)
return "".join(diff)
def print_test_output(
name: str,
messages: list[dict],
tools: list[dict] | None,
hf_prompt: str,
ollama_prompt: str | None = None,
as_repr: bool = False,
):
"""Print test output in a format suitable for Go test creation and LLM diffing."""
print(f"\n{'='*60}")
print(f"Test: {name}")
print("=" * 60)
print("\n--- Input Messages ---")
print(json.dumps(messages, indent=2))
if tools:
print("\n--- Tools ---")
print(json.dumps(tools, indent=2))
if ollama_prompt is not None:
# Comparison mode
if hf_prompt == ollama_prompt:
print("\n--- Result: MATCH ---")
print("\n--- Prompt (both identical) ---")
if as_repr:
print(repr(hf_prompt))
else:
print(hf_prompt)
else:
print("\n--- Result: MISMATCH ---")
print("\n--- HuggingFace Prompt ---")
if as_repr:
print(repr(hf_prompt))
else:
print(hf_prompt)
print("\n--- Ollama Prompt ---")
if as_repr:
print(repr(ollama_prompt))
else:
print(ollama_prompt)
print("\n--- Diff (Ollama -> HuggingFace) ---")
diff = compute_diff(hf_prompt, ollama_prompt)
if diff:
print(diff)
else:
print("(no line-level diff, check whitespace)")
else:
# HuggingFace only mode
print("\n--- HuggingFace Prompt ---")
if as_repr:
print(repr(hf_prompt))
else:
print(hf_prompt)
print("=" * 60)
def run_tests(
model: str,
as_repr: bool = False,
test_filter: str | None = None,
ollama_model: str | None = None,
ollama_host: str = "http://localhost:11434",
):
"""Run all predefined test cases against a model."""
if ollama_model:
print(f"\nComparing HuggingFace ({model}) vs Ollama ({ollama_model})\n")
else:
print(f"\nRunning tests against: {model}\n")
matches = 0
mismatches = 0
errors = 0
for test_case in TEST_CASES:
name = test_case["name"]
messages = test_case["messages"]
tools = test_case["tools"]
# Filter tests if specified
if test_filter and test_filter.lower() not in name.lower():
continue
try:
hf_prompt = apply_template(model, messages, tools)
ollama_prompt = None
if ollama_model:
ollama_prompt = get_ollama_prompt(
ollama_model, messages, tools, ollama_host
)
if ollama_prompt is None:
errors += 1
elif hf_prompt == ollama_prompt:
matches += 1
else:
mismatches += 1
print_test_output(
name, messages, tools, hf_prompt, ollama_prompt, as_repr=as_repr
)
except Exception as e:
errors += 1
print(f"\n{'='*60}")
print(f"Test: {name} - FAILED")
print(f"--- Input Messages ---")
print(json.dumps(messages, indent=2))
if tools:
print(f"--- Tools ---")
print(json.dumps(tools, indent=2))
print(f"--- Error ---")
print(f"{e}")
print("=" * 60)
# Print summary if comparing
if ollama_model:
total = matches + mismatches + errors
print(f"\n{'='*60}")
print("SUMMARY")
print("=" * 60)
print(f" Total: {total}")
print(f" Matches: {matches}")
print(f" Mismatches: {mismatches}")
print(f" Errors: {errors}")
print("=" * 60)
def show_template(model: str):
"""Show the chat template for a model."""
tokenizer = get_tokenizer(model)
print(f"\nChat template for {model}:\n")
print("-" * 60)
print(tokenizer.chat_template)
print("-" * 60)
def start_server(host: str = "0.0.0.0", port: int = 8000):
"""Start the FastAPI server for manual testing."""
from typing import Optional, List, Dict, Any as TypingAny
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
import uvicorn
class Message(BaseModel):
role: str
content: Optional[str] = None
tool_calls: Optional[List[Dict[str, TypingAny]]] = None
tool_call_id: Optional[str] = None
class GeneratePromptRequest(BaseModel):
messages: List[Message]
model: str = "PrimeIntellect/INTELLECT-3"
tools: Optional[List[Dict[str, TypingAny]]] = None
inject_tools_as_functions: bool = False
class GeneratePromptResponse(BaseModel):
prompt: str
model: str
app = FastAPI(title="HuggingFace Prompt Generator", version="1.0.0")
@app.post("/generate-prompt", response_model=GeneratePromptResponse)
async def generate_prompt(request: GeneratePromptRequest):
try:
messages = []
for msg in request.messages:
message_dict = {"role": msg.role}
if msg.content is not None:
message_dict["content"] = msg.content
if msg.tool_calls is not None:
tool_calls = []
for tc in msg.tool_calls:
tc_copy = tc.copy()
if "function" in tc_copy and "arguments" in tc_copy["function"]:
args = tc_copy["function"]["arguments"]
if isinstance(args, str):
try:
tc_copy["function"]["arguments"] = json.loads(args)
except json.JSONDecodeError:
pass
tool_calls.append(tc_copy)
message_dict["tool_calls"] = tool_calls
if msg.tool_call_id is not None:
message_dict["tool_call_id"] = msg.tool_call_id
messages.append(message_dict)
prompt = apply_template(request.model, messages, request.tools)
return GeneratePromptResponse(prompt=prompt, model=request.model)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.get("/health")
async def health_check():
return {"status": "healthy"}
print(f"Starting server on http://{host}:{port}")
print("Endpoints:")
print(" POST /generate-prompt - Generate prompt from messages")
print(" GET /health - Health check")
uvicorn.run(app, host=host, port=port)
def main():
parser = argparse.ArgumentParser(
description="HuggingFace Prompt Testing Tool",
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog=__doc__,
)
parser.add_argument(
"--model",
"-m",
type=str,
help="HuggingFace model name (e.g., PrimeIntellect/INTELLECT-3)",
)
parser.add_argument(
"--ollama-model",
"-o",
type=str,
help="Ollama model name to compare against (e.g., qwen3-coder)",
)
parser.add_argument(
"--ollama-host",
type=str,
default="http://localhost:11434",
help="Ollama server URL (default: http://localhost:11434)",
)
parser.add_argument(
"--serve",
"-s",
action="store_true",
help="Start FastAPI server for manual curl testing",
)
parser.add_argument(
"--port",
"-p",
type=int,
default=8000,
help="Server port (default: 8000)",
)
parser.add_argument(
"--show-template",
"-t",
action="store_true",
help="Show the chat template for the model",
)
parser.add_argument(
"--repr",
"-r",
action="store_true",
help="Output prompts as Python repr (shows escape sequences)",
)
parser.add_argument(
"--filter",
"-f",
type=str,
help="Filter tests by name (substring match)",
)
args = parser.parse_args()
if args.serve:
start_server(port=args.port)
elif args.model:
if args.show_template:
show_template(args.model)
else:
run_tests(
args.model,
as_repr=args.repr,
test_filter=args.filter,
ollama_model=args.ollama_model,
ollama_host=args.ollama_host,
)
else:
parser.print_help()
print("\nExample usage:")
print(" uv run cmd/chat_template/chat_template.py --model PrimeIntellect/INTELLECT-3")
print(" uv run cmd/chat_template/chat_template.py --model Qwen/Qwen3-Coder-480B-A35B-Instruct --ollama-model qwen3-coder")
print(" uv run cmd/chat_template/chat_template.py --serve")
sys.exit(1)
if __name__ == "__main__":
main()

View File

@@ -206,8 +206,6 @@ func ConvertModel(fsys fs.FS, f *os.File) error {
conv = &commandrModel{}
case "GptOssForCausalLM":
conv = &gptossModel{}
case "DeepseekOCRForCausalLM":
conv = &deepseekocr{}
default:
return fmt.Errorf("unsupported architecture %q", p.Architectures[0])
}

View File

@@ -1,136 +0,0 @@
package convert
import (
"fmt"
"github.com/ollama/ollama/fs/ggml"
)
type deepseekocr struct {
ModelParameters
LanguageConfig struct {
MaxPositionEmbeddings uint32 `json:"max_position_embeddings"`
HiddenSize uint32 `json:"hidden_size"`
HiddenLayers uint32 `json:"num_hidden_layers"`
IntermediateSize uint32 `json:"intermediate_size"`
NumAttentionHeads uint32 `json:"num_attention_heads"`
NumKeyValueHeads uint32 `json:"num_key_value_heads"`
NumRoutedExperts uint32 `json:"n_routed_experts"`
NumSharedExperts uint32 `json:"n_shared_experts"`
NumExpertsPerToken uint32 `json:"num_experts_per_tok"`
FirstKDenseReplace uint32 `json:"first_k_dense_replace"`
} `json:"language_config"`
VisionConfig struct {
ImageSize uint32 `json:"image_size"`
Width struct {
Vision struct {
Heads uint32 `json:"heads"`
ImageSize uint32 `json:"image_size"`
Layers uint32 `json:"layers"`
PatchSize uint32 `json:"patch_size"`
Width uint32 `json:"width"`
} `json:"clip-l-14-224"`
Sam struct {
GlobalAttentionIndexes []int32 `json:"global_attn_indexes"`
Heads uint32 `json:"heads"`
Layers uint32 `json:"layers"`
Width uint32 `json:"width"`
} `json:"sam_vit_b"`
}
} `json:"vision_config"`
}
func (m *deepseekocr) KV(t *Tokenizer) ggml.KV {
kv := m.ModelParameters.KV(t)
kv["general.architecture"] = "deepseekocr"
kv["block_count"] = m.LanguageConfig.HiddenLayers
kv["context_length"] = m.LanguageConfig.MaxPositionEmbeddings
kv["embedding_length"] = m.LanguageConfig.HiddenSize
kv["feed_forward_length"] = m.LanguageConfig.IntermediateSize
kv["attention.head_count"] = m.LanguageConfig.NumAttentionHeads
kv["attention.head_count_kv"] = m.LanguageConfig.NumKeyValueHeads
kv["expert_count"] = m.LanguageConfig.NumRoutedExperts
kv["expert_used_count"] = m.LanguageConfig.NumExpertsPerToken
kv["leading_dense_block_count"] = m.LanguageConfig.FirstKDenseReplace
kv["vision.block_count"] = m.VisionConfig.Width.Vision.Layers
kv["vision.embedding_length"] = m.VisionConfig.Width.Vision.Width
kv["vision.head_count"] = m.VisionConfig.Width.Vision.Heads
kv["vision.image_size"] = m.VisionConfig.Width.Vision.ImageSize
kv["vision.patch_size"] = m.VisionConfig.Width.Vision.PatchSize
kv["sam.block_count"] = m.VisionConfig.Width.Sam.Layers
kv["sam.embedding_length"] = m.VisionConfig.Width.Sam.Width
kv["sam.head_count"] = m.VisionConfig.Width.Sam.Heads
kv["sam.global_attention_indexes"] = m.VisionConfig.Width.Sam.GlobalAttentionIndexes
return kv
}
func (m *deepseekocr) Tensors(s []Tensor) (out []*ggml.Tensor) {
merges := make([]merge, m.LanguageConfig.HiddenLayers*3)
for i := range m.LanguageConfig.HiddenLayers {
merges[i*3+0] = merge{
fmt.Sprintf("blk.%d.mlp.experts.*.gate_proj.weight", i),
fmt.Sprintf("blk.%d.ffn_gate_exps.weight", i),
}
merges[i*3+1] = merge{
fmt.Sprintf("blk.%d.mlp.experts.*.up_proj.weight", i),
fmt.Sprintf("blk.%d.ffn_up_exps.weight", i),
}
merges[i*3+2] = merge{
fmt.Sprintf("blk.%d.mlp.experts.*.down_proj.weight", i),
fmt.Sprintf("blk.%d.ffn_down_exps.weight", i),
}
}
out, s = mergeTensors(s, merges...)
for _, t := range s {
out = append(out, &ggml.Tensor{
Name: t.Name(),
Kind: t.Kind(),
Shape: t.Shape(),
WriterTo: t,
})
}
return out
}
func (m *deepseekocr) Replacements() []string {
return []string{
"model.embed_tokens", "token_embd",
"model.layers", "blk",
"input_layernorm", "attn_norm",
"self_attn.q_proj", "attn_q",
"self_attn.k_proj", "attn_k",
"self_attn.v_proj", "attn_v",
"self_attn.o_proj", "attn_output",
"post_attention_layernorm", "ffn_norm",
"mlp.gate_proj", "ffn_gate",
"mlp.up_proj", "ffn_up",
"mlp.down_proj", "ffn_down",
"mlp.gate", "ffn_gate_inp",
"mlp.shared_experts.gate_proj", "ffn_gate_shexp",
"mlp.shared_experts.up_proj", "ffn_up_shexp",
"mlp.shared_experts.down_proj", "ffn_down_shexp",
"model.norm", "output_norm",
"lm_head", "output",
"model.vision_model", "v",
"embeddings.patch_embedding", "patch_embd",
"embeddings.class_embedding", "class_embd",
"embeddings.position_embedding", "position_embd",
"transformer.layers", "blk",
"model.projector", "mm",
"model.image_newline", "mm.image_newline",
//nolint:misspell // this misspelling is upstream. fixing it breaks the model
"model.view_seperator", "mm.view_seperator",
"model.sam_model.patch_embed.proj", "s.patch_embd",
"model.sam_model.pos_embed", "s.position_embd",
"model.sam_model.blocks", "s.blk",
"model.sam_model.neck", "s.neck",
"model.sam_model.net_", "s.net_",
}
}

View File

@@ -44,10 +44,7 @@ func (t tensorBase) Kind() uint32 {
t.name == "v.positional_embedding_vlm" ||
t.name == "v.tile_position_embd.weight" ||
t.name == "v.pre_tile_position_embd.weight" ||
t.name == "v.post_tile_position_embd.weight" ||
t.name == "s.position_embd" ||
strings.HasSuffix(t.name, "rel_pos_h") ||
strings.HasSuffix(t.name, "rel_pos_w") {
t.name == "v.post_tile_position_embd.weight" {
// these tensors are always F32
return tensorKindFP32
}

View File

@@ -96,10 +96,7 @@ type safetensor struct {
func (st safetensor) Kind() uint32 {
kind := st.tensorBase.Kind()
if st.dtype == "BF16" &&
!strings.HasPrefix(st.name, "v.") &&
!strings.HasPrefix(st.name, "s.") &&
kind != tensorKindFP32 {
if !strings.HasPrefix(st.name, "v.") && st.dtype == "BF16" && kind != tensorKindFP32 {
kind = tensorKindBF16
}

View File

@@ -65,7 +65,6 @@ func GPUDevices(ctx context.Context, runners []ml.FilteredRunnerDiscovery) []ml.
}
slog.Info("discovering available GPUs...")
detectIncompatibleLibraries()
// Warn if any user-overrides are set which could lead to incorrect GPU discovery
overrideWarnings()
@@ -99,9 +98,6 @@ func GPUDevices(ctx context.Context, runners []ml.FilteredRunnerDiscovery) []ml.
continue
} else if jetpack != "" && filepath.Base(dir) != "cuda_"+jetpack {
continue
} else if jetpack == "" && strings.Contains(filepath.Base(dir), "cuda_jetpack") {
slog.Debug("jetpack not detected (set JETSON_JETPACK or OLLAMA_LLM_LIBRARY to override), skipping", "libDir", dir)
continue
} else if !envconfig.EnableVulkan() && strings.Contains(filepath.Base(dir), "vulkan") {
slog.Info("experimental Vulkan support disabled. To enable, set OLLAMA_VULKAN=1")
continue
@@ -129,20 +125,10 @@ func GPUDevices(ctx context.Context, runners []ml.FilteredRunnerDiscovery) []ml.
supportedMu := sync.Mutex{}
supported := make(map[string]map[string]map[string]int) // [Library][libDir][ID] = pre-deletion devices index
for i := range devices {
libDir := devices[i].LibraryPath[len(devices[i].LibraryPath)-1]
if !devices[i].NeedsInitValidation() {
// No need to validate, add to the supported map
supportedMu.Lock()
if _, ok := supported[devices[i].Library]; !ok {
supported[devices[i].Library] = make(map[string]map[string]int)
}
if _, ok := supported[devices[i].Library][libDir]; !ok {
supported[devices[i].Library][libDir] = make(map[string]int)
}
supported[devices[i].Library][libDir][devices[i].ID] = i
supportedMu.Unlock()
continue
}
libDir := devices[i].LibraryPath[len(devices[i].LibraryPath)-1]
slog.Debug("verifying if device is supported", "library", libDir, "description", devices[i].Description, "compute", devices[i].Compute(), "id", devices[i].ID, "pci_id", devices[i].PCIID)
wg.Add(1)
go func(i int) {
@@ -488,16 +474,3 @@ func overrideWarnings() {
slog.Warn("if GPUs are not correctly discovered, unset and try again")
}
}
func detectIncompatibleLibraries() {
if runtime.GOOS != "windows" {
return
}
basePath, err := exec.LookPath("ggml-base.dll")
if err != nil || basePath == "" {
return
}
if !strings.HasPrefix(basePath, ml.LibOllamaPath) {
slog.Warn("potentially incompatible library detected in PATH", "location", basePath)
}
}

View File

@@ -57,13 +57,8 @@ ollama ps
```
<Info>
**Output**:
```
NAME ID SIZE PROCESSOR UNTIL
llama3:70b bcfb190ca3a7 42 GB 100% GPU 4 minutes from now
```
**Output**: ``` NAME ID SIZE PROCESSOR UNTIL llama3:70b bcfb190ca3a7 42 GB
100% GPU 4 minutes from now ```
</Info>
The `Processor` column will show which memory the model was loaded in to:
@@ -390,4 +385,4 @@ Ollama for Windows and macOS register as a login item during installation. You
- In `Task Manager` go to the `Startup apps` tab, search for `ollama` then click `Disable`
**MacOS**
- Open `Settings` and search for "Login Items", find the `Ollama` entry under "Allow in the Background`, then click the slider to disable.
- Open `Settings` and search for "Login Items", find the `Ollama` entry under "Allow in the Background`, then click the slider to disable.

View File

@@ -149,6 +149,9 @@ PARAMETER <parameter> <parametervalue>
| Parameter | Description | Value Type | Example Usage |
| -------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | ---------- | -------------------- |
| mirostat | Enable Mirostat sampling for controlling perplexity. (default: 0, 0 = disabled, 1 = Mirostat, 2 = Mirostat 2.0) | int | mirostat 0 |
| mirostat_eta | Influences how quickly the algorithm responds to feedback from the generated text. A lower learning rate will result in slower adjustments, while a higher learning rate will make the algorithm more responsive. (Default: 0.1) | float | mirostat_eta 0.1 |
| mirostat_tau | Controls the balance between coherence and diversity of the output. A lower value will result in more focused and coherent text. (Default: 5.0) | float | mirostat_tau 5.0 |
| num_ctx | Sets the size of the context window used to generate the next token. (Default: 2048) | int | num_ctx 4096 |
| repeat_last_n | Sets how far back for the model to look back to prevent repetition. (Default: 64, 0 = disabled, -1 = num_ctx) | int | repeat_last_n 64 |
| repeat_penalty | Sets how strongly to penalize repetitions. A higher value (e.g., 1.5) will penalize repetitions more strongly, while a lower value (e.g., 0.9) will be more lenient. (Default: 1.1) | float | repeat_penalty 1.1 |

View File

@@ -249,9 +249,6 @@ func (kv KV) OllamaEngineRequired() bool {
"qwen25vl",
"qwen3", "qwen3moe",
"qwen3vl", "qwen3vlmoe",
"deepseekocr",
"deepseek2",
"nomic-bert",
}, kv.Architecture())
}

View File

@@ -388,9 +388,9 @@ func NewFunctionNameMap() *FunctionNameMap {
}
}
// Init initializes the handler with tools, optional last message, and think value
// Init initializes the handler with tools and optional last message
// Implements the Parser interface
func (h *HarmonyMessageHandler) Init(tools []api.Tool, lastMessage *api.Message, thinkValue *api.ThinkValue) []api.Tool {
func (h *HarmonyMessageHandler) Init(tools []api.Tool, lastMessage *api.Message) []api.Tool {
// Initialize the harmony parser
if h.HarmonyParser == nil {
h.HarmonyParser = &HarmonyParser{

View File

@@ -3,6 +3,7 @@ package kvcache
import (
"errors"
"fmt"
"log/slog"
"math"
"slices"
@@ -39,18 +40,18 @@ type Causal struct {
// ** current forward pass **
// the active layer for Get and Put
curLayer int
// starting location for data storage for this batch
curLoc int
// size of the current batch
curBatchSize int
// locations for data storage for this batch
curLoc ml.Tensor
// mask of the cache as used by this batch
curMask ml.Tensor
// the active layer for Get and Put
curLayer int
// locations in the cache that are needed for this batch
curCellRange cellRange
@@ -205,47 +206,45 @@ func (c *Causal) StartForward(ctx ml.Context, batch input.Batch, reserve bool) e
c.curPositions = batch.Positions
c.opts.Except = nil
var locs []int32
if !reserve {
c.updateSlidingWindow()
var err error
locs, err = c.findLocs()
c.curLoc, err = c.findStartLoc()
if errors.Is(err, ErrKvCacheFull) {
c.defrag()
c.curLoc, err = c.findStartLoc()
}
if err != nil {
return err
}
for i, pos := range batch.Positions {
seq := batch.Sequences[i]
loc := int(locs[i])
c.cells[loc] = cacheCell{pos: pos, sequences: []int{seq}}
c.cells[c.curLoc+i] = cacheCell{pos: pos, sequences: []int{seq}}
seqRange, ok := c.cellRanges[seq]
if !ok {
seqRange = newRange()
}
seqRange.min = min(seqRange.min, loc)
c.curCellRange.min = min(c.curCellRange.min, loc)
seqRange.min = min(seqRange.min, c.curLoc+i)
c.curCellRange.min = min(c.curCellRange.min, c.curLoc+i)
seqRange.max = max(seqRange.max, loc)
c.curCellRange.max = max(c.curCellRange.max, loc)
seqRange.max = max(seqRange.max, c.curLoc+i)
c.curCellRange.max = max(c.curCellRange.max, c.curLoc+i)
c.cellRanges[seq] = seqRange
}
} else {
// If we are reserving memory, don't update any of the cache metadata but set the size
// to the worst case.
locs = make([]int32, c.curBatchSize)
for i := range locs {
locs[i] = int32(i)
}
c.curLoc = 0
c.curCellRange.min = 0
c.curCellRange.max = len(c.cells) - 1
}
c.curLoc = ctx.Input().FromInts(locs, len(locs))
c.curMask = c.buildMask(ctx)
return nil
@@ -258,20 +257,22 @@ func newRange() cellRange {
}
}
// Returns a slice of locations where each token in the batch should be stored
func (c *Causal) findLocs() ([]int32, error) {
loc := make([]int32, 0, c.curBatchSize)
// Find the first contiguous block of at least curBatchSize
func (c *Causal) findStartLoc() (int, error) {
var start, count int
for i := range c.cells {
if len(c.cells[i].sequences) == 0 {
loc = append(loc, int32(i))
if len(loc) >= c.curBatchSize {
return loc, nil
count++
if count >= c.curBatchSize {
return start, nil
}
} else {
start = i + 1
count = 0
}
}
return nil, fmt.Errorf("%w (cache: %v batch: %v)", ErrKvCacheFull, len(c.cells), c.curBatchSize)
return 0, fmt.Errorf("%w (cache: %v batch: %v)", ErrKvCacheFull, len(c.cells), c.curBatchSize)
}
func (c *Causal) updateSlidingWindow() {
@@ -401,6 +402,145 @@ func (c *Causal) buildMask(ctx ml.Context) ml.Tensor {
return maskTensor
}
func (c *Causal) moveCells(ctx ml.Context, src, dst, length int) {
for i, key := range c.keys {
if key == nil {
continue
}
kHeadDim := key.Dim(0)
numKVHeads := key.Dim(1)
rowSize := key.Stride(2)
kSrcView := key.View(ctx, rowSize*src, kHeadDim*numKVHeads*length)
kDstView := key.View(ctx, rowSize*dst, kHeadDim*numKVHeads*length)
value := c.values[i]
var vSrcView, vDstView ml.Tensor
if c.config.PermutedV {
vHeadDim := value.Dim(1)
elemSize := value.Stride(0)
vSrcView = value.View(ctx, elemSize*src, length, len(c.cells)*elemSize, vHeadDim*numKVHeads)
vDstView = value.View(ctx, elemSize*dst, length, len(c.cells)*elemSize, vHeadDim*numKVHeads)
} else {
vHeadDim := value.Dim(0)
rowSize := value.Stride(2)
vSrcView = value.View(ctx, rowSize*src, vHeadDim*numKVHeads*length)
vDstView = value.View(ctx, rowSize*dst, vHeadDim*numKVHeads*length)
}
ctx.Forward(
kSrcView.Copy(ctx, kDstView),
vSrcView.Copy(ctx, vDstView),
)
}
}
func (c *Causal) defrag() {
slog.Debug("defragmenting kv cache")
// Defrag strategy:
// - Search for empty holes at the beginning of the cache,
// filling them with active data starting at the end
// - If there are contiguous elements that need to be moved,
// combine them into a single operation by holding new moves
// until we see that the next one is non-contiguous
// - Fill up the context with the maximum number of operations it
// can hold then compute that and continue with a new context
//
// We could try to optimize placement by grouping blocks from
// the same sequences together but most likely the next forward
// pass will disrupt this anyways, so the real world benefit
// seems limited as this time.
ctx := c.backend.NewContext()
// For every move, 6 tensors are required per layer (2 views and a
// copy for each of k and v). We also need to refer to the original
// k and v cache tensors - once per layer, not per move.
layers := 0
for _, key := range c.keys {
if key == nil {
continue
}
layers++
}
maxMoves := (ctx.MaxGraphNodes() - 2*layers) / (6 * layers)
moves := 0
var pendingSrc, pendingDst, pendingLen int
src := len(c.cells) - 1
for dst := 0; dst < src; dst++ {
if len(c.cells[dst].sequences) == 0 {
for ; src > dst; src-- {
if len(c.cells[src].sequences) != 0 {
c.cells[dst] = c.cells[src]
c.cells[src] = cacheCell{}
if pendingLen > 0 {
if src == pendingSrc-pendingLen && dst == pendingDst+pendingLen {
pendingSrc = src
pendingLen++
break
} else {
c.moveCells(ctx, pendingSrc, pendingDst, pendingLen)
moves++
}
}
pendingSrc = src
pendingDst = dst
pendingLen = 1
break
}
}
}
if moves >= maxMoves {
ctx.Compute()
ctx.Close()
ctx = c.backend.NewContext()
moves = 0
}
}
if pendingLen > 0 {
c.moveCells(ctx, pendingSrc, pendingDst, pendingLen)
moves++
}
if moves > 0 {
ctx.Compute()
}
ctx.Close()
// Reset range metadata
for seq := range c.cellRanges {
seqRange := newRange()
for i, cell := range c.cells {
if slices.Contains(cell.sequences, seq) {
if i < seqRange.min {
seqRange.min = i
}
if i > seqRange.max {
seqRange.max = i
}
}
}
c.cellRanges[seq] = seqRange
}
c.updateSlidingWindow()
}
func (c *Causal) SetLayer(layer int) {
c.curLayer = layer
}
@@ -485,25 +625,18 @@ func (c *Causal) Put(ctx ml.Context, key, value ml.Tensor) {
}
}
key = key.Reshape(ctx, kHeadDim*numKVHeads, batchSize)
keyCache := c.keys[c.curLayer]
keyCache = keyCache.Reshape(ctx, kHeadDim*numKVHeads, len(c.cells))
ctx.Forward(keyCache.SetRows(ctx, key, c.curLoc))
rowSize := c.keys[c.curLayer].Stride(2)
ctx.Forward(key.Copy(ctx, c.keys[c.curLayer].View(ctx, rowSize*c.curLoc, kHeadDim*numKVHeads*batchSize)))
if c.config.PermutedV {
value = value.Reshape(ctx, vHeadDim*numKVHeads, 1, batchSize)
value = value.Permute(ctx, 2, 0, 1, 3)
elemSize := c.values[c.curLayer].Stride(0)
valueCache := c.values[c.curLayer]
valueCache = valueCache.Reshape(ctx, 1, len(c.cells), vHeadDim*numKVHeads)
ctx.Forward(valueCache.SetRows(ctx, value, c.curLoc))
value = value.Permute(ctx, 1, 2, 0, 3)
ctx.Forward(value.Copy(ctx, c.values[c.curLayer].View(ctx, elemSize*c.curLoc, batchSize, len(c.cells)*elemSize, vHeadDim*numKVHeads)))
} else {
value = value.Reshape(ctx, vHeadDim*numKVHeads, batchSize)
valueCache := c.values[c.curLayer]
valueCache = valueCache.Reshape(ctx, vHeadDim*numKVHeads, len(c.cells))
rowSize := c.values[c.curLayer].Stride(2)
ctx.Forward(valueCache.SetRows(ctx, value, c.curLoc))
ctx.Forward(value.Copy(ctx, c.values[c.curLayer].View(ctx, rowSize*c.curLoc, vHeadDim*numKVHeads*batchSize)))
}
}

File diff suppressed because it is too large Load Diff

View File

@@ -38,7 +38,7 @@ index 44ae76d66..639d551a2 100644
#ifdef __cplusplus
}
diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp
index ca02ea079..c12b069e5 100644
index d2c278a35..221e29509 100644
--- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp
+++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp
@@ -73,6 +73,7 @@ DispatchLoaderDynamic & ggml_vk_default_dispatcher();

View File

@@ -11,7 +11,7 @@ vidmem optimization.
1 file changed, 1 insertion(+), 4 deletions(-)
diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp
index c12b069e5..76c78c2ea 100644
index 221e29509..18b7cbccf 100644
--- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp
+++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp
@@ -5654,14 +5654,11 @@ static void ggml_vk_buffer_copy(vk_buffer& dst, size_t dst_offset, vk_buffer& sr

View File

@@ -50,7 +50,7 @@ Subject: [PATCH] Vulkan MMQ Integer Dot Refactor and K-Quant support (#16536)
create mode 100644 ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_shmem_types.glsl
diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp
index 76c78c2ea..7669ed206 100644
index 18b7cbccf..53b57c179 100644
--- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp
+++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp
@@ -488,6 +488,7 @@ struct vk_device_struct {

View File

@@ -58,7 +58,7 @@ index 639d551a2..e5c446d1d 100644
GGML_API size_t gguf_type_size(enum gguf_type type);
GGML_API struct gguf_context * gguf_init_from_file_impl(FILE * file, struct gguf_init_params params);
diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp
index 7669ed206..63a762ec2 100644
index 53b57c179..b2855b078 100644
--- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp
+++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp
@@ -387,12 +387,76 @@ static constexpr uint32_t num_argsort_pipelines = 11;

View File

@@ -31,7 +31,7 @@ Add new backend tests.
6 files changed, 371 insertions(+), 117 deletions(-)
diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp
index 63a762ec2..db92a7901 100644
index b2855b078..aaf4334b5 100644
--- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp
+++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp
@@ -458,6 +458,11 @@ static topk_moe_mode ggml_vk_num_additional_ops_to_topk_moe_mode(uint32_t num) {

View File

@@ -9,7 +9,7 @@ Subject: [PATCH] vulkan: Handle argsort with a large number of rows (#16851)
2 files changed, 16 insertions(+), 4 deletions(-)
diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp
index db92a7901..e959674d1 100644
index aaf4334b5..3604ceb04 100644
--- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp
+++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp
@@ -1084,6 +1084,7 @@ struct vk_op_soft_max_push_constants {

View File

@@ -20,7 +20,7 @@ Subject: [PATCH] vulkan: Fix crash when FP16 mul_mat accumulation is not
1 file changed, 13 insertions(+), 7 deletions(-)
diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp
index e959674d1..903050b0b 100644
index 3604ceb04..80185d9f0 100644
--- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp
+++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp
@@ -146,8 +146,13 @@ static void ggml_vk_destroy_pipeline(vk::Device& device, vk_pipeline& pipeline);

View File

@@ -1,25 +0,0 @@
From 0000000000000000000000000000000000000000 Mon Sep 17 00:00:00 2001
From: Michael Yang <git@mxy.ng>
Date: Tue, 18 Nov 2025 11:13:04 -0800
Subject: [PATCH] ggml-cuda: skip large batches
cuda panics on batches larger than 1024 so mark it as unsupported to
fallback to cpu
---
ggml/src/ggml-cuda/ggml-cuda.cu | 3 +++
1 file changed, 3 insertions(+)
diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu
index f1a20e7fe..1a71e07c9 100644
--- a/ggml/src/ggml-cuda/ggml-cuda.cu
+++ b/ggml/src/ggml-cuda/ggml-cuda.cu
@@ -3677,6 +3677,9 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
if (b->type == GGML_TYPE_F16 && a->type != GGML_TYPE_F16) {
return false;
}
+ if (op->op == GGML_OP_MUL_MAT && b->ne[2] * b->ne[3] > 1024) {
+ return false;
+ }
#ifdef GGML_USE_MUSA
const int cc = ggml_cuda_info().devices[dev_ctx->device].cc;
if (b->ne[2]*b->ne[3] > 1 && !ggml_is_transposed(a) && !ggml_is_transposed(b)) {

View File

@@ -1,28 +0,0 @@
From 0000000000000000000000000000000000000000 Mon Sep 17 00:00:00 2001
From: Daniel Hiltgen <daniel@ollama.com>
Date: Tue, 18 Nov 2025 09:58:23 -0800
Subject: [PATCH] win: exit instead of abort
---
ggml/src/ggml.c | 7 ++++++-
1 file changed, 6 insertions(+), 1 deletion(-)
diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c
index 9be35c1be..923c33d05 100644
--- a/ggml/src/ggml.c
+++ b/ggml/src/ggml.c
@@ -229,8 +229,13 @@ void ggml_abort(const char * file, int line, const char * fmt, ...) {
fprintf(stderr, "%s\n", message);
ggml_print_backtrace();
}
-
+#if defined(_WIN32)
+ fflush(stderr);
+ fflush(stdout);
+ exit(1);
+#else
abort();
+#endif
}
// ggml_print_backtrace is registered with std::set_terminate by ggml.cpp

View File

@@ -173,7 +173,6 @@ type Tensor interface {
Cos(ctx Context) Tensor
Tanh(ctx Context) Tensor
GELU(ctx Context, up ...Tensor) Tensor
QuickGELU(ctx Context, up ...Tensor) Tensor
SILU(ctx Context, up ...Tensor) Tensor
RELU(ctx Context, up ...Tensor) Tensor
Sigmoid(ctx Context) Tensor
@@ -194,7 +193,6 @@ type Tensor interface {
Repeat(ctx Context, dim, n int) Tensor
Concat(ctx Context, t2 Tensor, dim int) Tensor
Rows(ctx Context, t2 Tensor) Tensor
SetRows(ctx Context, src Tensor, idxs Tensor) Tensor
Copy(ctx Context, t2 Tensor) Tensor
Duplicate(ctx Context) Tensor
@@ -209,8 +207,6 @@ type Tensor interface {
Stddev(ctx Context) Tensor
Sqr(ctx Context) Tensor
Sqrt(ctx Context) Tensor
Interpolate(ctx Context, dims [4]int, samplingMode SamplingMode) Tensor
}
// ScaledDotProductAttention implements a fused attention
@@ -376,10 +372,3 @@ const (
DTypeI32
DTypeMXFP4
)
type SamplingMode int
const (
SamplingModeNearest SamplingMode = iota
SamplingModeBilinear
)

View File

@@ -314,7 +314,7 @@ func New(modelPath string, params ml.BackendParams) (ml.Backend, error) {
"altup_proj", "altup_unembd_proj",
"per_layer_token_embd", "per_layer_model_proj", "per_layer_proj_norm"):
createTensor(tensor{source: t}, output.bts, blocks)
case strings.HasPrefix(t.Name, "v.") || strings.HasPrefix(t.Name, "mm.") || strings.HasPrefix(t.Name, "s."):
case strings.HasPrefix(t.Name, "v.") || strings.HasPrefix(t.Name, "mm."):
// TODO: assign vision tensors to the gpu if possible
createTensor(tensor{source: t}, output.bts, blocks)
case contains(t.Name, "rope_freqs", "rope_factors_long", "rope_factors_short"):
@@ -1338,13 +1338,6 @@ func (t *Tensor) Rows(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
}
}
func (t *Tensor) SetRows(ctx ml.Context, src ml.Tensor, idxs ml.Tensor) ml.Tensor {
return &Tensor{
b: t.b,
t: C.ggml_set_rows(ctx.(*Context).ctx, t.t, src.(*Tensor).t, idxs.(*Tensor).t),
}
}
func (t *Tensor) Copy(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
return &Tensor{
b: t.b,
@@ -1385,10 +1378,6 @@ func inferShape(t *Tensor, shape []int) {
}
func (t *Tensor) Reshape(ctx ml.Context, shape ...int) ml.Tensor {
if !C.ggml_is_contiguous(t.t) {
return t.Contiguous(ctx, shape...)
}
if slices.Contains(shape, -1) {
inferShape(t, shape)
}
@@ -1578,16 +1567,6 @@ func (t *Tensor) GELU(ctx ml.Context, t2 ...ml.Tensor) ml.Tensor {
}
}
func (t *Tensor) QuickGELU(ctx ml.Context, t2 ...ml.Tensor) ml.Tensor {
var tt *C.struct_ggml_tensor
if len(t2) > 0 {
tt = C.ggml_geglu_quick_split(ctx.(*Context).ctx, t.t, t2[0].(*Tensor).t)
} else {
tt = C.ggml_gelu_quick_inplace(ctx.(*Context).ctx, t.t)
}
return &Tensor{b: t.b, t: tt}
}
func (t *Tensor) SILU(ctx ml.Context, t2 ...ml.Tensor) ml.Tensor {
if len(t2) > 0 {
return &Tensor{
@@ -1745,23 +1724,6 @@ func (t *Tensor) Sqrt(ctx ml.Context) ml.Tensor {
}
}
func (t *Tensor) Interpolate(ctx ml.Context, dims [4]int, samplingMode ml.SamplingMode) ml.Tensor {
var mode C.uint32_t
switch samplingMode {
case ml.SamplingModeNearest:
mode = C.GGML_SCALE_MODE_NEAREST
case ml.SamplingModeBilinear:
mode = C.GGML_SCALE_MODE_BILINEAR
default:
panic("unsupported interpolate mode")
}
return &Tensor{
b: t.b,
t: C.ggml_interpolate(ctx.(*Context).ctx, t.t, C.int64_t(dims[0]), C.int64_t(dims[1]), C.int64_t(dims[2]), C.int64_t(dims[3]), mode),
}
}
// Slice returns a view of the tensor sliced along dim from low to high in step steps.
// Slice panics if the dimension is invalid or the slice parameters are out of range.
// If dim=0 and step>1, the tensor is a copy rather than a view to ensure proper shape.

View File

@@ -3677,9 +3677,6 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
if (b->type == GGML_TYPE_F16 && a->type != GGML_TYPE_F16) {
return false;
}
if (op->op == GGML_OP_MUL_MAT && b->ne[2] * b->ne[3] > 1024) {
return false;
}
#ifdef GGML_USE_MUSA
const int cc = ggml_cuda_info().devices[dev_ctx->device].cc;
if (b->ne[2]*b->ne[3] > 1 && !ggml_is_transposed(a) && !ggml_is_transposed(b)) {

View File

@@ -229,13 +229,8 @@ void ggml_abort(const char * file, int line, const char * fmt, ...) {
fprintf(stderr, "%s\n", message);
ggml_print_backtrace();
}
#if defined(_WIN32)
fflush(stderr);
fflush(stdout);
exit(1);
#else
abort();
#endif
}
// ggml_print_backtrace is registered with std::set_terminate by ggml.cpp

View File

@@ -25,15 +25,12 @@ const (
// Composite returns an image with the alpha channel removed by drawing over a white background.
func Composite(img image.Image) image.Image {
white := color.RGBA{255, 255, 255, 255}
return CompositeColor(img, white)
}
// CompositeColor returns an image with the alpha channel removed by drawing over a white background.
func CompositeColor(img image.Image, color color.Color) image.Image {
dst := image.NewRGBA(img.Bounds())
draw.Draw(dst, dst.Bounds(), &image.Uniform{color}, image.Point{}, draw.Src)
white := color.RGBA{255, 255, 255, 255}
draw.Draw(dst, dst.Bounds(), &image.Uniform{white}, image.Point{}, draw.Src)
draw.Draw(dst, dst.Bounds(), img, img.Bounds().Min, draw.Over)
return dst
}
@@ -58,31 +55,6 @@ func Resize(img image.Image, newSize image.Point, method int) image.Image {
return dst
}
// Pad returns an image which has been resized to fit within a new size, preserving aspect ratio, and padded with a color.
func Pad(img image.Image, newSize image.Point, color color.Color, kernel draw.Interpolator) image.Image {
dst := image.NewRGBA(image.Rect(0, 0, newSize.X, newSize.Y))
draw.Draw(dst, dst.Bounds(), &image.Uniform{color}, image.Point{}, draw.Src)
var minPoint, maxPoint image.Point
if img.Bounds().Dx() > img.Bounds().Dy() {
// landscape
height := newSize.X * img.Bounds().Dy() / img.Bounds().Dx()
minPoint = image.Point{0, (newSize.Y - height) / 2}
maxPoint = image.Point{newSize.X, height + minPoint.Y}
} else {
// portrait
width := newSize.Y * img.Bounds().Dx() / img.Bounds().Dy()
minPoint = image.Point{(newSize.X - width) / 2, 0}
maxPoint = image.Point{minPoint.X + width, newSize.Y}
}
kernel.Scale(dst, image.Rectangle{
Min: minPoint,
Max: maxPoint,
}, img, img.Bounds(), draw.Over, nil)
return dst
}
// Normalize returns a slice of float32 containing each of the r, g, b values for an image normalized around a value.
func Normalize(img image.Image, mean, std [3]float32, rescale bool, channelFirst bool) []float32 {
var pixelVals []float32

View File

@@ -156,7 +156,6 @@ func New(c fs.Config) (model.Model, error) {
)),
},
},
true,
)
default:
return nil, model.ErrUnsupportedTokenizer

View File

@@ -254,30 +254,6 @@ func New(c fs.Config) (model.Model, error) {
keyLength := int(cmp.Or(c.Uint("attention.key_length_mla"), c.Uint("attention.key_length")))
valueLength := int(cmp.Or(c.Uint("attention.value_length_mla"), c.Uint("attention.value_length")))
var pre []string
switch c.String("tokenizer.ggml.pre") {
case "deepseek-v3":
pre = []string{
// Split regex into multiple parts (according to DeepSeek3's regex)
"\\p{N}{1,3}",
`[一-龥぀-ゟ゠-ヿ]+`,
"[!\"#$%&'()*+,\\-./:;<=>?@\\[\\\\\\]^_`{|}~][A-Za-z]+|[^\r\n\\p{L}\\p{P}\\p{S}]?[\\p{L}\\p{M}]+| ?[\\p{P}\\p{S}]+[\r\n]*|\\s*[\r\n]+|\\s+(?!\\S)|\\s+",
}
case "deepseek-llm":
// TODO: these models haven't been vetted so skip for now
// pre = []string{
// "[\r\n]",
// "\\s?[A-Za-zµÀ-ÖØ-öø-ƺƼ-ƿDŽ-ʓʕ-ʯͰ-ͳͶͷͻ-ͽͿΆΈ-ΊΌΎ-ΡΣ-ϵϷ-ҁҊ-ԯԱ-ՖႠ-ჅᎠ-Ᏽᏸ-ᏽᲐ-ᲺᲽ-Ჿᴀ-ᴫᵫ-ᵷᵹ-ᶚḀ-ἕἘ-Ἕἠ-ὅὈ-Ὅὐ-ὗὙὛὝὟ-ώᾀ-ᾴᾶ-ᾼιῂ-ῄῆ-ῌῐ-ΐῖ-Ίῠ-Ῥῲ-ῴῶ-ῼℂℇℊ--ℝℤΩℨK--ℴℹℼ-ℿⅅ-ⅉⅎↃↄⰀ-ⱻⱾ-ⳤⳫ-ⳮⳲⳳꙀ-ꙭꚀ-ꚛꜢ-ꝯꝱ-ꞇꞋ-ꞎꭰ-ꮿff-stﬓ-ﬗA--z𐐀-𐑏𐒰-𐓓𐓘-𐓻𐲀-𐲲𐳀-𐳲𑢠-𑣟𞤀-𞥃]+",
// "\\s?[!-/:-~---‟ -。]+",
// "\\s+$",
// "[一-龥ࠀ-一가-퟿]+",
// "[0-9]",
// }
fallthrough
default:
return nil, model.ErrUnsupportedTokenizer
}
m := Model{
BytePairEncoding: model.NewBytePairEncoding(
&model.Vocabulary{
@@ -292,7 +268,10 @@ func New(c fs.Config) (model.Model, error) {
c.Ints("tokenizer.ggml.eos_token_ids")...,
),
},
pre...,
// Split regex into multiple parts (according to DeepSeek3's regex)
"\\p{N}{1,3}",
`[一-龥぀-ゟ゠-ヿ]+`,
"[!\"#$%&'()*+,\\-./:;<=>?@\\[\\\\\\]^_`{|}~][A-Za-z]+|[^\r\n\\p{L}\\p{P}\\p{S}]?[\\p{L}\\p{M}]+| ?[\\p{P}\\p{S}]+[\r\n]*|\\s*[\r\n]+|\\s+(?!\\S)|\\s+",
),
Layers: layers,
Options: &Options{

View File

@@ -1,83 +0,0 @@
package deepseekocr
import (
"bytes"
"image"
"image/color"
"math"
"slices"
"golang.org/x/image/draw"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/model/imageproc"
)
type ratio struct {
x, y int
}
func ProcessImage(ctx ml.Context, bts []byte) (ml.Tensor, ml.Tensor, []int, error) {
img, _, err := image.Decode(bytes.NewReader(bts))
if err != nil {
return nil, nil, nil, err
}
minNum, maxNum, imageSize, baseSize := 2, 9, 640, 1024
var targetRatios []ratio
for n := minNum; n <= maxNum; n++ {
for i := 1; i <= n; i++ {
for j := 1; j <= n; j++ {
if i*j <= maxNum && i*j >= minNum && !slices.Contains(targetRatios, ratio{i, j}) {
targetRatios = append(targetRatios, ratio{i, j})
}
}
}
}
targetRatio := findBestAspectRatio(targetRatios, img.Bounds().Dx(), img.Bounds().Dy(), imageSize)
targetWidth, targetHeight := imageSize*targetRatio.x, imageSize*targetRatio.y
blocks := targetRatio.x * targetRatio.y
mean := imageproc.ImageNetStandardMean
std := imageproc.ImageNetStandardSTD
var patches []float32
resized := imageproc.Resize(img, image.Point{X: targetWidth, Y: targetHeight}, imageproc.ResizeBilinear)
for i := range blocks {
patch := image.NewRGBA(image.Rect(0, 0, imageSize, imageSize))
draw.Draw(patch, patch.Bounds(), resized, image.Point{
X: i % (targetWidth / imageSize) * imageSize,
Y: i / (targetWidth / imageSize) * imageSize,
}, draw.Over)
patches = append(patches, imageproc.Normalize(patch, mean, std, true, true)...)
}
img = imageproc.CompositeColor(img, color.Gray{})
img = imageproc.Pad(img, image.Point{X: baseSize, Y: baseSize}, color.Gray{127}, draw.BiLinear)
return ctx.Input().FromFloats(patches, imageSize, imageSize, 3, blocks),
ctx.Input().FromFloats(imageproc.Normalize(img, mean, std, true, true), baseSize, baseSize, 3),
[]int{targetRatio.x, targetRatio.y},
nil
}
func findBestAspectRatio(targetRatios []ratio, width, height, imageSize int) ratio {
bestDiff := math.MaxFloat64
best := ratio{1, 1}
realRatio := float64(width) / float64(height)
for _, target := range targetRatios {
targetRatio := float64(target.x) / float64(target.y)
diff := math.Abs(realRatio - targetRatio)
if diff < bestDiff {
bestDiff = diff
best = target
} else if diff == bestDiff {
if float64(width*height) > 0.5*float64(imageSize*imageSize*best.x*best.y) {
best = target
}
}
}
return best
}

View File

@@ -1,192 +0,0 @@
package deepseekocr
import (
"math"
"slices"
"github.com/ollama/ollama/fs"
"github.com/ollama/ollama/kvcache"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/ml/nn"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/model/input"
)
type Model struct {
model.Base
model.TextProcessor
Sam *samModel `gguf:"s"`
Vision *visionModel `gguf:"v"`
Text *textModel
ImageNewline ml.Tensor `gguf:"mm.image_newline"`
//nolint:misspell // this misspelling is upstream. fixing it breaks the model
ViewSeperator ml.Tensor `gguf:"mm.view_seperator"`
Projector *nn.Linear `gguf:"mm.layers"`
}
func (m *Model) EncodeMultimodal(ctx ml.Context, bts []byte) ([]input.Multimodal, error) {
patches, original, crop, err := ProcessImage(ctx, bts)
if err != nil {
return nil, err
}
var outputs []ml.Tensor
if true { // TODO: local features if sum(patches) != 0
samOutputs := m.Sam.Forward(ctx, patches)
visionOutputs := m.Vision.Forward(ctx, patches, samOutputs)
samOutputs = samOutputs.Reshape(ctx, -1, samOutputs.Dim(2), samOutputs.Dim(3)).Permute(ctx, 1, 0, 2, 3)
visionOutputs = visionOutputs.Slice(ctx, 1, 1, visionOutputs.Dim(1), 1)
localOutputs := visionOutputs.Concat(ctx, samOutputs, 0)
localOutputs = m.Projector.Forward(ctx, localOutputs)
hw := int(math.Sqrt(float64(localOutputs.Dim(1))))
localOutputs = localOutputs.Reshape(ctx, -1, hw, crop[0], crop[1])
localOutputs = localOutputs.Permute(ctx, 0, 2, 1, 3)
localOutputs = localOutputs.Contiguous(ctx, -1, crop[0]*hw, crop[1]*hw)
localOutputs = localOutputs.Concat(ctx, m.ImageNewline.Repeat(ctx, 2, localOutputs.Dim(2)), 1)
localOutputs = localOutputs.Reshape(ctx, localOutputs.Dim(0), -1)
outputs = append(outputs, localOutputs)
}
samOutputs := m.Sam.Forward(ctx, original)
visionOutputs := m.Vision.Forward(ctx, original, samOutputs)
samOutputs = samOutputs.Reshape(ctx, -1, samOutputs.Dim(2), samOutputs.Dim(3)).Permute(ctx, 1, 0, 2, 3)
visionOutputs = visionOutputs.Slice(ctx, 1, 1, visionOutputs.Dim(1), 1)
globalOutputs := visionOutputs.Concat(ctx, samOutputs, 0)
globalOutputs = m.Projector.Forward(ctx, globalOutputs)
hw := int(math.Sqrt(float64(globalOutputs.Dim(1))))
globalOutputs = globalOutputs.Reshape(ctx, -1, hw, hw)
globalOutputs = globalOutputs.Concat(ctx, m.ImageNewline.Repeat(ctx, 2, globalOutputs.Dim(2)), 1)
globalOutputs = globalOutputs.Reshape(ctx, globalOutputs.Dim(0), -1)
outputs = append(outputs, globalOutputs, m.ViewSeperator)
return []input.Multimodal{
{Tensor: outputs[0].Stack(ctx, 1, outputs[1:]...)},
}, nil
}
func (m *Model) PostTokenize(inputs []*input.Input) ([]*input.Input, error) {
outputs := make([]*input.Input, 0, len(inputs))
for i := range inputs {
if inputs[i].Multimodal == nil {
outputs = append(outputs, inputs[i])
continue
}
t := inputs[i].Multimodal[0].Tensor
outputs = append(outputs, &input.Input{
Token: 128815,
Multimodal: inputs[i].Multimodal,
MultimodalHash: inputs[i].MultimodalHash,
SameBatch: t.Dim(1) - 1,
})
outputs = slices.Grow(outputs, t.Dim(1)-1)
outputs = append(outputs, slices.Repeat([]*input.Input{{Token: 128815}}, t.Dim(1)-1)...)
}
return outputs, nil
}
func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
inputsEmbeds := m.Text.TokenEmbedding.Forward(ctx, batch.Inputs).Duplicate(ctx)
positions := ctx.Input().FromInts(batch.Positions, len(batch.Positions))
for _, mm := range batch.Multimodal {
t := mm.Multimodal[0].Tensor
ctx.Forward(t.Copy(ctx, inputsEmbeds.View(ctx, mm.Index*inputsEmbeds.Stride(1), t.Dim(0)*t.Dim(1))))
}
hiddenStates := inputsEmbeds
for i, block := range m.Text.Blocks {
if m.Cache != nil {
m.Cache.SetLayer(i)
}
var outputs ml.Tensor
if i == len(m.Text.Blocks)-1 {
outputs = batch.Outputs
}
hiddenStates = block.Forward(ctx, hiddenStates, positions, outputs, m.Cache, m.Text.Options)
}
hiddenStates = m.Text.OutputNorm.Forward(ctx, hiddenStates, m.Text.Options.eps)
return m.Text.Output.Forward(ctx, hiddenStates), nil
}
func init() {
model.Register("deepseekocr", func(c fs.Config) (model.Model, error) {
textBlocks := make([]textBlock, c.Uint("block_count"))
leadingDenseBlockCount := int(c.Uint("leading_dense_block_count", 1))
for i := range textBlocks {
if i >= leadingDenseBlockCount {
textBlocks[i].FeedForward = &textMoe{}
} else {
textBlocks[i].FeedForward = &textMLP{}
}
}
m := Model{
TextProcessor: model.NewBytePairEncoding(
&model.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"),
Types: c.Ints("tokenizer.ggml.token_type"),
Merges: c.Strings("tokenizer.ggml.merges"),
AddBOS: c.Bool("tokenizer.ggml.add_bos_token", true),
BOS: []int32{int32(c.Uint("tokenizer.ggml.bos_token_id"))},
AddEOS: c.Bool("tokenizer.ggml.add_eos_token", false),
EOS: append(
[]int32{int32(c.Uint("tokenizer.ggml.eos_token_id"))},
c.Ints("tokenizer.ggml.eos_token_ids")...,
),
},
// Split regex into multiple parts (according to DeepSeek3's regex)
"\\p{N}{1,3}",
`[一-龥぀-ゟ゠-ヿ]+`,
"[!\"#$%&'()*+,\\-./:;<=>?@\\[\\\\\\]^_`{|}~][A-Za-z]+|[^\r\n\\p{L}\\p{P}\\p{S}]?[\\p{L}\\p{M}]+| ?[\\p{P}\\p{S}]+[\r\n]*|\\s*[\r\n]+|\\s+(?!\\S)|\\s+",
),
Text: &textModel{
Blocks: textBlocks,
Options: textOptions{
hiddenSize: int(c.Uint("embedding_length")),
numHeads: int(c.Uint("attention.head_count")),
numKVHeads: int(c.Uint("attention.head_count_kv")),
numExperts: int(c.Uint("expert_count")),
numExpertsUsed: int(c.Uint("expert_used_count")),
ropeBase: c.Float("rope.freq_base", 10_000),
ropeScale: c.Float("rope.scaling.factor", 1.0),
eps: c.Float("attention.layer_norm_rms_epsilon", 1e-6),
},
},
Vision: &visionModel{
Blocks: make([]visionBlock, c.Uint("vision.block_count")),
Options: visionOptions{
hiddenSize: int(c.Uint("vision.embedding_length")),
numHeads: int(c.Uint("vision.head_count")),
imageSize: int(c.Uint("vision.image_size", 224)),
patchSize: int(c.Uint("vision.patch_size", 14)),
eps: c.Float("vision.attention.layer_norm_epsilon", 1e-5),
},
},
Sam: &samModel{
Blocks: make([]samBlock, c.Uint("sam.block_count")),
Options: samOptions{
hiddenSize: int(c.Uint("sam.embedding_length")),
numHeads: int(c.Uint("sam.head_count")),
eps: c.Float("sam.attention.layer_norm_epsilon", 1e-6),
globalAttentionLayers: c.Ints("sam.global_attention_indexes"),
},
},
}
m.Cache = kvcache.NewCausalCache(m.Text.Shift)
return &m, nil
})
}

View File

@@ -1,225 +0,0 @@
package deepseekocr
import (
"math"
"slices"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/ml/nn"
)
type samModel struct {
PatchEmbedding *nn.Conv2D `gguf:"patch_embd"`
PositionEmbedding ml.Tensor `gguf:"position_embd"`
Blocks []samBlock `gguf:"blk"`
Neck *samNeck `gguf:"neck"`
Net2 *nn.Conv2D `gguf:"net_2"`
Net3 *nn.Conv2D `gguf:"net_3"`
Options samOptions
}
func (m *samModel) absolutePositionEmbedding(ctx ml.Context, hiddenStates ml.Tensor) ml.Tensor {
source := m.PositionEmbedding.Dim(1)
target := hiddenStates.Dim(2)
if source != target {
positionEmbed := m.PositionEmbedding.Permute(ctx, 2, 0, 1, 3)
positionEmbed = positionEmbed.Interpolate(ctx, [4]int{target, target, hiddenStates.Dim(0), 1}, ml.SamplingModeBilinear)
return positionEmbed.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx)
}
return m.PositionEmbedding
}
func (m *samModel) Forward(ctx ml.Context, t ml.Tensor) ml.Tensor {
hiddenStates := m.PatchEmbedding.Forward(ctx, t, 16, 16, 0, 0, 1, 1)
hiddenStates = hiddenStates.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx)
if m.PositionEmbedding != nil {
hiddenStates = hiddenStates.Add(ctx, m.absolutePositionEmbedding(ctx, hiddenStates))
}
for i, block := range m.Blocks {
var windowSize int
if !slices.Contains(m.Options.globalAttentionLayers, int32(i)) {
windowSize = 14
}
hiddenStates = block.Forward(ctx, hiddenStates, windowSize, m.Options)
}
hiddenStates = hiddenStates.Permute(ctx, 2, 0, 1, 3).Contiguous(ctx)
hiddenStates = m.Neck.Forward(ctx, hiddenStates, m.Options)
hiddenStates = m.Net2.Forward(ctx, hiddenStates, 2, 2, 1, 1, 1, 1)
hiddenStates = m.Net3.Forward(ctx, hiddenStates, 2, 2, 1, 1, 1, 1)
return hiddenStates
}
type samOptions struct {
hiddenSize,
numHeads int
eps float32
globalAttentionLayers []int32
}
func (o samOptions) headDim() int {
return o.hiddenSize / o.numHeads
}
type samBlock struct {
Norm1 *nn.LayerNorm `gguf:"norm1"`
Attention *samAttention `gguf:"attn"`
Norm2 *nn.LayerNorm `gguf:"norm2"`
FeedForward *samMLP `gguf:"mlp"`
}
func (m *samBlock) Forward(ctx ml.Context, hiddenStates ml.Tensor, windowSize int, opts samOptions) ml.Tensor {
c, w, h := hiddenStates.Dim(0), hiddenStates.Dim(1), hiddenStates.Dim(2)
residual := hiddenStates
hiddenStates = m.Norm1.Forward(ctx, hiddenStates, opts.eps)
var pw, ph int
if windowSize > 0 {
pw = (windowSize - hiddenStates.Dim(1)%windowSize) % windowSize
ph = (windowSize - hiddenStates.Dim(2)%windowSize) % windowSize
if pw > 0 || ph > 0 {
hiddenStates = hiddenStates.Pad(ctx, 0, pw, ph, 0)
}
hiddenStates = hiddenStates.Reshape(ctx, c*windowSize, (w+pw)/windowSize, windowSize, -1)
hiddenStates = hiddenStates.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx, c, windowSize, windowSize, -1)
}
hiddenStates = m.Attention.Forward(ctx, hiddenStates, opts)
if windowSize > 0 {
hiddenStates = hiddenStates.Reshape(ctx, c*windowSize, windowSize, (w+pw)/windowSize, -1)
hiddenStates = hiddenStates.Permute(ctx, 0, 2, 1, 3)
hiddenStates = hiddenStates.Contiguous(ctx, c, w+pw, h+ph, -1)
hiddenStates = hiddenStates.Pad(ctx, 0, -pw, -ph, 0)
}
hiddenStates = hiddenStates.Add(ctx, residual)
residual = hiddenStates
hiddenStates = m.Norm2.Forward(ctx, hiddenStates, opts.eps)
hiddenStates = m.FeedForward.Forward(ctx, hiddenStates, opts)
return hiddenStates.Add(ctx, residual)
}
type samAttention struct {
QKV *nn.Linear `gguf:"qkv"`
Output *nn.Linear `gguf:"proj"`
RelativePosition *struct {
Height ml.Tensor `gguf:"h"`
Width ml.Tensor `gguf:"w"`
} `gguf:",pre:rel_pos_"`
}
func relativeCoordinates(ctx ml.Context, qn, kn int) ml.Tensor {
s := make([]int32, qn*kn)
for i := range qn {
for j := range kn {
q := i * max(kn/qn, 1)
k := j * max(qn/kn, 1)
s[i*kn+j] = int32(q - k + (kn-1)*max(qn/kn, 1))
}
}
return ctx.Input().FromInts(s, qn*kn)
}
func relativePositions(ctx ml.Context, positions ml.Tensor, qn, kn int) ml.Tensor {
maxRelativeDistance := 2*max(qn, kn) - 1
if positions.Dim(1) != maxRelativeDistance {
// linear interpolation kernel not available so approx. with bilinear interpolation
positions = positions.Interpolate(ctx, [4]int{positions.Dim(0), maxRelativeDistance, 1, 1}, ml.SamplingModeBilinear)
}
rc := relativeCoordinates(ctx, qn, kn)
return positions.Rows(ctx, rc).Reshape(ctx, positions.Dim(0), kn, qn)
}
func (m *samAttention) decomposedRelativePositions(ctx ml.Context, query ml.Tensor, qn, kn []int) (ml.Tensor, ml.Tensor) {
qh, qw := qn[0], qn[1]
kh, kw := kn[0], kn[1]
rh := relativePositions(ctx, m.RelativePosition.Height, qh, kh)
rw := relativePositions(ctx, m.RelativePosition.Width, qw, kw)
query = query.Contiguous(ctx, query.Dim(0), qw, qh, -1)
rh = rh.Mulmat(ctx, query).Reshape(ctx, 1, kh, qh*qw, -1)
rw = rw.Mulmat(ctx, query.Permute(ctx, 0, 2, 1, 3)).Permute(ctx, 0, 2, 1, 3).Contiguous(ctx, kw, 1, qh*qw, -1)
return rh, rw
}
func (m *samAttention) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts samOptions) ml.Tensor {
w, h, b := hiddenStates.Dim(1), hiddenStates.Dim(2), hiddenStates.Dim(3)
qkv := m.QKV.Forward(ctx, hiddenStates)
qkv = qkv.Reshape(ctx, opts.headDim(), -1, w*h, b)
chunks := qkv.Chunk(ctx, 1, opts.numHeads)
query, key, value := chunks[0], chunks[1], chunks[2]
ctx.Forward(query, key, value)
query = query.Permute(ctx, 0, 2, 1, 3)
rh, rw := m.decomposedRelativePositions(ctx, query, []int{h, w}, []int{h, w})
mask := rh.Repeat(ctx, 0, rw.Dim(0)).Add(ctx, rw)
mask = mask.Reshape(ctx, h*w, -1, opts.numHeads, b)
key = key.Permute(ctx, 0, 2, 1, 3)
scores := key.MulmatFullPrec(ctx, query)
scores = scores.Scale(ctx, 1/math.Sqrt(float64(opts.headDim())))
scores = scores.Add(ctx, mask)
scores = scores.Softmax(ctx)
value = value.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx)
attention := value.Mulmat(ctx, scores)
attention = attention.Permute(ctx, 0, 2, 1, 3)
attention = attention.Contiguous(ctx, -1, w, h, b)
return m.Output.Forward(ctx, attention)
}
type samMLP struct {
Lin1 *nn.Linear `gguf:"lin1"`
Lin2 *nn.Linear `gguf:"lin2"`
}
func (m *samMLP) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts samOptions) ml.Tensor {
return m.Lin2.Forward(ctx, m.Lin1.Forward(ctx, hiddenStates).GELU(ctx))
}
type LayerNorm2D struct {
Weight ml.Tensor `gguf:"weight"`
Bias ml.Tensor `gguf:"bias"`
}
func (ln *LayerNorm2D) Forward(ctx ml.Context, x ml.Tensor, eps float32) ml.Tensor {
x = x.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx)
u := x.Mean(ctx)
d := x.Sub(ctx, u)
s := d.Sqr(ctx).Mean(ctx)
x = d.Div(ctx, s.Add(ctx, ctx.Input().FromFloats([]float32{eps}, 1)).Sqrt(ctx))
x = x.Mul(ctx, ln.Weight).Add(ctx, ln.Bias)
return x.Permute(ctx, 2, 0, 1, 3).Contiguous(ctx)
}
type samNeck struct {
C1 *nn.Conv2D `gguf:"0"`
LN1 *LayerNorm2D `gguf:"1"`
C2 *nn.Conv2D `gguf:"2"`
LN2 *LayerNorm2D `gguf:"3"`
}
func (m *samNeck) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts samOptions) ml.Tensor {
hiddenStates = m.C1.Forward(ctx, hiddenStates, 1, 1, 0, 0, 1, 1)
hiddenStates = m.LN1.Forward(ctx, hiddenStates, opts.eps)
hiddenStates = m.C2.Forward(ctx, hiddenStates, 1, 1, 1, 1, 1, 1)
hiddenStates = m.LN2.Forward(ctx, hiddenStates, opts.eps)
return hiddenStates
}

View File

@@ -1,140 +0,0 @@
package deepseekocr
import (
"math"
"github.com/ollama/ollama/kvcache"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/ml/nn"
"github.com/ollama/ollama/ml/nn/fast"
"github.com/ollama/ollama/ml/nn/rope"
)
type textModel struct {
TokenEmbedding *nn.Embedding `gguf:"token_embd"`
Blocks []textBlock `gguf:"blk"`
OutputNorm *nn.RMSNorm `gguf:"output_norm"`
Output *nn.Linear `gguf:"output"`
Options textOptions
}
func (m *textModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
return m.Options.applyRotaryPositionalEmbedding(ctx, key, shift), nil
}
type textOptions struct {
hiddenSize,
numHeads,
numKVHeads,
numExperts,
numExpertsUsed int
ropeBase,
ropeScale,
eps float32
}
func (o textOptions) headDim() int {
return o.hiddenSize / o.numHeads
}
func (o textOptions) applyRotaryPositionalEmbedding(ctx ml.Context, t, p ml.Tensor) ml.Tensor {
return fast.RoPE(ctx, t, p, o.headDim(), o.ropeBase, 1/o.ropeScale, rope.WithTypeNeoX())
}
type textBlock struct {
AttentionNorm *nn.RMSNorm `gguf:"attn_norm"`
Attention *textAttention
MLPNNorm *nn.RMSNorm `gguf:"ffn_norm"`
FeedForward textFeedForward
}
func (m *textBlock) Forward(ctx ml.Context, hiddenStates, positions, outputs ml.Tensor, cache kvcache.Cache, opts textOptions) ml.Tensor {
residual := hiddenStates
hiddenStates = m.AttentionNorm.Forward(ctx, hiddenStates, opts.eps)
hiddenStates = m.Attention.Forward(ctx, hiddenStates, positions, cache, opts)
if outputs != nil {
hiddenStates = hiddenStates.Rows(ctx, outputs)
residual = residual.Rows(ctx, outputs)
}
hiddenStates = hiddenStates.Add(ctx, residual)
residual = hiddenStates
hiddenStates = m.MLPNNorm.Forward(ctx, hiddenStates, opts.eps)
hiddenStates = m.FeedForward.Forward(ctx, hiddenStates, opts)
return hiddenStates.Add(ctx, residual)
}
type textAttention struct {
Query *nn.Linear `gguf:"attn_q"`
Key *nn.Linear `gguf:"attn_k"`
Value *nn.Linear `gguf:"attn_v"`
Output *nn.Linear `gguf:"attn_output"`
}
func (m *textAttention) Forward(ctx ml.Context, hiddenStates, positions ml.Tensor, cache kvcache.Cache, opts textOptions) ml.Tensor {
query := m.Query.Forward(ctx, hiddenStates)
query = query.Reshape(ctx, opts.headDim(), opts.numHeads, -1)
key := m.Key.Forward(ctx, hiddenStates)
key = key.Reshape(ctx, opts.headDim(), opts.numKVHeads, -1)
value := m.Value.Forward(ctx, hiddenStates)
value = value.Reshape(ctx, opts.headDim(), opts.numKVHeads, -1)
query = opts.applyRotaryPositionalEmbedding(ctx, query, positions)
key = opts.applyRotaryPositionalEmbedding(ctx, key, positions)
attention := nn.Attention(ctx, query, key, value, 1./math.Sqrt(float64(opts.headDim())), cache)
attention = attention.Reshape(ctx, -1, attention.Dim(2))
return m.Output.Forward(ctx, attention)
}
type textFeedForward interface {
Forward(ml.Context, ml.Tensor, textOptions) ml.Tensor
}
type textMoe struct {
Router *nn.Linear `gguf:"ffn_gate_inp"`
Gate *nn.LinearBatch `gguf:"ffn_gate_exps"`
Up *nn.LinearBatch `gguf:"ffn_up_exps"`
Down *nn.LinearBatch `gguf:"ffn_down_exps"`
SharedExperts *textMLP `gguf:",suf:_shexp"`
}
func (m *textMoe) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts textOptions) ml.Tensor {
scores := m.Router.Forward(ctx, hiddenStates).Softmax(ctx)
indices := scores.TopK(ctx, opts.numExpertsUsed)
weights := scores.Reshape(ctx, 1, opts.numExperts, hiddenStates.Dim(1)).Rows(ctx, indices)
experts := hiddenStates.Reshape(ctx, hiddenStates.Dim(0), 1, hiddenStates.Dim(1))
experts = m.Gate.Forward(ctx, experts, indices).SILU(ctx, m.Up.Forward(ctx, experts, indices))
experts = m.Down.Forward(ctx, experts, indices)
experts = experts.Mul(ctx, weights)
expert := func(i int) ml.Tensor {
return experts.View(
ctx, i*experts.Stride(1), experts.Dim(0), experts.Stride(2), experts.Dim(2),
)
}
routedStates := expert(0)
for i := 1; i < opts.numExpertsUsed; i++ {
routedStates = routedStates.Add(ctx, expert(i))
}
sharedStates := m.SharedExperts.Forward(ctx, hiddenStates, opts)
return routedStates.Add(ctx, sharedStates)
}
type textMLP struct {
Gate *nn.Linear `gguf:"ffn_gate"`
Up *nn.Linear `gguf:"ffn_up"`
Down *nn.Linear `gguf:"ffn_down"`
}
func (m *textMLP) Forward(ctx ml.Context, hiddenStates ml.Tensor, _ textOptions) ml.Tensor {
hiddenStates = m.Gate.Forward(ctx, hiddenStates).SILU(ctx, m.Up.Forward(ctx, hiddenStates))
return m.Down.Forward(ctx, hiddenStates)
}

View File

@@ -1,117 +0,0 @@
package deepseekocr
import (
"math"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/ml/nn"
)
type visionModel struct {
PatchEmbedding *nn.Conv2D `gguf:"patch_embd"`
ClassEmbedding ml.Tensor `gguf:"class_embd"`
PositionEmbedding *nn.Embedding `gguf:"position_embd"`
PreLayerNorm *nn.LayerNorm `gguf:"pre_layrnorm"`
Blocks []visionBlock `gguf:"blk"`
Options visionOptions
}
func (m *visionModel) absolutePositionEmbedding(ctx ml.Context, embeds ml.Tensor) ml.Tensor {
numPatches := m.Options.imageSize / m.Options.patchSize * m.Options.imageSize / m.Options.patchSize
positions := ctx.Arange(0, float32(numPatches+1), 1, ml.DTypeI32)
positionEmbeds := m.PositionEmbedding.Forward(ctx, positions)
source := int(math.Sqrt(float64(positionEmbeds.Dim(1) - 1)))
target := int(math.Sqrt(float64(embeds.Dim(1) - 1)))
if source != target {
newPositionEmbeds := positionEmbeds.Slice(ctx, 1, 1, positionEmbeds.Dim(1), 1)
newPositionEmbeds = newPositionEmbeds.Reshape(ctx, -1, source, source)
newPositionEmbeds = newPositionEmbeds.Permute(ctx, 2, 0, 1, 3).Contiguous(ctx)
newPositionEmbeds = newPositionEmbeds.Interpolate(ctx, [4]int{target, target, embeds.Dim(0), 1}, ml.SamplingModeBilinear)
newPositionEmbeds = newPositionEmbeds.Permute(ctx, 1, 2, 0, 3)
newPositionEmbeds = newPositionEmbeds.Contiguous(ctx, -1, target*target)
positionEmbeds = positionEmbeds.Slice(ctx, 1, 0, 1, 1).Concat(ctx, newPositionEmbeds, 1)
}
return positionEmbeds
}
func (m *visionModel) Forward(ctx ml.Context, pixelValues, patchEmbeds ml.Tensor) ml.Tensor {
if patchEmbeds == nil {
patchEmbeds = m.PatchEmbedding.Forward(ctx, pixelValues, m.Options.patchSize, m.Options.patchSize, 0, 0, 1, 1)
}
patchEmbeds = patchEmbeds.Reshape(ctx, -1, patchEmbeds.Dim(2), patchEmbeds.Dim(3))
patchEmbeds = patchEmbeds.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx)
classEmbeds := m.ClassEmbedding.Repeat(ctx, 2, patchEmbeds.Dim(2))
embeds := classEmbeds.Concat(ctx, patchEmbeds, 1)
embeds = embeds.Add(ctx, m.absolutePositionEmbedding(ctx, embeds))
hiddenStates := m.PreLayerNorm.Forward(ctx, embeds, m.Options.eps)
for _, block := range m.Blocks {
hiddenStates = block.Forward(ctx, hiddenStates, m.Options)
}
return hiddenStates
}
type visionOptions struct {
hiddenSize,
numHeads int
eps float32
imageSize, patchSize int
}
func (o visionOptions) headDim() int {
return o.hiddenSize / o.numHeads
}
type visionBlock struct {
Norm1 *nn.LayerNorm `gguf:"layer_norm1"`
Attention *visionAttention `gguf:"self_attn"`
Norm2 *nn.LayerNorm `gguf:"layer_norm2"`
FeedForward *visionMLP `gguf:"mlp"`
}
func (m *visionBlock) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts visionOptions) ml.Tensor {
residual := hiddenStates
hiddenStates = m.Norm1.Forward(ctx, hiddenStates, opts.eps)
hiddenStates = m.Attention.Forward(ctx, hiddenStates, opts)
hiddenStates = hiddenStates.Add(ctx, residual)
residual = hiddenStates
hiddenStates = m.Norm2.Forward(ctx, hiddenStates, opts.eps)
hiddenStates = m.FeedForward.Forward(ctx, hiddenStates)
hiddenStates = hiddenStates.Add(ctx, residual)
return hiddenStates
}
type visionAttention struct {
QKV *nn.Linear `gguf:"qkv_proj"`
Output *nn.Linear `gguf:"out_proj"`
}
func (m *visionAttention) Forward(ctx ml.Context, t ml.Tensor, opts visionOptions) ml.Tensor {
qkv := m.QKV.Forward(ctx, t)
qkv = qkv.Reshape(ctx, opts.headDim(), -1, qkv.Dim(1), qkv.Dim(2))
chunks := qkv.Chunk(ctx, 1, opts.numHeads)
query, key, value := chunks[0], chunks[1], chunks[2]
attention := nn.Attention(ctx, query, key, value, 1/math.Sqrt(float64(opts.headDim())), nil)
attention = attention.Reshape(ctx, -1, attention.Dim(2), attention.Dim(3))
return m.Output.Forward(ctx, attention)
}
type visionMLP struct {
FC1 *nn.Linear `gguf:"fc1"`
FC2 *nn.Linear `gguf:"fc2"`
}
func (m *visionMLP) Forward(ctx ml.Context, t ml.Tensor) ml.Tensor {
return m.FC2.Forward(ctx, m.FC1.Forward(ctx, t).QuickGELU(ctx))
}

View File

@@ -3,7 +3,6 @@ package models
import (
_ "github.com/ollama/ollama/model/models/bert"
_ "github.com/ollama/ollama/model/models/deepseek2"
_ "github.com/ollama/ollama/model/models/deepseekocr"
_ "github.com/ollama/ollama/model/models/gemma2"
_ "github.com/ollama/ollama/model/models/gemma3"
_ "github.com/ollama/ollama/model/models/gemma3n"
@@ -12,7 +11,6 @@ import (
_ "github.com/ollama/ollama/model/models/llama4"
_ "github.com/ollama/ollama/model/models/mistral3"
_ "github.com/ollama/ollama/model/models/mllama"
_ "github.com/ollama/ollama/model/models/nomicbert"
_ "github.com/ollama/ollama/model/models/qwen2"
_ "github.com/ollama/ollama/model/models/qwen25vl"
_ "github.com/ollama/ollama/model/models/qwen3"

View File

@@ -1,170 +0,0 @@
package nomicbert
import (
"cmp"
"math"
"github.com/ollama/ollama/fs"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/ml/nn"
"github.com/ollama/ollama/ml/nn/fast"
"github.com/ollama/ollama/ml/nn/pooling"
"github.com/ollama/ollama/ml/nn/rope"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/model/input"
)
type Model struct {
model.Base
model.TextProcessor
TokenEmbedding *nn.Embedding `gguf:"token_embd"`
TypeEmbedding *nn.Embedding `gguf:"token_types"`
TokenEmbeddingNorm *nn.LayerNorm `gguf:"token_embd_norm"`
Layers []EncoderLayer `gguf:"blk"`
Options
}
type Options struct {
hiddenSize int
numHeads int
headDim int
eps float32
poolingType pooling.Type
normalize bool
ropeFreqBase float32
}
// Single Encoder Layer
type EncoderLayer struct {
*Attention
AttentionNorm *nn.LayerNorm `gguf:"attn_output_norm"`
*MLP
MLPNorm *nn.LayerNorm `gguf:"layer_output_norm"`
}
type Attention struct {
QKV *nn.Linear `gguf:"attn_qkv"`
Output *nn.Linear `gguf:"attn_output"`
}
type MLP struct {
Gate *nn.Linear `gguf:"ffn_gate"`
Up *nn.Linear `gguf:"ffn_up"`
Down *nn.Linear `gguf:"ffn_down"`
}
func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
hiddenStates := m.TokenEmbedding.Forward(ctx, batch.Inputs)
typeEmbed := m.TypeEmbedding.Weight.Slice(ctx, 1, 0, 1, 1)
hiddenStates = hiddenStates.Add(ctx, typeEmbed)
hiddenStates = m.TokenEmbeddingNorm.Forward(ctx, hiddenStates, m.eps)
positions := ctx.Input().FromInts(batch.Positions, len(batch.Positions))
for _, layer := range m.Layers {
hiddenStates = layer.Forward(ctx, hiddenStates, positions, &m.Options)
}
hiddenStates = m.poolingType.Forward(ctx, hiddenStates)
if m.normalize {
hiddenStates = hiddenStates.L2Norm(ctx, 1e-12)
}
return hiddenStates, nil
}
func (e *EncoderLayer) Forward(ctx ml.Context, hiddenStates ml.Tensor, positions ml.Tensor, opts *Options) ml.Tensor {
residual := hiddenStates
hiddenStates = e.Attention.Forward(ctx, hiddenStates, positions, opts)
hiddenStates = hiddenStates.Add(ctx, residual)
hiddenStates = e.AttentionNorm.Forward(ctx, hiddenStates, opts.eps)
residual = hiddenStates
hiddenStates = e.MLP.Forward(ctx, hiddenStates)
hiddenStates = hiddenStates.Add(ctx, residual)
hiddenStates = e.MLPNorm.Forward(ctx, hiddenStates, opts.eps)
return hiddenStates
}
func (a *Attention) Forward(ctx ml.Context, hiddenStates ml.Tensor, positions ml.Tensor, opts *Options) ml.Tensor {
batchSize := hiddenStates.Dim(1)
qkv := a.QKV.Forward(ctx, hiddenStates)
qkv = qkv.Reshape(ctx, opts.headDim, opts.numHeads*3, batchSize)
chunks := qkv.Chunk(ctx, 1, opts.numHeads)
query, key, value := chunks[0], chunks[1], chunks[2]
query = fast.RoPE(ctx, query, positions, opts.headDim, opts.ropeFreqBase, 1.0, rope.WithTypeNeoX())
key = fast.RoPE(ctx, key, positions, opts.headDim, opts.ropeFreqBase, 1.0, rope.WithTypeNeoX())
attention := nn.Attention(ctx, query, key, value, 1.0/math.Sqrt(float64(opts.headDim)), nil)
attention = attention.Reshape(ctx, opts.hiddenSize, batchSize)
return a.Output.Forward(ctx, attention)
}
func (m *MLP) Forward(ctx ml.Context, hiddenStates ml.Tensor) ml.Tensor {
hidden := m.Gate.Forward(ctx, hiddenStates).SILU(ctx, m.Up.Forward(ctx, hiddenStates))
return m.Down.Forward(ctx, hidden)
}
func New(c fs.Config) (model.Model, error) {
hiddenSize := int(c.Uint("embedding_length"))
numHeads := int(c.Uint("attention.head_count"))
headDim := hiddenSize / numHeads
processor := model.NewWordPiece(
&model.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"),
Scores: c.Floats("tokenizer.ggml.scores"),
Types: c.Ints("tokenizer.ggml.token_type"),
AddBOS: c.Bool("tokenizer.ggml.add_bos_token", true),
BOS: []int32{
int32(cmp.Or(
c.Uint("tokenizer.ggml.cls_token_id"),
c.Uint("tokenizer.ggml.bos_token_id"),
)),
},
AddEOS: c.Bool("tokenizer.ggml.add_eos_token", true),
EOS: []int32{
int32(cmp.Or(
c.Uint("tokenizer.ggml.separator_token_id"),
c.Uint("tokenizer.ggml.eos_token_id"),
)),
},
},
false,
)
return &Model{
TextProcessor: processor,
Layers: make([]EncoderLayer, c.Uint("block_count")),
Options: Options{
hiddenSize: hiddenSize,
numHeads: numHeads,
headDim: headDim,
eps: c.Float("attention.layer_norm_epsilon"),
poolingType: pooling.Type(c.Uint("pooling_type")),
normalize: c.Bool("normalize_embeddings", false),
ropeFreqBase: c.Float("rope.freq_base", 1000.0),
},
}, nil
}
func init() {
model.Register("nomic-bert", New)
model.Register("nomic-bert_embed", New)
}

View File

@@ -1,319 +0,0 @@
package parsers
import (
"encoding/json"
"errors"
"log/slog"
"strings"
"unicode"
"github.com/ollama/ollama/api"
)
type CogitoParserState int
const (
CogitoCollectingThinking CogitoParserState = iota
CogitoCollectingContent
CogitoCollectingToolCalls
CogitoCollectingToolOutput
)
const (
cogitoThinkingCloseTag = "</think>"
cogitoToolCallsBeginTag = "<tool▁calls▁begin>"
cogitoToolCallsEndTag = "<tool▁calls▁end>"
cogitoToolCallBeginTag = "<tool▁call▁begin>"
cogitoToolCallEndTag = "<tool▁call▁end>"
cogitoToolSepTag = "<tool▁sep>"
cogitoToolOutputBeginTag = "<tool▁output▁begin>"
cogitoToolOutputEndTag = "<tool▁output▁end>"
cogitoToolOutputsBeginTag = "<tool▁outputs▁begin>"
cogitoToolOutputsEndTag = "<tool▁outputs▁end>"
)
type CogitoParser struct {
state CogitoParserState
buffer strings.Builder
}
func (p *CogitoParser) HasToolSupport() bool {
return true
}
func (p *CogitoParser) HasThinkingSupport() bool {
return true
}
func (p *CogitoParser) setInitialState(lastMessage *api.Message, tools []api.Tool, thinkValue *api.ThinkValue) {
prefill := lastMessage != nil && lastMessage.Role == "assistant"
// Check both model capability AND request preference
thinkingEnabled := thinkValue != nil && thinkValue.Bool()
// thinkingEnabled should be set to false for tools
if !thinkingEnabled {
p.state = CogitoCollectingContent
return
}
if prefill && lastMessage.Content != "" {
p.state = CogitoCollectingContent
return
}
// Note: for cogito, if there are tools, then we don't want to be thinking
if len(tools) > 0 {
p.state = CogitoCollectingContent
return
}
p.state = CogitoCollectingThinking
}
func (p *CogitoParser) Init(tools []api.Tool, lastMessage *api.Message, thinkValue *api.ThinkValue) []api.Tool {
p.setInitialState(lastMessage, tools, thinkValue)
return tools
}
type cogitoEvent interface {
isCogitoEvent()
}
type cogitoEventThinkingContent struct {
content string
}
type cogitoEventContent struct {
content string
}
type cogitoEventToolCall struct {
toolCall api.ToolCall
}
func (cogitoEventThinkingContent) isCogitoEvent() {}
func (cogitoEventContent) isCogitoEvent() {}
func (cogitoEventToolCall) isCogitoEvent() {}
func (p *CogitoParser) Add(s string, done bool) (content string, thinking string, calls []api.ToolCall, err error) {
p.buffer.WriteString(s)
events := p.parseEvents()
var toolCalls []api.ToolCall
var contentSb strings.Builder
var thinkingSb strings.Builder
for _, event := range events {
switch event := event.(type) {
case cogitoEventToolCall:
toolCalls = append(toolCalls, event.toolCall)
case cogitoEventThinkingContent:
thinkingSb.WriteString(event.content)
case cogitoEventContent:
contentSb.WriteString(event.content)
}
}
return contentSb.String(), thinkingSb.String(), toolCalls, nil
}
func (p *CogitoParser) parseEvents() []cogitoEvent {
var all []cogitoEvent
keepLooping := true
for keepLooping {
var events []cogitoEvent
events, keepLooping = p.eat()
if len(events) > 0 {
all = append(all, events...)
}
}
return all
}
func (p *CogitoParser) eat() ([]cogitoEvent, bool) {
var events []cogitoEvent
bufStr := p.buffer.String()
if bufStr == "" {
return events, false
}
switch p.state {
case CogitoCollectingThinking:
if strings.Contains(bufStr, cogitoThinkingCloseTag) { // thinking[</think>] -> content
split := strings.SplitN(bufStr, cogitoThinkingCloseTag, 2)
thinking := split[0]
thinking = strings.TrimRightFunc(thinking, unicode.IsSpace)
remaining := split[1]
remaining = strings.TrimLeftFunc(remaining, unicode.IsSpace)
p.buffer.Reset()
p.buffer.WriteString(remaining)
p.state = CogitoCollectingContent
if len(thinking) > 0 {
events = append(events, cogitoEventThinkingContent{content: thinking})
}
return events, true
} else if overlapLen := overlap(bufStr, cogitoThinkingCloseTag); overlapLen > 0 { // partial </think>
beforePartialTag := bufStr[:len(bufStr)-overlapLen]
trailingLen := trailingWhitespaceLen(beforePartialTag)
ambiguousStart := len(beforePartialTag) - trailingLen
unambiguous := bufStr[:ambiguousStart]
ambiguous := bufStr[ambiguousStart:]
p.buffer.Reset()
p.buffer.WriteString(ambiguous)
if len(unambiguous) > 0 {
events = append(events, cogitoEventThinkingContent{content: unambiguous})
}
return events, false
} else { // otherwise its thinking content
whitespaceLen := trailingWhitespaceLen(bufStr)
ambiguousStart := len(bufStr) - whitespaceLen
unambiguous := bufStr[:ambiguousStart]
ambiguous := bufStr[ambiguousStart:]
p.buffer.Reset()
p.buffer.WriteString(ambiguous)
if len(unambiguous) > 0 {
events = append(events, cogitoEventThinkingContent{content: unambiguous})
}
return events, false
}
case CogitoCollectingContent:
switch {
case strings.Contains(bufStr, cogitoToolCallsBeginTag): // content[<tool▁calls▁begin>] -> tool calls
split := strings.SplitN(bufStr, cogitoToolCallsBeginTag, 2)
contentBefore := strings.TrimRightFunc(split[0], unicode.IsSpace)
remaining := split[1]
p.buffer.Reset()
p.buffer.WriteString(remaining)
p.state = CogitoCollectingToolCalls
if len(contentBefore) > 0 {
events = append(events, cogitoEventContent{content: contentBefore})
}
return events, true
case strings.Contains(bufStr, cogitoToolOutputsBeginTag): // content[<tool▁outputs▁begin>] -> tool outputs
split := strings.SplitN(bufStr, cogitoToolOutputsBeginTag, 2)
contentBefore := strings.TrimRightFunc(split[0], unicode.IsSpace)
remaining := split[1]
p.buffer.Reset()
p.buffer.WriteString(remaining)
p.state = CogitoCollectingToolOutput
if len(contentBefore) > 0 {
events = append(events, cogitoEventContent{content: contentBefore})
}
return events, true
default: // otherwise its content
p.buffer.Reset()
if len(bufStr) > 0 {
events = append(events, cogitoEventContent{content: bufStr})
}
return events, false
}
case CogitoCollectingToolCalls:
if idx := strings.Index(bufStr, cogitoToolCallBeginTag); idx != -1 {
startIdx := idx + len(cogitoToolCallBeginTag)
if endIdx := strings.Index(bufStr[startIdx:], cogitoToolCallEndTag); endIdx != -1 {
toolCallContent := bufStr[startIdx : startIdx+endIdx]
if toolCall, err := p.parseToolCallContent(toolCallContent); err == nil {
remaining := bufStr[startIdx+endIdx+len(cogitoToolCallEndTag):]
remaining = strings.TrimLeftFunc(remaining, unicode.IsSpace)
p.buffer.Reset()
p.buffer.WriteString(remaining)
events = append(events, cogitoEventToolCall{toolCall: toolCall})
return events, true
} else {
slog.Warn("cogito tool call parsing failed", "error", err)
}
}
}
if idx := strings.Index(bufStr, cogitoToolCallsEndTag); idx != -1 {
remaining := bufStr[idx+len(cogitoToolCallsEndTag):]
remaining = strings.TrimLeftFunc(remaining, unicode.IsSpace)
p.buffer.Reset()
p.buffer.WriteString(remaining)
p.state = CogitoCollectingContent
return events, true
}
return events, false
case CogitoCollectingToolOutput:
if idx := strings.Index(bufStr, cogitoToolOutputBeginTag); idx != -1 {
startIdx := idx + len(cogitoToolOutputBeginTag)
if endIdx := strings.Index(bufStr[startIdx:], cogitoToolOutputEndTag); endIdx != -1 {
remaining := bufStr[startIdx+endIdx+len(cogitoToolOutputEndTag):]
remaining = strings.TrimLeftFunc(remaining, unicode.IsSpace)
p.buffer.Reset()
p.buffer.WriteString(remaining)
return events, true
}
}
if idx := strings.Index(bufStr, cogitoToolOutputsEndTag); idx != -1 {
remaining := bufStr[idx+len(cogitoToolOutputsEndTag):]
remaining = strings.TrimLeftFunc(remaining, unicode.IsSpace)
p.buffer.Reset()
p.buffer.WriteString(remaining)
p.state = CogitoCollectingContent
return events, true
}
return events, false
}
return events, false
}
func (p *CogitoParser) parseToolCallContent(content string) (api.ToolCall, error) {
// Expected format: function<tool▁sep>tool_name\n```json\n{args}\n```
parts := strings.SplitN(content, cogitoToolSepTag, 2)
if len(parts) < 2 {
return api.ToolCall{}, errors.New("invalid format")
}
nameAndArgs := parts[1]
jsonStart := strings.Index(nameAndArgs, "\n```json\n")
if jsonStart == -1 {
return api.ToolCall{}, errors.New("invalid format")
}
toolName := strings.TrimSpace(nameAndArgs[:jsonStart])
jsonContent := nameAndArgs[jsonStart+len("\n```json\n"):]
jsonEnd := strings.Index(jsonContent, "\n```")
if jsonEnd == -1 {
return api.ToolCall{}, errors.New("invalid format")
}
argsJSON := jsonContent[:jsonEnd]
var args api.ToolCallFunctionArguments
if err := json.Unmarshal([]byte(argsJSON), &args); err != nil {
return api.ToolCall{}, err
}
return api.ToolCall{
Function: api.ToolCallFunction{
Name: toolName,
Arguments: args,
},
}, nil
}

View File

@@ -1,565 +0,0 @@
package parsers
import (
"strings"
"testing"
"github.com/google/go-cmp/cmp"
"github.com/ollama/ollama/api"
)
func TestCogitoParser(t *testing.T) {
tests := []struct {
name string
input string
expectedContent string
expectedThinking string
expectedToolCalls []api.ToolCall
tools []api.Tool
lastMessage *api.Message
}{
{
name: "simple_content",
input: "This is a simple response.",
expectedContent: "This is a simple response.",
expectedThinking: "",
},
{
name: "thinking_only",
input: "This is thinking content.</think>This is response content.",
expectedContent: "This is response content.",
expectedThinking: "This is thinking content.",
},
{
name: "tool_call_simple",
input: `<tool▁calls▁begin><tool▁call▁begin>function<tool▁sep>get_weather
` + "```json\n" + `{"location":"Paris"}
` + "```" + `<tool▁call▁end><tool▁calls▁end>`,
expectedToolCalls: []api.ToolCall{
{
Function: api.ToolCallFunction{
Name: "get_weather",
Arguments: api.ToolCallFunctionArguments{
"location": "Paris",
},
},
},
},
tools: []api.Tool{
{
Type: "function",
Function: api.ToolFunction{
Name: "get_weather",
Parameters: api.ToolFunctionParameters{
Properties: map[string]api.ToolProperty{
"location": {Type: api.PropertyType{"string"}},
},
},
},
},
},
},
{
name: "thinking_with_tool_call",
input: `I need to check the weather.</think><tool▁calls▁begin><tool▁call▁begin>function<tool▁sep>get_weather
` + "```json\n" + `{"location":"Paris"}
` + "```" + `<tool▁call▁end><tool▁calls▁end>`,
expectedContent: "I need to check the weather.</think>",
expectedThinking: "", // No thinking when tools are present (Cogito-specific behavior)
expectedToolCalls: []api.ToolCall{
{
Function: api.ToolCallFunction{
Name: "get_weather",
Arguments: api.ToolCallFunctionArguments{
"location": "Paris",
},
},
},
},
tools: []api.Tool{
{
Type: "function",
Function: api.ToolFunction{
Name: "get_weather",
Parameters: api.ToolFunctionParameters{
Properties: map[string]api.ToolProperty{
"location": {Type: api.PropertyType{"string"}},
},
},
},
},
},
},
{
name: "multiple_tool_calls",
input: `<tool▁calls▁begin><tool▁call▁begin>function<tool▁sep>get_weather
` + "```json\n" + `{"location":"Paris"}
` + "```" + `<tool▁call▁end>
<tool▁call▁begin>function<tool▁sep>get_weather
` + "```json\n" + `{"location":"London"}
` + "```" + `<tool▁call▁end><tool▁calls▁end>`,
expectedToolCalls: []api.ToolCall{
{
Function: api.ToolCallFunction{
Name: "get_weather",
Arguments: api.ToolCallFunctionArguments{
"location": "Paris",
},
},
},
{
Function: api.ToolCallFunction{
Name: "get_weather",
Arguments: api.ToolCallFunctionArguments{
"location": "London",
},
},
},
},
tools: []api.Tool{
{
Type: "function",
Function: api.ToolFunction{
Name: "get_weather",
Parameters: api.ToolFunctionParameters{
Properties: map[string]api.ToolProperty{
"location": {Type: api.PropertyType{"string"}},
},
},
},
},
},
},
{
name: "complex_tool_arguments",
input: `<tool▁calls▁begin><tool▁call▁begin>function<tool▁sep>process_data
` + "```json\n" + `{"items":["item1","item2"],"config":{"enabled":true,"threshold":0.95},"count":42}
` + "```" + `<tool▁call▁end><tool▁calls▁end>`,
expectedToolCalls: []api.ToolCall{
{
Function: api.ToolCallFunction{
Name: "process_data",
Arguments: api.ToolCallFunctionArguments{
"items": []any{"item1", "item2"},
"config": map[string]any{"enabled": true, "threshold": 0.95},
"count": 42.0,
},
},
},
},
},
{
name: "tool_output_parsing",
input: `<tool▁outputs▁begin><tool▁output▁begin>{"temperature": 22, "condition": "sunny"}<tool▁output▁end><tool▁outputs▁end>`,
expectedContent: "",
expectedThinking: "",
},
{
name: "thinking_with_multiline_content",
input: `This is line 1
This is line 2
This is line 3</think>Final response here.`,
expectedContent: "Final response here.",
expectedThinking: "This is line 1\nThis is line 2\nThis is line 3",
},
{
name: "no_thinking_simple",
input: "This is content.",
expectedContent: "This is content.",
expectedThinking: "",
},
{
name: "prefill_content_only",
input: "Continuing from previous content.",
expectedContent: "Continuing from previous content.",
lastMessage: &api.Message{
Role: "assistant",
Content: "Previous content",
},
},
{
name: "prefill_with_thinking",
input: "Continuing thinking</think>Continuing content.",
expectedContent: "Continuing content.",
expectedThinking: "Continuing thinking",
lastMessage: &api.Message{
Role: "assistant",
},
},
// Edge cases
{
name: "nested_think_tags_in_thinking",
input: "I'm thinking <think>nested</think> more thinking</think>Final content.",
expectedContent: "more thinking</think>Final content.",
expectedThinking: "I'm thinking <think>nested",
},
{
name: "multiple_think_close_tags",
input: "First thinking</think>Content</think>More content.",
expectedContent: "Content</think>More content.",
expectedThinking: "First thinking",
},
{
name: "empty_thinking_content",
input: "</think>Just content here.",
expectedContent: "</think>Just content here.",
expectedThinking: "",
},
{
name: "thinking_disabled_with_think_tags",
input: "Content with </think> tags should be treated as content.",
expectedContent: "Content with </think> tags should be treated as content.",
expectedThinking: "",
lastMessage: &api.Message{
Role: "assistant",
Content: "existing", // Forces non-thinking mode
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Use thinking-enabled parser for tests that expect thinking
hasThinking := tt.expectedThinking != ""
parser := &CogitoParser{} // it has thinking support
parser.Init(tt.tools, tt.lastMessage, &api.ThinkValue{Value: hasThinking}) // but we should set it with the request that the user wants
content, thinking, toolCalls, err := parser.Add(tt.input, true)
if err != nil {
t.Fatalf("Add() error = %v", err)
}
if diff := cmp.Diff(tt.expectedContent, content); diff != "" {
t.Errorf("content mismatch (-want +got):\n%s", diff)
}
if diff := cmp.Diff(tt.expectedThinking, thinking); diff != "" {
t.Errorf("thinking mismatch (-want +got):\n%s", diff)
}
if diff := cmp.Diff(tt.expectedToolCalls, toolCalls); diff != "" {
t.Errorf("tool calls mismatch (-want +got):\n%s", diff)
}
})
}
}
func TestCogitoParser_Streaming(t *testing.T) {
parser := &CogitoParser{}
parser.Init(nil, nil, &api.ThinkValue{Value: true})
chunks := []string{
"This is ",
"thinking content",
".</think>This is ",
"content.<tool▁calls▁begin><tool▁call▁begin>function<tool▁sep>test_tool\n```json\n{\"arg\":\"value\"}\n```<tool▁call▁end><tool▁calls▁end>",
}
var finalContent, finalThinking strings.Builder
var finalToolCalls []api.ToolCall
for i, chunk := range chunks {
done := i == len(chunks)-1
content, thinking, toolCalls, err := parser.Add(chunk, done)
if err != nil {
t.Fatalf("Add() error on chunk %d: %v", i, err)
}
finalContent.WriteString(content)
finalThinking.WriteString(thinking)
finalToolCalls = append(finalToolCalls, toolCalls...)
}
expectedContent := "This is content."
expectedThinking := "This is thinking content."
expectedToolCalls := []api.ToolCall{
{
Function: api.ToolCallFunction{
Name: "test_tool",
Arguments: api.ToolCallFunctionArguments{
"arg": "value",
},
},
},
}
if finalContent.String() != expectedContent {
t.Errorf("expected content %q, got %q", expectedContent, finalContent.String())
}
if finalThinking.String() != expectedThinking {
t.Errorf("expected thinking %q, got %q", expectedThinking, finalThinking.String())
}
if diff := cmp.Diff(expectedToolCalls, finalToolCalls); diff != "" {
t.Errorf("tool calls mismatch (-want +got):\n%s", diff)
}
}
func TestCogitoParser_StreamingEdgeCases(t *testing.T) {
tests := []struct {
name string
chunks []string
expectedContent string
expectedThinking string
expectedToolCalls []api.ToolCall
hasThinkingSupport bool
}{
{
name: "split_thinking_tag",
chunks: []string{
"This is thinking content</thi",
"nk>This is content.",
},
expectedContent: "This is content.",
expectedThinking: "This is thinking content",
hasThinkingSupport: true,
},
{
name: "split_tool_calls_begin_tag_conservative_parsing",
chunks: []string{
"Content before<tool▁calls▁beg",
"in><tool▁call▁begin>function<tool▁sep>test\n```json\n{}\n```<tool▁call▁end><tool▁calls▁end>",
},
// Parser is conservative - treats incomplete tags as content
expectedContent: "Content before<tool▁calls▁begin><tool▁call▁begin>function<tool▁sep>test\n```json\n{}\n```<tool▁call▁end><tool▁calls▁end>",
expectedToolCalls: nil,
hasThinkingSupport: false,
},
{
name: "thinking_disabled_with_split_tags",
chunks: []string{
"Content with </thi",
"nk> should be treated as content.",
},
expectedContent: "Content with </think> should be treated as content.",
expectedThinking: "",
hasThinkingSupport: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
parser := &CogitoParser{}
parser.Init(nil, nil, &api.ThinkValue{Value: tt.hasThinkingSupport})
var finalContent, finalThinking strings.Builder
var finalToolCalls []api.ToolCall
for i, chunk := range tt.chunks {
done := i == len(tt.chunks)-1
content, thinking, toolCalls, err := parser.Add(chunk, done)
if err != nil {
t.Fatalf("Add() error on chunk %d: %v", i, err)
}
finalContent.WriteString(content)
finalThinking.WriteString(thinking)
finalToolCalls = append(finalToolCalls, toolCalls...)
}
if finalContent.String() != tt.expectedContent {
t.Errorf("expected content %q, got %q", tt.expectedContent, finalContent.String())
}
if finalThinking.String() != tt.expectedThinking {
t.Errorf("expected thinking %q, got %q", tt.expectedThinking, finalThinking.String())
}
if diff := cmp.Diff(tt.expectedToolCalls, finalToolCalls); diff != "" {
t.Errorf("tool calls mismatch (-want +got):\n%s", diff)
}
})
}
}
func TestCogitoParser_HasToolSupport(t *testing.T) {
parser := &CogitoParser{}
if !parser.HasToolSupport() {
t.Error("CogitoParser should support tools")
}
}
func TestCogitoParser_Init(t *testing.T) {
parser := &CogitoParser{}
tools := []api.Tool{
{Function: api.ToolFunction{Name: "test_tool"}},
}
lastMessage := &api.Message{Role: "assistant", Content: "previous"}
returnedTools := parser.Init(tools, lastMessage, nil)
if len(returnedTools) != len(tools) {
t.Errorf("expected %d tools returned, got %d", len(tools), len(returnedTools))
}
}
func TestCogitoParser_parseToolCallContent(t *testing.T) {
tests := []struct {
name string
content string
expected api.ToolCall
expectError bool
}{
{
name: "valid_tool_call_standard_format",
content: `function<tool▁sep>get_weather
` + "```json\n" + `{"location":"Paris"}
` + "```",
expected: api.ToolCall{
Function: api.ToolCallFunction{
Name: "get_weather",
Arguments: api.ToolCallFunctionArguments{
"location": "Paris",
},
},
},
expectError: false,
},
{
name: "valid_tool_call_complex_args",
content: `function<tool▁sep>process_data
` + "```json\n" + `{"items":["item1","item2"],"config":{"enabled":true},"count":42}
` + "```",
expected: api.ToolCall{
Function: api.ToolCallFunction{
Name: "process_data",
Arguments: api.ToolCallFunctionArguments{
"items": []any{"item1", "item2"},
"config": map[string]any{"enabled": true},
"count": 42.0,
},
},
},
expectError: false,
},
{
name: "valid_tool_call_empty_args",
content: `function<tool▁sep>no_args_tool
` + "```json\n" + `{}
` + "```",
expected: api.ToolCall{
Function: api.ToolCallFunction{
Name: "no_args_tool",
Arguments: api.ToolCallFunctionArguments{},
},
},
expectError: false,
},
{
name: "missing_separator",
content: `functionget_weather` + "```json\n" + `{"location":"Paris"}` + "\n```",
expected: api.ToolCall{},
expectError: true,
},
{
name: "invalid_function_type",
content: `not_function<tool▁sep>get_weather` + "```json\n" + `{"location":"Paris"}` + "\n```",
expected: api.ToolCall{},
expectError: true,
},
{
name: "missing_json_block_start",
content: `function<tool▁sep>get_weather{"location":"Paris"}` + "```",
expected: api.ToolCall{},
expectError: true,
},
{
name: "missing_json_block_end",
content: `function<tool▁sep>get_weather` + "```json\n" + `{"location":"Paris"}`,
expected: api.ToolCall{},
expectError: true,
},
{
name: "invalid_json",
content: `function<tool▁sep>get_weather` + "```json\n" + `{location:Paris}` + "\n```",
expected: api.ToolCall{},
expectError: true,
},
{
name: "empty_function_type",
content: `<tool▁sep>get_weather` + "```json\n" + `{"location":"Paris"}` + "\n```",
expected: api.ToolCall{},
expectError: true,
},
{
name: "tool_with_spaces_in_name",
content: `function<tool▁sep> get_weather
` + "```json\n" + `{"location":"Paris"}
` + "```",
expected: api.ToolCall{
Function: api.ToolCallFunction{
Name: "get_weather",
Arguments: api.ToolCallFunctionArguments{
"location": "Paris",
},
},
},
expectError: false,
},
{
name: "tool_with_multiline_json",
content: `function<tool▁sep>get_weather
` + "```json\n" + `{
"location": "Paris",
"units": "metric"
}
` + "```",
expected: api.ToolCall{
Function: api.ToolCallFunction{
Name: "get_weather",
Arguments: api.ToolCallFunctionArguments{
"location": "Paris",
"units": "metric",
},
},
},
expectError: false,
},
{
name: "tool_with_nested_objects",
content: `function<tool▁sep>complex_tool
` + "```json\n" + `{"nested":{"deep":{"value":123}}}
` + "```",
expected: api.ToolCall{
Function: api.ToolCallFunction{
Name: "complex_tool",
Arguments: api.ToolCallFunctionArguments{
"nested": map[string]any{
"deep": map[string]any{
"value": 123.0,
},
},
},
},
},
expectError: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
parser := &CogitoParser{}
result, err := parser.parseToolCallContent(tt.content)
if tt.expectError {
if err == nil {
t.Errorf("expected error but got none")
}
return
}
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if diff := cmp.Diff(tt.expected, result); diff != "" {
t.Errorf("tool call mismatch (-want +got):\n%s", diff)
}
})
}
}

View File

@@ -6,9 +6,9 @@ import (
)
type Parser interface {
// Init initializes the parser with tools, optional last message for chat prefill, and think value
// Init initializes the parser with tools and optional last message for chat prefill
// Returns processed tools if the parser needs to modify them (e.g., harmony renames them)
Init(tools []api.Tool, lastMessage *api.Message, thinkValue *api.ThinkValue) []api.Tool
Init(tools []api.Tool, lastMessage *api.Message) []api.Tool
// Add processes streamed content and returns parsed content, thinking, and tool calls
// The done flag indicates if this is the last chunk (used for draining accumulators)
Add(s string, done bool) (content string, thinking string, calls []api.ToolCall, err error)
@@ -52,8 +52,6 @@ func ParserForName(name string) Parser {
return &PassthroughParser{}
case "harmony":
return harmony.NewHarmonyMessageHandler()
case "cogito":
return &CogitoParser{}
default:
return nil
}
@@ -61,7 +59,7 @@ func ParserForName(name string) Parser {
type PassthroughParser struct{}
func (p *PassthroughParser) Init(tools []api.Tool, lastMessage *api.Message, thinkValue *api.ThinkValue) []api.Tool {
func (p *PassthroughParser) Init(tools []api.Tool, lastMessage *api.Message) []api.Tool {
return tools // passthrough doesn't modify tools
}

View File

@@ -10,7 +10,7 @@ type mockParser struct {
name string
}
func (m *mockParser) Init(tools []api.Tool, lastMessage *api.Message, thinkValue *api.ThinkValue) []api.Tool {
func (m *mockParser) Init(tools []api.Tool, lastMessage *api.Message) []api.Tool {
return tools
}

View File

@@ -43,7 +43,7 @@ func (p *Qwen3CoderParser) HasThinkingSupport() bool {
return false
}
func (p *Qwen3CoderParser) Init(tools []api.Tool, lastMessage *api.Message, thinkValue *api.ThinkValue) []api.Tool {
func (p *Qwen3CoderParser) Init(tools []api.Tool, lastMessage *api.Message) []api.Tool {
p.tools = tools
return tools // Qwen doesn't modify tools
}
@@ -432,7 +432,7 @@ func transformToXML(raw string) string {
groups := qwenTagRegex.FindStringSubmatch(match)
tag := groups[1]
var escapedValue strings.Builder
_ = xml.EscapeText(&escapedValue, []byte(groups[2])) // error is always nil for strings.Builder
xml.EscapeText(&escapedValue, []byte(groups[2]))
return fmt.Sprintf(`<%s name="%s">`, tag, escapedValue.String())
})

View File

@@ -54,7 +54,7 @@ func (p *Qwen3VLParser) setInitialState(lastMessage *api.Message) {
p.state = CollectingThinkingContent
}
func (p *Qwen3VLParser) Init(tools []api.Tool, lastMessage *api.Message, thinkValue *api.ThinkValue) []api.Tool {
func (p *Qwen3VLParser) Init(tools []api.Tool, lastMessage *api.Message) []api.Tool {
p.tools = tools
p.setInitialState(lastMessage)
return tools

View File

@@ -198,7 +198,7 @@ func TestQwen3VLNonThinkingParserStreaming(t *testing.T) {
t.Run(tc.desc, func(t *testing.T) {
parser := Qwen3VLParser{hasThinkingSupport: false}
parser.Init([]api.Tool{}, nil, nil)
parser.Init([]api.Tool{}, nil)
for i, step := range tc.steps {
parser.buffer.WriteString(step.input)
@@ -515,7 +515,7 @@ func TestQwenOldParserStreaming(t *testing.T) {
t.Run(tc.desc, func(t *testing.T) {
parser := Qwen3VLParser{hasThinkingSupport: false}
parser.Init([]api.Tool{}, nil, nil)
parser.Init([]api.Tool{}, nil)
for i, step := range tc.steps {
parser.buffer.WriteString(step.input)
@@ -822,7 +822,7 @@ func TestQwen3VLNonThinkingToolCallWhitespaceHandling(t *testing.T) {
t.Run(tc.desc, func(t *testing.T) {
parser := Qwen3VLParser{hasThinkingSupport: false}
parser.Init([]api.Tool{}, nil, nil)
parser.Init([]api.Tool{}, nil)
for i, step := range tc.steps {
parser.buffer.WriteString(step.input)

View File

@@ -205,7 +205,7 @@ func TestQwen3VLThinkingParserStreaming(t *testing.T) {
t.Run(tc.desc, func(t *testing.T) {
parser := Qwen3VLParser{hasThinkingSupport: true}
parser.Init([]api.Tool{}, nil, nil)
parser.Init([]api.Tool{}, nil)
// parser.state = CollectingThinkingContent
for i, step := range tc.steps {
@@ -386,7 +386,7 @@ func TestQwen3VLParserState(t *testing.T) {
for _, tc := range cases {
parser := Qwen3VLParser{hasThinkingSupport: tc.hasThinking}
parser.Init(nil, tc.last, nil)
parser.Init(nil, tc.last)
if parser.state != tc.wantState {
t.Errorf("%s: got state %v, want %v", tc.desc, parser.state, tc.wantState)
}
@@ -437,7 +437,7 @@ func TestQwen3VLThinkingParserWithThinkingPrefill(t *testing.T) {
for _, tc := range cases {
t.Run(tc.desc, func(t *testing.T) {
parser := Qwen3VLParser{hasThinkingSupport: true}
parser.Init([]api.Tool{}, last, nil)
parser.Init([]api.Tool{}, last)
for i, step := range tc.steps {
parser.buffer.WriteString(step.input)
@@ -500,7 +500,7 @@ func TestQwen3VLThinkingParserWithNonThinkingPrefill(t *testing.T) {
for _, tc := range cases {
t.Run(tc.desc, func(t *testing.T) {
parser := Qwen3VLParser{hasThinkingSupport: true}
parser.Init([]api.Tool{}, last, nil)
parser.Init([]api.Tool{}, last)
for i, step := range tc.steps {
parser.buffer.WriteString(step.input)
@@ -523,7 +523,7 @@ func TestQwen3VLThinkingParserStreamingAssistantPrefillContent(t *testing.T) {
// last message is assistant with content ⇒ start in CollectingContent
last := &api.Message{Role: "assistant", Content: "has content"}
parser := Qwen3VLParser{hasThinkingSupport: true}
parser.Init([]api.Tool{}, last, nil)
parser.Init([]api.Tool{}, last)
type step struct {
input string
@@ -750,7 +750,7 @@ func TestQwen3VLThinkingWhitespaceHandling(t *testing.T) {
t.Run(tc.desc, func(t *testing.T) {
parser := Qwen3VLParser{hasThinkingSupport: true}
parser.Init([]api.Tool{}, nil, nil)
parser.Init([]api.Tool{}, nil)
for i, step := range tc.steps {
parser.buffer.WriteString(step.input)
@@ -859,7 +859,7 @@ func TestQwen3VLToolCallWhitespaceHandling(t *testing.T) {
t.Run(tc.desc, func(t *testing.T) {
parser := Qwen3VLParser{hasThinkingSupport: true}
parser.Init([]api.Tool{}, tc.prefillMsg, nil)
parser.Init([]api.Tool{}, tc.prefillMsg)
for i, step := range tc.steps {
parser.buffer.WriteString(step.input)

View File

@@ -1,129 +0,0 @@
package renderers
import (
"encoding/json"
"strings"
"github.com/ollama/ollama/api"
)
type CogitoRenderer struct {
isThinking bool
}
func (r *CogitoRenderer) Render(messages []api.Message, tools []api.Tool, thinkValue *api.ThinkValue) (string, error) {
var sb strings.Builder
defaultPrompt := "You are Cogito, an AI assistant created by Deep Cogito, which is an AI research lab based in San Francisco."
// thinking is enabled: model must support it AND user must request it (true)
enableThinking := r.isThinking && (thinkValue != nil && thinkValue.Bool())
var systemPrompt string
var conversationMessages []api.Message
if len(messages) > 0 && messages[0].Role == "system" {
systemPrompt = messages[0].Content
conversationMessages = messages[1:]
} else {
conversationMessages = messages
}
var finalSystemPrompt string
if enableThinking {
finalSystemPrompt = "Enable deep thinking subroutine.\n\n" + defaultPrompt
if systemPrompt != "" {
finalSystemPrompt += "\n\n" + systemPrompt + "\n\n"
}
} else {
finalSystemPrompt = defaultPrompt
if systemPrompt != "" {
finalSystemPrompt += "\n\n" + systemPrompt
}
}
if len(tools) > 0 {
if finalSystemPrompt != "" {
finalSystemPrompt += "\nYou have the following functions available:\n"
} else {
finalSystemPrompt = "You have the following functions available:\n"
}
for _, tool := range tools {
toolJSON, _ := json.MarshalIndent(tool, "", " ") // TODO(gguo): double check json format
finalSystemPrompt += "```json\n" + string(toolJSON) + "\n```\n"
}
}
sb.WriteString("<begin▁of▁sentence>" + finalSystemPrompt)
outputsOpen := false
isLastUser := false
for i, message := range conversationMessages {
switch message.Role {
case "user":
isLastUser = true
sb.WriteString("<User>" + message.Content + "<Assistant>")
case "assistant":
isLastUser = false
if len(message.ToolCalls) > 0 {
if message.Content != "" {
sb.WriteString(message.Content)
}
sb.WriteString("<tool▁calls▁begin>")
for j, toolCall := range message.ToolCalls {
sb.WriteString("<tool▁call▁begin>function<tool▁sep>" + toolCall.Function.Name)
argsJSON, _ := json.Marshal(toolCall.Function.Arguments)
sb.WriteString("\n```json\n" + string(argsJSON) + "\n```")
sb.WriteString("<tool▁call▁end>")
if j < len(message.ToolCalls)-1 {
sb.WriteString("\n")
}
}
sb.WriteString("<tool▁calls▁end><end▁of▁sentence>")
} else {
sb.WriteString(message.Content + "<end▁of▁sentence>")
}
case "tool":
isLastUser = false
if !outputsOpen {
sb.WriteString("<tool▁outputs▁begin>")
outputsOpen = true
}
sb.WriteString("<tool▁output▁begin>" + message.Content + "<tool▁output▁end>")
hasNextTool := i+1 < len(conversationMessages) && conversationMessages[i+1].Role == "tool"
if hasNextTool {
sb.WriteString("\n")
} else {
sb.WriteString("<tool▁outputs▁end>")
outputsOpen = false
}
}
}
if outputsOpen {
sb.WriteString("<tool▁outputs▁end>")
}
if !isLastUser {
sb.WriteString("<Assistant>")
}
if enableThinking {
sb.WriteString("<think>\n")
}
return sb.String(), nil
}

View File

@@ -1,491 +0,0 @@
package renderers
import (
"testing"
"github.com/google/go-cmp/cmp"
"github.com/ollama/ollama/api"
)
func TestCogitoRenderer(t *testing.T) {
tests := []struct {
name string
messages []api.Message
tools []api.Tool
thinkValue *api.ThinkValue
expected string
}{
{
name: "basic user message",
messages: []api.Message{
{Role: "user", Content: "Hello, how are you?"},
},
thinkValue: &api.ThinkValue{Value: false},
expected: `<begin▁of▁sentence>You are Cogito, an AI assistant created by Deep Cogito, which is an AI research lab based in San Francisco.<User>Hello, how are you?<Assistant>`,
},
{
name: "basic with system message",
messages: []api.Message{
{Role: "system", Content: "You are a helpful assistant."},
{Role: "user", Content: "Hello, how are you?"},
},
thinkValue: &api.ThinkValue{Value: false},
expected: `<begin▁of▁sentence>You are Cogito, an AI assistant created by Deep Cogito, which is an AI research lab based in San Francisco.
You are a helpful assistant.<User>Hello, how are you?<Assistant>`,
},
{
name: "conversation with assistant response",
messages: []api.Message{
{Role: "user", Content: "What is the capital of France?"},
{Role: "assistant", Content: "The capital of France is Paris."},
{Role: "user", Content: "Fantastic!"},
},
thinkValue: &api.ThinkValue{Value: false},
expected: `<begin▁of▁sentence>You are Cogito, an AI assistant created by Deep Cogito, which is an AI research lab based in San Francisco.<User>What is the capital of France?<Assistant>The capital of France is Paris.<end▁of▁sentence><User>Fantastic!<Assistant>`,
},
{
name: "thinking enabled without system",
messages: []api.Message{
{Role: "user", Content: "Hello, how are you?"},
},
thinkValue: &api.ThinkValue{Value: true},
expected: `<begin▁of▁sentence>Enable deep thinking subroutine.
You are Cogito, an AI assistant created by Deep Cogito, which is an AI research lab based in San Francisco.<User>Hello, how are you?<Assistant><think>
`,
},
{
name: "thinking enabled with system",
messages: []api.Message{
{Role: "system", Content: "You are a helpful assistant."},
{Role: "user", Content: "Hello, how are you?"},
},
thinkValue: &api.ThinkValue{Value: true},
expected: `<begin▁of▁sentence>Enable deep thinking subroutine.
You are Cogito, an AI assistant created by Deep Cogito, which is an AI research lab based in San Francisco.
You are a helpful assistant.
<User>Hello, how are you?<Assistant><think>
`,
},
{
name: "thinking disabled",
messages: []api.Message{
{Role: "user", Content: "Hello, how are you?"},
},
thinkValue: &api.ThinkValue{Value: false},
expected: `<begin▁of▁sentence>You are Cogito, an AI assistant created by Deep Cogito, which is an AI research lab based in San Francisco.<User>Hello, how are you?<Assistant>`,
},
{
name: "with tools",
messages: []api.Message{
{Role: "user", Content: "What's the weather like?"},
},
thinkValue: &api.ThinkValue{Value: false},
tools: []api.Tool{
{
Type: "function",
Function: api.ToolFunction{
Name: "get_weather",
Description: "Get current weather",
Parameters: api.ToolFunctionParameters{
Type: "object",
Properties: map[string]api.ToolProperty{
"location": {
Type: api.PropertyType{"string"},
Description: "City name",
},
},
Required: []string{"location"},
},
},
},
},
expected: `<begin▁of▁sentence>You are Cogito, an AI assistant created by Deep Cogito, which is an AI research lab based in San Francisco.
You have the following functions available:
` + "```json\n" + `{
"type": "function",
"function": {
"name": "get_weather",
"description": "Get current weather",
"parameters": {
"type": "object",
"required": [
"location"
],
"properties": {
"location": {
"type": "string",
"description": "City name"
}
}
}
}
}
` + "```\n" + `<User>What's the weather like?<Assistant>`,
},
{
name: "assistant with tool calls",
messages: []api.Message{
{Role: "user", Content: "What's the weather in Paris?"},
{
Role: "assistant",
Content: "I'll check the weather in Paris for you.",
ToolCalls: []api.ToolCall{
{
Function: api.ToolCallFunction{
Name: "get_weather",
Arguments: api.ToolCallFunctionArguments{
"location": "Paris",
},
},
},
},
},
},
thinkValue: &api.ThinkValue{Value: false},
expected: `<begin▁of▁sentence>You are Cogito, an AI assistant created by Deep Cogito, which is an AI research lab based in San Francisco.<User>What's the weather in Paris?<Assistant>I'll check the weather in Paris for you.<tool▁calls▁begin><tool▁call▁begin>function<tool▁sep>get_weather
` + "```json\n" + `{"location":"Paris"}
` + "```" + `<tool▁call▁end><tool▁calls▁end><end▁of▁sentence><Assistant>`,
},
{
name: "tool response",
messages: []api.Message{
{Role: "user", Content: "What's the weather in Paris?"},
{
Role: "assistant",
ToolCalls: []api.ToolCall{
{
Function: api.ToolCallFunction{
Name: "get_weather",
Arguments: api.ToolCallFunctionArguments{
"location": "Paris",
},
},
},
},
},
{Role: "tool", Content: "Temperature: 22°C, Sunny"},
},
thinkValue: &api.ThinkValue{Value: false},
expected: `<begin▁of▁sentence>You are Cogito, an AI assistant created by Deep Cogito, which is an AI research lab based in San Francisco.<User>What's the weather in Paris?<Assistant><tool▁calls▁begin><tool▁call▁begin>function<tool▁sep>get_weather
` + "```json\n" + `{"location":"Paris"}
` + "```" + `<tool▁call▁end><tool▁calls▁end><end▁of▁sentence><tool▁outputs▁begin><tool▁output▁begin>Temperature: 22°C, Sunny<tool▁output▁end><tool▁outputs▁end><Assistant>`,
},
{
name: "multiple tool responses",
messages: []api.Message{
{Role: "user", Content: "Get weather for Paris and London"},
{
Role: "assistant",
ToolCalls: []api.ToolCall{
{
Function: api.ToolCallFunction{
Name: "get_weather",
Arguments: api.ToolCallFunctionArguments{
"location": "Paris",
},
},
},
{
Function: api.ToolCallFunction{
Name: "get_weather",
Arguments: api.ToolCallFunctionArguments{
"location": "London",
},
},
},
},
},
{Role: "tool", Content: "Paris: 22°C, Sunny"},
{Role: "tool", Content: "London: 18°C, Cloudy"},
},
thinkValue: &api.ThinkValue{Value: false},
expected: `<begin▁of▁sentence>You are Cogito, an AI assistant created by Deep Cogito, which is an AI research lab based in San Francisco.<User>Get weather for Paris and London<Assistant><tool▁calls▁begin><tool▁call▁begin>function<tool▁sep>get_weather
` + "```json\n" + `{"location":"Paris"}
` + "```" + `<tool▁call▁end>
<tool▁call▁begin>function<tool▁sep>get_weather
` + "```json\n" + `{"location":"London"}
` + "```" + `<tool▁call▁end><tool▁calls▁end><end▁of▁sentence><tool▁outputs▁begin><tool▁output▁begin>Paris: 22°C, Sunny<tool▁output▁end>
<tool▁output▁begin>London: 18°C, Cloudy<tool▁output▁end><tool▁outputs▁end><Assistant>`,
},
{
name: "thinking with tools",
messages: []api.Message{
{Role: "user", Content: "What's the weather like?"},
},
tools: []api.Tool{
{
Type: "function",
Function: api.ToolFunction{
Name: "get_weather",
Description: "Get current weather",
Parameters: api.ToolFunctionParameters{
Type: "object",
Properties: map[string]api.ToolProperty{
"location": {
Type: api.PropertyType{"string"},
Description: "City name",
},
},
Required: []string{"location"},
},
},
},
},
thinkValue: &api.ThinkValue{Value: true},
expected: `<begin▁of▁sentence>Enable deep thinking subroutine.
You are Cogito, an AI assistant created by Deep Cogito, which is an AI research lab based in San Francisco.
You have the following functions available:
` + "```json\n" + `{
"type": "function",
"function": {
"name": "get_weather",
"description": "Get current weather",
"parameters": {
"type": "object",
"required": [
"location"
],
"properties": {
"location": {
"type": "string",
"description": "City name"
}
}
}
}
}
` + "```\n" + `<User>What's the weather like?<Assistant><think>
`,
},
// test cases based on cogito
{
name: "single_turn_thinking_false",
messages: []api.Message{
{Role: "user", Content: "Hello"},
},
thinkValue: &api.ThinkValue{Value: false},
expected: `<begin▁of▁sentence>You are Cogito, an AI assistant created by Deep Cogito, which is an AI research lab based in San Francisco.<User>Hello<Assistant>`,
},
{
name: "single_turn_thinking_true",
messages: []api.Message{
{Role: "user", Content: "Hello"},
},
thinkValue: &api.ThinkValue{Value: true},
expected: `<begin▁of▁sentence>Enable deep thinking subroutine.
You are Cogito, an AI assistant created by Deep Cogito, which is an AI research lab based in San Francisco.<User>Hello<Assistant><think>
`,
},
{
name: "multi_turn_thinking_false",
messages: []api.Message{
{Role: "user", Content: "Hello"},
{Role: "assistant", Content: "Hi there!"},
{Role: "user", Content: "How are you?"},
},
thinkValue: &api.ThinkValue{Value: false},
expected: `<begin▁of▁sentence>You are Cogito, an AI assistant created by Deep Cogito, which is an AI research lab based in San Francisco.<User>Hello<Assistant>Hi there!<end▁of▁sentence><User>How are you?<Assistant>`,
},
{
name: "multi_turn_thinking_true",
messages: []api.Message{
{Role: "user", Content: "Hello"},
{Role: "assistant", Content: "Hi there!"},
{Role: "user", Content: "How are you?"},
},
thinkValue: &api.ThinkValue{Value: true},
expected: `<begin▁of▁sentence>Enable deep thinking subroutine.
You are Cogito, an AI assistant created by Deep Cogito, which is an AI research lab based in San Francisco.<User>Hello<Assistant>Hi there!<end▁of▁sentence><User>How are you?<Assistant><think>
`,
},
{
name: "multi_with_system_thinking_false",
messages: []api.Message{
{Role: "system", Content: "You are a helpful assistant"},
{Role: "user", Content: "Start"},
{Role: "assistant", Content: "Okay"},
},
thinkValue: &api.ThinkValue{Value: false},
expected: `<begin▁of▁sentence>You are Cogito, an AI assistant created by Deep Cogito, which is an AI research lab based in San Francisco.
You are a helpful assistant<User>Start<Assistant>Okay<end▁of▁sentence><Assistant>`,
},
{
name: "multi_with_system_thinking_true",
messages: []api.Message{
{Role: "system", Content: "You are a helpful assistant"},
{Role: "user", Content: "Start"},
{Role: "assistant", Content: "Okay"},
},
thinkValue: &api.ThinkValue{Value: true},
expected: `<begin▁of▁sentence>Enable deep thinking subroutine.
You are Cogito, an AI assistant created by Deep Cogito, which is an AI research lab based in San Francisco.
You are a helpful assistant
<User>Start<Assistant>Okay<end▁of▁sentence><Assistant><think>
`,
},
{
name: "multi_with_system2_thinking_false",
messages: []api.Message{
{Role: "system", Content: "You are a pirate chatbot who always responds in pirate speak!"},
{Role: "user", Content: "Give me a short introduction to LLMs."},
{Role: "assistant", Content: "Arrr! I'm a pirate"},
{Role: "user", Content: "Tell me more about LLMs."},
},
thinkValue: &api.ThinkValue{Value: false},
expected: `<begin▁of▁sentence>You are Cogito, an AI assistant created by Deep Cogito, which is an AI research lab based in San Francisco.
You are a pirate chatbot who always responds in pirate speak!<User>Give me a short introduction to LLMs.<Assistant>Arrr! I'm a pirate<end▁of▁sentence><User>Tell me more about LLMs.<Assistant>`,
},
{
name: "multi_with_system2_thinking_true",
messages: []api.Message{
{Role: "system", Content: "You are a pirate chatbot who always responds in pirate speak!"},
{Role: "user", Content: "Give me a short introduction to LLMs."},
{Role: "assistant", Content: "Arrr! I'm a pirate"},
{Role: "user", Content: "Tell me more about LLMs."},
},
thinkValue: &api.ThinkValue{Value: true},
expected: `<begin▁of▁sentence>Enable deep thinking subroutine.
You are Cogito, an AI assistant created by Deep Cogito, which is an AI research lab based in San Francisco.
You are a pirate chatbot who always responds in pirate speak!
<User>Give me a short introduction to LLMs.<Assistant>Arrr! I'm a pirate<end▁of▁sentence><User>Tell me more about LLMs.<Assistant><think>
`,
},
// tools
{
name: "tool_calls_only_no_content",
messages: []api.Message{
{Role: "user", Content: "Get weather for Paris"},
{
Role: "assistant",
ToolCalls: []api.ToolCall{
{
Function: api.ToolCallFunction{
Name: "get_weather",
Arguments: api.ToolCallFunctionArguments{
"location": "Paris",
},
},
},
},
},
},
thinkValue: &api.ThinkValue{Value: false},
expected: `<begin▁of▁sentence>You are Cogito, an AI assistant created by Deep Cogito, which is an AI research lab based in San Francisco.<User>Get weather for Paris<Assistant><tool▁calls▁begin><tool▁call▁begin>function<tool▁sep>get_weather
` + "```json\n" + `{"location":"Paris"}
` + "```" + `<tool▁call▁end><tool▁calls▁end><end▁of▁sentence><Assistant>`,
},
{
name: "complex_tool_arguments",
messages: []api.Message{
{Role: "user", Content: "Process complex data"},
{
Role: "assistant",
ToolCalls: []api.ToolCall{
{
Function: api.ToolCallFunction{
Name: "process_data",
Arguments: api.ToolCallFunctionArguments{
"items": []any{"item1", "item2", "item3"},
"config": map[string]any{
"enabled": true,
"threshold": 0.95,
"tags": []string{"important", "urgent"},
},
},
},
},
},
},
},
thinkValue: &api.ThinkValue{Value: false},
expected: `<begin▁of▁sentence>You are Cogito, an AI assistant created by Deep Cogito, which is an AI research lab based in San Francisco.<User>Process complex data<Assistant><tool▁calls▁begin><tool▁call▁begin>function<tool▁sep>process_data
` + "```json\n" + `{"config":{"enabled":true,"tags":["important","urgent"],"threshold":0.95},"items":["item1","item2","item3"]}
` + "```" + `<tool▁call▁end><tool▁calls▁end><end▁of▁sentence><Assistant>`,
},
{
name: "empty_messages",
messages: []api.Message{
{Role: "system", Content: ""},
{Role: "user", Content: "Hello"},
{Role: "assistant", Content: ""},
},
thinkValue: &api.ThinkValue{Value: false},
expected: `<begin▁of▁sentence>You are Cogito, an AI assistant created by Deep Cogito, which is an AI research lab based in San Francisco.<User>Hello<Assistant><end▁of▁sentence><Assistant>`,
},
{
name: "thinking_with_empty_assistant_content",
messages: []api.Message{
{Role: "user", Content: "Think about this"},
{Role: "assistant", Content: ""},
},
thinkValue: &api.ThinkValue{Value: true},
expected: `<begin▁of▁sentence>Enable deep thinking subroutine.
You are Cogito, an AI assistant created by Deep Cogito, which is an AI research lab based in San Francisco.<User>Think about this<Assistant><end▁of▁sentence><Assistant><think>
`,
},
{
name: "multiple_system_messages",
messages: []api.Message{
{Role: "system", Content: "First instruction"},
{Role: "system", Content: "Second instruction"},
{Role: "user", Content: "Hello"},
},
thinkValue: &api.ThinkValue{Value: false},
expected: `<begin▁of▁sentence>You are Cogito, an AI assistant created by Deep Cogito, which is an AI research lab based in San Francisco.
First instruction<User>Hello<Assistant>`,
},
{
name: "special_characters_in_content",
messages: []api.Message{
{Role: "user", Content: "What about <|special|> tokens and \"quotes\"?"},
{Role: "assistant", Content: "They're handled normally in content."},
},
thinkValue: &api.ThinkValue{Value: false},
expected: `<begin▁of▁sentence>You are Cogito, an AI assistant created by Deep Cogito, which is an AI research lab based in San Francisco.<User>What about <|special|> tokens and "quotes"?<Assistant>They're handled normally in content.<end▁of▁sentence><Assistant>`,
},
{
name: "long_conversation_multiple_rounds",
messages: []api.Message{
{Role: "user", Content: "Hi"},
{Role: "assistant", Content: "Hello!"},
{Role: "user", Content: "How are you?"},
{Role: "assistant", Content: "Good, thanks!"},
{Role: "user", Content: "What's the weather?"},
},
thinkValue: &api.ThinkValue{Value: false},
expected: `<begin▁of▁sentence>You are Cogito, an AI assistant created by Deep Cogito, which is an AI research lab based in San Francisco.<User>Hi<Assistant>Hello!<end▁of▁sentence><User>How are you?<Assistant>Good, thanks!<end▁of▁sentence><User>What's the weather?<Assistant>`,
},
}
renderer := &CogitoRenderer{isThinking: true}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
rendered, err := renderer.Render(tt.messages, tt.tools, tt.thinkValue)
if err != nil {
t.Fatalf("Render() error = %v", err)
}
if diff := cmp.Diff(tt.expected, rendered); diff != "" {
t.Errorf("Render() mismatch (-want +got):\n%s", diff)
}
})
}
}

View File

@@ -56,9 +56,6 @@ func rendererForName(name string) Renderer {
case "qwen3-vl-thinking":
renderer := &Qwen3VLRenderer{isThinking: true, useImgTags: RenderImgTags}
return renderer
case "cogito":
renderer := &CogitoRenderer{isThinking: true}
return renderer
default:
return nil
}

View File

@@ -10,8 +10,7 @@ import (
)
type WordPiece struct {
vocab *Vocabulary
lowercase bool
vocab *Vocabulary
}
// ggmlPrefix is the prefix used by GGML vocabularies to indicate word boundaries.
@@ -115,10 +114,8 @@ func (wpm WordPiece) Encode(s string, addSpecial bool) ([]int32, error) {
subword = ggmlPrefix + subword
}
if wpm.lowercase {
subword = strings.ToLower(subword)
}
piece = wpm.vocab.Encode(subword)
// TODO: some models might not want [ToLower]
piece = wpm.vocab.Encode(strings.ToLower(subword))
if piece >= 0 {
break
}
@@ -163,9 +160,8 @@ func (wpm WordPiece) Vocabulary() *Vocabulary {
var _ TextProcessor = (*WordPiece)(nil)
func NewWordPiece(vocab *Vocabulary, lowercase bool) WordPiece {
func NewWordPiece(vocab *Vocabulary) WordPiece {
return WordPiece{
vocab: vocab,
lowercase: lowercase,
vocab: vocab,
}
}

View File

@@ -15,9 +15,7 @@ func TestWordPiece(t *testing.T) {
AddEOS: true,
BOS: []int32{1},
EOS: []int32{2},
},
true, // lowercase
)
})
ids, err := wpm.Encode("Hello world!", true)
if err != nil {

View File

@@ -1,123 +0,0 @@
package parser
import (
"os"
"os/user"
"path/filepath"
"runtime"
"testing"
)
func TestExpandPath(t *testing.T) {
mockCurrentUser := func() (*user.User, error) {
return &user.User{
Username: "testuser",
HomeDir: func() string {
if os.PathSeparator == '\\' {
return filepath.FromSlash("D:/home/testuser")
}
return "/home/testuser"
}(),
}, nil
}
mockLookupUser := func(username string) (*user.User, error) {
fakeUsers := map[string]string{
"testuser": func() string {
if os.PathSeparator == '\\' {
return filepath.FromSlash("D:/home/testuser")
}
return "/home/testuser"
}(),
"anotheruser": func() string {
if os.PathSeparator == '\\' {
return filepath.FromSlash("D:/home/anotheruser")
}
return "/home/anotheruser"
}(),
}
if homeDir, ok := fakeUsers[username]; ok {
return &user.User{
Username: username,
HomeDir: homeDir,
}, nil
}
return nil, os.ErrNotExist
}
pwd, err := os.Getwd()
if err != nil {
t.Fatal(err)
}
t.Run("unix tests", func(t *testing.T) {
if runtime.GOOS == "windows" {
return
}
tests := []struct {
path string
relativeDir string
expected string
shouldErr bool
}{
{"~", "", "/home/testuser", false},
{"~/myfolder/myfile.txt", "", "/home/testuser/myfolder/myfile.txt", false},
{"~anotheruser/docs/file.txt", "", "/home/anotheruser/docs/file.txt", false},
{"~nonexistentuser/file.txt", "", "", true},
{"relative/path/to/file", "", filepath.Join(pwd, "relative/path/to/file"), false},
{"/absolute/path/to/file", "", "/absolute/path/to/file", false},
{"/absolute/path/to/file", "someotherdir/", "/absolute/path/to/file", false},
{".", pwd, pwd, false},
{".", "", pwd, false},
{"somefile", "somedir", filepath.Join(pwd, "somedir", "somefile"), false},
}
for _, test := range tests {
result, err := expandPathImpl(test.path, test.relativeDir, mockCurrentUser, mockLookupUser)
if (err != nil) != test.shouldErr {
t.Errorf("expandPathImpl(%q) returned error: %v, expected error: %v", test.path, err != nil, test.shouldErr)
}
if result != test.expected && !test.shouldErr {
t.Errorf("expandPathImpl(%q) = %q, want %q", test.path, result, test.expected)
}
}
})
t.Run("windows tests", func(t *testing.T) {
if runtime.GOOS != "windows" {
return
}
tests := []struct {
path string
relativeDir string
expected string
shouldErr bool
}{
{"~", "", "D:\\home\\testuser", false},
{"~/myfolder/myfile.txt", "", "D:\\home\\testuser\\myfolder\\myfile.txt", false},
{"~anotheruser/docs/file.txt", "", "D:\\home\\anotheruser\\docs\\file.txt", false},
{"~nonexistentuser/file.txt", "", "", true},
{"relative\\path\\to\\file", "", filepath.Join(pwd, "relative\\path\\to\\file"), false},
{"D:\\absolute\\path\\to\\file", "", "D:\\absolute\\path\\to\\file", false},
{"D:\\absolute\\path\\to\\file", "someotherdir/", "D:\\absolute\\path\\to\\file", false},
{".", pwd, pwd, false},
{".", "", pwd, false},
{"somefile", "somedir", filepath.Join(pwd, "somedir", "somefile"), false},
}
for _, test := range tests {
result, err := expandPathImpl(test.path, test.relativeDir, mockCurrentUser, mockLookupUser)
if (err != nil) != test.shouldErr {
t.Errorf("expandPathImpl(%q) returned error: %v, expected error: %v", test.path, err != nil, test.shouldErr)
}
if result != test.expected && !test.shouldErr {
t.Errorf("expandPathImpl(%q) = %q, want %q", test.path, result, test.expected)
}
}
})
}

View File

@@ -620,43 +620,43 @@ func isValidCommand(cmd string) bool {
}
}
func expandPathImpl(path, relativeDir string, currentUserFunc func() (*user.User, error), lookupUserFunc func(string) (*user.User, error)) (string, error) {
if filepath.IsAbs(path) || strings.HasPrefix(path, "\\") || strings.HasPrefix(path, "/") {
return filepath.Abs(path)
} else if strings.HasPrefix(path, "~") {
var homeDir string
if path == "~" || strings.HasPrefix(path, "~/") {
// Current user's home directory
currentUser, err := currentUserFunc()
if err != nil {
return "", fmt.Errorf("failed to get current user: %w", err)
}
homeDir = currentUser.HomeDir
path = strings.TrimPrefix(path, "~")
} else {
// Specific user's home directory
parts := strings.SplitN(path[1:], "/", 2)
userInfo, err := lookupUserFunc(parts[0])
if err != nil {
return "", fmt.Errorf("failed to find user '%s': %w", parts[0], err)
}
homeDir = userInfo.HomeDir
if len(parts) > 1 {
path = "/" + parts[1]
} else {
path = ""
}
}
path = filepath.Join(homeDir, path)
} else {
path = filepath.Join(relativeDir, path)
func expandPath(path, dir string) (string, error) {
if filepath.IsAbs(path) {
return path, nil
}
return filepath.Abs(path)
}
path, found := strings.CutPrefix(path, "~")
switch {
case !found:
// make path relative to dir
if !filepath.IsAbs(dir) {
// if dir is relative, make it absolute relative to cwd
cwd, err := os.Getwd()
if err != nil {
return "", err
}
dir = filepath.Join(cwd, dir)
}
path = filepath.Join(dir, path)
case filepath.IsLocal(path):
// ~<user>/...
// make path relative to specified user's home
split := strings.SplitN(path, string(os.PathSeparator), 2)
u, err := user.Lookup(split[0])
if err != nil {
return "", err
}
split[0] = u.HomeDir
path = filepath.Join(split...)
default:
// ~ or ~/...
// make path relative to current user's home
home, err := os.UserHomeDir()
if err != nil {
return "", err
}
path = filepath.Join(home, path)
}
func expandPath(path, relativeDir string) (string, error) {
return expandPathImpl(path, relativeDir, user.Current, user.Lookup)
return filepath.Clean(path), nil
}

View File

@@ -9,7 +9,9 @@ import (
"io"
"maps"
"os"
"os/user"
"path/filepath"
"runtime"
"strings"
"testing"
"unicode/utf16"
@@ -1126,3 +1128,62 @@ func TestFilesForModel(t *testing.T) {
})
}
}
func TestExpandPath(t *testing.T) {
home := t.TempDir()
t.Setenv("HOME", home)
t.Setenv("USERPROFILE", home)
cwd, err := os.Getwd()
if err != nil {
t.Fatal(err)
}
u, err := user.Current()
if err != nil {
t.Fatal(err)
}
volume := ""
if runtime.GOOS == "windows" {
volume = "D:"
}
cases := []struct {
input,
dir,
want string
err error
}{
{"~", "", home, nil},
{"~/path/to/file", "", filepath.Join(home, filepath.ToSlash("path/to/file")), nil},
{"~" + u.Username + "/path/to/file", "", filepath.Join(u.HomeDir, filepath.ToSlash("path/to/file")), nil},
{"~nonexistentuser/path/to/file", "", "", user.UnknownUserError("nonexistentuser")},
{"relative/path/to/file", "", filepath.Join(cwd, filepath.ToSlash("relative/path/to/file")), nil},
{volume + "/absolute/path/to/file", "", filepath.ToSlash(volume + "/absolute/path/to/file"), nil},
{volume + "/absolute/path/to/file", filepath.ToSlash("another/path"), filepath.ToSlash(volume + "/absolute/path/to/file"), nil},
{".", cwd, cwd, nil},
{".", "", cwd, nil},
{"", cwd, cwd, nil},
{"", "", cwd, nil},
{"file", "path/to", filepath.Join(cwd, filepath.ToSlash("path/to/file")), nil},
}
for _, tt := range cases {
t.Run(tt.input, func(t *testing.T) {
got, err := expandPath(tt.input, tt.dir)
// On Windows, user.Lookup does not map syscall errors to user.UnknownUserError
// so we special case the test to just check for an error.
// See https://cs.opensource.google/go/go/+/refs/tags/go1.25.1:src/os/user/lookup_windows.go;l=455
if runtime.GOOS != "windows" && !errors.Is(err, tt.err) {
t.Fatalf("expandPath(%q) error = %v, wantErr %v", tt.input, err, tt.err)
} else if tt.err != nil && err == nil {
t.Fatal("test case expected to fail on windows")
}
if got != tt.want {
t.Errorf("expandPath(%q) = %v, want %v", tt.input, got, tt.want)
}
})
}
}

View File

@@ -340,7 +340,7 @@ func (s *Server) GenerateHandler(c *gin.Context) {
builtinParser = parsers.ParserForName(m.Config.Parser)
if builtinParser != nil {
// no tools or last message for generate endpoint
builtinParser.Init(nil, nil, req.Think)
builtinParser.Init(nil, nil)
}
}
@@ -2051,7 +2051,7 @@ func (s *Server) ChatHandler(c *gin.Context) {
lastMessage = &msgs[len(msgs)-1]
}
// Initialize parser and get processed tools
processedTools = builtinParser.Init(req.Tools, lastMessage, req.Think)
processedTools = builtinParser.Init(req.Tools, lastMessage)
}
}