diff --git a/cmd/cmd_test.go b/cmd/cmd_test.go index 096e65b9a..11b1643ff 100644 --- a/cmd/cmd_test.go +++ b/cmd/cmd_test.go @@ -4,6 +4,7 @@ import ( "bytes" "encoding/json" "errors" + "fmt" "io" "net/http" "net/http/httptest" @@ -14,7 +15,6 @@ import ( "time" "github.com/google/go-cmp/cmp" - "github.com/google/go-cmp/cmp/cmpopts" "github.com/spf13/cobra" "github.com/ollama/ollama/api" @@ -596,7 +596,7 @@ func TestCreateHandler(t *testing.T) { t.Fatal(err) } - if _, err := w.WriteString("FROM stdin"); err != nil { + if _, err := w.WriteString("FROM test"); err != nil { t.Fatal(err) } @@ -611,21 +611,20 @@ func TestCreateHandler(t *testing.T) { }, wantRequest: api.CreateRequest{ Model: "stdin", - From: "stdin", + From: "test", }, }, { name: "default", filename: func(t *testing.T) string { t.Chdir(t.TempDir()) - f, err := os.Create("Modelfile") if err != nil { t.Fatal(err) } defer f.Close() - if _, err := f.WriteString("FROM default"); err != nil { + if _, err := f.WriteString("FROM test"); err != nil { t.Fatal(err) } @@ -633,27 +632,7 @@ func TestCreateHandler(t *testing.T) { }, wantRequest: api.CreateRequest{ Model: "default", - From: "default", - }, - }, - { - name: "file flag", - filename: func(t *testing.T) string { - f, err := os.CreateTemp(t.TempDir(), filepath.Base(t.Name())) - if err != nil { - t.Fatal(err) - } - defer f.Close() - - if _, err := f.WriteString("FROM file:flag"); err != nil { - t.Fatal(err) - } - - return f.Name() - }, - wantRequest: api.CreateRequest{ - Model: "file_flag", - From: "file:flag", + From: "test", }, }, { @@ -674,6 +653,29 @@ func TestCreateHandler(t *testing.T) { }, wantRequest: api.CreateRequest{ Model: "default_safetensors", + Files: map[string]string{ + "model.safetensors": "sha256:6e340b9cffb37a989ca544e6bb780a2c78901d3fb33738768511a30617afa01d", + }, + }, + }, + { + name: "file flag", + filename: func(t *testing.T) string { + f, err := os.CreateTemp(t.TempDir(), filepath.Base(t.Name())) + if err != nil { + t.Fatal(err) + } + defer f.Close() + + if _, err := f.WriteString("FROM test"); err != nil { + t.Fatal(err) + } + + return f.Name() + }, + wantRequest: api.CreateRequest{ + Model: "file_flag", + From: "test", }, }, { @@ -686,6 +688,7 @@ func TestCreateHandler(t *testing.T) { return "" }, + wantErr: fmt.Errorf("openat %s: path escapes from parent", "model.safetensors"), }, } @@ -708,7 +711,7 @@ func TestCreateHandler(t *testing.T) { return } - if diff := cmp.Diff(tt.wantRequest, req, cmpopts.IgnoreFields(api.CreateRequest{}, "Files")); diff != "" { + if diff := cmp.Diff(tt.wantRequest, req); diff != "" { t.Errorf("Create request mismatch (-want +got):\n%s", diff) } } else if strings.HasPrefix(r.URL.Path, "/api/blobs/") { @@ -727,7 +730,9 @@ func TestCreateHandler(t *testing.T) { t.Fatal(err) } - if err := CreateHandler(&cmd, []string{filepath.Base(t.Name())}); !errors.Is(err, tt.wantErr) { + if err := CreateHandler(&cmd, []string{filepath.Base(t.Name())}); err != tt.wantErr && + err.Error() != tt.wantErr.Error() && + !errors.Is(err, tt.wantErr) { t.Fatal(err) } }) diff --git a/parser/parser.go b/parser/parser.go index d40a79c29..b7ebcba85 100644 --- a/parser/parser.go +++ b/parser/parser.go @@ -7,6 +7,8 @@ import ( "errors" "fmt" "io" + "io/fs" + "iter" "net/http" "os" "os/user" @@ -148,31 +150,23 @@ func fileDigestMap(path string) (map[string]string, error) { } var files []string - if fi.IsDir() { - fs, err := filesForModel(path) + if !fi.IsDir() { + files = []string{path} + } else { + root, err := os.OpenRoot(path) + if err != nil { + return nil, err + } + defer root.Close() + + fs, err := filesForModel(root.FS()) if err != nil { return nil, err } for _, f := range fs { - f, err := filepath.EvalSymlinks(f) - if err != nil { - return nil, err - } - - rel, err := filepath.Rel(path, f) - if err != nil { - return nil, err - } - - if !filepath.IsLocal(rel) { - return nil, fmt.Errorf("insecure path: %s", rel) - } - - files = append(files, f) + files = append(files, filepath.Join(path, f)) } - } else { - files = []string{path} } var mu sync.Mutex @@ -218,67 +212,90 @@ func digestForFile(filename string) (string, error) { return fmt.Sprintf("sha256:%x", hash.Sum(nil)), nil } -func filesForModel(path string) ([]string, error) { - detectContentType := func(path string) (string, error) { - f, err := os.Open(path) - if err != nil { - return "", err - } - defer f.Close() +func detectContentType(fsys fs.FS, path string) (string, error) { + f, err := fsys.Open(path) + if err != nil { + return "", err + } + defer f.Close() - var b bytes.Buffer - b.Grow(512) - - if _, err := io.CopyN(&b, f, 512); err != nil && !errors.Is(err, io.EOF) { - return "", err - } - - contentType, _, _ := strings.Cut(http.DetectContentType(b.Bytes()), ";") - return contentType, nil + bts := make([]byte, 512) + n, err := io.ReadFull(f, bts) + if errors.Is(err, io.ErrUnexpectedEOF) { + bts = bts[:n] + } else if err != nil { + return "", err } - glob := func(pattern, contentType string) ([]string, error) { - matches, err := filepath.Glob(pattern) + contentType, _, _ := strings.Cut(http.DetectContentType(bts), ";") + return contentType, nil +} + +func matchFirst(fsys fs.FS, patternsContentTypes ...string) iter.Seq2[string, error] { + return func(yield func(string, error) bool) { + for i := 0; i < len(patternsContentTypes); i += 2 { + pattern := patternsContentTypes[i] + contentType := patternsContentTypes[i+1] + matches, err := fs.Glob(fsys, pattern) + if err != nil { + if !yield("", err) { + return + } + continue + } + + if len(matches) > 0 { + for _, match := range matches { + if ct, err := detectContentType(fsys, match); err != nil { + if !yield("", err) { + return + } + } else if ct == contentType { + if !yield(match, nil) { + return + } + } + } + + return + } + } + } +} + +func collect[E any](it iter.Seq2[E, error]) (s []E, _ error) { + for v, err := range it { if err != nil { return nil, err } - - for _, match := range matches { - if ct, err := detectContentType(match); err != nil { - return nil, err - } else if ct != contentType { - return nil, fmt.Errorf("invalid content type: expected %s for %s", ct, match) - } - } - - return matches, nil + s = append(s, v) } + return s, nil +} - var files []string - if st, _ := glob(filepath.Join(path, "*.safetensors"), "application/octet-stream"); len(st) > 0 { +func filesForModel(fsys fs.FS) ([]string, error) { + files, err := collect(matchFirst( + fsys, // safetensors files might be unresolved git lfs references; skip if they are // covers model-x-of-y.safetensors, model.fp32-x-of-y.safetensors, model.safetensors - files = append(files, st...) - } else if pt, _ := glob(filepath.Join(path, "pytorch_model*.bin"), "application/zip"); len(pt) > 0 { + "*.safetensors", "application/octet-stream", // pytorch files might also be unresolved git lfs references; skip if they are // covers pytorch_model-x-of-y.bin, pytorch_model.fp32-x-of-y.bin, pytorch_model.bin - files = append(files, pt...) - } else if pt, _ := glob(filepath.Join(path, "consolidated*.pth"), "application/zip"); len(pt) > 0 { + "pytorch_model*.bin", "application/zip", // pytorch files might also be unresolved git lfs references; skip if they are // covers consolidated.x.pth, consolidated.pth - files = append(files, pt...) - } else if gg, _ := glob(filepath.Join(path, "*.gguf"), "application/octet-stream"); len(gg) > 0 { + "consolidated*.pth", "application/zip", // covers gguf files ending in .gguf - files = append(files, gg...) - } else if gg, _ := glob(filepath.Join(path, "*.bin"), "application/octet-stream"); len(gg) > 0 { + "*.gguf", "application/octet-stream", // covers gguf files ending in .bin - files = append(files, gg...) - } else { - return nil, ErrModelNotFound + "*.bin", "application/octet-stream", + )) + if err != nil { + return nil, err } // add configuration files, json files are detected as text/plain - js, err := glob(filepath.Join(path, "*.json"), "text/plain") + js, err := collect(matchFirst(fsys, "*.json", "text/plain")) if err != nil { return nil, err } @@ -286,7 +303,7 @@ func filesForModel(path string) ([]string, error) { // bert models require a nested config.json // TODO(mxyng): merge this with the glob above - js, err = glob(filepath.Join(path, "**/*.json"), "text/plain") + js, err = collect(matchFirst(fsys, "**/*.json", "text/plain")) if err != nil { return nil, err } @@ -296,14 +313,15 @@ func filesForModel(path string) ([]string, error) { if !slices.ContainsFunc(files, func(s string) bool { return slices.Contains(strings.Split(s, string(os.PathSeparator)), "tokenizer.json") }) { - if tks, _ := glob(filepath.Join(path, "tokenizer.model"), "application/octet-stream"); len(tks) > 0 { + tokenizers, err := collect(matchFirst(fsys, // add tokenizer.model if it exists, tokenizer.json is automatically picked up by the previous glob // tokenizer.model might be a unresolved git lfs reference; error if it is - files = append(files, tks...) - } else if tks, _ := glob(filepath.Join(path, "**/tokenizer.model"), "text/plain"); len(tks) > 0 { - // some times tokenizer.model is in a subdirectory (e.g. meta-llama/Meta-Llama-3-8B) - files = append(files, tks...) + "tokenizer.model", "application/octet-stream", + )) + if err != nil { + return nil, err } + files = append(files, tokenizers...) } return files, nil