From 7bd3f0269cb9e29a2b2d02111eb355cfc345945b Mon Sep 17 00:00:00 2001 From: Michael Yang Date: Tue, 5 Aug 2025 20:06:33 -0700 Subject: [PATCH] drop float16 dependency goos: darwin goarch: arm64 pkg: github.com/ollama/ollama/convert/float16 cpu: Apple M3 Max BenchmarkFloat16/x448/float16-16 159 7398462 ns/op BenchmarkFloat16/simple-16 512 2327098 ns/op PASS ok github.com/ollama/ollama/convert/float16 2.553s --- convert/float16/float16.go | 97 +++++++++++++++++++++++++++++++++ convert/float16/float16_test.go | 75 +++++++++++++++++++++++++ convert/reader_safetensors.go | 14 +---- go.mod | 1 - go.sum | 2 - 5 files changed, 175 insertions(+), 14 deletions(-) create mode 100644 convert/float16/float16.go create mode 100644 convert/float16/float16_test.go diff --git a/convert/float16/float16.go b/convert/float16/float16.go new file mode 100644 index 000000000..06607cd35 --- /dev/null +++ b/convert/float16/float16.go @@ -0,0 +1,97 @@ +package float16 + +import ( + "math" +) + +func FromFloat32s(f32s []float32) (u16s []uint16) { + u16s = make([]uint16, len(f32s)) + for i := range f32s { + bits := math.Float32bits(f32s[i]) + sign := (bits >> 31) & 0x1 + exponent := (bits >> 23) & 0xFF + mantissa := bits & 0x7FFFFF + if exponent == 0xFF { + if mantissa == 0 { + // Infinity + u16s[i] = uint16((sign << 15) | 0x7C00) + } else { + // NaN + u16s[i] = uint16((sign << 15) | 0x7C00 | (mantissa >> 13)) + } + } else if exponent == 0 && mantissa == 0 { + // Zero + u16s[i] = uint16(sign << 15) + } else { + // Convert exponent from FP32 bias (127) to FP16 bias (15) + exponent := int(exponent) - 127 + 15 + if exponent >= 31 { + // Overflow to infinity + u16s[i] = uint16((sign << 15) | 0x7C00) + } else if exponent <= 0 { + // Underflow - create subnormal or zero + if exponent < -10 { + u16s[i] = uint16(sign << 15) // Zero + } else { + // Subnormal number + mantissa = (mantissa | 0x800000) >> uint(-exponent+1) + u16s[i] = uint16((sign << 15) | (mantissa >> 13)) + } + } else { + // Normal number - truncate mantissa from 23 to 10 bits + u16s[i] = uint16((sign << 15) | (uint32(exponent) << 10) | (mantissa >> 13)) + } + } + } + + return u16s +} + +func Float32s(u16s []uint16) (f32s []float32) { + f32s = make([]float32, len(u16s)) + for i := range u16s { + sign := (u16s[i] >> 15) & 0x1 + exponent := (u16s[i] >> 10) & 0x1F + mantissa := u16s[i] & 0x3FF + + var u32 uint32 + switch exponent { + case 0: + if mantissa == 0 { + // Zero + u32 = uint32(sign) << 31 + } else { + // Subnormal - convert to normal + // Find leading 1 bit + shift := 0 + temp := mantissa + for temp&0x400 == 0 { + temp <<= 1 + shift++ + } + + exponent := 127 - 15 + 1 - shift + mantissa := (uint32(temp&0x3FF) << 13) + + u32 = (uint32(sign) << 31) | (uint32(exponent) << 23) | mantissa + } + case 0x1F: + if mantissa == 0 { + // Infinity + u32 = (uint32(sign) << 31) | 0x7F800000 + } else { + // NaN + u32 = (uint32(sign) << 31) | 0x7F800000 | (uint32(mantissa) << 13) + } + default: + // Normal number + exponent := uint32(exponent) - 15 + 127 + mantissa := uint32(mantissa) << 13 + + u32 = (uint32(sign) << 31) | (exponent << 23) | mantissa + } + + f32s[i] = math.Float32frombits(u32) + } + return f32s +} diff --git a/convert/float16/float16_test.go b/convert/float16/float16_test.go new file mode 100644 index 000000000..313339a2e --- /dev/null +++ b/convert/float16/float16_test.go @@ -0,0 +1,75 @@ +package float16 + +import ( + "math" + "math/rand/v2" + "testing" + + "github.com/google/go-cmp/cmp" +) + +func TestFloat16(t *testing.T) { + cases := []struct { + name string + input uint16 + want uint32 + }{ + // Zero cases + {"positive zero", 0x0000, 0x0}, + {"negative zero", 0x8000, 0x80000000}, + + // Normal numbers + {"one", 0x3C00, 0x3F800000}, + {"negative one", 0xBC00, 0xBF800000}, + {"two", 0x4000, 0x40000000}, + {"half", 0x3800, 0x3F000000}, + {"max normal", 0x7BFF, 0x477fe000}, + {"min positive normal", 0x0400, 0x38800000}, + + // Infinity cases + {"positive infinity", 0x7C00, 0x7F800000}, + {"negative infinity", 0xFC00, 0xFF800000}, + + // NaN cases + {"NaN", 0x7C01, 0x7f802000}, + {"NaN with payload", 0x7E00, 0x7FC00000}, + + // Subnormal cases + {"min positive subnormal", 0x0001, 0x33800000}, + {"max subnormal", 0x03FF, 0x387fc000}, + + // Common values + {"pi approximation", 0x4248, 0x40490000}, + {"e approximation", 0x416F, 0x402de000}, + } + + for _, tt := range cases { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + t.Run("Float32s", func(t *testing.T) { + got := Float32s([]uint16{tt.input})[0] + if diff := cmp.Diff(tt.want, math.Float32bits(got)); diff != "" { + t.Errorf("Float32s mismatch (-want +got):\n%s", diff) + } + }) + + t.Run("FromFloat32s", func(t *testing.T) { + got := FromFloat32s([]float32{math.Float32frombits(tt.want)}) + if diff := cmp.Diff([]uint16{tt.input}, got); diff != "" { + t.Errorf("FromFloat32s mismatch (-want +got):\n%s", diff) + } + }) + }) + } +} + +func BenchmarkFloat16(b *testing.B) { + f32s := make([]float32, 1_000_000) + for i := range f32s { + f32s[i] = rand.Float32() + } + for b.Loop() { + Float32s(FromFloat32s(f32s)) + } +} diff --git a/convert/reader_safetensors.go b/convert/reader_safetensors.go index fb8bfaff8..95e24934f 100644 --- a/convert/reader_safetensors.go +++ b/convert/reader_safetensors.go @@ -14,7 +14,7 @@ import ( "strings" "github.com/ollama/ollama/convert/bfloat16" - "github.com/x448/float16" + "github.com/ollama/ollama/convert/float16" ) type safetensorMetadata struct { @@ -163,10 +163,7 @@ func (st safetensor) WriteTo(w io.Writer) (int64, error) { return 0, err } - f32s = make([]float32, len(u16s)) - for i := range u16s { - f32s[i] = float16.Frombits(u16s[i]).Float32() - } + f32s = float16.Float32s(u16s) case "BF16": u16s := make([]uint16, st.size/2) @@ -191,12 +188,7 @@ func (st safetensor) WriteTo(w io.Writer) (int64, error) { case tensorKindFP32: return 0, binary.Write(w, binary.LittleEndian, f32s) case tensorKindFP16: - u16s := make([]uint16, len(f32s)) - for i := range f32s { - u16s[i] = float16.Fromfloat32(f32s[i]).Bits() - } - - return 0, binary.Write(w, binary.LittleEndian, u16s) + return 0, binary.Write(w, binary.LittleEndian, float16.FromFloat32s(f32s)) case tensorKindBF16: return 0, binary.Write(w, binary.LittleEndian, bfloat16.FromFloat32s(f32s)) default: diff --git a/go.mod b/go.mod index 083db72ff..b7a4001fa 100644 --- a/go.mod +++ b/go.mod @@ -10,7 +10,6 @@ require ( github.com/olekukonko/tablewriter v0.0.5 github.com/spf13/cobra v1.7.0 github.com/stretchr/testify v1.9.0 - github.com/x448/float16 v0.8.4 golang.org/x/sync v0.12.0 ) diff --git a/go.sum b/go.sum index b04430de5..eea4f7cda 100644 --- a/go.sum +++ b/go.sum @@ -195,8 +195,6 @@ github.com/twitchyliquid64/golang-asm v0.15.1 h1:SU5vSMR7hnwNxj24w34ZyCi/FmDZTkS github.com/twitchyliquid64/golang-asm v0.15.1/go.mod h1:a1lVb/DtPvCB8fslRZhAngC2+aY1QWCk3Cedj/Gdt08= github.com/ugorji/go/codec v1.2.12 h1:9LC83zGrHhuUA9l16C9AHXAqEV/2wBQ4nkvumAE65EE= github.com/ugorji/go/codec v1.2.12/go.mod h1:UNopzCgEMSXjBc6AOMqYvWC1ktqTAfzJZUZgYf6w6lg= -github.com/x448/float16 v0.8.4 h1:qLwI1I70+NjRFUR3zs1JPUCgaCXSh3SW62uAKT1mSBM= -github.com/x448/float16 v0.8.4/go.mod h1:14CWIYCyZA/cWjXOioeEpHeN/83MdbZDRQHoFcYsOfg= github.com/xtgo/set v1.0.0 h1:6BCNBRv3ORNDQ7fyoJXRv+tstJz3m1JVFQErfeZz2pY= github.com/xtgo/set v1.0.0/go.mod h1:d3NHzGzSa0NmB2NhFyECA+QdRp29oEn2xbT+TpeFoM8= github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74=