Compare commits

..

1 Commits

Author SHA1 Message Date
jmorganca
b54dcc750c update .gitattributes with proper linguist-vendored entry 2024-10-09 13:52:33 -07:00
6 changed files with 42 additions and 71 deletions

9
.gitattributes vendored
View File

@@ -1,12 +1,5 @@
llm/ext_server/* linguist-vendored
llama/**/*.cpp linguist-vendored
llama/**/*.hpp linguist-vendored
llama/**/*.h linguist-vendored
llama/**/*.c linguist-vendored
llama/**/*.cu linguist-vendored
llama/**/*.cuh linguist-vendored
llama/**/*.m linguist-vendored
llama/**/*.metal linguist-vendored
llama/** linguist-vendored
* text=auto
*.go text eol=lf

View File

@@ -442,6 +442,13 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error {
return err
}
// clear all previous images for better responses
if len(images) > 0 {
for i := range opts.Messages {
opts.Messages[i].Images = nil
}
}
newMessage.Content = msg
newMessage.Images = images
}

View File

@@ -451,27 +451,14 @@ func (s *Server) processBatch(tokenBatch *llama.Batch, embedBatch *llama.Batch)
sequence := strings.Join(seq.pendingResponses, "")
if ok, stop := findStop(sequence, seq.stop); ok {
slog.Debug("hit stop token", "pending", seq.pendingResponses, "stop", stop)
slog.Debug("hit stop token", "stop", seq.stop)
var tokenTruncated bool
origLen := len(seq.pendingResponses)
seq.pendingResponses, tokenTruncated = truncateStop(seq.pendingResponses, stop)
newLen := len(seq.pendingResponses)
// Update the cache based on the tokens that will be returned:
// - We have 1 token more than is currently in the cache because
// the last one generated wasn't submitted to Decode
// - Remove any stop sequences that we stripped out
// - If truncateStop removed a portion of a token, drop that
// - As defense-in-depth, if truncatedToken didn't find a stop token
// remove the extra one that we added to the cache len
tokenLen := len(seq.cache.Inputs) + 1
tokenLen -= origLen - newLen
if tokenTruncated || origLen == newLen {
tokenLen--
}
seq.cache.Inputs = seq.cache.Inputs[:tokenLen]
trimCacheLen := len(seq.pendingResponses) - 1
seq.pendingResponses = truncateStop(seq.pendingResponses, stop)
trimCacheLen -= len(seq.pendingResponses)
// remove any tokens from the cache that we don't actually return
seq.cache.Inputs = seq.cache.Inputs[:len(seq.cache.Inputs)-trimCacheLen]
s.removeSequence(i, "stop")
continue
}

10
llama/runner/stop.go vendored
View File

@@ -28,13 +28,13 @@ func containsStopSuffix(sequence string, stops []string) bool {
// truncateStop removes the provided stop string from pieces,
// returning the partial pieces with stop removed, including truncating
// the last piece if required (and signalling if this was the case)
func truncateStop(pieces []string, stop string) ([]string, bool) {
// the last piece if required
func truncateStop(pieces []string, stop string) []string {
joined := strings.Join(pieces, "")
index := strings.Index(joined, stop)
if index == -1 {
return pieces, false
return pieces
}
joined = joined[:index]
@@ -46,7 +46,6 @@ func truncateStop(pieces []string, stop string) ([]string, bool) {
}
var result []string
tokenTruncated := false
start := 0
for _, length := range lengths {
if start >= len(joined) {
@@ -56,13 +55,12 @@ func truncateStop(pieces []string, stop string) ([]string, bool) {
end := start + length
if end > len(joined) {
end = len(joined)
tokenTruncated = true
}
result = append(result, joined[start:end])
start = end
}
return result, tokenTruncated
return result
}
func incompleteUnicode(token string) bool {

View File

@@ -7,54 +7,42 @@ import (
func TestTruncateStop(t *testing.T) {
tests := []struct {
name string
pieces []string
stop string
expected []string
expectedTrunc bool
name string
pieces []string
stop string
expected []string
}{
{
name: "Single word",
pieces: []string{"hello", "world"},
stop: "world",
expected: []string{"hello"},
expectedTrunc: false,
name: "Single word",
pieces: []string{"hello", "world"},
stop: "world",
expected: []string{"hello"},
},
{
name: "Partial",
pieces: []string{"hello", "wor"},
stop: "or",
expected: []string{"hello", "w"},
expectedTrunc: true,
name: "Partial",
pieces: []string{"hello", "wor"},
stop: "or",
expected: []string{"hello", "w"},
},
{
name: "Suffix",
pieces: []string{"Hello", " there", "!"},
stop: "!",
expected: []string{"Hello", " there"},
expectedTrunc: false,
name: "Suffix",
pieces: []string{"Hello", " there", "!"},
stop: "!",
expected: []string{"Hello", " there"},
},
{
name: "Suffix partial",
pieces: []string{"Hello", " the", "re!"},
stop: "there!",
expected: []string{"Hello", " "},
expectedTrunc: true,
},
{
name: "Middle",
pieces: []string{"hello", " wor"},
stop: "llo w",
expected: []string{"he"},
expectedTrunc: true,
name: "Middle",
pieces: []string{"hello", " wor"},
stop: "llo w",
expected: []string{"he"},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result, resultTrunc := truncateStop(tt.pieces, tt.stop)
if !reflect.DeepEqual(result, tt.expected) || resultTrunc != tt.expectedTrunc {
t.Errorf("truncateStop(%v, %s): have %v (%v); want %v (%v)", tt.pieces, tt.stop, result, resultTrunc, tt.expected, tt.expectedTrunc)
result := truncateStop(tt.pieces, tt.stop)
if !reflect.DeepEqual(result, tt.expected) {
t.Errorf("truncateStop(%v, %s): have %v; want %v", tt.pieces, tt.stop, result, tt.expected)
}
})
}

View File

@@ -1086,13 +1086,10 @@ func (s *llmServer) Detokenize(ctx context.Context, tokens []int) (string, error
}
func (s *llmServer) Close() error {
s.modelLock.Lock()
if s.model != nil {
llama.FreeModel(s.model)
s.model = nil
}
s.modelLock.Unlock()
if s.cmd != nil {
slog.Debug("stopping llama server")
if err := s.cmd.Process.Kill(); err != nil {
@@ -1103,6 +1100,7 @@ func (s *llmServer) Close() error {
slog.Debug("waiting for llama server to exit")
<-s.done
}
s.cmd = nil
slog.Debug("llama server stopped")
}