From 9bee2450e9d20aff4cddef76f298899954aa4839 Mon Sep 17 00:00:00 2001 From: Michael Yang Date: Tue, 28 Oct 2025 20:53:32 -0700 Subject: [PATCH] pull --- api/client.go | 19 ---------- cmd/cmd.go | 97 ++++++++++++--------------------------------------- cmd/pull.go | 37 ++++++++++++++++++++ 3 files changed, 60 insertions(+), 93 deletions(-) create mode 100644 cmd/pull.go diff --git a/api/client.go b/api/client.go index c4a8a75ca..9bab4462e 100644 --- a/api/client.go +++ b/api/client.go @@ -294,25 +294,6 @@ func (c *Client) Chat(ctx context.Context, req *ChatRequest, fn ChatResponseFunc }) } -// PullProgressFunc is a function that [Client.Pull] invokes every time there -// is progress with a "pull" request sent to the service. If this function -// returns an error, [Client.Pull] will stop the process and return this error. -type PullProgressFunc func(ProgressResponse) error - -// Pull downloads a model from the ollama library. fn is called each time -// progress is made on the request and can be used to display a progress bar, -// etc. -func (c *Client) Pull(ctx context.Context, req *PullRequest, fn PullProgressFunc) error { - return c.stream(ctx, http.MethodPost, "/api/pull", req, func(bts []byte) error { - var resp ProgressResponse - if err := json.Unmarshal(bts, &resp); err != nil { - return err - } - - return fn(resp) - }) -} - // PushProgressFunc is a function that [Client.Push] invokes when progress is // made. // It's similar to other progress function types like [PullProgressFunc]. diff --git a/cmd/cmd.go b/cmd/cmd.go index 5b04a1864..86e4d9f3b 100644 --- a/cmd/cmd.go +++ b/cmd/cmd.go @@ -10,6 +10,7 @@ import ( "errors" "fmt" "io" + "iter" "log" "math" "net" @@ -413,7 +414,7 @@ func RunHandler(cmd *cobra.Command, args []string) error { info, err := client.Show(cmd.Context(), showReq) var se api.StatusError if errors.As(err, &se) && se.StatusCode == http.StatusNotFound { - if err := PullHandler(cmd, []string{name}); err != nil { + if err := pullHandler(cmd, []string{name}); err != nil { return nil, err } return client.Show(cmd.Context(), &api.ShowRequest{Name: name}) @@ -985,79 +986,38 @@ func CopyHandler(cmd *cobra.Command, args []string) error { return nil } -func PullHandler(cmd *cobra.Command, args []string) error { - insecure, err := cmd.Flags().GetBool("insecure") +func must[T any](v T, err error) T { if err != nil { - return err + panic(err) } + return v +} - client, err := api.ClientFromEnvironment() - if err != nil { - return err - } - - p := progress.NewProgress(os.Stderr) - defer p.Stop() - - bars := make(map[string]*progress.Bar) - +func progressHandler(p *progress.Progress, w iter.Seq2[api.ProgressResponse, error]) error { var status string - var spinner *progress.Spinner - - fn := func(resp api.ProgressResponse) error { - if resp.Digest != "" { - if resp.Completed == 0 { - // This is the initial status update for the - // layer, which the server sends before - // beginning the download, for clients to - // compute total size and prepare for - // downloads, if needed. - // - // Skipping this here to avoid showing a 0% - // progress bar, which *should* clue the user - // into the fact that many things are being - // downloaded and that the current active - // download is not that last. However, in rare - // cases it seems to be triggering to some, and - // it isn't worth explaining, so just ignore - // and regress to the old UI that keeps giving - // you the "But wait, there is more!" after - // each "100% done" bar, which is "better." - return nil + var state progress.State + for c := range w { + if c.Status != status { + if s, ok := state.(*progress.Spinner); ok { + s.Stop() } - if spinner != nil { - spinner.Stop() + status = c.Status + if c.Digest != "" { + state = progress.NewBar(status, c.Total, c.Completed) + } else { + state = progress.NewSpinner(status) } - bar, ok := bars[resp.Digest] - if !ok { - name, isDigest := strings.CutPrefix(resp.Digest, "sha256:") - name = strings.TrimSpace(name) - if isDigest { - name = name[:min(12, len(name))] - } - bar = progress.NewBar(fmt.Sprintf("pulling %s:", name), resp.Total, resp.Completed) - bars[resp.Digest] = bar - p.Add(resp.Digest, bar) - } - - bar.Set(resp.Completed) - } else if status != resp.Status { - if spinner != nil { - spinner.Stop() - } - - status = resp.Status - spinner = progress.NewSpinner(status) - p.Add(status, spinner) + p.Add(status, state) } - return nil + if b, ok := state.(*progress.Bar); ok { + b.Set(c.Completed) + } } - request := api.PullRequest{Name: args[0], Insecure: insecure} - return client.Pull(cmd.Context(), &request, fn) + return nil } type generateContextKey string @@ -1652,16 +1612,6 @@ func NewCLI() *cobra.Command { RunE: RunServer, } - pullCmd := &cobra.Command{ - Use: "pull MODEL", - Short: "Pull a model from a registry", - Args: cobra.ExactArgs(1), - PreRunE: checkServerHeartbeat, - RunE: PullHandler, - } - - pullCmd.Flags().Bool("insecure", false, "Use an insecure registry") - pushCmd := &cobra.Command{ Use: "push MODEL", Short: "Push a model to a registry", @@ -1731,7 +1681,6 @@ func NewCLI() *cobra.Command { showCmd, runCmd, stopCmd, - pullCmd, pushCmd, psCmd, copyCmd, @@ -1771,7 +1720,7 @@ func NewCLI() *cobra.Command { showCmd, runCmd, stopCmd, - pullCmd, + cmdPull(), pushCmd, signinCmd, signoutCmd, diff --git a/cmd/pull.go b/cmd/pull.go new file mode 100644 index 000000000..53af1713c --- /dev/null +++ b/cmd/pull.go @@ -0,0 +1,37 @@ +package cmd + +import ( + "os" + + "github.com/ollama/ollama/api" + "github.com/ollama/ollama/client" + "github.com/ollama/ollama/progress" + "github.com/spf13/cobra" +) + +func cmdPull() *cobra.Command { + cmd := cobra.Command{ + Use: "pull [model]", + Short: "Pull a model from a remote repository", + Args: cobra.ExactArgs(1), + PreRunE: checkServerHeartbeat, + RunE: pullHandler, + } + cmd.Flags().Bool("insecure", false, "Allow insecure server connections when pulling models") + return &cmd +} + +func pullHandler(cmd *cobra.Command, args []string) error { + c := client.New() + w, err := c.Pull(cmd.Context(), api.PullRequest{ + Name: args[0], + Insecure: must(cmd.Flags().GetBool("insecure")), + }) + if err != nil { + return err + } + + p := progress.NewProgress(os.Stderr) + defer p.Stop() + return progressHandler(p, w) +}