Compare commits

..

12 Commits

Author SHA1 Message Date
Bruce MacDonald
e19c64e047 lint fix 2024-12-12 13:07:09 -08:00
Bruce MacDonald
9e190ac4d9 api: return structured error on unauthorized push
This commit implements a structured error response system for the Ollama API, replacing
ad-hoc error handling and string parsing with proper error types and codes through a new
ErrorResponse struct. Instead of relying on regex to parse error messages for SSH keys,
the API now passes this data in a structured format with standardized fields for error
messages, codes, and additional data. This structured approach makes the API more
maintainable and reliable while improving the developer experience by enabling
programmatic error handling, consistent error formats, and better error
documentation.
2024-12-12 11:49:36 -08:00
Bruce MacDonald
ae9165d661 remove images_test.go (uses filesystem key) 2024-11-27 15:52:30 -08:00
Bruce MacDonald
a262b86a5e fix lint checks 2024-11-27 15:45:45 -08:00
Bruce MacDonald
4d5d3c3276 Update error.go 2024-11-27 15:24:08 -08:00
Bruce MacDonald
ea90ee7415 Update cmd.go 2024-11-27 15:22:27 -08:00
Bruce MacDonald
40134c6587 server: show user feedback when key is anonymous
When an ollama key is not registered with any account on ollama.com this is
not obvious. In the current CLI an error message that the user is not
authorized is displayed. This change brings back previous behavior to show
the user their key and where they should add it. It protects against adding
unexpected keys by checking that the key is available locally.

A follow-up change should add structured errors from the API. This change
just relies on a known error message.
2024-11-27 15:01:12 -08:00
Bruce MacDonald
940e62772e openai: remove unused error code (#7850)
The writeError takes a code argument which is no longer used. Remove it for clarity.
2024-11-26 16:08:09 -08:00
Jesse Gross
71e6a0d0d1 runner.go: Don't try to extract image tags for text models
When processing a prompt, we look for image tags of the form
[img-0], which are inserted by the Ollama server process.
However, this can cause errors if the original prompt has these
tags - typically an image not found error is returned.

This changes tag searching behavior to be similar to the 0.3.x
series, which will largely avoid these problems. However,they can
still happen when input text with these tags is used with image
models. The correct solution is to escape the tags but this is a
larger issue with special sequences in general so this is an
incremental fix that should avoid the problem for the majority
of cases.
2024-11-26 13:23:24 -08:00
Jesse Gross
2cd11ae365 runner.go: Add unit tests for context shifting
This also makes it easier to truncate long inputs the same as
shifting but does not actually implement it. This type of
truncation has a trade off between quality and time to first
token.
2024-11-26 11:21:35 -08:00
jake83741
52bbad12f9 readme: update description for vnc-lm community integration (#7832) 2024-11-25 17:56:30 -08:00
frob
30e88d7f31 cmd: don't submit svg files as images for now (#7830) 2024-11-25 16:43:29 -08:00
17 changed files with 455 additions and 97 deletions

View File

@@ -504,7 +504,7 @@ See the [API documentation](./docs/api.md) for all endpoints.
- [Terraform AWS Ollama & Open WebUI](https://github.com/xuyangbocn/terraform-aws-self-host-llm) (A Terraform module to deploy on AWS a ready-to-use Ollama service, together with its front end Open WebUI service.)
- [node-red-contrib-ollama](https://github.com/jakubburkiewicz/node-red-contrib-ollama)
- [Local AI Helper](https://github.com/ivostoykov/localAI) (Chrome and Firefox extensions that enable interactions with the active tab and customisable API endpoints. Includes secure storage for user prompts.)
- [vnc-lm](https://github.com/jk011ru/vnc-lm) (A containerized Discord bot with support for attachments and web links)
- [vnc-lm](https://github.com/jake83741/vnc-lm) (Discord bot for messaging with LLMs through Ollama and LiteLLM. Seamlessly move between local and flagship models.)
- [LSP-AI](https://github.com/SilasMarvin/lsp-ai) (Open-source language server for AI-powered functionality)
- [QodeAssist](https://github.com/Palm1r/QodeAssist) (AI-powered coding assistant plugin for Qt Creator)
- [Obsidian Quiz Generator plugin](https://github.com/ECuiDev/obsidian-quiz-generator)

View File

@@ -163,24 +163,29 @@ func (c *Client) stream(ctx context.Context, method, path string, data any, fn f
scanBuf := make([]byte, 0, maxBufferSize)
scanner.Buffer(scanBuf, maxBufferSize)
for scanner.Scan() {
var errorResponse struct {
Error string `json:"error,omitempty"`
}
bts := scanner.Bytes()
var errorResponse ErrorResponse
if err := json.Unmarshal(bts, &errorResponse); err != nil {
return fmt.Errorf("unmarshal: %w", err)
}
if errorResponse.Error != "" {
return errors.New(errorResponse.Error)
switch errorResponse.Code {
case ErrCodeUnknownKey:
return ErrUnknownOllamaKey{
Message: errorResponse.Message,
Key: errorResponse.Data["key"].(string),
}
}
if errorResponse.Message != "" {
return errors.New(errorResponse.Message)
}
if response.StatusCode >= http.StatusBadRequest {
return StatusError{
StatusCode: response.StatusCode,
Status: response.Status,
ErrorMessage: errorResponse.Error,
ErrorMessage: errorResponse.Message,
}
}

View File

@@ -1,6 +1,12 @@
package api
import (
"context"
"errors"
"fmt"
"net/http"
"net/http/httptest"
"net/url"
"testing"
)
@@ -43,3 +49,117 @@ func TestClientFromEnvironment(t *testing.T) {
})
}
}
func TestStream(t *testing.T) {
tests := []struct {
name string
serverResponse []string
statusCode int
expectedError error
}{
{
name: "unknown key error",
serverResponse: []string{
`{"error":"unauthorized access","code":"unknown_key","data":{"key":"test-key"}}`,
},
statusCode: http.StatusUnauthorized,
expectedError: &ErrUnknownOllamaKey{
Message: "unauthorized access",
Key: "test-key",
},
},
{
name: "general error message",
serverResponse: []string{
`{"error":"something went wrong"}`,
},
statusCode: http.StatusInternalServerError,
expectedError: fmt.Errorf("something went wrong"),
},
{
name: "malformed json response",
serverResponse: []string{
`{invalid-json`,
},
statusCode: http.StatusOK,
expectedError: fmt.Errorf("unmarshal: invalid character 'i' looking for beginning of object key string"),
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/x-ndjson")
w.WriteHeader(tt.statusCode)
for _, resp := range tt.serverResponse {
fmt.Fprintln(w, resp)
}
}))
defer server.Close()
baseURL, err := url.Parse(server.URL)
if err != nil {
t.Fatalf("failed to parse server URL: %v", err)
}
client := &Client{
http: server.Client(),
base: baseURL,
}
var responses [][]byte
err = client.stream(context.Background(), "POST", "/test", "test", func(bts []byte) error {
responses = append(responses, bts)
return nil
})
// Error checking
if tt.expectedError == nil {
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
return
}
if err == nil {
t.Fatalf("expected error %v, got nil", tt.expectedError)
}
// Check for specific error types
var unknownKeyErr ErrUnknownOllamaKey
if errors.As(tt.expectedError, &unknownKeyErr) {
var gotErr ErrUnknownOllamaKey
if !errors.As(err, &gotErr) {
t.Fatalf("expected ErrUnknownOllamaKey, got %T", err)
}
if unknownKeyErr.Key != gotErr.Key {
t.Errorf("expected key %q, got %q", unknownKeyErr.Key, gotErr.Key)
}
if unknownKeyErr.Message != gotErr.Message {
t.Errorf("expected message %q, got %q", unknownKeyErr.Message, gotErr.Message)
}
return
}
var statusErr StatusError
if errors.As(tt.expectedError, &statusErr) {
var gotErr StatusError
if !errors.As(err, &gotErr) {
t.Fatalf("expected StatusError, got %T", err)
}
if statusErr.StatusCode != gotErr.StatusCode {
t.Errorf("expected status code %d, got %d", statusErr.StatusCode, gotErr.StatusCode)
}
if statusErr.ErrorMessage != gotErr.ErrorMessage {
t.Errorf("expected error message %q, got %q", statusErr.ErrorMessage, gotErr.ErrorMessage)
}
return
}
// For other errors, compare error strings
if err.Error() != tt.expectedError.Error() {
t.Errorf("expected error %q, got %q", tt.expectedError, err)
}
})
}
}

74
api/errors.go Normal file
View File

@@ -0,0 +1,74 @@
package api
import (
"fmt"
"slices"
"strings"
)
const InvalidModelNameErrMsg = "invalid model name"
// API error responses
// ErrorCode represents a standardized error code identifier
type ErrorCode string
const (
ErrCodeUnknownKey ErrorCode = "unknown_key"
ErrCodeGeneral ErrorCode = "general" // Generic fallback error code
)
// ErrorResponse implements a structured error interface
type ErrorResponse struct {
Message string `json:"error"` // Human-readable error message, uses 'error' field name for backwards compatibility
Code ErrorCode `json:"code"` // Machine-readable error code for programmatic handling, not response code
Data map[string]any `json:"data"` // Additional error specific data, if any
}
func (e ErrorResponse) Error() string {
return e.Message
}
type ErrUnknownOllamaKey struct {
Message string
Key string
}
func (e ErrUnknownOllamaKey) Error() string {
return fmt.Sprintf("unauthorized: unknown ollama key %q", strings.TrimSpace(e.Key))
}
func (e *ErrUnknownOllamaKey) FormatUserMessage(localKeys []string) string {
// The user should only be told to add the key if it is the same one that exists locally
if slices.Index(localKeys, e.Key) == -1 {
return e.Message
}
return fmt.Sprintf(`%s
Your ollama key is:
%s
Add your key at:
https://ollama.com/settings/keys`, e.Message, e.Key)
}
// StatusError is an error with an HTTP status code and message,
// it is parsed on the client-side and not returned from the API
type StatusError struct {
StatusCode int // e.g. 200
Status string // e.g. "200 OK"
ErrorMessage string `json:"error"`
}
func (e StatusError) Error() string {
switch {
case e.Status != "" && e.ErrorMessage != "":
return fmt.Sprintf("%s: %s", e.Status, e.ErrorMessage)
case e.Status != "":
return e.Status
case e.ErrorMessage != "":
return e.ErrorMessage
default:
// this should not happen
return "something went wrong, please see the ollama server logs for details"
}
}

View File

@@ -12,27 +12,6 @@ import (
"time"
)
// StatusError is an error with an HTTP status code and message.
type StatusError struct {
StatusCode int
Status string
ErrorMessage string `json:"error"`
}
func (e StatusError) Error() string {
switch {
case e.Status != "" && e.ErrorMessage != "":
return fmt.Sprintf("%s: %s", e.Status, e.ErrorMessage)
case e.Status != "":
return e.Status
case e.ErrorMessage != "":
return e.ErrorMessage
default:
// this should not happen
return "something went wrong, please see the ollama server logs for details"
}
}
// ImageData represents the raw binary data of an image file.
type ImageData []byte

View File

@@ -34,6 +34,7 @@ import (
"golang.org/x/term"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/auth"
"github.com/ollama/ollama/envconfig"
"github.com/ollama/ollama/format"
"github.com/ollama/ollama/parser"
@@ -513,6 +514,24 @@ func RunHandler(cmd *cobra.Command, args []string) error {
return generate(cmd, opts)
}
func localPubKeys() ([]string, error) {
usrKey, err := auth.GetPublicKey()
if err != nil {
return nil, err
}
keys := []string{usrKey}
if runtime.GOOS == "linux" {
// try the ollama service public key if on Linux
if svcKey, err := os.ReadFile("/usr/share/ollama/.ollama/id_ed25519.pub"); err == nil {
keys = append(keys, strings.TrimSpace(string(svcKey)))
}
}
return keys, nil
}
func PushHandler(cmd *cobra.Command, args []string) error {
client, err := api.ClientFromEnvironment()
if err != nil {
@@ -561,21 +580,29 @@ func PushHandler(cmd *cobra.Command, args []string) error {
request := api.PushRequest{Name: args[0], Insecure: insecure}
n := model.ParseName(args[0])
isOllamaHost := strings.HasSuffix(n.Host, ".ollama.ai") || strings.HasSuffix(n.Host, ".ollama.com")
if err := client.Push(cmd.Context(), &request, fn); err != nil {
if spinner != nil {
spinner.Stop()
}
var ke api.ErrUnknownOllamaKey
if errors.As(err, &ke) && isOllamaHost {
// the user has not added their ollama key to ollama.com
// return an error with a more user-friendly message
locals, _ := localPubKeys()
return errors.New(ke.FormatUserMessage(locals))
}
if strings.Contains(err.Error(), "access denied") {
return errors.New("you are not authorized to push to this namespace, create the model under a namespace you own")
}
return err
return fmt.Errorf("yoyoyo: %w", err)
}
p.Stop()
spinner.Stop()
destination := n.String()
if strings.HasSuffix(n.Host, ".ollama.ai") || strings.HasSuffix(n.Host, ".ollama.com") {
if isOllamaHost {
destination = "https://ollama.com/" + strings.TrimSuffix(n.DisplayShortest(), ":latest")
}
fmt.Printf("\nYou can find your model at:\n\n")

View File

@@ -373,15 +373,13 @@ func TestGetModelfileName(t *testing.T) {
func TestPushHandler(t *testing.T) {
tests := []struct {
name string
modelName string
serverResponse map[string]func(w http.ResponseWriter, r *http.Request)
expectedError string
expectedOutput string
}{
{
name: "successful push",
modelName: "test-model",
modelName: "successful-push",
serverResponse: map[string]func(w http.ResponseWriter, r *http.Request){
"/api/push": func(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
@@ -394,8 +392,8 @@ func TestPushHandler(t *testing.T) {
return
}
if req.Name != "test-model" {
t.Errorf("expected model name 'test-model', got %s", req.Name)
if req.Name != "successful-push" {
t.Errorf("expected model name 'successful-push', got %s", req.Name)
}
// Simulate progress updates
@@ -414,11 +412,10 @@ func TestPushHandler(t *testing.T) {
}
},
},
expectedOutput: "\nYou can find your model at:\n\n\thttps://ollama.com/test-model\n",
expectedOutput: "\nYou can find your model at:\n\n\thttps://ollama.com/successful-push\n",
},
{
name: "unauthorized push",
modelName: "unauthorized-model",
modelName: "unauthorized-push",
serverResponse: map[string]func(w http.ResponseWriter, r *http.Request){
"/api/push": func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
@@ -433,10 +430,29 @@ func TestPushHandler(t *testing.T) {
},
expectedError: "you are not authorized to push to this namespace, create the model under a namespace you own",
},
{
modelName: "unknown-key-err",
serverResponse: map[string]func(w http.ResponseWriter, r *http.Request){
"/api/push": func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusUnauthorized)
uerr := api.ErrUnknownOllamaKey{
Key: "aaa",
}
err := json.NewEncoder(w).Encode(map[string]string{
"error": uerr.Error(),
})
if err != nil {
t.Fatal(err)
}
},
},
expectedError: "unauthorized: unknown ollama key \"aaa\"",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Run(tt.modelName, func(t *testing.T) {
mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if handler, ok := tt.serverResponse[r.URL.Path]; ok {
handler(w, r)

View File

@@ -19,7 +19,6 @@ import (
"github.com/ollama/ollama/envconfig"
"github.com/ollama/ollama/parser"
"github.com/ollama/ollama/readline"
"github.com/ollama/ollama/types/errtypes"
)
type MultilineState int
@@ -220,7 +219,7 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error {
fn := func(resp api.ProgressResponse) error { return nil }
err = client.Create(cmd.Context(), req, fn)
if err != nil {
if strings.Contains(err.Error(), errtypes.InvalidModelNameErrMsg) {
if strings.Contains(err.Error(), api.InvalidModelNameErrMsg) {
fmt.Printf("error: The model name '%s' is invalid\n", args[1])
continue
}
@@ -514,7 +513,7 @@ func extractFileNames(input string) []string {
// Regex to match file paths starting with optional drive letter, / ./ \ or .\ and include escaped or unescaped spaces (\ or %20)
// and followed by more characters and a file extension
// This will capture non filename strings, but we'll check for file existence to remove mismatches
regexPattern := `(?:[a-zA-Z]:)?(?:\./|/|\\)[\S\\ ]+?\.(?i:jpg|jpeg|png|svg)\b`
regexPattern := `(?:[a-zA-Z]:)?(?:\./|/|\\)[\S\\ ]+?\.(?i:jpg|jpeg|png)\b`
re := regexp.MustCompile(regexPattern)
return re.FindAllString(input, -1)

View File

@@ -12,44 +12,45 @@ import (
func TestExtractFilenames(t *testing.T) {
// Unix style paths
input := ` some preamble
./relative\ path/one.png inbetween1 ./not a valid two.jpg inbetween2
/unescaped space /three.jpeg inbetween3 /valid\ path/dir/four.png "./quoted with spaces/five.svg`
./relative\ path/one.png inbetween1 ./not a valid two.jpg inbetween2 ./1.svg
/unescaped space /three.jpeg inbetween3 /valid\ path/dir/four.png "./quoted with spaces/five.JPG`
res := extractFileNames(input)
assert.Len(t, res, 5)
assert.Contains(t, res[0], "one.png")
assert.Contains(t, res[1], "two.jpg")
assert.Contains(t, res[2], "three.jpeg")
assert.Contains(t, res[3], "four.png")
assert.Contains(t, res[4], "five.svg")
assert.Contains(t, res[4], "five.JPG")
assert.NotContains(t, res[4], '"')
assert.NotContains(t, res, "inbtween")
assert.NotContains(t, res, "inbetween1")
assert.NotContains(t, res, "./1.svg")
// Windows style paths
input = ` some preamble
c:/users/jdoe/one.png inbetween1 c:/program files/someplace/two.jpg inbetween2
/absolute/nospace/three.jpeg inbetween3 /absolute/with space/four.png inbetween4
./relative\ path/five.svg inbetween5 "./relative with/spaces/six.png inbetween6
d:\path with\spaces\seven.svg inbetween7 c:\users\jdoe\eight.png inbetween8
d:\program files\someplace\nine.png inbetween9 "E:\program files\someplace\ten.svg some ending
./relative\ path/five.JPG inbetween5 "./relative with/spaces/six.png inbetween6
d:\path with\spaces\seven.JPEG inbetween7 c:\users\jdoe\eight.png inbetween8
d:\program files\someplace\nine.png inbetween9 "E:\program files\someplace\ten.PNG some ending
`
res = extractFileNames(input)
assert.Len(t, res, 10)
assert.NotContains(t, res, "inbtween")
assert.NotContains(t, res, "inbetween2")
assert.Contains(t, res[0], "one.png")
assert.Contains(t, res[0], "c:")
assert.Contains(t, res[1], "two.jpg")
assert.Contains(t, res[1], "c:")
assert.Contains(t, res[2], "three.jpeg")
assert.Contains(t, res[3], "four.png")
assert.Contains(t, res[4], "five.svg")
assert.Contains(t, res[4], "five.JPG")
assert.Contains(t, res[5], "six.png")
assert.Contains(t, res[6], "seven.svg")
assert.Contains(t, res[6], "seven.JPEG")
assert.Contains(t, res[6], "d:")
assert.Contains(t, res[7], "eight.png")
assert.Contains(t, res[7], "c:")
assert.Contains(t, res[8], "nine.png")
assert.Contains(t, res[8], "d:")
assert.Contains(t, res[9], "ten.svg")
assert.Contains(t, res[9], "ten.PNG")
assert.Contains(t, res[9], "E:")
}

View File

@@ -199,6 +199,20 @@ func countCommonPrefix(a []input, b []input) int {
return count
}
func (c *InputCache) ShiftDiscard(inputLen int, numKeep int) int {
targetFree := (c.numCtx - numKeep) / 2
targetFree = max(targetFree, 1)
currentFree := c.numCtx - inputLen
discard := targetFree - currentFree
if discard < 0 {
discard = 0
}
return discard
}
// Frees up space in the KV cache by deleting the oldest half of history and shifting
// the newest half into that space (saving numKeep inputs at the beginning).
//
@@ -208,11 +222,7 @@ func (c *InputCache) ShiftCacheSlot(slot *InputCacheSlot, numKeep int) error {
return fmt.Errorf("unable to shift context - keep exceeds context (keep: %v context: %v)", numKeep, c.numCtx)
}
targetFree := (c.numCtx - numKeep) / 2
targetFree = max(targetFree, 1)
currentFree := c.numCtx - len(slot.Inputs)
discard := targetFree - currentFree
discard := c.ShiftDiscard(len(slot.Inputs), numKeep)
if discard <= 0 {
return nil

View File

@@ -227,3 +227,66 @@ func TestFindCacheSlot(t *testing.T) {
})
}
}
func TestShiftDiscard(t *testing.T) {
tests := []struct {
name string
numCtx int
numKeep int
inputLen int
expected int
}{
{
name: "Shift",
numCtx: 2048,
numKeep: 5,
inputLen: 2048,
expected: 1021,
},
{
name: "Max Keep",
numCtx: 2048,
numKeep: 2047,
inputLen: 2048,
expected: 1,
},
{
name: "No Keep",
numCtx: 2048,
numKeep: 0,
inputLen: 2048,
expected: 1024,
},
{
name: "Truncate",
numCtx: 2048,
numKeep: 5,
inputLen: 5000,
expected: 3973,
},
{
name: "Truncate Keep",
numCtx: 2048,
numKeep: 2047,
inputLen: 5000,
expected: 2953,
},
{
name: "No Op",
numCtx: 2048,
numKeep: 5,
inputLen: 512,
expected: 0,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
c := InputCache{numCtx: tt.numCtx}
result := c.ShiftDiscard(tt.inputLen, tt.numKeep)
if result != tt.expected {
t.Errorf("shiftDiscard(ctx: %v, keep: %v input: %v): have %v; want %v", tt.numCtx, tt.numKeep, tt.inputLen, result, tt.expected)
}
})
}
}

View File

@@ -122,9 +122,11 @@ func (s *Server) NewSequence(prompt string, images []ImageData, params NewSequen
params.numKeep = min(params.numKeep, s.cache.numCtx-1)
if len(inputs) > s.cache.numCtx {
slog.Warn("truncating input prompt", "limit", s.cache.numCtx, "prompt", len(inputs), "numKeep", params.numKeep)
discard := len(inputs) - s.cache.numCtx
newInputs := inputs[:params.numKeep]
newInputs = append(newInputs, inputs[len(inputs)-s.cache.numCtx+params.numKeep:]...)
newInputs = append(newInputs, inputs[params.numKeep+discard:]...)
slog.Warn("truncating input prompt", "limit", s.cache.numCtx, "prompt", len(inputs), "keep", params.numKeep, "new", len(newInputs))
inputs = newInputs
}
@@ -162,10 +164,16 @@ func (s *Server) NewSequence(prompt string, images []ImageData, params NewSequen
// generating image embeddings for each image
func (s *Server) inputs(prompt string, images []ImageData) ([]input, error) {
var inputs []input
var parts []string
var matches [][]string
re := regexp.MustCompile(`\[img-(\d+)\]`)
parts := re.Split(prompt, -1)
matches := re.FindAllStringSubmatch(prompt, -1)
if s.image != nil {
re := regexp.MustCompile(`\[img-(\d+)\]`)
parts = re.Split(prompt, -1)
matches = re.FindAllStringSubmatch(prompt, -1)
} else {
parts = []string{prompt}
}
for i, part := range parts {
// text - tokenize

View File

@@ -571,7 +571,7 @@ type EmbedWriter struct {
model string
}
func (w *BaseWriter) writeError(code int, data []byte) (int, error) {
func (w *BaseWriter) writeError(data []byte) (int, error) {
var serr api.StatusError
err := json.Unmarshal(data, &serr)
if err != nil {
@@ -630,7 +630,7 @@ func (w *ChatWriter) writeResponse(data []byte) (int, error) {
func (w *ChatWriter) Write(data []byte) (int, error) {
code := w.ResponseWriter.Status()
if code != http.StatusOK {
return w.writeError(code, data)
return w.writeError(data)
}
return w.writeResponse(data)
@@ -679,7 +679,7 @@ func (w *CompleteWriter) writeResponse(data []byte) (int, error) {
func (w *CompleteWriter) Write(data []byte) (int, error) {
code := w.ResponseWriter.Status()
if code != http.StatusOK {
return w.writeError(code, data)
return w.writeError(data)
}
return w.writeResponse(data)
@@ -704,7 +704,7 @@ func (w *ListWriter) writeResponse(data []byte) (int, error) {
func (w *ListWriter) Write(data []byte) (int, error) {
code := w.ResponseWriter.Status()
if code != http.StatusOK {
return w.writeError(code, data)
return w.writeError(data)
}
return w.writeResponse(data)
@@ -730,7 +730,7 @@ func (w *RetrieveWriter) writeResponse(data []byte) (int, error) {
func (w *RetrieveWriter) Write(data []byte) (int, error) {
code := w.ResponseWriter.Status()
if code != http.StatusOK {
return w.writeError(code, data)
return w.writeError(data)
}
return w.writeResponse(data)
@@ -755,7 +755,7 @@ func (w *EmbedWriter) writeResponse(data []byte) (int, error) {
func (w *EmbedWriter) Write(data []byte) (int, error) {
code := w.ResponseWriter.Status()
if code != http.StatusOK {
return w.writeError(code, data)
return w.writeError(data)
}
return w.writeResponse(data)

View File

@@ -23,6 +23,7 @@ import (
"strings"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/auth"
"github.com/ollama/ollama/envconfig"
"github.com/ollama/ollama/format"
"github.com/ollama/ollama/llama"
@@ -30,6 +31,7 @@ import (
"github.com/ollama/ollama/parser"
"github.com/ollama/ollama/template"
"github.com/ollama/ollama/types/model"
"github.com/ollama/ollama/types/registry"
"github.com/ollama/ollama/version"
)
@@ -980,8 +982,6 @@ func GetSHA256Digest(r io.Reader) (string, int64) {
return fmt.Sprintf("sha256:%x", h.Sum(nil)), n
}
var errUnauthorized = errors.New("unauthorized: access denied")
func makeRequestWithRetry(ctx context.Context, method string, requestURL *url.URL, headers http.Header, body io.ReadSeeker, regOpts *registryOptions) (*http.Response, error) {
for range 2 {
resp, err := makeRequest(ctx, method, requestURL, headers, body, regOpts)
@@ -1019,13 +1019,33 @@ func makeRequestWithRetry(ctx context.Context, method string, requestURL *url.UR
if err != nil {
return nil, fmt.Errorf("%d: %s", resp.StatusCode, err)
}
var re registry.Errs
if err := json.Unmarshal(responseBody, &re); err == nil && len(re.Errors) > 0 {
if re.HasCode(registry.ErrCodeAnonymous) {
// if the error is due to anonymous access return a custom error
// this error is used by the CLI to direct a user to add their key to an account
pubKey, nestedErr := auth.GetPublicKey()
if nestedErr != nil {
slog.Error(fmt.Sprintf("couldn't get public key: %v", nestedErr))
return nil, re
}
return nil, api.ErrUnknownOllamaKey{
Key: pubKey,
}
}
return nil, re
}
// Fallback to returning the raw response if parsing fails
return nil, fmt.Errorf("%d: %s", resp.StatusCode, responseBody)
default:
return resp, nil
}
}
return nil, errUnauthorized
// should never be reached
return nil, fmt.Errorf("failed to make upload request")
}
// testMakeRequestDialContext specifies the dial function for the http client in

View File

@@ -36,7 +36,6 @@ import (
"github.com/ollama/ollama/runners"
"github.com/ollama/ollama/server/imageproc"
"github.com/ollama/ollama/template"
"github.com/ollama/ollama/types/errtypes"
"github.com/ollama/ollama/types/model"
"github.com/ollama/ollama/version"
)
@@ -610,7 +609,7 @@ func (s *Server) PushHandler(c *gin.Context) {
defer cancel()
if err := PushModel(ctx, model, regOpts, fn); err != nil {
ch <- gin.H{"error": err.Error()}
ch <- newErr(err)
}
}()
@@ -650,7 +649,7 @@ func (s *Server) CreateHandler(c *gin.Context) {
name := model.ParseName(cmp.Or(r.Model, r.Name))
if !name.IsValid() {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": errtypes.InvalidModelNameErrMsg})
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": api.InvalidModelNameErrMsg})
return
}
@@ -1550,3 +1549,24 @@ func handleScheduleError(c *gin.Context, name string, err error) {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
}
}
// newErr creates a structured API ErrorResponse from an existing error
func newErr(err error) api.ErrorResponse {
if err == nil {
return api.ErrorResponse{}
}
// Default to just returning the generic error message
resp := api.ErrorResponse{
Code: api.ErrCodeGeneral,
Message: err.Error(),
}
// Add additional error specific data, if any
var keyErr api.ErrUnknownOllamaKey
if errors.As(err, &keyErr) {
resp.Code = api.ErrCodeUnknownKey
resp.Data = map[string]any{
"key": keyErr.Key,
}
}
return resp
}

View File

@@ -1,21 +0,0 @@
// Package errtypes contains custom error types
package errtypes
import (
"fmt"
"strings"
)
const (
UnknownOllamaKeyErrMsg = "unknown ollama key"
InvalidModelNameErrMsg = "invalid model name"
)
// TODO: This should have a structured response from the API
type UnknownOllamaKey struct {
Key string
}
func (e *UnknownOllamaKey) Error() string {
return fmt.Sprintf("unauthorized: %s %q", UnknownOllamaKeyErrMsg, strings.TrimSpace(e.Key))
}

37
types/registry/error.go Normal file
View File

@@ -0,0 +1,37 @@
package registry
import (
"fmt"
"slices"
"strings"
)
const ErrCodeAnonymous = "ANONYMOUS_ACCESS_DENIED"
type Err struct {
Code string `json:"code"`
Message string `json:"message"`
}
// Errs represents the structure of error responses from the registry
// TODO (brucemacd): this struct should be imported from some shared package that is used between the registry and ollama
type Errs struct {
Errors []Err `json:"errors"`
}
func (e Errs) Error() string {
if len(e.Errors) == 0 {
return "unknown registry error"
}
var msgs []string
for _, err := range e.Errors {
msgs = append(msgs, fmt.Sprintf("%s: %s", err.Code, err.Message))
}
return strings.Join(msgs, "; ")
}
func (e Errs) HasCode(code string) bool {
return slices.ContainsFunc(e.Errors, func(err Err) bool {
return err.Code == code
})
}