diff --git a/fs/ggml/ggml.go b/fs/ggml/ggml.go index 56614a321..3baa7408f 100644 --- a/fs/ggml/ggml.go +++ b/fs/ggml/ggml.go @@ -826,10 +826,6 @@ func (f GGML) SupportsFlashAttention() bool { return false } - if arch := f.KV().Architecture(); slices.Contains([]string{"gemma2"}, arch) { - return false - } - // Check head counts match and are non-zero headCountK := f.KV().EmbeddingHeadCountK() headCountV := f.KV().EmbeddingHeadCountV() diff --git a/model/models/gemma2/model.go b/model/models/gemma2/model.go index 25daa3795..83ba6c719 100644 --- a/model/models/gemma2/model.go +++ b/model/models/gemma2/model.go @@ -73,9 +73,10 @@ func New(c fs.Config) (model.Model, error) { }, } - slidingWindowLen := int32(c.Uint("attention.sliding_window")) - m.Cache = kvcache.NewWrapperCache(kvcache.NewSWACache(slidingWindowLen, m.Shift), kvcache.NewCausalCache(m.Shift)) - m.Cache.SetConfig(ml.CacheConfig{}) + m.Cache = kvcache.NewWrapperCache( + kvcache.NewSWACache(int32(c.Uint("attention.sliding_window")), m.Shift), + kvcache.NewCausalCache(m.Shift), + ) return &m, nil }