diff --git a/cmd/cmd.go b/cmd/cmd.go index 1d1d116ba..8b7cee95c 100644 --- a/cmd/cmd.go +++ b/cmd/cmd.go @@ -64,54 +64,37 @@ func ensureThinkingSupport(ctx context.Context, client *api.Client, name string) fmt.Fprintf(os.Stderr, "warning: model %q does not support thinking output\n", name) } -var errModelfileNotFound = errors.New("specified Modelfile wasn't found") - -func getModelfileName(cmd *cobra.Command) (string, error) { - filename, _ := cmd.Flags().GetString("file") - - if filename == "" { - filename = "Modelfile" - } - - absName, err := filepath.Abs(filename) - if err != nil { - return "", err - } - - _, err = os.Stat(absName) - if err != nil { - return "", err - } - - return absName, nil -} - func CreateHandler(cmd *cobra.Command, args []string) error { p := progress.NewProgress(os.Stderr) defer p.Stop() - var reader io.Reader - - filename, err := getModelfileName(cmd) - if os.IsNotExist(err) { - if filename == "" { - reader = strings.NewReader("FROM .\n") - } else { - return errModelfileNotFound - } - } else if err != nil { - return err - } else { - f, err := os.Open(filename) - if err != nil { - return err - } - - reader = f - defer f.Close() + filename, err := cmd.Flags().GetString("file") + if err != nil { + return fmt.Errorf("error retrieving file flag: %w", err) } - modelfile, err := parser.ParseFile(reader) + var r, fallback io.Reader + switch filename { + case "-": + r = os.Stdin + case "": + filename = "Modelfile" + fallback = strings.NewReader("FROM .") + fallthrough + default: + r, err = os.Open(filename) + if errors.Is(err, os.ErrNotExist) && fallback != nil { + r = fallback + } else if errors.Is(err, os.ErrNotExist) { + return fmt.Errorf("%w: Modelfile %q does not exist, please create it or use --file to specify a different file", err, filename) + } else if err != nil { + return err + } else { + defer r.(*os.File).Close() + } + } + + modelfile, err := parser.ParseFile(r) if err != nil { return err } @@ -127,10 +110,7 @@ func CreateHandler(cmd *cobra.Command, args []string) error { spinner.Stop() req.Model = args[0] - quantize, _ := cmd.Flags().GetString("quantize") - if quantize != "" { - req.Quantize = quantize - } + req.Quantize, _ = cmd.Flags().GetString("quantize") client, err := api.ClientFromEnvironment() if err != nil { diff --git a/cmd/cmd_test.go b/cmd/cmd_test.go index cf5fe7caa..096e65b9a 100644 --- a/cmd/cmd_test.go +++ b/cmd/cmd_test.go @@ -3,21 +3,31 @@ package cmd import ( "bytes" "encoding/json" + "errors" "io" "net/http" "net/http/httptest" "os" + "path/filepath" "strings" "testing" "time" "github.com/google/go-cmp/cmp" + "github.com/google/go-cmp/cmp/cmpopts" "github.com/spf13/cobra" "github.com/ollama/ollama/api" "github.com/ollama/ollama/types/model" ) +func mockServer(t *testing.T, h http.HandlerFunc) { + t.Helper() + s := httptest.NewServer(h) + t.Cleanup(s.Close) + t.Setenv("OLLAMA_HOST", s.URL) +} + func TestShowInfo(t *testing.T) { t.Run("bare details", func(t *testing.T) { var b bytes.Buffer @@ -351,101 +361,6 @@ func TestDeleteHandler(t *testing.T) { } } -func TestGetModelfileName(t *testing.T) { - tests := []struct { - name string - modelfileName string - fileExists bool - expectedName string - expectedErr error - }{ - { - name: "no modelfile specified, no modelfile exists", - modelfileName: "", - fileExists: false, - expectedName: "", - expectedErr: os.ErrNotExist, - }, - { - name: "no modelfile specified, modelfile exists", - modelfileName: "", - fileExists: true, - expectedName: "Modelfile", - expectedErr: nil, - }, - { - name: "modelfile specified, no modelfile exists", - modelfileName: "crazyfile", - fileExists: false, - expectedName: "", - expectedErr: os.ErrNotExist, - }, - { - name: "modelfile specified, modelfile exists", - modelfileName: "anotherfile", - fileExists: true, - expectedName: "anotherfile", - expectedErr: nil, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - cmd := &cobra.Command{ - Use: "fakecmd", - } - cmd.Flags().String("file", "", "path to modelfile") - - var expectedFilename string - - if tt.fileExists { - var fn string - if tt.modelfileName != "" { - fn = tt.modelfileName - } else { - fn = "Modelfile" - } - - tempFile, err := os.CreateTemp(t.TempDir(), fn) - if err != nil { - t.Fatalf("temp modelfile creation failed: %v", err) - } - defer tempFile.Close() - - expectedFilename = tempFile.Name() - err = cmd.Flags().Set("file", expectedFilename) - if err != nil { - t.Fatalf("couldn't set file flag: %v", err) - } - } else { - expectedFilename = tt.expectedName - if tt.modelfileName != "" { - err := cmd.Flags().Set("file", tt.modelfileName) - if err != nil { - t.Fatalf("couldn't set file flag: %v", err) - } - } - } - - actualFilename, actualErr := getModelfileName(cmd) - - if actualFilename != expectedFilename { - t.Errorf("expected filename: '%s' actual filename: '%s'", expectedFilename, actualFilename) - } - - if tt.expectedErr != os.ErrNotExist { - if actualErr != tt.expectedErr { - t.Errorf("expected err: %v actual err: %v", tt.expectedErr, actualErr) - } - } else { - if !os.IsNotExist(actualErr) { - t.Errorf("expected err: %v actual err: %v", tt.expectedErr, actualErr) - } - } - }) - } -} - func TestPushHandler(t *testing.T) { tests := []struct { name string @@ -661,128 +576,160 @@ func TestListHandler(t *testing.T) { } func TestCreateHandler(t *testing.T) { - tests := []struct { - name string - modelName string - modelFile string - serverResponse map[string]func(w http.ResponseWriter, r *http.Request) - expectedError string - expectedOutput string + cases := []struct { + name string + filename func(*testing.T) string + + wantRequest api.CreateRequest + wantErr error }{ { - name: "successful create", - modelName: "test-model", - modelFile: "FROM foo", - serverResponse: map[string]func(w http.ResponseWriter, r *http.Request){ - "/api/create": func(w http.ResponseWriter, r *http.Request) { - if r.Method != http.MethodPost { - t.Errorf("expected POST request, got %s", r.Method) - } + name: "not exist", + filename: func(*testing.T) string { return "not_exist" }, + wantErr: os.ErrNotExist, + }, + { + name: "stdin", + filename: func(t *testing.T) string { + r, w, err := os.Pipe() + if err != nil { + t.Fatal(err) + } - req := api.CreateRequest{} + if _, err := w.WriteString("FROM stdin"); err != nil { + t.Fatal(err) + } + + if err := w.Close(); err != nil { + t.Fatal(err) + } + + stdin := os.Stdin + t.Cleanup(func() { os.Stdin = stdin }) + os.Stdin = r + return "-" + }, + wantRequest: api.CreateRequest{ + Model: "stdin", + From: "stdin", + }, + }, + { + name: "default", + filename: func(t *testing.T) string { + t.Chdir(t.TempDir()) + + f, err := os.Create("Modelfile") + if err != nil { + t.Fatal(err) + } + defer f.Close() + + if _, err := f.WriteString("FROM default"); err != nil { + t.Fatal(err) + } + + return "" + }, + wantRequest: api.CreateRequest{ + Model: "default", + From: "default", + }, + }, + { + name: "file flag", + filename: func(t *testing.T) string { + f, err := os.CreateTemp(t.TempDir(), filepath.Base(t.Name())) + if err != nil { + t.Fatal(err) + } + defer f.Close() + + if _, err := f.WriteString("FROM file:flag"); err != nil { + t.Fatal(err) + } + + return f.Name() + }, + wantRequest: api.CreateRequest{ + Model: "file_flag", + From: "file:flag", + }, + }, + { + name: "default safetensors", + filename: func(t *testing.T) string { + t.Chdir(t.TempDir()) + f, err := os.Create("model.safetensors") + if err != nil { + t.Fatal(err) + } + defer f.Close() + + if err := f.Truncate(1); err != nil { + t.Fatal(err) + } + + return "" + }, + wantRequest: api.CreateRequest{ + Model: "default_safetensors", + }, + }, + { + name: "insecure path", + filename: func(t *testing.T) string { + t.Chdir(t.TempDir()) + if err := os.Symlink("../../../../../../nope", "model.safetensors"); err != nil { + t.Fatal(err) + } + + return "" + }, + }, + } + + var cmd cobra.Command + cmd.SetContext(t.Context()) + cmd.Flags().String("file", "", "") + cmd.Flags().String("quantize", "", "") + for _, tt := range cases { + t.Run(tt.name, func(t *testing.T) { + mockServer(t, func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost { + http.Error(w, "method not allowed", http.StatusMethodNotAllowed) + return + } + + if r.URL.Path == "/api/create" { + var req api.CreateRequest if err := json.NewDecoder(r.Body).Decode(&req); err != nil { http.Error(w, err.Error(), http.StatusBadRequest) return } - if req.Model != "test-model" { - t.Errorf("expected model name 'test-model', got %s", req.Name) + if diff := cmp.Diff(tt.wantRequest, req, cmpopts.IgnoreFields(api.CreateRequest{}, "Files")); diff != "" { + t.Errorf("Create request mismatch (-want +got):\n%s", diff) } - - if req.From != "foo" { - t.Errorf("expected from 'foo', got %s", req.From) - } - - responses := []api.ProgressResponse{ - {Status: "using existing layer sha256:56bb8bd477a519ffa694fc449c2413c6f0e1d3b1c88fa7e3c9d88d3ae49d4dcb"}, - {Status: "writing manifest"}, - {Status: "success"}, - } - - for _, resp := range responses { - if err := json.NewEncoder(w).Encode(resp); err != nil { - http.Error(w, err.Error(), http.StatusInternalServerError) - return - } - w.(http.Flusher).Flush() - } - }, - }, - expectedOutput: "", - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - handler, ok := tt.serverResponse[r.URL.Path] - if !ok { - t.Errorf("unexpected request to %s", r.URL.Path) + } else if strings.HasPrefix(r.URL.Path, "/api/blobs/") { + w.WriteHeader(http.StatusOK) + } else { http.Error(w, "not found", http.StatusNotFound) - return } - handler(w, r) - })) - t.Setenv("OLLAMA_HOST", mockServer.URL) - t.Cleanup(mockServer.Close) - tempFile, err := os.CreateTemp(t.TempDir(), "modelfile") - if err != nil { - t.Fatal(err) - } - defer os.Remove(tempFile.Name()) + }) - if _, err := tempFile.WriteString(tt.modelFile); err != nil { - t.Fatal(err) + var filename string + if tt.filename != nil { + filename = tt.filename(t) } - if err := tempFile.Close(); err != nil { + + if err := cmd.Flags().Set("file", filename); err != nil { t.Fatal(err) } - cmd := &cobra.Command{} - cmd.Flags().String("file", "", "") - if err := cmd.Flags().Set("file", tempFile.Name()); err != nil { + if err := CreateHandler(&cmd, []string{filepath.Base(t.Name())}); !errors.Is(err, tt.wantErr) { t.Fatal(err) } - - cmd.Flags().Bool("insecure", false, "") - cmd.SetContext(t.Context()) - - // Redirect stderr to capture progress output - oldStderr := os.Stderr - r, w, _ := os.Pipe() - os.Stderr = w - - // Capture stdout for the "Model pushed" message - oldStdout := os.Stdout - outR, outW, _ := os.Pipe() - os.Stdout = outW - - err = CreateHandler(cmd, []string{tt.modelName}) - - // Restore stderr - w.Close() - os.Stderr = oldStderr - // drain the pipe - if _, err := io.ReadAll(r); err != nil { - t.Fatal(err) - } - - // Restore stdout and get output - outW.Close() - os.Stdout = oldStdout - stdout, _ := io.ReadAll(outR) - - if tt.expectedError == "" { - if err != nil { - t.Errorf("expected no error, got %v", err) - } - - if tt.expectedOutput != "" { - if got := string(stdout); got != tt.expectedOutput { - t.Errorf("expected output %q, got %q", tt.expectedOutput, got) - } - } - } }) } }