From 2aee6c172b019bfe3f3b5a54b4feaa84bcf89dd6 Mon Sep 17 00:00:00 2001 From: jmorganca Date: Sat, 20 Dec 2025 16:16:34 -0800 Subject: [PATCH 1/4] server: stream hash verification during download Hash blob data while downloading (by trying to using page cache as much as possible) instead of after, improving download speeds. Add configurable download concurrency (default 48) and part size (default 64MB) for faster downloads on high-bandwidth connections. --- server/download.go | 340 +++++++++++++++++++++++++++++++++------ server/download_test.go | 319 ++++++++++++++++++++++++++++++++++++ server/images.go | 26 +-- server/sparse_common.go | 8 - server/sparse_windows.go | 17 -- 5 files changed, 611 insertions(+), 99 deletions(-) create mode 100644 server/download_test.go delete mode 100644 server/sparse_common.go delete mode 100644 server/sparse_windows.go diff --git a/server/download.go b/server/download.go index 42d713c09..2253c1532 100644 --- a/server/download.go +++ b/server/download.go @@ -2,9 +2,11 @@ package server import ( "context" + "crypto/sha256" "encoding/json" "errors" "fmt" + "hash" "io" "log/slog" "math" @@ -31,9 +33,45 @@ const maxRetries = 6 var ( errMaxRetriesExceeded = errors.New("max retries exceeded") errPartStalled = errors.New("part stalled") + errPartSlow = errors.New("part slow, racing") errMaxRedirectsExceeded = errors.New("maximum redirects exceeded (10) for directURL") ) +// speedTracker tracks download speeds and computes rolling median. +type speedTracker struct { + mu sync.Mutex + speeds []float64 // bytes per second +} + +func (s *speedTracker) Record(bytesPerSec float64) { + s.mu.Lock() + defer s.mu.Unlock() + s.speeds = append(s.speeds, bytesPerSec) + // Keep last 100 samples + if len(s.speeds) > 100 { + s.speeds = s.speeds[1:] + } +} + +func (s *speedTracker) Median() float64 { + s.mu.Lock() + defer s.mu.Unlock() + if len(s.speeds) < 3 { + return 0 // not enough data + } + // Simple median: sort a copy and take middle + sorted := make([]float64, len(s.speeds)) + copy(sorted, s.speeds) + for i := range sorted { + for j := i + 1; j < len(sorted); j++ { + if sorted[j] < sorted[i] { + sorted[i], sorted[j] = sorted[j], sorted[i] + } + } + } + return sorted[len(sorted)/2] +} + var blobDownloadManager sync.Map type blobDownload struct { @@ -94,26 +132,127 @@ func (p *blobDownloadPart) UnmarshalJSON(b []byte) error { return nil } -const ( - numDownloadParts = 16 - minDownloadPartSize int64 = 100 * format.MegaByte - maxDownloadPartSize int64 = 1000 * format.MegaByte +var ( + downloadPartSize = int64(envInt("OLLAMA_DOWNLOAD_PART_SIZE", 64)) * format.MegaByte + downloadConcurrency = envInt("OLLAMA_DOWNLOAD_CONCURRENCY", 48) ) +func envInt(key string, defaultVal int) int { + if s := os.Getenv(key); s != "" { + if v, err := strconv.Atoi(s); err == nil { + return v + } + } + return defaultVal +} + +// streamHasher reads a file sequentially and hashes it as chunks complete. +// Memory usage: ~64KB (just the read buffer), regardless of file size or concurrency. +// Works by reading from OS page cache - data just written is still in RAM. +type streamHasher struct { + file *os.File + hasher hash.Hash + parts []*blobDownloadPart + total int64 // total bytes to hash + hashed atomic.Int64 + + mu sync.Mutex + cond *sync.Cond + completed []bool + done bool + err error +} + +func newStreamHasher(file *os.File, parts []*blobDownloadPart, total int64) *streamHasher { + h := &streamHasher{ + file: file, + hasher: sha256.New(), + parts: parts, + total: total, + completed: make([]bool, len(parts)), + } + h.cond = sync.NewCond(&h.mu) + return h +} + +// MarkComplete signals that a part has been written to disk. +func (h *streamHasher) MarkComplete(partIndex int) { + h.mu.Lock() + h.completed[partIndex] = true + h.cond.Broadcast() + h.mu.Unlock() +} + +// Run reads and hashes the file sequentially. Call in a goroutine. +func (h *streamHasher) Run() { + buf := make([]byte, 64*1024) // 64KB read buffer + var offset int64 + + for i, part := range h.parts { + // Wait for this part to be written + h.mu.Lock() + for !h.completed[i] && !h.done { + h.cond.Wait() + } + if h.done { + h.mu.Unlock() + return + } + h.mu.Unlock() + + // Read and hash this part (from page cache) + remaining := part.Size + for remaining > 0 { + n := int64(len(buf)) + if n > remaining { + n = remaining + } + nr, err := h.file.ReadAt(buf[:n], offset) + if err != nil && err != io.EOF { + h.mu.Lock() + h.err = err + h.mu.Unlock() + return + } + h.hasher.Write(buf[:nr]) + offset += int64(nr) + remaining -= int64(nr) + h.hashed.Store(offset) + } + } +} + +// Stop signals the hasher to exit early. +func (h *streamHasher) Stop() { + h.mu.Lock() + h.done = true + h.cond.Broadcast() + h.mu.Unlock() +} + +// Hashed returns bytes hashed so far. +func (h *streamHasher) Hashed() int64 { + return h.hashed.Load() +} + +// Digest returns the computed hash. +func (h *streamHasher) Digest() string { + return fmt.Sprintf("sha256:%x", h.hasher.Sum(nil)) +} + +// Err returns any error from hashing. +func (h *streamHasher) Err() error { + h.mu.Lock() + defer h.mu.Unlock() + return h.err +} + func (p *blobDownloadPart) Name() string { return strings.Join([]string{ p.blobDownload.Name, "partial", strconv.Itoa(p.N), }, "-") } -func (p *blobDownloadPart) StartsAt() int64 { - return p.Offset + p.Completed.Load() -} - -func (p *blobDownloadPart) StopsAt() int64 { - return p.Offset + p.Size -} - func (p *blobDownloadPart) Write(b []byte) (n int, err error) { n = len(b) p.blobDownload.Completed.Add(int64(n)) @@ -151,14 +290,7 @@ func (b *blobDownload) Prepare(ctx context.Context, requestURL *url.URL, opts *r b.Total, _ = strconv.ParseInt(resp.Header.Get("Content-Length"), 10, 64) - size := b.Total / numDownloadParts - switch { - case size < minDownloadPartSize: - size = minDownloadPartSize - case size > maxDownloadPartSize: - size = maxDownloadPartSize - } - + size := downloadPartSize var offset int64 for offset < b.Total { if offset+size > b.Total { @@ -220,9 +352,6 @@ func (b *blobDownload) run(ctx context.Context, requestURL *url.URL, opts *regis return err } defer file.Close() - setSparse(file) - - _ = file.Truncate(b.Total) directURL, err := func() (*url.URL, error) { ctx, cancel := context.WithTimeout(ctx, 30*time.Second) @@ -270,44 +399,106 @@ func (b *blobDownload) run(ctx context.Context, requestURL *url.URL, opts *regis return err } + // Download chunks to disk, hash by reading from page cache. + // Memory: ~64KB (hasher read buffer only), regardless of concurrency. + // The hasher follows behind the downloaders, reading recently-written + // data from OS page cache (RAM) rather than disk. + sh := newStreamHasher(file, b.Parts, b.Total) + tracker := &speedTracker{} + + // Start hasher goroutine + hashDone := make(chan struct{}) + go func() { + sh.Run() + close(hashDone) + }() + + // Log progress periodically + // Page cache warning: if spread > 4GB, hasher may hit disk instead of RAM + const pageCacheWarningBytes = 4 << 30 // 4GB + progressDone := make(chan struct{}) + go func() { + ticker := time.NewTicker(time.Second) + defer ticker.Stop() + for { + select { + case <-ticker.C: + downloaded := b.Completed.Load() + hashed := sh.Hashed() + dlPct := int(downloaded * 100 / b.Total) + hPct := int(hashed * 100 / b.Total) + spread := dlPct - hPct + spreadBytes := downloaded - hashed + + slog.Debug(fmt.Sprintf("progress: downloaded %d%% | hashed %d%% | spread %d%%", dlPct, hPct, spread)) + if spreadBytes > pageCacheWarningBytes { + slog.Debug("page cache pressure", "ahead", fmt.Sprintf("%.1fGB", float64(spreadBytes)/(1<<30))) + } + case <-progressDone: + return + } + } + }() + g, inner := errgroup.WithContext(ctx) - g.SetLimit(numDownloadParts) + g.SetLimit(downloadConcurrency) for i := range b.Parts { part := b.Parts[i] if part.Completed.Load() == part.Size { + sh.MarkComplete(part.N) continue } g.Go(func() error { var err error + var slowRetries int for try := 0; try < maxRetries; try++ { - w := io.NewOffsetWriter(file, part.StartsAt()) - err = b.downloadChunk(inner, directURL, w, part) + // After 3 slow retries, stop checking slowness and let it complete + skipSlowCheck := slowRetries >= 3 + err = b.downloadChunkToDisk(inner, directURL, file, part, tracker, skipSlowCheck) switch { case errors.Is(err, context.Canceled), errors.Is(err, syscall.ENOSPC): - // return immediately if the context is canceled or the device is out of space return err case errors.Is(err, errPartStalled): try-- continue + case errors.Is(err, errPartSlow): + // Kill slow request, retry immediately (stays within concurrency limit) + slowRetries++ + try-- + continue case err != nil: sleep := time.Second * time.Duration(math.Pow(2, float64(try))) slog.Info(fmt.Sprintf("%s part %d attempt %d failed: %v, retrying in %s", b.Digest[7:19], part.N, try, err, sleep)) time.Sleep(sleep) continue default: + sh.MarkComplete(part.N) return nil } } - return fmt.Errorf("%w: %w", errMaxRetriesExceeded, err) }) } if err := g.Wait(); err != nil { + close(progressDone) + sh.Stop() return err } + // Wait for hasher to finish + <-hashDone + close(progressDone) + if err := sh.Err(); err != nil { + return err + } + + // Verify hash + if computed := sh.Digest(); computed != b.Digest { + return fmt.Errorf("digest mismatch: got %s, want %s", computed, b.Digest) + } + // explicitly close the file so we can rename it if err := file.Close(); err != nil { return err @@ -326,38 +517,69 @@ func (b *blobDownload) run(ctx context.Context, requestURL *url.URL, opts *regis return nil } -func (b *blobDownload) downloadChunk(ctx context.Context, requestURL *url.URL, w io.Writer, part *blobDownloadPart) error { +// downloadChunkToDisk streams a part directly to disk at its offset. +// Memory: ~32KB (read buffer only). +// If skipSlowCheck is true, don't flag slow parts (used after repeated slow retries). +func (b *blobDownload) downloadChunkToDisk(ctx context.Context, requestURL *url.URL, file *os.File, part *blobDownloadPart, tracker *speedTracker, skipSlowCheck bool) error { g, ctx := errgroup.WithContext(ctx) + startTime := time.Now() + var bytesAtLastCheck atomic.Int64 + g.Go(func() error { req, err := http.NewRequestWithContext(ctx, http.MethodGet, requestURL.String(), nil) if err != nil { return err } - req.Header.Set("Range", fmt.Sprintf("bytes=%d-%d", part.StartsAt(), part.StopsAt()-1)) + req.Header.Set("Range", fmt.Sprintf("bytes=%d-%d", part.Offset, part.Offset+part.Size-1)) resp, err := http.DefaultClient.Do(req) if err != nil { return err } defer resp.Body.Close() - n, err := io.CopyN(w, io.TeeReader(resp.Body, part), part.Size-part.Completed.Load()) - if err != nil && !errors.Is(err, context.Canceled) && !errors.Is(err, io.ErrUnexpectedEOF) { - // rollback progress - b.Completed.Add(-n) - return err + w := io.NewOffsetWriter(file, part.Offset) + buf := make([]byte, 32*1024) + + var written int64 + for written < part.Size { + n, err := resp.Body.Read(buf) + if n > 0 { + if _, werr := w.Write(buf[:n]); werr != nil { + return werr + } + written += int64(n) + b.Completed.Add(int64(n)) + bytesAtLastCheck.Store(written) + + part.lastUpdatedMu.Lock() + part.lastUpdated = time.Now() + part.lastUpdatedMu.Unlock() + } + if err == io.EOF { + break + } + if err != nil { + b.Completed.Add(-written) + return err + } } - part.Completed.Add(n) - if err := b.writePart(part.Name(), part); err != nil { - return err + // Record speed for this part + elapsed := time.Since(startTime).Seconds() + if elapsed > 0 { + tracker.Record(float64(part.Size) / elapsed) } - // return nil or context.Canceled or UnexpectedEOF (resumable) - return err + part.Completed.Store(part.Size) + return b.writePart(part.Name(), part) }) g.Go(func() error { ticker := time.NewTicker(time.Second) + defer ticker.Stop() + var lastBytes int64 + checksWithoutProgress := 0 + for { select { case <-ticker.C: @@ -365,19 +587,35 @@ func (b *blobDownload) downloadChunk(ctx context.Context, requestURL *url.URL, w return nil } - part.lastUpdatedMu.Lock() - lastUpdated := part.lastUpdated - part.lastUpdatedMu.Unlock() + currentBytes := bytesAtLastCheck.Load() - if !lastUpdated.IsZero() && time.Since(lastUpdated) > 30*time.Second { - const msg = "%s part %d stalled; retrying. If this persists, press ctrl-c to exit, then 'ollama pull' to find a faster connection." - slog.Info(fmt.Sprintf(msg, b.Digest[7:19], part.N)) - // reset last updated - part.lastUpdatedMu.Lock() - part.lastUpdated = time.Time{} - part.lastUpdatedMu.Unlock() - return errPartStalled + // Check for stall (no progress for 10 seconds) + if currentBytes == lastBytes { + checksWithoutProgress++ + if checksWithoutProgress >= 10 { + slog.Info(fmt.Sprintf("%s part %d stalled; retrying", b.Digest[7:19], part.N)) + return errPartStalled + } + } else { + checksWithoutProgress = 0 } + lastBytes = currentBytes + + // Check for slow speed after 5+ seconds (only for multi-part downloads) + // Skip if we've already retried for slowness too many times + elapsed := time.Since(startTime).Seconds() + if !skipSlowCheck && elapsed >= 5 && currentBytes > 0 && len(b.Parts) > 1 { + currentSpeed := float64(currentBytes) / elapsed + median := tracker.Median() + + // If we're below 10% of median speed, flag as slow + if median > 0 && currentSpeed < median*0.1 { + slog.Info(fmt.Sprintf("%s part %d slow (%.0f KB/s vs median %.0f KB/s); retrying", + b.Digest[7:19], part.N, currentSpeed/1024, median/1024)) + return errPartSlow + } + } + case <-ctx.Done(): return ctx.Err() } diff --git a/server/download_test.go b/server/download_test.go new file mode 100644 index 000000000..d45e1113a --- /dev/null +++ b/server/download_test.go @@ -0,0 +1,319 @@ +package server + +import ( + "crypto/rand" + "crypto/sha256" + "fmt" + "os" + "sync" + "testing" +) + +func TestSpeedTracker_Median(t *testing.T) { + s := &speedTracker{} + + // Less than 3 samples returns 0 + s.Record(100) + s.Record(200) + if got := s.Median(); got != 0 { + t.Errorf("expected 0 with < 3 samples, got %f", got) + } + + // With 3+ samples, returns median + s.Record(300) + // Samples: [100, 200, 300] -> median = 200 + if got := s.Median(); got != 200 { + t.Errorf("expected median 200, got %f", got) + } + + // Add more samples + s.Record(50) + s.Record(250) + // Samples: [100, 200, 300, 50, 250] sorted = [50, 100, 200, 250, 300] -> median = 200 + if got := s.Median(); got != 200 { + t.Errorf("expected median 200, got %f", got) + } +} + +func TestSpeedTracker_RollingWindow(t *testing.T) { + s := &speedTracker{} + + // Add 105 samples (should keep only last 100) + for i := 0; i < 105; i++ { + s.Record(float64(i)) + } + + s.mu.Lock() + if len(s.speeds) != 100 { + t.Errorf("expected 100 samples, got %d", len(s.speeds)) + } + // First sample should be 5 (0-4 were dropped) + if s.speeds[0] != 5 { + t.Errorf("expected first sample to be 5, got %f", s.speeds[0]) + } + s.mu.Unlock() +} + +func TestSpeedTracker_Concurrent(t *testing.T) { + s := &speedTracker{} + + var wg sync.WaitGroup + for i := 0; i < 100; i++ { + wg.Add(1) + go func(v int) { + defer wg.Done() + s.Record(float64(v)) + s.Median() // concurrent read + }(i) + } + wg.Wait() + + // Should not panic, and should have reasonable state + s.mu.Lock() + if len(s.speeds) == 0 || len(s.speeds) > 100 { + t.Errorf("unexpected speeds length: %d", len(s.speeds)) + } + s.mu.Unlock() +} + +func TestStreamHasher_Sequential(t *testing.T) { + // Create temp file + f, err := os.CreateTemp("", "streamhasher_test") + if err != nil { + t.Fatal(err) + } + defer os.Remove(f.Name()) + defer f.Close() + + // Write test data + data := []byte("hello world, this is a test of the stream hasher") + if _, err := f.Write(data); err != nil { + t.Fatal(err) + } + + // Create parts + parts := []*blobDownloadPart{ + {Offset: 0, Size: int64(len(data))}, + } + + sh := newStreamHasher(f, parts, int64(len(data))) + + // Mark complete and run + sh.MarkComplete(0) + + done := make(chan struct{}) + go func() { + sh.Run() + close(done) + }() + <-done + + // Verify digest + expected := fmt.Sprintf("sha256:%x", sha256.Sum256(data)) + if got := sh.Digest(); got != expected { + t.Errorf("digest mismatch: got %s, want %s", got, expected) + } + + if err := sh.Err(); err != nil { + t.Errorf("unexpected error: %v", err) + } +} + +func TestStreamHasher_OutOfOrderCompletion(t *testing.T) { + // Create temp file + f, err := os.CreateTemp("", "streamhasher_test") + if err != nil { + t.Fatal(err) + } + defer os.Remove(f.Name()) + defer f.Close() + + // Write test data (3 parts of 10 bytes each) + data := []byte("0123456789ABCDEFGHIJabcdefghij") + if _, err := f.Write(data); err != nil { + t.Fatal(err) + } + + // Create 3 parts + parts := []*blobDownloadPart{ + {N: 0, Offset: 0, Size: 10}, + {N: 1, Offset: 10, Size: 10}, + {N: 2, Offset: 20, Size: 10}, + } + + sh := newStreamHasher(f, parts, int64(len(data))) + + done := make(chan struct{}) + go func() { + sh.Run() + close(done) + }() + + // Mark parts complete out of order: 2, 0, 1 + sh.MarkComplete(2) + sh.MarkComplete(0) // This should trigger hashing of part 0 + sh.MarkComplete(1) // This should trigger hashing of parts 1 and 2 + + <-done + + // Verify digest + expected := fmt.Sprintf("sha256:%x", sha256.Sum256(data)) + if got := sh.Digest(); got != expected { + t.Errorf("digest mismatch: got %s, want %s", got, expected) + } +} + +func TestStreamHasher_Stop(t *testing.T) { + // Create temp file + f, err := os.CreateTemp("", "streamhasher_test") + if err != nil { + t.Fatal(err) + } + defer os.Remove(f.Name()) + defer f.Close() + + parts := []*blobDownloadPart{ + {Offset: 0, Size: 100}, + } + + sh := newStreamHasher(f, parts, 100) + + done := make(chan struct{}) + go func() { + sh.Run() + close(done) + }() + + // Stop without completing any parts + sh.Stop() + <-done + + // Should exit cleanly without error + if err := sh.Err(); err != nil { + t.Errorf("unexpected error after Stop: %v", err) + } +} + +func TestStreamHasher_HashedProgress(t *testing.T) { + // Create temp file with known data + f, err := os.CreateTemp("", "streamhasher_test") + if err != nil { + t.Fatal(err) + } + defer os.Remove(f.Name()) + defer f.Close() + + data := make([]byte, 1000) + rand.Read(data) + if _, err := f.Write(data); err != nil { + t.Fatal(err) + } + + parts := []*blobDownloadPart{ + {N: 0, Offset: 0, Size: 500}, + {N: 1, Offset: 500, Size: 500}, + } + + sh := newStreamHasher(f, parts, 1000) + + // Initially no progress + if got := sh.Hashed(); got != 0 { + t.Errorf("expected 0 hashed initially, got %d", got) + } + + done := make(chan struct{}) + go func() { + sh.Run() + close(done) + }() + + // Complete part 0 + sh.MarkComplete(0) + + // Give hasher time to process + for i := 0; i < 100; i++ { + if sh.Hashed() >= 500 { + break + } + } + + // Complete part 1 + sh.MarkComplete(1) + <-done + + if got := sh.Hashed(); got != 1000 { + t.Errorf("expected 1000 hashed, got %d", got) + } +} + +func BenchmarkSpeedTracker_Record(b *testing.B) { + s := &speedTracker{} + b.ResetTimer() + for i := 0; i < b.N; i++ { + s.Record(float64(i)) + } +} + +func BenchmarkSpeedTracker_Median(b *testing.B) { + s := &speedTracker{} + // Pre-populate with 100 samples + for i := 0; i < 100; i++ { + s.Record(float64(i)) + } + b.ResetTimer() + for i := 0; i < b.N; i++ { + s.Median() + } +} + +func BenchmarkStreamHasher(b *testing.B) { + // Create temp file with test data + f, err := os.CreateTemp("", "streamhasher_bench") + if err != nil { + b.Fatal(err) + } + defer os.Remove(f.Name()) + defer f.Close() + + size := 64 * 1024 * 1024 // 64MB + data := make([]byte, size) + rand.Read(data) + if _, err := f.Write(data); err != nil { + b.Fatal(err) + } + + parts := []*blobDownloadPart{ + {Offset: 0, Size: int64(size)}, + } + + b.SetBytes(int64(size)) + b.ResetTimer() + + for i := 0; i < b.N; i++ { + sh := newStreamHasher(f, parts, int64(size)) + sh.MarkComplete(0) + + done := make(chan struct{}) + go func() { + sh.Run() + close(done) + }() + <-done + } +} + +func BenchmarkHashThroughput(b *testing.B) { + // Baseline: raw SHA256 throughput on this machine + size := 256 * 1024 * 1024 // 256MB + data := make([]byte, size) + rand.Read(data) + + b.SetBytes(int64(size)) + b.ResetTimer() + + for i := 0; i < b.N; i++ { + h := sha256.New() + h.Write(data) + h.Sum(nil) + } +} diff --git a/server/images.go b/server/images.go index 951f7ac6e..d3de232b1 100644 --- a/server/images.go +++ b/server/images.go @@ -620,9 +620,8 @@ func PullModel(ctx context.Context, name string, regOpts *registryOptions, fn fu layers = append(layers, manifest.Config) } - skipVerify := make(map[string]bool) for _, layer := range layers { - cacheHit, err := downloadBlob(ctx, downloadOpts{ + _, err := downloadBlob(ctx, downloadOpts{ mp: mp, digest: layer.Digest, regOpts: regOpts, @@ -631,31 +630,12 @@ func PullModel(ctx context.Context, name string, regOpts *registryOptions, fn fu if err != nil { return err } - skipVerify[layer.Digest] = cacheHit delete(deleteMap, layer.Digest) } delete(deleteMap, manifest.Config.Digest) - fn(api.ProgressResponse{Status: "verifying sha256 digest"}) - for _, layer := range layers { - if skipVerify[layer.Digest] { - continue - } - if err := verifyBlob(layer.Digest); err != nil { - if errors.Is(err, errDigestMismatch) { - // something went wrong, delete the blob - fp, err := GetBlobsPath(layer.Digest) - if err != nil { - return err - } - if err := os.Remove(fp); err != nil { - // log this, but return the original error - slog.Info(fmt.Sprintf("couldn't remove file with digest mismatch '%s': %v", fp, err)) - } - } - return err - } - } + // Note: Digest verification now happens inline during download in blobDownload.run() + // via the orderedWriter, so no separate verification pass is needed. fn(api.ProgressResponse{Status: "writing manifest"}) diff --git a/server/sparse_common.go b/server/sparse_common.go deleted file mode 100644 index c88b2da0b..000000000 --- a/server/sparse_common.go +++ /dev/null @@ -1,8 +0,0 @@ -//go:build !windows - -package server - -import "os" - -func setSparse(*os.File) { -} diff --git a/server/sparse_windows.go b/server/sparse_windows.go deleted file mode 100644 index f21cbbda7..000000000 --- a/server/sparse_windows.go +++ /dev/null @@ -1,17 +0,0 @@ -package server - -import ( - "os" - - "golang.org/x/sys/windows" -) - -func setSparse(file *os.File) { - // exFat (and other FS types) don't support sparse files, so ignore errors - windows.DeviceIoControl( //nolint:errcheck - windows.Handle(file.Fd()), windows.FSCTL_SET_SPARSE, - nil, 0, - nil, 0, - nil, nil, - ) -} From 9a8c2a46354eed7657715797d39ed17f92d0cc03 Mon Sep 17 00:00:00 2001 From: jmorganca Date: Sat, 20 Dec 2025 22:00:35 -0800 Subject: [PATCH 2/4] revert unwanted changes --- server/download.go | 50 +++++----------------------------------- server/sparse_common.go | 8 +++++++ server/sparse_windows.go | 17 ++++++++++++++ 3 files changed, 31 insertions(+), 44 deletions(-) create mode 100644 server/sparse_common.go create mode 100644 server/sparse_windows.go diff --git a/server/download.go b/server/download.go index 2253c1532..382fdaa9f 100644 --- a/server/download.go +++ b/server/download.go @@ -15,6 +15,7 @@ import ( "net/url" "os" "path/filepath" + "slices" "strconv" "strings" "sync" @@ -59,16 +60,9 @@ func (s *speedTracker) Median() float64 { if len(s.speeds) < 3 { return 0 // not enough data } - // Simple median: sort a copy and take middle - sorted := make([]float64, len(s.speeds)) - copy(sorted, s.speeds) - for i := range sorted { - for j := i + 1; j < len(sorted); j++ { - if sorted[j] < sorted[i] { - sorted[i], sorted[j] = sorted[j], sorted[i] - } - } - } + + sorted := slices.Clone(s.speeds) + slices.Sort(sorted) return sorted[len(sorted)/2] } @@ -183,7 +177,7 @@ func (h *streamHasher) MarkComplete(partIndex int) { h.mu.Unlock() } -// Run reads and hashes the file sequentially. Call in a goroutine. +// Run reads and hashes the file sequentially func (h *streamHasher) Run() { buf := make([]byte, 64*1024) // 64KB read buffer var offset int64 @@ -399,47 +393,18 @@ func (b *blobDownload) run(ctx context.Context, requestURL *url.URL, opts *regis return err } - // Download chunks to disk, hash by reading from page cache. - // Memory: ~64KB (hasher read buffer only), regardless of concurrency. + // Download chunks to disk, hash sequentially // The hasher follows behind the downloaders, reading recently-written // data from OS page cache (RAM) rather than disk. sh := newStreamHasher(file, b.Parts, b.Total) tracker := &speedTracker{} - // Start hasher goroutine hashDone := make(chan struct{}) go func() { sh.Run() close(hashDone) }() - // Log progress periodically - // Page cache warning: if spread > 4GB, hasher may hit disk instead of RAM - const pageCacheWarningBytes = 4 << 30 // 4GB - progressDone := make(chan struct{}) - go func() { - ticker := time.NewTicker(time.Second) - defer ticker.Stop() - for { - select { - case <-ticker.C: - downloaded := b.Completed.Load() - hashed := sh.Hashed() - dlPct := int(downloaded * 100 / b.Total) - hPct := int(hashed * 100 / b.Total) - spread := dlPct - hPct - spreadBytes := downloaded - hashed - - slog.Debug(fmt.Sprintf("progress: downloaded %d%% | hashed %d%% | spread %d%%", dlPct, hPct, spread)) - if spreadBytes > pageCacheWarningBytes { - slog.Debug("page cache pressure", "ahead", fmt.Sprintf("%.1fGB", float64(spreadBytes)/(1<<30))) - } - case <-progressDone: - return - } - } - }() - g, inner := errgroup.WithContext(ctx) g.SetLimit(downloadConcurrency) for i := range b.Parts { @@ -482,14 +447,12 @@ func (b *blobDownload) run(ctx context.Context, requestURL *url.URL, opts *regis } if err := g.Wait(); err != nil { - close(progressDone) sh.Stop() return err } // Wait for hasher to finish <-hashDone - close(progressDone) if err := sh.Err(); err != nil { return err } @@ -518,7 +481,6 @@ func (b *blobDownload) run(ctx context.Context, requestURL *url.URL, opts *regis } // downloadChunkToDisk streams a part directly to disk at its offset. -// Memory: ~32KB (read buffer only). // If skipSlowCheck is true, don't flag slow parts (used after repeated slow retries). func (b *blobDownload) downloadChunkToDisk(ctx context.Context, requestURL *url.URL, file *os.File, part *blobDownloadPart, tracker *speedTracker, skipSlowCheck bool) error { g, ctx := errgroup.WithContext(ctx) diff --git a/server/sparse_common.go b/server/sparse_common.go new file mode 100644 index 000000000..c88b2da0b --- /dev/null +++ b/server/sparse_common.go @@ -0,0 +1,8 @@ +//go:build !windows + +package server + +import "os" + +func setSparse(*os.File) { +} diff --git a/server/sparse_windows.go b/server/sparse_windows.go new file mode 100644 index 000000000..f21cbbda7 --- /dev/null +++ b/server/sparse_windows.go @@ -0,0 +1,17 @@ +package server + +import ( + "os" + + "golang.org/x/sys/windows" +) + +func setSparse(file *os.File) { + // exFat (and other FS types) don't support sparse files, so ignore errors + windows.DeviceIoControl( //nolint:errcheck + windows.Handle(file.Fd()), windows.FSCTL_SET_SPARSE, + nil, 0, + nil, 0, + nil, nil, + ) +} From bdb9ea4772b40141b2a3eab88edbc7c861c61a48 Mon Sep 17 00:00:00 2001 From: jmorganca Date: Sat, 20 Dec 2025 22:24:31 -0800 Subject: [PATCH 3/4] cleanup --- server/download.go | 54 +++++++++++++++++------------------------ server/download_test.go | 49 +++++++++++++++++++------------------ server/images.go | 30 ++--------------------- 3 files changed, 49 insertions(+), 84 deletions(-) diff --git a/server/download.go b/server/download.go index 382fdaa9f..3996a8334 100644 --- a/server/download.go +++ b/server/download.go @@ -34,7 +34,7 @@ const maxRetries = 6 var ( errMaxRetriesExceeded = errors.New("max retries exceeded") errPartStalled = errors.New("part stalled") - errPartSlow = errors.New("part slow, racing") + errPartSlow = errors.New("part too slow") errMaxRedirectsExceeded = errors.New("maximum redirects exceeded (10) for directURL") ) @@ -48,8 +48,8 @@ func (s *speedTracker) Record(bytesPerSec float64) { s.mu.Lock() defer s.mu.Unlock() s.speeds = append(s.speeds, bytesPerSec) - // Keep last 100 samples - if len(s.speeds) > 100 { + // Keep last 30 samples (flushes stale speeds faster when conditions change) + if len(s.speeds) > 30 { s.speeds = s.speeds[1:] } } @@ -57,8 +57,8 @@ func (s *speedTracker) Record(bytesPerSec float64) { func (s *speedTracker) Median() float64 { s.mu.Lock() defer s.mu.Unlock() - if len(s.speeds) < 3 { - return 0 // not enough data + if len(s.speeds) < 10 { + return 0 // not enough data for reliable median } sorted := slices.Clone(s.speeds) @@ -90,9 +90,6 @@ type blobDownloadPart struct { Size int64 Completed atomic.Int64 - lastUpdatedMu sync.Mutex - lastUpdated time.Time - *blobDownload `json:"-"` } @@ -128,7 +125,7 @@ func (p *blobDownloadPart) UnmarshalJSON(b []byte) error { var ( downloadPartSize = int64(envInt("OLLAMA_DOWNLOAD_PART_SIZE", 64)) * format.MegaByte - downloadConcurrency = envInt("OLLAMA_DOWNLOAD_CONCURRENCY", 48) + downloadConcurrency = envInt("OLLAMA_DOWNLOAD_CONCURRENCY", 32) ) func envInt(key string, defaultVal int) int { @@ -142,7 +139,7 @@ func envInt(key string, defaultVal int) int { // streamHasher reads a file sequentially and hashes it as chunks complete. // Memory usage: ~64KB (just the read buffer), regardless of file size or concurrency. -// Works by reading from OS page cache - data just written is still in RAM. +// Works by trying to read from OS page cache - data just written should still be in RAM. type streamHasher struct { file *os.File hasher hash.Hash @@ -169,8 +166,8 @@ func newStreamHasher(file *os.File, parts []*blobDownloadPart, total int64) *str return h } -// MarkComplete signals that a part has been written to disk. -func (h *streamHasher) MarkComplete(partIndex int) { +// Done signals that a part has been written to disk. +func (h *streamHasher) Done(partIndex int) { h.mu.Lock() h.completed[partIndex] = true h.cond.Broadcast() @@ -194,7 +191,7 @@ func (h *streamHasher) Run() { } h.mu.Unlock() - // Read and hash this part (from page cache) + // Read and hash part remaining := part.Size for remaining > 0 { n := int64(len(buf)) @@ -250,9 +247,6 @@ func (p *blobDownloadPart) Name() string { func (p *blobDownloadPart) Write(b []byte) (n int, err error) { n = len(b) p.blobDownload.Completed.Add(int64(n)) - p.lastUpdatedMu.Lock() - p.lastUpdated = time.Now() - p.lastUpdatedMu.Unlock() return n, nil } @@ -410,7 +404,7 @@ func (b *blobDownload) run(ctx context.Context, requestURL *url.URL, opts *regis for i := range b.Parts { part := b.Parts[i] if part.Completed.Load() == part.Size { - sh.MarkComplete(part.N) + sh.Done(part.N) continue } @@ -420,7 +414,7 @@ func (b *blobDownload) run(ctx context.Context, requestURL *url.URL, opts *regis for try := 0; try < maxRetries; try++ { // After 3 slow retries, stop checking slowness and let it complete skipSlowCheck := slowRetries >= 3 - err = b.downloadChunkToDisk(inner, directURL, file, part, tracker, skipSlowCheck) + err = b.downloadChunk(inner, directURL, file, part, tracker, skipSlowCheck) switch { case errors.Is(err, context.Canceled), errors.Is(err, syscall.ENOSPC): return err @@ -438,7 +432,7 @@ func (b *blobDownload) run(ctx context.Context, requestURL *url.URL, opts *regis time.Sleep(sleep) continue default: - sh.MarkComplete(part.N) + sh.Done(part.N) return nil } } @@ -480,9 +474,9 @@ func (b *blobDownload) run(ctx context.Context, requestURL *url.URL, opts *regis return nil } -// downloadChunkToDisk streams a part directly to disk at its offset. +// downloadChunk streams a part directly to disk at its offset. // If skipSlowCheck is true, don't flag slow parts (used after repeated slow retries). -func (b *blobDownload) downloadChunkToDisk(ctx context.Context, requestURL *url.URL, file *os.File, part *blobDownloadPart, tracker *speedTracker, skipSlowCheck bool) error { +func (b *blobDownload) downloadChunk(ctx context.Context, requestURL *url.URL, file *os.File, part *blobDownloadPart, tracker *speedTracker, skipSlowCheck bool) error { g, ctx := errgroup.WithContext(ctx) startTime := time.Now() var bytesAtLastCheck atomic.Int64 @@ -512,10 +506,6 @@ func (b *blobDownload) downloadChunkToDisk(ctx context.Context, requestURL *url. written += int64(n) b.Completed.Add(int64(n)) bytesAtLastCheck.Store(written) - - part.lastUpdatedMu.Lock() - part.lastUpdated = time.Now() - part.lastUpdatedMu.Unlock() } if err == io.EOF { break @@ -663,21 +653,21 @@ type downloadOpts struct { } // downloadBlob downloads a blob from the registry and stores it in the blobs directory -func downloadBlob(ctx context.Context, opts downloadOpts) (cacheHit bool, _ error) { +func downloadBlob(ctx context.Context, opts downloadOpts) error { if opts.digest == "" { - return false, fmt.Errorf(("%s: %s"), opts.mp.GetNamespaceRepository(), "digest is empty") + return fmt.Errorf(("%s: %s"), opts.mp.GetNamespaceRepository(), "digest is empty") } fp, err := GetBlobsPath(opts.digest) if err != nil { - return false, err + return err } fi, err := os.Stat(fp) switch { case errors.Is(err, os.ErrNotExist): case err != nil: - return false, err + return err default: opts.fn(api.ProgressResponse{ Status: fmt.Sprintf("pulling %s", opts.digest[7:19]), @@ -686,7 +676,7 @@ func downloadBlob(ctx context.Context, opts downloadOpts) (cacheHit bool, _ erro Completed: fi.Size(), }) - return true, nil + return nil } data, ok := blobDownloadManager.LoadOrStore(opts.digest, &blobDownload{Name: fp, Digest: opts.digest}) @@ -696,12 +686,12 @@ func downloadBlob(ctx context.Context, opts downloadOpts) (cacheHit bool, _ erro requestURL = requestURL.JoinPath("v2", opts.mp.GetNamespaceRepository(), "blobs", opts.digest) if err := download.Prepare(ctx, requestURL, opts.regOpts); err != nil { blobDownloadManager.Delete(opts.digest) - return false, err + return err } //nolint:contextcheck go download.Run(context.Background(), requestURL, opts.regOpts) } - return false, download.Wait(ctx, opts.fn) + return download.Wait(ctx, opts.fn) } diff --git a/server/download_test.go b/server/download_test.go index d45e1113a..56842f237 100644 --- a/server/download_test.go +++ b/server/download_test.go @@ -12,40 +12,41 @@ import ( func TestSpeedTracker_Median(t *testing.T) { s := &speedTracker{} - // Less than 3 samples returns 0 - s.Record(100) - s.Record(200) + // Less than 10 samples returns 0 + for i := 0; i < 9; i++ { + s.Record(float64(100 + i*10)) + } if got := s.Median(); got != 0 { - t.Errorf("expected 0 with < 3 samples, got %f", got) + t.Errorf("expected 0 with < 10 samples, got %f", got) } - // With 3+ samples, returns median - s.Record(300) - // Samples: [100, 200, 300] -> median = 200 - if got := s.Median(); got != 200 { - t.Errorf("expected median 200, got %f", got) + // With 10+ samples, returns median + s.Record(190) + // Samples: [100, 110, 120, 130, 140, 150, 160, 170, 180, 190] -> median = 150 + if got := s.Median(); got != 150 { + t.Errorf("expected median 150, got %f", got) } // Add more samples s.Record(50) - s.Record(250) - // Samples: [100, 200, 300, 50, 250] sorted = [50, 100, 200, 250, 300] -> median = 200 - if got := s.Median(); got != 200 { - t.Errorf("expected median 200, got %f", got) + // Samples: [100, 110, 120, 130, 140, 150, 160, 170, 180, 190, 50] + // sorted = [50, 100, 110, 120, 130, 140, 150, 160, 170, 180, 190] -> median = 140 + if got := s.Median(); got != 140 { + t.Errorf("expected median 140, got %f", got) } } func TestSpeedTracker_RollingWindow(t *testing.T) { s := &speedTracker{} - // Add 105 samples (should keep only last 100) - for i := 0; i < 105; i++ { + // Add 35 samples (should keep only last 30) + for i := 0; i < 35; i++ { s.Record(float64(i)) } s.mu.Lock() - if len(s.speeds) != 100 { - t.Errorf("expected 100 samples, got %d", len(s.speeds)) + if len(s.speeds) != 30 { + t.Errorf("expected 30 samples, got %d", len(s.speeds)) } // First sample should be 5 (0-4 were dropped) if s.speeds[0] != 5 { @@ -99,7 +100,7 @@ func TestStreamHasher_Sequential(t *testing.T) { sh := newStreamHasher(f, parts, int64(len(data))) // Mark complete and run - sh.MarkComplete(0) + sh.Done(0) done := make(chan struct{}) go func() { @@ -150,9 +151,9 @@ func TestStreamHasher_OutOfOrderCompletion(t *testing.T) { }() // Mark parts complete out of order: 2, 0, 1 - sh.MarkComplete(2) - sh.MarkComplete(0) // This should trigger hashing of part 0 - sh.MarkComplete(1) // This should trigger hashing of parts 1 and 2 + sh.Done(2) + sh.Done(0) // This should trigger hashing of part 0 + sh.Done(1) // This should trigger hashing of parts 1 and 2 <-done @@ -228,7 +229,7 @@ func TestStreamHasher_HashedProgress(t *testing.T) { }() // Complete part 0 - sh.MarkComplete(0) + sh.Done(0) // Give hasher time to process for i := 0; i < 100; i++ { @@ -238,7 +239,7 @@ func TestStreamHasher_HashedProgress(t *testing.T) { } // Complete part 1 - sh.MarkComplete(1) + sh.Done(1) <-done if got := sh.Hashed(); got != 1000 { @@ -291,7 +292,7 @@ func BenchmarkStreamHasher(b *testing.B) { for i := 0; i < b.N; i++ { sh := newStreamHasher(f, parts, int64(size)) - sh.MarkComplete(0) + sh.Done(0) done := make(chan struct{}) go func() { diff --git a/server/images.go b/server/images.go index d3de232b1..4cec2233e 100644 --- a/server/images.go +++ b/server/images.go @@ -621,22 +621,18 @@ func PullModel(ctx context.Context, name string, regOpts *registryOptions, fn fu } for _, layer := range layers { - _, err := downloadBlob(ctx, downloadOpts{ + if err := downloadBlob(ctx, downloadOpts{ mp: mp, digest: layer.Digest, regOpts: regOpts, fn: fn, - }) - if err != nil { + }); err != nil { return err } delete(deleteMap, layer.Digest) } delete(deleteMap, manifest.Config.Digest) - // Note: Digest verification now happens inline during download in blobDownload.run() - // via the orderedWriter, so no separate verification pass is needed. - fn(api.ProgressResponse{Status: "writing manifest"}) manifestJSON, err := json.Marshal(manifest) @@ -839,25 +835,3 @@ func parseRegistryChallenge(authStr string) registryChallenge { Scope: getValue(authStr, "scope"), } } - -var errDigestMismatch = errors.New("digest mismatch, file must be downloaded again") - -func verifyBlob(digest string) error { - fp, err := GetBlobsPath(digest) - if err != nil { - return err - } - - f, err := os.Open(fp) - if err != nil { - return err - } - defer f.Close() - - fileDigest, _ := GetSHA256Digest(f) - if digest != fileDigest { - return fmt.Errorf("%w: want %s, got %s", errDigestMismatch, digest, fileDigest) - } - - return nil -} From bf63d18b11cb6262e1228b30a94bcc19176e8c37 Mon Sep 17 00:00:00 2001 From: jmorganca Date: Sat, 20 Dec 2025 22:46:46 -0800 Subject: [PATCH 4/4] linter --- server/download.go | 3 +++ server/download_test.go | 28 ++++++++++++++-------------- server/routes.go | 1 - 3 files changed, 17 insertions(+), 15 deletions(-) diff --git a/server/download.go b/server/download.go index 3996a8334..cd6839f67 100644 --- a/server/download.go +++ b/server/download.go @@ -340,6 +340,9 @@ func (b *blobDownload) run(ctx context.Context, requestURL *url.URL, opts *regis return err } defer file.Close() + setSparse(file) + + _ = file.Truncate(b.Total) directURL, err := func() (*url.URL, error) { ctx, cancel := context.WithTimeout(ctx, 30*time.Second) diff --git a/server/download_test.go b/server/download_test.go index 56842f237..c97f37143 100644 --- a/server/download_test.go +++ b/server/download_test.go @@ -13,7 +13,7 @@ func TestSpeedTracker_Median(t *testing.T) { s := &speedTracker{} // Less than 10 samples returns 0 - for i := 0; i < 9; i++ { + for i := range 9 { s.Record(float64(100 + i*10)) } if got := s.Median(); got != 0 { @@ -40,7 +40,7 @@ func TestSpeedTracker_RollingWindow(t *testing.T) { s := &speedTracker{} // Add 35 samples (should keep only last 30) - for i := 0; i < 35; i++ { + for i := range 35 { s.Record(float64(i)) } @@ -59,7 +59,7 @@ func TestSpeedTracker_Concurrent(t *testing.T) { s := &speedTracker{} var wg sync.WaitGroup - for i := 0; i < 100; i++ { + for i := range 100 { wg.Add(1) go func(v int) { defer wg.Done() @@ -79,7 +79,7 @@ func TestSpeedTracker_Concurrent(t *testing.T) { func TestStreamHasher_Sequential(t *testing.T) { // Create temp file - f, err := os.CreateTemp("", "streamhasher_test") + f, err := os.CreateTemp(t.TempDir(), "streamhasher_test") if err != nil { t.Fatal(err) } @@ -122,7 +122,7 @@ func TestStreamHasher_Sequential(t *testing.T) { func TestStreamHasher_OutOfOrderCompletion(t *testing.T) { // Create temp file - f, err := os.CreateTemp("", "streamhasher_test") + f, err := os.CreateTemp(t.TempDir(), "streamhasher_test") if err != nil { t.Fatal(err) } @@ -166,7 +166,7 @@ func TestStreamHasher_OutOfOrderCompletion(t *testing.T) { func TestStreamHasher_Stop(t *testing.T) { // Create temp file - f, err := os.CreateTemp("", "streamhasher_test") + f, err := os.CreateTemp(t.TempDir(), "streamhasher_test") if err != nil { t.Fatal(err) } @@ -197,7 +197,7 @@ func TestStreamHasher_Stop(t *testing.T) { func TestStreamHasher_HashedProgress(t *testing.T) { // Create temp file with known data - f, err := os.CreateTemp("", "streamhasher_test") + f, err := os.CreateTemp(t.TempDir(), "streamhasher_test") if err != nil { t.Fatal(err) } @@ -232,7 +232,7 @@ func TestStreamHasher_HashedProgress(t *testing.T) { sh.Done(0) // Give hasher time to process - for i := 0; i < 100; i++ { + for range 100 { if sh.Hashed() >= 500 { break } @@ -250,7 +250,7 @@ func TestStreamHasher_HashedProgress(t *testing.T) { func BenchmarkSpeedTracker_Record(b *testing.B) { s := &speedTracker{} b.ResetTimer() - for i := 0; i < b.N; i++ { + for i := range b.N { s.Record(float64(i)) } } @@ -258,18 +258,18 @@ func BenchmarkSpeedTracker_Record(b *testing.B) { func BenchmarkSpeedTracker_Median(b *testing.B) { s := &speedTracker{} // Pre-populate with 100 samples - for i := 0; i < 100; i++ { + for i := range 100 { s.Record(float64(i)) } b.ResetTimer() - for i := 0; i < b.N; i++ { + for range b.N { s.Median() } } func BenchmarkStreamHasher(b *testing.B) { // Create temp file with test data - f, err := os.CreateTemp("", "streamhasher_bench") + f, err := os.CreateTemp(b.TempDir(), "streamhasher_bench") if err != nil { b.Fatal(err) } @@ -290,7 +290,7 @@ func BenchmarkStreamHasher(b *testing.B) { b.SetBytes(int64(size)) b.ResetTimer() - for i := 0; i < b.N; i++ { + for range b.N { sh := newStreamHasher(f, parts, int64(size)) sh.Done(0) @@ -312,7 +312,7 @@ func BenchmarkHashThroughput(b *testing.B) { b.SetBytes(int64(size)) b.ResetTimer() - for i := 0; i < b.N; i++ { + for range b.N { h := sha256.New() h.Write(data) h.Sum(nil) diff --git a/server/routes.go b/server/routes.go index b19a40fbc..dfa84db2b 100644 --- a/server/routes.go +++ b/server/routes.go @@ -2395,4 +2395,3 @@ func filterThinkTags(msgs []api.Message, m *Model) []api.Message { } return msgs } -