diff --git a/server/download.go b/server/download.go index 1265460ea..30acb314d 100644 --- a/server/download.go +++ b/server/download.go @@ -15,6 +15,7 @@ import ( "net/url" "os" "path/filepath" + "runtime" "strconv" "strings" "sync" @@ -98,27 +99,18 @@ func (p *blobDownloadPart) UnmarshalJSON(b []byte) error { const numDownloadParts = 16 -// Download tuning. Override via environment variables for different environments: -// - Developer laptop: defaults are fine -// - Fast server (10Gbps+): OLLAMA_DOWNLOAD_CONCURRENCY=32 or higher -// - Memory constrained: reduce OLLAMA_DOWNLOAD_PART_SIZE and OLLAMA_DOWNLOAD_BUFFER_SIZE +// Download tuning. Override via environment variables: +// - Memory constrained: reduce OLLAMA_DOWNLOAD_CONCURRENCY +// - Default uses ~64KB memory regardless of concurrency (streams to disk, hashes from page cache) var ( // downloadPartSize is the size of each download part. - // Smaller = less memory, more HTTP requests. // Default: 16MB. Override with OLLAMA_DOWNLOAD_PART_SIZE (in MB). downloadPartSize = int64(getEnvInt("OLLAMA_DOWNLOAD_PART_SIZE", 16)) * format.MegaByte // downloadConcurrency limits concurrent part downloads. - // Higher = faster on fast connections, more memory. - // Default: 8. Override with OLLAMA_DOWNLOAD_CONCURRENCY. - downloadConcurrency = getEnvInt("OLLAMA_DOWNLOAD_CONCURRENCY", 8) - - // downloadBufferSize is the max bytes buffered in orderedWriter before - // Submit blocks (backpressure). This bounds memory usage. - // Default: 128MB. Override with OLLAMA_DOWNLOAD_BUFFER_SIZE (in MB). - // Total memory ≈ (concurrency × part_size) + buffer_size - // Default: (8 × 16MB) + 128MB = 256MB max - downloadBufferSize = int64(getEnvInt("OLLAMA_DOWNLOAD_BUFFER_SIZE", 128)) * format.MegaByte + // Higher = faster on fast connections. Memory stays ~64KB regardless. + // Default: 2 * GOMAXPROCS (scales with CPU cores). Override with OLLAMA_DOWNLOAD_CONCURRENCY. + downloadConcurrency = getEnvInt("OLLAMA_DOWNLOAD_CONCURRENCY", 2*runtime.GOMAXPROCS(0)) ) func getEnvInt(key string, defaultVal int) int { @@ -130,64 +122,105 @@ func getEnvInt(key string, defaultVal int) int { return defaultVal } -// orderedWriter buffers out-of-order parts and writes them sequentially -// through a hasher and file. This allows parallel downloads while computing -// the hash incrementally without a post-download verification pass. -type orderedWriter struct { - mu sync.Mutex - cond *sync.Cond - next int // next expected part index - pending map[int][]byte // out-of-order parts waiting to be written - pendingSize int64 // total bytes in pending - out io.Writer // destination (typically MultiWriter(file, hasher)) - hasher hash.Hash // for computing final digest +// streamHasher reads a file sequentially and hashes it as chunks complete. +// Memory usage: ~64KB (just the read buffer), regardless of file size or concurrency. +// Works by reading from OS page cache - data just written is still in RAM. +type streamHasher struct { + file *os.File + hasher hash.Hash + parts []*blobDownloadPart + total int64 // total bytes to hash + hashed atomic.Int64 + + mu sync.Mutex + cond *sync.Cond + completed []bool + done bool + err error } -func newOrderedWriter(file io.Writer, hasher hash.Hash) *orderedWriter { - w := &orderedWriter{ - pending: make(map[int][]byte), - out: io.MultiWriter(file, hasher), - hasher: hasher, +func newStreamHasher(file *os.File, parts []*blobDownloadPart, total int64) *streamHasher { + h := &streamHasher{ + file: file, + hasher: sha256.New(), + parts: parts, + total: total, + completed: make([]bool, len(parts)), } - w.cond = sync.NewCond(&w.mu) - return w + h.cond = sync.NewCond(&h.mu) + return h } -// Submit adds a part to the writer. Parts are written in order; if this part -// is out of order, it's buffered until earlier parts arrive. Blocks if the -// pending buffer exceeds downloadBufferSize (backpressure), unless this is the -// next expected part (which will drain the buffer). -func (w *orderedWriter) Submit(partIndex int, data []byte) error { - w.mu.Lock() - defer w.mu.Unlock() +// MarkComplete signals that a part has been written to disk. +func (h *streamHasher) MarkComplete(partIndex int) { + h.mu.Lock() + h.completed[partIndex] = true + h.cond.Broadcast() + h.mu.Unlock() +} - // Backpressure: wait if buffer is too full, unless we're the next part - // (the next part will drain the buffer, so it must always proceed) - for w.pendingSize+int64(len(data)) > downloadBufferSize && partIndex != w.next { - w.cond.Wait() - } +// Run reads and hashes the file sequentially. Call in a goroutine. +func (h *streamHasher) Run() { + buf := make([]byte, 64*1024) // 64KB read buffer + var offset int64 - w.pending[partIndex] = data - w.pendingSize += int64(len(data)) - - // Write all consecutive parts starting from next - for w.pending[w.next] != nil { - data := w.pending[w.next] - if _, err := w.out.Write(data); err != nil { - return err + for i, part := range h.parts { + // Wait for this part to be written + h.mu.Lock() + for !h.completed[i] && !h.done { + h.cond.Wait() + } + if h.done { + h.mu.Unlock() + return + } + h.mu.Unlock() + + // Read and hash this part (from page cache) + remaining := part.Size + for remaining > 0 { + n := int64(len(buf)) + if n > remaining { + n = remaining + } + nr, err := h.file.ReadAt(buf[:n], offset) + if err != nil && err != io.EOF { + h.mu.Lock() + h.err = err + h.mu.Unlock() + return + } + h.hasher.Write(buf[:nr]) + offset += int64(nr) + remaining -= int64(nr) + h.hashed.Store(offset) } - w.pendingSize -= int64(len(data)) - w.pending[w.next] = nil // help GC free the slice - delete(w.pending, w.next) - w.next++ - w.cond.Broadcast() // wake any blocked submitters } - return nil } -// Digest returns the computed hash after all parts have been written. -func (w *orderedWriter) Digest() string { - return fmt.Sprintf("sha256:%x", w.hasher.Sum(nil)) +// Stop signals the hasher to exit early. +func (h *streamHasher) Stop() { + h.mu.Lock() + h.done = true + h.cond.Broadcast() + h.mu.Unlock() +} + +// Hashed returns bytes hashed so far. +func (h *streamHasher) Hashed() int64 { + return h.hashed.Load() +} + +// Digest returns the computed hash. +func (h *streamHasher) Digest() string { + return fmt.Sprintf("sha256:%x", h.hasher.Sum(nil)) +} + +// Err returns any error from hashing. +func (h *streamHasher) Err() error { + h.mu.Lock() + defer h.mu.Unlock() + return h.err } func (p *blobDownloadPart) Name() string { @@ -196,14 +229,6 @@ func (p *blobDownloadPart) Name() string { }, "-") } -func (p *blobDownloadPart) StartsAt() int64 { - return p.Offset + p.Completed.Load() -} - -func (p *blobDownloadPart) StopsAt() int64 { - return p.Offset + p.Size -} - func (p *blobDownloadPart) Write(b []byte) (n int, err error) { n = len(b) p.blobDownload.Completed.Add(int64(n)) @@ -357,24 +382,32 @@ func (b *blobDownload) run(ctx context.Context, requestURL *url.URL, opts *regis return err } - // Download chunks in parallel, hash while writing via ordered writer. - // Memory: (concurrency × part_size) + buffer. For 2GB files with defaults: - // 8 × 32MB + 64MB = 320MB max. - ow := newOrderedWriter(file, sha256.New()) + // Download chunks to disk, hash by reading from page cache. + // Memory: ~64KB (hasher read buffer only), regardless of concurrency. + // The hasher follows behind the downloaders, reading recently-written + // data from OS page cache (RAM) rather than disk. + sh := newStreamHasher(file, b.Parts, b.Total) + + // Start hasher goroutine + hashDone := make(chan struct{}) + go func() { + sh.Run() + close(hashDone) + }() g, inner := errgroup.WithContext(ctx) g.SetLimit(downloadConcurrency) for i := range b.Parts { part := b.Parts[i] if part.Completed.Load() == part.Size { + sh.MarkComplete(part.N) continue } g.Go(func() error { - var data []byte var err error for try := 0; try < maxRetries; try++ { - data, err = b.downloadChunkToBuffer(inner, directURL, part) + err = b.downloadChunkToDisk(inner, directURL, file, part) switch { case errors.Is(err, context.Canceled), errors.Is(err, syscall.ENOSPC): return err @@ -387,22 +420,46 @@ func (b *blobDownload) run(ctx context.Context, requestURL *url.URL, opts *regis time.Sleep(sleep) continue default: - err := ow.Submit(part.N, data) - data = nil // help GC - return err + sh.MarkComplete(part.N) + return nil } } return fmt.Errorf("%w: %w", errMaxRetriesExceeded, err) }) } + // Log progress periodically + progressDone := make(chan struct{}) + go func() { + ticker := time.NewTicker(time.Second) + defer ticker.Stop() + for { + select { + case <-ticker.C: + dl := int(b.Completed.Load() * 100 / b.Total) + h := int(sh.Hashed() * 100 / b.Total) + slog.Info(fmt.Sprintf("progress: downloaded %d%% | hashed %d%%", dl, h)) + case <-progressDone: + return + } + } + }() + if err := g.Wait(); err != nil { + close(progressDone) + sh.Stop() return err } - // Verify hash - no re-read needed, hash was computed while writing - computed := ow.Digest() - if computed != b.Digest { + // Wait for hasher to finish + <-hashDone + close(progressDone) + if err := sh.Err(); err != nil { + return err + } + + // Verify hash + if computed := sh.Digest(); computed != b.Digest { return fmt.Errorf("digest mismatch: got %s, want %s", computed, b.Digest) } @@ -424,11 +481,11 @@ func (b *blobDownload) run(ctx context.Context, requestURL *url.URL, opts *regis return nil } -// downloadChunkToBuffer downloads a part to a buffer, tracking progress and detecting stalls. -func (b *blobDownload) downloadChunkToBuffer(ctx context.Context, requestURL *url.URL, part *blobDownloadPart) ([]byte, error) { +// downloadChunkToDisk streams a part directly to disk at its offset. +// Memory: ~32KB (read buffer only). +func (b *blobDownload) downloadChunkToDisk(ctx context.Context, requestURL *url.URL, file *os.File, part *blobDownloadPart) error { g, ctx := errgroup.WithContext(ctx) - var data []byte g.Go(func() error { req, err := http.NewRequestWithContext(ctx, http.MethodGet, requestURL.String(), nil) if err != nil { @@ -441,14 +498,17 @@ func (b *blobDownload) downloadChunkToBuffer(ctx context.Context, requestURL *ur } defer resp.Body.Close() - // Pre-allocate buffer for the part - data = make([]byte, 0, part.Size) - buf := make([]byte, 32*1024) // 32KB read buffer + w := io.NewOffsetWriter(file, part.Offset) + buf := make([]byte, 32*1024) - for { + var written int64 + for written < part.Size { n, err := resp.Body.Read(buf) if n > 0 { - data = append(data, buf[:n]...) + if _, werr := w.Write(buf[:n]); werr != nil { + return werr + } + written += int64(n) b.Completed.Add(int64(n)) part.lastUpdatedMu.Lock() @@ -459,18 +519,13 @@ func (b *blobDownload) downloadChunkToBuffer(ctx context.Context, requestURL *ur break } if err != nil { - // rollback progress - b.Completed.Add(-int64(len(data))) + b.Completed.Add(-written) return err } } part.Completed.Store(part.Size) - if err := b.writePart(part.Name(), part); err != nil { - return err - } - - return nil + return b.writePart(part.Name(), part) }) g.Go(func() error { @@ -482,15 +537,11 @@ func (b *blobDownload) downloadChunkToBuffer(ctx context.Context, requestURL *ur if part.Completed.Load() >= part.Size { return nil } - part.lastUpdatedMu.Lock() lastUpdated := part.lastUpdated part.lastUpdatedMu.Unlock() - if !lastUpdated.IsZero() && time.Since(lastUpdated) > 30*time.Second { - const msg = "%s part %d stalled; retrying. If this persists, press ctrl-c to exit, then 'ollama pull' to find a faster connection." - slog.Info(fmt.Sprintf(msg, b.Digest[7:19], part.N)) - // reset last updated + slog.Info(fmt.Sprintf("%s part %d stalled; retrying", b.Digest[7:19], part.N)) part.lastUpdatedMu.Lock() part.lastUpdated = time.Time{} part.lastUpdatedMu.Unlock() @@ -502,10 +553,7 @@ func (b *blobDownload) downloadChunkToBuffer(ctx context.Context, requestURL *ur } }) - if err := g.Wait(); err != nil { - return nil, err - } - return data, nil + return g.Wait() } func (b *blobDownload) newPart(offset, size int64) error {