diff --git a/convert/reader_test.go b/convert/reader_test.go index efcf98cb6..ea9624e63 100644 --- a/convert/reader_test.go +++ b/convert/reader_test.go @@ -3,13 +3,13 @@ package convert import ( "bytes" "encoding/binary" - "math" "os" "path/filepath" "testing" "github.com/google/go-cmp/cmp" - "github.com/x448/float16" + "github.com/ollama/ollama/convert/bfloat16" + "github.com/ollama/ollama/convert/float16" ) func TestSafetensors(t *testing.T) { @@ -21,6 +21,11 @@ func TestSafetensors(t *testing.T) { } defer root.Close() + f32s := make([]float32, 32) + for i := range f32s { + f32s[i] = float32(i) + } + cases := []struct { name, dtype string @@ -36,11 +41,6 @@ func TestSafetensors(t *testing.T) { size: 32 * 4, // 32 floats, each 4 bytes shape: []uint64{32}, setup: func(t *testing.T, f *os.File) { - f32s := make([]float32, 32) - for i := range f32s { - f32s[i] = float32(i) - } - if err := binary.Write(f, binary.LittleEndian, f32s); err != nil { t.Fatal(err) } @@ -62,11 +62,6 @@ func TestSafetensors(t *testing.T) { size: 32 * 4, // 32 floats, each 4 bytes shape: []uint64{16, 2}, setup: func(t *testing.T, f *os.File) { - f32s := make([]float32, 32) - for i := range f32s { - f32s[i] = float32(i) - } - if err := binary.Write(f, binary.LittleEndian, f32s); err != nil { t.Fatal(err) } @@ -84,12 +79,7 @@ func TestSafetensors(t *testing.T) { size: 32 * 2, // 32 floats, each 2 bytes shape: []uint64{16, 2}, setup: func(t *testing.T, f *os.File) { - u16s := make([]uint16, 32) - for i := range u16s { - u16s[i] = float16.Fromfloat32(float32(i)).Bits() - } - - if err := binary.Write(f, binary.LittleEndian, u16s); err != nil { + if err := binary.Write(f, binary.LittleEndian, float16.FromFloat32s(f32s)); err != nil { t.Fatal(err) } }, @@ -106,12 +96,7 @@ func TestSafetensors(t *testing.T) { size: 32 * 2, // 32 floats, each 2 bytes shape: []uint64{32}, setup: func(t *testing.T, f *os.File) { - u16s := make([]uint16, 32) - for i := range u16s { - u16s[i] = float16.Fromfloat32(float32(i)).Bits() - } - - if err := binary.Write(f, binary.LittleEndian, u16s); err != nil { + if err := binary.Write(f, binary.LittleEndian, float16.FromFloat32s(f32s)); err != nil { t.Fatal(err) } }, @@ -132,13 +117,7 @@ func TestSafetensors(t *testing.T) { size: 32 * 2, // 32 brain floats, each 2 bytes shape: []uint64{16, 2}, setup: func(t *testing.T, f *os.File) { - u16s := make([]uint16, 32) - for i := range u16s { - bits := math.Float32bits(float32(i)) - u16s[i] = uint16(bits >> 16) - } - - if err := binary.Write(f, binary.LittleEndian, u16s); err != nil { + if err := binary.Write(f, binary.LittleEndian, bfloat16.FromFloat32s(f32s)); err != nil { t.Fatal(err) } }, @@ -155,13 +134,7 @@ func TestSafetensors(t *testing.T) { size: 32 * 2, // 32 brain floats, each 2 bytes shape: []uint64{32}, setup: func(t *testing.T, f *os.File) { - u16s := make([]uint16, 32) - for i := range u16s { - bits := math.Float32bits(float32(i)) - u16s[i] = uint16(bits >> 16) - } - - if err := binary.Write(f, binary.LittleEndian, u16s); err != nil { + if err := binary.Write(f, binary.LittleEndian, bfloat16.FromFloat32s(f32s)); err != nil { t.Fatal(err) } },