feat: Auto-size num_ctx to VRAM budget (Issue #12353)

This commit is contained in:
ljluestc 2026-01-03 23:34:38 -08:00
parent 18fdcc94e5
commit 11e9dcca85
4 changed files with 162 additions and 2 deletions

View File

@ -457,7 +457,9 @@ type Runner struct {
NumGPU int `json:"num_gpu,omitempty"`
MainGPU int `json:"main_gpu,omitempty"`
UseMMap *bool `json:"use_mmap,omitempty"`
NumThread int `json:"num_thread,omitempty"`
NumThread int `json:"num_thread,omitempty"`
FitVRAM bool `json:"fit_vram,omitempty"`
MaxVRAM uint64 `json:"max_vram,omitempty"`
}
// EmbedRequest is the request passed to [Client.Embed].

View File

@ -416,6 +416,26 @@ func RunHandler(cmd *cobra.Command, args []string) error {
opts.KeepAlive = &api.Duration{Duration: d}
}
fitVRAM, err := cmd.Flags().GetBool("fit-vram")
if err != nil {
return err
}
if fitVRAM {
opts.Options["fit_vram"] = true
}
maxVRAMStr, err := cmd.Flags().GetString("max-vram")
if err != nil {
return err
}
if maxVRAMStr != "" {
maxVRAM, err := parseBytes(maxVRAMStr)
if err != nil {
return fmt.Errorf("invalid max-vram value: %w", err)
}
opts.Options["max_vram"] = maxVRAM
}
prompts := args[1:]
// prepend stdin to the prompt if provided
if !term.IsTerminal(int(os.Stdin.Fd())) {
@ -1754,6 +1774,8 @@ func NewCLI() *cobra.Command {
runCmd.Flags().Bool("hidethinking", false, "Hide thinking output (if provided)")
runCmd.Flags().Bool("truncate", false, "For embedding models: truncate inputs exceeding context length (default: true). Set --truncate=false to error instead")
runCmd.Flags().Int("dimensions", 0, "Truncate output embeddings to specified dimension (embedding models only)")
runCmd.Flags().Bool("fit-vram", false, "Fit num_ctx to VRAM budget (recalc on model switch)")
runCmd.Flags().String("max-vram", "", "Max VRAM budget (e.g. 6GB)")
stopCmd := &cobra.Command{
Use: "stop MODEL",
@ -1979,3 +2001,48 @@ func renderToolCalls(toolCalls []api.ToolCall, plainText bool) string {
}
return out
}
func parseBytes(s string) (uint64, error) {
if s == "" {
return 0, nil
}
s = strings.ToUpper(strings.TrimSpace(s))
var multiplier uint64 = 1
switch {
case strings.HasSuffix(s, "TB"):
multiplier = 1000 * 1000 * 1000 * 1000
s = strings.TrimSuffix(s, "TB")
case strings.HasSuffix(s, "GB"):
multiplier = 1000 * 1000 * 1000
s = strings.TrimSuffix(s, "GB")
case strings.HasSuffix(s, "MB"):
multiplier = 1000 * 1000
s = strings.TrimSuffix(s, "MB")
case strings.HasSuffix(s, "KB"):
multiplier = 1000
s = strings.TrimSuffix(s, "KB")
case strings.HasSuffix(s, "B"):
multiplier = 1
s = strings.TrimSuffix(s, "B")
case strings.HasSuffix(s, "TIB"):
multiplier = 1024 * 1024 * 1024 * 1024
s = strings.TrimSuffix(s, "TIB")
case strings.HasSuffix(s, "GIB"):
multiplier = 1024 * 1024 * 1024
s = strings.TrimSuffix(s, "GIB")
case strings.HasSuffix(s, "MIB"):
multiplier = 1024 * 1024
s = strings.TrimSuffix(s, "MIB")
case strings.HasSuffix(s, "KIB"):
multiplier = 1024
s = strings.TrimSuffix(s, "KIB")
}
val, err := strconv.ParseFloat(s, 64)
if err != nil {
return 0, err
}
return uint64(val * float64(multiplier)), nil
}

View File

@ -17,7 +17,7 @@ func startApp(ctx context.Context, client *api.Client) error {
}
link, err := os.Readlink(exe)
if err != nil {
return err
return errors.New("could not connect to ollama server, run 'ollama serve' to start it")
}
r := regexp.MustCompile(`^.*/Ollama\s?\d*.app`)
m := r.FindStringSubmatch(link)

View File

@ -257,6 +257,63 @@ func NewLlamaServer(systemInfo ml.SystemInfo, gpus []ml.DeviceInfo, modelPath st
}
}
if opts.FitVRAM {
var availableVRAM uint64
for _, gpu := range gpus {
availableVRAM += gpu.FreeMemory
}
// Reserve 15% headroom
availableVRAM = uint64(float64(availableVRAM) * 0.85)
if opts.MaxVRAM > 0 && opts.MaxVRAM < availableVRAM {
availableVRAM = opts.MaxVRAM
}
trainCtx := f.KV().ContextLength()
low := 2048
high := int(trainCtx)
if high <= 0 {
high = 32768
}
if high < low {
low = high
}
best := low
est := estimateMemoryUsage(f, low, opts.NumBatch, numParallel, loadRequest.KvCacheType, loadRequest.FlashAttention)
if est > availableVRAM {
slog.Warn("minimal context does not fit in VRAM", "num_ctx", low, "required", format.HumanBytes(int64(est)), "available", format.HumanBytes(int64(availableVRAM)))
best = low
} else {
for low <= high {
mid := (low + high) / 2
// Align to 256
mid = (mid / 256) * 256
if mid < best {
mid = best
}
est := estimateMemoryUsage(f, mid, opts.NumBatch, numParallel, loadRequest.KvCacheType, loadRequest.FlashAttention)
if est <= availableVRAM {
best = mid
low = mid + 256
} else {
high = mid - 256
}
}
}
slog.Info("auto-sized num_ctx", "original", opts.NumCtx, "new", best, "available_vram", format.HumanBytes(int64(availableVRAM)))
opts.NumCtx = best
loadRequest.KvSize = opts.NumCtx * numParallel
if opts.NumBatch > opts.NumCtx {
opts.NumBatch = opts.NumCtx
loadRequest.BatchSize = opts.NumBatch
}
}
gpuLibs := ml.LibraryPaths(gpus)
status := NewStatusWriter(os.Stderr)
cmd, port, err := StartRunner(
@ -1897,3 +1954,37 @@ func (s *ollamaServer) GetDeviceInfos(ctx context.Context) []ml.DeviceInfo {
}
return devices
}
func estimateMemoryUsage(f *ggml.GGML, numCtx int, batchSize int, numParallel int, kvCacheType string, fa ml.FlashAttentionType) uint64 {
// 1. Calculate weights size
var weights uint64
layers := f.Tensors().GroupLayers()
// Sum all block layers
for i := uint64(0); i < f.KV().BlockCount(); i++ {
if blk, ok := layers[fmt.Sprintf("blk.%d", i)]; ok {
weights += blk.Size()
}
}
// Add output/token embeddings
if layer, ok := layers["output_norm"]; ok {
weights += layer.Size()
}
if layer, ok := layers["output"]; ok {
weights += layer.Size()
} else if layer, ok := layers["token_embd"]; ok {
weights += layer.Size()
}
// 2. Calculate Graph & KV size
kv, _, graphFull := f.GraphSize(uint64(numCtx), uint64(batchSize), numParallel, kvCacheType, fa)
var kvTotal uint64
for _, k := range kv {
kvTotal += k
}
// Total estimate: Weights + KV + Graph Scratch
return weights + kvTotal + graphFull
}