gemma2: enable flash attention
This commit is contained in:
parent
f49797fbdb
commit
89637ae43b
|
|
@ -826,10 +826,6 @@ func (f GGML) SupportsFlashAttention() bool {
|
|||
return false
|
||||
}
|
||||
|
||||
if arch := f.KV().Architecture(); slices.Contains([]string{"gemma2"}, arch) {
|
||||
return false
|
||||
}
|
||||
|
||||
// Check head counts match and are non-zero
|
||||
headCountK := f.KV().EmbeddingHeadCountK()
|
||||
headCountV := f.KV().EmbeddingHeadCountV()
|
||||
|
|
|
|||
|
|
@ -73,9 +73,10 @@ func New(c fs.Config) (model.Model, error) {
|
|||
},
|
||||
}
|
||||
|
||||
slidingWindowLen := int32(c.Uint("attention.sliding_window"))
|
||||
m.Cache = kvcache.NewWrapperCache(kvcache.NewSWACache(slidingWindowLen, m.Shift), kvcache.NewCausalCache(m.Shift))
|
||||
m.Cache.SetConfig(ml.CacheConfig{})
|
||||
m.Cache = kvcache.NewWrapperCache(
|
||||
kvcache.NewSWACache(int32(c.Uint("attention.sliding_window")), m.Shift),
|
||||
kvcache.NewCausalCache(m.Shift),
|
||||
)
|
||||
|
||||
return &m, nil
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in New Issue