saved state
This commit is contained in:
@@ -165,11 +165,12 @@ func (s weighed) Sample(logits []float64) ([]float64, error) {
|
||||
if len(logitsCopy) == 0 {
|
||||
return nil, errors.New("no valid tokens found")
|
||||
}
|
||||
logitsCopy, err := computeSoftmax(logitsCopy)
|
||||
|
||||
softmax, err := computeSoftmax(logitsCopy)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
w := sampleuv.NewWeighted(logitsCopy, nil)
|
||||
w := sampleuv.NewWeighted(softmax, nil)
|
||||
if v, ok := w.Take(); ok {
|
||||
// returns the token ID
|
||||
return []float64{float64(indices[v])}, nil
|
||||
|
||||
Reference in New Issue
Block a user