ollama/cmd/create.go

454 lines
9.6 KiB
Go

package cmd
import (
"crypto/sha256"
"encoding/hex"
"errors"
"fmt"
"io"
"io/fs"
"iter"
"log/slog"
"net/http"
"os"
"os/user"
"path/filepath"
"runtime"
"slices"
"strings"
"sync"
"time"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/parser"
"github.com/ollama/ollama/progress"
"github.com/spf13/cobra"
"golang.org/x/sync/errgroup"
)
func expandPath(path, dir string) (string, error) {
if filepath.IsAbs(path) {
return path, nil
}
path, found := strings.CutPrefix(path, "~")
if !found {
// make path relative to dir
if !filepath.IsAbs(dir) {
// if dir is relative, make it absolute relative to cwd
cwd, err := os.Getwd()
if err != nil {
return "", err
}
dir = filepath.Join(cwd, dir)
}
path = filepath.Join(dir, path)
} else if filepath.IsLocal(path) {
// ~<user>/...
// make path relative to specified user's home
split := strings.SplitN(path, "/", 2)
u, err := user.Lookup(split[0])
if err != nil {
return "", err
}
split[0] = u.HomeDir
path = filepath.Join(split...)
} else {
// ~ or ~/...
// make path relative to current user's home
home, err := os.UserHomeDir()
if err != nil {
return "", err
}
path = filepath.Join(home, path)
}
return filepath.Clean(path), nil
}
func detectContentType(fsys fs.FS, path string) (string, error) {
f, err := fsys.Open(path)
if err != nil {
return "", err
}
defer f.Close()
b := make([]byte, 512)
if _, err := f.Read(b); err != nil && err != io.EOF {
return "", err
}
contentType, _, _ := strings.Cut(http.DetectContentType(b), ";")
return contentType, nil
}
// glob returns an iterator that yields files matching the given patterns and content types.
// The patterns and content types are provided as pairs of strings.
// If a content type is an empty string, all files matching the pattern are yielded.
// The iterator stops after the first pattern that matches any files.
func glob(fsys fs.FS, patternOrContentType ...string) iter.Seq2[string, error] {
if len(patternOrContentType)%2 != 0 {
panic("glob: patternOrContentType must have an even number of elements")
}
return func(yield func(string, error) bool) {
for i := 0; i < len(patternOrContentType); i += 2 {
pattern := patternOrContentType[i]
contentType := patternOrContentType[i+1]
matches, err := fs.Glob(fsys, pattern)
if err != nil {
yield("", err)
return
}
if len(matches) > 0 {
for _, match := range matches {
if contentType == "" {
if !yield(match, nil) {
return
}
continue
}
ct, err := detectContentType(fsys, match)
if err != nil {
yield("", err)
return
}
if ct == contentType {
if !yield(match, nil) {
return
}
}
}
return
}
}
}
}
func filesSeq(fsys fs.FS) iter.Seq[string] {
return func(yield func(string) bool) {
for match := range glob(fsys,
"*.safetensors", "",
"*.bin", "application/zip",
"*.pth", "application/zip",
"*.gguf", "application/octet-stream",
"*.bin", "application/octet-stream") {
if !yield(match) {
return
}
}
for match := range glob(fsys,
"tokenizer.json", "application/json",
"tokenizer.model", "application/octet-stream",
) {
if !yield(match) {
return
}
}
for match := range glob(fsys, "*.json", "") {
if !yield(match) {
return
}
}
for match := range glob(fsys, "**/*.json", "") {
if !yield(match) {
return
}
}
}
}
func get[T any](m map[string]any, key string) (t T) {
if v, ok := m[key].(T); ok {
t = v
}
return
}
var deprecatedParameters = []string{
"penalize_newline",
"low_vram",
"f16_kv",
"logits_all",
"vocab_only",
"use_mlock",
"mirostat",
"mirostat_tau",
"mirostat_eta",
}
func createRequest(modelfile *parser.Modelfile, dir string) (*api.CreateRequest, error) {
m := make(map[string]any)
parameters := make(map[string]any)
var files, adapters []api.File
var g errgroup.Group
g.SetLimit(runtime.GOMAXPROCS(0))
for _, cmd := range modelfile.Commands {
switch cmd.Name {
case "model", "adapter":
path, err := expandPath(cmd.Args, dir)
if err != nil {
return nil, err
}
fsys := os.DirFS(path)
seq := filesSeq(fsys)
if fi, err := os.Stat(path); errors.Is(err, os.ErrNotExist) {
m["from"] = cmd.Args
break
} else if err != nil {
return nil, err
} else if !fi.IsDir() {
base := filepath.Base(path)
path = filepath.Dir(path)
seq = func(yield func(string) bool) {
yield(base)
}
}
var mu sync.Mutex
for file := range seq {
g.Go(func() error {
f, err := os.Open(filepath.Join(path, file))
if err != nil {
return err
}
defer f.Close()
sha256sum := sha256.New()
if _, err := io.Copy(sha256sum, f); err != nil {
return err
}
file := api.File{
Name: file,
Path: filepath.Join(path, file),
Digest: "sha256:" + hex.EncodeToString(sha256sum.Sum(nil)),
}
mu.Lock()
defer mu.Unlock()
switch cmd.Name {
case "model":
files = append(files, file)
case "adapter":
adapters = append(adapters, file)
}
return nil
})
}
case "template", "system", "renderer", "parser":
m[cmd.Name] = cmd.Args
case "license":
m[cmd.Name] = append(get[[]string](m, cmd.Name), cmd.Args)
case "message":
role, msg, found := strings.Cut(cmd.Args, ": ")
if !found {
return nil, fmt.Errorf("invalid message command: %s", cmd.Args)
}
m[cmd.Name] = append(get[[]api.Message](m, cmd.Name), api.Message{
Role: role,
Content: msg,
})
default:
if slices.Contains(deprecatedParameters, cmd.Name) {
slog.Warn("parameter is deprecated", "name", cmd.Name)
break
}
ps, err := api.FormatParameters(map[string][]string{cmd.Name: {cmd.Args}})
if err != nil {
return nil, err
}
for k, v := range ps {
if ks, ok := parameters[k].([]string); ok {
parameters[k] = append(ks, v.([]string)...)
} else if vs, ok := v.([]string); ok {
parameters[k] = vs
} else {
parameters[k] = v
}
}
}
}
if err := g.Wait(); err != nil {
return nil, err
}
return &api.CreateRequest{
From: get[string](m, "from"),
Files: files,
Adapters: adapters,
License: get[[]string](m, "license"),
Messages: get[[]api.Message](m, "message"),
Parameters: parameters,
Parser: get[string](m, "parser"),
Renderer: get[string](m, "renderer"),
System: get[string](m, "system"),
Template: get[string](m, "template"),
}, nil
}
func CreateHandler(cmd *cobra.Command, args []string) error {
p := progress.NewProgress(os.Stderr)
defer p.Stop()
var reader io.Reader
filename, err := getModelfileName(cmd)
if os.IsNotExist(err) {
if filename == "" {
reader = strings.NewReader("FROM .\n")
} else {
return errModelfileNotFound
}
} else if err != nil {
return err
} else {
f, err := os.Open(filename)
if err != nil {
return err
}
reader = f
defer f.Close()
}
modelfile, err := parser.ParseFile(reader)
if err != nil {
return err
}
status := "gathering model components"
spinner := progress.NewSpinner(status)
p.Add(status, spinner)
req, err := createRequest(modelfile, filepath.Dir(filename))
if err != nil {
return err
}
spinner.Stop()
req.Model = args[0]
quantize, _ := cmd.Flags().GetString("quantize")
if quantize != "" {
req.Quantize = quantize
}
client, err := api.ClientFromEnvironment()
if err != nil {
return err
}
var g errgroup.Group
g.SetLimit(runtime.GOMAXPROCS(0))
for _, file := range req.Files {
g.Go(func() error {
_, err := createBlob(cmd, client, file.Path, file.Digest, p)
return err
})
}
if err := g.Wait(); err != nil {
return err
}
bars := make(map[string]*progress.Bar)
fn := func(resp api.ProgressResponse) error {
if resp.Digest != "" {
bar, ok := bars[resp.Digest]
if !ok {
msg := resp.Status
if msg == "" {
msg = fmt.Sprintf("pulling %s...", resp.Digest[7:19])
}
bar = progress.NewBar(msg, resp.Total, resp.Completed)
bars[resp.Digest] = bar
p.Add(resp.Digest, bar)
}
bar.Set(resp.Completed)
} else if status != resp.Status {
spinner.Stop()
status = resp.Status
spinner = progress.NewSpinner(status)
p.Add(status, spinner)
}
return nil
}
if err := client.Create(cmd.Context(), req, fn); err != nil {
if strings.Contains(err.Error(), "path or Modelfile are required") {
return fmt.Errorf("the ollama server must be updated to use `ollama create` with this client")
}
return err
}
return nil
}
func createBlob(cmd *cobra.Command, client *api.Client, path string, digest string, p *progress.Progress) (string, error) {
realPath, err := filepath.EvalSymlinks(path)
if err != nil {
return "", err
}
bin, err := os.Open(realPath)
if err != nil {
return "", err
}
defer bin.Close()
// Get file info to retrieve the size
fileInfo, err := bin.Stat()
if err != nil {
return "", err
}
fileSize := fileInfo.Size()
var pw progressWriter
status := fmt.Sprintf("copying file %s 0%%", digest)
spinner := progress.NewSpinner(status)
p.Add(status, spinner)
defer spinner.Stop()
done := make(chan struct{})
defer close(done)
go func() {
ticker := time.NewTicker(60 * time.Millisecond)
defer ticker.Stop()
for {
select {
case <-ticker.C:
spinner.SetMessage(fmt.Sprintf("copying file %s %d%%", digest, int(100*pw.n.Load()/fileSize)))
case <-done:
spinner.SetMessage(fmt.Sprintf("copying file %s 100%%", digest))
return
}
}
}()
if err := client.CreateBlob(cmd.Context(), digest, io.TeeReader(bin, &pw)); err != nil {
return "", err
}
return digest, nil
}