close to working
This commit is contained in:
parent
0a9862a383
commit
5c3bf414ef
|
|
@ -43,7 +43,7 @@ func main() {
|
|||
|
||||
// Use the olmo3 renderer to format the prompt properly
|
||||
messages := []api.Message{
|
||||
{Role: "user", Content: "What is machine learning?"},
|
||||
{Role: "user", Content: "wagwan"},
|
||||
}
|
||||
// prompt := "Question: What is machine learning? Answer:"
|
||||
prompt, err := renderers.RenderWithRenderer("olmo3", messages, nil, nil)
|
||||
|
|
@ -55,7 +55,7 @@ func main() {
|
|||
fmt.Printf("\nRendered prompt:\n%s\n", prompt)
|
||||
|
||||
tp := m.(model.TextProcessor)
|
||||
tokens, err := tp.Encode(prompt, true)
|
||||
tokens, err := tp.Encode(prompt, false)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
|
@ -67,7 +67,7 @@ func main() {
|
|||
generated := make([]int32, 0, maxTokens)
|
||||
|
||||
// Create sampler (temperature=0 for greedy sampling)
|
||||
sampler := sample.NewSampler(0.6, 0, 0, 0, -1, nil)
|
||||
sampler := sample.NewSampler(0, 0, 0, 0, -1, nil)
|
||||
|
||||
for i := 0; i < maxTokens; i++ {
|
||||
// Create a new context for each generation step to avoid memory buildup
|
||||
|
|
|
|||
|
|
@ -252,6 +252,7 @@ func (kv KV) OllamaEngineRequired() bool {
|
|||
"deepseekocr",
|
||||
"deepseek2",
|
||||
"nomic-bert",
|
||||
"olmo2",
|
||||
}, kv.Architecture())
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -2,7 +2,6 @@ package olmo
|
|||
|
||||
import (
|
||||
"cmp"
|
||||
"fmt"
|
||||
"math"
|
||||
|
||||
"github.com/ollama/ollama/fs"
|
||||
|
|
@ -23,6 +22,8 @@ type Options struct {
|
|||
|
||||
originalContextLength int
|
||||
attnFactor float32
|
||||
slidingWindow int32
|
||||
slidingWindowPattern []bool // per-layer SWA pattern (true = SWA, false = full attention)
|
||||
}
|
||||
|
||||
type Model struct {
|
||||
|
|
@ -38,7 +39,6 @@ type Model struct {
|
|||
}
|
||||
|
||||
func New(c fs.Config) (model.Model, error) {
|
||||
fmt.Println("🦙 OLMo model loaded!")
|
||||
vocabulary := model.Vocabulary{
|
||||
Values: c.Strings("tokenizer.ggml.tokens"),
|
||||
Scores: c.Floats("tokenizer.ggml.scores"),
|
||||
|
|
@ -65,6 +65,9 @@ func New(c fs.Config) (model.Model, error) {
|
|||
}
|
||||
processor := model.NewBytePairEncoding(&vocabulary, pretokenizers...)
|
||||
|
||||
slidingWindow := int32(c.Uint("attention.sliding_window"))
|
||||
slidingWindowPattern := c.Bools("attention.sliding_window_pattern")
|
||||
|
||||
m := Model{
|
||||
TextProcessor: processor,
|
||||
Layers: make([]Layer, c.Uint("block_count")),
|
||||
|
|
@ -80,10 +83,16 @@ func New(c fs.Config) (model.Model, error) {
|
|||
clampKQV: c.Float("attention.clamp_kqv", 0),
|
||||
originalContextLength: int(c.Uint("rope.scaling.original_context_length")),
|
||||
attnFactor: c.Float("rope.scaling.attn_factor", 1),
|
||||
slidingWindow: slidingWindow,
|
||||
slidingWindowPattern: slidingWindowPattern,
|
||||
},
|
||||
}
|
||||
|
||||
m.Cache = kvcache.NewCausalCache(m.Shift)
|
||||
// OLMo3 uses interleaved sliding window attention (every 4th layer is full attention)
|
||||
m.Cache = kvcache.NewWrapperCache(
|
||||
kvcache.NewSWACache(slidingWindow, m.Shift),
|
||||
kvcache.NewCausalCache(m.Shift),
|
||||
)
|
||||
|
||||
return &m, nil
|
||||
}
|
||||
|
|
@ -103,12 +112,23 @@ func (o *Options) ropeOptions(factors ml.Tensor, isSWA bool) []func(*rope.Option
|
|||
rope.WithFactors(factors),
|
||||
}
|
||||
|
||||
if !isSWA && o.originalContextLength > 0 {
|
||||
opts = append(opts,
|
||||
rope.WithOriginalContextLength(o.originalContextLength),
|
||||
rope.WithExtrapolationFactor(1.),
|
||||
rope.WithAttentionFactor(o.attnFactor),
|
||||
)
|
||||
if o.originalContextLength > 0 {
|
||||
if isSWA {
|
||||
// For SWA layers, use regular rope with no YaRN scaling
|
||||
// ext_factor=0.0, attn_factor=1.0 per llama.cpp
|
||||
opts = append(opts,
|
||||
rope.WithOriginalContextLength(o.originalContextLength),
|
||||
rope.WithExtrapolationFactor(0.),
|
||||
rope.WithAttentionFactor(1.),
|
||||
)
|
||||
} else {
|
||||
// For full attention layers, use YaRN scaling
|
||||
opts = append(opts,
|
||||
rope.WithOriginalContextLength(o.originalContextLength),
|
||||
rope.WithExtrapolationFactor(1.),
|
||||
rope.WithAttentionFactor(o.attnFactor),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
return opts
|
||||
|
|
@ -150,7 +170,7 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positions ml.Tenso
|
|||
|
||||
func (m *Model) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
|
||||
ropeDim := cmp.Or(m.ropeDim, m.hiddenSize/m.numHeads)
|
||||
isSWA := isSWALayer(layer)
|
||||
isSWA := m.isSWALayer(layer)
|
||||
|
||||
freqScale := float32(1.0)
|
||||
if !isSWA {
|
||||
|
|
@ -204,7 +224,14 @@ func (l *Layer) Forward(ctx ml.Context, hiddenState, positions, outputs ml.Tenso
|
|||
return hiddenState.Add(ctx, ffnInput)
|
||||
}
|
||||
|
||||
func isSWALayer(layerIdx int) bool {
|
||||
// isSWALayer returns true if the layer uses sliding window attention.
|
||||
// Uses the sliding_window_pattern from the model config if available,
|
||||
// otherwise falls back to the default OLMo3 pattern (every 4th layer is full attention).
|
||||
func (m *Model) isSWALayer(layerIdx int) bool {
|
||||
if len(m.slidingWindowPattern) > layerIdx {
|
||||
return m.slidingWindowPattern[layerIdx]
|
||||
}
|
||||
// Fallback: OLMo3 pattern where every 4th layer (indices 3, 7, 11, ...) uses full attention
|
||||
return (layerIdx+1)%4 != 0
|
||||
}
|
||||
|
||||
|
|
@ -216,7 +243,16 @@ func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
|
|||
for i, layer := range m.Layers {
|
||||
m.Cache.SetLayer(i)
|
||||
|
||||
isSWA := isSWALayer(i)
|
||||
isSWA := m.isSWALayer(i)
|
||||
|
||||
// Set cache type for interleaved SWA (OLMo3)
|
||||
if wc, ok := m.Cache.(*kvcache.WrapperCache); ok {
|
||||
if isSWA {
|
||||
wc.SetLayerType(0) // SWA cache
|
||||
} else {
|
||||
wc.SetLayerType(1) // Causal cache
|
||||
}
|
||||
}
|
||||
|
||||
var outputs ml.Tensor
|
||||
if i == len(m.Layers)-1 {
|
||||
|
|
|
|||
|
|
@ -140,7 +140,7 @@ func (r *Olmo3Renderer) Render(messages []api.Message, tools []api.Tool, _ *api.
|
|||
}
|
||||
|
||||
if needsGenerationPrompt {
|
||||
sb.WriteString("<|im_start|>assistant\n\n")
|
||||
sb.WriteString("<|im_start|>assistant\n")
|
||||
}
|
||||
|
||||
return sb.String(), nil
|
||||
|
|
|
|||
Loading…
Reference in New Issue