Compare commits
5 Commits
jmorganca/
...
v0.3.13
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
7fe3902552 | ||
|
|
0077e22d52 | ||
|
|
03408f3437 | ||
|
|
cd7e01e8b9 | ||
|
|
7a962bd802 |
9
.gitattributes
vendored
9
.gitattributes
vendored
@@ -1,5 +1,12 @@
|
||||
llm/ext_server/* linguist-vendored
|
||||
llama/** linguist-vendored
|
||||
llama/**/*.cpp linguist-vendored
|
||||
llama/**/*.hpp linguist-vendored
|
||||
llama/**/*.h linguist-vendored
|
||||
llama/**/*.c linguist-vendored
|
||||
llama/**/*.cu linguist-vendored
|
||||
llama/**/*.cuh linguist-vendored
|
||||
llama/**/*.m linguist-vendored
|
||||
llama/**/*.metal linguist-vendored
|
||||
|
||||
* text=auto
|
||||
*.go text eol=lf
|
||||
|
||||
@@ -442,13 +442,6 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error {
|
||||
return err
|
||||
}
|
||||
|
||||
// clear all previous images for better responses
|
||||
if len(images) > 0 {
|
||||
for i := range opts.Messages {
|
||||
opts.Messages[i].Images = nil
|
||||
}
|
||||
}
|
||||
|
||||
newMessage.Content = msg
|
||||
newMessage.Images = images
|
||||
}
|
||||
|
||||
@@ -451,14 +451,27 @@ func (s *Server) processBatch(tokenBatch *llama.Batch, embedBatch *llama.Batch)
|
||||
sequence := strings.Join(seq.pendingResponses, "")
|
||||
|
||||
if ok, stop := findStop(sequence, seq.stop); ok {
|
||||
slog.Debug("hit stop token", "stop", seq.stop)
|
||||
slog.Debug("hit stop token", "pending", seq.pendingResponses, "stop", stop)
|
||||
|
||||
trimCacheLen := len(seq.pendingResponses) - 1
|
||||
seq.pendingResponses = truncateStop(seq.pendingResponses, stop)
|
||||
trimCacheLen -= len(seq.pendingResponses)
|
||||
var tokenTruncated bool
|
||||
origLen := len(seq.pendingResponses)
|
||||
seq.pendingResponses, tokenTruncated = truncateStop(seq.pendingResponses, stop)
|
||||
newLen := len(seq.pendingResponses)
|
||||
|
||||
// Update the cache based on the tokens that will be returned:
|
||||
// - We have 1 token more than is currently in the cache because
|
||||
// the last one generated wasn't submitted to Decode
|
||||
// - Remove any stop sequences that we stripped out
|
||||
// - If truncateStop removed a portion of a token, drop that
|
||||
// - As defense-in-depth, if truncatedToken didn't find a stop token
|
||||
// remove the extra one that we added to the cache len
|
||||
tokenLen := len(seq.cache.Inputs) + 1
|
||||
tokenLen -= origLen - newLen
|
||||
if tokenTruncated || origLen == newLen {
|
||||
tokenLen--
|
||||
}
|
||||
seq.cache.Inputs = seq.cache.Inputs[:tokenLen]
|
||||
|
||||
// remove any tokens from the cache that we don't actually return
|
||||
seq.cache.Inputs = seq.cache.Inputs[:len(seq.cache.Inputs)-trimCacheLen]
|
||||
s.removeSequence(i, "stop")
|
||||
continue
|
||||
}
|
||||
|
||||
@@ -28,13 +28,13 @@ func containsStopSuffix(sequence string, stops []string) bool {
|
||||
|
||||
// truncateStop removes the provided stop string from pieces,
|
||||
// returning the partial pieces with stop removed, including truncating
|
||||
// the last piece if required
|
||||
func truncateStop(pieces []string, stop string) []string {
|
||||
// the last piece if required (and signalling if this was the case)
|
||||
func truncateStop(pieces []string, stop string) ([]string, bool) {
|
||||
joined := strings.Join(pieces, "")
|
||||
|
||||
index := strings.Index(joined, stop)
|
||||
if index == -1 {
|
||||
return pieces
|
||||
return pieces, false
|
||||
}
|
||||
|
||||
joined = joined[:index]
|
||||
@@ -46,6 +46,7 @@ func truncateStop(pieces []string, stop string) []string {
|
||||
}
|
||||
|
||||
var result []string
|
||||
tokenTruncated := false
|
||||
start := 0
|
||||
for _, length := range lengths {
|
||||
if start >= len(joined) {
|
||||
@@ -55,12 +56,13 @@ func truncateStop(pieces []string, stop string) []string {
|
||||
end := start + length
|
||||
if end > len(joined) {
|
||||
end = len(joined)
|
||||
tokenTruncated = true
|
||||
}
|
||||
result = append(result, joined[start:end])
|
||||
start = end
|
||||
}
|
||||
|
||||
return result
|
||||
return result, tokenTruncated
|
||||
}
|
||||
|
||||
func incompleteUnicode(token string) bool {
|
||||
|
||||
@@ -7,42 +7,54 @@ import (
|
||||
|
||||
func TestTruncateStop(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
pieces []string
|
||||
stop string
|
||||
expected []string
|
||||
name string
|
||||
pieces []string
|
||||
stop string
|
||||
expected []string
|
||||
expectedTrunc bool
|
||||
}{
|
||||
{
|
||||
name: "Single word",
|
||||
pieces: []string{"hello", "world"},
|
||||
stop: "world",
|
||||
expected: []string{"hello"},
|
||||
name: "Single word",
|
||||
pieces: []string{"hello", "world"},
|
||||
stop: "world",
|
||||
expected: []string{"hello"},
|
||||
expectedTrunc: false,
|
||||
},
|
||||
{
|
||||
name: "Partial",
|
||||
pieces: []string{"hello", "wor"},
|
||||
stop: "or",
|
||||
expected: []string{"hello", "w"},
|
||||
name: "Partial",
|
||||
pieces: []string{"hello", "wor"},
|
||||
stop: "or",
|
||||
expected: []string{"hello", "w"},
|
||||
expectedTrunc: true,
|
||||
},
|
||||
{
|
||||
name: "Suffix",
|
||||
pieces: []string{"Hello", " there", "!"},
|
||||
stop: "!",
|
||||
expected: []string{"Hello", " there"},
|
||||
name: "Suffix",
|
||||
pieces: []string{"Hello", " there", "!"},
|
||||
stop: "!",
|
||||
expected: []string{"Hello", " there"},
|
||||
expectedTrunc: false,
|
||||
},
|
||||
{
|
||||
name: "Middle",
|
||||
pieces: []string{"hello", " wor"},
|
||||
stop: "llo w",
|
||||
expected: []string{"he"},
|
||||
name: "Suffix partial",
|
||||
pieces: []string{"Hello", " the", "re!"},
|
||||
stop: "there!",
|
||||
expected: []string{"Hello", " "},
|
||||
expectedTrunc: true,
|
||||
},
|
||||
{
|
||||
name: "Middle",
|
||||
pieces: []string{"hello", " wor"},
|
||||
stop: "llo w",
|
||||
expected: []string{"he"},
|
||||
expectedTrunc: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := truncateStop(tt.pieces, tt.stop)
|
||||
if !reflect.DeepEqual(result, tt.expected) {
|
||||
t.Errorf("truncateStop(%v, %s): have %v; want %v", tt.pieces, tt.stop, result, tt.expected)
|
||||
result, resultTrunc := truncateStop(tt.pieces, tt.stop)
|
||||
if !reflect.DeepEqual(result, tt.expected) || resultTrunc != tt.expectedTrunc {
|
||||
t.Errorf("truncateStop(%v, %s): have %v (%v); want %v (%v)", tt.pieces, tt.stop, result, resultTrunc, tt.expected, tt.expectedTrunc)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
@@ -1086,10 +1086,13 @@ func (s *llmServer) Detokenize(ctx context.Context, tokens []int) (string, error
|
||||
}
|
||||
|
||||
func (s *llmServer) Close() error {
|
||||
s.modelLock.Lock()
|
||||
if s.model != nil {
|
||||
llama.FreeModel(s.model)
|
||||
s.model = nil
|
||||
}
|
||||
s.modelLock.Unlock()
|
||||
|
||||
if s.cmd != nil {
|
||||
slog.Debug("stopping llama server")
|
||||
if err := s.cmd.Process.Kill(); err != nil {
|
||||
@@ -1100,7 +1103,6 @@ func (s *llmServer) Close() error {
|
||||
slog.Debug("waiting for llama server to exit")
|
||||
<-s.done
|
||||
}
|
||||
s.cmd = nil
|
||||
|
||||
slog.Debug("llama server stopped")
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user