This commit is contained in:
jmorganca 2025-12-20 19:10:35 -08:00
parent c623b256a3
commit f90d968b8b
1 changed files with 17 additions and 43 deletions

View File

@ -132,20 +132,17 @@ func (p *blobDownloadPart) UnmarshalJSON(b []byte) error {
return nil return nil
} }
const numDownloadParts = 16
// Download tuning. Override via environment variables: // Download tuning. Override via environment variables:
// - Memory constrained: reduce OLLAMA_DOWNLOAD_CONCURRENCY // - Memory constrained: reduce OLLAMA_DOWNLOAD_CONCURRENCY
// - Default uses ~64KB memory regardless of concurrency (streams to disk, hashes from page cache) // - Default uses ~64KB memory regardless of concurrency (streams to disk, hashes from page cache)
var ( var (
// downloadPartSize is the size of each download part. // downloadPartSize is the size of each download part.
// Default: 16MB. Override with OLLAMA_DOWNLOAD_PART_SIZE (in MB). // Default: 64MB. Override with OLLAMA_DOWNLOAD_PART_SIZE (in MB).
downloadPartSize = int64(getEnvInt("OLLAMA_DOWNLOAD_PART_SIZE", 32)) * format.MegaByte downloadPartSize = int64(getEnvInt("OLLAMA_DOWNLOAD_PART_SIZE", 64)) * format.MegaByte
// downloadConcurrency limits concurrent part downloads. // downloadConcurrency limits concurrent part downloads.
// Higher = faster on fast connections. Memory stays ~64KB regardless. // Default: 32. Override with OLLAMA_DOWNLOAD_CONCURRENCY.
// Default: 64. Override with OLLAMA_DOWNLOAD_CONCURRENCY. downloadConcurrency = getEnvInt("OLLAMA_DOWNLOAD_CONCURRENCY", 32)
downloadConcurrency = getEnvInt("OLLAMA_DOWNLOAD_CONCURRENCY", 64)
) )
func getEnvInt(key string, defaultVal int) int { func getEnvInt(key string, defaultVal int) int {
@ -301,14 +298,7 @@ func (b *blobDownload) Prepare(ctx context.Context, requestURL *url.URL, opts *r
b.Total, _ = strconv.ParseInt(resp.Header.Get("Content-Length"), 10, 64) b.Total, _ = strconv.ParseInt(resp.Header.Get("Content-Length"), 10, 64)
size := b.Total / numDownloadParts size := downloadPartSize
switch {
case size < downloadPartSize:
size = downloadPartSize
case size > downloadPartSize:
size = downloadPartSize
}
var offset int64 var offset int64
for offset < b.Total { for offset < b.Total {
if offset+size > b.Total { if offset+size > b.Total {
@ -442,8 +432,11 @@ func (b *blobDownload) run(ctx context.Context, requestURL *url.URL, opts *regis
g.Go(func() error { g.Go(func() error {
var err error var err error
var slowRetries int
for try := 0; try < maxRetries; try++ { for try := 0; try < maxRetries; try++ {
err = b.downloadChunkToDisk(inner, directURL, file, part, tracker) // After 3 slow retries, stop checking slowness and let it complete
skipSlowCheck := slowRetries >= 3
err = b.downloadChunkToDisk(inner, directURL, file, part, tracker, skipSlowCheck)
switch { switch {
case errors.Is(err, context.Canceled), errors.Is(err, syscall.ENOSPC): case errors.Is(err, context.Canceled), errors.Is(err, syscall.ENOSPC):
return err return err
@ -451,30 +444,9 @@ func (b *blobDownload) run(ctx context.Context, requestURL *url.URL, opts *regis
try-- try--
continue continue
case errors.Is(err, errPartSlow): case errors.Is(err, errPartSlow):
// Race: start a second download, first to complete wins // Kill slow request, retry immediately (stays within concurrency limit)
raceCtx, raceCancel := context.WithCancel(inner) slowRetries++
var once sync.Once try--
var raceErr error
done := make(chan struct{})
for r := 0; r < 2; r++ {
go func() {
e := b.downloadChunkToDisk(raceCtx, directURL, file, part, tracker)
once.Do(func() {
raceErr = e
raceCancel()
close(done)
})
}()
}
<-done
if raceErr == nil {
sh.MarkComplete(part.N)
return nil
}
err = raceErr
// Fall through to retry logic
continue continue
case err != nil: case err != nil:
sleep := time.Second * time.Duration(math.Pow(2, float64(try))) sleep := time.Second * time.Duration(math.Pow(2, float64(try)))
@ -556,7 +528,8 @@ func (b *blobDownload) run(ctx context.Context, requestURL *url.URL, opts *regis
// downloadChunkToDisk streams a part directly to disk at its offset. // downloadChunkToDisk streams a part directly to disk at its offset.
// Memory: ~32KB (read buffer only). // Memory: ~32KB (read buffer only).
func (b *blobDownload) downloadChunkToDisk(ctx context.Context, requestURL *url.URL, file *os.File, part *blobDownloadPart, tracker *speedTracker) error { // If skipSlowCheck is true, don't flag slow parts (used after repeated slow retries).
func (b *blobDownload) downloadChunkToDisk(ctx context.Context, requestURL *url.URL, file *os.File, part *blobDownloadPart, tracker *speedTracker, skipSlowCheck bool) error {
g, ctx := errgroup.WithContext(ctx) g, ctx := errgroup.WithContext(ctx)
startTime := time.Now() startTime := time.Now()
var bytesAtLastCheck atomic.Int64 var bytesAtLastCheck atomic.Int64
@ -638,14 +611,15 @@ func (b *blobDownload) downloadChunkToDisk(ctx context.Context, requestURL *url.
} }
// Check for slow speed after 5+ seconds (only for multi-part downloads) // Check for slow speed after 5+ seconds (only for multi-part downloads)
// Skip if we've already retried for slowness too many times
elapsed := time.Since(startTime).Seconds() elapsed := time.Since(startTime).Seconds()
if elapsed >= 5 && currentBytes > 0 && len(b.Parts) > 1 { if !skipSlowCheck && elapsed >= 5 && currentBytes > 0 && len(b.Parts) > 1 {
currentSpeed := float64(currentBytes) / elapsed currentSpeed := float64(currentBytes) / elapsed
median := tracker.Median() median := tracker.Median()
// If we're below 10% of median speed, flag as slow // If we're below 10% of median speed, flag as slow
if median > 0 && currentSpeed < median*0.1 { if median > 0 && currentSpeed < median*0.1 {
slog.Info(fmt.Sprintf("%s part %d slow (%.0f KB/s vs median %.0f KB/s); racing", slog.Info(fmt.Sprintf("%s part %d slow (%.0f KB/s vs median %.0f KB/s); retrying",
b.Digest[7:19], part.N, currentSpeed/1024, median/1024)) b.Digest[7:19], part.N, currentSpeed/1024, median/1024))
return errPartSlow return errPartSlow
} }