refactor filesForModel

This commit is contained in:
Michael Yang 2025-07-27 14:38:20 -07:00
parent 19279d778d
commit 087beb40ed
2 changed files with 118 additions and 95 deletions

View File

@ -4,6 +4,7 @@ import (
"bytes" "bytes"
"encoding/json" "encoding/json"
"errors" "errors"
"fmt"
"io" "io"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
@ -14,7 +15,6 @@ import (
"time" "time"
"github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp"
"github.com/google/go-cmp/cmp/cmpopts"
"github.com/spf13/cobra" "github.com/spf13/cobra"
"github.com/ollama/ollama/api" "github.com/ollama/ollama/api"
@ -596,7 +596,7 @@ func TestCreateHandler(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
if _, err := w.WriteString("FROM stdin"); err != nil { if _, err := w.WriteString("FROM test"); err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -611,21 +611,20 @@ func TestCreateHandler(t *testing.T) {
}, },
wantRequest: api.CreateRequest{ wantRequest: api.CreateRequest{
Model: "stdin", Model: "stdin",
From: "stdin", From: "test",
}, },
}, },
{ {
name: "default", name: "default",
filename: func(t *testing.T) string { filename: func(t *testing.T) string {
t.Chdir(t.TempDir()) t.Chdir(t.TempDir())
f, err := os.Create("Modelfile") f, err := os.Create("Modelfile")
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
defer f.Close() defer f.Close()
if _, err := f.WriteString("FROM default"); err != nil { if _, err := f.WriteString("FROM test"); err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -633,27 +632,7 @@ func TestCreateHandler(t *testing.T) {
}, },
wantRequest: api.CreateRequest{ wantRequest: api.CreateRequest{
Model: "default", Model: "default",
From: "default", From: "test",
},
},
{
name: "file flag",
filename: func(t *testing.T) string {
f, err := os.CreateTemp(t.TempDir(), filepath.Base(t.Name()))
if err != nil {
t.Fatal(err)
}
defer f.Close()
if _, err := f.WriteString("FROM file:flag"); err != nil {
t.Fatal(err)
}
return f.Name()
},
wantRequest: api.CreateRequest{
Model: "file_flag",
From: "file:flag",
}, },
}, },
{ {
@ -674,6 +653,29 @@ func TestCreateHandler(t *testing.T) {
}, },
wantRequest: api.CreateRequest{ wantRequest: api.CreateRequest{
Model: "default_safetensors", Model: "default_safetensors",
Files: map[string]string{
"model.safetensors": "sha256:6e340b9cffb37a989ca544e6bb780a2c78901d3fb33738768511a30617afa01d",
},
},
},
{
name: "file flag",
filename: func(t *testing.T) string {
f, err := os.CreateTemp(t.TempDir(), filepath.Base(t.Name()))
if err != nil {
t.Fatal(err)
}
defer f.Close()
if _, err := f.WriteString("FROM test"); err != nil {
t.Fatal(err)
}
return f.Name()
},
wantRequest: api.CreateRequest{
Model: "file_flag",
From: "test",
}, },
}, },
{ {
@ -686,6 +688,7 @@ func TestCreateHandler(t *testing.T) {
return "" return ""
}, },
wantErr: fmt.Errorf("openat %s: path escapes from parent", "model.safetensors"),
}, },
} }
@ -708,7 +711,7 @@ func TestCreateHandler(t *testing.T) {
return return
} }
if diff := cmp.Diff(tt.wantRequest, req, cmpopts.IgnoreFields(api.CreateRequest{}, "Files")); diff != "" { if diff := cmp.Diff(tt.wantRequest, req); diff != "" {
t.Errorf("Create request mismatch (-want +got):\n%s", diff) t.Errorf("Create request mismatch (-want +got):\n%s", diff)
} }
} else if strings.HasPrefix(r.URL.Path, "/api/blobs/") { } else if strings.HasPrefix(r.URL.Path, "/api/blobs/") {
@ -727,7 +730,9 @@ func TestCreateHandler(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
if err := CreateHandler(&cmd, []string{filepath.Base(t.Name())}); !errors.Is(err, tt.wantErr) { if err := CreateHandler(&cmd, []string{filepath.Base(t.Name())}); err != tt.wantErr &&
err.Error() != tt.wantErr.Error() &&
!errors.Is(err, tt.wantErr) {
t.Fatal(err) t.Fatal(err)
} }
}) })

View File

@ -7,6 +7,8 @@ import (
"errors" "errors"
"fmt" "fmt"
"io" "io"
"io/fs"
"iter"
"net/http" "net/http"
"os" "os"
"os/user" "os/user"
@ -148,31 +150,23 @@ func fileDigestMap(path string) (map[string]string, error) {
} }
var files []string var files []string
if fi.IsDir() { if !fi.IsDir() {
fs, err := filesForModel(path) files = []string{path}
} else {
root, err := os.OpenRoot(path)
if err != nil {
return nil, err
}
defer root.Close()
fs, err := filesForModel(root.FS())
if err != nil { if err != nil {
return nil, err return nil, err
} }
for _, f := range fs { for _, f := range fs {
f, err := filepath.EvalSymlinks(f) files = append(files, filepath.Join(path, f))
if err != nil {
return nil, err
}
rel, err := filepath.Rel(path, f)
if err != nil {
return nil, err
}
if !filepath.IsLocal(rel) {
return nil, fmt.Errorf("insecure path: %s", rel)
}
files = append(files, f)
} }
} else {
files = []string{path}
} }
var mu sync.Mutex var mu sync.Mutex
@ -218,67 +212,90 @@ func digestForFile(filename string) (string, error) {
return fmt.Sprintf("sha256:%x", hash.Sum(nil)), nil return fmt.Sprintf("sha256:%x", hash.Sum(nil)), nil
} }
func filesForModel(path string) ([]string, error) { func detectContentType(fsys fs.FS, path string) (string, error) {
detectContentType := func(path string) (string, error) { f, err := fsys.Open(path)
f, err := os.Open(path) if err != nil {
if err != nil { return "", err
return "", err }
} defer f.Close()
defer f.Close()
var b bytes.Buffer bts := make([]byte, 512)
b.Grow(512) n, err := io.ReadFull(f, bts)
if errors.Is(err, io.ErrUnexpectedEOF) {
if _, err := io.CopyN(&b, f, 512); err != nil && !errors.Is(err, io.EOF) { bts = bts[:n]
return "", err } else if err != nil {
} return "", err
contentType, _, _ := strings.Cut(http.DetectContentType(b.Bytes()), ";")
return contentType, nil
} }
glob := func(pattern, contentType string) ([]string, error) { contentType, _, _ := strings.Cut(http.DetectContentType(bts), ";")
matches, err := filepath.Glob(pattern) return contentType, nil
}
func matchFirst(fsys fs.FS, patternsContentTypes ...string) iter.Seq2[string, error] {
return func(yield func(string, error) bool) {
for i := 0; i < len(patternsContentTypes); i += 2 {
pattern := patternsContentTypes[i]
contentType := patternsContentTypes[i+1]
matches, err := fs.Glob(fsys, pattern)
if err != nil {
if !yield("", err) {
return
}
continue
}
if len(matches) > 0 {
for _, match := range matches {
if ct, err := detectContentType(fsys, match); err != nil {
if !yield("", err) {
return
}
} else if ct == contentType {
if !yield(match, nil) {
return
}
}
}
return
}
}
}
}
func collect[E any](it iter.Seq2[E, error]) (s []E, _ error) {
for v, err := range it {
if err != nil { if err != nil {
return nil, err return nil, err
} }
s = append(s, v)
for _, match := range matches {
if ct, err := detectContentType(match); err != nil {
return nil, err
} else if ct != contentType {
return nil, fmt.Errorf("invalid content type: expected %s for %s", ct, match)
}
}
return matches, nil
} }
return s, nil
}
var files []string func filesForModel(fsys fs.FS) ([]string, error) {
if st, _ := glob(filepath.Join(path, "*.safetensors"), "application/octet-stream"); len(st) > 0 { files, err := collect(matchFirst(
fsys,
// safetensors files might be unresolved git lfs references; skip if they are // safetensors files might be unresolved git lfs references; skip if they are
// covers model-x-of-y.safetensors, model.fp32-x-of-y.safetensors, model.safetensors // covers model-x-of-y.safetensors, model.fp32-x-of-y.safetensors, model.safetensors
files = append(files, st...) "*.safetensors", "application/octet-stream",
} else if pt, _ := glob(filepath.Join(path, "pytorch_model*.bin"), "application/zip"); len(pt) > 0 {
// pytorch files might also be unresolved git lfs references; skip if they are // pytorch files might also be unresolved git lfs references; skip if they are
// covers pytorch_model-x-of-y.bin, pytorch_model.fp32-x-of-y.bin, pytorch_model.bin // covers pytorch_model-x-of-y.bin, pytorch_model.fp32-x-of-y.bin, pytorch_model.bin
files = append(files, pt...) "pytorch_model*.bin", "application/zip",
} else if pt, _ := glob(filepath.Join(path, "consolidated*.pth"), "application/zip"); len(pt) > 0 {
// pytorch files might also be unresolved git lfs references; skip if they are // pytorch files might also be unresolved git lfs references; skip if they are
// covers consolidated.x.pth, consolidated.pth // covers consolidated.x.pth, consolidated.pth
files = append(files, pt...) "consolidated*.pth", "application/zip",
} else if gg, _ := glob(filepath.Join(path, "*.gguf"), "application/octet-stream"); len(gg) > 0 {
// covers gguf files ending in .gguf // covers gguf files ending in .gguf
files = append(files, gg...) "*.gguf", "application/octet-stream",
} else if gg, _ := glob(filepath.Join(path, "*.bin"), "application/octet-stream"); len(gg) > 0 {
// covers gguf files ending in .bin // covers gguf files ending in .bin
files = append(files, gg...) "*.bin", "application/octet-stream",
} else { ))
return nil, ErrModelNotFound if err != nil {
return nil, err
} }
// add configuration files, json files are detected as text/plain // add configuration files, json files are detected as text/plain
js, err := glob(filepath.Join(path, "*.json"), "text/plain") js, err := collect(matchFirst(fsys, "*.json", "text/plain"))
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -286,7 +303,7 @@ func filesForModel(path string) ([]string, error) {
// bert models require a nested config.json // bert models require a nested config.json
// TODO(mxyng): merge this with the glob above // TODO(mxyng): merge this with the glob above
js, err = glob(filepath.Join(path, "**/*.json"), "text/plain") js, err = collect(matchFirst(fsys, "**/*.json", "text/plain"))
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -296,14 +313,15 @@ func filesForModel(path string) ([]string, error) {
if !slices.ContainsFunc(files, func(s string) bool { if !slices.ContainsFunc(files, func(s string) bool {
return slices.Contains(strings.Split(s, string(os.PathSeparator)), "tokenizer.json") return slices.Contains(strings.Split(s, string(os.PathSeparator)), "tokenizer.json")
}) { }) {
if tks, _ := glob(filepath.Join(path, "tokenizer.model"), "application/octet-stream"); len(tks) > 0 { tokenizers, err := collect(matchFirst(fsys,
// add tokenizer.model if it exists, tokenizer.json is automatically picked up by the previous glob // add tokenizer.model if it exists, tokenizer.json is automatically picked up by the previous glob
// tokenizer.model might be a unresolved git lfs reference; error if it is // tokenizer.model might be a unresolved git lfs reference; error if it is
files = append(files, tks...) "tokenizer.model", "application/octet-stream",
} else if tks, _ := glob(filepath.Join(path, "**/tokenizer.model"), "text/plain"); len(tks) > 0 { ))
// some times tokenizer.model is in a subdirectory (e.g. meta-llama/Meta-Llama-3-8B) if err != nil {
files = append(files, tks...) return nil, err
} }
files = append(files, tokenizers...)
} }
return files, nil return files, nil