Update Vulkan Code to de4c07f93783a1a96456a44dc16b9db538ee1618

This commit is contained in:
Inforithmics 2025-08-10 16:01:07 +02:00
parent bc5c3fb213
commit 2e7452be71
42 changed files with 4607 additions and 809 deletions

View File

@ -32,8 +32,10 @@ if (Vulkan_FOUND)
if (${glslc_error} MATCHES ".*extension not supported: GL_KHR_cooperative_matrix.*")
message(STATUS "GL_KHR_cooperative_matrix not supported by glslc")
set(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT OFF)
else()
message(STATUS "GL_KHR_cooperative_matrix supported by glslc")
set(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT ON)
add_compile_definitions(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT)
endif()
@ -46,11 +48,45 @@ if (Vulkan_FOUND)
if (${glslc_error} MATCHES ".*extension not supported: GL_NV_cooperative_matrix2.*")
message(STATUS "GL_NV_cooperative_matrix2 not supported by glslc")
set(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT OFF)
else()
message(STATUS "GL_NV_cooperative_matrix2 supported by glslc")
set(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT ON)
add_compile_definitions(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT)
endif()
# Compile a test shader to determine whether GL_EXT_integer_dot_product is supported.
# If it's not, there will be an error to stderr.
# If it's supported, set a define to indicate that we should compile those shaders
execute_process(COMMAND ${Vulkan_GLSLC_EXECUTABLE} -o - -fshader-stage=compute --target-env=vulkan1.3 "${CMAKE_CURRENT_SOURCE_DIR}/vulkan-shaders/test_integer_dot_support.comp"
OUTPUT_VARIABLE glslc_output
ERROR_VARIABLE glslc_error)
if (${glslc_error} MATCHES ".*extension not supported: GL_EXT_integer_dot_product.*")
message(STATUS "GL_EXT_integer_dot_product not supported by glslc")
set(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT OFF)
else()
message(STATUS "GL_EXT_integer_dot_product supported by glslc")
set(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT ON)
add_compile_definitions(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
endif()
# Compile a test shader to determine whether GL_EXT_bfloat16 is supported.
# If it's not, there will be an error to stderr.
# If it's supported, set a define to indicate that we should compile those shaders
execute_process(COMMAND ${Vulkan_GLSLC_EXECUTABLE} -o - -fshader-stage=compute --target-env=vulkan1.3 "${CMAKE_CURRENT_SOURCE_DIR}/vulkan-shaders/test_bfloat16_support.comp"
OUTPUT_VARIABLE glslc_output
ERROR_VARIABLE glslc_error)
if (${glslc_error} MATCHES ".*extension not supported: GL_EXT_bfloat16.*")
message(STATUS "GL_EXT_bfloat16 not supported by glslc")
set(GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT OFF)
else()
message(STATUS "GL_EXT_bfloat16 supported by glslc")
set(GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT ON)
add_compile_definitions(GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT)
endif()
target_link_libraries(ggml-vulkan PRIVATE Vulkan::Vulkan)
target_include_directories(ggml-vulkan PRIVATE ${CMAKE_CURRENT_BINARY_DIR})
@ -119,6 +155,10 @@ if (Vulkan_FOUND)
SOURCE_DIR ${CMAKE_CURRENT_SOURCE_DIR}/vulkan-shaders
CMAKE_ARGS -DCMAKE_TOOLCHAIN_FILE=${HOST_CMAKE_TOOLCHAIN_FILE}
-DCMAKE_INSTALL_PREFIX=${CMAKE_BINARY_DIR}
-DGGML_VULKAN_COOPMAT_GLSLC_SUPPORT=${GGML_VULKAN_COOPMAT_GLSLC_SUPPORT}
-DGGML_VULKAN_COOPMAT2_GLSLC_SUPPORT=${GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT}
-DGGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT=${GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT}
-DGGML_VULKAN_BFLOAT16_GLSLC_SUPPORT=${GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT}
BUILD_COMMAND ${CMAKE_COMMAND} --build .
INSTALL_COMMAND ${CMAKE_COMMAND} --install .
INSTALL_DIR ${CMAKE_BINARY_DIR}

View File

@ -0,0 +1,15 @@
set(CMAKE_BUILD_TYPE Release)
set(CMAKE_C_FLAGS -O2)
set(CMAKE_CXX_FLAGS -O2)
set(CMAKE_FIND_ROOT_PATH_MODE_PROGRAM NEVER)
set(CMAKE_FIND_ROOT_PATH_MODE_LIBRARY NEVER)
set(CMAKE_FIND_ROOT_PATH_MODE_INCLUDE NEVER)
set(CMAKE_C_COMPILER "@HOST_C_COMPILER@")
set(CMAKE_CXX_COMPILER "@HOST_CXX_COMPILER@")
set(CMAKE_RUNTIME_OUTPUT_DIRECTORY @CMAKE_RUNTIME_OUTPUT_DIRECTORY@)
if("@CMAKE_C_COMPILER_ID@" STREQUAL "MSVC")
foreach(CONFIG IN ITEMS DEBUG RELEASE MINSIZEREL RELWITHDEBINFO)
set(CMAKE_RUNTIME_OUTPUT_DIRECTORY_${CONFIG} ${CMAKE_RUNTIME_OUTPUT_DIRECTORY})
endforeach()
endif()

File diff suppressed because it is too large Load Diff

View File

@ -1,9 +1,20 @@
find_package (Threads REQUIRED)
find_program(GLSLC_EXECUTABLE glslc)
if(NOT GLSLC_EXECUTABLE)
message(FATAL_ERROR "glslc not found.")
endif()
cmake_minimum_required(VERSION 3.19)
project("vulkan-shaders-gen" C CXX)
find_package (Threads REQUIRED)
if (GGML_VULKAN_COOPMAT_GLSLC_SUPPORT)
add_compile_definitions(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT)
endif()
if (GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT)
add_compile_definitions(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT)
endif()
if (GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
add_compile_definitions(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
endif()
if (GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT)
add_compile_definitions(GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT)
endif()
set(TARGET vulkan-shaders-gen)
add_executable(${TARGET} vulkan-shaders-gen.cpp)
install(TARGETS ${TARGET} RUNTIME)

View File

@ -18,7 +18,11 @@ void main() {
// fast path for when all four iterations are in-bounds
if (idx + (num_iter-1)*num_threads < p.ne) {
[[unroll]] for (uint i = 0; i < num_iter; ++i) {
#ifndef OPTIMIZATION_ERROR_WORKAROUND
#if defined(DATA_D_BF16)
float f = float(data_a[get_aoffset() + idx]);
data_d[get_doffset() + idx] = D_TYPE(fp32_to_bf16(f));
#elif !defined(OPTIMIZATION_ERROR_WORKAROUND)
data_d[get_doffset() + idx] = D_TYPE(data_a[get_aoffset() + idx]);
#else
data_d[get_doffset() + idx] = data_a[get_aoffset() + idx];
@ -31,7 +35,10 @@ void main() {
continue;
}
#ifndef OPTIMIZATION_ERROR_WORKAROUND
#if defined(DATA_D_BF16)
float f = float(data_a[get_aoffset() + idx]);
data_d[get_doffset() + idx] = D_TYPE(fp32_to_bf16(f));
#elif !defined(OPTIMIZATION_ERROR_WORKAROUND)
data_d[get_doffset() + idx] = D_TYPE(data_a[get_aoffset() + idx]);
#else
data_d[get_doffset() + idx] = data_a[get_aoffset() + idx];

View File

@ -0,0 +1,105 @@
#version 450
#include "types.comp"
layout (push_constant) uniform parameter
{
uint ne;
uint batches;
uint channels;
uint dst_w;
uint dst_h;
uint src_w;
uint src_h;
uint knl_w;
uint knl_h;
int stride_x;
int stride_y;
int pad_x;
int pad_y;
int dilation_x;
int dilation_y;
} p;
layout (binding = 0) readonly buffer A {A_TYPE knl_data[];};
layout (binding = 1) readonly buffer B {B_TYPE src_data[];};
layout (binding = 2) writeonly buffer D {D_TYPE dst_data[];};
layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
FLOAT_TYPE conv_2d_dw_whcn(uint idx) {
uint i0 = idx / p.dst_w;
uint dst_x = idx - i0 * p.dst_w;
uint i1 = i0 / p.dst_h;
uint dst_y = i0 - i1 * p.dst_h;
uint n = i1 / p.channels;
uint c = i1 - n * p.channels;
uint src_i = n * p.channels * p.src_h * p.src_w + c * p.src_h * p.src_w;
uint knl_i = c * p.knl_h * p.knl_w;
FLOAT_TYPE sum = 0.0;
for (uint knl_y = 0; knl_y < p.knl_h; ++knl_y) {
uint src_y = dst_y * p.stride_y + knl_y * p.dilation_y - p.pad_y;
if (src_y >= p.src_h) { // src_y < 0 will wrap to a large unsigned int
continue;
}
for (uint knl_x = 0; knl_x < p.knl_w; ++knl_x) {
uint src_x = dst_x * p.stride_x + knl_x * p.dilation_x - p.pad_x;
if (src_x >= p.src_w) { // src_x < 0 will wrap to a large unsigned int
continue;
}
FLOAT_TYPE v = FLOAT_TYPE(src_data[src_i + src_y * p.src_w + src_x]);
FLOAT_TYPE k = FLOAT_TYPE(knl_data[knl_i + knl_y * p.knl_w + knl_x]);
sum = fma(v, k, sum);
}
}
return sum;
}
FLOAT_TYPE conv_2d_dw_cwhn(uint idx) {
uint i0 = idx / p.channels;
uint c = idx - i0 * p.channels;
uint i1 = i0 / p.dst_w;
uint dst_x = i0 - i1 * p.dst_w;
uint n = i1 / p.dst_h;
uint dst_y = i1 - n * p.dst_h;
uint src_i = n * p.channels * p.src_h * p.src_w;
uint src_row = p.src_w * p.channels;
uint knl_row = p.knl_w * p.channels;
FLOAT_TYPE sum = 0.0;
for (uint knl_y = 0; knl_y < p.knl_h; ++knl_y) {
uint src_y = dst_y * p.stride_y + knl_y * p.dilation_y - p.pad_y;
if (src_y >= p.src_h) { // src_y < 0 will wrap to a large unsigned int
continue;
}
for (uint knl_x = 0; knl_x < p.knl_w; ++knl_x) {
uint src_x = dst_x * p.stride_x + knl_x * p.dilation_x - p.pad_x;
if (src_x >= p.src_w) { // src_x < 0 will wrap to a large unsigned int
continue;
}
FLOAT_TYPE v = FLOAT_TYPE(src_data[src_i + src_y * src_row + src_x * p.channels + c]);
FLOAT_TYPE k = FLOAT_TYPE(knl_data[ knl_y * knl_row + knl_x * p.channels + c]);
sum = fma(v, k, sum);
}
}
return sum;
}
void main() {
uint idx = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x;
if (idx >= p.ne) {
return;
}
FLOAT_TYPE result =
#ifdef WHCN
conv_2d_dw_whcn(idx);
#else
conv_2d_dw_cwhn(idx);
#endif
dst_data[idx] = D_TYPE(result);
}

View File

@ -12,7 +12,10 @@ void main() {
return;
}
#ifndef OPTIMIZATION_ERROR_WORKAROUND
#if defined(DATA_D_BF16)
float f = float(data_a[get_aoffset() + src0_idx(idx)]);
data_d[get_doffset() + dst_idx(idx)] = D_TYPE(fp32_to_bf16(f));
#elif !defined(OPTIMIZATION_ERROR_WORKAROUND)
data_d[get_doffset() + dst_idx(idx)] = D_TYPE(data_a[get_aoffset() + src0_idx(idx)]);
#else
data_d[get_doffset() + dst_idx(idx)] = data_a[get_aoffset() + src0_idx(idx)];

View File

@ -1,5 +1,10 @@
#version 450
#if RTE16
#extension GL_EXT_spirv_intrinsics : enable
spirv_execution_mode(capabilities = [4467], 4462, 16); // RoundingModeRTE, 16 bits
#endif // RTE16
#include "types.comp"
#include "generic_unary_head.comp"

View File

@ -23,6 +23,12 @@ vec2 dequantize(uint ib, uint iqs, uint a_offset) {
}
#endif
#if defined(DATA_A_BF16)
vec2 dequantize(uint ib, uint iqs, uint a_offset) {
return vec2(bf16_to_fp32(data_a[a_offset + ib]), bf16_to_fp32(data_a[a_offset + ib + 1]));
}
#endif
#if defined(DATA_A_Q4_0)
vec2 dequantize(uint ib, uint iqs, uint a_offset) {
const uint vui = uint(data_a[a_offset + ib].qs[iqs]);
@ -82,9 +88,9 @@ vec2 dequantize(uint ib, uint iqs, uint a_offset) {
return vec2(int(data_a[a_offset + ib].qs[iqs]), int(data_a[a_offset + ib].qs[iqs + 1]));
}
vec4 dequantize4(uint ib, uint iqs, uint a_offset) {
uint32_t v0 = data_a_packed16[a_offset + ib].qs[iqs/2];
uint32_t v1 = data_a_packed16[a_offset + ib].qs[iqs/2 + 1];
return vec4(int8_t(v0 & 0xFF), int8_t(v0 >> 8), int8_t(v1 & 0xFF), int8_t(v1 >> 8));
const i8vec2 v0 = unpack8(int32_t(data_a_packed16[a_offset + ib].qs[iqs/2])).xy; // vec4 used due to #12147
const i8vec2 v1 = unpack8(int32_t(data_a_packed16[a_offset + ib].qs[iqs/2 + 1])).xy;
return vec4(v0.x, v0.y, v1.x, v1.y);
}
#endif
@ -428,7 +434,7 @@ vec4 dequantize4(uint ib, uint iqs, uint a_offset) {
}
#endif
#if defined(DATA_A_F32) || defined(DATA_A_F16)
#if defined(DATA_A_F32) || defined(DATA_A_F16) || defined(DATA_A_BF16)
vec2 get_dm(uint ib, uint a_offset) {
return vec2(0, 0);
}

View File

@ -92,7 +92,7 @@ float16_t dequantFuncQ8_0(const in decodeBufQ8_0 bl, const in uint blockCoords[2
const uint iqs = idx;
// Load 16b and select the byte for this element
int32_t qs = unpack8(int32_t(bl.block.qs[(iqs & 0x1E) >> 1]))[iqs & 1];
int32_t qs = unpack8(bl.block.qs[(iqs & 0x1E) >> 1])[iqs & 1];
float16_t ret = float16_t(qs) * d;
return ret;
}
@ -167,6 +167,101 @@ layout(buffer_reference, std430, buffer_reference_align = 16) buffer decodeBufQ4
block_q4_K_packed128 block;
};
#if defined(IS_MUL_MM2)
// For Q4_K and Q5_K in the mat-mul shader, we decode a tile's worth of scales
// into shared memory and then process the whole tile using those scales.
// There is a fetch function that loads into private variables and then a store
// function that stores into shared memory.
// Q4_K and Q5_K have the same encoding of scales, so everything is shared except
// the part that fetches from the structure (which has a different block layout).
#if defined(DATA_A_Q4_K) || defined(DATA_A_Q5_K)
const uint shAscales_stride = (BM + 2);
// 1 scale per 32 elements -> 8 scales per block, per row
shared vec2 shAscales[8 * shAscales_stride];
uvec4 row_v;
#endif
#if defined(DATA_A_Q4_K)
layout (binding = 0) readonly buffer A_Q4_K_128 {block_q4_K_packed128 data_a_q4_k_packed128[];};
void fetch_scalesQ4_K(uint ir_BM, uint pos_a, uint stride_a, uint block_k, uint tid, bool in_bounds)
{
uint tids_per_row = BLOCK_SIZE / BM;
uint is_per_tid = 8 / tids_per_row;
uint is_start = is_per_tid * (tid % tids_per_row);
uint tid_row = tid / tids_per_row;
uint row = ir_BM + tid_row;
uint block_index = pos_a + row * stride_a + (block_k / QUANT_K);
if (in_bounds || row < p.M) {
row_v = data_a_q4_k_packed128[block_index].q4k[0];
}
}
#endif
#if defined(DATA_A_Q5_K)
layout (binding = 0) readonly buffer A_Q5_K_128 {block_q5_K_packed128 data_a_q5_k_packed128[];};
void fetch_scalesQ5_K(uint ir_BM, uint pos_a, uint stride_a, uint block_k, uint tid, bool in_bounds)
{
uint tids_per_row = BLOCK_SIZE / BM;
uint is_per_tid = 8 / tids_per_row;
uint is_start = is_per_tid * (tid % tids_per_row);
uint tid_row = tid / tids_per_row;
uint row = ir_BM + tid_row;
uint block_index = pos_a + row * stride_a + (block_k / QUANT_K);
if (in_bounds || row < p.M) {
row_v = data_a_q5_k_packed128[block_index].q5k[0];
}
}
#endif
#if defined(DATA_A_Q4_K) || defined(DATA_A_Q5_K)
void store_scalesQ4_K(uint tid)
{
barrier();
uint tids_per_row = BLOCK_SIZE / BM;
uint is_per_tid = 8 / tids_per_row;
uint is_start = is_per_tid * (tid % tids_per_row);
uint tid_row = tid / tids_per_row;
[[unroll]] for (uint idx = 0; idx < is_per_tid; ++idx) {
uint is = idx + is_start;
uvec4 v = row_v;
const vec2 loadd = vec2(unpackFloat2x16(v.x));
uint32_t sc;
uint32_t mbyte;
uint32_t scale0 = v.y;
uint32_t scale4 = v.z;
uint32_t scale8 = v.w;
uint32_t sc_lo = scale0;
uint32_t mb_lo = scale4;
uint32_t sc_hi = (scale8 & 0x0F0F0F0F) | ((scale0 & 0xC0C0C0C0) >> 2);
uint32_t mb_hi = ((scale8 & 0xF0F0F0F0) >> 4) | ((scale4 & 0xC0C0C0C0) >> 2);
sc = is < 4 ? sc_lo : sc_hi;
mbyte = is < 4 ? mb_lo : mb_hi;
sc = sc >> (8 * (is & 3));
mbyte = mbyte >> (8 * (is & 3));
sc &= 0x3F;
mbyte &= 0x3F;
const float d = loadd.x * float(sc);
const float m = loadd.y * float(mbyte);
shAscales[is * shAscales_stride + tid_row] = vec2(d,m);
}
barrier();
}
#endif
#endif
float16_t dequantFuncQ4_K(const in decodeBufQ4_K bl, const in uint blockCoords[2], const in uint coordInBlock[2])
{
decodeBufQ4_K_packed16 bl16 = decodeBufQ4_K_packed16(bl);
@ -176,9 +271,13 @@ float16_t dequantFuncQ4_K(const in decodeBufQ4_K bl, const in uint blockCoords[2
const uint b = (idx & 0x20) >> 5; // 0,1
const uint is = (idx & 0xE0) >> 5; // 0..7
#if defined(IS_MUL_MM2) && defined(DATA_A_Q4_K)
vec2 v = shAscales[is * shAscales_stride + (blockCoords[0] % BM)];
float d = v.x;
float m = v.y;
#else
uvec4 v = bl128.block.q4k[0];
const f16vec2 loadd = unpackFloat2x16(v.x);
const vec2 loadd = vec2(unpackFloat2x16(v.x));
uint32_t sc;
uint32_t mbyte;
@ -199,15 +298,16 @@ float16_t dequantFuncQ4_K(const in decodeBufQ4_K bl, const in uint blockCoords[2
sc &= 0x3F;
mbyte &= 0x3F;
const float16_t d = loadd.x * float16_t(sc);
const float16_t m = loadd.y * float16_t(mbyte);
const float d = loadd.x * float(sc);
const float m = loadd.y * float(mbyte);
#endif
uint qs = uint32_t(bl16.block.qs[((idx & 0xC0) >> 2) + ((idx & 0x1E) >> 1)]);
qs = (qs >> (b * 4 + 8 * (idx & 1))) & 0xF;
float16_t ret = d * float16_t(qs) - m;
float ret = d * float(qs) - m;
return ret;
return float16_t(ret);
}
layout(buffer_reference, std430, buffer_reference_align = 16) buffer decodeBufQ5_K {
@ -231,6 +331,11 @@ float16_t dequantFuncQ5_K(const in decodeBufQ5_K bl, const in uint blockCoords[2
const uint b = (idx & 0x20) >> 5; // 0,1
const uint is = (idx & 0xE0) >> 5; // 0..7
#if defined(IS_MUL_MM2) && defined(DATA_A_Q5_K)
vec2 v = shAscales[is * shAscales_stride + (blockCoords[0] % BM)];
float d = v.x;
float m = v.y;
#else
uvec4 v = bl128.block.q5k[0];
const f16vec2 loadd = unpackFloat2x16(v.x);
@ -256,6 +361,7 @@ float16_t dequantFuncQ5_K(const in decodeBufQ5_K bl, const in uint blockCoords[2
const float16_t d = loadd.x * float16_t(sc);
const float16_t m = loadd.y * float16_t(mbyte);
#endif
uint qh = uint32_t(bl16.block.qh[(idx & 0x1E) >> 1]);
qh = ((qh >> is) & 0x101) << 4;
@ -264,9 +370,9 @@ float16_t dequantFuncQ5_K(const in decodeBufQ5_K bl, const in uint blockCoords[2
qs = (qs >> (b * 4)) & 0x0F0F;
qs = unpack8(qs | qh)[idx & 1];
float16_t ret = d * (float16_t(qs)) - m;
float ret = d * float(qs) - m;
return ret;
return float16_t(ret);
}
layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufQ6_K {
@ -311,8 +417,8 @@ float16_t dequantFuncIQ1_S(const in decodeBufIQ1_S bl, const in uint blockCoords
const float16_t d = bl.block.d;
const uint idx = coordInBlock[1];
const uint ib32 = idx / 32;
const uint ib8 = idx / 8;
const uint ib32 = (idx & 0xE0) >> 5;
const uint ib8 = (idx & 0xF8) >> 3;
const uint qh = bl.block.qh[ib32];
const uint qs = bl.block.qs[ib8];
@ -330,14 +436,20 @@ layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufIQ1
block_iq1_m block;
};
layout(buffer_reference, std430, buffer_reference_align = 8) buffer decodeBufIQ1_M_packed64 {
block_iq1_m_packed64 block;
};
float16_t dequantFuncIQ1_M(const in decodeBufIQ1_M bl, const in uint blockCoords[2], const in uint coordInBlock[2])
{
const u16vec4 scales = u16vec4(bl.block.scales[0], bl.block.scales[1], bl.block.scales[2], bl.block.scales[3]) >> 12;
const float16_t d = uint16BitsToHalf(scales.x | (scales.y << 4) | (scales.z << 8) | (scales.w << 12));
decodeBufIQ1_M_packed64 bl64 = decodeBufIQ1_M_packed64(bl);
const uint idx = coordInBlock[1];
const uint ib8 = idx / 8;
const uint ib16 = idx / 16;
uvec2 scales = unpack32(bl64.block.scales);
const float16_t d = uint16BitsToHalf(uint16_t(((scales.x & 0xF000) >> 12) | ((scales.x & 0xF0000000) >> 24) | ((scales.y & 0xF000) >> 4) | ((scales.y & 0xF0000000) >> 16)));
const uint ib8 = (idx & 0xF8) >> 3;
const uint ib16 = (idx & 0xF0) >> 4;
const int i8 = int(idx % 8);
const uint sc = bl.block.scales[ib8 / 8];
const uint qs = bl.block.qs[ib8];
@ -370,7 +482,7 @@ float16_t dequantFuncIQ2_XXS(const in decodeBufIQ2_XXS bl, const in uint blockCo
const uint ib8 = (idx & 0x18) >> 3; // 0..3
const uint iqs = 8 * ib32 + ib8;
const uint8_t qs = bl.block.qs[iqs];
const uint qs = bl.block.qs[iqs];
const uint signscale = pack32(u16vec2(bl16.block.qs[4*ib32+2], bl16.block.qs[4*ib32+3]));
const float dscale = float(bl.block.d) * 0.25 * (0.5 + float(signscale >> 28));
@ -558,8 +670,12 @@ float16_t dequantFuncIQ4_NL(const in decodeBufIQ4_NL bl, const in uint blockCoor
#define dequantFuncA dequantFuncQ3_K
#elif defined(DATA_A_Q4_K)
#define dequantFuncA dequantFuncQ4_K
#define fetch_scales fetch_scalesQ4_K
#define store_scales store_scalesQ4_K
#elif defined(DATA_A_Q5_K)
#define dequantFuncA dequantFuncQ5_K
#define fetch_scales fetch_scalesQ5_K
#define store_scales store_scalesQ4_K
#elif defined(DATA_A_Q6_K)
#define dequantFuncA dequantFuncQ6_K
#elif defined(DATA_A_IQ1_S)

View File

@ -0,0 +1,483 @@
#version 450
#extension GL_EXT_control_flow_attributes : enable
#extension GL_EXT_shader_16bit_storage : require
#extension GL_EXT_shader_explicit_arithmetic_types_float16 : require
#extension GL_EXT_shader_explicit_arithmetic_types_int32 : require
#extension GL_KHR_shader_subgroup_shuffle : enable
#include "types.comp"
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
layout (constant_id = 1) const uint32_t Br = 1;
layout (constant_id = 2) const uint32_t Bc = 32;
layout (constant_id = 3) const uint32_t D = 32;
layout (constant_id = 5) const uint32_t D_split = 16;
const uint32_t D_per_thread = D / D_split;
const uint32_t cols_per_iter = gl_WorkGroupSize.x / D_split;
const uint32_t cols_per_thread = Bc / cols_per_iter;
layout (push_constant) uniform parameter {
uint32_t N;
uint32_t KV;
uint32_t ne1;
uint32_t ne2;
uint32_t ne3;
uint32_t neq2;
uint32_t neq3;
uint32_t nek2;
uint32_t nek3;
uint32_t nev2;
uint32_t nev3;
uint32_t nem1;
uint32_t nb01;
uint32_t nb02;
uint32_t nb03;
uint32_t nb11;
uint32_t nb12;
uint32_t nb13;
uint32_t nb21;
uint32_t nb22;
uint32_t nb23;
uint32_t nb31;
float scale;
float max_bias;
float logit_softcap;
uint32_t mask;
uint32_t n_head_log2;
float m0;
float m1;
uint32_t gqa_ratio;
uint32_t split_kv;
uint32_t k_num;
} p;
layout (binding = 0) readonly buffer Q {float data_q[];};
layout (binding = 0) readonly buffer QV4 {vec4 data_qv4[];};
layout (binding = 1) readonly buffer K {float16_t data_k[];};
layout (binding = 1) readonly buffer KV4 {f16vec4 data_kv4[];};
layout (binding = 2) readonly buffer V {float16_t data_v[];};
layout (binding = 2) readonly buffer VV4 {f16vec4 data_vv4[];};
layout (binding = 3) readonly buffer M {float16_t data_m[];};
layout (binding = 4) writeonly buffer O {D_TYPE data_o[];};
#if defined(A_TYPE_PACKED16)
#define BINDING_IDX_K 0
#define BINDING_IDX_V 1
layout (binding = 1) readonly buffer KV_PACKED16 {A_TYPE_PACKED16 data_packed16[];} kv_packed[2];
#endif
#if defined(DATA_A_Q4_0)
#define BLOCK_BYTE_SIZE 18
vec4 dequantize4(uint ib, uint iqs, uint a_offset, uint binding_idx) {
uint vui_lo = uint(kv_packed[binding_idx].data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 0]);
uint vui_hi = uint(kv_packed[binding_idx].data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 1]);
uint shift = (iqs & 0x10) >> 2;
vui_lo >>= shift;
vui_hi >>= shift;
return float(kv_packed[binding_idx].data_packed16[a_offset + ib].d) * (vec4(vui_lo & 0xF, (vui_lo >> 8) & 0xF, vui_hi & 0xF, (vui_hi >> 8) & 0xF) - 8.0f);
}
#endif
#if defined(DATA_A_Q8_0)
#define BLOCK_BYTE_SIZE 34
vec4 dequantize4(uint ib, uint iqs, uint a_offset, uint binding_idx) {
const i8vec2 v0 = unpack8(int32_t(kv_packed[binding_idx].data_packed16[a_offset + ib].qs[iqs / 2])).xy; // vec4 used due to #12147
const i8vec2 v1 = unpack8(int32_t(kv_packed[binding_idx].data_packed16[a_offset + ib].qs[iqs / 2 + 1])).xy;
return float(kv_packed[binding_idx].data_packed16[a_offset + ib].d) * vec4(v0.x, v0.y, v1.x, v1.y);
}
#endif
#define CEIL_DIV(a, b) (((a) + (b) - 1) / (b))
// Store the output when doing grouped query attention.
// Rows index by Q's dimension 2, and the first N rows are valid.
D_TYPE perElemOpGqaStore(const in uint32_t r, const in uint32_t c, const in D_TYPE elem, const in uint32_t o_offset, const in uint32_t iq2, const in uint32_t N)
{
uint32_t offset = (iq2 + r) * D + c;
data_o[o_offset + offset] = D_TYPE(elem);
return elem;
}
// Store column zero. This is used to save per-row m and L values for split_k.
ACC_TYPE perElemOpStoreCol0(const in uint32_t r, const in uint32_t c, const in ACC_TYPE elem, const in uint32_t o_offset, const in uint32_t iq2, const in uint32_t N)
{
if (r < N && c == 0) {
uint32_t offset = iq2 + r;
data_o[o_offset + offset] = D_TYPE(elem);
}
return elem;
}
// Load the slope matrix, indexed by Q's dimension 2.
ACC_TYPE perElemOpComputeSlope(const in uint32_t r, const in uint32_t c, const in ACC_TYPE elem, const in uint32_t iq2)
{
const uint32_t h = iq2 + (r % p.gqa_ratio);
const ACC_TYPE base = ACC_TYPE(h < p.n_head_log2 ? p.m0 : p.m1);
const int exph = int(h < p.n_head_log2 ? h + 1 : 2*(h - p.n_head_log2) + 1);
return ACC_TYPE(pow(base, ACC_TYPE(exph)));
}
shared FLOAT_TYPE tmpsh[gl_WorkGroupSize.x];
shared vec4 tmpshv4[gl_WorkGroupSize.x];
shared float masksh[Bc][Br];
shared vec4 Qf[Br][D / 4];
void main() {
#ifdef NEEDS_INIT_IQ_SHMEM
init_iq_shmem(gl_WorkGroupSize);
#endif
const uint32_t tid = gl_LocalInvocationIndex;
const uint32_t N = p.N;
const uint32_t KV = p.KV;
const uint32_t d_tid = gl_LocalInvocationIndex % D_split;
const uint32_t col_tid = gl_LocalInvocationIndex / D_split;
uint32_t i = gl_WorkGroupID.x;
uint32_t split_k_index = 0;
if (p.k_num > 1) {
i = 0;
split_k_index = gl_WorkGroupID.x;
}
const uint32_t Tr = CEIL_DIV(N, Br);
const uint32_t start_j = split_k_index * p.split_kv / Bc;
const uint32_t end_j = CEIL_DIV(min(KV, (split_k_index + 1) * p.split_kv), Bc);
// When not using grouped query attention, all rows share the same iq2, equal to gl_WorkGroupID.y.
// When using grouped query attention, each workgroup does gqa_ratio consecutive values of iq2.
const uint32_t iq2 = gl_WorkGroupID.y * p.gqa_ratio;
const uint32_t iq3 = gl_WorkGroupID.z;
// broadcast factors
const uint32_t rk2 = p.neq2/p.nek2;
const uint32_t rk3 = p.neq3/p.nek3;
const uint32_t rv2 = p.neq2/p.nev2;
const uint32_t rv3 = p.neq3/p.nev3;
// k indices
const uint32_t ik3 = iq3 / rk3;
const uint32_t ik2 = iq2 / rk2;
// v indices
const uint32_t iv3 = iq3 / rv3;
const uint32_t iv2 = iq2 / rv2;
// nb?1 are already divided by the type size and are in units of elements.
// When using grouped query attention, Q is indexed by iq2, so the stride
// should be nb02 (which is in bytes).
uint32_t q_stride = p.gqa_ratio > 1 ? (p.nb02 / 4) : p.nb01;
uint32_t k_stride = p.nb11;
uint32_t v_stride = p.nb21;
// When using grouped query attention, all rows use the same mask (stride 0).
// "p.gqa_ratio >> 16" is just a roundabout way of writing zero
// that prevents the compiler from folding the "&" through the select
// and breaking the alignment detection.
uint32_t m_stride = (p.gqa_ratio > 1) ? (p.gqa_ratio >> 16) : KV;
uint32_t q_offset = (iq2*p.nb02+iq3*p.nb03) / 4;
[[unroll]] for (uint32_t idx = 0; idx < Br * D / 4; idx += gl_WorkGroupSize.x) {
uint32_t d = (idx + tid) % (D / 4);
uint32_t r = (idx + tid) / (D / 4);
if (r < Br && d < D / 4 &&
i * Br + r < N) {
Qf[r][d] = vec4(data_qv4[q_offset / 4 + (i * Br + r) * q_stride / 4 + d]) * p.scale;
}
}
barrier();
vec4 Of[Br][D_per_thread / 4];
[[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) {
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
Of[r][d] = vec4(0.0);
}
}
float Lf[Br], Mf[Br];
// Use -FLT_MAX/2 rather than -inf to reduce the possibility of NaNs, e.g. when computing Mold-M.
const float NEG_FLT_MAX_OVER_2 = uintBitsToFloat(0xFEFFFFFF);
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
Lf[r] = 0;
Mf[r] = NEG_FLT_MAX_OVER_2;
}
float slope[Br];
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
slope[r] = 1.0;
}
// ALiBi
if (p.max_bias > 0.0f) {
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
slope[r] = perElemOpComputeSlope(r, col_tid, ACC_TYPE(0), iq2);
}
}
#if BLOCK_SIZE > 1
uint32_t k_offset = (ik2*p.nb12 + ik3*p.nb13) / BLOCK_BYTE_SIZE;
uint32_t v_offset = (iv2*p.nb22 + iv3*p.nb23) / BLOCK_BYTE_SIZE;
#else
uint32_t k_offset = (ik2*p.nb12 + ik3*p.nb13) / 2;
uint32_t v_offset = (iv2*p.nb22 + iv3*p.nb23) / 2;
#endif
[[dont_unroll]]
for (uint32_t j = start_j; j < end_j; ++j) {
float Sf[Br][cols_per_thread];
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
[[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) {
Sf[r][c] = 0.0;
}
}
[[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) {
[[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) {
#if BLOCK_SIZE > 1
uint coord = (j * Bc + c * cols_per_iter + col_tid) * k_stride * BLOCK_SIZE + 4 * (d * D_split + d_tid);
uint ib = coord / BLOCK_SIZE;
uint iqs = (coord % BLOCK_SIZE);
vec4 K_Tf = dequantize4(ib, iqs, k_offset, BINDING_IDX_K);
#else
vec4 K_Tf = vec4(data_kv4[k_offset / 4 + (j * Bc + c * cols_per_iter + col_tid) * k_stride / 4 + d * D_split + d_tid]);
#endif
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
Sf[r][c] += dot(Qf[r][d * D_split + d_tid], K_Tf);
}
}
}
[[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) {
// Compute sum across the D_split
[[unroll]] for (uint s = D_split / 2; s > 0; s >>= 1) {
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
Sf[r][c] += subgroupShuffleXor(Sf[r][c], s);
}
}
}
if (p.logit_softcap != 0.0f) {
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
[[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) {
Sf[r][c] = p.logit_softcap * tanh(Sf[r][c]);
}
}
}
if (p.mask != 0) {
[[unroll]] for (uint32_t idx = 0; idx < Bc * Br; idx += gl_WorkGroupSize.x) {
uint32_t c = (idx + tid) % Bc;
uint32_t r = (idx + tid) / Bc;
if (idx + tid < Bc * Br) {
masksh[c][r] = float(data_m[(i * Br + r) * m_stride + (j * Bc + c)]);
}
}
barrier();
[[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) {
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
float mvf = masksh[c * cols_per_iter + col_tid][r];
Sf[r][c] += slope[r]*mvf;
}
}
barrier();
}
float rowmaxf[Br], Pf[Br][cols_per_thread], rowsumf[Br], eMf[Br], Moldf[Br];
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
rowmaxf[r] = Sf[r][0];
[[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) {
rowmaxf[r] = max(rowmaxf[r], Sf[r][c]);
}
Moldf[r] = Mf[r];
// M = max(rowmax, Mold)
// P = e^(S - M)
// eM = e^(Mold - M)
Mf[r] = max(rowmaxf[r], Moldf[r]);
[[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) {
Pf[r][c] = exp(Sf[r][c] - Mf[r]);
}
eMf[r] = exp(Moldf[r] - Mf[r]);
// Compute sum across row of P
rowsumf[r] = 0.0;
[[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) {
rowsumf[r] += Pf[r][c];
}
Lf[r] = eMf[r]*Lf[r] + rowsumf[r];
}
[[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) {
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
Of[r][d] = eMf[r] * Of[r][d];
}
}
[[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) {
[[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) {
#if BLOCK_SIZE > 1
uint coord = (j * Bc + c * cols_per_iter + col_tid) * v_stride * BLOCK_SIZE + 4 * (d * D_split + d_tid);
uint ib = coord / BLOCK_SIZE;
uint iqs = (coord % BLOCK_SIZE);
vec4 Vf = dequantize4(ib, iqs, v_offset, BINDING_IDX_V);
#else
vec4 Vf = vec4(data_vv4[v_offset / 4 + (j * Bc + c * cols_per_iter + col_tid) * v_stride / 4 + d * D_split + d_tid]);
#endif
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
Of[r][d] += Pf[r][c] * Vf;
}
}
}
barrier();
}
// reduce across threads
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
float rowmaxf, eMf;
tmpsh[tid] = Mf[r];
// Compute max across the row
barrier();
[[unroll]] for (int s = int(gl_WorkGroupSize.x) / 2; s >= D_split; s >>= 1) {
if (tid < s) {
tmpsh[tid] = max(tmpsh[tid], tmpsh[tid + s]);
}
barrier();
}
rowmaxf = tmpsh[d_tid];
barrier();
float Moldf = Mf[r];
// M = max(rowmax, Mold)
// eM = e^(Mold - M)
Mf[r] = max(rowmaxf, Moldf);
eMf = exp(Moldf - Mf[r]);
Lf[r] = eMf*Lf[r];
tmpsh[tid] = Lf[r];
// Compute sum across the row
barrier();
[[unroll]] for (int s = int(gl_WorkGroupSize.x) / 2; s >= D_split; s >>= 1) {
if (tid < s) {
tmpsh[tid] = tmpsh[tid] + tmpsh[tid + s];
}
barrier();
}
Lf[r] = tmpsh[d_tid];
barrier();
[[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) {
Of[r][d] = eMf * Of[r][d];
tmpshv4[tid] = Of[r][d];
barrier();
[[unroll]] for (int s = int(gl_WorkGroupSize.x) / 2; s >= D_split; s >>= 1) {
if (tid < s) {
Of[r][d] += tmpshv4[tid + s];
tmpshv4[tid] = Of[r][d];
}
barrier();
}
Of[r][d] = tmpshv4[d_tid];
barrier();
}
}
// If there is split_k, then the split_k resolve shader does the final
// division by L. Store the intermediate O value and per-row m and L values.
if (p.k_num > 1) {
uint32_t o_offset = D * p.ne1 * split_k_index;
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
if (r < N) {
[[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) {
[[unroll]] for (uint32_t comp = 0; comp < 4; ++comp) {
perElemOpGqaStore(r, 4*(d * D_split + d_tid) + comp, Of[r][d][comp], o_offset, iq2, N);
}
}
}
}
o_offset = D * p.ne1 * p.k_num + p.ne1 * split_k_index * 2;
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
if (r < N) {
perElemOpStoreCol0(r, 0u, ACC_TYPE(Lf[r]), o_offset, iq2, N);
perElemOpStoreCol0(r, 0u, ACC_TYPE(Mf[r]), o_offset + p.ne1, iq2, N);
}
}
return;
}
float Lfrcp[Br];
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
Lfrcp[r] = 1.0 / Lf[r];
}
[[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) {
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
Of[r][d] *= Lfrcp[r];
}
}
uint32_t o_offset = iq3*p.ne2*p.ne1;
if (p.gqa_ratio > 1) {
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
if (r < N) {
[[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) {
[[unroll]] for (uint32_t comp = 0; comp < 4; ++comp) {
perElemOpGqaStore(r, 4*(d * D_split + d_tid) + comp, Of[r][d][comp], o_offset, iq2, N);
}
}
}
}
} else {
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
if (i * Br + r < N) {
[[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) {
[[unroll]] for (uint32_t comp = 0; comp < 4; ++comp) {
data_o[o_offset + iq2 * D + (i * Br + r) * p.ne1 * D + 4*(d * D_split + d_tid) + comp] = D_TYPE(Of[r][d][comp]);
}
}
}
}
}
}

View File

@ -61,6 +61,10 @@ layout (push_constant) uniform parameter {
uint32_t n_head_log2;
float m0;
float m1;
uint32_t gqa_ratio;
uint32_t split_kv;
uint32_t k_num;
} p;
layout (binding = 0) readonly buffer Q {uint8_t data_q[];};
@ -103,6 +107,38 @@ ACC_TYPE Max(const in uint32_t row, const in uint32_t col, const in ACC_TYPE ele
#define DECODEFUNC
#endif
// Store the output when doing grouped query attention.
// Rows index by Q's dimension 2, and the first N rows are valid.
D_TYPE perElemOpGqaStore(const in uint32_t r, const in uint32_t c, const in D_TYPE elem, const in uint32_t o_offset, const in uint32_t iq2, const in uint32_t N)
{
if (r < N && c < D) {
uint32_t offset = (iq2 + r) * D + c;
data_o[o_offset + offset] = D_TYPE(elem);
}
return elem;
}
// Store column zero. This is used to save per-row m and L values for split_k.
ACC_TYPE perElemOpStoreCol0(const in uint32_t r, const in uint32_t c, const in ACC_TYPE elem, const in uint32_t o_offset, const in uint32_t iq2, const in uint32_t N)
{
if (r < N && c == 0) {
uint32_t offset = iq2 + r;
data_o[o_offset + offset] = D_TYPE(elem);
}
return elem;
}
// Load the slope matrix, indexed by Q's dimension 2.
ACC_TYPE perElemOpComputeSlope(const in uint32_t r, const in uint32_t c, const in ACC_TYPE elem, const in uint32_t iq2)
{
const uint32_t h = iq2 + (r % p.gqa_ratio);
const ACC_TYPE base = ACC_TYPE(h < p.n_head_log2 ? p.m0 : p.m1);
const int exph = int(h < p.n_head_log2 ? h + 1 : 2*(h - p.n_head_log2) + 1);
return ACC_TYPE(pow(base, ACC_TYPE(exph)));
}
void main() {
#ifdef NEEDS_INIT_IQ_SHMEM
init_iq_shmem(gl_WorkGroupSize);
@ -111,12 +147,22 @@ void main() {
const uint32_t N = p.N;
const uint32_t KV = p.KV;
uint32_t i = gl_WorkGroupID.x;
uint32_t split_k_index = 0;
if (p.k_num > 1) {
i = 0;
split_k_index = gl_WorkGroupID.x;
}
const uint32_t Tr = CEIL_DIV(N, Br);
const uint32_t Tc = CEIL_DIV(KV, Bc);
const uint32_t i = gl_WorkGroupID.x;
const uint32_t start_j = split_k_index * p.split_kv / Bc;
const uint32_t end_j = CEIL_DIV(min(KV, (split_k_index + 1) * p.split_kv), Bc);
const uint32_t iq2 = gl_WorkGroupID.y;
// When not using grouped query attention, all rows share the same iq2, equal to gl_WorkGroupID.y.
// When using grouped query attention, each workgroup does gqa_ratio consecutive values of iq2.
const uint32_t iq2 = gl_WorkGroupID.y * p.gqa_ratio;
const uint32_t iq3 = gl_WorkGroupID.z;
// broadcast factors
@ -149,10 +195,17 @@ void main() {
tensorLayoutK = setTensorLayoutDimensionNV(tensorLayoutK, KV, D);
tensorLayoutV = setTensorLayoutDimensionNV(tensorLayoutV, KV, D);
// nb?1 are already divided by the type size and are in units of elements
uint32_t q_stride = p.nb01;
// nb?1 are already divided by the type size and are in units of elements.
// When using grouped query attention, Q is indexed by iq2, so the stride
// should be nb02 (which is in bytes).
uint32_t q_stride = p.gqa_ratio > 1 ? (p.nb02 / 4) : p.nb01;
uint32_t k_stride = p.nb11;
uint32_t v_stride = p.nb21;
// When using grouped query attention, all rows use the same mask (stride 0).
// "p.gqa_ratio >> 16" is just a roundabout way of writing zero
// that prevents the compiler from folding the "&" through the select
// and breaking the alignment detection.
uint32_t m_stride = (p.gqa_ratio > 1) ? (p.gqa_ratio >> 16) : KV;
// hint to the compiler that strides are aligned for the aligned variant of the shader
if (Clamp != gl_CooperativeMatrixClampModeConstantNV)
{
@ -161,6 +214,7 @@ void main() {
k_stride &= ~7;
v_stride &= ~7;
#endif
m_stride &= ~7;
}
tensorLayoutQ = setTensorLayoutStrideNV(tensorLayoutQ, q_stride, 1);
tensorLayoutK = setTensorLayoutStrideNV(tensorLayoutK, k_stride, 1);
@ -179,23 +233,21 @@ void main() {
coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> L, M;
L = coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(0);
M = coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(-1.0/0.0);
// Use -FLT_MAX/2 rather than -inf to reduce the possibility of NaNs, e.g. when computing Mold-M.
const float NEG_FLT_MAX_OVER_2 = uintBitsToFloat(0xFEFFFFFF);
ACC_TYPE slope = ACC_TYPE(1.0);
L = coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(0);
M = coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(NEG_FLT_MAX_OVER_2);
coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> slopeMat = coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(1.0);
// ALiBi
if (p.max_bias > 0.0f) {
const uint32_t h = iq2;
const ACC_TYPE base = ACC_TYPE(h < p.n_head_log2 ? p.m0 : p.m1);
const int exph = int(h < p.n_head_log2 ? h + 1 : 2*(h - p.n_head_log2) + 1);
slope = pow(base, ACC_TYPE(exph));
coopMatPerElementNV(slopeMat, slopeMat, perElemOpComputeSlope, iq2);
}
[[dont_unroll]]
for (uint32_t j = 0; j < Tc; ++j) {
for (uint32_t j = start_j; j < end_j; ++j) {
coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> S = coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(0);
@ -213,14 +265,15 @@ void main() {
}
if (p.mask != 0) {
tensorLayoutNV<2, gl_CooperativeMatrixClampModeConstantNV> tensorLayoutM = createTensorLayoutNV(2, gl_CooperativeMatrixClampModeConstantNV);
tensorLayoutNV<2, Clamp> tensorLayoutM = createTensorLayoutNV(2, Clamp);
tensorLayoutM = setTensorLayoutDimensionNV(tensorLayoutM, p.nem1, KV);
tensorLayoutM = setTensorLayoutStrideNV(tensorLayoutM, m_stride, 1);
coopmat<float16_t, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> mv;
coopMatLoadTensorNV(mv, data_m, 0, sliceTensorLayoutNV(tensorLayoutM, i * Br, Br, j * Bc, Bc));
S += slope*coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(mv);
S += slopeMat*coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(mv);
}
// Clear padding elements to -inf, so they don't contribute to rowmax
@ -231,7 +284,7 @@ void main() {
uint R = ((i + 1) * Br > N) ? (N % Br) : Br;
uint C = ((j + 1) * Bc > KV) ? (KV % Bc) : Bc;
coopMatPerElementNV(S, S, replacePadding, ACC_TYPE(-1.0/0.0), R, C);
coopMatPerElementNV(S, S, replacePadding, ACC_TYPE(NEG_FLT_MAX_OVER_2), R, C);
}
coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> rowmax, P, rowsum, eM;
@ -280,9 +333,25 @@ void main() {
// resize eM by using smear/reduce
coopMatReduceNV(eMdiag, eM, gl_CooperativeMatrixReduceRowNV, smearReduce);
O = eMdiag * O;
// multiply with fp16 accumulation, then add to O.
coopmat<float16_t, gl_ScopeWorkgroup, Br, D, gl_MatrixUseAccumulator> PV = coopmat<float16_t, gl_ScopeWorkgroup, Br, D, gl_MatrixUseAccumulator>(0);
PV = coopMatMulAdd(P_A, V, PV);
O = coopMatMulAdd(P_A, V, O);
O = eMdiag * O + coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, D, gl_MatrixUseAccumulator>(PV);
}
// If there is split_k, then the split_k resolve shader does the final
// division by L. Store the intermediate O value and per-row m and L values.
if (p.k_num > 1) {
coopmat<D_TYPE, gl_ScopeWorkgroup, Br, D, gl_MatrixUseAccumulator> O_D = coopmat<D_TYPE, gl_ScopeWorkgroup, Br, D, gl_MatrixUseAccumulator>(O);
uint32_t o_offset = D * p.ne1 * split_k_index;
coopMatPerElementNV(O_D, O_D, perElemOpGqaStore, o_offset, iq2, N);
o_offset = D * p.ne1 * p.k_num + p.ne1 * split_k_index * 2;
coopMatPerElementNV(L, L, perElemOpStoreCol0, o_offset, iq2, N);
coopMatPerElementNV(M, M, perElemOpStoreCol0, o_offset + p.ne1, iq2, N);
return;
}
coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, D, gl_MatrixUseAccumulator> Ldiag;
@ -297,13 +366,18 @@ void main() {
O = Ldiag*O;
tensorLayoutNV<3, gl_CooperativeMatrixClampModeConstantNV> tensorLayoutD = createTensorLayoutNV(3, gl_CooperativeMatrixClampModeConstantNV);
tensorLayoutD = setTensorLayoutDimensionNV(tensorLayoutD, p.ne2, p.ne1, D);
// permute dimensions
tensorViewNV<3, false, 1, 0, 2> tensorViewPermute = createTensorViewNV(3, false, 1, 0, 2);
uint32_t o_offset = iq3*p.ne2*p.ne1;
coopmat<D_TYPE, gl_ScopeWorkgroup, Br, D, gl_MatrixUseAccumulator> O_D = coopmat<D_TYPE, gl_ScopeWorkgroup, Br, D, gl_MatrixUseAccumulator>(O);
coopMatStoreTensorNV(O_D, data_o, o_offset, sliceTensorLayoutNV(tensorLayoutD, i * Br, Br, iq2, 1, 0, D), tensorViewPermute);
if (p.gqa_ratio > 1) {
coopMatPerElementNV(O_D, O_D, perElemOpGqaStore, o_offset, iq2, N);
} else {
tensorLayoutNV<3, gl_CooperativeMatrixClampModeConstantNV> tensorLayoutD = createTensorLayoutNV(3, gl_CooperativeMatrixClampModeConstantNV);
tensorLayoutD = setTensorLayoutDimensionNV(tensorLayoutD, p.ne2, p.ne1, D);
// permute dimensions
tensorViewNV<3, false, 1, 0, 2> tensorViewPermute = createTensorViewNV(3, false, 1, 0, 2);
coopMatStoreTensorNV(O_D, data_o, o_offset, sliceTensorLayoutNV(tensorLayoutD, i * Br, Br, iq2, N, 0, D), tensorViewPermute);
}
}

View File

@ -0,0 +1,59 @@
#version 450
#extension GL_EXT_control_flow_attributes : enable
#define BLOCK_SIZE 32
layout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in;
layout (binding = 0) readonly buffer A {float data_a[];};
layout (binding = 1) writeonly buffer D {float data_d[];};
layout (push_constant) uniform parameter {
uint D;
uint N;
uint k_num;
} p;
void main() {
// Each workgroup handles a row
const uint n = gl_WorkGroupID.x;
const uint tid = gl_LocalInvocationID.x;
uint D = p.D;
uint N = p.N;
uint k_num = p.k_num;
uint l_offset = D * N * k_num + n;
uint m_offset = D * N * k_num + N + n;
uint lm_stride = N * 2;
// Compute the max m value for the row
float m_max = -1.0/0.0;
[[unroll]] for (uint k = 0; k < k_num; ++k) {
float m = data_a[m_offset + k * lm_stride];
m_max = max(m_max, m);
}
// Compute L based on m_max
float L = 0;
[[unroll]] for (uint k = 0; k < k_num; ++k) {
float l = data_a[l_offset + k * lm_stride];
float m = data_a[m_offset + k * lm_stride];
L += exp(m - m_max) * l;
}
L = 1.0 / L;
// Scale and sum the O contributions based on m_max and store the result to memory
for (uint d = tid; d < D; d += BLOCK_SIZE) {
float O = 0.0;
[[unroll]] for (uint k = 0; k < k_num; ++k) {
uint o_offset = D * N * k + D * n + d;
float m = data_a[m_offset + k * lm_stride];
O += exp(m - m_max) * data_a[o_offset];
}
O *= L;
data_d[D * n + d] = O;
}
}

View File

@ -20,9 +20,14 @@ void main() {
const uint a_offset = get_aoffset() + i01*p.nb01 + i11*p.nb02 + i12*p.nb03;
const uint d_offset = get_doffset() + i10*p.nb21 + i11*p.nb22 + i12*p.nb23;
#ifndef OPTIMIZATION_ERROR_WORKAROUND
data_d[d_offset + i00] = D_TYPE(data_a[a_offset + i00]);
#if defined(DATA_A_BF16)
FLOAT_TYPE v = FLOAT_TYPE(bf16_to_fp32(data_a[a_offset + i00]));
#else
data_d[d_offset + i00] = data_a[a_offset + i00];
FLOAT_TYPE v = FLOAT_TYPE(data_a[a_offset + i00]);
#endif
#ifndef OPTIMIZATION_ERROR_WORKAROUND
data_d[d_offset + i00] = D_TYPE(v);
#else
data_d[d_offset + i00] = D_TYPE(v);
#endif
}

View File

@ -1,5 +1,7 @@
#version 450
#extension GL_EXT_control_flow_attributes : enable
#include "types.comp"
#include "generic_binary_head.comp"
#include "dequant_funcs.comp"

View File

@ -40,6 +40,20 @@ void main() {
const uint batch = gl_GlobalInvocationID.z / p.IC;
const uint ic = gl_GlobalInvocationID.z % p.IC;
const uint src_base = ic * p.offset_delta + batch * p.batch_offset;
const uint dst_base = ((batch * p.OH + oh) * p.OW) * p.CHW + ic * (p.KW * p.KH);
const int oh_s1 = int(oh) * p.s1;
const uint ksize = p.OW * (p.KH > 1 ? p.KW : 1);
const uint base_linear_idx = gidx * NUM_ITER;
const uint max_ky = ksize / p.OW;
uint current_kx = base_linear_idx / ksize;
const uint rem = base_linear_idx - (current_kx * ksize);
uint current_ky = rem / p.OW;
uint current_ix = rem % p.OW;
A_TYPE values[NUM_ITER];
uint offset_dst[NUM_ITER];
[[unroll]] for (uint idx = 0; idx < NUM_ITER; ++idx) {
@ -48,36 +62,35 @@ void main() {
[[unroll]] for (uint idx = 0; idx < NUM_ITER; ++idx) {
const uint i = gidx * NUM_ITER + idx;
const uint linear_idx = base_linear_idx + idx;
const uint ksize = p.OW * (p.KH > 1 ? p.KW : 1);
const uint kx = i / ksize;
const uint kd = kx * ksize;
const uint ky = (i - kd) / p.OW;
const uint ix = i % p.OW;
const uint iiw = ix * p.s0 + kx * p.d0 - p.p0;
const uint iih = oh * p.s1 + ky * p.d1 - p.p1;
offset_dst[idx] =
((batch * p.OH + oh) * p.OW + ix) * p.CHW +
(ic * (p.KW * p.KH) + ky * p.KW + kx);
if (i >= p.pelements) {
if (linear_idx >= p.pelements) {
continue;
}
if (iih < p.IH && iiw < p.IW) {
const uint offset_src = ic * p.offset_delta + batch * p.batch_offset;
values[idx] = data_a[offset_src + iih * p.IW + iiw];
const uint iiw = current_ix * p.s0 + current_kx * p.d0 - p.p0;
const uint iih = oh_s1 + current_ky * p.d1 - p.p1;
offset_dst[idx] = dst_base + current_ix * p.CHW + current_ky * p.KW + current_kx;
if ((iih < p.IH) && (iiw < p.IW)) {
values[idx] = data_a[src_base + iih * p.IW + iiw];
}
if (++current_ix == p.OW) {
current_ix = 0;
if (++current_ky == max_ky) {
current_ky = 0;
current_kx++;
}
}
}
[[unroll]] for (uint idx = 0; idx < NUM_ITER; ++idx) {
const uint i = gidx * NUM_ITER + idx;
const uint linear_idx = base_linear_idx + idx;
if (i >= p.pelements) {
if (linear_idx >= p.pelements) {
continue;
}

View File

@ -0,0 +1,41 @@
#version 450
#include "generic_head.comp"
#include "types.comp"
#extension GL_EXT_control_flow_attributes : enable
#define BLOCK_SIZE 512
layout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in;
layout (binding = 0) readonly buffer X {A_TYPE data_a[];};
layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};
shared FLOAT_TYPE sum[BLOCK_SIZE];
void main() {
const uint row = gl_WorkGroupID.z * 262144 + gl_WorkGroupID.y * 512 + gl_WorkGroupID.x;
const uint tid = gl_LocalInvocationID.x;
sum[tid] = FLOAT_TYPE(0.0f); // partial sum for thread in warp
[[unroll]] for (uint col = tid; col < p.KX; col += BLOCK_SIZE) {
const FLOAT_TYPE xi = FLOAT_TYPE(data_a[row*p.KX + col]);
sum[tid] += xi * xi;
}
// sum up partial sums and write back result
barrier();
[[unroll]] for (int s = BLOCK_SIZE / 2; s > 0; s >>= 1) {
if (tid < s) {
sum[tid] += sum[tid + s];
}
barrier();
}
const FLOAT_TYPE scale = inversesqrt(max(sum[0], FLOAT_TYPE(p.param1)));
[[unroll]] for (uint col = tid; col < p.KX; col += BLOCK_SIZE) {
data_d[row*p.KX + col] = D_TYPE(scale * FLOAT_TYPE(data_a[row*p.KX + col]));
}
}

View File

@ -6,7 +6,7 @@
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
#if !defined(DATA_A_F32) && !defined(DATA_A_F16)
#if !defined(DATA_A_F32) && !defined(DATA_A_F16) && !defined(DATA_A_BF16)
#define K_PER_ITER 8
#else
#define K_PER_ITER 2
@ -105,6 +105,16 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
int unroll_count = 4;
uint unrolled_iters = num_iters & ~(unroll_count - 1);
#if K_PER_ITER == 2
// If the K dimension is odd, we need lastiter==true on the last iteration
// so OOB is computed correctly. Skip some unrolling to make that happen.
if ((p.ncols & 1) != 0 &&
unrolled_iters == num_iters &&
unrolled_iters > 0) {
unrolled_iters -= unroll_count;
}
#endif
uint i = 0;
while (i < unrolled_iters) {
// Manually partially unroll the loop
@ -113,8 +123,18 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
i++;
}
}
unroll_count = 2;
unrolled_iters = num_iters & ~(unroll_count - 1);
#if K_PER_ITER == 2
if ((p.ncols & 1) != 0 &&
unrolled_iters == num_iters &&
unrolled_iters > 0) {
unrolled_iters -= unroll_count;
}
#endif
while (i < unrolled_iters) {
// Manually partially unroll the loop
[[unroll]] for (uint k = 0; k < unroll_count; ++k) {

View File

@ -0,0 +1,90 @@
#version 450
#extension GL_EXT_shader_explicit_arithmetic_types_int32 : require
#include "mul_mat_vec_base.comp"
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
FLOAT_TYPE temp[NUM_COLS][NUM_ROWS];
void calc_superblock(const uint a_offset, const uint b_offset, const uint itid, const uint i, const uint num_blocks_per_row, const uint first_row, const uint num_rows) {
const uint y_idx = i * QUANT_K + 16 * itid;
const uint nibble_shift = 4 * (itid & 1);
const uint ib32 = itid / 2; // 0..7
uint ibi = a_offset / QUANT_K + first_row * num_blocks_per_row + i;
[[unroll]] for (uint n = 0; n < num_rows; ++n) {
const float d = float(data_a[ibi].d);
const uint scale = (data_a[ibi].scales[ib32] >> nibble_shift) & 0xF;
const float db = d * (0.5 + scale) * 0.25;
const uint qh = data_a[ibi].qh[ib32];
const u8vec2 qs16 = unpack8(uint32_t(data_a_packed16[ibi].qs[itid])).xy; // vec4 used due to #12147
const u8vec2 sign16 = unpack8(uint32_t(data_a_packed16[ibi].qs[QUANT_K / 16 + itid])).xy;
[[unroll]] for (uint l = 0; l < 2; ++l) {
const uint8_t sign = sign16[l];
const uint qs = qs16[l] | ((qh << (8 - nibble_shift - 2 * l)) & 0x300);
const uvec2 grid = iq2s_grid[qs];
const vec4 grid0 = vec4(unpack8(grid.x));
const vec4 grid1 = vec4(unpack8(grid.y));
[[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
vec4 b0 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + y_idx) / 4 + 2*l + 0]);
vec4 b4 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + y_idx) / 4 + 2*l + 1]);
FLOAT_TYPE sum =
fma(FLOAT_TYPE(b0.x), FLOAT_TYPE((sign & 1) != 0 ? -grid0.x : grid0.x),
fma(FLOAT_TYPE(b0.y), FLOAT_TYPE((sign & 2) != 0 ? -grid0.y : grid0.y),
fma(FLOAT_TYPE(b0.z), FLOAT_TYPE((sign & 4) != 0 ? -grid0.z : grid0.z),
fma(FLOAT_TYPE(b0.w), FLOAT_TYPE((sign & 8) != 0 ? -grid0.w : grid0.w),
fma(FLOAT_TYPE(b4.x), FLOAT_TYPE((sign & 16) != 0 ? -grid1.x : grid1.x),
fma(FLOAT_TYPE(b4.y), FLOAT_TYPE((sign & 32) != 0 ? -grid1.y : grid1.y),
fma(FLOAT_TYPE(b4.z), FLOAT_TYPE((sign & 64) != 0 ? -grid1.z : grid1.z),
fma(FLOAT_TYPE(b4.w), FLOAT_TYPE((sign & 128) != 0 ? -grid1.w : grid1.w),
FLOAT_TYPE(0.0)))))))));
temp[j][n] = fma(db, sum, temp[j][n]);
}
}
ibi += num_blocks_per_row;
}
}
void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
uint a_offset, b_offset, d_offset;
get_offsets(a_offset, b_offset, d_offset);
const uint num_blocks_per_row = p.ncols / QUANT_K;
// 16 threads are used to process each block
const uint blocks_per_wg = gl_WorkGroupSize.x/16;
const uint tid = gl_LocalInvocationID.x;
const uint itid = tid % 16; // 0...15
const uint ix = tid / 16;
[[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
[[unroll]] for (uint i = 0; i < NUM_ROWS; ++i) {
temp[j][i] = FLOAT_TYPE(0);
}
}
[[unroll]] for (uint i = ix; i < num_blocks_per_row; i += blocks_per_wg)
calc_superblock(a_offset, b_offset, itid, i, num_blocks_per_row, first_row, num_rows);
reduce_result(temp, d_offset, first_row, num_rows, tid);
}
void main() {
const uint first_row = NUM_ROWS * (gl_WorkGroupID.x + gl_NumWorkGroups.x * gl_WorkGroupID.z);
init_iq_shmem(gl_WorkGroupSize);
// do NUM_ROWS at a time, unless there aren't enough remaining rows
if (first_row + NUM_ROWS <= p.stride_d) {
compute_outputs(first_row, NUM_ROWS);
} else {
if (first_row >= p.stride_d) {
return;
}
compute_outputs(first_row, p.stride_d - first_row);
}
}

View File

@ -0,0 +1,87 @@
#version 450
#extension GL_EXT_shader_explicit_arithmetic_types_int32 : require
#include "mul_mat_vec_base.comp"
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
FLOAT_TYPE temp[NUM_COLS][NUM_ROWS];
void calc_superblock(const uint a_offset, const uint b_offset, const uint itid, const uint i, const uint num_blocks_per_row, const uint first_row, const uint num_rows) {
const uint y_idx = i * QUANT_K + 16 * itid;
const uint nibble_shift = 4 * (itid & 1);
const uint ib32 = itid / 2; // 0..7
uint ibi = a_offset / QUANT_K + first_row * num_blocks_per_row + i;
[[unroll]] for (uint n = 0; n < num_rows; ++n) {
const float d = float(data_a[ibi].d);
const uint scale = (data_a[ibi].scales[ib32] >> nibble_shift) & 0xF;
const float db = d * (0.5 + scale) * 0.25;
[[unroll]] for (uint l = 0; l < 2; ++l) {
const uint qs = data_a[ibi].qs[2 * itid + l];
const uint sign = qs >> 9;
const uint sign7 = bitCount(sign);
const vec4 grid0 = vec4(unpack8(iq2xs_grid[qs & 511].x));
const vec4 grid1 = vec4(unpack8(iq2xs_grid[qs & 511].y));
[[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
vec4 b0 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + y_idx) / 4 + 2*l + 0]);
vec4 b4 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + y_idx) / 4 + 2*l + 1]);
FLOAT_TYPE sum =
fma(FLOAT_TYPE(b0.x), FLOAT_TYPE((sign & 1) != 0 ? -grid0.x : grid0.x),
fma(FLOAT_TYPE(b0.y), FLOAT_TYPE((sign & 2) != 0 ? -grid0.y : grid0.y),
fma(FLOAT_TYPE(b0.z), FLOAT_TYPE((sign & 4) != 0 ? -grid0.z : grid0.z),
fma(FLOAT_TYPE(b0.w), FLOAT_TYPE((sign & 8) != 0 ? -grid0.w : grid0.w),
fma(FLOAT_TYPE(b4.x), FLOAT_TYPE((sign & 16) != 0 ? -grid1.x : grid1.x),
fma(FLOAT_TYPE(b4.y), FLOAT_TYPE((sign & 32) != 0 ? -grid1.y : grid1.y),
fma(FLOAT_TYPE(b4.z), FLOAT_TYPE((sign & 64) != 0 ? -grid1.z : grid1.z),
fma(FLOAT_TYPE(b4.w), FLOAT_TYPE((sign7 & 1) != 0 ? -grid1.w : grid1.w),
FLOAT_TYPE(0.0)))))))));
temp[j][n] = fma(db, sum, temp[j][n]);
}
}
ibi += num_blocks_per_row;
}
}
void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
uint a_offset, b_offset, d_offset;
get_offsets(a_offset, b_offset, d_offset);
const uint num_blocks_per_row = p.ncols / QUANT_K;
// 16 threads are used to process each block
const uint blocks_per_wg = gl_WorkGroupSize.x/16;
const uint tid = gl_LocalInvocationID.x;
const uint itid = tid % 16; // 0...15
const uint ix = tid / 16;
[[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
[[unroll]] for (uint i = 0; i < NUM_ROWS; ++i) {
temp[j][i] = FLOAT_TYPE(0);
}
}
[[unroll]] for (uint i = ix; i < num_blocks_per_row; i += blocks_per_wg)
calc_superblock(a_offset, b_offset, itid, i, num_blocks_per_row, first_row, num_rows);
reduce_result(temp, d_offset, first_row, num_rows, tid);
}
void main() {
const uint first_row = NUM_ROWS * (gl_WorkGroupID.x + gl_NumWorkGroups.x * gl_WorkGroupID.z);
init_iq_shmem(gl_WorkGroupSize);
// do NUM_ROWS at a time, unless there aren't enough remaining rows
if (first_row + NUM_ROWS <= p.stride_d) {
compute_outputs(first_row, NUM_ROWS);
} else {
if (first_row >= p.stride_d) {
return;
}
compute_outputs(first_row, p.stride_d - first_row);
}
}

View File

@ -0,0 +1,87 @@
#version 450
#extension GL_EXT_shader_explicit_arithmetic_types_int32 : require
#include "mul_mat_vec_base.comp"
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
FLOAT_TYPE temp[NUM_COLS][NUM_ROWS];
void calc_superblock(const uint a_offset, const uint b_offset, const uint itid, const uint i, const uint num_blocks_per_row, const uint first_row, const uint num_rows) {
const uint y_idx = i * QUANT_K + 16 * itid;
const uint ib32 = itid / 2; // 0..7
uint ibi = a_offset / QUANT_K + first_row * num_blocks_per_row + i;
[[unroll]] for (uint n = 0; n < num_rows; ++n) {
const float d = float(data_a[ibi].d);
const uint signscale = pack32(u16vec2(
data_a_packed16[ibi].qs[4 * ib32 + 2],
data_a_packed16[ibi].qs[4 * ib32 + 3]));
const float db = d * 0.25 * (0.5 + (signscale >> 28));
[[unroll]] for (uint l = 0; l < 2; ++l) {
const uint qs = data_a[ibi].qs[8 * ib32 + 2 * (itid & 1) + l];
const uint sign = bitfieldExtract(signscale, 7 * int(2 * (itid & 1) + l), 7);
const uint sign7 = bitCount(sign);
const vec4 grid0 = vec4(unpack8(iq2xxs_grid[qs].x));
const vec4 grid1 = vec4(unpack8(iq2xxs_grid[qs].y));
[[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
const vec4 b0 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + y_idx) / 4 + 2*l + 0]);
const vec4 b4 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + y_idx) / 4 + 2*l + 1]);
FLOAT_TYPE sum =
fma(FLOAT_TYPE(b0.x), FLOAT_TYPE((sign & 1) != 0 ? -grid0.x : grid0.x),
fma(FLOAT_TYPE(b0.y), FLOAT_TYPE((sign & 2) != 0 ? -grid0.y : grid0.y),
fma(FLOAT_TYPE(b0.z), FLOAT_TYPE((sign & 4) != 0 ? -grid0.z : grid0.z),
fma(FLOAT_TYPE(b0.w), FLOAT_TYPE((sign & 8) != 0 ? -grid0.w : grid0.w),
fma(FLOAT_TYPE(b4.x), FLOAT_TYPE((sign & 16) != 0 ? -grid1.x : grid1.x),
fma(FLOAT_TYPE(b4.y), FLOAT_TYPE((sign & 32) != 0 ? -grid1.y : grid1.y),
fma(FLOAT_TYPE(b4.z), FLOAT_TYPE((sign & 64) != 0 ? -grid1.z : grid1.z),
fma(FLOAT_TYPE(b4.w), FLOAT_TYPE((sign7 & 1) != 0 ? -grid1.w : grid1.w),
FLOAT_TYPE(0.0)))))))));
temp[j][n] = fma(db, sum, temp[j][n]);
}
}
ibi += num_blocks_per_row;
}
}
void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
uint a_offset, b_offset, d_offset;
get_offsets(a_offset, b_offset, d_offset);
const uint num_blocks_per_row = p.ncols / QUANT_K;
// 16 threads are used to process each block
const uint blocks_per_wg = gl_WorkGroupSize.x/16;
const uint tid = gl_LocalInvocationID.x;
const uint itid = tid % 16; // 0...15
const uint ix = tid / 16;
[[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
[[unroll]] for (uint i = 0; i < NUM_ROWS; ++i) {
temp[j][i] = FLOAT_TYPE(0);
}
}
[[unroll]] for (uint i = ix; i < num_blocks_per_row; i += blocks_per_wg)
calc_superblock(a_offset, b_offset, itid, i, num_blocks_per_row, first_row, num_rows);
reduce_result(temp, d_offset, first_row, num_rows, tid);
}
void main() {
const uint first_row = NUM_ROWS * (gl_WorkGroupID.x + gl_NumWorkGroups.x * gl_WorkGroupID.z);
init_iq_shmem(gl_WorkGroupSize);
// do NUM_ROWS at a time, unless there aren't enough remaining rows
if (first_row + NUM_ROWS <= p.stride_d) {
compute_outputs(first_row, NUM_ROWS);
} else {
if (first_row >= p.stride_d) {
return;
}
compute_outputs(first_row, p.stride_d - first_row);
}
}

View File

@ -0,0 +1,90 @@
#version 450
#extension GL_EXT_shader_explicit_arithmetic_types_int32 : require
#include "mul_mat_vec_base.comp"
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
FLOAT_TYPE temp[NUM_COLS][NUM_ROWS];
void calc_superblock(const uint a_offset, const uint b_offset, const uint ib32, const uint i, const uint num_blocks_per_row, const uint first_row, const uint num_rows) {
const uint y_idx = i * QUANT_K + 32 * ib32;
uint ibi = a_offset / QUANT_K + first_row * num_blocks_per_row + i;
[[unroll]] for (uint n = 0; n < num_rows; ++n) {
const float d = float(data_a[ibi].d);
const uint scale = (data_a[ibi].scales[ib32/2] >> (4 * (ib32 & 1))) & 0xF;
const float dscale = d * (1 + 2 * scale);
const uint qh = data_a[ibi].qh[ib32];
FLOAT_TYPE sum[NUM_COLS];
[[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
sum[j] = 0.0;
}
[[unroll]] for (uint l = 0; l < 4; ++l) {
const u8vec2 qs = unpack8(uint32_t(data_a_packed16[ibi].qs[4 * ib32 + l])).xy; // vec4 used due to #12147
const uint sign = data_a[ibi].signs[4 * ib32 + l];
const vec4 grid0 = vec4(unpack8(iq3s_grid[qs.x | ((qh << (8 - 2*l)) & 0x100)]));
const vec4 grid1 = vec4(unpack8(iq3s_grid[qs.y | ((qh << (7 - 2*l)) & 0x100)]));
[[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
const vec4 b0 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + y_idx) / 4 + 2*l + 0]);
const vec4 b4 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + y_idx) / 4 + 2*l + 1]);
sum[j] =
fma(FLOAT_TYPE(b0.x), FLOAT_TYPE((sign & 1) != 0 ? -grid0.x : grid0.x),
fma(FLOAT_TYPE(b0.y), FLOAT_TYPE((sign & 2) != 0 ? -grid0.y : grid0.y),
fma(FLOAT_TYPE(b0.z), FLOAT_TYPE((sign & 4) != 0 ? -grid0.z : grid0.z),
fma(FLOAT_TYPE(b0.w), FLOAT_TYPE((sign & 8) != 0 ? -grid0.w : grid0.w),
fma(FLOAT_TYPE(b4.x), FLOAT_TYPE((sign & 16) != 0 ? -grid1.x : grid1.x),
fma(FLOAT_TYPE(b4.y), FLOAT_TYPE((sign & 32) != 0 ? -grid1.y : grid1.y),
fma(FLOAT_TYPE(b4.z), FLOAT_TYPE((sign & 64) != 0 ? -grid1.z : grid1.z),
fma(FLOAT_TYPE(b4.w), FLOAT_TYPE((sign & 128) != 0 ? -grid1.w : grid1.w),
sum[j]))))))));
}
}
[[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
temp[j][n] = fma(dscale, sum[j], temp[j][n]);
}
ibi += num_blocks_per_row;
}
}
void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
uint a_offset, b_offset, d_offset;
get_offsets(a_offset, b_offset, d_offset);
const uint num_blocks_per_row = p.ncols / QUANT_K;
// 8 threads are used to process each block
const uint blocks_per_wg = gl_WorkGroupSize.x/8;
const uint tid = gl_LocalInvocationID.x;
const uint itid = tid % 8; // 0...7
const uint ix = tid / 8;
[[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
[[unroll]] for (uint i = 0; i < NUM_ROWS; ++i) {
temp[j][i] = FLOAT_TYPE(0);
}
}
[[unroll]] for (uint i = ix; i < num_blocks_per_row; i += blocks_per_wg)
calc_superblock(a_offset, b_offset, itid, i, num_blocks_per_row, first_row, num_rows);
reduce_result(temp, d_offset, first_row, num_rows, tid);
}
void main() {
const uint first_row = NUM_ROWS * (gl_WorkGroupID.x + gl_NumWorkGroups.x * gl_WorkGroupID.z);
init_iq_shmem(gl_WorkGroupSize);
// do NUM_ROWS at a time, unless there aren't enough remaining rows
if (first_row + NUM_ROWS <= p.stride_d) {
compute_outputs(first_row, NUM_ROWS);
} else {
if (first_row >= p.stride_d) {
return;
}
compute_outputs(first_row, p.stride_d - first_row);
}
}

View File

@ -0,0 +1,88 @@
#version 450
#extension GL_EXT_shader_explicit_arithmetic_types_int32 : require
#include "mul_mat_vec_base.comp"
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
FLOAT_TYPE temp[NUM_COLS][NUM_ROWS];
void calc_superblock(const uint a_offset, const uint b_offset, const uint itid, const uint i, const uint num_blocks_per_row, const uint first_row, const uint num_rows) {
const uint y_idx = i * QUANT_K + 16 * itid;
const uint ib32 = itid / 2; // 0..7
uint ibi = a_offset / QUANT_K + first_row * num_blocks_per_row + i;
[[unroll]] for (uint n = 0; n < num_rows; ++n) {
const float d = float(data_a[ibi].d);
const uint signscale = pack32(u16vec2(
data_a_packed16[ibi].qs[QUANT_K / 8 + 2 * ib32],
data_a_packed16[ibi].qs[QUANT_K / 8 + 2 * ib32 + 1]));
const float db = d * 0.5 * (0.5 + (signscale >> 28));
[[unroll]] for (uint l = 0; l < 2; ++l) {
const uint qs0 = data_a[ibi].qs[8 * ib32 + 4 * (itid & 1) + 2 * l];
const uint qs1 = data_a[ibi].qs[8 * ib32 + 4 * (itid & 1) + 2 * l + 1];
const uint sign = bitfieldExtract(signscale, 7 * int(2 * (itid & 1) + l), 7);
const uint sign7 = bitCount(sign);
const vec4 grid0 = vec4(unpack8(iq3xxs_grid[qs0]));
const vec4 grid1 = vec4(unpack8(iq3xxs_grid[qs1]));
[[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
const vec4 b0 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + y_idx) / 4 + 2*l + 0]);
const vec4 b4 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + y_idx) / 4 + 2*l + 1]);
FLOAT_TYPE sum =
fma(FLOAT_TYPE(b0.x), FLOAT_TYPE((sign & 1) != 0 ? -grid0.x : grid0.x),
fma(FLOAT_TYPE(b0.y), FLOAT_TYPE((sign & 2) != 0 ? -grid0.y : grid0.y),
fma(FLOAT_TYPE(b0.z), FLOAT_TYPE((sign & 4) != 0 ? -grid0.z : grid0.z),
fma(FLOAT_TYPE(b0.w), FLOAT_TYPE((sign & 8) != 0 ? -grid0.w : grid0.w),
fma(FLOAT_TYPE(b4.x), FLOAT_TYPE((sign & 16) != 0 ? -grid1.x : grid1.x),
fma(FLOAT_TYPE(b4.y), FLOAT_TYPE((sign & 32) != 0 ? -grid1.y : grid1.y),
fma(FLOAT_TYPE(b4.z), FLOAT_TYPE((sign & 64) != 0 ? -grid1.z : grid1.z),
fma(FLOAT_TYPE(b4.w), FLOAT_TYPE((sign7 & 1) != 0 ? -grid1.w : grid1.w),
FLOAT_TYPE(0.0)))))))));
temp[j][n] = fma(db, sum, temp[j][n]);
}
}
ibi += num_blocks_per_row;
}
}
void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
uint a_offset, b_offset, d_offset;
get_offsets(a_offset, b_offset, d_offset);
const uint num_blocks_per_row = p.ncols / QUANT_K;
// 16 threads are used to process each block
const uint blocks_per_wg = gl_WorkGroupSize.x/16;
const uint tid = gl_LocalInvocationID.x;
const uint itid = tid % 16; // 0...15
const uint ix = tid / 16;
[[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
[[unroll]] for (uint i = 0; i < NUM_ROWS; ++i) {
temp[j][i] = FLOAT_TYPE(0);
}
}
[[unroll]] for (uint i = ix; i < num_blocks_per_row; i += blocks_per_wg)
calc_superblock(a_offset, b_offset, itid, i, num_blocks_per_row, first_row, num_rows);
reduce_result(temp, d_offset, first_row, num_rows, tid);
}
void main() {
const uint first_row = NUM_ROWS * (gl_WorkGroupID.x + gl_NumWorkGroups.x * gl_WorkGroupID.z);
init_iq_shmem(gl_WorkGroupSize);
// do NUM_ROWS at a time, unless there aren't enough remaining rows
if (first_row + NUM_ROWS <= p.stride_d) {
compute_outputs(first_row, NUM_ROWS);
} else {
if (first_row >= p.stride_d) {
return;
}
compute_outputs(first_row, p.stride_d - first_row);
}
}

View File

@ -12,13 +12,18 @@ layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
layout (binding = 1) readonly buffer B {B_TYPE data_b[];};
layout (binding = 2) writeonly buffer D {D_TYPE dst[];};
layout (binding = 0) readonly buffer AV4 {A_TYPE_VEC4 data_a_v4[];};
layout (binding = 1) readonly buffer BV4 {B_TYPE_VEC4 data_b_v4[];};
layout (push_constant) uniform parameter
{
uint ncols_x;
uint nrows_x;
uint row_stride_x;
uint channel_stride_x;
uint channel_stride_y;
uint channel_x_divisor;
uint ne12;
uint b_offset;
uint d_offset;
} p;
@ -30,6 +35,7 @@ void main() {
const uint row_x = gl_GlobalInvocationID.y;
const uint channel = gl_GlobalInvocationID.z;
const uint channel_x = channel / p.channel_x_divisor;
const uint channel_y = channel % p.ne12;
const uint nrows_y = p.ncols_x;
const uint nrows_dst = p.nrows_x;
@ -37,25 +43,66 @@ void main() {
const uint idst = channel*nrows_dst + row_dst;
tmp[tid] = 0.0f;
FLOAT_TYPE temp = 0.0f;
for (uint col_x0 = 0; col_x0 < p.ncols_x; col_x0 += BLOCK_SIZE) {
const uint col_x = col_x0 + tid;
// Detect alignment for vector loads
bool is_aligned = (p.ncols_x % 4) == 0 && (p.row_stride_x % 4) == 0 && (p.channel_stride_x % 4) == 0;
if (col_x >= p.ncols_x) {
break;
for (uint col_x0 = 0; col_x0 < p.ncols_x;) {
// Unroll 2x and do vec4 loads if aligned
const uint unroll_count = 2;
if (col_x0 + unroll_count * 4 * BLOCK_SIZE <= p.ncols_x && is_aligned) {
[[unroll]] for (uint i = 0; i < unroll_count; ++i) {
const uint col_x = col_x0 + 4*tid;
const uint row_y = col_x;
const uint ix = channel_x*p.channel_stride_x + row_x*p.row_stride_x + col_x;
const uint iy = channel_y*p.channel_stride_y + row_y;
const vec4 av4 = vec4(data_a_v4[ix / 4]);
const vec4 bv4 = vec4(data_b_v4[iy / 4]);
temp += dot(av4, bv4);
col_x0 += 4*BLOCK_SIZE;
}
// do vec4 loads if aligned
} else if (col_x0 + 4*BLOCK_SIZE <= p.ncols_x && is_aligned) {
const uint col_x = col_x0 + 4*tid;
const uint row_y = col_x;
const uint ix = channel_x*p.channel_stride_x + row_x*p.row_stride_x + col_x;
const uint iy = channel_y*p.channel_stride_y + row_y;
const vec4 av4 = vec4(data_a_v4[ix / 4]);
const vec4 bv4 = vec4(data_b_v4[iy / 4]);
temp += dot(av4, bv4);
col_x0 += 4*BLOCK_SIZE;
} else {
const uint col_x = col_x0 + tid;
if (col_x >= p.ncols_x) {
break;
}
const uint row_y = col_x;
const uint ix = channel_x*p.channel_stride_x + row_x*p.row_stride_x + col_x;
const uint iy = channel_y*p.channel_stride_y + row_y;
const FLOAT_TYPE xi = FLOAT_TYPE(data_a[ix]);
temp = fma(xi, FLOAT_TYPE(data_b[iy]), temp);
col_x0 += BLOCK_SIZE;
}
const uint row_y = col_x;
const uint ix = channel_x*p.channel_stride_x + row_x*p.row_stride_x + col_x;
const uint iy = channel*nrows_y + row_y;
const FLOAT_TYPE xi = FLOAT_TYPE(data_a[ix]);
tmp[tid] = fma(xi, FLOAT_TYPE(data_b[iy]), tmp[tid]);
}
tmp[tid] = temp;
// sum up partial sums and write back result
barrier();
[[unroll]] for (int s = BLOCK_SIZE / 2; s > 0; s >>= 1) {

View File

@ -2,16 +2,25 @@
#extension GL_EXT_control_flow_attributes : enable
#extension GL_EXT_shader_16bit_storage : require
#if USE_SUBGROUP_ADD
#extension GL_KHR_shader_subgroup_arithmetic : enable
#endif
#define BLOCK_SIZE 32
#define FLOAT_TYPE float
layout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in;
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
layout (binding = 1) readonly buffer B {B_TYPE data_b[];};
layout (binding = 2) writeonly buffer D {D_TYPE dst[];};
layout (binding = 0) readonly buffer AV4 {A_TYPE_VEC4 data_a_v4[];};
layout (binding = 1) readonly buffer BV4 {B_TYPE_VEC4 data_b_v4[];};
layout(constant_id = 0) const int BLOCK_SIZE = 32;
// gqa_ratio is in the range [1,8]
layout(constant_id = 1) const uint gqa_ratio = 1;
layout (push_constant) uniform parameter
{
uint ncols_x;
@ -22,52 +31,124 @@ layout (push_constant) uniform parameter
uint d_offset;
} p;
shared FLOAT_TYPE tmp[BLOCK_SIZE];
#if !USE_SUBGROUP_ADD
shared FLOAT_TYPE tmp[8][BLOCK_SIZE];
#endif
void main() {
const uint tid = gl_LocalInvocationID.x;
const uint row_x = gl_GlobalInvocationID.y;
const uint channel = gl_GlobalInvocationID.z;
const uint channel_x = channel / (p.nchannels_y / p.nchannels_x);
uint channel, channel_x;
// When gqa_ratio > 1, each invocation does multiple rows.
// The row in the A matrix is starting from channel / gqa_ratio and the
// rows in the B matrix are [channel, channel+gqa_ratio).
// When gpa_ratio is 1, each invocation does one row.
if (gqa_ratio > 1) {
channel_x = gl_GlobalInvocationID.z;
channel = channel_x * gqa_ratio;
} else {
channel = gl_GlobalInvocationID.z;
channel_x = channel / (p.nchannels_y / p.nchannels_x);;
}
const uint nrows_y = p.ncols_x;
const uint nrows_dst = p.nrows_x;
const uint row_dst = row_x;
tmp[tid] = FLOAT_TYPE(0.0f);
for (uint col_x0 = 0; col_x0 < p.ncols_x; col_x0 += BLOCK_SIZE) {
const uint col_x = col_x0 + tid;
if (col_x >= p.ncols_x) {
break;
}
// x is transposed and permuted
const uint ix = row_x*p.nchannels_x*p.ncols_x + channel_x*p.ncols_x + col_x;
const FLOAT_TYPE xi = FLOAT_TYPE(data_a[ix]);
const uint row_y = col_x;
// y is not transposed but permuted
const uint iy = channel*nrows_y + row_y;
tmp[tid] = fma(xi, FLOAT_TYPE(data_b[iy]), tmp[tid]);
FLOAT_TYPE temp[8];
[[unroll]] for (uint i = 0; i < 8; ++i) {
temp[i] = FLOAT_TYPE(0.0f);
}
// dst is not transposed and not permuted
const uint idst = channel*nrows_dst + row_dst;
// Detect alignment for vector loads
bool is_aligned = (p.ncols_x % 4) == 0 && (p.nchannels_x % 4) == 0 && (nrows_y % 4) == 0;
for (uint col_x0 = 0; col_x0 < p.ncols_x; col_x0 += BLOCK_SIZE) {
// Use vec4 loads if aligned
if (col_x0 + 4*BLOCK_SIZE <= p.ncols_x && is_aligned) {
uint col_x = col_x0 + 4*tid;
const uint row_y = col_x;
// x is transposed and permuted
const uint ix = row_x*p.nchannels_x*p.ncols_x + channel_x*p.ncols_x + col_x;
const vec4 av4 = vec4(data_a_v4[ix / 4]);
[[unroll]] for (uint c = 0; c < gqa_ratio; ++c) {
// y is not transposed but permuted
const uint iy = (channel + c)*nrows_y + row_y;
vec4 bv4 = data_b_v4[iy / 4];
temp[c] += dot(av4, bv4);
}
col_x0 += 3*BLOCK_SIZE;
} else {
const uint col_x = col_x0 + tid;
if (col_x >= p.ncols_x) {
break;
}
// x is transposed and permuted
const uint ix = row_x*p.nchannels_x*p.ncols_x + channel_x*p.ncols_x + col_x;
const FLOAT_TYPE xi = FLOAT_TYPE(data_a[ix]);
const uint row_y = col_x;
[[unroll]] for (uint c = 0; c < gqa_ratio; ++c) {
// y is not transposed but permuted
const uint iy = (channel + c)*nrows_y + row_y;
temp[c] = fma(xi, FLOAT_TYPE(data_b[iy]), temp[c]);
}
}
}
#if USE_SUBGROUP_ADD
// reduce vec4 at a time
vec4 t = vec4(temp[0], temp[1], temp[2], temp[3]);
t = subgroupAdd(t);
temp[0] = t[0];
temp[1] = t[1];
temp[2] = t[2];
temp[3] = t[3];
if (gqa_ratio > 4) {
t = vec4(temp[4], temp[5], temp[6], temp[7]);
t = subgroupAdd(t);
temp[4] = t[0];
temp[5] = t[1];
temp[6] = t[2];
temp[7] = t[3];
}
#else
[[unroll]] for (uint c = 0; c < gqa_ratio; ++c) {
tmp[c][tid] = temp[c];
}
// sum up partial sums and write back result
barrier();
[[unroll]] for (int s = BLOCK_SIZE / 2; s > 0; s >>= 1) {
if (tid < s) {
tmp[tid] += tmp[tid + s];
[[unroll]] for (uint c = 0; c < gqa_ratio; ++c) {
temp[c] += tmp[c][tid + s];
tmp[c][tid] = temp[c];
}
}
barrier();
}
[[unroll]] for (uint c = 0; c < gqa_ratio; ++c) {
temp[c] = tmp[c][tid];
}
#endif
if (tid == 0) {
dst[idst] = tmp[0];
[[unroll]] for (uint c = 0; c < gqa_ratio; ++c) {
// dst is not transposed and not permuted
const uint idst = (channel + c)*nrows_dst + row_dst;
dst[idst] = temp[c];
}
}
}

View File

@ -5,23 +5,24 @@
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
shared FLOAT_TYPE sccache1[BLOCK_SIZE/16][16];
shared FLOAT_TYPE sccache2[BLOCK_SIZE/16][16];
shared FLOAT_TYPE sccache1[2][BLOCK_SIZE/16][16];
shared FLOAT_TYPE sccache2[2][BLOCK_SIZE/16][16];
FLOAT_TYPE temp[NUM_COLS][NUM_ROWS];
uint csel = 0;
void calc_superblock(const uint a_offset, const uint b_offset, const uint itid, const uint v_im, const uint ix, const uint q_offset, const uint y_offset, const uint i, const uint num_blocks_per_row, const uint first_row, const uint num_rows, const bool all_threads) {
const uint y_idx = i * QUANT_K + y_offset;
[[unroll]] for (uint n = 0; n < num_rows; ++n) {
const uint ib0 = a_offset / QUANT_K + (first_row+n)*num_blocks_per_row;
csel ^= 1;
barrier();
if (!all_threads) { // when we don't have enough blocks to use all threads
if (i < num_blocks_per_row) {
const uint32_t scale = uint32_t(data_a[ib0 + i].scales[itid]);
sccache1[ix][itid] = FLOAT_TYPE(scale & 0xF);
sccache2[ix][itid] = FLOAT_TYPE((scale >> 4) & 0xF);
sccache1[csel][ix][itid] = FLOAT_TYPE(scale & 0xF);
sccache2[csel][ix][itid] = FLOAT_TYPE((scale >> 4) & 0xF);
}
barrier();
@ -29,8 +30,8 @@ void calc_superblock(const uint a_offset, const uint b_offset, const uint itid,
continue;
} else {
const uint32_t scale = uint32_t(data_a[ib0 + i].scales[itid]);
sccache1[ix][itid] = FLOAT_TYPE(scale & 0xF);
sccache2[ix][itid] = FLOAT_TYPE((scale >> 4) & 0xF);
sccache1[csel][ix][itid] = FLOAT_TYPE(scale & 0xF);
sccache2[csel][ix][itid] = FLOAT_TYPE((scale >> 4) & 0xF);
barrier();
}
@ -57,22 +58,22 @@ void calc_superblock(const uint a_offset, const uint b_offset, const uint itid,
FLOAT_TYPE sum1 = FLOAT_TYPE(0.0);
FLOAT_TYPE sum2 = FLOAT_TYPE(0.0);
[[unroll]] for (int l = 0; l < 2; ++l) {
sum1 = fma(FLOAT_TYPE(b0[l]), sccache1[ix][ 8*v_im] * qs_u32_0[l ],
fma(FLOAT_TYPE(b16[l]), sccache1[ix][1 + 8*v_im] * qs_u32_0[l+2],
fma(FLOAT_TYPE(b32[l]), sccache1[ix][2 + 8*v_im] * qs_u32_2[l ],
fma(FLOAT_TYPE(b48[l]), sccache1[ix][3 + 8*v_im] * qs_u32_2[l+2],
fma(FLOAT_TYPE(b64[l]), sccache1[ix][4 + 8*v_im] * qs_u32_4[l ],
fma(FLOAT_TYPE(b80[l]), sccache1[ix][5 + 8*v_im] * qs_u32_4[l+2],
fma(FLOAT_TYPE(b96[l]), sccache1[ix][6 + 8*v_im] * qs_u32_6[l ],
fma(FLOAT_TYPE(b112[l]), sccache1[ix][7 + 8*v_im] * qs_u32_6[l+2], sum1))))))));
sum2 = fma(FLOAT_TYPE(b0[l]), sccache2[ix][ 8*v_im],
fma(FLOAT_TYPE(b16[l]), sccache2[ix][1 + 8*v_im],
fma(FLOAT_TYPE(b32[l]), sccache2[ix][2 + 8*v_im],
fma(FLOAT_TYPE(b48[l]), sccache2[ix][3 + 8*v_im],
fma(FLOAT_TYPE(b64[l]), sccache2[ix][4 + 8*v_im],
fma(FLOAT_TYPE(b80[l]), sccache2[ix][5 + 8*v_im],
fma(FLOAT_TYPE(b96[l]), sccache2[ix][6 + 8*v_im],
fma(FLOAT_TYPE(b112[l]), sccache2[ix][7 + 8*v_im], sum2))))))));
sum1 = fma(FLOAT_TYPE(b0[l]), sccache1[csel][ix][ 8*v_im] * qs_u32_0[l ],
fma(FLOAT_TYPE(b16[l]), sccache1[csel][ix][1 + 8*v_im] * qs_u32_0[l+2],
fma(FLOAT_TYPE(b32[l]), sccache1[csel][ix][2 + 8*v_im] * qs_u32_2[l ],
fma(FLOAT_TYPE(b48[l]), sccache1[csel][ix][3 + 8*v_im] * qs_u32_2[l+2],
fma(FLOAT_TYPE(b64[l]), sccache1[csel][ix][4 + 8*v_im] * qs_u32_4[l ],
fma(FLOAT_TYPE(b80[l]), sccache1[csel][ix][5 + 8*v_im] * qs_u32_4[l+2],
fma(FLOAT_TYPE(b96[l]), sccache1[csel][ix][6 + 8*v_im] * qs_u32_6[l ],
fma(FLOAT_TYPE(b112[l]), sccache1[csel][ix][7 + 8*v_im] * qs_u32_6[l+2], sum1))))))));
sum2 = fma(FLOAT_TYPE(b0[l]), sccache2[csel][ix][ 8*v_im],
fma(FLOAT_TYPE(b16[l]), sccache2[csel][ix][1 + 8*v_im],
fma(FLOAT_TYPE(b32[l]), sccache2[csel][ix][2 + 8*v_im],
fma(FLOAT_TYPE(b48[l]), sccache2[csel][ix][3 + 8*v_im],
fma(FLOAT_TYPE(b64[l]), sccache2[csel][ix][4 + 8*v_im],
fma(FLOAT_TYPE(b80[l]), sccache2[csel][ix][5 + 8*v_im],
fma(FLOAT_TYPE(b96[l]), sccache2[csel][ix][6 + 8*v_im],
fma(FLOAT_TYPE(b112[l]), sccache2[csel][ix][7 + 8*v_im], sum2))))))));
}
temp[j][n] = fma(dall, sum1, fma(-dmin, sum2, temp[j][n]));
}

View File

@ -5,20 +5,21 @@
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
shared FLOAT_TYPE sccache[BLOCK_SIZE/16][2][8];
shared FLOAT_TYPE sccache[2][BLOCK_SIZE/16][2][8];
FLOAT_TYPE temp[NUM_COLS][NUM_ROWS];
uint csel = 0;
void calc_superblock(const uint a_offset, const uint b_offset, const uint ix, const uint itid8, const uint v_im, const uint v_im4, const uint v_in, const uint32_t hm_m[4], const uint q_offset, const uint y_offset, const uint s_shift, const uint i, const uint num_blocks_per_row, const uint first_row, const uint num_rows, const bool all_threads) {
const uint y_idx = i * QUANT_K + y_offset;
[[unroll]] for (uint n = 0; n < num_rows; ++n) {
const uint ib0 = a_offset / QUANT_K + (first_row+n)*num_blocks_per_row;
csel ^= 1;
if (!all_threads) { // when we don't have enough blocks to use all threads
barrier();
if (i < num_blocks_per_row)
sccache[ix][v_im][itid8] = FLOAT_TYPE(int8_t(((data_a[ib0+i].scales[itid8] >> v_im4) & 0xF) | (((data_a[ib0+i].scales[itid8%4+8] >> s_shift) & 3) << 4)) - 32);
sccache[csel][ix][v_im][itid8] = FLOAT_TYPE(int8_t(((data_a[ib0+i].scales[itid8] >> v_im4) & 0xF) | (((data_a[ib0+i].scales[itid8%4+8] >> s_shift) & 3) << 4)) - 32);
barrier();
if (i >= num_blocks_per_row)
@ -40,8 +41,7 @@ void calc_superblock(const uint a_offset, const uint b_offset, const uint ix, co
const vec4 qs_u32_6 = vec4(unpack8((qs_u32 >> 6) & 0x03030303));
if (all_threads) {
barrier();
sccache[ix][v_im][itid8] = FLOAT_TYPE(int8_t(((data_a[ib0+i].scales[itid8] >> v_im4) & 0xF) | (((data_a[ib0+i].scales[itid8%4+8] >> s_shift) & 3) << 4)) - 32);
sccache[csel][ix][v_im][itid8] = FLOAT_TYPE(int8_t(((data_a[ib0+i].scales[itid8] >> v_im4) & 0xF) | (((data_a[ib0+i].scales[itid8%4+8] >> s_shift) & 3) << 4)) - 32);
barrier();
}
@ -59,14 +59,14 @@ void calc_superblock(const uint a_offset, const uint b_offset, const uint ix, co
FLOAT_TYPE sum = FLOAT_TYPE(0.0);
[[unroll]] for (int l = 0; l < 2; ++l) {
sum = fma(FLOAT_TYPE( b0[l]) * sccache[ix][v_im][0], qs_u32_0[l ] - hmk_0[l ],
fma(FLOAT_TYPE( b16[l]) * sccache[ix][v_im][1], qs_u32_0[l+2] - hmk_0[l+2],
fma(FLOAT_TYPE( b32[l]) * sccache[ix][v_im][2], qs_u32_2[l ] - hmk_1[l ],
fma(FLOAT_TYPE( b48[l]) * sccache[ix][v_im][3], qs_u32_2[l+2] - hmk_1[l+2],
fma(FLOAT_TYPE( b64[l]) * sccache[ix][v_im][4], qs_u32_4[l ] - hmk_2[l ],
fma(FLOAT_TYPE( b80[l]) * sccache[ix][v_im][5], qs_u32_4[l+2] - hmk_2[l+2],
fma(FLOAT_TYPE( b96[l]) * sccache[ix][v_im][6], qs_u32_6[l ] - hmk_3[l ],
fma(FLOAT_TYPE(b112[l]) * sccache[ix][v_im][7], qs_u32_6[l+2] - hmk_3[l+2], sum))))))));
sum = fma(FLOAT_TYPE( b0[l]) * sccache[csel][ix][v_im][0], qs_u32_0[l ] - hmk_0[l ],
fma(FLOAT_TYPE( b16[l]) * sccache[csel][ix][v_im][1], qs_u32_0[l+2] - hmk_0[l+2],
fma(FLOAT_TYPE( b32[l]) * sccache[csel][ix][v_im][2], qs_u32_2[l ] - hmk_1[l ],
fma(FLOAT_TYPE( b48[l]) * sccache[csel][ix][v_im][3], qs_u32_2[l+2] - hmk_1[l+2],
fma(FLOAT_TYPE( b64[l]) * sccache[csel][ix][v_im][4], qs_u32_4[l ] - hmk_2[l ],
fma(FLOAT_TYPE( b80[l]) * sccache[csel][ix][v_im][5], qs_u32_4[l+2] - hmk_2[l+2],
fma(FLOAT_TYPE( b96[l]) * sccache[csel][ix][v_im][6], qs_u32_6[l ] - hmk_3[l ],
fma(FLOAT_TYPE(b112[l]) * sccache[csel][ix][v_im][7], qs_u32_6[l+2] - hmk_3[l+2], sum))))))));
}
temp[j][n] = fma(d, sum, temp[j][n]);
}

View File

@ -6,20 +6,21 @@
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
shared FLOAT_TYPE sccache[BLOCK_SIZE/16][16];
shared FLOAT_TYPE sccache[2][BLOCK_SIZE/16][16];
FLOAT_TYPE temp[NUM_COLS][NUM_ROWS];
uint csel = 0;
void calc_superblock(const uint a_offset, const uint b_offset, const uint itid, const uint ix, const uint ql_offset, const uint qh_offset, const uint s_offset, const uint y_offset, const uint i, const uint num_blocks_per_row, const uint first_row, const uint num_rows, const bool all_threads) {
const uint y_idx = i * QUANT_K + y_offset;
[[unroll]] for (uint n = 0; n < num_rows; ++n) {
const uint ib0 = a_offset / QUANT_K + (first_row+n)*num_blocks_per_row;
csel ^= 1;
if (!all_threads) { // when we don't have enough blocks to use all threads
barrier();
if (i < num_blocks_per_row)
sccache[ix][itid] = FLOAT_TYPE(data_a[ib0 + i].scales[itid]);
sccache[csel][ix][itid] = FLOAT_TYPE(data_a[ib0 + i].scales[itid]);
barrier();
if (i >= num_blocks_per_row)
@ -51,8 +52,7 @@ void calc_superblock(const uint a_offset, const uint b_offset, const uint itid,
const vec4 q3 = vec4(unpack8(q3_u32)) - 32;
if (all_threads) {
barrier();
sccache[ix][itid] = FLOAT_TYPE(data_a[ib0 + i].scales[itid]);
sccache[csel][ix][itid] = FLOAT_TYPE(data_a[ib0 + i].scales[itid]);
barrier();
}
@ -71,7 +71,7 @@ void calc_superblock(const uint a_offset, const uint b_offset, const uint itid,
sum[2] = fma(FLOAT_TYPE(by64[l]), q2[l], sum[2]);
sum[3] = fma(FLOAT_TYPE(by96[l]), q3[l], sum[3]);
}
temp[j][n] = fma(fma(sum[0], sccache[ix][s_offset], fma(sum[1], sccache[ix][s_offset + 2], fma(sum[2], sccache[ix][s_offset + 4], sum[3] * sccache[ix][s_offset + 6]))), d, temp[j][n]);
temp[j][n] = fma(fma(sum[0], sccache[csel][ix][s_offset], fma(sum[1], sccache[csel][ix][s_offset + 2], fma(sum[2], sccache[csel][ix][s_offset + 4], sum[3] * sccache[csel][ix][s_offset + 6]))), d, temp[j][n]);
}
}
}

View File

@ -10,6 +10,10 @@
#extension GL_EXT_shader_explicit_arithmetic_types_float16 : require
#endif
#if defined(DATA_A_BF16) && defined(COOPMAT)
#extension GL_EXT_bfloat16 : enable
#endif
#ifdef COOPMAT
#extension GL_KHR_cooperative_matrix : enable
#extension GL_KHR_memory_scope_semantics : enable
@ -29,9 +33,20 @@
#define LOAD_VEC_B 1
#endif
#if !defined(TO_FLOAT_TYPE)
#define TO_FLOAT_TYPE FLOAT_TYPE
#endif
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
#if defined(A_TYPE_PACKED16)
layout (binding = 0) readonly buffer A_PACKED16 {A_TYPE_PACKED16 data_a_packed16[];};
#endif
#if defined(A_TYPE_PACKED32)
layout (binding = 0) readonly buffer A_PACKED32 {A_TYPE_PACKED32 data_a_packed32[];};
#endif
layout (binding = 1) readonly buffer B {B_TYPE data_b[];};
layout (binding = 2) writeonly buffer D {D_TYPE data_d[];};
@ -88,7 +103,7 @@ shared FLOAT_TYPE buf_a[BM * SHMEM_STRIDE];
shared FLOAT_TYPE buf_b[BN * SHMEM_STRIDE];
#ifdef MUL_MAT_ID
shared u16vec2 row_ids[3072];
shared u16vec2 row_ids[4096];
#endif // MUL_MAT_ID
#define NUM_WARPS (BLOCK_SIZE / WARP)
@ -195,8 +210,8 @@ void main() {
#endif
#ifdef COOPMAT
coopmat<float16_t, gl_ScopeSubgroup, TM, TK, gl_MatrixUseA> cache_a;
coopmat<float16_t, gl_ScopeSubgroup, TK, TN, gl_MatrixUseB> cache_b;
coopmat<FLOAT_TYPE, gl_ScopeSubgroup, TM, TK, gl_MatrixUseA> cache_a;
coopmat<FLOAT_TYPE, gl_ScopeSubgroup, TK, TN, gl_MatrixUseB> cache_b;
coopmat<ACC_TYPE, gl_ScopeSubgroup, TM, TN, gl_MatrixUseAccumulator> sums[cms_per_row * cms_per_col];
[[unroll]] for (uint i = 0; i < cms_per_row * cms_per_col; i++) {
@ -205,7 +220,7 @@ void main() {
#else
ACC_TYPE sums[WMITER * TM * WNITER * TN];
FLOAT_TYPE cache_a[WMITER * TM];
FLOAT_TYPE cache_b[WNITER * TN];
FLOAT_TYPE cache_b[TN];
[[unroll]] for (uint i = 0; i < WMITER*TM*WNITER*TN; i++) {
sums[i] = ACC_TYPE(0.0f);
@ -241,76 +256,117 @@ void main() {
buf_a[(loadc_a + l) * SHMEM_STRIDE + loadr_a] = FLOAT_TYPE(0.0f);
}
#endif
#elif defined(DATA_A_BF16)
#if LOAD_VEC_A == 4
const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A;
buf_a[buf_idx ] = TO_FLOAT_TYPE(data_a[idx].x);
buf_a[buf_idx + 1] = TO_FLOAT_TYPE(data_a[idx].y);
buf_a[buf_idx + 2] = TO_FLOAT_TYPE(data_a[idx].z);
buf_a[buf_idx + 3] = TO_FLOAT_TYPE(data_a[idx].w);
#else
if (ir * BM + loadc_a + l < p.M && block + loadr_a < end_k) {
buf_a[(loadc_a + l) * SHMEM_STRIDE + loadr_a] = TO_FLOAT_TYPE(data_a[pos_a + (loadc_a + l) * p.stride_a + loadr_a]);
} else {
buf_a[(loadc_a + l) * SHMEM_STRIDE + loadr_a] = TO_FLOAT_TYPE(uint16_t(0));
}
#endif
#elif defined(DATA_A_Q4_0)
const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a;
const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + 4 * loadr_a;
const uint ib = idx / 16;
const uint iqs = idx & 0xF;
const uint ib = idx / 4;
const uint iqs = idx & 0x03;
const float d = float(data_a[ib].d);
const uint vui = uint(data_a[ib].qs[iqs]);
const vec2 v = (vec2(vui & 0xF, vui >> 4) - 8.0f) * d;
const float d = float(data_a_packed16[ib].d);
const uint vui = uint(data_a_packed16[ib].qs[2*iqs]) | (uint(data_a_packed16[ib].qs[2*iqs + 1]) << 16);
const vec4 v0 = (vec4(unpack8(vui & 0x0F0F0F0F)) - 8.0f) * d;
const vec4 v1 = (vec4(unpack8((vui >> 4) & 0x0F0F0F0F)) - 8.0f) * d;
buf_a[buf_idx ] = FLOAT_TYPE(v.x);
buf_a[buf_idx + 16] = FLOAT_TYPE(v.y);
buf_a[buf_idx ] = FLOAT_TYPE(v0.x);
buf_a[buf_idx + 1 ] = FLOAT_TYPE(v0.y);
buf_a[buf_idx + 2 ] = FLOAT_TYPE(v0.z);
buf_a[buf_idx + 3 ] = FLOAT_TYPE(v0.w);
buf_a[buf_idx + 16] = FLOAT_TYPE(v1.x);
buf_a[buf_idx + 17] = FLOAT_TYPE(v1.y);
buf_a[buf_idx + 18] = FLOAT_TYPE(v1.z);
buf_a[buf_idx + 19] = FLOAT_TYPE(v1.w);
#elif defined(DATA_A_Q4_1)
const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a;
const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + 4 * loadr_a;
const uint ib = idx / 16;
const uint iqs = idx & 0xF;
const uint ib = idx / 4;
const uint iqs = idx & 0x03;
const float d = float(data_a[ib].d);
const float m = float(data_a[ib].m);
const uint vui = uint(data_a[ib].qs[iqs]);
const vec2 v = vec2(vui & 0xF, vui >> 4) * d + m;
const float d = float(data_a_packed16[ib].d);
const float m = float(data_a_packed16[ib].m);
const uint vui = uint(data_a_packed16[ib].qs[2*iqs]) | (uint(data_a_packed16[ib].qs[2*iqs + 1]) << 16);
const vec4 v0 = vec4(unpack8(vui & 0x0F0F0F0F)) * d + m;
const vec4 v1 = vec4(unpack8((vui >> 4) & 0x0F0F0F0F)) * d + m;
buf_a[buf_idx ] = FLOAT_TYPE(v.x);
buf_a[buf_idx + 16] = FLOAT_TYPE(v.y);
buf_a[buf_idx ] = FLOAT_TYPE(v0.x);
buf_a[buf_idx + 1 ] = FLOAT_TYPE(v0.y);
buf_a[buf_idx + 2 ] = FLOAT_TYPE(v0.z);
buf_a[buf_idx + 3 ] = FLOAT_TYPE(v0.w);
buf_a[buf_idx + 16] = FLOAT_TYPE(v1.x);
buf_a[buf_idx + 17] = FLOAT_TYPE(v1.y);
buf_a[buf_idx + 18] = FLOAT_TYPE(v1.z);
buf_a[buf_idx + 19] = FLOAT_TYPE(v1.w);
#elif defined(DATA_A_Q5_0)
const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a;
const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + 2 * loadr_a;
const uint ib = idx / 16;
const uint iqs = idx & 0xF;
const uint ib = idx / 8;
const uint iqs = idx & 0x07;
const float d = float(data_a[ib].d);
const uint uint_qh = uint(data_a[ib].qh[1]) << 16 | data_a[ib].qh[0];
const ivec2 qh = ivec2(((uint_qh >> iqs) << 4) & 0x10, (uint_qh >> (iqs + 12)) & 0x10);
const uint vui = uint(data_a[ib].qs[iqs]);
const vec2 v = (vec2((vui & 0xF) | qh.x, (vui >> 4) | qh.y) - 16.0f) * d;
const float d = float(data_a_packed16[ib].d);
const uint uint_qh = uint(data_a_packed16[ib].qh[1]) << 16 | uint(data_a_packed16[ib].qh[0]);
const ivec2 qh0 = ivec2(((uint_qh >> 2*iqs) << 4) & 0x10, (uint_qh >> (2*iqs + 12)) & 0x10);
const ivec2 qh1 = ivec2(((uint_qh >> (2*iqs + 1)) << 4) & 0x10, (uint_qh >> (2*iqs + 13)) & 0x10);
const uint vui = uint(data_a_packed16[ib].qs[iqs]);
const vec4 v = (vec4((vui & 0xF) | qh0.x, ((vui >> 4) & 0xF) | qh0.y, ((vui >> 8) & 0xF) | qh1.x, (vui >> 12) | qh1.y) - 16.0f) * d;
buf_a[buf_idx ] = FLOAT_TYPE(v.x);
buf_a[buf_idx + 1 ] = FLOAT_TYPE(v.z);
buf_a[buf_idx + 16] = FLOAT_TYPE(v.y);
buf_a[buf_idx + 17] = FLOAT_TYPE(v.w);
#elif defined(DATA_A_Q5_1)
const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a;
const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + 2 * loadr_a;
const uint ib = idx / 16;
const uint iqs = idx & 0xF;
const uint ib = idx / 8;
const uint iqs = idx & 0x07;
const float d = float(data_a[ib].d);
const float m = float(data_a[ib].m);
const uint uint_qh = data_a[ib].qh;
const ivec2 qh = ivec2(((uint_qh >> iqs) << 4) & 0x10, (uint_qh >> (iqs + 12)) & 0x10);
const uint vui = uint(data_a[ib].qs[iqs]);
const vec2 v = vec2((vui & 0xF) | qh.x, (vui >> 4) | qh.y) * d + m;
const float d = float(data_a_packed16[ib].d);
const float m = float(data_a_packed16[ib].m);
const uint uint_qh = data_a_packed16[ib].qh;
const ivec2 qh0 = ivec2(((uint_qh >> 2*iqs) << 4) & 0x10, (uint_qh >> (2*iqs + 12)) & 0x10);
const ivec2 qh1 = ivec2(((uint_qh >> (2*iqs + 1)) << 4) & 0x10, (uint_qh >> (2*iqs + 13)) & 0x10);
const uint vui = uint(data_a_packed16[ib].qs[iqs]);
const vec4 v = vec4((vui & 0xF) | qh0.x, ((vui >> 4) & 0xF) | qh0.y, ((vui >> 8) & 0xF) | qh1.x, (vui >> 12) | qh1.y) * d + m;
buf_a[buf_idx ] = FLOAT_TYPE(v.x);
buf_a[buf_idx + 1 ] = FLOAT_TYPE(v.z);
buf_a[buf_idx + 16] = FLOAT_TYPE(v.y);
buf_a[buf_idx + 17] = FLOAT_TYPE(v.w);
#elif defined(DATA_A_Q8_0)
const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A;
const uint ib = idx / 16;
const uint iqs = (idx & 0xF) * 2;
const uint ib = idx / 8;
const uint iqs = idx & 0x07;
const float d = float(data_a[ib].d);
const vec2 v = vec2(int(data_a[ib].qs[iqs]), int(data_a[ib].qs[iqs + 1])) * d;
const float d = float(data_a_packed16[ib].d);
const i8vec2 v0 = unpack8(int32_t(data_a_packed16[ib].qs[2*iqs])).xy; // vec4 used due to #12147
const i8vec2 v1 = unpack8(int32_t(data_a_packed16[ib].qs[2*iqs + 1])).xy;
const vec4 v = vec4(v0.x, v0.y, v1.x, v1.y) * d;
buf_a[buf_idx ] = FLOAT_TYPE(v.x);
buf_a[buf_idx + 1] = FLOAT_TYPE(v.y);
buf_a[buf_idx + 2] = FLOAT_TYPE(v.z);
buf_a[buf_idx + 3] = FLOAT_TYPE(v.w);
#elif defined(DATA_A_Q2_K)
const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A;
@ -511,7 +567,7 @@ void main() {
const uint sign = (sign7 | (bitCount(sign7) << 7)) >> (2 * (idx % 4));
const i8vec2 sign01 = i8vec2(1 - (2 & i8vec2(int8_t(sign << 1), int8_t(sign))));
const uint grid = iq2xxs_grid[qs][(idx % 4) / 2] >> (16 * (idx & 1));
const vec2 v = db * vec2(sign01) * vec2(unpack8(grid).xy);
const vec2 v = db * vec2(sign01) * vec2(unpack8(grid).xy); // vec4 used due to #12147
buf_a[buf_idx ] = FLOAT_TYPE(v.x);
buf_a[buf_idx + 1] = FLOAT_TYPE(v.y);
@ -531,7 +587,7 @@ void main() {
const uint sign = (sign7 | (bitCount(sign7) << 7)) >> (2 * (idx % 4));
const i8vec2 sign01 = i8vec2(1 - (2 & i8vec2(int8_t(sign << 1), int8_t(sign))));
const uint grid = iq2xs_grid[qs & 511][(idx % 4) / 2] >> (16 * (idx & 1));
const vec2 v = db * vec2(sign01) * vec2(unpack8(grid).xy);
const vec2 v = db * vec2(sign01) * vec2(unpack8(grid).xy); // vec4 used due to #12147
buf_a[buf_idx ] = FLOAT_TYPE(v.x);
buf_a[buf_idx + 1] = FLOAT_TYPE(v.y);
@ -553,7 +609,7 @@ void main() {
const float db = d * 0.25 * (0.5 + scale);
const i8vec2 sign01 = i8vec2(1 - (2 & i8vec2(int8_t(sign << 1), int8_t(sign))));
const uint16_t grid = unpack16(iq2s_grid[qs | ((qh << (8 - qhshift)) & 0x300)][(idx & 2) >> 1])[idx & 1];
const vec2 v = db * vec2(sign01) * vec2(unpack8(grid));
const vec2 v = db * vec2(sign01) * vec2(unpack8(uint32_t(grid)).xy); // vec4 used due to #12147
buf_a[buf_idx ] = FLOAT_TYPE(v.x);
buf_a[buf_idx + 1] = FLOAT_TYPE(v.y);
@ -578,7 +634,7 @@ void main() {
const uint sign = (sign7 | (bitCount(sign7) << 7)) >> (2 * (idx % 4));
const i8vec2 sign01 = i8vec2(1 - (2 & i8vec2(int8_t(sign << 1), int8_t(sign))));
const uint grid = iq3xxs_grid[qs] >> (16 * (idx & 1));
const vec2 v = db * vec2(sign01) * vec2(unpack8(grid).xy);
const vec2 v = db * vec2(sign01) * vec2(unpack8(grid).xy); // vec4 used due to #12147
buf_a[buf_idx ] = FLOAT_TYPE(v.x);
buf_a[buf_idx + 1] = FLOAT_TYPE(v.y);
@ -598,7 +654,7 @@ void main() {
const i8vec2 sign01 = i8vec2(1 - (2 & i8vec2(sign << 1, sign)));
const float db = d * (1 + 2 * ((scale >> (4 * (iqh & 1))) & 0xf));
const uint32_t grid = iq3s_grid[qs | ((qh << (8 - (iqs % 8))) & 256)] >> (16 * (idx % 2));
const vec2 v = db * vec2(sign01) * vec2(unpack8(grid).xy);
const vec2 v = db * vec2(sign01) * vec2(unpack8(grid).xy); // vec4 used due to #12147
buf_a[buf_idx ] = FLOAT_TYPE(v.x);
buf_a[buf_idx + 1] = FLOAT_TYPE(v.y);
@ -623,17 +679,18 @@ void main() {
buf_a[buf_idx + 1] = FLOAT_TYPE(v.y);
#elif defined(DATA_A_IQ4_NL)
const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a;
const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + 2 * loadr_a;
const uint ib = idx / 16;
const uint iqs = idx & 0xF;
const uint ib = idx / 8;
const uint iqs = idx & 0x07;
const float d = float(data_a[ib].d);
const uint vui = uint(data_a[ib].qs[iqs]);
const vec2 v = vec2(kvalues_iq4nl[vui & 0xF], kvalues_iq4nl[vui >> 4]) * d;
const FLOAT_TYPE d = FLOAT_TYPE(data_a_packed16[ib].d);
const uint vui = uint(data_a_packed16[ib].qs[iqs]);
buf_a[buf_idx ] = FLOAT_TYPE(v.x);
buf_a[buf_idx + 16] = FLOAT_TYPE(v.y);
buf_a[buf_idx ] = FLOAT_TYPE(kvalues_iq4nl[vui & 0xF]) * d;
buf_a[buf_idx + 1 ] = FLOAT_TYPE(kvalues_iq4nl[bitfieldExtract(vui, 8, 4)]) * d;
buf_a[buf_idx + 16] = FLOAT_TYPE(kvalues_iq4nl[bitfieldExtract(vui, 4, 4)]) * d;
buf_a[buf_idx + 17] = FLOAT_TYPE(kvalues_iq4nl[vui >> 12]) * d;
#endif
}
[[unroll]] for (uint l = 0; l < BN; l += loadstride_b) {
@ -661,13 +718,13 @@ void main() {
const uint idx = pos_b + (loadc_b + l) * p.stride_b / LOAD_VEC_B + loadr_b;
#endif
const uint buf_idx = (loadc_b + l) * SHMEM_STRIDE + loadr_b * LOAD_VEC_B;
buf_b[buf_idx + 0] = FLOAT_TYPE(data_b[idx].x);
buf_b[buf_idx + 1] = FLOAT_TYPE(data_b[idx].y);
buf_b[buf_idx + 2] = FLOAT_TYPE(data_b[idx].z);
buf_b[buf_idx + 3] = FLOAT_TYPE(data_b[idx].w);
buf_b[buf_idx + 0] = TO_FLOAT_TYPE(data_b[idx].x);
buf_b[buf_idx + 1] = TO_FLOAT_TYPE(data_b[idx].y);
buf_b[buf_idx + 2] = TO_FLOAT_TYPE(data_b[idx].z);
buf_b[buf_idx + 3] = TO_FLOAT_TYPE(data_b[idx].w);
#elif !MUL_MAT_ID
if (ic * BN + loadc_b + l < p.N && block + loadr_b < end_k) {
buf_b[(loadc_b + l) * SHMEM_STRIDE + loadr_b] = FLOAT_TYPE(data_b[pos_b + (loadc_b + l) * p.stride_b + loadr_b]);
buf_b[(loadc_b + l) * SHMEM_STRIDE + loadr_b] = TO_FLOAT_TYPE(data_b[pos_b + (loadc_b + l) * p.stride_b + loadr_b]);
} else {
buf_b[(loadc_b + l) * SHMEM_STRIDE + loadr_b] = FLOAT_TYPE(0.0f);
}
@ -675,7 +732,7 @@ void main() {
const uint row_i = ic * BN + loadc_b + l;
if (row_i < _ne1) {
const u16vec2 row_idx = row_ids[row_i];
buf_b[(loadc_b + l) * SHMEM_STRIDE + loadr_b] = FLOAT_TYPE(data_b[pos_b + row_idx.y * p.batch_stride_b + (row_idx.x % p.ne11) * p.stride_b + loadr_b]);
buf_b[(loadc_b + l) * SHMEM_STRIDE + loadr_b] = TO_FLOAT_TYPE(data_b[pos_b + row_idx.y * p.batch_stride_b + (row_idx.x % p.ne11) * p.stride_b + loadr_b]);
} else {
buf_b[(loadc_b + l) * SHMEM_STRIDE + loadr_b] = FLOAT_TYPE(0.0f);
}
@ -710,16 +767,14 @@ void main() {
}
[[unroll]] for (uint wsic = 0; wsic < WNITER; wsic++) {
[[unroll]] for (uint j = 0; j < TN; j++) {
cache_b[wsic * TN + j] = buf_b[(warp_c * WN + wsic * WSUBN + tiwc * TN + j) * SHMEM_STRIDE + i];
cache_b[j] = buf_b[(warp_c * WN + wsic * WSUBN + tiwc * TN + j) * SHMEM_STRIDE + i];
}
}
[[unroll]] for (uint wsic = 0; wsic < WNITER; wsic++) {
[[unroll]] for (uint wsir = 0; wsir < WMITER; wsir++) {
[[unroll]] for (uint cc = 0; cc < TN; cc++) {
[[unroll]] for (uint cr = 0; cr < TM; cr++) {
const uint sums_idx = (wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr;
sums[sums_idx] = fma(ACC_TYPE(cache_a[wsir * TM + cr]), ACC_TYPE(cache_b[wsic * TN + cc]), sums[sums_idx]);
sums[sums_idx] = fma(ACC_TYPE(cache_a[wsir * TM + cr]), ACC_TYPE(cache_b[cc]), sums[sums_idx]);
}
}
}
@ -743,7 +798,7 @@ void main() {
[[unroll]] for (uint cm_col = 0; cm_col < cms_per_col; cm_col++) {
coopMatStore(sums[cm_col * cms_per_row + cm_row], coopmat_stage, warp_i * TM * TN, TM, gl_CooperativeMatrixLayoutColumnMajor);
[[unroll]] for (uint col = 0; col < BN; col += storestride) {
[[unroll]] for (uint col = 0; col < TN; col += storestride) {
const uint row_i = dc + cm_col * TN + col + store_c;
if (row_i >= _ne1) break;

View File

@ -14,15 +14,25 @@
#extension GL_EXT_buffer_reference : enable
#extension GL_KHR_shader_subgroup_ballot : enable
#extension GL_KHR_shader_subgroup_vote : enable
#ifdef DATA_A_BF16
#extension GL_EXT_bfloat16 : enable
#endif
#include "types.comp"
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
#define IS_MUL_MM2 1
layout (constant_id = 0) const uint BLOCK_SIZE = 256;
layout (constant_id = 1) const uint BM = 64;
layout (constant_id = 2) const uint BN = 64;
layout (constant_id = 3) const uint BK = 16; // Assumed to be 32 if working with a quant
layout (constant_id = 4) const bool enable_smaller_matrices = false;
const uint BNover2 = enable_smaller_matrices ? (BN / 2) : BN;
const uint BNover4 = enable_smaller_matrices ? (BN / 4) : BN;
layout (push_constant) uniform parameter
{
uint M;
@ -48,6 +58,8 @@ layout (push_constant) uniform parameter
uint broadcast2;
uint broadcast3;
#endif
// N dimension for the B matrix can be >= p.N
uint padded_N;
} p;
@ -64,10 +76,23 @@ layout (binding = 2) writeonly buffer D {D_TYPE data_d[];};
#define DECODEFUNCA
#endif
#if !defined(fetch_scales)
#define fetch_scales(a, b, c, d, e, f)
#endif
#if !defined(store_scales)
#define store_scales(a)
#endif
#if defined(DATA_A_BF16)
#define MAT_TYPE bfloat16_t
#else
#define MAT_TYPE FLOAT_TYPE
#endif
#ifdef MUL_MAT_ID
layout (binding = 3) readonly buffer IDS {int data_ids[];};
shared u16vec4 row_ids[3072];
shared u16vec4 row_ids[4096];
layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufB {
B_TYPE b[];
@ -110,6 +135,8 @@ void main() {
init_iq_shmem(gl_WorkGroupSize);
#endif
const uint tid = gl_LocalInvocationIndex;
#ifdef MUL_MAT_ID
const uint expert_idx = gl_GlobalInvocationID.z;
#else
@ -166,15 +193,13 @@ void main() {
const uint end_k = min(p.K, (ik + 1) * p.k_split);
#endif
coopmat<ACC_TYPE, gl_ScopeWorkgroup, BM, BN, gl_MatrixUseAccumulator> sum;
sum = coopmat<ACC_TYPE, gl_ScopeWorkgroup, BM, BN, gl_MatrixUseAccumulator>(0.0);
#ifdef MUL_MAT_ID
uint pos_a = (expert_idx * p.batch_stride_a) / QUANT_K;
uint pos_b = 0;
#else
uint pos_a = (batch_idx_a * p.batch_stride_a) / QUANT_K;
uint pos_b = batch_idx * p.batch_stride_b;
uint pos_d = batch_idx * p.batch_stride_d + ik * p.batch_stride_d * gl_NumWorkGroups.z;
#endif
uint stride_a = p.stride_a / QUANT_K;
@ -195,6 +220,7 @@ void main() {
tensorLayoutNV<2> tensorLayoutB = createTensorLayoutNV(2);
tensorLayoutNV<2, gl_CooperativeMatrixClampModeConstantNV> tensorLayoutBClamp = createTensorLayoutNV(2, gl_CooperativeMatrixClampModeConstantNV);
tensorLayoutNV<2, gl_CooperativeMatrixClampModeConstantNV> tensorLayoutD = createTensorLayoutNV(2, gl_CooperativeMatrixClampModeConstantNV);
tensorLayoutD = setTensorLayoutStrideNV(tensorLayoutD, p.stride_d, 1);
#if QUANT_K > 1
tensorLayoutA = setTensorLayoutBlockSizeNV(tensorLayoutA, 1, QUANT_K);
@ -202,24 +228,32 @@ void main() {
#endif
// Use end_k rather than p.K as the dimension because that's what
// we need to bound check against when using split_k
// we need to bound check against when using split_k.
// Bounds check B against padded_N, but bounds check D against N.
tensorLayoutA = setTensorLayoutDimensionNV(tensorLayoutA, p.M, end_k);
tensorLayoutB = setTensorLayoutDimensionNV(tensorLayoutB, p.N, end_k);
tensorLayoutB = setTensorLayoutDimensionNV(tensorLayoutB, p.padded_N, end_k);
tensorLayoutD = setTensorLayoutDimensionNV(tensorLayoutD, p.N, p.M);
tensorLayoutAClamp = setTensorLayoutDimensionNV(tensorLayoutAClamp, p.M, end_k);
tensorLayoutBClamp = setTensorLayoutDimensionNV(tensorLayoutBClamp, p.N, end_k);
tensorLayoutBClamp = setTensorLayoutDimensionNV(tensorLayoutBClamp, p.padded_N, end_k);
tensorViewNV<2, false, 1, 0> tensorViewTranspose = createTensorViewNV(2, false, 1, 0);
#if !defined(MUL_MAT_ID)
const uint START_ALIGN_K = 256;
// For Qi_K (block size 256), unroll whole 256 element tiles.
// For legacy quants (block size 32), unroll 8x.
const uint UNROLL_K = (QUANT_K == 256) ? 256 : (BK * 8);
const uint unroll_count = UNROLL_K / BK;
// Detect a fast path where all loads are entirely in bounds and no clamping is required
if ((ir + 1) * BM <= p.M && (ic + 1) * BN <= p.N && (start_k % BK) == 0 && (end_k % BK) == 0 &&
if ((ir + 1) * BM <= p.M && (ic + 1) * BN <= p.padded_N && (start_k % START_ALIGN_K) == 0 && (end_k % BK) == 0 &&
#if QUANT_K == 1
(stride_a % 8) == 0 &&
#endif
(stride_b % 8) == 0 && (start_k % 8) == 0) {
(stride_b % 8) == 0) {
// Hint to the compiler that values are aligned (want 16B alignment)
start_k &= ~7;
start_k &= ~(START_ALIGN_K-1);
stride_b &= ~7;
#if QUANT_K == 1
stride_a &= ~7;
@ -228,17 +262,131 @@ void main() {
tensorLayoutA = setTensorLayoutStrideNV(tensorLayoutA, stride_a, 1);
tensorLayoutB = setTensorLayoutStrideNV(tensorLayoutB, stride_b, 1);
uint k_iters = (end_k - start_k + BK - 1) / BK;
uint k_iters = (end_k - start_k) / UNROLL_K;
uint block_k = start_k;
for (uint block_k = start_k, i = 0; i < k_iters; block_k += BK, ++i) {
// fetch scale values for a tile of quants. These will be copied into shared memory.
// The fetches and stores are pipelined to hide the latency.
fetch_scales(ir * BM, pos_a, stride_a, start_k, tid, true);
coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA> mat_a;
coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, BK, BN, gl_MatrixUseB> mat_b;
if (enable_smaller_matrices && ic * BN + BNover4 >= p.N) {
coopmat<ACC_TYPE, gl_ScopeWorkgroup, BM, BNover4, gl_MatrixUseAccumulator> sum = coopmat<ACC_TYPE, gl_ScopeWorkgroup, BM, BNover4, gl_MatrixUseAccumulator>(0.0);
for (uint i = 0; i < k_iters; ++i) {
coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutA, ir * BM, BM, block_k, BK) DECODEFUNCA);
coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BN, block_k, BK), tensorViewTranspose);
store_scales(tid);
if (block_k + UNROLL_K < end_k) {
fetch_scales(ir * BM, pos_a, stride_a, block_k + UNROLL_K, tid, true);
}
sum = coopMatMulAdd(mat_a, mat_b, sum);
// Manually partial unroll
[[unroll]] for (uint j = 0; j < unroll_count; ++j) {
coopmat<MAT_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA> mat_a;
coopmat<MAT_TYPE, gl_ScopeWorkgroup, BK, BNover4, gl_MatrixUseB> mat_b;
coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutA, ir * BM, BM, block_k, BK) DECODEFUNCA);
coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BNover4, block_k, BK), tensorViewTranspose);
sum = coopMatMulAdd(mat_a, mat_b, sum);
block_k += BK;
}
}
// Do any remaining iterations that were not unrolled
if (block_k < end_k) {
store_scales(tid);
}
while (block_k < end_k) {
coopmat<MAT_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA> mat_a;
coopmat<MAT_TYPE, gl_ScopeWorkgroup, BK, BNover4, gl_MatrixUseB> mat_b;
coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutA, ir * BM, BM, block_k, BK) DECODEFUNCA);
coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BNover4, block_k, BK), tensorViewTranspose);
sum = coopMatMulAdd(mat_a, mat_b, sum);
block_k += BK;
}
coopmat<D_TYPE, gl_ScopeWorkgroup, BM, BNover4, gl_MatrixUseAccumulator> mat_d = coopmat<D_TYPE, gl_ScopeWorkgroup, BM, BNover4, gl_MatrixUseAccumulator>(sum);
coopMatStoreTensorNV(mat_d, data_d, pos_d, sliceTensorLayoutNV(tensorLayoutD, ic * BN, BNover4, ir * BM, BM), tensorViewTranspose);
return;
} else if (enable_smaller_matrices && ic * BN + BNover2 >= p.N) {
coopmat<ACC_TYPE, gl_ScopeWorkgroup, BM, BNover2, gl_MatrixUseAccumulator> sum = coopmat<ACC_TYPE, gl_ScopeWorkgroup, BM, BNover2, gl_MatrixUseAccumulator>(0.0);
for (uint i = 0; i < k_iters; ++i) {
store_scales(tid);
if (block_k + UNROLL_K < end_k) {
fetch_scales(ir * BM, pos_a, stride_a, block_k + UNROLL_K, tid, true);
}
// Manually partial unroll
[[unroll]] for (uint j = 0; j < unroll_count; ++j) {
coopmat<MAT_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA> mat_a;
coopmat<MAT_TYPE, gl_ScopeWorkgroup, BK, BNover2, gl_MatrixUseB> mat_b;
coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutA, ir * BM, BM, block_k, BK) DECODEFUNCA);
coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BNover2, block_k, BK), tensorViewTranspose);
sum = coopMatMulAdd(mat_a, mat_b, sum);
block_k += BK;
}
}
// Do any remaining iterations that were not unrolled
if (block_k < end_k) {
store_scales(tid);
}
while (block_k < end_k) {
coopmat<MAT_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA> mat_a;
coopmat<MAT_TYPE, gl_ScopeWorkgroup, BK, BNover2, gl_MatrixUseB> mat_b;
coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutA, ir * BM, BM, block_k, BK) DECODEFUNCA);
coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BNover2, block_k, BK), tensorViewTranspose);
sum = coopMatMulAdd(mat_a, mat_b, sum);
block_k += BK;
}
coopmat<D_TYPE, gl_ScopeWorkgroup, BM, BNover2, gl_MatrixUseAccumulator> mat_d = coopmat<D_TYPE, gl_ScopeWorkgroup, BM, BNover2, gl_MatrixUseAccumulator>(sum);
coopMatStoreTensorNV(mat_d, data_d, pos_d, sliceTensorLayoutNV(tensorLayoutD, ic * BN, BNover2, ir * BM, BM), tensorViewTranspose);
return;
} else {
coopmat<ACC_TYPE, gl_ScopeWorkgroup, BM, BN, gl_MatrixUseAccumulator> sum = coopmat<ACC_TYPE, gl_ScopeWorkgroup, BM, BN, gl_MatrixUseAccumulator>(0.0);
for (uint i = 0; i < k_iters; ++i) {
store_scales(tid);
if (block_k + UNROLL_K < end_k) {
fetch_scales(ir * BM, pos_a, stride_a, block_k + UNROLL_K, tid, true);
}
// Manually partial unroll
[[unroll]] for (uint j = 0; j < unroll_count; ++j) {
coopmat<MAT_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA> mat_a;
coopmat<MAT_TYPE, gl_ScopeWorkgroup, BK, BN, gl_MatrixUseB> mat_b;
coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutA, ir * BM, BM, block_k, BK) DECODEFUNCA);
coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BN, block_k, BK), tensorViewTranspose);
sum = coopMatMulAdd(mat_a, mat_b, sum);
block_k += BK;
}
}
// Do any remaining iterations that were not unrolled
if (block_k < end_k) {
store_scales(tid);
}
while (block_k < end_k) {
coopmat<MAT_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA> mat_a;
coopmat<MAT_TYPE, gl_ScopeWorkgroup, BK, BN, gl_MatrixUseB> mat_b;
coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutA, ir * BM, BM, block_k, BK) DECODEFUNCA);
coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BN, block_k, BK), tensorViewTranspose);
sum = coopMatMulAdd(mat_a, mat_b, sum);
block_k += BK;
}
coopmat<D_TYPE, gl_ScopeWorkgroup, BM, BN, gl_MatrixUseAccumulator> mat_d = coopmat<D_TYPE, gl_ScopeWorkgroup, BM, BN, gl_MatrixUseAccumulator>(sum);
coopMatStoreTensorNV(mat_d, data_d, pos_d, sliceTensorLayoutNV(tensorLayoutD, ic * BN, BN, ir * BM, BM), tensorViewTranspose);
return;
}
} else
#endif // !defined(MUL_MAT_ID)
@ -251,61 +399,43 @@ void main() {
tensorLayoutBClamp = setTensorLayoutStrideNV(tensorLayoutBClamp, stride_b, 1);
coopmat<ACC_TYPE, gl_ScopeWorkgroup, BM, BN, gl_MatrixUseAccumulator> sum;
sum = coopmat<ACC_TYPE, gl_ScopeWorkgroup, BM, BN, gl_MatrixUseAccumulator>(0.0);
uint k_iters = (end_k - start_k + BK - 1) / BK;
fetch_scales(ir * BM, pos_a, stride_a, start_k, tid, false);
[[dont_unroll]]
for (uint block_k = start_k; block_k < end_k; block_k += BK) {
for (uint block_k = start_k, i = 0; i < k_iters; block_k += BK, ++i) {
coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA> mat_a;
coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, BK, BN, gl_MatrixUseB> mat_b;
// Clamping is expensive, so detect different code paths for each combination
// of A and B needing clamping.
bool unclampedA = (ir + 1) * BM <= p.M && block_k + BK <= end_k && (block_k % 8) == 0;
#ifdef MUL_MAT_ID
bool unclampedB = true;
#else
bool unclampedB = (ic + 1) * BN <= p.N && block_k + BK <= end_k && (block_k % 8) == 0;
#endif
if (unclampedA && unclampedB) {
coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutA, ir * BM, BM, (block_k & ~7), BK) DECODEFUNCA);
#ifdef MUL_MAT_ID
coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BN, block_k, BK), tensorViewTranspose, decodeFuncB);
#else
coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BN, (block_k & ~7), BK), tensorViewTranspose);
#endif
sum = coopMatMulAdd(mat_a, mat_b, sum);
} else if (unclampedA && !unclampedB) {
coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutA, ir * BM, BM, (block_k & ~7), BK) DECODEFUNCA);
coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutBClamp, ic * BN, BN, block_k, BK), tensorViewTranspose);
sum = coopMatMulAdd(mat_a, mat_b, sum);
} else if (!unclampedA && unclampedB) {
coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutAClamp, ir * BM, BM, block_k, BK) DECODEFUNCA);
#ifdef MUL_MAT_ID
coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BN, block_k, BK), tensorViewTranspose, decodeFuncB);
#else
coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BN, (block_k & ~7), BK), tensorViewTranspose);
#endif
sum = coopMatMulAdd(mat_a, mat_b, sum);
} else if (!unclampedA && !unclampedB) {
coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutAClamp, ir * BM, BM, block_k, BK) DECODEFUNCA);
coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutBClamp, ic * BN, BN, block_k, BK), tensorViewTranspose);
sum = coopMatMulAdd(mat_a, mat_b, sum);
store_scales(tid);
if (block_k + BK < end_k) {
fetch_scales(ir * BM, pos_a, stride_a, block_k + BK, tid, false);
}
}
}
// Convert from ACC_TYPE to D_TYPE
coopmat<D_TYPE, gl_ScopeWorkgroup, BM, BN, gl_MatrixUseAccumulator> mat_d;
mat_d = coopmat<D_TYPE, gl_ScopeWorkgroup, BM, BN, gl_MatrixUseAccumulator>(sum);
coopmat<MAT_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA> mat_a;
coopmat<MAT_TYPE, gl_ScopeWorkgroup, BK, BN, gl_MatrixUseB> mat_b;
coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutAClamp, ir * BM, BM, block_k, BK) DECODEFUNCA);
#ifdef MUL_MAT_ID
coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BN, block_k, BK), tensorViewTranspose, decodeFuncB);
#else
coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutBClamp, ic * BN, BN, block_k, BK), tensorViewTranspose);
#endif
sum = coopMatMulAdd(mat_a, mat_b, sum);
}
// Convert from ACC_TYPE to D_TYPE
coopmat<D_TYPE, gl_ScopeWorkgroup, BM, BN, gl_MatrixUseAccumulator> mat_d;
mat_d = coopmat<D_TYPE, gl_ScopeWorkgroup, BM, BN, gl_MatrixUseAccumulator>(sum);
#ifdef MUL_MAT_ID
// Call callback to store each element, remapping row through shared memory
coopMatPerElementNV(mat_d, mat_d, perElemOpD, ir, ic);
// Call callback to store each element, remapping row through shared memory
coopMatPerElementNV(mat_d, mat_d, perElemOpD, ir, ic);
#else
tensorLayoutD = setTensorLayoutStrideNV(tensorLayoutD, p.stride_d, 1);
uint pos_d = batch_idx * p.batch_stride_d + ik * p.batch_stride_d * gl_NumWorkGroups.z;
coopMatStoreTensorNV(mat_d, data_d, pos_d, sliceTensorLayoutNV(tensorLayoutD, ic * BN, BN, ir * BM, BM), tensorViewTranspose);
coopMatStoreTensorNV(mat_d, data_d, pos_d, sliceTensorLayoutNV(tensorLayoutD, ic * BN, BN, ir * BM, BM), tensorViewTranspose);
#endif
}
}

View File

@ -0,0 +1,442 @@
#version 450
#extension GL_EXT_control_flow_attributes : enable
#extension GL_EXT_shader_16bit_storage : require
#extension GL_EXT_shader_explicit_arithmetic_types_int8 : require
#extension GL_EXT_integer_dot_product : require
#ifdef FLOAT16
#extension GL_EXT_shader_explicit_arithmetic_types_float16 : require
#endif
#ifdef COOPMAT
#extension GL_KHR_cooperative_matrix : enable
#extension GL_KHR_memory_scope_semantics : enable
#extension GL_KHR_shader_subgroup_basic : enable
#endif
#ifdef MUL_MAT_ID
#extension GL_EXT_shader_explicit_arithmetic_types_int16 : require
#endif
#include "types.comp"
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
layout (binding = 0) readonly buffer A {A_TYPE_PACKED16 data_a[];};
#if defined(A_TYPE_PACKED32)
layout (binding = 0) readonly buffer A_PACKED32 {A_TYPE_PACKED32 data_a_packed32[];};
#endif
layout (binding = 1) readonly buffer B {block_q8_1_packed32 data_b[];};
layout (binding = 2) writeonly buffer D {D_TYPE data_d[];};
#ifdef MUL_MAT_ID
layout (binding = 3) readonly buffer IDS {int data_ids[];};
#endif
layout (push_constant) uniform parameter
{
uint M;
uint N;
uint K;
uint stride_a;
uint stride_b;
uint stride_d;
uint batch_stride_a;
uint batch_stride_b;
uint batch_stride_d;
#ifdef MUL_MAT_ID
uint nei0;
uint nei1;
uint nbi1;
uint ne11;
#else
uint k_split;
uint ne02;
uint ne12;
uint broadcast2;
uint broadcast3;
#endif
} p;
layout (constant_id = 0) const uint BLOCK_SIZE = 64;
layout (constant_id = 1) const uint BM = 64;
layout (constant_id = 2) const uint BN = 64;
// layout (constant_id = 3) const uint BK = 32;
layout (constant_id = 4) const uint WM = 32;
layout (constant_id = 5) const uint WN = 32;
layout (constant_id = 6) const uint WMITER = 2;
layout (constant_id = 7) const uint TM = 4;
layout (constant_id = 8) const uint TN = 2;
layout (constant_id = 9) const uint TK = 1; // Only needed for coopmat
layout (constant_id = 10) const uint WARP = 32;
#define BK 32
#ifdef COOPMAT
#define SHMEM_STRIDE (BK / 4 + 4)
#else
#define SHMEM_STRIDE (BK / 4 + 1)
#endif
shared int32_t buf_a_qs[BM * SHMEM_STRIDE];
#ifndef COOPMAT
#if QUANT_AUXF == 1
shared FLOAT_TYPE buf_a_dm[BM];
#else
shared FLOAT_TYPE_VEC2 buf_a_dm[BM];
#endif
#endif
shared int32_t buf_b_qs[BN * SHMEM_STRIDE];
#ifndef COOPMAT
shared FLOAT_TYPE_VEC2 buf_b_ds[BN];
#endif
#define LOAD_VEC_A (4 * QUANT_R)
#define LOAD_VEC_B 4
#ifdef MUL_MAT_ID
shared u16vec2 row_ids[4096];
#endif // MUL_MAT_ID
#define NUM_WARPS (BLOCK_SIZE / WARP)
#ifdef COOPMAT
shared ACC_TYPE coopmat_stage[TM * TN * NUM_WARPS];
#endif
#include "mul_mmq_funcs.comp"
void main() {
#ifdef NEEDS_INIT_IQ_SHMEM
init_iq_shmem(gl_WorkGroupSize);
#endif
#ifdef MUL_MAT_ID
const uint expert_idx = gl_GlobalInvocationID.z;
#else
const uint batch_idx = gl_GlobalInvocationID.z;
const uint i13 = batch_idx / p.ne12;
const uint i12 = batch_idx % p.ne12;
const uint i03 = i13 / p.broadcast3;
const uint i02 = i12 / p.broadcast2;
const uint batch_idx_a = i03 * p.ne02 + i02;
#endif
const uint blocks_m = (p.M + BM - 1) / BM;
const uint ir = gl_WorkGroupID.x % blocks_m;
const uint ik = gl_WorkGroupID.x / blocks_m;
const uint ic = gl_WorkGroupID.y;
const uint WNITER = (WM * WN) / (WARP * TM * TN * WMITER);
const uint WSUBM = WM / WMITER;
const uint WSUBN = WN / WNITER;
#ifdef COOPMAT
const uint warp_i = gl_SubgroupID;
const uint tiw = gl_SubgroupInvocationID;
const uint cms_per_row = WM / TM;
const uint cms_per_col = WN / TN;
const uint storestride = WARP / TM;
const uint store_r = tiw % TM;
const uint store_c = tiw / TM;
#else
const uint warp_i = gl_LocalInvocationID.x / WARP;
const uint tiw = gl_LocalInvocationID.x % WARP;
const uint tiwr = tiw % (WSUBM / TM);
const uint tiwc = tiw / (WSUBM / TM);
#endif
const uint warp_r = warp_i % (BM / WM);
const uint warp_c = warp_i / (BM / WM);
const uint loadr_a = gl_LocalInvocationID.x % (BK / LOAD_VEC_A);
const uint loadc_a = gl_LocalInvocationID.x / (BK / LOAD_VEC_A);
const uint loadr_b = gl_LocalInvocationID.x % (BK / LOAD_VEC_B);
const uint loadc_b = gl_LocalInvocationID.x / (BK / LOAD_VEC_B);
const uint loadstride_a = BLOCK_SIZE * LOAD_VEC_A / BK;
const uint loadstride_b = BLOCK_SIZE * LOAD_VEC_B / BK;
#ifdef MUL_MAT_ID
uint _ne1 = 0;
for (uint ii1 = 0; ii1 < p.nei1; ii1++) {
for (uint ii0 = 0; ii0 < p.nei0; ii0++) {
if (data_ids[ii1*p.nbi1 + ii0] == expert_idx) {
row_ids[_ne1] = u16vec2(ii0, ii1);
_ne1++;
}
}
}
barrier();
// Workgroup has no work
if (ic * BN >= _ne1) return;
#endif
#ifdef MUL_MAT_ID
const uint start_k = 0;
const uint end_k = p.K;
#else
const uint start_k = ik * p.k_split;
const uint end_k = min(p.K, (ik + 1) * p.k_split);
#endif
uint pos_a_ib = (
#ifdef MUL_MAT_ID
expert_idx * p.batch_stride_a +
#else
batch_idx_a * p.batch_stride_a +
#endif
ir * BM * p.stride_a + start_k) / BK;
#ifdef MUL_MAT_ID
uint pos_b_ib = 0;
#else
uint pos_b_ib = (batch_idx * p.batch_stride_b + ic * BN * p.stride_b + start_k) / BK;
#endif
#ifdef COOPMAT
coopmat<int8_t, gl_ScopeSubgroup, TM, TK, gl_MatrixUseA> cache_a;
coopmat<int8_t, gl_ScopeSubgroup, TK, TN, gl_MatrixUseB> cache_b;
coopmat<int32_t, gl_ScopeSubgroup, TM, TN, gl_MatrixUseAccumulator> cm_result;
coopmat<ACC_TYPE, gl_ScopeSubgroup, TM, TN, gl_MatrixUseAccumulator> factors[cms_per_row * cms_per_col];
coopmat<ACC_TYPE, gl_ScopeSubgroup, TM, TN, gl_MatrixUseAccumulator> sums[cms_per_row * cms_per_col];
[[unroll]] for (uint i = 0; i < cms_per_row * cms_per_col; i++) {
sums[i] = coopmat<ACC_TYPE, gl_ScopeSubgroup, TM, TN, gl_MatrixUseAccumulator>(0.0f);
}
#else
int32_t cache_a_qs[WMITER * TM * BK / 4];
int32_t cache_b_qs[TN * BK / 4];
ACC_TYPE sums[WMITER * TM * WNITER * TN];
[[unroll]] for (uint i = 0; i < WMITER*TM*WNITER*TN; i++) {
sums[i] = ACC_TYPE(0.0f);
}
#endif
#if QUANT_AUXF == 1
FLOAT_TYPE cache_a_dm[WMITER * TM];
#else
FLOAT_TYPE_VEC2 cache_a_dm[WMITER * TM];
#endif
FLOAT_TYPE_VEC2 cache_b_ds[TN];
for (uint block = start_k; block < end_k; block += BK) {
[[unroll]] for (uint l = 0; loadc_a + l < BM; l += loadstride_a) {
const uint ib = pos_a_ib + (loadc_a + l) * p.stride_a / BK;
const uint iqs = loadr_a;
const uint buf_ib = loadc_a + l;
if (iqs == 0) {
#if QUANT_AUXF == 1
buf_a_dm[buf_ib] = get_d(ib);
#else
buf_a_dm[buf_ib] = get_dm(ib);
#endif
}
#if QUANT_R == 1
buf_a_qs[buf_ib * SHMEM_STRIDE + iqs] = repack(ib, iqs);
#else
const i32vec2 vals = repack(ib, iqs);
buf_a_qs[buf_ib * SHMEM_STRIDE + iqs ] = vals.x;
buf_a_qs[buf_ib * SHMEM_STRIDE + iqs + 4] = vals.y;
#endif
}
[[unroll]] for (uint l = 0; loadc_b + l < BN; l += loadstride_b) {
#ifdef MUL_MAT_ID
const u16vec2 row_idx = row_ids[ic * BN + loadc_b + l];
const uint idx = pos_b_ib + row_idx.y * p.batch_stride_b / LOAD_VEC_B + (row_idx.x % p.ne11) * p.stride_b / LOAD_VEC_B + loadr_b;
const uint ib = idx / 8;
const uint iqs = idx & 0x7;
#else
const uint ib = pos_b_ib + (loadc_b + l) * p.stride_b / BK;
const uint iqs = loadr_b;
#endif
const uint buf_ib = loadc_b + l;
if (iqs == 0) {
buf_b_ds[buf_ib] = FLOAT_TYPE_VEC2(data_b[ib].ds);
}
buf_b_qs[buf_ib * SHMEM_STRIDE + iqs] = data_b[ib].qs[iqs];
}
barrier();
pos_a_ib += 1;
pos_b_ib += 1;
#ifdef COOPMAT
[[unroll]] for (uint cm_row = 0; cm_row < cms_per_row; cm_row++) {
const uint ib_a = warp_r * WM + cm_row * TM;
// Load from shared into cache
coopMatLoad(cache_a, buf_a_qs, ib_a * SHMEM_STRIDE, SHMEM_STRIDE, gl_CooperativeMatrixLayoutRowMajor);
// TODO: only cache values that are actually needed
[[unroll]] for (uint t_idx = 0; t_idx < TM; t_idx++) {
cache_a_dm[t_idx] = buf_a_dm[ib_a + t_idx];
}
[[unroll]] for (uint cm_col = 0; cm_col < cms_per_col; cm_col++) {
const uint ib_b = warp_c * WN + cm_col * TN;
coopMatLoad(cache_b, buf_b_qs, ib_b * SHMEM_STRIDE, SHMEM_STRIDE, gl_CooperativeMatrixLayoutColumnMajor);
// TODO: only cache values that are actually needed
[[unroll]] for (uint t_idx = 0; t_idx < TN; t_idx++) {
cache_b_dm[t_idx] = buf_b_d[ib_b + t_idx];
}
cm_result = coopmat<int32_t, gl_ScopeSubgroup, TM, TN, gl_MatrixUseAccumulator>(0);
cm_result = coopMatMulAdd(cache_a, cache_b, cm_result);
[[unroll]] for (uint col = 0; col < TN; col += storestride) {
coopmat_stage[warp_i * TM * TN + (store_c + col) * TM + store_r] = ACC_TYPE(float(cache_a_d[store_r]) * float(cache_b_d[store_c + col]));
}
coopMatLoad(factors, coopmat_stage, warp_i * TM * TN, TM, gl_CooperativeMatrixLayoutColumnMajor);
sums[cm_col * cms_per_row + cm_row] += factors * coopmat<ACC_TYPE, gl_ScopeSubgroup, TM, TN, gl_MatrixUseAccumulator>(cm_result);
}
}
#else
// Load from shared into cache
[[unroll]] for (uint wsir = 0; wsir < WMITER; wsir++) {
[[unroll]] for (uint cr = 0; cr < TM; cr++) {
const uint ib = warp_r * WM + wsir * WSUBM + tiwr * TM + cr;
cache_a_dm[wsir * TM + cr] = buf_a_dm[ib];
[[unroll]] for (uint idx_k = 0; idx_k < BK / 4; idx_k++) {
cache_a_qs[(wsir * TM + cr) * (BK / 4) + idx_k] = buf_a_qs[ib * SHMEM_STRIDE + idx_k];
}
}
}
[[unroll]] for (uint wsic = 0; wsic < WNITER; wsic++) {
[[unroll]] for (uint cc = 0; cc < TN; cc++) {
const uint ib = warp_c * WN + wsic * WSUBN + tiwc * TN + cc;
cache_b_ds[cc] = buf_b_ds[ib];
[[unroll]] for (uint idx_k = 0; idx_k < BK / 4; idx_k++) {
cache_b_qs[cc * (BK / 4) + idx_k] = buf_b_qs[ib * SHMEM_STRIDE + idx_k];
}
}
[[unroll]] for (uint wsir = 0; wsir < WMITER; wsir++) {
[[unroll]] for (uint cc = 0; cc < TN; cc++) {
[[unroll]] for (uint cr = 0; cr < TM; cr++) {
const uint cache_a_idx = wsir * TM + cr;
const uint sums_idx = (wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr;
int32_t q_sum = 0;
[[unroll]] for (uint idx_k = 0; idx_k < BK / 4; idx_k++) {
q_sum += dotPacked4x8EXT(cache_a_qs[cache_a_idx * (BK / 4) + idx_k],
cache_b_qs[cc * (BK / 4) + idx_k]);
}
sums[sums_idx] += mul_q8_1(q_sum, cache_a_dm[cache_a_idx], cache_b_ds[cc]);
}
}
}
}
#endif
barrier();
}
const uint dr = ir * BM + warp_r * WM;
const uint dc = ic * BN + warp_c * WN;
#ifndef MUL_MAT_ID
const uint offsets = batch_idx * p.batch_stride_d + ik * p.batch_stride_d * gl_NumWorkGroups.z;
#endif
#ifdef COOPMAT
#ifdef MUL_MAT_ID
[[unroll]] for (uint cm_row = 0; cm_row < cms_per_row; cm_row++) {
[[unroll]] for (uint cm_col = 0; cm_col < cms_per_col; cm_col++) {
coopMatStore(sums[cm_col * cms_per_row + cm_row], coopmat_stage, warp_i * TM * TN, TM, gl_CooperativeMatrixLayoutColumnMajor);
[[unroll]] for (uint col = 0; col < BN; col += storestride) {
const uint row_i = dc + cm_col * TN + col + store_c;
if (row_i >= _ne1) break;
const u16vec2 row_idx = row_ids[row_i];
data_d[row_idx.y * p.batch_stride_d + row_idx.x * p.stride_d + dr + cm_row * TM + store_r] = D_TYPE(coopmat_stage[warp_i * TM * TN + (col + store_c) * TM + store_r]);
}
}
}
#else
const bool is_aligned = p.stride_d % 4 == 0; // Assumption: D_TYPE == float
[[unroll]] for (uint cm_row = 0; cm_row < cms_per_row; cm_row++) {
[[unroll]] for (uint cm_col = 0; cm_col < cms_per_col; cm_col++) {
const bool is_in_bounds = dr + (cm_row + 1) * TM <= p.M && dc + (cm_col + 1) * TN <= p.N;
if (is_aligned && is_in_bounds) {
// Full coopMat is within bounds and stride_d is aligned with 16B
coopmat<D_TYPE, gl_ScopeSubgroup, TM, TN, gl_MatrixUseAccumulator> cm_dtype = coopmat<D_TYPE, gl_ScopeSubgroup, TM, TN, gl_MatrixUseAccumulator>(sums[cm_col * cms_per_row + cm_row]);
coopMatStore(cm_dtype, data_d, offsets + (dc + cm_col * TN) * p.stride_d + dr + cm_row * TM, p.stride_d, gl_CooperativeMatrixLayoutColumnMajor);
} else if (is_in_bounds) {
// Full coopMat is within bounds, but stride_d is not aligned
coopMatStore(sums[cm_col * cms_per_row + cm_row], coopmat_stage, warp_i * TM * TN, TM, gl_CooperativeMatrixLayoutColumnMajor);
[[unroll]] for (uint col = 0; col < TN; col += storestride) {
data_d[offsets + (dc + cm_col * TN + col + store_c) * p.stride_d + dr + cm_row * TM + store_r] = D_TYPE(coopmat_stage[warp_i * TM * TN + (col + store_c) * TM + store_r]);
}
} else if (dr + cm_row * TM < p.M && dc + cm_col * TN < p.N) {
// Partial coopMat is within bounds
coopMatStore(sums[cm_col * cms_per_row + cm_row], coopmat_stage, warp_i * TM * TN, TM, gl_CooperativeMatrixLayoutColumnMajor);
[[unroll]] for (uint col = 0; col < TN; col += storestride) {
if (dr + cm_row * TM + store_r < p.M && dc + cm_col * TN + col + store_c < p.N) {
data_d[offsets + (dc + cm_col * TN + col + store_c) * p.stride_d + dr + cm_row * TM + store_r] = D_TYPE(coopmat_stage[warp_i * TM * TN + (col + store_c) * TM + store_r]);
}
}
}
}
}
#endif // MUL_MAT_ID
#else
[[unroll]] for (uint wsic = 0; wsic < WNITER; wsic++) {
[[unroll]] for (uint wsir = 0; wsir < WMITER; wsir++) {
const uint dr_warp = dr + wsir * WSUBM + tiwr * TM;
const uint dc_warp = dc + wsic * WSUBN + tiwc * TN;
[[unroll]] for (uint cc = 0; cc < TN; cc++) {
#ifdef MUL_MAT_ID
const uint row_i = dc_warp + cc;
if (row_i >= _ne1) break;
const u16vec2 row_idx = row_ids[row_i];
#endif // MUL_MAT_ID
[[unroll]] for (uint cr = 0; cr < TM; cr++) {
#ifdef MUL_MAT_ID
data_d[row_idx.y * p.batch_stride_d + row_idx.x * p.stride_d + dr_warp + cr] = D_TYPE(sums[(wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr]);
#else
if (dr_warp + cr < p.M && dc_warp + cc < p.N) {
data_d[offsets + (dc_warp + cc) * p.stride_d + dr_warp + cr] = D_TYPE(sums[(wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr]);
}
#endif // MUL_MAT_ID
}
}
}
}
#endif // COOPMAT
}

View File

@ -0,0 +1,99 @@
#extension GL_EXT_shader_explicit_arithmetic_types_int32 : require
#extension GL_EXT_shader_explicit_arithmetic_types_int16 : require
#extension GL_EXT_shader_explicit_arithmetic_types_int8 : require
#include "types.comp"
// Each iqs value maps to a 32-bit integer
#if defined(DATA_A_Q4_0)
i32vec2 repack(uint ib, uint iqs) {
// Use 2-byte loads since a q4_0 block (18 bytes) is not divisible by 4
const u16vec2 quants = u16vec2(data_a[ib].qs[iqs * 2 ],
data_a[ib].qs[iqs * 2 + 1]);
const uint32_t vui = pack32(quants);
return i32vec2( vui & 0x0F0F0F0F,
(vui >> 4) & 0x0F0F0F0F);
}
ACC_TYPE mul_q8_1(int32_t q_sum, float da, vec2 dsb) {
return ACC_TYPE(da * (float(q_sum) * dsb.x - 8.0f * dsb.y));
}
#endif
#if defined(DATA_A_Q4_1)
i32vec2 repack(uint ib, uint iqs) {
// Use 4-byte loads since a q4_1 block (20 bytes) is divisible by 4
const uint32_t vui = data_a_packed32[ib].qs[iqs];
return i32vec2( vui & 0x0F0F0F0F,
(vui >> 4) & 0x0F0F0F0F);
}
ACC_TYPE mul_q8_1(int32_t q_sum, vec2 dma, vec2 dsb) {
return ACC_TYPE(float(q_sum) * dma.x * dsb.x + dma.y * dsb.y);
}
#endif
#if defined(DATA_A_Q5_0)
i32vec2 repack(uint ib, uint iqs) {
// Use 2-byte loads since a q5_0 block (22 bytes) is not divisible by 4
const u16vec2 quants = u16vec2(data_a[ib].qs[iqs * 2 ],
data_a[ib].qs[iqs * 2 + 1]);
const uint32_t vui = pack32(quants);
const int32_t qh = int32_t((uint32_t(data_a[ib].qh[1]) << 16 | data_a[ib].qh[0]) >> (4 * iqs));
const int32_t v0 = int32_t(vui & 0x0F0F0F0F)
| ((qh & 0xF) * 0x02040810) & 0x10101010; // (0,1,2,3) -> (4,12,20,28)
const int32_t v1 = int32_t((vui >> 4) & 0x0F0F0F0F)
| (((qh >> 16) & 0xF) * 0x02040810) & 0x10101010; // (16,17,18,19) -> (4,12,20,28)
return i32vec2(v0, v1);
}
ACC_TYPE mul_q8_1(int32_t q_sum, float da, vec2 dsb) {
return ACC_TYPE(da * (float(q_sum) * dsb.x - 16.0f * dsb.y));
}
#endif
#if defined(DATA_A_Q5_1)
i32vec2 repack(uint ib, uint iqs) {
// Use 4-byte loads since a q5_1 block (24 bytes) is divisible by 4
const uint32_t vui = data_a_packed32[ib].qs[iqs];
const int32_t qh = int32_t(data_a_packed32[ib].qh >> (4 * iqs));
const int32_t v0 = int32_t(vui & 0x0F0F0F0F)
| ((qh & 0xF) * 0x02040810) & 0x10101010; // (0,1,2,3) -> (4,12,20,28)
const int32_t v1 = int32_t((vui >> 4) & 0x0F0F0F0F)
| (((qh >> 16) & 0xF) * 0x02040810) & 0x10101010; // (16,17,18,19) -> (4,12,20,28)
return i32vec2(v0, v1);
}
ACC_TYPE mul_q8_1(int32_t q_sum, vec2 dma, vec2 dsb) {
return ACC_TYPE(float(q_sum) * dma.x * dsb.x + dma.y * dsb.y);
}
#endif
#if defined(DATA_A_Q8_0)
int32_t repack(uint ib, uint iqs) {
// Use 2-byte loads since a q8_0 block (34 bytes) is not divisible by 4
return pack32(i16vec2(data_a[ib].qs[iqs * 2 ],
data_a[ib].qs[iqs * 2 + 1]));
}
ACC_TYPE mul_q8_1(int32_t q_sum, float da, vec2 dsb) {
return ACC_TYPE(float(q_sum) * da * dsb.x);
}
#endif
#if defined(DATA_A_Q4_0) || defined(DATA_A_Q5_0) || defined(DATA_A_Q8_0) || defined(DATA_A_IQ1_S) || defined(DATA_A_IQ2_XXS) || defined(DATA_A_IQ2_XS) || defined(DATA_A_IQ2_S) || defined(DATA_A_IQ3_XXS) || defined(DATA_A_IQ3_S) || defined(DATA_A_IQ4_XS) || defined(DATA_A_IQ4_NL)
FLOAT_TYPE get_d(uint ib) {
return FLOAT_TYPE(data_a[ib].d);
}
#endif
#if defined(DATA_A_Q4_1) || defined(DATA_A_Q5_1)
FLOAT_TYPE_VEC2 get_dm(uint ib) {
return FLOAT_TYPE_VEC2(data_a_packed32[ib].dm);
}
#endif

View File

@ -0,0 +1,77 @@
#version 450
#extension GL_EXT_control_flow_attributes : require
#extension GL_EXT_shader_16bit_storage : require
layout (push_constant) uniform parameter
{
uint ne;
} p;
#include "types.comp"
layout(constant_id = 0) const uint GROUP_SIZE = 32;
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
layout (binding = 0) readonly buffer A {vec4 data_a[];};
layout (binding = 1) writeonly buffer D {block_q8_1_packed32 data_b[];};
shared float shmem[GROUP_SIZE];
void quantize() {
const uint wgid = gl_WorkGroupID.x;
const uint tid = gl_LocalInvocationID.x;
// Each thread handles a vec4, so 8 threads handle a block
const uint blocks_per_group = GROUP_SIZE / 8;
const uint block_in_wg = tid / 8;
const uint ib = wgid * blocks_per_group + block_in_wg;
const uint iqs = tid % 8;
if (ib >= gl_NumWorkGroups.x * blocks_per_group) {
return;
}
const uint a_idx = ib * 8 + iqs;
vec4 vals = a_idx < p.ne ? data_a[a_idx] : vec4(0.0f);
const vec4 abs_vals = abs(vals);
// Find absolute max for each block
shmem[tid] = max(max(abs_vals.x, abs_vals.y), max(abs_vals.z, abs_vals.w));
barrier();
[[unroll]] for (uint s = 4; s > 0; s >>= 1) {
if (iqs < s) {
shmem[tid] = max(shmem[tid], shmem[tid + s]);
}
barrier();
}
const float amax = shmem[block_in_wg * 8];
const float d = amax / 127.0;
const float d_inv = d != 0.0 ? 1.0 / d : 0.0;
vals = round(vals * d_inv);
data_b[ib].qs[iqs] = pack32(i8vec4(round(vals)));
barrier();
// Calculate the sum for each block
shmem[tid] = vals.x + vals.y + vals.z + vals.w;
barrier();
[[unroll]] for (uint s = 4; s > 0; s >>= 1) {
if (iqs < s) {
shmem[tid] += shmem[tid + s];
}
barrier();
}
if (iqs == 0) {
const float sum = shmem[tid];
data_b[ib].ds = f16vec2(vec2(d, sum * d));
}
}
void main() {
quantize();
}

View File

@ -17,5 +17,5 @@ void main() {
return;
}
data_d[i] = max(float(data_a[i]), 0);
data_d[i] = D_TYPE(max(float(data_a[i]), 0));
}

View File

@ -1,6 +1,6 @@
#version 450
#include "generic_head.comp"
#include "generic_unary_head.comp"
#include "types.comp"
#extension GL_EXT_control_flow_attributes : enable
@ -8,19 +8,29 @@
layout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in;
layout (binding = 0) readonly buffer X {A_TYPE data_a[];};
layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};
shared FLOAT_TYPE sum[BLOCK_SIZE];
void main() {
const uint row = gl_WorkGroupID.z * 262144 + gl_WorkGroupID.y * 512 + gl_WorkGroupID.x;
const uint tid = gl_LocalInvocationID.x;
const uint ncols = p.ne00;
const uint nrows = gl_NumWorkGroups.x;
const uint nchannels = gl_NumWorkGroups.y;
const uint row = gl_WorkGroupID.x;
const uint channel = gl_WorkGroupID.y;
const uint samp = gl_WorkGroupID.z;
const uint tid = gl_LocalInvocationID.x;
const uint stride_row = p.nb01;
const uint stride_channel = p.nb02;
const uint stride_sample = p.nb03;
uint32_t a_offset = samp*stride_sample + channel*stride_channel + row*stride_row + get_aoffset();
uint32_t d_offset = ((samp*nchannels + channel)*nrows + row)*ncols + get_doffset();
sum[tid] = FLOAT_TYPE(0.0f); // partial sum for thread in warp
[[unroll]] for (uint col = tid; col < p.KX; col += BLOCK_SIZE) {
const FLOAT_TYPE xi = FLOAT_TYPE(data_a[row*p.KX + col]);
[[unroll]] for (uint col = tid; col < ncols; col += BLOCK_SIZE) {
const FLOAT_TYPE xi = FLOAT_TYPE(data_a[a_offset + col]);
sum[tid] += xi * xi;
}
@ -33,10 +43,10 @@ void main() {
barrier();
}
const FLOAT_TYPE mean = sum[0] / FLOAT_TYPE(p.KX);
const FLOAT_TYPE mean = sum[0] / FLOAT_TYPE(ncols);
const FLOAT_TYPE scale = inversesqrt(mean + FLOAT_TYPE(p.param1));
[[unroll]] for (uint col = tid; col < p.KX; col += BLOCK_SIZE) {
data_d[row*p.KX + col] = D_TYPE(scale * FLOAT_TYPE(data_a[row*p.KX + col]));
[[unroll]] for (uint col = tid; col < ncols; col += BLOCK_SIZE) {
data_d[d_offset + col] = D_TYPE(scale * FLOAT_TYPE(data_a[a_offset + col]));
}
}

View File

@ -16,5 +16,5 @@ void main() {
if (i >= p.KX) {
return;
}
data_d[i] = D_TYPE(1. / (1 + exp(-1. *data_a[i])));
data_d[i] = D_TYPE(1. / (1 + exp(-1. * float(data_a[i]))));
}

View File

@ -16,5 +16,5 @@ void main() {
if (i >= p.KX) {
return;
}
data_d[i] = D_TYPE(1. - 2. / (exp(2.*data_a[i]) + 1.));
data_d[i] = D_TYPE(1. - 2. / (exp(2.*float(data_a[i])) + 1.));
}

View File

@ -0,0 +1,7 @@
#version 460
#extension GL_EXT_bfloat16 : require
void main()
{
}

View File

@ -0,0 +1,7 @@
#version 460
#extension GL_EXT_integer_dot_product : require
void main()
{
}

View File

@ -1,7 +1,7 @@
#if !defined(GGML_TYPES_COMP)
#define GGML_TYPES_COMP
#extension GL_EXT_shader_explicit_arithmetic_types_int64 : require
#extension GL_EXT_shader_explicit_arithmetic_types_int32 : require
#extension GL_EXT_shader_explicit_arithmetic_types_int16 : require
#extension GL_EXT_shader_explicit_arithmetic_types_int8 : require
@ -33,6 +33,19 @@
#endif
#endif
#if defined(DATA_A_BF16)
#define QUANT_K 1
#define QUANT_R 1
#if !defined(LOAD_VEC_A) || LOAD_VEC_A == 1
#define A_TYPE uint16_t
#elif LOAD_VEC_A == 4
#define A_TYPE u16vec4
#elif LOAD_VEC_A == 8
#error unsupported
#endif
#endif
#define QUANT_K_Q4_0 32
#define QUANT_R_Q4_0 2
@ -50,6 +63,7 @@ struct block_q4_0_packed16
#if defined(DATA_A_Q4_0)
#define QUANT_K QUANT_K_Q4_0
#define QUANT_R QUANT_R_Q4_0
#define QUANT_AUXF 1
#define A_TYPE block_q4_0
#define A_TYPE_PACKED16 block_q4_0_packed16
#endif
@ -71,11 +85,19 @@ struct block_q4_1_packed16
uint16_t qs[16/2];
};
struct block_q4_1_packed32
{
f16vec2 dm;
uint32_t qs[16/4];
};
#if defined(DATA_A_Q4_1)
#define QUANT_K QUANT_K_Q4_1
#define QUANT_R QUANT_R_Q4_1
#define QUANT_AUXF 2
#define A_TYPE block_q4_1
#define A_TYPE_PACKED16 block_q4_1_packed16
#define A_TYPE_PACKED32 block_q4_1_packed32
#endif
#define QUANT_K_Q5_0 32
@ -98,6 +120,7 @@ struct block_q5_0_packed16
#if defined(DATA_A_Q5_0)
#define QUANT_K QUANT_K_Q5_0
#define QUANT_R QUANT_R_Q5_0
#define QUANT_AUXF 1
#define A_TYPE block_q5_0
#define A_TYPE_PACKED16 block_q5_0_packed16
#endif
@ -121,11 +144,20 @@ struct block_q5_1_packed16
uint16_t qs[16/2];
};
struct block_q5_1_packed32
{
f16vec2 dm;
uint qh;
uint32_t qs[16/4];
};
#if defined(DATA_A_Q5_1)
#define QUANT_K QUANT_K_Q5_1
#define QUANT_R QUANT_R_Q5_1
#define QUANT_AUXF 2
#define A_TYPE block_q5_1
#define A_TYPE_PACKED16 block_q5_1_packed16
#define A_TYPE_PACKED32 block_q5_1_packed32
#endif
#define QUANT_K_Q8_0 32
@ -139,16 +171,42 @@ struct block_q8_0
struct block_q8_0_packed16
{
float16_t d;
uint16_t qs[32/2];
int16_t qs[32/2];
};
struct block_q8_0_packed32
{
float16_t d;
int32_t qs[32/4];
};
#if defined(DATA_A_Q8_0)
#define QUANT_K QUANT_K_Q8_0
#define QUANT_R QUANT_R_Q8_0
#define QUANT_AUXF 1
#define A_TYPE block_q8_0
#define A_TYPE_PACKED16 block_q8_0_packed16
#define A_TYPE_PACKED32 block_q8_0_packed32
#endif
#define QUANT_K_Q8_1 32
#define QUANT_R_Q8_1 1
struct block_q8_1
{
f16vec2 ds;
int8_t qs[32];
};
struct block_q8_1_packed16
{
f16vec2 ds;
int16_t qs[16];
};
struct block_q8_1_packed32
{
f16vec2 ds;
int32_t qs[8];
};
// K-quants
#define QUANT_K_Q2_K 256
@ -312,6 +370,12 @@ struct block_iq1_m {
uint16_t scales[QUANT_K_IQ1_M/64];
};
struct block_iq1_m_packed64 {
uint64_t qs[QUANT_K_IQ1_M/8/8];
uint64_t qh[QUANT_K_IQ1_M/16/8];
uint64_t scales;
};
#if defined(DATA_A_IQ1_S)
#define QUANT_K QUANT_K_IQ1_S
#define QUANT_R QUANT_R_IQ1_S
@ -466,10 +530,13 @@ shared uint16_t iq1s_grid[2048];
void init_iq_shmem(uvec3 wgsize)
{
// copy the table into shared memory and sync
for (uint i = gl_LocalInvocationIndex.x; i < iq1s_grid_const.length(); i += wgsize.x) {
u16vec2 g = unpack16(iq1s_grid_const[i]);
iq1s_grid[2*i+0] = g.x;
iq1s_grid[2*i+1] = g.y;
[[unroll]] for (uint i = 0; i < iq1s_grid_const.length(); i += wgsize.x) {
uint idx = i + gl_LocalInvocationIndex.x;
if (iq1s_grid_const.length() % wgsize.x == 0 || idx < iq1s_grid_const.length()) {
u16vec2 g = unpack16(iq1s_grid_const[idx]);
iq1s_grid[2*idx+0] = g.x;
iq1s_grid[2*idx+1] = g.y;
}
}
barrier();
}
@ -565,8 +632,10 @@ shared uvec2 iq2xxs_grid[256];
void init_iq_shmem(uvec3 wgsize)
{
// copy the table into shared memory and sync
for (uint i = gl_LocalInvocationIndex.x; i < iq2xxs_grid.length(); i += wgsize.x) {
iq2xxs_grid[i] = iq2xxs_grid_const[i];
[[unroll]] for (uint i = 0; i < iq2xxs_grid.length(); i += wgsize.x) {
if (iq2xxs_grid_const.length() % wgsize.x == 0 || i + gl_LocalInvocationIndex.x < iq2xxs_grid_const.length()) {
iq2xxs_grid[i + gl_LocalInvocationIndex.x] = iq2xxs_grid_const[i + gl_LocalInvocationIndex.x];
}
}
barrier();
}
@ -733,8 +802,10 @@ shared uvec2 iq2xs_grid[512];
void init_iq_shmem(uvec3 wgsize)
{
// copy the table into shared memory and sync
for (uint i = gl_LocalInvocationIndex.x; i < iq2xs_grid.length(); i += wgsize.x) {
iq2xs_grid[i] = iq2xs_grid_const[i];
[[unroll]] for (uint i = 0; i < iq2xs_grid.length(); i += wgsize.x) {
if (iq2xs_grid.length() % wgsize.x == 0 || i + gl_LocalInvocationIndex.x < iq2xs_grid_const.length()) {
iq2xs_grid[i + gl_LocalInvocationIndex.x] = iq2xs_grid_const[i + gl_LocalInvocationIndex.x];
}
}
barrier();
}
@ -756,6 +827,14 @@ struct block_iq2_s
uint8_t scales[QUANT_K_IQ2_S/32];
};
struct block_iq2_s_packed16
{
float16_t d;
uint16_t qs[QUANT_K_IQ2_S/8];
uint16_t qh[QUANT_K_IQ2_S/64];
uint16_t scales[QUANT_K_IQ2_S/64];
};
#if defined(DATA_A_IQ2_S)
const uvec2 iq2s_grid_const[1024] = {
@ -1023,8 +1102,10 @@ shared uvec2 iq2s_grid[1024];
void init_iq_shmem(uvec3 wgsize)
{
// copy the table into shared memory and sync
for (uint i = gl_LocalInvocationIndex.x; i < iq2s_grid.length(); i += wgsize.x) {
iq2s_grid[i] = iq2s_grid_const[i];
[[unroll]] for (uint i = 0; i < iq2s_grid.length(); i += wgsize.x) {
if (iq2s_grid.length() % wgsize.x == 0 || i + gl_LocalInvocationIndex.x < iq2s_grid_const.length()) {
iq2s_grid[i + gl_LocalInvocationIndex.x] = iq2s_grid_const[i + gl_LocalInvocationIndex.x];
}
}
barrier();
}
@ -1032,6 +1113,7 @@ void init_iq_shmem(uvec3 wgsize)
#define QUANT_K QUANT_K_IQ2_S
#define QUANT_R QUANT_R_IQ2_S
#define A_TYPE block_iq2_s
#define A_TYPE_PACKED16 block_iq2_s_packed16
#endif
#define QUANT_K_IQ3_XXS 256
@ -1092,8 +1174,10 @@ shared uint32_t iq3xxs_grid[256];
void init_iq_shmem(uvec3 wgsize)
{
// copy the table into shared memory and sync
for (uint i = gl_LocalInvocationIndex.x; i < iq3xxs_grid.length(); i += wgsize.x) {
iq3xxs_grid[i] = iq3xxs_grid_const[i];
[[unroll]] for (uint i = 0; i < iq3xxs_grid.length(); i += wgsize.x) {
if (iq3xxs_grid.length() % wgsize.x == 0 || i + gl_LocalInvocationIndex.x < iq3xxs_grid.length()) {
iq3xxs_grid[i + gl_LocalInvocationIndex.x] = iq3xxs_grid_const[i + gl_LocalInvocationIndex.x];
}
}
barrier();
}
@ -1200,8 +1284,10 @@ shared uint32_t iq3s_grid[512];
void init_iq_shmem(uvec3 wgsize)
{
// copy the table into shared memory and sync
for (uint i = gl_LocalInvocationIndex.x; i < iq3s_grid.length(); i += wgsize.x) {
iq3s_grid[i] = iq3s_grid_const[i];
[[unroll]] for (uint i = 0; i < iq3s_grid.length(); i += wgsize.x) {
if (iq3s_grid.length() % wgsize.x == 0 || i + gl_LocalInvocationIndex.x < iq3s_grid.length()) {
iq3s_grid[i + gl_LocalInvocationIndex.x] = iq3s_grid_const[i + gl_LocalInvocationIndex.x];
}
}
barrier();
}
@ -1270,4 +1356,18 @@ void init_iq_shmem(uvec3 wgsize)
}
#endif
// returns the bfloat value in the low 16b.
// See ggml_compute_fp32_to_bf16
uint32_t fp32_to_bf16(float f)
{
uint32_t u = floatBitsToUint(f);
u = (u + (0x7fff + ((u >> 16) & 1))) >> 16;
return u;
}
float bf16_to_fp32(uint32_t u)
{
return uintBitsToFloat(u << 16);
}
#endif // !defined(GGML_TYPES_COMP)

View File

@ -63,7 +63,8 @@ const std::vector<std::string> type_names = {
"iq3_xxs",
"iq3_s",
"iq4_xs",
"iq4_nl"
"iq4_nl",
"bf16",
};
namespace {
@ -295,7 +296,9 @@ void matmul_shaders(bool fp16, bool matmul_id, bool coopmat, bool coopmat2, bool
std::string aligned_b_type_f32 = coopmat2 ? "float" : fp16 ? "mat2x4" : "vec4";
std::string aligned_b_type_f16 = coopmat2 ? "float16_t" : fp16 ? "f16mat2x4" : "f16vec4";
std::map<std::string, std::string> base_dict = {{"FLOAT_TYPE", (coopmat2 || fp16) ? "float16_t" : "float"}};
std::map<std::string, std::string> base_dict = {
{"FLOAT_TYPE_VEC2", (coopmat2 || fp16) ? "f16vec2" : "vec2"},
};
std::string shader_name = "matmul";
if (matmul_id) {
@ -313,34 +316,81 @@ void matmul_shaders(bool fp16, bool matmul_id, bool coopmat, bool coopmat2, bool
base_dict["COOPMAT"] = "1";
}
base_dict["ACC_TYPE"] = f16acc ? "float16_t" : "float";
const std::string source_name = coopmat2 ? "mul_mm_cm2.comp" : "mul_mm.comp";
std::string source_name = coopmat2 ? "mul_mm_cm2.comp" : "mul_mm.comp";
auto const &FLOAT_TYPE = [&](const std::string &t) -> std::string {
if (t == "bf16") {
// scalar path promotes to float
if (!coopmat && !coopmat2) {
return "float";
}
return "bfloat16_t";
}
if (coopmat2 || fp16) {
return "float16_t";
}
return "float";
};
// Shaders with f16 B_TYPE
string_to_spv(shader_name + "_f32_f16", source_name, merge_maps(base_dict, {{"DATA_A_F32", "1"}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}, }), fp16, coopmat, coopmat2, f16acc);
string_to_spv(shader_name + "_f32_f16_aligned", source_name, merge_maps(base_dict, {{"DATA_A_F32", "1"}, {"LOAD_VEC_A", load_vec}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"D_TYPE", "float"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc);
string_to_spv(shader_name + "_f32_f16", source_name, merge_maps(base_dict, {{"FLOAT_TYPE", FLOAT_TYPE("f16")}, {"DATA_A_F32", "1"}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}, }), fp16, coopmat, coopmat2, f16acc);
string_to_spv(shader_name + "_f32_f16_aligned", source_name, merge_maps(base_dict, {{"FLOAT_TYPE", FLOAT_TYPE("f16")}, {"DATA_A_F32", "1"}, {"LOAD_VEC_A", load_vec}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"D_TYPE", "float"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc);
string_to_spv(shader_name + "_f16_aligned", source_name, merge_maps(base_dict, {{"DATA_A_F16", "1"}, {"LOAD_VEC_A", load_vec}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"D_TYPE", "float"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc);
string_to_spv(shader_name + "_f16", source_name, merge_maps(base_dict, {{"DATA_A_F16", "1"}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}}), fp16, coopmat, coopmat2, f16acc);
string_to_spv(shader_name + "_f16_aligned", source_name, merge_maps(base_dict, {{"FLOAT_TYPE", FLOAT_TYPE("f16")}, {"DATA_A_F16", "1"}, {"LOAD_VEC_A", load_vec}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"D_TYPE", "float"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc);
string_to_spv(shader_name + "_f16", source_name, merge_maps(base_dict, {{"FLOAT_TYPE", FLOAT_TYPE("f16")}, {"DATA_A_F16", "1"}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}}), fp16, coopmat, coopmat2, f16acc);
// bf16
{
std::string load_vec_a_unaligned = "1";
// For aligned matmul loads
std::string load_vec_a = coopmat2 ? "1" : "4";
// scalar path promotes to float
std::string to_float_type = (coopmat || coopmat2) ? "uintBitsToBFloat16EXT" : "bf16_to_fp32";
// If bfloat16 is not supported, then only compile the scalar (promote to fp32) shader
#if !defined(GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT)
if (!(coopmat || coopmat2))
#endif
{
string_to_spv(shader_name + "_bf16_aligned", source_name, merge_maps(base_dict, {{"FLOAT_TYPE", FLOAT_TYPE("bf16")}, {"TO_FLOAT_TYPE", to_float_type}, {"DATA_A_BF16", "1"}, {"LOAD_VEC_A", load_vec_a}, {"LOAD_VEC_B", "4"}, {"B_TYPE", coopmat2 ? "bfloat16_t" : "u16vec4"}, {"D_TYPE", "float"}, {"B_IS_FLOAT", "1"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc);
string_to_spv(shader_name + "_bf16", source_name, merge_maps(base_dict, {{"FLOAT_TYPE", FLOAT_TYPE("bf16")}, {"TO_FLOAT_TYPE", to_float_type}, {"DATA_A_BF16", "1"}, {"LOAD_VEC_A", load_vec_a_unaligned}, {"B_TYPE", coopmat2 ? "bfloat16_t" : "uint16_t"}, {"D_TYPE", "float"}, {"B_IS_FLOAT", "1"}}), fp16, coopmat, coopmat2, f16acc);
}
}
for (const auto& tname : type_names) {
std::string load_vec_quant = "2";
if ((tname == "q4_0") || (tname == "q4_1"))
load_vec_quant = "8";
else if ((tname == "q5_0") || (tname == "q5_1") || (tname == "q8_0") || (tname == "iq4_nl"))
load_vec_quant = "4";
if (tname == "bf16") {
continue;
}
std::string data_a_key = "DATA_A_" + to_uppercase(tname);
// For unaligned, load one at a time for f32/f16, or two at a time for quants
std::string load_vec_a_unaligned = (coopmat2 || tname == "f32" || tname == "f16") ? "1" : "2";
std::string load_vec_a_unaligned = (coopmat2 || tname == "f32" || tname == "f16" || tname == "bf16") ? "1" : load_vec_quant;
// For aligned matmul loads
std::string load_vec_a = (coopmat2 || tname == "f32" || tname == "f16") ? load_vec : "2";
std::string load_vec_a = (coopmat2 || tname == "f32" || tname == "f16" || tname == "bf16") ? load_vec : load_vec_quant;
// don't generate f32 variants for coopmat2
if (!coopmat2) {
string_to_spv(shader_name + "_" + tname + "_f32", source_name, merge_maps(base_dict, {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a_unaligned}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"B_IS_FLOAT", "1"}}), fp16, coopmat, coopmat2, f16acc);
string_to_spv(shader_name + "_" + tname + "_f32_aligned", source_name, merge_maps(base_dict, {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f32}, {"D_TYPE", "float"}, {"B_IS_FLOAT", "1"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc);
string_to_spv(shader_name + "_" + tname + "_f32", source_name, merge_maps(base_dict, {{"FLOAT_TYPE", FLOAT_TYPE(tname)}, {data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a_unaligned}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}), fp16, coopmat, coopmat2, f16acc);
string_to_spv(shader_name + "_" + tname + "_f32_aligned", source_name, merge_maps(base_dict, {{"FLOAT_TYPE", FLOAT_TYPE(tname)}, {data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f32}, {"D_TYPE", "float"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc);
}
if (tname != "f16" && tname != "f32") {
string_to_spv(shader_name + "_" + tname + "_f16", source_name, merge_maps(base_dict, {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a_unaligned}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}, {"B_IS_FLOAT", "1"}}), fp16, coopmat, coopmat2, f16acc);
string_to_spv(shader_name + "_" + tname + "_f16_aligned", source_name, merge_maps(base_dict, {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"D_TYPE", "float"}, {"B_IS_FLOAT", "1"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc);
string_to_spv(shader_name + "_" + tname + "_f16", source_name, merge_maps(base_dict, {{"FLOAT_TYPE", FLOAT_TYPE(tname)}, {data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a_unaligned}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}}), fp16, coopmat, coopmat2, f16acc);
string_to_spv(shader_name + "_" + tname + "_f16_aligned", source_name, merge_maps(base_dict, {{"FLOAT_TYPE", FLOAT_TYPE(tname)}, {data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"D_TYPE", "float"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc);
}
#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
if (!coopmat && !coopmat2 && !matmul_id && (tname == "q4_0" || tname == "q4_1" || tname == "q5_0" || tname == "q5_1" || tname == "q8_0")) {
string_to_spv(shader_name + "_" + tname + "_q8_1", "mul_mmq.comp", merge_maps(base_dict, {{"FLOAT_TYPE", FLOAT_TYPE(tname)}, {data_a_key, "1"}, {"D_TYPE", "float"},}), fp16, coopmat, coopmat2, f16acc);
}
#endif
}
}
@ -371,7 +421,6 @@ void process_shaders() {
#endif
}
#if defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT)
// flash attention
for (const auto& f16acc : {false, true}) {
std::string acctype = f16acc ? "float16_t" : "float";
@ -380,7 +429,9 @@ void process_shaders() {
if (tname == "f32") {
continue;
}
if (tname == "bf16") continue;
#if defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT)
if (tname == "f16") {
string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn_cm2.comp",
merge_maps(base_dict, {{"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"ACC_TYPE", acctype}}), true, false, true, f16acc);
@ -389,14 +440,22 @@ void process_shaders() {
string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn_cm2.comp",
merge_maps(base_dict, {{data_a_key, "1"}, {"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"ACC_TYPE", acctype}, {"DEQUANTFUNC", "dequantFunc"+to_uppercase(tname) }, {"BLOCK_SIZE", "QUANT_K_"+to_uppercase(tname) }}), true, false, true, f16acc);
}
#endif
if (tname == "f16") {
string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn.comp",
merge_maps(base_dict, {{"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"ACC_TYPE", acctype}}), true, false, false, f16acc);
} else if (tname == "q4_0" || tname == "q8_0") {
std::string data_a_key = "DATA_A_" + to_uppercase(tname);
string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn.comp",
merge_maps(base_dict, {{data_a_key, "1"}, {"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"ACC_TYPE", acctype}, {"BLOCK_SIZE", "QUANT_K_"+to_uppercase(tname) }}), true, false, false, f16acc);
}
}
}
#endif
for (const auto& tname : type_names) {
// mul mat vec
std::string data_a_key = "DATA_A_" + to_uppercase(tname);
std::string shader = (string_ends_with(tname, "_k") || string_starts_with(tname, "iq1_")) ? "mul_mat_vec_" + tname + ".comp" : "mul_mat_vec.comp";
std::string shader = (string_ends_with(tname, "_k") || string_starts_with(tname, "iq1_") || string_starts_with(tname, "iq2_") || string_starts_with(tname, "iq3_")) ? "mul_mat_vec_" + tname + ".comp" : "mul_mat_vec.comp";
string_to_spv("mul_mat_vec_" + tname + "_f32_f32", shader, merge_maps(base_dict, {{data_a_key, "1"}, {"B_TYPE", "float"}, {"B_TYPE_VEC2", "vec2"}, {"B_TYPE_VEC4", "vec4"}, {"D_TYPE", "float"}}));
string_to_spv("mul_mat_vec_" + tname + "_f16_f32", shader, merge_maps(base_dict, {{data_a_key, "1"}, {"B_TYPE", "float16_t"}, {"B_TYPE_VEC2", "f16vec2"}, {"B_TYPE_VEC4", "f16vec4"}, {"D_TYPE", "float"}}));
@ -404,12 +463,12 @@ void process_shaders() {
string_to_spv("mul_mat_vec_id_" + tname + "_f32", shader, merge_maps(base_dict, {{"MUL_MAT_ID", "1"}, {data_a_key, "1"}, {"B_TYPE", "float"}, {"B_TYPE_VEC2", "vec2"}, {"B_TYPE_VEC4", "vec4"}, {"D_TYPE", "float"}}));
// Dequant shaders
if (tname != "f16") {
if (tname != "f16" && tname != "bf16") {
string_to_spv("dequant_" + tname, "dequant_" + tname + ".comp", merge_maps(base_dict, {{data_a_key, "1"}, {"D_TYPE", "float16_t"}}));
}
if (!string_ends_with(tname, "_k")) {
shader = (tname == "f32" || tname == "f16") ? "get_rows.comp" : "get_rows_quant.comp";
shader = (tname == "f32" || tname == "f16" || tname == "bf16") ? "get_rows.comp" : "get_rows_quant.comp";
if (tname == "f16") {
string_to_spv("get_rows_" + tname, shader, merge_maps(base_dict, {{data_a_key, "1"}, {"B_TYPE", "int"}, {"D_TYPE", "float16_t"}, {"OPTIMIZATION_ERROR_WORKAROUND", "1"}}));
@ -420,35 +479,62 @@ void process_shaders() {
}
}
string_to_spv("mul_mat_vec_p021_f16_f32", "mul_mat_vec_p021.comp", {{"A_TYPE", "float16_t"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}});
string_to_spv("mul_mat_vec_nc_f16_f32", "mul_mat_vec_nc.comp", {{"A_TYPE", "float16_t"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}});
string_to_spv("mul_mat_vec_p021_f16_f32_subgroup_add", "mul_mat_vec_p021.comp", {{"A_TYPE", "float16_t"}, {"A_TYPE_VEC4", "f16vec4"}, {"B_TYPE", "float"}, {"B_TYPE_VEC4", "vec4"}, {"D_TYPE", "float"}, {"USE_SUBGROUP_ADD", "1"}});
string_to_spv("mul_mat_vec_p021_f16_f32", "mul_mat_vec_p021.comp", {{"A_TYPE", "float16_t"}, {"A_TYPE_VEC4", "f16vec4"}, {"B_TYPE", "float"}, {"B_TYPE_VEC4", "vec4"}, {"D_TYPE", "float"}});
string_to_spv("mul_mat_vec_nc_f16_f32", "mul_mat_vec_nc.comp", {{"A_TYPE", "float16_t"}, {"A_TYPE_VEC4", "f16vec4"}, {"B_TYPE", "float"}, {"B_TYPE_VEC4", "vec4"}, {"D_TYPE", "float"}});
// Norms
string_to_spv("norm_f32", "norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
string_to_spv("group_norm_f32", "group_norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
string_to_spv("rms_norm_f32", "rms_norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
string_to_spv("rms_norm_back_f32", "rms_norm_back.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}));
string_to_spv("l2_norm_f32", "l2_norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
string_to_spv("cpy_f32_f32", "copy.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
string_to_spv("cpy_f32_f16", "copy.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float16_t"}});
string_to_spv("cpy_f16_f16", "copy.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"OPTIMIZATION_ERROR_WORKAROUND", "1"}});
string_to_spv("cpy_f16_f32", "copy.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float"}, {"OPTIMIZATION_ERROR_WORKAROUND", "1"}});
string_to_spv("cpy_f32_bf16","copy.comp", {{"A_TYPE", "float"}, {"D_TYPE", "uint16_t"}, {"DATA_D_BF16", "1"}});
string_to_spv("contig_cpy_f32_f32", "contig_copy.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
string_to_spv("contig_cpy_f32_f16", "contig_copy.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float16_t"}});
string_to_spv("contig_cpy_f16_f16", "contig_copy.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"OPTIMIZATION_ERROR_WORKAROUND", "1"}});
string_to_spv("contig_cpy_f16_f32", "contig_copy.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float"}, {"OPTIMIZATION_ERROR_WORKAROUND", "1"}});
string_to_spv("contig_cpy_f32_bf16","contig_copy.comp",{{"A_TYPE", "float"}, {"D_TYPE", "uint16_t"}, {"DATA_D_BF16", "1"}});
for (std::string t : {"q4_0", "q4_1", "q5_0", "q5_1", "q8_0", "iq4_nl"}) {
string_to_spv("cpy_f32_" + t, "copy_to_quant.comp", {{"DATA_A_" + to_uppercase(t), "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
string_to_spv("cpy_f32_" + t + "_rte", "copy_to_quant.comp", {{"DATA_A_" + to_uppercase(t), "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"RTE16", "1"}});
string_to_spv("cpy_" + t + "_f32", "copy_from_quant.comp", {{"DATA_A_" + to_uppercase(t), "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
}
string_to_spv("add_f32", "add.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
string_to_spv("add_f16_f32_f16", "add.comp", {{"A_TYPE", "float16_t"}, {"B_TYPE", "float"}, {"D_TYPE", "float16_t"}, {"FLOAT_TYPE", "float"}});
auto get_type_str = [](bool f16) {
return f16 ? "float16_t" : "float";
};
auto get_suffix = [](bool src0_f16, bool src1_f16, bool dst_f16) {
std::string s;
s += std::string(src0_f16 ? "_f16" : "_f32");
s += std::string(src1_f16 ? "_f16" : "_f32");
s += std::string(dst_f16 ? "_f16" : "_f32");
return s;
};
for (std::string op : {"add", "sub", "mul", "div"}) {
for (auto src0_f16 : {false, true}) {
for (auto src1_f16 : {false, true}) {
for (auto dst_f16 : {false, true}) {
auto name = op + get_suffix(src0_f16, src1_f16, dst_f16);
string_to_spv(name.c_str(), op + ".comp", {{"A_TYPE", get_type_str(src0_f16)}, {"B_TYPE", get_type_str(src1_f16)}, {"D_TYPE", get_type_str(dst_f16)}, {"FLOAT_TYPE", "float"}});
}
}
}
}
string_to_spv("sub_f32", "sub.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
string_to_spv("acc_f32", "acc.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
string_to_spv("split_k_reduce", "mul_mat_split_k_reduce.comp", {});
string_to_spv("fa_split_k_reduce", "flash_attn_split_k_reduce.comp", {});
string_to_spv("quantize_q8_1", "quantize_q8_1.comp", {});
string_to_spv("mul_f32", "mul.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
@ -475,14 +561,21 @@ void process_shaders() {
string_to_spv("upscale_f32", "upscale.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}});
string_to_spv("gelu_f32", "gelu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
string_to_spv("gelu_quick_f32", "gelu_quick.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
string_to_spv("silu_f32", "silu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
string_to_spv("silu_back_f32", "silu_back.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}});
string_to_spv("relu_f32", "relu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
string_to_spv("leaky_relu_f32", "leaky_relu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
string_to_spv("tanh_f32", "tanh.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
string_to_spv("sigmoid_f32", "sigmoid.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
string_to_spv("gelu_f16", "gelu.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
string_to_spv("gelu_f32", "gelu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
string_to_spv("gelu_quick_f16", "gelu_quick.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
string_to_spv("gelu_quick_f32", "gelu_quick.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
string_to_spv("silu_f16", "silu.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
string_to_spv("silu_f32", "silu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
string_to_spv("relu_f16", "relu.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
string_to_spv("relu_f32", "relu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
string_to_spv("tanh_f16", "tanh.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
string_to_spv("tanh_f32", "tanh.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
string_to_spv("sigmoid_f16", "sigmoid.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
string_to_spv("sigmoid_f32", "sigmoid.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
string_to_spv("leaky_relu_f32", "leaky_relu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
string_to_spv("silu_back_f32", "silu_back.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}});
string_to_spv("diag_mask_inf_f32", "diag_mask_inf.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
@ -522,8 +615,13 @@ void process_shaders() {
string_to_spv("rwkv_wkv6_f32", "wkv6.comp", merge_maps(base_dict, {{"A_TYPE", "float"}}));
string_to_spv("rwkv_wkv7_f32", "wkv7.comp", merge_maps(base_dict, {{"A_TYPE", "float"}}));
string_to_spv("opt_step_adamw_f32", "opt_step_adamw.comp", merge_maps(base_dict, {{"A_TYPE", "float"}}));
string_to_spv("conv2d_dw_whcn_f32", "conv2d_dw.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"WHCN", "1"}}));
string_to_spv("conv2d_dw_cwhn_f32", "conv2d_dw.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"CWHN", "1"}}));
for (auto &c : compiles) {
c.wait();
}
@ -578,7 +676,12 @@ void write_output_files() {
std::remove(path.c_str());
}
}
for (const char *op : {"add", "sub", "mul", "div"}) {
fprintf(hdr, "extern unsigned char *%s_data[2][2][2];\n", op);
fprintf(hdr, "extern uint64_t %s_len[2][2][2];\n", op);
fprintf(src, "unsigned char *%s_data[2][2][2] = {{{%s_f32_f32_f32_data, %s_f32_f32_f16_data}, {%s_f32_f16_f32_data, %s_f32_f16_f16_data}}, {{%s_f16_f32_f32_data, %s_f16_f32_f16_data}, {%s_f16_f16_f32_data, %s_f16_f16_f16_data}}};\n", op, op, op, op, op, op, op, op, op);
fprintf(src, "uint64_t %s_len[2][2][2] = {{{%s_f32_f32_f32_len, %s_f32_f32_f16_len}, {%s_f32_f16_f32_len, %s_f32_f16_f16_len}}, {{%s_f16_f32_f32_len, %s_f16_f32_f16_len}, {%s_f16_f16_f32_len, %s_f16_f16_f16_len}}};\n", op, op, op, op, op, op, op, op, op);
}
fclose(hdr);
fclose(src);
}

View File

@ -0,0 +1,91 @@
#version 450
#extension GL_EXT_control_flow_attributes : require
#define BLOCK_SIZE 64
layout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in;
layout(push_constant) uniform Parameters {
uint B;
uint T;
uint C;
uint H;
};
layout(binding = 0) readonly buffer RBuf { A_TYPE r[]; };
layout(binding = 1) readonly buffer WBuf { A_TYPE w[]; };
layout(binding = 2) readonly buffer KBuf { A_TYPE k[]; };
layout(binding = 3) readonly buffer VBuf { A_TYPE v[]; };
layout(binding = 4) readonly buffer ABuf { A_TYPE a[]; };
layout(binding = 5) readonly buffer BBuf { A_TYPE b[]; };
layout(binding = 6) readonly buffer StateBuf { A_TYPE state_in[]; };
layout(binding = 7) buffer DstBuf { A_TYPE dst[]; };
shared A_TYPE _r[BLOCK_SIZE], _w[BLOCK_SIZE], _k[BLOCK_SIZE], _a[BLOCK_SIZE], _b[BLOCK_SIZE];
void main() {
const uint head_size = BLOCK_SIZE;
const uint batch_id = gl_WorkGroupID.x / H;
const uint head_id = gl_WorkGroupID.x % H;
const uint tid = gl_LocalInvocationID.x;
const uint state_size = C * head_size;
const uint n_seq_tokens = T / B;
if (batch_id >= B || head_id >= H) {
return;
}
A_TYPE state[BLOCK_SIZE];
[[unroll]] for (uint i = 0; i < head_size; i++) {
state[i] = state_in[batch_id * state_size + head_id * head_size * head_size
+ tid * head_size + i];
}
const uint start_t = batch_id * n_seq_tokens * C + head_id * head_size + tid;
const uint end_t = (batch_id + 1) * n_seq_tokens * C + head_id * head_size + tid;
for (uint t = start_t; t < end_t; t += C) {
barrier();
_r[tid] = r[t];
_w[tid] = w[t];
_k[tid] = k[t];
_a[tid] = a[t];
_b[tid] = b[t];
barrier();
A_TYPE sa = 0.0;
[[unroll]] for (uint j = 0; j < head_size; j += 4) {
vec4 s_vec = vec4(state[j], state[j+1], state[j+2], state[j+3]);
vec4 a_vec = vec4(_a[j], _a[j+1], _a[j+2], _a[j+3]);
sa += dot(s_vec, a_vec);
}
const A_TYPE v_val = v[t];
A_TYPE y = 0.0;
[[unroll]] for (uint j = 0; j < head_size; j += 4) {
vec4 r_vec = vec4(_r[j], _r[j+1], _r[j+2], _r[j+3]);
vec4 w_vec = vec4(_w[j], _w[j+1], _w[j+2], _w[j+3]);
vec4 k_vec = vec4(_k[j], _k[j+1], _k[j+2], _k[j+3]);
vec4 b_vec = vec4(_b[j], _b[j+1], _b[j+2], _b[j+3]);
vec4 s_vec = vec4(state[j], state[j+1], state[j+2], state[j+3]);
vec4 kv = k_vec * v_val;
s_vec = s_vec * w_vec + kv + sa * b_vec;
y += dot(r_vec, s_vec);
state[j] = s_vec.x;
state[j+1] = s_vec.y;
state[j+2] = s_vec.z;
state[j+3] = s_vec.w;
}
dst[t] = y;
}
[[unroll]] for (uint i = 0; i < head_size; i++) {
dst[T * C + batch_id * state_size + head_id * head_size * head_size
+ tid * head_size + i] = state[i];
}
}