fix: create with nested directories

specifically models with a nested config.json will mangle the top level
config.json. this introduces a race that produces an error when the
nested config.json is written after the top level config.json.

this also fixes the issue with embedding models where
1_Pooling/config.json could not be read since they get written out as
config.json
This commit is contained in:
Michael Yang 2025-09-18 17:24:39 -07:00
parent 220a0da37e
commit b846eacf42
8 changed files with 496 additions and 569 deletions

View File

@ -471,10 +471,10 @@ type CreateRequest struct {
RemoteHost string `json:"remote_host,omitempty"`
// Files is a map of files include when creating the model.
Files map[string]string `json:"files,omitempty"`
Files Files `json:"files,omitempty"`
// Adapters is a map of LoRA adapters to include when creating the model.
Adapters map[string]string `json:"adapters,omitempty"`
Adapters Files `json:"adapters,omitempty"`
// Template is the template used when constructing a request to the model.
Template string `json:"template,omitempty"`
@ -503,6 +503,31 @@ type CreateRequest struct {
Quantization string `json:"quantization,omitempty"`
}
type Files []File
func (f Files) MarshalJSON() ([]byte, error) {
m := make(map[string]string, len(f))
for _, file := range f {
m[file.Name] = file.Digest
}
return json.Marshal(m)
}
func (f *Files) UnmarshalJSON(data []byte) error {
m := make(map[string]string)
if err := json.Unmarshal(data, &m); err != nil {
return err
}
for name, digest := range m {
*f = append(*f, File{Name: name, Digest: digest})
}
return nil
}
type File struct {
Name, Path, Digest string
}
// DeleteRequest is the request passed to [Client.Delete].
type DeleteRequest struct {
Model string `json:"model"`
@ -988,8 +1013,8 @@ func (d *Duration) UnmarshalJSON(b []byte) (err error) {
return nil
}
// FormatParams converts specified parameter options to their correct types
func FormatParams(params map[string][]string) (map[string]any, error) {
// FormatParameters converts specified parameter options to their correct types
func FormatParameters(params map[string][]string) (map[string]any, error) {
opts := Options{}
valueOpts := reflect.ValueOf(&opts).Elem() // names of the fields in the options struct
typeOpts := reflect.TypeOf(opts) // types of the fields in the options struct

View File

@ -203,7 +203,7 @@ func TestUseMmapFormatParams(t *testing.T) {
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
resp, err := FormatParams(test.req)
resp, err := FormatParameters(test.req)
require.Equal(t, test.err, err)
respVal, ok := resp["use_mmap"]
if test.exp != nil {

View File

@ -33,20 +33,17 @@ import (
"github.com/olekukonko/tablewriter"
"github.com/spf13/cobra"
"golang.org/x/crypto/ssh"
"golang.org/x/sync/errgroup"
"golang.org/x/term"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/auth"
"github.com/ollama/ollama/envconfig"
"github.com/ollama/ollama/format"
"github.com/ollama/ollama/parser"
"github.com/ollama/ollama/progress"
"github.com/ollama/ollama/readline"
"github.com/ollama/ollama/runner"
"github.com/ollama/ollama/server"
"github.com/ollama/ollama/types/model"
"github.com/ollama/ollama/types/syncmap"
"github.com/ollama/ollama/version"
)
@ -89,179 +86,6 @@ func getModelfileName(cmd *cobra.Command) (string, error) {
return absName, nil
}
func CreateHandler(cmd *cobra.Command, args []string) error {
p := progress.NewProgress(os.Stderr)
defer p.Stop()
var reader io.Reader
filename, err := getModelfileName(cmd)
if os.IsNotExist(err) {
if filename == "" {
reader = strings.NewReader("FROM .\n")
} else {
return errModelfileNotFound
}
} else if err != nil {
return err
} else {
f, err := os.Open(filename)
if err != nil {
return err
}
reader = f
defer f.Close()
}
modelfile, err := parser.ParseFile(reader)
if err != nil {
return err
}
status := "gathering model components"
spinner := progress.NewSpinner(status)
p.Add(status, spinner)
req, err := modelfile.CreateRequest(filepath.Dir(filename))
if err != nil {
return err
}
spinner.Stop()
req.Model = args[0]
quantize, _ := cmd.Flags().GetString("quantize")
if quantize != "" {
req.Quantize = quantize
}
client, err := api.ClientFromEnvironment()
if err != nil {
return err
}
var g errgroup.Group
g.SetLimit(max(runtime.GOMAXPROCS(0)-1, 1))
files := syncmap.NewSyncMap[string, string]()
for f, digest := range req.Files {
g.Go(func() error {
if _, err := createBlob(cmd, client, f, digest, p); err != nil {
return err
}
// TODO: this is incorrect since the file might be in a subdirectory
// instead this should take the path relative to the model directory
// but the current implementation does not allow this
files.Store(filepath.Base(f), digest)
return nil
})
}
adapters := syncmap.NewSyncMap[string, string]()
for f, digest := range req.Adapters {
g.Go(func() error {
if _, err := createBlob(cmd, client, f, digest, p); err != nil {
return err
}
// TODO: same here
adapters.Store(filepath.Base(f), digest)
return nil
})
}
if err := g.Wait(); err != nil {
return err
}
req.Files = files.Items()
req.Adapters = adapters.Items()
bars := make(map[string]*progress.Bar)
fn := func(resp api.ProgressResponse) error {
if resp.Digest != "" {
bar, ok := bars[resp.Digest]
if !ok {
msg := resp.Status
if msg == "" {
msg = fmt.Sprintf("pulling %s...", resp.Digest[7:19])
}
bar = progress.NewBar(msg, resp.Total, resp.Completed)
bars[resp.Digest] = bar
p.Add(resp.Digest, bar)
}
bar.Set(resp.Completed)
} else if status != resp.Status {
spinner.Stop()
status = resp.Status
spinner = progress.NewSpinner(status)
p.Add(status, spinner)
}
return nil
}
if err := client.Create(cmd.Context(), req, fn); err != nil {
if strings.Contains(err.Error(), "path or Modelfile are required") {
return fmt.Errorf("the ollama server must be updated to use `ollama create` with this client")
}
return err
}
return nil
}
func createBlob(cmd *cobra.Command, client *api.Client, path string, digest string, p *progress.Progress) (string, error) {
realPath, err := filepath.EvalSymlinks(path)
if err != nil {
return "", err
}
bin, err := os.Open(realPath)
if err != nil {
return "", err
}
defer bin.Close()
// Get file info to retrieve the size
fileInfo, err := bin.Stat()
if err != nil {
return "", err
}
fileSize := fileInfo.Size()
var pw progressWriter
status := fmt.Sprintf("copying file %s 0%%", digest)
spinner := progress.NewSpinner(status)
p.Add(status, spinner)
defer spinner.Stop()
done := make(chan struct{})
defer close(done)
go func() {
ticker := time.NewTicker(60 * time.Millisecond)
defer ticker.Stop()
for {
select {
case <-ticker.C:
spinner.SetMessage(fmt.Sprintf("copying file %s %d%%", digest, int(100*pw.n.Load()/fileSize)))
case <-done:
spinner.SetMessage(fmt.Sprintf("copying file %s 100%%", digest))
return
}
}
}()
if err := client.CreateBlob(cmd.Context(), digest, io.TeeReader(bin, &pw)); err != nil {
return "", err
}
return digest, nil
}
type progressWriter struct {
n atomic.Int64
}

442
cmd/create.go Normal file
View File

@ -0,0 +1,442 @@
package cmd
import (
"crypto/sha256"
"encoding/hex"
"fmt"
"io"
"io/fs"
"iter"
"log/slog"
"net/http"
"os"
"os/user"
"path/filepath"
"runtime"
"slices"
"strings"
"sync"
"time"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/parser"
"github.com/ollama/ollama/progress"
"github.com/spf13/cobra"
"golang.org/x/sync/errgroup"
)
func expandPath(path, dir string) (string, error) {
if filepath.IsAbs(path) {
return path, nil
}
path, found := strings.CutPrefix(path, "~")
if !found {
// make path relative to dir
if !filepath.IsAbs(dir) {
// if dir is relative, make it absolute relative to cwd
cwd, err := os.Getwd()
if err != nil {
return "", err
}
dir = filepath.Join(cwd, dir)
}
path = filepath.Join(dir, path)
} else if filepath.IsLocal(path) {
// ~<user>/...
// make path relative to specified user's home
split := strings.SplitN(path, "/", 2)
u, err := user.Lookup(split[0])
if err != nil {
return "", err
}
split[0] = u.HomeDir
path = filepath.Join(split...)
} else {
// ~ or ~/...
// make path relative to current user's home
home, err := os.UserHomeDir()
if err != nil {
return "", err
}
path = filepath.Join(home, path)
}
return filepath.Clean(path), nil
}
func detectContentType(fsys fs.FS, path string) (string, error) {
f, err := fsys.Open(path)
if err != nil {
return "", err
}
defer f.Close()
b := make([]byte, 512)
if _, err := f.Read(b); err != nil && err != io.EOF {
return "", err
}
contentType, _, _ := strings.Cut(http.DetectContentType(b), ";")
return contentType, nil
}
// glob returns an iterator that yields files matching the given patterns and content types.
// The patterns and content types are provided as pairs of strings.
// If a content type is an empty string, all files matching the pattern are yielded.
// The iterator stops after the first pattern that matches any files.
func glob(fsys fs.FS, patternOrContentType ...string) iter.Seq2[string, error] {
if len(patternOrContentType)%2 != 0 {
panic("glob: patternOrContentType must have an even number of elements")
}
return func(yield func(string, error) bool) {
for i := 0; i < len(patternOrContentType); i += 2 {
pattern := patternOrContentType[i]
contentType := patternOrContentType[i+1]
matches, err := fs.Glob(fsys, pattern)
if err != nil {
yield("", err)
return
}
if len(matches) > 0 {
for _, match := range matches {
if contentType == "" {
if !yield(match, nil) {
return
}
continue
}
ct, err := detectContentType(fsys, match)
if err != nil {
yield("", err)
return
}
if ct == contentType {
if !yield(match, nil) {
return
}
}
}
return
}
}
}
}
func filesSeq(fsys fs.FS) iter.Seq[string] {
return func(yield func(string) bool) {
for match := range glob(fsys,
"*.safetensors", "",
"*.bin", "application/zip",
"*.pth", "application/zip",
"*.gguf", "application/octet-stream",
"*.bin", "application/octet-stream") {
if !yield(match) {
return
}
}
for match := range glob(fsys,
"tokenizer.json", "application/json",
"tokenizer.model", "application/octet-stream",
) {
if !yield(match) {
return
}
}
for match := range glob(fsys, "*.json", "") {
if !yield(match) {
return
}
}
for match := range glob(fsys, "**/*.json", "") {
if !yield(match) {
return
}
}
}
}
func get[T any](m map[string]any, key string) (t T) {
if v, ok := m[key].(T); ok {
t = v
}
return
}
var deprecatedParameters = []string{
"penalize_newline",
"low_vram",
"f16_kv",
"logits_all",
"vocab_only",
"use_mlock",
"mirostat",
"mirostat_tau",
"mirostat_eta",
}
func createRequest(modelfile *parser.Modelfile, dir string) (*api.CreateRequest, error) {
m := make(map[string]any)
parameters := make(map[string]any)
var files, adapters []api.File
var g errgroup.Group
g.SetLimit(runtime.GOMAXPROCS(0))
for _, cmd := range modelfile.Commands {
switch cmd.Name {
case "model", "adapter":
path, err := expandPath(cmd.Args, dir)
if err != nil {
return nil, err
}
if stat, err := os.Stat(path); err != nil {
return nil, err
} else if !stat.IsDir() {
return nil, nil
}
var mu sync.Mutex
for file := range filesSeq(os.DirFS(path)) {
g.Go(func() error {
f, err := os.Open(filepath.Join(path, file))
if err != nil {
return err
}
defer f.Close()
sha256sum := sha256.New()
if _, err := io.Copy(sha256sum, f); err != nil {
return err
}
file := api.File{
Name: file,
Path: filepath.Join(path, file),
Digest: "sha256:" + hex.EncodeToString(sha256sum.Sum(nil)),
}
mu.Lock()
defer mu.Unlock()
switch cmd.Name {
case "model":
files = append(files, file)
case "adapter":
adapters = append(adapters, file)
}
return nil
})
}
case "template", "system", "renderer", "parser":
m[cmd.Name] = cmd.Args
case "license":
m[cmd.Name] = append(get[[]string](m, cmd.Name), cmd.Args)
case "message":
role, msg, found := strings.Cut(cmd.Args, ": ")
if !found {
return nil, fmt.Errorf("invalid message command: %s", cmd.Args)
}
m[cmd.Name] = append(get[[]api.Message](m, cmd.Name), api.Message{
Role: role,
Content: msg,
})
default:
if slices.Contains(deprecatedParameters, cmd.Name) {
slog.Warn("parameter is deprecated", "name", cmd.Name)
break
}
ps, err := api.FormatParameters(map[string][]string{cmd.Name: {cmd.Args}})
if err != nil {
return nil, err
}
for k, v := range ps {
if ks, ok := parameters[k].([]string); ok {
parameters[k] = append(ks, v.([]string)...)
} else if vs, ok := v.([]string); ok {
parameters[k] = vs
} else {
parameters[k] = v
}
}
}
}
if err := g.Wait(); err != nil {
return nil, err
}
return &api.CreateRequest{
Files: files,
Adapters: adapters,
Parameters: parameters,
Template: get[string](m, "template"),
System: get[string](m, "system"),
License: get[[]string](m, "license"),
Messages: get[[]api.Message](m, "message"),
Renderer: get[string](m, "renderer"),
Parser: get[string](m, "parser"),
}, nil
}
func CreateHandler(cmd *cobra.Command, args []string) error {
p := progress.NewProgress(os.Stderr)
defer p.Stop()
var reader io.Reader
filename, err := getModelfileName(cmd)
if os.IsNotExist(err) {
if filename == "" {
reader = strings.NewReader("FROM .\n")
} else {
return errModelfileNotFound
}
} else if err != nil {
return err
} else {
f, err := os.Open(filename)
if err != nil {
return err
}
reader = f
defer f.Close()
}
modelfile, err := parser.ParseFile(reader)
if err != nil {
return err
}
status := "gathering model components"
spinner := progress.NewSpinner(status)
p.Add(status, spinner)
req, err := createRequest(modelfile, filepath.Dir(filename))
if err != nil {
return err
}
spinner.Stop()
req.Model = args[0]
quantize, _ := cmd.Flags().GetString("quantize")
if quantize != "" {
req.Quantize = quantize
}
client, err := api.ClientFromEnvironment()
if err != nil {
return err
}
var g errgroup.Group
g.SetLimit(runtime.GOMAXPROCS(0))
for _, file := range req.Files {
g.Go(func() error {
_, err := createBlob(cmd, client, file.Path, file.Digest, p)
return err
})
}
if err := g.Wait(); err != nil {
return err
}
bars := make(map[string]*progress.Bar)
fn := func(resp api.ProgressResponse) error {
if resp.Digest != "" {
bar, ok := bars[resp.Digest]
if !ok {
msg := resp.Status
if msg == "" {
msg = fmt.Sprintf("pulling %s...", resp.Digest[7:19])
}
bar = progress.NewBar(msg, resp.Total, resp.Completed)
bars[resp.Digest] = bar
p.Add(resp.Digest, bar)
}
bar.Set(resp.Completed)
} else if status != resp.Status {
spinner.Stop()
status = resp.Status
spinner = progress.NewSpinner(status)
p.Add(status, spinner)
}
return nil
}
if err := client.Create(cmd.Context(), req, fn); err != nil {
if strings.Contains(err.Error(), "path or Modelfile are required") {
return fmt.Errorf("the ollama server must be updated to use `ollama create` with this client")
}
return err
}
return nil
}
func createBlob(cmd *cobra.Command, client *api.Client, path string, digest string, p *progress.Progress) (string, error) {
realPath, err := filepath.EvalSymlinks(path)
if err != nil {
return "", err
}
bin, err := os.Open(realPath)
if err != nil {
return "", err
}
defer bin.Close()
// Get file info to retrieve the size
fileInfo, err := bin.Stat()
if err != nil {
return "", err
}
fileSize := fileInfo.Size()
var pw progressWriter
status := fmt.Sprintf("copying file %s 0%%", digest)
spinner := progress.NewSpinner(status)
p.Add(status, spinner)
defer spinner.Stop()
done := make(chan struct{})
defer close(done)
go func() {
ticker := time.NewTicker(60 * time.Millisecond)
defer ticker.Stop()
for {
select {
case <-ticker.C:
spinner.SetMessage(fmt.Sprintf("copying file %s %d%%", digest, int(100*pw.n.Load()/fileSize)))
case <-done:
spinner.SetMessage(fmt.Sprintf("copying file %s 100%%", digest))
return
}
}
}()
if err := client.CreateBlob(cmd.Context(), digest, io.TeeReader(bin, &pw)); err != nil {
return "", err
}
return digest, nil
}

View File

@ -316,7 +316,7 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error {
continue
}
params := args[3:]
fp, err := api.FormatParams(map[string][]string{args[2]: params})
fp, err := api.FormatParameters(map[string][]string{args[2]: params})
if err != nil {
fmt.Printf("Couldn't set parameter: %q\n", err)
continue

View File

@ -3,25 +3,14 @@ package parser
import (
"bufio"
"bytes"
"crypto/sha256"
"errors"
"fmt"
"io"
"net/http"
"os"
"os/user"
"path/filepath"
"runtime"
"slices"
"strconv"
"strings"
"sync"
"golang.org/x/sync/errgroup"
"golang.org/x/text/encoding/unicode"
"golang.org/x/text/transform"
"github.com/ollama/ollama/api"
)
var ErrModelNotFound = errors.New("no Modelfile or safetensors files found")
@ -39,281 +28,6 @@ func (f Modelfile) String() string {
return sb.String()
}
var deprecatedParameters = []string{
"penalize_newline",
"low_vram",
"f16_kv",
"logits_all",
"vocab_only",
"use_mlock",
"mirostat",
"mirostat_tau",
"mirostat_eta",
}
// CreateRequest creates a new *api.CreateRequest from an existing Modelfile
func (f Modelfile) CreateRequest(relativeDir string) (*api.CreateRequest, error) {
req := &api.CreateRequest{}
var messages []api.Message
var licenses []string
params := make(map[string]any)
for _, c := range f.Commands {
switch c.Name {
case "model":
path, err := expandPath(c.Args, relativeDir)
if err != nil {
return nil, err
}
digestMap, err := fileDigestMap(path)
if errors.Is(err, os.ErrNotExist) {
req.From = c.Args
continue
} else if err != nil {
return nil, err
}
if req.Files == nil {
req.Files = digestMap
} else {
for k, v := range digestMap {
req.Files[k] = v
}
}
case "adapter":
path, err := expandPath(c.Args, relativeDir)
if err != nil {
return nil, err
}
digestMap, err := fileDigestMap(path)
if err != nil {
return nil, err
}
req.Adapters = digestMap
case "template":
req.Template = c.Args
case "system":
req.System = c.Args
case "license":
licenses = append(licenses, c.Args)
case "renderer":
req.Renderer = c.Args
case "parser":
req.Parser = c.Args
case "message":
role, msg, _ := strings.Cut(c.Args, ": ")
messages = append(messages, api.Message{Role: role, Content: msg})
default:
if slices.Contains(deprecatedParameters, c.Name) {
fmt.Printf("warning: parameter %s is deprecated\n", c.Name)
break
}
ps, err := api.FormatParams(map[string][]string{c.Name: {c.Args}})
if err != nil {
return nil, err
}
for k, v := range ps {
if ks, ok := params[k].([]string); ok {
params[k] = append(ks, v.([]string)...)
} else if vs, ok := v.([]string); ok {
params[k] = vs
} else {
params[k] = v
}
}
}
}
if len(params) > 0 {
req.Parameters = params
}
if len(messages) > 0 {
req.Messages = messages
}
if len(licenses) > 0 {
req.License = licenses
}
return req, nil
}
func fileDigestMap(path string) (map[string]string, error) {
fl := make(map[string]string)
fi, err := os.Stat(path)
if err != nil {
return nil, err
}
var files []string
if fi.IsDir() {
fs, err := filesForModel(path)
if err != nil {
return nil, err
}
for _, f := range fs {
f, err := filepath.EvalSymlinks(f)
if err != nil {
return nil, err
}
rel, err := filepath.Rel(path, f)
if err != nil {
return nil, err
}
if !filepath.IsLocal(rel) {
return nil, fmt.Errorf("insecure path: %s", rel)
}
files = append(files, f)
}
} else {
files = []string{path}
}
var mu sync.Mutex
var g errgroup.Group
g.SetLimit(max(runtime.GOMAXPROCS(0)-1, 1))
for _, f := range files {
g.Go(func() error {
digest, err := digestForFile(f)
if err != nil {
return err
}
mu.Lock()
defer mu.Unlock()
fl[f] = digest
return nil
})
}
if err := g.Wait(); err != nil {
return nil, err
}
return fl, nil
}
func digestForFile(filename string) (string, error) {
filepath, err := filepath.EvalSymlinks(filename)
if err != nil {
return "", err
}
bin, err := os.Open(filepath)
if err != nil {
return "", err
}
defer bin.Close()
hash := sha256.New()
if _, err := io.Copy(hash, bin); err != nil {
return "", err
}
return fmt.Sprintf("sha256:%x", hash.Sum(nil)), nil
}
func filesForModel(path string) ([]string, error) {
detectContentType := func(path string) (string, error) {
f, err := os.Open(path)
if err != nil {
return "", err
}
defer f.Close()
var b bytes.Buffer
b.Grow(512)
if _, err := io.CopyN(&b, f, 512); err != nil && !errors.Is(err, io.EOF) {
return "", err
}
contentType, _, _ := strings.Cut(http.DetectContentType(b.Bytes()), ";")
return contentType, nil
}
glob := func(pattern, contentType string) ([]string, error) {
matches, err := filepath.Glob(pattern)
if err != nil {
return nil, err
}
for _, match := range matches {
if ct, err := detectContentType(match); err != nil {
return nil, err
} else if len(contentType) > 0 && ct != contentType {
return nil, fmt.Errorf("invalid content type: expected %s for %s", ct, match)
}
}
return matches, nil
}
var files []string
// some safetensors files do not properly match "application/octet-stream", so skip checking their contentType
if st, _ := glob(filepath.Join(path, "*.safetensors"), ""); len(st) > 0 {
// safetensors files might be unresolved git lfs references; skip if they are
// covers model-x-of-y.safetensors, model.fp32-x-of-y.safetensors, model.safetensors
files = append(files, st...)
} else if pt, _ := glob(filepath.Join(path, "pytorch_model*.bin"), "application/zip"); len(pt) > 0 {
// pytorch files might also be unresolved git lfs references; skip if they are
// covers pytorch_model-x-of-y.bin, pytorch_model.fp32-x-of-y.bin, pytorch_model.bin
files = append(files, pt...)
} else if pt, _ := glob(filepath.Join(path, "consolidated*.pth"), "application/zip"); len(pt) > 0 {
// pytorch files might also be unresolved git lfs references; skip if they are
// covers consolidated.x.pth, consolidated.pth
files = append(files, pt...)
} else if gg, _ := glob(filepath.Join(path, "*.gguf"), "application/octet-stream"); len(gg) > 0 {
// covers gguf files ending in .gguf
files = append(files, gg...)
} else if gg, _ := glob(filepath.Join(path, "*.bin"), "application/octet-stream"); len(gg) > 0 {
// covers gguf files ending in .bin
files = append(files, gg...)
} else {
return nil, ErrModelNotFound
}
// add configuration files, json files are detected as text/plain
js, err := glob(filepath.Join(path, "*.json"), "text/plain")
if err != nil {
return nil, err
}
files = append(files, js...)
// bert models require a nested config.json
// TODO(mxyng): merge this with the glob above
js, err = glob(filepath.Join(path, "**/*.json"), "text/plain")
if err != nil {
return nil, err
}
files = append(files, js...)
// only include tokenizer.model is tokenizer.json is not present
if !slices.ContainsFunc(files, func(s string) bool {
return slices.Contains(strings.Split(s, string(os.PathSeparator)), "tokenizer.json")
}) {
if tks, _ := glob(filepath.Join(path, "tokenizer.model"), "application/octet-stream"); len(tks) > 0 {
// add tokenizer.model if it exists, tokenizer.json is automatically picked up by the previous glob
// tokenizer.model might be a unresolved git lfs reference; error if it is
files = append(files, tks...)
} else if tks, _ := glob(filepath.Join(path, "**/tokenizer.model"), "text/plain"); len(tks) > 0 {
// some times tokenizer.model is in a subdirectory (e.g. meta-llama/Meta-Llama-3-8B)
files = append(files, tks...)
}
}
return files, nil
}
type Command struct {
Name string
Args string
@ -340,7 +54,7 @@ type state int
const (
stateNil state = iota
stateName
stateKey
stateValue
stateParameter
stateMessage
@ -368,7 +82,7 @@ func (e *ParserError) Error() string {
func ParseFile(r io.Reader) (*Modelfile, error) {
var cmd Command
var curr state
var currLine int = 1
currLine := 1
var b bytes.Buffer
var role string
@ -402,7 +116,7 @@ func ParseFile(r io.Reader) (*Modelfile, error) {
// process the state transition, some transitions need to be intercepted and redirected
if next != curr {
switch curr {
case stateName:
case stateKey:
if !isValidCommand(b.String()) {
return nil, &ParserError{
LineNumber: currLine,
@ -505,12 +219,12 @@ func parseRuneForState(r rune, cs state) (state, rune, error) {
case isSpace(r), isNewline(r):
return stateNil, 0, nil
default:
return stateName, r, nil
return stateKey, r, nil
}
case stateName:
case stateKey:
switch {
case isAlpha(r):
return stateName, r, nil
return stateKey, r, nil
case isSpace(r):
return stateValue, 0, nil
default:
@ -616,43 +330,3 @@ func isValidCommand(cmd string) bool {
return false
}
}
func expandPath(path, dir string) (string, error) {
if filepath.IsAbs(path) {
return path, nil
}
path, found := strings.CutPrefix(path, "~")
if !found {
// make path relative to dir
if !filepath.IsAbs(dir) {
// if dir is relative, make it absolute relative to cwd
cwd, err := os.Getwd()
if err != nil {
return "", err
}
dir = filepath.Join(cwd, dir)
}
path = filepath.Join(dir, path)
} else if filepath.IsLocal(path) {
// ~<user>/...
// make path relative to specified user's home
split := strings.SplitN(path, "/", 2)
u, err := user.Lookup(split[0])
if err != nil {
return "", err
}
split[0] = u.HomeDir
path = filepath.Join(split...)
} else {
// ~ or ~/...
// make path relative to current user's home
home, err := os.UserHomeDir()
if err != nil {
return "", err
}
path = filepath.Join(home, path)
}
return filepath.Clean(path), nil
}

View File

@ -62,8 +62,8 @@ func (s *Server) CreateHandler(c *gin.Context) {
config.Renderer = r.Renderer
config.Parser = r.Parser
for v := range r.Files {
if !fs.ValidPath(v) {
for _, v := range r.Files {
if !fs.ValidPath(v.Name) {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": errFilePath.Error()})
return
}
@ -276,7 +276,7 @@ func remoteURL(raw string) (string, error) {
return u.String(), nil
}
func convertModelFromFiles(files map[string]string, baseLayers []*layerGGML, isAdapter bool, fn func(resp api.ProgressResponse)) ([]*layerGGML, error) {
func convertModelFromFiles(files api.Files, baseLayers []*layerGGML, isAdapter bool, fn func(resp api.ProgressResponse)) ([]*layerGGML, error) {
switch detectModelTypeFromFiles(files) {
case "safetensors":
layers, err := convertFromSafetensors(files, baseLayers, isAdapter, fn)
@ -295,7 +295,7 @@ func convertModelFromFiles(files map[string]string, baseLayers []*layerGGML, isA
var digest string
var allLayers []*layerGGML
for _, v := range files {
digest = v
digest = v.Digest
layers, err := ggufLayers(digest, fn)
if err != nil {
return nil, err
@ -308,15 +308,15 @@ func convertModelFromFiles(files map[string]string, baseLayers []*layerGGML, isA
}
}
func detectModelTypeFromFiles(files map[string]string) string {
for fn := range files {
if strings.HasSuffix(fn, ".safetensors") {
func detectModelTypeFromFiles(files api.Files) string {
for _, fn := range files {
if strings.HasSuffix(fn.Name, ".safetensors") {
return "safetensors"
} else if strings.HasSuffix(fn, ".gguf") {
} else if strings.HasSuffix(fn.Name, ".gguf") {
return "gguf"
} else {
// try to see if we can find a gguf file even without the file extension
blobPath, err := GetBlobsPath(files[fn])
blobPath, err := GetBlobsPath(fn.Digest)
if err != nil {
slog.Error("error getting blobs path", "file", fn)
return ""
@ -346,7 +346,7 @@ func detectModelTypeFromFiles(files map[string]string) string {
return ""
}
func convertFromSafetensors(files map[string]string, baseLayers []*layerGGML, isAdapter bool, fn func(resp api.ProgressResponse)) ([]*layerGGML, error) {
func convertFromSafetensors(files api.Files, baseLayers []*layerGGML, isAdapter bool, fn func(resp api.ProgressResponse)) ([]*layerGGML, error) {
tmpDir, err := os.MkdirTemp(envconfig.Models(), "ollama-safetensors")
if err != nil {
return nil, err
@ -359,20 +359,20 @@ func convertFromSafetensors(files map[string]string, baseLayers []*layerGGML, is
}
defer root.Close()
for fp, digest := range files {
if !fs.ValidPath(fp) {
return nil, fmt.Errorf("%w: %s", errFilePath, fp)
for _, fn := range files {
if !fs.ValidPath(fn.Name) {
return nil, fmt.Errorf("%w: %s", errFilePath, fn)
}
if _, err := root.Stat(fp); err != nil && !errors.Is(err, fs.ErrNotExist) {
if _, err := root.Stat(fn.Name); err != nil && !errors.Is(err, fs.ErrNotExist) {
// Path is likely outside the root
return nil, fmt.Errorf("%w: %s: %s", errFilePath, err, fp)
return nil, fmt.Errorf("%w: %s: %s", errFilePath, err, fn)
}
blobPath, err := GetBlobsPath(digest)
blobPath, err := GetBlobsPath(fn.Digest)
if err != nil {
return nil, err
}
if err := createLink(blobPath, filepath.Join(tmpDir, fp)); err != nil {
if err := createLink(blobPath, filepath.Join(tmpDir, fn.Name)); err != nil {
return nil, err
}
}

View File

@ -1,38 +0,0 @@
package syncmap
import (
"maps"
"sync"
)
// SyncMap is a simple, generic thread-safe map implementation.
type SyncMap[K comparable, V any] struct {
mu sync.RWMutex
m map[K]V
}
func NewSyncMap[K comparable, V any]() *SyncMap[K, V] {
return &SyncMap[K, V]{
m: make(map[K]V),
}
}
func (s *SyncMap[K, V]) Load(key K) (V, bool) {
s.mu.RLock()
defer s.mu.RUnlock()
val, ok := s.m[key]
return val, ok
}
func (s *SyncMap[K, V]) Store(key K, value V) {
s.mu.Lock()
defer s.mu.Unlock()
s.m[key] = value
}
func (s *SyncMap[K, V]) Items() map[K]V {
s.mu.RLock()
defer s.mu.RUnlock()
// shallow copy map items
return maps.Clone(s.m)
}