Compare commits
1 Commits
brucemacd/
...
jmorganca/
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
bc40a820f0 |
@@ -504,7 +504,7 @@ See the [API documentation](./docs/api.md) for all endpoints.
|
||||
- [Terraform AWS Ollama & Open WebUI](https://github.com/xuyangbocn/terraform-aws-self-host-llm) (A Terraform module to deploy on AWS a ready-to-use Ollama service, together with its front end Open WebUI service.)
|
||||
- [node-red-contrib-ollama](https://github.com/jakubburkiewicz/node-red-contrib-ollama)
|
||||
- [Local AI Helper](https://github.com/ivostoykov/localAI) (Chrome and Firefox extensions that enable interactions with the active tab and customisable API endpoints. Includes secure storage for user prompts.)
|
||||
- [vnc-lm](https://github.com/jake83741/vnc-lm) (Discord bot for messaging with LLMs through Ollama and LiteLLM. Seamlessly move between local and flagship models.)
|
||||
- [vnc-lm](https://github.com/jk011ru/vnc-lm) (A containerized Discord bot with support for attachments and web links)
|
||||
- [LSP-AI](https://github.com/SilasMarvin/lsp-ai) (Open-source language server for AI-powered functionality)
|
||||
- [QodeAssist](https://github.com/Palm1r/QodeAssist) (AI-powered coding assistant plugin for Qt Creator)
|
||||
- [Obsidian Quiz Generator plugin](https://github.com/ECuiDev/obsidian-quiz-generator)
|
||||
@@ -518,4 +518,3 @@ See the [API documentation](./docs/api.md) for all endpoints.
|
||||
### Observability
|
||||
|
||||
- [OpenLIT](https://github.com/openlit/openlit) is an OpenTelemetry-native tool for monitoring Ollama Applications & GPUs using traces and metrics.
|
||||
- [HoneyHive](https://docs.honeyhive.ai/integrations/ollama) is an AI observability and evaluation platform for AI agents. Use HoneyHive to evaluate agent performance, interrogate failures, and monitor quality in production.
|
||||
|
||||
@@ -163,29 +163,24 @@ func (c *Client) stream(ctx context.Context, method, path string, data any, fn f
|
||||
scanBuf := make([]byte, 0, maxBufferSize)
|
||||
scanner.Buffer(scanBuf, maxBufferSize)
|
||||
for scanner.Scan() {
|
||||
bts := scanner.Bytes()
|
||||
var errorResponse struct {
|
||||
Error string `json:"error,omitempty"`
|
||||
}
|
||||
|
||||
var errorResponse ErrorResponse
|
||||
bts := scanner.Bytes()
|
||||
if err := json.Unmarshal(bts, &errorResponse); err != nil {
|
||||
return fmt.Errorf("unmarshal: %w", err)
|
||||
}
|
||||
|
||||
switch errorResponse.Code {
|
||||
case ErrCodeUnknownKey:
|
||||
return ErrUnknownOllamaKey{
|
||||
Message: errorResponse.Message,
|
||||
Key: errorResponse.Data["key"].(string),
|
||||
}
|
||||
}
|
||||
if errorResponse.Message != "" {
|
||||
return errors.New(errorResponse.Message)
|
||||
if errorResponse.Error != "" {
|
||||
return errors.New(errorResponse.Error)
|
||||
}
|
||||
|
||||
if response.StatusCode >= http.StatusBadRequest {
|
||||
return StatusError{
|
||||
StatusCode: response.StatusCode,
|
||||
Status: response.Status,
|
||||
ErrorMessage: errorResponse.Message,
|
||||
ErrorMessage: errorResponse.Error,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -1,12 +1,6 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"testing"
|
||||
)
|
||||
|
||||
@@ -49,117 +43,3 @@ func TestClientFromEnvironment(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestStream(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
serverResponse []string
|
||||
statusCode int
|
||||
expectedError error
|
||||
}{
|
||||
{
|
||||
name: "unknown key error",
|
||||
serverResponse: []string{
|
||||
`{"error":"unauthorized access","code":"unknown_key","data":{"key":"test-key"}}`,
|
||||
},
|
||||
statusCode: http.StatusUnauthorized,
|
||||
expectedError: &ErrUnknownOllamaKey{
|
||||
Message: "unauthorized access",
|
||||
Key: "test-key",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "general error message",
|
||||
serverResponse: []string{
|
||||
`{"error":"something went wrong"}`,
|
||||
},
|
||||
statusCode: http.StatusInternalServerError,
|
||||
expectedError: fmt.Errorf("something went wrong"),
|
||||
},
|
||||
{
|
||||
name: "malformed json response",
|
||||
serverResponse: []string{
|
||||
`{invalid-json`,
|
||||
},
|
||||
statusCode: http.StatusOK,
|
||||
expectedError: fmt.Errorf("unmarshal: invalid character 'i' looking for beginning of object key string"),
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/x-ndjson")
|
||||
w.WriteHeader(tt.statusCode)
|
||||
for _, resp := range tt.serverResponse {
|
||||
fmt.Fprintln(w, resp)
|
||||
}
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
baseURL, err := url.Parse(server.URL)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to parse server URL: %v", err)
|
||||
}
|
||||
|
||||
client := &Client{
|
||||
http: server.Client(),
|
||||
base: baseURL,
|
||||
}
|
||||
|
||||
var responses [][]byte
|
||||
err = client.stream(context.Background(), "POST", "/test", "test", func(bts []byte) error {
|
||||
responses = append(responses, bts)
|
||||
return nil
|
||||
})
|
||||
|
||||
// Error checking
|
||||
if tt.expectedError == nil {
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if err == nil {
|
||||
t.Fatalf("expected error %v, got nil", tt.expectedError)
|
||||
}
|
||||
|
||||
// Check for specific error types
|
||||
var unknownKeyErr ErrUnknownOllamaKey
|
||||
if errors.As(tt.expectedError, &unknownKeyErr) {
|
||||
var gotErr ErrUnknownOllamaKey
|
||||
if !errors.As(err, &gotErr) {
|
||||
t.Fatalf("expected ErrUnknownOllamaKey, got %T", err)
|
||||
}
|
||||
if unknownKeyErr.Key != gotErr.Key {
|
||||
t.Errorf("expected key %q, got %q", unknownKeyErr.Key, gotErr.Key)
|
||||
}
|
||||
if unknownKeyErr.Message != gotErr.Message {
|
||||
t.Errorf("expected message %q, got %q", unknownKeyErr.Message, gotErr.Message)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
var statusErr StatusError
|
||||
if errors.As(tt.expectedError, &statusErr) {
|
||||
var gotErr StatusError
|
||||
if !errors.As(err, &gotErr) {
|
||||
t.Fatalf("expected StatusError, got %T", err)
|
||||
}
|
||||
if statusErr.StatusCode != gotErr.StatusCode {
|
||||
t.Errorf("expected status code %d, got %d", statusErr.StatusCode, gotErr.StatusCode)
|
||||
}
|
||||
if statusErr.ErrorMessage != gotErr.ErrorMessage {
|
||||
t.Errorf("expected error message %q, got %q", statusErr.ErrorMessage, gotErr.ErrorMessage)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// For other errors, compare error strings
|
||||
if err.Error() != tt.expectedError.Error() {
|
||||
t.Errorf("expected error %q, got %q", tt.expectedError, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,74 +0,0 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"slices"
|
||||
"strings"
|
||||
)
|
||||
|
||||
const InvalidModelNameErrMsg = "invalid model name"
|
||||
|
||||
// API error responses
|
||||
// ErrorCode represents a standardized error code identifier
|
||||
type ErrorCode string
|
||||
|
||||
const (
|
||||
ErrCodeUnknownKey ErrorCode = "unknown_key"
|
||||
ErrCodeGeneral ErrorCode = "general" // Generic fallback error code
|
||||
)
|
||||
|
||||
// ErrorResponse implements a structured error interface
|
||||
type ErrorResponse struct {
|
||||
Message string `json:"error"` // Human-readable error message, uses 'error' field name for backwards compatibility
|
||||
Code ErrorCode `json:"code"` // Machine-readable error code for programmatic handling, not response code
|
||||
Data map[string]any `json:"data"` // Additional error specific data, if any
|
||||
}
|
||||
|
||||
func (e ErrorResponse) Error() string {
|
||||
return e.Message
|
||||
}
|
||||
|
||||
type ErrUnknownOllamaKey struct {
|
||||
Message string
|
||||
Key string
|
||||
}
|
||||
|
||||
func (e ErrUnknownOllamaKey) Error() string {
|
||||
return fmt.Sprintf("unauthorized: unknown ollama key %q", strings.TrimSpace(e.Key))
|
||||
}
|
||||
|
||||
func (e *ErrUnknownOllamaKey) FormatUserMessage(localKeys []string) string {
|
||||
// The user should only be told to add the key if it is the same one that exists locally
|
||||
if slices.Index(localKeys, e.Key) == -1 {
|
||||
return e.Message
|
||||
}
|
||||
|
||||
return fmt.Sprintf(`%s
|
||||
|
||||
Your ollama key is:
|
||||
%s
|
||||
Add your key at:
|
||||
https://ollama.com/settings/keys`, e.Message, e.Key)
|
||||
}
|
||||
|
||||
// StatusError is an error with an HTTP status code and message,
|
||||
// it is parsed on the client-side and not returned from the API
|
||||
type StatusError struct {
|
||||
StatusCode int // e.g. 200
|
||||
Status string // e.g. "200 OK"
|
||||
ErrorMessage string `json:"error"`
|
||||
}
|
||||
|
||||
func (e StatusError) Error() string {
|
||||
switch {
|
||||
case e.Status != "" && e.ErrorMessage != "":
|
||||
return fmt.Sprintf("%s: %s", e.Status, e.ErrorMessage)
|
||||
case e.Status != "":
|
||||
return e.Status
|
||||
case e.ErrorMessage != "":
|
||||
return e.ErrorMessage
|
||||
default:
|
||||
// this should not happen
|
||||
return "something went wrong, please see the ollama server logs for details"
|
||||
}
|
||||
}
|
||||
21
api/types.go
21
api/types.go
@@ -12,6 +12,27 @@ import (
|
||||
"time"
|
||||
)
|
||||
|
||||
// StatusError is an error with an HTTP status code and message.
|
||||
type StatusError struct {
|
||||
StatusCode int
|
||||
Status string
|
||||
ErrorMessage string `json:"error"`
|
||||
}
|
||||
|
||||
func (e StatusError) Error() string {
|
||||
switch {
|
||||
case e.Status != "" && e.ErrorMessage != "":
|
||||
return fmt.Sprintf("%s: %s", e.Status, e.ErrorMessage)
|
||||
case e.Status != "":
|
||||
return e.Status
|
||||
case e.ErrorMessage != "":
|
||||
return e.ErrorMessage
|
||||
default:
|
||||
// this should not happen
|
||||
return "something went wrong, please see the ollama server logs for details"
|
||||
}
|
||||
}
|
||||
|
||||
// ImageData represents the raw binary data of an image file.
|
||||
type ImageData []byte
|
||||
|
||||
|
||||
41
cmd/cmd.go
41
cmd/cmd.go
@@ -34,13 +34,11 @@ import (
|
||||
"golang.org/x/term"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/auth"
|
||||
"github.com/ollama/ollama/envconfig"
|
||||
"github.com/ollama/ollama/format"
|
||||
"github.com/ollama/ollama/parser"
|
||||
"github.com/ollama/ollama/progress"
|
||||
"github.com/ollama/ollama/server"
|
||||
"github.com/ollama/ollama/types/model"
|
||||
"github.com/ollama/ollama/version"
|
||||
)
|
||||
|
||||
@@ -514,24 +512,6 @@ func RunHandler(cmd *cobra.Command, args []string) error {
|
||||
return generate(cmd, opts)
|
||||
}
|
||||
|
||||
func localPubKeys() ([]string, error) {
|
||||
usrKey, err := auth.GetPublicKey()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
keys := []string{usrKey}
|
||||
|
||||
if runtime.GOOS == "linux" {
|
||||
// try the ollama service public key if on Linux
|
||||
if svcKey, err := os.ReadFile("/usr/share/ollama/.ollama/id_ed25519.pub"); err == nil {
|
||||
keys = append(keys, strings.TrimSpace(string(svcKey)))
|
||||
}
|
||||
}
|
||||
|
||||
return keys, nil
|
||||
}
|
||||
|
||||
func PushHandler(cmd *cobra.Command, args []string) error {
|
||||
client, err := api.ClientFromEnvironment()
|
||||
if err != nil {
|
||||
@@ -578,36 +558,17 @@ func PushHandler(cmd *cobra.Command, args []string) error {
|
||||
}
|
||||
|
||||
request := api.PushRequest{Name: args[0], Insecure: insecure}
|
||||
|
||||
n := model.ParseName(args[0])
|
||||
isOllamaHost := strings.HasSuffix(n.Host, ".ollama.ai") || strings.HasSuffix(n.Host, ".ollama.com")
|
||||
if err := client.Push(cmd.Context(), &request, fn); err != nil {
|
||||
if spinner != nil {
|
||||
spinner.Stop()
|
||||
}
|
||||
var ke api.ErrUnknownOllamaKey
|
||||
if errors.As(err, &ke) && isOllamaHost {
|
||||
// the user has not added their ollama key to ollama.com
|
||||
// return an error with a more user-friendly message
|
||||
locals, _ := localPubKeys()
|
||||
return errors.New(ke.FormatUserMessage(locals))
|
||||
}
|
||||
if strings.Contains(err.Error(), "access denied") {
|
||||
return errors.New("you are not authorized to push to this namespace, create the model under a namespace you own")
|
||||
}
|
||||
return fmt.Errorf("yoyoyo: %w", err)
|
||||
return err
|
||||
}
|
||||
|
||||
p.Stop()
|
||||
spinner.Stop()
|
||||
|
||||
destination := n.String()
|
||||
if isOllamaHost {
|
||||
destination = "https://ollama.com/" + strings.TrimSuffix(n.DisplayShortest(), ":latest")
|
||||
}
|
||||
fmt.Printf("\nYou can find your model at:\n\n")
|
||||
fmt.Printf("\t%s\n", destination)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
141
cmd/cmd_test.go
141
cmd/cmd_test.go
@@ -4,7 +4,6 @@ import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
@@ -370,143 +369,3 @@ func TestGetModelfileName(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestPushHandler(t *testing.T) {
|
||||
tests := []struct {
|
||||
modelName string
|
||||
serverResponse map[string]func(w http.ResponseWriter, r *http.Request)
|
||||
expectedError string
|
||||
expectedOutput string
|
||||
}{
|
||||
{
|
||||
modelName: "successful-push",
|
||||
serverResponse: map[string]func(w http.ResponseWriter, r *http.Request){
|
||||
"/api/push": func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodPost {
|
||||
t.Errorf("expected POST request, got %s", r.Method)
|
||||
}
|
||||
|
||||
var req api.PushRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
http.Error(w, err.Error(), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
if req.Name != "successful-push" {
|
||||
t.Errorf("expected model name 'successful-push', got %s", req.Name)
|
||||
}
|
||||
|
||||
// Simulate progress updates
|
||||
responses := []api.ProgressResponse{
|
||||
{Status: "preparing manifest"},
|
||||
{Digest: "sha256:abc123456789", Total: 100, Completed: 50},
|
||||
{Digest: "sha256:abc123456789", Total: 100, Completed: 100},
|
||||
}
|
||||
|
||||
for _, resp := range responses {
|
||||
if err := json.NewEncoder(w).Encode(resp); err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
w.(http.Flusher).Flush()
|
||||
}
|
||||
},
|
||||
},
|
||||
expectedOutput: "\nYou can find your model at:\n\n\thttps://ollama.com/successful-push\n",
|
||||
},
|
||||
{
|
||||
modelName: "unauthorized-push",
|
||||
serverResponse: map[string]func(w http.ResponseWriter, r *http.Request){
|
||||
"/api/push": func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusUnauthorized)
|
||||
err := json.NewEncoder(w).Encode(map[string]string{
|
||||
"error": "access denied",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
},
|
||||
},
|
||||
expectedError: "you are not authorized to push to this namespace, create the model under a namespace you own",
|
||||
},
|
||||
{
|
||||
modelName: "unknown-key-err",
|
||||
serverResponse: map[string]func(w http.ResponseWriter, r *http.Request){
|
||||
"/api/push": func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusUnauthorized)
|
||||
uerr := api.ErrUnknownOllamaKey{
|
||||
Key: "aaa",
|
||||
}
|
||||
err := json.NewEncoder(w).Encode(map[string]string{
|
||||
"error": uerr.Error(),
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
},
|
||||
},
|
||||
expectedError: "unauthorized: unknown ollama key \"aaa\"",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.modelName, func(t *testing.T) {
|
||||
mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if handler, ok := tt.serverResponse[r.URL.Path]; ok {
|
||||
handler(w, r)
|
||||
return
|
||||
}
|
||||
http.Error(w, "not found", http.StatusNotFound)
|
||||
}))
|
||||
defer mockServer.Close()
|
||||
|
||||
t.Setenv("OLLAMA_HOST", mockServer.URL)
|
||||
|
||||
cmd := &cobra.Command{}
|
||||
cmd.Flags().Bool("insecure", false, "")
|
||||
cmd.SetContext(context.TODO())
|
||||
|
||||
// Redirect stderr to capture progress output
|
||||
oldStderr := os.Stderr
|
||||
r, w, _ := os.Pipe()
|
||||
os.Stderr = w
|
||||
|
||||
// Capture stdout for the "Model pushed" message
|
||||
oldStdout := os.Stdout
|
||||
outR, outW, _ := os.Pipe()
|
||||
os.Stdout = outW
|
||||
|
||||
err := PushHandler(cmd, []string{tt.modelName})
|
||||
|
||||
// Restore stderr
|
||||
w.Close()
|
||||
os.Stderr = oldStderr
|
||||
// drain the pipe
|
||||
if _, err := io.ReadAll(r); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Restore stdout and get output
|
||||
outW.Close()
|
||||
os.Stdout = oldStdout
|
||||
stdout, _ := io.ReadAll(outR)
|
||||
|
||||
if tt.expectedError == "" {
|
||||
if err != nil {
|
||||
t.Errorf("expected no error, got %v", err)
|
||||
}
|
||||
if tt.expectedOutput != "" {
|
||||
if got := string(stdout); got != tt.expectedOutput {
|
||||
t.Errorf("expected output %q, got %q", tt.expectedOutput, got)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
if err == nil || !strings.Contains(err.Error(), tt.expectedError) {
|
||||
t.Errorf("expected error containing %q, got %v", tt.expectedError, err)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -19,6 +19,7 @@ import (
|
||||
"github.com/ollama/ollama/envconfig"
|
||||
"github.com/ollama/ollama/parser"
|
||||
"github.com/ollama/ollama/readline"
|
||||
"github.com/ollama/ollama/types/errtypes"
|
||||
)
|
||||
|
||||
type MultilineState int
|
||||
@@ -219,7 +220,7 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error {
|
||||
fn := func(resp api.ProgressResponse) error { return nil }
|
||||
err = client.Create(cmd.Context(), req, fn)
|
||||
if err != nil {
|
||||
if strings.Contains(err.Error(), api.InvalidModelNameErrMsg) {
|
||||
if strings.Contains(err.Error(), errtypes.InvalidModelNameErrMsg) {
|
||||
fmt.Printf("error: The model name '%s' is invalid\n", args[1])
|
||||
continue
|
||||
}
|
||||
@@ -513,7 +514,7 @@ func extractFileNames(input string) []string {
|
||||
// Regex to match file paths starting with optional drive letter, / ./ \ or .\ and include escaped or unescaped spaces (\ or %20)
|
||||
// and followed by more characters and a file extension
|
||||
// This will capture non filename strings, but we'll check for file existence to remove mismatches
|
||||
regexPattern := `(?:[a-zA-Z]:)?(?:\./|/|\\)[\S\\ ]+?\.(?i:jpg|jpeg|png)\b`
|
||||
regexPattern := `(?:[a-zA-Z]:)?(?:\./|/|\\)[\S\\ ]+?\.(?i:jpg|jpeg|png|svg)\b`
|
||||
re := regexp.MustCompile(regexPattern)
|
||||
|
||||
return re.FindAllString(input, -1)
|
||||
|
||||
@@ -12,45 +12,44 @@ import (
|
||||
func TestExtractFilenames(t *testing.T) {
|
||||
// Unix style paths
|
||||
input := ` some preamble
|
||||
./relative\ path/one.png inbetween1 ./not a valid two.jpg inbetween2 ./1.svg
|
||||
/unescaped space /three.jpeg inbetween3 /valid\ path/dir/four.png "./quoted with spaces/five.JPG`
|
||||
./relative\ path/one.png inbetween1 ./not a valid two.jpg inbetween2
|
||||
/unescaped space /three.jpeg inbetween3 /valid\ path/dir/four.png "./quoted with spaces/five.svg`
|
||||
res := extractFileNames(input)
|
||||
assert.Len(t, res, 5)
|
||||
assert.Contains(t, res[0], "one.png")
|
||||
assert.Contains(t, res[1], "two.jpg")
|
||||
assert.Contains(t, res[2], "three.jpeg")
|
||||
assert.Contains(t, res[3], "four.png")
|
||||
assert.Contains(t, res[4], "five.JPG")
|
||||
assert.Contains(t, res[4], "five.svg")
|
||||
assert.NotContains(t, res[4], '"')
|
||||
assert.NotContains(t, res, "inbetween1")
|
||||
assert.NotContains(t, res, "./1.svg")
|
||||
assert.NotContains(t, res, "inbtween")
|
||||
|
||||
// Windows style paths
|
||||
input = ` some preamble
|
||||
c:/users/jdoe/one.png inbetween1 c:/program files/someplace/two.jpg inbetween2
|
||||
/absolute/nospace/three.jpeg inbetween3 /absolute/with space/four.png inbetween4
|
||||
./relative\ path/five.JPG inbetween5 "./relative with/spaces/six.png inbetween6
|
||||
d:\path with\spaces\seven.JPEG inbetween7 c:\users\jdoe\eight.png inbetween8
|
||||
d:\program files\someplace\nine.png inbetween9 "E:\program files\someplace\ten.PNG some ending
|
||||
./relative\ path/five.svg inbetween5 "./relative with/spaces/six.png inbetween6
|
||||
d:\path with\spaces\seven.svg inbetween7 c:\users\jdoe\eight.png inbetween8
|
||||
d:\program files\someplace\nine.png inbetween9 "E:\program files\someplace\ten.svg some ending
|
||||
`
|
||||
res = extractFileNames(input)
|
||||
assert.Len(t, res, 10)
|
||||
assert.NotContains(t, res, "inbetween2")
|
||||
assert.NotContains(t, res, "inbtween")
|
||||
assert.Contains(t, res[0], "one.png")
|
||||
assert.Contains(t, res[0], "c:")
|
||||
assert.Contains(t, res[1], "two.jpg")
|
||||
assert.Contains(t, res[1], "c:")
|
||||
assert.Contains(t, res[2], "three.jpeg")
|
||||
assert.Contains(t, res[3], "four.png")
|
||||
assert.Contains(t, res[4], "five.JPG")
|
||||
assert.Contains(t, res[4], "five.svg")
|
||||
assert.Contains(t, res[5], "six.png")
|
||||
assert.Contains(t, res[6], "seven.JPEG")
|
||||
assert.Contains(t, res[6], "seven.svg")
|
||||
assert.Contains(t, res[6], "d:")
|
||||
assert.Contains(t, res[7], "eight.png")
|
||||
assert.Contains(t, res[7], "c:")
|
||||
assert.Contains(t, res[8], "nine.png")
|
||||
assert.Contains(t, res[8], "d:")
|
||||
assert.Contains(t, res[9], "ten.PNG")
|
||||
assert.Contains(t, res[9], "ten.svg")
|
||||
assert.Contains(t, res[9], "E:")
|
||||
}
|
||||
|
||||
|
||||
@@ -199,20 +199,6 @@ func countCommonPrefix(a []input, b []input) int {
|
||||
return count
|
||||
}
|
||||
|
||||
func (c *InputCache) ShiftDiscard(inputLen int, numKeep int) int {
|
||||
targetFree := (c.numCtx - numKeep) / 2
|
||||
targetFree = max(targetFree, 1)
|
||||
|
||||
currentFree := c.numCtx - inputLen
|
||||
discard := targetFree - currentFree
|
||||
|
||||
if discard < 0 {
|
||||
discard = 0
|
||||
}
|
||||
|
||||
return discard
|
||||
}
|
||||
|
||||
// Frees up space in the KV cache by deleting the oldest half of history and shifting
|
||||
// the newest half into that space (saving numKeep inputs at the beginning).
|
||||
//
|
||||
@@ -222,7 +208,11 @@ func (c *InputCache) ShiftCacheSlot(slot *InputCacheSlot, numKeep int) error {
|
||||
return fmt.Errorf("unable to shift context - keep exceeds context (keep: %v context: %v)", numKeep, c.numCtx)
|
||||
}
|
||||
|
||||
discard := c.ShiftDiscard(len(slot.Inputs), numKeep)
|
||||
targetFree := (c.numCtx - numKeep) / 2
|
||||
targetFree = max(targetFree, 1)
|
||||
|
||||
currentFree := c.numCtx - len(slot.Inputs)
|
||||
discard := targetFree - currentFree
|
||||
|
||||
if discard <= 0 {
|
||||
return nil
|
||||
|
||||
@@ -227,66 +227,3 @@ func TestFindCacheSlot(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestShiftDiscard(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
numCtx int
|
||||
numKeep int
|
||||
inputLen int
|
||||
expected int
|
||||
}{
|
||||
{
|
||||
name: "Shift",
|
||||
numCtx: 2048,
|
||||
numKeep: 5,
|
||||
inputLen: 2048,
|
||||
expected: 1021,
|
||||
},
|
||||
{
|
||||
name: "Max Keep",
|
||||
numCtx: 2048,
|
||||
numKeep: 2047,
|
||||
inputLen: 2048,
|
||||
expected: 1,
|
||||
},
|
||||
{
|
||||
name: "No Keep",
|
||||
numCtx: 2048,
|
||||
numKeep: 0,
|
||||
inputLen: 2048,
|
||||
expected: 1024,
|
||||
},
|
||||
{
|
||||
name: "Truncate",
|
||||
numCtx: 2048,
|
||||
numKeep: 5,
|
||||
inputLen: 5000,
|
||||
expected: 3973,
|
||||
},
|
||||
{
|
||||
name: "Truncate Keep",
|
||||
numCtx: 2048,
|
||||
numKeep: 2047,
|
||||
inputLen: 5000,
|
||||
expected: 2953,
|
||||
},
|
||||
{
|
||||
name: "No Op",
|
||||
numCtx: 2048,
|
||||
numKeep: 5,
|
||||
inputLen: 512,
|
||||
expected: 0,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
c := InputCache{numCtx: tt.numCtx}
|
||||
result := c.ShiftDiscard(tt.inputLen, tt.numKeep)
|
||||
if result != tt.expected {
|
||||
t.Errorf("shiftDiscard(ctx: %v, keep: %v input: %v): have %v; want %v", tt.numCtx, tt.numKeep, tt.inputLen, result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -122,11 +122,9 @@ func (s *Server) NewSequence(prompt string, images []ImageData, params NewSequen
|
||||
params.numKeep = min(params.numKeep, s.cache.numCtx-1)
|
||||
|
||||
if len(inputs) > s.cache.numCtx {
|
||||
discard := len(inputs) - s.cache.numCtx
|
||||
slog.Warn("truncating input prompt", "limit", s.cache.numCtx, "prompt", len(inputs), "numKeep", params.numKeep)
|
||||
newInputs := inputs[:params.numKeep]
|
||||
newInputs = append(newInputs, inputs[params.numKeep+discard:]...)
|
||||
|
||||
slog.Warn("truncating input prompt", "limit", s.cache.numCtx, "prompt", len(inputs), "keep", params.numKeep, "new", len(newInputs))
|
||||
newInputs = append(newInputs, inputs[len(inputs)-s.cache.numCtx+params.numKeep:]...)
|
||||
inputs = newInputs
|
||||
}
|
||||
|
||||
@@ -164,16 +162,10 @@ func (s *Server) NewSequence(prompt string, images []ImageData, params NewSequen
|
||||
// generating image embeddings for each image
|
||||
func (s *Server) inputs(prompt string, images []ImageData) ([]input, error) {
|
||||
var inputs []input
|
||||
var parts []string
|
||||
var matches [][]string
|
||||
|
||||
if s.image != nil {
|
||||
re := regexp.MustCompile(`\[img-(\d+)\]`)
|
||||
parts = re.Split(prompt, -1)
|
||||
matches = re.FindAllStringSubmatch(prompt, -1)
|
||||
} else {
|
||||
parts = []string{prompt}
|
||||
}
|
||||
re := regexp.MustCompile(`\[img-(\d+)\]`)
|
||||
parts := re.Split(prompt, -1)
|
||||
matches := re.FindAllStringSubmatch(prompt, -1)
|
||||
|
||||
for i, part := range parts {
|
||||
// text - tokenize
|
||||
|
||||
@@ -571,7 +571,7 @@ type EmbedWriter struct {
|
||||
model string
|
||||
}
|
||||
|
||||
func (w *BaseWriter) writeError(data []byte) (int, error) {
|
||||
func (w *BaseWriter) writeError(code int, data []byte) (int, error) {
|
||||
var serr api.StatusError
|
||||
err := json.Unmarshal(data, &serr)
|
||||
if err != nil {
|
||||
@@ -630,7 +630,7 @@ func (w *ChatWriter) writeResponse(data []byte) (int, error) {
|
||||
func (w *ChatWriter) Write(data []byte) (int, error) {
|
||||
code := w.ResponseWriter.Status()
|
||||
if code != http.StatusOK {
|
||||
return w.writeError(data)
|
||||
return w.writeError(code, data)
|
||||
}
|
||||
|
||||
return w.writeResponse(data)
|
||||
@@ -679,7 +679,7 @@ func (w *CompleteWriter) writeResponse(data []byte) (int, error) {
|
||||
func (w *CompleteWriter) Write(data []byte) (int, error) {
|
||||
code := w.ResponseWriter.Status()
|
||||
if code != http.StatusOK {
|
||||
return w.writeError(data)
|
||||
return w.writeError(code, data)
|
||||
}
|
||||
|
||||
return w.writeResponse(data)
|
||||
@@ -704,7 +704,7 @@ func (w *ListWriter) writeResponse(data []byte) (int, error) {
|
||||
func (w *ListWriter) Write(data []byte) (int, error) {
|
||||
code := w.ResponseWriter.Status()
|
||||
if code != http.StatusOK {
|
||||
return w.writeError(data)
|
||||
return w.writeError(code, data)
|
||||
}
|
||||
|
||||
return w.writeResponse(data)
|
||||
@@ -730,7 +730,7 @@ func (w *RetrieveWriter) writeResponse(data []byte) (int, error) {
|
||||
func (w *RetrieveWriter) Write(data []byte) (int, error) {
|
||||
code := w.ResponseWriter.Status()
|
||||
if code != http.StatusOK {
|
||||
return w.writeError(data)
|
||||
return w.writeError(code, data)
|
||||
}
|
||||
|
||||
return w.writeResponse(data)
|
||||
@@ -755,7 +755,7 @@ func (w *EmbedWriter) writeResponse(data []byte) (int, error) {
|
||||
func (w *EmbedWriter) Write(data []byte) (int, error) {
|
||||
code := w.ResponseWriter.Status()
|
||||
if code != http.StatusOK {
|
||||
return w.writeError(data)
|
||||
return w.writeError(code, data)
|
||||
}
|
||||
|
||||
return w.writeResponse(data)
|
||||
|
||||
@@ -23,7 +23,6 @@ import (
|
||||
"strings"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/auth"
|
||||
"github.com/ollama/ollama/envconfig"
|
||||
"github.com/ollama/ollama/format"
|
||||
"github.com/ollama/ollama/llama"
|
||||
@@ -31,7 +30,6 @@ import (
|
||||
"github.com/ollama/ollama/parser"
|
||||
"github.com/ollama/ollama/template"
|
||||
"github.com/ollama/ollama/types/model"
|
||||
"github.com/ollama/ollama/types/registry"
|
||||
"github.com/ollama/ollama/version"
|
||||
)
|
||||
|
||||
@@ -982,6 +980,8 @@ func GetSHA256Digest(r io.Reader) (string, int64) {
|
||||
return fmt.Sprintf("sha256:%x", h.Sum(nil)), n
|
||||
}
|
||||
|
||||
var errUnauthorized = errors.New("unauthorized: access denied")
|
||||
|
||||
func makeRequestWithRetry(ctx context.Context, method string, requestURL *url.URL, headers http.Header, body io.ReadSeeker, regOpts *registryOptions) (*http.Response, error) {
|
||||
for range 2 {
|
||||
resp, err := makeRequest(ctx, method, requestURL, headers, body, regOpts)
|
||||
@@ -1019,33 +1019,13 @@ func makeRequestWithRetry(ctx context.Context, method string, requestURL *url.UR
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("%d: %s", resp.StatusCode, err)
|
||||
}
|
||||
|
||||
var re registry.Errs
|
||||
if err := json.Unmarshal(responseBody, &re); err == nil && len(re.Errors) > 0 {
|
||||
if re.HasCode(registry.ErrCodeAnonymous) {
|
||||
// if the error is due to anonymous access return a custom error
|
||||
// this error is used by the CLI to direct a user to add their key to an account
|
||||
pubKey, nestedErr := auth.GetPublicKey()
|
||||
if nestedErr != nil {
|
||||
slog.Error(fmt.Sprintf("couldn't get public key: %v", nestedErr))
|
||||
return nil, re
|
||||
}
|
||||
return nil, api.ErrUnknownOllamaKey{
|
||||
Key: pubKey,
|
||||
}
|
||||
}
|
||||
return nil, re
|
||||
}
|
||||
|
||||
// Fallback to returning the raw response if parsing fails
|
||||
return nil, fmt.Errorf("%d: %s", resp.StatusCode, responseBody)
|
||||
default:
|
||||
return resp, nil
|
||||
}
|
||||
}
|
||||
|
||||
// should never be reached
|
||||
return nil, fmt.Errorf("failed to make upload request")
|
||||
return nil, errUnauthorized
|
||||
}
|
||||
|
||||
// testMakeRequestDialContext specifies the dial function for the http client in
|
||||
@@ -1096,15 +1076,18 @@ func makeRequest(ctx context.Context, method string, requestURL *url.URL, header
|
||||
req.ContentLength = contentLength
|
||||
}
|
||||
|
||||
c := &http.Client{
|
||||
resp, err := (&http.Client{
|
||||
Transport: &http.Transport{
|
||||
Proxy: http.ProxyFromEnvironment,
|
||||
DialContext: testMakeRequestDialContext,
|
||||
},
|
||||
CheckRedirect: regOpts.CheckRedirect,
|
||||
}).Do(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if testMakeRequestDialContext != nil {
|
||||
tr := http.DefaultTransport.(*http.Transport).Clone()
|
||||
tr.DialContext = testMakeRequestDialContext
|
||||
c.Transport = tr
|
||||
}
|
||||
return c.Do(req)
|
||||
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
func getValue(header, key string) string {
|
||||
|
||||
@@ -36,6 +36,7 @@ import (
|
||||
"github.com/ollama/ollama/runners"
|
||||
"github.com/ollama/ollama/server/imageproc"
|
||||
"github.com/ollama/ollama/template"
|
||||
"github.com/ollama/ollama/types/errtypes"
|
||||
"github.com/ollama/ollama/types/model"
|
||||
"github.com/ollama/ollama/version"
|
||||
)
|
||||
@@ -609,7 +610,7 @@ func (s *Server) PushHandler(c *gin.Context) {
|
||||
defer cancel()
|
||||
|
||||
if err := PushModel(ctx, model, regOpts, fn); err != nil {
|
||||
ch <- newErr(err)
|
||||
ch <- gin.H{"error": err.Error()}
|
||||
}
|
||||
}()
|
||||
|
||||
@@ -649,7 +650,7 @@ func (s *Server) CreateHandler(c *gin.Context) {
|
||||
|
||||
name := model.ParseName(cmp.Or(r.Model, r.Name))
|
||||
if !name.IsValid() {
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": api.InvalidModelNameErrMsg})
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": errtypes.InvalidModelNameErrMsg})
|
||||
return
|
||||
}
|
||||
|
||||
@@ -1549,24 +1550,3 @@ func handleScheduleError(c *gin.Context, name string, err error) {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
}
|
||||
}
|
||||
|
||||
// newErr creates a structured API ErrorResponse from an existing error
|
||||
func newErr(err error) api.ErrorResponse {
|
||||
if err == nil {
|
||||
return api.ErrorResponse{}
|
||||
}
|
||||
// Default to just returning the generic error message
|
||||
resp := api.ErrorResponse{
|
||||
Code: api.ErrCodeGeneral,
|
||||
Message: err.Error(),
|
||||
}
|
||||
// Add additional error specific data, if any
|
||||
var keyErr api.ErrUnknownOllamaKey
|
||||
if errors.As(err, &keyErr) {
|
||||
resp.Code = api.ErrCodeUnknownKey
|
||||
resp.Data = map[string]any{
|
||||
"key": keyErr.Key,
|
||||
}
|
||||
}
|
||||
return resp
|
||||
}
|
||||
|
||||
21
types/errtypes/errtypes.go
Normal file
21
types/errtypes/errtypes.go
Normal file
@@ -0,0 +1,21 @@
|
||||
// Package errtypes contains custom error types
|
||||
package errtypes
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
)
|
||||
|
||||
const (
|
||||
UnknownOllamaKeyErrMsg = "unknown ollama key"
|
||||
InvalidModelNameErrMsg = "invalid model name"
|
||||
)
|
||||
|
||||
// TODO: This should have a structured response from the API
|
||||
type UnknownOllamaKey struct {
|
||||
Key string
|
||||
}
|
||||
|
||||
func (e *UnknownOllamaKey) Error() string {
|
||||
return fmt.Sprintf("unauthorized: %s %q", UnknownOllamaKeyErrMsg, strings.TrimSpace(e.Key))
|
||||
}
|
||||
@@ -1,37 +0,0 @@
|
||||
package registry
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"slices"
|
||||
"strings"
|
||||
)
|
||||
|
||||
const ErrCodeAnonymous = "ANONYMOUS_ACCESS_DENIED"
|
||||
|
||||
type Err struct {
|
||||
Code string `json:"code"`
|
||||
Message string `json:"message"`
|
||||
}
|
||||
|
||||
// Errs represents the structure of error responses from the registry
|
||||
// TODO (brucemacd): this struct should be imported from some shared package that is used between the registry and ollama
|
||||
type Errs struct {
|
||||
Errors []Err `json:"errors"`
|
||||
}
|
||||
|
||||
func (e Errs) Error() string {
|
||||
if len(e.Errors) == 0 {
|
||||
return "unknown registry error"
|
||||
}
|
||||
var msgs []string
|
||||
for _, err := range e.Errors {
|
||||
msgs = append(msgs, fmt.Sprintf("%s: %s", err.Code, err.Message))
|
||||
}
|
||||
return strings.Join(msgs, "; ")
|
||||
}
|
||||
|
||||
func (e Errs) HasCode(code string) bool {
|
||||
return slices.ContainsFunc(e.Errors, func(err Err) bool {
|
||||
return err.Code == code
|
||||
})
|
||||
}
|
||||
Reference in New Issue
Block a user