ollama/runner/ollamarunner/runner_test.go

168 lines
4.4 KiB
Go

package ollamarunner
import (
"testing"
"github.com/ollama/ollama/llm"
)
func TestEnableContextShiftLogic(t *testing.T) {
tests := []struct {
name string
enableContextShift bool
contextLength int32
cacheInputs int
pendingInputs int
minBatch int
expectedDoneReason llm.DoneReason
shouldRemove bool
}{
{
name: "context shifting enabled - should shift",
enableContextShift: true,
contextLength: 100,
cacheInputs: 80,
pendingInputs: 0,
minBatch: 30,
expectedDoneReason: llm.DoneReasonStop,
shouldRemove: false,
},
{
name: "context shifting disabled - should remove with DoneReasonContextShift",
enableContextShift: false,
contextLength: 100,
cacheInputs: 80,
pendingInputs: 0,
minBatch: 30,
expectedDoneReason: llm.DoneReasonContextShift,
shouldRemove: true,
},
{
name: "context shifting disabled - within limits",
enableContextShift: false,
contextLength: 100,
cacheInputs: 50,
pendingInputs: 0,
minBatch: 30,
expectedDoneReason: llm.DoneReasonStop,
shouldRemove: false,
},
{
name: "context shifting disabled - exact limit",
enableContextShift: false,
contextLength: 100,
cacheInputs: 100,
pendingInputs: 0,
minBatch: 1,
expectedDoneReason: llm.DoneReasonContextShift,
shouldRemove: true,
},
{
name: "pending inputs - should break batch",
enableContextShift: true,
contextLength: 100,
cacheInputs: 50,
pendingInputs: 20,
minBatch: 30,
expectedDoneReason: llm.DoneReasonStop,
shouldRemove: false,
},
{
name: "no pending inputs - should shift",
enableContextShift: true,
contextLength: 100,
cacheInputs: 80,
pendingInputs: 0,
minBatch: 30,
expectedDoneReason: llm.DoneReasonStop,
shouldRemove: false,
},
{
name: "long prompt with context shifting disabled - will be handled at runtime",
enableContextShift: false,
contextLength: 100,
cacheInputs: 0,
pendingInputs: 0,
minBatch: 150, // Simulates a long prompt
expectedDoneReason: llm.DoneReasonContextShift,
shouldRemove: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Test the core logic from processBatch - matches actual implementation
if int32(tt.cacheInputs+tt.pendingInputs+tt.minBatch) > tt.contextLength {
if tt.pendingInputs != 0 {
// Should break batch - don't remove sequence
if tt.shouldRemove {
t.Error("should not remove sequence when pending inputs exist")
}
} else if !tt.enableContextShift {
// Should remove with DoneReasonContextShift
if !tt.shouldRemove {
t.Error("should remove sequence when context shifting disabled")
}
if tt.expectedDoneReason != llm.DoneReasonContextShift {
t.Errorf("expected DoneReason %v, got %v", llm.DoneReasonContextShift, tt.expectedDoneReason)
}
} else {
// Should shift context - don't remove sequence
if tt.shouldRemove {
t.Error("should not remove sequence when context shifting enabled")
}
}
} else {
// Within limits - should not remove
if tt.shouldRemove {
t.Errorf("should not remove sequence when within context limits")
}
}
})
}
}
func TestPredictLimitLogic(t *testing.T) {
tests := []struct {
name string
numPredict int
numPredicted int
expectRemove bool
}{
{
name: "predict limit not reached",
numPredict: 5,
numPredicted: 3,
expectRemove: false,
},
{
name: "predict limit reached",
numPredict: 5,
numPredicted: 5,
expectRemove: true,
},
{
name: "predict limit exceeded",
numPredict: 5,
numPredicted: 6,
expectRemove: true,
},
{
name: "no predict limit",
numPredict: 0,
numPredicted: 100,
expectRemove: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Test the core logic from processBatch
shouldRemove := tt.numPredict > 0 && tt.numPredicted >= tt.numPredict
if shouldRemove != tt.expectRemove {
t.Errorf("expected remove=%v, got %v", tt.expectRemove, shouldRemove)
}
})
}
}