Revert changes in transforms_test.go
This commit is contained in:
parent
29b1ed0077
commit
0ddb64db1f
|
|
@ -178,16 +178,8 @@ func TestTopP(t *testing.T) {
|
|||
|
||||
// Test with normal p value
|
||||
got = topP(tokens, 0.95)
|
||||
// Should keep tokens until cumsum > 0.95
|
||||
|
||||
if len(got) > 3 {
|
||||
t.Errorf("topP(0.95): kept too many tokens: got %d", len(got))
|
||||
t.Logf("got: %v", got)
|
||||
}
|
||||
|
||||
// Test with normal p value
|
||||
got = topP(tokens, 0.95)
|
||||
|
||||
if len(tokens) > 3 {
|
||||
t.Errorf("topP(0.95): kept too many tokens: got %d", len(tokens))
|
||||
t.Logf("got: %v", got)
|
||||
}
|
||||
|
|
@ -216,17 +208,8 @@ func TestTopP(t *testing.T) {
|
|||
softmax(tokens)
|
||||
got = topP(tokens, 1e-10)
|
||||
if len(got) == 0 {
|
||||
t.Errorf("topP(1e-10): should keep at least one token, got %d", len(tokens))
|
||||
t.Logf("got: %v", tokens)
|
||||
}
|
||||
|
||||
// Test edge case - ensure at least one token remains
|
||||
input = []float32{-1e6, -1e6, -1e6} // One dominant token
|
||||
tokens = toTokens(input)
|
||||
softmax(tokens)
|
||||
tokens = topP(tokens, 0.0) // Very small p
|
||||
if len(tokens) < 1 {
|
||||
t.Error("topP should keep at least one token")
|
||||
t.Errorf("topP(1e-10): should keep at least one token, got %d", len(got))
|
||||
t.Logf("got: %v", got)
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -268,27 +251,6 @@ func TestMinP(t *testing.T) {
|
|||
t.Logf("got: %v", tokens)
|
||||
}
|
||||
|
||||
tokens = topK(tokens, 20)
|
||||
softmax(tokens)
|
||||
|
||||
tokens = minP(tokens, 1.0)
|
||||
|
||||
if len(tokens) != 1 {
|
||||
t.Errorf("minP(1.0): should keep all tokens, got %d, want %d", len(tokens), len(tokens))
|
||||
}
|
||||
|
||||
// Test with normal p value
|
||||
tokens = toTokens(input) // Reset tokens
|
||||
tokens = topK(tokens, 20)
|
||||
softmax(tokens)
|
||||
tokens = minP(tokens, 0.2)
|
||||
|
||||
// Should keep tokens with prob >= 0.2 * max_prob
|
||||
if len(tokens) > 3 {
|
||||
t.Errorf("minP(0.2): kept too many tokens: got %d", len(tokens))
|
||||
t.Logf("got: %v", tokens)
|
||||
}
|
||||
|
||||
// Test with single token
|
||||
tokens = toTokens(input[:1])
|
||||
tokens = topK(tokens, 20)
|
||||
|
|
@ -307,32 +269,18 @@ func TestMinP(t *testing.T) {
|
|||
tokens = minP(tokens, 1.0)
|
||||
if len(tokens) < 1 {
|
||||
t.Error("minP should keep at least one token even with extreme probabilities")
|
||||
}
|
||||
got := minP(tokens, 1.0)
|
||||
|
||||
if len(got) != 1 {
|
||||
t.Errorf("minP(1.0): should keep all tokens, got %d, want %d", len(got), len(tokens))
|
||||
}
|
||||
|
||||
// Test with normal p value
|
||||
got = minP(tokens, 0.2)
|
||||
|
||||
// Should keep tokens with prob >= 0.2 * max_prob
|
||||
if len(tokens) > 3 {
|
||||
t.Errorf("minP(0.2): kept too many tokens: got %d", len(tokens))
|
||||
t.Logf("got: %v", tokens)
|
||||
}
|
||||
|
||||
// Test with zero p value
|
||||
tokens = toTokens(input) // Reset tokens
|
||||
tokens = topK(tokens, 20)
|
||||
softmax(tokens)
|
||||
tokens = minP(tokens, 0.0)
|
||||
|
||||
// Should keep only the highest probability token
|
||||
if len(tokens) != len(input) {
|
||||
t.Errorf("minP(0.0): should keep only one token, got %d", len(tokens))
|
||||
t.Logf("got: %v", tokens)
|
||||
}
|
||||
|
||||
input = []float32{1e-10, 1e-10, 1e-10}
|
||||
tokens = toTokens(input)
|
||||
softmax(tokens)
|
||||
tokens = minP(tokens, 1.0)
|
||||
if len(tokens) < 1 {
|
||||
t.Error("minP should keep at least one token even with extreme probabilities")
|
||||
if len(got) > 3 {
|
||||
t.Errorf("minP(0.2): kept too many tokens: got %d", len(got))
|
||||
t.Logf("got: %v", got)
|
||||
}
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue