improve temperature sampler

This commit is contained in:
ParthSareen
2025-01-20 13:42:23 -08:00
parent b91487f289
commit 7cd9fbbbb1
2 changed files with 56 additions and 32 deletions

View File

@@ -16,16 +16,19 @@ type Sampler interface {
type Temperature float64
func (s Temperature) Sample(logits []float64) ([]float64, error) {
if s < 0 || s > 1 {
return nil, errors.New("temperature must be between 0 and 1")
func (t Temperature) Sample(logits []float64) ([]float64, error) {
if t < 0 || t > 2 {
return nil, errors.New("temperature must be between 0 and 2")
}
// greedy sampling
if s == 0 {
return []float64{floats.Max(logits)}, nil
// subtracting max logit to avoid under/overflow
maxLogit := floats.Max(logits)
temp := math.Max(float64(t), 1e-7)
for i := range logits {
logits[i] = (logits[i] - maxLogit) / temp
}
floats.Scale(1.0/float64(s), logits)
return logits, nil
}
@@ -47,10 +50,8 @@ func computeSoftmax(logits []float64) ([]float64, error) {
}
floatSum := floats.Sum(copiedLogits)
if floatSum == 0 {
return nil, errors.New("no valid tokens found")
}
floats.Scale(1.0/floatSum, copiedLogits)
return copiedLogits, nil
}
@@ -175,9 +176,28 @@ func (s weighed) Sample(logits []float64) ([]float64, error) {
return nil, errors.New("weighed sampler failed")
}
// TODO: remove after next PR merge
type greedy struct{}
func Greedy() Sampler {
return greedy{}
}
func (greedy) Sample(logits []float64) ([]float64, error) {
return []float64{float64(floats.MaxIdx(logits))}, nil
}
func Sample(logits []float64, samplers ...Sampler) ([]float64, error) {
var err error
for _, sampler := range samplers {
if sampler == Temperature(0) {
// early return with greedy if temperature is 0
logits, err = Greedy().Sample(logits)
if err != nil {
return nil, err
}
return logits, nil
}
logits, err = sampler.Sample(logits)
if err != nil {
return nil, err