Compare commits
3 Commits
v0.11.5-rc
...
mxyng/gemm
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
12a7e5ec46 | ||
|
|
b323cfe731 | ||
|
|
05ccb17c6e |
@@ -378,9 +378,7 @@ func (c *Causal) buildMask(ctx ml.Context) ml.Tensor {
|
||||
maskTensor := ctx.Input().FromFloatSlice(mask, length, batchSize)
|
||||
|
||||
if c.config.MaskDType != ml.DTypeF32 {
|
||||
out := ctx.Input().Empty(c.config.MaskDType, maskTensor.Shape()...)
|
||||
ctx.Forward(maskTensor.Copy(ctx, out))
|
||||
maskTensor = out
|
||||
maskTensor = maskTensor.Cast(ctx, c.config.MaskDType)
|
||||
}
|
||||
|
||||
return maskTensor
|
||||
|
||||
@@ -396,6 +396,7 @@ type Tensor interface {
|
||||
|
||||
Shape() []int
|
||||
DType() DType
|
||||
Cast(ctx Context, dtype DType) Tensor
|
||||
|
||||
Bytes() []byte
|
||||
Floats() []float32
|
||||
|
||||
@@ -843,23 +843,7 @@ func (c *Context) newTensor(dtype ml.DType, shape []int) ml.Tensor {
|
||||
panic("set Input or Layer before creating tensors")
|
||||
}
|
||||
|
||||
var cdtype uint32
|
||||
switch dtype {
|
||||
case ml.DTypeF32:
|
||||
cdtype = C.GGML_TYPE_F32
|
||||
case ml.DTypeF16:
|
||||
cdtype = C.GGML_TYPE_F16
|
||||
case ml.DTypeQ80:
|
||||
cdtype = C.GGML_TYPE_Q8_0
|
||||
case ml.DTypeQ40:
|
||||
cdtype = C.GGML_TYPE_Q4_0
|
||||
case ml.DTypeI32:
|
||||
cdtype = C.GGML_TYPE_I32
|
||||
case ml.DTypeMXFP4:
|
||||
cdtype = C.GGML_TYPE_MXFP4
|
||||
default:
|
||||
panic("unsupported dtype")
|
||||
}
|
||||
cdtype := ggmlDType(dtype)
|
||||
|
||||
if len(shape) < 1 || shape[0] == 0 {
|
||||
var shape C.int64_t = 0
|
||||
@@ -1056,6 +1040,32 @@ func (t *Tensor) DType() ml.DType {
|
||||
}
|
||||
}
|
||||
|
||||
func ggmlDType(dtype ml.DType) uint32 {
|
||||
switch dtype {
|
||||
case ml.DTypeF32:
|
||||
return C.GGML_TYPE_F32
|
||||
case ml.DTypeF16:
|
||||
return C.GGML_TYPE_F16
|
||||
case ml.DTypeQ80:
|
||||
return C.GGML_TYPE_Q8_0
|
||||
case ml.DTypeQ40:
|
||||
return C.GGML_TYPE_Q4_0
|
||||
case ml.DTypeI32:
|
||||
return C.GGML_TYPE_I32
|
||||
case ml.DTypeMXFP4:
|
||||
return C.GGML_TYPE_MXFP4
|
||||
default:
|
||||
panic("unsupported dtype")
|
||||
}
|
||||
}
|
||||
|
||||
func (t *Tensor) Cast(ctx ml.Context, dtype ml.DType) ml.Tensor {
|
||||
return &Tensor{
|
||||
b: t.b,
|
||||
t: C.ggml_cast(ctx.(*Context).ctx, t.t, ggmlDType(dtype)),
|
||||
}
|
||||
}
|
||||
|
||||
func (t *Tensor) Neg(ctx ml.Context) ml.Tensor {
|
||||
return &Tensor{
|
||||
b: t.b,
|
||||
|
||||
@@ -69,10 +69,10 @@ func New(c fs.Config) (model.Model, error) {
|
||||
},
|
||||
}
|
||||
|
||||
slidingWindowLen := int32(c.Uint("attention.sliding_window"))
|
||||
m.Cache = kvcache.NewWrapperCache(kvcache.NewSWACache(slidingWindowLen, m.Shift), kvcache.NewCausalCache(m.Shift))
|
||||
m.Cache.SetConfig(ml.CacheConfig{})
|
||||
|
||||
m.Cache = kvcache.NewWrapperCache(
|
||||
kvcache.NewSWACache(int32(c.Uint("attention.sliding_window")), m.Shift),
|
||||
kvcache.NewCausalCache(m.Shift),
|
||||
)
|
||||
return &m, nil
|
||||
}
|
||||
|
||||
@@ -90,12 +90,6 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Ten
|
||||
q = q.Reshape(ctx, opts.attnKeyLen, opts.numHeads, batchSize)
|
||||
q = fast.RoPE(ctx, q, positionIDs, opts.attnKeyLen, opts.ropeBase, opts.ropeScale, rope.WithTypeNeoX())
|
||||
|
||||
if opts.largeModelScaling {
|
||||
q = q.Scale(ctx, 1.0/math.Sqrt(float64(opts.hiddenSize/opts.numHeads)))
|
||||
} else {
|
||||
q = q.Scale(ctx, 1.0/math.Sqrt(float64(opts.attnKeyLen)))
|
||||
}
|
||||
|
||||
k := sa.Key.Forward(ctx, hiddenState)
|
||||
k = k.Reshape(ctx, opts.attnKeyLen, opts.numKVHeads, batchSize)
|
||||
k = fast.RoPE(ctx, k, positionIDs, opts.attnKeyLen, opts.ropeBase, opts.ropeScale, rope.WithTypeNeoX())
|
||||
@@ -103,28 +97,14 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Ten
|
||||
v := sa.Value.Forward(ctx, hiddenState)
|
||||
v = v.Reshape(ctx, opts.attnValLen, opts.numKVHeads, batchSize)
|
||||
|
||||
cache.Put(ctx, k, v)
|
||||
k, v, mask := cache.Get(ctx)
|
||||
scale := 1.0 / math.Sqrt(float64(opts.attnKeyLen))
|
||||
if opts.largeModelScaling {
|
||||
scale = 1.0 / math.Sqrt(float64(opts.hiddenSize/opts.numHeads))
|
||||
}
|
||||
|
||||
q = q.Permute(ctx, 0, 2, 1, 3)
|
||||
k = k.Permute(ctx, 0, 2, 1, 3)
|
||||
v = v.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx)
|
||||
|
||||
kq := k.Mulmat(ctx, q)
|
||||
|
||||
// logit softcap
|
||||
kq = kq.Scale(ctx, 1.0/float64(opts.attnLogitSoftcap))
|
||||
kq = kq.Tanh(ctx)
|
||||
kq = kq.Scale(ctx, float64(opts.attnLogitSoftcap))
|
||||
|
||||
kq = kq.Add(ctx, mask)
|
||||
kq = kq.Softmax(ctx)
|
||||
|
||||
kqv := v.Mulmat(ctx, kq)
|
||||
kqv = kqv.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
|
||||
kqv = kqv.Reshape(ctx, opts.attnValLen*opts.numHeads, batchSize)
|
||||
|
||||
return sa.Output.Forward(ctx, kqv)
|
||||
attn := nn.Attention(ctx, q, k, v, scale, cache)
|
||||
attn = attn.Reshape(ctx, opts.attnValLen*opts.numHeads, batchSize)
|
||||
return sa.Output.Forward(ctx, attn)
|
||||
}
|
||||
|
||||
func (m *Model) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
|
||||
|
||||
@@ -86,12 +86,6 @@ func (sa *TextSelfAttention) Forward(ctx ml.Context, layer int, hiddenState, pos
|
||||
q = sa.QueryNorm.Forward(ctx, q, opts.eps)
|
||||
q = fast.RoPE(ctx, q, positionIDs, opts.attnKeyLen, ropeBase, opts.ropeScale, rope.WithTypeNeoX())
|
||||
|
||||
if opts.largeModelScaling {
|
||||
q = q.Scale(ctx, 1.0/math.Sqrt(float64(opts.hiddenSize/opts.numHeads)))
|
||||
} else {
|
||||
q = q.Scale(ctx, 1.0/math.Sqrt(float64(opts.attnKeyLen)))
|
||||
}
|
||||
|
||||
k := sa.Key.Forward(ctx, hiddenState)
|
||||
k = k.Reshape(ctx, opts.attnKeyLen, opts.numKVHeads, batchSize)
|
||||
k = sa.KeyNorm.Forward(ctx, k, opts.eps)
|
||||
@@ -100,8 +94,12 @@ func (sa *TextSelfAttention) Forward(ctx ml.Context, layer int, hiddenState, pos
|
||||
v := sa.Value.Forward(ctx, hiddenState)
|
||||
v = v.Reshape(ctx, opts.attnValLen, opts.numKVHeads, batchSize)
|
||||
|
||||
scaleFactor := 1.0
|
||||
kqv := nn.Attention(ctx, q, k, v, scaleFactor, cache)
|
||||
scale := 1.0 / math.Sqrt(float64(opts.attnKeyLen))
|
||||
if opts.largeModelScaling {
|
||||
scale = 1.0 / math.Sqrt(float64(opts.hiddenSize/opts.numHeads))
|
||||
}
|
||||
|
||||
kqv := nn.Attention(ctx, q, k, v, scale, cache)
|
||||
kqv = kqv.Reshape(ctx, opts.attnValLen*opts.numHeads, batchSize)
|
||||
|
||||
return sa.Output.Forward(ctx, kqv)
|
||||
|
||||
Reference in New Issue
Block a user