ml: let model specify rope configuration

Add support for model-specific RoPE configuration parameters by:

1. Creating a new `RopeConfig` struct to encapsulate all RoPE parameters
2. Adding `RopeType` enum to specify different RoPE variants (Standard/NeoX)
3. Extracting original context length from model config
4. Refactoring `RoPE()` interface to use the new config struct
5. Updating llama and mllama models to use new RoPE configuration

This change allows models to specify their RoPE implementation type and
original context length, which is important for proper position embedding
calculation and model compatibility.
This commit is contained in:
Bruce MacDonald
2025-02-14 14:11:36 -08:00
parent 010313bb63
commit eb086514da
5 changed files with 104 additions and 28 deletions

View File

@@ -10,10 +10,10 @@ import (
)
type Options struct {
RopeFactors ml.Tensor `gguf:"rope_freqs.weight"`
hiddenSize, numHeads, numKVHeads int
eps, ropeBase, ropeScale float32
ropeDim uint32
RopeFactors ml.Tensor `gguf:"rope_freqs.weight"`
ctxLen, hiddenSize, numHeads, numKVHeads int
eps, ropeBase, ropeScale float32
ropeDim uint32
}
type Model struct {
@@ -46,6 +46,7 @@ func New(c ml.Config) (model.Model, error) {
numHeads: int(c.Uint("attention.head_count")),
numKVHeads: int(c.Uint("attention.head_count_kv")),
eps: c.Float("attention.layer_norm_rms_epsilon"),
ctxLen: int(c.Uint("context_length")),
ropeBase: c.Float("rope.freq_base"),
ropeScale: c.Float("rope.freq_scale", 1),
ropeDim: c.Uint("rope.dimension_count"),
@@ -67,14 +68,23 @@ type SelfAttention struct {
func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Tensor, cache kvcache.Cache, opts *Options) ml.Tensor {
batchSize := hiddenState.Dim(1)
headDim := opts.hiddenSize / opts.numHeads
rc := ml.RopeConfig{
PositionIDs: positionIDs,
RopeFactors: opts.RopeFactors,
RopeDim: opts.ropeDim,
RopeType: ml.RopeTypeStandard,
OrigCtxLen: opts.ctxLen,
RopeBase: opts.ropeBase,
RopeScale: opts.ropeScale,
}
q := sa.Query.Forward(ctx, hiddenState)
q = q.Reshape(ctx, headDim, opts.numHeads, batchSize)
q = q.RoPE(ctx, positionIDs, opts.RopeFactors, opts.ropeDim, opts.ropeBase, opts.ropeScale)
q = q.RoPE(ctx, rc)
k := sa.Key.Forward(ctx, hiddenState)
k = k.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
k = k.RoPE(ctx, positionIDs, opts.RopeFactors, opts.ropeDim, opts.ropeBase, opts.ropeScale)
k = k.RoPE(ctx, rc)
v := sa.Value.Forward(ctx, hiddenState)
v = v.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
@@ -99,7 +109,18 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Ten
}
func (m *Model) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
return key.RoPE(ctx, shift, m.Options.RopeFactors, m.Options.ropeDim, m.Options.ropeBase, m.Options.ropeScale), nil
return key.RoPE(
ctx,
ml.RopeConfig{
PositionIDs: shift,
RopeFactors: m.Options.RopeFactors,
RopeDim: m.Options.ropeDim,
RopeType: ml.RopeTypeStandard,
OrigCtxLen: m.Options.ctxLen,
RopeBase: m.Options.ropeBase,
RopeScale: m.Options.ropeScale,
},
), nil
}
type MLP struct {

View File

@@ -19,14 +19,23 @@ type TextSelfAttention struct {
func (sa *TextSelfAttention) Forward(ctx ml.Context, hiddenState, positions, _ ml.Tensor, cache *kvcache.WrapperCache, opts *TextModelOptions) ml.Tensor {
batchSize := hiddenState.Dim(1)
headDim := opts.hiddenSize / opts.numHeads
rc := ml.RopeConfig{
PositionIDs: positions,
RopeFactors: opts.RopeFactors,
RopeDim: opts.ropeDim,
RopeType: ml.RopeTypeStandard,
OrigCtxLen: opts.ctxLen,
RopeBase: opts.ropeBase,
RopeScale: opts.ropeScale,
}
query := sa.Query.Forward(ctx, hiddenState)
query = query.Reshape(ctx, headDim, opts.numHeads, batchSize)
query = query.RoPE(ctx, positions, opts.RopeFactors, opts.ropeDim, opts.ropeBase, opts.ropeScale)
query = query.RoPE(ctx, rc)
key := sa.Key.Forward(ctx, hiddenState)
key = key.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
key = key.RoPE(ctx, positions, opts.RopeFactors, opts.ropeDim, opts.ropeBase, opts.ropeScale)
key = key.RoPE(ctx, rc)
value := sa.Value.Forward(ctx, hiddenState)
value = value.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
@@ -52,7 +61,18 @@ func (sa *TextSelfAttention) Forward(ctx ml.Context, hiddenState, positions, _ m
func (m *TextModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
// This will only get called for layers in the cache, which are just the self attention layers
return key.RoPE(ctx, shift, m.RopeFactors, m.ropeDim, m.ropeBase, m.ropeScale), nil
return key.RoPE(
ctx,
ml.RopeConfig{
PositionIDs: shift,
RopeFactors: m.RopeFactors,
RopeDim: m.ropeDim,
RopeType: ml.RopeTypeStandard,
OrigCtxLen: m.ctxLen,
RopeBase: m.ropeBase,
RopeScale: m.ropeScale,
},
), nil
}
type TextMLP struct {
@@ -189,9 +209,9 @@ func (d *TextDecoder) Forward(ctx ml.Context, hiddenState, positionIDs, mask, cr
type TextModelOptions struct {
RopeFactors ml.Tensor `gguf:"rope_freqs.weight"`
hiddenSize, numHeads, numKVHeads int
eps, ropeBase, ropeScale float32
ropeDim uint32
ctxLen, hiddenSize, numHeads, numKVHeads int
eps, ropeBase, ropeScale float32
ropeDim uint32
crossAttentionLayers []uint32
}