accept `-` to create from stdin
This commit is contained in:
parent
ff89ba90bc
commit
19279d778d
72
cmd/cmd.go
72
cmd/cmd.go
|
|
@ -64,54 +64,37 @@ func ensureThinkingSupport(ctx context.Context, client *api.Client, name string)
|
|||
fmt.Fprintf(os.Stderr, "warning: model %q does not support thinking output\n", name)
|
||||
}
|
||||
|
||||
var errModelfileNotFound = errors.New("specified Modelfile wasn't found")
|
||||
|
||||
func getModelfileName(cmd *cobra.Command) (string, error) {
|
||||
filename, _ := cmd.Flags().GetString("file")
|
||||
|
||||
if filename == "" {
|
||||
filename = "Modelfile"
|
||||
}
|
||||
|
||||
absName, err := filepath.Abs(filename)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
_, err = os.Stat(absName)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return absName, nil
|
||||
}
|
||||
|
||||
func CreateHandler(cmd *cobra.Command, args []string) error {
|
||||
p := progress.NewProgress(os.Stderr)
|
||||
defer p.Stop()
|
||||
|
||||
var reader io.Reader
|
||||
|
||||
filename, err := getModelfileName(cmd)
|
||||
if os.IsNotExist(err) {
|
||||
if filename == "" {
|
||||
reader = strings.NewReader("FROM .\n")
|
||||
} else {
|
||||
return errModelfileNotFound
|
||||
}
|
||||
} else if err != nil {
|
||||
return err
|
||||
} else {
|
||||
f, err := os.Open(filename)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
reader = f
|
||||
defer f.Close()
|
||||
filename, err := cmd.Flags().GetString("file")
|
||||
if err != nil {
|
||||
return fmt.Errorf("error retrieving file flag: %w", err)
|
||||
}
|
||||
|
||||
modelfile, err := parser.ParseFile(reader)
|
||||
var r, fallback io.Reader
|
||||
switch filename {
|
||||
case "-":
|
||||
r = os.Stdin
|
||||
case "":
|
||||
filename = "Modelfile"
|
||||
fallback = strings.NewReader("FROM .")
|
||||
fallthrough
|
||||
default:
|
||||
r, err = os.Open(filename)
|
||||
if errors.Is(err, os.ErrNotExist) && fallback != nil {
|
||||
r = fallback
|
||||
} else if errors.Is(err, os.ErrNotExist) {
|
||||
return fmt.Errorf("%w: Modelfile %q does not exist, please create it or use --file to specify a different file", err, filename)
|
||||
} else if err != nil {
|
||||
return err
|
||||
} else {
|
||||
defer r.(*os.File).Close()
|
||||
}
|
||||
}
|
||||
|
||||
modelfile, err := parser.ParseFile(r)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
|
@ -127,10 +110,7 @@ func CreateHandler(cmd *cobra.Command, args []string) error {
|
|||
spinner.Stop()
|
||||
|
||||
req.Model = args[0]
|
||||
quantize, _ := cmd.Flags().GetString("quantize")
|
||||
if quantize != "" {
|
||||
req.Quantize = quantize
|
||||
}
|
||||
req.Quantize, _ = cmd.Flags().GetString("quantize")
|
||||
|
||||
client, err := api.ClientFromEnvironment()
|
||||
if err != nil {
|
||||
|
|
|
|||
345
cmd/cmd_test.go
345
cmd/cmd_test.go
|
|
@ -3,21 +3,31 @@ package cmd
|
|||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/google/go-cmp/cmp/cmpopts"
|
||||
"github.com/spf13/cobra"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/types/model"
|
||||
)
|
||||
|
||||
func mockServer(t *testing.T, h http.HandlerFunc) {
|
||||
t.Helper()
|
||||
s := httptest.NewServer(h)
|
||||
t.Cleanup(s.Close)
|
||||
t.Setenv("OLLAMA_HOST", s.URL)
|
||||
}
|
||||
|
||||
func TestShowInfo(t *testing.T) {
|
||||
t.Run("bare details", func(t *testing.T) {
|
||||
var b bytes.Buffer
|
||||
|
|
@ -351,101 +361,6 @@ func TestDeleteHandler(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestGetModelfileName(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
modelfileName string
|
||||
fileExists bool
|
||||
expectedName string
|
||||
expectedErr error
|
||||
}{
|
||||
{
|
||||
name: "no modelfile specified, no modelfile exists",
|
||||
modelfileName: "",
|
||||
fileExists: false,
|
||||
expectedName: "",
|
||||
expectedErr: os.ErrNotExist,
|
||||
},
|
||||
{
|
||||
name: "no modelfile specified, modelfile exists",
|
||||
modelfileName: "",
|
||||
fileExists: true,
|
||||
expectedName: "Modelfile",
|
||||
expectedErr: nil,
|
||||
},
|
||||
{
|
||||
name: "modelfile specified, no modelfile exists",
|
||||
modelfileName: "crazyfile",
|
||||
fileExists: false,
|
||||
expectedName: "",
|
||||
expectedErr: os.ErrNotExist,
|
||||
},
|
||||
{
|
||||
name: "modelfile specified, modelfile exists",
|
||||
modelfileName: "anotherfile",
|
||||
fileExists: true,
|
||||
expectedName: "anotherfile",
|
||||
expectedErr: nil,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
cmd := &cobra.Command{
|
||||
Use: "fakecmd",
|
||||
}
|
||||
cmd.Flags().String("file", "", "path to modelfile")
|
||||
|
||||
var expectedFilename string
|
||||
|
||||
if tt.fileExists {
|
||||
var fn string
|
||||
if tt.modelfileName != "" {
|
||||
fn = tt.modelfileName
|
||||
} else {
|
||||
fn = "Modelfile"
|
||||
}
|
||||
|
||||
tempFile, err := os.CreateTemp(t.TempDir(), fn)
|
||||
if err != nil {
|
||||
t.Fatalf("temp modelfile creation failed: %v", err)
|
||||
}
|
||||
defer tempFile.Close()
|
||||
|
||||
expectedFilename = tempFile.Name()
|
||||
err = cmd.Flags().Set("file", expectedFilename)
|
||||
if err != nil {
|
||||
t.Fatalf("couldn't set file flag: %v", err)
|
||||
}
|
||||
} else {
|
||||
expectedFilename = tt.expectedName
|
||||
if tt.modelfileName != "" {
|
||||
err := cmd.Flags().Set("file", tt.modelfileName)
|
||||
if err != nil {
|
||||
t.Fatalf("couldn't set file flag: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
actualFilename, actualErr := getModelfileName(cmd)
|
||||
|
||||
if actualFilename != expectedFilename {
|
||||
t.Errorf("expected filename: '%s' actual filename: '%s'", expectedFilename, actualFilename)
|
||||
}
|
||||
|
||||
if tt.expectedErr != os.ErrNotExist {
|
||||
if actualErr != tt.expectedErr {
|
||||
t.Errorf("expected err: %v actual err: %v", tt.expectedErr, actualErr)
|
||||
}
|
||||
} else {
|
||||
if !os.IsNotExist(actualErr) {
|
||||
t.Errorf("expected err: %v actual err: %v", tt.expectedErr, actualErr)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestPushHandler(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
|
|
@ -661,128 +576,160 @@ func TestListHandler(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestCreateHandler(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
modelName string
|
||||
modelFile string
|
||||
serverResponse map[string]func(w http.ResponseWriter, r *http.Request)
|
||||
expectedError string
|
||||
expectedOutput string
|
||||
cases := []struct {
|
||||
name string
|
||||
filename func(*testing.T) string
|
||||
|
||||
wantRequest api.CreateRequest
|
||||
wantErr error
|
||||
}{
|
||||
{
|
||||
name: "successful create",
|
||||
modelName: "test-model",
|
||||
modelFile: "FROM foo",
|
||||
serverResponse: map[string]func(w http.ResponseWriter, r *http.Request){
|
||||
"/api/create": func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodPost {
|
||||
t.Errorf("expected POST request, got %s", r.Method)
|
||||
}
|
||||
name: "not exist",
|
||||
filename: func(*testing.T) string { return "not_exist" },
|
||||
wantErr: os.ErrNotExist,
|
||||
},
|
||||
{
|
||||
name: "stdin",
|
||||
filename: func(t *testing.T) string {
|
||||
r, w, err := os.Pipe()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
req := api.CreateRequest{}
|
||||
if _, err := w.WriteString("FROM stdin"); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if err := w.Close(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
stdin := os.Stdin
|
||||
t.Cleanup(func() { os.Stdin = stdin })
|
||||
os.Stdin = r
|
||||
return "-"
|
||||
},
|
||||
wantRequest: api.CreateRequest{
|
||||
Model: "stdin",
|
||||
From: "stdin",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "default",
|
||||
filename: func(t *testing.T) string {
|
||||
t.Chdir(t.TempDir())
|
||||
|
||||
f, err := os.Create("Modelfile")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
if _, err := f.WriteString("FROM default"); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
return ""
|
||||
},
|
||||
wantRequest: api.CreateRequest{
|
||||
Model: "default",
|
||||
From: "default",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "file flag",
|
||||
filename: func(t *testing.T) string {
|
||||
f, err := os.CreateTemp(t.TempDir(), filepath.Base(t.Name()))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
if _, err := f.WriteString("FROM file:flag"); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
return f.Name()
|
||||
},
|
||||
wantRequest: api.CreateRequest{
|
||||
Model: "file_flag",
|
||||
From: "file:flag",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "default safetensors",
|
||||
filename: func(t *testing.T) string {
|
||||
t.Chdir(t.TempDir())
|
||||
f, err := os.Create("model.safetensors")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
if err := f.Truncate(1); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
return ""
|
||||
},
|
||||
wantRequest: api.CreateRequest{
|
||||
Model: "default_safetensors",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "insecure path",
|
||||
filename: func(t *testing.T) string {
|
||||
t.Chdir(t.TempDir())
|
||||
if err := os.Symlink("../../../../../../nope", "model.safetensors"); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
return ""
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
var cmd cobra.Command
|
||||
cmd.SetContext(t.Context())
|
||||
cmd.Flags().String("file", "", "")
|
||||
cmd.Flags().String("quantize", "", "")
|
||||
for _, tt := range cases {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
mockServer(t, func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodPost {
|
||||
http.Error(w, "method not allowed", http.StatusMethodNotAllowed)
|
||||
return
|
||||
}
|
||||
|
||||
if r.URL.Path == "/api/create" {
|
||||
var req api.CreateRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
http.Error(w, err.Error(), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
if req.Model != "test-model" {
|
||||
t.Errorf("expected model name 'test-model', got %s", req.Name)
|
||||
if diff := cmp.Diff(tt.wantRequest, req, cmpopts.IgnoreFields(api.CreateRequest{}, "Files")); diff != "" {
|
||||
t.Errorf("Create request mismatch (-want +got):\n%s", diff)
|
||||
}
|
||||
|
||||
if req.From != "foo" {
|
||||
t.Errorf("expected from 'foo', got %s", req.From)
|
||||
}
|
||||
|
||||
responses := []api.ProgressResponse{
|
||||
{Status: "using existing layer sha256:56bb8bd477a519ffa694fc449c2413c6f0e1d3b1c88fa7e3c9d88d3ae49d4dcb"},
|
||||
{Status: "writing manifest"},
|
||||
{Status: "success"},
|
||||
}
|
||||
|
||||
for _, resp := range responses {
|
||||
if err := json.NewEncoder(w).Encode(resp); err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
w.(http.Flusher).Flush()
|
||||
}
|
||||
},
|
||||
},
|
||||
expectedOutput: "",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
handler, ok := tt.serverResponse[r.URL.Path]
|
||||
if !ok {
|
||||
t.Errorf("unexpected request to %s", r.URL.Path)
|
||||
} else if strings.HasPrefix(r.URL.Path, "/api/blobs/") {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
} else {
|
||||
http.Error(w, "not found", http.StatusNotFound)
|
||||
return
|
||||
}
|
||||
handler(w, r)
|
||||
}))
|
||||
t.Setenv("OLLAMA_HOST", mockServer.URL)
|
||||
t.Cleanup(mockServer.Close)
|
||||
tempFile, err := os.CreateTemp(t.TempDir(), "modelfile")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer os.Remove(tempFile.Name())
|
||||
})
|
||||
|
||||
if _, err := tempFile.WriteString(tt.modelFile); err != nil {
|
||||
t.Fatal(err)
|
||||
var filename string
|
||||
if tt.filename != nil {
|
||||
filename = tt.filename(t)
|
||||
}
|
||||
if err := tempFile.Close(); err != nil {
|
||||
|
||||
if err := cmd.Flags().Set("file", filename); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
cmd := &cobra.Command{}
|
||||
cmd.Flags().String("file", "", "")
|
||||
if err := cmd.Flags().Set("file", tempFile.Name()); err != nil {
|
||||
if err := CreateHandler(&cmd, []string{filepath.Base(t.Name())}); !errors.Is(err, tt.wantErr) {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
cmd.Flags().Bool("insecure", false, "")
|
||||
cmd.SetContext(t.Context())
|
||||
|
||||
// Redirect stderr to capture progress output
|
||||
oldStderr := os.Stderr
|
||||
r, w, _ := os.Pipe()
|
||||
os.Stderr = w
|
||||
|
||||
// Capture stdout for the "Model pushed" message
|
||||
oldStdout := os.Stdout
|
||||
outR, outW, _ := os.Pipe()
|
||||
os.Stdout = outW
|
||||
|
||||
err = CreateHandler(cmd, []string{tt.modelName})
|
||||
|
||||
// Restore stderr
|
||||
w.Close()
|
||||
os.Stderr = oldStderr
|
||||
// drain the pipe
|
||||
if _, err := io.ReadAll(r); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Restore stdout and get output
|
||||
outW.Close()
|
||||
os.Stdout = oldStdout
|
||||
stdout, _ := io.ReadAll(outR)
|
||||
|
||||
if tt.expectedError == "" {
|
||||
if err != nil {
|
||||
t.Errorf("expected no error, got %v", err)
|
||||
}
|
||||
|
||||
if tt.expectedOutput != "" {
|
||||
if got := string(stdout); got != tt.expectedOutput {
|
||||
t.Errorf("expected output %q, got %q", tt.expectedOutput, got)
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in New Issue