Compare commits
3 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
8253ad4d2b | ||
|
|
fa7776fd24 | ||
|
|
0d38b66502 |
@@ -80,7 +80,6 @@ func (m *gptossModel) Tensors(ts []Tensor) []*ggml.Tensor {
|
||||
case "scales":
|
||||
mxfp4s[name].scales = t
|
||||
}
|
||||
|
||||
} else {
|
||||
out = append(out, &ggml.Tensor{
|
||||
Name: t.Name(),
|
||||
|
||||
@@ -54,7 +54,7 @@ func (t tensorBase) Kind() uint32 {
|
||||
case 1:
|
||||
return tensorKindFP32
|
||||
default:
|
||||
return tensorKindBF16
|
||||
return tensorKindFP16
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -93,6 +93,15 @@ type safetensor struct {
|
||||
*tensorBase
|
||||
}
|
||||
|
||||
func (st safetensor) Kind() uint32 {
|
||||
kind := st.tensorBase.Kind()
|
||||
if st.dtype == "BF16" && kind != tensorKindFP32 {
|
||||
kind = tensorKindBF16
|
||||
}
|
||||
|
||||
return kind
|
||||
}
|
||||
|
||||
func (st safetensor) Clone() Tensor {
|
||||
return &safetensor{
|
||||
fs: st.fs,
|
||||
|
||||
@@ -761,6 +761,10 @@ func (f GGML) SupportsFlashAttention() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
if f.KV().Architecture() == "gptoss" {
|
||||
return false
|
||||
}
|
||||
|
||||
// Check head counts match and are non-zero
|
||||
headCountK := f.KV().EmbeddingHeadCountK()
|
||||
headCountV := f.KV().EmbeddingHeadCountV()
|
||||
|
||||
@@ -214,6 +214,7 @@ func (c *Causal) StartForward(ctx ml.Context, batch input.Batch, reserve bool) e
|
||||
c.curLoc, err = c.findStartLoc()
|
||||
}
|
||||
if err != nil {
|
||||
slog.Warn("unable to find a kv cache slot", "cache", c)
|
||||
return err
|
||||
}
|
||||
|
||||
|
||||
120
ml/backend/ggml/ggml/src/ggml-cuda/mmvmxfp4.cu
vendored
120
ml/backend/ggml/ggml/src/ggml-cuda/mmvmxfp4.cu
vendored
@@ -10,8 +10,8 @@ typedef union {
|
||||
|
||||
template <typename type_acc, int block_size> // TODO type_acc unused - consider bf16 support
|
||||
static __global__ void mul_mat_vec_mxfp4(
|
||||
const block_mxfp4 * __restrict__ x_base, const float * __restrict__ y, const int32_t * __restrict__ ids, float * __restrict__ dst,
|
||||
const int64_t ncols, const int64_t nchannels_y, const int64_t stride_row,
|
||||
const block_mxfp4 * __restrict__ x, const float * __restrict__ y, const int32_t * __restrict__ ids, float * __restrict__ dst,
|
||||
const int64_t ncols2, const int64_t nchannels_y, const int64_t stride_row,
|
||||
const int64_t channel_ratio, const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst,
|
||||
const int64_t sample_ratio, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst) {
|
||||
const int64_t row = blockIdx.x;
|
||||
@@ -23,20 +23,16 @@ static __global__ void mul_mat_vec_mxfp4(
|
||||
const int64_t sample_y = sample_dst;
|
||||
const int tid = threadIdx.x;
|
||||
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
|
||||
const int64_t ncols8 = ncols / 8;
|
||||
|
||||
const uint16_t dst_bias = 15;
|
||||
const uint16_t dst_0p5 = 0x3800;
|
||||
const uint16_t dst_m_bits = 10;
|
||||
|
||||
// x_base is offset by blocks of 32 elements
|
||||
x_base += sample_x *stride_sample_x + channel_x *stride_channel_x + row*stride_row;
|
||||
// y is offset by elements
|
||||
y += sample_y *stride_sample_y + channel_y *stride_channel_y;
|
||||
// dst is offset by elements
|
||||
x += sample_x *stride_sample_x + channel_x *stride_channel_x + row*stride_row;
|
||||
y += sample_y *stride_sample_y + channel_y *stride_channel_y;
|
||||
dst += sample_dst*stride_sample_dst + channel_dst*stride_channel_dst;
|
||||
|
||||
const float4 * y4 = (const float4 *) y;
|
||||
const float2 * y2 = (const float2 *) y;
|
||||
|
||||
extern __shared__ char data_mmv[]; // allocated in GPU shared memory: warp_size*sizeof(float)
|
||||
float * buf_iw = (float *) data_mmv;
|
||||
@@ -50,72 +46,50 @@ static __global__ void mul_mat_vec_mxfp4(
|
||||
|
||||
float sumf = 0.0f;
|
||||
|
||||
// each i8 index proceses 8 items at a time
|
||||
for (int64_t i8 = tid; i8 < ncols8; i8 += block_size) {
|
||||
// As i8 indexes past a block, we have to offset further
|
||||
int offset0 = i8 / (MXFP4/8);
|
||||
int xi = (i8 % (MXFP4/8)) * 4; // jump 4 bytes for each 8 elements
|
||||
const block_mxfp4 *x = x_base+offset0;
|
||||
for (int64_t col2 = tid; col2 < ncols2; col2 += block_size) {
|
||||
int offset0 = col2 / (MXFP4/2);
|
||||
int i = col2 % (MXFP4/2);
|
||||
const block_mxfp4 *x2 = x+offset0;
|
||||
|
||||
union {
|
||||
uint32_t as_bits;
|
||||
float as_value;
|
||||
} scale;
|
||||
scale.as_bits = (((uint32_t)x->d) << 23);
|
||||
scale.as_bits = (((uint32_t)x2->d) << 23);
|
||||
uint16_t em0 = x2->qs[i] & 0x07;
|
||||
uint16_t em1 = x2->qs[i] & 0x70;
|
||||
// float16 values
|
||||
f16_t x0;
|
||||
f16_t x1;
|
||||
x0.u16 = (em0 << (dst_m_bits - 1)) | ((x2->qs[i] & 0x08) << 12);
|
||||
x1.u16 = (em1 << (dst_m_bits - 5)) | ((x2->qs[i] & 0x80) << 8);
|
||||
|
||||
// Three cases:
|
||||
// x is normal and non-zero: Correct bias
|
||||
if ((em0 & 0x06) != 0) {
|
||||
x0.u16 = x0.u16 + ((dst_bias - 1) << dst_m_bits);
|
||||
}
|
||||
if ((em1 & 0x60) != 0) {
|
||||
x1.u16 = x1.u16 + ((dst_bias - 1) << dst_m_bits);
|
||||
}
|
||||
// x is subnormal (x == 0bs001 where s is the sign): Map to +-0.5 in the dst type
|
||||
if (em0 == 0x01) {
|
||||
x0.u16 = dst_0p5 | (x0.u16 & 0x8000);
|
||||
}
|
||||
if (em1 == 0x10) {
|
||||
x1.u16 = dst_0p5 | (x1.u16 & 0x8000);
|
||||
}
|
||||
// x is zero, do nothing
|
||||
|
||||
if (isnan(scale.as_value)) {
|
||||
sumf = scale.as_value;
|
||||
break;
|
||||
}
|
||||
const uint8_t qs[4] = {
|
||||
(uint8_t)(x->qs[xi]),
|
||||
(uint8_t)(x->qs[xi+1]),
|
||||
(uint8_t)(x->qs[xi+2]),
|
||||
(uint8_t)(x->qs[xi+3])
|
||||
};
|
||||
|
||||
const uint8_t el[8] = {
|
||||
(uint8_t)(qs[0] & 0xf),
|
||||
(uint8_t)((qs[0] & 0xf0) >> 4),
|
||||
(uint8_t)(qs[1] & 0xf),
|
||||
(uint8_t)((qs[1] & 0xf0) >> 4),
|
||||
(uint8_t)(qs[2] & 0xf),
|
||||
(uint8_t)((qs[2] & 0xf0) >> 4),
|
||||
(uint8_t)(qs[3] & 0xf),
|
||||
(uint8_t)((qs[3] & 0xf0) >> 4)
|
||||
};
|
||||
|
||||
uint16_t em[8];
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 8; i++) { em[i] = (uint16_t)(el[i] & 0x07); }
|
||||
|
||||
// float16 values
|
||||
f16_t x4u[8];
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 8; i++) { x4u[i].u16 = (em[i] << (dst_m_bits - 1)) | ((el[i] & 0x08) << 12); }
|
||||
|
||||
// Three cases:
|
||||
// x is normal and non-zero: Correct bias
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 8; i++) { if ((em[i] & 0x06) != 0) { x4u[i].u16 = x4u[i].u16 + ((dst_bias - 1) << dst_m_bits); } }
|
||||
|
||||
// x is subnormal (x == 0bs001 where s is the sign): Map to +-0.5 in the dst type
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 8; i++) { if (em[i] == 0x01) { x4u[i].u16 = dst_0p5 | (x4u[i].u16 & 0x8000); } }
|
||||
// x is zero, do nothing
|
||||
|
||||
const float scalef = scale.as_value;
|
||||
const float4 tmpx0 = {x4u[0].f16, x4u[1].f16, x4u[2].f16, x4u[3].f16};
|
||||
const float4 tmpx1 = {x4u[4].f16, x4u[5].f16, x4u[6].f16, x4u[7].f16};
|
||||
const float4 tmpy0 = y4[i8*2];
|
||||
const float4 tmpy1 = y4[i8*2+1];
|
||||
sumf += tmpx0.x * tmpy0.x * scalef;
|
||||
sumf += tmpx0.y * tmpy0.y * scalef;
|
||||
sumf += tmpx0.z * tmpy0.z * scalef;
|
||||
sumf += tmpx0.w * tmpy0.w * scalef;
|
||||
sumf += tmpx1.x * tmpy1.x * scalef;
|
||||
sumf += tmpx1.y * tmpy1.y * scalef;
|
||||
sumf += tmpx1.z * tmpy1.z * scalef;
|
||||
sumf += tmpx1.w * tmpy1.w * scalef;
|
||||
const float2 tmpx = {x0.f16, x1.f16};
|
||||
const float2 tmpy = y2[col2];
|
||||
sumf += tmpx.x*tmpy.x*scale.as_value;
|
||||
sumf += tmpx.y*tmpy.y*scale.as_value;
|
||||
}
|
||||
|
||||
sumf = warp_reduce_sum<warp_size>(sumf);
|
||||
@@ -177,42 +151,42 @@ static void launch_mul_mat_vec_cuda_mxfp4(
|
||||
switch (block_size_best) {
|
||||
case 32: {
|
||||
mul_mat_vec_mxfp4<type_acc, 32><<<block_nums, block_dims, smem, stream>>>
|
||||
(x, y, ids, dst, ncols, nchannels_y, stride_row, channel_ratio, stride_channel_x, stride_channel_y,
|
||||
(x, y, ids, dst, ncols/2, nchannels_y, stride_row, channel_ratio, stride_channel_x, stride_channel_y,
|
||||
stride_channel_dst, sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
|
||||
} break;
|
||||
case 64: {
|
||||
mul_mat_vec_mxfp4<type_acc, 64><<<block_nums, block_dims, smem, stream>>>
|
||||
(x, y, ids, dst, ncols, nchannels_y, stride_row, channel_ratio, stride_channel_x, stride_channel_y,
|
||||
(x, y, ids, dst, ncols/2, nchannels_y, stride_row, channel_ratio, stride_channel_x, stride_channel_y,
|
||||
stride_channel_dst, sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
|
||||
} break;
|
||||
case 96: {
|
||||
mul_mat_vec_mxfp4<type_acc, 96><<<block_nums, block_dims, smem, stream>>>
|
||||
(x, y, ids, dst, ncols, nchannels_y, stride_row, channel_ratio, stride_channel_x, stride_channel_y,
|
||||
(x, y, ids, dst, ncols/2, nchannels_y, stride_row, channel_ratio, stride_channel_x, stride_channel_y,
|
||||
stride_channel_dst, sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
|
||||
} break;
|
||||
case 128: {
|
||||
mul_mat_vec_mxfp4<type_acc, 128><<<block_nums, block_dims, smem, stream>>>
|
||||
(x, y, ids, dst, ncols, nchannels_y, stride_row, channel_ratio, stride_channel_x, stride_channel_y,
|
||||
(x, y, ids, dst, ncols/2, nchannels_y, stride_row, channel_ratio, stride_channel_x, stride_channel_y,
|
||||
stride_channel_dst, sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
|
||||
} break;
|
||||
case 160: {
|
||||
mul_mat_vec_mxfp4<type_acc, 160><<<block_nums, block_dims, smem, stream>>>
|
||||
(x, y, ids, dst, ncols, nchannels_y, stride_row, channel_ratio, stride_channel_x, stride_channel_y,
|
||||
(x, y, ids, dst, ncols/2, nchannels_y, stride_row, channel_ratio, stride_channel_x, stride_channel_y,
|
||||
stride_channel_dst, sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
|
||||
} break;
|
||||
case 192: {
|
||||
mul_mat_vec_mxfp4<type_acc, 192><<<block_nums, block_dims, smem, stream>>>
|
||||
(x, y, ids, dst, ncols, nchannels_y, stride_row, channel_ratio, stride_channel_x, stride_channel_y,
|
||||
(x, y, ids, dst, ncols/2, nchannels_y, stride_row, channel_ratio, stride_channel_x, stride_channel_y,
|
||||
stride_channel_dst, sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
|
||||
} break;
|
||||
case 224: {
|
||||
mul_mat_vec_mxfp4<type_acc, 224><<<block_nums, block_dims, smem, stream>>>
|
||||
(x, y, ids, dst, ncols, nchannels_y, stride_row, channel_ratio, stride_channel_x, stride_channel_y,
|
||||
(x, y, ids, dst, ncols/2, nchannels_y, stride_row, channel_ratio, stride_channel_x, stride_channel_y,
|
||||
stride_channel_dst, sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
|
||||
} break;
|
||||
case 256: {
|
||||
mul_mat_vec_mxfp4<type_acc, 256><<<block_nums, block_dims, smem, stream>>>
|
||||
(x, y, ids, dst, ncols, nchannels_y, stride_row, channel_ratio, stride_channel_x, stride_channel_y,
|
||||
(x, y, ids, dst, ncols/2, nchannels_y, stride_row, channel_ratio, stride_channel_x, stride_channel_y,
|
||||
stride_channel_dst, sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
|
||||
} break;
|
||||
default: {
|
||||
|
||||
@@ -22,27 +22,25 @@ import (
|
||||
|
||||
// MXFP4 reference: https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf
|
||||
|
||||
var (
|
||||
// E2M1 values
|
||||
mxfp4_vals = []float32{
|
||||
0.0, // 0 00 0 = 0x0
|
||||
0.5, // 0 00 1 = 0x1
|
||||
1.0, // 0 01 0 = 0x2
|
||||
1.5, // 0 01 1 = 0x3
|
||||
2.0, // 0 10 0 = 0x4
|
||||
3.0, // 0 10 1 = 0x5
|
||||
4.0, // 0 11 0 = 0x6
|
||||
6.0, // 0 11 1 = 0x7
|
||||
0.0, // 1 00 0 = 0x8
|
||||
-0.5, // 1 00 1 = 0x9
|
||||
-1.0, // 1 01 0 = 0xa
|
||||
-1.5, // 1 01 1 = 0xb
|
||||
-2.0, // 1 10 0 = 0xc
|
||||
-3.0, // 1 10 1 = 0xd
|
||||
-4.0, // 1 11 0 = 0xe
|
||||
-6.0, // 1 11 1 = 0xf
|
||||
}
|
||||
)
|
||||
// E2M1 values
|
||||
var mxfp4_vals = []float32{
|
||||
0.0, // 0 00 0 = 0x0
|
||||
0.5, // 0 00 1 = 0x1
|
||||
1.0, // 0 01 0 = 0x2
|
||||
1.5, // 0 01 1 = 0x3
|
||||
2.0, // 0 10 0 = 0x4
|
||||
3.0, // 0 10 1 = 0x5
|
||||
4.0, // 0 11 0 = 0x6
|
||||
6.0, // 0 11 1 = 0x7
|
||||
0.0, // 1 00 0 = 0x8
|
||||
-0.5, // 1 00 1 = 0x9
|
||||
-1.0, // 1 01 0 = 0xa
|
||||
-1.5, // 1 01 1 = 0xb
|
||||
-2.0, // 1 10 0 = 0xc
|
||||
-3.0, // 1 10 1 = 0xd
|
||||
-4.0, // 1 11 0 = 0xe
|
||||
-6.0, // 1 11 1 = 0xf
|
||||
}
|
||||
|
||||
func TestMXFP4Ops(t *testing.T) {
|
||||
b := setup(t)
|
||||
@@ -412,11 +410,9 @@ func TestMXFP4Ops(t *testing.T) {
|
||||
}
|
||||
// t.Logf("MulmatID results matched:\n%s", d4)
|
||||
})
|
||||
|
||||
})
|
||||
|
||||
t.Run("mm", func(t *testing.T) {
|
||||
|
||||
t.Run("example", func(t *testing.T) {
|
||||
r := rand.New(rand.NewSource(0))
|
||||
ctx := initContextOrSkip(t, b, useGPU)
|
||||
@@ -735,7 +731,6 @@ func TestMXFP4Simple(t *testing.T) {
|
||||
}
|
||||
t.Logf("result (mxfp4): \n%s", d3)
|
||||
})
|
||||
|
||||
}
|
||||
|
||||
func TestMXFP4Conversion(t *testing.T) {
|
||||
@@ -744,7 +739,7 @@ func TestMXFP4Conversion(t *testing.T) {
|
||||
|
||||
data := [32 * 4]float32{}
|
||||
for i := range data {
|
||||
data[i] = mxfp4_vals[r.Int()%len(mxfp4_vals)] * 0.1
|
||||
data[i] = mxfp4_vals[r.Int()%len(mxfp4_vals)]
|
||||
}
|
||||
mxData := Quantize(fsggml.TensorTypeMXFP4, data[:], []uint64{uint64(len(data))})
|
||||
newData := ConvertToF32(mxData, uint32(fsggml.TensorTypeMXFP4), uint64(len(data)))
|
||||
|
||||
Reference in New Issue
Block a user