diff --git a/api/types.go b/api/types.go index 94d492006..a900051a0 100644 --- a/api/types.go +++ b/api/types.go @@ -285,6 +285,7 @@ type Options struct { PresencePenalty float32 `json:"presence_penalty,omitempty"` FrequencyPenalty float32 `json:"frequency_penalty,omitempty"` Stop []string `json:"stop,omitempty"` + ShiftContext bool `json:"shift_context,omitempty"` } // Runner options which must be set when the model is loaded into memory @@ -663,6 +664,7 @@ func DefaultOptions() Options { PresencePenalty: 0.0, FrequencyPenalty: 0.0, Seed: -1, + ShiftContext: true, Runner: Runner{ // options set when the model is loaded diff --git a/llm/server.go b/llm/server.go index 373f6faef..b4acde9ac 100644 --- a/llm/server.go +++ b/llm/server.go @@ -700,6 +700,8 @@ const ( DoneReasonStop DoneReason = iota // DoneReasonLength indicates the completion stopped due to length limits DoneReasonLength + // DoneReasonContextShift indicates the completion stopped due to context shift + DoneReasonContextShift // DoneReasonConnectionClosed indicates the completion stopped due to the connection being closed DoneReasonConnectionClosed ) @@ -710,6 +712,8 @@ func (d DoneReason) String() string { return "length" case DoneReasonStop: return "stop" + case DoneReasonContextShift: + return "context_limit_reached" default: return "" // closed } diff --git a/runner/llamarunner/runner.go b/runner/llamarunner/runner.go index 7aa9b96a2..88dc71a59 100644 --- a/runner/llamarunner/runner.go +++ b/runner/llamarunner/runner.go @@ -80,6 +80,9 @@ type Sequence struct { // true if an embedding are to be returned instead of text generation embeddingOnly bool + // true if context shifting should be enabled + shiftContext bool + doneReason llm.DoneReason // Metrics @@ -90,11 +93,12 @@ type Sequence struct { } type NewSequenceParams struct { - numPredict int - stop []string - numKeep int - samplingParams *llama.SamplingParams - embedding bool + numPredict int + stop []string + numKeep int + samplingParams *llama.SamplingParams + embedding bool + enableContextShift bool } func (s *Server) NewSequence(prompt string, images []llm.ImageData, params NewSequenceParams) (*Sequence, error) { @@ -120,7 +124,7 @@ func (s *Server) NewSequence(prompt string, images []llm.ImageData, params NewSe // Ensure that at least 1 input can be discarded during shift params.numKeep = min(params.numKeep, s.cache.numCtx-1) - if len(inputs) > s.cache.numCtx { + if len(inputs) > s.cache.numCtx && params.enableContextShift { discard := len(inputs) - s.cache.numCtx newInputs := inputs[:params.numKeep] newInputs = append(newInputs, inputs[params.numKeep+discard:]...) @@ -155,6 +159,7 @@ func (s *Server) NewSequence(prompt string, images []llm.ImageData, params NewSe embeddingOnly: params.embedding, stop: params.stop, numKeep: params.numKeep, + shiftContext: params.enableContextShift, }, nil } @@ -300,13 +305,26 @@ func flushPending(seq *Sequence) bool { func (s *Server) removeSequence(seqIndex int, reason llm.DoneReason) { seq := s.seqs[seqIndex] + if seq == nil { + return + } + + // Mark the sequence as being removed to prevent further processing + s.seqs[seqIndex] = nil + + if seq.cache != nil { + seq.cache.InUse = false + } + + if len(seq.pendingResponses) > 0 { + flushPending(seq) + } - flushPending(seq) seq.doneReason = reason + close(seq.responses) close(seq.embedding) - seq.cache.InUse = false - s.seqs[seqIndex] = nil + s.seqsSem.Release(1) } @@ -340,7 +358,7 @@ func (s *Server) run(ctx context.Context) { default: err := s.processBatch(tokenBatch, embedBatch) if err != nil { - panic(err) + slog.Error("error processing batch", "error", err) } tokenBatch.Clear() @@ -382,6 +400,10 @@ func (s *Server) processBatch(tokenBatch *llama.Batch, embedBatch *llama.Batch) for i, input := range seq.inputs { if len(seq.cache.Inputs)+len(seq.pendingInputs)+1 > s.cache.numCtx { + if !seq.shiftContext { + s.removeSequence(seqIdx, llm.DoneReasonContextShift) + continue + } if len(seq.pendingInputs) == 0 { err := s.cache.ShiftCacheSlot(seq.cache, seq.numKeep) if err != nil { @@ -573,11 +595,12 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) { } seq, err := s.NewSequence(req.Prompt, req.Images, NewSequenceParams{ - numPredict: req.Options.NumPredict, - stop: req.Options.Stop, - numKeep: req.Options.NumKeep, - samplingParams: &samplingParams, - embedding: false, + numPredict: req.Options.NumPredict, + stop: req.Options.Stop, + numKeep: req.Options.NumKeep, + samplingParams: &samplingParams, + embedding: false, + enableContextShift: req.Options.ShiftContext, }) if err != nil { http.Error(w, fmt.Sprintf("Failed to create new sequence: %v", err), http.StatusInternalServerError) diff --git a/runner/llamarunner/runner_test.go b/runner/llamarunner/runner_test.go new file mode 100644 index 000000000..4f2d05d51 --- /dev/null +++ b/runner/llamarunner/runner_test.go @@ -0,0 +1,152 @@ +package llamarunner + +import ( + "testing" + + "github.com/ollama/ollama/llm" +) + +func TestContextShiftLogic(t *testing.T) { + tests := []struct { + name string + enableContextShift bool + contextLength int32 + cacheInputs int + pendingInputs int + minBatch int + expectedDoneReason llm.DoneReason + shouldRemove bool + }{ + { + name: "context shifting enabled - should shift", + enableContextShift: true, + contextLength: 100, + cacheInputs: 80, + pendingInputs: 0, + minBatch: 30, + expectedDoneReason: llm.DoneReasonStop, + shouldRemove: false, + }, + { + name: "context shifting disabled - should remove", + enableContextShift: false, + contextLength: 100, + cacheInputs: 80, + pendingInputs: 0, + minBatch: 30, + expectedDoneReason: llm.DoneReasonContextShift, + shouldRemove: true, + }, + { + name: "context shifting disabled - within limits", + enableContextShift: false, + contextLength: 100, + cacheInputs: 50, + pendingInputs: 0, + minBatch: 30, + expectedDoneReason: llm.DoneReasonStop, + shouldRemove: false, + }, + { + name: "pending inputs - should break batch", + enableContextShift: true, + contextLength: 100, + cacheInputs: 50, + pendingInputs: 20, + minBatch: 30, + expectedDoneReason: llm.DoneReasonStop, + shouldRemove: false, + }, + { + name: "no pending inputs - should shift", + enableContextShift: true, + contextLength: 100, + cacheInputs: 80, + pendingInputs: 0, + minBatch: 30, + expectedDoneReason: llm.DoneReasonStop, + shouldRemove: false, + }, + { + name: "long prompt with context shifting disabled - will be handled at runtime", + enableContextShift: false, + contextLength: 100, + cacheInputs: 0, + pendingInputs: 0, + minBatch: 150, // Simulates a long prompt + expectedDoneReason: llm.DoneReasonContextShift, + shouldRemove: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Test the core logic from processBatch + if int32(tt.cacheInputs+tt.pendingInputs+tt.minBatch) > tt.contextLength { + if tt.pendingInputs != 0 { + // Should break batch + if tt.shouldRemove { + t.Error("should not remove sequence when pending inputs exist") + } + } else if !tt.enableContextShift { + // Should remove with DoneReasonContextShift + if !tt.shouldRemove { + t.Error("should remove sequence when context shifting disabled") + } + if tt.expectedDoneReason != llm.DoneReasonContextShift { + t.Errorf("expected DoneReason %v, got %v", llm.DoneReasonContextShift, tt.expectedDoneReason) + } + } else { + // Should shift context + if tt.shouldRemove { + t.Error("should not remove sequence when context shifting enabled") + } + } + } + }) + } +} + +func TestPredictLimitLogic(t *testing.T) { + tests := []struct { + name string + numPredict int + numPredicted int + expectRemove bool + }{ + { + name: "predict limit not reached", + numPredict: 5, + numPredicted: 3, + expectRemove: false, + }, + { + name: "predict limit reached", + numPredict: 5, + numPredicted: 5, + expectRemove: true, + }, + { + name: "predict limit exceeded", + numPredict: 5, + numPredicted: 6, + expectRemove: true, + }, + { + name: "no predict limit", + numPredict: 0, + numPredicted: 100, + expectRemove: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Test the core logic from processBatch + shouldRemove := tt.numPredict > 0 && tt.numPredicted >= tt.numPredict + if shouldRemove != tt.expectRemove { + t.Errorf("expected remove=%v, got %v", tt.expectRemove, shouldRemove) + } + }) + } +} diff --git a/runner/ollamarunner/runner.go b/runner/ollamarunner/runner.go index a7a889f1f..bf2ca8343 100644 --- a/runner/ollamarunner/runner.go +++ b/runner/ollamarunner/runner.go @@ -85,6 +85,9 @@ type Sequence struct { // true if an embedding are to be returned instead of text generation embeddingOnly bool + // true if context shifting should be enabled + shiftContext bool + doneReason llm.DoneReason // Metrics @@ -95,11 +98,12 @@ type Sequence struct { } type NewSequenceParams struct { - numPredict int - stop []string - numKeep int32 - sampler sample.Sampler - embedding bool + numPredict int + stop []string + numKeep int32 + sampler sample.Sampler + embedding bool + enableContextShift bool } func (s *Server) NewSequence(prompt string, images []llm.ImageData, params NewSequenceParams) (*Sequence, error) { @@ -121,7 +125,7 @@ func (s *Server) NewSequence(prompt string, images []llm.ImageData, params NewSe // Ensure that at least 1 input can be discarded during shift params.numKeep = min(params.numKeep, s.cache.numCtx-1) - if int32(len(inputs)) > s.cache.numCtx { + if int32(len(inputs)) > s.cache.numCtx && params.enableContextShift { discard := int32(len(inputs)) - s.cache.numCtx promptStart := params.numKeep + discard @@ -175,6 +179,7 @@ func (s *Server) NewSequence(prompt string, images []llm.ImageData, params NewSe embeddingOnly: params.embedding, stop: params.stop, numKeep: params.numKeep, + shiftContext: params.enableContextShift, }, nil } @@ -341,13 +346,25 @@ func flushPending(seq *Sequence) bool { func (s *Server) removeSequence(seqIndex int, reason llm.DoneReason) { seq := s.seqs[seqIndex] + if seq == nil { + return + } + + // Mark the sequence as being removed to prevent further processing + s.seqs[seqIndex] = nil + + if seq.cache != nil { + seq.cache.InUse = false + } + + if len(seq.pendingResponses) > 0 { + flushPending(seq) + } - flushPending(seq) seq.doneReason = reason + close(seq.responses) close(seq.embedding) - seq.cache.InUse = false - s.seqs[seqIndex] = nil s.seqsSem.Release(1) } @@ -431,6 +448,11 @@ func (s *Server) processBatch() error { break } + if !seq.shiftContext { + s.removeSequence(seqIdx, llm.DoneReasonContextShift) + continue + } + err := s.cache.ShiftCacheSlot(seq.cache, seq.numKeep) if err != nil { var reprocess *ErrReprocessInputs @@ -629,11 +651,12 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) { ) seq, err := s.NewSequence(req.Prompt, req.Images, NewSequenceParams{ - numPredict: req.Options.NumPredict, - stop: req.Options.Stop, - numKeep: int32(req.Options.NumKeep), - sampler: sampler, - embedding: false, + numPredict: req.Options.NumPredict, + stop: req.Options.Stop, + numKeep: int32(req.Options.NumKeep), + sampler: sampler, + embedding: false, + enableContextShift: req.Options.ShiftContext, }) if err != nil { http.Error(w, fmt.Sprintf("Failed to create new sequence: %v", err), http.StatusInternalServerError) diff --git a/runner/ollamarunner/runner_test.go b/runner/ollamarunner/runner_test.go new file mode 100644 index 000000000..5fb8d1a56 --- /dev/null +++ b/runner/ollamarunner/runner_test.go @@ -0,0 +1,167 @@ +package ollamarunner + +import ( + "testing" + + "github.com/ollama/ollama/llm" +) + +func TestEnableContextShiftLogic(t *testing.T) { + tests := []struct { + name string + enableContextShift bool + contextLength int32 + cacheInputs int + pendingInputs int + minBatch int + expectedDoneReason llm.DoneReason + shouldRemove bool + }{ + { + name: "context shifting enabled - should shift", + enableContextShift: true, + contextLength: 100, + cacheInputs: 80, + pendingInputs: 0, + minBatch: 30, + expectedDoneReason: llm.DoneReasonStop, + shouldRemove: false, + }, + { + name: "context shifting disabled - should remove with DoneReasonContextShift", + enableContextShift: false, + contextLength: 100, + cacheInputs: 80, + pendingInputs: 0, + minBatch: 30, + expectedDoneReason: llm.DoneReasonContextShift, + shouldRemove: true, + }, + { + name: "context shifting disabled - within limits", + enableContextShift: false, + contextLength: 100, + cacheInputs: 50, + pendingInputs: 0, + minBatch: 30, + expectedDoneReason: llm.DoneReasonStop, + shouldRemove: false, + }, + { + name: "context shifting disabled - exact limit", + enableContextShift: false, + contextLength: 100, + cacheInputs: 100, + pendingInputs: 0, + minBatch: 1, + expectedDoneReason: llm.DoneReasonContextShift, + shouldRemove: true, + }, + { + name: "pending inputs - should break batch", + enableContextShift: true, + contextLength: 100, + cacheInputs: 50, + pendingInputs: 20, + minBatch: 30, + expectedDoneReason: llm.DoneReasonStop, + shouldRemove: false, + }, + { + name: "no pending inputs - should shift", + enableContextShift: true, + contextLength: 100, + cacheInputs: 80, + pendingInputs: 0, + minBatch: 30, + expectedDoneReason: llm.DoneReasonStop, + shouldRemove: false, + }, + { + name: "long prompt with context shifting disabled - will be handled at runtime", + enableContextShift: false, + contextLength: 100, + cacheInputs: 0, + pendingInputs: 0, + minBatch: 150, // Simulates a long prompt + expectedDoneReason: llm.DoneReasonContextShift, + shouldRemove: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Test the core logic from processBatch - matches actual implementation + if int32(tt.cacheInputs+tt.pendingInputs+tt.minBatch) > tt.contextLength { + if tt.pendingInputs != 0 { + // Should break batch - don't remove sequence + if tt.shouldRemove { + t.Error("should not remove sequence when pending inputs exist") + } + } else if !tt.enableContextShift { + // Should remove with DoneReasonContextShift + if !tt.shouldRemove { + t.Error("should remove sequence when context shifting disabled") + } + if tt.expectedDoneReason != llm.DoneReasonContextShift { + t.Errorf("expected DoneReason %v, got %v", llm.DoneReasonContextShift, tt.expectedDoneReason) + } + } else { + // Should shift context - don't remove sequence + if tt.shouldRemove { + t.Error("should not remove sequence when context shifting enabled") + } + } + } else { + // Within limits - should not remove + if tt.shouldRemove { + t.Errorf("should not remove sequence when within context limits") + } + } + }) + } +} + +func TestPredictLimitLogic(t *testing.T) { + tests := []struct { + name string + numPredict int + numPredicted int + expectRemove bool + }{ + { + name: "predict limit not reached", + numPredict: 5, + numPredicted: 3, + expectRemove: false, + }, + { + name: "predict limit reached", + numPredict: 5, + numPredicted: 5, + expectRemove: true, + }, + { + name: "predict limit exceeded", + numPredict: 5, + numPredicted: 6, + expectRemove: true, + }, + { + name: "no predict limit", + numPredict: 0, + numPredicted: 100, + expectRemove: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Test the core logic from processBatch + shouldRemove := tt.numPredict > 0 && tt.numPredicted >= tt.numPredict + if shouldRemove != tt.expectRemove { + t.Errorf("expected remove=%v, got %v", tt.expectRemove, shouldRemove) + } + }) + } +} diff --git a/server/prompt.go b/server/prompt.go index f8c895d71..55fb61e9d 100644 --- a/server/prompt.go +++ b/server/prompt.go @@ -63,6 +63,9 @@ func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api. } if ctxLen > opts.NumCtx { + if !opts.ShiftContext { + return "", nil, fmt.Errorf("context length of %d tokens exceeded, context shifting is disabled", opts.NumCtx) + } slog.Debug("truncating input messages which exceed context length", "truncated", len(msgs[i:])) break } else { diff --git a/server/prompt_test.go b/server/prompt_test.go index 0043b9a47..c679d705b 100644 --- a/server/prompt_test.go +++ b/server/prompt_test.go @@ -2,6 +2,7 @@ package server import ( "bytes" + "fmt" "testing" "github.com/google/go-cmp/cmp" @@ -56,7 +57,7 @@ func TestChatPrompt(t *testing.T) { {Role: "user", Content: "A test. And a thumping good one at that, I'd wager."}, }, expect: expect{ - prompt: "A test. And a thumping good one at that, I'd wager. ", + error: fmt.Errorf("context length of 1 tokens exceeded, context shifting is disabled"), }, }, { @@ -69,10 +70,7 @@ func TestChatPrompt(t *testing.T) { {Role: "user", Content: "A test. And a thumping good one at that, I'd wager.", Images: []api.ImageData{[]byte("something")}}, }, expect: expect{ - prompt: "[img-0]A test. And a thumping good one at that, I'd wager. ", - images: [][]byte{ - []byte("something"), - }, + error: fmt.Errorf("context length of 64 tokens exceeded, context shifting is disabled"), }, }, { @@ -85,10 +83,7 @@ func TestChatPrompt(t *testing.T) { {Role: "user", Content: "A test. And a thumping good one at that, I'd wager.", Images: []api.ImageData{[]byte("somethingelse")}}, }, expect: expect{ - prompt: "[img-0]A test. And a thumping good one at that, I'd wager. ", - images: [][]byte{ - []byte("somethingelse"), - }, + error: fmt.Errorf("context length of 64 tokens exceeded, context shifting is disabled"), }, }, { @@ -156,10 +151,7 @@ func TestChatPrompt(t *testing.T) { {Role: "user", Content: "A test. And a thumping good one at that, I'd wager."}, }, expect: expect{ - prompt: "[img-0] I-I'm a what? A test. And a thumping good one at that, I'd wager. ", - images: [][]byte{ - []byte("somethingelse"), - }, + error: fmt.Errorf("context length of 1024 tokens exceeded, context shifting is disabled"), }, }, { @@ -208,12 +200,25 @@ func TestChatPrompt(t *testing.T) { t.Run(tt.name, func(t *testing.T) { model := tt.model opts := api.Options{Runner: api.Runner{NumCtx: tt.limit}} + + // For truncation tests, disable context shifting to test the truncation behavior + if tt.name == "truncate messages" || + tt.name == "truncate messages with image" || + tt.name == "truncate messages with images" || + tt.name == "truncate message with interleaved images" { + opts.ShiftContext = false + } + think := false prompt, images, err := chatPrompt(t.Context(), &model, mockRunner{}.Tokenize, &opts, tt.msgs, nil, &think) if tt.error == nil && err != nil { t.Fatal(err) - } else if tt.error != nil && err != tt.error { - t.Fatalf("expected err '%q', got '%q'", tt.error, err) + } else if tt.error != nil && err != nil { + if err.Error() != tt.error.Error() { + t.Fatalf("expected err '%q', got '%q'", tt.error, err) + } + } else if tt.error != nil && err == nil { + t.Fatalf("expected err '%q', got nil", tt.error) } if diff := cmp.Diff(prompt, tt.prompt); diff != "" { diff --git a/server/routes_test.go b/server/routes_test.go index 7c44bc957..aa6121ed9 100644 --- a/server/routes_test.go +++ b/server/routes_test.go @@ -5,6 +5,7 @@ import ( "context" "encoding/binary" "encoding/json" + "errors" "fmt" "io" "io/fs" @@ -25,6 +26,7 @@ import ( "github.com/google/go-cmp/cmp" "github.com/ollama/ollama/api" "github.com/ollama/ollama/fs/ggml" + "github.com/ollama/ollama/llm" "github.com/ollama/ollama/openai" "github.com/ollama/ollama/server/internal/client/ollama" "github.com/ollama/ollama/types/model" @@ -968,3 +970,154 @@ func TestWaitForStream(t *testing.T) { }) } } + +func TestEnableContextShiftNonStreamingResponse(t *testing.T) { + tests := []struct { + name string + enableContextShift bool + responses []llm.CompletionResponse + expectedDone bool + expectedDoneReason string + }{ + { + name: "context shifting disabled - should have DoneReasonLength", + enableContextShift: false, + responses: []llm.CompletionResponse{ + {Content: "Hello", Done: false}, + {Content: " world", Done: false}, + {Content: "", Done: true, DoneReason: llm.DoneReasonLength}, + }, + expectedDone: true, + expectedDoneReason: "length", + }, + { + name: "context shifting enabled - should have DoneReasonStop", + enableContextShift: true, + responses: []llm.CompletionResponse{ + {Content: "Hello", Done: false}, + {Content: " world", Done: false}, + {Content: "", Done: true, DoneReason: llm.DoneReasonStop}, + }, + expectedDone: true, + expectedDoneReason: "stop", + }, + { + name: "no final response with Done=true", + enableContextShift: false, + responses: []llm.CompletionResponse{ + {Content: "Hello", Done: false}, + {Content: " world", Done: false}, + }, + expectedDone: false, + expectedDoneReason: "", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // The last response in the channel will naturally be the final state + lastResponse := tt.responses[len(tt.responses)-1] + + if lastResponse.Done != tt.expectedDone { + t.Errorf("expected Done=%v, got %v", tt.expectedDone, lastResponse.Done) + } + + if tt.expectedDoneReason != "" { + if lastResponse.DoneReason.String() != tt.expectedDoneReason { + t.Errorf("expected DoneReason=%s, got %s", tt.expectedDoneReason, lastResponse.DoneReason.String()) + } + } + }) + } +} + +func TestHandleScheduleError(t *testing.T) { + tests := []struct { + name string + errorMessage string + expectedStatus int + }{ + { + name: "context length exceeded error", + errorMessage: "context length of 100 tokens exceeded, context shifting is disabled", + expectedStatus: http.StatusInternalServerError, + }, + { + name: "other error", + errorMessage: "some other error", + expectedStatus: http.StatusInternalServerError, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + gin.SetMode(gin.TestMode) + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + + err := errors.New(tt.errorMessage) + + handleScheduleError(c, "test-model", err) + + if w.Code != tt.expectedStatus { + t.Errorf("expected status %d, got %d", tt.expectedStatus, w.Code) + } + + var response map[string]any + if err := json.Unmarshal(w.Body.Bytes(), &response); err != nil { + t.Fatalf("failed to unmarshal response: %v", err) + } + + if errorMsg, ok := response["error"].(string); !ok || errorMsg != tt.errorMessage { + t.Errorf("expected error message '%s', got '%s'", tt.errorMessage, errorMsg) + } + }) + } +} + +func TestEnableContextShiftOptions(t *testing.T) { + t.Run("default options have enableContextShift=true", func(t *testing.T) { + opts := api.DefaultOptions() + if !opts.ShiftContext { + t.Errorf("expected EnableContextShift=true by default, got %v", opts.ShiftContext) + } + }) + + t.Run("can set enableContextShift to false", func(t *testing.T) { + opts := api.DefaultOptions() + opts.ShiftContext = false + if opts.ShiftContext { + t.Errorf("expected EnableContextShift=false after setting, got %v", opts.ShiftContext) + } + }) + + t.Run("JSON serialization omits false values", func(t *testing.T) { + opts := api.DefaultOptions() + opts.ShiftContext = false + + data, err := json.Marshal(opts) + if err != nil { + t.Fatalf("failed to marshal options: %v", err) + } + + // Check that enable_context_shift is not in the JSON when false + if bytes.Contains(data, []byte("enable_context_shift")) { + t.Errorf("expected enable_context_shift to be omitted from JSON when false, but found it in: %s", string(data)) + } + }) + + t.Run("JSON serialization includes true values", func(t *testing.T) { + opts := api.DefaultOptions() + opts.ShiftContext = true + + data, err := json.Marshal(opts) + if err != nil { + t.Fatalf("failed to marshal options: %v", err) + } + + // Check that enable_context_shift is in the JSON when true + if !bytes.Contains(data, []byte("enable_context_shift")) { + t.Errorf("expected enable_context_shift to be in JSON when true, but not found in: %s", string(data)) + } + }) +}