From f331801252f102767e6becf4631a5b4ad9b3f3d9 Mon Sep 17 00:00:00 2001 From: Grace Guo Date: Mon, 8 Dec 2025 17:35:31 -0800 Subject: [PATCH] init changes --- model/models/deepseek2/model.go | 41 +++++++++++++++++++++++++++++---- 1 file changed, 37 insertions(+), 4 deletions(-) diff --git a/model/models/deepseek2/model.go b/model/models/deepseek2/model.go index 576076aab..2f1fec961 100644 --- a/model/models/deepseek2/model.go +++ b/model/models/deepseek2/model.go @@ -39,6 +39,10 @@ type Options struct { ropeBase, ropeScale float32 kqScale float64 + + attentionTemperatureScale float32 + attentionTemperatureLength int + attentionTemperatureFloorScale int } func (o Options) applyRotaryPositionEmbeddings(ctx ml.Context, t, p ml.Tensor) ml.Tensor { @@ -66,7 +70,7 @@ type Attention struct { Output *nn.Linear `gguf:"attn_out,alt:attn_output"` } -func (attn *Attention) Forward(ctx ml.Context, hiddenStates, positions ml.Tensor, cache kvcache.Cache, opts *Options) ml.Tensor { +func (attn *Attention) Forward(ctx ml.Context, hiddenStates, positions, attentionScales ml.Tensor, cache kvcache.Cache, opts *Options) ml.Tensor { seqLength := hiddenStates.Dim(1) var query ml.Tensor @@ -104,6 +108,11 @@ func (attn *Attention) Forward(ctx ml.Context, hiddenStates, positions ml.Tensor kRot = kRot.Repeat(ctx, 1, queryChunks[0].Dim(1)) query = qRot.Concat(ctx, queryChunks[0], 0) key := kRot.Concat(ctx, kvChunks[0], 0) + + if attentionScales != nil { + query = query.Mul(ctx, attentionScales) + } + attention = nn.Attention(ctx, query, key, kvChunks[1], opts.kqScale, cache) } else { // v3.1 qPass := queryChunks[0].Permute(ctx, 0, 2, 1, 3) @@ -115,6 +124,10 @@ func (attn *Attention) Forward(ctx ml.Context, hiddenStates, positions ml.Tensor key := kRot.Concat(ctx, kPass, 0) value := kPass + if attentionScales != nil { + query = query.Mul(ctx, attentionScales) + } + attention = nn.AttentionWithVMLA(ctx, query, key, value, nil, attn.VB.Weight, opts.kqScale, cache) } @@ -201,10 +214,10 @@ type Layer struct { MLP MLP } -func (t *Layer) Forward(ctx ml.Context, hiddenStates, positions, outputs ml.Tensor, cache kvcache.Cache, opts *Options) ml.Tensor { +func (t *Layer) Forward(ctx ml.Context, hiddenStates, positions, attentionScales, outputs ml.Tensor, cache kvcache.Cache, opts *Options) ml.Tensor { residual := hiddenStates hiddenStates = t.AttentionNorm.Forward(ctx, hiddenStates, opts.eps) - hiddenStates = t.Attention.Forward(ctx, hiddenStates, positions, cache, opts) + hiddenStates = t.Attention.Forward(ctx, hiddenStates, positions, attentionScales, cache, opts) if outputs != nil { hiddenStates = hiddenStates.Rows(ctx, outputs) @@ -316,6 +329,11 @@ func New(c fs.Config) (model.Model, error) { routedScalingFactor: c.Float("expert_weights_scale"), originalContextLength: int(c.Uint("rope.scaling.original_context_length")), + // TODO: double check these values + attentionTemperatureScale: c.Float("attention.temperature_scale", 1.0), + attentionTemperatureLength: int(c.Uint("attention.temperature_length")), + attentionTemperatureFloorScale: int(c.Uint("attention.temperature_floor_scale", 8192)), + kqScale: kqScale, }, } @@ -333,6 +351,21 @@ func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) { hiddenStates := m.TokenEmbedding.Forward(ctx, batch.Inputs) + // Temperature tuning - used by mistral-large + var attentionScales ml.Tensor + if m.attentionTemperatureScale != 0.0 { + nTokens := len(batch.Positions) + scales := make([]float32, nTokens) + + for i, pos := range batch.Positions { + posFloat := float64(pos) + scaleValue := math.Log(math.Floor((posFloat+1.0)/float64(m.attentionTemperatureFloorScale))+1.0)*float64(m.attentionTemperatureScale) + 1.0 + scales[i] = float32(scaleValue) + } + + attentionScales = ctx.Input().FromFloats(scales, 1, 1, nTokens) + } + for i, layer := range m.Layers { m.Cache.SetLayer(i) @@ -341,7 +374,7 @@ func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) { outputs = batch.Outputs } - hiddenStates = layer.Forward(ctx, hiddenStates, positions, outputs, m.Cache, m.Options) + hiddenStates = layer.Forward(ctx, hiddenStates, positions, attentionScales, outputs, m.Cache, m.Options) } hiddenStates = m.OutputNorm.Forward(ctx, hiddenStates, m.eps)