From 7c5b656bb3724ba2ae9f03476d4119738283e4f5 Mon Sep 17 00:00:00 2001 From: jmorganca Date: Sat, 20 Dec 2025 16:54:51 -0800 Subject: [PATCH] wip --- server/download.go | 177 +++++++++++++++++++++++++++++++++------- server/download_test.go | 166 +++++++++++++++++++++++++++++++++++++ server/images.go | 26 +----- 3 files changed, 318 insertions(+), 51 deletions(-) create mode 100644 server/download_test.go diff --git a/server/download.go b/server/download.go index 42d713c09..1265460ea 100644 --- a/server/download.go +++ b/server/download.go @@ -2,9 +2,11 @@ package server import ( "context" + "crypto/sha256" "encoding/json" "errors" "fmt" + "hash" "io" "log/slog" "math" @@ -94,12 +96,100 @@ func (p *blobDownloadPart) UnmarshalJSON(b []byte) error { return nil } -const ( - numDownloadParts = 16 - minDownloadPartSize int64 = 100 * format.MegaByte - maxDownloadPartSize int64 = 1000 * format.MegaByte +const numDownloadParts = 16 + +// Download tuning. Override via environment variables for different environments: +// - Developer laptop: defaults are fine +// - Fast server (10Gbps+): OLLAMA_DOWNLOAD_CONCURRENCY=32 or higher +// - Memory constrained: reduce OLLAMA_DOWNLOAD_PART_SIZE and OLLAMA_DOWNLOAD_BUFFER_SIZE +var ( + // downloadPartSize is the size of each download part. + // Smaller = less memory, more HTTP requests. + // Default: 16MB. Override with OLLAMA_DOWNLOAD_PART_SIZE (in MB). + downloadPartSize = int64(getEnvInt("OLLAMA_DOWNLOAD_PART_SIZE", 16)) * format.MegaByte + + // downloadConcurrency limits concurrent part downloads. + // Higher = faster on fast connections, more memory. + // Default: 8. Override with OLLAMA_DOWNLOAD_CONCURRENCY. + downloadConcurrency = getEnvInt("OLLAMA_DOWNLOAD_CONCURRENCY", 8) + + // downloadBufferSize is the max bytes buffered in orderedWriter before + // Submit blocks (backpressure). This bounds memory usage. + // Default: 128MB. Override with OLLAMA_DOWNLOAD_BUFFER_SIZE (in MB). + // Total memory ≈ (concurrency × part_size) + buffer_size + // Default: (8 × 16MB) + 128MB = 256MB max + downloadBufferSize = int64(getEnvInt("OLLAMA_DOWNLOAD_BUFFER_SIZE", 128)) * format.MegaByte ) +func getEnvInt(key string, defaultVal int) int { + if s := os.Getenv(key); s != "" { + if v, err := strconv.Atoi(s); err == nil { + return v + } + } + return defaultVal +} + +// orderedWriter buffers out-of-order parts and writes them sequentially +// through a hasher and file. This allows parallel downloads while computing +// the hash incrementally without a post-download verification pass. +type orderedWriter struct { + mu sync.Mutex + cond *sync.Cond + next int // next expected part index + pending map[int][]byte // out-of-order parts waiting to be written + pendingSize int64 // total bytes in pending + out io.Writer // destination (typically MultiWriter(file, hasher)) + hasher hash.Hash // for computing final digest +} + +func newOrderedWriter(file io.Writer, hasher hash.Hash) *orderedWriter { + w := &orderedWriter{ + pending: make(map[int][]byte), + out: io.MultiWriter(file, hasher), + hasher: hasher, + } + w.cond = sync.NewCond(&w.mu) + return w +} + +// Submit adds a part to the writer. Parts are written in order; if this part +// is out of order, it's buffered until earlier parts arrive. Blocks if the +// pending buffer exceeds downloadBufferSize (backpressure), unless this is the +// next expected part (which will drain the buffer). +func (w *orderedWriter) Submit(partIndex int, data []byte) error { + w.mu.Lock() + defer w.mu.Unlock() + + // Backpressure: wait if buffer is too full, unless we're the next part + // (the next part will drain the buffer, so it must always proceed) + for w.pendingSize+int64(len(data)) > downloadBufferSize && partIndex != w.next { + w.cond.Wait() + } + + w.pending[partIndex] = data + w.pendingSize += int64(len(data)) + + // Write all consecutive parts starting from next + for w.pending[w.next] != nil { + data := w.pending[w.next] + if _, err := w.out.Write(data); err != nil { + return err + } + w.pendingSize -= int64(len(data)) + w.pending[w.next] = nil // help GC free the slice + delete(w.pending, w.next) + w.next++ + w.cond.Broadcast() // wake any blocked submitters + } + return nil +} + +// Digest returns the computed hash after all parts have been written. +func (w *orderedWriter) Digest() string { + return fmt.Sprintf("sha256:%x", w.hasher.Sum(nil)) +} + func (p *blobDownloadPart) Name() string { return strings.Join([]string{ p.blobDownload.Name, "partial", strconv.Itoa(p.N), @@ -153,10 +243,10 @@ func (b *blobDownload) Prepare(ctx context.Context, requestURL *url.URL, opts *r size := b.Total / numDownloadParts switch { - case size < minDownloadPartSize: - size = minDownloadPartSize - case size > maxDownloadPartSize: - size = maxDownloadPartSize + case size < downloadPartSize: + size = downloadPartSize + case size > downloadPartSize: + size = downloadPartSize } var offset int64 @@ -220,9 +310,6 @@ func (b *blobDownload) run(ctx context.Context, requestURL *url.URL, opts *regis return err } defer file.Close() - setSparse(file) - - _ = file.Truncate(b.Total) directURL, err := func() (*url.URL, error) { ctx, cancel := context.WithTimeout(ctx, 30*time.Second) @@ -270,8 +357,13 @@ func (b *blobDownload) run(ctx context.Context, requestURL *url.URL, opts *regis return err } + // Download chunks in parallel, hash while writing via ordered writer. + // Memory: (concurrency × part_size) + buffer. For 2GB files with defaults: + // 8 × 32MB + 64MB = 320MB max. + ow := newOrderedWriter(file, sha256.New()) + g, inner := errgroup.WithContext(ctx) - g.SetLimit(numDownloadParts) + g.SetLimit(downloadConcurrency) for i := range b.Parts { part := b.Parts[i] if part.Completed.Load() == part.Size { @@ -279,13 +371,12 @@ func (b *blobDownload) run(ctx context.Context, requestURL *url.URL, opts *regis } g.Go(func() error { + var data []byte var err error for try := 0; try < maxRetries; try++ { - w := io.NewOffsetWriter(file, part.StartsAt()) - err = b.downloadChunk(inner, directURL, w, part) + data, err = b.downloadChunkToBuffer(inner, directURL, part) switch { case errors.Is(err, context.Canceled), errors.Is(err, syscall.ENOSPC): - // return immediately if the context is canceled or the device is out of space return err case errors.Is(err, errPartStalled): try-- @@ -296,10 +387,11 @@ func (b *blobDownload) run(ctx context.Context, requestURL *url.URL, opts *regis time.Sleep(sleep) continue default: - return nil + err := ow.Submit(part.N, data) + data = nil // help GC + return err } } - return fmt.Errorf("%w: %w", errMaxRetriesExceeded, err) }) } @@ -308,6 +400,12 @@ func (b *blobDownload) run(ctx context.Context, requestURL *url.URL, opts *regis return err } + // Verify hash - no re-read needed, hash was computed while writing + computed := ow.Digest() + if computed != b.Digest { + return fmt.Errorf("digest mismatch: got %s, want %s", computed, b.Digest) + } + // explicitly close the file so we can rename it if err := file.Close(); err != nil { return err @@ -326,38 +424,58 @@ func (b *blobDownload) run(ctx context.Context, requestURL *url.URL, opts *regis return nil } -func (b *blobDownload) downloadChunk(ctx context.Context, requestURL *url.URL, w io.Writer, part *blobDownloadPart) error { +// downloadChunkToBuffer downloads a part to a buffer, tracking progress and detecting stalls. +func (b *blobDownload) downloadChunkToBuffer(ctx context.Context, requestURL *url.URL, part *blobDownloadPart) ([]byte, error) { g, ctx := errgroup.WithContext(ctx) + + var data []byte g.Go(func() error { req, err := http.NewRequestWithContext(ctx, http.MethodGet, requestURL.String(), nil) if err != nil { return err } - req.Header.Set("Range", fmt.Sprintf("bytes=%d-%d", part.StartsAt(), part.StopsAt()-1)) + req.Header.Set("Range", fmt.Sprintf("bytes=%d-%d", part.Offset, part.Offset+part.Size-1)) resp, err := http.DefaultClient.Do(req) if err != nil { return err } defer resp.Body.Close() - n, err := io.CopyN(w, io.TeeReader(resp.Body, part), part.Size-part.Completed.Load()) - if err != nil && !errors.Is(err, context.Canceled) && !errors.Is(err, io.ErrUnexpectedEOF) { - // rollback progress - b.Completed.Add(-n) - return err + // Pre-allocate buffer for the part + data = make([]byte, 0, part.Size) + buf := make([]byte, 32*1024) // 32KB read buffer + + for { + n, err := resp.Body.Read(buf) + if n > 0 { + data = append(data, buf[:n]...) + b.Completed.Add(int64(n)) + + part.lastUpdatedMu.Lock() + part.lastUpdated = time.Now() + part.lastUpdatedMu.Unlock() + } + if err == io.EOF { + break + } + if err != nil { + // rollback progress + b.Completed.Add(-int64(len(data))) + return err + } } - part.Completed.Add(n) + part.Completed.Store(part.Size) if err := b.writePart(part.Name(), part); err != nil { return err } - // return nil or context.Canceled or UnexpectedEOF (resumable) - return err + return nil }) g.Go(func() error { ticker := time.NewTicker(time.Second) + defer ticker.Stop() for { select { case <-ticker.C: @@ -384,7 +502,10 @@ func (b *blobDownload) downloadChunk(ctx context.Context, requestURL *url.URL, w } }) - return g.Wait() + if err := g.Wait(); err != nil { + return nil, err + } + return data, nil } func (b *blobDownload) newPart(offset, size int64) error { diff --git a/server/download_test.go b/server/download_test.go new file mode 100644 index 000000000..cd54b7e84 --- /dev/null +++ b/server/download_test.go @@ -0,0 +1,166 @@ +package server + +import ( + "bytes" + "crypto/rand" + "crypto/sha256" + "fmt" + "io" + "testing" +) + +func TestOrderedWriter_InOrder(t *testing.T) { + var buf bytes.Buffer + hasher := sha256.New() + ow := newOrderedWriter(&buf, hasher) + + // Submit parts in order + for i := 0; i < 5; i++ { + data := []byte{byte(i), byte(i), byte(i)} + if err := ow.Submit(i, data); err != nil { + t.Fatalf("Submit(%d) failed: %v", i, err) + } + } + + // Verify output + expected := []byte{0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4} + if !bytes.Equal(buf.Bytes(), expected) { + t.Errorf("got %v, want %v", buf.Bytes(), expected) + } +} + +func TestOrderedWriter_OutOfOrder(t *testing.T) { + var buf bytes.Buffer + hasher := sha256.New() + ow := newOrderedWriter(&buf, hasher) + + // Submit parts out of order: 2, 4, 1, 0, 3 + order := []int{2, 4, 1, 0, 3} + for _, i := range order { + data := []byte{byte(i), byte(i), byte(i)} + if err := ow.Submit(i, data); err != nil { + t.Fatalf("Submit(%d) failed: %v", i, err) + } + } + + // Verify output is still in correct order + expected := []byte{0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4} + if !bytes.Equal(buf.Bytes(), expected) { + t.Errorf("got %v, want %v", buf.Bytes(), expected) + } +} + +func TestOrderedWriter_Digest(t *testing.T) { + var buf bytes.Buffer + hasher := sha256.New() + ow := newOrderedWriter(&buf, hasher) + + // Submit some data + data := []byte("hello world") + if err := ow.Submit(0, data); err != nil { + t.Fatalf("Submit failed: %v", err) + } + + // Verify digest format and correctness + got := ow.Digest() + if len(got) != 71 { // "sha256:" + 64 hex chars + t.Errorf("digest has wrong length: %d, got: %s", len(got), got) + } + if got[:7] != "sha256:" { + t.Errorf("digest doesn't start with sha256: %s", got) + } + + // Verify it matches expected hash + expectedHash := sha256.Sum256(data) + want := "sha256:" + fmt.Sprintf("%x", expectedHash[:]) + if got != want { + t.Errorf("digest mismatch: got %s, want %s", got, want) + } +} + +func BenchmarkOrderedWriter_InOrder(b *testing.B) { + // Benchmark throughput when parts arrive in order (best case) + partSize := 64 * 1024 * 1024 // 64MB parts + numParts := 4 + data := make([]byte, partSize) + rand.Read(data) + + b.SetBytes(int64(partSize * numParts)) + b.ResetTimer() + + for i := 0; i < b.N; i++ { + ow := newOrderedWriter(io.Discard, sha256.New()) + for p := 0; p < numParts; p++ { + if err := ow.Submit(p, data); err != nil { + b.Fatal(err) + } + } + } +} + +func BenchmarkOrderedWriter_OutOfOrder(b *testing.B) { + // Benchmark throughput when parts arrive out of order (worst case) + partSize := 64 * 1024 * 1024 // 64MB parts + numParts := 4 + data := make([]byte, partSize) + rand.Read(data) + + // Reverse order: 3, 2, 1, 0 + order := make([]int, numParts) + for i := 0; i < numParts; i++ { + order[i] = numParts - 1 - i + } + + b.SetBytes(int64(partSize * numParts)) + b.ResetTimer() + + for i := 0; i < b.N; i++ { + ow := newOrderedWriter(io.Discard, sha256.New()) + for _, p := range order { + if err := ow.Submit(p, data); err != nil { + b.Fatal(err) + } + } + } +} + +func BenchmarkHashThroughput(b *testing.B) { + // Baseline: raw SHA256 throughput on this machine + size := 256 * 1024 * 1024 // 256MB + data := make([]byte, size) + rand.Read(data) + + b.SetBytes(int64(size)) + b.ResetTimer() + + for i := 0; i < b.N; i++ { + h := sha256.New() + h.Write(data) + h.Sum(nil) + } +} + +func BenchmarkOrderedWriter_Memory(b *testing.B) { + // Measure memory when buffering out-of-order parts + partSize := 64 * 1024 * 1024 // 64MB parts + numParts := 4 + data := make([]byte, partSize) + rand.Read(data) + + b.ReportAllocs() + b.ResetTimer() + + for i := 0; i < b.N; i++ { + ow := newOrderedWriter(io.Discard, sha256.New()) + // Submit all except part 0 (forces buffering) + for p := 1; p < numParts; p++ { + if err := ow.Submit(p, data); err != nil { + b.Fatal(err) + } + } + // Submit part 0 to flush + if err := ow.Submit(0, data); err != nil { + b.Fatal(err) + } + } +} diff --git a/server/images.go b/server/images.go index 951f7ac6e..d3de232b1 100644 --- a/server/images.go +++ b/server/images.go @@ -620,9 +620,8 @@ func PullModel(ctx context.Context, name string, regOpts *registryOptions, fn fu layers = append(layers, manifest.Config) } - skipVerify := make(map[string]bool) for _, layer := range layers { - cacheHit, err := downloadBlob(ctx, downloadOpts{ + _, err := downloadBlob(ctx, downloadOpts{ mp: mp, digest: layer.Digest, regOpts: regOpts, @@ -631,31 +630,12 @@ func PullModel(ctx context.Context, name string, regOpts *registryOptions, fn fu if err != nil { return err } - skipVerify[layer.Digest] = cacheHit delete(deleteMap, layer.Digest) } delete(deleteMap, manifest.Config.Digest) - fn(api.ProgressResponse{Status: "verifying sha256 digest"}) - for _, layer := range layers { - if skipVerify[layer.Digest] { - continue - } - if err := verifyBlob(layer.Digest); err != nil { - if errors.Is(err, errDigestMismatch) { - // something went wrong, delete the blob - fp, err := GetBlobsPath(layer.Digest) - if err != nil { - return err - } - if err := os.Remove(fp); err != nil { - // log this, but return the original error - slog.Info(fmt.Sprintf("couldn't remove file with digest mismatch '%s': %v", fp, err)) - } - } - return err - } - } + // Note: Digest verification now happens inline during download in blobDownload.run() + // via the orderedWriter, so no separate verification pass is needed. fn(api.ProgressResponse{Status: "writing manifest"})