ml: let model specify rope configuration
Add support for model-specific RoPE configuration parameters by: 1. Creating a new `RopeConfig` struct to encapsulate all RoPE parameters 2. Adding `RopeType` enum to specify different RoPE variants (Standard/NeoX) 3. Extracting original context length from model config 4. Refactoring `RoPE()` interface to use the new config struct 5. Updating llama and mllama models to use new RoPE configuration This change allows models to specify their RoPE implementation type and original context length, which is important for proper position embedding calculation and model compatibility.
This commit is contained in:
@@ -10,10 +10,10 @@ import (
|
||||
)
|
||||
|
||||
type Options struct {
|
||||
RopeFactors ml.Tensor `gguf:"rope_freqs.weight"`
|
||||
hiddenSize, numHeads, numKVHeads int
|
||||
eps, ropeBase, ropeScale float32
|
||||
ropeDim uint32
|
||||
RopeFactors ml.Tensor `gguf:"rope_freqs.weight"`
|
||||
ctxLen, hiddenSize, numHeads, numKVHeads int
|
||||
eps, ropeBase, ropeScale float32
|
||||
ropeDim uint32
|
||||
}
|
||||
|
||||
type Model struct {
|
||||
@@ -46,6 +46,7 @@ func New(c ml.Config) (model.Model, error) {
|
||||
numHeads: int(c.Uint("attention.head_count")),
|
||||
numKVHeads: int(c.Uint("attention.head_count_kv")),
|
||||
eps: c.Float("attention.layer_norm_rms_epsilon"),
|
||||
ctxLen: int(c.Uint("context_length")),
|
||||
ropeBase: c.Float("rope.freq_base"),
|
||||
ropeScale: c.Float("rope.freq_scale", 1),
|
||||
ropeDim: c.Uint("rope.dimension_count"),
|
||||
@@ -67,14 +68,23 @@ type SelfAttention struct {
|
||||
func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Tensor, cache kvcache.Cache, opts *Options) ml.Tensor {
|
||||
batchSize := hiddenState.Dim(1)
|
||||
headDim := opts.hiddenSize / opts.numHeads
|
||||
rc := ml.RopeConfig{
|
||||
PositionIDs: positionIDs,
|
||||
RopeFactors: opts.RopeFactors,
|
||||
RopeDim: opts.ropeDim,
|
||||
RopeType: ml.RopeTypeStandard,
|
||||
OrigCtxLen: opts.ctxLen,
|
||||
RopeBase: opts.ropeBase,
|
||||
RopeScale: opts.ropeScale,
|
||||
}
|
||||
|
||||
q := sa.Query.Forward(ctx, hiddenState)
|
||||
q = q.Reshape(ctx, headDim, opts.numHeads, batchSize)
|
||||
q = q.RoPE(ctx, positionIDs, opts.RopeFactors, opts.ropeDim, opts.ropeBase, opts.ropeScale)
|
||||
q = q.RoPE(ctx, rc)
|
||||
|
||||
k := sa.Key.Forward(ctx, hiddenState)
|
||||
k = k.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
|
||||
k = k.RoPE(ctx, positionIDs, opts.RopeFactors, opts.ropeDim, opts.ropeBase, opts.ropeScale)
|
||||
k = k.RoPE(ctx, rc)
|
||||
|
||||
v := sa.Value.Forward(ctx, hiddenState)
|
||||
v = v.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
|
||||
@@ -99,7 +109,18 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Ten
|
||||
}
|
||||
|
||||
func (m *Model) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
|
||||
return key.RoPE(ctx, shift, m.Options.RopeFactors, m.Options.ropeDim, m.Options.ropeBase, m.Options.ropeScale), nil
|
||||
return key.RoPE(
|
||||
ctx,
|
||||
ml.RopeConfig{
|
||||
PositionIDs: shift,
|
||||
RopeFactors: m.Options.RopeFactors,
|
||||
RopeDim: m.Options.ropeDim,
|
||||
RopeType: ml.RopeTypeStandard,
|
||||
OrigCtxLen: m.Options.ctxLen,
|
||||
RopeBase: m.Options.ropeBase,
|
||||
RopeScale: m.Options.ropeScale,
|
||||
},
|
||||
), nil
|
||||
}
|
||||
|
||||
type MLP struct {
|
||||
|
||||
@@ -19,14 +19,23 @@ type TextSelfAttention struct {
|
||||
func (sa *TextSelfAttention) Forward(ctx ml.Context, hiddenState, positions, _ ml.Tensor, cache *kvcache.WrapperCache, opts *TextModelOptions) ml.Tensor {
|
||||
batchSize := hiddenState.Dim(1)
|
||||
headDim := opts.hiddenSize / opts.numHeads
|
||||
rc := ml.RopeConfig{
|
||||
PositionIDs: positions,
|
||||
RopeFactors: opts.RopeFactors,
|
||||
RopeDim: opts.ropeDim,
|
||||
RopeType: ml.RopeTypeStandard,
|
||||
OrigCtxLen: opts.ctxLen,
|
||||
RopeBase: opts.ropeBase,
|
||||
RopeScale: opts.ropeScale,
|
||||
}
|
||||
|
||||
query := sa.Query.Forward(ctx, hiddenState)
|
||||
query = query.Reshape(ctx, headDim, opts.numHeads, batchSize)
|
||||
query = query.RoPE(ctx, positions, opts.RopeFactors, opts.ropeDim, opts.ropeBase, opts.ropeScale)
|
||||
query = query.RoPE(ctx, rc)
|
||||
|
||||
key := sa.Key.Forward(ctx, hiddenState)
|
||||
key = key.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
|
||||
key = key.RoPE(ctx, positions, opts.RopeFactors, opts.ropeDim, opts.ropeBase, opts.ropeScale)
|
||||
key = key.RoPE(ctx, rc)
|
||||
|
||||
value := sa.Value.Forward(ctx, hiddenState)
|
||||
value = value.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
|
||||
@@ -52,7 +61,18 @@ func (sa *TextSelfAttention) Forward(ctx ml.Context, hiddenState, positions, _ m
|
||||
|
||||
func (m *TextModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
|
||||
// This will only get called for layers in the cache, which are just the self attention layers
|
||||
return key.RoPE(ctx, shift, m.RopeFactors, m.ropeDim, m.ropeBase, m.ropeScale), nil
|
||||
return key.RoPE(
|
||||
ctx,
|
||||
ml.RopeConfig{
|
||||
PositionIDs: shift,
|
||||
RopeFactors: m.RopeFactors,
|
||||
RopeDim: m.ropeDim,
|
||||
RopeType: ml.RopeTypeStandard,
|
||||
OrigCtxLen: m.ctxLen,
|
||||
RopeBase: m.ropeBase,
|
||||
RopeScale: m.ropeScale,
|
||||
},
|
||||
), nil
|
||||
}
|
||||
|
||||
type TextMLP struct {
|
||||
@@ -189,9 +209,9 @@ func (d *TextDecoder) Forward(ctx ml.Context, hiddenState, positionIDs, mask, cr
|
||||
type TextModelOptions struct {
|
||||
RopeFactors ml.Tensor `gguf:"rope_freqs.weight"`
|
||||
|
||||
hiddenSize, numHeads, numKVHeads int
|
||||
eps, ropeBase, ropeScale float32
|
||||
ropeDim uint32
|
||||
ctxLen, hiddenSize, numHeads, numKVHeads int
|
||||
eps, ropeBase, ropeScale float32
|
||||
ropeDim uint32
|
||||
|
||||
crossAttentionLayers []uint32
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user