close to working

This commit is contained in:
ParthSareen 2025-12-08 18:17:56 -08:00
parent 0a9862a383
commit 5c3bf414ef
4 changed files with 53 additions and 16 deletions

View File

@ -43,7 +43,7 @@ func main() {
// Use the olmo3 renderer to format the prompt properly
messages := []api.Message{
{Role: "user", Content: "What is machine learning?"},
{Role: "user", Content: "wagwan"},
}
// prompt := "Question: What is machine learning? Answer:"
prompt, err := renderers.RenderWithRenderer("olmo3", messages, nil, nil)
@ -55,7 +55,7 @@ func main() {
fmt.Printf("\nRendered prompt:\n%s\n", prompt)
tp := m.(model.TextProcessor)
tokens, err := tp.Encode(prompt, true)
tokens, err := tp.Encode(prompt, false)
if err != nil {
log.Fatal(err)
}
@ -67,7 +67,7 @@ func main() {
generated := make([]int32, 0, maxTokens)
// Create sampler (temperature=0 for greedy sampling)
sampler := sample.NewSampler(0.6, 0, 0, 0, -1, nil)
sampler := sample.NewSampler(0, 0, 0, 0, -1, nil)
for i := 0; i < maxTokens; i++ {
// Create a new context for each generation step to avoid memory buildup

View File

@ -252,6 +252,7 @@ func (kv KV) OllamaEngineRequired() bool {
"deepseekocr",
"deepseek2",
"nomic-bert",
"olmo2",
}, kv.Architecture())
}

View File

@ -2,7 +2,6 @@ package olmo
import (
"cmp"
"fmt"
"math"
"github.com/ollama/ollama/fs"
@ -23,6 +22,8 @@ type Options struct {
originalContextLength int
attnFactor float32
slidingWindow int32
slidingWindowPattern []bool // per-layer SWA pattern (true = SWA, false = full attention)
}
type Model struct {
@ -38,7 +39,6 @@ type Model struct {
}
func New(c fs.Config) (model.Model, error) {
fmt.Println("🦙 OLMo model loaded!")
vocabulary := model.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"),
Scores: c.Floats("tokenizer.ggml.scores"),
@ -65,6 +65,9 @@ func New(c fs.Config) (model.Model, error) {
}
processor := model.NewBytePairEncoding(&vocabulary, pretokenizers...)
slidingWindow := int32(c.Uint("attention.sliding_window"))
slidingWindowPattern := c.Bools("attention.sliding_window_pattern")
m := Model{
TextProcessor: processor,
Layers: make([]Layer, c.Uint("block_count")),
@ -80,10 +83,16 @@ func New(c fs.Config) (model.Model, error) {
clampKQV: c.Float("attention.clamp_kqv", 0),
originalContextLength: int(c.Uint("rope.scaling.original_context_length")),
attnFactor: c.Float("rope.scaling.attn_factor", 1),
slidingWindow: slidingWindow,
slidingWindowPattern: slidingWindowPattern,
},
}
m.Cache = kvcache.NewCausalCache(m.Shift)
// OLMo3 uses interleaved sliding window attention (every 4th layer is full attention)
m.Cache = kvcache.NewWrapperCache(
kvcache.NewSWACache(slidingWindow, m.Shift),
kvcache.NewCausalCache(m.Shift),
)
return &m, nil
}
@ -103,12 +112,23 @@ func (o *Options) ropeOptions(factors ml.Tensor, isSWA bool) []func(*rope.Option
rope.WithFactors(factors),
}
if !isSWA && o.originalContextLength > 0 {
opts = append(opts,
rope.WithOriginalContextLength(o.originalContextLength),
rope.WithExtrapolationFactor(1.),
rope.WithAttentionFactor(o.attnFactor),
)
if o.originalContextLength > 0 {
if isSWA {
// For SWA layers, use regular rope with no YaRN scaling
// ext_factor=0.0, attn_factor=1.0 per llama.cpp
opts = append(opts,
rope.WithOriginalContextLength(o.originalContextLength),
rope.WithExtrapolationFactor(0.),
rope.WithAttentionFactor(1.),
)
} else {
// For full attention layers, use YaRN scaling
opts = append(opts,
rope.WithOriginalContextLength(o.originalContextLength),
rope.WithExtrapolationFactor(1.),
rope.WithAttentionFactor(o.attnFactor),
)
}
}
return opts
@ -150,7 +170,7 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positions ml.Tenso
func (m *Model) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
ropeDim := cmp.Or(m.ropeDim, m.hiddenSize/m.numHeads)
isSWA := isSWALayer(layer)
isSWA := m.isSWALayer(layer)
freqScale := float32(1.0)
if !isSWA {
@ -204,7 +224,14 @@ func (l *Layer) Forward(ctx ml.Context, hiddenState, positions, outputs ml.Tenso
return hiddenState.Add(ctx, ffnInput)
}
func isSWALayer(layerIdx int) bool {
// isSWALayer returns true if the layer uses sliding window attention.
// Uses the sliding_window_pattern from the model config if available,
// otherwise falls back to the default OLMo3 pattern (every 4th layer is full attention).
func (m *Model) isSWALayer(layerIdx int) bool {
if len(m.slidingWindowPattern) > layerIdx {
return m.slidingWindowPattern[layerIdx]
}
// Fallback: OLMo3 pattern where every 4th layer (indices 3, 7, 11, ...) uses full attention
return (layerIdx+1)%4 != 0
}
@ -216,7 +243,16 @@ func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
for i, layer := range m.Layers {
m.Cache.SetLayer(i)
isSWA := isSWALayer(i)
isSWA := m.isSWALayer(i)
// Set cache type for interleaved SWA (OLMo3)
if wc, ok := m.Cache.(*kvcache.WrapperCache); ok {
if isSWA {
wc.SetLayerType(0) // SWA cache
} else {
wc.SetLayerType(1) // Causal cache
}
}
var outputs ml.Tensor
if i == len(m.Layers)-1 {

View File

@ -140,7 +140,7 @@ func (r *Olmo3Renderer) Render(messages []api.Message, tools []api.Tool, _ *api.
}
if needsGenerationPrompt {
sb.WriteString("<|im_start|>assistant\n\n")
sb.WriteString("<|im_start|>assistant\n")
}
return sb.String(), nil