Compare commits
5 Commits
pdevine/au
...
mxyng/crea
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
087beb40ed | ||
|
|
19279d778d | ||
|
|
ff89ba90bc | ||
|
|
6dcc5dfb9c | ||
|
|
25911a6e6b |
2
.github/workflows/release.yaml
vendored
2
.github/workflows/release.yaml
vendored
@@ -23,7 +23,7 @@ jobs:
|
||||
echo GOFLAGS="'-ldflags=-w -s \"-X=github.com/ollama/ollama/version.Version=${GITHUB_REF_NAME#v}\" \"-X=github.com/ollama/ollama/server.mode=release\"'" >>$GITHUB_OUTPUT
|
||||
|
||||
darwin-build:
|
||||
runs-on: macos-13
|
||||
runs-on: macos-13-xlarge
|
||||
environment: release
|
||||
needs: setup-environment
|
||||
strategy:
|
||||
|
||||
@@ -42,23 +42,6 @@ type Client struct {
|
||||
|
||||
func checkError(resp *http.Response, body []byte) error {
|
||||
if resp.StatusCode < http.StatusBadRequest {
|
||||
if len(body) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
// streams can contain error message even with StatusOK
|
||||
var errorResponse struct {
|
||||
Error string `json:"error,omitempty"`
|
||||
}
|
||||
|
||||
if err := json.Unmarshal(body, &errorResponse); err != nil {
|
||||
return fmt.Errorf("unmarshal: %w", err)
|
||||
}
|
||||
|
||||
if errorResponse.Error != "" {
|
||||
return errors.New(errorResponse.Error)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -230,9 +213,25 @@ func (c *Client) stream(ctx context.Context, method, path string, data any, fn f
|
||||
scanBuf := make([]byte, 0, maxBufferSize)
|
||||
scanner.Buffer(scanBuf, maxBufferSize)
|
||||
for scanner.Scan() {
|
||||
var errorResponse struct {
|
||||
Error string `json:"error,omitempty"`
|
||||
}
|
||||
|
||||
bts := scanner.Bytes()
|
||||
if err := checkError(response, bts); err != nil {
|
||||
return err
|
||||
if err := json.Unmarshal(bts, &errorResponse); err != nil {
|
||||
return fmt.Errorf("unmarshal: %w", err)
|
||||
}
|
||||
|
||||
if response.StatusCode >= http.StatusBadRequest {
|
||||
return StatusError{
|
||||
StatusCode: response.StatusCode,
|
||||
Status: response.Status,
|
||||
ErrorMessage: errorResponse.Error,
|
||||
}
|
||||
}
|
||||
|
||||
if errorResponse.Error != "" {
|
||||
return errors.New(errorResponse.Error)
|
||||
}
|
||||
|
||||
if err := fn(bts); err != nil {
|
||||
|
||||
@@ -89,6 +89,16 @@ func TestClientStream(t *testing.T) {
|
||||
},
|
||||
wantErr: "mid-stream error",
|
||||
},
|
||||
{
|
||||
name: "http status error takes precedence over general error",
|
||||
responses: []any{
|
||||
testError{
|
||||
message: "custom error message",
|
||||
statusCode: http.StatusInternalServerError,
|
||||
},
|
||||
},
|
||||
wantErr: "500",
|
||||
},
|
||||
{
|
||||
name: "successful stream completion",
|
||||
responses: []any{
|
||||
|
||||
35
auth/auth.go
35
auth/auth.go
@@ -18,8 +18,6 @@ import (
|
||||
|
||||
const defaultPrivateKey = "id_ed25519"
|
||||
|
||||
var ErrInvalidToken = errors.New("invalid token")
|
||||
|
||||
func keyPath() (string, error) {
|
||||
home, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
@@ -29,39 +27,6 @@ func keyPath() (string, error) {
|
||||
return filepath.Join(home, ".ollama", defaultPrivateKey), nil
|
||||
}
|
||||
|
||||
func parseToken(token string) (key, sig []byte, _ error) {
|
||||
keyData, sigData, ok := strings.Cut(token, ":")
|
||||
if !ok {
|
||||
return nil, nil, fmt.Errorf("identity: parseToken: %w", ErrInvalidToken)
|
||||
}
|
||||
sig, err := base64.StdEncoding.DecodeString(sigData)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("identity: parseToken: base64 decoding signature: %w", err)
|
||||
}
|
||||
return []byte(keyData), sig, nil
|
||||
}
|
||||
|
||||
func Authenticate(token, checkData string) (ssh.PublicKey, error) {
|
||||
keyShort, sigBytes, err := parseToken(token)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
keyLong := append([]byte("ssh-ed25519 "), keyShort...)
|
||||
pub, _, _, _, err := ssh.ParseAuthorizedKey(keyLong)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := pub.Verify([]byte(checkData), &ssh.Signature{
|
||||
Format: pub.Type(),
|
||||
Blob: sigBytes,
|
||||
}); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return pub, nil
|
||||
}
|
||||
|
||||
func GetPublicKey() (string, error) {
|
||||
keyPath, err := keyPath()
|
||||
if err != nil {
|
||||
|
||||
@@ -1,254 +0,0 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"regexp"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"golang.org/x/crypto/ssh"
|
||||
)
|
||||
|
||||
type KeyEntry struct {
|
||||
Name string
|
||||
PublicKey string
|
||||
Endpoints []string
|
||||
}
|
||||
|
||||
type KeyPermission struct {
|
||||
Name string
|
||||
Endpoints []string
|
||||
}
|
||||
|
||||
type APIPermissions struct {
|
||||
permissions map[string]*KeyPermission
|
||||
lastModified time.Time
|
||||
mutex sync.RWMutex
|
||||
}
|
||||
|
||||
var ws = regexp.MustCompile(`\s+`)
|
||||
|
||||
func authkeyPath() (string, error) {
|
||||
home, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return filepath.Join(home, ".ollama", "authorized_keys"), nil
|
||||
}
|
||||
|
||||
func NewAPIPermissions() *APIPermissions {
|
||||
return &APIPermissions{
|
||||
permissions: make(map[string]*KeyPermission),
|
||||
mutex: sync.RWMutex{},
|
||||
}
|
||||
}
|
||||
|
||||
func (ap *APIPermissions) ReloadIfNeeded() error {
|
||||
ap.mutex.Lock()
|
||||
defer ap.mutex.Unlock()
|
||||
|
||||
filename, err := authkeyPath()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
fileInfo, err := os.Stat(filename)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to stat file: %v", err)
|
||||
}
|
||||
|
||||
if !fileInfo.ModTime().After(ap.lastModified) {
|
||||
return nil
|
||||
}
|
||||
|
||||
file, err := os.Open(filename)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to open file: %v", err)
|
||||
}
|
||||
defer file.Close()
|
||||
|
||||
ap.lastModified = fileInfo.ModTime()
|
||||
return ap.parse(file)
|
||||
}
|
||||
|
||||
func (ap *APIPermissions) parse(r io.Reader) error {
|
||||
ap.permissions = make(map[string]*KeyPermission)
|
||||
|
||||
scanner := bufio.NewScanner(r)
|
||||
var cnt int
|
||||
for scanner.Scan() {
|
||||
cnt += 1
|
||||
line := strings.TrimSpace(scanner.Text())
|
||||
|
||||
if line == "" || strings.HasPrefix(line, "#") {
|
||||
continue
|
||||
}
|
||||
|
||||
line = ws.ReplaceAllString(line, " ")
|
||||
|
||||
entry, err := ap.parseLine(line)
|
||||
if err != nil {
|
||||
slog.Warn(fmt.Sprintf("authorized_keys line %d: skipping invalid line: %v\n", cnt, err))
|
||||
continue
|
||||
}
|
||||
|
||||
var pubKeyStr string
|
||||
|
||||
if entry.PublicKey == "*" {
|
||||
pubKeyStr = "*"
|
||||
} else {
|
||||
pubKey, err := ap.validateAndDecodeKey(entry)
|
||||
if err != nil {
|
||||
slog.Warn(fmt.Sprintf("authorized_keys line %d: invalid key for %s: %v\n", cnt, entry.Name, err))
|
||||
continue
|
||||
}
|
||||
pubKeyStr = pubKey
|
||||
}
|
||||
|
||||
if perm, exists := ap.permissions[pubKeyStr]; exists {
|
||||
if perm.Name == "default" {
|
||||
perm.Name = entry.Name
|
||||
}
|
||||
if len(perm.Endpoints) == 1 && perm.Endpoints[0] == "*" {
|
||||
// skip redundant entries
|
||||
continue
|
||||
} else if len(entry.Endpoints) == 1 && entry.Endpoints[0] == "*" {
|
||||
// overwrite redundant entries
|
||||
perm.Endpoints = entry.Endpoints
|
||||
} else {
|
||||
perm.Endpoints = append(perm.Endpoints, entry.Endpoints...)
|
||||
}
|
||||
} else {
|
||||
ap.permissions[pubKeyStr] = &KeyPermission{
|
||||
Name: entry.Name,
|
||||
Endpoints: entry.Endpoints,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return scanner.Err()
|
||||
}
|
||||
|
||||
func (ap *APIPermissions) parseLine(line string) (*KeyEntry, error) {
|
||||
parts := strings.SplitN(line, " ", 4)
|
||||
if len(parts) < 2 {
|
||||
return nil, fmt.Errorf("key type and public key not found")
|
||||
}
|
||||
|
||||
kind, b64Key := parts[0], parts[1]
|
||||
name := "default"
|
||||
eps := "*"
|
||||
|
||||
if len(parts) >= 3 && parts[2] != "" {
|
||||
if parts[2] != "*" {
|
||||
name = parts[2]
|
||||
}
|
||||
}
|
||||
|
||||
if len(parts) == 4 && parts[3] != "" {
|
||||
eps = parts[3]
|
||||
}
|
||||
|
||||
if kind != "ssh-ed25519" && kind != "*" {
|
||||
return nil, fmt.Errorf("unsupported key type %s", kind)
|
||||
}
|
||||
|
||||
if kind == "*" && b64Key != "*" {
|
||||
return nil, fmt.Errorf("unsupported key type")
|
||||
}
|
||||
|
||||
var endpoints []string
|
||||
if eps == "*" {
|
||||
endpoints = []string{"*"}
|
||||
} else {
|
||||
for _, e := range strings.Split(eps, ",") {
|
||||
e = strings.TrimSpace(e)
|
||||
if e == "" {
|
||||
return nil, fmt.Errorf("empty endpoint in list")
|
||||
} else if e == "*" {
|
||||
endpoints = []string{"*"}
|
||||
break
|
||||
}
|
||||
endpoints = append(endpoints, e)
|
||||
}
|
||||
}
|
||||
|
||||
return &KeyEntry{
|
||||
PublicKey: b64Key,
|
||||
Name: name,
|
||||
Endpoints: endpoints,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (ap *APIPermissions) validateAndDecodeKey(entry *KeyEntry) (string, error) {
|
||||
keyBlob, err := base64.StdEncoding.DecodeString(entry.PublicKey)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("base64 decode: %w", err)
|
||||
}
|
||||
pub, err := ssh.ParsePublicKey(keyBlob)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("parse key: %w", err)
|
||||
}
|
||||
if pub.Type() != ssh.KeyAlgoED25519 {
|
||||
return "", fmt.Errorf("key is not Ed25519")
|
||||
}
|
||||
|
||||
return entry.PublicKey, nil
|
||||
}
|
||||
|
||||
func (ap *APIPermissions) Authorize(pubKey ssh.PublicKey, endpoint string) (bool, string, error) {
|
||||
if err := ap.ReloadIfNeeded(); err != nil {
|
||||
return false, "unknown", err
|
||||
}
|
||||
|
||||
ap.mutex.RLock()
|
||||
defer ap.mutex.RUnlock()
|
||||
|
||||
if wildcardPerm, exists := ap.permissions["*"]; exists {
|
||||
if len(wildcardPerm.Endpoints) == 1 && wildcardPerm.Endpoints[0] == "*" {
|
||||
return true, wildcardPerm.Name, nil
|
||||
}
|
||||
|
||||
for _, allowedEndpoint := range wildcardPerm.Endpoints {
|
||||
if allowedEndpoint == endpoint {
|
||||
return true, wildcardPerm.Name, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
keyString := string(ssh.MarshalAuthorizedKey(pubKey))
|
||||
parts := strings.SplitN(keyString, " ", 2)
|
||||
var base64Key string
|
||||
if len(parts) > 1 {
|
||||
base64Key = parts[1]
|
||||
} else {
|
||||
base64Key = parts[0]
|
||||
}
|
||||
|
||||
base64Key = strings.TrimSpace(base64Key)
|
||||
|
||||
perm, exists := ap.permissions[base64Key]
|
||||
if !exists {
|
||||
return false, "unknown", nil
|
||||
}
|
||||
|
||||
if len(perm.Endpoints) == 1 && perm.Endpoints[0] == "*" {
|
||||
return true, perm.Name, nil
|
||||
}
|
||||
|
||||
for _, allowedEndpoint := range perm.Endpoints {
|
||||
if allowedEndpoint == endpoint {
|
||||
return true, perm.Name, nil
|
||||
}
|
||||
}
|
||||
|
||||
return false, "unknown", nil
|
||||
}
|
||||
@@ -1,133 +0,0 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"reflect"
|
||||
"testing"
|
||||
)
|
||||
|
||||
const validB64 = "AAAAC3NzaC1lZDI1NTE5AAAAICy1v/Sn0kGhu1LXzCsnx3wlk5ESdncS66JWo13yeJod"
|
||||
|
||||
func TestParse(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
file string
|
||||
want map[string]*KeyPermission
|
||||
}{
|
||||
{
|
||||
name: "two fields only defaults",
|
||||
file: "ssh-ed25519 " + validB64 + "\n",
|
||||
want: map[string]*KeyPermission{
|
||||
validB64: {
|
||||
Name: "default",
|
||||
Endpoints: []string{"*"},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "extra whitespace collapsed and default endpoints",
|
||||
file: "ssh-ed25519 " + validB64 + " alice\n",
|
||||
want: map[string]*KeyPermission{
|
||||
validB64: {
|
||||
Name: "alice",
|
||||
Endpoints: []string{"*"},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "four fields full",
|
||||
file: "ssh-ed25519 " + validB64 + " bob /api/foo,/api/bar\n",
|
||||
want: map[string]*KeyPermission{
|
||||
validB64: {
|
||||
Name: "bob",
|
||||
Endpoints: []string{"/api/foo", "/api/bar"},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "comment lines ignored and multiple entries",
|
||||
file: "# header\n\nssh-ed25519 " + validB64 + " user1\nssh-ed25519 " + validB64 + " user2 /api/x\n",
|
||||
want: map[string]*KeyPermission{
|
||||
validB64: {
|
||||
Name: "user1",
|
||||
Endpoints: []string{"*"},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "three entries variety",
|
||||
file: "ssh-ed25519 " + validB64 + "\nssh-ed25519 " + validB64 + " alice /api/a,/api/b\nssh-ed25519 " + validB64 + " bob /api/c\n",
|
||||
want: map[string]*KeyPermission{
|
||||
validB64: {
|
||||
Name: "alice",
|
||||
Endpoints: []string{"*"},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "two entries w/ wildcard",
|
||||
file: "ssh-ed25519 " + validB64 + " alice /api/a\n* * * /api/b\n",
|
||||
want: map[string]*KeyPermission{
|
||||
validB64: {
|
||||
Name: "alice",
|
||||
Endpoints: []string{"/api/a"},
|
||||
},
|
||||
"*": {
|
||||
Name: "default",
|
||||
Endpoints: []string{"/api/b"},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "tags for everyone",
|
||||
file: "* * * /api/tags",
|
||||
want: map[string]*KeyPermission{
|
||||
"*": {
|
||||
Name: "default",
|
||||
Endpoints: []string{"/api/tags"},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "default name",
|
||||
file: "* * somename",
|
||||
want: map[string]*KeyPermission{
|
||||
"*": {
|
||||
Name: "somename",
|
||||
Endpoints: []string{"*"},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "unsupported key type",
|
||||
file: "ssh-rsa AAAAB3Nza...\n",
|
||||
want: map[string]*KeyPermission{},
|
||||
},
|
||||
{
|
||||
name: "bad base64",
|
||||
file: "ssh-ed25519 invalid@@@\n",
|
||||
want: map[string]*KeyPermission{},
|
||||
},
|
||||
{
|
||||
name: "just an asterix",
|
||||
file: "*\n",
|
||||
want: map[string]*KeyPermission{},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
perms := NewAPIPermissions()
|
||||
err := perms.parse(bytes.NewBufferString(tc.file))
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if len(perms.permissions) != len(tc.want) {
|
||||
t.Fatalf("got %d entries, want %d", len(perms.permissions), len(tc.want))
|
||||
}
|
||||
if !reflect.DeepEqual(perms.permissions, tc.want) {
|
||||
t.Errorf("got %+v, want %+v", perms.permissions, tc.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
72
cmd/cmd.go
72
cmd/cmd.go
@@ -64,54 +64,37 @@ func ensureThinkingSupport(ctx context.Context, client *api.Client, name string)
|
||||
fmt.Fprintf(os.Stderr, "warning: model %q does not support thinking output\n", name)
|
||||
}
|
||||
|
||||
var errModelfileNotFound = errors.New("specified Modelfile wasn't found")
|
||||
|
||||
func getModelfileName(cmd *cobra.Command) (string, error) {
|
||||
filename, _ := cmd.Flags().GetString("file")
|
||||
|
||||
if filename == "" {
|
||||
filename = "Modelfile"
|
||||
}
|
||||
|
||||
absName, err := filepath.Abs(filename)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
_, err = os.Stat(absName)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return absName, nil
|
||||
}
|
||||
|
||||
func CreateHandler(cmd *cobra.Command, args []string) error {
|
||||
p := progress.NewProgress(os.Stderr)
|
||||
defer p.Stop()
|
||||
|
||||
var reader io.Reader
|
||||
|
||||
filename, err := getModelfileName(cmd)
|
||||
if os.IsNotExist(err) {
|
||||
if filename == "" {
|
||||
reader = strings.NewReader("FROM .\n")
|
||||
} else {
|
||||
return errModelfileNotFound
|
||||
}
|
||||
} else if err != nil {
|
||||
return err
|
||||
} else {
|
||||
f, err := os.Open(filename)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
reader = f
|
||||
defer f.Close()
|
||||
filename, err := cmd.Flags().GetString("file")
|
||||
if err != nil {
|
||||
return fmt.Errorf("error retrieving file flag: %w", err)
|
||||
}
|
||||
|
||||
modelfile, err := parser.ParseFile(reader)
|
||||
var r, fallback io.Reader
|
||||
switch filename {
|
||||
case "-":
|
||||
r = os.Stdin
|
||||
case "":
|
||||
filename = "Modelfile"
|
||||
fallback = strings.NewReader("FROM .")
|
||||
fallthrough
|
||||
default:
|
||||
r, err = os.Open(filename)
|
||||
if errors.Is(err, os.ErrNotExist) && fallback != nil {
|
||||
r = fallback
|
||||
} else if errors.Is(err, os.ErrNotExist) {
|
||||
return fmt.Errorf("%w: Modelfile %q does not exist, please create it or use --file to specify a different file", err, filename)
|
||||
} else if err != nil {
|
||||
return err
|
||||
} else {
|
||||
defer r.(*os.File).Close()
|
||||
}
|
||||
}
|
||||
|
||||
modelfile, err := parser.ParseFile(r)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -127,10 +110,7 @@ func CreateHandler(cmd *cobra.Command, args []string) error {
|
||||
spinner.Stop()
|
||||
|
||||
req.Model = args[0]
|
||||
quantize, _ := cmd.Flags().GetString("quantize")
|
||||
if quantize != "" {
|
||||
req.Quantize = quantize
|
||||
}
|
||||
req.Quantize, _ = cmd.Flags().GetString("quantize")
|
||||
|
||||
client, err := api.ClientFromEnvironment()
|
||||
if err != nil {
|
||||
|
||||
350
cmd/cmd_test.go
350
cmd/cmd_test.go
@@ -3,10 +3,13 @@ package cmd
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
@@ -18,6 +21,13 @@ import (
|
||||
"github.com/ollama/ollama/types/model"
|
||||
)
|
||||
|
||||
func mockServer(t *testing.T, h http.HandlerFunc) {
|
||||
t.Helper()
|
||||
s := httptest.NewServer(h)
|
||||
t.Cleanup(s.Close)
|
||||
t.Setenv("OLLAMA_HOST", s.URL)
|
||||
}
|
||||
|
||||
func TestShowInfo(t *testing.T) {
|
||||
t.Run("bare details", func(t *testing.T) {
|
||||
var b bytes.Buffer
|
||||
@@ -351,101 +361,6 @@ func TestDeleteHandler(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetModelfileName(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
modelfileName string
|
||||
fileExists bool
|
||||
expectedName string
|
||||
expectedErr error
|
||||
}{
|
||||
{
|
||||
name: "no modelfile specified, no modelfile exists",
|
||||
modelfileName: "",
|
||||
fileExists: false,
|
||||
expectedName: "",
|
||||
expectedErr: os.ErrNotExist,
|
||||
},
|
||||
{
|
||||
name: "no modelfile specified, modelfile exists",
|
||||
modelfileName: "",
|
||||
fileExists: true,
|
||||
expectedName: "Modelfile",
|
||||
expectedErr: nil,
|
||||
},
|
||||
{
|
||||
name: "modelfile specified, no modelfile exists",
|
||||
modelfileName: "crazyfile",
|
||||
fileExists: false,
|
||||
expectedName: "",
|
||||
expectedErr: os.ErrNotExist,
|
||||
},
|
||||
{
|
||||
name: "modelfile specified, modelfile exists",
|
||||
modelfileName: "anotherfile",
|
||||
fileExists: true,
|
||||
expectedName: "anotherfile",
|
||||
expectedErr: nil,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
cmd := &cobra.Command{
|
||||
Use: "fakecmd",
|
||||
}
|
||||
cmd.Flags().String("file", "", "path to modelfile")
|
||||
|
||||
var expectedFilename string
|
||||
|
||||
if tt.fileExists {
|
||||
var fn string
|
||||
if tt.modelfileName != "" {
|
||||
fn = tt.modelfileName
|
||||
} else {
|
||||
fn = "Modelfile"
|
||||
}
|
||||
|
||||
tempFile, err := os.CreateTemp(t.TempDir(), fn)
|
||||
if err != nil {
|
||||
t.Fatalf("temp modelfile creation failed: %v", err)
|
||||
}
|
||||
defer tempFile.Close()
|
||||
|
||||
expectedFilename = tempFile.Name()
|
||||
err = cmd.Flags().Set("file", expectedFilename)
|
||||
if err != nil {
|
||||
t.Fatalf("couldn't set file flag: %v", err)
|
||||
}
|
||||
} else {
|
||||
expectedFilename = tt.expectedName
|
||||
if tt.modelfileName != "" {
|
||||
err := cmd.Flags().Set("file", tt.modelfileName)
|
||||
if err != nil {
|
||||
t.Fatalf("couldn't set file flag: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
actualFilename, actualErr := getModelfileName(cmd)
|
||||
|
||||
if actualFilename != expectedFilename {
|
||||
t.Errorf("expected filename: '%s' actual filename: '%s'", expectedFilename, actualFilename)
|
||||
}
|
||||
|
||||
if tt.expectedErr != os.ErrNotExist {
|
||||
if actualErr != tt.expectedErr {
|
||||
t.Errorf("expected err: %v actual err: %v", tt.expectedErr, actualErr)
|
||||
}
|
||||
} else {
|
||||
if !os.IsNotExist(actualErr) {
|
||||
t.Errorf("expected err: %v actual err: %v", tt.expectedErr, actualErr)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestPushHandler(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
@@ -661,128 +576,165 @@ func TestListHandler(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestCreateHandler(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
modelName string
|
||||
modelFile string
|
||||
serverResponse map[string]func(w http.ResponseWriter, r *http.Request)
|
||||
expectedError string
|
||||
expectedOutput string
|
||||
cases := []struct {
|
||||
name string
|
||||
filename func(*testing.T) string
|
||||
|
||||
wantRequest api.CreateRequest
|
||||
wantErr error
|
||||
}{
|
||||
{
|
||||
name: "successful create",
|
||||
modelName: "test-model",
|
||||
modelFile: "FROM foo",
|
||||
serverResponse: map[string]func(w http.ResponseWriter, r *http.Request){
|
||||
"/api/create": func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodPost {
|
||||
t.Errorf("expected POST request, got %s", r.Method)
|
||||
}
|
||||
name: "not exist",
|
||||
filename: func(*testing.T) string { return "not_exist" },
|
||||
wantErr: os.ErrNotExist,
|
||||
},
|
||||
{
|
||||
name: "stdin",
|
||||
filename: func(t *testing.T) string {
|
||||
r, w, err := os.Pipe()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
req := api.CreateRequest{}
|
||||
if _, err := w.WriteString("FROM test"); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if err := w.Close(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
stdin := os.Stdin
|
||||
t.Cleanup(func() { os.Stdin = stdin })
|
||||
os.Stdin = r
|
||||
return "-"
|
||||
},
|
||||
wantRequest: api.CreateRequest{
|
||||
Model: "stdin",
|
||||
From: "test",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "default",
|
||||
filename: func(t *testing.T) string {
|
||||
t.Chdir(t.TempDir())
|
||||
f, err := os.Create("Modelfile")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
if _, err := f.WriteString("FROM test"); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
return ""
|
||||
},
|
||||
wantRequest: api.CreateRequest{
|
||||
Model: "default",
|
||||
From: "test",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "default safetensors",
|
||||
filename: func(t *testing.T) string {
|
||||
t.Chdir(t.TempDir())
|
||||
f, err := os.Create("model.safetensors")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
if err := f.Truncate(1); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
return ""
|
||||
},
|
||||
wantRequest: api.CreateRequest{
|
||||
Model: "default_safetensors",
|
||||
Files: map[string]string{
|
||||
"model.safetensors": "sha256:6e340b9cffb37a989ca544e6bb780a2c78901d3fb33738768511a30617afa01d",
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "file flag",
|
||||
filename: func(t *testing.T) string {
|
||||
f, err := os.CreateTemp(t.TempDir(), filepath.Base(t.Name()))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
if _, err := f.WriteString("FROM test"); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
return f.Name()
|
||||
},
|
||||
wantRequest: api.CreateRequest{
|
||||
Model: "file_flag",
|
||||
From: "test",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "insecure path",
|
||||
filename: func(t *testing.T) string {
|
||||
t.Chdir(t.TempDir())
|
||||
if err := os.Symlink("../../../../../../nope", "model.safetensors"); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
return ""
|
||||
},
|
||||
wantErr: fmt.Errorf("openat %s: path escapes from parent", "model.safetensors"),
|
||||
},
|
||||
}
|
||||
|
||||
var cmd cobra.Command
|
||||
cmd.SetContext(t.Context())
|
||||
cmd.Flags().String("file", "", "")
|
||||
cmd.Flags().String("quantize", "", "")
|
||||
for _, tt := range cases {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
mockServer(t, func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodPost {
|
||||
http.Error(w, "method not allowed", http.StatusMethodNotAllowed)
|
||||
return
|
||||
}
|
||||
|
||||
if r.URL.Path == "/api/create" {
|
||||
var req api.CreateRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
http.Error(w, err.Error(), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
if req.Model != "test-model" {
|
||||
t.Errorf("expected model name 'test-model', got %s", req.Name)
|
||||
if diff := cmp.Diff(tt.wantRequest, req); diff != "" {
|
||||
t.Errorf("Create request mismatch (-want +got):\n%s", diff)
|
||||
}
|
||||
|
||||
if req.From != "foo" {
|
||||
t.Errorf("expected from 'foo', got %s", req.From)
|
||||
}
|
||||
|
||||
responses := []api.ProgressResponse{
|
||||
{Status: "using existing layer sha256:56bb8bd477a519ffa694fc449c2413c6f0e1d3b1c88fa7e3c9d88d3ae49d4dcb"},
|
||||
{Status: "writing manifest"},
|
||||
{Status: "success"},
|
||||
}
|
||||
|
||||
for _, resp := range responses {
|
||||
if err := json.NewEncoder(w).Encode(resp); err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
w.(http.Flusher).Flush()
|
||||
}
|
||||
},
|
||||
},
|
||||
expectedOutput: "",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
handler, ok := tt.serverResponse[r.URL.Path]
|
||||
if !ok {
|
||||
t.Errorf("unexpected request to %s", r.URL.Path)
|
||||
} else if strings.HasPrefix(r.URL.Path, "/api/blobs/") {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
} else {
|
||||
http.Error(w, "not found", http.StatusNotFound)
|
||||
return
|
||||
}
|
||||
handler(w, r)
|
||||
}))
|
||||
t.Setenv("OLLAMA_HOST", mockServer.URL)
|
||||
t.Cleanup(mockServer.Close)
|
||||
tempFile, err := os.CreateTemp(t.TempDir(), "modelfile")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer os.Remove(tempFile.Name())
|
||||
})
|
||||
|
||||
if _, err := tempFile.WriteString(tt.modelFile); err != nil {
|
||||
t.Fatal(err)
|
||||
var filename string
|
||||
if tt.filename != nil {
|
||||
filename = tt.filename(t)
|
||||
}
|
||||
if err := tempFile.Close(); err != nil {
|
||||
|
||||
if err := cmd.Flags().Set("file", filename); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
cmd := &cobra.Command{}
|
||||
cmd.Flags().String("file", "", "")
|
||||
if err := cmd.Flags().Set("file", tempFile.Name()); err != nil {
|
||||
if err := CreateHandler(&cmd, []string{filepath.Base(t.Name())}); err != tt.wantErr &&
|
||||
err.Error() != tt.wantErr.Error() &&
|
||||
!errors.Is(err, tt.wantErr) {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
cmd.Flags().Bool("insecure", false, "")
|
||||
cmd.SetContext(t.Context())
|
||||
|
||||
// Redirect stderr to capture progress output
|
||||
oldStderr := os.Stderr
|
||||
r, w, _ := os.Pipe()
|
||||
os.Stderr = w
|
||||
|
||||
// Capture stdout for the "Model pushed" message
|
||||
oldStdout := os.Stdout
|
||||
outR, outW, _ := os.Pipe()
|
||||
os.Stdout = outW
|
||||
|
||||
err = CreateHandler(cmd, []string{tt.modelName})
|
||||
|
||||
// Restore stderr
|
||||
w.Close()
|
||||
os.Stderr = oldStderr
|
||||
// drain the pipe
|
||||
if _, err := io.ReadAll(r); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Restore stdout and get output
|
||||
outW.Close()
|
||||
os.Stdout = oldStdout
|
||||
stdout, _ := io.ReadAll(outR)
|
||||
|
||||
if tt.expectedError == "" {
|
||||
if err != nil {
|
||||
t.Errorf("expected no error, got %v", err)
|
||||
}
|
||||
|
||||
if tt.expectedOutput != "" {
|
||||
if got := string(stdout); got != tt.expectedOutput {
|
||||
t.Errorf("expected output %q, got %q", tt.expectedOutput, got)
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -58,7 +58,7 @@ func AMDGetGPUInfo() ([]RocmGPUInfo, error) {
|
||||
driverMajor, driverMinor, err := AMDDriverVersion()
|
||||
if err != nil {
|
||||
// TODO - if we see users crash and burn with the upstreamed kernel this can be adjusted to hard-fail rocm support and fallback to CPU
|
||||
slog.Warn("ollama recommends running the https://www.amd.com/en/support/linux-drivers", "error", err)
|
||||
slog.Warn("ollama recommends running the https://www.amd.com/en/support/download/linux-drivers.html", "error", err)
|
||||
}
|
||||
|
||||
// Determine if the user has already pre-selected which GPUs to look at, then ignore the others
|
||||
|
||||
@@ -19,7 +19,7 @@ diff --git a/ggml/src/ggml-metal/ggml-metal.m b/ggml/src/ggml-metal/ggml-metal.m
|
||||
index a9eeebc6..110c9ece 100644
|
||||
--- a/ggml/src/ggml-metal/ggml-metal.m
|
||||
+++ b/ggml/src/ggml-metal/ggml-metal.m
|
||||
@@ -489,6 +489,7 @@ enum ggml_metal_kernel_type {
|
||||
@@ -489,6 +489,7 @@ static void ggml_backend_metal_device_rel(struct ggml_backend_metal_device_conte
|
||||
GGML_METAL_KERNEL_TYPE_COS,
|
||||
GGML_METAL_KERNEL_TYPE_NEG,
|
||||
GGML_METAL_KERNEL_TYPE_SUM_ROWS,
|
||||
@@ -27,7 +27,7 @@ index a9eeebc6..110c9ece 100644
|
||||
GGML_METAL_KERNEL_TYPE_POOL_2D_AVG_F32,
|
||||
GGML_METAL_KERNEL_TYPE_POOL_2D_MAX_F32,
|
||||
GGML_METAL_KERNEL_TYPE_ARGMAX,
|
||||
@@ -1436,6 +1437,7 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
|
||||
@@ -1436,6 +1437,7 @@ @implementation GGMLMetalClass
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_COS, cos, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_NEG, neg, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SUM_ROWS, sum_rows, true);
|
||||
|
||||
27
llama/patches/0022-BF16-macos-version-guard.patch
Normal file
27
llama/patches/0022-BF16-macos-version-guard.patch
Normal file
@@ -0,0 +1,27 @@
|
||||
From 0000000000000000000000000000000000000000 Mon Sep 17 00:00:00 2001
|
||||
From: Daniel Hiltgen <daniel@ollama.com>
|
||||
Date: Wed, 30 Jul 2025 08:43:46 -0700
|
||||
Subject: [PATCH] BF16 macos version guard
|
||||
|
||||
Only enable BF16 on supported MacOS versions (v14+)
|
||||
---
|
||||
ggml/src/ggml-metal/ggml-metal.m | 6 +++++-
|
||||
1 file changed, 5 insertions(+), 1 deletion(-)
|
||||
|
||||
diff --git a/ggml/src/ggml-metal/ggml-metal.m b/ggml/src/ggml-metal/ggml-metal.m
|
||||
index 110c9ece..ab46f6e3 100644
|
||||
--- a/ggml/src/ggml-metal/ggml-metal.m
|
||||
+++ b/ggml/src/ggml-metal/ggml-metal.m
|
||||
@@ -89,7 +89,11 @@
|
||||
ctx->has_bfloat |= [ctx->mtl_device supportsFamily:MTLGPUFamilyApple6];
|
||||
|
||||
#if defined(GGML_METAL_USE_BF16)
|
||||
- ctx->use_bfloat = ctx->has_bfloat;
|
||||
+ if (@available(macOS 14.0, *)) {
|
||||
+ ctx->use_bfloat = ctx->has_bfloat;
|
||||
+ } else {
|
||||
+ ctx->use_bfloat = false;
|
||||
+ }
|
||||
#else
|
||||
ctx->use_bfloat = false;
|
||||
#endif
|
||||
@@ -89,7 +89,11 @@ static id<MTLDevice> ggml_backend_metal_device_acq(struct ggml_backend_metal_dev
|
||||
ctx->has_bfloat |= [ctx->mtl_device supportsFamily:MTLGPUFamilyApple6];
|
||||
|
||||
#if defined(GGML_METAL_USE_BF16)
|
||||
ctx->use_bfloat = ctx->has_bfloat;
|
||||
if (@available(macOS 14.0, *)) {
|
||||
ctx->use_bfloat = ctx->has_bfloat;
|
||||
} else {
|
||||
ctx->use_bfloat = false;
|
||||
}
|
||||
#else
|
||||
ctx->use_bfloat = false;
|
||||
#endif
|
||||
|
||||
152
parser/parser.go
152
parser/parser.go
@@ -7,6 +7,8 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"io/fs"
|
||||
"iter"
|
||||
"net/http"
|
||||
"os"
|
||||
"os/user"
|
||||
@@ -148,31 +150,23 @@ func fileDigestMap(path string) (map[string]string, error) {
|
||||
}
|
||||
|
||||
var files []string
|
||||
if fi.IsDir() {
|
||||
fs, err := filesForModel(path)
|
||||
if !fi.IsDir() {
|
||||
files = []string{path}
|
||||
} else {
|
||||
root, err := os.OpenRoot(path)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer root.Close()
|
||||
|
||||
fs, err := filesForModel(root.FS())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for _, f := range fs {
|
||||
f, err := filepath.EvalSymlinks(f)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
rel, err := filepath.Rel(path, f)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if !filepath.IsLocal(rel) {
|
||||
return nil, fmt.Errorf("insecure path: %s", rel)
|
||||
}
|
||||
|
||||
files = append(files, f)
|
||||
files = append(files, filepath.Join(path, f))
|
||||
}
|
||||
} else {
|
||||
files = []string{path}
|
||||
}
|
||||
|
||||
var mu sync.Mutex
|
||||
@@ -218,67 +212,90 @@ func digestForFile(filename string) (string, error) {
|
||||
return fmt.Sprintf("sha256:%x", hash.Sum(nil)), nil
|
||||
}
|
||||
|
||||
func filesForModel(path string) ([]string, error) {
|
||||
detectContentType := func(path string) (string, error) {
|
||||
f, err := os.Open(path)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
defer f.Close()
|
||||
func detectContentType(fsys fs.FS, path string) (string, error) {
|
||||
f, err := fsys.Open(path)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
var b bytes.Buffer
|
||||
b.Grow(512)
|
||||
|
||||
if _, err := io.CopyN(&b, f, 512); err != nil && !errors.Is(err, io.EOF) {
|
||||
return "", err
|
||||
}
|
||||
|
||||
contentType, _, _ := strings.Cut(http.DetectContentType(b.Bytes()), ";")
|
||||
return contentType, nil
|
||||
bts := make([]byte, 512)
|
||||
n, err := io.ReadFull(f, bts)
|
||||
if errors.Is(err, io.ErrUnexpectedEOF) {
|
||||
bts = bts[:n]
|
||||
} else if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
glob := func(pattern, contentType string) ([]string, error) {
|
||||
matches, err := filepath.Glob(pattern)
|
||||
contentType, _, _ := strings.Cut(http.DetectContentType(bts), ";")
|
||||
return contentType, nil
|
||||
}
|
||||
|
||||
func matchFirst(fsys fs.FS, patternsContentTypes ...string) iter.Seq2[string, error] {
|
||||
return func(yield func(string, error) bool) {
|
||||
for i := 0; i < len(patternsContentTypes); i += 2 {
|
||||
pattern := patternsContentTypes[i]
|
||||
contentType := patternsContentTypes[i+1]
|
||||
matches, err := fs.Glob(fsys, pattern)
|
||||
if err != nil {
|
||||
if !yield("", err) {
|
||||
return
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
if len(matches) > 0 {
|
||||
for _, match := range matches {
|
||||
if ct, err := detectContentType(fsys, match); err != nil {
|
||||
if !yield("", err) {
|
||||
return
|
||||
}
|
||||
} else if ct == contentType {
|
||||
if !yield(match, nil) {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func collect[E any](it iter.Seq2[E, error]) (s []E, _ error) {
|
||||
for v, err := range it {
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for _, match := range matches {
|
||||
if ct, err := detectContentType(match); err != nil {
|
||||
return nil, err
|
||||
} else if ct != contentType {
|
||||
return nil, fmt.Errorf("invalid content type: expected %s for %s", ct, match)
|
||||
}
|
||||
}
|
||||
|
||||
return matches, nil
|
||||
s = append(s, v)
|
||||
}
|
||||
return s, nil
|
||||
}
|
||||
|
||||
var files []string
|
||||
if st, _ := glob(filepath.Join(path, "*.safetensors"), "application/octet-stream"); len(st) > 0 {
|
||||
func filesForModel(fsys fs.FS) ([]string, error) {
|
||||
files, err := collect(matchFirst(
|
||||
fsys,
|
||||
// safetensors files might be unresolved git lfs references; skip if they are
|
||||
// covers model-x-of-y.safetensors, model.fp32-x-of-y.safetensors, model.safetensors
|
||||
files = append(files, st...)
|
||||
} else if pt, _ := glob(filepath.Join(path, "pytorch_model*.bin"), "application/zip"); len(pt) > 0 {
|
||||
"*.safetensors", "application/octet-stream",
|
||||
// pytorch files might also be unresolved git lfs references; skip if they are
|
||||
// covers pytorch_model-x-of-y.bin, pytorch_model.fp32-x-of-y.bin, pytorch_model.bin
|
||||
files = append(files, pt...)
|
||||
} else if pt, _ := glob(filepath.Join(path, "consolidated*.pth"), "application/zip"); len(pt) > 0 {
|
||||
"pytorch_model*.bin", "application/zip",
|
||||
// pytorch files might also be unresolved git lfs references; skip if they are
|
||||
// covers consolidated.x.pth, consolidated.pth
|
||||
files = append(files, pt...)
|
||||
} else if gg, _ := glob(filepath.Join(path, "*.gguf"), "application/octet-stream"); len(gg) > 0 {
|
||||
"consolidated*.pth", "application/zip",
|
||||
// covers gguf files ending in .gguf
|
||||
files = append(files, gg...)
|
||||
} else if gg, _ := glob(filepath.Join(path, "*.bin"), "application/octet-stream"); len(gg) > 0 {
|
||||
"*.gguf", "application/octet-stream",
|
||||
// covers gguf files ending in .bin
|
||||
files = append(files, gg...)
|
||||
} else {
|
||||
return nil, ErrModelNotFound
|
||||
"*.bin", "application/octet-stream",
|
||||
))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// add configuration files, json files are detected as text/plain
|
||||
js, err := glob(filepath.Join(path, "*.json"), "text/plain")
|
||||
js, err := collect(matchFirst(fsys, "*.json", "text/plain"))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -286,7 +303,7 @@ func filesForModel(path string) ([]string, error) {
|
||||
|
||||
// bert models require a nested config.json
|
||||
// TODO(mxyng): merge this with the glob above
|
||||
js, err = glob(filepath.Join(path, "**/*.json"), "text/plain")
|
||||
js, err = collect(matchFirst(fsys, "**/*.json", "text/plain"))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -296,14 +313,15 @@ func filesForModel(path string) ([]string, error) {
|
||||
if !slices.ContainsFunc(files, func(s string) bool {
|
||||
return slices.Contains(strings.Split(s, string(os.PathSeparator)), "tokenizer.json")
|
||||
}) {
|
||||
if tks, _ := glob(filepath.Join(path, "tokenizer.model"), "application/octet-stream"); len(tks) > 0 {
|
||||
tokenizers, err := collect(matchFirst(fsys,
|
||||
// add tokenizer.model if it exists, tokenizer.json is automatically picked up by the previous glob
|
||||
// tokenizer.model might be a unresolved git lfs reference; error if it is
|
||||
files = append(files, tks...)
|
||||
} else if tks, _ := glob(filepath.Join(path, "**/tokenizer.model"), "text/plain"); len(tks) > 0 {
|
||||
// some times tokenizer.model is in a subdirectory (e.g. meta-llama/Meta-Llama-3-8B)
|
||||
files = append(files, tks...)
|
||||
"tokenizer.model", "application/octet-stream",
|
||||
))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
files = append(files, tokenizers...)
|
||||
}
|
||||
|
||||
return files, nil
|
||||
|
||||
@@ -28,7 +28,6 @@ import (
|
||||
"golang.org/x/sync/errgroup"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/auth"
|
||||
"github.com/ollama/ollama/discover"
|
||||
"github.com/ollama/ollama/envconfig"
|
||||
"github.com/ollama/ollama/fs/ggml"
|
||||
@@ -56,8 +55,6 @@ var mode string = gin.DebugMode
|
||||
type Server struct {
|
||||
addr net.Addr
|
||||
sched *Scheduler
|
||||
|
||||
perms *auth.APIPermissions
|
||||
}
|
||||
|
||||
func init() {
|
||||
@@ -72,38 +69,6 @@ func init() {
|
||||
gin.SetMode(mode)
|
||||
}
|
||||
|
||||
func loggedFormatter(param gin.LogFormatterParams) string {
|
||||
var statusColor, methodColor, resetColor string
|
||||
if param.IsOutputColor() {
|
||||
statusColor = param.StatusCodeColor()
|
||||
methodColor = param.MethodColor()
|
||||
resetColor = param.ResetColor()
|
||||
}
|
||||
|
||||
if param.Latency > time.Minute {
|
||||
param.Latency = param.Latency.Truncate(time.Second)
|
||||
}
|
||||
|
||||
username := "default"
|
||||
if userVal, exists := param.Keys["username"]; exists {
|
||||
if name, ok := userVal.(string); ok {
|
||||
username = name
|
||||
}
|
||||
}
|
||||
|
||||
return fmt.Sprintf(
|
||||
"[Ollama] %s |%s %3d %s| %13v | %15s | %-20s |%s %-7s %s %#v\n%s",
|
||||
param.TimeStamp.Format("2006/01/02 - 15:04:05"),
|
||||
statusColor, param.StatusCode, resetColor,
|
||||
param.Latency,
|
||||
param.ClientIP,
|
||||
username,
|
||||
methodColor, param.Method, resetColor,
|
||||
param.Path,
|
||||
param.ErrorMessage,
|
||||
)
|
||||
}
|
||||
|
||||
var (
|
||||
errRequired = errors.New("is required")
|
||||
errBadTemplate = errors.New("template error")
|
||||
@@ -1146,43 +1111,6 @@ func allowedHost(host string) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func allowedEndpointsMiddleware(perms *auth.APIPermissions) gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
if !envconfig.UseAuth() || (c.Request.Method == "HEAD" && c.Request.URL.Path == "/") {
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
|
||||
token := strings.TrimPrefix(c.Request.Header.Get("Authorization"), "Bearer ")
|
||||
if token == "" {
|
||||
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "unauthorized"})
|
||||
return
|
||||
}
|
||||
|
||||
pubKey, err := auth.Authenticate(token, fmt.Sprintf("%s,%s", c.Request.Method, c.Request.RequestURI))
|
||||
if err != nil {
|
||||
slog.Error("authentication error", "error", err)
|
||||
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "unauthorized"})
|
||||
return
|
||||
}
|
||||
|
||||
authorized, name, err := perms.Authorize(pubKey, c.Request.URL.Path)
|
||||
c.Set("username", name)
|
||||
if err != nil {
|
||||
slog.Error("authorization error", "error", err)
|
||||
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "unauthorized"})
|
||||
return
|
||||
}
|
||||
|
||||
if !authorized {
|
||||
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "unauthorized"})
|
||||
return
|
||||
}
|
||||
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
func allowedHostsMiddleware(addr net.Addr) gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
if addr == nil {
|
||||
@@ -1249,13 +1177,10 @@ func (s *Server) GenerateRoutes(rc *ollama.Registry) (http.Handler, error) {
|
||||
}
|
||||
corsConfig.AllowOrigins = envconfig.AllowedOrigins()
|
||||
|
||||
r := gin.New()
|
||||
r := gin.Default()
|
||||
r.HandleMethodNotAllowed = true
|
||||
r.Use(
|
||||
gin.LoggerWithFormatter(loggedFormatter),
|
||||
gin.Recovery(),
|
||||
cors.New(corsConfig),
|
||||
allowedEndpointsMiddleware(s.perms),
|
||||
allowedHostsMiddleware(s.addr),
|
||||
)
|
||||
|
||||
@@ -1265,7 +1190,7 @@ func (s *Server) GenerateRoutes(rc *ollama.Registry) (http.Handler, error) {
|
||||
r.HEAD("/api/version", func(c *gin.Context) { c.JSON(http.StatusOK, gin.H{"version": version.Version}) })
|
||||
r.GET("/api/version", func(c *gin.Context) { c.JSON(http.StatusOK, gin.H{"version": version.Version}) })
|
||||
|
||||
// Local model cache management
|
||||
// Local model cache management (new implementation is at end of function)
|
||||
r.POST("/api/pull", s.PullHandler)
|
||||
r.POST("/api/push", s.PushHandler)
|
||||
r.HEAD("/api/tags", s.ListHandler)
|
||||
@@ -1297,7 +1222,7 @@ func (s *Server) GenerateRoutes(rc *ollama.Registry) (http.Handler, error) {
|
||||
// wrap old with new
|
||||
rs := ®istry.Local{
|
||||
Client: rc,
|
||||
Logger: slog.Default(),
|
||||
Logger: slog.Default(), // TODO(bmizerany): Take a logger, do not use slog.Default()
|
||||
Fallback: r,
|
||||
|
||||
Prune: PruneLayers,
|
||||
@@ -1342,12 +1267,6 @@ func Serve(ln net.Listener) error {
|
||||
|
||||
s := &Server{addr: ln.Addr()}
|
||||
|
||||
if envconfig.UseAuth() {
|
||||
perms := auth.NewAPIPermissions()
|
||||
perms.ReloadIfNeeded()
|
||||
s.perms = perms
|
||||
}
|
||||
|
||||
var rc *ollama.Registry
|
||||
if useClient2 {
|
||||
var err error
|
||||
|
||||
Reference in New Issue
Block a user