Compare commits

..

4 Commits

Author SHA1 Message Date
Jeffrey Morgan
f8453e9d4a llm: attempt to evaluate symlinks, but do not fail (#9089)
provides a better approach to #9088 that will attempt to
evaluate symlinks (important for macOS where 'ollama' is
often a symlink), but use the result of os.Executable()
as a fallback in scenarios where filepath.EvalSymlinks
fails due to permission erorrs or other issues
2025-02-13 22:38:23 -08:00
Jeffrey Morgan
55c0ee76b4 llm: do not evaluate symlink for exe path lookup (#9088)
In some cases, the directories in the executable path read by
filepath.EvalSymlinks are not accessible, resulting in permission
errors which results in an error when running models. It also
doesn't work well on long paths on windows, also resulting in
errors. This change removes filepath.EvalSymlinks when accessing
os.Executable() altogether
2025-02-13 22:13:47 -08:00
Jeffrey Morgan
c03e248735 ml/backend/ggml: stable sort devices by score (#9081) 2025-02-13 18:43:33 -08:00
Jeffrey Morgan
a4f69a0191 build: add -DGGML_CUDA_NO_PEER_COPY=ON for rocm builds on windows (#9060) 2025-02-13 00:23:17 -08:00
13 changed files with 145 additions and 324 deletions

View File

@@ -104,6 +104,10 @@ if(CMAKE_HIP_COMPILER)
if(AMDGPU_TARGETS)
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/ggml/ggml/src/ggml-hip)
if (WIN32)
target_compile_definitions(ggml-hip PRIVATE GGML_CUDA_NO_PEER_COPY=1)
endif()
set(OLLAMA_HIP_INSTALL_DIR ${OLLAMA_INSTALL_DIR}/rocm)
install(TARGETS ggml-hip
RUNTIME_DEPENDENCIES

View File

@@ -77,8 +77,6 @@ type GenerateRequest struct {
// request, for multimodal models.
Images []ImageData `json:"images,omitempty"`
LogProbs int `json:"logprobs,omitempty"`
// Options lists model-specific options. For example, temperature can be
// set through this field, if the model supports it.
Options map[string]interface{} `json:"options"`
@@ -105,8 +103,6 @@ type ChatRequest struct {
// Tools is an optional list of tools the model has access to.
Tools `json:"tools,omitempty"`
LogProbs int `json:"logprobs,omitempty"`
// Options lists model-specific options.
Options map[string]interface{} `json:"options"`
}
@@ -186,20 +182,13 @@ func (t *ToolFunction) String() string {
return string(bts)
}
type TokenProbs struct {
TokenID int `json:"id"`
LogProb float32 `json:"logprob"`
Token string `json:"token"`
}
// ChatResponse is the response returned by [Client.Chat]. Its fields are
// similar to [GenerateResponse].
type ChatResponse struct {
Model string `json:"model"`
CreatedAt time.Time `json:"created_at"`
Message Message `json:"message"`
DoneReason string `json:"done_reason,omitempty"`
LogProbs []TokenProbs `json:"logprobs,omitempty"`
Model string `json:"model"`
CreatedAt time.Time `json:"created_at"`
Message Message `json:"message"`
DoneReason string `json:"done_reason,omitempty"`
Done bool `json:"done"`
@@ -463,8 +452,6 @@ type GenerateResponse struct {
// can be sent in the next request to keep a conversational memory.
Context []int `json:"context,omitempty"`
LogProbs []TokenProbs `json:"logprobs,omitempty"`
Metrics
}

View File

@@ -19,9 +19,8 @@ var LibOllamaPath string = func() string {
return ""
}
exe, err = filepath.EvalSymlinks(exe)
if err != nil {
return ""
if eval, err := filepath.EvalSymlinks(exe); err == nil {
exe = eval
}
var libPath string

View File

@@ -50,7 +50,7 @@ import (
_ "github.com/ollama/ollama/llama/llama.cpp/common"
_ "github.com/ollama/ollama/llama/llama.cpp/examples/llava"
_ "github.com/ollama/ollama/llama/llama.cpp/src"
ggml "github.com/ollama/ollama/ml/backend/ggml/ggml/src"
"github.com/ollama/ollama/ml/backend/ggml/ggml/src"
)
func BackendInit() {
@@ -220,19 +220,6 @@ func (c *Context) GetEmbeddingsIth(i int) []float32 {
return embeddings
}
// GetLogits returns the logits from the last decode operation.
// The returned slice has length equal to the vocabulary size.
func (c *Context) GetLogits() []float32 {
logits := unsafe.Pointer(C.llama_get_logits(c.c))
if logits == nil {
return nil
}
// Get the number of vocabulary tokens to determine array size
vocabSize := c.Model().NumVocab()
return unsafe.Slice((*float32)(logits), vocabSize)
}
type ModelParams struct {
NumGpuLayers int
MainGpu int

View File

@@ -8,7 +8,7 @@ Subject: [PATCH] sort devices by score
1 file changed, 13 insertions(+), 8 deletions(-)
diff --git a/ggml/src/ggml-backend-reg.cpp b/ggml/src/ggml-backend-reg.cpp
index 899d16f2..ac5cda07 100644
index 899d16f2..135f7df0 100644
--- a/ggml/src/ggml-backend-reg.cpp
+++ b/ggml/src/ggml-backend-reg.cpp
@@ -150,7 +150,7 @@ struct ggml_backend_reg_entry {
@@ -29,7 +29,7 @@ index 899d16f2..ac5cda07 100644
if (!reg) {
return;
}
@@ -206,15 +206,15 @@ struct ggml_backend_registry {
@@ -206,15 +206,20 @@ struct ggml_backend_registry {
#endif
backends.push_back({ reg, std::move(handle) });
for (size_t i = 0; i < ggml_backend_reg_dev_count(reg); i++) {
@@ -45,10 +45,15 @@ index 899d16f2..ac5cda07 100644
#endif
- devices.push_back(device);
+ devices.push_back({device, score});
+ std::stable_sort(devices.begin(), devices.end(),
+ [](const auto & a, const auto & b) {
+ return a.second > b.second;
+ }
+ );
}
ggml_backend_reg_t load_backend(const std::wstring & path, bool silent) {
@@ -257,7 +257,7 @@ struct ggml_backend_registry {
@@ -257,7 +262,7 @@ struct ggml_backend_registry {
GGML_LOG_INFO("%s: loaded %s backend from %s\n", __func__, ggml_backend_reg_name(reg), utf16_to_utf8(path).c_str());
@@ -57,7 +62,7 @@ index 899d16f2..ac5cda07 100644
return reg;
}
@@ -280,7 +280,7 @@ struct ggml_backend_registry {
@@ -280,7 +285,7 @@ struct ggml_backend_registry {
// remove devices
devices.erase(
std::remove_if(devices.begin(), devices.end(),
@@ -66,17 +71,12 @@ index 899d16f2..ac5cda07 100644
devices.end());
// remove backend
@@ -338,7 +338,12 @@ size_t ggml_backend_dev_count() {
@@ -338,7 +343,7 @@ size_t ggml_backend_dev_count() {
ggml_backend_dev_t ggml_backend_dev_get(size_t index) {
GGML_ASSERT(index < ggml_backend_dev_count());
- return get_reg().devices[index];
+ auto devices = get_reg().devices;
+ if (!std::is_heap(devices.begin(), devices.end())) {
+ std::make_heap(devices.begin(), devices.end(), [](const auto & a, const auto & b) { return a.second < b.second; });
+ }
+
+ return devices[index].first;
+ return get_reg().devices[index].first;
}
ggml_backend_dev_t ggml_backend_dev_by_name(const char * name) {

View File

@@ -8,7 +8,7 @@ Subject: [PATCH] try/catch backend load
1 file changed, 23 insertions(+), 22 deletions(-)
diff --git a/ggml/src/ggml-backend-reg.cpp b/ggml/src/ggml-backend-reg.cpp
index ac5cda07..374c3b21 100644
index 135f7df0..84b21dd8 100644
--- a/ggml/src/ggml-backend-reg.cpp
+++ b/ggml/src/ggml-backend-reg.cpp
@@ -512,32 +512,33 @@ static ggml_backend_reg_t ggml_backend_load_best(const char * name, bool silent,

View File

@@ -8,14 +8,12 @@ import (
"fmt"
"log"
"log/slog"
"math"
"net"
"net/http"
"os"
"path/filepath"
"regexp"
"runtime"
"sort"
"strconv"
"strings"
"sync"
@@ -50,9 +48,8 @@ type Sequence struct {
// inputs that have been added to a batch but not yet submitted to Decode
pendingInputs []input
// TODO: update this comment
// tokens that have been generated but not returned yet (e.g. for stop sequences)
pendingResponses []CompletionResponse
pendingResponses []string
// input cache being used by this sequence
cache *InputCacheSlot
@@ -62,7 +59,7 @@ type Sequence struct {
crossAttention bool
// channel to send responses over
responses chan CompletionResponse
responses chan string
// channel to stop decoding (such as if the remote connection is closed)
quit chan bool
@@ -86,11 +83,6 @@ type Sequence struct {
doneReason string
logits []float32
// number of logprobs to return with the completion response
logprobs int
// Metrics
startProcessingTime time.Time
startGenerationTime time.Time
@@ -104,7 +96,6 @@ type NewSequenceParams struct {
numKeep int
samplingParams *llama.SamplingParams
embedding bool
logprobs int
}
func (s *Server) NewSequence(prompt string, images []ImageData, params NewSequenceParams) (*Sequence, error) {
@@ -157,15 +148,14 @@ func (s *Server) NewSequence(prompt string, images []ImageData, params NewSequen
numPromptInputs: len(inputs),
startProcessingTime: startTime,
numPredict: params.numPredict,
pendingResponses: make([]CompletionResponse, 0),
responses: make(chan CompletionResponse, 100),
pendingResponses: make([]string, 0),
responses: make(chan string, 100),
quit: make(chan bool, 1),
embedding: make(chan []float32, 1),
samplingCtx: sc,
embeddingOnly: params.embedding,
stop: params.stop,
numKeep: params.numKeep,
logprobs: params.logprobs,
}, nil
}
@@ -284,37 +274,29 @@ func (s *Server) allNil() bool {
}
func flushPending(seq *Sequence) bool {
if len(seq.pendingResponses) == 0 {
joined := strings.Join(seq.pendingResponses, "")
seq.pendingResponses = []string{}
// Check if there are any partial UTF-8 characters remaining.
// We already check and queue as we are generating but some may
// still make it here:
// - Sequence is ending, e.g. generation limit has been hit
// - Invalid characters in the middle of a string
// This is a stricter check to ensure we never output invalid Unicode.
for !utf8.ValidString(joined) {
joined = joined[:len(joined)-1]
}
if len(joined) == 0 {
return true
}
resps := []CompletionResponse{}
for _, resp := range seq.pendingResponses {
resps = append(resps, resp)
select {
case seq.responses <- joined:
return true
case <-seq.quit:
return false
}
seq.pendingResponses = []CompletionResponse{}
// TODO: figure out this result logic
result := false
for _, resp := range resps {
// Check if there are any partial UTF-8 characters remaining.
// We already check and queue as we are generating but some may
// still make it here:
// - Sequence is ending, e.g. generation limit has been hit
// - Invalid characters in the middle of a string
// This is a stricter check to ensure we never output invalid Unicode.
for !utf8.ValidString(resp.Content) {
resp.Content = resp.Content[:len(resp.Content)-1]
}
select {
case seq.responses <- resp:
result = true
case <-seq.quit:
result = false
}
}
return result
}
func (s *Server) removeSequence(seqIndex int, reason string) {
@@ -368,63 +350,6 @@ func (s *Server) run(ctx context.Context) {
}
}
// TokenProbs represents probability information for a token
type TokenProbs struct {
TokenID int `json:"id"`
Logit float32 `json:"logit"`
Prob float32 `json:"prob"`
LogProb float32 `json:"logprob"`
Token string `json:"token"`
}
// probs returns sorted token probabilities for a specific token index
func probs(logits []float32, vocabSize int) []TokenProbs {
probs := make([]TokenProbs, vocabSize)
// Initialize token data with logits
for i := 0; i < vocabSize; i++ {
probs[i] = TokenProbs{
TokenID: i,
Logit: logits[i],
}
}
// Sort tokens by logits in descending order
sort.Slice(probs, func(i, j int) bool {
return probs[i].Logit > probs[j].Logit
})
// Apply softmax
maxLogit := probs[0].Logit
var sum float32 = 0.0
for i := range probs {
p := float32(math.Exp(float64(probs[i].Logit - maxLogit)))
probs[i].Prob = p
sum += p
}
// Normalize probabilities and calculate log probs
for i := range probs {
prob := probs[i].Prob / sum
probs[i].Prob = prob
probs[i].LogProb = float32(math.Log(float64(prob)))
}
return probs
}
// probs returns sorted token probabilities for a specific token index
func (s *Server) probs(seq *Sequence) []TokenProbs {
// Get logits for the specific token index
logits := s.lc.GetLogits()
seq.logits = make([]float32, len(logits))
copy(seq.logits, logits)
vocabSize := s.model.NumVocab()
return probs(logits, vocabSize)
}
// TODO (jmorganca): processBatch should be simplified, removing:
// * sampling
// * stop token checking
@@ -558,19 +483,6 @@ func (s *Server) processBatch(tokenBatch *llama.Batch, embedBatch *llama.Batch)
seq.numPredicted++
resp := CompletionResponse{Content: piece}
if seq.logprobs > 0 {
// TODO: return selected token in logprobs always
resp.LogProbs = s.probs(seq)
// TODO: fix this logprobs limit
resp.LogProbs = resp.LogProbs[:min(len(resp.LogProbs), seq.logprobs)]
for i := range resp.LogProbs {
// decode the token id to a piece
resp.LogProbs[i].Token = s.model.TokenToPiece(resp.LogProbs[i].TokenID)
}
}
// if it's an end of sequence token, break
if s.model.TokenIsEog(token) {
// TODO (jmorganca): we should send this back
@@ -583,21 +495,16 @@ func (s *Server) processBatch(tokenBatch *llama.Batch, embedBatch *llama.Batch)
seq.inputs = []input{{token: token}}
// TODO: add probs here
seq.pendingResponses = append(seq.pendingResponses, resp)
var sequence string
for _, r := range seq.pendingResponses {
sequence += r.Content
}
seq.pendingResponses = append(seq.pendingResponses, piece)
sequence := strings.Join(seq.pendingResponses, "")
if ok, stop := findStop(sequence, seq.stop); ok {
slog.Debug("hit stop token", "pending", seq.pendingResponses, "stop", stop)
// TODO: fix this stop sequence caching
var tokenTruncated bool
origLen := len(sequence)
sequence, tokenTruncated = truncateStop(sequence, stop)
newLen := len(sequence)
origLen := len(seq.pendingResponses)
seq.pendingResponses, tokenTruncated = truncateStop(seq.pendingResponses, stop)
newLen := len(seq.pendingResponses)
// Update the cache based on the tokens that will be returned:
// - We have 1 token more than is currently in the cache because
@@ -668,7 +575,6 @@ type CompletionRequest struct {
Images []ImageData `json:"image_data"`
Grammar string `json:"grammar"`
CachePrompt bool `json:"cache_prompt"`
Logprobs int `json:"logprobs,omitempty"`
Options
}
@@ -684,10 +590,8 @@ type CompletionResponse struct {
Content string `json:"content"`
Stop bool `json:"stop"`
Model string `json:"model,omitempty"`
Prompt string `json:"prompt,omitempty"`
LogProbs []TokenProbs `json:"logprobs,omitempty"`
Model string `json:"model,omitempty"`
Prompt string `json:"prompt,omitempty"`
StoppedLimit bool `json:"stopped_limit,omitempty"`
PredictedN int `json:"predicted_n,omitempty"`
PredictedMS float64 `json:"predicted_ms,omitempty"`
@@ -705,6 +609,10 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
return
}
// Set the headers to indicate streaming
w.Header().Set("Content-Type", "application/json")
w.Header().Set("Transfer-Encoding", "chunked")
flusher, ok := w.(http.Flusher)
if !ok {
http.Error(w, "Streaming not supported", http.StatusInternalServerError)
@@ -733,7 +641,6 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
numKeep: req.NumKeep,
samplingParams: &samplingParams,
embedding: false,
logprobs: req.Logprobs,
})
if err != nil {
http.Error(w, fmt.Sprintf("Failed to create new sequence: %v", err), http.StatusInternalServerError)
@@ -781,10 +688,11 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
case <-r.Context().Done():
close(seq.quit)
return
case resp, ok := <-seq.responses:
case content, ok := <-seq.responses:
if ok {
fmt.Println("response", resp)
if err := json.NewEncoder(w).Encode(&resp); err != nil {
if err := json.NewEncoder(w).Encode(&CompletionResponse{
Content: content,
}); err != nil {
http.Error(w, fmt.Sprintf("failed to encode response: %v", err), http.StatusInternalServerError)
close(seq.quit)
return

View File

@@ -1,58 +0,0 @@
package runner
import (
"math"
"testing"
)
func TestProbs(t *testing.T) {
// Input test data
logits := []float32{1.0, 2.0, 0.5, -1.0}
vocabSize := 4
want := []TokenProbs{
{TokenID: 1, Logit: 2.0}, // Highest logit
{TokenID: 0, Logit: 1.0}, // Second highest
{TokenID: 2, Logit: 0.5}, // Third
{TokenID: 3, Logit: -1.0}, // Lowest
}
got := probs(logits, vocabSize)
// Test 1: Check sorting order
for i := 0; i < len(got)-1; i++ {
if got[i].Logit < got[i+1].Logit {
t.Errorf("probs not properly sorted: logit at pos %d (%f) < logit at pos %d (%f)",
i, got[i].Logit, i+1, got[i+1].Logit)
}
}
// Test 2: Check probability normalization
var sum float32
for _, p := range got {
sum += p.Prob
}
if math.Abs(float64(sum-1.0)) > 1e-6 {
t.Errorf("probabilities do not sum to 1: got %v", sum)
}
// Test 3: Check token IDs match expected order
for i, want := range want {
if got[i].TokenID != want.TokenID {
t.Errorf("wrong token ID at position %d: got %d, want %d",
i, got[i].TokenID, want.TokenID)
}
if got[i].Logit != want.Logit {
t.Errorf("wrong logit at position %d: got %f, want %f",
i, got[i].Logit, want.Logit)
}
}
// Test 4: Check log probs are correctly calculated
for i, p := range got {
expectedLogProb := float32(math.Log(float64(p.Prob)))
if math.Abs(float64(p.LogProb-expectedLogProb)) > 1e-6 {
t.Errorf("wrong log prob at position %d: got %f, want %f",
i, p.LogProb, expectedLogProb)
}
}
}

View File

@@ -26,15 +26,43 @@ func containsStopSuffix(sequence string, stops []string) bool {
return false
}
// truncateStop removes the provided stop string from sequence,
// returning both the truncated sequence and a bool indicating if truncation occurred
func truncateStop(sequence string, stop string) (string, bool) {
index := strings.Index(sequence, stop)
// truncateStop removes the provided stop string from pieces,
// returning the partial pieces with stop removed, including truncating
// the last piece if required (and signalling if this was the case)
func truncateStop(pieces []string, stop string) ([]string, bool) {
joined := strings.Join(pieces, "")
index := strings.Index(joined, stop)
if index == -1 {
return sequence, false
return pieces, false
}
return sequence[:index], true
joined = joined[:index]
// Split truncated string back into pieces of original lengths
lengths := make([]int, len(pieces))
for i, piece := range pieces {
lengths[i] = len(piece)
}
var result []string
tokenTruncated := false
start := 0
for _, length := range lengths {
if start >= len(joined) {
break
}
end := start + length
if end > len(joined) {
end = len(joined)
tokenTruncated = true
}
result = append(result, joined[start:end])
start = end
}
return result, tokenTruncated
}
func incompleteUnicode(token string) bool {

View File

@@ -1,60 +1,60 @@
package runner
import (
"reflect"
"testing"
)
func TestTruncateStop(t *testing.T) {
tests := []struct {
name string
sequence string
pieces []string
stop string
expected string
expected []string
expectedTrunc bool
}{
{
name: "Single word",
sequence: "helloworld",
pieces: []string{"hello", "world"},
stop: "world",
expected: "hello",
expectedTrunc: true,
expected: []string{"hello"},
expectedTrunc: false,
},
{
name: "Partial",
sequence: "hellowor",
pieces: []string{"hello", "wor"},
stop: "or",
expected: "hellow",
expected: []string{"hello", "w"},
expectedTrunc: true,
},
{
name: "Suffix",
sequence: "Hello there!",
pieces: []string{"Hello", " there", "!"},
stop: "!",
expected: "Hello there",
expected: []string{"Hello", " there"},
expectedTrunc: false,
},
{
name: "Suffix partial",
pieces: []string{"Hello", " the", "re!"},
stop: "there!",
expected: []string{"Hello", " "},
expectedTrunc: true,
},
{
name: "Middle",
sequence: "hello wor",
pieces: []string{"hello", " wor"},
stop: "llo w",
expected: "he",
expected: []string{"he"},
expectedTrunc: true,
},
{
name: "No stop found",
sequence: "hello world",
stop: "xyz",
expected: "hello world",
expectedTrunc: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result, truncated := truncateStop(tt.sequence, tt.stop)
if result != tt.expected || truncated != tt.expectedTrunc {
t.Errorf("truncateStop(%q, %q): have %q (%v); want %q (%v)",
tt.sequence, tt.stop, result, truncated, tt.expected, tt.expectedTrunc)
result, resultTrunc := truncateStop(tt.pieces, tt.stop)
if !reflect.DeepEqual(result, tt.expected) || resultTrunc != tt.expectedTrunc {
t.Errorf("truncateStop(%v, %s): have %v (%v); want %v (%v)", tt.pieces, tt.stop, result, resultTrunc, tt.expected, tt.expectedTrunc)
}
})
}

View File

@@ -320,9 +320,8 @@ func NewLlamaServer(gpus discover.GpuInfoList, model string, ggml *GGML, adapter
return nil, fmt.Errorf("unable to lookup executable path: %w", err)
}
exe, err = filepath.EvalSymlinks(exe)
if err != nil {
return nil, fmt.Errorf("unable to evaluate symlinks for executable path: %w", err)
if eval, err := filepath.EvalSymlinks(exe); err == nil {
exe = eval
}
// TODO - once fully switched to the Go runner, load the model here for tokenize/detokenize cgo access
@@ -644,22 +643,12 @@ type ImageData struct {
AspectRatioID int `json:"aspect_ratio_id"`
}
// TokenProbs represents probability information for a token
type TokenProbs struct {
TokenID int `json:"id"`
Logit float32 `json:"logit"`
Prob float32 `json:"prob"`
LogProb float32 `json:"logprob"`
Token string `json:"token"`
}
type completion struct {
Content string `json:"content"`
Model string `json:"model"`
Prompt string `json:"prompt"`
Stop bool `json:"stop"`
StoppedLimit bool `json:"stopped_limit"`
LogProbs []TokenProbs `json:"logprobs"`
Content string `json:"content"`
Model string `json:"model"`
Prompt string `json:"prompt"`
Stop bool `json:"stop"`
StoppedLimit bool `json:"stopped_limit"`
Timings struct {
PredictedN int `json:"predicted_n"`
@@ -670,16 +659,14 @@ type completion struct {
}
type CompletionRequest struct {
Prompt string
Format json.RawMessage
Images []ImageData
LogProbs int
Options *api.Options
Prompt string
Format json.RawMessage
Images []ImageData
Options *api.Options
}
type CompletionResponse struct {
Content string
LogProbs []TokenProbs
DoneReason string
Done bool
PromptEvalCount int
@@ -710,12 +697,9 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu
"seed": req.Options.Seed,
"stop": req.Options.Stop,
"image_data": req.Images,
"logprobs": req.LogProbs,
"cache_prompt": true,
}
fmt.Println("completion request:", request)
if len(req.Format) > 0 {
switch string(req.Format) {
case `null`, `""`:
@@ -811,6 +795,7 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu
continue
}
// slog.Debug("got line", "line", string(line))
evt, ok := bytes.CutPrefix(line, []byte("data: "))
if !ok {
evt = line
@@ -836,8 +821,7 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu
if c.Content != "" {
fn(CompletionResponse{
Content: c.Content,
LogProbs: c.LogProbs,
Content: c.Content,
})
}
@@ -854,7 +838,6 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu
PromptEvalDuration: parseDurationMs(c.Timings.PromptMS),
EvalCount: c.Timings.PredictedN,
EvalDuration: parseDurationMs(c.Timings.PredictedMS),
LogProbs: c.LogProbs,
})
return nil
}

View File

@@ -215,6 +215,11 @@ struct ggml_backend_registry {
GGML_LOG_DEBUG("%s: registered device %s (%s)\n", __func__, ggml_backend_dev_name(device), ggml_backend_dev_description(device));
#endif
devices.push_back({device, score});
std::stable_sort(devices.begin(), devices.end(),
[](const auto & a, const auto & b) {
return a.second > b.second;
}
);
}
ggml_backend_reg_t load_backend(const std::wstring & path, bool silent) {
@@ -338,12 +343,7 @@ size_t ggml_backend_dev_count() {
ggml_backend_dev_t ggml_backend_dev_get(size_t index) {
GGML_ASSERT(index < ggml_backend_dev_count());
auto devices = get_reg().devices;
if (!std::is_heap(devices.begin(), devices.end())) {
std::make_heap(devices.begin(), devices.end(), [](const auto & a, const auto & b) { return a.second < b.second; });
}
return devices[index].first;
return get_reg().devices[index].first;
}
ggml_backend_dev_t ggml_backend_dev_by_name(const char * name) {

View File

@@ -293,13 +293,11 @@ func (s *Server) GenerateHandler(c *gin.Context) {
var sb strings.Builder
defer close(ch)
if err := r.Completion(c.Request.Context(), llm.CompletionRequest{
Prompt: prompt,
Images: images,
Format: req.Format,
LogProbs: req.LogProbs,
Options: opts,
Prompt: prompt,
Images: images,
Format: req.Format,
Options: opts,
}, func(cr llm.CompletionResponse) {
fmt.Printf("banana: %#v\n", cr)
res := api.GenerateResponse{
Model: req.Model,
CreatedAt: time.Now().UTC(),
@@ -313,13 +311,6 @@ func (s *Server) GenerateHandler(c *gin.Context) {
EvalDuration: cr.EvalDuration,
},
}
for _, p := range cr.LogProbs {
res.LogProbs = append(res.LogProbs, api.TokenProbs{
TokenID: p.TokenID,
LogProb: p.LogProb,
Token: p.Token,
})
}
if _, err := sb.WriteString(cr.Content); err != nil {
ch <- gin.H{"error": err.Error()}
@@ -1475,11 +1466,10 @@ func (s *Server) ChatHandler(c *gin.Context) {
var sb strings.Builder
var toolCallIndex int = 0
if err := r.Completion(c.Request.Context(), llm.CompletionRequest{
Prompt: prompt,
Images: images,
Format: req.Format,
LogProbs: req.LogProbs,
Options: opts,
Prompt: prompt,
Images: images,
Format: req.Format,
Options: opts,
}, func(r llm.CompletionResponse) {
res := api.ChatResponse{
Model: req.Model,
@@ -1494,13 +1484,6 @@ func (s *Server) ChatHandler(c *gin.Context) {
EvalDuration: r.EvalDuration,
},
}
for _, p := range r.LogProbs {
res.LogProbs = append(res.LogProbs, api.TokenProbs{
TokenID: p.TokenID,
LogProb: p.LogProb,
Token: p.Token,
})
}
if r.Done {
res.TotalDuration = time.Since(checkpointStart)