diff --git a/llm/server.go b/llm/server.go index 373f6faef..2df701783 100644 --- a/llm/server.go +++ b/llm/server.go @@ -22,6 +22,7 @@ import ( "strings" "sync" "time" + "unicode/utf8" "golang.org/x/sync/semaphore" @@ -725,10 +726,68 @@ type CompletionResponse struct { EvalDuration time.Duration `json:"eval_duration"` } +// unicodeBufferHandler wraps a completion response callback to handle partial UTF-8 sequences. +// This function creates a stateful closure that is NOT safe for concurrent use. +// Each completion request should create its own handler instance. +func unicodeBufferHandler(fn func(CompletionResponse)) func(CompletionResponse) { + var pendingUTF8 string + + return func(resp CompletionResponse) { + if resp.Content == "" && !resp.Done { + // No content to process, just pass through + fn(resp) + return + } + + // Combine any pending UTF-8 with current content + combinedContent := pendingUTF8 + resp.Content + pendingUTF8 = "" + + // Check if combined content is valid UTF-8 + if utf8.ValidString(combinedContent) { + // Valid UTF-8, send it + resp.Content = combinedContent + fn(resp) + } else { + // Invalid UTF-8 + if resp.Done { + // This is the final response, trim incomplete UTF-8 + trimmedContent := combinedContent + for !utf8.ValidString(trimmedContent) && len(trimmedContent) > 0 { + trimmedContent = trimmedContent[:len(trimmedContent)-1] + } + resp.Content = trimmedContent + fn(resp) + } else { + // Not final response, split valid and invalid parts + validPrefix := combinedContent + for !utf8.ValidString(validPrefix) && len(validPrefix) > 0 { + validPrefix = validPrefix[:len(validPrefix)-1] + } + + if len(validPrefix) > 0 { + // Send valid prefix + resp.Content = validPrefix + fn(resp) + // Buffer the remainder + pendingUTF8 = combinedContent[len(validPrefix):] + } else { + // No valid prefix, buffer everything + pendingUTF8 = combinedContent + // Don't send this response + } + } + } + } +} + func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn func(CompletionResponse)) error { slog.Debug("completion request", "images", len(req.Images), "prompt", len(req.Prompt), "format", string(req.Format)) slog.Log(ctx, logutil.LevelTrace, "completion request", "prompt", req.Prompt) + // Wrap the callback with unicode buffer handling + unicodeFn := unicodeBufferHandler(fn) + if len(req.Format) > 0 { switch string(req.Format) { case `null`, `""`: @@ -854,13 +913,13 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu } if c.Content != "" { - fn(CompletionResponse{ + unicodeFn(CompletionResponse{ Content: c.Content, }) } if c.Done { - fn(c) + unicodeFn(c) return nil } } diff --git a/llm/server_test.go b/llm/server_test.go index b6a8705e5..00f02ea52 100644 --- a/llm/server_test.go +++ b/llm/server_test.go @@ -70,3 +70,152 @@ func TestLLMServerCompletionFormat(t *testing.T) { }, nil) checkValid(err) } + +func TestUnicodeBufferHandler(t *testing.T) { + tests := []struct { + name string + inputResponses []CompletionResponse + expectedResponses []CompletionResponse + description string + }{ + { + name: "complete_unicode", + inputResponses: []CompletionResponse{ + {Content: "Hello", Done: false}, + {Content: " world", Done: false}, + {Content: "!", Done: true}, + }, + expectedResponses: []CompletionResponse{ + {Content: "Hello", Done: false}, + {Content: " world", Done: false}, + {Content: "!", Done: true}, + }, + description: "All responses with valid unicode should pass through unchanged", + }, + { + name: "incomplete_unicode_at_end_with_done", + inputResponses: []CompletionResponse{ + {Content: "Hello", Done: false}, + {Content: string([]byte{0xF0, 0x9F}), Done: true}, // Incomplete emoji with Done=true + }, + expectedResponses: []CompletionResponse{ + {Content: "Hello", Done: false}, + {Content: "", Done: true}, // Content is trimmed but response is still sent with Done=true + }, + description: "When Done=true, incomplete Unicode at the end should be trimmed", + }, + { + name: "split_unicode_across_responses", + inputResponses: []CompletionResponse{ + {Content: "Hello " + string([]byte{0xF0, 0x9F}), Done: false}, // First part of 😀 + {Content: string([]byte{0x98, 0x80}) + " world!", Done: true}, // Second part of 😀 and more text + }, + expectedResponses: []CompletionResponse{ + {Content: "Hello ", Done: false}, // Incomplete Unicode trimmed + {Content: "😀 world!", Done: true}, // Complete emoji in second response + }, + description: "Unicode split across responses should be handled correctly", + }, + { + name: "incomplete_unicode_buffered", + inputResponses: []CompletionResponse{ + {Content: "Test " + string([]byte{0xF0, 0x9F}), Done: false}, // Incomplete emoji + {Content: string([]byte{0x98, 0x80}), Done: false}, // Complete the emoji + {Content: " done", Done: true}, + }, + expectedResponses: []CompletionResponse{ + {Content: "Test ", Done: false}, // First part without incomplete unicode + {Content: "😀", Done: false}, // Complete emoji + {Content: " done", Done: true}, + }, + description: "Incomplete unicode should be buffered and combined with next response", + }, + { + name: "empty_response_with_done", + inputResponses: []CompletionResponse{ + {Content: "Complete response", Done: false}, + {Content: "", Done: true}, // Empty response with Done=true + }, + expectedResponses: []CompletionResponse{ + {Content: "Complete response", Done: false}, + {Content: "", Done: true}, // Should still be sent because Done=true + }, + description: "Empty final response with Done=true should still be sent", + }, + { + name: "done_reason_preserved", + inputResponses: []CompletionResponse{ + {Content: "Response", Done: false}, + {Content: " complete", Done: true, DoneReason: DoneReasonStop}, + }, + expectedResponses: []CompletionResponse{ + {Content: "Response", Done: false}, + {Content: " complete", Done: true, DoneReason: DoneReasonStop}, + }, + description: "DoneReason should be preserved in the final response", + }, + { + name: "only_incomplete_unicode_not_done", + inputResponses: []CompletionResponse{ + {Content: string([]byte{0xF0, 0x9F}), Done: false}, // Only incomplete unicode + }, + expectedResponses: []CompletionResponse{ + // No response expected - should be buffered + }, + description: "Response with only incomplete unicode should be buffered if not done", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var actualResponses []CompletionResponse + + // Create a callback that collects responses + callback := func(resp CompletionResponse) { + actualResponses = append(actualResponses, resp) + } + + // Create the unicode buffer handler + handler := unicodeBufferHandler(callback) + + // Send all input responses through the handler + for _, resp := range tt.inputResponses { + handler(resp) + } + + // Verify the number of responses + if len(actualResponses) != len(tt.expectedResponses) { + t.Fatalf("%s: got %d responses, want %d responses", + tt.description, len(actualResponses), len(tt.expectedResponses)) + } + + // Verify each response matches the expected one + for i, expected := range tt.expectedResponses { + if i >= len(actualResponses) { + t.Fatalf("%s: missing response at index %d", tt.description, i) + continue + } + + actual := actualResponses[i] + + // Verify content + if actual.Content != expected.Content { + t.Errorf("%s: response[%d].Content = %q, want %q", + tt.description, i, actual.Content, expected.Content) + } + + // Verify Done flag + if actual.Done != expected.Done { + t.Errorf("%s: response[%d].Done = %v, want %v", + tt.description, i, actual.Done, expected.Done) + } + + // Verify DoneReason if specified + if actual.DoneReason != expected.DoneReason { + t.Errorf("%s: response[%d].DoneReason = %v, want %v", + tt.description, i, actual.DoneReason, expected.DoneReason) + } + } + }) + } +} diff --git a/runner/common/stop.go b/runner/common/stop.go index 3f27a286e..23d9b92c3 100644 --- a/runner/common/stop.go +++ b/runner/common/stop.go @@ -2,6 +2,8 @@ package common import ( "strings" + + "github.com/ollama/ollama/llm" ) func FindStop(sequence string, stops []string) (bool, string) { @@ -29,68 +31,41 @@ func ContainsStopSuffix(sequence string, stops []string) bool { // truncateStop removes the provided stop string from pieces, // returning the partial pieces with stop removed, including truncating // the last piece if required (and signalling if this was the case) -func TruncateStop(pieces []string, stop string) ([]string, bool) { - joined := strings.Join(pieces, "") - - index := strings.Index(joined, stop) - if index == -1 { - return pieces, false +func TruncateStop(resps []llm.CompletionResponse, stop string) ([]llm.CompletionResponse, bool) { + var sequence string + for _, resp := range resps { + sequence += resp.Content } - joined = joined[:index] - - // Split truncated string back into pieces of original lengths - lengths := make([]int, len(pieces)) - for i, piece := range pieces { - lengths[i] = len(piece) + idx := strings.Index(sequence, stop) + if idx < 0 { + return resps, false } - var result []string - tokenTruncated := false - start := 0 - for _, length := range lengths { - if start >= len(joined) { + truncated := sequence[:idx] + if len(truncated) == 0 { + return nil, true + } + + result := make([]llm.CompletionResponse, 0, len(resps)) + + // Track position in truncated sequence + pos := 0 + truncationHappened := false + for _, resp := range resps { + if pos >= len(truncated) { break } - end := start + length - if end > len(joined) { - end = len(joined) - tokenTruncated = true + chunk := truncated[pos:min(pos+len(resp.Content), len(truncated))] + if len(chunk) < len(resp.Content) { + truncationHappened = true } - result = append(result, joined[start:end]) - start = end + if len(chunk) > 0 { + result = append(result, llm.CompletionResponse{Content: chunk}) + } + pos += len(resp.Content) } - return result, tokenTruncated -} - -func IncompleteUnicode(token string) bool { - incomplete := false - - // check if there is incomplete UTF-8 character at the end - for i := 1; i < 5 && i <= len(token); i++ { - c := token[len(token)-i] - - if (c & 0xc0) == 0x80 { - // continuation byte: 10xxxxxx - continue - } - - if (c & 0xe0) == 0xc0 { - // 2-byte character: 110xxxxx ... - incomplete = i < 2 - } else if (c & 0xf0) == 0xe0 { - // 3-byte character: 1110xxxx ... - incomplete = i < 3 - } else if (c & 0xf8) == 0xf0 { - // 4-byte character: 11110xxx ... - incomplete = i < 4 - } - - // else 1-byte character or invalid byte - break - } - - return incomplete + return result, truncationHappened } diff --git a/runner/common/stop_test.go b/runner/common/stop_test.go index 8df267eb4..de6cef841 100644 --- a/runner/common/stop_test.go +++ b/runner/common/stop_test.go @@ -1,51 +1,84 @@ package common import ( + "fmt" "reflect" "testing" + + "github.com/ollama/ollama/llm" ) func TestTruncateStop(t *testing.T) { tests := []struct { name string - pieces []string + pieces []llm.CompletionResponse stop string - expected []string + expected []llm.CompletionResponse expectedTrunc bool }{ { - name: "Single word", - pieces: []string{"hello", "world"}, - stop: "world", - expected: []string{"hello"}, + name: "Single word", + pieces: []llm.CompletionResponse{ + {Content: "Hello"}, + {Content: "world"}, + }, + stop: "world", + expected: []llm.CompletionResponse{ + {Content: "Hello"}, + }, expectedTrunc: false, }, { - name: "Partial", - pieces: []string{"hello", "wor"}, - stop: "or", - expected: []string{"hello", "w"}, + name: "Partial", + pieces: []llm.CompletionResponse{ + {Content: "Hello"}, + {Content: " wor"}, + }, + stop: "or", + expected: []llm.CompletionResponse{ + {Content: "Hello"}, + {Content: " w"}, + }, expectedTrunc: true, }, { - name: "Suffix", - pieces: []string{"Hello", " there", "!"}, - stop: "!", - expected: []string{"Hello", " there"}, + name: "Suffix", + pieces: []llm.CompletionResponse{ + {Content: "Hello"}, + {Content: " there"}, + {Content: "!"}, + }, + stop: "!", + expected: []llm.CompletionResponse{ + {Content: "Hello"}, + {Content: " there"}, + }, expectedTrunc: false, }, { - name: "Suffix partial", - pieces: []string{"Hello", " the", "re!"}, - stop: "there!", - expected: []string{"Hello", " "}, + name: "Suffix partial", + pieces: []llm.CompletionResponse{ + {Content: "Hello"}, + {Content: " the"}, + {Content: "re!"}, + }, + stop: "there!", + expected: []llm.CompletionResponse{ + {Content: "Hello"}, + {Content: " "}, + }, expectedTrunc: true, }, { - name: "Middle", - pieces: []string{"hello", " wor"}, - stop: "llo w", - expected: []string{"he"}, + name: "Middle", + pieces: []llm.CompletionResponse{ + {Content: "Hello"}, + {Content: " wo"}, + }, + stop: "llo w", + expected: []llm.CompletionResponse{ + {Content: "He"}, + }, expectedTrunc: true, }, } @@ -54,76 +87,23 @@ func TestTruncateStop(t *testing.T) { t.Run(tt.name, func(t *testing.T) { result, resultTrunc := TruncateStop(tt.pieces, tt.stop) if !reflect.DeepEqual(result, tt.expected) || resultTrunc != tt.expectedTrunc { - t.Errorf("truncateStop(%v, %s): have %v (%v); want %v (%v)", tt.pieces, tt.stop, result, resultTrunc, tt.expected, tt.expectedTrunc) + t.Errorf("truncateStop(%v, %v):\n%shave truncated %v\nwant truncated %v", + tt.pieces, tt.stop, formatContentDiff(result, tt.expected), resultTrunc, tt.expectedTrunc) } }) } } -func TestIncompleteUnicode(t *testing.T) { - tests := []struct { - name string - input string - expected bool - }{ - { - name: "Basic", - input: "hi", - expected: false, - }, - { - name: "Two byte", - input: "hi" + string([]byte{0xc2, 0xa3}), - expected: false, - }, - { - name: "Two byte - missing last", - input: "hi" + string([]byte{0xc2}), - expected: true, - }, - { - name: "Three byte", - input: "hi" + string([]byte{0xe0, 0xA0, 0x80}), - expected: false, - }, - { - name: "Three byte - missing last", - input: "hi" + string([]byte{0xe0, 0xA0}), - expected: true, - }, - { - name: "Three byte - missing last 2", - input: "hi" + string([]byte{0xe0}), - expected: true, - }, - { - name: "Four byte", - input: "hi" + string([]byte{0xf0, 0x92, 0x8a, 0xb7}), - expected: false, - }, - { - name: "Four byte - missing last", - input: "hi" + string([]byte{0xf0, 0x92, 0x8a}), - expected: true, - }, - { - name: "Four byte - missing last 2", - input: "hi" + string([]byte{0xf0, 0x92}), - expected: true, - }, - { - name: "Four byte - missing last 3", - input: "hi" + string([]byte{0xf0}), - expected: true, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - result := IncompleteUnicode(tt.input) - if result != tt.expected { - t.Errorf("incompleteUnicode(%s): have %v; want %v", tt.input, result, tt.expected) - } - }) +func formatContentDiff(result, expected []llm.CompletionResponse) string { + var s string + for i := 0; i < len(result) || i < len(expected); i++ { + if i < len(result) && i < len(expected) && result[i].Content != expected[i].Content { + s += fmt.Sprintf("[%d] %q vs %q\n", i, result[i].Content, expected[i].Content) + } else if i < len(result) && i >= len(expected) { + s += fmt.Sprintf("[%d] extra %q\n", i, result[i].Content) + } else if i >= len(result) && i < len(expected) { + s += fmt.Sprintf("[%d] missing %q\n", i, expected[i].Content) + } } + return s } diff --git a/runner/llamarunner/runner.go b/runner/llamarunner/runner.go index 7aa9b96a2..69efac83d 100644 --- a/runner/llamarunner/runner.go +++ b/runner/llamarunner/runner.go @@ -17,7 +17,6 @@ import ( "strings" "sync" "time" - "unicode/utf8" "golang.org/x/sync/semaphore" @@ -52,13 +51,13 @@ type Sequence struct { pendingInputs []input // tokens that have been generated but not returned yet (e.g. for stop sequences) - pendingResponses []string + pendingResponses []llm.CompletionResponse // input cache being used by this sequence cache *InputCacheSlot // channel to send responses over - responses chan string + responses chan llm.CompletionResponse // channel to stop decoding (such as if the remote connection is closed) quit chan bool @@ -89,6 +88,19 @@ type Sequence struct { numPromptInputs int } +func (seq *Sequence) send(resp llm.CompletionResponse) bool { + if len(resp.Content) > 0 || resp.Done { + select { + case seq.responses <- resp: + // Successfully sent + return true + case <-seq.quit: + return false + } + } + return true +} + type NewSequenceParams struct { numPredict int stop []string @@ -147,8 +159,8 @@ func (s *Server) NewSequence(prompt string, images []llm.ImageData, params NewSe numPromptInputs: len(inputs), startProcessingTime: startTime, numPredict: params.numPredict, - pendingResponses: make([]string, 0), - responses: make(chan string, 100), + pendingResponses: make([]llm.CompletionResponse, 0), + responses: make(chan llm.CompletionResponse, 100), quit: make(chan bool, 1), embedding: make(chan []float32, 1), samplingCtx: sc, @@ -272,36 +284,15 @@ func (s *Server) allNil() bool { return true } -func flushPending(seq *Sequence) bool { - joined := strings.Join(seq.pendingResponses, "") - seq.pendingResponses = []string{} - - // Check if there are any partial UTF-8 characters remaining. - // We already check and queue as we are generating but some may - // still make it here: - // - Sequence is ending, e.g. generation limit has been hit - // - Invalid characters in the middle of a string - // This is a stricter check to ensure we never output invalid Unicode. - for !utf8.ValidString(joined) { - joined = joined[:len(joined)-1] - } - - if len(joined) == 0 { - return true - } - - select { - case seq.responses <- joined: - return true - case <-seq.quit: - return false - } -} - func (s *Server) removeSequence(seqIndex int, reason llm.DoneReason) { seq := s.seqs[seqIndex] - flushPending(seq) + // Send any remaining pending responses + for _, resp := range seq.pendingResponses { + seq.send(resp) + } + seq.pendingResponses = []llm.CompletionResponse{} + seq.doneReason = reason close(seq.responses) close(seq.embedding) @@ -490,8 +481,11 @@ func (s *Server) processBatch(tokenBatch *llama.Batch, embedBatch *llama.Batch) seq.inputs = []input{{token: token}} - seq.pendingResponses = append(seq.pendingResponses, piece) - sequence := strings.Join(seq.pendingResponses, "") + seq.pendingResponses = append(seq.pendingResponses, llm.CompletionResponse{Content: piece}) + sequence := "" + for _, r := range seq.pendingResponses { + sequence += r.Content + } if ok, stop := common.FindStop(sequence, seq.stop); ok { slog.Debug("hit stop token", "pending", seq.pendingResponses, "stop", stop) @@ -523,13 +517,13 @@ func (s *Server) processBatch(tokenBatch *llama.Batch, embedBatch *llama.Batch) continue } - if common.IncompleteUnicode(sequence) { - continue - } - - if !flushPending(seq) { - s.removeSequence(i, llm.DoneReasonConnectionClosed) + for _, resp := range seq.pendingResponses { + if !seq.send(resp) { + s.removeSequence(i, llm.DoneReasonConnectionClosed) + break + } } + seq.pendingResponses = []llm.CompletionResponse{} } return nil @@ -627,9 +621,7 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) { return case content, ok := <-seq.responses: if ok { - if err := json.NewEncoder(w).Encode(&llm.CompletionResponse{ - Content: content, - }); err != nil { + if err := json.NewEncoder(w).Encode(&content); err != nil { http.Error(w, fmt.Sprintf("failed to encode response: %v", err), http.StatusInternalServerError) close(seq.quit) return diff --git a/runner/ollamarunner/runner.go b/runner/ollamarunner/runner.go index a7a889f1f..a9d03f547 100644 --- a/runner/ollamarunner/runner.go +++ b/runner/ollamarunner/runner.go @@ -20,7 +20,6 @@ import ( "strings" "sync" "time" - "unicode/utf8" "golang.org/x/image/bmp" "golang.org/x/sync/semaphore" @@ -56,13 +55,13 @@ type Sequence struct { pendingInputs []input.Input // tokens that have been generated but not returned yet (e.g. for stop sequences) - pendingResponses []string + pendingResponses []llm.CompletionResponse // input cache being used by this sequence cache *InputCacheSlot // channel to send responses over - responses chan string + responses chan llm.CompletionResponse // channel to stop decoding (such as if the remote connection is closed) quit chan bool @@ -94,6 +93,19 @@ type Sequence struct { numPromptInputs int } +func (seq *Sequence) send(resp llm.CompletionResponse) bool { + if len(resp.Content) > 0 || resp.Done { + select { + case seq.responses <- resp: + // Successfully sent + return true + case <-seq.quit: + return false + } + } + return true +} + type NewSequenceParams struct { numPredict int stop []string @@ -167,8 +179,8 @@ func (s *Server) NewSequence(prompt string, images []llm.ImageData, params NewSe numPromptInputs: len(inputs), startProcessingTime: startTime, numPredict: params.numPredict, - pendingResponses: make([]string, 0), - responses: make(chan string, 100), + pendingResponses: make([]llm.CompletionResponse, 0), + responses: make(chan llm.CompletionResponse, 100), quit: make(chan bool, 1), embedding: make(chan []float32, 1), sampler: params.sampler, @@ -313,36 +325,15 @@ func (s *Server) allNil() bool { return true } -func flushPending(seq *Sequence) bool { - joined := strings.Join(seq.pendingResponses, "") - seq.pendingResponses = []string{} - - // Check if there are any partial UTF-8 characters remaining. - // We already check and queue as we are generating but some may - // still make it here: - // - Sequence is ending, e.g. generation limit has been hit - // - Invalid characters in the middle of a string - // This is a stricter check to ensure we never output invalid Unicode. - for !utf8.ValidString(joined) { - joined = joined[:len(joined)-1] - } - - if len(joined) == 0 { - return true - } - - select { - case seq.responses <- joined: - return true - case <-seq.quit: - return false - } -} - func (s *Server) removeSequence(seqIndex int, reason llm.DoneReason) { seq := s.seqs[seqIndex] - flushPending(seq) + // Send any remaining pending responses + for _, resp := range seq.pendingResponses { + seq.send(resp) + } + seq.pendingResponses = []llm.CompletionResponse{} + seq.doneReason = reason close(seq.responses) close(seq.embedding) @@ -541,8 +532,11 @@ func (s *Server) processBatch() error { seq.inputs = []input.Input{{Token: token}} - seq.pendingResponses = append(seq.pendingResponses, piece) - sequence := strings.Join(seq.pendingResponses, "") + seq.pendingResponses = append(seq.pendingResponses, llm.CompletionResponse{Content: piece}) + sequence := "" + for _, r := range seq.pendingResponses { + sequence += r.Content + } if ok, stop := common.FindStop(sequence, seq.stop); ok { slog.Debug("hit stop token", "pending", seq.pendingResponses, "stop", stop) @@ -574,13 +568,14 @@ func (s *Server) processBatch() error { continue } - if common.IncompleteUnicode(sequence) { - continue - } - - if !flushPending(seq) { - s.removeSequence(i, llm.DoneReasonConnectionClosed) + // Send all pending responses directly without unicode checking + for _, resp := range seq.pendingResponses { + if !seq.send(resp) { + s.removeSequence(i, llm.DoneReasonConnectionClosed) + break + } } + seq.pendingResponses = []llm.CompletionResponse{} } return nil @@ -683,9 +678,7 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) { return case content, ok := <-seq.responses: if ok { - if err := json.NewEncoder(w).Encode(&llm.CompletionResponse{ - Content: content, - }); err != nil { + if err := json.NewEncoder(w).Encode(&content); err != nil { http.Error(w, fmt.Sprintf("failed to encode response: %v", err), http.StatusInternalServerError) close(seq.quit) return