Compare commits

..

3 Commits

Author SHA1 Message Date
Jesse Gross
8253ad4d2b ggml: Prevent kv cache quanitization on gpt-oss
KV cache quantization has a dependency on the flash attention kernel.
We currently cannot use flash attention with gpt-oss as it requires
additional operations.

The model definition does not call flash attention, so it works
regardless of the setting but the cache will pick up the
quantization type. This updates the flash attention setting earlier
in the loading flow so that all downstream settings are also set correctly.

Fixes: #11671
2025-08-05 13:04:03 -07:00
Michael Yang
fa7776fd24 gpt-oss (#11672)
* bf16

* tests

* gpt-oss

* enable gptoss for engine

* rough estimate

* convert to mxfp4

* handle safetensors U8

* clamp glu/linear

* update tokenizer

* MXFP4 support

This implements the Open Compute Microscaling (MX) FP4 format
as a tensor type with backend implementations focusing
on mulmat and mulmatid on CPU, CUDA, and Metal.

* Unit tests for MXFP4 support

This exercises various operations and shapes on both CPU and GPU (if detected
on the system)

* cuda graph

* unit test adjustments

* cuda: optimize memory access

Read 4 bytes at a time (8 elements) when performing mul_mat_vec_mxfp4

* mac: fix crash on old macos versions

cblas_sgemm is only supported on v13.3 and up, however bf16 is
only supported on v14+ so we were falling back to ggml-blas and
crashing on bf16 tensors.  Checking for the function being null
seems to be the simplest way to condittionally avoid registering the
backend.

* server: Minimum context length for gptoss

This model requires a minimum context length of 8192 to function
effectively. Users can set higher values through all normal mechanisms
but lower values will be silently reset.

* ggml: Multiply by numParallel for gptoss sliding window

When computing the graph size estimate, the context size is already
multiplied by numParallel so estimates reflect that. However, since
sliding window models use a smaller, fixed context size, they need
to manually take numParallel into account.

* gpt-oss integration

includes harmony parser and thinking levels, etc.

* fix sync

* fix tests

* fix lint

---------

Co-authored-by: Daniel Hiltgen <daniel@ollama.com>
Co-authored-by: Jesse Gross <jesse@ollama.com>
Co-authored-by: Devon Rifkin <drifkin@drifkin.net>
2025-08-05 12:21:16 -07:00
Jesse Gross
0d38b66502 kvcache: Log contents of cache when unable to find a slot
There is a bug when using sliding window attention where we run
out of KV cache slots. This is likely due to not correctly removing
all of the entries as they slide out of range. This adds additional
logging when this occurs to track down the source.

Bug #10127
2025-08-04 16:59:29 -07:00
7 changed files with 82 additions and 100 deletions

View File

@@ -80,7 +80,6 @@ func (m *gptossModel) Tensors(ts []Tensor) []*ggml.Tensor {
case "scales":
mxfp4s[name].scales = t
}
} else {
out = append(out, &ggml.Tensor{
Name: t.Name(),

View File

@@ -54,7 +54,7 @@ func (t tensorBase) Kind() uint32 {
case 1:
return tensorKindFP32
default:
return tensorKindBF16
return tensorKindFP16
}
}

View File

@@ -93,6 +93,15 @@ type safetensor struct {
*tensorBase
}
func (st safetensor) Kind() uint32 {
kind := st.tensorBase.Kind()
if st.dtype == "BF16" && kind != tensorKindFP32 {
kind = tensorKindBF16
}
return kind
}
func (st safetensor) Clone() Tensor {
return &safetensor{
fs: st.fs,

View File

@@ -761,6 +761,10 @@ func (f GGML) SupportsFlashAttention() bool {
return false
}
if f.KV().Architecture() == "gptoss" {
return false
}
// Check head counts match and are non-zero
headCountK := f.KV().EmbeddingHeadCountK()
headCountV := f.KV().EmbeddingHeadCountV()

View File

@@ -214,6 +214,7 @@ func (c *Causal) StartForward(ctx ml.Context, batch input.Batch, reserve bool) e
c.curLoc, err = c.findStartLoc()
}
if err != nil {
slog.Warn("unable to find a kv cache slot", "cache", c)
return err
}

View File

@@ -10,8 +10,8 @@ typedef union {
template <typename type_acc, int block_size> // TODO type_acc unused - consider bf16 support
static __global__ void mul_mat_vec_mxfp4(
const block_mxfp4 * __restrict__ x_base, const float * __restrict__ y, const int32_t * __restrict__ ids, float * __restrict__ dst,
const int64_t ncols, const int64_t nchannels_y, const int64_t stride_row,
const block_mxfp4 * __restrict__ x, const float * __restrict__ y, const int32_t * __restrict__ ids, float * __restrict__ dst,
const int64_t ncols2, const int64_t nchannels_y, const int64_t stride_row,
const int64_t channel_ratio, const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst,
const int64_t sample_ratio, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst) {
const int64_t row = blockIdx.x;
@@ -23,20 +23,16 @@ static __global__ void mul_mat_vec_mxfp4(
const int64_t sample_y = sample_dst;
const int tid = threadIdx.x;
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
const int64_t ncols8 = ncols / 8;
const uint16_t dst_bias = 15;
const uint16_t dst_0p5 = 0x3800;
const uint16_t dst_m_bits = 10;
// x_base is offset by blocks of 32 elements
x_base += sample_x *stride_sample_x + channel_x *stride_channel_x + row*stride_row;
// y is offset by elements
y += sample_y *stride_sample_y + channel_y *stride_channel_y;
// dst is offset by elements
x += sample_x *stride_sample_x + channel_x *stride_channel_x + row*stride_row;
y += sample_y *stride_sample_y + channel_y *stride_channel_y;
dst += sample_dst*stride_sample_dst + channel_dst*stride_channel_dst;
const float4 * y4 = (const float4 *) y;
const float2 * y2 = (const float2 *) y;
extern __shared__ char data_mmv[]; // allocated in GPU shared memory: warp_size*sizeof(float)
float * buf_iw = (float *) data_mmv;
@@ -50,72 +46,50 @@ static __global__ void mul_mat_vec_mxfp4(
float sumf = 0.0f;
// each i8 index proceses 8 items at a time
for (int64_t i8 = tid; i8 < ncols8; i8 += block_size) {
// As i8 indexes past a block, we have to offset further
int offset0 = i8 / (MXFP4/8);
int xi = (i8 % (MXFP4/8)) * 4; // jump 4 bytes for each 8 elements
const block_mxfp4 *x = x_base+offset0;
for (int64_t col2 = tid; col2 < ncols2; col2 += block_size) {
int offset0 = col2 / (MXFP4/2);
int i = col2 % (MXFP4/2);
const block_mxfp4 *x2 = x+offset0;
union {
uint32_t as_bits;
float as_value;
} scale;
scale.as_bits = (((uint32_t)x->d) << 23);
scale.as_bits = (((uint32_t)x2->d) << 23);
uint16_t em0 = x2->qs[i] & 0x07;
uint16_t em1 = x2->qs[i] & 0x70;
// float16 values
f16_t x0;
f16_t x1;
x0.u16 = (em0 << (dst_m_bits - 1)) | ((x2->qs[i] & 0x08) << 12);
x1.u16 = (em1 << (dst_m_bits - 5)) | ((x2->qs[i] & 0x80) << 8);
// Three cases:
// x is normal and non-zero: Correct bias
if ((em0 & 0x06) != 0) {
x0.u16 = x0.u16 + ((dst_bias - 1) << dst_m_bits);
}
if ((em1 & 0x60) != 0) {
x1.u16 = x1.u16 + ((dst_bias - 1) << dst_m_bits);
}
// x is subnormal (x == 0bs001 where s is the sign): Map to +-0.5 in the dst type
if (em0 == 0x01) {
x0.u16 = dst_0p5 | (x0.u16 & 0x8000);
}
if (em1 == 0x10) {
x1.u16 = dst_0p5 | (x1.u16 & 0x8000);
}
// x is zero, do nothing
if (isnan(scale.as_value)) {
sumf = scale.as_value;
break;
}
const uint8_t qs[4] = {
(uint8_t)(x->qs[xi]),
(uint8_t)(x->qs[xi+1]),
(uint8_t)(x->qs[xi+2]),
(uint8_t)(x->qs[xi+3])
};
const uint8_t el[8] = {
(uint8_t)(qs[0] & 0xf),
(uint8_t)((qs[0] & 0xf0) >> 4),
(uint8_t)(qs[1] & 0xf),
(uint8_t)((qs[1] & 0xf0) >> 4),
(uint8_t)(qs[2] & 0xf),
(uint8_t)((qs[2] & 0xf0) >> 4),
(uint8_t)(qs[3] & 0xf),
(uint8_t)((qs[3] & 0xf0) >> 4)
};
uint16_t em[8];
#pragma unroll
for (int i = 0; i < 8; i++) { em[i] = (uint16_t)(el[i] & 0x07); }
// float16 values
f16_t x4u[8];
#pragma unroll
for (int i = 0; i < 8; i++) { x4u[i].u16 = (em[i] << (dst_m_bits - 1)) | ((el[i] & 0x08) << 12); }
// Three cases:
// x is normal and non-zero: Correct bias
#pragma unroll
for (int i = 0; i < 8; i++) { if ((em[i] & 0x06) != 0) { x4u[i].u16 = x4u[i].u16 + ((dst_bias - 1) << dst_m_bits); } }
// x is subnormal (x == 0bs001 where s is the sign): Map to +-0.5 in the dst type
#pragma unroll
for (int i = 0; i < 8; i++) { if (em[i] == 0x01) { x4u[i].u16 = dst_0p5 | (x4u[i].u16 & 0x8000); } }
// x is zero, do nothing
const float scalef = scale.as_value;
const float4 tmpx0 = {x4u[0].f16, x4u[1].f16, x4u[2].f16, x4u[3].f16};
const float4 tmpx1 = {x4u[4].f16, x4u[5].f16, x4u[6].f16, x4u[7].f16};
const float4 tmpy0 = y4[i8*2];
const float4 tmpy1 = y4[i8*2+1];
sumf += tmpx0.x * tmpy0.x * scalef;
sumf += tmpx0.y * tmpy0.y * scalef;
sumf += tmpx0.z * tmpy0.z * scalef;
sumf += tmpx0.w * tmpy0.w * scalef;
sumf += tmpx1.x * tmpy1.x * scalef;
sumf += tmpx1.y * tmpy1.y * scalef;
sumf += tmpx1.z * tmpy1.z * scalef;
sumf += tmpx1.w * tmpy1.w * scalef;
const float2 tmpx = {x0.f16, x1.f16};
const float2 tmpy = y2[col2];
sumf += tmpx.x*tmpy.x*scale.as_value;
sumf += tmpx.y*tmpy.y*scale.as_value;
}
sumf = warp_reduce_sum<warp_size>(sumf);
@@ -177,42 +151,42 @@ static void launch_mul_mat_vec_cuda_mxfp4(
switch (block_size_best) {
case 32: {
mul_mat_vec_mxfp4<type_acc, 32><<<block_nums, block_dims, smem, stream>>>
(x, y, ids, dst, ncols, nchannels_y, stride_row, channel_ratio, stride_channel_x, stride_channel_y,
(x, y, ids, dst, ncols/2, nchannels_y, stride_row, channel_ratio, stride_channel_x, stride_channel_y,
stride_channel_dst, sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
} break;
case 64: {
mul_mat_vec_mxfp4<type_acc, 64><<<block_nums, block_dims, smem, stream>>>
(x, y, ids, dst, ncols, nchannels_y, stride_row, channel_ratio, stride_channel_x, stride_channel_y,
(x, y, ids, dst, ncols/2, nchannels_y, stride_row, channel_ratio, stride_channel_x, stride_channel_y,
stride_channel_dst, sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
} break;
case 96: {
mul_mat_vec_mxfp4<type_acc, 96><<<block_nums, block_dims, smem, stream>>>
(x, y, ids, dst, ncols, nchannels_y, stride_row, channel_ratio, stride_channel_x, stride_channel_y,
(x, y, ids, dst, ncols/2, nchannels_y, stride_row, channel_ratio, stride_channel_x, stride_channel_y,
stride_channel_dst, sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
} break;
case 128: {
mul_mat_vec_mxfp4<type_acc, 128><<<block_nums, block_dims, smem, stream>>>
(x, y, ids, dst, ncols, nchannels_y, stride_row, channel_ratio, stride_channel_x, stride_channel_y,
(x, y, ids, dst, ncols/2, nchannels_y, stride_row, channel_ratio, stride_channel_x, stride_channel_y,
stride_channel_dst, sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
} break;
case 160: {
mul_mat_vec_mxfp4<type_acc, 160><<<block_nums, block_dims, smem, stream>>>
(x, y, ids, dst, ncols, nchannels_y, stride_row, channel_ratio, stride_channel_x, stride_channel_y,
(x, y, ids, dst, ncols/2, nchannels_y, stride_row, channel_ratio, stride_channel_x, stride_channel_y,
stride_channel_dst, sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
} break;
case 192: {
mul_mat_vec_mxfp4<type_acc, 192><<<block_nums, block_dims, smem, stream>>>
(x, y, ids, dst, ncols, nchannels_y, stride_row, channel_ratio, stride_channel_x, stride_channel_y,
(x, y, ids, dst, ncols/2, nchannels_y, stride_row, channel_ratio, stride_channel_x, stride_channel_y,
stride_channel_dst, sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
} break;
case 224: {
mul_mat_vec_mxfp4<type_acc, 224><<<block_nums, block_dims, smem, stream>>>
(x, y, ids, dst, ncols, nchannels_y, stride_row, channel_ratio, stride_channel_x, stride_channel_y,
(x, y, ids, dst, ncols/2, nchannels_y, stride_row, channel_ratio, stride_channel_x, stride_channel_y,
stride_channel_dst, sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
} break;
case 256: {
mul_mat_vec_mxfp4<type_acc, 256><<<block_nums, block_dims, smem, stream>>>
(x, y, ids, dst, ncols, nchannels_y, stride_row, channel_ratio, stride_channel_x, stride_channel_y,
(x, y, ids, dst, ncols/2, nchannels_y, stride_row, channel_ratio, stride_channel_x, stride_channel_y,
stride_channel_dst, sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
} break;
default: {

View File

@@ -22,27 +22,25 @@ import (
// MXFP4 reference: https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf
var (
// E2M1 values
mxfp4_vals = []float32{
0.0, // 0 00 0 = 0x0
0.5, // 0 00 1 = 0x1
1.0, // 0 01 0 = 0x2
1.5, // 0 01 1 = 0x3
2.0, // 0 10 0 = 0x4
3.0, // 0 10 1 = 0x5
4.0, // 0 11 0 = 0x6
6.0, // 0 11 1 = 0x7
0.0, // 1 00 0 = 0x8
-0.5, // 1 00 1 = 0x9
-1.0, // 1 01 0 = 0xa
-1.5, // 1 01 1 = 0xb
-2.0, // 1 10 0 = 0xc
-3.0, // 1 10 1 = 0xd
-4.0, // 1 11 0 = 0xe
-6.0, // 1 11 1 = 0xf
}
)
// E2M1 values
var mxfp4_vals = []float32{
0.0, // 0 00 0 = 0x0
0.5, // 0 00 1 = 0x1
1.0, // 0 01 0 = 0x2
1.5, // 0 01 1 = 0x3
2.0, // 0 10 0 = 0x4
3.0, // 0 10 1 = 0x5
4.0, // 0 11 0 = 0x6
6.0, // 0 11 1 = 0x7
0.0, // 1 00 0 = 0x8
-0.5, // 1 00 1 = 0x9
-1.0, // 1 01 0 = 0xa
-1.5, // 1 01 1 = 0xb
-2.0, // 1 10 0 = 0xc
-3.0, // 1 10 1 = 0xd
-4.0, // 1 11 0 = 0xe
-6.0, // 1 11 1 = 0xf
}
func TestMXFP4Ops(t *testing.T) {
b := setup(t)
@@ -412,11 +410,9 @@ func TestMXFP4Ops(t *testing.T) {
}
// t.Logf("MulmatID results matched:\n%s", d4)
})
})
t.Run("mm", func(t *testing.T) {
t.Run("example", func(t *testing.T) {
r := rand.New(rand.NewSource(0))
ctx := initContextOrSkip(t, b, useGPU)
@@ -735,7 +731,6 @@ func TestMXFP4Simple(t *testing.T) {
}
t.Logf("result (mxfp4): \n%s", d3)
})
}
func TestMXFP4Conversion(t *testing.T) {
@@ -744,7 +739,7 @@ func TestMXFP4Conversion(t *testing.T) {
data := [32 * 4]float32{}
for i := range data {
data[i] = mxfp4_vals[r.Int()%len(mxfp4_vals)] * 0.1
data[i] = mxfp4_vals[r.Int()%len(mxfp4_vals)]
}
mxData := Quantize(fsggml.TensorTypeMXFP4, data[:], []uint64{uint64(len(data))})
newData := ConvertToF32(mxData, uint32(fsggml.TensorTypeMXFP4), uint64(len(data)))