diff --git a/api/client.go b/api/client.go index 7cc2acb3d..39b74d828 100644 --- a/api/client.go +++ b/api/client.go @@ -42,6 +42,23 @@ type Client struct { func checkError(resp *http.Response, body []byte) error { if resp.StatusCode < http.StatusBadRequest { + if len(body) == 0 { + return nil + } + + // streams can contain error message even with StatusOK + var errorResponse struct { + Error string `json:"error,omitempty"` + } + + if err := json.Unmarshal(body, &errorResponse); err != nil { + return fmt.Errorf("unmarshal: %w", err) + } + + if errorResponse.Error != "" { + return errors.New(errorResponse.Error) + } + return nil } @@ -213,25 +230,9 @@ func (c *Client) stream(ctx context.Context, method, path string, data any, fn f scanBuf := make([]byte, 0, maxBufferSize) scanner.Buffer(scanBuf, maxBufferSize) for scanner.Scan() { - var errorResponse struct { - Error string `json:"error,omitempty"` - } - bts := scanner.Bytes() - if err := json.Unmarshal(bts, &errorResponse); err != nil { - return fmt.Errorf("unmarshal: %w", err) - } - - if response.StatusCode >= http.StatusBadRequest { - return StatusError{ - StatusCode: response.StatusCode, - Status: response.Status, - ErrorMessage: errorResponse.Error, - } - } - - if errorResponse.Error != "" { - return errors.New(errorResponse.Error) + if err := checkError(response, bts); err != nil { + return err } if err := fn(bts); err != nil { diff --git a/api/client_test.go b/api/client_test.go index f0034e02d..2ceeec9cf 100644 --- a/api/client_test.go +++ b/api/client_test.go @@ -89,16 +89,6 @@ func TestClientStream(t *testing.T) { }, wantErr: "mid-stream error", }, - { - name: "http status error takes precedence over general error", - responses: []any{ - testError{ - message: "custom error message", - statusCode: http.StatusInternalServerError, - }, - }, - wantErr: "500", - }, { name: "successful stream completion", responses: []any{ diff --git a/auth/auth.go b/auth/auth.go index e1d854124..6d5374a08 100644 --- a/auth/auth.go +++ b/auth/auth.go @@ -18,6 +18,8 @@ import ( const defaultPrivateKey = "id_ed25519" +var ErrInvalidToken = errors.New("invalid token") + func keyPath() (string, error) { home, err := os.UserHomeDir() if err != nil { @@ -27,6 +29,39 @@ func keyPath() (string, error) { return filepath.Join(home, ".ollama", defaultPrivateKey), nil } +func parseToken(token string) (key, sig []byte, _ error) { + keyData, sigData, ok := strings.Cut(token, ":") + if !ok { + return nil, nil, fmt.Errorf("identity: parseToken: %w", ErrInvalidToken) + } + sig, err := base64.StdEncoding.DecodeString(sigData) + if err != nil { + return nil, nil, fmt.Errorf("identity: parseToken: base64 decoding signature: %w", err) + } + return []byte(keyData), sig, nil +} + +func Authenticate(token, checkData string) (ssh.PublicKey, error) { + keyShort, sigBytes, err := parseToken(token) + if err != nil { + return nil, err + } + keyLong := append([]byte("ssh-ed25519 "), keyShort...) + pub, _, _, _, err := ssh.ParseAuthorizedKey(keyLong) + if err != nil { + return nil, err + } + + if err := pub.Verify([]byte(checkData), &ssh.Signature{ + Format: pub.Type(), + Blob: sigBytes, + }); err != nil { + return nil, err + } + + return pub, nil +} + func GetPublicKey() (string, error) { keyPath, err := keyPath() if err != nil { diff --git a/auth/authorized_keys.go b/auth/authorized_keys.go new file mode 100644 index 000000000..3a120da44 --- /dev/null +++ b/auth/authorized_keys.go @@ -0,0 +1,254 @@ +package auth + +import ( + "bufio" + "encoding/base64" + "fmt" + "io" + "log/slog" + "os" + "path/filepath" + "regexp" + "strings" + "sync" + "time" + + "golang.org/x/crypto/ssh" +) + +type KeyEntry struct { + Name string + PublicKey string + Endpoints []string +} + +type KeyPermission struct { + Name string + Endpoints []string +} + +type APIPermissions struct { + permissions map[string]*KeyPermission + lastModified time.Time + mutex sync.RWMutex +} + +var ws = regexp.MustCompile(`\s+`) + +func authkeyPath() (string, error) { + home, err := os.UserHomeDir() + if err != nil { + return "", err + } + + return filepath.Join(home, ".ollama", "authorized_keys"), nil +} + +func NewAPIPermissions() *APIPermissions { + return &APIPermissions{ + permissions: make(map[string]*KeyPermission), + mutex: sync.RWMutex{}, + } +} + +func (ap *APIPermissions) ReloadIfNeeded() error { + ap.mutex.Lock() + defer ap.mutex.Unlock() + + filename, err := authkeyPath() + if err != nil { + return err + } + + fileInfo, err := os.Stat(filename) + if err != nil { + return fmt.Errorf("failed to stat file: %v", err) + } + + if !fileInfo.ModTime().After(ap.lastModified) { + return nil + } + + file, err := os.Open(filename) + if err != nil { + return fmt.Errorf("failed to open file: %v", err) + } + defer file.Close() + + ap.lastModified = fileInfo.ModTime() + return ap.parse(file) +} + +func (ap *APIPermissions) parse(r io.Reader) error { + ap.permissions = make(map[string]*KeyPermission) + + scanner := bufio.NewScanner(r) + var cnt int + for scanner.Scan() { + cnt += 1 + line := strings.TrimSpace(scanner.Text()) + + if line == "" || strings.HasPrefix(line, "#") { + continue + } + + line = ws.ReplaceAllString(line, " ") + + entry, err := ap.parseLine(line) + if err != nil { + slog.Warn(fmt.Sprintf("authorized_keys line %d: skipping invalid line: %v\n", cnt, err)) + continue + } + + var pubKeyStr string + + if entry.PublicKey == "*" { + pubKeyStr = "*" + } else { + pubKey, err := ap.validateAndDecodeKey(entry) + if err != nil { + slog.Warn(fmt.Sprintf("authorized_keys line %d: invalid key for %s: %v\n", cnt, entry.Name, err)) + continue + } + pubKeyStr = pubKey + } + + if perm, exists := ap.permissions[pubKeyStr]; exists { + if perm.Name == "default" { + perm.Name = entry.Name + } + if len(perm.Endpoints) == 1 && perm.Endpoints[0] == "*" { + // skip redundant entries + continue + } else if len(entry.Endpoints) == 1 && entry.Endpoints[0] == "*" { + // overwrite redundant entries + perm.Endpoints = entry.Endpoints + } else { + perm.Endpoints = append(perm.Endpoints, entry.Endpoints...) + } + } else { + ap.permissions[pubKeyStr] = &KeyPermission{ + Name: entry.Name, + Endpoints: entry.Endpoints, + } + } + } + + return scanner.Err() +} + +func (ap *APIPermissions) parseLine(line string) (*KeyEntry, error) { + parts := strings.SplitN(line, " ", 4) + if len(parts) < 2 { + return nil, fmt.Errorf("key type and public key not found") + } + + kind, b64Key := parts[0], parts[1] + name := "default" + eps := "*" + + if len(parts) >= 3 && parts[2] != "" { + if parts[2] != "*" { + name = parts[2] + } + } + + if len(parts) == 4 && parts[3] != "" { + eps = parts[3] + } + + if kind != "ssh-ed25519" && kind != "*" { + return nil, fmt.Errorf("unsupported key type %s", kind) + } + + if kind == "*" && b64Key != "*" { + return nil, fmt.Errorf("unsupported key type") + } + + var endpoints []string + if eps == "*" { + endpoints = []string{"*"} + } else { + for _, e := range strings.Split(eps, ",") { + e = strings.TrimSpace(e) + if e == "" { + return nil, fmt.Errorf("empty endpoint in list") + } else if e == "*" { + endpoints = []string{"*"} + break + } + endpoints = append(endpoints, e) + } + } + + return &KeyEntry{ + PublicKey: b64Key, + Name: name, + Endpoints: endpoints, + }, nil +} + +func (ap *APIPermissions) validateAndDecodeKey(entry *KeyEntry) (string, error) { + keyBlob, err := base64.StdEncoding.DecodeString(entry.PublicKey) + if err != nil { + return "", fmt.Errorf("base64 decode: %w", err) + } + pub, err := ssh.ParsePublicKey(keyBlob) + if err != nil { + return "", fmt.Errorf("parse key: %w", err) + } + if pub.Type() != ssh.KeyAlgoED25519 { + return "", fmt.Errorf("key is not Ed25519") + } + + return entry.PublicKey, nil +} + +func (ap *APIPermissions) Authorize(pubKey ssh.PublicKey, endpoint string) (bool, string, error) { + if err := ap.ReloadIfNeeded(); err != nil { + return false, "unknown", err + } + + ap.mutex.RLock() + defer ap.mutex.RUnlock() + + if wildcardPerm, exists := ap.permissions["*"]; exists { + if len(wildcardPerm.Endpoints) == 1 && wildcardPerm.Endpoints[0] == "*" { + return true, wildcardPerm.Name, nil + } + + for _, allowedEndpoint := range wildcardPerm.Endpoints { + if allowedEndpoint == endpoint { + return true, wildcardPerm.Name, nil + } + } + } + + keyString := string(ssh.MarshalAuthorizedKey(pubKey)) + parts := strings.SplitN(keyString, " ", 2) + var base64Key string + if len(parts) > 1 { + base64Key = parts[1] + } else { + base64Key = parts[0] + } + + base64Key = strings.TrimSpace(base64Key) + + perm, exists := ap.permissions[base64Key] + if !exists { + return false, "unknown", nil + } + + if len(perm.Endpoints) == 1 && perm.Endpoints[0] == "*" { + return true, perm.Name, nil + } + + for _, allowedEndpoint := range perm.Endpoints { + if allowedEndpoint == endpoint { + return true, perm.Name, nil + } + } + + return false, "unknown", nil +} diff --git a/auth/authorized_keys_test.go b/auth/authorized_keys_test.go new file mode 100644 index 000000000..0b8b9ad13 --- /dev/null +++ b/auth/authorized_keys_test.go @@ -0,0 +1,140 @@ +package auth + +import ( + "bytes" + "encoding/base64" + "reflect" + "testing" +) + +const validB64 = "AAAAC3NzaC1lZDI1NTE5AAAAICy1v/Sn0kGhu1LXzCsnx3wlk5ESdncS66JWo13yeJod" + +var ( + validKeyBlob, _ = base64.StdEncoding.DecodeString(validB64) + //validPub, _ = ssh.ParsePublicKey(validKeyBlob) + validPub = validB64 +) + +func TestParse(t *testing.T) { + tests := []struct { + name string + file string + want map[string]*KeyPermission + }{ + { + name: "two fields only defaults", + file: "ssh-ed25519 " + validB64 + "\n", + want: map[string]*KeyPermission{ + validPub: &KeyPermission{ + Name: "default", + Endpoints: []string{"*"}, + }, + }, + }, + { + name: "extra whitespace collapsed and default endpoints", + file: "ssh-ed25519 " + validB64 + " alice\n", + want: map[string]*KeyPermission{ + validPub: &KeyPermission{ + Name: "alice", + Endpoints: []string{"*"}, + }, + }, + }, + { + name: "four fields full", + file: "ssh-ed25519 " + validB64 + " bob /api/foo,/api/bar\n", + want: map[string]*KeyPermission{ + validPub: &KeyPermission{ + Name: "bob", + Endpoints: []string{"/api/foo", "/api/bar"}, + }, + }, + }, + { + name: "comment lines ignored and multiple entries", + file: "# header\n\nssh-ed25519 " + validB64 + " user1\nssh-ed25519 " + validB64 + " user2 /api/x\n", + want: map[string]*KeyPermission{ + validPub: &KeyPermission{ + Name: "user1", + Endpoints: []string{"*"}, + }, + }, + }, + { + name: "three entries variety", + file: "ssh-ed25519 " + validB64 + "\nssh-ed25519 " + validB64 + " alice /api/a,/api/b\nssh-ed25519 " + validB64 + " bob /api/c\n", + want: map[string]*KeyPermission{ + validPub: &KeyPermission{ + Name: "alice", + Endpoints: []string{"*"}, + }, + }, + }, + { + name: "two entries w/ wildcard", + file: "ssh-ed25519 " + validB64 + " alice /api/a\n* * * /api/b\n", + want: map[string]*KeyPermission{ + validPub: &KeyPermission{ + Name: "alice", + Endpoints: []string{"/api/a"}, + }, + "*": &KeyPermission{ + Name: "default", + Endpoints: []string{"/api/b"}, + }, + }, + }, + { + name: "tags for everyone", + file: "* * * /api/tags", + want: map[string]*KeyPermission{ + "*": &KeyPermission{ + Name: "default", + Endpoints: []string{"/api/tags"}, + }, + }, + }, + { + name: "default name", + file: "* * somename", + want: map[string]*KeyPermission{ + "*": &KeyPermission{ + Name: "somename", + Endpoints: []string{"*"}, + }, + }, + }, + { + name: "unsupported key type", + file: "ssh-rsa AAAAB3Nza...\n", + want: map[string]*KeyPermission{}, + }, + { + name: "bad base64", + file: "ssh-ed25519 invalid@@@\n", + want: map[string]*KeyPermission{}, + }, + { + name: "just an asterix", + file: "*\n", + want: map[string]*KeyPermission{}, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + perms := NewAPIPermissions() + err := perms.parse(bytes.NewBufferString(tc.file)) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if len(perms.permissions) != len(tc.want) { + t.Fatalf("got %d entries, want %d", len(perms.permissions), len(tc.want)) + } + if !reflect.DeepEqual(perms.permissions, tc.want) { + t.Errorf("got %+v, want %+v", perms.permissions, tc.want) + } + }) + } +} diff --git a/server/routes.go b/server/routes.go index 40348e737..24c422306 100644 --- a/server/routes.go +++ b/server/routes.go @@ -28,6 +28,7 @@ import ( "golang.org/x/sync/errgroup" "github.com/ollama/ollama/api" + "github.com/ollama/ollama/auth" "github.com/ollama/ollama/discover" "github.com/ollama/ollama/envconfig" "github.com/ollama/ollama/fs/ggml" @@ -55,6 +56,8 @@ var mode string = gin.DebugMode type Server struct { addr net.Addr sched *Scheduler + + perms *auth.APIPermissions } func init() { @@ -69,6 +72,38 @@ func init() { gin.SetMode(mode) } +func loggedFormatter(param gin.LogFormatterParams) string { + var statusColor, methodColor, resetColor string + if param.IsOutputColor() { + statusColor = param.StatusCodeColor() + methodColor = param.MethodColor() + resetColor = param.ResetColor() + } + + if param.Latency > time.Minute { + param.Latency = param.Latency.Truncate(time.Second) + } + + username := "default" + if userVal, exists := param.Keys["username"]; exists { + if name, ok := userVal.(string); ok { + username = name + } + } + + return fmt.Sprintf( + "[Ollama] %s |%s %3d %s| %13v | %15s | %-20s |%s %-7s %s %#v\n%s", + param.TimeStamp.Format("2006/01/02 - 15:04:05"), + statusColor, param.StatusCode, resetColor, + param.Latency, + param.ClientIP, + username, + methodColor, param.Method, resetColor, + param.Path, + param.ErrorMessage, + ) +} + var ( errRequired = errors.New("is required") errBadTemplate = errors.New("template error") @@ -1111,6 +1146,44 @@ func allowedHost(host string) bool { return false } +func allowedEndpointsMiddleware(perms *auth.APIPermissions) gin.HandlerFunc { + return func(c *gin.Context) { + if !envconfig.UseAuth() || (c.Request.Method == "HEAD" && c.Request.URL.Path == "/") { + c.Next() + return + } + + token := strings.TrimPrefix(c.Request.Header.Get("Authorization"), "Bearer ") + if token == "" { + c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "unauthorized"}) + return + } + + pubKey, err := auth.Authenticate(token, fmt.Sprintf("%s,%s", c.Request.Method, c.Request.RequestURI)) + if err != nil { + slog.Error("authentication error", "error", err) + c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "unauthorized"}) + return + } + + authorized, name, err := perms.Authorize(pubKey, c.Request.URL.Path) + c.Set("username", name) + if err != nil { + slog.Error("authorization error", "error", err) + c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "unauthorized"}) + return + } + + if !authorized { + c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "unauthorized"}) + return + } + + c.Next() + return + } +} + func allowedHostsMiddleware(addr net.Addr) gin.HandlerFunc { return func(c *gin.Context) { if addr == nil { @@ -1177,10 +1250,13 @@ func (s *Server) GenerateRoutes(rc *ollama.Registry) (http.Handler, error) { } corsConfig.AllowOrigins = envconfig.AllowedOrigins() - r := gin.Default() + r := gin.New() r.HandleMethodNotAllowed = true r.Use( + gin.LoggerWithFormatter(loggedFormatter), + gin.Recovery(), cors.New(corsConfig), + allowedEndpointsMiddleware(s.perms), allowedHostsMiddleware(s.addr), ) @@ -1190,7 +1266,7 @@ func (s *Server) GenerateRoutes(rc *ollama.Registry) (http.Handler, error) { r.HEAD("/api/version", func(c *gin.Context) { c.JSON(http.StatusOK, gin.H{"version": version.Version}) }) r.GET("/api/version", func(c *gin.Context) { c.JSON(http.StatusOK, gin.H{"version": version.Version}) }) - // Local model cache management (new implementation is at end of function) + // Local model cache management r.POST("/api/pull", s.PullHandler) r.POST("/api/push", s.PushHandler) r.HEAD("/api/tags", s.ListHandler) @@ -1222,7 +1298,7 @@ func (s *Server) GenerateRoutes(rc *ollama.Registry) (http.Handler, error) { // wrap old with new rs := ®istry.Local{ Client: rc, - Logger: slog.Default(), // TODO(bmizerany): Take a logger, do not use slog.Default() + Logger: slog.Default(), Fallback: r, Prune: PruneLayers, @@ -1267,6 +1343,12 @@ func Serve(ln net.Listener) error { s := &Server{addr: ln.Addr()} + if envconfig.UseAuth() { + perms := auth.NewAPIPermissions() + perms.ReloadIfNeeded() + s.perms = perms + } + var rc *ollama.Registry if useClient2 { var err error