server: stream hash verification during download

Hash blob data while downloading (by trying to using page cache as much as possible)
instead of after, improving download speeds. Add configurable download concurrency
(default 48) and part size (default 64MB) for faster downloads on high-bandwidth
connections.
This commit is contained in:
jmorganca 2025-12-20 16:16:34 -08:00
parent 172b5924af
commit 2aee6c172b
5 changed files with 611 additions and 99 deletions

View File

@ -2,9 +2,11 @@ package server
import ( import (
"context" "context"
"crypto/sha256"
"encoding/json" "encoding/json"
"errors" "errors"
"fmt" "fmt"
"hash"
"io" "io"
"log/slog" "log/slog"
"math" "math"
@ -31,9 +33,45 @@ const maxRetries = 6
var ( var (
errMaxRetriesExceeded = errors.New("max retries exceeded") errMaxRetriesExceeded = errors.New("max retries exceeded")
errPartStalled = errors.New("part stalled") errPartStalled = errors.New("part stalled")
errPartSlow = errors.New("part slow, racing")
errMaxRedirectsExceeded = errors.New("maximum redirects exceeded (10) for directURL") errMaxRedirectsExceeded = errors.New("maximum redirects exceeded (10) for directURL")
) )
// speedTracker tracks download speeds and computes rolling median.
type speedTracker struct {
mu sync.Mutex
speeds []float64 // bytes per second
}
func (s *speedTracker) Record(bytesPerSec float64) {
s.mu.Lock()
defer s.mu.Unlock()
s.speeds = append(s.speeds, bytesPerSec)
// Keep last 100 samples
if len(s.speeds) > 100 {
s.speeds = s.speeds[1:]
}
}
func (s *speedTracker) Median() float64 {
s.mu.Lock()
defer s.mu.Unlock()
if len(s.speeds) < 3 {
return 0 // not enough data
}
// Simple median: sort a copy and take middle
sorted := make([]float64, len(s.speeds))
copy(sorted, s.speeds)
for i := range sorted {
for j := i + 1; j < len(sorted); j++ {
if sorted[j] < sorted[i] {
sorted[i], sorted[j] = sorted[j], sorted[i]
}
}
}
return sorted[len(sorted)/2]
}
var blobDownloadManager sync.Map var blobDownloadManager sync.Map
type blobDownload struct { type blobDownload struct {
@ -94,26 +132,127 @@ func (p *blobDownloadPart) UnmarshalJSON(b []byte) error {
return nil return nil
} }
const ( var (
numDownloadParts = 16 downloadPartSize = int64(envInt("OLLAMA_DOWNLOAD_PART_SIZE", 64)) * format.MegaByte
minDownloadPartSize int64 = 100 * format.MegaByte downloadConcurrency = envInt("OLLAMA_DOWNLOAD_CONCURRENCY", 48)
maxDownloadPartSize int64 = 1000 * format.MegaByte
) )
func envInt(key string, defaultVal int) int {
if s := os.Getenv(key); s != "" {
if v, err := strconv.Atoi(s); err == nil {
return v
}
}
return defaultVal
}
// streamHasher reads a file sequentially and hashes it as chunks complete.
// Memory usage: ~64KB (just the read buffer), regardless of file size or concurrency.
// Works by reading from OS page cache - data just written is still in RAM.
type streamHasher struct {
file *os.File
hasher hash.Hash
parts []*blobDownloadPart
total int64 // total bytes to hash
hashed atomic.Int64
mu sync.Mutex
cond *sync.Cond
completed []bool
done bool
err error
}
func newStreamHasher(file *os.File, parts []*blobDownloadPart, total int64) *streamHasher {
h := &streamHasher{
file: file,
hasher: sha256.New(),
parts: parts,
total: total,
completed: make([]bool, len(parts)),
}
h.cond = sync.NewCond(&h.mu)
return h
}
// MarkComplete signals that a part has been written to disk.
func (h *streamHasher) MarkComplete(partIndex int) {
h.mu.Lock()
h.completed[partIndex] = true
h.cond.Broadcast()
h.mu.Unlock()
}
// Run reads and hashes the file sequentially. Call in a goroutine.
func (h *streamHasher) Run() {
buf := make([]byte, 64*1024) // 64KB read buffer
var offset int64
for i, part := range h.parts {
// Wait for this part to be written
h.mu.Lock()
for !h.completed[i] && !h.done {
h.cond.Wait()
}
if h.done {
h.mu.Unlock()
return
}
h.mu.Unlock()
// Read and hash this part (from page cache)
remaining := part.Size
for remaining > 0 {
n := int64(len(buf))
if n > remaining {
n = remaining
}
nr, err := h.file.ReadAt(buf[:n], offset)
if err != nil && err != io.EOF {
h.mu.Lock()
h.err = err
h.mu.Unlock()
return
}
h.hasher.Write(buf[:nr])
offset += int64(nr)
remaining -= int64(nr)
h.hashed.Store(offset)
}
}
}
// Stop signals the hasher to exit early.
func (h *streamHasher) Stop() {
h.mu.Lock()
h.done = true
h.cond.Broadcast()
h.mu.Unlock()
}
// Hashed returns bytes hashed so far.
func (h *streamHasher) Hashed() int64 {
return h.hashed.Load()
}
// Digest returns the computed hash.
func (h *streamHasher) Digest() string {
return fmt.Sprintf("sha256:%x", h.hasher.Sum(nil))
}
// Err returns any error from hashing.
func (h *streamHasher) Err() error {
h.mu.Lock()
defer h.mu.Unlock()
return h.err
}
func (p *blobDownloadPart) Name() string { func (p *blobDownloadPart) Name() string {
return strings.Join([]string{ return strings.Join([]string{
p.blobDownload.Name, "partial", strconv.Itoa(p.N), p.blobDownload.Name, "partial", strconv.Itoa(p.N),
}, "-") }, "-")
} }
func (p *blobDownloadPart) StartsAt() int64 {
return p.Offset + p.Completed.Load()
}
func (p *blobDownloadPart) StopsAt() int64 {
return p.Offset + p.Size
}
func (p *blobDownloadPart) Write(b []byte) (n int, err error) { func (p *blobDownloadPart) Write(b []byte) (n int, err error) {
n = len(b) n = len(b)
p.blobDownload.Completed.Add(int64(n)) p.blobDownload.Completed.Add(int64(n))
@ -151,14 +290,7 @@ func (b *blobDownload) Prepare(ctx context.Context, requestURL *url.URL, opts *r
b.Total, _ = strconv.ParseInt(resp.Header.Get("Content-Length"), 10, 64) b.Total, _ = strconv.ParseInt(resp.Header.Get("Content-Length"), 10, 64)
size := b.Total / numDownloadParts size := downloadPartSize
switch {
case size < minDownloadPartSize:
size = minDownloadPartSize
case size > maxDownloadPartSize:
size = maxDownloadPartSize
}
var offset int64 var offset int64
for offset < b.Total { for offset < b.Total {
if offset+size > b.Total { if offset+size > b.Total {
@ -220,9 +352,6 @@ func (b *blobDownload) run(ctx context.Context, requestURL *url.URL, opts *regis
return err return err
} }
defer file.Close() defer file.Close()
setSparse(file)
_ = file.Truncate(b.Total)
directURL, err := func() (*url.URL, error) { directURL, err := func() (*url.URL, error) {
ctx, cancel := context.WithTimeout(ctx, 30*time.Second) ctx, cancel := context.WithTimeout(ctx, 30*time.Second)
@ -270,44 +399,106 @@ func (b *blobDownload) run(ctx context.Context, requestURL *url.URL, opts *regis
return err return err
} }
// Download chunks to disk, hash by reading from page cache.
// Memory: ~64KB (hasher read buffer only), regardless of concurrency.
// The hasher follows behind the downloaders, reading recently-written
// data from OS page cache (RAM) rather than disk.
sh := newStreamHasher(file, b.Parts, b.Total)
tracker := &speedTracker{}
// Start hasher goroutine
hashDone := make(chan struct{})
go func() {
sh.Run()
close(hashDone)
}()
// Log progress periodically
// Page cache warning: if spread > 4GB, hasher may hit disk instead of RAM
const pageCacheWarningBytes = 4 << 30 // 4GB
progressDone := make(chan struct{})
go func() {
ticker := time.NewTicker(time.Second)
defer ticker.Stop()
for {
select {
case <-ticker.C:
downloaded := b.Completed.Load()
hashed := sh.Hashed()
dlPct := int(downloaded * 100 / b.Total)
hPct := int(hashed * 100 / b.Total)
spread := dlPct - hPct
spreadBytes := downloaded - hashed
slog.Debug(fmt.Sprintf("progress: downloaded %d%% | hashed %d%% | spread %d%%", dlPct, hPct, spread))
if spreadBytes > pageCacheWarningBytes {
slog.Debug("page cache pressure", "ahead", fmt.Sprintf("%.1fGB", float64(spreadBytes)/(1<<30)))
}
case <-progressDone:
return
}
}
}()
g, inner := errgroup.WithContext(ctx) g, inner := errgroup.WithContext(ctx)
g.SetLimit(numDownloadParts) g.SetLimit(downloadConcurrency)
for i := range b.Parts { for i := range b.Parts {
part := b.Parts[i] part := b.Parts[i]
if part.Completed.Load() == part.Size { if part.Completed.Load() == part.Size {
sh.MarkComplete(part.N)
continue continue
} }
g.Go(func() error { g.Go(func() error {
var err error var err error
var slowRetries int
for try := 0; try < maxRetries; try++ { for try := 0; try < maxRetries; try++ {
w := io.NewOffsetWriter(file, part.StartsAt()) // After 3 slow retries, stop checking slowness and let it complete
err = b.downloadChunk(inner, directURL, w, part) skipSlowCheck := slowRetries >= 3
err = b.downloadChunkToDisk(inner, directURL, file, part, tracker, skipSlowCheck)
switch { switch {
case errors.Is(err, context.Canceled), errors.Is(err, syscall.ENOSPC): case errors.Is(err, context.Canceled), errors.Is(err, syscall.ENOSPC):
// return immediately if the context is canceled or the device is out of space
return err return err
case errors.Is(err, errPartStalled): case errors.Is(err, errPartStalled):
try-- try--
continue continue
case errors.Is(err, errPartSlow):
// Kill slow request, retry immediately (stays within concurrency limit)
slowRetries++
try--
continue
case err != nil: case err != nil:
sleep := time.Second * time.Duration(math.Pow(2, float64(try))) sleep := time.Second * time.Duration(math.Pow(2, float64(try)))
slog.Info(fmt.Sprintf("%s part %d attempt %d failed: %v, retrying in %s", b.Digest[7:19], part.N, try, err, sleep)) slog.Info(fmt.Sprintf("%s part %d attempt %d failed: %v, retrying in %s", b.Digest[7:19], part.N, try, err, sleep))
time.Sleep(sleep) time.Sleep(sleep)
continue continue
default: default:
sh.MarkComplete(part.N)
return nil return nil
} }
} }
return fmt.Errorf("%w: %w", errMaxRetriesExceeded, err) return fmt.Errorf("%w: %w", errMaxRetriesExceeded, err)
}) })
} }
if err := g.Wait(); err != nil { if err := g.Wait(); err != nil {
close(progressDone)
sh.Stop()
return err return err
} }
// Wait for hasher to finish
<-hashDone
close(progressDone)
if err := sh.Err(); err != nil {
return err
}
// Verify hash
if computed := sh.Digest(); computed != b.Digest {
return fmt.Errorf("digest mismatch: got %s, want %s", computed, b.Digest)
}
// explicitly close the file so we can rename it // explicitly close the file so we can rename it
if err := file.Close(); err != nil { if err := file.Close(); err != nil {
return err return err
@ -326,38 +517,69 @@ func (b *blobDownload) run(ctx context.Context, requestURL *url.URL, opts *regis
return nil return nil
} }
func (b *blobDownload) downloadChunk(ctx context.Context, requestURL *url.URL, w io.Writer, part *blobDownloadPart) error { // downloadChunkToDisk streams a part directly to disk at its offset.
// Memory: ~32KB (read buffer only).
// If skipSlowCheck is true, don't flag slow parts (used after repeated slow retries).
func (b *blobDownload) downloadChunkToDisk(ctx context.Context, requestURL *url.URL, file *os.File, part *blobDownloadPart, tracker *speedTracker, skipSlowCheck bool) error {
g, ctx := errgroup.WithContext(ctx) g, ctx := errgroup.WithContext(ctx)
startTime := time.Now()
var bytesAtLastCheck atomic.Int64
g.Go(func() error { g.Go(func() error {
req, err := http.NewRequestWithContext(ctx, http.MethodGet, requestURL.String(), nil) req, err := http.NewRequestWithContext(ctx, http.MethodGet, requestURL.String(), nil)
if err != nil { if err != nil {
return err return err
} }
req.Header.Set("Range", fmt.Sprintf("bytes=%d-%d", part.StartsAt(), part.StopsAt()-1)) req.Header.Set("Range", fmt.Sprintf("bytes=%d-%d", part.Offset, part.Offset+part.Size-1))
resp, err := http.DefaultClient.Do(req) resp, err := http.DefaultClient.Do(req)
if err != nil { if err != nil {
return err return err
} }
defer resp.Body.Close() defer resp.Body.Close()
n, err := io.CopyN(w, io.TeeReader(resp.Body, part), part.Size-part.Completed.Load()) w := io.NewOffsetWriter(file, part.Offset)
if err != nil && !errors.Is(err, context.Canceled) && !errors.Is(err, io.ErrUnexpectedEOF) { buf := make([]byte, 32*1024)
// rollback progress
b.Completed.Add(-n) var written int64
for written < part.Size {
n, err := resp.Body.Read(buf)
if n > 0 {
if _, werr := w.Write(buf[:n]); werr != nil {
return werr
}
written += int64(n)
b.Completed.Add(int64(n))
bytesAtLastCheck.Store(written)
part.lastUpdatedMu.Lock()
part.lastUpdated = time.Now()
part.lastUpdatedMu.Unlock()
}
if err == io.EOF {
break
}
if err != nil {
b.Completed.Add(-written)
return err return err
} }
part.Completed.Add(n)
if err := b.writePart(part.Name(), part); err != nil {
return err
} }
// return nil or context.Canceled or UnexpectedEOF (resumable) // Record speed for this part
return err elapsed := time.Since(startTime).Seconds()
if elapsed > 0 {
tracker.Record(float64(part.Size) / elapsed)
}
part.Completed.Store(part.Size)
return b.writePart(part.Name(), part)
}) })
g.Go(func() error { g.Go(func() error {
ticker := time.NewTicker(time.Second) ticker := time.NewTicker(time.Second)
defer ticker.Stop()
var lastBytes int64
checksWithoutProgress := 0
for { for {
select { select {
case <-ticker.C: case <-ticker.C:
@ -365,19 +587,35 @@ func (b *blobDownload) downloadChunk(ctx context.Context, requestURL *url.URL, w
return nil return nil
} }
part.lastUpdatedMu.Lock() currentBytes := bytesAtLastCheck.Load()
lastUpdated := part.lastUpdated
part.lastUpdatedMu.Unlock()
if !lastUpdated.IsZero() && time.Since(lastUpdated) > 30*time.Second { // Check for stall (no progress for 10 seconds)
const msg = "%s part %d stalled; retrying. If this persists, press ctrl-c to exit, then 'ollama pull' to find a faster connection." if currentBytes == lastBytes {
slog.Info(fmt.Sprintf(msg, b.Digest[7:19], part.N)) checksWithoutProgress++
// reset last updated if checksWithoutProgress >= 10 {
part.lastUpdatedMu.Lock() slog.Info(fmt.Sprintf("%s part %d stalled; retrying", b.Digest[7:19], part.N))
part.lastUpdated = time.Time{}
part.lastUpdatedMu.Unlock()
return errPartStalled return errPartStalled
} }
} else {
checksWithoutProgress = 0
}
lastBytes = currentBytes
// Check for slow speed after 5+ seconds (only for multi-part downloads)
// Skip if we've already retried for slowness too many times
elapsed := time.Since(startTime).Seconds()
if !skipSlowCheck && elapsed >= 5 && currentBytes > 0 && len(b.Parts) > 1 {
currentSpeed := float64(currentBytes) / elapsed
median := tracker.Median()
// If we're below 10% of median speed, flag as slow
if median > 0 && currentSpeed < median*0.1 {
slog.Info(fmt.Sprintf("%s part %d slow (%.0f KB/s vs median %.0f KB/s); retrying",
b.Digest[7:19], part.N, currentSpeed/1024, median/1024))
return errPartSlow
}
}
case <-ctx.Done(): case <-ctx.Done():
return ctx.Err() return ctx.Err()
} }

319
server/download_test.go Normal file
View File

@ -0,0 +1,319 @@
package server
import (
"crypto/rand"
"crypto/sha256"
"fmt"
"os"
"sync"
"testing"
)
func TestSpeedTracker_Median(t *testing.T) {
s := &speedTracker{}
// Less than 3 samples returns 0
s.Record(100)
s.Record(200)
if got := s.Median(); got != 0 {
t.Errorf("expected 0 with < 3 samples, got %f", got)
}
// With 3+ samples, returns median
s.Record(300)
// Samples: [100, 200, 300] -> median = 200
if got := s.Median(); got != 200 {
t.Errorf("expected median 200, got %f", got)
}
// Add more samples
s.Record(50)
s.Record(250)
// Samples: [100, 200, 300, 50, 250] sorted = [50, 100, 200, 250, 300] -> median = 200
if got := s.Median(); got != 200 {
t.Errorf("expected median 200, got %f", got)
}
}
func TestSpeedTracker_RollingWindow(t *testing.T) {
s := &speedTracker{}
// Add 105 samples (should keep only last 100)
for i := 0; i < 105; i++ {
s.Record(float64(i))
}
s.mu.Lock()
if len(s.speeds) != 100 {
t.Errorf("expected 100 samples, got %d", len(s.speeds))
}
// First sample should be 5 (0-4 were dropped)
if s.speeds[0] != 5 {
t.Errorf("expected first sample to be 5, got %f", s.speeds[0])
}
s.mu.Unlock()
}
func TestSpeedTracker_Concurrent(t *testing.T) {
s := &speedTracker{}
var wg sync.WaitGroup
for i := 0; i < 100; i++ {
wg.Add(1)
go func(v int) {
defer wg.Done()
s.Record(float64(v))
s.Median() // concurrent read
}(i)
}
wg.Wait()
// Should not panic, and should have reasonable state
s.mu.Lock()
if len(s.speeds) == 0 || len(s.speeds) > 100 {
t.Errorf("unexpected speeds length: %d", len(s.speeds))
}
s.mu.Unlock()
}
func TestStreamHasher_Sequential(t *testing.T) {
// Create temp file
f, err := os.CreateTemp("", "streamhasher_test")
if err != nil {
t.Fatal(err)
}
defer os.Remove(f.Name())
defer f.Close()
// Write test data
data := []byte("hello world, this is a test of the stream hasher")
if _, err := f.Write(data); err != nil {
t.Fatal(err)
}
// Create parts
parts := []*blobDownloadPart{
{Offset: 0, Size: int64(len(data))},
}
sh := newStreamHasher(f, parts, int64(len(data)))
// Mark complete and run
sh.MarkComplete(0)
done := make(chan struct{})
go func() {
sh.Run()
close(done)
}()
<-done
// Verify digest
expected := fmt.Sprintf("sha256:%x", sha256.Sum256(data))
if got := sh.Digest(); got != expected {
t.Errorf("digest mismatch: got %s, want %s", got, expected)
}
if err := sh.Err(); err != nil {
t.Errorf("unexpected error: %v", err)
}
}
func TestStreamHasher_OutOfOrderCompletion(t *testing.T) {
// Create temp file
f, err := os.CreateTemp("", "streamhasher_test")
if err != nil {
t.Fatal(err)
}
defer os.Remove(f.Name())
defer f.Close()
// Write test data (3 parts of 10 bytes each)
data := []byte("0123456789ABCDEFGHIJabcdefghij")
if _, err := f.Write(data); err != nil {
t.Fatal(err)
}
// Create 3 parts
parts := []*blobDownloadPart{
{N: 0, Offset: 0, Size: 10},
{N: 1, Offset: 10, Size: 10},
{N: 2, Offset: 20, Size: 10},
}
sh := newStreamHasher(f, parts, int64(len(data)))
done := make(chan struct{})
go func() {
sh.Run()
close(done)
}()
// Mark parts complete out of order: 2, 0, 1
sh.MarkComplete(2)
sh.MarkComplete(0) // This should trigger hashing of part 0
sh.MarkComplete(1) // This should trigger hashing of parts 1 and 2
<-done
// Verify digest
expected := fmt.Sprintf("sha256:%x", sha256.Sum256(data))
if got := sh.Digest(); got != expected {
t.Errorf("digest mismatch: got %s, want %s", got, expected)
}
}
func TestStreamHasher_Stop(t *testing.T) {
// Create temp file
f, err := os.CreateTemp("", "streamhasher_test")
if err != nil {
t.Fatal(err)
}
defer os.Remove(f.Name())
defer f.Close()
parts := []*blobDownloadPart{
{Offset: 0, Size: 100},
}
sh := newStreamHasher(f, parts, 100)
done := make(chan struct{})
go func() {
sh.Run()
close(done)
}()
// Stop without completing any parts
sh.Stop()
<-done
// Should exit cleanly without error
if err := sh.Err(); err != nil {
t.Errorf("unexpected error after Stop: %v", err)
}
}
func TestStreamHasher_HashedProgress(t *testing.T) {
// Create temp file with known data
f, err := os.CreateTemp("", "streamhasher_test")
if err != nil {
t.Fatal(err)
}
defer os.Remove(f.Name())
defer f.Close()
data := make([]byte, 1000)
rand.Read(data)
if _, err := f.Write(data); err != nil {
t.Fatal(err)
}
parts := []*blobDownloadPart{
{N: 0, Offset: 0, Size: 500},
{N: 1, Offset: 500, Size: 500},
}
sh := newStreamHasher(f, parts, 1000)
// Initially no progress
if got := sh.Hashed(); got != 0 {
t.Errorf("expected 0 hashed initially, got %d", got)
}
done := make(chan struct{})
go func() {
sh.Run()
close(done)
}()
// Complete part 0
sh.MarkComplete(0)
// Give hasher time to process
for i := 0; i < 100; i++ {
if sh.Hashed() >= 500 {
break
}
}
// Complete part 1
sh.MarkComplete(1)
<-done
if got := sh.Hashed(); got != 1000 {
t.Errorf("expected 1000 hashed, got %d", got)
}
}
func BenchmarkSpeedTracker_Record(b *testing.B) {
s := &speedTracker{}
b.ResetTimer()
for i := 0; i < b.N; i++ {
s.Record(float64(i))
}
}
func BenchmarkSpeedTracker_Median(b *testing.B) {
s := &speedTracker{}
// Pre-populate with 100 samples
for i := 0; i < 100; i++ {
s.Record(float64(i))
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
s.Median()
}
}
func BenchmarkStreamHasher(b *testing.B) {
// Create temp file with test data
f, err := os.CreateTemp("", "streamhasher_bench")
if err != nil {
b.Fatal(err)
}
defer os.Remove(f.Name())
defer f.Close()
size := 64 * 1024 * 1024 // 64MB
data := make([]byte, size)
rand.Read(data)
if _, err := f.Write(data); err != nil {
b.Fatal(err)
}
parts := []*blobDownloadPart{
{Offset: 0, Size: int64(size)},
}
b.SetBytes(int64(size))
b.ResetTimer()
for i := 0; i < b.N; i++ {
sh := newStreamHasher(f, parts, int64(size))
sh.MarkComplete(0)
done := make(chan struct{})
go func() {
sh.Run()
close(done)
}()
<-done
}
}
func BenchmarkHashThroughput(b *testing.B) {
// Baseline: raw SHA256 throughput on this machine
size := 256 * 1024 * 1024 // 256MB
data := make([]byte, size)
rand.Read(data)
b.SetBytes(int64(size))
b.ResetTimer()
for i := 0; i < b.N; i++ {
h := sha256.New()
h.Write(data)
h.Sum(nil)
}
}

View File

@ -620,9 +620,8 @@ func PullModel(ctx context.Context, name string, regOpts *registryOptions, fn fu
layers = append(layers, manifest.Config) layers = append(layers, manifest.Config)
} }
skipVerify := make(map[string]bool)
for _, layer := range layers { for _, layer := range layers {
cacheHit, err := downloadBlob(ctx, downloadOpts{ _, err := downloadBlob(ctx, downloadOpts{
mp: mp, mp: mp,
digest: layer.Digest, digest: layer.Digest,
regOpts: regOpts, regOpts: regOpts,
@ -631,31 +630,12 @@ func PullModel(ctx context.Context, name string, regOpts *registryOptions, fn fu
if err != nil { if err != nil {
return err return err
} }
skipVerify[layer.Digest] = cacheHit
delete(deleteMap, layer.Digest) delete(deleteMap, layer.Digest)
} }
delete(deleteMap, manifest.Config.Digest) delete(deleteMap, manifest.Config.Digest)
fn(api.ProgressResponse{Status: "verifying sha256 digest"}) // Note: Digest verification now happens inline during download in blobDownload.run()
for _, layer := range layers { // via the orderedWriter, so no separate verification pass is needed.
if skipVerify[layer.Digest] {
continue
}
if err := verifyBlob(layer.Digest); err != nil {
if errors.Is(err, errDigestMismatch) {
// something went wrong, delete the blob
fp, err := GetBlobsPath(layer.Digest)
if err != nil {
return err
}
if err := os.Remove(fp); err != nil {
// log this, but return the original error
slog.Info(fmt.Sprintf("couldn't remove file with digest mismatch '%s': %v", fp, err))
}
}
return err
}
}
fn(api.ProgressResponse{Status: "writing manifest"}) fn(api.ProgressResponse{Status: "writing manifest"})

View File

@ -1,8 +0,0 @@
//go:build !windows
package server
import "os"
func setSparse(*os.File) {
}

View File

@ -1,17 +0,0 @@
package server
import (
"os"
"golang.org/x/sys/windows"
)
func setSparse(file *os.File) {
// exFat (and other FS types) don't support sparse files, so ignore errors
windows.DeviceIoControl( //nolint:errcheck
windows.Handle(file.Fd()), windows.FSCTL_SET_SPARSE,
nil, 0,
nil, 0,
nil, nil,
)
}