Compare commits

..

2 Commits

Author SHA1 Message Date
Bruce MacDonald
e19c64e047 lint fix 2024-12-12 13:07:09 -08:00
Bruce MacDonald
9e190ac4d9 api: return structured error on unauthorized push
This commit implements a structured error response system for the Ollama API, replacing
ad-hoc error handling and string parsing with proper error types and codes through a new
ErrorResponse struct. Instead of relying on regex to parse error messages for SSH keys,
the API now passes this data in a structured format with standardized fields for error
messages, codes, and additional data. This structured approach makes the API more
maintainable and reliable while improving the developer experience by enabling
programmatic error handling, consistent error formats, and better error
documentation.
2024-12-12 11:49:36 -08:00
12 changed files with 249 additions and 136 deletions

View File

@@ -163,24 +163,29 @@ func (c *Client) stream(ctx context.Context, method, path string, data any, fn f
scanBuf := make([]byte, 0, maxBufferSize)
scanner.Buffer(scanBuf, maxBufferSize)
for scanner.Scan() {
var errorResponse struct {
Error string `json:"error,omitempty"`
}
bts := scanner.Bytes()
var errorResponse ErrorResponse
if err := json.Unmarshal(bts, &errorResponse); err != nil {
return fmt.Errorf("unmarshal: %w", err)
}
if errorResponse.Error != "" {
return errors.New(errorResponse.Error)
switch errorResponse.Code {
case ErrCodeUnknownKey:
return ErrUnknownOllamaKey{
Message: errorResponse.Message,
Key: errorResponse.Data["key"].(string),
}
}
if errorResponse.Message != "" {
return errors.New(errorResponse.Message)
}
if response.StatusCode >= http.StatusBadRequest {
return StatusError{
StatusCode: response.StatusCode,
Status: response.Status,
ErrorMessage: errorResponse.Error,
ErrorMessage: errorResponse.Message,
}
}

View File

@@ -1,6 +1,12 @@
package api
import (
"context"
"errors"
"fmt"
"net/http"
"net/http/httptest"
"net/url"
"testing"
)
@@ -43,3 +49,117 @@ func TestClientFromEnvironment(t *testing.T) {
})
}
}
func TestStream(t *testing.T) {
tests := []struct {
name string
serverResponse []string
statusCode int
expectedError error
}{
{
name: "unknown key error",
serverResponse: []string{
`{"error":"unauthorized access","code":"unknown_key","data":{"key":"test-key"}}`,
},
statusCode: http.StatusUnauthorized,
expectedError: &ErrUnknownOllamaKey{
Message: "unauthorized access",
Key: "test-key",
},
},
{
name: "general error message",
serverResponse: []string{
`{"error":"something went wrong"}`,
},
statusCode: http.StatusInternalServerError,
expectedError: fmt.Errorf("something went wrong"),
},
{
name: "malformed json response",
serverResponse: []string{
`{invalid-json`,
},
statusCode: http.StatusOK,
expectedError: fmt.Errorf("unmarshal: invalid character 'i' looking for beginning of object key string"),
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/x-ndjson")
w.WriteHeader(tt.statusCode)
for _, resp := range tt.serverResponse {
fmt.Fprintln(w, resp)
}
}))
defer server.Close()
baseURL, err := url.Parse(server.URL)
if err != nil {
t.Fatalf("failed to parse server URL: %v", err)
}
client := &Client{
http: server.Client(),
base: baseURL,
}
var responses [][]byte
err = client.stream(context.Background(), "POST", "/test", "test", func(bts []byte) error {
responses = append(responses, bts)
return nil
})
// Error checking
if tt.expectedError == nil {
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
return
}
if err == nil {
t.Fatalf("expected error %v, got nil", tt.expectedError)
}
// Check for specific error types
var unknownKeyErr ErrUnknownOllamaKey
if errors.As(tt.expectedError, &unknownKeyErr) {
var gotErr ErrUnknownOllamaKey
if !errors.As(err, &gotErr) {
t.Fatalf("expected ErrUnknownOllamaKey, got %T", err)
}
if unknownKeyErr.Key != gotErr.Key {
t.Errorf("expected key %q, got %q", unknownKeyErr.Key, gotErr.Key)
}
if unknownKeyErr.Message != gotErr.Message {
t.Errorf("expected message %q, got %q", unknownKeyErr.Message, gotErr.Message)
}
return
}
var statusErr StatusError
if errors.As(tt.expectedError, &statusErr) {
var gotErr StatusError
if !errors.As(err, &gotErr) {
t.Fatalf("expected StatusError, got %T", err)
}
if statusErr.StatusCode != gotErr.StatusCode {
t.Errorf("expected status code %d, got %d", statusErr.StatusCode, gotErr.StatusCode)
}
if statusErr.ErrorMessage != gotErr.ErrorMessage {
t.Errorf("expected error message %q, got %q", statusErr.ErrorMessage, gotErr.ErrorMessage)
}
return
}
// For other errors, compare error strings
if err.Error() != tt.expectedError.Error() {
t.Errorf("expected error %q, got %q", tt.expectedError, err)
}
})
}
}

74
api/errors.go Normal file
View File

@@ -0,0 +1,74 @@
package api
import (
"fmt"
"slices"
"strings"
)
const InvalidModelNameErrMsg = "invalid model name"
// API error responses
// ErrorCode represents a standardized error code identifier
type ErrorCode string
const (
ErrCodeUnknownKey ErrorCode = "unknown_key"
ErrCodeGeneral ErrorCode = "general" // Generic fallback error code
)
// ErrorResponse implements a structured error interface
type ErrorResponse struct {
Message string `json:"error"` // Human-readable error message, uses 'error' field name for backwards compatibility
Code ErrorCode `json:"code"` // Machine-readable error code for programmatic handling, not response code
Data map[string]any `json:"data"` // Additional error specific data, if any
}
func (e ErrorResponse) Error() string {
return e.Message
}
type ErrUnknownOllamaKey struct {
Message string
Key string
}
func (e ErrUnknownOllamaKey) Error() string {
return fmt.Sprintf("unauthorized: unknown ollama key %q", strings.TrimSpace(e.Key))
}
func (e *ErrUnknownOllamaKey) FormatUserMessage(localKeys []string) string {
// The user should only be told to add the key if it is the same one that exists locally
if slices.Index(localKeys, e.Key) == -1 {
return e.Message
}
return fmt.Sprintf(`%s
Your ollama key is:
%s
Add your key at:
https://ollama.com/settings/keys`, e.Message, e.Key)
}
// StatusError is an error with an HTTP status code and message,
// it is parsed on the client-side and not returned from the API
type StatusError struct {
StatusCode int // e.g. 200
Status string // e.g. "200 OK"
ErrorMessage string `json:"error"`
}
func (e StatusError) Error() string {
switch {
case e.Status != "" && e.ErrorMessage != "":
return fmt.Sprintf("%s: %s", e.Status, e.ErrorMessage)
case e.Status != "":
return e.Status
case e.ErrorMessage != "":
return e.ErrorMessage
default:
// this should not happen
return "something went wrong, please see the ollama server logs for details"
}
}

View File

@@ -12,27 +12,6 @@ import (
"time"
)
// StatusError is an error with an HTTP status code and message.
type StatusError struct {
StatusCode int
Status string
ErrorMessage string `json:"error"`
}
func (e StatusError) Error() string {
switch {
case e.Status != "" && e.ErrorMessage != "":
return fmt.Sprintf("%s: %s", e.Status, e.ErrorMessage)
case e.Status != "":
return e.Status
case e.ErrorMessage != "":
return e.ErrorMessage
default:
// this should not happen
return "something went wrong, please see the ollama server logs for details"
}
}
// ImageData represents the raw binary data of an image file.
type ImageData []byte

View File

@@ -8,7 +8,6 @@ import (
"crypto/ed25519"
"crypto/rand"
"crypto/sha256"
"encoding/base64"
"encoding/pem"
"errors"
"fmt"
@@ -17,11 +16,9 @@ import (
"math"
"net"
"net/http"
"net/url"
"os"
"os/signal"
"path/filepath"
"regexp"
"runtime"
"strconv"
"strings"
@@ -32,7 +29,6 @@ import (
"github.com/containerd/console"
"github.com/mattn/go-runewidth"
"github.com/olekukonko/tablewriter"
"github.com/pkg/browser"
"github.com/spf13/cobra"
"golang.org/x/crypto/ssh"
"golang.org/x/term"
@@ -44,7 +40,6 @@ import (
"github.com/ollama/ollama/parser"
"github.com/ollama/ollama/progress"
"github.com/ollama/ollama/server"
"github.com/ollama/ollama/types/errtypes"
"github.com/ollama/ollama/types/model"
"github.com/ollama/ollama/version"
)
@@ -519,74 +514,22 @@ func RunHandler(cmd *cobra.Command, args []string) error {
return generate(cmd, opts)
}
func generateFingerprint(key string) string {
hash := sha256.Sum256([]byte(key))
fingerprint := base64.RawURLEncoding.EncodeToString(hash[:6])
var formatted strings.Builder
for i, char := range fingerprint {
if i > 0 && i%2 == 0 {
formatted.WriteRune('-')
}
formatted.WriteRune(char)
func localPubKeys() ([]string, error) {
usrKey, err := auth.GetPublicKey()
if err != nil {
return nil, err
}
return formatted.String()
}
keys := []string{usrKey}
// tryConnect handles key validation when a connection fails due to an unknown key.
// It attempts to open the browser for interactive sessions to let users connect their key,
// falling back to command-line instructions for non-interactive sessions.
// Returns nil if browser flow succeeds, or an error with connection instructions otherwise.
func tryConnect(unknownKeyErr error) error {
// find SSH public key in the error message
// TODO (brucemacd): the API should return structured errors so that this message parsing isn't needed
sshKeyPattern := `ssh-\w+ [^\s"]+`
re := regexp.MustCompile(sshKeyPattern)
matches := re.FindStringSubmatch(unknownKeyErr.Error())
if len(matches) > 0 {
serverPubKey := matches[0]
localPubKey, err := auth.GetPublicKey()
if err != nil {
return unknownKeyErr
if runtime.GOOS == "linux" {
// try the ollama service public key if on Linux
if svcKey, err := os.ReadFile("/usr/share/ollama/.ollama/id_ed25519.pub"); err == nil {
keys = append(keys, strings.TrimSpace(string(svcKey)))
}
if runtime.GOOS == "linux" && serverPubKey != localPubKey {
// try the ollama service public key
svcPubKey, err := os.ReadFile("/usr/share/ollama/.ollama/id_ed25519.pub")
if err != nil {
return unknownKeyErr
}
localPubKey = strings.TrimSpace(string(svcPubKey))
}
// check if the returned public key matches the local public key, this prevents adding a remote key to the user's account
if serverPubKey != localPubKey {
return unknownKeyErr
}
if term.IsTerminal(int(os.Stdout.Fd())) {
// URL encode the key and device name for the browser URL
encodedKey := base64.RawURLEncoding.EncodeToString([]byte(localPubKey))
d, _ := os.Hostname()
encodedDevice := url.QueryEscape(d)
browserURL := fmt.Sprintf("https://ollama.com/connect?host=%s&key=%s", encodedDevice, encodedKey)
if err := browser.OpenURL(browserURL); err == nil {
fmt.Printf("\nOpening browser to add your key...\n")
fmt.Printf("\nCheck that this code matches what is shown in your browser:\n")
fmt.Printf("\n %s\n", generateFingerprint(localPubKey))
return nil
}
}
// only return error for non-interactive terminals or if browser opening failed
return fmt.Errorf("%s\nAdd your key at:\nhttps://ollama.com/settings/keys", unknownKeyErr.Error())
}
return unknownKeyErr
return keys, nil
}
func PushHandler(cmd *cobra.Command, args []string) error {
@@ -642,18 +585,17 @@ func PushHandler(cmd *cobra.Command, args []string) error {
if spinner != nil {
spinner.Stop()
}
if p != nil {
p.Stop()
var ke api.ErrUnknownOllamaKey
if errors.As(err, &ke) && isOllamaHost {
// the user has not added their ollama key to ollama.com
// return an error with a more user-friendly message
locals, _ := localPubKeys()
return errors.New(ke.FormatUserMessage(locals))
}
if strings.Contains(err.Error(), "access denied") {
return errors.New("you are not authorized to push to this namespace, create the model under a namespace you own")
}
if strings.Contains(err.Error(), errtypes.UnknownOllamaKeyErrMsg) && isOllamaHost {
// the user has not added their ollama key to ollama.com
// return an error with a more user-friendly message
return tryConnect(err)
}
return err
return fmt.Errorf("yoyoyo: %w", err)
}
p.Stop()

View File

@@ -16,7 +16,6 @@ import (
"github.com/spf13/cobra"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/types/errtypes"
)
func TestShowInfo(t *testing.T) {
@@ -437,7 +436,7 @@ func TestPushHandler(t *testing.T) {
"/api/push": func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusUnauthorized)
uerr := errtypes.UnknownOllamaKey{
uerr := api.ErrUnknownOllamaKey{
Key: "aaa",
}
err := json.NewEncoder(w).Encode(map[string]string{

View File

@@ -19,7 +19,6 @@ import (
"github.com/ollama/ollama/envconfig"
"github.com/ollama/ollama/parser"
"github.com/ollama/ollama/readline"
"github.com/ollama/ollama/types/errtypes"
)
type MultilineState int
@@ -220,7 +219,7 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error {
fn := func(resp api.ProgressResponse) error { return nil }
err = client.Create(cmd.Context(), req, fn)
if err != nil {
if strings.Contains(err.Error(), errtypes.InvalidModelNameErrMsg) {
if strings.Contains(err.Error(), api.InvalidModelNameErrMsg) {
fmt.Printf("error: The model name '%s' is invalid\n", args[1])
continue
}

1
go.mod
View File

@@ -22,7 +22,6 @@ require (
github.com/mattn/go-runewidth v0.0.14
github.com/nlpodyssey/gopickle v0.3.0
github.com/pdevine/tensor v0.0.0-20240510204454-f88f4562727c
github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c
golang.org/x/image v0.22.0
)

2
go.sum
View File

@@ -159,8 +159,6 @@ github.com/phpdave11/gofpdf v1.4.2/go.mod h1:zpO6xFn9yxo3YLyMvW8HcKWVdbNqgIfOOp2
github.com/phpdave11/gofpdi v1.0.12/go.mod h1:vBmVV0Do6hSBHC8uKUQ71JGW+ZGQq74llk/7bXwjDoI=
github.com/pierrec/lz4/v4 v4.1.8 h1:ieHkV+i2BRzngO4Wd/3HGowuZStgq6QkPsD1eolNAO4=
github.com/pierrec/lz4/v4 v4.1.8/go.mod h1:gZWDp/Ze/IJXGXf23ltt2EXimqmTUXEy0GFuRQyBid4=
github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c h1:+mdjkGKdHQG3305AYmdv1U2eRNDiU2ErMBj1gwrq8eQ=
github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c/go.mod h1:7rwL4CYBLnjLxUqIJNnCWiEdr3bn6IUYi15bNlnbCCU=
github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=

View File

@@ -30,7 +30,6 @@ import (
"github.com/ollama/ollama/llm"
"github.com/ollama/ollama/parser"
"github.com/ollama/ollama/template"
"github.com/ollama/ollama/types/errtypes"
"github.com/ollama/ollama/types/model"
"github.com/ollama/ollama/types/registry"
"github.com/ollama/ollama/version"
@@ -1031,7 +1030,7 @@ func makeRequestWithRetry(ctx context.Context, method string, requestURL *url.UR
slog.Error(fmt.Sprintf("couldn't get public key: %v", nestedErr))
return nil, re
}
return nil, errtypes.UnknownOllamaKey{
return nil, api.ErrUnknownOllamaKey{
Key: pubKey,
}
}

View File

@@ -36,7 +36,6 @@ import (
"github.com/ollama/ollama/runners"
"github.com/ollama/ollama/server/imageproc"
"github.com/ollama/ollama/template"
"github.com/ollama/ollama/types/errtypes"
"github.com/ollama/ollama/types/model"
"github.com/ollama/ollama/version"
)
@@ -610,7 +609,7 @@ func (s *Server) PushHandler(c *gin.Context) {
defer cancel()
if err := PushModel(ctx, model, regOpts, fn); err != nil {
ch <- gin.H{"error": err.Error()}
ch <- newErr(err)
}
}()
@@ -650,7 +649,7 @@ func (s *Server) CreateHandler(c *gin.Context) {
name := model.ParseName(cmp.Or(r.Model, r.Name))
if !name.IsValid() {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": errtypes.InvalidModelNameErrMsg})
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": api.InvalidModelNameErrMsg})
return
}
@@ -1550,3 +1549,24 @@ func handleScheduleError(c *gin.Context, name string, err error) {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
}
}
// newErr creates a structured API ErrorResponse from an existing error
func newErr(err error) api.ErrorResponse {
if err == nil {
return api.ErrorResponse{}
}
// Default to just returning the generic error message
resp := api.ErrorResponse{
Code: api.ErrCodeGeneral,
Message: err.Error(),
}
// Add additional error specific data, if any
var keyErr api.ErrUnknownOllamaKey
if errors.As(err, &keyErr) {
resp.Code = api.ErrCodeUnknownKey
resp.Data = map[string]any{
"key": keyErr.Key,
}
}
return resp
}

View File

@@ -1,21 +0,0 @@
// Package errtypes contains custom error types
package errtypes
import (
"fmt"
"strings"
)
const (
UnknownOllamaKeyErrMsg = "unknown ollama key"
InvalidModelNameErrMsg = "invalid model name"
)
// TODO: This should have a structured response from the API
type UnknownOllamaKey struct {
Key string
}
func (e UnknownOllamaKey) Error() string {
return fmt.Sprintf("unauthorized: %s %q", UnknownOllamaKeyErrMsg, strings.TrimSpace(e.Key))
}