wip
This commit is contained in:
parent
bddb27ab5b
commit
7c5b656bb3
|
|
@ -2,9 +2,11 @@ package server
|
|||
|
||||
import (
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"hash"
|
||||
"io"
|
||||
"log/slog"
|
||||
"math"
|
||||
|
|
@ -94,12 +96,100 @@ func (p *blobDownloadPart) UnmarshalJSON(b []byte) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
const (
|
||||
numDownloadParts = 16
|
||||
minDownloadPartSize int64 = 100 * format.MegaByte
|
||||
maxDownloadPartSize int64 = 1000 * format.MegaByte
|
||||
const numDownloadParts = 16
|
||||
|
||||
// Download tuning. Override via environment variables for different environments:
|
||||
// - Developer laptop: defaults are fine
|
||||
// - Fast server (10Gbps+): OLLAMA_DOWNLOAD_CONCURRENCY=32 or higher
|
||||
// - Memory constrained: reduce OLLAMA_DOWNLOAD_PART_SIZE and OLLAMA_DOWNLOAD_BUFFER_SIZE
|
||||
var (
|
||||
// downloadPartSize is the size of each download part.
|
||||
// Smaller = less memory, more HTTP requests.
|
||||
// Default: 16MB. Override with OLLAMA_DOWNLOAD_PART_SIZE (in MB).
|
||||
downloadPartSize = int64(getEnvInt("OLLAMA_DOWNLOAD_PART_SIZE", 16)) * format.MegaByte
|
||||
|
||||
// downloadConcurrency limits concurrent part downloads.
|
||||
// Higher = faster on fast connections, more memory.
|
||||
// Default: 8. Override with OLLAMA_DOWNLOAD_CONCURRENCY.
|
||||
downloadConcurrency = getEnvInt("OLLAMA_DOWNLOAD_CONCURRENCY", 8)
|
||||
|
||||
// downloadBufferSize is the max bytes buffered in orderedWriter before
|
||||
// Submit blocks (backpressure). This bounds memory usage.
|
||||
// Default: 128MB. Override with OLLAMA_DOWNLOAD_BUFFER_SIZE (in MB).
|
||||
// Total memory ≈ (concurrency × part_size) + buffer_size
|
||||
// Default: (8 × 16MB) + 128MB = 256MB max
|
||||
downloadBufferSize = int64(getEnvInt("OLLAMA_DOWNLOAD_BUFFER_SIZE", 128)) * format.MegaByte
|
||||
)
|
||||
|
||||
func getEnvInt(key string, defaultVal int) int {
|
||||
if s := os.Getenv(key); s != "" {
|
||||
if v, err := strconv.Atoi(s); err == nil {
|
||||
return v
|
||||
}
|
||||
}
|
||||
return defaultVal
|
||||
}
|
||||
|
||||
// orderedWriter buffers out-of-order parts and writes them sequentially
|
||||
// through a hasher and file. This allows parallel downloads while computing
|
||||
// the hash incrementally without a post-download verification pass.
|
||||
type orderedWriter struct {
|
||||
mu sync.Mutex
|
||||
cond *sync.Cond
|
||||
next int // next expected part index
|
||||
pending map[int][]byte // out-of-order parts waiting to be written
|
||||
pendingSize int64 // total bytes in pending
|
||||
out io.Writer // destination (typically MultiWriter(file, hasher))
|
||||
hasher hash.Hash // for computing final digest
|
||||
}
|
||||
|
||||
func newOrderedWriter(file io.Writer, hasher hash.Hash) *orderedWriter {
|
||||
w := &orderedWriter{
|
||||
pending: make(map[int][]byte),
|
||||
out: io.MultiWriter(file, hasher),
|
||||
hasher: hasher,
|
||||
}
|
||||
w.cond = sync.NewCond(&w.mu)
|
||||
return w
|
||||
}
|
||||
|
||||
// Submit adds a part to the writer. Parts are written in order; if this part
|
||||
// is out of order, it's buffered until earlier parts arrive. Blocks if the
|
||||
// pending buffer exceeds downloadBufferSize (backpressure), unless this is the
|
||||
// next expected part (which will drain the buffer).
|
||||
func (w *orderedWriter) Submit(partIndex int, data []byte) error {
|
||||
w.mu.Lock()
|
||||
defer w.mu.Unlock()
|
||||
|
||||
// Backpressure: wait if buffer is too full, unless we're the next part
|
||||
// (the next part will drain the buffer, so it must always proceed)
|
||||
for w.pendingSize+int64(len(data)) > downloadBufferSize && partIndex != w.next {
|
||||
w.cond.Wait()
|
||||
}
|
||||
|
||||
w.pending[partIndex] = data
|
||||
w.pendingSize += int64(len(data))
|
||||
|
||||
// Write all consecutive parts starting from next
|
||||
for w.pending[w.next] != nil {
|
||||
data := w.pending[w.next]
|
||||
if _, err := w.out.Write(data); err != nil {
|
||||
return err
|
||||
}
|
||||
w.pendingSize -= int64(len(data))
|
||||
w.pending[w.next] = nil // help GC free the slice
|
||||
delete(w.pending, w.next)
|
||||
w.next++
|
||||
w.cond.Broadcast() // wake any blocked submitters
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Digest returns the computed hash after all parts have been written.
|
||||
func (w *orderedWriter) Digest() string {
|
||||
return fmt.Sprintf("sha256:%x", w.hasher.Sum(nil))
|
||||
}
|
||||
|
||||
func (p *blobDownloadPart) Name() string {
|
||||
return strings.Join([]string{
|
||||
p.blobDownload.Name, "partial", strconv.Itoa(p.N),
|
||||
|
|
@ -153,10 +243,10 @@ func (b *blobDownload) Prepare(ctx context.Context, requestURL *url.URL, opts *r
|
|||
|
||||
size := b.Total / numDownloadParts
|
||||
switch {
|
||||
case size < minDownloadPartSize:
|
||||
size = minDownloadPartSize
|
||||
case size > maxDownloadPartSize:
|
||||
size = maxDownloadPartSize
|
||||
case size < downloadPartSize:
|
||||
size = downloadPartSize
|
||||
case size > downloadPartSize:
|
||||
size = downloadPartSize
|
||||
}
|
||||
|
||||
var offset int64
|
||||
|
|
@ -220,9 +310,6 @@ func (b *blobDownload) run(ctx context.Context, requestURL *url.URL, opts *regis
|
|||
return err
|
||||
}
|
||||
defer file.Close()
|
||||
setSparse(file)
|
||||
|
||||
_ = file.Truncate(b.Total)
|
||||
|
||||
directURL, err := func() (*url.URL, error) {
|
||||
ctx, cancel := context.WithTimeout(ctx, 30*time.Second)
|
||||
|
|
@ -270,8 +357,13 @@ func (b *blobDownload) run(ctx context.Context, requestURL *url.URL, opts *regis
|
|||
return err
|
||||
}
|
||||
|
||||
// Download chunks in parallel, hash while writing via ordered writer.
|
||||
// Memory: (concurrency × part_size) + buffer. For 2GB files with defaults:
|
||||
// 8 × 32MB + 64MB = 320MB max.
|
||||
ow := newOrderedWriter(file, sha256.New())
|
||||
|
||||
g, inner := errgroup.WithContext(ctx)
|
||||
g.SetLimit(numDownloadParts)
|
||||
g.SetLimit(downloadConcurrency)
|
||||
for i := range b.Parts {
|
||||
part := b.Parts[i]
|
||||
if part.Completed.Load() == part.Size {
|
||||
|
|
@ -279,13 +371,12 @@ func (b *blobDownload) run(ctx context.Context, requestURL *url.URL, opts *regis
|
|||
}
|
||||
|
||||
g.Go(func() error {
|
||||
var data []byte
|
||||
var err error
|
||||
for try := 0; try < maxRetries; try++ {
|
||||
w := io.NewOffsetWriter(file, part.StartsAt())
|
||||
err = b.downloadChunk(inner, directURL, w, part)
|
||||
data, err = b.downloadChunkToBuffer(inner, directURL, part)
|
||||
switch {
|
||||
case errors.Is(err, context.Canceled), errors.Is(err, syscall.ENOSPC):
|
||||
// return immediately if the context is canceled or the device is out of space
|
||||
return err
|
||||
case errors.Is(err, errPartStalled):
|
||||
try--
|
||||
|
|
@ -296,10 +387,11 @@ func (b *blobDownload) run(ctx context.Context, requestURL *url.URL, opts *regis
|
|||
time.Sleep(sleep)
|
||||
continue
|
||||
default:
|
||||
return nil
|
||||
err := ow.Submit(part.N, data)
|
||||
data = nil // help GC
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return fmt.Errorf("%w: %w", errMaxRetriesExceeded, err)
|
||||
})
|
||||
}
|
||||
|
|
@ -308,6 +400,12 @@ func (b *blobDownload) run(ctx context.Context, requestURL *url.URL, opts *regis
|
|||
return err
|
||||
}
|
||||
|
||||
// Verify hash - no re-read needed, hash was computed while writing
|
||||
computed := ow.Digest()
|
||||
if computed != b.Digest {
|
||||
return fmt.Errorf("digest mismatch: got %s, want %s", computed, b.Digest)
|
||||
}
|
||||
|
||||
// explicitly close the file so we can rename it
|
||||
if err := file.Close(); err != nil {
|
||||
return err
|
||||
|
|
@ -326,38 +424,58 @@ func (b *blobDownload) run(ctx context.Context, requestURL *url.URL, opts *regis
|
|||
return nil
|
||||
}
|
||||
|
||||
func (b *blobDownload) downloadChunk(ctx context.Context, requestURL *url.URL, w io.Writer, part *blobDownloadPart) error {
|
||||
// downloadChunkToBuffer downloads a part to a buffer, tracking progress and detecting stalls.
|
||||
func (b *blobDownload) downloadChunkToBuffer(ctx context.Context, requestURL *url.URL, part *blobDownloadPart) ([]byte, error) {
|
||||
g, ctx := errgroup.WithContext(ctx)
|
||||
|
||||
var data []byte
|
||||
g.Go(func() error {
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, requestURL.String(), nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
req.Header.Set("Range", fmt.Sprintf("bytes=%d-%d", part.StartsAt(), part.StopsAt()-1))
|
||||
req.Header.Set("Range", fmt.Sprintf("bytes=%d-%d", part.Offset, part.Offset+part.Size-1))
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
n, err := io.CopyN(w, io.TeeReader(resp.Body, part), part.Size-part.Completed.Load())
|
||||
if err != nil && !errors.Is(err, context.Canceled) && !errors.Is(err, io.ErrUnexpectedEOF) {
|
||||
// rollback progress
|
||||
b.Completed.Add(-n)
|
||||
return err
|
||||
// Pre-allocate buffer for the part
|
||||
data = make([]byte, 0, part.Size)
|
||||
buf := make([]byte, 32*1024) // 32KB read buffer
|
||||
|
||||
for {
|
||||
n, err := resp.Body.Read(buf)
|
||||
if n > 0 {
|
||||
data = append(data, buf[:n]...)
|
||||
b.Completed.Add(int64(n))
|
||||
|
||||
part.lastUpdatedMu.Lock()
|
||||
part.lastUpdated = time.Now()
|
||||
part.lastUpdatedMu.Unlock()
|
||||
}
|
||||
if err == io.EOF {
|
||||
break
|
||||
}
|
||||
if err != nil {
|
||||
// rollback progress
|
||||
b.Completed.Add(-int64(len(data)))
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
part.Completed.Add(n)
|
||||
part.Completed.Store(part.Size)
|
||||
if err := b.writePart(part.Name(), part); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// return nil or context.Canceled or UnexpectedEOF (resumable)
|
||||
return err
|
||||
return nil
|
||||
})
|
||||
|
||||
g.Go(func() error {
|
||||
ticker := time.NewTicker(time.Second)
|
||||
defer ticker.Stop()
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
|
|
@ -384,7 +502,10 @@ func (b *blobDownload) downloadChunk(ctx context.Context, requestURL *url.URL, w
|
|||
}
|
||||
})
|
||||
|
||||
return g.Wait()
|
||||
if err := g.Wait(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return data, nil
|
||||
}
|
||||
|
||||
func (b *blobDownload) newPart(offset, size int64) error {
|
||||
|
|
|
|||
|
|
@ -0,0 +1,166 @@
|
|||
package server
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/rand"
|
||||
"crypto/sha256"
|
||||
"fmt"
|
||||
"io"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestOrderedWriter_InOrder(t *testing.T) {
|
||||
var buf bytes.Buffer
|
||||
hasher := sha256.New()
|
||||
ow := newOrderedWriter(&buf, hasher)
|
||||
|
||||
// Submit parts in order
|
||||
for i := 0; i < 5; i++ {
|
||||
data := []byte{byte(i), byte(i), byte(i)}
|
||||
if err := ow.Submit(i, data); err != nil {
|
||||
t.Fatalf("Submit(%d) failed: %v", i, err)
|
||||
}
|
||||
}
|
||||
|
||||
// Verify output
|
||||
expected := []byte{0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4}
|
||||
if !bytes.Equal(buf.Bytes(), expected) {
|
||||
t.Errorf("got %v, want %v", buf.Bytes(), expected)
|
||||
}
|
||||
}
|
||||
|
||||
func TestOrderedWriter_OutOfOrder(t *testing.T) {
|
||||
var buf bytes.Buffer
|
||||
hasher := sha256.New()
|
||||
ow := newOrderedWriter(&buf, hasher)
|
||||
|
||||
// Submit parts out of order: 2, 4, 1, 0, 3
|
||||
order := []int{2, 4, 1, 0, 3}
|
||||
for _, i := range order {
|
||||
data := []byte{byte(i), byte(i), byte(i)}
|
||||
if err := ow.Submit(i, data); err != nil {
|
||||
t.Fatalf("Submit(%d) failed: %v", i, err)
|
||||
}
|
||||
}
|
||||
|
||||
// Verify output is still in correct order
|
||||
expected := []byte{0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4}
|
||||
if !bytes.Equal(buf.Bytes(), expected) {
|
||||
t.Errorf("got %v, want %v", buf.Bytes(), expected)
|
||||
}
|
||||
}
|
||||
|
||||
func TestOrderedWriter_Digest(t *testing.T) {
|
||||
var buf bytes.Buffer
|
||||
hasher := sha256.New()
|
||||
ow := newOrderedWriter(&buf, hasher)
|
||||
|
||||
// Submit some data
|
||||
data := []byte("hello world")
|
||||
if err := ow.Submit(0, data); err != nil {
|
||||
t.Fatalf("Submit failed: %v", err)
|
||||
}
|
||||
|
||||
// Verify digest format and correctness
|
||||
got := ow.Digest()
|
||||
if len(got) != 71 { // "sha256:" + 64 hex chars
|
||||
t.Errorf("digest has wrong length: %d, got: %s", len(got), got)
|
||||
}
|
||||
if got[:7] != "sha256:" {
|
||||
t.Errorf("digest doesn't start with sha256: %s", got)
|
||||
}
|
||||
|
||||
// Verify it matches expected hash
|
||||
expectedHash := sha256.Sum256(data)
|
||||
want := "sha256:" + fmt.Sprintf("%x", expectedHash[:])
|
||||
if got != want {
|
||||
t.Errorf("digest mismatch: got %s, want %s", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkOrderedWriter_InOrder(b *testing.B) {
|
||||
// Benchmark throughput when parts arrive in order (best case)
|
||||
partSize := 64 * 1024 * 1024 // 64MB parts
|
||||
numParts := 4
|
||||
data := make([]byte, partSize)
|
||||
rand.Read(data)
|
||||
|
||||
b.SetBytes(int64(partSize * numParts))
|
||||
b.ResetTimer()
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
ow := newOrderedWriter(io.Discard, sha256.New())
|
||||
for p := 0; p < numParts; p++ {
|
||||
if err := ow.Submit(p, data); err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkOrderedWriter_OutOfOrder(b *testing.B) {
|
||||
// Benchmark throughput when parts arrive out of order (worst case)
|
||||
partSize := 64 * 1024 * 1024 // 64MB parts
|
||||
numParts := 4
|
||||
data := make([]byte, partSize)
|
||||
rand.Read(data)
|
||||
|
||||
// Reverse order: 3, 2, 1, 0
|
||||
order := make([]int, numParts)
|
||||
for i := 0; i < numParts; i++ {
|
||||
order[i] = numParts - 1 - i
|
||||
}
|
||||
|
||||
b.SetBytes(int64(partSize * numParts))
|
||||
b.ResetTimer()
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
ow := newOrderedWriter(io.Discard, sha256.New())
|
||||
for _, p := range order {
|
||||
if err := ow.Submit(p, data); err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkHashThroughput(b *testing.B) {
|
||||
// Baseline: raw SHA256 throughput on this machine
|
||||
size := 256 * 1024 * 1024 // 256MB
|
||||
data := make([]byte, size)
|
||||
rand.Read(data)
|
||||
|
||||
b.SetBytes(int64(size))
|
||||
b.ResetTimer()
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
h := sha256.New()
|
||||
h.Write(data)
|
||||
h.Sum(nil)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkOrderedWriter_Memory(b *testing.B) {
|
||||
// Measure memory when buffering out-of-order parts
|
||||
partSize := 64 * 1024 * 1024 // 64MB parts
|
||||
numParts := 4
|
||||
data := make([]byte, partSize)
|
||||
rand.Read(data)
|
||||
|
||||
b.ReportAllocs()
|
||||
b.ResetTimer()
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
ow := newOrderedWriter(io.Discard, sha256.New())
|
||||
// Submit all except part 0 (forces buffering)
|
||||
for p := 1; p < numParts; p++ {
|
||||
if err := ow.Submit(p, data); err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
}
|
||||
// Submit part 0 to flush
|
||||
if err := ow.Submit(0, data); err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -620,9 +620,8 @@ func PullModel(ctx context.Context, name string, regOpts *registryOptions, fn fu
|
|||
layers = append(layers, manifest.Config)
|
||||
}
|
||||
|
||||
skipVerify := make(map[string]bool)
|
||||
for _, layer := range layers {
|
||||
cacheHit, err := downloadBlob(ctx, downloadOpts{
|
||||
_, err := downloadBlob(ctx, downloadOpts{
|
||||
mp: mp,
|
||||
digest: layer.Digest,
|
||||
regOpts: regOpts,
|
||||
|
|
@ -631,31 +630,12 @@ func PullModel(ctx context.Context, name string, regOpts *registryOptions, fn fu
|
|||
if err != nil {
|
||||
return err
|
||||
}
|
||||
skipVerify[layer.Digest] = cacheHit
|
||||
delete(deleteMap, layer.Digest)
|
||||
}
|
||||
delete(deleteMap, manifest.Config.Digest)
|
||||
|
||||
fn(api.ProgressResponse{Status: "verifying sha256 digest"})
|
||||
for _, layer := range layers {
|
||||
if skipVerify[layer.Digest] {
|
||||
continue
|
||||
}
|
||||
if err := verifyBlob(layer.Digest); err != nil {
|
||||
if errors.Is(err, errDigestMismatch) {
|
||||
// something went wrong, delete the blob
|
||||
fp, err := GetBlobsPath(layer.Digest)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := os.Remove(fp); err != nil {
|
||||
// log this, but return the original error
|
||||
slog.Info(fmt.Sprintf("couldn't remove file with digest mismatch '%s': %v", fp, err))
|
||||
}
|
||||
}
|
||||
return err
|
||||
}
|
||||
}
|
||||
// Note: Digest verification now happens inline during download in blobDownload.run()
|
||||
// via the orderedWriter, so no separate verification pass is needed.
|
||||
|
||||
fn(api.ProgressResponse{Status: "writing manifest"})
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue