65 lines
1.4 KiB
Go
65 lines
1.4 KiB
Go
package pooling_test
|
|
|
|
import (
|
|
"bytes"
|
|
"os"
|
|
"testing"
|
|
|
|
"github.com/google/go-cmp/cmp"
|
|
fsggml "github.com/ollama/ollama/fs/ggml"
|
|
"github.com/ollama/ollama/ml"
|
|
"github.com/ollama/ollama/ml/backend/ggml"
|
|
"github.com/ollama/ollama/ml/nn/pooling"
|
|
)
|
|
|
|
func setup(tb testing.TB, n int) ml.Backend {
|
|
tb.Helper()
|
|
|
|
f, err := os.CreateTemp(tb.TempDir(), "*.bin")
|
|
if err != nil {
|
|
tb.Fatal(err)
|
|
}
|
|
defer f.Close()
|
|
|
|
if err := fsggml.WriteGGUF(f, fsggml.KV{
|
|
"general.architecture": "test",
|
|
"test.block_count": uint32(1),
|
|
}, []*fsggml.Tensor{
|
|
{Name: "blk.0.weight", Shape: []uint64{1}, WriterTo: bytes.NewBuffer(make([]byte, 4))},
|
|
}); err != nil {
|
|
tb.Fatal(err)
|
|
}
|
|
|
|
b, err := ggml.New(f.Name(), ml.BackendParams{AllocMemory: true})
|
|
if err != nil {
|
|
tb.Fatal(err)
|
|
}
|
|
|
|
return b
|
|
}
|
|
|
|
func TestForward(t *testing.T) {
|
|
cases := map[pooling.Type][]float32{
|
|
pooling.TypeMean: {4, 5, 6, 7, 8, 9, 10, 11},
|
|
pooling.TypeCLS: {0, 1, 2, 3, 4, 5, 6, 7},
|
|
pooling.TypeLast: {8, 9, 10, 11, 12, 13, 14, 15},
|
|
}
|
|
for typ, want := range cases {
|
|
t.Run(typ.String(), func(t *testing.T) {
|
|
b := setup(t, 99)
|
|
defer b.Close()
|
|
|
|
ctx := b.NewContext()
|
|
defer ctx.Close()
|
|
|
|
tt := ctx.Input().Arange(0, 16, 1, ml.DTypeF32).Reshape(ctx, 8, 2)
|
|
tt = typ.Forward(ctx, tt)
|
|
|
|
ctx.Forward(tt).Compute(tt)
|
|
if diff := cmp.Diff(want, tt.Floats()); diff != "" {
|
|
t.Error(diff)
|
|
}
|
|
})
|
|
}
|
|
}
|