drop bfloat16 dependency

goos: darwin
goarch: arm64
pkg: github.com/ollama/ollama/convert/bfloat16
cpu: Apple M3 Max
BenchmarkBfloat16/d4l3k/go-bfloat16-16               516           2269453 ns/op
BenchmarkBfloat16/simple-16                          1759            626316 ns/op
PASS
ok      github.com/ollama/ollama/convert/bfloat16        2.502s
This commit is contained in:
Michael Yang 2025-08-05 17:32:22 -07:00
parent ef7d26ba2c
commit 276c4df770
6 changed files with 123 additions and 21 deletions

View File

@ -0,0 +1,21 @@
package bfloat16
import "math"
// FromFloat32s converts a slice of float32 values to a slice of bfloat16 values, represented as uint16s.
func FromFloat32s(f32s []float32) (u16s []uint16) {
u16s = make([]uint16, len(f32s))
for i := range f32s {
u16s[i] = uint16(math.Float32bits(f32s[i]) >> 16)
}
return u16s
}
// Float32s converts a slice of bfloat16 values, represented as uint16s, back to a slice of float32 values.
func Float32s(u16s []uint16) (f32s []float32) {
f32s = make([]float32, len(u16s))
for i := range u16s {
f32s[i] = math.Float32frombits(uint32(u16s[i]) << 16)
}
return f32s
}

View File

@ -0,0 +1,82 @@
package bfloat16
import (
"math"
"math/rand/v2"
"testing"
"github.com/google/go-cmp/cmp"
)
func TestBfloat16(t *testing.T) {
cases := []struct {
name string
input uint16
want uint32
}{
// Zero cases
{"positive zero", 0x0000, 0x0},
{"negative zero", 0x8000, 0x80000000},
// Normal numbers
{"one", 0x3F80, 0x3F800000},
{"negative one", 0xBF80, 0xBF800000},
{"two", 0x4000, 0x40000000},
{"half", 0x3F00, 0x3F000000},
{"quarter", 0x3E80, 0x3E800000},
{"max finite", 0x7F7F, 0x7F7F0000},
{"min positive normal", 0x0080, 0x00800000},
// Infinity cases
{"positive infinity", 0x7F80, 0x7F800000},
{"negative infinity", 0xFF80, 0xFF800000},
// NaN cases
{"NaN", 0x7FC0, 0x7FC00000},
{"NaN with payload", 0x7FC1, 0x7FC10000},
// Subnormal cases
{"min positive subnormal", 0x0001, 0x00010000},
{"max subnormal", 0x007F, 0x007F0000},
// Powers of 2
{"2^10", 0x4480, 0x44800000},
{"2^-10", 0x3A80, 0x3A800000},
{"2^20", 0x4B80, 0x4B800000},
// Common approximations in BF16
{"pi approximation", 0x4049, 0x40490000},
{"e approximation", 0x402E, 0x402E0000},
{"sqrt(2) approximation", 0x3FB5, 0x3FB50000},
}
for _, tt := range cases {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
t.Run("Float32s", func(t *testing.T) {
got := Float32s([]uint16{tt.input})[0]
if diff := cmp.Diff(tt.want, math.Float32bits(got)); diff != "" {
t.Errorf("Float32s mismatch (-want +got):\n%s", diff)
}
})
t.Run("FromFloat32s", func(t *testing.T) {
got := FromFloat32s([]float32{math.Float32frombits(tt.want)})
if diff := cmp.Diff([]uint16{tt.input}, got); diff != "" {
t.Errorf("FromFloat32s mismatch (-want +got):\n%s", diff)
}
})
})
}
}
func BenchmarkBfloat16(b *testing.B) {
f32s := make([]float32, 1_000_000)
for i := range f32s {
f32s[i] = rand.Float32()
}
for b.Loop() {
Float32s(FromFloat32s(f32s))
}
}

View File

@ -13,7 +13,7 @@ import (
"slices"
"strings"
"github.com/d4l3k/go-bfloat16"
"github.com/ollama/ollama/convert/bfloat16"
"github.com/x448/float16"
)
@ -169,12 +169,13 @@ func (st safetensor) WriteTo(w io.Writer) (int64, error) {
}
case "BF16":
u8s := make([]uint8, st.size)
if err = binary.Read(br, binary.LittleEndian, u8s); err != nil {
u16s := make([]uint16, st.size/2)
if err = binary.Read(br, binary.LittleEndian, u16s); err != nil {
return 0, err
}
f32s = bfloat16.DecodeFloat32(u8s)
f32s = bfloat16.Float32s(u16s)
default:
return 0, fmt.Errorf("unknown data type: %s", st.dtype)
}
@ -190,15 +191,14 @@ func (st safetensor) WriteTo(w io.Writer) (int64, error) {
case tensorKindFP32:
return 0, binary.Write(w, binary.LittleEndian, f32s)
case tensorKindFP16:
f16s := make([]uint16, len(f32s))
u16s := make([]uint16, len(f32s))
for i := range f32s {
f16s[i] = float16.Fromfloat32(f32s[i]).Bits()
u16s[i] = float16.Fromfloat32(f32s[i]).Bits()
}
return 0, binary.Write(w, binary.LittleEndian, f16s)
return 0, binary.Write(w, binary.LittleEndian, u16s)
case tensorKindBF16:
u8s := bfloat16.EncodeFloat32(f32s)
return 0, binary.Write(w, binary.LittleEndian, u8s)
return 0, binary.Write(w, binary.LittleEndian, bfloat16.FromFloat32s(f32s))
default:
return 0, fmt.Errorf("unknown storage type: %d", st.Kind())
}

View File

@ -3,11 +3,11 @@ package convert
import (
"bytes"
"encoding/binary"
"math"
"os"
"path/filepath"
"testing"
"github.com/d4l3k/go-bfloat16"
"github.com/google/go-cmp/cmp"
"github.com/x448/float16"
)
@ -132,12 +132,13 @@ func TestSafetensors(t *testing.T) {
size: 32 * 2, // 32 brain floats, each 2 bytes
shape: []uint64{16, 2},
setup: func(t *testing.T, f *os.File) {
f32s := make([]float32, 32)
for i := range f32s {
f32s[i] = float32(i)
u16s := make([]uint16, 32)
for i := range u16s {
bits := math.Float32bits(float32(i))
u16s[i] = uint16(bits >> 16)
}
if err := binary.Write(f, binary.LittleEndian, bfloat16.EncodeFloat32(f32s)); err != nil {
if err := binary.Write(f, binary.LittleEndian, u16s); err != nil {
t.Fatal(err)
}
},
@ -154,12 +155,13 @@ func TestSafetensors(t *testing.T) {
size: 32 * 2, // 32 brain floats, each 2 bytes
shape: []uint64{32},
setup: func(t *testing.T, f *os.File) {
f32s := make([]float32, 32)
for i := range f32s {
f32s[i] = float32(i)
u16s := make([]uint16, 32)
for i := range u16s {
bits := math.Float32bits(float32(i))
u16s[i] = uint16(bits >> 16)
}
if err := binary.Write(f, binary.LittleEndian, bfloat16.EncodeFloat32(f32s)); err != nil {
if err := binary.Write(f, binary.LittleEndian, u16s); err != nil {
t.Fatal(err)
}
},

1
go.mod
View File

@ -16,7 +16,6 @@ require (
require (
github.com/agnivade/levenshtein v1.1.1
github.com/d4l3k/go-bfloat16 v0.0.0-20211005043715-690c3bdd05f1
github.com/dlclark/regexp2 v1.11.4
github.com/emirpasic/gods/v2 v2.0.0-alpha
github.com/google/go-cmp v0.7.0

2
go.sum
View File

@ -35,8 +35,6 @@ github.com/containerd/console v1.0.3 h1:lIr7SlA5PxZyMV30bDW0MGbiOPXwc63yRuCP0ARu
github.com/containerd/console v1.0.3/go.mod h1:7LqA/THxQ86k76b8c/EMSiaJ3h1eZkMkXar0TQ1gf3U=
github.com/cpuguy83/go-md2man/v2 v2.0.2/go.mod h1:tgQtvFlXSQOSOSIRvRPT7W67SCa46tRHOmNcaadrF8o=
github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E=
github.com/d4l3k/go-bfloat16 v0.0.0-20211005043715-690c3bdd05f1 h1:cBzrdJPAFBsgCrDPnZxlp1dF2+k4r1kVpD7+1S1PVjY=
github.com/d4l3k/go-bfloat16 v0.0.0-20211005043715-690c3bdd05f1/go.mod h1:uw2gLcxEuYUlAd/EXyjc/v55nd3+47YAgWbSXVxPrNI=
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=