From 11e9dcca857d7562d467375535714155e2202910 Mon Sep 17 00:00:00 2001 From: ljluestc Date: Sat, 3 Jan 2026 23:34:38 -0800 Subject: [PATCH] feat: Auto-size num_ctx to VRAM budget (Issue #12353) --- api/types.go | 4 +- cmd/cmd.go | 67 +++++++++++++++++++++++++++++++++ cmd/start_darwin.go | 2 +- llm/server.go | 91 +++++++++++++++++++++++++++++++++++++++++++++ 4 files changed, 162 insertions(+), 2 deletions(-) diff --git a/api/types.go b/api/types.go index 63b898975..a5dbc4567 100644 --- a/api/types.go +++ b/api/types.go @@ -457,7 +457,9 @@ type Runner struct { NumGPU int `json:"num_gpu,omitempty"` MainGPU int `json:"main_gpu,omitempty"` UseMMap *bool `json:"use_mmap,omitempty"` - NumThread int `json:"num_thread,omitempty"` + NumThread int `json:"num_thread,omitempty"` + FitVRAM bool `json:"fit_vram,omitempty"` + MaxVRAM uint64 `json:"max_vram,omitempty"` } // EmbedRequest is the request passed to [Client.Embed]. diff --git a/cmd/cmd.go b/cmd/cmd.go index 35074ad2b..03cbb0b59 100644 --- a/cmd/cmd.go +++ b/cmd/cmd.go @@ -416,6 +416,26 @@ func RunHandler(cmd *cobra.Command, args []string) error { opts.KeepAlive = &api.Duration{Duration: d} } + fitVRAM, err := cmd.Flags().GetBool("fit-vram") + if err != nil { + return err + } + if fitVRAM { + opts.Options["fit_vram"] = true + } + + maxVRAMStr, err := cmd.Flags().GetString("max-vram") + if err != nil { + return err + } + if maxVRAMStr != "" { + maxVRAM, err := parseBytes(maxVRAMStr) + if err != nil { + return fmt.Errorf("invalid max-vram value: %w", err) + } + opts.Options["max_vram"] = maxVRAM + } + prompts := args[1:] // prepend stdin to the prompt if provided if !term.IsTerminal(int(os.Stdin.Fd())) { @@ -1754,6 +1774,8 @@ func NewCLI() *cobra.Command { runCmd.Flags().Bool("hidethinking", false, "Hide thinking output (if provided)") runCmd.Flags().Bool("truncate", false, "For embedding models: truncate inputs exceeding context length (default: true). Set --truncate=false to error instead") runCmd.Flags().Int("dimensions", 0, "Truncate output embeddings to specified dimension (embedding models only)") + runCmd.Flags().Bool("fit-vram", false, "Fit num_ctx to VRAM budget (recalc on model switch)") + runCmd.Flags().String("max-vram", "", "Max VRAM budget (e.g. 6GB)") stopCmd := &cobra.Command{ Use: "stop MODEL", @@ -1979,3 +2001,48 @@ func renderToolCalls(toolCalls []api.ToolCall, plainText bool) string { } return out } + +func parseBytes(s string) (uint64, error) { + if s == "" { + return 0, nil + } + s = strings.ToUpper(strings.TrimSpace(s)) + + var multiplier uint64 = 1 + switch { + case strings.HasSuffix(s, "TB"): + multiplier = 1000 * 1000 * 1000 * 1000 + s = strings.TrimSuffix(s, "TB") + case strings.HasSuffix(s, "GB"): + multiplier = 1000 * 1000 * 1000 + s = strings.TrimSuffix(s, "GB") + case strings.HasSuffix(s, "MB"): + multiplier = 1000 * 1000 + s = strings.TrimSuffix(s, "MB") + case strings.HasSuffix(s, "KB"): + multiplier = 1000 + s = strings.TrimSuffix(s, "KB") + case strings.HasSuffix(s, "B"): + multiplier = 1 + s = strings.TrimSuffix(s, "B") + case strings.HasSuffix(s, "TIB"): + multiplier = 1024 * 1024 * 1024 * 1024 + s = strings.TrimSuffix(s, "TIB") + case strings.HasSuffix(s, "GIB"): + multiplier = 1024 * 1024 * 1024 + s = strings.TrimSuffix(s, "GIB") + case strings.HasSuffix(s, "MIB"): + multiplier = 1024 * 1024 + s = strings.TrimSuffix(s, "MIB") + case strings.HasSuffix(s, "KIB"): + multiplier = 1024 + s = strings.TrimSuffix(s, "KIB") + } + + val, err := strconv.ParseFloat(s, 64) + if err != nil { + return 0, err + } + + return uint64(val * float64(multiplier)), nil +} diff --git a/cmd/start_darwin.go b/cmd/start_darwin.go index 05a4551e1..4c39da148 100644 --- a/cmd/start_darwin.go +++ b/cmd/start_darwin.go @@ -17,7 +17,7 @@ func startApp(ctx context.Context, client *api.Client) error { } link, err := os.Readlink(exe) if err != nil { - return err + return errors.New("could not connect to ollama server, run 'ollama serve' to start it") } r := regexp.MustCompile(`^.*/Ollama\s?\d*.app`) m := r.FindStringSubmatch(link) diff --git a/llm/server.go b/llm/server.go index c83bd5a40..64f086971 100644 --- a/llm/server.go +++ b/llm/server.go @@ -257,6 +257,63 @@ func NewLlamaServer(systemInfo ml.SystemInfo, gpus []ml.DeviceInfo, modelPath st } } + if opts.FitVRAM { + var availableVRAM uint64 + for _, gpu := range gpus { + availableVRAM += gpu.FreeMemory + } + + // Reserve 15% headroom + availableVRAM = uint64(float64(availableVRAM) * 0.85) + + if opts.MaxVRAM > 0 && opts.MaxVRAM < availableVRAM { + availableVRAM = opts.MaxVRAM + } + + trainCtx := f.KV().ContextLength() + low := 2048 + high := int(trainCtx) + if high <= 0 { + high = 32768 + } + if high < low { + low = high + } + + best := low + est := estimateMemoryUsage(f, low, opts.NumBatch, numParallel, loadRequest.KvCacheType, loadRequest.FlashAttention) + if est > availableVRAM { + slog.Warn("minimal context does not fit in VRAM", "num_ctx", low, "required", format.HumanBytes(int64(est)), "available", format.HumanBytes(int64(availableVRAM))) + best = low + } else { + for low <= high { + mid := (low + high) / 2 + // Align to 256 + mid = (mid / 256) * 256 + if mid < best { + mid = best + } + + est := estimateMemoryUsage(f, mid, opts.NumBatch, numParallel, loadRequest.KvCacheType, loadRequest.FlashAttention) + if est <= availableVRAM { + best = mid + low = mid + 256 + } else { + high = mid - 256 + } + } + } + + slog.Info("auto-sized num_ctx", "original", opts.NumCtx, "new", best, "available_vram", format.HumanBytes(int64(availableVRAM))) + opts.NumCtx = best + loadRequest.KvSize = opts.NumCtx * numParallel + + if opts.NumBatch > opts.NumCtx { + opts.NumBatch = opts.NumCtx + loadRequest.BatchSize = opts.NumBatch + } + } + gpuLibs := ml.LibraryPaths(gpus) status := NewStatusWriter(os.Stderr) cmd, port, err := StartRunner( @@ -1897,3 +1954,37 @@ func (s *ollamaServer) GetDeviceInfos(ctx context.Context) []ml.DeviceInfo { } return devices } + +func estimateMemoryUsage(f *ggml.GGML, numCtx int, batchSize int, numParallel int, kvCacheType string, fa ml.FlashAttentionType) uint64 { + // 1. Calculate weights size + var weights uint64 + layers := f.Tensors().GroupLayers() + + // Sum all block layers + for i := uint64(0); i < f.KV().BlockCount(); i++ { + if blk, ok := layers[fmt.Sprintf("blk.%d", i)]; ok { + weights += blk.Size() + } + } + + // Add output/token embeddings + if layer, ok := layers["output_norm"]; ok { + weights += layer.Size() + } + if layer, ok := layers["output"]; ok { + weights += layer.Size() + } else if layer, ok := layers["token_embd"]; ok { + weights += layer.Size() + } + + // 2. Calculate Graph & KV size + kv, _, graphFull := f.GraphSize(uint64(numCtx), uint64(batchSize), numParallel, kvCacheType, fa) + + var kvTotal uint64 + for _, k := range kv { + kvTotal += k + } + + // Total estimate: Weights + KV + Graph Scratch + return weights + kvTotal + graphFull +}