init changes
This commit is contained in:
parent
49a9c9ba6a
commit
f331801252
|
|
@ -39,6 +39,10 @@ type Options struct {
|
|||
ropeBase,
|
||||
ropeScale float32
|
||||
kqScale float64
|
||||
|
||||
attentionTemperatureScale float32
|
||||
attentionTemperatureLength int
|
||||
attentionTemperatureFloorScale int
|
||||
}
|
||||
|
||||
func (o Options) applyRotaryPositionEmbeddings(ctx ml.Context, t, p ml.Tensor) ml.Tensor {
|
||||
|
|
@ -66,7 +70,7 @@ type Attention struct {
|
|||
Output *nn.Linear `gguf:"attn_out,alt:attn_output"`
|
||||
}
|
||||
|
||||
func (attn *Attention) Forward(ctx ml.Context, hiddenStates, positions ml.Tensor, cache kvcache.Cache, opts *Options) ml.Tensor {
|
||||
func (attn *Attention) Forward(ctx ml.Context, hiddenStates, positions, attentionScales ml.Tensor, cache kvcache.Cache, opts *Options) ml.Tensor {
|
||||
seqLength := hiddenStates.Dim(1)
|
||||
|
||||
var query ml.Tensor
|
||||
|
|
@ -104,6 +108,11 @@ func (attn *Attention) Forward(ctx ml.Context, hiddenStates, positions ml.Tensor
|
|||
kRot = kRot.Repeat(ctx, 1, queryChunks[0].Dim(1))
|
||||
query = qRot.Concat(ctx, queryChunks[0], 0)
|
||||
key := kRot.Concat(ctx, kvChunks[0], 0)
|
||||
|
||||
if attentionScales != nil {
|
||||
query = query.Mul(ctx, attentionScales)
|
||||
}
|
||||
|
||||
attention = nn.Attention(ctx, query, key, kvChunks[1], opts.kqScale, cache)
|
||||
} else { // v3.1
|
||||
qPass := queryChunks[0].Permute(ctx, 0, 2, 1, 3)
|
||||
|
|
@ -115,6 +124,10 @@ func (attn *Attention) Forward(ctx ml.Context, hiddenStates, positions ml.Tensor
|
|||
key := kRot.Concat(ctx, kPass, 0)
|
||||
value := kPass
|
||||
|
||||
if attentionScales != nil {
|
||||
query = query.Mul(ctx, attentionScales)
|
||||
}
|
||||
|
||||
attention = nn.AttentionWithVMLA(ctx, query, key, value, nil, attn.VB.Weight, opts.kqScale, cache)
|
||||
}
|
||||
|
||||
|
|
@ -201,10 +214,10 @@ type Layer struct {
|
|||
MLP MLP
|
||||
}
|
||||
|
||||
func (t *Layer) Forward(ctx ml.Context, hiddenStates, positions, outputs ml.Tensor, cache kvcache.Cache, opts *Options) ml.Tensor {
|
||||
func (t *Layer) Forward(ctx ml.Context, hiddenStates, positions, attentionScales, outputs ml.Tensor, cache kvcache.Cache, opts *Options) ml.Tensor {
|
||||
residual := hiddenStates
|
||||
hiddenStates = t.AttentionNorm.Forward(ctx, hiddenStates, opts.eps)
|
||||
hiddenStates = t.Attention.Forward(ctx, hiddenStates, positions, cache, opts)
|
||||
hiddenStates = t.Attention.Forward(ctx, hiddenStates, positions, attentionScales, cache, opts)
|
||||
|
||||
if outputs != nil {
|
||||
hiddenStates = hiddenStates.Rows(ctx, outputs)
|
||||
|
|
@ -316,6 +329,11 @@ func New(c fs.Config) (model.Model, error) {
|
|||
routedScalingFactor: c.Float("expert_weights_scale"),
|
||||
originalContextLength: int(c.Uint("rope.scaling.original_context_length")),
|
||||
|
||||
// TODO: double check these values
|
||||
attentionTemperatureScale: c.Float("attention.temperature_scale", 1.0),
|
||||
attentionTemperatureLength: int(c.Uint("attention.temperature_length")),
|
||||
attentionTemperatureFloorScale: int(c.Uint("attention.temperature_floor_scale", 8192)),
|
||||
|
||||
kqScale: kqScale,
|
||||
},
|
||||
}
|
||||
|
|
@ -333,6 +351,21 @@ func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
|
|||
|
||||
hiddenStates := m.TokenEmbedding.Forward(ctx, batch.Inputs)
|
||||
|
||||
// Temperature tuning - used by mistral-large
|
||||
var attentionScales ml.Tensor
|
||||
if m.attentionTemperatureScale != 0.0 {
|
||||
nTokens := len(batch.Positions)
|
||||
scales := make([]float32, nTokens)
|
||||
|
||||
for i, pos := range batch.Positions {
|
||||
posFloat := float64(pos)
|
||||
scaleValue := math.Log(math.Floor((posFloat+1.0)/float64(m.attentionTemperatureFloorScale))+1.0)*float64(m.attentionTemperatureScale) + 1.0
|
||||
scales[i] = float32(scaleValue)
|
||||
}
|
||||
|
||||
attentionScales = ctx.Input().FromFloats(scales, 1, 1, nTokens)
|
||||
}
|
||||
|
||||
for i, layer := range m.Layers {
|
||||
m.Cache.SetLayer(i)
|
||||
|
||||
|
|
@ -341,7 +374,7 @@ func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
|
|||
outputs = batch.Outputs
|
||||
}
|
||||
|
||||
hiddenStates = layer.Forward(ctx, hiddenStates, positions, outputs, m.Cache, m.Options)
|
||||
hiddenStates = layer.Forward(ctx, hiddenStates, positions, attentionScales, outputs, m.Cache, m.Options)
|
||||
}
|
||||
|
||||
hiddenStates = m.OutputNorm.Forward(ctx, hiddenStates, m.eps)
|
||||
|
|
|
|||
Loading…
Reference in New Issue