From 5c3bf414ef9baeaa9e081c20062c1861834c3a70 Mon Sep 17 00:00:00 2001 From: ParthSareen Date: Mon, 8 Dec 2025 18:17:56 -0800 Subject: [PATCH] close to working --- cmd/testolmo/main.go | 6 ++-- fs/ggml/ggml.go | 1 + model/models/olmo/model.go | 60 ++++++++++++++++++++++++++++++-------- model/renderers/olmo3.go | 2 +- 4 files changed, 53 insertions(+), 16 deletions(-) diff --git a/cmd/testolmo/main.go b/cmd/testolmo/main.go index 67f32c4fb..842d66394 100644 --- a/cmd/testolmo/main.go +++ b/cmd/testolmo/main.go @@ -43,7 +43,7 @@ func main() { // Use the olmo3 renderer to format the prompt properly messages := []api.Message{ - {Role: "user", Content: "What is machine learning?"}, + {Role: "user", Content: "wagwan"}, } // prompt := "Question: What is machine learning? Answer:" prompt, err := renderers.RenderWithRenderer("olmo3", messages, nil, nil) @@ -55,7 +55,7 @@ func main() { fmt.Printf("\nRendered prompt:\n%s\n", prompt) tp := m.(model.TextProcessor) - tokens, err := tp.Encode(prompt, true) + tokens, err := tp.Encode(prompt, false) if err != nil { log.Fatal(err) } @@ -67,7 +67,7 @@ func main() { generated := make([]int32, 0, maxTokens) // Create sampler (temperature=0 for greedy sampling) - sampler := sample.NewSampler(0.6, 0, 0, 0, -1, nil) + sampler := sample.NewSampler(0, 0, 0, 0, -1, nil) for i := 0; i < maxTokens; i++ { // Create a new context for each generation step to avoid memory buildup diff --git a/fs/ggml/ggml.go b/fs/ggml/ggml.go index 6ce9724f2..0f4cfeced 100644 --- a/fs/ggml/ggml.go +++ b/fs/ggml/ggml.go @@ -252,6 +252,7 @@ func (kv KV) OllamaEngineRequired() bool { "deepseekocr", "deepseek2", "nomic-bert", + "olmo2", }, kv.Architecture()) } diff --git a/model/models/olmo/model.go b/model/models/olmo/model.go index 231cbf7f8..6fb205d3f 100644 --- a/model/models/olmo/model.go +++ b/model/models/olmo/model.go @@ -2,7 +2,6 @@ package olmo import ( "cmp" - "fmt" "math" "github.com/ollama/ollama/fs" @@ -23,6 +22,8 @@ type Options struct { originalContextLength int attnFactor float32 + slidingWindow int32 + slidingWindowPattern []bool // per-layer SWA pattern (true = SWA, false = full attention) } type Model struct { @@ -38,7 +39,6 @@ type Model struct { } func New(c fs.Config) (model.Model, error) { - fmt.Println("🦙 OLMo model loaded!") vocabulary := model.Vocabulary{ Values: c.Strings("tokenizer.ggml.tokens"), Scores: c.Floats("tokenizer.ggml.scores"), @@ -65,6 +65,9 @@ func New(c fs.Config) (model.Model, error) { } processor := model.NewBytePairEncoding(&vocabulary, pretokenizers...) + slidingWindow := int32(c.Uint("attention.sliding_window")) + slidingWindowPattern := c.Bools("attention.sliding_window_pattern") + m := Model{ TextProcessor: processor, Layers: make([]Layer, c.Uint("block_count")), @@ -80,10 +83,16 @@ func New(c fs.Config) (model.Model, error) { clampKQV: c.Float("attention.clamp_kqv", 0), originalContextLength: int(c.Uint("rope.scaling.original_context_length")), attnFactor: c.Float("rope.scaling.attn_factor", 1), + slidingWindow: slidingWindow, + slidingWindowPattern: slidingWindowPattern, }, } - m.Cache = kvcache.NewCausalCache(m.Shift) + // OLMo3 uses interleaved sliding window attention (every 4th layer is full attention) + m.Cache = kvcache.NewWrapperCache( + kvcache.NewSWACache(slidingWindow, m.Shift), + kvcache.NewCausalCache(m.Shift), + ) return &m, nil } @@ -103,12 +112,23 @@ func (o *Options) ropeOptions(factors ml.Tensor, isSWA bool) []func(*rope.Option rope.WithFactors(factors), } - if !isSWA && o.originalContextLength > 0 { - opts = append(opts, - rope.WithOriginalContextLength(o.originalContextLength), - rope.WithExtrapolationFactor(1.), - rope.WithAttentionFactor(o.attnFactor), - ) + if o.originalContextLength > 0 { + if isSWA { + // For SWA layers, use regular rope with no YaRN scaling + // ext_factor=0.0, attn_factor=1.0 per llama.cpp + opts = append(opts, + rope.WithOriginalContextLength(o.originalContextLength), + rope.WithExtrapolationFactor(0.), + rope.WithAttentionFactor(1.), + ) + } else { + // For full attention layers, use YaRN scaling + opts = append(opts, + rope.WithOriginalContextLength(o.originalContextLength), + rope.WithExtrapolationFactor(1.), + rope.WithAttentionFactor(o.attnFactor), + ) + } } return opts @@ -150,7 +170,7 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positions ml.Tenso func (m *Model) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) { ropeDim := cmp.Or(m.ropeDim, m.hiddenSize/m.numHeads) - isSWA := isSWALayer(layer) + isSWA := m.isSWALayer(layer) freqScale := float32(1.0) if !isSWA { @@ -204,7 +224,14 @@ func (l *Layer) Forward(ctx ml.Context, hiddenState, positions, outputs ml.Tenso return hiddenState.Add(ctx, ffnInput) } -func isSWALayer(layerIdx int) bool { +// isSWALayer returns true if the layer uses sliding window attention. +// Uses the sliding_window_pattern from the model config if available, +// otherwise falls back to the default OLMo3 pattern (every 4th layer is full attention). +func (m *Model) isSWALayer(layerIdx int) bool { + if len(m.slidingWindowPattern) > layerIdx { + return m.slidingWindowPattern[layerIdx] + } + // Fallback: OLMo3 pattern where every 4th layer (indices 3, 7, 11, ...) uses full attention return (layerIdx+1)%4 != 0 } @@ -216,7 +243,16 @@ func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) { for i, layer := range m.Layers { m.Cache.SetLayer(i) - isSWA := isSWALayer(i) + isSWA := m.isSWALayer(i) + + // Set cache type for interleaved SWA (OLMo3) + if wc, ok := m.Cache.(*kvcache.WrapperCache); ok { + if isSWA { + wc.SetLayerType(0) // SWA cache + } else { + wc.SetLayerType(1) // Causal cache + } + } var outputs ml.Tensor if i == len(m.Layers)-1 { diff --git a/model/renderers/olmo3.go b/model/renderers/olmo3.go index 24ade20dc..6b0e320dd 100644 --- a/model/renderers/olmo3.go +++ b/model/renderers/olmo3.go @@ -140,7 +140,7 @@ func (r *Olmo3Renderer) Render(messages []api.Message, tools []api.Tool, _ *api. } if needsGenerationPrompt { - sb.WriteString("<|im_start|>assistant\n\n") + sb.WriteString("<|im_start|>assistant\n") } return sb.String(), nil