streaming

This commit is contained in:
jmorganca 2025-12-20 17:39:55 -08:00
parent 7c5b656bb3
commit 51cb1155ba
1 changed files with 154 additions and 106 deletions

View File

@ -15,6 +15,7 @@ import (
"net/url"
"os"
"path/filepath"
"runtime"
"strconv"
"strings"
"sync"
@ -98,27 +99,18 @@ func (p *blobDownloadPart) UnmarshalJSON(b []byte) error {
const numDownloadParts = 16
// Download tuning. Override via environment variables for different environments:
// - Developer laptop: defaults are fine
// - Fast server (10Gbps+): OLLAMA_DOWNLOAD_CONCURRENCY=32 or higher
// - Memory constrained: reduce OLLAMA_DOWNLOAD_PART_SIZE and OLLAMA_DOWNLOAD_BUFFER_SIZE
// Download tuning. Override via environment variables:
// - Memory constrained: reduce OLLAMA_DOWNLOAD_CONCURRENCY
// - Default uses ~64KB memory regardless of concurrency (streams to disk, hashes from page cache)
var (
// downloadPartSize is the size of each download part.
// Smaller = less memory, more HTTP requests.
// Default: 16MB. Override with OLLAMA_DOWNLOAD_PART_SIZE (in MB).
downloadPartSize = int64(getEnvInt("OLLAMA_DOWNLOAD_PART_SIZE", 16)) * format.MegaByte
// downloadConcurrency limits concurrent part downloads.
// Higher = faster on fast connections, more memory.
// Default: 8. Override with OLLAMA_DOWNLOAD_CONCURRENCY.
downloadConcurrency = getEnvInt("OLLAMA_DOWNLOAD_CONCURRENCY", 8)
// downloadBufferSize is the max bytes buffered in orderedWriter before
// Submit blocks (backpressure). This bounds memory usage.
// Default: 128MB. Override with OLLAMA_DOWNLOAD_BUFFER_SIZE (in MB).
// Total memory ≈ (concurrency × part_size) + buffer_size
// Default: (8 × 16MB) + 128MB = 256MB max
downloadBufferSize = int64(getEnvInt("OLLAMA_DOWNLOAD_BUFFER_SIZE", 128)) * format.MegaByte
// Higher = faster on fast connections. Memory stays ~64KB regardless.
// Default: 2 * GOMAXPROCS (scales with CPU cores). Override with OLLAMA_DOWNLOAD_CONCURRENCY.
downloadConcurrency = getEnvInt("OLLAMA_DOWNLOAD_CONCURRENCY", 2*runtime.GOMAXPROCS(0))
)
func getEnvInt(key string, defaultVal int) int {
@ -130,64 +122,105 @@ func getEnvInt(key string, defaultVal int) int {
return defaultVal
}
// orderedWriter buffers out-of-order parts and writes them sequentially
// through a hasher and file. This allows parallel downloads while computing
// the hash incrementally without a post-download verification pass.
type orderedWriter struct {
// streamHasher reads a file sequentially and hashes it as chunks complete.
// Memory usage: ~64KB (just the read buffer), regardless of file size or concurrency.
// Works by reading from OS page cache - data just written is still in RAM.
type streamHasher struct {
file *os.File
hasher hash.Hash
parts []*blobDownloadPart
total int64 // total bytes to hash
hashed atomic.Int64
mu sync.Mutex
cond *sync.Cond
next int // next expected part index
pending map[int][]byte // out-of-order parts waiting to be written
pendingSize int64 // total bytes in pending
out io.Writer // destination (typically MultiWriter(file, hasher))
hasher hash.Hash // for computing final digest
completed []bool
done bool
err error
}
func newOrderedWriter(file io.Writer, hasher hash.Hash) *orderedWriter {
w := &orderedWriter{
pending: make(map[int][]byte),
out: io.MultiWriter(file, hasher),
hasher: hasher,
func newStreamHasher(file *os.File, parts []*blobDownloadPart, total int64) *streamHasher {
h := &streamHasher{
file: file,
hasher: sha256.New(),
parts: parts,
total: total,
completed: make([]bool, len(parts)),
}
w.cond = sync.NewCond(&w.mu)
return w
h.cond = sync.NewCond(&h.mu)
return h
}
// Submit adds a part to the writer. Parts are written in order; if this part
// is out of order, it's buffered until earlier parts arrive. Blocks if the
// pending buffer exceeds downloadBufferSize (backpressure), unless this is the
// next expected part (which will drain the buffer).
func (w *orderedWriter) Submit(partIndex int, data []byte) error {
w.mu.Lock()
defer w.mu.Unlock()
// Backpressure: wait if buffer is too full, unless we're the next part
// (the next part will drain the buffer, so it must always proceed)
for w.pendingSize+int64(len(data)) > downloadBufferSize && partIndex != w.next {
w.cond.Wait()
// MarkComplete signals that a part has been written to disk.
func (h *streamHasher) MarkComplete(partIndex int) {
h.mu.Lock()
h.completed[partIndex] = true
h.cond.Broadcast()
h.mu.Unlock()
}
w.pending[partIndex] = data
w.pendingSize += int64(len(data))
// Run reads and hashes the file sequentially. Call in a goroutine.
func (h *streamHasher) Run() {
buf := make([]byte, 64*1024) // 64KB read buffer
var offset int64
// Write all consecutive parts starting from next
for w.pending[w.next] != nil {
data := w.pending[w.next]
if _, err := w.out.Write(data); err != nil {
return err
for i, part := range h.parts {
// Wait for this part to be written
h.mu.Lock()
for !h.completed[i] && !h.done {
h.cond.Wait()
}
if h.done {
h.mu.Unlock()
return
}
h.mu.Unlock()
// Read and hash this part (from page cache)
remaining := part.Size
for remaining > 0 {
n := int64(len(buf))
if n > remaining {
n = remaining
}
nr, err := h.file.ReadAt(buf[:n], offset)
if err != nil && err != io.EOF {
h.mu.Lock()
h.err = err
h.mu.Unlock()
return
}
h.hasher.Write(buf[:nr])
offset += int64(nr)
remaining -= int64(nr)
h.hashed.Store(offset)
}
w.pendingSize -= int64(len(data))
w.pending[w.next] = nil // help GC free the slice
delete(w.pending, w.next)
w.next++
w.cond.Broadcast() // wake any blocked submitters
}
return nil
}
// Digest returns the computed hash after all parts have been written.
func (w *orderedWriter) Digest() string {
return fmt.Sprintf("sha256:%x", w.hasher.Sum(nil))
// Stop signals the hasher to exit early.
func (h *streamHasher) Stop() {
h.mu.Lock()
h.done = true
h.cond.Broadcast()
h.mu.Unlock()
}
// Hashed returns bytes hashed so far.
func (h *streamHasher) Hashed() int64 {
return h.hashed.Load()
}
// Digest returns the computed hash.
func (h *streamHasher) Digest() string {
return fmt.Sprintf("sha256:%x", h.hasher.Sum(nil))
}
// Err returns any error from hashing.
func (h *streamHasher) Err() error {
h.mu.Lock()
defer h.mu.Unlock()
return h.err
}
func (p *blobDownloadPart) Name() string {
@ -196,14 +229,6 @@ func (p *blobDownloadPart) Name() string {
}, "-")
}
func (p *blobDownloadPart) StartsAt() int64 {
return p.Offset + p.Completed.Load()
}
func (p *blobDownloadPart) StopsAt() int64 {
return p.Offset + p.Size
}
func (p *blobDownloadPart) Write(b []byte) (n int, err error) {
n = len(b)
p.blobDownload.Completed.Add(int64(n))
@ -357,24 +382,32 @@ func (b *blobDownload) run(ctx context.Context, requestURL *url.URL, opts *regis
return err
}
// Download chunks in parallel, hash while writing via ordered writer.
// Memory: (concurrency × part_size) + buffer. For 2GB files with defaults:
// 8 × 32MB + 64MB = 320MB max.
ow := newOrderedWriter(file, sha256.New())
// Download chunks to disk, hash by reading from page cache.
// Memory: ~64KB (hasher read buffer only), regardless of concurrency.
// The hasher follows behind the downloaders, reading recently-written
// data from OS page cache (RAM) rather than disk.
sh := newStreamHasher(file, b.Parts, b.Total)
// Start hasher goroutine
hashDone := make(chan struct{})
go func() {
sh.Run()
close(hashDone)
}()
g, inner := errgroup.WithContext(ctx)
g.SetLimit(downloadConcurrency)
for i := range b.Parts {
part := b.Parts[i]
if part.Completed.Load() == part.Size {
sh.MarkComplete(part.N)
continue
}
g.Go(func() error {
var data []byte
var err error
for try := 0; try < maxRetries; try++ {
data, err = b.downloadChunkToBuffer(inner, directURL, part)
err = b.downloadChunkToDisk(inner, directURL, file, part)
switch {
case errors.Is(err, context.Canceled), errors.Is(err, syscall.ENOSPC):
return err
@ -387,22 +420,46 @@ func (b *blobDownload) run(ctx context.Context, requestURL *url.URL, opts *regis
time.Sleep(sleep)
continue
default:
err := ow.Submit(part.N, data)
data = nil // help GC
return err
sh.MarkComplete(part.N)
return nil
}
}
return fmt.Errorf("%w: %w", errMaxRetriesExceeded, err)
})
}
// Log progress periodically
progressDone := make(chan struct{})
go func() {
ticker := time.NewTicker(time.Second)
defer ticker.Stop()
for {
select {
case <-ticker.C:
dl := int(b.Completed.Load() * 100 / b.Total)
h := int(sh.Hashed() * 100 / b.Total)
slog.Info(fmt.Sprintf("progress: downloaded %d%% | hashed %d%%", dl, h))
case <-progressDone:
return
}
}
}()
if err := g.Wait(); err != nil {
close(progressDone)
sh.Stop()
return err
}
// Verify hash - no re-read needed, hash was computed while writing
computed := ow.Digest()
if computed != b.Digest {
// Wait for hasher to finish
<-hashDone
close(progressDone)
if err := sh.Err(); err != nil {
return err
}
// Verify hash
if computed := sh.Digest(); computed != b.Digest {
return fmt.Errorf("digest mismatch: got %s, want %s", computed, b.Digest)
}
@ -424,11 +481,11 @@ func (b *blobDownload) run(ctx context.Context, requestURL *url.URL, opts *regis
return nil
}
// downloadChunkToBuffer downloads a part to a buffer, tracking progress and detecting stalls.
func (b *blobDownload) downloadChunkToBuffer(ctx context.Context, requestURL *url.URL, part *blobDownloadPart) ([]byte, error) {
// downloadChunkToDisk streams a part directly to disk at its offset.
// Memory: ~32KB (read buffer only).
func (b *blobDownload) downloadChunkToDisk(ctx context.Context, requestURL *url.URL, file *os.File, part *blobDownloadPart) error {
g, ctx := errgroup.WithContext(ctx)
var data []byte
g.Go(func() error {
req, err := http.NewRequestWithContext(ctx, http.MethodGet, requestURL.String(), nil)
if err != nil {
@ -441,14 +498,17 @@ func (b *blobDownload) downloadChunkToBuffer(ctx context.Context, requestURL *ur
}
defer resp.Body.Close()
// Pre-allocate buffer for the part
data = make([]byte, 0, part.Size)
buf := make([]byte, 32*1024) // 32KB read buffer
w := io.NewOffsetWriter(file, part.Offset)
buf := make([]byte, 32*1024)
for {
var written int64
for written < part.Size {
n, err := resp.Body.Read(buf)
if n > 0 {
data = append(data, buf[:n]...)
if _, werr := w.Write(buf[:n]); werr != nil {
return werr
}
written += int64(n)
b.Completed.Add(int64(n))
part.lastUpdatedMu.Lock()
@ -459,18 +519,13 @@ func (b *blobDownload) downloadChunkToBuffer(ctx context.Context, requestURL *ur
break
}
if err != nil {
// rollback progress
b.Completed.Add(-int64(len(data)))
b.Completed.Add(-written)
return err
}
}
part.Completed.Store(part.Size)
if err := b.writePart(part.Name(), part); err != nil {
return err
}
return nil
return b.writePart(part.Name(), part)
})
g.Go(func() error {
@ -482,15 +537,11 @@ func (b *blobDownload) downloadChunkToBuffer(ctx context.Context, requestURL *ur
if part.Completed.Load() >= part.Size {
return nil
}
part.lastUpdatedMu.Lock()
lastUpdated := part.lastUpdated
part.lastUpdatedMu.Unlock()
if !lastUpdated.IsZero() && time.Since(lastUpdated) > 30*time.Second {
const msg = "%s part %d stalled; retrying. If this persists, press ctrl-c to exit, then 'ollama pull' to find a faster connection."
slog.Info(fmt.Sprintf(msg, b.Digest[7:19], part.N))
// reset last updated
slog.Info(fmt.Sprintf("%s part %d stalled; retrying", b.Digest[7:19], part.N))
part.lastUpdatedMu.Lock()
part.lastUpdated = time.Time{}
part.lastUpdatedMu.Unlock()
@ -502,10 +553,7 @@ func (b *blobDownload) downloadChunkToBuffer(ctx context.Context, requestURL *ur
}
})
if err := g.Wait(); err != nil {
return nil, err
}
return data, nil
return g.Wait()
}
func (b *blobDownload) newPart(offset, size int64) error {