fix(llama): other llama flavours (#12308)

* fix(llama): rope scale

* spm llama

* skip moe models

* cleanup
This commit is contained in:
Michael Yang
2025-09-17 12:12:21 -07:00
committed by GitHub
parent a417ac97ee
commit 564b558c92
10 changed files with 75 additions and 67 deletions

View File

@@ -63,7 +63,7 @@ func New(c fs.Config) (model.Model, error) {
attnValLen: int(c.Uint("attention.value_length")),
eps: c.Float("attention.layer_norm_rms_epsilon"),
ropeBase: c.Float("rope.freq_base", 10000.0),
ropeScale: c.Float("rope.freq_scale", 1.0),
ropeScale: c.Float("rope.scaling.factor", 1.0),
attnLogitSoftcap: c.Float("attn_logit_softcapping"),
finalLogitSoftcap: c.Float("final_logit_softcapping"),
},
@@ -88,7 +88,7 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Ten
q := sa.Query.Forward(ctx, hiddenState)
q = q.Reshape(ctx, opts.attnKeyLen, opts.numHeads, batchSize)
q = fast.RoPE(ctx, q, positionIDs, opts.attnKeyLen, opts.ropeBase, opts.ropeScale, rope.WithTypeNeoX())
q = fast.RoPE(ctx, q, positionIDs, opts.attnKeyLen, opts.ropeBase, 1./opts.ropeScale, rope.WithTypeNeoX())
if opts.largeModelScaling {
q = q.Scale(ctx, 1.0/math.Sqrt(float64(opts.hiddenSize/opts.numHeads)))
@@ -98,7 +98,7 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Ten
k := sa.Key.Forward(ctx, hiddenState)
k = k.Reshape(ctx, opts.attnKeyLen, opts.numKVHeads, batchSize)
k = fast.RoPE(ctx, k, positionIDs, opts.attnKeyLen, opts.ropeBase, opts.ropeScale, rope.WithTypeNeoX())
k = fast.RoPE(ctx, k, positionIDs, opts.attnKeyLen, opts.ropeBase, 1./opts.ropeScale, rope.WithTypeNeoX())
v := sa.Value.Forward(ctx, hiddenState)
v = v.Reshape(ctx, opts.attnValLen, opts.numKVHeads, batchSize)