Compare commits

..

1 Commits

Author SHA1 Message Date
jmorganca
bc40a820f0 server: fix proxy not being set from environment on ollama pull and ollama push 2024-11-25 14:28:37 -08:00
17 changed files with 97 additions and 590 deletions

View File

@@ -504,7 +504,7 @@ See the [API documentation](./docs/api.md) for all endpoints.
- [Terraform AWS Ollama & Open WebUI](https://github.com/xuyangbocn/terraform-aws-self-host-llm) (A Terraform module to deploy on AWS a ready-to-use Ollama service, together with its front end Open WebUI service.)
- [node-red-contrib-ollama](https://github.com/jakubburkiewicz/node-red-contrib-ollama)
- [Local AI Helper](https://github.com/ivostoykov/localAI) (Chrome and Firefox extensions that enable interactions with the active tab and customisable API endpoints. Includes secure storage for user prompts.)
- [vnc-lm](https://github.com/jake83741/vnc-lm) (Discord bot for messaging with LLMs through Ollama and LiteLLM. Seamlessly move between local and flagship models.)
- [vnc-lm](https://github.com/jk011ru/vnc-lm) (A containerized Discord bot with support for attachments and web links)
- [LSP-AI](https://github.com/SilasMarvin/lsp-ai) (Open-source language server for AI-powered functionality)
- [QodeAssist](https://github.com/Palm1r/QodeAssist) (AI-powered coding assistant plugin for Qt Creator)
- [Obsidian Quiz Generator plugin](https://github.com/ECuiDev/obsidian-quiz-generator)
@@ -518,4 +518,3 @@ See the [API documentation](./docs/api.md) for all endpoints.
### Observability
- [OpenLIT](https://github.com/openlit/openlit) is an OpenTelemetry-native tool for monitoring Ollama Applications & GPUs using traces and metrics.
- [HoneyHive](https://docs.honeyhive.ai/integrations/ollama) is an AI observability and evaluation platform for AI agents. Use HoneyHive to evaluate agent performance, interrogate failures, and monitor quality in production.

View File

@@ -163,29 +163,24 @@ func (c *Client) stream(ctx context.Context, method, path string, data any, fn f
scanBuf := make([]byte, 0, maxBufferSize)
scanner.Buffer(scanBuf, maxBufferSize)
for scanner.Scan() {
bts := scanner.Bytes()
var errorResponse struct {
Error string `json:"error,omitempty"`
}
var errorResponse ErrorResponse
bts := scanner.Bytes()
if err := json.Unmarshal(bts, &errorResponse); err != nil {
return fmt.Errorf("unmarshal: %w", err)
}
switch errorResponse.Code {
case ErrCodeUnknownKey:
return ErrUnknownOllamaKey{
Message: errorResponse.Message,
Key: errorResponse.Data["key"].(string),
}
}
if errorResponse.Message != "" {
return errors.New(errorResponse.Message)
if errorResponse.Error != "" {
return errors.New(errorResponse.Error)
}
if response.StatusCode >= http.StatusBadRequest {
return StatusError{
StatusCode: response.StatusCode,
Status: response.Status,
ErrorMessage: errorResponse.Message,
ErrorMessage: errorResponse.Error,
}
}

View File

@@ -1,12 +1,6 @@
package api
import (
"context"
"errors"
"fmt"
"net/http"
"net/http/httptest"
"net/url"
"testing"
)
@@ -49,117 +43,3 @@ func TestClientFromEnvironment(t *testing.T) {
})
}
}
func TestStream(t *testing.T) {
tests := []struct {
name string
serverResponse []string
statusCode int
expectedError error
}{
{
name: "unknown key error",
serverResponse: []string{
`{"error":"unauthorized access","code":"unknown_key","data":{"key":"test-key"}}`,
},
statusCode: http.StatusUnauthorized,
expectedError: &ErrUnknownOllamaKey{
Message: "unauthorized access",
Key: "test-key",
},
},
{
name: "general error message",
serverResponse: []string{
`{"error":"something went wrong"}`,
},
statusCode: http.StatusInternalServerError,
expectedError: fmt.Errorf("something went wrong"),
},
{
name: "malformed json response",
serverResponse: []string{
`{invalid-json`,
},
statusCode: http.StatusOK,
expectedError: fmt.Errorf("unmarshal: invalid character 'i' looking for beginning of object key string"),
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/x-ndjson")
w.WriteHeader(tt.statusCode)
for _, resp := range tt.serverResponse {
fmt.Fprintln(w, resp)
}
}))
defer server.Close()
baseURL, err := url.Parse(server.URL)
if err != nil {
t.Fatalf("failed to parse server URL: %v", err)
}
client := &Client{
http: server.Client(),
base: baseURL,
}
var responses [][]byte
err = client.stream(context.Background(), "POST", "/test", "test", func(bts []byte) error {
responses = append(responses, bts)
return nil
})
// Error checking
if tt.expectedError == nil {
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
return
}
if err == nil {
t.Fatalf("expected error %v, got nil", tt.expectedError)
}
// Check for specific error types
var unknownKeyErr ErrUnknownOllamaKey
if errors.As(tt.expectedError, &unknownKeyErr) {
var gotErr ErrUnknownOllamaKey
if !errors.As(err, &gotErr) {
t.Fatalf("expected ErrUnknownOllamaKey, got %T", err)
}
if unknownKeyErr.Key != gotErr.Key {
t.Errorf("expected key %q, got %q", unknownKeyErr.Key, gotErr.Key)
}
if unknownKeyErr.Message != gotErr.Message {
t.Errorf("expected message %q, got %q", unknownKeyErr.Message, gotErr.Message)
}
return
}
var statusErr StatusError
if errors.As(tt.expectedError, &statusErr) {
var gotErr StatusError
if !errors.As(err, &gotErr) {
t.Fatalf("expected StatusError, got %T", err)
}
if statusErr.StatusCode != gotErr.StatusCode {
t.Errorf("expected status code %d, got %d", statusErr.StatusCode, gotErr.StatusCode)
}
if statusErr.ErrorMessage != gotErr.ErrorMessage {
t.Errorf("expected error message %q, got %q", statusErr.ErrorMessage, gotErr.ErrorMessage)
}
return
}
// For other errors, compare error strings
if err.Error() != tt.expectedError.Error() {
t.Errorf("expected error %q, got %q", tt.expectedError, err)
}
})
}
}

View File

@@ -1,74 +0,0 @@
package api
import (
"fmt"
"slices"
"strings"
)
const InvalidModelNameErrMsg = "invalid model name"
// API error responses
// ErrorCode represents a standardized error code identifier
type ErrorCode string
const (
ErrCodeUnknownKey ErrorCode = "unknown_key"
ErrCodeGeneral ErrorCode = "general" // Generic fallback error code
)
// ErrorResponse implements a structured error interface
type ErrorResponse struct {
Message string `json:"error"` // Human-readable error message, uses 'error' field name for backwards compatibility
Code ErrorCode `json:"code"` // Machine-readable error code for programmatic handling, not response code
Data map[string]any `json:"data"` // Additional error specific data, if any
}
func (e ErrorResponse) Error() string {
return e.Message
}
type ErrUnknownOllamaKey struct {
Message string
Key string
}
func (e ErrUnknownOllamaKey) Error() string {
return fmt.Sprintf("unauthorized: unknown ollama key %q", strings.TrimSpace(e.Key))
}
func (e *ErrUnknownOllamaKey) FormatUserMessage(localKeys []string) string {
// The user should only be told to add the key if it is the same one that exists locally
if slices.Index(localKeys, e.Key) == -1 {
return e.Message
}
return fmt.Sprintf(`%s
Your ollama key is:
%s
Add your key at:
https://ollama.com/settings/keys`, e.Message, e.Key)
}
// StatusError is an error with an HTTP status code and message,
// it is parsed on the client-side and not returned from the API
type StatusError struct {
StatusCode int // e.g. 200
Status string // e.g. "200 OK"
ErrorMessage string `json:"error"`
}
func (e StatusError) Error() string {
switch {
case e.Status != "" && e.ErrorMessage != "":
return fmt.Sprintf("%s: %s", e.Status, e.ErrorMessage)
case e.Status != "":
return e.Status
case e.ErrorMessage != "":
return e.ErrorMessage
default:
// this should not happen
return "something went wrong, please see the ollama server logs for details"
}
}

View File

@@ -12,6 +12,27 @@ import (
"time"
)
// StatusError is an error with an HTTP status code and message.
type StatusError struct {
StatusCode int
Status string
ErrorMessage string `json:"error"`
}
func (e StatusError) Error() string {
switch {
case e.Status != "" && e.ErrorMessage != "":
return fmt.Sprintf("%s: %s", e.Status, e.ErrorMessage)
case e.Status != "":
return e.Status
case e.ErrorMessage != "":
return e.ErrorMessage
default:
// this should not happen
return "something went wrong, please see the ollama server logs for details"
}
}
// ImageData represents the raw binary data of an image file.
type ImageData []byte

View File

@@ -34,13 +34,11 @@ import (
"golang.org/x/term"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/auth"
"github.com/ollama/ollama/envconfig"
"github.com/ollama/ollama/format"
"github.com/ollama/ollama/parser"
"github.com/ollama/ollama/progress"
"github.com/ollama/ollama/server"
"github.com/ollama/ollama/types/model"
"github.com/ollama/ollama/version"
)
@@ -514,24 +512,6 @@ func RunHandler(cmd *cobra.Command, args []string) error {
return generate(cmd, opts)
}
func localPubKeys() ([]string, error) {
usrKey, err := auth.GetPublicKey()
if err != nil {
return nil, err
}
keys := []string{usrKey}
if runtime.GOOS == "linux" {
// try the ollama service public key if on Linux
if svcKey, err := os.ReadFile("/usr/share/ollama/.ollama/id_ed25519.pub"); err == nil {
keys = append(keys, strings.TrimSpace(string(svcKey)))
}
}
return keys, nil
}
func PushHandler(cmd *cobra.Command, args []string) error {
client, err := api.ClientFromEnvironment()
if err != nil {
@@ -578,36 +558,17 @@ func PushHandler(cmd *cobra.Command, args []string) error {
}
request := api.PushRequest{Name: args[0], Insecure: insecure}
n := model.ParseName(args[0])
isOllamaHost := strings.HasSuffix(n.Host, ".ollama.ai") || strings.HasSuffix(n.Host, ".ollama.com")
if err := client.Push(cmd.Context(), &request, fn); err != nil {
if spinner != nil {
spinner.Stop()
}
var ke api.ErrUnknownOllamaKey
if errors.As(err, &ke) && isOllamaHost {
// the user has not added their ollama key to ollama.com
// return an error with a more user-friendly message
locals, _ := localPubKeys()
return errors.New(ke.FormatUserMessage(locals))
}
if strings.Contains(err.Error(), "access denied") {
return errors.New("you are not authorized to push to this namespace, create the model under a namespace you own")
}
return fmt.Errorf("yoyoyo: %w", err)
return err
}
p.Stop()
spinner.Stop()
destination := n.String()
if isOllamaHost {
destination = "https://ollama.com/" + strings.TrimSuffix(n.DisplayShortest(), ":latest")
}
fmt.Printf("\nYou can find your model at:\n\n")
fmt.Printf("\t%s\n", destination)
return nil
}

View File

@@ -4,7 +4,6 @@ import (
"bytes"
"context"
"encoding/json"
"io"
"net/http"
"net/http/httptest"
"os"
@@ -370,143 +369,3 @@ func TestGetModelfileName(t *testing.T) {
})
}
}
func TestPushHandler(t *testing.T) {
tests := []struct {
modelName string
serverResponse map[string]func(w http.ResponseWriter, r *http.Request)
expectedError string
expectedOutput string
}{
{
modelName: "successful-push",
serverResponse: map[string]func(w http.ResponseWriter, r *http.Request){
"/api/push": func(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
t.Errorf("expected POST request, got %s", r.Method)
}
var req api.PushRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
if req.Name != "successful-push" {
t.Errorf("expected model name 'successful-push', got %s", req.Name)
}
// Simulate progress updates
responses := []api.ProgressResponse{
{Status: "preparing manifest"},
{Digest: "sha256:abc123456789", Total: 100, Completed: 50},
{Digest: "sha256:abc123456789", Total: 100, Completed: 100},
}
for _, resp := range responses {
if err := json.NewEncoder(w).Encode(resp); err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
w.(http.Flusher).Flush()
}
},
},
expectedOutput: "\nYou can find your model at:\n\n\thttps://ollama.com/successful-push\n",
},
{
modelName: "unauthorized-push",
serverResponse: map[string]func(w http.ResponseWriter, r *http.Request){
"/api/push": func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusUnauthorized)
err := json.NewEncoder(w).Encode(map[string]string{
"error": "access denied",
})
if err != nil {
t.Fatal(err)
}
},
},
expectedError: "you are not authorized to push to this namespace, create the model under a namespace you own",
},
{
modelName: "unknown-key-err",
serverResponse: map[string]func(w http.ResponseWriter, r *http.Request){
"/api/push": func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusUnauthorized)
uerr := api.ErrUnknownOllamaKey{
Key: "aaa",
}
err := json.NewEncoder(w).Encode(map[string]string{
"error": uerr.Error(),
})
if err != nil {
t.Fatal(err)
}
},
},
expectedError: "unauthorized: unknown ollama key \"aaa\"",
},
}
for _, tt := range tests {
t.Run(tt.modelName, func(t *testing.T) {
mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if handler, ok := tt.serverResponse[r.URL.Path]; ok {
handler(w, r)
return
}
http.Error(w, "not found", http.StatusNotFound)
}))
defer mockServer.Close()
t.Setenv("OLLAMA_HOST", mockServer.URL)
cmd := &cobra.Command{}
cmd.Flags().Bool("insecure", false, "")
cmd.SetContext(context.TODO())
// Redirect stderr to capture progress output
oldStderr := os.Stderr
r, w, _ := os.Pipe()
os.Stderr = w
// Capture stdout for the "Model pushed" message
oldStdout := os.Stdout
outR, outW, _ := os.Pipe()
os.Stdout = outW
err := PushHandler(cmd, []string{tt.modelName})
// Restore stderr
w.Close()
os.Stderr = oldStderr
// drain the pipe
if _, err := io.ReadAll(r); err != nil {
t.Fatal(err)
}
// Restore stdout and get output
outW.Close()
os.Stdout = oldStdout
stdout, _ := io.ReadAll(outR)
if tt.expectedError == "" {
if err != nil {
t.Errorf("expected no error, got %v", err)
}
if tt.expectedOutput != "" {
if got := string(stdout); got != tt.expectedOutput {
t.Errorf("expected output %q, got %q", tt.expectedOutput, got)
}
}
} else {
if err == nil || !strings.Contains(err.Error(), tt.expectedError) {
t.Errorf("expected error containing %q, got %v", tt.expectedError, err)
}
}
})
}
}

View File

@@ -19,6 +19,7 @@ import (
"github.com/ollama/ollama/envconfig"
"github.com/ollama/ollama/parser"
"github.com/ollama/ollama/readline"
"github.com/ollama/ollama/types/errtypes"
)
type MultilineState int
@@ -219,7 +220,7 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error {
fn := func(resp api.ProgressResponse) error { return nil }
err = client.Create(cmd.Context(), req, fn)
if err != nil {
if strings.Contains(err.Error(), api.InvalidModelNameErrMsg) {
if strings.Contains(err.Error(), errtypes.InvalidModelNameErrMsg) {
fmt.Printf("error: The model name '%s' is invalid\n", args[1])
continue
}
@@ -513,7 +514,7 @@ func extractFileNames(input string) []string {
// Regex to match file paths starting with optional drive letter, / ./ \ or .\ and include escaped or unescaped spaces (\ or %20)
// and followed by more characters and a file extension
// This will capture non filename strings, but we'll check for file existence to remove mismatches
regexPattern := `(?:[a-zA-Z]:)?(?:\./|/|\\)[\S\\ ]+?\.(?i:jpg|jpeg|png)\b`
regexPattern := `(?:[a-zA-Z]:)?(?:\./|/|\\)[\S\\ ]+?\.(?i:jpg|jpeg|png|svg)\b`
re := regexp.MustCompile(regexPattern)
return re.FindAllString(input, -1)

View File

@@ -12,45 +12,44 @@ import (
func TestExtractFilenames(t *testing.T) {
// Unix style paths
input := ` some preamble
./relative\ path/one.png inbetween1 ./not a valid two.jpg inbetween2 ./1.svg
/unescaped space /three.jpeg inbetween3 /valid\ path/dir/four.png "./quoted with spaces/five.JPG`
./relative\ path/one.png inbetween1 ./not a valid two.jpg inbetween2
/unescaped space /three.jpeg inbetween3 /valid\ path/dir/four.png "./quoted with spaces/five.svg`
res := extractFileNames(input)
assert.Len(t, res, 5)
assert.Contains(t, res[0], "one.png")
assert.Contains(t, res[1], "two.jpg")
assert.Contains(t, res[2], "three.jpeg")
assert.Contains(t, res[3], "four.png")
assert.Contains(t, res[4], "five.JPG")
assert.Contains(t, res[4], "five.svg")
assert.NotContains(t, res[4], '"')
assert.NotContains(t, res, "inbetween1")
assert.NotContains(t, res, "./1.svg")
assert.NotContains(t, res, "inbtween")
// Windows style paths
input = ` some preamble
c:/users/jdoe/one.png inbetween1 c:/program files/someplace/two.jpg inbetween2
/absolute/nospace/three.jpeg inbetween3 /absolute/with space/four.png inbetween4
./relative\ path/five.JPG inbetween5 "./relative with/spaces/six.png inbetween6
d:\path with\spaces\seven.JPEG inbetween7 c:\users\jdoe\eight.png inbetween8
d:\program files\someplace\nine.png inbetween9 "E:\program files\someplace\ten.PNG some ending
./relative\ path/five.svg inbetween5 "./relative with/spaces/six.png inbetween6
d:\path with\spaces\seven.svg inbetween7 c:\users\jdoe\eight.png inbetween8
d:\program files\someplace\nine.png inbetween9 "E:\program files\someplace\ten.svg some ending
`
res = extractFileNames(input)
assert.Len(t, res, 10)
assert.NotContains(t, res, "inbetween2")
assert.NotContains(t, res, "inbtween")
assert.Contains(t, res[0], "one.png")
assert.Contains(t, res[0], "c:")
assert.Contains(t, res[1], "two.jpg")
assert.Contains(t, res[1], "c:")
assert.Contains(t, res[2], "three.jpeg")
assert.Contains(t, res[3], "four.png")
assert.Contains(t, res[4], "five.JPG")
assert.Contains(t, res[4], "five.svg")
assert.Contains(t, res[5], "six.png")
assert.Contains(t, res[6], "seven.JPEG")
assert.Contains(t, res[6], "seven.svg")
assert.Contains(t, res[6], "d:")
assert.Contains(t, res[7], "eight.png")
assert.Contains(t, res[7], "c:")
assert.Contains(t, res[8], "nine.png")
assert.Contains(t, res[8], "d:")
assert.Contains(t, res[9], "ten.PNG")
assert.Contains(t, res[9], "ten.svg")
assert.Contains(t, res[9], "E:")
}

View File

@@ -199,20 +199,6 @@ func countCommonPrefix(a []input, b []input) int {
return count
}
func (c *InputCache) ShiftDiscard(inputLen int, numKeep int) int {
targetFree := (c.numCtx - numKeep) / 2
targetFree = max(targetFree, 1)
currentFree := c.numCtx - inputLen
discard := targetFree - currentFree
if discard < 0 {
discard = 0
}
return discard
}
// Frees up space in the KV cache by deleting the oldest half of history and shifting
// the newest half into that space (saving numKeep inputs at the beginning).
//
@@ -222,7 +208,11 @@ func (c *InputCache) ShiftCacheSlot(slot *InputCacheSlot, numKeep int) error {
return fmt.Errorf("unable to shift context - keep exceeds context (keep: %v context: %v)", numKeep, c.numCtx)
}
discard := c.ShiftDiscard(len(slot.Inputs), numKeep)
targetFree := (c.numCtx - numKeep) / 2
targetFree = max(targetFree, 1)
currentFree := c.numCtx - len(slot.Inputs)
discard := targetFree - currentFree
if discard <= 0 {
return nil

View File

@@ -227,66 +227,3 @@ func TestFindCacheSlot(t *testing.T) {
})
}
}
func TestShiftDiscard(t *testing.T) {
tests := []struct {
name string
numCtx int
numKeep int
inputLen int
expected int
}{
{
name: "Shift",
numCtx: 2048,
numKeep: 5,
inputLen: 2048,
expected: 1021,
},
{
name: "Max Keep",
numCtx: 2048,
numKeep: 2047,
inputLen: 2048,
expected: 1,
},
{
name: "No Keep",
numCtx: 2048,
numKeep: 0,
inputLen: 2048,
expected: 1024,
},
{
name: "Truncate",
numCtx: 2048,
numKeep: 5,
inputLen: 5000,
expected: 3973,
},
{
name: "Truncate Keep",
numCtx: 2048,
numKeep: 2047,
inputLen: 5000,
expected: 2953,
},
{
name: "No Op",
numCtx: 2048,
numKeep: 5,
inputLen: 512,
expected: 0,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
c := InputCache{numCtx: tt.numCtx}
result := c.ShiftDiscard(tt.inputLen, tt.numKeep)
if result != tt.expected {
t.Errorf("shiftDiscard(ctx: %v, keep: %v input: %v): have %v; want %v", tt.numCtx, tt.numKeep, tt.inputLen, result, tt.expected)
}
})
}
}

View File

@@ -122,11 +122,9 @@ func (s *Server) NewSequence(prompt string, images []ImageData, params NewSequen
params.numKeep = min(params.numKeep, s.cache.numCtx-1)
if len(inputs) > s.cache.numCtx {
discard := len(inputs) - s.cache.numCtx
slog.Warn("truncating input prompt", "limit", s.cache.numCtx, "prompt", len(inputs), "numKeep", params.numKeep)
newInputs := inputs[:params.numKeep]
newInputs = append(newInputs, inputs[params.numKeep+discard:]...)
slog.Warn("truncating input prompt", "limit", s.cache.numCtx, "prompt", len(inputs), "keep", params.numKeep, "new", len(newInputs))
newInputs = append(newInputs, inputs[len(inputs)-s.cache.numCtx+params.numKeep:]...)
inputs = newInputs
}
@@ -164,16 +162,10 @@ func (s *Server) NewSequence(prompt string, images []ImageData, params NewSequen
// generating image embeddings for each image
func (s *Server) inputs(prompt string, images []ImageData) ([]input, error) {
var inputs []input
var parts []string
var matches [][]string
if s.image != nil {
re := regexp.MustCompile(`\[img-(\d+)\]`)
parts = re.Split(prompt, -1)
matches = re.FindAllStringSubmatch(prompt, -1)
} else {
parts = []string{prompt}
}
re := regexp.MustCompile(`\[img-(\d+)\]`)
parts := re.Split(prompt, -1)
matches := re.FindAllStringSubmatch(prompt, -1)
for i, part := range parts {
// text - tokenize

View File

@@ -571,7 +571,7 @@ type EmbedWriter struct {
model string
}
func (w *BaseWriter) writeError(data []byte) (int, error) {
func (w *BaseWriter) writeError(code int, data []byte) (int, error) {
var serr api.StatusError
err := json.Unmarshal(data, &serr)
if err != nil {
@@ -630,7 +630,7 @@ func (w *ChatWriter) writeResponse(data []byte) (int, error) {
func (w *ChatWriter) Write(data []byte) (int, error) {
code := w.ResponseWriter.Status()
if code != http.StatusOK {
return w.writeError(data)
return w.writeError(code, data)
}
return w.writeResponse(data)
@@ -679,7 +679,7 @@ func (w *CompleteWriter) writeResponse(data []byte) (int, error) {
func (w *CompleteWriter) Write(data []byte) (int, error) {
code := w.ResponseWriter.Status()
if code != http.StatusOK {
return w.writeError(data)
return w.writeError(code, data)
}
return w.writeResponse(data)
@@ -704,7 +704,7 @@ func (w *ListWriter) writeResponse(data []byte) (int, error) {
func (w *ListWriter) Write(data []byte) (int, error) {
code := w.ResponseWriter.Status()
if code != http.StatusOK {
return w.writeError(data)
return w.writeError(code, data)
}
return w.writeResponse(data)
@@ -730,7 +730,7 @@ func (w *RetrieveWriter) writeResponse(data []byte) (int, error) {
func (w *RetrieveWriter) Write(data []byte) (int, error) {
code := w.ResponseWriter.Status()
if code != http.StatusOK {
return w.writeError(data)
return w.writeError(code, data)
}
return w.writeResponse(data)
@@ -755,7 +755,7 @@ func (w *EmbedWriter) writeResponse(data []byte) (int, error) {
func (w *EmbedWriter) Write(data []byte) (int, error) {
code := w.ResponseWriter.Status()
if code != http.StatusOK {
return w.writeError(data)
return w.writeError(code, data)
}
return w.writeResponse(data)

View File

@@ -23,7 +23,6 @@ import (
"strings"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/auth"
"github.com/ollama/ollama/envconfig"
"github.com/ollama/ollama/format"
"github.com/ollama/ollama/llama"
@@ -31,7 +30,6 @@ import (
"github.com/ollama/ollama/parser"
"github.com/ollama/ollama/template"
"github.com/ollama/ollama/types/model"
"github.com/ollama/ollama/types/registry"
"github.com/ollama/ollama/version"
)
@@ -982,6 +980,8 @@ func GetSHA256Digest(r io.Reader) (string, int64) {
return fmt.Sprintf("sha256:%x", h.Sum(nil)), n
}
var errUnauthorized = errors.New("unauthorized: access denied")
func makeRequestWithRetry(ctx context.Context, method string, requestURL *url.URL, headers http.Header, body io.ReadSeeker, regOpts *registryOptions) (*http.Response, error) {
for range 2 {
resp, err := makeRequest(ctx, method, requestURL, headers, body, regOpts)
@@ -1019,33 +1019,13 @@ func makeRequestWithRetry(ctx context.Context, method string, requestURL *url.UR
if err != nil {
return nil, fmt.Errorf("%d: %s", resp.StatusCode, err)
}
var re registry.Errs
if err := json.Unmarshal(responseBody, &re); err == nil && len(re.Errors) > 0 {
if re.HasCode(registry.ErrCodeAnonymous) {
// if the error is due to anonymous access return a custom error
// this error is used by the CLI to direct a user to add their key to an account
pubKey, nestedErr := auth.GetPublicKey()
if nestedErr != nil {
slog.Error(fmt.Sprintf("couldn't get public key: %v", nestedErr))
return nil, re
}
return nil, api.ErrUnknownOllamaKey{
Key: pubKey,
}
}
return nil, re
}
// Fallback to returning the raw response if parsing fails
return nil, fmt.Errorf("%d: %s", resp.StatusCode, responseBody)
default:
return resp, nil
}
}
// should never be reached
return nil, fmt.Errorf("failed to make upload request")
return nil, errUnauthorized
}
// testMakeRequestDialContext specifies the dial function for the http client in
@@ -1096,15 +1076,18 @@ func makeRequest(ctx context.Context, method string, requestURL *url.URL, header
req.ContentLength = contentLength
}
c := &http.Client{
resp, err := (&http.Client{
Transport: &http.Transport{
Proxy: http.ProxyFromEnvironment,
DialContext: testMakeRequestDialContext,
},
CheckRedirect: regOpts.CheckRedirect,
}).Do(req)
if err != nil {
return nil, err
}
if testMakeRequestDialContext != nil {
tr := http.DefaultTransport.(*http.Transport).Clone()
tr.DialContext = testMakeRequestDialContext
c.Transport = tr
}
return c.Do(req)
return resp, nil
}
func getValue(header, key string) string {

View File

@@ -36,6 +36,7 @@ import (
"github.com/ollama/ollama/runners"
"github.com/ollama/ollama/server/imageproc"
"github.com/ollama/ollama/template"
"github.com/ollama/ollama/types/errtypes"
"github.com/ollama/ollama/types/model"
"github.com/ollama/ollama/version"
)
@@ -609,7 +610,7 @@ func (s *Server) PushHandler(c *gin.Context) {
defer cancel()
if err := PushModel(ctx, model, regOpts, fn); err != nil {
ch <- newErr(err)
ch <- gin.H{"error": err.Error()}
}
}()
@@ -649,7 +650,7 @@ func (s *Server) CreateHandler(c *gin.Context) {
name := model.ParseName(cmp.Or(r.Model, r.Name))
if !name.IsValid() {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": api.InvalidModelNameErrMsg})
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": errtypes.InvalidModelNameErrMsg})
return
}
@@ -1549,24 +1550,3 @@ func handleScheduleError(c *gin.Context, name string, err error) {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
}
}
// newErr creates a structured API ErrorResponse from an existing error
func newErr(err error) api.ErrorResponse {
if err == nil {
return api.ErrorResponse{}
}
// Default to just returning the generic error message
resp := api.ErrorResponse{
Code: api.ErrCodeGeneral,
Message: err.Error(),
}
// Add additional error specific data, if any
var keyErr api.ErrUnknownOllamaKey
if errors.As(err, &keyErr) {
resp.Code = api.ErrCodeUnknownKey
resp.Data = map[string]any{
"key": keyErr.Key,
}
}
return resp
}

View File

@@ -0,0 +1,21 @@
// Package errtypes contains custom error types
package errtypes
import (
"fmt"
"strings"
)
const (
UnknownOllamaKeyErrMsg = "unknown ollama key"
InvalidModelNameErrMsg = "invalid model name"
)
// TODO: This should have a structured response from the API
type UnknownOllamaKey struct {
Key string
}
func (e *UnknownOllamaKey) Error() string {
return fmt.Sprintf("unauthorized: %s %q", UnknownOllamaKeyErrMsg, strings.TrimSpace(e.Key))
}

View File

@@ -1,37 +0,0 @@
package registry
import (
"fmt"
"slices"
"strings"
)
const ErrCodeAnonymous = "ANONYMOUS_ACCESS_DENIED"
type Err struct {
Code string `json:"code"`
Message string `json:"message"`
}
// Errs represents the structure of error responses from the registry
// TODO (brucemacd): this struct should be imported from some shared package that is used between the registry and ollama
type Errs struct {
Errors []Err `json:"errors"`
}
func (e Errs) Error() string {
if len(e.Errors) == 0 {
return "unknown registry error"
}
var msgs []string
for _, err := range e.Errors {
msgs = append(msgs, fmt.Sprintf("%s: %s", err.Code, err.Message))
}
return strings.Join(msgs, "; ")
}
func (e Errs) HasCode(code string) bool {
return slices.ContainsFunc(e.Errors, func(err Err) bool {
return err.Code == code
})
}