diff --git a/cmd/cmd.go b/cmd/cmd.go index 35074ad2b..00c3d82c0 100644 --- a/cmd/cmd.go +++ b/cmd/cmd.go @@ -1877,6 +1877,7 @@ func NewCLI() *cobra.Command { envVars["OLLAMA_KEEP_ALIVE"], envVars["OLLAMA_MAX_LOADED_MODELS"], envVars["OLLAMA_MAX_QUEUE"], + envVars["OLLAMA_PRELOAD_MODELS"], envVars["OLLAMA_MODELS"], envVars["OLLAMA_NUM_PARALLEL"], envVars["OLLAMA_NOPRUNE"], diff --git a/docs/faq.mdx b/docs/faq.mdx index 4237da41f..ae99fc555 100644 --- a/docs/faq.mdx +++ b/docs/faq.mdx @@ -266,6 +266,22 @@ To preload a model using the CLI, use the command: ollama run llama3.2 "" ``` +You can also have Ollama preload models automatically at startup with the `OLLAMA_PRELOAD_MODELS` environment variable. Provide a comma-separated list of models and optional query-style parameters (the same keys you would pass with `ollama run` model parameters) to tune the warm-up request: + +```shell +OLLAMA_PRELOAD_MODELS="llama3.2,phi3?temperature=0.2&num_ctx=4096&keepalive=30m" ollama serve +``` + +- Each item is processed in order. Parameters such as `temperature`, `num_ctx`, or other model options are sent just like using `ollama run -p key=value`. +- Use `keepalive` to control how long the preloaded model stays in memory (e.g., `keepalive=-1` to keep it loaded). +- Optional `prompt` lets you send text during warm-up; embedding models automatically use `"init"` so they receive a required input without you having to specify it. + +For embedding models, a simple configuration like the following warms the model with the implicit `"init"` prompt and leaves it loaded: + +```shell +OLLAMA_PRELOAD_MODELS="mxbai-embed-large?keepalive=-1" ollama serve +``` + ## How do I keep a model loaded in memory or make it unload immediately? By default models are kept in memory for 5 minutes before being unloaded. This allows for quicker response times if you're making numerous requests to the LLM. If you want to immediately unload a model from memory, use the `ollama stop` command: diff --git a/envconfig/config.go b/envconfig/config.go index 238e5e6e1..ce817ac77 100644 --- a/envconfig/config.go +++ b/envconfig/config.go @@ -148,6 +148,11 @@ func Remotes() []string { return r } +// PreloadedModels returns the raw comma-separated list of models to preload on startup. +func PreloadedModels() string { + return Var("OLLAMA_PRELOAD_MODELS") +} + func BoolWithDefault(k string) func(defaultValue bool) bool { return func(defaultValue bool) bool { if s := Var(k); s != "" { @@ -283,6 +288,7 @@ func AsMap() map[string]EnvVar { "OLLAMA_LOAD_TIMEOUT": {"OLLAMA_LOAD_TIMEOUT", LoadTimeout(), "How long to allow model loads to stall before giving up (default \"5m\")"}, "OLLAMA_MAX_LOADED_MODELS": {"OLLAMA_MAX_LOADED_MODELS", MaxRunners(), "Maximum number of loaded models per GPU"}, "OLLAMA_MAX_QUEUE": {"OLLAMA_MAX_QUEUE", MaxQueue(), "Maximum number of queued requests"}, + "OLLAMA_PRELOAD_MODELS": {"OLLAMA_PRELOAD_MODELS", PreloadedModels(), "Comma-separated models to preload on startup"}, "OLLAMA_MODELS": {"OLLAMA_MODELS", Models(), "The path to the models directory"}, "OLLAMA_NOHISTORY": {"OLLAMA_NOHISTORY", NoHistory(), "Do not preserve readline history"}, "OLLAMA_NOPRUNE": {"OLLAMA_NOPRUNE", NoPrune(), "Do not prune model blobs on startup"}, diff --git a/server/preload.go b/server/preload.go new file mode 100644 index 000000000..2696dcf97 --- /dev/null +++ b/server/preload.go @@ -0,0 +1,249 @@ +package server + +import ( + "context" + "fmt" + "log/slog" + "net/url" + "slices" + "strconv" + "strings" + "time" + + "github.com/ollama/ollama/api" + "github.com/ollama/ollama/envconfig" + "github.com/ollama/ollama/types/model" +) + +type preloadModelSpec struct { + Name string + Prompt string + KeepAlive *api.Duration + Options map[string]any + Think *api.ThinkValue +} + +func parsePreloadSpecs(raw string) ([]preloadModelSpec, error) { + if strings.TrimSpace(raw) == "" { + return nil, nil + } + + parts := strings.Split(raw, ",") + specs := make([]preloadModelSpec, 0, len(parts)) + for _, part := range parts { + part = strings.TrimSpace(part) + if part == "" { + continue + } + + spec, err := parsePreloadEntry(part) + if err != nil { + return nil, fmt.Errorf("could not parse %q: %w", part, err) + } + + specs = append(specs, spec) + } + + return specs, nil +} + +func parsePreloadEntry(raw string) (preloadModelSpec, error) { + var spec preloadModelSpec + + namePart, optionsPart, hasOptions := strings.Cut(raw, "?") + spec.Name = strings.TrimSpace(namePart) + if spec.Name == "" { + return spec, fmt.Errorf("model name is required") + } + + if !hasOptions || strings.TrimSpace(optionsPart) == "" { + return spec, nil + } + + values, err := url.ParseQuery(optionsPart) + if err != nil { + return spec, fmt.Errorf("invalid parameters: %w", err) + } + + opts := map[string]any{} + for key, val := range values { + if len(val) == 0 { + continue + } + + v := val + // only use the last value unless multiple are explicitly provided + if len(v) == 1 { + switch strings.ToLower(key) { + case "prompt": + spec.Prompt = v[0] + continue + case "keepalive", "keep_alive": + d, err := parseDurationValue(v[0]) + if err != nil { + return spec, fmt.Errorf("keepalive for %s: %w", spec.Name, err) + } + spec.KeepAlive = &api.Duration{Duration: d} + continue + case "think": + tv, err := parseThinkValue(v[0]) + if err != nil { + return spec, fmt.Errorf("think for %s: %w", spec.Name, err) + } + spec.Think = tv + continue + } + + opts[key] = parseValue(v[0]) + continue + } + + values := make([]any, 0, len(v)) + for _, vv := range v { + values = append(values, parseValue(vv)) + } + opts[key] = values + } + + if len(opts) > 0 { + spec.Options = opts + } + + return spec, nil +} + +func parseDurationValue(raw string) (time.Duration, error) { + if d, err := time.ParseDuration(raw); err == nil { + return d, nil + } + + seconds, err := strconv.ParseInt(raw, 10, 64) + if err != nil { + return 0, fmt.Errorf("invalid duration %q", raw) + } + + return time.Duration(seconds) * time.Second, nil +} + +func parseThinkValue(raw string) (*api.ThinkValue, error) { + lowered := strings.ToLower(raw) + switch lowered { + case "true", "false": + b, _ := strconv.ParseBool(lowered) + return &api.ThinkValue{Value: b}, nil + case "high", "medium", "low": + return &api.ThinkValue{Value: lowered}, nil + default: + return nil, fmt.Errorf("invalid think value %q", raw) + } +} + +func parseValue(raw string) any { + if b, err := strconv.ParseBool(raw); err == nil { + return b + } + + if i, err := strconv.ParseInt(raw, 10, 64); err == nil { + return i + } + + if f, err := strconv.ParseFloat(raw, 64); err == nil { + return f + } + + return raw +} + +func waitForServerReady(ctx context.Context, client *api.Client) error { + ticker := time.NewTicker(500 * time.Millisecond) + defer ticker.Stop() + + for { + if err := client.Heartbeat(ctx); err == nil { + return nil + } + + select { + case <-ctx.Done(): + return ctx.Err() + case <-ticker.C: + } + } +} + +func preloadModels(ctx context.Context) error { + specs, err := parsePreloadSpecs(envconfig.PreloadedModels()) + if err != nil { + return err + } + + if len(specs) == 0 { + return nil + } + + client, err := api.ClientFromEnvironment() + if err != nil { + return err + } + + if err := waitForServerReady(ctx, client); err != nil { + return err + } + + for _, spec := range specs { + if ctx.Err() != nil { + return ctx.Err() + } + + slog.Info("preloading model", "model", spec.Name) + + info, err := client.Show(ctx, &api.ShowRequest{Name: spec.Name}) + if err != nil { + slog.Error("unable to describe model for preloading", "model", spec.Name, "error", err) + continue + } + + isEmbedding := slices.Contains(info.Capabilities, model.CapabilityEmbedding) + prompt := spec.Prompt + if prompt == "" && isEmbedding { + prompt = "init" + } + + if spec.Options == nil { + spec.Options = map[string]any{} + } + + if isEmbedding { + req := &api.EmbedRequest{ + Model: spec.Name, + Input: prompt, + KeepAlive: spec.KeepAlive, + Options: spec.Options, + } + + if _, err := client.Embed(ctx, req); err != nil { + slog.Error("preloading embedding model failed", "model", spec.Name, "error", err) + continue + } + } else { + stream := false + req := &api.GenerateRequest{ + Model: spec.Name, + Prompt: prompt, + KeepAlive: spec.KeepAlive, + Options: spec.Options, + Stream: &stream, + Think: spec.Think, + } + + if err := client.Generate(ctx, req, func(api.GenerateResponse) error { return nil }); err != nil { + slog.Error("preloading model failed", "model", spec.Name, "error", err) + continue + } + } + + slog.Info("model preloaded", "model", spec.Name) + } + + return nil +} diff --git a/server/routes.go b/server/routes.go index 977a13ff2..39e1c1580 100644 --- a/server/routes.go +++ b/server/routes.go @@ -1658,6 +1658,14 @@ func Serve(ln net.Listener) error { slog.Info("entering low vram mode", "total vram", format.HumanBytes2(totalVRAM), "threshold", format.HumanBytes2(lowVRAMThreshold)) } + preloadCtx, preloadCancel := context.WithCancel(ctx) + defer preloadCancel() + go func() { + if err := preloadModels(preloadCtx); err != nil && !errors.Is(err, context.Canceled) { + slog.Error("failed to preload models", "error", err) + } + }() + err = srvr.Serve(ln) // If server is closed from the signal handler, wait for the ctx to be done // otherwise error out quickly