improvements

This commit is contained in:
ParthSareen 2025-12-09 22:26:30 -08:00
parent 3015146cda
commit a613eca69c
2 changed files with 36 additions and 23 deletions

View File

@ -29,6 +29,9 @@ type Options struct {
ropeType string
ropeExtrapolation float32
ropeBetaFast float32
ropeBetaSlow float32
slidingWindowPattern []bool
}
@ -66,7 +69,7 @@ func New(c fs.Config) (model.Model, error) {
var pretokenizers []string
if c.String("tokenizer.ggml.pre") != "default" {
pretokenizers = []string{
`(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+`,
"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+",
}
}
processor := model.NewBytePairEncoding(&vocabulary, pretokenizers...)
@ -74,8 +77,6 @@ func New(c fs.Config) (model.Model, error) {
hiddenSize := int(c.Uint("embedding_length"))
numHeads := int(c.Uint("attention.head_count"))
numKVHeads := int(c.Uint("attention.head_count_kv"))
// headDim := int(c.Uint("attention.head_count"))
// ropeDim := int(c.Uint("rope.dimension_count"))
eps := c.Float("attention.layer_norm_rms_epsilon")
ropeBase := c.Float("rope.freq_base", 1e4)
ropeScale := c.Float("rope.scaling.factor", 1)
@ -87,8 +88,6 @@ func New(c fs.Config) (model.Model, error) {
fmt.Printf("hiddenSize: %d\n", hiddenSize)
fmt.Printf("numHeads: %d\n", numHeads)
fmt.Printf("numKVHeads: %d\n", numKVHeads)
// fmt.Printf("headDim: %d\n", headDim)
// fmt.Printf("ropeDim: %d\n", ropeDim)
fmt.Printf("eps: %f\n", eps)
fmt.Printf("ropeBase: %f\n", ropeBase)
fmt.Printf("ropeScale: %f\n", ropeScale)
@ -102,11 +101,9 @@ func New(c fs.Config) (model.Model, error) {
TextProcessor: processor,
Layers: make([]Layer, c.Uint("block_count")),
Options: Options{
hiddenSize: hiddenSize,
numHeads: numHeads,
numKVHeads: numKVHeads,
// headDim: headDim,
// ropeDim: ropeDim,
hiddenSize: hiddenSize,
numHeads: numHeads,
numKVHeads: numKVHeads,
eps: eps,
ropeBase: ropeBase,
ropeScale: ropeScale,
@ -114,6 +111,8 @@ func New(c fs.Config) (model.Model, error) {
attnFactor: attnFactor,
ropeType: ropeType,
ropeExtrapolation: ropeExtrapolation,
ropeBetaFast: 32.0,
ropeBetaSlow: 1.0,
slidingWindowPattern: c.Bools("attention.sliding_window_pattern"),
},
}
@ -141,19 +140,31 @@ func (m *Model) applyRoPE(ctx ml.Context, states, positions ml.Tensor, ropeDim i
var ropeOpts []func(*rope.Options)
// Both SWA and non-SWA use beta_fast and beta_slow
// But SWA uses freq_scale=1.0, ext_factor=0.0, attn_factor=1.0
// defaults
ropeOpts = append(ropeOpts,
rope.WithBetaFast(m.ropeBetaFast),
rope.WithBetaSlow(m.ropeBetaSlow),
)
// SWA uses freq_scale=1.0, ext_factor=0.0, attn_factor=1.0
// Non-SWA uses full yarn parameters
if m.originalContextLength > 0 {
ropeOpts = append(ropeOpts,
rope.WithOriginalContextLength(m.originalContextLength),
)
// if !isSWA {
ropeOpts = append(ropeOpts,
rope.WithExtrapolationFactor(m.ropeExtrapolation),
)
// rope.WithAttentionFactor(m.attnFactor),
// )
// no yarn for swa
if isSWA {
ropeOpts = append(ropeOpts,
rope.WithExtrapolationFactor(0),
rope.WithAttentionFactor(1.),
)
} else {
ropeOpts = append(ropeOpts,
rope.WithExtrapolationFactor(m.ropeExtrapolation),
rope.WithAttentionFactor(m.attnFactor),
)
}
}
freqScale := float32(1.0)
@ -221,12 +232,13 @@ func (l *Layer) Forward(ctx ml.Context, hiddenState, positions, outputs ml.Tenso
residual := hiddenState
hiddenState = l.SelfAttention.Forward(ctx, hiddenState, positions, cache, m, isSWA)
hiddenState = l.PostAttentionNorm.Forward(ctx, hiddenState, m.eps)
if outputs != nil {
hiddenState = hiddenState.Rows(ctx, outputs)
residual = residual.Rows(ctx, outputs)
}
// i think this should be after getting the rows?
hiddenState = l.PostAttentionNorm.Forward(ctx, hiddenState, m.eps)
hiddenState = hiddenState.Add(ctx, residual)
residual = hiddenState
@ -256,11 +268,12 @@ func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
cacheType = cacheTypeCausal
}
cache := m.Cache.(*kvcache.WrapperCache)
cache.SetLayerType(cacheType)
wc := m.Cache.(*kvcache.WrapperCache)
wc.SetLayerType(cacheType)
// would need to check the cache at the layer instead
if causal, ok := cache.UnderlyingCache().(*kvcache.Causal); ok {
causal.SetCausal(ctx, kvcache.CausalOptions{Except: []int{i}})
if causal, ok := wc.UnderlyingCache().(*kvcache.Causal); ok {
// TODO: not sure about the index here
causal.SetCausal(ctx, kvcache.CausalOptions{Except: []int{}})
}
var outputs ml.Tensor

View File

@ -107,7 +107,7 @@ func TestTokenization(t *testing.T) {
prompt := args.prompt
if prompt == "" {
prompt = "hello"
prompt = "hello, how are you?"
}
tp := m.(model.TextProcessor)