diff --git a/api/types.go b/api/types.go index 5b8e034c2..e9392c743 100644 --- a/api/types.go +++ b/api/types.go @@ -471,10 +471,10 @@ type CreateRequest struct { RemoteHost string `json:"remote_host,omitempty"` // Files is a map of files include when creating the model. - Files map[string]string `json:"files,omitempty"` + Files Files `json:"files,omitempty"` // Adapters is a map of LoRA adapters to include when creating the model. - Adapters map[string]string `json:"adapters,omitempty"` + Adapters Files `json:"adapters,omitempty"` // Template is the template used when constructing a request to the model. Template string `json:"template,omitempty"` @@ -503,6 +503,31 @@ type CreateRequest struct { Quantization string `json:"quantization,omitempty"` } +type Files []File + +func (f Files) MarshalJSON() ([]byte, error) { + m := make(map[string]string, len(f)) + for _, file := range f { + m[file.Name] = file.Digest + } + return json.Marshal(m) +} + +func (f *Files) UnmarshalJSON(data []byte) error { + m := make(map[string]string) + if err := json.Unmarshal(data, &m); err != nil { + return err + } + for name, digest := range m { + *f = append(*f, File{Name: name, Digest: digest}) + } + return nil +} + +type File struct { + Name, Path, Digest string +} + // DeleteRequest is the request passed to [Client.Delete]. type DeleteRequest struct { Model string `json:"model"` @@ -988,8 +1013,8 @@ func (d *Duration) UnmarshalJSON(b []byte) (err error) { return nil } -// FormatParams converts specified parameter options to their correct types -func FormatParams(params map[string][]string) (map[string]any, error) { +// FormatParameters converts specified parameter options to their correct types +func FormatParameters(params map[string][]string) (map[string]any, error) { opts := Options{} valueOpts := reflect.ValueOf(&opts).Elem() // names of the fields in the options struct typeOpts := reflect.TypeOf(opts) // types of the fields in the options struct diff --git a/api/types_test.go b/api/types_test.go index 5393b4623..056f4e58f 100644 --- a/api/types_test.go +++ b/api/types_test.go @@ -203,7 +203,7 @@ func TestUseMmapFormatParams(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { - resp, err := FormatParams(test.req) + resp, err := FormatParameters(test.req) require.Equal(t, test.err, err) respVal, ok := resp["use_mmap"] if test.exp != nil { diff --git a/cmd/cmd.go b/cmd/cmd.go index 294e1662f..9420b0e25 100644 --- a/cmd/cmd.go +++ b/cmd/cmd.go @@ -33,20 +33,17 @@ import ( "github.com/olekukonko/tablewriter" "github.com/spf13/cobra" "golang.org/x/crypto/ssh" - "golang.org/x/sync/errgroup" "golang.org/x/term" "github.com/ollama/ollama/api" "github.com/ollama/ollama/auth" "github.com/ollama/ollama/envconfig" "github.com/ollama/ollama/format" - "github.com/ollama/ollama/parser" "github.com/ollama/ollama/progress" "github.com/ollama/ollama/readline" "github.com/ollama/ollama/runner" "github.com/ollama/ollama/server" "github.com/ollama/ollama/types/model" - "github.com/ollama/ollama/types/syncmap" "github.com/ollama/ollama/version" ) @@ -89,179 +86,6 @@ func getModelfileName(cmd *cobra.Command) (string, error) { return absName, nil } -func CreateHandler(cmd *cobra.Command, args []string) error { - p := progress.NewProgress(os.Stderr) - defer p.Stop() - - var reader io.Reader - - filename, err := getModelfileName(cmd) - if os.IsNotExist(err) { - if filename == "" { - reader = strings.NewReader("FROM .\n") - } else { - return errModelfileNotFound - } - } else if err != nil { - return err - } else { - f, err := os.Open(filename) - if err != nil { - return err - } - - reader = f - defer f.Close() - } - - modelfile, err := parser.ParseFile(reader) - if err != nil { - return err - } - - status := "gathering model components" - spinner := progress.NewSpinner(status) - p.Add(status, spinner) - - req, err := modelfile.CreateRequest(filepath.Dir(filename)) - if err != nil { - return err - } - spinner.Stop() - - req.Model = args[0] - quantize, _ := cmd.Flags().GetString("quantize") - if quantize != "" { - req.Quantize = quantize - } - - client, err := api.ClientFromEnvironment() - if err != nil { - return err - } - - var g errgroup.Group - g.SetLimit(max(runtime.GOMAXPROCS(0)-1, 1)) - - files := syncmap.NewSyncMap[string, string]() - for f, digest := range req.Files { - g.Go(func() error { - if _, err := createBlob(cmd, client, f, digest, p); err != nil { - return err - } - - // TODO: this is incorrect since the file might be in a subdirectory - // instead this should take the path relative to the model directory - // but the current implementation does not allow this - files.Store(filepath.Base(f), digest) - return nil - }) - } - - adapters := syncmap.NewSyncMap[string, string]() - for f, digest := range req.Adapters { - g.Go(func() error { - if _, err := createBlob(cmd, client, f, digest, p); err != nil { - return err - } - - // TODO: same here - adapters.Store(filepath.Base(f), digest) - return nil - }) - } - - if err := g.Wait(); err != nil { - return err - } - - req.Files = files.Items() - req.Adapters = adapters.Items() - - bars := make(map[string]*progress.Bar) - fn := func(resp api.ProgressResponse) error { - if resp.Digest != "" { - bar, ok := bars[resp.Digest] - if !ok { - msg := resp.Status - if msg == "" { - msg = fmt.Sprintf("pulling %s...", resp.Digest[7:19]) - } - bar = progress.NewBar(msg, resp.Total, resp.Completed) - bars[resp.Digest] = bar - p.Add(resp.Digest, bar) - } - - bar.Set(resp.Completed) - } else if status != resp.Status { - spinner.Stop() - - status = resp.Status - spinner = progress.NewSpinner(status) - p.Add(status, spinner) - } - - return nil - } - - if err := client.Create(cmd.Context(), req, fn); err != nil { - if strings.Contains(err.Error(), "path or Modelfile are required") { - return fmt.Errorf("the ollama server must be updated to use `ollama create` with this client") - } - return err - } - - return nil -} - -func createBlob(cmd *cobra.Command, client *api.Client, path string, digest string, p *progress.Progress) (string, error) { - realPath, err := filepath.EvalSymlinks(path) - if err != nil { - return "", err - } - - bin, err := os.Open(realPath) - if err != nil { - return "", err - } - defer bin.Close() - - // Get file info to retrieve the size - fileInfo, err := bin.Stat() - if err != nil { - return "", err - } - fileSize := fileInfo.Size() - - var pw progressWriter - status := fmt.Sprintf("copying file %s 0%%", digest) - spinner := progress.NewSpinner(status) - p.Add(status, spinner) - defer spinner.Stop() - - done := make(chan struct{}) - defer close(done) - - go func() { - ticker := time.NewTicker(60 * time.Millisecond) - defer ticker.Stop() - for { - select { - case <-ticker.C: - spinner.SetMessage(fmt.Sprintf("copying file %s %d%%", digest, int(100*pw.n.Load()/fileSize))) - case <-done: - spinner.SetMessage(fmt.Sprintf("copying file %s 100%%", digest)) - return - } - } - }() - - if err := client.CreateBlob(cmd.Context(), digest, io.TeeReader(bin, &pw)); err != nil { - return "", err - } - return digest, nil -} - type progressWriter struct { n atomic.Int64 } diff --git a/cmd/create.go b/cmd/create.go new file mode 100644 index 000000000..01efed949 --- /dev/null +++ b/cmd/create.go @@ -0,0 +1,442 @@ +package cmd + +import ( + "crypto/sha256" + "encoding/hex" + "fmt" + "io" + "io/fs" + "iter" + "log/slog" + "net/http" + "os" + "os/user" + "path/filepath" + "runtime" + "slices" + "strings" + "sync" + "time" + + "github.com/ollama/ollama/api" + "github.com/ollama/ollama/parser" + "github.com/ollama/ollama/progress" + "github.com/spf13/cobra" + "golang.org/x/sync/errgroup" +) + +func expandPath(path, dir string) (string, error) { + if filepath.IsAbs(path) { + return path, nil + } + + path, found := strings.CutPrefix(path, "~") + if !found { + // make path relative to dir + if !filepath.IsAbs(dir) { + // if dir is relative, make it absolute relative to cwd + cwd, err := os.Getwd() + if err != nil { + return "", err + } + dir = filepath.Join(cwd, dir) + } + path = filepath.Join(dir, path) + } else if filepath.IsLocal(path) { + // ~/... + // make path relative to specified user's home + split := strings.SplitN(path, "/", 2) + u, err := user.Lookup(split[0]) + if err != nil { + return "", err + } + split[0] = u.HomeDir + path = filepath.Join(split...) + } else { + // ~ or ~/... + // make path relative to current user's home + home, err := os.UserHomeDir() + if err != nil { + return "", err + } + path = filepath.Join(home, path) + } + + return filepath.Clean(path), nil +} + +func detectContentType(fsys fs.FS, path string) (string, error) { + f, err := fsys.Open(path) + if err != nil { + return "", err + } + defer f.Close() + + b := make([]byte, 512) + if _, err := f.Read(b); err != nil && err != io.EOF { + return "", err + } + + contentType, _, _ := strings.Cut(http.DetectContentType(b), ";") + return contentType, nil +} + +// glob returns an iterator that yields files matching the given patterns and content types. +// The patterns and content types are provided as pairs of strings. +// If a content type is an empty string, all files matching the pattern are yielded. +// The iterator stops after the first pattern that matches any files. +func glob(fsys fs.FS, patternOrContentType ...string) iter.Seq2[string, error] { + if len(patternOrContentType)%2 != 0 { + panic("glob: patternOrContentType must have an even number of elements") + } + + return func(yield func(string, error) bool) { + for i := 0; i < len(patternOrContentType); i += 2 { + pattern := patternOrContentType[i] + contentType := patternOrContentType[i+1] + + matches, err := fs.Glob(fsys, pattern) + if err != nil { + yield("", err) + return + } + + if len(matches) > 0 { + for _, match := range matches { + if contentType == "" { + if !yield(match, nil) { + return + } + + continue + } + + ct, err := detectContentType(fsys, match) + if err != nil { + yield("", err) + return + } + + if ct == contentType { + if !yield(match, nil) { + return + } + } + } + + return + } + } + } +} + +func filesSeq(fsys fs.FS) iter.Seq[string] { + return func(yield func(string) bool) { + for match := range glob(fsys, + "*.safetensors", "", + "*.bin", "application/zip", + "*.pth", "application/zip", + "*.gguf", "application/octet-stream", + "*.bin", "application/octet-stream") { + if !yield(match) { + return + } + } + + for match := range glob(fsys, + "tokenizer.json", "application/json", + "tokenizer.model", "application/octet-stream", + ) { + if !yield(match) { + return + } + } + + for match := range glob(fsys, "*.json", "") { + if !yield(match) { + return + } + } + + for match := range glob(fsys, "**/*.json", "") { + if !yield(match) { + return + } + } + } +} + +func get[T any](m map[string]any, key string) (t T) { + if v, ok := m[key].(T); ok { + t = v + } + return +} + +var deprecatedParameters = []string{ + "penalize_newline", + "low_vram", + "f16_kv", + "logits_all", + "vocab_only", + "use_mlock", + "mirostat", + "mirostat_tau", + "mirostat_eta", +} + +func createRequest(modelfile *parser.Modelfile, dir string) (*api.CreateRequest, error) { + m := make(map[string]any) + parameters := make(map[string]any) + var files, adapters []api.File + + var g errgroup.Group + g.SetLimit(runtime.GOMAXPROCS(0)) + for _, cmd := range modelfile.Commands { + switch cmd.Name { + case "model", "adapter": + path, err := expandPath(cmd.Args, dir) + if err != nil { + return nil, err + } + + if stat, err := os.Stat(path); err != nil { + return nil, err + } else if !stat.IsDir() { + return nil, nil + } + + var mu sync.Mutex + for file := range filesSeq(os.DirFS(path)) { + g.Go(func() error { + f, err := os.Open(filepath.Join(path, file)) + if err != nil { + return err + } + defer f.Close() + + sha256sum := sha256.New() + if _, err := io.Copy(sha256sum, f); err != nil { + return err + } + + file := api.File{ + Name: file, + Path: filepath.Join(path, file), + Digest: "sha256:" + hex.EncodeToString(sha256sum.Sum(nil)), + } + + mu.Lock() + defer mu.Unlock() + switch cmd.Name { + case "model": + files = append(files, file) + case "adapter": + adapters = append(adapters, file) + } + + return nil + }) + } + + case "template", "system", "renderer", "parser": + m[cmd.Name] = cmd.Args + case "license": + m[cmd.Name] = append(get[[]string](m, cmd.Name), cmd.Args) + case "message": + role, msg, found := strings.Cut(cmd.Args, ": ") + if !found { + return nil, fmt.Errorf("invalid message command: %s", cmd.Args) + } + + m[cmd.Name] = append(get[[]api.Message](m, cmd.Name), api.Message{ + Role: role, + Content: msg, + }) + default: + if slices.Contains(deprecatedParameters, cmd.Name) { + slog.Warn("parameter is deprecated", "name", cmd.Name) + break + } + + ps, err := api.FormatParameters(map[string][]string{cmd.Name: {cmd.Args}}) + if err != nil { + return nil, err + } + + for k, v := range ps { + if ks, ok := parameters[k].([]string); ok { + parameters[k] = append(ks, v.([]string)...) + } else if vs, ok := v.([]string); ok { + parameters[k] = vs + } else { + parameters[k] = v + } + } + } + } + + if err := g.Wait(); err != nil { + return nil, err + } + + return &api.CreateRequest{ + Files: files, + Adapters: adapters, + Parameters: parameters, + Template: get[string](m, "template"), + System: get[string](m, "system"), + License: get[[]string](m, "license"), + Messages: get[[]api.Message](m, "message"), + Renderer: get[string](m, "renderer"), + Parser: get[string](m, "parser"), + }, nil +} + +func CreateHandler(cmd *cobra.Command, args []string) error { + p := progress.NewProgress(os.Stderr) + defer p.Stop() + + var reader io.Reader + + filename, err := getModelfileName(cmd) + if os.IsNotExist(err) { + if filename == "" { + reader = strings.NewReader("FROM .\n") + } else { + return errModelfileNotFound + } + } else if err != nil { + return err + } else { + f, err := os.Open(filename) + if err != nil { + return err + } + + reader = f + defer f.Close() + } + + modelfile, err := parser.ParseFile(reader) + if err != nil { + return err + } + + status := "gathering model components" + spinner := progress.NewSpinner(status) + p.Add(status, spinner) + + req, err := createRequest(modelfile, filepath.Dir(filename)) + if err != nil { + return err + } + spinner.Stop() + + req.Model = args[0] + quantize, _ := cmd.Flags().GetString("quantize") + if quantize != "" { + req.Quantize = quantize + } + + client, err := api.ClientFromEnvironment() + if err != nil { + return err + } + + var g errgroup.Group + g.SetLimit(runtime.GOMAXPROCS(0)) + for _, file := range req.Files { + g.Go(func() error { + _, err := createBlob(cmd, client, file.Path, file.Digest, p) + return err + }) + } + + if err := g.Wait(); err != nil { + return err + } + + bars := make(map[string]*progress.Bar) + fn := func(resp api.ProgressResponse) error { + if resp.Digest != "" { + bar, ok := bars[resp.Digest] + if !ok { + msg := resp.Status + if msg == "" { + msg = fmt.Sprintf("pulling %s...", resp.Digest[7:19]) + } + bar = progress.NewBar(msg, resp.Total, resp.Completed) + bars[resp.Digest] = bar + p.Add(resp.Digest, bar) + } + + bar.Set(resp.Completed) + } else if status != resp.Status { + spinner.Stop() + + status = resp.Status + spinner = progress.NewSpinner(status) + p.Add(status, spinner) + } + + return nil + } + + if err := client.Create(cmd.Context(), req, fn); err != nil { + if strings.Contains(err.Error(), "path or Modelfile are required") { + return fmt.Errorf("the ollama server must be updated to use `ollama create` with this client") + } + return err + } + + return nil +} + +func createBlob(cmd *cobra.Command, client *api.Client, path string, digest string, p *progress.Progress) (string, error) { + realPath, err := filepath.EvalSymlinks(path) + if err != nil { + return "", err + } + + bin, err := os.Open(realPath) + if err != nil { + return "", err + } + defer bin.Close() + + // Get file info to retrieve the size + fileInfo, err := bin.Stat() + if err != nil { + return "", err + } + fileSize := fileInfo.Size() + + var pw progressWriter + status := fmt.Sprintf("copying file %s 0%%", digest) + spinner := progress.NewSpinner(status) + p.Add(status, spinner) + defer spinner.Stop() + + done := make(chan struct{}) + defer close(done) + + go func() { + ticker := time.NewTicker(60 * time.Millisecond) + defer ticker.Stop() + for { + select { + case <-ticker.C: + spinner.SetMessage(fmt.Sprintf("copying file %s %d%%", digest, int(100*pw.n.Load()/fileSize))) + case <-done: + spinner.SetMessage(fmt.Sprintf("copying file %s 100%%", digest)) + return + } + } + }() + + if err := client.CreateBlob(cmd.Context(), digest, io.TeeReader(bin, &pw)); err != nil { + return "", err + } + return digest, nil +} diff --git a/cmd/interactive.go b/cmd/interactive.go index e290d84ce..222f643df 100644 --- a/cmd/interactive.go +++ b/cmd/interactive.go @@ -316,7 +316,7 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error { continue } params := args[3:] - fp, err := api.FormatParams(map[string][]string{args[2]: params}) + fp, err := api.FormatParameters(map[string][]string{args[2]: params}) if err != nil { fmt.Printf("Couldn't set parameter: %q\n", err) continue diff --git a/parser/parser.go b/parser/parser.go index bc16dd399..3f441bf7d 100644 --- a/parser/parser.go +++ b/parser/parser.go @@ -3,25 +3,14 @@ package parser import ( "bufio" "bytes" - "crypto/sha256" "errors" "fmt" "io" - "net/http" - "os" - "os/user" - "path/filepath" - "runtime" - "slices" "strconv" "strings" - "sync" - "golang.org/x/sync/errgroup" "golang.org/x/text/encoding/unicode" "golang.org/x/text/transform" - - "github.com/ollama/ollama/api" ) var ErrModelNotFound = errors.New("no Modelfile or safetensors files found") @@ -39,281 +28,6 @@ func (f Modelfile) String() string { return sb.String() } -var deprecatedParameters = []string{ - "penalize_newline", - "low_vram", - "f16_kv", - "logits_all", - "vocab_only", - "use_mlock", - "mirostat", - "mirostat_tau", - "mirostat_eta", -} - -// CreateRequest creates a new *api.CreateRequest from an existing Modelfile -func (f Modelfile) CreateRequest(relativeDir string) (*api.CreateRequest, error) { - req := &api.CreateRequest{} - - var messages []api.Message - var licenses []string - params := make(map[string]any) - - for _, c := range f.Commands { - switch c.Name { - case "model": - path, err := expandPath(c.Args, relativeDir) - if err != nil { - return nil, err - } - - digestMap, err := fileDigestMap(path) - if errors.Is(err, os.ErrNotExist) { - req.From = c.Args - continue - } else if err != nil { - return nil, err - } - - if req.Files == nil { - req.Files = digestMap - } else { - for k, v := range digestMap { - req.Files[k] = v - } - } - case "adapter": - path, err := expandPath(c.Args, relativeDir) - if err != nil { - return nil, err - } - - digestMap, err := fileDigestMap(path) - if err != nil { - return nil, err - } - - req.Adapters = digestMap - case "template": - req.Template = c.Args - case "system": - req.System = c.Args - case "license": - licenses = append(licenses, c.Args) - case "renderer": - req.Renderer = c.Args - case "parser": - req.Parser = c.Args - case "message": - role, msg, _ := strings.Cut(c.Args, ": ") - messages = append(messages, api.Message{Role: role, Content: msg}) - default: - if slices.Contains(deprecatedParameters, c.Name) { - fmt.Printf("warning: parameter %s is deprecated\n", c.Name) - break - } - - ps, err := api.FormatParams(map[string][]string{c.Name: {c.Args}}) - if err != nil { - return nil, err - } - - for k, v := range ps { - if ks, ok := params[k].([]string); ok { - params[k] = append(ks, v.([]string)...) - } else if vs, ok := v.([]string); ok { - params[k] = vs - } else { - params[k] = v - } - } - } - } - - if len(params) > 0 { - req.Parameters = params - } - if len(messages) > 0 { - req.Messages = messages - } - if len(licenses) > 0 { - req.License = licenses - } - - return req, nil -} - -func fileDigestMap(path string) (map[string]string, error) { - fl := make(map[string]string) - - fi, err := os.Stat(path) - if err != nil { - return nil, err - } - - var files []string - if fi.IsDir() { - fs, err := filesForModel(path) - if err != nil { - return nil, err - } - - for _, f := range fs { - f, err := filepath.EvalSymlinks(f) - if err != nil { - return nil, err - } - - rel, err := filepath.Rel(path, f) - if err != nil { - return nil, err - } - - if !filepath.IsLocal(rel) { - return nil, fmt.Errorf("insecure path: %s", rel) - } - - files = append(files, f) - } - } else { - files = []string{path} - } - - var mu sync.Mutex - var g errgroup.Group - g.SetLimit(max(runtime.GOMAXPROCS(0)-1, 1)) - for _, f := range files { - g.Go(func() error { - digest, err := digestForFile(f) - if err != nil { - return err - } - - mu.Lock() - defer mu.Unlock() - fl[f] = digest - return nil - }) - } - - if err := g.Wait(); err != nil { - return nil, err - } - - return fl, nil -} - -func digestForFile(filename string) (string, error) { - filepath, err := filepath.EvalSymlinks(filename) - if err != nil { - return "", err - } - - bin, err := os.Open(filepath) - if err != nil { - return "", err - } - defer bin.Close() - - hash := sha256.New() - if _, err := io.Copy(hash, bin); err != nil { - return "", err - } - return fmt.Sprintf("sha256:%x", hash.Sum(nil)), nil -} - -func filesForModel(path string) ([]string, error) { - detectContentType := func(path string) (string, error) { - f, err := os.Open(path) - if err != nil { - return "", err - } - defer f.Close() - - var b bytes.Buffer - b.Grow(512) - - if _, err := io.CopyN(&b, f, 512); err != nil && !errors.Is(err, io.EOF) { - return "", err - } - - contentType, _, _ := strings.Cut(http.DetectContentType(b.Bytes()), ";") - return contentType, nil - } - - glob := func(pattern, contentType string) ([]string, error) { - matches, err := filepath.Glob(pattern) - if err != nil { - return nil, err - } - - for _, match := range matches { - if ct, err := detectContentType(match); err != nil { - return nil, err - } else if len(contentType) > 0 && ct != contentType { - return nil, fmt.Errorf("invalid content type: expected %s for %s", ct, match) - } - } - - return matches, nil - } - - var files []string - // some safetensors files do not properly match "application/octet-stream", so skip checking their contentType - if st, _ := glob(filepath.Join(path, "*.safetensors"), ""); len(st) > 0 { - // safetensors files might be unresolved git lfs references; skip if they are - // covers model-x-of-y.safetensors, model.fp32-x-of-y.safetensors, model.safetensors - files = append(files, st...) - } else if pt, _ := glob(filepath.Join(path, "pytorch_model*.bin"), "application/zip"); len(pt) > 0 { - // pytorch files might also be unresolved git lfs references; skip if they are - // covers pytorch_model-x-of-y.bin, pytorch_model.fp32-x-of-y.bin, pytorch_model.bin - files = append(files, pt...) - } else if pt, _ := glob(filepath.Join(path, "consolidated*.pth"), "application/zip"); len(pt) > 0 { - // pytorch files might also be unresolved git lfs references; skip if they are - // covers consolidated.x.pth, consolidated.pth - files = append(files, pt...) - } else if gg, _ := glob(filepath.Join(path, "*.gguf"), "application/octet-stream"); len(gg) > 0 { - // covers gguf files ending in .gguf - files = append(files, gg...) - } else if gg, _ := glob(filepath.Join(path, "*.bin"), "application/octet-stream"); len(gg) > 0 { - // covers gguf files ending in .bin - files = append(files, gg...) - } else { - return nil, ErrModelNotFound - } - - // add configuration files, json files are detected as text/plain - js, err := glob(filepath.Join(path, "*.json"), "text/plain") - if err != nil { - return nil, err - } - files = append(files, js...) - - // bert models require a nested config.json - // TODO(mxyng): merge this with the glob above - js, err = glob(filepath.Join(path, "**/*.json"), "text/plain") - if err != nil { - return nil, err - } - files = append(files, js...) - - // only include tokenizer.model is tokenizer.json is not present - if !slices.ContainsFunc(files, func(s string) bool { - return slices.Contains(strings.Split(s, string(os.PathSeparator)), "tokenizer.json") - }) { - if tks, _ := glob(filepath.Join(path, "tokenizer.model"), "application/octet-stream"); len(tks) > 0 { - // add tokenizer.model if it exists, tokenizer.json is automatically picked up by the previous glob - // tokenizer.model might be a unresolved git lfs reference; error if it is - files = append(files, tks...) - } else if tks, _ := glob(filepath.Join(path, "**/tokenizer.model"), "text/plain"); len(tks) > 0 { - // some times tokenizer.model is in a subdirectory (e.g. meta-llama/Meta-Llama-3-8B) - files = append(files, tks...) - } - } - - return files, nil -} - type Command struct { Name string Args string @@ -340,7 +54,7 @@ type state int const ( stateNil state = iota - stateName + stateKey stateValue stateParameter stateMessage @@ -368,7 +82,7 @@ func (e *ParserError) Error() string { func ParseFile(r io.Reader) (*Modelfile, error) { var cmd Command var curr state - var currLine int = 1 + currLine := 1 var b bytes.Buffer var role string @@ -402,7 +116,7 @@ func ParseFile(r io.Reader) (*Modelfile, error) { // process the state transition, some transitions need to be intercepted and redirected if next != curr { switch curr { - case stateName: + case stateKey: if !isValidCommand(b.String()) { return nil, &ParserError{ LineNumber: currLine, @@ -505,12 +219,12 @@ func parseRuneForState(r rune, cs state) (state, rune, error) { case isSpace(r), isNewline(r): return stateNil, 0, nil default: - return stateName, r, nil + return stateKey, r, nil } - case stateName: + case stateKey: switch { case isAlpha(r): - return stateName, r, nil + return stateKey, r, nil case isSpace(r): return stateValue, 0, nil default: @@ -616,43 +330,3 @@ func isValidCommand(cmd string) bool { return false } } - -func expandPath(path, dir string) (string, error) { - if filepath.IsAbs(path) { - return path, nil - } - - path, found := strings.CutPrefix(path, "~") - if !found { - // make path relative to dir - if !filepath.IsAbs(dir) { - // if dir is relative, make it absolute relative to cwd - cwd, err := os.Getwd() - if err != nil { - return "", err - } - dir = filepath.Join(cwd, dir) - } - path = filepath.Join(dir, path) - } else if filepath.IsLocal(path) { - // ~/... - // make path relative to specified user's home - split := strings.SplitN(path, "/", 2) - u, err := user.Lookup(split[0]) - if err != nil { - return "", err - } - split[0] = u.HomeDir - path = filepath.Join(split...) - } else { - // ~ or ~/... - // make path relative to current user's home - home, err := os.UserHomeDir() - if err != nil { - return "", err - } - path = filepath.Join(home, path) - } - - return filepath.Clean(path), nil -} diff --git a/server/create.go b/server/create.go index 19f24ec80..3e2d09b93 100644 --- a/server/create.go +++ b/server/create.go @@ -62,8 +62,8 @@ func (s *Server) CreateHandler(c *gin.Context) { config.Renderer = r.Renderer config.Parser = r.Parser - for v := range r.Files { - if !fs.ValidPath(v) { + for _, v := range r.Files { + if !fs.ValidPath(v.Name) { c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": errFilePath.Error()}) return } @@ -276,7 +276,7 @@ func remoteURL(raw string) (string, error) { return u.String(), nil } -func convertModelFromFiles(files map[string]string, baseLayers []*layerGGML, isAdapter bool, fn func(resp api.ProgressResponse)) ([]*layerGGML, error) { +func convertModelFromFiles(files api.Files, baseLayers []*layerGGML, isAdapter bool, fn func(resp api.ProgressResponse)) ([]*layerGGML, error) { switch detectModelTypeFromFiles(files) { case "safetensors": layers, err := convertFromSafetensors(files, baseLayers, isAdapter, fn) @@ -295,7 +295,7 @@ func convertModelFromFiles(files map[string]string, baseLayers []*layerGGML, isA var digest string var allLayers []*layerGGML for _, v := range files { - digest = v + digest = v.Digest layers, err := ggufLayers(digest, fn) if err != nil { return nil, err @@ -308,15 +308,15 @@ func convertModelFromFiles(files map[string]string, baseLayers []*layerGGML, isA } } -func detectModelTypeFromFiles(files map[string]string) string { - for fn := range files { - if strings.HasSuffix(fn, ".safetensors") { +func detectModelTypeFromFiles(files api.Files) string { + for _, fn := range files { + if strings.HasSuffix(fn.Name, ".safetensors") { return "safetensors" - } else if strings.HasSuffix(fn, ".gguf") { + } else if strings.HasSuffix(fn.Name, ".gguf") { return "gguf" } else { // try to see if we can find a gguf file even without the file extension - blobPath, err := GetBlobsPath(files[fn]) + blobPath, err := GetBlobsPath(fn.Digest) if err != nil { slog.Error("error getting blobs path", "file", fn) return "" @@ -346,7 +346,7 @@ func detectModelTypeFromFiles(files map[string]string) string { return "" } -func convertFromSafetensors(files map[string]string, baseLayers []*layerGGML, isAdapter bool, fn func(resp api.ProgressResponse)) ([]*layerGGML, error) { +func convertFromSafetensors(files api.Files, baseLayers []*layerGGML, isAdapter bool, fn func(resp api.ProgressResponse)) ([]*layerGGML, error) { tmpDir, err := os.MkdirTemp(envconfig.Models(), "ollama-safetensors") if err != nil { return nil, err @@ -359,20 +359,20 @@ func convertFromSafetensors(files map[string]string, baseLayers []*layerGGML, is } defer root.Close() - for fp, digest := range files { - if !fs.ValidPath(fp) { - return nil, fmt.Errorf("%w: %s", errFilePath, fp) + for _, fn := range files { + if !fs.ValidPath(fn.Name) { + return nil, fmt.Errorf("%w: %s", errFilePath, fn) } - if _, err := root.Stat(fp); err != nil && !errors.Is(err, fs.ErrNotExist) { + if _, err := root.Stat(fn.Name); err != nil && !errors.Is(err, fs.ErrNotExist) { // Path is likely outside the root - return nil, fmt.Errorf("%w: %s: %s", errFilePath, err, fp) + return nil, fmt.Errorf("%w: %s: %s", errFilePath, err, fn) } - blobPath, err := GetBlobsPath(digest) + blobPath, err := GetBlobsPath(fn.Digest) if err != nil { return nil, err } - if err := createLink(blobPath, filepath.Join(tmpDir, fp)); err != nil { + if err := createLink(blobPath, filepath.Join(tmpDir, fn.Name)); err != nil { return nil, err } } diff --git a/types/syncmap/syncmap.go b/types/syncmap/syncmap.go deleted file mode 100644 index ff21cd999..000000000 --- a/types/syncmap/syncmap.go +++ /dev/null @@ -1,38 +0,0 @@ -package syncmap - -import ( - "maps" - "sync" -) - -// SyncMap is a simple, generic thread-safe map implementation. -type SyncMap[K comparable, V any] struct { - mu sync.RWMutex - m map[K]V -} - -func NewSyncMap[K comparable, V any]() *SyncMap[K, V] { - return &SyncMap[K, V]{ - m: make(map[K]V), - } -} - -func (s *SyncMap[K, V]) Load(key K) (V, bool) { - s.mu.RLock() - defer s.mu.RUnlock() - val, ok := s.m[key] - return val, ok -} - -func (s *SyncMap[K, V]) Store(key K, value V) { - s.mu.Lock() - defer s.mu.Unlock() - s.m[key] = value -} - -func (s *SyncMap[K, V]) Items() map[K]V { - s.mu.RLock() - defer s.mu.RUnlock() - // shallow copy map items - return maps.Clone(s.m) -}