fixed generation issue

This commit is contained in:
nicole pardal 2025-12-08 00:35:49 -08:00 committed by ParthSareen
parent 3eea7f198b
commit d8bf6a5dee
1 changed files with 64 additions and 30 deletions

View File

@ -19,6 +19,9 @@ type Options struct {
headDim, ropeDim int headDim, ropeDim int
eps, ropeBase, ropeScale float32 eps, ropeBase, ropeScale float32
clampKQV float32 clampKQV float32
originalContextLength int
attnFactor float32
} }
type Model struct { type Model struct {
@ -64,26 +67,21 @@ func New(c fs.Config) (model.Model, error) {
TextProcessor: processor, TextProcessor: processor,
Layers: make([]Layer, c.Uint("block_count")), Layers: make([]Layer, c.Uint("block_count")),
Options: Options{ Options: Options{
hiddenSize: int(c.Uint("embedding_length")), hiddenSize: int(c.Uint("embedding_length")),
numHeads: int(c.Uint("attention.head_count")), numHeads: int(c.Uint("attention.head_count")),
numKVHeads: int(c.Uint("attention.head_count_kv")), numKVHeads: int(c.Uint("attention.head_count_kv")),
headDim: int(c.Uint("attention.key_length")), headDim: int(c.Uint("attention.key_length")),
ropeDim: int(c.Uint("rope.dimension_count")), ropeDim: int(c.Uint("rope.dimension_count")),
eps: c.Float("attention.layer_norm_rms_epsilon"), eps: c.Float("attention.layer_norm_rms_epsilon"),
ropeBase: c.Float("rope.freq_base", 1e4), ropeBase: c.Float("rope.freq_base", 1e4),
ropeScale: c.Float("rope.scaling.factor", 1), ropeScale: c.Float("rope.scaling.factor", 1),
clampKQV: c.Float("attention.clamp_kqv", 0), clampKQV: c.Float("attention.clamp_kqv", 0),
originalContextLength: int(c.Uint("rope.scaling.original_context_length")),
attnFactor: c.Float("rope.scaling.attn_factor", 1),
}, },
} }
if slidingWindow := c.Uint("attention.sliding_window"); slidingWindow > 0 { m.Cache = kvcache.NewCausalCache(m.Shift)
m.Cache = kvcache.NewWrapperCache(
kvcache.NewSWACache(int32(slidingWindow), m.Shift),
kvcache.NewCausalCache(m.Shift),
)
} else {
m.Cache = kvcache.NewCausalCache(m.Shift)
}
return &m, nil return &m, nil
} }
@ -98,7 +96,23 @@ type SelfAttention struct {
RopeFactors ml.Tensor `gguf:"rope_freqs.weight"` RopeFactors ml.Tensor `gguf:"rope_freqs.weight"`
} }
func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positions ml.Tensor, cache kvcache.Cache, opts *Options) ml.Tensor { func (o *Options) ropeOptions(factors ml.Tensor, isSWA bool) []func(*rope.Options) {
opts := []func(*rope.Options){
rope.WithFactors(factors),
}
if !isSWA && o.originalContextLength > 0 {
opts = append(opts,
rope.WithOriginalContextLength(o.originalContextLength),
rope.WithExtrapolationFactor(1.),
rope.WithAttentionFactor(o.attnFactor),
)
}
return opts
}
func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positions ml.Tensor, cache kvcache.Cache, opts *Options, isSWA bool) ml.Tensor {
batchSize := hiddenState.Dim(1) batchSize := hiddenState.Dim(1)
headDim := cmp.Or(opts.headDim, opts.hiddenSize/opts.numHeads) headDim := cmp.Or(opts.headDim, opts.hiddenSize/opts.numHeads)
ropeDim := cmp.Or(opts.ropeDim, headDim) ropeDim := cmp.Or(opts.ropeDim, headDim)
@ -118,8 +132,14 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positions ml.Tenso
value := sa.Value.Forward(ctx, hiddenState) value := sa.Value.Forward(ctx, hiddenState)
value = value.Reshape(ctx, headDim, opts.numKVHeads, batchSize) value = value.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
query = fast.RoPE(ctx, query, positions, ropeDim, opts.ropeBase, 1./opts.ropeScale, rope.WithFactors(sa.RopeFactors)) freqScale := float32(1.0)
key = fast.RoPE(ctx, key, positions, ropeDim, opts.ropeBase, 1./opts.ropeScale, rope.WithFactors(sa.RopeFactors)) if !isSWA {
freqScale = 1. / opts.ropeScale
}
ropeOpts := opts.ropeOptions(sa.RopeFactors, isSWA)
query = fast.RoPE(ctx, query, positions, ropeDim, opts.ropeBase, freqScale, ropeOpts...)
key = fast.RoPE(ctx, key, positions, ropeDim, opts.ropeBase, freqScale, ropeOpts...)
attention := nn.Attention(ctx, query, key, value, 1.0/math.Sqrt(float64(headDim)), cache) attention := nn.Attention(ctx, query, key, value, 1.0/math.Sqrt(float64(headDim)), cache)
attention = attention.Reshape(ctx, headDim*opts.numHeads, batchSize) attention = attention.Reshape(ctx, headDim*opts.numHeads, batchSize)
@ -128,7 +148,15 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positions ml.Tenso
func (m *Model) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) { func (m *Model) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
ropeDim := cmp.Or(m.ropeDim, m.hiddenSize/m.numHeads) ropeDim := cmp.Or(m.ropeDim, m.hiddenSize/m.numHeads)
return fast.RoPE(ctx, key, shift, ropeDim, m.ropeBase, 1./m.ropeScale, rope.WithFactors(m.Layers[layer].SelfAttention.RopeFactors)), nil isSWA := isSWALayer(layer)
freqScale := float32(1.0)
if !isSWA {
freqScale = 1. / m.ropeScale
}
ropeOpts := m.Options.ropeOptions(m.Layers[layer].SelfAttention.RopeFactors, isSWA)
return fast.RoPE(ctx, key, shift, ropeDim, m.ropeBase, freqScale, ropeOpts...), nil
} }
type MLP struct { type MLP struct {
@ -149,28 +177,33 @@ type Layer struct {
PostFFWNorm *nn.RMSNorm `gguf:"post_ffw_norm"` PostFFWNorm *nn.RMSNorm `gguf:"post_ffw_norm"`
} }
func (l *Layer) Forward(ctx ml.Context, hiddenState, positions, outputs ml.Tensor, cache kvcache.Cache, opts *Options) ml.Tensor { func (l *Layer) Forward(ctx ml.Context, hiddenState, positions, outputs ml.Tensor, cache kvcache.Cache, opts *Options, isSWA bool) ml.Tensor {
residual := hiddenState residual := hiddenState
hiddenState = l.SelfAttention.Forward(ctx, hiddenState, positions, cache, opts)
hiddenState = l.SelfAttention.Forward(ctx, hiddenState, positions, cache, opts, isSWA)
if outputs != nil { if outputs != nil {
hiddenState = hiddenState.Rows(ctx, outputs) hiddenState = hiddenState.Rows(ctx, outputs)
residual = residual.Rows(ctx, outputs) residual = residual.Rows(ctx, outputs)
} }
hiddenState = hiddenState.Add(ctx, residual)
if l.PostAttentionNorm != nil { if l.PostAttentionNorm != nil {
hiddenState = l.PostAttentionNorm.Forward(ctx, hiddenState, opts.eps) hiddenState = l.PostAttentionNorm.Forward(ctx, hiddenState, opts.eps)
} }
residual = hiddenState ffnInput := hiddenState.Add(ctx, residual)
hiddenState = l.MLP.Forward(ctx, hiddenState, opts)
hiddenState = hiddenState.Add(ctx, residual) hiddenState = l.MLP.Forward(ctx, ffnInput, opts)
if l.PostFFWNorm != nil { if l.PostFFWNorm != nil {
hiddenState = l.PostFFWNorm.Forward(ctx, hiddenState, opts.eps) hiddenState = l.PostFFWNorm.Forward(ctx, hiddenState, opts.eps)
} }
return hiddenState return hiddenState.Add(ctx, ffnInput)
}
func isSWALayer(layerIdx int) bool {
return (layerIdx+1)%4 != 0
} }
func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) { func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
@ -181,12 +214,14 @@ func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
for i, layer := range m.Layers { for i, layer := range m.Layers {
m.Cache.SetLayer(i) m.Cache.SetLayer(i)
isSWA := isSWALayer(i)
var outputs ml.Tensor var outputs ml.Tensor
if i == len(m.Layers)-1 { if i == len(m.Layers)-1 {
outputs = batch.Outputs outputs = batch.Outputs
} }
hiddenState = layer.Forward(ctx, hiddenState, positions, outputs, m.Cache, &m.Options) hiddenState = layer.Forward(ctx, hiddenState, positions, outputs, m.Cache, &m.Options, isSWA)
} }
hiddenState = m.OutputNorm.Forward(ctx, hiddenState, m.eps) hiddenState = m.OutputNorm.Forward(ctx, hiddenState, m.eps)
@ -194,6 +229,5 @@ func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
} }
func init() { func init() {
model.Register("olmo", New)
model.Register("olmo2", New) model.Register("olmo2", New)
} }