Compare commits

..

25 Commits

Author SHA1 Message Date
Patrick Devine
9a483dc7b7 refactor to use client.List instead of walking the filesystem 2024-01-26 18:34:42 -08:00
Patrick Devine
366b38460f fix linter 2024-01-26 18:34:42 -08:00
Patrick Devine
021b1bdc4a add --upgrade-all flag to refresh any stale models 2024-01-26 18:34:40 -08:00
Patrick Devine
b5cf31b460 add keep_alive to generate/chat/embedding api endpoints (#2146) 2024-01-26 14:28:02 -08:00
Daniel Hiltgen
cc4915e262 Merge pull request #2214 from dhiltgen/reject_cuda_without_avx
Detect lack of AVX and fallback to CPU mode
2024-01-26 12:06:44 -08:00
Daniel Hiltgen
667a2ba18a Detect lack of AVX and fallback to CPU mode
We build the GPU libraries with AVX enabled to ensure that if not all
layers fit on the GPU we get better performance in a mixed mode.
If the user is using a virtualization/emulation system that lacks AVX
this used to result in an illegal instruction error and crash before this
fix.  Now we will report a warning in the server log, and just use
CPU mode to ensure we don't crash.
2024-01-26 11:36:03 -08:00
Michael Yang
e054ebe059 Merge pull request #2212 from ollama/mxyng/fix-build
fix build
2024-01-26 11:19:08 -08:00
Michael Yang
9d3dcfd0ec fix logging 2024-01-26 11:04:27 -08:00
Michael Yang
6e0ea5ecc8 Merge pull request #1916 from ollama/mxyng/inactivity-monitor
download: add inactivity monitor
2024-01-26 10:56:00 -08:00
Daniel Hiltgen
a47d8b2557 Merge pull request #2197 from dhiltgen/remove_rocm_image
Add back ROCm container support
2024-01-26 09:34:23 -08:00
Daniel Hiltgen
30c43c285c Merge pull request #2195 from dhiltgen/rocm_real_gpus
Ignore AMD integrated GPUs
2024-01-26 09:30:24 -08:00
Daniel Hiltgen
23a7ea593b Merge pull request #2209 from dhiltgen/harden_mgmt
Fix crash on cuda ml init failure
2024-01-26 09:30:13 -08:00
Daniel Hiltgen
75c44aa319 Add back ROCm container support
This adds ROCm support back as a discrete image.
2024-01-26 09:24:29 -08:00
Daniel Hiltgen
9d7b5d6c91 Ignore AMD integrated GPUs
Detect and ignore integrated GPUs reported by rocm.
2024-01-26 09:21:35 -08:00
Daniel Hiltgen
5d9c4a5f5a Fix crash on cuda ml init failure
The new driver lookup code was triggering after init failure due to a missing return
2024-01-26 09:18:33 -08:00
Daniel Hiltgen
197e420a97 Merge pull request #2196 from dhiltgen/remove_rocm_image
Switch back to ubuntu base
2024-01-25 16:50:32 -08:00
Daniel Hiltgen
a34e1ad3cf Switch back to ubuntu base
The size increase for rocm support in the standard image is problematic
We'll revisit multiple tags for rocm support in a follow up PR.
2024-01-25 16:46:01 -08:00
Michael Yang
2ae0556292 Merge pull request #1679 from ollama/mxyng/build-gpus
build cuda and rocm
2024-01-25 16:38:14 -08:00
Jeffrey Morgan
5be9bdd444 Update modelfile.md 2024-01-25 16:29:48 -08:00
Jeffrey Morgan
b706794905 Update modelfile.md to include MESSAGE 2024-01-25 16:29:32 -08:00
Michael Yang
a8c5413d06 only generate gpu libs 2024-01-25 15:41:56 -08:00
Michael Yang
5580de4571 archive ollama binaries 2024-01-25 15:40:16 -08:00
Michael Yang
946431d5b0 build cuda and rocm 2024-01-25 15:40:15 -08:00
Michael Yang
0610126049 remove env setting 2024-01-25 15:39:43 -08:00
Michael Yang
27331ae3a8 download: add inactivity monitor
if a download part is inactive for some time, restart it
2024-01-12 15:23:15 -08:00
14 changed files with 348 additions and 102 deletions

View File

@@ -23,27 +23,71 @@ jobs:
with:
go-version: '1.21'
cache: true
- if: ${{ startsWith(matrix.os, 'windows-') }}
shell: pwsh
run: |
$path = vswhere -latest -products * -requires Microsoft.VisualStudio.Component.VC.Tools.x86.x64 -property installationPath
if ($path) {
$path = join-path $path 'Common7\Tools\vsdevcmd.bat'
if (test-path $path) {
cmd /s /c """$path"" $args && set" | where { $_ -match '(\w+)=(.*)' } | foreach {
echo "$($Matches[1])=$($Matches[2])" | Out-File -FilePath $Env:GITHUB_ENV -Encoding utf8 -Append
}
}
}
echo "C:\Program Files\Git\usr\bin" | Out-File -FilePath $Env:GITHUB_PATH -Encoding utf8 -Append
- run: go get ./...
- run: go generate -x ./...
- uses: actions/upload-artifact@v4
with:
name: ${{ matrix.os }}-${{ matrix.arch }}-libraries
path: |
llm/llama.cpp/build/**/lib/*
path: llm/llama.cpp/build/**/lib/*
generate-cuda:
strategy:
matrix:
cuda-version:
- '11.8.0'
runs-on: ubuntu-latest
container: nvidia/cuda:${{ matrix.cuda-version }}-devel-ubuntu20.04
steps:
- run: |
apt-get update && apt-get install -y git build-essential curl
curl -fsSL https://github.com/Kitware/CMake/releases/download/v3.28.1/cmake-3.28.1-linux-x86_64.tar.gz \
| tar -zx -C /usr --strip-components 1
env:
DEBIAN_FRONTEND: noninteractive
- uses: actions/checkout@v4
- uses: actions/setup-go@v4
with:
go-version: '1.21'
cache: true
- run: go get ./...
- run: |
git config --global --add safe.directory /__w/ollama/ollama
go generate -x ./...
env:
OLLAMA_SKIP_CPU_GENERATE: '1'
- uses: actions/upload-artifact@v4
with:
name: cuda-${{ matrix.cuda-version }}-libraries
path: llm/llama.cpp/build/**/lib/*
generate-rocm:
strategy:
matrix:
rocm-version:
- '5.7.1'
- '6.0'
runs-on: ubuntu-latest
container: rocm/dev-ubuntu-20.04:${{ matrix.rocm-version }}
steps:
- run: |
apt-get update && apt-get install -y git build-essential curl rocm-libs
curl -fsSL https://github.com/Kitware/CMake/releases/download/v3.28.1/cmake-3.28.1-linux-x86_64.tar.gz \
| tar -zx -C /usr --strip-components 1
env:
DEBIAN_FRONTEND: noninteractive
- uses: actions/checkout@v4
- uses: actions/setup-go@v4
with:
go-version: '1.21'
cache: true
- run: go get ./...
- run: |
git config --global --add safe.directory /__w/ollama/ollama
go generate -x ./...
env:
OLLAMA_SKIP_CPU_GENERATE: '1'
- uses: actions/upload-artifact@v4
with:
name: rocm-${{ matrix.rocm-version }}-libraries
path: llm/llama.cpp/build/**/lib/*
lint:
strategy:
matrix:
@@ -112,3 +156,7 @@ jobs:
path: llm/llama.cpp/build
- run: go build
- run: go test -v ./...
- uses: actions/upload-artifact@v4
with:
name: ${{ matrix.os }}-binaries
path: ollama

View File

@@ -109,17 +109,28 @@ ARG CGO_CFLAGS
RUN go build .
# Runtime stages
FROM --platform=linux/amd64 rocm/dev-centos-7:6.0-complete as runtime-amd64
FROM --platform=linux/amd64 ubuntu:22.04 as runtime-amd64
RUN apt-get update && apt-get install -y ca-certificates
COPY --from=build-amd64 /go/src/github.com/jmorganca/ollama/ollama /bin/ollama
FROM --platform=linux/arm64 ubuntu:22.04 as runtime-arm64
RUN apt-get update && apt-get install -y ca-certificates
COPY --from=build-arm64 /go/src/github.com/jmorganca/ollama/ollama /bin/ollama
# Radeon images are much larger so we keep it distinct from the CPU/CUDA image
FROM --platform=linux/amd64 rocm/dev-centos-7:5.7.1-complete as runtime-rocm
RUN update-pciids
COPY --from=build-amd64 /go/src/github.com/jmorganca/ollama/ollama /bin/ollama
EXPOSE 11434
ENV OLLAMA_HOST 0.0.0.0
ENTRYPOINT ["/bin/ollama"]
CMD ["serve"]
FROM runtime-$TARGETARCH
EXPOSE 11434
ENV OLLAMA_HOST 0.0.0.0
ENV PATH=/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin
ENV LD_LIBRARY_PATH=/usr/local/nvidia/lib:/usr/local/nvidia/lib64:/opt/rocm/lib:
ENV LD_LIBRARY_PATH=/usr/local/nvidia/lib:/usr/local/nvidia/lib64
ENV NVIDIA_DRIVER_CAPABILITIES=compute,utility
ENTRYPOINT ["/bin/ollama"]

View File

@@ -34,24 +34,26 @@ func (e StatusError) Error() string {
type ImageData []byte
type GenerateRequest struct {
Model string `json:"model"`
Prompt string `json:"prompt"`
System string `json:"system"`
Template string `json:"template"`
Context []int `json:"context,omitempty"`
Stream *bool `json:"stream,omitempty"`
Raw bool `json:"raw,omitempty"`
Format string `json:"format"`
Images []ImageData `json:"images,omitempty"`
Model string `json:"model"`
Prompt string `json:"prompt"`
System string `json:"system"`
Template string `json:"template"`
Context []int `json:"context,omitempty"`
Stream *bool `json:"stream,omitempty"`
Raw bool `json:"raw,omitempty"`
Format string `json:"format"`
KeepAlive *Duration `json:"keep_alive,omitempty"`
Images []ImageData `json:"images,omitempty"`
Options map[string]interface{} `json:"options"`
}
type ChatRequest struct {
Model string `json:"model"`
Messages []Message `json:"messages"`
Stream *bool `json:"stream,omitempty"`
Format string `json:"format"`
Model string `json:"model"`
Messages []Message `json:"messages"`
Stream *bool `json:"stream,omitempty"`
Format string `json:"format"`
KeepAlive *Duration `json:"keep_alive,omitempty"`
Options map[string]interface{} `json:"options"`
}
@@ -126,8 +128,9 @@ type Runner struct {
}
type EmbeddingRequest struct {
Model string `json:"model"`
Prompt string `json:"prompt"`
Model string `json:"model"`
Prompt string `json:"prompt"`
KeepAlive *Duration `json:"keep_alive,omitempty"`
Options map[string]interface{} `json:"options"`
}
@@ -180,11 +183,12 @@ type CopyRequest struct {
}
type PullRequest struct {
Model string `json:"model"`
Insecure bool `json:"insecure,omitempty"`
Username string `json:"username"`
Password string `json:"password"`
Stream *bool `json:"stream,omitempty"`
Model string `json:"model"`
Insecure bool `json:"insecure,omitempty"`
Username string `json:"username"`
Password string `json:"password"`
Stream *bool `json:"stream,omitempty"`
CurrentDigest string `json:"current_digest,omitempty"`
// Name is deprecated, see Model
Name string `json:"name"`
@@ -238,6 +242,7 @@ type GenerateResponse struct {
type ModelDetails struct {
ParentModel string `json:"parent_model"`
Digest string `json:"digest"`
Format string `json:"format"`
Family string `json:"family"`
Families []string `json:"families"`
@@ -413,14 +418,19 @@ func (d *Duration) UnmarshalJSON(b []byte) (err error) {
case float64:
if t < 0 {
t = math.MaxFloat64
d.Duration = time.Duration(t)
} else {
d.Duration = time.Duration(t * float64(time.Second))
}
d.Duration = time.Duration(t)
case string:
d.Duration, err = time.ParseDuration(t)
if err != nil {
return err
}
if d.Duration < 0 {
mf := math.MaxFloat64
d.Duration = time.Duration(mf)
}
}
return nil

View File

@@ -357,6 +357,42 @@ func CopyHandler(cmd *cobra.Command, args []string) error {
}
func PullHandler(cmd *cobra.Command, args []string) error {
upgradeAll, err := cmd.Flags().GetBool("upgrade-all")
if err != nil {
return err
}
if !upgradeAll {
if len(args) == 0 {
return fmt.Errorf("no model specified to pull")
}
return pull(cmd, args[0], "")
}
client, err := api.ClientFromEnvironment()
if err != nil {
return err
}
models, err := client.List(cmd.Context())
if err != nil {
return err
}
for _, m := range (*models).Models {
err = pull(cmd, m.Name, "sha256:"+m.Digest)
if err != nil {
if strings.Contains(err.Error(), "file does not exist") {
fmt.Printf("model '%s' is no longer available\n", m.Name)
continue
}
return err
}
}
return nil
}
func pull(cmd *cobra.Command, name string, currentDigest string) error {
insecure, err := cmd.Flags().GetBool("insecure")
if err != nil {
return err
@@ -368,7 +404,7 @@ func PullHandler(cmd *cobra.Command, args []string) error {
}
p := progress.NewProgress(os.Stderr)
defer p.Stop()
defer p.StopWithoutClear()
bars := make(map[string]*progress.Bar)
@@ -402,7 +438,7 @@ func PullHandler(cmd *cobra.Command, args []string) error {
return nil
}
request := api.PullRequest{Name: args[0], Insecure: insecure}
request := api.PullRequest{Name: name, Insecure: insecure, CurrentDigest: currentDigest}
if err := client.Pull(cmd.Context(), &request, fn); err != nil {
return err
}
@@ -884,12 +920,13 @@ func NewCLI() *cobra.Command {
pullCmd := &cobra.Command{
Use: "pull MODEL",
Short: "Pull a model from a registry",
Args: cobra.ExactArgs(1),
Args: cobra.RangeArgs(0, 1),
PreRunE: checkServerHeartbeat,
RunE: PullHandler,
}
pullCmd.Flags().Bool("insecure", false, "Use an insecure registry")
pullCmd.Flags().Bool("upgrade-all", false, "Upgrade all models if they're out of date")
pushCmd := &cobra.Command{
Use: "push MODEL",

View File

@@ -19,6 +19,7 @@ A model file is the blueprint to create and share models with Ollama.
- [SYSTEM](#system)
- [ADAPTER](#adapter)
- [LICENSE](#license)
- [MESSAGE](#message)
- [Notes](#notes)
## Format
@@ -38,6 +39,7 @@ INSTRUCTION arguments
| [`SYSTEM`](#system) | Specifies the system message that will be set in the template. |
| [`ADAPTER`](#adapter) | Defines the (Q)LoRA adapters to apply to the model. |
| [`LICENSE`](#license) | Specifies the legal license. |
| [`MESSAGE`](#message) | Specify message history. |
## Examples
@@ -205,6 +207,19 @@ LICENSE """
"""
```
### MESSAGE
The `MESSAGE` instruction allows you to specify a message history for the model to use when responding:
```modelfile
MESSAGE user Is Toronto in Canada?
MESSAGE assistant yes
MESSAGE user Is Sacramento in Canada?
MESSAGE assistant no
MESSAGE user Is Ontario in Canada?
MESSAGE assistant yes
```
## Notes
- the **`Modelfile` is not case sensitive**. In the examples, uppercase instructions are used to make it easier to distinguish it from arguments.

View File

@@ -16,6 +16,7 @@ import (
"os"
"path/filepath"
"runtime"
"strconv"
"strings"
"sync"
"unsafe"
@@ -121,9 +122,15 @@ func GetGPUInfo() GpuInfo {
initGPUHandles()
}
// All our GPU builds have AVX enabled, so fallback to CPU if we don't detect at least AVX
cpuVariant := GetCPUVariant()
if cpuVariant == "" {
slog.Warn("CPU does not have AVX or AVX2, disabling GPU support.")
}
var memInfo C.mem_info_t
resp := GpuInfo{}
if gpuHandles.cuda != nil {
if gpuHandles.cuda != nil && cpuVariant != "" {
C.cuda_check_vram(*gpuHandles.cuda, &memInfo)
if memInfo.err != nil {
slog.Info(fmt.Sprintf("error looking up CUDA GPU memory: %s", C.GoString(memInfo.err)))
@@ -142,12 +149,33 @@ func GetGPUInfo() GpuInfo {
slog.Info(fmt.Sprintf("CUDA GPU is too old. Falling back to CPU mode. Compute Capability detected: %d.%d", cc.major, cc.minor))
}
}
} else if gpuHandles.rocm != nil {
} else if gpuHandles.rocm != nil && cpuVariant != "" {
C.rocm_check_vram(*gpuHandles.rocm, &memInfo)
if memInfo.err != nil {
slog.Info(fmt.Sprintf("error looking up ROCm GPU memory: %s", C.GoString(memInfo.err)))
C.free(unsafe.Pointer(memInfo.err))
} else if memInfo.igpu_index >= 0 && memInfo.count == 1 {
// Only one GPU detected and it appears to be an integrated GPU - skip it
slog.Info("ROCm unsupported integrated GPU detected")
} else {
if memInfo.igpu_index >= 0 {
// We have multiple GPUs reported, and one of them is an integrated GPU
// so we have to set the env var to bypass it
// If the user has specified their own ROCR_VISIBLE_DEVICES, don't clobber it
val := os.Getenv("ROCR_VISIBLE_DEVICES")
if val == "" {
devices := []string{}
for i := 0; i < int(memInfo.count); i++ {
if i == int(memInfo.igpu_index) {
continue
}
devices = append(devices, strconv.Itoa(i))
}
val = strings.Join(devices, ",")
os.Setenv("ROCR_VISIBLE_DEVICES", val)
}
slog.Info(fmt.Sprintf("ROCm integrated GPU detected - ROCR_VISIBLE_DEVICES=%s", val))
}
resp.Library = "rocm"
var version C.rocm_version_resp_t
C.rocm_get_version(*gpuHandles.rocm, &version)
@@ -163,7 +191,7 @@ func GetGPUInfo() GpuInfo {
if resp.Library == "" {
C.cpu_check_ram(&memInfo)
resp.Library = "cpu"
resp.Variant = GetCPUVariant()
resp.Variant = cpuVariant
}
if memInfo.err != nil {
slog.Info(fmt.Sprintf("error looking up CPU memory: %s", C.GoString(memInfo.err)))
@@ -199,7 +227,9 @@ func CheckVRAM() (int64, error) {
if overhead < gpus*1024*1024*1024 {
overhead = gpus * 1024 * 1024 * 1024
}
return int64(gpuInfo.FreeMemory - overhead), nil
avail := int64(gpuInfo.FreeMemory - overhead)
slog.Debug(fmt.Sprintf("%s detected %d devices with %dM available memory", gpuInfo.Library, gpuInfo.DeviceCount, avail/1024/1024))
return avail, nil
}
return 0, fmt.Errorf("no GPU detected") // TODO - better handling of CPU based memory determiniation

View File

@@ -42,6 +42,7 @@ typedef struct mem_info {
uint64_t total;
uint64_t free;
unsigned int count;
int igpu_index; // If >= 0, we detected an integrated GPU to ignore
char *err; // If non-nill, caller responsible for freeing
} mem_info_t;

View File

@@ -70,6 +70,7 @@ void cuda_init(char *cuda_lib_path, cuda_init_resp_t *resp) {
resp->ch.handle = NULL;
snprintf(buf, buflen, "nvml vram init failure: %d", ret);
resp->err = strdup(buf);
return;
}
// Report driver version if we're in verbose mode, ignore errors

View File

@@ -77,6 +77,7 @@ void rocm_init(char *rocm_lib_path, rocm_init_resp_t *resp) {
void rocm_check_vram(rocm_handle_t h, mem_info_t *resp) {
resp->err = NULL;
resp->igpu_index = -1;
uint64_t totalMem = 0;
uint64_t usedMem = 0;
rsmi_status_t ret;
@@ -162,8 +163,14 @@ void rocm_check_vram(rocm_handle_t h, mem_info_t *resp) {
}
LOG(h.verbose, "[%d] ROCm totalMem %ld\n", i, totalMem);
LOG(h.verbose, "[%d] ROCm usedMem %ld\n", i, usedMem);
resp->total += totalMem;
resp->free += totalMem - usedMem;
if (totalMem < 1024 * 1024 * 1024) {
// Do not add up integrated GPU memory capacity, it's a bogus 512M, and actually uses system memory
LOG(h.verbose, "[%d] ROCm integrated GPU\n", i);
resp->igpu_index = i;
} else {
resp->total += totalMem;
resp->free += totalMem - usedMem;
}
}
}

View File

@@ -52,6 +52,10 @@ func (p *Progress) Stop() bool {
return stopped
}
func (p *Progress) StopWithoutClear() bool {
return p.stop()
}
func (p *Progress) StopAndClear() bool {
fmt.Fprint(p.w, "\033[?25l")
defer fmt.Fprint(p.w, "\033[?25h")

View File

@@ -13,3 +13,13 @@ docker build \
-f Dockerfile \
-t ollama/ollama:$VERSION \
.
docker build \
--load \
--platform=linux/amd64 \
--build-arg=VERSION \
--build-arg=GOFLAGS \
--target runtime-rocm \
-f Dockerfile \
-t ollama/ollama:$VERSION-rocm \
.

View File

@@ -25,6 +25,11 @@ import (
"github.com/jmorganca/ollama/format"
)
const maxRetries = 6
var errMaxRetriesExceeded = errors.New("max retries exceeded")
var errPartStalled = errors.New("part stalled")
var blobDownloadManager sync.Map
type blobDownload struct {
@@ -44,10 +49,11 @@ type blobDownload struct {
}
type blobDownloadPart struct {
N int
Offset int64
Size int64
Completed int64
N int
Offset int64
Size int64
Completed int64
lastUpdated time.Time
*blobDownload `json:"-"`
}
@@ -72,6 +78,13 @@ func (p *blobDownloadPart) StopsAt() int64 {
return p.Offset + p.Size
}
func (p *blobDownloadPart) Write(b []byte) (n int, err error) {
n = len(b)
p.blobDownload.Completed.Add(int64(n))
p.lastUpdated = time.Now()
return n, nil
}
func (b *blobDownload) Prepare(ctx context.Context, requestURL *url.URL, opts *RegistryOptions) error {
partFilePaths, err := filepath.Glob(b.Name + "-partial-*")
if err != nil {
@@ -157,6 +170,9 @@ func (b *blobDownload) run(ctx context.Context, requestURL *url.URL, opts *Regis
case errors.Is(err, context.Canceled), errors.Is(err, syscall.ENOSPC):
// return immediately if the context is canceled or the device is out of space
return err
case errors.Is(err, errPartStalled):
try--
continue
case err != nil:
sleep := time.Second * time.Duration(math.Pow(2, float64(try)))
slog.Info(fmt.Sprintf("%s part %d attempt %d failed: %v, retrying in %s", b.Digest[7:19], part.N, try, err, sleep))
@@ -195,28 +211,54 @@ func (b *blobDownload) run(ctx context.Context, requestURL *url.URL, opts *Regis
}
func (b *blobDownload) downloadChunk(ctx context.Context, requestURL *url.URL, w io.Writer, part *blobDownloadPart, opts *RegistryOptions) error {
headers := make(http.Header)
headers.Set("Range", fmt.Sprintf("bytes=%d-%d", part.StartsAt(), part.StopsAt()-1))
resp, err := makeRequestWithRetry(ctx, http.MethodGet, requestURL, headers, nil, opts)
if err != nil {
return err
}
defer resp.Body.Close()
g, ctx := errgroup.WithContext(ctx)
g.Go(func() error {
headers := make(http.Header)
headers.Set("Range", fmt.Sprintf("bytes=%d-%d", part.StartsAt(), part.StopsAt()-1))
resp, err := makeRequestWithRetry(ctx, http.MethodGet, requestURL, headers, nil, opts)
if err != nil {
return err
}
defer resp.Body.Close()
n, err := io.Copy(w, io.TeeReader(resp.Body, b))
if err != nil && !errors.Is(err, context.Canceled) && !errors.Is(err, io.ErrUnexpectedEOF) {
// rollback progress
b.Completed.Add(-n)
return err
}
n, err := io.Copy(w, io.TeeReader(resp.Body, part))
if err != nil && !errors.Is(err, context.Canceled) && !errors.Is(err, io.ErrUnexpectedEOF) {
// rollback progress
b.Completed.Add(-n)
return err
}
part.Completed += n
if err := b.writePart(part.Name(), part); err != nil {
return err
}
part.Completed += n
if err := b.writePart(part.Name(), part); err != nil {
return err
}
// return nil or context.Canceled or UnexpectedEOF (resumable)
return err
// return nil or context.Canceled or UnexpectedEOF (resumable)
return err
})
g.Go(func() error {
ticker := time.NewTicker(time.Second)
for {
select {
case <-ticker.C:
if part.Completed >= part.Size {
return nil
}
if !part.lastUpdated.IsZero() && time.Since(part.lastUpdated) > 5*time.Second {
slog.Info(fmt.Sprintf("%s part %d stalled; retrying", b.Digest[7:19], part.N))
// reset last updated
part.lastUpdated = time.Time{}
return errPartStalled
}
case <-ctx.Done():
return ctx.Err()
}
}
})
return g.Wait()
}
func (b *blobDownload) newPart(offset, size int64) error {
@@ -255,12 +297,6 @@ func (b *blobDownload) writePart(partName string, part *blobDownloadPart) error
return json.NewEncoder(partFile).Encode(part)
}
func (b *blobDownload) Write(p []byte) (n int, err error) {
n = len(p)
b.Completed.Add(int64(n))
return n, nil
}
func (b *blobDownload) acquire() {
b.references.Add(1)
}
@@ -279,20 +315,19 @@ func (b *blobDownload) Wait(ctx context.Context, fn func(api.ProgressResponse))
for {
select {
case <-ticker.C:
fn(api.ProgressResponse{
Status: fmt.Sprintf("pulling %s", b.Digest[7:19]),
Digest: b.Digest,
Total: b.Total,
Completed: b.Completed.Load(),
})
if b.done || b.err != nil {
return b.err
}
case <-ctx.Done():
return ctx.Err()
}
fn(api.ProgressResponse{
Status: fmt.Sprintf("pulling %s", b.Digest[7:19]),
Digest: b.Digest,
Total: b.Total,
Completed: b.Completed.Load(),
})
if b.done || b.err != nil {
return b.err
}
}
}
@@ -303,10 +338,6 @@ type downloadOpts struct {
fn func(api.ProgressResponse)
}
const maxRetries = 6
var errMaxRetriesExceeded = errors.New("max retries exceeded")
// downloadBlob downloads a blob from the registry and stores it in the blobs directory
func downloadBlob(ctx context.Context, opts downloadOpts) error {
fp, err := GetBlobsPath(opts.digest)

View File

@@ -471,7 +471,7 @@ func CreateModel(ctx context.Context, name, modelFileDir string, commands []pars
switch {
case errors.Is(err, os.ErrNotExist):
fn(api.ProgressResponse{Status: "pulling model"})
if err := PullModel(ctx, c.Args, &RegistryOptions{}, fn); err != nil {
if err := PullModel(ctx, c.Args, "", &RegistryOptions{}, fn); err != nil {
return err
}
@@ -1041,7 +1041,7 @@ func PushModel(ctx context.Context, name string, regOpts *RegistryOptions, fn fu
return nil
}
func PullModel(ctx context.Context, name string, regOpts *RegistryOptions, fn func(api.ProgressResponse)) error {
func PullModel(ctx context.Context, name, currentDigest string, regOpts *RegistryOptions, fn func(api.ProgressResponse)) error {
mp := ParseModelPath(name)
var manifest *ManifestV2
@@ -1069,13 +1069,23 @@ func PullModel(ctx context.Context, name string, regOpts *RegistryOptions, fn fu
return fmt.Errorf("insecure protocol http")
}
fn(api.ProgressResponse{Status: "pulling manifest"})
if currentDigest == "" {
fn(api.ProgressResponse{Status: "pulling manifest"})
}
manifest, err = pullModelManifest(ctx, mp, regOpts)
manifest, err = pullModelManifest(ctx, mp, currentDigest, regOpts)
if err != nil {
return fmt.Errorf("pull model manifest: %s", err)
}
if currentDigest != "" {
if manifest == nil {
// we already have the model
return nil
}
fn(api.ProgressResponse{Status: "upgrading " + mp.GetShortTagname()})
}
var layers []*Layer
layers = append(layers, manifest.Layers...)
layers = append(layers, manifest.Config)
@@ -1147,17 +1157,27 @@ func PullModel(ctx context.Context, name string, regOpts *RegistryOptions, fn fu
return nil
}
func pullModelManifest(ctx context.Context, mp ModelPath, regOpts *RegistryOptions) (*ManifestV2, error) {
func pullModelManifest(ctx context.Context, mp ModelPath, currentDigest string, regOpts *RegistryOptions) (*ManifestV2, error) {
requestURL := mp.BaseURL().JoinPath("v2", mp.GetNamespaceRepository(), "manifests", mp.Tag)
headers := make(http.Header)
headers.Set("Accept", "application/vnd.docker.distribution.manifest.v2+json")
if currentDigest != "" {
headers.Set("If-None-Match", currentDigest)
}
resp, err := makeRequestWithRetry(ctx, http.MethodGet, requestURL, headers, nil, regOpts)
if err != nil {
return nil, err
}
defer resp.Body.Close()
// todo we can potentially read the manifest locally and return it here
if resp.StatusCode == http.StatusNotModified {
return nil, nil
}
var m *ManifestV2
if err := json.NewDecoder(resp.Body).Decode(&m); err != nil {
return nil, err

View File

@@ -186,7 +186,13 @@ func GenerateHandler(c *gin.Context) {
return
}
sessionDuration := defaultSessionDuration
var sessionDuration time.Duration
if req.KeepAlive == nil {
sessionDuration = defaultSessionDuration
} else {
sessionDuration = req.KeepAlive.Duration
}
if err := load(c, model, opts, sessionDuration); err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
@@ -378,7 +384,14 @@ func EmbeddingHandler(c *gin.Context) {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
sessionDuration := defaultSessionDuration
var sessionDuration time.Duration
if req.KeepAlive == nil {
sessionDuration = defaultSessionDuration
} else {
sessionDuration = req.KeepAlive.Duration
}
if err := load(c, model, opts, sessionDuration); err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
@@ -438,7 +451,7 @@ func PullModelHandler(c *gin.Context) {
ctx, cancel := context.WithCancel(c.Request.Context())
defer cancel()
if err := PullModel(ctx, model, regOpts, fn); err != nil {
if err := PullModel(ctx, model, req.CurrentDigest, regOpts, fn); err != nil {
ch <- gin.H{"error": err.Error()}
}
}()
@@ -660,6 +673,7 @@ func GetModelInfo(req api.ShowRequest) (*api.ShowResponse, error) {
modelDetails := api.ModelDetails{
ParentModel: model.ParentModel,
Digest: "sha256:" + model.Digest,
Format: model.Config.ModelFormat,
Family: model.Config.ModelFamily,
Families: model.Config.ModelFamilies,
@@ -1074,7 +1088,14 @@ func ChatHandler(c *gin.Context) {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
sessionDuration := defaultSessionDuration
var sessionDuration time.Duration
if req.KeepAlive == nil {
sessionDuration = defaultSessionDuration
} else {
sessionDuration = req.KeepAlive.Duration
}
if err := load(c, model, opts, sessionDuration); err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return