init changes

This commit is contained in:
Grace Guo 2025-12-08 17:35:31 -08:00
parent 49a9c9ba6a
commit f331801252
1 changed files with 37 additions and 4 deletions

View File

@ -39,6 +39,10 @@ type Options struct {
ropeBase,
ropeScale float32
kqScale float64
attentionTemperatureScale float32
attentionTemperatureLength int
attentionTemperatureFloorScale int
}
func (o Options) applyRotaryPositionEmbeddings(ctx ml.Context, t, p ml.Tensor) ml.Tensor {
@ -66,7 +70,7 @@ type Attention struct {
Output *nn.Linear `gguf:"attn_out,alt:attn_output"`
}
func (attn *Attention) Forward(ctx ml.Context, hiddenStates, positions ml.Tensor, cache kvcache.Cache, opts *Options) ml.Tensor {
func (attn *Attention) Forward(ctx ml.Context, hiddenStates, positions, attentionScales ml.Tensor, cache kvcache.Cache, opts *Options) ml.Tensor {
seqLength := hiddenStates.Dim(1)
var query ml.Tensor
@ -104,6 +108,11 @@ func (attn *Attention) Forward(ctx ml.Context, hiddenStates, positions ml.Tensor
kRot = kRot.Repeat(ctx, 1, queryChunks[0].Dim(1))
query = qRot.Concat(ctx, queryChunks[0], 0)
key := kRot.Concat(ctx, kvChunks[0], 0)
if attentionScales != nil {
query = query.Mul(ctx, attentionScales)
}
attention = nn.Attention(ctx, query, key, kvChunks[1], opts.kqScale, cache)
} else { // v3.1
qPass := queryChunks[0].Permute(ctx, 0, 2, 1, 3)
@ -115,6 +124,10 @@ func (attn *Attention) Forward(ctx ml.Context, hiddenStates, positions ml.Tensor
key := kRot.Concat(ctx, kPass, 0)
value := kPass
if attentionScales != nil {
query = query.Mul(ctx, attentionScales)
}
attention = nn.AttentionWithVMLA(ctx, query, key, value, nil, attn.VB.Weight, opts.kqScale, cache)
}
@ -201,10 +214,10 @@ type Layer struct {
MLP MLP
}
func (t *Layer) Forward(ctx ml.Context, hiddenStates, positions, outputs ml.Tensor, cache kvcache.Cache, opts *Options) ml.Tensor {
func (t *Layer) Forward(ctx ml.Context, hiddenStates, positions, attentionScales, outputs ml.Tensor, cache kvcache.Cache, opts *Options) ml.Tensor {
residual := hiddenStates
hiddenStates = t.AttentionNorm.Forward(ctx, hiddenStates, opts.eps)
hiddenStates = t.Attention.Forward(ctx, hiddenStates, positions, cache, opts)
hiddenStates = t.Attention.Forward(ctx, hiddenStates, positions, attentionScales, cache, opts)
if outputs != nil {
hiddenStates = hiddenStates.Rows(ctx, outputs)
@ -316,6 +329,11 @@ func New(c fs.Config) (model.Model, error) {
routedScalingFactor: c.Float("expert_weights_scale"),
originalContextLength: int(c.Uint("rope.scaling.original_context_length")),
// TODO: double check these values
attentionTemperatureScale: c.Float("attention.temperature_scale", 1.0),
attentionTemperatureLength: int(c.Uint("attention.temperature_length")),
attentionTemperatureFloorScale: int(c.Uint("attention.temperature_floor_scale", 8192)),
kqScale: kqScale,
},
}
@ -333,6 +351,21 @@ func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
hiddenStates := m.TokenEmbedding.Forward(ctx, batch.Inputs)
// Temperature tuning - used by mistral-large
var attentionScales ml.Tensor
if m.attentionTemperatureScale != 0.0 {
nTokens := len(batch.Positions)
scales := make([]float32, nTokens)
for i, pos := range batch.Positions {
posFloat := float64(pos)
scaleValue := math.Log(math.Floor((posFloat+1.0)/float64(m.attentionTemperatureFloorScale))+1.0)*float64(m.attentionTemperatureScale) + 1.0
scales[i] = float32(scaleValue)
}
attentionScales = ctx.Input().FromFloats(scales, 1, 1, nTokens)
}
for i, layer := range m.Layers {
m.Cache.SetLayer(i)
@ -341,7 +374,7 @@ func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
outputs = batch.Outputs
}
hiddenStates = layer.Forward(ctx, hiddenStates, positions, outputs, m.Cache, m.Options)
hiddenStates = layer.Forward(ctx, hiddenStates, positions, attentionScales, outputs, m.Cache, m.Options)
}
hiddenStates = m.OutputNorm.Forward(ctx, hiddenStates, m.eps)