speed tracker

This commit is contained in:
jmorganca 2025-12-20 18:02:52 -08:00
parent 55b1ee2557
commit 6e00a0c89a
1 changed files with 108 additions and 2 deletions

View File

@ -33,9 +33,45 @@ const maxRetries = 6
var (
errMaxRetriesExceeded = errors.New("max retries exceeded")
errPartStalled = errors.New("part stalled")
errPartSlow = errors.New("part slow, racing")
errMaxRedirectsExceeded = errors.New("maximum redirects exceeded (10) for directURL")
)
// speedTracker tracks download speeds and computes rolling median.
type speedTracker struct {
mu sync.Mutex
speeds []float64 // bytes per second
}
func (s *speedTracker) Record(bytesPerSec float64) {
s.mu.Lock()
defer s.mu.Unlock()
s.speeds = append(s.speeds, bytesPerSec)
// Keep last 100 samples
if len(s.speeds) > 100 {
s.speeds = s.speeds[1:]
}
}
func (s *speedTracker) Median() float64 {
s.mu.Lock()
defer s.mu.Unlock()
if len(s.speeds) < 3 {
return 0 // not enough data
}
// Simple median: sort a copy and take middle
sorted := make([]float64, len(s.speeds))
copy(sorted, s.speeds)
for i := range sorted {
for j := i + 1; j < len(sorted); j++ {
if sorted[j] < sorted[i] {
sorted[i], sorted[j] = sorted[j], sorted[i]
}
}
}
return sorted[len(sorted)/2]
}
var blobDownloadManager sync.Map
type blobDownload struct {
@ -386,6 +422,7 @@ func (b *blobDownload) run(ctx context.Context, requestURL *url.URL, opts *regis
// The hasher follows behind the downloaders, reading recently-written
// data from OS page cache (RAM) rather than disk.
sh := newStreamHasher(file, b.Parts, b.Total)
tracker := &speedTracker{}
// Start hasher goroutine
hashDone := make(chan struct{})
@ -406,13 +443,39 @@ func (b *blobDownload) run(ctx context.Context, requestURL *url.URL, opts *regis
g.Go(func() error {
var err error
for try := 0; try < maxRetries; try++ {
err = b.downloadChunkToDisk(inner, directURL, file, part)
err = b.downloadChunkToDisk(inner, directURL, file, part, tracker)
switch {
case errors.Is(err, context.Canceled), errors.Is(err, syscall.ENOSPC):
return err
case errors.Is(err, errPartStalled):
try--
continue
case errors.Is(err, errPartSlow):
// Race: start a second download, first to complete wins
raceCtx, raceCancel := context.WithCancel(inner)
var once sync.Once
var raceErr error
done := make(chan struct{})
for r := 0; r < 2; r++ {
go func() {
e := b.downloadChunkToDisk(raceCtx, directURL, file, part, tracker)
once.Do(func() {
raceErr = e
raceCancel()
close(done)
})
}()
}
<-done
if raceErr == nil {
sh.MarkComplete(part.N)
return nil
}
err = raceErr
// Fall through to retry logic
continue
case err != nil:
sleep := time.Second * time.Duration(math.Pow(2, float64(try)))
slog.Info(fmt.Sprintf("%s part %d attempt %d failed: %v, retrying in %s", b.Digest[7:19], part.N, try, err, sleep))
@ -482,8 +545,10 @@ func (b *blobDownload) run(ctx context.Context, requestURL *url.URL, opts *regis
// downloadChunkToDisk streams a part directly to disk at its offset.
// Memory: ~32KB (read buffer only).
func (b *blobDownload) downloadChunkToDisk(ctx context.Context, requestURL *url.URL, file *os.File, part *blobDownloadPart) error {
func (b *blobDownload) downloadChunkToDisk(ctx context.Context, requestURL *url.URL, file *os.File, part *blobDownloadPart, tracker *speedTracker) error {
g, ctx := errgroup.WithContext(ctx)
startTime := time.Now()
var bytesAtLastCheck atomic.Int64
g.Go(func() error {
req, err := http.NewRequestWithContext(ctx, http.MethodGet, requestURL.String(), nil)
@ -509,6 +574,7 @@ func (b *blobDownload) downloadChunkToDisk(ctx context.Context, requestURL *url.
}
written += int64(n)
b.Completed.Add(int64(n))
bytesAtLastCheck.Store(written)
part.lastUpdatedMu.Lock()
part.lastUpdated = time.Now()
@ -523,6 +589,12 @@ func (b *blobDownload) downloadChunkToDisk(ctx context.Context, requestURL *url.
}
}
// Record speed for this part
elapsed := time.Since(startTime).Seconds()
if elapsed > 0 {
tracker.Record(float64(part.Size) / elapsed)
}
part.Completed.Store(part.Size)
return b.writePart(part.Name(), part)
})
@ -530,12 +602,19 @@ func (b *blobDownload) downloadChunkToDisk(ctx context.Context, requestURL *url.
g.Go(func() error {
ticker := time.NewTicker(time.Second)
defer ticker.Stop()
var lastBytes int64
checksWithoutProgress := 0
for {
select {
case <-ticker.C:
if part.Completed.Load() >= part.Size {
return nil
}
currentBytes := bytesAtLastCheck.Load()
// Check for complete stall (30 seconds no progress)
part.lastUpdatedMu.Lock()
lastUpdated := part.lastUpdated
part.lastUpdatedMu.Unlock()
@ -546,6 +625,33 @@ func (b *blobDownload) downloadChunkToDisk(ctx context.Context, requestURL *url.
part.lastUpdatedMu.Unlock()
return errPartStalled
}
// Check for slow speed after 5+ seconds
elapsed := time.Since(startTime).Seconds()
if elapsed >= 5 && currentBytes > 0 {
currentSpeed := float64(currentBytes) / elapsed
median := tracker.Median()
// If we're below 10% of median speed, flag as slow
if median > 0 && currentSpeed < median*0.1 {
slog.Info(fmt.Sprintf("%s part %d slow (%.0f KB/s vs median %.0f KB/s); racing",
b.Digest[7:19], part.N, currentSpeed/1024, median/1024))
return errPartSlow
}
}
// Also check if speed dropped significantly mid-download
if currentBytes == lastBytes {
checksWithoutProgress++
if checksWithoutProgress >= 10 {
slog.Info(fmt.Sprintf("%s part %d no progress for 10s; retrying", b.Digest[7:19], part.N))
return errPartStalled
}
} else {
checksWithoutProgress = 0
}
lastBytes = currentBytes
case <-ctx.Done():
return ctx.Err()
}