From 12a7e5ec46c2ed20611fcac3ffb9c934107afcbb Mon Sep 17 00:00:00 2001 From: Michael Yang Date: Tue, 19 Aug 2025 13:43:42 -0700 Subject: [PATCH] gemma3: scale in attention --- model/models/gemma3/model_text.go | 14 ++++++-------- 1 file changed, 6 insertions(+), 8 deletions(-) diff --git a/model/models/gemma3/model_text.go b/model/models/gemma3/model_text.go index 70d7797e9..f13053417 100644 --- a/model/models/gemma3/model_text.go +++ b/model/models/gemma3/model_text.go @@ -86,12 +86,6 @@ func (sa *TextSelfAttention) Forward(ctx ml.Context, layer int, hiddenState, pos q = sa.QueryNorm.Forward(ctx, q, opts.eps) q = fast.RoPE(ctx, q, positionIDs, opts.attnKeyLen, ropeBase, opts.ropeScale, rope.WithTypeNeoX()) - if opts.largeModelScaling { - q = q.Scale(ctx, 1.0/math.Sqrt(float64(opts.hiddenSize/opts.numHeads))) - } else { - q = q.Scale(ctx, 1.0/math.Sqrt(float64(opts.attnKeyLen))) - } - k := sa.Key.Forward(ctx, hiddenState) k = k.Reshape(ctx, opts.attnKeyLen, opts.numKVHeads, batchSize) k = sa.KeyNorm.Forward(ctx, k, opts.eps) @@ -100,8 +94,12 @@ func (sa *TextSelfAttention) Forward(ctx ml.Context, layer int, hiddenState, pos v := sa.Value.Forward(ctx, hiddenState) v = v.Reshape(ctx, opts.attnValLen, opts.numKVHeads, batchSize) - scaleFactor := 1.0 - kqv := nn.Attention(ctx, q, k, v, scaleFactor, cache) + scale := 1.0 / math.Sqrt(float64(opts.attnKeyLen)) + if opts.largeModelScaling { + scale = 1.0 / math.Sqrt(float64(opts.hiddenSize/opts.numHeads)) + } + + kqv := nn.Attention(ctx, q, k, v, scale, cache) kqv = kqv.Reshape(ctx, opts.attnValLen*opts.numHeads, batchSize) return sa.Output.Forward(ctx, kqv)