diff --git a/server/download.go b/server/download.go index 382fdaa9f..3996a8334 100644 --- a/server/download.go +++ b/server/download.go @@ -34,7 +34,7 @@ const maxRetries = 6 var ( errMaxRetriesExceeded = errors.New("max retries exceeded") errPartStalled = errors.New("part stalled") - errPartSlow = errors.New("part slow, racing") + errPartSlow = errors.New("part too slow") errMaxRedirectsExceeded = errors.New("maximum redirects exceeded (10) for directURL") ) @@ -48,8 +48,8 @@ func (s *speedTracker) Record(bytesPerSec float64) { s.mu.Lock() defer s.mu.Unlock() s.speeds = append(s.speeds, bytesPerSec) - // Keep last 100 samples - if len(s.speeds) > 100 { + // Keep last 30 samples (flushes stale speeds faster when conditions change) + if len(s.speeds) > 30 { s.speeds = s.speeds[1:] } } @@ -57,8 +57,8 @@ func (s *speedTracker) Record(bytesPerSec float64) { func (s *speedTracker) Median() float64 { s.mu.Lock() defer s.mu.Unlock() - if len(s.speeds) < 3 { - return 0 // not enough data + if len(s.speeds) < 10 { + return 0 // not enough data for reliable median } sorted := slices.Clone(s.speeds) @@ -90,9 +90,6 @@ type blobDownloadPart struct { Size int64 Completed atomic.Int64 - lastUpdatedMu sync.Mutex - lastUpdated time.Time - *blobDownload `json:"-"` } @@ -128,7 +125,7 @@ func (p *blobDownloadPart) UnmarshalJSON(b []byte) error { var ( downloadPartSize = int64(envInt("OLLAMA_DOWNLOAD_PART_SIZE", 64)) * format.MegaByte - downloadConcurrency = envInt("OLLAMA_DOWNLOAD_CONCURRENCY", 48) + downloadConcurrency = envInt("OLLAMA_DOWNLOAD_CONCURRENCY", 32) ) func envInt(key string, defaultVal int) int { @@ -142,7 +139,7 @@ func envInt(key string, defaultVal int) int { // streamHasher reads a file sequentially and hashes it as chunks complete. // Memory usage: ~64KB (just the read buffer), regardless of file size or concurrency. -// Works by reading from OS page cache - data just written is still in RAM. +// Works by trying to read from OS page cache - data just written should still be in RAM. type streamHasher struct { file *os.File hasher hash.Hash @@ -169,8 +166,8 @@ func newStreamHasher(file *os.File, parts []*blobDownloadPart, total int64) *str return h } -// MarkComplete signals that a part has been written to disk. -func (h *streamHasher) MarkComplete(partIndex int) { +// Done signals that a part has been written to disk. +func (h *streamHasher) Done(partIndex int) { h.mu.Lock() h.completed[partIndex] = true h.cond.Broadcast() @@ -194,7 +191,7 @@ func (h *streamHasher) Run() { } h.mu.Unlock() - // Read and hash this part (from page cache) + // Read and hash part remaining := part.Size for remaining > 0 { n := int64(len(buf)) @@ -250,9 +247,6 @@ func (p *blobDownloadPart) Name() string { func (p *blobDownloadPart) Write(b []byte) (n int, err error) { n = len(b) p.blobDownload.Completed.Add(int64(n)) - p.lastUpdatedMu.Lock() - p.lastUpdated = time.Now() - p.lastUpdatedMu.Unlock() return n, nil } @@ -410,7 +404,7 @@ func (b *blobDownload) run(ctx context.Context, requestURL *url.URL, opts *regis for i := range b.Parts { part := b.Parts[i] if part.Completed.Load() == part.Size { - sh.MarkComplete(part.N) + sh.Done(part.N) continue } @@ -420,7 +414,7 @@ func (b *blobDownload) run(ctx context.Context, requestURL *url.URL, opts *regis for try := 0; try < maxRetries; try++ { // After 3 slow retries, stop checking slowness and let it complete skipSlowCheck := slowRetries >= 3 - err = b.downloadChunkToDisk(inner, directURL, file, part, tracker, skipSlowCheck) + err = b.downloadChunk(inner, directURL, file, part, tracker, skipSlowCheck) switch { case errors.Is(err, context.Canceled), errors.Is(err, syscall.ENOSPC): return err @@ -438,7 +432,7 @@ func (b *blobDownload) run(ctx context.Context, requestURL *url.URL, opts *regis time.Sleep(sleep) continue default: - sh.MarkComplete(part.N) + sh.Done(part.N) return nil } } @@ -480,9 +474,9 @@ func (b *blobDownload) run(ctx context.Context, requestURL *url.URL, opts *regis return nil } -// downloadChunkToDisk streams a part directly to disk at its offset. +// downloadChunk streams a part directly to disk at its offset. // If skipSlowCheck is true, don't flag slow parts (used after repeated slow retries). -func (b *blobDownload) downloadChunkToDisk(ctx context.Context, requestURL *url.URL, file *os.File, part *blobDownloadPart, tracker *speedTracker, skipSlowCheck bool) error { +func (b *blobDownload) downloadChunk(ctx context.Context, requestURL *url.URL, file *os.File, part *blobDownloadPart, tracker *speedTracker, skipSlowCheck bool) error { g, ctx := errgroup.WithContext(ctx) startTime := time.Now() var bytesAtLastCheck atomic.Int64 @@ -512,10 +506,6 @@ func (b *blobDownload) downloadChunkToDisk(ctx context.Context, requestURL *url. written += int64(n) b.Completed.Add(int64(n)) bytesAtLastCheck.Store(written) - - part.lastUpdatedMu.Lock() - part.lastUpdated = time.Now() - part.lastUpdatedMu.Unlock() } if err == io.EOF { break @@ -663,21 +653,21 @@ type downloadOpts struct { } // downloadBlob downloads a blob from the registry and stores it in the blobs directory -func downloadBlob(ctx context.Context, opts downloadOpts) (cacheHit bool, _ error) { +func downloadBlob(ctx context.Context, opts downloadOpts) error { if opts.digest == "" { - return false, fmt.Errorf(("%s: %s"), opts.mp.GetNamespaceRepository(), "digest is empty") + return fmt.Errorf(("%s: %s"), opts.mp.GetNamespaceRepository(), "digest is empty") } fp, err := GetBlobsPath(opts.digest) if err != nil { - return false, err + return err } fi, err := os.Stat(fp) switch { case errors.Is(err, os.ErrNotExist): case err != nil: - return false, err + return err default: opts.fn(api.ProgressResponse{ Status: fmt.Sprintf("pulling %s", opts.digest[7:19]), @@ -686,7 +676,7 @@ func downloadBlob(ctx context.Context, opts downloadOpts) (cacheHit bool, _ erro Completed: fi.Size(), }) - return true, nil + return nil } data, ok := blobDownloadManager.LoadOrStore(opts.digest, &blobDownload{Name: fp, Digest: opts.digest}) @@ -696,12 +686,12 @@ func downloadBlob(ctx context.Context, opts downloadOpts) (cacheHit bool, _ erro requestURL = requestURL.JoinPath("v2", opts.mp.GetNamespaceRepository(), "blobs", opts.digest) if err := download.Prepare(ctx, requestURL, opts.regOpts); err != nil { blobDownloadManager.Delete(opts.digest) - return false, err + return err } //nolint:contextcheck go download.Run(context.Background(), requestURL, opts.regOpts) } - return false, download.Wait(ctx, opts.fn) + return download.Wait(ctx, opts.fn) } diff --git a/server/download_test.go b/server/download_test.go index d45e1113a..56842f237 100644 --- a/server/download_test.go +++ b/server/download_test.go @@ -12,40 +12,41 @@ import ( func TestSpeedTracker_Median(t *testing.T) { s := &speedTracker{} - // Less than 3 samples returns 0 - s.Record(100) - s.Record(200) + // Less than 10 samples returns 0 + for i := 0; i < 9; i++ { + s.Record(float64(100 + i*10)) + } if got := s.Median(); got != 0 { - t.Errorf("expected 0 with < 3 samples, got %f", got) + t.Errorf("expected 0 with < 10 samples, got %f", got) } - // With 3+ samples, returns median - s.Record(300) - // Samples: [100, 200, 300] -> median = 200 - if got := s.Median(); got != 200 { - t.Errorf("expected median 200, got %f", got) + // With 10+ samples, returns median + s.Record(190) + // Samples: [100, 110, 120, 130, 140, 150, 160, 170, 180, 190] -> median = 150 + if got := s.Median(); got != 150 { + t.Errorf("expected median 150, got %f", got) } // Add more samples s.Record(50) - s.Record(250) - // Samples: [100, 200, 300, 50, 250] sorted = [50, 100, 200, 250, 300] -> median = 200 - if got := s.Median(); got != 200 { - t.Errorf("expected median 200, got %f", got) + // Samples: [100, 110, 120, 130, 140, 150, 160, 170, 180, 190, 50] + // sorted = [50, 100, 110, 120, 130, 140, 150, 160, 170, 180, 190] -> median = 140 + if got := s.Median(); got != 140 { + t.Errorf("expected median 140, got %f", got) } } func TestSpeedTracker_RollingWindow(t *testing.T) { s := &speedTracker{} - // Add 105 samples (should keep only last 100) - for i := 0; i < 105; i++ { + // Add 35 samples (should keep only last 30) + for i := 0; i < 35; i++ { s.Record(float64(i)) } s.mu.Lock() - if len(s.speeds) != 100 { - t.Errorf("expected 100 samples, got %d", len(s.speeds)) + if len(s.speeds) != 30 { + t.Errorf("expected 30 samples, got %d", len(s.speeds)) } // First sample should be 5 (0-4 were dropped) if s.speeds[0] != 5 { @@ -99,7 +100,7 @@ func TestStreamHasher_Sequential(t *testing.T) { sh := newStreamHasher(f, parts, int64(len(data))) // Mark complete and run - sh.MarkComplete(0) + sh.Done(0) done := make(chan struct{}) go func() { @@ -150,9 +151,9 @@ func TestStreamHasher_OutOfOrderCompletion(t *testing.T) { }() // Mark parts complete out of order: 2, 0, 1 - sh.MarkComplete(2) - sh.MarkComplete(0) // This should trigger hashing of part 0 - sh.MarkComplete(1) // This should trigger hashing of parts 1 and 2 + sh.Done(2) + sh.Done(0) // This should trigger hashing of part 0 + sh.Done(1) // This should trigger hashing of parts 1 and 2 <-done @@ -228,7 +229,7 @@ func TestStreamHasher_HashedProgress(t *testing.T) { }() // Complete part 0 - sh.MarkComplete(0) + sh.Done(0) // Give hasher time to process for i := 0; i < 100; i++ { @@ -238,7 +239,7 @@ func TestStreamHasher_HashedProgress(t *testing.T) { } // Complete part 1 - sh.MarkComplete(1) + sh.Done(1) <-done if got := sh.Hashed(); got != 1000 { @@ -291,7 +292,7 @@ func BenchmarkStreamHasher(b *testing.B) { for i := 0; i < b.N; i++ { sh := newStreamHasher(f, parts, int64(size)) - sh.MarkComplete(0) + sh.Done(0) done := make(chan struct{}) go func() { diff --git a/server/images.go b/server/images.go index d3de232b1..4cec2233e 100644 --- a/server/images.go +++ b/server/images.go @@ -621,22 +621,18 @@ func PullModel(ctx context.Context, name string, regOpts *registryOptions, fn fu } for _, layer := range layers { - _, err := downloadBlob(ctx, downloadOpts{ + if err := downloadBlob(ctx, downloadOpts{ mp: mp, digest: layer.Digest, regOpts: regOpts, fn: fn, - }) - if err != nil { + }); err != nil { return err } delete(deleteMap, layer.Digest) } delete(deleteMap, manifest.Config.Digest) - // Note: Digest verification now happens inline during download in blobDownload.run() - // via the orderedWriter, so no separate verification pass is needed. - fn(api.ProgressResponse{Status: "writing manifest"}) manifestJSON, err := json.Marshal(manifest) @@ -839,25 +835,3 @@ func parseRegistryChallenge(authStr string) registryChallenge { Scope: getValue(authStr, "scope"), } } - -var errDigestMismatch = errors.New("digest mismatch, file must be downloaded again") - -func verifyBlob(digest string) error { - fp, err := GetBlobsPath(digest) - if err != nil { - return err - } - - f, err := os.Open(fp) - if err != nil { - return err - } - defer f.Close() - - fileDigest, _ := GetSHA256Digest(f) - if digest != fileDigest { - return fmt.Errorf("%w: want %s, got %s", errDigestMismatch, digest, fileDigest) - } - - return nil -}