diff --git a/ml/backend.go b/ml/backend.go index 4d930fe43..6e5a059ad 100644 --- a/ml/backend.go +++ b/ml/backend.go @@ -233,8 +233,10 @@ type Tensor interface { // // kqv := value.Mulmat(ctx, kq) // return kqv.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx) +// +// cacheConfigApplied indicates whether the optimizations requested through CacheConfig have been performed type ScaledDotProductAttention interface { - ScaledDotProductAttention(ctx Context, key, value, mask, sinks Tensor, vmla Tensor, scale float64) Tensor + ScaledDotProductAttention(ctx Context, key, value, mask, sinks Tensor, vmla Tensor, scale float64, cacheConfigApplied bool) Tensor } type number interface { diff --git a/ml/backend/ggml/ggml.go b/ml/backend/ggml/ggml.go index f1a19e0bd..18bdc91eb 100644 --- a/ml/backend/ggml/ggml.go +++ b/ml/backend/ggml/ggml.go @@ -1645,7 +1645,29 @@ func (t *Tensor) AvgPool2D(ctx ml.Context, k, s int, p float32) ml.Tensor { } } -func (t *Tensor) ScaledDotProductAttention(ctx ml.Context, key, value, mask, sinks ml.Tensor, vmla ml.Tensor, scale float64) ml.Tensor { +func (t *Tensor) ScaledDotProductAttention(ctx ml.Context, key, value, mask, sinks ml.Tensor, vmla ml.Tensor, scale float64, cacheConfigApplied bool) ml.Tensor { + // If the cache didn't help us with required transformations, do them here + if !cacheConfigApplied { + cacheConfig := t.b.CacheConfig() + + // Padding key and value to CachePadding is a performance optimization, not a requirement, so we don't do it if it wasn't done by the caller + + if cacheConfig.PermutedV { + value = value.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx) + } + + if mask != nil { + padSize := int(pad(C.size_t(mask.Dim(1)), C.size_t(cacheConfig.MaskBatchPadding))) - mask.Dim(1) + if padSize > 0 { + mask = mask.Pad(ctx, 0, padSize, 0, 0) + } + + if mask.DType() != cacheConfig.MaskDType { + mask = mask.Cast(ctx, cacheConfig.MaskDType) + } + } + } + var kqMask *C.struct_ggml_tensor if mask != nil { kqMask = mask.(*Tensor).t diff --git a/ml/nn/attention.go b/ml/nn/attention.go index 123ae5378..e495e1f60 100644 --- a/ml/nn/attention.go +++ b/ml/nn/attention.go @@ -57,10 +57,9 @@ func AttentionWithVMLA(ctx ml.Context, query, key, value, sinks ml.Tensor, vmla key, value, mask = cache.Get(ctx) } - // Only use the fast SDPA implementation if we have a cache, since that's what - // will do any expected backend-specific transformations for us - if sdpa, ok := query.(ml.ScaledDotProductAttention); ok && cache != nil { - return sdpa.ScaledDotProductAttention(ctx, key, value, mask, sinks, vmla, scale) + if sdpa, ok := query.(ml.ScaledDotProductAttention); ok { + cacheConfigApplied := cache != nil + return sdpa.ScaledDotProductAttention(ctx, key, value, mask, sinks, vmla, scale, cacheConfigApplied) } else { query = query.Permute(ctx, 0, 2, 1, 3) key = key.Permute(ctx, 0, 2, 1, 3)