pull
This commit is contained in:
parent
929542140f
commit
9bee2450e9
|
|
@ -294,25 +294,6 @@ func (c *Client) Chat(ctx context.Context, req *ChatRequest, fn ChatResponseFunc
|
|||
})
|
||||
}
|
||||
|
||||
// PullProgressFunc is a function that [Client.Pull] invokes every time there
|
||||
// is progress with a "pull" request sent to the service. If this function
|
||||
// returns an error, [Client.Pull] will stop the process and return this error.
|
||||
type PullProgressFunc func(ProgressResponse) error
|
||||
|
||||
// Pull downloads a model from the ollama library. fn is called each time
|
||||
// progress is made on the request and can be used to display a progress bar,
|
||||
// etc.
|
||||
func (c *Client) Pull(ctx context.Context, req *PullRequest, fn PullProgressFunc) error {
|
||||
return c.stream(ctx, http.MethodPost, "/api/pull", req, func(bts []byte) error {
|
||||
var resp ProgressResponse
|
||||
if err := json.Unmarshal(bts, &resp); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return fn(resp)
|
||||
})
|
||||
}
|
||||
|
||||
// PushProgressFunc is a function that [Client.Push] invokes when progress is
|
||||
// made.
|
||||
// It's similar to other progress function types like [PullProgressFunc].
|
||||
|
|
|
|||
97
cmd/cmd.go
97
cmd/cmd.go
|
|
@ -10,6 +10,7 @@ import (
|
|||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"iter"
|
||||
"log"
|
||||
"math"
|
||||
"net"
|
||||
|
|
@ -413,7 +414,7 @@ func RunHandler(cmd *cobra.Command, args []string) error {
|
|||
info, err := client.Show(cmd.Context(), showReq)
|
||||
var se api.StatusError
|
||||
if errors.As(err, &se) && se.StatusCode == http.StatusNotFound {
|
||||
if err := PullHandler(cmd, []string{name}); err != nil {
|
||||
if err := pullHandler(cmd, []string{name}); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return client.Show(cmd.Context(), &api.ShowRequest{Name: name})
|
||||
|
|
@ -985,79 +986,38 @@ func CopyHandler(cmd *cobra.Command, args []string) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
func PullHandler(cmd *cobra.Command, args []string) error {
|
||||
insecure, err := cmd.Flags().GetBool("insecure")
|
||||
func must[T any](v T, err error) T {
|
||||
if err != nil {
|
||||
return err
|
||||
panic(err)
|
||||
}
|
||||
return v
|
||||
}
|
||||
|
||||
client, err := api.ClientFromEnvironment()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
p := progress.NewProgress(os.Stderr)
|
||||
defer p.Stop()
|
||||
|
||||
bars := make(map[string]*progress.Bar)
|
||||
|
||||
func progressHandler(p *progress.Progress, w iter.Seq2[api.ProgressResponse, error]) error {
|
||||
var status string
|
||||
var spinner *progress.Spinner
|
||||
|
||||
fn := func(resp api.ProgressResponse) error {
|
||||
if resp.Digest != "" {
|
||||
if resp.Completed == 0 {
|
||||
// This is the initial status update for the
|
||||
// layer, which the server sends before
|
||||
// beginning the download, for clients to
|
||||
// compute total size and prepare for
|
||||
// downloads, if needed.
|
||||
//
|
||||
// Skipping this here to avoid showing a 0%
|
||||
// progress bar, which *should* clue the user
|
||||
// into the fact that many things are being
|
||||
// downloaded and that the current active
|
||||
// download is not that last. However, in rare
|
||||
// cases it seems to be triggering to some, and
|
||||
// it isn't worth explaining, so just ignore
|
||||
// and regress to the old UI that keeps giving
|
||||
// you the "But wait, there is more!" after
|
||||
// each "100% done" bar, which is "better."
|
||||
return nil
|
||||
var state progress.State
|
||||
for c := range w {
|
||||
if c.Status != status {
|
||||
if s, ok := state.(*progress.Spinner); ok {
|
||||
s.Stop()
|
||||
}
|
||||
|
||||
if spinner != nil {
|
||||
spinner.Stop()
|
||||
status = c.Status
|
||||
if c.Digest != "" {
|
||||
state = progress.NewBar(status, c.Total, c.Completed)
|
||||
} else {
|
||||
state = progress.NewSpinner(status)
|
||||
}
|
||||
|
||||
bar, ok := bars[resp.Digest]
|
||||
if !ok {
|
||||
name, isDigest := strings.CutPrefix(resp.Digest, "sha256:")
|
||||
name = strings.TrimSpace(name)
|
||||
if isDigest {
|
||||
name = name[:min(12, len(name))]
|
||||
}
|
||||
bar = progress.NewBar(fmt.Sprintf("pulling %s:", name), resp.Total, resp.Completed)
|
||||
bars[resp.Digest] = bar
|
||||
p.Add(resp.Digest, bar)
|
||||
}
|
||||
|
||||
bar.Set(resp.Completed)
|
||||
} else if status != resp.Status {
|
||||
if spinner != nil {
|
||||
spinner.Stop()
|
||||
}
|
||||
|
||||
status = resp.Status
|
||||
spinner = progress.NewSpinner(status)
|
||||
p.Add(status, spinner)
|
||||
p.Add(status, state)
|
||||
}
|
||||
|
||||
return nil
|
||||
if b, ok := state.(*progress.Bar); ok {
|
||||
b.Set(c.Completed)
|
||||
}
|
||||
}
|
||||
|
||||
request := api.PullRequest{Name: args[0], Insecure: insecure}
|
||||
return client.Pull(cmd.Context(), &request, fn)
|
||||
return nil
|
||||
}
|
||||
|
||||
type generateContextKey string
|
||||
|
|
@ -1652,16 +1612,6 @@ func NewCLI() *cobra.Command {
|
|||
RunE: RunServer,
|
||||
}
|
||||
|
||||
pullCmd := &cobra.Command{
|
||||
Use: "pull MODEL",
|
||||
Short: "Pull a model from a registry",
|
||||
Args: cobra.ExactArgs(1),
|
||||
PreRunE: checkServerHeartbeat,
|
||||
RunE: PullHandler,
|
||||
}
|
||||
|
||||
pullCmd.Flags().Bool("insecure", false, "Use an insecure registry")
|
||||
|
||||
pushCmd := &cobra.Command{
|
||||
Use: "push MODEL",
|
||||
Short: "Push a model to a registry",
|
||||
|
|
@ -1731,7 +1681,6 @@ func NewCLI() *cobra.Command {
|
|||
showCmd,
|
||||
runCmd,
|
||||
stopCmd,
|
||||
pullCmd,
|
||||
pushCmd,
|
||||
psCmd,
|
||||
copyCmd,
|
||||
|
|
@ -1771,7 +1720,7 @@ func NewCLI() *cobra.Command {
|
|||
showCmd,
|
||||
runCmd,
|
||||
stopCmd,
|
||||
pullCmd,
|
||||
cmdPull(),
|
||||
pushCmd,
|
||||
signinCmd,
|
||||
signoutCmd,
|
||||
|
|
|
|||
|
|
@ -0,0 +1,37 @@
|
|||
package cmd
|
||||
|
||||
import (
|
||||
"os"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/client"
|
||||
"github.com/ollama/ollama/progress"
|
||||
"github.com/spf13/cobra"
|
||||
)
|
||||
|
||||
func cmdPull() *cobra.Command {
|
||||
cmd := cobra.Command{
|
||||
Use: "pull [model]",
|
||||
Short: "Pull a model from a remote repository",
|
||||
Args: cobra.ExactArgs(1),
|
||||
PreRunE: checkServerHeartbeat,
|
||||
RunE: pullHandler,
|
||||
}
|
||||
cmd.Flags().Bool("insecure", false, "Allow insecure server connections when pulling models")
|
||||
return &cmd
|
||||
}
|
||||
|
||||
func pullHandler(cmd *cobra.Command, args []string) error {
|
||||
c := client.New()
|
||||
w, err := c.Pull(cmd.Context(), api.PullRequest{
|
||||
Name: args[0],
|
||||
Insecure: must(cmd.Flags().GetBool("insecure")),
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
p := progress.NewProgress(os.Stderr)
|
||||
defer p.Stop()
|
||||
return progressHandler(p, w)
|
||||
}
|
||||
Loading…
Reference in New Issue