diff --git a/ml/backend/ggml/ggml.go b/ml/backend/ggml/ggml.go index ebcc1d86f..75fce7a68 100644 --- a/ml/backend/ggml/ggml.go +++ b/ml/backend/ggml/ggml.go @@ -851,19 +851,19 @@ func (c *Context) Reserve() { slog.Debug("compute graph", "nodes", C.ggml_graph_n_nodes(c.graph), "splits", C.ggml_backend_sched_get_n_splits(c.b.sched)) - // Reserve may get called multiple times for different graphs - we just want the last run, which will contain the max allocations - for _, bt := range c.b.schedBufts { - c.b.btDeviceMemory[bt].Graph = 0 - } - + graphs := make(map[C.ggml_backend_buffer_type_t]uint64) for i := range c.b.schedBackends { bufferSize := C.ggml_backend_sched_get_attempted_buffer_size(c.b.sched, c.b.schedBackends[i]) - c.b.btDeviceMemory[c.b.schedBufts[i]].Graph += uint64(bufferSize) + graphs[c.b.schedBufts[i]] += uint64(bufferSize) logutil.Trace("compute graph", "backend", C.GoString(C.ggml_backend_name(c.b.schedBackends[i])), "buffer_type", C.GoString(C.ggml_backend_buft_name(c.b.schedBufts[i])), "size", format.HumanBytes2(uint64(bufferSize))) } + for bt, size := range graphs { + c.b.btDeviceMemory[bt].Graph = max(c.b.btDeviceMemory[bt].Graph, size) + } + if !reserved { panic(ml.ErrNoMem{BackendMemory: *c.b.requiredMemory}) }