Merge 11e9dcca85 into 6c3faafed2
This commit is contained in:
commit
b62dee7d78
|
|
@ -590,7 +590,9 @@ type Runner struct {
|
|||
NumGPU int `json:"num_gpu,omitempty"`
|
||||
MainGPU int `json:"main_gpu,omitempty"`
|
||||
UseMMap *bool `json:"use_mmap,omitempty"`
|
||||
NumThread int `json:"num_thread,omitempty"`
|
||||
NumThread int `json:"num_thread,omitempty"`
|
||||
FitVRAM bool `json:"fit_vram,omitempty"`
|
||||
MaxVRAM uint64 `json:"max_vram,omitempty"`
|
||||
}
|
||||
|
||||
// EmbedRequest is the request passed to [Client.Embed].
|
||||
|
|
|
|||
67
cmd/cmd.go
67
cmd/cmd.go
|
|
@ -416,6 +416,26 @@ func RunHandler(cmd *cobra.Command, args []string) error {
|
|||
opts.KeepAlive = &api.Duration{Duration: d}
|
||||
}
|
||||
|
||||
fitVRAM, err := cmd.Flags().GetBool("fit-vram")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if fitVRAM {
|
||||
opts.Options["fit_vram"] = true
|
||||
}
|
||||
|
||||
maxVRAMStr, err := cmd.Flags().GetString("max-vram")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if maxVRAMStr != "" {
|
||||
maxVRAM, err := parseBytes(maxVRAMStr)
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid max-vram value: %w", err)
|
||||
}
|
||||
opts.Options["max_vram"] = maxVRAM
|
||||
}
|
||||
|
||||
prompts := args[1:]
|
||||
// prepend stdin to the prompt if provided
|
||||
if !term.IsTerminal(int(os.Stdin.Fd())) {
|
||||
|
|
@ -1754,6 +1774,8 @@ func NewCLI() *cobra.Command {
|
|||
runCmd.Flags().Bool("hidethinking", false, "Hide thinking output (if provided)")
|
||||
runCmd.Flags().Bool("truncate", false, "For embedding models: truncate inputs exceeding context length (default: true). Set --truncate=false to error instead")
|
||||
runCmd.Flags().Int("dimensions", 0, "Truncate output embeddings to specified dimension (embedding models only)")
|
||||
runCmd.Flags().Bool("fit-vram", false, "Fit num_ctx to VRAM budget (recalc on model switch)")
|
||||
runCmd.Flags().String("max-vram", "", "Max VRAM budget (e.g. 6GB)")
|
||||
|
||||
stopCmd := &cobra.Command{
|
||||
Use: "stop MODEL",
|
||||
|
|
@ -1979,3 +2001,48 @@ func renderToolCalls(toolCalls []api.ToolCall, plainText bool) string {
|
|||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func parseBytes(s string) (uint64, error) {
|
||||
if s == "" {
|
||||
return 0, nil
|
||||
}
|
||||
s = strings.ToUpper(strings.TrimSpace(s))
|
||||
|
||||
var multiplier uint64 = 1
|
||||
switch {
|
||||
case strings.HasSuffix(s, "TB"):
|
||||
multiplier = 1000 * 1000 * 1000 * 1000
|
||||
s = strings.TrimSuffix(s, "TB")
|
||||
case strings.HasSuffix(s, "GB"):
|
||||
multiplier = 1000 * 1000 * 1000
|
||||
s = strings.TrimSuffix(s, "GB")
|
||||
case strings.HasSuffix(s, "MB"):
|
||||
multiplier = 1000 * 1000
|
||||
s = strings.TrimSuffix(s, "MB")
|
||||
case strings.HasSuffix(s, "KB"):
|
||||
multiplier = 1000
|
||||
s = strings.TrimSuffix(s, "KB")
|
||||
case strings.HasSuffix(s, "B"):
|
||||
multiplier = 1
|
||||
s = strings.TrimSuffix(s, "B")
|
||||
case strings.HasSuffix(s, "TIB"):
|
||||
multiplier = 1024 * 1024 * 1024 * 1024
|
||||
s = strings.TrimSuffix(s, "TIB")
|
||||
case strings.HasSuffix(s, "GIB"):
|
||||
multiplier = 1024 * 1024 * 1024
|
||||
s = strings.TrimSuffix(s, "GIB")
|
||||
case strings.HasSuffix(s, "MIB"):
|
||||
multiplier = 1024 * 1024
|
||||
s = strings.TrimSuffix(s, "MIB")
|
||||
case strings.HasSuffix(s, "KIB"):
|
||||
multiplier = 1024
|
||||
s = strings.TrimSuffix(s, "KIB")
|
||||
}
|
||||
|
||||
val, err := strconv.ParseFloat(s, 64)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
return uint64(val * float64(multiplier)), nil
|
||||
}
|
||||
|
|
|
|||
|
|
@ -17,7 +17,7 @@ func startApp(ctx context.Context, client *api.Client) error {
|
|||
}
|
||||
link, err := os.Readlink(exe)
|
||||
if err != nil {
|
||||
return err
|
||||
return errors.New("could not connect to ollama server, run 'ollama serve' to start it")
|
||||
}
|
||||
r := regexp.MustCompile(`^.*/Ollama\s?\d*.app`)
|
||||
m := r.FindStringSubmatch(link)
|
||||
|
|
|
|||
|
|
@ -257,6 +257,63 @@ func NewLlamaServer(systemInfo ml.SystemInfo, gpus []ml.DeviceInfo, modelPath st
|
|||
}
|
||||
}
|
||||
|
||||
if opts.FitVRAM {
|
||||
var availableVRAM uint64
|
||||
for _, gpu := range gpus {
|
||||
availableVRAM += gpu.FreeMemory
|
||||
}
|
||||
|
||||
// Reserve 15% headroom
|
||||
availableVRAM = uint64(float64(availableVRAM) * 0.85)
|
||||
|
||||
if opts.MaxVRAM > 0 && opts.MaxVRAM < availableVRAM {
|
||||
availableVRAM = opts.MaxVRAM
|
||||
}
|
||||
|
||||
trainCtx := f.KV().ContextLength()
|
||||
low := 2048
|
||||
high := int(trainCtx)
|
||||
if high <= 0 {
|
||||
high = 32768
|
||||
}
|
||||
if high < low {
|
||||
low = high
|
||||
}
|
||||
|
||||
best := low
|
||||
est := estimateMemoryUsage(f, low, opts.NumBatch, numParallel, loadRequest.KvCacheType, loadRequest.FlashAttention)
|
||||
if est > availableVRAM {
|
||||
slog.Warn("minimal context does not fit in VRAM", "num_ctx", low, "required", format.HumanBytes(int64(est)), "available", format.HumanBytes(int64(availableVRAM)))
|
||||
best = low
|
||||
} else {
|
||||
for low <= high {
|
||||
mid := (low + high) / 2
|
||||
// Align to 256
|
||||
mid = (mid / 256) * 256
|
||||
if mid < best {
|
||||
mid = best
|
||||
}
|
||||
|
||||
est := estimateMemoryUsage(f, mid, opts.NumBatch, numParallel, loadRequest.KvCacheType, loadRequest.FlashAttention)
|
||||
if est <= availableVRAM {
|
||||
best = mid
|
||||
low = mid + 256
|
||||
} else {
|
||||
high = mid - 256
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
slog.Info("auto-sized num_ctx", "original", opts.NumCtx, "new", best, "available_vram", format.HumanBytes(int64(availableVRAM)))
|
||||
opts.NumCtx = best
|
||||
loadRequest.KvSize = opts.NumCtx * numParallel
|
||||
|
||||
if opts.NumBatch > opts.NumCtx {
|
||||
opts.NumBatch = opts.NumCtx
|
||||
loadRequest.BatchSize = opts.NumBatch
|
||||
}
|
||||
}
|
||||
|
||||
gpuLibs := ml.LibraryPaths(gpus)
|
||||
status := NewStatusWriter(os.Stderr)
|
||||
cmd, port, err := StartRunner(
|
||||
|
|
@ -1897,3 +1954,37 @@ func (s *ollamaServer) GetDeviceInfos(ctx context.Context) []ml.DeviceInfo {
|
|||
}
|
||||
return devices
|
||||
}
|
||||
|
||||
func estimateMemoryUsage(f *ggml.GGML, numCtx int, batchSize int, numParallel int, kvCacheType string, fa ml.FlashAttentionType) uint64 {
|
||||
// 1. Calculate weights size
|
||||
var weights uint64
|
||||
layers := f.Tensors().GroupLayers()
|
||||
|
||||
// Sum all block layers
|
||||
for i := uint64(0); i < f.KV().BlockCount(); i++ {
|
||||
if blk, ok := layers[fmt.Sprintf("blk.%d", i)]; ok {
|
||||
weights += blk.Size()
|
||||
}
|
||||
}
|
||||
|
||||
// Add output/token embeddings
|
||||
if layer, ok := layers["output_norm"]; ok {
|
||||
weights += layer.Size()
|
||||
}
|
||||
if layer, ok := layers["output"]; ok {
|
||||
weights += layer.Size()
|
||||
} else if layer, ok := layers["token_embd"]; ok {
|
||||
weights += layer.Size()
|
||||
}
|
||||
|
||||
// 2. Calculate Graph & KV size
|
||||
kv, _, graphFull := f.GraphSize(uint64(numCtx), uint64(batchSize), numParallel, kvCacheType, fa)
|
||||
|
||||
var kvTotal uint64
|
||||
for _, k := range kv {
|
||||
kvTotal += k
|
||||
}
|
||||
|
||||
// Total estimate: Weights + KV + Graph Scratch
|
||||
return weights + kvTotal + graphFull
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in New Issue