cleanup attention interface

the updated interface supports variadic attention options which
removes the need for individual `AttentionWith...` functions. it means
more models can use the attention interface, e.g. models with
custom masks, logit softcapping, etc.

additionally, this interface should be less error prone since there are
now reasonable defaults for all optional parameters
This commit is contained in:
Michael Yang 2025-11-18 17:06:24 -08:00
parent 89eb795293
commit de82b1f9a3
28 changed files with 222 additions and 162 deletions

View File

@ -19,6 +19,7 @@ import (
"io"
"log/slog"
"maps"
"math"
"os"
"runtime"
"slices"
@ -35,6 +36,7 @@ import (
"github.com/ollama/ollama/logutil"
"github.com/ollama/ollama/ml"
ggml "github.com/ollama/ollama/ml/backend/ggml/ggml/src"
"github.com/ollama/ollama/ml/nn/attention"
"github.com/ollama/ollama/ml/nn/rope"
"golang.org/x/sync/errgroup"
)
@ -1849,3 +1851,89 @@ func (t *Tensor) ChunkSections(ctx ml.Context, dim int, sections ...int) []ml.Te
}
return s
}
func (t *Tensor) SDPA(ctx ml.Context, key, value ml.Tensor, fns ...func(*attention.Options)) ml.Tensor {
opts := attention.Options{
Scale: 1 / math.Sqrt(float64(t.Dim(0))),
}
for _, fn := range fns {
fn(&opts)
}
if !opts.Cached {
config := t.b.CacheConfig()
if config.PermutedV {
value = value.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx)
}
if opts.Mask != nil {
if padSize := int(pad(C.size_t(opts.Mask.Dim(1)), C.size_t(config.MaskBatchPadding))) - opts.Mask.Dim(1); padSize > 0 {
opts.Mask = opts.Mask.Pad(ctx, 0, padSize, 0, 0)
}
if opts.Mask.DType() != config.MaskDType {
opts.Mask = opts.Mask.Cast(ctx, config.MaskDType)
}
}
}
query := t.Permute(ctx, 0, 2, 1, 3)
key = key.Permute(ctx, 0, 2, 1, 3)
var mask *C.struct_ggml_tensor
if opts.Mask != nil {
mask = opts.Mask.(*Tensor).t
}
if t.b.flashAttention == ml.FlashAttentionEnabled {
value = value.Permute(ctx, 0, 2, 1, 3)
tt := C.ggml_flash_attn_ext(ctx.(*Context).ctx, query.(*Tensor).t, key.(*Tensor).t, value.(*Tensor).t, mask, C.float(opts.Scale), 0, C.float(opts.LogitSoftcap))
C.ggml_flash_attn_ext_set_prec(tt, C.GGML_PREC_F32)
if opts.Sinks != nil {
C.ggml_flash_attn_ext_add_sinks(tt, opts.Sinks.(*Tensor).t)
}
var attention ml.Tensor = &Tensor{b: t.b, t: tt}
if opts.MLA != nil {
attention = attention.Permute(ctx, 0, 2, 1, 3)
attention = opts.MLA.Mulmat(ctx, attention)
attention = attention.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
}
return attention
}
scores := key.Mulmat(ctx, query)
C.ggml_mul_mat_set_prec(scores.(*Tensor).t, C.GGML_PREC_F32)
if opts.LogitSoftcap > 0 {
scores = scores.Scale(ctx, 1/float64(opts.LogitSoftcap)).Tanh(ctx).Scale(ctx, float64(opts.LogitSoftcap))
}
if opts.Cached {
scores = &Tensor{b: t.b, t: C.ggml_soft_max_ext(ctx.(*Context).ctx, scores.(*Tensor).t, mask, C.float(opts.Scale), 0)}
} else {
scores = scores.Scale(ctx, opts.Scale)
if opts.Mask != nil {
scores = scores.Add(ctx, opts.Mask)
}
scores = scores.Softmax(ctx)
}
if opts.Sinks != nil {
C.ggml_soft_max_add_sinks(scores.(*Tensor).t, opts.Sinks.(*Tensor).t)
}
if key.Dim(1) == value.Dim(2) && key.Dim(2) == value.Dim(1) {
value = value.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx)
}
attention := value.Mulmat(ctx, scores)
if opts.MLA != nil {
attention = opts.MLA.Mulmat(ctx, attention)
}
return attention.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
}

View File

@ -1,12 +1,17 @@
package nn
import (
"fmt"
"log"
"github.com/ollama/ollama/kvcache"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/ml/nn/attention"
)
type fastAttention interface {
SDPA(ctx ml.Context, key, value ml.Tensor, opts ...func(*attention.Options)) ml.Tensor
}
// Attention implements scaled dot-product attention for transformer models:
// Attention(Q, K, V) = softmax(QK^T/√d_k)V
//
@ -21,27 +26,19 @@ import (
// Returns:
//
// Attention output with shape [d_v, heads, seq_len_q]
func Attention(ctx ml.Context, query, key, value ml.Tensor, scale float64, cache kvcache.Cache) ml.Tensor {
return AttentionWithVMLA(ctx, query, key, value, nil, nil, scale, cache)
}
func AttentionWithSinks(ctx ml.Context, query, key, value, sinks ml.Tensor, scale float64, cache kvcache.Cache) ml.Tensor {
return AttentionWithVMLA(ctx, query, key, value, sinks, nil, scale, cache)
}
func AttentionWithVMLA(ctx ml.Context, query, key, value, sinks ml.Tensor, vmla ml.Tensor, scale float64, cache kvcache.Cache) ml.Tensor {
ctx.Forward(query)
func Attention(ctx ml.Context, query, key, value ml.Tensor, cache kvcache.Cache, fns ...func(*attention.Options)) ml.Tensor {
if key != nil && value != nil {
if query.Dim(0) != key.Dim(0) {
panic(fmt.Errorf("d_k in attention operation does not match between query(%v) and key(%v)", query.Dim(0), key.Dim(0)))
log.Fatalf("d_k in attention operation does not match between query(%v) and key(%v)", query.Dim(0), key.Dim(0))
}
if key.Dim(1) != value.Dim(1) {
panic(fmt.Errorf("kv_heads in attention operation does not match between key(%v) and value(%v)", key.Dim(1), value.Dim(1)))
log.Fatalf("kv_heads in attention operation does not match between key(%v) and value(%v)", key.Dim(1), value.Dim(1))
}
if key.Dim(2) != value.Dim(2) {
panic(fmt.Errorf("seq_len_k in attention operation does not match between key(%v) and value(%v)", key.Dim(2), value.Dim(2)))
log.Fatalf("seq_len_k in attention operation does not match between key(%v) and value(%v)", key.Dim(2), value.Dim(2))
}
ctx.Forward(key, value)
@ -57,28 +54,12 @@ func AttentionWithVMLA(ctx ml.Context, query, key, value, sinks ml.Tensor, vmla
key, value, mask = cache.Get(ctx)
}
if sdpa, ok := query.(ml.ScaledDotProductAttention); ok {
cacheConfigApplied := cache != nil
return sdpa.ScaledDotProductAttention(ctx, key, value, mask, sinks, vmla, scale, cacheConfigApplied)
} else {
query = query.Permute(ctx, 0, 2, 1, 3)
key = key.Permute(ctx, 0, 2, 1, 3)
value = value.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx)
kq := key.MulmatFullPrec(ctx, query)
kq = kq.Scale(ctx, scale)
if mask != nil {
kq = kq.Add(ctx, mask)
}
kq = kq.Softmax(ctx)
kqv := value.Mulmat(ctx, kq)
if vmla != nil {
kqv = vmla.Mulmat(ctx, kqv)
}
return kqv.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
if t, ok := query.(fastAttention); ok {
return t.SDPA(ctx, key, value, append([]func(*attention.Options){
attention.WithMask(mask),
func(opts *attention.Options) { opts.Cached = cache != nil },
}, fns...)...)
}
panic("Attention not implemented for this tensor type")
}

View File

@ -0,0 +1,55 @@
package attention
import (
"github.com/ollama/ollama/ml"
)
type Options struct {
// Scale is a scaling factor applied to the attention scores. Default is 1/√d_k.
Scale float64
// LogitSoftcap is used to apply a soft cap to the logits before softmax.
LogitSoftcap float32
// Mask is used in some attention mechanisms to mask out certain positions.
Mask ml.Tensor
// Sinks is used in some attention mechanisms to store additional data.
Sinks ml.Tensor
// MLA is used in some attention mechanisms for multi-latent attention.
MLA ml.Tensor
// Cached indicates whether key/value were retrieved from cache.
Cached bool
}
func WithScale(scale float64) func(*Options) {
return func(o *Options) {
o.Scale = scale
}
}
func WithSinks(sinks ml.Tensor) func(*Options) {
return func(o *Options) {
o.Sinks = sinks
}
}
func WithMLA(mla ml.Tensor) func(*Options) {
return func(o *Options) {
o.MLA = mla
}
}
func WithMask(mask ml.Tensor) func(*Options) {
return func(o *Options) {
o.Mask = mask
}
}
func WithLogitSoftcap(softcap float32) func(*Options) {
return func(o *Options) {
o.LogitSoftcap = softcap
}
}

View File

@ -2,7 +2,6 @@ package bert
import (
"cmp"
"math"
"github.com/ollama/ollama/fs"
"github.com/ollama/ollama/ml"
@ -99,7 +98,7 @@ func (a *Attention) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *Option
value := a.Value.Forward(ctx, hiddenStates)
value = value.Reshape(ctx, opts.headDim(), cmp.Or(opts.numKVHeads, opts.numHeads), batchSize)
attention := nn.Attention(ctx, query, key, value, 1/math.Sqrt(float64(opts.headDim())), nil)
attention := nn.Attention(ctx, query, key, value, nil)
attention = attention.Reshape(ctx, opts.hiddenSize, batchSize)
return a.Output.Forward(ctx, attention)
}

View File

@ -10,6 +10,7 @@ import (
"github.com/ollama/ollama/kvcache"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/ml/nn"
"github.com/ollama/ollama/ml/nn/attention"
"github.com/ollama/ollama/ml/nn/rope"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/model/input"
@ -66,22 +67,22 @@ type Attention struct {
Output *nn.Linear `gguf:"attn_out,alt:attn_output"`
}
func (attn *Attention) Forward(ctx ml.Context, hiddenStates, positions ml.Tensor, cache kvcache.Cache, opts *Options) ml.Tensor {
func (m *Attention) Forward(ctx ml.Context, hiddenStates, positions ml.Tensor, cache kvcache.Cache, opts *Options) ml.Tensor {
seqLength := hiddenStates.Dim(1)
var query ml.Tensor
if opts.qLoraRank == 0 {
query = attn.Q.Forward(ctx, hiddenStates)
query = m.Q.Forward(ctx, hiddenStates)
} else {
query = attn.QA.Forward(ctx, hiddenStates)
query = attn.QANorm.Forward(ctx, query, opts.eps)
query = attn.QB.Forward(ctx, query)
query = m.QA.Forward(ctx, hiddenStates)
query = m.QANorm.Forward(ctx, query, opts.eps)
query = m.QB.Forward(ctx, query)
}
query = query.Reshape(ctx, query.Dim(0)/opts.numHeads, opts.numHeads, seqLength)
queryChunks := query.ChunkSections(ctx, 0, opts.qkNopeHeadDim, opts.qkRopeHeadDim)
compressedKV := attn.KVA.Forward(ctx, hiddenStates)
compressedKV := m.KVA.Forward(ctx, hiddenStates)
kPass := compressedKV.Slice(ctx, 0, 0, opts.kvLoraRank, 1)
kRot := compressedKV.View(ctx,
opts.kvLoraRank*compressedKV.Stride(0), opts.qkRopeHeadDim,
@ -91,12 +92,10 @@ func (attn *Attention) Forward(ctx ml.Context, hiddenStates, positions ml.Tensor
qRot := opts.applyRotaryPositionEmbeddings(ctx, queryChunks[1], positions)
kRot = opts.applyRotaryPositionEmbeddings(ctx, kRot, positions)
kPass = attn.KVANorm.Forward(ctx, kPass, opts.eps)
var attention ml.Tensor
kPass = m.KVANorm.Forward(ctx, kPass, opts.eps)
if !opts.isMLA { // v3
kPass = attn.KVB.Forward(ctx, kPass)
kPass = m.KVB.Forward(ctx, kPass)
kv := kPass.Reshape(ctx, kPass.Dim(0)/opts.numKVHeads, opts.numKVHeads, seqLength)
kvChunks := kv.ChunkSections(ctx, 0, opts.kqNopeHeadDim, opts.vHeadDim)
@ -104,10 +103,10 @@ func (attn *Attention) Forward(ctx ml.Context, hiddenStates, positions ml.Tensor
kRot = kRot.Repeat(ctx, 1, queryChunks[0].Dim(1))
query = qRot.Concat(ctx, queryChunks[0], 0)
key := kRot.Concat(ctx, kvChunks[0], 0)
attention = nn.Attention(ctx, query, key, kvChunks[1], opts.kqScale, cache)
hiddenStates = nn.Attention(ctx, query, key, kvChunks[1], cache, attention.WithScale(opts.kqScale))
} else { // v3.1
qPass := queryChunks[0].Permute(ctx, 0, 2, 1, 3)
qPassAbsorb := attn.KB.Forward(ctx, qPass)
qPassAbsorb := m.KB.Forward(ctx, qPass)
qPassAbsorb = qPassAbsorb.Permute(ctx, 0, 2, 1, 3)
query = qRot.Concat(ctx, qPassAbsorb, 0)
@ -115,11 +114,14 @@ func (attn *Attention) Forward(ctx ml.Context, hiddenStates, positions ml.Tensor
key := kRot.Concat(ctx, kPass, 0)
value := kPass
attention = nn.AttentionWithVMLA(ctx, query, key, value, nil, attn.VB.Weight, opts.kqScale, cache)
hiddenStates = nn.Attention(ctx, query, key, value, cache,
attention.WithMLA(m.VB.Weight),
attention.WithScale(opts.kqScale),
)
}
attention = attention.Reshape(ctx, attention.Dim(0)*attention.Dim(1), seqLength)
return attn.Output.Forward(ctx, attention)
hiddenStates = hiddenStates.Reshape(ctx, hiddenStates.Dim(0)*hiddenStates.Dim(1), seqLength)
return m.Output.Forward(ctx, hiddenStates)
}
type MLP interface {

View File

@ -1,11 +1,11 @@
package deepseekocr
import (
"math"
"slices"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/ml/nn"
"github.com/ollama/ollama/ml/nn/attention"
)
type samModel struct {
@ -166,23 +166,13 @@ func (m *samAttention) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts samO
ctx.Forward(query, key, value)
query = query.Permute(ctx, 0, 2, 1, 3)
rh, rw := m.decomposedRelativePositions(ctx, query, []int{h, w}, []int{h, w})
rh, rw := m.decomposedRelativePositions(ctx, query.Permute(ctx, 0, 2, 1, 3), []int{h, w}, []int{h, w})
mask := rh.Repeat(ctx, 0, rw.Dim(0)).Add(ctx, rw)
mask = mask.Reshape(ctx, h*w, -1, opts.numHeads, b)
key = key.Permute(ctx, 0, 2, 1, 3)
scores := key.MulmatFullPrec(ctx, query)
scores = scores.Scale(ctx, 1/math.Sqrt(float64(opts.headDim())))
scores = scores.Add(ctx, mask)
scores = scores.Softmax(ctx)
value = value.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx)
attention := value.Mulmat(ctx, scores)
attention = attention.Permute(ctx, 0, 2, 1, 3)
attention = attention.Contiguous(ctx, -1, w, h, b)
return m.Output.Forward(ctx, attention)
hiddenStates = nn.Attention(ctx, query, key, value, nil, attention.WithMask(mask))
hiddenStates = hiddenStates.Contiguous(ctx, -1, w, h, b)
return m.Output.Forward(ctx, hiddenStates)
}
type samMLP struct {

View File

@ -1,8 +1,6 @@
package deepseekocr
import (
"math"
"github.com/ollama/ollama/kvcache"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/ml/nn"
@ -85,7 +83,7 @@ func (m *textAttention) Forward(ctx ml.Context, hiddenStates, positions ml.Tenso
query = opts.applyRotaryPositionEmbeddings(ctx, query, positions)
key = opts.applyRotaryPositionEmbeddings(ctx, key, positions)
attention := nn.Attention(ctx, query, key, value, 1./math.Sqrt(float64(opts.headDim())), cache)
attention := nn.Attention(ctx, query, key, value, cache)
attention = attention.Reshape(ctx, -1, attention.Dim(2))
return m.Output.Forward(ctx, attention)
}

View File

@ -102,7 +102,7 @@ func (m *visionAttention) Forward(ctx ml.Context, t ml.Tensor, opts visionOption
chunks := qkv.Chunk(ctx, 1, opts.numHeads)
query, key, value := chunks[0], chunks[1], chunks[2]
attention := nn.Attention(ctx, query, key, value, 1/math.Sqrt(float64(opts.headDim())), nil)
attention := nn.Attention(ctx, query, key, value, nil)
attention = attention.Reshape(ctx, -1, attention.Dim(2), attention.Dim(3))
return m.Output.Forward(ctx, attention)
}

View File

@ -7,6 +7,7 @@ import (
"github.com/ollama/ollama/kvcache"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/ml/nn"
"github.com/ollama/ollama/ml/nn/attention"
"github.com/ollama/ollama/ml/nn/rope"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/model/input"
@ -106,28 +107,13 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Ten
v := sa.Value.Forward(ctx, hiddenState)
v = v.Reshape(ctx, opts.attnValLen, opts.numKVHeads, batchSize)
cache.Put(ctx, k, v)
k, v, mask := cache.Get(ctx)
hiddenState = nn.Attention(ctx, q, k, v, cache,
attention.WithLogitSoftcap(opts.attnLogitSoftcap),
attention.WithScale(1),
)
q = q.Permute(ctx, 0, 2, 1, 3)
k = k.Permute(ctx, 0, 2, 1, 3)
v = v.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx)
kq := k.Mulmat(ctx, q)
// logit softcap
kq = kq.Scale(ctx, 1.0/float64(opts.attnLogitSoftcap))
kq = kq.Tanh(ctx)
kq = kq.Scale(ctx, float64(opts.attnLogitSoftcap))
kq = kq.Add(ctx, mask)
kq = kq.Softmax(ctx)
kqv := v.Mulmat(ctx, kq)
kqv = kqv.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
kqv = kqv.Reshape(ctx, opts.attnValLen*opts.numHeads, batchSize)
return sa.Output.Forward(ctx, kqv)
hiddenState = hiddenState.Reshape(ctx, opts.attnValLen*opts.numHeads, batchSize)
return sa.Output.Forward(ctx, hiddenState)
}
func (m *Model) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {

View File

@ -7,6 +7,7 @@ import (
"github.com/ollama/ollama/kvcache"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/ml/nn"
"github.com/ollama/ollama/ml/nn/attention"
"github.com/ollama/ollama/ml/nn/rope"
"github.com/ollama/ollama/model/input"
)
@ -165,8 +166,7 @@ func (sa *TextSelfAttention) Forward(ctx ml.Context, layer int, hiddenState, pos
v := sa.Value.Forward(ctx, hiddenState)
v = v.Reshape(ctx, opts.attnValLen, opts.numKVHeads, batchSize)
scaleFactor := 1.0
kqv := nn.Attention(ctx, q, k, v, scaleFactor, cache)
kqv := nn.Attention(ctx, q, k, v, cache, attention.WithScale(1))
kqv = kqv.Reshape(ctx, opts.attnValLen*opts.numHeads, batchSize)
return sa.Output.Forward(ctx, kqv)

View File

@ -1,8 +1,6 @@
package gemma3
import (
"math"
"github.com/ollama/ollama/fs"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/ml/nn"
@ -28,7 +26,7 @@ func (sa *VisionSelfAttention) Forward(ctx ml.Context, hiddenState ml.Tensor, op
key = key.Reshape(ctx, headDim, opts.numHeads, key.Dim(1), batchSize)
value = value.Reshape(ctx, headDim, opts.numHeads, value.Dim(1), batchSize)
attention := nn.Attention(ctx, query, key, value, 1.0/math.Sqrt(float64(headDim)), nil)
attention := nn.Attention(ctx, query, key, value, nil)
attention = attention.Reshape(ctx, opts.hiddenSize, attention.Dim(2), batchSize)
hiddenState = sa.Output.Forward(ctx, attention)

View File

@ -8,6 +8,7 @@ import (
"github.com/ollama/ollama/kvcache"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/ml/nn"
"github.com/ollama/ollama/ml/nn/attention"
"github.com/ollama/ollama/ml/nn/rope"
"github.com/ollama/ollama/model/input"
)
@ -269,9 +270,9 @@ func (attn TextAttention) Forward(ctx ml.Context, hiddenStates, positions ml.Ten
value = value.RMSNorm(ctx, nil, opts.eps)
}
attention := nn.Attention(ctx, query, key, value, 1., cache)
attention = attention.Reshape(ctx, attention.Dim(0)*attention.Dim(1), batchSize)
return attn.Output.Forward(ctx, attention)
hiddenStates = nn.Attention(ctx, query, key, value, cache, attention.WithScale(1))
hiddenStates = hiddenStates.Reshape(ctx, hiddenStates.Dim(0)*hiddenStates.Dim(1), batchSize)
return attn.Output.Forward(ctx, hiddenStates)
}
type TextMLP struct {

View File

@ -2,13 +2,13 @@ package gptoss
import (
"cmp"
"math"
"strings"
"github.com/ollama/ollama/fs"
"github.com/ollama/ollama/kvcache"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/ml/nn"
"github.com/ollama/ollama/ml/nn/attention"
"github.com/ollama/ollama/ml/nn/rope"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/model/input"
@ -137,9 +137,9 @@ func (attn *AttentionBlock) Forward(ctx ml.Context, hiddenStates, positions ml.T
query = opts.applyRotaryPositionEmbeddings(ctx, query, positions)
key = opts.applyRotaryPositionEmbeddings(ctx, key, positions)
attention := nn.AttentionWithSinks(ctx, query, key, value, attn.Sinks, 1/math.Sqrt(float64(opts.headDim())), cache)
attention = attention.Reshape(ctx, attention.Dim(0)*attention.Dim(1), batchSize)
return attn.Output.Forward(ctx, attention).Add(ctx, residual)
hiddenStates = nn.Attention(ctx, query, key, value, cache, attention.WithSinks(attn.Sinks))
hiddenStates = hiddenStates.Reshape(ctx, hiddenStates.Dim(0)*hiddenStates.Dim(1), batchSize)
return attn.Output.Forward(ctx, hiddenStates).Add(ctx, residual)
}
type MLPBlock struct {

View File

@ -2,7 +2,6 @@ package llama
import (
"cmp"
"math"
"github.com/ollama/ollama/fs"
"github.com/ollama/ollama/kvcache"
@ -131,7 +130,7 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positions ml.Tenso
query = opts.applyRotaryPositionEmbeddings(ctx, query, positions, sa.RopeFactors)
key = opts.applyRotaryPositionEmbeddings(ctx, key, positions, sa.RopeFactors)
attention := nn.Attention(ctx, query, key, value, 1.0/math.Sqrt(float64(headDim)), cache)
attention := nn.Attention(ctx, query, key, value, cache)
attention = attention.Reshape(ctx, headDim*opts.numHeads, batchSize)
return sa.Output.Forward(ctx, attention)
}

View File

@ -45,7 +45,7 @@ func (sa *TextAttention) Forward(ctx ml.Context, hiddenStates, positions, attent
query = query.Mul(ctx, attentionScales)
}
attention := nn.Attention(ctx, query, key, value, 1./math.Sqrt(float64(headDim)), cache)
attention := nn.Attention(ctx, query, key, value, cache)
attention = attention.Reshape(ctx, opts.hiddenSize, batchSize)
return sa.Output.Forward(ctx, attention)
}

View File

@ -72,7 +72,7 @@ func (sa *VisionAttention) Forward(ctx ml.Context, hiddenState, cos, sin ml.Tens
query = applyVisionRotaryEmbedding(ctx, query, cos, sin)
key = applyVisionRotaryEmbedding(ctx, key, cos, sin)
attention := nn.Attention(ctx, query, key, value, 1./math.Sqrt(float64(headDim)), nil)
attention := nn.Attention(ctx, query, key, value, nil)
attention = attention.Reshape(ctx, opts.hiddenSize, attention.Dim(2), attention.Dim(3))
return sa.Output.Forward(ctx, attention)
}

View File

@ -79,7 +79,7 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs, posit
q = q.Mul(ctx, positionsScale)
}
kqv := nn.Attention(ctx, q, k, v, 1.0/math.Sqrt(float64(headDim)), cache)
kqv := nn.Attention(ctx, q, k, v, cache)
kqv = kqv.Reshape(ctx, headDim*opts.numHeads, batchSize)
return sa.Output.Forward(ctx, kqv)
}

View File

@ -39,7 +39,7 @@ func (sa *VisionSelfAttention) Forward(ctx ml.Context, hiddenStates, cos, sin ml
query = applyRotaryPositionEmbeddings(ctx, query, cos, sin)
key = applyRotaryPositionEmbeddings(ctx, key, cos, sin)
attention := nn.Attention(ctx, query, key, value, 1./math.Sqrt(float64(opts.headDim)), nil)
attention := nn.Attention(ctx, query, key, value, nil)
attention = attention.Reshape(ctx, opts.hiddenSize, attention.Dim(2), batchSize)
return sa.Output.Forward(ctx, attention)
}

View File

@ -1,7 +1,6 @@
package mllama
import (
"math"
"slices"
"github.com/ollama/ollama/fs"
@ -34,8 +33,7 @@ func (sa *TextSelfAttention) Forward(ctx ml.Context, hiddenState, positions ml.T
value := sa.Value.Forward(ctx, hiddenState)
value = value.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
scaleFactor := 1.0 / math.Sqrt(float64(headDim))
attention := nn.Attention(ctx, query, key, value, scaleFactor, cache)
attention := nn.Attention(ctx, query, key, value, cache)
attention = attention.Reshape(ctx, opts.hiddenSize, batchSize)
return sa.Output.Forward(ctx, attention)
@ -122,20 +120,7 @@ func (ca *TextCrossAttention) Forward(ctx ml.Context, hiddenState, crossAttentio
}
key, value, _ = cache.Get(ctx)
scaleFactor := 1.0 / math.Sqrt(float64(headDim))
query = query.Permute(ctx, 0, 2, 1, 3)
key = key.Permute(ctx, 0, 2, 1, 3)
value = value.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx)
kq := key.MulmatFullPrec(ctx, query)
kq = kq.Scale(ctx, scaleFactor)
kq = kq.Softmax(ctx)
kqv := value.Mulmat(ctx, kq)
attention := kqv.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
attention := nn.Attention(ctx, query, key, value, nil)
attention = attention.Reshape(ctx, opts.hiddenSize, batchSize)
return ca.Output.Forward(ctx, attention)

View File

@ -1,7 +1,6 @@
package mllama
import (
"math"
"slices"
"github.com/ollama/ollama/fs"
@ -30,7 +29,7 @@ func (sa *VisionSelfAttention) Forward(ctx ml.Context, hiddenState ml.Tensor, op
value := sa.Value.Forward(ctx, hiddenState)
value = value.Reshape(ctx, headDim, opts.numHeads, value.Dim(1), batchSize)
attention := nn.Attention(ctx, query, key, value, 1./math.Sqrt(float64(headDim)), nil)
attention := nn.Attention(ctx, query, key, value, nil)
attention = attention.Reshape(ctx, opts.hiddenSize, attention.Dim(2), batchSize)
return sa.Output.Forward(ctx, attention)
}

View File

@ -2,7 +2,6 @@ package nomicbert
import (
"cmp"
"math"
"github.com/ollama/ollama/fs"
"github.com/ollama/ollama/ml"
@ -166,7 +165,7 @@ func (a *Attention) Forward(ctx ml.Context, hiddenStates ml.Tensor, positions ml
query = opts.applyRotaryPositionEmbeddings(ctx, query, positions)
key = opts.applyRotaryPositionEmbeddings(ctx, key, positions)
attention := nn.Attention(ctx, query, key, value, 1.0/math.Sqrt(float64(opts.headDim)), nil)
attention := nn.Attention(ctx, query, key, value, nil)
attention = attention.Reshape(ctx, opts.hiddenSize, batchSize)

View File

@ -2,7 +2,6 @@ package olmo3
import (
"fmt"
"math"
"github.com/ollama/ollama/fs"
"github.com/ollama/ollama/kvcache"
@ -132,7 +131,7 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positions ml.Tenso
value := sa.Value.Forward(ctx, hiddenState)
value = value.Reshape(ctx, headDim, m.numKVHeads, batchSize)
attention := nn.Attention(ctx, query, key, value, 1.0/math.Sqrt(float64(headDim)), cache)
attention := nn.Attention(ctx, query, key, value, cache)
attention = attention.Reshape(ctx, m.hiddenSize, batchSize)
return sa.Output.Forward(ctx, attention)

View File

@ -3,7 +3,6 @@ package qwen2
import (
"cmp"
"fmt"
"math"
"strings"
"github.com/ollama/ollama/fs"
@ -48,7 +47,7 @@ func (attn Attention) Forward(ctx ml.Context, hiddenStates, positions ml.Tensor,
query = opts.applyRotaryPositionEmbeddings(ctx, query, positions)
key = opts.applyRotaryPositionEmbeddings(ctx, key, positions)
attention := nn.Attention(ctx, query, key, value, 1.0/math.Sqrt(float64(headDim)), cache)
attention := nn.Attention(ctx, query, key, value, cache)
attention = attention.Reshape(ctx, headDim*opts.numHeads, batchSize)
return attn.Output.Forward(ctx, attention)

View File

@ -1,8 +1,6 @@
package qwen25vl
import (
"math"
"github.com/ollama/ollama/fs"
"github.com/ollama/ollama/kvcache"
"github.com/ollama/ollama/ml"
@ -81,8 +79,7 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Ten
v := sa.Value.Forward(ctx, hiddenState)
v = v.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
scaleFactor := 1.0 / math.Sqrt(float64(headDim))
kqv := nn.Attention(ctx, q, k, v, scaleFactor, cache)
kqv := nn.Attention(ctx, q, k, v, cache)
kqv = kqv.Reshape(ctx, opts.hiddenSize, batchSize)
return sa.Output.Forward(ctx, kqv)

View File

@ -8,6 +8,7 @@ import (
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/ml/nn"
"github.com/ollama/ollama/ml/nn/rope"
"github.com/ollama/ollama/ml/nn/attention"
)
func blockDiagonalMask(ctx ml.Context, seqLength int, bounds []int) ml.Tensor {
@ -50,25 +51,9 @@ func (sa *VisionSelfAttention) Forward(ctx ml.Context, hiddenStates, positions,
query = opts.applyRotaryPositionEmbeddings(ctx, query, positions)
key = opts.applyRotaryPositionEmbeddings(ctx, key, positions)
// Scale factor for scaled dot-product attention
scale := 1.0 / math.Sqrt(float64(opts.headDim))
// Scaled dot-product attention
query = query.Permute(ctx, 0, 2, 1, 3)
key = key.Permute(ctx, 0, 2, 1, 3)
value = value.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx)
kq := key.MulmatFullPrec(ctx, query)
kq = kq.Scale(ctx, scale)
if mask != nil {
kq = kq.Add(ctx, mask)
}
kq = kq.Softmax(ctx)
kqv := value.Mulmat(ctx, kq)
attention := kqv.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
attention = attention.Reshape(ctx, opts.hiddenSize, attention.Dim(2))
return sa.Output.Forward(ctx, attention)
hiddenStates = nn.Attention(ctx, query, key, value, nil, attention.WithMask(mask))
hiddenStates = hiddenStates.Reshape(ctx, opts.hiddenSize, hiddenStates.Dim(2))
return sa.Output.Forward(ctx, hiddenStates)
}
// VisionMLP implements the multi-layer perceptron

View File

@ -74,7 +74,7 @@ func (sa *Attention) Forward(ctx ml.Context, hiddenStates, positions ml.Tensor,
query = opts.applyRotaryPositionEmbeddings(ctx, query, positions)
key = opts.applyRotaryPositionEmbeddings(ctx, key, positions)
attention := nn.Attention(ctx, query, key, value, 1./math.Sqrt(float64(opts.headDim())), cache)
attention := nn.Attention(ctx, query, key, value, cache)
attention = attention.Reshape(ctx, attention.Dim(0)*attention.Dim(1), batchSize)
return sa.Output.Forward(ctx, attention)
}

View File

@ -66,7 +66,7 @@ func (sa *TextAttention) Forward(ctx ml.Context, hiddenStates, positions ml.Tens
query = opts.applyRotaryPositionEmbeddings(ctx, query, positions)
key = opts.applyRotaryPositionEmbeddings(ctx, key, positions)
attention := nn.Attention(ctx, query, key, value, 1./math.Sqrt(float64(opts.headDim())), cache)
attention := nn.Attention(ctx, query, key, value, cache)
attention = attention.Reshape(ctx, attention.Dim(0)*attention.Dim(1), batchSize)
return sa.Output.Forward(ctx, attention)
}

View File

@ -39,7 +39,7 @@ func (sa *VisionAttention) Forward(ctx ml.Context, hiddenStates, cos, sin ml.Ten
value := sa.Value.Forward(ctx, hiddenStates)
value = value.Reshape(ctx, opts.headDim(), opts.numHeads, value.Dim(1))
attention := nn.Attention(ctx, query, key, value, math.Pow(float64(opts.headDim()), -0.5), nil)
attention := nn.Attention(ctx, query, key, value, nil)
attention = attention.Reshape(ctx, opts.hiddenSize, attention.Dim(2))
return sa.Output.Forward(ctx, attention)
}