cleanup attention interface
the updated interface supports variadic attention options which removes the need for individual `AttentionWith...` functions. it means more models can use the attention interface, e.g. models with custom masks, logit softcapping, etc. additionally, this interface should be less error prone since there are now reasonable defaults for all optional parameters
This commit is contained in:
parent
89eb795293
commit
de82b1f9a3
|
|
@ -19,6 +19,7 @@ import (
|
|||
"io"
|
||||
"log/slog"
|
||||
"maps"
|
||||
"math"
|
||||
"os"
|
||||
"runtime"
|
||||
"slices"
|
||||
|
|
@ -35,6 +36,7 @@ import (
|
|||
"github.com/ollama/ollama/logutil"
|
||||
"github.com/ollama/ollama/ml"
|
||||
ggml "github.com/ollama/ollama/ml/backend/ggml/ggml/src"
|
||||
"github.com/ollama/ollama/ml/nn/attention"
|
||||
"github.com/ollama/ollama/ml/nn/rope"
|
||||
"golang.org/x/sync/errgroup"
|
||||
)
|
||||
|
|
@ -1849,3 +1851,89 @@ func (t *Tensor) ChunkSections(ctx ml.Context, dim int, sections ...int) []ml.Te
|
|||
}
|
||||
return s
|
||||
}
|
||||
|
||||
func (t *Tensor) SDPA(ctx ml.Context, key, value ml.Tensor, fns ...func(*attention.Options)) ml.Tensor {
|
||||
opts := attention.Options{
|
||||
Scale: 1 / math.Sqrt(float64(t.Dim(0))),
|
||||
}
|
||||
|
||||
for _, fn := range fns {
|
||||
fn(&opts)
|
||||
}
|
||||
|
||||
if !opts.Cached {
|
||||
config := t.b.CacheConfig()
|
||||
if config.PermutedV {
|
||||
value = value.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx)
|
||||
}
|
||||
|
||||
if opts.Mask != nil {
|
||||
if padSize := int(pad(C.size_t(opts.Mask.Dim(1)), C.size_t(config.MaskBatchPadding))) - opts.Mask.Dim(1); padSize > 0 {
|
||||
opts.Mask = opts.Mask.Pad(ctx, 0, padSize, 0, 0)
|
||||
}
|
||||
|
||||
if opts.Mask.DType() != config.MaskDType {
|
||||
opts.Mask = opts.Mask.Cast(ctx, config.MaskDType)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
query := t.Permute(ctx, 0, 2, 1, 3)
|
||||
key = key.Permute(ctx, 0, 2, 1, 3)
|
||||
|
||||
var mask *C.struct_ggml_tensor
|
||||
if opts.Mask != nil {
|
||||
mask = opts.Mask.(*Tensor).t
|
||||
}
|
||||
|
||||
if t.b.flashAttention == ml.FlashAttentionEnabled {
|
||||
value = value.Permute(ctx, 0, 2, 1, 3)
|
||||
|
||||
tt := C.ggml_flash_attn_ext(ctx.(*Context).ctx, query.(*Tensor).t, key.(*Tensor).t, value.(*Tensor).t, mask, C.float(opts.Scale), 0, C.float(opts.LogitSoftcap))
|
||||
C.ggml_flash_attn_ext_set_prec(tt, C.GGML_PREC_F32)
|
||||
if opts.Sinks != nil {
|
||||
C.ggml_flash_attn_ext_add_sinks(tt, opts.Sinks.(*Tensor).t)
|
||||
}
|
||||
|
||||
var attention ml.Tensor = &Tensor{b: t.b, t: tt}
|
||||
if opts.MLA != nil {
|
||||
attention = attention.Permute(ctx, 0, 2, 1, 3)
|
||||
attention = opts.MLA.Mulmat(ctx, attention)
|
||||
attention = attention.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
|
||||
}
|
||||
|
||||
return attention
|
||||
}
|
||||
|
||||
scores := key.Mulmat(ctx, query)
|
||||
C.ggml_mul_mat_set_prec(scores.(*Tensor).t, C.GGML_PREC_F32)
|
||||
if opts.LogitSoftcap > 0 {
|
||||
scores = scores.Scale(ctx, 1/float64(opts.LogitSoftcap)).Tanh(ctx).Scale(ctx, float64(opts.LogitSoftcap))
|
||||
}
|
||||
|
||||
if opts.Cached {
|
||||
scores = &Tensor{b: t.b, t: C.ggml_soft_max_ext(ctx.(*Context).ctx, scores.(*Tensor).t, mask, C.float(opts.Scale), 0)}
|
||||
} else {
|
||||
scores = scores.Scale(ctx, opts.Scale)
|
||||
if opts.Mask != nil {
|
||||
scores = scores.Add(ctx, opts.Mask)
|
||||
}
|
||||
|
||||
scores = scores.Softmax(ctx)
|
||||
}
|
||||
|
||||
if opts.Sinks != nil {
|
||||
C.ggml_soft_max_add_sinks(scores.(*Tensor).t, opts.Sinks.(*Tensor).t)
|
||||
}
|
||||
|
||||
if key.Dim(1) == value.Dim(2) && key.Dim(2) == value.Dim(1) {
|
||||
value = value.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx)
|
||||
}
|
||||
|
||||
attention := value.Mulmat(ctx, scores)
|
||||
if opts.MLA != nil {
|
||||
attention = opts.MLA.Mulmat(ctx, attention)
|
||||
}
|
||||
|
||||
return attention.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,12 +1,17 @@
|
|||
package nn
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"log"
|
||||
|
||||
"github.com/ollama/ollama/kvcache"
|
||||
"github.com/ollama/ollama/ml"
|
||||
"github.com/ollama/ollama/ml/nn/attention"
|
||||
)
|
||||
|
||||
type fastAttention interface {
|
||||
SDPA(ctx ml.Context, key, value ml.Tensor, opts ...func(*attention.Options)) ml.Tensor
|
||||
}
|
||||
|
||||
// Attention implements scaled dot-product attention for transformer models:
|
||||
// Attention(Q, K, V) = softmax(QK^T/√d_k)V
|
||||
//
|
||||
|
|
@ -21,27 +26,19 @@ import (
|
|||
// Returns:
|
||||
//
|
||||
// Attention output with shape [d_v, heads, seq_len_q]
|
||||
func Attention(ctx ml.Context, query, key, value ml.Tensor, scale float64, cache kvcache.Cache) ml.Tensor {
|
||||
return AttentionWithVMLA(ctx, query, key, value, nil, nil, scale, cache)
|
||||
}
|
||||
|
||||
func AttentionWithSinks(ctx ml.Context, query, key, value, sinks ml.Tensor, scale float64, cache kvcache.Cache) ml.Tensor {
|
||||
return AttentionWithVMLA(ctx, query, key, value, sinks, nil, scale, cache)
|
||||
}
|
||||
|
||||
func AttentionWithVMLA(ctx ml.Context, query, key, value, sinks ml.Tensor, vmla ml.Tensor, scale float64, cache kvcache.Cache) ml.Tensor {
|
||||
ctx.Forward(query)
|
||||
func Attention(ctx ml.Context, query, key, value ml.Tensor, cache kvcache.Cache, fns ...func(*attention.Options)) ml.Tensor {
|
||||
if key != nil && value != nil {
|
||||
if query.Dim(0) != key.Dim(0) {
|
||||
panic(fmt.Errorf("d_k in attention operation does not match between query(%v) and key(%v)", query.Dim(0), key.Dim(0)))
|
||||
log.Fatalf("d_k in attention operation does not match between query(%v) and key(%v)", query.Dim(0), key.Dim(0))
|
||||
}
|
||||
|
||||
if key.Dim(1) != value.Dim(1) {
|
||||
panic(fmt.Errorf("kv_heads in attention operation does not match between key(%v) and value(%v)", key.Dim(1), value.Dim(1)))
|
||||
log.Fatalf("kv_heads in attention operation does not match between key(%v) and value(%v)", key.Dim(1), value.Dim(1))
|
||||
}
|
||||
|
||||
if key.Dim(2) != value.Dim(2) {
|
||||
panic(fmt.Errorf("seq_len_k in attention operation does not match between key(%v) and value(%v)", key.Dim(2), value.Dim(2)))
|
||||
log.Fatalf("seq_len_k in attention operation does not match between key(%v) and value(%v)", key.Dim(2), value.Dim(2))
|
||||
}
|
||||
|
||||
ctx.Forward(key, value)
|
||||
|
|
@ -57,28 +54,12 @@ func AttentionWithVMLA(ctx ml.Context, query, key, value, sinks ml.Tensor, vmla
|
|||
key, value, mask = cache.Get(ctx)
|
||||
}
|
||||
|
||||
if sdpa, ok := query.(ml.ScaledDotProductAttention); ok {
|
||||
cacheConfigApplied := cache != nil
|
||||
return sdpa.ScaledDotProductAttention(ctx, key, value, mask, sinks, vmla, scale, cacheConfigApplied)
|
||||
} else {
|
||||
query = query.Permute(ctx, 0, 2, 1, 3)
|
||||
key = key.Permute(ctx, 0, 2, 1, 3)
|
||||
value = value.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx)
|
||||
|
||||
kq := key.MulmatFullPrec(ctx, query)
|
||||
|
||||
kq = kq.Scale(ctx, scale)
|
||||
if mask != nil {
|
||||
kq = kq.Add(ctx, mask)
|
||||
}
|
||||
kq = kq.Softmax(ctx)
|
||||
|
||||
kqv := value.Mulmat(ctx, kq)
|
||||
|
||||
if vmla != nil {
|
||||
kqv = vmla.Mulmat(ctx, kqv)
|
||||
}
|
||||
|
||||
return kqv.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
|
||||
if t, ok := query.(fastAttention); ok {
|
||||
return t.SDPA(ctx, key, value, append([]func(*attention.Options){
|
||||
attention.WithMask(mask),
|
||||
func(opts *attention.Options) { opts.Cached = cache != nil },
|
||||
}, fns...)...)
|
||||
}
|
||||
|
||||
panic("Attention not implemented for this tensor type")
|
||||
}
|
||||
|
|
|
|||
|
|
@ -0,0 +1,55 @@
|
|||
package attention
|
||||
|
||||
import (
|
||||
"github.com/ollama/ollama/ml"
|
||||
)
|
||||
|
||||
type Options struct {
|
||||
// Scale is a scaling factor applied to the attention scores. Default is 1/√d_k.
|
||||
Scale float64
|
||||
|
||||
// LogitSoftcap is used to apply a soft cap to the logits before softmax.
|
||||
LogitSoftcap float32
|
||||
|
||||
// Mask is used in some attention mechanisms to mask out certain positions.
|
||||
Mask ml.Tensor
|
||||
|
||||
// Sinks is used in some attention mechanisms to store additional data.
|
||||
Sinks ml.Tensor
|
||||
|
||||
// MLA is used in some attention mechanisms for multi-latent attention.
|
||||
MLA ml.Tensor
|
||||
|
||||
// Cached indicates whether key/value were retrieved from cache.
|
||||
Cached bool
|
||||
}
|
||||
|
||||
func WithScale(scale float64) func(*Options) {
|
||||
return func(o *Options) {
|
||||
o.Scale = scale
|
||||
}
|
||||
}
|
||||
|
||||
func WithSinks(sinks ml.Tensor) func(*Options) {
|
||||
return func(o *Options) {
|
||||
o.Sinks = sinks
|
||||
}
|
||||
}
|
||||
|
||||
func WithMLA(mla ml.Tensor) func(*Options) {
|
||||
return func(o *Options) {
|
||||
o.MLA = mla
|
||||
}
|
||||
}
|
||||
|
||||
func WithMask(mask ml.Tensor) func(*Options) {
|
||||
return func(o *Options) {
|
||||
o.Mask = mask
|
||||
}
|
||||
}
|
||||
|
||||
func WithLogitSoftcap(softcap float32) func(*Options) {
|
||||
return func(o *Options) {
|
||||
o.LogitSoftcap = softcap
|
||||
}
|
||||
}
|
||||
|
|
@ -2,7 +2,6 @@ package bert
|
|||
|
||||
import (
|
||||
"cmp"
|
||||
"math"
|
||||
|
||||
"github.com/ollama/ollama/fs"
|
||||
"github.com/ollama/ollama/ml"
|
||||
|
|
@ -99,7 +98,7 @@ func (a *Attention) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *Option
|
|||
value := a.Value.Forward(ctx, hiddenStates)
|
||||
value = value.Reshape(ctx, opts.headDim(), cmp.Or(opts.numKVHeads, opts.numHeads), batchSize)
|
||||
|
||||
attention := nn.Attention(ctx, query, key, value, 1/math.Sqrt(float64(opts.headDim())), nil)
|
||||
attention := nn.Attention(ctx, query, key, value, nil)
|
||||
attention = attention.Reshape(ctx, opts.hiddenSize, batchSize)
|
||||
return a.Output.Forward(ctx, attention)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -10,6 +10,7 @@ import (
|
|||
"github.com/ollama/ollama/kvcache"
|
||||
"github.com/ollama/ollama/ml"
|
||||
"github.com/ollama/ollama/ml/nn"
|
||||
"github.com/ollama/ollama/ml/nn/attention"
|
||||
"github.com/ollama/ollama/ml/nn/rope"
|
||||
"github.com/ollama/ollama/model"
|
||||
"github.com/ollama/ollama/model/input"
|
||||
|
|
@ -66,22 +67,22 @@ type Attention struct {
|
|||
Output *nn.Linear `gguf:"attn_out,alt:attn_output"`
|
||||
}
|
||||
|
||||
func (attn *Attention) Forward(ctx ml.Context, hiddenStates, positions ml.Tensor, cache kvcache.Cache, opts *Options) ml.Tensor {
|
||||
func (m *Attention) Forward(ctx ml.Context, hiddenStates, positions ml.Tensor, cache kvcache.Cache, opts *Options) ml.Tensor {
|
||||
seqLength := hiddenStates.Dim(1)
|
||||
|
||||
var query ml.Tensor
|
||||
if opts.qLoraRank == 0 {
|
||||
query = attn.Q.Forward(ctx, hiddenStates)
|
||||
query = m.Q.Forward(ctx, hiddenStates)
|
||||
} else {
|
||||
query = attn.QA.Forward(ctx, hiddenStates)
|
||||
query = attn.QANorm.Forward(ctx, query, opts.eps)
|
||||
query = attn.QB.Forward(ctx, query)
|
||||
query = m.QA.Forward(ctx, hiddenStates)
|
||||
query = m.QANorm.Forward(ctx, query, opts.eps)
|
||||
query = m.QB.Forward(ctx, query)
|
||||
}
|
||||
|
||||
query = query.Reshape(ctx, query.Dim(0)/opts.numHeads, opts.numHeads, seqLength)
|
||||
queryChunks := query.ChunkSections(ctx, 0, opts.qkNopeHeadDim, opts.qkRopeHeadDim)
|
||||
|
||||
compressedKV := attn.KVA.Forward(ctx, hiddenStates)
|
||||
compressedKV := m.KVA.Forward(ctx, hiddenStates)
|
||||
kPass := compressedKV.Slice(ctx, 0, 0, opts.kvLoraRank, 1)
|
||||
kRot := compressedKV.View(ctx,
|
||||
opts.kvLoraRank*compressedKV.Stride(0), opts.qkRopeHeadDim,
|
||||
|
|
@ -91,12 +92,10 @@ func (attn *Attention) Forward(ctx ml.Context, hiddenStates, positions ml.Tensor
|
|||
|
||||
qRot := opts.applyRotaryPositionEmbeddings(ctx, queryChunks[1], positions)
|
||||
kRot = opts.applyRotaryPositionEmbeddings(ctx, kRot, positions)
|
||||
kPass = attn.KVANorm.Forward(ctx, kPass, opts.eps)
|
||||
|
||||
var attention ml.Tensor
|
||||
kPass = m.KVANorm.Forward(ctx, kPass, opts.eps)
|
||||
|
||||
if !opts.isMLA { // v3
|
||||
kPass = attn.KVB.Forward(ctx, kPass)
|
||||
kPass = m.KVB.Forward(ctx, kPass)
|
||||
|
||||
kv := kPass.Reshape(ctx, kPass.Dim(0)/opts.numKVHeads, opts.numKVHeads, seqLength)
|
||||
kvChunks := kv.ChunkSections(ctx, 0, opts.kqNopeHeadDim, opts.vHeadDim)
|
||||
|
|
@ -104,10 +103,10 @@ func (attn *Attention) Forward(ctx ml.Context, hiddenStates, positions ml.Tensor
|
|||
kRot = kRot.Repeat(ctx, 1, queryChunks[0].Dim(1))
|
||||
query = qRot.Concat(ctx, queryChunks[0], 0)
|
||||
key := kRot.Concat(ctx, kvChunks[0], 0)
|
||||
attention = nn.Attention(ctx, query, key, kvChunks[1], opts.kqScale, cache)
|
||||
hiddenStates = nn.Attention(ctx, query, key, kvChunks[1], cache, attention.WithScale(opts.kqScale))
|
||||
} else { // v3.1
|
||||
qPass := queryChunks[0].Permute(ctx, 0, 2, 1, 3)
|
||||
qPassAbsorb := attn.KB.Forward(ctx, qPass)
|
||||
qPassAbsorb := m.KB.Forward(ctx, qPass)
|
||||
qPassAbsorb = qPassAbsorb.Permute(ctx, 0, 2, 1, 3)
|
||||
|
||||
query = qRot.Concat(ctx, qPassAbsorb, 0)
|
||||
|
|
@ -115,11 +114,14 @@ func (attn *Attention) Forward(ctx ml.Context, hiddenStates, positions ml.Tensor
|
|||
key := kRot.Concat(ctx, kPass, 0)
|
||||
value := kPass
|
||||
|
||||
attention = nn.AttentionWithVMLA(ctx, query, key, value, nil, attn.VB.Weight, opts.kqScale, cache)
|
||||
hiddenStates = nn.Attention(ctx, query, key, value, cache,
|
||||
attention.WithMLA(m.VB.Weight),
|
||||
attention.WithScale(opts.kqScale),
|
||||
)
|
||||
}
|
||||
|
||||
attention = attention.Reshape(ctx, attention.Dim(0)*attention.Dim(1), seqLength)
|
||||
return attn.Output.Forward(ctx, attention)
|
||||
hiddenStates = hiddenStates.Reshape(ctx, hiddenStates.Dim(0)*hiddenStates.Dim(1), seqLength)
|
||||
return m.Output.Forward(ctx, hiddenStates)
|
||||
}
|
||||
|
||||
type MLP interface {
|
||||
|
|
|
|||
|
|
@ -1,11 +1,11 @@
|
|||
package deepseekocr
|
||||
|
||||
import (
|
||||
"math"
|
||||
"slices"
|
||||
|
||||
"github.com/ollama/ollama/ml"
|
||||
"github.com/ollama/ollama/ml/nn"
|
||||
"github.com/ollama/ollama/ml/nn/attention"
|
||||
)
|
||||
|
||||
type samModel struct {
|
||||
|
|
@ -166,23 +166,13 @@ func (m *samAttention) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts samO
|
|||
|
||||
ctx.Forward(query, key, value)
|
||||
|
||||
query = query.Permute(ctx, 0, 2, 1, 3)
|
||||
rh, rw := m.decomposedRelativePositions(ctx, query, []int{h, w}, []int{h, w})
|
||||
rh, rw := m.decomposedRelativePositions(ctx, query.Permute(ctx, 0, 2, 1, 3), []int{h, w}, []int{h, w})
|
||||
mask := rh.Repeat(ctx, 0, rw.Dim(0)).Add(ctx, rw)
|
||||
mask = mask.Reshape(ctx, h*w, -1, opts.numHeads, b)
|
||||
|
||||
key = key.Permute(ctx, 0, 2, 1, 3)
|
||||
scores := key.MulmatFullPrec(ctx, query)
|
||||
scores = scores.Scale(ctx, 1/math.Sqrt(float64(opts.headDim())))
|
||||
|
||||
scores = scores.Add(ctx, mask)
|
||||
scores = scores.Softmax(ctx)
|
||||
|
||||
value = value.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx)
|
||||
attention := value.Mulmat(ctx, scores)
|
||||
attention = attention.Permute(ctx, 0, 2, 1, 3)
|
||||
attention = attention.Contiguous(ctx, -1, w, h, b)
|
||||
return m.Output.Forward(ctx, attention)
|
||||
hiddenStates = nn.Attention(ctx, query, key, value, nil, attention.WithMask(mask))
|
||||
hiddenStates = hiddenStates.Contiguous(ctx, -1, w, h, b)
|
||||
return m.Output.Forward(ctx, hiddenStates)
|
||||
}
|
||||
|
||||
type samMLP struct {
|
||||
|
|
|
|||
|
|
@ -1,8 +1,6 @@
|
|||
package deepseekocr
|
||||
|
||||
import (
|
||||
"math"
|
||||
|
||||
"github.com/ollama/ollama/kvcache"
|
||||
"github.com/ollama/ollama/ml"
|
||||
"github.com/ollama/ollama/ml/nn"
|
||||
|
|
@ -85,7 +83,7 @@ func (m *textAttention) Forward(ctx ml.Context, hiddenStates, positions ml.Tenso
|
|||
query = opts.applyRotaryPositionEmbeddings(ctx, query, positions)
|
||||
key = opts.applyRotaryPositionEmbeddings(ctx, key, positions)
|
||||
|
||||
attention := nn.Attention(ctx, query, key, value, 1./math.Sqrt(float64(opts.headDim())), cache)
|
||||
attention := nn.Attention(ctx, query, key, value, cache)
|
||||
attention = attention.Reshape(ctx, -1, attention.Dim(2))
|
||||
return m.Output.Forward(ctx, attention)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -102,7 +102,7 @@ func (m *visionAttention) Forward(ctx ml.Context, t ml.Tensor, opts visionOption
|
|||
chunks := qkv.Chunk(ctx, 1, opts.numHeads)
|
||||
query, key, value := chunks[0], chunks[1], chunks[2]
|
||||
|
||||
attention := nn.Attention(ctx, query, key, value, 1/math.Sqrt(float64(opts.headDim())), nil)
|
||||
attention := nn.Attention(ctx, query, key, value, nil)
|
||||
attention = attention.Reshape(ctx, -1, attention.Dim(2), attention.Dim(3))
|
||||
return m.Output.Forward(ctx, attention)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -7,6 +7,7 @@ import (
|
|||
"github.com/ollama/ollama/kvcache"
|
||||
"github.com/ollama/ollama/ml"
|
||||
"github.com/ollama/ollama/ml/nn"
|
||||
"github.com/ollama/ollama/ml/nn/attention"
|
||||
"github.com/ollama/ollama/ml/nn/rope"
|
||||
"github.com/ollama/ollama/model"
|
||||
"github.com/ollama/ollama/model/input"
|
||||
|
|
@ -106,28 +107,13 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Ten
|
|||
v := sa.Value.Forward(ctx, hiddenState)
|
||||
v = v.Reshape(ctx, opts.attnValLen, opts.numKVHeads, batchSize)
|
||||
|
||||
cache.Put(ctx, k, v)
|
||||
k, v, mask := cache.Get(ctx)
|
||||
hiddenState = nn.Attention(ctx, q, k, v, cache,
|
||||
attention.WithLogitSoftcap(opts.attnLogitSoftcap),
|
||||
attention.WithScale(1),
|
||||
)
|
||||
|
||||
q = q.Permute(ctx, 0, 2, 1, 3)
|
||||
k = k.Permute(ctx, 0, 2, 1, 3)
|
||||
v = v.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx)
|
||||
|
||||
kq := k.Mulmat(ctx, q)
|
||||
|
||||
// logit softcap
|
||||
kq = kq.Scale(ctx, 1.0/float64(opts.attnLogitSoftcap))
|
||||
kq = kq.Tanh(ctx)
|
||||
kq = kq.Scale(ctx, float64(opts.attnLogitSoftcap))
|
||||
|
||||
kq = kq.Add(ctx, mask)
|
||||
kq = kq.Softmax(ctx)
|
||||
|
||||
kqv := v.Mulmat(ctx, kq)
|
||||
kqv = kqv.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
|
||||
kqv = kqv.Reshape(ctx, opts.attnValLen*opts.numHeads, batchSize)
|
||||
|
||||
return sa.Output.Forward(ctx, kqv)
|
||||
hiddenState = hiddenState.Reshape(ctx, opts.attnValLen*opts.numHeads, batchSize)
|
||||
return sa.Output.Forward(ctx, hiddenState)
|
||||
}
|
||||
|
||||
func (m *Model) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
|
||||
|
|
|
|||
|
|
@ -7,6 +7,7 @@ import (
|
|||
"github.com/ollama/ollama/kvcache"
|
||||
"github.com/ollama/ollama/ml"
|
||||
"github.com/ollama/ollama/ml/nn"
|
||||
"github.com/ollama/ollama/ml/nn/attention"
|
||||
"github.com/ollama/ollama/ml/nn/rope"
|
||||
"github.com/ollama/ollama/model/input"
|
||||
)
|
||||
|
|
@ -165,8 +166,7 @@ func (sa *TextSelfAttention) Forward(ctx ml.Context, layer int, hiddenState, pos
|
|||
v := sa.Value.Forward(ctx, hiddenState)
|
||||
v = v.Reshape(ctx, opts.attnValLen, opts.numKVHeads, batchSize)
|
||||
|
||||
scaleFactor := 1.0
|
||||
kqv := nn.Attention(ctx, q, k, v, scaleFactor, cache)
|
||||
kqv := nn.Attention(ctx, q, k, v, cache, attention.WithScale(1))
|
||||
kqv = kqv.Reshape(ctx, opts.attnValLen*opts.numHeads, batchSize)
|
||||
|
||||
return sa.Output.Forward(ctx, kqv)
|
||||
|
|
|
|||
|
|
@ -1,8 +1,6 @@
|
|||
package gemma3
|
||||
|
||||
import (
|
||||
"math"
|
||||
|
||||
"github.com/ollama/ollama/fs"
|
||||
"github.com/ollama/ollama/ml"
|
||||
"github.com/ollama/ollama/ml/nn"
|
||||
|
|
@ -28,7 +26,7 @@ func (sa *VisionSelfAttention) Forward(ctx ml.Context, hiddenState ml.Tensor, op
|
|||
key = key.Reshape(ctx, headDim, opts.numHeads, key.Dim(1), batchSize)
|
||||
value = value.Reshape(ctx, headDim, opts.numHeads, value.Dim(1), batchSize)
|
||||
|
||||
attention := nn.Attention(ctx, query, key, value, 1.0/math.Sqrt(float64(headDim)), nil)
|
||||
attention := nn.Attention(ctx, query, key, value, nil)
|
||||
attention = attention.Reshape(ctx, opts.hiddenSize, attention.Dim(2), batchSize)
|
||||
|
||||
hiddenState = sa.Output.Forward(ctx, attention)
|
||||
|
|
|
|||
|
|
@ -8,6 +8,7 @@ import (
|
|||
"github.com/ollama/ollama/kvcache"
|
||||
"github.com/ollama/ollama/ml"
|
||||
"github.com/ollama/ollama/ml/nn"
|
||||
"github.com/ollama/ollama/ml/nn/attention"
|
||||
"github.com/ollama/ollama/ml/nn/rope"
|
||||
"github.com/ollama/ollama/model/input"
|
||||
)
|
||||
|
|
@ -269,9 +270,9 @@ func (attn TextAttention) Forward(ctx ml.Context, hiddenStates, positions ml.Ten
|
|||
value = value.RMSNorm(ctx, nil, opts.eps)
|
||||
}
|
||||
|
||||
attention := nn.Attention(ctx, query, key, value, 1., cache)
|
||||
attention = attention.Reshape(ctx, attention.Dim(0)*attention.Dim(1), batchSize)
|
||||
return attn.Output.Forward(ctx, attention)
|
||||
hiddenStates = nn.Attention(ctx, query, key, value, cache, attention.WithScale(1))
|
||||
hiddenStates = hiddenStates.Reshape(ctx, hiddenStates.Dim(0)*hiddenStates.Dim(1), batchSize)
|
||||
return attn.Output.Forward(ctx, hiddenStates)
|
||||
}
|
||||
|
||||
type TextMLP struct {
|
||||
|
|
|
|||
|
|
@ -2,13 +2,13 @@ package gptoss
|
|||
|
||||
import (
|
||||
"cmp"
|
||||
"math"
|
||||
"strings"
|
||||
|
||||
"github.com/ollama/ollama/fs"
|
||||
"github.com/ollama/ollama/kvcache"
|
||||
"github.com/ollama/ollama/ml"
|
||||
"github.com/ollama/ollama/ml/nn"
|
||||
"github.com/ollama/ollama/ml/nn/attention"
|
||||
"github.com/ollama/ollama/ml/nn/rope"
|
||||
"github.com/ollama/ollama/model"
|
||||
"github.com/ollama/ollama/model/input"
|
||||
|
|
@ -137,9 +137,9 @@ func (attn *AttentionBlock) Forward(ctx ml.Context, hiddenStates, positions ml.T
|
|||
query = opts.applyRotaryPositionEmbeddings(ctx, query, positions)
|
||||
key = opts.applyRotaryPositionEmbeddings(ctx, key, positions)
|
||||
|
||||
attention := nn.AttentionWithSinks(ctx, query, key, value, attn.Sinks, 1/math.Sqrt(float64(opts.headDim())), cache)
|
||||
attention = attention.Reshape(ctx, attention.Dim(0)*attention.Dim(1), batchSize)
|
||||
return attn.Output.Forward(ctx, attention).Add(ctx, residual)
|
||||
hiddenStates = nn.Attention(ctx, query, key, value, cache, attention.WithSinks(attn.Sinks))
|
||||
hiddenStates = hiddenStates.Reshape(ctx, hiddenStates.Dim(0)*hiddenStates.Dim(1), batchSize)
|
||||
return attn.Output.Forward(ctx, hiddenStates).Add(ctx, residual)
|
||||
}
|
||||
|
||||
type MLPBlock struct {
|
||||
|
|
|
|||
|
|
@ -2,7 +2,6 @@ package llama
|
|||
|
||||
import (
|
||||
"cmp"
|
||||
"math"
|
||||
|
||||
"github.com/ollama/ollama/fs"
|
||||
"github.com/ollama/ollama/kvcache"
|
||||
|
|
@ -131,7 +130,7 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positions ml.Tenso
|
|||
query = opts.applyRotaryPositionEmbeddings(ctx, query, positions, sa.RopeFactors)
|
||||
key = opts.applyRotaryPositionEmbeddings(ctx, key, positions, sa.RopeFactors)
|
||||
|
||||
attention := nn.Attention(ctx, query, key, value, 1.0/math.Sqrt(float64(headDim)), cache)
|
||||
attention := nn.Attention(ctx, query, key, value, cache)
|
||||
attention = attention.Reshape(ctx, headDim*opts.numHeads, batchSize)
|
||||
return sa.Output.Forward(ctx, attention)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -45,7 +45,7 @@ func (sa *TextAttention) Forward(ctx ml.Context, hiddenStates, positions, attent
|
|||
query = query.Mul(ctx, attentionScales)
|
||||
}
|
||||
|
||||
attention := nn.Attention(ctx, query, key, value, 1./math.Sqrt(float64(headDim)), cache)
|
||||
attention := nn.Attention(ctx, query, key, value, cache)
|
||||
attention = attention.Reshape(ctx, opts.hiddenSize, batchSize)
|
||||
return sa.Output.Forward(ctx, attention)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -72,7 +72,7 @@ func (sa *VisionAttention) Forward(ctx ml.Context, hiddenState, cos, sin ml.Tens
|
|||
query = applyVisionRotaryEmbedding(ctx, query, cos, sin)
|
||||
key = applyVisionRotaryEmbedding(ctx, key, cos, sin)
|
||||
|
||||
attention := nn.Attention(ctx, query, key, value, 1./math.Sqrt(float64(headDim)), nil)
|
||||
attention := nn.Attention(ctx, query, key, value, nil)
|
||||
attention = attention.Reshape(ctx, opts.hiddenSize, attention.Dim(2), attention.Dim(3))
|
||||
return sa.Output.Forward(ctx, attention)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -79,7 +79,7 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs, posit
|
|||
q = q.Mul(ctx, positionsScale)
|
||||
}
|
||||
|
||||
kqv := nn.Attention(ctx, q, k, v, 1.0/math.Sqrt(float64(headDim)), cache)
|
||||
kqv := nn.Attention(ctx, q, k, v, cache)
|
||||
kqv = kqv.Reshape(ctx, headDim*opts.numHeads, batchSize)
|
||||
return sa.Output.Forward(ctx, kqv)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -39,7 +39,7 @@ func (sa *VisionSelfAttention) Forward(ctx ml.Context, hiddenStates, cos, sin ml
|
|||
query = applyRotaryPositionEmbeddings(ctx, query, cos, sin)
|
||||
key = applyRotaryPositionEmbeddings(ctx, key, cos, sin)
|
||||
|
||||
attention := nn.Attention(ctx, query, key, value, 1./math.Sqrt(float64(opts.headDim)), nil)
|
||||
attention := nn.Attention(ctx, query, key, value, nil)
|
||||
attention = attention.Reshape(ctx, opts.hiddenSize, attention.Dim(2), batchSize)
|
||||
return sa.Output.Forward(ctx, attention)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,7 +1,6 @@
|
|||
package mllama
|
||||
|
||||
import (
|
||||
"math"
|
||||
"slices"
|
||||
|
||||
"github.com/ollama/ollama/fs"
|
||||
|
|
@ -34,8 +33,7 @@ func (sa *TextSelfAttention) Forward(ctx ml.Context, hiddenState, positions ml.T
|
|||
value := sa.Value.Forward(ctx, hiddenState)
|
||||
value = value.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
|
||||
|
||||
scaleFactor := 1.0 / math.Sqrt(float64(headDim))
|
||||
attention := nn.Attention(ctx, query, key, value, scaleFactor, cache)
|
||||
attention := nn.Attention(ctx, query, key, value, cache)
|
||||
attention = attention.Reshape(ctx, opts.hiddenSize, batchSize)
|
||||
|
||||
return sa.Output.Forward(ctx, attention)
|
||||
|
|
@ -122,20 +120,7 @@ func (ca *TextCrossAttention) Forward(ctx ml.Context, hiddenState, crossAttentio
|
|||
}
|
||||
|
||||
key, value, _ = cache.Get(ctx)
|
||||
|
||||
scaleFactor := 1.0 / math.Sqrt(float64(headDim))
|
||||
|
||||
query = query.Permute(ctx, 0, 2, 1, 3)
|
||||
key = key.Permute(ctx, 0, 2, 1, 3)
|
||||
value = value.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx)
|
||||
|
||||
kq := key.MulmatFullPrec(ctx, query)
|
||||
|
||||
kq = kq.Scale(ctx, scaleFactor)
|
||||
kq = kq.Softmax(ctx)
|
||||
|
||||
kqv := value.Mulmat(ctx, kq)
|
||||
attention := kqv.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
|
||||
attention := nn.Attention(ctx, query, key, value, nil)
|
||||
attention = attention.Reshape(ctx, opts.hiddenSize, batchSize)
|
||||
|
||||
return ca.Output.Forward(ctx, attention)
|
||||
|
|
|
|||
|
|
@ -1,7 +1,6 @@
|
|||
package mllama
|
||||
|
||||
import (
|
||||
"math"
|
||||
"slices"
|
||||
|
||||
"github.com/ollama/ollama/fs"
|
||||
|
|
@ -30,7 +29,7 @@ func (sa *VisionSelfAttention) Forward(ctx ml.Context, hiddenState ml.Tensor, op
|
|||
value := sa.Value.Forward(ctx, hiddenState)
|
||||
value = value.Reshape(ctx, headDim, opts.numHeads, value.Dim(1), batchSize)
|
||||
|
||||
attention := nn.Attention(ctx, query, key, value, 1./math.Sqrt(float64(headDim)), nil)
|
||||
attention := nn.Attention(ctx, query, key, value, nil)
|
||||
attention = attention.Reshape(ctx, opts.hiddenSize, attention.Dim(2), batchSize)
|
||||
return sa.Output.Forward(ctx, attention)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -2,7 +2,6 @@ package nomicbert
|
|||
|
||||
import (
|
||||
"cmp"
|
||||
"math"
|
||||
|
||||
"github.com/ollama/ollama/fs"
|
||||
"github.com/ollama/ollama/ml"
|
||||
|
|
@ -166,7 +165,7 @@ func (a *Attention) Forward(ctx ml.Context, hiddenStates ml.Tensor, positions ml
|
|||
query = opts.applyRotaryPositionEmbeddings(ctx, query, positions)
|
||||
key = opts.applyRotaryPositionEmbeddings(ctx, key, positions)
|
||||
|
||||
attention := nn.Attention(ctx, query, key, value, 1.0/math.Sqrt(float64(opts.headDim)), nil)
|
||||
attention := nn.Attention(ctx, query, key, value, nil)
|
||||
|
||||
attention = attention.Reshape(ctx, opts.hiddenSize, batchSize)
|
||||
|
||||
|
|
|
|||
|
|
@ -2,7 +2,6 @@ package olmo3
|
|||
|
||||
import (
|
||||
"fmt"
|
||||
"math"
|
||||
|
||||
"github.com/ollama/ollama/fs"
|
||||
"github.com/ollama/ollama/kvcache"
|
||||
|
|
@ -132,7 +131,7 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positions ml.Tenso
|
|||
value := sa.Value.Forward(ctx, hiddenState)
|
||||
value = value.Reshape(ctx, headDim, m.numKVHeads, batchSize)
|
||||
|
||||
attention := nn.Attention(ctx, query, key, value, 1.0/math.Sqrt(float64(headDim)), cache)
|
||||
attention := nn.Attention(ctx, query, key, value, cache)
|
||||
attention = attention.Reshape(ctx, m.hiddenSize, batchSize)
|
||||
|
||||
return sa.Output.Forward(ctx, attention)
|
||||
|
|
|
|||
|
|
@ -3,7 +3,6 @@ package qwen2
|
|||
import (
|
||||
"cmp"
|
||||
"fmt"
|
||||
"math"
|
||||
"strings"
|
||||
|
||||
"github.com/ollama/ollama/fs"
|
||||
|
|
@ -48,7 +47,7 @@ func (attn Attention) Forward(ctx ml.Context, hiddenStates, positions ml.Tensor,
|
|||
query = opts.applyRotaryPositionEmbeddings(ctx, query, positions)
|
||||
key = opts.applyRotaryPositionEmbeddings(ctx, key, positions)
|
||||
|
||||
attention := nn.Attention(ctx, query, key, value, 1.0/math.Sqrt(float64(headDim)), cache)
|
||||
attention := nn.Attention(ctx, query, key, value, cache)
|
||||
attention = attention.Reshape(ctx, headDim*opts.numHeads, batchSize)
|
||||
|
||||
return attn.Output.Forward(ctx, attention)
|
||||
|
|
|
|||
|
|
@ -1,8 +1,6 @@
|
|||
package qwen25vl
|
||||
|
||||
import (
|
||||
"math"
|
||||
|
||||
"github.com/ollama/ollama/fs"
|
||||
"github.com/ollama/ollama/kvcache"
|
||||
"github.com/ollama/ollama/ml"
|
||||
|
|
@ -81,8 +79,7 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Ten
|
|||
v := sa.Value.Forward(ctx, hiddenState)
|
||||
v = v.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
|
||||
|
||||
scaleFactor := 1.0 / math.Sqrt(float64(headDim))
|
||||
kqv := nn.Attention(ctx, q, k, v, scaleFactor, cache)
|
||||
kqv := nn.Attention(ctx, q, k, v, cache)
|
||||
kqv = kqv.Reshape(ctx, opts.hiddenSize, batchSize)
|
||||
|
||||
return sa.Output.Forward(ctx, kqv)
|
||||
|
|
|
|||
|
|
@ -8,6 +8,7 @@ import (
|
|||
"github.com/ollama/ollama/ml"
|
||||
"github.com/ollama/ollama/ml/nn"
|
||||
"github.com/ollama/ollama/ml/nn/rope"
|
||||
"github.com/ollama/ollama/ml/nn/attention"
|
||||
)
|
||||
|
||||
func blockDiagonalMask(ctx ml.Context, seqLength int, bounds []int) ml.Tensor {
|
||||
|
|
@ -50,25 +51,9 @@ func (sa *VisionSelfAttention) Forward(ctx ml.Context, hiddenStates, positions,
|
|||
query = opts.applyRotaryPositionEmbeddings(ctx, query, positions)
|
||||
key = opts.applyRotaryPositionEmbeddings(ctx, key, positions)
|
||||
|
||||
// Scale factor for scaled dot-product attention
|
||||
scale := 1.0 / math.Sqrt(float64(opts.headDim))
|
||||
|
||||
// Scaled dot-product attention
|
||||
query = query.Permute(ctx, 0, 2, 1, 3)
|
||||
key = key.Permute(ctx, 0, 2, 1, 3)
|
||||
value = value.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx)
|
||||
|
||||
kq := key.MulmatFullPrec(ctx, query)
|
||||
kq = kq.Scale(ctx, scale)
|
||||
if mask != nil {
|
||||
kq = kq.Add(ctx, mask)
|
||||
}
|
||||
kq = kq.Softmax(ctx)
|
||||
kqv := value.Mulmat(ctx, kq)
|
||||
attention := kqv.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
|
||||
attention = attention.Reshape(ctx, opts.hiddenSize, attention.Dim(2))
|
||||
|
||||
return sa.Output.Forward(ctx, attention)
|
||||
hiddenStates = nn.Attention(ctx, query, key, value, nil, attention.WithMask(mask))
|
||||
hiddenStates = hiddenStates.Reshape(ctx, opts.hiddenSize, hiddenStates.Dim(2))
|
||||
return sa.Output.Forward(ctx, hiddenStates)
|
||||
}
|
||||
|
||||
// VisionMLP implements the multi-layer perceptron
|
||||
|
|
|
|||
|
|
@ -74,7 +74,7 @@ func (sa *Attention) Forward(ctx ml.Context, hiddenStates, positions ml.Tensor,
|
|||
query = opts.applyRotaryPositionEmbeddings(ctx, query, positions)
|
||||
key = opts.applyRotaryPositionEmbeddings(ctx, key, positions)
|
||||
|
||||
attention := nn.Attention(ctx, query, key, value, 1./math.Sqrt(float64(opts.headDim())), cache)
|
||||
attention := nn.Attention(ctx, query, key, value, cache)
|
||||
attention = attention.Reshape(ctx, attention.Dim(0)*attention.Dim(1), batchSize)
|
||||
return sa.Output.Forward(ctx, attention)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -66,7 +66,7 @@ func (sa *TextAttention) Forward(ctx ml.Context, hiddenStates, positions ml.Tens
|
|||
query = opts.applyRotaryPositionEmbeddings(ctx, query, positions)
|
||||
key = opts.applyRotaryPositionEmbeddings(ctx, key, positions)
|
||||
|
||||
attention := nn.Attention(ctx, query, key, value, 1./math.Sqrt(float64(opts.headDim())), cache)
|
||||
attention := nn.Attention(ctx, query, key, value, cache)
|
||||
attention = attention.Reshape(ctx, attention.Dim(0)*attention.Dim(1), batchSize)
|
||||
return sa.Output.Forward(ctx, attention)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -39,7 +39,7 @@ func (sa *VisionAttention) Forward(ctx ml.Context, hiddenStates, cos, sin ml.Ten
|
|||
value := sa.Value.Forward(ctx, hiddenStates)
|
||||
value = value.Reshape(ctx, opts.headDim(), opts.numHeads, value.Dim(1))
|
||||
|
||||
attention := nn.Attention(ctx, query, key, value, math.Pow(float64(opts.headDim()), -0.5), nil)
|
||||
attention := nn.Attention(ctx, query, key, value, nil)
|
||||
attention = attention.Reshape(ctx, opts.hiddenSize, attention.Dim(2))
|
||||
return sa.Output.Forward(ctx, attention)
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in New Issue