Compare commits

..

7 Commits

Author SHA1 Message Date
nicole pardal
03abdb4969 fixed pretokenizer 2025-12-09 10:02:17 -08:00
nicole pardal
57c1d7db9a fixed generation issue 2025-12-08 00:35:49 -08:00
nicole pardal
91d6370a62 removed original olmo support 2025-12-01 14:17:46 -08:00
nicole pardal
38a2a6468f removed olmo1 support 2025-12-01 14:14:31 -08:00
nicole pardal
064ec63ddf lint 2025-11-26 20:05:25 -08:00
nicole pardal
fd959fbf7a updated converter 2025-11-26 19:42:34 -08:00
nicole pardal
cfc9729edf olmo model initial 2025-11-25 15:49:09 -08:00
11 changed files with 347 additions and 728 deletions

2
.gitattributes vendored
View File

@@ -19,8 +19,6 @@ ml/backend/**/*.comp linguist-vendored
ml/backend/**/*.glsl linguist-vendored
ml/backend/**/CMakeLists.txt linguist-vendored
app/webview linguist-vendored
llama/build-info.cpp linguist-generated
ml/backend/ggml/ggml/src/ggml-metal/ggml-metal-embed.s linguist-generated

View File

@@ -226,14 +226,7 @@ func (c *Client) stream(ctx context.Context, method, path string, data any, fn f
bts := scanner.Bytes()
if err := json.Unmarshal(bts, &errorResponse); err != nil {
if response.StatusCode >= http.StatusBadRequest {
return StatusError{
StatusCode: response.StatusCode,
Status: response.Status,
ErrorMessage: string(bts),
}
}
return errors.New(string(bts))
return fmt.Errorf("unmarshal: %w", err)
}
if response.StatusCode == http.StatusUnauthorized {

View File

@@ -55,7 +55,6 @@ func TestClientFromEnvironment(t *testing.T) {
type testError struct {
message string
statusCode int
raw bool // if true, write message as-is instead of JSON encoding
}
func (e testError) Error() string {
@@ -112,20 +111,6 @@ func TestClientStream(t *testing.T) {
},
},
},
{
name: "plain text error response",
responses: []any{
"internal server error",
},
wantErr: "internal server error",
},
{
name: "HTML error page",
responses: []any{
"<html><body>404 Not Found</body></html>",
},
wantErr: "404 Not Found",
},
}
for _, tc := range testCases {
@@ -150,12 +135,6 @@ func TestClientStream(t *testing.T) {
return
}
if str, ok := resp.(string); ok {
fmt.Fprintln(w, str)
flusher.Flush()
continue
}
if err := json.NewEncoder(w).Encode(resp); err != nil {
t.Fatalf("failed to encode response: %v", err)
}
@@ -194,10 +173,9 @@ func TestClientStream(t *testing.T) {
func TestClientDo(t *testing.T) {
testCases := []struct {
name string
response any
wantErr string
wantStatusCode int
name string
response any
wantErr string
}{
{
name: "immediate error response",
@@ -205,8 +183,7 @@ func TestClientDo(t *testing.T) {
message: "test error message",
statusCode: http.StatusBadRequest,
},
wantErr: "test error message",
wantStatusCode: http.StatusBadRequest,
wantErr: "test error message",
},
{
name: "server error response",
@@ -214,8 +191,7 @@ func TestClientDo(t *testing.T) {
message: "internal error",
statusCode: http.StatusInternalServerError,
},
wantErr: "internal error",
wantStatusCode: http.StatusInternalServerError,
wantErr: "internal error",
},
{
name: "successful response",
@@ -227,26 +203,6 @@ func TestClientDo(t *testing.T) {
Success: true,
},
},
{
name: "plain text error response",
response: testError{
message: "internal server error",
statusCode: http.StatusInternalServerError,
raw: true,
},
wantErr: "internal server error",
wantStatusCode: http.StatusInternalServerError,
},
{
name: "HTML error page",
response: testError{
message: "<html><body>404 Not Found</body></html>",
statusCode: http.StatusNotFound,
raw: true,
},
wantErr: "<html><body>404 Not Found</body></html>",
wantStatusCode: http.StatusNotFound,
},
}
for _, tc := range testCases {
@@ -254,16 +210,11 @@ func TestClientDo(t *testing.T) {
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if errResp, ok := tc.response.(testError); ok {
w.WriteHeader(errResp.statusCode)
if !errResp.raw {
err := json.NewEncoder(w).Encode(map[string]string{
"error": errResp.message,
})
if err != nil {
t.Fatal("failed to encode error response:", err)
}
} else {
// Write raw message (simulates non-JSON error responses)
fmt.Fprint(w, errResp.message)
err := json.NewEncoder(w).Encode(map[string]string{
"error": errResp.message,
})
if err != nil {
t.Fatal("failed to encode error response:", err)
}
return
}
@@ -290,15 +241,6 @@ func TestClientDo(t *testing.T) {
if err.Error() != tc.wantErr {
t.Errorf("error message mismatch: got %q, want %q", err.Error(), tc.wantErr)
}
if tc.wantStatusCode != 0 {
if statusErr, ok := err.(StatusError); ok {
if statusErr.StatusCode != tc.wantStatusCode {
t.Errorf("status code mismatch: got %d, want %d", statusErr.StatusCode, tc.wantStatusCode)
}
} else {
t.Errorf("expected StatusError, got %T", err)
}
}
return
}

View File

@@ -1,625 +0,0 @@
#!/usr/bin/env python3
# /// script
# requires-python = ">=3.11"
# dependencies = [
# "transformers>=4.57.0",
# "jinja2",
# "fastapi",
# "uvicorn",
# "pydantic",
# "requests",
# ]
# ///
"""
Chat Template Testing Tool
Test HuggingFace chat templates against Ollama renderers.
Usage:
# Run predefined test cases against a HuggingFace model
uv run cmd/chat_template/chat_template.py --model PrimeIntellect/INTELLECT-3
# Compare HuggingFace output with Ollama renderer
uv run cmd/chat_template/chat_template.py --model PrimeIntellect/INTELLECT-3 --ollama-model intellect3
# Start server for manual curl testing
uv run cmd/chat_template/chat_template.py --serve
# Show chat template for a model
uv run cmd/chat_template/chat_template.py --model PrimeIntellect/INTELLECT-3 --show-template
"""
import argparse
import json
import sys
from typing import Any
from transformers import AutoTokenizer
TEST_CASES = [
{
"name": "basic_user_message",
"messages": [{"role": "user", "content": "Hello!"}],
"tools": None,
},
{
"name": "with_system_message",
"messages": [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "Hello!"},
],
"tools": None,
},
{
"name": "multi_turn_conversation",
"messages": [
{"role": "user", "content": "Hello"},
{"role": "assistant", "content": "Hi there!"},
{"role": "user", "content": "How are you?"},
],
"tools": None,
},
{
"name": "with_tools",
"messages": [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "What is the weather?"},
],
"tools": [
{
"type": "function",
"function": {
"name": "get_weather",
"description": "Get the current weather",
"parameters": {
"type": "object",
"required": ["location"],
"properties": {
"location": {"type": "string", "description": "The city"}
},
},
},
}
],
},
{
"name": "tool_call_and_response",
"messages": [
{"role": "user", "content": "What is the weather in SF?"},
{
"role": "assistant",
"content": "Let me check the weather.",
"tool_calls": [
{
"id": "call_1",
"type": "function",
"function": {
"name": "get_weather",
"arguments": {"location": "San Francisco"},
},
}
],
},
{"role": "tool", "content": '{"temperature": 68}', "tool_call_id": "call_1"},
],
"tools": [
{
"type": "function",
"function": {
"name": "get_weather",
"description": "Get the current weather",
"parameters": {
"type": "object",
"required": ["location"],
"properties": {
"location": {"type": "string", "description": "The city"}
},
},
},
}
],
},
{
"name": "parallel_tool_calls",
"messages": [
{"role": "user", "content": "Get weather in SF and NYC"},
{
"role": "assistant",
"tool_calls": [
{
"id": "call_1",
"type": "function",
"function": {
"name": "get_weather",
"arguments": {"location": "San Francisco"},
},
},
{
"id": "call_2",
"type": "function",
"function": {
"name": "get_weather",
"arguments": {"location": "New York"},
},
},
],
},
{"role": "tool", "content": '{"temperature": 68}', "tool_call_id": "call_1"},
{"role": "tool", "content": '{"temperature": 55}', "tool_call_id": "call_2"},
],
"tools": [
{
"type": "function",
"function": {
"name": "get_weather",
"parameters": {
"type": "object",
"properties": {"location": {"type": "string"}},
},
},
}
],
},
# Thinking tests
{
"name": "assistant_with_thinking",
"messages": [
{"role": "user", "content": "What is 2+2?"},
{
"role": "assistant",
"content": "The answer is 4.",
"thinking": "Let me calculate: 2 + 2 = 4. This is basic arithmetic.",
},
{"role": "user", "content": "And 3+3?"},
],
"tools": None,
},
{
"name": "thinking_with_tool_call",
"messages": [
{"role": "user", "content": "What's the weather in Paris?"},
{
"role": "assistant",
"content": "I'll check the weather for you.",
"thinking": "The user wants to know the weather in Paris. I should call the get_weather function.",
"tool_calls": [
{
"id": "call_1",
"type": "function",
"function": {
"name": "get_weather",
"arguments": {"location": "Paris"},
},
}
],
},
{"role": "tool", "content": '{"temperature": 18, "condition": "cloudy"}', "tool_call_id": "call_1"},
],
"tools": [
{
"type": "function",
"function": {
"name": "get_weather",
"description": "Get current weather",
"parameters": {
"type": "object",
"properties": {"location": {"type": "string"}},
},
},
}
],
},
{
"name": "thinking_only_no_content",
"messages": [
{"role": "user", "content": "Think about this silently."},
{
"role": "assistant",
"content": "", # HuggingFace requires content field
"thinking": "I'm thinking about this but won't respond with visible content.",
},
{"role": "user", "content": "What did you think?"},
],
"tools": None,
},
]
# Cache for tokenizers
_tokenizer_cache: dict[str, Any] = {}
def get_tokenizer(model_name: str):
"""Get or create tokenizer for the given model."""
if model_name not in _tokenizer_cache:
print(f"Loading tokenizer for {model_name}...", file=sys.stderr)
_tokenizer_cache[model_name] = AutoTokenizer.from_pretrained(model_name)
return _tokenizer_cache[model_name]
def apply_template(
model: str,
messages: list[dict],
tools: list[dict] | None = None,
) -> str:
"""Apply HuggingFace chat template to messages."""
tokenizer = get_tokenizer(model)
if tools:
return tokenizer.apply_chat_template(
messages,
tools=tools,
tokenize=False,
add_generation_prompt=True,
)
else:
return tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True,
)
def get_ollama_prompt(
ollama_model: str,
messages: list[dict],
tools: list[dict] | None = None,
ollama_host: str = "http://localhost:11434",
) -> str | None:
"""Get rendered prompt from Ollama using debug_render_only."""
import requests
# Convert messages to Ollama format
ollama_messages = []
for msg in messages:
ollama_msg = {"role": msg["role"]}
if "content" in msg:
ollama_msg["content"] = msg["content"]
if "thinking" in msg:
ollama_msg["thinking"] = msg["thinking"]
if "tool_calls" in msg:
# Convert tool_calls to Ollama format
tool_calls = []
for tc in msg["tool_calls"]:
tool_call = {
"function": {
"name": tc["function"]["name"],
"arguments": tc["function"]["arguments"],
}
}
if "id" in tc:
tool_call["id"] = tc["id"]
tool_calls.append(tool_call)
ollama_msg["tool_calls"] = tool_calls
if "tool_call_id" in msg:
ollama_msg["tool_call_id"] = msg["tool_call_id"]
ollama_messages.append(ollama_msg)
payload = {
"model": ollama_model,
"messages": ollama_messages,
"stream": False,
"_debug_render_only": True,
}
if tools:
payload["tools"] = tools
try:
resp = requests.post(f"{ollama_host}/api/chat", json=payload, timeout=30)
resp.raise_for_status()
data = resp.json()
# Field name is _debug_info with underscore prefix
if "_debug_info" in data and "rendered_template" in data["_debug_info"]:
return data["_debug_info"]["rendered_template"]
return None
except requests.exceptions.ConnectionError:
print(f" [ERROR] Cannot connect to Ollama at {ollama_host}", file=sys.stderr)
return None
except Exception as e:
print(f" [ERROR] Ollama request failed: {e}", file=sys.stderr)
return None
def compute_diff(hf_prompt: str, ollama_prompt: str) -> str:
"""Compute a unified diff between HuggingFace and Ollama prompts."""
import difflib
hf_lines = hf_prompt.splitlines(keepends=True)
ollama_lines = ollama_prompt.splitlines(keepends=True)
diff = difflib.unified_diff(
ollama_lines,
hf_lines,
fromfile="Ollama",
tofile="HuggingFace",
lineterm="",
)
return "".join(diff)
def print_test_output(
name: str,
messages: list[dict],
tools: list[dict] | None,
hf_prompt: str,
ollama_prompt: str | None = None,
as_repr: bool = False,
):
"""Print test output in a format suitable for Go test creation and LLM diffing."""
print(f"\n{'='*60}")
print(f"Test: {name}")
print("=" * 60)
print("\n--- Input Messages ---")
print(json.dumps(messages, indent=2))
if tools:
print("\n--- Tools ---")
print(json.dumps(tools, indent=2))
if ollama_prompt is not None:
# Comparison mode
if hf_prompt == ollama_prompt:
print("\n--- Result: MATCH ---")
print("\n--- Prompt (both identical) ---")
if as_repr:
print(repr(hf_prompt))
else:
print(hf_prompt)
else:
print("\n--- Result: MISMATCH ---")
print("\n--- HuggingFace Prompt ---")
if as_repr:
print(repr(hf_prompt))
else:
print(hf_prompt)
print("\n--- Ollama Prompt ---")
if as_repr:
print(repr(ollama_prompt))
else:
print(ollama_prompt)
print("\n--- Diff (Ollama -> HuggingFace) ---")
diff = compute_diff(hf_prompt, ollama_prompt)
if diff:
print(diff)
else:
print("(no line-level diff, check whitespace)")
else:
# HuggingFace only mode
print("\n--- HuggingFace Prompt ---")
if as_repr:
print(repr(hf_prompt))
else:
print(hf_prompt)
print("=" * 60)
def run_tests(
model: str,
as_repr: bool = False,
test_filter: str | None = None,
ollama_model: str | None = None,
ollama_host: str = "http://localhost:11434",
):
"""Run all predefined test cases against a model."""
if ollama_model:
print(f"\nComparing HuggingFace ({model}) vs Ollama ({ollama_model})\n")
else:
print(f"\nRunning tests against: {model}\n")
matches = 0
mismatches = 0
errors = 0
for test_case in TEST_CASES:
name = test_case["name"]
messages = test_case["messages"]
tools = test_case["tools"]
# Filter tests if specified
if test_filter and test_filter.lower() not in name.lower():
continue
try:
hf_prompt = apply_template(model, messages, tools)
ollama_prompt = None
if ollama_model:
ollama_prompt = get_ollama_prompt(
ollama_model, messages, tools, ollama_host
)
if ollama_prompt is None:
errors += 1
elif hf_prompt == ollama_prompt:
matches += 1
else:
mismatches += 1
print_test_output(
name, messages, tools, hf_prompt, ollama_prompt, as_repr=as_repr
)
except Exception as e:
errors += 1
print(f"\n{'='*60}")
print(f"Test: {name} - FAILED")
print(f"--- Input Messages ---")
print(json.dumps(messages, indent=2))
if tools:
print(f"--- Tools ---")
print(json.dumps(tools, indent=2))
print(f"--- Error ---")
print(f"{e}")
print("=" * 60)
# Print summary if comparing
if ollama_model:
total = matches + mismatches + errors
print(f"\n{'='*60}")
print("SUMMARY")
print("=" * 60)
print(f" Total: {total}")
print(f" Matches: {matches}")
print(f" Mismatches: {mismatches}")
print(f" Errors: {errors}")
print("=" * 60)
def show_template(model: str):
"""Show the chat template for a model."""
tokenizer = get_tokenizer(model)
print(f"\nChat template for {model}:\n")
print("-" * 60)
print(tokenizer.chat_template)
print("-" * 60)
def start_server(host: str = "0.0.0.0", port: int = 8000):
"""Start the FastAPI server for manual testing."""
from typing import Optional, List, Dict, Any as TypingAny
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
import uvicorn
class Message(BaseModel):
role: str
content: Optional[str] = None
tool_calls: Optional[List[Dict[str, TypingAny]]] = None
tool_call_id: Optional[str] = None
class GeneratePromptRequest(BaseModel):
messages: List[Message]
model: str = "PrimeIntellect/INTELLECT-3"
tools: Optional[List[Dict[str, TypingAny]]] = None
inject_tools_as_functions: bool = False
class GeneratePromptResponse(BaseModel):
prompt: str
model: str
app = FastAPI(title="HuggingFace Prompt Generator", version="1.0.0")
@app.post("/generate-prompt", response_model=GeneratePromptResponse)
async def generate_prompt(request: GeneratePromptRequest):
try:
messages = []
for msg in request.messages:
message_dict = {"role": msg.role}
if msg.content is not None:
message_dict["content"] = msg.content
if msg.tool_calls is not None:
tool_calls = []
for tc in msg.tool_calls:
tc_copy = tc.copy()
if "function" in tc_copy and "arguments" in tc_copy["function"]:
args = tc_copy["function"]["arguments"]
if isinstance(args, str):
try:
tc_copy["function"]["arguments"] = json.loads(args)
except json.JSONDecodeError:
pass
tool_calls.append(tc_copy)
message_dict["tool_calls"] = tool_calls
if msg.tool_call_id is not None:
message_dict["tool_call_id"] = msg.tool_call_id
messages.append(message_dict)
prompt = apply_template(request.model, messages, request.tools)
return GeneratePromptResponse(prompt=prompt, model=request.model)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.get("/health")
async def health_check():
return {"status": "healthy"}
print(f"Starting server on http://{host}:{port}")
print("Endpoints:")
print(" POST /generate-prompt - Generate prompt from messages")
print(" GET /health - Health check")
uvicorn.run(app, host=host, port=port)
def main():
parser = argparse.ArgumentParser(
description="HuggingFace Prompt Testing Tool",
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog=__doc__,
)
parser.add_argument(
"--model",
"-m",
type=str,
help="HuggingFace model name (e.g., PrimeIntellect/INTELLECT-3)",
)
parser.add_argument(
"--ollama-model",
"-o",
type=str,
help="Ollama model name to compare against (e.g., qwen3-coder)",
)
parser.add_argument(
"--ollama-host",
type=str,
default="http://localhost:11434",
help="Ollama server URL (default: http://localhost:11434)",
)
parser.add_argument(
"--serve",
"-s",
action="store_true",
help="Start FastAPI server for manual curl testing",
)
parser.add_argument(
"--port",
"-p",
type=int,
default=8000,
help="Server port (default: 8000)",
)
parser.add_argument(
"--show-template",
"-t",
action="store_true",
help="Show the chat template for the model",
)
parser.add_argument(
"--repr",
"-r",
action="store_true",
help="Output prompts as Python repr (shows escape sequences)",
)
parser.add_argument(
"--filter",
"-f",
type=str,
help="Filter tests by name (substring match)",
)
args = parser.parse_args()
if args.serve:
start_server(port=args.port)
elif args.model:
if args.show_template:
show_template(args.model)
else:
run_tests(
args.model,
as_repr=args.repr,
test_filter=args.filter,
ollama_model=args.ollama_model,
ollama_host=args.ollama_host,
)
else:
parser.print_help()
print("\nExample usage:")
print(" uv run cmd/chat_template/chat_template.py --model PrimeIntellect/INTELLECT-3")
print(" uv run cmd/chat_template/chat_template.py --model Qwen/Qwen3-Coder-480B-A35B-Instruct --ollama-model qwen3-coder")
print(" uv run cmd/chat_template/chat_template.py --serve")
sys.exit(1)
if __name__ == "__main__":
main()

View File

@@ -200,6 +200,8 @@ func ConvertModel(fsys fs.FS, f *os.File) error {
conv = &qwen25VLModel{}
case "Qwen3VLForConditionalGeneration", "Qwen3VLMoeForConditionalGeneration":
conv = &qwen3VLModel{}
case "OLMo2ForCausalLM", "Olmo2ForCausalLM", "OLMo3ForCausalLM", "Olmo3ForCausalLM":
conv = &olmoModel{}
case "BertModel":
conv = &bertModel{}
case "CohereForCausalLM":

94
convert/convert_olmo.go Normal file
View File

@@ -0,0 +1,94 @@
package convert
import (
"cmp"
"github.com/ollama/ollama/fs/ggml"
)
type olmoModel struct {
ModelParameters
HiddenSize uint32 `json:"hidden_size"`
NumHiddenLayers uint32 `json:"num_hidden_layers"`
IntermediateSize uint32 `json:"intermediate_size"`
NumAttentionHeads uint32 `json:"num_attention_heads"`
NumKeyValueHeads uint32 `json:"num_key_value_heads"`
MaxPositionEmbeddings uint32 `json:"max_position_embeddings"`
RMSNormEPS float32 `json:"rms_norm_eps"`
RopeTheta float32 `json:"rope_theta"`
ClampKQV float32 `json:"f_clamp_kqv"`
SlidingWindow uint32 `json:"sliding_window"`
LayerTypes []string `json:"layer_types"`
}
var _ ModelConverter = (*olmoModel)(nil)
func (p *olmoModel) KV(t *Tokenizer) ggml.KV {
kv := p.ModelParameters.KV(t)
kv["general.architecture"] = "olmo"
kv["olmo.block_count"] = p.NumHiddenLayers
kv["olmo.context_length"] = p.MaxPositionEmbeddings
kv["olmo.embedding_length"] = p.HiddenSize
kv["olmo.feed_forward_length"] = p.IntermediateSize
kv["olmo.attention.head_count"] = p.NumAttentionHeads
kv["olmo.attention.head_count_kv"] = cmp.Or(p.NumKeyValueHeads, p.NumAttentionHeads)
if p.RopeTheta > 0 {
kv["olmo.rope.freq_base"] = p.RopeTheta
} else {
kv["olmo.rope.freq_base"] = float32(10000.0)
}
if p.RMSNormEPS > 0 {
kv["olmo.attention.layer_norm_rms_epsilon"] = p.RMSNormEPS
}
if p.ClampKQV > 0 {
kv["olmo.attention.clamp_kqv"] = p.ClampKQV
}
if p.SlidingWindow > 0 {
kv["olmo.attention.sliding_window"] = p.SlidingWindow
}
if len(p.LayerTypes) > 0 {
kv["olmo.attention.layer_types"] = p.LayerTypes
}
return kv
}
func (p *olmoModel) Tensors(ts []Tensor) []*ggml.Tensor {
out := make([]*ggml.Tensor, 0, len(ts))
for _, t := range ts {
out = append(out, &ggml.Tensor{
Name: t.Name(),
Kind: t.Kind(),
Shape: t.Shape(),
WriterTo: t,
})
}
return out
}
func (p *olmoModel) Replacements() []string {
return []string{
"lm_head", "output",
"model.embed_tokens", "token_embd",
"model.layers", "blk",
"model.norm", "output_norm",
"self_attn.q_proj", "attn_q",
"self_attn.k_proj", "attn_k",
"self_attn.v_proj", "attn_v",
"self_attn.o_proj", "attn_output",
"self_attn.q_norm", "attn_q_norm",
"self_attn.k_norm", "attn_k_norm",
"post_attention_layernorm", "post_attention_norm",
"post_feedforward_layernorm", "post_ffw_norm",
"mlp.gate_proj", "ffn_gate",
"mlp.down_proj", "ffn_down",
"mlp.up_proj", "ffn_up",
}
}

View File

@@ -65,7 +65,6 @@ func GPUDevices(ctx context.Context, runners []ml.FilteredRunnerDiscovery) []ml.
}
slog.Info("discovering available GPUs...")
detectIncompatibleLibraries()
// Warn if any user-overrides are set which could lead to incorrect GPU discovery
overrideWarnings()
@@ -99,9 +98,6 @@ func GPUDevices(ctx context.Context, runners []ml.FilteredRunnerDiscovery) []ml.
continue
} else if jetpack != "" && filepath.Base(dir) != "cuda_"+jetpack {
continue
} else if jetpack == "" && strings.Contains(filepath.Base(dir), "cuda_jetpack") {
slog.Debug("jetpack not detected (set JETSON_JETPACK or OLLAMA_LLM_LIBRARY to override), skipping", "libDir", dir)
continue
} else if !envconfig.EnableVulkan() && strings.Contains(filepath.Base(dir), "vulkan") {
slog.Info("experimental Vulkan support disabled. To enable, set OLLAMA_VULKAN=1")
continue
@@ -488,16 +484,3 @@ func overrideWarnings() {
slog.Warn("if GPUs are not correctly discovered, unset and try again")
}
}
func detectIncompatibleLibraries() {
if runtime.GOOS != "windows" {
return
}
basePath, err := exec.LookPath("ggml-base.dll")
if err != nil || basePath == "" {
return
}
if !strings.HasPrefix(basePath, ml.LibOllamaPath) {
slog.Warn("potentially incompatible library detected in PATH", "location", basePath)
}
}

View File

@@ -57,13 +57,8 @@ ollama ps
```
<Info>
**Output**:
```
NAME ID SIZE PROCESSOR UNTIL
llama3:70b bcfb190ca3a7 42 GB 100% GPU 4 minutes from now
```
**Output**: ``` NAME ID SIZE PROCESSOR UNTIL llama3:70b bcfb190ca3a7 42 GB
100% GPU 4 minutes from now ```
</Info>
The `Processor` column will show which memory the model was loaded in to:
@@ -390,4 +385,4 @@ Ollama for Windows and macOS register as a login item during installation. You
- In `Task Manager` go to the `Startup apps` tab, search for `ollama` then click `Disable`
**MacOS**
- Open `Settings` and search for "Login Items", find the `Ollama` entry under "Allow in the Background`, then click the slider to disable.
- Open `Settings` and search for "Login Items", find the `Ollama` entry under "Allow in the Background`, then click the slider to disable.

View File

@@ -149,6 +149,9 @@ PARAMETER <parameter> <parametervalue>
| Parameter | Description | Value Type | Example Usage |
| -------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | ---------- | -------------------- |
| mirostat | Enable Mirostat sampling for controlling perplexity. (default: 0, 0 = disabled, 1 = Mirostat, 2 = Mirostat 2.0) | int | mirostat 0 |
| mirostat_eta | Influences how quickly the algorithm responds to feedback from the generated text. A lower learning rate will result in slower adjustments, while a higher learning rate will make the algorithm more responsive. (Default: 0.1) | float | mirostat_eta 0.1 |
| mirostat_tau | Controls the balance between coherence and diversity of the output. A lower value will result in more focused and coherent text. (Default: 5.0) | float | mirostat_tau 5.0 |
| num_ctx | Sets the size of the context window used to generate the next token. (Default: 2048) | int | num_ctx 4096 |
| repeat_last_n | Sets how far back for the model to look back to prevent repetition. (Default: 64, 0 = disabled, -1 = num_ctx) | int | repeat_last_n 64 |
| repeat_penalty | Sets how strongly to penalize repetitions. A higher value (e.g., 1.5) will penalize repetitions more strongly, while a lower value (e.g., 0.9) will be more lenient. (Default: 1.1) | float | repeat_penalty 1.1 |

View File

@@ -13,6 +13,7 @@ import (
_ "github.com/ollama/ollama/model/models/mistral3"
_ "github.com/ollama/ollama/model/models/mllama"
_ "github.com/ollama/ollama/model/models/nomicbert"
_ "github.com/ollama/ollama/model/models/olmo"
_ "github.com/ollama/ollama/model/models/qwen2"
_ "github.com/ollama/ollama/model/models/qwen25vl"
_ "github.com/ollama/ollama/model/models/qwen3"

233
model/models/olmo/model.go Normal file
View File

@@ -0,0 +1,233 @@
package olmo
import (
"cmp"
"math"
"github.com/ollama/ollama/fs"
"github.com/ollama/ollama/kvcache"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/ml/nn"
"github.com/ollama/ollama/ml/nn/fast"
"github.com/ollama/ollama/ml/nn/rope"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/model/input"
)
type Options struct {
hiddenSize, numHeads, numKVHeads int
headDim, ropeDim int
eps, ropeBase, ropeScale float32
clampKQV float32
originalContextLength int
attnFactor float32
}
type Model struct {
model.Base
model.TextProcessor
TokenEmbedding *nn.Embedding `gguf:"token_embd"`
Layers []Layer `gguf:"blk"`
OutputNorm *nn.RMSNorm `gguf:"output_norm"`
Output *nn.Linear `gguf:"output,alt:token_embd"`
Options
}
func New(c fs.Config) (model.Model, error) {
vocabulary := model.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"),
Scores: c.Floats("tokenizer.ggml.scores"),
Types: c.Ints("tokenizer.ggml.token_type"),
Merges: c.Strings("tokenizer.ggml.merges"),
AddBOS: c.Bool("tokenizer.ggml.add_bos_token", true),
BOS: []int32{int32(c.Uint("tokenizer.ggml.bos_token_id"))},
AddEOS: c.Bool("tokenizer.ggml.add_eos_token", false),
EOS: append(
[]int32{int32(c.Uint("tokenizer.ggml.eos_token_id"))},
c.Ints("tokenizer.ggml.eos_token_ids")...,
),
}
if c.String("tokenizer.ggml.model") != "gpt2" {
return nil, model.ErrUnsupportedTokenizer
}
var pretokenizers []string
if c.String("tokenizer.ggml.pre") != "default" {
pretokenizers = []string{
`(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+`,
}
}
processor := model.NewBytePairEncoding(&vocabulary, pretokenizers...)
m := Model{
TextProcessor: processor,
Layers: make([]Layer, c.Uint("block_count")),
Options: Options{
hiddenSize: int(c.Uint("embedding_length")),
numHeads: int(c.Uint("attention.head_count")),
numKVHeads: int(c.Uint("attention.head_count_kv")),
headDim: int(c.Uint("attention.key_length")),
ropeDim: int(c.Uint("rope.dimension_count")),
eps: c.Float("attention.layer_norm_rms_epsilon"),
ropeBase: c.Float("rope.freq_base", 1e4),
ropeScale: c.Float("rope.scaling.factor", 1),
clampKQV: c.Float("attention.clamp_kqv", 0),
originalContextLength: int(c.Uint("rope.scaling.original_context_length")),
attnFactor: c.Float("rope.scaling.attn_factor", 1),
},
}
m.Cache = kvcache.NewCausalCache(m.Shift)
return &m, nil
}
type SelfAttention struct {
Query *nn.Linear `gguf:"attn_q"`
Key *nn.Linear `gguf:"attn_k"`
Value *nn.Linear `gguf:"attn_v"`
Output *nn.Linear `gguf:"attn_output"`
QNorm *nn.RMSNorm `gguf:"attn_q_norm"`
KNorm *nn.RMSNorm `gguf:"attn_k_norm"`
RopeFactors ml.Tensor `gguf:"rope_freqs.weight"`
}
func (o *Options) ropeOptions(factors ml.Tensor, isSWA bool) []func(*rope.Options) {
opts := []func(*rope.Options){
rope.WithFactors(factors),
}
if !isSWA && o.originalContextLength > 0 {
opts = append(opts,
rope.WithOriginalContextLength(o.originalContextLength),
rope.WithExtrapolationFactor(1.),
rope.WithAttentionFactor(o.attnFactor),
)
}
return opts
}
func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positions ml.Tensor, cache kvcache.Cache, opts *Options, isSWA bool) ml.Tensor {
batchSize := hiddenState.Dim(1)
headDim := cmp.Or(opts.headDim, opts.hiddenSize/opts.numHeads)
ropeDim := cmp.Or(opts.ropeDim, headDim)
query := sa.Query.Forward(ctx, hiddenState)
if sa.QNorm != nil {
query = sa.QNorm.Forward(ctx, query, opts.eps)
}
query = query.Reshape(ctx, headDim, opts.numHeads, batchSize)
key := sa.Key.Forward(ctx, hiddenState)
if sa.KNorm != nil {
key = sa.KNorm.Forward(ctx, key, opts.eps)
}
key = key.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
value := sa.Value.Forward(ctx, hiddenState)
value = value.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
freqScale := float32(1.0)
if !isSWA {
freqScale = 1. / opts.ropeScale
}
ropeOpts := opts.ropeOptions(sa.RopeFactors, isSWA)
query = fast.RoPE(ctx, query, positions, ropeDim, opts.ropeBase, freqScale, ropeOpts...)
key = fast.RoPE(ctx, key, positions, ropeDim, opts.ropeBase, freqScale, ropeOpts...)
attention := nn.Attention(ctx, query, key, value, 1.0/math.Sqrt(float64(headDim)), cache)
attention = attention.Reshape(ctx, headDim*opts.numHeads, batchSize)
return sa.Output.Forward(ctx, attention)
}
func (m *Model) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
ropeDim := cmp.Or(m.ropeDim, m.hiddenSize/m.numHeads)
isSWA := isSWALayer(layer)
freqScale := float32(1.0)
if !isSWA {
freqScale = 1. / m.ropeScale
}
ropeOpts := m.Options.ropeOptions(m.Layers[layer].SelfAttention.RopeFactors, isSWA)
return fast.RoPE(ctx, key, shift, ropeDim, m.ropeBase, freqScale, ropeOpts...), nil
}
type MLP struct {
Up *nn.Linear `gguf:"ffn_up"`
Down *nn.Linear `gguf:"ffn_down"`
Gate *nn.Linear `gguf:"ffn_gate"`
}
func (mlp *MLP) Forward(ctx ml.Context, hiddenState ml.Tensor, opts *Options) ml.Tensor {
hiddenState = mlp.Gate.Forward(ctx, hiddenState).SILU(ctx, mlp.Up.Forward(ctx, hiddenState))
return mlp.Down.Forward(ctx, hiddenState)
}
type Layer struct {
SelfAttention *SelfAttention
PostAttentionNorm *nn.RMSNorm `gguf:"post_attention_norm"`
MLP *MLP
PostFFWNorm *nn.RMSNorm `gguf:"post_ffw_norm"`
}
func (l *Layer) Forward(ctx ml.Context, hiddenState, positions, outputs ml.Tensor, cache kvcache.Cache, opts *Options, isSWA bool) ml.Tensor {
residual := hiddenState
hiddenState = l.SelfAttention.Forward(ctx, hiddenState, positions, cache, opts, isSWA)
if outputs != nil {
hiddenState = hiddenState.Rows(ctx, outputs)
residual = residual.Rows(ctx, outputs)
}
if l.PostAttentionNorm != nil {
hiddenState = l.PostAttentionNorm.Forward(ctx, hiddenState, opts.eps)
}
ffnInput := hiddenState.Add(ctx, residual)
hiddenState = l.MLP.Forward(ctx, ffnInput, opts)
if l.PostFFWNorm != nil {
hiddenState = l.PostFFWNorm.Forward(ctx, hiddenState, opts.eps)
}
return hiddenState.Add(ctx, ffnInput)
}
func isSWALayer(layerIdx int) bool {
return (layerIdx+1)%4 != 0
}
func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
positions := ctx.Input().FromInts(batch.Positions, len(batch.Positions))
hiddenState := m.TokenEmbedding.Forward(ctx, batch.Inputs)
for i, layer := range m.Layers {
m.Cache.SetLayer(i)
isSWA := isSWALayer(i)
var outputs ml.Tensor
if i == len(m.Layers)-1 {
outputs = batch.Outputs
}
hiddenState = layer.Forward(ctx, hiddenState, positions, outputs, m.Cache, &m.Options, isSWA)
}
hiddenState = m.OutputNorm.Forward(ctx, hiddenState, m.eps)
return m.Output.Forward(ctx, hiddenState), nil
}
func init() {
model.Register("olmo2", New)
}