This commit is contained in:
jmorganca 2025-12-20 16:54:51 -08:00
parent bddb27ab5b
commit 7c5b656bb3
3 changed files with 318 additions and 51 deletions

View File

@ -2,9 +2,11 @@ package server
import (
"context"
"crypto/sha256"
"encoding/json"
"errors"
"fmt"
"hash"
"io"
"log/slog"
"math"
@ -94,12 +96,100 @@ func (p *blobDownloadPart) UnmarshalJSON(b []byte) error {
return nil
}
const (
numDownloadParts = 16
minDownloadPartSize int64 = 100 * format.MegaByte
maxDownloadPartSize int64 = 1000 * format.MegaByte
const numDownloadParts = 16
// Download tuning. Override via environment variables for different environments:
// - Developer laptop: defaults are fine
// - Fast server (10Gbps+): OLLAMA_DOWNLOAD_CONCURRENCY=32 or higher
// - Memory constrained: reduce OLLAMA_DOWNLOAD_PART_SIZE and OLLAMA_DOWNLOAD_BUFFER_SIZE
var (
// downloadPartSize is the size of each download part.
// Smaller = less memory, more HTTP requests.
// Default: 16MB. Override with OLLAMA_DOWNLOAD_PART_SIZE (in MB).
downloadPartSize = int64(getEnvInt("OLLAMA_DOWNLOAD_PART_SIZE", 16)) * format.MegaByte
// downloadConcurrency limits concurrent part downloads.
// Higher = faster on fast connections, more memory.
// Default: 8. Override with OLLAMA_DOWNLOAD_CONCURRENCY.
downloadConcurrency = getEnvInt("OLLAMA_DOWNLOAD_CONCURRENCY", 8)
// downloadBufferSize is the max bytes buffered in orderedWriter before
// Submit blocks (backpressure). This bounds memory usage.
// Default: 128MB. Override with OLLAMA_DOWNLOAD_BUFFER_SIZE (in MB).
// Total memory ≈ (concurrency × part_size) + buffer_size
// Default: (8 × 16MB) + 128MB = 256MB max
downloadBufferSize = int64(getEnvInt("OLLAMA_DOWNLOAD_BUFFER_SIZE", 128)) * format.MegaByte
)
func getEnvInt(key string, defaultVal int) int {
if s := os.Getenv(key); s != "" {
if v, err := strconv.Atoi(s); err == nil {
return v
}
}
return defaultVal
}
// orderedWriter buffers out-of-order parts and writes them sequentially
// through a hasher and file. This allows parallel downloads while computing
// the hash incrementally without a post-download verification pass.
type orderedWriter struct {
mu sync.Mutex
cond *sync.Cond
next int // next expected part index
pending map[int][]byte // out-of-order parts waiting to be written
pendingSize int64 // total bytes in pending
out io.Writer // destination (typically MultiWriter(file, hasher))
hasher hash.Hash // for computing final digest
}
func newOrderedWriter(file io.Writer, hasher hash.Hash) *orderedWriter {
w := &orderedWriter{
pending: make(map[int][]byte),
out: io.MultiWriter(file, hasher),
hasher: hasher,
}
w.cond = sync.NewCond(&w.mu)
return w
}
// Submit adds a part to the writer. Parts are written in order; if this part
// is out of order, it's buffered until earlier parts arrive. Blocks if the
// pending buffer exceeds downloadBufferSize (backpressure), unless this is the
// next expected part (which will drain the buffer).
func (w *orderedWriter) Submit(partIndex int, data []byte) error {
w.mu.Lock()
defer w.mu.Unlock()
// Backpressure: wait if buffer is too full, unless we're the next part
// (the next part will drain the buffer, so it must always proceed)
for w.pendingSize+int64(len(data)) > downloadBufferSize && partIndex != w.next {
w.cond.Wait()
}
w.pending[partIndex] = data
w.pendingSize += int64(len(data))
// Write all consecutive parts starting from next
for w.pending[w.next] != nil {
data := w.pending[w.next]
if _, err := w.out.Write(data); err != nil {
return err
}
w.pendingSize -= int64(len(data))
w.pending[w.next] = nil // help GC free the slice
delete(w.pending, w.next)
w.next++
w.cond.Broadcast() // wake any blocked submitters
}
return nil
}
// Digest returns the computed hash after all parts have been written.
func (w *orderedWriter) Digest() string {
return fmt.Sprintf("sha256:%x", w.hasher.Sum(nil))
}
func (p *blobDownloadPart) Name() string {
return strings.Join([]string{
p.blobDownload.Name, "partial", strconv.Itoa(p.N),
@ -153,10 +243,10 @@ func (b *blobDownload) Prepare(ctx context.Context, requestURL *url.URL, opts *r
size := b.Total / numDownloadParts
switch {
case size < minDownloadPartSize:
size = minDownloadPartSize
case size > maxDownloadPartSize:
size = maxDownloadPartSize
case size < downloadPartSize:
size = downloadPartSize
case size > downloadPartSize:
size = downloadPartSize
}
var offset int64
@ -220,9 +310,6 @@ func (b *blobDownload) run(ctx context.Context, requestURL *url.URL, opts *regis
return err
}
defer file.Close()
setSparse(file)
_ = file.Truncate(b.Total)
directURL, err := func() (*url.URL, error) {
ctx, cancel := context.WithTimeout(ctx, 30*time.Second)
@ -270,8 +357,13 @@ func (b *blobDownload) run(ctx context.Context, requestURL *url.URL, opts *regis
return err
}
// Download chunks in parallel, hash while writing via ordered writer.
// Memory: (concurrency × part_size) + buffer. For 2GB files with defaults:
// 8 × 32MB + 64MB = 320MB max.
ow := newOrderedWriter(file, sha256.New())
g, inner := errgroup.WithContext(ctx)
g.SetLimit(numDownloadParts)
g.SetLimit(downloadConcurrency)
for i := range b.Parts {
part := b.Parts[i]
if part.Completed.Load() == part.Size {
@ -279,13 +371,12 @@ func (b *blobDownload) run(ctx context.Context, requestURL *url.URL, opts *regis
}
g.Go(func() error {
var data []byte
var err error
for try := 0; try < maxRetries; try++ {
w := io.NewOffsetWriter(file, part.StartsAt())
err = b.downloadChunk(inner, directURL, w, part)
data, err = b.downloadChunkToBuffer(inner, directURL, part)
switch {
case errors.Is(err, context.Canceled), errors.Is(err, syscall.ENOSPC):
// return immediately if the context is canceled or the device is out of space
return err
case errors.Is(err, errPartStalled):
try--
@ -296,10 +387,11 @@ func (b *blobDownload) run(ctx context.Context, requestURL *url.URL, opts *regis
time.Sleep(sleep)
continue
default:
return nil
err := ow.Submit(part.N, data)
data = nil // help GC
return err
}
}
return fmt.Errorf("%w: %w", errMaxRetriesExceeded, err)
})
}
@ -308,6 +400,12 @@ func (b *blobDownload) run(ctx context.Context, requestURL *url.URL, opts *regis
return err
}
// Verify hash - no re-read needed, hash was computed while writing
computed := ow.Digest()
if computed != b.Digest {
return fmt.Errorf("digest mismatch: got %s, want %s", computed, b.Digest)
}
// explicitly close the file so we can rename it
if err := file.Close(); err != nil {
return err
@ -326,38 +424,58 @@ func (b *blobDownload) run(ctx context.Context, requestURL *url.URL, opts *regis
return nil
}
func (b *blobDownload) downloadChunk(ctx context.Context, requestURL *url.URL, w io.Writer, part *blobDownloadPart) error {
// downloadChunkToBuffer downloads a part to a buffer, tracking progress and detecting stalls.
func (b *blobDownload) downloadChunkToBuffer(ctx context.Context, requestURL *url.URL, part *blobDownloadPart) ([]byte, error) {
g, ctx := errgroup.WithContext(ctx)
var data []byte
g.Go(func() error {
req, err := http.NewRequestWithContext(ctx, http.MethodGet, requestURL.String(), nil)
if err != nil {
return err
}
req.Header.Set("Range", fmt.Sprintf("bytes=%d-%d", part.StartsAt(), part.StopsAt()-1))
req.Header.Set("Range", fmt.Sprintf("bytes=%d-%d", part.Offset, part.Offset+part.Size-1))
resp, err := http.DefaultClient.Do(req)
if err != nil {
return err
}
defer resp.Body.Close()
n, err := io.CopyN(w, io.TeeReader(resp.Body, part), part.Size-part.Completed.Load())
if err != nil && !errors.Is(err, context.Canceled) && !errors.Is(err, io.ErrUnexpectedEOF) {
// rollback progress
b.Completed.Add(-n)
return err
// Pre-allocate buffer for the part
data = make([]byte, 0, part.Size)
buf := make([]byte, 32*1024) // 32KB read buffer
for {
n, err := resp.Body.Read(buf)
if n > 0 {
data = append(data, buf[:n]...)
b.Completed.Add(int64(n))
part.lastUpdatedMu.Lock()
part.lastUpdated = time.Now()
part.lastUpdatedMu.Unlock()
}
if err == io.EOF {
break
}
if err != nil {
// rollback progress
b.Completed.Add(-int64(len(data)))
return err
}
}
part.Completed.Add(n)
part.Completed.Store(part.Size)
if err := b.writePart(part.Name(), part); err != nil {
return err
}
// return nil or context.Canceled or UnexpectedEOF (resumable)
return err
return nil
})
g.Go(func() error {
ticker := time.NewTicker(time.Second)
defer ticker.Stop()
for {
select {
case <-ticker.C:
@ -384,7 +502,10 @@ func (b *blobDownload) downloadChunk(ctx context.Context, requestURL *url.URL, w
}
})
return g.Wait()
if err := g.Wait(); err != nil {
return nil, err
}
return data, nil
}
func (b *blobDownload) newPart(offset, size int64) error {

166
server/download_test.go Normal file
View File

@ -0,0 +1,166 @@
package server
import (
"bytes"
"crypto/rand"
"crypto/sha256"
"fmt"
"io"
"testing"
)
func TestOrderedWriter_InOrder(t *testing.T) {
var buf bytes.Buffer
hasher := sha256.New()
ow := newOrderedWriter(&buf, hasher)
// Submit parts in order
for i := 0; i < 5; i++ {
data := []byte{byte(i), byte(i), byte(i)}
if err := ow.Submit(i, data); err != nil {
t.Fatalf("Submit(%d) failed: %v", i, err)
}
}
// Verify output
expected := []byte{0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4}
if !bytes.Equal(buf.Bytes(), expected) {
t.Errorf("got %v, want %v", buf.Bytes(), expected)
}
}
func TestOrderedWriter_OutOfOrder(t *testing.T) {
var buf bytes.Buffer
hasher := sha256.New()
ow := newOrderedWriter(&buf, hasher)
// Submit parts out of order: 2, 4, 1, 0, 3
order := []int{2, 4, 1, 0, 3}
for _, i := range order {
data := []byte{byte(i), byte(i), byte(i)}
if err := ow.Submit(i, data); err != nil {
t.Fatalf("Submit(%d) failed: %v", i, err)
}
}
// Verify output is still in correct order
expected := []byte{0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4}
if !bytes.Equal(buf.Bytes(), expected) {
t.Errorf("got %v, want %v", buf.Bytes(), expected)
}
}
func TestOrderedWriter_Digest(t *testing.T) {
var buf bytes.Buffer
hasher := sha256.New()
ow := newOrderedWriter(&buf, hasher)
// Submit some data
data := []byte("hello world")
if err := ow.Submit(0, data); err != nil {
t.Fatalf("Submit failed: %v", err)
}
// Verify digest format and correctness
got := ow.Digest()
if len(got) != 71 { // "sha256:" + 64 hex chars
t.Errorf("digest has wrong length: %d, got: %s", len(got), got)
}
if got[:7] != "sha256:" {
t.Errorf("digest doesn't start with sha256: %s", got)
}
// Verify it matches expected hash
expectedHash := sha256.Sum256(data)
want := "sha256:" + fmt.Sprintf("%x", expectedHash[:])
if got != want {
t.Errorf("digest mismatch: got %s, want %s", got, want)
}
}
func BenchmarkOrderedWriter_InOrder(b *testing.B) {
// Benchmark throughput when parts arrive in order (best case)
partSize := 64 * 1024 * 1024 // 64MB parts
numParts := 4
data := make([]byte, partSize)
rand.Read(data)
b.SetBytes(int64(partSize * numParts))
b.ResetTimer()
for i := 0; i < b.N; i++ {
ow := newOrderedWriter(io.Discard, sha256.New())
for p := 0; p < numParts; p++ {
if err := ow.Submit(p, data); err != nil {
b.Fatal(err)
}
}
}
}
func BenchmarkOrderedWriter_OutOfOrder(b *testing.B) {
// Benchmark throughput when parts arrive out of order (worst case)
partSize := 64 * 1024 * 1024 // 64MB parts
numParts := 4
data := make([]byte, partSize)
rand.Read(data)
// Reverse order: 3, 2, 1, 0
order := make([]int, numParts)
for i := 0; i < numParts; i++ {
order[i] = numParts - 1 - i
}
b.SetBytes(int64(partSize * numParts))
b.ResetTimer()
for i := 0; i < b.N; i++ {
ow := newOrderedWriter(io.Discard, sha256.New())
for _, p := range order {
if err := ow.Submit(p, data); err != nil {
b.Fatal(err)
}
}
}
}
func BenchmarkHashThroughput(b *testing.B) {
// Baseline: raw SHA256 throughput on this machine
size := 256 * 1024 * 1024 // 256MB
data := make([]byte, size)
rand.Read(data)
b.SetBytes(int64(size))
b.ResetTimer()
for i := 0; i < b.N; i++ {
h := sha256.New()
h.Write(data)
h.Sum(nil)
}
}
func BenchmarkOrderedWriter_Memory(b *testing.B) {
// Measure memory when buffering out-of-order parts
partSize := 64 * 1024 * 1024 // 64MB parts
numParts := 4
data := make([]byte, partSize)
rand.Read(data)
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
ow := newOrderedWriter(io.Discard, sha256.New())
// Submit all except part 0 (forces buffering)
for p := 1; p < numParts; p++ {
if err := ow.Submit(p, data); err != nil {
b.Fatal(err)
}
}
// Submit part 0 to flush
if err := ow.Submit(0, data); err != nil {
b.Fatal(err)
}
}
}

View File

@ -620,9 +620,8 @@ func PullModel(ctx context.Context, name string, regOpts *registryOptions, fn fu
layers = append(layers, manifest.Config)
}
skipVerify := make(map[string]bool)
for _, layer := range layers {
cacheHit, err := downloadBlob(ctx, downloadOpts{
_, err := downloadBlob(ctx, downloadOpts{
mp: mp,
digest: layer.Digest,
regOpts: regOpts,
@ -631,31 +630,12 @@ func PullModel(ctx context.Context, name string, regOpts *registryOptions, fn fu
if err != nil {
return err
}
skipVerify[layer.Digest] = cacheHit
delete(deleteMap, layer.Digest)
}
delete(deleteMap, manifest.Config.Digest)
fn(api.ProgressResponse{Status: "verifying sha256 digest"})
for _, layer := range layers {
if skipVerify[layer.Digest] {
continue
}
if err := verifyBlob(layer.Digest); err != nil {
if errors.Is(err, errDigestMismatch) {
// something went wrong, delete the blob
fp, err := GetBlobsPath(layer.Digest)
if err != nil {
return err
}
if err := os.Remove(fp); err != nil {
// log this, but return the original error
slog.Info(fmt.Sprintf("couldn't remove file with digest mismatch '%s': %v", fp, err))
}
}
return err
}
}
// Note: Digest verification now happens inline during download in blobDownload.run()
// via the orderedWriter, so no separate verification pass is needed.
fn(api.ProgressResponse{Status: "writing manifest"})