From bddb27ab5b1414c83dc8fe5b56f9ab9e501f9a54 Mon Sep 17 00:00:00 2001 From: jmorganca Date: Sat, 20 Dec 2025 16:16:34 -0800 Subject: [PATCH] client2 updated --- server/internal/cache/blob/cache.go | 52 ++-- server/internal/cache/blob/cache_test.go | 27 +- server/internal/client/ollama/registry.go | 251 ++++++++---------- server/internal/client/ollama/trace.go | 34 +-- server/internal/internal/backoff/backoff.go | 64 ++--- .../internal/backoff/backoff_synctest_test.go | 69 +++-- .../internal/internal/backoff/backoff_test.go | 27 +- server/internal/registry/server.go | 66 ++--- 8 files changed, 271 insertions(+), 319 deletions(-) diff --git a/server/internal/cache/blob/cache.go b/server/internal/cache/blob/cache.go index 2f02ca955..2d31630e8 100644 --- a/server/internal/cache/blob/cache.go +++ b/server/internal/cache/blob/cache.go @@ -10,7 +10,6 @@ import ( "hash" "io" "io/fs" - "iter" "os" "path/filepath" "strings" @@ -327,21 +326,19 @@ func (c *DiskCache) GetFile(d Digest) string { return absJoin(c.dir, "blobs", filename) } -// Links returns a sequence of link names. The sequence is in lexical order. +// Links returns a slice of link names in lexical order. // Names are converted from their relative path form to their name form but are // not guaranteed to be valid. Callers should validate the names before using. -func (c *DiskCache) Links() iter.Seq2[string, error] { - return func(yield func(string, error) bool) { - for path, err := range c.links() { - if err != nil { - yield("", err) - return - } - if !yield(pathToName(path), nil) { - return - } - } +func (c *DiskCache) Links() ([]string, error) { + paths, err := c.links() + if err != nil { + return nil, err } + names := make([]string, len(paths)) + for i, path := range paths { + names[i] = pathToName(path) + } + return names, nil } // pathToName converts a path to a name. It is the inverse of nameToPath. The @@ -372,10 +369,11 @@ func (c *DiskCache) manifestPath(name string) (string, error) { } maybe := filepath.Join("manifests", np) - for l, err := range c.links() { - if err != nil { - return "", err - } + paths, err := c.links() + if err != nil { + return "", err + } + for _, l := range paths { if strings.EqualFold(maybe, l) { return filepath.Join(c.dir, l), nil } @@ -383,22 +381,10 @@ func (c *DiskCache) manifestPath(name string) (string, error) { return filepath.Join(c.dir, maybe), nil } -// links returns a sequence of links in the cache in lexical order. -func (c *DiskCache) links() iter.Seq2[string, error] { - // TODO(bmizerany): reuse empty dirnames if exist - return func(yield func(string, error) bool) { - fsys := os.DirFS(c.dir) - manifests, err := fs.Glob(fsys, "manifests/*/*/*/*") - if err != nil { - yield("", err) - return - } - for _, manifest := range manifests { - if !yield(manifest, nil) { - return - } - } - } +// links returns a slice of link paths in the cache in lexical order. +func (c *DiskCache) links() ([]string, error) { + fsys := os.DirFS(c.dir) + return fs.Glob(fsys, "manifests/*/*/*/*") } type checkWriter struct { diff --git a/server/internal/cache/blob/cache_test.go b/server/internal/cache/blob/cache_test.go index af29a3123..f6f15a4e2 100644 --- a/server/internal/cache/blob/cache_test.go +++ b/server/internal/cache/blob/cache_test.go @@ -466,12 +466,9 @@ func testManifestNameReuse(t *testing.T) { t.Fatalf("g = %v, want %v", g, w) } - var got []string - for l, err := range c.links() { - if err != nil { - t.Fatal(err) - } - got = append(got, l) + got, err := c.links() + if err != nil { + t.Fatal(err) } want := []string{"manifests/h/n/m/t"} if !slices.Equal(got, want) { @@ -487,12 +484,9 @@ func testManifestNameReuse(t *testing.T) { err = c.Link("h/n/m:T", d1) check(err) - got = got[:0] - for l, err := range c.links() { - if err != nil { - t.Fatal(err) - } - got = append(got, l) + got, err = c.links() + if err != nil { + t.Fatal(err) } // we should have only one link that is same case as the last link @@ -554,12 +548,9 @@ func TestNames(t *testing.T) { check(c.Link("h/n/m:t", mkdigest("1"))) check(c.Link("h/n/m:u", mkdigest("2"))) - var got []string - for l, err := range c.Links() { - if err != nil { - t.Fatal(err) - } - got = append(got, l) + got, err := c.Links() + if err != nil { + t.Fatal(err) } want := []string{"h/n/m:t", "h/n/m:u"} if !slices.Equal(got, want) { diff --git a/server/internal/client/ollama/registry.go b/server/internal/client/ollama/registry.go index eae130bf4..adfb4a376 100644 --- a/server/internal/client/ollama/registry.go +++ b/server/internal/client/ollama/registry.go @@ -19,7 +19,6 @@ import ( "fmt" "io" "io/fs" - "iter" "log/slog" "net/http" "os" @@ -546,18 +545,7 @@ func (r *Registry) Pull(ctx context.Context, name string) error { }) }() - for cs, err := range r.chunksums(ctx, name, l) { - if err != nil { - // Note the chunksum stream - // interruption, but do not cancel - // in-flight downloads. We can still - // make progress on them. Once they are - // done, ErrIncomplete will be returned - // below. - update(0, err) - break - } - + err = r.chunksums(ctx, name, l, func(cs chunksum) bool { cacheKey := fmt.Sprintf( "v1 pull chunksum %s %s %d-%d", l.Digest, @@ -569,7 +557,7 @@ func (r *Registry) Pull(ctx context.Context, name string) error { _, err := c.Get(cacheKeyDigest) if err == nil { update(cs.Chunk.Size(), ErrCached) - continue + return true // continue } wg.Add(1) @@ -620,6 +608,13 @@ func (r *Registry) Pull(ctx context.Context, name string) error { // Record the downloading of this chunk. return blob.PutBytes(c, cacheKeyDigest, cacheKey) }) + return true // continue processing chunks + }) + if err != nil { + // Note the chunksum stream interruption, but do not cancel + // in-flight downloads. We can still make progress on them. + // Once they are done, ErrIncomplete will be returned below. + update(0, err) } return nil @@ -674,19 +669,6 @@ func (m *Manifest) Layer(d blob.Digest) *Layer { return nil } -func (m *Manifest) All() iter.Seq[*Layer] { - return func(yield func(*Layer) bool) { - if !yield(m.Config) { - return - } - for _, l := range m.Layers { - if !yield(l) { - return - } - } - } -} - func (m *Manifest) Size() int64 { var size int64 if m.Config != nil { @@ -811,125 +793,114 @@ type chunksum struct { Digest blob.Digest } -// chunksums returns a sequence of chunksums for the given layer. If the layer is under the -// chunking threshold, a single chunksum is returned that covers the entire layer. If the layer -// is over the chunking threshold, the chunksums are read from the chunksums endpoint. -func (r *Registry) chunksums(ctx context.Context, name string, l *Layer) iter.Seq2[chunksum, error] { - return func(yield func(chunksum, error) bool) { - scheme, n, _, err := r.parseNameExtended(name) +// chunksums calls fn for each chunksum in the layer. If the layer is under the +// chunking threshold, a single chunksum covering the entire layer is passed to fn. +// If the layer is over the chunking threshold, chunksums are read from the chunksums endpoint. +// Returns an error if the chunksum stream fails, or nil if all chunksums were processed. +// If fn returns false, iteration stops early and chunksums returns nil. +func (r *Registry) chunksums(ctx context.Context, name string, l *Layer, fn func(chunksum) bool) error { + scheme, n, _, err := r.parseNameExtended(name) + if err != nil { + return err + } + + if l.Size < r.maxChunkingThreshold() { + // any layer under the threshold should be downloaded + // in one go. + cs := chunksum{ + URL: fmt.Sprintf("%s://%s/v2/%s/%s/blobs/%s", + scheme, + n.Host(), + n.Namespace(), + n.Model(), + l.Digest, + ), + Chunk: blob.Chunk{Start: 0, End: l.Size - 1}, + Digest: l.Digest, + } + fn(cs) + return nil + } + + // The response is a sequence of chunksums. + // + // Chunksums are chunks of a larger blob that can be + // downloaded and verified independently. + // + // The chunksums endpoint is a GET request that returns a + // sequence of chunksums in the following format: + // + // > GET /v2///chunksums/ + // + // < HTTP/1.1 200 OK + // < Content-Location: + // < + // < - + // < ... + // + // The is the URL to download the chunks from and + // each is the digest of the chunk, and - + // is the range the chunk in the blob. + // + // Ranges may be used directly in Range headers like + // "bytes=-". + // + // The chunksums returned are guaranteed to be contiguous and + // include all bytes of the layer. If the stream is cut short, + // clients should retry. + + chunksumsURL := fmt.Sprintf("%s://%s/v2/%s/%s/chunksums/%s", + scheme, + n.Host(), + n.Namespace(), + n.Model(), + l.Digest, + ) + + req, err := r.newRequest(ctx, "GET", chunksumsURL, nil) + if err != nil { + return err + } + res, err := sendRequest(r.client(), req) + if err != nil { + return err + } + defer res.Body.Close() + if res.StatusCode != 200 { + return fmt.Errorf("chunksums: unexpected status code %d", res.StatusCode) + } + blobURL := res.Header.Get("Content-Location") + + s := bufio.NewScanner(res.Body) + s.Split(bufio.ScanWords) + for { + if !s.Scan() { + return s.Err() + } + d, err := blob.ParseDigest(s.Bytes()) if err != nil { - yield(chunksum{}, err) - return + return fmt.Errorf("invalid digest: %q", s.Bytes()) } - if l.Size < r.maxChunkingThreshold() { - // any layer under the threshold should be downloaded - // in one go. - cs := chunksum{ - URL: fmt.Sprintf("%s://%s/v2/%s/%s/blobs/%s", - scheme, - n.Host(), - n.Namespace(), - n.Model(), - l.Digest, - ), - Chunk: blob.Chunk{Start: 0, End: l.Size - 1}, - Digest: l.Digest, + if !s.Scan() { + err := s.Err() + if err == nil { + err = fmt.Errorf("missing chunk range for digest %s", d) } - yield(cs, nil) - return + return err } - - // The response is a sequence of chunksums. - // - // Chunksums are chunks of a larger blob that can be - // downloaded and verified independently. - // - // The chunksums endpoint is a GET request that returns a - // sequence of chunksums in the following format: - // - // > GET /v2///chunksums/ - // - // < HTTP/1.1 200 OK - // < Content-Location: - // < - // < - - // < ... - // - // The is the URL to download the chunks from and - // each is the digest of the chunk, and - - // is the range the chunk in the blob. - // - // Ranges may be used directly in Range headers like - // "bytes=-". - // - // The chunksums returned are guaranteed to be contiguous and - // include all bytes of the layer. If the stream is cut short, - // clients should retry. - - chunksumsURL := fmt.Sprintf("%s://%s/v2/%s/%s/chunksums/%s", - scheme, - n.Host(), - n.Namespace(), - n.Model(), - l.Digest, - ) - - req, err := r.newRequest(ctx, "GET", chunksumsURL, nil) + chunk, err := parseChunk(s.Bytes()) if err != nil { - yield(chunksum{}, err) - return + return fmt.Errorf("invalid chunk range for digest %s: %q", d, s.Bytes()) } - res, err := sendRequest(r.client(), req) - if err != nil { - yield(chunksum{}, err) - return + + cs := chunksum{ + URL: blobURL, + Chunk: chunk, + Digest: d, } - defer res.Body.Close() - if res.StatusCode != 200 { - err := fmt.Errorf("chunksums: unexpected status code %d", res.StatusCode) - yield(chunksum{}, err) - return - } - blobURL := res.Header.Get("Content-Location") - - s := bufio.NewScanner(res.Body) - s.Split(bufio.ScanWords) - for { - if !s.Scan() { - if s.Err() != nil { - yield(chunksum{}, s.Err()) - } - return - } - d, err := blob.ParseDigest(s.Bytes()) - if err != nil { - yield(chunksum{}, fmt.Errorf("invalid digest: %q", s.Bytes())) - return - } - - if !s.Scan() { - err := s.Err() - if err == nil { - err = fmt.Errorf("missing chunk range for digest %s", d) - } - yield(chunksum{}, err) - return - } - chunk, err := parseChunk(s.Bytes()) - if err != nil { - yield(chunksum{}, fmt.Errorf("invalid chunk range for digest %s: %q", d, s.Bytes())) - return - } - - cs := chunksum{ - URL: blobURL, - Chunk: chunk, - Digest: d, - } - if !yield(cs, nil) { - return - } + if !fn(cs) { + return nil } } } @@ -1176,8 +1147,8 @@ func splitExtended(s string) (scheme, name, digest string) { return scheme, s, digest } -// parseChunk parses a string in the form "start-end" and returns the Chunk. -func parseChunk[S ~string | ~[]byte](s S) (blob.Chunk, error) { +// parseChunk parses a byte slice in the form "start-end" and returns the Chunk. +func parseChunk(s []byte) (blob.Chunk, error) { startPart, endPart, found := strings.Cut(string(s), "-") if !found { return blob.Chunk{}, fmt.Errorf("chunks: invalid range %q: missing '-'", s) diff --git a/server/internal/client/ollama/trace.go b/server/internal/client/ollama/trace.go index a7cac0d5d..a7e6c3ff7 100644 --- a/server/internal/client/ollama/trace.go +++ b/server/internal/client/ollama/trace.go @@ -27,46 +27,20 @@ type Trace struct { } func (t *Trace) update(l *Layer, n int64, err error) { - if t.Update != nil { + if t != nil && t.Update != nil { t.Update(l, n, err) } } type traceKey struct{} -// WithTrace adds a trace to the context for transfer progress reporting. +// WithTrace attaches a Trace to the context for transfer progress reporting. func WithTrace(ctx context.Context, t *Trace) context.Context { - old := traceFromContext(ctx) - if old == t { - // No change, return the original context. This also prevents - // infinite recursion below, if the caller passes the same - // Trace. - return ctx - } - - // Create a new Trace that wraps the old one, if any. If we used the - // same pointer t, we end up with a recursive structure. - composed := &Trace{ - Update: func(l *Layer, n int64, err error) { - if old != nil { - old.update(l, n, err) - } - t.update(l, n, err) - }, - } - return context.WithValue(ctx, traceKey{}, composed) + return context.WithValue(ctx, traceKey{}, t) } -var emptyTrace = &Trace{} - -// traceFromContext returns the Trace associated with ctx, or an empty Trace if -// none is found. -// -// It never returns nil. +// traceFromContext returns the Trace associated with ctx, or nil if none. func traceFromContext(ctx context.Context) *Trace { t, _ := ctx.Value(traceKey{}).(*Trace) - if t == nil { - return emptyTrace - } return t } diff --git a/server/internal/internal/backoff/backoff.go b/server/internal/internal/backoff/backoff.go index 08b4ed7f9..3dcbf06cf 100644 --- a/server/internal/internal/backoff/backoff.go +++ b/server/internal/internal/backoff/backoff.go @@ -2,44 +2,46 @@ package backoff import ( "context" - "iter" "math/rand/v2" "time" ) -func Loop(ctx context.Context, maxBackoff time.Duration) iter.Seq2[int, error] { - var n int - return func(yield func(int, error) bool) { - var t *time.Timer - for { - if ctx.Err() != nil { - yield(n, ctx.Err()) - return - } +// Retry calls fn repeatedly with exponential backoff until it returns nil, +// a non-retryable error (shouldRetry returns false), or the context is cancelled. +// The shouldRetry function determines if an error is retryable. +// Returns the last error encountered, or nil if fn succeeded. +func Retry(ctx context.Context, maxBackoff time.Duration, shouldRetry func(error) bool, fn func() error) error { + var t *time.Timer + for n := 0; ; n++ { + if err := ctx.Err(); err != nil { + return err + } - if !yield(n, nil) { - return - } + err := fn() + if err == nil { + return nil + } + if !shouldRetry(err) { + return err + } - n++ + // n^2 backoff timer is a little smoother than the + // common choice of 2^n. + d := min(time.Duration(n*n)*10*time.Millisecond, maxBackoff) + // Randomize the delay between 0.5-1.5 x msec, in order + // to prevent accidental "thundering herd" problems. + d = time.Duration(float64(d) * (rand.Float64() + 0.5)) - // n^2 backoff timer is a little smoother than the - // common choice of 2^n. - d := min(time.Duration(n*n)*10*time.Millisecond, maxBackoff) - // Randomize the delay between 0.5-1.5 x msec, in order - // to prevent accidental "thundering herd" problems. - d = time.Duration(float64(d) * (rand.Float64() + 0.5)) - - if t == nil { - t = time.NewTimer(d) - } else { - t.Reset(d) - } - select { - case <-ctx.Done(): - t.Stop() - case <-t.C: - } + if t == nil { + t = time.NewTimer(d) + } else { + t.Reset(d) + } + select { + case <-ctx.Done(): + t.Stop() + return ctx.Err() + case <-t.C: } } } diff --git a/server/internal/internal/backoff/backoff_synctest_test.go b/server/internal/internal/backoff/backoff_synctest_test.go index cf17ce80a..abf7d7d40 100644 --- a/server/internal/internal/backoff/backoff_synctest_test.go +++ b/server/internal/internal/backoff/backoff_synctest_test.go @@ -10,31 +10,70 @@ import ( "time" ) -func TestLoop(t *testing.T) { +func TestRetry(t *testing.T) { synctest.Run(func() { - last := -1 + n := 0 ctx, cancel := context.WithCancel(t.Context()) defer cancel() - for n, err := range Loop(ctx, 100*time.Millisecond) { - if !errors.Is(err, ctx.Err()) { - t.Errorf("err = %v, want nil", err) - } - if err != nil { - break - } - if n != last+1 { - t.Errorf("n = %d, want %d", n, last+1) - } - last = n + err := Retry(ctx, 100*time.Millisecond, func(err error) bool { return true }, func() error { + n++ if n > 5 { cancel() } + return errors.New("keep going") + }) + + if !errors.Is(err, context.Canceled) { + t.Errorf("err = %v, want context.Canceled", err) } - if last != 6 { - t.Errorf("last = %d, want 6", last) + if n != 6 { + t.Errorf("n = %d, want 6", n) + } + }) +} + +func TestRetrySuccess(t *testing.T) { + synctest.Run(func() { + n := 0 + err := Retry(t.Context(), 100*time.Millisecond, func(err error) bool { return true }, func() error { + n++ + if n >= 3 { + return nil // success + } + return errors.New("retry") + }) + + if err != nil { + t.Errorf("err = %v, want nil", err) + } + if n != 3 { + t.Errorf("n = %d, want 3", n) + } + }) +} + +func TestRetryNonRetryable(t *testing.T) { + synctest.Run(func() { + permanent := errors.New("permanent error") + n := 0 + err := Retry(t.Context(), 100*time.Millisecond, func(err error) bool { + return !errors.Is(err, permanent) + }, func() error { + n++ + if n >= 2 { + return permanent + } + return errors.New("retry") + }) + + if !errors.Is(err, permanent) { + t.Errorf("err = %v, want permanent", err) + } + if n != 2 { + t.Errorf("n = %d, want 2", n) } }) } diff --git a/server/internal/internal/backoff/backoff_test.go b/server/internal/internal/backoff/backoff_test.go index f474118f0..ba1562aa7 100644 --- a/server/internal/internal/backoff/backoff_test.go +++ b/server/internal/internal/backoff/backoff_test.go @@ -3,37 +3,46 @@ package backoff import ( + "errors" "testing" "testing/synctest" "time" ) -func TestLoopAllocs(t *testing.T) { +var errRetry = errors.New("retry") + +func TestRetryAllocs(t *testing.T) { for i := range 3 { got := testing.AllocsPerRun(1000, func() { - for tick := range Loop(t.Context(), 1) { + tick := 0 + Retry(t.Context(), 1, func(err error) bool { return true }, func() error { + tick++ if tick >= i { - break + return nil } - } + return errRetry + }) }) want := float64(0) if i > 0 { want = 3 // due to time.NewTimer } if got > want { - t.Errorf("[%d ticks]: allocs = %v, want 0", i, want) + t.Errorf("[%d ticks]: allocs = %v, want <= %v", i, got, want) } } } -func BenchmarkLoop(b *testing.B) { +func BenchmarkRetry(b *testing.B) { ctx := b.Context() synctest.Run(func() { - for n := range Loop(ctx, 100*time.Millisecond) { + n := 0 + Retry(ctx, 100*time.Millisecond, func(err error) bool { return true }, func() error { + n++ if n == b.N { - break + return nil } - } + return errRetry + }) }) } diff --git a/server/internal/registry/server.go b/server/internal/registry/server.go index f62a622a9..ab8528553 100644 --- a/server/internal/registry/server.go +++ b/server/internal/registry/server.go @@ -231,7 +231,7 @@ func (s *Local) handleDelete(_ http.ResponseWriter, r *http.Request) error { if r.Method != "DELETE" { return errMethodNotAllowed } - p, err := decodeUserJSON[*params](r.Body) + p, err := decodeParams(r.Body) if err != nil { return err } @@ -261,7 +261,7 @@ func (s *Local) handlePull(w http.ResponseWriter, r *http.Request) error { return errMethodNotAllowed } - p, err := decodeUserJSON[*params](r.Body) + p, err := decodeParams(r.Body) if err != nil { return err } @@ -293,10 +293,14 @@ func (s *Local) handlePull(w http.ResponseWriter, r *http.Request) error { } } - t := time.NewTicker(1<<63 - 1) // "unstarted" timer + // ticker controls periodic progress flushing. It starts paused (very long + // interval) and is activated by start() once all layers are registered, + // so clients see a complete total before progress begins. + ticker := time.NewTicker(1 << 62) // effectively paused until started + defer ticker.Stop() start := sync.OnceFunc(func() { - flushProgress() // flush initial state - t.Reset(100 * time.Millisecond) + flushProgress() + ticker.Reset(100 * time.Millisecond) }) ctx := ollama.WithTrace(r.Context(), &ollama.Trace{ Update: func(l *ollama.Layer, n int64, err error) { @@ -320,36 +324,21 @@ func (s *Local) handlePull(w http.ResponseWriter, r *http.Request) error { }) }() - // Block flushing progress updates until every - // layer is accounted for. Clients depend on a - // complete model size to calculate progress - // correctly; if they use an incomplete total, - // progress indicators would erratically jump - // as new layers are registered. start() }, }) done := make(chan error, 1) - go func() (err error) { - defer func() { done <- err }() - for _, err := range backoff.Loop(ctx, 3*time.Second) { - if err != nil { - return err - } - err := s.Client.Pull(ctx, p.model()) - if canRetry(err) { - continue - } - return err - } - return nil + go func() { + done <- backoff.Retry(ctx, 3*time.Second, canRetry, func() error { + return s.Client.Pull(ctx, p.model()) + }) }() enc.Encode(progressUpdateJSON{Status: "pulling manifest"}) for { select { - case <-t.C: + case <-ticker.C: flushProgress() case err := <-done: flushProgress() @@ -374,20 +363,13 @@ func (s *Local) handlePull(w http.ResponseWriter, r *http.Request) error { } } -func decodeUserJSON[T any](r io.Reader) (T, error) { - var v T - err := json.NewDecoder(r).Decode(&v) +func decodeParams(r io.Reader) (*params, error) { + var p params + err := json.NewDecoder(r).Decode(&p) if err == nil { - return v, nil + return &p, nil } - var zero T - // Not sure why, but I can't seem to be able to use: - // - // errors.As(err, &json.UnmarshalTypeError{}) - // - // This is working fine in stdlib, so I'm not sure what rules changed - // and why this no longer works here. So, we do it the verbose way. var a *json.UnmarshalTypeError var b *json.SyntaxError if errors.As(err, &a) || errors.As(err, &b) { @@ -396,7 +378,7 @@ func decodeUserJSON[T any](r io.Reader) (T, error) { if errors.Is(err, io.EOF) { err = &serverError{Status: 400, Message: "empty request body", Code: "bad_request"} } - return zero, err + return nil, err } func canRetry(err error) bool { @@ -408,10 +390,8 @@ func canRetry(err error) bool { return oe.Temporary() } s := err.Error() - return cmp.Or( - errors.Is(err, context.DeadlineExceeded), - strings.Contains(s, "unreachable"), - strings.Contains(s, "no route to host"), - strings.Contains(s, "connection reset by peer"), - ) + return errors.Is(err, context.DeadlineExceeded) || + strings.Contains(s, "unreachable") || + strings.Contains(s, "no route to host") || + strings.Contains(s, "connection reset by peer") }