From b323cfe731ac401b5df178a5fdd2ae852a0f7056 Mon Sep 17 00:00:00 2001 From: Michael Yang Date: Tue, 19 Aug 2025 13:32:42 -0700 Subject: [PATCH] gemma2: use fast attention --- model/models/gemma2/model.go | 42 ++++++++++-------------------------- 1 file changed, 11 insertions(+), 31 deletions(-) diff --git a/model/models/gemma2/model.go b/model/models/gemma2/model.go index e621d03ae..aba56e634 100644 --- a/model/models/gemma2/model.go +++ b/model/models/gemma2/model.go @@ -69,10 +69,10 @@ func New(c fs.Config) (model.Model, error) { }, } - slidingWindowLen := int32(c.Uint("attention.sliding_window")) - m.Cache = kvcache.NewWrapperCache(kvcache.NewSWACache(slidingWindowLen, m.Shift), kvcache.NewCausalCache(m.Shift)) - m.Cache.SetConfig(ml.CacheConfig{}) - + m.Cache = kvcache.NewWrapperCache( + kvcache.NewSWACache(int32(c.Uint("attention.sliding_window")), m.Shift), + kvcache.NewCausalCache(m.Shift), + ) return &m, nil } @@ -90,12 +90,6 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Ten q = q.Reshape(ctx, opts.attnKeyLen, opts.numHeads, batchSize) q = fast.RoPE(ctx, q, positionIDs, opts.attnKeyLen, opts.ropeBase, opts.ropeScale, rope.WithTypeNeoX()) - if opts.largeModelScaling { - q = q.Scale(ctx, 1.0/math.Sqrt(float64(opts.hiddenSize/opts.numHeads))) - } else { - q = q.Scale(ctx, 1.0/math.Sqrt(float64(opts.attnKeyLen))) - } - k := sa.Key.Forward(ctx, hiddenState) k = k.Reshape(ctx, opts.attnKeyLen, opts.numKVHeads, batchSize) k = fast.RoPE(ctx, k, positionIDs, opts.attnKeyLen, opts.ropeBase, opts.ropeScale, rope.WithTypeNeoX()) @@ -103,28 +97,14 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Ten v := sa.Value.Forward(ctx, hiddenState) v = v.Reshape(ctx, opts.attnValLen, opts.numKVHeads, batchSize) - cache.Put(ctx, k, v) - k, v, mask := cache.Get(ctx) + scale := 1.0 / math.Sqrt(float64(opts.attnKeyLen)) + if opts.largeModelScaling { + scale = 1.0 / math.Sqrt(float64(opts.hiddenSize/opts.numHeads)) + } - q = q.Permute(ctx, 0, 2, 1, 3) - k = k.Permute(ctx, 0, 2, 1, 3) - v = v.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx) - - kq := k.Mulmat(ctx, q) - - // logit softcap - kq = kq.Scale(ctx, 1.0/float64(opts.attnLogitSoftcap)) - kq = kq.Tanh(ctx) - kq = kq.Scale(ctx, float64(opts.attnLogitSoftcap)) - - kq = kq.Add(ctx, mask) - kq = kq.Softmax(ctx) - - kqv := v.Mulmat(ctx, kq) - kqv = kqv.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx) - kqv = kqv.Reshape(ctx, opts.attnValLen*opts.numHeads, batchSize) - - return sa.Output.Forward(ctx, kqv) + attn := nn.Attention(ctx, q, k, v, scale, cache) + attn = attn.Reshape(ctx, opts.attnValLen*opts.numHeads, batchSize) + return sa.Output.Forward(ctx, attn) } func (m *Model) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {