drop float16 dependency

goos: darwin
goarch: arm64
pkg: github.com/ollama/ollama/convert/float16
cpu: Apple M3 Max
BenchmarkFloat16/x448/float16-16                     159           7398462 ns/op
BenchmarkFloat16/simple-16                           512           2327098 ns/op
PASS
ok      github.com/ollama/ollama/convert/float16        2.553s
This commit is contained in:
Michael Yang
2025-08-05 20:06:33 -07:00
parent 276c4df770
commit 7bd3f0269c
5 changed files with 175 additions and 14 deletions

View File

@@ -0,0 +1,97 @@
package float16
import (
"math"
)
func FromFloat32s(f32s []float32) (u16s []uint16) {
u16s = make([]uint16, len(f32s))
for i := range f32s {
bits := math.Float32bits(f32s[i])
sign := (bits >> 31) & 0x1
exponent := (bits >> 23) & 0xFF
mantissa := bits & 0x7FFFFF
if exponent == 0xFF {
if mantissa == 0 {
// Infinity
u16s[i] = uint16((sign << 15) | 0x7C00)
} else {
// NaN
u16s[i] = uint16((sign << 15) | 0x7C00 | (mantissa >> 13))
}
} else if exponent == 0 && mantissa == 0 {
// Zero
u16s[i] = uint16(sign << 15)
} else {
// Convert exponent from FP32 bias (127) to FP16 bias (15)
exponent := int(exponent) - 127 + 15
if exponent >= 31 {
// Overflow to infinity
u16s[i] = uint16((sign << 15) | 0x7C00)
} else if exponent <= 0 {
// Underflow - create subnormal or zero
if exponent < -10 {
u16s[i] = uint16(sign << 15) // Zero
} else {
// Subnormal number
mantissa = (mantissa | 0x800000) >> uint(-exponent+1)
u16s[i] = uint16((sign << 15) | (mantissa >> 13))
}
} else {
// Normal number - truncate mantissa from 23 to 10 bits
u16s[i] = uint16((sign << 15) | (uint32(exponent) << 10) | (mantissa >> 13))
}
}
}
return u16s
}
func Float32s(u16s []uint16) (f32s []float32) {
f32s = make([]float32, len(u16s))
for i := range u16s {
sign := (u16s[i] >> 15) & 0x1
exponent := (u16s[i] >> 10) & 0x1F
mantissa := u16s[i] & 0x3FF
var u32 uint32
switch exponent {
case 0:
if mantissa == 0 {
// Zero
u32 = uint32(sign) << 31
} else {
// Subnormal - convert to normal
// Find leading 1 bit
shift := 0
temp := mantissa
for temp&0x400 == 0 {
temp <<= 1
shift++
}
exponent := 127 - 15 + 1 - shift
mantissa := (uint32(temp&0x3FF) << 13)
u32 = (uint32(sign) << 31) | (uint32(exponent) << 23) | mantissa
}
case 0x1F:
if mantissa == 0 {
// Infinity
u32 = (uint32(sign) << 31) | 0x7F800000
} else {
// NaN
u32 = (uint32(sign) << 31) | 0x7F800000 | (uint32(mantissa) << 13)
}
default:
// Normal number
exponent := uint32(exponent) - 15 + 127
mantissa := uint32(mantissa) << 13
u32 = (uint32(sign) << 31) | (exponent << 23) | mantissa
}
f32s[i] = math.Float32frombits(u32)
}
return f32s
}

View File

@@ -0,0 +1,75 @@
package float16
import (
"math"
"math/rand/v2"
"testing"
"github.com/google/go-cmp/cmp"
)
func TestFloat16(t *testing.T) {
cases := []struct {
name string
input uint16
want uint32
}{
// Zero cases
{"positive zero", 0x0000, 0x0},
{"negative zero", 0x8000, 0x80000000},
// Normal numbers
{"one", 0x3C00, 0x3F800000},
{"negative one", 0xBC00, 0xBF800000},
{"two", 0x4000, 0x40000000},
{"half", 0x3800, 0x3F000000},
{"max normal", 0x7BFF, 0x477fe000},
{"min positive normal", 0x0400, 0x38800000},
// Infinity cases
{"positive infinity", 0x7C00, 0x7F800000},
{"negative infinity", 0xFC00, 0xFF800000},
// NaN cases
{"NaN", 0x7C01, 0x7f802000},
{"NaN with payload", 0x7E00, 0x7FC00000},
// Subnormal cases
{"min positive subnormal", 0x0001, 0x33800000},
{"max subnormal", 0x03FF, 0x387fc000},
// Common values
{"pi approximation", 0x4248, 0x40490000},
{"e approximation", 0x416F, 0x402de000},
}
for _, tt := range cases {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
t.Run("Float32s", func(t *testing.T) {
got := Float32s([]uint16{tt.input})[0]
if diff := cmp.Diff(tt.want, math.Float32bits(got)); diff != "" {
t.Errorf("Float32s mismatch (-want +got):\n%s", diff)
}
})
t.Run("FromFloat32s", func(t *testing.T) {
got := FromFloat32s([]float32{math.Float32frombits(tt.want)})
if diff := cmp.Diff([]uint16{tt.input}, got); diff != "" {
t.Errorf("FromFloat32s mismatch (-want +got):\n%s", diff)
}
})
})
}
}
func BenchmarkFloat16(b *testing.B) {
f32s := make([]float32, 1_000_000)
for i := range f32s {
f32s[i] = rand.Float32()
}
for b.Loop() {
Float32s(FromFloat32s(f32s))
}
}