From 0ddb64db1f3b461969eee4911e7332abdae63d67 Mon Sep 17 00:00:00 2001 From: Thomas Stocker Date: Sat, 9 Aug 2025 22:33:42 +0200 Subject: [PATCH] Revert changes in transforms_test.go --- sample/transforms_test.go | 78 +++++++-------------------------------- 1 file changed, 13 insertions(+), 65 deletions(-) diff --git a/sample/transforms_test.go b/sample/transforms_test.go index 92e57a987..5307c5f8a 100644 --- a/sample/transforms_test.go +++ b/sample/transforms_test.go @@ -178,16 +178,8 @@ func TestTopP(t *testing.T) { // Test with normal p value got = topP(tokens, 0.95) - // Should keep tokens until cumsum > 0.95 + if len(got) > 3 { - t.Errorf("topP(0.95): kept too many tokens: got %d", len(got)) - t.Logf("got: %v", got) - } - - // Test with normal p value - got = topP(tokens, 0.95) - - if len(tokens) > 3 { t.Errorf("topP(0.95): kept too many tokens: got %d", len(tokens)) t.Logf("got: %v", got) } @@ -216,17 +208,8 @@ func TestTopP(t *testing.T) { softmax(tokens) got = topP(tokens, 1e-10) if len(got) == 0 { - t.Errorf("topP(1e-10): should keep at least one token, got %d", len(tokens)) - t.Logf("got: %v", tokens) - } - - // Test edge case - ensure at least one token remains - input = []float32{-1e6, -1e6, -1e6} // One dominant token - tokens = toTokens(input) - softmax(tokens) - tokens = topP(tokens, 0.0) // Very small p - if len(tokens) < 1 { - t.Error("topP should keep at least one token") + t.Errorf("topP(1e-10): should keep at least one token, got %d", len(got)) + t.Logf("got: %v", got) } } @@ -268,27 +251,6 @@ func TestMinP(t *testing.T) { t.Logf("got: %v", tokens) } - tokens = topK(tokens, 20) - softmax(tokens) - - tokens = minP(tokens, 1.0) - - if len(tokens) != 1 { - t.Errorf("minP(1.0): should keep all tokens, got %d, want %d", len(tokens), len(tokens)) - } - - // Test with normal p value - tokens = toTokens(input) // Reset tokens - tokens = topK(tokens, 20) - softmax(tokens) - tokens = minP(tokens, 0.2) - - // Should keep tokens with prob >= 0.2 * max_prob - if len(tokens) > 3 { - t.Errorf("minP(0.2): kept too many tokens: got %d", len(tokens)) - t.Logf("got: %v", tokens) - } - // Test with single token tokens = toTokens(input[:1]) tokens = topK(tokens, 20) @@ -307,32 +269,18 @@ func TestMinP(t *testing.T) { tokens = minP(tokens, 1.0) if len(tokens) < 1 { t.Error("minP should keep at least one token even with extreme probabilities") - } + got := minP(tokens, 1.0) + + if len(got) != 1 { + t.Errorf("minP(1.0): should keep all tokens, got %d, want %d", len(got), len(tokens)) + } + + // Test with normal p value + got = minP(tokens, 0.2) // Should keep tokens with prob >= 0.2 * max_prob - if len(tokens) > 3 { - t.Errorf("minP(0.2): kept too many tokens: got %d", len(tokens)) - t.Logf("got: %v", tokens) - } - - // Test with zero p value - tokens = toTokens(input) // Reset tokens - tokens = topK(tokens, 20) - softmax(tokens) - tokens = minP(tokens, 0.0) - - // Should keep only the highest probability token - if len(tokens) != len(input) { - t.Errorf("minP(0.0): should keep only one token, got %d", len(tokens)) - t.Logf("got: %v", tokens) - } - - input = []float32{1e-10, 1e-10, 1e-10} - tokens = toTokens(input) - softmax(tokens) - tokens = minP(tokens, 1.0) - if len(tokens) < 1 { - t.Error("minP should keep at least one token even with extreme probabilities") + if len(got) > 3 { + t.Errorf("minP(0.2): kept too many tokens: got %d", len(got)) t.Logf("got: %v", got) }