gemma2: enable flash attention

This commit is contained in:
Michael Yang 2025-12-15 11:31:42 -08:00
parent f49797fbdb
commit 89637ae43b
2 changed files with 4 additions and 7 deletions

View File

@ -826,10 +826,6 @@ func (f GGML) SupportsFlashAttention() bool {
return false return false
} }
if arch := f.KV().Architecture(); slices.Contains([]string{"gemma2"}, arch) {
return false
}
// Check head counts match and are non-zero // Check head counts match and are non-zero
headCountK := f.KV().EmbeddingHeadCountK() headCountK := f.KV().EmbeddingHeadCountK()
headCountV := f.KV().EmbeddingHeadCountV() headCountV := f.KV().EmbeddingHeadCountV()

View File

@ -73,9 +73,10 @@ func New(c fs.Config) (model.Model, error) {
}, },
} }
slidingWindowLen := int32(c.Uint("attention.sliding_window")) m.Cache = kvcache.NewWrapperCache(
m.Cache = kvcache.NewWrapperCache(kvcache.NewSWACache(slidingWindowLen, m.Shift), kvcache.NewCausalCache(m.Shift)) kvcache.NewSWACache(int32(c.Uint("attention.sliding_window")), m.Shift),
m.Cache.SetConfig(ml.CacheConfig{}) kvcache.NewCausalCache(m.Shift),
)
return &m, nil return &m, nil
} }