From 276c4df7702d9ff25c3b3e10778e90c5654428de Mon Sep 17 00:00:00 2001 From: Michael Yang Date: Tue, 5 Aug 2025 17:32:22 -0700 Subject: [PATCH] drop bfloat16 dependency goos: darwin goarch: arm64 pkg: github.com/ollama/ollama/convert/bfloat16 cpu: Apple M3 Max BenchmarkBfloat16/d4l3k/go-bfloat16-16 516 2269453 ns/op BenchmarkBfloat16/simple-16 1759 626316 ns/op PASS ok github.com/ollama/ollama/convert/bfloat16 2.502s --- convert/bfloat16/bfloat16.go | 21 ++++++++ convert/bfloat16/bfloat16_test.go | 82 +++++++++++++++++++++++++++++++ convert/reader_safetensors.go | 18 +++---- convert/reader_test.go | 20 ++++---- go.mod | 1 - go.sum | 2 - 6 files changed, 123 insertions(+), 21 deletions(-) create mode 100644 convert/bfloat16/bfloat16.go create mode 100644 convert/bfloat16/bfloat16_test.go diff --git a/convert/bfloat16/bfloat16.go b/convert/bfloat16/bfloat16.go new file mode 100644 index 000000000..4ffa78277 --- /dev/null +++ b/convert/bfloat16/bfloat16.go @@ -0,0 +1,21 @@ +package bfloat16 + +import "math" + +// FromFloat32s converts a slice of float32 values to a slice of bfloat16 values, represented as uint16s. +func FromFloat32s(f32s []float32) (u16s []uint16) { + u16s = make([]uint16, len(f32s)) + for i := range f32s { + u16s[i] = uint16(math.Float32bits(f32s[i]) >> 16) + } + return u16s +} + +// Float32s converts a slice of bfloat16 values, represented as uint16s, back to a slice of float32 values. +func Float32s(u16s []uint16) (f32s []float32) { + f32s = make([]float32, len(u16s)) + for i := range u16s { + f32s[i] = math.Float32frombits(uint32(u16s[i]) << 16) + } + return f32s +} diff --git a/convert/bfloat16/bfloat16_test.go b/convert/bfloat16/bfloat16_test.go new file mode 100644 index 000000000..99558d2c2 --- /dev/null +++ b/convert/bfloat16/bfloat16_test.go @@ -0,0 +1,82 @@ +package bfloat16 + +import ( + "math" + "math/rand/v2" + "testing" + + "github.com/google/go-cmp/cmp" +) + +func TestBfloat16(t *testing.T) { + cases := []struct { + name string + input uint16 + want uint32 + }{ + // Zero cases + {"positive zero", 0x0000, 0x0}, + {"negative zero", 0x8000, 0x80000000}, + + // Normal numbers + {"one", 0x3F80, 0x3F800000}, + {"negative one", 0xBF80, 0xBF800000}, + {"two", 0x4000, 0x40000000}, + {"half", 0x3F00, 0x3F000000}, + {"quarter", 0x3E80, 0x3E800000}, + {"max finite", 0x7F7F, 0x7F7F0000}, + {"min positive normal", 0x0080, 0x00800000}, + + // Infinity cases + {"positive infinity", 0x7F80, 0x7F800000}, + {"negative infinity", 0xFF80, 0xFF800000}, + + // NaN cases + {"NaN", 0x7FC0, 0x7FC00000}, + {"NaN with payload", 0x7FC1, 0x7FC10000}, + + // Subnormal cases + {"min positive subnormal", 0x0001, 0x00010000}, + {"max subnormal", 0x007F, 0x007F0000}, + + // Powers of 2 + {"2^10", 0x4480, 0x44800000}, + {"2^-10", 0x3A80, 0x3A800000}, + {"2^20", 0x4B80, 0x4B800000}, + + // Common approximations in BF16 + {"pi approximation", 0x4049, 0x40490000}, + {"e approximation", 0x402E, 0x402E0000}, + {"sqrt(2) approximation", 0x3FB5, 0x3FB50000}, + } + + for _, tt := range cases { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + t.Run("Float32s", func(t *testing.T) { + got := Float32s([]uint16{tt.input})[0] + if diff := cmp.Diff(tt.want, math.Float32bits(got)); diff != "" { + t.Errorf("Float32s mismatch (-want +got):\n%s", diff) + } + }) + + t.Run("FromFloat32s", func(t *testing.T) { + got := FromFloat32s([]float32{math.Float32frombits(tt.want)}) + if diff := cmp.Diff([]uint16{tt.input}, got); diff != "" { + t.Errorf("FromFloat32s mismatch (-want +got):\n%s", diff) + } + }) + }) + } +} + +func BenchmarkBfloat16(b *testing.B) { + f32s := make([]float32, 1_000_000) + for i := range f32s { + f32s[i] = rand.Float32() + } + for b.Loop() { + Float32s(FromFloat32s(f32s)) + } +} diff --git a/convert/reader_safetensors.go b/convert/reader_safetensors.go index ccc596732..fb8bfaff8 100644 --- a/convert/reader_safetensors.go +++ b/convert/reader_safetensors.go @@ -13,7 +13,7 @@ import ( "slices" "strings" - "github.com/d4l3k/go-bfloat16" + "github.com/ollama/ollama/convert/bfloat16" "github.com/x448/float16" ) @@ -169,12 +169,13 @@ func (st safetensor) WriteTo(w io.Writer) (int64, error) { } case "BF16": - u8s := make([]uint8, st.size) - if err = binary.Read(br, binary.LittleEndian, u8s); err != nil { + u16s := make([]uint16, st.size/2) + if err = binary.Read(br, binary.LittleEndian, u16s); err != nil { return 0, err } - f32s = bfloat16.DecodeFloat32(u8s) + f32s = bfloat16.Float32s(u16s) + default: return 0, fmt.Errorf("unknown data type: %s", st.dtype) } @@ -190,15 +191,14 @@ func (st safetensor) WriteTo(w io.Writer) (int64, error) { case tensorKindFP32: return 0, binary.Write(w, binary.LittleEndian, f32s) case tensorKindFP16: - f16s := make([]uint16, len(f32s)) + u16s := make([]uint16, len(f32s)) for i := range f32s { - f16s[i] = float16.Fromfloat32(f32s[i]).Bits() + u16s[i] = float16.Fromfloat32(f32s[i]).Bits() } - return 0, binary.Write(w, binary.LittleEndian, f16s) + return 0, binary.Write(w, binary.LittleEndian, u16s) case tensorKindBF16: - u8s := bfloat16.EncodeFloat32(f32s) - return 0, binary.Write(w, binary.LittleEndian, u8s) + return 0, binary.Write(w, binary.LittleEndian, bfloat16.FromFloat32s(f32s)) default: return 0, fmt.Errorf("unknown storage type: %d", st.Kind()) } diff --git a/convert/reader_test.go b/convert/reader_test.go index 6dbe32a51..efcf98cb6 100644 --- a/convert/reader_test.go +++ b/convert/reader_test.go @@ -3,11 +3,11 @@ package convert import ( "bytes" "encoding/binary" + "math" "os" "path/filepath" "testing" - "github.com/d4l3k/go-bfloat16" "github.com/google/go-cmp/cmp" "github.com/x448/float16" ) @@ -132,12 +132,13 @@ func TestSafetensors(t *testing.T) { size: 32 * 2, // 32 brain floats, each 2 bytes shape: []uint64{16, 2}, setup: func(t *testing.T, f *os.File) { - f32s := make([]float32, 32) - for i := range f32s { - f32s[i] = float32(i) + u16s := make([]uint16, 32) + for i := range u16s { + bits := math.Float32bits(float32(i)) + u16s[i] = uint16(bits >> 16) } - if err := binary.Write(f, binary.LittleEndian, bfloat16.EncodeFloat32(f32s)); err != nil { + if err := binary.Write(f, binary.LittleEndian, u16s); err != nil { t.Fatal(err) } }, @@ -154,12 +155,13 @@ func TestSafetensors(t *testing.T) { size: 32 * 2, // 32 brain floats, each 2 bytes shape: []uint64{32}, setup: func(t *testing.T, f *os.File) { - f32s := make([]float32, 32) - for i := range f32s { - f32s[i] = float32(i) + u16s := make([]uint16, 32) + for i := range u16s { + bits := math.Float32bits(float32(i)) + u16s[i] = uint16(bits >> 16) } - if err := binary.Write(f, binary.LittleEndian, bfloat16.EncodeFloat32(f32s)); err != nil { + if err := binary.Write(f, binary.LittleEndian, u16s); err != nil { t.Fatal(err) } }, diff --git a/go.mod b/go.mod index 46e7f433f..083db72ff 100644 --- a/go.mod +++ b/go.mod @@ -16,7 +16,6 @@ require ( require ( github.com/agnivade/levenshtein v1.1.1 - github.com/d4l3k/go-bfloat16 v0.0.0-20211005043715-690c3bdd05f1 github.com/dlclark/regexp2 v1.11.4 github.com/emirpasic/gods/v2 v2.0.0-alpha github.com/google/go-cmp v0.7.0 diff --git a/go.sum b/go.sum index c0ab53aab..b04430de5 100644 --- a/go.sum +++ b/go.sum @@ -35,8 +35,6 @@ github.com/containerd/console v1.0.3 h1:lIr7SlA5PxZyMV30bDW0MGbiOPXwc63yRuCP0ARu github.com/containerd/console v1.0.3/go.mod h1:7LqA/THxQ86k76b8c/EMSiaJ3h1eZkMkXar0TQ1gf3U= github.com/cpuguy83/go-md2man/v2 v2.0.2/go.mod h1:tgQtvFlXSQOSOSIRvRPT7W67SCa46tRHOmNcaadrF8o= github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= -github.com/d4l3k/go-bfloat16 v0.0.0-20211005043715-690c3bdd05f1 h1:cBzrdJPAFBsgCrDPnZxlp1dF2+k4r1kVpD7+1S1PVjY= -github.com/d4l3k/go-bfloat16 v0.0.0-20211005043715-690c3bdd05f1/go.mod h1:uw2gLcxEuYUlAd/EXyjc/v55nd3+47YAgWbSXVxPrNI= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=