update tests

This commit is contained in:
Michael Yang 2025-08-05 20:43:38 -07:00
parent 7bd3f0269c
commit 69f3dfdedf
1 changed files with 11 additions and 38 deletions

View File

@ -3,13 +3,13 @@ package convert
import (
"bytes"
"encoding/binary"
"math"
"os"
"path/filepath"
"testing"
"github.com/google/go-cmp/cmp"
"github.com/x448/float16"
"github.com/ollama/ollama/convert/bfloat16"
"github.com/ollama/ollama/convert/float16"
)
func TestSafetensors(t *testing.T) {
@ -21,6 +21,11 @@ func TestSafetensors(t *testing.T) {
}
defer root.Close()
f32s := make([]float32, 32)
for i := range f32s {
f32s[i] = float32(i)
}
cases := []struct {
name,
dtype string
@ -36,11 +41,6 @@ func TestSafetensors(t *testing.T) {
size: 32 * 4, // 32 floats, each 4 bytes
shape: []uint64{32},
setup: func(t *testing.T, f *os.File) {
f32s := make([]float32, 32)
for i := range f32s {
f32s[i] = float32(i)
}
if err := binary.Write(f, binary.LittleEndian, f32s); err != nil {
t.Fatal(err)
}
@ -62,11 +62,6 @@ func TestSafetensors(t *testing.T) {
size: 32 * 4, // 32 floats, each 4 bytes
shape: []uint64{16, 2},
setup: func(t *testing.T, f *os.File) {
f32s := make([]float32, 32)
for i := range f32s {
f32s[i] = float32(i)
}
if err := binary.Write(f, binary.LittleEndian, f32s); err != nil {
t.Fatal(err)
}
@ -84,12 +79,7 @@ func TestSafetensors(t *testing.T) {
size: 32 * 2, // 32 floats, each 2 bytes
shape: []uint64{16, 2},
setup: func(t *testing.T, f *os.File) {
u16s := make([]uint16, 32)
for i := range u16s {
u16s[i] = float16.Fromfloat32(float32(i)).Bits()
}
if err := binary.Write(f, binary.LittleEndian, u16s); err != nil {
if err := binary.Write(f, binary.LittleEndian, float16.FromFloat32s(f32s)); err != nil {
t.Fatal(err)
}
},
@ -106,12 +96,7 @@ func TestSafetensors(t *testing.T) {
size: 32 * 2, // 32 floats, each 2 bytes
shape: []uint64{32},
setup: func(t *testing.T, f *os.File) {
u16s := make([]uint16, 32)
for i := range u16s {
u16s[i] = float16.Fromfloat32(float32(i)).Bits()
}
if err := binary.Write(f, binary.LittleEndian, u16s); err != nil {
if err := binary.Write(f, binary.LittleEndian, float16.FromFloat32s(f32s)); err != nil {
t.Fatal(err)
}
},
@ -132,13 +117,7 @@ func TestSafetensors(t *testing.T) {
size: 32 * 2, // 32 brain floats, each 2 bytes
shape: []uint64{16, 2},
setup: func(t *testing.T, f *os.File) {
u16s := make([]uint16, 32)
for i := range u16s {
bits := math.Float32bits(float32(i))
u16s[i] = uint16(bits >> 16)
}
if err := binary.Write(f, binary.LittleEndian, u16s); err != nil {
if err := binary.Write(f, binary.LittleEndian, bfloat16.FromFloat32s(f32s)); err != nil {
t.Fatal(err)
}
},
@ -155,13 +134,7 @@ func TestSafetensors(t *testing.T) {
size: 32 * 2, // 32 brain floats, each 2 bytes
shape: []uint64{32},
setup: func(t *testing.T, f *os.File) {
u16s := make([]uint16, 32)
for i := range u16s {
bits := math.Float32bits(float32(i))
u16s[i] = uint16(bits >> 16)
}
if err := binary.Write(f, binary.LittleEndian, u16s); err != nil {
if err := binary.Write(f, binary.LittleEndian, bfloat16.FromFloat32s(f32s)); err != nil {
t.Fatal(err)
}
},