ollama/server/preload.go

250 lines
5.1 KiB
Go

package server
import (
"context"
"fmt"
"log/slog"
"net/url"
"slices"
"strconv"
"strings"
"time"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/envconfig"
"github.com/ollama/ollama/types/model"
)
type preloadModelSpec struct {
Name string
Prompt string
KeepAlive *api.Duration
Options map[string]any
Think *api.ThinkValue
}
func parsePreloadSpecs(raw string) ([]preloadModelSpec, error) {
if strings.TrimSpace(raw) == "" {
return nil, nil
}
parts := strings.Split(raw, ",")
specs := make([]preloadModelSpec, 0, len(parts))
for _, part := range parts {
part = strings.TrimSpace(part)
if part == "" {
continue
}
spec, err := parsePreloadEntry(part)
if err != nil {
return nil, fmt.Errorf("could not parse %q: %w", part, err)
}
specs = append(specs, spec)
}
return specs, nil
}
func parsePreloadEntry(raw string) (preloadModelSpec, error) {
var spec preloadModelSpec
namePart, optionsPart, hasOptions := strings.Cut(raw, "?")
spec.Name = strings.TrimSpace(namePart)
if spec.Name == "" {
return spec, fmt.Errorf("model name is required")
}
if !hasOptions || strings.TrimSpace(optionsPart) == "" {
return spec, nil
}
values, err := url.ParseQuery(optionsPart)
if err != nil {
return spec, fmt.Errorf("invalid parameters: %w", err)
}
opts := map[string]any{}
for key, val := range values {
if len(val) == 0 {
continue
}
v := val
// only use the last value unless multiple are explicitly provided
if len(v) == 1 {
switch strings.ToLower(key) {
case "prompt":
spec.Prompt = v[0]
continue
case "keepalive", "keep_alive":
d, err := parseDurationValue(v[0])
if err != nil {
return spec, fmt.Errorf("keepalive for %s: %w", spec.Name, err)
}
spec.KeepAlive = &api.Duration{Duration: d}
continue
case "think":
tv, err := parseThinkValue(v[0])
if err != nil {
return spec, fmt.Errorf("think for %s: %w", spec.Name, err)
}
spec.Think = tv
continue
}
opts[key] = parseValue(v[0])
continue
}
values := make([]any, 0, len(v))
for _, vv := range v {
values = append(values, parseValue(vv))
}
opts[key] = values
}
if len(opts) > 0 {
spec.Options = opts
}
return spec, nil
}
func parseDurationValue(raw string) (time.Duration, error) {
if d, err := time.ParseDuration(raw); err == nil {
return d, nil
}
seconds, err := strconv.ParseInt(raw, 10, 64)
if err != nil {
return 0, fmt.Errorf("invalid duration %q", raw)
}
return time.Duration(seconds) * time.Second, nil
}
func parseThinkValue(raw string) (*api.ThinkValue, error) {
lowered := strings.ToLower(raw)
switch lowered {
case "true", "false":
b, _ := strconv.ParseBool(lowered)
return &api.ThinkValue{Value: b}, nil
case "high", "medium", "low":
return &api.ThinkValue{Value: lowered}, nil
default:
return nil, fmt.Errorf("invalid think value %q", raw)
}
}
func parseValue(raw string) any {
if b, err := strconv.ParseBool(raw); err == nil {
return b
}
if i, err := strconv.ParseInt(raw, 10, 64); err == nil {
return i
}
if f, err := strconv.ParseFloat(raw, 64); err == nil {
return f
}
return raw
}
func waitForServerReady(ctx context.Context, client *api.Client) error {
ticker := time.NewTicker(500 * time.Millisecond)
defer ticker.Stop()
for {
if err := client.Heartbeat(ctx); err == nil {
return nil
}
select {
case <-ctx.Done():
return ctx.Err()
case <-ticker.C:
}
}
}
func preloadModels(ctx context.Context) error {
specs, err := parsePreloadSpecs(envconfig.PreloadedModels())
if err != nil {
return err
}
if len(specs) == 0 {
return nil
}
client, err := api.ClientFromEnvironment()
if err != nil {
return err
}
if err := waitForServerReady(ctx, client); err != nil {
return err
}
for _, spec := range specs {
if ctx.Err() != nil {
return ctx.Err()
}
slog.Info("preloading model", "model", spec.Name)
info, err := client.Show(ctx, &api.ShowRequest{Name: spec.Name})
if err != nil {
slog.Error("unable to describe model for preloading", "model", spec.Name, "error", err)
continue
}
isEmbedding := slices.Contains(info.Capabilities, model.CapabilityEmbedding)
prompt := spec.Prompt
if prompt == "" && isEmbedding {
prompt = "init"
}
if spec.Options == nil {
spec.Options = map[string]any{}
}
if isEmbedding {
req := &api.EmbedRequest{
Model: spec.Name,
Input: prompt,
KeepAlive: spec.KeepAlive,
Options: spec.Options,
}
if _, err := client.Embed(ctx, req); err != nil {
slog.Error("preloading embedding model failed", "model", spec.Name, "error", err)
continue
}
} else {
stream := false
req := &api.GenerateRequest{
Model: spec.Name,
Prompt: prompt,
KeepAlive: spec.KeepAlive,
Options: spec.Options,
Stream: &stream,
Think: spec.Think,
}
if err := client.Generate(ctx, req, func(api.GenerateResponse) error { return nil }); err != nil {
slog.Error("preloading model failed", "model", spec.Name, "error", err)
continue
}
}
slog.Info("model preloaded", "model", spec.Name)
}
return nil
}