Compare commits

..

5 Commits

Author SHA1 Message Date
Josh Yan
0e1ec461f9 import 2024-08-28 11:18:23 -07:00
Josh Yan
52ef79bb7d last lint (hopefully) 2024-08-28 11:12:39 -07:00
Josh Yan
800edd7884 lint again 2024-08-28 11:10:03 -07:00
Josh Yan
01b20fe6f1 lint 2024-08-28 11:07:43 -07:00
Josh Yan
340162fbc3 convert progress 2024-08-28 10:54:52 -07:00
18 changed files with 311 additions and 507 deletions

View File

@@ -124,6 +124,7 @@ func CreateHandler(cmd *cobra.Command, args []string) error {
}
bars := make(map[string]*progress.Bar)
var convertSpin *progress.Spinner
fn := func(resp api.ProgressResponse) error {
if resp.Digest != "" {
spinner.Stop()
@@ -136,6 +137,16 @@ func CreateHandler(cmd *cobra.Command, args []string) error {
}
bar.Set(resp.Completed)
} else if strings.Contains(resp.Status, "converting") {
spinner.Stop()
if convertSpin != nil {
convertSpin.SetMessage(resp.Status)
} else {
status = resp.Status
convertSpin = progress.NewSpinner(resp.Status)
p.Add("convert", convertSpin)
}
} else if status != resp.Status {
spinner.Stop()

View File

@@ -9,6 +9,7 @@ import (
"log/slog"
"strings"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/llm"
)
@@ -79,12 +80,12 @@ func (ModelParameters) specialTokenTypes() []string {
}
}
func (ModelParameters) writeFile(ws io.WriteSeeker, kv llm.KV, ts []llm.Tensor) error {
return llm.WriteGGUF(ws, kv, ts)
func (ModelParameters) writeFile(ws io.WriteSeeker, kv llm.KV, ts []llm.Tensor, fn func(api.ProgressResponse)) error {
return llm.WriteGGUF(ws, kv, ts, fn)
}
func (AdapterParameters) writeFile(ws io.WriteSeeker, kv llm.KV, ts []llm.Tensor) error {
return llm.WriteGGUF(ws, kv, ts)
func (AdapterParameters) writeFile(ws io.WriteSeeker, kv llm.KV, ts []llm.Tensor, fn func(api.ProgressResponse)) error {
return llm.WriteGGUF(ws, kv, ts, fn)
}
type ModelConverter interface {
@@ -99,7 +100,7 @@ type ModelConverter interface {
// specialTokenTypes returns any special token types the model uses
specialTokenTypes() []string
// writeFile writes the model to the provided io.WriteSeeker
writeFile(io.WriteSeeker, llm.KV, []llm.Tensor) error
writeFile(io.WriteSeeker, llm.KV, []llm.Tensor, func(api.ProgressResponse)) error
}
type moreParser interface {
@@ -115,10 +116,10 @@ type AdapterConverter interface {
// See [strings.Replacer](https://pkg.go.dev/strings#Replacer) for details
Replacements() []string
writeFile(io.WriteSeeker, llm.KV, []llm.Tensor) error
writeFile(io.WriteSeeker, llm.KV, []llm.Tensor, func(api.ProgressResponse)) error
}
func ConvertAdapter(fsys fs.FS, ws io.WriteSeeker, baseKV llm.KV) error {
func ConvertAdapter(fsys fs.FS, ws io.WriteSeeker, baseKV llm.KV, fn func(api.ProgressResponse)) error {
bts, err := fs.ReadFile(fsys, "adapter_config.json")
if err != nil {
return err
@@ -153,14 +154,17 @@ func ConvertAdapter(fsys fs.FS, ws io.WriteSeeker, baseKV llm.KV) error {
return err
}
return conv.writeFile(ws, conv.KV(baseKV), conv.Tensors(ts))
fn(api.ProgressResponse{
Status: fmt.Sprintf("converting adapter 0%%"),
})
return conv.writeFile(ws, conv.KV(baseKV), conv.Tensors(ts), fn)
}
// Convert writes an Ollama compatible model to the provided io.WriteSeeker based on configurations
// and files it finds in the input path.
// Supported input model formats include safetensors.
// Supported input tokenizers files include tokenizer.json (preferred) and tokenizer.model.
func ConvertModel(fsys fs.FS, ws io.WriteSeeker) error {
func ConvertModel(fsys fs.FS, ws io.WriteSeeker, fn func(api.ProgressResponse)) error {
bts, err := fs.ReadFile(fsys, "config.json")
if err != nil {
return err
@@ -224,5 +228,8 @@ func ConvertModel(fsys fs.FS, ws io.WriteSeeker) error {
return err
}
return conv.writeFile(ws, conv.KV(t), conv.Tensors(ts))
fn(api.ProgressResponse{
Status: fmt.Sprintf("converting model 0%%"),
})
return conv.writeFile(ws, conv.KV(t), conv.Tensors(ts), fn)
}

View File

@@ -19,6 +19,7 @@ import (
"golang.org/x/exp/maps"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/llm"
)
@@ -31,7 +32,7 @@ func convertFull(t *testing.T, fsys fs.FS) (*os.File, llm.KV, llm.Tensors) {
}
defer f.Close()
if err := ConvertModel(fsys, f); err != nil {
if err := ConvertModel(fsys, f, func(api.ProgressResponse) {}); err != nil {
t.Fatal(err)
}
@@ -89,7 +90,7 @@ func TestMain(m *testing.M) {
os.Exit(m.Run())
}
func TestConvertModel(t *testing.T) {
func TestConvertFull(t *testing.T) {
cases := []string{
"Meta-Llama-3-8B-Instruct",
"Meta-Llama-3.1-8B-Instruct",
@@ -150,7 +151,7 @@ func TestConvertInvalidDatatype(t *testing.T) {
tempDir := t.TempDir()
generateSafetensorTestData(t, tempDir)
err = ConvertModel(os.DirFS(tempDir), f)
err = ConvertModel(os.DirFS(tempDir), f, func(api.ProgressResponse) {})
if err == nil || err.Error() != "unsupported safetensors model" {
t.Errorf("expected error but didn't get one")
}
@@ -287,7 +288,7 @@ func TestConvertAdapter(t *testing.T) {
tempDir := t.TempDir()
generateLoraTestData(t, tempDir)
if err = ConvertAdapter(os.DirFS(tempDir), f, c.BaseKV); err != nil {
if err = ConvertAdapter(os.DirFS(tempDir), f, c.BaseKV, func(api.ProgressResponse) {}); err != nil {
t.Fatal(err)
}

View File

@@ -100,21 +100,8 @@ func parseTokenizer(fsys fs.FS, specialTokenTypes []string) (*Tokenizer, error)
}
if template, ok := p["chat_template"]; ok {
var s []struct {
Name string `json:"name"`
Template string `json:"template"`
}
if err := json.Unmarshal(template, &t.Template); err == nil {
// noop
} else if err := json.Unmarshal(template, &s); err == nil {
for _, e := range s {
if e.Name == "default" {
t.Template = e.Template
break
}
}
} else {
return nil, fmt.Errorf("invalid chat_template: %w", err)
if err := json.Unmarshal(template, &t.Template); err != nil {
return nil, err
}
}
@@ -154,6 +141,7 @@ func parseTokenizer(fsys fs.FS, specialTokenTypes []string) (*Tokenizer, error)
}
type tokenizer struct {
Version string `json:"version"`
AddedTokens []token `json:"added_tokens"`
Model struct {
Type string `json:"type"`
@@ -251,7 +239,7 @@ func parseVocabulary(fsys fs.FS) (*Vocabulary, error) {
return pattern.Func(fsys)
}
return nil, errors.New("unknown tokenizer format")
return nil, errors.New("unknown tensor format")
}
type SpecialVocabulary struct {

View File

@@ -1,208 +0,0 @@
package convert
import (
"io"
"io/fs"
"os"
"path/filepath"
"strings"
"testing"
"github.com/google/go-cmp/cmp"
)
func createTokenizerFS(t *testing.T, dir string, files map[string]io.Reader) fs.FS {
t.Helper()
for k, v := range files {
if err := func() error {
f, err := os.Create(filepath.Join(dir, k))
if err != nil {
return err
}
defer f.Close()
if _, err := io.Copy(f, v); err != nil {
return err
}
return nil
}(); err != nil {
t.Fatalf("unexpected error: %v", err)
}
}
return os.DirFS(dir)
}
func TestParseTokenizer(t *testing.T) {
cases := []struct {
name string
fsys fs.FS
specialTokenTypes []string
want *Tokenizer
}{
{
name: "string chat template",
fsys: createTokenizerFS(t, t.TempDir(), map[string]io.Reader{
"tokenizer.json": strings.NewReader(`{}`),
"tokenizer_config.json": strings.NewReader(`{
"chat_template": "<default template>"
}`),
}),
want: &Tokenizer{
Vocabulary: &Vocabulary{Model: "gpt2"},
Pre: "default",
Template: "<default template>",
},
},
{
name: "list chat template",
fsys: createTokenizerFS(t, t.TempDir(), map[string]io.Reader{
"tokenizer.json": strings.NewReader(`{}`),
"tokenizer_config.json": strings.NewReader(`{
"chat_template": [
{
"name": "default",
"template": "<default template>"
},
{
"name": "tools",
"template": "<tools template>"
}
]
}`),
}),
want: &Tokenizer{
Vocabulary: &Vocabulary{Model: "gpt2"},
Pre: "default",
Template: "<default template>",
},
},
{
name: "added tokens",
fsys: createTokenizerFS(t, t.TempDir(), map[string]io.Reader{
"tokenizer.json": strings.NewReader(`{
"added_tokens": [
{
"id": 999,
"content": "<unused999>",
"special": false
}
]
}`),
}),
want: &Tokenizer{
Vocabulary: &Vocabulary{
Model: "gpt2",
Tokens: []string{"<unused999>"},
Scores: []float32{999},
Types: []int32{4},
},
Pre: "default",
},
},
{
name: "added tokens overlap vocab",
fsys: createTokenizerFS(t, t.TempDir(), map[string]io.Reader{
"tokenizer.json": strings.NewReader(`{
"added_tokens": [
{
"id": 0,
"content": "<pad>",
"special": true
}
],
"model": {
"vocab": {
"<pad>": 0
}
}
}`),
}),
want: &Tokenizer{
Vocabulary: &Vocabulary{
Model: "gpt2",
Tokens: []string{"<pad>"},
Scores: []float32{0},
Types: []int32{3},
},
Pre: "default",
},
},
{
name: "special token types",
fsys: createTokenizerFS(t, t.TempDir(), map[string]io.Reader{
"tokenizer.json": strings.NewReader(`{
"added_tokens": [
{
"id": 0,
"content": "<pad>",
"special": true
},
{
"id": 1,
"content": "<eos>",
"special": true
},
{
"id": 2,
"content": "<bos>",
"special": true
},
{
"id": 3,
"content": "<unk>",
"special": true
}
],
"model": {
"vocab": {
"<pad>": 0,
"<eos>": 1,
"<bos>": 2,
"<unk>": 3
}
}
}`),
"tokenizer_config.json": strings.NewReader(`{
"add_bos_token": true,
"add_eos_token": false,
"bos_token": "<bos>",
"eos_token": "<eos>",
"pad_token": "<pad>",
"unk_token": "<unk>"
}`),
}),
specialTokenTypes: []string{"pad", "eos", "bos", "unk"},
want: &Tokenizer{
Vocabulary: &Vocabulary{
Model: "gpt2",
Tokens: []string{"<pad>", "<eos>", "<bos>", "<unk>"},
Scores: []float32{0, 1, 2, 3},
Types: []int32{3, 3, 3, 3},
},
SpecialVocabulary: []*SpecialVocabulary{
{Type: "pad", Content: "<pad>", ID: 0, AddToken: false},
{Type: "eos", Content: "<eos>", ID: 1, AddToken: false},
{Type: "bos", Content: "<bos>", ID: 2, AddToken: true},
{Type: "unk", Content: "<unk>", ID: 3, AddToken: false},
},
Pre: "default",
},
},
}
for _, tt := range cases {
t.Run(tt.name, func(t *testing.T) {
tokenizer, err := parseTokenizer(tt.fsys, tt.specialTokenTypes)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if diff := cmp.Diff(tt.want, tokenizer); diff != "" {
t.Errorf("unexpected tokenizer (-want +got):\n%s", diff)
}
})
}
}

View File

@@ -12,6 +12,8 @@ import (
"strings"
"golang.org/x/exp/maps"
"github.com/ollama/ollama/api"
)
type containerGGUF struct {
@@ -506,7 +508,7 @@ func writeGGUFArray[S ~[]E, E any](w io.Writer, t uint32, s S) error {
return binary.Write(w, binary.LittleEndian, s)
}
func WriteGGUF(ws io.WriteSeeker, kv KV, ts []Tensor) error {
func WriteGGUF(ws io.WriteSeeker, kv KV, ts []Tensor, fn func(api.ProgressResponse)) error {
if err := binary.Write(ws, binary.LittleEndian, []byte("GGUF")); err != nil {
return err
}
@@ -552,7 +554,10 @@ func WriteGGUF(ws io.WriteSeeker, kv KV, ts []Tensor) error {
}
var alignment int64 = 32
for _, t := range ts {
for i, t := range ts {
fn(api.ProgressResponse{
Status: fmt.Sprintf("converting model %d%%", 100*(i+1)/len(ts)),
})
if err := ggufWriteTensor(ws, t, alignment); err != nil {
return err
}

View File

@@ -41,7 +41,7 @@ func TestEstimateGPULayers(t *testing.T) {
"tokenizer.ggml.tokens": []string{" "},
"tokenizer.ggml.scores": []float32{0},
"tokenizer.ggml.token_type": []int32{0},
}, tensors)
}, tensors, func(api.ProgressResponse) {})
require.NoError(t, err)
ggml, err := LoadModel(f.Name(), 0)

View File

@@ -24,7 +24,6 @@ import (
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/format"
"github.com/ollama/ollama/types/model"
)
const maxRetries = 6
@@ -452,16 +451,15 @@ func (b *blobDownload) Wait(ctx context.Context, fn func(api.ProgressResponse))
}
}
type downloadOptions struct {
name model.Name
baseURL *url.URL
type downloadOpts struct {
mp ModelPath
digest string
regOpts *registryOptions
fn func(api.ProgressResponse)
}
// downloadBlob downloads a blob from the registry and stores it in the blobs directory
func downloadBlob(ctx context.Context, opts downloadOptions) (cacheHit bool, _ error) {
func downloadBlob(ctx context.Context, opts downloadOpts) (cacheHit bool, _ error) {
fp, err := GetBlobsPath(opts.digest)
if err != nil {
return false, err
@@ -486,7 +484,8 @@ func downloadBlob(ctx context.Context, opts downloadOptions) (cacheHit bool, _ e
data, ok := blobDownloadManager.LoadOrStore(opts.digest, &blobDownload{Name: fp, Digest: opts.digest})
download := data.(*blobDownload)
if !ok {
requestURL := opts.baseURL.JoinPath("blobs", opts.digest)
requestURL := opts.mp.BaseURL()
requestURL = requestURL.JoinPath("v2", opts.mp.GetNamespaceRepository(), "blobs", opts.digest)
if err := download.Prepare(ctx, requestURL, opts.regOpts); err != nil {
blobDownloadManager.Delete(opts.digest)
return false, err

View File

@@ -16,7 +16,6 @@ import (
"net/http"
"net/url"
"os"
"path"
"path/filepath"
"runtime"
"slices"
@@ -501,7 +500,7 @@ func CreateModel(ctx context.Context, name model.Name, modelFileDir, quantizatio
return false
}
if err := layer.Prune(); err != nil {
if err := layer.Remove(); err != nil {
return false
}
@@ -689,40 +688,152 @@ func CopyModel(src, dst model.Name) error {
return err
}
func PushModel(ctx context.Context, name model.Name, opts registryOptions, fn func(api.ProgressResponse)) error {
m, err := ParseNamedManifest(name)
func deleteUnusedLayers(deleteMap map[string]struct{}) error {
manifests, err := Manifests()
if err != nil {
return err
}
scheme := "https"
if opts.Insecure {
scheme = "http"
for _, manifest := range manifests {
for _, layer := range manifest.Layers {
delete(deleteMap, layer.Digest)
}
delete(deleteMap, manifest.Config.Digest)
}
baseURL, err := url.Parse(fmt.Sprintf("%s://%s", scheme, path.Join(name.Host, "v2", name.Namespace, name.Model)))
// only delete the files which are still in the deleteMap
for k := range deleteMap {
fp, err := GetBlobsPath(k)
if err != nil {
slog.Info(fmt.Sprintf("couldn't get file path for '%s': %v", k, err))
continue
}
if err := os.Remove(fp); err != nil {
slog.Info(fmt.Sprintf("couldn't remove file '%s': %v", fp, err))
continue
}
}
return nil
}
func PruneLayers() error {
deleteMap := make(map[string]struct{})
p, err := GetBlobsPath("")
if err != nil {
return err
}
for _, layer := range append(m.Layers, m.Config) {
if err := uploadBlob(ctx, uploadOptions{name: name, baseURL: baseURL, layer: layer, regOpts: &opts, fn: fn}); err != nil {
blobs, err := os.ReadDir(p)
if err != nil {
slog.Info(fmt.Sprintf("couldn't read dir '%s': %v", p, err))
return err
}
for _, blob := range blobs {
name := blob.Name()
name = strings.ReplaceAll(name, "-", ":")
_, err := GetBlobsPath(name)
if err != nil {
if errors.Is(err, ErrInvalidDigestFormat) {
// remove invalid blobs (e.g. partial downloads)
if err := os.Remove(filepath.Join(p, blob.Name())); err != nil {
slog.Error("couldn't remove blob", "blob", blob.Name(), "error", err)
}
}
continue
}
deleteMap[name] = struct{}{}
}
slog.Info(fmt.Sprintf("total blobs: %d", len(deleteMap)))
if err := deleteUnusedLayers(deleteMap); err != nil {
slog.Error(fmt.Sprintf("couldn't remove unused layers: %v", err))
return nil
}
slog.Info(fmt.Sprintf("total unused blobs removed: %d", len(deleteMap)))
return nil
}
func PruneDirectory(path string) error {
info, err := os.Lstat(path)
if err != nil {
return err
}
if info.IsDir() && info.Mode()&os.ModeSymlink == 0 {
entries, err := os.ReadDir(path)
if err != nil {
return err
}
for _, entry := range entries {
if err := PruneDirectory(filepath.Join(path, entry.Name())); err != nil {
return err
}
}
entries, err = os.ReadDir(path)
if err != nil {
return err
}
if len(entries) > 0 {
return nil
}
return os.Remove(path)
}
return nil
}
func PushModel(ctx context.Context, name string, regOpts *registryOptions, fn func(api.ProgressResponse)) error {
mp := ParseModelPath(name)
fn(api.ProgressResponse{Status: "retrieving manifest"})
if mp.ProtocolScheme == "http" && !regOpts.Insecure {
return errors.New("insecure protocol http")
}
manifest, _, err := GetManifest(mp)
if err != nil {
fn(api.ProgressResponse{Status: "couldn't retrieve manifest"})
return err
}
var layers []Layer
layers = append(layers, manifest.Layers...)
if manifest.Config.Digest != "" {
layers = append(layers, manifest.Config)
}
for _, layer := range layers {
if err := uploadBlob(ctx, mp, layer, regOpts, fn); err != nil {
slog.Info(fmt.Sprintf("error uploading blob: %v", err))
return err
}
}
fn(api.ProgressResponse{Status: "pushing manifest"})
requestURL := baseURL.JoinPath("manifests", name.Tag)
requestURL := mp.BaseURL()
requestURL = requestURL.JoinPath("v2", mp.GetNamespaceRepository(), "manifests", mp.Tag)
manifestJSON, err := json.Marshal(m)
manifestJSON, err := json.Marshal(manifest)
if err != nil {
return err
}
headers := make(http.Header)
headers.Set("Content-Type", "application/vnd.docker.distribution.manifest.v2+json")
resp, err := makeRequestWithRetry(ctx, http.MethodPut, requestURL, headers, bytes.NewReader(manifestJSON), &opts)
resp, err := makeRequestWithRetry(ctx, http.MethodPut, requestURL, headers, bytes.NewReader(manifestJSON), regOpts)
if err != nil {
return err
}
@@ -733,83 +844,118 @@ func PushModel(ctx context.Context, name model.Name, opts registryOptions, fn fu
return nil
}
func PullModel(ctx context.Context, name model.Name, opts *registryOptions, fn func(api.ProgressResponse)) error {
mm, _ := ParseNamedManifest(name)
func PullModel(ctx context.Context, name string, regOpts *registryOptions, fn func(api.ProgressResponse)) error {
mp := ParseModelPath(name)
scheme := "https"
if opts.Insecure {
scheme = "http"
// build deleteMap to prune unused layers
deleteMap := make(map[string]struct{})
manifest, _, err := GetManifest(mp)
if errors.Is(err, os.ErrNotExist) {
// noop
} else if err != nil && !errors.Is(err, os.ErrNotExist) {
return err
} else {
for _, l := range manifest.Layers {
deleteMap[l.Digest] = struct{}{}
}
if manifest.Config.Digest != "" {
deleteMap[manifest.Config.Digest] = struct{}{}
}
}
baseURL, err := url.Parse(fmt.Sprintf("%s://%s", scheme, path.Join(name.Host, "v2", name.Namespace, name.Model)))
if err != nil {
return err
if mp.ProtocolScheme == "http" && !regOpts.Insecure {
return errors.New("insecure protocol http")
}
fn(api.ProgressResponse{Status: "pulling manifest"})
m, err := pullModelManifest(ctx, name, baseURL, opts)
manifest, err = pullModelManifest(ctx, mp, regOpts)
if err != nil {
return fmt.Errorf("pull model manifest: %s", err)
}
layers := append(m.Layers, m.Config)
var layers []Layer
layers = append(layers, manifest.Layers...)
if manifest.Config.Digest != "" {
layers = append(layers, manifest.Config)
}
skipVerify := make(map[string]bool)
for _, layer := range layers {
hit, err := downloadBlob(ctx, downloadOptions{
name: name,
baseURL: baseURL,
cacheHit, err := downloadBlob(ctx, downloadOpts{
mp: mp,
digest: layer.Digest,
regOpts: opts,
regOpts: regOpts,
fn: fn,
})
if err != nil {
return err
}
skipVerify[layer.Digest] = hit
skipVerify[layer.Digest] = cacheHit
delete(deleteMap, layer.Digest)
}
delete(deleteMap, manifest.Config.Digest)
fn(api.ProgressResponse{Status: "verifying sha256 digest"})
for _, layer := range layers {
if !skipVerify[layer.Digest] {
if err := verifyBlob(layer.Digest); errors.Is(err, errDigestMismatch) {
if skipVerify[layer.Digest] {
continue
}
if err := verifyBlob(layer.Digest); err != nil {
if errors.Is(err, errDigestMismatch) {
// something went wrong, delete the blob
fp, err := GetBlobsPath(layer.Digest)
if err != nil {
return err
}
if err := os.Remove(fp); err != nil {
// log this, but return the original error
slog.Info(fmt.Sprintf("couldn't remove file with digest mismatch '%s': %v", fp, err))
}
} else if err != nil {
return err
}
return err
}
}
fn(api.ProgressResponse{Status: "writing manifest"})
if err := WriteManifest(name, m.Config, m.Layers); err != nil {
manifestJSON, err := json.Marshal(manifest)
if err != nil {
return err
}
if !envconfig.NoPrune() && mm != nil {
fn(api.ProgressResponse{Status: "pruning old layers"})
_ = mm.RemoveLayers()
fp, err := mp.GetManifestPath()
if err != nil {
return err
}
if err := os.MkdirAll(filepath.Dir(fp), 0o755); err != nil {
return err
}
err = os.WriteFile(fp, manifestJSON, 0o644)
if err != nil {
slog.Info(fmt.Sprintf("couldn't write to %s", fp))
return err
}
if !envconfig.NoPrune() && len(deleteMap) > 0 {
fn(api.ProgressResponse{Status: "removing unused layers"})
if err := deleteUnusedLayers(deleteMap); err != nil {
fn(api.ProgressResponse{Status: fmt.Sprintf("couldn't remove unused layers: %v", err)})
}
}
fn(api.ProgressResponse{Status: "success"})
return nil
}
func pullModelManifest(ctx context.Context, name model.Name, baseURL *url.URL, opts *registryOptions) (*Manifest, error) {
requestURL := baseURL.JoinPath("manifests", name.Tag)
func pullModelManifest(ctx context.Context, mp ModelPath, regOpts *registryOptions) (*Manifest, error) {
requestURL := mp.BaseURL().JoinPath("v2", mp.GetNamespaceRepository(), "manifests", mp.Tag)
headers := make(http.Header)
headers.Set("Accept", "application/vnd.docker.distribution.manifest.v2+json")
resp, err := makeRequestWithRetry(ctx, http.MethodGet, requestURL, headers, nil, opts)
resp, err := makeRequestWithRetry(ctx, http.MethodGet, requestURL, headers, nil, regOpts)
if err != nil {
return nil, err
}
@@ -959,7 +1105,6 @@ func makeRequest(ctx context.Context, method string, requestURL *url.URL, header
return nil, err
}
slog.Debug("request upstream", "method", method, "request", requestURL.Redacted(), "status", resp.StatusCode)
return resp, nil
}

View File

@@ -5,10 +5,7 @@ import (
"errors"
"fmt"
"io"
"log/slog"
"os"
"path/filepath"
"strings"
)
type Layer struct {
@@ -104,8 +101,7 @@ func (l *Layer) Open() (io.ReadSeekCloser, error) {
return os.Open(blob)
}
// Prune removes the layer from the filesystem if it is not referenced any manifest.
func (l *Layer) Prune() error {
func (l *Layer) Remove() error {
if l.Digest == "" {
return nil
}
@@ -129,41 +125,5 @@ func (l *Layer) Prune() error {
return err
}
slog.Debug("pruning layer", "digest", l.Digest)
return os.Remove(blob)
}
func Layers() (map[string]Layer, error) {
blobs, err := GetBlobsPath("")
if err != nil {
return nil, err
}
// TODO(mxyng): use something less brittle
matches, err := filepath.Glob(filepath.Join(blobs, "*"))
if err != nil {
return nil, err
}
layers := make(map[string]Layer)
for _, match := range matches {
rel, err := filepath.Rel(blobs, match)
if err != nil {
slog.Warn("bad filepath", "path", match, "error", err)
continue
}
// TODO(mxyng): this should ideally use model.Digest but
// that's currently incompatible with the manifest digest
digest := strings.Replace(rel, "sha256-", "sha256:", 1)
layer, err := NewLayerFromLayer(digest, "", "")
if err != nil {
slog.Warn("bad blob", "digest", digest, "error", err)
layer = Layer{Digest: rel}
}
layers[digest] = layer
}
return layers, nil
}

View File

@@ -43,13 +43,13 @@ func (m *Manifest) Remove() error {
return err
}
return pruneEmptyDirectory(manifests)
return PruneDirectory(manifests)
}
func (m *Manifest) RemoveLayers() error {
for _, layer := range append(m.Layers, m.Config) {
if layer.Digest != "" {
if err := layer.Prune(); errors.Is(err, os.ErrNotExist) {
if err := layer.Remove(); errors.Is(err, os.ErrNotExist) {
slog.Debug("layer does not exist", "digest", layer.Digest)
} else if err != nil {
return err
@@ -169,38 +169,3 @@ func Manifests() (map[model.Name]*Manifest, error) {
return ms, nil
}
func pruneEmptyDirectory(p string) error {
fi, err := os.Lstat(p)
if err != nil {
return err
}
if fi.Mode()&os.ModeSymlink == 0 {
entries, err := os.ReadDir(p)
if err != nil {
return err
}
for _, entry := range entries {
if entry.IsDir() {
if err := pruneEmptyDirectory(filepath.Join(p, entry.Name())); err != nil {
return err
}
}
}
entries, err = os.ReadDir(p)
if err != nil {
return err
}
if len(entries) == 0 {
if err := os.Remove(p); err != nil {
return err
}
}
}
return nil
}

View File

@@ -34,7 +34,7 @@ func parseFromModel(ctx context.Context, name model.Name, fn func(api.ProgressRe
m, err := ParseNamedManifest(name)
switch {
case errors.Is(err, os.ErrNotExist):
if err := PullModel(ctx, name, &registryOptions{}, fn); err != nil {
if err := PullModel(ctx, name.String(), &registryOptions{}, fn); err != nil {
return nil, err
}
@@ -98,7 +98,6 @@ func parseFromZipFile(_ context.Context, command string, baseLayers []*layerGGML
}
defer os.RemoveAll(p)
fn(api.ProgressResponse{Status: "converting model"})
// TODO(mxyng): this should write directly into a layer
// e.g. NewLayer(arch.Reader(), "application/vnd.ollama.image.model")
t, err := os.CreateTemp(p, "fp16")
@@ -123,13 +122,18 @@ func parseFromZipFile(_ context.Context, command string, baseLayers []*layerGGML
if baseModel == nil {
return nil, fmt.Errorf("no base model specified for the adapter")
}
if err := convert.ConvertAdapter(convert.NewZipReader(r, p, 32<<20), t, baseModel.KV()); err != nil {
fn(api.ProgressResponse{
Status: "converting adapter",
})
if err := convert.ConvertAdapter(convert.NewZipReader(r, p, 32<<20), t, baseModel.KV(), fn); err != nil {
return nil, err
}
layerType = "application/vnd.ollama.image.adapter"
case "model":
if err := convert.ConvertModel(convert.NewZipReader(r, p, 32<<20), t); err != nil {
fn(api.ProgressResponse{
Status: "converting model",
})
if err := convert.ConvertModel(convert.NewZipReader(r, p, 32<<20), t, fn); err != nil {
return nil, err
}
layerType = "application/vnd.ollama.image.model"

View File

@@ -145,7 +145,7 @@ func TestParseFromFileFromLayer(t *testing.T) {
t.Fatalf("failed to open file: %v", err)
}
defer file.Close()
if err := llm.WriteGGUF(file, llm.KV{"general.architecture": "gemma"}, []llm.Tensor{}); err != nil {
if err := llm.WriteGGUF(file, llm.KV{"general.architecture": "gemma"}, []llm.Tensor{}, func(api.ProgressResponse) {}); err != nil {
t.Fatalf("failed to write gguf: %v", err)
}
@@ -197,7 +197,7 @@ func TestParseLayerFromCopy(t *testing.T) {
defer file2.Close()
for range 5 {
if err := llm.WriteGGUF(file2, llm.KV{"general.architecture": "gemma"}, []llm.Tensor{}); err != nil {
if err := llm.WriteGGUF(file2, llm.KV{"general.architecture": "gemma"}, []llm.Tensor{}, func(api.ProgressResponse) {}); err != nil {
t.Fatalf("failed to write gguf: %v", err)
}
}

View File

@@ -464,22 +464,24 @@ func (s *Server) EmbeddingsHandler(c *gin.Context) {
}
func (s *Server) PullHandler(c *gin.Context) {
var r api.PullRequest
if err := c.ShouldBindJSON(&r); errors.Is(err, io.EOF) {
var req api.PullRequest
err := c.ShouldBindJSON(&req)
switch {
case errors.Is(err, io.EOF):
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"})
return
} else if err != nil {
case err != nil:
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
n := model.ParseName(cmp.Or(r.Model, r.Name))
if !n.IsValid() {
name := model.ParseName(cmp.Or(req.Model, req.Name))
if !name.IsValid() {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "invalid model name"})
return
}
if err := checkNameExists(n); err != nil {
if err := checkNameExists(name); err != nil {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
@@ -491,15 +493,19 @@ func (s *Server) PullHandler(c *gin.Context) {
ch <- r
}
regOpts := &registryOptions{
Insecure: req.Insecure,
}
ctx, cancel := context.WithCancel(c.Request.Context())
defer cancel()
if err := PullModel(ctx, n, &registryOptions{Insecure: r.Insecure}, fn); err != nil {
if err := PullModel(ctx, name.DisplayShortest(), regOpts, fn); err != nil {
ch <- gin.H{"error": err.Error()}
}
}()
if r.Stream != nil && !*r.Stream {
if req.Stream != nil && !*req.Stream {
waitForStream(c, ch)
return
}
@@ -508,18 +514,24 @@ func (s *Server) PullHandler(c *gin.Context) {
}
func (s *Server) PushHandler(c *gin.Context) {
var r api.PushRequest
if err := c.ShouldBindJSON(&r); errors.Is(err, io.EOF) {
var req api.PushRequest
err := c.ShouldBindJSON(&req)
switch {
case errors.Is(err, io.EOF):
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"})
return
} else if err != nil {
case err != nil:
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
n := model.ParseName(cmp.Or(r.Model, r.Name))
if !n.IsValid() {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("name %q is invalid", cmp.Or(r.Model, r.Name))})
var model string
if req.Model != "" {
model = req.Model
} else if req.Name != "" {
model = req.Name
} else {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "model is required"})
return
}
@@ -530,15 +542,19 @@ func (s *Server) PushHandler(c *gin.Context) {
ch <- r
}
regOpts := &registryOptions{
Insecure: req.Insecure,
}
ctx, cancel := context.WithCancel(c.Request.Context())
defer cancel()
if err := PushModel(ctx, n, registryOptions{Insecure: r.Insecure}, fn); err != nil {
if err := PushModel(ctx, model, regOpts, fn); err != nil {
ch <- gin.H{"error": err.Error()}
}
}()
if r.Stream != nil && !*r.Stream {
if req.Stream != nil && !*req.Stream {
waitForStream(c, ch)
return
}
@@ -1131,15 +1147,18 @@ func Serve(ln net.Listener) error {
}
if !envconfig.NoPrune() {
layers, err := Layers()
// clean up unused layers and manifests
if err := PruneLayers(); err != nil {
return err
}
manifestsPath, err := GetManifestPath()
if err != nil {
return err
}
for _, layer := range layers {
if err := layer.Prune(); err != nil {
return err
}
if err := PruneDirectory(manifestsPath); err != nil {
return err
}
}

View File

@@ -30,7 +30,7 @@ func createBinFile(t *testing.T, kv map[string]any, ti []llm.Tensor) string {
}
defer f.Close()
if err := llm.WriteGGUF(f, kv, ti); err != nil {
if err := llm.WriteGGUF(f, kv, ti, func(api.ProgressResponse) {}); err != nil {
t.Fatal(err)
}

View File

@@ -5,21 +5,16 @@ import (
"context"
"encoding/binary"
"encoding/json"
"errors"
"fmt"
"io"
"math"
"net"
"net/http"
"net/http/httptest"
"os"
"path/filepath"
"sort"
"strings"
"testing"
"time"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
@@ -457,84 +452,3 @@ func TestNormalize(t *testing.T) {
})
}
}
func TestServe(t *testing.T) {
gin.SetMode(gin.TestMode)
p := t.TempDir()
t.Setenv("OLLAMA_MODELS", p)
var s Server
// seed some models
createRequest(t, s.CreateHandler, api.CreateRequest{
Name: "test-model",
Modelfile: fmt.Sprintf("FROM %s", createBinFile(t, nil, nil)),
})
createRequest(t, s.CreateHandler, api.CreateRequest{
Name: "test-model-2",
Modelfile: "FROM test-model\nSYSTEM You are a good robot.",
})
createRequest(t, s.CreateHandler, api.CreateRequest{
Name: "test-model-3",
Modelfile: "FROM test-model\nSYSTEM You are a bad robot.",
})
checkFileExists(t, filepath.Join(p, "blobs", "*"), []string{
filepath.Join(p, "blobs", "sha256-1c515c46e60f849c6aeffa86e256508ac450464762a31ca08648e418f07c9819"),
filepath.Join(p, "blobs", "sha256-461fd034bb72312965d46160399b1b882c6a2f8c7305237ed7dd65f848fba10c"),
filepath.Join(p, "blobs", "sha256-66e9776a5bb7e5f6093681aa8ba01a7a6b6ae1dd697281f11fa714eaa948a6a4"),
filepath.Join(p, "blobs", "sha256-a4e5e156ddec27e286f75328784d7106b60a4eb1d246e950a001a3f944fbda99"),
filepath.Join(p, "blobs", "sha256-b3a5b5b438604c5103ba403a5455af94ea98494b5bbc177f4665716a37b99c1e"),
filepath.Join(p, "blobs", "sha256-ca239d7bd8ea90e4a5d2e6bf88f8d74a47b14336e73eb4e18bed4dd325018116"),
})
ln, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatal(err)
}
defer ln.Close()
//nolint:errcheck
go Serve(ln)
// wait for server to be healthy (GET / => 200)
ctx, cancel := context.WithTimeout(context.TODO(), time.Second)
defer cancel()
if err := func() error {
tick := time.NewTicker(20 * time.Millisecond)
defer tick.Stop()
for {
select {
case <-ctx.Done():
return errors.New("server did not become healthy")
case <-tick.C:
r, err := http.Get(fmt.Sprintf("http://%s", ln.Addr()))
if err != nil {
continue
}
if err := r.Body.Close(); err != nil {
return err
}
if r.StatusCode == http.StatusOK {
return nil
}
}
}
}(); err != nil {
t.Fatal(err)
}
checkFileExists(t, filepath.Join(p, "blobs", "*"), []string{
filepath.Join(p, "blobs", "sha256-1c515c46e60f849c6aeffa86e256508ac450464762a31ca08648e418f07c9819"),
filepath.Join(p, "blobs", "sha256-461fd034bb72312965d46160399b1b882c6a2f8c7305237ed7dd65f848fba10c"),
filepath.Join(p, "blobs", "sha256-66e9776a5bb7e5f6093681aa8ba01a7a6b6ae1dd697281f11fa714eaa948a6a4"),
filepath.Join(p, "blobs", "sha256-a4e5e156ddec27e286f75328784d7106b60a4eb1d246e950a001a3f944fbda99"),
filepath.Join(p, "blobs", "sha256-b3a5b5b438604c5103ba403a5455af94ea98494b5bbc177f4665716a37b99c1e"),
filepath.Join(p, "blobs", "sha256-ca239d7bd8ea90e4a5d2e6bf88f8d74a47b14336e73eb4e18bed4dd325018116"),
})
}

View File

@@ -128,7 +128,8 @@ func newScenarioRequest(t *testing.T, ctx context.Context, modelName string, est
}, []llm.Tensor{
{Name: "blk.0.attn.weight", Kind: uint32(0), Offset: uint64(0), Shape: []uint64{1, 1, 1, 1}, WriterTo: bytes.NewReader(make([]byte, 32))},
{Name: "output.weight", Kind: uint32(0), Offset: uint64(0), Shape: []uint64{1, 1, 1, 1}, WriterTo: bytes.NewReader(make([]byte, 32))},
}))
},
func(api.ProgressResponse) {}))
require.NoError(t, err)
fname := f.Name()

View File

@@ -21,7 +21,6 @@ import (
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/format"
"github.com/ollama/ollama/types/model"
)
var blobUploadManager sync.Map
@@ -109,7 +108,7 @@ func (b *blobUpload) Prepare(ctx context.Context, requestURL *url.URL, opts *reg
offset += size
}
slog.Info("uploading blob", "digest", b.Digest, "size", format.HumanBytes(b.Total), "parts", len(b.Parts), "size per part", format.HumanBytes(b.Parts[0].Size))
slog.Info(fmt.Sprintf("uploading %s in %d %s part(s)", b.Digest[7:19], len(b.Parts), format.HumanBytes(b.Parts[0].Size)))
requestURL, err = url.Parse(location)
if err != nil {
@@ -363,46 +362,40 @@ func (p *progressWriter) Rollback() {
p.written = 0
}
type uploadOptions struct {
name model.Name
baseURL *url.URL
layer Layer
regOpts *registryOptions
fn func(api.ProgressResponse)
}
func uploadBlob(ctx context.Context, mp ModelPath, layer Layer, opts *registryOptions, fn func(api.ProgressResponse)) error {
requestURL := mp.BaseURL()
requestURL = requestURL.JoinPath("v2", mp.GetNamespaceRepository(), "blobs", layer.Digest)
func uploadBlob(ctx context.Context, opts uploadOptions) error {
requestURL := opts.baseURL.JoinPath("blobs", opts.layer.Digest)
resp, err := makeRequestWithRetry(ctx, http.MethodHead, requestURL, nil, nil, opts.regOpts)
resp, err := makeRequestWithRetry(ctx, http.MethodHead, requestURL, nil, nil, opts)
switch {
case errors.Is(err, os.ErrNotExist):
case err != nil:
return err
default:
defer resp.Body.Close()
opts.fn(api.ProgressResponse{
Status: fmt.Sprintf("pushing %s", opts.layer.Digest[7:19]),
Digest: opts.layer.Digest,
Total: opts.layer.Size,
Completed: opts.layer.Size,
fn(api.ProgressResponse{
Status: fmt.Sprintf("pushing %s", layer.Digest[7:19]),
Digest: layer.Digest,
Total: layer.Size,
Completed: layer.Size,
})
return nil
}
data, ok := blobUploadManager.LoadOrStore(opts.layer.Digest, &blobUpload{Layer: opts.layer})
data, ok := blobUploadManager.LoadOrStore(layer.Digest, &blobUpload{Layer: layer})
upload := data.(*blobUpload)
if !ok {
requestURL := opts.baseURL.JoinPath("blobs", "uploads")
if err := upload.Prepare(ctx, requestURL, opts.regOpts); err != nil {
blobUploadManager.Delete(opts.layer.Digest)
requestURL := mp.BaseURL()
requestURL = requestURL.JoinPath("v2", mp.GetNamespaceRepository(), "blobs/uploads/")
if err := upload.Prepare(ctx, requestURL, opts); err != nil {
blobUploadManager.Delete(layer.Digest)
return err
}
//nolint:contextcheck
go upload.Run(context.Background(), opts.regOpts)
go upload.Run(context.Background(), opts)
}
return upload.Wait(ctx, opts.fn)
return upload.Wait(ctx, fn)
}