From 2a03498bbb1d41fa413ef404cd106032118f879e Mon Sep 17 00:00:00 2001 From: Michael Yang Date: Tue, 15 Jul 2025 13:12:21 -0700 Subject: [PATCH] iter quant --- ml/backend/ggml/quantization.go | 52 ++++++++++++++++--------------- server/quantization.go | 17 ++++++++--- server/quantization_test.go | 54 +++++++++++++++++++++++++++++++++ 3 files changed, 94 insertions(+), 29 deletions(-) diff --git a/ml/backend/ggml/quantization.go b/ml/backend/ggml/quantization.go index bb31e455d..6d2f5924b 100644 --- a/ml/backend/ggml/quantization.go +++ b/ml/backend/ggml/quantization.go @@ -10,6 +10,8 @@ package ggml import "C" import ( + "iter" + "slices" "unsafe" fsggml "github.com/ollama/ollama/fs/ggml" @@ -50,32 +52,32 @@ func ConvertToF32(data []byte, dtype uint32, nelements uint64) []float32 { return f32s } -func Quantize(newType fsggml.TensorType, f32s []float32, shape []uint64) []byte { - buf := make([]byte, len(f32s)*4) // upper bound on size - nPerRow := C.int64_t(shape[0]) - nrows := C.int64_t(1) - if len(shape) > 1 { - nrows = C.int64_t(shape[1]) +func Quantize(newType fsggml.TensorType, f32s []float32, shape []uint64) iter.Seq[[]byte] { + return func(yield func([]byte) bool) { + C.ggml_quantize_init(uint32(newType)) + defer C.ggml_quantize_free() + + dims := slices.Repeat([]C.int64_t{1}, 4) + for i, s := range shape { + dims[i] = C.int64_t(s) + } + + bts := make([]byte, C.ggml_row_size(uint32(newType), dims[0])*C.size_t(dims[1])) + for chunk := range dims[2] { + offset := chunk * dims[0] * dims[1] + + n := C.ggml_quantize_chunk( + uint32(newType), + (*C.float)(&f32s[0]), + unsafe.Pointer(&bts[0]), + offset, dims[1], dims[0], nil, + ) + + if !yield(bts[:n]) { + return + } + } } - shape2 := C.int64_t(1) - if len(shape) > 2 { - shape2 = C.int64_t(shape[2]) - } - nelements_matrix := nPerRow * nrows - newSize := C.size_t(0) - for i03 := C.int64_t(0); i03 < shape2; i03++ { - f32s_03 := i03 * nelements_matrix - buf_03 := C.int64_t(C.ggml_row_size(uint32(newType), nPerRow)) * i03 * nrows - newSize += C.ggml_quantize_chunk( - uint32(newType), - (*C.float)(&f32s[f32s_03]), - unsafe.Pointer((uintptr)(unsafe.Pointer(&buf[0]))+uintptr(buf_03)), - 0, - nrows, - nPerRow, - nil) - } - return buf[:newSize] } func QuantizationVersion() uint32 { diff --git a/server/quantization.go b/server/quantization.go index 10175a351..004709744 100644 --- a/server/quantization.go +++ b/server/quantization.go @@ -40,10 +40,19 @@ func (q quantizer) WriteTo(w io.Writer) (int64, error) { } else { f32s = ggml.ConvertToF32(data, q.from.Kind, q.from.Elements()) } - data = ggml.Quantize(newType, f32s, q.from.Shape) - n, err := w.Write(data) - q.progressFn(q.from.Size()) - return int64(n), err + + var n int64 + for bts := range ggml.Quantize(newType, f32s, q.from.Shape) { + nn, err := w.Write(bts) + if err != nil { + return 0, err + } + + q.progressFn(uint64(nn)) + n += int64(nn) + } + + return n, err } type quantizeState struct { diff --git a/server/quantization_test.go b/server/quantization_test.go index 8b726c836..dbd7df372 100644 --- a/server/quantization_test.go +++ b/server/quantization_test.go @@ -2,12 +2,14 @@ package server import ( "bytes" + "encoding/binary" "fmt" "math" "os" "strings" "testing" + "github.com/google/go-cmp/cmp" fsggml "github.com/ollama/ollama/fs/ggml" "github.com/ollama/ollama/ml/backend/ggml" ) @@ -649,3 +651,55 @@ var ( }, } ) + +func TestQuantizer(t *testing.T) { + from := fsggml.Tensor{ + Name: "fp32", + Shape: []uint64{256}, + Kind: uint32(fsggml.TensorTypeF32), + } + + temp, err := os.CreateTemp(t.TempDir(), "*.bin") + if err != nil { + t.Fatalf("failed to create temp file: %v", err) + } + + f32s := make([]float32, 256) + for i := range f32s { + f32s[i] = float32(i) + } + + if err := binary.Write(temp, binary.LittleEndian, f32s); err != nil { + t.Fatalf("failed to write to temp file: %v", err) + } + + for type_, want := range quantBytes { + t.Run(type_.String(), func(t *testing.T) { + f, err := os.Open(temp.Name()) + if err != nil { + t.Fatalf("failed to open temp file: %v", err) + } + defer f.Close() + + q := quantizer{ + File: f, + from: &from, + to: &fsggml.Tensor{ + Name: type_.String(), + Shape: from.Shape, + Kind: uint32(type_), + }, + progressFn: func(uint64) {}, + } + + var b bytes.Buffer + if _, err := q.WriteTo(&b); err != nil { + t.Fatalf("WriteTo failed: %v", err) + } + + if diff := cmp.Diff(b.Bytes(), want); diff != "" { + t.Errorf("quantized data mismatch for %s (-got +want):\n%s", type_, diff) + } + }) + } +}