From 5305e2ad14a594a0e19fb1022312158346ab00ee Mon Sep 17 00:00:00 2001 From: Gabe Goodhart Date: Tue, 15 Jul 2025 14:50:01 -0600 Subject: [PATCH] feat: Sync llama.cpp Branch: GraniteFour Signed-off-by: Gabe Goodhart --- llama/build-info.cpp | 2 +- llama/llama.cpp/include/llama.h | 13 +- llama/llama.cpp/src/llama-arch.cpp | 60 + llama/llama.cpp/src/llama-arch.h | 7 + llama/llama.cpp/src/llama-chat.cpp | 25 +- llama/llama.cpp/src/llama-chat.h | 1 + llama/llama.cpp/src/llama-context.cpp | 15 +- llama/llama.cpp/src/llama-hparams.cpp | 5 + llama/llama.cpp/src/llama-hparams.h | 4 +- llama/llama.cpp/src/llama-model.cpp | 626 +++++++ llama/llama.cpp/src/llama-model.h | 11 + llama/llama.cpp/src/llama-quant.cpp | 4 +- llama/llama.cpp/src/llama-vocab.cpp | 363 ++++- llama/llama.cpp/src/llama-vocab.h | 1 + llama/llama.cpp/src/unicode.cpp | 207 +++ llama/llama.cpp/src/unicode.h | 2 + .../ggml/src/ggml-cpu/llamafile/sgemm.cpp | 1431 ++++------------- .../ggml/ggml/src/ggml-cuda/ggml-cuda.cu | 15 + .../ggml/ggml/src/ggml-cuda/set-rows.cu | 151 ++ .../ggml/ggml/src/ggml-cuda/set-rows.cuh | 7 + .../ggml/ggml/src/ggml-cuda/ssm-conv.cu | 12 +- ml/backend/ggml/ggml/src/ggml-cuda/unary.cu | 7 + ml/backend/ggml/ggml/src/ggml-cuda/unary.cuh | 2 + .../ggml/ggml/src/ggml-cuda/vendors/hip.h | 19 +- .../src/ggml-metal/ggml-metal-embed.metal | 45 + .../ggml/ggml/src/ggml-metal/ggml-metal.m | 90 ++ .../ggml/ggml/src/ggml-metal/ggml-metal.metal | 45 + 27 files changed, 2050 insertions(+), 1120 deletions(-) create mode 100644 ml/backend/ggml/ggml/src/ggml-cuda/set-rows.cu create mode 100644 ml/backend/ggml/ggml/src/ggml-cuda/set-rows.cuh diff --git a/llama/build-info.cpp b/llama/build-info.cpp index fb1e51f03..ab60add7d 100644 --- a/llama/build-info.cpp +++ b/llama/build-info.cpp @@ -1,4 +1,4 @@ int LLAMA_BUILD_NUMBER = 0; -char const *LLAMA_COMMIT = "aaa088d87f9006d56866085fe46e4b2755ef723f"; +char const *LLAMA_COMMIT = "4a4f426944e79b79e389f9ed7b34831cb9b637ad"; char const *LLAMA_COMPILER = ""; char const *LLAMA_BUILD_TARGET = ""; diff --git a/llama/llama.cpp/include/llama.h b/llama/llama.cpp/include/llama.h index f73b1ab65..c83b75915 100644 --- a/llama/llama.cpp/include/llama.h +++ b/llama/llama.cpp/include/llama.h @@ -71,12 +71,13 @@ extern "C" { typedef int32_t llama_seq_id; enum llama_vocab_type { - LLAMA_VOCAB_TYPE_NONE = 0, // For models without vocab - LLAMA_VOCAB_TYPE_SPM = 1, // LLaMA tokenizer based on byte-level BPE with byte fallback - LLAMA_VOCAB_TYPE_BPE = 2, // GPT-2 tokenizer based on byte-level BPE - LLAMA_VOCAB_TYPE_WPM = 3, // BERT tokenizer based on WordPiece - LLAMA_VOCAB_TYPE_UGM = 4, // T5 tokenizer based on Unigram - LLAMA_VOCAB_TYPE_RWKV = 5, // RWKV tokenizer based on greedy tokenization + LLAMA_VOCAB_TYPE_NONE = 0, // For models without vocab + LLAMA_VOCAB_TYPE_SPM = 1, // LLaMA tokenizer based on byte-level BPE with byte fallback + LLAMA_VOCAB_TYPE_BPE = 2, // GPT-2 tokenizer based on byte-level BPE + LLAMA_VOCAB_TYPE_WPM = 3, // BERT tokenizer based on WordPiece + LLAMA_VOCAB_TYPE_UGM = 4, // T5 tokenizer based on Unigram + LLAMA_VOCAB_TYPE_RWKV = 5, // RWKV tokenizer based on greedy tokenization + LLAMA_VOCAB_TYPE_PLAMO2 = 6, // PLaMo-2 tokenizer based on Aho-Corasick with dynamic programming }; enum llama_rope_type { diff --git a/llama/llama.cpp/src/llama-arch.cpp b/llama/llama.cpp/src/llama-arch.cpp index d9d5ec655..78b50e9e9 100644 --- a/llama/llama.cpp/src/llama-arch.cpp +++ b/llama/llama.cpp/src/llama-arch.cpp @@ -34,6 +34,7 @@ static const std::map LLM_ARCH_NAMES = { { LLM_ARCH_PHI3, "phi3" }, { LLM_ARCH_PHIMOE, "phimoe" }, { LLM_ARCH_PLAMO, "plamo" }, + { LLM_ARCH_PLAMO2, "plamo2" }, { LLM_ARCH_CODESHELL, "codeshell" }, { LLM_ARCH_ORION, "orion" }, { LLM_ARCH_INTERNLM2, "internlm2" }, @@ -84,6 +85,7 @@ static const std::map LLM_ARCH_NAMES = { { LLM_ARCH_ERNIE4_5, "ernie4_5" }, { LLM_ARCH_HUNYUAN_MOE, "hunyuan-moe" }, { LLM_ARCH_SMOLLM3, "smollm3" }, + { LLM_ARCH_LFM2, "lfm2" }, { LLM_ARCH_UNKNOWN, "(unknown)" }, }; @@ -190,6 +192,8 @@ static const std::map LLM_KV_NAMES = { { LLM_KV_CLASSIFIER_OUTPUT_LABELS, "%s.classifier.output_labels" }, + { LLM_KV_SHORTCONV_L_CACHE, "%s.shortconv.l_cache" }, + { LLM_KV_TOKENIZER_MODEL, "tokenizer.ggml.model" }, { LLM_KV_TOKENIZER_PRE, "tokenizer.ggml.pre" }, { LLM_KV_TOKENIZER_LIST, "tokenizer.ggml.tokens" }, @@ -783,6 +787,36 @@ static const std::map> LLM_TENSOR_N { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, }, }, + { + LLM_ARCH_PLAMO2, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_ROPE_FREQS, "rope_freqs" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_QKV, "blk.%d.attn_qkv" }, + { LLM_TENSOR_ATTN_Q_NORM, "blk.%d.attn_q_norm" }, + { LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_ATTN_ROT_EMBD, "blk.%d.attn_rot_embd" }, + { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + { LLM_TENSOR_SSM_IN, "blk.%d.ssm_in" }, + { LLM_TENSOR_SSM_CONV1D, "blk.%d.ssm_conv1d" }, + { LLM_TENSOR_SSM_X, "blk.%d.ssm_x" }, + { LLM_TENSOR_SSM_DT, "blk.%d.ssm_dt" }, + { LLM_TENSOR_SSM_A, "blk.%d.ssm_a" }, + { LLM_TENSOR_SSM_D, "blk.%d.ssm_d" }, + { LLM_TENSOR_SSM_OUT, "blk.%d.ssm_out" }, + { LLM_TENSOR_SSM_DT_NORM, "blk.%d.ssm_dt_norm" }, + { LLM_TENSOR_SSM_B_NORM, "blk.%d.ssm_b_norm" }, + { LLM_TENSOR_SSM_C_NORM, "blk.%d.ssm_c_norm" }, + { LLM_TENSOR_ATTN_POST_NORM, "blk.%d.post_attention_norm" }, + { LLM_TENSOR_FFN_POST_NORM, "blk.%d.post_ffw_norm" }, + }, + }, { LLM_ARCH_CODESHELL, { @@ -1850,6 +1884,27 @@ static const std::map> LLM_TENSOR_N { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, }, }, + { + LLM_ARCH_LFM2, + { + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, + { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, + { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm" }, + { LLM_TENSOR_ATTN_Q_NORM, "blk.%d.attn_q_norm" }, + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, + { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, + { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + { LLM_TENSOR_SHORTCONV_CONV, "blk.%d.shortconv.conv" }, + { LLM_TENSOR_SHORTCONV_INPROJ, "blk.%d.shortconv.in_proj" }, + { LLM_TENSOR_SHORTCONV_OUTPROJ, "blk.%d.shortconv.out_proj" }, + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_TOKEN_EMBD_NORM, "token_embd_norm" }, + } + }, { LLM_ARCH_UNKNOWN, { @@ -2018,6 +2073,9 @@ static const std::map LLM_TENSOR_INFOS = { {LLM_TENSOR_CONVNEXT_PW1, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, {LLM_TENSOR_CONVNEXT_PW2, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, {LLM_TENSOR_CONVNEXT_GAMMA, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, + {LLM_TENSOR_SHORTCONV_CONV, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_SSM_CONV}}, + {LLM_TENSOR_SHORTCONV_INPROJ, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_SHORTCONV_OUTPROJ, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, }; LLM_KV::LLM_KV(llm_arch arch, const char * suffix) : arch(arch), suffix(suffix) {} @@ -2088,7 +2146,9 @@ bool llm_arch_is_hybrid(const llm_arch & arch) { switch (arch) { case LLM_ARCH_JAMBA: case LLM_ARCH_FALCON_H1: + case LLM_ARCH_PLAMO2: case LLM_ARCH_GRANITE_HYBRID: + case LLM_ARCH_LFM2: return true; default: return false; diff --git a/llama/llama.cpp/src/llama-arch.h b/llama/llama.cpp/src/llama-arch.h index 2cb0fd95d..6b0a1df95 100644 --- a/llama/llama.cpp/src/llama-arch.h +++ b/llama/llama.cpp/src/llama-arch.h @@ -38,6 +38,7 @@ enum llm_arch { LLM_ARCH_PHI3, LLM_ARCH_PHIMOE, LLM_ARCH_PLAMO, + LLM_ARCH_PLAMO2, LLM_ARCH_CODESHELL, LLM_ARCH_ORION, LLM_ARCH_INTERNLM2, @@ -88,6 +89,7 @@ enum llm_arch { LLM_ARCH_ERNIE4_5, LLM_ARCH_HUNYUAN_MOE, LLM_ARCH_SMOLLM3, + LLM_ARCH_LFM2, LLM_ARCH_UNKNOWN, }; @@ -229,6 +231,8 @@ enum llm_kv { LLM_KV_CLASSIFIER_OUTPUT_LABELS, + LLM_KV_SHORTCONV_L_CACHE, + // deprecated: LLM_KV_TOKENIZER_PREFIX_ID, LLM_KV_TOKENIZER_SUFFIX_ID, @@ -399,6 +403,9 @@ enum llm_tensor { LLM_TENSOR_POS_NET_ATTN_K, LLM_TENSOR_POS_NET_ATTN_V, LLM_TENSOR_POS_NET_ATTN_OUT, + LLM_TENSOR_SHORTCONV_CONV, + LLM_TENSOR_SHORTCONV_INPROJ, + LLM_TENSOR_SHORTCONV_OUTPROJ, }; enum llm_tensor_layer { diff --git a/llama/llama.cpp/src/llama-chat.cpp b/llama/llama.cpp/src/llama-chat.cpp index cbc19d3c4..240937ece 100644 --- a/llama/llama.cpp/src/llama-chat.cpp +++ b/llama/llama.cpp/src/llama-chat.cpp @@ -65,6 +65,7 @@ static const std::map LLM_CHAT_TEMPLATES = { { "llama4", LLM_CHAT_TEMPLATE_LLAMA4 }, { "smolvlm", LLM_CHAT_TEMPLATE_SMOLVLM }, { "hunyuan-moe", LLM_CHAT_TEMPLATE_HUNYUAN_MOE }, + { "kimi-k2", LLM_CHAT_TEMPLATE_KIMI_K2 }, }; llm_chat_template llm_chat_template_from_str(const std::string & name) { @@ -170,7 +171,7 @@ llm_chat_template llm_chat_detect_template(const std::string & tmpl) { // ref: https://huggingface.co/LGAI-EXAONE/EXAONE-3.0-7.8B-Instruct/discussions/8#66bae61b1893d14ee8ed85bb // EXAONE-3.0-7.8B-Instruct return LLM_CHAT_TEMPLATE_EXAONE_3; - } else if (tmpl_contains("rwkv-world")) { + } else if (tmpl_contains("rwkv-world") || tmpl_contains("{{- 'User: ' + message['content']|trim + '\\n\\n' -}}")) { return LLM_CHAT_TEMPLATE_RWKV_WORLD; } else if (tmpl_contains("<|start_of_role|>")) { return LLM_CHAT_TEMPLATE_GRANITE; @@ -188,6 +189,8 @@ llm_chat_template llm_chat_detect_template(const std::string & tmpl) { return LLM_CHAT_TEMPLATE_DOTS1; } else if (tmpl_contains("<|startoftext|>") && tmpl_contains("<|extra_4|>")) { return LLM_CHAT_TEMPLATE_HUNYUAN_MOE; + } else if (tmpl_contains("<|im_assistant|>assistant<|im_middle|>")) { + return LLM_CHAT_TEMPLATE_KIMI_K2; } return LLM_CHAT_TEMPLATE_UNKNOWN; } @@ -680,6 +683,26 @@ int32_t llm_chat_apply_template( ss << "<|startoftext|>" << message->content << "<|extra_0|>"; } } + } else if (tmpl == LLM_CHAT_TEMPLATE_KIMI_K2) { + // moonshotai/Kimi-K2-Instruct + for (auto message : chat) { + std::string role(message->role); + if (role == "system") { + ss << "<|im_system|>system<|im_middle|>"; + } else if (role == "user") { + ss << "<|im_user|>user<|im_middle|>"; + } else if (role == "assistant") { + ss << "<|im_assistant|>assistant<|im_middle|>"; + } else if (role == "tool") { + ss << "<|im_system|>tool<|im_middle|>"; + } + + ss << message->content << "<|im_end|>"; + + if (add_ass) { + ss << "<|im_assistant|>assistant<|im_middle|>"; + } + } } else { // template not supported return -1; diff --git a/llama/llama.cpp/src/llama-chat.h b/llama/llama.cpp/src/llama-chat.h index b621fda28..cab053348 100644 --- a/llama/llama.cpp/src/llama-chat.h +++ b/llama/llama.cpp/src/llama-chat.h @@ -45,6 +45,7 @@ enum llm_chat_template { LLM_CHAT_TEMPLATE_SMOLVLM, LLM_CHAT_TEMPLATE_DOTS1, LLM_CHAT_TEMPLATE_HUNYUAN_MOE, + LLM_CHAT_TEMPLATE_KIMI_K2, LLM_CHAT_TEMPLATE_UNKNOWN, }; diff --git a/llama/llama.cpp/src/llama-context.cpp b/llama/llama.cpp/src/llama-context.cpp index 06e93b19c..7c07b047b 100644 --- a/llama/llama.cpp/src/llama-context.cpp +++ b/llama/llama.cpp/src/llama-context.cpp @@ -731,7 +731,8 @@ int llama_context::encode(const llama_batch & batch_inp) { const auto & hparams = model.hparams; - const int64_t n_embd = hparams.n_embd; + const int64_t n_embd = hparams.n_embd; + const int32_t n_vocab = model.vocab.n_tokens(); // note: during encode, we always pass the full sequence starting from pos = 0 if (!balloc->init(batch_inp, model.vocab, nullptr, n_embd, true)) { @@ -791,10 +792,20 @@ int llama_context::encode(const llama_batch & batch_inp) { } } + auto * t_logits = res->get_logits(); auto * t_embd = res->get_embd_pooled() ? res->get_embd_pooled() : res->get_embd(); + // extract logits + if (logits && t_logits) { + ggml_backend_t backend_res = ggml_backend_sched_get_tensor_backend(sched.get(), t_logits); + GGML_ASSERT(backend_res != nullptr); + GGML_ASSERT(logits != nullptr); + + ggml_backend_tensor_get_async(backend_res, t_logits, logits, 0, n_tokens*n_vocab*sizeof(float)); + } + // extract embeddings - if (t_embd) { + if (embd && t_embd) { ggml_backend_t backend_embd = ggml_backend_sched_get_tensor_backend(sched.get(), t_embd); GGML_ASSERT(backend_embd != nullptr); diff --git a/llama/llama.cpp/src/llama-hparams.cpp b/llama/llama.cpp/src/llama-hparams.cpp index f1c965b8a..b4e456e3e 100644 --- a/llama/llama.cpp/src/llama-hparams.cpp +++ b/llama/llama.cpp/src/llama-hparams.cpp @@ -71,6 +71,11 @@ uint32_t llama_hparams::n_embd_r() const { return token_shift_count * n_embd; } + if (n_shortconv_l_cache != 0) { + // for LFM2 models + return n_embd * (n_shortconv_l_cache - 1); + } + // TODO: maybe support other convolution strides than 1 // NOTE: since the first column of the conv_state is shifted out each time, it's not actually needed // Corresponds to Mamba's conv_states size diff --git a/llama/llama.cpp/src/llama-hparams.h b/llama/llama.cpp/src/llama-hparams.h index 906fa1851..1badf9921 100644 --- a/llama/llama.cpp/src/llama-hparams.h +++ b/llama/llama.cpp/src/llama-hparams.h @@ -6,7 +6,7 @@ // bump if necessary #define LLAMA_MAX_LAYERS 512 -#define LLAMA_MAX_EXPERTS 256 // DeepSeekV3 +#define LLAMA_MAX_EXPERTS 384 // Kimi-K2 enum llama_expert_gating_func_type { LLAMA_EXPERT_GATING_FUNC_TYPE_NONE = 0, @@ -55,6 +55,8 @@ struct llama_hparams { struct llama_hparams_posnet posnet; struct llama_hparams_convnext convnext; + uint32_t n_shortconv_l_cache = 0; + std::array n_head_arr; std::array n_head_kv_arr; std::array n_ff_arr; diff --git a/llama/llama.cpp/src/llama-model.cpp b/llama/llama.cpp/src/llama-model.cpp index 35d7a4dfb..5cd1d0ec9 100644 --- a/llama/llama.cpp/src/llama-model.cpp +++ b/llama/llama.cpp/src/llama-model.cpp @@ -43,15 +43,18 @@ const char * llm_type_name(llm_type type) { case LLM_TYPE_256M: return "256M"; case LLM_TYPE_270M: return "270M"; case LLM_TYPE_335M: return "335M"; + case LLM_TYPE_350M: return "350M"; case LLM_TYPE_410M: return "410M"; case LLM_TYPE_450M: return "450M"; case LLM_TYPE_475M: return "475M"; + case LLM_TYPE_700M: return "700M"; case LLM_TYPE_770M: return "770M"; case LLM_TYPE_780M: return "780M"; case LLM_TYPE_0_3B: return "0.3B"; case LLM_TYPE_0_5B: return "0.5B"; case LLM_TYPE_0_6B: return "0.6B"; case LLM_TYPE_1B: return "1B"; + case LLM_TYPE_1_2B: return "1.2B"; case LLM_TYPE_1_3B: return "1.3B"; case LLM_TYPE_1_4B: return "1.4B"; case LLM_TYPE_1_5B: return "1.5B"; @@ -932,6 +935,33 @@ void llama_model::load_hparams(llama_model_loader & ml) { default: type = LLM_TYPE_UNKNOWN; } } break; + case LLM_ARCH_PLAMO2: + { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + + // Load Mamba SSM parameters + ml.get_key(LLM_KV_SSM_CONV_KERNEL, hparams.ssm_d_conv); + ml.get_key(LLM_KV_SSM_INNER_SIZE, hparams.ssm_d_inner); + ml.get_key(LLM_KV_SSM_STATE_SIZE, hparams.ssm_d_state); + ml.get_key(LLM_KV_SSM_TIME_STEP_RANK, hparams.ssm_dt_rank); + ml.get_key(LLM_KV_SSM_GROUP_COUNT, hparams.ssm_n_group); + + for (uint32_t i = 0; i < hparams.n_layer; ++i) { + hparams.recurrent_layer_arr[i] = hparams.n_head_kv(i) == 0; + } + + switch (hparams.n_layer) { + case 16: type = LLM_TYPE_1B; break; + case 32: + if (hparams.n_embd == 2048) { + type = LLM_TYPE_2B; + } else if (hparams.n_embd == 4096) { + type = LLM_TYPE_8B; + } + break; + default: type = LLM_TYPE_UNKNOWN; + } + } break; case LLM_ARCH_GPT2: { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); @@ -1678,6 +1708,20 @@ void llama_model::load_hparams(llama_model_loader & ml) { default: type = LLM_TYPE_UNKNOWN; } } break; + case LLM_ARCH_LFM2: + { + ml.get_key(LLM_KV_SHORTCONV_L_CACHE, hparams.n_shortconv_l_cache); + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + for (uint32_t il = 0; il < hparams.n_layer; ++il) { + hparams.recurrent_layer_arr[il] = hparams.n_head_kv(il) == 0; + } + switch (hparams.n_embd) { + case 1024: type = LLM_TYPE_350M; break; + case 1536: type = LLM_TYPE_700M; break; + case 2048: type = LLM_TYPE_1_2B; break; + default: type = LLM_TYPE_UNKNOWN; + } + } break; default: throw std::runtime_error("unsupported model architecture"); } @@ -2936,6 +2980,73 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); } } break; + case LLM_ARCH_PLAMO2: + { + const uint32_t d_conv = hparams.ssm_d_conv; + const uint32_t d_state = hparams.ssm_d_state; + const uint32_t num_heads = hparams.ssm_dt_rank; + const uint32_t intermediate_size = hparams.ssm_d_inner; + const uint32_t head_dim = intermediate_size / num_heads; + const uint32_t qk_dim = head_dim; + const uint32_t v_dim = head_dim; + const int64_t num_attention_heads = hparams.n_head(); + const int64_t q_num_heads = num_attention_heads; + const int64_t dt_dim = std::max(64, int(hparams.n_embd / 16)); + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); + // if output is NULL, init from the input tok embed + if (output == NULL) { + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); + } + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + bool is_mamba_layer = hparams.is_recurrent(i); + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + + if (is_mamba_layer) { + layer.ssm_in = create_tensor(tn(LLM_TENSOR_SSM_IN, "weight", i), {n_embd, 2 * intermediate_size}, 0); + layer.ssm_conv1d = create_tensor(tn(LLM_TENSOR_SSM_CONV1D, "weight", i), {d_conv, intermediate_size}, 0); + + layer.ssm_x = create_tensor(tn(LLM_TENSOR_SSM_X, "weight", i), {intermediate_size, dt_dim + 2*d_state}, 0); + layer.ssm_dt = create_tensor(tn(LLM_TENSOR_SSM_DT, "weight", i), {dt_dim, num_heads}, 0); + layer.ssm_dt_b = create_tensor(tn(LLM_TENSOR_SSM_DT, "bias", i), {num_heads}, 0); + + layer.ssm_a = create_tensor(tn(LLM_TENSOR_SSM_A, i), {num_heads}, 0); + layer.ssm_d = create_tensor(tn(LLM_TENSOR_SSM_D, i), {num_heads}, 0); + + layer.ssm_out = create_tensor(tn(LLM_TENSOR_SSM_OUT, "weight", i), {intermediate_size, n_embd}, 0); + + layer.ssm_dt_norm = create_tensor(tn(LLM_TENSOR_SSM_DT_NORM, i), {dt_dim}, 0); + layer.ssm_b_norm = create_tensor(tn(LLM_TENSOR_SSM_B_NORM, i), {d_state}, 0); + layer.ssm_c_norm = create_tensor(tn(LLM_TENSOR_SSM_C_NORM, i), {d_state}, 0); + } else { + const int64_t num_key_value_heads = hparams.n_head_kv(i); + const int64_t k_num_heads = num_key_value_heads; + const int64_t v_num_heads = num_key_value_heads; + const int64_t q_proj_dim = q_num_heads * qk_dim; + const int64_t k_proj_dim = k_num_heads * qk_dim; + const int64_t v_proj_dim = v_num_heads * v_dim; + + layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, q_proj_dim + k_proj_dim + v_proj_dim}, 0); + layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {head_dim, num_attention_heads}, 0); + layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {head_dim, k_num_heads}, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {q_num_heads * v_dim, n_embd}, 0); + } + + // All layers have post-attention norm, FFN norm, and FFN tensors + layer.attn_post_norm = create_tensor(tn(LLM_TENSOR_ATTN_POST_NORM, i), {n_embd}, 0); + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff * 2}, 0); + layer.ffn_post_norm = create_tensor(tn(LLM_TENSOR_FFN_POST_NORM, i), {n_embd}, 0); + } + } break; case LLM_ARCH_GPT2: { tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); @@ -4949,6 +5060,39 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); } } break; + case LLM_ARCH_LFM2: + { + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + tok_norm = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "weight"), {n_embd}, 0); + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + // ffn is same for transformer and conv layers + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + + // for operator_norm + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + + if (!hparams.is_recurrent(i)) { + layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head_k}, 0); + layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k}, 0); + GGML_ASSERT(n_embd_v_gqa == n_embd_k_gqa); + + layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}, 0); + layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, hparams.n_embd_k_gqa(i)}, 0); + layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, hparams.n_embd_v_gqa(i)}, 0); + + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); + } else { + layer.shortconv.conv = create_tensor(tn(LLM_TENSOR_SHORTCONV_CONV, "weight", i), {hparams.n_shortconv_l_cache, n_embd}, 0); + layer.shortconv.in_proj = create_tensor(tn(LLM_TENSOR_SHORTCONV_INPROJ, "weight", i), {n_embd, 3 * n_embd}, 0); + layer.shortconv.out_proj = create_tensor(tn(LLM_TENSOR_SHORTCONV_OUTPROJ, "weight", i), {n_embd, n_embd}, 0); + } + } + } break; default: throw std::runtime_error("unknown architecture"); } @@ -5202,6 +5346,7 @@ void llama_model::print_info() const { arch == LLM_ARCH_MAMBA2 || arch == LLM_ARCH_JAMBA || arch == LLM_ARCH_FALCON_H1 || + arch == LLM_ARCH_PLAMO2 || arch == LLM_ARCH_GRANITE_HYBRID) { LLAMA_LOG_INFO("%s: ssm_d_conv = %u\n", __func__, hparams.ssm_d_conv); LLAMA_LOG_INFO("%s: ssm_d_inner = %u\n", __func__, hparams.ssm_d_inner); @@ -15628,6 +15773,320 @@ struct llm_build_falcon_h1 : public llm_graph_context_mamba { } }; +struct llm_build_plamo2 : public llm_graph_context_mamba { + llm_build_plamo2(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context_mamba(params) { + ggml_tensor * cur; + ggml_tensor * inpL; + + // {n_embd, n_tokens} + inpL = build_inp_embd(model.tok_embd); + cb(inpL, "embedding_output", -1); + + ggml_tensor * inp_pos = build_inp_pos(); + + auto * inp_hybrid = build_inp_mem_hybrid(); + + ggml_tensor * inp_out_ids = build_inp_out_ids(); + + for (int il = 0; il < n_layer; ++il) { + ggml_tensor * residual = inpL; + + // ggml_graph_add_node(gf, model.layers[il].attn_norm); + // cb(model.layers[il].attn_norm, "attn_norm", il); + + // pre_mixer_norm + cur = build_norm(inpL, model.layers[il].attn_norm, NULL, LLM_NORM_RMS, il); + + // check if this layer is Mamba or Attention + bool is_mamba_layer = hparams.is_recurrent(il); + + if (is_mamba_layer) { + // PLaMo-2 Mamba layer + cur = build_plamo2_mamba_layer(inp_hybrid->get_recr(), gf, cur, model, ubatch, il); + } else { + // PLaMo-2 Attention layer + cur = build_plamo2_attn_layer(inp_hybrid->get_attn(), inp_pos, gf, cur, model, il); + } + + // post_mixer_norm + cur = build_norm(cur, model.layers[il].attn_post_norm, NULL, LLM_NORM_RMS, il); + cb(cur, "attn_post_norm", il); + + // residual connection + cur = ggml_add(ctx0, cur, residual); + cb(cur, "attn_residual", il); + residual = cur; + + // pre-ffn norm + cur = build_norm(cur, model.layers[il].ffn_norm, NULL, LLM_NORM_RMS, il); + cb(cur, "ffn_pre_norm", il); + + // feed-forward network + cur = build_ffn(cur, + model.layers[il].ffn_up, NULL, NULL, + NULL, NULL, NULL, + model.layers[il].ffn_down, NULL, NULL, + NULL, + LLM_FFN_SWIGLU, LLM_FFN_SEQ, il); + cb(cur, "ffn_out", il); + + // post ffn norm + cur = build_norm(cur, model.layers[il].ffn_post_norm, NULL, LLM_NORM_RMS, il); + cb(cur, "ffn_post_norm", il); + + if (il == n_layer - 1 && inp_out_ids) { + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + residual = ggml_get_rows(ctx0, residual, inp_out_ids); + } + + // residual connection + cur = ggml_add(ctx0, cur, residual); + cb(cur, "ffn_residual", il); + + inpL = cur; + } + + cur = inpL; + + // final norm + cur = build_norm(cur, model.output_norm, NULL, LLM_NORM_RMS, -1); + cb(cur, "result_norm", -1); + + // lm_head + cur = build_lora_mm(model.output, cur); + cb(cur, "result_output", -1); + + // Explicitly mark as output tensor to ensure proper backend assignment + ggml_set_output(cur); + + res->t_logits = cur; + + ggml_build_forward_expand(gf, cur); + } + +private: + ggml_tensor * build_plamo2_attn_layer( + llm_graph_input_attn_kv_unified * inp, + ggml_tensor * inp_pos, + ggml_cgraph * gf, + ggml_tensor * cur, + const llama_model & model, + int il) { + + // self-attention + { + // PLaMo-2 uses combined QKV tensor + ggml_tensor * qkv = build_lora_mm(model.layers[il].wqkv, cur); + cb(qkv, "qkv", il); + + // split QKV tensor into Q, K, V + const int64_t n_embd_head_q = hparams.n_embd_head_k; + const int64_t n_embd_head_k = hparams.n_embd_head_k; + const int64_t n_embd_head_v = hparams.n_embd_head_v; + int32_t n_head_kv = hparams.n_head_kv(il); + + const int64_t q_offset = 0; + const int64_t k_offset = n_embd_head_q * n_head; + const int64_t v_offset = k_offset + n_embd_head_k * n_head_kv; + + ggml_tensor * Qcur = ggml_view_3d(ctx0, qkv, n_embd_head_q, n_head, n_tokens, n_embd_head_q * sizeof(float), qkv->nb[1], q_offset * ggml_element_size(qkv)); + ggml_tensor * Kcur = ggml_view_3d(ctx0, qkv, n_embd_head_k, n_head_kv, n_tokens, n_embd_head_k * sizeof(float), qkv->nb[1], k_offset * ggml_element_size(qkv)); + ggml_tensor * Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, qkv, n_embd_head_v * n_head_kv, n_tokens, qkv->nb[1], v_offset * ggml_element_size(qkv))); + + cb(Qcur, "Qcur", il); + cb(Kcur, "Kcur", il); + cb(Vcur, "Vcur", il); + + Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head_v, n_head_kv, n_tokens); + + Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, NULL, LLM_NORM_RMS, il); + cb(Qcur, "Qcur_normed", il); + + Qcur = ggml_rope_ext( + ctx0, Qcur, inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + Kcur = build_norm(Kcur, model.layers[il].attn_k_norm, NULL, LLM_NORM_RMS, il); + cb(Kcur, "Kcur_normed", il); + + Kcur = ggml_rope_ext( + ctx0, Kcur, inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + cur = build_attn(inp, gf, model.layers[il].wo, NULL, Qcur, Kcur, Vcur, NULL, NULL, 1.0f, il); + } + + cb(cur, "attn_out", il); + + return cur; + } + + ggml_tensor * build_plamo2_mamba_layer( + llm_graph_input_rs * inp, + ggml_cgraph * gf, + ggml_tensor * cur, + const llama_model & model, + const llama_ubatch & ubatch, + int il) { + + const auto * mctx_cur = inp->mctx; + + const auto kv_head = mctx_cur->get_head(); + + const int64_t d_conv = hparams.ssm_d_conv; + const int64_t d_inner = hparams.ssm_d_inner; + const int64_t d_state = hparams.ssm_d_state; + const int64_t n_heads = hparams.ssm_dt_rank; + const int64_t head_dim = d_inner / n_heads; + const int64_t n_group = hparams.ssm_n_group; + const int64_t n_seqs = ubatch.n_seqs; + + const int64_t n_seq_tokens = ubatch.n_seq_tokens; + + GGML_ASSERT(n_seqs != 0); + GGML_ASSERT(ubatch.equal_seqs); + GGML_ASSERT(ubatch.n_tokens == n_seq_tokens * n_seqs); + + ggml_tensor * conv_states_all = mctx_cur->get_r_l(il); + ggml_tensor * ssm_states_all = mctx_cur->get_s_l(il); + + ggml_tensor * conv = build_rs(inp, gf, conv_states_all, hparams.n_embd_r(), n_seqs); + conv = ggml_reshape_3d(ctx0, conv, d_conv - 1, d_inner + 2*n_group*d_state, n_seqs); + + // {n_embd, n_tokens} => {n_embd, n_seq_tokens, n_seqs} + cur = ggml_reshape_3d(ctx0, cur, cur->ne[0], n_seq_tokens, n_seqs); + + // in_proj: {n_embd, 2*d_inner} @ {n_embd, n_seq_tokens, n_seqs} => {2*d_inner, n_seq_tokens, n_seqs} + ggml_tensor * zx = build_lora_mm(model.layers[il].ssm_in, cur); + cb(zx, "mamba_in_proj", il); + // {8192, 5, 1, 1} -> {8192, 1, 5, 1} + zx = ggml_permute(ctx0, zx, 0, 2, 1, 3); + zx = ggml_reshape_4d(ctx0, zx, head_dim * 2, n_heads, n_seq_tokens, n_seqs); + cb(zx, "mamba_in_proj_out", il); + + // split into z and x + // => {head_dim * n_heads, n_seq_tokens, n_seqs} + ggml_tensor * x = ggml_view_4d(ctx0, zx, head_dim, n_heads, n_seq_tokens, n_seqs, zx->nb[1], zx->nb[2], zx->nb[3], head_dim*ggml_element_size(zx)); + x = ggml_cont(ctx0, x); + x = ggml_reshape_3d(ctx0, x, head_dim * n_heads, n_seq_tokens, n_seqs); + // x = ggml_permute(ctx0, x, 0, 2, 1, 3); + cb(x, "mamba_x_split", il); + + ggml_tensor * z = ggml_view_4d(ctx0, zx, head_dim, n_heads, n_seq_tokens, n_seqs, zx->nb[1], zx->nb[2], zx->nb[3], 0); + cb(z, "mamba_z_split", il); + + // conv1d + { + // => {d_conv - 1 + n_seq_tokens, d_inner, n_seqs} + x = ggml_view_2d(ctx0, x, d_inner, n_seq_tokens * n_seqs, d_inner * x->nb[0], 0); + ggml_tensor * conv_x = ggml_concat(ctx0, conv, ggml_transpose(ctx0, x), 0); + cb(conv_x, "mamba_conv1d_input", il); + + // copy last (d_conv - 1) columns back into the state cache + ggml_tensor * last_conv = ggml_view_3d(ctx0, conv_x, d_conv - 1, d_inner, n_seqs, + conv_x->nb[1], conv_x->nb[2], n_seq_tokens*(conv_x->nb[0])); + + ggml_build_forward_expand(gf, + ggml_cpy(ctx0, last_conv, + ggml_view_1d(ctx0, conv_states_all, + (d_conv - 1)*(d_inner)*(n_seqs), + kv_head*(d_conv - 1)*(d_inner)*ggml_element_size(conv_states_all)))); + + // 1D convolution + x = ggml_ssm_conv(ctx0, conv_x, model.layers[il].ssm_conv1d); + cb(x, "mamba_conv1d", il); + + x = ggml_silu(ctx0, x); + cb(x, "mamba_conv1d_silu", il); + } + + // SSM + { + // bcdt_proj: {d_inner, dt_rank + 2*d_state} @ {d_inner, n_seq_tokens, n_seqs} => {dt_rank + 2*d_state, n_seq_tokens, n_seqs} + ggml_tensor * x_bcdt = build_lora_mm(model.layers[il].ssm_x, x); + cb(x_bcdt, "mamba_bcdt_proj", il); + + // split into dt, B, C + const int64_t dt_dim = std::max(64, int(hparams.n_embd / 16)); + ggml_tensor * B = ggml_view_3d(ctx0, x_bcdt, d_state, n_seq_tokens, n_seqs, x_bcdt->nb[1], x_bcdt->nb[2], 0); + ggml_tensor * C = ggml_view_3d(ctx0, x_bcdt, d_state, n_seq_tokens, n_seqs, x_bcdt->nb[1], x_bcdt->nb[2], ggml_element_size(x_bcdt)*d_state); + ggml_tensor * dt = ggml_view_3d(ctx0, x_bcdt, dt_dim, n_seq_tokens, n_seqs, x_bcdt->nb[1], x_bcdt->nb[2], ggml_element_size(x_bcdt)*(2*d_state)); + cb(B, "mamba_B_raw", il); + cb(C, "mamba_C_raw", il); + cb(dt, "mamba_dt_raw", il); + + // Apply RMS norm to dt, B, C (PLaMo-2 specific) + B = build_norm(B, model.layers[il].ssm_b_norm, NULL, LLM_NORM_RMS, il); + C = build_norm(C, model.layers[il].ssm_c_norm, NULL, LLM_NORM_RMS, il); + dt = build_norm(dt, model.layers[il].ssm_dt_norm, NULL, LLM_NORM_RMS, il); + cb(B, "mamba_B_normed", il); + cb(C, "mamba_C_normed", il); + cb(dt, "mamba_dt_normed", il); + + // dt_proj: {dt_rank, d_inner} @ {dt_rank, n_seq_tokens, n_seqs} => {d_inner, n_seq_tokens, n_seqs} + dt = build_lora_mm(model.layers[il].ssm_dt, dt); + dt = ggml_add(ctx0, dt, model.layers[il].ssm_dt_b); + cb(dt, "mamba_dt_proj", il); + + ggml_tensor * A = ggml_reshape_2d(ctx0, model.layers[il].ssm_a, 1, n_heads); + cb(A, "mamba_A", il); + + x = ggml_view_4d(ctx0, x, head_dim, n_heads, n_seq_tokens, n_seqs, head_dim * ggml_element_size(x), head_dim * n_heads * ggml_element_size(x), head_dim * n_heads * n_seq_tokens * ggml_element_size(x), 0); + B = ggml_view_4d(ctx0, B, d_state, 1, n_seq_tokens, n_seqs, d_state * B->nb[0], B->nb[1], B->nb[2], 0); + C = ggml_view_4d(ctx0, C, d_state, 1, n_seq_tokens, n_seqs, d_state * C->nb[0], C->nb[1], C->nb[2], 0); + + // use the states and the indices provided by build_recurrent_state + // (this is necessary in order to properly use the states before they are overwritten, + // while avoiding to make unnecessary copies of the states) + auto get_ssm_rows = [&](ggml_context * ctx, ggml_tensor * states, ggml_tensor * ids) { + ggml_tensor * ssm = ggml_reshape_4d(ctx, states, d_state, head_dim, n_heads, mctx_cur->get_size()); + + // Custom operator to optimize the parallel associative scan + // as described in the Annex D of the Mamba paper. + // => {d_inner, n_seq_tokens, n_seqs} and {d_state, d_inner, n_seqs} + return ggml_ssm_scan(ctx, ssm, x, dt, A, B, C, ids); + }; + + ggml_tensor * y_ssm = build_rs(inp, gf, ssm_states_all, hparams.n_embd_s(), ubatch.n_seqs, get_ssm_rows); + cb(y_ssm, "mamba_ssm_scan", il); + + // store last states + ggml_build_forward_expand(gf, + ggml_cpy(ctx0, + ggml_view_1d(ctx0, y_ssm, d_state*d_inner*n_seqs, x->nb[3]*x->ne[3]), + ggml_view_1d(ctx0, ssm_states_all, d_state*d_inner*n_seqs, + kv_head*d_state*d_inner*ggml_element_size(ssm_states_all)))); + + ggml_tensor * y = ggml_view_4d(ctx0, y_ssm, head_dim, n_heads, n_seq_tokens, n_seqs, head_dim * ggml_element_size(x), head_dim * n_heads * ggml_element_size(x), head_dim * n_heads * n_seq_tokens * ggml_element_size(x), 0); + cb(y, "mamba_y_view", il); + + // Add D parameter and apply gating with z + // {d_inner, n_seq_tokens, n_seqs} * {d_inner} => {d_inner, n_seq_tokens, n_seqs} + ggml_tensor * D = ggml_reshape_2d(ctx0, model.layers[il].ssm_d, 1, n_heads); + y = ggml_add(ctx0, y, ggml_mul(ctx0, x, D)); + cb(y, "mamba_y_add_d", il); + + y = ggml_swiglu_split(ctx0, ggml_cont(ctx0, z), y); + cb(y, "mamba_y_swiglu_z", il); + + // out_proj: {d_inner, n_embd} @ {d_inner, n_seq_tokens, n_seqs} => {n_embd, n_seq_tokens, n_seqs} + y = ggml_view_3d(ctx0, y, head_dim * n_heads, n_seq_tokens, n_seqs, y->nb[2], y->nb[3], 0); + cur = build_lora_mm(model.layers[il].ssm_out, y); + cb(cur, "mamba_out_proj", il); + } + + // {n_embd, n_seq_tokens, n_seqs} => {n_embd, n_tokens} + cur = ggml_reshape_2d(ctx0, cur, cur->ne[0], n_seq_tokens * n_seqs); + cb(cur, "mamba_out", il); + + return cur; + } +}; + struct llm_build_arcee : public llm_graph_context { llm_build_arcee(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_v; @@ -16061,6 +16520,163 @@ struct llm_build_smollm3 : public llm_graph_context { } }; +struct llm_build_lfm2 : public llm_graph_context { + const llama_model & model; + + llm_build_lfm2(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params), model(model) { + + ggml_tensor * cur = build_inp_embd(model.tok_embd); + cb(cur, "model.embed_tokens", -1); + + ggml_tensor * inp_pos = build_inp_pos(); + auto * inp_hybrid = build_inp_mem_hybrid(); + ggml_tensor * inp_out_ids = build_inp_out_ids(); + + for (int il = 0; il < n_layer; ++il) { + auto * prev_cur = cur; + cur = build_norm(cur, model.layers[il].attn_norm, NULL, LLM_NORM_RMS, il); + cb(cur, "model.layers.{}.operator_norm", il); + + cur = hparams.is_recurrent(il) ? + build_shortconv_block(gf, cur, inp_hybrid->get_recr(), il) : + build_attn_block(gf, cur, inp_pos, inp_hybrid->get_attn(), il) ; + + if (il == n_layer - 1 && inp_out_ids) { + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + prev_cur = ggml_get_rows(ctx0, prev_cur, inp_out_ids); + } + + cur = ggml_add(ctx0, prev_cur, cur); + cur = ggml_add(ctx0, cur, build_feed_forward(cur, il)); + } + + cur = build_norm(cur, model.tok_norm, NULL, LLM_NORM_RMS, -1); + cb(cur, "model.embedding_norm", -1); + res->t_embd = cur; + + // lm_head is tied with embeddings + cur = build_lora_mm(model.tok_embd, cur); + cb(cur, "lm_head", -1); + + res->t_logits = cur; + + ggml_build_forward_expand(gf, cur); + } + + ggml_tensor * build_feed_forward(ggml_tensor * cur, + int il) const { + cur = build_norm(cur, model.layers[il].ffn_norm, NULL, LLM_NORM_RMS, il); + cb(cur, "model.layers.{}.ffn_norm", il); + + GGML_ASSERT(!model.layers[il].ffn_up_b); + GGML_ASSERT(!model.layers[il].ffn_gate_b); + GGML_ASSERT(!model.layers[il].ffn_down_b); + cur = build_ffn(cur, + model.layers[il].ffn_up, NULL, NULL, + model.layers[il].ffn_gate, NULL, NULL, + model.layers[il].ffn_down, NULL, NULL, + NULL, + LLM_FFN_SILU, LLM_FFN_PAR, il); + cb(cur, "model.layers.{}.feed_forward.w2", il); + + return cur; + } + + ggml_tensor * build_attn_block(ggml_cgraph * gf, + ggml_tensor * cur, + ggml_tensor * inp_pos, + llm_graph_input_attn_kv_unified * inp_attn, + int il) const { + GGML_ASSERT(hparams.n_embd_v_gqa(il) == hparams.n_embd_k_gqa(il)); + auto const n_embd_head = hparams.n_embd_head_v; + auto const n_head_kv = hparams.n_head_kv(il); + + auto * q = build_lora_mm(model.layers[il].wq, cur); + cb(q, "model.layers.{}.self_attn.q_proj", il); + auto * k = build_lora_mm(model.layers[il].wk, cur); + cb(k, "model.layers.{}.self_attn.k_proj", il); + auto * v = build_lora_mm(model.layers[il].wv, cur); + cb(v, "model.layers.{}.self_attn.v_proj", il); + + q = ggml_reshape_3d(ctx0, q, n_embd_head, n_head, n_tokens); + k = ggml_reshape_3d(ctx0, k, n_embd_head, n_head_kv, n_tokens); + v = ggml_reshape_3d(ctx0, v, n_embd_head, n_head_kv, n_tokens); + + // qk norm + q = build_norm(q, model.layers[il].attn_q_norm, NULL, LLM_NORM_RMS, il); + cb(q, "model.layers.{}.self_attn.q_layernorm", il); + k = build_norm(k, model.layers[il].attn_k_norm, NULL, LLM_NORM_RMS, il); + cb(k, "model.layers.{}.self_attn.k_layernorm", il); + + // RoPE + q = ggml_rope_ext( + ctx0, q, inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + k = ggml_rope_ext( + ctx0, k, inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + cur = build_attn(inp_attn, gf, model.layers[il].wo, NULL, + q, k, v, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); + + cb(cur, "model.layers.{}.self_attn.out_proj", il); + + return cur; + } + + ggml_tensor * build_shortconv_block(ggml_cgraph * gf, + ggml_tensor * cur, + llm_graph_input_rs * inp_recr, + int il) { + const auto * mctx_cur = static_cast(mctx)->get_recr(); + + auto * bcx = build_lora_mm(model.layers[il].shortconv.in_proj, cur); + cb(bcx, "model.layers.{}.conv.in_proj", il); + + constexpr auto n_chunks = 3; + GGML_ASSERT(bcx->ne[0] % n_chunks == 0); + auto const chunk_size = bcx->ne[0] / n_chunks; + auto * b = ggml_view_2d(ctx0, bcx, chunk_size, bcx->ne[1], bcx->nb[1], 0 * chunk_size * ggml_element_size(bcx)); + auto * c = ggml_view_2d(ctx0, bcx, chunk_size, bcx->ne[1], bcx->nb[1], 1 * chunk_size * ggml_element_size(bcx)); + auto * x = ggml_view_2d(ctx0, bcx, chunk_size, bcx->ne[1], bcx->nb[1], 2 * chunk_size * ggml_element_size(bcx)); + + auto * bx = ggml_transpose(ctx0, ggml_mul(ctx0, b, x)); + + // read conv state directly, with build_rs generation is slower + ggml_tensor * conv_state = mctx_cur->get_r_l(il); + const int64_t n_seqs = ubatch.n_seqs; + ggml_tensor * conv = build_rs(inp_recr, gf, conv_state, hparams.n_embd_r(), n_seqs); + conv = ggml_reshape_3d(ctx0, conv_state, hparams.n_shortconv_l_cache - 1, hparams.n_embd, n_seqs); + + bx = ggml_concat(ctx0, conv, bx, 0); + GGML_ASSERT(bx->ne[0] > conv->ne[0]); + + auto * new_conv = ggml_view_2d(ctx0, bx, conv->ne[0], bx->ne[1], bx->nb[1], (bx->ne[0] - conv->ne[0]) * ggml_element_size(bx)); + GGML_ASSERT(ggml_are_same_shape(conv, new_conv)); + + // write conv state + ggml_build_forward_expand(gf, ggml_cpy(ctx0, new_conv, conv_state)); + + auto * conv_kernel = model.layers[il].shortconv.conv; + GGML_ASSERT(hparams.n_shortconv_l_cache > 0); + + // construct ssm_conv op + ggml_tensor * conv_out = ggml_ssm_conv(ctx0, bx, conv_kernel); + cb(conv_out, "model.layers.{}.conv.conv", il); + + auto * y = ggml_mul(ctx0, c, conv_out); + + y = build_lora_mm(model.layers[il].shortconv.out_proj, y); + cb(y, "model.layers.{}.conv.out_proj", il); + + return y; + } +}; + llama_memory_i * llama_model::create_memory(const llama_memory_params & params, llama_cparams & cparams) const { llama_memory_i * res; @@ -16257,6 +16873,10 @@ llm_graph_result_ptr llama_model::build_graph( { llm = std::make_unique(*this, params, gf); } break; + case LLM_ARCH_PLAMO2: + { + llm = std::make_unique(*this, params, gf); + } break; case LLM_ARCH_GPT2: { llm = std::make_unique(*this, params, gf); @@ -16467,6 +17087,10 @@ llm_graph_result_ptr llama_model::build_graph( { llm = std::make_unique(*this, params, gf); } break; + case LLM_ARCH_LFM2: + { + llm = std::make_unique(*this, params, gf); + } break; default: GGML_ABORT("fatal error"); } @@ -16647,6 +17271,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) { case LLM_ARCH_PHI3: case LLM_ARCH_PHIMOE: case LLM_ARCH_PLAMO: + case LLM_ARCH_PLAMO2: case LLM_ARCH_GEMMA: case LLM_ARCH_GEMMA2: case LLM_ARCH_GEMMA3: @@ -16661,6 +17286,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) { case LLM_ARCH_MINICPM3: case LLM_ARCH_DOTS1: case LLM_ARCH_HUNYUAN_MOE: + case LLM_ARCH_LFM2: return LLAMA_ROPE_TYPE_NEOX; case LLM_ARCH_QWEN2VL: diff --git a/llama/llama.cpp/src/llama-model.h b/llama/llama.cpp/src/llama-model.h index 05a9adfa9..68bbf01bc 100644 --- a/llama/llama.cpp/src/llama-model.h +++ b/llama/llama.cpp/src/llama-model.h @@ -35,15 +35,18 @@ enum llm_type { LLM_TYPE_256M, LLM_TYPE_270M, LLM_TYPE_335M, + LLM_TYPE_350M, LLM_TYPE_410M, LLM_TYPE_450M, LLM_TYPE_475M, + LLM_TYPE_700M, LLM_TYPE_770M, LLM_TYPE_780M, LLM_TYPE_0_3B, LLM_TYPE_0_5B, LLM_TYPE_0_6B, LLM_TYPE_1B, + LLM_TYPE_1_2B, LLM_TYPE_1_3B, LLM_TYPE_1_4B, LLM_TYPE_1_5B, @@ -156,6 +159,12 @@ struct llama_layer_convnext { struct ggml_tensor * gamma = nullptr; }; +struct llama_layer_shortconv { + struct ggml_tensor * in_proj = nullptr; + struct ggml_tensor * conv = nullptr; + struct ggml_tensor * out_proj = nullptr; +}; + struct llama_layer { // normalization struct ggml_tensor * attn_norm = nullptr; @@ -344,6 +353,8 @@ struct llama_layer { struct llama_layer_posnet posnet; struct llama_layer_convnext convnext; + + struct llama_layer_shortconv shortconv; }; struct llama_model { diff --git a/llama/llama.cpp/src/llama-quant.cpp b/llama/llama.cpp/src/llama-quant.cpp index f4b5713d7..a00af7a1d 100644 --- a/llama/llama.cpp/src/llama-quant.cpp +++ b/llama/llama.cpp/src/llama-quant.cpp @@ -844,6 +844,7 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std:: // do not quantize Mamba's small yet 2D weights // NOTE: can't use LLM_TN here because the layer number is not known quantize &= name.find("ssm_conv1d.weight") == std::string::npos; + quantize &= name.find("shortconv.conv.weight") == std::string::npos; // do not quantize RWKV's small yet 2D weights quantize &= name.find("time_mix_first.weight") == std::string::npos; @@ -883,8 +884,7 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std:: if (std::regex pattern(tname); std::regex_search(tensor_name, pattern)) { if (qtype != new_type) { LLAMA_LOG_DEBUG("(overriding %s) ", ggml_type_name(new_type)); - new_type = qtype; - break; // if two or more types are specified for the tensor, first match wins + new_type = qtype; // if two or more types are specified for the same tensor, the last match wins } } } diff --git a/llama/llama.cpp/src/llama-vocab.cpp b/llama/llama.cpp/src/llama-vocab.cpp index d0339c64d..3484560eb 100644 --- a/llama/llama.cpp/src/llama-vocab.cpp +++ b/llama/llama.cpp/src/llama-vocab.cpp @@ -11,6 +11,7 @@ #include #include #include +#include #include #include #include @@ -404,6 +405,13 @@ struct llm_tokenizer_bpe : llm_tokenizer { "[^\\r\\n\\p{L}\\p{N}]?((?=[\\p{L}])([^a-z]))*((?=[\\p{L}])([^A-Z]))+(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])?|[^\\r\\n\\p{L}\\p{N}]?((?=[\\p{L}])([^a-z]))+((?=[\\p{L}])([^A-Z]))*(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])?|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n/]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+", }; break; + case LLAMA_VOCAB_PRE_TYPE_KIMI_K2: + regex_exprs = { + // K2 trigger pattern - this will activate the custom K2 handler in unicode.cpp + // The custom handler implements all K2 patterns with proper Han character exclusion + "\\p{Han}+", + }; + break; case LLAMA_VOCAB_PRE_TYPE_SUPERBPE: regex_exprs = { "\\p{N}+", @@ -1196,6 +1204,284 @@ private: const llm_tokenizer_rwkv & tokenizer; }; +struct llm_tokenizer_plamo2 : llm_tokenizer { + llm_tokenizer_plamo2(const llama_vocab & vocab) { + build(vocab); + } + + void build(const llama_vocab & vocab) { + // Reset internal structures + tokens_.clear(); + bytes_.assign(256, 0); + to_suffix_id_.clear(); + table_.clear(); + + // Build token list and byte mapping + std::unordered_map suffix_to_score; + std::unordered_map token_to_id; + + for (size_t token_id = 0; token_id < vocab.n_tokens(); ++token_id) { + const auto & entry = vocab.get_token_data(token_id); + tokens_.push_back(entry.text); + token_to_id[entry.text] = static_cast(token_id); + + // Handle byte tokens + if (vocab.is_byte(token_id)) { + if (entry.text.length() == 6 && entry.text.substr(0, 3) == "<0x" && entry.text.back() == '>') { + std::string hex_str = entry.text.substr(3, 2); + int byte_val = std::stoi(hex_str, nullptr, 16); + bytes_[byte_val] = static_cast(token_id); + } + continue; + } + + // Add token and all its suffixes to suffix_to_score + suffix_to_score[entry.text] = entry.score; + + // Extract suffixes character by character (UTF-8 aware) + std::vector cpts = unicode_cpts_from_utf8(entry.text); + for (size_t i = 1; i < cpts.size(); ++i) { + std::string suffix; + for (size_t j = i; j < cpts.size(); ++j) { + suffix += unicode_cpt_to_utf8(cpts[j]); + } + if (suffix_to_score.find(suffix) == suffix_to_score.end()) { + suffix_to_score[suffix] = std::numeric_limits::quiet_NaN(); + } + } + } + + // Check that all byte tokens are set + for (int i = 0; i < 256; ++i) { + if (bytes_[i] == 0) { + throw std::runtime_error("Byte token for <0x" + std::to_string(i) + "> is not set"); + } + } + + // Build suffix list in lexicographical order of reversed strings + std::vector suffixes; + for (const auto & pair : suffix_to_score) { + suffixes.push_back(pair.first); + } + suffixes.push_back(""); // Empty suffix + + std::sort(suffixes.begin(), suffixes.end(), [](const std::string & a, const std::string & b) { + std::string rev_a(a.rbegin(), a.rend()); + std::string rev_b(b.rbegin(), b.rend()); + return rev_a < rev_b; + }); + + // Build suffix_to_id and to_suffix_id_ + std::unordered_map suffix_to_id; + int32_t num_pieces = 0; + + for (const auto & suffix : suffixes) { + suffix_to_id[suffix] = num_pieces; + if (!suffix.empty()) { + std::vector cpts = unicode_cpts_from_utf8(suffix); + + std::string remaining; + for (size_t i = 1; i < cpts.size(); ++i) { + remaining += unicode_cpt_to_utf8(cpts[i]); + } + + int64_t piece_code = (static_cast(cpts[0]) << 32) | suffix_to_id[remaining]; + to_suffix_id_[piece_code] = num_pieces; + + // Count number of pieces for this suffix + int32_t pieces_for_suffix = 1; // sentinel row + for (int32_t piece_length = static_cast(cpts.size()); piece_length > 0; --piece_length) { + std::string piece; + for (int32_t i = 0; i < piece_length; ++i) { + piece += unicode_cpt_to_utf8(cpts[i]); + } + if (suffix_to_score.find(piece) != suffix_to_score.end()) { + pieces_for_suffix++; + } + } + num_pieces += pieces_for_suffix; + } else { + num_pieces++; // Empty suffix contributes one piece (sentinel row) + } + } + + // Build flattened table + table_.resize(num_pieces, std::vector(4, 0)); + int32_t table_idx = 0; + + for (const auto & suffix : suffixes) { + // Add all prefixes of the suffix to the table (in decreasing order of length) + std::vector cpts = unicode_cpts_from_utf8(suffix); + for (int32_t piece_length = static_cast(cpts.size()); piece_length > 0; --piece_length) { + std::string piece; + for (int32_t i = 0; i < piece_length; ++i) { + piece += unicode_cpt_to_utf8(cpts[i]); + } + + auto score_it = suffix_to_score.find(piece); + if (score_it == suffix_to_score.end()) { + continue; + } + + table_[table_idx][TABLE_PIECE_LENGTH] = piece_length; + auto token_it = token_to_id.find(piece); + table_[table_idx][TABLE_TOKEN_ID] = (token_it != token_to_id.end()) ? token_it->second : -1; + + float score = score_it->second; + table_[table_idx][TABLE_SCORE] = std::isfinite(score) ? + static_cast(std::round(score * 1e4)) : INVALID_SCORE; + table_[table_idx][TABLE_PIECE_ID] = suffix_to_id[piece]; + + table_idx++; + } + + // Add sentinel row + table_[table_idx][TABLE_PIECE_LENGTH] = 1; + table_[table_idx][TABLE_TOKEN_ID] = -1; + table_[table_idx][TABLE_SCORE] = UNKNOWN_SCORE; + table_idx++; + } + } + + std::vector encode(const std::string & text) const { + std::vector unicode_data = unicode_cpts_from_utf8(text); + // Skip the first code point if it is a BOM (Byte Order Mark) + if (!unicode_data.empty() && unicode_data[0] == 0xFEFF) { + unicode_data.erase(unicode_data.begin()); + } + + if (unicode_data.empty()) { + return {}; + } + + const size_t data_len = unicode_data.size(); + + // Initialize scores array (dynamic programming) + std::vector scores(data_len + 1, static_cast(1) << 60); + scores[data_len] = 0; + + // Path array to track best tokenization + std::vector> path(data_len + 1, std::vector(3, 0)); + + int32_t suffix_id = 0; + + // Process from end to beginning + for (int i = static_cast(data_len) - 1; i >= 0; --i) { + uint32_t c = unicode_data[i]; + + // Find next suffix ID + for (size_t p = suffix_id; p < table_.size(); ++p) { + int64_t piece_code = (static_cast(c) << 32) | table_[p][TABLE_PIECE_ID]; + auto it = to_suffix_id_.find(piece_code); + suffix_id = (it != to_suffix_id_.end()) ? it->second : 0; + + if (suffix_id > 0 || table_[p][TABLE_SCORE] == UNKNOWN_SCORE) { + break; + } + } + + // Update best path + for (size_t p = suffix_id; p < table_.size(); ++p) { + int32_t score = table_[p][TABLE_SCORE]; + if (score > INVALID_SCORE) { + int32_t piece_length = table_[p][TABLE_PIECE_LENGTH]; + int64_t s = scores[i + piece_length] - score; + + if (s < scores[i]) { + scores[i] = s; + path[i][PATH_TOKEN_LENGTH] = piece_length; + path[i][PATH_TOKEN_ID] = table_[p][TABLE_TOKEN_ID]; + path[i][PATH_NUM_TOKENS] = path[i + piece_length][PATH_NUM_TOKENS] + 1; + + if (score == UNKNOWN_SCORE) { + // Add UTF-8 byte count + path[i][PATH_NUM_TOKENS] += (c >= 0x80) + (c >= 0x800) + (c >= 0x10000); + } + } + } + + if (score == UNKNOWN_SCORE) { + break; + } + } + } + + // Decode the best path + std::vector token_ids; + token_ids.reserve(path[0][PATH_NUM_TOKENS]); + + int pos = 0; + while (pos < static_cast(data_len)) { + if (path[pos][PATH_TOKEN_ID] >= 0) { + token_ids.push_back(path[pos][PATH_TOKEN_ID]); + } else { + // Fall back to byte tokens + uint32_t c = unicode_data[pos]; + int s = 1 + (c >= 0x80) + (c >= 0x800) + (c >= 0x10000); + + for (int i = 0; i < s; ++i) { + uint8_t b; + if (s == 1) { + b = c; + } else { + if (i == 0) { + b = (0xF00 >> s) & 0xFF; + } else { + b = 0x80; + } + } + token_ids.push_back(bytes_[b | ((c >> ((s - i - 1) * 6)) & 0x3F)]); + } + } + + assert(path[pos][PATH_TOKEN_LENGTH] > 0); + pos += path[pos][PATH_TOKEN_LENGTH]; + } + + return token_ids; + } +private: + // Constants for table structure + static constexpr int32_t TABLE_PIECE_LENGTH = 0; + static constexpr int32_t TABLE_TOKEN_ID = 1; + static constexpr int32_t TABLE_SCORE = 2; + static constexpr int32_t TABLE_PIECE_ID = 3; + + // Constants for path array + static constexpr int32_t PATH_TOKEN_LENGTH = 0; + static constexpr int32_t PATH_TOKEN_ID = 1; + static constexpr int32_t PATH_NUM_TOKENS = 2; + + // Score constants + static constexpr int32_t INVALID_SCORE = -20000000; + static constexpr int32_t UNKNOWN_SCORE = -10000000; + + // List of tokens in the vocabulary + std::vector tokens_; + + // Mapping from byte code point to token ID (for byte fallback) + std::vector bytes_; + + // Mapping from piece code to suffix ID + std::unordered_map to_suffix_id_; + + // Flattened table representing the Trie structure + // Each row contains: [piece_length, token_id, score, piece_id] + std::vector> table_; +}; + +struct llm_tokenizer_plamo2_session { + llm_tokenizer_plamo2_session(const llm_tokenizer_plamo2 & tokenizer) : tokenizer(tokenizer) {} + + void tokenize(const std::string & text, std::vector & output) { + std::vector tokens = tokenizer.encode(text); + output.insert(output.end(), tokens.begin(), tokens.end()); + } + +private: + const llm_tokenizer_plamo2 & tokenizer; +}; + // // impl // @@ -1497,6 +1783,16 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) { special_unk_id = LLAMA_TOKEN_NULL; special_sep_id = LLAMA_TOKEN_NULL; special_pad_id = LLAMA_TOKEN_NULL; + } else if (tokenizer_model == "plamo2") { + type = LLAMA_VOCAB_TYPE_PLAMO2; + + // PLaMo-2 default special tokens (these will be overridden by model config) + special_bos_id = 1; // <|plamo:bos|> + special_eos_id = 2; // <|plamo:eos|> + special_unk_id = 0; // <|plamo:unk|> + special_sep_id = LLAMA_TOKEN_NULL; + special_pad_id = 3; // <|plamo:pad|> + special_mask_id = LLAMA_TOKEN_NULL; } else { throw std::runtime_error(format("unknown tokenizer: '%s'", tokenizer_model.c_str())); } @@ -1514,7 +1810,8 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) { tokenizer_pre == "falcon3" || tokenizer_pre == "falcon-h1" || tokenizer_pre == "pixtral" || - tokenizer_pre == "midm-2.0") { + tokenizer_pre == "midm-2.0" || + tokenizer_pre == "lfm2") { pre_type = LLAMA_VOCAB_PRE_TYPE_LLAMA3; ignore_merges = true; add_bos = true; @@ -1653,6 +1950,10 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) { tokenizer_pre == "hunyuan") { pre_type = LLAMA_VOCAB_PRE_TYPE_HUNYUAN; clean_spaces = false; + } else if ( + tokenizer_pre == "kimi-k2") { + pre_type = LLAMA_VOCAB_PRE_TYPE_KIMI_K2; + clean_spaces = false; } else { LLAMA_LOG_WARN("%s: missing or unrecognized pre-tokenizer type, using: 'default'\n", __func__); pre_type = LLAMA_VOCAB_PRE_TYPE_DEFAULT; @@ -2134,13 +2435,14 @@ enum llama_vocab_type llama_vocab::impl::get_type() const { std::string llama_vocab::impl::type_name() const{ switch (type) { - case LLAMA_VOCAB_TYPE_NONE: return "no vocab"; - case LLAMA_VOCAB_TYPE_SPM: return "SPM"; - case LLAMA_VOCAB_TYPE_BPE: return "BPE"; - case LLAMA_VOCAB_TYPE_WPM: return "WPM"; - case LLAMA_VOCAB_TYPE_UGM: return "UGM"; - case LLAMA_VOCAB_TYPE_RWKV: return "RWKV"; - default: return "unknown"; + case LLAMA_VOCAB_TYPE_NONE: return "no vocab"; + case LLAMA_VOCAB_TYPE_SPM: return "SPM"; + case LLAMA_VOCAB_TYPE_BPE: return "BPE"; + case LLAMA_VOCAB_TYPE_WPM: return "WPM"; + case LLAMA_VOCAB_TYPE_UGM: return "UGM"; + case LLAMA_VOCAB_TYPE_RWKV: return "RWKV"; + case LLAMA_VOCAB_TYPE_PLAMO2: return "PLaMo2"; + default: return "unknown"; } } @@ -2223,6 +2525,9 @@ void llama_vocab::impl::init_tokenizer(enum llama_vocab_type type) { case LLAMA_VOCAB_TYPE_RWKV: tokenizer = std::make_unique(vocab); break; + case LLAMA_VOCAB_TYPE_PLAMO2: + tokenizer = std::make_unique(vocab); + break; default: GGML_ABORT("unsupported vocab type"); } @@ -2555,6 +2860,23 @@ std::vector llama_vocab::impl::tokenize( if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT) { std::string text = fragment.raw_text.substr(fragment.offset, fragment.length); +#ifdef PRETOKENIZERDEBUG + LLAMA_LOG_WARN("TT: (%ld %ld %ld) '%s'\n", text.length(), fragment.offset, fragment.length, text.c_str()); +#endif + + session.tokenize(text, output); + } else { // if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_TOKEN) + output.push_back(fragment.token); + } + } + } break; + case LLAMA_VOCAB_TYPE_PLAMO2: + { + llm_tokenizer_plamo2_session session(*static_cast(tokenizer.get())); + for (const auto & fragment : fragment_buffer) { + if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT) { + std::string text = fragment.raw_text.substr(fragment.offset, fragment.length); + #ifdef PRETOKENIZERDEBUG LLAMA_LOG_WARN("TT: (%ld %ld %ld) '%s'\n", text.length(), fragment.offset, fragment.length, text.c_str()); #endif @@ -2653,6 +2975,24 @@ int32_t llama_vocab::impl::token_to_piece(llama_token token, char * buf, int32_t memcpy(buf, result.data(), result.size()); return (int)result.size(); } + case LLAMA_VOCAB_TYPE_PLAMO2: { + // PLaMo-2 uses similar token handling as BPE/SPM + if (vocab.is_byte(token)) { + // Handle byte tokens like <0xXX> + if (token_text.length() == 6 && token_text.substr(0, 3) == "<0x" && token_text.back() == '>') { + int hex_val = std::stoi(token_text.substr(3, 2), nullptr, 16); + if (length < 1) { + return -1; + } + buf[0] = static_cast(hex_val); + return 1; + } + } + + // Normal token - just copy the text + std::string result = token_text; + return _try_copy(result.data(), result.size()); + } default: GGML_ABORT("fatal error"); } @@ -2897,6 +3237,12 @@ llama_token llama_vocab::byte_to_token(uint8_t ch) const { case LLAMA_VOCAB_TYPE_BPE: { return pimpl->token_to_id.at(unicode_byte_to_utf8(ch)); } + case LLAMA_VOCAB_TYPE_PLAMO2: { + // PLaMo-2 uses byte tokens in format <0xXX> + char hex_str[8]; + snprintf(hex_str, sizeof(hex_str), "<0x%02X>", ch); + return pimpl->token_to_id.at(hex_str); + } default: GGML_ABORT("fatal error"); } @@ -3374,4 +3720,3 @@ int32_t llama_detokenize( bool unparse_special) { return vocab->detokenize(tokens, n_tokens, text, text_len_max, remove_special, unparse_special); } - diff --git a/llama/llama.cpp/src/llama-vocab.h b/llama/llama.cpp/src/llama-vocab.h index 46a1ccecb..1ce8fd307 100644 --- a/llama/llama.cpp/src/llama-vocab.h +++ b/llama/llama.cpp/src/llama-vocab.h @@ -45,6 +45,7 @@ enum llama_vocab_pre_type { LLAMA_VOCAB_PRE_TYPE_PIXTRAL = 34, LLAMA_VOCAB_PRE_TYPE_SEED_CODER = 35, LLAMA_VOCAB_PRE_TYPE_HUNYUAN = 36, + LLAMA_VOCAB_PRE_TYPE_KIMI_K2 = 37, }; struct LLM_KV; diff --git a/llama/llama.cpp/src/unicode.cpp b/llama/llama.cpp/src/unicode.cpp index 4da581c5e..ce336a228 100644 --- a/llama/llama.cpp/src/unicode.cpp +++ b/llama/llama.cpp/src/unicode.cpp @@ -578,6 +578,178 @@ static std::vector unicode_regex_split_stl(const std::string & text, con return bpe_offsets; } +// K2 system regex patterns (from tokenization_kimi.py): +// [\p{Han}]+|[^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}&&[^\p{Han}]]*[\p{Ll}\p{Lm}\p{Lo}\p{M}&&[^\p{Han}]]+(?i:'s|'t|'re|'ve|'m|'ll|'d)?|[^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}&&[^\p{Han}]]+[\p{Ll}\p{Lm}\p{Lo}\p{M}&&[^\p{Han}]]*(?i:'s|'t|'re|'ve|'m|'ll|'d)?|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+ +static std::vector unicode_regex_split_custom_kimi_k2(const std::string & text, const std::vector & offsets) { + std::vector bpe_offsets; + bpe_offsets.reserve(offsets.size()); + + const auto cpts = unicode_cpts_from_utf8(text); + + size_t start = 0; + for (auto offset : offsets) { + const size_t offset_ini = start; + const size_t offset_end = start + offset; + assert(offset_end <= cpts.size()); + start = offset_end; + + static const uint32_t OUT_OF_RANGE = 0xFFFFFFFF; + auto _get_cpt = [&] (const size_t pos) -> uint32_t { + return (offset_ini <= pos && pos < offset_end) ? cpts[pos] : OUT_OF_RANGE; + }; + + auto _get_flags = [&] (const size_t pos) -> unicode_cpt_flags { + return (offset_ini <= pos && pos < offset_end) ? unicode_cpt_flags_from_cpt(cpts[pos]) : unicode_cpt_flags{}; + }; + + size_t _prev_end = offset_ini; + auto _add_token = [&] (const size_t end) -> size_t { + assert(_prev_end <= end && end <= offset_end); + size_t len = end - _prev_end; + if (len > 0) { + bpe_offsets.push_back(len); + } + _prev_end = end; + return len; + }; + + for (size_t pos = offset_ini; pos < offset_end; /*pos++*/ ) { + const uint32_t cpt = _get_cpt(pos); + const auto flags = _get_flags(pos); + + // Pattern 1: [\p{Han}]+ (Chinese characters) + if (unicode_cpt_is_han(cpt)) { + while (unicode_cpt_is_han(_get_cpt(pos))) { + pos++; + } + _add_token(pos); + continue; + } + + // Pattern 2 & 3: Letter words excluding Han characters with optional contractions + // [^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}&&[^\p{Han}]]*[\p{Ll}\p{Lm}\p{Lo}\p{M}&&[^\p{Han}]]+(?:'s|'t|'re|'ve|'m|'ll|'d)? + // [^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}&&[^\p{Han}]]+[\p{Ll}\p{Lm}\p{Lo}\p{M}&&[^\p{Han}]]*(?:'s|'t|'re|'ve|'m|'ll|'d)? + // Check if current char is a letter OR if current char could be a leading char and next char is a letter + bool is_letter_pattern = (flags.is_letter && !unicode_cpt_is_han(cpt)) || + (!(cpt == '\r' || cpt == '\n' || flags.is_letter || flags.is_number) && + _get_flags(pos + 1).is_letter && !unicode_cpt_is_han(_get_cpt(pos + 1))); + + if (is_letter_pattern) { + // Handle optional leading non-letter/non-number character + bool has_leading_char = false; + if (!(cpt == '\r' || cpt == '\n' || flags.is_letter || flags.is_number)) { + has_leading_char = true; + pos++; + } + + // Match letter sequence (excluding Han characters) + bool has_letters = false; + while (_get_flags(pos).is_letter && !unicode_cpt_is_han(_get_cpt(pos))) { + has_letters = true; + pos++; + } + + // Only proceed if we found letters (after potentially skipping leading char) + if (has_letters || (!has_leading_char && _get_flags(pos).is_letter && !unicode_cpt_is_han(_get_cpt(pos)))) { + if (!has_letters) pos++; // consume the first letter if we didn't already + + // Continue consuming letters + while (_get_flags(pos).is_letter && !unicode_cpt_is_han(_get_cpt(pos))) { + pos++; + } + + // Check for optional contractions (?:'s|'t|'re|'ve|'m|'ll|'d) + if (_get_cpt(pos) == '\'' && pos + 1 < offset_end) { + uint32_t cpt_next = unicode_tolower(_get_cpt(pos + 1)); + if (cpt_next == 's' || cpt_next == 't' || cpt_next == 'm' || cpt_next == 'd') { + pos += 2; + } else if (pos + 2 < offset_end) { + uint32_t cpt_next_next = unicode_tolower(_get_cpt(pos + 2)); + if ((cpt_next == 'r' && cpt_next_next == 'e') || + (cpt_next == 'v' && cpt_next_next == 'e') || + (cpt_next == 'l' && cpt_next_next == 'l')) { + pos += 3; + } + } + } + + _add_token(pos); + continue; + } else if (has_leading_char) { + // We consumed a leading char but found no letters, backtrack + pos--; + } + } + + // Pattern 4: \p{N}{1,3} (numbers 1-3 digits) + if (flags.is_number) { + size_t ini = pos; + while (_get_flags(pos).is_number) { + if (++pos - ini >= 3) { + _add_token(pos); + ini = pos; + } + } + _add_token(pos); + continue; + } + + // Pattern 5: ?[^\s\p{L}\p{N}]+[\r\n]* (optional space + non-word chars + optional newlines) + auto flags2 = (cpt == ' ' ? _get_flags(pos + 1) : flags); + if (!(flags2.is_whitespace || flags2.is_letter || flags2.is_number) && flags2.as_uint()) { + pos += (cpt == ' '); + while (!(flags2.is_whitespace || flags2.is_letter || flags2.is_number) && flags2.as_uint()) { + flags2 = _get_flags(++pos); + } + // Match optional [\r\n]* + uint32_t cpt2 = _get_cpt(pos); + while (cpt2 == '\r' || cpt2 == '\n') { + cpt2 = _get_cpt(++pos); + } + _add_token(pos); + continue; + } + + // Count whitespace characters + size_t num_whitespaces = 0; + size_t last_end_r_or_n = 0; + while (_get_flags(pos + num_whitespaces).is_whitespace) { + uint32_t cpt2 = _get_cpt(pos + num_whitespaces); + if (cpt2 == '\r' || cpt2 == '\n') { + last_end_r_or_n = pos + num_whitespaces + 1; + } + num_whitespaces++; + } + + // Pattern 6: \s*[\r\n]+ (whitespace with newlines) + if (last_end_r_or_n > 0) { + pos = last_end_r_or_n; + _add_token(pos); + continue; + } + + // Pattern 7: \s+(?!\S) (trailing whitespace) + if (num_whitespaces > 1 && _get_cpt(pos + num_whitespaces) != OUT_OF_RANGE) { + pos += num_whitespaces - 1; + _add_token(pos); + continue; + } + + // Pattern 8: \s+ (general whitespace) + if (num_whitespaces > 0) { + pos += num_whitespaces; + _add_token(pos); + continue; + } + + // No matches - consume single character + _add_token(++pos); + } + } + + return bpe_offsets; +} + static std::vector unicode_regex_split_custom(const std::string & text, const std::string & regex_expr, const std::vector & offsets) { std::vector bpe_offsets; @@ -588,6 +760,9 @@ static std::vector unicode_regex_split_custom(const std::string & text, regex_expr == "(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+") { bpe_offsets = unicode_regex_split_custom_llama3(text, offsets); + } else if (regex_expr == "\\p{Han}+") { + // K2's first pattern - handle all K2 patterns together + bpe_offsets = unicode_regex_split_custom_kimi_k2(text, offsets); } return bpe_offsets; @@ -693,6 +868,38 @@ uint32_t unicode_tolower(uint32_t cpt) { return cpt; // Return the original code point if no lowercase mapping is found } +bool unicode_cpt_is_han(uint32_t cpt) { + // Han character ranges (Chinese/CJK characters) + // CJK Unified Ideographs (most common) + if (cpt >= 0x4E00 && cpt <= 0x9FFF) return true; + + // CJK Extension A + if (cpt >= 0x3400 && cpt <= 0x4DBF) return true; + + // CJK Extension B + if (cpt >= 0x20000 && cpt <= 0x2A6DF) return true; + + // CJK Extension C + if (cpt >= 0x2A700 && cpt <= 0x2B73F) return true; + + // CJK Extension D + if (cpt >= 0x2B740 && cpt <= 0x2B81F) return true; + + // CJK Extension E + if (cpt >= 0x2B820 && cpt <= 0x2CEAF) return true; + + // CJK Extension F + if (cpt >= 0x2CEB0 && cpt <= 0x2EBEF) return true; + + // CJK Compatibility Ideographs + if (cpt >= 0xF900 && cpt <= 0xFAFF) return true; + + // CJK Compatibility Ideographs Supplement + if (cpt >= 0x2F800 && cpt <= 0x2FA1F) return true; + + return false; +} + std::vector unicode_regex_split(const std::string & text, const std::vector & regex_exprs) { // unicode categories static const std::map k_ucat_enum = { diff --git a/llama/llama.cpp/src/unicode.h b/llama/llama.cpp/src/unicode.h index c27098df7..0a5fa2a78 100644 --- a/llama/llama.cpp/src/unicode.h +++ b/llama/llama.cpp/src/unicode.h @@ -63,4 +63,6 @@ uint8_t unicode_utf8_to_byte(const std::string & utf8); uint32_t unicode_tolower(uint32_t cpt); +bool unicode_cpt_is_han(uint32_t cpt); + std::vector unicode_regex_split(const std::string & text, const std::vector & regex_exprs); diff --git a/ml/backend/ggml/ggml/src/ggml-cpu/llamafile/sgemm.cpp b/ml/backend/ggml/ggml/src/ggml-cpu/llamafile/sgemm.cpp index ed61869a5..2be54c31b 100644 --- a/ml/backend/ggml/ggml/src/ggml-cpu/llamafile/sgemm.cpp +++ b/ml/backend/ggml/ggml/src/ggml-cpu/llamafile/sgemm.cpp @@ -1541,7 +1541,7 @@ class tinyBLAS_BF16_PPC { } else if constexpr(RM == 8 && RN == 4) { KERNEL_8x4(ii,jj); } else { - static_assert(false, "RN/RM values not supported"); + assert(false && "RN/RM values not supported"); } } @@ -1573,13 +1573,13 @@ class tinyBLAS_BF16_PPC { const int nth; }; -template +template class tinyBLAS_Q0_PPC { public: tinyBLAS_Q0_PPC(int64_t k, const TA *A, int64_t lda, - const TB *B, int64_t ldb, - TC *C, int64_t ldc, + const block_q8_0 *B, int64_t ldb, + float *C, int64_t ldc, int ith, int nth) : A(A), B(B), C(C), k(k), lda(lda), ldb(ldb), ldc(ldc), ith(ith), nth(nth) { } @@ -1590,8 +1590,7 @@ class tinyBLAS_Q0_PPC { private: - template - inline void save_res(int ii, int jj, int idx, vector float* fin_res) { + inline void save_res(int ii, int jj, int idx, vector float* fin_res, int RM=4, int RN=4) { for (int I = 0; I < RM; I++) { for (int J = 0; J < RN; J++) { *((float*)(C+ii+((jj+J)*ldc)+I)) = *((float*)&fin_res[idx+I]+J); @@ -1611,29 +1610,67 @@ class tinyBLAS_Q0_PPC { fin_res[s_idx+i] = vec_madd(res[i], vs[s_idx+i], fin_res[s_idx+i]); } } - - template - void packNormalInt4(const TA* a, int64_t lda, int rows, int cols, VA* vec, std::array& comparray) { - int64_t i, j; - TA *aoffset = NULL; - VA *vecOffset = NULL; - TA *aoffset1 = NULL, *aoffset2 = NULL, *aoffset3 = NULL, *aoffset4 = NULL; - TA *aoffset5 = NULL, *aoffset6 = NULL, *aoffset7 = NULL, *aoffset8 = NULL; - VB c1[2] = {0}, c2[2] = {0}, c3[2] = {0}, c4[2] = {0}; - VB c5[2] = {0}, c6[2] = {0}, c7[2] = {0}, c8[2] = {0}; - VB t1, t2, t3, t4, t5, t6, t7, t8; + /* This function processes quantized data from block_q4_0 elements. + * First the we try to extract the two int4 values stored in single int8_t into two signed int8. + * And then we subtract each of the resultant element with 8, to convert signed int8 to unsigned int8. + * Also compute the rowsum which is required to compensate the above conversion. */ + inline void process_q4_elements(vector signed char (&c)[2], int* ca) { const vector signed char lowMask = vec_splats((signed char)0xF); const vector unsigned char v4 = vec_splats((unsigned char)0x4); const vector signed char v8 = vec_splats((signed char)0x8); - aoffset = const_cast(a); - vecOffset = vec; + vector signed int vsum = {0}; + vector signed int vsum2 = {0}; + c[0] = vec_and(c[1], lowMask); + c[1] = vec_sr(c[1], v4); + c[0] = vec_sub(c[0], v8); + c[1] = vec_sub(c[1], v8); + vsum = vec_sum4s(c[0], vsum); + vsum2 = vec_sum4s(c[1], vsum2); + vsum = vec_add(vsum, vsum2); + *(ca) = vsum[0] + vsum[1] + vsum[2] + vsum[3]; + } + + template + inline void vector_permute_store(V2 &s1, V2 &s2, V2 &s3, V2 &s4, V1 *vecOffset, bool flip) { vector unsigned char swiz1 = {0, 1, 2, 3, 4, 5, 6, 7, 16, 17, 18, 19, 20, 21, 22, 23}; vector unsigned char swiz2 = {8, 9, 10, 11, 12, 13, 14, 15, 24, 25, 26, 27, 28, 29, 30, 31}; vector unsigned char swiz3 = {0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27}; vector unsigned char swiz4 = {4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31}; - vector signed int vsum = {0}; - vector signed int vsum2 = {0}; + V2 t1, t2, t3, t4, t5, t6, t7, t8; + vector unsigned char xor_vector; + uint8_t flip_vec = 0x80; + xor_vector = vec_splats(flip_vec); + t1 = vec_perm(s1, s2, swiz1); + t2 = vec_perm(s1, s2, swiz2); + t3 = vec_perm(s3, s4, swiz1); + t4 = vec_perm(s3, s4, swiz2); + t5 = vec_perm(t1, t3, swiz3); + t6 = vec_perm(t1, t3, swiz4); + t7 = vec_perm(t2, t4, swiz3); + t8 = vec_perm(t2, t4, swiz4); + if (flip == true) { + t5 = vec_xor(t5, xor_vector); + t6 = vec_xor(t6, xor_vector); + t7 = vec_xor(t7, xor_vector); + t8 = vec_xor(t8, xor_vector); + } + vec_xst(t5, 0, vecOffset); + vec_xst(t6, 0, vecOffset+16); + vec_xst(t7, 0, vecOffset+32); + vec_xst(t8, 0, vecOffset+48); + } + template + void packNormalInt4(const TA* a, int64_t lda, int rows, int cols, int8_t* vec, std::array& comparray) { + int64_t i, j; + TA *aoffset = NULL; + int8_t *vecOffset = NULL; + TA *aoffset1 = NULL, *aoffset2 = NULL, *aoffset3 = NULL, *aoffset4 = NULL; + TA *aoffset5 = NULL, *aoffset6 = NULL, *aoffset7 = NULL, *aoffset8 = NULL; + vector signed char c1[2] = {0}, c2[2] = {0}, c3[2] = {0}, c4[2] = {0}; + vector signed char c5[2] = {0}, c6[2] = {0}, c7[2] = {0}, c8[2] = {0}; + aoffset = const_cast(a); + vecOffset = vec; j = (rows >> 3); if (j > 0) { do { @@ -1646,159 +1683,30 @@ class tinyBLAS_Q0_PPC { aoffset7 = aoffset6 + lda; aoffset8 = aoffset7 + lda; aoffset += 8 * lda; - i = (cols >> 2); if (i > 0) { do { - c1[1] = reinterpret_cast(vec_xl(0, aoffset1->qs)); - c2[1] = reinterpret_cast(vec_xl(0, aoffset2->qs)); - c3[1] = reinterpret_cast(vec_xl(0, aoffset3->qs)); - c4[1] = reinterpret_cast(vec_xl(0, aoffset4->qs)); - c5[1] = reinterpret_cast(vec_xl(0, aoffset5->qs)); - c6[1] = reinterpret_cast(vec_xl(0, aoffset6->qs)); - c7[1] = reinterpret_cast(vec_xl(0, aoffset7->qs)); - c8[1] = reinterpret_cast(vec_xl(0, aoffset8->qs)); - - c1[0] = vec_and(c1[1], lowMask); - c1[1] = vec_sr(c1[1], v4); - c1[0] = vec_sub(c1[0], v8); - c1[1] = vec_sub(c1[1], v8); - vsum = vec_sum4s(c1[0], vsum); - vsum2 = vec_sum4s(c1[1], vsum2); - vsum = vec_add(vsum, vsum2); - comparray[0] = vsum[0] + vsum[1] + vsum[2] + vsum[3]; - vsum = vec_splats(0); - vsum2 = vec_splats(0); - - c2[0] = vec_and(c2[1], lowMask); - c2[1] = vec_sr(c2[1], v4); - c2[0] = vec_sub(c2[0], v8); - c2[1] = vec_sub(c2[1], v8); - vsum = vec_sum4s(c2[0], vsum); - vsum2 = vec_sum4s(c2[1], vsum2); - vsum = vec_add(vsum, vsum2); - comparray[1] = vsum[0] + vsum[1] + vsum[2] + vsum[3]; - vsum = vec_splats(0); - vsum2 = vec_splats(0); - - c3[0] = vec_and(c3[1], lowMask); - c3[1] = vec_sr(c3[1], v4); - c3[0] = vec_sub(c3[0], v8); - c3[1] = vec_sub(c3[1], v8); - vsum = vec_sum4s(c3[0], vsum); - vsum2 = vec_sum4s(c3[1], vsum2); - vsum = vec_add(vsum, vsum2); - comparray[2] = vsum[0] + vsum[1] + vsum[2] + vsum[3]; - vsum = vec_splats(0); - vsum2 = vec_splats(0); - - c4[0] = vec_and(c4[1], lowMask); - c4[1] = vec_sr(c4[1], v4); - c4[0] = vec_sub(c4[0], v8); - c4[1] = vec_sub(c4[1], v8); - vsum = vec_sum4s(c4[0], vsum); - vsum2 = vec_sum4s(c4[1], vsum2); - vsum = vec_add(vsum, vsum2); - comparray[3] = vsum[0] + vsum[1] + vsum[2] + vsum[3]; - vsum = vec_splats(0); - vsum2 = vec_splats(0); - - c5[0] = vec_and(c5[1], lowMask); - c5[1] = vec_sr(c5[1], v4); - c5[0] = vec_sub(c5[0], v8); - c5[1] = vec_sub(c5[1], v8); - vsum = vec_sum4s(c5[0], vsum); - vsum2 = vec_sum4s(c5[1], vsum2); - vsum = vec_add(vsum, vsum2); - comparray[4] = vsum[0] + vsum[1] + vsum[2] + vsum[3]; - vsum = vec_splats(0); - vsum2 = vec_splats(0); - - c6[0] = vec_and(c6[1], lowMask); - c6[1] = vec_sr(c6[1], v4); - c6[0] = vec_sub(c6[0], v8); - c6[1] = vec_sub(c6[1], v8); - vsum = vec_sum4s(c6[0], vsum); - vsum2 = vec_sum4s(c6[1], vsum2); - vsum = vec_add(vsum, vsum2); - comparray[5] = vsum[0] + vsum[1] + vsum[2] + vsum[3]; - vsum = vec_splats(0); - vsum2 = vec_splats(0); - - c7[0] = vec_and(c7[1], lowMask); - c7[1] = vec_sr(c7[1], v4); - c7[0] = vec_sub(c7[0], v8); - c7[1] = vec_sub(c7[1], v8); - vsum = vec_sum4s(c7[0], vsum); - vsum2 = vec_sum4s(c7[1], vsum2); - vsum = vec_add(vsum, vsum2); - comparray[6] = vsum[0] + vsum[1] + vsum[2] + vsum[3]; - vsum = vec_splats(0); - vsum2 = vec_splats(0); - - c8[0] = vec_and(c8[1], lowMask); - c8[1] = vec_sr(c8[1], v4); - c8[0] = vec_sub(c8[0], v8); - c8[1] = vec_sub(c8[1], v8); - vsum = vec_sum4s(c8[0], vsum); - vsum2 = vec_sum4s(c8[1], vsum2); - vsum = vec_add(vsum, vsum2); - comparray[7] = vsum[0] + vsum[1] + vsum[2] + vsum[3]; - vsum = vec_splats(0); - vsum2 = vec_splats(0); - - t1 = vec_perm(c1[0], c2[0], swiz1); - t2 = vec_perm(c1[0], c2[0], swiz2); - t3 = vec_perm(c3[0], c4[0], swiz1); - t4 = vec_perm(c3[0], c4[0], swiz2); - t5 = vec_perm(t1, t3, swiz3); - t6 = vec_perm(t1, t3, swiz4); - t7 = vec_perm(t2, t4, swiz3); - t8 = vec_perm(t2, t4, swiz4); - vec_xst(t5, 0, vecOffset); - vec_xst(t6, 0, vecOffset+16); - vec_xst(t7, 0, vecOffset+32); - vec_xst(t8, 0, vecOffset+48); - - t1 = vec_perm(c1[1], c2[1], swiz1); - t2 = vec_perm(c1[1], c2[1], swiz2); - t3 = vec_perm(c3[1], c4[1], swiz1); - t4 = vec_perm(c3[1], c4[1], swiz2); - t5 = vec_perm(t1, t3, swiz3); - t6 = vec_perm(t1, t3, swiz4); - t7 = vec_perm(t2, t4, swiz3); - t8 = vec_perm(t2, t4, swiz4); - vec_xst(t5, 0, vecOffset+64); - vec_xst(t6, 0, vecOffset+80); - vec_xst(t7, 0, vecOffset+96); - vec_xst(t8, 0, vecOffset+112); - - t1 = vec_perm(c5[0], c6[0], swiz1); - t2 = vec_perm(c5[0], c6[0], swiz2); - t3 = vec_perm(c7[0], c8[0], swiz1); - t4 = vec_perm(c7[0], c8[0], swiz2); - t5 = vec_perm(t1, t3, swiz3); - t6 = vec_perm(t1, t3, swiz4); - t7 = vec_perm(t2, t4, swiz3); - t8 = vec_perm(t2, t4, swiz4); - vec_xst(t5, 0, vecOffset+128); - vec_xst(t6, 0, vecOffset+144); - vec_xst(t7, 0, vecOffset+160); - vec_xst(t8, 0, vecOffset+176); - - t1 = vec_perm(c5[1], c6[1], swiz1); - t2 = vec_perm(c5[1], c6[1], swiz2); - t3 = vec_perm(c7[1], c8[1], swiz1); - t4 = vec_perm(c7[1], c8[1], swiz2); - t5 = vec_perm(t1, t3, swiz3); - t6 = vec_perm(t1, t3, swiz4); - t7 = vec_perm(t2, t4, swiz3); - t8 = vec_perm(t2, t4, swiz4); - vec_xst(t5, 0, vecOffset+192); - vec_xst(t6, 0, vecOffset+208); - vec_xst(t7, 0, vecOffset+224); - vec_xst(t8, 0, vecOffset+240); + c1[1] = reinterpret_cast(vec_xl(0, aoffset1->qs)); + c2[1] = reinterpret_cast(vec_xl(0, aoffset2->qs)); + c3[1] = reinterpret_cast(vec_xl(0, aoffset3->qs)); + c4[1] = reinterpret_cast(vec_xl(0, aoffset4->qs)); + c5[1] = reinterpret_cast(vec_xl(0, aoffset5->qs)); + c6[1] = reinterpret_cast(vec_xl(0, aoffset6->qs)); + c7[1] = reinterpret_cast(vec_xl(0, aoffset7->qs)); + c8[1] = reinterpret_cast(vec_xl(0, aoffset8->qs)); + process_q4_elements(c1, &comparray[0]); + process_q4_elements(c2, &comparray[1]); + process_q4_elements(c3, &comparray[2]); + process_q4_elements(c4, &comparray[3]); + process_q4_elements(c5, &comparray[4]); + process_q4_elements(c6, &comparray[5]); + process_q4_elements(c7, &comparray[6]); + process_q4_elements(c8, &comparray[7]); + vector_permute_store(c1[0], c2[0], c3[0], c4[0], vecOffset, false); + vector_permute_store(c1[1], c2[1], c3[1], c4[1], vecOffset+64, false); + vector_permute_store(c5[0], c6[0], c7[0], c8[0], vecOffset+128, false); + vector_permute_store(c5[1], c6[1], c7[1], c8[1], vecOffset+192, false); aoffset1 += lda; aoffset2 += lda; aoffset3 += lda; @@ -1821,85 +1729,20 @@ class tinyBLAS_Q0_PPC { aoffset3 = aoffset2 + lda; aoffset4 = aoffset3 + lda; aoffset += 4 * lda; - i = (cols >> 2); if (i > 0) { do { - c1[1] = reinterpret_cast(vec_xl(0, aoffset1->qs)); - c2[1] = reinterpret_cast(vec_xl(0, aoffset2->qs)); - c3[1] = reinterpret_cast(vec_xl(0, aoffset3->qs)); - c4[1] = reinterpret_cast(vec_xl(0, aoffset4->qs)); - - c1[0] = vec_and(c1[1], lowMask); - c1[1] = vec_sr(c1[1], v4); - c1[0] = vec_sub(c1[0], v8); - c1[1] = vec_sub(c1[1], v8); - vsum = vec_sum4s(c1[0], vsum); - vsum2 = vec_sum4s(c1[1], vsum2); - vsum = vec_add(vsum, vsum2); - comparray[0] = vsum[0] + vsum[1] + vsum[2] + vsum[3]; - vsum = vec_splats(0); - vsum2 = vec_splats(0); - - c2[0] = vec_and(c2[1], lowMask); - c2[1] = vec_sr(c2[1], v4); - c2[0] = vec_sub(c2[0], v8); - c2[1] = vec_sub(c2[1], v8); - vsum = vec_sum4s(c2[0], vsum); - vsum2 = vec_sum4s(c2[1], vsum2); - vsum = vec_add(vsum, vsum2); - comparray[1] = vsum[0] + vsum[1] + vsum[2] + vsum[3]; - vsum = vec_splats(0); - vsum2 = vec_splats(0); - - c3[0] = vec_and(c3[1], lowMask); - c3[1] = vec_sr(c3[1], v4); - c3[0] = vec_sub(c3[0], v8); - c3[1] = vec_sub(c3[1], v8); - vsum = vec_sum4s(c3[0], vsum); - vsum2 = vec_sum4s(c3[1], vsum2); - vsum = vec_add(vsum, vsum2); - comparray[2] = vsum[0] + vsum[1] + vsum[2] + vsum[3]; - vsum = vec_splats(0); - vsum2 = vec_splats(0); - - c4[0] = vec_and(c4[1], lowMask); - c4[1] = vec_sr(c4[1], v4); - c4[0] = vec_sub(c4[0], v8); - c4[1] = vec_sub(c4[1], v8); - vsum = vec_sum4s(c4[0], vsum); - vsum2 = vec_sum4s(c4[1], vsum2); - vsum = vec_add(vsum, vsum2); - comparray[3] = vsum[0] + vsum[1] + vsum[2] + vsum[3]; - vsum = vec_splats(0); - vsum2 = vec_splats( 0); - - t1 = vec_perm(c1[0], c2[0], swiz1); - t2 = vec_perm(c1[0], c2[0], swiz2); - t3 = vec_perm(c3[0], c4[0], swiz1); - t4 = vec_perm(c3[0], c4[0], swiz2); - t5 = vec_perm(t1, t3, swiz3); - t6 = vec_perm(t1, t3, swiz4); - t7 = vec_perm(t2, t4, swiz3); - t8 = vec_perm(t2, t4, swiz4); - vec_xst(t5, 0, vecOffset); - vec_xst(t6, 0, vecOffset+16); - vec_xst(t7, 0, vecOffset+32); - vec_xst(t8, 0, vecOffset+48); - - t1 = vec_perm(c1[1], c2[1], swiz1); - t2 = vec_perm(c1[1], c2[1], swiz2); - t3 = vec_perm(c3[1], c4[1], swiz1); - t4 = vec_perm(c3[1], c4[1], swiz2); - t5 = vec_perm(t1, t3, swiz3); - t6 = vec_perm(t1, t3, swiz4); - t7 = vec_perm(t2, t4, swiz3); - t8 = vec_perm(t2, t4, swiz4); - vec_xst(t5, 0, vecOffset+64); - vec_xst(t6, 0, vecOffset+80); - vec_xst(t7, 0, vecOffset+96); - vec_xst(t8, 0, vecOffset+112); + c1[1] = reinterpret_cast(vec_xl(0, aoffset1->qs)); + c2[1] = reinterpret_cast(vec_xl(0, aoffset2->qs)); + c3[1] = reinterpret_cast(vec_xl(0, aoffset3->qs)); + c4[1] = reinterpret_cast(vec_xl(0, aoffset4->qs)); + process_q4_elements(c1, &comparray[0]); + process_q4_elements(c2, &comparray[1]); + process_q4_elements(c3, &comparray[2]); + process_q4_elements(c4, &comparray[3]); + vector_permute_store(c1[0], c2[0], c3[0], c4[0], vecOffset, false); + vector_permute_store(c1[1], c2[1], c3[1], c4[1], vecOffset+64, false); aoffset1 += lda; aoffset2 += lda; aoffset3 += lda; @@ -1918,80 +1761,17 @@ class tinyBLAS_Q0_PPC { if (i > 0) { do { switch(rows) { - case 3: c3[1] = reinterpret_cast(vec_xl(0, aoffset3->qs)); - case 2: c2[1] = reinterpret_cast(vec_xl(0, aoffset2->qs)); - case 1: c1[1] = reinterpret_cast(vec_xl(0, aoffset1->qs)); + case 3: c3[1] = reinterpret_cast(vec_xl(0, aoffset3->qs)); + case 2: c2[1] = reinterpret_cast(vec_xl(0, aoffset2->qs)); + case 1: c1[1] = reinterpret_cast(vec_xl(0, aoffset1->qs)); break; } - c1[0] = vec_and(c1[1], lowMask); - c1[1] = vec_sr(c1[1], v4); - c1[0] = vec_sub(c1[0], v8); - c1[1] = vec_sub(c1[1], v8); - vsum = vec_sum4s(c1[0], vsum); - vsum2 = vec_sum4s(c1[1], vsum2); - vsum = vec_add(vsum, vsum2); - comparray[0] = vsum[0] + vsum[1] + vsum[2] + vsum[3]; - vsum = vec_splats(0); - vsum2 = vec_splats(0); - - c2[0] = vec_and(c2[1], lowMask); - c2[1] = vec_sr(c2[1], v4); - c2[0] = vec_sub(c2[0], v8); - c2[1] = vec_sub(c2[1], v8); - vsum = vec_sum4s(c2[0], vsum); - vsum2 = vec_sum4s(c2[1], vsum2); - vsum = vec_add(vsum, vsum2); - comparray[1] = vsum[0] + vsum[1] + vsum[2] + vsum[3]; - vsum = vec_splats(0); - vsum2 = vec_splats(0); - - c3[0] = vec_and(c3[1], lowMask); - c3[1] = vec_sr(c3[1], v4); - c3[0] = vec_sub(c3[0], v8); - c3[1] = vec_sub(c3[1], v8); - vsum = vec_sum4s(c3[0], vsum); - vsum2 = vec_sum4s(c3[1], vsum2); - vsum = vec_add(vsum, vsum2); - comparray[2] = vsum[0] + vsum[1] + vsum[2] + vsum[3]; - vsum = vec_splats(0); - vsum2 = vec_splats(0); - - c4[0] = vec_and(c4[1], lowMask); - c4[1] = vec_sr(c4[1], v4); - c4[0] = vec_sub(c4[0], v8); - c4[1] = vec_sub(c4[1], v8); - vsum = vec_sum4s(c4[0], vsum); - vsum2 = vec_sum4s(c4[1], vsum2); - vsum = vec_add(vsum, vsum2); - comparray[3] = vsum[0] + vsum[1] + vsum[2] + vsum[3]; - vsum = vec_splats(0); - vsum2 = vec_splats(0); - - t1 = vec_perm(c1[0], c2[0], swiz1); - t2 = vec_perm(c1[0], c2[0], swiz2); - t3 = vec_perm(c3[0], c4[0], swiz1); - t4 = vec_perm(c3[0], c4[0], swiz2); - t5 = vec_perm(t1, t3, swiz3); - t6 = vec_perm(t1, t3, swiz4); - t7 = vec_perm(t2, t4, swiz3); - t8 = vec_perm(t2, t4, swiz4); - vec_xst(t5, 0, vecOffset); - vec_xst(t6, 0, vecOffset+16); - vec_xst(t7, 0, vecOffset+32); - vec_xst(t8, 0, vecOffset+48); - - t1 = vec_perm(c1[1], c2[1], swiz1); - t2 = vec_perm(c1[1], c2[1], swiz2); - t3 = vec_perm(c3[1], c4[1], swiz1); - t4 = vec_perm(c3[1], c4[1], swiz2); - t5 = vec_perm(t1, t3, swiz3); - t6 = vec_perm(t1, t3, swiz4); - t7 = vec_perm(t2, t4, swiz3); - t8 = vec_perm(t2, t4, swiz4); - vec_xst(t5, 0, vecOffset+64); - vec_xst(t6, 0, vecOffset+80); - vec_xst(t7, 0, vecOffset+96); - vec_xst(t8, 0, vecOffset+112); + process_q4_elements(c1, &comparray[0]); + process_q4_elements(c2, &comparray[1]); + process_q4_elements(c3, &comparray[2]); + process_q4_elements(c4, &comparray[3]); + vector_permute_store(c1[0], c2[0], c3[0], c4[0], vecOffset, false); + vector_permute_store(c1[1], c2[1], c3[1], c4[1], vecOffset+64, false); aoffset1 += lda; aoffset2 += lda; aoffset3 += lda; @@ -2001,146 +1781,40 @@ class tinyBLAS_Q0_PPC { } } } - template - void packNormal(const TB* a, int64_t lda, int rows, int cols, VA* vec, bool flip) { + void packNormal(const block_q8_0* a, int64_t lda, int rows, int cols, VA* vec, bool flip) { int64_t i, j; - TB *aoffset = NULL; + block_q8_0 *aoffset = NULL; VA *vecOffset = NULL; - TB *aoffset1 = NULL, *aoffset2 = NULL, *aoffset3 = NULL, *aoffset4 = NULL; - TB *aoffset5 = NULL, *aoffset6 = NULL, *aoffset7 = NULL, *aoffset8 = NULL; - __vector_pair C1, C2, C3, C4, C5, C6, C7, C8; - VB c1[2] = {0}, c2[2] = {0}, c3[2] = {0}, c4[2]={0}; - VB c5[2] = {0}, c6[2] = {0}, c7[2] = {0}, c8[2]={0}; - VB t1, t2, t3, t4, t5, t6, t7, t8; - vector unsigned char xor_vector; - uint8_t flip_vec = 0x80; - xor_vector = vec_splats(flip_vec); - vector unsigned char swiz1 = {0, 1, 2, 3, 4, 5, 6, 7, 16, 17, 18, 19, 20, 21, 22, 23}; - vector unsigned char swiz2 = {8, 9, 10, 11, 12, 13, 14, 15, 24, 25, 26, 27, 28, 29, 30, 31}; - vector unsigned char swiz3 = {0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27}; - vector unsigned char swiz4 = {4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31}; - - aoffset = const_cast(a); + block_q8_0* aoffsets[8]; + __vector_pair arr[8]; + VB c[8][2] = {0}; + VB c1[8] = {0}; VB c2[8] = {0}; + aoffset = const_cast(a); vecOffset = vec; j = (rows >> 3); if (j > 0) { do { - aoffset1 = aoffset; - aoffset2 = aoffset1 + lda; - aoffset3 = aoffset2 + lda; - aoffset4 = aoffset3 + lda; - aoffset5 = aoffset4 + lda; - aoffset6 = aoffset5 + lda; - aoffset7 = aoffset6 + lda; - aoffset8 = aoffset7 + lda; + aoffsets[0] = aoffset; + for (int it = 1; it < 8; it++) + aoffsets[it] = aoffsets[it-1] + lda; aoffset += 8 * lda; i = (cols >> 3); if (i > 0) { do { - C1 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset1->qs); - C2 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset2->qs); - C3 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset3->qs); - C4 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset4->qs); - C5 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset5->qs); - C6 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset6->qs); - C7 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset7->qs); - C8 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset8->qs); - - __builtin_vsx_disassemble_pair(c1, &C1); - __builtin_vsx_disassemble_pair(c2, &C2); - __builtin_vsx_disassemble_pair(c3, &C3); - __builtin_vsx_disassemble_pair(c4, &C4); - __builtin_vsx_disassemble_pair(c5, &C5); - __builtin_vsx_disassemble_pair(c6, &C6); - __builtin_vsx_disassemble_pair(c7, &C7); - __builtin_vsx_disassemble_pair(c8, &C8); - - t1 = vec_perm(c1[0], c2[0], swiz1); - t2 = vec_perm(c1[0], c2[0], swiz2); - t3 = vec_perm(c3[0], c4[0], swiz1); - t4 = vec_perm(c3[0], c4[0], swiz2); - t5 = vec_perm(t1, t3, swiz3); - t6 = vec_perm(t1, t3, swiz4); - t7 = vec_perm(t2, t4, swiz3); - t8 = vec_perm(t2, t4, swiz4); - if (flip == true) { - t5 = vec_xor(t5, xor_vector); - t6 = vec_xor(t6, xor_vector); - t7 = vec_xor(t7, xor_vector); - t8 = vec_xor(t8, xor_vector); + for (int it = 0; it < 8; it++) { + arr[it] = __builtin_vsx_lxvp(0, (__vector_pair*)aoffsets[it]->qs); + __builtin_vsx_disassemble_pair(c[it], &arr[it]); + c1[it] = c[it][0]; + c2[it] = c[it][1]; } - vec_xst(t5, 0, vecOffset); - vec_xst(t6, 0, vecOffset+16); - vec_xst(t7, 0, vecOffset+32); - vec_xst(t8, 0, vecOffset+48); - - t1 = vec_perm(c1[1], c2[1], swiz1); - t2 = vec_perm(c1[1], c2[1], swiz2); - t3 = vec_perm(c3[1], c4[1], swiz1); - t4 = vec_perm(c3[1], c4[1], swiz2); - t5 = vec_perm(t1, t3, swiz3); - t6 = vec_perm(t1, t3, swiz4); - t7 = vec_perm(t2, t4, swiz3); - t8 = vec_perm(t2, t4, swiz4); - if (flip == true) { - t5 = vec_xor(t5, xor_vector); - t6 = vec_xor(t6, xor_vector); - t7 = vec_xor(t7, xor_vector); - t8 = vec_xor(t8, xor_vector); - } - vec_xst(t5, 0, vecOffset+64); - vec_xst(t6, 0, vecOffset+80); - vec_xst(t7, 0, vecOffset+96); - vec_xst(t8, 0, vecOffset+112); - - t1 = vec_perm(c5[0], c6[0], swiz1); - t2 = vec_perm(c5[0], c6[0], swiz2); - t3 = vec_perm(c7[0], c8[0], swiz1); - t4 = vec_perm(c7[0], c8[0], swiz2); - t5 = vec_perm(t1, t3, swiz3); - t6 = vec_perm(t1, t3, swiz4); - t7 = vec_perm(t2, t4, swiz3); - t8 = vec_perm(t2, t4, swiz4); - if (flip == true) { - t5 = vec_xor(t5, xor_vector); - t6 = vec_xor(t6, xor_vector); - t7 = vec_xor(t7, xor_vector); - t8 = vec_xor(t8, xor_vector); - } - vec_xst(t5, 0, vecOffset+128); - vec_xst(t6, 0, vecOffset+144); - vec_xst(t7, 0, vecOffset+160); - vec_xst(t8, 0, vecOffset+176); - - t1 = vec_perm(c5[1], c6[1], swiz1); - t2 = vec_perm(c5[1], c6[1], swiz2); - t3 = vec_perm(c7[1], c8[1], swiz1); - t4 = vec_perm(c7[1], c8[1], swiz2); - t5 = vec_perm(t1, t3, swiz3); - t6 = vec_perm(t1, t3, swiz4); - t7 = vec_perm(t2, t4, swiz3); - t8 = vec_perm(t2, t4, swiz4); - if (flip == true) { - t5 = vec_xor(t5, xor_vector); - t6 = vec_xor(t6, xor_vector); - t7 = vec_xor(t7, xor_vector); - t8 = vec_xor(t8, xor_vector); - } - vec_xst(t5, 0, vecOffset+192); - vec_xst(t6, 0, vecOffset+208); - vec_xst(t7, 0, vecOffset+224); - vec_xst(t8, 0, vecOffset+240); - - aoffset1 += lda; - aoffset2 += lda; - aoffset3 += lda; - aoffset4 += lda; - aoffset5 += lda; - aoffset6 += lda; - aoffset7 += lda; - aoffset8 += lda; + vector_permute_store(c1[0], c1[1], c1[2], c1[3], vecOffset, flip); + vector_permute_store(c2[0], c2[1], c2[2], c2[3], vecOffset+64, flip); + vector_permute_store(c1[4], c1[5], c1[6], c1[7], vecOffset+128, flip); + vector_permute_store(c2[4], c2[5], c2[6], c2[7], vecOffset+192, flip); + for (int it = 0; it < 8; it++) + aoffsets[it] += lda; vecOffset += 256; i--; } while(i > 0); @@ -2150,129 +1824,53 @@ class tinyBLAS_Q0_PPC { } if (rows & 4) { - aoffset1 = aoffset; - aoffset2 = aoffset1 + lda; - aoffset3 = aoffset2 + lda; - aoffset4 = aoffset3 + lda; - aoffset += 4 * lda; - + aoffsets[0] = aoffset; + for (int it = 1; it < 4; it++ ) + aoffsets[it] = aoffsets[it-1] + lda; + aoffset += 4 * lda; i = (cols >> 3); if (i > 0) { do { - C1 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset1->qs); - C2 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset2->qs); - C3 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset3->qs); - C4 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset4->qs); - - __builtin_vsx_disassemble_pair(c1, &C1); - __builtin_vsx_disassemble_pair(c2, &C2); - __builtin_vsx_disassemble_pair(c3, &C3); - __builtin_vsx_disassemble_pair(c4, &C4); - - t1 = vec_perm(c1[0], c2[0], swiz1); - t2 = vec_perm(c1[0], c2[0], swiz2); - t3 = vec_perm(c3[0], c4[0], swiz1); - t4 = vec_perm(c3[0], c4[0], swiz2); - t5 = vec_perm(t1, t3, swiz3); - t6 = vec_perm(t1, t3, swiz4); - t7 = vec_perm(t2, t4, swiz3); - t8 = vec_perm(t2, t4, swiz4); - if (flip == true) { - t5 = vec_xor(t5, xor_vector); - t6 = vec_xor(t6, xor_vector); - t7 = vec_xor(t7, xor_vector); - t8 = vec_xor(t8, xor_vector); + for (int it = 0; it < 4; it++) { + arr[it] = __builtin_vsx_lxvp(0, (__vector_pair*)aoffsets[it]->qs); + __builtin_vsx_disassemble_pair(c[it], &arr[it]); + c1[it] = c[it][0]; + c2[it] = c[it][1]; } - vec_xst(t5, 0, vecOffset); - vec_xst(t6, 0, vecOffset+16); - vec_xst(t7, 0, vecOffset+32); - vec_xst(t8, 0, vecOffset+48); - - t1 = vec_perm(c1[1], c2[1], swiz1); - t2 = vec_perm(c1[1], c2[1], swiz2); - t3 = vec_perm(c3[1], c4[1], swiz1); - t4 = vec_perm(c3[1], c4[1], swiz2); - t5 = vec_perm(t1, t3, swiz3); - t6 = vec_perm(t1, t3, swiz4); - t7 = vec_perm(t2, t4, swiz3); - t8 = vec_perm(t2, t4, swiz4); - if (flip == true) { - t5 = vec_xor(t5, xor_vector); - t6 = vec_xor(t6, xor_vector); - t7 = vec_xor(t7, xor_vector); - t8 = vec_xor(t8, xor_vector); + vector_permute_store(c1[0], c1[1], c1[2], c1[3], vecOffset, flip); + vector_permute_store(c2[0], c2[1], c2[2], c2[3], vecOffset+64, flip); + for (int it = 0; it < 4; it++) { + aoffsets[it] += lda; } - vec_xst(t5, 0, vecOffset+64); - vec_xst(t6, 0, vecOffset+80); - vec_xst(t7, 0, vecOffset+96); - vec_xst(t8, 0, vecOffset+112); - - aoffset1 += lda; - aoffset2 += lda; - aoffset3 += lda; - aoffset4 += lda; vecOffset += 128; i--; } while(i > 0); } } + if (rows & 3) { - aoffset1 = aoffset; - aoffset2 = aoffset1 + lda; - aoffset3 = aoffset2 + lda; + aoffsets[0] = aoffset; + for (int it = 1; it < 3; it++ ) + aoffsets[it] = aoffsets[it-1] + lda; i = (cols >> 3); if (i > 0) { do { switch(rows) { - case 3: C3 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset3->qs); - __builtin_vsx_disassemble_pair(c3, &C3); - case 2: C2 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset2->qs); - __builtin_vsx_disassemble_pair(c2, &C2); - case 1: C1 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset1->qs); - __builtin_vsx_disassemble_pair(c1, &C1); + case 3: arr[2] = __builtin_vsx_lxvp(0, (__vector_pair*)aoffsets[2]->qs); + __builtin_vsx_disassemble_pair(c[2], &arr[2]); + c1[2] = c[2][0]; c2[2] = c[2][1]; + case 2: arr[1] = __builtin_vsx_lxvp(0, (__vector_pair*)aoffsets[1]->qs); + __builtin_vsx_disassemble_pair(c[1], &arr[1]); + c1[1] = c[1][0]; c2[1] = c[1][1]; + case 1: arr[0] = __builtin_vsx_lxvp(0, (__vector_pair*)aoffsets[0]->qs); + __builtin_vsx_disassemble_pair(c[0], &arr[0]); + c1[0] = c[0][0]; c2[0] = c[0][1]; break; } - t1 = vec_perm(c1[0], c2[0], swiz1); - t2 = vec_perm(c1[0], c2[0], swiz2); - t3 = vec_perm(c3[0], c4[0], swiz1); - t4 = vec_perm(c3[0], c4[0], swiz2); - t5 = vec_perm(t1, t3, swiz3); - t6 = vec_perm(t1, t3, swiz4); - t7 = vec_perm(t2, t4, swiz3); - t8 = vec_perm(t2, t4, swiz4); - if (flip == true) { - t5 = vec_xor(t5, xor_vector); - t6 = vec_xor(t6, xor_vector); - t7 = vec_xor(t7, xor_vector); - t8 = vec_xor(t8, xor_vector); - } - vec_xst(t5, 0, vecOffset); - vec_xst(t6, 0, vecOffset+16); - vec_xst(t7, 0, vecOffset+32); - vec_xst(t8, 0, vecOffset+48); - - t1 = vec_perm(c1[1], c2[1], swiz1); - t2 = vec_perm(c1[1], c2[1], swiz2); - t3 = vec_perm(c3[1], c4[1], swiz1); - t4 = vec_perm(c3[1], c4[1], swiz2); - t5 = vec_perm(t1, t3, swiz3); - t6 = vec_perm(t1, t3, swiz4); - t7 = vec_perm(t2, t4, swiz3); - t8 = vec_perm(t2, t4, swiz4); - if (flip == true) { - t5 = vec_xor(t5, xor_vector); - t6 = vec_xor(t6, xor_vector); - t7 = vec_xor(t7, xor_vector); - t8 = vec_xor(t8, xor_vector); - } - vec_xst(t5, 0, vecOffset+64); - vec_xst(t6, 0, vecOffset+80); - vec_xst(t7, 0, vecOffset+96); - vec_xst(t8, 0, vecOffset+112); - - aoffset1 += lda; - aoffset2 += lda; - aoffset3 += lda; + vector_permute_store(c1[0], c1[1], c1[2], c1[3], vecOffset, flip); + vector_permute_store(c2[0], c2[1], c2[2], c2[3], vecOffset+64, flip); + for (int it = 0; it < 3; it++) + aoffsets[it] += lda; vecOffset += 128; i--; } while(i > 0); @@ -2281,159 +1879,42 @@ class tinyBLAS_Q0_PPC { } void mnpack(int64_t m0, int64_t m, int64_t n0, int64_t n) { - int64_t mc, nc, mp, np; - int m_rem = MIN(m - m0, 8); - int n_rem = MIN(n - n0, 8); - // TO-DO: KERNEL_16x8 and KERNEL_8x16 are having some performance - // issues. After resolving them, below code will be enabled. - /*if (m_rem >= 16 && n_rem >= 8) { - mc = 16; - nc = 8; - gemm<16,8>(m0, m, n0, n); - } else if(m_rem >= 8 && n_rem >= 16) { - mc = 8; - nc = 16; - gemm<8,16>(m0, m, n0, n); - }*/ + int m_rem = MIN(m - m0, 16); + int n_rem = MIN(n - n0, 16); + + int mc = 0, nc = 0; + if (m_rem >= 8 && n_rem >= 8) { - mc = 8; - nc = 8; - gemm<8,8>(m0, m, n0, n); + mc = 8; + nc = 8; + gemm<8, 8>(m0, m, n0, n); } else if (m_rem >= 4 && n_rem >= 8) { mc = 4; nc = 8; - gemm<4,8>(m0, m, n0, n); + gemm<4, 8>(m0, m, n0, n); } else if (m_rem >= 8 && n_rem >= 4) { mc = 8; nc = 4; - gemm<8,4>(m0, m, n0, n); + gemm<8, 4>(m0, m, n0, n); } else if (m_rem >= 4 && n_rem >= 4) { mc = 4; nc = 4; - gemm_small<4, 4>(m0, m, n0, n); - } else if ((m_rem < 4) && (n_rem > 4)) { - nc = 4; - switch(m_rem) { - case 1: - mc = 1; - gemm_small<1, 4>(m0, m, n0, n); - break; - case 2: - mc = 2; - gemm_small<2, 4>(m0, m, n0, n); - break; - case 3: - mc = 3; - gemm_small<3, 4>(m0, m, n0, n); - break; - default: - return; - } - } else if ((m_rem > 4) && (n_rem < 4)) { - mc = 4; - switch(n_rem) { - case 1: - nc = 1; - gemm_small<4, 1>(m0, m, n0, n); - break; - case 2: - nc = 2; - gemm_small<4, 2>(m0, m, n0, n); - break; - case 3: - nc = 3; - gemm_small<4, 3>(m0, m, n0, n); - break; - default: - return; - } + gemm_small(m0, m, n0, n, mc, nc); } else { - switch((m_rem << 4) | n_rem) { - case 0x43: - mc = 4; - nc = 3; - gemm_small<4, 3>(m0, m, n0, n); - break; - case 0x42: - mc = 4; - nc = 2; - gemm_small<4, 2>(m0, m, n0, n); - break; - case 0x41: - mc = 4; - nc = 1; - gemm_small<4, 1>(m0, m, n0, n); - break; - case 0x34: - mc = 3; - nc = 4; - gemm_small<3, 4>(m0, m, n0, n); - break; - case 0x33: - mc = 3; - nc = 3; - gemm_small<3, 3>(m0, m, n0, n); - break; - case 0x32: - mc = 3; - nc = 2; - gemm_small<3, 2>(m0, m, n0, n); - break; - case 0x31: - mc = 3; - nc = 1; - gemm_small<3, 1>(m0, m, n0, n); - break; - case 0x24: - mc = 2; - nc = 4; - gemm_small<2, 4>(m0, m, n0, n); - break; - case 0x23: - mc = 2; - nc = 3; - gemm_small<2, 3>(m0, m, n0, n); - break; - case 0x22: - mc = 2; - nc = 2; - gemm_small<2, 2>(m0, m, n0, n); - break; - case 0x21: - mc = 2; - nc = 1; - gemm_small<2, 1>(m0, m, n0, n); - break; - case 0x14: - mc = 1; - nc = 4; - gemm_small<1, 4>(m0, m, n0, n); - break; - case 0x13: - mc = 1; - nc = 3; - gemm_small<1, 3>(m0, m, n0, n); - break; - case 0x12: - mc = 1; - nc = 2; - gemm_small<1, 2>(m0, m, n0, n); - break; - case 0x11: - mc = 1; - nc = 1; - gemm_small<1, 1>(m0, m, n0, n); - break; - default: - return; - } + mc = (m_rem >= 4) ? 4 : m_rem; + nc = (n_rem >= 4) ? 4 : n_rem; + if (mc == 0 || nc == 0) + return; + gemm_small(m0, m, n0, n, mc, nc); } - mp = m0 + (m - m0) / mc * mc; - np = n0 + (n - n0) / nc * nc; + + int64_t mp = m0 + ((m - m0) / mc) * mc; + int64_t np = n0 + ((n - n0) / nc) * nc; mnpack(mp, m, n0, np); mnpack(m0, m, np, n); } + void KERNEL_4x8(int64_t ii, int64_t jj) { vec_t vec_A[8], vec_B[16] = {0}; acc_t acc_0, acc_1; @@ -2445,9 +1926,9 @@ class tinyBLAS_Q0_PPC { __builtin_mma_xxsetaccz(&acc_0); __builtin_mma_xxsetaccz(&acc_1); if (std::is_same_v) { - packNormalInt4((A+(ii*lda)+l), lda, 4, 4, (int8_t*)vec_A, comparray); + packNormalInt4<4>((A+(ii*lda)+l), lda, 4, 4, (int8_t*)vec_A, comparray); } else { - packNormal((const TB*)(A+(ii*lda)+l), lda, 4, 8, (int8_t*)vec_A, false); + packNormal((const block_q8_0*)(A+(ii*lda)+l), lda, 4, 8, (int8_t*)vec_A, false); } packNormal((B+(jj*ldb)+l), ldb, 8, 8, (uint8_t*)vec_B, true); for(int x = 0; x < 8; x++) { @@ -2475,8 +1956,8 @@ class tinyBLAS_Q0_PPC { compute<4>(&acc_0, 0, 0, comparray, vs, fin_res); compute<4>(&acc_1, 0, 4, comparray, vs, fin_res); } - save_res<4, 4>(ii, jj, 0, fin_res); - save_res<4, 4>(ii, jj+4, 4, fin_res); + save_res(ii, jj, 0, fin_res); + save_res(ii, jj+4, 4, fin_res); } void KERNEL_8x4(int64_t ii, int64_t jj) { @@ -2490,9 +1971,9 @@ class tinyBLAS_Q0_PPC { __builtin_mma_xxsetaccz(&acc_0); __builtin_mma_xxsetaccz(&acc_1); if (std::is_same_v) { - packNormalInt4((A+(ii*lda)+l), lda, 8, 4, (int8_t*)vec_A, comparray); + packNormalInt4<8>((A+(ii*lda)+l), lda, 8, 4, (int8_t*)vec_A, comparray); } else { - packNormal((const TB*)(A+(ii*lda)+l), lda, 8, 8, (int8_t*)vec_A, false); + packNormal((const block_q8_0*)(A+(ii*lda)+l), lda, 8, 8, (int8_t*)vec_A, false); } packNormal((B+(jj*ldb)+l), ldb, 4, 8, (uint8_t*)vec_B, true); for(int x = 0; x < 8; x++) { @@ -2519,8 +2000,8 @@ class tinyBLAS_Q0_PPC { compute<8>(&acc_0, 0, 0, comparray, vs, fin_res); compute<8>(&acc_1, 4, 4, comparray, vs, fin_res); } - save_res<4, 4>(ii, jj, 0, fin_res); - save_res<4, 4>(ii+4, jj, 4, fin_res); + save_res(ii, jj, 0, fin_res); + save_res(ii+4, jj, 4, fin_res); } void KERNEL_8x8(int64_t ii, int64_t jj) { @@ -2536,9 +2017,9 @@ class tinyBLAS_Q0_PPC { __builtin_mma_xxsetaccz(&acc_2); __builtin_mma_xxsetaccz(&acc_3); if (std::is_same_v) { - packNormalInt4((A+(ii*lda)+l), lda, 8, 4, (int8_t*)vec_A, comparray); + packNormalInt4<8>((A+(ii*lda)+l), lda, 8, 4, (int8_t*)vec_A, comparray); } else { - packNormal((const TB*)(A+(ii*lda)+l), lda, 8, 8, (int8_t*)vec_A, false); + packNormal((const block_q8_0*)(A+(ii*lda)+l), lda, 8, 8, (int8_t*)vec_A, false); } packNormal((B+(jj*ldb)+l), ldb, 8, 8, (uint8_t*)vec_B, true); for(int x = 0; x < 8; x++) { @@ -2570,14 +2051,13 @@ class tinyBLAS_Q0_PPC { compute<8>(&acc_2, 0, 8, comparray, vs, fin_res); compute<8>(&acc_3, 4, 12, comparray, vs, fin_res); } - save_res<4, 4>(ii, jj, 0, fin_res); - save_res<4, 4>(ii+4, jj, 4, fin_res); - save_res<4, 4>(ii, jj+4, 8, fin_res); - save_res<4, 4>(ii+4, jj+4, 12, fin_res); + save_res(ii, jj, 0, fin_res); + save_res(ii+4, jj, 4, fin_res); + save_res(ii, jj+4, 8, fin_res); + save_res(ii+4, jj+4, 12, fin_res); } - template - void gemm_small(int64_t m0, int64_t m, int64_t n0, int64_t n) { + void gemm_small(int64_t m0, int64_t m, int64_t n0, int64_t n, int RM, int RN) { int64_t ytiles = (m - m0) / RM; int64_t xtiles = (n - n0) / RN; int64_t tiles = xtiles * ytiles; @@ -2606,9 +2086,9 @@ class tinyBLAS_Q0_PPC { __builtin_prefetch((B+(jj*ldb)+(l+1))->qs, 0, 1); // prefetch one loop ahead __builtin_mma_xxsetaccz(&acc_0); if (isAblock_q4) { - packNormalInt4((A+(ii*lda)+l), lda, RM, 4, (int8_t*)vec_A, comparray); + packNormalInt4<4>((A+(ii*lda)+l), lda, RM, 4, (int8_t*)vec_A, comparray); } else { - packNormal((const TB*)(A+(ii*lda)+l), lda, RM, 8, (int8_t*)vec_A, false); + packNormal((const block_q8_0*)(A+(ii*lda)+l), lda, RM, 8, (int8_t*)vec_A, false); } packNormal((B+(jj*ldb)+l), ldb, RN, 8, (uint8_t*)vec_B, true); for(int x = 0; x < 8; x+=4) { @@ -2641,7 +2121,7 @@ class tinyBLAS_Q0_PPC { fin_res[i] = vec_madd(res[i], vs[i], fin_res[i]); } } - save_res(ii, jj, 0, fin_res); + save_res(ii, jj, 0, fin_res, RM, RN); } } @@ -2654,7 +2134,7 @@ class tinyBLAS_Q0_PPC { } else if constexpr(RM == 8 && RN == 8) { KERNEL_8x8(ii,jj); } else { - static_assert(false, "RN/RM values not supported"); + assert(false && "RN/RM values not supported"); } } @@ -2676,10 +2156,8 @@ class tinyBLAS_Q0_PPC { } const TA *const A; - const TB *const B; - TC *C; - TA *At; - TB *Bt; + const block_q8_0 *const B; + float *C; const int64_t k; const int64_t lda; const int64_t ldb; @@ -2688,13 +2166,12 @@ class tinyBLAS_Q0_PPC { const int nth; }; -template class tinyBLAS_PPC { public: tinyBLAS_PPC(int64_t k, - const TA *A, int64_t lda, - const TB *B, int64_t ldb, - TC *C, int64_t ldc, + const float *A, int64_t lda, + const float *B, int64_t ldb, + float *C, int64_t ldc, int ith, int nth) : A(A), B(B), C(C), k(k), lda(lda), ldb(ldb), ldc(ldc), ith(ith), nth(nth) { } @@ -2707,247 +2184,139 @@ class tinyBLAS_PPC { void (tinyBLAS_PPC::*kernel)(int64_t, int64_t); - template - void packTranspose(const TA* a, int64_t lda, int rows, int cols, TA* vec) { + inline void vector_permute_store_4(vector float *src, float *vecOffset) { + vector float t1, t2, t3, t4, t5, t6, t7, t8; + t1 = vec_mergeh(src[0], src[1]); + t2 = vec_mergeh(src[2], src[3]); + t3 = vec_mergel(src[0], src[1]); + t4 = vec_mergel(src[2], src[3]); + + t5 = vec_xxpermdi(t1, t2, 0); + t6 = vec_xxpermdi(t1, t2, 3); + t7 = vec_xxpermdi(t3, t4, 0); + t8 = vec_xxpermdi(t3, t4, 3); + + vec_xst(t5, 0, vecOffset); + vec_xst(t6, 0, vecOffset + 4); + vec_xst(t7, 0, vecOffset + 8); + vec_xst(t8, 0, vecOffset + 12); + } + + inline void vector_permute_store_8(vector float *src, float *vecOffset) { + vector float t1, t2, t3, t4, t5, t6, t7, t8; + t1 = vec_mergeh(src[0], src[1]); + t2 = vec_mergeh(src[2], src[3]); + t3 = vec_mergeh(src[4], src[5]); + t4 = vec_mergeh(src[6], src[7]); + + t5 = vec_xxpermdi(t1, t2, 0); + t6 = vec_xxpermdi(t3, t4, 0); + t7 = vec_xxpermdi(t1, t2, 3); + t8 = vec_xxpermdi(t3, t4, 3); + + vec_xst(t5, 0, vecOffset); + vec_xst(t6, 0, vecOffset + 4); + vec_xst(t7, 0, vecOffset + 8); + vec_xst(t8, 0, vecOffset + 12); + + t1 = vec_mergel(src[0], src[1]); + t2 = vec_mergel(src[2], src[3]); + t3 = vec_mergel(src[4], src[5]); + t4 = vec_mergel(src[6], src[7]); + + t5 = vec_xxpermdi(t1, t2, 0); + t6 = vec_xxpermdi(t3, t4, 0); + t7 = vec_xxpermdi(t1, t2, 3); + t8 = vec_xxpermdi(t3, t4, 3); + + vec_xst(t5, 0, vecOffset + 16); + vec_xst(t6, 0, vecOffset + 20); + vec_xst(t7, 0, vecOffset + 24); + vec_xst(t8, 0, vecOffset + 28); + } + + void packTranspose(const float* a, int64_t lda, int rows, int cols, float* vec) { int64_t i, j; - TA *aoffset = NULL, *boffset = NULL; - TA *aoffset1 = NULL, *aoffset2 = NULL, *aoffset3 = NULL, *aoffset4 = NULL; - TA *aoffset5 = NULL, *aoffset6 = NULL, *aoffset7 = NULL, *aoffset8 = NULL; - __vector_pair C1, C2, C3, C4, C5, C6, C7, C8; - VA c1[2] = {0}, c2[2] = {0}, c3[2] = {0}, c4[2] = {0}; - VA c5[2] = {0}, c6[2] = {0}, c7[2] = {0}, c8[2] = {0}; - VA t1, t2, t3, t4, t5, t6, t7, t8; - aoffset = const_cast(a); + float * aoffsets[8]; + float *aoffset = NULL, *boffset = NULL; + __vector_pair arr[8]; + vector float c[8][2] = {0}; + vector float c1[8] = {0}; + vector float c2[8] = {0}; + aoffset = const_cast(a); boffset = vec; j = (rows >> 3); if (j > 0) { do { - aoffset1 = aoffset; - aoffset2 = aoffset1 + lda; - aoffset3 = aoffset2 + lda; - aoffset4 = aoffset3 + lda; - aoffset5 = aoffset4 + lda; - aoffset6 = aoffset5 + lda; - aoffset7 = aoffset6 + lda; - aoffset8 = aoffset7 + lda; + aoffsets[0] = aoffset; + for (int it = 1; it< 8; it++) + aoffsets[it] = aoffsets[it-1] + lda; aoffset += 8 * lda; i = (cols >> 3); if (i > 0) { do { - C1 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset1); - C2 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset2); - C3 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset3); - C4 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset4); - C5 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset5); - C6 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset6); - C7 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset7); - C8 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset8); - __builtin_vsx_disassemble_pair(c1, &C1); - __builtin_vsx_disassemble_pair(c2, &C2); - __builtin_vsx_disassemble_pair(c3, &C3); - __builtin_vsx_disassemble_pair(c4, &C4); - __builtin_vsx_disassemble_pair(c5, &C5); - __builtin_vsx_disassemble_pair(c6, &C6); - __builtin_vsx_disassemble_pair(c7, &C7); - __builtin_vsx_disassemble_pair(c8, &C8); + for (int it = 0; it< 8; it++) { + arr[it] = __builtin_vsx_lxvp(0, (__vector_pair*)aoffsets[it]); + __builtin_vsx_disassemble_pair(c[it], &arr[it]); + c1[it] = c[it][0]; + c2[it] = c[it][1]; + } - t1 = vec_mergeh(c1[0], c2[0]); - t2 = vec_mergeh(c3[0], c4[0]); - t3 = vec_mergeh(c5[0], c6[0]); - t4 = vec_mergeh(c7[0], c8[0]); - t5 = vec_xxpermdi(t1, t2, 0); - t6 = vec_xxpermdi(t3, t4, 0); - t7 = vec_xxpermdi(t1, t2, 3); - t8 = vec_xxpermdi(t3, t4, 3); - vec_xst(t5, 0, boffset); - vec_xst(t6, 0, boffset+4); - vec_xst(t7, 0, boffset+8); - vec_xst(t8, 0, boffset+12); - - t1 = vec_mergel(c1[0], c2[0]); - t2 = vec_mergel(c3[0], c4[0]); - t3 = vec_mergel(c5[0], c6[0]); - t4 = vec_mergel(c7[0], c8[0]); - t5 = vec_xxpermdi(t1, t2, 0); - t6 = vec_xxpermdi(t3, t4, 0); - t7 = vec_xxpermdi(t1, t2, 3); - t8 = vec_xxpermdi(t3, t4, 3); - vec_xst(t5, 0, boffset+16); - vec_xst(t6, 0, boffset+20); - vec_xst(t7, 0, boffset+24); - vec_xst(t8, 0, boffset+28); - - t1 = vec_mergeh(c1[1], c2[1]); - t2 = vec_mergeh(c3[1], c4[1]); - t3 = vec_mergeh(c5[1], c6[1]); - t4 = vec_mergeh(c7[1], c8[1]); - t5 = vec_xxpermdi(t1, t2, 0); - t6 = vec_xxpermdi(t3, t4, 0); - t7 = vec_xxpermdi(t1, t2, 3); - t8 = vec_xxpermdi(t3, t4, 3); - vec_xst(t5, 0, boffset+32); - vec_xst(t6, 0, boffset+36); - vec_xst(t7, 0, boffset+40); - vec_xst(t8, 0, boffset+44); - - t1 = vec_mergel(c1[1], c2[1]); - t2 = vec_mergel(c3[1], c4[1]); - t3 = vec_mergel(c5[1], c6[1]); - t4 = vec_mergel(c7[1], c8[1]); - t5 = vec_xxpermdi(t1, t2, 0); - t6 = vec_xxpermdi(t3, t4, 0); - t7 = vec_xxpermdi(t1, t2, 3); - t8 = vec_xxpermdi(t3, t4, 3); - vec_xst(t5, 0, boffset+48); - vec_xst(t6, 0, boffset+52); - vec_xst(t7, 0, boffset+56); - vec_xst(t8, 0, boffset+60); - - aoffset1 += 8*lda; - aoffset2 += 8*lda; - aoffset3 += 8*lda; - aoffset4 += 8*lda; + vector_permute_store_8(c1, boffset); + vector_permute_store_8(c2, boffset+32); + for (int it = 0; it < 4; it++) + aoffsets[it] = aoffsets[it] + 8*lda; boffset += 64; i--; } while(i > 0); } if (cols & 4) { - c1[0] = vec_xl(0, aoffset1); - c2[0] = vec_xl(0, aoffset2); - c3[0] = vec_xl(0, aoffset3); - c4[0] = vec_xl(0, aoffset4); - c5[0] = vec_xl(0, aoffset5); - c6[0] = vec_xl(0, aoffset6); - c7[0] = vec_xl(0, aoffset7); - c8[0] = vec_xl(0, aoffset8); - - t1 = vec_mergeh(c1[0], c2[0]); - t2 = vec_mergeh(c3[0], c4[0]); - t3 = vec_mergeh(c5[0], c6[0]); - t4 = vec_mergeh(c7[0], c8[0]); - t5 = vec_xxpermdi(t1, t2, 0); - t6 = vec_xxpermdi(t3, t4, 0); - t7 = vec_xxpermdi(t1, t2, 3); - t8 = vec_xxpermdi(t3, t4, 3); - vec_xst(t5, 0, boffset); - vec_xst(t6, 0, boffset+4); - vec_xst(t7, 0, boffset+8); - vec_xst(t8, 0, boffset+12); - - t1 = vec_mergel(c1[0], c2[0]); - t2 = vec_mergel(c3[0], c4[0]); - t3 = vec_mergel(c5[0], c6[0]); - t4 = vec_mergel(c7[0], c8[0]); - t5 = vec_xxpermdi(t1, t2, 0); - t6 = vec_xxpermdi(t3, t4, 0); - t7 = vec_xxpermdi(t1, t2, 3); - t8 = vec_xxpermdi(t3, t4, 3); - vec_xst(t5, 0, boffset+16); - vec_xst(t6, 0, boffset+20); - vec_xst(t7, 0, boffset+24); - vec_xst(t8, 0, boffset+28); + for (int it = 0; it < 8 ; it++) + c1[it] = vec_xl(0, aoffsets[it]); + vector_permute_store_8(c1, boffset); } j--; } while(j > 0); } if (rows & 4) { - aoffset1 = aoffset; - aoffset2 = aoffset1 + lda; - aoffset3 = aoffset2 + lda; - aoffset4 = aoffset3 + lda; + aoffsets[0] = aoffset; + for (int it = 1; it < 4; it++) + aoffsets[it] = aoffsets[it-1] + lda; aoffset += 4 * lda; i = (cols >> 3); if (i > 0) { do { - C1 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset1); - C2 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset2); - C3 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset3); - C4 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset4); - __builtin_vsx_disassemble_pair(c1, &C1); - __builtin_vsx_disassemble_pair(c2, &C2); - __builtin_vsx_disassemble_pair(c3, &C3); - __builtin_vsx_disassemble_pair(c4, &C4); - - t1 = vec_mergeh(c1[0], c2[0]); - t2 = vec_mergeh(c3[0], c4[0]); - t3 = vec_mergel(c1[0], c2[0]); - t4 = vec_mergel(c3[0], c4[0]); - t5 = vec_xxpermdi(t1, t2, 0); - t6 = vec_xxpermdi(t1, t2, 3); - t7 = vec_xxpermdi(t3, t4, 0); - t8 = vec_xxpermdi(t3, t4, 3); - vec_xst(t5, 0, boffset); - vec_xst(t6, 0, boffset+4); - vec_xst(t7, 0, boffset+8); - vec_xst(t8, 0, boffset+12); - - t1 = vec_mergeh(c1[1], c2[1]); - t2 = vec_mergeh(c3[1], c4[1]); - t3 = vec_mergel(c1[1], c2[1]); - t4 = vec_mergel(c3[1], c4[1]); - t5 = vec_xxpermdi(t1, t2, 0); - t6 = vec_xxpermdi(t1, t2, 3); - t7 = vec_xxpermdi(t3, t4, 0); - t8 = vec_xxpermdi(t3, t4, 3); - vec_xst(t5, 0, boffset+16); - vec_xst(t6, 0, boffset+20); - vec_xst(t7, 0, boffset+24); - vec_xst(t8, 0, boffset+28); - - aoffset1 += 8*lda; - aoffset2 += 8*lda; - aoffset3 += 8*lda; - aoffset4 += 8*lda; + for (int it = 0; it < 4; it++) { + arr[it] = __builtin_vsx_lxvp(0, (__vector_pair*)aoffsets[it]); + __builtin_vsx_disassemble_pair(c[it], &arr[it]); + c1[it] = c[it][0]; + c2[it] = c[it][1]; + } + vector_permute_store_4(c1, boffset); + vector_permute_store_4(c2, boffset+16); + for (int it = 0; it < 4; it++) + aoffsets[it] += 8*lda; boffset += 32; i--; } while(i > 0); } if (cols & 4) { - c1[0] = vec_xl(0, aoffset1); - c2[0] = vec_xl(0, aoffset2); - c3[0] = vec_xl(0, aoffset3); - c4[0] = vec_xl(0, aoffset4); - - t1 = vec_mergeh(c1[0], c2[0]); - t2 = vec_mergeh(c3[0], c4[0]); - t3 = vec_xxpermdi(t1, t2, 0); - t4 = vec_xxpermdi(t1, t2, 3); - vec_xst(t3, 0, boffset); - vec_xst(t4, 0, boffset+4); - - t1 = vec_mergel(c1[0], c2[0]); - t2 = vec_mergel(c3[0], c4[0]); - t3 = vec_xxpermdi(t1, t2, 0); - t4 = vec_xxpermdi(t1, t2, 3); - vec_xst(t3, 0, boffset+8); - vec_xst(t4, 0, boffset+12); + for (int it = 0; it < 4; it++) + c1[it] = vec_xl(0, aoffsets[it]); + vector_permute_store_4(c1, boffset); } } if (rows & 3) { - aoffset1 = aoffset; - aoffset2 = aoffset1 + lda; - aoffset3 = aoffset2 + lda; + aoffsets[0] = aoffset; + for (int it = 1; it < 3; it++) + aoffsets[it] = aoffsets[it-1] + lda; if (cols & 4) { - c1[0] = vec_xl(0, aoffset1); - c2[0] = vec_xl(0, aoffset2); - c3[0] = vec_xl(0, aoffset3); - - t1 = vec_mergeh(c1[0], c2[0]); - t2 = vec_mergeh(c3[0], c4[0]); - t3 = vec_xxpermdi(t1, t2, 0); - t4 = vec_xxpermdi(t1, t2, 3); - vec_xst(t3, 0, boffset); - vec_xst(t4, 0, boffset+4); - - t1 = vec_mergel(c1[0], c2[0]); - t2 = vec_mergel(c3[0], c4[0]); - t3 = vec_xxpermdi(t1, t2, 0); - t4 = vec_xxpermdi(t1, t2, 3); - vec_xst(t3, 0, boffset+8); - vec_xst(t4, 0, boffset+12); + for (int it = 0; it < 3; it++) + c1[it] = vec_xl(0, aoffsets[it]); + vector_permute_store_4(c1, boffset); } } } @@ -2957,8 +2326,8 @@ class tinyBLAS_PPC { acc_t acc_0; __builtin_mma_xxsetaccz(&acc_0); for (int l = 0; l < k; l+=4) { - packTranspose(A+(ii*lda)+l, lda, 4, 4, (TA*)vec_A); - packTranspose(B+(jj*ldb)+l, ldb, 4, 4, (TA*)vec_B); + packTranspose(A+(ii*lda)+l, lda, 4, 4, (float*)vec_A); + packTranspose(B+(jj*ldb)+l, ldb, 4, 4, (float*)vec_B); __builtin_mma_xvf32gerpp(&acc_0, vec_A[0], vec_B[0]); __builtin_mma_xvf32gerpp(&acc_0, vec_A[1], vec_B[1]); __builtin_mma_xvf32gerpp(&acc_0, vec_A[2], vec_B[2]); @@ -2973,8 +2342,8 @@ class tinyBLAS_PPC { __builtin_mma_xxsetaccz(&acc_0); __builtin_mma_xxsetaccz(&acc_1); for (int64_t l = 0; l < k; l+=4) { - packTranspose(A+(ii*lda)+l, lda, 4, 4, (TA*)vec_A); - packTranspose(B+(jj*ldb)+l, ldb, 8, 4, (TA*)vec_B); + packTranspose(A+(ii*lda)+l, lda, 4, 4, (float*)vec_A); + packTranspose(B+(jj*ldb)+l, ldb, 8, 4, (float*)vec_B); __builtin_mma_xvf32gerpp(&acc_0, vec_A[0], (vec_t)vec_B[0]); __builtin_mma_xvf32gerpp(&acc_1, vec_A[0], (vec_t)vec_B[1]); __builtin_mma_xvf32gerpp(&acc_0, vec_A[1], (vec_t)vec_B[2]); @@ -2994,8 +2363,8 @@ class tinyBLAS_PPC { __builtin_mma_xxsetaccz(&acc_0); __builtin_mma_xxsetaccz(&acc_1); for (int64_t l = 0; l < k; l+=4) { - packTranspose(A+(ii*lda)+l, lda, 8, 4, (TA*)vec_A); - packTranspose(B+(jj*ldb)+l, ldb, 4, 4, (TA*)vec_B); + packTranspose(A+(ii*lda)+l, lda, 8, 4, (float*)vec_A); + packTranspose(B+(jj*ldb)+l, ldb, 4, 4, (float*)vec_B); __builtin_mma_xvf32gerpp(&acc_0, (vec_t)vec_A[0], vec_B[0]); __builtin_mma_xvf32gerpp(&acc_1, (vec_t)vec_A[1], vec_B[0]); __builtin_mma_xvf32gerpp(&acc_0, (vec_t)vec_A[2], vec_B[1]); @@ -3017,8 +2386,8 @@ class tinyBLAS_PPC { __builtin_mma_xxsetaccz(&acc_2); __builtin_mma_xxsetaccz(&acc_3); for (int l = 0; l < k; l+=8) { - packTranspose(A+(ii*lda)+l, lda, 8, 8, (TA*)vec_A); - packTranspose(B+(jj*ldb)+l, ldb, 8, 8, (TA*)vec_B); + packTranspose(A+(ii*lda)+l, lda, 8, 8, (float*)vec_A); + packTranspose(B+(jj*ldb)+l, ldb, 8, 8, (float*)vec_B); for(int x = 0; x < 16; x+=2) { __builtin_mma_xvf32gerpp(&acc_0, (vec_t)vec_A[x], vec_B[x]); __builtin_mma_xvf32gerpp(&acc_1, (vec_t)vec_A[x], vec_B[x+1]); @@ -3033,155 +2402,37 @@ class tinyBLAS_PPC { } void mnpack(int64_t m0, int64_t m, int64_t n0, int64_t n) { - int64_t mc, nc, mp, np; - int m_rem = MIN(m - m0, 16); - int n_rem = MIN(n - n0, 16); - if (m_rem >= 16 && n_rem >= 8) { - mc = 8; - nc = 8; - gemm<8,8>(m0, m, n0, n); - } else if(m_rem >= 8 && n_rem >= 16) { - mc = 8; - nc = 8; - gemm<8,8>(m0, m, n0, n); - } else if (m_rem >= 8 && n_rem >= 8) { - mc = 8; - nc = 8; - gemm<8,8>(m0, m, n0, n); + int m_rem = MIN(m - m0, 8); + int n_rem = MIN(n - n0, 8); + int mc = 0, nc = 0; + if (m_rem >= 8 && n_rem >= 8) { + mc = 8; + nc = 8; + gemm<8, 8>(m0, m, n0, n); } else if (m_rem >= 4 && n_rem >= 8) { - mc = 4; - nc = 8; - gemm<4,8>(m0, m, n0, n); + mc = 4; + nc = 8; + gemm<4, 8>(m0, m, n0, n); } else if (m_rem >= 8 && n_rem >= 4) { - mc = 8; - nc = 4; - gemm<8,4>(m0, m, n0, n); + mc = 8; + nc = 4; + gemm<8, 4>(m0, m, n0, n); } else if (m_rem >= 4 && n_rem >= 4) { - mc = 4; - nc = 4; - gemm<4,4>(m0, m, n0, n); - } else if ((m_rem < 4) && (n_rem > 4)) { - nc = 4; - switch(m_rem) { - case 1: - mc = 1; - gemm_small(m0, m, n0, n, mc, nc); - break; - case 2: - mc = 2; - gemm_small(m0, m, n0, n, mc, nc); - break; - case 3: - mc = 3; - gemm_small(m0, m, n0, n, mc, nc); - break; - default: - return; - } - } else if ((m_rem > 4) && (n_rem < 4)) { - mc = 4; - switch(n_rem) { - case 1: - nc = 1; - gemm_small(m0, m, n0, n, mc, nc); - break; - case 2: - nc = 2; - gemm_small(m0, m, n0, n, mc, nc); - break; - case 3: - nc = 3; - gemm_small(m0, m, n0, n, mc, nc); - break; - default: - return; - } + mc = 4; + nc = 4; + gemm<4, 4>(m0, m, n0, n); } else { - switch((m_rem << 4) | n_rem) { - case 0x43: - mc = 4; - nc = 3; - gemm_small(m0, m, n0, n, mc, nc); - break; - case 0x42: - mc = 4; - nc = 2; - gemm_small(m0, m, n0, n, mc, nc); - break; - case 0x41: - mc = 4; - nc = 1; - gemm_small(m0, m, n0, n, mc, nc); - break; - case 0x34: - mc = 3; - nc = 4; - gemm_small(m0, m, n0, n, mc, nc); - break; - case 0x33: - mc = 3; - nc = 3; - gemm_small(m0, m, n0, n, mc, nc); - break; - case 0x32: - mc = 3; - nc = 2; - gemm_small(m0, m, n0, n, mc, nc); - break; - case 0x31: - mc = 3; - nc = 1; - gemm_small(m0, m, n0, n, mc, nc); - break; - case 0x24: - mc = 2; - nc = 4; - gemm_small(m0, m, n0, n, mc, nc); - break; - case 0x23: - mc = 2; - nc = 3; - gemm_small(m0, m, n0, n, mc, nc); - break; - case 0x22: - mc = 2; - nc = 2; - gemm_small(m0, m, n0, n, mc, nc); - break; - case 0x21: - mc = 2; - nc = 1; - gemm_small(m0, m, n0, n, mc, nc); - break; - case 0x14: - mc = 1; - nc = 4; - gemm_small(m0, m, n0, n, mc, nc); - break; - case 0x13: - mc = 1; - nc = 3; - gemm_small(m0, m, n0, n, mc, nc); - break; - case 0x12: - mc = 1; - nc = 2; - gemm_small(m0, m, n0, n, mc, nc); - break; - case 0x11: - mc = 1; - nc = 1; - gemm_small(m0, m, n0, n, mc, nc); - break; - default: - return; - } + mc = (m_rem >= 4) ? 4 : m_rem; + nc = (n_rem >= 4) ? 4 : n_rem; + if (mc == 0 || nc == 0) + return; + gemm_small(m0, m, n0, n, mc, nc); } - mp = m0 + (m - m0) / mc * mc; - np = n0 + (n - n0) / nc * nc; + int64_t mp = m0 + ((m - m0) / mc) * mc; + int64_t np = n0 + ((n - n0) / nc) * nc; mnpack(mp, m, n0, np); mnpack(m0, m, np, n); - } + } void gemm_small(int64_t m0, int64_t m, int64_t n0, int64_t n, int RM, int RN) { int64_t ytiles = (m - m0) / RM; @@ -3206,22 +2457,22 @@ class tinyBLAS_PPC { * matrix elements. */ if (RM == 1) { - TA* a = const_cast(A+(ii)*lda+l); - packTranspose(B+(jj*ldb)+l, ldb, RN, 4, (TA*)vec_B); + float* a = const_cast(A+(ii)*lda+l); + packTranspose(B+(jj*ldb)+l, ldb, RN, 4, (float*)vec_B); vec_A[0] = (vec_t)vec_xl(0,a); - vec_A[1] = (vec_t)vec_splats(*((TA*)&vec_A+1)); - vec_A[2] = (vec_t)vec_splats(*((TA*)&vec_A+2)); - vec_A[3] = (vec_t)vec_splats(*((TA*)&vec_A+3)); + vec_A[1] = (vec_t)vec_splats(*((float*)&vec_A+1)); + vec_A[2] = (vec_t)vec_splats(*((float*)&vec_A+2)); + vec_A[3] = (vec_t)vec_splats(*((float*)&vec_A+3)); } else if (RN == 1) { - packTranspose(A+(ii*lda)+l, lda, RM, 4, (TA*)vec_A); - TB* b = const_cast(B+(jj)*ldb+l); + packTranspose(A+(ii*lda)+l, lda, RM, 4, (float*)vec_A); + float* b = const_cast(B+(jj)*ldb+l); vec_B[0] = (vec_t)vec_xl(0,b); - vec_B[1] = (vec_t)vec_splats(*((TB*)&vec_B+1)); - vec_B[2] = (vec_t)vec_splats(*((TB*)&vec_B+2)); - vec_B[3] = (vec_t)vec_splats(*((TB*)&vec_B+3)); + vec_B[1] = (vec_t)vec_splats(*((float*)&vec_B+1)); + vec_B[2] = (vec_t)vec_splats(*((float*)&vec_B+2)); + vec_B[3] = (vec_t)vec_splats(*((float*)&vec_B+3)); } else { - packTranspose(A+(ii*lda)+l, lda, RM, 4, (TA*)vec_A); - packTranspose(B+(jj*ldb)+l, ldb, RN, 4, (TA*)vec_B); + packTranspose(A+(ii*lda)+l, lda, RM, 4, (float*)vec_A); + packTranspose(B+(jj*ldb)+l, ldb, RN, 4, (float*)vec_B); } __builtin_mma_xvf32gerpp(&acc_0, vec_A[0], vec_B[0]); __builtin_mma_xvf32gerpp(&acc_0, vec_A[1], vec_B[1]); @@ -3231,7 +2482,7 @@ class tinyBLAS_PPC { __builtin_mma_disassemble_acc(vec_C, &acc_0); for (int I = 0; I < RM; I++) { for (int J = 0; J < RN; J++) { - *((TC*)(C+ii+((jj+J)*ldc)+I)) = *((TC*)&vec_C[I]+J); + *((float*)(C+ii+((jj+J)*ldc)+I)) = *((float*)&vec_C[I]+J); } } } @@ -3263,11 +2514,9 @@ class tinyBLAS_PPC { } } - const TA *const A; - const TB *const B; - TC *C; - TA *At; - TB *Bt; + const float *const A; + const float *const B; + float *C; const int64_t k; const int64_t lda; const int64_t ldb; @@ -3366,7 +2615,7 @@ bool llamafile_sgemm(const struct ggml_compute_params * params, int64_t m, int64 #elif defined(__MMA__) if (k % 8) return false; - tinyBLAS_PPC tb{ + tinyBLAS_PPC tb{ k, (const float *)A, lda, (const float *)B, ldb, (float *)C, ldc, @@ -3493,7 +2742,7 @@ bool llamafile_sgemm(const struct ggml_compute_params * params, int64_t m, int64 return false; if (m < 8 && m != 4) return false; - tinyBLAS_Q0_PPC tb{ + tinyBLAS_Q0_PPC tb{ k, (const block_q8_0 *)A, lda, (const block_q8_0 *)B, ldb, (float *)C, ldc, @@ -3530,7 +2779,7 @@ bool llamafile_sgemm(const struct ggml_compute_params * params, int64_t m, int64 return false; if (m < 8 && m != 4) return false; - tinyBLAS_Q0_PPC tb{ + tinyBLAS_Q0_PPC tb{ k, (const block_q4_0 *)A, lda, (const block_q8_0 *)B, ldb, (float *)C, ldc, diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/ggml-cuda.cu b/ml/backend/ggml/ggml/src/ggml-cuda/ggml-cuda.cu index d2f1ed879..1610e5d4b 100644 --- a/ml/backend/ggml/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ml/backend/ggml/ggml/src/ggml-cuda/ggml-cuda.cu @@ -43,6 +43,7 @@ #include "ggml-cuda/upscale.cuh" #include "ggml-cuda/wkv.cuh" #include "ggml-cuda/gla.cuh" +#include "ggml-cuda/set-rows.cuh" #include "ggml.h" #include @@ -2233,6 +2234,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg case GGML_OP_GET_ROWS_BACK: ggml_cuda_op_get_rows_back(ctx, dst); break; + case GGML_OP_SET_ROWS: + ggml_cuda_op_set_rows(ctx, dst); + break; case GGML_OP_DUP: ggml_cuda_dup(ctx, dst); break; @@ -2302,6 +2306,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg case GGML_UNARY_OP_EXP: ggml_cuda_op_exp(ctx, dst); break; + case GGML_UNARY_OP_ELU: + ggml_cuda_op_elu(ctx, dst); + break; default: return false; } @@ -3122,6 +3129,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g case GGML_UNARY_OP_GELU_QUICK: case GGML_UNARY_OP_TANH: case GGML_UNARY_OP_EXP: + case GGML_UNARY_OP_ELU: return ggml_is_contiguous(op->src[0]); default: return false; @@ -3226,6 +3234,13 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g { return op->type == GGML_TYPE_F32 && op->src[0]->type == GGML_TYPE_F32 && op->ne[2] == 1 && op->ne[3] == 1; } break; + case GGML_OP_SET_ROWS: + { +#pragma message("TODO: implement Q4_0, Q4_1, Q5_0, Q5_1, Q8_0, IQ4_NL support (https://github.com/ggml-org/llama.cpp/pull/14661)") + return (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16 || op->type == GGML_TYPE_BF16) && + op->src[0]->type == GGML_TYPE_F32 && + op->src[1]->type == GGML_TYPE_I64; + } break; case GGML_OP_CPY: { ggml_type src0_type = op->src[0]->type; diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/set-rows.cu b/ml/backend/ggml/ggml/src/ggml-cuda/set-rows.cu new file mode 100644 index 000000000..58cee9244 --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-cuda/set-rows.cu @@ -0,0 +1,151 @@ +#include "set-rows.cuh" + +typedef void (*set_rows_kernel_t)(const char * src, char * dst); + +template +__device__ void set_rows_1(const src_t * src_f, dst_t * dst_f) { + GGML_UNUSED(src_f); + GGML_UNUSED(dst_f); +} + +template<> +__device__ __forceinline__ void set_rows_1(const float * src_f, half * dst_h) { + *dst_h = __float2half(*src_f); +} + +template<> +__device__ __forceinline__ void set_rows_1(const float * src_f, nv_bfloat16 * dst_b) { + *dst_b = *src_f; +} + +template<> +__device__ __forceinline__ void set_rows_1(const float * src_f, float * dst_f) { + *dst_f = *src_f; +} + +template +static __global__ void k_set_rows( + const src_t * __restrict__ src0, const int64_t * __restrict__ src1, dst_t * __restrict__ dst, + const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03, + const int64_t ne10, const int64_t ne11, const int64_t ne12, const int64_t ne13, + const int64_t s01, const int64_t s02, const int64_t s03, + const int64_t s10, const int64_t s11, const int64_t s12, + const int64_t s1, const int64_t s2, const int64_t s3) { + + const int64_t i = int64_t(blockDim.x) * blockIdx.x + threadIdx.x; + const int64_t ne_total = ne00 * ne01 * ne02 * ne03; + + if (i >= ne_total) { + return; + } + + const int64_t i03 = i / (ne00 * ne01 * ne02); + const int64_t i02 = (i - i03 * ne00 * ne01 * ne02) / (ne00 * ne01); + const int64_t i01 = (i - i03 * ne00 * ne01 * ne02 - i02 * ne00 * ne01) / ne00; + const int64_t i00 = i - i03 * ne00 * ne01 * ne02 - i02 * ne00 * ne01 - i01 * ne00; + + const int64_t i12 = i03 % ne12; + const int64_t i11 = i02 % ne11; + const int64_t i10 = i01; + + const int64_t dst_row = *(src1 + i10*s10 + i11*s11 + i12*s12); + + const src_t * src0_row = src0 + i01*s01 + i02*s02 + i03*s03; + dst_t * dst_row_ptr = dst + dst_row*s1 + i02*s2 + i03*s3; + + const src_t* src_elem = src0_row + i00; + dst_t* dst_elem = dst_row_ptr + i00; + set_rows_1(src_elem, dst_elem); + + GGML_UNUSED(ne10); + GGML_UNUSED(ne13); +} + +template +static void set_rows_cuda( + const src_t * src0_d, const int64_t * src1_d, dst_t * dst_d, + const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03, + const int64_t ne10, const int64_t ne11, const int64_t ne12, const int64_t ne13, + const size_t nb01, const size_t nb02, const size_t nb03, + const size_t nb10, const size_t nb11, const size_t nb12, + const size_t nb1, const size_t nb2, const size_t nb3, + cudaStream_t stream) { + + const int64_t ne_total = ne00 * ne01 * ne02 * ne03; + const int num_blocks = (ne_total + CUDA_SET_ROWS_BLOCK_SIZE - 1) / CUDA_SET_ROWS_BLOCK_SIZE; + const dim3 block_size(CUDA_SET_ROWS_BLOCK_SIZE); + const dim3 grid_size(num_blocks); + + + const int64_t s01 = nb01/sizeof(src_t); + const int64_t s02 = nb02/sizeof(src_t); + const int64_t s03 = nb03/sizeof(src_t); + const int64_t s10 = nb10/sizeof(int64_t); + const int64_t s11 = nb11/sizeof(int64_t); + const int64_t s12 = nb12/sizeof(int64_t); + const int64_t s1 = nb1/sizeof(dst_t); + const int64_t s2 = nb2/sizeof(dst_t); + const int64_t s3 = nb3/sizeof(dst_t); + + if (ne_total > 0) { + k_set_rows<<>>( + src0_d, src1_d, dst_d, + ne00, ne01, ne02, ne03, + ne10, ne11, ne12, ne13, + s01, s02, s03, + s10, s11, s12, + s1, s2, s3); + } +} + + +void ggml_cuda_op_set_rows(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + const ggml_tensor * src0 = dst->src[0]; + const ggml_tensor * src1 = dst->src[1]; + + GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT(src1->type == GGML_TYPE_I64); + + GGML_TENSOR_BINARY_OP_LOCALS + + const float * src0_d = (const float *)src0->data; + const int64_t * src1_d = (const int64_t *)src1->data; + + cudaStream_t stream = ctx.stream(); + + + + if (dst->type == GGML_TYPE_F32) { + set_rows_cuda( + src0_d, src1_d, (float*)dst->data, + ne00, ne01, ne02, ne03, + ne10, ne11, ne12, ne13, + nb01, nb02, nb03, + nb10, nb11, nb12, + nb1, nb2, nb3, + stream + ); + } else if (dst->type == GGML_TYPE_F16) { + set_rows_cuda( + src0_d, src1_d, (half*)dst->data, + ne00, ne01, ne02, ne03, + ne10, ne11, ne12, ne13, + nb01, nb02, nb03, + nb10, nb11, nb12, + nb1, nb2, nb3, + stream + ); + } else if (dst->type == GGML_TYPE_BF16) { + set_rows_cuda( + src0_d, src1_d, (nv_bfloat16*)dst->data, + ne00, ne01, ne02, ne03, + ne10, ne11, ne12, ne13, + nb01, nb02, nb03, + nb10, nb11, nb12, + nb1, nb2, nb3, + stream + ); + } else { + GGML_ABORT("unsupported type"); + } +} diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/set-rows.cuh b/ml/backend/ggml/ggml/src/ggml-cuda/set-rows.cuh new file mode 100644 index 000000000..c140c0873 --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-cuda/set-rows.cuh @@ -0,0 +1,7 @@ +#pragma once + +#include "common.cuh" + +#define CUDA_SET_ROWS_BLOCK_SIZE 256 + +void ggml_cuda_op_set_rows(ggml_backend_cuda_context & ctx, ggml_tensor * dst); diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/ssm-conv.cu b/ml/backend/ggml/ggml/src/ggml-cuda/ssm-conv.cu index f63757196..419797336 100644 --- a/ml/backend/ggml/ggml/src/ggml-cuda/ssm-conv.cu +++ b/ml/backend/ggml/ggml/src/ggml-cuda/ssm-conv.cu @@ -107,8 +107,11 @@ static void ssm_conv_f32_cuda(const float * src0, const float * src1, const int if (nc == 4) { ssm_conv_f32<<>>(src0, src1, src0_nb0, src0_nb1, src0_nb2, src1_nb1, dst, dst_nb0, dst_nb1, dst_nb2, n_t); + } else if (nc == 3) { + ssm_conv_f32<<>>(src0, src1, src0_nb0, src0_nb1, src0_nb2, src1_nb1, + dst, dst_nb0, dst_nb1, dst_nb2, n_t); } else { - GGML_ABORT("Only support kernel size = 4 now."); + GGML_ABORT("Only support kernel size = 3 or size = 4 right now."); } } else { if (nc == 4) { @@ -116,8 +119,13 @@ static void ssm_conv_f32_cuda(const float * src0, const float * src1, const int dim3 blocks(n_s, (nr + threads - 1) / threads, (n_t + split_n_t - 1) / split_n_t); ssm_conv_long_token_f32<<>>( src0, src1, src0_nb0, src0_nb1, src0_nb2, src1_nb1, dst, dst_nb0, dst_nb1, dst_nb2, n_t); + } else if (nc == 3) { + const int64_t split_n_t = 32; + dim3 blocks(n_s, (nr + threads - 1) / threads, (n_t + split_n_t - 1) / split_n_t); + ssm_conv_long_token_f32<<>>( + src0, src1, src0_nb0, src0_nb1, src0_nb2, src1_nb1, dst, dst_nb0, dst_nb1, dst_nb2, n_t); } else { - GGML_ABORT("Only support kernel size = 4 right now."); + GGML_ABORT("Only support kernel size = 3 or size = 4 right now."); } } } diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/unary.cu b/ml/backend/ggml/ggml/src/ggml-cuda/unary.cu index f9c7b83c4..91c830c4d 100644 --- a/ml/backend/ggml/ggml/src/ggml-cuda/unary.cu +++ b/ml/backend/ggml/ggml/src/ggml-cuda/unary.cu @@ -83,6 +83,10 @@ static __device__ __forceinline__ float op_log(float x) { return logf(x); } +static __device__ __forceinline__ float op_elu(float x) { + return (x > 0.f) ? x : expm1f(x); +} + template static __global__ void unary_op_kernel(const T * x, T * dst, const int k) { const int i = blockDim.x*blockIdx.x + threadIdx.x; @@ -196,6 +200,9 @@ void ggml_cuda_op_log(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { ggml_cuda_op_unary(ctx, dst); } +void ggml_cuda_op_elu(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + ggml_cuda_op_unary(ctx, dst); +} /* gated ops */ template diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/unary.cuh b/ml/backend/ggml/ggml/src/ggml-cuda/unary.cuh index 289d690e5..cb14d16f8 100644 --- a/ml/backend/ggml/ggml/src/ggml-cuda/unary.cuh +++ b/ml/backend/ggml/ggml/src/ggml-cuda/unary.cuh @@ -59,6 +59,8 @@ void ggml_cuda_op_cos(ggml_backend_cuda_context & ctx, ggml_tensor * dst); void ggml_cuda_op_log(ggml_backend_cuda_context & ctx, ggml_tensor * dst); +void ggml_cuda_op_elu(ggml_backend_cuda_context & ctx, ggml_tensor * dst); + void ggml_cuda_op_reglu(ggml_backend_cuda_context & ctx, ggml_tensor * dst); void ggml_cuda_op_geglu(ggml_backend_cuda_context & ctx, ggml_tensor * dst); diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/vendors/hip.h b/ml/backend/ggml/ggml/src/ggml-cuda/vendors/hip.h index 1a28831b7..184d445f5 100644 --- a/ml/backend/ggml/ggml/src/ggml-cuda/vendors/hip.h +++ b/ml/backend/ggml/ggml/src/ggml-cuda/vendors/hip.h @@ -10,9 +10,6 @@ #include "rocblas/rocblas.h" #endif // __HIP_PLATFORM_AMD__ -#define CUBLAS_COMPUTE_16F HIPBLAS_R_16F -#define CUBLAS_COMPUTE_32F HIPBLAS_R_32F -#define CUBLAS_COMPUTE_32F_FAST_16F HIPBLAS_R_32F #define CUBLAS_GEMM_DEFAULT HIPBLAS_GEMM_DEFAULT #define CUBLAS_GEMM_DEFAULT_TENSOR_OP HIPBLAS_GEMM_DEFAULT #define CUBLAS_OP_N HIPBLAS_OP_N @@ -30,7 +27,6 @@ #define CU_CHECK(fn) {hipError_t err = fn; if(err != hipSuccess) { GGML_ABORT("HipVMM Failure: %s\n", hipGetErrorString(err)); }} #define __shfl_sync(mask, var, laneMask, width) __shfl(var, laneMask, width) #define __shfl_xor_sync(mask, var, laneMask, width) __shfl_xor(var, laneMask, width) -#define cublasComputeType_t hipblasDatatype_t //deprecated, new hipblasComputeType_t not in 5.6 #define cublasCreate hipblasCreate #define cublasDestroy hipblasDestroy #define cublasGemmEx hipblasGemmEx @@ -42,7 +38,6 @@ #define cublasSgemm hipblasSgemm #define cublasStatus_t hipblasStatus_t #define cublasOperation_t hipblasOperation_t -#define cudaDataType_t hipblasDatatype_t //deprecated, new hipblasDatatype not in 5.6 #define cudaDeviceCanAccessPeer hipDeviceCanAccessPeer #define cudaDeviceDisablePeerAccess hipDeviceDisablePeerAccess #define cudaDeviceEnablePeerAccess hipDeviceEnablePeerAccess @@ -144,6 +139,20 @@ #define CUBLAS_STATUS_INTERNAL_ERROR HIPBLAS_STATUS_INTERNAL_ERROR #define CUBLAS_STATUS_NOT_SUPPORTED HIPBLAS_STATUS_NOT_SUPPORTED +#if defined(__HIP_PLATFORM_AMD__) && HIP_VERSION >= 70000000 +#define CUBLAS_COMPUTE_16F HIPBLAS_COMPUTE_16F +#define CUBLAS_COMPUTE_32F HIPBLAS_COMPUTE_32F +#define CUBLAS_COMPUTE_32F_FAST_16F HIPBLAS_COMPUTE_32F_FAST_16F +#define cublasComputeType_t hipblasComputeType_t +#define cudaDataType_t hipDataType +#else +#define CUBLAS_COMPUTE_16F HIPBLAS_R_16F +#define CUBLAS_COMPUTE_32F HIPBLAS_R_32F +#define CUBLAS_COMPUTE_32F_FAST_16F HIPBLAS_R_32F +#define cublasComputeType_t hipblasDatatype_t +#define cudaDataType_t hipblasDatatype_t +#endif + #define __CUDA_ARCH__ 1300 #if defined(__gfx803__) || defined(__gfx900__) || defined(__gfx906__) diff --git a/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal-embed.metal b/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal-embed.metal index 403fdee94..f6e048215 100644 --- a/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal-embed.metal +++ b/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal-embed.metal @@ -3723,6 +3723,51 @@ kernel void kernel_neg( dst[tpig] = -src0[tpig]; } +kernel void kernel_abs( + device const float * src0, + device float * dst, + uint tpig[[thread_position_in_grid]]) { + dst[tpig] = fabs(src0[tpig]); +} + +kernel void kernel_sgn( + device const float * src0, + device float * dst, + uint tpig[[thread_position_in_grid]]) { + device const float & x = src0[tpig]; + dst[tpig] = (x > 0.0f) ? 1.0f : ((x < 0.0f) ? -1.0f : 0.0f); +} + +kernel void kernel_step( + device const float * src0, + device float * dst, + uint tpig[[thread_position_in_grid]]) { + dst[tpig] = src0[tpig] > 0.0f ? 1.0f : 0.0f; +} + +kernel void kernel_hardswish( + device const float * src0, + device float * dst, + uint tpig[[thread_position_in_grid]]) { + device const float & x = src0[tpig]; + dst[tpig] = x * fmin(1.0f, fmax(0.0f, (x + 3.0f) / 6.0f)); +} + +kernel void kernel_hardsigmoid( + device const float * src0, + device float * dst, + uint tpig[[thread_position_in_grid]]) { + device const float & x = src0[tpig]; + dst[tpig] = fmin(1.0f, fmax(0.0f, (x + 3.0f) / 6.0f)); +} + +kernel void kernel_exp( + device const float * src0, + device float * dst, + uint tpig[[thread_position_in_grid]]) { + dst[tpig] = exp(src0[tpig]); +} + kernel void kernel_reglu( device const char * src0, device const char * src1, diff --git a/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal.m b/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal.m index 22414ac38..731c25eba 100644 --- a/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal.m +++ b/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal.m @@ -173,6 +173,12 @@ enum ggml_metal_kernel_type { GGML_METAL_KERNEL_TYPE_SILU, GGML_METAL_KERNEL_TYPE_SILU_4, GGML_METAL_KERNEL_TYPE_ELU, + GGML_METAL_KERNEL_TYPE_ABS, + GGML_METAL_KERNEL_TYPE_SGN, + GGML_METAL_KERNEL_TYPE_STEP, + GGML_METAL_KERNEL_TYPE_HARDSWISH, + GGML_METAL_KERNEL_TYPE_HARDSIGMOID, + GGML_METAL_KERNEL_TYPE_EXP, GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16, GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16_4, GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32, @@ -1155,6 +1161,12 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SILU, silu, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SILU_4, silu_4, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ELU, elu, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ABS, abs, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SGN, sgn, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_STEP, step, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_HARDSWISH, hardswish, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_HARDSIGMOID, hardsigmoid, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_EXP, exp, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16, soft_max_f16, has_simdgroup_reduction); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16_4, soft_max_f16_4, has_simdgroup_reduction); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32, soft_max_f32, has_simdgroup_reduction); @@ -1688,6 +1700,12 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex case GGML_UNARY_OP_SILU: case GGML_UNARY_OP_ELU: case GGML_UNARY_OP_NEG: + case GGML_UNARY_OP_ABS: + case GGML_UNARY_OP_SGN: + case GGML_UNARY_OP_STEP: + case GGML_UNARY_OP_HARDSWISH: + case GGML_UNARY_OP_HARDSIGMOID: + case GGML_UNARY_OP_EXP: return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_F32; default: return false; @@ -2439,6 +2457,78 @@ static bool ggml_metal_encode_node( [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; } break; + case GGML_UNARY_OP_ABS: + { + id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ABS].pipeline; + + [encoder setComputePipelineState:pipeline]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; + + const int64_t n = ggml_nelements(dst); + + [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; + } break; + case GGML_UNARY_OP_SGN: + { + id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SGN].pipeline; + + [encoder setComputePipelineState:pipeline]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; + + const int64_t n = ggml_nelements(dst); + + [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; + } break; + case GGML_UNARY_OP_STEP: + { + id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_STEP].pipeline; + + [encoder setComputePipelineState:pipeline]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; + + const int64_t n = ggml_nelements(dst); + + [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; + } break; + case GGML_UNARY_OP_HARDSWISH: + { + id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_HARDSWISH].pipeline; + + [encoder setComputePipelineState:pipeline]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; + + const int64_t n = ggml_nelements(dst); + + [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; + } break; + case GGML_UNARY_OP_HARDSIGMOID: + { + id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_HARDSIGMOID].pipeline; + + [encoder setComputePipelineState:pipeline]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; + + const int64_t n = ggml_nelements(dst); + + [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; + } break; + case GGML_UNARY_OP_EXP: + { + id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_EXP].pipeline; + + [encoder setComputePipelineState:pipeline]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; + + const int64_t n = ggml_nelements(dst); + + [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; + } break; default: { GGML_LOG_WARN("%s: node %3d, op = %8s not implemented\n", __func__, idx, ggml_op_name(dst->op)); diff --git a/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal.metal b/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal.metal index 239ec31fb..13235e288 100644 --- a/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal.metal +++ b/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal.metal @@ -1199,6 +1199,51 @@ kernel void kernel_neg( dst[tpig] = -src0[tpig]; } +kernel void kernel_abs( + device const float * src0, + device float * dst, + uint tpig[[thread_position_in_grid]]) { + dst[tpig] = fabs(src0[tpig]); +} + +kernel void kernel_sgn( + device const float * src0, + device float * dst, + uint tpig[[thread_position_in_grid]]) { + device const float & x = src0[tpig]; + dst[tpig] = (x > 0.0f) ? 1.0f : ((x < 0.0f) ? -1.0f : 0.0f); +} + +kernel void kernel_step( + device const float * src0, + device float * dst, + uint tpig[[thread_position_in_grid]]) { + dst[tpig] = src0[tpig] > 0.0f ? 1.0f : 0.0f; +} + +kernel void kernel_hardswish( + device const float * src0, + device float * dst, + uint tpig[[thread_position_in_grid]]) { + device const float & x = src0[tpig]; + dst[tpig] = x * fmin(1.0f, fmax(0.0f, (x + 3.0f) / 6.0f)); +} + +kernel void kernel_hardsigmoid( + device const float * src0, + device float * dst, + uint tpig[[thread_position_in_grid]]) { + device const float & x = src0[tpig]; + dst[tpig] = fmin(1.0f, fmax(0.0f, (x + 3.0f) / 6.0f)); +} + +kernel void kernel_exp( + device const float * src0, + device float * dst, + uint tpig[[thread_position_in_grid]]) { + dst[tpig] = exp(src0[tpig]); +} + kernel void kernel_reglu( device const char * src0, device const char * src1,