streaming
This commit is contained in:
parent
7c5b656bb3
commit
51cb1155ba
|
|
@ -15,6 +15,7 @@ import (
|
|||
"net/url"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
|
|
@ -98,27 +99,18 @@ func (p *blobDownloadPart) UnmarshalJSON(b []byte) error {
|
|||
|
||||
const numDownloadParts = 16
|
||||
|
||||
// Download tuning. Override via environment variables for different environments:
|
||||
// - Developer laptop: defaults are fine
|
||||
// - Fast server (10Gbps+): OLLAMA_DOWNLOAD_CONCURRENCY=32 or higher
|
||||
// - Memory constrained: reduce OLLAMA_DOWNLOAD_PART_SIZE and OLLAMA_DOWNLOAD_BUFFER_SIZE
|
||||
// Download tuning. Override via environment variables:
|
||||
// - Memory constrained: reduce OLLAMA_DOWNLOAD_CONCURRENCY
|
||||
// - Default uses ~64KB memory regardless of concurrency (streams to disk, hashes from page cache)
|
||||
var (
|
||||
// downloadPartSize is the size of each download part.
|
||||
// Smaller = less memory, more HTTP requests.
|
||||
// Default: 16MB. Override with OLLAMA_DOWNLOAD_PART_SIZE (in MB).
|
||||
downloadPartSize = int64(getEnvInt("OLLAMA_DOWNLOAD_PART_SIZE", 16)) * format.MegaByte
|
||||
|
||||
// downloadConcurrency limits concurrent part downloads.
|
||||
// Higher = faster on fast connections, more memory.
|
||||
// Default: 8. Override with OLLAMA_DOWNLOAD_CONCURRENCY.
|
||||
downloadConcurrency = getEnvInt("OLLAMA_DOWNLOAD_CONCURRENCY", 8)
|
||||
|
||||
// downloadBufferSize is the max bytes buffered in orderedWriter before
|
||||
// Submit blocks (backpressure). This bounds memory usage.
|
||||
// Default: 128MB. Override with OLLAMA_DOWNLOAD_BUFFER_SIZE (in MB).
|
||||
// Total memory ≈ (concurrency × part_size) + buffer_size
|
||||
// Default: (8 × 16MB) + 128MB = 256MB max
|
||||
downloadBufferSize = int64(getEnvInt("OLLAMA_DOWNLOAD_BUFFER_SIZE", 128)) * format.MegaByte
|
||||
// Higher = faster on fast connections. Memory stays ~64KB regardless.
|
||||
// Default: 2 * GOMAXPROCS (scales with CPU cores). Override with OLLAMA_DOWNLOAD_CONCURRENCY.
|
||||
downloadConcurrency = getEnvInt("OLLAMA_DOWNLOAD_CONCURRENCY", 2*runtime.GOMAXPROCS(0))
|
||||
)
|
||||
|
||||
func getEnvInt(key string, defaultVal int) int {
|
||||
|
|
@ -130,64 +122,105 @@ func getEnvInt(key string, defaultVal int) int {
|
|||
return defaultVal
|
||||
}
|
||||
|
||||
// orderedWriter buffers out-of-order parts and writes them sequentially
|
||||
// through a hasher and file. This allows parallel downloads while computing
|
||||
// the hash incrementally without a post-download verification pass.
|
||||
type orderedWriter struct {
|
||||
// streamHasher reads a file sequentially and hashes it as chunks complete.
|
||||
// Memory usage: ~64KB (just the read buffer), regardless of file size or concurrency.
|
||||
// Works by reading from OS page cache - data just written is still in RAM.
|
||||
type streamHasher struct {
|
||||
file *os.File
|
||||
hasher hash.Hash
|
||||
parts []*blobDownloadPart
|
||||
total int64 // total bytes to hash
|
||||
hashed atomic.Int64
|
||||
|
||||
mu sync.Mutex
|
||||
cond *sync.Cond
|
||||
next int // next expected part index
|
||||
pending map[int][]byte // out-of-order parts waiting to be written
|
||||
pendingSize int64 // total bytes in pending
|
||||
out io.Writer // destination (typically MultiWriter(file, hasher))
|
||||
hasher hash.Hash // for computing final digest
|
||||
completed []bool
|
||||
done bool
|
||||
err error
|
||||
}
|
||||
|
||||
func newOrderedWriter(file io.Writer, hasher hash.Hash) *orderedWriter {
|
||||
w := &orderedWriter{
|
||||
pending: make(map[int][]byte),
|
||||
out: io.MultiWriter(file, hasher),
|
||||
hasher: hasher,
|
||||
func newStreamHasher(file *os.File, parts []*blobDownloadPart, total int64) *streamHasher {
|
||||
h := &streamHasher{
|
||||
file: file,
|
||||
hasher: sha256.New(),
|
||||
parts: parts,
|
||||
total: total,
|
||||
completed: make([]bool, len(parts)),
|
||||
}
|
||||
w.cond = sync.NewCond(&w.mu)
|
||||
return w
|
||||
h.cond = sync.NewCond(&h.mu)
|
||||
return h
|
||||
}
|
||||
|
||||
// Submit adds a part to the writer. Parts are written in order; if this part
|
||||
// is out of order, it's buffered until earlier parts arrive. Blocks if the
|
||||
// pending buffer exceeds downloadBufferSize (backpressure), unless this is the
|
||||
// next expected part (which will drain the buffer).
|
||||
func (w *orderedWriter) Submit(partIndex int, data []byte) error {
|
||||
w.mu.Lock()
|
||||
defer w.mu.Unlock()
|
||||
|
||||
// Backpressure: wait if buffer is too full, unless we're the next part
|
||||
// (the next part will drain the buffer, so it must always proceed)
|
||||
for w.pendingSize+int64(len(data)) > downloadBufferSize && partIndex != w.next {
|
||||
w.cond.Wait()
|
||||
}
|
||||
|
||||
w.pending[partIndex] = data
|
||||
w.pendingSize += int64(len(data))
|
||||
|
||||
// Write all consecutive parts starting from next
|
||||
for w.pending[w.next] != nil {
|
||||
data := w.pending[w.next]
|
||||
if _, err := w.out.Write(data); err != nil {
|
||||
return err
|
||||
}
|
||||
w.pendingSize -= int64(len(data))
|
||||
w.pending[w.next] = nil // help GC free the slice
|
||||
delete(w.pending, w.next)
|
||||
w.next++
|
||||
w.cond.Broadcast() // wake any blocked submitters
|
||||
}
|
||||
return nil
|
||||
// MarkComplete signals that a part has been written to disk.
|
||||
func (h *streamHasher) MarkComplete(partIndex int) {
|
||||
h.mu.Lock()
|
||||
h.completed[partIndex] = true
|
||||
h.cond.Broadcast()
|
||||
h.mu.Unlock()
|
||||
}
|
||||
|
||||
// Digest returns the computed hash after all parts have been written.
|
||||
func (w *orderedWriter) Digest() string {
|
||||
return fmt.Sprintf("sha256:%x", w.hasher.Sum(nil))
|
||||
// Run reads and hashes the file sequentially. Call in a goroutine.
|
||||
func (h *streamHasher) Run() {
|
||||
buf := make([]byte, 64*1024) // 64KB read buffer
|
||||
var offset int64
|
||||
|
||||
for i, part := range h.parts {
|
||||
// Wait for this part to be written
|
||||
h.mu.Lock()
|
||||
for !h.completed[i] && !h.done {
|
||||
h.cond.Wait()
|
||||
}
|
||||
if h.done {
|
||||
h.mu.Unlock()
|
||||
return
|
||||
}
|
||||
h.mu.Unlock()
|
||||
|
||||
// Read and hash this part (from page cache)
|
||||
remaining := part.Size
|
||||
for remaining > 0 {
|
||||
n := int64(len(buf))
|
||||
if n > remaining {
|
||||
n = remaining
|
||||
}
|
||||
nr, err := h.file.ReadAt(buf[:n], offset)
|
||||
if err != nil && err != io.EOF {
|
||||
h.mu.Lock()
|
||||
h.err = err
|
||||
h.mu.Unlock()
|
||||
return
|
||||
}
|
||||
h.hasher.Write(buf[:nr])
|
||||
offset += int64(nr)
|
||||
remaining -= int64(nr)
|
||||
h.hashed.Store(offset)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Stop signals the hasher to exit early.
|
||||
func (h *streamHasher) Stop() {
|
||||
h.mu.Lock()
|
||||
h.done = true
|
||||
h.cond.Broadcast()
|
||||
h.mu.Unlock()
|
||||
}
|
||||
|
||||
// Hashed returns bytes hashed so far.
|
||||
func (h *streamHasher) Hashed() int64 {
|
||||
return h.hashed.Load()
|
||||
}
|
||||
|
||||
// Digest returns the computed hash.
|
||||
func (h *streamHasher) Digest() string {
|
||||
return fmt.Sprintf("sha256:%x", h.hasher.Sum(nil))
|
||||
}
|
||||
|
||||
// Err returns any error from hashing.
|
||||
func (h *streamHasher) Err() error {
|
||||
h.mu.Lock()
|
||||
defer h.mu.Unlock()
|
||||
return h.err
|
||||
}
|
||||
|
||||
func (p *blobDownloadPart) Name() string {
|
||||
|
|
@ -196,14 +229,6 @@ func (p *blobDownloadPart) Name() string {
|
|||
}, "-")
|
||||
}
|
||||
|
||||
func (p *blobDownloadPart) StartsAt() int64 {
|
||||
return p.Offset + p.Completed.Load()
|
||||
}
|
||||
|
||||
func (p *blobDownloadPart) StopsAt() int64 {
|
||||
return p.Offset + p.Size
|
||||
}
|
||||
|
||||
func (p *blobDownloadPart) Write(b []byte) (n int, err error) {
|
||||
n = len(b)
|
||||
p.blobDownload.Completed.Add(int64(n))
|
||||
|
|
@ -357,24 +382,32 @@ func (b *blobDownload) run(ctx context.Context, requestURL *url.URL, opts *regis
|
|||
return err
|
||||
}
|
||||
|
||||
// Download chunks in parallel, hash while writing via ordered writer.
|
||||
// Memory: (concurrency × part_size) + buffer. For 2GB files with defaults:
|
||||
// 8 × 32MB + 64MB = 320MB max.
|
||||
ow := newOrderedWriter(file, sha256.New())
|
||||
// Download chunks to disk, hash by reading from page cache.
|
||||
// Memory: ~64KB (hasher read buffer only), regardless of concurrency.
|
||||
// The hasher follows behind the downloaders, reading recently-written
|
||||
// data from OS page cache (RAM) rather than disk.
|
||||
sh := newStreamHasher(file, b.Parts, b.Total)
|
||||
|
||||
// Start hasher goroutine
|
||||
hashDone := make(chan struct{})
|
||||
go func() {
|
||||
sh.Run()
|
||||
close(hashDone)
|
||||
}()
|
||||
|
||||
g, inner := errgroup.WithContext(ctx)
|
||||
g.SetLimit(downloadConcurrency)
|
||||
for i := range b.Parts {
|
||||
part := b.Parts[i]
|
||||
if part.Completed.Load() == part.Size {
|
||||
sh.MarkComplete(part.N)
|
||||
continue
|
||||
}
|
||||
|
||||
g.Go(func() error {
|
||||
var data []byte
|
||||
var err error
|
||||
for try := 0; try < maxRetries; try++ {
|
||||
data, err = b.downloadChunkToBuffer(inner, directURL, part)
|
||||
err = b.downloadChunkToDisk(inner, directURL, file, part)
|
||||
switch {
|
||||
case errors.Is(err, context.Canceled), errors.Is(err, syscall.ENOSPC):
|
||||
return err
|
||||
|
|
@ -387,22 +420,46 @@ func (b *blobDownload) run(ctx context.Context, requestURL *url.URL, opts *regis
|
|||
time.Sleep(sleep)
|
||||
continue
|
||||
default:
|
||||
err := ow.Submit(part.N, data)
|
||||
data = nil // help GC
|
||||
return err
|
||||
sh.MarkComplete(part.N)
|
||||
return nil
|
||||
}
|
||||
}
|
||||
return fmt.Errorf("%w: %w", errMaxRetriesExceeded, err)
|
||||
})
|
||||
}
|
||||
|
||||
// Log progress periodically
|
||||
progressDone := make(chan struct{})
|
||||
go func() {
|
||||
ticker := time.NewTicker(time.Second)
|
||||
defer ticker.Stop()
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
dl := int(b.Completed.Load() * 100 / b.Total)
|
||||
h := int(sh.Hashed() * 100 / b.Total)
|
||||
slog.Info(fmt.Sprintf("progress: downloaded %d%% | hashed %d%%", dl, h))
|
||||
case <-progressDone:
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
if err := g.Wait(); err != nil {
|
||||
close(progressDone)
|
||||
sh.Stop()
|
||||
return err
|
||||
}
|
||||
|
||||
// Verify hash - no re-read needed, hash was computed while writing
|
||||
computed := ow.Digest()
|
||||
if computed != b.Digest {
|
||||
// Wait for hasher to finish
|
||||
<-hashDone
|
||||
close(progressDone)
|
||||
if err := sh.Err(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Verify hash
|
||||
if computed := sh.Digest(); computed != b.Digest {
|
||||
return fmt.Errorf("digest mismatch: got %s, want %s", computed, b.Digest)
|
||||
}
|
||||
|
||||
|
|
@ -424,11 +481,11 @@ func (b *blobDownload) run(ctx context.Context, requestURL *url.URL, opts *regis
|
|||
return nil
|
||||
}
|
||||
|
||||
// downloadChunkToBuffer downloads a part to a buffer, tracking progress and detecting stalls.
|
||||
func (b *blobDownload) downloadChunkToBuffer(ctx context.Context, requestURL *url.URL, part *blobDownloadPart) ([]byte, error) {
|
||||
// downloadChunkToDisk streams a part directly to disk at its offset.
|
||||
// Memory: ~32KB (read buffer only).
|
||||
func (b *blobDownload) downloadChunkToDisk(ctx context.Context, requestURL *url.URL, file *os.File, part *blobDownloadPart) error {
|
||||
g, ctx := errgroup.WithContext(ctx)
|
||||
|
||||
var data []byte
|
||||
g.Go(func() error {
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, requestURL.String(), nil)
|
||||
if err != nil {
|
||||
|
|
@ -441,14 +498,17 @@ func (b *blobDownload) downloadChunkToBuffer(ctx context.Context, requestURL *ur
|
|||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
// Pre-allocate buffer for the part
|
||||
data = make([]byte, 0, part.Size)
|
||||
buf := make([]byte, 32*1024) // 32KB read buffer
|
||||
w := io.NewOffsetWriter(file, part.Offset)
|
||||
buf := make([]byte, 32*1024)
|
||||
|
||||
for {
|
||||
var written int64
|
||||
for written < part.Size {
|
||||
n, err := resp.Body.Read(buf)
|
||||
if n > 0 {
|
||||
data = append(data, buf[:n]...)
|
||||
if _, werr := w.Write(buf[:n]); werr != nil {
|
||||
return werr
|
||||
}
|
||||
written += int64(n)
|
||||
b.Completed.Add(int64(n))
|
||||
|
||||
part.lastUpdatedMu.Lock()
|
||||
|
|
@ -459,18 +519,13 @@ func (b *blobDownload) downloadChunkToBuffer(ctx context.Context, requestURL *ur
|
|||
break
|
||||
}
|
||||
if err != nil {
|
||||
// rollback progress
|
||||
b.Completed.Add(-int64(len(data)))
|
||||
b.Completed.Add(-written)
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
part.Completed.Store(part.Size)
|
||||
if err := b.writePart(part.Name(), part); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
return b.writePart(part.Name(), part)
|
||||
})
|
||||
|
||||
g.Go(func() error {
|
||||
|
|
@ -482,15 +537,11 @@ func (b *blobDownload) downloadChunkToBuffer(ctx context.Context, requestURL *ur
|
|||
if part.Completed.Load() >= part.Size {
|
||||
return nil
|
||||
}
|
||||
|
||||
part.lastUpdatedMu.Lock()
|
||||
lastUpdated := part.lastUpdated
|
||||
part.lastUpdatedMu.Unlock()
|
||||
|
||||
if !lastUpdated.IsZero() && time.Since(lastUpdated) > 30*time.Second {
|
||||
const msg = "%s part %d stalled; retrying. If this persists, press ctrl-c to exit, then 'ollama pull' to find a faster connection."
|
||||
slog.Info(fmt.Sprintf(msg, b.Digest[7:19], part.N))
|
||||
// reset last updated
|
||||
slog.Info(fmt.Sprintf("%s part %d stalled; retrying", b.Digest[7:19], part.N))
|
||||
part.lastUpdatedMu.Lock()
|
||||
part.lastUpdated = time.Time{}
|
||||
part.lastUpdatedMu.Unlock()
|
||||
|
|
@ -502,10 +553,7 @@ func (b *blobDownload) downloadChunkToBuffer(ctx context.Context, requestURL *ur
|
|||
}
|
||||
})
|
||||
|
||||
if err := g.Wait(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return data, nil
|
||||
return g.Wait()
|
||||
}
|
||||
|
||||
func (b *blobDownload) newPart(offset, size int64) error {
|
||||
|
|
|
|||
Loading…
Reference in New Issue