From de82b1f9a3881987f658999382dbd764f4b0b538 Mon Sep 17 00:00:00 2001 From: Michael Yang Date: Tue, 18 Nov 2025 17:06:24 -0800 Subject: [PATCH] cleanup attention interface the updated interface supports variadic attention options which removes the need for individual `AttentionWith...` functions. it means more models can use the attention interface, e.g. models with custom masks, logit softcapping, etc. additionally, this interface should be less error prone since there are now reasonable defaults for all optional parameters --- ml/backend/ggml/ggml.go | 88 ++++++++++++++++++++++++ ml/nn/attention.go | 53 +++++--------- ml/nn/attention/options.go | 55 +++++++++++++++ model/models/bert/embed.go | 3 +- model/models/deepseek2/model.go | 32 +++++---- model/models/deepseekocr/model_sam.go | 20 ++---- model/models/deepseekocr/model_text.go | 4 +- model/models/deepseekocr/model_vision.go | 2 +- model/models/gemma2/model.go | 28 ++------ model/models/gemma3/model_text.go | 4 +- model/models/gemma3/model_vision.go | 4 +- model/models/gemma3n/model_text.go | 7 +- model/models/gptoss/model.go | 8 +-- model/models/llama/model.go | 3 +- model/models/llama4/model_text.go | 2 +- model/models/llama4/model_vision.go | 2 +- model/models/mistral3/model_text.go | 2 +- model/models/mistral3/model_vision.go | 2 +- model/models/mllama/model_text.go | 19 +---- model/models/mllama/model_vision.go | 3 +- model/models/nomicbert/model.go | 3 +- model/models/olmo3/model.go | 3 +- model/models/qwen2/model.go | 3 +- model/models/qwen25vl/model_text.go | 5 +- model/models/qwen25vl/model_vision.go | 23 ++----- model/models/qwen3/model.go | 2 +- model/models/qwen3vl/model_text.go | 2 +- model/models/qwen3vl/model_vision.go | 2 +- 28 files changed, 222 insertions(+), 162 deletions(-) create mode 100644 ml/nn/attention/options.go diff --git a/ml/backend/ggml/ggml.go b/ml/backend/ggml/ggml.go index 6a044260a..424e0796a 100644 --- a/ml/backend/ggml/ggml.go +++ b/ml/backend/ggml/ggml.go @@ -19,6 +19,7 @@ import ( "io" "log/slog" "maps" + "math" "os" "runtime" "slices" @@ -35,6 +36,7 @@ import ( "github.com/ollama/ollama/logutil" "github.com/ollama/ollama/ml" ggml "github.com/ollama/ollama/ml/backend/ggml/ggml/src" + "github.com/ollama/ollama/ml/nn/attention" "github.com/ollama/ollama/ml/nn/rope" "golang.org/x/sync/errgroup" ) @@ -1849,3 +1851,89 @@ func (t *Tensor) ChunkSections(ctx ml.Context, dim int, sections ...int) []ml.Te } return s } + +func (t *Tensor) SDPA(ctx ml.Context, key, value ml.Tensor, fns ...func(*attention.Options)) ml.Tensor { + opts := attention.Options{ + Scale: 1 / math.Sqrt(float64(t.Dim(0))), + } + + for _, fn := range fns { + fn(&opts) + } + + if !opts.Cached { + config := t.b.CacheConfig() + if config.PermutedV { + value = value.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx) + } + + if opts.Mask != nil { + if padSize := int(pad(C.size_t(opts.Mask.Dim(1)), C.size_t(config.MaskBatchPadding))) - opts.Mask.Dim(1); padSize > 0 { + opts.Mask = opts.Mask.Pad(ctx, 0, padSize, 0, 0) + } + + if opts.Mask.DType() != config.MaskDType { + opts.Mask = opts.Mask.Cast(ctx, config.MaskDType) + } + } + } + + query := t.Permute(ctx, 0, 2, 1, 3) + key = key.Permute(ctx, 0, 2, 1, 3) + + var mask *C.struct_ggml_tensor + if opts.Mask != nil { + mask = opts.Mask.(*Tensor).t + } + + if t.b.flashAttention == ml.FlashAttentionEnabled { + value = value.Permute(ctx, 0, 2, 1, 3) + + tt := C.ggml_flash_attn_ext(ctx.(*Context).ctx, query.(*Tensor).t, key.(*Tensor).t, value.(*Tensor).t, mask, C.float(opts.Scale), 0, C.float(opts.LogitSoftcap)) + C.ggml_flash_attn_ext_set_prec(tt, C.GGML_PREC_F32) + if opts.Sinks != nil { + C.ggml_flash_attn_ext_add_sinks(tt, opts.Sinks.(*Tensor).t) + } + + var attention ml.Tensor = &Tensor{b: t.b, t: tt} + if opts.MLA != nil { + attention = attention.Permute(ctx, 0, 2, 1, 3) + attention = opts.MLA.Mulmat(ctx, attention) + attention = attention.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx) + } + + return attention + } + + scores := key.Mulmat(ctx, query) + C.ggml_mul_mat_set_prec(scores.(*Tensor).t, C.GGML_PREC_F32) + if opts.LogitSoftcap > 0 { + scores = scores.Scale(ctx, 1/float64(opts.LogitSoftcap)).Tanh(ctx).Scale(ctx, float64(opts.LogitSoftcap)) + } + + if opts.Cached { + scores = &Tensor{b: t.b, t: C.ggml_soft_max_ext(ctx.(*Context).ctx, scores.(*Tensor).t, mask, C.float(opts.Scale), 0)} + } else { + scores = scores.Scale(ctx, opts.Scale) + if opts.Mask != nil { + scores = scores.Add(ctx, opts.Mask) + } + + scores = scores.Softmax(ctx) + } + + if opts.Sinks != nil { + C.ggml_soft_max_add_sinks(scores.(*Tensor).t, opts.Sinks.(*Tensor).t) + } + + if key.Dim(1) == value.Dim(2) && key.Dim(2) == value.Dim(1) { + value = value.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx) + } + + attention := value.Mulmat(ctx, scores) + if opts.MLA != nil { + attention = opts.MLA.Mulmat(ctx, attention) + } + + return attention.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx) +} diff --git a/ml/nn/attention.go b/ml/nn/attention.go index e495e1f60..5a3bfc092 100644 --- a/ml/nn/attention.go +++ b/ml/nn/attention.go @@ -1,12 +1,17 @@ package nn import ( - "fmt" + "log" "github.com/ollama/ollama/kvcache" "github.com/ollama/ollama/ml" + "github.com/ollama/ollama/ml/nn/attention" ) +type fastAttention interface { + SDPA(ctx ml.Context, key, value ml.Tensor, opts ...func(*attention.Options)) ml.Tensor +} + // Attention implements scaled dot-product attention for transformer models: // Attention(Q, K, V) = softmax(QK^T/√d_k)V // @@ -21,27 +26,19 @@ import ( // Returns: // // Attention output with shape [d_v, heads, seq_len_q] -func Attention(ctx ml.Context, query, key, value ml.Tensor, scale float64, cache kvcache.Cache) ml.Tensor { - return AttentionWithVMLA(ctx, query, key, value, nil, nil, scale, cache) -} -func AttentionWithSinks(ctx ml.Context, query, key, value, sinks ml.Tensor, scale float64, cache kvcache.Cache) ml.Tensor { - return AttentionWithVMLA(ctx, query, key, value, sinks, nil, scale, cache) -} - -func AttentionWithVMLA(ctx ml.Context, query, key, value, sinks ml.Tensor, vmla ml.Tensor, scale float64, cache kvcache.Cache) ml.Tensor { - ctx.Forward(query) +func Attention(ctx ml.Context, query, key, value ml.Tensor, cache kvcache.Cache, fns ...func(*attention.Options)) ml.Tensor { if key != nil && value != nil { if query.Dim(0) != key.Dim(0) { - panic(fmt.Errorf("d_k in attention operation does not match between query(%v) and key(%v)", query.Dim(0), key.Dim(0))) + log.Fatalf("d_k in attention operation does not match between query(%v) and key(%v)", query.Dim(0), key.Dim(0)) } if key.Dim(1) != value.Dim(1) { - panic(fmt.Errorf("kv_heads in attention operation does not match between key(%v) and value(%v)", key.Dim(1), value.Dim(1))) + log.Fatalf("kv_heads in attention operation does not match between key(%v) and value(%v)", key.Dim(1), value.Dim(1)) } if key.Dim(2) != value.Dim(2) { - panic(fmt.Errorf("seq_len_k in attention operation does not match between key(%v) and value(%v)", key.Dim(2), value.Dim(2))) + log.Fatalf("seq_len_k in attention operation does not match between key(%v) and value(%v)", key.Dim(2), value.Dim(2)) } ctx.Forward(key, value) @@ -57,28 +54,12 @@ func AttentionWithVMLA(ctx ml.Context, query, key, value, sinks ml.Tensor, vmla key, value, mask = cache.Get(ctx) } - if sdpa, ok := query.(ml.ScaledDotProductAttention); ok { - cacheConfigApplied := cache != nil - return sdpa.ScaledDotProductAttention(ctx, key, value, mask, sinks, vmla, scale, cacheConfigApplied) - } else { - query = query.Permute(ctx, 0, 2, 1, 3) - key = key.Permute(ctx, 0, 2, 1, 3) - value = value.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx) - - kq := key.MulmatFullPrec(ctx, query) - - kq = kq.Scale(ctx, scale) - if mask != nil { - kq = kq.Add(ctx, mask) - } - kq = kq.Softmax(ctx) - - kqv := value.Mulmat(ctx, kq) - - if vmla != nil { - kqv = vmla.Mulmat(ctx, kqv) - } - - return kqv.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx) + if t, ok := query.(fastAttention); ok { + return t.SDPA(ctx, key, value, append([]func(*attention.Options){ + attention.WithMask(mask), + func(opts *attention.Options) { opts.Cached = cache != nil }, + }, fns...)...) } + + panic("Attention not implemented for this tensor type") } diff --git a/ml/nn/attention/options.go b/ml/nn/attention/options.go new file mode 100644 index 000000000..4bf1c377c --- /dev/null +++ b/ml/nn/attention/options.go @@ -0,0 +1,55 @@ +package attention + +import ( + "github.com/ollama/ollama/ml" +) + +type Options struct { + // Scale is a scaling factor applied to the attention scores. Default is 1/√d_k. + Scale float64 + + // LogitSoftcap is used to apply a soft cap to the logits before softmax. + LogitSoftcap float32 + + // Mask is used in some attention mechanisms to mask out certain positions. + Mask ml.Tensor + + // Sinks is used in some attention mechanisms to store additional data. + Sinks ml.Tensor + + // MLA is used in some attention mechanisms for multi-latent attention. + MLA ml.Tensor + + // Cached indicates whether key/value were retrieved from cache. + Cached bool +} + +func WithScale(scale float64) func(*Options) { + return func(o *Options) { + o.Scale = scale + } +} + +func WithSinks(sinks ml.Tensor) func(*Options) { + return func(o *Options) { + o.Sinks = sinks + } +} + +func WithMLA(mla ml.Tensor) func(*Options) { + return func(o *Options) { + o.MLA = mla + } +} + +func WithMask(mask ml.Tensor) func(*Options) { + return func(o *Options) { + o.Mask = mask + } +} + +func WithLogitSoftcap(softcap float32) func(*Options) { + return func(o *Options) { + o.LogitSoftcap = softcap + } +} diff --git a/model/models/bert/embed.go b/model/models/bert/embed.go index 5e7ca5e92..27c315757 100644 --- a/model/models/bert/embed.go +++ b/model/models/bert/embed.go @@ -2,7 +2,6 @@ package bert import ( "cmp" - "math" "github.com/ollama/ollama/fs" "github.com/ollama/ollama/ml" @@ -99,7 +98,7 @@ func (a *Attention) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *Option value := a.Value.Forward(ctx, hiddenStates) value = value.Reshape(ctx, opts.headDim(), cmp.Or(opts.numKVHeads, opts.numHeads), batchSize) - attention := nn.Attention(ctx, query, key, value, 1/math.Sqrt(float64(opts.headDim())), nil) + attention := nn.Attention(ctx, query, key, value, nil) attention = attention.Reshape(ctx, opts.hiddenSize, batchSize) return a.Output.Forward(ctx, attention) } diff --git a/model/models/deepseek2/model.go b/model/models/deepseek2/model.go index 576076aab..d40df204f 100644 --- a/model/models/deepseek2/model.go +++ b/model/models/deepseek2/model.go @@ -10,6 +10,7 @@ import ( "github.com/ollama/ollama/kvcache" "github.com/ollama/ollama/ml" "github.com/ollama/ollama/ml/nn" + "github.com/ollama/ollama/ml/nn/attention" "github.com/ollama/ollama/ml/nn/rope" "github.com/ollama/ollama/model" "github.com/ollama/ollama/model/input" @@ -66,22 +67,22 @@ type Attention struct { Output *nn.Linear `gguf:"attn_out,alt:attn_output"` } -func (attn *Attention) Forward(ctx ml.Context, hiddenStates, positions ml.Tensor, cache kvcache.Cache, opts *Options) ml.Tensor { +func (m *Attention) Forward(ctx ml.Context, hiddenStates, positions ml.Tensor, cache kvcache.Cache, opts *Options) ml.Tensor { seqLength := hiddenStates.Dim(1) var query ml.Tensor if opts.qLoraRank == 0 { - query = attn.Q.Forward(ctx, hiddenStates) + query = m.Q.Forward(ctx, hiddenStates) } else { - query = attn.QA.Forward(ctx, hiddenStates) - query = attn.QANorm.Forward(ctx, query, opts.eps) - query = attn.QB.Forward(ctx, query) + query = m.QA.Forward(ctx, hiddenStates) + query = m.QANorm.Forward(ctx, query, opts.eps) + query = m.QB.Forward(ctx, query) } query = query.Reshape(ctx, query.Dim(0)/opts.numHeads, opts.numHeads, seqLength) queryChunks := query.ChunkSections(ctx, 0, opts.qkNopeHeadDim, opts.qkRopeHeadDim) - compressedKV := attn.KVA.Forward(ctx, hiddenStates) + compressedKV := m.KVA.Forward(ctx, hiddenStates) kPass := compressedKV.Slice(ctx, 0, 0, opts.kvLoraRank, 1) kRot := compressedKV.View(ctx, opts.kvLoraRank*compressedKV.Stride(0), opts.qkRopeHeadDim, @@ -91,12 +92,10 @@ func (attn *Attention) Forward(ctx ml.Context, hiddenStates, positions ml.Tensor qRot := opts.applyRotaryPositionEmbeddings(ctx, queryChunks[1], positions) kRot = opts.applyRotaryPositionEmbeddings(ctx, kRot, positions) - kPass = attn.KVANorm.Forward(ctx, kPass, opts.eps) - - var attention ml.Tensor + kPass = m.KVANorm.Forward(ctx, kPass, opts.eps) if !opts.isMLA { // v3 - kPass = attn.KVB.Forward(ctx, kPass) + kPass = m.KVB.Forward(ctx, kPass) kv := kPass.Reshape(ctx, kPass.Dim(0)/opts.numKVHeads, opts.numKVHeads, seqLength) kvChunks := kv.ChunkSections(ctx, 0, opts.kqNopeHeadDim, opts.vHeadDim) @@ -104,10 +103,10 @@ func (attn *Attention) Forward(ctx ml.Context, hiddenStates, positions ml.Tensor kRot = kRot.Repeat(ctx, 1, queryChunks[0].Dim(1)) query = qRot.Concat(ctx, queryChunks[0], 0) key := kRot.Concat(ctx, kvChunks[0], 0) - attention = nn.Attention(ctx, query, key, kvChunks[1], opts.kqScale, cache) + hiddenStates = nn.Attention(ctx, query, key, kvChunks[1], cache, attention.WithScale(opts.kqScale)) } else { // v3.1 qPass := queryChunks[0].Permute(ctx, 0, 2, 1, 3) - qPassAbsorb := attn.KB.Forward(ctx, qPass) + qPassAbsorb := m.KB.Forward(ctx, qPass) qPassAbsorb = qPassAbsorb.Permute(ctx, 0, 2, 1, 3) query = qRot.Concat(ctx, qPassAbsorb, 0) @@ -115,11 +114,14 @@ func (attn *Attention) Forward(ctx ml.Context, hiddenStates, positions ml.Tensor key := kRot.Concat(ctx, kPass, 0) value := kPass - attention = nn.AttentionWithVMLA(ctx, query, key, value, nil, attn.VB.Weight, opts.kqScale, cache) + hiddenStates = nn.Attention(ctx, query, key, value, cache, + attention.WithMLA(m.VB.Weight), + attention.WithScale(opts.kqScale), + ) } - attention = attention.Reshape(ctx, attention.Dim(0)*attention.Dim(1), seqLength) - return attn.Output.Forward(ctx, attention) + hiddenStates = hiddenStates.Reshape(ctx, hiddenStates.Dim(0)*hiddenStates.Dim(1), seqLength) + return m.Output.Forward(ctx, hiddenStates) } type MLP interface { diff --git a/model/models/deepseekocr/model_sam.go b/model/models/deepseekocr/model_sam.go index 8bf30f96c..636743b72 100644 --- a/model/models/deepseekocr/model_sam.go +++ b/model/models/deepseekocr/model_sam.go @@ -1,11 +1,11 @@ package deepseekocr import ( - "math" "slices" "github.com/ollama/ollama/ml" "github.com/ollama/ollama/ml/nn" + "github.com/ollama/ollama/ml/nn/attention" ) type samModel struct { @@ -166,23 +166,13 @@ func (m *samAttention) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts samO ctx.Forward(query, key, value) - query = query.Permute(ctx, 0, 2, 1, 3) - rh, rw := m.decomposedRelativePositions(ctx, query, []int{h, w}, []int{h, w}) + rh, rw := m.decomposedRelativePositions(ctx, query.Permute(ctx, 0, 2, 1, 3), []int{h, w}, []int{h, w}) mask := rh.Repeat(ctx, 0, rw.Dim(0)).Add(ctx, rw) mask = mask.Reshape(ctx, h*w, -1, opts.numHeads, b) - key = key.Permute(ctx, 0, 2, 1, 3) - scores := key.MulmatFullPrec(ctx, query) - scores = scores.Scale(ctx, 1/math.Sqrt(float64(opts.headDim()))) - - scores = scores.Add(ctx, mask) - scores = scores.Softmax(ctx) - - value = value.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx) - attention := value.Mulmat(ctx, scores) - attention = attention.Permute(ctx, 0, 2, 1, 3) - attention = attention.Contiguous(ctx, -1, w, h, b) - return m.Output.Forward(ctx, attention) + hiddenStates = nn.Attention(ctx, query, key, value, nil, attention.WithMask(mask)) + hiddenStates = hiddenStates.Contiguous(ctx, -1, w, h, b) + return m.Output.Forward(ctx, hiddenStates) } type samMLP struct { diff --git a/model/models/deepseekocr/model_text.go b/model/models/deepseekocr/model_text.go index ab6221ccf..f3aac476b 100644 --- a/model/models/deepseekocr/model_text.go +++ b/model/models/deepseekocr/model_text.go @@ -1,8 +1,6 @@ package deepseekocr import ( - "math" - "github.com/ollama/ollama/kvcache" "github.com/ollama/ollama/ml" "github.com/ollama/ollama/ml/nn" @@ -85,7 +83,7 @@ func (m *textAttention) Forward(ctx ml.Context, hiddenStates, positions ml.Tenso query = opts.applyRotaryPositionEmbeddings(ctx, query, positions) key = opts.applyRotaryPositionEmbeddings(ctx, key, positions) - attention := nn.Attention(ctx, query, key, value, 1./math.Sqrt(float64(opts.headDim())), cache) + attention := nn.Attention(ctx, query, key, value, cache) attention = attention.Reshape(ctx, -1, attention.Dim(2)) return m.Output.Forward(ctx, attention) } diff --git a/model/models/deepseekocr/model_vision.go b/model/models/deepseekocr/model_vision.go index 61121ebfd..ac7380dae 100644 --- a/model/models/deepseekocr/model_vision.go +++ b/model/models/deepseekocr/model_vision.go @@ -102,7 +102,7 @@ func (m *visionAttention) Forward(ctx ml.Context, t ml.Tensor, opts visionOption chunks := qkv.Chunk(ctx, 1, opts.numHeads) query, key, value := chunks[0], chunks[1], chunks[2] - attention := nn.Attention(ctx, query, key, value, 1/math.Sqrt(float64(opts.headDim())), nil) + attention := nn.Attention(ctx, query, key, value, nil) attention = attention.Reshape(ctx, -1, attention.Dim(2), attention.Dim(3)) return m.Output.Forward(ctx, attention) } diff --git a/model/models/gemma2/model.go b/model/models/gemma2/model.go index 7b0aa2f01..25daa3795 100644 --- a/model/models/gemma2/model.go +++ b/model/models/gemma2/model.go @@ -7,6 +7,7 @@ import ( "github.com/ollama/ollama/kvcache" "github.com/ollama/ollama/ml" "github.com/ollama/ollama/ml/nn" + "github.com/ollama/ollama/ml/nn/attention" "github.com/ollama/ollama/ml/nn/rope" "github.com/ollama/ollama/model" "github.com/ollama/ollama/model/input" @@ -106,28 +107,13 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Ten v := sa.Value.Forward(ctx, hiddenState) v = v.Reshape(ctx, opts.attnValLen, opts.numKVHeads, batchSize) - cache.Put(ctx, k, v) - k, v, mask := cache.Get(ctx) + hiddenState = nn.Attention(ctx, q, k, v, cache, + attention.WithLogitSoftcap(opts.attnLogitSoftcap), + attention.WithScale(1), + ) - q = q.Permute(ctx, 0, 2, 1, 3) - k = k.Permute(ctx, 0, 2, 1, 3) - v = v.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx) - - kq := k.Mulmat(ctx, q) - - // logit softcap - kq = kq.Scale(ctx, 1.0/float64(opts.attnLogitSoftcap)) - kq = kq.Tanh(ctx) - kq = kq.Scale(ctx, float64(opts.attnLogitSoftcap)) - - kq = kq.Add(ctx, mask) - kq = kq.Softmax(ctx) - - kqv := v.Mulmat(ctx, kq) - kqv = kqv.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx) - kqv = kqv.Reshape(ctx, opts.attnValLen*opts.numHeads, batchSize) - - return sa.Output.Forward(ctx, kqv) + hiddenState = hiddenState.Reshape(ctx, opts.attnValLen*opts.numHeads, batchSize) + return sa.Output.Forward(ctx, hiddenState) } func (m *Model) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) { diff --git a/model/models/gemma3/model_text.go b/model/models/gemma3/model_text.go index e1c0004d9..25b0a3d60 100644 --- a/model/models/gemma3/model_text.go +++ b/model/models/gemma3/model_text.go @@ -7,6 +7,7 @@ import ( "github.com/ollama/ollama/kvcache" "github.com/ollama/ollama/ml" "github.com/ollama/ollama/ml/nn" + "github.com/ollama/ollama/ml/nn/attention" "github.com/ollama/ollama/ml/nn/rope" "github.com/ollama/ollama/model/input" ) @@ -165,8 +166,7 @@ func (sa *TextSelfAttention) Forward(ctx ml.Context, layer int, hiddenState, pos v := sa.Value.Forward(ctx, hiddenState) v = v.Reshape(ctx, opts.attnValLen, opts.numKVHeads, batchSize) - scaleFactor := 1.0 - kqv := nn.Attention(ctx, q, k, v, scaleFactor, cache) + kqv := nn.Attention(ctx, q, k, v, cache, attention.WithScale(1)) kqv = kqv.Reshape(ctx, opts.attnValLen*opts.numHeads, batchSize) return sa.Output.Forward(ctx, kqv) diff --git a/model/models/gemma3/model_vision.go b/model/models/gemma3/model_vision.go index 8b1a8eb00..e99def5e9 100644 --- a/model/models/gemma3/model_vision.go +++ b/model/models/gemma3/model_vision.go @@ -1,8 +1,6 @@ package gemma3 import ( - "math" - "github.com/ollama/ollama/fs" "github.com/ollama/ollama/ml" "github.com/ollama/ollama/ml/nn" @@ -28,7 +26,7 @@ func (sa *VisionSelfAttention) Forward(ctx ml.Context, hiddenState ml.Tensor, op key = key.Reshape(ctx, headDim, opts.numHeads, key.Dim(1), batchSize) value = value.Reshape(ctx, headDim, opts.numHeads, value.Dim(1), batchSize) - attention := nn.Attention(ctx, query, key, value, 1.0/math.Sqrt(float64(headDim)), nil) + attention := nn.Attention(ctx, query, key, value, nil) attention = attention.Reshape(ctx, opts.hiddenSize, attention.Dim(2), batchSize) hiddenState = sa.Output.Forward(ctx, attention) diff --git a/model/models/gemma3n/model_text.go b/model/models/gemma3n/model_text.go index 89cc54b8b..0e28fbc98 100644 --- a/model/models/gemma3n/model_text.go +++ b/model/models/gemma3n/model_text.go @@ -8,6 +8,7 @@ import ( "github.com/ollama/ollama/kvcache" "github.com/ollama/ollama/ml" "github.com/ollama/ollama/ml/nn" + "github.com/ollama/ollama/ml/nn/attention" "github.com/ollama/ollama/ml/nn/rope" "github.com/ollama/ollama/model/input" ) @@ -269,9 +270,9 @@ func (attn TextAttention) Forward(ctx ml.Context, hiddenStates, positions ml.Ten value = value.RMSNorm(ctx, nil, opts.eps) } - attention := nn.Attention(ctx, query, key, value, 1., cache) - attention = attention.Reshape(ctx, attention.Dim(0)*attention.Dim(1), batchSize) - return attn.Output.Forward(ctx, attention) + hiddenStates = nn.Attention(ctx, query, key, value, cache, attention.WithScale(1)) + hiddenStates = hiddenStates.Reshape(ctx, hiddenStates.Dim(0)*hiddenStates.Dim(1), batchSize) + return attn.Output.Forward(ctx, hiddenStates) } type TextMLP struct { diff --git a/model/models/gptoss/model.go b/model/models/gptoss/model.go index 9d1520bf3..3e061a412 100644 --- a/model/models/gptoss/model.go +++ b/model/models/gptoss/model.go @@ -2,13 +2,13 @@ package gptoss import ( "cmp" - "math" "strings" "github.com/ollama/ollama/fs" "github.com/ollama/ollama/kvcache" "github.com/ollama/ollama/ml" "github.com/ollama/ollama/ml/nn" + "github.com/ollama/ollama/ml/nn/attention" "github.com/ollama/ollama/ml/nn/rope" "github.com/ollama/ollama/model" "github.com/ollama/ollama/model/input" @@ -137,9 +137,9 @@ func (attn *AttentionBlock) Forward(ctx ml.Context, hiddenStates, positions ml.T query = opts.applyRotaryPositionEmbeddings(ctx, query, positions) key = opts.applyRotaryPositionEmbeddings(ctx, key, positions) - attention := nn.AttentionWithSinks(ctx, query, key, value, attn.Sinks, 1/math.Sqrt(float64(opts.headDim())), cache) - attention = attention.Reshape(ctx, attention.Dim(0)*attention.Dim(1), batchSize) - return attn.Output.Forward(ctx, attention).Add(ctx, residual) + hiddenStates = nn.Attention(ctx, query, key, value, cache, attention.WithSinks(attn.Sinks)) + hiddenStates = hiddenStates.Reshape(ctx, hiddenStates.Dim(0)*hiddenStates.Dim(1), batchSize) + return attn.Output.Forward(ctx, hiddenStates).Add(ctx, residual) } type MLPBlock struct { diff --git a/model/models/llama/model.go b/model/models/llama/model.go index 5ff4894e4..023545213 100644 --- a/model/models/llama/model.go +++ b/model/models/llama/model.go @@ -2,7 +2,6 @@ package llama import ( "cmp" - "math" "github.com/ollama/ollama/fs" "github.com/ollama/ollama/kvcache" @@ -131,7 +130,7 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positions ml.Tenso query = opts.applyRotaryPositionEmbeddings(ctx, query, positions, sa.RopeFactors) key = opts.applyRotaryPositionEmbeddings(ctx, key, positions, sa.RopeFactors) - attention := nn.Attention(ctx, query, key, value, 1.0/math.Sqrt(float64(headDim)), cache) + attention := nn.Attention(ctx, query, key, value, cache) attention = attention.Reshape(ctx, headDim*opts.numHeads, batchSize) return sa.Output.Forward(ctx, attention) } diff --git a/model/models/llama4/model_text.go b/model/models/llama4/model_text.go index c2bf06148..692fd2396 100644 --- a/model/models/llama4/model_text.go +++ b/model/models/llama4/model_text.go @@ -45,7 +45,7 @@ func (sa *TextAttention) Forward(ctx ml.Context, hiddenStates, positions, attent query = query.Mul(ctx, attentionScales) } - attention := nn.Attention(ctx, query, key, value, 1./math.Sqrt(float64(headDim)), cache) + attention := nn.Attention(ctx, query, key, value, cache) attention = attention.Reshape(ctx, opts.hiddenSize, batchSize) return sa.Output.Forward(ctx, attention) } diff --git a/model/models/llama4/model_vision.go b/model/models/llama4/model_vision.go index ff6b7fcf2..0186b5db1 100644 --- a/model/models/llama4/model_vision.go +++ b/model/models/llama4/model_vision.go @@ -72,7 +72,7 @@ func (sa *VisionAttention) Forward(ctx ml.Context, hiddenState, cos, sin ml.Tens query = applyVisionRotaryEmbedding(ctx, query, cos, sin) key = applyVisionRotaryEmbedding(ctx, key, cos, sin) - attention := nn.Attention(ctx, query, key, value, 1./math.Sqrt(float64(headDim)), nil) + attention := nn.Attention(ctx, query, key, value, nil) attention = attention.Reshape(ctx, opts.hiddenSize, attention.Dim(2), attention.Dim(3)) return sa.Output.Forward(ctx, attention) } diff --git a/model/models/mistral3/model_text.go b/model/models/mistral3/model_text.go index 36106107b..cb5b1c090 100644 --- a/model/models/mistral3/model_text.go +++ b/model/models/mistral3/model_text.go @@ -79,7 +79,7 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs, posit q = q.Mul(ctx, positionsScale) } - kqv := nn.Attention(ctx, q, k, v, 1.0/math.Sqrt(float64(headDim)), cache) + kqv := nn.Attention(ctx, q, k, v, cache) kqv = kqv.Reshape(ctx, headDim*opts.numHeads, batchSize) return sa.Output.Forward(ctx, kqv) } diff --git a/model/models/mistral3/model_vision.go b/model/models/mistral3/model_vision.go index 1de0412d5..8e291e9d0 100644 --- a/model/models/mistral3/model_vision.go +++ b/model/models/mistral3/model_vision.go @@ -39,7 +39,7 @@ func (sa *VisionSelfAttention) Forward(ctx ml.Context, hiddenStates, cos, sin ml query = applyRotaryPositionEmbeddings(ctx, query, cos, sin) key = applyRotaryPositionEmbeddings(ctx, key, cos, sin) - attention := nn.Attention(ctx, query, key, value, 1./math.Sqrt(float64(opts.headDim)), nil) + attention := nn.Attention(ctx, query, key, value, nil) attention = attention.Reshape(ctx, opts.hiddenSize, attention.Dim(2), batchSize) return sa.Output.Forward(ctx, attention) } diff --git a/model/models/mllama/model_text.go b/model/models/mllama/model_text.go index afd674eb9..fd6e2fb3f 100644 --- a/model/models/mllama/model_text.go +++ b/model/models/mllama/model_text.go @@ -1,7 +1,6 @@ package mllama import ( - "math" "slices" "github.com/ollama/ollama/fs" @@ -34,8 +33,7 @@ func (sa *TextSelfAttention) Forward(ctx ml.Context, hiddenState, positions ml.T value := sa.Value.Forward(ctx, hiddenState) value = value.Reshape(ctx, headDim, opts.numKVHeads, batchSize) - scaleFactor := 1.0 / math.Sqrt(float64(headDim)) - attention := nn.Attention(ctx, query, key, value, scaleFactor, cache) + attention := nn.Attention(ctx, query, key, value, cache) attention = attention.Reshape(ctx, opts.hiddenSize, batchSize) return sa.Output.Forward(ctx, attention) @@ -122,20 +120,7 @@ func (ca *TextCrossAttention) Forward(ctx ml.Context, hiddenState, crossAttentio } key, value, _ = cache.Get(ctx) - - scaleFactor := 1.0 / math.Sqrt(float64(headDim)) - - query = query.Permute(ctx, 0, 2, 1, 3) - key = key.Permute(ctx, 0, 2, 1, 3) - value = value.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx) - - kq := key.MulmatFullPrec(ctx, query) - - kq = kq.Scale(ctx, scaleFactor) - kq = kq.Softmax(ctx) - - kqv := value.Mulmat(ctx, kq) - attention := kqv.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx) + attention := nn.Attention(ctx, query, key, value, nil) attention = attention.Reshape(ctx, opts.hiddenSize, batchSize) return ca.Output.Forward(ctx, attention) diff --git a/model/models/mllama/model_vision.go b/model/models/mllama/model_vision.go index 2d4249472..2c48ae870 100644 --- a/model/models/mllama/model_vision.go +++ b/model/models/mllama/model_vision.go @@ -1,7 +1,6 @@ package mllama import ( - "math" "slices" "github.com/ollama/ollama/fs" @@ -30,7 +29,7 @@ func (sa *VisionSelfAttention) Forward(ctx ml.Context, hiddenState ml.Tensor, op value := sa.Value.Forward(ctx, hiddenState) value = value.Reshape(ctx, headDim, opts.numHeads, value.Dim(1), batchSize) - attention := nn.Attention(ctx, query, key, value, 1./math.Sqrt(float64(headDim)), nil) + attention := nn.Attention(ctx, query, key, value, nil) attention = attention.Reshape(ctx, opts.hiddenSize, attention.Dim(2), batchSize) return sa.Output.Forward(ctx, attention) } diff --git a/model/models/nomicbert/model.go b/model/models/nomicbert/model.go index 096d046a0..75c41de6f 100644 --- a/model/models/nomicbert/model.go +++ b/model/models/nomicbert/model.go @@ -2,7 +2,6 @@ package nomicbert import ( "cmp" - "math" "github.com/ollama/ollama/fs" "github.com/ollama/ollama/ml" @@ -166,7 +165,7 @@ func (a *Attention) Forward(ctx ml.Context, hiddenStates ml.Tensor, positions ml query = opts.applyRotaryPositionEmbeddings(ctx, query, positions) key = opts.applyRotaryPositionEmbeddings(ctx, key, positions) - attention := nn.Attention(ctx, query, key, value, 1.0/math.Sqrt(float64(opts.headDim)), nil) + attention := nn.Attention(ctx, query, key, value, nil) attention = attention.Reshape(ctx, opts.hiddenSize, batchSize) diff --git a/model/models/olmo3/model.go b/model/models/olmo3/model.go index 523c00e68..ba1287279 100644 --- a/model/models/olmo3/model.go +++ b/model/models/olmo3/model.go @@ -2,7 +2,6 @@ package olmo3 import ( "fmt" - "math" "github.com/ollama/ollama/fs" "github.com/ollama/ollama/kvcache" @@ -132,7 +131,7 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positions ml.Tenso value := sa.Value.Forward(ctx, hiddenState) value = value.Reshape(ctx, headDim, m.numKVHeads, batchSize) - attention := nn.Attention(ctx, query, key, value, 1.0/math.Sqrt(float64(headDim)), cache) + attention := nn.Attention(ctx, query, key, value, cache) attention = attention.Reshape(ctx, m.hiddenSize, batchSize) return sa.Output.Forward(ctx, attention) diff --git a/model/models/qwen2/model.go b/model/models/qwen2/model.go index 66f546ae6..36d5e26f7 100644 --- a/model/models/qwen2/model.go +++ b/model/models/qwen2/model.go @@ -3,7 +3,6 @@ package qwen2 import ( "cmp" "fmt" - "math" "strings" "github.com/ollama/ollama/fs" @@ -48,7 +47,7 @@ func (attn Attention) Forward(ctx ml.Context, hiddenStates, positions ml.Tensor, query = opts.applyRotaryPositionEmbeddings(ctx, query, positions) key = opts.applyRotaryPositionEmbeddings(ctx, key, positions) - attention := nn.Attention(ctx, query, key, value, 1.0/math.Sqrt(float64(headDim)), cache) + attention := nn.Attention(ctx, query, key, value, cache) attention = attention.Reshape(ctx, headDim*opts.numHeads, batchSize) return attn.Output.Forward(ctx, attention) diff --git a/model/models/qwen25vl/model_text.go b/model/models/qwen25vl/model_text.go index 61b072d67..6f209d2a1 100644 --- a/model/models/qwen25vl/model_text.go +++ b/model/models/qwen25vl/model_text.go @@ -1,8 +1,6 @@ package qwen25vl import ( - "math" - "github.com/ollama/ollama/fs" "github.com/ollama/ollama/kvcache" "github.com/ollama/ollama/ml" @@ -81,8 +79,7 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Ten v := sa.Value.Forward(ctx, hiddenState) v = v.Reshape(ctx, headDim, opts.numKVHeads, batchSize) - scaleFactor := 1.0 / math.Sqrt(float64(headDim)) - kqv := nn.Attention(ctx, q, k, v, scaleFactor, cache) + kqv := nn.Attention(ctx, q, k, v, cache) kqv = kqv.Reshape(ctx, opts.hiddenSize, batchSize) return sa.Output.Forward(ctx, kqv) diff --git a/model/models/qwen25vl/model_vision.go b/model/models/qwen25vl/model_vision.go index f1275437f..b0f2eb54b 100644 --- a/model/models/qwen25vl/model_vision.go +++ b/model/models/qwen25vl/model_vision.go @@ -8,6 +8,7 @@ import ( "github.com/ollama/ollama/ml" "github.com/ollama/ollama/ml/nn" "github.com/ollama/ollama/ml/nn/rope" + "github.com/ollama/ollama/ml/nn/attention" ) func blockDiagonalMask(ctx ml.Context, seqLength int, bounds []int) ml.Tensor { @@ -50,25 +51,9 @@ func (sa *VisionSelfAttention) Forward(ctx ml.Context, hiddenStates, positions, query = opts.applyRotaryPositionEmbeddings(ctx, query, positions) key = opts.applyRotaryPositionEmbeddings(ctx, key, positions) - // Scale factor for scaled dot-product attention - scale := 1.0 / math.Sqrt(float64(opts.headDim)) - - // Scaled dot-product attention - query = query.Permute(ctx, 0, 2, 1, 3) - key = key.Permute(ctx, 0, 2, 1, 3) - value = value.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx) - - kq := key.MulmatFullPrec(ctx, query) - kq = kq.Scale(ctx, scale) - if mask != nil { - kq = kq.Add(ctx, mask) - } - kq = kq.Softmax(ctx) - kqv := value.Mulmat(ctx, kq) - attention := kqv.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx) - attention = attention.Reshape(ctx, opts.hiddenSize, attention.Dim(2)) - - return sa.Output.Forward(ctx, attention) + hiddenStates = nn.Attention(ctx, query, key, value, nil, attention.WithMask(mask)) + hiddenStates = hiddenStates.Reshape(ctx, opts.hiddenSize, hiddenStates.Dim(2)) + return sa.Output.Forward(ctx, hiddenStates) } // VisionMLP implements the multi-layer perceptron diff --git a/model/models/qwen3/model.go b/model/models/qwen3/model.go index d7747364e..c081aac7c 100644 --- a/model/models/qwen3/model.go +++ b/model/models/qwen3/model.go @@ -74,7 +74,7 @@ func (sa *Attention) Forward(ctx ml.Context, hiddenStates, positions ml.Tensor, query = opts.applyRotaryPositionEmbeddings(ctx, query, positions) key = opts.applyRotaryPositionEmbeddings(ctx, key, positions) - attention := nn.Attention(ctx, query, key, value, 1./math.Sqrt(float64(opts.headDim())), cache) + attention := nn.Attention(ctx, query, key, value, cache) attention = attention.Reshape(ctx, attention.Dim(0)*attention.Dim(1), batchSize) return sa.Output.Forward(ctx, attention) } diff --git a/model/models/qwen3vl/model_text.go b/model/models/qwen3vl/model_text.go index 750c2473a..8fc1f80bd 100644 --- a/model/models/qwen3vl/model_text.go +++ b/model/models/qwen3vl/model_text.go @@ -66,7 +66,7 @@ func (sa *TextAttention) Forward(ctx ml.Context, hiddenStates, positions ml.Tens query = opts.applyRotaryPositionEmbeddings(ctx, query, positions) key = opts.applyRotaryPositionEmbeddings(ctx, key, positions) - attention := nn.Attention(ctx, query, key, value, 1./math.Sqrt(float64(opts.headDim())), cache) + attention := nn.Attention(ctx, query, key, value, cache) attention = attention.Reshape(ctx, attention.Dim(0)*attention.Dim(1), batchSize) return sa.Output.Forward(ctx, attention) } diff --git a/model/models/qwen3vl/model_vision.go b/model/models/qwen3vl/model_vision.go index 761281edc..3ec2061b2 100644 --- a/model/models/qwen3vl/model_vision.go +++ b/model/models/qwen3vl/model_vision.go @@ -39,7 +39,7 @@ func (sa *VisionAttention) Forward(ctx ml.Context, hiddenStates, cos, sin ml.Ten value := sa.Value.Forward(ctx, hiddenStates) value = value.Reshape(ctx, opts.headDim(), opts.numHeads, value.Dim(1)) - attention := nn.Attention(ctx, query, key, value, math.Pow(float64(opts.headDim()), -0.5), nil) + attention := nn.Attention(ctx, query, key, value, nil) attention = attention.Reshape(ctx, opts.hiddenSize, attention.Dim(2)) return sa.Output.Forward(ctx, attention) }