server: add authorized_keys file

This change adds an "authorized_keys" file similar to sshd which can control
access to an Ollama server. The file itself is very simple and consists of
various entries for Ollama public keys.

The format is:

<key format> <public key> <name> [<endpoint>,...]

Examples:

ssh-ed25519 AAAAC3NzaC1lZDI1NT... bob /api/tags,/api/ps,/api/show,/api/generate,/api/chat

Use the "*" wildcard symbol to substitute any value, e.g.:

To grant full access to "bob":
ssh-ed25519 AAAAC3NzaC1lZDI1NT... bob *

To allow all callers to view tags (i.e. "ollama ls"):
* * * /api/tags

- The key format must be set to "ssh-ed25519" or set to the wildcard character.
- The public key must be an ssh based ed25519 (Ollama) public key or set to the wildcard
  character.
- Name can be any string you wish to associate with the public key. Note that if a public
  key is used in more than one entry in the file, the first instance of the name will be
  used and subsequent name values will be ignored.
- Endpoints is a comma separated list of Ollama Server API endpoints or the wildcard
  character. The HTTP method is not currently needed, but could be added in the future.
This commit is contained in:
Patrick Devine 2025-07-29 18:06:37 -07:00
parent 8afa6e83f2
commit 5968989a7f
6 changed files with 533 additions and 31 deletions

View File

@ -42,6 +42,23 @@ type Client struct {
func checkError(resp *http.Response, body []byte) error {
if resp.StatusCode < http.StatusBadRequest {
if len(body) == 0 {
return nil
}
// streams can contain error message even with StatusOK
var errorResponse struct {
Error string `json:"error,omitempty"`
}
if err := json.Unmarshal(body, &errorResponse); err != nil {
return fmt.Errorf("unmarshal: %w", err)
}
if errorResponse.Error != "" {
return errors.New(errorResponse.Error)
}
return nil
}
@ -213,25 +230,9 @@ func (c *Client) stream(ctx context.Context, method, path string, data any, fn f
scanBuf := make([]byte, 0, maxBufferSize)
scanner.Buffer(scanBuf, maxBufferSize)
for scanner.Scan() {
var errorResponse struct {
Error string `json:"error,omitempty"`
}
bts := scanner.Bytes()
if err := json.Unmarshal(bts, &errorResponse); err != nil {
return fmt.Errorf("unmarshal: %w", err)
}
if response.StatusCode >= http.StatusBadRequest {
return StatusError{
StatusCode: response.StatusCode,
Status: response.Status,
ErrorMessage: errorResponse.Error,
}
}
if errorResponse.Error != "" {
return errors.New(errorResponse.Error)
if err := checkError(response, bts); err != nil {
return err
}
if err := fn(bts); err != nil {

View File

@ -89,16 +89,6 @@ func TestClientStream(t *testing.T) {
},
wantErr: "mid-stream error",
},
{
name: "http status error takes precedence over general error",
responses: []any{
testError{
message: "custom error message",
statusCode: http.StatusInternalServerError,
},
},
wantErr: "500",
},
{
name: "successful stream completion",
responses: []any{

View File

@ -18,6 +18,8 @@ import (
const defaultPrivateKey = "id_ed25519"
var ErrInvalidToken = errors.New("invalid token")
func keyPath() (string, error) {
home, err := os.UserHomeDir()
if err != nil {
@ -27,6 +29,39 @@ func keyPath() (string, error) {
return filepath.Join(home, ".ollama", defaultPrivateKey), nil
}
func parseToken(token string) (key, sig []byte, _ error) {
keyData, sigData, ok := strings.Cut(token, ":")
if !ok {
return nil, nil, fmt.Errorf("identity: parseToken: %w", ErrInvalidToken)
}
sig, err := base64.StdEncoding.DecodeString(sigData)
if err != nil {
return nil, nil, fmt.Errorf("identity: parseToken: base64 decoding signature: %w", err)
}
return []byte(keyData), sig, nil
}
func Authenticate(token, checkData string) (ssh.PublicKey, error) {
keyShort, sigBytes, err := parseToken(token)
if err != nil {
return nil, err
}
keyLong := append([]byte("ssh-ed25519 "), keyShort...)
pub, _, _, _, err := ssh.ParseAuthorizedKey(keyLong)
if err != nil {
return nil, err
}
if err := pub.Verify([]byte(checkData), &ssh.Signature{
Format: pub.Type(),
Blob: sigBytes,
}); err != nil {
return nil, err
}
return pub, nil
}
func GetPublicKey() (string, error) {
keyPath, err := keyPath()
if err != nil {

254
auth/authorized_keys.go Normal file
View File

@ -0,0 +1,254 @@
package auth
import (
"bufio"
"encoding/base64"
"fmt"
"io"
"log/slog"
"os"
"path/filepath"
"regexp"
"strings"
"sync"
"time"
"golang.org/x/crypto/ssh"
)
type KeyEntry struct {
Name string
PublicKey string
Endpoints []string
}
type KeyPermission struct {
Name string
Endpoints []string
}
type APIPermissions struct {
permissions map[string]*KeyPermission
lastModified time.Time
mutex sync.RWMutex
}
var ws = regexp.MustCompile(`\s+`)
func authkeyPath() (string, error) {
home, err := os.UserHomeDir()
if err != nil {
return "", err
}
return filepath.Join(home, ".ollama", "authorized_keys"), nil
}
func NewAPIPermissions() *APIPermissions {
return &APIPermissions{
permissions: make(map[string]*KeyPermission),
mutex: sync.RWMutex{},
}
}
func (ap *APIPermissions) ReloadIfNeeded() error {
ap.mutex.Lock()
defer ap.mutex.Unlock()
filename, err := authkeyPath()
if err != nil {
return err
}
fileInfo, err := os.Stat(filename)
if err != nil {
return fmt.Errorf("failed to stat file: %v", err)
}
if !fileInfo.ModTime().After(ap.lastModified) {
return nil
}
file, err := os.Open(filename)
if err != nil {
return fmt.Errorf("failed to open file: %v", err)
}
defer file.Close()
ap.lastModified = fileInfo.ModTime()
return ap.parse(file)
}
func (ap *APIPermissions) parse(r io.Reader) error {
ap.permissions = make(map[string]*KeyPermission)
scanner := bufio.NewScanner(r)
var cnt int
for scanner.Scan() {
cnt += 1
line := strings.TrimSpace(scanner.Text())
if line == "" || strings.HasPrefix(line, "#") {
continue
}
line = ws.ReplaceAllString(line, " ")
entry, err := ap.parseLine(line)
if err != nil {
slog.Warn(fmt.Sprintf("authorized_keys line %d: skipping invalid line: %v\n", cnt, err))
continue
}
var pubKeyStr string
if entry.PublicKey == "*" {
pubKeyStr = "*"
} else {
pubKey, err := ap.validateAndDecodeKey(entry)
if err != nil {
slog.Warn(fmt.Sprintf("authorized_keys line %d: invalid key for %s: %v\n", cnt, entry.Name, err))
continue
}
pubKeyStr = pubKey
}
if perm, exists := ap.permissions[pubKeyStr]; exists {
if perm.Name == "default" {
perm.Name = entry.Name
}
if len(perm.Endpoints) == 1 && perm.Endpoints[0] == "*" {
// skip redundant entries
continue
} else if len(entry.Endpoints) == 1 && entry.Endpoints[0] == "*" {
// overwrite redundant entries
perm.Endpoints = entry.Endpoints
} else {
perm.Endpoints = append(perm.Endpoints, entry.Endpoints...)
}
} else {
ap.permissions[pubKeyStr] = &KeyPermission{
Name: entry.Name,
Endpoints: entry.Endpoints,
}
}
}
return scanner.Err()
}
func (ap *APIPermissions) parseLine(line string) (*KeyEntry, error) {
parts := strings.SplitN(line, " ", 4)
if len(parts) < 2 {
return nil, fmt.Errorf("key type and public key not found")
}
kind, b64Key := parts[0], parts[1]
name := "default"
eps := "*"
if len(parts) >= 3 && parts[2] != "" {
if parts[2] != "*" {
name = parts[2]
}
}
if len(parts) == 4 && parts[3] != "" {
eps = parts[3]
}
if kind != "ssh-ed25519" && kind != "*" {
return nil, fmt.Errorf("unsupported key type %s", kind)
}
if kind == "*" && b64Key != "*" {
return nil, fmt.Errorf("unsupported key type")
}
var endpoints []string
if eps == "*" {
endpoints = []string{"*"}
} else {
for _, e := range strings.Split(eps, ",") {
e = strings.TrimSpace(e)
if e == "" {
return nil, fmt.Errorf("empty endpoint in list")
} else if e == "*" {
endpoints = []string{"*"}
break
}
endpoints = append(endpoints, e)
}
}
return &KeyEntry{
PublicKey: b64Key,
Name: name,
Endpoints: endpoints,
}, nil
}
func (ap *APIPermissions) validateAndDecodeKey(entry *KeyEntry) (string, error) {
keyBlob, err := base64.StdEncoding.DecodeString(entry.PublicKey)
if err != nil {
return "", fmt.Errorf("base64 decode: %w", err)
}
pub, err := ssh.ParsePublicKey(keyBlob)
if err != nil {
return "", fmt.Errorf("parse key: %w", err)
}
if pub.Type() != ssh.KeyAlgoED25519 {
return "", fmt.Errorf("key is not Ed25519")
}
return entry.PublicKey, nil
}
func (ap *APIPermissions) Authorize(pubKey ssh.PublicKey, endpoint string) (bool, string, error) {
if err := ap.ReloadIfNeeded(); err != nil {
return false, "unknown", err
}
ap.mutex.RLock()
defer ap.mutex.RUnlock()
if wildcardPerm, exists := ap.permissions["*"]; exists {
if len(wildcardPerm.Endpoints) == 1 && wildcardPerm.Endpoints[0] == "*" {
return true, wildcardPerm.Name, nil
}
for _, allowedEndpoint := range wildcardPerm.Endpoints {
if allowedEndpoint == endpoint {
return true, wildcardPerm.Name, nil
}
}
}
keyString := string(ssh.MarshalAuthorizedKey(pubKey))
parts := strings.SplitN(keyString, " ", 2)
var base64Key string
if len(parts) > 1 {
base64Key = parts[1]
} else {
base64Key = parts[0]
}
base64Key = strings.TrimSpace(base64Key)
perm, exists := ap.permissions[base64Key]
if !exists {
return false, "unknown", nil
}
if len(perm.Endpoints) == 1 && perm.Endpoints[0] == "*" {
return true, perm.Name, nil
}
for _, allowedEndpoint := range perm.Endpoints {
if allowedEndpoint == endpoint {
return true, perm.Name, nil
}
}
return false, "unknown", nil
}

View File

@ -0,0 +1,140 @@
package auth
import (
"bytes"
"encoding/base64"
"reflect"
"testing"
)
const validB64 = "AAAAC3NzaC1lZDI1NTE5AAAAICy1v/Sn0kGhu1LXzCsnx3wlk5ESdncS66JWo13yeJod"
var (
validKeyBlob, _ = base64.StdEncoding.DecodeString(validB64)
//validPub, _ = ssh.ParsePublicKey(validKeyBlob)
validPub = validB64
)
func TestParse(t *testing.T) {
tests := []struct {
name string
file string
want map[string]*KeyPermission
}{
{
name: "two fields only defaults",
file: "ssh-ed25519 " + validB64 + "\n",
want: map[string]*KeyPermission{
validPub: &KeyPermission{
Name: "default",
Endpoints: []string{"*"},
},
},
},
{
name: "extra whitespace collapsed and default endpoints",
file: "ssh-ed25519 " + validB64 + " alice\n",
want: map[string]*KeyPermission{
validPub: &KeyPermission{
Name: "alice",
Endpoints: []string{"*"},
},
},
},
{
name: "four fields full",
file: "ssh-ed25519 " + validB64 + " bob /api/foo,/api/bar\n",
want: map[string]*KeyPermission{
validPub: &KeyPermission{
Name: "bob",
Endpoints: []string{"/api/foo", "/api/bar"},
},
},
},
{
name: "comment lines ignored and multiple entries",
file: "# header\n\nssh-ed25519 " + validB64 + " user1\nssh-ed25519 " + validB64 + " user2 /api/x\n",
want: map[string]*KeyPermission{
validPub: &KeyPermission{
Name: "user1",
Endpoints: []string{"*"},
},
},
},
{
name: "three entries variety",
file: "ssh-ed25519 " + validB64 + "\nssh-ed25519 " + validB64 + " alice /api/a,/api/b\nssh-ed25519 " + validB64 + " bob /api/c\n",
want: map[string]*KeyPermission{
validPub: &KeyPermission{
Name: "alice",
Endpoints: []string{"*"},
},
},
},
{
name: "two entries w/ wildcard",
file: "ssh-ed25519 " + validB64 + " alice /api/a\n* * * /api/b\n",
want: map[string]*KeyPermission{
validPub: &KeyPermission{
Name: "alice",
Endpoints: []string{"/api/a"},
},
"*": &KeyPermission{
Name: "default",
Endpoints: []string{"/api/b"},
},
},
},
{
name: "tags for everyone",
file: "* * * /api/tags",
want: map[string]*KeyPermission{
"*": &KeyPermission{
Name: "default",
Endpoints: []string{"/api/tags"},
},
},
},
{
name: "default name",
file: "* * somename",
want: map[string]*KeyPermission{
"*": &KeyPermission{
Name: "somename",
Endpoints: []string{"*"},
},
},
},
{
name: "unsupported key type",
file: "ssh-rsa AAAAB3Nza...\n",
want: map[string]*KeyPermission{},
},
{
name: "bad base64",
file: "ssh-ed25519 invalid@@@\n",
want: map[string]*KeyPermission{},
},
{
name: "just an asterix",
file: "*\n",
want: map[string]*KeyPermission{},
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
perms := NewAPIPermissions()
err := perms.parse(bytes.NewBufferString(tc.file))
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if len(perms.permissions) != len(tc.want) {
t.Fatalf("got %d entries, want %d", len(perms.permissions), len(tc.want))
}
if !reflect.DeepEqual(perms.permissions, tc.want) {
t.Errorf("got %+v, want %+v", perms.permissions, tc.want)
}
})
}
}

View File

@ -28,6 +28,7 @@ import (
"golang.org/x/sync/errgroup"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/auth"
"github.com/ollama/ollama/discover"
"github.com/ollama/ollama/envconfig"
"github.com/ollama/ollama/fs/ggml"
@ -55,6 +56,8 @@ var mode string = gin.DebugMode
type Server struct {
addr net.Addr
sched *Scheduler
perms *auth.APIPermissions
}
func init() {
@ -69,6 +72,38 @@ func init() {
gin.SetMode(mode)
}
func loggedFormatter(param gin.LogFormatterParams) string {
var statusColor, methodColor, resetColor string
if param.IsOutputColor() {
statusColor = param.StatusCodeColor()
methodColor = param.MethodColor()
resetColor = param.ResetColor()
}
if param.Latency > time.Minute {
param.Latency = param.Latency.Truncate(time.Second)
}
username := "default"
if userVal, exists := param.Keys["username"]; exists {
if name, ok := userVal.(string); ok {
username = name
}
}
return fmt.Sprintf(
"[Ollama] %s |%s %3d %s| %13v | %15s | %-20s |%s %-7s %s %#v\n%s",
param.TimeStamp.Format("2006/01/02 - 15:04:05"),
statusColor, param.StatusCode, resetColor,
param.Latency,
param.ClientIP,
username,
methodColor, param.Method, resetColor,
param.Path,
param.ErrorMessage,
)
}
var (
errRequired = errors.New("is required")
errBadTemplate = errors.New("template error")
@ -1111,6 +1146,44 @@ func allowedHost(host string) bool {
return false
}
func allowedEndpointsMiddleware(perms *auth.APIPermissions) gin.HandlerFunc {
return func(c *gin.Context) {
if !envconfig.UseAuth() || (c.Request.Method == "HEAD" && c.Request.URL.Path == "/") {
c.Next()
return
}
token := strings.TrimPrefix(c.Request.Header.Get("Authorization"), "Bearer ")
if token == "" {
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "unauthorized"})
return
}
pubKey, err := auth.Authenticate(token, fmt.Sprintf("%s,%s", c.Request.Method, c.Request.RequestURI))
if err != nil {
slog.Error("authentication error", "error", err)
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "unauthorized"})
return
}
authorized, name, err := perms.Authorize(pubKey, c.Request.URL.Path)
c.Set("username", name)
if err != nil {
slog.Error("authorization error", "error", err)
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "unauthorized"})
return
}
if !authorized {
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "unauthorized"})
return
}
c.Next()
return
}
}
func allowedHostsMiddleware(addr net.Addr) gin.HandlerFunc {
return func(c *gin.Context) {
if addr == nil {
@ -1177,10 +1250,13 @@ func (s *Server) GenerateRoutes(rc *ollama.Registry) (http.Handler, error) {
}
corsConfig.AllowOrigins = envconfig.AllowedOrigins()
r := gin.Default()
r := gin.New()
r.HandleMethodNotAllowed = true
r.Use(
gin.LoggerWithFormatter(loggedFormatter),
gin.Recovery(),
cors.New(corsConfig),
allowedEndpointsMiddleware(s.perms),
allowedHostsMiddleware(s.addr),
)
@ -1190,7 +1266,7 @@ func (s *Server) GenerateRoutes(rc *ollama.Registry) (http.Handler, error) {
r.HEAD("/api/version", func(c *gin.Context) { c.JSON(http.StatusOK, gin.H{"version": version.Version}) })
r.GET("/api/version", func(c *gin.Context) { c.JSON(http.StatusOK, gin.H{"version": version.Version}) })
// Local model cache management (new implementation is at end of function)
// Local model cache management
r.POST("/api/pull", s.PullHandler)
r.POST("/api/push", s.PushHandler)
r.HEAD("/api/tags", s.ListHandler)
@ -1222,7 +1298,7 @@ func (s *Server) GenerateRoutes(rc *ollama.Registry) (http.Handler, error) {
// wrap old with new
rs := &registry.Local{
Client: rc,
Logger: slog.Default(), // TODO(bmizerany): Take a logger, do not use slog.Default()
Logger: slog.Default(),
Fallback: r,
Prune: PruneLayers,
@ -1267,6 +1343,12 @@ func Serve(ln net.Listener) error {
s := &Server{addr: ln.Addr()}
if envconfig.UseAuth() {
perms := auth.NewAPIPermissions()
perms.ReloadIfNeeded()
s.perms = perms
}
var rc *ollama.Registry
if useClient2 {
var err error