Compare commits
9 Commits
brucemacd/
...
brucemacd/
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
b88489a87e | ||
|
|
fdbb0b5cfe | ||
|
|
64f95067ba | ||
|
|
6dfcdec2da | ||
|
|
7d16ec8fe8 | ||
|
|
82658c3eec | ||
|
|
378d6e1e6a | ||
|
|
afa55bc70c | ||
|
|
49df03da9a |
@@ -437,9 +437,10 @@ See the [API documentation](./docs/api.md) for all endpoints.
|
|||||||
|
|
||||||
- [Pacman](https://archlinux.org/packages/extra/x86_64/ollama/)
|
- [Pacman](https://archlinux.org/packages/extra/x86_64/ollama/)
|
||||||
- [Gentoo](https://github.com/gentoo/guru/tree/master/app-misc/ollama)
|
- [Gentoo](https://github.com/gentoo/guru/tree/master/app-misc/ollama)
|
||||||
|
- [Homebrew](https://formulae.brew.sh/formula/ollama)
|
||||||
- [Helm Chart](https://artifacthub.io/packages/helm/ollama-helm/ollama)
|
- [Helm Chart](https://artifacthub.io/packages/helm/ollama-helm/ollama)
|
||||||
- [Guix channel](https://codeberg.org/tusharhero/ollama-guix)
|
- [Guix channel](https://codeberg.org/tusharhero/ollama-guix)
|
||||||
- [Nix package](https://search.nixos.org/packages?channel=24.05&show=ollama&from=0&size=50&sort=relevance&type=packages&query=ollama)
|
- [Nix package](https://search.nixos.org/packages?show=ollama&from=0&size=50&sort=relevance&type=packages&query=ollama)
|
||||||
- [Flox](https://flox.dev/blog/ollama-part-one)
|
- [Flox](https://flox.dev/blog/ollama-part-one)
|
||||||
|
|
||||||
### Libraries
|
### Libraries
|
||||||
@@ -494,7 +495,7 @@ See the [API documentation](./docs/api.md) for all endpoints.
|
|||||||
- [multi-llm-ts](https://github.com/nbonamy/multi-llm-ts) (A Typescript/JavaScript library allowing access to different LLM in unified API)
|
- [multi-llm-ts](https://github.com/nbonamy/multi-llm-ts) (A Typescript/JavaScript library allowing access to different LLM in unified API)
|
||||||
- [LlmTornado](https://github.com/lofcz/llmtornado) (C# library providing a unified interface for major FOSS & Commercial inference APIs)
|
- [LlmTornado](https://github.com/lofcz/llmtornado) (C# library providing a unified interface for major FOSS & Commercial inference APIs)
|
||||||
- [Ollama for Zig](https://github.com/dravenk/ollama-zig)
|
- [Ollama for Zig](https://github.com/dravenk/ollama-zig)
|
||||||
- [Abso](https://github.com/lunary-ai/abso/blob/main/README.md#ollama) (OpenAI-compatible TypeScript SDK for any LLM provider)
|
- [Abso](https://github.com/lunary-ai/abso) (OpenAI-compatible TypeScript SDK for any LLM provider)
|
||||||
|
|
||||||
### Mobile
|
### Mobile
|
||||||
|
|
||||||
|
|||||||
@@ -18,6 +18,7 @@ import (
|
|||||||
"bytes"
|
"bytes"
|
||||||
"context"
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
@@ -29,28 +30,6 @@ import (
|
|||||||
"github.com/ollama/ollama/version"
|
"github.com/ollama/ollama/version"
|
||||||
)
|
)
|
||||||
|
|
||||||
// StatusError is an error with an HTTP status code and message,
|
|
||||||
// it is parsed on the client-side and not returned from the API
|
|
||||||
type StatusError struct {
|
|
||||||
StatusCode int // e.g. 200
|
|
||||||
Status string // e.g. "200 OK"
|
|
||||||
ErrorResponse
|
|
||||||
}
|
|
||||||
|
|
||||||
func (e StatusError) Error() string {
|
|
||||||
switch {
|
|
||||||
case e.Status != "" && e.Err != "":
|
|
||||||
return fmt.Sprintf("%s: %s", e.Status, e.Err)
|
|
||||||
case e.Status != "":
|
|
||||||
return e.Status
|
|
||||||
case e.Err != "":
|
|
||||||
return e.Err
|
|
||||||
default:
|
|
||||||
// this should not happen
|
|
||||||
return "something went wrong, please see the ollama server logs for details"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Client encapsulates client state for interacting with the ollama
|
// Client encapsulates client state for interacting with the ollama
|
||||||
// service. Use [ClientFromEnvironment] to create new Clients.
|
// service. Use [ClientFromEnvironment] to create new Clients.
|
||||||
type Client struct {
|
type Client struct {
|
||||||
@@ -68,7 +47,7 @@ func checkError(resp *http.Response, body []byte) error {
|
|||||||
err := json.Unmarshal(body, &apiError)
|
err := json.Unmarshal(body, &apiError)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// Use the full body as the message if we fail to decode a response.
|
// Use the full body as the message if we fail to decode a response.
|
||||||
apiError.Err = string(body)
|
apiError.ErrorMessage = string(body)
|
||||||
}
|
}
|
||||||
|
|
||||||
return apiError
|
return apiError
|
||||||
@@ -153,7 +132,7 @@ func (c *Client) do(ctx context.Context, method, path string, reqData, respData
|
|||||||
const maxBufferSize = 512 * format.KiloByte
|
const maxBufferSize = 512 * format.KiloByte
|
||||||
|
|
||||||
func (c *Client) stream(ctx context.Context, method, path string, data any, fn func([]byte) error) error {
|
func (c *Client) stream(ctx context.Context, method, path string, data any, fn func([]byte) error) error {
|
||||||
var buf io.Reader
|
var buf *bytes.Buffer
|
||||||
if data != nil {
|
if data != nil {
|
||||||
bts, err := json.Marshal(data)
|
bts, err := json.Marshal(data)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -184,22 +163,24 @@ func (c *Client) stream(ctx context.Context, method, path string, data any, fn f
|
|||||||
scanBuf := make([]byte, 0, maxBufferSize)
|
scanBuf := make([]byte, 0, maxBufferSize)
|
||||||
scanner.Buffer(scanBuf, maxBufferSize)
|
scanner.Buffer(scanBuf, maxBufferSize)
|
||||||
for scanner.Scan() {
|
for scanner.Scan() {
|
||||||
bts := scanner.Bytes()
|
var errorResponse struct {
|
||||||
|
Error string `json:"error,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
var errorResponse ErrorResponse
|
bts := scanner.Bytes()
|
||||||
if err := json.Unmarshal(bts, &errorResponse); err != nil {
|
if err := json.Unmarshal(bts, &errorResponse); err != nil {
|
||||||
return fmt.Errorf("unmarshal: %w", err)
|
return fmt.Errorf("unmarshal: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if errorResponse.Err != "" {
|
if errorResponse.Error != "" {
|
||||||
return errorResponse
|
return errors.New(errorResponse.Error)
|
||||||
}
|
}
|
||||||
|
|
||||||
if response.StatusCode >= http.StatusBadRequest {
|
if response.StatusCode >= http.StatusBadRequest {
|
||||||
return StatusError{
|
return StatusError{
|
||||||
StatusCode: response.StatusCode,
|
StatusCode: response.StatusCode,
|
||||||
Status: response.Status,
|
Status: response.Status,
|
||||||
ErrorResponse: errorResponse,
|
ErrorMessage: errorResponse.Error,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -1,12 +1,6 @@
|
|||||||
package api
|
package api
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
|
||||||
"encoding/json"
|
|
||||||
"fmt"
|
|
||||||
"net/http"
|
|
||||||
"net/http/httptest"
|
|
||||||
"net/url"
|
|
||||||
"testing"
|
"testing"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -49,270 +43,3 @@ func TestClientFromEnvironment(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// testError represents an internal error type for testing different error formats
|
|
||||||
type testError struct {
|
|
||||||
message string // basic error message
|
|
||||||
structured *ErrorResponse // structured error response, nil for basic format
|
|
||||||
statusCode int
|
|
||||||
}
|
|
||||||
|
|
||||||
func (e testError) Error() string {
|
|
||||||
return e.message
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestClientStream(t *testing.T) {
|
|
||||||
testCases := []struct {
|
|
||||||
name string
|
|
||||||
responses []any
|
|
||||||
wantErr string
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
name: "basic error format",
|
|
||||||
responses: []any{
|
|
||||||
testError{
|
|
||||||
message: "test error message",
|
|
||||||
statusCode: http.StatusBadRequest,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
wantErr: "test error message",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "structured error format",
|
|
||||||
responses: []any{
|
|
||||||
testError{
|
|
||||||
message: "test structured error",
|
|
||||||
structured: &ErrorResponse{
|
|
||||||
Err: "test structured error",
|
|
||||||
Hint: "test hint",
|
|
||||||
},
|
|
||||||
statusCode: http.StatusBadRequest,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
wantErr: "test structured error\ntest hint",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "error after chunks - basic format",
|
|
||||||
responses: []any{
|
|
||||||
ChatResponse{Message: Message{Content: "partial 1"}},
|
|
||||||
ChatResponse{Message: Message{Content: "partial 2"}},
|
|
||||||
testError{
|
|
||||||
message: "mid-stream basic error",
|
|
||||||
statusCode: http.StatusOK,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
wantErr: "mid-stream basic error",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "error after chunks - structured format",
|
|
||||||
responses: []any{
|
|
||||||
ChatResponse{Message: Message{Content: "partial 1"}},
|
|
||||||
ChatResponse{Message: Message{Content: "partial 2"}},
|
|
||||||
testError{
|
|
||||||
message: "mid-stream structured error",
|
|
||||||
structured: &ErrorResponse{
|
|
||||||
Err: "mid-stream structured error",
|
|
||||||
Hint: "additional context",
|
|
||||||
},
|
|
||||||
statusCode: http.StatusOK,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
wantErr: "mid-stream structured error\nadditional context",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "successful stream completion",
|
|
||||||
responses: []any{
|
|
||||||
ChatResponse{Message: Message{Content: "chunk 1"}},
|
|
||||||
ChatResponse{Message: Message{Content: "chunk 2"}},
|
|
||||||
ChatResponse{
|
|
||||||
Message: Message{Content: "final chunk"},
|
|
||||||
Done: true,
|
|
||||||
DoneReason: "stop",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tc := range testCases {
|
|
||||||
t.Run(tc.name, func(t *testing.T) {
|
|
||||||
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
flusher, ok := w.(http.Flusher)
|
|
||||||
if !ok {
|
|
||||||
t.Fatal("expected http.Flusher")
|
|
||||||
}
|
|
||||||
|
|
||||||
w.Header().Set("Content-Type", "application/x-ndjson")
|
|
||||||
|
|
||||||
for _, resp := range tc.responses {
|
|
||||||
if errResp, ok := resp.(testError); ok {
|
|
||||||
w.WriteHeader(errResp.statusCode)
|
|
||||||
var err error
|
|
||||||
if errResp.structured != nil {
|
|
||||||
err = json.NewEncoder(w).Encode(errResp.structured)
|
|
||||||
} else {
|
|
||||||
err = json.NewEncoder(w).Encode(map[string]string{
|
|
||||||
"error": errResp.message,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal("failed to encode error response:", err)
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := json.NewEncoder(w).Encode(resp); err != nil {
|
|
||||||
t.Fatalf("failed to encode response: %v", err)
|
|
||||||
}
|
|
||||||
flusher.Flush()
|
|
||||||
}
|
|
||||||
}))
|
|
||||||
defer ts.Close()
|
|
||||||
|
|
||||||
client := NewClient(&url.URL{Scheme: "http", Host: ts.Listener.Addr().String()}, http.DefaultClient)
|
|
||||||
|
|
||||||
var receivedChunks []ChatResponse
|
|
||||||
err := client.stream(context.Background(), http.MethodPost, "/v1/chat", nil, func(chunk []byte) error {
|
|
||||||
var resp ChatResponse
|
|
||||||
if err := json.Unmarshal(chunk, &resp); err != nil {
|
|
||||||
return fmt.Errorf("failed to unmarshal chunk: %w", err)
|
|
||||||
}
|
|
||||||
receivedChunks = append(receivedChunks, resp)
|
|
||||||
return nil
|
|
||||||
})
|
|
||||||
|
|
||||||
if tc.wantErr != "" {
|
|
||||||
if err == nil {
|
|
||||||
t.Fatalf("got nil, want error %q", tc.wantErr)
|
|
||||||
}
|
|
||||||
if err.Error() != tc.wantErr {
|
|
||||||
t.Errorf("error message mismatch: got %q, want %q", err.Error(), tc.wantErr)
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
if err != nil {
|
|
||||||
t.Errorf("got error %q, want nil", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestClientDo(t *testing.T) {
|
|
||||||
testCases := []struct {
|
|
||||||
name string
|
|
||||||
response any
|
|
||||||
wantErr string
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
name: "basic error format",
|
|
||||||
response: testError{
|
|
||||||
message: "test error message",
|
|
||||||
statusCode: http.StatusBadRequest,
|
|
||||||
},
|
|
||||||
wantErr: "test error message",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "structured error format",
|
|
||||||
response: testError{
|
|
||||||
message: "test structured error",
|
|
||||||
structured: &ErrorResponse{
|
|
||||||
Err: "test structured error",
|
|
||||||
Hint: "test hint",
|
|
||||||
},
|
|
||||||
statusCode: http.StatusBadRequest,
|
|
||||||
},
|
|
||||||
wantErr: "test structured error",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "server error - basic format",
|
|
||||||
response: testError{
|
|
||||||
message: "internal error",
|
|
||||||
statusCode: http.StatusInternalServerError,
|
|
||||||
},
|
|
||||||
wantErr: "internal error",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "server error - structured format",
|
|
||||||
response: testError{
|
|
||||||
message: "internal server error",
|
|
||||||
structured: &ErrorResponse{
|
|
||||||
Err: "internal server error",
|
|
||||||
Hint: "please try again later",
|
|
||||||
},
|
|
||||||
statusCode: http.StatusInternalServerError,
|
|
||||||
},
|
|
||||||
wantErr: "internal server error",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "successful response",
|
|
||||||
response: struct {
|
|
||||||
ID string `json:"id"`
|
|
||||||
Success bool `json:"success"`
|
|
||||||
}{
|
|
||||||
ID: "msg_123",
|
|
||||||
Success: true,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tc := range testCases {
|
|
||||||
t.Run(tc.name, func(t *testing.T) {
|
|
||||||
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
if errResp, ok := tc.response.(testError); ok {
|
|
||||||
w.WriteHeader(errResp.statusCode)
|
|
||||||
var err error
|
|
||||||
if errResp.structured != nil {
|
|
||||||
err = json.NewEncoder(w).Encode(errResp.structured)
|
|
||||||
} else {
|
|
||||||
err = json.NewEncoder(w).Encode(map[string]string{
|
|
||||||
"error": errResp.message,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal("failed to encode error response:", err)
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
w.Header().Set("Content-Type", "application/json")
|
|
||||||
if err := json.NewEncoder(w).Encode(tc.response); err != nil {
|
|
||||||
t.Fatalf("failed to encode response: %v", err)
|
|
||||||
}
|
|
||||||
}))
|
|
||||||
defer ts.Close()
|
|
||||||
|
|
||||||
client := NewClient(&url.URL{Scheme: "http", Host: ts.Listener.Addr().String()}, http.DefaultClient)
|
|
||||||
|
|
||||||
var resp struct {
|
|
||||||
ID string `json:"id"`
|
|
||||||
Success bool `json:"success"`
|
|
||||||
}
|
|
||||||
err := client.do(context.Background(), http.MethodPost, "/v1/messages", nil, &resp)
|
|
||||||
|
|
||||||
if tc.wantErr != "" {
|
|
||||||
if err == nil {
|
|
||||||
t.Fatalf("got nil, want error %q", tc.wantErr)
|
|
||||||
}
|
|
||||||
if err.Error() != tc.wantErr {
|
|
||||||
t.Errorf("error message mismatch: got %q, want %q", err.Error(), tc.wantErr)
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
t.Errorf("got error %q, want nil", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if expectedResp, ok := tc.response.(struct {
|
|
||||||
ID string `json:"id"`
|
|
||||||
Success bool `json:"success"`
|
|
||||||
}); ok {
|
|
||||||
if resp.ID != expectedResp.ID {
|
|
||||||
t.Errorf("response ID mismatch: got %q, want %q", resp.ID, expectedResp.ID)
|
|
||||||
}
|
|
||||||
if resp.Success != expectedResp.Success {
|
|
||||||
t.Errorf("response Success mismatch: got %v, want %v", resp.Success, expectedResp.Success)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|||||||
58
api/types.go
58
api/types.go
@@ -12,6 +12,27 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// StatusError is an error with an HTTP status code and message.
|
||||||
|
type StatusError struct {
|
||||||
|
StatusCode int
|
||||||
|
Status string
|
||||||
|
ErrorMessage string `json:"error"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e StatusError) Error() string {
|
||||||
|
switch {
|
||||||
|
case e.Status != "" && e.ErrorMessage != "":
|
||||||
|
return fmt.Sprintf("%s: %s", e.Status, e.ErrorMessage)
|
||||||
|
case e.Status != "":
|
||||||
|
return e.Status
|
||||||
|
case e.ErrorMessage != "":
|
||||||
|
return e.ErrorMessage
|
||||||
|
default:
|
||||||
|
// this should not happen
|
||||||
|
return "something went wrong, please see the ollama server logs for details"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// ImageData represents the raw binary data of an image file.
|
// ImageData represents the raw binary data of an image file.
|
||||||
type ImageData []byte
|
type ImageData []byte
|
||||||
|
|
||||||
@@ -56,6 +77,8 @@ type GenerateRequest struct {
|
|||||||
// request, for multimodal models.
|
// request, for multimodal models.
|
||||||
Images []ImageData `json:"images,omitempty"`
|
Images []ImageData `json:"images,omitempty"`
|
||||||
|
|
||||||
|
LogProbs int `json:"logprobs,omitempty"`
|
||||||
|
|
||||||
// Options lists model-specific options. For example, temperature can be
|
// Options lists model-specific options. For example, temperature can be
|
||||||
// set through this field, if the model supports it.
|
// set through this field, if the model supports it.
|
||||||
Options map[string]interface{} `json:"options"`
|
Options map[string]interface{} `json:"options"`
|
||||||
@@ -82,6 +105,8 @@ type ChatRequest struct {
|
|||||||
// Tools is an optional list of tools the model has access to.
|
// Tools is an optional list of tools the model has access to.
|
||||||
Tools `json:"tools,omitempty"`
|
Tools `json:"tools,omitempty"`
|
||||||
|
|
||||||
|
LogProbs int `json:"logprobs,omitempty"`
|
||||||
|
|
||||||
// Options lists model-specific options.
|
// Options lists model-specific options.
|
||||||
Options map[string]interface{} `json:"options"`
|
Options map[string]interface{} `json:"options"`
|
||||||
}
|
}
|
||||||
@@ -161,13 +186,20 @@ func (t *ToolFunction) String() string {
|
|||||||
return string(bts)
|
return string(bts)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type TokenProbs struct {
|
||||||
|
TokenID int `json:"id"`
|
||||||
|
LogProb float32 `json:"logprob"`
|
||||||
|
Token string `json:"token"`
|
||||||
|
}
|
||||||
|
|
||||||
// ChatResponse is the response returned by [Client.Chat]. Its fields are
|
// ChatResponse is the response returned by [Client.Chat]. Its fields are
|
||||||
// similar to [GenerateResponse].
|
// similar to [GenerateResponse].
|
||||||
type ChatResponse struct {
|
type ChatResponse struct {
|
||||||
Model string `json:"model"`
|
Model string `json:"model"`
|
||||||
CreatedAt time.Time `json:"created_at"`
|
CreatedAt time.Time `json:"created_at"`
|
||||||
Message Message `json:"message"`
|
Message Message `json:"message"`
|
||||||
DoneReason string `json:"done_reason,omitempty"`
|
DoneReason string `json:"done_reason,omitempty"`
|
||||||
|
LogProbs []TokenProbs `json:"logprobs,omitempty"`
|
||||||
|
|
||||||
Done bool `json:"done"`
|
Done bool `json:"done"`
|
||||||
|
|
||||||
@@ -431,6 +463,8 @@ type GenerateResponse struct {
|
|||||||
// can be sent in the next request to keep a conversational memory.
|
// can be sent in the next request to keep a conversational memory.
|
||||||
Context []int `json:"context,omitempty"`
|
Context []int `json:"context,omitempty"`
|
||||||
|
|
||||||
|
LogProbs []TokenProbs `json:"logprobs,omitempty"`
|
||||||
|
|
||||||
Metrics
|
Metrics
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -640,22 +674,6 @@ func (d *Duration) UnmarshalJSON(b []byte) (err error) {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// ErrorResponse implements a structured error interface that is returned from the Ollama server
|
|
||||||
type ErrorResponse struct {
|
|
||||||
// Err is the error from the server. It helps with debugging the code-path
|
|
||||||
Err string `json:"error"`
|
|
||||||
|
|
||||||
// Hint is a user-friendly message about what went wrong, with suggested troubleshooting
|
|
||||||
Hint string `json:"hint"`
|
|
||||||
}
|
|
||||||
|
|
||||||
func (e ErrorResponse) Error() string {
|
|
||||||
if e.Hint == "" {
|
|
||||||
return e.Err
|
|
||||||
}
|
|
||||||
return fmt.Sprintf("%s\n%s", e.Err, e.Hint)
|
|
||||||
}
|
|
||||||
|
|
||||||
// FormatParams converts specified parameter options to their correct types
|
// FormatParams converts specified parameter options to their correct types
|
||||||
func FormatParams(params map[string][]string) (map[string]interface{}, error) {
|
func FormatParams(params map[string][]string) (map[string]interface{}, error) {
|
||||||
opts := Options{}
|
opts := Options{}
|
||||||
|
|||||||
@@ -50,7 +50,7 @@ import (
|
|||||||
_ "github.com/ollama/ollama/llama/llama.cpp/common"
|
_ "github.com/ollama/ollama/llama/llama.cpp/common"
|
||||||
_ "github.com/ollama/ollama/llama/llama.cpp/examples/llava"
|
_ "github.com/ollama/ollama/llama/llama.cpp/examples/llava"
|
||||||
_ "github.com/ollama/ollama/llama/llama.cpp/src"
|
_ "github.com/ollama/ollama/llama/llama.cpp/src"
|
||||||
"github.com/ollama/ollama/ml/backend/ggml/ggml/src"
|
ggml "github.com/ollama/ollama/ml/backend/ggml/ggml/src"
|
||||||
)
|
)
|
||||||
|
|
||||||
func BackendInit() {
|
func BackendInit() {
|
||||||
@@ -220,6 +220,19 @@ func (c *Context) GetEmbeddingsIth(i int) []float32 {
|
|||||||
return embeddings
|
return embeddings
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetLogits returns the logits from the last decode operation.
|
||||||
|
// The returned slice has length equal to the vocabulary size.
|
||||||
|
func (c *Context) GetLogits() []float32 {
|
||||||
|
logits := unsafe.Pointer(C.llama_get_logits(c.c))
|
||||||
|
if logits == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get the number of vocabulary tokens to determine array size
|
||||||
|
vocabSize := c.Model().NumVocab()
|
||||||
|
return unsafe.Slice((*float32)(logits), vocabSize)
|
||||||
|
}
|
||||||
|
|
||||||
type ModelParams struct {
|
type ModelParams struct {
|
||||||
NumGpuLayers int
|
NumGpuLayers int
|
||||||
MainGpu int
|
MainGpu int
|
||||||
|
|||||||
69
llama/patches/0017-try-catch-backend-load.patch
Normal file
69
llama/patches/0017-try-catch-backend-load.patch
Normal file
@@ -0,0 +1,69 @@
|
|||||||
|
From 0000000000000000000000000000000000000000 Mon Sep 17 00:00:00 2001
|
||||||
|
From: Michael Yang <mxyng@pm.me>
|
||||||
|
Date: Tue, 11 Feb 2025 14:06:36 -0800
|
||||||
|
Subject: [PATCH] try/catch backend load
|
||||||
|
|
||||||
|
---
|
||||||
|
ggml/src/ggml-backend-reg.cpp | 45 ++++++++++++++++++-----------------
|
||||||
|
1 file changed, 23 insertions(+), 22 deletions(-)
|
||||||
|
|
||||||
|
diff --git a/ggml/src/ggml-backend-reg.cpp b/ggml/src/ggml-backend-reg.cpp
|
||||||
|
index ac5cda07..374c3b21 100644
|
||||||
|
--- a/ggml/src/ggml-backend-reg.cpp
|
||||||
|
+++ b/ggml/src/ggml-backend-reg.cpp
|
||||||
|
@@ -512,32 +512,33 @@ static ggml_backend_reg_t ggml_backend_load_best(const char * name, bool silent,
|
||||||
|
}
|
||||||
|
fs::directory_iterator dir_it(search_path, fs::directory_options::skip_permission_denied);
|
||||||
|
for (const auto & entry : dir_it) {
|
||||||
|
- if (entry.is_regular_file()) {
|
||||||
|
- std::wstring filename = entry.path().filename().wstring();
|
||||||
|
- std::wstring ext = entry.path().extension().wstring();
|
||||||
|
- if (filename.find(file_prefix) == 0 && ext == backend_filename_suffix()) {
|
||||||
|
- dl_handle_ptr handle { dl_load_library(entry.path().wstring()) };
|
||||||
|
- if (!handle && !silent) {
|
||||||
|
- GGML_LOG_ERROR("%s: failed to load %s\n", __func__, utf16_to_utf8(entry.path().wstring()).c_str());
|
||||||
|
- }
|
||||||
|
- if (handle) {
|
||||||
|
+ try {
|
||||||
|
+ if (entry.is_regular_file()) {
|
||||||
|
+ std::wstring filename = entry.path().filename().wstring();
|
||||||
|
+ std::wstring ext = entry.path().extension().wstring();
|
||||||
|
+ if (filename.find(file_prefix) == 0 && ext == backend_filename_suffix()) {
|
||||||
|
+ dl_handle_ptr handle { dl_load_library(entry.path().wstring()) };
|
||||||
|
+ if (!handle) {
|
||||||
|
+ GGML_LOG_ERROR("%s: failed to load %s\n", __func__, utf16_to_utf8(entry.path().wstring()).c_str());
|
||||||
|
+ continue;
|
||||||
|
+ }
|
||||||
|
+
|
||||||
|
auto score_fn = (ggml_backend_score_t) dl_get_sym(handle.get(), "ggml_backend_score");
|
||||||
|
- if (score_fn) {
|
||||||
|
- int s = score_fn();
|
||||||
|
-#ifndef NDEBUG
|
||||||
|
- GGML_LOG_DEBUG("%s: %s score: %d\n", __func__, utf16_to_utf8(entry.path().wstring()).c_str(), s);
|
||||||
|
-#endif
|
||||||
|
- if (s > best_score) {
|
||||||
|
- best_score = s;
|
||||||
|
- best_path = entry.path().wstring();
|
||||||
|
- }
|
||||||
|
- } else {
|
||||||
|
- if (!silent) {
|
||||||
|
- GGML_LOG_INFO("%s: failed to find ggml_backend_score in %s\n", __func__, utf16_to_utf8(entry.path().wstring()).c_str());
|
||||||
|
- }
|
||||||
|
+ if (!score_fn) {
|
||||||
|
+ GGML_LOG_DEBUG("%s: failed to find ggml_backend_score in %s\n", __func__, utf16_to_utf8(entry.path().wstring()).c_str());
|
||||||
|
+ continue;
|
||||||
|
+ }
|
||||||
|
+
|
||||||
|
+ int s = score_fn();
|
||||||
|
+ GGML_LOG_DEBUG("%s: %s score: %d\n", __func__, utf16_to_utf8(entry.path().wstring()).c_str(), s);
|
||||||
|
+ if (s > best_score) {
|
||||||
|
+ best_score = s;
|
||||||
|
+ best_path = entry.path().wstring();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
+ } catch (const std::exception & e) {
|
||||||
|
+ GGML_LOG_ERROR("%s: failed to load %s: %s\n", __func__, utf16_to_utf8(entry.path().wstring()).c_str(), e.what());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -8,12 +8,14 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"log"
|
"log"
|
||||||
"log/slog"
|
"log/slog"
|
||||||
|
"math"
|
||||||
"net"
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
"regexp"
|
"regexp"
|
||||||
"runtime"
|
"runtime"
|
||||||
|
"sort"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
@@ -48,8 +50,9 @@ type Sequence struct {
|
|||||||
// inputs that have been added to a batch but not yet submitted to Decode
|
// inputs that have been added to a batch but not yet submitted to Decode
|
||||||
pendingInputs []input
|
pendingInputs []input
|
||||||
|
|
||||||
|
// TODO: update this comment
|
||||||
// tokens that have been generated but not returned yet (e.g. for stop sequences)
|
// tokens that have been generated but not returned yet (e.g. for stop sequences)
|
||||||
pendingResponses []string
|
pendingResponses []CompletionResponse
|
||||||
|
|
||||||
// input cache being used by this sequence
|
// input cache being used by this sequence
|
||||||
cache *InputCacheSlot
|
cache *InputCacheSlot
|
||||||
@@ -59,7 +62,7 @@ type Sequence struct {
|
|||||||
crossAttention bool
|
crossAttention bool
|
||||||
|
|
||||||
// channel to send responses over
|
// channel to send responses over
|
||||||
responses chan string
|
responses chan CompletionResponse
|
||||||
|
|
||||||
// channel to stop decoding (such as if the remote connection is closed)
|
// channel to stop decoding (such as if the remote connection is closed)
|
||||||
quit chan bool
|
quit chan bool
|
||||||
@@ -83,6 +86,11 @@ type Sequence struct {
|
|||||||
|
|
||||||
doneReason string
|
doneReason string
|
||||||
|
|
||||||
|
logits []float32
|
||||||
|
|
||||||
|
// number of logprobs to return with the completion response
|
||||||
|
logprobs int
|
||||||
|
|
||||||
// Metrics
|
// Metrics
|
||||||
startProcessingTime time.Time
|
startProcessingTime time.Time
|
||||||
startGenerationTime time.Time
|
startGenerationTime time.Time
|
||||||
@@ -96,6 +104,7 @@ type NewSequenceParams struct {
|
|||||||
numKeep int
|
numKeep int
|
||||||
samplingParams *llama.SamplingParams
|
samplingParams *llama.SamplingParams
|
||||||
embedding bool
|
embedding bool
|
||||||
|
logprobs int
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) NewSequence(prompt string, images []ImageData, params NewSequenceParams) (*Sequence, error) {
|
func (s *Server) NewSequence(prompt string, images []ImageData, params NewSequenceParams) (*Sequence, error) {
|
||||||
@@ -148,14 +157,15 @@ func (s *Server) NewSequence(prompt string, images []ImageData, params NewSequen
|
|||||||
numPromptInputs: len(inputs),
|
numPromptInputs: len(inputs),
|
||||||
startProcessingTime: startTime,
|
startProcessingTime: startTime,
|
||||||
numPredict: params.numPredict,
|
numPredict: params.numPredict,
|
||||||
pendingResponses: make([]string, 0),
|
pendingResponses: make([]CompletionResponse, 0),
|
||||||
responses: make(chan string, 100),
|
responses: make(chan CompletionResponse, 100),
|
||||||
quit: make(chan bool, 1),
|
quit: make(chan bool, 1),
|
||||||
embedding: make(chan []float32, 1),
|
embedding: make(chan []float32, 1),
|
||||||
samplingCtx: sc,
|
samplingCtx: sc,
|
||||||
embeddingOnly: params.embedding,
|
embeddingOnly: params.embedding,
|
||||||
stop: params.stop,
|
stop: params.stop,
|
||||||
numKeep: params.numKeep,
|
numKeep: params.numKeep,
|
||||||
|
logprobs: params.logprobs,
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -274,29 +284,37 @@ func (s *Server) allNil() bool {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func flushPending(seq *Sequence) bool {
|
func flushPending(seq *Sequence) bool {
|
||||||
joined := strings.Join(seq.pendingResponses, "")
|
if len(seq.pendingResponses) == 0 {
|
||||||
seq.pendingResponses = []string{}
|
|
||||||
|
|
||||||
// Check if there are any partial UTF-8 characters remaining.
|
|
||||||
// We already check and queue as we are generating but some may
|
|
||||||
// still make it here:
|
|
||||||
// - Sequence is ending, e.g. generation limit has been hit
|
|
||||||
// - Invalid characters in the middle of a string
|
|
||||||
// This is a stricter check to ensure we never output invalid Unicode.
|
|
||||||
for !utf8.ValidString(joined) {
|
|
||||||
joined = joined[:len(joined)-1]
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(joined) == 0 {
|
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
resps := []CompletionResponse{}
|
||||||
select {
|
for _, resp := range seq.pendingResponses {
|
||||||
case seq.responses <- joined:
|
resps = append(resps, resp)
|
||||||
return true
|
|
||||||
case <-seq.quit:
|
|
||||||
return false
|
|
||||||
}
|
}
|
||||||
|
seq.pendingResponses = []CompletionResponse{}
|
||||||
|
|
||||||
|
// TODO: figure out this result logic
|
||||||
|
result := false
|
||||||
|
for _, resp := range resps {
|
||||||
|
// Check if there are any partial UTF-8 characters remaining.
|
||||||
|
// We already check and queue as we are generating but some may
|
||||||
|
// still make it here:
|
||||||
|
// - Sequence is ending, e.g. generation limit has been hit
|
||||||
|
// - Invalid characters in the middle of a string
|
||||||
|
// This is a stricter check to ensure we never output invalid Unicode.
|
||||||
|
for !utf8.ValidString(resp.Content) {
|
||||||
|
resp.Content = resp.Content[:len(resp.Content)-1]
|
||||||
|
}
|
||||||
|
|
||||||
|
select {
|
||||||
|
case seq.responses <- resp:
|
||||||
|
result = true
|
||||||
|
case <-seq.quit:
|
||||||
|
result = false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return result
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) removeSequence(seqIndex int, reason string) {
|
func (s *Server) removeSequence(seqIndex int, reason string) {
|
||||||
@@ -350,6 +368,63 @@ func (s *Server) run(ctx context.Context) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TokenProbs represents probability information for a token
|
||||||
|
type TokenProbs struct {
|
||||||
|
TokenID int `json:"id"`
|
||||||
|
Logit float32 `json:"logit"`
|
||||||
|
Prob float32 `json:"prob"`
|
||||||
|
LogProb float32 `json:"logprob"`
|
||||||
|
Token string `json:"token"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// probs returns sorted token probabilities for a specific token index
|
||||||
|
func probs(logits []float32, vocabSize int) []TokenProbs {
|
||||||
|
probs := make([]TokenProbs, vocabSize)
|
||||||
|
|
||||||
|
// Initialize token data with logits
|
||||||
|
for i := 0; i < vocabSize; i++ {
|
||||||
|
probs[i] = TokenProbs{
|
||||||
|
TokenID: i,
|
||||||
|
Logit: logits[i],
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Sort tokens by logits in descending order
|
||||||
|
sort.Slice(probs, func(i, j int) bool {
|
||||||
|
return probs[i].Logit > probs[j].Logit
|
||||||
|
})
|
||||||
|
|
||||||
|
// Apply softmax
|
||||||
|
maxLogit := probs[0].Logit
|
||||||
|
var sum float32 = 0.0
|
||||||
|
|
||||||
|
for i := range probs {
|
||||||
|
p := float32(math.Exp(float64(probs[i].Logit - maxLogit)))
|
||||||
|
probs[i].Prob = p
|
||||||
|
sum += p
|
||||||
|
}
|
||||||
|
|
||||||
|
// Normalize probabilities and calculate log probs
|
||||||
|
for i := range probs {
|
||||||
|
prob := probs[i].Prob / sum
|
||||||
|
probs[i].Prob = prob
|
||||||
|
probs[i].LogProb = float32(math.Log(float64(prob)))
|
||||||
|
}
|
||||||
|
|
||||||
|
return probs
|
||||||
|
}
|
||||||
|
|
||||||
|
// probs returns sorted token probabilities for a specific token index
|
||||||
|
func (s *Server) probs(seq *Sequence) []TokenProbs {
|
||||||
|
// Get logits for the specific token index
|
||||||
|
logits := s.lc.GetLogits()
|
||||||
|
seq.logits = make([]float32, len(logits))
|
||||||
|
copy(seq.logits, logits)
|
||||||
|
|
||||||
|
vocabSize := s.model.NumVocab()
|
||||||
|
return probs(logits, vocabSize)
|
||||||
|
}
|
||||||
|
|
||||||
// TODO (jmorganca): processBatch should be simplified, removing:
|
// TODO (jmorganca): processBatch should be simplified, removing:
|
||||||
// * sampling
|
// * sampling
|
||||||
// * stop token checking
|
// * stop token checking
|
||||||
@@ -483,6 +558,19 @@ func (s *Server) processBatch(tokenBatch *llama.Batch, embedBatch *llama.Batch)
|
|||||||
|
|
||||||
seq.numPredicted++
|
seq.numPredicted++
|
||||||
|
|
||||||
|
resp := CompletionResponse{Content: piece}
|
||||||
|
|
||||||
|
if seq.logprobs > 0 {
|
||||||
|
// TODO: return selected token in logprobs always
|
||||||
|
resp.LogProbs = s.probs(seq)
|
||||||
|
// TODO: fix this logprobs limit
|
||||||
|
resp.LogProbs = resp.LogProbs[:min(len(resp.LogProbs), seq.logprobs)]
|
||||||
|
for i := range resp.LogProbs {
|
||||||
|
// decode the token id to a piece
|
||||||
|
resp.LogProbs[i].Token = s.model.TokenToPiece(resp.LogProbs[i].TokenID)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// if it's an end of sequence token, break
|
// if it's an end of sequence token, break
|
||||||
if s.model.TokenIsEog(token) {
|
if s.model.TokenIsEog(token) {
|
||||||
// TODO (jmorganca): we should send this back
|
// TODO (jmorganca): we should send this back
|
||||||
@@ -495,16 +583,21 @@ func (s *Server) processBatch(tokenBatch *llama.Batch, embedBatch *llama.Batch)
|
|||||||
|
|
||||||
seq.inputs = []input{{token: token}}
|
seq.inputs = []input{{token: token}}
|
||||||
|
|
||||||
seq.pendingResponses = append(seq.pendingResponses, piece)
|
// TODO: add probs here
|
||||||
sequence := strings.Join(seq.pendingResponses, "")
|
seq.pendingResponses = append(seq.pendingResponses, resp)
|
||||||
|
var sequence string
|
||||||
|
for _, r := range seq.pendingResponses {
|
||||||
|
sequence += r.Content
|
||||||
|
}
|
||||||
|
|
||||||
if ok, stop := findStop(sequence, seq.stop); ok {
|
if ok, stop := findStop(sequence, seq.stop); ok {
|
||||||
slog.Debug("hit stop token", "pending", seq.pendingResponses, "stop", stop)
|
slog.Debug("hit stop token", "pending", seq.pendingResponses, "stop", stop)
|
||||||
|
|
||||||
|
// TODO: fix this stop sequence caching
|
||||||
var tokenTruncated bool
|
var tokenTruncated bool
|
||||||
origLen := len(seq.pendingResponses)
|
origLen := len(sequence)
|
||||||
seq.pendingResponses, tokenTruncated = truncateStop(seq.pendingResponses, stop)
|
sequence, tokenTruncated = truncateStop(sequence, stop)
|
||||||
newLen := len(seq.pendingResponses)
|
newLen := len(sequence)
|
||||||
|
|
||||||
// Update the cache based on the tokens that will be returned:
|
// Update the cache based on the tokens that will be returned:
|
||||||
// - We have 1 token more than is currently in the cache because
|
// - We have 1 token more than is currently in the cache because
|
||||||
@@ -575,6 +668,7 @@ type CompletionRequest struct {
|
|||||||
Images []ImageData `json:"image_data"`
|
Images []ImageData `json:"image_data"`
|
||||||
Grammar string `json:"grammar"`
|
Grammar string `json:"grammar"`
|
||||||
CachePrompt bool `json:"cache_prompt"`
|
CachePrompt bool `json:"cache_prompt"`
|
||||||
|
Logprobs int `json:"logprobs,omitempty"`
|
||||||
|
|
||||||
Options
|
Options
|
||||||
}
|
}
|
||||||
@@ -590,8 +684,10 @@ type CompletionResponse struct {
|
|||||||
Content string `json:"content"`
|
Content string `json:"content"`
|
||||||
Stop bool `json:"stop"`
|
Stop bool `json:"stop"`
|
||||||
|
|
||||||
Model string `json:"model,omitempty"`
|
Model string `json:"model,omitempty"`
|
||||||
Prompt string `json:"prompt,omitempty"`
|
Prompt string `json:"prompt,omitempty"`
|
||||||
|
LogProbs []TokenProbs `json:"logprobs,omitempty"`
|
||||||
|
|
||||||
StoppedLimit bool `json:"stopped_limit,omitempty"`
|
StoppedLimit bool `json:"stopped_limit,omitempty"`
|
||||||
PredictedN int `json:"predicted_n,omitempty"`
|
PredictedN int `json:"predicted_n,omitempty"`
|
||||||
PredictedMS float64 `json:"predicted_ms,omitempty"`
|
PredictedMS float64 `json:"predicted_ms,omitempty"`
|
||||||
@@ -609,10 +705,6 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// Set the headers to indicate streaming
|
|
||||||
w.Header().Set("Content-Type", "application/json")
|
|
||||||
w.Header().Set("Transfer-Encoding", "chunked")
|
|
||||||
|
|
||||||
flusher, ok := w.(http.Flusher)
|
flusher, ok := w.(http.Flusher)
|
||||||
if !ok {
|
if !ok {
|
||||||
http.Error(w, "Streaming not supported", http.StatusInternalServerError)
|
http.Error(w, "Streaming not supported", http.StatusInternalServerError)
|
||||||
@@ -641,6 +733,7 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
|
|||||||
numKeep: req.NumKeep,
|
numKeep: req.NumKeep,
|
||||||
samplingParams: &samplingParams,
|
samplingParams: &samplingParams,
|
||||||
embedding: false,
|
embedding: false,
|
||||||
|
logprobs: req.Logprobs,
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
http.Error(w, fmt.Sprintf("Failed to create new sequence: %v", err), http.StatusInternalServerError)
|
http.Error(w, fmt.Sprintf("Failed to create new sequence: %v", err), http.StatusInternalServerError)
|
||||||
@@ -688,11 +781,10 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
|
|||||||
case <-r.Context().Done():
|
case <-r.Context().Done():
|
||||||
close(seq.quit)
|
close(seq.quit)
|
||||||
return
|
return
|
||||||
case content, ok := <-seq.responses:
|
case resp, ok := <-seq.responses:
|
||||||
if ok {
|
if ok {
|
||||||
if err := json.NewEncoder(w).Encode(&CompletionResponse{
|
fmt.Println("response", resp)
|
||||||
Content: content,
|
if err := json.NewEncoder(w).Encode(&resp); err != nil {
|
||||||
}); err != nil {
|
|
||||||
http.Error(w, fmt.Sprintf("failed to encode response: %v", err), http.StatusInternalServerError)
|
http.Error(w, fmt.Sprintf("failed to encode response: %v", err), http.StatusInternalServerError)
|
||||||
close(seq.quit)
|
close(seq.quit)
|
||||||
return
|
return
|
||||||
|
|||||||
58
llama/runner/runner_test.go
Normal file
58
llama/runner/runner_test.go
Normal file
@@ -0,0 +1,58 @@
|
|||||||
|
package runner
|
||||||
|
|
||||||
|
import (
|
||||||
|
"math"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestProbs(t *testing.T) {
|
||||||
|
// Input test data
|
||||||
|
logits := []float32{1.0, 2.0, 0.5, -1.0}
|
||||||
|
vocabSize := 4
|
||||||
|
want := []TokenProbs{
|
||||||
|
{TokenID: 1, Logit: 2.0}, // Highest logit
|
||||||
|
{TokenID: 0, Logit: 1.0}, // Second highest
|
||||||
|
{TokenID: 2, Logit: 0.5}, // Third
|
||||||
|
{TokenID: 3, Logit: -1.0}, // Lowest
|
||||||
|
}
|
||||||
|
|
||||||
|
got := probs(logits, vocabSize)
|
||||||
|
|
||||||
|
// Test 1: Check sorting order
|
||||||
|
for i := 0; i < len(got)-1; i++ {
|
||||||
|
if got[i].Logit < got[i+1].Logit {
|
||||||
|
t.Errorf("probs not properly sorted: logit at pos %d (%f) < logit at pos %d (%f)",
|
||||||
|
i, got[i].Logit, i+1, got[i+1].Logit)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test 2: Check probability normalization
|
||||||
|
var sum float32
|
||||||
|
for _, p := range got {
|
||||||
|
sum += p.Prob
|
||||||
|
}
|
||||||
|
if math.Abs(float64(sum-1.0)) > 1e-6 {
|
||||||
|
t.Errorf("probabilities do not sum to 1: got %v", sum)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test 3: Check token IDs match expected order
|
||||||
|
for i, want := range want {
|
||||||
|
if got[i].TokenID != want.TokenID {
|
||||||
|
t.Errorf("wrong token ID at position %d: got %d, want %d",
|
||||||
|
i, got[i].TokenID, want.TokenID)
|
||||||
|
}
|
||||||
|
if got[i].Logit != want.Logit {
|
||||||
|
t.Errorf("wrong logit at position %d: got %f, want %f",
|
||||||
|
i, got[i].Logit, want.Logit)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test 4: Check log probs are correctly calculated
|
||||||
|
for i, p := range got {
|
||||||
|
expectedLogProb := float32(math.Log(float64(p.Prob)))
|
||||||
|
if math.Abs(float64(p.LogProb-expectedLogProb)) > 1e-6 {
|
||||||
|
t.Errorf("wrong log prob at position %d: got %f, want %f",
|
||||||
|
i, p.LogProb, expectedLogProb)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -26,43 +26,15 @@ func containsStopSuffix(sequence string, stops []string) bool {
|
|||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
// truncateStop removes the provided stop string from pieces,
|
// truncateStop removes the provided stop string from sequence,
|
||||||
// returning the partial pieces with stop removed, including truncating
|
// returning both the truncated sequence and a bool indicating if truncation occurred
|
||||||
// the last piece if required (and signalling if this was the case)
|
func truncateStop(sequence string, stop string) (string, bool) {
|
||||||
func truncateStop(pieces []string, stop string) ([]string, bool) {
|
index := strings.Index(sequence, stop)
|
||||||
joined := strings.Join(pieces, "")
|
|
||||||
|
|
||||||
index := strings.Index(joined, stop)
|
|
||||||
if index == -1 {
|
if index == -1 {
|
||||||
return pieces, false
|
return sequence, false
|
||||||
}
|
}
|
||||||
|
|
||||||
joined = joined[:index]
|
return sequence[:index], true
|
||||||
|
|
||||||
// Split truncated string back into pieces of original lengths
|
|
||||||
lengths := make([]int, len(pieces))
|
|
||||||
for i, piece := range pieces {
|
|
||||||
lengths[i] = len(piece)
|
|
||||||
}
|
|
||||||
|
|
||||||
var result []string
|
|
||||||
tokenTruncated := false
|
|
||||||
start := 0
|
|
||||||
for _, length := range lengths {
|
|
||||||
if start >= len(joined) {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
|
|
||||||
end := start + length
|
|
||||||
if end > len(joined) {
|
|
||||||
end = len(joined)
|
|
||||||
tokenTruncated = true
|
|
||||||
}
|
|
||||||
result = append(result, joined[start:end])
|
|
||||||
start = end
|
|
||||||
}
|
|
||||||
|
|
||||||
return result, tokenTruncated
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func incompleteUnicode(token string) bool {
|
func incompleteUnicode(token string) bool {
|
||||||
|
|||||||
@@ -1,60 +1,60 @@
|
|||||||
package runner
|
package runner
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"reflect"
|
|
||||||
"testing"
|
"testing"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestTruncateStop(t *testing.T) {
|
func TestTruncateStop(t *testing.T) {
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
pieces []string
|
sequence string
|
||||||
stop string
|
stop string
|
||||||
expected []string
|
expected string
|
||||||
expectedTrunc bool
|
expectedTrunc bool
|
||||||
}{
|
}{
|
||||||
{
|
{
|
||||||
name: "Single word",
|
name: "Single word",
|
||||||
pieces: []string{"hello", "world"},
|
sequence: "helloworld",
|
||||||
stop: "world",
|
stop: "world",
|
||||||
expected: []string{"hello"},
|
expected: "hello",
|
||||||
expectedTrunc: false,
|
expectedTrunc: true,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "Partial",
|
name: "Partial",
|
||||||
pieces: []string{"hello", "wor"},
|
sequence: "hellowor",
|
||||||
stop: "or",
|
stop: "or",
|
||||||
expected: []string{"hello", "w"},
|
expected: "hellow",
|
||||||
expectedTrunc: true,
|
expectedTrunc: true,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "Suffix",
|
name: "Suffix",
|
||||||
pieces: []string{"Hello", " there", "!"},
|
sequence: "Hello there!",
|
||||||
stop: "!",
|
stop: "!",
|
||||||
expected: []string{"Hello", " there"},
|
expected: "Hello there",
|
||||||
expectedTrunc: false,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "Suffix partial",
|
|
||||||
pieces: []string{"Hello", " the", "re!"},
|
|
||||||
stop: "there!",
|
|
||||||
expected: []string{"Hello", " "},
|
|
||||||
expectedTrunc: true,
|
expectedTrunc: true,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "Middle",
|
name: "Middle",
|
||||||
pieces: []string{"hello", " wor"},
|
sequence: "hello wor",
|
||||||
stop: "llo w",
|
stop: "llo w",
|
||||||
expected: []string{"he"},
|
expected: "he",
|
||||||
expectedTrunc: true,
|
expectedTrunc: true,
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
name: "No stop found",
|
||||||
|
sequence: "hello world",
|
||||||
|
stop: "xyz",
|
||||||
|
expected: "hello world",
|
||||||
|
expectedTrunc: false,
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
result, resultTrunc := truncateStop(tt.pieces, tt.stop)
|
result, truncated := truncateStop(tt.sequence, tt.stop)
|
||||||
if !reflect.DeepEqual(result, tt.expected) || resultTrunc != tt.expectedTrunc {
|
if result != tt.expected || truncated != tt.expectedTrunc {
|
||||||
t.Errorf("truncateStop(%v, %s): have %v (%v); want %v (%v)", tt.pieces, tt.stop, result, resultTrunc, tt.expected, tt.expectedTrunc)
|
t.Errorf("truncateStop(%q, %q): have %q (%v); want %q (%v)",
|
||||||
|
tt.sequence, tt.stop, result, truncated, tt.expected, tt.expectedTrunc)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -644,12 +644,22 @@ type ImageData struct {
|
|||||||
AspectRatioID int `json:"aspect_ratio_id"`
|
AspectRatioID int `json:"aspect_ratio_id"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TokenProbs represents probability information for a token
|
||||||
|
type TokenProbs struct {
|
||||||
|
TokenID int `json:"id"`
|
||||||
|
Logit float32 `json:"logit"`
|
||||||
|
Prob float32 `json:"prob"`
|
||||||
|
LogProb float32 `json:"logprob"`
|
||||||
|
Token string `json:"token"`
|
||||||
|
}
|
||||||
|
|
||||||
type completion struct {
|
type completion struct {
|
||||||
Content string `json:"content"`
|
Content string `json:"content"`
|
||||||
Model string `json:"model"`
|
Model string `json:"model"`
|
||||||
Prompt string `json:"prompt"`
|
Prompt string `json:"prompt"`
|
||||||
Stop bool `json:"stop"`
|
Stop bool `json:"stop"`
|
||||||
StoppedLimit bool `json:"stopped_limit"`
|
StoppedLimit bool `json:"stopped_limit"`
|
||||||
|
LogProbs []TokenProbs `json:"logprobs"`
|
||||||
|
|
||||||
Timings struct {
|
Timings struct {
|
||||||
PredictedN int `json:"predicted_n"`
|
PredictedN int `json:"predicted_n"`
|
||||||
@@ -660,14 +670,16 @@ type completion struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type CompletionRequest struct {
|
type CompletionRequest struct {
|
||||||
Prompt string
|
Prompt string
|
||||||
Format json.RawMessage
|
Format json.RawMessage
|
||||||
Images []ImageData
|
Images []ImageData
|
||||||
Options *api.Options
|
LogProbs int
|
||||||
|
Options *api.Options
|
||||||
}
|
}
|
||||||
|
|
||||||
type CompletionResponse struct {
|
type CompletionResponse struct {
|
||||||
Content string
|
Content string
|
||||||
|
LogProbs []TokenProbs
|
||||||
DoneReason string
|
DoneReason string
|
||||||
Done bool
|
Done bool
|
||||||
PromptEvalCount int
|
PromptEvalCount int
|
||||||
@@ -698,9 +710,12 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu
|
|||||||
"seed": req.Options.Seed,
|
"seed": req.Options.Seed,
|
||||||
"stop": req.Options.Stop,
|
"stop": req.Options.Stop,
|
||||||
"image_data": req.Images,
|
"image_data": req.Images,
|
||||||
|
"logprobs": req.LogProbs,
|
||||||
"cache_prompt": true,
|
"cache_prompt": true,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fmt.Println("completion request:", request)
|
||||||
|
|
||||||
if len(req.Format) > 0 {
|
if len(req.Format) > 0 {
|
||||||
switch string(req.Format) {
|
switch string(req.Format) {
|
||||||
case `null`, `""`:
|
case `null`, `""`:
|
||||||
@@ -796,7 +811,6 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
// slog.Debug("got line", "line", string(line))
|
|
||||||
evt, ok := bytes.CutPrefix(line, []byte("data: "))
|
evt, ok := bytes.CutPrefix(line, []byte("data: "))
|
||||||
if !ok {
|
if !ok {
|
||||||
evt = line
|
evt = line
|
||||||
@@ -822,7 +836,8 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu
|
|||||||
|
|
||||||
if c.Content != "" {
|
if c.Content != "" {
|
||||||
fn(CompletionResponse{
|
fn(CompletionResponse{
|
||||||
Content: c.Content,
|
Content: c.Content,
|
||||||
|
LogProbs: c.LogProbs,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -839,6 +854,7 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu
|
|||||||
PromptEvalDuration: parseDurationMs(c.Timings.PromptMS),
|
PromptEvalDuration: parseDurationMs(c.Timings.PromptMS),
|
||||||
EvalCount: c.Timings.PredictedN,
|
EvalCount: c.Timings.PredictedN,
|
||||||
EvalDuration: parseDurationMs(c.Timings.PredictedMS),
|
EvalDuration: parseDurationMs(c.Timings.PredictedMS),
|
||||||
|
LogProbs: c.LogProbs,
|
||||||
})
|
})
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|||||||
45
ml/backend/ggml/ggml/src/ggml-backend-reg.cpp
vendored
45
ml/backend/ggml/ggml/src/ggml-backend-reg.cpp
vendored
@@ -512,32 +512,33 @@ static ggml_backend_reg_t ggml_backend_load_best(const char * name, bool silent,
|
|||||||
}
|
}
|
||||||
fs::directory_iterator dir_it(search_path, fs::directory_options::skip_permission_denied);
|
fs::directory_iterator dir_it(search_path, fs::directory_options::skip_permission_denied);
|
||||||
for (const auto & entry : dir_it) {
|
for (const auto & entry : dir_it) {
|
||||||
if (entry.is_regular_file()) {
|
try {
|
||||||
std::wstring filename = entry.path().filename().wstring();
|
if (entry.is_regular_file()) {
|
||||||
std::wstring ext = entry.path().extension().wstring();
|
std::wstring filename = entry.path().filename().wstring();
|
||||||
if (filename.find(file_prefix) == 0 && ext == backend_filename_suffix()) {
|
std::wstring ext = entry.path().extension().wstring();
|
||||||
dl_handle_ptr handle { dl_load_library(entry.path().wstring()) };
|
if (filename.find(file_prefix) == 0 && ext == backend_filename_suffix()) {
|
||||||
if (!handle && !silent) {
|
dl_handle_ptr handle { dl_load_library(entry.path().wstring()) };
|
||||||
GGML_LOG_ERROR("%s: failed to load %s\n", __func__, utf16_to_utf8(entry.path().wstring()).c_str());
|
if (!handle) {
|
||||||
}
|
GGML_LOG_ERROR("%s: failed to load %s\n", __func__, utf16_to_utf8(entry.path().wstring()).c_str());
|
||||||
if (handle) {
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
auto score_fn = (ggml_backend_score_t) dl_get_sym(handle.get(), "ggml_backend_score");
|
auto score_fn = (ggml_backend_score_t) dl_get_sym(handle.get(), "ggml_backend_score");
|
||||||
if (score_fn) {
|
if (!score_fn) {
|
||||||
int s = score_fn();
|
GGML_LOG_DEBUG("%s: failed to find ggml_backend_score in %s\n", __func__, utf16_to_utf8(entry.path().wstring()).c_str());
|
||||||
#ifndef NDEBUG
|
continue;
|
||||||
GGML_LOG_DEBUG("%s: %s score: %d\n", __func__, utf16_to_utf8(entry.path().wstring()).c_str(), s);
|
}
|
||||||
#endif
|
|
||||||
if (s > best_score) {
|
int s = score_fn();
|
||||||
best_score = s;
|
GGML_LOG_DEBUG("%s: %s score: %d\n", __func__, utf16_to_utf8(entry.path().wstring()).c_str(), s);
|
||||||
best_path = entry.path().wstring();
|
if (s > best_score) {
|
||||||
}
|
best_score = s;
|
||||||
} else {
|
best_path = entry.path().wstring();
|
||||||
if (!silent) {
|
|
||||||
GGML_LOG_INFO("%s: failed to find ggml_backend_score in %s\n", __func__, utf16_to_utf8(entry.path().wstring()).c_str());
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
} catch (const std::exception & e) {
|
||||||
|
GGML_LOG_ERROR("%s: failed to load %s: %s\n", __func__, utf16_to_utf8(entry.path().wstring()).c_str(), e.what());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -79,6 +79,11 @@ var OnceLoad = sync.OnceFunc(func() {
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if abspath != filepath.Dir(exe) && !strings.Contains(abspath, filepath.FromSlash("lib/ollama")) {
|
||||||
|
slog.Debug("skipping path which is not part of ollama", "path", abspath)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
if _, ok := visited[abspath]; !ok {
|
if _, ok := visited[abspath]; !ok {
|
||||||
func() {
|
func() {
|
||||||
slog.Debug("ggml backend load all from path", "path", abspath)
|
slog.Debug("ggml backend load all from path", "path", abspath)
|
||||||
|
|||||||
@@ -610,14 +610,14 @@ type EmbedWriter struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (w *BaseWriter) writeError(data []byte) (int, error) {
|
func (w *BaseWriter) writeError(data []byte) (int, error) {
|
||||||
var er api.ErrorResponse // error response is used here to parse the error message
|
var serr api.StatusError
|
||||||
err := json.Unmarshal(data, &er)
|
err := json.Unmarshal(data, &serr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, err
|
return 0, err
|
||||||
}
|
}
|
||||||
|
|
||||||
w.ResponseWriter.Header().Set("Content-Type", "application/json")
|
w.ResponseWriter.Header().Set("Content-Type", "application/json")
|
||||||
err = json.NewEncoder(w.ResponseWriter).Encode(NewError(http.StatusInternalServerError, er.Err))
|
err = json.NewEncoder(w.ResponseWriter).Encode(NewError(http.StatusInternalServerError, serr.Error()))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, err
|
return 0, err
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -550,7 +550,7 @@ func PullModel(ctx context.Context, name string, regOpts *registryOptions, fn fu
|
|||||||
|
|
||||||
manifest, err = pullModelManifest(ctx, mp, regOpts)
|
manifest, err = pullModelManifest(ctx, mp, regOpts)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("pull model manifest: %w", err)
|
return fmt.Errorf("pull model manifest: %s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
var layers []Layer
|
var layers []Layer
|
||||||
@@ -629,18 +629,13 @@ func PullModel(ctx context.Context, name string, regOpts *registryOptions, fn fu
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
var ErrRemoteModelNotFound = errors.New("model not found")
|
|
||||||
|
|
||||||
func pullModelManifest(ctx context.Context, mp ModelPath, regOpts *registryOptions) (*Manifest, error) {
|
func pullModelManifest(ctx context.Context, mp ModelPath, regOpts *registryOptions) (*Manifest, error) {
|
||||||
requestURL := mp.BaseURL().JoinPath("v2", mp.GetNamespaceRepository(), "manifests", mp.Tag)
|
requestURL := mp.BaseURL().JoinPath("v2", mp.GetNamespaceRepository(), "manifests", mp.Tag)
|
||||||
|
|
||||||
headers := make(http.Header)
|
headers := make(http.Header)
|
||||||
headers.Set("Accept", "application/vnd.docker.distribution.manifest.v2+json")
|
headers.Set("Accept", "application/vnd.docker.distribution.manifest.v2+json")
|
||||||
resp, err := makeRequestWithRetry(ctx, http.MethodGet, requestURL, headers, nil, regOpts)
|
resp, err := makeRequestWithRetry(ctx, http.MethodGet, requestURL, headers, nil, regOpts)
|
||||||
if errors.Is(err, os.ErrNotExist) {
|
if err != nil {
|
||||||
// The model was not found on the remote registry
|
|
||||||
return nil, fmt.Errorf("%w: %s", ErrRemoteModelNotFound, err)
|
|
||||||
} else if err != nil {
|
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
defer resp.Body.Close()
|
defer resp.Body.Close()
|
||||||
|
|||||||
@@ -293,11 +293,13 @@ func (s *Server) GenerateHandler(c *gin.Context) {
|
|||||||
var sb strings.Builder
|
var sb strings.Builder
|
||||||
defer close(ch)
|
defer close(ch)
|
||||||
if err := r.Completion(c.Request.Context(), llm.CompletionRequest{
|
if err := r.Completion(c.Request.Context(), llm.CompletionRequest{
|
||||||
Prompt: prompt,
|
Prompt: prompt,
|
||||||
Images: images,
|
Images: images,
|
||||||
Format: req.Format,
|
Format: req.Format,
|
||||||
Options: opts,
|
LogProbs: req.LogProbs,
|
||||||
|
Options: opts,
|
||||||
}, func(cr llm.CompletionResponse) {
|
}, func(cr llm.CompletionResponse) {
|
||||||
|
fmt.Printf("banana: %#v\n", cr)
|
||||||
res := api.GenerateResponse{
|
res := api.GenerateResponse{
|
||||||
Model: req.Model,
|
Model: req.Model,
|
||||||
CreatedAt: time.Now().UTC(),
|
CreatedAt: time.Now().UTC(),
|
||||||
@@ -311,6 +313,13 @@ func (s *Server) GenerateHandler(c *gin.Context) {
|
|||||||
EvalDuration: cr.EvalDuration,
|
EvalDuration: cr.EvalDuration,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
for _, p := range cr.LogProbs {
|
||||||
|
res.LogProbs = append(res.LogProbs, api.TokenProbs{
|
||||||
|
TokenID: p.TokenID,
|
||||||
|
LogProb: p.LogProb,
|
||||||
|
Token: p.Token,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
if _, err := sb.WriteString(cr.Content); err != nil {
|
if _, err := sb.WriteString(cr.Content); err != nil {
|
||||||
ch <- gin.H{"error": err.Error()}
|
ch <- gin.H{"error": err.Error()}
|
||||||
@@ -564,8 +573,7 @@ func (s *Server) PullHandler(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
reqName := cmp.Or(req.Model, req.Name)
|
name := model.ParseName(cmp.Or(req.Model, req.Name))
|
||||||
name := model.ParseName(reqName)
|
|
||||||
if !name.IsValid() {
|
if !name.IsValid() {
|
||||||
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": errtypes.InvalidModelNameErrMsg})
|
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": errtypes.InvalidModelNameErrMsg})
|
||||||
return
|
return
|
||||||
@@ -592,18 +600,7 @@ func (s *Server) PullHandler(c *gin.Context) {
|
|||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
if err := PullModel(ctx, name.DisplayShortest(), regOpts, fn); err != nil {
|
if err := PullModel(ctx, name.DisplayShortest(), regOpts, fn); err != nil {
|
||||||
if errors.Is(err, ErrRemoteModelNotFound) {
|
ch <- gin.H{"error": err.Error()}
|
||||||
hint := fmt.Sprintf("Model %q not found - please check the model name is correct and try again", reqName)
|
|
||||||
if name.Host == DefaultRegistry {
|
|
||||||
hint = fmt.Sprintf("Model %q not found - search available models at: https://ollama.com/search?q=%s", reqName, reqName)
|
|
||||||
}
|
|
||||||
ch <- api.ErrorResponse{
|
|
||||||
Err: err.Error(),
|
|
||||||
Hint: hint,
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
ch <- gin.H{"error": err.Error()}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
@@ -1478,10 +1475,11 @@ func (s *Server) ChatHandler(c *gin.Context) {
|
|||||||
var sb strings.Builder
|
var sb strings.Builder
|
||||||
var toolCallIndex int = 0
|
var toolCallIndex int = 0
|
||||||
if err := r.Completion(c.Request.Context(), llm.CompletionRequest{
|
if err := r.Completion(c.Request.Context(), llm.CompletionRequest{
|
||||||
Prompt: prompt,
|
Prompt: prompt,
|
||||||
Images: images,
|
Images: images,
|
||||||
Format: req.Format,
|
Format: req.Format,
|
||||||
Options: opts,
|
LogProbs: req.LogProbs,
|
||||||
|
Options: opts,
|
||||||
}, func(r llm.CompletionResponse) {
|
}, func(r llm.CompletionResponse) {
|
||||||
res := api.ChatResponse{
|
res := api.ChatResponse{
|
||||||
Model: req.Model,
|
Model: req.Model,
|
||||||
@@ -1496,6 +1494,13 @@ func (s *Server) ChatHandler(c *gin.Context) {
|
|||||||
EvalDuration: r.EvalDuration,
|
EvalDuration: r.EvalDuration,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
for _, p := range r.LogProbs {
|
||||||
|
res.LogProbs = append(res.LogProbs, api.TokenProbs{
|
||||||
|
TokenID: p.TokenID,
|
||||||
|
LogProb: p.LogProb,
|
||||||
|
Token: p.Token,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
if r.Done {
|
if r.Done {
|
||||||
res.TotalDuration = time.Since(checkpointStart)
|
res.TotalDuration = time.Since(checkpointStart)
|
||||||
|
|||||||
Reference in New Issue
Block a user