This commit is contained in:
Michael Yang 2025-10-28 20:53:32 -07:00
parent 929542140f
commit 9bee2450e9
3 changed files with 60 additions and 93 deletions

View File

@ -294,25 +294,6 @@ func (c *Client) Chat(ctx context.Context, req *ChatRequest, fn ChatResponseFunc
})
}
// PullProgressFunc is a function that [Client.Pull] invokes every time there
// is progress with a "pull" request sent to the service. If this function
// returns an error, [Client.Pull] will stop the process and return this error.
type PullProgressFunc func(ProgressResponse) error
// Pull downloads a model from the ollama library. fn is called each time
// progress is made on the request and can be used to display a progress bar,
// etc.
func (c *Client) Pull(ctx context.Context, req *PullRequest, fn PullProgressFunc) error {
return c.stream(ctx, http.MethodPost, "/api/pull", req, func(bts []byte) error {
var resp ProgressResponse
if err := json.Unmarshal(bts, &resp); err != nil {
return err
}
return fn(resp)
})
}
// PushProgressFunc is a function that [Client.Push] invokes when progress is
// made.
// It's similar to other progress function types like [PullProgressFunc].

View File

@ -10,6 +10,7 @@ import (
"errors"
"fmt"
"io"
"iter"
"log"
"math"
"net"
@ -413,7 +414,7 @@ func RunHandler(cmd *cobra.Command, args []string) error {
info, err := client.Show(cmd.Context(), showReq)
var se api.StatusError
if errors.As(err, &se) && se.StatusCode == http.StatusNotFound {
if err := PullHandler(cmd, []string{name}); err != nil {
if err := pullHandler(cmd, []string{name}); err != nil {
return nil, err
}
return client.Show(cmd.Context(), &api.ShowRequest{Name: name})
@ -985,79 +986,38 @@ func CopyHandler(cmd *cobra.Command, args []string) error {
return nil
}
func PullHandler(cmd *cobra.Command, args []string) error {
insecure, err := cmd.Flags().GetBool("insecure")
func must[T any](v T, err error) T {
if err != nil {
return err
panic(err)
}
return v
}
client, err := api.ClientFromEnvironment()
if err != nil {
return err
}
p := progress.NewProgress(os.Stderr)
defer p.Stop()
bars := make(map[string]*progress.Bar)
func progressHandler(p *progress.Progress, w iter.Seq2[api.ProgressResponse, error]) error {
var status string
var spinner *progress.Spinner
fn := func(resp api.ProgressResponse) error {
if resp.Digest != "" {
if resp.Completed == 0 {
// This is the initial status update for the
// layer, which the server sends before
// beginning the download, for clients to
// compute total size and prepare for
// downloads, if needed.
//
// Skipping this here to avoid showing a 0%
// progress bar, which *should* clue the user
// into the fact that many things are being
// downloaded and that the current active
// download is not that last. However, in rare
// cases it seems to be triggering to some, and
// it isn't worth explaining, so just ignore
// and regress to the old UI that keeps giving
// you the "But wait, there is more!" after
// each "100% done" bar, which is "better."
return nil
var state progress.State
for c := range w {
if c.Status != status {
if s, ok := state.(*progress.Spinner); ok {
s.Stop()
}
if spinner != nil {
spinner.Stop()
status = c.Status
if c.Digest != "" {
state = progress.NewBar(status, c.Total, c.Completed)
} else {
state = progress.NewSpinner(status)
}
bar, ok := bars[resp.Digest]
if !ok {
name, isDigest := strings.CutPrefix(resp.Digest, "sha256:")
name = strings.TrimSpace(name)
if isDigest {
name = name[:min(12, len(name))]
}
bar = progress.NewBar(fmt.Sprintf("pulling %s:", name), resp.Total, resp.Completed)
bars[resp.Digest] = bar
p.Add(resp.Digest, bar)
}
bar.Set(resp.Completed)
} else if status != resp.Status {
if spinner != nil {
spinner.Stop()
}
status = resp.Status
spinner = progress.NewSpinner(status)
p.Add(status, spinner)
p.Add(status, state)
}
return nil
if b, ok := state.(*progress.Bar); ok {
b.Set(c.Completed)
}
}
request := api.PullRequest{Name: args[0], Insecure: insecure}
return client.Pull(cmd.Context(), &request, fn)
return nil
}
type generateContextKey string
@ -1652,16 +1612,6 @@ func NewCLI() *cobra.Command {
RunE: RunServer,
}
pullCmd := &cobra.Command{
Use: "pull MODEL",
Short: "Pull a model from a registry",
Args: cobra.ExactArgs(1),
PreRunE: checkServerHeartbeat,
RunE: PullHandler,
}
pullCmd.Flags().Bool("insecure", false, "Use an insecure registry")
pushCmd := &cobra.Command{
Use: "push MODEL",
Short: "Push a model to a registry",
@ -1731,7 +1681,6 @@ func NewCLI() *cobra.Command {
showCmd,
runCmd,
stopCmd,
pullCmd,
pushCmd,
psCmd,
copyCmd,
@ -1771,7 +1720,7 @@ func NewCLI() *cobra.Command {
showCmd,
runCmd,
stopCmd,
pullCmd,
cmdPull(),
pushCmd,
signinCmd,
signoutCmd,

37
cmd/pull.go Normal file
View File

@ -0,0 +1,37 @@
package cmd
import (
"os"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/client"
"github.com/ollama/ollama/progress"
"github.com/spf13/cobra"
)
func cmdPull() *cobra.Command {
cmd := cobra.Command{
Use: "pull [model]",
Short: "Pull a model from a remote repository",
Args: cobra.ExactArgs(1),
PreRunE: checkServerHeartbeat,
RunE: pullHandler,
}
cmd.Flags().Bool("insecure", false, "Allow insecure server connections when pulling models")
return &cmd
}
func pullHandler(cmd *cobra.Command, args []string) error {
c := client.New()
w, err := c.Pull(cmd.Context(), api.PullRequest{
Name: args[0],
Insecure: must(cmd.Flags().GetBool("insecure")),
})
if err != nil {
return err
}
p := progress.NewProgress(os.Stderr)
defer p.Stop()
return progressHandler(p, w)
}