sample: add sampling package for new engine (#8410)
This commit is contained in:
238
sample/samplers_test.go
Normal file
238
sample/samplers_test.go
Normal file
@@ -0,0 +1,238 @@
|
||||
package sample
|
||||
|
||||
import (
|
||||
"math"
|
||||
"math/rand/v2"
|
||||
"testing"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
)
|
||||
|
||||
func TestWeighted(t *testing.T) {
|
||||
got, err := Weighted(nil).Sample([]float32{float32(math.Inf(-1)), 2, float32(math.Inf(-1)), float32(math.Inf(-1))})
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
return
|
||||
}
|
||||
want := int32(1)
|
||||
if want != got {
|
||||
t.Errorf("index mismatch: want %d, got %d", want, got)
|
||||
}
|
||||
|
||||
got, err = Weighted(nil).Sample([]float32{float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1))})
|
||||
if err == nil {
|
||||
t.Error("expected error for no valid tokens, got index", got)
|
||||
}
|
||||
|
||||
seed := uint64(42)
|
||||
got, err = Weighted(&seed).Sample([]float32{1, 2, 3, 4})
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
return
|
||||
}
|
||||
// With seed 42, we expect a consistent sample
|
||||
want = int32(3) // This will be deterministic due to the seed
|
||||
if want != got {
|
||||
t.Errorf("index mismatch: want %d, got %d", want, got)
|
||||
}
|
||||
}
|
||||
|
||||
type testTransform struct {
|
||||
id int
|
||||
callOrder *[]int
|
||||
}
|
||||
|
||||
func (ts *testTransform) Apply(logits []float64) []float64 {
|
||||
if ts.callOrder != nil {
|
||||
*ts.callOrder = append(*ts.callOrder, ts.id)
|
||||
}
|
||||
return logits
|
||||
}
|
||||
|
||||
func TestSample(t *testing.T) {
|
||||
input := []float32{1, 2, 3, 4}
|
||||
|
||||
var callOrder []int
|
||||
mock1 := &testTransform{
|
||||
id: 1,
|
||||
callOrder: &callOrder,
|
||||
}
|
||||
mock2 := &testTransform{
|
||||
id: 2,
|
||||
callOrder: &callOrder,
|
||||
}
|
||||
mock3 := &testTransform{
|
||||
id: 3,
|
||||
callOrder: &callOrder,
|
||||
}
|
||||
|
||||
got, err := Greedy(mock1, mock2, mock3).Sample(input)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
return
|
||||
}
|
||||
|
||||
want := int32(3) // Greedy sampler should pick highest logit
|
||||
if want != got {
|
||||
t.Errorf("index mismatch: want %d, got %d", want, got)
|
||||
}
|
||||
wantOrder := []int{1, 2, 3}
|
||||
if diff := cmp.Diff(wantOrder, callOrder); diff != "" {
|
||||
t.Errorf("call order mismatch (-want +got):\n%s", diff)
|
||||
}
|
||||
|
||||
callOrder = nil
|
||||
|
||||
_, err = Weighted(nil, mock1, mock2, mock3).Sample(input)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
return
|
||||
}
|
||||
wantOrder = []int{1, 2, 3}
|
||||
if diff := cmp.Diff(wantOrder, callOrder); diff != "" {
|
||||
t.Errorf("call order mismatch (-want +got):\n%s", diff)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewSampler(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
temperature float32
|
||||
topK int
|
||||
topP float32
|
||||
minP float32
|
||||
seed int
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "no transforms",
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "temperature",
|
||||
temperature: 0.5,
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "invalid temperature negative",
|
||||
temperature: -1,
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "invalid temperature too high",
|
||||
temperature: 2.1,
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "top k",
|
||||
topK: 10,
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "invalid top k negative",
|
||||
topK: -1,
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "top p",
|
||||
topP: 0.9,
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "invalid top p negative",
|
||||
topP: -0.1,
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "invalid top p one",
|
||||
topP: 1.0,
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "min p",
|
||||
minP: 0.2,
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "invalid min p negative",
|
||||
minP: -0.1,
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "invalid min p one",
|
||||
minP: 1.0,
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "seed",
|
||||
seed: 42,
|
||||
wantErr: true, // seed alone is not valid without other transforms
|
||||
},
|
||||
{
|
||||
name: "default values",
|
||||
temperature: 0.8,
|
||||
topK: 40,
|
||||
topP: 0.9,
|
||||
minP: 0.0,
|
||||
seed: 0,
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "all zeroes",
|
||||
temperature: 0.0,
|
||||
topK: 0,
|
||||
topP: 0.0,
|
||||
minP: 0.0,
|
||||
seed: 0,
|
||||
wantErr: true, // all zeroes means no transforms
|
||||
},
|
||||
{
|
||||
name: "all transforms",
|
||||
temperature: 0.8,
|
||||
topK: 50,
|
||||
topP: 0.95,
|
||||
minP: 0.1,
|
||||
seed: 42,
|
||||
wantErr: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
_, err := NewSampler(tt.temperature, tt.topK, tt.topP, tt.minP, tt.seed)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("NewSampler() error = %v, wantErr %v", err, tt.wantErr)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkSample(b *testing.B) {
|
||||
transforms := []Transform{
|
||||
Temperature(0.5),
|
||||
TopK(10),
|
||||
TopP(0.9),
|
||||
MinP(0.2),
|
||||
}
|
||||
|
||||
samplers := map[string]Sampler{
|
||||
"Greedy": Greedy(transforms...),
|
||||
"Weighted": Weighted(nil, transforms...),
|
||||
}
|
||||
|
||||
logits := make([]float32, 1<<16)
|
||||
for i := range logits {
|
||||
logits[i] = rand.Float32()
|
||||
}
|
||||
|
||||
for name, s := range samplers {
|
||||
b.Run(name, func(b *testing.B) {
|
||||
b.ResetTimer()
|
||||
for range b.N {
|
||||
if _, err := s.Sample(logits); err != nil {
|
||||
b.Error(err)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user