gemma2: enable flash attention
This commit is contained in:
parent
f49797fbdb
commit
89637ae43b
|
|
@ -826,10 +826,6 @@ func (f GGML) SupportsFlashAttention() bool {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
if arch := f.KV().Architecture(); slices.Contains([]string{"gemma2"}, arch) {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check head counts match and are non-zero
|
// Check head counts match and are non-zero
|
||||||
headCountK := f.KV().EmbeddingHeadCountK()
|
headCountK := f.KV().EmbeddingHeadCountK()
|
||||||
headCountV := f.KV().EmbeddingHeadCountV()
|
headCountV := f.KV().EmbeddingHeadCountV()
|
||||||
|
|
|
||||||
|
|
@ -73,9 +73,10 @@ func New(c fs.Config) (model.Model, error) {
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
slidingWindowLen := int32(c.Uint("attention.sliding_window"))
|
m.Cache = kvcache.NewWrapperCache(
|
||||||
m.Cache = kvcache.NewWrapperCache(kvcache.NewSWACache(slidingWindowLen, m.Shift), kvcache.NewCausalCache(m.Shift))
|
kvcache.NewSWACache(int32(c.Uint("attention.sliding_window")), m.Shift),
|
||||||
m.Cache.SetConfig(ml.CacheConfig{})
|
kvcache.NewCausalCache(m.Shift),
|
||||||
|
)
|
||||||
|
|
||||||
return &m, nil
|
return &m, nil
|
||||||
}
|
}
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue