diff --git a/ml/backend.go b/ml/backend.go index 1e781fa7f..620f29d81 100644 --- a/ml/backend.go +++ b/ml/backend.go @@ -33,7 +33,7 @@ type Backend interface { // BackendCacheConfig should be implemented by backends that need special output // from the cache to meet specific requirements. It is frequently implemented in -// conjunction with ScaledDotProductAttention. +// conjunction with [nn.fastAttention]. type BackendCacheConfig interface { CacheConfig() CacheConfig } @@ -152,7 +152,6 @@ type Tensor interface { Div(ctx Context, t2 Tensor) Tensor Mulmat(ctx Context, t2 Tensor) Tensor - MulmatFullPrec(ctx Context, t2 Tensor) Tensor MulmatID(ctx Context, t2, ids Tensor) Tensor AddID(ctx Context, t2, ids Tensor) Tensor @@ -213,32 +212,6 @@ type Tensor interface { Interpolate(ctx Context, dims [4]int, samplingMode SamplingMode) Tensor } -// ScaledDotProductAttention implements a fused attention -// operation equivalent to following code on a tensor named -// query: -// -// query = query.Permute(ctx, 0, 2, 1, 3) -// key = key.Permute(ctx, 0, 2, 1, 3) -// value = value.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx) -// -// kq := key.MulmatFullPrec(ctx, query) -// -// kq = kq.Scale(ctx, scale) -// -// if mask != nil { -// kq = kq.Add(ctx, mask) -// } -// -// kq = kq.Softmax(ctx) -// -// kqv := value.Mulmat(ctx, kq) -// return kqv.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx) -// -// cacheConfigApplied indicates whether the optimizations requested through CacheConfig have been performed -type ScaledDotProductAttention interface { - ScaledDotProductAttention(ctx Context, key, value, mask, sinks Tensor, vmla Tensor, scale float64, cacheConfigApplied bool) Tensor -} - type number interface { ~int | ~int8 | ~int16 | ~int32 | ~int64 | ~uint | ~uint8 | ~uint16 | ~uint32 | ~uint64 | diff --git a/ml/backend/ggml/ggml.go b/ml/backend/ggml/ggml.go index 424e0796a..919203549 100644 --- a/ml/backend/ggml/ggml.go +++ b/ml/backend/ggml/ggml.go @@ -1250,16 +1250,6 @@ func (t *Tensor) Mulmat(ctx ml.Context, t2 ml.Tensor) ml.Tensor { } } -func (t *Tensor) MulmatFullPrec(ctx ml.Context, t2 ml.Tensor) ml.Tensor { - mul := C.ggml_mul_mat(ctx.(*Context).ctx, t.t, t2.(*Tensor).t) - C.ggml_mul_mat_set_prec(mul, C.GGML_PREC_F32) - - return &Tensor{ - b: t.b, - t: mul, - } -} - func (t *Tensor) MulmatID(ctx ml.Context, t2, ids ml.Tensor) ml.Tensor { return &Tensor{ b: t.b, @@ -1650,75 +1640,6 @@ func (t *Tensor) AvgPool2D(ctx ml.Context, k, s int, p float32) ml.Tensor { } } -func (t *Tensor) ScaledDotProductAttention(ctx ml.Context, key, value, mask, sinks ml.Tensor, vmla ml.Tensor, scale float64, cacheConfigApplied bool) ml.Tensor { - // If the cache didn't help us with required transformations, do them here - if !cacheConfigApplied { - cacheConfig := t.b.CacheConfig() - - // Padding key and value to CachePadding is a performance optimization, not a requirement, so we don't do it if it wasn't done by the caller - - if cacheConfig.PermutedV { - value = value.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx) - } - - if mask != nil { - padSize := int(pad(C.size_t(mask.Dim(1)), C.size_t(cacheConfig.MaskBatchPadding))) - mask.Dim(1) - if padSize > 0 { - mask = mask.Pad(ctx, 0, padSize, 0, 0) - } - - if mask.DType() != cacheConfig.MaskDType { - mask = mask.Cast(ctx, cacheConfig.MaskDType) - } - } - } - - var kqMask *C.struct_ggml_tensor - if mask != nil { - kqMask = mask.(*Tensor).t - } - - query := t.Permute(ctx, 0, 2, 1, 3) - key = key.Permute(ctx, 0, 2, 1, 3) - - if t.b.flashAttention == ml.FlashAttentionEnabled { - value = value.Permute(ctx, 0, 2, 1, 3) - - kqv := C.ggml_flash_attn_ext(ctx.(*Context).ctx, query.(*Tensor).t, key.(*Tensor).t, value.(*Tensor).t, kqMask, C.float(scale), 0, 0) - if sinks != nil { - C.ggml_flash_attn_ext_add_sinks(kqv, sinks.(*Tensor).t) - } - C.ggml_flash_attn_ext_set_prec(kqv, C.GGML_PREC_F32) - - if vmla != nil { - var cur ml.Tensor = &Tensor{b: t.b, t: kqv} - cur = cur.Permute(ctx, 0, 2, 1, 3) - cur = vmla.Mulmat(ctx, cur) - cur = cur.Permute(ctx, 0, 2, 1, 3) - cur = cur.Contiguous(ctx) - kqv = cur.(*Tensor).t - } - - return &Tensor{b: t.b, t: kqv} - } else { - kq := key.MulmatFullPrec(ctx, query) - kq = &Tensor{ - b: t.b, - t: C.ggml_soft_max_ext(ctx.(*Context).ctx, kq.(*Tensor).t, kqMask, C.float(scale), 0), - } - if sinks != nil { - C.ggml_soft_max_add_sinks(kq.(*Tensor).t, sinks.(*Tensor).t) - } - - kqv := value.Mulmat(ctx, kq) - if vmla != nil { - kqv = vmla.Mulmat(ctx, kqv) - } - - return kqv.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx) - } -} - func (t *Tensor) Duplicate(ctx ml.Context) ml.Tensor { return &Tensor{ b: t.b,