diff --git a/server/download.go b/server/download.go index 81b4054c8..68836a824 100644 --- a/server/download.go +++ b/server/download.go @@ -132,20 +132,17 @@ func (p *blobDownloadPart) UnmarshalJSON(b []byte) error { return nil } -const numDownloadParts = 16 - // Download tuning. Override via environment variables: // - Memory constrained: reduce OLLAMA_DOWNLOAD_CONCURRENCY // - Default uses ~64KB memory regardless of concurrency (streams to disk, hashes from page cache) var ( // downloadPartSize is the size of each download part. - // Default: 16MB. Override with OLLAMA_DOWNLOAD_PART_SIZE (in MB). - downloadPartSize = int64(getEnvInt("OLLAMA_DOWNLOAD_PART_SIZE", 32)) * format.MegaByte + // Default: 64MB. Override with OLLAMA_DOWNLOAD_PART_SIZE (in MB). + downloadPartSize = int64(getEnvInt("OLLAMA_DOWNLOAD_PART_SIZE", 64)) * format.MegaByte // downloadConcurrency limits concurrent part downloads. - // Higher = faster on fast connections. Memory stays ~64KB regardless. - // Default: 64. Override with OLLAMA_DOWNLOAD_CONCURRENCY. - downloadConcurrency = getEnvInt("OLLAMA_DOWNLOAD_CONCURRENCY", 64) + // Default: 32. Override with OLLAMA_DOWNLOAD_CONCURRENCY. + downloadConcurrency = getEnvInt("OLLAMA_DOWNLOAD_CONCURRENCY", 32) ) func getEnvInt(key string, defaultVal int) int { @@ -301,14 +298,7 @@ func (b *blobDownload) Prepare(ctx context.Context, requestURL *url.URL, opts *r b.Total, _ = strconv.ParseInt(resp.Header.Get("Content-Length"), 10, 64) - size := b.Total / numDownloadParts - switch { - case size < downloadPartSize: - size = downloadPartSize - case size > downloadPartSize: - size = downloadPartSize - } - + size := downloadPartSize var offset int64 for offset < b.Total { if offset+size > b.Total { @@ -442,8 +432,11 @@ func (b *blobDownload) run(ctx context.Context, requestURL *url.URL, opts *regis g.Go(func() error { var err error + var slowRetries int for try := 0; try < maxRetries; try++ { - err = b.downloadChunkToDisk(inner, directURL, file, part, tracker) + // After 3 slow retries, stop checking slowness and let it complete + skipSlowCheck := slowRetries >= 3 + err = b.downloadChunkToDisk(inner, directURL, file, part, tracker, skipSlowCheck) switch { case errors.Is(err, context.Canceled), errors.Is(err, syscall.ENOSPC): return err @@ -451,30 +444,9 @@ func (b *blobDownload) run(ctx context.Context, requestURL *url.URL, opts *regis try-- continue case errors.Is(err, errPartSlow): - // Race: start a second download, first to complete wins - raceCtx, raceCancel := context.WithCancel(inner) - var once sync.Once - var raceErr error - - done := make(chan struct{}) - for r := 0; r < 2; r++ { - go func() { - e := b.downloadChunkToDisk(raceCtx, directURL, file, part, tracker) - once.Do(func() { - raceErr = e - raceCancel() - close(done) - }) - }() - } - <-done - - if raceErr == nil { - sh.MarkComplete(part.N) - return nil - } - err = raceErr - // Fall through to retry logic + // Kill slow request, retry immediately (stays within concurrency limit) + slowRetries++ + try-- continue case err != nil: sleep := time.Second * time.Duration(math.Pow(2, float64(try))) @@ -556,7 +528,8 @@ func (b *blobDownload) run(ctx context.Context, requestURL *url.URL, opts *regis // downloadChunkToDisk streams a part directly to disk at its offset. // Memory: ~32KB (read buffer only). -func (b *blobDownload) downloadChunkToDisk(ctx context.Context, requestURL *url.URL, file *os.File, part *blobDownloadPart, tracker *speedTracker) error { +// If skipSlowCheck is true, don't flag slow parts (used after repeated slow retries). +func (b *blobDownload) downloadChunkToDisk(ctx context.Context, requestURL *url.URL, file *os.File, part *blobDownloadPart, tracker *speedTracker, skipSlowCheck bool) error { g, ctx := errgroup.WithContext(ctx) startTime := time.Now() var bytesAtLastCheck atomic.Int64 @@ -638,14 +611,15 @@ func (b *blobDownload) downloadChunkToDisk(ctx context.Context, requestURL *url. } // Check for slow speed after 5+ seconds (only for multi-part downloads) + // Skip if we've already retried for slowness too many times elapsed := time.Since(startTime).Seconds() - if elapsed >= 5 && currentBytes > 0 && len(b.Parts) > 1 { + if !skipSlowCheck && elapsed >= 5 && currentBytes > 0 && len(b.Parts) > 1 { currentSpeed := float64(currentBytes) / elapsed median := tracker.Median() // If we're below 10% of median speed, flag as slow if median > 0 && currentSpeed < median*0.1 { - slog.Info(fmt.Sprintf("%s part %d slow (%.0f KB/s vs median %.0f KB/s); racing", + slog.Info(fmt.Sprintf("%s part %d slow (%.0f KB/s vs median %.0f KB/s); retrying", b.Digest[7:19], part.N, currentSpeed/1024, median/1024)) return errPartSlow }