iter quant
This commit is contained in:
@@ -10,6 +10,8 @@ package ggml
|
||||
import "C"
|
||||
|
||||
import (
|
||||
"iter"
|
||||
"slices"
|
||||
"unsafe"
|
||||
|
||||
fsggml "github.com/ollama/ollama/fs/ggml"
|
||||
@@ -50,32 +52,32 @@ func ConvertToF32(data []byte, dtype uint32, nelements uint64) []float32 {
|
||||
return f32s
|
||||
}
|
||||
|
||||
func Quantize(newType fsggml.TensorType, f32s []float32, shape []uint64) []byte {
|
||||
buf := make([]byte, len(f32s)*4) // upper bound on size
|
||||
nPerRow := C.int64_t(shape[0])
|
||||
nrows := C.int64_t(1)
|
||||
if len(shape) > 1 {
|
||||
nrows = C.int64_t(shape[1])
|
||||
func Quantize(newType fsggml.TensorType, f32s []float32, shape []uint64) iter.Seq[[]byte] {
|
||||
return func(yield func([]byte) bool) {
|
||||
C.ggml_quantize_init(uint32(newType))
|
||||
defer C.ggml_quantize_free()
|
||||
|
||||
dims := slices.Repeat([]C.int64_t{1}, 4)
|
||||
for i, s := range shape {
|
||||
dims[i] = C.int64_t(s)
|
||||
}
|
||||
|
||||
bts := make([]byte, C.ggml_row_size(uint32(newType), dims[0])*C.size_t(dims[1]))
|
||||
for chunk := range dims[2] {
|
||||
offset := chunk * dims[0] * dims[1]
|
||||
|
||||
n := C.ggml_quantize_chunk(
|
||||
uint32(newType),
|
||||
(*C.float)(&f32s[0]),
|
||||
unsafe.Pointer(&bts[0]),
|
||||
offset, dims[1], dims[0], nil,
|
||||
)
|
||||
|
||||
if !yield(bts[:n]) {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
shape2 := C.int64_t(1)
|
||||
if len(shape) > 2 {
|
||||
shape2 = C.int64_t(shape[2])
|
||||
}
|
||||
nelements_matrix := nPerRow * nrows
|
||||
newSize := C.size_t(0)
|
||||
for i03 := C.int64_t(0); i03 < shape2; i03++ {
|
||||
f32s_03 := i03 * nelements_matrix
|
||||
buf_03 := C.int64_t(C.ggml_row_size(uint32(newType), nPerRow)) * i03 * nrows
|
||||
newSize += C.ggml_quantize_chunk(
|
||||
uint32(newType),
|
||||
(*C.float)(&f32s[f32s_03]),
|
||||
unsafe.Pointer((uintptr)(unsafe.Pointer(&buf[0]))+uintptr(buf_03)),
|
||||
0,
|
||||
nrows,
|
||||
nPerRow,
|
||||
nil)
|
||||
}
|
||||
return buf[:newSize]
|
||||
}
|
||||
|
||||
func QuantizationVersion() uint32 {
|
||||
|
||||
Reference in New Issue
Block a user