Compare commits

..

9 Commits

Author SHA1 Message Date
Bruce MacDonald
b88489a87e ... 2025-02-20 09:36:36 -08:00
Bruce MacDonald
fdbb0b5cfe prototype 2025-02-13 15:22:15 -08:00
Bruce MacDonald
64f95067ba ... 2025-02-13 14:02:04 -08:00
Bruce MacDonald
6dfcdec2da send completion response on chan 2025-02-12 17:03:52 -08:00
Bruce MacDonald
7d16ec8fe8 print logprobs 2025-02-12 16:36:03 -08:00
Clinton
82658c3eec readme: add Homebrew to package managers section (#9052) 2025-02-12 11:17:39 -08:00
bloominstrong
378d6e1e6a docs: fix nix package link (#9045)
removing the channel tag from the url so it will always go to the current stable channel.
2025-02-12 09:16:26 -08:00
Hugues Chocart
afa55bc70c doc: fix link for Abso (#9043) 2025-02-12 09:15:08 -08:00
Michael Yang
49df03da9a fix: harden backend loading (#9024)
* wrap ggml_backend_load_best in try/catch
* ignore non-ollama paths
2025-02-11 15:36:53 -08:00
16 changed files with 440 additions and 487 deletions

View File

@@ -437,9 +437,10 @@ See the [API documentation](./docs/api.md) for all endpoints.
- [Pacman](https://archlinux.org/packages/extra/x86_64/ollama/) - [Pacman](https://archlinux.org/packages/extra/x86_64/ollama/)
- [Gentoo](https://github.com/gentoo/guru/tree/master/app-misc/ollama) - [Gentoo](https://github.com/gentoo/guru/tree/master/app-misc/ollama)
- [Homebrew](https://formulae.brew.sh/formula/ollama)
- [Helm Chart](https://artifacthub.io/packages/helm/ollama-helm/ollama) - [Helm Chart](https://artifacthub.io/packages/helm/ollama-helm/ollama)
- [Guix channel](https://codeberg.org/tusharhero/ollama-guix) - [Guix channel](https://codeberg.org/tusharhero/ollama-guix)
- [Nix package](https://search.nixos.org/packages?channel=24.05&show=ollama&from=0&size=50&sort=relevance&type=packages&query=ollama) - [Nix package](https://search.nixos.org/packages?show=ollama&from=0&size=50&sort=relevance&type=packages&query=ollama)
- [Flox](https://flox.dev/blog/ollama-part-one) - [Flox](https://flox.dev/blog/ollama-part-one)
### Libraries ### Libraries
@@ -494,7 +495,7 @@ See the [API documentation](./docs/api.md) for all endpoints.
- [multi-llm-ts](https://github.com/nbonamy/multi-llm-ts) (A Typescript/JavaScript library allowing access to different LLM in unified API) - [multi-llm-ts](https://github.com/nbonamy/multi-llm-ts) (A Typescript/JavaScript library allowing access to different LLM in unified API)
- [LlmTornado](https://github.com/lofcz/llmtornado) (C# library providing a unified interface for major FOSS & Commercial inference APIs) - [LlmTornado](https://github.com/lofcz/llmtornado) (C# library providing a unified interface for major FOSS & Commercial inference APIs)
- [Ollama for Zig](https://github.com/dravenk/ollama-zig) - [Ollama for Zig](https://github.com/dravenk/ollama-zig)
- [Abso](https://github.com/lunary-ai/abso/blob/main/README.md#ollama) (OpenAI-compatible TypeScript SDK for any LLM provider) - [Abso](https://github.com/lunary-ai/abso) (OpenAI-compatible TypeScript SDK for any LLM provider)
### Mobile ### Mobile

View File

@@ -18,6 +18,7 @@ import (
"bytes" "bytes"
"context" "context"
"encoding/json" "encoding/json"
"errors"
"fmt" "fmt"
"io" "io"
"net/http" "net/http"
@@ -29,28 +30,6 @@ import (
"github.com/ollama/ollama/version" "github.com/ollama/ollama/version"
) )
// StatusError is an error with an HTTP status code and message,
// it is parsed on the client-side and not returned from the API
type StatusError struct {
StatusCode int // e.g. 200
Status string // e.g. "200 OK"
ErrorResponse
}
func (e StatusError) Error() string {
switch {
case e.Status != "" && e.Err != "":
return fmt.Sprintf("%s: %s", e.Status, e.Err)
case e.Status != "":
return e.Status
case e.Err != "":
return e.Err
default:
// this should not happen
return "something went wrong, please see the ollama server logs for details"
}
}
// Client encapsulates client state for interacting with the ollama // Client encapsulates client state for interacting with the ollama
// service. Use [ClientFromEnvironment] to create new Clients. // service. Use [ClientFromEnvironment] to create new Clients.
type Client struct { type Client struct {
@@ -68,7 +47,7 @@ func checkError(resp *http.Response, body []byte) error {
err := json.Unmarshal(body, &apiError) err := json.Unmarshal(body, &apiError)
if err != nil { if err != nil {
// Use the full body as the message if we fail to decode a response. // Use the full body as the message if we fail to decode a response.
apiError.Err = string(body) apiError.ErrorMessage = string(body)
} }
return apiError return apiError
@@ -153,7 +132,7 @@ func (c *Client) do(ctx context.Context, method, path string, reqData, respData
const maxBufferSize = 512 * format.KiloByte const maxBufferSize = 512 * format.KiloByte
func (c *Client) stream(ctx context.Context, method, path string, data any, fn func([]byte) error) error { func (c *Client) stream(ctx context.Context, method, path string, data any, fn func([]byte) error) error {
var buf io.Reader var buf *bytes.Buffer
if data != nil { if data != nil {
bts, err := json.Marshal(data) bts, err := json.Marshal(data)
if err != nil { if err != nil {
@@ -184,22 +163,24 @@ func (c *Client) stream(ctx context.Context, method, path string, data any, fn f
scanBuf := make([]byte, 0, maxBufferSize) scanBuf := make([]byte, 0, maxBufferSize)
scanner.Buffer(scanBuf, maxBufferSize) scanner.Buffer(scanBuf, maxBufferSize)
for scanner.Scan() { for scanner.Scan() {
bts := scanner.Bytes() var errorResponse struct {
Error string `json:"error,omitempty"`
}
var errorResponse ErrorResponse bts := scanner.Bytes()
if err := json.Unmarshal(bts, &errorResponse); err != nil { if err := json.Unmarshal(bts, &errorResponse); err != nil {
return fmt.Errorf("unmarshal: %w", err) return fmt.Errorf("unmarshal: %w", err)
} }
if errorResponse.Err != "" { if errorResponse.Error != "" {
return errorResponse return errors.New(errorResponse.Error)
} }
if response.StatusCode >= http.StatusBadRequest { if response.StatusCode >= http.StatusBadRequest {
return StatusError{ return StatusError{
StatusCode: response.StatusCode, StatusCode: response.StatusCode,
Status: response.Status, Status: response.Status,
ErrorResponse: errorResponse, ErrorMessage: errorResponse.Error,
} }
} }

View File

@@ -1,12 +1,6 @@
package api package api
import ( import (
"context"
"encoding/json"
"fmt"
"net/http"
"net/http/httptest"
"net/url"
"testing" "testing"
) )
@@ -49,270 +43,3 @@ func TestClientFromEnvironment(t *testing.T) {
}) })
} }
} }
// testError represents an internal error type for testing different error formats
type testError struct {
message string // basic error message
structured *ErrorResponse // structured error response, nil for basic format
statusCode int
}
func (e testError) Error() string {
return e.message
}
func TestClientStream(t *testing.T) {
testCases := []struct {
name string
responses []any
wantErr string
}{
{
name: "basic error format",
responses: []any{
testError{
message: "test error message",
statusCode: http.StatusBadRequest,
},
},
wantErr: "test error message",
},
{
name: "structured error format",
responses: []any{
testError{
message: "test structured error",
structured: &ErrorResponse{
Err: "test structured error",
Hint: "test hint",
},
statusCode: http.StatusBadRequest,
},
},
wantErr: "test structured error\ntest hint",
},
{
name: "error after chunks - basic format",
responses: []any{
ChatResponse{Message: Message{Content: "partial 1"}},
ChatResponse{Message: Message{Content: "partial 2"}},
testError{
message: "mid-stream basic error",
statusCode: http.StatusOK,
},
},
wantErr: "mid-stream basic error",
},
{
name: "error after chunks - structured format",
responses: []any{
ChatResponse{Message: Message{Content: "partial 1"}},
ChatResponse{Message: Message{Content: "partial 2"}},
testError{
message: "mid-stream structured error",
structured: &ErrorResponse{
Err: "mid-stream structured error",
Hint: "additional context",
},
statusCode: http.StatusOK,
},
},
wantErr: "mid-stream structured error\nadditional context",
},
{
name: "successful stream completion",
responses: []any{
ChatResponse{Message: Message{Content: "chunk 1"}},
ChatResponse{Message: Message{Content: "chunk 2"}},
ChatResponse{
Message: Message{Content: "final chunk"},
Done: true,
DoneReason: "stop",
},
},
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
flusher, ok := w.(http.Flusher)
if !ok {
t.Fatal("expected http.Flusher")
}
w.Header().Set("Content-Type", "application/x-ndjson")
for _, resp := range tc.responses {
if errResp, ok := resp.(testError); ok {
w.WriteHeader(errResp.statusCode)
var err error
if errResp.structured != nil {
err = json.NewEncoder(w).Encode(errResp.structured)
} else {
err = json.NewEncoder(w).Encode(map[string]string{
"error": errResp.message,
})
}
if err != nil {
t.Fatal("failed to encode error response:", err)
}
return
}
if err := json.NewEncoder(w).Encode(resp); err != nil {
t.Fatalf("failed to encode response: %v", err)
}
flusher.Flush()
}
}))
defer ts.Close()
client := NewClient(&url.URL{Scheme: "http", Host: ts.Listener.Addr().String()}, http.DefaultClient)
var receivedChunks []ChatResponse
err := client.stream(context.Background(), http.MethodPost, "/v1/chat", nil, func(chunk []byte) error {
var resp ChatResponse
if err := json.Unmarshal(chunk, &resp); err != nil {
return fmt.Errorf("failed to unmarshal chunk: %w", err)
}
receivedChunks = append(receivedChunks, resp)
return nil
})
if tc.wantErr != "" {
if err == nil {
t.Fatalf("got nil, want error %q", tc.wantErr)
}
if err.Error() != tc.wantErr {
t.Errorf("error message mismatch: got %q, want %q", err.Error(), tc.wantErr)
}
} else {
if err != nil {
t.Errorf("got error %q, want nil", err)
}
}
})
}
}
func TestClientDo(t *testing.T) {
testCases := []struct {
name string
response any
wantErr string
}{
{
name: "basic error format",
response: testError{
message: "test error message",
statusCode: http.StatusBadRequest,
},
wantErr: "test error message",
},
{
name: "structured error format",
response: testError{
message: "test structured error",
structured: &ErrorResponse{
Err: "test structured error",
Hint: "test hint",
},
statusCode: http.StatusBadRequest,
},
wantErr: "test structured error",
},
{
name: "server error - basic format",
response: testError{
message: "internal error",
statusCode: http.StatusInternalServerError,
},
wantErr: "internal error",
},
{
name: "server error - structured format",
response: testError{
message: "internal server error",
structured: &ErrorResponse{
Err: "internal server error",
Hint: "please try again later",
},
statusCode: http.StatusInternalServerError,
},
wantErr: "internal server error",
},
{
name: "successful response",
response: struct {
ID string `json:"id"`
Success bool `json:"success"`
}{
ID: "msg_123",
Success: true,
},
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if errResp, ok := tc.response.(testError); ok {
w.WriteHeader(errResp.statusCode)
var err error
if errResp.structured != nil {
err = json.NewEncoder(w).Encode(errResp.structured)
} else {
err = json.NewEncoder(w).Encode(map[string]string{
"error": errResp.message,
})
}
if err != nil {
t.Fatal("failed to encode error response:", err)
}
return
}
w.Header().Set("Content-Type", "application/json")
if err := json.NewEncoder(w).Encode(tc.response); err != nil {
t.Fatalf("failed to encode response: %v", err)
}
}))
defer ts.Close()
client := NewClient(&url.URL{Scheme: "http", Host: ts.Listener.Addr().String()}, http.DefaultClient)
var resp struct {
ID string `json:"id"`
Success bool `json:"success"`
}
err := client.do(context.Background(), http.MethodPost, "/v1/messages", nil, &resp)
if tc.wantErr != "" {
if err == nil {
t.Fatalf("got nil, want error %q", tc.wantErr)
}
if err.Error() != tc.wantErr {
t.Errorf("error message mismatch: got %q, want %q", err.Error(), tc.wantErr)
}
return
}
if err != nil {
t.Errorf("got error %q, want nil", err)
}
if expectedResp, ok := tc.response.(struct {
ID string `json:"id"`
Success bool `json:"success"`
}); ok {
if resp.ID != expectedResp.ID {
t.Errorf("response ID mismatch: got %q, want %q", resp.ID, expectedResp.ID)
}
if resp.Success != expectedResp.Success {
t.Errorf("response Success mismatch: got %v, want %v", resp.Success, expectedResp.Success)
}
}
})
}
}

View File

@@ -12,6 +12,27 @@ import (
"time" "time"
) )
// StatusError is an error with an HTTP status code and message.
type StatusError struct {
StatusCode int
Status string
ErrorMessage string `json:"error"`
}
func (e StatusError) Error() string {
switch {
case e.Status != "" && e.ErrorMessage != "":
return fmt.Sprintf("%s: %s", e.Status, e.ErrorMessage)
case e.Status != "":
return e.Status
case e.ErrorMessage != "":
return e.ErrorMessage
default:
// this should not happen
return "something went wrong, please see the ollama server logs for details"
}
}
// ImageData represents the raw binary data of an image file. // ImageData represents the raw binary data of an image file.
type ImageData []byte type ImageData []byte
@@ -56,6 +77,8 @@ type GenerateRequest struct {
// request, for multimodal models. // request, for multimodal models.
Images []ImageData `json:"images,omitempty"` Images []ImageData `json:"images,omitempty"`
LogProbs int `json:"logprobs,omitempty"`
// Options lists model-specific options. For example, temperature can be // Options lists model-specific options. For example, temperature can be
// set through this field, if the model supports it. // set through this field, if the model supports it.
Options map[string]interface{} `json:"options"` Options map[string]interface{} `json:"options"`
@@ -82,6 +105,8 @@ type ChatRequest struct {
// Tools is an optional list of tools the model has access to. // Tools is an optional list of tools the model has access to.
Tools `json:"tools,omitempty"` Tools `json:"tools,omitempty"`
LogProbs int `json:"logprobs,omitempty"`
// Options lists model-specific options. // Options lists model-specific options.
Options map[string]interface{} `json:"options"` Options map[string]interface{} `json:"options"`
} }
@@ -161,13 +186,20 @@ func (t *ToolFunction) String() string {
return string(bts) return string(bts)
} }
type TokenProbs struct {
TokenID int `json:"id"`
LogProb float32 `json:"logprob"`
Token string `json:"token"`
}
// ChatResponse is the response returned by [Client.Chat]. Its fields are // ChatResponse is the response returned by [Client.Chat]. Its fields are
// similar to [GenerateResponse]. // similar to [GenerateResponse].
type ChatResponse struct { type ChatResponse struct {
Model string `json:"model"` Model string `json:"model"`
CreatedAt time.Time `json:"created_at"` CreatedAt time.Time `json:"created_at"`
Message Message `json:"message"` Message Message `json:"message"`
DoneReason string `json:"done_reason,omitempty"` DoneReason string `json:"done_reason,omitempty"`
LogProbs []TokenProbs `json:"logprobs,omitempty"`
Done bool `json:"done"` Done bool `json:"done"`
@@ -431,6 +463,8 @@ type GenerateResponse struct {
// can be sent in the next request to keep a conversational memory. // can be sent in the next request to keep a conversational memory.
Context []int `json:"context,omitempty"` Context []int `json:"context,omitempty"`
LogProbs []TokenProbs `json:"logprobs,omitempty"`
Metrics Metrics
} }
@@ -640,22 +674,6 @@ func (d *Duration) UnmarshalJSON(b []byte) (err error) {
return nil return nil
} }
// ErrorResponse implements a structured error interface that is returned from the Ollama server
type ErrorResponse struct {
// Err is the error from the server. It helps with debugging the code-path
Err string `json:"error"`
// Hint is a user-friendly message about what went wrong, with suggested troubleshooting
Hint string `json:"hint"`
}
func (e ErrorResponse) Error() string {
if e.Hint == "" {
return e.Err
}
return fmt.Sprintf("%s\n%s", e.Err, e.Hint)
}
// FormatParams converts specified parameter options to their correct types // FormatParams converts specified parameter options to their correct types
func FormatParams(params map[string][]string) (map[string]interface{}, error) { func FormatParams(params map[string][]string) (map[string]interface{}, error) {
opts := Options{} opts := Options{}

View File

@@ -50,7 +50,7 @@ import (
_ "github.com/ollama/ollama/llama/llama.cpp/common" _ "github.com/ollama/ollama/llama/llama.cpp/common"
_ "github.com/ollama/ollama/llama/llama.cpp/examples/llava" _ "github.com/ollama/ollama/llama/llama.cpp/examples/llava"
_ "github.com/ollama/ollama/llama/llama.cpp/src" _ "github.com/ollama/ollama/llama/llama.cpp/src"
"github.com/ollama/ollama/ml/backend/ggml/ggml/src" ggml "github.com/ollama/ollama/ml/backend/ggml/ggml/src"
) )
func BackendInit() { func BackendInit() {
@@ -220,6 +220,19 @@ func (c *Context) GetEmbeddingsIth(i int) []float32 {
return embeddings return embeddings
} }
// GetLogits returns the logits from the last decode operation.
// The returned slice has length equal to the vocabulary size.
func (c *Context) GetLogits() []float32 {
logits := unsafe.Pointer(C.llama_get_logits(c.c))
if logits == nil {
return nil
}
// Get the number of vocabulary tokens to determine array size
vocabSize := c.Model().NumVocab()
return unsafe.Slice((*float32)(logits), vocabSize)
}
type ModelParams struct { type ModelParams struct {
NumGpuLayers int NumGpuLayers int
MainGpu int MainGpu int

View File

@@ -0,0 +1,69 @@
From 0000000000000000000000000000000000000000 Mon Sep 17 00:00:00 2001
From: Michael Yang <mxyng@pm.me>
Date: Tue, 11 Feb 2025 14:06:36 -0800
Subject: [PATCH] try/catch backend load
---
ggml/src/ggml-backend-reg.cpp | 45 ++++++++++++++++++-----------------
1 file changed, 23 insertions(+), 22 deletions(-)
diff --git a/ggml/src/ggml-backend-reg.cpp b/ggml/src/ggml-backend-reg.cpp
index ac5cda07..374c3b21 100644
--- a/ggml/src/ggml-backend-reg.cpp
+++ b/ggml/src/ggml-backend-reg.cpp
@@ -512,32 +512,33 @@ static ggml_backend_reg_t ggml_backend_load_best(const char * name, bool silent,
}
fs::directory_iterator dir_it(search_path, fs::directory_options::skip_permission_denied);
for (const auto & entry : dir_it) {
- if (entry.is_regular_file()) {
- std::wstring filename = entry.path().filename().wstring();
- std::wstring ext = entry.path().extension().wstring();
- if (filename.find(file_prefix) == 0 && ext == backend_filename_suffix()) {
- dl_handle_ptr handle { dl_load_library(entry.path().wstring()) };
- if (!handle && !silent) {
- GGML_LOG_ERROR("%s: failed to load %s\n", __func__, utf16_to_utf8(entry.path().wstring()).c_str());
- }
- if (handle) {
+ try {
+ if (entry.is_regular_file()) {
+ std::wstring filename = entry.path().filename().wstring();
+ std::wstring ext = entry.path().extension().wstring();
+ if (filename.find(file_prefix) == 0 && ext == backend_filename_suffix()) {
+ dl_handle_ptr handle { dl_load_library(entry.path().wstring()) };
+ if (!handle) {
+ GGML_LOG_ERROR("%s: failed to load %s\n", __func__, utf16_to_utf8(entry.path().wstring()).c_str());
+ continue;
+ }
+
auto score_fn = (ggml_backend_score_t) dl_get_sym(handle.get(), "ggml_backend_score");
- if (score_fn) {
- int s = score_fn();
-#ifndef NDEBUG
- GGML_LOG_DEBUG("%s: %s score: %d\n", __func__, utf16_to_utf8(entry.path().wstring()).c_str(), s);
-#endif
- if (s > best_score) {
- best_score = s;
- best_path = entry.path().wstring();
- }
- } else {
- if (!silent) {
- GGML_LOG_INFO("%s: failed to find ggml_backend_score in %s\n", __func__, utf16_to_utf8(entry.path().wstring()).c_str());
- }
+ if (!score_fn) {
+ GGML_LOG_DEBUG("%s: failed to find ggml_backend_score in %s\n", __func__, utf16_to_utf8(entry.path().wstring()).c_str());
+ continue;
+ }
+
+ int s = score_fn();
+ GGML_LOG_DEBUG("%s: %s score: %d\n", __func__, utf16_to_utf8(entry.path().wstring()).c_str(), s);
+ if (s > best_score) {
+ best_score = s;
+ best_path = entry.path().wstring();
}
}
}
+ } catch (const std::exception & e) {
+ GGML_LOG_ERROR("%s: failed to load %s: %s\n", __func__, utf16_to_utf8(entry.path().wstring()).c_str(), e.what());
}
}
}

View File

@@ -8,12 +8,14 @@ import (
"fmt" "fmt"
"log" "log"
"log/slog" "log/slog"
"math"
"net" "net"
"net/http" "net/http"
"os" "os"
"path/filepath" "path/filepath"
"regexp" "regexp"
"runtime" "runtime"
"sort"
"strconv" "strconv"
"strings" "strings"
"sync" "sync"
@@ -48,8 +50,9 @@ type Sequence struct {
// inputs that have been added to a batch but not yet submitted to Decode // inputs that have been added to a batch but not yet submitted to Decode
pendingInputs []input pendingInputs []input
// TODO: update this comment
// tokens that have been generated but not returned yet (e.g. for stop sequences) // tokens that have been generated but not returned yet (e.g. for stop sequences)
pendingResponses []string pendingResponses []CompletionResponse
// input cache being used by this sequence // input cache being used by this sequence
cache *InputCacheSlot cache *InputCacheSlot
@@ -59,7 +62,7 @@ type Sequence struct {
crossAttention bool crossAttention bool
// channel to send responses over // channel to send responses over
responses chan string responses chan CompletionResponse
// channel to stop decoding (such as if the remote connection is closed) // channel to stop decoding (such as if the remote connection is closed)
quit chan bool quit chan bool
@@ -83,6 +86,11 @@ type Sequence struct {
doneReason string doneReason string
logits []float32
// number of logprobs to return with the completion response
logprobs int
// Metrics // Metrics
startProcessingTime time.Time startProcessingTime time.Time
startGenerationTime time.Time startGenerationTime time.Time
@@ -96,6 +104,7 @@ type NewSequenceParams struct {
numKeep int numKeep int
samplingParams *llama.SamplingParams samplingParams *llama.SamplingParams
embedding bool embedding bool
logprobs int
} }
func (s *Server) NewSequence(prompt string, images []ImageData, params NewSequenceParams) (*Sequence, error) { func (s *Server) NewSequence(prompt string, images []ImageData, params NewSequenceParams) (*Sequence, error) {
@@ -148,14 +157,15 @@ func (s *Server) NewSequence(prompt string, images []ImageData, params NewSequen
numPromptInputs: len(inputs), numPromptInputs: len(inputs),
startProcessingTime: startTime, startProcessingTime: startTime,
numPredict: params.numPredict, numPredict: params.numPredict,
pendingResponses: make([]string, 0), pendingResponses: make([]CompletionResponse, 0),
responses: make(chan string, 100), responses: make(chan CompletionResponse, 100),
quit: make(chan bool, 1), quit: make(chan bool, 1),
embedding: make(chan []float32, 1), embedding: make(chan []float32, 1),
samplingCtx: sc, samplingCtx: sc,
embeddingOnly: params.embedding, embeddingOnly: params.embedding,
stop: params.stop, stop: params.stop,
numKeep: params.numKeep, numKeep: params.numKeep,
logprobs: params.logprobs,
}, nil }, nil
} }
@@ -274,29 +284,37 @@ func (s *Server) allNil() bool {
} }
func flushPending(seq *Sequence) bool { func flushPending(seq *Sequence) bool {
joined := strings.Join(seq.pendingResponses, "") if len(seq.pendingResponses) == 0 {
seq.pendingResponses = []string{}
// Check if there are any partial UTF-8 characters remaining.
// We already check and queue as we are generating but some may
// still make it here:
// - Sequence is ending, e.g. generation limit has been hit
// - Invalid characters in the middle of a string
// This is a stricter check to ensure we never output invalid Unicode.
for !utf8.ValidString(joined) {
joined = joined[:len(joined)-1]
}
if len(joined) == 0 {
return true return true
} }
resps := []CompletionResponse{}
select { for _, resp := range seq.pendingResponses {
case seq.responses <- joined: resps = append(resps, resp)
return true
case <-seq.quit:
return false
} }
seq.pendingResponses = []CompletionResponse{}
// TODO: figure out this result logic
result := false
for _, resp := range resps {
// Check if there are any partial UTF-8 characters remaining.
// We already check and queue as we are generating but some may
// still make it here:
// - Sequence is ending, e.g. generation limit has been hit
// - Invalid characters in the middle of a string
// This is a stricter check to ensure we never output invalid Unicode.
for !utf8.ValidString(resp.Content) {
resp.Content = resp.Content[:len(resp.Content)-1]
}
select {
case seq.responses <- resp:
result = true
case <-seq.quit:
result = false
}
}
return result
} }
func (s *Server) removeSequence(seqIndex int, reason string) { func (s *Server) removeSequence(seqIndex int, reason string) {
@@ -350,6 +368,63 @@ func (s *Server) run(ctx context.Context) {
} }
} }
// TokenProbs represents probability information for a token
type TokenProbs struct {
TokenID int `json:"id"`
Logit float32 `json:"logit"`
Prob float32 `json:"prob"`
LogProb float32 `json:"logprob"`
Token string `json:"token"`
}
// probs returns sorted token probabilities for a specific token index
func probs(logits []float32, vocabSize int) []TokenProbs {
probs := make([]TokenProbs, vocabSize)
// Initialize token data with logits
for i := 0; i < vocabSize; i++ {
probs[i] = TokenProbs{
TokenID: i,
Logit: logits[i],
}
}
// Sort tokens by logits in descending order
sort.Slice(probs, func(i, j int) bool {
return probs[i].Logit > probs[j].Logit
})
// Apply softmax
maxLogit := probs[0].Logit
var sum float32 = 0.0
for i := range probs {
p := float32(math.Exp(float64(probs[i].Logit - maxLogit)))
probs[i].Prob = p
sum += p
}
// Normalize probabilities and calculate log probs
for i := range probs {
prob := probs[i].Prob / sum
probs[i].Prob = prob
probs[i].LogProb = float32(math.Log(float64(prob)))
}
return probs
}
// probs returns sorted token probabilities for a specific token index
func (s *Server) probs(seq *Sequence) []TokenProbs {
// Get logits for the specific token index
logits := s.lc.GetLogits()
seq.logits = make([]float32, len(logits))
copy(seq.logits, logits)
vocabSize := s.model.NumVocab()
return probs(logits, vocabSize)
}
// TODO (jmorganca): processBatch should be simplified, removing: // TODO (jmorganca): processBatch should be simplified, removing:
// * sampling // * sampling
// * stop token checking // * stop token checking
@@ -483,6 +558,19 @@ func (s *Server) processBatch(tokenBatch *llama.Batch, embedBatch *llama.Batch)
seq.numPredicted++ seq.numPredicted++
resp := CompletionResponse{Content: piece}
if seq.logprobs > 0 {
// TODO: return selected token in logprobs always
resp.LogProbs = s.probs(seq)
// TODO: fix this logprobs limit
resp.LogProbs = resp.LogProbs[:min(len(resp.LogProbs), seq.logprobs)]
for i := range resp.LogProbs {
// decode the token id to a piece
resp.LogProbs[i].Token = s.model.TokenToPiece(resp.LogProbs[i].TokenID)
}
}
// if it's an end of sequence token, break // if it's an end of sequence token, break
if s.model.TokenIsEog(token) { if s.model.TokenIsEog(token) {
// TODO (jmorganca): we should send this back // TODO (jmorganca): we should send this back
@@ -495,16 +583,21 @@ func (s *Server) processBatch(tokenBatch *llama.Batch, embedBatch *llama.Batch)
seq.inputs = []input{{token: token}} seq.inputs = []input{{token: token}}
seq.pendingResponses = append(seq.pendingResponses, piece) // TODO: add probs here
sequence := strings.Join(seq.pendingResponses, "") seq.pendingResponses = append(seq.pendingResponses, resp)
var sequence string
for _, r := range seq.pendingResponses {
sequence += r.Content
}
if ok, stop := findStop(sequence, seq.stop); ok { if ok, stop := findStop(sequence, seq.stop); ok {
slog.Debug("hit stop token", "pending", seq.pendingResponses, "stop", stop) slog.Debug("hit stop token", "pending", seq.pendingResponses, "stop", stop)
// TODO: fix this stop sequence caching
var tokenTruncated bool var tokenTruncated bool
origLen := len(seq.pendingResponses) origLen := len(sequence)
seq.pendingResponses, tokenTruncated = truncateStop(seq.pendingResponses, stop) sequence, tokenTruncated = truncateStop(sequence, stop)
newLen := len(seq.pendingResponses) newLen := len(sequence)
// Update the cache based on the tokens that will be returned: // Update the cache based on the tokens that will be returned:
// - We have 1 token more than is currently in the cache because // - We have 1 token more than is currently in the cache because
@@ -575,6 +668,7 @@ type CompletionRequest struct {
Images []ImageData `json:"image_data"` Images []ImageData `json:"image_data"`
Grammar string `json:"grammar"` Grammar string `json:"grammar"`
CachePrompt bool `json:"cache_prompt"` CachePrompt bool `json:"cache_prompt"`
Logprobs int `json:"logprobs,omitempty"`
Options Options
} }
@@ -590,8 +684,10 @@ type CompletionResponse struct {
Content string `json:"content"` Content string `json:"content"`
Stop bool `json:"stop"` Stop bool `json:"stop"`
Model string `json:"model,omitempty"` Model string `json:"model,omitempty"`
Prompt string `json:"prompt,omitempty"` Prompt string `json:"prompt,omitempty"`
LogProbs []TokenProbs `json:"logprobs,omitempty"`
StoppedLimit bool `json:"stopped_limit,omitempty"` StoppedLimit bool `json:"stopped_limit,omitempty"`
PredictedN int `json:"predicted_n,omitempty"` PredictedN int `json:"predicted_n,omitempty"`
PredictedMS float64 `json:"predicted_ms,omitempty"` PredictedMS float64 `json:"predicted_ms,omitempty"`
@@ -609,10 +705,6 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
return return
} }
// Set the headers to indicate streaming
w.Header().Set("Content-Type", "application/json")
w.Header().Set("Transfer-Encoding", "chunked")
flusher, ok := w.(http.Flusher) flusher, ok := w.(http.Flusher)
if !ok { if !ok {
http.Error(w, "Streaming not supported", http.StatusInternalServerError) http.Error(w, "Streaming not supported", http.StatusInternalServerError)
@@ -641,6 +733,7 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
numKeep: req.NumKeep, numKeep: req.NumKeep,
samplingParams: &samplingParams, samplingParams: &samplingParams,
embedding: false, embedding: false,
logprobs: req.Logprobs,
}) })
if err != nil { if err != nil {
http.Error(w, fmt.Sprintf("Failed to create new sequence: %v", err), http.StatusInternalServerError) http.Error(w, fmt.Sprintf("Failed to create new sequence: %v", err), http.StatusInternalServerError)
@@ -688,11 +781,10 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
case <-r.Context().Done(): case <-r.Context().Done():
close(seq.quit) close(seq.quit)
return return
case content, ok := <-seq.responses: case resp, ok := <-seq.responses:
if ok { if ok {
if err := json.NewEncoder(w).Encode(&CompletionResponse{ fmt.Println("response", resp)
Content: content, if err := json.NewEncoder(w).Encode(&resp); err != nil {
}); err != nil {
http.Error(w, fmt.Sprintf("failed to encode response: %v", err), http.StatusInternalServerError) http.Error(w, fmt.Sprintf("failed to encode response: %v", err), http.StatusInternalServerError)
close(seq.quit) close(seq.quit)
return return

View File

@@ -0,0 +1,58 @@
package runner
import (
"math"
"testing"
)
func TestProbs(t *testing.T) {
// Input test data
logits := []float32{1.0, 2.0, 0.5, -1.0}
vocabSize := 4
want := []TokenProbs{
{TokenID: 1, Logit: 2.0}, // Highest logit
{TokenID: 0, Logit: 1.0}, // Second highest
{TokenID: 2, Logit: 0.5}, // Third
{TokenID: 3, Logit: -1.0}, // Lowest
}
got := probs(logits, vocabSize)
// Test 1: Check sorting order
for i := 0; i < len(got)-1; i++ {
if got[i].Logit < got[i+1].Logit {
t.Errorf("probs not properly sorted: logit at pos %d (%f) < logit at pos %d (%f)",
i, got[i].Logit, i+1, got[i+1].Logit)
}
}
// Test 2: Check probability normalization
var sum float32
for _, p := range got {
sum += p.Prob
}
if math.Abs(float64(sum-1.0)) > 1e-6 {
t.Errorf("probabilities do not sum to 1: got %v", sum)
}
// Test 3: Check token IDs match expected order
for i, want := range want {
if got[i].TokenID != want.TokenID {
t.Errorf("wrong token ID at position %d: got %d, want %d",
i, got[i].TokenID, want.TokenID)
}
if got[i].Logit != want.Logit {
t.Errorf("wrong logit at position %d: got %f, want %f",
i, got[i].Logit, want.Logit)
}
}
// Test 4: Check log probs are correctly calculated
for i, p := range got {
expectedLogProb := float32(math.Log(float64(p.Prob)))
if math.Abs(float64(p.LogProb-expectedLogProb)) > 1e-6 {
t.Errorf("wrong log prob at position %d: got %f, want %f",
i, p.LogProb, expectedLogProb)
}
}
}

View File

@@ -26,43 +26,15 @@ func containsStopSuffix(sequence string, stops []string) bool {
return false return false
} }
// truncateStop removes the provided stop string from pieces, // truncateStop removes the provided stop string from sequence,
// returning the partial pieces with stop removed, including truncating // returning both the truncated sequence and a bool indicating if truncation occurred
// the last piece if required (and signalling if this was the case) func truncateStop(sequence string, stop string) (string, bool) {
func truncateStop(pieces []string, stop string) ([]string, bool) { index := strings.Index(sequence, stop)
joined := strings.Join(pieces, "")
index := strings.Index(joined, stop)
if index == -1 { if index == -1 {
return pieces, false return sequence, false
} }
joined = joined[:index] return sequence[:index], true
// Split truncated string back into pieces of original lengths
lengths := make([]int, len(pieces))
for i, piece := range pieces {
lengths[i] = len(piece)
}
var result []string
tokenTruncated := false
start := 0
for _, length := range lengths {
if start >= len(joined) {
break
}
end := start + length
if end > len(joined) {
end = len(joined)
tokenTruncated = true
}
result = append(result, joined[start:end])
start = end
}
return result, tokenTruncated
} }
func incompleteUnicode(token string) bool { func incompleteUnicode(token string) bool {

View File

@@ -1,60 +1,60 @@
package runner package runner
import ( import (
"reflect"
"testing" "testing"
) )
func TestTruncateStop(t *testing.T) { func TestTruncateStop(t *testing.T) {
tests := []struct { tests := []struct {
name string name string
pieces []string sequence string
stop string stop string
expected []string expected string
expectedTrunc bool expectedTrunc bool
}{ }{
{ {
name: "Single word", name: "Single word",
pieces: []string{"hello", "world"}, sequence: "helloworld",
stop: "world", stop: "world",
expected: []string{"hello"}, expected: "hello",
expectedTrunc: false, expectedTrunc: true,
}, },
{ {
name: "Partial", name: "Partial",
pieces: []string{"hello", "wor"}, sequence: "hellowor",
stop: "or", stop: "or",
expected: []string{"hello", "w"}, expected: "hellow",
expectedTrunc: true, expectedTrunc: true,
}, },
{ {
name: "Suffix", name: "Suffix",
pieces: []string{"Hello", " there", "!"}, sequence: "Hello there!",
stop: "!", stop: "!",
expected: []string{"Hello", " there"}, expected: "Hello there",
expectedTrunc: false,
},
{
name: "Suffix partial",
pieces: []string{"Hello", " the", "re!"},
stop: "there!",
expected: []string{"Hello", " "},
expectedTrunc: true, expectedTrunc: true,
}, },
{ {
name: "Middle", name: "Middle",
pieces: []string{"hello", " wor"}, sequence: "hello wor",
stop: "llo w", stop: "llo w",
expected: []string{"he"}, expected: "he",
expectedTrunc: true, expectedTrunc: true,
}, },
{
name: "No stop found",
sequence: "hello world",
stop: "xyz",
expected: "hello world",
expectedTrunc: false,
},
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
result, resultTrunc := truncateStop(tt.pieces, tt.stop) result, truncated := truncateStop(tt.sequence, tt.stop)
if !reflect.DeepEqual(result, tt.expected) || resultTrunc != tt.expectedTrunc { if result != tt.expected || truncated != tt.expectedTrunc {
t.Errorf("truncateStop(%v, %s): have %v (%v); want %v (%v)", tt.pieces, tt.stop, result, resultTrunc, tt.expected, tt.expectedTrunc) t.Errorf("truncateStop(%q, %q): have %q (%v); want %q (%v)",
tt.sequence, tt.stop, result, truncated, tt.expected, tt.expectedTrunc)
} }
}) })
} }

View File

@@ -644,12 +644,22 @@ type ImageData struct {
AspectRatioID int `json:"aspect_ratio_id"` AspectRatioID int `json:"aspect_ratio_id"`
} }
// TokenProbs represents probability information for a token
type TokenProbs struct {
TokenID int `json:"id"`
Logit float32 `json:"logit"`
Prob float32 `json:"prob"`
LogProb float32 `json:"logprob"`
Token string `json:"token"`
}
type completion struct { type completion struct {
Content string `json:"content"` Content string `json:"content"`
Model string `json:"model"` Model string `json:"model"`
Prompt string `json:"prompt"` Prompt string `json:"prompt"`
Stop bool `json:"stop"` Stop bool `json:"stop"`
StoppedLimit bool `json:"stopped_limit"` StoppedLimit bool `json:"stopped_limit"`
LogProbs []TokenProbs `json:"logprobs"`
Timings struct { Timings struct {
PredictedN int `json:"predicted_n"` PredictedN int `json:"predicted_n"`
@@ -660,14 +670,16 @@ type completion struct {
} }
type CompletionRequest struct { type CompletionRequest struct {
Prompt string Prompt string
Format json.RawMessage Format json.RawMessage
Images []ImageData Images []ImageData
Options *api.Options LogProbs int
Options *api.Options
} }
type CompletionResponse struct { type CompletionResponse struct {
Content string Content string
LogProbs []TokenProbs
DoneReason string DoneReason string
Done bool Done bool
PromptEvalCount int PromptEvalCount int
@@ -698,9 +710,12 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu
"seed": req.Options.Seed, "seed": req.Options.Seed,
"stop": req.Options.Stop, "stop": req.Options.Stop,
"image_data": req.Images, "image_data": req.Images,
"logprobs": req.LogProbs,
"cache_prompt": true, "cache_prompt": true,
} }
fmt.Println("completion request:", request)
if len(req.Format) > 0 { if len(req.Format) > 0 {
switch string(req.Format) { switch string(req.Format) {
case `null`, `""`: case `null`, `""`:
@@ -796,7 +811,6 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu
continue continue
} }
// slog.Debug("got line", "line", string(line))
evt, ok := bytes.CutPrefix(line, []byte("data: ")) evt, ok := bytes.CutPrefix(line, []byte("data: "))
if !ok { if !ok {
evt = line evt = line
@@ -822,7 +836,8 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu
if c.Content != "" { if c.Content != "" {
fn(CompletionResponse{ fn(CompletionResponse{
Content: c.Content, Content: c.Content,
LogProbs: c.LogProbs,
}) })
} }
@@ -839,6 +854,7 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu
PromptEvalDuration: parseDurationMs(c.Timings.PromptMS), PromptEvalDuration: parseDurationMs(c.Timings.PromptMS),
EvalCount: c.Timings.PredictedN, EvalCount: c.Timings.PredictedN,
EvalDuration: parseDurationMs(c.Timings.PredictedMS), EvalDuration: parseDurationMs(c.Timings.PredictedMS),
LogProbs: c.LogProbs,
}) })
return nil return nil
} }

View File

@@ -512,32 +512,33 @@ static ggml_backend_reg_t ggml_backend_load_best(const char * name, bool silent,
} }
fs::directory_iterator dir_it(search_path, fs::directory_options::skip_permission_denied); fs::directory_iterator dir_it(search_path, fs::directory_options::skip_permission_denied);
for (const auto & entry : dir_it) { for (const auto & entry : dir_it) {
if (entry.is_regular_file()) { try {
std::wstring filename = entry.path().filename().wstring(); if (entry.is_regular_file()) {
std::wstring ext = entry.path().extension().wstring(); std::wstring filename = entry.path().filename().wstring();
if (filename.find(file_prefix) == 0 && ext == backend_filename_suffix()) { std::wstring ext = entry.path().extension().wstring();
dl_handle_ptr handle { dl_load_library(entry.path().wstring()) }; if (filename.find(file_prefix) == 0 && ext == backend_filename_suffix()) {
if (!handle && !silent) { dl_handle_ptr handle { dl_load_library(entry.path().wstring()) };
GGML_LOG_ERROR("%s: failed to load %s\n", __func__, utf16_to_utf8(entry.path().wstring()).c_str()); if (!handle) {
} GGML_LOG_ERROR("%s: failed to load %s\n", __func__, utf16_to_utf8(entry.path().wstring()).c_str());
if (handle) { continue;
}
auto score_fn = (ggml_backend_score_t) dl_get_sym(handle.get(), "ggml_backend_score"); auto score_fn = (ggml_backend_score_t) dl_get_sym(handle.get(), "ggml_backend_score");
if (score_fn) { if (!score_fn) {
int s = score_fn(); GGML_LOG_DEBUG("%s: failed to find ggml_backend_score in %s\n", __func__, utf16_to_utf8(entry.path().wstring()).c_str());
#ifndef NDEBUG continue;
GGML_LOG_DEBUG("%s: %s score: %d\n", __func__, utf16_to_utf8(entry.path().wstring()).c_str(), s); }
#endif
if (s > best_score) { int s = score_fn();
best_score = s; GGML_LOG_DEBUG("%s: %s score: %d\n", __func__, utf16_to_utf8(entry.path().wstring()).c_str(), s);
best_path = entry.path().wstring(); if (s > best_score) {
} best_score = s;
} else { best_path = entry.path().wstring();
if (!silent) {
GGML_LOG_INFO("%s: failed to find ggml_backend_score in %s\n", __func__, utf16_to_utf8(entry.path().wstring()).c_str());
}
} }
} }
} }
} catch (const std::exception & e) {
GGML_LOG_ERROR("%s: failed to load %s: %s\n", __func__, utf16_to_utf8(entry.path().wstring()).c_str(), e.what());
} }
} }
} }

View File

@@ -79,6 +79,11 @@ var OnceLoad = sync.OnceFunc(func() {
continue continue
} }
if abspath != filepath.Dir(exe) && !strings.Contains(abspath, filepath.FromSlash("lib/ollama")) {
slog.Debug("skipping path which is not part of ollama", "path", abspath)
continue
}
if _, ok := visited[abspath]; !ok { if _, ok := visited[abspath]; !ok {
func() { func() {
slog.Debug("ggml backend load all from path", "path", abspath) slog.Debug("ggml backend load all from path", "path", abspath)

View File

@@ -610,14 +610,14 @@ type EmbedWriter struct {
} }
func (w *BaseWriter) writeError(data []byte) (int, error) { func (w *BaseWriter) writeError(data []byte) (int, error) {
var er api.ErrorResponse // error response is used here to parse the error message var serr api.StatusError
err := json.Unmarshal(data, &er) err := json.Unmarshal(data, &serr)
if err != nil { if err != nil {
return 0, err return 0, err
} }
w.ResponseWriter.Header().Set("Content-Type", "application/json") w.ResponseWriter.Header().Set("Content-Type", "application/json")
err = json.NewEncoder(w.ResponseWriter).Encode(NewError(http.StatusInternalServerError, er.Err)) err = json.NewEncoder(w.ResponseWriter).Encode(NewError(http.StatusInternalServerError, serr.Error()))
if err != nil { if err != nil {
return 0, err return 0, err
} }

View File

@@ -550,7 +550,7 @@ func PullModel(ctx context.Context, name string, regOpts *registryOptions, fn fu
manifest, err = pullModelManifest(ctx, mp, regOpts) manifest, err = pullModelManifest(ctx, mp, regOpts)
if err != nil { if err != nil {
return fmt.Errorf("pull model manifest: %w", err) return fmt.Errorf("pull model manifest: %s", err)
} }
var layers []Layer var layers []Layer
@@ -629,18 +629,13 @@ func PullModel(ctx context.Context, name string, regOpts *registryOptions, fn fu
return nil return nil
} }
var ErrRemoteModelNotFound = errors.New("model not found")
func pullModelManifest(ctx context.Context, mp ModelPath, regOpts *registryOptions) (*Manifest, error) { func pullModelManifest(ctx context.Context, mp ModelPath, regOpts *registryOptions) (*Manifest, error) {
requestURL := mp.BaseURL().JoinPath("v2", mp.GetNamespaceRepository(), "manifests", mp.Tag) requestURL := mp.BaseURL().JoinPath("v2", mp.GetNamespaceRepository(), "manifests", mp.Tag)
headers := make(http.Header) headers := make(http.Header)
headers.Set("Accept", "application/vnd.docker.distribution.manifest.v2+json") headers.Set("Accept", "application/vnd.docker.distribution.manifest.v2+json")
resp, err := makeRequestWithRetry(ctx, http.MethodGet, requestURL, headers, nil, regOpts) resp, err := makeRequestWithRetry(ctx, http.MethodGet, requestURL, headers, nil, regOpts)
if errors.Is(err, os.ErrNotExist) { if err != nil {
// The model was not found on the remote registry
return nil, fmt.Errorf("%w: %s", ErrRemoteModelNotFound, err)
} else if err != nil {
return nil, err return nil, err
} }
defer resp.Body.Close() defer resp.Body.Close()

View File

@@ -293,11 +293,13 @@ func (s *Server) GenerateHandler(c *gin.Context) {
var sb strings.Builder var sb strings.Builder
defer close(ch) defer close(ch)
if err := r.Completion(c.Request.Context(), llm.CompletionRequest{ if err := r.Completion(c.Request.Context(), llm.CompletionRequest{
Prompt: prompt, Prompt: prompt,
Images: images, Images: images,
Format: req.Format, Format: req.Format,
Options: opts, LogProbs: req.LogProbs,
Options: opts,
}, func(cr llm.CompletionResponse) { }, func(cr llm.CompletionResponse) {
fmt.Printf("banana: %#v\n", cr)
res := api.GenerateResponse{ res := api.GenerateResponse{
Model: req.Model, Model: req.Model,
CreatedAt: time.Now().UTC(), CreatedAt: time.Now().UTC(),
@@ -311,6 +313,13 @@ func (s *Server) GenerateHandler(c *gin.Context) {
EvalDuration: cr.EvalDuration, EvalDuration: cr.EvalDuration,
}, },
} }
for _, p := range cr.LogProbs {
res.LogProbs = append(res.LogProbs, api.TokenProbs{
TokenID: p.TokenID,
LogProb: p.LogProb,
Token: p.Token,
})
}
if _, err := sb.WriteString(cr.Content); err != nil { if _, err := sb.WriteString(cr.Content); err != nil {
ch <- gin.H{"error": err.Error()} ch <- gin.H{"error": err.Error()}
@@ -564,8 +573,7 @@ func (s *Server) PullHandler(c *gin.Context) {
return return
} }
reqName := cmp.Or(req.Model, req.Name) name := model.ParseName(cmp.Or(req.Model, req.Name))
name := model.ParseName(reqName)
if !name.IsValid() { if !name.IsValid() {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": errtypes.InvalidModelNameErrMsg}) c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": errtypes.InvalidModelNameErrMsg})
return return
@@ -592,18 +600,7 @@ func (s *Server) PullHandler(c *gin.Context) {
defer cancel() defer cancel()
if err := PullModel(ctx, name.DisplayShortest(), regOpts, fn); err != nil { if err := PullModel(ctx, name.DisplayShortest(), regOpts, fn); err != nil {
if errors.Is(err, ErrRemoteModelNotFound) { ch <- gin.H{"error": err.Error()}
hint := fmt.Sprintf("Model %q not found - please check the model name is correct and try again", reqName)
if name.Host == DefaultRegistry {
hint = fmt.Sprintf("Model %q not found - search available models at: https://ollama.com/search?q=%s", reqName, reqName)
}
ch <- api.ErrorResponse{
Err: err.Error(),
Hint: hint,
}
} else {
ch <- gin.H{"error": err.Error()}
}
} }
}() }()
@@ -1478,10 +1475,11 @@ func (s *Server) ChatHandler(c *gin.Context) {
var sb strings.Builder var sb strings.Builder
var toolCallIndex int = 0 var toolCallIndex int = 0
if err := r.Completion(c.Request.Context(), llm.CompletionRequest{ if err := r.Completion(c.Request.Context(), llm.CompletionRequest{
Prompt: prompt, Prompt: prompt,
Images: images, Images: images,
Format: req.Format, Format: req.Format,
Options: opts, LogProbs: req.LogProbs,
Options: opts,
}, func(r llm.CompletionResponse) { }, func(r llm.CompletionResponse) {
res := api.ChatResponse{ res := api.ChatResponse{
Model: req.Model, Model: req.Model,
@@ -1496,6 +1494,13 @@ func (s *Server) ChatHandler(c *gin.Context) {
EvalDuration: r.EvalDuration, EvalDuration: r.EvalDuration,
}, },
} }
for _, p := range r.LogProbs {
res.LogProbs = append(res.LogProbs, api.TokenProbs{
TokenID: p.TokenID,
LogProb: p.LogProb,
Token: p.Token,
})
}
if r.Done { if r.Done {
res.TotalDuration = time.Since(checkpointStart) res.TotalDuration = time.Since(checkpointStart)