rope options

This commit is contained in:
nicole pardal 2025-12-10 14:40:15 -08:00
parent a613eca69c
commit b6f769ae60
1 changed files with 6 additions and 1 deletions

View File

@ -139,6 +139,8 @@ func (m *Model) applyRoPE(ctx ml.Context, states, positions ml.Tensor, ropeDim i
var ropeOpts []func(*rope.Options)
ropeOpts = append(ropeOpts, rope.WithTypeNeoX())
// Both SWA and non-SWA use beta_fast and beta_slow
// defaults
ropeOpts = append(ropeOpts,
@ -187,7 +189,6 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positions ml.Tenso
//check here
query = m.applyRoPE(ctx, query, positions, ropeDim, isSWA)
// and here
key := sa.Key.Forward(ctx, hiddenState)
key = sa.KNorm.Forward(ctx, key, m.eps)
@ -233,6 +234,8 @@ func (l *Layer) Forward(ctx ml.Context, hiddenState, positions, outputs ml.Tenso
hiddenState = l.SelfAttention.Forward(ctx, hiddenState, positions, cache, m, isSWA)
// return hiddenState
if outputs != nil {
hiddenState = hiddenState.Rows(ctx, outputs)
residual = residual.Rows(ctx, outputs)
@ -282,6 +285,8 @@ func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
}
hiddenState = layer.Forward(ctx, hiddenState, positions, outputs, m.Cache, m, isSWA)
// return hiddenState, nil
}
hiddenState = m.OutputNorm.Forward(ctx, hiddenState, m.eps)