rope options
This commit is contained in:
parent
a613eca69c
commit
b6f769ae60
|
|
@ -139,6 +139,8 @@ func (m *Model) applyRoPE(ctx ml.Context, states, positions ml.Tensor, ropeDim i
|
|||
|
||||
var ropeOpts []func(*rope.Options)
|
||||
|
||||
ropeOpts = append(ropeOpts, rope.WithTypeNeoX())
|
||||
|
||||
// Both SWA and non-SWA use beta_fast and beta_slow
|
||||
// defaults
|
||||
ropeOpts = append(ropeOpts,
|
||||
|
|
@ -187,7 +189,6 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positions ml.Tenso
|
|||
//check here
|
||||
|
||||
query = m.applyRoPE(ctx, query, positions, ropeDim, isSWA)
|
||||
// and here
|
||||
|
||||
key := sa.Key.Forward(ctx, hiddenState)
|
||||
key = sa.KNorm.Forward(ctx, key, m.eps)
|
||||
|
|
@ -233,6 +234,8 @@ func (l *Layer) Forward(ctx ml.Context, hiddenState, positions, outputs ml.Tenso
|
|||
|
||||
hiddenState = l.SelfAttention.Forward(ctx, hiddenState, positions, cache, m, isSWA)
|
||||
|
||||
// return hiddenState
|
||||
|
||||
if outputs != nil {
|
||||
hiddenState = hiddenState.Rows(ctx, outputs)
|
||||
residual = residual.Rows(ctx, outputs)
|
||||
|
|
@ -282,6 +285,8 @@ func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
|
|||
}
|
||||
|
||||
hiddenState = layer.Forward(ctx, hiddenState, positions, outputs, m.Cache, m, isSWA)
|
||||
|
||||
// return hiddenState, nil
|
||||
}
|
||||
|
||||
hiddenState = m.OutputNorm.Forward(ctx, hiddenState, m.eps)
|
||||
|
|
|
|||
Loading…
Reference in New Issue