diff --git a/server/download.go b/server/download.go index 42d713c09..2253c1532 100644 --- a/server/download.go +++ b/server/download.go @@ -2,9 +2,11 @@ package server import ( "context" + "crypto/sha256" "encoding/json" "errors" "fmt" + "hash" "io" "log/slog" "math" @@ -31,9 +33,45 @@ const maxRetries = 6 var ( errMaxRetriesExceeded = errors.New("max retries exceeded") errPartStalled = errors.New("part stalled") + errPartSlow = errors.New("part slow, racing") errMaxRedirectsExceeded = errors.New("maximum redirects exceeded (10) for directURL") ) +// speedTracker tracks download speeds and computes rolling median. +type speedTracker struct { + mu sync.Mutex + speeds []float64 // bytes per second +} + +func (s *speedTracker) Record(bytesPerSec float64) { + s.mu.Lock() + defer s.mu.Unlock() + s.speeds = append(s.speeds, bytesPerSec) + // Keep last 100 samples + if len(s.speeds) > 100 { + s.speeds = s.speeds[1:] + } +} + +func (s *speedTracker) Median() float64 { + s.mu.Lock() + defer s.mu.Unlock() + if len(s.speeds) < 3 { + return 0 // not enough data + } + // Simple median: sort a copy and take middle + sorted := make([]float64, len(s.speeds)) + copy(sorted, s.speeds) + for i := range sorted { + for j := i + 1; j < len(sorted); j++ { + if sorted[j] < sorted[i] { + sorted[i], sorted[j] = sorted[j], sorted[i] + } + } + } + return sorted[len(sorted)/2] +} + var blobDownloadManager sync.Map type blobDownload struct { @@ -94,26 +132,127 @@ func (p *blobDownloadPart) UnmarshalJSON(b []byte) error { return nil } -const ( - numDownloadParts = 16 - minDownloadPartSize int64 = 100 * format.MegaByte - maxDownloadPartSize int64 = 1000 * format.MegaByte +var ( + downloadPartSize = int64(envInt("OLLAMA_DOWNLOAD_PART_SIZE", 64)) * format.MegaByte + downloadConcurrency = envInt("OLLAMA_DOWNLOAD_CONCURRENCY", 48) ) +func envInt(key string, defaultVal int) int { + if s := os.Getenv(key); s != "" { + if v, err := strconv.Atoi(s); err == nil { + return v + } + } + return defaultVal +} + +// streamHasher reads a file sequentially and hashes it as chunks complete. +// Memory usage: ~64KB (just the read buffer), regardless of file size or concurrency. +// Works by reading from OS page cache - data just written is still in RAM. +type streamHasher struct { + file *os.File + hasher hash.Hash + parts []*blobDownloadPart + total int64 // total bytes to hash + hashed atomic.Int64 + + mu sync.Mutex + cond *sync.Cond + completed []bool + done bool + err error +} + +func newStreamHasher(file *os.File, parts []*blobDownloadPart, total int64) *streamHasher { + h := &streamHasher{ + file: file, + hasher: sha256.New(), + parts: parts, + total: total, + completed: make([]bool, len(parts)), + } + h.cond = sync.NewCond(&h.mu) + return h +} + +// MarkComplete signals that a part has been written to disk. +func (h *streamHasher) MarkComplete(partIndex int) { + h.mu.Lock() + h.completed[partIndex] = true + h.cond.Broadcast() + h.mu.Unlock() +} + +// Run reads and hashes the file sequentially. Call in a goroutine. +func (h *streamHasher) Run() { + buf := make([]byte, 64*1024) // 64KB read buffer + var offset int64 + + for i, part := range h.parts { + // Wait for this part to be written + h.mu.Lock() + for !h.completed[i] && !h.done { + h.cond.Wait() + } + if h.done { + h.mu.Unlock() + return + } + h.mu.Unlock() + + // Read and hash this part (from page cache) + remaining := part.Size + for remaining > 0 { + n := int64(len(buf)) + if n > remaining { + n = remaining + } + nr, err := h.file.ReadAt(buf[:n], offset) + if err != nil && err != io.EOF { + h.mu.Lock() + h.err = err + h.mu.Unlock() + return + } + h.hasher.Write(buf[:nr]) + offset += int64(nr) + remaining -= int64(nr) + h.hashed.Store(offset) + } + } +} + +// Stop signals the hasher to exit early. +func (h *streamHasher) Stop() { + h.mu.Lock() + h.done = true + h.cond.Broadcast() + h.mu.Unlock() +} + +// Hashed returns bytes hashed so far. +func (h *streamHasher) Hashed() int64 { + return h.hashed.Load() +} + +// Digest returns the computed hash. +func (h *streamHasher) Digest() string { + return fmt.Sprintf("sha256:%x", h.hasher.Sum(nil)) +} + +// Err returns any error from hashing. +func (h *streamHasher) Err() error { + h.mu.Lock() + defer h.mu.Unlock() + return h.err +} + func (p *blobDownloadPart) Name() string { return strings.Join([]string{ p.blobDownload.Name, "partial", strconv.Itoa(p.N), }, "-") } -func (p *blobDownloadPart) StartsAt() int64 { - return p.Offset + p.Completed.Load() -} - -func (p *blobDownloadPart) StopsAt() int64 { - return p.Offset + p.Size -} - func (p *blobDownloadPart) Write(b []byte) (n int, err error) { n = len(b) p.blobDownload.Completed.Add(int64(n)) @@ -151,14 +290,7 @@ func (b *blobDownload) Prepare(ctx context.Context, requestURL *url.URL, opts *r b.Total, _ = strconv.ParseInt(resp.Header.Get("Content-Length"), 10, 64) - size := b.Total / numDownloadParts - switch { - case size < minDownloadPartSize: - size = minDownloadPartSize - case size > maxDownloadPartSize: - size = maxDownloadPartSize - } - + size := downloadPartSize var offset int64 for offset < b.Total { if offset+size > b.Total { @@ -220,9 +352,6 @@ func (b *blobDownload) run(ctx context.Context, requestURL *url.URL, opts *regis return err } defer file.Close() - setSparse(file) - - _ = file.Truncate(b.Total) directURL, err := func() (*url.URL, error) { ctx, cancel := context.WithTimeout(ctx, 30*time.Second) @@ -270,44 +399,106 @@ func (b *blobDownload) run(ctx context.Context, requestURL *url.URL, opts *regis return err } + // Download chunks to disk, hash by reading from page cache. + // Memory: ~64KB (hasher read buffer only), regardless of concurrency. + // The hasher follows behind the downloaders, reading recently-written + // data from OS page cache (RAM) rather than disk. + sh := newStreamHasher(file, b.Parts, b.Total) + tracker := &speedTracker{} + + // Start hasher goroutine + hashDone := make(chan struct{}) + go func() { + sh.Run() + close(hashDone) + }() + + // Log progress periodically + // Page cache warning: if spread > 4GB, hasher may hit disk instead of RAM + const pageCacheWarningBytes = 4 << 30 // 4GB + progressDone := make(chan struct{}) + go func() { + ticker := time.NewTicker(time.Second) + defer ticker.Stop() + for { + select { + case <-ticker.C: + downloaded := b.Completed.Load() + hashed := sh.Hashed() + dlPct := int(downloaded * 100 / b.Total) + hPct := int(hashed * 100 / b.Total) + spread := dlPct - hPct + spreadBytes := downloaded - hashed + + slog.Debug(fmt.Sprintf("progress: downloaded %d%% | hashed %d%% | spread %d%%", dlPct, hPct, spread)) + if spreadBytes > pageCacheWarningBytes { + slog.Debug("page cache pressure", "ahead", fmt.Sprintf("%.1fGB", float64(spreadBytes)/(1<<30))) + } + case <-progressDone: + return + } + } + }() + g, inner := errgroup.WithContext(ctx) - g.SetLimit(numDownloadParts) + g.SetLimit(downloadConcurrency) for i := range b.Parts { part := b.Parts[i] if part.Completed.Load() == part.Size { + sh.MarkComplete(part.N) continue } g.Go(func() error { var err error + var slowRetries int for try := 0; try < maxRetries; try++ { - w := io.NewOffsetWriter(file, part.StartsAt()) - err = b.downloadChunk(inner, directURL, w, part) + // After 3 slow retries, stop checking slowness and let it complete + skipSlowCheck := slowRetries >= 3 + err = b.downloadChunkToDisk(inner, directURL, file, part, tracker, skipSlowCheck) switch { case errors.Is(err, context.Canceled), errors.Is(err, syscall.ENOSPC): - // return immediately if the context is canceled or the device is out of space return err case errors.Is(err, errPartStalled): try-- continue + case errors.Is(err, errPartSlow): + // Kill slow request, retry immediately (stays within concurrency limit) + slowRetries++ + try-- + continue case err != nil: sleep := time.Second * time.Duration(math.Pow(2, float64(try))) slog.Info(fmt.Sprintf("%s part %d attempt %d failed: %v, retrying in %s", b.Digest[7:19], part.N, try, err, sleep)) time.Sleep(sleep) continue default: + sh.MarkComplete(part.N) return nil } } - return fmt.Errorf("%w: %w", errMaxRetriesExceeded, err) }) } if err := g.Wait(); err != nil { + close(progressDone) + sh.Stop() return err } + // Wait for hasher to finish + <-hashDone + close(progressDone) + if err := sh.Err(); err != nil { + return err + } + + // Verify hash + if computed := sh.Digest(); computed != b.Digest { + return fmt.Errorf("digest mismatch: got %s, want %s", computed, b.Digest) + } + // explicitly close the file so we can rename it if err := file.Close(); err != nil { return err @@ -326,38 +517,69 @@ func (b *blobDownload) run(ctx context.Context, requestURL *url.URL, opts *regis return nil } -func (b *blobDownload) downloadChunk(ctx context.Context, requestURL *url.URL, w io.Writer, part *blobDownloadPart) error { +// downloadChunkToDisk streams a part directly to disk at its offset. +// Memory: ~32KB (read buffer only). +// If skipSlowCheck is true, don't flag slow parts (used after repeated slow retries). +func (b *blobDownload) downloadChunkToDisk(ctx context.Context, requestURL *url.URL, file *os.File, part *blobDownloadPart, tracker *speedTracker, skipSlowCheck bool) error { g, ctx := errgroup.WithContext(ctx) + startTime := time.Now() + var bytesAtLastCheck atomic.Int64 + g.Go(func() error { req, err := http.NewRequestWithContext(ctx, http.MethodGet, requestURL.String(), nil) if err != nil { return err } - req.Header.Set("Range", fmt.Sprintf("bytes=%d-%d", part.StartsAt(), part.StopsAt()-1)) + req.Header.Set("Range", fmt.Sprintf("bytes=%d-%d", part.Offset, part.Offset+part.Size-1)) resp, err := http.DefaultClient.Do(req) if err != nil { return err } defer resp.Body.Close() - n, err := io.CopyN(w, io.TeeReader(resp.Body, part), part.Size-part.Completed.Load()) - if err != nil && !errors.Is(err, context.Canceled) && !errors.Is(err, io.ErrUnexpectedEOF) { - // rollback progress - b.Completed.Add(-n) - return err + w := io.NewOffsetWriter(file, part.Offset) + buf := make([]byte, 32*1024) + + var written int64 + for written < part.Size { + n, err := resp.Body.Read(buf) + if n > 0 { + if _, werr := w.Write(buf[:n]); werr != nil { + return werr + } + written += int64(n) + b.Completed.Add(int64(n)) + bytesAtLastCheck.Store(written) + + part.lastUpdatedMu.Lock() + part.lastUpdated = time.Now() + part.lastUpdatedMu.Unlock() + } + if err == io.EOF { + break + } + if err != nil { + b.Completed.Add(-written) + return err + } } - part.Completed.Add(n) - if err := b.writePart(part.Name(), part); err != nil { - return err + // Record speed for this part + elapsed := time.Since(startTime).Seconds() + if elapsed > 0 { + tracker.Record(float64(part.Size) / elapsed) } - // return nil or context.Canceled or UnexpectedEOF (resumable) - return err + part.Completed.Store(part.Size) + return b.writePart(part.Name(), part) }) g.Go(func() error { ticker := time.NewTicker(time.Second) + defer ticker.Stop() + var lastBytes int64 + checksWithoutProgress := 0 + for { select { case <-ticker.C: @@ -365,19 +587,35 @@ func (b *blobDownload) downloadChunk(ctx context.Context, requestURL *url.URL, w return nil } - part.lastUpdatedMu.Lock() - lastUpdated := part.lastUpdated - part.lastUpdatedMu.Unlock() + currentBytes := bytesAtLastCheck.Load() - if !lastUpdated.IsZero() && time.Since(lastUpdated) > 30*time.Second { - const msg = "%s part %d stalled; retrying. If this persists, press ctrl-c to exit, then 'ollama pull' to find a faster connection." - slog.Info(fmt.Sprintf(msg, b.Digest[7:19], part.N)) - // reset last updated - part.lastUpdatedMu.Lock() - part.lastUpdated = time.Time{} - part.lastUpdatedMu.Unlock() - return errPartStalled + // Check for stall (no progress for 10 seconds) + if currentBytes == lastBytes { + checksWithoutProgress++ + if checksWithoutProgress >= 10 { + slog.Info(fmt.Sprintf("%s part %d stalled; retrying", b.Digest[7:19], part.N)) + return errPartStalled + } + } else { + checksWithoutProgress = 0 } + lastBytes = currentBytes + + // Check for slow speed after 5+ seconds (only for multi-part downloads) + // Skip if we've already retried for slowness too many times + elapsed := time.Since(startTime).Seconds() + if !skipSlowCheck && elapsed >= 5 && currentBytes > 0 && len(b.Parts) > 1 { + currentSpeed := float64(currentBytes) / elapsed + median := tracker.Median() + + // If we're below 10% of median speed, flag as slow + if median > 0 && currentSpeed < median*0.1 { + slog.Info(fmt.Sprintf("%s part %d slow (%.0f KB/s vs median %.0f KB/s); retrying", + b.Digest[7:19], part.N, currentSpeed/1024, median/1024)) + return errPartSlow + } + } + case <-ctx.Done(): return ctx.Err() } diff --git a/server/download_test.go b/server/download_test.go new file mode 100644 index 000000000..d45e1113a --- /dev/null +++ b/server/download_test.go @@ -0,0 +1,319 @@ +package server + +import ( + "crypto/rand" + "crypto/sha256" + "fmt" + "os" + "sync" + "testing" +) + +func TestSpeedTracker_Median(t *testing.T) { + s := &speedTracker{} + + // Less than 3 samples returns 0 + s.Record(100) + s.Record(200) + if got := s.Median(); got != 0 { + t.Errorf("expected 0 with < 3 samples, got %f", got) + } + + // With 3+ samples, returns median + s.Record(300) + // Samples: [100, 200, 300] -> median = 200 + if got := s.Median(); got != 200 { + t.Errorf("expected median 200, got %f", got) + } + + // Add more samples + s.Record(50) + s.Record(250) + // Samples: [100, 200, 300, 50, 250] sorted = [50, 100, 200, 250, 300] -> median = 200 + if got := s.Median(); got != 200 { + t.Errorf("expected median 200, got %f", got) + } +} + +func TestSpeedTracker_RollingWindow(t *testing.T) { + s := &speedTracker{} + + // Add 105 samples (should keep only last 100) + for i := 0; i < 105; i++ { + s.Record(float64(i)) + } + + s.mu.Lock() + if len(s.speeds) != 100 { + t.Errorf("expected 100 samples, got %d", len(s.speeds)) + } + // First sample should be 5 (0-4 were dropped) + if s.speeds[0] != 5 { + t.Errorf("expected first sample to be 5, got %f", s.speeds[0]) + } + s.mu.Unlock() +} + +func TestSpeedTracker_Concurrent(t *testing.T) { + s := &speedTracker{} + + var wg sync.WaitGroup + for i := 0; i < 100; i++ { + wg.Add(1) + go func(v int) { + defer wg.Done() + s.Record(float64(v)) + s.Median() // concurrent read + }(i) + } + wg.Wait() + + // Should not panic, and should have reasonable state + s.mu.Lock() + if len(s.speeds) == 0 || len(s.speeds) > 100 { + t.Errorf("unexpected speeds length: %d", len(s.speeds)) + } + s.mu.Unlock() +} + +func TestStreamHasher_Sequential(t *testing.T) { + // Create temp file + f, err := os.CreateTemp("", "streamhasher_test") + if err != nil { + t.Fatal(err) + } + defer os.Remove(f.Name()) + defer f.Close() + + // Write test data + data := []byte("hello world, this is a test of the stream hasher") + if _, err := f.Write(data); err != nil { + t.Fatal(err) + } + + // Create parts + parts := []*blobDownloadPart{ + {Offset: 0, Size: int64(len(data))}, + } + + sh := newStreamHasher(f, parts, int64(len(data))) + + // Mark complete and run + sh.MarkComplete(0) + + done := make(chan struct{}) + go func() { + sh.Run() + close(done) + }() + <-done + + // Verify digest + expected := fmt.Sprintf("sha256:%x", sha256.Sum256(data)) + if got := sh.Digest(); got != expected { + t.Errorf("digest mismatch: got %s, want %s", got, expected) + } + + if err := sh.Err(); err != nil { + t.Errorf("unexpected error: %v", err) + } +} + +func TestStreamHasher_OutOfOrderCompletion(t *testing.T) { + // Create temp file + f, err := os.CreateTemp("", "streamhasher_test") + if err != nil { + t.Fatal(err) + } + defer os.Remove(f.Name()) + defer f.Close() + + // Write test data (3 parts of 10 bytes each) + data := []byte("0123456789ABCDEFGHIJabcdefghij") + if _, err := f.Write(data); err != nil { + t.Fatal(err) + } + + // Create 3 parts + parts := []*blobDownloadPart{ + {N: 0, Offset: 0, Size: 10}, + {N: 1, Offset: 10, Size: 10}, + {N: 2, Offset: 20, Size: 10}, + } + + sh := newStreamHasher(f, parts, int64(len(data))) + + done := make(chan struct{}) + go func() { + sh.Run() + close(done) + }() + + // Mark parts complete out of order: 2, 0, 1 + sh.MarkComplete(2) + sh.MarkComplete(0) // This should trigger hashing of part 0 + sh.MarkComplete(1) // This should trigger hashing of parts 1 and 2 + + <-done + + // Verify digest + expected := fmt.Sprintf("sha256:%x", sha256.Sum256(data)) + if got := sh.Digest(); got != expected { + t.Errorf("digest mismatch: got %s, want %s", got, expected) + } +} + +func TestStreamHasher_Stop(t *testing.T) { + // Create temp file + f, err := os.CreateTemp("", "streamhasher_test") + if err != nil { + t.Fatal(err) + } + defer os.Remove(f.Name()) + defer f.Close() + + parts := []*blobDownloadPart{ + {Offset: 0, Size: 100}, + } + + sh := newStreamHasher(f, parts, 100) + + done := make(chan struct{}) + go func() { + sh.Run() + close(done) + }() + + // Stop without completing any parts + sh.Stop() + <-done + + // Should exit cleanly without error + if err := sh.Err(); err != nil { + t.Errorf("unexpected error after Stop: %v", err) + } +} + +func TestStreamHasher_HashedProgress(t *testing.T) { + // Create temp file with known data + f, err := os.CreateTemp("", "streamhasher_test") + if err != nil { + t.Fatal(err) + } + defer os.Remove(f.Name()) + defer f.Close() + + data := make([]byte, 1000) + rand.Read(data) + if _, err := f.Write(data); err != nil { + t.Fatal(err) + } + + parts := []*blobDownloadPart{ + {N: 0, Offset: 0, Size: 500}, + {N: 1, Offset: 500, Size: 500}, + } + + sh := newStreamHasher(f, parts, 1000) + + // Initially no progress + if got := sh.Hashed(); got != 0 { + t.Errorf("expected 0 hashed initially, got %d", got) + } + + done := make(chan struct{}) + go func() { + sh.Run() + close(done) + }() + + // Complete part 0 + sh.MarkComplete(0) + + // Give hasher time to process + for i := 0; i < 100; i++ { + if sh.Hashed() >= 500 { + break + } + } + + // Complete part 1 + sh.MarkComplete(1) + <-done + + if got := sh.Hashed(); got != 1000 { + t.Errorf("expected 1000 hashed, got %d", got) + } +} + +func BenchmarkSpeedTracker_Record(b *testing.B) { + s := &speedTracker{} + b.ResetTimer() + for i := 0; i < b.N; i++ { + s.Record(float64(i)) + } +} + +func BenchmarkSpeedTracker_Median(b *testing.B) { + s := &speedTracker{} + // Pre-populate with 100 samples + for i := 0; i < 100; i++ { + s.Record(float64(i)) + } + b.ResetTimer() + for i := 0; i < b.N; i++ { + s.Median() + } +} + +func BenchmarkStreamHasher(b *testing.B) { + // Create temp file with test data + f, err := os.CreateTemp("", "streamhasher_bench") + if err != nil { + b.Fatal(err) + } + defer os.Remove(f.Name()) + defer f.Close() + + size := 64 * 1024 * 1024 // 64MB + data := make([]byte, size) + rand.Read(data) + if _, err := f.Write(data); err != nil { + b.Fatal(err) + } + + parts := []*blobDownloadPart{ + {Offset: 0, Size: int64(size)}, + } + + b.SetBytes(int64(size)) + b.ResetTimer() + + for i := 0; i < b.N; i++ { + sh := newStreamHasher(f, parts, int64(size)) + sh.MarkComplete(0) + + done := make(chan struct{}) + go func() { + sh.Run() + close(done) + }() + <-done + } +} + +func BenchmarkHashThroughput(b *testing.B) { + // Baseline: raw SHA256 throughput on this machine + size := 256 * 1024 * 1024 // 256MB + data := make([]byte, size) + rand.Read(data) + + b.SetBytes(int64(size)) + b.ResetTimer() + + for i := 0; i < b.N; i++ { + h := sha256.New() + h.Write(data) + h.Sum(nil) + } +} diff --git a/server/images.go b/server/images.go index 951f7ac6e..d3de232b1 100644 --- a/server/images.go +++ b/server/images.go @@ -620,9 +620,8 @@ func PullModel(ctx context.Context, name string, regOpts *registryOptions, fn fu layers = append(layers, manifest.Config) } - skipVerify := make(map[string]bool) for _, layer := range layers { - cacheHit, err := downloadBlob(ctx, downloadOpts{ + _, err := downloadBlob(ctx, downloadOpts{ mp: mp, digest: layer.Digest, regOpts: regOpts, @@ -631,31 +630,12 @@ func PullModel(ctx context.Context, name string, regOpts *registryOptions, fn fu if err != nil { return err } - skipVerify[layer.Digest] = cacheHit delete(deleteMap, layer.Digest) } delete(deleteMap, manifest.Config.Digest) - fn(api.ProgressResponse{Status: "verifying sha256 digest"}) - for _, layer := range layers { - if skipVerify[layer.Digest] { - continue - } - if err := verifyBlob(layer.Digest); err != nil { - if errors.Is(err, errDigestMismatch) { - // something went wrong, delete the blob - fp, err := GetBlobsPath(layer.Digest) - if err != nil { - return err - } - if err := os.Remove(fp); err != nil { - // log this, but return the original error - slog.Info(fmt.Sprintf("couldn't remove file with digest mismatch '%s': %v", fp, err)) - } - } - return err - } - } + // Note: Digest verification now happens inline during download in blobDownload.run() + // via the orderedWriter, so no separate verification pass is needed. fn(api.ProgressResponse{Status: "writing manifest"}) diff --git a/server/sparse_common.go b/server/sparse_common.go deleted file mode 100644 index c88b2da0b..000000000 --- a/server/sparse_common.go +++ /dev/null @@ -1,8 +0,0 @@ -//go:build !windows - -package server - -import "os" - -func setSparse(*os.File) { -} diff --git a/server/sparse_windows.go b/server/sparse_windows.go deleted file mode 100644 index f21cbbda7..000000000 --- a/server/sparse_windows.go +++ /dev/null @@ -1,17 +0,0 @@ -package server - -import ( - "os" - - "golang.org/x/sys/windows" -) - -func setSparse(file *os.File) { - // exFat (and other FS types) don't support sparse files, so ignore errors - windows.DeviceIoControl( //nolint:errcheck - windows.Handle(file.Fd()), windows.FSCTL_SET_SPARSE, - nil, 0, - nil, 0, - nil, nil, - ) -}