From b6f769ae605b5dc6f53d1485a33c22ccf5d85572 Mon Sep 17 00:00:00 2001 From: nicole pardal Date: Wed, 10 Dec 2025 14:40:15 -0800 Subject: [PATCH] rope options --- model/models/olmo/model.go | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/model/models/olmo/model.go b/model/models/olmo/model.go index e57c87a64..b3ea3e65e 100644 --- a/model/models/olmo/model.go +++ b/model/models/olmo/model.go @@ -139,6 +139,8 @@ func (m *Model) applyRoPE(ctx ml.Context, states, positions ml.Tensor, ropeDim i var ropeOpts []func(*rope.Options) + ropeOpts = append(ropeOpts, rope.WithTypeNeoX()) + // Both SWA and non-SWA use beta_fast and beta_slow // defaults ropeOpts = append(ropeOpts, @@ -187,7 +189,6 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positions ml.Tenso //check here query = m.applyRoPE(ctx, query, positions, ropeDim, isSWA) - // and here key := sa.Key.Forward(ctx, hiddenState) key = sa.KNorm.Forward(ctx, key, m.eps) @@ -233,6 +234,8 @@ func (l *Layer) Forward(ctx ml.Context, hiddenState, positions, outputs ml.Tenso hiddenState = l.SelfAttention.Forward(ctx, hiddenState, positions, cache, m, isSWA) + // return hiddenState + if outputs != nil { hiddenState = hiddenState.Rows(ctx, outputs) residual = residual.Rows(ctx, outputs) @@ -282,6 +285,8 @@ func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) { } hiddenState = layer.Forward(ctx, hiddenState, positions, outputs, m.Cache, m, isSWA) + + // return hiddenState, nil } hiddenState = m.OutputNorm.Forward(ctx, hiddenState, m.eps)