From 6e00a0c89ae3519c73632463f99d9010418dd6ca Mon Sep 17 00:00:00 2001 From: jmorganca Date: Sat, 20 Dec 2025 18:02:52 -0800 Subject: [PATCH] speed tracker --- server/download.go | 110 ++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 108 insertions(+), 2 deletions(-) diff --git a/server/download.go b/server/download.go index 25e6fc697..964115a81 100644 --- a/server/download.go +++ b/server/download.go @@ -33,9 +33,45 @@ const maxRetries = 6 var ( errMaxRetriesExceeded = errors.New("max retries exceeded") errPartStalled = errors.New("part stalled") + errPartSlow = errors.New("part slow, racing") errMaxRedirectsExceeded = errors.New("maximum redirects exceeded (10) for directURL") ) +// speedTracker tracks download speeds and computes rolling median. +type speedTracker struct { + mu sync.Mutex + speeds []float64 // bytes per second +} + +func (s *speedTracker) Record(bytesPerSec float64) { + s.mu.Lock() + defer s.mu.Unlock() + s.speeds = append(s.speeds, bytesPerSec) + // Keep last 100 samples + if len(s.speeds) > 100 { + s.speeds = s.speeds[1:] + } +} + +func (s *speedTracker) Median() float64 { + s.mu.Lock() + defer s.mu.Unlock() + if len(s.speeds) < 3 { + return 0 // not enough data + } + // Simple median: sort a copy and take middle + sorted := make([]float64, len(s.speeds)) + copy(sorted, s.speeds) + for i := range sorted { + for j := i + 1; j < len(sorted); j++ { + if sorted[j] < sorted[i] { + sorted[i], sorted[j] = sorted[j], sorted[i] + } + } + } + return sorted[len(sorted)/2] +} + var blobDownloadManager sync.Map type blobDownload struct { @@ -386,6 +422,7 @@ func (b *blobDownload) run(ctx context.Context, requestURL *url.URL, opts *regis // The hasher follows behind the downloaders, reading recently-written // data from OS page cache (RAM) rather than disk. sh := newStreamHasher(file, b.Parts, b.Total) + tracker := &speedTracker{} // Start hasher goroutine hashDone := make(chan struct{}) @@ -406,13 +443,39 @@ func (b *blobDownload) run(ctx context.Context, requestURL *url.URL, opts *regis g.Go(func() error { var err error for try := 0; try < maxRetries; try++ { - err = b.downloadChunkToDisk(inner, directURL, file, part) + err = b.downloadChunkToDisk(inner, directURL, file, part, tracker) switch { case errors.Is(err, context.Canceled), errors.Is(err, syscall.ENOSPC): return err case errors.Is(err, errPartStalled): try-- continue + case errors.Is(err, errPartSlow): + // Race: start a second download, first to complete wins + raceCtx, raceCancel := context.WithCancel(inner) + var once sync.Once + var raceErr error + + done := make(chan struct{}) + for r := 0; r < 2; r++ { + go func() { + e := b.downloadChunkToDisk(raceCtx, directURL, file, part, tracker) + once.Do(func() { + raceErr = e + raceCancel() + close(done) + }) + }() + } + <-done + + if raceErr == nil { + sh.MarkComplete(part.N) + return nil + } + err = raceErr + // Fall through to retry logic + continue case err != nil: sleep := time.Second * time.Duration(math.Pow(2, float64(try))) slog.Info(fmt.Sprintf("%s part %d attempt %d failed: %v, retrying in %s", b.Digest[7:19], part.N, try, err, sleep)) @@ -482,8 +545,10 @@ func (b *blobDownload) run(ctx context.Context, requestURL *url.URL, opts *regis // downloadChunkToDisk streams a part directly to disk at its offset. // Memory: ~32KB (read buffer only). -func (b *blobDownload) downloadChunkToDisk(ctx context.Context, requestURL *url.URL, file *os.File, part *blobDownloadPart) error { +func (b *blobDownload) downloadChunkToDisk(ctx context.Context, requestURL *url.URL, file *os.File, part *blobDownloadPart, tracker *speedTracker) error { g, ctx := errgroup.WithContext(ctx) + startTime := time.Now() + var bytesAtLastCheck atomic.Int64 g.Go(func() error { req, err := http.NewRequestWithContext(ctx, http.MethodGet, requestURL.String(), nil) @@ -509,6 +574,7 @@ func (b *blobDownload) downloadChunkToDisk(ctx context.Context, requestURL *url. } written += int64(n) b.Completed.Add(int64(n)) + bytesAtLastCheck.Store(written) part.lastUpdatedMu.Lock() part.lastUpdated = time.Now() @@ -523,6 +589,12 @@ func (b *blobDownload) downloadChunkToDisk(ctx context.Context, requestURL *url. } } + // Record speed for this part + elapsed := time.Since(startTime).Seconds() + if elapsed > 0 { + tracker.Record(float64(part.Size) / elapsed) + } + part.Completed.Store(part.Size) return b.writePart(part.Name(), part) }) @@ -530,12 +602,19 @@ func (b *blobDownload) downloadChunkToDisk(ctx context.Context, requestURL *url. g.Go(func() error { ticker := time.NewTicker(time.Second) defer ticker.Stop() + var lastBytes int64 + checksWithoutProgress := 0 + for { select { case <-ticker.C: if part.Completed.Load() >= part.Size { return nil } + + currentBytes := bytesAtLastCheck.Load() + + // Check for complete stall (30 seconds no progress) part.lastUpdatedMu.Lock() lastUpdated := part.lastUpdated part.lastUpdatedMu.Unlock() @@ -546,6 +625,33 @@ func (b *blobDownload) downloadChunkToDisk(ctx context.Context, requestURL *url. part.lastUpdatedMu.Unlock() return errPartStalled } + + // Check for slow speed after 5+ seconds + elapsed := time.Since(startTime).Seconds() + if elapsed >= 5 && currentBytes > 0 { + currentSpeed := float64(currentBytes) / elapsed + median := tracker.Median() + + // If we're below 10% of median speed, flag as slow + if median > 0 && currentSpeed < median*0.1 { + slog.Info(fmt.Sprintf("%s part %d slow (%.0f KB/s vs median %.0f KB/s); racing", + b.Digest[7:19], part.N, currentSpeed/1024, median/1024)) + return errPartSlow + } + } + + // Also check if speed dropped significantly mid-download + if currentBytes == lastBytes { + checksWithoutProgress++ + if checksWithoutProgress >= 10 { + slog.Info(fmt.Sprintf("%s part %d no progress for 10s; retrying", b.Digest[7:19], part.N)) + return errPartStalled + } + } else { + checksWithoutProgress = 0 + } + lastBytes = currentBytes + case <-ctx.Done(): return ctx.Err() }