gemma3: scale in attention
This commit is contained in:
parent
b323cfe731
commit
12a7e5ec46
|
|
@ -86,12 +86,6 @@ func (sa *TextSelfAttention) Forward(ctx ml.Context, layer int, hiddenState, pos
|
|||
q = sa.QueryNorm.Forward(ctx, q, opts.eps)
|
||||
q = fast.RoPE(ctx, q, positionIDs, opts.attnKeyLen, ropeBase, opts.ropeScale, rope.WithTypeNeoX())
|
||||
|
||||
if opts.largeModelScaling {
|
||||
q = q.Scale(ctx, 1.0/math.Sqrt(float64(opts.hiddenSize/opts.numHeads)))
|
||||
} else {
|
||||
q = q.Scale(ctx, 1.0/math.Sqrt(float64(opts.attnKeyLen)))
|
||||
}
|
||||
|
||||
k := sa.Key.Forward(ctx, hiddenState)
|
||||
k = k.Reshape(ctx, opts.attnKeyLen, opts.numKVHeads, batchSize)
|
||||
k = sa.KeyNorm.Forward(ctx, k, opts.eps)
|
||||
|
|
@ -100,8 +94,12 @@ func (sa *TextSelfAttention) Forward(ctx ml.Context, layer int, hiddenState, pos
|
|||
v := sa.Value.Forward(ctx, hiddenState)
|
||||
v = v.Reshape(ctx, opts.attnValLen, opts.numKVHeads, batchSize)
|
||||
|
||||
scaleFactor := 1.0
|
||||
kqv := nn.Attention(ctx, q, k, v, scaleFactor, cache)
|
||||
scale := 1.0 / math.Sqrt(float64(opts.attnKeyLen))
|
||||
if opts.largeModelScaling {
|
||||
scale = 1.0 / math.Sqrt(float64(opts.hiddenSize/opts.numHeads))
|
||||
}
|
||||
|
||||
kqv := nn.Attention(ctx, q, k, v, scale, cache)
|
||||
kqv = kqv.Reshape(ctx, opts.attnValLen*opts.numHeads, batchSize)
|
||||
|
||||
return sa.Output.Forward(ctx, kqv)
|
||||
|
|
|
|||
Loading…
Reference in New Issue