Update sync with latest llama.cpp layout, and run against b3485

This commit is contained in:
Daniel Hiltgen
2024-07-29 16:21:09 -07:00
parent 5152a430f5
commit 41bf8d9932
235 changed files with 22809 additions and 12964 deletions

View File

@@ -43,8 +43,10 @@
// [1] J. Tunney, LLaMA Now Goes Faster on CPUs, Mar. 2024. [Online].
// Available: https://justine.lol/matmul/. [Accessed: 29-Mar-2024].
#if defined(__GNUC__)
#pragma GCC diagnostic ignored "-Wpedantic"
#pragma GCC diagnostic ignored "-Wignored-attributes"
#endif
#include "sgemm.h"
#include "ggml-impl.h"
@@ -247,9 +249,8 @@ class tinyBLAS {
: A(A), B(B), C(C), k(k), lda(lda), ldb(ldb), ldc(ldc), ith(ith), nth(nth) {
}
void matmul(int64_t m, int64_t n, int task) {
if (task == GGML_TASK_TYPE_COMPUTE)
mnpack(0, m, 0, n);
void matmul(int64_t m, int64_t n) {
mnpack(0, m, 0, n);
}
private:
@@ -456,9 +457,8 @@ class tinyBLAS_Q0_ARM {
: A(A), B(B), C(C), k(k), lda(lda), ldb(ldb), ldc(ldc), ith(ith), nth(nth) {
}
void matmul(int64_t m, int64_t n, int task) {
if (task == GGML_TASK_TYPE_COMPUTE)
mnpack(0, m, 0, n);
void matmul(int64_t m, int64_t n) {
mnpack(0, m, 0, n);
}
private:
@@ -594,9 +594,8 @@ class tinyBLAS_Q0_AVX {
: A(A), B(B), C(C), k(k), lda(lda), ldb(ldb), ldc(ldc), ith(ith), nth(nth) {
}
void matmul(int64_t m, int64_t n, int task) {
if (task == GGML_TASK_TYPE_COMPUTE)
mnpack(0, m, 0, n);
void matmul(int64_t m, int64_t n) {
mnpack(0, m, 0, n);
}
private:
@@ -827,7 +826,7 @@ class tinyBLAS_Q0_AVX {
* For example, for single-threaded single-precision GEMM you can say
*
* llamafile_sgemm(m, n, k, A, lda, B, ldb, C, ldc,
* 0, 1, GGML_TASK_TYPE_COMPUTE,
* 0, 1,
* GGML_TYPE_F32, GGML_TYPE_F32, GGML_TYPE_F32);
*
* @param m is rows in `A` and `C`
@@ -841,14 +840,13 @@ class tinyBLAS_Q0_AVX {
* @param ldc is row stride of `C`
* @param ith is thread id (must be less than `nth`)
* @param nth is number of threads (must be greater than zero)
* @param task is GGML task type
* @param Atype is GGML data type of `A`
* @param Btype is GGML data type of `B`
* @param Ctype is GGML data type of `C`
* @return true if this function was able to service the matmul request
*/
bool llamafile_sgemm(int64_t m, int64_t n, int64_t k, const void *A, int64_t lda, const void *B, int64_t ldb, void *C,
int64_t ldc, int ith, int nth, int task, int Atype, int Btype, int Ctype) {
int64_t ldc, int ith, int nth, int Atype, int Btype, int Ctype) {
assert(m >= 0);
assert(n >= 0);
@@ -875,7 +873,7 @@ bool llamafile_sgemm(int64_t m, int64_t n, int64_t k, const void *A, int64_t lda
(const float *)B, ldb,
(float *)C, ldc,
ith, nth};
tb.matmul(m, n, task);
tb.matmul(m, n);
return true;
#elif defined(__AVX__) || defined(__AVX2__)
if (k % 8)
@@ -885,7 +883,7 @@ bool llamafile_sgemm(int64_t m, int64_t n, int64_t k, const void *A, int64_t lda
(const float *)B, ldb,
(float *)C, ldc,
ith, nth};
tb.matmul(m, n, task);
tb.matmul(m, n);
return true;
#elif defined(__ARM_NEON)
if (n < 4)
@@ -897,7 +895,7 @@ bool llamafile_sgemm(int64_t m, int64_t n, int64_t k, const void *A, int64_t lda
(const float *)B, ldb,
(float *)C, ldc,
ith, nth};
tb.matmul(m, n, task);
tb.matmul(m, n);
return true;
#else
return false;
@@ -915,7 +913,7 @@ bool llamafile_sgemm(int64_t m, int64_t n, int64_t k, const void *A, int64_t lda
(const float *)B, ldb,
(float *)C, ldc,
ith, nth};
tb.matmul(m, n, task);
tb.matmul(m, n);
return true;
#elif (defined(__AVX__) || defined(__AVX2__)) && defined(__F16C__)
if (k % 8)
@@ -927,7 +925,7 @@ bool llamafile_sgemm(int64_t m, int64_t n, int64_t k, const void *A, int64_t lda
(const float *)B, ldb,
(float *)C, ldc,
ith, nth};
tb.matmul(m, n, task);
tb.matmul(m, n);
return true;
#elif defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) && !defined(_MSC_VER)
if (n < 8)
@@ -941,7 +939,7 @@ bool llamafile_sgemm(int64_t m, int64_t n, int64_t k, const void *A, int64_t lda
(const ggml_fp16_t *)B, ldb,
(float *)C, ldc,
ith, nth};
tb.matmul(m, n, task);
tb.matmul(m, n);
return true;
#elif defined(__ARM_NEON) && !defined(_MSC_VER)
if (k % 4)
@@ -953,7 +951,7 @@ bool llamafile_sgemm(int64_t m, int64_t n, int64_t k, const void *A, int64_t lda
(const float *)B, ldb,
(float *)C, ldc,
ith, nth};
tb.matmul(m, n, task);
tb.matmul(m, n);
return true;
#else
return false;
@@ -969,7 +967,7 @@ bool llamafile_sgemm(int64_t m, int64_t n, int64_t k, const void *A, int64_t lda
(const block_q8_0 *)B, ldb,
(float *)C, ldc,
ith, nth};
tb.matmul(m, n, task);
tb.matmul(m, n);
return true;
#elif defined(__ARM_FEATURE_DOTPROD)
tinyBLAS_Q0_ARM<block_q8_0> tb{
@@ -977,7 +975,7 @@ bool llamafile_sgemm(int64_t m, int64_t n, int64_t k, const void *A, int64_t lda
(const block_q8_0 *)B, ldb,
(float *)C, ldc,
ith, nth};
tb.matmul(m, n, task);
tb.matmul(m, n);
return true;
#else
return false;
@@ -993,7 +991,7 @@ bool llamafile_sgemm(int64_t m, int64_t n, int64_t k, const void *A, int64_t lda
(const block_q8_0 *)B, ldb,
(float *)C, ldc,
ith, nth};
tb.matmul(m, n, task);
tb.matmul(m, n);
return true;
#elif defined(__ARM_FEATURE_DOTPROD)
tinyBLAS_Q0_ARM<block_q4_0> tb{
@@ -1001,7 +999,7 @@ bool llamafile_sgemm(int64_t m, int64_t n, int64_t k, const void *A, int64_t lda
(const block_q8_0 *)B, ldb,
(float *)C, ldc,
ith, nth};
tb.matmul(m, n, task);
tb.matmul(m, n);
return true;
#else
return false;
@@ -1023,7 +1021,6 @@ bool llamafile_sgemm(int64_t m, int64_t n, int64_t k, const void *A, int64_t lda
(void)ldc;
(void)ith;
(void)nth;
(void)task;
(void)Atype;
(void)Btype;
(void)Ctype;