fix: Fix Solar and argsort/copy patches after bump
Branch: GraniteFour Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>
This commit is contained in:
parent
8fbeb68858
commit
94912ec7dd
|
|
@ -15,10 +15,10 @@ adds support for the Solar Pro architecture
|
|||
7 files changed, 248 insertions(+)
|
||||
|
||||
diff --git a/src/llama-arch.cpp b/src/llama-arch.cpp
|
||||
index 1105139f..d9d5ec65 100644
|
||||
index dbf9774..eb6be95 100644
|
||||
--- a/src/llama-arch.cpp
|
||||
+++ b/src/llama-arch.cpp
|
||||
@@ -75,6 +75,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
|
||||
@@ -77,6 +77,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
|
||||
{ LLM_ARCH_GRANITE_MOE, "granitemoe" },
|
||||
{ LLM_ARCH_GRANITE_HYBRID, "granitehybrid" },
|
||||
{ LLM_ARCH_CHAMELEON, "chameleon" },
|
||||
|
|
@ -26,7 +26,7 @@ index 1105139f..d9d5ec65 100644
|
|||
{ LLM_ARCH_WAVTOKENIZER_DEC, "wavtokenizer-dec" },
|
||||
{ LLM_ARCH_PLM, "plm" },
|
||||
{ LLM_ARCH_BAILINGMOE, "bailingmoe" },
|
||||
@@ -153,6 +154,7 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
|
||||
@@ -159,6 +160,7 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
|
||||
{ LLM_KV_ATTENTION_RELATIVE_BUCKETS_COUNT, "%s.attention.relative_buckets_count" },
|
||||
{ LLM_KV_ATTENTION_SLIDING_WINDOW, "%s.attention.sliding_window" },
|
||||
{ LLM_KV_ATTENTION_SCALE, "%s.attention.scale" },
|
||||
|
|
@ -34,7 +34,7 @@ index 1105139f..d9d5ec65 100644
|
|||
{ LLM_KV_ATTENTION_KEY_LENGTH_MLA, "%s.attention.key_length_mla" },
|
||||
{ LLM_KV_ATTENTION_VALUE_LENGTH_MLA, "%s.attention.value_length_mla" },
|
||||
|
||||
@@ -1697,6 +1699,24 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
|
||||
@@ -1755,6 +1757,24 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
|
||||
{ LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm" },
|
||||
},
|
||||
},
|
||||
|
|
@ -59,7 +59,7 @@ index 1105139f..d9d5ec65 100644
|
|||
{
|
||||
LLM_ARCH_WAVTOKENIZER_DEC,
|
||||
{
|
||||
@@ -1981,6 +2001,7 @@ static const std::map<llm_tensor, llm_tensor_info> LLM_TENSOR_INFOS = {
|
||||
@@ -2123,6 +2143,7 @@ static const std::map<llm_tensor, llm_tensor_info> LLM_TENSOR_INFOS = {
|
||||
{LLM_TENSOR_LAUREL_POST_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
|
||||
// this tensor is loaded for T5, but never used
|
||||
{LLM_TENSOR_DEC_CROSS_ATTN_REL_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_NONE}},
|
||||
|
|
@ -68,10 +68,10 @@ index 1105139f..d9d5ec65 100644
|
|||
{LLM_TENSOR_POS_NET_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
|
||||
{LLM_TENSOR_POS_NET_NORM1, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
|
||||
diff --git a/src/llama-arch.h b/src/llama-arch.h
|
||||
index a9dd188a..2cb0fd95 100644
|
||||
index 8267a8d..2983556 100644
|
||||
--- a/src/llama-arch.h
|
||||
+++ b/src/llama-arch.h
|
||||
@@ -79,6 +79,7 @@ enum llm_arch {
|
||||
@@ -81,6 +81,7 @@ enum llm_arch {
|
||||
LLM_ARCH_GRANITE_MOE,
|
||||
LLM_ARCH_GRANITE_HYBRID,
|
||||
LLM_ARCH_CHAMELEON,
|
||||
|
|
@ -79,7 +79,7 @@ index a9dd188a..2cb0fd95 100644
|
|||
LLM_ARCH_WAVTOKENIZER_DEC,
|
||||
LLM_ARCH_PLM,
|
||||
LLM_ARCH_BAILINGMOE,
|
||||
@@ -157,6 +158,7 @@ enum llm_kv {
|
||||
@@ -163,6 +164,7 @@ enum llm_kv {
|
||||
LLM_KV_ATTENTION_RELATIVE_BUCKETS_COUNT,
|
||||
LLM_KV_ATTENTION_SLIDING_WINDOW,
|
||||
LLM_KV_ATTENTION_SCALE,
|
||||
|
|
@ -87,7 +87,7 @@ index a9dd188a..2cb0fd95 100644
|
|||
LLM_KV_ATTENTION_KEY_LENGTH_MLA,
|
||||
LLM_KV_ATTENTION_VALUE_LENGTH_MLA,
|
||||
|
||||
@@ -380,6 +382,7 @@ enum llm_tensor {
|
||||
@@ -388,6 +390,7 @@ enum llm_tensor {
|
||||
LLM_TENSOR_ENC_OUTPUT_NORM,
|
||||
LLM_TENSOR_CLS,
|
||||
LLM_TENSOR_CLS_OUT,
|
||||
|
|
@ -96,10 +96,10 @@ index a9dd188a..2cb0fd95 100644
|
|||
LLM_TENSOR_CONVNEXT_DW,
|
||||
LLM_TENSOR_CONVNEXT_NORM,
|
||||
diff --git a/src/llama-hparams.cpp b/src/llama-hparams.cpp
|
||||
index 86c814d5..f1c965b8 100644
|
||||
index 7a06368..35fc054 100644
|
||||
--- a/src/llama-hparams.cpp
|
||||
+++ b/src/llama-hparams.cpp
|
||||
@@ -95,6 +95,14 @@ uint32_t llama_hparams::n_pos_per_embd() const {
|
||||
@@ -146,6 +146,14 @@ uint32_t llama_hparams::n_pos_per_embd() const {
|
||||
return rope_type == LLAMA_ROPE_TYPE_MROPE ? 4 : 1;
|
||||
}
|
||||
|
||||
|
|
@ -115,10 +115,10 @@ index 86c814d5..f1c965b8 100644
|
|||
if (il < n_layer) {
|
||||
return swa_layers[il];
|
||||
diff --git a/src/llama-hparams.h b/src/llama-hparams.h
|
||||
index 476d0a5e..906fa185 100644
|
||||
index 8b7e2a1..d5f673e 100644
|
||||
--- a/src/llama-hparams.h
|
||||
+++ b/src/llama-hparams.h
|
||||
@@ -59,6 +59,8 @@ struct llama_hparams {
|
||||
@@ -61,6 +61,8 @@ struct llama_hparams {
|
||||
std::array<uint32_t, LLAMA_MAX_LAYERS> n_head_kv_arr;
|
||||
std::array<uint32_t, LLAMA_MAX_LAYERS> n_ff_arr;
|
||||
|
||||
|
|
@ -127,7 +127,7 @@ index 476d0a5e..906fa185 100644
|
|||
uint32_t n_layer_dense_lead = 0;
|
||||
uint32_t n_lora_q = 0;
|
||||
uint32_t n_lora_kv = 0;
|
||||
@@ -201,6 +203,9 @@ struct llama_hparams {
|
||||
@@ -218,6 +220,9 @@ struct llama_hparams {
|
||||
|
||||
uint32_t n_pos_per_embd() const;
|
||||
|
||||
|
|
@ -138,7 +138,7 @@ index 476d0a5e..906fa185 100644
|
|||
};
|
||||
|
||||
diff --git a/src/llama-model-loader.cpp b/src/llama-model-loader.cpp
|
||||
index bd9e6da8..99ea20df 100644
|
||||
index bd9e6da..99ea20d 100644
|
||||
--- a/src/llama-model-loader.cpp
|
||||
+++ b/src/llama-model-loader.cpp
|
||||
@@ -464,6 +464,7 @@ namespace GGUFMeta {
|
||||
|
|
@ -150,10 +150,10 @@ index bd9e6da8..99ea20df 100644
|
|||
llama_model_loader::llama_model_loader(
|
||||
const std::string & fname,
|
||||
diff --git a/src/llama-model.cpp b/src/llama-model.cpp
|
||||
index 8fc025af..35d7a4df 100644
|
||||
index e3aa9e6..20a7060 100644
|
||||
--- a/src/llama-model.cpp
|
||||
+++ b/src/llama-model.cpp
|
||||
@@ -1567,6 +1567,21 @@ void llama_model::load_hparams(llama_model_loader & ml) {
|
||||
@@ -1648,6 +1648,21 @@ void llama_model::load_hparams(llama_model_loader & ml) {
|
||||
default: type = LLM_TYPE_UNKNOWN;
|
||||
}
|
||||
} break;
|
||||
|
|
@ -175,7 +175,7 @@ index 8fc025af..35d7a4df 100644
|
|||
case LLM_ARCH_WAVTOKENIZER_DEC:
|
||||
{
|
||||
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps);
|
||||
@@ -4325,6 +4340,34 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
|
||||
@@ -4555,6 +4570,34 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
|
||||
|
||||
layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
|
||||
|
||||
|
|
@ -210,12 +210,12 @@ index 8fc025af..35d7a4df 100644
|
|||
layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0);
|
||||
layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0);
|
||||
layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0);
|
||||
@@ -14369,6 +14412,165 @@ struct llm_build_granite_hybrid : public llm_graph_context_mamba {
|
||||
@@ -14925,6 +14968,165 @@ struct llm_build_granite_hybrid : public llm_graph_context_mamba {
|
||||
}
|
||||
};
|
||||
|
||||
+struct llm_build_solar : public llm_graph_context {
|
||||
+ llm_build_solar(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) {
|
||||
+ llm_build_solar(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
|
||||
+ const int64_t n_embd_head = hparams.n_embd_head_v;
|
||||
+ GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
|
||||
+ GGML_ASSERT(n_embd_head == hparams.n_rot);
|
||||
|
|
@ -314,7 +314,7 @@ index 8fc025af..35d7a4df 100644
|
|||
+ cb(Kcur, "Kcur", il);
|
||||
+ cb(Vcur, "Vcur", il);
|
||||
+
|
||||
+ cur = build_attn(inp_attn, gf,
|
||||
+ cur = build_attn(inp_attn,
|
||||
+ model.layers[il].wo, model.layers[il].bo,
|
||||
+ Qcur, Kcur, Vcur, nullptr, nullptr, kq_scale, il);
|
||||
+ cb(cur, "attn_out", il);
|
||||
|
|
@ -376,18 +376,18 @@ index 8fc025af..35d7a4df 100644
|
|||
// ref: https://github.com/facebookresearch/chameleon
|
||||
// based on the original build_llama() function, changes:
|
||||
// * qk-norm
|
||||
@@ -16225,6 +16427,10 @@ llm_graph_result_ptr llama_model::build_graph(
|
||||
@@ -17582,6 +17784,10 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const {
|
||||
{
|
||||
llm = std::make_unique<llm_build_chameleon>(*this, params, gf);
|
||||
llm = std::make_unique<llm_build_chameleon>(*this, params);
|
||||
} break;
|
||||
+ case LLM_ARCH_SOLAR:
|
||||
+ {
|
||||
+ llm = std::make_unique<llm_build_solar>(*this, params, gf);
|
||||
+ llm = std::make_unique<llm_build_solar>(*this, params);
|
||||
+ } break;
|
||||
case LLM_ARCH_WAVTOKENIZER_DEC:
|
||||
{
|
||||
llm = std::make_unique<llm_build_wavtokenizer_dec>(*this, params, gf);
|
||||
@@ -16412,6 +16618,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) {
|
||||
llm = std::make_unique<llm_build_wavtokenizer_dec>(*this, params);
|
||||
@@ -17785,6 +17991,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) {
|
||||
case LLM_ARCH_GRANITE_MOE:
|
||||
case LLM_ARCH_GRANITE_HYBRID:
|
||||
case LLM_ARCH_CHAMELEON:
|
||||
|
|
@ -396,10 +396,10 @@ index 8fc025af..35d7a4df 100644
|
|||
case LLM_ARCH_NEO_BERT:
|
||||
case LLM_ARCH_SMOLLM3:
|
||||
diff --git a/src/llama-model.h b/src/llama-model.h
|
||||
index 431efbd5..05a9adfa 100644
|
||||
index 094e238..2692cf8 100644
|
||||
--- a/src/llama-model.h
|
||||
+++ b/src/llama-model.h
|
||||
@@ -67,6 +67,7 @@ enum llm_type {
|
||||
@@ -70,6 +70,7 @@ enum llm_type {
|
||||
LLM_TYPE_15B,
|
||||
LLM_TYPE_16B,
|
||||
LLM_TYPE_20B,
|
||||
|
|
@ -407,7 +407,7 @@ index 431efbd5..05a9adfa 100644
|
|||
LLM_TYPE_27B,
|
||||
LLM_TYPE_30B,
|
||||
LLM_TYPE_32B,
|
||||
@@ -338,6 +339,8 @@ struct llama_layer {
|
||||
@@ -349,6 +350,8 @@ struct llama_layer {
|
||||
struct ggml_tensor * laurel_r = nullptr;
|
||||
struct ggml_tensor * laurel_post_norm = nullptr;
|
||||
|
||||
|
|
|
|||
|
|
@ -6,14 +6,14 @@ Subject: [PATCH] add argsort and cuda copy for i32
|
|||
---
|
||||
ggml/src/ggml-cpu/ops.cpp | 43 ++++++++++++++
|
||||
ggml/src/ggml-cuda/argsort.cu | 102 +++++++++++++++++++++++++++++++++-
|
||||
ggml/src/ggml-cuda/cpy.cu | 49 ++++++++++++++++
|
||||
3 files changed, 192 insertions(+), 2 deletions(-)
|
||||
ggml/src/ggml-cuda/cpy.cu | 2 +
|
||||
3 files changed, 145 insertions(+), 2 deletions(-)
|
||||
|
||||
diff --git a/ggml/src/ggml-cpu/ops.cpp b/ggml/src/ggml-cpu/ops.cpp
|
||||
index aefa37e0..ed5a6c91 100644
|
||||
index 6581d27..9dedc11 100644
|
||||
--- a/ggml/src/ggml-cpu/ops.cpp
|
||||
+++ b/ggml/src/ggml-cpu/ops.cpp
|
||||
@@ -7050,6 +7050,45 @@ static void ggml_compute_forward_argsort_f32(
|
||||
@@ -7967,6 +7967,45 @@ static void ggml_compute_forward_argsort_f32(
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -59,7 +59,7 @@ index aefa37e0..ed5a6c91 100644
|
|||
void ggml_compute_forward_argsort(
|
||||
const ggml_compute_params * params,
|
||||
ggml_tensor * dst) {
|
||||
@@ -7061,6 +7100,10 @@ void ggml_compute_forward_argsort(
|
||||
@@ -7978,6 +8017,10 @@ void ggml_compute_forward_argsort(
|
||||
{
|
||||
ggml_compute_forward_argsort_f32(params, dst);
|
||||
} break;
|
||||
|
|
@ -71,7 +71,7 @@ index aefa37e0..ed5a6c91 100644
|
|||
{
|
||||
GGML_ABORT("fatal error");
|
||||
diff --git a/ggml/src/ggml-cuda/argsort.cu b/ggml/src/ggml-cuda/argsort.cu
|
||||
index 607ded85..53b02634 100644
|
||||
index 607ded8..53b0263 100644
|
||||
--- a/ggml/src/ggml-cuda/argsort.cu
|
||||
+++ b/ggml/src/ggml-cuda/argsort.cu
|
||||
@@ -85,13 +85,107 @@ static void argsort_f32_i32_cuda(const float * x, int * dst, const int ncols, co
|
||||
|
|
@ -195,83 +195,15 @@ index 607ded85..53b02634 100644
|
|||
+ }
|
||||
}
|
||||
diff --git a/ggml/src/ggml-cuda/cpy.cu b/ggml/src/ggml-cuda/cpy.cu
|
||||
index 2c55d214..90d95d32 100644
|
||||
index f9bb025..f99467e 100644
|
||||
--- a/ggml/src/ggml-cuda/cpy.cu
|
||||
+++ b/ggml/src/ggml-cuda/cpy.cu
|
||||
@@ -41,6 +41,13 @@ static __device__ void cpy_1_f16_f32(const char * cxi, char * cdsti) {
|
||||
*dsti = *xi;
|
||||
}
|
||||
|
||||
+static __device__ void cpy_1_i32_i32(const char * cxi, char * cdsti) {
|
||||
+ const int32_t * xi = (const int32_t *) cxi;
|
||||
+ int32_t * dsti = (int32_t *) cdsti;
|
||||
+
|
||||
+ *dsti = *xi;
|
||||
+}
|
||||
+
|
||||
template <cpy_kernel_t cpy_1>
|
||||
static __global__ void cpy_f32_f16(const char * cx, char * cdst_direct, const int ne,
|
||||
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
|
||||
@@ -71,6 +78,44 @@ static __global__ void cpy_f32_f16(const char * cx, char * cdst_direct, const in
|
||||
cpy_1(cx + x_offset, cdst + dst_offset);
|
||||
}
|
||||
|
||||
+// First, add this template function after the other template functions
|
||||
+template <cpy_kernel_t cpy_1>
|
||||
+static __global__ void cpy_i32_i32(const char * cx, char * cdst, const int ne,
|
||||
+ const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
|
||||
+ const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,
|
||||
+ const int nb12, const int nb13) {
|
||||
+ const int64_t i = blockDim.x*blockIdx.x + threadIdx.x;
|
||||
+
|
||||
+ if (i >= ne) {
|
||||
+ return;
|
||||
+ }
|
||||
+
|
||||
+ const int64_t i03 = i/(ne00 * ne01 * ne02);
|
||||
+ const int64_t i02 = (i - i03*ne00*ne01*ne02 )/ (ne00*ne01);
|
||||
+ const int64_t i01 = (i - i03*ne00*ne01*ne02 - i02*ne01*ne00) / ne00;
|
||||
+ const int64_t i00 = i - i03*ne00*ne01*ne02 - i02*ne01*ne00 - i01*ne00;
|
||||
+ const int64_t x_offset = i00*nb00 + i01*nb01 + i02*nb02 + i03 * nb03;
|
||||
+
|
||||
+ const int64_t i13 = i/(ne10 * ne11 * ne12);
|
||||
+ const int64_t i12 = (i - i13*ne10*ne11*ne12) / (ne10*ne11);
|
||||
+ const int64_t i11 = (i - i13*ne10*ne11*ne12 - i12*ne10*ne11) / ne10;
|
||||
+ const int64_t i10 = i - i13*ne10*ne11*ne12 - i12*ne10*ne11 - i11*ne10;
|
||||
+ const int64_t dst_offset = i10*nb10 + i11*nb11 + i12*nb12 + i13 * nb13;
|
||||
+
|
||||
+ cpy_1(cx + x_offset, cdst + dst_offset);
|
||||
+}
|
||||
+
|
||||
+// Then modify the ggml_cpy_i32_i32_cuda function to use the new template
|
||||
+static void ggml_cpy_i32_i32_cuda(
|
||||
+ const char * cx, char * cdst, const int ne,
|
||||
+ const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
|
||||
+ const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream, char ** cdst_indirect, int graph_cpynode_index) {
|
||||
+
|
||||
+ const int num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE;
|
||||
+ cpy_i32_i32<cpy_1_i32_i32><<<num_blocks, CUDA_CPY_BLOCK_SIZE, 0, stream>>>
|
||||
+ (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
|
||||
+}
|
||||
+
|
||||
static __device__ void cpy_blck_f32_q8_0(const char * cxi, char * cdsti) {
|
||||
const float * xi = (const float *) cxi;
|
||||
block_q8_0 * dsti = (block_q8_0 *) cdsti;
|
||||
@@ -643,6 +688,8 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg
|
||||
ggml_cpy_f16_f16_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
|
||||
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32) {
|
||||
ggml_cpy_f16_f32_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
|
||||
+ } else if (src0->type == GGML_TYPE_I32 && src1->type == GGML_TYPE_I32) {
|
||||
+ ggml_cpy_i32_i32_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
|
||||
@@ -373,6 +373,8 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg
|
||||
ggml_cpy_flt_cuda<nv_bfloat16, nv_bfloat16> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
|
||||
} else if (src0->type == GGML_TYPE_BF16 && src1->type == GGML_TYPE_F16) {
|
||||
ggml_cpy_flt_cuda<nv_bfloat16, half> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
|
||||
+ } else if (src0->type == GGML_TYPE_BF16 && src1->type == GGML_TYPE_F32) {
|
||||
+ ggml_cpy_flt_cuda<int, int>(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
|
||||
} else if (src0->type == GGML_TYPE_BF16 && src1->type == GGML_TYPE_F32) {
|
||||
ggml_cpy_flt_cuda<nv_bfloat16, float> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
|
||||
} else {
|
||||
GGML_ABORT("%s: unsupported type combination (%s to %s)\n", __func__,
|
||||
ggml_type_name(src0->type), ggml_type_name(src1->type));
|
||||
@@ -698,6 +745,8 @@ void* ggml_cuda_cpy_fn(const ggml_tensor * src0, ggml_tensor * src1) {
|
||||
return (void*) cpy_f32_f16<cpy_1_f32_f16>;
|
||||
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32) {
|
||||
return (void*) cpy_f32_f16<cpy_1_f16_f32>;
|
||||
+ } else if (src0->type == GGML_TYPE_I32 && src1->type == GGML_TYPE_I32) {
|
||||
+ return (void*) cpy_i32_i32<cpy_1_i32_i32>;
|
||||
} else {
|
||||
GGML_ABORT("%s: unsupported type combination (%s to %s)\n", __func__,
|
||||
ggml_type_name(src0->type), ggml_type_name(src1->type));
|
||||
|
|
|
|||
Loading…
Reference in New Issue