71 lines
2.3 KiB
Diff
71 lines
2.3 KiB
Diff
From 0000000000000000000000000000000000000000 Mon Sep 17 00:00:00 2001
|
|
From: Michael Yang <git@mxy.ng>
|
|
Date: Thu, 1 May 2025 13:45:12 -0700
|
|
Subject: [PATCH] add argsort for int32_t
|
|
|
|
---
|
|
ggml/src/ggml-cpu/ops.cpp | 43 +++++++++++++++++++++++++++++++++++++++
|
|
1 file changed, 43 insertions(+)
|
|
|
|
diff --git a/ggml/src/ggml-cpu/ops.cpp b/ggml/src/ggml-cpu/ops.cpp
|
|
index 66b8da68..1ad571d3 100644
|
|
--- a/ggml/src/ggml-cpu/ops.cpp
|
|
+++ b/ggml/src/ggml-cpu/ops.cpp
|
|
@@ -6718,6 +6718,45 @@ static void ggml_compute_forward_argsort_f32(
|
|
}
|
|
}
|
|
|
|
+static void ggml_compute_forward_argsort_i32(
|
|
+ const ggml_compute_params * params,
|
|
+ ggml_tensor * dst) {
|
|
+
|
|
+ const ggml_tensor * src0 = dst->src[0];
|
|
+
|
|
+ GGML_TENSOR_UNARY_OP_LOCALS
|
|
+
|
|
+ GGML_ASSERT(nb0 == sizeof(int32_t));
|
|
+
|
|
+ const int ith = params->ith;
|
|
+ const int nth = params->nth;
|
|
+
|
|
+ const int64_t nr = ggml_nrows(src0);
|
|
+
|
|
+ ggml_sort_order order = (ggml_sort_order) ggml_get_op_params_i32(dst, 0);
|
|
+
|
|
+ for (int64_t i = ith; i < nr; i += nth) {
|
|
+ int32_t * dst_data = (int32_t *)((char *) dst->data + i*nb1);
|
|
+ const int32_t * src_data = (int32_t *)((char *) src0->data + i*nb01);
|
|
+
|
|
+ for (int64_t j = 0; j < ne0; j++) {
|
|
+ dst_data[j] = j;
|
|
+ }
|
|
+
|
|
+ // C doesn't have a functional sort, so we do a bubble sort instead
|
|
+ for (int64_t j = 0; j < ne0; j++) {
|
|
+ for (int64_t k = j + 1; k < ne0; k++) {
|
|
+ if ((order == GGML_SORT_ORDER_ASC && src_data[dst_data[j]] > src_data[dst_data[k]]) ||
|
|
+ (order == GGML_SORT_ORDER_DESC && src_data[dst_data[j]] < src_data[dst_data[k]])) {
|
|
+ int32_t tmp = dst_data[j];
|
|
+ dst_data[j] = dst_data[k];
|
|
+ dst_data[k] = tmp;
|
|
+ }
|
|
+ }
|
|
+ }
|
|
+ }
|
|
+}
|
|
+
|
|
void ggml_compute_forward_argsort(
|
|
const ggml_compute_params * params,
|
|
ggml_tensor * dst) {
|
|
@@ -6729,6 +6768,10 @@ void ggml_compute_forward_argsort(
|
|
{
|
|
ggml_compute_forward_argsort_f32(params, dst);
|
|
} break;
|
|
+ case GGML_TYPE_I32:
|
|
+ {
|
|
+ ggml_compute_forward_argsort_i32(params, dst);
|
|
+ } break;
|
|
default:
|
|
{
|
|
GGML_ABORT("fatal error");
|