From 2e7452be718999c738adadafc9d8e05ff1bd7576 Mon Sep 17 00:00:00 2001 From: Inforithmics Date: Sun, 10 Aug 2025 16:01:07 +0200 Subject: [PATCH] Update Vulkan Code to de4c07f93783a1a96456a44dc16b9db538ee1618 --- .../ggml/ggml/src/ggml-vulkan/CMakeLists.txt | 40 + .../ggml-vulkan/cmake/host-toolchain.cmake.in | 15 + .../ggml/ggml/src/ggml-vulkan/ggml-vulkan.cpp | 2007 +++++++++++++---- .../ggml-vulkan/vulkan-shaders/CMakeLists.txt | 21 +- .../vulkan-shaders/contig_copy.comp | 11 +- .../ggml-vulkan/vulkan-shaders/conv2d_dw.comp | 105 + .../src/ggml-vulkan/vulkan-shaders/copy.comp | 5 +- .../vulkan-shaders/copy_to_quant.comp | 5 + .../vulkan-shaders/dequant_funcs.comp | 14 +- .../vulkan-shaders/dequant_funcs_cm2.comp | 148 +- .../vulkan-shaders/flash_attn.comp | 483 ++++ .../vulkan-shaders/flash_attn_cm2.comp | 126 +- .../flash_attn_split_k_reduce.comp | 59 + .../ggml-vulkan/vulkan-shaders/get_rows.comp | 11 +- .../vulkan-shaders/get_rows_quant.comp | 2 + .../ggml-vulkan/vulkan-shaders/im2col.comp | 53 +- .../ggml-vulkan/vulkan-shaders/l2_norm.comp | 41 + .../vulkan-shaders/mul_mat_vec.comp | 22 +- .../vulkan-shaders/mul_mat_vec_iq2_s.comp | 90 + .../vulkan-shaders/mul_mat_vec_iq2_xs.comp | 87 + .../vulkan-shaders/mul_mat_vec_iq2_xxs.comp | 87 + .../vulkan-shaders/mul_mat_vec_iq3_s.comp | 90 + .../vulkan-shaders/mul_mat_vec_iq3_xxs.comp | 88 + .../vulkan-shaders/mul_mat_vec_nc.comp | 75 +- .../vulkan-shaders/mul_mat_vec_p021.comp | 137 +- .../vulkan-shaders/mul_mat_vec_q2_k.comp | 47 +- .../vulkan-shaders/mul_mat_vec_q3_k.comp | 26 +- .../vulkan-shaders/mul_mat_vec_q6_k.comp | 12 +- .../ggml-vulkan/vulkan-shaders/mul_mm.comp | 187 +- .../vulkan-shaders/mul_mm_cm2.comp | 262 ++- .../ggml-vulkan/vulkan-shaders/mul_mmq.comp | 442 ++++ .../vulkan-shaders/mul_mmq_funcs.comp | 99 + .../vulkan-shaders/quantize_q8_1.comp | 77 + .../src/ggml-vulkan/vulkan-shaders/relu.comp | 2 +- .../ggml-vulkan/vulkan-shaders/rms_norm.comp | 32 +- .../ggml-vulkan/vulkan-shaders/sigmoid.comp | 2 +- .../src/ggml-vulkan/vulkan-shaders/tanh.comp | 2 +- .../vulkan-shaders/test_bfloat16_support.comp | 7 + .../test_integer_dot_support.comp | 7 + .../src/ggml-vulkan/vulkan-shaders/types.comp | 132 +- .../vulkan-shaders/vulkan-shaders-gen.cpp | 167 +- .../src/ggml-vulkan/vulkan-shaders/wkv7.comp | 91 + 42 files changed, 4607 insertions(+), 809 deletions(-) create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/cmake/host-toolchain.cmake.in create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_dw.comp create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/l2_norm.comp create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_s.comp create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_xs.comp create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_xxs.comp create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq3_s.comp create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq3_xxs.comp create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.comp create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/quantize_q8_1.comp create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/test_bfloat16_support.comp create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/test_integer_dot_support.comp create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/wkv7.comp diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/CMakeLists.txt b/ml/backend/ggml/ggml/src/ggml-vulkan/CMakeLists.txt index d970f7e20..31816219c 100644 --- a/ml/backend/ggml/ggml/src/ggml-vulkan/CMakeLists.txt +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/CMakeLists.txt @@ -32,8 +32,10 @@ if (Vulkan_FOUND) if (${glslc_error} MATCHES ".*extension not supported: GL_KHR_cooperative_matrix.*") message(STATUS "GL_KHR_cooperative_matrix not supported by glslc") + set(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT OFF) else() message(STATUS "GL_KHR_cooperative_matrix supported by glslc") + set(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT ON) add_compile_definitions(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT) endif() @@ -46,11 +48,45 @@ if (Vulkan_FOUND) if (${glslc_error} MATCHES ".*extension not supported: GL_NV_cooperative_matrix2.*") message(STATUS "GL_NV_cooperative_matrix2 not supported by glslc") + set(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT OFF) else() message(STATUS "GL_NV_cooperative_matrix2 supported by glslc") + set(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT ON) add_compile_definitions(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT) endif() + # Compile a test shader to determine whether GL_EXT_integer_dot_product is supported. + # If it's not, there will be an error to stderr. + # If it's supported, set a define to indicate that we should compile those shaders + execute_process(COMMAND ${Vulkan_GLSLC_EXECUTABLE} -o - -fshader-stage=compute --target-env=vulkan1.3 "${CMAKE_CURRENT_SOURCE_DIR}/vulkan-shaders/test_integer_dot_support.comp" + OUTPUT_VARIABLE glslc_output + ERROR_VARIABLE glslc_error) + + if (${glslc_error} MATCHES ".*extension not supported: GL_EXT_integer_dot_product.*") + message(STATUS "GL_EXT_integer_dot_product not supported by glslc") + set(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT OFF) + else() + message(STATUS "GL_EXT_integer_dot_product supported by glslc") + set(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT ON) + add_compile_definitions(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT) + endif() + + # Compile a test shader to determine whether GL_EXT_bfloat16 is supported. + # If it's not, there will be an error to stderr. + # If it's supported, set a define to indicate that we should compile those shaders + execute_process(COMMAND ${Vulkan_GLSLC_EXECUTABLE} -o - -fshader-stage=compute --target-env=vulkan1.3 "${CMAKE_CURRENT_SOURCE_DIR}/vulkan-shaders/test_bfloat16_support.comp" + OUTPUT_VARIABLE glslc_output + ERROR_VARIABLE glslc_error) + + if (${glslc_error} MATCHES ".*extension not supported: GL_EXT_bfloat16.*") + message(STATUS "GL_EXT_bfloat16 not supported by glslc") + set(GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT OFF) + else() + message(STATUS "GL_EXT_bfloat16 supported by glslc") + set(GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT ON) + add_compile_definitions(GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT) + endif() + target_link_libraries(ggml-vulkan PRIVATE Vulkan::Vulkan) target_include_directories(ggml-vulkan PRIVATE ${CMAKE_CURRENT_BINARY_DIR}) @@ -119,6 +155,10 @@ if (Vulkan_FOUND) SOURCE_DIR ${CMAKE_CURRENT_SOURCE_DIR}/vulkan-shaders CMAKE_ARGS -DCMAKE_TOOLCHAIN_FILE=${HOST_CMAKE_TOOLCHAIN_FILE} -DCMAKE_INSTALL_PREFIX=${CMAKE_BINARY_DIR} + -DGGML_VULKAN_COOPMAT_GLSLC_SUPPORT=${GGML_VULKAN_COOPMAT_GLSLC_SUPPORT} + -DGGML_VULKAN_COOPMAT2_GLSLC_SUPPORT=${GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT} + -DGGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT=${GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT} + -DGGML_VULKAN_BFLOAT16_GLSLC_SUPPORT=${GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT} BUILD_COMMAND ${CMAKE_COMMAND} --build . INSTALL_COMMAND ${CMAKE_COMMAND} --install . INSTALL_DIR ${CMAKE_BINARY_DIR} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/cmake/host-toolchain.cmake.in b/ml/backend/ggml/ggml/src/ggml-vulkan/cmake/host-toolchain.cmake.in new file mode 100644 index 000000000..2d8a85696 --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/cmake/host-toolchain.cmake.in @@ -0,0 +1,15 @@ +set(CMAKE_BUILD_TYPE Release) +set(CMAKE_C_FLAGS -O2) +set(CMAKE_CXX_FLAGS -O2) +set(CMAKE_FIND_ROOT_PATH_MODE_PROGRAM NEVER) +set(CMAKE_FIND_ROOT_PATH_MODE_LIBRARY NEVER) +set(CMAKE_FIND_ROOT_PATH_MODE_INCLUDE NEVER) +set(CMAKE_C_COMPILER "@HOST_C_COMPILER@") +set(CMAKE_CXX_COMPILER "@HOST_CXX_COMPILER@") +set(CMAKE_RUNTIME_OUTPUT_DIRECTORY @CMAKE_RUNTIME_OUTPUT_DIRECTORY@) + +if("@CMAKE_C_COMPILER_ID@" STREQUAL "MSVC") + foreach(CONFIG IN ITEMS DEBUG RELEASE MINSIZEREL RELWITHDEBINFO) + set(CMAKE_RUNTIME_OUTPUT_DIRECTORY_${CONFIG} ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}) + endforeach() +endif() diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ml/backend/ggml/ggml/src/ggml-vulkan/ggml-vulkan.cpp index abe3e7908..e2b357fdc 100644 --- a/ml/backend/ggml/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -24,12 +24,54 @@ #include #include +#if defined(_MSC_VER) +# define NOMINMAX 1 +# include +# define YIELD() YieldProcessor() +#elif defined(__clang__) || defined(__GNUC__) +# if defined(__x86_64__) ||defined(__i386__) +# include +# define YIELD() _mm_pause() +# elif defined(__arm__) || defined(__aarch64__) +# if defined(__clang__) +# include +# define YIELD() __yield() +# else +# define YIELD() asm volatile("yield") +# endif +# endif +#endif + +#if !defined(YIELD) +#define YIELD() +#endif + #include "ggml-impl.h" #include "ggml-backend-impl.h" #include "ggml-vulkan-shaders.hpp" +// remove this once it's more widely available in the SDK +#if !defined(VK_KHR_shader_bfloat16) + +#define VK_KHR_shader_bfloat16 1 +#define VK_KHR_SHADER_BFLOAT16_SPEC_VERSION 1 +#define VK_KHR_SHADER_BFLOAT16_EXTENSION_NAME "VK_KHR_shader_bfloat16" +#define VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_SHADER_BFLOAT16_FEATURES_KHR ((VkStructureType)1000141000) +#define VK_COMPONENT_TYPE_BFLOAT16_KHR ((VkComponentTypeKHR)1000141000) + +typedef struct VkPhysicalDeviceShaderBfloat16FeaturesKHR { + VkStructureType sType; + void* pNext; + VkBool32 shaderBFloat16Type; + VkBool32 shaderBFloat16DotProduct; + VkBool32 shaderBFloat16CooperativeMatrix; +} VkPhysicalDeviceShaderBfloat16FeaturesKHR; +#endif + +#define ROUNDUP_POW2(M, N) (((M) + (N) - 1) & ~((N) - 1)) #define CEIL_DIV(M, N) (((M) + (N)-1) / (N)) +static bool is_pow2(uint32_t x) { return x > 1 && (x & (x-1)) == 0; } #define VK_VENDOR_ID_AMD 0x1002 #define VK_VENDOR_ID_APPLE 0x106b @@ -148,6 +190,67 @@ class vk_perf_logger; static void ggml_vk_destroy_buffer(vk_buffer& buf); static constexpr uint32_t mul_mat_vec_max_cols = 8; +static constexpr uint32_t p021_max_gqa_ratio = 8; + +enum vk_device_architecture { + OTHER, + AMD_GCN, + AMD_RDNA1, + AMD_RDNA2, + AMD_RDNA3, +}; + +static vk_device_architecture get_device_architecture(const vk::PhysicalDevice& device) { + vk::PhysicalDeviceProperties props = device.getProperties(); + + if (props.vendorID == VK_VENDOR_ID_AMD) { + const std::vector ext_props = device.enumerateDeviceExtensionProperties(); + + bool amd_shader_core_properties = false; + bool integer_dot_product = false; + bool subgroup_size_control = false; + + for (const auto& properties : ext_props) { + if (strcmp("VK_AMD_shader_core_properties", properties.extensionName) == 0) { + amd_shader_core_properties = true; + } else if (strcmp("VK_KHR_shader_integer_dot_product", properties.extensionName) == 0) { + integer_dot_product = true; + } else if (strcmp("VK_EXT_subgroup_size_control", properties.extensionName) == 0) { + subgroup_size_control = true; + } + } + + if (!amd_shader_core_properties || !integer_dot_product || !subgroup_size_control) { + return vk_device_architecture::OTHER; + } + + vk::PhysicalDeviceProperties2 props2; + vk::PhysicalDeviceShaderCorePropertiesAMD shader_core_props_amd; + vk::PhysicalDeviceShaderIntegerDotProductPropertiesKHR integer_dot_props; + vk::PhysicalDeviceSubgroupSizeControlPropertiesEXT subgroup_size_control_props; + + props2.pNext = &shader_core_props_amd; + shader_core_props_amd.pNext = &integer_dot_props; + integer_dot_props.pNext = &subgroup_size_control_props; + + device.getProperties2(&props2); + + if (subgroup_size_control_props.maxSubgroupSize == 64 && subgroup_size_control_props.minSubgroupSize == 64) { + return vk_device_architecture::AMD_GCN; + } + if (subgroup_size_control_props.maxSubgroupSize == 64 && subgroup_size_control_props.minSubgroupSize == 32) { + // RDNA + if (shader_core_props_amd.wavefrontsPerSimd == 20) { + return vk_device_architecture::AMD_RDNA1; + } + if (integer_dot_props.integerDotProduct4x8BitPackedMixedSignednessAccelerated) { + return vk_device_architecture::AMD_RDNA3; + } + return vk_device_architecture::AMD_RDNA2; + } + } + return vk_device_architecture::OTHER; +} struct vk_device_struct { std::mutex mutex; @@ -161,6 +264,8 @@ struct vk_device_struct { bool pipeline_robustness; vk::Device device; uint32_t vendor_id; + vk::DriverId driver_id; + vk_device_architecture architecture; vk_queue compute_queue; vk_queue transfer_queue; bool single_queue; @@ -169,6 +274,10 @@ struct vk_device_struct { bool uma; bool prefer_host_memory; bool float_controls_rte_fp16; + bool subgroup_add; + bool subgroup_shuffle; + + bool integer_dot_product; bool subgroup_size_control; uint32_t subgroup_min_size; @@ -176,11 +285,18 @@ struct vk_device_struct { bool subgroup_require_full_support; bool coopmat_support; - bool coopmat_acc_f32_support; - bool coopmat_acc_f16_support; + bool coopmat_acc_f32_support {}; + bool coopmat_acc_f16_support {}; + bool coopmat_bf16_support {}; uint32_t coopmat_m; uint32_t coopmat_n; uint32_t coopmat_k; + + bool coopmat_int_support; + uint32_t coopmat_int_m; + uint32_t coopmat_int_n; + uint32_t coopmat_int_k; + bool coopmat2; size_t idx; @@ -197,34 +313,45 @@ struct vk_device_struct { vk_matmul_pipeline pipeline_matmul_f32 {}; vk_matmul_pipeline pipeline_matmul_f32_f16 {}; + vk_matmul_pipeline pipeline_matmul_bf16 {}; vk_matmul_pipeline2 pipeline_matmul_f16; vk_matmul_pipeline2 pipeline_matmul_f16_f32; - vk_pipeline pipeline_matmul_split_k_reduce; - vk_matmul_pipeline2 pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_COUNT]; vk_matmul_pipeline2 pipeline_dequant_mul_mat_mat[GGML_TYPE_COUNT]; + vk_matmul_pipeline2 pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_COUNT]; + vk_matmul_pipeline2 pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_COUNT]; vk_matmul_pipeline pipeline_matmul_id_f32 {}; + vk_matmul_pipeline pipeline_matmul_id_bf16 {}; vk_matmul_pipeline2 pipeline_matmul_id_f16; vk_matmul_pipeline2 pipeline_matmul_id_f16_f32; vk_matmul_pipeline2 pipeline_dequant_mul_mat_mat_id[GGML_TYPE_COUNT]; + vk_pipeline pipeline_matmul_split_k_reduce; + vk_pipeline pipeline_quantize_q8_1; + vk_pipeline pipeline_dequant[GGML_TYPE_COUNT]; vk_pipeline pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_COUNT][mul_mat_vec_max_cols]; vk_pipeline pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_COUNT][mul_mat_vec_max_cols]; vk_pipeline pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_COUNT]; - vk_pipeline pipeline_mul_mat_vec_p021_f16_f32; + vk_pipeline pipeline_mul_mat_vec_p021_f16_f32[p021_max_gqa_ratio]; vk_pipeline pipeline_mul_mat_vec_nc_f16_f32; vk_pipeline pipeline_get_rows[GGML_TYPE_COUNT]; vk_pipeline pipeline_get_rows_f32[GGML_TYPE_COUNT]; vk_pipeline pipeline_acc_f32; - vk_pipeline pipeline_add_f32, pipeline_add_f32_norepeat; - vk_pipeline pipeline_add_f16_f32_f16, pipeline_add_f16_f32_f16_norepeat; - vk_pipeline pipeline_sub_f32, pipeline_sub_f32_norepeat; - vk_pipeline pipeline_mul_f32, pipeline_mul_f32_norepeat; - vk_pipeline pipeline_div_f32, pipeline_div_f32_norepeat; + + // [src0 0=fp32,1=fp16][src1 0=fp32,1=fp16][dst 0=fp32,1=fp16] + vk_pipeline pipeline_add[2][2][2]; + vk_pipeline pipeline_add_norepeat[2][2][2]; + vk_pipeline pipeline_sub[2][2][2]; + vk_pipeline pipeline_sub_norepeat[2][2][2]; + vk_pipeline pipeline_mul[2][2][2]; + vk_pipeline pipeline_mul_norepeat[2][2][2]; + vk_pipeline pipeline_div[2][2][2]; + vk_pipeline pipeline_div_norepeat[2][2][2]; + vk_pipeline pipeline_concat_f32, pipeline_concat_f16, pipeline_concat_i32; vk_pipeline pipeline_upscale_f32; vk_pipeline pipeline_scale_f32; @@ -234,22 +361,26 @@ struct vk_device_struct { vk_pipeline pipeline_clamp_f32; vk_pipeline pipeline_pad_f32; vk_pipeline pipeline_repeat_f32, pipeline_repeat_back_f32; - vk_pipeline pipeline_cpy_f32_f32, pipeline_cpy_f32_f16, pipeline_cpy_f16_f16; - vk_pipeline pipeline_contig_cpy_f32_f32, pipeline_contig_cpy_f32_f16, pipeline_contig_cpy_f16_f16; + vk_pipeline pipeline_cpy_f32_f32, pipeline_cpy_f32_f16, pipeline_cpy_f16_f16, pipeline_cpy_f16_f32, pipeline_cpy_f32_bf16; + vk_pipeline pipeline_contig_cpy_f32_f32, pipeline_contig_cpy_f32_f16, pipeline_contig_cpy_f16_f16, pipeline_contig_cpy_f16_f32, pipeline_contig_cpy_f32_bf16; vk_pipeline pipeline_cpy_f32_quant[GGML_TYPE_COUNT]; vk_pipeline pipeline_cpy_quant_f32[GGML_TYPE_COUNT]; vk_pipeline pipeline_norm_f32; vk_pipeline pipeline_group_norm_f32; vk_pipeline pipeline_rms_norm_f32; vk_pipeline pipeline_rms_norm_back_f32; - vk_pipeline pipeline_gelu_f32; - vk_pipeline pipeline_gelu_quick_f32; - vk_pipeline pipeline_silu_f32; - vk_pipeline pipeline_silu_back_f32; - vk_pipeline pipeline_relu_f32; + vk_pipeline pipeline_l2_norm_f32; + + // [src/dst 0=fp32,1=fp16] + vk_pipeline pipeline_gelu[2]; + vk_pipeline pipeline_gelu_quick[2]; + vk_pipeline pipeline_silu[2]; + vk_pipeline pipeline_relu[2]; + vk_pipeline pipeline_tanh[2]; + vk_pipeline pipeline_sigmoid[2]; + vk_pipeline pipeline_leaky_relu_f32; - vk_pipeline pipeline_tanh_f32; - vk_pipeline pipeline_sigmoid_f32; + vk_pipeline pipeline_silu_back_f32; vk_pipeline pipeline_diag_mask_inf_f32; vk_pipeline pipeline_soft_max_f32, pipeline_soft_max_f32_f16; vk_pipeline pipeline_soft_max_f32_wg512, pipeline_soft_max_f32_f16_wg512; @@ -266,9 +397,19 @@ struct vk_device_struct { vk_pipeline pipeline_timestep_embedding_f32; vk_pipeline pipeline_pool2d_f32; vk_pipeline pipeline_rwkv_wkv6_f32; + vk_pipeline pipeline_rwkv_wkv7_f32; vk_pipeline pipeline_opt_step_adamw_f32; + vk_pipeline pipeline_conv2d_dw_whcn_f32; + vk_pipeline pipeline_conv2d_dw_cwhn_f32; // [2][2][2] is for {f16acc,f32acc}x{large,small_rows}x{unaligned, aligned} + vk_pipeline pipeline_flash_attn_f32_f16_D64_cm2[GGML_TYPE_COUNT][2][2][2]; + vk_pipeline pipeline_flash_attn_f32_f16_D80_cm2[GGML_TYPE_COUNT][2][2][2]; + vk_pipeline pipeline_flash_attn_f32_f16_D96_cm2[GGML_TYPE_COUNT][2][2][2]; + vk_pipeline pipeline_flash_attn_f32_f16_D112_cm2[GGML_TYPE_COUNT][2][2][2]; + vk_pipeline pipeline_flash_attn_f32_f16_D128_cm2[GGML_TYPE_COUNT][2][2][2]; + vk_pipeline pipeline_flash_attn_f32_f16_D256_cm2[GGML_TYPE_COUNT][2][2][2]; + vk_pipeline pipeline_flash_attn_f32_f16_D64[GGML_TYPE_COUNT][2][2][2]; vk_pipeline pipeline_flash_attn_f32_f16_D80[GGML_TYPE_COUNT][2][2][2]; vk_pipeline pipeline_flash_attn_f32_f16_D96[GGML_TYPE_COUNT][2][2][2]; @@ -276,6 +417,8 @@ struct vk_device_struct { vk_pipeline pipeline_flash_attn_f32_f16_D128[GGML_TYPE_COUNT][2][2][2]; vk_pipeline pipeline_flash_attn_f32_f16_D256[GGML_TYPE_COUNT][2][2][2]; + vk_pipeline pipeline_flash_attn_split_k_reduce; + std::unordered_map pipelines; std::unordered_map pipeline_descriptor_set_requirements; @@ -368,6 +511,7 @@ struct vk_mat_mat_push_constants { uint32_t batch_stride_a; uint32_t batch_stride_b; uint32_t batch_stride_d; uint32_t k_split; uint32_t ne02; uint32_t ne12; uint32_t broadcast2; uint32_t broadcast3; + uint32_t padded_N; }; struct vk_mat_vec_push_constants { uint32_t ncols; uint32_t stride_a; uint32_t stride_b; uint32_t stride_d; @@ -380,6 +524,7 @@ struct vk_mat_mat_id_push_constants { uint32_t stride_a; uint32_t stride_b; uint32_t stride_d; uint32_t batch_stride_a; uint32_t batch_stride_b; uint32_t batch_stride_d; uint32_t nei0; uint32_t nei1; uint32_t nbi1; uint32_t ne11; + uint32_t padded_N; }; struct vk_mat_vec_id_push_constants { uint32_t ncols; uint32_t stride_a; uint32_t stride_b; uint32_t stride_d; @@ -422,6 +567,10 @@ struct vk_flash_attn_push_constants { uint32_t n_head_log2; float m0; float m1; + + uint32_t gqa_ratio; + uint32_t split_kv; + uint32_t k_num; }; struct vk_op_push_constants { @@ -565,13 +714,29 @@ struct vk_op_rwkv_wkv6_push_constants { uint32_t H; }; -// Allow pre-recording command buffers -struct vk_staging_memcpy { - vk_staging_memcpy(void * _dst, const void * _src, size_t _n) : dst(_dst), src(_src), n(_n) {} +struct vk_op_rwkv_wkv7_push_constants { + uint32_t B; + uint32_t T; + uint32_t C; + uint32_t H; +}; - void * dst; - const void * src; - size_t n; +struct vk_op_conv2d_dw_push_constants { + uint32_t ne; + uint32_t batches; + uint32_t channels; + uint32_t dst_w; + uint32_t dst_h; + uint32_t src_w; + uint32_t src_h; + uint32_t knl_w; + uint32_t knl_h; + int32_t stride_x; + int32_t stride_y; + int32_t pad_x; + int32_t pad_y; + int32_t dilation_x; + int32_t dilation_y; }; struct vk_op_upscale_push_constants { @@ -581,6 +746,15 @@ struct vk_op_upscale_push_constants { float sf0; float sf1; float sf2; float sf3; }; +// Allow pre-recording command buffers +struct vk_staging_memcpy { + vk_staging_memcpy(void * _dst, const void * _src, size_t _n) : dst(_dst), src(_src), n(_n) {} + + void * dst; + const void * src; + size_t n; +}; + struct vk_context_struct { vk_submission * s; std::vector seqs; @@ -695,7 +869,8 @@ struct ggml_backend_vk_context { ggml_vk_garbage_collector gc; size_t prealloc_size_x, prealloc_size_y, prealloc_size_split_k; vk_buffer prealloc_x, prealloc_y, prealloc_split_k; - vk::Fence fence; + vk::Fence fence, almost_ready_fence; + bool almost_ready_fence_pending {}; vk_buffer buffer_pool[MAX_VK_BUFFERS]; @@ -786,6 +961,39 @@ typedef void (*ggml_vk_func_t)(ggml_backend_vk_context * ctx, vk_context& subctx static void ggml_backend_vk_free(ggml_backend_t backend); +// Wait for ctx->fence to be signaled. +static void ggml_vk_wait_for_fence(ggml_backend_vk_context * ctx) { + // Use waitForFences while most of the graph executes. Hopefully the CPU can sleep + // during this wait. + if (ctx->almost_ready_fence_pending) { + VK_CHECK(ctx->device->device.waitForFences({ ctx->almost_ready_fence }, true, UINT64_MAX), "almost_ready_fence"); + ctx->device->device.resetFences({ ctx->almost_ready_fence }); + ctx->almost_ready_fence_pending = false; + } + + // Spin (w/pause) waiting for the graph to finish executing. + vk::Result result; + while ((result = ctx->device->device.getFenceStatus(ctx->fence)) != vk::Result::eSuccess) { + if (result != vk::Result::eNotReady) { + fprintf(stderr, "ggml_vulkan: error %s at %s:%d\n", to_string(result).c_str(), __FILE__, __LINE__); + exit(1); + } + for (uint32_t i = 0; i < 100; ++i) { + YIELD(); + YIELD(); + YIELD(); + YIELD(); + YIELD(); + YIELD(); + YIELD(); + YIELD(); + YIELD(); + YIELD(); + } + } + ctx->device->device.resetFences({ ctx->fence }); +} + // variables to track number of compiles in progress static uint32_t compile_count = 0; static std::mutex compile_count_mutex; @@ -1382,13 +1590,29 @@ static void ggml_vk_wait_events(vk_context& ctx, std::vector&& events // number of rows/cols for flash attention shader static constexpr uint32_t flash_attention_num_small_rows = 32; -static std::array fa_rows_cols(uint32_t D, uint32_t clamp, ggml_type type, bool small_rows) { +static constexpr uint32_t scalar_flash_attention_num_small_rows = 1; +static constexpr uint32_t scalar_flash_attention_num_large_rows = 8; + +static uint32_t get_fa_num_small_rows(bool scalar) { + return scalar ? scalar_flash_attention_num_small_rows : flash_attention_num_small_rows; +} + +static std::array fa_rows_cols(bool scalar, uint32_t D, uint32_t clamp, ggml_type type, bool small_rows) { GGML_UNUSED(clamp); + if (scalar) { + if (small_rows) { + return {scalar_flash_attention_num_small_rows, 64}; + } else { + return {scalar_flash_attention_num_large_rows, 32}; + } + } + // small rows, large cols if (small_rows) { - return {flash_attention_num_small_rows, 128}; + return {get_fa_num_small_rows(scalar), 32}; } + // small cols to reduce register count if (ggml_is_quantized(type) || D == 256) { return {64, 32}; @@ -1433,7 +1657,7 @@ static bool ggml_vk_matmul_shmem_support(const vk_device& device, const std::vec const uint32_t warps = warptile[0] / warptile[10]; const uint32_t load_bufs = (warptile[1] + warptile[2]) * (warptile[3] + bank_conflict_offset) * type_size; - const uint32_t mmid_row_ids = mul_mat_id ? 3072 * sizeof(uint32_t) : 0; + const uint32_t mmid_row_ids = mul_mat_id ? 4096 * sizeof(uint32_t) : 0; const uint32_t coopmat_stage = device->coopmat_support ? warptile[7] * warptile[8] / warps * sizeof(float) : 0; const uint32_t total_size = load_bufs + mmid_row_ids + coopmat_stage + lut_size; @@ -1445,6 +1669,73 @@ static bool ggml_vk_matmul_shmem_support(const vk_device& device, const std::vec return supported; } +struct GpuPipelineConfig { + // GPU architecture identifier. + // Example: vk_device_architecture::AMD_GCN + vk_device_architecture arch; + + // Mapping of pipeline names to their specific subgroup sizes. + // Example: {"soft_max_f32", 64} + std::unordered_map pipelines; + + // Default subgroup size for this GPU. + // Defaults to 0 if not explicitly provided. + uint32_t default_subgroup_size = 0; +}; + +// Pipeline configuration for RDNA1 GPUs. +static const std::unordered_map rdna1_pipelines = { + {"soft_max", 64}, {"im2col", 64}, + {"argmax", 64}, {"mul_mat_vec", 64}, + {"mul_mat_vec_f16", 32}, {"mul_mat_vec_f32_f16", 32} +}; + +// Pipeline configuration for RDNA2 GPUs. +static const std::unordered_map rdna2_pipelines = { + {"soft_max", 64}, {"im2col", 64}, +}; + +static constexpr uint32_t RDNA_DEFAULT_SUBGROUP_SIZE = 32; + +// Define configurations for different GPUs. +static std::vector gpu_pipeline_configs = { + { + vk_device_architecture::AMD_RDNA1, + { + rdna1_pipelines, + }, + RDNA_DEFAULT_SUBGROUP_SIZE + }, + { + vk_device_architecture::AMD_RDNA2, + { + rdna2_pipelines, + }, + RDNA_DEFAULT_SUBGROUP_SIZE + }, +}; + +static uint32_t get_subgroup_size(const std::string &pipeline_name, const vk_device_architecture &arch) { + for (const auto &config : gpu_pipeline_configs) { + if (config.arch == arch) { + auto pipIt = config.pipelines.find(pipeline_name); + if (pipIt != config.pipelines.end()) { + return pipIt->second; + } + std::vector> sorted_pipelines(config.pipelines.begin(), config.pipelines.end()); + std::sort(sorted_pipelines.begin(), sorted_pipelines.end(), + [](const auto &a, const auto &b) { return a.first.size() > b.first.size(); }); + for (const auto &entry : sorted_pipelines) { + if (pipeline_name.find(entry.first) != std::string::npos) { + return entry.second; + } + } + return config.default_subgroup_size; + } + } + return 0; // If no matching configuration is found +} + static void ggml_vk_load_shaders(vk_device& device) { VK_LOG_DEBUG("ggml_vk_load_shaders(" << device->name << ")"); @@ -1456,6 +1747,7 @@ static void ggml_vk_load_shaders(vk_device& device) { // mulmat std::vector l_warptile, m_warptile, s_warptile, l_warptile_mmq, m_warptile_mmq, s_warptile_mmq, + l_warptile_mmq_int, m_warptile_mmq_int, s_warptile_mmq_int, l_warptile_mmq_k, m_warptile_mmq_k, s_warptile_mmq_k, l_warptile_mmqid, m_warptile_mmqid, s_warptile_mmqid; std::array l_wg_denoms, m_wg_denoms, s_wg_denoms, @@ -1466,36 +1758,36 @@ static void ggml_vk_load_shaders(vk_device& device) { uint32_t l_align, m_align, s_align; if (device->coopmat2) { // spec constants and tile sizes for non-quant matmul/matmul_id - l_warptile = { 256, 128, 256, 64 }; - m_warptile = { 256, 128, 128, 64 }; - s_warptile = { 128, 64, 64, 64 }; + l_warptile = { 256, 128, 256, 64, 1 }; + m_warptile = { 256, 128, 128, 64, 0 }; + s_warptile = { 128, 64, 64, 64, 0 }; l_wg_denoms = {128, 256, 1 }; m_wg_denoms = {128, 128, 1 }; s_wg_denoms = { 64, 64, 1 }; // spec constants and tile sizes for quant matmul (non-Qi_K) - l_warptile_mmq = { 256, 128, 256, 64 }; - m_warptile_mmq = { 256, 128, 128, 64 }; - s_warptile_mmq = { 256, 128, 128, 64 }; + l_warptile_mmq = { 256, 128, 256, 64, 1 }; + m_warptile_mmq = { 256, 128, 128, 64, 1 }; + s_warptile_mmq = { 256, 32, 64, 128, 0 }; l_mmq_wg_denoms = { 128, 256, 1 }; m_mmq_wg_denoms = { 128, 128, 1 }; - s_mmq_wg_denoms = { 128, 128, 1 }; + s_mmq_wg_denoms = { 32, 64, 1 }; // spec constants and tile sizes for quant matmul (Qi_K) - l_warptile_mmq_k = { 256, 128, 512, 16 }; - m_warptile_mmq_k = { 256, 128, 256, 16 }; - s_warptile_mmq_k = { 256, 32, 128, 64 }; - l_mmq_wg_denoms_k = { 128, 512, 1 }; - m_mmq_wg_denoms_k = { 128, 256, 1 }; - s_mmq_wg_denoms_k = { 32, 128, 1 }; + l_warptile_mmq_k = { 256, 64, 128, 64, 1 }; + m_warptile_mmq_k = { 256, 32, 64, 64, 0 }; + s_warptile_mmq_k = { 256, 32, 32, 128, 0 }; + l_mmq_wg_denoms_k = { 64, 128, 1 }; + m_mmq_wg_denoms_k = { 32, 64, 1 }; + s_mmq_wg_denoms_k = { 32, 32, 1 }; // spec constants and tile sizes for quant matmul_id - l_warptile_mmqid = { 256, 128, 128, 16 }; - m_warptile_mmqid = { 256, 128, 64, 16 }; - s_warptile_mmqid = { 256, 64, 64, 16 }; - l_mmqid_wg_denoms = { 128, 128, 1 }; + l_warptile_mmqid = { 256, 128, 64, 16, 0 }; + m_warptile_mmqid = { 256, 128, 64, 16, 0 }; + s_warptile_mmqid = { 256, 128, 64, 16, 0 }; + l_mmqid_wg_denoms = { 128, 64, 1 }; m_mmqid_wg_denoms = { 128, 64, 1 }; - s_mmqid_wg_denoms = { 64, 64, 1 }; + s_mmqid_wg_denoms = { 128, 64, 1 }; l_align = 128; m_align = 64; @@ -1520,6 +1812,15 @@ static void ggml_vk_load_shaders(vk_device& device) { m_warptile_mmq = { 128, 64, 64, 32, subgroup_size_8, 32, 2, tm_m, tn_m, tk_m, subgroup_size_8 }; s_warptile_mmq = { subgroup_size_32, 32, 32, 32, 32, 32, 2, tm_s, tn_s, tk_s, subgroup_size_8 }; + l_warptile_mmq_int = { 128, 128, 128, 32, subgroup_size_8 * 2, 64, 2, 4, 4, 1, subgroup_size_8 }; + m_warptile_mmq_int = { 128, 64, 64, 32, subgroup_size_8, 32, 2, 2, 2, 1, subgroup_size_8 }; + s_warptile_mmq_int = { subgroup_size_32, 32, 32, 32, 32, 32, 2, 2, 1, 1, subgroup_size_8 }; + + // chip specific tuning + if ((device->architecture == AMD_GCN) && (device->driver_id != vk::DriverId::eAmdProprietary)) { + m_warptile_mmq = m_warptile_mmq_int = { 256, 64, 64, 32, 16, 16, 2, 2, 2, 1, 16 }; + } + l_mmq_wg_denoms = l_wg_denoms = {128, 128, 1 }; m_mmq_wg_denoms = m_wg_denoms = { 64, 64, 1 }; s_mmq_wg_denoms = s_wg_denoms = { 32, 32, 1 }; @@ -1565,12 +1866,22 @@ static void ggml_vk_load_shaders(vk_device& device) { if (!device->pipeline_matmul_id_f32) { device->pipeline_matmul_id_f32 = std::make_shared(); } + if (!device->pipeline_matmul_bf16) { + device->pipeline_matmul_bf16 = std::make_shared(); + } + if (!device->pipeline_matmul_id_bf16) { + device->pipeline_matmul_id_bf16 = std::make_shared(); + } std::vector> compiles; auto const &ggml_vk_create_pipeline = [&](vk_device& device, vk_pipeline& pipeline, const std::string &name, size_t spv_size, const void* spv_data, const std::string &entrypoint, uint32_t parameter_count, uint32_t push_constant_size, std::array wg_denoms, const std::vector& specialization_constants, uint32_t align, bool disable_robustness = false, bool require_full_subgroups = false, uint32_t required_subgroup_size = 0) { + if (!require_full_subgroups && required_subgroup_size == 0) { + required_subgroup_size = get_subgroup_size(name, device->architecture); + } + if (!pipeline) { pipeline = std::make_shared(); pipeline->name = name; @@ -1596,63 +1907,66 @@ static void ggml_vk_load_shaders(vk_device& device) { parameter_count, wg_denoms, specialization_constants, disable_robustness, require_full_subgroups, required_subgroup_size)); }; + auto const &fa_wg_denoms = [&](bool scalar, uint32_t D, uint32_t clamp, ggml_type type, bool small_rows) -> std::array { + return {fa_rows_cols(scalar, D, clamp, type, small_rows)[0], 1, 1}; + }; + + auto const &fa_spec_constants = [&](bool scalar, uint32_t D, uint32_t clamp, ggml_type type, bool small_rows) -> std::vector { + // For large number of rows, 128 invocations seems to work best. + // For small number of rows (e.g. N==1), 256 works better. But matrix granularity for 256 is 32, so we + // can't use 256 for D==80. + // For scalar, use 128 (arbitrary) + uint32_t wg_size = scalar ? 128 : ((small_rows && (D % 32) == 0) ? 256 : 128); + auto rows_cols = fa_rows_cols(scalar, D, clamp, type, small_rows); + + // D_split can't be larger than a subgroup because we use subgroupShuffle to reduce it. + // D_split can't be larger than the LSB of D divided by 4 due to vectorization in the shader. + const uint32_t D_lsb = D ^ (D & (D-1)); + uint32_t D_split = std::min(std::min(device->subgroup_size, 8u), D_lsb / 4); + + // mask dim1 is padded to 64, we rely on this to avoid clamping mask loads + GGML_ASSERT((GGML_KQ_MASK_PAD % rows_cols[0]) == 0); + return {wg_size, rows_cols[0], rows_cols[1], (D), clamp, D_split}; + }; + +#define CREATE_FA2(TYPE, NAMELC, SCALAR, SUFFIX, D) \ + ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D ## SUFFIX[TYPE][0][0][0], "flash_attn_f32_f16_D" #D "_f16acc" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(SCALAR, D,1,TYPE,false), fa_spec_constants(SCALAR, D,1,TYPE,false), 1, true); \ + ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D ## SUFFIX[TYPE][0][0][1], "flash_attn_f32_f16_D" #D "_aligned_f16acc" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(SCALAR, D,0,TYPE,false), fa_spec_constants(SCALAR, D,0,TYPE,false), fa_rows_cols(SCALAR,D,0,TYPE,false)[1], true); \ + ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D ## SUFFIX[TYPE][1][0][0], "flash_attn_f32_f16_D" #D "_f32acc" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(SCALAR, D,1,TYPE,false), fa_spec_constants(SCALAR, D,1,TYPE,false), 1, true); \ + ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D ## SUFFIX[TYPE][1][0][1], "flash_attn_f32_f16_D" #D "_aligned_f32acc" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(SCALAR, D,0,TYPE,false), fa_spec_constants(SCALAR, D,0,TYPE,false), fa_rows_cols(SCALAR,D,0,TYPE,false)[1], true); \ + ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D ## SUFFIX[TYPE][0][1][0], "flash_attn_f32_f16_D" #D "_f16acc_smallrows" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(SCALAR, D,1,TYPE,true), fa_spec_constants(SCALAR, D,1,TYPE,true), 1, true); \ + ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D ## SUFFIX[TYPE][0][1][1], "flash_attn_f32_f16_D" #D "_aligned_f16acc_smallrows" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(SCALAR, D,0,TYPE,true), fa_spec_constants(SCALAR, D,0,TYPE,true), fa_rows_cols(SCALAR,D,0,TYPE,true)[1], true); \ + ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D ## SUFFIX[TYPE][1][1][0], "flash_attn_f32_f16_D" #D "_f32acc_smallrows" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(SCALAR, D,1,TYPE,true), fa_spec_constants(SCALAR, D,1,TYPE,true), 1, true); \ + ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D ## SUFFIX[TYPE][1][1][1], "flash_attn_f32_f16_D" #D "_aligned_f32acc_smallrows" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(SCALAR, D,0,TYPE,true), fa_spec_constants(SCALAR, D,0,TYPE,true), fa_rows_cols(SCALAR,D,0,TYPE,true)[1], true); \ + +#define CREATE_FA(TYPE, NAMELC, SCALAR, SUFFIX) \ + CREATE_FA2(TYPE, NAMELC, SCALAR, SUFFIX, 64) \ + CREATE_FA2(TYPE, NAMELC, SCALAR, SUFFIX, 80) \ + CREATE_FA2(TYPE, NAMELC, SCALAR, SUFFIX, 96) \ + CREATE_FA2(TYPE, NAMELC, SCALAR, SUFFIX, 112) \ + CREATE_FA2(TYPE, NAMELC, SCALAR, SUFFIX, 128) \ + CREATE_FA2(TYPE, NAMELC, SCALAR, SUFFIX, 256) + + CREATE_FA(GGML_TYPE_F16, f16, true, ) + CREATE_FA(GGML_TYPE_Q4_0, q4_0, true, ) + CREATE_FA(GGML_TYPE_Q8_0, q8_0, true, ) #if defined(VK_NV_cooperative_matrix2) && defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT) if (device->coopmat2) { - - auto const &fa_wg_denoms = [&](uint32_t D, uint32_t clamp, ggml_type type, bool small_rows) -> std::array { - return {fa_rows_cols(D, clamp, type, small_rows)[0], 1, 1}; - }; - - auto const &fa_spec_constants = [&](uint32_t D, uint32_t clamp, ggml_type type, bool small_rows) -> std::vector { - // For large number of rows, 128 invocations seems to work best. - // For small number of rows (e.g. N==1), 256 works better. But matrix granularity for 256 is 32, so we - // can't use 256 for D==80. - uint32_t wg_size = (small_rows && (D % 32) == 0) ? 256 : 128; - auto rows_cols = fa_rows_cols(D, clamp, type, small_rows); - return {wg_size, rows_cols[0], rows_cols[1], (D), clamp}; - }; - -#define CREATE_FA2(TYPE, NAMELC, D) \ - ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D[TYPE][0][0][0], "flash_attn_f32_f16_D" #D "_f16acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _f16acc_cm2_len, flash_attn_f32_f16_ ## NAMELC ## _f16acc_cm2_data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(D,1,TYPE,false), fa_spec_constants(D,1,TYPE,false), 1); \ - ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D[TYPE][0][0][1], "flash_attn_f32_f16_D" #D "_aligned_f16acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _f16acc_cm2_len, flash_attn_f32_f16_ ## NAMELC ## _f16acc_cm2_data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(D,0,TYPE,false), fa_spec_constants(D,0,TYPE,false), fa_rows_cols(D,0,TYPE,false)[1]); \ - ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D[TYPE][1][0][0], "flash_attn_f32_f16_D" #D "_f32acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _cm2_len, flash_attn_f32_f16_ ## NAMELC ## _cm2_data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(D,1,TYPE,false), fa_spec_constants(D,1,TYPE,false), 1); \ - ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D[TYPE][1][0][1], "flash_attn_f32_f16_D" #D "_aligned_f32acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _cm2_len, flash_attn_f32_f16_ ## NAMELC ## _cm2_data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(D,0,TYPE,false), fa_spec_constants(D,0,TYPE,false), fa_rows_cols(D,0,TYPE,false)[1]); \ - ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D[TYPE][0][1][0], "flash_attn_f32_f16_D" #D "_f16acc_smallrows" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _f16acc_cm2_len, flash_attn_f32_f16_ ## NAMELC ## _f16acc_cm2_data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(D,1,TYPE,true), fa_spec_constants(D,1,TYPE,true), 1); \ - ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D[TYPE][0][1][1], "flash_attn_f32_f16_D" #D "_aligned_f16acc_smallrows" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _f16acc_cm2_len, flash_attn_f32_f16_ ## NAMELC ## _f16acc_cm2_data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(D,0,TYPE,true), fa_spec_constants(D,0,TYPE,true), fa_rows_cols(D,0,TYPE,true)[1]); \ - ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D[TYPE][1][1][0], "flash_attn_f32_f16_D" #D "_f32acc_smallrows" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _cm2_len, flash_attn_f32_f16_ ## NAMELC ## _cm2_data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(D,1,TYPE,true), fa_spec_constants(D,1,TYPE,true), 1); \ - ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D[TYPE][1][1][1], "flash_attn_f32_f16_D" #D "_aligned_f32acc_smallrows" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _cm2_len, flash_attn_f32_f16_ ## NAMELC ## _cm2_data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(D,0,TYPE,true), fa_spec_constants(D,0,TYPE,true), fa_rows_cols(D,0,TYPE,true)[1]); \ - -#define CREATE_FA(TYPE, NAMELC) \ - CREATE_FA2(TYPE, NAMELC, 64) \ - CREATE_FA2(TYPE, NAMELC, 80) \ - CREATE_FA2(TYPE, NAMELC, 96) \ - CREATE_FA2(TYPE, NAMELC, 112) \ - CREATE_FA2(TYPE, NAMELC, 128) \ - CREATE_FA2(TYPE, NAMELC, 256) - - CREATE_FA(GGML_TYPE_F16, f16) - CREATE_FA(GGML_TYPE_Q4_0, q4_0) - CREATE_FA(GGML_TYPE_Q4_1, q4_1) - CREATE_FA(GGML_TYPE_Q5_0, q5_0) - CREATE_FA(GGML_TYPE_Q5_1, q5_1) - CREATE_FA(GGML_TYPE_Q8_0, q8_0) - // K dequants currently disabled because D dimension is rounded up to 256 and runs inefficiently - //CREATE_FA(GGML_TYPE_Q2_K, q2_k) - //CREATE_FA(GGML_TYPE_Q3_K, q3_k) - //CREATE_FA(GGML_TYPE_Q4_K, q4_k) - //CREATE_FA(GGML_TYPE_Q5_K, q5_k) - //CREATE_FA(GGML_TYPE_Q6_K, q6_k) - //CREATE_FA(GGML_TYPE_IQ1_S, iq1_s) - //CREATE_FA(GGML_TYPE_IQ1_M, iq1_m) - //CREATE_FA(GGML_TYPE_IQ2_XXS, iq2_xxs) - //CREATE_FA(GGML_TYPE_IQ2_XS, iq2_xs) - //CREATE_FA(GGML_TYPE_IQ2_S, iq2_s) - //CREATE_FA(GGML_TYPE_IQ3_XXS, iq3_xxs) - //CREATE_FA(GGML_TYPE_IQ3_S, iq3_s) - //CREATE_FA(GGML_TYPE_IQ4_XS, iq4_xs) - CREATE_FA(GGML_TYPE_IQ4_NL, iq4_nl) + CREATE_FA(GGML_TYPE_F16, f16, false, _cm2) + CREATE_FA(GGML_TYPE_Q4_0, q4_0, false, _cm2) + CREATE_FA(GGML_TYPE_Q4_1, q4_1, false, _cm2) + CREATE_FA(GGML_TYPE_Q5_0, q5_0, false, _cm2) + CREATE_FA(GGML_TYPE_Q5_1, q5_1, false, _cm2) + CREATE_FA(GGML_TYPE_Q8_0, q8_0, false, _cm2) + CREATE_FA(GGML_TYPE_IQ4_NL, iq4_nl, false, _cm2) + } +#endif +#undef CREATE_FA2 #undef CREATE_FA +#if defined(VK_NV_cooperative_matrix2) && defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT) + if (device->coopmat2) { + // Create 6 variants, {s,m,l}x{unaligned,aligned} #define CREATE_MM(PIPELINE_NAME, NAMELC, F16ACC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT) \ ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC #F16ACC "_l", NAMELC ## F16ACC ## _cm2_len, NAMELC ## F16ACC ## _cm2_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1); \ @@ -1668,6 +1982,11 @@ static void ggml_vk_load_shaders(vk_device& device) { CREATE_MM(PIPELINE_NAME . f32acc, NAMELC, , WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT) \ CREATE_MM2(pipeline_matmul_f16, matmul_f16, wg_denoms, warptile, vk_mat_mat_push_constants, 3) +#if defined(GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT) + if (device->coopmat_bf16_support) { + CREATE_MM(pipeline_matmul_bf16, matmul_bf16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3) + } +#endif CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q4_0].f16acc, matmul_q4_0_f16, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3) CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q4_1].f16acc, matmul_q4_1_f16, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3) CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q5_0].f16acc, matmul_q5_0_f16, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3) @@ -1689,6 +2008,11 @@ static void ggml_vk_load_shaders(vk_device& device) { CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_IQ4_NL].f16acc, matmul_iq4_nl_f16, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3) CREATE_MM2(pipeline_matmul_id_f16, matmul_id_f16, wg_denoms, warptile, vk_mat_mat_id_push_constants, 4) +#if defined(GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT) + if (device->coopmat_bf16_support) { + CREATE_MM(pipeline_matmul_id_bf16, matmul_id_bf16, , wg_denoms, warptile, vk_mat_mat_id_push_constants, 4) + } +#endif CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0].f16acc, matmul_id_q4_0_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4) CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1].f16acc, matmul_id_q4_1_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4) CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0].f16acc, matmul_id_q5_0_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4) @@ -1742,6 +2066,11 @@ static void ggml_vk_load_shaders(vk_device& device) { CREATE_MM(GGML_TYPE_F32, pipeline_matmul_f32_f16, matmul_f32_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, ); CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_f16, matmul_f16, wg_denoms, warptile, vk_mat_mat_push_constants, 3, ); CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_f16_f32, matmul_f16_f32, wg_denoms, warptile, vk_mat_mat_push_constants, 3, ); +#if defined(GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT) + if (device->coopmat_bf16_support) { + CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_bf16, matmul_bf16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, ) + } +#endif if (device->coopmat_acc_f16_support) { CREATE_MM(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0].f16acc, matmul_q4_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); @@ -1790,6 +2119,11 @@ static void ggml_vk_load_shaders(vk_device& device) { CREATE_MM(GGML_TYPE_F32, pipeline_matmul_id_f32, matmul_id_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id); CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_id_f16, matmul_id_f16, wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id); CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_id_f16_f32, matmul_id_f16_f32, wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id); +#if defined(GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT) + if (device->coopmat_bf16_support) { + CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_id_bf16, matmul_id_bf16, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id); + } +#endif if (device->coopmat_acc_f16_support) { CREATE_MM(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0].f16acc, matmul_id_q4_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); @@ -1854,6 +2188,14 @@ static void ggml_vk_load_shaders(vk_device& device) { if (device->mul_mat ## ID ## _s[TYPE]) \ ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_s, #NAMELC #F16ACC "_aligned_s", NAMELC ## _aligned ## F16ACC ## _len, NAMELC ## _aligned ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, s_align); \ +#define CREATE_MMQ(TYPE, PIPELINE_NAME, NAMELC, F16ACC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \ + if (device->mul_mat ## ID ## _l[TYPE]) \ + ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC #F16ACC "_l", NAMELC ## F16ACC ## _len, NAMELC ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1); \ + if (device->mul_mat ## ID ## _m[TYPE]) \ + ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->m, #NAMELC #F16ACC "_m", NAMELC ## F16ACC ## _len, NAMELC ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, 1); \ + if (device->mul_mat ## ID ## _s[TYPE]) \ + ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->s, #NAMELC #F16ACC "_s", NAMELC ## F16ACC ## _len, NAMELC ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, 1); \ + // Create 2 variants, {f16,f32} accumulator #define CREATE_MM2(TYPE, PIPELINE_NAME, NAMELC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \ CREATE_MM(TYPE, PIPELINE_NAME . f16acc, NAMELC, _f16acc, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \ @@ -1864,6 +2206,8 @@ static void ggml_vk_load_shaders(vk_device& device) { CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_f16, matmul_f16, wg_denoms, warptile, vk_mat_mat_push_constants, 3, ); CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_f16_f32, matmul_f16_f32, wg_denoms, warptile, vk_mat_mat_push_constants, 3, ); + CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_bf16, matmul_bf16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, ); + CREATE_MM(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0].f16acc, matmul_q4_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); CREATE_MM(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1].f16acc, matmul_q4_1_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); CREATE_MM(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0].f16acc, matmul_q5_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); @@ -1885,10 +2229,22 @@ static void ggml_vk_load_shaders(vk_device& device) { CREATE_MM(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_XS].f16acc, matmul_iq4_xs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); CREATE_MM(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL].f16acc, matmul_iq4_nl_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); +#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT) + if (device->integer_dot_product) { + CREATE_MMQ(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q4_0].f16acc, matmul_q4_0_q8_1, _f16acc, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, ); + CREATE_MMQ(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q4_1].f16acc, matmul_q4_1_q8_1, _f16acc, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, ); + CREATE_MMQ(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q5_0].f16acc, matmul_q5_0_q8_1, _f16acc, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, ); + CREATE_MMQ(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q5_1].f16acc, matmul_q5_1_q8_1, _f16acc, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, ); + CREATE_MMQ(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q8_0].f16acc, matmul_q8_0_q8_1, _f16acc, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, ); + } +#endif + CREATE_MM(GGML_TYPE_F32, pipeline_matmul_id_f32, matmul_id_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id); CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_id_f16, matmul_id_f16, wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id); CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_id_f16_f32, matmul_id_f16_f32, wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id); + CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_id_bf16, matmul_id_bf16, , wg_denoms, warptile, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0].f16acc, matmul_id_q4_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); CREATE_MM(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1].f16acc, matmul_id_q4_1_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); CREATE_MM(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0].f16acc, matmul_id_q5_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); @@ -1910,6 +2266,7 @@ static void ggml_vk_load_shaders(vk_device& device) { CREATE_MM(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS].f16acc, matmul_id_iq4_xs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); CREATE_MM(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f16acc, matmul_id_iq4_nl_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); #undef CREATE_MM2 +#undef CREATE_MMQ #undef CREATE_MM } else { // Create 6 variants, {s,m,l}x{unaligned,aligned} @@ -1927,11 +2284,21 @@ static void ggml_vk_load_shaders(vk_device& device) { if (device->mul_mat ## ID ## _s[TYPE]) \ ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_s, #NAMELC #F16ACC "_aligned_s", NAMELC ## _aligned ## F16ACC ## _fp32_len, NAMELC ## _aligned ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, s_align); \ +#define CREATE_MMQ(TYPE, PIPELINE_NAME, NAMELC, F16ACC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \ + if (device->mul_mat ## ID ## _l[TYPE]) \ + ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC #F16ACC "_l", NAMELC ## F16ACC ## _fp32_len, NAMELC ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1); \ + if (device->mul_mat ## ID ## _m[TYPE]) \ + ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->m, #NAMELC #F16ACC "_m", NAMELC ## F16ACC ## _fp32_len, NAMELC ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, 1); \ + if (device->mul_mat ## ID ## _s[TYPE]) \ + ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->s, #NAMELC #F16ACC "_s", NAMELC ## F16ACC ## _fp32_len, NAMELC ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, 1); \ + CREATE_MM(GGML_TYPE_F32, pipeline_matmul_f32, matmul_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, ); CREATE_MM(GGML_TYPE_F32, pipeline_matmul_f32_f16, matmul_f32_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, ); CREATE_MM(GGML_TYPE_F16, pipeline_matmul_f16.f32acc, matmul_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, ); CREATE_MM(GGML_TYPE_F16, pipeline_matmul_f16_f32.f32acc, matmul_f16_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, ); + CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_bf16, matmul_bf16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, ); + CREATE_MM(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0].f32acc, matmul_q4_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); CREATE_MM(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1].f32acc, matmul_q4_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); CREATE_MM(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0].f32acc, matmul_q5_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); @@ -1953,10 +2320,22 @@ static void ggml_vk_load_shaders(vk_device& device) { CREATE_MM(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_XS].f32acc, matmul_iq4_xs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); CREATE_MM(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL].f32acc, matmul_iq4_nl_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); +#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT) + if (device->integer_dot_product) { + CREATE_MMQ(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q4_0].f32acc, matmul_q4_0_q8_1, , mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, ); + CREATE_MMQ(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q4_1].f32acc, matmul_q4_1_q8_1, , mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, ); + CREATE_MMQ(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q5_0].f32acc, matmul_q5_0_q8_1, , mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, ); + CREATE_MMQ(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q5_1].f32acc, matmul_q5_1_q8_1, , mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, ); + CREATE_MMQ(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q8_0].f32acc, matmul_q8_0_q8_1, , mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, ); + } +#endif + CREATE_MM(GGML_TYPE_F32, pipeline_matmul_id_f32, matmul_id_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id); CREATE_MM(GGML_TYPE_F16, pipeline_matmul_id_f16.f32acc, matmul_id_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id); CREATE_MM(GGML_TYPE_F16, pipeline_matmul_id_f16_f32.f32acc, matmul_id_f16_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id); + CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_id_bf16, matmul_id_bf16, , wg_denoms, warptile, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0].f32acc, matmul_id_q4_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); CREATE_MM(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1].f32acc, matmul_id_q4_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); CREATE_MM(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0].f32acc, matmul_id_q5_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); @@ -1977,8 +2356,26 @@ static void ggml_vk_load_shaders(vk_device& device) { CREATE_MM(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_S].f32acc, matmul_id_iq3_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); CREATE_MM(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS].f32acc, matmul_id_iq4_xs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); CREATE_MM(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f32acc, matmul_id_iq4_nl_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); -#undef CREATE_MM } + // reusing CREATE_MM from the fp32 path + if ((device->coopmat2 || device->coopmat_support) +#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT) + && !device->coopmat_bf16_support +#endif + ) { + // use scalar tile sizes + l_warptile = { 128, 128, 128, 16, subgroup_size_8 * 2, 64, 2, 4, 4, 1, subgroup_size_8 }; + m_warptile = { 128, 64, 64, 16, subgroup_size_8, 32, 2, 4, 2, 1, subgroup_size_8 }; + s_warptile = { subgroup_size_16, 32, 32, 16, 32, 32, 2, 2, 2, 1, subgroup_size_8 }; + + l_wg_denoms = {128, 128, 1 }; + m_wg_denoms = { 64, 64, 1 }; + s_wg_denoms = { 32, 32, 1 }; + + CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_bf16, matmul_bf16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, ); + CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_id_bf16, matmul_id_bf16, , wg_denoms, warptile, vk_mat_mat_id_push_constants, 4, _id); + } +#undef CREATE_MM // mul mat vec @@ -1986,16 +2383,18 @@ static void ggml_vk_load_shaders(vk_device& device) { uint32_t rm_stdq = 1; uint32_t rm_kq = 2; if (device->vendor_id == VK_VENDOR_ID_AMD) { - if (device->subgroup_min_size == 64 && device->subgroup_max_size == 64) { // GCN + if (device->architecture == AMD_GCN) { rm_stdq = 2; rm_kq = 4; } } else if (device->vendor_id == VK_VENDOR_ID_INTEL) rm_stdq = 2; + uint32_t rm_iq = 2 * rm_kq; for (uint32_t i = 0; i < mul_mat_vec_max_cols; ++i) { ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_F32 ][i], "mul_mat_vec_f32_f32_f32_"+std::to_string(i+1), mul_mat_vec_f32_f32_f32_len, mul_mat_vec_f32_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size, 2, i+1}, 1); ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_F16 ][i], "mul_mat_vec_f16_f32_f32_"+std::to_string(i+1), mul_mat_vec_f16_f32_f32_len, mul_mat_vec_f16_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size, 2, i+1}, 1); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_BF16][i], "mul_mat_vec_bf16_f32_f32_"+std::to_string(i+1), mul_mat_vec_bf16_f32_f32_len, mul_mat_vec_bf16_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size, 2, i+1}, 1); ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q4_0][i], "mul_mat_vec_q4_0_f32_f32_"+std::to_string(i+1), mul_mat_vec_q4_0_f32_f32_len, mul_mat_vec_q4_0_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq, i+1}, 1, true); ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q4_1][i], "mul_mat_vec_q4_1_f32_f32_"+std::to_string(i+1), mul_mat_vec_q4_1_f32_f32_len, mul_mat_vec_q4_1_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq, i+1}, 1, true); ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q5_0][i], "mul_mat_vec_q5_0_f32_f32_"+std::to_string(i+1), mul_mat_vec_q5_0_f32_f32_len, mul_mat_vec_q5_0_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq, i+1}, 1, true); @@ -2006,18 +2405,19 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q4_K][i], "mul_mat_vec_q4_k_f32_f32_"+std::to_string(i+1), mul_mat_vec_q4_k_f32_f32_len, mul_mat_vec_q4_k_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true); ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q5_K][i], "mul_mat_vec_q5_k_f32_f32_"+std::to_string(i+1), mul_mat_vec_q5_k_f32_f32_len, mul_mat_vec_q5_k_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true); ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q6_K][i], "mul_mat_vec_q6_k_f32_f32_"+std::to_string(i+1), mul_mat_vec_q6_k_f32_f32_len, mul_mat_vec_q6_k_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_IQ1_S][i], "mul_mat_vec_iq1_s_f32_f32_"+std::to_string(i+1), mul_mat_vec_iq1_s_f32_f32_len, mul_mat_vec_iq1_s_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_IQ1_M][i], "mul_mat_vec_iq1_m_f32_f32_"+std::to_string(i+1), mul_mat_vec_iq1_m_f32_f32_len, mul_mat_vec_iq1_m_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_IQ2_XXS][i], "mul_mat_vec_iq2_xxs_f32_f32_"+std::to_string(i+1), mul_mat_vec_iq2_xxs_f32_f32_len, mul_mat_vec_iq2_xxs_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_IQ2_XS][i], "mul_mat_vec_iq2_xs_f32_f32_"+std::to_string(i+1), mul_mat_vec_iq2_xs_f32_f32_len, mul_mat_vec_iq2_xs_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_IQ2_S][i], "mul_mat_vec_iq2_s_f32_f32_"+std::to_string(i+1), mul_mat_vec_iq2_s_f32_f32_len, mul_mat_vec_iq2_s_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_IQ3_XXS][i], "mul_mat_vec_iq3_xxs_f32_f32_"+std::to_string(i+1), mul_mat_vec_iq3_xxs_f32_f32_len, mul_mat_vec_iq3_xxs_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_IQ3_S][i], "mul_mat_vec_iq3_s_f32_f32_"+std::to_string(i+1), mul_mat_vec_iq3_s_f32_f32_len, mul_mat_vec_iq3_s_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_IQ4_XS][i], "mul_mat_vec_iq4_xs_f32_f32_"+std::to_string(i+1), mul_mat_vec_iq4_xs_f32_f32_len, mul_mat_vec_iq4_xs_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_IQ4_NL][i], "mul_mat_vec_iq4_nl_f32_f32_"+std::to_string(i+1), mul_mat_vec_iq4_nl_f32_f32_len, mul_mat_vec_iq4_nl_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {subgroup_size_16, 2*rm_stdq, i+1}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_IQ1_S][i], "mul_mat_vec_iq1_s_f32_f32_"+std::to_string(i+1), mul_mat_vec_iq1_s_f32_f32_len, mul_mat_vec_iq1_s_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq, i+1}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_IQ1_M][i], "mul_mat_vec_iq1_m_f32_f32_"+std::to_string(i+1), mul_mat_vec_iq1_m_f32_f32_len, mul_mat_vec_iq1_m_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq, i+1}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_IQ2_XXS][i], "mul_mat_vec_iq2_xxs_f32_f32_"+std::to_string(i+1), mul_mat_vec_iq2_xxs_f32_f32_len, mul_mat_vec_iq2_xxs_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq, i+1}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_IQ2_XS][i], "mul_mat_vec_iq2_xs_f32_f32_"+std::to_string(i+1), mul_mat_vec_iq2_xs_f32_f32_len, mul_mat_vec_iq2_xs_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq, i+1}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_IQ2_S][i], "mul_mat_vec_iq2_s_f32_f32_"+std::to_string(i+1), mul_mat_vec_iq2_s_f32_f32_len, mul_mat_vec_iq2_s_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq, i+1}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_IQ3_XXS][i], "mul_mat_vec_iq3_xxs_f32_f32_"+std::to_string(i+1), mul_mat_vec_iq3_xxs_f32_f32_len, mul_mat_vec_iq3_xxs_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq, i+1}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_IQ3_S][i], "mul_mat_vec_iq3_s_f32_f32_"+std::to_string(i+1), mul_mat_vec_iq3_s_f32_f32_len, mul_mat_vec_iq3_s_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq, i+1}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_IQ4_XS][i], "mul_mat_vec_iq4_xs_f32_f32_"+std::to_string(i+1), mul_mat_vec_iq4_xs_f32_f32_len, mul_mat_vec_iq4_xs_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq, i+1}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_IQ4_NL][i], "mul_mat_vec_iq4_nl_f32_f32_"+std::to_string(i+1), mul_mat_vec_iq4_nl_f32_f32_len, mul_mat_vec_iq4_nl_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq, i+1}, 1, true); ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_F32 ][i], "mul_mat_vec_f32_f16_f32_"+std::to_string(i+1), mul_mat_vec_f32_f16_f32_len, mul_mat_vec_f32_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size, 2, i+1}, 1); ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_F16 ][i], "mul_mat_vec_f16_f16_f32_"+std::to_string(i+1), mul_mat_vec_f16_f16_f32_len, mul_mat_vec_f16_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size, 2, i+1}, 1); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_BF16][i], "mul_mat_vec_bf16_f16_f32_"+std::to_string(i+1), mul_mat_vec_bf16_f16_f32_len, mul_mat_vec_bf16_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size, 2, i+1}, 1); ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q4_0][i], "mul_mat_vec_q4_0_f16_f32_"+std::to_string(i+1), mul_mat_vec_q4_0_f16_f32_len, mul_mat_vec_q4_0_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq, i+1}, 1, true); ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q4_1][i], "mul_mat_vec_q4_1_f16_f32_"+std::to_string(i+1), mul_mat_vec_q4_1_f16_f32_len, mul_mat_vec_q4_1_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq, i+1}, 1, true); ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q5_0][i], "mul_mat_vec_q5_0_f16_f32_"+std::to_string(i+1), mul_mat_vec_q5_0_f16_f32_len, mul_mat_vec_q5_0_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq, i+1}, 1, true); @@ -2028,19 +2428,20 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q4_K][i], "mul_mat_vec_q4_k_f16_f32_"+std::to_string(i+1), mul_mat_vec_q4_k_f16_f32_len, mul_mat_vec_q4_k_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true); ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q5_K][i], "mul_mat_vec_q5_k_f16_f32_"+std::to_string(i+1), mul_mat_vec_q5_k_f16_f32_len, mul_mat_vec_q5_k_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true); ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q6_K][i], "mul_mat_vec_q6_k_f16_f32_"+std::to_string(i+1), mul_mat_vec_q6_k_f16_f32_len, mul_mat_vec_q6_k_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_IQ1_S][i], "mul_mat_vec_iq1_s_f16_f32_"+std::to_string(i+1), mul_mat_vec_iq1_s_f16_f32_len, mul_mat_vec_iq1_s_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_IQ1_M][i], "mul_mat_vec_iq1_m_f16_f32_"+std::to_string(i+1), mul_mat_vec_iq1_m_f16_f32_len, mul_mat_vec_iq1_m_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_IQ2_XXS][i], "mul_mat_vec_iq2_xxs_f16_f32_"+std::to_string(i+1), mul_mat_vec_iq2_xxs_f16_f32_len, mul_mat_vec_iq2_xxs_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_IQ2_XS][i], "mul_mat_vec_iq2_xs_f16_f32_"+std::to_string(i+1), mul_mat_vec_iq2_xs_f16_f32_len, mul_mat_vec_iq2_xs_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_IQ2_S][i], "mul_mat_vec_iq2_s_f16_f32_"+std::to_string(i+1), mul_mat_vec_iq2_s_f16_f32_len, mul_mat_vec_iq2_s_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_IQ3_XXS][i], "mul_mat_vec_iq3_xxs_f16_f32_"+std::to_string(i+1), mul_mat_vec_iq3_xxs_f16_f32_len, mul_mat_vec_iq3_xxs_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_IQ3_S][i], "mul_mat_vec_iq3_s_f16_f32_"+std::to_string(i+1), mul_mat_vec_iq3_s_f16_f32_len, mul_mat_vec_iq3_s_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_IQ4_XS][i], "mul_mat_vec_iq4_xs_f16_f32_"+std::to_string(i+1), mul_mat_vec_iq4_xs_f16_f32_len, mul_mat_vec_iq4_xs_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_IQ4_NL][i], "mul_mat_vec_iq4_nl_f16_f32_"+std::to_string(i+1), mul_mat_vec_iq4_nl_f16_f32_len, mul_mat_vec_iq4_nl_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {subgroup_size_16, 2*rm_stdq, i+1}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_IQ1_S][i], "mul_mat_vec_iq1_s_f16_f32_"+std::to_string(i+1), mul_mat_vec_iq1_s_f16_f32_len, mul_mat_vec_iq1_s_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq, i+1}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_IQ1_M][i], "mul_mat_vec_iq1_m_f16_f32_"+std::to_string(i+1), mul_mat_vec_iq1_m_f16_f32_len, mul_mat_vec_iq1_m_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq, i+1}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_IQ2_XXS][i], "mul_mat_vec_iq2_xxs_f16_f32_"+std::to_string(i+1), mul_mat_vec_iq2_xxs_f16_f32_len, mul_mat_vec_iq2_xxs_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq, i+1}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_IQ2_XS][i], "mul_mat_vec_iq2_xs_f16_f32_"+std::to_string(i+1), mul_mat_vec_iq2_xs_f16_f32_len, mul_mat_vec_iq2_xs_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq, i+1}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_IQ2_S][i], "mul_mat_vec_iq2_s_f16_f32_"+std::to_string(i+1), mul_mat_vec_iq2_s_f16_f32_len, mul_mat_vec_iq2_s_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq, i+1}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_IQ3_XXS][i], "mul_mat_vec_iq3_xxs_f16_f32_"+std::to_string(i+1), mul_mat_vec_iq3_xxs_f16_f32_len, mul_mat_vec_iq3_xxs_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq, i+1}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_IQ3_S][i], "mul_mat_vec_iq3_s_f16_f32_"+std::to_string(i+1), mul_mat_vec_iq3_s_f16_f32_len, mul_mat_vec_iq3_s_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq, i+1}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_IQ4_XS][i], "mul_mat_vec_iq4_xs_f16_f32_"+std::to_string(i+1), mul_mat_vec_iq4_xs_f16_f32_len, mul_mat_vec_iq4_xs_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq, i+1}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_IQ4_NL][i], "mul_mat_vec_iq4_nl_f16_f32_"+std::to_string(i+1), mul_mat_vec_iq4_nl_f16_f32_len, mul_mat_vec_iq4_nl_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq, i+1}, 1, true); } ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_F32 ], "mul_mat_vec_id_f32_f32", mul_mat_vec_id_f32_f32_len, mul_mat_vec_id_f32_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1); ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_F16 ], "mul_mat_vec_id_f16_f32", mul_mat_vec_id_f16_f32_len, mul_mat_vec_id_f16_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_BF16], "mul_mat_vec_id_bf16_f32", mul_mat_vec_id_bf16_f32_len, mul_mat_vec_id_bf16_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1); ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q4_0], "mul_mat_vec_id_q4_0_f32", mul_mat_vec_id_q4_0_f32_len, mul_mat_vec_id_q4_0_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq}, 1, true); ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q4_1], "mul_mat_vec_id_q4_1_f32", mul_mat_vec_id_q4_1_f32_len, mul_mat_vec_id_q4_1_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq}, 1, true); ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q5_0], "mul_mat_vec_id_q5_0_f32", mul_mat_vec_id_q5_0_f32_len, mul_mat_vec_id_q5_0_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq}, 1, true); @@ -2051,15 +2452,15 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q4_K], "mul_mat_vec_id_q4_k_f32", mul_mat_vec_id_q4_k_f32_len, mul_mat_vec_id_q4_k_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq}, 1, true); ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q5_K], "mul_mat_vec_id_q5_k_f32", mul_mat_vec_id_q5_k_f32_len, mul_mat_vec_id_q5_k_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq}, 1, true); ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q6_K], "mul_mat_vec_id_q6_k_f32", mul_mat_vec_id_q6_k_f32_len, mul_mat_vec_id_q6_k_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq}, 1, true); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_IQ1_S], "mul_mat_vec_id_iq1_s_f32", mul_mat_vec_id_iq1_s_f32_len, mul_mat_vec_id_iq1_s_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq}, 1, true); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_IQ1_M], "mul_mat_vec_id_iq1_m_f32", mul_mat_vec_id_iq1_m_f32_len, mul_mat_vec_id_iq1_m_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq}, 1, true); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_IQ2_XXS], "mul_mat_vec_id_iq2_xxs_f32", mul_mat_vec_id_iq2_xxs_f32_len, mul_mat_vec_id_iq2_xxs_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq}, 1, true); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_IQ2_XS], "mul_mat_vec_id_iq2_xs_f32", mul_mat_vec_id_iq2_xs_f32_len, mul_mat_vec_id_iq2_xs_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq}, 1, true); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_IQ2_S], "mul_mat_vec_id_iq2_s_f32", mul_mat_vec_id_iq2_s_f32_len, mul_mat_vec_id_iq2_s_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq}, 1, true); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_IQ3_XXS], "mul_mat_vec_id_iq3_xxs_f32", mul_mat_vec_id_iq3_xxs_f32_len, mul_mat_vec_id_iq3_xxs_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq}, 1, true); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_IQ3_S], "mul_mat_vec_id_iq3_s_f32", mul_mat_vec_id_iq3_s_f32_len, mul_mat_vec_id_iq3_s_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq}, 1, true); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_IQ4_XS], "mul_mat_vec_id_iq4_xs_f32", mul_mat_vec_id_iq4_xs_f32_len, mul_mat_vec_id_iq4_xs_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq}, 1, true); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_IQ4_NL], "mul_mat_vec_id_iq4_nl_f32", mul_mat_vec_id_iq4_nl_f32_len, mul_mat_vec_id_iq4_nl_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {2*rm_stdq, 1, 1}, {subgroup_size_16, 2*rm_stdq}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_IQ1_S], "mul_mat_vec_id_iq1_s_f32", mul_mat_vec_id_iq1_s_f32_len, mul_mat_vec_id_iq1_s_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_IQ1_M], "mul_mat_vec_id_iq1_m_f32", mul_mat_vec_id_iq1_m_f32_len, mul_mat_vec_id_iq1_m_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_IQ2_XXS], "mul_mat_vec_id_iq2_xxs_f32", mul_mat_vec_id_iq2_xxs_f32_len, mul_mat_vec_id_iq2_xxs_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_IQ2_XS], "mul_mat_vec_id_iq2_xs_f32", mul_mat_vec_id_iq2_xs_f32_len, mul_mat_vec_id_iq2_xs_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_IQ2_S], "mul_mat_vec_id_iq2_s_f32", mul_mat_vec_id_iq2_s_f32_len, mul_mat_vec_id_iq2_s_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_IQ3_XXS], "mul_mat_vec_id_iq3_xxs_f32", mul_mat_vec_id_iq3_xxs_f32_len, mul_mat_vec_id_iq3_xxs_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_IQ3_S], "mul_mat_vec_id_iq3_s_f32", mul_mat_vec_id_iq3_s_f32_len, mul_mat_vec_id_iq3_s_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_IQ4_XS], "mul_mat_vec_id_iq4_xs_f32", mul_mat_vec_id_iq4_xs_f32_len, mul_mat_vec_id_iq4_xs_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_IQ4_NL], "mul_mat_vec_id_iq4_nl_f32", mul_mat_vec_id_iq4_nl_f32_len, mul_mat_vec_id_iq4_nl_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq}, 1, true); // dequant shaders ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_F32 ], "f32_to_f16", dequant_f32_len, dequant_f32_data, "main", 2, 5 * sizeof(uint32_t), {256 * 16, 1, 1}, {}, 1); @@ -2086,6 +2487,7 @@ static void ggml_vk_load_shaders(vk_device& device) { // get_rows ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_F32 ], "get_rows_f32", get_rows_f32_len, get_rows_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), { 512, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_F16 ], "get_rows_f16", get_rows_f16_len, get_rows_f16_data, "main", 3, sizeof(vk_op_binary_push_constants), { 512, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_BF16], "get_rows_bf16", get_rows_bf16_len, get_rows_bf16_data, "main", 3, sizeof(vk_op_binary_push_constants), { 512, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_Q4_0], "get_rows_q4_0", get_rows_q4_0_len, get_rows_q4_0_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_Q4_1], "get_rows_q4_1", get_rows_q4_1_len, get_rows_q4_1_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_Q5_0], "get_rows_q5_0", get_rows_q5_0_len, get_rows_q5_0_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); @@ -2103,6 +2505,7 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_F32 ], "get_rows_f32_f32", get_rows_f32_f32_len, get_rows_f32_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), { 512, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_F16 ], "get_rows_f16_f32", get_rows_f16_f32_len, get_rows_f16_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), { 512, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_BF16], "get_rows_bf16_f32", get_rows_bf16_f32_len, get_rows_bf16_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), { 512, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_Q4_0], "get_rows_q4_0_f32", get_rows_q4_0_f32_len, get_rows_q4_0_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_Q4_1], "get_rows_q4_1_f32", get_rows_q4_1_f32_len, get_rows_q4_1_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_Q5_0], "get_rows_q5_0_f32", get_rows_q5_0_f32_len, get_rows_q5_0_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); @@ -2119,29 +2522,51 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_IQ4_NL], "get_rows_iq4_nl_f32", get_rows_iq4_nl_f32_len, get_rows_iq4_nl_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_matmul_split_k_reduce, "split_k_reduce", split_k_reduce_len, split_k_reduce_data, "main", 2, 2 * sizeof(uint32_t), {256 * 4, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_flash_attn_split_k_reduce, "fa_split_k_reduce", fa_split_k_reduce_len, fa_split_k_reduce_data, "main", 2, 3 * sizeof(uint32_t), {1, 1, 1}, {}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_quantize_q8_1, "quantize_q8_1", quantize_q8_1_len, quantize_q8_1_data, "main", 2, 1 * sizeof(uint32_t), {32 * device->subgroup_size / 8, 1, 1}, { device->subgroup_size }, 1); - ggml_vk_create_pipeline(device, device->pipeline_mul_mat_vec_p021_f16_f32, "mul_mat_vec_p021_f16_f32", mul_mat_vec_p021_f16_f32_len, mul_mat_vec_p021_f16_f32_data, "main", 3, 6 * sizeof(uint32_t), {1, 1, 1}, {}, 1); - ggml_vk_create_pipeline(device, device->pipeline_mul_mat_vec_nc_f16_f32, "mul_mat_vec_nc_f16_f32", mul_mat_vec_nc_f16_f32_len, mul_mat_vec_nc_f16_f32_data, "main", 3, 7 * sizeof(uint32_t), {1, 1, 1}, {}, 1); + for (uint32_t i = 0; i < p021_max_gqa_ratio; ++i) { + if (device->subgroup_add && device->subgroup_require_full_support) { + ggml_vk_create_pipeline(device, device->pipeline_mul_mat_vec_p021_f16_f32[i], "mul_mat_vec_p021_f16_f32"+std::to_string(i+1), mul_mat_vec_p021_f16_f32_subgroup_add_len, mul_mat_vec_p021_f16_f32_subgroup_add_data, "main", 3, 6 * sizeof(uint32_t), {1, 1, 1}, {device->subgroup_size, i + 1}, 1, true, true); + } else { + ggml_vk_create_pipeline(device, device->pipeline_mul_mat_vec_p021_f16_f32[i], "mul_mat_vec_p021_f16_f32"+std::to_string(i+1), mul_mat_vec_p021_f16_f32_len, mul_mat_vec_p021_f16_f32_data, "main", 3, 6 * sizeof(uint32_t), {1, 1, 1}, {device->subgroup_size, i + 1}, 1, true); + } + } + ggml_vk_create_pipeline(device, device->pipeline_mul_mat_vec_nc_f16_f32, "mul_mat_vec_nc_f16_f32", mul_mat_vec_nc_f16_f32_len, mul_mat_vec_nc_f16_f32_data, "main", 3, 9 * sizeof(uint32_t), {1, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_norm_f32, "norm_f32", norm_f32_len, norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_group_norm_f32, "group_norm_f32", group_norm_f32_len, group_norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1); - ggml_vk_create_pipeline(device, device->pipeline_rms_norm_f32, "rms_norm_f32", rms_norm_f32_len, rms_norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_rms_norm_f32, "rms_norm_f32", rms_norm_f32_len, rms_norm_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {1, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_rms_norm_back_f32, "rms_norm_back_f32", rms_norm_back_f32_len, rms_norm_back_f32_data, "main", 3, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_l2_norm_f32, "l2_norm_f32", l2_norm_f32_len, l2_norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_f32, "cpy_f32_f32", cpy_f32_f32_len, cpy_f32_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_f16, "cpy_f32_f16", cpy_f32_f16_len, cpy_f32_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_cpy_f16_f16, "cpy_f16_f16", cpy_f16_f16_len, cpy_f16_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_cpy_f16_f32, "cpy_f16_f32", cpy_f16_f32_len, cpy_f16_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_bf16,"cpy_f32_bf16",cpy_f32_bf16_len,cpy_f32_bf16_data,"main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_contig_cpy_f32_f32, "contig_cpy_f32_f32", contig_cpy_f32_f32_len, contig_cpy_f32_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_contig_cpy_f32_f16, "contig_cpy_f32_f16", contig_cpy_f32_f16_len, contig_cpy_f32_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_contig_cpy_f16_f16, "contig_cpy_f16_f16", contig_cpy_f16_f16_len, contig_cpy_f16_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_contig_cpy_f16_f32, "contig_cpy_f16_f32", contig_cpy_f16_f32_len, contig_cpy_f16_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_contig_cpy_f32_bf16,"contig_cpy_f32_bf16",contig_cpy_f32_bf16_len,contig_cpy_f32_bf16_data,"main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); - ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q4_0], "cpy_f32_q4_0", cpy_f32_q4_0_len, cpy_f32_q4_0_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q4_0), 1, 1}, {}, 1); - ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q4_1], "cpy_f32_q4_1", cpy_f32_q4_1_len, cpy_f32_q4_1_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q4_1), 1, 1}, {}, 1); - ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q5_0], "cpy_f32_q5_0", cpy_f32_q5_0_len, cpy_f32_q5_0_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q5_0), 1, 1}, {}, 1); - ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q5_1], "cpy_f32_q5_1", cpy_f32_q5_1_len, cpy_f32_q5_1_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q5_1), 1, 1}, {}, 1); - ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q8_0], "cpy_f32_q8_0", cpy_f32_q8_0_len, cpy_f32_q8_0_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q8_0), 1, 1}, {}, 1); - ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_IQ4_NL], "cpy_f32_iq4_nl", cpy_f32_iq4_nl_len, cpy_f32_iq4_nl_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_IQ4_NL), 1, 1}, {}, 1); + if (device->float_controls_rte_fp16) { + ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q4_0], "cpy_f32_q4_0", cpy_f32_q4_0_rte_len, cpy_f32_q4_0_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q4_0), 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q4_1], "cpy_f32_q4_1", cpy_f32_q4_1_rte_len, cpy_f32_q4_1_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q4_1), 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q5_0], "cpy_f32_q5_0", cpy_f32_q5_0_rte_len, cpy_f32_q5_0_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q5_0), 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q5_1], "cpy_f32_q5_1", cpy_f32_q5_1_rte_len, cpy_f32_q5_1_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q5_1), 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q8_0], "cpy_f32_q8_0", cpy_f32_q8_0_rte_len, cpy_f32_q8_0_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q8_0), 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_IQ4_NL], "cpy_f32_iq4_nl", cpy_f32_iq4_nl_rte_len, cpy_f32_iq4_nl_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_IQ4_NL), 1, 1}, {}, 1); + } else { + ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q4_0], "cpy_f32_q4_0", cpy_f32_q4_0_len, cpy_f32_q4_0_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q4_0), 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q4_1], "cpy_f32_q4_1", cpy_f32_q4_1_len, cpy_f32_q4_1_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q4_1), 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q5_0], "cpy_f32_q5_0", cpy_f32_q5_0_len, cpy_f32_q5_0_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q5_0), 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q5_1], "cpy_f32_q5_1", cpy_f32_q5_1_len, cpy_f32_q5_1_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q5_1), 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q8_0], "cpy_f32_q8_0", cpy_f32_q8_0_len, cpy_f32_q8_0_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q8_0), 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_IQ4_NL], "cpy_f32_iq4_nl", cpy_f32_iq4_nl_len, cpy_f32_iq4_nl_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_IQ4_NL), 1, 1}, {}, 1); + } ggml_vk_create_pipeline(device, device->pipeline_cpy_quant_f32[GGML_TYPE_Q4_0], "cpy_q4_0_f32", cpy_q4_0_f32_len, cpy_q4_0_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q4_0), 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_cpy_quant_f32[GGML_TYPE_Q4_1], "cpy_q4_1_f32", cpy_q4_1_f32_len, cpy_q4_1_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q4_1), 1, 1}, {}, 1); @@ -2150,20 +2575,32 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_cpy_quant_f32[GGML_TYPE_Q8_0], "cpy_q8_0_f32", cpy_q8_0_f32_len, cpy_q8_0_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q8_0), 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_cpy_quant_f32[GGML_TYPE_IQ4_NL], "cpy_iq4_nl_f32", cpy_iq4_nl_f32_len, cpy_iq4_nl_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_IQ4_NL), 1, 1}, {}, 1); - ggml_vk_create_pipeline(device, device->pipeline_add_f32, "add_f32", add_f32_len, add_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {0}, 1); - ggml_vk_create_pipeline(device, device->pipeline_add_f32_norepeat, "add_f32_norepeat", add_f32_len, add_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {1}, 1); - ggml_vk_create_pipeline(device, device->pipeline_add_f16_f32_f16, "add_f16_f32_f16", add_f16_f32_f16_len, add_f16_f32_f16_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {0}, 1); - ggml_vk_create_pipeline(device, device->pipeline_add_f16_f32_f16_norepeat, "add_f16_f32_f16_norepeat", add_f16_f32_f16_len, add_f16_f32_f16_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {1}, 1); + auto get_suffix = [](bool src0_f16, bool src1_f16, bool dst_f16) { + std::string s; + s += std::string(src0_f16 ? "_f16" : "_f32"); + s += std::string(src1_f16 ? "_f16" : "_f32"); + s += std::string(dst_f16 ? "_f16" : "_f32"); + return s; + }; + +#define CREATE_BINARY(name, namemod, spec) \ + for (int s0 : {0,1}) for (int s1 : {0,1}) for (int d : {0,1}) \ + ggml_vk_create_pipeline(device, device->pipeline_ ## name ## namemod[s0][s1][d], \ + #name + get_suffix(s0, s1, d) + #namemod, name ## _len[s0][s1][d], name ## _data[s0][s1][d], \ + "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, spec, 1); + + CREATE_BINARY(add, , {0}) + CREATE_BINARY(add, _norepeat, {1}) + CREATE_BINARY(sub, , {0}) + CREATE_BINARY(sub, _norepeat, {1}) + CREATE_BINARY(mul, , {0}) + CREATE_BINARY(mul, _norepeat, {1}) + CREATE_BINARY(div, , {0}) + CREATE_BINARY(div, _norepeat, {1}) +#undef CREATE_BINARY ggml_vk_create_pipeline(device, device->pipeline_acc_f32, "acc_f32", acc_f32_len, acc_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1); - ggml_vk_create_pipeline(device, device->pipeline_sub_f32, "sub_f32", sub_f32_len, sub_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {0}, 1); - ggml_vk_create_pipeline(device, device->pipeline_sub_f32_norepeat, "sub_f32_norepeat", sub_f32_len, sub_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {1}, 1); - ggml_vk_create_pipeline(device, device->pipeline_mul_f32, "mul_f32", mul_f32_len, mul_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {0}, 1); - ggml_vk_create_pipeline(device, device->pipeline_mul_f32_norepeat, "mul_f32_norepeat", mul_f32_len, mul_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {1}, 1); - ggml_vk_create_pipeline(device, device->pipeline_div_f32, "div_f32", div_f32_len, div_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {0}, 1); - ggml_vk_create_pipeline(device, device->pipeline_div_f32_norepeat, "div_f32_norepeat", div_f32_len, div_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {1}, 1); - ggml_vk_create_pipeline(device, device->pipeline_concat_f32, "concat_f32", concat_f32_len, concat_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_concat_f16, "concat_f16", concat_f16_len, concat_f16_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_concat_i32, "concat_i32", concat_i32_len, concat_i32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1); @@ -2183,14 +2620,20 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_repeat_f32, "repeat_f32", repeat_f32_len, repeat_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_repeat_back_f32, "repeat_back_f32", repeat_back_f32_len, repeat_back_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); - ggml_vk_create_pipeline(device, device->pipeline_gelu_f32, "gelu_f32", gelu_f32_len, gelu_f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1); - ggml_vk_create_pipeline(device, device->pipeline_gelu_quick_f32, "gelu_quick_f32", gelu_quick_f32_len, gelu_quick_f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1); - ggml_vk_create_pipeline(device, device->pipeline_silu_f32, "silu_f32", silu_f32_len, silu_f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1); - ggml_vk_create_pipeline(device, device->pipeline_silu_back_f32, "silu_back_f32", silu_back_f32_len, silu_back_f32_data, "main", 3, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1); - ggml_vk_create_pipeline(device, device->pipeline_relu_f32, "relu_f32", relu_f32_len, relu_f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1); +#define CREATE_UNARY(name) \ + ggml_vk_create_pipeline(device, device->pipeline_ ## name [0], #name "_f32", name ## _f32_len, name ## _f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1); \ + ggml_vk_create_pipeline(device, device->pipeline_ ## name [1], #name "_f16", name ## _f16_len, name ## _f16_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1); + + CREATE_UNARY(gelu) + CREATE_UNARY(gelu_quick) + CREATE_UNARY(silu) + CREATE_UNARY(relu) + CREATE_UNARY(tanh) + CREATE_UNARY(sigmoid) +#undef CREATE_UNARY + ggml_vk_create_pipeline(device, device->pipeline_leaky_relu_f32, "leaky_relu_f32", leaky_relu_f32_len, leaky_relu_f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1); - ggml_vk_create_pipeline(device, device->pipeline_tanh_f32, "tanh_f32", tanh_f32_len, tanh_f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1); - ggml_vk_create_pipeline(device, device->pipeline_sigmoid_f32, "sigmoid_f32", sigmoid_f32_len, sigmoid_f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_silu_back_f32, "silu_back_f32", silu_back_f32_len, silu_back_f32_data, "main", 3, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_diag_mask_inf_f32, "diag_mask_inf_f32", diag_mask_inf_f32_len, diag_mask_inf_f32_data, "main", 2, sizeof(vk_op_diag_mask_push_constants), {1, 512, 1}, {}, 1, true); @@ -2238,15 +2681,20 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_rwkv_wkv6_f32, "rwkv_wkv6_f32", rwkv_wkv6_f32_len, rwkv_wkv6_f32_data, "main", 7, sizeof(vk_op_rwkv_wkv6_push_constants), {1, 1, 1}, {device->subgroup_size}, 1); + ggml_vk_create_pipeline(device, device->pipeline_rwkv_wkv7_f32, "rwkv_wkv7_f32", rwkv_wkv7_f32_len, rwkv_wkv7_f32_data, "main", 8, sizeof(vk_op_rwkv_wkv7_push_constants), {1, 1, 1}, {device->subgroup_size}, 1); + ggml_vk_create_pipeline(device, device->pipeline_opt_step_adamw_f32, "opt_step_adamw_f32", opt_step_adamw_f32_len, opt_step_adamw_f32_data, "main", 5, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_conv2d_dw_whcn_f32, "conv2d_dw_whcn_f32", conv2d_dw_whcn_f32_len, conv2d_dw_whcn_f32_data, "main", 3, sizeof(vk_op_conv2d_dw_push_constants), {512, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_conv2d_dw_cwhn_f32, "conv2d_dw_cwhn_f32", conv2d_dw_cwhn_f32_len, conv2d_dw_cwhn_f32_data, "main", 3, sizeof(vk_op_conv2d_dw_push_constants), {512, 1, 1}, {}, 1); + for (auto &c : compiles) { c.wait(); } device->need_compiles = false; } -static bool ggml_vk_khr_cooperative_matrix_support(const vk::PhysicalDeviceProperties& props, const vk::PhysicalDeviceDriverProperties& driver_props); +static bool ggml_vk_khr_cooperative_matrix_support(const vk::PhysicalDeviceProperties& props, const vk::PhysicalDeviceDriverProperties& driver_props, vk_device_architecture arch); static vk_device ggml_vk_get_device(size_t idx) { VK_LOG_DEBUG("ggml_vk_get_device(" << idx << ")"); @@ -2275,6 +2723,8 @@ static vk_device ggml_vk_get_device(size_t idx) { device->physical_device = physical_devices[dev_num]; const std::vector ext_props = device->physical_device.enumerateDeviceExtensionProperties(); + device->architecture = get_device_architecture(device->physical_device); + const char* GGML_VK_PREFER_HOST_MEMORY = getenv("GGML_VK_PREFER_HOST_MEMORY"); device->prefer_host_memory = GGML_VK_PREFER_HOST_MEMORY != nullptr; @@ -2286,8 +2736,9 @@ static vk_device ggml_vk_get_device(size_t idx) { bool pipeline_robustness = false; bool coopmat2_support = false; device->coopmat_support = false; + device->integer_dot_product = false; + bool bfloat16_support = false; - // Check if maintenance4 is supported for (const auto& properties : ext_props) { if (strcmp("VK_KHR_maintenance4", properties.extensionName) == 0) { maintenance4_support = true; @@ -2312,6 +2763,14 @@ static vk_device ggml_vk_get_device(size_t idx) { } else if (strcmp("VK_NV_cooperative_matrix2", properties.extensionName) == 0 && !getenv("GGML_VK_DISABLE_COOPMAT2")) { coopmat2_support = true; +#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT) + } else if (strcmp("VK_KHR_shader_integer_dot_product", properties.extensionName) == 0 && + !getenv("GGML_VK_DISABLE_INTEGER_DOT_PRODUCT")) { + device->integer_dot_product = true; +#endif + } else if (strcmp("VK_KHR_shader_bfloat16", properties.extensionName) == 0 && + !getenv("GGML_VK_DISABLE_BFLOAT16")) { + bfloat16_support = true; } } @@ -2322,13 +2781,16 @@ static vk_device ggml_vk_get_device(size_t idx) { vk::PhysicalDeviceDriverProperties driver_props; vk::PhysicalDeviceShaderSMBuiltinsPropertiesNV sm_props; vk::PhysicalDeviceShaderCoreProperties2AMD amd_shader_core_properties2_props; + vk::PhysicalDeviceVulkan11Properties vk11_props; vk::PhysicalDeviceVulkan12Properties vk12_props; vk::PhysicalDeviceSubgroupSizeControlPropertiesEXT subgroup_size_control_props; + vk::PhysicalDeviceShaderIntegerDotProductPropertiesKHR shader_integer_dot_product_props; props2.pNext = &props3; props3.pNext = &subgroup_props; subgroup_props.pNext = &driver_props; - driver_props.pNext = &vk12_props; + driver_props.pNext = &vk11_props; + vk11_props.pNext = &vk12_props; VkBaseOutStructure * last_struct = (VkBaseOutStructure *)&vk12_props; @@ -2357,9 +2819,15 @@ static vk_device ggml_vk_get_device(size_t idx) { } #endif + if (device->integer_dot_product) { + last_struct->pNext = (VkBaseOutStructure *)&shader_integer_dot_product_props; + last_struct = (VkBaseOutStructure *)&shader_integer_dot_product_props; + } + device->physical_device.getProperties2(&props2); device->properties = props2.properties; device->vendor_id = device->properties.vendorID; + device->driver_id = driver_props.driverID; const char* GGML_VK_FORCE_MAX_ALLOCATION_SIZE = getenv("GGML_VK_FORCE_MAX_ALLOCATION_SIZE"); @@ -2375,13 +2843,9 @@ static vk_device ggml_vk_get_device(size_t idx) { if (GGML_VK_SUBALLOCATION_BLOCK_SIZE != nullptr) { device->suballocation_block_size = std::stoul(GGML_VK_SUBALLOCATION_BLOCK_SIZE); -#if defined(_WIN32) - } else if (device->vendor_id == VK_VENDOR_ID_NVIDIA) { + } else { // Limit batching of allocations to 1GB by default to avoid fragmentation issues device->suballocation_block_size = 1024*1024*1024; -#endif - } else { - device->suballocation_block_size = device->max_memory_allocation_size; } device->suballocation_block_size = std::min(device->suballocation_block_size, device->max_memory_allocation_size); @@ -2396,14 +2860,22 @@ static vk_device ggml_vk_get_device(size_t idx) { } device->float_controls_rte_fp16 = vk12_props.shaderRoundingModeRTEFloat16; + device->subgroup_add = (vk11_props.subgroupSupportedStages & vk::ShaderStageFlagBits::eCompute) && + (vk11_props.subgroupSupportedOperations & vk::SubgroupFeatureFlagBits::eArithmetic); + + device->subgroup_shuffle = (vk11_props.subgroupSupportedStages & vk::ShaderStageFlagBits::eCompute) && + (vk11_props.subgroupSupportedOperations & vk::SubgroupFeatureFlagBits::eShuffle); + const bool force_disable_f16 = getenv("GGML_VK_DISABLE_F16") != nullptr; device->fp16 = !force_disable_f16 && fp16_storage && fp16_compute; - if (!ggml_vk_khr_cooperative_matrix_support(device->properties, driver_props)) { + if (!ggml_vk_khr_cooperative_matrix_support(device->properties, driver_props, device->architecture)) { device->coopmat_support = false; } + device->integer_dot_product = device->integer_dot_product && shader_integer_dot_product_props.integerDotProduct4x8BitPackedSignedAccelerated; + std::vector queue_family_props = device->physical_device.getQueueFamilyProperties(); // Try to find a non-graphics compute queue and transfer-focused queues @@ -2488,6 +2960,17 @@ static vk_device ggml_vk_get_device(size_t idx) { } #endif +#if defined(VK_KHR_shader_bfloat16) + VkPhysicalDeviceShaderBfloat16FeaturesKHR bfloat16_features {}; + bfloat16_features.pNext = nullptr; + bfloat16_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_SHADER_BFLOAT16_FEATURES_KHR; + if (bfloat16_support) { + last_struct->pNext = (VkBaseOutStructure *)&bfloat16_features; + last_struct = (VkBaseOutStructure *)&bfloat16_features; + device_extensions.push_back("VK_KHR_shader_bfloat16"); + } +#endif + VkPhysicalDeviceMaintenance4Features maint4_features {}; maint4_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_MAINTENANCE_4_FEATURES; if (maintenance4_support) { @@ -2496,6 +2979,14 @@ static vk_device ggml_vk_get_device(size_t idx) { device_extensions.push_back("VK_KHR_maintenance4"); } + VkPhysicalDeviceShaderIntegerDotProductFeaturesKHR shader_integer_dot_product_features {}; + shader_integer_dot_product_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_SHADER_INTEGER_DOT_PRODUCT_FEATURES_KHR; + if (device->integer_dot_product) { + last_struct->pNext = (VkBaseOutStructure *)&shader_integer_dot_product_features; + last_struct = (VkBaseOutStructure *)&shader_integer_dot_product_features; + device_extensions.push_back("VK_KHR_shader_integer_dot_product"); + } + vkGetPhysicalDeviceFeatures2(device->physical_device, &device_features2); device->fp16 = device->fp16 && vk12_features.shaderFloat16; @@ -2665,7 +3156,37 @@ static vk_device ggml_vk_get_device(size_t idx) { device->coopmat_acc_f16_support = true; } } + } else if ((vk::ComponentTypeKHR)prop.AType == vk::ComponentTypeKHR::eSint8 && + (vk::ComponentTypeKHR)prop.BType == vk::ComponentTypeKHR::eSint8 && + (vk::ComponentTypeKHR)prop.CType == vk::ComponentTypeKHR::eSint32 && + (vk::ComponentTypeKHR)prop.ResultType == vk::ComponentTypeKHR::eSint32 && + (vk::ScopeKHR)prop.scope == vk::ScopeKHR::eSubgroup && + device->coopmat_int_m == 0 + ) { + device->coopmat_int_support = true; + device->coopmat_int_m = prop.MSize; + device->coopmat_int_n = prop.NSize; + device->coopmat_int_k = prop.KSize; } +#if defined(VK_KHR_shader_bfloat16) && defined(GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT) + if (prop.AType == VK_COMPONENT_TYPE_BFLOAT16_KHR && + prop.BType == VK_COMPONENT_TYPE_BFLOAT16_KHR && + prop.CType == VK_COMPONENT_TYPE_FLOAT32_KHR && + prop.ResultType == VK_COMPONENT_TYPE_FLOAT32_KHR && + (vk::ScopeKHR)prop.scope == vk::ScopeKHR::eSubgroup + ) { + // coopmat sizes not set yet + if (device->coopmat_m == 0) { + device->coopmat_bf16_support = true; + device->coopmat_m = prop.MSize; + device->coopmat_n = prop.NSize; + device->coopmat_k = prop.KSize; + } else if (device->coopmat_m == prop.MSize && device->coopmat_n == prop.NSize && device->coopmat_k == prop.KSize) { + // Only enable if shape is identical + device->coopmat_bf16_support = true; + } + } +#endif } if (device->coopmat_m == 0 || !device->coopmat_acc_f32_support) { @@ -2673,11 +3194,19 @@ static vk_device ggml_vk_get_device(size_t idx) { GGML_LOG_DEBUG("ggml_vulkan: WARNING: No suitable matrix core mode found. Disabling matrix cores.\n"); device->coopmat_support = false; } + if (getenv("GGML_VK_DISABLE_BFLOAT16")) { + device->coopmat_bf16_support = false; + } } if (device->coopmat_support) { device_extensions.push_back("VK_KHR_cooperative_matrix"); } +#if defined(VK_KHR_shader_bfloat16) + if (device->coopmat_bf16_support) { + device_extensions.push_back("VK_KHR_shader_bfloat16"); + } +#endif #endif device->name = GGML_VK_NAME + std::to_string(idx); @@ -2769,22 +3298,11 @@ static void ggml_vk_print_gpu_info(size_t idx) { vk::PhysicalDevice physical_device = devices[dev_num]; std::vector ext_props = physical_device.enumerateDeviceExtensionProperties(); - vk::PhysicalDeviceProperties2 props2; - vk::PhysicalDeviceMaintenance3Properties props3; - vk::PhysicalDeviceSubgroupProperties subgroup_props; - vk::PhysicalDeviceDriverProperties driver_props; - props2.pNext = &props3; - props3.pNext = &subgroup_props; - subgroup_props.pNext = &driver_props; - physical_device.getProperties2(&props2); - - const size_t subgroup_size = subgroup_props.subgroupSize; - const bool uma = props2.properties.deviceType == vk::PhysicalDeviceType::eIntegratedGpu; - bool fp16_storage = false; bool fp16_compute = false; bool coopmat_support = false; bool coopmat2_support = false; + bool integer_dot_product = false; for (auto properties : ext_props) { if (strcmp("VK_KHR_16bit_storage", properties.extensionName) == 0) { @@ -2800,25 +3318,44 @@ static void ggml_vk_print_gpu_info(size_t idx) { } else if (strcmp("VK_NV_cooperative_matrix2", properties.extensionName) == 0 && !getenv("GGML_VK_DISABLE_COOPMAT2")) { coopmat2_support = true; +#endif +#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT) + } else if (strcmp("VK_KHR_shader_integer_dot_product", properties.extensionName) == 0 && + !getenv("GGML_VK_DISABLE_INTEGER_DOT_PRODUCT")) { + integer_dot_product = true; #endif } } - if (!ggml_vk_khr_cooperative_matrix_support(props2.properties, driver_props)) { - coopmat_support = false; - } + const vk_device_architecture device_architecture = get_device_architecture(physical_device); const char* GGML_VK_DISABLE_F16 = getenv("GGML_VK_DISABLE_F16"); bool force_disable_f16 = GGML_VK_DISABLE_F16 != nullptr; bool fp16 = !force_disable_f16 && fp16_storage && fp16_compute; - vk::PhysicalDeviceFeatures device_features = physical_device.getFeatures(); + vk::PhysicalDeviceProperties2 props2; + vk::PhysicalDeviceMaintenance3Properties props3; + vk::PhysicalDeviceSubgroupProperties subgroup_props; + vk::PhysicalDeviceDriverProperties driver_props; + vk::PhysicalDeviceShaderIntegerDotProductPropertiesKHR shader_integer_dot_product_props; + props2.pNext = &props3; + props3.pNext = &subgroup_props; + subgroup_props.pNext = &driver_props; + + // Pointer to the last chain element + VkBaseOutStructure * last_struct = (VkBaseOutStructure *)&driver_props; + + if (integer_dot_product) { + last_struct->pNext = (VkBaseOutStructure *)&shader_integer_dot_product_props; + last_struct = (VkBaseOutStructure *)&shader_integer_dot_product_props; + } + + physical_device.getProperties2(&props2); VkPhysicalDeviceFeatures2 device_features2; device_features2.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_FEATURES_2; device_features2.pNext = nullptr; - device_features2.features = (VkPhysicalDeviceFeatures)device_features; VkPhysicalDeviceVulkan11Features vk11_features; vk11_features.pNext = nullptr; @@ -2831,7 +3368,7 @@ static void ggml_vk_print_gpu_info(size_t idx) { vk11_features.pNext = &vk12_features; // Pointer to the last chain element - VkBaseOutStructure * last_struct = (VkBaseOutStructure *)&vk12_features; + last_struct = (VkBaseOutStructure *)&vk12_features; #if defined(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT) VkPhysicalDeviceCooperativeMatrixFeaturesKHR coopmat_features; @@ -2843,20 +3380,39 @@ static void ggml_vk_print_gpu_info(size_t idx) { last_struct->pNext = (VkBaseOutStructure *)&coopmat_features; last_struct = (VkBaseOutStructure *)&coopmat_features; } +#endif + + VkPhysicalDeviceShaderIntegerDotProductFeaturesKHR shader_integer_dot_product_features {}; + shader_integer_dot_product_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_SHADER_INTEGER_DOT_PRODUCT_FEATURES_KHR; + if (integer_dot_product) { + last_struct->pNext = (VkBaseOutStructure *)&shader_integer_dot_product_features; + last_struct = (VkBaseOutStructure *)&shader_integer_dot_product_features; + } vkGetPhysicalDeviceFeatures2(physical_device, &device_features2); fp16 = fp16 && vk12_features.shaderFloat16; - coopmat_support = coopmat_support && coopmat_features.cooperativeMatrix; + uint32_t default_subgroup_size = get_subgroup_size("", device_architecture); + const size_t subgroup_size = (default_subgroup_size != 0) ? default_subgroup_size : subgroup_props.subgroupSize; + const bool uma = props2.properties.deviceType == vk::PhysicalDeviceType::eIntegratedGpu; + + integer_dot_product = integer_dot_product + && shader_integer_dot_product_props.integerDotProduct4x8BitPackedSignedAccelerated + && shader_integer_dot_product_features.shaderIntegerDotProduct; + + coopmat_support = coopmat_support +#if defined(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT) + && coopmat_features.cooperativeMatrix #endif + && ggml_vk_khr_cooperative_matrix_support(props2.properties, driver_props, device_architecture); std::string matrix_cores = coopmat2_support ? "NV_coopmat2" : coopmat_support ? "KHR_coopmat" : "none"; std::string device_name = props2.properties.deviceName.data(); - GGML_LOG_DEBUG("ggml_vulkan: %zu = %s (%s) | uma: %d | fp16: %d | warp size: %zu | shared memory: %d | matrix cores: %s\n", + GGML_LOG_DEBUG("ggml_vulkan: %zu = %s (%s) | uma: %d | fp16: %d | warp size: %zu | shared memory: %d | int dot: %d | matrix cores: %s\n", idx, device_name.c_str(), driver_props.driverName.data(), uma, fp16, subgroup_size, - props2.properties.limits.maxComputeSharedMemorySize, matrix_cores.c_str()); + props2.properties.limits.maxComputeSharedMemorySize, integer_dot_product, matrix_cores.c_str()); if (props2.properties.deviceType == vk::PhysicalDeviceType::eCpu) { GGML_LOG_DEBUG("ggml_vulkan: Warning: Device type is CPU. This is probably not the device you want.\n"); @@ -3058,6 +3614,7 @@ static void ggml_vk_init(ggml_backend_vk_context * ctx, size_t idx) { ctx->prealloc_size_split_k = 0; ctx->fence = ctx->device->device.createFence({}); + ctx->almost_ready_fence = ctx->device->device.createFence({}); #ifdef GGML_VULKAN_CHECK_RESULTS const char* skip_checks = getenv("GGML_VULKAN_SKIP_CHECKS"); @@ -3106,6 +3663,9 @@ static vk_matmul_pipeline ggml_vk_get_mul_mat_mat_pipeline(ggml_backend_vk_conte if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_F16) { return ctx->device->pipeline_matmul_f32_f16; } + if (src0_type == GGML_TYPE_BF16 && src1_type == GGML_TYPE_BF16) { + return ctx->device->pipeline_matmul_bf16; + } if (prec == GGML_PREC_DEFAULT && ctx->device->fp16 && !(ctx->device->coopmat_support && !ctx->device->coopmat_acc_f16_support)) { if (src0_type == GGML_TYPE_F16 && src1_type == GGML_TYPE_F32) { return ctx->device->pipeline_matmul_f16_f32.f16acc; @@ -3122,6 +3682,17 @@ static vk_matmul_pipeline ggml_vk_get_mul_mat_mat_pipeline(ggml_backend_vk_conte } } + // MMQ + if (src1_type == GGML_TYPE_Q8_1) { + vk_matmul_pipeline pipelines = ctx->device->pipeline_dequant_mul_mat_mat_q8_1[src0_type].f16acc; + + if (pipelines->s == nullptr && pipelines->m == nullptr && pipelines->l == nullptr) { + return nullptr; + } + + return pipelines; + } + if (src1_type != GGML_TYPE_F32 && !ctx->device->coopmat2) { return nullptr; } @@ -3166,6 +3737,7 @@ static vk_pipeline ggml_vk_get_dequantize_mul_mat_vec(ggml_backend_vk_context * switch (a_type) { case GGML_TYPE_F32: case GGML_TYPE_F16: + case GGML_TYPE_BF16: case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_1: case GGML_TYPE_Q5_0: @@ -3198,6 +3770,9 @@ static vk_matmul_pipeline ggml_vk_get_mul_mat_mat_id_pipeline(ggml_backend_vk_co if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_F32) { return ctx->device->pipeline_matmul_id_f32; } + if (src0_type == GGML_TYPE_BF16 && src1_type == GGML_TYPE_BF16) { + return ctx->device->pipeline_matmul_id_bf16; + } if (prec == GGML_PREC_DEFAULT && ctx->device->fp16 && !(ctx->device->coopmat_support && !ctx->device->coopmat_acc_f16_support)) { if (src0_type == GGML_TYPE_F16 && src1_type == GGML_TYPE_F32) { return ctx->device->pipeline_matmul_id_f16_f32.f16acc; @@ -3251,6 +3826,7 @@ static vk_pipeline ggml_vk_get_dequantize_mul_mat_vec_id(ggml_backend_vk_context switch (a_type) { case GGML_TYPE_F32: case GGML_TYPE_F16: + case GGML_TYPE_BF16: case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_1: case GGML_TYPE_Q5_0: @@ -3414,8 +3990,6 @@ static vk_submission ggml_vk_begin_submission(vk_device& device, vk_queue& q, bo return s; } - - static void ggml_vk_dispatch_pipeline(ggml_backend_vk_context* ctx, vk_context& subctx, vk_pipeline& pipeline, std::initializer_list const& descriptor_buffer_infos, size_t push_constant_size, const void* push_constants, std::array elements) { const uint32_t wg0 = CEIL_DIV(elements[0], pipeline->wg_denoms[0]); const uint32_t wg1 = CEIL_DIV(elements[1], pipeline->wg_denoms[1]); @@ -3839,20 +4413,30 @@ static uint32_t ggml_vk_guess_split_k(ggml_backend_vk_context * ctx, int m, int if (split_k == 3) { split_k = 2; } + if (ctx->device->coopmat2) { + // coopmat2 shader expects splits to be aligned to 256 + while (split_k > 1 && ((k / split_k) % 256) != 0) { + split_k /= 2; + } + } } } return split_k; } -static vk_pipeline ggml_vk_guess_matmul_pipeline(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, int m, int n, bool aligned, ggml_type src0_type) { - VK_LOG_DEBUG("ggml_vk_guess_matmul_pipeline(" << m << ", " << n << ", " << aligned << ", " << ggml_type_name(src0_type) << ")"); +static vk_pipeline ggml_vk_guess_matmul_pipeline(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, uint32_t m, uint32_t n, bool aligned, ggml_type src0_type, ggml_type src1_type) { + VK_LOG_DEBUG("ggml_vk_guess_matmul_pipeline(" << m << ", " << n << ", " << aligned << ", " << ggml_type_name(src0_type) << ", " << ggml_type_name(src1_type) << ")"); if (ctx->device->coopmat2) { - if ((ctx->device->mul_mat_l[src0_type] && (m % mmp->l->wg_denoms[0]) == 0 && (n % mmp->l->wg_denoms[1]) == 0) || (!ctx->device->mul_mat_m[src0_type] && !ctx->device->mul_mat_s[src0_type])) { + // Use large shader when the N dimension is greater than the medium shader's tile size + uint32_t crossover_large = mmp->m->wg_denoms[1]; + if ((ctx->device->mul_mat_l[src0_type] && (n > crossover_large)) || (!ctx->device->mul_mat_m[src0_type] && !ctx->device->mul_mat_s[src0_type])) { return aligned ? mmp->a_l : mmp->l; } - if ((ctx->device->mul_mat_m[src0_type] && (m % mmp->m->wg_denoms[0]) == 0 && (n % mmp->m->wg_denoms[1]) == 0) || !ctx->device->mul_mat_s[src0_type]) { + // Use medium shader when the N dimension is greater than the small shader's tile size + uint32_t crossover_medium = mmp->s->wg_denoms[1]; + if ((ctx->device->mul_mat_m[src0_type] && (n > crossover_medium)) || !ctx->device->mul_mat_s[src0_type]) { return aligned ? mmp->a_m : mmp->m; } return aligned ? mmp->a_s : mmp->s; @@ -3867,9 +4451,9 @@ static vk_pipeline ggml_vk_guess_matmul_pipeline(ggml_backend_vk_context * ctx, return aligned ? mmp->a_l : mmp->l; } -static uint32_t ggml_vk_guess_matmul_pipeline_align(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, int m, int n, ggml_type src0_type) { - VK_LOG_DEBUG("ggml_vk_guess_matmul_pipeline_align(" << m << ", " << n << ", " << ggml_type_name(src0_type) << ")"); - return ggml_vk_guess_matmul_pipeline(ctx, mmp, m, n, true, src0_type)->align; +static uint32_t ggml_vk_guess_matmul_pipeline_align(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, int m, int n, ggml_type src0_type, ggml_type src1_type) { + VK_LOG_DEBUG("ggml_vk_guess_matmul_pipeline_align(" << m << ", " << n << ", " << ggml_type_name(src0_type) << ", " << ggml_type_name(src1_type) << ")"); + return ggml_vk_guess_matmul_pipeline(ctx, mmp, m, n, true, src0_type, src1_type)->align; } static void ggml_vk_matmul( @@ -3877,18 +4461,19 @@ static void ggml_vk_matmul( vk_subbuffer&& a, vk_subbuffer&& b, vk_subbuffer&& d, vk_subbuffer&& split_k_buffer, uint32_t m, uint32_t n, uint32_t k, uint32_t stride_a, uint32_t stride_b, uint32_t stride_d, uint32_t batch_stride_a, uint32_t batch_stride_b, uint32_t batch_stride_d, - uint32_t split_k, uint32_t batch, uint32_t ne02, uint32_t ne12, uint32_t broadcast2, uint32_t broadcast3) { - VK_LOG_DEBUG("ggml_vk_matmul(a: (" << a.buffer->buffer << ", " << a.offset << ", " << a.size << "), b: (" << b.buffer->buffer << ", " << b.offset << ", " << b.size << "), d: (" << d.buffer->buffer << ", " << d.offset << ", " << d.size << "), split_k: (" << (split_k_buffer.buffer != nullptr ? split_k_buffer.buffer->buffer : VK_NULL_HANDLE) << ", " << split_k_buffer.offset << ", " << split_k_buffer.size << "), m: " << m << ", n: " << n << ", k: " << k << ", stride_a: " << stride_a << ", stride_b: " << stride_b << ", stride_d: " << stride_d << ", batch_stride_a: " << batch_stride_a << ", batch_stride_b: " << batch_stride_b << ", batch_stride_d: " << batch_stride_d << ", split_k: " << split_k << ", batch: " << batch << ", ne02: " << ne02 << ", ne12: " << ne12 << ", broadcast2: " << broadcast2 << ", broadcast3: " << broadcast3 << ")"); + uint32_t split_k, uint32_t batch, uint32_t ne02, uint32_t ne12, uint32_t broadcast2, uint32_t broadcast3, + uint32_t padded_n) { + VK_LOG_DEBUG("ggml_vk_matmul(a: (" << a.buffer->buffer << ", " << a.offset << ", " << a.size << "), b: (" << b.buffer->buffer << ", " << b.offset << ", " << b.size << "), d: (" << d.buffer->buffer << ", " << d.offset << ", " << d.size << "), split_k: (" << (split_k_buffer.buffer != nullptr ? split_k_buffer.buffer->buffer : VK_NULL_HANDLE) << ", " << split_k_buffer.offset << ", " << split_k_buffer.size << "), m: " << m << ", n: " << n << ", k: " << k << ", stride_a: " << stride_a << ", stride_b: " << stride_b << ", stride_d: " << stride_d << ", batch_stride_a: " << batch_stride_a << ", batch_stride_b: " << batch_stride_b << ", batch_stride_d: " << batch_stride_d << ", split_k: " << split_k << ", batch: " << batch << ", ne02: " << ne02 << ", ne12: " << ne12 << ", broadcast2: " << broadcast2 << ", broadcast3: " << broadcast3 << ", padded_n: " << padded_n << ")"); ggml_vk_sync_buffers(subctx); if (split_k == 1) { - const vk_mat_mat_push_constants pc = { m, n, k, stride_a, stride_b, stride_d, batch_stride_a, batch_stride_b, batch_stride_d, k, ne02, ne12, broadcast2, broadcast3 }; + const vk_mat_mat_push_constants pc = { m, n, k, stride_a, stride_b, stride_d, batch_stride_a, batch_stride_b, batch_stride_d, k, ne02, ne12, broadcast2, broadcast3, padded_n }; ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { a, b, d }, sizeof(vk_mat_mat_push_constants), &pc, { m, n, batch }); return; } GGML_ASSERT(batch_stride_d == m * n); - const vk_mat_mat_push_constants pc1 = { m, n, k, stride_a, stride_b, stride_d, batch_stride_a, batch_stride_b, batch_stride_d, CEIL_DIV(k, split_k), ne02, ne12, broadcast2, broadcast3 }; + const vk_mat_mat_push_constants pc1 = { m, n, k, stride_a, stride_b, stride_d, batch_stride_a, batch_stride_b, batch_stride_d, CEIL_DIV(k, split_k), ne02, ne12, broadcast2, broadcast3, padded_n }; // Make sure enough workgroups get assigned for split k to work ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { a, b, split_k_buffer }, sizeof(vk_mat_mat_push_constants), &pc1, { (CEIL_DIV(m, pipeline->wg_denoms[0]) * pipeline->wg_denoms[0]) * split_k, n, batch }); ggml_vk_sync_buffers(subctx); @@ -3896,14 +4481,18 @@ static void ggml_vk_matmul( ggml_vk_dispatch_pipeline(ctx, subctx, ctx->device->pipeline_matmul_split_k_reduce, { split_k_buffer, d }, pc2.size() * sizeof(uint32_t), pc2.data(), { m * n * batch, 1, 1 }); } -static vk_pipeline ggml_vk_guess_matmul_id_pipeline(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, int m, int n, bool aligned, ggml_type src0_type) { - VK_LOG_DEBUG("ggml_vk_guess_matmul_pipeline(" << m << ", " << n << ", " << aligned << ", " << ggml_type_name(src0_type) << ")"); +static vk_pipeline ggml_vk_guess_matmul_id_pipeline(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, uint32_t m, uint32_t n, bool aligned, ggml_type src0_type) { + VK_LOG_DEBUG("ggml_vk_guess_matmul_id_pipeline(" << m << ", " << n << ", " << aligned << ", " << ggml_type_name(src0_type) << ")"); if (ctx->device->coopmat2) { - if ((ctx->device->mul_mat_id_l[src0_type] && (m % mmp->l->wg_denoms[0]) == 0 && (n % mmp->l->wg_denoms[1]) == 0) || (!ctx->device->mul_mat_id_m[src0_type] && !ctx->device->mul_mat_id_s[src0_type])) { + // Use large shader when the N dimension is greater than the medium shader's tile size + uint32_t crossover_large = mmp->m->wg_denoms[1]; + if ((ctx->device->mul_mat_id_l[src0_type] && (n > crossover_large)) || (!ctx->device->mul_mat_id_m[src0_type] && !ctx->device->mul_mat_id_s[src0_type])) { return aligned ? mmp->a_l : mmp->l; } - if ((ctx->device->mul_mat_id_m[src0_type] && (m % mmp->m->wg_denoms[0]) == 0 && (n % mmp->m->wg_denoms[1]) == 0) || !ctx->device->mul_mat_id_s[src0_type]) { + // Use medium shader when the N dimension is greater than the small shader's tile size + uint32_t crossover_medium = mmp->s->wg_denoms[1]; + if ((ctx->device->mul_mat_id_m[src0_type] && (n > crossover_medium)) || !ctx->device->mul_mat_id_s[src0_type]) { return aligned ? mmp->a_m : mmp->m; } return aligned ? mmp->a_s : mmp->s; @@ -3928,14 +4517,15 @@ static void ggml_vk_matmul_id( vk_subbuffer&& a, vk_subbuffer&& b, vk_subbuffer&& d, vk_subbuffer&& ids, uint32_t m, uint32_t n, uint32_t k, uint32_t stride_a, uint32_t stride_b, uint32_t stride_d, uint32_t batch_stride_a, uint32_t batch_stride_b, uint32_t batch_stride_d, - uint32_t n_as, uint32_t nei0, uint32_t nei1, uint32_t nbi1, uint32_t ne11) { + uint32_t n_as, uint32_t nei0, uint32_t nei1, uint32_t nbi1, uint32_t ne11, + uint32_t padded_n) { VK_LOG_DEBUG("ggml_vk_matmul_id(a: (" << a.buffer->buffer << ", " << a.offset << ", " << a.size << "), b: (" << b.buffer->buffer << ", " << b.offset << ", " << b.size << "), d: (" << d.buffer->buffer << ", " << d.offset << ", " << d.size << "), ids: (" << ids.buffer->buffer << ", " << ids.offset << ", " << ids.size << "), " << "m: " << m << ", n: " << n << ", k: " << k << ", stride_a: " << stride_a << ", stride_b: " << stride_b << ", stride_d: " << stride_d << ", " << "batch_stride_a: " << batch_stride_a << ", batch_stride_b: " << batch_stride_b << ", batch_stride_d: " << batch_stride_d << ", " << "n_as: " << n_as << ", nei0: " << nei0 << ", nei1: " << nei1 << ", nbi1: " << nbi1 << ", ne11: " << ne11 << ")"); ggml_vk_sync_buffers(subctx); const vk_mat_mat_id_push_constants pc = { m, n, k, stride_a, stride_b, stride_d, batch_stride_a, batch_stride_b, batch_stride_d, - nei0, nei1, nbi1, ne11 }; + nei0, nei1, nbi1, ne11, padded_n }; ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { a, b, d, ids }, sizeof(vk_mat_mat_id_push_constants), &pc, { m, nei1, n_as }); } @@ -3972,6 +4562,20 @@ static vk_pipeline ggml_vk_get_cpy_pipeline(ggml_backend_vk_context * ctx, const return ctx->device->pipeline_cpy_f16_f16; } } + if (src->type == GGML_TYPE_F16 && to == GGML_TYPE_F32) { + if (contig) { + return ctx->device->pipeline_contig_cpy_f16_f32; + } else { + return ctx->device->pipeline_cpy_f16_f32; + } + } + if (src->type == GGML_TYPE_F32 && to == GGML_TYPE_BF16) { + if (contig) { + return ctx->device->pipeline_contig_cpy_f32_bf16; + } else { + return ctx->device->pipeline_cpy_f32_bf16; + } + } if (src->type == GGML_TYPE_F32) { switch (to) { case GGML_TYPE_Q4_0: @@ -4033,6 +4637,25 @@ static void ggml_vk_cpy_to_contiguous(ggml_backend_vk_context * ctx, vk_context& ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { in, out }, sizeof(vk_op_unary_push_constants), &pc, elements); } +static vk_pipeline ggml_vk_get_quantize_pipeline(ggml_backend_vk_context * ctx, ggml_type type) { + switch(type) { + case GGML_TYPE_Q8_1: + return ctx->device->pipeline_quantize_q8_1; + default: + std::cerr << "Missing quantize pipeline for type: " << ggml_type_name(type) << std::endl; + GGML_ABORT("fatal error"); + } +} + +static void ggml_vk_quantize_q8_1(ggml_backend_vk_context * ctx, vk_context& subctx, vk_subbuffer&& in, vk_subbuffer&& out, uint32_t ne) { + VK_LOG_DEBUG("ggml_vk_quantize_q8_1(" << "buffer in size=" << in.buffer->size << ", buffer out size=" << out.buffer->size << ", " << ne << ")"); + + vk_pipeline pipeline = ggml_vk_get_quantize_pipeline(ctx, GGML_TYPE_Q8_1); + + ggml_vk_sync_buffers(subctx); + ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { in, out }, sizeof(uint32_t), &ne, { ne, 1, 1 }); +} + static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) { VK_LOG_DEBUG("ggml_vk_mul_mat_q_f16((" << src0 << ", name=" << src0->name << ", type=" << src0->type << ", ne0=" << src0->ne[0] << ", ne1=" << src0->ne[1] << ", ne2=" << src0->ne[2] << ", ne3=" << src0->ne[3] << ", nb0=" << src0->nb[0] << ", nb1=" << src0->nb[1] << ", nb2=" << src0->nb[2] << ", nb3=" << src0->nb[3]; std::cerr << "), (" << src1 << ", name=" << src1->name << ", type=" << src1->type << ", ne0=" << src1->ne[0] << ", ne1=" << src1->ne[1] << ", ne2=" << src1->ne[2] << ", ne3=" << src1->ne[3] << ", nb0=" << src1->nb[0] << ", nb1=" << src1->nb[1] << ", nb2=" << src1->nb[2] << ", nb3=" << src1->nb[3]; @@ -4080,56 +4703,76 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub const bool x_non_contig = (ctx->device->coopmat2 && src0->type == GGML_TYPE_F32) || !ggml_vk_dim01_contiguous(src0); const bool y_non_contig = (ctx->device->coopmat2 && src1->type == GGML_TYPE_F32) || + (src0->type == GGML_TYPE_BF16 && src1->type != GGML_TYPE_BF16) || !ggml_vk_dim01_contiguous(src1); + // If src0 is BF16, try to use a BF16 x BF16 multiply + ggml_type f16_type = src0->type == GGML_TYPE_BF16 ? GGML_TYPE_BF16 : GGML_TYPE_F16; + const bool y_f32_kernel = src1->type == GGML_TYPE_F32 && !y_non_contig; - vk_matmul_pipeline mmp = ggml_vk_get_mul_mat_mat_pipeline(ctx, src0->type, y_non_contig ? GGML_TYPE_F16 : src1->type, (ggml_prec)dst->op_params[0]); + bool quantize_y = ctx->device->integer_dot_product && src1->type == GGML_TYPE_F32 && ggml_is_contiguous(src1) && (ne11 * ne10) % 4 == 0; + + // Check for mmq first + vk_matmul_pipeline mmp = quantize_y ? ggml_vk_get_mul_mat_mat_pipeline(ctx, src0->type, GGML_TYPE_Q8_1, (ggml_prec)dst->op_params[0]) : nullptr; + + if (mmp == nullptr) { + // Fall back to f16 dequant mul mat + mmp = ggml_vk_get_mul_mat_mat_pipeline(ctx, src0->type, y_non_contig ? f16_type : src1->type, (ggml_prec)dst->op_params[0]); + quantize_y = false; + } const bool qx_needs_dequant = mmp == nullptr || x_non_contig; - const bool qy_needs_dequant = (src1->type != GGML_TYPE_F16 && !y_f32_kernel) || y_non_contig; + const bool qy_needs_dequant = !quantize_y && ((src1->type != f16_type && !y_f32_kernel) || y_non_contig); if (qx_needs_dequant) { // Fall back to dequant + f16 mulmat - mmp = ggml_vk_get_mul_mat_mat_pipeline(ctx, GGML_TYPE_F16, y_f32_kernel ? GGML_TYPE_F32 : GGML_TYPE_F16, (ggml_prec)dst->op_params[0]); + mmp = ggml_vk_get_mul_mat_mat_pipeline(ctx, f16_type, y_f32_kernel ? GGML_TYPE_F32 : f16_type, (ggml_prec)dst->op_params[0]); } // Not implemented GGML_ASSERT(y_non_contig || !qy_needs_dequant); // NOLINT + const uint32_t kpad = quantize_y ? 0 : ggml_vk_align_size(ne10, ggml_vk_guess_matmul_pipeline_align(ctx, mmp, ne01, ne11, qx_needs_dequant ? f16_type : src0->type, quantize_y ? GGML_TYPE_Q8_1 : (y_f32_kernel ? GGML_TYPE_F32 : src1->type))); + const bool aligned = !quantize_y && ne10 == kpad && ne01 > 8 && ne11 > 8; + + vk_pipeline pipeline = ggml_vk_guess_matmul_pipeline(ctx, mmp, ne01, ne11, aligned, qx_needs_dequant ? f16_type : src0->type, quantize_y ? GGML_TYPE_Q8_1 : (y_f32_kernel ? GGML_TYPE_F32 : src1->type)); + + // Reserve extra storage in the N dimension for the Y matrix, so we can avoid bounds-checking + uint32_t padded_n = qy_needs_dequant ? ROUNDUP_POW2(ne11, pipeline->wg_denoms[1]) : ne11; const int x_ne = ne01 * ne00; - const int y_ne = ne11 * ne10; + const int y_ne = padded_n * ne10; const int d_ne = ne11 * ne01; - const uint32_t kpad = ggml_vk_align_size(ne10, ggml_vk_guess_matmul_pipeline_align(ctx, mmp, ne01, ne11, qx_needs_dequant ? GGML_TYPE_F16 : src0->type)); - const bool aligned = ne10 == kpad && ne01 > 8 && ne11 > 8; - - vk_pipeline pipeline = ggml_vk_guess_matmul_pipeline(ctx, mmp, ne01, ne11, aligned, qx_needs_dequant ? GGML_TYPE_F16 : src0->type); - const uint32_t split_k = ggml_vk_guess_split_k(ctx, ne01, ne11, ne10, pipeline); const uint64_t qx_sz = ggml_type_size(src0->type) * x_ne / ggml_blck_size(src0->type); const uint64_t qy_sz = ggml_type_size(src1->type) * y_ne / ggml_blck_size(src1->type); const uint64_t x_sz = !qx_needs_dequant ? qx_sz : sizeof(ggml_fp16_t) * x_ne; - const uint64_t y_sz = y_f32_kernel ? sizeof(float) * y_ne : sizeof(ggml_fp16_t) * y_ne; + const uint64_t y_sz = quantize_y ? (y_ne * ggml_type_size(GGML_TYPE_Q8_1) / ggml_blck_size(GGML_TYPE_Q8_1)) : (y_f32_kernel ? sizeof(float) * y_ne : sizeof(ggml_fp16_t) * y_ne); const uint64_t d_sz = sizeof(float) * d_ne; vk_pipeline to_fp16_vk_0 = nullptr; vk_pipeline to_fp16_vk_1 = nullptr; + vk_pipeline to_q8_1 = nullptr; if (x_non_contig) { - to_fp16_vk_0 = ggml_vk_get_cpy_pipeline(ctx, src0, nullptr, GGML_TYPE_F16); + to_fp16_vk_0 = ggml_vk_get_cpy_pipeline(ctx, src0, nullptr, f16_type); } else { to_fp16_vk_0 = ggml_vk_get_to_fp16(ctx, src0->type); } if (y_non_contig) { - to_fp16_vk_1 = ggml_vk_get_cpy_pipeline(ctx, src1, nullptr, GGML_TYPE_F16); + to_fp16_vk_1 = ggml_vk_get_cpy_pipeline(ctx, src1, nullptr, f16_type); } else { to_fp16_vk_1 = ggml_vk_get_to_fp16(ctx, src1->type); } GGML_ASSERT(!qx_needs_dequant || to_fp16_vk_0 != nullptr); // NOLINT GGML_ASSERT(!qy_needs_dequant || to_fp16_vk_1 != nullptr); // NOLINT + if (quantize_y) { + to_q8_1 = ggml_vk_get_quantize_pipeline(ctx, GGML_TYPE_Q8_1); + } + if (dryrun) { const uint64_t x_sz_upd = x_sz * ne02 * ne03; const uint64_t y_sz_upd = y_sz * ne12 * ne13; @@ -4143,7 +4786,7 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub if (qx_needs_dequant && ctx->prealloc_size_x < x_sz_upd) { ctx->prealloc_size_x = x_sz_upd; } - if (qy_needs_dequant && ctx->prealloc_size_y < y_sz_upd) { + if ((qy_needs_dequant || quantize_y) && ctx->prealloc_size_y < y_sz_upd) { ctx->prealloc_size_y = y_sz_upd; } if (split_k > 1 && ctx->prealloc_size_split_k < split_k_size) { @@ -4158,6 +4801,9 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub if (qy_needs_dequant) { ggml_pipeline_request_descriptor_sets(ctx->device, to_fp16_vk_1, 1); } + if (quantize_y) { + ggml_pipeline_request_descriptor_sets(ctx->device, to_q8_1, 1); + } if (split_k > 1) { ggml_pipeline_request_descriptor_sets(ctx->device, ctx->device->pipeline_matmul_split_k_reduce, 1); } @@ -4193,6 +4839,9 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub if (qy_needs_dequant) { d_Y = ctx->prealloc_y; GGML_ASSERT(d_Y->size >= y_sz * ne12 * ne13); + } else if (quantize_y) { + d_Y = ctx->prealloc_y; + GGML_ASSERT(d_Y->size >= y_ne * ggml_type_size(GGML_TYPE_Q8_1) / ggml_blck_size(GGML_TYPE_Q8_1)); } else { d_Y = d_Qy; y_buf_offset = qy_buf_offset; @@ -4209,6 +4858,9 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub if (y_non_contig) { ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_1, src1, { d_Qy, qy_buf_offset, VK_WHOLE_SIZE }, { d_Y, 0, VK_WHOLE_SIZE }); } + if (quantize_y) { + ggml_vk_quantize_q8_1(ctx, subctx, { d_Qy, qy_buf_offset, VK_WHOLE_SIZE }, { d_Y, 0, VK_WHOLE_SIZE }, y_ne * ne12 * ne13); + } uint32_t stride_batch_x = ne00*ne01; uint32_t stride_batch_y = ne10*ne11; @@ -4217,7 +4869,7 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub stride_batch_x = src0->nb[0] / ggml_type_size(src0->type); } - if (!ggml_vk_dim01_contiguous(src1) && !qy_needs_dequant) { + if (!ggml_vk_dim01_contiguous(src1) && !qy_needs_dequant && !quantize_y) { stride_batch_y = src1->nb[0] / ggml_type_size(src1->type); } @@ -4228,7 +4880,7 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub { d_D, d_buf_offset, d_sz * ne12 * ne13 }, { ctx->prealloc_split_k, 0, d_sz * ne12 * ne13 * split_k }, ne01, ne11, ne10, ne10, ne10, ne01, stride_batch_x, stride_batch_y, ne20*ne21, - split_k, ne12*ne13, ne02, ne12, r2, r3 + split_k, ne12*ne13, ne02, ne12, r2, r3, padded_n ); // NOLINT } @@ -4465,9 +5117,15 @@ static void ggml_vk_mul_mat_vec_p021_f16_f32(ggml_backend_vk_context * ctx, vk_c const uint64_t qy_sz = ggml_type_size(src1->type) * y_ne / ggml_blck_size(src1->type); const uint64_t d_sz = sizeof(float) * d_ne; + // With grouped query attention there are > 1 Q matrices per K, V matrix. + uint32_t gqa_ratio = (uint32_t)ne12 / (uint32_t)ne02; + if (gqa_ratio > 8 || gqa_ratio == 0 || ne12 != ne02 * gqa_ratio) { + gqa_ratio = 1; + } + if (dryrun) { // Request descriptor sets - ggml_pipeline_request_descriptor_sets(ctx->device, ctx->device->pipeline_mul_mat_vec_p021_f16_f32, 1); + ggml_pipeline_request_descriptor_sets(ctx->device, ctx->device->pipeline_mul_mat_vec_p021_f16_f32[gqa_ratio - 1], 1); return; } @@ -4491,8 +5149,15 @@ static void ggml_vk_mul_mat_vec_p021_f16_f32(ggml_backend_vk_context * ctx, vk_c // compute const std::array pc = { (uint32_t)ne00, (uint32_t)ne01, (uint32_t)ne02, (uint32_t)ne12, (uint32_t)(qy_shader_offset / ggml_type_size(src1->type)), (uint32_t)(d_shader_offset / ggml_type_size(dst->type)) }; + + uint32_t workgroups_z = (uint32_t)ne12; + // When gqa_ratio > 1, each invocation does multiple rows and we can launch fewer workgroups + if (gqa_ratio > 1) { + workgroups_z /= gqa_ratio; + } + ggml_vk_sync_buffers(subctx); - ggml_vk_dispatch_pipeline(ctx, subctx, ctx->device->pipeline_mul_mat_vec_p021_f16_f32, { vk_subbuffer{ d_Qx, qx_buf_offset, qx_sz }, vk_subbuffer{ d_Qy, qy_buffer_offset, qy_sz + qy_shader_offset }, vk_subbuffer{ d_D, d_buffer_offset, d_sz + d_shader_offset } }, 6 * sizeof(uint32_t), &pc, { 1, (uint32_t)ne01, (uint32_t)ne12 }); + ggml_vk_dispatch_pipeline(ctx, subctx, ctx->device->pipeline_mul_mat_vec_p021_f16_f32[gqa_ratio - 1], { vk_subbuffer{ d_Qx, qx_buf_offset, qx_sz }, vk_subbuffer{ d_Qy, qy_buffer_offset, qy_sz + qy_shader_offset }, vk_subbuffer{ d_D, d_buffer_offset, d_sz + d_shader_offset } }, 6 * sizeof(uint32_t), &pc, { 1, (uint32_t)ne01, workgroups_z }); } static void ggml_vk_mul_mat_vec_nc_f16_f32(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) { @@ -4514,6 +5179,8 @@ static void ggml_vk_mul_mat_vec_nc_f16_f32(ggml_backend_vk_context * ctx, vk_con const uint64_t nb01 = src0->nb[1]; const uint64_t nb02 = src0->nb[2]; + const uint64_t nb12 = src1->nb[2]; + // const uint64_t ne10 = src1->ne[0]; const uint64_t ne11 = src1->ne[1]; const uint64_t ne12 = src1->ne[2]; @@ -4539,6 +5206,7 @@ static void ggml_vk_mul_mat_vec_nc_f16_f32(ggml_backend_vk_context * ctx, vk_con const uint32_t row_stride_x = nb01 / sizeof(ggml_fp16_t); const uint32_t channel_stride_x = nb02 / sizeof(ggml_fp16_t); + const uint32_t channel_stride_y = nb12 / sizeof(float); const uint64_t qx_sz = ggml_nbytes(src0); const uint64_t qy_sz = ggml_nbytes(src1); @@ -4569,7 +5237,7 @@ static void ggml_vk_mul_mat_vec_nc_f16_f32(ggml_backend_vk_context * ctx, vk_con const uint64_t d_shader_offset = d_buf_offset - d_buffer_offset; // compute - const std::array pc = { (uint32_t)ne00, (uint32_t)ne01, row_stride_x, channel_stride_x, (uint32_t)(ne12 / ne02), (uint32_t)(qy_shader_offset / ggml_type_size(src1->type)), (uint32_t)(d_shader_offset / ggml_type_size(dst->type)) }; + const std::array pc = { (uint32_t)ne00, (uint32_t)ne01, row_stride_x, channel_stride_x, channel_stride_y, (uint32_t)(ne12 / ne02), (uint32_t)ne12, (uint32_t)(qy_shader_offset / ggml_type_size(src1->type)), (uint32_t)(d_shader_offset / ggml_type_size(dst->type)) }; ggml_vk_sync_buffers(subctx); ggml_vk_dispatch_pipeline(ctx, subctx, ctx->device->pipeline_mul_mat_vec_nc_f16_f32, { vk_subbuffer{ d_Qx, qx_buf_offset, qx_sz }, vk_subbuffer{ d_Qy, qy_buffer_offset, qy_sz + qy_shader_offset }, vk_subbuffer{ d_D, d_buffer_offset, d_sz + d_shader_offset } }, 7 * sizeof(uint32_t), &pc, { 1, (uint32_t)ne01, (uint32_t)ne12 }); @@ -4594,7 +5262,7 @@ static void ggml_vk_mul_mat(ggml_backend_vk_context * ctx, vk_context& subctx, c // mul_mat_vec supports batching ne12*ne13 when ne11==1, or treating ne11 as the batch size (up to four) // when ne12 and ne13 are one. } else if ((dst->ne[1] == 1 || (dst->ne[1] <= mul_mat_vec_max_cols && src1->ne[2] * src1->ne[3] == 1)) && - (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type))) { + (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || src0->type == GGML_TYPE_BF16 || ggml_is_quantized(src0->type))) { ggml_vk_mul_mat_vec_q_f16(ctx, subctx, src0, src1, dst, dryrun); } else { ggml_vk_mul_mat_q_f16(ctx, subctx, src0, src1, dst, dryrun); @@ -4621,7 +5289,7 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context& const uint64_t nei0 = ids->ne[0]; const uint64_t nei1 = ids->ne[1]; - GGML_ASSERT(nei0 * nei1 <= 3072); + GGML_ASSERT(nei0 * nei1 <= 4096); const uint32_t nbi1 = ids->nb[1]; const uint32_t nbi2 = ids->nb[2]; @@ -4662,31 +5330,37 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context& const bool x_non_contig = (ctx->device->coopmat2 && src0->type == GGML_TYPE_F32) || !ggml_vk_dim01_contiguous(src0); const bool y_non_contig = (ctx->device->coopmat2 && src1->type == GGML_TYPE_F32) || + (src0->type == GGML_TYPE_BF16 && src1->type != GGML_TYPE_BF16) || !ggml_vk_dim01_contiguous(src1); + // If src0 is BF16, try to use a BF16 x BF16 multiply + ggml_type f16_type = src0->type == GGML_TYPE_BF16 ? GGML_TYPE_BF16 : GGML_TYPE_F16; + const bool y_f32_kernel = src1->type == GGML_TYPE_F32 && !y_non_contig; - vk_matmul_pipeline mmp = ggml_vk_get_mul_mat_mat_id_pipeline(ctx, src0->type, y_non_contig ? GGML_TYPE_F16 : src1->type, (ggml_prec)dst->op_params[0]); + vk_matmul_pipeline mmp = ggml_vk_get_mul_mat_mat_id_pipeline(ctx, src0->type, y_non_contig ? f16_type : src1->type, (ggml_prec)dst->op_params[0]); const bool qx_needs_dequant = mmp == nullptr || x_non_contig; - const bool qy_needs_dequant = (src1->type != GGML_TYPE_F16 && !y_f32_kernel) || y_non_contig; + const bool qy_needs_dequant = (src1->type != f16_type && !y_f32_kernel) || y_non_contig; if (qx_needs_dequant) { // Fall back to dequant + f16 mulmat - mmp = ggml_vk_get_mul_mat_mat_id_pipeline(ctx, GGML_TYPE_F16, y_f32_kernel ? GGML_TYPE_F32 : GGML_TYPE_F16, (ggml_prec)dst->op_params[0]); + mmp = ggml_vk_get_mul_mat_mat_id_pipeline(ctx, f16_type, y_f32_kernel ? GGML_TYPE_F32 : f16_type, (ggml_prec)dst->op_params[0]); } // Not implemented GGML_ASSERT(y_non_contig || !qy_needs_dequant); // NOLINT - const uint64_t x_ne = ne01 * ne00; - const uint64_t y_ne = ne11 * ne10; - const uint64_t d_ne = ne21 * ne20; - - const uint32_t kpad = ggml_vk_align_size(ne10, ggml_vk_guess_matmul_id_pipeline_align(ctx, mmp, ne01, nei1, qx_needs_dequant ? GGML_TYPE_F16 : src0->type)); + const uint32_t kpad = ggml_vk_align_size(ne10, ggml_vk_guess_matmul_id_pipeline_align(ctx, mmp, ne01, nei1, qx_needs_dequant ? f16_type : src0->type)); const bool aligned = ne10 == kpad && ne01 > 8 && nei1 > 8; - vk_pipeline pipeline = ggml_vk_guess_matmul_id_pipeline(ctx, mmp, ne01, nei1, aligned, qx_needs_dequant ? GGML_TYPE_F16 : src0->type); + vk_pipeline pipeline = ggml_vk_guess_matmul_id_pipeline(ctx, mmp, ne01, nei1, aligned, qx_needs_dequant ? f16_type : src0->type); + + // Reserve extra storage in the N dimension for the Y matrix, so we can avoid bounds-checking + uint32_t padded_n = qy_needs_dequant ? ROUNDUP_POW2(ne11, pipeline->wg_denoms[1]) :ne11; + const uint64_t x_ne = ne01 * ne00; + const uint64_t y_ne = padded_n * ne10; + const uint64_t d_ne = ne21 * ne20; const uint64_t qx_sz = ggml_type_size(src0->type) * x_ne / ggml_blck_size(src0->type); const uint64_t qy_sz = ggml_type_size(src1->type) * y_ne / ggml_blck_size(src1->type); @@ -4699,12 +5373,12 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context& vk_pipeline to_fp16_vk_1 = nullptr; if (x_non_contig) { - to_fp16_vk_0 = ggml_vk_get_cpy_pipeline(ctx, src0, nullptr, GGML_TYPE_F16); + to_fp16_vk_0 = ggml_vk_get_cpy_pipeline(ctx, src0, nullptr, f16_type); } else { to_fp16_vk_0 = ggml_vk_get_to_fp16(ctx, src0->type); } if (y_non_contig) { - to_fp16_vk_1 = ggml_vk_get_cpy_pipeline(ctx, src1, nullptr, GGML_TYPE_F16); + to_fp16_vk_1 = ggml_vk_get_cpy_pipeline(ctx, src1, nullptr, f16_type); } else { to_fp16_vk_1 = ggml_vk_get_to_fp16(ctx, src1->type); } @@ -4806,7 +5480,7 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context& { d_D, d_buf_offset, d_sz * ne22 * ne23 }, { d_ids, ids_buf_offset, ids_sz }, ne01, ne21, ne10, ne10, ne10, ne01, stride_batch_x, stride_batch_y, ne20*ne21, - n_as, nei0, nei1, nbi1 / ggml_type_size(ids->type), ne11 + n_as, nei0, nei1, nbi1 / ggml_type_size(ids->type), ne11, padded_n ); // NOLINT } @@ -5034,7 +5708,7 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx const uint32_t nbm1 = mask ? mask->nb[1] : 0; const uint32_t D = neq0; - const uint32_t N = neq1; + uint32_t N = neq1; const uint32_t KV = nek1; GGML_ASSERT(ne0 == D); @@ -5064,20 +5738,57 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx assert(q->type == GGML_TYPE_F32); assert(k->type == v->type); + bool scalar = !ctx->device->coopmat2; + + uint32_t gqa_ratio = 1; + uint32_t qk_ratio = neq2 / nek2; + uint32_t workgroups_x = (uint32_t)neq1; + uint32_t workgroups_y = (uint32_t)neq2; + uint32_t workgroups_z = (uint32_t)neq3; + + // For scalar FA, we can use the "large" size to accommodate qga. + // For coopmat FA, we always use the small size (which is still pretty large for gqa). + const uint32_t max_gqa = scalar ? scalar_flash_attention_num_large_rows : get_fa_num_small_rows(false); + + if (N == 1 && qk_ratio > 1 && qk_ratio <= max_gqa && + qk_ratio * nek2 == neq2 && nek2 == nev2 && neq3 == 1 && nek3 == 1 && nev3 == 1) { + // grouped query attention - make the N dimension equal to gqa_ratio, reduce + // workgroups proportionally in y dimension. The shader will detect gqa_ratio > 1 + // and change addressing calculations to index Q's dimension 2. + gqa_ratio = qk_ratio; + N = gqa_ratio; + workgroups_y /= N; + } + vk_pipeline *pipelines; // XXX TODO other backends may be changing accumulator precision to default to f32 soon - bool f32acc = dst->op_params[3] == GGML_PREC_F32; - bool small_rows = N <= flash_attention_num_small_rows; - switch (D) { - case 64: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D64[k->type][f32acc][small_rows][0]; break; - case 80: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D80[k->type][f32acc][small_rows][0]; break; - case 96: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D96[k->type][f32acc][small_rows][0]; break; - case 112: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D112[k->type][f32acc][small_rows][0]; break; - case 128: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D128[k->type][f32acc][small_rows][0]; break; - case 256: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D256[k->type][f32acc][small_rows][0]; break; - default: - assert(!"unsupported D value"); - return; + bool f32acc = scalar || dst->op_params[3] == GGML_PREC_F32; + bool small_rows = N <= get_fa_num_small_rows(scalar); + + if (scalar) { + switch (D) { + case 64: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D64[k->type][f32acc][small_rows][0]; break; + case 80: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D80[k->type][f32acc][small_rows][0]; break; + case 96: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D96[k->type][f32acc][small_rows][0]; break; + case 112: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D112[k->type][f32acc][small_rows][0]; break; + case 128: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D128[k->type][f32acc][small_rows][0]; break; + case 256: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D256[k->type][f32acc][small_rows][0]; break; + default: + GGML_ASSERT(!"unsupported D value"); + return; + } + } else { + switch (D) { + case 64: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D64_cm2[k->type][f32acc][small_rows][0]; break; + case 80: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D80_cm2[k->type][f32acc][small_rows][0]; break; + case 96: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D96_cm2[k->type][f32acc][small_rows][0]; break; + case 112: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D112_cm2[k->type][f32acc][small_rows][0]; break; + case 128: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D128_cm2[k->type][f32acc][small_rows][0]; break; + case 256: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D256_cm2[k->type][f32acc][small_rows][0]; break; + default: + GGML_ASSERT(!"unsupported D value"); + return; + } } assert(pipelines); @@ -5089,12 +5800,47 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx // the "aligned" shader variant will forcibly align strides, for performance (q_stride & 7) == 0 && (k_stride & 7) == 0 && (v_stride & 7) == 0; + // mask dim1 is padded to 64, we rely on this to avoid clamping mask loads + GGML_ASSERT((nem1 % GGML_KQ_MASK_PAD) == 0); + vk_pipeline pipeline = pipelines[aligned]; assert(pipeline); + uint32_t split_kv = KV; + uint32_t split_k = 1; + + // Use a placeholder core count if one isn't available. split_k is a big help for perf. + const uint32_t shader_core_count = ctx->device->shader_core_count ? ctx->device->shader_core_count : 16; + + // Try to use split_k when KV is large enough to be worth the overhead + if (workgroups_x == 1 && shader_core_count > 0 && KV >= 512) { + // Try to run two workgroups per SM. + split_k = ctx->device->shader_core_count * 2 / workgroups_y; + if (split_k > 1) { + // Try to evenly split KV into split_k chunks, but it needs to be a multiple + // of "align", so recompute split_k based on that. + split_kv = ROUNDUP_POW2(KV / split_k, pipelines[1]->align); + split_k = CEIL_DIV(KV, split_kv); + workgroups_x = split_k; + } + } + + // Reserve space for split_k temporaries. For each split, we need to store the O matrix (D x ne1) + // and the per-row m and L values (ne1 rows). + const uint64_t split_k_size = split_k > 1 ? (D * ne1 * sizeof(float) + ne1 * sizeof(float) * 2) * split_k : 0; + if (split_k_size > ctx->device->max_memory_allocation_size) { + GGML_ABORT("Requested preallocation size is too large"); + } + if (ctx->prealloc_size_split_k < split_k_size) { + ctx->prealloc_size_split_k = split_k_size; + } + if (dryrun) { // Request descriptor sets ggml_pipeline_request_descriptor_sets(ctx->device, pipeline, 1); + if (split_k > 1) { + ggml_pipeline_request_descriptor_sets(ctx->device, ctx->device->pipeline_flash_attn_split_k_reduce, 1); + } return; } @@ -5115,8 +5861,6 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx const float m0 = powf(2.0f, -(max_bias ) / n_head_log2); const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2); - ggml_vk_sync_buffers(subctx); - vk_buffer d_Q = nullptr, d_K = nullptr, d_V = nullptr, d_D = nullptr, d_M = nullptr; size_t q_buf_offset = 0, k_buf_offset = 0, v_buf_offset = 0, d_buf_offset = 0, m_buf_offset = 0; @@ -5181,16 +5925,45 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx v_stride, (uint32_t)nbv2, (uint32_t)nbv3, nbm1, scale, max_bias, logit_softcap, - mask != nullptr, n_head_log2, m0, m1 }; - ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, - { - vk_subbuffer{d_Q, q_buf_offset, VK_WHOLE_SIZE}, - vk_subbuffer{d_K, k_buf_offset, VK_WHOLE_SIZE}, - vk_subbuffer{d_V, v_buf_offset, VK_WHOLE_SIZE}, - vk_subbuffer{d_M, m_buf_offset, VK_WHOLE_SIZE}, - vk_subbuffer{d_D, d_buf_offset, VK_WHOLE_SIZE}, - }, - sizeof(vk_flash_attn_push_constants), &pc, { (uint32_t)neq1, (uint32_t)neq2, (uint32_t)neq3 }); + mask != nullptr, n_head_log2, m0, m1, + gqa_ratio, split_kv, split_k }; + + ggml_vk_sync_buffers(subctx); + + if (split_k > 1) { + ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, + { + vk_subbuffer{d_Q, q_buf_offset, VK_WHOLE_SIZE}, + vk_subbuffer{d_K, k_buf_offset, VK_WHOLE_SIZE}, + vk_subbuffer{d_V, v_buf_offset, VK_WHOLE_SIZE}, + vk_subbuffer{d_M, m_buf_offset, VK_WHOLE_SIZE}, + vk_subbuffer{ctx->prealloc_split_k, 0, VK_WHOLE_SIZE}, + }, + // We only use split_k when group query attention is enabled, which means + // there's no more than one tile of rows (i.e. workgroups_x would have been + // one). We reuse workgroups_x to mean the number of splits, so we need to + // cancel out the divide by wg_denoms[0]. + sizeof(vk_flash_attn_push_constants), &pc, { workgroups_x * pipeline->wg_denoms[0], workgroups_y, workgroups_z }); + + ggml_vk_sync_buffers(subctx); + const std::array pc2 = { D, (uint32_t)ne1, split_k }; + ggml_vk_dispatch_pipeline(ctx, subctx, ctx->device->pipeline_flash_attn_split_k_reduce, + { + vk_subbuffer{ctx->prealloc_split_k, 0, VK_WHOLE_SIZE}, + vk_subbuffer{d_D, d_buf_offset, VK_WHOLE_SIZE}, + }, + pc2.size() * uint32_t{sizeof(uint32_t)}, pc2.data(), { (uint32_t)ne1, 1, 1 }); + } else { + ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, + { + vk_subbuffer{d_Q, q_buf_offset, VK_WHOLE_SIZE}, + vk_subbuffer{d_K, k_buf_offset, VK_WHOLE_SIZE}, + vk_subbuffer{d_V, v_buf_offset, VK_WHOLE_SIZE}, + vk_subbuffer{d_M, m_buf_offset, VK_WHOLE_SIZE}, + vk_subbuffer{d_D, d_buf_offset, VK_WHOLE_SIZE}, + }, + sizeof(vk_flash_attn_push_constants), &pc, { workgroups_x, workgroups_y, workgroups_z }); + } } static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst, ggml_op op) { @@ -5210,26 +5983,37 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const } return nullptr; case GGML_OP_ADD: - if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { - return ggml_are_same_shape(src0, src1) ? ctx->device->pipeline_add_f32_norepeat : ctx->device->pipeline_add_f32; - } - if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F16) { - return ggml_are_same_shape(src0, src1) ? ctx->device->pipeline_add_f16_f32_f16_norepeat : ctx->device->pipeline_add_f16_f32_f16; - } - return nullptr; case GGML_OP_SUB: - if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { - return ggml_are_same_shape(src0, src1) ? ctx->device->pipeline_sub_f32_norepeat : ctx->device->pipeline_sub_f32; - } - return nullptr; case GGML_OP_MUL: - if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { - return ggml_are_same_shape(src0, src1) ? ctx->device->pipeline_mul_f32_norepeat : ctx->device->pipeline_mul_f32; - } - return nullptr; case GGML_OP_DIV: - if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { - return ggml_are_same_shape(src0, src1) ? ctx->device->pipeline_div_f32_norepeat : ctx->device->pipeline_div_f32; + if ((src0->type != GGML_TYPE_F32 && src0->type != GGML_TYPE_F16) || + (src1->type != GGML_TYPE_F32 && src1->type != GGML_TYPE_F16) || + (dst->type != GGML_TYPE_F32 && dst->type != GGML_TYPE_F16)) { + return nullptr; + } + switch (op) { + case GGML_OP_ADD: + { + auto pipelines = ggml_are_same_shape(src0, src1) ? ctx->device->pipeline_add_norepeat : ctx->device->pipeline_add; + return pipelines[src0->type == GGML_TYPE_F16][src1->type == GGML_TYPE_F16][dst->type == GGML_TYPE_F16]; + } + case GGML_OP_SUB: + { + auto pipelines = ggml_are_same_shape(src0, src1) ? ctx->device->pipeline_sub_norepeat : ctx->device->pipeline_sub; + return pipelines[src0->type == GGML_TYPE_F16][src1->type == GGML_TYPE_F16][dst->type == GGML_TYPE_F16]; + } + case GGML_OP_MUL: + { + auto pipelines = ggml_are_same_shape(src0, src1) ? ctx->device->pipeline_mul_norepeat : ctx->device->pipeline_mul; + return pipelines[src0->type == GGML_TYPE_F16][src1->type == GGML_TYPE_F16][dst->type == GGML_TYPE_F16]; + } + case GGML_OP_DIV: + { + auto pipelines = ggml_are_same_shape(src0, src1) ? ctx->device->pipeline_div_norepeat : ctx->device->pipeline_div; + return pipelines[src0->type == GGML_TYPE_F16][src1->type == GGML_TYPE_F16][dst->type == GGML_TYPE_F16]; + } + default: + break; } return nullptr; case GGML_OP_CONCAT: @@ -5244,7 +6028,7 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const } return nullptr; case GGML_OP_UPSCALE: - if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { + if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32 && dst->op_params[0] == GGML_SCALE_MODE_NEAREST) { return ctx->device->pipeline_upscale_f32; } return nullptr; @@ -5317,38 +6101,31 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const return ctx->device->pipeline_rms_norm_back_f32; } return nullptr; + case GGML_OP_L2_NORM: + if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { + return ctx->device->pipeline_l2_norm_f32; + } + return nullptr; case GGML_OP_UNARY: + if ((src0->type != GGML_TYPE_F32 && src0->type != GGML_TYPE_F16) || + (dst->type != GGML_TYPE_F32 && dst->type != GGML_TYPE_F16) || + (src0->type != dst->type)) { + return nullptr; + } + switch (ggml_get_unary_op(dst)) { case GGML_UNARY_OP_SILU: - if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { - return ctx->device->pipeline_silu_f32; - } - break; + return ctx->device->pipeline_silu[dst->type == GGML_TYPE_F16]; case GGML_UNARY_OP_GELU: - if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { - return ctx->device->pipeline_gelu_f32; - } - break; + return ctx->device->pipeline_gelu[dst->type == GGML_TYPE_F16]; case GGML_UNARY_OP_GELU_QUICK: - if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { - return ctx->device->pipeline_gelu_quick_f32; - } - break; + return ctx->device->pipeline_gelu_quick[dst->type == GGML_TYPE_F16]; case GGML_UNARY_OP_RELU: - if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { - return ctx->device->pipeline_relu_f32; - } - break; + return ctx->device->pipeline_relu[dst->type == GGML_TYPE_F16]; case GGML_UNARY_OP_TANH: - if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { - return ctx->device->pipeline_tanh_f32; - } - break; + return ctx->device->pipeline_tanh[dst->type == GGML_TYPE_F16]; case GGML_UNARY_OP_SIGMOID: - if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { - return ctx->device->pipeline_sigmoid_f32; - } - break; + return ctx->device->pipeline_sigmoid[dst->type == GGML_TYPE_F16]; default: break; } @@ -5456,6 +6233,11 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const return ctx->device->pipeline_rwkv_wkv6_f32; } return nullptr; + case GGML_OP_RWKV_WKV7: + if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { + return ctx->device->pipeline_rwkv_wkv7_f32; + } + return nullptr; case GGML_OP_OPT_STEP_ADAMW: if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { return ctx->device->pipeline_opt_step_adamw_f32; @@ -5466,6 +6248,15 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const return ctx->device->pipeline_leaky_relu_f32; } return nullptr; + case GGML_OP_CONV_2D_DW: + if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { + if (ggml_is_contiguous(src1)) { + return ctx->device->pipeline_conv2d_dw_whcn_f32; + } else if (ggml_is_contiguous_channels(src1)) { + return ctx->device->pipeline_conv2d_dw_cwhn_f32; + } + } + return nullptr; default: return nullptr; } @@ -5491,6 +6282,8 @@ static bool ggml_vk_op_supports_incontiguous(ggml_op op) { case GGML_OP_REPEAT: case GGML_OP_REPEAT_BACK: case GGML_OP_ROPE: + case GGML_OP_RMS_NORM: + case GGML_OP_CONV_2D_DW: return true; default: return false; @@ -5701,8 +6494,8 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co switch (op) { case GGML_OP_NORM: - case GGML_OP_RMS_NORM: case GGML_OP_RMS_NORM_BACK: + case GGML_OP_L2_NORM: case GGML_OP_SOFT_MAX: case GGML_OP_SOFT_MAX_BACK: case GGML_OP_SUM_ROWS: @@ -5717,6 +6510,10 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co elements = { nr, 1, 1 }; } } break; + case GGML_OP_RMS_NORM: + elements = { (uint32_t)ne01, (uint32_t)ne02, (uint32_t)ne03 }; + break; + case GGML_OP_SUM: // We use GGML_OP_SUM_ROWS with 1 row. elements = { 1, 1, 1 }; @@ -5783,6 +6580,7 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co case GGML_OP_CONCAT: case GGML_OP_UPSCALE: case GGML_OP_UNARY: + case GGML_OP_CONV_2D_DW: { const uint32_t ne = ggml_nelements(dst); if (ne > 262144) { @@ -5952,23 +6750,17 @@ static void ggml_vk_div(ggml_backend_vk_context * ctx, vk_context& subctx, const }, dryrun); } -static void ggml_vk_op_f32_rwkv6(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst, const vk_op_rwkv_wkv6_push_constants&& pc, bool dryrun = false) { - const ggml_tensor * k = dst->src[0]; - const ggml_tensor * v = dst->src[1]; - const ggml_tensor * r = dst->src[2]; - const ggml_tensor * tf = dst->src[3]; - const ggml_tensor * td = dst->src[4]; - const ggml_tensor * state = dst->src[5]; +static void ggml_vk_op_f32_wkv(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst, const vk_op_rwkv_wkv6_push_constants&& pc, int version, bool dryrun = false) { + GGML_ASSERT(version == 6 || version == 7); + int num_srcs = version == 6 ? 6 : 7; + + for (int i = 0; i < num_srcs; i++) { + GGML_ASSERT(!ggml_is_quantized(dst->src[i]->type)); + } - GGML_ASSERT(!ggml_is_quantized(k->type)); - GGML_ASSERT(!ggml_is_quantized(v->type)); - GGML_ASSERT(!ggml_is_quantized(r->type)); - GGML_ASSERT(!ggml_is_quantized(tf->type)); - GGML_ASSERT(!ggml_is_quantized(td->type)); - GGML_ASSERT(!ggml_is_quantized(state->type)); GGML_ASSERT(dst->buffer != nullptr); - vk_pipeline pipeline = ggml_vk_op_get_pipeline(ctx, k, v, r, dst, GGML_OP_RWKV_WKV6); + vk_pipeline pipeline = ggml_vk_op_get_pipeline(ctx, dst->src[0], dst->src[1], dst->src[2], dst, dst->op); GGML_ASSERT(pipeline != nullptr); if (dryrun) { @@ -5977,89 +6769,73 @@ static void ggml_vk_op_f32_rwkv6(ggml_backend_vk_context * ctx, vk_context& subc } ggml_backend_vk_buffer_context * dst_buf_ctx = (ggml_backend_vk_buffer_context *)dst->buffer->context; - ggml_backend_vk_buffer_context * k_buf_ctx = (ggml_backend_vk_buffer_context *)k->buffer->context; - ggml_backend_vk_buffer_context * v_buf_ctx = (ggml_backend_vk_buffer_context *)v->buffer->context; - ggml_backend_vk_buffer_context * r_buf_ctx = (ggml_backend_vk_buffer_context *)r->buffer->context; - ggml_backend_vk_buffer_context * tf_buf_ctx = (ggml_backend_vk_buffer_context *)tf->buffer->context; - ggml_backend_vk_buffer_context * td_buf_ctx = (ggml_backend_vk_buffer_context *)td->buffer->context; - ggml_backend_vk_buffer_context * state_buf_ctx = (ggml_backend_vk_buffer_context *)state->buffer->context; + ggml_backend_vk_buffer_context * src_buf_ctxs[7] = { nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr }; + for (int i = 0; i < num_srcs; i++) { + src_buf_ctxs[i] = (ggml_backend_vk_buffer_context *)dst->src[i]->buffer->context; + } ggml_vk_sync_buffers(subctx); - vk_buffer d_D = nullptr, d_K = nullptr, d_V = nullptr, d_R = nullptr, d_TF = nullptr, d_TD = nullptr, d_State = nullptr; - size_t k_offset = 0, v_offset = 0, r_offset = 0, tf_offset = 0, td_offset = 0, state_offset = 0, dst_offset = 0; - bool K_uma = false, V_uma = false, R_uma = false, TF_uma = false, TD_uma = false, STATE_uma = false, DST_uma = false; + vk_buffer d_D = nullptr, d_srcs[7] = { nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr }; + size_t dst_offset = 0, src_offsets[7] = { 0, 0, 0, 0, 0, 0, 0 }; + bool dst_uma = false, srcs_uma[7] = { false, false, false, false, false, false, false }; if (ctx->device->uma) { - ggml_vk_host_get(ctx->device, k->data, d_K, k_offset); - ggml_vk_host_get(ctx->device, v->data, d_V, v_offset); - ggml_vk_host_get(ctx->device, r->data, d_R, r_offset); - ggml_vk_host_get(ctx->device, tf->data, d_TF, tf_offset); - ggml_vk_host_get(ctx->device, td->data, d_TD, td_offset); - ggml_vk_host_get(ctx->device, state->data, d_State, state_offset); + for (int i = 0; i < num_srcs; i++) { + ggml_vk_host_get(ctx->device, dst->src[i]->data, d_srcs[i], src_offsets[i]); + srcs_uma[i] = d_srcs[i] != nullptr; + } + ggml_vk_host_get(ctx->device, dst->data, d_D, dst_offset); - - K_uma = d_K != nullptr; - V_uma = d_V != nullptr; - R_uma = d_R != nullptr; - TF_uma = d_TF != nullptr; - TD_uma = d_TD != nullptr; - STATE_uma = d_State != nullptr; - DST_uma = d_D != nullptr; + dst_uma = d_D != nullptr; } - if (!K_uma) { - d_K = k_buf_ctx->dev_buffer; - k_offset = vk_tensor_offset(k) + k->view_offs; + uint64_t src_sizes[7] = { 0, 0, 0, 0, 0, 0, 0 }; + for (int i = 0; i < num_srcs; i++) { + src_sizes[i] = ggml_nbytes(dst->src[i]); + if (!srcs_uma[i]) { + d_srcs[i] = src_buf_ctxs[i]->dev_buffer; + src_offsets[i] = vk_tensor_offset(dst->src[i]) + dst->src[i]->view_offs; + } } - if (!V_uma) { - d_V = v_buf_ctx->dev_buffer; - v_offset = vk_tensor_offset(v) + v->view_offs; - } - if (!R_uma) { - d_R = r_buf_ctx->dev_buffer; - r_offset = vk_tensor_offset(r) + r->view_offs; - } - if (!TF_uma) { - d_TF = tf_buf_ctx->dev_buffer; - tf_offset = vk_tensor_offset(tf) + tf->view_offs; - } - if (!TD_uma) { - d_TD = td_buf_ctx->dev_buffer; - td_offset = vk_tensor_offset(td) + td->view_offs; - } - if (!STATE_uma) { - d_State = state_buf_ctx->dev_buffer; - state_offset = vk_tensor_offset(state) + state->view_offs; - } - if (!DST_uma) { + + const uint64_t dst_size = ggml_nbytes(dst); + if (!dst_uma) { d_D = dst_buf_ctx->dev_buffer; dst_offset = vk_tensor_offset(dst) + dst->view_offs; } - const uint64_t k_size = ggml_nbytes(k); - const uint64_t v_size = ggml_nbytes(v); - const uint64_t r_size = ggml_nbytes(r); - const uint64_t tf_size = ggml_nbytes(tf); - const uint64_t td_size = ggml_nbytes(td); - const uint64_t state_size = ggml_nbytes(state); - const uint64_t dst_size = ggml_nbytes(dst); - std::array elements = { (uint32_t)(pc.B * pc.H), 1, 1 }; - ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { - vk_subbuffer{ d_K, k_offset, k_size }, - vk_subbuffer{ d_V, v_offset, v_size }, - vk_subbuffer{ d_R, r_offset, r_size }, - vk_subbuffer{ d_TF, tf_offset, tf_size }, - vk_subbuffer{ d_TD, td_offset, td_size }, - vk_subbuffer{ d_State, state_offset, state_size }, - vk_subbuffer{ d_D, dst_offset, dst_size } - }, sizeof(vk_op_rwkv_wkv6_push_constants), &pc, elements); + if (version == 6) { + ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { + vk_subbuffer{ d_srcs[0], src_offsets[0], src_sizes[0] }, + vk_subbuffer{ d_srcs[1], src_offsets[1], src_sizes[1] }, + vk_subbuffer{ d_srcs[2], src_offsets[2], src_sizes[2] }, + vk_subbuffer{ d_srcs[3], src_offsets[3], src_sizes[3] }, + vk_subbuffer{ d_srcs[4], src_offsets[4], src_sizes[4] }, + vk_subbuffer{ d_srcs[5], src_offsets[5], src_sizes[5] }, + vk_subbuffer{ d_D, dst_offset, dst_size } + }, sizeof(vk_op_rwkv_wkv6_push_constants), &pc, elements); + } else if (version == 7) { + ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { + vk_subbuffer{ d_srcs[0], src_offsets[0], src_sizes[0] }, + vk_subbuffer{ d_srcs[1], src_offsets[1], src_sizes[1] }, + vk_subbuffer{ d_srcs[2], src_offsets[2], src_sizes[2] }, + vk_subbuffer{ d_srcs[3], src_offsets[3], src_sizes[3] }, + vk_subbuffer{ d_srcs[4], src_offsets[4], src_sizes[4] }, + vk_subbuffer{ d_srcs[5], src_offsets[5], src_sizes[5] }, + vk_subbuffer{ d_srcs[6], src_offsets[6], src_sizes[6] }, + vk_subbuffer{ d_D, dst_offset, dst_size } + }, sizeof(vk_op_rwkv_wkv7_push_constants), &pc, elements); + } else { + // shouldn't happen + GGML_ASSERT(false); + } } static void ggml_vk_rwkv_wkv6(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst, bool dryrun = false) { @@ -6068,7 +6844,7 @@ static void ggml_vk_rwkv_wkv6(ggml_backend_vk_context * ctx, vk_context& subctx, const size_t n_heads = dst->src[0]->ne[1]; const size_t n_seqs = dst->src[5]->ne[1]; - ggml_vk_op_f32_rwkv6( + ggml_vk_op_f32_wkv( ctx, subctx, dst, { (uint32_t)n_seqs, @@ -6076,6 +6852,26 @@ static void ggml_vk_rwkv_wkv6(ggml_backend_vk_context * ctx, vk_context& subctx, (uint32_t)n_embed, (uint32_t)n_heads, }, + 6, + dryrun + ); +} + +static void ggml_vk_rwkv_wkv7(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst, bool dryrun = false) { + const size_t seq_length = dst->src[0]->ne[2]; + const size_t n_embed = dst->ne[0]; + const size_t n_heads = dst->src[0]->ne[1]; + const size_t n_seqs = dst->src[6]->ne[1]; + + ggml_vk_op_f32_wkv( + ctx, subctx, dst, + { + (uint32_t)n_seqs, + (uint32_t)seq_length, + (uint32_t)n_embed, + (uint32_t)n_heads, + }, + 7, dryrun ); } @@ -6369,7 +7165,17 @@ static void ggml_vk_group_norm(ggml_backend_vk_context * ctx, vk_context& subctx static void ggml_vk_rms_norm(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) { float * op_params = (float *)dst->op_params; - ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_RMS_NORM, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], op_params[0], 0.0f }, dryrun); + const uint32_t src0_type_size = ggml_type_size(src0->type); + const uint32_t dst_type_size = ggml_type_size(dst->type); + + ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_RMS_NORM, { + (uint32_t)ggml_nelements(src0), + (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2], (uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size, + (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2], (uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size, + 0, + op_params[0], 0.0f, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + }, dryrun); } static void ggml_vk_rms_norm_back(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) { @@ -6377,6 +7183,11 @@ static void ggml_vk_rms_norm_back(ggml_backend_vk_context * ctx, vk_context& sub ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_RMS_NORM_BACK, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], op_params[0], 0.0f }, dryrun); } +static void ggml_vk_l2_norm(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) { + float * op_params = (float *)dst->op_params; + ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_L2_NORM, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], op_params[0], 0.0f }, dryrun); +} + static void ggml_vk_unary(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) { ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_UNARY, { (uint32_t)ggml_nelements(src0), 0, 0.0f, 0.0f }, dryrun); } @@ -6556,6 +7367,30 @@ static void ggml_vk_pool_2d(ggml_backend_vk_context * ctx, vk_context& subctx, c }, dryrun); } +static void ggml_vk_conv_2d_dw(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) { + vk_op_conv2d_dw_push_constants p{}; + p.ne = ggml_nelements(dst); + p.channels = dst->ne[2]; + p.batches = dst->ne[3]; + p.dst_w = dst->ne[0]; + p.dst_h = dst->ne[1]; + p.src_w = src1->ne[0]; + p.src_h = src1->ne[1]; + p.knl_w = src0->ne[0]; + p.knl_h = src0->ne[1]; + p.stride_x = dst->op_params[0]; + p.stride_y = dst->op_params[1]; + p.pad_x = dst->op_params[2]; + p.pad_y = dst->op_params[3]; + p.dilation_x = dst->op_params[4]; + p.dilation_y = dst->op_params[5]; + + GGML_ASSERT(src0->ne[3] == p.channels); + GGML_ASSERT(src1->ne[3] == p.batches); + + ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_CONV_2D_DW, std::move(p), dryrun); +} + static void ggml_vk_leaky_relu(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) { const float * op_params = (const float *)dst->op_params; ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_LEAKY_RELU, { (uint32_t)ggml_nelements(src0), 0, op_params[0], 0.0f }, dryrun); @@ -6717,6 +7552,10 @@ static void ggml_vk_test_matmul(ggml_backend_vk_context * ctx, size_t m, size_t } } + if (ctx->device->need_compiles) { + ggml_vk_load_shaders(ctx->device); + } + ggml_pipeline_allocate_descriptor_sets(ctx->device); vk_buffer d_X = ggml_vk_create_buffer_check(ctx->device, sizeof(X_TYPE) * x_ne, vk::MemoryPropertyFlagBits::eDeviceLocal); @@ -6766,7 +7605,7 @@ static void ggml_vk_test_matmul(ggml_backend_vk_context * ctx, size_t m, size_t ctx, subctx, p, ggml_vk_subbuffer(d_X), ggml_vk_subbuffer(d_Y), ggml_vk_subbuffer(d_D), ggml_vk_subbuffer(ctx->prealloc_split_k), m, n, k, k, k, m, k*m, k*n, m*n, - split_k, batch, batch, batch, 1, 1 + split_k, batch, batch, batch, 1, 1, n ); } ggml_vk_ctx_end(subctx); @@ -6965,6 +7804,10 @@ static void ggml_vk_test_dequant(ggml_backend_vk_context * ctx, size_t ne, ggml_ ggml_pipeline_request_descriptor_sets(ctx->device, p, 1); + if (ctx->device->need_compiles) { + ggml_vk_load_shaders(ctx->device); + } + ggml_pipeline_allocate_descriptor_sets(ctx->device); ggml_vk_buffer_write(qx_buf, 0, qx, qx_sz); @@ -7024,66 +7867,198 @@ static void ggml_vk_test_dequant(ggml_backend_vk_context * ctx, size_t ne, ggml_ free(x_chk); } -static void ggml_vk_test_dequant_matmul(ggml_backend_vk_context * ctx, size_t m, size_t n, size_t k, size_t batch, size_t num_it, size_t split_k, size_t shader_size, ggml_type quant) { +// This does not work without ggml q8_1 quantization support +// +// typedef uint16_t ggml_half; +// typedef uint32_t ggml_half2; +// +// #define QK8_1 32 +// typedef struct { +// union { +// struct { +// ggml_half d; // delta +// ggml_half s; // d * sum(qs[i]) +// } GGML_COMMON_AGGR_S; +// ggml_half2 ds; +// } GGML_COMMON_AGGR_U; +// int8_t qs[QK8_1]; // quants +// } block_q8_1; +// +// static void ggml_vk_test_quantize(ggml_backend_vk_context * ctx, size_t ne, ggml_type quant) { +// VK_LOG_DEBUG("ggml_vk_test_quantize(" << ne << ")"); +// GGML_ASSERT(quant == GGML_TYPE_Q8_1); +// +// const size_t x_sz = sizeof(float) * ne; +// const size_t qx_sz = ne * ggml_type_size(quant)/ggml_blck_size(quant); +// float * x = (float *) malloc(x_sz); +// block_q8_1 * qx = (block_q8_1 *)malloc(qx_sz); +// block_q8_1 * qx_res = (block_q8_1 *)malloc(qx_sz); +// vk_buffer x_buf = ggml_vk_create_buffer_check(ctx->device, x_sz, vk::MemoryPropertyFlagBits::eDeviceLocal); +// vk_buffer qx_buf = ggml_vk_create_buffer_check(ctx->device, qx_sz, vk::MemoryPropertyFlagBits::eDeviceLocal); +// +// for (size_t i = 0; i < ne; i++) { +// x[i] = rand() / (float)RAND_MAX; +// } +// +// vk_pipeline p = ggml_vk_get_quantize_pipeline(ctx, quant); +// +// ggml_pipeline_request_descriptor_sets(ctx->device, p, 1); +// +// if (ctx->device->need_compiles) { +// ggml_vk_load_shaders(ctx->device); +// } +// +// ggml_pipeline_allocate_descriptor_sets(ctx->device); +// +// ggml_vk_buffer_write(x_buf, 0, x, x_sz); +// +// vk_context subctx = ggml_vk_create_context(ctx, ctx->device->compute_queue); +// ggml_vk_ctx_begin(ctx->device, subctx); +// ggml_vk_quantize_q8_1(ctx, subctx, ggml_vk_subbuffer(x_buf), ggml_vk_subbuffer(qx_buf), ne); +// ggml_vk_ctx_end(subctx); +// +// auto begin = std::chrono::high_resolution_clock::now(); +// +// ggml_vk_submit(subctx, ctx->fence); +// VK_CHECK(ctx->device->device.waitForFences({ ctx->fence }, true, UINT64_MAX), "ggml_vk_test_quantize waitForFences"); +// ctx->device->device.resetFences({ ctx->fence }); +// +// auto end = std::chrono::high_resolution_clock::now(); +// +// double ms_quant = std::chrono::duration_cast(end-begin).count() / 1000.0; +// ggml_vk_buffer_read(qx_buf, 0, qx, qx_sz); +// +// ggml_vk_quantize_data(x, qx_res, ne, quant); +// +// int first_err = -1; +// +// for (size_t i = 0; i < ne / 32; i++) { +// double error = std::fabs(ggml_fp16_to_fp32(qx_res[i].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.d) - ggml_fp16_to_fp32(qx[i].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.d)); +// +// if (first_err < 0 && error > 0.1) { +// first_err = i; +// } +// +// error = std::fabs(ggml_fp16_to_fp32(qx_res[i].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.s) - ggml_fp16_to_fp32(qx[i].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.s)); +// +// if (first_err < 0 && error > 0.1) { +// first_err = i; +// } +// +// for (size_t j = 0; j < 32; j++) { +// uint64_t error = std::abs(qx_res[i].qs[j] - qx[i].qs[j]); +// +// if (first_err < 0 && error > 1) { +// first_err = i; +// } +// } +// } +// +// std::cerr << "TEST QUANTIZE " << ggml_type_name(quant) << " time=" << ms_quant << "ms " << (first_err == -1 ? "CORRECT" : "INCORRECT") << std::endl; +// +// if (first_err != -1) { +// std::cerr << "first_error = " << first_err << std::endl; +// std::cerr << "Actual result: " << std::endl << std::endl; +// std::cout << "d=" << ggml_fp16_to_fp32(qx[first_err].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.d) << " s=" << ggml_fp16_to_fp32(qx[first_err].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.s) << " "; +// for (size_t j = 0; j < 32; j++) { +// std::cout << " qs" << j << "=" << (uint32_t)qx[first_err].qs[j] << " "; +// } +// std::cerr << std::endl << std::endl << "Expected result: " << std::endl << std::endl; +// std::cout << "d=" << ggml_fp16_to_fp32(qx_res[first_err].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.d) << " s=" << ggml_fp16_to_fp32(qx_res[first_err].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.s) << " "; +// for (size_t j = 0; j < 32; j++) { +// std::cout << " qs" << j << "=" << (uint32_t)qx_res[first_err].qs[j] << " "; +// } +// std::cerr << std::endl; +// } +// +// ggml_vk_destroy_buffer(x_buf); +// ggml_vk_destroy_buffer(qx_buf); +// +// free(x); +// free(qx); +// free(qx_res); +// } + +static void ggml_vk_test_dequant_matmul(ggml_backend_vk_context * ctx, size_t m, size_t n, size_t k, size_t batch, size_t num_it, size_t split_k, size_t shader_size, ggml_type quant, bool mmq = false) { VK_LOG_DEBUG("ggml_vk_test_dequant_matmul(" << m << ", " << n << ", " << k << ", " << batch << ", " << num_it << ", " << split_k << ", " << ggml_type_name(quant) << ")"); const size_t x_ne = m * k * batch; const size_t y_ne = k * n * batch; const size_t d_ne = m * n * batch; + vk_matmul_pipeline2 * pipelines; + + if (mmq) { + pipelines = ctx->device->pipeline_dequant_mul_mat_mat_q8_1; + } else { + pipelines = ctx->device->pipeline_dequant_mul_mat_mat; + } + + const bool fp16acc = ctx->device->fp16; + vk_pipeline p; std::string shname; if (shader_size == 0) { - p = ctx->device->fp16 ? ctx->device->pipeline_dequant_mul_mat_mat[quant].f16acc->a_s : ctx->device->pipeline_dequant_mul_mat_mat[quant].f32acc->a_s; + p = fp16acc ? pipelines[quant].f16acc->a_s : pipelines[quant].f32acc->a_s; shname = std::string(ggml_type_name(quant)) + "_ALIGNED_S"; } else if (shader_size == 1) { - p = ctx->device->fp16 ? ctx->device->pipeline_dequant_mul_mat_mat[quant].f16acc->a_m : ctx->device->pipeline_dequant_mul_mat_mat[quant].f32acc->a_m; + p = fp16acc ? pipelines[quant].f16acc->a_m : pipelines[quant].f32acc->a_m; shname = std::string(ggml_type_name(quant)) + "_ALIGNED_M"; } else if (shader_size == 2) { - p = ctx->device->fp16 ? ctx->device->pipeline_dequant_mul_mat_mat[quant].f16acc->a_l : ctx->device->pipeline_dequant_mul_mat_mat[quant].f32acc->a_l; + p = fp16acc ? pipelines[quant].f16acc->a_l : pipelines[quant].f32acc->a_l; shname = std::string(ggml_type_name(quant)) + "_ALIGNED_L"; } else { GGML_ASSERT(0); } - const size_t kpad = ggml_vk_align_size(k, p->align); + const size_t kpad = mmq ? 0 : ggml_vk_align_size(k, p->align); - if (k != kpad) { + if (mmq || k != kpad) { if (shader_size == 0) { - p = ctx->device->fp16 ? ctx->device->pipeline_dequant_mul_mat_mat[quant].f16acc->s : ctx->device->pipeline_dequant_mul_mat_mat[quant].f32acc->s; + p = fp16acc ? pipelines[quant].f16acc->s : pipelines[quant].f32acc->s; shname = std::string(ggml_type_name(quant)) + "_S"; } else if (shader_size == 1) { - p = ctx->device->fp16 ? ctx->device->pipeline_dequant_mul_mat_mat[quant].f16acc->m : ctx->device->pipeline_dequant_mul_mat_mat[quant].f32acc->m; + p = fp16acc ? pipelines[quant].f16acc->m : pipelines[quant].f32acc->m; shname = std::string(ggml_type_name(quant)) + "_M"; } else if (shader_size == 2) { - p = ctx->device->fp16 ? ctx->device->pipeline_dequant_mul_mat_mat[quant].f16acc->l : ctx->device->pipeline_dequant_mul_mat_mat[quant].f32acc->l; + p = fp16acc ? pipelines[quant].f16acc->l : pipelines[quant].f32acc->l; shname = std::string(ggml_type_name(quant)) + "_L"; } else { GGML_ASSERT(0); } } + if (p == nullptr) { + std::cerr << "error: no pipeline for ggml_vk_test_dequant_matmul " << ggml_type_name(quant) << std::endl; + return; + } + const size_t x_sz = sizeof(float) * x_ne; const size_t y_sz = sizeof(float) * y_ne; const size_t qx_sz = x_ne * ggml_type_size(quant)/ggml_blck_size(quant); + const size_t qy_sz = mmq ? y_ne * ggml_type_size(GGML_TYPE_Q8_1)/ggml_blck_size(GGML_TYPE_Q8_1) : y_sz; const size_t d_sz = sizeof(float) * d_ne; float * x = (float *) malloc(x_sz); float * y = (float *) malloc(y_sz); void * qx = malloc(qx_sz); vk_buffer qx_buf = ggml_vk_create_buffer_check(ctx->device, qx_sz, vk::MemoryPropertyFlagBits::eDeviceLocal); vk_buffer y_buf = ggml_vk_create_buffer_check(ctx->device, y_sz, vk::MemoryPropertyFlagBits::eDeviceLocal); + vk_buffer qy_buf = ggml_vk_create_buffer_check(ctx->device, qy_sz, vk::MemoryPropertyFlagBits::eDeviceLocal); vk_buffer d_buf = ggml_vk_create_buffer_check(ctx->device, d_sz, vk::MemoryPropertyFlagBits::eDeviceLocal); float * d = (float *) malloc(d_sz); float * d_chk = (float *) malloc(d_sz); for (size_t i = 0; i < x_ne; i++) { x[i] = (rand() / (float)RAND_MAX) * 2.0f - 1.0f; + // x[i] = (i % k == i / k) ? 1.0f : 0.0f; + // x[i] = i % k; } ggml_vk_quantize_data(x, qx, x_ne, quant); for (size_t i = 0; i < y_ne; i++) { - // y[i] = rand() / (float)RAND_MAX; - y[i] = (i % k == i / k) ? 1.0f : 0.0f; + y[i] = (rand() / (float)RAND_MAX) * 2.0f - 1.0f; + // y[i] = (i % k == i / k) ? 1.0f : 0.0f; + // y[i] = i % k; } ggml_pipeline_request_descriptor_sets(ctx->device, p, num_it); @@ -7098,6 +8073,13 @@ static void ggml_vk_test_dequant_matmul(ggml_backend_vk_context * ctx, size_t m, ctx->prealloc_split_k = ggml_vk_create_buffer_check(ctx->device, sizeof(float) * d_ne * split_k, vk::MemoryPropertyFlagBits::eDeviceLocal); } } + if (mmq) { + ggml_pipeline_request_descriptor_sets(ctx->device, ctx->device->pipeline_quantize_q8_1, num_it); + } + + if (ctx->device->need_compiles) { + ggml_vk_load_shaders(ctx->device); + } ggml_pipeline_allocate_descriptor_sets(ctx->device); @@ -7106,13 +8088,25 @@ static void ggml_vk_test_dequant_matmul(ggml_backend_vk_context * ctx, size_t m, vk_context subctx = ggml_vk_create_context(ctx, ctx->device->compute_queue); ggml_vk_ctx_begin(ctx->device, subctx); - for (size_t i = 0; i < num_it; i++) { - ggml_vk_matmul( - ctx, subctx, p, ggml_vk_subbuffer(qx_buf), ggml_vk_subbuffer(y_buf), ggml_vk_subbuffer(d_buf), ggml_vk_subbuffer(ctx->prealloc_split_k), - m, n, k, - k, k, m, k*m, k*n, m*n, - split_k, batch, batch, batch, 1, 1 - ); + if (mmq) { + for (size_t i = 0; i < num_it; i++) { + ggml_vk_quantize_q8_1(ctx, subctx, { y_buf, 0, y_sz }, { qy_buf, 0, qy_sz }, y_ne); + ggml_vk_matmul( + ctx, subctx, p, { qx_buf, 0, qx_sz }, { qy_buf, 0, qy_sz }, { d_buf, 0, d_sz }, { ctx->prealloc_split_k, 0, ctx->prealloc_size_split_k }, + m, n, k, + k, k, m, k*m, k*n, m*n, + split_k, batch, batch, batch, 1, 1, n + ); + } + } else { + for (size_t i = 0; i < num_it; i++) { + ggml_vk_matmul( + ctx, subctx, p, { qx_buf, 0, qx_sz }, { y_buf, 0, y_sz }, { d_buf, 0, d_sz }, { ctx->prealloc_split_k, 0, ctx->prealloc_size_split_k }, + m, n, k, + k, k, m, k*m, k*n, m*n, + split_k, batch, batch, batch, 1, 1, n + ); + } } ggml_vk_ctx_end(subctx); @@ -7170,7 +8164,11 @@ static void ggml_vk_test_dequant_matmul(ggml_backend_vk_context * ctx, size_t m, double tflops = 2.0*m*n*k*batch*num_it / (time_ms / 1000.0) / (1000.0*1000.0*1000.0*1000.0); - std::cerr << "TEST MMQ " << shname << " m=" << m << " n=" << n << " k=" << k << " batch=" << batch << " split_k=" << split_k << " matmul " << time_ms / num_it << "ms " << tflops << " TFLOPS avg_err=" << avg_err << std::endl; + std::cerr << "TEST dequant matmul " << shname; + if (mmq) { + std::cerr << " mmq"; + } + std::cerr << " m=" << m << " n=" << n << " k=" << k << " batch=" << batch << " split_k=" << split_k << " matmul " << time_ms / num_it << "ms " << tflops << " TFLOPS avg_err=" << avg_err << std::endl; if (avg_err > 0.01 || std::isnan(avg_err)) { std::cerr << "m = " << first_err_m << " n = " << first_err_n << " b = " << first_err_b << std::endl; @@ -7180,6 +8178,12 @@ static void ggml_vk_test_dequant_matmul(ggml_backend_vk_context * ctx, size_t m, std::cerr << "Expected result: " << std::endl << std::endl; ggml_vk_print_matrix_area(d_chk, GGML_TYPE_F32, m, n, first_err_m, first_err_n, first_err_b); + std::cerr << "src0: " << std::endl << std::endl; + ggml_vk_print_matrix_area(x, GGML_TYPE_F32, k, m, first_err_m, first_err_n, first_err_b); + std::cerr << std::endl; + std::cerr << "src1: " << std::endl << std::endl; + ggml_vk_print_matrix_area(y, GGML_TYPE_F32, k, n, first_err_m, first_err_n, first_err_b); + if (split_k > 1) { float * split_k_buf = (float *) malloc(sizeof(float) * d_ne * split_k); ggml_vk_buffer_read(ctx->prealloc_split_k, 0, split_k_buf, sizeof(float) * d_ne * split_k); @@ -7202,6 +8206,7 @@ static void ggml_vk_test_dequant_matmul(ggml_backend_vk_context * ctx, size_t m, ggml_vk_destroy_buffer(qx_buf); ggml_vk_destroy_buffer(y_buf); + ggml_vk_destroy_buffer(qy_buf); ggml_vk_destroy_buffer(d_buf); free(x); @@ -7236,6 +8241,24 @@ static void ggml_vk_preallocate_buffers(ggml_backend_vk_context * ctx) { }; const size_t num_it = 100; + ggml_vk_test_dequant_matmul(ctx, 4096, 512, 4096, 2, num_it, 1, 0, GGML_TYPE_Q4_0); + ggml_vk_test_dequant_matmul(ctx, 4096, 512, 4096, 2, num_it, 1, 1, GGML_TYPE_Q4_0); + ggml_vk_test_dequant_matmul(ctx, 4096, 512, 4096, 2, num_it, 1, 2, GGML_TYPE_Q4_0); + + ggml_vk_test_dequant_matmul(ctx, 4096, 512, 4096, 2, num_it, 1, 0, GGML_TYPE_Q4_0, true); + ggml_vk_test_dequant_matmul(ctx, 4096, 512, 4096, 2, num_it, 1, 1, GGML_TYPE_Q4_0, true); + ggml_vk_test_dequant_matmul(ctx, 4096, 512, 4096, 2, num_it, 1, 2, GGML_TYPE_Q4_0, true); + + ggml_vk_test_dequant_matmul(ctx, 4096, 512, 4096, 2, num_it, 1, 0, GGML_TYPE_Q8_0); + ggml_vk_test_dequant_matmul(ctx, 4096, 512, 4096, 2, num_it, 1, 1, GGML_TYPE_Q8_0); + ggml_vk_test_dequant_matmul(ctx, 4096, 512, 4096, 2, num_it, 1, 2, GGML_TYPE_Q8_0); + + ggml_vk_test_dequant_matmul(ctx, 4096, 512, 4096, 2, num_it, 1, 0, GGML_TYPE_Q8_0, true); + ggml_vk_test_dequant_matmul(ctx, 4096, 512, 4096, 2, num_it, 1, 1, GGML_TYPE_Q8_0, true); + ggml_vk_test_dequant_matmul(ctx, 4096, 512, 4096, 2, num_it, 1, 2, GGML_TYPE_Q8_0, true); + + abort(); + for (size_t i = 0; i < vals.size(); i += 3) { ggml_vk_test_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 1, 0); ggml_vk_test_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 1, 1); @@ -7310,11 +8333,11 @@ static void ggml_vk_preallocate_buffers(ggml_backend_vk_context * ctx) { } } -static bool ggml_vk_compute_forward(ggml_backend_vk_context* ctx, ggml_tensor* tensor, int tensor_idx, bool use_fence); +static bool ggml_vk_compute_forward(ggml_backend_vk_context* ctx, ggml_tensor* tensor, int tensor_idx, bool use_fence, bool almost_ready); // Returns true if node has enqueued work into the queue, false otherwise // If submit is true the current all operations queued so far are being submitted to Vulkan to overlap cmdlist creation and GPU execution. -static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * node, int node_idx, ggml_tensor *node_begin, int node_idx_begin, bool dryrun, bool last_node, bool submit){ +static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * node, int node_idx, ggml_tensor *node_begin, int node_idx_begin, bool dryrun, bool last_node, bool almost_ready, bool submit){ if (ggml_is_empty(node) || !node->buffer) { return false; } @@ -7372,6 +8395,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod case GGML_OP_GROUP_NORM: case GGML_OP_RMS_NORM: case GGML_OP_RMS_NORM_BACK: + case GGML_OP_L2_NORM: case GGML_OP_DIAG_MASK_INF: case GGML_OP_SOFT_MAX: case GGML_OP_SOFT_MAX_BACK: @@ -7387,7 +8411,9 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod case GGML_OP_IM2COL: case GGML_OP_TIMESTEP_EMBEDDING: case GGML_OP_POOL_2D: + case GGML_OP_CONV_2D_DW: case GGML_OP_RWKV_WKV6: + case GGML_OP_RWKV_WKV7: case GGML_OP_LEAKY_RELU: case GGML_OP_FLASH_ATTN_EXT: case GGML_OP_OPT_STEP_ADAMW: @@ -7434,6 +8460,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod case GGML_OP_GROUP_NORM: case GGML_OP_RMS_NORM: case GGML_OP_RMS_NORM_BACK: + case GGML_OP_L2_NORM: case GGML_OP_UNARY: case GGML_OP_DIAG_MASK_INF: case GGML_OP_SOFT_MAX: @@ -7448,6 +8475,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod case GGML_OP_IM2COL: case GGML_OP_TIMESTEP_EMBEDDING: case GGML_OP_POOL_2D: + case GGML_OP_CONV_2D_DW: case GGML_OP_LEAKY_RELU: { // These operations all go through ggml_vk_op_f32, so short-circuit and @@ -7551,6 +8579,10 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod case GGML_OP_RMS_NORM_BACK: ggml_vk_rms_norm_back(ctx, compute_ctx, src0, src1, node, dryrun); + break; + case GGML_OP_L2_NORM: + ggml_vk_l2_norm(ctx, compute_ctx, src0, node, dryrun); + break; case GGML_OP_UNARY: switch (ggml_get_unary_op(node)) { @@ -7617,6 +8649,10 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod case GGML_OP_POOL_2D: ggml_vk_pool_2d(ctx, compute_ctx, src0, node, dryrun); + break; + case GGML_OP_CONV_2D_DW: + ggml_vk_conv_2d_dw(ctx, compute_ctx, src0, src1, node, dryrun); + break; case GGML_OP_LEAKY_RELU: ggml_vk_leaky_relu(ctx, compute_ctx, src0, node, dryrun); @@ -7641,6 +8677,11 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod break; + case GGML_OP_RWKV_WKV7: + ggml_vk_rwkv_wkv7(ctx, compute_ctx, node, dryrun); + + break; + case GGML_OP_OPT_STEP_ADAMW: ggml_vk_opt_step_adamw(ctx, compute_ctx, node, dryrun); @@ -7674,7 +8715,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod ctx->compute_ctx.reset(); - bool ok = ggml_vk_compute_forward(ctx, node_begin, node_idx_begin, false); + bool ok = ggml_vk_compute_forward(ctx, node_begin, node_idx_begin, false, almost_ready); if (!ok) { if (node->op == GGML_OP_UNARY) { std::cerr << __func__ << ": error: op not supported UNARY " << node->name << " (" << ggml_unary_op_name(static_cast(node->op_params[0])) << ")" << std::endl; @@ -7688,7 +8729,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod return true; } -static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_tensor * tensor, int tensor_idx, bool use_fence = true){ +static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_tensor * tensor, int tensor_idx, bool use_fence = true, bool almost_ready = false) { ggml_backend_buffer * buf = nullptr; switch (tensor->op) { @@ -7714,6 +8755,7 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_tensor * case GGML_OP_GROUP_NORM: case GGML_OP_RMS_NORM: case GGML_OP_RMS_NORM_BACK: + case GGML_OP_L2_NORM: case GGML_OP_DIAG_MASK_INF: case GGML_OP_SOFT_MAX: case GGML_OP_SOFT_MAX_BACK: @@ -7732,7 +8774,9 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_tensor * case GGML_OP_IM2COL: case GGML_OP_TIMESTEP_EMBEDDING: case GGML_OP_POOL_2D: + case GGML_OP_CONV_2D_DW: case GGML_OP_RWKV_WKV6: + case GGML_OP_RWKV_WKV7: case GGML_OP_LEAKY_RELU: case GGML_OP_REPEAT: case GGML_OP_REPEAT_BACK: @@ -7789,12 +8833,15 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_tensor * memcpy(cpy.dst, cpy.src, cpy.n); } - ggml_vk_submit(subctx, use_fence ? ctx->fence : vk::Fence{}); + if (almost_ready && !ctx->almost_ready_fence_pending && !use_fence) { + ggml_vk_submit(subctx, ctx->almost_ready_fence); + ctx->almost_ready_fence_pending = true; + } else { + ggml_vk_submit(subctx, use_fence ? ctx->fence : vk::Fence{}); + } if (use_fence) { - VK_CHECK(ctx->device->device.waitForFences({ ctx->fence }, true, UINT64_MAX), "ggml_vk_compute_forward waitForFences"); - - ctx->device->device.resetFences({ ctx->fence }); + ggml_vk_wait_for_fence(ctx); } #ifdef GGML_VULKAN_CHECK_RESULTS ggml_vk_check_results_1(tensor); @@ -7880,6 +8927,7 @@ static void ggml_vk_cleanup(ggml_backend_vk_context * ctx) { ctx->gc.events.clear(); ctx->device->device.destroyFence(ctx->fence); + ctx->device->device.destroyFence(ctx->almost_ready_fence); } static int ggml_vk_get_device_count() { @@ -7922,11 +8970,12 @@ static void * ggml_backend_vk_buffer_get_base(ggml_backend_buffer_t buffer) { UNUSED(buffer); } -static void ggml_backend_vk_buffer_init_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor) { +static enum ggml_status ggml_backend_vk_buffer_init_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor) { VK_LOG_DEBUG("ggml_backend_vk_buffer_init_tensor(" << buffer << " (" << buffer->context << "), " << tensor << ")"); if (tensor->view_src != nullptr) { GGML_ASSERT(tensor->view_src->buffer->buft == buffer->buft); } + return GGML_STATUS_SUCCESS; } static void ggml_backend_vk_buffer_memset_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor, uint8_t value, size_t offset, size_t size) { @@ -8225,8 +9274,7 @@ static void ggml_backend_vk_synchronize(ggml_backend_t backend) { } ggml_vk_submit(transfer_ctx, ctx->fence); - VK_CHECK(ctx->device->device.waitForFences({ ctx->fence }, true, UINT64_MAX), "ggml_backend_vk_synchronize waitForFences"); - ctx->device->device.resetFences({ ctx->fence }); + ggml_vk_wait_for_fence(ctx); for (auto& cpy : transfer_ctx->out_memcpys) { memcpy(cpy.dst, cpy.src, cpy.n); @@ -8243,8 +9291,12 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg VK_LOG_DEBUG("ggml_backend_vk_graph_compute(" << cgraph->n_nodes << " nodes)"); ggml_backend_vk_context * ctx = (ggml_backend_vk_context *)backend->context; + uint64_t total_mat_mul_bytes = 0; for (int i = 0; i < cgraph->n_nodes; i++) { - ggml_vk_build_graph(ctx, cgraph->nodes[i], i, nullptr, 0, true, false, false); + ggml_vk_build_graph(ctx, cgraph->nodes[i], i, nullptr, 0, true, false, false, false); + if (cgraph->nodes[i]->op == GGML_OP_MUL_MAT || cgraph->nodes[i]->op == GGML_OP_MUL_MAT_ID) { + total_mat_mul_bytes += ggml_nbytes(cgraph->nodes[i]->src[0]); + } } if (ctx->device->need_compiles) { ggml_vk_load_shaders(ctx->device); @@ -8265,19 +9317,32 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg bool first_node_in_batch = true; // true if next node will be first node in a batch int submit_node_idx = 0; // index to first node in a batch - // Submit work every nodes_per_submit nodes to overlap CPU cmdbuffer generation with GPU execution. - // Start with a smaller count to get work submitted right away, and increase it after each submit. - int nodes_per_submit = 20; + // Submit after enough work has accumulated, to overlap CPU cmdbuffer generation with GPU execution. + // Estimate the amount of matmul work by looking at the weight matrix size, and submit every 100MB + // (and scaled down based on model size, so smaller models submit earlier). + // Also submit at least every 100 nodes, in case there are workloads without as much matmul. + int nodes_per_submit = 100; int submitted_nodes = 0; int submit_count = 0; + uint64_t mul_mat_bytes = 0; + uint64_t mul_mat_bytes_per_submit = std::min(uint64_t(100*1000*1000), total_mat_mul_bytes / 40u); for (int i = 0; i < cgraph->n_nodes; i++) { if (first_node_in_batch) { submit_node_idx = i; } - bool submit = (submitted_nodes >= nodes_per_submit) || (i == last_node); + if (cgraph->nodes[i]->op == GGML_OP_MUL_MAT || cgraph->nodes[i]->op == GGML_OP_MUL_MAT_ID) { + mul_mat_bytes += ggml_nbytes(cgraph->nodes[i]->src[0]); + } - bool enqueued = ggml_vk_build_graph(ctx, cgraph->nodes[i], i, cgraph->nodes[submit_node_idx], submit_node_idx, false, i == last_node, submit); + // Signal the almost_ready fence when the graph is mostly complete (< 20% remaining) + bool almost_ready = (cgraph->n_nodes - i) < cgraph->n_nodes / 5; + bool submit = (submitted_nodes >= nodes_per_submit) || + (mul_mat_bytes >= mul_mat_bytes_per_submit) || + (i == last_node) || + (almost_ready && !ctx->almost_ready_fence_pending); + + bool enqueued = ggml_vk_build_graph(ctx, cgraph->nodes[i], i, cgraph->nodes[submit_node_idx], submit_node_idx, false, i == last_node, almost_ready, submit); if (enqueued) { ++submitted_nodes; @@ -8289,16 +9354,12 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg #endif } - if (submit) { + if (submit && enqueued) { first_node_in_batch = true; submitted_nodes = 0; - switch (submit_count) { - case 0: - nodes_per_submit = 50; - break; - default: - nodes_per_submit = 100; - break; + mul_mat_bytes = 0; + if (submit_count < 3) { + mul_mat_bytes_per_submit *= 2; } submit_count++; } @@ -8450,7 +9511,10 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm case GGML_UNARY_OP_RELU: case GGML_UNARY_OP_TANH: case GGML_UNARY_OP_SIGMOID: - return ggml_is_contiguous(op->src[0]); + return ggml_is_contiguous(op->src[0]) && + (op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16) && + (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) && + (op->src[0]->type == op->type); default: return false; } @@ -8468,6 +9532,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm switch (src0_type) { case GGML_TYPE_F32: case GGML_TYPE_F16: + case GGML_TYPE_BF16: case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_1: case GGML_TYPE_Q5_0: @@ -8503,19 +9568,23 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm if (a->ne[3] != b->ne[3]) { return false; } - if (!(ggml_vk_dim01_contiguous(op->src[0]) || op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16) || + if (!(ggml_vk_dim01_contiguous(op->src[0]) || op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16 || op->src[0]->type == GGML_TYPE_BF16) || !(ggml_vk_dim01_contiguous(op->src[1]) || op->src[1]->type == GGML_TYPE_F32 || op->src[1]->type == GGML_TYPE_F16)) { return false; } + if (op->src[0]->type == GGML_TYPE_BF16 && op->src[1]->type == GGML_TYPE_F16) { + // We currently don't have a bf16 x f16 shader, or an fp16->bf16 copy shader. + // So don't support this combination for now. + return false; + } return true; } break; case GGML_OP_FLASH_ATTN_EXT: { ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context; - if (!ggml_vk_get_device(ctx->device)->coopmat2) { - return false; - } + auto device = ggml_vk_get_device(ctx->device); + bool coopmat2 = device->coopmat2; switch (op->src[0]->ne[0]) { case 64: case 80: @@ -8527,6 +9596,10 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm default: return false; } + if (op->src[1]->ne[0] != op->src[2]->ne[0]) { + // different head sizes of K and V are not supported yet + return false; + } if (op->src[0]->type != GGML_TYPE_F32) { return false; } @@ -8544,10 +9617,12 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm switch (op->src[1]->type) { case GGML_TYPE_F16: case GGML_TYPE_Q4_0: + case GGML_TYPE_Q8_0: + // supported in scalar and coopmat2 paths + break; case GGML_TYPE_Q4_1: case GGML_TYPE_Q5_0: case GGML_TYPE_Q5_1: - case GGML_TYPE_Q8_0: // K dequants currently disabled because D dimension is rounded up to 256 and runs inefficiently //case GGML_TYPE_Q2_K: //case GGML_TYPE_Q3_K: @@ -8563,10 +9638,18 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm //case GGML_TYPE_IQ3_S: //case GGML_TYPE_IQ4_XS: case GGML_TYPE_IQ4_NL: + // currently supported only in coopmat2 path + if (!coopmat2) { + return false; + } break; default: return false; } + if (!coopmat2 && !device->subgroup_shuffle) { + // scalar FA uses subgroupShuffle + return false; + } return true; } case GGML_OP_GET_ROWS: @@ -8574,6 +9657,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm switch (op->src[0]->type) { case GGML_TYPE_F32: case GGML_TYPE_F16: + case GGML_TYPE_BF16: case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_1: case GGML_TYPE_Q5_0: @@ -8604,6 +9688,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm switch (src1_type) { case GGML_TYPE_F32: case GGML_TYPE_F16: + case GGML_TYPE_BF16: case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_1: case GGML_TYPE_Q5_0: @@ -8617,6 +9702,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm } if (src1_type == GGML_TYPE_F32) { switch (src0_type) { + case GGML_TYPE_F16: case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_1: case GGML_TYPE_Q5_0: @@ -8645,25 +9731,31 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm case GGML_OP_VIEW: case GGML_OP_PERMUTE: case GGML_OP_TRANSPOSE: + case GGML_OP_RMS_NORM: return true; case GGML_OP_NORM: case GGML_OP_GROUP_NORM: - case GGML_OP_RMS_NORM: + case GGML_OP_L2_NORM: return ggml_is_contiguous(op->src[0]); case GGML_OP_ADD: - case GGML_OP_ACC: case GGML_OP_SUB: case GGML_OP_MUL: case GGML_OP_DIV: - case GGML_OP_CONCAT: + return (op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16) && + (op->src[1]->type == GGML_TYPE_F32 || op->src[1]->type == GGML_TYPE_F16) && + (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16); case GGML_OP_SILU_BACK: case GGML_OP_RMS_NORM_BACK: - case GGML_OP_UPSCALE: - case GGML_OP_SCALE: case GGML_OP_SQR: case GGML_OP_SIN: case GGML_OP_COS: case GGML_OP_CLAMP: + return op->src[0]->type == GGML_TYPE_F32; + case GGML_OP_UPSCALE: + return op->op_params[0] == GGML_SCALE_MODE_NEAREST; + case GGML_OP_ACC: + case GGML_OP_CONCAT: + case GGML_OP_SCALE: case GGML_OP_PAD: case GGML_OP_DIAG_MASK_INF: case GGML_OP_SOFT_MAX: @@ -8675,8 +9767,10 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm case GGML_OP_COUNT_EQUAL: case GGML_OP_IM2COL: case GGML_OP_TIMESTEP_EMBEDDING: + case GGML_OP_CONV_2D_DW: case GGML_OP_POOL_2D: case GGML_OP_RWKV_WKV6: + case GGML_OP_RWKV_WKV7: case GGML_OP_LEAKY_RELU: case GGML_OP_OPT_STEP_ADAMW: return true; @@ -8823,7 +9917,7 @@ static bool ggml_vk_instance_portability_enumeration_ext_available(const std::ve UNUSED(instance_extensions); } -static bool ggml_vk_khr_cooperative_matrix_support(const vk::PhysicalDeviceProperties& props, const vk::PhysicalDeviceDriverProperties& driver_props) { +static bool ggml_vk_khr_cooperative_matrix_support(const vk::PhysicalDeviceProperties& props, const vk::PhysicalDeviceDriverProperties& driver_props, vk_device_architecture arch) { switch (props.vendorID) { case VK_VENDOR_ID_INTEL: // Intel drivers don't support coopmat properly yet @@ -8831,10 +9925,7 @@ static bool ggml_vk_khr_cooperative_matrix_support(const vk::PhysicalDevicePrope case VK_VENDOR_ID_AMD: if (driver_props.driverID == vk::DriverId::eAmdProprietary || driver_props.driverID == vk::DriverId::eAmdOpenSource) { // Workaround for AMD proprietary driver reporting support on all GPUs - const std::string name = props.deviceName; - return name.rfind("AMD Radeon RX 7", 0) == 0 || name.rfind("AMD Radeon(TM) RX 7", 0) == 0 || // RDNA 3 consumer GPUs - name.rfind("AMD Radeon PRO W7", 0) == 0 || name.rfind("AMD Radeon(TM) PRO W7", 0) == 0 || // RDNA 3 workstation GPUs - name.rfind("AMD Radeon 7", 0) == 0 || name.rfind("AMD Radeon(TM) 7", 0) == 0; // RDNA 3 APUs + return arch == vk_device_architecture::AMD_RDNA3; } return true; default: @@ -9017,7 +10108,7 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) { } if (tensor->op == GGML_OP_FLASH_ATTN_EXT) { - const float *params = (const float *)tensor->op_params; + const float * params = (const float *)tensor->op_params; tensor_clone = ggml_flash_attn_ext(ggml_ctx, src_clone[0], src_clone[1], src_clone[2], src_clone[3], params[0], params[1], params[2]); } else if (tensor->op == GGML_OP_MUL_MAT) { tensor_clone = ggml_mul_mat(ggml_ctx, src_clone[0], src_clone[1]); @@ -9032,9 +10123,10 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) { } else if (tensor->op == GGML_OP_CONCAT) { tensor_clone = ggml_concat(ggml_ctx, src_clone[0], src_clone[1], *(int *)tensor->op_params); } else if (tensor->op == GGML_OP_UPSCALE) { - tensor_clone = ggml_upscale_ext(ggml_ctx, src_clone[0], tensor->ne[0], tensor->ne[1], tensor->ne[2], tensor->ne[3]); + tensor_clone = ggml_upscale_ext(ggml_ctx, src_clone[0], tensor->ne[0], tensor->ne[1], tensor->ne[2], tensor->ne[3], tensor->op_params[0], tensor->op_params[1], (ggml_scale_mode) tensor->op_params[0]); } else if (tensor->op == GGML_OP_SCALE) { - tensor_clone = ggml_scale(ggml_ctx, src_clone[0], ((float *)tensor->op_params)[0]); + const float * params = (const float *)tensor->op_params; + tensor_clone = ggml_scale(ggml_ctx, src_clone[0], params[0]); } else if (tensor->op == GGML_OP_SQR) { tensor_clone = ggml_sqr(ggml_ctx, src_clone[0]); } else if (tensor->op == GGML_OP_SIN) { @@ -9042,7 +10134,8 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) { } else if (tensor->op == GGML_OP_COS) { tensor_clone = ggml_cos(ggml_ctx, src_clone[0]); } else if (tensor->op == GGML_OP_CLAMP) { - tensor_clone = ggml_clamp(ggml_ctx, src_clone[0], ((float *)tensor->op_params)[0], ((float *)tensor->op_params)[1]); + const float * params = (const float *)tensor->op_params; + tensor_clone = ggml_clamp(ggml_ctx, src_clone[0], params[0], params[1]); } else if (tensor->op == GGML_OP_PAD) { tensor_clone = ggml_pad(ggml_ctx, src_clone[0], tensor->ne[0] - src_clone[0]->ne[0], tensor->ne[1] - src_clone[0]->ne[1], tensor->ne[2] - src_clone[0]->ne[2], tensor->ne[3] - src_clone[0]->ne[3]); } else if (tensor->op == GGML_OP_REPEAT) { @@ -9056,7 +10149,8 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) { } else if (tensor->op == GGML_OP_NORM) { tensor_clone = ggml_norm(ggml_ctx, src_clone[0], *(float *)tensor->op_params); } else if (tensor->op == GGML_OP_GROUP_NORM) { - tensor_clone = ggml_group_norm(ggml_ctx, src_clone[0], *(int *)tensor->op_params, ((float *)tensor->op_params)[1]); + const float * float_params = (const float *)tensor->op_params; + tensor_clone = ggml_group_norm(ggml_ctx, src_clone[0], tensor->op_params[0], float_params[1]); } else if (tensor->op == GGML_OP_RMS_NORM) { tensor_clone = ggml_rms_norm(ggml_ctx, src_clone[0], *(float *)tensor->op_params); } else if (tensor->op == GGML_OP_RMS_NORM_BACK) { @@ -9064,16 +10158,20 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) { tensor_clone = ggml_rms_norm_back(ggml_ctx, src_clone[0], src_clone[1], eps); } else if (tensor->op == GGML_OP_SILU_BACK) { tensor_clone = ggml_silu_back(ggml_ctx, src_clone[0], src_clone[1]); + } else if (tensor->op == GGML_OP_L2_NORM) { + const float eps = ((float *) tensor->op_params)[0]; + tensor_clone = ggml_l2_norm(ggml_ctx, src_clone[0], eps); } else if (tensor->op == GGML_OP_SOFT_MAX) { if (src1 != nullptr) { - tensor_clone = ggml_soft_max_ext(ggml_ctx, src_clone[0], src_clone[1], ((float *)tensor->op_params)[0], ((float *)tensor->op_params)[1]); + const float * params = (const float *)tensor->op_params; + tensor_clone = ggml_soft_max_ext(ggml_ctx, src_clone[0], src_clone[1], params[0], params[1]); } else { tensor_clone = ggml_soft_max(ggml_ctx, src_clone[0]); } } else if (tensor->op == GGML_OP_SOFT_MAX_BACK) { tensor_clone = ggml_soft_max_ext_back(ggml_ctx, src_clone[0], src_clone[1], ((float *)tensor->op_params)[0], ((float *)tensor->op_params)[1]); } else if (tensor->op == GGML_OP_DIAG_MASK_INF) { - tensor_clone = ggml_diag_mask_inf(ggml_ctx, src_clone[0], *(int *)tensor->op_params); + tensor_clone = ggml_diag_mask_inf(ggml_ctx, src_clone[0], tensor->op_params[0]); } else if (tensor->op == GGML_OP_ROPE || tensor->op == GGML_OP_ROPE_BACK) { const int n_dims = ((int32_t *) tensor->op_params)[1]; const int mode = ((int32_t *) tensor->op_params)[2]; @@ -9183,6 +10281,9 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) { } else if (tensor->op == GGML_OP_RWKV_WKV6) { tensor_clone = ggml_rwkv_wkv6(ggml_ctx, src_clone[0], src_clone[1], src_clone[2], src_clone[3], src_clone[4], src_clone[5]); + } else if (tensor->op == GGML_OP_RWKV_WKV7) { + tensor_clone = ggml_rwkv_wkv7(ggml_ctx, src_clone[0], src_clone[1], src_clone[2], src_clone[3], + src_clone[4], src_clone[5], src_clone[6]); } else if (tensor->op == GGML_OP_OPT_STEP_ADAMW) { src_clone[0]->flags = src0->flags; tensor_clone = ggml_opt_step_adamw(ggml_ctx, src_clone[0], src_clone[1], diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt index 074031087..ad13f69b3 100644 --- a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt @@ -1,9 +1,20 @@ -find_package (Threads REQUIRED) -find_program(GLSLC_EXECUTABLE glslc) -if(NOT GLSLC_EXECUTABLE) - message(FATAL_ERROR "glslc not found.") -endif() +cmake_minimum_required(VERSION 3.19) +project("vulkan-shaders-gen" C CXX) +find_package (Threads REQUIRED) + +if (GGML_VULKAN_COOPMAT_GLSLC_SUPPORT) + add_compile_definitions(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT) +endif() +if (GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT) + add_compile_definitions(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT) +endif() +if (GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT) + add_compile_definitions(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT) +endif() +if (GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT) + add_compile_definitions(GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT) +endif() set(TARGET vulkan-shaders-gen) add_executable(${TARGET} vulkan-shaders-gen.cpp) install(TARGETS ${TARGET} RUNTIME) diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/contig_copy.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/contig_copy.comp index dd828c232..6567a8c54 100644 --- a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/contig_copy.comp +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/contig_copy.comp @@ -18,7 +18,11 @@ void main() { // fast path for when all four iterations are in-bounds if (idx + (num_iter-1)*num_threads < p.ne) { [[unroll]] for (uint i = 0; i < num_iter; ++i) { -#ifndef OPTIMIZATION_ERROR_WORKAROUND + +#if defined(DATA_D_BF16) + float f = float(data_a[get_aoffset() + idx]); + data_d[get_doffset() + idx] = D_TYPE(fp32_to_bf16(f)); +#elif !defined(OPTIMIZATION_ERROR_WORKAROUND) data_d[get_doffset() + idx] = D_TYPE(data_a[get_aoffset() + idx]); #else data_d[get_doffset() + idx] = data_a[get_aoffset() + idx]; @@ -31,7 +35,10 @@ void main() { continue; } -#ifndef OPTIMIZATION_ERROR_WORKAROUND +#if defined(DATA_D_BF16) + float f = float(data_a[get_aoffset() + idx]); + data_d[get_doffset() + idx] = D_TYPE(fp32_to_bf16(f)); +#elif !defined(OPTIMIZATION_ERROR_WORKAROUND) data_d[get_doffset() + idx] = D_TYPE(data_a[get_aoffset() + idx]); #else data_d[get_doffset() + idx] = data_a[get_aoffset() + idx]; diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_dw.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_dw.comp new file mode 100644 index 000000000..938c74da5 --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_dw.comp @@ -0,0 +1,105 @@ +#version 450 + +#include "types.comp" + +layout (push_constant) uniform parameter +{ + uint ne; + uint batches; + uint channels; + uint dst_w; + uint dst_h; + uint src_w; + uint src_h; + uint knl_w; + uint knl_h; + int stride_x; + int stride_y; + int pad_x; + int pad_y; + int dilation_x; + int dilation_y; +} p; + +layout (binding = 0) readonly buffer A {A_TYPE knl_data[];}; +layout (binding = 1) readonly buffer B {B_TYPE src_data[];}; +layout (binding = 2) writeonly buffer D {D_TYPE dst_data[];}; + +layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in; + +FLOAT_TYPE conv_2d_dw_whcn(uint idx) { + uint i0 = idx / p.dst_w; + uint dst_x = idx - i0 * p.dst_w; + uint i1 = i0 / p.dst_h; + uint dst_y = i0 - i1 * p.dst_h; + uint n = i1 / p.channels; + uint c = i1 - n * p.channels; + + uint src_i = n * p.channels * p.src_h * p.src_w + c * p.src_h * p.src_w; + uint knl_i = c * p.knl_h * p.knl_w; + + FLOAT_TYPE sum = 0.0; + for (uint knl_y = 0; knl_y < p.knl_h; ++knl_y) { + uint src_y = dst_y * p.stride_y + knl_y * p.dilation_y - p.pad_y; + if (src_y >= p.src_h) { // src_y < 0 will wrap to a large unsigned int + continue; + } + for (uint knl_x = 0; knl_x < p.knl_w; ++knl_x) { + uint src_x = dst_x * p.stride_x + knl_x * p.dilation_x - p.pad_x; + if (src_x >= p.src_w) { // src_x < 0 will wrap to a large unsigned int + continue; + } + FLOAT_TYPE v = FLOAT_TYPE(src_data[src_i + src_y * p.src_w + src_x]); + FLOAT_TYPE k = FLOAT_TYPE(knl_data[knl_i + knl_y * p.knl_w + knl_x]); + sum = fma(v, k, sum); + } + } + return sum; +} + +FLOAT_TYPE conv_2d_dw_cwhn(uint idx) { + uint i0 = idx / p.channels; + uint c = idx - i0 * p.channels; + uint i1 = i0 / p.dst_w; + uint dst_x = i0 - i1 * p.dst_w; + uint n = i1 / p.dst_h; + uint dst_y = i1 - n * p.dst_h; + + uint src_i = n * p.channels * p.src_h * p.src_w; + uint src_row = p.src_w * p.channels; + uint knl_row = p.knl_w * p.channels; + + FLOAT_TYPE sum = 0.0; + for (uint knl_y = 0; knl_y < p.knl_h; ++knl_y) { + uint src_y = dst_y * p.stride_y + knl_y * p.dilation_y - p.pad_y; + if (src_y >= p.src_h) { // src_y < 0 will wrap to a large unsigned int + continue; + } + for (uint knl_x = 0; knl_x < p.knl_w; ++knl_x) { + uint src_x = dst_x * p.stride_x + knl_x * p.dilation_x - p.pad_x; + if (src_x >= p.src_w) { // src_x < 0 will wrap to a large unsigned int + continue; + } + FLOAT_TYPE v = FLOAT_TYPE(src_data[src_i + src_y * src_row + src_x * p.channels + c]); + FLOAT_TYPE k = FLOAT_TYPE(knl_data[ knl_y * knl_row + knl_x * p.channels + c]); + sum = fma(v, k, sum); + } + } + return sum; +} + +void main() { + uint idx = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x; + if (idx >= p.ne) { + return; + } + + FLOAT_TYPE result = +#ifdef WHCN + conv_2d_dw_whcn(idx); +#else + conv_2d_dw_cwhn(idx); +#endif + dst_data[idx] = D_TYPE(result); +} + diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/copy.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/copy.comp index 29c906494..f476a2e3d 100644 --- a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/copy.comp +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/copy.comp @@ -12,7 +12,10 @@ void main() { return; } -#ifndef OPTIMIZATION_ERROR_WORKAROUND +#if defined(DATA_D_BF16) + float f = float(data_a[get_aoffset() + src0_idx(idx)]); + data_d[get_doffset() + dst_idx(idx)] = D_TYPE(fp32_to_bf16(f)); +#elif !defined(OPTIMIZATION_ERROR_WORKAROUND) data_d[get_doffset() + dst_idx(idx)] = D_TYPE(data_a[get_aoffset() + src0_idx(idx)]); #else data_d[get_doffset() + dst_idx(idx)] = data_a[get_aoffset() + src0_idx(idx)]; diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp index c813f1404..9c76437d9 100644 --- a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp @@ -1,5 +1,10 @@ #version 450 +#if RTE16 +#extension GL_EXT_spirv_intrinsics : enable +spirv_execution_mode(capabilities = [4467], 4462, 16); // RoundingModeRTE, 16 bits +#endif // RTE16 + #include "types.comp" #include "generic_unary_head.comp" diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.comp index 10318e876..0d9739d40 100644 --- a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.comp +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.comp @@ -23,6 +23,12 @@ vec2 dequantize(uint ib, uint iqs, uint a_offset) { } #endif +#if defined(DATA_A_BF16) +vec2 dequantize(uint ib, uint iqs, uint a_offset) { + return vec2(bf16_to_fp32(data_a[a_offset + ib]), bf16_to_fp32(data_a[a_offset + ib + 1])); +} +#endif + #if defined(DATA_A_Q4_0) vec2 dequantize(uint ib, uint iqs, uint a_offset) { const uint vui = uint(data_a[a_offset + ib].qs[iqs]); @@ -82,9 +88,9 @@ vec2 dequantize(uint ib, uint iqs, uint a_offset) { return vec2(int(data_a[a_offset + ib].qs[iqs]), int(data_a[a_offset + ib].qs[iqs + 1])); } vec4 dequantize4(uint ib, uint iqs, uint a_offset) { - uint32_t v0 = data_a_packed16[a_offset + ib].qs[iqs/2]; - uint32_t v1 = data_a_packed16[a_offset + ib].qs[iqs/2 + 1]; - return vec4(int8_t(v0 & 0xFF), int8_t(v0 >> 8), int8_t(v1 & 0xFF), int8_t(v1 >> 8)); + const i8vec2 v0 = unpack8(int32_t(data_a_packed16[a_offset + ib].qs[iqs/2])).xy; // vec4 used due to #12147 + const i8vec2 v1 = unpack8(int32_t(data_a_packed16[a_offset + ib].qs[iqs/2 + 1])).xy; + return vec4(v0.x, v0.y, v1.x, v1.y); } #endif @@ -428,7 +434,7 @@ vec4 dequantize4(uint ib, uint iqs, uint a_offset) { } #endif -#if defined(DATA_A_F32) || defined(DATA_A_F16) +#if defined(DATA_A_F32) || defined(DATA_A_F16) || defined(DATA_A_BF16) vec2 get_dm(uint ib, uint a_offset) { return vec2(0, 0); } diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.comp index 4770469ed..9cb7da2da 100644 --- a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.comp +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.comp @@ -92,7 +92,7 @@ float16_t dequantFuncQ8_0(const in decodeBufQ8_0 bl, const in uint blockCoords[2 const uint iqs = idx; // Load 16b and select the byte for this element - int32_t qs = unpack8(int32_t(bl.block.qs[(iqs & 0x1E) >> 1]))[iqs & 1]; + int32_t qs = unpack8(bl.block.qs[(iqs & 0x1E) >> 1])[iqs & 1]; float16_t ret = float16_t(qs) * d; return ret; } @@ -167,6 +167,101 @@ layout(buffer_reference, std430, buffer_reference_align = 16) buffer decodeBufQ4 block_q4_K_packed128 block; }; +#if defined(IS_MUL_MM2) + +// For Q4_K and Q5_K in the mat-mul shader, we decode a tile's worth of scales +// into shared memory and then process the whole tile using those scales. +// There is a fetch function that loads into private variables and then a store +// function that stores into shared memory. +// Q4_K and Q5_K have the same encoding of scales, so everything is shared except +// the part that fetches from the structure (which has a different block layout). +#if defined(DATA_A_Q4_K) || defined(DATA_A_Q5_K) +const uint shAscales_stride = (BM + 2); +// 1 scale per 32 elements -> 8 scales per block, per row +shared vec2 shAscales[8 * shAscales_stride]; +uvec4 row_v; +#endif + +#if defined(DATA_A_Q4_K) +layout (binding = 0) readonly buffer A_Q4_K_128 {block_q4_K_packed128 data_a_q4_k_packed128[];}; + +void fetch_scalesQ4_K(uint ir_BM, uint pos_a, uint stride_a, uint block_k, uint tid, bool in_bounds) +{ + uint tids_per_row = BLOCK_SIZE / BM; + uint is_per_tid = 8 / tids_per_row; + uint is_start = is_per_tid * (tid % tids_per_row); + uint tid_row = tid / tids_per_row; + + uint row = ir_BM + tid_row; + uint block_index = pos_a + row * stride_a + (block_k / QUANT_K); + if (in_bounds || row < p.M) { + row_v = data_a_q4_k_packed128[block_index].q4k[0]; + } +} +#endif +#if defined(DATA_A_Q5_K) +layout (binding = 0) readonly buffer A_Q5_K_128 {block_q5_K_packed128 data_a_q5_k_packed128[];}; + +void fetch_scalesQ5_K(uint ir_BM, uint pos_a, uint stride_a, uint block_k, uint tid, bool in_bounds) +{ + uint tids_per_row = BLOCK_SIZE / BM; + uint is_per_tid = 8 / tids_per_row; + uint is_start = is_per_tid * (tid % tids_per_row); + uint tid_row = tid / tids_per_row; + + uint row = ir_BM + tid_row; + uint block_index = pos_a + row * stride_a + (block_k / QUANT_K); + if (in_bounds || row < p.M) { + row_v = data_a_q5_k_packed128[block_index].q5k[0]; + } +} +#endif + +#if defined(DATA_A_Q4_K) || defined(DATA_A_Q5_K) +void store_scalesQ4_K(uint tid) +{ + barrier(); + + uint tids_per_row = BLOCK_SIZE / BM; + uint is_per_tid = 8 / tids_per_row; + uint is_start = is_per_tid * (tid % tids_per_row); + uint tid_row = tid / tids_per_row; + + [[unroll]] for (uint idx = 0; idx < is_per_tid; ++idx) { + uint is = idx + is_start; + uvec4 v = row_v; + const vec2 loadd = vec2(unpackFloat2x16(v.x)); + + uint32_t sc; + uint32_t mbyte; + + uint32_t scale0 = v.y; + uint32_t scale4 = v.z; + uint32_t scale8 = v.w; + + uint32_t sc_lo = scale0; + uint32_t mb_lo = scale4; + uint32_t sc_hi = (scale8 & 0x0F0F0F0F) | ((scale0 & 0xC0C0C0C0) >> 2); + uint32_t mb_hi = ((scale8 & 0xF0F0F0F0) >> 4) | ((scale4 & 0xC0C0C0C0) >> 2); + + sc = is < 4 ? sc_lo : sc_hi; + mbyte = is < 4 ? mb_lo : mb_hi; + sc = sc >> (8 * (is & 3)); + mbyte = mbyte >> (8 * (is & 3)); + sc &= 0x3F; + mbyte &= 0x3F; + + const float d = loadd.x * float(sc); + const float m = loadd.y * float(mbyte); + shAscales[is * shAscales_stride + tid_row] = vec2(d,m); + } + + barrier(); +} +#endif + +#endif + float16_t dequantFuncQ4_K(const in decodeBufQ4_K bl, const in uint blockCoords[2], const in uint coordInBlock[2]) { decodeBufQ4_K_packed16 bl16 = decodeBufQ4_K_packed16(bl); @@ -176,9 +271,13 @@ float16_t dequantFuncQ4_K(const in decodeBufQ4_K bl, const in uint blockCoords[2 const uint b = (idx & 0x20) >> 5; // 0,1 const uint is = (idx & 0xE0) >> 5; // 0..7 +#if defined(IS_MUL_MM2) && defined(DATA_A_Q4_K) + vec2 v = shAscales[is * shAscales_stride + (blockCoords[0] % BM)]; + float d = v.x; + float m = v.y; +#else uvec4 v = bl128.block.q4k[0]; - - const f16vec2 loadd = unpackFloat2x16(v.x); + const vec2 loadd = vec2(unpackFloat2x16(v.x)); uint32_t sc; uint32_t mbyte; @@ -199,15 +298,16 @@ float16_t dequantFuncQ4_K(const in decodeBufQ4_K bl, const in uint blockCoords[2 sc &= 0x3F; mbyte &= 0x3F; - const float16_t d = loadd.x * float16_t(sc); - const float16_t m = loadd.y * float16_t(mbyte); + const float d = loadd.x * float(sc); + const float m = loadd.y * float(mbyte); +#endif uint qs = uint32_t(bl16.block.qs[((idx & 0xC0) >> 2) + ((idx & 0x1E) >> 1)]); qs = (qs >> (b * 4 + 8 * (idx & 1))) & 0xF; - float16_t ret = d * float16_t(qs) - m; + float ret = d * float(qs) - m; - return ret; + return float16_t(ret); } layout(buffer_reference, std430, buffer_reference_align = 16) buffer decodeBufQ5_K { @@ -231,6 +331,11 @@ float16_t dequantFuncQ5_K(const in decodeBufQ5_K bl, const in uint blockCoords[2 const uint b = (idx & 0x20) >> 5; // 0,1 const uint is = (idx & 0xE0) >> 5; // 0..7 +#if defined(IS_MUL_MM2) && defined(DATA_A_Q5_K) + vec2 v = shAscales[is * shAscales_stride + (blockCoords[0] % BM)]; + float d = v.x; + float m = v.y; +#else uvec4 v = bl128.block.q5k[0]; const f16vec2 loadd = unpackFloat2x16(v.x); @@ -256,6 +361,7 @@ float16_t dequantFuncQ5_K(const in decodeBufQ5_K bl, const in uint blockCoords[2 const float16_t d = loadd.x * float16_t(sc); const float16_t m = loadd.y * float16_t(mbyte); +#endif uint qh = uint32_t(bl16.block.qh[(idx & 0x1E) >> 1]); qh = ((qh >> is) & 0x101) << 4; @@ -264,9 +370,9 @@ float16_t dequantFuncQ5_K(const in decodeBufQ5_K bl, const in uint blockCoords[2 qs = (qs >> (b * 4)) & 0x0F0F; qs = unpack8(qs | qh)[idx & 1]; - float16_t ret = d * (float16_t(qs)) - m; + float ret = d * float(qs) - m; - return ret; + return float16_t(ret); } layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufQ6_K { @@ -311,8 +417,8 @@ float16_t dequantFuncIQ1_S(const in decodeBufIQ1_S bl, const in uint blockCoords const float16_t d = bl.block.d; const uint idx = coordInBlock[1]; - const uint ib32 = idx / 32; - const uint ib8 = idx / 8; + const uint ib32 = (idx & 0xE0) >> 5; + const uint ib8 = (idx & 0xF8) >> 3; const uint qh = bl.block.qh[ib32]; const uint qs = bl.block.qs[ib8]; @@ -330,14 +436,20 @@ layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufIQ1 block_iq1_m block; }; +layout(buffer_reference, std430, buffer_reference_align = 8) buffer decodeBufIQ1_M_packed64 { + block_iq1_m_packed64 block; +}; + float16_t dequantFuncIQ1_M(const in decodeBufIQ1_M bl, const in uint blockCoords[2], const in uint coordInBlock[2]) { - const u16vec4 scales = u16vec4(bl.block.scales[0], bl.block.scales[1], bl.block.scales[2], bl.block.scales[3]) >> 12; - const float16_t d = uint16BitsToHalf(scales.x | (scales.y << 4) | (scales.z << 8) | (scales.w << 12)); + decodeBufIQ1_M_packed64 bl64 = decodeBufIQ1_M_packed64(bl); const uint idx = coordInBlock[1]; - const uint ib8 = idx / 8; - const uint ib16 = idx / 16; + uvec2 scales = unpack32(bl64.block.scales); + const float16_t d = uint16BitsToHalf(uint16_t(((scales.x & 0xF000) >> 12) | ((scales.x & 0xF0000000) >> 24) | ((scales.y & 0xF000) >> 4) | ((scales.y & 0xF0000000) >> 16))); + + const uint ib8 = (idx & 0xF8) >> 3; + const uint ib16 = (idx & 0xF0) >> 4; const int i8 = int(idx % 8); const uint sc = bl.block.scales[ib8 / 8]; const uint qs = bl.block.qs[ib8]; @@ -370,7 +482,7 @@ float16_t dequantFuncIQ2_XXS(const in decodeBufIQ2_XXS bl, const in uint blockCo const uint ib8 = (idx & 0x18) >> 3; // 0..3 const uint iqs = 8 * ib32 + ib8; - const uint8_t qs = bl.block.qs[iqs]; + const uint qs = bl.block.qs[iqs]; const uint signscale = pack32(u16vec2(bl16.block.qs[4*ib32+2], bl16.block.qs[4*ib32+3])); const float dscale = float(bl.block.d) * 0.25 * (0.5 + float(signscale >> 28)); @@ -558,8 +670,12 @@ float16_t dequantFuncIQ4_NL(const in decodeBufIQ4_NL bl, const in uint blockCoor #define dequantFuncA dequantFuncQ3_K #elif defined(DATA_A_Q4_K) #define dequantFuncA dequantFuncQ4_K +#define fetch_scales fetch_scalesQ4_K +#define store_scales store_scalesQ4_K #elif defined(DATA_A_Q5_K) #define dequantFuncA dequantFuncQ5_K +#define fetch_scales fetch_scalesQ5_K +#define store_scales store_scalesQ4_K #elif defined(DATA_A_Q6_K) #define dequantFuncA dequantFuncQ6_K #elif defined(DATA_A_IQ1_S) diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp new file mode 100644 index 000000000..e6545160d --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp @@ -0,0 +1,483 @@ +#version 450 + +#extension GL_EXT_control_flow_attributes : enable +#extension GL_EXT_shader_16bit_storage : require + +#extension GL_EXT_shader_explicit_arithmetic_types_float16 : require +#extension GL_EXT_shader_explicit_arithmetic_types_int32 : require + +#extension GL_KHR_shader_subgroup_shuffle : enable + +#include "types.comp" + +layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; + +layout (constant_id = 1) const uint32_t Br = 1; +layout (constant_id = 2) const uint32_t Bc = 32; +layout (constant_id = 3) const uint32_t D = 32; + +layout (constant_id = 5) const uint32_t D_split = 16; +const uint32_t D_per_thread = D / D_split; + +const uint32_t cols_per_iter = gl_WorkGroupSize.x / D_split; +const uint32_t cols_per_thread = Bc / cols_per_iter; + +layout (push_constant) uniform parameter { + uint32_t N; + uint32_t KV; + + uint32_t ne1; + uint32_t ne2; + uint32_t ne3; + + uint32_t neq2; + uint32_t neq3; + uint32_t nek2; + uint32_t nek3; + uint32_t nev2; + uint32_t nev3; + uint32_t nem1; + + uint32_t nb01; + uint32_t nb02; + uint32_t nb03; + uint32_t nb11; + uint32_t nb12; + uint32_t nb13; + uint32_t nb21; + uint32_t nb22; + uint32_t nb23; + uint32_t nb31; + + float scale; + float max_bias; + float logit_softcap; + + uint32_t mask; + uint32_t n_head_log2; + float m0; + float m1; + + uint32_t gqa_ratio; + uint32_t split_kv; + uint32_t k_num; +} p; + +layout (binding = 0) readonly buffer Q {float data_q[];}; +layout (binding = 0) readonly buffer QV4 {vec4 data_qv4[];}; +layout (binding = 1) readonly buffer K {float16_t data_k[];}; +layout (binding = 1) readonly buffer KV4 {f16vec4 data_kv4[];}; +layout (binding = 2) readonly buffer V {float16_t data_v[];}; +layout (binding = 2) readonly buffer VV4 {f16vec4 data_vv4[];}; +layout (binding = 3) readonly buffer M {float16_t data_m[];}; +layout (binding = 4) writeonly buffer O {D_TYPE data_o[];}; + +#if defined(A_TYPE_PACKED16) +#define BINDING_IDX_K 0 +#define BINDING_IDX_V 1 +layout (binding = 1) readonly buffer KV_PACKED16 {A_TYPE_PACKED16 data_packed16[];} kv_packed[2]; +#endif + +#if defined(DATA_A_Q4_0) +#define BLOCK_BYTE_SIZE 18 + +vec4 dequantize4(uint ib, uint iqs, uint a_offset, uint binding_idx) { + uint vui_lo = uint(kv_packed[binding_idx].data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 0]); + uint vui_hi = uint(kv_packed[binding_idx].data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 1]); + uint shift = (iqs & 0x10) >> 2; + vui_lo >>= shift; + vui_hi >>= shift; + + return float(kv_packed[binding_idx].data_packed16[a_offset + ib].d) * (vec4(vui_lo & 0xF, (vui_lo >> 8) & 0xF, vui_hi & 0xF, (vui_hi >> 8) & 0xF) - 8.0f); +} +#endif + +#if defined(DATA_A_Q8_0) +#define BLOCK_BYTE_SIZE 34 +vec4 dequantize4(uint ib, uint iqs, uint a_offset, uint binding_idx) { + const i8vec2 v0 = unpack8(int32_t(kv_packed[binding_idx].data_packed16[a_offset + ib].qs[iqs / 2])).xy; // vec4 used due to #12147 + const i8vec2 v1 = unpack8(int32_t(kv_packed[binding_idx].data_packed16[a_offset + ib].qs[iqs / 2 + 1])).xy; + + return float(kv_packed[binding_idx].data_packed16[a_offset + ib].d) * vec4(v0.x, v0.y, v1.x, v1.y); +} +#endif + +#define CEIL_DIV(a, b) (((a) + (b) - 1) / (b)) + +// Store the output when doing grouped query attention. +// Rows index by Q's dimension 2, and the first N rows are valid. +D_TYPE perElemOpGqaStore(const in uint32_t r, const in uint32_t c, const in D_TYPE elem, const in uint32_t o_offset, const in uint32_t iq2, const in uint32_t N) +{ + uint32_t offset = (iq2 + r) * D + c; + data_o[o_offset + offset] = D_TYPE(elem); + return elem; +} + +// Store column zero. This is used to save per-row m and L values for split_k. +ACC_TYPE perElemOpStoreCol0(const in uint32_t r, const in uint32_t c, const in ACC_TYPE elem, const in uint32_t o_offset, const in uint32_t iq2, const in uint32_t N) +{ + if (r < N && c == 0) { + uint32_t offset = iq2 + r; + data_o[o_offset + offset] = D_TYPE(elem); + } + return elem; +} + +// Load the slope matrix, indexed by Q's dimension 2. +ACC_TYPE perElemOpComputeSlope(const in uint32_t r, const in uint32_t c, const in ACC_TYPE elem, const in uint32_t iq2) +{ + const uint32_t h = iq2 + (r % p.gqa_ratio); + + const ACC_TYPE base = ACC_TYPE(h < p.n_head_log2 ? p.m0 : p.m1); + const int exph = int(h < p.n_head_log2 ? h + 1 : 2*(h - p.n_head_log2) + 1); + + return ACC_TYPE(pow(base, ACC_TYPE(exph))); +} + +shared FLOAT_TYPE tmpsh[gl_WorkGroupSize.x]; +shared vec4 tmpshv4[gl_WorkGroupSize.x]; + +shared float masksh[Bc][Br]; +shared vec4 Qf[Br][D / 4]; + +void main() { +#ifdef NEEDS_INIT_IQ_SHMEM + init_iq_shmem(gl_WorkGroupSize); +#endif + + const uint32_t tid = gl_LocalInvocationIndex; + const uint32_t N = p.N; + const uint32_t KV = p.KV; + + const uint32_t d_tid = gl_LocalInvocationIndex % D_split; + const uint32_t col_tid = gl_LocalInvocationIndex / D_split; + + uint32_t i = gl_WorkGroupID.x; + uint32_t split_k_index = 0; + + if (p.k_num > 1) { + i = 0; + split_k_index = gl_WorkGroupID.x; + } + + const uint32_t Tr = CEIL_DIV(N, Br); + + const uint32_t start_j = split_k_index * p.split_kv / Bc; + const uint32_t end_j = CEIL_DIV(min(KV, (split_k_index + 1) * p.split_kv), Bc); + + // When not using grouped query attention, all rows share the same iq2, equal to gl_WorkGroupID.y. + // When using grouped query attention, each workgroup does gqa_ratio consecutive values of iq2. + const uint32_t iq2 = gl_WorkGroupID.y * p.gqa_ratio; + const uint32_t iq3 = gl_WorkGroupID.z; + + // broadcast factors + const uint32_t rk2 = p.neq2/p.nek2; + const uint32_t rk3 = p.neq3/p.nek3; + + const uint32_t rv2 = p.neq2/p.nev2; + const uint32_t rv3 = p.neq3/p.nev3; + + // k indices + const uint32_t ik3 = iq3 / rk3; + const uint32_t ik2 = iq2 / rk2; + + // v indices + const uint32_t iv3 = iq3 / rv3; + const uint32_t iv2 = iq2 / rv2; + + // nb?1 are already divided by the type size and are in units of elements. + // When using grouped query attention, Q is indexed by iq2, so the stride + // should be nb02 (which is in bytes). + uint32_t q_stride = p.gqa_ratio > 1 ? (p.nb02 / 4) : p.nb01; + uint32_t k_stride = p.nb11; + uint32_t v_stride = p.nb21; + // When using grouped query attention, all rows use the same mask (stride 0). + // "p.gqa_ratio >> 16" is just a roundabout way of writing zero + // that prevents the compiler from folding the "&" through the select + // and breaking the alignment detection. + uint32_t m_stride = (p.gqa_ratio > 1) ? (p.gqa_ratio >> 16) : KV; + + uint32_t q_offset = (iq2*p.nb02+iq3*p.nb03) / 4; + + [[unroll]] for (uint32_t idx = 0; idx < Br * D / 4; idx += gl_WorkGroupSize.x) { + uint32_t d = (idx + tid) % (D / 4); + uint32_t r = (idx + tid) / (D / 4); + if (r < Br && d < D / 4 && + i * Br + r < N) { + Qf[r][d] = vec4(data_qv4[q_offset / 4 + (i * Br + r) * q_stride / 4 + d]) * p.scale; + } + } + barrier(); + + vec4 Of[Br][D_per_thread / 4]; + [[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) { + [[unroll]] for (uint32_t r = 0; r < Br; ++r) { + Of[r][d] = vec4(0.0); + } + } + + float Lf[Br], Mf[Br]; + + // Use -FLT_MAX/2 rather than -inf to reduce the possibility of NaNs, e.g. when computing Mold-M. + const float NEG_FLT_MAX_OVER_2 = uintBitsToFloat(0xFEFFFFFF); + + [[unroll]] for (uint32_t r = 0; r < Br; ++r) { + Lf[r] = 0; + Mf[r] = NEG_FLT_MAX_OVER_2; + } + + float slope[Br]; + [[unroll]] for (uint32_t r = 0; r < Br; ++r) { + slope[r] = 1.0; + } + + // ALiBi + if (p.max_bias > 0.0f) { + [[unroll]] for (uint32_t r = 0; r < Br; ++r) { + slope[r] = perElemOpComputeSlope(r, col_tid, ACC_TYPE(0), iq2); + } + } + +#if BLOCK_SIZE > 1 + uint32_t k_offset = (ik2*p.nb12 + ik3*p.nb13) / BLOCK_BYTE_SIZE; + uint32_t v_offset = (iv2*p.nb22 + iv3*p.nb23) / BLOCK_BYTE_SIZE; +#else + uint32_t k_offset = (ik2*p.nb12 + ik3*p.nb13) / 2; + uint32_t v_offset = (iv2*p.nb22 + iv3*p.nb23) / 2; +#endif + + [[dont_unroll]] + for (uint32_t j = start_j; j < end_j; ++j) { + + float Sf[Br][cols_per_thread]; + [[unroll]] for (uint32_t r = 0; r < Br; ++r) { + [[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) { + Sf[r][c] = 0.0; + } + } + + + [[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) { + [[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) { +#if BLOCK_SIZE > 1 + uint coord = (j * Bc + c * cols_per_iter + col_tid) * k_stride * BLOCK_SIZE + 4 * (d * D_split + d_tid); + uint ib = coord / BLOCK_SIZE; + uint iqs = (coord % BLOCK_SIZE); + vec4 K_Tf = dequantize4(ib, iqs, k_offset, BINDING_IDX_K); +#else + vec4 K_Tf = vec4(data_kv4[k_offset / 4 + (j * Bc + c * cols_per_iter + col_tid) * k_stride / 4 + d * D_split + d_tid]); +#endif + [[unroll]] for (uint32_t r = 0; r < Br; ++r) { + Sf[r][c] += dot(Qf[r][d * D_split + d_tid], K_Tf); + } + } + } + + [[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) { + // Compute sum across the D_split + [[unroll]] for (uint s = D_split / 2; s > 0; s >>= 1) { + [[unroll]] for (uint32_t r = 0; r < Br; ++r) { + Sf[r][c] += subgroupShuffleXor(Sf[r][c], s); + } + } + } + + if (p.logit_softcap != 0.0f) { + [[unroll]] for (uint32_t r = 0; r < Br; ++r) { + [[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) { + Sf[r][c] = p.logit_softcap * tanh(Sf[r][c]); + } + } + } + + if (p.mask != 0) { + + [[unroll]] for (uint32_t idx = 0; idx < Bc * Br; idx += gl_WorkGroupSize.x) { + uint32_t c = (idx + tid) % Bc; + uint32_t r = (idx + tid) / Bc; + if (idx + tid < Bc * Br) { + masksh[c][r] = float(data_m[(i * Br + r) * m_stride + (j * Bc + c)]); + } + } + barrier(); + + [[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) { + [[unroll]] for (uint32_t r = 0; r < Br; ++r) { + float mvf = masksh[c * cols_per_iter + col_tid][r]; + + Sf[r][c] += slope[r]*mvf; + } + } + barrier(); + } + + float rowmaxf[Br], Pf[Br][cols_per_thread], rowsumf[Br], eMf[Br], Moldf[Br]; + [[unroll]] for (uint32_t r = 0; r < Br; ++r) { + rowmaxf[r] = Sf[r][0]; + [[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) { + rowmaxf[r] = max(rowmaxf[r], Sf[r][c]); + } + Moldf[r] = Mf[r]; + + // M = max(rowmax, Mold) + // P = e^(S - M) + // eM = e^(Mold - M) + Mf[r] = max(rowmaxf[r], Moldf[r]); + [[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) { + Pf[r][c] = exp(Sf[r][c] - Mf[r]); + } + eMf[r] = exp(Moldf[r] - Mf[r]); + + // Compute sum across row of P + rowsumf[r] = 0.0; + [[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) { + rowsumf[r] += Pf[r][c]; + } + + Lf[r] = eMf[r]*Lf[r] + rowsumf[r]; + } + + [[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) { + [[unroll]] for (uint32_t r = 0; r < Br; ++r) { + Of[r][d] = eMf[r] * Of[r][d]; + } + } + + [[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) { + [[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) { +#if BLOCK_SIZE > 1 + uint coord = (j * Bc + c * cols_per_iter + col_tid) * v_stride * BLOCK_SIZE + 4 * (d * D_split + d_tid); + uint ib = coord / BLOCK_SIZE; + uint iqs = (coord % BLOCK_SIZE); + vec4 Vf = dequantize4(ib, iqs, v_offset, BINDING_IDX_V); +#else + vec4 Vf = vec4(data_vv4[v_offset / 4 + (j * Bc + c * cols_per_iter + col_tid) * v_stride / 4 + d * D_split + d_tid]); +#endif + [[unroll]] for (uint32_t r = 0; r < Br; ++r) { + Of[r][d] += Pf[r][c] * Vf; + } + } + } + + barrier(); + } + + // reduce across threads + + [[unroll]] for (uint32_t r = 0; r < Br; ++r) { + float rowmaxf, eMf; + + tmpsh[tid] = Mf[r]; + // Compute max across the row + barrier(); + [[unroll]] for (int s = int(gl_WorkGroupSize.x) / 2; s >= D_split; s >>= 1) { + if (tid < s) { + tmpsh[tid] = max(tmpsh[tid], tmpsh[tid + s]); + } + barrier(); + } + rowmaxf = tmpsh[d_tid]; + barrier(); + + float Moldf = Mf[r]; + + // M = max(rowmax, Mold) + // eM = e^(Mold - M) + Mf[r] = max(rowmaxf, Moldf); + eMf = exp(Moldf - Mf[r]); + + Lf[r] = eMf*Lf[r]; + + tmpsh[tid] = Lf[r]; + + // Compute sum across the row + barrier(); + [[unroll]] for (int s = int(gl_WorkGroupSize.x) / 2; s >= D_split; s >>= 1) { + if (tid < s) { + tmpsh[tid] = tmpsh[tid] + tmpsh[tid + s]; + } + barrier(); + } + Lf[r] = tmpsh[d_tid]; + barrier(); + + [[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) { + + Of[r][d] = eMf * Of[r][d]; + tmpshv4[tid] = Of[r][d]; + + barrier(); + [[unroll]] for (int s = int(gl_WorkGroupSize.x) / 2; s >= D_split; s >>= 1) { + if (tid < s) { + Of[r][d] += tmpshv4[tid + s]; + tmpshv4[tid] = Of[r][d]; + } + barrier(); + } + Of[r][d] = tmpshv4[d_tid]; + barrier(); + } + } + + + // If there is split_k, then the split_k resolve shader does the final + // division by L. Store the intermediate O value and per-row m and L values. + if (p.k_num > 1) { + uint32_t o_offset = D * p.ne1 * split_k_index; + + [[unroll]] for (uint32_t r = 0; r < Br; ++r) { + if (r < N) { + [[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) { + [[unroll]] for (uint32_t comp = 0; comp < 4; ++comp) { + perElemOpGqaStore(r, 4*(d * D_split + d_tid) + comp, Of[r][d][comp], o_offset, iq2, N); + } + } + } + } + + o_offset = D * p.ne1 * p.k_num + p.ne1 * split_k_index * 2; + [[unroll]] for (uint32_t r = 0; r < Br; ++r) { + if (r < N) { + perElemOpStoreCol0(r, 0u, ACC_TYPE(Lf[r]), o_offset, iq2, N); + perElemOpStoreCol0(r, 0u, ACC_TYPE(Mf[r]), o_offset + p.ne1, iq2, N); + } + } + + return; + } + + float Lfrcp[Br]; + [[unroll]] for (uint32_t r = 0; r < Br; ++r) { + Lfrcp[r] = 1.0 / Lf[r]; + } + + [[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) { + [[unroll]] for (uint32_t r = 0; r < Br; ++r) { + Of[r][d] *= Lfrcp[r]; + } + } + + uint32_t o_offset = iq3*p.ne2*p.ne1; + + if (p.gqa_ratio > 1) { + [[unroll]] for (uint32_t r = 0; r < Br; ++r) { + if (r < N) { + [[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) { + [[unroll]] for (uint32_t comp = 0; comp < 4; ++comp) { + perElemOpGqaStore(r, 4*(d * D_split + d_tid) + comp, Of[r][d][comp], o_offset, iq2, N); + } + } + } + } + } else { + [[unroll]] for (uint32_t r = 0; r < Br; ++r) { + if (i * Br + r < N) { + [[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) { + [[unroll]] for (uint32_t comp = 0; comp < 4; ++comp) { + data_o[o_offset + iq2 * D + (i * Br + r) * p.ne1 * D + 4*(d * D_split + d_tid) + comp] = D_TYPE(Of[r][d][comp]); + } + } + } + } + } +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp index df30355f6..b926a578a 100644 --- a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp @@ -61,6 +61,10 @@ layout (push_constant) uniform parameter { uint32_t n_head_log2; float m0; float m1; + + uint32_t gqa_ratio; + uint32_t split_kv; + uint32_t k_num; } p; layout (binding = 0) readonly buffer Q {uint8_t data_q[];}; @@ -103,6 +107,38 @@ ACC_TYPE Max(const in uint32_t row, const in uint32_t col, const in ACC_TYPE ele #define DECODEFUNC #endif +// Store the output when doing grouped query attention. +// Rows index by Q's dimension 2, and the first N rows are valid. +D_TYPE perElemOpGqaStore(const in uint32_t r, const in uint32_t c, const in D_TYPE elem, const in uint32_t o_offset, const in uint32_t iq2, const in uint32_t N) +{ + if (r < N && c < D) { + uint32_t offset = (iq2 + r) * D + c; + data_o[o_offset + offset] = D_TYPE(elem); + } + return elem; +} + +// Store column zero. This is used to save per-row m and L values for split_k. +ACC_TYPE perElemOpStoreCol0(const in uint32_t r, const in uint32_t c, const in ACC_TYPE elem, const in uint32_t o_offset, const in uint32_t iq2, const in uint32_t N) +{ + if (r < N && c == 0) { + uint32_t offset = iq2 + r; + data_o[o_offset + offset] = D_TYPE(elem); + } + return elem; +} + +// Load the slope matrix, indexed by Q's dimension 2. +ACC_TYPE perElemOpComputeSlope(const in uint32_t r, const in uint32_t c, const in ACC_TYPE elem, const in uint32_t iq2) +{ + const uint32_t h = iq2 + (r % p.gqa_ratio); + + const ACC_TYPE base = ACC_TYPE(h < p.n_head_log2 ? p.m0 : p.m1); + const int exph = int(h < p.n_head_log2 ? h + 1 : 2*(h - p.n_head_log2) + 1); + + return ACC_TYPE(pow(base, ACC_TYPE(exph))); +} + void main() { #ifdef NEEDS_INIT_IQ_SHMEM init_iq_shmem(gl_WorkGroupSize); @@ -111,12 +147,22 @@ void main() { const uint32_t N = p.N; const uint32_t KV = p.KV; + uint32_t i = gl_WorkGroupID.x; + uint32_t split_k_index = 0; + + if (p.k_num > 1) { + i = 0; + split_k_index = gl_WorkGroupID.x; + } + const uint32_t Tr = CEIL_DIV(N, Br); - const uint32_t Tc = CEIL_DIV(KV, Bc); - const uint32_t i = gl_WorkGroupID.x; + const uint32_t start_j = split_k_index * p.split_kv / Bc; + const uint32_t end_j = CEIL_DIV(min(KV, (split_k_index + 1) * p.split_kv), Bc); - const uint32_t iq2 = gl_WorkGroupID.y; + // When not using grouped query attention, all rows share the same iq2, equal to gl_WorkGroupID.y. + // When using grouped query attention, each workgroup does gqa_ratio consecutive values of iq2. + const uint32_t iq2 = gl_WorkGroupID.y * p.gqa_ratio; const uint32_t iq3 = gl_WorkGroupID.z; // broadcast factors @@ -149,10 +195,17 @@ void main() { tensorLayoutK = setTensorLayoutDimensionNV(tensorLayoutK, KV, D); tensorLayoutV = setTensorLayoutDimensionNV(tensorLayoutV, KV, D); - // nb?1 are already divided by the type size and are in units of elements - uint32_t q_stride = p.nb01; + // nb?1 are already divided by the type size and are in units of elements. + // When using grouped query attention, Q is indexed by iq2, so the stride + // should be nb02 (which is in bytes). + uint32_t q_stride = p.gqa_ratio > 1 ? (p.nb02 / 4) : p.nb01; uint32_t k_stride = p.nb11; uint32_t v_stride = p.nb21; + // When using grouped query attention, all rows use the same mask (stride 0). + // "p.gqa_ratio >> 16" is just a roundabout way of writing zero + // that prevents the compiler from folding the "&" through the select + // and breaking the alignment detection. + uint32_t m_stride = (p.gqa_ratio > 1) ? (p.gqa_ratio >> 16) : KV; // hint to the compiler that strides are aligned for the aligned variant of the shader if (Clamp != gl_CooperativeMatrixClampModeConstantNV) { @@ -161,6 +214,7 @@ void main() { k_stride &= ~7; v_stride &= ~7; #endif + m_stride &= ~7; } tensorLayoutQ = setTensorLayoutStrideNV(tensorLayoutQ, q_stride, 1); tensorLayoutK = setTensorLayoutStrideNV(tensorLayoutK, k_stride, 1); @@ -179,23 +233,21 @@ void main() { coopmat L, M; - L = coopmat(0); - M = coopmat(-1.0/0.0); + // Use -FLT_MAX/2 rather than -inf to reduce the possibility of NaNs, e.g. when computing Mold-M. + const float NEG_FLT_MAX_OVER_2 = uintBitsToFloat(0xFEFFFFFF); - ACC_TYPE slope = ACC_TYPE(1.0); + L = coopmat(0); + M = coopmat(NEG_FLT_MAX_OVER_2); + + coopmat slopeMat = coopmat(1.0); // ALiBi if (p.max_bias > 0.0f) { - const uint32_t h = iq2; - - const ACC_TYPE base = ACC_TYPE(h < p.n_head_log2 ? p.m0 : p.m1); - const int exph = int(h < p.n_head_log2 ? h + 1 : 2*(h - p.n_head_log2) + 1); - - slope = pow(base, ACC_TYPE(exph)); + coopMatPerElementNV(slopeMat, slopeMat, perElemOpComputeSlope, iq2); } [[dont_unroll]] - for (uint32_t j = 0; j < Tc; ++j) { + for (uint32_t j = start_j; j < end_j; ++j) { coopmat S = coopmat(0); @@ -213,14 +265,15 @@ void main() { } if (p.mask != 0) { - tensorLayoutNV<2, gl_CooperativeMatrixClampModeConstantNV> tensorLayoutM = createTensorLayoutNV(2, gl_CooperativeMatrixClampModeConstantNV); + tensorLayoutNV<2, Clamp> tensorLayoutM = createTensorLayoutNV(2, Clamp); tensorLayoutM = setTensorLayoutDimensionNV(tensorLayoutM, p.nem1, KV); + tensorLayoutM = setTensorLayoutStrideNV(tensorLayoutM, m_stride, 1); coopmat mv; coopMatLoadTensorNV(mv, data_m, 0, sliceTensorLayoutNV(tensorLayoutM, i * Br, Br, j * Bc, Bc)); - S += slope*coopmat(mv); + S += slopeMat*coopmat(mv); } // Clear padding elements to -inf, so they don't contribute to rowmax @@ -231,7 +284,7 @@ void main() { uint R = ((i + 1) * Br > N) ? (N % Br) : Br; uint C = ((j + 1) * Bc > KV) ? (KV % Bc) : Bc; - coopMatPerElementNV(S, S, replacePadding, ACC_TYPE(-1.0/0.0), R, C); + coopMatPerElementNV(S, S, replacePadding, ACC_TYPE(NEG_FLT_MAX_OVER_2), R, C); } coopmat rowmax, P, rowsum, eM; @@ -280,9 +333,25 @@ void main() { // resize eM by using smear/reduce coopMatReduceNV(eMdiag, eM, gl_CooperativeMatrixReduceRowNV, smearReduce); - O = eMdiag * O; + // multiply with fp16 accumulation, then add to O. + coopmat PV = coopmat(0); + PV = coopMatMulAdd(P_A, V, PV); - O = coopMatMulAdd(P_A, V, O); + O = eMdiag * O + coopmat(PV); + } + + // If there is split_k, then the split_k resolve shader does the final + // division by L. Store the intermediate O value and per-row m and L values. + if (p.k_num > 1) { + coopmat O_D = coopmat(O); + + uint32_t o_offset = D * p.ne1 * split_k_index; + coopMatPerElementNV(O_D, O_D, perElemOpGqaStore, o_offset, iq2, N); + + o_offset = D * p.ne1 * p.k_num + p.ne1 * split_k_index * 2; + coopMatPerElementNV(L, L, perElemOpStoreCol0, o_offset, iq2, N); + coopMatPerElementNV(M, M, perElemOpStoreCol0, o_offset + p.ne1, iq2, N); + return; } coopmat Ldiag; @@ -297,13 +366,18 @@ void main() { O = Ldiag*O; - tensorLayoutNV<3, gl_CooperativeMatrixClampModeConstantNV> tensorLayoutD = createTensorLayoutNV(3, gl_CooperativeMatrixClampModeConstantNV); - tensorLayoutD = setTensorLayoutDimensionNV(tensorLayoutD, p.ne2, p.ne1, D); - - // permute dimensions - tensorViewNV<3, false, 1, 0, 2> tensorViewPermute = createTensorViewNV(3, false, 1, 0, 2); uint32_t o_offset = iq3*p.ne2*p.ne1; coopmat O_D = coopmat(O); - coopMatStoreTensorNV(O_D, data_o, o_offset, sliceTensorLayoutNV(tensorLayoutD, i * Br, Br, iq2, 1, 0, D), tensorViewPermute); + if (p.gqa_ratio > 1) { + coopMatPerElementNV(O_D, O_D, perElemOpGqaStore, o_offset, iq2, N); + } else { + tensorLayoutNV<3, gl_CooperativeMatrixClampModeConstantNV> tensorLayoutD = createTensorLayoutNV(3, gl_CooperativeMatrixClampModeConstantNV); + tensorLayoutD = setTensorLayoutDimensionNV(tensorLayoutD, p.ne2, p.ne1, D); + + // permute dimensions + tensorViewNV<3, false, 1, 0, 2> tensorViewPermute = createTensorViewNV(3, false, 1, 0, 2); + + coopMatStoreTensorNV(O_D, data_o, o_offset, sliceTensorLayoutNV(tensorLayoutD, i * Br, Br, iq2, N, 0, D), tensorViewPermute); + } } diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp new file mode 100644 index 000000000..a7e395685 --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp @@ -0,0 +1,59 @@ +#version 450 + +#extension GL_EXT_control_flow_attributes : enable + +#define BLOCK_SIZE 32 + +layout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) readonly buffer A {float data_a[];}; +layout (binding = 1) writeonly buffer D {float data_d[];}; + +layout (push_constant) uniform parameter { + uint D; + uint N; + uint k_num; +} p; + +void main() { + // Each workgroup handles a row + const uint n = gl_WorkGroupID.x; + const uint tid = gl_LocalInvocationID.x; + + uint D = p.D; + uint N = p.N; + uint k_num = p.k_num; + + uint l_offset = D * N * k_num + n; + uint m_offset = D * N * k_num + N + n; + uint lm_stride = N * 2; + + // Compute the max m value for the row + float m_max = -1.0/0.0; + [[unroll]] for (uint k = 0; k < k_num; ++k) { + float m = data_a[m_offset + k * lm_stride]; + m_max = max(m_max, m); + } + + // Compute L based on m_max + float L = 0; + [[unroll]] for (uint k = 0; k < k_num; ++k) { + float l = data_a[l_offset + k * lm_stride]; + float m = data_a[m_offset + k * lm_stride]; + L += exp(m - m_max) * l; + } + + L = 1.0 / L; + + // Scale and sum the O contributions based on m_max and store the result to memory + for (uint d = tid; d < D; d += BLOCK_SIZE) { + float O = 0.0; + [[unroll]] for (uint k = 0; k < k_num; ++k) { + uint o_offset = D * N * k + D * n + d; + float m = data_a[m_offset + k * lm_stride]; + O += exp(m - m_max) * data_a[o_offset]; + } + O *= L; + data_d[D * n + d] = O; + } +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/get_rows.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/get_rows.comp index e877ed779..ee6b86a18 100644 --- a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/get_rows.comp +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/get_rows.comp @@ -20,9 +20,14 @@ void main() { const uint a_offset = get_aoffset() + i01*p.nb01 + i11*p.nb02 + i12*p.nb03; const uint d_offset = get_doffset() + i10*p.nb21 + i11*p.nb22 + i12*p.nb23; -#ifndef OPTIMIZATION_ERROR_WORKAROUND - data_d[d_offset + i00] = D_TYPE(data_a[a_offset + i00]); +#if defined(DATA_A_BF16) + FLOAT_TYPE v = FLOAT_TYPE(bf16_to_fp32(data_a[a_offset + i00])); #else - data_d[d_offset + i00] = data_a[a_offset + i00]; + FLOAT_TYPE v = FLOAT_TYPE(data_a[a_offset + i00]); +#endif +#ifndef OPTIMIZATION_ERROR_WORKAROUND + data_d[d_offset + i00] = D_TYPE(v); +#else + data_d[d_offset + i00] = D_TYPE(v); #endif } diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/get_rows_quant.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/get_rows_quant.comp index c9f855687..cfd645a38 100644 --- a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/get_rows_quant.comp +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/get_rows_quant.comp @@ -1,5 +1,7 @@ #version 450 +#extension GL_EXT_control_flow_attributes : enable + #include "types.comp" #include "generic_binary_head.comp" #include "dequant_funcs.comp" diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp index 122b1e93f..09aa849e8 100644 --- a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp @@ -40,6 +40,20 @@ void main() { const uint batch = gl_GlobalInvocationID.z / p.IC; const uint ic = gl_GlobalInvocationID.z % p.IC; + const uint src_base = ic * p.offset_delta + batch * p.batch_offset; + const uint dst_base = ((batch * p.OH + oh) * p.OW) * p.CHW + ic * (p.KW * p.KH); + const int oh_s1 = int(oh) * p.s1; + const uint ksize = p.OW * (p.KH > 1 ? p.KW : 1); + + const uint base_linear_idx = gidx * NUM_ITER; + + const uint max_ky = ksize / p.OW; + + uint current_kx = base_linear_idx / ksize; + const uint rem = base_linear_idx - (current_kx * ksize); + uint current_ky = rem / p.OW; + uint current_ix = rem % p.OW; + A_TYPE values[NUM_ITER]; uint offset_dst[NUM_ITER]; [[unroll]] for (uint idx = 0; idx < NUM_ITER; ++idx) { @@ -48,36 +62,35 @@ void main() { [[unroll]] for (uint idx = 0; idx < NUM_ITER; ++idx) { - const uint i = gidx * NUM_ITER + idx; + const uint linear_idx = base_linear_idx + idx; - const uint ksize = p.OW * (p.KH > 1 ? p.KW : 1); - const uint kx = i / ksize; - const uint kd = kx * ksize; - const uint ky = (i - kd) / p.OW; - const uint ix = i % p.OW; - - const uint iiw = ix * p.s0 + kx * p.d0 - p.p0; - const uint iih = oh * p.s1 + ky * p.d1 - p.p1; - - offset_dst[idx] = - ((batch * p.OH + oh) * p.OW + ix) * p.CHW + - (ic * (p.KW * p.KH) + ky * p.KW + kx); - - if (i >= p.pelements) { + if (linear_idx >= p.pelements) { continue; } - if (iih < p.IH && iiw < p.IW) { - const uint offset_src = ic * p.offset_delta + batch * p.batch_offset; - values[idx] = data_a[offset_src + iih * p.IW + iiw]; + const uint iiw = current_ix * p.s0 + current_kx * p.d0 - p.p0; + const uint iih = oh_s1 + current_ky * p.d1 - p.p1; + + offset_dst[idx] = dst_base + current_ix * p.CHW + current_ky * p.KW + current_kx; + + if ((iih < p.IH) && (iiw < p.IW)) { + values[idx] = data_a[src_base + iih * p.IW + iiw]; + } + + if (++current_ix == p.OW) { + current_ix = 0; + if (++current_ky == max_ky) { + current_ky = 0; + current_kx++; + } } } [[unroll]] for (uint idx = 0; idx < NUM_ITER; ++idx) { - const uint i = gidx * NUM_ITER + idx; + const uint linear_idx = base_linear_idx + idx; - if (i >= p.pelements) { + if (linear_idx >= p.pelements) { continue; } diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/l2_norm.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/l2_norm.comp new file mode 100644 index 000000000..deba8c398 --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/l2_norm.comp @@ -0,0 +1,41 @@ +#version 450 + +#include "generic_head.comp" +#include "types.comp" + +#extension GL_EXT_control_flow_attributes : enable +#define BLOCK_SIZE 512 + +layout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) readonly buffer X {A_TYPE data_a[];}; +layout (binding = 1) writeonly buffer D {D_TYPE data_d[];}; + +shared FLOAT_TYPE sum[BLOCK_SIZE]; + +void main() { + const uint row = gl_WorkGroupID.z * 262144 + gl_WorkGroupID.y * 512 + gl_WorkGroupID.x; + const uint tid = gl_LocalInvocationID.x; + + sum[tid] = FLOAT_TYPE(0.0f); // partial sum for thread in warp + + [[unroll]] for (uint col = tid; col < p.KX; col += BLOCK_SIZE) { + const FLOAT_TYPE xi = FLOAT_TYPE(data_a[row*p.KX + col]); + sum[tid] += xi * xi; + } + + // sum up partial sums and write back result + barrier(); + [[unroll]] for (int s = BLOCK_SIZE / 2; s > 0; s >>= 1) { + if (tid < s) { + sum[tid] += sum[tid + s]; + } + barrier(); + } + + const FLOAT_TYPE scale = inversesqrt(max(sum[0], FLOAT_TYPE(p.param1))); + + [[unroll]] for (uint col = tid; col < p.KX; col += BLOCK_SIZE) { + data_d[row*p.KX + col] = D_TYPE(scale * FLOAT_TYPE(data_a[row*p.KX + col])); + } +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec.comp index 31ecd9f81..bb429dd59 100644 --- a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec.comp +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec.comp @@ -6,7 +6,7 @@ layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; -#if !defined(DATA_A_F32) && !defined(DATA_A_F16) +#if !defined(DATA_A_F32) && !defined(DATA_A_F16) && !defined(DATA_A_BF16) #define K_PER_ITER 8 #else #define K_PER_ITER 2 @@ -105,6 +105,16 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) { int unroll_count = 4; uint unrolled_iters = num_iters & ~(unroll_count - 1); +#if K_PER_ITER == 2 + // If the K dimension is odd, we need lastiter==true on the last iteration + // so OOB is computed correctly. Skip some unrolling to make that happen. + if ((p.ncols & 1) != 0 && + unrolled_iters == num_iters && + unrolled_iters > 0) { + unrolled_iters -= unroll_count; + } +#endif + uint i = 0; while (i < unrolled_iters) { // Manually partially unroll the loop @@ -113,8 +123,18 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) { i++; } } + unroll_count = 2; unrolled_iters = num_iters & ~(unroll_count - 1); + +#if K_PER_ITER == 2 + if ((p.ncols & 1) != 0 && + unrolled_iters == num_iters && + unrolled_iters > 0) { + unrolled_iters -= unroll_count; + } +#endif + while (i < unrolled_iters) { // Manually partially unroll the loop [[unroll]] for (uint k = 0; k < unroll_count; ++k) { diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_s.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_s.comp new file mode 100644 index 000000000..8d01536fa --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_s.comp @@ -0,0 +1,90 @@ +#version 450 +#extension GL_EXT_shader_explicit_arithmetic_types_int32 : require + +#include "mul_mat_vec_base.comp" + +layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; + +FLOAT_TYPE temp[NUM_COLS][NUM_ROWS]; + +void calc_superblock(const uint a_offset, const uint b_offset, const uint itid, const uint i, const uint num_blocks_per_row, const uint first_row, const uint num_rows) { + const uint y_idx = i * QUANT_K + 16 * itid; + const uint nibble_shift = 4 * (itid & 1); + const uint ib32 = itid / 2; // 0..7 + + uint ibi = a_offset / QUANT_K + first_row * num_blocks_per_row + i; + [[unroll]] for (uint n = 0; n < num_rows; ++n) { + const float d = float(data_a[ibi].d); + const uint scale = (data_a[ibi].scales[ib32] >> nibble_shift) & 0xF; + const float db = d * (0.5 + scale) * 0.25; + + const uint qh = data_a[ibi].qh[ib32]; + const u8vec2 qs16 = unpack8(uint32_t(data_a_packed16[ibi].qs[itid])).xy; // vec4 used due to #12147 + const u8vec2 sign16 = unpack8(uint32_t(data_a_packed16[ibi].qs[QUANT_K / 16 + itid])).xy; + [[unroll]] for (uint l = 0; l < 2; ++l) { + const uint8_t sign = sign16[l]; + const uint qs = qs16[l] | ((qh << (8 - nibble_shift - 2 * l)) & 0x300); + const uvec2 grid = iq2s_grid[qs]; + const vec4 grid0 = vec4(unpack8(grid.x)); + const vec4 grid1 = vec4(unpack8(grid.y)); + + [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { + vec4 b0 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + y_idx) / 4 + 2*l + 0]); + vec4 b4 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + y_idx) / 4 + 2*l + 1]); + + FLOAT_TYPE sum = + fma(FLOAT_TYPE(b0.x), FLOAT_TYPE((sign & 1) != 0 ? -grid0.x : grid0.x), + fma(FLOAT_TYPE(b0.y), FLOAT_TYPE((sign & 2) != 0 ? -grid0.y : grid0.y), + fma(FLOAT_TYPE(b0.z), FLOAT_TYPE((sign & 4) != 0 ? -grid0.z : grid0.z), + fma(FLOAT_TYPE(b0.w), FLOAT_TYPE((sign & 8) != 0 ? -grid0.w : grid0.w), + fma(FLOAT_TYPE(b4.x), FLOAT_TYPE((sign & 16) != 0 ? -grid1.x : grid1.x), + fma(FLOAT_TYPE(b4.y), FLOAT_TYPE((sign & 32) != 0 ? -grid1.y : grid1.y), + fma(FLOAT_TYPE(b4.z), FLOAT_TYPE((sign & 64) != 0 ? -grid1.z : grid1.z), + fma(FLOAT_TYPE(b4.w), FLOAT_TYPE((sign & 128) != 0 ? -grid1.w : grid1.w), + FLOAT_TYPE(0.0))))))))); + temp[j][n] = fma(db, sum, temp[j][n]); + } + } + ibi += num_blocks_per_row; + } +} + +void compute_outputs(const uint32_t first_row, const uint32_t num_rows) { + uint a_offset, b_offset, d_offset; + get_offsets(a_offset, b_offset, d_offset); + + const uint num_blocks_per_row = p.ncols / QUANT_K; + + // 16 threads are used to process each block + const uint blocks_per_wg = gl_WorkGroupSize.x/16; + const uint tid = gl_LocalInvocationID.x; + const uint itid = tid % 16; // 0...15 + const uint ix = tid / 16; + + [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { + [[unroll]] for (uint i = 0; i < NUM_ROWS; ++i) { + temp[j][i] = FLOAT_TYPE(0); + } + } + + [[unroll]] for (uint i = ix; i < num_blocks_per_row; i += blocks_per_wg) + calc_superblock(a_offset, b_offset, itid, i, num_blocks_per_row, first_row, num_rows); + + reduce_result(temp, d_offset, first_row, num_rows, tid); +} + +void main() { + const uint first_row = NUM_ROWS * (gl_WorkGroupID.x + gl_NumWorkGroups.x * gl_WorkGroupID.z); + + init_iq_shmem(gl_WorkGroupSize); + + // do NUM_ROWS at a time, unless there aren't enough remaining rows + if (first_row + NUM_ROWS <= p.stride_d) { + compute_outputs(first_row, NUM_ROWS); + } else { + if (first_row >= p.stride_d) { + return; + } + compute_outputs(first_row, p.stride_d - first_row); + } +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_xs.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_xs.comp new file mode 100644 index 000000000..c49604324 --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_xs.comp @@ -0,0 +1,87 @@ +#version 450 +#extension GL_EXT_shader_explicit_arithmetic_types_int32 : require + +#include "mul_mat_vec_base.comp" + +layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; + +FLOAT_TYPE temp[NUM_COLS][NUM_ROWS]; + +void calc_superblock(const uint a_offset, const uint b_offset, const uint itid, const uint i, const uint num_blocks_per_row, const uint first_row, const uint num_rows) { + const uint y_idx = i * QUANT_K + 16 * itid; + const uint nibble_shift = 4 * (itid & 1); + const uint ib32 = itid / 2; // 0..7 + + uint ibi = a_offset / QUANT_K + first_row * num_blocks_per_row + i; + [[unroll]] for (uint n = 0; n < num_rows; ++n) { + const float d = float(data_a[ibi].d); + const uint scale = (data_a[ibi].scales[ib32] >> nibble_shift) & 0xF; + const float db = d * (0.5 + scale) * 0.25; + + [[unroll]] for (uint l = 0; l < 2; ++l) { + const uint qs = data_a[ibi].qs[2 * itid + l]; + const uint sign = qs >> 9; + const uint sign7 = bitCount(sign); + const vec4 grid0 = vec4(unpack8(iq2xs_grid[qs & 511].x)); + const vec4 grid1 = vec4(unpack8(iq2xs_grid[qs & 511].y)); + + [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { + vec4 b0 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + y_idx) / 4 + 2*l + 0]); + vec4 b4 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + y_idx) / 4 + 2*l + 1]); + + FLOAT_TYPE sum = + fma(FLOAT_TYPE(b0.x), FLOAT_TYPE((sign & 1) != 0 ? -grid0.x : grid0.x), + fma(FLOAT_TYPE(b0.y), FLOAT_TYPE((sign & 2) != 0 ? -grid0.y : grid0.y), + fma(FLOAT_TYPE(b0.z), FLOAT_TYPE((sign & 4) != 0 ? -grid0.z : grid0.z), + fma(FLOAT_TYPE(b0.w), FLOAT_TYPE((sign & 8) != 0 ? -grid0.w : grid0.w), + fma(FLOAT_TYPE(b4.x), FLOAT_TYPE((sign & 16) != 0 ? -grid1.x : grid1.x), + fma(FLOAT_TYPE(b4.y), FLOAT_TYPE((sign & 32) != 0 ? -grid1.y : grid1.y), + fma(FLOAT_TYPE(b4.z), FLOAT_TYPE((sign & 64) != 0 ? -grid1.z : grid1.z), + fma(FLOAT_TYPE(b4.w), FLOAT_TYPE((sign7 & 1) != 0 ? -grid1.w : grid1.w), + FLOAT_TYPE(0.0))))))))); + temp[j][n] = fma(db, sum, temp[j][n]); + } + } + ibi += num_blocks_per_row; + } +} + +void compute_outputs(const uint32_t first_row, const uint32_t num_rows) { + uint a_offset, b_offset, d_offset; + get_offsets(a_offset, b_offset, d_offset); + + const uint num_blocks_per_row = p.ncols / QUANT_K; + + // 16 threads are used to process each block + const uint blocks_per_wg = gl_WorkGroupSize.x/16; + const uint tid = gl_LocalInvocationID.x; + const uint itid = tid % 16; // 0...15 + const uint ix = tid / 16; + + [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { + [[unroll]] for (uint i = 0; i < NUM_ROWS; ++i) { + temp[j][i] = FLOAT_TYPE(0); + } + } + + [[unroll]] for (uint i = ix; i < num_blocks_per_row; i += blocks_per_wg) + calc_superblock(a_offset, b_offset, itid, i, num_blocks_per_row, first_row, num_rows); + + reduce_result(temp, d_offset, first_row, num_rows, tid); +} + +void main() { + const uint first_row = NUM_ROWS * (gl_WorkGroupID.x + gl_NumWorkGroups.x * gl_WorkGroupID.z); + + init_iq_shmem(gl_WorkGroupSize); + + // do NUM_ROWS at a time, unless there aren't enough remaining rows + if (first_row + NUM_ROWS <= p.stride_d) { + compute_outputs(first_row, NUM_ROWS); + } else { + if (first_row >= p.stride_d) { + return; + } + compute_outputs(first_row, p.stride_d - first_row); + } +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_xxs.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_xxs.comp new file mode 100644 index 000000000..94d4b92e1 --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_xxs.comp @@ -0,0 +1,87 @@ +#version 450 +#extension GL_EXT_shader_explicit_arithmetic_types_int32 : require + +#include "mul_mat_vec_base.comp" + +layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; + +FLOAT_TYPE temp[NUM_COLS][NUM_ROWS]; + +void calc_superblock(const uint a_offset, const uint b_offset, const uint itid, const uint i, const uint num_blocks_per_row, const uint first_row, const uint num_rows) { + const uint y_idx = i * QUANT_K + 16 * itid; + const uint ib32 = itid / 2; // 0..7 + + uint ibi = a_offset / QUANT_K + first_row * num_blocks_per_row + i; + [[unroll]] for (uint n = 0; n < num_rows; ++n) { + const float d = float(data_a[ibi].d); + const uint signscale = pack32(u16vec2( + data_a_packed16[ibi].qs[4 * ib32 + 2], + data_a_packed16[ibi].qs[4 * ib32 + 3])); + const float db = d * 0.25 * (0.5 + (signscale >> 28)); + [[unroll]] for (uint l = 0; l < 2; ++l) { + const uint qs = data_a[ibi].qs[8 * ib32 + 2 * (itid & 1) + l]; + const uint sign = bitfieldExtract(signscale, 7 * int(2 * (itid & 1) + l), 7); + const uint sign7 = bitCount(sign); + const vec4 grid0 = vec4(unpack8(iq2xxs_grid[qs].x)); + const vec4 grid1 = vec4(unpack8(iq2xxs_grid[qs].y)); + + [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { + const vec4 b0 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + y_idx) / 4 + 2*l + 0]); + const vec4 b4 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + y_idx) / 4 + 2*l + 1]); + + FLOAT_TYPE sum = + fma(FLOAT_TYPE(b0.x), FLOAT_TYPE((sign & 1) != 0 ? -grid0.x : grid0.x), + fma(FLOAT_TYPE(b0.y), FLOAT_TYPE((sign & 2) != 0 ? -grid0.y : grid0.y), + fma(FLOAT_TYPE(b0.z), FLOAT_TYPE((sign & 4) != 0 ? -grid0.z : grid0.z), + fma(FLOAT_TYPE(b0.w), FLOAT_TYPE((sign & 8) != 0 ? -grid0.w : grid0.w), + fma(FLOAT_TYPE(b4.x), FLOAT_TYPE((sign & 16) != 0 ? -grid1.x : grid1.x), + fma(FLOAT_TYPE(b4.y), FLOAT_TYPE((sign & 32) != 0 ? -grid1.y : grid1.y), + fma(FLOAT_TYPE(b4.z), FLOAT_TYPE((sign & 64) != 0 ? -grid1.z : grid1.z), + fma(FLOAT_TYPE(b4.w), FLOAT_TYPE((sign7 & 1) != 0 ? -grid1.w : grid1.w), + FLOAT_TYPE(0.0))))))))); + temp[j][n] = fma(db, sum, temp[j][n]); + } + } + ibi += num_blocks_per_row; + } +} + +void compute_outputs(const uint32_t first_row, const uint32_t num_rows) { + uint a_offset, b_offset, d_offset; + get_offsets(a_offset, b_offset, d_offset); + + const uint num_blocks_per_row = p.ncols / QUANT_K; + + // 16 threads are used to process each block + const uint blocks_per_wg = gl_WorkGroupSize.x/16; + const uint tid = gl_LocalInvocationID.x; + const uint itid = tid % 16; // 0...15 + const uint ix = tid / 16; + + [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { + [[unroll]] for (uint i = 0; i < NUM_ROWS; ++i) { + temp[j][i] = FLOAT_TYPE(0); + } + } + + [[unroll]] for (uint i = ix; i < num_blocks_per_row; i += blocks_per_wg) + calc_superblock(a_offset, b_offset, itid, i, num_blocks_per_row, first_row, num_rows); + + reduce_result(temp, d_offset, first_row, num_rows, tid); +} + +void main() { + const uint first_row = NUM_ROWS * (gl_WorkGroupID.x + gl_NumWorkGroups.x * gl_WorkGroupID.z); + + init_iq_shmem(gl_WorkGroupSize); + + // do NUM_ROWS at a time, unless there aren't enough remaining rows + if (first_row + NUM_ROWS <= p.stride_d) { + compute_outputs(first_row, NUM_ROWS); + } else { + if (first_row >= p.stride_d) { + return; + } + compute_outputs(first_row, p.stride_d - first_row); + } +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq3_s.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq3_s.comp new file mode 100644 index 000000000..f021e4047 --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq3_s.comp @@ -0,0 +1,90 @@ +#version 450 +#extension GL_EXT_shader_explicit_arithmetic_types_int32 : require + +#include "mul_mat_vec_base.comp" + +layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; + +FLOAT_TYPE temp[NUM_COLS][NUM_ROWS]; + +void calc_superblock(const uint a_offset, const uint b_offset, const uint ib32, const uint i, const uint num_blocks_per_row, const uint first_row, const uint num_rows) { + const uint y_idx = i * QUANT_K + 32 * ib32; + + uint ibi = a_offset / QUANT_K + first_row * num_blocks_per_row + i; + [[unroll]] for (uint n = 0; n < num_rows; ++n) { + const float d = float(data_a[ibi].d); + const uint scale = (data_a[ibi].scales[ib32/2] >> (4 * (ib32 & 1))) & 0xF; + const float dscale = d * (1 + 2 * scale); + const uint qh = data_a[ibi].qh[ib32]; + FLOAT_TYPE sum[NUM_COLS]; + [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { + sum[j] = 0.0; + } + [[unroll]] for (uint l = 0; l < 4; ++l) { + const u8vec2 qs = unpack8(uint32_t(data_a_packed16[ibi].qs[4 * ib32 + l])).xy; // vec4 used due to #12147 + const uint sign = data_a[ibi].signs[4 * ib32 + l]; + const vec4 grid0 = vec4(unpack8(iq3s_grid[qs.x | ((qh << (8 - 2*l)) & 0x100)])); + const vec4 grid1 = vec4(unpack8(iq3s_grid[qs.y | ((qh << (7 - 2*l)) & 0x100)])); + + [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { + const vec4 b0 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + y_idx) / 4 + 2*l + 0]); + const vec4 b4 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + y_idx) / 4 + 2*l + 1]); + + sum[j] = + fma(FLOAT_TYPE(b0.x), FLOAT_TYPE((sign & 1) != 0 ? -grid0.x : grid0.x), + fma(FLOAT_TYPE(b0.y), FLOAT_TYPE((sign & 2) != 0 ? -grid0.y : grid0.y), + fma(FLOAT_TYPE(b0.z), FLOAT_TYPE((sign & 4) != 0 ? -grid0.z : grid0.z), + fma(FLOAT_TYPE(b0.w), FLOAT_TYPE((sign & 8) != 0 ? -grid0.w : grid0.w), + fma(FLOAT_TYPE(b4.x), FLOAT_TYPE((sign & 16) != 0 ? -grid1.x : grid1.x), + fma(FLOAT_TYPE(b4.y), FLOAT_TYPE((sign & 32) != 0 ? -grid1.y : grid1.y), + fma(FLOAT_TYPE(b4.z), FLOAT_TYPE((sign & 64) != 0 ? -grid1.z : grid1.z), + fma(FLOAT_TYPE(b4.w), FLOAT_TYPE((sign & 128) != 0 ? -grid1.w : grid1.w), + sum[j])))))))); + } + } + [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { + temp[j][n] = fma(dscale, sum[j], temp[j][n]); + } + ibi += num_blocks_per_row; + } +} + +void compute_outputs(const uint32_t first_row, const uint32_t num_rows) { + uint a_offset, b_offset, d_offset; + get_offsets(a_offset, b_offset, d_offset); + + const uint num_blocks_per_row = p.ncols / QUANT_K; + + // 8 threads are used to process each block + const uint blocks_per_wg = gl_WorkGroupSize.x/8; + const uint tid = gl_LocalInvocationID.x; + const uint itid = tid % 8; // 0...7 + const uint ix = tid / 8; + + [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { + [[unroll]] for (uint i = 0; i < NUM_ROWS; ++i) { + temp[j][i] = FLOAT_TYPE(0); + } + } + + [[unroll]] for (uint i = ix; i < num_blocks_per_row; i += blocks_per_wg) + calc_superblock(a_offset, b_offset, itid, i, num_blocks_per_row, first_row, num_rows); + + reduce_result(temp, d_offset, first_row, num_rows, tid); +} + +void main() { + const uint first_row = NUM_ROWS * (gl_WorkGroupID.x + gl_NumWorkGroups.x * gl_WorkGroupID.z); + + init_iq_shmem(gl_WorkGroupSize); + + // do NUM_ROWS at a time, unless there aren't enough remaining rows + if (first_row + NUM_ROWS <= p.stride_d) { + compute_outputs(first_row, NUM_ROWS); + } else { + if (first_row >= p.stride_d) { + return; + } + compute_outputs(first_row, p.stride_d - first_row); + } +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq3_xxs.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq3_xxs.comp new file mode 100644 index 000000000..3fe9dc3a4 --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq3_xxs.comp @@ -0,0 +1,88 @@ +#version 450 +#extension GL_EXT_shader_explicit_arithmetic_types_int32 : require + +#include "mul_mat_vec_base.comp" + +layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; + +FLOAT_TYPE temp[NUM_COLS][NUM_ROWS]; + +void calc_superblock(const uint a_offset, const uint b_offset, const uint itid, const uint i, const uint num_blocks_per_row, const uint first_row, const uint num_rows) { + const uint y_idx = i * QUANT_K + 16 * itid; + const uint ib32 = itid / 2; // 0..7 + + uint ibi = a_offset / QUANT_K + first_row * num_blocks_per_row + i; + [[unroll]] for (uint n = 0; n < num_rows; ++n) { + const float d = float(data_a[ibi].d); + const uint signscale = pack32(u16vec2( + data_a_packed16[ibi].qs[QUANT_K / 8 + 2 * ib32], + data_a_packed16[ibi].qs[QUANT_K / 8 + 2 * ib32 + 1])); + const float db = d * 0.5 * (0.5 + (signscale >> 28)); + [[unroll]] for (uint l = 0; l < 2; ++l) { + const uint qs0 = data_a[ibi].qs[8 * ib32 + 4 * (itid & 1) + 2 * l]; + const uint qs1 = data_a[ibi].qs[8 * ib32 + 4 * (itid & 1) + 2 * l + 1]; + const uint sign = bitfieldExtract(signscale, 7 * int(2 * (itid & 1) + l), 7); + const uint sign7 = bitCount(sign); + const vec4 grid0 = vec4(unpack8(iq3xxs_grid[qs0])); + const vec4 grid1 = vec4(unpack8(iq3xxs_grid[qs1])); + + [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { + const vec4 b0 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + y_idx) / 4 + 2*l + 0]); + const vec4 b4 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + y_idx) / 4 + 2*l + 1]); + + FLOAT_TYPE sum = + fma(FLOAT_TYPE(b0.x), FLOAT_TYPE((sign & 1) != 0 ? -grid0.x : grid0.x), + fma(FLOAT_TYPE(b0.y), FLOAT_TYPE((sign & 2) != 0 ? -grid0.y : grid0.y), + fma(FLOAT_TYPE(b0.z), FLOAT_TYPE((sign & 4) != 0 ? -grid0.z : grid0.z), + fma(FLOAT_TYPE(b0.w), FLOAT_TYPE((sign & 8) != 0 ? -grid0.w : grid0.w), + fma(FLOAT_TYPE(b4.x), FLOAT_TYPE((sign & 16) != 0 ? -grid1.x : grid1.x), + fma(FLOAT_TYPE(b4.y), FLOAT_TYPE((sign & 32) != 0 ? -grid1.y : grid1.y), + fma(FLOAT_TYPE(b4.z), FLOAT_TYPE((sign & 64) != 0 ? -grid1.z : grid1.z), + fma(FLOAT_TYPE(b4.w), FLOAT_TYPE((sign7 & 1) != 0 ? -grid1.w : grid1.w), + FLOAT_TYPE(0.0))))))))); + temp[j][n] = fma(db, sum, temp[j][n]); + } + } + ibi += num_blocks_per_row; + } +} + +void compute_outputs(const uint32_t first_row, const uint32_t num_rows) { + uint a_offset, b_offset, d_offset; + get_offsets(a_offset, b_offset, d_offset); + + const uint num_blocks_per_row = p.ncols / QUANT_K; + + // 16 threads are used to process each block + const uint blocks_per_wg = gl_WorkGroupSize.x/16; + const uint tid = gl_LocalInvocationID.x; + const uint itid = tid % 16; // 0...15 + const uint ix = tid / 16; + + [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { + [[unroll]] for (uint i = 0; i < NUM_ROWS; ++i) { + temp[j][i] = FLOAT_TYPE(0); + } + } + + [[unroll]] for (uint i = ix; i < num_blocks_per_row; i += blocks_per_wg) + calc_superblock(a_offset, b_offset, itid, i, num_blocks_per_row, first_row, num_rows); + + reduce_result(temp, d_offset, first_row, num_rows, tid); +} + +void main() { + const uint first_row = NUM_ROWS * (gl_WorkGroupID.x + gl_NumWorkGroups.x * gl_WorkGroupID.z); + + init_iq_shmem(gl_WorkGroupSize); + + // do NUM_ROWS at a time, unless there aren't enough remaining rows + if (first_row + NUM_ROWS <= p.stride_d) { + compute_outputs(first_row, NUM_ROWS); + } else { + if (first_row >= p.stride_d) { + return; + } + compute_outputs(first_row, p.stride_d - first_row); + } +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_nc.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_nc.comp index 1cc4996d3..bc633369f 100644 --- a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_nc.comp +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_nc.comp @@ -12,13 +12,18 @@ layout (binding = 0) readonly buffer A {A_TYPE data_a[];}; layout (binding = 1) readonly buffer B {B_TYPE data_b[];}; layout (binding = 2) writeonly buffer D {D_TYPE dst[];}; +layout (binding = 0) readonly buffer AV4 {A_TYPE_VEC4 data_a_v4[];}; +layout (binding = 1) readonly buffer BV4 {B_TYPE_VEC4 data_b_v4[];}; + layout (push_constant) uniform parameter { uint ncols_x; uint nrows_x; uint row_stride_x; uint channel_stride_x; + uint channel_stride_y; uint channel_x_divisor; + uint ne12; uint b_offset; uint d_offset; } p; @@ -30,6 +35,7 @@ void main() { const uint row_x = gl_GlobalInvocationID.y; const uint channel = gl_GlobalInvocationID.z; const uint channel_x = channel / p.channel_x_divisor; + const uint channel_y = channel % p.ne12; const uint nrows_y = p.ncols_x; const uint nrows_dst = p.nrows_x; @@ -37,25 +43,66 @@ void main() { const uint idst = channel*nrows_dst + row_dst; - tmp[tid] = 0.0f; + FLOAT_TYPE temp = 0.0f; - for (uint col_x0 = 0; col_x0 < p.ncols_x; col_x0 += BLOCK_SIZE) { - const uint col_x = col_x0 + tid; + // Detect alignment for vector loads + bool is_aligned = (p.ncols_x % 4) == 0 && (p.row_stride_x % 4) == 0 && (p.channel_stride_x % 4) == 0; - if (col_x >= p.ncols_x) { - break; + for (uint col_x0 = 0; col_x0 < p.ncols_x;) { + + // Unroll 2x and do vec4 loads if aligned + const uint unroll_count = 2; + if (col_x0 + unroll_count * 4 * BLOCK_SIZE <= p.ncols_x && is_aligned) { + [[unroll]] for (uint i = 0; i < unroll_count; ++i) { + const uint col_x = col_x0 + 4*tid; + + const uint row_y = col_x; + + const uint ix = channel_x*p.channel_stride_x + row_x*p.row_stride_x + col_x; + const uint iy = channel_y*p.channel_stride_y + row_y; + + const vec4 av4 = vec4(data_a_v4[ix / 4]); + const vec4 bv4 = vec4(data_b_v4[iy / 4]); + + temp += dot(av4, bv4); + + col_x0 += 4*BLOCK_SIZE; + } + // do vec4 loads if aligned + } else if (col_x0 + 4*BLOCK_SIZE <= p.ncols_x && is_aligned) { + const uint col_x = col_x0 + 4*tid; + + const uint row_y = col_x; + + const uint ix = channel_x*p.channel_stride_x + row_x*p.row_stride_x + col_x; + const uint iy = channel_y*p.channel_stride_y + row_y; + + const vec4 av4 = vec4(data_a_v4[ix / 4]); + const vec4 bv4 = vec4(data_b_v4[iy / 4]); + + temp += dot(av4, bv4); + + col_x0 += 4*BLOCK_SIZE; + } else { + const uint col_x = col_x0 + tid; + if (col_x >= p.ncols_x) { + break; + } + + const uint row_y = col_x; + + const uint ix = channel_x*p.channel_stride_x + row_x*p.row_stride_x + col_x; + const uint iy = channel_y*p.channel_stride_y + row_y; + + const FLOAT_TYPE xi = FLOAT_TYPE(data_a[ix]); + + temp = fma(xi, FLOAT_TYPE(data_b[iy]), temp); + col_x0 += BLOCK_SIZE; } - - const uint row_y = col_x; - - const uint ix = channel_x*p.channel_stride_x + row_x*p.row_stride_x + col_x; - const uint iy = channel*nrows_y + row_y; - - const FLOAT_TYPE xi = FLOAT_TYPE(data_a[ix]); - - tmp[tid] = fma(xi, FLOAT_TYPE(data_b[iy]), tmp[tid]); } + tmp[tid] = temp; + // sum up partial sums and write back result barrier(); [[unroll]] for (int s = BLOCK_SIZE / 2; s > 0; s >>= 1) { diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_p021.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_p021.comp index 9b443807d..7aa070eeb 100644 --- a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_p021.comp +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_p021.comp @@ -2,16 +2,25 @@ #extension GL_EXT_control_flow_attributes : enable #extension GL_EXT_shader_16bit_storage : require +#if USE_SUBGROUP_ADD +#extension GL_KHR_shader_subgroup_arithmetic : enable +#endif -#define BLOCK_SIZE 32 #define FLOAT_TYPE float -layout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in; +layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; layout (binding = 0) readonly buffer A {A_TYPE data_a[];}; layout (binding = 1) readonly buffer B {B_TYPE data_b[];}; layout (binding = 2) writeonly buffer D {D_TYPE dst[];}; +layout (binding = 0) readonly buffer AV4 {A_TYPE_VEC4 data_a_v4[];}; +layout (binding = 1) readonly buffer BV4 {B_TYPE_VEC4 data_b_v4[];}; + +layout(constant_id = 0) const int BLOCK_SIZE = 32; +// gqa_ratio is in the range [1,8] +layout(constant_id = 1) const uint gqa_ratio = 1; + layout (push_constant) uniform parameter { uint ncols_x; @@ -22,52 +31,124 @@ layout (push_constant) uniform parameter uint d_offset; } p; -shared FLOAT_TYPE tmp[BLOCK_SIZE]; +#if !USE_SUBGROUP_ADD +shared FLOAT_TYPE tmp[8][BLOCK_SIZE]; +#endif void main() { const uint tid = gl_LocalInvocationID.x; const uint row_x = gl_GlobalInvocationID.y; - const uint channel = gl_GlobalInvocationID.z; - const uint channel_x = channel / (p.nchannels_y / p.nchannels_x); + + uint channel, channel_x; + + // When gqa_ratio > 1, each invocation does multiple rows. + // The row in the A matrix is starting from channel / gqa_ratio and the + // rows in the B matrix are [channel, channel+gqa_ratio). + // When gpa_ratio is 1, each invocation does one row. + if (gqa_ratio > 1) { + channel_x = gl_GlobalInvocationID.z; + channel = channel_x * gqa_ratio; + } else { + channel = gl_GlobalInvocationID.z; + channel_x = channel / (p.nchannels_y / p.nchannels_x);; + } const uint nrows_y = p.ncols_x; const uint nrows_dst = p.nrows_x; const uint row_dst = row_x; - tmp[tid] = FLOAT_TYPE(0.0f); - - for (uint col_x0 = 0; col_x0 < p.ncols_x; col_x0 += BLOCK_SIZE) { - const uint col_x = col_x0 + tid; - - if (col_x >= p.ncols_x) { - break; - } - - // x is transposed and permuted - const uint ix = row_x*p.nchannels_x*p.ncols_x + channel_x*p.ncols_x + col_x; - const FLOAT_TYPE xi = FLOAT_TYPE(data_a[ix]); - - const uint row_y = col_x; - - // y is not transposed but permuted - const uint iy = channel*nrows_y + row_y; - - tmp[tid] = fma(xi, FLOAT_TYPE(data_b[iy]), tmp[tid]); + FLOAT_TYPE temp[8]; + [[unroll]] for (uint i = 0; i < 8; ++i) { + temp[i] = FLOAT_TYPE(0.0f); } - // dst is not transposed and not permuted - const uint idst = channel*nrows_dst + row_dst; + // Detect alignment for vector loads + bool is_aligned = (p.ncols_x % 4) == 0 && (p.nchannels_x % 4) == 0 && (nrows_y % 4) == 0; + for (uint col_x0 = 0; col_x0 < p.ncols_x; col_x0 += BLOCK_SIZE) { + + // Use vec4 loads if aligned + if (col_x0 + 4*BLOCK_SIZE <= p.ncols_x && is_aligned) { + + uint col_x = col_x0 + 4*tid; + const uint row_y = col_x; + + // x is transposed and permuted + const uint ix = row_x*p.nchannels_x*p.ncols_x + channel_x*p.ncols_x + col_x; + const vec4 av4 = vec4(data_a_v4[ix / 4]); + + [[unroll]] for (uint c = 0; c < gqa_ratio; ++c) { + // y is not transposed but permuted + const uint iy = (channel + c)*nrows_y + row_y; + + vec4 bv4 = data_b_v4[iy / 4]; + temp[c] += dot(av4, bv4); + } + + col_x0 += 3*BLOCK_SIZE; + } else { + const uint col_x = col_x0 + tid; + + if (col_x >= p.ncols_x) { + break; + } + + // x is transposed and permuted + const uint ix = row_x*p.nchannels_x*p.ncols_x + channel_x*p.ncols_x + col_x; + const FLOAT_TYPE xi = FLOAT_TYPE(data_a[ix]); + + const uint row_y = col_x; + + [[unroll]] for (uint c = 0; c < gqa_ratio; ++c) { + // y is not transposed but permuted + const uint iy = (channel + c)*nrows_y + row_y; + + temp[c] = fma(xi, FLOAT_TYPE(data_b[iy]), temp[c]); + } + } + } + +#if USE_SUBGROUP_ADD + // reduce vec4 at a time + vec4 t = vec4(temp[0], temp[1], temp[2], temp[3]); + t = subgroupAdd(t); + temp[0] = t[0]; + temp[1] = t[1]; + temp[2] = t[2]; + temp[3] = t[3]; + if (gqa_ratio > 4) { + t = vec4(temp[4], temp[5], temp[6], temp[7]); + t = subgroupAdd(t); + temp[4] = t[0]; + temp[5] = t[1]; + temp[6] = t[2]; + temp[7] = t[3]; + } +#else + [[unroll]] for (uint c = 0; c < gqa_ratio; ++c) { + tmp[c][tid] = temp[c]; + } // sum up partial sums and write back result barrier(); [[unroll]] for (int s = BLOCK_SIZE / 2; s > 0; s >>= 1) { if (tid < s) { - tmp[tid] += tmp[tid + s]; + [[unroll]] for (uint c = 0; c < gqa_ratio; ++c) { + temp[c] += tmp[c][tid + s]; + tmp[c][tid] = temp[c]; + } } barrier(); } + [[unroll]] for (uint c = 0; c < gqa_ratio; ++c) { + temp[c] = tmp[c][tid]; + } +#endif if (tid == 0) { - dst[idst] = tmp[0]; + [[unroll]] for (uint c = 0; c < gqa_ratio; ++c) { + // dst is not transposed and not permuted + const uint idst = (channel + c)*nrows_dst + row_dst; + dst[idst] = temp[c]; + } } } diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q2_k.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q2_k.comp index 8cdc640e8..423ceb8a3 100644 --- a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q2_k.comp +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q2_k.comp @@ -5,23 +5,24 @@ layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; -shared FLOAT_TYPE sccache1[BLOCK_SIZE/16][16]; -shared FLOAT_TYPE sccache2[BLOCK_SIZE/16][16]; +shared FLOAT_TYPE sccache1[2][BLOCK_SIZE/16][16]; +shared FLOAT_TYPE sccache2[2][BLOCK_SIZE/16][16]; FLOAT_TYPE temp[NUM_COLS][NUM_ROWS]; +uint csel = 0; void calc_superblock(const uint a_offset, const uint b_offset, const uint itid, const uint v_im, const uint ix, const uint q_offset, const uint y_offset, const uint i, const uint num_blocks_per_row, const uint first_row, const uint num_rows, const bool all_threads) { const uint y_idx = i * QUANT_K + y_offset; [[unroll]] for (uint n = 0; n < num_rows; ++n) { const uint ib0 = a_offset / QUANT_K + (first_row+n)*num_blocks_per_row; + csel ^= 1; - barrier(); if (!all_threads) { // when we don't have enough blocks to use all threads if (i < num_blocks_per_row) { const uint32_t scale = uint32_t(data_a[ib0 + i].scales[itid]); - sccache1[ix][itid] = FLOAT_TYPE(scale & 0xF); - sccache2[ix][itid] = FLOAT_TYPE((scale >> 4) & 0xF); + sccache1[csel][ix][itid] = FLOAT_TYPE(scale & 0xF); + sccache2[csel][ix][itid] = FLOAT_TYPE((scale >> 4) & 0xF); } barrier(); @@ -29,8 +30,8 @@ void calc_superblock(const uint a_offset, const uint b_offset, const uint itid, continue; } else { const uint32_t scale = uint32_t(data_a[ib0 + i].scales[itid]); - sccache1[ix][itid] = FLOAT_TYPE(scale & 0xF); - sccache2[ix][itid] = FLOAT_TYPE((scale >> 4) & 0xF); + sccache1[csel][ix][itid] = FLOAT_TYPE(scale & 0xF); + sccache2[csel][ix][itid] = FLOAT_TYPE((scale >> 4) & 0xF); barrier(); } @@ -57,22 +58,22 @@ void calc_superblock(const uint a_offset, const uint b_offset, const uint itid, FLOAT_TYPE sum1 = FLOAT_TYPE(0.0); FLOAT_TYPE sum2 = FLOAT_TYPE(0.0); [[unroll]] for (int l = 0; l < 2; ++l) { - sum1 = fma(FLOAT_TYPE(b0[l]), sccache1[ix][ 8*v_im] * qs_u32_0[l ], - fma(FLOAT_TYPE(b16[l]), sccache1[ix][1 + 8*v_im] * qs_u32_0[l+2], - fma(FLOAT_TYPE(b32[l]), sccache1[ix][2 + 8*v_im] * qs_u32_2[l ], - fma(FLOAT_TYPE(b48[l]), sccache1[ix][3 + 8*v_im] * qs_u32_2[l+2], - fma(FLOAT_TYPE(b64[l]), sccache1[ix][4 + 8*v_im] * qs_u32_4[l ], - fma(FLOAT_TYPE(b80[l]), sccache1[ix][5 + 8*v_im] * qs_u32_4[l+2], - fma(FLOAT_TYPE(b96[l]), sccache1[ix][6 + 8*v_im] * qs_u32_6[l ], - fma(FLOAT_TYPE(b112[l]), sccache1[ix][7 + 8*v_im] * qs_u32_6[l+2], sum1)))))))); - sum2 = fma(FLOAT_TYPE(b0[l]), sccache2[ix][ 8*v_im], - fma(FLOAT_TYPE(b16[l]), sccache2[ix][1 + 8*v_im], - fma(FLOAT_TYPE(b32[l]), sccache2[ix][2 + 8*v_im], - fma(FLOAT_TYPE(b48[l]), sccache2[ix][3 + 8*v_im], - fma(FLOAT_TYPE(b64[l]), sccache2[ix][4 + 8*v_im], - fma(FLOAT_TYPE(b80[l]), sccache2[ix][5 + 8*v_im], - fma(FLOAT_TYPE(b96[l]), sccache2[ix][6 + 8*v_im], - fma(FLOAT_TYPE(b112[l]), sccache2[ix][7 + 8*v_im], sum2)))))))); + sum1 = fma(FLOAT_TYPE(b0[l]), sccache1[csel][ix][ 8*v_im] * qs_u32_0[l ], + fma(FLOAT_TYPE(b16[l]), sccache1[csel][ix][1 + 8*v_im] * qs_u32_0[l+2], + fma(FLOAT_TYPE(b32[l]), sccache1[csel][ix][2 + 8*v_im] * qs_u32_2[l ], + fma(FLOAT_TYPE(b48[l]), sccache1[csel][ix][3 + 8*v_im] * qs_u32_2[l+2], + fma(FLOAT_TYPE(b64[l]), sccache1[csel][ix][4 + 8*v_im] * qs_u32_4[l ], + fma(FLOAT_TYPE(b80[l]), sccache1[csel][ix][5 + 8*v_im] * qs_u32_4[l+2], + fma(FLOAT_TYPE(b96[l]), sccache1[csel][ix][6 + 8*v_im] * qs_u32_6[l ], + fma(FLOAT_TYPE(b112[l]), sccache1[csel][ix][7 + 8*v_im] * qs_u32_6[l+2], sum1)))))))); + sum2 = fma(FLOAT_TYPE(b0[l]), sccache2[csel][ix][ 8*v_im], + fma(FLOAT_TYPE(b16[l]), sccache2[csel][ix][1 + 8*v_im], + fma(FLOAT_TYPE(b32[l]), sccache2[csel][ix][2 + 8*v_im], + fma(FLOAT_TYPE(b48[l]), sccache2[csel][ix][3 + 8*v_im], + fma(FLOAT_TYPE(b64[l]), sccache2[csel][ix][4 + 8*v_im], + fma(FLOAT_TYPE(b80[l]), sccache2[csel][ix][5 + 8*v_im], + fma(FLOAT_TYPE(b96[l]), sccache2[csel][ix][6 + 8*v_im], + fma(FLOAT_TYPE(b112[l]), sccache2[csel][ix][7 + 8*v_im], sum2)))))))); } temp[j][n] = fma(dall, sum1, fma(-dmin, sum2, temp[j][n])); } diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q3_k.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q3_k.comp index 3116fad16..e91724a28 100644 --- a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q3_k.comp +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q3_k.comp @@ -5,20 +5,21 @@ layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; -shared FLOAT_TYPE sccache[BLOCK_SIZE/16][2][8]; +shared FLOAT_TYPE sccache[2][BLOCK_SIZE/16][2][8]; FLOAT_TYPE temp[NUM_COLS][NUM_ROWS]; +uint csel = 0; void calc_superblock(const uint a_offset, const uint b_offset, const uint ix, const uint itid8, const uint v_im, const uint v_im4, const uint v_in, const uint32_t hm_m[4], const uint q_offset, const uint y_offset, const uint s_shift, const uint i, const uint num_blocks_per_row, const uint first_row, const uint num_rows, const bool all_threads) { const uint y_idx = i * QUANT_K + y_offset; [[unroll]] for (uint n = 0; n < num_rows; ++n) { const uint ib0 = a_offset / QUANT_K + (first_row+n)*num_blocks_per_row; + csel ^= 1; if (!all_threads) { // when we don't have enough blocks to use all threads - barrier(); if (i < num_blocks_per_row) - sccache[ix][v_im][itid8] = FLOAT_TYPE(int8_t(((data_a[ib0+i].scales[itid8] >> v_im4) & 0xF) | (((data_a[ib0+i].scales[itid8%4+8] >> s_shift) & 3) << 4)) - 32); + sccache[csel][ix][v_im][itid8] = FLOAT_TYPE(int8_t(((data_a[ib0+i].scales[itid8] >> v_im4) & 0xF) | (((data_a[ib0+i].scales[itid8%4+8] >> s_shift) & 3) << 4)) - 32); barrier(); if (i >= num_blocks_per_row) @@ -40,8 +41,7 @@ void calc_superblock(const uint a_offset, const uint b_offset, const uint ix, co const vec4 qs_u32_6 = vec4(unpack8((qs_u32 >> 6) & 0x03030303)); if (all_threads) { - barrier(); - sccache[ix][v_im][itid8] = FLOAT_TYPE(int8_t(((data_a[ib0+i].scales[itid8] >> v_im4) & 0xF) | (((data_a[ib0+i].scales[itid8%4+8] >> s_shift) & 3) << 4)) - 32); + sccache[csel][ix][v_im][itid8] = FLOAT_TYPE(int8_t(((data_a[ib0+i].scales[itid8] >> v_im4) & 0xF) | (((data_a[ib0+i].scales[itid8%4+8] >> s_shift) & 3) << 4)) - 32); barrier(); } @@ -59,14 +59,14 @@ void calc_superblock(const uint a_offset, const uint b_offset, const uint ix, co FLOAT_TYPE sum = FLOAT_TYPE(0.0); [[unroll]] for (int l = 0; l < 2; ++l) { - sum = fma(FLOAT_TYPE( b0[l]) * sccache[ix][v_im][0], qs_u32_0[l ] - hmk_0[l ], - fma(FLOAT_TYPE( b16[l]) * sccache[ix][v_im][1], qs_u32_0[l+2] - hmk_0[l+2], - fma(FLOAT_TYPE( b32[l]) * sccache[ix][v_im][2], qs_u32_2[l ] - hmk_1[l ], - fma(FLOAT_TYPE( b48[l]) * sccache[ix][v_im][3], qs_u32_2[l+2] - hmk_1[l+2], - fma(FLOAT_TYPE( b64[l]) * sccache[ix][v_im][4], qs_u32_4[l ] - hmk_2[l ], - fma(FLOAT_TYPE( b80[l]) * sccache[ix][v_im][5], qs_u32_4[l+2] - hmk_2[l+2], - fma(FLOAT_TYPE( b96[l]) * sccache[ix][v_im][6], qs_u32_6[l ] - hmk_3[l ], - fma(FLOAT_TYPE(b112[l]) * sccache[ix][v_im][7], qs_u32_6[l+2] - hmk_3[l+2], sum)))))))); + sum = fma(FLOAT_TYPE( b0[l]) * sccache[csel][ix][v_im][0], qs_u32_0[l ] - hmk_0[l ], + fma(FLOAT_TYPE( b16[l]) * sccache[csel][ix][v_im][1], qs_u32_0[l+2] - hmk_0[l+2], + fma(FLOAT_TYPE( b32[l]) * sccache[csel][ix][v_im][2], qs_u32_2[l ] - hmk_1[l ], + fma(FLOAT_TYPE( b48[l]) * sccache[csel][ix][v_im][3], qs_u32_2[l+2] - hmk_1[l+2], + fma(FLOAT_TYPE( b64[l]) * sccache[csel][ix][v_im][4], qs_u32_4[l ] - hmk_2[l ], + fma(FLOAT_TYPE( b80[l]) * sccache[csel][ix][v_im][5], qs_u32_4[l+2] - hmk_2[l+2], + fma(FLOAT_TYPE( b96[l]) * sccache[csel][ix][v_im][6], qs_u32_6[l ] - hmk_3[l ], + fma(FLOAT_TYPE(b112[l]) * sccache[csel][ix][v_im][7], qs_u32_6[l+2] - hmk_3[l+2], sum)))))))); } temp[j][n] = fma(d, sum, temp[j][n]); } diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q6_k.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q6_k.comp index f05f96b5e..d53d9ee0a 100644 --- a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q6_k.comp +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q6_k.comp @@ -6,20 +6,21 @@ layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; -shared FLOAT_TYPE sccache[BLOCK_SIZE/16][16]; +shared FLOAT_TYPE sccache[2][BLOCK_SIZE/16][16]; FLOAT_TYPE temp[NUM_COLS][NUM_ROWS]; +uint csel = 0; void calc_superblock(const uint a_offset, const uint b_offset, const uint itid, const uint ix, const uint ql_offset, const uint qh_offset, const uint s_offset, const uint y_offset, const uint i, const uint num_blocks_per_row, const uint first_row, const uint num_rows, const bool all_threads) { const uint y_idx = i * QUANT_K + y_offset; [[unroll]] for (uint n = 0; n < num_rows; ++n) { const uint ib0 = a_offset / QUANT_K + (first_row+n)*num_blocks_per_row; + csel ^= 1; if (!all_threads) { // when we don't have enough blocks to use all threads - barrier(); if (i < num_blocks_per_row) - sccache[ix][itid] = FLOAT_TYPE(data_a[ib0 + i].scales[itid]); + sccache[csel][ix][itid] = FLOAT_TYPE(data_a[ib0 + i].scales[itid]); barrier(); if (i >= num_blocks_per_row) @@ -51,8 +52,7 @@ void calc_superblock(const uint a_offset, const uint b_offset, const uint itid, const vec4 q3 = vec4(unpack8(q3_u32)) - 32; if (all_threads) { - barrier(); - sccache[ix][itid] = FLOAT_TYPE(data_a[ib0 + i].scales[itid]); + sccache[csel][ix][itid] = FLOAT_TYPE(data_a[ib0 + i].scales[itid]); barrier(); } @@ -71,7 +71,7 @@ void calc_superblock(const uint a_offset, const uint b_offset, const uint itid, sum[2] = fma(FLOAT_TYPE(by64[l]), q2[l], sum[2]); sum[3] = fma(FLOAT_TYPE(by96[l]), q3[l], sum[3]); } - temp[j][n] = fma(fma(sum[0], sccache[ix][s_offset], fma(sum[1], sccache[ix][s_offset + 2], fma(sum[2], sccache[ix][s_offset + 4], sum[3] * sccache[ix][s_offset + 6]))), d, temp[j][n]); + temp[j][n] = fma(fma(sum[0], sccache[csel][ix][s_offset], fma(sum[1], sccache[csel][ix][s_offset + 2], fma(sum[2], sccache[csel][ix][s_offset + 4], sum[3] * sccache[csel][ix][s_offset + 6]))), d, temp[j][n]); } } } diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp index 39657195c..7859a1a60 100644 --- a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp @@ -10,6 +10,10 @@ #extension GL_EXT_shader_explicit_arithmetic_types_float16 : require #endif +#if defined(DATA_A_BF16) && defined(COOPMAT) +#extension GL_EXT_bfloat16 : enable +#endif + #ifdef COOPMAT #extension GL_KHR_cooperative_matrix : enable #extension GL_KHR_memory_scope_semantics : enable @@ -29,9 +33,20 @@ #define LOAD_VEC_B 1 #endif +#if !defined(TO_FLOAT_TYPE) +#define TO_FLOAT_TYPE FLOAT_TYPE +#endif + layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; layout (binding = 0) readonly buffer A {A_TYPE data_a[];}; +#if defined(A_TYPE_PACKED16) +layout (binding = 0) readonly buffer A_PACKED16 {A_TYPE_PACKED16 data_a_packed16[];}; +#endif +#if defined(A_TYPE_PACKED32) +layout (binding = 0) readonly buffer A_PACKED32 {A_TYPE_PACKED32 data_a_packed32[];}; +#endif + layout (binding = 1) readonly buffer B {B_TYPE data_b[];}; layout (binding = 2) writeonly buffer D {D_TYPE data_d[];}; @@ -88,7 +103,7 @@ shared FLOAT_TYPE buf_a[BM * SHMEM_STRIDE]; shared FLOAT_TYPE buf_b[BN * SHMEM_STRIDE]; #ifdef MUL_MAT_ID -shared u16vec2 row_ids[3072]; +shared u16vec2 row_ids[4096]; #endif // MUL_MAT_ID #define NUM_WARPS (BLOCK_SIZE / WARP) @@ -195,8 +210,8 @@ void main() { #endif #ifdef COOPMAT - coopmat cache_a; - coopmat cache_b; + coopmat cache_a; + coopmat cache_b; coopmat sums[cms_per_row * cms_per_col]; [[unroll]] for (uint i = 0; i < cms_per_row * cms_per_col; i++) { @@ -205,7 +220,7 @@ void main() { #else ACC_TYPE sums[WMITER * TM * WNITER * TN]; FLOAT_TYPE cache_a[WMITER * TM]; - FLOAT_TYPE cache_b[WNITER * TN]; + FLOAT_TYPE cache_b[TN]; [[unroll]] for (uint i = 0; i < WMITER*TM*WNITER*TN; i++) { sums[i] = ACC_TYPE(0.0f); @@ -241,76 +256,117 @@ void main() { buf_a[(loadc_a + l) * SHMEM_STRIDE + loadr_a] = FLOAT_TYPE(0.0f); } #endif +#elif defined(DATA_A_BF16) +#if LOAD_VEC_A == 4 + const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a; + const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A; + buf_a[buf_idx ] = TO_FLOAT_TYPE(data_a[idx].x); + buf_a[buf_idx + 1] = TO_FLOAT_TYPE(data_a[idx].y); + buf_a[buf_idx + 2] = TO_FLOAT_TYPE(data_a[idx].z); + buf_a[buf_idx + 3] = TO_FLOAT_TYPE(data_a[idx].w); +#else + if (ir * BM + loadc_a + l < p.M && block + loadr_a < end_k) { + buf_a[(loadc_a + l) * SHMEM_STRIDE + loadr_a] = TO_FLOAT_TYPE(data_a[pos_a + (loadc_a + l) * p.stride_a + loadr_a]); + } else { + buf_a[(loadc_a + l) * SHMEM_STRIDE + loadr_a] = TO_FLOAT_TYPE(uint16_t(0)); + } +#endif #elif defined(DATA_A_Q4_0) const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a; - const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a; + const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + 4 * loadr_a; - const uint ib = idx / 16; - const uint iqs = idx & 0xF; + const uint ib = idx / 4; + const uint iqs = idx & 0x03; - const float d = float(data_a[ib].d); - const uint vui = uint(data_a[ib].qs[iqs]); - const vec2 v = (vec2(vui & 0xF, vui >> 4) - 8.0f) * d; + const float d = float(data_a_packed16[ib].d); + const uint vui = uint(data_a_packed16[ib].qs[2*iqs]) | (uint(data_a_packed16[ib].qs[2*iqs + 1]) << 16); + const vec4 v0 = (vec4(unpack8(vui & 0x0F0F0F0F)) - 8.0f) * d; + const vec4 v1 = (vec4(unpack8((vui >> 4) & 0x0F0F0F0F)) - 8.0f) * d; - buf_a[buf_idx ] = FLOAT_TYPE(v.x); - buf_a[buf_idx + 16] = FLOAT_TYPE(v.y); + buf_a[buf_idx ] = FLOAT_TYPE(v0.x); + buf_a[buf_idx + 1 ] = FLOAT_TYPE(v0.y); + buf_a[buf_idx + 2 ] = FLOAT_TYPE(v0.z); + buf_a[buf_idx + 3 ] = FLOAT_TYPE(v0.w); + buf_a[buf_idx + 16] = FLOAT_TYPE(v1.x); + buf_a[buf_idx + 17] = FLOAT_TYPE(v1.y); + buf_a[buf_idx + 18] = FLOAT_TYPE(v1.z); + buf_a[buf_idx + 19] = FLOAT_TYPE(v1.w); #elif defined(DATA_A_Q4_1) const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a; - const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a; + const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + 4 * loadr_a; - const uint ib = idx / 16; - const uint iqs = idx & 0xF; + const uint ib = idx / 4; + const uint iqs = idx & 0x03; - const float d = float(data_a[ib].d); - const float m = float(data_a[ib].m); - const uint vui = uint(data_a[ib].qs[iqs]); - const vec2 v = vec2(vui & 0xF, vui >> 4) * d + m; + const float d = float(data_a_packed16[ib].d); + const float m = float(data_a_packed16[ib].m); + const uint vui = uint(data_a_packed16[ib].qs[2*iqs]) | (uint(data_a_packed16[ib].qs[2*iqs + 1]) << 16); + const vec4 v0 = vec4(unpack8(vui & 0x0F0F0F0F)) * d + m; + const vec4 v1 = vec4(unpack8((vui >> 4) & 0x0F0F0F0F)) * d + m; - buf_a[buf_idx ] = FLOAT_TYPE(v.x); - buf_a[buf_idx + 16] = FLOAT_TYPE(v.y); + buf_a[buf_idx ] = FLOAT_TYPE(v0.x); + buf_a[buf_idx + 1 ] = FLOAT_TYPE(v0.y); + buf_a[buf_idx + 2 ] = FLOAT_TYPE(v0.z); + buf_a[buf_idx + 3 ] = FLOAT_TYPE(v0.w); + buf_a[buf_idx + 16] = FLOAT_TYPE(v1.x); + buf_a[buf_idx + 17] = FLOAT_TYPE(v1.y); + buf_a[buf_idx + 18] = FLOAT_TYPE(v1.z); + buf_a[buf_idx + 19] = FLOAT_TYPE(v1.w); #elif defined(DATA_A_Q5_0) const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a; - const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a; + const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + 2 * loadr_a; - const uint ib = idx / 16; - const uint iqs = idx & 0xF; + const uint ib = idx / 8; + const uint iqs = idx & 0x07; - const float d = float(data_a[ib].d); - const uint uint_qh = uint(data_a[ib].qh[1]) << 16 | data_a[ib].qh[0]; - const ivec2 qh = ivec2(((uint_qh >> iqs) << 4) & 0x10, (uint_qh >> (iqs + 12)) & 0x10); - const uint vui = uint(data_a[ib].qs[iqs]); - const vec2 v = (vec2((vui & 0xF) | qh.x, (vui >> 4) | qh.y) - 16.0f) * d; + const float d = float(data_a_packed16[ib].d); + const uint uint_qh = uint(data_a_packed16[ib].qh[1]) << 16 | uint(data_a_packed16[ib].qh[0]); + const ivec2 qh0 = ivec2(((uint_qh >> 2*iqs) << 4) & 0x10, (uint_qh >> (2*iqs + 12)) & 0x10); + const ivec2 qh1 = ivec2(((uint_qh >> (2*iqs + 1)) << 4) & 0x10, (uint_qh >> (2*iqs + 13)) & 0x10); + + const uint vui = uint(data_a_packed16[ib].qs[iqs]); + const vec4 v = (vec4((vui & 0xF) | qh0.x, ((vui >> 4) & 0xF) | qh0.y, ((vui >> 8) & 0xF) | qh1.x, (vui >> 12) | qh1.y) - 16.0f) * d; buf_a[buf_idx ] = FLOAT_TYPE(v.x); + buf_a[buf_idx + 1 ] = FLOAT_TYPE(v.z); buf_a[buf_idx + 16] = FLOAT_TYPE(v.y); + buf_a[buf_idx + 17] = FLOAT_TYPE(v.w); #elif defined(DATA_A_Q5_1) const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a; - const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a; + const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + 2 * loadr_a; - const uint ib = idx / 16; - const uint iqs = idx & 0xF; + const uint ib = idx / 8; + const uint iqs = idx & 0x07; - const float d = float(data_a[ib].d); - const float m = float(data_a[ib].m); - const uint uint_qh = data_a[ib].qh; - const ivec2 qh = ivec2(((uint_qh >> iqs) << 4) & 0x10, (uint_qh >> (iqs + 12)) & 0x10); - const uint vui = uint(data_a[ib].qs[iqs]); - const vec2 v = vec2((vui & 0xF) | qh.x, (vui >> 4) | qh.y) * d + m; + const float d = float(data_a_packed16[ib].d); + const float m = float(data_a_packed16[ib].m); + const uint uint_qh = data_a_packed16[ib].qh; + const ivec2 qh0 = ivec2(((uint_qh >> 2*iqs) << 4) & 0x10, (uint_qh >> (2*iqs + 12)) & 0x10); + const ivec2 qh1 = ivec2(((uint_qh >> (2*iqs + 1)) << 4) & 0x10, (uint_qh >> (2*iqs + 13)) & 0x10); + + const uint vui = uint(data_a_packed16[ib].qs[iqs]); + const vec4 v = vec4((vui & 0xF) | qh0.x, ((vui >> 4) & 0xF) | qh0.y, ((vui >> 8) & 0xF) | qh1.x, (vui >> 12) | qh1.y) * d + m; buf_a[buf_idx ] = FLOAT_TYPE(v.x); + buf_a[buf_idx + 1 ] = FLOAT_TYPE(v.z); buf_a[buf_idx + 16] = FLOAT_TYPE(v.y); + buf_a[buf_idx + 17] = FLOAT_TYPE(v.w); #elif defined(DATA_A_Q8_0) const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a; const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A; - const uint ib = idx / 16; - const uint iqs = (idx & 0xF) * 2; + const uint ib = idx / 8; + const uint iqs = idx & 0x07; - const float d = float(data_a[ib].d); - const vec2 v = vec2(int(data_a[ib].qs[iqs]), int(data_a[ib].qs[iqs + 1])) * d; + const float d = float(data_a_packed16[ib].d); + const i8vec2 v0 = unpack8(int32_t(data_a_packed16[ib].qs[2*iqs])).xy; // vec4 used due to #12147 + const i8vec2 v1 = unpack8(int32_t(data_a_packed16[ib].qs[2*iqs + 1])).xy; + const vec4 v = vec4(v0.x, v0.y, v1.x, v1.y) * d; buf_a[buf_idx ] = FLOAT_TYPE(v.x); buf_a[buf_idx + 1] = FLOAT_TYPE(v.y); + buf_a[buf_idx + 2] = FLOAT_TYPE(v.z); + buf_a[buf_idx + 3] = FLOAT_TYPE(v.w); #elif defined(DATA_A_Q2_K) const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a; const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A; @@ -511,7 +567,7 @@ void main() { const uint sign = (sign7 | (bitCount(sign7) << 7)) >> (2 * (idx % 4)); const i8vec2 sign01 = i8vec2(1 - (2 & i8vec2(int8_t(sign << 1), int8_t(sign)))); const uint grid = iq2xxs_grid[qs][(idx % 4) / 2] >> (16 * (idx & 1)); - const vec2 v = db * vec2(sign01) * vec2(unpack8(grid).xy); + const vec2 v = db * vec2(sign01) * vec2(unpack8(grid).xy); // vec4 used due to #12147 buf_a[buf_idx ] = FLOAT_TYPE(v.x); buf_a[buf_idx + 1] = FLOAT_TYPE(v.y); @@ -531,7 +587,7 @@ void main() { const uint sign = (sign7 | (bitCount(sign7) << 7)) >> (2 * (idx % 4)); const i8vec2 sign01 = i8vec2(1 - (2 & i8vec2(int8_t(sign << 1), int8_t(sign)))); const uint grid = iq2xs_grid[qs & 511][(idx % 4) / 2] >> (16 * (idx & 1)); - const vec2 v = db * vec2(sign01) * vec2(unpack8(grid).xy); + const vec2 v = db * vec2(sign01) * vec2(unpack8(grid).xy); // vec4 used due to #12147 buf_a[buf_idx ] = FLOAT_TYPE(v.x); buf_a[buf_idx + 1] = FLOAT_TYPE(v.y); @@ -553,7 +609,7 @@ void main() { const float db = d * 0.25 * (0.5 + scale); const i8vec2 sign01 = i8vec2(1 - (2 & i8vec2(int8_t(sign << 1), int8_t(sign)))); const uint16_t grid = unpack16(iq2s_grid[qs | ((qh << (8 - qhshift)) & 0x300)][(idx & 2) >> 1])[idx & 1]; - const vec2 v = db * vec2(sign01) * vec2(unpack8(grid)); + const vec2 v = db * vec2(sign01) * vec2(unpack8(uint32_t(grid)).xy); // vec4 used due to #12147 buf_a[buf_idx ] = FLOAT_TYPE(v.x); buf_a[buf_idx + 1] = FLOAT_TYPE(v.y); @@ -578,7 +634,7 @@ void main() { const uint sign = (sign7 | (bitCount(sign7) << 7)) >> (2 * (idx % 4)); const i8vec2 sign01 = i8vec2(1 - (2 & i8vec2(int8_t(sign << 1), int8_t(sign)))); const uint grid = iq3xxs_grid[qs] >> (16 * (idx & 1)); - const vec2 v = db * vec2(sign01) * vec2(unpack8(grid).xy); + const vec2 v = db * vec2(sign01) * vec2(unpack8(grid).xy); // vec4 used due to #12147 buf_a[buf_idx ] = FLOAT_TYPE(v.x); buf_a[buf_idx + 1] = FLOAT_TYPE(v.y); @@ -598,7 +654,7 @@ void main() { const i8vec2 sign01 = i8vec2(1 - (2 & i8vec2(sign << 1, sign))); const float db = d * (1 + 2 * ((scale >> (4 * (iqh & 1))) & 0xf)); const uint32_t grid = iq3s_grid[qs | ((qh << (8 - (iqs % 8))) & 256)] >> (16 * (idx % 2)); - const vec2 v = db * vec2(sign01) * vec2(unpack8(grid).xy); + const vec2 v = db * vec2(sign01) * vec2(unpack8(grid).xy); // vec4 used due to #12147 buf_a[buf_idx ] = FLOAT_TYPE(v.x); buf_a[buf_idx + 1] = FLOAT_TYPE(v.y); @@ -623,17 +679,18 @@ void main() { buf_a[buf_idx + 1] = FLOAT_TYPE(v.y); #elif defined(DATA_A_IQ4_NL) const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a; - const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a; + const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + 2 * loadr_a; - const uint ib = idx / 16; - const uint iqs = idx & 0xF; + const uint ib = idx / 8; + const uint iqs = idx & 0x07; - const float d = float(data_a[ib].d); - const uint vui = uint(data_a[ib].qs[iqs]); - const vec2 v = vec2(kvalues_iq4nl[vui & 0xF], kvalues_iq4nl[vui >> 4]) * d; + const FLOAT_TYPE d = FLOAT_TYPE(data_a_packed16[ib].d); + const uint vui = uint(data_a_packed16[ib].qs[iqs]); - buf_a[buf_idx ] = FLOAT_TYPE(v.x); - buf_a[buf_idx + 16] = FLOAT_TYPE(v.y); + buf_a[buf_idx ] = FLOAT_TYPE(kvalues_iq4nl[vui & 0xF]) * d; + buf_a[buf_idx + 1 ] = FLOAT_TYPE(kvalues_iq4nl[bitfieldExtract(vui, 8, 4)]) * d; + buf_a[buf_idx + 16] = FLOAT_TYPE(kvalues_iq4nl[bitfieldExtract(vui, 4, 4)]) * d; + buf_a[buf_idx + 17] = FLOAT_TYPE(kvalues_iq4nl[vui >> 12]) * d; #endif } [[unroll]] for (uint l = 0; l < BN; l += loadstride_b) { @@ -661,13 +718,13 @@ void main() { const uint idx = pos_b + (loadc_b + l) * p.stride_b / LOAD_VEC_B + loadr_b; #endif const uint buf_idx = (loadc_b + l) * SHMEM_STRIDE + loadr_b * LOAD_VEC_B; - buf_b[buf_idx + 0] = FLOAT_TYPE(data_b[idx].x); - buf_b[buf_idx + 1] = FLOAT_TYPE(data_b[idx].y); - buf_b[buf_idx + 2] = FLOAT_TYPE(data_b[idx].z); - buf_b[buf_idx + 3] = FLOAT_TYPE(data_b[idx].w); + buf_b[buf_idx + 0] = TO_FLOAT_TYPE(data_b[idx].x); + buf_b[buf_idx + 1] = TO_FLOAT_TYPE(data_b[idx].y); + buf_b[buf_idx + 2] = TO_FLOAT_TYPE(data_b[idx].z); + buf_b[buf_idx + 3] = TO_FLOAT_TYPE(data_b[idx].w); #elif !MUL_MAT_ID if (ic * BN + loadc_b + l < p.N && block + loadr_b < end_k) { - buf_b[(loadc_b + l) * SHMEM_STRIDE + loadr_b] = FLOAT_TYPE(data_b[pos_b + (loadc_b + l) * p.stride_b + loadr_b]); + buf_b[(loadc_b + l) * SHMEM_STRIDE + loadr_b] = TO_FLOAT_TYPE(data_b[pos_b + (loadc_b + l) * p.stride_b + loadr_b]); } else { buf_b[(loadc_b + l) * SHMEM_STRIDE + loadr_b] = FLOAT_TYPE(0.0f); } @@ -675,7 +732,7 @@ void main() { const uint row_i = ic * BN + loadc_b + l; if (row_i < _ne1) { const u16vec2 row_idx = row_ids[row_i]; - buf_b[(loadc_b + l) * SHMEM_STRIDE + loadr_b] = FLOAT_TYPE(data_b[pos_b + row_idx.y * p.batch_stride_b + (row_idx.x % p.ne11) * p.stride_b + loadr_b]); + buf_b[(loadc_b + l) * SHMEM_STRIDE + loadr_b] = TO_FLOAT_TYPE(data_b[pos_b + row_idx.y * p.batch_stride_b + (row_idx.x % p.ne11) * p.stride_b + loadr_b]); } else { buf_b[(loadc_b + l) * SHMEM_STRIDE + loadr_b] = FLOAT_TYPE(0.0f); } @@ -710,16 +767,14 @@ void main() { } [[unroll]] for (uint wsic = 0; wsic < WNITER; wsic++) { [[unroll]] for (uint j = 0; j < TN; j++) { - cache_b[wsic * TN + j] = buf_b[(warp_c * WN + wsic * WSUBN + tiwc * TN + j) * SHMEM_STRIDE + i]; + cache_b[j] = buf_b[(warp_c * WN + wsic * WSUBN + tiwc * TN + j) * SHMEM_STRIDE + i]; } - } - [[unroll]] for (uint wsic = 0; wsic < WNITER; wsic++) { [[unroll]] for (uint wsir = 0; wsir < WMITER; wsir++) { [[unroll]] for (uint cc = 0; cc < TN; cc++) { [[unroll]] for (uint cr = 0; cr < TM; cr++) { const uint sums_idx = (wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr; - sums[sums_idx] = fma(ACC_TYPE(cache_a[wsir * TM + cr]), ACC_TYPE(cache_b[wsic * TN + cc]), sums[sums_idx]); + sums[sums_idx] = fma(ACC_TYPE(cache_a[wsir * TM + cr]), ACC_TYPE(cache_b[cc]), sums[sums_idx]); } } } @@ -743,7 +798,7 @@ void main() { [[unroll]] for (uint cm_col = 0; cm_col < cms_per_col; cm_col++) { coopMatStore(sums[cm_col * cms_per_row + cm_row], coopmat_stage, warp_i * TM * TN, TM, gl_CooperativeMatrixLayoutColumnMajor); - [[unroll]] for (uint col = 0; col < BN; col += storestride) { + [[unroll]] for (uint col = 0; col < TN; col += storestride) { const uint row_i = dc + cm_col * TN + col + store_c; if (row_i >= _ne1) break; diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp index 66dd2c860..918465757 100644 --- a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp @@ -14,15 +14,25 @@ #extension GL_EXT_buffer_reference : enable #extension GL_KHR_shader_subgroup_ballot : enable #extension GL_KHR_shader_subgroup_vote : enable +#ifdef DATA_A_BF16 +#extension GL_EXT_bfloat16 : enable +#endif #include "types.comp" layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; +#define IS_MUL_MM2 1 + +layout (constant_id = 0) const uint BLOCK_SIZE = 256; layout (constant_id = 1) const uint BM = 64; layout (constant_id = 2) const uint BN = 64; layout (constant_id = 3) const uint BK = 16; // Assumed to be 32 if working with a quant +layout (constant_id = 4) const bool enable_smaller_matrices = false; +const uint BNover2 = enable_smaller_matrices ? (BN / 2) : BN; +const uint BNover4 = enable_smaller_matrices ? (BN / 4) : BN; + layout (push_constant) uniform parameter { uint M; @@ -48,6 +58,8 @@ layout (push_constant) uniform parameter uint broadcast2; uint broadcast3; #endif + // N dimension for the B matrix can be >= p.N + uint padded_N; } p; @@ -64,10 +76,23 @@ layout (binding = 2) writeonly buffer D {D_TYPE data_d[];}; #define DECODEFUNCA #endif +#if !defined(fetch_scales) +#define fetch_scales(a, b, c, d, e, f) +#endif +#if !defined(store_scales) +#define store_scales(a) +#endif + +#if defined(DATA_A_BF16) +#define MAT_TYPE bfloat16_t +#else +#define MAT_TYPE FLOAT_TYPE +#endif + #ifdef MUL_MAT_ID layout (binding = 3) readonly buffer IDS {int data_ids[];}; -shared u16vec4 row_ids[3072]; +shared u16vec4 row_ids[4096]; layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufB { B_TYPE b[]; @@ -110,6 +135,8 @@ void main() { init_iq_shmem(gl_WorkGroupSize); #endif + const uint tid = gl_LocalInvocationIndex; + #ifdef MUL_MAT_ID const uint expert_idx = gl_GlobalInvocationID.z; #else @@ -166,15 +193,13 @@ void main() { const uint end_k = min(p.K, (ik + 1) * p.k_split); #endif - coopmat sum; - sum = coopmat(0.0); - #ifdef MUL_MAT_ID uint pos_a = (expert_idx * p.batch_stride_a) / QUANT_K; uint pos_b = 0; #else uint pos_a = (batch_idx_a * p.batch_stride_a) / QUANT_K; uint pos_b = batch_idx * p.batch_stride_b; + uint pos_d = batch_idx * p.batch_stride_d + ik * p.batch_stride_d * gl_NumWorkGroups.z; #endif uint stride_a = p.stride_a / QUANT_K; @@ -195,6 +220,7 @@ void main() { tensorLayoutNV<2> tensorLayoutB = createTensorLayoutNV(2); tensorLayoutNV<2, gl_CooperativeMatrixClampModeConstantNV> tensorLayoutBClamp = createTensorLayoutNV(2, gl_CooperativeMatrixClampModeConstantNV); tensorLayoutNV<2, gl_CooperativeMatrixClampModeConstantNV> tensorLayoutD = createTensorLayoutNV(2, gl_CooperativeMatrixClampModeConstantNV); + tensorLayoutD = setTensorLayoutStrideNV(tensorLayoutD, p.stride_d, 1); #if QUANT_K > 1 tensorLayoutA = setTensorLayoutBlockSizeNV(tensorLayoutA, 1, QUANT_K); @@ -202,24 +228,32 @@ void main() { #endif // Use end_k rather than p.K as the dimension because that's what - // we need to bound check against when using split_k + // we need to bound check against when using split_k. + // Bounds check B against padded_N, but bounds check D against N. tensorLayoutA = setTensorLayoutDimensionNV(tensorLayoutA, p.M, end_k); - tensorLayoutB = setTensorLayoutDimensionNV(tensorLayoutB, p.N, end_k); + tensorLayoutB = setTensorLayoutDimensionNV(tensorLayoutB, p.padded_N, end_k); tensorLayoutD = setTensorLayoutDimensionNV(tensorLayoutD, p.N, p.M); tensorLayoutAClamp = setTensorLayoutDimensionNV(tensorLayoutAClamp, p.M, end_k); - tensorLayoutBClamp = setTensorLayoutDimensionNV(tensorLayoutBClamp, p.N, end_k); + tensorLayoutBClamp = setTensorLayoutDimensionNV(tensorLayoutBClamp, p.padded_N, end_k); tensorViewNV<2, false, 1, 0> tensorViewTranspose = createTensorViewNV(2, false, 1, 0); #if !defined(MUL_MAT_ID) + + const uint START_ALIGN_K = 256; + // For Qi_K (block size 256), unroll whole 256 element tiles. + // For legacy quants (block size 32), unroll 8x. + const uint UNROLL_K = (QUANT_K == 256) ? 256 : (BK * 8); + const uint unroll_count = UNROLL_K / BK; + // Detect a fast path where all loads are entirely in bounds and no clamping is required - if ((ir + 1) * BM <= p.M && (ic + 1) * BN <= p.N && (start_k % BK) == 0 && (end_k % BK) == 0 && + if ((ir + 1) * BM <= p.M && (ic + 1) * BN <= p.padded_N && (start_k % START_ALIGN_K) == 0 && (end_k % BK) == 0 && #if QUANT_K == 1 (stride_a % 8) == 0 && #endif - (stride_b % 8) == 0 && (start_k % 8) == 0) { + (stride_b % 8) == 0) { // Hint to the compiler that values are aligned (want 16B alignment) - start_k &= ~7; + start_k &= ~(START_ALIGN_K-1); stride_b &= ~7; #if QUANT_K == 1 stride_a &= ~7; @@ -228,17 +262,131 @@ void main() { tensorLayoutA = setTensorLayoutStrideNV(tensorLayoutA, stride_a, 1); tensorLayoutB = setTensorLayoutStrideNV(tensorLayoutB, stride_b, 1); - uint k_iters = (end_k - start_k + BK - 1) / BK; + uint k_iters = (end_k - start_k) / UNROLL_K; + uint block_k = start_k; - for (uint block_k = start_k, i = 0; i < k_iters; block_k += BK, ++i) { + // fetch scale values for a tile of quants. These will be copied into shared memory. + // The fetches and stores are pipelined to hide the latency. + fetch_scales(ir * BM, pos_a, stride_a, start_k, tid, true); - coopmat mat_a; - coopmat mat_b; + if (enable_smaller_matrices && ic * BN + BNover4 >= p.N) { + coopmat sum = coopmat(0.0); + for (uint i = 0; i < k_iters; ++i) { - coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutA, ir * BM, BM, block_k, BK) DECODEFUNCA); - coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BN, block_k, BK), tensorViewTranspose); + store_scales(tid); + if (block_k + UNROLL_K < end_k) { + fetch_scales(ir * BM, pos_a, stride_a, block_k + UNROLL_K, tid, true); + } - sum = coopMatMulAdd(mat_a, mat_b, sum); + // Manually partial unroll + [[unroll]] for (uint j = 0; j < unroll_count; ++j) { + coopmat mat_a; + coopmat mat_b; + + coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutA, ir * BM, BM, block_k, BK) DECODEFUNCA); + coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BNover4, block_k, BK), tensorViewTranspose); + + sum = coopMatMulAdd(mat_a, mat_b, sum); + block_k += BK; + } + } + // Do any remaining iterations that were not unrolled + if (block_k < end_k) { + store_scales(tid); + } + while (block_k < end_k) { + coopmat mat_a; + coopmat mat_b; + + coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutA, ir * BM, BM, block_k, BK) DECODEFUNCA); + coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BNover4, block_k, BK), tensorViewTranspose); + + sum = coopMatMulAdd(mat_a, mat_b, sum); + block_k += BK; + } + coopmat mat_d = coopmat(sum); + + coopMatStoreTensorNV(mat_d, data_d, pos_d, sliceTensorLayoutNV(tensorLayoutD, ic * BN, BNover4, ir * BM, BM), tensorViewTranspose); + return; + } else if (enable_smaller_matrices && ic * BN + BNover2 >= p.N) { + coopmat sum = coopmat(0.0); + for (uint i = 0; i < k_iters; ++i) { + + store_scales(tid); + if (block_k + UNROLL_K < end_k) { + fetch_scales(ir * BM, pos_a, stride_a, block_k + UNROLL_K, tid, true); + } + + // Manually partial unroll + [[unroll]] for (uint j = 0; j < unroll_count; ++j) { + coopmat mat_a; + coopmat mat_b; + + coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutA, ir * BM, BM, block_k, BK) DECODEFUNCA); + coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BNover2, block_k, BK), tensorViewTranspose); + + sum = coopMatMulAdd(mat_a, mat_b, sum); + block_k += BK; + } + } + // Do any remaining iterations that were not unrolled + if (block_k < end_k) { + store_scales(tid); + } + while (block_k < end_k) { + coopmat mat_a; + coopmat mat_b; + + coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutA, ir * BM, BM, block_k, BK) DECODEFUNCA); + coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BNover2, block_k, BK), tensorViewTranspose); + + sum = coopMatMulAdd(mat_a, mat_b, sum); + block_k += BK; + } + coopmat mat_d = coopmat(sum); + + coopMatStoreTensorNV(mat_d, data_d, pos_d, sliceTensorLayoutNV(tensorLayoutD, ic * BN, BNover2, ir * BM, BM), tensorViewTranspose); + return; + } else { + coopmat sum = coopmat(0.0); + + for (uint i = 0; i < k_iters; ++i) { + + store_scales(tid); + if (block_k + UNROLL_K < end_k) { + fetch_scales(ir * BM, pos_a, stride_a, block_k + UNROLL_K, tid, true); + } + + // Manually partial unroll + [[unroll]] for (uint j = 0; j < unroll_count; ++j) { + coopmat mat_a; + coopmat mat_b; + + coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutA, ir * BM, BM, block_k, BK) DECODEFUNCA); + coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BN, block_k, BK), tensorViewTranspose); + + sum = coopMatMulAdd(mat_a, mat_b, sum); + block_k += BK; + } + } + // Do any remaining iterations that were not unrolled + if (block_k < end_k) { + store_scales(tid); + } + while (block_k < end_k) { + coopmat mat_a; + coopmat mat_b; + + coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutA, ir * BM, BM, block_k, BK) DECODEFUNCA); + coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BN, block_k, BK), tensorViewTranspose); + + sum = coopMatMulAdd(mat_a, mat_b, sum); + block_k += BK; + } + coopmat mat_d = coopmat(sum); + + coopMatStoreTensorNV(mat_d, data_d, pos_d, sliceTensorLayoutNV(tensorLayoutD, ic * BN, BN, ir * BM, BM), tensorViewTranspose); + return; } } else #endif // !defined(MUL_MAT_ID) @@ -251,61 +399,43 @@ void main() { tensorLayoutBClamp = setTensorLayoutStrideNV(tensorLayoutBClamp, stride_b, 1); + coopmat sum; + sum = coopmat(0.0); + + uint k_iters = (end_k - start_k + BK - 1) / BK; + + fetch_scales(ir * BM, pos_a, stride_a, start_k, tid, false); + [[dont_unroll]] - for (uint block_k = start_k; block_k < end_k; block_k += BK) { + for (uint block_k = start_k, i = 0; i < k_iters; block_k += BK, ++i) { - coopmat mat_a; - coopmat mat_b; - - // Clamping is expensive, so detect different code paths for each combination - // of A and B needing clamping. - bool unclampedA = (ir + 1) * BM <= p.M && block_k + BK <= end_k && (block_k % 8) == 0; -#ifdef MUL_MAT_ID - bool unclampedB = true; -#else - bool unclampedB = (ic + 1) * BN <= p.N && block_k + BK <= end_k && (block_k % 8) == 0; -#endif - if (unclampedA && unclampedB) { - coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutA, ir * BM, BM, (block_k & ~7), BK) DECODEFUNCA); -#ifdef MUL_MAT_ID - coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BN, block_k, BK), tensorViewTranspose, decodeFuncB); -#else - coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BN, (block_k & ~7), BK), tensorViewTranspose); -#endif - sum = coopMatMulAdd(mat_a, mat_b, sum); - } else if (unclampedA && !unclampedB) { - coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutA, ir * BM, BM, (block_k & ~7), BK) DECODEFUNCA); - coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutBClamp, ic * BN, BN, block_k, BK), tensorViewTranspose); - - sum = coopMatMulAdd(mat_a, mat_b, sum); - } else if (!unclampedA && unclampedB) { - coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutAClamp, ir * BM, BM, block_k, BK) DECODEFUNCA); -#ifdef MUL_MAT_ID - coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BN, block_k, BK), tensorViewTranspose, decodeFuncB); -#else - coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BN, (block_k & ~7), BK), tensorViewTranspose); -#endif - sum = coopMatMulAdd(mat_a, mat_b, sum); - } else if (!unclampedA && !unclampedB) { - coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutAClamp, ir * BM, BM, block_k, BK) DECODEFUNCA); - coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutBClamp, ic * BN, BN, block_k, BK), tensorViewTranspose); - - sum = coopMatMulAdd(mat_a, mat_b, sum); + store_scales(tid); + if (block_k + BK < end_k) { + fetch_scales(ir * BM, pos_a, stride_a, block_k + BK, tid, false); } - } - } - // Convert from ACC_TYPE to D_TYPE - coopmat mat_d; - mat_d = coopmat(sum); + coopmat mat_a; + coopmat mat_b; + + coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutAClamp, ir * BM, BM, block_k, BK) DECODEFUNCA); +#ifdef MUL_MAT_ID + coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BN, block_k, BK), tensorViewTranspose, decodeFuncB); +#else + coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutBClamp, ic * BN, BN, block_k, BK), tensorViewTranspose); +#endif + + sum = coopMatMulAdd(mat_a, mat_b, sum); + } + + // Convert from ACC_TYPE to D_TYPE + coopmat mat_d; + mat_d = coopmat(sum); #ifdef MUL_MAT_ID - // Call callback to store each element, remapping row through shared memory - coopMatPerElementNV(mat_d, mat_d, perElemOpD, ir, ic); + // Call callback to store each element, remapping row through shared memory + coopMatPerElementNV(mat_d, mat_d, perElemOpD, ir, ic); #else - tensorLayoutD = setTensorLayoutStrideNV(tensorLayoutD, p.stride_d, 1); - - uint pos_d = batch_idx * p.batch_stride_d + ik * p.batch_stride_d * gl_NumWorkGroups.z; - coopMatStoreTensorNV(mat_d, data_d, pos_d, sliceTensorLayoutNV(tensorLayoutD, ic * BN, BN, ir * BM, BM), tensorViewTranspose); + coopMatStoreTensorNV(mat_d, data_d, pos_d, sliceTensorLayoutNV(tensorLayoutD, ic * BN, BN, ir * BM, BM), tensorViewTranspose); #endif + } } diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp new file mode 100644 index 000000000..83de90eb7 --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp @@ -0,0 +1,442 @@ +#version 450 + +#extension GL_EXT_control_flow_attributes : enable +#extension GL_EXT_shader_16bit_storage : require +#extension GL_EXT_shader_explicit_arithmetic_types_int8 : require + +#extension GL_EXT_integer_dot_product : require + +#ifdef FLOAT16 +#extension GL_EXT_shader_explicit_arithmetic_types_float16 : require +#endif + +#ifdef COOPMAT +#extension GL_KHR_cooperative_matrix : enable +#extension GL_KHR_memory_scope_semantics : enable +#extension GL_KHR_shader_subgroup_basic : enable +#endif + +#ifdef MUL_MAT_ID +#extension GL_EXT_shader_explicit_arithmetic_types_int16 : require +#endif + +#include "types.comp" + +layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) readonly buffer A {A_TYPE_PACKED16 data_a[];}; +#if defined(A_TYPE_PACKED32) +layout (binding = 0) readonly buffer A_PACKED32 {A_TYPE_PACKED32 data_a_packed32[];}; +#endif +layout (binding = 1) readonly buffer B {block_q8_1_packed32 data_b[];}; +layout (binding = 2) writeonly buffer D {D_TYPE data_d[];}; + +#ifdef MUL_MAT_ID +layout (binding = 3) readonly buffer IDS {int data_ids[];}; +#endif + +layout (push_constant) uniform parameter +{ + uint M; + uint N; + uint K; + uint stride_a; + uint stride_b; + uint stride_d; + + uint batch_stride_a; + uint batch_stride_b; + uint batch_stride_d; + +#ifdef MUL_MAT_ID + uint nei0; + uint nei1; + uint nbi1; + uint ne11; +#else + uint k_split; + uint ne02; + uint ne12; + uint broadcast2; + uint broadcast3; +#endif +} p; + +layout (constant_id = 0) const uint BLOCK_SIZE = 64; +layout (constant_id = 1) const uint BM = 64; +layout (constant_id = 2) const uint BN = 64; +// layout (constant_id = 3) const uint BK = 32; +layout (constant_id = 4) const uint WM = 32; +layout (constant_id = 5) const uint WN = 32; +layout (constant_id = 6) const uint WMITER = 2; +layout (constant_id = 7) const uint TM = 4; +layout (constant_id = 8) const uint TN = 2; +layout (constant_id = 9) const uint TK = 1; // Only needed for coopmat +layout (constant_id = 10) const uint WARP = 32; + +#define BK 32 + +#ifdef COOPMAT +#define SHMEM_STRIDE (BK / 4 + 4) +#else +#define SHMEM_STRIDE (BK / 4 + 1) +#endif + +shared int32_t buf_a_qs[BM * SHMEM_STRIDE]; + +#ifndef COOPMAT +#if QUANT_AUXF == 1 +shared FLOAT_TYPE buf_a_dm[BM]; +#else +shared FLOAT_TYPE_VEC2 buf_a_dm[BM]; +#endif +#endif + +shared int32_t buf_b_qs[BN * SHMEM_STRIDE]; +#ifndef COOPMAT +shared FLOAT_TYPE_VEC2 buf_b_ds[BN]; +#endif + +#define LOAD_VEC_A (4 * QUANT_R) +#define LOAD_VEC_B 4 + +#ifdef MUL_MAT_ID +shared u16vec2 row_ids[4096]; +#endif // MUL_MAT_ID + +#define NUM_WARPS (BLOCK_SIZE / WARP) + +#ifdef COOPMAT +shared ACC_TYPE coopmat_stage[TM * TN * NUM_WARPS]; +#endif + +#include "mul_mmq_funcs.comp" + +void main() { +#ifdef NEEDS_INIT_IQ_SHMEM + init_iq_shmem(gl_WorkGroupSize); +#endif + +#ifdef MUL_MAT_ID + const uint expert_idx = gl_GlobalInvocationID.z; +#else + const uint batch_idx = gl_GlobalInvocationID.z; + + const uint i13 = batch_idx / p.ne12; + const uint i12 = batch_idx % p.ne12; + + const uint i03 = i13 / p.broadcast3; + const uint i02 = i12 / p.broadcast2; + + const uint batch_idx_a = i03 * p.ne02 + i02; +#endif + + const uint blocks_m = (p.M + BM - 1) / BM; + const uint ir = gl_WorkGroupID.x % blocks_m; + const uint ik = gl_WorkGroupID.x / blocks_m; + const uint ic = gl_WorkGroupID.y; + + const uint WNITER = (WM * WN) / (WARP * TM * TN * WMITER); + const uint WSUBM = WM / WMITER; + const uint WSUBN = WN / WNITER; + +#ifdef COOPMAT + const uint warp_i = gl_SubgroupID; + + const uint tiw = gl_SubgroupInvocationID; + + const uint cms_per_row = WM / TM; + const uint cms_per_col = WN / TN; + + const uint storestride = WARP / TM; + const uint store_r = tiw % TM; + const uint store_c = tiw / TM; +#else + const uint warp_i = gl_LocalInvocationID.x / WARP; + + const uint tiw = gl_LocalInvocationID.x % WARP; + + const uint tiwr = tiw % (WSUBM / TM); + const uint tiwc = tiw / (WSUBM / TM); +#endif + + const uint warp_r = warp_i % (BM / WM); + const uint warp_c = warp_i / (BM / WM); + + const uint loadr_a = gl_LocalInvocationID.x % (BK / LOAD_VEC_A); + const uint loadc_a = gl_LocalInvocationID.x / (BK / LOAD_VEC_A); + const uint loadr_b = gl_LocalInvocationID.x % (BK / LOAD_VEC_B); + const uint loadc_b = gl_LocalInvocationID.x / (BK / LOAD_VEC_B); + + const uint loadstride_a = BLOCK_SIZE * LOAD_VEC_A / BK; + const uint loadstride_b = BLOCK_SIZE * LOAD_VEC_B / BK; + +#ifdef MUL_MAT_ID + uint _ne1 = 0; + for (uint ii1 = 0; ii1 < p.nei1; ii1++) { + for (uint ii0 = 0; ii0 < p.nei0; ii0++) { + if (data_ids[ii1*p.nbi1 + ii0] == expert_idx) { + row_ids[_ne1] = u16vec2(ii0, ii1); + _ne1++; + } + } + } + + barrier(); + + // Workgroup has no work + if (ic * BN >= _ne1) return; +#endif + +#ifdef MUL_MAT_ID + const uint start_k = 0; + const uint end_k = p.K; +#else + const uint start_k = ik * p.k_split; + const uint end_k = min(p.K, (ik + 1) * p.k_split); +#endif + + uint pos_a_ib = ( +#ifdef MUL_MAT_ID + expert_idx * p.batch_stride_a + +#else + batch_idx_a * p.batch_stride_a + +#endif + ir * BM * p.stride_a + start_k) / BK; +#ifdef MUL_MAT_ID + uint pos_b_ib = 0; +#else + uint pos_b_ib = (batch_idx * p.batch_stride_b + ic * BN * p.stride_b + start_k) / BK; +#endif + +#ifdef COOPMAT + coopmat cache_a; + coopmat cache_b; + coopmat cm_result; + + coopmat factors[cms_per_row * cms_per_col]; + + coopmat sums[cms_per_row * cms_per_col]; + + [[unroll]] for (uint i = 0; i < cms_per_row * cms_per_col; i++) { + sums[i] = coopmat(0.0f); + } +#else + int32_t cache_a_qs[WMITER * TM * BK / 4]; + + int32_t cache_b_qs[TN * BK / 4]; + + ACC_TYPE sums[WMITER * TM * WNITER * TN]; + + [[unroll]] for (uint i = 0; i < WMITER*TM*WNITER*TN; i++) { + sums[i] = ACC_TYPE(0.0f); + } +#endif + +#if QUANT_AUXF == 1 + FLOAT_TYPE cache_a_dm[WMITER * TM]; +#else + FLOAT_TYPE_VEC2 cache_a_dm[WMITER * TM]; +#endif + + FLOAT_TYPE_VEC2 cache_b_ds[TN]; + + for (uint block = start_k; block < end_k; block += BK) { + [[unroll]] for (uint l = 0; loadc_a + l < BM; l += loadstride_a) { + const uint ib = pos_a_ib + (loadc_a + l) * p.stride_a / BK; + const uint iqs = loadr_a; + const uint buf_ib = loadc_a + l; + + if (iqs == 0) { +#if QUANT_AUXF == 1 + buf_a_dm[buf_ib] = get_d(ib); +#else + buf_a_dm[buf_ib] = get_dm(ib); +#endif + } +#if QUANT_R == 1 + buf_a_qs[buf_ib * SHMEM_STRIDE + iqs] = repack(ib, iqs); +#else + const i32vec2 vals = repack(ib, iqs); + buf_a_qs[buf_ib * SHMEM_STRIDE + iqs ] = vals.x; + buf_a_qs[buf_ib * SHMEM_STRIDE + iqs + 4] = vals.y; +#endif + } + [[unroll]] for (uint l = 0; loadc_b + l < BN; l += loadstride_b) { +#ifdef MUL_MAT_ID + const u16vec2 row_idx = row_ids[ic * BN + loadc_b + l]; + const uint idx = pos_b_ib + row_idx.y * p.batch_stride_b / LOAD_VEC_B + (row_idx.x % p.ne11) * p.stride_b / LOAD_VEC_B + loadr_b; + const uint ib = idx / 8; + const uint iqs = idx & 0x7; +#else + const uint ib = pos_b_ib + (loadc_b + l) * p.stride_b / BK; + const uint iqs = loadr_b; +#endif + + const uint buf_ib = loadc_b + l; + + if (iqs == 0) { + buf_b_ds[buf_ib] = FLOAT_TYPE_VEC2(data_b[ib].ds); + } + buf_b_qs[buf_ib * SHMEM_STRIDE + iqs] = data_b[ib].qs[iqs]; + } + + barrier(); + + pos_a_ib += 1; + pos_b_ib += 1; + +#ifdef COOPMAT + [[unroll]] for (uint cm_row = 0; cm_row < cms_per_row; cm_row++) { + const uint ib_a = warp_r * WM + cm_row * TM; + // Load from shared into cache + coopMatLoad(cache_a, buf_a_qs, ib_a * SHMEM_STRIDE, SHMEM_STRIDE, gl_CooperativeMatrixLayoutRowMajor); + + // TODO: only cache values that are actually needed + [[unroll]] for (uint t_idx = 0; t_idx < TM; t_idx++) { + cache_a_dm[t_idx] = buf_a_dm[ib_a + t_idx]; + } + + [[unroll]] for (uint cm_col = 0; cm_col < cms_per_col; cm_col++) { + const uint ib_b = warp_c * WN + cm_col * TN; + coopMatLoad(cache_b, buf_b_qs, ib_b * SHMEM_STRIDE, SHMEM_STRIDE, gl_CooperativeMatrixLayoutColumnMajor); + + // TODO: only cache values that are actually needed + [[unroll]] for (uint t_idx = 0; t_idx < TN; t_idx++) { + cache_b_dm[t_idx] = buf_b_d[ib_b + t_idx]; + } + + cm_result = coopmat(0); + cm_result = coopMatMulAdd(cache_a, cache_b, cm_result); + + [[unroll]] for (uint col = 0; col < TN; col += storestride) { + coopmat_stage[warp_i * TM * TN + (store_c + col) * TM + store_r] = ACC_TYPE(float(cache_a_d[store_r]) * float(cache_b_d[store_c + col])); + } + + coopMatLoad(factors, coopmat_stage, warp_i * TM * TN, TM, gl_CooperativeMatrixLayoutColumnMajor); + sums[cm_col * cms_per_row + cm_row] += factors * coopmat(cm_result); + } + } +#else + // Load from shared into cache + [[unroll]] for (uint wsir = 0; wsir < WMITER; wsir++) { + [[unroll]] for (uint cr = 0; cr < TM; cr++) { + const uint ib = warp_r * WM + wsir * WSUBM + tiwr * TM + cr; + cache_a_dm[wsir * TM + cr] = buf_a_dm[ib]; + [[unroll]] for (uint idx_k = 0; idx_k < BK / 4; idx_k++) { + cache_a_qs[(wsir * TM + cr) * (BK / 4) + idx_k] = buf_a_qs[ib * SHMEM_STRIDE + idx_k]; + } + } + } + + [[unroll]] for (uint wsic = 0; wsic < WNITER; wsic++) { + [[unroll]] for (uint cc = 0; cc < TN; cc++) { + const uint ib = warp_c * WN + wsic * WSUBN + tiwc * TN + cc; + cache_b_ds[cc] = buf_b_ds[ib]; + [[unroll]] for (uint idx_k = 0; idx_k < BK / 4; idx_k++) { + cache_b_qs[cc * (BK / 4) + idx_k] = buf_b_qs[ib * SHMEM_STRIDE + idx_k]; + } + } + + [[unroll]] for (uint wsir = 0; wsir < WMITER; wsir++) { + [[unroll]] for (uint cc = 0; cc < TN; cc++) { + [[unroll]] for (uint cr = 0; cr < TM; cr++) { + const uint cache_a_idx = wsir * TM + cr; + const uint sums_idx = (wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr; + int32_t q_sum = 0; + [[unroll]] for (uint idx_k = 0; idx_k < BK / 4; idx_k++) { + q_sum += dotPacked4x8EXT(cache_a_qs[cache_a_idx * (BK / 4) + idx_k], + cache_b_qs[cc * (BK / 4) + idx_k]); + } + + sums[sums_idx] += mul_q8_1(q_sum, cache_a_dm[cache_a_idx], cache_b_ds[cc]); + } + } + } + } +#endif + + barrier(); + } + + const uint dr = ir * BM + warp_r * WM; + const uint dc = ic * BN + warp_c * WN; + +#ifndef MUL_MAT_ID + const uint offsets = batch_idx * p.batch_stride_d + ik * p.batch_stride_d * gl_NumWorkGroups.z; +#endif + +#ifdef COOPMAT +#ifdef MUL_MAT_ID + [[unroll]] for (uint cm_row = 0; cm_row < cms_per_row; cm_row++) { + [[unroll]] for (uint cm_col = 0; cm_col < cms_per_col; cm_col++) { + coopMatStore(sums[cm_col * cms_per_row + cm_row], coopmat_stage, warp_i * TM * TN, TM, gl_CooperativeMatrixLayoutColumnMajor); + + [[unroll]] for (uint col = 0; col < BN; col += storestride) { + const uint row_i = dc + cm_col * TN + col + store_c; + if (row_i >= _ne1) break; + + const u16vec2 row_idx = row_ids[row_i]; + + data_d[row_idx.y * p.batch_stride_d + row_idx.x * p.stride_d + dr + cm_row * TM + store_r] = D_TYPE(coopmat_stage[warp_i * TM * TN + (col + store_c) * TM + store_r]); + } + } + } +#else + const bool is_aligned = p.stride_d % 4 == 0; // Assumption: D_TYPE == float + + [[unroll]] for (uint cm_row = 0; cm_row < cms_per_row; cm_row++) { + [[unroll]] for (uint cm_col = 0; cm_col < cms_per_col; cm_col++) { + const bool is_in_bounds = dr + (cm_row + 1) * TM <= p.M && dc + (cm_col + 1) * TN <= p.N; + + if (is_aligned && is_in_bounds) { + // Full coopMat is within bounds and stride_d is aligned with 16B + coopmat cm_dtype = coopmat(sums[cm_col * cms_per_row + cm_row]); + coopMatStore(cm_dtype, data_d, offsets + (dc + cm_col * TN) * p.stride_d + dr + cm_row * TM, p.stride_d, gl_CooperativeMatrixLayoutColumnMajor); + } else if (is_in_bounds) { + // Full coopMat is within bounds, but stride_d is not aligned + coopMatStore(sums[cm_col * cms_per_row + cm_row], coopmat_stage, warp_i * TM * TN, TM, gl_CooperativeMatrixLayoutColumnMajor); + + [[unroll]] for (uint col = 0; col < TN; col += storestride) { + data_d[offsets + (dc + cm_col * TN + col + store_c) * p.stride_d + dr + cm_row * TM + store_r] = D_TYPE(coopmat_stage[warp_i * TM * TN + (col + store_c) * TM + store_r]); + } + } else if (dr + cm_row * TM < p.M && dc + cm_col * TN < p.N) { + // Partial coopMat is within bounds + coopMatStore(sums[cm_col * cms_per_row + cm_row], coopmat_stage, warp_i * TM * TN, TM, gl_CooperativeMatrixLayoutColumnMajor); + + [[unroll]] for (uint col = 0; col < TN; col += storestride) { + if (dr + cm_row * TM + store_r < p.M && dc + cm_col * TN + col + store_c < p.N) { + data_d[offsets + (dc + cm_col * TN + col + store_c) * p.stride_d + dr + cm_row * TM + store_r] = D_TYPE(coopmat_stage[warp_i * TM * TN + (col + store_c) * TM + store_r]); + } + } + } + } + } +#endif // MUL_MAT_ID +#else + [[unroll]] for (uint wsic = 0; wsic < WNITER; wsic++) { + [[unroll]] for (uint wsir = 0; wsir < WMITER; wsir++) { + + const uint dr_warp = dr + wsir * WSUBM + tiwr * TM; + const uint dc_warp = dc + wsic * WSUBN + tiwc * TN; + [[unroll]] for (uint cc = 0; cc < TN; cc++) { +#ifdef MUL_MAT_ID + const uint row_i = dc_warp + cc; + if (row_i >= _ne1) break; + + const u16vec2 row_idx = row_ids[row_i]; +#endif // MUL_MAT_ID + [[unroll]] for (uint cr = 0; cr < TM; cr++) { +#ifdef MUL_MAT_ID + data_d[row_idx.y * p.batch_stride_d + row_idx.x * p.stride_d + dr_warp + cr] = D_TYPE(sums[(wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr]); +#else + if (dr_warp + cr < p.M && dc_warp + cc < p.N) { + data_d[offsets + (dc_warp + cc) * p.stride_d + dr_warp + cr] = D_TYPE(sums[(wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr]); + } +#endif // MUL_MAT_ID + } + } + } + } +#endif // COOPMAT +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.comp new file mode 100644 index 000000000..63b15471b --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.comp @@ -0,0 +1,99 @@ +#extension GL_EXT_shader_explicit_arithmetic_types_int32 : require +#extension GL_EXT_shader_explicit_arithmetic_types_int16 : require +#extension GL_EXT_shader_explicit_arithmetic_types_int8 : require + +#include "types.comp" + +// Each iqs value maps to a 32-bit integer + +#if defined(DATA_A_Q4_0) +i32vec2 repack(uint ib, uint iqs) { + // Use 2-byte loads since a q4_0 block (18 bytes) is not divisible by 4 + const u16vec2 quants = u16vec2(data_a[ib].qs[iqs * 2 ], + data_a[ib].qs[iqs * 2 + 1]); + const uint32_t vui = pack32(quants); + return i32vec2( vui & 0x0F0F0F0F, + (vui >> 4) & 0x0F0F0F0F); +} + +ACC_TYPE mul_q8_1(int32_t q_sum, float da, vec2 dsb) { + return ACC_TYPE(da * (float(q_sum) * dsb.x - 8.0f * dsb.y)); +} +#endif + +#if defined(DATA_A_Q4_1) +i32vec2 repack(uint ib, uint iqs) { + // Use 4-byte loads since a q4_1 block (20 bytes) is divisible by 4 + const uint32_t vui = data_a_packed32[ib].qs[iqs]; + return i32vec2( vui & 0x0F0F0F0F, + (vui >> 4) & 0x0F0F0F0F); +} + +ACC_TYPE mul_q8_1(int32_t q_sum, vec2 dma, vec2 dsb) { + return ACC_TYPE(float(q_sum) * dma.x * dsb.x + dma.y * dsb.y); +} +#endif + +#if defined(DATA_A_Q5_0) +i32vec2 repack(uint ib, uint iqs) { + // Use 2-byte loads since a q5_0 block (22 bytes) is not divisible by 4 + const u16vec2 quants = u16vec2(data_a[ib].qs[iqs * 2 ], + data_a[ib].qs[iqs * 2 + 1]); + const uint32_t vui = pack32(quants); + const int32_t qh = int32_t((uint32_t(data_a[ib].qh[1]) << 16 | data_a[ib].qh[0]) >> (4 * iqs)); + const int32_t v0 = int32_t(vui & 0x0F0F0F0F) + | ((qh & 0xF) * 0x02040810) & 0x10101010; // (0,1,2,3) -> (4,12,20,28) + + const int32_t v1 = int32_t((vui >> 4) & 0x0F0F0F0F) + | (((qh >> 16) & 0xF) * 0x02040810) & 0x10101010; // (16,17,18,19) -> (4,12,20,28) + + return i32vec2(v0, v1); +} + +ACC_TYPE mul_q8_1(int32_t q_sum, float da, vec2 dsb) { + return ACC_TYPE(da * (float(q_sum) * dsb.x - 16.0f * dsb.y)); +} +#endif + +#if defined(DATA_A_Q5_1) +i32vec2 repack(uint ib, uint iqs) { + // Use 4-byte loads since a q5_1 block (24 bytes) is divisible by 4 + const uint32_t vui = data_a_packed32[ib].qs[iqs]; + const int32_t qh = int32_t(data_a_packed32[ib].qh >> (4 * iqs)); + const int32_t v0 = int32_t(vui & 0x0F0F0F0F) + | ((qh & 0xF) * 0x02040810) & 0x10101010; // (0,1,2,3) -> (4,12,20,28) + + const int32_t v1 = int32_t((vui >> 4) & 0x0F0F0F0F) + | (((qh >> 16) & 0xF) * 0x02040810) & 0x10101010; // (16,17,18,19) -> (4,12,20,28) + + return i32vec2(v0, v1); +} + +ACC_TYPE mul_q8_1(int32_t q_sum, vec2 dma, vec2 dsb) { + return ACC_TYPE(float(q_sum) * dma.x * dsb.x + dma.y * dsb.y); +} +#endif + +#if defined(DATA_A_Q8_0) +int32_t repack(uint ib, uint iqs) { + // Use 2-byte loads since a q8_0 block (34 bytes) is not divisible by 4 + return pack32(i16vec2(data_a[ib].qs[iqs * 2 ], + data_a[ib].qs[iqs * 2 + 1])); +} + +ACC_TYPE mul_q8_1(int32_t q_sum, float da, vec2 dsb) { + return ACC_TYPE(float(q_sum) * da * dsb.x); +} +#endif + +#if defined(DATA_A_Q4_0) || defined(DATA_A_Q5_0) || defined(DATA_A_Q8_0) || defined(DATA_A_IQ1_S) || defined(DATA_A_IQ2_XXS) || defined(DATA_A_IQ2_XS) || defined(DATA_A_IQ2_S) || defined(DATA_A_IQ3_XXS) || defined(DATA_A_IQ3_S) || defined(DATA_A_IQ4_XS) || defined(DATA_A_IQ4_NL) +FLOAT_TYPE get_d(uint ib) { + return FLOAT_TYPE(data_a[ib].d); +} +#endif + +#if defined(DATA_A_Q4_1) || defined(DATA_A_Q5_1) +FLOAT_TYPE_VEC2 get_dm(uint ib) { + return FLOAT_TYPE_VEC2(data_a_packed32[ib].dm); +} +#endif diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/quantize_q8_1.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/quantize_q8_1.comp new file mode 100644 index 000000000..e2e020fec --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/quantize_q8_1.comp @@ -0,0 +1,77 @@ +#version 450 + +#extension GL_EXT_control_flow_attributes : require +#extension GL_EXT_shader_16bit_storage : require + +layout (push_constant) uniform parameter +{ + uint ne; +} p; + +#include "types.comp" + +layout(constant_id = 0) const uint GROUP_SIZE = 32; +layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) readonly buffer A {vec4 data_a[];}; +layout (binding = 1) writeonly buffer D {block_q8_1_packed32 data_b[];}; + +shared float shmem[GROUP_SIZE]; + +void quantize() { + const uint wgid = gl_WorkGroupID.x; + const uint tid = gl_LocalInvocationID.x; + + // Each thread handles a vec4, so 8 threads handle a block + const uint blocks_per_group = GROUP_SIZE / 8; + + const uint block_in_wg = tid / 8; + + const uint ib = wgid * blocks_per_group + block_in_wg; + const uint iqs = tid % 8; + + if (ib >= gl_NumWorkGroups.x * blocks_per_group) { + return; + } + + const uint a_idx = ib * 8 + iqs; + + vec4 vals = a_idx < p.ne ? data_a[a_idx] : vec4(0.0f); + const vec4 abs_vals = abs(vals); + + // Find absolute max for each block + shmem[tid] = max(max(abs_vals.x, abs_vals.y), max(abs_vals.z, abs_vals.w)); + barrier(); + [[unroll]] for (uint s = 4; s > 0; s >>= 1) { + if (iqs < s) { + shmem[tid] = max(shmem[tid], shmem[tid + s]); + } + barrier(); + } + + const float amax = shmem[block_in_wg * 8]; + const float d = amax / 127.0; + const float d_inv = d != 0.0 ? 1.0 / d : 0.0; + vals = round(vals * d_inv); + data_b[ib].qs[iqs] = pack32(i8vec4(round(vals))); + barrier(); + + // Calculate the sum for each block + shmem[tid] = vals.x + vals.y + vals.z + vals.w; + barrier(); + [[unroll]] for (uint s = 4; s > 0; s >>= 1) { + if (iqs < s) { + shmem[tid] += shmem[tid + s]; + } + barrier(); + } + if (iqs == 0) { + const float sum = shmem[tid]; + + data_b[ib].ds = f16vec2(vec2(d, sum * d)); + } +} + +void main() { + quantize(); +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/relu.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/relu.comp index 52a19b62a..4f806270c 100644 --- a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/relu.comp +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/relu.comp @@ -17,5 +17,5 @@ void main() { return; } - data_d[i] = max(float(data_a[i]), 0); + data_d[i] = D_TYPE(max(float(data_a[i]), 0)); } diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp index b554400ba..deb8ee996 100644 --- a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp @@ -1,6 +1,6 @@ #version 450 -#include "generic_head.comp" +#include "generic_unary_head.comp" #include "types.comp" #extension GL_EXT_control_flow_attributes : enable @@ -8,19 +8,29 @@ layout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in; -layout (binding = 0) readonly buffer X {A_TYPE data_a[];}; -layout (binding = 1) writeonly buffer D {D_TYPE data_d[];}; - shared FLOAT_TYPE sum[BLOCK_SIZE]; void main() { - const uint row = gl_WorkGroupID.z * 262144 + gl_WorkGroupID.y * 512 + gl_WorkGroupID.x; - const uint tid = gl_LocalInvocationID.x; + const uint ncols = p.ne00; + const uint nrows = gl_NumWorkGroups.x; + const uint nchannels = gl_NumWorkGroups.y; + + const uint row = gl_WorkGroupID.x; + const uint channel = gl_WorkGroupID.y; + const uint samp = gl_WorkGroupID.z; + const uint tid = gl_LocalInvocationID.x; + + const uint stride_row = p.nb01; + const uint stride_channel = p.nb02; + const uint stride_sample = p.nb03; + + uint32_t a_offset = samp*stride_sample + channel*stride_channel + row*stride_row + get_aoffset(); + uint32_t d_offset = ((samp*nchannels + channel)*nrows + row)*ncols + get_doffset(); sum[tid] = FLOAT_TYPE(0.0f); // partial sum for thread in warp - [[unroll]] for (uint col = tid; col < p.KX; col += BLOCK_SIZE) { - const FLOAT_TYPE xi = FLOAT_TYPE(data_a[row*p.KX + col]); + [[unroll]] for (uint col = tid; col < ncols; col += BLOCK_SIZE) { + const FLOAT_TYPE xi = FLOAT_TYPE(data_a[a_offset + col]); sum[tid] += xi * xi; } @@ -33,10 +43,10 @@ void main() { barrier(); } - const FLOAT_TYPE mean = sum[0] / FLOAT_TYPE(p.KX); + const FLOAT_TYPE mean = sum[0] / FLOAT_TYPE(ncols); const FLOAT_TYPE scale = inversesqrt(mean + FLOAT_TYPE(p.param1)); - [[unroll]] for (uint col = tid; col < p.KX; col += BLOCK_SIZE) { - data_d[row*p.KX + col] = D_TYPE(scale * FLOAT_TYPE(data_a[row*p.KX + col])); + [[unroll]] for (uint col = tid; col < ncols; col += BLOCK_SIZE) { + data_d[d_offset + col] = D_TYPE(scale * FLOAT_TYPE(data_a[a_offset + col])); } } diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/sigmoid.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/sigmoid.comp index 776581e2c..5c9e5c350 100644 --- a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/sigmoid.comp +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/sigmoid.comp @@ -16,5 +16,5 @@ void main() { if (i >= p.KX) { return; } - data_d[i] = D_TYPE(1. / (1 + exp(-1. *data_a[i]))); + data_d[i] = D_TYPE(1. / (1 + exp(-1. * float(data_a[i])))); } diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/tanh.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/tanh.comp index 495f966bd..8a6f868f5 100644 --- a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/tanh.comp +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/tanh.comp @@ -16,5 +16,5 @@ void main() { if (i >= p.KX) { return; } - data_d[i] = D_TYPE(1. - 2. / (exp(2.*data_a[i]) + 1.)); + data_d[i] = D_TYPE(1. - 2. / (exp(2.*float(data_a[i])) + 1.)); } diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/test_bfloat16_support.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/test_bfloat16_support.comp new file mode 100644 index 000000000..fd0ba401f --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/test_bfloat16_support.comp @@ -0,0 +1,7 @@ +#version 460 + +#extension GL_EXT_bfloat16 : require + +void main() +{ +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/test_integer_dot_support.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/test_integer_dot_support.comp new file mode 100644 index 000000000..470e3074d --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/test_integer_dot_support.comp @@ -0,0 +1,7 @@ +#version 460 + +#extension GL_EXT_integer_dot_product : require + +void main() +{ +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/types.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/types.comp index dfa16cda5..3bde71783 100644 --- a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/types.comp +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/types.comp @@ -1,7 +1,7 @@ - #if !defined(GGML_TYPES_COMP) #define GGML_TYPES_COMP +#extension GL_EXT_shader_explicit_arithmetic_types_int64 : require #extension GL_EXT_shader_explicit_arithmetic_types_int32 : require #extension GL_EXT_shader_explicit_arithmetic_types_int16 : require #extension GL_EXT_shader_explicit_arithmetic_types_int8 : require @@ -33,6 +33,19 @@ #endif #endif +#if defined(DATA_A_BF16) +#define QUANT_K 1 +#define QUANT_R 1 + +#if !defined(LOAD_VEC_A) || LOAD_VEC_A == 1 +#define A_TYPE uint16_t +#elif LOAD_VEC_A == 4 +#define A_TYPE u16vec4 +#elif LOAD_VEC_A == 8 +#error unsupported +#endif +#endif + #define QUANT_K_Q4_0 32 #define QUANT_R_Q4_0 2 @@ -50,6 +63,7 @@ struct block_q4_0_packed16 #if defined(DATA_A_Q4_0) #define QUANT_K QUANT_K_Q4_0 #define QUANT_R QUANT_R_Q4_0 +#define QUANT_AUXF 1 #define A_TYPE block_q4_0 #define A_TYPE_PACKED16 block_q4_0_packed16 #endif @@ -71,11 +85,19 @@ struct block_q4_1_packed16 uint16_t qs[16/2]; }; +struct block_q4_1_packed32 +{ + f16vec2 dm; + uint32_t qs[16/4]; +}; + #if defined(DATA_A_Q4_1) #define QUANT_K QUANT_K_Q4_1 #define QUANT_R QUANT_R_Q4_1 +#define QUANT_AUXF 2 #define A_TYPE block_q4_1 #define A_TYPE_PACKED16 block_q4_1_packed16 +#define A_TYPE_PACKED32 block_q4_1_packed32 #endif #define QUANT_K_Q5_0 32 @@ -98,6 +120,7 @@ struct block_q5_0_packed16 #if defined(DATA_A_Q5_0) #define QUANT_K QUANT_K_Q5_0 #define QUANT_R QUANT_R_Q5_0 +#define QUANT_AUXF 1 #define A_TYPE block_q5_0 #define A_TYPE_PACKED16 block_q5_0_packed16 #endif @@ -121,11 +144,20 @@ struct block_q5_1_packed16 uint16_t qs[16/2]; }; +struct block_q5_1_packed32 +{ + f16vec2 dm; + uint qh; + uint32_t qs[16/4]; +}; + #if defined(DATA_A_Q5_1) #define QUANT_K QUANT_K_Q5_1 #define QUANT_R QUANT_R_Q5_1 +#define QUANT_AUXF 2 #define A_TYPE block_q5_1 #define A_TYPE_PACKED16 block_q5_1_packed16 +#define A_TYPE_PACKED32 block_q5_1_packed32 #endif #define QUANT_K_Q8_0 32 @@ -139,16 +171,42 @@ struct block_q8_0 struct block_q8_0_packed16 { float16_t d; - uint16_t qs[32/2]; + int16_t qs[32/2]; +}; +struct block_q8_0_packed32 +{ + float16_t d; + int32_t qs[32/4]; }; #if defined(DATA_A_Q8_0) #define QUANT_K QUANT_K_Q8_0 #define QUANT_R QUANT_R_Q8_0 +#define QUANT_AUXF 1 #define A_TYPE block_q8_0 #define A_TYPE_PACKED16 block_q8_0_packed16 +#define A_TYPE_PACKED32 block_q8_0_packed32 #endif +#define QUANT_K_Q8_1 32 +#define QUANT_R_Q8_1 1 + +struct block_q8_1 +{ + f16vec2 ds; + int8_t qs[32]; +}; +struct block_q8_1_packed16 +{ + f16vec2 ds; + int16_t qs[16]; +}; +struct block_q8_1_packed32 +{ + f16vec2 ds; + int32_t qs[8]; +}; + // K-quants #define QUANT_K_Q2_K 256 @@ -312,6 +370,12 @@ struct block_iq1_m { uint16_t scales[QUANT_K_IQ1_M/64]; }; +struct block_iq1_m_packed64 { + uint64_t qs[QUANT_K_IQ1_M/8/8]; + uint64_t qh[QUANT_K_IQ1_M/16/8]; + uint64_t scales; +}; + #if defined(DATA_A_IQ1_S) #define QUANT_K QUANT_K_IQ1_S #define QUANT_R QUANT_R_IQ1_S @@ -466,10 +530,13 @@ shared uint16_t iq1s_grid[2048]; void init_iq_shmem(uvec3 wgsize) { // copy the table into shared memory and sync - for (uint i = gl_LocalInvocationIndex.x; i < iq1s_grid_const.length(); i += wgsize.x) { - u16vec2 g = unpack16(iq1s_grid_const[i]); - iq1s_grid[2*i+0] = g.x; - iq1s_grid[2*i+1] = g.y; + [[unroll]] for (uint i = 0; i < iq1s_grid_const.length(); i += wgsize.x) { + uint idx = i + gl_LocalInvocationIndex.x; + if (iq1s_grid_const.length() % wgsize.x == 0 || idx < iq1s_grid_const.length()) { + u16vec2 g = unpack16(iq1s_grid_const[idx]); + iq1s_grid[2*idx+0] = g.x; + iq1s_grid[2*idx+1] = g.y; + } } barrier(); } @@ -565,8 +632,10 @@ shared uvec2 iq2xxs_grid[256]; void init_iq_shmem(uvec3 wgsize) { // copy the table into shared memory and sync - for (uint i = gl_LocalInvocationIndex.x; i < iq2xxs_grid.length(); i += wgsize.x) { - iq2xxs_grid[i] = iq2xxs_grid_const[i]; + [[unroll]] for (uint i = 0; i < iq2xxs_grid.length(); i += wgsize.x) { + if (iq2xxs_grid_const.length() % wgsize.x == 0 || i + gl_LocalInvocationIndex.x < iq2xxs_grid_const.length()) { + iq2xxs_grid[i + gl_LocalInvocationIndex.x] = iq2xxs_grid_const[i + gl_LocalInvocationIndex.x]; + } } barrier(); } @@ -733,8 +802,10 @@ shared uvec2 iq2xs_grid[512]; void init_iq_shmem(uvec3 wgsize) { // copy the table into shared memory and sync - for (uint i = gl_LocalInvocationIndex.x; i < iq2xs_grid.length(); i += wgsize.x) { - iq2xs_grid[i] = iq2xs_grid_const[i]; + [[unroll]] for (uint i = 0; i < iq2xs_grid.length(); i += wgsize.x) { + if (iq2xs_grid.length() % wgsize.x == 0 || i + gl_LocalInvocationIndex.x < iq2xs_grid_const.length()) { + iq2xs_grid[i + gl_LocalInvocationIndex.x] = iq2xs_grid_const[i + gl_LocalInvocationIndex.x]; + } } barrier(); } @@ -756,6 +827,14 @@ struct block_iq2_s uint8_t scales[QUANT_K_IQ2_S/32]; }; +struct block_iq2_s_packed16 +{ + float16_t d; + uint16_t qs[QUANT_K_IQ2_S/8]; + uint16_t qh[QUANT_K_IQ2_S/64]; + uint16_t scales[QUANT_K_IQ2_S/64]; +}; + #if defined(DATA_A_IQ2_S) const uvec2 iq2s_grid_const[1024] = { @@ -1023,8 +1102,10 @@ shared uvec2 iq2s_grid[1024]; void init_iq_shmem(uvec3 wgsize) { // copy the table into shared memory and sync - for (uint i = gl_LocalInvocationIndex.x; i < iq2s_grid.length(); i += wgsize.x) { - iq2s_grid[i] = iq2s_grid_const[i]; + [[unroll]] for (uint i = 0; i < iq2s_grid.length(); i += wgsize.x) { + if (iq2s_grid.length() % wgsize.x == 0 || i + gl_LocalInvocationIndex.x < iq2s_grid_const.length()) { + iq2s_grid[i + gl_LocalInvocationIndex.x] = iq2s_grid_const[i + gl_LocalInvocationIndex.x]; + } } barrier(); } @@ -1032,6 +1113,7 @@ void init_iq_shmem(uvec3 wgsize) #define QUANT_K QUANT_K_IQ2_S #define QUANT_R QUANT_R_IQ2_S #define A_TYPE block_iq2_s +#define A_TYPE_PACKED16 block_iq2_s_packed16 #endif #define QUANT_K_IQ3_XXS 256 @@ -1092,8 +1174,10 @@ shared uint32_t iq3xxs_grid[256]; void init_iq_shmem(uvec3 wgsize) { // copy the table into shared memory and sync - for (uint i = gl_LocalInvocationIndex.x; i < iq3xxs_grid.length(); i += wgsize.x) { - iq3xxs_grid[i] = iq3xxs_grid_const[i]; + [[unroll]] for (uint i = 0; i < iq3xxs_grid.length(); i += wgsize.x) { + if (iq3xxs_grid.length() % wgsize.x == 0 || i + gl_LocalInvocationIndex.x < iq3xxs_grid.length()) { + iq3xxs_grid[i + gl_LocalInvocationIndex.x] = iq3xxs_grid_const[i + gl_LocalInvocationIndex.x]; + } } barrier(); } @@ -1200,8 +1284,10 @@ shared uint32_t iq3s_grid[512]; void init_iq_shmem(uvec3 wgsize) { // copy the table into shared memory and sync - for (uint i = gl_LocalInvocationIndex.x; i < iq3s_grid.length(); i += wgsize.x) { - iq3s_grid[i] = iq3s_grid_const[i]; + [[unroll]] for (uint i = 0; i < iq3s_grid.length(); i += wgsize.x) { + if (iq3s_grid.length() % wgsize.x == 0 || i + gl_LocalInvocationIndex.x < iq3s_grid.length()) { + iq3s_grid[i + gl_LocalInvocationIndex.x] = iq3s_grid_const[i + gl_LocalInvocationIndex.x]; + } } barrier(); } @@ -1270,4 +1356,18 @@ void init_iq_shmem(uvec3 wgsize) } #endif +// returns the bfloat value in the low 16b. +// See ggml_compute_fp32_to_bf16 +uint32_t fp32_to_bf16(float f) +{ + uint32_t u = floatBitsToUint(f); + u = (u + (0x7fff + ((u >> 16) & 1))) >> 16; + return u; +} + +float bf16_to_fp32(uint32_t u) +{ + return uintBitsToFloat(u << 16); +} + #endif // !defined(GGML_TYPES_COMP) diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp index c5e0bba82..d196137eb 100644 --- a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp @@ -63,7 +63,8 @@ const std::vector type_names = { "iq3_xxs", "iq3_s", "iq4_xs", - "iq4_nl" + "iq4_nl", + "bf16", }; namespace { @@ -295,7 +296,9 @@ void matmul_shaders(bool fp16, bool matmul_id, bool coopmat, bool coopmat2, bool std::string aligned_b_type_f32 = coopmat2 ? "float" : fp16 ? "mat2x4" : "vec4"; std::string aligned_b_type_f16 = coopmat2 ? "float16_t" : fp16 ? "f16mat2x4" : "f16vec4"; - std::map base_dict = {{"FLOAT_TYPE", (coopmat2 || fp16) ? "float16_t" : "float"}}; + std::map base_dict = { + {"FLOAT_TYPE_VEC2", (coopmat2 || fp16) ? "f16vec2" : "vec2"}, + }; std::string shader_name = "matmul"; if (matmul_id) { @@ -313,34 +316,81 @@ void matmul_shaders(bool fp16, bool matmul_id, bool coopmat, bool coopmat2, bool base_dict["COOPMAT"] = "1"; } - base_dict["ACC_TYPE"] = f16acc ? "float16_t" : "float"; + const std::string source_name = coopmat2 ? "mul_mm_cm2.comp" : "mul_mm.comp"; - std::string source_name = coopmat2 ? "mul_mm_cm2.comp" : "mul_mm.comp"; + auto const &FLOAT_TYPE = [&](const std::string &t) -> std::string { + if (t == "bf16") { + // scalar path promotes to float + if (!coopmat && !coopmat2) { + return "float"; + } + return "bfloat16_t"; + } + if (coopmat2 || fp16) { + return "float16_t"; + } + return "float"; + }; // Shaders with f16 B_TYPE - string_to_spv(shader_name + "_f32_f16", source_name, merge_maps(base_dict, {{"DATA_A_F32", "1"}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}, }), fp16, coopmat, coopmat2, f16acc); - string_to_spv(shader_name + "_f32_f16_aligned", source_name, merge_maps(base_dict, {{"DATA_A_F32", "1"}, {"LOAD_VEC_A", load_vec}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"D_TYPE", "float"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc); + string_to_spv(shader_name + "_f32_f16", source_name, merge_maps(base_dict, {{"FLOAT_TYPE", FLOAT_TYPE("f16")}, {"DATA_A_F32", "1"}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}, }), fp16, coopmat, coopmat2, f16acc); + string_to_spv(shader_name + "_f32_f16_aligned", source_name, merge_maps(base_dict, {{"FLOAT_TYPE", FLOAT_TYPE("f16")}, {"DATA_A_F32", "1"}, {"LOAD_VEC_A", load_vec}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"D_TYPE", "float"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc); - string_to_spv(shader_name + "_f16_aligned", source_name, merge_maps(base_dict, {{"DATA_A_F16", "1"}, {"LOAD_VEC_A", load_vec}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"D_TYPE", "float"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc); - string_to_spv(shader_name + "_f16", source_name, merge_maps(base_dict, {{"DATA_A_F16", "1"}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}}), fp16, coopmat, coopmat2, f16acc); + string_to_spv(shader_name + "_f16_aligned", source_name, merge_maps(base_dict, {{"FLOAT_TYPE", FLOAT_TYPE("f16")}, {"DATA_A_F16", "1"}, {"LOAD_VEC_A", load_vec}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"D_TYPE", "float"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc); + string_to_spv(shader_name + "_f16", source_name, merge_maps(base_dict, {{"FLOAT_TYPE", FLOAT_TYPE("f16")}, {"DATA_A_F16", "1"}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}}), fp16, coopmat, coopmat2, f16acc); + + // bf16 + { + std::string load_vec_a_unaligned = "1"; + // For aligned matmul loads + std::string load_vec_a = coopmat2 ? "1" : "4"; + + // scalar path promotes to float + std::string to_float_type = (coopmat || coopmat2) ? "uintBitsToBFloat16EXT" : "bf16_to_fp32"; + + // If bfloat16 is not supported, then only compile the scalar (promote to fp32) shader +#if !defined(GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT) + if (!(coopmat || coopmat2)) +#endif + { + string_to_spv(shader_name + "_bf16_aligned", source_name, merge_maps(base_dict, {{"FLOAT_TYPE", FLOAT_TYPE("bf16")}, {"TO_FLOAT_TYPE", to_float_type}, {"DATA_A_BF16", "1"}, {"LOAD_VEC_A", load_vec_a}, {"LOAD_VEC_B", "4"}, {"B_TYPE", coopmat2 ? "bfloat16_t" : "u16vec4"}, {"D_TYPE", "float"}, {"B_IS_FLOAT", "1"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc); + string_to_spv(shader_name + "_bf16", source_name, merge_maps(base_dict, {{"FLOAT_TYPE", FLOAT_TYPE("bf16")}, {"TO_FLOAT_TYPE", to_float_type}, {"DATA_A_BF16", "1"}, {"LOAD_VEC_A", load_vec_a_unaligned}, {"B_TYPE", coopmat2 ? "bfloat16_t" : "uint16_t"}, {"D_TYPE", "float"}, {"B_IS_FLOAT", "1"}}), fp16, coopmat, coopmat2, f16acc); + } + } for (const auto& tname : type_names) { + std::string load_vec_quant = "2"; + if ((tname == "q4_0") || (tname == "q4_1")) + load_vec_quant = "8"; + else if ((tname == "q5_0") || (tname == "q5_1") || (tname == "q8_0") || (tname == "iq4_nl")) + load_vec_quant = "4"; + + if (tname == "bf16") { + continue; + } + std::string data_a_key = "DATA_A_" + to_uppercase(tname); // For unaligned, load one at a time for f32/f16, or two at a time for quants - std::string load_vec_a_unaligned = (coopmat2 || tname == "f32" || tname == "f16") ? "1" : "2"; + std::string load_vec_a_unaligned = (coopmat2 || tname == "f32" || tname == "f16" || tname == "bf16") ? "1" : load_vec_quant; // For aligned matmul loads - std::string load_vec_a = (coopmat2 || tname == "f32" || tname == "f16") ? load_vec : "2"; + std::string load_vec_a = (coopmat2 || tname == "f32" || tname == "f16" || tname == "bf16") ? load_vec : load_vec_quant; // don't generate f32 variants for coopmat2 if (!coopmat2) { - string_to_spv(shader_name + "_" + tname + "_f32", source_name, merge_maps(base_dict, {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a_unaligned}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"B_IS_FLOAT", "1"}}), fp16, coopmat, coopmat2, f16acc); - string_to_spv(shader_name + "_" + tname + "_f32_aligned", source_name, merge_maps(base_dict, {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f32}, {"D_TYPE", "float"}, {"B_IS_FLOAT", "1"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc); + string_to_spv(shader_name + "_" + tname + "_f32", source_name, merge_maps(base_dict, {{"FLOAT_TYPE", FLOAT_TYPE(tname)}, {data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a_unaligned}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}), fp16, coopmat, coopmat2, f16acc); + string_to_spv(shader_name + "_" + tname + "_f32_aligned", source_name, merge_maps(base_dict, {{"FLOAT_TYPE", FLOAT_TYPE(tname)}, {data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f32}, {"D_TYPE", "float"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc); } if (tname != "f16" && tname != "f32") { - string_to_spv(shader_name + "_" + tname + "_f16", source_name, merge_maps(base_dict, {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a_unaligned}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}, {"B_IS_FLOAT", "1"}}), fp16, coopmat, coopmat2, f16acc); - string_to_spv(shader_name + "_" + tname + "_f16_aligned", source_name, merge_maps(base_dict, {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"D_TYPE", "float"}, {"B_IS_FLOAT", "1"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc); + string_to_spv(shader_name + "_" + tname + "_f16", source_name, merge_maps(base_dict, {{"FLOAT_TYPE", FLOAT_TYPE(tname)}, {data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a_unaligned}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}}), fp16, coopmat, coopmat2, f16acc); + string_to_spv(shader_name + "_" + tname + "_f16_aligned", source_name, merge_maps(base_dict, {{"FLOAT_TYPE", FLOAT_TYPE(tname)}, {data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"D_TYPE", "float"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc); } + +#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT) + if (!coopmat && !coopmat2 && !matmul_id && (tname == "q4_0" || tname == "q4_1" || tname == "q5_0" || tname == "q5_1" || tname == "q8_0")) { + string_to_spv(shader_name + "_" + tname + "_q8_1", "mul_mmq.comp", merge_maps(base_dict, {{"FLOAT_TYPE", FLOAT_TYPE(tname)}, {data_a_key, "1"}, {"D_TYPE", "float"},}), fp16, coopmat, coopmat2, f16acc); + } +#endif } } @@ -371,7 +421,6 @@ void process_shaders() { #endif } -#if defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT) // flash attention for (const auto& f16acc : {false, true}) { std::string acctype = f16acc ? "float16_t" : "float"; @@ -380,7 +429,9 @@ void process_shaders() { if (tname == "f32") { continue; } + if (tname == "bf16") continue; +#if defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT) if (tname == "f16") { string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn_cm2.comp", merge_maps(base_dict, {{"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"ACC_TYPE", acctype}}), true, false, true, f16acc); @@ -389,14 +440,22 @@ void process_shaders() { string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn_cm2.comp", merge_maps(base_dict, {{data_a_key, "1"}, {"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"ACC_TYPE", acctype}, {"DEQUANTFUNC", "dequantFunc"+to_uppercase(tname) }, {"BLOCK_SIZE", "QUANT_K_"+to_uppercase(tname) }}), true, false, true, f16acc); } +#endif + if (tname == "f16") { + string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn.comp", + merge_maps(base_dict, {{"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"ACC_TYPE", acctype}}), true, false, false, f16acc); + } else if (tname == "q4_0" || tname == "q8_0") { + std::string data_a_key = "DATA_A_" + to_uppercase(tname); + string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn.comp", + merge_maps(base_dict, {{data_a_key, "1"}, {"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"ACC_TYPE", acctype}, {"BLOCK_SIZE", "QUANT_K_"+to_uppercase(tname) }}), true, false, false, f16acc); + } } } -#endif for (const auto& tname : type_names) { // mul mat vec std::string data_a_key = "DATA_A_" + to_uppercase(tname); - std::string shader = (string_ends_with(tname, "_k") || string_starts_with(tname, "iq1_")) ? "mul_mat_vec_" + tname + ".comp" : "mul_mat_vec.comp"; + std::string shader = (string_ends_with(tname, "_k") || string_starts_with(tname, "iq1_") || string_starts_with(tname, "iq2_") || string_starts_with(tname, "iq3_")) ? "mul_mat_vec_" + tname + ".comp" : "mul_mat_vec.comp"; string_to_spv("mul_mat_vec_" + tname + "_f32_f32", shader, merge_maps(base_dict, {{data_a_key, "1"}, {"B_TYPE", "float"}, {"B_TYPE_VEC2", "vec2"}, {"B_TYPE_VEC4", "vec4"}, {"D_TYPE", "float"}})); string_to_spv("mul_mat_vec_" + tname + "_f16_f32", shader, merge_maps(base_dict, {{data_a_key, "1"}, {"B_TYPE", "float16_t"}, {"B_TYPE_VEC2", "f16vec2"}, {"B_TYPE_VEC4", "f16vec4"}, {"D_TYPE", "float"}})); @@ -404,12 +463,12 @@ void process_shaders() { string_to_spv("mul_mat_vec_id_" + tname + "_f32", shader, merge_maps(base_dict, {{"MUL_MAT_ID", "1"}, {data_a_key, "1"}, {"B_TYPE", "float"}, {"B_TYPE_VEC2", "vec2"}, {"B_TYPE_VEC4", "vec4"}, {"D_TYPE", "float"}})); // Dequant shaders - if (tname != "f16") { + if (tname != "f16" && tname != "bf16") { string_to_spv("dequant_" + tname, "dequant_" + tname + ".comp", merge_maps(base_dict, {{data_a_key, "1"}, {"D_TYPE", "float16_t"}})); } if (!string_ends_with(tname, "_k")) { - shader = (tname == "f32" || tname == "f16") ? "get_rows.comp" : "get_rows_quant.comp"; + shader = (tname == "f32" || tname == "f16" || tname == "bf16") ? "get_rows.comp" : "get_rows_quant.comp"; if (tname == "f16") { string_to_spv("get_rows_" + tname, shader, merge_maps(base_dict, {{data_a_key, "1"}, {"B_TYPE", "int"}, {"D_TYPE", "float16_t"}, {"OPTIMIZATION_ERROR_WORKAROUND", "1"}})); @@ -420,35 +479,62 @@ void process_shaders() { } } - string_to_spv("mul_mat_vec_p021_f16_f32", "mul_mat_vec_p021.comp", {{"A_TYPE", "float16_t"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}); - string_to_spv("mul_mat_vec_nc_f16_f32", "mul_mat_vec_nc.comp", {{"A_TYPE", "float16_t"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}); + string_to_spv("mul_mat_vec_p021_f16_f32_subgroup_add", "mul_mat_vec_p021.comp", {{"A_TYPE", "float16_t"}, {"A_TYPE_VEC4", "f16vec4"}, {"B_TYPE", "float"}, {"B_TYPE_VEC4", "vec4"}, {"D_TYPE", "float"}, {"USE_SUBGROUP_ADD", "1"}}); + string_to_spv("mul_mat_vec_p021_f16_f32", "mul_mat_vec_p021.comp", {{"A_TYPE", "float16_t"}, {"A_TYPE_VEC4", "f16vec4"}, {"B_TYPE", "float"}, {"B_TYPE_VEC4", "vec4"}, {"D_TYPE", "float"}}); + string_to_spv("mul_mat_vec_nc_f16_f32", "mul_mat_vec_nc.comp", {{"A_TYPE", "float16_t"}, {"A_TYPE_VEC4", "f16vec4"}, {"B_TYPE", "float"}, {"B_TYPE_VEC4", "vec4"}, {"D_TYPE", "float"}}); // Norms string_to_spv("norm_f32", "norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}})); string_to_spv("group_norm_f32", "group_norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}})); string_to_spv("rms_norm_f32", "rms_norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}})); string_to_spv("rms_norm_back_f32", "rms_norm_back.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}})); + string_to_spv("l2_norm_f32", "l2_norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}})); string_to_spv("cpy_f32_f32", "copy.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); string_to_spv("cpy_f32_f16", "copy.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float16_t"}}); string_to_spv("cpy_f16_f16", "copy.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"OPTIMIZATION_ERROR_WORKAROUND", "1"}}); + string_to_spv("cpy_f16_f32", "copy.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float"}, {"OPTIMIZATION_ERROR_WORKAROUND", "1"}}); + string_to_spv("cpy_f32_bf16","copy.comp", {{"A_TYPE", "float"}, {"D_TYPE", "uint16_t"}, {"DATA_D_BF16", "1"}}); string_to_spv("contig_cpy_f32_f32", "contig_copy.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); string_to_spv("contig_cpy_f32_f16", "contig_copy.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float16_t"}}); string_to_spv("contig_cpy_f16_f16", "contig_copy.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"OPTIMIZATION_ERROR_WORKAROUND", "1"}}); + string_to_spv("contig_cpy_f16_f32", "contig_copy.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float"}, {"OPTIMIZATION_ERROR_WORKAROUND", "1"}}); + string_to_spv("contig_cpy_f32_bf16","contig_copy.comp",{{"A_TYPE", "float"}, {"D_TYPE", "uint16_t"}, {"DATA_D_BF16", "1"}}); for (std::string t : {"q4_0", "q4_1", "q5_0", "q5_1", "q8_0", "iq4_nl"}) { string_to_spv("cpy_f32_" + t, "copy_to_quant.comp", {{"DATA_A_" + to_uppercase(t), "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}}); + string_to_spv("cpy_f32_" + t + "_rte", "copy_to_quant.comp", {{"DATA_A_" + to_uppercase(t), "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"RTE16", "1"}}); string_to_spv("cpy_" + t + "_f32", "copy_from_quant.comp", {{"DATA_A_" + to_uppercase(t), "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}}); } - string_to_spv("add_f32", "add.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}}); - string_to_spv("add_f16_f32_f16", "add.comp", {{"A_TYPE", "float16_t"}, {"B_TYPE", "float"}, {"D_TYPE", "float16_t"}, {"FLOAT_TYPE", "float"}}); + auto get_type_str = [](bool f16) { + return f16 ? "float16_t" : "float"; + }; + auto get_suffix = [](bool src0_f16, bool src1_f16, bool dst_f16) { + std::string s; + s += std::string(src0_f16 ? "_f16" : "_f32"); + s += std::string(src1_f16 ? "_f16" : "_f32"); + s += std::string(dst_f16 ? "_f16" : "_f32"); + return s; + }; + for (std::string op : {"add", "sub", "mul", "div"}) { + for (auto src0_f16 : {false, true}) { + for (auto src1_f16 : {false, true}) { + for (auto dst_f16 : {false, true}) { + auto name = op + get_suffix(src0_f16, src1_f16, dst_f16); + string_to_spv(name.c_str(), op + ".comp", {{"A_TYPE", get_type_str(src0_f16)}, {"B_TYPE", get_type_str(src1_f16)}, {"D_TYPE", get_type_str(dst_f16)}, {"FLOAT_TYPE", "float"}}); + } + } + } + } string_to_spv("sub_f32", "sub.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}}); string_to_spv("acc_f32", "acc.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}}); string_to_spv("split_k_reduce", "mul_mat_split_k_reduce.comp", {}); + string_to_spv("fa_split_k_reduce", "flash_attn_split_k_reduce.comp", {}); + string_to_spv("quantize_q8_1", "quantize_q8_1.comp", {}); string_to_spv("mul_f32", "mul.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}}); @@ -475,14 +561,21 @@ void process_shaders() { string_to_spv("upscale_f32", "upscale.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}); - string_to_spv("gelu_f32", "gelu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); - string_to_spv("gelu_quick_f32", "gelu_quick.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); - string_to_spv("silu_f32", "silu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); - string_to_spv("silu_back_f32", "silu_back.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}); - string_to_spv("relu_f32", "relu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); - string_to_spv("leaky_relu_f32", "leaky_relu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); - string_to_spv("tanh_f32", "tanh.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); - string_to_spv("sigmoid_f32", "sigmoid.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); + string_to_spv("gelu_f16", "gelu.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}}); + string_to_spv("gelu_f32", "gelu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); + string_to_spv("gelu_quick_f16", "gelu_quick.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}}); + string_to_spv("gelu_quick_f32", "gelu_quick.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); + string_to_spv("silu_f16", "silu.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}}); + string_to_spv("silu_f32", "silu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); + string_to_spv("relu_f16", "relu.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}}); + string_to_spv("relu_f32", "relu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); + string_to_spv("tanh_f16", "tanh.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}}); + string_to_spv("tanh_f32", "tanh.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); + string_to_spv("sigmoid_f16", "sigmoid.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}}); + string_to_spv("sigmoid_f32", "sigmoid.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); + + string_to_spv("leaky_relu_f32", "leaky_relu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); + string_to_spv("silu_back_f32", "silu_back.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}); string_to_spv("diag_mask_inf_f32", "diag_mask_inf.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); @@ -522,8 +615,13 @@ void process_shaders() { string_to_spv("rwkv_wkv6_f32", "wkv6.comp", merge_maps(base_dict, {{"A_TYPE", "float"}})); + string_to_spv("rwkv_wkv7_f32", "wkv7.comp", merge_maps(base_dict, {{"A_TYPE", "float"}})); + string_to_spv("opt_step_adamw_f32", "opt_step_adamw.comp", merge_maps(base_dict, {{"A_TYPE", "float"}})); + string_to_spv("conv2d_dw_whcn_f32", "conv2d_dw.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"WHCN", "1"}})); + string_to_spv("conv2d_dw_cwhn_f32", "conv2d_dw.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"CWHN", "1"}})); + for (auto &c : compiles) { c.wait(); } @@ -578,7 +676,12 @@ void write_output_files() { std::remove(path.c_str()); } } - + for (const char *op : {"add", "sub", "mul", "div"}) { + fprintf(hdr, "extern unsigned char *%s_data[2][2][2];\n", op); + fprintf(hdr, "extern uint64_t %s_len[2][2][2];\n", op); + fprintf(src, "unsigned char *%s_data[2][2][2] = {{{%s_f32_f32_f32_data, %s_f32_f32_f16_data}, {%s_f32_f16_f32_data, %s_f32_f16_f16_data}}, {{%s_f16_f32_f32_data, %s_f16_f32_f16_data}, {%s_f16_f16_f32_data, %s_f16_f16_f16_data}}};\n", op, op, op, op, op, op, op, op, op); + fprintf(src, "uint64_t %s_len[2][2][2] = {{{%s_f32_f32_f32_len, %s_f32_f32_f16_len}, {%s_f32_f16_f32_len, %s_f32_f16_f16_len}}, {{%s_f16_f32_f32_len, %s_f16_f32_f16_len}, {%s_f16_f16_f32_len, %s_f16_f16_f16_len}}};\n", op, op, op, op, op, op, op, op, op); + } fclose(hdr); fclose(src); } diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/wkv7.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/wkv7.comp new file mode 100644 index 000000000..88c1c02b3 --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/wkv7.comp @@ -0,0 +1,91 @@ +#version 450 + +#extension GL_EXT_control_flow_attributes : require + +#define BLOCK_SIZE 64 +layout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in; + +layout(push_constant) uniform Parameters { + uint B; + uint T; + uint C; + uint H; +}; + +layout(binding = 0) readonly buffer RBuf { A_TYPE r[]; }; +layout(binding = 1) readonly buffer WBuf { A_TYPE w[]; }; +layout(binding = 2) readonly buffer KBuf { A_TYPE k[]; }; +layout(binding = 3) readonly buffer VBuf { A_TYPE v[]; }; +layout(binding = 4) readonly buffer ABuf { A_TYPE a[]; }; +layout(binding = 5) readonly buffer BBuf { A_TYPE b[]; }; +layout(binding = 6) readonly buffer StateBuf { A_TYPE state_in[]; }; +layout(binding = 7) buffer DstBuf { A_TYPE dst[]; }; + +shared A_TYPE _r[BLOCK_SIZE], _w[BLOCK_SIZE], _k[BLOCK_SIZE], _a[BLOCK_SIZE], _b[BLOCK_SIZE]; + +void main() { + const uint head_size = BLOCK_SIZE; + const uint batch_id = gl_WorkGroupID.x / H; + const uint head_id = gl_WorkGroupID.x % H; + const uint tid = gl_LocalInvocationID.x; + + const uint state_size = C * head_size; + const uint n_seq_tokens = T / B; + + if (batch_id >= B || head_id >= H) { + return; + } + + A_TYPE state[BLOCK_SIZE]; + [[unroll]] for (uint i = 0; i < head_size; i++) { + state[i] = state_in[batch_id * state_size + head_id * head_size * head_size + + tid * head_size + i]; + } + + const uint start_t = batch_id * n_seq_tokens * C + head_id * head_size + tid; + const uint end_t = (batch_id + 1) * n_seq_tokens * C + head_id * head_size + tid; + + for (uint t = start_t; t < end_t; t += C) { + barrier(); + _r[tid] = r[t]; + _w[tid] = w[t]; + _k[tid] = k[t]; + _a[tid] = a[t]; + _b[tid] = b[t]; + barrier(); + + A_TYPE sa = 0.0; + [[unroll]] for (uint j = 0; j < head_size; j += 4) { + vec4 s_vec = vec4(state[j], state[j+1], state[j+2], state[j+3]); + vec4 a_vec = vec4(_a[j], _a[j+1], _a[j+2], _a[j+3]); + sa += dot(s_vec, a_vec); + } + + const A_TYPE v_val = v[t]; + A_TYPE y = 0.0; + + [[unroll]] for (uint j = 0; j < head_size; j += 4) { + vec4 r_vec = vec4(_r[j], _r[j+1], _r[j+2], _r[j+3]); + vec4 w_vec = vec4(_w[j], _w[j+1], _w[j+2], _w[j+3]); + vec4 k_vec = vec4(_k[j], _k[j+1], _k[j+2], _k[j+3]); + vec4 b_vec = vec4(_b[j], _b[j+1], _b[j+2], _b[j+3]); + vec4 s_vec = vec4(state[j], state[j+1], state[j+2], state[j+3]); + + vec4 kv = k_vec * v_val; + s_vec = s_vec * w_vec + kv + sa * b_vec; + y += dot(r_vec, s_vec); + + state[j] = s_vec.x; + state[j+1] = s_vec.y; + state[j+2] = s_vec.z; + state[j+3] = s_vec.w; + } + + dst[t] = y; + } + + [[unroll]] for (uint i = 0; i < head_size; i++) { + dst[T * C + batch_id * state_size + head_id * head_size * head_size + + tid * head_size + i] = state[i]; + } +}