Compare commits
40 Commits
jyan/local
...
jyan/local
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
f30b54209c | ||
|
|
e39be4f63a | ||
|
|
b8c3d54f7a | ||
|
|
c8434b0e69 | ||
|
|
65658e4077 | ||
|
|
b29382b86f | ||
|
|
2efe2013a1 | ||
|
|
5c3786f4d5 | ||
|
|
33848ad10f | ||
|
|
ff06a2916d | ||
|
|
d923a59356 | ||
|
|
2b42ad5754 | ||
|
|
e3253e5469 | ||
|
|
35b49739ec | ||
|
|
bd8596d32b | ||
|
|
b85705162f | ||
|
|
d62a3a1e2b | ||
|
|
de48cd681f | ||
|
|
5d0e078057 | ||
|
|
8d5739b833 | ||
|
|
b5ff0ed4ff | ||
|
|
857054f9fa | ||
|
|
6dd9be55e2 | ||
|
|
d70707a668 | ||
|
|
c88774ffeb | ||
|
|
34d197000d | ||
|
|
6c0a8379f6 | ||
|
|
163ee9a8b0 | ||
|
|
de7b2f3948 | ||
|
|
f27c66fb0c | ||
|
|
a238191798 | ||
|
|
6436c7a375 | ||
|
|
896a15874e | ||
|
|
56008688a1 | ||
|
|
d14d38e940 | ||
|
|
03df02883d | ||
|
|
ae49abf80a | ||
|
|
2c450502db | ||
|
|
46b76aeb46 | ||
|
|
0e01da82d6 |
@@ -7,6 +7,7 @@ import (
|
||||
"crypto/rand"
|
||||
"encoding/base64"
|
||||
"encoding/pem"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
@@ -26,7 +27,7 @@ func privateKey() (ssh.Signer, error) {
|
||||
|
||||
keyPath := filepath.Join(home, ".ollama", defaultPrivateKey)
|
||||
privateKeyFile, err := os.ReadFile(keyPath)
|
||||
if os.IsNotExist(err) {
|
||||
if errors.Is(err, os.ErrNotExist) {
|
||||
err := initializeKeypair()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -50,7 +51,7 @@ func GetPublicKey() (ssh.PublicKey, error) {
|
||||
|
||||
pubkeyPath := filepath.Join(home, ".ollama", defaultPrivateKey+".pub")
|
||||
pubKeyFile, err := os.ReadFile(pubkeyPath)
|
||||
if os.IsNotExist(err) {
|
||||
if errors.Is(err, os.ErrNotExist) {
|
||||
// try from privateKey
|
||||
privateKey, err := privateKey()
|
||||
if err != nil {
|
||||
@@ -113,7 +114,7 @@ func initializeKeypair() error {
|
||||
pubKeyPath := filepath.Join(home, ".ollama", "id_ed25519.pub")
|
||||
|
||||
_, err = os.Stat(privKeyPath)
|
||||
if os.IsNotExist(err) {
|
||||
if errors.Is(err, os.ErrNotExist) {
|
||||
fmt.Printf("Couldn't find '%s'. Generating new private key.\n", privKeyPath)
|
||||
cryptoPublicKey, cryptoPrivateKey, err := ed25519.GenerateKey(rand.Reader)
|
||||
if err != nil {
|
||||
|
||||
96
cmd/cmd.go
96
cmd/cmd.go
@@ -5,7 +5,6 @@ import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
@@ -111,7 +110,7 @@ func CreateHandler(cmd *cobra.Command, args []string) error {
|
||||
path = tempfile
|
||||
}
|
||||
|
||||
digest, err := createBlob(cmd, client, path)
|
||||
digest, err := createBlob(cmd, path)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -264,7 +263,7 @@ func tempZipFiles(path string) (string, error) {
|
||||
|
||||
var ErrBlobExists = errors.New("blob exists")
|
||||
|
||||
func createBlob(cmd *cobra.Command, client *api.Client, path string) (string, error) {
|
||||
func createBlob(cmd *cobra.Command, path string) (string, error) {
|
||||
bin, err := os.Open(path)
|
||||
if err != nil {
|
||||
return "", err
|
||||
@@ -282,40 +281,23 @@ func createBlob(cmd *cobra.Command, client *api.Client, path string) (string, er
|
||||
|
||||
digest := fmt.Sprintf("sha256:%x", hash.Sum(nil))
|
||||
|
||||
// We check if we can find the models directory locally
|
||||
// If we can, we return the path to the directory
|
||||
// If we can't, we return an error
|
||||
// If the blob exists already, we return the digest
|
||||
dest, err := getLocalPath(cmd.Context(), digest)
|
||||
|
||||
// Use our new CreateBlob request which will include the file path
|
||||
// The server checks for that file and if the server is local, it will copy the file over
|
||||
// If the local copy fails, the server will continue to the default local copy
|
||||
// If that fails, it will continue with the server POST
|
||||
err = CreateBlob(cmd.Context(), path, digest, bin)
|
||||
if errors.Is(err, ErrBlobExists) {
|
||||
return digest, nil
|
||||
}
|
||||
|
||||
// Successfully found the model directory
|
||||
if err == nil {
|
||||
// Copy blob in via OS specific copy
|
||||
// Linux errors out to use io.copy
|
||||
err = localCopy(path, dest)
|
||||
if err == nil {
|
||||
return digest, nil
|
||||
}
|
||||
|
||||
// Default copy using io.copy
|
||||
err = defaultCopy(path, dest)
|
||||
if err == nil {
|
||||
return digest, nil
|
||||
}
|
||||
}
|
||||
|
||||
// If at any point copying the blob over locally fails, we default to the copy through the server
|
||||
if err = client.CreateBlob(cmd.Context(), digest, bin); err != nil {
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return digest, nil
|
||||
}
|
||||
|
||||
func getLocalPath(ctx context.Context, digest string) (string, error) {
|
||||
func CreateBlob(ctx context.Context, src, digest string, r *os.File) (error) {
|
||||
ollamaHost := envconfig.Host
|
||||
|
||||
client := http.DefaultClient
|
||||
@@ -324,75 +306,37 @@ func getLocalPath(ctx context.Context, digest string) (string, error) {
|
||||
Host: net.JoinHostPort(ollamaHost.Host, ollamaHost.Port),
|
||||
}
|
||||
|
||||
data, err := json.Marshal(digest)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
reqBody := bytes.NewReader(data)
|
||||
path := fmt.Sprintf("/api/blobs/%s", digest)
|
||||
requestURL := base.JoinPath(path)
|
||||
request, err := http.NewRequestWithContext(ctx, http.MethodPost, requestURL.String(), reqBody)
|
||||
request, err := http.NewRequestWithContext(ctx, http.MethodPost, requestURL.String(), r)
|
||||
if err != nil {
|
||||
return "", err
|
||||
return err
|
||||
}
|
||||
|
||||
authz, err := api.Authorization(ctx, request)
|
||||
if err != nil {
|
||||
return "", err
|
||||
return err
|
||||
}
|
||||
|
||||
request.Header.Set("Authorization", authz)
|
||||
request.Header.Set("User-Agent", fmt.Sprintf("ollama/%s (%s %s) Go/%s", version.Version, runtime.GOARCH, runtime.GOOS, runtime.Version()))
|
||||
request.Header.Set("X-Redirect-Create", "1")
|
||||
request.Header.Set("X-Ollama-File", src)
|
||||
|
||||
resp, err := client.Do(request)
|
||||
if err != nil {
|
||||
return "", err
|
||||
return err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode == http.StatusTemporaryRedirect {
|
||||
dest := resp.Header.Get("LocalLocation")
|
||||
|
||||
return dest, nil
|
||||
}
|
||||
return "", ErrBlobExists
|
||||
}
|
||||
|
||||
func defaultCopy(path string, dest string) error {
|
||||
// This function should be called if the server is local
|
||||
// It should find the model directory, copy the blob over, and return the digest
|
||||
dirPath := filepath.Dir(dest)
|
||||
|
||||
if err := os.MkdirAll(dirPath, 0o755); err != nil {
|
||||
return err
|
||||
if resp.StatusCode == http.StatusCreated {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Copy blob over
|
||||
sourceFile, err := os.Open(path)
|
||||
if err != nil {
|
||||
return fmt.Errorf("could not open source file: %v", err)
|
||||
}
|
||||
defer sourceFile.Close()
|
||||
|
||||
destFile, err := os.Create(dest)
|
||||
if err != nil {
|
||||
return fmt.Errorf("could not create destination file: %v", err)
|
||||
}
|
||||
defer destFile.Close()
|
||||
|
||||
_, err = io.CopyBuffer(destFile, sourceFile, make([]byte, 4*1024*1024))
|
||||
if err != nil {
|
||||
return fmt.Errorf("error copying file: %v", err)
|
||||
if resp.StatusCode == http.StatusOK {
|
||||
return ErrBlobExists
|
||||
}
|
||||
|
||||
err = destFile.Sync()
|
||||
if err != nil {
|
||||
return fmt.Errorf("error flushing file: %v", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
return err
|
||||
}
|
||||
|
||||
func RunHandler(cmd *cobra.Command, args []string) error {
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
package cmd
|
||||
package server
|
||||
|
||||
import (
|
||||
"os"
|
||||
@@ -1,4 +1,4 @@
|
||||
package cmd
|
||||
package server
|
||||
|
||||
import "errors"
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
//go:build windows
|
||||
// +build windows
|
||||
|
||||
package cmd
|
||||
package server
|
||||
|
||||
import (
|
||||
"os"
|
||||
@@ -942,10 +942,13 @@ func (s *Server) CreateBlobHandler(c *gin.Context) {
|
||||
c.Status(http.StatusOK)
|
||||
return
|
||||
}
|
||||
if c.GetHeader("X-Redirect-Create") == "1" && s.isLocal(c) {
|
||||
c.Header("LocalLocation", path)
|
||||
c.Status(http.StatusTemporaryRedirect)
|
||||
return
|
||||
|
||||
if c.GetHeader("X-Ollama-File") != "" && s.isLocal(c) {
|
||||
err = localBlobCopy(c.GetHeader("X-Ollama-File"), path)
|
||||
if err == nil {
|
||||
c.Status(http.StatusCreated)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
layer, err := NewLayer(c.Request.Body, "")
|
||||
@@ -962,6 +965,25 @@ func (s *Server) CreateBlobHandler(c *gin.Context) {
|
||||
c.Status(http.StatusCreated)
|
||||
}
|
||||
|
||||
func localBlobCopy (src, dest string) error {
|
||||
_, err := os.Stat(src)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = localCopy(src, dest)
|
||||
if err == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
err = defaultCopy(src, dest)
|
||||
if err == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
return fmt.Errorf("failed to copy blob")
|
||||
}
|
||||
|
||||
func (s *Server) isLocal(c *gin.Context) bool {
|
||||
if authz := c.GetHeader("Authorization"); authz != "" {
|
||||
parts := strings.Split(authz, ":")
|
||||
@@ -1010,6 +1032,41 @@ func (s *Server) isLocal(c *gin.Context) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func defaultCopy(path string, dest string) error {
|
||||
// This function should be called if the server is local
|
||||
// It should find the model directory, copy the blob over, and return the digest
|
||||
dirPath := filepath.Dir(dest)
|
||||
|
||||
if err := os.MkdirAll(dirPath, 0o755); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Copy blob over
|
||||
sourceFile, err := os.Open(path)
|
||||
if err != nil {
|
||||
return fmt.Errorf("could not open source file: %v", err)
|
||||
}
|
||||
defer sourceFile.Close()
|
||||
|
||||
destFile, err := os.Create(dest)
|
||||
if err != nil {
|
||||
return fmt.Errorf("could not create destination file: %v", err)
|
||||
}
|
||||
defer destFile.Close()
|
||||
|
||||
_, err = io.CopyBuffer(destFile, sourceFile, make([]byte, 4*1024*1024))
|
||||
if err != nil {
|
||||
return fmt.Errorf("error copying file: %v", err)
|
||||
}
|
||||
|
||||
err = destFile.Sync()
|
||||
if err != nil {
|
||||
return fmt.Errorf("error flushing file: %v", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func isLocalIP(ip netip.Addr) bool {
|
||||
if interfaces, err := net.Interfaces(); err == nil {
|
||||
for _, iface := range interfaces {
|
||||
|
||||
@@ -535,6 +535,7 @@ func TestIsLocalReal(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
clientPubLoc := t.TempDir()
|
||||
t.Setenv("HOME", clientPubLoc)
|
||||
t.Setenv("USERPROFILE", clientPubLoc)
|
||||
|
||||
_, err := auth.GetPublicKey()
|
||||
if err != nil {
|
||||
@@ -572,6 +573,7 @@ func TestIsLocalReal(t *testing.T) {
|
||||
t.Run("different server pubkey", func(t *testing.T) {
|
||||
serverPubLoc := t.TempDir()
|
||||
t.Setenv("HOME", serverPubLoc)
|
||||
t.Setenv("USERPROFILE", serverPubLoc)
|
||||
_, err := auth.GetPublicKey()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
|
||||
Reference in New Issue
Block a user