runner: enable returning more info from runner processing
Currently we return only the text predicted from the LLM. This was nice in that it was simple, but there may be other info we want to know from the processing. This change adds the ability to return more information from the runner than just the text predicted.
This commit is contained in:
parent
9f8a18ec05
commit
d5eae8248d
|
|
@ -22,6 +22,7 @@ import (
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
"unicode/utf8"
|
||||||
|
|
||||||
"golang.org/x/sync/semaphore"
|
"golang.org/x/sync/semaphore"
|
||||||
|
|
||||||
|
|
@ -725,10 +726,68 @@ type CompletionResponse struct {
|
||||||
EvalDuration time.Duration `json:"eval_duration"`
|
EvalDuration time.Duration `json:"eval_duration"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// unicodeBufferHandler wraps a completion response callback to handle partial UTF-8 sequences.
|
||||||
|
// This function creates a stateful closure that is NOT safe for concurrent use.
|
||||||
|
// Each completion request should create its own handler instance.
|
||||||
|
func unicodeBufferHandler(fn func(CompletionResponse)) func(CompletionResponse) {
|
||||||
|
var pendingUTF8 string
|
||||||
|
|
||||||
|
return func(resp CompletionResponse) {
|
||||||
|
if resp.Content == "" && !resp.Done {
|
||||||
|
// No content to process, just pass through
|
||||||
|
fn(resp)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Combine any pending UTF-8 with current content
|
||||||
|
combinedContent := pendingUTF8 + resp.Content
|
||||||
|
pendingUTF8 = ""
|
||||||
|
|
||||||
|
// Check if combined content is valid UTF-8
|
||||||
|
if utf8.ValidString(combinedContent) {
|
||||||
|
// Valid UTF-8, send it
|
||||||
|
resp.Content = combinedContent
|
||||||
|
fn(resp)
|
||||||
|
} else {
|
||||||
|
// Invalid UTF-8
|
||||||
|
if resp.Done {
|
||||||
|
// This is the final response, trim incomplete UTF-8
|
||||||
|
trimmedContent := combinedContent
|
||||||
|
for !utf8.ValidString(trimmedContent) && len(trimmedContent) > 0 {
|
||||||
|
trimmedContent = trimmedContent[:len(trimmedContent)-1]
|
||||||
|
}
|
||||||
|
resp.Content = trimmedContent
|
||||||
|
fn(resp)
|
||||||
|
} else {
|
||||||
|
// Not final response, split valid and invalid parts
|
||||||
|
validPrefix := combinedContent
|
||||||
|
for !utf8.ValidString(validPrefix) && len(validPrefix) > 0 {
|
||||||
|
validPrefix = validPrefix[:len(validPrefix)-1]
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(validPrefix) > 0 {
|
||||||
|
// Send valid prefix
|
||||||
|
resp.Content = validPrefix
|
||||||
|
fn(resp)
|
||||||
|
// Buffer the remainder
|
||||||
|
pendingUTF8 = combinedContent[len(validPrefix):]
|
||||||
|
} else {
|
||||||
|
// No valid prefix, buffer everything
|
||||||
|
pendingUTF8 = combinedContent
|
||||||
|
// Don't send this response
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn func(CompletionResponse)) error {
|
func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn func(CompletionResponse)) error {
|
||||||
slog.Debug("completion request", "images", len(req.Images), "prompt", len(req.Prompt), "format", string(req.Format))
|
slog.Debug("completion request", "images", len(req.Images), "prompt", len(req.Prompt), "format", string(req.Format))
|
||||||
slog.Log(ctx, logutil.LevelTrace, "completion request", "prompt", req.Prompt)
|
slog.Log(ctx, logutil.LevelTrace, "completion request", "prompt", req.Prompt)
|
||||||
|
|
||||||
|
// Wrap the callback with unicode buffer handling
|
||||||
|
unicodeFn := unicodeBufferHandler(fn)
|
||||||
|
|
||||||
if len(req.Format) > 0 {
|
if len(req.Format) > 0 {
|
||||||
switch string(req.Format) {
|
switch string(req.Format) {
|
||||||
case `null`, `""`:
|
case `null`, `""`:
|
||||||
|
|
@ -854,13 +913,13 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu
|
||||||
}
|
}
|
||||||
|
|
||||||
if c.Content != "" {
|
if c.Content != "" {
|
||||||
fn(CompletionResponse{
|
unicodeFn(CompletionResponse{
|
||||||
Content: c.Content,
|
Content: c.Content,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
if c.Done {
|
if c.Done {
|
||||||
fn(c)
|
unicodeFn(c)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -70,3 +70,152 @@ func TestLLMServerCompletionFormat(t *testing.T) {
|
||||||
}, nil)
|
}, nil)
|
||||||
checkValid(err)
|
checkValid(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestUnicodeBufferHandler(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
inputResponses []CompletionResponse
|
||||||
|
expectedResponses []CompletionResponse
|
||||||
|
description string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "complete_unicode",
|
||||||
|
inputResponses: []CompletionResponse{
|
||||||
|
{Content: "Hello", Done: false},
|
||||||
|
{Content: " world", Done: false},
|
||||||
|
{Content: "!", Done: true},
|
||||||
|
},
|
||||||
|
expectedResponses: []CompletionResponse{
|
||||||
|
{Content: "Hello", Done: false},
|
||||||
|
{Content: " world", Done: false},
|
||||||
|
{Content: "!", Done: true},
|
||||||
|
},
|
||||||
|
description: "All responses with valid unicode should pass through unchanged",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "incomplete_unicode_at_end_with_done",
|
||||||
|
inputResponses: []CompletionResponse{
|
||||||
|
{Content: "Hello", Done: false},
|
||||||
|
{Content: string([]byte{0xF0, 0x9F}), Done: true}, // Incomplete emoji with Done=true
|
||||||
|
},
|
||||||
|
expectedResponses: []CompletionResponse{
|
||||||
|
{Content: "Hello", Done: false},
|
||||||
|
{Content: "", Done: true}, // Content is trimmed but response is still sent with Done=true
|
||||||
|
},
|
||||||
|
description: "When Done=true, incomplete Unicode at the end should be trimmed",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "split_unicode_across_responses",
|
||||||
|
inputResponses: []CompletionResponse{
|
||||||
|
{Content: "Hello " + string([]byte{0xF0, 0x9F}), Done: false}, // First part of 😀
|
||||||
|
{Content: string([]byte{0x98, 0x80}) + " world!", Done: true}, // Second part of 😀 and more text
|
||||||
|
},
|
||||||
|
expectedResponses: []CompletionResponse{
|
||||||
|
{Content: "Hello ", Done: false}, // Incomplete Unicode trimmed
|
||||||
|
{Content: "😀 world!", Done: true}, // Complete emoji in second response
|
||||||
|
},
|
||||||
|
description: "Unicode split across responses should be handled correctly",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "incomplete_unicode_buffered",
|
||||||
|
inputResponses: []CompletionResponse{
|
||||||
|
{Content: "Test " + string([]byte{0xF0, 0x9F}), Done: false}, // Incomplete emoji
|
||||||
|
{Content: string([]byte{0x98, 0x80}), Done: false}, // Complete the emoji
|
||||||
|
{Content: " done", Done: true},
|
||||||
|
},
|
||||||
|
expectedResponses: []CompletionResponse{
|
||||||
|
{Content: "Test ", Done: false}, // First part without incomplete unicode
|
||||||
|
{Content: "😀", Done: false}, // Complete emoji
|
||||||
|
{Content: " done", Done: true},
|
||||||
|
},
|
||||||
|
description: "Incomplete unicode should be buffered and combined with next response",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "empty_response_with_done",
|
||||||
|
inputResponses: []CompletionResponse{
|
||||||
|
{Content: "Complete response", Done: false},
|
||||||
|
{Content: "", Done: true}, // Empty response with Done=true
|
||||||
|
},
|
||||||
|
expectedResponses: []CompletionResponse{
|
||||||
|
{Content: "Complete response", Done: false},
|
||||||
|
{Content: "", Done: true}, // Should still be sent because Done=true
|
||||||
|
},
|
||||||
|
description: "Empty final response with Done=true should still be sent",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "done_reason_preserved",
|
||||||
|
inputResponses: []CompletionResponse{
|
||||||
|
{Content: "Response", Done: false},
|
||||||
|
{Content: " complete", Done: true, DoneReason: DoneReasonStop},
|
||||||
|
},
|
||||||
|
expectedResponses: []CompletionResponse{
|
||||||
|
{Content: "Response", Done: false},
|
||||||
|
{Content: " complete", Done: true, DoneReason: DoneReasonStop},
|
||||||
|
},
|
||||||
|
description: "DoneReason should be preserved in the final response",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "only_incomplete_unicode_not_done",
|
||||||
|
inputResponses: []CompletionResponse{
|
||||||
|
{Content: string([]byte{0xF0, 0x9F}), Done: false}, // Only incomplete unicode
|
||||||
|
},
|
||||||
|
expectedResponses: []CompletionResponse{
|
||||||
|
// No response expected - should be buffered
|
||||||
|
},
|
||||||
|
description: "Response with only incomplete unicode should be buffered if not done",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
var actualResponses []CompletionResponse
|
||||||
|
|
||||||
|
// Create a callback that collects responses
|
||||||
|
callback := func(resp CompletionResponse) {
|
||||||
|
actualResponses = append(actualResponses, resp)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create the unicode buffer handler
|
||||||
|
handler := unicodeBufferHandler(callback)
|
||||||
|
|
||||||
|
// Send all input responses through the handler
|
||||||
|
for _, resp := range tt.inputResponses {
|
||||||
|
handler(resp)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify the number of responses
|
||||||
|
if len(actualResponses) != len(tt.expectedResponses) {
|
||||||
|
t.Fatalf("%s: got %d responses, want %d responses",
|
||||||
|
tt.description, len(actualResponses), len(tt.expectedResponses))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify each response matches the expected one
|
||||||
|
for i, expected := range tt.expectedResponses {
|
||||||
|
if i >= len(actualResponses) {
|
||||||
|
t.Fatalf("%s: missing response at index %d", tt.description, i)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
actual := actualResponses[i]
|
||||||
|
|
||||||
|
// Verify content
|
||||||
|
if actual.Content != expected.Content {
|
||||||
|
t.Errorf("%s: response[%d].Content = %q, want %q",
|
||||||
|
tt.description, i, actual.Content, expected.Content)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify Done flag
|
||||||
|
if actual.Done != expected.Done {
|
||||||
|
t.Errorf("%s: response[%d].Done = %v, want %v",
|
||||||
|
tt.description, i, actual.Done, expected.Done)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify DoneReason if specified
|
||||||
|
if actual.DoneReason != expected.DoneReason {
|
||||||
|
t.Errorf("%s: response[%d].DoneReason = %v, want %v",
|
||||||
|
tt.description, i, actual.DoneReason, expected.DoneReason)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
||||||
|
|
@ -2,6 +2,8 @@ package common
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
|
"github.com/ollama/ollama/llm"
|
||||||
)
|
)
|
||||||
|
|
||||||
func FindStop(sequence string, stops []string) (bool, string) {
|
func FindStop(sequence string, stops []string) (bool, string) {
|
||||||
|
|
@ -29,68 +31,41 @@ func ContainsStopSuffix(sequence string, stops []string) bool {
|
||||||
// truncateStop removes the provided stop string from pieces,
|
// truncateStop removes the provided stop string from pieces,
|
||||||
// returning the partial pieces with stop removed, including truncating
|
// returning the partial pieces with stop removed, including truncating
|
||||||
// the last piece if required (and signalling if this was the case)
|
// the last piece if required (and signalling if this was the case)
|
||||||
func TruncateStop(pieces []string, stop string) ([]string, bool) {
|
func TruncateStop(resps []llm.CompletionResponse, stop string) ([]llm.CompletionResponse, bool) {
|
||||||
joined := strings.Join(pieces, "")
|
var sequence string
|
||||||
|
for _, resp := range resps {
|
||||||
index := strings.Index(joined, stop)
|
sequence += resp.Content
|
||||||
if index == -1 {
|
|
||||||
return pieces, false
|
|
||||||
}
|
}
|
||||||
|
|
||||||
joined = joined[:index]
|
idx := strings.Index(sequence, stop)
|
||||||
|
if idx < 0 {
|
||||||
// Split truncated string back into pieces of original lengths
|
return resps, false
|
||||||
lengths := make([]int, len(pieces))
|
|
||||||
for i, piece := range pieces {
|
|
||||||
lengths[i] = len(piece)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
var result []string
|
truncated := sequence[:idx]
|
||||||
tokenTruncated := false
|
if len(truncated) == 0 {
|
||||||
start := 0
|
return nil, true
|
||||||
for _, length := range lengths {
|
}
|
||||||
if start >= len(joined) {
|
|
||||||
|
result := make([]llm.CompletionResponse, 0, len(resps))
|
||||||
|
|
||||||
|
// Track position in truncated sequence
|
||||||
|
pos := 0
|
||||||
|
truncationHappened := false
|
||||||
|
for _, resp := range resps {
|
||||||
|
if pos >= len(truncated) {
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
|
|
||||||
end := start + length
|
chunk := truncated[pos:min(pos+len(resp.Content), len(truncated))]
|
||||||
if end > len(joined) {
|
if len(chunk) < len(resp.Content) {
|
||||||
end = len(joined)
|
truncationHappened = true
|
||||||
tokenTruncated = true
|
|
||||||
}
|
}
|
||||||
result = append(result, joined[start:end])
|
if len(chunk) > 0 {
|
||||||
start = end
|
result = append(result, llm.CompletionResponse{Content: chunk})
|
||||||
|
}
|
||||||
|
pos += len(resp.Content)
|
||||||
}
|
}
|
||||||
|
|
||||||
return result, tokenTruncated
|
return result, truncationHappened
|
||||||
}
|
|
||||||
|
|
||||||
func IncompleteUnicode(token string) bool {
|
|
||||||
incomplete := false
|
|
||||||
|
|
||||||
// check if there is incomplete UTF-8 character at the end
|
|
||||||
for i := 1; i < 5 && i <= len(token); i++ {
|
|
||||||
c := token[len(token)-i]
|
|
||||||
|
|
||||||
if (c & 0xc0) == 0x80 {
|
|
||||||
// continuation byte: 10xxxxxx
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
if (c & 0xe0) == 0xc0 {
|
|
||||||
// 2-byte character: 110xxxxx ...
|
|
||||||
incomplete = i < 2
|
|
||||||
} else if (c & 0xf0) == 0xe0 {
|
|
||||||
// 3-byte character: 1110xxxx ...
|
|
||||||
incomplete = i < 3
|
|
||||||
} else if (c & 0xf8) == 0xf0 {
|
|
||||||
// 4-byte character: 11110xxx ...
|
|
||||||
incomplete = i < 4
|
|
||||||
}
|
|
||||||
|
|
||||||
// else 1-byte character or invalid byte
|
|
||||||
break
|
|
||||||
}
|
|
||||||
|
|
||||||
return incomplete
|
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -1,51 +1,84 @@
|
||||||
package common
|
package common
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"fmt"
|
||||||
"reflect"
|
"reflect"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
|
"github.com/ollama/ollama/llm"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestTruncateStop(t *testing.T) {
|
func TestTruncateStop(t *testing.T) {
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
pieces []string
|
pieces []llm.CompletionResponse
|
||||||
stop string
|
stop string
|
||||||
expected []string
|
expected []llm.CompletionResponse
|
||||||
expectedTrunc bool
|
expectedTrunc bool
|
||||||
}{
|
}{
|
||||||
{
|
{
|
||||||
name: "Single word",
|
name: "Single word",
|
||||||
pieces: []string{"hello", "world"},
|
pieces: []llm.CompletionResponse{
|
||||||
stop: "world",
|
{Content: "Hello"},
|
||||||
expected: []string{"hello"},
|
{Content: "world"},
|
||||||
|
},
|
||||||
|
stop: "world",
|
||||||
|
expected: []llm.CompletionResponse{
|
||||||
|
{Content: "Hello"},
|
||||||
|
},
|
||||||
expectedTrunc: false,
|
expectedTrunc: false,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "Partial",
|
name: "Partial",
|
||||||
pieces: []string{"hello", "wor"},
|
pieces: []llm.CompletionResponse{
|
||||||
stop: "or",
|
{Content: "Hello"},
|
||||||
expected: []string{"hello", "w"},
|
{Content: " wor"},
|
||||||
|
},
|
||||||
|
stop: "or",
|
||||||
|
expected: []llm.CompletionResponse{
|
||||||
|
{Content: "Hello"},
|
||||||
|
{Content: " w"},
|
||||||
|
},
|
||||||
expectedTrunc: true,
|
expectedTrunc: true,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "Suffix",
|
name: "Suffix",
|
||||||
pieces: []string{"Hello", " there", "!"},
|
pieces: []llm.CompletionResponse{
|
||||||
stop: "!",
|
{Content: "Hello"},
|
||||||
expected: []string{"Hello", " there"},
|
{Content: " there"},
|
||||||
|
{Content: "!"},
|
||||||
|
},
|
||||||
|
stop: "!",
|
||||||
|
expected: []llm.CompletionResponse{
|
||||||
|
{Content: "Hello"},
|
||||||
|
{Content: " there"},
|
||||||
|
},
|
||||||
expectedTrunc: false,
|
expectedTrunc: false,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "Suffix partial",
|
name: "Suffix partial",
|
||||||
pieces: []string{"Hello", " the", "re!"},
|
pieces: []llm.CompletionResponse{
|
||||||
stop: "there!",
|
{Content: "Hello"},
|
||||||
expected: []string{"Hello", " "},
|
{Content: " the"},
|
||||||
|
{Content: "re!"},
|
||||||
|
},
|
||||||
|
stop: "there!",
|
||||||
|
expected: []llm.CompletionResponse{
|
||||||
|
{Content: "Hello"},
|
||||||
|
{Content: " "},
|
||||||
|
},
|
||||||
expectedTrunc: true,
|
expectedTrunc: true,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "Middle",
|
name: "Middle",
|
||||||
pieces: []string{"hello", " wor"},
|
pieces: []llm.CompletionResponse{
|
||||||
stop: "llo w",
|
{Content: "Hello"},
|
||||||
expected: []string{"he"},
|
{Content: " wo"},
|
||||||
|
},
|
||||||
|
stop: "llo w",
|
||||||
|
expected: []llm.CompletionResponse{
|
||||||
|
{Content: "He"},
|
||||||
|
},
|
||||||
expectedTrunc: true,
|
expectedTrunc: true,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
@ -54,76 +87,23 @@ func TestTruncateStop(t *testing.T) {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
result, resultTrunc := TruncateStop(tt.pieces, tt.stop)
|
result, resultTrunc := TruncateStop(tt.pieces, tt.stop)
|
||||||
if !reflect.DeepEqual(result, tt.expected) || resultTrunc != tt.expectedTrunc {
|
if !reflect.DeepEqual(result, tt.expected) || resultTrunc != tt.expectedTrunc {
|
||||||
t.Errorf("truncateStop(%v, %s): have %v (%v); want %v (%v)", tt.pieces, tt.stop, result, resultTrunc, tt.expected, tt.expectedTrunc)
|
t.Errorf("truncateStop(%v, %v):\n%shave truncated %v\nwant truncated %v",
|
||||||
|
tt.pieces, tt.stop, formatContentDiff(result, tt.expected), resultTrunc, tt.expectedTrunc)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestIncompleteUnicode(t *testing.T) {
|
func formatContentDiff(result, expected []llm.CompletionResponse) string {
|
||||||
tests := []struct {
|
var s string
|
||||||
name string
|
for i := 0; i < len(result) || i < len(expected); i++ {
|
||||||
input string
|
if i < len(result) && i < len(expected) && result[i].Content != expected[i].Content {
|
||||||
expected bool
|
s += fmt.Sprintf("[%d] %q vs %q\n", i, result[i].Content, expected[i].Content)
|
||||||
}{
|
} else if i < len(result) && i >= len(expected) {
|
||||||
{
|
s += fmt.Sprintf("[%d] extra %q\n", i, result[i].Content)
|
||||||
name: "Basic",
|
} else if i >= len(result) && i < len(expected) {
|
||||||
input: "hi",
|
s += fmt.Sprintf("[%d] missing %q\n", i, expected[i].Content)
|
||||||
expected: false,
|
}
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "Two byte",
|
|
||||||
input: "hi" + string([]byte{0xc2, 0xa3}),
|
|
||||||
expected: false,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "Two byte - missing last",
|
|
||||||
input: "hi" + string([]byte{0xc2}),
|
|
||||||
expected: true,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "Three byte",
|
|
||||||
input: "hi" + string([]byte{0xe0, 0xA0, 0x80}),
|
|
||||||
expected: false,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "Three byte - missing last",
|
|
||||||
input: "hi" + string([]byte{0xe0, 0xA0}),
|
|
||||||
expected: true,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "Three byte - missing last 2",
|
|
||||||
input: "hi" + string([]byte{0xe0}),
|
|
||||||
expected: true,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "Four byte",
|
|
||||||
input: "hi" + string([]byte{0xf0, 0x92, 0x8a, 0xb7}),
|
|
||||||
expected: false,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "Four byte - missing last",
|
|
||||||
input: "hi" + string([]byte{0xf0, 0x92, 0x8a}),
|
|
||||||
expected: true,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "Four byte - missing last 2",
|
|
||||||
input: "hi" + string([]byte{0xf0, 0x92}),
|
|
||||||
expected: true,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "Four byte - missing last 3",
|
|
||||||
input: "hi" + string([]byte{0xf0}),
|
|
||||||
expected: true,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
result := IncompleteUnicode(tt.input)
|
|
||||||
if result != tt.expected {
|
|
||||||
t.Errorf("incompleteUnicode(%s): have %v; want %v", tt.input, result, tt.expected)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
}
|
||||||
|
return s
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -17,7 +17,6 @@ import (
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
"unicode/utf8"
|
|
||||||
|
|
||||||
"golang.org/x/sync/semaphore"
|
"golang.org/x/sync/semaphore"
|
||||||
|
|
||||||
|
|
@ -52,13 +51,13 @@ type Sequence struct {
|
||||||
pendingInputs []input
|
pendingInputs []input
|
||||||
|
|
||||||
// tokens that have been generated but not returned yet (e.g. for stop sequences)
|
// tokens that have been generated but not returned yet (e.g. for stop sequences)
|
||||||
pendingResponses []string
|
pendingResponses []llm.CompletionResponse
|
||||||
|
|
||||||
// input cache being used by this sequence
|
// input cache being used by this sequence
|
||||||
cache *InputCacheSlot
|
cache *InputCacheSlot
|
||||||
|
|
||||||
// channel to send responses over
|
// channel to send responses over
|
||||||
responses chan string
|
responses chan llm.CompletionResponse
|
||||||
|
|
||||||
// channel to stop decoding (such as if the remote connection is closed)
|
// channel to stop decoding (such as if the remote connection is closed)
|
||||||
quit chan bool
|
quit chan bool
|
||||||
|
|
@ -89,6 +88,19 @@ type Sequence struct {
|
||||||
numPromptInputs int
|
numPromptInputs int
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (seq *Sequence) send(resp llm.CompletionResponse) bool {
|
||||||
|
if len(resp.Content) > 0 || resp.Done {
|
||||||
|
select {
|
||||||
|
case seq.responses <- resp:
|
||||||
|
// Successfully sent
|
||||||
|
return true
|
||||||
|
case <-seq.quit:
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
type NewSequenceParams struct {
|
type NewSequenceParams struct {
|
||||||
numPredict int
|
numPredict int
|
||||||
stop []string
|
stop []string
|
||||||
|
|
@ -147,8 +159,8 @@ func (s *Server) NewSequence(prompt string, images []llm.ImageData, params NewSe
|
||||||
numPromptInputs: len(inputs),
|
numPromptInputs: len(inputs),
|
||||||
startProcessingTime: startTime,
|
startProcessingTime: startTime,
|
||||||
numPredict: params.numPredict,
|
numPredict: params.numPredict,
|
||||||
pendingResponses: make([]string, 0),
|
pendingResponses: make([]llm.CompletionResponse, 0),
|
||||||
responses: make(chan string, 100),
|
responses: make(chan llm.CompletionResponse, 100),
|
||||||
quit: make(chan bool, 1),
|
quit: make(chan bool, 1),
|
||||||
embedding: make(chan []float32, 1),
|
embedding: make(chan []float32, 1),
|
||||||
samplingCtx: sc,
|
samplingCtx: sc,
|
||||||
|
|
@ -272,36 +284,15 @@ func (s *Server) allNil() bool {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
func flushPending(seq *Sequence) bool {
|
|
||||||
joined := strings.Join(seq.pendingResponses, "")
|
|
||||||
seq.pendingResponses = []string{}
|
|
||||||
|
|
||||||
// Check if there are any partial UTF-8 characters remaining.
|
|
||||||
// We already check and queue as we are generating but some may
|
|
||||||
// still make it here:
|
|
||||||
// - Sequence is ending, e.g. generation limit has been hit
|
|
||||||
// - Invalid characters in the middle of a string
|
|
||||||
// This is a stricter check to ensure we never output invalid Unicode.
|
|
||||||
for !utf8.ValidString(joined) {
|
|
||||||
joined = joined[:len(joined)-1]
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(joined) == 0 {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
select {
|
|
||||||
case seq.responses <- joined:
|
|
||||||
return true
|
|
||||||
case <-seq.quit:
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *Server) removeSequence(seqIndex int, reason llm.DoneReason) {
|
func (s *Server) removeSequence(seqIndex int, reason llm.DoneReason) {
|
||||||
seq := s.seqs[seqIndex]
|
seq := s.seqs[seqIndex]
|
||||||
|
|
||||||
flushPending(seq)
|
// Send any remaining pending responses
|
||||||
|
for _, resp := range seq.pendingResponses {
|
||||||
|
seq.send(resp)
|
||||||
|
}
|
||||||
|
seq.pendingResponses = []llm.CompletionResponse{}
|
||||||
|
|
||||||
seq.doneReason = reason
|
seq.doneReason = reason
|
||||||
close(seq.responses)
|
close(seq.responses)
|
||||||
close(seq.embedding)
|
close(seq.embedding)
|
||||||
|
|
@ -490,8 +481,11 @@ func (s *Server) processBatch(tokenBatch *llama.Batch, embedBatch *llama.Batch)
|
||||||
|
|
||||||
seq.inputs = []input{{token: token}}
|
seq.inputs = []input{{token: token}}
|
||||||
|
|
||||||
seq.pendingResponses = append(seq.pendingResponses, piece)
|
seq.pendingResponses = append(seq.pendingResponses, llm.CompletionResponse{Content: piece})
|
||||||
sequence := strings.Join(seq.pendingResponses, "")
|
sequence := ""
|
||||||
|
for _, r := range seq.pendingResponses {
|
||||||
|
sequence += r.Content
|
||||||
|
}
|
||||||
|
|
||||||
if ok, stop := common.FindStop(sequence, seq.stop); ok {
|
if ok, stop := common.FindStop(sequence, seq.stop); ok {
|
||||||
slog.Debug("hit stop token", "pending", seq.pendingResponses, "stop", stop)
|
slog.Debug("hit stop token", "pending", seq.pendingResponses, "stop", stop)
|
||||||
|
|
@ -523,13 +517,13 @@ func (s *Server) processBatch(tokenBatch *llama.Batch, embedBatch *llama.Batch)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
if common.IncompleteUnicode(sequence) {
|
for _, resp := range seq.pendingResponses {
|
||||||
continue
|
if !seq.send(resp) {
|
||||||
}
|
s.removeSequence(i, llm.DoneReasonConnectionClosed)
|
||||||
|
break
|
||||||
if !flushPending(seq) {
|
}
|
||||||
s.removeSequence(i, llm.DoneReasonConnectionClosed)
|
|
||||||
}
|
}
|
||||||
|
seq.pendingResponses = []llm.CompletionResponse{}
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
|
|
@ -627,9 +621,7 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
|
||||||
return
|
return
|
||||||
case content, ok := <-seq.responses:
|
case content, ok := <-seq.responses:
|
||||||
if ok {
|
if ok {
|
||||||
if err := json.NewEncoder(w).Encode(&llm.CompletionResponse{
|
if err := json.NewEncoder(w).Encode(&content); err != nil {
|
||||||
Content: content,
|
|
||||||
}); err != nil {
|
|
||||||
http.Error(w, fmt.Sprintf("failed to encode response: %v", err), http.StatusInternalServerError)
|
http.Error(w, fmt.Sprintf("failed to encode response: %v", err), http.StatusInternalServerError)
|
||||||
close(seq.quit)
|
close(seq.quit)
|
||||||
return
|
return
|
||||||
|
|
|
||||||
|
|
@ -20,7 +20,6 @@ import (
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
"unicode/utf8"
|
|
||||||
|
|
||||||
"golang.org/x/image/bmp"
|
"golang.org/x/image/bmp"
|
||||||
"golang.org/x/sync/semaphore"
|
"golang.org/x/sync/semaphore"
|
||||||
|
|
@ -56,13 +55,13 @@ type Sequence struct {
|
||||||
pendingInputs []input.Input
|
pendingInputs []input.Input
|
||||||
|
|
||||||
// tokens that have been generated but not returned yet (e.g. for stop sequences)
|
// tokens that have been generated but not returned yet (e.g. for stop sequences)
|
||||||
pendingResponses []string
|
pendingResponses []llm.CompletionResponse
|
||||||
|
|
||||||
// input cache being used by this sequence
|
// input cache being used by this sequence
|
||||||
cache *InputCacheSlot
|
cache *InputCacheSlot
|
||||||
|
|
||||||
// channel to send responses over
|
// channel to send responses over
|
||||||
responses chan string
|
responses chan llm.CompletionResponse
|
||||||
|
|
||||||
// channel to stop decoding (such as if the remote connection is closed)
|
// channel to stop decoding (such as if the remote connection is closed)
|
||||||
quit chan bool
|
quit chan bool
|
||||||
|
|
@ -94,6 +93,19 @@ type Sequence struct {
|
||||||
numPromptInputs int
|
numPromptInputs int
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (seq *Sequence) send(resp llm.CompletionResponse) bool {
|
||||||
|
if len(resp.Content) > 0 || resp.Done {
|
||||||
|
select {
|
||||||
|
case seq.responses <- resp:
|
||||||
|
// Successfully sent
|
||||||
|
return true
|
||||||
|
case <-seq.quit:
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
type NewSequenceParams struct {
|
type NewSequenceParams struct {
|
||||||
numPredict int
|
numPredict int
|
||||||
stop []string
|
stop []string
|
||||||
|
|
@ -167,8 +179,8 @@ func (s *Server) NewSequence(prompt string, images []llm.ImageData, params NewSe
|
||||||
numPromptInputs: len(inputs),
|
numPromptInputs: len(inputs),
|
||||||
startProcessingTime: startTime,
|
startProcessingTime: startTime,
|
||||||
numPredict: params.numPredict,
|
numPredict: params.numPredict,
|
||||||
pendingResponses: make([]string, 0),
|
pendingResponses: make([]llm.CompletionResponse, 0),
|
||||||
responses: make(chan string, 100),
|
responses: make(chan llm.CompletionResponse, 100),
|
||||||
quit: make(chan bool, 1),
|
quit: make(chan bool, 1),
|
||||||
embedding: make(chan []float32, 1),
|
embedding: make(chan []float32, 1),
|
||||||
sampler: params.sampler,
|
sampler: params.sampler,
|
||||||
|
|
@ -313,36 +325,15 @@ func (s *Server) allNil() bool {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
func flushPending(seq *Sequence) bool {
|
|
||||||
joined := strings.Join(seq.pendingResponses, "")
|
|
||||||
seq.pendingResponses = []string{}
|
|
||||||
|
|
||||||
// Check if there are any partial UTF-8 characters remaining.
|
|
||||||
// We already check and queue as we are generating but some may
|
|
||||||
// still make it here:
|
|
||||||
// - Sequence is ending, e.g. generation limit has been hit
|
|
||||||
// - Invalid characters in the middle of a string
|
|
||||||
// This is a stricter check to ensure we never output invalid Unicode.
|
|
||||||
for !utf8.ValidString(joined) {
|
|
||||||
joined = joined[:len(joined)-1]
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(joined) == 0 {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
select {
|
|
||||||
case seq.responses <- joined:
|
|
||||||
return true
|
|
||||||
case <-seq.quit:
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *Server) removeSequence(seqIndex int, reason llm.DoneReason) {
|
func (s *Server) removeSequence(seqIndex int, reason llm.DoneReason) {
|
||||||
seq := s.seqs[seqIndex]
|
seq := s.seqs[seqIndex]
|
||||||
|
|
||||||
flushPending(seq)
|
// Send any remaining pending responses
|
||||||
|
for _, resp := range seq.pendingResponses {
|
||||||
|
seq.send(resp)
|
||||||
|
}
|
||||||
|
seq.pendingResponses = []llm.CompletionResponse{}
|
||||||
|
|
||||||
seq.doneReason = reason
|
seq.doneReason = reason
|
||||||
close(seq.responses)
|
close(seq.responses)
|
||||||
close(seq.embedding)
|
close(seq.embedding)
|
||||||
|
|
@ -541,8 +532,11 @@ func (s *Server) processBatch() error {
|
||||||
|
|
||||||
seq.inputs = []input.Input{{Token: token}}
|
seq.inputs = []input.Input{{Token: token}}
|
||||||
|
|
||||||
seq.pendingResponses = append(seq.pendingResponses, piece)
|
seq.pendingResponses = append(seq.pendingResponses, llm.CompletionResponse{Content: piece})
|
||||||
sequence := strings.Join(seq.pendingResponses, "")
|
sequence := ""
|
||||||
|
for _, r := range seq.pendingResponses {
|
||||||
|
sequence += r.Content
|
||||||
|
}
|
||||||
|
|
||||||
if ok, stop := common.FindStop(sequence, seq.stop); ok {
|
if ok, stop := common.FindStop(sequence, seq.stop); ok {
|
||||||
slog.Debug("hit stop token", "pending", seq.pendingResponses, "stop", stop)
|
slog.Debug("hit stop token", "pending", seq.pendingResponses, "stop", stop)
|
||||||
|
|
@ -574,13 +568,14 @@ func (s *Server) processBatch() error {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
if common.IncompleteUnicode(sequence) {
|
// Send all pending responses directly without unicode checking
|
||||||
continue
|
for _, resp := range seq.pendingResponses {
|
||||||
}
|
if !seq.send(resp) {
|
||||||
|
s.removeSequence(i, llm.DoneReasonConnectionClosed)
|
||||||
if !flushPending(seq) {
|
break
|
||||||
s.removeSequence(i, llm.DoneReasonConnectionClosed)
|
}
|
||||||
}
|
}
|
||||||
|
seq.pendingResponses = []llm.CompletionResponse{}
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
|
|
@ -683,9 +678,7 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
|
||||||
return
|
return
|
||||||
case content, ok := <-seq.responses:
|
case content, ok := <-seq.responses:
|
||||||
if ok {
|
if ok {
|
||||||
if err := json.NewEncoder(w).Encode(&llm.CompletionResponse{
|
if err := json.NewEncoder(w).Encode(&content); err != nil {
|
||||||
Content: content,
|
|
||||||
}); err != nil {
|
|
||||||
http.Error(w, fmt.Sprintf("failed to encode response: %v", err), http.StatusInternalServerError)
|
http.Error(w, fmt.Sprintf("failed to encode response: %v", err), http.StatusInternalServerError)
|
||||||
close(seq.quit)
|
close(seq.quit)
|
||||||
return
|
return
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue