From a613eca69c74543319b840d137e067d498581436 Mon Sep 17 00:00:00 2001 From: ParthSareen Date: Tue, 9 Dec 2025 22:26:30 -0800 Subject: [PATCH] improvements --- model/models/olmo/model.go | 57 +++++++++++++++++++++------------- model/models/olmo/olmo_test.go | 2 +- 2 files changed, 36 insertions(+), 23 deletions(-) diff --git a/model/models/olmo/model.go b/model/models/olmo/model.go index 6b95933ad..e57c87a64 100644 --- a/model/models/olmo/model.go +++ b/model/models/olmo/model.go @@ -29,6 +29,9 @@ type Options struct { ropeType string ropeExtrapolation float32 + ropeBetaFast float32 + ropeBetaSlow float32 + slidingWindowPattern []bool } @@ -66,7 +69,7 @@ func New(c fs.Config) (model.Model, error) { var pretokenizers []string if c.String("tokenizer.ggml.pre") != "default" { pretokenizers = []string{ - `(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+`, + "(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+", } } processor := model.NewBytePairEncoding(&vocabulary, pretokenizers...) @@ -74,8 +77,6 @@ func New(c fs.Config) (model.Model, error) { hiddenSize := int(c.Uint("embedding_length")) numHeads := int(c.Uint("attention.head_count")) numKVHeads := int(c.Uint("attention.head_count_kv")) - // headDim := int(c.Uint("attention.head_count")) - // ropeDim := int(c.Uint("rope.dimension_count")) eps := c.Float("attention.layer_norm_rms_epsilon") ropeBase := c.Float("rope.freq_base", 1e4) ropeScale := c.Float("rope.scaling.factor", 1) @@ -87,8 +88,6 @@ func New(c fs.Config) (model.Model, error) { fmt.Printf("hiddenSize: %d\n", hiddenSize) fmt.Printf("numHeads: %d\n", numHeads) fmt.Printf("numKVHeads: %d\n", numKVHeads) - // fmt.Printf("headDim: %d\n", headDim) - // fmt.Printf("ropeDim: %d\n", ropeDim) fmt.Printf("eps: %f\n", eps) fmt.Printf("ropeBase: %f\n", ropeBase) fmt.Printf("ropeScale: %f\n", ropeScale) @@ -102,11 +101,9 @@ func New(c fs.Config) (model.Model, error) { TextProcessor: processor, Layers: make([]Layer, c.Uint("block_count")), Options: Options{ - hiddenSize: hiddenSize, - numHeads: numHeads, - numKVHeads: numKVHeads, - // headDim: headDim, - // ropeDim: ropeDim, + hiddenSize: hiddenSize, + numHeads: numHeads, + numKVHeads: numKVHeads, eps: eps, ropeBase: ropeBase, ropeScale: ropeScale, @@ -114,6 +111,8 @@ func New(c fs.Config) (model.Model, error) { attnFactor: attnFactor, ropeType: ropeType, ropeExtrapolation: ropeExtrapolation, + ropeBetaFast: 32.0, + ropeBetaSlow: 1.0, slidingWindowPattern: c.Bools("attention.sliding_window_pattern"), }, } @@ -141,19 +140,31 @@ func (m *Model) applyRoPE(ctx ml.Context, states, positions ml.Tensor, ropeDim i var ropeOpts []func(*rope.Options) // Both SWA and non-SWA use beta_fast and beta_slow - // But SWA uses freq_scale=1.0, ext_factor=0.0, attn_factor=1.0 + // defaults + ropeOpts = append(ropeOpts, + rope.WithBetaFast(m.ropeBetaFast), + rope.WithBetaSlow(m.ropeBetaSlow), + ) + + // SWA uses freq_scale=1.0, ext_factor=0.0, attn_factor=1.0 // Non-SWA uses full yarn parameters if m.originalContextLength > 0 { ropeOpts = append(ropeOpts, rope.WithOriginalContextLength(m.originalContextLength), ) - // if !isSWA { - ropeOpts = append(ropeOpts, - rope.WithExtrapolationFactor(m.ropeExtrapolation), - ) - // rope.WithAttentionFactor(m.attnFactor), - // ) + // no yarn for swa + if isSWA { + ropeOpts = append(ropeOpts, + rope.WithExtrapolationFactor(0), + rope.WithAttentionFactor(1.), + ) + } else { + ropeOpts = append(ropeOpts, + rope.WithExtrapolationFactor(m.ropeExtrapolation), + rope.WithAttentionFactor(m.attnFactor), + ) + } } freqScale := float32(1.0) @@ -221,12 +232,13 @@ func (l *Layer) Forward(ctx ml.Context, hiddenState, positions, outputs ml.Tenso residual := hiddenState hiddenState = l.SelfAttention.Forward(ctx, hiddenState, positions, cache, m, isSWA) - hiddenState = l.PostAttentionNorm.Forward(ctx, hiddenState, m.eps) if outputs != nil { hiddenState = hiddenState.Rows(ctx, outputs) residual = residual.Rows(ctx, outputs) } + // i think this should be after getting the rows? + hiddenState = l.PostAttentionNorm.Forward(ctx, hiddenState, m.eps) hiddenState = hiddenState.Add(ctx, residual) residual = hiddenState @@ -256,11 +268,12 @@ func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) { cacheType = cacheTypeCausal } - cache := m.Cache.(*kvcache.WrapperCache) - cache.SetLayerType(cacheType) + wc := m.Cache.(*kvcache.WrapperCache) + wc.SetLayerType(cacheType) // would need to check the cache at the layer instead - if causal, ok := cache.UnderlyingCache().(*kvcache.Causal); ok { - causal.SetCausal(ctx, kvcache.CausalOptions{Except: []int{i}}) + if causal, ok := wc.UnderlyingCache().(*kvcache.Causal); ok { + // TODO: not sure about the index here + causal.SetCausal(ctx, kvcache.CausalOptions{Except: []int{}}) } var outputs ml.Tensor diff --git a/model/models/olmo/olmo_test.go b/model/models/olmo/olmo_test.go index 2d527e761..d2d6f85c9 100644 --- a/model/models/olmo/olmo_test.go +++ b/model/models/olmo/olmo_test.go @@ -107,7 +107,7 @@ func TestTokenization(t *testing.T) { prompt := args.prompt if prompt == "" { - prompt = "hello" + prompt = "hello, how are you?" } tp := m.(model.TextProcessor)