sync llama.cpp vulkan code
This commit is contained in:
parent
163f62fcb6
commit
93d7126ce5
File diff suppressed because it is too large
Load Diff
|
|
@ -1,20 +1,34 @@
|
|||
#version 450
|
||||
|
||||
#extension GL_EXT_shader_16bit_storage : require
|
||||
#if ADD_RMS
|
||||
#extension GL_KHR_shader_subgroup_arithmetic : enable
|
||||
#extension GL_KHR_shader_subgroup_basic : enable
|
||||
#endif
|
||||
|
||||
#include "types.comp"
|
||||
#include "generic_binary_head.comp"
|
||||
|
||||
const uint num_threads = 256;
|
||||
|
||||
layout (binding = 3, std430) buffer PartialBuf {float partial_sums[];};
|
||||
|
||||
layout(local_size_x = num_threads, local_size_y = 1, local_size_z = 1) in;
|
||||
|
||||
#if ADD_RMS
|
||||
// XXX TODO this could be sized based on number of subgroups, but that't not considered a constant
|
||||
shared FLOAT_TYPE sumsh[num_threads];
|
||||
#endif
|
||||
|
||||
void main() {
|
||||
uint idx = get_idx();
|
||||
uint orig_idx = idx;
|
||||
|
||||
// num_threads * num_iter must equal 512, to match the wg_denoms and get_idx calculation
|
||||
const uint num_iter = 2;
|
||||
|
||||
FLOAT_TYPE sum_sq = 0;
|
||||
|
||||
[[unroll]] for (uint i = 0; i < num_iter; ++i) {
|
||||
if (idx >= p.ne) {
|
||||
continue;
|
||||
|
|
@ -22,8 +36,34 @@ void main() {
|
|||
uint i00, i01, i02, i03;
|
||||
get_indices(idx, i00, i01, i02, i03);
|
||||
|
||||
data_d[get_doffset() + dst_idx(i00, i01, i02, i03)] = D_TYPE(FLOAT_TYPE(data_a[get_aoffset() + src0_idx(i00, i01, i02, i03)]) + FLOAT_TYPE(data_b[get_boffset() + src1_idx(i00, i01, i02, i03)]));
|
||||
FLOAT_TYPE sum = FLOAT_TYPE(data_a[get_aoffset() + src0_idx(i00, i01, i02, i03)]) + FLOAT_TYPE(data_b[get_boffset() + src1_idx(i00, i01, i02, i03)]);
|
||||
sum_sq += sum*sum;
|
||||
|
||||
data_d[get_doffset() + dst_idx(i00, i01, i02, i03)] = D_TYPE(sum);
|
||||
|
||||
idx += num_threads;
|
||||
}
|
||||
|
||||
#if ADD_RMS
|
||||
if (p.param3 != 0) {
|
||||
// reduce the sum within each subgroup, then across subgroups
|
||||
const uint NumSubgroups = num_threads / gl_SubgroupSize;
|
||||
sum_sq = subgroupAdd(sum_sq);
|
||||
if (gl_SubgroupInvocationID == 0) {
|
||||
sumsh[gl_SubgroupID] = sum_sq;
|
||||
}
|
||||
barrier();
|
||||
[[unroll]] for (uint s = NumSubgroups / 2; s > 0; s >>= 1) {
|
||||
if (gl_SubgroupID < s && gl_SubgroupInvocationID == 0) {
|
||||
sum_sq += sumsh[gl_SubgroupID + s];
|
||||
sumsh[gl_SubgroupID] = sum_sq;
|
||||
}
|
||||
barrier();
|
||||
}
|
||||
|
||||
if (gl_SubgroupID == 0 && gl_SubgroupInvocationID == 0) {
|
||||
partial_sums[orig_idx / (num_iter * num_threads)] = sum_sq;
|
||||
}
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
|
|
|||
|
|
@ -5,6 +5,8 @@
|
|||
|
||||
#extension GL_EXT_control_flow_attributes : enable
|
||||
|
||||
#define FLT_MAX 3.402823466e+38F
|
||||
|
||||
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
|
||||
|
||||
layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
|
||||
|
|
@ -19,19 +21,26 @@ void main() {
|
|||
const uint row = gl_WorkGroupID.z * 262144 + gl_WorkGroupID.y * 512 + gl_WorkGroupID.x;
|
||||
const uint col = gl_LocalInvocationID.x;
|
||||
|
||||
if (col >= p.KX) {
|
||||
if (row >= p.KY) {
|
||||
return;
|
||||
}
|
||||
A_TYPE amax = data_a[row*p.KX + col];
|
||||
tmp[col] = col;
|
||||
|
||||
A_TYPE amax = -FLT_MAX;
|
||||
uint acol = col;
|
||||
|
||||
if (col < p.KX) {
|
||||
amax = data_a[row*p.KX + col];
|
||||
}
|
||||
|
||||
for (uint i = col + BLOCK_SIZE; i < p.KX; i += BLOCK_SIZE) {
|
||||
A_TYPE val = data_a[row*p.KX + i];
|
||||
if (val > amax) {
|
||||
amax = val;
|
||||
tmp[col] = i;
|
||||
acol = i;
|
||||
}
|
||||
}
|
||||
|
||||
tmp[col] = acol;
|
||||
tmpmax[col] = amax;
|
||||
|
||||
barrier();
|
||||
|
|
|
|||
|
|
@ -1,22 +1,24 @@
|
|||
#version 450
|
||||
#extension GL_EXT_control_flow_attributes : enable
|
||||
|
||||
#include "types.comp"
|
||||
|
||||
#define BLOCK_SIZE 1024
|
||||
layout(constant_id = 0) const int BLOCK_SIZE = 1024;
|
||||
layout(constant_id = 1) const int BLOCK_SIZE_LOG2 = 10;
|
||||
#define ASC 0
|
||||
|
||||
layout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in;
|
||||
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
|
||||
|
||||
layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
|
||||
layout (binding = 1) buffer D {int data_d[];};
|
||||
|
||||
layout (push_constant) uniform parameter {
|
||||
uint ncols;
|
||||
uint ncols_pad;
|
||||
uint order;
|
||||
} p;
|
||||
|
||||
shared int dst_row[BLOCK_SIZE];
|
||||
shared A_TYPE a_sh[BLOCK_SIZE];
|
||||
|
||||
void swap(uint idx0, uint idx1) {
|
||||
int tmp = dst_row[idx0];
|
||||
|
|
@ -24,7 +26,7 @@ void swap(uint idx0, uint idx1) {
|
|||
dst_row[idx1] = tmp;
|
||||
}
|
||||
|
||||
void main() {
|
||||
void argsort(bool needs_bounds_check) {
|
||||
// bitonic sort
|
||||
const int col = int(gl_LocalInvocationID.x);
|
||||
const uint row = gl_WorkGroupID.y;
|
||||
|
|
@ -32,38 +34,46 @@ void main() {
|
|||
const uint row_offset = row * p.ncols;
|
||||
|
||||
// initialize indices
|
||||
if (col < p.ncols_pad) {
|
||||
dst_row[col] = col;
|
||||
}
|
||||
dst_row[col] = col;
|
||||
a_sh[col] = data_a[row_offset + col];
|
||||
barrier();
|
||||
|
||||
for (uint k = 2; k <= p.ncols_pad; k *= 2) {
|
||||
for (uint j = k / 2; j > 0; j /= 2) {
|
||||
const uint ixj = col ^ j;
|
||||
if (col < p.ncols_pad && ixj > col) {
|
||||
if ((col & k) == 0) {
|
||||
if (dst_row[col] >= p.ncols ||
|
||||
(dst_row[ixj] < p.ncols && (p.order == ASC ?
|
||||
data_a[row_offset + dst_row[col]] > data_a[row_offset + dst_row[ixj]] :
|
||||
data_a[row_offset + dst_row[col]] < data_a[row_offset + dst_row[ixj]]))
|
||||
) {
|
||||
swap(col, ixj);
|
||||
}
|
||||
} else {
|
||||
if (dst_row[ixj] >= p.ncols ||
|
||||
(dst_row[col] < p.ncols && (p.order == ASC ?
|
||||
data_a[row_offset + dst_row[col]] < data_a[row_offset + dst_row[ixj]] :
|
||||
data_a[row_offset + dst_row[col]] > data_a[row_offset + dst_row[ixj]]))
|
||||
) {
|
||||
swap(col, ixj);
|
||||
}
|
||||
}
|
||||
uint num_outer_loop_iters = BLOCK_SIZE_LOG2;
|
||||
[[unroll]] for (uint k = 2, outer_idx = 0; outer_idx < num_outer_loop_iters; k *= 2, outer_idx++) {
|
||||
uint num_inner_loop_iters = outer_idx + 1;
|
||||
[[unroll]] for (uint j = k / 2, inner_idx = 0; inner_idx < num_inner_loop_iters; j /= 2, inner_idx++) {
|
||||
const int ixj = int(col ^ j);
|
||||
|
||||
int idx_0 = (col & k) == 0 ? col : ixj;
|
||||
int idx_1 = (col & k) == 0 ? ixj : col;
|
||||
|
||||
int sh_idx_0 = dst_row[idx_0];
|
||||
int sh_idx_1 = dst_row[idx_1];
|
||||
bool idx_0_oob = needs_bounds_check ? sh_idx_0 >= p.ncols : false;
|
||||
bool idx_1_oob = needs_bounds_check ? sh_idx_1 >= p.ncols : false;
|
||||
|
||||
if ((idx_0_oob ||
|
||||
(!idx_1_oob && a_sh[sh_idx_0] > a_sh[sh_idx_1])) && (ixj > col)) {
|
||||
swap(idx_0, idx_1);
|
||||
}
|
||||
|
||||
barrier();
|
||||
}
|
||||
}
|
||||
|
||||
if (col < p.ncols) {
|
||||
data_d[row_offset + col] = dst_row[col];
|
||||
if (p.order == ASC) {
|
||||
data_d[row_offset + col] = dst_row[col];
|
||||
} else {
|
||||
data_d[row_offset + p.ncols - col - 1] = dst_row[col];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void main() {
|
||||
if (p.ncols == BLOCK_SIZE) {
|
||||
argsort(false);
|
||||
} else {
|
||||
argsort(true);
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -16,7 +16,7 @@
|
|||
// shape notation: [dim(N), ..., dim(0)] -- stride(dim(j)) >= stride(dim(i)) if i > j
|
||||
layout(binding = 0) readonly buffer A {
|
||||
A_TYPE knl_data[];
|
||||
}; // src0 - kernel: [KW, KH, Cin, Cout]
|
||||
}; // src0 - kernel: [KW, KH, Cin, Cout] for conv_2d, [KW, KH, Cout, Cin] for conv_transposed_2d
|
||||
|
||||
layout(binding = 1) readonly buffer B {
|
||||
B_TYPE src_data[];
|
||||
|
|
@ -66,6 +66,10 @@ layout(push_constant) uniform parameter {
|
|||
uint32_t KWKHmp; uint32_t KWKHL;
|
||||
uint32_t OWmp; uint32_t OWL;
|
||||
uint32_t OWOHmp; uint32_t OWOHL;
|
||||
#ifdef TRANSPOSE
|
||||
uint32_t s0mp; uint32_t s0L;
|
||||
uint32_t s1mp; uint32_t s1L;
|
||||
#endif
|
||||
}
|
||||
|
||||
p;
|
||||
|
|
@ -225,7 +229,11 @@ void main() {
|
|||
uint32_t B_ly = r_offset + Ar;
|
||||
uint32_t B_lx = Ac;
|
||||
uint32_t K_idx = B_idx_K * BS_K + B_ly; /* Global K_idx (row index of A)*/
|
||||
#ifdef TRANSPOSE
|
||||
uint32_t knl_idx = min(KW_idx_a + KH_idx_a * p.nb01 + K_idx * p.nb02 + Cin_idx_a * p.nb03, K * CRS - 1);
|
||||
#else
|
||||
uint32_t knl_idx = min(KW_idx_a + KH_idx_a * p.nb01 + Cin_idx_a * p.nb02 + K_idx * p.nb03, K * CRS - 1);
|
||||
#endif
|
||||
float val = knl_data[knl_idx];
|
||||
if (K_idx >= K || CRS_idx_a >= CRS) {
|
||||
val = 0.0;
|
||||
|
|
@ -267,12 +275,24 @@ void main() {
|
|||
KW_idx_b = CRS_remainder - KH_idx_b * p.KW;
|
||||
#endif
|
||||
|
||||
#ifdef TRANSPOSE
|
||||
uint32_t H_idx_x_s1 = OH_idx - KH_idx_b * p.d1 + p.p1;
|
||||
uint32_t W_idx_x_s0 = OW_idx - KW_idx_b * p.d0 + p.p0;
|
||||
uint32_t H_idx = fastdiv(H_idx_x_s1, p.s1mp, p.s1L);
|
||||
uint32_t W_idx = fastdiv(W_idx_x_s0, p.s0mp, p.s0L);
|
||||
#else
|
||||
uint32_t H_idx = OH_idx * p.s1 + KH_idx_b * p.d1 - p.p1;
|
||||
uint32_t W_idx = OW_idx * p.s0 + KW_idx_b * p.d0 - p.p0;
|
||||
#endif
|
||||
uint32_t src_idx =
|
||||
min(max(W_idx + H_idx * p.nb11 + Cin_idx_b * p.nb12 + N_idx * p.nb13, 0), p.Cin * p.N * p.W * p.H - 1);
|
||||
float val = src_data[src_idx];
|
||||
if (CRS_idx_b >= CRS || NPQ_idx >= NPQ || H_idx < 0 || H_idx >= p.H || W_idx < 0 || W_idx >= p.W) {
|
||||
if (CRS_idx_b >= CRS || NPQ_idx >= NPQ
|
||||
|| H_idx >= p.H || W_idx >= p.W // Lower bound checks aren't necessary. (idx >= 0x80000000 for such case)
|
||||
#ifdef TRANSPOSE
|
||||
|| (H_idx_x_s1 - H_idx * p.s1 != 0) || (W_idx_x_s0 - W_idx * p.s0 != 0)
|
||||
#endif
|
||||
) {
|
||||
val = 0.0;
|
||||
}
|
||||
Bsh[B_ly * Bsh_stride + B_lx] = SHMEM_TYPE(val);
|
||||
|
|
|
|||
|
|
@ -15,8 +15,15 @@ layout (binding = 0) readonly buffer S {float data_s[];};
|
|||
|
||||
#if defined(SET_ROWS)
|
||||
#include "generic_binary_head.comp"
|
||||
layout (binding = 1) readonly buffer C {uvec2 data_i[];};
|
||||
layout (binding = 1) readonly buffer C {B_TYPE data_i[];};
|
||||
layout (binding = 2) writeonly buffer Q {A_TYPE data_q[];};
|
||||
|
||||
#if B_SIZE == 64
|
||||
#define DATA_I_SWIZZLE .x
|
||||
#else
|
||||
#define DATA_I_SWIZZLE
|
||||
#endif
|
||||
|
||||
#else
|
||||
#include "generic_unary_head.comp"
|
||||
layout (binding = 1) writeonly buffer Q {A_TYPE data_q[];};
|
||||
|
|
@ -259,7 +266,7 @@ void main() {
|
|||
uint i11 = fastmod(i02, p.ne11);
|
||||
uint i10 = i01;
|
||||
|
||||
uint i1 = data_i[src1_idx(i10, i11, i12, 0) + get_boffset()].x;
|
||||
uint i1 = data_i[src1_idx(i10, i11, i12, 0) + get_boffset()] DATA_I_SWIZZLE;
|
||||
|
||||
uint src0_idx = src0_idx(i00, i01, i02, i03) + get_aoffset();
|
||||
uint dst_idx = dst_idx(i00 / QUANT_K, i1, i02, i03) + get_doffset();
|
||||
|
|
|
|||
|
|
@ -478,3 +478,139 @@ vec2 get_dm(uint ib, uint a_offset) {
|
|||
return vec2(float(data_a[a_offset + ib].d), float(data_a[a_offset + ib].m));
|
||||
}
|
||||
#endif
|
||||
|
||||
#if defined(DATA_A_Q2_K)
|
||||
vec2 dequantize(uint ib, uint iqs, uint a_offset) {
|
||||
iqs /= 2;
|
||||
const uint qsi = (iqs / 64) * 32 + (iqs % 16) * 2; // 0,2,4..30
|
||||
const uint scalesi = iqs / 8; // 0..15
|
||||
const uint qsshift = ((iqs % 64) / 16) * 2; // 0,2,4,6
|
||||
|
||||
const uvec2 qs = uvec2(data_a[a_offset + ib].qs[qsi], data_a[a_offset + ib].qs[qsi + 1]);
|
||||
const uint scales = data_a[a_offset + ib].scales[scalesi];
|
||||
const vec2 d = vec2(data_a[a_offset + ib].d);
|
||||
|
||||
return d.x * float(scales & 0xF) * vec2((qs >> qsshift) & 3) - d.y * float(scales >> 4);
|
||||
}
|
||||
vec2 get_dm(uint ib, uint a_offset) {
|
||||
return vec2(1, 0);
|
||||
}
|
||||
#endif
|
||||
|
||||
#if defined(DATA_A_Q3_K)
|
||||
vec2 dequantize(uint ib, uint iqs, uint a_offset) {
|
||||
iqs /= 2;
|
||||
const uint n = iqs / 64; // 0,1
|
||||
const uint qsi = n * 32 + (iqs % 16) * 2; // 0,2,4..62
|
||||
const uint hmi = (iqs % 16) * 2; // 0,2,4..30
|
||||
const uint j = (iqs % 64) / 4; // 0..3
|
||||
const uint is = iqs / 8; // 0..15
|
||||
const uint halfsplit = ((iqs % 64) / 16); // 0,1,2,3
|
||||
const uint qsshift = halfsplit * 2; // 0,2,4,6
|
||||
const uint m = 1 << (4 * n + halfsplit); // 1,2,4,8,16,32,64,128
|
||||
|
||||
const int8_t us = int8_t(((data_a[a_offset + ib].scales[is % 8] >> (4 * int(is / 8))) & 0xF)
|
||||
| (((data_a[a_offset + ib].scales[8 + (is % 4)] >> (2 * int(is / 4))) & 3) << 4));
|
||||
const float dl = float(data_a[a_offset + ib].d) * float(us - 32);
|
||||
|
||||
return vec2(dl * float(int8_t((data_a[a_offset + ib].qs[qsi ] >> qsshift) & 3) - (((data_a[a_offset + ib].hmask[hmi ] & m) != 0) ? 0 : 4)),
|
||||
dl * float(int8_t((data_a[a_offset + ib].qs[qsi + 1] >> qsshift) & 3) - (((data_a[a_offset + ib].hmask[hmi + 1] & m) != 0) ? 0 : 4)));
|
||||
}
|
||||
vec2 get_dm(uint ib, uint a_offset) {
|
||||
return vec2(1, 0);
|
||||
}
|
||||
#endif
|
||||
|
||||
#if defined(DATA_A_Q4_K)
|
||||
vec2 dequantize(uint ib, uint iqs, uint a_offset) {
|
||||
iqs /= 2;
|
||||
const uint n = iqs / 32; // 0,1,2,3
|
||||
const uint b = (iqs % 32) / 16; // 0,1
|
||||
const uint is = 2 * n + b; // 0..7
|
||||
const uint qsi = n * 32 + (iqs % 16) * 2; // 0,2,4..126
|
||||
|
||||
const vec2 loadd = vec2(data_a[a_offset + ib].d);
|
||||
|
||||
const uint scidx0 = (is < 4) ? is : (is + 4);
|
||||
const uint scidx1 = (is < 4) ? is : (is - 4);
|
||||
const uint scidxmask1 = (is < 4) ? 0x30 : 0xC0;
|
||||
const uint scidxshift1 = (is < 4) ? 0 : 2;
|
||||
const uint mbidx0 = is + 4;
|
||||
const uint mbidx1 = (is < 4) ? is + 4 : is;
|
||||
const uint mbidxmask0 = (is < 4) ? 0xF : 0xF0;
|
||||
const uint mbidxshift0 = (is < 4) ? 0 : 4;
|
||||
const uint mbidxmask1 = (is < 4) ? 0x30 : 0xC0;
|
||||
const uint mbidxshift1 = (is < 4) ? 0 : 2;
|
||||
|
||||
const uint8_t sc = uint8_t((data_a[a_offset + ib].scales[scidx0] & 0xF) | ((data_a[a_offset + ib].scales[scidx1] & scidxmask1) >> scidxshift1));
|
||||
const uint8_t mbyte = uint8_t((data_a[a_offset + ib].scales[mbidx0] & mbidxmask0) >> mbidxshift0 | ((data_a[a_offset + ib].scales[mbidx1] & mbidxmask1) >> mbidxshift1));
|
||||
|
||||
const float d = loadd.x * sc;
|
||||
const float m = -loadd.y * mbyte;
|
||||
|
||||
return vec2(fma(d, float((data_a[a_offset + ib].qs[qsi ] >> (b * 4)) & 0xF), m),
|
||||
fma(d, float((data_a[a_offset + ib].qs[qsi + 1] >> (b * 4)) & 0xF), m));
|
||||
}
|
||||
vec2 get_dm(uint ib, uint a_offset) {
|
||||
return vec2(1, 0);
|
||||
}
|
||||
#endif
|
||||
|
||||
#if defined(DATA_A_Q5_K)
|
||||
vec2 dequantize(uint ib, uint iqs, uint a_offset) {
|
||||
iqs /= 2;
|
||||
const uint n = iqs / 32; // 0,1,2,3
|
||||
const uint b = (iqs % 32) / 16; // 0,1
|
||||
const uint is = 2 * n + b; // 0..7
|
||||
const uint qsi = n * 32 + (iqs % 16) * 2; // 0,2,4..126
|
||||
const uint qhi = (iqs % 16) * 2; // 0,2,4..30
|
||||
|
||||
const uint8_t hm = uint8_t(1 << (iqs / 16));
|
||||
|
||||
const vec2 loadd = vec2(data_a[a_offset + ib].d);
|
||||
|
||||
const uint scidx0 = (is < 4) ? is : (is + 4);
|
||||
const uint scidx1 = (is < 4) ? is : (is - 4);
|
||||
const uint scidxmask1 = (is < 4) ? 0x30 : 0xC0;
|
||||
const uint scidxshift1 = (is < 4) ? 0 : 2;
|
||||
const uint mbidx0 = is + 4;
|
||||
const uint mbidx1 = (is < 4) ? is + 4 : is;
|
||||
const uint mbidxmask0 = (is < 4) ? 0xF : 0xF0;
|
||||
const uint mbidxshift0 = (is < 4) ? 0 : 4;
|
||||
const uint mbidxmask1 = (is < 4) ? 0x30 : 0xC0;
|
||||
const uint mbidxshift1 = (is < 4) ? 0 : 2;
|
||||
|
||||
const uint8_t sc = uint8_t((data_a[a_offset + ib].scales[scidx0] & 0xF) | ((data_a[a_offset + ib].scales[scidx1] & scidxmask1) >> scidxshift1));
|
||||
const uint8_t mbyte = uint8_t(((data_a[a_offset + ib].scales[mbidx0] & mbidxmask0) >> mbidxshift0) | ((data_a[a_offset + ib].scales[mbidx1] & mbidxmask1) >> mbidxshift1));
|
||||
|
||||
const float d = loadd.x * sc;
|
||||
const float m = -loadd.y * mbyte;
|
||||
|
||||
return vec2(fma(d, float((data_a[a_offset + ib].qs[qsi ] >> (b * 4)) & 0xF) + float((data_a[a_offset + ib].qh[qhi ] & hm) != 0 ? 16 : 0), m),
|
||||
fma(d, float((data_a[a_offset + ib].qs[qsi + 1] >> (b * 4)) & 0xF) + float((data_a[a_offset + ib].qh[qhi + 1] & hm) != 0 ? 16 : 0), m));
|
||||
}
|
||||
vec2 get_dm(uint ib, uint a_offset) {
|
||||
return vec2(1, 0);
|
||||
}
|
||||
#endif
|
||||
|
||||
#if defined(DATA_A_Q6_K)
|
||||
vec2 dequantize(uint ib, uint iqs, uint a_offset) {
|
||||
iqs /= 2;
|
||||
const uint n = iqs / 64; // 0,1
|
||||
const uint b = (iqs % 64) / 32; // 0,1
|
||||
const uint is_b = (iqs % 16) / 8; // 0,1
|
||||
const uint qhshift = ((iqs % 64) / 16) * 2; // 0,2,4,6
|
||||
const uint is = 8 * n + qhshift + is_b; // 0..15
|
||||
const uint qsi = n * 64 + (iqs % 32) * 2; // 0,2,4..126
|
||||
const uint qhi = n * 32 + (iqs % 16) * 2; // 0,2,4..62
|
||||
|
||||
const float dscale = float(data_a[a_offset + ib].d) * float(data_a[a_offset + ib].scales[is]);
|
||||
|
||||
return vec2(dscale * float(int8_t(((data_a[a_offset + ib].ql[qsi ] >> (b * 4)) & 0xF) | (((data_a[a_offset + ib].qh[qhi ] >> qhshift) & 3) << 4)) - 32),
|
||||
dscale * float(int8_t(((data_a[a_offset + ib].ql[qsi + 1] >> (b * 4)) & 0xF) | (((data_a[a_offset + ib].qh[qhi + 1] >> qhshift) & 3) << 4)) - 32));
|
||||
}
|
||||
vec2 get_dm(uint ib, uint a_offset) {
|
||||
return vec2(1, 0);
|
||||
}
|
||||
#endif
|
||||
|
|
|
|||
|
|
@ -29,7 +29,7 @@ void main() {
|
|||
uint qs = data_a[ib].qs[4 * ib32 + l];
|
||||
const uint8_t sign = data_a[ib].qs[QUANT_K / 8 + 4 * ib32 + l];
|
||||
qs |= (qh << (8 - 2 * l)) & 0x300;
|
||||
const uvec2 grid = iq2s_grid[qs & 511];
|
||||
const uvec2 grid = iq2s_grid[qs];
|
||||
const u8vec4 grid0 = unpack8(grid.x);
|
||||
const u8vec4 grid1 = unpack8(grid.y);
|
||||
data_b[b_idx + 8 * l + 0] = D_TYPE(db[l/2] * grid0.x * ((sign & 1) != 0 ? -1.0 : 1.0));
|
||||
|
|
|
|||
|
|
@ -33,7 +33,8 @@ void main() {
|
|||
[[unroll]] for (uint l = 0; l < 4; ++l) {
|
||||
const uint sign7 = bitfieldExtract(signscale, 7 * int(l), 7);
|
||||
const uint sign8 = sign7 | (bitCount(sign7) << 7); // parity bit
|
||||
const uvec2 grid = iq2xxs_grid[data_a[ib].qs[8 * is + l]];
|
||||
const uint qs = data_a[ib].qs[8 * is + l];
|
||||
const uvec2 grid = iq2xxs_grid[qs];
|
||||
const u8vec4 grid0 = unpack8(grid.x);
|
||||
const u8vec4 grid1 = unpack8(grid.y);
|
||||
data_b[b_idx + 8 * l + 0] = D_TYPE(db * grid0.x * ((sign8 & 1) != 0 ? -1.0 : 1.0));
|
||||
|
|
|
|||
|
|
@ -22,15 +22,16 @@ void main() {
|
|||
const uint b_idx = 256 * ib + 32 * is;
|
||||
|
||||
const float d = float(data_a[ib].d);
|
||||
const float db = d * (1 + 2 * ((data_a[ib].scales[is] >> (4 * (is % 2))) & 0xf));
|
||||
const float db = d * (1 + 2 * ((data_a[ib].scales[is / 2] >> (4 * (is % 2))) & 0xf));
|
||||
|
||||
// We must produce 32 values using 4 sign bytes, 1 qh byte, 8 qs bytes.
|
||||
uint qh = data_a[ib].qh[is];
|
||||
[[unroll]] for (uint l = 0; l < 8; ++l) {
|
||||
uint qs = data_a[ib].qs[8 * is + l];
|
||||
uint gidx = qs | ((qh << (8 - l)) & 256);
|
||||
uint8_t signs = data_a[ib].signs[8 * is + l / 2] >> (4 * (l & 1));
|
||||
u8vec4 grid = unpack8(iq3s_grid[gidx]);
|
||||
const uint iqs = 8 * is + l;
|
||||
const uint qs = data_a[ib].qs[iqs];
|
||||
const uint gidx = qs | ((qh << (8 - l)) & 256);
|
||||
const uint8_t signs = data_a[ib].signs[iqs / 2] >> (4 * (l & 1));
|
||||
const u8vec4 grid = unpack8(iq3s_grid[gidx]);
|
||||
data_b[b_idx + 4 * l + 0] = D_TYPE(db * grid.x * ((signs & 1) != 0 ? -1.0 : 1.0));
|
||||
data_b[b_idx + 4 * l + 1] = D_TYPE(db * grid.y * ((signs & 2) != 0 ? -1.0 : 1.0));
|
||||
data_b[b_idx + 4 * l + 2] = D_TYPE(db * grid.z * ((signs & 4) != 0 ? -1.0 : 1.0));
|
||||
|
|
|
|||
|
|
@ -35,8 +35,10 @@ void main() {
|
|||
const uint sign7 = bitfieldExtract(signscale, 7 * int(l), 7);
|
||||
// Restore parity bit.
|
||||
const uint sign8 = sign7 | (bitCount(sign7) << 7);
|
||||
const u8vec4 grid0 = unpack8(iq3xxs_grid[data_a[ib].qs[8 * is + 2 * l]]);
|
||||
const u8vec4 grid1 = unpack8(iq3xxs_grid[data_a[ib].qs[8 * is + 2 * l + 1]]);
|
||||
const uint qs0 = data_a[ib].qs[8 * is + 2 * l];
|
||||
const uint qs1 = data_a[ib].qs[8 * is + 2 * l + 1];
|
||||
const u8vec4 grid0 = unpack8(iq3xxs_grid[qs0]);
|
||||
const u8vec4 grid1 = unpack8(iq3xxs_grid[qs1]);
|
||||
data_b[b_idx + 8 * l + 0] = D_TYPE(db * grid0.x * ((sign8 & 1) != 0 ? -1.0 : 1.0));
|
||||
data_b[b_idx + 8 * l + 1] = D_TYPE(db * grid0.y * ((sign8 & 2) != 0 ? -1.0 : 1.0));
|
||||
data_b[b_idx + 8 * l + 2] = D_TYPE(db * grid0.z * ((sign8 & 4) != 0 ? -1.0 : 1.0));
|
||||
|
|
|
|||
|
|
@ -0,0 +1,21 @@
|
|||
#version 450
|
||||
|
||||
#include "rte.comp"
|
||||
#include "generic_head.comp"
|
||||
#include "types.comp"
|
||||
|
||||
#extension GL_EXT_control_flow_attributes : enable
|
||||
|
||||
layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
|
||||
|
||||
layout (binding = 0) readonly buffer X {A_TYPE data_a[];};
|
||||
layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};
|
||||
|
||||
void main() {
|
||||
const uint i = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x;
|
||||
|
||||
if (i >= p.KX) {
|
||||
return;
|
||||
}
|
||||
data_d[i] = D_TYPE(exp(float(data_a[i])));
|
||||
}
|
||||
|
|
@ -117,6 +117,9 @@ void main() {
|
|||
|
||||
|
||||
[[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) {
|
||||
if (KV_bounds_check && j * Bc + c * cols_per_iter + col_tid >= KV) {
|
||||
continue;
|
||||
}
|
||||
[[unroll]] for (uint32_t d = 0; d < HSK_per_thread / 4; ++d) {
|
||||
#if BLOCK_SIZE > 1
|
||||
uint coord = (j * Bc + c * cols_per_iter + col_tid) * k_stride * BLOCK_SIZE + 4 * (d * D_split + d_tid);
|
||||
|
|
@ -155,7 +158,11 @@ void main() {
|
|||
uint32_t c = (idx + tid) % Bc;
|
||||
uint32_t r = (idx + tid) / Bc;
|
||||
if (idx + tid < Bc * Br) {
|
||||
masksh[c][r] = float(data_m[m_offset + (i * Br + r) * m_stride + (j * Bc + c)]);
|
||||
if (!KV_bounds_check || j * Bc + c < KV) {
|
||||
masksh[c][r] = float(data_m[m_offset + (i * Br + r) * m_stride + (j * Bc + c)]);
|
||||
} else {
|
||||
masksh[c][r] = float(0);
|
||||
}
|
||||
}
|
||||
}
|
||||
barrier();
|
||||
|
|
@ -172,8 +179,11 @@ void main() {
|
|||
|
||||
float rowmaxf[Br], Pf[Br][cols_per_thread], rowsumf[Br], eMf[Br], Moldf[Br];
|
||||
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
|
||||
rowmaxf[r] = Sf[r][0];
|
||||
rowmaxf[r] = NEG_FLT_MAX_OVER_2;
|
||||
[[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) {
|
||||
if (KV_bounds_check && j * Bc + c * cols_per_iter + col_tid >= KV) {
|
||||
continue;
|
||||
}
|
||||
rowmaxf[r] = max(rowmaxf[r], Sf[r][c]);
|
||||
}
|
||||
Moldf[r] = Mf[r];
|
||||
|
|
@ -190,6 +200,9 @@ void main() {
|
|||
// Compute sum across row of P
|
||||
rowsumf[r] = 0.0;
|
||||
[[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) {
|
||||
if (KV_bounds_check && j * Bc + c * cols_per_iter + col_tid >= KV) {
|
||||
continue;
|
||||
}
|
||||
rowsumf[r] += Pf[r][c];
|
||||
}
|
||||
|
||||
|
|
@ -203,6 +216,9 @@ void main() {
|
|||
}
|
||||
|
||||
[[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) {
|
||||
if (KV_bounds_check && j * Bc + c * cols_per_iter + col_tid >= KV) {
|
||||
continue;
|
||||
}
|
||||
[[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
|
||||
#if BLOCK_SIZE > 1
|
||||
uint coord = (j * Bc + c * cols_per_iter + col_tid) * v_stride * BLOCK_SIZE + 4 * (d * D_split + d_tid);
|
||||
|
|
@ -334,6 +350,9 @@ void main() {
|
|||
[[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
|
||||
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
|
||||
Of[r][d] *= Lfrcp[r];
|
||||
#if defined(ACC_TYPE_MAX)
|
||||
Of[r][d] = clamp(Of[r][d], -vec4(ACC_TYPE_MAX), vec4(ACC_TYPE_MAX));
|
||||
#endif
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -9,6 +9,12 @@ layout (constant_id = 4) const uint32_t HSV = 32;
|
|||
layout (constant_id = 5) const uint32_t Clamp = 0;
|
||||
layout (constant_id = 6) const uint32_t D_split = 16;
|
||||
|
||||
// Round up head sizes to a multiple of 16, for coopmat1/coopmat2 paths
|
||||
const uint32_t HSK_pad = (HSK + 15) & ~15;
|
||||
const uint32_t HSV_pad = (HSV + 15) & ~15;
|
||||
|
||||
const bool KV_bounds_check = Clamp != 0;
|
||||
|
||||
layout (push_constant) uniform parameter {
|
||||
uint32_t N;
|
||||
uint32_t KV;
|
||||
|
|
@ -61,30 +67,48 @@ layout (binding = 5) writeonly buffer O {D_TYPE data_o[];};
|
|||
#if defined(A_TYPE_PACKED16)
|
||||
#define BINDING_IDX_K 0
|
||||
#define BINDING_IDX_V 1
|
||||
layout (binding = 1) readonly buffer KV_PACKED16 {A_TYPE_PACKED16 data_packed16[];} kv_packed[2];
|
||||
layout (binding = 1) readonly buffer K_PACKED16 {A_TYPE_PACKED16 k_data_packed16[];} k_packed;
|
||||
layout (binding = 2) readonly buffer V_PACKED16 {A_TYPE_PACKED16 v_data_packed16[];} v_packed;
|
||||
#endif
|
||||
|
||||
#if defined(DATA_A_Q4_0)
|
||||
#define BLOCK_BYTE_SIZE 18
|
||||
|
||||
vec4 dequantize4(uint ib, uint iqs, uint a_offset, uint binding_idx) {
|
||||
uint vui_lo = uint(kv_packed[binding_idx].data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 0]);
|
||||
uint vui_hi = uint(kv_packed[binding_idx].data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 1]);
|
||||
uint shift = (iqs & 0x10) >> 2;
|
||||
vui_lo >>= shift;
|
||||
vui_hi >>= shift;
|
||||
if (binding_idx == BINDING_IDX_K) {
|
||||
uint vui_lo = uint(k_packed.k_data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 0]);
|
||||
uint vui_hi = uint(k_packed.k_data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 1]);
|
||||
uint shift = (iqs & 0x10) >> 2;
|
||||
vui_lo >>= shift;
|
||||
vui_hi >>= shift;
|
||||
|
||||
return float(kv_packed[binding_idx].data_packed16[a_offset + ib].d) * (vec4(vui_lo & 0xF, (vui_lo >> 8) & 0xF, vui_hi & 0xF, (vui_hi >> 8) & 0xF) - 8.0f);
|
||||
return float(k_packed.k_data_packed16[a_offset + ib].d) * (vec4(vui_lo & 0xF, (vui_lo >> 8) & 0xF, vui_hi & 0xF, (vui_hi >> 8) & 0xF) - 8.0f);
|
||||
} else {
|
||||
uint vui_lo = uint(v_packed.v_data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 0]);
|
||||
uint vui_hi = uint(v_packed.v_data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 1]);
|
||||
uint shift = (iqs & 0x10) >> 2;
|
||||
vui_lo >>= shift;
|
||||
vui_hi >>= shift;
|
||||
|
||||
return float(v_packed.v_data_packed16[a_offset + ib].d) * (vec4(vui_lo & 0xF, (vui_lo >> 8) & 0xF, vui_hi & 0xF, (vui_hi >> 8) & 0xF) - 8.0f);
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
#if defined(DATA_A_Q8_0)
|
||||
#define BLOCK_BYTE_SIZE 34
|
||||
vec4 dequantize4(uint ib, uint iqs, uint a_offset, uint binding_idx) {
|
||||
const i8vec2 v0 = unpack8(int32_t(kv_packed[binding_idx].data_packed16[a_offset + ib].qs[iqs / 2])).xy; // vec4 used due to #12147
|
||||
const i8vec2 v1 = unpack8(int32_t(kv_packed[binding_idx].data_packed16[a_offset + ib].qs[iqs / 2 + 1])).xy;
|
||||
if (binding_idx == BINDING_IDX_K) {
|
||||
const i8vec2 v0 = unpack8(int32_t(k_packed.k_data_packed16[a_offset + ib].qs[iqs / 2])).xy; // vec4 used due to #12147
|
||||
const i8vec2 v1 = unpack8(int32_t(k_packed.k_data_packed16[a_offset + ib].qs[iqs / 2 + 1])).xy;
|
||||
|
||||
return float(kv_packed[binding_idx].data_packed16[a_offset + ib].d) * vec4(v0.x, v0.y, v1.x, v1.y);
|
||||
return float(k_packed.k_data_packed16[a_offset + ib].d) * vec4(v0.x, v0.y, v1.x, v1.y);
|
||||
} else {
|
||||
const i8vec2 v0 = unpack8(int32_t(v_packed.v_data_packed16[a_offset + ib].qs[iqs / 2])).xy; // vec4 used due to #12147
|
||||
const i8vec2 v1 = unpack8(int32_t(v_packed.v_data_packed16[a_offset + ib].qs[iqs / 2 + 1])).xy;
|
||||
|
||||
return float(v_packed.v_data_packed16[a_offset + ib].d) * vec4(v0.x, v0.y, v1.x, v1.y);
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
|
|
|
|||
|
|
@ -46,14 +46,14 @@ const uint32_t MatBc = 16;
|
|||
shared FLOAT_TYPE tmpsh[gl_WorkGroupSize.x];
|
||||
shared ACC_TYPEV4 tmpshv4[gl_WorkGroupSize.x];
|
||||
|
||||
const uint32_t qstride = HSK / 4 + 2; // in units of f16vec4
|
||||
const uint32_t qstride = HSK_pad / 4 + 2; // in units of f16vec4
|
||||
shared f16vec4 Qf[Br * qstride];
|
||||
|
||||
// Avoid padding for hsk==256 to make it fit in 48KB shmem.
|
||||
const uint32_t sfshstride = (HSK <= 128) ? (Br + 8) : Br;
|
||||
shared ACC_TYPE sfsh[Bc * sfshstride];
|
||||
|
||||
const uint32_t kshstride = HSK / 4 + 2; // in units of f16vec4
|
||||
const uint32_t kshstride = HSK_pad / 4 + 2; // in units of f16vec4
|
||||
shared f16vec4 ksh[Bc * kshstride];
|
||||
|
||||
shared float slope[Br];
|
||||
|
|
@ -74,6 +74,21 @@ void main() {
|
|||
|
||||
#define tile_row(r) (row_tid * rows_per_thread + (r))
|
||||
|
||||
// Zero-initialize shared memory for Q/K when HSK is not a multiple of 16 (HSK_pad > HSK).
|
||||
if ((HSK % 16) != 0) {
|
||||
[[unroll]] for (uint i = 0; i < Br * qstride; i += gl_WorkGroupSize.x) {
|
||||
if (i + tid < Br * qstride) {
|
||||
Qf[i + tid] = f16vec4(0);
|
||||
}
|
||||
}
|
||||
[[unroll]] for (uint i = 0; i < Bc * kshstride; i += gl_WorkGroupSize.x) {
|
||||
if (i + tid < Bc * kshstride) {
|
||||
ksh[i + tid] = f16vec4(0);
|
||||
}
|
||||
}
|
||||
barrier();
|
||||
}
|
||||
|
||||
uint32_t q_offset = (iq2*p.nb02+iq3*p.nb03) / 4;
|
||||
|
||||
[[unroll]] for (uint32_t idx = 0; idx < Br * HSK / 4; idx += gl_WorkGroupSize.x) {
|
||||
|
|
@ -137,28 +152,31 @@ void main() {
|
|||
uint32_t d = (idx + tid) % (HSK / 4);
|
||||
uint32_t c = (idx + tid) / (HSK / 4);
|
||||
if (c < Bc && d < HSK / 4) {
|
||||
f16vec4 K_Tf = f16vec4(0);
|
||||
if (!KV_bounds_check || j * Bc + c < KV) {
|
||||
#if BLOCK_SIZE > 1
|
||||
uint coord = (j * Bc + c) * k_stride * BLOCK_SIZE + 4 * d;
|
||||
uint ib = coord / BLOCK_SIZE;
|
||||
uint iqs = (coord % BLOCK_SIZE);
|
||||
f16vec4 K_Tf = f16vec4(dequantize4(ib, iqs, k_offset, BINDING_IDX_K));
|
||||
uint coord = (j * Bc + c) * k_stride * BLOCK_SIZE + 4 * d;
|
||||
uint ib = coord / BLOCK_SIZE;
|
||||
uint iqs = (coord % BLOCK_SIZE);
|
||||
K_Tf = f16vec4(dequantize4(ib, iqs, k_offset, BINDING_IDX_K));
|
||||
#else
|
||||
f16vec4 K_Tf = f16vec4(data_kv4[k_offset / 4 + (j * Bc + c) * k_stride / 4 + d]);
|
||||
K_Tf = f16vec4(data_kv4[k_offset / 4 + (j * Bc + c) * k_stride / 4 + d]);
|
||||
#endif
|
||||
}
|
||||
|
||||
ksh[c * kshstride + d] = K_Tf;
|
||||
}
|
||||
}
|
||||
barrier();
|
||||
|
||||
// K * Q^T -> S^T: Bc x HSK * HSK x Br -> Bc x Br
|
||||
// K * Q^T -> S^T: Bc x HSK_pad * HSK_pad x Br -> Bc x Br
|
||||
// Bc split across workgroup (four subgroups), loop over HSK in chunks of 16: 16 x 16 * 16 x 16 -> 16 x 16
|
||||
// This is written transposed in order to allow for N being 8 if implementations need it
|
||||
coopmat<ACC_TYPE, gl_ScopeSubgroup, MatBc, MatBr, gl_MatrixUseAccumulator> SfMat = coopmat<ACC_TYPE, gl_ScopeSubgroup, MatBc, MatBr, gl_MatrixUseAccumulator>(0);
|
||||
coopmat<float16_t, gl_ScopeSubgroup, MatBc, 16, gl_MatrixUseA> KMat;
|
||||
coopmat<float16_t, gl_ScopeSubgroup, 16, MatBr, gl_MatrixUseB> QMat;
|
||||
|
||||
for (uint32_t d = 0; d < HSK / 16; ++d) {
|
||||
for (uint32_t d = 0; d < HSK_pad / 16; ++d) {
|
||||
coopMatLoad(QMat, Qf, d * 16 / 4, qstride, gl_CooperativeMatrixLayoutColumnMajor);
|
||||
|
||||
uint coord = (gl_SubgroupID * MatBc) * kshstride + d * 16 / 4;
|
||||
|
|
@ -187,7 +205,9 @@ void main() {
|
|||
uint32_t c = (idx + tid) % Bc;
|
||||
uint32_t r = (idx + tid) / Bc;
|
||||
if (idx + tid < Bc * Br || idx + gl_WorkGroupSize.x <= Bc * Br) {
|
||||
sfsh[c * sfshstride + r] += ACC_TYPE(slope[r] * float(data_m[m_offset + (i * Br + r) * m_stride + (j * Bc + c)]));
|
||||
if (!KV_bounds_check || j * Bc + c < KV) {
|
||||
sfsh[c * sfshstride + r] += ACC_TYPE(slope[r] * float(data_m[m_offset + (i * Br + r) * m_stride + (j * Bc + c)]));
|
||||
}
|
||||
}
|
||||
}
|
||||
barrier();
|
||||
|
|
@ -195,8 +215,11 @@ void main() {
|
|||
|
||||
float eMf[rows_per_thread];
|
||||
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
|
||||
float rowmaxf = sfsh[tile_row(r) + (0 * cols_per_iter + col_tid) * sfshstride];
|
||||
float rowmaxf = NEG_FLT_MAX_OVER_2;
|
||||
[[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) {
|
||||
if (KV_bounds_check && j * Bc + c * cols_per_iter + col_tid >= KV) {
|
||||
continue;
|
||||
}
|
||||
rowmaxf = max(rowmaxf, float(sfsh[tile_row(r) + (c * cols_per_iter + col_tid) * sfshstride]));
|
||||
}
|
||||
float Moldf = Mf[r];
|
||||
|
|
@ -210,7 +233,7 @@ void main() {
|
|||
|
||||
[[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
|
||||
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
|
||||
Of[r][d] = float16_t(eMf[r]) * Of[r][d];
|
||||
Of[r][d] = ACC_TYPE(eMf[r]) * Of[r][d];
|
||||
}
|
||||
}
|
||||
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
|
||||
|
|
@ -218,6 +241,9 @@ void main() {
|
|||
}
|
||||
|
||||
[[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) {
|
||||
if (KV_bounds_check && j * Bc + c * cols_per_iter + col_tid >= KV) {
|
||||
continue;
|
||||
}
|
||||
float Pf[rows_per_thread];
|
||||
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
|
||||
Pf[r] = exp(sfsh[tile_row(r) + (c * cols_per_iter + col_tid) * sfshstride] - Mf[r]);
|
||||
|
|
@ -233,7 +259,7 @@ void main() {
|
|||
vec4 Vf = vec4(data_vv4[v_offset / 4 + (j * Bc + c * cols_per_iter + col_tid) * v_stride / 4 + d * D_split + d_tid]);
|
||||
#endif
|
||||
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
|
||||
Of[r][d] += float16_t(Pf[r]) * ACC_TYPEV4(Vf);
|
||||
Of[r][d] += ACC_TYPE(Pf[r]) * ACC_TYPEV4(Vf);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -288,7 +314,7 @@ void main() {
|
|||
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
|
||||
[[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
|
||||
|
||||
Of[r][d] = float16_t(eMf[r]) * Of[r][d];
|
||||
Of[r][d] = ACC_TYPE(eMf[r]) * Of[r][d];
|
||||
tmpshv4[tid] = Of[r][d];
|
||||
|
||||
barrier();
|
||||
|
|
@ -357,7 +383,10 @@ void main() {
|
|||
|
||||
[[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
|
||||
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
|
||||
Of[r][d] *= float16_t(Lfrcp[r]);
|
||||
Of[r][d] *= ACC_TYPE(Lfrcp[r]);
|
||||
#if defined(ACC_TYPE_MAX)
|
||||
Of[r][d] = clamp(Of[r][d], -ACC_TYPE_MAX, ACC_TYPE_MAX);
|
||||
#endif
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -104,16 +104,16 @@ void main() {
|
|||
tensorLayoutK = setTensorLayoutStrideNV(tensorLayoutK, k_stride, 1);
|
||||
tensorLayoutV = setTensorLayoutStrideNV(tensorLayoutV, v_stride, 1);
|
||||
|
||||
coopmat<Q_TYPE, gl_ScopeWorkgroup, Br, HSK, gl_MatrixUseAccumulator> Q;
|
||||
coopmat<float16_t, gl_ScopeWorkgroup, Br, HSK, gl_MatrixUseA> Qf16;
|
||||
coopmat<Q_TYPE, gl_ScopeWorkgroup, Br, HSK_pad, gl_MatrixUseAccumulator> Q;
|
||||
coopmat<float16_t, gl_ScopeWorkgroup, Br, HSK_pad, gl_MatrixUseA> Qf16;
|
||||
|
||||
uint32_t q_offset = iq2*p.nb02+iq3*p.nb03;
|
||||
coopMatLoadTensorNV(Q, data_q, q_offset, sliceTensorLayoutNV(tensorLayoutQ, i * Br, Br, 0, HSK));
|
||||
coopMatLoadTensorNV(Q, data_q, q_offset, sliceTensorLayoutNV(tensorLayoutQ, i * Br, Br, 0, HSK_pad));
|
||||
|
||||
Qf16 = coopmat<float16_t, gl_ScopeWorkgroup, Br, HSK, gl_MatrixUseA>(Q);
|
||||
Qf16 = coopmat<float16_t, gl_ScopeWorkgroup, Br, HSK_pad, gl_MatrixUseA>(Q);
|
||||
Qf16 *= float16_t(p.scale);
|
||||
|
||||
coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, HSV, gl_MatrixUseAccumulator> O = coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, HSV, gl_MatrixUseAccumulator>(0);
|
||||
coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, HSV_pad, gl_MatrixUseAccumulator> O = coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, HSV_pad, gl_MatrixUseAccumulator>(0);
|
||||
|
||||
coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> L, M;
|
||||
|
||||
|
|
@ -140,10 +140,10 @@ void main() {
|
|||
|
||||
coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> S = coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(0);
|
||||
|
||||
coopmat<float16_t, gl_ScopeWorkgroup, HSK, Bc, gl_MatrixUseB> K_T;
|
||||
coopmat<float16_t, gl_ScopeWorkgroup, HSK_pad, Bc, gl_MatrixUseB> K_T;
|
||||
|
||||
uint32_t k_offset = ik2*p.nb12 + ik3*p.nb13;
|
||||
coopMatLoadTensorNV(K_T, data_k, k_offset, sliceTensorLayoutNV(tensorLayoutK, j * Bc, Bc, 0, HSK), tensorViewTranspose DECODEFUNC);
|
||||
coopMatLoadTensorNV(K_T, data_k, k_offset, sliceTensorLayoutNV(tensorLayoutK, j * Bc, Bc, 0, HSK_pad), tensorViewTranspose DECODEFUNC);
|
||||
S = coopMatMulAdd(Qf16, K_T, S);
|
||||
|
||||
if (p.logit_softcap != 0.0f) {
|
||||
|
|
@ -208,31 +208,31 @@ void main() {
|
|||
rowsum = coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(0.0);
|
||||
rowsum = coopMatMulAdd(P_A, One, rowsum);
|
||||
|
||||
coopmat<float16_t, gl_ScopeWorkgroup, Bc, HSV, gl_MatrixUseB> V;
|
||||
coopmat<float16_t, gl_ScopeWorkgroup, Bc, HSV_pad, gl_MatrixUseB> V;
|
||||
uint32_t v_offset = iv2*p.nb22 + iv3*p.nb23;
|
||||
coopMatLoadTensorNV(V, data_v, v_offset, sliceTensorLayoutNV(tensorLayoutV, j * Bc, Bc, 0, HSV) DECODEFUNC);
|
||||
coopMatLoadTensorNV(V, data_v, v_offset, sliceTensorLayoutNV(tensorLayoutV, j * Bc, Bc, 0, HSV_pad) DECODEFUNC);
|
||||
|
||||
L = eM*L + rowsum;
|
||||
|
||||
// This is the "diagonal" matrix in the paper, but since we do componentwise
|
||||
// multiply rather than matrix multiply it has the diagonal element smeared
|
||||
// across the row
|
||||
coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, HSV, gl_MatrixUseAccumulator> eMdiag;
|
||||
coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, HSV_pad, gl_MatrixUseAccumulator> eMdiag;
|
||||
|
||||
// resize eM by using smear/reduce
|
||||
coopMatReduceNV(eMdiag, eM, gl_CooperativeMatrixReduceRowNV, smearReduce);
|
||||
|
||||
// multiply with fp16 accumulation, then add to O.
|
||||
coopmat<float16_t, gl_ScopeWorkgroup, Br, HSV, gl_MatrixUseAccumulator> PV = coopmat<float16_t, gl_ScopeWorkgroup, Br, HSV, gl_MatrixUseAccumulator>(0);
|
||||
coopmat<float16_t, gl_ScopeWorkgroup, Br, HSV_pad, gl_MatrixUseAccumulator> PV = coopmat<float16_t, gl_ScopeWorkgroup, Br, HSV_pad, gl_MatrixUseAccumulator>(0);
|
||||
PV = coopMatMulAdd(P_A, V, PV);
|
||||
|
||||
O = eMdiag * O + coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, HSV, gl_MatrixUseAccumulator>(PV);
|
||||
O = eMdiag * O + coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, HSV_pad, gl_MatrixUseAccumulator>(PV);
|
||||
}
|
||||
|
||||
// If there is split_k, then the split_k resolve shader does the final
|
||||
// division by L. Store the intermediate O value and per-row m and L values.
|
||||
if (p.k_num > 1) {
|
||||
coopmat<D_TYPE, gl_ScopeWorkgroup, Br, HSV, gl_MatrixUseAccumulator> O_D = coopmat<D_TYPE, gl_ScopeWorkgroup, Br, HSV, gl_MatrixUseAccumulator>(O);
|
||||
coopmat<D_TYPE, gl_ScopeWorkgroup, Br, HSV_pad, gl_MatrixUseAccumulator> O_D = coopmat<D_TYPE, gl_ScopeWorkgroup, Br, HSV_pad, gl_MatrixUseAccumulator>(O);
|
||||
|
||||
uint32_t o_offset = HSV * p.ne1 * (split_k_index + iq3 * p.k_num);
|
||||
coopMatPerElementNV(O_D, O_D, perElemOpGqaStore, o_offset, iq2, N);
|
||||
|
|
@ -243,16 +243,16 @@ void main() {
|
|||
return;
|
||||
}
|
||||
|
||||
coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, HSV, gl_MatrixUseAccumulator> Ldiag;
|
||||
coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, HSV_pad, gl_MatrixUseAccumulator> Ldiag;
|
||||
|
||||
// resize L by using smear/reduce
|
||||
coopMatReduceNV(Ldiag, L, gl_CooperativeMatrixReduceRowNV, smearReduce);
|
||||
|
||||
if ((p.mask_n_head_log2 & SINK_ENABLE_BIT) != 0) {
|
||||
coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, HSV, gl_MatrixUseAccumulator> S;
|
||||
coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, HSV_pad, gl_MatrixUseAccumulator> S;
|
||||
coopMatPerElementNV(S, S, perElemOpGetSink, iq2);
|
||||
|
||||
coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, HSV, gl_MatrixUseAccumulator> Mr;
|
||||
coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, HSV_pad, gl_MatrixUseAccumulator> Mr;
|
||||
|
||||
// resize M by using smear/reduce
|
||||
coopMatReduceNV(Mr, M, gl_CooperativeMatrixReduceRowNV, smearReduce);
|
||||
|
|
@ -283,9 +283,13 @@ void main() {
|
|||
|
||||
O = Ldiag*O;
|
||||
|
||||
#if defined(ACC_TYPE_MAX)
|
||||
[[unroll]] for (uint i = 0; i < O.length(); ++i) { O[i] = clamp(O[i], -ACC_TYPE_MAX, ACC_TYPE_MAX); }
|
||||
#endif
|
||||
|
||||
uint32_t o_offset = iq3*p.ne2*p.ne1*HSV;
|
||||
|
||||
coopmat<D_TYPE, gl_ScopeWorkgroup, Br, HSV, gl_MatrixUseAccumulator> O_D = coopmat<D_TYPE, gl_ScopeWorkgroup, Br, HSV, gl_MatrixUseAccumulator>(O);
|
||||
coopmat<D_TYPE, gl_ScopeWorkgroup, Br, HSV_pad, gl_MatrixUseAccumulator> O_D = coopmat<D_TYPE, gl_ScopeWorkgroup, Br, HSV_pad, gl_MatrixUseAccumulator>(O);
|
||||
if (p.gqa_ratio > 1) {
|
||||
coopMatPerElementNV(O_D, O_D, perElemOpGqaStore, o_offset, iq2, N);
|
||||
} else {
|
||||
|
|
@ -295,6 +299,6 @@ void main() {
|
|||
// permute dimensions
|
||||
tensorViewNV<3, false, 1, 0, 2> tensorViewPermute = createTensorViewNV(3, false, 1, 0, 2);
|
||||
|
||||
coopMatStoreTensorNV(O_D, data_o, o_offset, sliceTensorLayoutNV(tensorLayoutD, i * Br, Br, iq2, N, 0, HSV), tensorViewPermute);
|
||||
coopMatStoreTensorNV(O_D, data_o, o_offset, sliceTensorLayoutNV(tensorLayoutD, i * Br, Br, iq2, N, 0, HSV_pad), tensorViewPermute);
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -111,6 +111,10 @@ void main() {
|
|||
}
|
||||
}
|
||||
O *= L;
|
||||
|
||||
const float FLT_MAX = uintBitsToFloat(0x7F7FFFFF);
|
||||
O = clamp(O, -FLT_MAX, FLT_MAX);
|
||||
|
||||
data_d[iq3 * D * N + D * n + d] = O;
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -2,6 +2,7 @@
|
|||
#extension GL_EXT_control_flow_attributes : require
|
||||
|
||||
#include "rte.comp"
|
||||
#include "utils.comp"
|
||||
|
||||
layout (push_constant) uniform parameter
|
||||
{
|
||||
|
|
@ -28,25 +29,9 @@ uint get_aoffset() { return p.misalign_offsets >> 16; }
|
|||
uint get_boffset() { return (p.misalign_offsets >> 8) & 0xFF; }
|
||||
uint get_doffset() { return p.misalign_offsets & 0xFF; }
|
||||
|
||||
// mod and div are expensive and coordinates/dimensions are often power of 2 or equal to 1
|
||||
uint fastmod(uint a, uint b) {
|
||||
if ((b & (b-1)) == 0) {
|
||||
return a & (b-1);
|
||||
}
|
||||
return a % b;
|
||||
}
|
||||
|
||||
uint fastdiv(uint a, uint b) {
|
||||
return (a < b) ? 0 : (a / b);
|
||||
}
|
||||
|
||||
void get_indices(uint idx, out uint i00, out uint i01, out uint i02, out uint i03) {
|
||||
i03 = fastdiv(idx, (p.ne02*p.ne01*p.ne00));
|
||||
const uint i03_offset = i03 * p.ne02*p.ne01*p.ne00;
|
||||
i02 = fastdiv((idx - i03_offset), (p.ne01*p.ne00));
|
||||
const uint i02_offset = i02*p.ne01*p.ne00;
|
||||
i01 = (idx - i03_offset - i02_offset) / p.ne00;
|
||||
i00 = idx - i03_offset - i02_offset - i01*p.ne00;
|
||||
get_indices(idx, i00, i01, i02, i03, p.ne00, p.ne01, p.ne02, p.ne03);
|
||||
}
|
||||
|
||||
uint src0_idx(uint i00, uint i01, uint i02, uint i03) {
|
||||
|
|
|
|||
|
|
@ -7,27 +7,36 @@ layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
|
|||
|
||||
void main() {
|
||||
const uint i00 = gl_GlobalInvocationID.x;
|
||||
const uint i10 = gl_GlobalInvocationID.y;
|
||||
const uint i11 = (gl_GlobalInvocationID.z)/p.ne12;
|
||||
const uint i12 = (gl_GlobalInvocationID.z)%p.ne12;
|
||||
|
||||
if (i00 >= p.ne00) {
|
||||
return;
|
||||
}
|
||||
|
||||
const uint i01 = data_b[get_boffset() + i10*p.nb10 + i11*p.nb11 + i12*p.nb12];
|
||||
uint gid_z = gl_GlobalInvocationID.z;
|
||||
while (gid_z < p.ne11 * p.ne12) {
|
||||
uint gid_y = gl_GlobalInvocationID.y;
|
||||
while (gid_y < p.ne10) {
|
||||
const uint i10 = gid_y;
|
||||
const uint i11 = gid_z / p.ne12;
|
||||
const uint i12 = gid_z % p.ne12;
|
||||
|
||||
const uint a_offset = get_aoffset() + i01*p.nb01 + i11*p.nb02 + i12*p.nb03;
|
||||
const uint d_offset = get_doffset() + i10*p.nb21 + i11*p.nb22 + i12*p.nb23;
|
||||
const uint i01 = data_b[get_boffset() + i10*p.nb10 + i11*p.nb11 + i12*p.nb12];
|
||||
|
||||
const uint a_offset = get_aoffset() + i01*p.nb01 + i11*p.nb02 + i12*p.nb03;
|
||||
const uint d_offset = get_doffset() + i10*p.nb21 + i11*p.nb22 + i12*p.nb23;
|
||||
|
||||
#if defined(DATA_A_BF16)
|
||||
FLOAT_TYPE v = FLOAT_TYPE(bf16_to_fp32(data_a[a_offset + i00]));
|
||||
FLOAT_TYPE v = FLOAT_TYPE(bf16_to_fp32(data_a[a_offset + i00]));
|
||||
#else
|
||||
FLOAT_TYPE v = FLOAT_TYPE(data_a[a_offset + i00]);
|
||||
FLOAT_TYPE v = FLOAT_TYPE(data_a[a_offset + i00]);
|
||||
#endif
|
||||
#ifndef OPTIMIZATION_ERROR_WORKAROUND
|
||||
data_d[d_offset + i00] = D_TYPE(v);
|
||||
data_d[d_offset + i00] = D_TYPE(v);
|
||||
#else
|
||||
data_d[d_offset + i00] = D_TYPE(v);
|
||||
data_d[d_offset + i00] = D_TYPE(v);
|
||||
#endif
|
||||
gid_y += gl_WorkGroupSize.y * gl_NumWorkGroups.y;
|
||||
}
|
||||
gid_z += gl_WorkGroupSize.z * gl_NumWorkGroups.z;
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -10,9 +10,6 @@ layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
|
|||
|
||||
void main() {
|
||||
const uint i00 = (gl_GlobalInvocationID.x)*2;
|
||||
const uint i10 = gl_GlobalInvocationID.y;
|
||||
const uint i11 = (gl_GlobalInvocationID.z)/p.ne12;
|
||||
const uint i12 = (gl_GlobalInvocationID.z)%p.ne12;
|
||||
|
||||
#ifdef NEEDS_INIT_IQ_SHMEM
|
||||
init_iq_shmem(gl_WorkGroupSize);
|
||||
|
|
@ -22,20 +19,33 @@ void main() {
|
|||
return;
|
||||
}
|
||||
|
||||
const uint i01 = data_b[i10*p.nb10 + i11*p.nb11 + i12*p.nb12];
|
||||
uint gid_z = gl_GlobalInvocationID.z;
|
||||
while (gid_z < p.ne11 * p.ne12) {
|
||||
uint gid_y = gl_GlobalInvocationID.y;
|
||||
while (gid_y < p.ne10) {
|
||||
const uint i10 = gid_y;
|
||||
const uint i11 = gid_z / p.ne12;
|
||||
const uint i12 = gid_z % p.ne12;
|
||||
|
||||
const uint a_offset = i01*p.nb01 + i11*p.nb02 + i12*p.nb03;
|
||||
const uint d_offset = i10*p.nb21 + i11*p.nb22 + i12*p.nb23;
|
||||
const uint i01 = data_b[i10*p.nb10 + i11*p.nb11 + i12*p.nb12];
|
||||
|
||||
const uint ib = a_offset + i00/QUANT_K; // block index
|
||||
const uint iqs = (i00%QUANT_K)/QUANT_R; // quant index
|
||||
const uint iybs = i00 - i00%QUANT_K; // dst block start index
|
||||
const uint y_offset = QUANT_R == 1 ? 1 : QUANT_K/2;
|
||||
const uint a_offset = i01*p.nb01 + i11*p.nb02 + i12*p.nb03;
|
||||
const uint d_offset = i10*p.nb21 + i11*p.nb22 + i12*p.nb23;
|
||||
|
||||
vec2 v = dequantize(ib, iqs, 0);
|
||||
const vec2 dm = get_dm(ib, 0);
|
||||
v = v * dm.x + dm.y;
|
||||
const uint ib = a_offset + i00/QUANT_K; // block index
|
||||
const uint iqs = (i00%QUANT_K)/QUANT_R; // quant index
|
||||
const uint iybs = i00 - i00%QUANT_K; // dst block start index
|
||||
const uint y_offset = QUANT_R == 1 ? 1 : QUANT_K/2;
|
||||
|
||||
data_d[d_offset + iybs + iqs ] = D_TYPE(v.x);
|
||||
data_d[d_offset + iybs + iqs + y_offset] = D_TYPE(v.y);
|
||||
vec2 v = dequantize(ib, iqs, 0);
|
||||
const vec2 dm = get_dm(ib, 0);
|
||||
v = v * dm.x + dm.y;
|
||||
|
||||
data_d[d_offset + iybs + iqs ] = D_TYPE(v.x);
|
||||
data_d[d_offset + iybs + iqs + y_offset] = D_TYPE(v.y);
|
||||
|
||||
gid_y += gl_WorkGroupSize.y * gl_NumWorkGroups.y;
|
||||
}
|
||||
gid_z += gl_WorkGroupSize.z * gl_NumWorkGroups.z;
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -0,0 +1,22 @@
|
|||
#version 450
|
||||
|
||||
#include "generic_head.comp"
|
||||
#include "types.comp"
|
||||
|
||||
#extension GL_EXT_control_flow_attributes : enable
|
||||
|
||||
layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
|
||||
|
||||
layout (binding = 0) readonly buffer X {A_TYPE data_a[];};
|
||||
layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};
|
||||
|
||||
void main() {
|
||||
const uint i = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x;
|
||||
|
||||
if (i >= p.KX) {
|
||||
return;
|
||||
}
|
||||
|
||||
const float x = float(data_a[i]);
|
||||
data_d[i] = D_TYPE(min(1.0f, max(0.0f, (x + 3.0f) / 6.0f)));
|
||||
}
|
||||
|
|
@ -0,0 +1,22 @@
|
|||
#version 450
|
||||
|
||||
#include "generic_head.comp"
|
||||
#include "types.comp"
|
||||
|
||||
#extension GL_EXT_control_flow_attributes : enable
|
||||
|
||||
layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
|
||||
|
||||
layout (binding = 0) readonly buffer X {A_TYPE data_a[];};
|
||||
layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};
|
||||
|
||||
void main() {
|
||||
const uint i = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x;
|
||||
|
||||
if (i >= p.KX) {
|
||||
return;
|
||||
}
|
||||
|
||||
const float x = float(data_a[i]);
|
||||
data_d[i] = D_TYPE(x * min(1.0f, max(0.0f, (x + 3.0f) / 6.0f)));
|
||||
}
|
||||
|
|
@ -5,8 +5,11 @@
|
|||
|
||||
#include "rte.comp"
|
||||
|
||||
#include "types.comp"
|
||||
|
||||
layout (push_constant) uniform parameter
|
||||
{
|
||||
BDA_STORAGE_T dst_addr;
|
||||
uint batch_offset; uint offset_delta;
|
||||
uint IC;
|
||||
uint IW; uint IH;
|
||||
|
|
@ -19,8 +22,6 @@ layout (push_constant) uniform parameter
|
|||
int d0; int d1;
|
||||
} p;
|
||||
|
||||
#include "types.comp"
|
||||
|
||||
layout(constant_id = 0) const uint BLOCK_SIZE = 32;
|
||||
|
||||
const uint NUM_ITER = 512 / BLOCK_SIZE;
|
||||
|
|
@ -30,6 +31,10 @@ layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
|
|||
layout (binding = 0) readonly buffer X {A_TYPE data_a[];};
|
||||
layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};
|
||||
|
||||
#if BDA
|
||||
layout (buffer_reference) buffer D_ptr {D_TYPE d;};
|
||||
#endif
|
||||
|
||||
void main() {
|
||||
const uint gidx = gl_GlobalInvocationID.x;
|
||||
|
||||
|
|
@ -38,7 +43,7 @@ void main() {
|
|||
const uint ic = gl_GlobalInvocationID.z % p.IC;
|
||||
|
||||
const uint src_base = ic * p.offset_delta + batch * p.batch_offset;
|
||||
const uint dst_base = ((batch * p.OH + oh) * p.OW) * p.CHW + ic * (p.KW * p.KH);
|
||||
const BDA_OFFSET_T dst_base = ((BDA_OFFSET_T(batch) * p.OH + oh) * p.OW) * p.CHW + BDA_OFFSET_T(ic) * (p.KW * p.KH);
|
||||
const int oh_s1 = int(oh) * p.s1;
|
||||
const uint ksize = p.OW * p.KH;
|
||||
|
||||
|
|
@ -50,7 +55,7 @@ void main() {
|
|||
uint current_ix = rem % p.OW;
|
||||
|
||||
A_TYPE values[NUM_ITER];
|
||||
uint offset_dst[NUM_ITER];
|
||||
BDA_OFFSET_T offset_dst[NUM_ITER];
|
||||
[[unroll]] for (uint idx = 0; idx < NUM_ITER; ++idx) {
|
||||
values[idx] = A_TYPE(0);
|
||||
}
|
||||
|
|
@ -66,7 +71,7 @@ void main() {
|
|||
const uint iiw = current_ix * p.s0 + current_kx * p.d0 - p.p0;
|
||||
const uint iih = oh_s1 + current_ky * p.d1 - p.p1;
|
||||
|
||||
offset_dst[idx] = dst_base + current_ix * p.CHW + current_ky * p.KW + current_kx;
|
||||
offset_dst[idx] = dst_base + BDA_OFFSET_T(current_ix) * p.CHW + current_ky * p.KW + current_kx;
|
||||
|
||||
if ((iih < p.IH) && (iiw < p.IW)) {
|
||||
values[idx] = data_a[src_base + iih * p.IW + iiw];
|
||||
|
|
@ -89,7 +94,11 @@ void main() {
|
|||
continue;
|
||||
}
|
||||
|
||||
#if BDA
|
||||
D_ptr dst_addr = D_ptr(p.dst_addr + D_SIZE * offset_dst[idx]);
|
||||
dst_addr.d = D_TYPE(values[idx]);
|
||||
#else
|
||||
data_d[offset_dst[idx]] = D_TYPE(values[idx]);
|
||||
#endif
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
|||
|
|
@ -0,0 +1,126 @@
|
|||
#version 450
|
||||
|
||||
#extension GL_EXT_shader_16bit_storage : require
|
||||
#extension GL_EXT_control_flow_attributes : require
|
||||
#extension GL_EXT_shader_explicit_arithmetic_types_int32 : require
|
||||
|
||||
#include "rte.comp"
|
||||
|
||||
#include "types.comp"
|
||||
|
||||
layout (push_constant) uniform parameter
|
||||
{
|
||||
BDA_STORAGE_T dst_addr;
|
||||
uint32_t nb10;
|
||||
uint32_t nb11;
|
||||
uint32_t nb12;
|
||||
uint32_t nb13;
|
||||
uint32_t s0;
|
||||
uint32_t s1;
|
||||
uint32_t s2;
|
||||
uint32_t p0;
|
||||
uint32_t p1;
|
||||
uint32_t p2;
|
||||
uint32_t d0;
|
||||
uint32_t d1;
|
||||
uint32_t d2;
|
||||
uint32_t IW;
|
||||
uint32_t IH;
|
||||
uint32_t ID;
|
||||
uint32_t IC;
|
||||
uint32_t KW;
|
||||
uint32_t OH;
|
||||
uint32_t KD_KH_KW;
|
||||
uint32_t KH_KW;
|
||||
uint32_t IC_KD_KH_KW;
|
||||
uint32_t N_OD_OH;
|
||||
uint32_t OD_OH;
|
||||
uint32_t OD_OH_OW_IC_KD_KH_KW;
|
||||
uint32_t OH_OW_IC_KD_KH_KW;
|
||||
uint32_t OW_IC_KD_KH_KW;
|
||||
uint32_t misalign_offsets;
|
||||
} p;
|
||||
|
||||
uint get_aoffset() { return p.misalign_offsets >> 16; }
|
||||
uint get_doffset() { return p.misalign_offsets & 0xFFFF; }
|
||||
|
||||
layout(constant_id = 0) const uint BLOCK_SIZE = 32;
|
||||
|
||||
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
|
||||
|
||||
layout (binding = 0) readonly buffer X {A_TYPE data_a[];};
|
||||
layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};
|
||||
|
||||
#if BDA
|
||||
layout (buffer_reference) buffer D_ptr {D_TYPE d;};
|
||||
#endif
|
||||
|
||||
void main() {
|
||||
const uint32_t i = gl_GlobalInvocationID.x;
|
||||
|
||||
uint32_t nb10 = p.nb10;
|
||||
uint32_t nb11 = p.nb11;
|
||||
uint32_t nb12 = p.nb12;
|
||||
uint32_t nb13 = p.nb13;
|
||||
uint32_t s0 = p.s0;
|
||||
uint32_t s1 = p.s1;
|
||||
uint32_t s2 = p.s2;
|
||||
uint32_t p0 = p.p0;
|
||||
uint32_t p1 = p.p1;
|
||||
uint32_t p2 = p.p2;
|
||||
uint32_t d0 = p.d0;
|
||||
uint32_t d1 = p.d1;
|
||||
uint32_t d2 = p.d2;
|
||||
uint32_t IW = p.IW;
|
||||
uint32_t IH = p.IH;
|
||||
uint32_t ID = p.ID;
|
||||
uint32_t IC = p.IC;
|
||||
uint32_t KW = p.KW;
|
||||
uint32_t OH = p.OH;
|
||||
uint32_t KD_KH_KW = p.KD_KH_KW;
|
||||
uint32_t KH_KW = p.KH_KW;
|
||||
uint32_t IC_KD_KH_KW = p.IC_KD_KH_KW;
|
||||
uint32_t N_OD_OH = p.N_OD_OH;
|
||||
uint32_t OD_OH = p.OD_OH;
|
||||
uint32_t OD_OH_OW_IC_KD_KH_KW = p.OD_OH_OW_IC_KD_KH_KW;
|
||||
uint32_t OH_OW_IC_KD_KH_KW = p.OH_OW_IC_KD_KH_KW;
|
||||
uint32_t OW_IC_KD_KH_KW = p.OW_IC_KD_KH_KW;
|
||||
|
||||
if (i >= IC_KD_KH_KW) {
|
||||
return;
|
||||
}
|
||||
|
||||
const uint32_t iic = i / KD_KH_KW;
|
||||
const uint32_t ikd = (i - iic * KD_KH_KW) / KH_KW;
|
||||
const uint32_t ikh = (i - iic * KD_KH_KW - ikd * KH_KW) / KW;
|
||||
const uint32_t ikw = i % KW;
|
||||
|
||||
const uint32_t iow = gl_GlobalInvocationID.y;
|
||||
for (uint32_t iz = gl_GlobalInvocationID.z; iz < N_OD_OH; iz += gl_NumWorkGroups.z) {
|
||||
const uint32_t in_ = iz / OD_OH;
|
||||
const uint32_t iod = (iz - in_*OD_OH) / OH;
|
||||
const uint32_t ioh = iz % OH;
|
||||
|
||||
const uint32_t iiw = iow * s0 + ikw * d0 - p0;
|
||||
const uint32_t iih = ioh * s1 + ikh * d1 - p1;
|
||||
const uint32_t iid = iod * s2 + ikd * d2 - p2;
|
||||
|
||||
const BDA_OFFSET_T offset_dst = BDA_OFFSET_T(in_)*OD_OH_OW_IC_KD_KH_KW + BDA_OFFSET_T(iod)*OH_OW_IC_KD_KH_KW + BDA_OFFSET_T(ioh)*OW_IC_KD_KH_KW + BDA_OFFSET_T(iow)*IC_KD_KH_KW + iic*KD_KH_KW + ikd * KH_KW + ikh*KW + ikw;
|
||||
|
||||
const uint32_t offset_src = (in_*IC + iic)*nb13 + iid*nb12 + iih*nb11 + iiw*nb10;
|
||||
#if BDA
|
||||
D_ptr dst_addr = D_ptr(p.dst_addr + D_SIZE * offset_dst);
|
||||
if (iih >= IH || iiw >= IW || iid >= ID) {
|
||||
dst_addr.d = D_TYPE(0.0f);
|
||||
} else {
|
||||
dst_addr.d = D_TYPE(data_a[offset_src + get_aoffset()]);
|
||||
}
|
||||
#else
|
||||
if (iih >= IH || iiw >= IW || iid >= ID) {
|
||||
data_d[offset_dst + get_doffset()] = D_TYPE(0.0f);
|
||||
} else {
|
||||
data_d[offset_dst + get_doffset()] = D_TYPE(data_a[offset_src + get_aoffset()]);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
}
|
||||
|
|
@ -2,16 +2,30 @@
|
|||
#extension GL_EXT_shader_16bit_storage : require
|
||||
#extension GL_EXT_shader_8bit_storage : require
|
||||
|
||||
#if USE_SUBGROUP_ADD || USE_SUBGROUP_ADD_NO_SHMEM
|
||||
#extension GL_KHR_shader_subgroup_basic : require
|
||||
#extension GL_KHR_shader_subgroup_arithmetic : require
|
||||
#endif
|
||||
|
||||
#ifdef MUL_MAT_ID
|
||||
#define EXPERT_COUNT 8
|
||||
#endif
|
||||
|
||||
#include "types.comp"
|
||||
|
||||
#ifndef MMQ
|
||||
layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
|
||||
#else
|
||||
layout (binding = 0) readonly buffer A {A_TYPE_PACKED16 data_a[];};
|
||||
#endif
|
||||
|
||||
layout (binding = 1) readonly buffer B {B_TYPE data_b[];};
|
||||
#ifdef B_TYPE_VEC2
|
||||
layout (binding = 1) readonly buffer BV2 {B_TYPE_VEC2 data_b_v2[];};
|
||||
#endif
|
||||
#ifdef B_TYPE_VEC4
|
||||
layout (binding = 1) readonly buffer BV4 {B_TYPE_VEC4 data_b_v4[];};
|
||||
#endif
|
||||
|
||||
layout (binding = 2) writeonly buffer D {D_TYPE data_d[];};
|
||||
#ifdef MUL_MAT_ID
|
||||
|
|
@ -88,9 +102,57 @@ layout (constant_id = 0) const uint BLOCK_SIZE = 32;
|
|||
layout (constant_id = 1) const uint NUM_ROWS = 1;
|
||||
layout (constant_id = 2) const uint NUM_COLS = 1;
|
||||
|
||||
#ifdef USE_SUBGROUP_ADD_NO_SHMEM
|
||||
void reduce_result(inout FLOAT_TYPE temp[NUM_COLS][NUM_ROWS], const in uint32_t d_offset, const in uint32_t first_row, const in uint32_t num_rows, const in uint32_t tid) {
|
||||
[[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
|
||||
[[unroll]] for (uint n = 0; n < num_rows; ++n) {
|
||||
temp[j][n] = subgroupAdd(temp[j][n]);
|
||||
}
|
||||
}
|
||||
|
||||
if (tid == 0) {
|
||||
[[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
|
||||
[[unroll]] for (uint n = 0; n < num_rows; ++n) {
|
||||
data_d[j*p.batch_stride_d + d_offset + first_row + n] = D_TYPE(temp[j][n]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
#else
|
||||
shared FLOAT_TYPE tmpsh[NUM_COLS][NUM_ROWS][BLOCK_SIZE];
|
||||
|
||||
void reduce_result(const in FLOAT_TYPE temp[NUM_COLS][NUM_ROWS], const in uint32_t d_offset, const in uint32_t first_row, const in uint32_t num_rows, const in uint32_t tid) {
|
||||
void reduce_result(FLOAT_TYPE temp[NUM_COLS][NUM_ROWS], const in uint32_t d_offset, const in uint32_t first_row, const in uint32_t num_rows, const in uint32_t tid) {
|
||||
// subgroupAdd is probably faster on devices that support it,
|
||||
// particularly when the workgroup has more than one subgroup
|
||||
#if USE_SUBGROUP_ADD
|
||||
// sum up partial sums within a subgroup
|
||||
[[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
|
||||
[[unroll]] for (uint n = 0; n < num_rows; ++n) {
|
||||
temp[j][n] = subgroupAdd(temp[j][n]);
|
||||
}
|
||||
}
|
||||
|
||||
// Go through shared memory to sum partials across subgroups
|
||||
if (gl_SubgroupInvocationID == 0) {
|
||||
[[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
|
||||
[[unroll]] for (uint n = 0; n < num_rows; ++n) {
|
||||
tmpsh[j][n][gl_SubgroupID] = temp[j][n];
|
||||
}
|
||||
}
|
||||
}
|
||||
barrier();
|
||||
if (tid == 0) {
|
||||
[[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
|
||||
[[unroll]] for (uint n = 0; n < num_rows; ++n) {
|
||||
temp[j][n] = FLOAT_TYPE(0);
|
||||
[[unroll]] for (uint s = 0; s < gl_NumSubgroups; ++s) {
|
||||
temp[j][n] += tmpsh[j][n][s];
|
||||
}
|
||||
data_d[j*p.batch_stride_d + d_offset + first_row + n] = D_TYPE(temp[j][n]);
|
||||
}
|
||||
}
|
||||
}
|
||||
#else
|
||||
// sum up partial sums and write back result
|
||||
[[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
|
||||
[[unroll]] for (uint n = 0; n < num_rows; ++n) {
|
||||
|
|
@ -115,4 +177,6 @@ void reduce_result(const in FLOAT_TYPE temp[NUM_COLS][NUM_ROWS], const in uint32
|
|||
}
|
||||
}
|
||||
}
|
||||
#endif
|
||||
}
|
||||
#endif
|
||||
|
|
|
|||
|
|
@ -0,0 +1,140 @@
|
|||
#version 450
|
||||
|
||||
#extension GL_EXT_shader_explicit_arithmetic_types_int32 : require
|
||||
#extension GL_EXT_integer_dot_product : require
|
||||
|
||||
#define MMQ
|
||||
#define B_TYPE block_q8_1_x4
|
||||
|
||||
#include "mul_mat_vec_base.comp"
|
||||
|
||||
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
|
||||
|
||||
#define K_PER_ITER 8
|
||||
|
||||
#include "mul_mmq_funcs.comp"
|
||||
|
||||
uint a_offset, b_offset, d_offset;
|
||||
|
||||
int32_t cache_b_qs[2];
|
||||
vec2 cache_b_ds;
|
||||
|
||||
void iter(inout FLOAT_TYPE temp[NUM_COLS][NUM_ROWS], const uint first_row, const uint num_rows, const uint tid, const uint i) {
|
||||
[[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
|
||||
const uint col = i*BLOCK_SIZE + tid*K_PER_ITER;
|
||||
|
||||
// Preload data_b block
|
||||
const uint b_block_idx = (j*p.batch_stride_b + col) / QUANT_K_Q8_1 + b_offset;
|
||||
const uint b_qs_idx = tid % 4;
|
||||
const uint b_block_idx_outer = b_block_idx / 4;
|
||||
const uint b_block_idx_inner = b_block_idx % 4;
|
||||
cache_b_ds = vec2(data_b[b_block_idx_outer].ds[b_block_idx_inner]);
|
||||
|
||||
#if QUANT_R == 2
|
||||
cache_b_qs[0] = data_b[b_block_idx_outer].qs[b_block_idx_inner * 8 + b_qs_idx];
|
||||
cache_b_qs[1] = data_b[b_block_idx_outer].qs[b_block_idx_inner * 8 + b_qs_idx + 4];
|
||||
#else
|
||||
cache_b_qs[0] = data_b[b_block_idx_outer].qs[b_block_idx_inner * 8 + b_qs_idx * 2];
|
||||
cache_b_qs[1] = data_b[b_block_idx_outer].qs[b_block_idx_inner * 8 + b_qs_idx * 2 + 1];
|
||||
#endif
|
||||
|
||||
uint ibi = first_row*p.ncols;
|
||||
[[unroll]] for (uint n = 0; n < num_rows; ++n) {
|
||||
const uint a_block_idx = (ibi + col)/QUANT_K + a_offset;
|
||||
ibi += p.ncols;
|
||||
|
||||
int32_t q_sum = 0;
|
||||
#if QUANT_R == 2
|
||||
const i32vec2 data_a_qs = repack(a_block_idx, b_qs_idx);
|
||||
q_sum += dotPacked4x8EXT(data_a_qs.x,
|
||||
cache_b_qs[0]);
|
||||
q_sum += dotPacked4x8EXT(data_a_qs.y,
|
||||
cache_b_qs[1]);
|
||||
#else
|
||||
int32_t data_a_qs = repack(a_block_idx, b_qs_idx * 2);
|
||||
q_sum += dotPacked4x8EXT(data_a_qs,
|
||||
cache_b_qs[0]);
|
||||
data_a_qs = repack(a_block_idx, b_qs_idx * 2 + 1);
|
||||
q_sum += dotPacked4x8EXT(data_a_qs,
|
||||
cache_b_qs[1]);
|
||||
#endif
|
||||
|
||||
#if QUANT_AUXF == 1
|
||||
temp[j][n] += mul_q8_1(q_sum, get_d(a_block_idx), cache_b_ds, 4);
|
||||
#else
|
||||
temp[j][n] += mul_q8_1(q_sum, get_dm(a_block_idx), cache_b_ds, 4);
|
||||
#endif
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
|
||||
const uint tid = gl_LocalInvocationID.x;
|
||||
|
||||
get_offsets(a_offset, b_offset, d_offset);
|
||||
a_offset /= QUANT_K;
|
||||
b_offset /= QUANT_K_Q8_1;
|
||||
|
||||
FLOAT_TYPE temp[NUM_COLS][NUM_ROWS];
|
||||
|
||||
[[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
|
||||
[[unroll]] for (uint n = 0; n < num_rows; ++n) {
|
||||
temp[j][n] = FLOAT_TYPE(0.0f);
|
||||
}
|
||||
}
|
||||
|
||||
uint num_iters = p.ncols / (K_PER_ITER * BLOCK_SIZE);
|
||||
if (num_iters * K_PER_ITER * BLOCK_SIZE + K_PER_ITER*tid < p.ncols) {
|
||||
num_iters++;
|
||||
}
|
||||
int unroll_count = 4;
|
||||
uint unrolled_iters = num_iters & ~(unroll_count - 1);
|
||||
|
||||
uint i = 0;
|
||||
while (i < unrolled_iters) {
|
||||
// Manually partially unroll the loop
|
||||
[[unroll]] for (uint k = 0; k < unroll_count; ++k) {
|
||||
iter(temp, first_row, num_rows, tid, i*K_PER_ITER);
|
||||
i++;
|
||||
}
|
||||
}
|
||||
|
||||
unroll_count = 2;
|
||||
unrolled_iters = num_iters & ~(unroll_count - 1);
|
||||
|
||||
#if K_PER_ITER == 2
|
||||
if ((p.ncols & 1) != 0 &&
|
||||
unrolled_iters == num_iters &&
|
||||
unrolled_iters > 0) {
|
||||
unrolled_iters -= unroll_count;
|
||||
}
|
||||
#endif
|
||||
|
||||
while (i < unrolled_iters) {
|
||||
// Manually partially unroll the loop
|
||||
[[unroll]] for (uint k = 0; k < unroll_count; ++k) {
|
||||
iter(temp, first_row, num_rows, tid, i*K_PER_ITER);
|
||||
i++;
|
||||
}
|
||||
}
|
||||
while (i < num_iters) {
|
||||
iter(temp, first_row, num_rows, tid, i*K_PER_ITER);
|
||||
i++;
|
||||
}
|
||||
|
||||
reduce_result(temp, d_offset, first_row, num_rows, tid);
|
||||
}
|
||||
|
||||
void main() {
|
||||
const uint first_row = NUM_ROWS * (gl_WorkGroupID.x + gl_NumWorkGroups.x * gl_WorkGroupID.z);
|
||||
|
||||
// do NUM_ROWS at a time, unless there aren't enough remaining rows
|
||||
if (first_row + NUM_ROWS <= p.stride_d) {
|
||||
compute_outputs(first_row, NUM_ROWS);
|
||||
} else {
|
||||
if (first_row >= p.stride_d) {
|
||||
return;
|
||||
}
|
||||
compute_outputs(first_row, p.stride_d - first_row);
|
||||
}
|
||||
}
|
||||
|
|
@ -17,6 +17,9 @@
|
|||
#ifdef COOPMAT
|
||||
#extension GL_KHR_cooperative_matrix : enable
|
||||
#extension GL_KHR_memory_scope_semantics : enable
|
||||
#endif
|
||||
|
||||
#if defined(COOPMAT) || defined(MUL_MAT_ID_USE_SUBGROUPS)
|
||||
#extension GL_KHR_shader_subgroup_basic : enable
|
||||
#extension GL_KHR_shader_subgroup_ballot : enable
|
||||
#endif
|
||||
|
|
@ -34,6 +37,18 @@
|
|||
#define LOAD_VEC_B 1
|
||||
#endif
|
||||
|
||||
// Load 2 values at once without affecting index calculations through LOAD_VEC
|
||||
#if (defined(DATA_A_F32) || defined(DATA_A_F16) || defined(DATA_A_BF16)) && !defined(ALIGNED)
|
||||
#define LOAD_VEC_BATCH_A 2
|
||||
#else
|
||||
#define LOAD_VEC_BATCH_A 1
|
||||
#endif
|
||||
#if !defined(ALIGNED)
|
||||
#define LOAD_VEC_BATCH_B 2
|
||||
#else
|
||||
#define LOAD_VEC_BATCH_B 1
|
||||
#endif
|
||||
|
||||
#if !defined(TO_FLOAT_TYPE)
|
||||
#define TO_FLOAT_TYPE FLOAT_TYPE
|
||||
#endif
|
||||
|
|
@ -95,28 +110,93 @@ layout (constant_id = 9) const uint TK = 1; // Only needed for coopmat
|
|||
layout (constant_id = 10) const uint WARP = 32;
|
||||
|
||||
#ifdef COOPMAT
|
||||
#define SHMEM_STRIDE (BK + 8)
|
||||
#define SHMEM_STRIDE (BK / 2 + 4)
|
||||
#else
|
||||
#define SHMEM_STRIDE (BK + 1)
|
||||
#define SHMEM_STRIDE (BK / 2 + 1)
|
||||
#endif
|
||||
|
||||
shared FLOAT_TYPE buf_a[BM * SHMEM_STRIDE];
|
||||
shared FLOAT_TYPE buf_b[BN * SHMEM_STRIDE];
|
||||
|
||||
#ifdef MUL_MAT_ID
|
||||
shared u16vec2 row_ids[4096];
|
||||
uint _ne1;
|
||||
#ifdef COOPMAT
|
||||
shared uint _ne1_sh;
|
||||
#endif
|
||||
#endif // MUL_MAT_ID
|
||||
shared FLOAT_TYPE_VEC2 buf_a[BM * SHMEM_STRIDE];
|
||||
shared FLOAT_TYPE_VEC2 buf_b[BN * SHMEM_STRIDE];
|
||||
|
||||
#define NUM_WARPS (BLOCK_SIZE / WARP)
|
||||
|
||||
#ifdef MUL_MAT_ID
|
||||
shared u16vec2 row_ids[BN];
|
||||
uint _ne1;
|
||||
|
||||
#ifdef MUL_MAT_ID_USE_SUBGROUPS
|
||||
shared uvec4 ballots_sh[NUM_WARPS];
|
||||
|
||||
void load_row_ids(uint expert_idx, bool nei0_is_pow2, uint ic) {
|
||||
_ne1 = 0;
|
||||
uint num_elements = p.nei1 * p.nei0;
|
||||
uint nei0shift = findLSB(p.nei0);
|
||||
|
||||
uint ids[16];
|
||||
uint iter = 0;
|
||||
|
||||
for (uint j = 0; j < num_elements; j += BLOCK_SIZE) {
|
||||
// prefetch up to 16 elements
|
||||
if (iter == 0) {
|
||||
[[unroll]] for (uint k = 0; k < 16; ++k) {
|
||||
uint i = j + gl_LocalInvocationIndex + k*BLOCK_SIZE;
|
||||
bool in_range = i < num_elements;
|
||||
uint ii1;
|
||||
if (nei0_is_pow2) {
|
||||
ii1 = i >> nei0shift;
|
||||
} else {
|
||||
ii1 = i / p.nei0;
|
||||
}
|
||||
uint ii0 = i - ii1 * p.nei0;
|
||||
ids[k] = in_range ? data_ids[ii1*p.nbi1 + ii0] : 0;
|
||||
}
|
||||
}
|
||||
uint i = j + gl_LocalInvocationIndex;
|
||||
bool in_range = i < num_elements;
|
||||
uint ii1;
|
||||
if (nei0_is_pow2) {
|
||||
ii1 = i >> nei0shift;
|
||||
} else {
|
||||
ii1 = i / p.nei0;
|
||||
}
|
||||
uint ii0 = i - ii1 * p.nei0;
|
||||
uint id = ids[iter++];
|
||||
uvec4 ballot = subgroupBallot(in_range && id == expert_idx);
|
||||
|
||||
ballots_sh[gl_SubgroupID] = ballot;
|
||||
barrier();
|
||||
|
||||
uint subgroup_base = 0;
|
||||
uint total = 0;
|
||||
for (uint k = 0; k < gl_NumSubgroups; ++k) {
|
||||
if (k == gl_SubgroupID) {
|
||||
subgroup_base = total;
|
||||
}
|
||||
total += subgroupBallotBitCount(ballots_sh[k]);
|
||||
}
|
||||
barrier();
|
||||
|
||||
uint idx = subgroup_base + subgroupBallotExclusiveBitCount(ballot);
|
||||
if (in_range && id == expert_idx && _ne1 + idx >= ic * BN && _ne1 + idx < (ic + 1) * BN) {
|
||||
row_ids[_ne1 + idx - ic * BN] = u16vec2(ii0, ii1);
|
||||
}
|
||||
_ne1 += total;
|
||||
iter &= 15;
|
||||
if (_ne1 >= (ic + 1) * BN) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
barrier();
|
||||
}
|
||||
#endif // MUL_MAT_ID_USE_SUBGROUPS
|
||||
#endif // MUL_MAT_ID
|
||||
|
||||
#ifdef COOPMAT
|
||||
shared ACC_TYPE coopmat_stage[TM * TN * NUM_WARPS];
|
||||
#endif
|
||||
|
||||
#include "mul_mm_funcs.comp"
|
||||
|
||||
void main() {
|
||||
#ifdef NEEDS_INIT_IQ_SHMEM
|
||||
init_iq_shmem(gl_WorkGroupSize);
|
||||
|
|
@ -168,60 +248,29 @@ void main() {
|
|||
const uint warp_r = warp_i % (BM / WM);
|
||||
const uint warp_c = warp_i / (BM / WM);
|
||||
|
||||
const uint loadr_a = gl_LocalInvocationID.x % (BK / LOAD_VEC_A);
|
||||
const uint loadc_a = gl_LocalInvocationID.x / (BK / LOAD_VEC_A);
|
||||
const uint loadr_b = gl_LocalInvocationID.x % (BK / LOAD_VEC_B);
|
||||
const uint loadc_b = gl_LocalInvocationID.x / (BK / LOAD_VEC_B);
|
||||
const uint loadr_a = gl_LocalInvocationID.x % (BK / LOAD_VEC_A / LOAD_VEC_BATCH_A);
|
||||
const uint loadc_a = gl_LocalInvocationID.x / (BK / LOAD_VEC_A / LOAD_VEC_BATCH_A);
|
||||
const uint loadr_b = gl_LocalInvocationID.x % (BK / LOAD_VEC_B / LOAD_VEC_BATCH_B);
|
||||
const uint loadc_b = gl_LocalInvocationID.x / (BK / LOAD_VEC_B / LOAD_VEC_BATCH_B);
|
||||
|
||||
const uint loadstride_a = gl_WorkGroupSize.x * LOAD_VEC_A / BK;
|
||||
const uint loadstride_b = gl_WorkGroupSize.x * LOAD_VEC_B / BK;
|
||||
const uint loadstride_a = gl_WorkGroupSize.x * LOAD_VEC_A * LOAD_VEC_BATCH_A / BK;
|
||||
const uint loadstride_b = gl_WorkGroupSize.x * LOAD_VEC_B * LOAD_VEC_BATCH_B / BK;
|
||||
|
||||
#ifdef MUL_MAT_ID
|
||||
#ifdef COOPMAT
|
||||
// Spread the search across all elements in the first subgroup
|
||||
if (gl_SubgroupID == 0) {
|
||||
_ne1 = 0;
|
||||
uint num_elements = p.nei1 * p.nei0;
|
||||
|
||||
uint ids[16];
|
||||
uint iter = 0;
|
||||
|
||||
for (uint j = 0; j < num_elements; j += gl_SubgroupSize) {
|
||||
// prefetch up to 16 elements
|
||||
if (iter == 0) {
|
||||
[[unroll]] for (uint k = 0; k < 16; ++k) {
|
||||
uint i = j + gl_SubgroupInvocationID + k*gl_SubgroupSize;
|
||||
bool in_range = i < num_elements;
|
||||
uint ii1 = i / p.nei0;
|
||||
uint ii0 = i % p.nei0;
|
||||
ids[k] = in_range ? data_ids[ii1*p.nbi1 + ii0] : 0;
|
||||
}
|
||||
}
|
||||
uint i = j + gl_SubgroupInvocationID;
|
||||
bool in_range = i < num_elements;
|
||||
uint ii1 = i / p.nei0;
|
||||
uint ii0 = i % p.nei0;
|
||||
uint id = ids[iter++];
|
||||
uvec4 ballot = subgroupBallot(in_range && id == expert_idx);
|
||||
uint idx = subgroupBallotExclusiveBitCount(ballot);
|
||||
if (in_range && id == expert_idx) {
|
||||
row_ids[_ne1 + idx] = u16vec2(ii0, ii1);
|
||||
}
|
||||
_ne1 += subgroupBallotBitCount(ballot);
|
||||
iter &= 15;
|
||||
}
|
||||
_ne1_sh = _ne1;
|
||||
#ifdef MUL_MAT_ID_USE_SUBGROUPS
|
||||
if (bitCount(p.nei0) == 1) {
|
||||
load_row_ids(expert_idx, true, ic);
|
||||
} else {
|
||||
load_row_ids(expert_idx, false, ic);
|
||||
}
|
||||
|
||||
barrier();
|
||||
|
||||
_ne1 = _ne1_sh;
|
||||
#else
|
||||
_ne1 = 0;
|
||||
for (uint ii1 = 0; ii1 < p.nei1; ii1++) {
|
||||
for (uint ii0 = 0; ii0 < p.nei0; ii0++) {
|
||||
for (uint ii1 = 0; ii1 < p.nei1 && _ne1 < (ic + 1) * BN; ii1++) {
|
||||
for (uint ii0 = 0; ii0 < p.nei0 && _ne1 < (ic + 1) * BN; ii0++) {
|
||||
if (data_ids[ii1*p.nbi1 + ii0] == expert_idx) {
|
||||
row_ids[_ne1] = u16vec2(ii0, ii1);
|
||||
if (_ne1 >= ic * BN) {
|
||||
row_ids[_ne1 - ic * BN] = u16vec2(ii0, ii1);
|
||||
}
|
||||
_ne1++;
|
||||
}
|
||||
}
|
||||
|
|
@ -265,8 +314,8 @@ void main() {
|
|||
}
|
||||
#else
|
||||
ACC_TYPE sums[WMITER * TM * WNITER * TN];
|
||||
FLOAT_TYPE cache_a[WMITER * TM];
|
||||
FLOAT_TYPE cache_b[TN];
|
||||
FLOAT_TYPE_VEC2 cache_a[WMITER * TM];
|
||||
FLOAT_TYPE_VEC2 cache_b[TN];
|
||||
|
||||
[[unroll]] for (uint i = 0; i < WMITER*TM*WNITER*TN; i++) {
|
||||
sums[i] = ACC_TYPE(0.0f);
|
||||
|
|
@ -275,538 +324,13 @@ void main() {
|
|||
|
||||
for (uint block = start_k; block < end_k; block += BK) {
|
||||
[[unroll]] for (uint l = 0; l < BM; l += loadstride_a) {
|
||||
|
||||
#if defined(DATA_A_F32) || defined(DATA_A_F16)
|
||||
#if LOAD_VEC_A == 8
|
||||
const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
|
||||
const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A;
|
||||
buf_a[buf_idx ] = FLOAT_TYPE(data_a[idx][0].x);
|
||||
buf_a[buf_idx + 1] = FLOAT_TYPE(data_a[idx][0].y);
|
||||
buf_a[buf_idx + 2] = FLOAT_TYPE(data_a[idx][0].z);
|
||||
buf_a[buf_idx + 3] = FLOAT_TYPE(data_a[idx][0].w);
|
||||
buf_a[buf_idx + 4] = FLOAT_TYPE(data_a[idx][1].x);
|
||||
buf_a[buf_idx + 5] = FLOAT_TYPE(data_a[idx][1].y);
|
||||
buf_a[buf_idx + 6] = FLOAT_TYPE(data_a[idx][1].z);
|
||||
buf_a[buf_idx + 7] = FLOAT_TYPE(data_a[idx][1].w);
|
||||
#elif LOAD_VEC_A == 4
|
||||
const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
|
||||
const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A;
|
||||
buf_a[buf_idx ] = FLOAT_TYPE(data_a[idx].x);
|
||||
buf_a[buf_idx + 1] = FLOAT_TYPE(data_a[idx].y);
|
||||
buf_a[buf_idx + 2] = FLOAT_TYPE(data_a[idx].z);
|
||||
buf_a[buf_idx + 3] = FLOAT_TYPE(data_a[idx].w);
|
||||
#else
|
||||
if (ir * BM + loadc_a + l < p.M && block + loadr_a < end_k) {
|
||||
buf_a[(loadc_a + l) * SHMEM_STRIDE + loadr_a] = FLOAT_TYPE(data_a[pos_a + (loadc_a + l) * p.stride_a + loadr_a]);
|
||||
} else {
|
||||
buf_a[(loadc_a + l) * SHMEM_STRIDE + loadr_a] = FLOAT_TYPE(0.0f);
|
||||
}
|
||||
#endif
|
||||
#elif defined(DATA_A_BF16)
|
||||
#if LOAD_VEC_A == 4
|
||||
const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
|
||||
const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A;
|
||||
buf_a[buf_idx ] = TO_FLOAT_TYPE(data_a[idx].x);
|
||||
buf_a[buf_idx + 1] = TO_FLOAT_TYPE(data_a[idx].y);
|
||||
buf_a[buf_idx + 2] = TO_FLOAT_TYPE(data_a[idx].z);
|
||||
buf_a[buf_idx + 3] = TO_FLOAT_TYPE(data_a[idx].w);
|
||||
#else
|
||||
if (ir * BM + loadc_a + l < p.M && block + loadr_a < end_k) {
|
||||
buf_a[(loadc_a + l) * SHMEM_STRIDE + loadr_a] = TO_FLOAT_TYPE(data_a[pos_a + (loadc_a + l) * p.stride_a + loadr_a]);
|
||||
} else {
|
||||
buf_a[(loadc_a + l) * SHMEM_STRIDE + loadr_a] = TO_FLOAT_TYPE(uint16_t(0));
|
||||
}
|
||||
#endif
|
||||
#elif defined(DATA_A_Q4_0)
|
||||
const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
|
||||
const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + 4 * loadr_a;
|
||||
|
||||
const uint ib = idx / 4;
|
||||
const uint iqs = idx & 0x03;
|
||||
|
||||
const float d = float(data_a_packed16[ib].d);
|
||||
const uint vui = uint(data_a_packed16[ib].qs[2*iqs]) | (uint(data_a_packed16[ib].qs[2*iqs + 1]) << 16);
|
||||
const vec4 v0 = (vec4(unpack8(vui & 0x0F0F0F0F)) - 8.0f) * d;
|
||||
const vec4 v1 = (vec4(unpack8((vui >> 4) & 0x0F0F0F0F)) - 8.0f) * d;
|
||||
|
||||
buf_a[buf_idx ] = FLOAT_TYPE(v0.x);
|
||||
buf_a[buf_idx + 1 ] = FLOAT_TYPE(v0.y);
|
||||
buf_a[buf_idx + 2 ] = FLOAT_TYPE(v0.z);
|
||||
buf_a[buf_idx + 3 ] = FLOAT_TYPE(v0.w);
|
||||
buf_a[buf_idx + 16] = FLOAT_TYPE(v1.x);
|
||||
buf_a[buf_idx + 17] = FLOAT_TYPE(v1.y);
|
||||
buf_a[buf_idx + 18] = FLOAT_TYPE(v1.z);
|
||||
buf_a[buf_idx + 19] = FLOAT_TYPE(v1.w);
|
||||
#elif defined(DATA_A_Q4_1)
|
||||
const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
|
||||
const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + 4 * loadr_a;
|
||||
|
||||
const uint ib = idx / 4;
|
||||
const uint iqs = idx & 0x03;
|
||||
|
||||
const float d = float(data_a_packed16[ib].d);
|
||||
const float m = float(data_a_packed16[ib].m);
|
||||
const uint vui = uint(data_a_packed16[ib].qs[2*iqs]) | (uint(data_a_packed16[ib].qs[2*iqs + 1]) << 16);
|
||||
const vec4 v0 = vec4(unpack8(vui & 0x0F0F0F0F)) * d + m;
|
||||
const vec4 v1 = vec4(unpack8((vui >> 4) & 0x0F0F0F0F)) * d + m;
|
||||
|
||||
buf_a[buf_idx ] = FLOAT_TYPE(v0.x);
|
||||
buf_a[buf_idx + 1 ] = FLOAT_TYPE(v0.y);
|
||||
buf_a[buf_idx + 2 ] = FLOAT_TYPE(v0.z);
|
||||
buf_a[buf_idx + 3 ] = FLOAT_TYPE(v0.w);
|
||||
buf_a[buf_idx + 16] = FLOAT_TYPE(v1.x);
|
||||
buf_a[buf_idx + 17] = FLOAT_TYPE(v1.y);
|
||||
buf_a[buf_idx + 18] = FLOAT_TYPE(v1.z);
|
||||
buf_a[buf_idx + 19] = FLOAT_TYPE(v1.w);
|
||||
#elif defined(DATA_A_Q5_0)
|
||||
const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
|
||||
const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + 2 * loadr_a;
|
||||
|
||||
const uint ib = idx / 8;
|
||||
const uint iqs = idx & 0x07;
|
||||
|
||||
const float d = float(data_a_packed16[ib].d);
|
||||
const uint uint_qh = uint(data_a_packed16[ib].qh[1]) << 16 | uint(data_a_packed16[ib].qh[0]);
|
||||
const ivec2 qh0 = ivec2(((uint_qh >> 2*iqs) << 4) & 0x10, (uint_qh >> (2*iqs + 12)) & 0x10);
|
||||
const ivec2 qh1 = ivec2(((uint_qh >> (2*iqs + 1)) << 4) & 0x10, (uint_qh >> (2*iqs + 13)) & 0x10);
|
||||
|
||||
const uint vui = uint(data_a_packed16[ib].qs[iqs]);
|
||||
const vec4 v = (vec4((vui & 0xF) | qh0.x, ((vui >> 4) & 0xF) | qh0.y, ((vui >> 8) & 0xF) | qh1.x, (vui >> 12) | qh1.y) - 16.0f) * d;
|
||||
|
||||
buf_a[buf_idx ] = FLOAT_TYPE(v.x);
|
||||
buf_a[buf_idx + 1 ] = FLOAT_TYPE(v.z);
|
||||
buf_a[buf_idx + 16] = FLOAT_TYPE(v.y);
|
||||
buf_a[buf_idx + 17] = FLOAT_TYPE(v.w);
|
||||
#elif defined(DATA_A_Q5_1)
|
||||
const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
|
||||
const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + 2 * loadr_a;
|
||||
|
||||
const uint ib = idx / 8;
|
||||
const uint iqs = idx & 0x07;
|
||||
|
||||
const float d = float(data_a_packed16[ib].d);
|
||||
const float m = float(data_a_packed16[ib].m);
|
||||
const uint uint_qh = data_a_packed16[ib].qh;
|
||||
const ivec2 qh0 = ivec2(((uint_qh >> 2*iqs) << 4) & 0x10, (uint_qh >> (2*iqs + 12)) & 0x10);
|
||||
const ivec2 qh1 = ivec2(((uint_qh >> (2*iqs + 1)) << 4) & 0x10, (uint_qh >> (2*iqs + 13)) & 0x10);
|
||||
|
||||
const uint vui = uint(data_a_packed16[ib].qs[iqs]);
|
||||
const vec4 v = vec4((vui & 0xF) | qh0.x, ((vui >> 4) & 0xF) | qh0.y, ((vui >> 8) & 0xF) | qh1.x, (vui >> 12) | qh1.y) * d + m;
|
||||
|
||||
buf_a[buf_idx ] = FLOAT_TYPE(v.x);
|
||||
buf_a[buf_idx + 1 ] = FLOAT_TYPE(v.z);
|
||||
buf_a[buf_idx + 16] = FLOAT_TYPE(v.y);
|
||||
buf_a[buf_idx + 17] = FLOAT_TYPE(v.w);
|
||||
#elif defined(DATA_A_Q8_0)
|
||||
const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
|
||||
const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A;
|
||||
|
||||
const uint ib = idx / 8;
|
||||
const uint iqs = idx & 0x07;
|
||||
|
||||
const float d = float(data_a_packed16[ib].d);
|
||||
const i8vec2 v0 = unpack8(int32_t(data_a_packed16[ib].qs[2*iqs])).xy; // vec4 used due to #12147
|
||||
const i8vec2 v1 = unpack8(int32_t(data_a_packed16[ib].qs[2*iqs + 1])).xy;
|
||||
const vec4 v = vec4(v0.x, v0.y, v1.x, v1.y) * d;
|
||||
|
||||
buf_a[buf_idx ] = FLOAT_TYPE(v.x);
|
||||
buf_a[buf_idx + 1] = FLOAT_TYPE(v.y);
|
||||
buf_a[buf_idx + 2] = FLOAT_TYPE(v.z);
|
||||
buf_a[buf_idx + 3] = FLOAT_TYPE(v.w);
|
||||
#elif defined(DATA_A_Q2_K)
|
||||
const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
|
||||
const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A;
|
||||
|
||||
const uint ib = idx / 128; // 2 values per idx
|
||||
const uint iqs = idx % 128; // 0..127
|
||||
|
||||
const uint qsi = (iqs / 64) * 32 + (iqs % 16) * 2; // 0,2,4..30
|
||||
const uint scalesi = iqs / 8; // 0..15
|
||||
const uint qsshift = ((iqs % 64) / 16) * 2; // 0,2,4,6
|
||||
|
||||
const uvec2 qs = uvec2(data_a[ib].qs[qsi], data_a[ib].qs[qsi + 1]);
|
||||
const uint scales = data_a[ib].scales[scalesi];
|
||||
const vec2 d = vec2(data_a[ib].d);
|
||||
|
||||
const vec2 v = d.x * float(scales & 0xF) * vec2((qs >> qsshift) & 3) - d.y * float(scales >> 4);
|
||||
|
||||
buf_a[buf_idx ] = FLOAT_TYPE(v.x);
|
||||
buf_a[buf_idx + 1] = FLOAT_TYPE(v.y);
|
||||
#elif defined(DATA_A_Q3_K)
|
||||
const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
|
||||
const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A;
|
||||
|
||||
const uint ib = idx / 128; // 2 values per idx
|
||||
const uint iqs = idx % 128; // 0..127
|
||||
|
||||
const uint n = iqs / 64; // 0,1
|
||||
const uint qsi = n * 32 + (iqs % 16) * 2; // 0,2,4..62
|
||||
const uint hmi = (iqs % 16) * 2; // 0,2,4..30
|
||||
const uint j = (iqs % 64) / 4; // 0..3
|
||||
const uint is = iqs / 8; // 0..15
|
||||
const uint halfsplit = ((iqs % 64) / 16); // 0,1,2,3
|
||||
const uint qsshift = halfsplit * 2; // 0,2,4,6
|
||||
const uint m = 1 << (4 * n + halfsplit); // 1,2,4,8,16,32,64,128
|
||||
|
||||
const int8_t us = int8_t(((data_a[ib].scales[is % 8] >> (4 * int(is / 8))) & 0xF)
|
||||
| (((data_a[ib].scales[8 + (is % 4)] >> (2 * int(is / 4))) & 3) << 4));
|
||||
const float dl = float(data_a[ib].d) * float(us - 32);
|
||||
|
||||
buf_a[buf_idx ] = FLOAT_TYPE(dl * float(int8_t((data_a[ib].qs[qsi ] >> qsshift) & 3) - (((data_a[ib].hmask[hmi ] & m) != 0) ? 0 : 4)));
|
||||
buf_a[buf_idx + 1] = FLOAT_TYPE(dl * float(int8_t((data_a[ib].qs[qsi + 1] >> qsshift) & 3) - (((data_a[ib].hmask[hmi + 1] & m) != 0) ? 0 : 4)));
|
||||
#elif defined(DATA_A_Q4_K)
|
||||
const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
|
||||
const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A;
|
||||
|
||||
const uint ib = idx / 128; // 2 values per idx
|
||||
const uint iqs = idx % 128; // 0..127
|
||||
|
||||
const uint n = iqs / 32; // 0,1,2,3
|
||||
const uint b = (iqs % 32) / 16; // 0,1
|
||||
const uint is = 2 * n + b; // 0..7
|
||||
const uint qsi = n * 32 + (iqs % 16) * 2; // 0,2,4..126
|
||||
|
||||
const vec2 loadd = vec2(data_a[ib].d);
|
||||
|
||||
const uint scidx0 = (is < 4) ? is : (is + 4);
|
||||
const uint scidx1 = (is < 4) ? is : (is - 4);
|
||||
const uint scidxmask1 = (is < 4) ? 0x30 : 0xC0;
|
||||
const uint scidxshift1 = (is < 4) ? 0 : 2;
|
||||
const uint mbidx0 = is + 4;
|
||||
const uint mbidx1 = (is < 4) ? is + 4 : is;
|
||||
const uint mbidxmask0 = (is < 4) ? 0xF : 0xF0;
|
||||
const uint mbidxshift0 = (is < 4) ? 0 : 4;
|
||||
const uint mbidxmask1 = (is < 4) ? 0x30 : 0xC0;
|
||||
const uint mbidxshift1 = (is < 4) ? 0 : 2;
|
||||
|
||||
const uint8_t sc = uint8_t((data_a[ib].scales[scidx0] & 0xF) | ((data_a[ib].scales[scidx1] & scidxmask1) >> scidxshift1));
|
||||
const uint8_t mbyte = uint8_t((data_a[ib].scales[mbidx0] & mbidxmask0) >> mbidxshift0 | ((data_a[ib].scales[mbidx1] & mbidxmask1) >> mbidxshift1));
|
||||
|
||||
const float d = loadd.x * sc;
|
||||
const float m = -loadd.y * mbyte;
|
||||
|
||||
buf_a[buf_idx ] = FLOAT_TYPE(fma(d, float((data_a[ib].qs[qsi ] >> (b * 4)) & 0xF), m));
|
||||
buf_a[buf_idx + 1] = FLOAT_TYPE(fma(d, float((data_a[ib].qs[qsi + 1] >> (b * 4)) & 0xF), m));
|
||||
#elif defined(DATA_A_Q5_K)
|
||||
const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
|
||||
const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A;
|
||||
|
||||
const uint ib = idx / 128; // 2 values per idx
|
||||
const uint iqs = idx % 128; // 0..127
|
||||
|
||||
const uint n = iqs / 32; // 0,1,2,3
|
||||
const uint b = (iqs % 32) / 16; // 0,1
|
||||
const uint is = 2 * n + b; // 0..7
|
||||
const uint qsi = n * 32 + (iqs % 16) * 2; // 0,2,4..126
|
||||
const uint qhi = (iqs % 16) * 2; // 0,2,4..30
|
||||
|
||||
const uint8_t hm = uint8_t(1 << (iqs / 16));
|
||||
|
||||
const vec2 loadd = vec2(data_a[ib].d);
|
||||
|
||||
const uint scidx0 = (is < 4) ? is : (is + 4);
|
||||
const uint scidx1 = (is < 4) ? is : (is - 4);
|
||||
const uint scidxmask1 = (is < 4) ? 0x30 : 0xC0;
|
||||
const uint scidxshift1 = (is < 4) ? 0 : 2;
|
||||
const uint mbidx0 = is + 4;
|
||||
const uint mbidx1 = (is < 4) ? is + 4 : is;
|
||||
const uint mbidxmask0 = (is < 4) ? 0xF : 0xF0;
|
||||
const uint mbidxshift0 = (is < 4) ? 0 : 4;
|
||||
const uint mbidxmask1 = (is < 4) ? 0x30 : 0xC0;
|
||||
const uint mbidxshift1 = (is < 4) ? 0 : 2;
|
||||
|
||||
const uint8_t sc = uint8_t((data_a[ib].scales[scidx0] & 0xF) | ((data_a[ib].scales[scidx1] & scidxmask1) >> scidxshift1));
|
||||
const uint8_t mbyte = uint8_t(((data_a[ib].scales[mbidx0] & mbidxmask0) >> mbidxshift0) | ((data_a[ib].scales[mbidx1] & mbidxmask1) >> mbidxshift1));
|
||||
|
||||
const float d = loadd.x * sc;
|
||||
const float m = -loadd.y * mbyte;
|
||||
|
||||
buf_a[buf_idx ] = FLOAT_TYPE(fma(d, float((data_a[ib].qs[qsi ] >> (b * 4)) & 0xF) + float((data_a[ib].qh[qhi ] & hm) != 0 ? 16 : 0), m));
|
||||
buf_a[buf_idx + 1] = FLOAT_TYPE(fma(d, float((data_a[ib].qs[qsi + 1] >> (b * 4)) & 0xF) + float((data_a[ib].qh[qhi + 1] & hm) != 0 ? 16 : 0), m));
|
||||
#elif defined(DATA_A_Q6_K)
|
||||
const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
|
||||
const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A;
|
||||
|
||||
const uint ib = idx / 128; // 2 values per idx
|
||||
const uint iqs = idx % 128; // 0..127
|
||||
|
||||
const uint n = iqs / 64; // 0,1
|
||||
const uint b = (iqs % 64) / 32; // 0,1
|
||||
const uint is_b = (iqs % 16) / 8; // 0,1
|
||||
const uint qhshift = ((iqs % 64) / 16) * 2; // 0,2,4,6
|
||||
const uint is = 8 * n + qhshift + is_b; // 0..15
|
||||
const uint qsi = n * 64 + (iqs % 32) * 2; // 0,2,4..126
|
||||
const uint qhi = n * 32 + (iqs % 16) * 2; // 0,2,4..62
|
||||
|
||||
const float dscale = float(data_a[ib].d) * float(data_a[ib].scales[is]);
|
||||
|
||||
buf_a[buf_idx ] = FLOAT_TYPE(dscale * float(int8_t(((data_a[ib].ql[qsi ] >> (b * 4)) & 0xF) | (((data_a[ib].qh[qhi ] >> qhshift) & 3) << 4)) - 32));
|
||||
buf_a[buf_idx + 1] = FLOAT_TYPE(dscale * float(int8_t(((data_a[ib].ql[qsi + 1] >> (b * 4)) & 0xF) | (((data_a[ib].qh[qhi + 1] >> qhshift) & 3) << 4)) - 32));
|
||||
#elif defined(DATA_A_IQ1_S)
|
||||
const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
|
||||
const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A;
|
||||
|
||||
const uint ib = idx / 32; // 8 values per idx
|
||||
const uint ib32 = (idx % 32) / 4; // 0..7
|
||||
const uint ib8 = idx % 32;
|
||||
|
||||
const float d = float(data_a[ib].d);
|
||||
const uint qh = data_a[ib].qh[ib32];
|
||||
const uint qs = data_a[ib].qs[ib8];
|
||||
const float dl = d * (2 * bitfieldExtract(qh, 12, 3) + 1);
|
||||
const float delta = ((qh & 0x8000) != 0) ? -IQ1S_DELTA : IQ1S_DELTA;
|
||||
const int16_t grid = int16_t(iq1s_grid[qs | (bitfieldExtract(qh, 3 * int(ib8 & 3), 3) << 8)]);
|
||||
|
||||
[[unroll]] for (int k = 0; k < 8; ++k) {
|
||||
buf_a[buf_idx + k] = FLOAT_TYPE(dl * (bitfieldExtract(grid, 2 * k, 2) + delta));
|
||||
}
|
||||
#elif defined(DATA_A_IQ1_M)
|
||||
const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
|
||||
const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A;
|
||||
|
||||
const uint ib = idx / 32; // 8 values per idx
|
||||
const uint ib8 = idx % 32;
|
||||
const uint ib16 = ib8 / 2;
|
||||
|
||||
const uint16_t[4] scales = data_a[ib].scales;
|
||||
const u16vec4 s = u16vec4(scales[0], scales[1], scales[2], scales[3]) >> 12;
|
||||
const float d = float(unpackHalf2x16(s.x | (s.y << 4) | (s.z << 8) | (s.w << 12)).x);
|
||||
const uint sc = scales[ib8 / 8];
|
||||
const uint qs = data_a[ib].qs[ib8];
|
||||
const uint qh = data_a[ib].qh[ib16] >> (4 * (ib8 & 1));
|
||||
const float dl = d * (2 * bitfieldExtract(sc, 3 * int(ib16 & 3), 3) + 1);
|
||||
const float delta = ((qh & 8) != 0) ? -IQ1M_DELTA : IQ1M_DELTA;
|
||||
const int16_t grid = int16_t(iq1s_grid[qs | ((qh & 7) << 8)]);
|
||||
|
||||
[[unroll]] for (int k = 0; k < 8; ++k) {
|
||||
buf_a[buf_idx + k] = FLOAT_TYPE(dl * (bitfieldExtract(grid, 2 * k, 2) + delta));
|
||||
}
|
||||
#elif defined(DATA_A_IQ2_XXS)
|
||||
const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
|
||||
const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A;
|
||||
|
||||
const uint ib = idx / 32; // 8 values per idx
|
||||
const uint ib32 = (idx % 32) / 4; // 0..7
|
||||
const uint ib8 = idx % 4;
|
||||
|
||||
const float d = float(data_a[ib].d);
|
||||
const uint qs = data_a[ib].qs[8 * ib32 + ib8];
|
||||
const uint signs = pack32(u8vec4(
|
||||
data_a[ib].qs[8*ib32 + 4],
|
||||
data_a[ib].qs[8*ib32 + 5],
|
||||
data_a[ib].qs[8*ib32 + 6],
|
||||
data_a[ib].qs[8*ib32 + 7]
|
||||
));
|
||||
const FLOAT_TYPE db = FLOAT_TYPE(d * 0.25 * (0.5 + (signs >> 28)));
|
||||
const uint32_t sign7 = bitfieldExtract(signs, 7 * int(ib8), 7);
|
||||
const uint sign = sign7 | (bitCount(sign7) << 7);
|
||||
const uvec2 grid = iq2xxs_grid[qs];
|
||||
const vec4 grid0 = vec4(unpack8(grid.x));
|
||||
const vec4 grid1 = vec4(unpack8(grid.y));
|
||||
|
||||
buf_a[buf_idx ] = db * FLOAT_TYPE((sign & 1) != 0 ? -grid0.x : grid0.x);
|
||||
buf_a[buf_idx + 1] = db * FLOAT_TYPE((sign & 2) != 0 ? -grid0.y : grid0.y);
|
||||
buf_a[buf_idx + 2] = db * FLOAT_TYPE((sign & 4) != 0 ? -grid0.z : grid0.z);
|
||||
buf_a[buf_idx + 3] = db * FLOAT_TYPE((sign & 8) != 0 ? -grid0.w : grid0.w);
|
||||
buf_a[buf_idx + 4] = db * FLOAT_TYPE((sign & 16) != 0 ? -grid1.x : grid1.x);
|
||||
buf_a[buf_idx + 5] = db * FLOAT_TYPE((sign & 32) != 0 ? -grid1.y : grid1.y);
|
||||
buf_a[buf_idx + 6] = db * FLOAT_TYPE((sign & 64) != 0 ? -grid1.z : grid1.z);
|
||||
buf_a[buf_idx + 7] = db * FLOAT_TYPE((sign & 128) != 0 ? -grid1.w : grid1.w);
|
||||
#elif defined(DATA_A_IQ2_XS)
|
||||
const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
|
||||
const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A;
|
||||
|
||||
const uint ib = idx / 32; // 8 values per idx
|
||||
const uint ib32 = (idx % 32) / 4; // 0..7
|
||||
const uint ib8 = idx % 4; // 0..3
|
||||
|
||||
const float d = float(data_a[ib].d);
|
||||
const uint scale = (data_a[ib].scales[ib32] >> (2 * (ib8 & 2))) & 0xf;
|
||||
const FLOAT_TYPE db = FLOAT_TYPE(d * 0.25 * (0.5 + scale));
|
||||
const uint qs = data_a[ib].qs[4 * ib32 + ib8];
|
||||
const uint sign7 = qs >> 9;
|
||||
const uint sign = sign7 | (bitCount(sign7) << 7);
|
||||
const uvec2 grid = iq2xs_grid[qs & 511];
|
||||
const vec4 grid0 = vec4(unpack8(grid.x));
|
||||
const vec4 grid1 = vec4(unpack8(grid.y));
|
||||
|
||||
buf_a[buf_idx ] = db * FLOAT_TYPE((sign & 1) != 0 ? -grid0.x : grid0.x);
|
||||
buf_a[buf_idx + 1] = db * FLOAT_TYPE((sign & 2) != 0 ? -grid0.y : grid0.y);
|
||||
buf_a[buf_idx + 2] = db * FLOAT_TYPE((sign & 4) != 0 ? -grid0.z : grid0.z);
|
||||
buf_a[buf_idx + 3] = db * FLOAT_TYPE((sign & 8) != 0 ? -grid0.w : grid0.w);
|
||||
buf_a[buf_idx + 4] = db * FLOAT_TYPE((sign & 16) != 0 ? -grid1.x : grid1.x);
|
||||
buf_a[buf_idx + 5] = db * FLOAT_TYPE((sign & 32) != 0 ? -grid1.y : grid1.y);
|
||||
buf_a[buf_idx + 6] = db * FLOAT_TYPE((sign & 64) != 0 ? -grid1.z : grid1.z);
|
||||
buf_a[buf_idx + 7] = db * FLOAT_TYPE((sign & 128) != 0 ? -grid1.w : grid1.w);
|
||||
#elif defined(DATA_A_IQ2_S)
|
||||
const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
|
||||
const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A;
|
||||
|
||||
const uint ib = idx / 32; // 8 values per idx
|
||||
const uint ib8 = idx % 32; // 0..31
|
||||
const uint ib32 = ib8 / 4; // 0..7
|
||||
|
||||
const uint scale = (data_a[ib].scales[ib32] >> (2 * (ib8 & 2))) & 0xf;
|
||||
const uint qs = data_a[ib].qs[ib8];
|
||||
const uint qh = data_a[ib].qh[ib32];
|
||||
const uint qhshift = 2 * (ib8 % 4);
|
||||
const uint sign = data_a[ib].qs[QUANT_K / 8 + ib8];
|
||||
|
||||
const float d = float(data_a[ib].d);
|
||||
const FLOAT_TYPE db = FLOAT_TYPE(d * 0.25 * (0.5 + scale));
|
||||
const uvec2 grid = iq2s_grid[qs | ((qh << (8 - qhshift)) & 0x300)];
|
||||
const vec4 grid0 = vec4(unpack8(grid.x));
|
||||
const vec4 grid1 = vec4(unpack8(grid.y));
|
||||
|
||||
buf_a[buf_idx ] = db * FLOAT_TYPE((sign & 1) != 0 ? -grid0.x : grid0.x);
|
||||
buf_a[buf_idx + 1] = db * FLOAT_TYPE((sign & 2) != 0 ? -grid0.y : grid0.y);
|
||||
buf_a[buf_idx + 2] = db * FLOAT_TYPE((sign & 4) != 0 ? -grid0.z : grid0.z);
|
||||
buf_a[buf_idx + 3] = db * FLOAT_TYPE((sign & 8) != 0 ? -grid0.w : grid0.w);
|
||||
buf_a[buf_idx + 4] = db * FLOAT_TYPE((sign & 16) != 0 ? -grid1.x : grid1.x);
|
||||
buf_a[buf_idx + 5] = db * FLOAT_TYPE((sign & 32) != 0 ? -grid1.y : grid1.y);
|
||||
buf_a[buf_idx + 6] = db * FLOAT_TYPE((sign & 64) != 0 ? -grid1.z : grid1.z);
|
||||
buf_a[buf_idx + 7] = db * FLOAT_TYPE((sign & 128) != 0 ? -grid1.w : grid1.w);
|
||||
#elif defined(DATA_A_IQ3_XXS)
|
||||
const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
|
||||
const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A;
|
||||
|
||||
const uint ib = idx / 64; // 4 values per idx
|
||||
const uint iqs = idx % 64; // 0..63
|
||||
const uint is = QUANT_K / 4 + 4 * (iqs / 8); // 8 values
|
||||
|
||||
const float d = float(data_a[ib].d);
|
||||
const uint qs = data_a[ib].qs[iqs];
|
||||
const uint signs = pack32(u8vec4(
|
||||
data_a[ib].qs[is+0],
|
||||
data_a[ib].qs[is+1],
|
||||
data_a[ib].qs[is+2],
|
||||
data_a[ib].qs[is+3]
|
||||
));
|
||||
const float db = d * 0.5 * (0.5 + (signs >> 28));
|
||||
const uint32_t sign7 = bitfieldExtract(signs, 7 * (int(iqs / 2) % 4), 7);
|
||||
const uint sign = (sign7 | (bitCount(sign7) << 7)) >> (4 * (idx % 2));
|
||||
const uint grid = iq3xxs_grid[qs];
|
||||
const vec4 v = db * vec4(unpack8(grid));
|
||||
|
||||
buf_a[buf_idx ] = FLOAT_TYPE((sign & 1) != 0 ? -v.x : v.x);
|
||||
buf_a[buf_idx + 1] = FLOAT_TYPE((sign & 2) != 0 ? -v.y : v.y);
|
||||
buf_a[buf_idx + 2] = FLOAT_TYPE((sign & 4) != 0 ? -v.z : v.z);
|
||||
buf_a[buf_idx + 3] = FLOAT_TYPE((sign & 8) != 0 ? -v.w : v.w);
|
||||
#elif defined(DATA_A_IQ3_S)
|
||||
const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
|
||||
const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A;
|
||||
|
||||
const uint ib = idx / 64; // 4 values per idx
|
||||
const uint iqs = idx % 64; // 0..63
|
||||
const uint iqh = iqs / 8;
|
||||
|
||||
const float d = float(data_a[ib].d);
|
||||
const uint qs = data_a[ib].qs[iqs];
|
||||
const uint qh = data_a[ib].qh[iqh];
|
||||
const int8_t sign = int8_t(data_a[ib].signs[iqs / 2] >> (4 * (idx % 2)));
|
||||
const uint scale = data_a[ib].scales[iqs / 16];
|
||||
const i8vec2 sign01 = i8vec2(1 - (2 & i8vec2(sign << 1, sign)));
|
||||
const float db = d * (1 + 2 * ((scale >> (4 * (iqh & 1))) & 0xf));
|
||||
const uint32_t grid = iq3s_grid[qs | ((qh << (8 - (iqs % 8))) & 256)];
|
||||
const vec4 v = db * vec4(unpack8(grid));
|
||||
|
||||
buf_a[buf_idx ] = FLOAT_TYPE((sign & 1) != 0 ? -v.x : v.x);
|
||||
buf_a[buf_idx + 1] = FLOAT_TYPE((sign & 2) != 0 ? -v.y : v.y);
|
||||
buf_a[buf_idx + 2] = FLOAT_TYPE((sign & 4) != 0 ? -v.z : v.z);
|
||||
buf_a[buf_idx + 3] = FLOAT_TYPE((sign & 8) != 0 ? -v.w : v.w);
|
||||
#elif defined(DATA_A_IQ4_XS)
|
||||
const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
|
||||
const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A;
|
||||
|
||||
const uint ib = idx / 128; // 2 values per idx
|
||||
const uint ib32 = (idx % 128) / 16; // 0..7
|
||||
const uint iq = 16 * ib32 + 2 * (idx % 8);
|
||||
|
||||
const uint sl = (data_a[ib].scales_l[ib32/2] >> (4 * (ib32 & 1))) & 0xF;
|
||||
const uint sh = ((data_a[ib].scales_h) >> (2 * ib32)) & 3;
|
||||
const uint qshift = (idx & 8) >> 1;
|
||||
u8vec2 qs = u8vec2(data_a[ib].qs[iq], data_a[ib].qs[iq + 1]);
|
||||
qs = (qs >> qshift) & uint8_t(0xF);
|
||||
|
||||
const float d = float(data_a[ib].d);
|
||||
const vec2 v = d * float(int(sl | (sh << 4)) - 32) * vec2(kvalues_iq4nl[qs.x], kvalues_iq4nl[qs.y]);
|
||||
|
||||
buf_a[buf_idx ] = FLOAT_TYPE(v.x);
|
||||
buf_a[buf_idx + 1] = FLOAT_TYPE(v.y);
|
||||
#elif defined(DATA_A_IQ4_NL)
|
||||
const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
|
||||
const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + 2 * loadr_a;
|
||||
|
||||
const uint ib = idx / 8;
|
||||
const uint iqs = idx & 0x07;
|
||||
|
||||
const FLOAT_TYPE d = FLOAT_TYPE(data_a_packed16[ib].d);
|
||||
const uint vui = uint(data_a_packed16[ib].qs[iqs]);
|
||||
|
||||
buf_a[buf_idx ] = FLOAT_TYPE(kvalues_iq4nl[vui & 0xF]) * d;
|
||||
buf_a[buf_idx + 1 ] = FLOAT_TYPE(kvalues_iq4nl[bitfieldExtract(vui, 8, 4)]) * d;
|
||||
buf_a[buf_idx + 16] = FLOAT_TYPE(kvalues_iq4nl[bitfieldExtract(vui, 4, 4)]) * d;
|
||||
buf_a[buf_idx + 17] = FLOAT_TYPE(kvalues_iq4nl[vui >> 12]) * d;
|
||||
#elif defined(DATA_A_MXFP4)
|
||||
const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
|
||||
const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + 2 * loadr_a;
|
||||
|
||||
const uint ib = idx / 8;
|
||||
const uint iqs = (idx & 0x07) * 2;
|
||||
|
||||
const float d = e8m0_to_fp32(data_a[ib].e);
|
||||
const uint vui = uint(data_a[ib].qs[iqs]);
|
||||
const uint vui2 = uint(data_a[ib].qs[iqs+1]);
|
||||
|
||||
buf_a[buf_idx ] = FLOAT_TYPE(kvalues_mxfp4[vui & 0xF] * d);
|
||||
buf_a[buf_idx + 16] = FLOAT_TYPE(kvalues_mxfp4[vui >> 4] * d);
|
||||
buf_a[buf_idx + 1] = FLOAT_TYPE(kvalues_mxfp4[vui2 & 0xF] * d);
|
||||
buf_a[buf_idx + 17] = FLOAT_TYPE(kvalues_mxfp4[vui2 >> 4] * d);
|
||||
#endif
|
||||
load_a_to_shmem(pos_a, loadr_a, loadc_a + l, ir * BM + loadc_a + l, block, end_k);
|
||||
}
|
||||
[[unroll]] for (uint l = 0; l < BN; l += loadstride_b) {
|
||||
#if LOAD_VEC_B == 8
|
||||
#ifdef MUL_MAT_ID
|
||||
const u16vec2 row_idx = row_ids[ic * BN + loadc_b + l];
|
||||
const uint idx = pos_b + row_idx.y * p.batch_stride_b / LOAD_VEC_B + (row_idx.x % p.ne11) * p.stride_b / LOAD_VEC_B + loadr_b;
|
||||
#if !defined(MUL_MAT_ID)
|
||||
load_b_to_shmem(pos_b, loadr_b, loadc_b + l, ic * BN + loadc_b + l, block, end_k);
|
||||
#else
|
||||
const uint idx = pos_b + (loadc_b + l) * p.stride_b / LOAD_VEC_B + loadr_b;
|
||||
#endif
|
||||
const uint buf_idx = (loadc_b + l) * SHMEM_STRIDE + loadr_b * LOAD_VEC_B;
|
||||
buf_b[buf_idx + 0] = FLOAT_TYPE(data_b[idx][0].x);
|
||||
buf_b[buf_idx + 1] = FLOAT_TYPE(data_b[idx][0].y);
|
||||
buf_b[buf_idx + 2] = FLOAT_TYPE(data_b[idx][0].z);
|
||||
buf_b[buf_idx + 3] = FLOAT_TYPE(data_b[idx][0].w);
|
||||
buf_b[buf_idx + 4] = FLOAT_TYPE(data_b[idx][1].x);
|
||||
buf_b[buf_idx + 5] = FLOAT_TYPE(data_b[idx][1].y);
|
||||
buf_b[buf_idx + 6] = FLOAT_TYPE(data_b[idx][1].z);
|
||||
buf_b[buf_idx + 7] = FLOAT_TYPE(data_b[idx][1].w);
|
||||
#elif LOAD_VEC_B == 4
|
||||
#ifdef MUL_MAT_ID
|
||||
const u16vec2 row_idx = row_ids[ic * BN + loadc_b + l];
|
||||
const uint idx = pos_b + row_idx.y * p.batch_stride_b / LOAD_VEC_B + (row_idx.x % p.ne11) * p.stride_b / LOAD_VEC_B + loadr_b;
|
||||
#else
|
||||
const uint idx = pos_b + (loadc_b + l) * p.stride_b / LOAD_VEC_B + loadr_b;
|
||||
#endif
|
||||
const uint buf_idx = (loadc_b + l) * SHMEM_STRIDE + loadr_b * LOAD_VEC_B;
|
||||
buf_b[buf_idx + 0] = TO_FLOAT_TYPE(data_b[idx].x);
|
||||
buf_b[buf_idx + 1] = TO_FLOAT_TYPE(data_b[idx].y);
|
||||
buf_b[buf_idx + 2] = TO_FLOAT_TYPE(data_b[idx].z);
|
||||
buf_b[buf_idx + 3] = TO_FLOAT_TYPE(data_b[idx].w);
|
||||
#elif !MUL_MAT_ID
|
||||
if (ic * BN + loadc_b + l < p.N && block + loadr_b < end_k) {
|
||||
buf_b[(loadc_b + l) * SHMEM_STRIDE + loadr_b] = TO_FLOAT_TYPE(data_b[pos_b + (loadc_b + l) * p.stride_b + loadr_b]);
|
||||
} else {
|
||||
buf_b[(loadc_b + l) * SHMEM_STRIDE + loadr_b] = FLOAT_TYPE(0.0f);
|
||||
}
|
||||
#else
|
||||
const uint row_i = ic * BN + loadc_b + l;
|
||||
if (row_i < _ne1) {
|
||||
const u16vec2 row_idx = row_ids[row_i];
|
||||
buf_b[(loadc_b + l) * SHMEM_STRIDE + loadr_b] = TO_FLOAT_TYPE(data_b[pos_b + row_idx.y * p.batch_stride_b + (row_idx.x % p.ne11) * p.stride_b + loadr_b]);
|
||||
} else {
|
||||
buf_b[(loadc_b + l) * SHMEM_STRIDE + loadr_b] = FLOAT_TYPE(0.0f);
|
||||
}
|
||||
load_b_to_shmem(pos_b, loadr_b, loadc_b + l, ic, _ne1, block, end_k);
|
||||
#endif
|
||||
}
|
||||
|
||||
|
|
@ -819,17 +343,17 @@ void main() {
|
|||
[[unroll]] for (uint i = 0; i < BK; i += TK) {
|
||||
[[unroll]] for (uint cm_row = 0; cm_row < cms_per_row; cm_row++) {
|
||||
// Load from shared into cache
|
||||
coopMatLoad(cache_a, buf_a, (warp_r * WM + cm_row * TM) * SHMEM_STRIDE + i, SHMEM_STRIDE, gl_CooperativeMatrixLayoutRowMajor);
|
||||
coopMatLoad(cache_a, buf_a, (warp_r * WM + cm_row * TM) * SHMEM_STRIDE + i / 2, SHMEM_STRIDE, gl_CooperativeMatrixLayoutRowMajor);
|
||||
|
||||
[[unroll]] for (uint cm_col = 0; cm_col < cms_per_col; cm_col++) {
|
||||
coopMatLoad(cache_b, buf_b, (warp_c * WN + cm_col * TN) * SHMEM_STRIDE + i, SHMEM_STRIDE, gl_CooperativeMatrixLayoutColumnMajor);
|
||||
coopMatLoad(cache_b, buf_b, (warp_c * WN + cm_col * TN) * SHMEM_STRIDE + i / 2, SHMEM_STRIDE, gl_CooperativeMatrixLayoutColumnMajor);
|
||||
|
||||
sums[cm_col * cms_per_row + cm_row] = coopMatMulAdd(cache_a, cache_b, sums[cm_col * cms_per_row + cm_row]);
|
||||
}
|
||||
}
|
||||
}
|
||||
#else
|
||||
[[unroll]] for (uint i = 0; i < BK; i++) {
|
||||
[[unroll]] for (uint i = 0; i < BK / 2; i++) {
|
||||
// Load from shared into cache
|
||||
[[unroll]] for (uint wsir = 0; wsir < WMITER; wsir++) {
|
||||
[[unroll]] for (uint j = 0; j < TM; j++) {
|
||||
|
|
@ -845,7 +369,7 @@ void main() {
|
|||
[[unroll]] for (uint cc = 0; cc < TN; cc++) {
|
||||
[[unroll]] for (uint cr = 0; cr < TM; cr++) {
|
||||
const uint sums_idx = (wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr;
|
||||
sums[sums_idx] = fma(ACC_TYPE(cache_a[wsir * TM + cr]), ACC_TYPE(cache_b[cc]), sums[sums_idx]);
|
||||
sums[sums_idx] = fma(ACC_TYPE(cache_a[wsir * TM + cr].x), ACC_TYPE(cache_b[cc].x), fma(ACC_TYPE(cache_a[wsir * TM + cr].y), ACC_TYPE(cache_b[cc].y), sums[sums_idx]));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -856,6 +380,20 @@ void main() {
|
|||
barrier();
|
||||
}
|
||||
|
||||
#if defined(ACC_TYPE_MAX)
|
||||
#ifdef COOPMAT
|
||||
[[unroll]] for (uint j = 0; j < cms_per_row * cms_per_col; j++) {
|
||||
[[unroll]] for (uint i = 0; i < sums[j].length(); ++i) {
|
||||
sums[j][i] = clamp(sums[j][i], -ACC_TYPE_MAX, ACC_TYPE_MAX);
|
||||
}
|
||||
}
|
||||
#else
|
||||
[[unroll]] for (uint i = 0; i < WMITER*TM*WNITER*TN; i++) {
|
||||
sums[i] = clamp(sums[i], -ACC_TYPE_MAX, ACC_TYPE_MAX);
|
||||
}
|
||||
#endif
|
||||
#endif
|
||||
|
||||
const uint dr = ir * BM + warp_r * WM;
|
||||
const uint dc = ic * BN + warp_c * WN;
|
||||
|
||||
|
|
@ -873,9 +411,11 @@ void main() {
|
|||
const uint row_i = dc + cm_col * TN + col + store_c;
|
||||
if (row_i >= _ne1) break;
|
||||
|
||||
const u16vec2 row_idx = row_ids[row_i];
|
||||
const u16vec2 row_idx = row_ids[row_i - ic * BN];
|
||||
|
||||
data_d[row_idx.y * p.batch_stride_d + row_idx.x * p.stride_d + dr + cm_row * TM + store_r] = D_TYPE(coopmat_stage[warp_i * TM * TN + (col + store_c) * TM + store_r]);
|
||||
if (dr + cm_row * TM + store_r < p.M) {
|
||||
data_d[row_idx.y * p.batch_stride_d + row_idx.x * p.stride_d + dr + cm_row * TM + store_r] = D_TYPE(coopmat_stage[warp_i * TM * TN + (col + store_c) * TM + store_r]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -921,11 +461,13 @@ void main() {
|
|||
const uint row_i = dc_warp + cc;
|
||||
if (row_i >= _ne1) break;
|
||||
|
||||
const u16vec2 row_idx = row_ids[row_i];
|
||||
const u16vec2 row_idx = row_ids[row_i - ic * BN];
|
||||
#endif // MUL_MAT_ID
|
||||
[[unroll]] for (uint cr = 0; cr < TM; cr++) {
|
||||
#ifdef MUL_MAT_ID
|
||||
data_d[row_idx.y * p.batch_stride_d + row_idx.x * p.stride_d + dr_warp + cr] = D_TYPE(sums[(wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr]);
|
||||
if (dr_warp + cr < p.M) {
|
||||
data_d[row_idx.y * p.batch_stride_d + row_idx.x * p.stride_d + dr_warp + cr] = D_TYPE(sums[(wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr]);
|
||||
}
|
||||
#else
|
||||
if (dr_warp + cr < p.M && dc_warp + cc < p.N) {
|
||||
data_d[offsets + (dc_warp + cc) * p.stride_d + dr_warp + cr] = D_TYPE(sums[(wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr]);
|
||||
|
|
|
|||
|
|
@ -19,6 +19,7 @@
|
|||
#endif
|
||||
|
||||
#include "types.comp"
|
||||
#include "utils.comp"
|
||||
|
||||
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
|
||||
|
||||
|
|
@ -92,14 +93,15 @@ layout (binding = 2) writeonly buffer D {D_TYPE data_d[];};
|
|||
#ifdef MUL_MAT_ID
|
||||
layout (binding = 3) readonly buffer IDS {int data_ids[];};
|
||||
|
||||
shared u16vec4 row_ids[4096];
|
||||
shared u16vec4 row_ids[BN];
|
||||
|
||||
layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufB {
|
||||
B_TYPE b[];
|
||||
};
|
||||
|
||||
uint _ne1;
|
||||
shared uint _ne1_sh;
|
||||
layout (constant_id = 5) const uint subgroup_size = 32;
|
||||
shared uvec4 ballots_sh[BLOCK_SIZE / subgroup_size];
|
||||
|
||||
B_TYPE decodeFuncB(const in decodeBufB bl, const in uint blockCoords[2], const in uint coordInBlock[2])
|
||||
{
|
||||
|
|
@ -109,7 +111,7 @@ B_TYPE decodeFuncB(const in decodeBufB bl, const in uint blockCoords[2], const i
|
|||
return B_TYPE(0.0);
|
||||
}
|
||||
|
||||
const u16vec4 row_idx = row_ids[row_i];
|
||||
const u16vec4 row_idx = row_ids[row_i & (BN - 1)];
|
||||
B_TYPE ret = data_b[row_idx.y * p.batch_stride_b + row_idx.x * p.stride_b + blockCoords[1]];
|
||||
|
||||
return ret;
|
||||
|
|
@ -121,13 +123,74 @@ D_TYPE perElemOpD(const in uint32_t r, const in uint32_t c, const in D_TYPE elem
|
|||
uint dc = ic * BN + c;
|
||||
|
||||
if (dr < p.M && dc < _ne1) {
|
||||
uint row_i = dc;
|
||||
uint row_i = c;
|
||||
const u16vec4 row_idx = row_ids[row_i];
|
||||
data_d[row_idx.y * p.batch_stride_d + row_idx.z * p.stride_d + dr] = elem;
|
||||
}
|
||||
return elem;
|
||||
}
|
||||
|
||||
void load_row_ids(uint expert_idx, bool nei0_is_pow2, uint ic) {
|
||||
_ne1 = 0;
|
||||
uint num_elements = p.nei1 * p.nei0;
|
||||
uint nei0shift = findLSB(p.nei0);
|
||||
|
||||
uint ids[16];
|
||||
uint iter = 0;
|
||||
|
||||
for (uint j = 0; j < num_elements; j += BLOCK_SIZE) {
|
||||
// prefetch up to 16 elements
|
||||
if (iter == 0) {
|
||||
[[unroll]] for (uint k = 0; k < 16; ++k) {
|
||||
uint i = j + gl_LocalInvocationIndex + k*BLOCK_SIZE;
|
||||
bool in_range = i < num_elements;
|
||||
uint ii1;
|
||||
if (nei0_is_pow2) {
|
||||
ii1 = i >> nei0shift;
|
||||
} else {
|
||||
ii1 = i / p.nei0;
|
||||
}
|
||||
uint ii0 = i - ii1 * p.nei0;
|
||||
ids[k] = in_range ? data_ids[ii1*p.nbi1 + ii0] : 0;
|
||||
}
|
||||
}
|
||||
uint i = j + gl_LocalInvocationIndex;
|
||||
bool in_range = i < num_elements;
|
||||
uint ii1;
|
||||
if (nei0_is_pow2) {
|
||||
ii1 = i >> nei0shift;
|
||||
} else {
|
||||
ii1 = i / p.nei0;
|
||||
}
|
||||
uint ii0 = i - ii1 * p.nei0;
|
||||
uint id = ids[iter++];
|
||||
uvec4 ballot = subgroupBallot(in_range && id == expert_idx);
|
||||
|
||||
ballots_sh[gl_SubgroupID] = ballot;
|
||||
barrier();
|
||||
|
||||
uint subgroup_base = 0;
|
||||
uint total = 0;
|
||||
for (uint k = 0; k < gl_NumSubgroups; ++k) {
|
||||
if (k == gl_SubgroupID) {
|
||||
subgroup_base = total;
|
||||
}
|
||||
total += subgroupBallotBitCount(ballots_sh[k]);
|
||||
}
|
||||
barrier();
|
||||
|
||||
uint idx = subgroup_base + subgroupBallotExclusiveBitCount(ballot);
|
||||
if (in_range && id == expert_idx && _ne1 + idx >= ic * BN && _ne1 + idx < (ic + 1) * BN) {
|
||||
row_ids[_ne1 + idx - ic * BN] = u16vec4(fastmod(ii0, p.ne11), ii1, ii0, 0);
|
||||
}
|
||||
_ne1 += total;
|
||||
iter &= 15;
|
||||
if (_ne1 >= (ic + 1) * BN) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
barrier();
|
||||
}
|
||||
#endif
|
||||
|
||||
void main() {
|
||||
|
|
@ -157,45 +220,12 @@ void main() {
|
|||
const uint ic = gl_WorkGroupID.y;
|
||||
|
||||
#ifdef MUL_MAT_ID
|
||||
// Spread the search across all elements in the first subgroup
|
||||
if (gl_SubgroupID == 0) {
|
||||
_ne1 = 0;
|
||||
uint num_elements = p.nei1 * p.nei0;
|
||||
|
||||
uint ids[16];
|
||||
uint iter = 0;
|
||||
|
||||
for (uint j = 0; j < num_elements; j += gl_SubgroupSize) {
|
||||
// prefetch up to 16 elements
|
||||
if (iter == 0) {
|
||||
[[unroll]] for (uint k = 0; k < 16; ++k) {
|
||||
uint i = j + gl_SubgroupInvocationID + k*gl_SubgroupSize;
|
||||
bool in_range = i < num_elements;
|
||||
uint ii1 = i / p.nei0;
|
||||
uint ii0 = i % p.nei0;
|
||||
ids[k] = in_range ? data_ids[ii1*p.nbi1 + ii0] : 0;
|
||||
}
|
||||
}
|
||||
uint i = j + gl_SubgroupInvocationID;
|
||||
bool in_range = i < num_elements;
|
||||
uint ii1 = i / p.nei0;
|
||||
uint ii0 = i % p.nei0;
|
||||
uint id = ids[iter++];
|
||||
uvec4 ballot = subgroupBallot(in_range && id == expert_idx);
|
||||
uint idx = subgroupBallotExclusiveBitCount(ballot);
|
||||
if (in_range && id == expert_idx) {
|
||||
row_ids[_ne1 + idx] = u16vec4(ii0 % p.ne11, ii1, ii0, 0);
|
||||
}
|
||||
_ne1 += subgroupBallotBitCount(ballot);
|
||||
iter &= 15;
|
||||
}
|
||||
_ne1_sh = _ne1;
|
||||
if (bitCount(p.nei0) == 1) {
|
||||
load_row_ids(expert_idx, true, ic);
|
||||
} else {
|
||||
load_row_ids(expert_idx, false, ic);
|
||||
}
|
||||
|
||||
barrier();
|
||||
|
||||
_ne1 = _ne1_sh;
|
||||
|
||||
// Workgroup has no work
|
||||
if (ic * BN >= _ne1) return;
|
||||
#endif
|
||||
|
|
@ -235,7 +265,6 @@ void main() {
|
|||
tensorLayoutNV<2> tensorLayoutB = createTensorLayoutNV(2);
|
||||
tensorLayoutNV<2, gl_CooperativeMatrixClampModeConstantNV> tensorLayoutBClamp = createTensorLayoutNV(2, gl_CooperativeMatrixClampModeConstantNV);
|
||||
tensorLayoutNV<2, gl_CooperativeMatrixClampModeConstantNV> tensorLayoutD = createTensorLayoutNV(2, gl_CooperativeMatrixClampModeConstantNV);
|
||||
tensorLayoutD = setTensorLayoutStrideNV(tensorLayoutD, p.stride_d, 1);
|
||||
|
||||
#if QUANT_K > 1
|
||||
tensorLayoutA = setTensorLayoutBlockSizeNV(tensorLayoutA, 1, QUANT_K);
|
||||
|
|
@ -251,6 +280,8 @@ void main() {
|
|||
tensorLayoutAClamp = setTensorLayoutDimensionNV(tensorLayoutAClamp, p.M, end_k);
|
||||
tensorLayoutBClamp = setTensorLayoutDimensionNV(tensorLayoutBClamp, p.padded_N, end_k);
|
||||
|
||||
tensorLayoutD = setTensorLayoutStrideNV(tensorLayoutD, p.stride_d, 1);
|
||||
|
||||
tensorViewNV<2, false, 1, 0> tensorViewTranspose = createTensorViewNV(2, false, 1, 0);
|
||||
|
||||
#if !defined(MUL_MAT_ID)
|
||||
|
|
@ -319,6 +350,10 @@ void main() {
|
|||
sum = coopMatMulAdd(mat_a, mat_b, sum);
|
||||
block_k += BK;
|
||||
}
|
||||
#if defined(ACC_TYPE_MAX)
|
||||
[[unroll]] for (uint i = 0; i < sum.length(); ++i) { sum[i] = clamp(sum[i], -ACC_TYPE_MAX, ACC_TYPE_MAX); }
|
||||
#endif
|
||||
|
||||
coopmat<D_TYPE, gl_ScopeWorkgroup, BM, BNover4, gl_MatrixUseAccumulator> mat_d = coopmat<D_TYPE, gl_ScopeWorkgroup, BM, BNover4, gl_MatrixUseAccumulator>(sum);
|
||||
|
||||
coopMatStoreTensorNV(mat_d, data_d, pos_d, sliceTensorLayoutNV(tensorLayoutD, ic * BN, BNover4, ir * BM, BM), tensorViewTranspose);
|
||||
|
|
@ -358,6 +393,10 @@ void main() {
|
|||
sum = coopMatMulAdd(mat_a, mat_b, sum);
|
||||
block_k += BK;
|
||||
}
|
||||
#if defined(ACC_TYPE_MAX)
|
||||
[[unroll]] for (uint i = 0; i < sum.length(); ++i) { sum[i] = clamp(sum[i], -ACC_TYPE_MAX, ACC_TYPE_MAX); }
|
||||
#endif
|
||||
|
||||
coopmat<D_TYPE, gl_ScopeWorkgroup, BM, BNover2, gl_MatrixUseAccumulator> mat_d = coopmat<D_TYPE, gl_ScopeWorkgroup, BM, BNover2, gl_MatrixUseAccumulator>(sum);
|
||||
|
||||
coopMatStoreTensorNV(mat_d, data_d, pos_d, sliceTensorLayoutNV(tensorLayoutD, ic * BN, BNover2, ir * BM, BM), tensorViewTranspose);
|
||||
|
|
@ -398,6 +437,10 @@ void main() {
|
|||
sum = coopMatMulAdd(mat_a, mat_b, sum);
|
||||
block_k += BK;
|
||||
}
|
||||
#if defined(ACC_TYPE_MAX)
|
||||
[[unroll]] for (uint i = 0; i < sum.length(); ++i) { sum[i] = clamp(sum[i], -ACC_TYPE_MAX, ACC_TYPE_MAX); }
|
||||
#endif
|
||||
|
||||
coopmat<D_TYPE, gl_ScopeWorkgroup, BM, BN, gl_MatrixUseAccumulator> mat_d = coopmat<D_TYPE, gl_ScopeWorkgroup, BM, BN, gl_MatrixUseAccumulator>(sum);
|
||||
|
||||
coopMatStoreTensorNV(mat_d, data_d, pos_d, sliceTensorLayoutNV(tensorLayoutD, ic * BN, BN, ir * BM, BM), tensorViewTranspose);
|
||||
|
|
@ -414,18 +457,111 @@ void main() {
|
|||
|
||||
tensorLayoutBClamp = setTensorLayoutStrideNV(tensorLayoutBClamp, stride_b, 1);
|
||||
|
||||
coopmat<ACC_TYPE, gl_ScopeWorkgroup, BM, BN, gl_MatrixUseAccumulator> sum;
|
||||
sum = coopmat<ACC_TYPE, gl_ScopeWorkgroup, BM, BN, gl_MatrixUseAccumulator>(0.0);
|
||||
|
||||
uint k_iters = (end_k - start_k + BK - 1) / BK;
|
||||
|
||||
fetch_scales(ir * BM, pos_a, stride_a, start_k, tid, false);
|
||||
store_scales(tid);
|
||||
|
||||
#ifdef MUL_MAT_ID
|
||||
if (enable_smaller_matrices && ic * BN + BNover4 >= _ne1) {
|
||||
coopmat<ACC_TYPE, gl_ScopeWorkgroup, BM, BNover4, gl_MatrixUseAccumulator> sum;
|
||||
sum = coopmat<ACC_TYPE, gl_ScopeWorkgroup, BM, BNover4, gl_MatrixUseAccumulator>(0.0);
|
||||
|
||||
[[dont_unroll]]
|
||||
for (uint block_k = start_k, i = 0; i < k_iters; block_k += BK, ++i) {
|
||||
|
||||
if ((block_k % QUANT_K) == 0) {
|
||||
store_scales(tid);
|
||||
}
|
||||
if (block_k + BK < end_k && ((block_k + BK) % QUANT_K) == 0) {
|
||||
fetch_scales(ir * BM, pos_a, stride_a, block_k + BK, tid, false);
|
||||
}
|
||||
|
||||
if ((ir + 1) * BM <= p.M && block_k + BK <= end_k) {
|
||||
coopmat<MAT_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA> mat_a;
|
||||
coopmat<MAT_TYPE, gl_ScopeWorkgroup, BK, BNover4, gl_MatrixUseB> mat_b;
|
||||
|
||||
coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutA, ir * BM, BM, block_k, BK) DECODEFUNCA);
|
||||
coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BNover4, block_k, BK), tensorViewTranspose, decodeFuncB);
|
||||
|
||||
sum = coopMatMulAdd(mat_a, mat_b, sum);
|
||||
} else {
|
||||
coopmat<MAT_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA> mat_a;
|
||||
coopmat<MAT_TYPE, gl_ScopeWorkgroup, BK, BNover4, gl_MatrixUseB> mat_b;
|
||||
|
||||
coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutAClamp, ir * BM, BM, block_k, BK) DECODEFUNCA);
|
||||
coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BNover4, block_k, BK), tensorViewTranspose, decodeFuncB);
|
||||
|
||||
sum = coopMatMulAdd(mat_a, mat_b, sum);
|
||||
}
|
||||
}
|
||||
#if defined(ACC_TYPE_MAX)
|
||||
[[unroll]] for (uint i = 0; i < sum.length(); ++i) { sum[i] = clamp(sum[i], -ACC_TYPE_MAX, ACC_TYPE_MAX); }
|
||||
#endif
|
||||
|
||||
// Convert from ACC_TYPE to D_TYPE
|
||||
coopmat<D_TYPE, gl_ScopeWorkgroup, BM, BNover4, gl_MatrixUseAccumulator> mat_d;
|
||||
mat_d = coopmat<D_TYPE, gl_ScopeWorkgroup, BM, BNover4, gl_MatrixUseAccumulator>(sum);
|
||||
|
||||
// Call callback to store each element, remapping row through shared memory
|
||||
coopMatPerElementNV(mat_d, mat_d, perElemOpD, ir, ic);
|
||||
return;
|
||||
}
|
||||
if (enable_smaller_matrices && ic * BN + BNover2 >= _ne1) {
|
||||
coopmat<ACC_TYPE, gl_ScopeWorkgroup, BM, BNover2, gl_MatrixUseAccumulator> sum;
|
||||
sum = coopmat<ACC_TYPE, gl_ScopeWorkgroup, BM, BNover2, gl_MatrixUseAccumulator>(0.0);
|
||||
|
||||
[[dont_unroll]]
|
||||
for (uint block_k = start_k, i = 0; i < k_iters; block_k += BK, ++i) {
|
||||
|
||||
if ((block_k % QUANT_K) == 0) {
|
||||
store_scales(tid);
|
||||
}
|
||||
if (block_k + BK < end_k && ((block_k + BK) % QUANT_K) == 0) {
|
||||
fetch_scales(ir * BM, pos_a, stride_a, block_k + BK, tid, false);
|
||||
}
|
||||
|
||||
if ((ir + 1) * BM <= p.M && block_k + BK <= end_k) {
|
||||
coopmat<MAT_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA> mat_a;
|
||||
coopmat<MAT_TYPE, gl_ScopeWorkgroup, BK, BNover2, gl_MatrixUseB> mat_b;
|
||||
|
||||
coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutA, ir * BM, BM, block_k, BK) DECODEFUNCA);
|
||||
coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BNover2, block_k, BK), tensorViewTranspose, decodeFuncB);
|
||||
|
||||
sum = coopMatMulAdd(mat_a, mat_b, sum);
|
||||
} else {
|
||||
coopmat<MAT_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA> mat_a;
|
||||
coopmat<MAT_TYPE, gl_ScopeWorkgroup, BK, BNover2, gl_MatrixUseB> mat_b;
|
||||
|
||||
coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutAClamp, ir * BM, BM, block_k, BK) DECODEFUNCA);
|
||||
coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BNover2, block_k, BK), tensorViewTranspose, decodeFuncB);
|
||||
|
||||
sum = coopMatMulAdd(mat_a, mat_b, sum);
|
||||
}
|
||||
}
|
||||
#if defined(ACC_TYPE_MAX)
|
||||
[[unroll]] for (uint i = 0; i < sum.length(); ++i) { sum[i] = clamp(sum[i], -ACC_TYPE_MAX, ACC_TYPE_MAX); }
|
||||
#endif
|
||||
|
||||
// Convert from ACC_TYPE to D_TYPE
|
||||
coopmat<D_TYPE, gl_ScopeWorkgroup, BM, BNover2, gl_MatrixUseAccumulator> mat_d;
|
||||
mat_d = coopmat<D_TYPE, gl_ScopeWorkgroup, BM, BNover2, gl_MatrixUseAccumulator>(sum);
|
||||
|
||||
// Call callback to store each element, remapping row through shared memory
|
||||
coopMatPerElementNV(mat_d, mat_d, perElemOpD, ir, ic);
|
||||
return;
|
||||
}
|
||||
#endif
|
||||
coopmat<ACC_TYPE, gl_ScopeWorkgroup, BM, BN, gl_MatrixUseAccumulator> sum;
|
||||
sum = coopmat<ACC_TYPE, gl_ScopeWorkgroup, BM, BN, gl_MatrixUseAccumulator>(0.0);
|
||||
|
||||
[[dont_unroll]]
|
||||
for (uint block_k = start_k, i = 0; i < k_iters; block_k += BK, ++i) {
|
||||
|
||||
store_scales(tid);
|
||||
if (block_k + BK < end_k) {
|
||||
if ((block_k % QUANT_K) == 0) {
|
||||
store_scales(tid);
|
||||
}
|
||||
if (block_k + BK < end_k && ((block_k + BK) % QUANT_K) == 0) {
|
||||
fetch_scales(ir * BM, pos_a, stride_a, block_k + BK, tid, false);
|
||||
}
|
||||
|
||||
|
|
@ -455,6 +591,9 @@ void main() {
|
|||
sum = coopMatMulAdd(mat_a, mat_b, sum);
|
||||
}
|
||||
}
|
||||
#if defined(ACC_TYPE_MAX)
|
||||
[[unroll]] for (uint i = 0; i < sum.length(); ++i) { sum[i] = clamp(sum[i], -ACC_TYPE_MAX, ACC_TYPE_MAX); }
|
||||
#endif
|
||||
|
||||
// Convert from ACC_TYPE to D_TYPE
|
||||
coopmat<D_TYPE, gl_ScopeWorkgroup, BM, BN, gl_MatrixUseAccumulator> mat_d;
|
||||
|
|
|
|||
|
|
@ -0,0 +1,556 @@
|
|||
void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uint idx_m, const uint block, const uint end_k) {
|
||||
#if defined(DATA_A_F32) || defined(DATA_A_F16)
|
||||
#if LOAD_VEC_A == 8
|
||||
const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
|
||||
const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2;
|
||||
FLOAT_TYPE_VEC8 aa = FLOAT_TYPE_VEC8(data_a[idx]);
|
||||
buf_a[buf_idx ] = aa[0].xy;
|
||||
buf_a[buf_idx + 1] = aa[0].zw;
|
||||
buf_a[buf_idx + 2] = aa[1].xy;
|
||||
buf_a[buf_idx + 3] = aa[1].zw;
|
||||
#elif LOAD_VEC_A == 4
|
||||
const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
|
||||
const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2;
|
||||
FLOAT_TYPE_VEC4 aa = FLOAT_TYPE_VEC4(data_a[idx]);
|
||||
buf_a[buf_idx ] = aa.xy;
|
||||
buf_a[buf_idx + 1] = aa.zw;
|
||||
#else // LOAD_VEC_BATCH_A == 2
|
||||
const uint idx = pos_a + col * p.stride_a + row * 2;
|
||||
const uint buf_idx = col * SHMEM_STRIDE + row;
|
||||
if (idx_m < p.M && block + row * 2 + 1 < end_k) {
|
||||
buf_a[buf_idx] = FLOAT_TYPE_VEC2(data_a[idx],
|
||||
data_a[idx + 1]);
|
||||
} else if (idx_m < p.M && block + row * 2 < end_k) {
|
||||
buf_a[buf_idx] = FLOAT_TYPE_VEC2(data_a[idx], 0.0f);
|
||||
} else {
|
||||
buf_a[buf_idx] = FLOAT_TYPE_VEC2(0.0f);
|
||||
}
|
||||
#endif
|
||||
#elif defined(DATA_A_BF16)
|
||||
#if LOAD_VEC_A == 4
|
||||
const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
|
||||
const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2;
|
||||
FLOAT_TYPE_VEC4 aa = FLOAT_TYPE_VEC4(TO_FLOAT_TYPE(data_a[idx]));
|
||||
buf_a[buf_idx ] = aa.xy;
|
||||
buf_a[buf_idx + 1] = aa.zw;
|
||||
#else // LOAD_VEC_BATCH_A == 2
|
||||
const uint idx = pos_a + col * p.stride_a + row * 2;
|
||||
const uint buf_idx = col * SHMEM_STRIDE + row;
|
||||
if (idx_m < p.M && block + row * 2 + 1 < end_k) {
|
||||
buf_a[buf_idx] = FLOAT_TYPE_VEC2(TO_FLOAT_TYPE(data_a[idx]),
|
||||
TO_FLOAT_TYPE(data_a[idx + 1]));
|
||||
} else if (idx_m < p.M && block + row * 2 < end_k) {
|
||||
buf_a[buf_idx] = FLOAT_TYPE_VEC2(TO_FLOAT_TYPE(data_a[idx]), 0.0f);
|
||||
} else {
|
||||
buf_a[buf_idx] = FLOAT_TYPE_VEC2(0.0f);
|
||||
}
|
||||
#endif
|
||||
#elif defined(DATA_A_Q4_0)
|
||||
const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
|
||||
const uint buf_idx = col * SHMEM_STRIDE + 2 * row;
|
||||
|
||||
const uint ib = idx / 4;
|
||||
const uint iqs = idx & 0x03;
|
||||
|
||||
const float d = float(data_a_packed16[ib].d);
|
||||
const uint vui = uint(data_a_packed16[ib].qs[2*iqs]) | (uint(data_a_packed16[ib].qs[2*iqs + 1]) << 16);
|
||||
const vec4 v0 = (vec4(unpack8(vui & 0x0F0F0F0F)) - 8.0f) * d;
|
||||
const vec4 v1 = (vec4(unpack8((vui >> 4) & 0x0F0F0F0F)) - 8.0f) * d;
|
||||
|
||||
buf_a[buf_idx ] = FLOAT_TYPE_VEC2(v0.xy);
|
||||
buf_a[buf_idx + 1] = FLOAT_TYPE_VEC2(v0.zw);
|
||||
buf_a[buf_idx + 8] = FLOAT_TYPE_VEC2(v1.xy);
|
||||
buf_a[buf_idx + 9] = FLOAT_TYPE_VEC2(v1.zw);
|
||||
#elif defined(DATA_A_Q4_1)
|
||||
const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
|
||||
const uint buf_idx = col * SHMEM_STRIDE + 2 * row;
|
||||
|
||||
const uint ib = idx / 4;
|
||||
const uint iqs = idx & 0x03;
|
||||
|
||||
const float d = float(data_a_packed16[ib].d);
|
||||
const float m = float(data_a_packed16[ib].m);
|
||||
const uint vui = uint(data_a_packed16[ib].qs[2*iqs]) | (uint(data_a_packed16[ib].qs[2*iqs + 1]) << 16);
|
||||
const vec4 v0 = vec4(unpack8(vui & 0x0F0F0F0F)) * d + m;
|
||||
const vec4 v1 = vec4(unpack8((vui >> 4) & 0x0F0F0F0F)) * d + m;
|
||||
|
||||
buf_a[buf_idx ] = FLOAT_TYPE_VEC2(v0.xy);
|
||||
buf_a[buf_idx + 1 ] = FLOAT_TYPE_VEC2(v0.zw);
|
||||
buf_a[buf_idx + 8 ] = FLOAT_TYPE_VEC2(v1.xy);
|
||||
buf_a[buf_idx + 9 ] = FLOAT_TYPE_VEC2(v1.zw);
|
||||
#elif defined(DATA_A_Q5_0)
|
||||
const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
|
||||
const uint buf_idx = col * SHMEM_STRIDE + row;
|
||||
|
||||
const uint ib = idx / 8;
|
||||
const uint iqs = idx & 0x07;
|
||||
|
||||
const float d = float(data_a_packed16[ib].d);
|
||||
const uint uint_qh = uint(data_a_packed16[ib].qh[1]) << 16 | uint(data_a_packed16[ib].qh[0]);
|
||||
const ivec2 qh0 = ivec2(((uint_qh >> 2*iqs) << 4) & 0x10, (uint_qh >> (2*iqs + 12)) & 0x10);
|
||||
const ivec2 qh1 = ivec2(((uint_qh >> (2*iqs + 1)) << 4) & 0x10, (uint_qh >> (2*iqs + 13)) & 0x10);
|
||||
|
||||
const uint vui = uint(data_a_packed16[ib].qs[iqs]);
|
||||
const vec4 v = (vec4((vui & 0xF) | qh0.x, ((vui >> 4) & 0xF) | qh0.y, ((vui >> 8) & 0xF) | qh1.x, (vui >> 12) | qh1.y) - 16.0f) * d;
|
||||
|
||||
buf_a[buf_idx ] = FLOAT_TYPE_VEC2(v.xz);
|
||||
buf_a[buf_idx + 8] = FLOAT_TYPE_VEC2(v.yw);
|
||||
#elif defined(DATA_A_Q5_1)
|
||||
const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
|
||||
const uint buf_idx = col * SHMEM_STRIDE + row;
|
||||
|
||||
const uint ib = idx / 8;
|
||||
const uint iqs = idx & 0x07;
|
||||
|
||||
const float d = float(data_a_packed16[ib].d);
|
||||
const float m = float(data_a_packed16[ib].m);
|
||||
const uint uint_qh = data_a_packed16[ib].qh;
|
||||
const ivec2 qh0 = ivec2(((uint_qh >> 2*iqs) << 4) & 0x10, (uint_qh >> (2*iqs + 12)) & 0x10);
|
||||
const ivec2 qh1 = ivec2(((uint_qh >> (2*iqs + 1)) << 4) & 0x10, (uint_qh >> (2*iqs + 13)) & 0x10);
|
||||
|
||||
const uint vui = uint(data_a_packed16[ib].qs[iqs]);
|
||||
const vec4 v = vec4((vui & 0xF) | qh0.x, ((vui >> 4) & 0xF) | qh0.y, ((vui >> 8) & 0xF) | qh1.x, (vui >> 12) | qh1.y) * d + m;
|
||||
|
||||
buf_a[buf_idx ] = FLOAT_TYPE_VEC2(v.xz);
|
||||
buf_a[buf_idx + 8] = FLOAT_TYPE_VEC2(v.yw);
|
||||
#elif defined(DATA_A_Q8_0)
|
||||
const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
|
||||
const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2;
|
||||
|
||||
const uint ib = idx / 8;
|
||||
const uint iqs = idx & 0x07;
|
||||
|
||||
const float d = float(data_a_packed16[ib].d);
|
||||
const i8vec2 v0 = unpack8(int32_t(data_a_packed16[ib].qs[2*iqs])).xy; // vec4 used due to #12147
|
||||
const i8vec2 v1 = unpack8(int32_t(data_a_packed16[ib].qs[2*iqs + 1])).xy;
|
||||
const vec4 v = vec4(v0.x, v0.y, v1.x, v1.y) * d;
|
||||
|
||||
buf_a[buf_idx ] = FLOAT_TYPE_VEC2(v.xy);
|
||||
buf_a[buf_idx + 1] = FLOAT_TYPE_VEC2(v.zw);
|
||||
#elif defined(DATA_A_Q2_K)
|
||||
const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
|
||||
const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2;
|
||||
|
||||
const uint ib = idx / 128; // 2 values per idx
|
||||
const uint iqs = idx % 128; // 0..127
|
||||
|
||||
const uint qsi = (iqs / 64) * 32 + (iqs % 16) * 2; // 0,2,4..30
|
||||
const uint scalesi = iqs / 8; // 0..15
|
||||
const uint qsshift = ((iqs % 64) / 16) * 2; // 0,2,4,6
|
||||
|
||||
const uvec2 qs = uvec2(data_a[ib].qs[qsi], data_a[ib].qs[qsi + 1]);
|
||||
const uint scales = data_a[ib].scales[scalesi];
|
||||
const vec2 d = vec2(data_a[ib].d);
|
||||
|
||||
const vec2 v = d.x * float(scales & 0xF) * vec2((qs >> qsshift) & 3) - d.y * float(scales >> 4);
|
||||
|
||||
buf_a[buf_idx] = FLOAT_TYPE_VEC2(v.xy);
|
||||
#elif defined(DATA_A_Q3_K)
|
||||
const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
|
||||
const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2;
|
||||
|
||||
const uint ib = idx / 128; // 2 values per idx
|
||||
const uint iqs = idx % 128; // 0..127
|
||||
|
||||
const uint n = iqs / 64; // 0,1
|
||||
const uint qsi = n * 32 + (iqs % 16) * 2; // 0,2,4..62
|
||||
const uint hmi = (iqs % 16) * 2; // 0,2,4..30
|
||||
const uint j = (iqs % 64) / 4; // 0..3
|
||||
const uint is = iqs / 8; // 0..15
|
||||
const uint halfsplit = ((iqs % 64) / 16); // 0,1,2,3
|
||||
const uint qsshift = halfsplit * 2; // 0,2,4,6
|
||||
const uint m = 1 << (4 * n + halfsplit); // 1,2,4,8,16,32,64,128
|
||||
|
||||
const int8_t us = int8_t(((data_a[ib].scales[is % 8] >> (4 * int(is / 8))) & 0xF)
|
||||
| (((data_a[ib].scales[8 + (is % 4)] >> (2 * int(is / 4))) & 3) << 4));
|
||||
const float dl = float(data_a[ib].d) * float(us - 32);
|
||||
|
||||
buf_a[buf_idx] = FLOAT_TYPE_VEC2(dl * float(int8_t((data_a[ib].qs[qsi ] >> qsshift) & 3) - (((data_a[ib].hmask[hmi ] & m) != 0) ? 0 : 4)),
|
||||
dl * float(int8_t((data_a[ib].qs[qsi + 1] >> qsshift) & 3) - (((data_a[ib].hmask[hmi + 1] & m) != 0) ? 0 : 4)));
|
||||
#elif defined(DATA_A_Q4_K)
|
||||
const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
|
||||
const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2;
|
||||
|
||||
const uint ib = idx / 128; // 2 values per idx
|
||||
const uint iqs = idx % 128; // 0..127
|
||||
|
||||
const uint n = iqs / 32; // 0,1,2,3
|
||||
const uint b = (iqs % 32) / 16; // 0,1
|
||||
const uint is = 2 * n + b; // 0..7
|
||||
const uint qsi = n * 32 + (iqs % 16) * 2; // 0,2,4..126
|
||||
|
||||
const vec2 loadd = vec2(data_a[ib].d);
|
||||
|
||||
const uint scidx0 = (is < 4) ? is : (is + 4);
|
||||
const uint scidx1 = (is < 4) ? is : (is - 4);
|
||||
const uint scidxmask1 = (is < 4) ? 0x30 : 0xC0;
|
||||
const uint scidxshift1 = (is < 4) ? 0 : 2;
|
||||
const uint mbidx0 = is + 4;
|
||||
const uint mbidx1 = (is < 4) ? is + 4 : is;
|
||||
const uint mbidxmask0 = (is < 4) ? 0xF : 0xF0;
|
||||
const uint mbidxshift0 = (is < 4) ? 0 : 4;
|
||||
const uint mbidxmask1 = (is < 4) ? 0x30 : 0xC0;
|
||||
const uint mbidxshift1 = (is < 4) ? 0 : 2;
|
||||
|
||||
const uint8_t sc = uint8_t((data_a[ib].scales[scidx0] & 0xF) | ((data_a[ib].scales[scidx1] & scidxmask1) >> scidxshift1));
|
||||
const uint8_t mbyte = uint8_t((data_a[ib].scales[mbidx0] & mbidxmask0) >> mbidxshift0 | ((data_a[ib].scales[mbidx1] & mbidxmask1) >> mbidxshift1));
|
||||
|
||||
const float d = loadd.x * sc;
|
||||
const float m = -loadd.y * mbyte;
|
||||
|
||||
buf_a[buf_idx] = FLOAT_TYPE_VEC2(fma(d, float((data_a[ib].qs[qsi ] >> (b * 4)) & 0xF), m),
|
||||
fma(d, float((data_a[ib].qs[qsi + 1] >> (b * 4)) & 0xF), m));
|
||||
#elif defined(DATA_A_Q5_K)
|
||||
const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
|
||||
const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2;
|
||||
|
||||
const uint ib = idx / 128; // 2 values per idx
|
||||
const uint iqs = idx % 128; // 0..127
|
||||
|
||||
const uint n = iqs / 32; // 0,1,2,3
|
||||
const uint b = (iqs % 32) / 16; // 0,1
|
||||
const uint is = 2 * n + b; // 0..7
|
||||
const uint qsi = n * 32 + (iqs % 16) * 2; // 0,2,4..126
|
||||
const uint qhi = (iqs % 16) * 2; // 0,2,4..30
|
||||
|
||||
const uint8_t hm = uint8_t(1 << (iqs / 16));
|
||||
|
||||
const vec2 loadd = vec2(data_a[ib].d);
|
||||
|
||||
const uint scidx0 = (is < 4) ? is : (is + 4);
|
||||
const uint scidx1 = (is < 4) ? is : (is - 4);
|
||||
const uint scidxmask1 = (is < 4) ? 0x30 : 0xC0;
|
||||
const uint scidxshift1 = (is < 4) ? 0 : 2;
|
||||
const uint mbidx0 = is + 4;
|
||||
const uint mbidx1 = (is < 4) ? is + 4 : is;
|
||||
const uint mbidxmask0 = (is < 4) ? 0xF : 0xF0;
|
||||
const uint mbidxshift0 = (is < 4) ? 0 : 4;
|
||||
const uint mbidxmask1 = (is < 4) ? 0x30 : 0xC0;
|
||||
const uint mbidxshift1 = (is < 4) ? 0 : 2;
|
||||
|
||||
const uint8_t sc = uint8_t((data_a[ib].scales[scidx0] & 0xF) | ((data_a[ib].scales[scidx1] & scidxmask1) >> scidxshift1));
|
||||
const uint8_t mbyte = uint8_t(((data_a[ib].scales[mbidx0] & mbidxmask0) >> mbidxshift0) | ((data_a[ib].scales[mbidx1] & mbidxmask1) >> mbidxshift1));
|
||||
|
||||
const float d = loadd.x * sc;
|
||||
const float m = -loadd.y * mbyte;
|
||||
|
||||
buf_a[buf_idx] = FLOAT_TYPE_VEC2(fma(d, float((data_a[ib].qs[qsi ] >> (b * 4)) & 0xF) + float((data_a[ib].qh[qhi ] & hm) != 0 ? 16 : 0), m),
|
||||
fma(d, float((data_a[ib].qs[qsi + 1] >> (b * 4)) & 0xF) + float((data_a[ib].qh[qhi + 1] & hm) != 0 ? 16 : 0), m));
|
||||
#elif defined(DATA_A_Q6_K)
|
||||
const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
|
||||
const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2;
|
||||
|
||||
const uint ib = idx / 128; // 2 values per idx
|
||||
const uint iqs = idx % 128; // 0..127
|
||||
|
||||
const uint n = iqs / 64; // 0,1
|
||||
const uint b = (iqs % 64) / 32; // 0,1
|
||||
const uint is_b = (iqs % 16) / 8; // 0,1
|
||||
const uint qhshift = ((iqs % 64) / 16) * 2; // 0,2,4,6
|
||||
const uint is = 8 * n + qhshift + is_b; // 0..15
|
||||
const uint qsi = n * 64 + (iqs % 32) * 2; // 0,2,4..126
|
||||
const uint qhi = n * 32 + (iqs % 16) * 2; // 0,2,4..62
|
||||
|
||||
const float dscale = float(data_a[ib].d) * float(data_a[ib].scales[is]);
|
||||
|
||||
buf_a[buf_idx] = FLOAT_TYPE_VEC2(dscale * float(int8_t(((data_a[ib].ql[qsi ] >> (b * 4)) & 0xF) | (((data_a[ib].qh[qhi ] >> qhshift) & 3) << 4)) - 32),
|
||||
dscale * float(int8_t(((data_a[ib].ql[qsi + 1] >> (b * 4)) & 0xF) | (((data_a[ib].qh[qhi + 1] >> qhshift) & 3) << 4)) - 32));
|
||||
#elif defined(DATA_A_IQ1_S)
|
||||
const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
|
||||
const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2;
|
||||
|
||||
const uint ib = idx / 32; // 8 values per idx
|
||||
const uint ib32 = (idx % 32) / 4; // 0..7
|
||||
const uint ib8 = idx % 32;
|
||||
|
||||
const float d = float(data_a[ib].d);
|
||||
const uint qh = data_a[ib].qh[ib32];
|
||||
const uint qs = data_a[ib].qs[ib8];
|
||||
const float dl = d * (2 * bitfieldExtract(qh, 12, 3) + 1);
|
||||
const float delta = ((qh & 0x8000) != 0) ? -IQ1S_DELTA : IQ1S_DELTA;
|
||||
const int16_t grid = int16_t(iq1s_grid[qs | (bitfieldExtract(qh, 3 * int(ib8 & 3), 3) << 8)]);
|
||||
|
||||
[[unroll]] for (int k = 0; k < 4; ++k) {
|
||||
buf_a[buf_idx + k] = FLOAT_TYPE_VEC2(dl * (bitfieldExtract(grid, 4 * k , 2) + delta),
|
||||
dl * (bitfieldExtract(grid, 4 * k + 2, 2) + delta));
|
||||
}
|
||||
#elif defined(DATA_A_IQ1_M)
|
||||
const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
|
||||
const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2;
|
||||
|
||||
const uint ib = idx / 32; // 8 values per idx
|
||||
const uint ib8 = idx % 32;
|
||||
const uint ib16 = ib8 / 2;
|
||||
|
||||
const uint16_t[4] scales = data_a[ib].scales;
|
||||
const u16vec4 s = u16vec4(scales[0], scales[1], scales[2], scales[3]) >> 12;
|
||||
const float d = float(unpackHalf2x16(s.x | (s.y << 4) | (s.z << 8) | (s.w << 12)).x);
|
||||
const uint sc = scales[ib8 / 8];
|
||||
const uint qs = data_a[ib].qs[ib8];
|
||||
const uint qh = data_a[ib].qh[ib16] >> (4 * (ib8 & 1));
|
||||
const float dl = d * (2 * bitfieldExtract(sc, 3 * int(ib16 & 3), 3) + 1);
|
||||
const float delta = ((qh & 8) != 0) ? -IQ1M_DELTA : IQ1M_DELTA;
|
||||
const int16_t grid = int16_t(iq1s_grid[qs | ((qh & 7) << 8)]);
|
||||
|
||||
[[unroll]] for (int k = 0; k < 4; ++k) {
|
||||
buf_a[buf_idx + k] = FLOAT_TYPE_VEC2(dl * (bitfieldExtract(grid, 4 * k , 2) + delta),
|
||||
dl * (bitfieldExtract(grid, 4 * k + 2, 2) + delta));
|
||||
}
|
||||
#elif defined(DATA_A_IQ2_XXS)
|
||||
const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
|
||||
const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2;
|
||||
|
||||
const uint ib = idx / 32; // 8 values per idx
|
||||
const uint ib32 = (idx % 32) / 4; // 0..7
|
||||
const uint ib8 = idx % 4;
|
||||
|
||||
const float d = float(data_a[ib].d);
|
||||
const uint qs = data_a[ib].qs[8 * ib32 + ib8];
|
||||
const uint signs = pack32(u8vec4(
|
||||
data_a[ib].qs[8*ib32 + 4],
|
||||
data_a[ib].qs[8*ib32 + 5],
|
||||
data_a[ib].qs[8*ib32 + 6],
|
||||
data_a[ib].qs[8*ib32 + 7]
|
||||
));
|
||||
const FLOAT_TYPE db = FLOAT_TYPE(d * 0.25 * (0.5 + (signs >> 28)));
|
||||
const uint32_t sign7 = bitfieldExtract(signs, 7 * int(ib8), 7);
|
||||
const uint sign = sign7 | (bitCount(sign7) << 7);
|
||||
const uvec2 grid = iq2xxs_grid[qs];
|
||||
const vec4 grid0 = vec4(unpack8(grid.x));
|
||||
const vec4 grid1 = vec4(unpack8(grid.y));
|
||||
|
||||
buf_a[buf_idx ] = db * FLOAT_TYPE_VEC2((sign & 1) != 0 ? -grid0.x : grid0.x,
|
||||
(sign & 2) != 0 ? -grid0.y : grid0.y);
|
||||
buf_a[buf_idx + 1] = db * FLOAT_TYPE_VEC2((sign & 4) != 0 ? -grid0.z : grid0.z,
|
||||
(sign & 8) != 0 ? -grid0.w : grid0.w);
|
||||
buf_a[buf_idx + 2] = db * FLOAT_TYPE_VEC2((sign & 16) != 0 ? -grid1.x : grid1.x,
|
||||
(sign & 32) != 0 ? -grid1.y : grid1.y);
|
||||
buf_a[buf_idx + 3] = db * FLOAT_TYPE_VEC2((sign & 64) != 0 ? -grid1.z : grid1.z,
|
||||
(sign & 128) != 0 ? -grid1.w : grid1.w);
|
||||
#elif defined(DATA_A_IQ2_XS)
|
||||
const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
|
||||
const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2;
|
||||
|
||||
const uint ib = idx / 32; // 8 values per idx
|
||||
const uint ib32 = (idx % 32) / 4; // 0..7
|
||||
const uint ib8 = idx % 4; // 0..3
|
||||
|
||||
const float d = float(data_a[ib].d);
|
||||
const uint scale = (data_a[ib].scales[ib32] >> (2 * (ib8 & 2))) & 0xf;
|
||||
const FLOAT_TYPE db = FLOAT_TYPE(d * 0.25 * (0.5 + scale));
|
||||
const uint qs = data_a[ib].qs[4 * ib32 + ib8];
|
||||
const uint sign7 = qs >> 9;
|
||||
const uint sign = sign7 | (bitCount(sign7) << 7);
|
||||
const uvec2 grid = iq2xs_grid[qs & 511];
|
||||
const vec4 grid0 = vec4(unpack8(grid.x));
|
||||
const vec4 grid1 = vec4(unpack8(grid.y));
|
||||
|
||||
buf_a[buf_idx ] = db * FLOAT_TYPE_VEC2((sign & 1) != 0 ? -grid0.x : grid0.x,
|
||||
(sign & 2) != 0 ? -grid0.y : grid0.y);
|
||||
buf_a[buf_idx + 1] = db * FLOAT_TYPE_VEC2((sign & 4) != 0 ? -grid0.z : grid0.z,
|
||||
(sign & 8) != 0 ? -grid0.w : grid0.w);
|
||||
buf_a[buf_idx + 2] = db * FLOAT_TYPE_VEC2((sign & 16) != 0 ? -grid1.x : grid1.x,
|
||||
(sign & 32) != 0 ? -grid1.y : grid1.y);
|
||||
buf_a[buf_idx + 3] = db * FLOAT_TYPE_VEC2((sign & 64) != 0 ? -grid1.z : grid1.z,
|
||||
(sign & 128) != 0 ? -grid1.w : grid1.w);
|
||||
#elif defined(DATA_A_IQ2_S)
|
||||
const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
|
||||
const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2;
|
||||
|
||||
const uint ib = idx / 32; // 8 values per idx
|
||||
const uint ib8 = idx % 32; // 0..31
|
||||
const uint ib32 = ib8 / 4; // 0..7
|
||||
|
||||
const uint scale = (data_a[ib].scales[ib32] >> (2 * (ib8 & 2))) & 0xf;
|
||||
const uint qs = data_a[ib].qs[ib8];
|
||||
const uint qh = data_a[ib].qh[ib32];
|
||||
const uint qhshift = 2 * (ib8 % 4);
|
||||
const uint sign = data_a[ib].qs[QUANT_K / 8 + ib8];
|
||||
|
||||
const float d = float(data_a[ib].d);
|
||||
const FLOAT_TYPE db = FLOAT_TYPE(d * 0.25 * (0.5 + scale));
|
||||
const uvec2 grid = iq2s_grid[qs | ((qh << (8 - qhshift)) & 0x300)];
|
||||
const vec4 grid0 = vec4(unpack8(grid.x));
|
||||
const vec4 grid1 = vec4(unpack8(grid.y));
|
||||
|
||||
buf_a[buf_idx ] = db * FLOAT_TYPE_VEC2((sign & 1) != 0 ? -grid0.x : grid0.x,
|
||||
(sign & 2) != 0 ? -grid0.y : grid0.y);
|
||||
buf_a[buf_idx + 1] = db * FLOAT_TYPE_VEC2((sign & 4) != 0 ? -grid0.z : grid0.z,
|
||||
(sign & 8) != 0 ? -grid0.w : grid0.w);
|
||||
buf_a[buf_idx + 2] = db * FLOAT_TYPE_VEC2((sign & 16) != 0 ? -grid1.x : grid1.x,
|
||||
(sign & 32) != 0 ? -grid1.y : grid1.y);
|
||||
buf_a[buf_idx + 3] = db * FLOAT_TYPE_VEC2((sign & 64) != 0 ? -grid1.z : grid1.z,
|
||||
(sign & 128) != 0 ? -grid1.w : grid1.w);
|
||||
#elif defined(DATA_A_IQ3_XXS)
|
||||
const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
|
||||
const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2;
|
||||
|
||||
const uint ib = idx / 64; // 4 values per idx
|
||||
const uint iqs = idx % 64; // 0..63
|
||||
const uint is = QUANT_K / 4 + 4 * (iqs / 8); // 8 values
|
||||
|
||||
const float d = float(data_a[ib].d);
|
||||
const uint qs = data_a[ib].qs[iqs];
|
||||
const uint signs = pack32(u8vec4(
|
||||
data_a[ib].qs[is+0],
|
||||
data_a[ib].qs[is+1],
|
||||
data_a[ib].qs[is+2],
|
||||
data_a[ib].qs[is+3]
|
||||
));
|
||||
const float db = d * 0.5 * (0.5 + (signs >> 28));
|
||||
const uint32_t sign7 = bitfieldExtract(signs, 7 * (int(iqs / 2) % 4), 7);
|
||||
const uint sign = (sign7 | (bitCount(sign7) << 7)) >> (4 * (idx % 2));
|
||||
const uint grid = iq3xxs_grid[qs];
|
||||
const vec4 v = db * vec4(unpack8(grid));
|
||||
|
||||
buf_a[buf_idx ] = FLOAT_TYPE_VEC2((sign & 1) != 0 ? -v.x : v.x,
|
||||
(sign & 2) != 0 ? -v.y : v.y);
|
||||
buf_a[buf_idx + 1] = FLOAT_TYPE_VEC2((sign & 4) != 0 ? -v.z : v.z,
|
||||
(sign & 8) != 0 ? -v.w : v.w);
|
||||
#elif defined(DATA_A_IQ3_S)
|
||||
const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
|
||||
const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2;
|
||||
|
||||
const uint ib = idx / 64; // 4 values per idx
|
||||
const uint iqs = idx % 64; // 0..63
|
||||
const uint iqh = iqs / 8;
|
||||
|
||||
const float d = float(data_a[ib].d);
|
||||
const uint qs = data_a[ib].qs[iqs];
|
||||
const uint qh = data_a[ib].qh[iqh];
|
||||
const int8_t sign = int8_t(data_a[ib].signs[iqs / 2] >> (4 * (idx % 2)));
|
||||
const uint scale = data_a[ib].scales[iqs / 16];
|
||||
const i8vec2 sign01 = i8vec2(1 - (2 & i8vec2(sign << 1, sign)));
|
||||
const float db = d * (1 + 2 * ((scale >> (4 * (iqh & 1))) & 0xf));
|
||||
const uint32_t grid = iq3s_grid[qs | ((qh << (8 - (iqs % 8))) & 256)];
|
||||
const vec4 v = db * vec4(unpack8(grid));
|
||||
|
||||
buf_a[buf_idx ] = FLOAT_TYPE_VEC2((sign & 1) != 0 ? -v.x : v.x,
|
||||
(sign & 2) != 0 ? -v.y : v.y);
|
||||
buf_a[buf_idx + 1] = FLOAT_TYPE_VEC2((sign & 4) != 0 ? -v.z : v.z,
|
||||
(sign & 8) != 0 ? -v.w : v.w);
|
||||
#elif defined(DATA_A_IQ4_XS)
|
||||
const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
|
||||
const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2;
|
||||
|
||||
const uint ib = idx / 128; // 2 values per idx
|
||||
const uint ib32 = (idx % 128) / 16; // 0..7
|
||||
const uint iq = 16 * ib32 + 2 * (idx % 8);
|
||||
|
||||
const uint sl = (data_a[ib].scales_l[ib32/2] >> (4 * (ib32 & 1))) & 0xF;
|
||||
const uint sh = ((data_a[ib].scales_h) >> (2 * ib32)) & 3;
|
||||
const uint qshift = (idx & 8) >> 1;
|
||||
u8vec2 qs = u8vec2(data_a[ib].qs[iq], data_a[ib].qs[iq + 1]);
|
||||
qs = (qs >> qshift) & uint8_t(0xF);
|
||||
|
||||
const float d = float(data_a[ib].d);
|
||||
const vec2 v = d * float(int(sl | (sh << 4)) - 32) * vec2(kvalues_iq4nl[qs.x], kvalues_iq4nl[qs.y]);
|
||||
|
||||
buf_a[buf_idx ] = FLOAT_TYPE_VEC2(v.xy);
|
||||
#elif defined(DATA_A_IQ4_NL)
|
||||
const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
|
||||
const uint buf_idx = col * SHMEM_STRIDE + row;
|
||||
|
||||
const uint ib = idx / 8;
|
||||
const uint iqs = idx & 0x07;
|
||||
|
||||
const FLOAT_TYPE d = FLOAT_TYPE(data_a_packed16[ib].d);
|
||||
const uint vui = uint(data_a_packed16[ib].qs[iqs]);
|
||||
|
||||
buf_a[buf_idx ] = d * FLOAT_TYPE_VEC2(kvalues_iq4nl[vui & 0xF],
|
||||
kvalues_iq4nl[bitfieldExtract(vui, 8, 4)]);
|
||||
buf_a[buf_idx + 8] = d * FLOAT_TYPE_VEC2(kvalues_iq4nl[bitfieldExtract(vui, 4, 4)],
|
||||
kvalues_iq4nl[vui >> 12]);
|
||||
#elif defined(DATA_A_MXFP4)
|
||||
const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
|
||||
const uint buf_idx = col * SHMEM_STRIDE + row;
|
||||
|
||||
const uint ib = idx / 8;
|
||||
const uint iqs = (idx & 0x07) * 2;
|
||||
|
||||
const float d = e8m0_to_fp32(data_a[ib].e);
|
||||
const uint vui = uint(data_a[ib].qs[iqs]);
|
||||
const uint vui2 = uint(data_a[ib].qs[iqs+1]);
|
||||
|
||||
buf_a[buf_idx ] = FLOAT_TYPE_VEC2(kvalues_mxfp4[vui & 0xF] * d,
|
||||
kvalues_mxfp4[vui2 & 0xF] * d);
|
||||
buf_a[buf_idx + 8] = FLOAT_TYPE_VEC2(kvalues_mxfp4[vui >> 4] * d,
|
||||
kvalues_mxfp4[vui2 >> 4] * d);
|
||||
#endif
|
||||
}
|
||||
|
||||
#if !defined(MUL_MAT_ID)
|
||||
void load_b_to_shmem(const uint pos_b, const uint row, const uint col, const uint idx_n, const uint block, const uint end_k) {
|
||||
#if LOAD_VEC_B == 8
|
||||
// Not supported for b_type bf16 because bf16mat2x4 does not exist
|
||||
const uint idx = pos_b + col * p.stride_b / LOAD_VEC_B + row;
|
||||
const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_B / 2;
|
||||
FLOAT_TYPE_VEC8 bb = FLOAT_TYPE_VEC8(data_b[idx]);
|
||||
buf_b[buf_idx + 0] = bb[0].xy;
|
||||
buf_b[buf_idx + 1] = bb[0].zw;
|
||||
buf_b[buf_idx + 2] = bb[1].xy;
|
||||
buf_b[buf_idx + 3] = bb[1].zw;
|
||||
#elif LOAD_VEC_B == 4
|
||||
const uint idx = pos_b + col * p.stride_b / LOAD_VEC_B + row;
|
||||
const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_B / 2;
|
||||
#if defined(DATA_B_BF16)
|
||||
FLOAT_TYPE_VEC4 bb = FLOAT_TYPE_VEC4(TO_FLOAT_TYPE(data_b[idx]));
|
||||
#else
|
||||
FLOAT_TYPE_VEC4 bb = FLOAT_TYPE_VEC4(data_b[idx]);
|
||||
#endif
|
||||
buf_b[buf_idx + 0] = bb.xy;
|
||||
buf_b[buf_idx + 1] = bb.zw;
|
||||
#else // LOAD_VEC_BATCH_B == 2
|
||||
const uint idx = pos_b + col * p.stride_b + row * 2;
|
||||
const uint buf_idx = col * SHMEM_STRIDE + row;
|
||||
if (idx_n < p.N && block + row * 2 + 1 < end_k) {
|
||||
buf_b[buf_idx] = FLOAT_TYPE_VEC2(TO_FLOAT_TYPE(data_b[idx]),
|
||||
TO_FLOAT_TYPE(data_b[idx + 1]));
|
||||
} else if (idx_n < p.N && block + row * 2 < end_k) {
|
||||
buf_b[buf_idx] = FLOAT_TYPE_VEC2(TO_FLOAT_TYPE(data_b[idx]), 0.0f);
|
||||
} else {
|
||||
buf_b[buf_idx] = FLOAT_TYPE_VEC2(0.0f);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
#else
|
||||
void load_b_to_shmem(const uint pos_b, const uint row, const uint col, const uint ic, const uint _ne1, const uint block, const uint end_k) {
|
||||
#if LOAD_VEC_B == 8
|
||||
// Not supported for b_type bf16 because bf16mat2x4 does not exist
|
||||
const u16vec2 row_idx = row_ids[col];
|
||||
const uint idx = pos_b + row_idx.y * p.batch_stride_b / LOAD_VEC_B + (row_idx.x % p.ne11) * p.stride_b / LOAD_VEC_B + row;
|
||||
const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_B / 2;
|
||||
FLOAT_TYPE_VEC8 bb = FLOAT_TYPE_VEC8(data_b[idx]);
|
||||
buf_b[buf_idx + 0] = bb[0].xy;
|
||||
buf_b[buf_idx + 1] = bb[0].zw;
|
||||
buf_b[buf_idx + 2] = bb[1].xy;
|
||||
buf_b[buf_idx + 3] = bb[1].zw;
|
||||
#elif LOAD_VEC_B == 4
|
||||
const u16vec2 row_idx = row_ids[col];
|
||||
const uint idx = pos_b + row_idx.y * p.batch_stride_b / LOAD_VEC_B + (row_idx.x % p.ne11) * p.stride_b / LOAD_VEC_B + row;
|
||||
const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_B / 2;
|
||||
#if defined(DATA_B_BF16)
|
||||
FLOAT_TYPE_VEC4 bb = FLOAT_TYPE_VEC4(TO_FLOAT_TYPE(data_b[idx]));
|
||||
#else
|
||||
FLOAT_TYPE_VEC4 bb = FLOAT_TYPE_VEC4(data_b[idx]);
|
||||
#endif
|
||||
buf_b[buf_idx + 0] = bb.xy;
|
||||
buf_b[buf_idx + 1] = bb.zw;
|
||||
#else // LOAD_VEC_BATCH_B == 2
|
||||
const uint row_i = ic * BN + col;
|
||||
const uint buf_idx = col * SHMEM_STRIDE + row;
|
||||
if (row_i < _ne1 && block + row * 2 + 1 < end_k) {
|
||||
const u16vec2 row_idx = row_ids[col];
|
||||
const uint idx = pos_b + row_idx.y * p.batch_stride_b + (row_idx.x % p.ne11) * p.stride_b + row * 2;
|
||||
buf_b[buf_idx] = FLOAT_TYPE_VEC2(TO_FLOAT_TYPE(data_b[idx]),
|
||||
TO_FLOAT_TYPE(data_b[idx + 1]));
|
||||
} else if (row_i < _ne1 && block + row * 2 < end_k) {
|
||||
const u16vec2 row_idx = row_ids[col];
|
||||
const uint idx = pos_b + row_idx.y * p.batch_stride_b + (row_idx.x % p.ne11) * p.stride_b + row * 2;
|
||||
buf_b[buf_idx] = FLOAT_TYPE_VEC2(TO_FLOAT_TYPE(data_b[idx]), 0.0f);
|
||||
} else {
|
||||
buf_b[buf_idx] = FLOAT_TYPE_VEC2(0.0f);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
#endif
|
||||
|
|
@ -28,7 +28,7 @@ layout (binding = 0) readonly buffer A {A_TYPE_PACKED16 data_a[];};
|
|||
#if defined(A_TYPE_PACKED32)
|
||||
layout (binding = 0) readonly buffer A_PACKED32 {A_TYPE_PACKED32 data_a_packed32[];};
|
||||
#endif
|
||||
layout (binding = 1) readonly buffer B {block_q8_1_packed32 data_b[];};
|
||||
layout (binding = 1) readonly buffer B {block_q8_1_x4_packed128 data_b[];};
|
||||
layout (binding = 2) writeonly buffer D {D_TYPE data_d[];};
|
||||
|
||||
#ifdef MUL_MAT_ID
|
||||
|
|
@ -98,7 +98,7 @@ shared FLOAT_TYPE_VEC2 buf_b_ds[BN];
|
|||
#endif
|
||||
|
||||
#define LOAD_VEC_A (4 * QUANT_R)
|
||||
#define LOAD_VEC_B 4
|
||||
#define LOAD_VEC_B 16
|
||||
|
||||
#ifdef MUL_MAT_ID
|
||||
shared u16vec2 row_ids[4096];
|
||||
|
|
@ -270,15 +270,22 @@ void main() {
|
|||
const uint iqs = idx & 0x7;
|
||||
#else
|
||||
const uint ib = pos_b_ib + (loadc_b + l) * p.stride_b / BK;
|
||||
const uint ib_outer = ib / 4;
|
||||
const uint ib_inner = ib % 4;
|
||||
|
||||
const uint iqs = loadr_b;
|
||||
#endif
|
||||
|
||||
const uint buf_ib = loadc_b + l;
|
||||
|
||||
if (iqs == 0) {
|
||||
buf_b_ds[buf_ib] = FLOAT_TYPE_VEC2(data_b[ib].ds);
|
||||
buf_b_ds[buf_ib] = FLOAT_TYPE_VEC2(data_b[ib_outer].ds[ib_inner]);
|
||||
}
|
||||
buf_b_qs[buf_ib * SHMEM_STRIDE + iqs] = data_b[ib].qs[iqs];
|
||||
const ivec4 values = data_b[ib_outer].qs[ib_inner * 2 + iqs];
|
||||
buf_b_qs[buf_ib * SHMEM_STRIDE + iqs * 4 ] = values.x;
|
||||
buf_b_qs[buf_ib * SHMEM_STRIDE + iqs * 4 + 1] = values.y;
|
||||
buf_b_qs[buf_ib * SHMEM_STRIDE + iqs * 4 + 2] = values.z;
|
||||
buf_b_qs[buf_ib * SHMEM_STRIDE + iqs * 4 + 3] = values.w;
|
||||
}
|
||||
|
||||
barrier();
|
||||
|
|
@ -349,7 +356,7 @@ void main() {
|
|||
cache_b_qs[cc * (BK / 4) + idx_k]);
|
||||
}
|
||||
|
||||
sums[sums_idx] += mul_q8_1(q_sum, cache_a_dm[cache_a_idx], cache_b_ds[cc]);
|
||||
sums[sums_idx] += mul_q8_1(q_sum, cache_a_dm[cache_a_idx], cache_b_ds[cc], 1);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -16,8 +16,8 @@ i32vec2 repack(uint ib, uint iqs) {
|
|||
(vui >> 4) & 0x0F0F0F0F);
|
||||
}
|
||||
|
||||
ACC_TYPE mul_q8_1(int32_t q_sum, float da, vec2 dsb) {
|
||||
return ACC_TYPE(da * (float(q_sum) * dsb.x - 8.0f * dsb.y));
|
||||
ACC_TYPE mul_q8_1(const int32_t q_sum, const float da, const vec2 dsb, const int32_t sum_divisor) {
|
||||
return ACC_TYPE(da * (float(q_sum) * dsb.x - (8 / sum_divisor) * dsb.y));
|
||||
}
|
||||
#endif
|
||||
|
||||
|
|
@ -29,8 +29,8 @@ i32vec2 repack(uint ib, uint iqs) {
|
|||
(vui >> 4) & 0x0F0F0F0F);
|
||||
}
|
||||
|
||||
ACC_TYPE mul_q8_1(int32_t q_sum, vec2 dma, vec2 dsb) {
|
||||
return ACC_TYPE(float(q_sum) * dma.x * dsb.x + dma.y * dsb.y);
|
||||
ACC_TYPE mul_q8_1(const int32_t q_sum, const vec2 dma, const vec2 dsb, const int32_t sum_divisor) {
|
||||
return ACC_TYPE(float(q_sum) * dma.x * dsb.x + dma.y * dsb.y / sum_divisor);
|
||||
}
|
||||
#endif
|
||||
|
||||
|
|
@ -50,8 +50,8 @@ i32vec2 repack(uint ib, uint iqs) {
|
|||
return i32vec2(v0, v1);
|
||||
}
|
||||
|
||||
ACC_TYPE mul_q8_1(int32_t q_sum, float da, vec2 dsb) {
|
||||
return ACC_TYPE(da * (float(q_sum) * dsb.x - 16.0f * dsb.y));
|
||||
ACC_TYPE mul_q8_1(const int32_t q_sum, const float da, const vec2 dsb, const int32_t sum_divisor) {
|
||||
return ACC_TYPE(da * (float(q_sum) * dsb.x - (16 / sum_divisor) * dsb.y));
|
||||
}
|
||||
#endif
|
||||
|
||||
|
|
@ -69,8 +69,8 @@ i32vec2 repack(uint ib, uint iqs) {
|
|||
return i32vec2(v0, v1);
|
||||
}
|
||||
|
||||
ACC_TYPE mul_q8_1(int32_t q_sum, vec2 dma, vec2 dsb) {
|
||||
return ACC_TYPE(float(q_sum) * dma.x * dsb.x + dma.y * dsb.y);
|
||||
ACC_TYPE mul_q8_1(const int32_t q_sum, const vec2 dma, const vec2 dsb, const int32_t sum_divisor) {
|
||||
return ACC_TYPE(float(q_sum) * dma.x * dsb.x + dma.y * dsb.y / sum_divisor);
|
||||
}
|
||||
#endif
|
||||
|
||||
|
|
@ -81,7 +81,7 @@ int32_t repack(uint ib, uint iqs) {
|
|||
data_a[ib].qs[iqs * 2 + 1]));
|
||||
}
|
||||
|
||||
ACC_TYPE mul_q8_1(int32_t q_sum, float da, vec2 dsb) {
|
||||
ACC_TYPE mul_q8_1(const int32_t q_sum, const float da, const vec2 dsb, const int32_t sum_divisor) {
|
||||
return ACC_TYPE(float(q_sum) * da * dsb.x);
|
||||
}
|
||||
#endif
|
||||
|
|
|
|||
|
|
@ -0,0 +1,111 @@
|
|||
#version 450
|
||||
|
||||
#extension GL_EXT_shader_16bit_storage : require
|
||||
#extension GL_EXT_nonuniform_qualifier : enable
|
||||
#extension GL_EXT_control_flow_attributes : require
|
||||
#if ADD_RMS
|
||||
#extension GL_KHR_shader_subgroup_arithmetic : enable
|
||||
#extension GL_KHR_shader_subgroup_basic : enable
|
||||
#endif
|
||||
|
||||
#include "rte.comp"
|
||||
#include "types.comp"
|
||||
#include "utils.comp"
|
||||
|
||||
layout (push_constant) uniform parameter2
|
||||
{
|
||||
// shape for dst
|
||||
uint ne20; uint ne21; uint ne22; uint ne23;
|
||||
|
||||
// strides for srcs+dst
|
||||
uint nb[12][4];
|
||||
|
||||
uint rms_partials;
|
||||
} p;
|
||||
|
||||
// Workaround for MoltenVK Bug, see https://github.com/ggml-org/llama.cpp/issues/15498
|
||||
// layout (binding = 0) readonly buffer A {A_TYPE data_a[];} a[];
|
||||
// layout (binding = 0) writeonly buffer D {D_TYPE data_d[];} d[];
|
||||
layout (binding = 0) buffer A {A_TYPE data_a[];} a[];
|
||||
layout (binding = 0) buffer D {D_TYPE data_d[];} d[];
|
||||
|
||||
layout (binding = 0, std430) buffer PartialBuf {float partial_sums[];} partials[];
|
||||
|
||||
layout(constant_id = 0) const uint num_srcs = 2;
|
||||
|
||||
uint src_idx(uint s, uint i00, uint i01, uint i02, uint i03) {
|
||||
return i03*p.nb[s][3] + i02*p.nb[s][2] + i01*p.nb[s][1] + i00*p.nb[s][0];
|
||||
}
|
||||
|
||||
uint dst_idx(uint i00, uint i01, uint i02, uint i03) {
|
||||
uint nb20 = p.nb[num_srcs][0];
|
||||
uint nb21 = p.nb[num_srcs][1];
|
||||
uint nb22 = p.nb[num_srcs][2];
|
||||
uint nb23 = p.nb[num_srcs][3];
|
||||
return i03*nb23 + i02*nb22 + i01*nb21 + i00*nb20;
|
||||
}
|
||||
|
||||
uint get_idx() {
|
||||
return gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x;
|
||||
}
|
||||
|
||||
const uint num_threads = 256;
|
||||
|
||||
layout(local_size_x = num_threads, local_size_y = 1, local_size_z = 1) in;
|
||||
|
||||
#if ADD_RMS
|
||||
// XXX TODO this could be sized based on number of subgroups, but that't not considered a constant
|
||||
shared FLOAT_TYPE sumsh[num_threads];
|
||||
#endif
|
||||
|
||||
void main() {
|
||||
uint idx = get_idx();
|
||||
uint orig_idx = idx;
|
||||
|
||||
uint ne = p.ne20 * p.ne21 * p.ne22 * p.ne23;
|
||||
|
||||
// num_threads * num_iter must equal 512, to match the wg_denoms and get_idx calculation
|
||||
const uint num_iter = 2;
|
||||
|
||||
FLOAT_TYPE sum_sq = 0;
|
||||
|
||||
[[unroll]] for (uint i = 0; i < num_iter; ++i) {
|
||||
if (idx >= ne) {
|
||||
continue;
|
||||
}
|
||||
uint i00, i01, i02, i03;
|
||||
get_indices(idx, i00, i01, i02, i03, p.ne20, p.ne21, p.ne22, p.ne23);
|
||||
|
||||
FLOAT_TYPE sum = FLOAT_TYPE(0);
|
||||
[[unroll]] for (uint s = 0; s < num_srcs; ++s) {
|
||||
sum += FLOAT_TYPE(a[s].data_a[src_idx(s, i00, i01, i02, i03)]);
|
||||
}
|
||||
sum_sq += sum*sum;
|
||||
d[num_srcs].data_d[dst_idx(i00, i01, i02, i03)] = D_TYPE(sum);
|
||||
|
||||
idx += num_threads;
|
||||
}
|
||||
|
||||
#if ADD_RMS
|
||||
if (p.rms_partials != 0) {
|
||||
// reduce the sum within each subgroup, then across subgroups
|
||||
const uint NumSubgroups = num_threads / gl_SubgroupSize;
|
||||
sum_sq = subgroupAdd(sum_sq);
|
||||
if (gl_SubgroupInvocationID == 0) {
|
||||
sumsh[gl_SubgroupID] = sum_sq;
|
||||
}
|
||||
barrier();
|
||||
[[unroll]] for (uint s = NumSubgroups / 2; s > 0; s >>= 1) {
|
||||
if (gl_SubgroupID < s && gl_SubgroupInvocationID == 0) {
|
||||
sum_sq += sumsh[gl_SubgroupID + s];
|
||||
sumsh[gl_SubgroupID] = sum_sq;
|
||||
}
|
||||
barrier();
|
||||
}
|
||||
|
||||
if (gl_SubgroupID == 0 && gl_SubgroupInvocationID == 0) {
|
||||
partials[num_srcs + 1].partial_sums[orig_idx / (num_iter * num_threads)] = sum_sq;
|
||||
}
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
|
@ -0,0 +1,22 @@
|
|||
#version 450
|
||||
|
||||
#include "generic_head.comp"
|
||||
|
||||
layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
|
||||
|
||||
layout (binding = 0) buffer X {A_TYPE data_x[];};
|
||||
layout (binding = 1) readonly buffer G {A_TYPE data_grad[];};
|
||||
layout (binding = 2) readonly buffer P {float data_params[2];};
|
||||
|
||||
void main() {
|
||||
const uint i = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x;
|
||||
|
||||
if (i >= p.KX) {
|
||||
return;
|
||||
}
|
||||
|
||||
const float alpha = data_params[0];
|
||||
const float keep = 1.f - alpha * data_params[1];
|
||||
|
||||
data_x[i] = data_x[i] * keep - alpha * data_grad[i];
|
||||
}
|
||||
|
|
@ -1,7 +1,25 @@
|
|||
#version 450
|
||||
|
||||
#include "types.comp"
|
||||
#include "generic_unary_head.comp"
|
||||
|
||||
layout (push_constant) uniform parameter
|
||||
{
|
||||
uint ne;
|
||||
uint ne00; uint ne01; uint ne02; uint ne03; uint nb00; uint nb01; uint nb02; uint nb03;
|
||||
uint ne10; uint ne11; uint ne12; uint ne13; uint nb10; uint nb11; uint nb12; uint nb13;
|
||||
uint misalign_offsets;
|
||||
|
||||
uint lp0; uint rp0;
|
||||
uint lp1; uint rp1;
|
||||
uint lp2; uint rp2;
|
||||
uint lp3; uint rp3;
|
||||
} p;
|
||||
|
||||
uint get_aoffset() { return p.misalign_offsets >> 16; }
|
||||
uint get_doffset() { return p.misalign_offsets & 0xFFFF; }
|
||||
|
||||
layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
|
||||
layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};
|
||||
|
||||
layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
|
||||
|
||||
|
|
@ -19,10 +37,13 @@ void main() {
|
|||
const uint i1 = (idx - i3_offset - i2_offset) / p.ne10;
|
||||
const uint i0 = idx - i3_offset - i2_offset - i1*p.ne10;
|
||||
|
||||
const uint src0_idx = i3*p.nb03 + i2*p.nb02 + i1*p.nb01 + i0*p.nb00;
|
||||
const uint src0_idx = (i3 - p.lp3)*p.nb03 + (i2 - p.lp2)*p.nb02 + (i1 - p.lp1)*p.nb01 + (i0 - p.lp0)*p.nb00;
|
||||
const uint dst_idx = i3*p.nb13 + i2*p.nb12 + i1*p.nb11 + i0*p.nb10;
|
||||
|
||||
const bool is_src0 = i0 < p.ne00 && i1 < p.ne01 && i2 < p.ne02 && i3 < p.ne03;
|
||||
const bool is_src0 = i0 >= p.lp0 && i0 < p.ne10 - p.rp0 &&
|
||||
i1 >= p.lp1 && i1 < p.ne11 - p.rp1 &&
|
||||
i2 >= p.lp2 && i2 < p.ne12 - p.rp2 &&
|
||||
i3 >= p.lp3 && i3 < p.ne13 - p.rp3;
|
||||
|
||||
data_d[get_doffset() + dst_idx] = D_TYPE(is_src0 ? data_a[get_aoffset() + src0_idx] : 0.0f);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -3,6 +3,15 @@
|
|||
#extension GL_EXT_control_flow_attributes : require
|
||||
#extension GL_EXT_shader_16bit_storage : require
|
||||
|
||||
#ifdef USE_SUBGROUPS
|
||||
#extension GL_KHR_shader_subgroup_basic : require
|
||||
#extension GL_KHR_shader_subgroup_clustered : require
|
||||
|
||||
#define INVOCATION_ID gl_SubgroupInvocationID.x
|
||||
#else
|
||||
#define INVOCATION_ID gl_LocalInvocationID.x
|
||||
#endif
|
||||
|
||||
layout (push_constant) uniform parameter
|
||||
{
|
||||
uint ne;
|
||||
|
|
@ -14,13 +23,19 @@ layout(constant_id = 0) const uint GROUP_SIZE = 32;
|
|||
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
|
||||
|
||||
layout (binding = 0) readonly buffer A {vec4 data_a[];};
|
||||
#ifndef QBLOCK_X4
|
||||
layout (binding = 1) writeonly buffer D {block_q8_1_packed32 data_b[];};
|
||||
#else
|
||||
layout (binding = 1) writeonly buffer D {block_q8_1_x4 data_b[];};
|
||||
#endif
|
||||
|
||||
#ifndef USE_SUBGROUPS
|
||||
shared float shmem[GROUP_SIZE];
|
||||
#endif
|
||||
|
||||
void quantize() {
|
||||
const uint wgid = gl_WorkGroupID.x;
|
||||
const uint tid = gl_LocalInvocationID.x;
|
||||
const uint tid = INVOCATION_ID;
|
||||
|
||||
// Each thread handles a vec4, so 8 threads handle a block
|
||||
const uint blocks_per_group = GROUP_SIZE / 8;
|
||||
|
|
@ -30,9 +45,19 @@ void quantize() {
|
|||
const uint ib = wgid * blocks_per_group + block_in_wg;
|
||||
const uint iqs = tid % 8;
|
||||
|
||||
#ifndef QBLOCK_X4
|
||||
if (ib >= gl_NumWorkGroups.x * blocks_per_group) {
|
||||
return;
|
||||
}
|
||||
#else
|
||||
const uint ibx4_outer = ib / 4;
|
||||
const uint ibx4_inner = ib % 4;
|
||||
|
||||
const uint required_x4_blocks = (p.ne + 127) / 128;
|
||||
if (ibx4_outer >= required_x4_blocks) {
|
||||
return;
|
||||
}
|
||||
#endif
|
||||
|
||||
const uint a_idx = ib * 8 + iqs;
|
||||
|
||||
|
|
@ -40,7 +65,9 @@ void quantize() {
|
|||
const vec4 abs_vals = abs(vals);
|
||||
|
||||
// Find absolute max for each block
|
||||
shmem[tid] = max(max(abs_vals.x, abs_vals.y), max(abs_vals.z, abs_vals.w));
|
||||
const float thread_max = max(max(abs_vals.x, abs_vals.y), max(abs_vals.z, abs_vals.w));
|
||||
#ifndef USE_SUBGROUPS
|
||||
shmem[tid] = thread_max;
|
||||
barrier();
|
||||
[[unroll]] for (uint s = 4; s > 0; s >>= 1) {
|
||||
if (iqs < s) {
|
||||
|
|
@ -50,14 +77,28 @@ void quantize() {
|
|||
}
|
||||
|
||||
const float amax = shmem[block_in_wg * 8];
|
||||
#else
|
||||
const float amax = subgroupClusteredMax(thread_max, 8);
|
||||
#endif
|
||||
|
||||
const float d = amax / 127.0;
|
||||
const float d_inv = d != 0.0 ? 1.0 / d : 0.0;
|
||||
vals = round(vals * d_inv);
|
||||
|
||||
#ifndef QBLOCK_X4
|
||||
data_b[ib].qs[iqs] = pack32(i8vec4(round(vals)));
|
||||
#else
|
||||
data_b[ibx4_outer].qs[ibx4_inner * 8 + iqs] = pack32(i8vec4(round(vals)));
|
||||
#endif
|
||||
|
||||
#ifndef USE_SUBGROUPS
|
||||
barrier();
|
||||
#endif
|
||||
|
||||
// Calculate the sum for each block
|
||||
shmem[tid] = vals.x + vals.y + vals.z + vals.w;
|
||||
const float thread_sum = vals.x + vals.y + vals.z + vals.w;
|
||||
#ifndef USE_SUBGROUPS
|
||||
shmem[tid] = thread_sum;
|
||||
barrier();
|
||||
[[unroll]] for (uint s = 4; s > 0; s >>= 1) {
|
||||
if (iqs < s) {
|
||||
|
|
@ -65,10 +106,19 @@ void quantize() {
|
|||
}
|
||||
barrier();
|
||||
}
|
||||
#else
|
||||
const float sum = subgroupClusteredAdd(thread_sum, 8);
|
||||
#endif
|
||||
if (iqs == 0) {
|
||||
#ifndef USE_SUBGROUPS
|
||||
const float sum = shmem[tid];
|
||||
#endif
|
||||
|
||||
#ifndef QBLOCK_X4
|
||||
data_b[ib].ds = f16vec2(vec2(d, sum * d));
|
||||
#else
|
||||
data_b[ibx4_outer].ds[ibx4_inner] = f16vec2(vec2(d, sum * d));
|
||||
#endif
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -10,9 +10,9 @@ layout (constant_id = 1) const bool do_multiply = false;
|
|||
|
||||
layout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in;
|
||||
|
||||
shared FLOAT_TYPE sum[BLOCK_SIZE];
|
||||
shared FLOAT_TYPE sumsh[BLOCK_SIZE];
|
||||
|
||||
void main() {
|
||||
void rms_norm(uint num_iters) {
|
||||
const uint ncols = p.ne00;
|
||||
const uint nrows = gl_NumWorkGroups.x;
|
||||
const uint nchannels = gl_NumWorkGroups.y;
|
||||
|
|
@ -30,38 +30,76 @@ void main() {
|
|||
uint32_t b_offset = src1_idx(0, row, channel, samp) + get_boffset();
|
||||
uint32_t d_offset = ((samp*nchannels + channel)*nrows + row)*ncols + get_doffset();
|
||||
|
||||
sum[tid] = FLOAT_TYPE(0.0f); // partial sum for thread in warp
|
||||
FLOAT_TYPE sum = FLOAT_TYPE(0.0f); // partial sum for thread in warp
|
||||
|
||||
[[unroll]] for (uint col = tid; col < ncols; col += BLOCK_SIZE) {
|
||||
const FLOAT_TYPE xi = FLOAT_TYPE(data_a[a_offset + col]);
|
||||
sum[tid] += xi * xi;
|
||||
[[unroll]] for (uint col = tid, idx = 0; idx < num_iters; col += BLOCK_SIZE, ++idx) {
|
||||
FLOAT_TYPE xi = FLOAT_TYPE(0);
|
||||
if (col < ncols) {
|
||||
xi = FLOAT_TYPE(data_a[a_offset + col]);
|
||||
}
|
||||
sum += xi * xi;
|
||||
}
|
||||
|
||||
sumsh[tid] = sum;
|
||||
// sum up partial sums and write back result
|
||||
barrier();
|
||||
[[unroll]] for (int s = BLOCK_SIZE / 2; s > 0; s >>= 1) {
|
||||
if (tid < s) {
|
||||
sum[tid] += sum[tid + s];
|
||||
sum += sumsh[tid + s];
|
||||
sumsh[tid] = sum;
|
||||
}
|
||||
barrier();
|
||||
}
|
||||
sum = sumsh[0];
|
||||
|
||||
const FLOAT_TYPE mean = sum[0] / FLOAT_TYPE(ncols);
|
||||
const FLOAT_TYPE mean = sum / FLOAT_TYPE(ncols);
|
||||
const FLOAT_TYPE scale = inversesqrt(mean + FLOAT_TYPE(p.param1));
|
||||
|
||||
if (do_multiply) {
|
||||
if (ncols > p.ne10) {
|
||||
[[unroll]] for (uint col = tid; col < ncols; col += BLOCK_SIZE) {
|
||||
[[unroll]] for (uint col = tid, idx = 0; idx < num_iters; col += BLOCK_SIZE, ++idx) {
|
||||
if (col >= ncols) {
|
||||
continue;
|
||||
}
|
||||
data_d[d_offset + col] = D_TYPE(scale * FLOAT_TYPE(data_a[a_offset + col]) * FLOAT_TYPE(data_b[b_offset + fastmod(col, p.ne10)]));
|
||||
}
|
||||
} else {
|
||||
[[unroll]] for (uint col = tid; col < ncols; col += BLOCK_SIZE) {
|
||||
[[unroll]] for (uint col = tid, idx = 0; idx < num_iters; col += BLOCK_SIZE, ++idx) {
|
||||
if (col >= ncols) {
|
||||
continue;
|
||||
}
|
||||
data_d[d_offset + col] = D_TYPE(scale * FLOAT_TYPE(data_a[a_offset + col]) * FLOAT_TYPE(data_b[b_offset + col]));
|
||||
}
|
||||
}
|
||||
} else {
|
||||
[[unroll]] for (uint col = tid; col < ncols; col += BLOCK_SIZE) {
|
||||
[[unroll]] for (uint col = tid, idx = 0; idx < num_iters; col += BLOCK_SIZE, ++idx) {
|
||||
if (col >= ncols) {
|
||||
continue;
|
||||
}
|
||||
data_d[d_offset + col] = D_TYPE(scale * FLOAT_TYPE(data_a[a_offset + col]));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void main() {
|
||||
// instantiate the rms_norm function for several different
|
||||
// dimensions, to allow loop unrolling
|
||||
uint num_blocks = (p.ne00 + BLOCK_SIZE - 1) / BLOCK_SIZE;
|
||||
if (num_blocks > 32) {
|
||||
rms_norm(num_blocks);
|
||||
} else if (num_blocks > 16) {
|
||||
rms_norm(32);
|
||||
} else if (num_blocks > 8) {
|
||||
rms_norm(16);
|
||||
} else if (num_blocks > 4) {
|
||||
rms_norm(8);
|
||||
} else if (num_blocks == 4) {
|
||||
rms_norm(4);
|
||||
} else if (num_blocks == 3) {
|
||||
rms_norm(3);
|
||||
} else if (num_blocks == 2) {
|
||||
rms_norm(2);
|
||||
} else if (num_blocks == 1) {
|
||||
rms_norm(1);
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -0,0 +1,65 @@
|
|||
#version 450
|
||||
|
||||
#include "generic_binary_head.comp"
|
||||
#include "types.comp"
|
||||
|
||||
#extension GL_EXT_control_flow_attributes : enable
|
||||
#extension GL_KHR_shader_subgroup_arithmetic : enable
|
||||
#extension GL_KHR_shader_subgroup_basic : enable
|
||||
|
||||
#define BLOCK_SIZE 128
|
||||
|
||||
layout (constant_id = 1) const bool do_multiply = false;
|
||||
|
||||
layout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in;
|
||||
|
||||
layout (binding = 3, std430) readonly buffer PartialsBuf {float partial_sums[];};
|
||||
|
||||
shared FLOAT_TYPE sumsh[BLOCK_SIZE];
|
||||
|
||||
void main() {
|
||||
const uint ncols = p.ne00;
|
||||
const uint nrows = gl_NumWorkGroups.x;
|
||||
const uint nchannels = gl_NumWorkGroups.y;
|
||||
|
||||
const uint row = 0;
|
||||
const uint channel = gl_WorkGroupID.y;
|
||||
const uint samp = gl_WorkGroupID.z;
|
||||
// The work is split across multiple workgroups in the x dimension. Each invocation
|
||||
// processes one element
|
||||
const uint tid = gl_GlobalInvocationID.x;
|
||||
|
||||
const uint stride_row = p.nb01;
|
||||
const uint stride_channel = p.nb02;
|
||||
const uint stride_sample = p.nb03;
|
||||
|
||||
uint32_t a_offset = samp*stride_sample + channel*stride_channel + row*stride_row + get_aoffset();
|
||||
uint32_t b_offset = src1_idx(0, row, channel, samp) + get_boffset();
|
||||
uint32_t d_offset = ((samp*nchannels + channel)*nrows + row)*ncols + get_doffset();
|
||||
|
||||
FLOAT_TYPE sum = FLOAT_TYPE(0.0f); // partial sum for thread in warp
|
||||
|
||||
uint32_t num_partials = p.param3;
|
||||
for (uint32_t i = gl_SubgroupInvocationID; i < num_partials; i += gl_SubgroupSize) {
|
||||
sum += partial_sums[i];
|
||||
}
|
||||
sum = subgroupAdd(sum);
|
||||
|
||||
uint col = tid;
|
||||
if (col >= ncols) {
|
||||
return;
|
||||
}
|
||||
|
||||
const FLOAT_TYPE mean = sum / FLOAT_TYPE(ncols);
|
||||
const FLOAT_TYPE scale = inversesqrt(mean + FLOAT_TYPE(p.param1));
|
||||
|
||||
if (do_multiply) {
|
||||
if (ncols > p.ne10) {
|
||||
data_d[d_offset + col] = D_TYPE(scale * FLOAT_TYPE(data_a[a_offset + col]) * FLOAT_TYPE(data_b[b_offset + fastmod(col, p.ne10)]));
|
||||
} else {
|
||||
data_d[d_offset + col] = D_TYPE(scale * FLOAT_TYPE(data_a[a_offset + col]) * FLOAT_TYPE(data_b[b_offset + col]));
|
||||
}
|
||||
} else {
|
||||
data_d[d_offset + col] = D_TYPE(scale * FLOAT_TYPE(data_a[a_offset + col]));
|
||||
}
|
||||
}
|
||||
|
|
@ -20,6 +20,10 @@ void main() {
|
|||
const uint row = gl_WorkGroupID.z * 262144 + gl_WorkGroupID.y * 512 + gl_WorkGroupID.x;
|
||||
const uint tid = gl_LocalInvocationID.x;
|
||||
|
||||
if (row >= p.KY) {
|
||||
return;
|
||||
}
|
||||
|
||||
FLOAT_TYPE scale = p.param1;
|
||||
|
||||
// partial sums for thread in warp
|
||||
|
|
|
|||
|
|
@ -0,0 +1,17 @@
|
|||
#version 450
|
||||
|
||||
#include "types.comp"
|
||||
#include "generic_unary_head.comp"
|
||||
|
||||
layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
|
||||
|
||||
void main() {
|
||||
const uint idx = get_idx();
|
||||
|
||||
if (idx >= p.ne) {
|
||||
return;
|
||||
}
|
||||
|
||||
const FLOAT_TYPE val = FLOAT_TYPE(data_a[get_aoffset() + src0_idx(idx)]);
|
||||
data_d[get_doffset() + dst_idx(idx)] = D_TYPE(sqrt(val));
|
||||
}
|
||||
|
|
@ -1,9 +1,9 @@
|
|||
#version 450
|
||||
|
||||
#include "generic_head.comp"
|
||||
#include "types.comp"
|
||||
|
||||
#extension GL_EXT_control_flow_attributes : enable
|
||||
|
||||
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
|
||||
|
||||
layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
|
||||
|
|
@ -11,16 +11,49 @@ layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};
|
|||
|
||||
layout (constant_id = 0) const uint BLOCK_SIZE = 32;
|
||||
|
||||
layout (push_constant) uniform parameter
|
||||
{
|
||||
uint n_cols;
|
||||
uint ne01, ne02;
|
||||
uint nb01, nb02, nb03;
|
||||
uint nb11, nb12, nb13;
|
||||
float weight;
|
||||
uint misalign_offsets;
|
||||
uint ne0_12mp, ne0_12L;
|
||||
uint ne0_1mp, ne0_1L;
|
||||
} p;
|
||||
|
||||
uint get_aoffset() { return p.misalign_offsets >> 16; }
|
||||
uint get_doffset() { return p.misalign_offsets & 0xFFFF; }
|
||||
|
||||
// see init_fastdiv_values in ggml-vulkan.cpp
|
||||
uint fastdiv(uint n, uint mp, uint L) {
|
||||
uint msbs, lsbs;
|
||||
// msbs = mulhi(n, mp)
|
||||
umulExtended(n, mp, msbs, lsbs);
|
||||
return (msbs + n) >> L;
|
||||
}
|
||||
|
||||
|
||||
shared FLOAT_TYPE tmp[BLOCK_SIZE];
|
||||
|
||||
void main() {
|
||||
const uint row = gl_WorkGroupID.z * 262144 + gl_WorkGroupID.y * 512 + gl_WorkGroupID.x;
|
||||
const uint col = gl_LocalInvocationID.x;
|
||||
const float weight = p.weight;
|
||||
|
||||
tmp[col] = FLOAT_TYPE(0.0f);
|
||||
const uint i03 = fastdiv(row, p.ne0_12mp, p.ne0_12L);
|
||||
const uint i03_offset = i03 * p.ne01*p.ne02;
|
||||
const uint i02 = fastdiv(row - i03_offset, p.ne0_1mp, p.ne0_1L);
|
||||
const uint i01 = row - i03_offset - i02*p.ne01;
|
||||
|
||||
for (uint i = col; i < p.KX; i += BLOCK_SIZE) {
|
||||
tmp[col] += FLOAT_TYPE(data_a[row*p.KX + i]);
|
||||
const uint src_idx = get_aoffset() + i01 * p.nb01 + i02 * p.nb02 + i03 * p.nb03;
|
||||
const uint dst_idx = get_doffset() + i01 * p.nb11 + i02 * p.nb12 + i03 * p.nb13;
|
||||
|
||||
tmp[col] = FLOAT_TYPE(0.0);
|
||||
|
||||
for (uint i = col; i < p.n_cols; i += BLOCK_SIZE) {
|
||||
tmp[col] += FLOAT_TYPE(data_a[src_idx + i]);
|
||||
}
|
||||
|
||||
barrier();
|
||||
|
|
@ -32,6 +65,6 @@ void main() {
|
|||
}
|
||||
|
||||
if (col == 0) {
|
||||
data_d[row] = D_TYPE(tmp[0]);
|
||||
data_d[dst_idx] = D_TYPE(tmp[0] * weight);
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -24,11 +24,12 @@ void main() {
|
|||
const uint j = gl_GlobalInvocationID.x;
|
||||
const uint d_offset = i * p.nb1;
|
||||
|
||||
if (p.dim % 2 != 0 && j == ((p.dim + 1) / 2)) {
|
||||
data_d[d_offset + p.dim] = 0.f;
|
||||
const uint half_dim = p.dim / 2;
|
||||
|
||||
if (p.dim % 2 != 0 && j == half_dim) {
|
||||
data_d[d_offset + 2 * half_dim] = 0.f;
|
||||
}
|
||||
|
||||
const uint half_dim = p.dim / 2;
|
||||
if (j >= half_dim) {
|
||||
return;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -11,12 +11,12 @@
|
|||
#define QUANT_K 1
|
||||
#define QUANT_R 1
|
||||
|
||||
#if !defined(LOAD_VEC_A) || LOAD_VEC_A == 1
|
||||
#define A_TYPE float
|
||||
#elif LOAD_VEC_A == 4
|
||||
#if LOAD_VEC_A == 4
|
||||
#define A_TYPE vec4
|
||||
#elif LOAD_VEC_A == 8
|
||||
#define A_TYPE mat2x4
|
||||
#else
|
||||
#define A_TYPE float
|
||||
#endif
|
||||
#endif
|
||||
|
||||
|
|
@ -24,12 +24,12 @@
|
|||
#define QUANT_K 1
|
||||
#define QUANT_R 1
|
||||
|
||||
#if !defined(LOAD_VEC_A) || LOAD_VEC_A == 1
|
||||
#define A_TYPE float16_t
|
||||
#elif LOAD_VEC_A == 4
|
||||
#if LOAD_VEC_A == 4
|
||||
#define A_TYPE f16vec4
|
||||
#elif LOAD_VEC_A == 8
|
||||
#define A_TYPE f16mat2x4
|
||||
#else
|
||||
#define A_TYPE float16_t
|
||||
#endif
|
||||
#endif
|
||||
|
||||
|
|
@ -37,12 +37,12 @@
|
|||
#define QUANT_K 1
|
||||
#define QUANT_R 1
|
||||
|
||||
#if !defined(LOAD_VEC_A) || LOAD_VEC_A == 1
|
||||
#define A_TYPE uint16_t
|
||||
#elif LOAD_VEC_A == 4
|
||||
#if LOAD_VEC_A == 4
|
||||
#define A_TYPE u16vec4
|
||||
#elif LOAD_VEC_A == 8
|
||||
#error unsupported
|
||||
#else
|
||||
#define A_TYPE uint16_t
|
||||
#endif
|
||||
#endif
|
||||
|
||||
|
|
@ -207,6 +207,18 @@ struct block_q8_1_packed32
|
|||
int32_t qs[8];
|
||||
};
|
||||
|
||||
// 4 blocks in one to allow 16-byte/128-bit alignment and loads
|
||||
struct block_q8_1_x4
|
||||
{
|
||||
f16vec2 ds[4];
|
||||
int32_t qs[32];
|
||||
};
|
||||
struct block_q8_1_x4_packed128
|
||||
{
|
||||
f16vec2 ds[4];
|
||||
ivec4 qs[8];
|
||||
};
|
||||
|
||||
// K-quants
|
||||
#define QUANT_K_Q2_K 256
|
||||
|
||||
|
|
@ -233,6 +245,7 @@ struct block_q2_K_packed32
|
|||
|
||||
#if defined(DATA_A_Q2_K)
|
||||
#define QUANT_K QUANT_K_Q2_K
|
||||
#define QUANT_R 1
|
||||
#define A_TYPE block_q2_K
|
||||
#define A_TYPE_PACKED16 block_q2_K_packed16
|
||||
#define A_TYPE_PACKED32 block_q2_K_packed32
|
||||
|
|
@ -258,6 +271,7 @@ struct block_q3_K_packed16
|
|||
|
||||
#if defined(DATA_A_Q3_K)
|
||||
#define QUANT_K QUANT_K_Q3_K
|
||||
#define QUANT_R 1
|
||||
#define A_TYPE block_q3_K
|
||||
#define A_TYPE_PACKED16 block_q3_K_packed16
|
||||
#endif
|
||||
|
|
@ -292,6 +306,7 @@ struct block_q4_K_packed128
|
|||
|
||||
#if defined(DATA_A_Q4_K)
|
||||
#define QUANT_K QUANT_K_Q4_K
|
||||
#define QUANT_R 1
|
||||
#define A_TYPE block_q4_K
|
||||
#define A_TYPE_PACKED16 block_q4_K_packed16
|
||||
#define A_TYPE_PACKED32 block_q4_K_packed32
|
||||
|
|
@ -322,6 +337,7 @@ struct block_q5_K_packed128
|
|||
|
||||
#if defined(DATA_A_Q5_K)
|
||||
#define QUANT_K QUANT_K_Q5_K
|
||||
#define QUANT_R 1
|
||||
#define A_TYPE block_q5_K
|
||||
#define A_TYPE_PACKED16 block_q5_K_packed16
|
||||
#endif
|
||||
|
|
@ -346,6 +362,7 @@ struct block_q6_K_packed16
|
|||
|
||||
#if defined(DATA_A_Q6_K)
|
||||
#define QUANT_K QUANT_K_Q6_K
|
||||
#define QUANT_R 1
|
||||
#define A_TYPE block_q6_K
|
||||
#define A_TYPE_PACKED16 block_q6_K_packed16
|
||||
#endif
|
||||
|
|
@ -1412,6 +1429,11 @@ float bf16_to_fp32(uint32_t u)
|
|||
return uintBitsToFloat(u << 16);
|
||||
}
|
||||
|
||||
vec4 bf16_to_fp32(uvec4 u)
|
||||
{
|
||||
return vec4(bf16_to_fp32(u.x), bf16_to_fp32(u.y), bf16_to_fp32(u.z), bf16_to_fp32(u.w));
|
||||
}
|
||||
|
||||
float e8m0_to_fp32(uint8_t x) {
|
||||
uint32_t bits;
|
||||
|
||||
|
|
@ -1425,4 +1447,19 @@ float e8m0_to_fp32(uint8_t x) {
|
|||
return uintBitsToFloat(bits);
|
||||
}
|
||||
|
||||
#if BDA
|
||||
|
||||
#extension GL_EXT_buffer_reference : enable
|
||||
#extension GL_EXT_shader_explicit_arithmetic_types_int64 : enable
|
||||
|
||||
#define BDA_STORAGE_T uint64_t
|
||||
#define BDA_OFFSET_T uint64_t
|
||||
|
||||
#else
|
||||
|
||||
#define BDA_STORAGE_T uvec2
|
||||
#define BDA_OFFSET_T uint
|
||||
|
||||
#endif
|
||||
|
||||
#endif // !defined(GGML_TYPES_COMP)
|
||||
|
|
|
|||
|
|
@ -0,0 +1,25 @@
|
|||
#ifndef UTILS_COMP
|
||||
#define UTILS_COMP
|
||||
|
||||
// mod and div are expensive and coordinates/dimensions are often power of 2 or equal to 1
|
||||
uint fastmod(uint a, uint b) {
|
||||
if ((b & (b-1)) == 0) {
|
||||
return a & (b-1);
|
||||
}
|
||||
return a % b;
|
||||
}
|
||||
|
||||
uint fastdiv(uint a, uint b) {
|
||||
return (a < b) ? 0 : (a / b);
|
||||
}
|
||||
|
||||
void get_indices(uint idx, out uint i00, out uint i01, out uint i02, out uint i03, uint ne00, uint ne01, uint ne02, uint ne03) {
|
||||
i03 = fastdiv(idx, (ne02*ne01*ne00));
|
||||
const uint i03_offset = i03 * ne02*ne01*ne00;
|
||||
i02 = fastdiv((idx - i03_offset), (ne01*ne00));
|
||||
const uint i02_offset = i02*ne01*ne00;
|
||||
i01 = (idx - i03_offset - i02_offset) / ne00;
|
||||
i00 = idx - i03_offset - i02_offset - i01*ne00;
|
||||
}
|
||||
|
||||
#endif // UTILS_COMP
|
||||
|
|
@ -68,6 +68,12 @@ const std::vector<std::string> type_names = {
|
|||
"bf16",
|
||||
};
|
||||
|
||||
enum MatMulIdType {
|
||||
NONE,
|
||||
DEFAULT,
|
||||
SUBGROUP,
|
||||
};
|
||||
|
||||
namespace {
|
||||
void execute_command(const std::string& command, std::string& stdout_str, std::string& stderr_str) {
|
||||
#ifdef _WIN32
|
||||
|
|
@ -200,6 +206,22 @@ bool string_ends_with(const std::string& str, const std::string& suffix) {
|
|||
return std::equal(suffix.rbegin(), suffix.rend(), str.rbegin());
|
||||
}
|
||||
|
||||
bool is_quantized_type(const std::string& type_name) {
|
||||
return type_name != "f32" && type_name != "f16" && type_name != "bf16";
|
||||
}
|
||||
|
||||
bool is_legacy_quant(const std::string& type_name) {
|
||||
return type_name == "q4_0" || type_name == "q4_1" || type_name == "q5_0" || type_name == "q5_1" || type_name == "q8_0";
|
||||
}
|
||||
|
||||
bool is_k_quant(const std::string& type_name) {
|
||||
return string_ends_with(type_name, "_k");
|
||||
}
|
||||
|
||||
bool is_iq_quant(const std::string& type_name) {
|
||||
return string_starts_with(type_name, "iq");
|
||||
}
|
||||
|
||||
static const char path_separator = '/';
|
||||
|
||||
std::string join_paths(const std::string& path1, const std::string& path2) {
|
||||
|
|
@ -223,7 +245,8 @@ void string_to_spv_func(const std::string& _name, const std::string& in_fname, c
|
|||
std::string target_env = (name.find("_cm2") != std::string::npos) ? "--target-env=vulkan1.3" : "--target-env=vulkan1.2";
|
||||
|
||||
// disable spirv-opt for coopmat shaders for https://github.com/ggerganov/llama.cpp/issues/10734
|
||||
std::string opt_level = coopmat ? "" : "-O";
|
||||
// disable spirv-opt for bf16 shaders for https://github.com/ggml-org/llama.cpp/issues/15344
|
||||
std::string opt_level = (coopmat || name.find("bf16") != std::string::npos) ? "" : "-O";
|
||||
|
||||
#ifdef _WIN32
|
||||
std::vector<std::string> cmd = {GLSLC, "-fshader-stage=compute", target_env, opt_level, "\"" + in_path + "\"", "-o", "\"" + out_fname + "\""};
|
||||
|
|
@ -292,26 +315,32 @@ void string_to_spv(const std::string& _name, const std::string& in_fname, const
|
|||
compiles.push_back(std::async(string_to_spv_func, _name, in_fname, defines, fp16, coopmat, coopmat2, f16acc));
|
||||
}
|
||||
|
||||
void matmul_shaders(bool fp16, bool matmul_id, bool coopmat, bool coopmat2, bool f16acc) {
|
||||
void matmul_shaders(bool fp16, MatMulIdType matmul_id_type, bool coopmat, bool coopmat2, bool f16acc) {
|
||||
std::string load_vec = coopmat2 ? "1" : fp16 ? "8" : "4";
|
||||
std::string aligned_b_type_f32 = coopmat2 ? "float" : fp16 ? "mat2x4" : "vec4";
|
||||
std::string aligned_b_type_f16 = coopmat2 ? "float16_t" : fp16 ? "f16mat2x4" : "f16vec4";
|
||||
|
||||
std::map<std::string, std::string> base_dict = {
|
||||
{"FLOAT_TYPE_VEC2", (coopmat2 || fp16) ? "f16vec2" : "vec2"},
|
||||
};
|
||||
std::map<std::string, std::string> base_dict;
|
||||
std::string shader_name = "matmul";
|
||||
|
||||
if (matmul_id) {
|
||||
if (matmul_id_type == MatMulIdType::DEFAULT) {
|
||||
base_dict["MUL_MAT_ID"] = "1";
|
||||
shader_name = "matmul_id";
|
||||
} else if (matmul_id_type == MatMulIdType::SUBGROUP) {
|
||||
base_dict["MUL_MAT_ID"] = "1";
|
||||
base_dict["MUL_MAT_ID_USE_SUBGROUPS"] = "1";
|
||||
shader_name = "matmul_id_subgroup";
|
||||
}
|
||||
|
||||
if (fp16) {
|
||||
base_dict["FLOAT16"] = "1";
|
||||
}
|
||||
|
||||
base_dict["ACC_TYPE"] = f16acc ? "float16_t" : "float";
|
||||
base_dict["ACC_TYPE" ] = f16acc ? "float16_t" : "float";
|
||||
base_dict["ACC_TYPE_VEC2"] = f16acc ? "f16vec2" : "vec2";
|
||||
if (f16acc) {
|
||||
base_dict["ACC_TYPE_MAX"] = "\"float16_t(65504.0)\"";
|
||||
}
|
||||
|
||||
if (coopmat) {
|
||||
base_dict["COOPMAT"] = "1";
|
||||
|
|
@ -319,43 +348,96 @@ void matmul_shaders(bool fp16, bool matmul_id, bool coopmat, bool coopmat2, bool
|
|||
|
||||
const std::string source_name = coopmat2 ? "mul_mm_cm2.comp" : "mul_mm.comp";
|
||||
|
||||
auto const &FLOAT_TYPE = [&](const std::string &t) -> std::string {
|
||||
if (t == "bf16") {
|
||||
// scalar path promotes to float
|
||||
if (!coopmat && !coopmat2) {
|
||||
return "float";
|
||||
auto const &FLOAT_TYPE = [&](int vec, const std::string &t) -> std::string {
|
||||
switch (vec) {
|
||||
case 1:
|
||||
if (t == "bf16") {
|
||||
// scalar path promotes to float
|
||||
if (!coopmat && !coopmat2) {
|
||||
return "float";
|
||||
}
|
||||
return "bfloat16_t";
|
||||
}
|
||||
return "bfloat16_t";
|
||||
if (coopmat2 || fp16) {
|
||||
return "float16_t";
|
||||
}
|
||||
return "float";
|
||||
case 2:
|
||||
if (t == "bf16") {
|
||||
// scalar path promotes to float
|
||||
if (!coopmat && !coopmat2) {
|
||||
return "vec2";
|
||||
}
|
||||
return "bf16vec2";
|
||||
}
|
||||
if (coopmat2 || fp16) {
|
||||
return "f16vec2";
|
||||
}
|
||||
return "vec2";
|
||||
case 4:
|
||||
if (t == "bf16") {
|
||||
// scalar path promotes to float
|
||||
if (!coopmat && !coopmat2) {
|
||||
return "vec4";
|
||||
}
|
||||
return "bf16vec4";
|
||||
}
|
||||
if (coopmat2 || fp16) {
|
||||
return "f16vec4";
|
||||
}
|
||||
return "vec4";
|
||||
case 8:
|
||||
if (t == "bf16") {
|
||||
// scalar path promotes to float
|
||||
if (!coopmat && !coopmat2) {
|
||||
return "mat2x4";
|
||||
}
|
||||
throw std::runtime_error("bf16 vec8 not supported");
|
||||
}
|
||||
if (coopmat2 || fp16) {
|
||||
return "f16mat2x4";
|
||||
}
|
||||
return "mat2x4";
|
||||
default:
|
||||
throw std::runtime_error("invalid vector size");
|
||||
}
|
||||
if (coopmat2 || fp16) {
|
||||
return "float16_t";
|
||||
}
|
||||
return "float";
|
||||
};
|
||||
|
||||
const std::map<std::string, std::string> float_type_dict_f16 = {
|
||||
{"FLOAT_TYPE", FLOAT_TYPE(1, "f16")},
|
||||
{"FLOAT_TYPE_VEC2", FLOAT_TYPE(2, "f16")},
|
||||
{"FLOAT_TYPE_VEC4", FLOAT_TYPE(4, "f16")},
|
||||
{"FLOAT_TYPE_VEC8", FLOAT_TYPE(8, "f16")},
|
||||
};
|
||||
|
||||
// Shaders with f16 B_TYPE
|
||||
string_to_spv(shader_name + "_f32_f16", source_name, merge_maps(base_dict, {{"FLOAT_TYPE", FLOAT_TYPE("f16")}, {"DATA_A_F32", "1"}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}, }), fp16, coopmat, coopmat2, f16acc);
|
||||
string_to_spv(shader_name + "_f32_f16_aligned", source_name, merge_maps(base_dict, {{"FLOAT_TYPE", FLOAT_TYPE("f16")}, {"DATA_A_F32", "1"}, {"LOAD_VEC_A", load_vec}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"D_TYPE", "float"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc);
|
||||
string_to_spv(shader_name + "_f32_f16", source_name, merge_maps(merge_maps(base_dict, float_type_dict_f16), {{"DATA_A_F32", "1"}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}, }), fp16, coopmat, coopmat2, f16acc);
|
||||
string_to_spv(shader_name + "_f32_f16_aligned", source_name, merge_maps(merge_maps(base_dict, float_type_dict_f16), {{"DATA_A_F32", "1"}, {"LOAD_VEC_A", load_vec}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"D_TYPE", "float"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc);
|
||||
|
||||
string_to_spv(shader_name + "_f16_aligned", source_name, merge_maps(base_dict, {{"FLOAT_TYPE", FLOAT_TYPE("f16")}, {"DATA_A_F16", "1"}, {"LOAD_VEC_A", load_vec}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"D_TYPE", "float"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc);
|
||||
string_to_spv(shader_name + "_f16", source_name, merge_maps(base_dict, {{"FLOAT_TYPE", FLOAT_TYPE("f16")}, {"DATA_A_F16", "1"}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}}), fp16, coopmat, coopmat2, f16acc);
|
||||
string_to_spv(shader_name + "_f16", source_name, merge_maps(merge_maps(base_dict, float_type_dict_f16), {{"DATA_A_F16", "1"}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}}), fp16, coopmat, coopmat2, f16acc);
|
||||
string_to_spv(shader_name + "_f16_aligned", source_name, merge_maps(merge_maps(base_dict, float_type_dict_f16), {{"DATA_A_F16", "1"}, {"LOAD_VEC_A", load_vec}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"D_TYPE", "float"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc);
|
||||
|
||||
// bf16
|
||||
{
|
||||
std::string load_vec_a_unaligned = "1";
|
||||
// For aligned matmul loads
|
||||
std::string load_vec_a = coopmat2 ? "1" : "4";
|
||||
|
||||
// scalar path promotes to float
|
||||
std::string to_float_type = (coopmat || coopmat2) ? "uintBitsToBFloat16EXT" : "bf16_to_fp32";
|
||||
|
||||
const std::map<std::string, std::string> float_type_dict_bf16 = {
|
||||
{"FLOAT_TYPE", FLOAT_TYPE(1, "bf16")},
|
||||
{"FLOAT_TYPE_VEC2", FLOAT_TYPE(2, "bf16")},
|
||||
{"FLOAT_TYPE_VEC4", FLOAT_TYPE(4, "bf16")},
|
||||
};
|
||||
|
||||
// If bfloat16 is not supported, then only compile the scalar (promote to fp32) shader
|
||||
#if !defined(GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT)
|
||||
if (!(coopmat || coopmat2))
|
||||
#endif
|
||||
{
|
||||
string_to_spv(shader_name + "_bf16_aligned", source_name, merge_maps(base_dict, {{"FLOAT_TYPE", FLOAT_TYPE("bf16")}, {"TO_FLOAT_TYPE", to_float_type}, {"DATA_A_BF16", "1"}, {"LOAD_VEC_A", load_vec_a}, {"LOAD_VEC_B", "4"}, {"B_TYPE", coopmat2 ? "bfloat16_t" : "u16vec4"}, {"D_TYPE", "float"}, {"B_IS_FLOAT", "1"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc);
|
||||
string_to_spv(shader_name + "_bf16", source_name, merge_maps(base_dict, {{"FLOAT_TYPE", FLOAT_TYPE("bf16")}, {"TO_FLOAT_TYPE", to_float_type}, {"DATA_A_BF16", "1"}, {"LOAD_VEC_A", load_vec_a_unaligned}, {"B_TYPE", coopmat2 ? "bfloat16_t" : "uint16_t"}, {"D_TYPE", "float"}, {"B_IS_FLOAT", "1"}}), fp16, coopmat, coopmat2, f16acc);
|
||||
string_to_spv(shader_name + "_bf16", source_name, merge_maps(merge_maps(base_dict, float_type_dict_bf16), {{"TO_FLOAT_TYPE", to_float_type}, {"DATA_A_BF16", "1"}, {"B_TYPE", coopmat2 ? "bfloat16_t" : "uint16_t"}, {"D_TYPE", "float"}, {"B_IS_FLOAT", "1"}, {"DATA_B_BF16", "1"}}), fp16, coopmat, coopmat2, f16acc);
|
||||
string_to_spv(shader_name + "_bf16_aligned", source_name, merge_maps(merge_maps(base_dict, float_type_dict_bf16), {{"TO_FLOAT_TYPE", to_float_type}, {"DATA_A_BF16", "1"}, {"LOAD_VEC_A", load_vec_a}, {"LOAD_VEC_B", "4"}, {"B_TYPE", coopmat2 ? "bfloat16_t" : "u16vec4"}, {"D_TYPE", "float"}, {"B_IS_FLOAT", "1"}, {"DATA_B_BF16", "1"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -376,20 +458,27 @@ void matmul_shaders(bool fp16, bool matmul_id, bool coopmat, bool coopmat2, bool
|
|||
// For aligned matmul loads
|
||||
std::string load_vec_a = (coopmat2 || tname == "f32" || tname == "f16" || tname == "bf16") ? load_vec : load_vec_quant;
|
||||
|
||||
const std::map<std::string, std::string> float_type_dict = {
|
||||
{"FLOAT_TYPE", FLOAT_TYPE(1, tname)},
|
||||
{"FLOAT_TYPE_VEC2", FLOAT_TYPE(2, tname)},
|
||||
{"FLOAT_TYPE_VEC4", FLOAT_TYPE(4, tname)},
|
||||
{"FLOAT_TYPE_VEC8", FLOAT_TYPE(8, tname)},
|
||||
};
|
||||
|
||||
// don't generate f32 variants for coopmat2
|
||||
if (!coopmat2) {
|
||||
string_to_spv(shader_name + "_" + tname + "_f32", source_name, merge_maps(base_dict, {{"FLOAT_TYPE", FLOAT_TYPE(tname)}, {data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a_unaligned}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}), fp16, coopmat, coopmat2, f16acc);
|
||||
string_to_spv(shader_name + "_" + tname + "_f32_aligned", source_name, merge_maps(base_dict, {{"FLOAT_TYPE", FLOAT_TYPE(tname)}, {data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f32}, {"D_TYPE", "float"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc);
|
||||
string_to_spv(shader_name + "_" + tname + "_f32", source_name, merge_maps(merge_maps(base_dict, float_type_dict), {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a_unaligned}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}), fp16, coopmat, coopmat2, f16acc);
|
||||
string_to_spv(shader_name + "_" + tname + "_f32_aligned", source_name, merge_maps(merge_maps(base_dict, float_type_dict), {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f32}, {"D_TYPE", "float"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc);
|
||||
}
|
||||
|
||||
if (tname != "f16" && tname != "f32") {
|
||||
string_to_spv(shader_name + "_" + tname + "_f16", source_name, merge_maps(base_dict, {{"FLOAT_TYPE", FLOAT_TYPE(tname)}, {data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a_unaligned}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}}), fp16, coopmat, coopmat2, f16acc);
|
||||
string_to_spv(shader_name + "_" + tname + "_f16_aligned", source_name, merge_maps(base_dict, {{"FLOAT_TYPE", FLOAT_TYPE(tname)}, {data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"D_TYPE", "float"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc);
|
||||
string_to_spv(shader_name + "_" + tname + "_f16", source_name, merge_maps(merge_maps(base_dict, float_type_dict), {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a_unaligned}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}}), fp16, coopmat, coopmat2, f16acc);
|
||||
string_to_spv(shader_name + "_" + tname + "_f16_aligned", source_name, merge_maps(merge_maps(base_dict, float_type_dict), {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"D_TYPE", "float"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc);
|
||||
}
|
||||
|
||||
#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
|
||||
if (!coopmat && !coopmat2 && !matmul_id && (tname == "q4_0" || tname == "q4_1" || tname == "q5_0" || tname == "q5_1" || tname == "q8_0")) {
|
||||
string_to_spv(shader_name + "_" + tname + "_q8_1", "mul_mmq.comp", merge_maps(base_dict, {{"FLOAT_TYPE", FLOAT_TYPE(tname)}, {data_a_key, "1"}, {"D_TYPE", "float"},}), fp16, coopmat, coopmat2, f16acc);
|
||||
if (!coopmat && !coopmat2 && matmul_id_type == MatMulIdType::NONE && is_legacy_quant(tname)) {
|
||||
string_to_spv(shader_name + "_" + tname + "_q8_1", "mul_mmq.comp", merge_maps(merge_maps(base_dict, float_type_dict), {{data_a_key, "1"}, {"D_TYPE", "float"},}), fp16, coopmat, coopmat2, f16acc);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
|
@ -400,32 +489,38 @@ void process_shaders() {
|
|||
std::map<std::string, std::string> base_dict = {{"FLOAT_TYPE", "float"}};
|
||||
|
||||
// matmul
|
||||
for (const auto& matmul_id : {false, true}) {
|
||||
for (const MatMulIdType& matmul_id_type : {MatMulIdType::NONE, MatMulIdType::DEFAULT, MatMulIdType::SUBGROUP}) {
|
||||
// No coopmats
|
||||
// fp32
|
||||
matmul_shaders(false, matmul_id, false, false, false);
|
||||
matmul_shaders(false, matmul_id_type, false, false, false);
|
||||
|
||||
// fp16, fp32acc and fp16acc
|
||||
matmul_shaders(true, matmul_id, false, false, false);
|
||||
matmul_shaders(true, matmul_id, false, false, true);
|
||||
matmul_shaders(true, matmul_id_type, false, false, false);
|
||||
matmul_shaders(true, matmul_id_type, false, false, true);
|
||||
|
||||
if (matmul_id_type != MatMulIdType::DEFAULT) {
|
||||
#if defined(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT)
|
||||
// Coopmat, fp32acc and fp16acc
|
||||
matmul_shaders(true, matmul_id, true, false, false);
|
||||
matmul_shaders(true, matmul_id, true, false, true);
|
||||
// Coopmat, fp32acc and fp16acc
|
||||
matmul_shaders(true, matmul_id_type, true, false, false);
|
||||
matmul_shaders(true, matmul_id_type, true, false, true);
|
||||
#endif
|
||||
|
||||
#if defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT)
|
||||
// Coopmat2, fp32acc and fp16acc
|
||||
matmul_shaders(true, matmul_id, false, true, false);
|
||||
matmul_shaders(true, matmul_id, false, true, true);
|
||||
// Coopmat2, fp32acc and fp16acc
|
||||
matmul_shaders(true, matmul_id_type, false, true, false);
|
||||
matmul_shaders(true, matmul_id_type, false, true, true);
|
||||
#endif
|
||||
}
|
||||
}
|
||||
|
||||
// flash attention
|
||||
for (const auto& f16acc : {false, true}) {
|
||||
std::string acctype = f16acc ? "float16_t" : "float";
|
||||
std::string acctypev4 = f16acc ? "f16vec4" : "vec4";
|
||||
std::map<std::string, std::string> fa_base_dict = base_dict;
|
||||
fa_base_dict["ACC_TYPE"] = f16acc ? "float16_t" : "float";
|
||||
fa_base_dict["ACC_TYPEV4"] = f16acc ? "f16vec4" : "vec4";
|
||||
if (f16acc) {
|
||||
fa_base_dict["ACC_TYPE_MAX"] = "\"float16_t(65504.0)\"";
|
||||
}
|
||||
|
||||
for (const auto& tname : type_names) {
|
||||
if (tname == "f32") {
|
||||
|
|
@ -436,30 +531,30 @@ void process_shaders() {
|
|||
#if defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT)
|
||||
if (tname == "f16") {
|
||||
string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn_cm2.comp",
|
||||
merge_maps(base_dict, {{"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"ACC_TYPE", acctype}}), true, false, true, f16acc);
|
||||
merge_maps(fa_base_dict, {{"Q_TYPE", "float"}, {"D_TYPE", "float"}}), true, false, true, f16acc);
|
||||
} else {
|
||||
std::string data_a_key = "DATA_A_" + to_uppercase(tname);
|
||||
string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn_cm2.comp",
|
||||
merge_maps(base_dict, {{data_a_key, "1"}, {"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"ACC_TYPE", acctype}, {"DEQUANTFUNC", "dequantFunc"+to_uppercase(tname) }, {"BLOCK_SIZE", "QUANT_K_"+to_uppercase(tname) }}), true, false, true, f16acc);
|
||||
merge_maps(fa_base_dict, {{data_a_key, "1"}, {"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"DEQUANTFUNC", "dequantFunc"+to_uppercase(tname) }, {"BLOCK_SIZE", "QUANT_K_"+to_uppercase(tname) }}), true, false, true, f16acc);
|
||||
}
|
||||
#endif
|
||||
#if defined(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT)
|
||||
if (tname == "f16") {
|
||||
string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn_cm1.comp",
|
||||
merge_maps(base_dict, {{"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"ACC_TYPE", acctype}, {"ACC_TYPEV4", acctypev4}, {"COOPMAT", "1"}}), true, true, false, f16acc);
|
||||
merge_maps(fa_base_dict, {{"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"COOPMAT", "1"}}), true, true, false, f16acc);
|
||||
} else if (tname == "q4_0" || tname == "q8_0") {
|
||||
std::string data_a_key = "DATA_A_" + to_uppercase(tname);
|
||||
string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn_cm1.comp",
|
||||
merge_maps(base_dict, {{data_a_key, "1"}, {"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"ACC_TYPE", acctype}, {"ACC_TYPEV4", acctypev4}, {"BLOCK_SIZE", "QUANT_K_"+to_uppercase(tname)}, {"COOPMAT", "1"}}), true, true, false, f16acc);
|
||||
merge_maps(fa_base_dict, {{data_a_key, "1"}, {"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"BLOCK_SIZE", "QUANT_K_"+to_uppercase(tname)}, {"COOPMAT", "1"}}), true, true, false, f16acc);
|
||||
}
|
||||
#endif
|
||||
if (tname == "f16") {
|
||||
string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn.comp",
|
||||
merge_maps(base_dict, {{"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"ACC_TYPE", acctype}}), true, false, false, f16acc);
|
||||
merge_maps(fa_base_dict, {{"Q_TYPE", "float"}, {"D_TYPE", "float"}}), true, false, false, f16acc);
|
||||
} else if (tname == "q4_0" || tname == "q8_0") {
|
||||
std::string data_a_key = "DATA_A_" + to_uppercase(tname);
|
||||
string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn.comp",
|
||||
merge_maps(base_dict, {{data_a_key, "1"}, {"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"ACC_TYPE", acctype}, {"BLOCK_SIZE", "QUANT_K_"+to_uppercase(tname) }}), true, false, false, f16acc);
|
||||
merge_maps(fa_base_dict, {{data_a_key, "1"}, {"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"BLOCK_SIZE", "QUANT_K_"+to_uppercase(tname) }}), true, false, false, f16acc);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -472,23 +567,36 @@ void process_shaders() {
|
|||
string_to_spv("mul_mat_vec_" + tname + "_f32_f32", shader, merge_maps(base_dict, {{data_a_key, "1"}, {"B_TYPE", "float"}, {"B_TYPE_VEC2", "vec2"}, {"B_TYPE_VEC4", "vec4"}, {"D_TYPE", "float"}}));
|
||||
string_to_spv("mul_mat_vec_" + tname + "_f16_f32", shader, merge_maps(base_dict, {{data_a_key, "1"}, {"B_TYPE", "float16_t"}, {"B_TYPE_VEC2", "f16vec2"}, {"B_TYPE_VEC4", "f16vec4"}, {"D_TYPE", "float"}}));
|
||||
|
||||
string_to_spv("mul_mat_vec_" + tname + "_f32_f32_subgroup", shader, merge_maps(base_dict, {{data_a_key, "1"}, {"B_TYPE", "float"}, {"B_TYPE_VEC2", "vec2"}, {"B_TYPE_VEC4", "vec4"}, {"D_TYPE", "float"}, {"USE_SUBGROUP_ADD", "1"}}));
|
||||
string_to_spv("mul_mat_vec_" + tname + "_f16_f32_subgroup", shader, merge_maps(base_dict, {{data_a_key, "1"}, {"B_TYPE", "float16_t"}, {"B_TYPE_VEC2", "f16vec2"}, {"B_TYPE_VEC4", "f16vec4"}, {"D_TYPE", "float"}, {"USE_SUBGROUP_ADD", "1"}}));
|
||||
|
||||
string_to_spv("mul_mat_vec_" + tname + "_f32_f32_subgroup_no_shmem", shader, merge_maps(base_dict, {{data_a_key, "1"}, {"B_TYPE", "float"}, {"B_TYPE_VEC2", "vec2"}, {"B_TYPE_VEC4", "vec4"}, {"D_TYPE", "float"}, {"USE_SUBGROUP_ADD_NO_SHMEM", "1"}}));
|
||||
string_to_spv("mul_mat_vec_" + tname + "_f16_f32_subgroup_no_shmem", shader, merge_maps(base_dict, {{data_a_key, "1"}, {"B_TYPE", "float16_t"}, {"B_TYPE_VEC2", "f16vec2"}, {"B_TYPE_VEC4", "f16vec4"}, {"D_TYPE", "float"}, {"USE_SUBGROUP_ADD_NO_SHMEM", "1"}}));
|
||||
|
||||
string_to_spv("mul_mat_vec_id_" + tname + "_f32", shader, merge_maps(base_dict, {{"MUL_MAT_ID", "1"}, {data_a_key, "1"}, {"B_TYPE", "float"}, {"B_TYPE_VEC2", "vec2"}, {"B_TYPE_VEC4", "vec4"}, {"D_TYPE", "float"}}));
|
||||
|
||||
// mul mat vec with integer dot product
|
||||
#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
|
||||
if (is_legacy_quant(tname)) {
|
||||
string_to_spv("mul_mat_vec_" + tname + "_q8_1_f32", "mul_mat_vecq.comp", merge_maps(base_dict, {{data_a_key, "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"FLOAT_TYPE_VEC2", "vec2"}, {"ACC_TYPE", "float"}}));
|
||||
string_to_spv("mul_mat_vec_" + tname + "_q8_1_f32_subgroup", "mul_mat_vecq.comp", merge_maps(base_dict, {{data_a_key, "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"FLOAT_TYPE_VEC2", "vec2"}, {"ACC_TYPE", "float"}, {"USE_SUBGROUP_ADD", "1"}}));
|
||||
string_to_spv("mul_mat_vec_" + tname + "_q8_1_f32_subgroup_no_shmem", "mul_mat_vecq.comp", merge_maps(base_dict, {{data_a_key, "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"FLOAT_TYPE_VEC2", "vec2"}, {"ACC_TYPE", "float"}, {"USE_SUBGROUP_ADD_NO_SHMEM", "1"}}));
|
||||
}
|
||||
#endif
|
||||
|
||||
// Dequant shaders
|
||||
if (tname != "f16" && tname != "bf16") {
|
||||
string_to_spv("dequant_" + tname, "dequant_" + tname + ".comp", merge_maps(base_dict, {{data_a_key, "1"}, {"D_TYPE", "float16_t"}}));
|
||||
}
|
||||
|
||||
if (!string_ends_with(tname, "_k")) {
|
||||
shader = (tname == "f32" || tname == "f16" || tname == "bf16") ? "get_rows.comp" : "get_rows_quant.comp";
|
||||
shader = (tname == "f32" || tname == "f16" || tname == "bf16") ? "get_rows.comp" : "get_rows_quant.comp";
|
||||
|
||||
if (tname == "f16") {
|
||||
string_to_spv("get_rows_" + tname, shader, merge_maps(base_dict, {{data_a_key, "1"}, {"B_TYPE", "int"}, {"D_TYPE", "float16_t"}, {"OPTIMIZATION_ERROR_WORKAROUND", "1"}}));
|
||||
} else {
|
||||
string_to_spv("get_rows_" + tname, shader, merge_maps(base_dict, {{data_a_key, "1"}, {"B_TYPE", "int"}, {"D_TYPE", "float16_t"}}));
|
||||
}
|
||||
string_to_spv("get_rows_" + tname + "_f32", shader, merge_maps(base_dict, {{data_a_key, "1"}, {"B_TYPE", "int"}, {"D_TYPE", "float"}}));
|
||||
if (tname == "f16") {
|
||||
string_to_spv("get_rows_" + tname, shader, merge_maps(base_dict, {{data_a_key, "1"}, {"B_TYPE", "int"}, {"D_TYPE", "float16_t"}, {"OPTIMIZATION_ERROR_WORKAROUND", "1"}}));
|
||||
} else {
|
||||
string_to_spv("get_rows_" + tname, shader, merge_maps(base_dict, {{data_a_key, "1"}, {"B_TYPE", "int"}, {"D_TYPE", "float16_t"}}));
|
||||
}
|
||||
string_to_spv("get_rows_" + tname + "_f32", shader, merge_maps(base_dict, {{data_a_key, "1"}, {"B_TYPE", "int"}, {"D_TYPE", "float"}}));
|
||||
}
|
||||
|
||||
string_to_spv("mul_mat_vec_p021_f16_f32_subgroup_add", "mul_mat_vec_p021.comp", {{"A_TYPE", "float16_t"}, {"A_TYPE_VEC4", "f16vec4"}, {"B_TYPE", "float"}, {"B_TYPE_VEC4", "vec4"}, {"D_TYPE", "float"}, {"USE_SUBGROUP_ADD", "1"}});
|
||||
|
|
@ -499,6 +607,7 @@ void process_shaders() {
|
|||
string_to_spv("norm_f32", "norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
|
||||
string_to_spv("group_norm_f32", "group_norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
|
||||
string_to_spv("rms_norm_f32", "rms_norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}));
|
||||
string_to_spv("rms_norm_partials_f32", "rms_norm_partials.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}));
|
||||
string_to_spv("rms_norm_back_f32", "rms_norm_back.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}));
|
||||
string_to_spv("l2_norm_f32", "l2_norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
|
||||
|
||||
|
|
@ -508,10 +617,14 @@ void process_shaders() {
|
|||
string_to_spv("cpy_f16_f32", "copy.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float"}, {"OPTIMIZATION_ERROR_WORKAROUND", "1"}});
|
||||
string_to_spv("cpy_f32_bf16","copy.comp", {{"A_TYPE", "float"}, {"D_TYPE", "uint16_t"}, {"DATA_D_BF16", "1"}});
|
||||
string_to_spv("contig_cpy_f32_f32", "contig_copy.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
|
||||
string_to_spv("contig_cpy_f32_i32", "contig_copy.comp", {{"A_TYPE", "float"}, {"D_TYPE", "int"}});
|
||||
string_to_spv("contig_cpy_i32_f32", "contig_copy.comp", {{"A_TYPE", "int"}, {"D_TYPE", "float"}});
|
||||
string_to_spv("contig_cpy_f32_f16", "contig_copy.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float16_t"}});
|
||||
string_to_spv("contig_cpy_f16_f16", "contig_copy.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"OPTIMIZATION_ERROR_WORKAROUND", "1"}});
|
||||
string_to_spv("contig_cpy_f16_f32", "contig_copy.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float"}, {"OPTIMIZATION_ERROR_WORKAROUND", "1"}});
|
||||
string_to_spv("contig_cpy_f32_bf16","contig_copy.comp",{{"A_TYPE", "float"}, {"D_TYPE", "uint16_t"}, {"DATA_D_BF16", "1"}});
|
||||
string_to_spv("cpy_f32_i32", "copy.comp", {{"A_TYPE", "float"}, {"D_TYPE", "int"}});
|
||||
string_to_spv("cpy_i32_f32", "copy.comp", {{"A_TYPE", "int"}, {"D_TYPE", "float"}});
|
||||
|
||||
for (std::string t : {"q4_0", "q4_1", "q5_0", "q5_1", "q8_0", "iq4_nl"}) {
|
||||
string_to_spv("cpy_f32_" + t, "copy_to_quant.comp", {{"DATA_A_" + to_uppercase(t), "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
|
||||
|
|
@ -520,8 +633,10 @@ void process_shaders() {
|
|||
}
|
||||
|
||||
for (std::string t : {"f32", "f16", "bf16", "q4_0", "q4_1", "q5_0", "q5_1", "q8_0", "iq4_nl"}) {
|
||||
string_to_spv("set_rows_" + t, "copy_to_quant.comp", {{"SET_ROWS", "1"}, {"DATA_A_" + to_uppercase(t), "1"}, {"B_TYPE", "uvec2"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
|
||||
string_to_spv("set_rows_" + t + "_rte", "copy_to_quant.comp", {{"SET_ROWS", "1"}, {"DATA_A_" + to_uppercase(t), "1"}, {"B_TYPE", "uvec2"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"RTE16", "1"}});
|
||||
string_to_spv("set_rows_" + t + "_i32", "copy_to_quant.comp", {{"SET_ROWS", "1"}, {"DATA_A_" + to_uppercase(t), "1"}, {"B_TYPE", "uint"}, {"B_SIZE", "32"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
|
||||
string_to_spv("set_rows_" + t + "_i32_rte", "copy_to_quant.comp", {{"SET_ROWS", "1"}, {"DATA_A_" + to_uppercase(t), "1"}, {"B_TYPE", "uint"}, {"B_SIZE", "32"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"RTE16", "1"}});
|
||||
string_to_spv("set_rows_" + t + "_i64", "copy_to_quant.comp", {{"SET_ROWS", "1"}, {"DATA_A_" + to_uppercase(t), "1"}, {"B_TYPE", "uvec2"}, {"B_SIZE", "64"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
|
||||
string_to_spv("set_rows_" + t + "_i64_rte", "copy_to_quant.comp", {{"SET_ROWS", "1"}, {"DATA_A_" + to_uppercase(t), "1"}, {"B_TYPE", "uvec2"}, {"B_SIZE", "64"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"RTE16", "1"}});
|
||||
}
|
||||
|
||||
auto get_type_str = [](bool f16) {
|
||||
|
|
@ -534,13 +649,15 @@ void process_shaders() {
|
|||
s += std::string(dst_f16 ? "_f16" : "_f32");
|
||||
return s;
|
||||
};
|
||||
for (std::string op : {"add", "sub", "mul", "div"}) {
|
||||
for (std::string op : {"add", "sub", "mul", "div", "add_rms", }) {
|
||||
for (auto src0_f16 : {false, true}) {
|
||||
for (auto src1_f16 : {false, true}) {
|
||||
for (auto dst_f16 : {false, true}) {
|
||||
for (auto rte : {false, true}) {
|
||||
auto source = op == "add_rms" ? std::string("add") : op;
|
||||
auto name = op + get_suffix(src0_f16, src1_f16, dst_f16) + (rte ? "_rte" : "");
|
||||
string_to_spv(name.c_str(), op + ".comp", {{"A_TYPE", get_type_str(src0_f16)}, {"B_TYPE", get_type_str(src1_f16)}, {"D_TYPE", get_type_str(dst_f16)}, {"FLOAT_TYPE", "float"}, {"RTE16", rte ? "1" : "0"}});
|
||||
auto add_rms = op == "add_rms" ? "1" : "0";
|
||||
string_to_spv(name.c_str(), source + ".comp", {{"A_TYPE", get_type_str(src0_f16)}, {"B_TYPE", get_type_str(src1_f16)}, {"D_TYPE", get_type_str(dst_f16)}, {"FLOAT_TYPE", "float"}, {"RTE16", rte ? "1" : "0"}, {"ADD_RMS" , add_rms}});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -553,7 +670,12 @@ void process_shaders() {
|
|||
|
||||
string_to_spv("split_k_reduce", "mul_mat_split_k_reduce.comp", {});
|
||||
string_to_spv("fa_split_k_reduce", "flash_attn_split_k_reduce.comp", {});
|
||||
|
||||
string_to_spv("quantize_q8_1", "quantize_q8_1.comp", {});
|
||||
string_to_spv("quantize_q8_1_subgroup", "quantize_q8_1.comp", {{"USE_SUBGROUPS", "1"}});
|
||||
|
||||
string_to_spv("quantize_q8_1_x4", "quantize_q8_1.comp", {{"QBLOCK_X4", "1"}});
|
||||
string_to_spv("quantize_q8_1_x4_subgroup", "quantize_q8_1.comp", {{"QBLOCK_X4", "1"}, {"USE_SUBGROUPS", "1"}});
|
||||
|
||||
string_to_spv("mul_f32", "mul.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
|
||||
|
||||
|
|
@ -566,6 +688,8 @@ void process_shaders() {
|
|||
|
||||
string_to_spv("sqr_f32", "square.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
|
||||
|
||||
string_to_spv("sqrt_f32", "sqrt.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
|
||||
|
||||
string_to_spv("sin_f32", "sin.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
|
||||
|
||||
string_to_spv("cos_f32", "cos.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
|
||||
|
|
@ -580,6 +704,11 @@ void process_shaders() {
|
|||
|
||||
string_to_spv("upscale_f32", "upscale.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}});
|
||||
|
||||
for (auto rte : {false, true}) {
|
||||
std::string suffix = rte ? "_rte" : "";
|
||||
string_to_spv("exp_f16" + suffix, "exp.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"RTE16", rte ? "1" : "0"}});
|
||||
string_to_spv("exp_f32" + suffix, "exp.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"} , {"RTE16", rte ? "1" : "0"}});
|
||||
}
|
||||
string_to_spv("gelu_f16", "gelu.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
|
||||
string_to_spv("gelu_f32", "gelu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
|
||||
string_to_spv("gelu_erf_f16", "gelu_erf.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
|
||||
|
|
@ -594,6 +723,10 @@ void process_shaders() {
|
|||
string_to_spv("tanh_f32", "tanh.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
|
||||
string_to_spv("sigmoid_f16", "sigmoid.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
|
||||
string_to_spv("sigmoid_f32", "sigmoid.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
|
||||
string_to_spv("hardsigmoid_f16","hardsigmoid.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
|
||||
string_to_spv("hardsigmoid_f32","hardsigmoid.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
|
||||
string_to_spv("hardswish_f16", "hardswish.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
|
||||
string_to_spv("hardswish_f32", "hardswish.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
|
||||
|
||||
for (auto rte : {false, true}) {
|
||||
std::string suffix = rte ? "_rte" : "";
|
||||
|
|
@ -642,9 +775,15 @@ void process_shaders() {
|
|||
string_to_spv("sum_rows_f32", "sum_rows.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
|
||||
string_to_spv("count_equal_i32", "count_equal.comp", merge_maps(base_dict, {{"A_TYPE", "int"}, {"B_TYPE", "int"}, {"D_TYPE", "int"}}));
|
||||
|
||||
string_to_spv("im2col_f32", "im2col.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
|
||||
string_to_spv("im2col_f32_f16", "im2col.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float16_t"}}));
|
||||
string_to_spv("im2col_f32_f16_rte", "im2col.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float16_t"}, {"RTE16", "1"}}));
|
||||
for (std::string dim_str : {"", "_3d"}) {
|
||||
for (bool bda : {false, true}) {
|
||||
std::string bda_str = bda ? "_bda" : "";
|
||||
std::string bda_def = bda ? "1" : "0";
|
||||
string_to_spv("im2col" + dim_str + "_f32" + bda_str, "im2col" + dim_str + ".comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"D_SIZE", "4"}, {"BDA", bda_def}}));
|
||||
string_to_spv("im2col" + dim_str + "_f32_f16" + bda_str, "im2col" + dim_str + ".comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float16_t"}, {"D_SIZE", "2"}, {"BDA", bda_def}}));
|
||||
string_to_spv("im2col" + dim_str + "_f32_f16_rte" + bda_str, "im2col" + dim_str + ".comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float16_t"}, {"D_SIZE", "2"}, {"RTE16", "1"}, {"BDA", bda_def}}));
|
||||
}
|
||||
}
|
||||
|
||||
string_to_spv("timestep_embedding_f32", "timestep_embedding.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
|
||||
|
||||
|
|
@ -657,25 +796,41 @@ void process_shaders() {
|
|||
string_to_spv("rwkv_wkv7_f32", "wkv7.comp", merge_maps(base_dict, {{"A_TYPE", "float"}}));
|
||||
|
||||
string_to_spv("opt_step_adamw_f32", "opt_step_adamw.comp", merge_maps(base_dict, {{"A_TYPE", "float"}}));
|
||||
string_to_spv("opt_step_sgd_f32", "opt_step_sgd.comp", merge_maps(base_dict, {{"A_TYPE", "float"}}));
|
||||
|
||||
string_to_spv("conv2d_f32_unroll", "conv2d_mm.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"USE_COLLECTIVES", "1"}, {"UNROLL", "[[unroll]]"}});
|
||||
string_to_spv("conv2d_f16_f32_unroll", "conv2d_mm.comp", {{"A_TYPE", "float16_t"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"USE_COLLECTIVES", "1"}, {"UNROLL", "[[unroll]]"}});
|
||||
|
||||
string_to_spv("conv2d_f32", "conv2d_mm.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"USE_COLLECTIVES", "1"}, {"UNROLL", ""}});
|
||||
string_to_spv("conv2d_f16_f32", "conv2d_mm.comp", {{"A_TYPE", "float16_t"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"USE_COLLECTIVES", "1"}, {"UNROLL", ""}});
|
||||
|
||||
for (auto transpose : {false, true}) {
|
||||
for (auto unroll : {false, true}) {
|
||||
for (auto a_f16 : {false, true}) {
|
||||
std::map<std::string, std::string> defines = {
|
||||
{"A_TYPE", a_f16 ? "float16_t" : "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"},
|
||||
{"USE_COLLECTIVES", "1"}, {"UNROLL", unroll ? "[[unroll]]" : ""},
|
||||
};
|
||||
if (transpose) defines["TRANSPOSE"] = "1";
|
||||
std::string name = std::string(transpose ? "conv_transpose_2d": "conv2d")
|
||||
+ (a_f16 ? "_f16" : "") + "_f32";
|
||||
string_to_spv(name + (unroll ? "_unroll" : ""), "conv2d_mm.comp", defines);
|
||||
#if defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT)
|
||||
string_to_spv("conv2d_f32", "conv2d_mm.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"USE_COLLECTIVES", "1"}, {"UNROLL", "[[unroll]]"}, {"COOPMAT2", "1"}}, true, false, true);
|
||||
string_to_spv("conv2d_f16_f32", "conv2d_mm.comp", {{"A_TYPE", "float16_t"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"USE_COLLECTIVES", "1"}, {"UNROLL", "[[unroll]]"}, {"COOPMAT2", "1"}}, true, false, true);
|
||||
if (unroll) {
|
||||
defines["COOPMAT2"] = "1";
|
||||
string_to_spv(name, "conv2d_mm.comp", defines, true, false, true);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
string_to_spv("conv2d_dw_whcn_f32", "conv2d_dw.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"WHCN", "1"}}));
|
||||
string_to_spv("conv2d_dw_cwhn_f32", "conv2d_dw.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"CWHN", "1"}}));
|
||||
string_to_spv("conv2d_dw_whcn_f16_f32", "conv2d_dw.comp", merge_maps(base_dict, {{"A_TYPE", "float16_t"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"WHCN", "1"}}));
|
||||
string_to_spv("conv2d_dw_cwhn_f16_f32", "conv2d_dw.comp", merge_maps(base_dict, {{"A_TYPE", "float16_t"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"CWHN", "1"}}));
|
||||
|
||||
string_to_spv("roll_f32", "roll.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
|
||||
|
||||
string_to_spv("add_id_f32", "add_id.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}));
|
||||
|
||||
string_to_spv("multi_add_f32", "multi_add.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"RTE16", "1"}, {"ADD_RMS" , "0"}});
|
||||
string_to_spv("multi_add_rms_f32", "multi_add.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"RTE16", "1"}, {"ADD_RMS" , "1"}});
|
||||
|
||||
for (auto &c : compiles) {
|
||||
c.wait();
|
||||
}
|
||||
|
|
@ -732,7 +887,7 @@ void write_output_files() {
|
|||
}
|
||||
|
||||
std::string suffixes[2] = {"_f32", "_f16"};
|
||||
for (const char *op : {"add", "sub", "mul", "div"}) {
|
||||
for (const char *op : {"add", "sub", "mul", "div", "add_rms"}) {
|
||||
fprintf(hdr, "extern unsigned char *%s_data[2][2][2][2];\n", op);
|
||||
fprintf(hdr, "extern uint64_t %s_len[2][2][2][2];\n", op);
|
||||
std::string data = "unsigned char *" + std::string(op) + "_data[2][2][2][2] = ";
|
||||
|
|
@ -784,6 +939,27 @@ void write_output_files() {
|
|||
fputs(data.c_str(), src);
|
||||
fputs(len.c_str(), src);
|
||||
}
|
||||
|
||||
std::vector<std::string> btypes = {"f16", "f32"};
|
||||
|
||||
#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
|
||||
btypes.push_back("q8_1");
|
||||
#endif
|
||||
|
||||
for (const std::string& btype : btypes) {
|
||||
for (const auto& tname : type_names) {
|
||||
if (btype == "q8_1" && !is_legacy_quant(tname)) {
|
||||
continue;
|
||||
}
|
||||
fprintf(hdr, "extern unsigned char *arr_dmmv_%s_%s_f32_data[3];\n", tname.c_str(), btype.c_str());
|
||||
fprintf(hdr, "extern uint64_t arr_dmmv_%s_%s_f32_len[3];\n", tname.c_str(), btype.c_str());
|
||||
std::string data = "unsigned char *arr_dmmv_" + tname + "_" + btype + "_f32_data[3] = {mul_mat_vec_" + tname + "_" + btype + "_f32_data, mul_mat_vec_" + tname + "_" + btype + "_f32_subgroup_data, mul_mat_vec_" + tname + "_" + btype + "_f32_subgroup_no_shmem_data};\n";
|
||||
std::string len = "uint64_t arr_dmmv_" + tname + "_" + btype + "_f32_len[3] = {mul_mat_vec_" + tname + "_" + btype + "_f32_len, mul_mat_vec_" + tname + "_" + btype + "_f32_subgroup_len, mul_mat_vec_" + tname + "_" + btype + "_f32_subgroup_no_shmem_len};\n";
|
||||
fputs(data.c_str(), src);
|
||||
fputs(len.c_str(), src);
|
||||
}
|
||||
}
|
||||
|
||||
fclose(hdr);
|
||||
fclose(src);
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in New Issue