ollama/llama/patches/0018-add-argsort-for-int32_...

71 lines
2.3 KiB
Diff

From 0000000000000000000000000000000000000000 Mon Sep 17 00:00:00 2001
From: Michael Yang <git@mxy.ng>
Date: Thu, 1 May 2025 13:45:12 -0700
Subject: [PATCH] add argsort for int32_t
---
ggml/src/ggml-cpu/ops.cpp | 43 +++++++++++++++++++++++++++++++++++++++
1 file changed, 43 insertions(+)
diff --git a/ggml/src/ggml-cpu/ops.cpp b/ggml/src/ggml-cpu/ops.cpp
index 66b8da68..1ad571d3 100644
--- a/ggml/src/ggml-cpu/ops.cpp
+++ b/ggml/src/ggml-cpu/ops.cpp
@@ -6718,6 +6718,45 @@ static void ggml_compute_forward_argsort_f32(
}
}
+static void ggml_compute_forward_argsort_i32(
+ const ggml_compute_params * params,
+ ggml_tensor * dst) {
+
+ const ggml_tensor * src0 = dst->src[0];
+
+ GGML_TENSOR_UNARY_OP_LOCALS
+
+ GGML_ASSERT(nb0 == sizeof(int32_t));
+
+ const int ith = params->ith;
+ const int nth = params->nth;
+
+ const int64_t nr = ggml_nrows(src0);
+
+ ggml_sort_order order = (ggml_sort_order) ggml_get_op_params_i32(dst, 0);
+
+ for (int64_t i = ith; i < nr; i += nth) {
+ int32_t * dst_data = (int32_t *)((char *) dst->data + i*nb1);
+ const int32_t * src_data = (int32_t *)((char *) src0->data + i*nb01);
+
+ for (int64_t j = 0; j < ne0; j++) {
+ dst_data[j] = j;
+ }
+
+ // C doesn't have a functional sort, so we do a bubble sort instead
+ for (int64_t j = 0; j < ne0; j++) {
+ for (int64_t k = j + 1; k < ne0; k++) {
+ if ((order == GGML_SORT_ORDER_ASC && src_data[dst_data[j]] > src_data[dst_data[k]]) ||
+ (order == GGML_SORT_ORDER_DESC && src_data[dst_data[j]] < src_data[dst_data[k]])) {
+ int32_t tmp = dst_data[j];
+ dst_data[j] = dst_data[k];
+ dst_data[k] = tmp;
+ }
+ }
+ }
+ }
+}
+
void ggml_compute_forward_argsort(
const ggml_compute_params * params,
ggml_tensor * dst) {
@@ -6729,6 +6768,10 @@ void ggml_compute_forward_argsort(
{
ggml_compute_forward_argsort_f32(params, dst);
} break;
+ case GGML_TYPE_I32:
+ {
+ ggml_compute_forward_argsort_i32(params, dst);
+ } break;
default:
{
GGML_ABORT("fatal error");