accept `-` to create from stdin

This commit is contained in:
Michael Yang 2025-07-26 09:58:27 -07:00
parent ff89ba90bc
commit 19279d778d
2 changed files with 172 additions and 245 deletions

View File

@ -64,54 +64,37 @@ func ensureThinkingSupport(ctx context.Context, client *api.Client, name string)
fmt.Fprintf(os.Stderr, "warning: model %q does not support thinking output\n", name)
}
var errModelfileNotFound = errors.New("specified Modelfile wasn't found")
func getModelfileName(cmd *cobra.Command) (string, error) {
filename, _ := cmd.Flags().GetString("file")
if filename == "" {
filename = "Modelfile"
}
absName, err := filepath.Abs(filename)
if err != nil {
return "", err
}
_, err = os.Stat(absName)
if err != nil {
return "", err
}
return absName, nil
}
func CreateHandler(cmd *cobra.Command, args []string) error {
p := progress.NewProgress(os.Stderr)
defer p.Stop()
var reader io.Reader
filename, err := getModelfileName(cmd)
if os.IsNotExist(err) {
if filename == "" {
reader = strings.NewReader("FROM .\n")
} else {
return errModelfileNotFound
}
} else if err != nil {
return err
} else {
f, err := os.Open(filename)
if err != nil {
return err
}
reader = f
defer f.Close()
filename, err := cmd.Flags().GetString("file")
if err != nil {
return fmt.Errorf("error retrieving file flag: %w", err)
}
modelfile, err := parser.ParseFile(reader)
var r, fallback io.Reader
switch filename {
case "-":
r = os.Stdin
case "":
filename = "Modelfile"
fallback = strings.NewReader("FROM .")
fallthrough
default:
r, err = os.Open(filename)
if errors.Is(err, os.ErrNotExist) && fallback != nil {
r = fallback
} else if errors.Is(err, os.ErrNotExist) {
return fmt.Errorf("%w: Modelfile %q does not exist, please create it or use --file to specify a different file", err, filename)
} else if err != nil {
return err
} else {
defer r.(*os.File).Close()
}
}
modelfile, err := parser.ParseFile(r)
if err != nil {
return err
}
@ -127,10 +110,7 @@ func CreateHandler(cmd *cobra.Command, args []string) error {
spinner.Stop()
req.Model = args[0]
quantize, _ := cmd.Flags().GetString("quantize")
if quantize != "" {
req.Quantize = quantize
}
req.Quantize, _ = cmd.Flags().GetString("quantize")
client, err := api.ClientFromEnvironment()
if err != nil {

View File

@ -3,21 +3,31 @@ package cmd
import (
"bytes"
"encoding/json"
"errors"
"io"
"net/http"
"net/http/httptest"
"os"
"path/filepath"
"strings"
"testing"
"time"
"github.com/google/go-cmp/cmp"
"github.com/google/go-cmp/cmp/cmpopts"
"github.com/spf13/cobra"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/types/model"
)
func mockServer(t *testing.T, h http.HandlerFunc) {
t.Helper()
s := httptest.NewServer(h)
t.Cleanup(s.Close)
t.Setenv("OLLAMA_HOST", s.URL)
}
func TestShowInfo(t *testing.T) {
t.Run("bare details", func(t *testing.T) {
var b bytes.Buffer
@ -351,101 +361,6 @@ func TestDeleteHandler(t *testing.T) {
}
}
func TestGetModelfileName(t *testing.T) {
tests := []struct {
name string
modelfileName string
fileExists bool
expectedName string
expectedErr error
}{
{
name: "no modelfile specified, no modelfile exists",
modelfileName: "",
fileExists: false,
expectedName: "",
expectedErr: os.ErrNotExist,
},
{
name: "no modelfile specified, modelfile exists",
modelfileName: "",
fileExists: true,
expectedName: "Modelfile",
expectedErr: nil,
},
{
name: "modelfile specified, no modelfile exists",
modelfileName: "crazyfile",
fileExists: false,
expectedName: "",
expectedErr: os.ErrNotExist,
},
{
name: "modelfile specified, modelfile exists",
modelfileName: "anotherfile",
fileExists: true,
expectedName: "anotherfile",
expectedErr: nil,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
cmd := &cobra.Command{
Use: "fakecmd",
}
cmd.Flags().String("file", "", "path to modelfile")
var expectedFilename string
if tt.fileExists {
var fn string
if tt.modelfileName != "" {
fn = tt.modelfileName
} else {
fn = "Modelfile"
}
tempFile, err := os.CreateTemp(t.TempDir(), fn)
if err != nil {
t.Fatalf("temp modelfile creation failed: %v", err)
}
defer tempFile.Close()
expectedFilename = tempFile.Name()
err = cmd.Flags().Set("file", expectedFilename)
if err != nil {
t.Fatalf("couldn't set file flag: %v", err)
}
} else {
expectedFilename = tt.expectedName
if tt.modelfileName != "" {
err := cmd.Flags().Set("file", tt.modelfileName)
if err != nil {
t.Fatalf("couldn't set file flag: %v", err)
}
}
}
actualFilename, actualErr := getModelfileName(cmd)
if actualFilename != expectedFilename {
t.Errorf("expected filename: '%s' actual filename: '%s'", expectedFilename, actualFilename)
}
if tt.expectedErr != os.ErrNotExist {
if actualErr != tt.expectedErr {
t.Errorf("expected err: %v actual err: %v", tt.expectedErr, actualErr)
}
} else {
if !os.IsNotExist(actualErr) {
t.Errorf("expected err: %v actual err: %v", tt.expectedErr, actualErr)
}
}
})
}
}
func TestPushHandler(t *testing.T) {
tests := []struct {
name string
@ -661,128 +576,160 @@ func TestListHandler(t *testing.T) {
}
func TestCreateHandler(t *testing.T) {
tests := []struct {
name string
modelName string
modelFile string
serverResponse map[string]func(w http.ResponseWriter, r *http.Request)
expectedError string
expectedOutput string
cases := []struct {
name string
filename func(*testing.T) string
wantRequest api.CreateRequest
wantErr error
}{
{
name: "successful create",
modelName: "test-model",
modelFile: "FROM foo",
serverResponse: map[string]func(w http.ResponseWriter, r *http.Request){
"/api/create": func(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
t.Errorf("expected POST request, got %s", r.Method)
}
name: "not exist",
filename: func(*testing.T) string { return "not_exist" },
wantErr: os.ErrNotExist,
},
{
name: "stdin",
filename: func(t *testing.T) string {
r, w, err := os.Pipe()
if err != nil {
t.Fatal(err)
}
req := api.CreateRequest{}
if _, err := w.WriteString("FROM stdin"); err != nil {
t.Fatal(err)
}
if err := w.Close(); err != nil {
t.Fatal(err)
}
stdin := os.Stdin
t.Cleanup(func() { os.Stdin = stdin })
os.Stdin = r
return "-"
},
wantRequest: api.CreateRequest{
Model: "stdin",
From: "stdin",
},
},
{
name: "default",
filename: func(t *testing.T) string {
t.Chdir(t.TempDir())
f, err := os.Create("Modelfile")
if err != nil {
t.Fatal(err)
}
defer f.Close()
if _, err := f.WriteString("FROM default"); err != nil {
t.Fatal(err)
}
return ""
},
wantRequest: api.CreateRequest{
Model: "default",
From: "default",
},
},
{
name: "file flag",
filename: func(t *testing.T) string {
f, err := os.CreateTemp(t.TempDir(), filepath.Base(t.Name()))
if err != nil {
t.Fatal(err)
}
defer f.Close()
if _, err := f.WriteString("FROM file:flag"); err != nil {
t.Fatal(err)
}
return f.Name()
},
wantRequest: api.CreateRequest{
Model: "file_flag",
From: "file:flag",
},
},
{
name: "default safetensors",
filename: func(t *testing.T) string {
t.Chdir(t.TempDir())
f, err := os.Create("model.safetensors")
if err != nil {
t.Fatal(err)
}
defer f.Close()
if err := f.Truncate(1); err != nil {
t.Fatal(err)
}
return ""
},
wantRequest: api.CreateRequest{
Model: "default_safetensors",
},
},
{
name: "insecure path",
filename: func(t *testing.T) string {
t.Chdir(t.TempDir())
if err := os.Symlink("../../../../../../nope", "model.safetensors"); err != nil {
t.Fatal(err)
}
return ""
},
},
}
var cmd cobra.Command
cmd.SetContext(t.Context())
cmd.Flags().String("file", "", "")
cmd.Flags().String("quantize", "", "")
for _, tt := range cases {
t.Run(tt.name, func(t *testing.T) {
mockServer(t, func(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
http.Error(w, "method not allowed", http.StatusMethodNotAllowed)
return
}
if r.URL.Path == "/api/create" {
var req api.CreateRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
if req.Model != "test-model" {
t.Errorf("expected model name 'test-model', got %s", req.Name)
if diff := cmp.Diff(tt.wantRequest, req, cmpopts.IgnoreFields(api.CreateRequest{}, "Files")); diff != "" {
t.Errorf("Create request mismatch (-want +got):\n%s", diff)
}
if req.From != "foo" {
t.Errorf("expected from 'foo', got %s", req.From)
}
responses := []api.ProgressResponse{
{Status: "using existing layer sha256:56bb8bd477a519ffa694fc449c2413c6f0e1d3b1c88fa7e3c9d88d3ae49d4dcb"},
{Status: "writing manifest"},
{Status: "success"},
}
for _, resp := range responses {
if err := json.NewEncoder(w).Encode(resp); err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
w.(http.Flusher).Flush()
}
},
},
expectedOutput: "",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
handler, ok := tt.serverResponse[r.URL.Path]
if !ok {
t.Errorf("unexpected request to %s", r.URL.Path)
} else if strings.HasPrefix(r.URL.Path, "/api/blobs/") {
w.WriteHeader(http.StatusOK)
} else {
http.Error(w, "not found", http.StatusNotFound)
return
}
handler(w, r)
}))
t.Setenv("OLLAMA_HOST", mockServer.URL)
t.Cleanup(mockServer.Close)
tempFile, err := os.CreateTemp(t.TempDir(), "modelfile")
if err != nil {
t.Fatal(err)
}
defer os.Remove(tempFile.Name())
})
if _, err := tempFile.WriteString(tt.modelFile); err != nil {
t.Fatal(err)
var filename string
if tt.filename != nil {
filename = tt.filename(t)
}
if err := tempFile.Close(); err != nil {
if err := cmd.Flags().Set("file", filename); err != nil {
t.Fatal(err)
}
cmd := &cobra.Command{}
cmd.Flags().String("file", "", "")
if err := cmd.Flags().Set("file", tempFile.Name()); err != nil {
if err := CreateHandler(&cmd, []string{filepath.Base(t.Name())}); !errors.Is(err, tt.wantErr) {
t.Fatal(err)
}
cmd.Flags().Bool("insecure", false, "")
cmd.SetContext(t.Context())
// Redirect stderr to capture progress output
oldStderr := os.Stderr
r, w, _ := os.Pipe()
os.Stderr = w
// Capture stdout for the "Model pushed" message
oldStdout := os.Stdout
outR, outW, _ := os.Pipe()
os.Stdout = outW
err = CreateHandler(cmd, []string{tt.modelName})
// Restore stderr
w.Close()
os.Stderr = oldStderr
// drain the pipe
if _, err := io.ReadAll(r); err != nil {
t.Fatal(err)
}
// Restore stdout and get output
outW.Close()
os.Stdout = oldStdout
stdout, _ := io.ReadAll(outR)
if tt.expectedError == "" {
if err != nil {
t.Errorf("expected no error, got %v", err)
}
if tt.expectedOutput != "" {
if got := string(stdout); got != tt.expectedOutput {
t.Errorf("expected output %q, got %q", tt.expectedOutput, got)
}
}
}
})
}
}