From 1200e427f729b0786781321f05594fe2aff26108 Mon Sep 17 00:00:00 2001 From: Jesse Gross Date: Mon, 8 Dec 2025 14:16:23 -0800 Subject: [PATCH] ollamarunner: Automatically enable flash attention If a user hasn't explicitly either enabled or disabled flash attention, automatically enable flash attention if the model supports it and it would not trigger a fallback to CPU. This supports text, vision and embedding models as well as automatic handling of KV cache quantization (which requires flash attention). If a model does not call the fast fused attention operation, this is detected and disables any operations that depend on it. --- fs/ggml/ggml.go | 32 +------ kvcache/causal_test.go | 2 +- llm/server.go | 79 ++++------------ ml/backend.go | 2 +- ml/backend/ggml/ggml.go | 147 +++++++++++++++++++++++++++++- runner/llamarunner/runner.go | 4 +- runner/ollamarunner/multimodal.go | 5 +- runner/ollamarunner/runner.go | 134 +++++++++++++++------------ 8 files changed, 249 insertions(+), 156 deletions(-) diff --git a/fs/ggml/ggml.go b/fs/ggml/ggml.go index 44a48511c..d3055ff53 100644 --- a/fs/ggml/ggml.go +++ b/fs/ggml/ggml.go @@ -813,43 +813,13 @@ func (f GGML) SupportsKVCacheType(cacheType string) bool { } // KVCacheTypeIsQuantized checks if the requested cache type is a quantized type -func (f GGML) KVCacheTypeIsQuantized(cacheType string) bool { +func KVCacheTypeIsQuantized(cacheType string) bool { if cacheType == "" || cacheType == "f16" || cacheType == "f32" || cacheType == "bf16" { return false } return true } -// SupportsFlashAttention checks if the model supports flash attention -func (f GGML) SupportsFlashAttention() bool { - _, isEmbedding := f.KV()[fmt.Sprintf("%s.pooling_type", f.KV().Architecture())] - if isEmbedding { - return false - } - - if arch := f.KV().Architecture(); slices.Contains([]string{"gemma2"}, arch) { - return false - } - - // Check head counts match and are non-zero - headCountK := f.KV().EmbeddingHeadCountK() - headCountV := f.KV().EmbeddingHeadCountV() - return headCountK != 0 && headCountV != 0 && headCountK == headCountV -} - -// FlashAttention checks if the model should enable flash attention -func (f GGML) FlashAttention() bool { - return slices.Contains([]string{ - "bert", - "gemma3", - "gptoss", "gpt-oss", - "mistral3", - "olmo3", - "qwen3", "qwen3moe", - "qwen3vl", "qwen3vlmoe", - }, f.KV().String("general.architecture")) -} - // kvCacheBytesPerElement returns the number of bytes per element for a given KV cache type func kvCacheBytesPerElement(cacheType string) float64 { switch cacheType { diff --git a/kvcache/causal_test.go b/kvcache/causal_test.go index aeda93bc6..d9a102d50 100644 --- a/kvcache/causal_test.go +++ b/kvcache/causal_test.go @@ -696,7 +696,7 @@ func (c *testContext) Forward(...ml.Tensor) ml.Context { return c } func (c *testContext) Compute(...ml.Tensor) {} -func (c *testContext) Reserve() {} +func (c *testContext) Reserve() error { return nil } func (c *testContext) MaxGraphNodes() int { return 10 diff --git a/llm/server.go b/llm/server.go index a89027b06..f591b7deb 100644 --- a/llm/server.go +++ b/llm/server.go @@ -188,73 +188,26 @@ func NewLlamaServer(systemInfo ml.SystemInfo, gpus []ml.DeviceInfo, modelPath st if len(projectors) > 0 && llamaModel != nil { loadRequest.ProjectorPath = projectors[0] } - // Determine if the user has forced FA on or off - faUserSet := false - if envconfig.FlashAttention(true) == envconfig.FlashAttention(false) { - faUserSet = true - } - fa := envconfig.FlashAttention(f.FlashAttention()) + // Determine if the user has forced FA on or off + if envconfig.FlashAttention(true) != envconfig.FlashAttention(false) { + loadRequest.FlashAttention = ml.FlashAttentionAuto + } else if envconfig.FlashAttention(false) { + loadRequest.FlashAttention = ml.FlashAttentionEnabled + } // This will disable flash attention unless all GPUs on the system support it, even if we end up selecting a subset - // that can handle it. - if fa && !ml.FlashAttentionSupported(gpus) { + // that can handle it. There are still holes in GGML's hardware detection for flash attention. + if loadRequest.FlashAttention != ml.FlashAttentionDisabled && !ml.FlashAttentionSupported(gpus) { slog.Warn("flash attention enabled but not supported by gpu") - fa = false - } - - if fa && !f.SupportsFlashAttention() { - slog.Warn("flash attention enabled but not supported by model") - fa = false + loadRequest.FlashAttention = ml.FlashAttentionDisabled } kvct := strings.ToLower(envconfig.KvCacheType()) - - if textProcessor == nil { - flashAttention := ml.FlashAttentionAuto - if faUserSet { - if fa { - flashAttention = ml.FlashAttentionEnabled - } else { - flashAttention = ml.FlashAttentionDisabled - } - } - - if kvct != "" { - if f.KVCacheTypeIsQuantized(kvct) { - if flashAttention != ml.FlashAttentionEnabled { - slog.Warn("OLLAMA_FLASH_ATTENTION must be enabled to use a quantized OLLAMA_KV_CACHE_TYPE", "type", kvct) - loadRequest.KvCacheType = "" - } else if f.SupportsKVCacheType(kvct) { - loadRequest.KvCacheType = kvct - } else { - slog.Warn("unsupported OLLAMA_KV_CACHE_TYPE", "type", kvct) - } - } else { - if f.SupportsKVCacheType(kvct) { - loadRequest.KvCacheType = kvct - } else { - slog.Warn("unsupported OLLAMA_KV_CACHE_TYPE", "type", kvct) - } - } - } - loadRequest.FlashAttention = flashAttention + if f.SupportsKVCacheType(kvct) { + loadRequest.KvCacheType = kvct } else { - // For Ollama engine, use our SupportsFlashAttention logic - if fa { - slog.Info("enabling flash attention") - loadRequest.FlashAttention = ml.FlashAttentionEnabled - - // Flash Attention also supports kv cache quantization - // Enable if the requested and kv cache type is supported by the model - if f.SupportsKVCacheType(kvct) { - loadRequest.KvCacheType = kvct - } else { - slog.Warn("kv cache type not supported by model", "type", kvct) - } - } else if kvct != "" && kvct != "f16" { - slog.Warn("quantized kv cache requested but flash attention disabled", "type", kvct) - } + slog.Warn("unsupported OLLAMA_KV_CACHE_TYPE", "type", kvct) } gpuLibs := ml.LibraryPaths(gpus) @@ -487,6 +440,7 @@ type LoadRequest struct { type LoadResponse struct { Success bool + Request LoadRequest // The original request with fields updated that the runner had to modify Memory ml.BackendMemory } @@ -511,6 +465,11 @@ func (s *llamaServer) Load(ctx context.Context, systemInfo ml.SystemInfo, system s.mem.GPUs[i].Cache = make([]uint64, s.totalLayers) } + if s.loadRequest.FlashAttention != ml.FlashAttentionEnabled && ggml.KVCacheTypeIsQuantized(s.loadRequest.KvCacheType) { + slog.Warn("OLLAMA_FLASH_ATTENTION must be enabled to use a quantized OLLAMA_KV_CACHE_TYPE", "type", s.loadRequest.KvCacheType) + s.loadRequest.KvCacheType = "" + } + // Check if embedding model and adjust batch size accordingly _, isEmbedding := s.ggml.KV()[fmt.Sprintf("%s.pooling_type", s.ggml.KV().Architecture())] if isEmbedding && s.loadRequest.BatchSize < s.options.NumCtx { @@ -769,6 +728,7 @@ nextOperation: resp.Memory.Log(slog.LevelDebug) slog.Debug("memory", "success", resp.Success, "required", resp.Memory) + s.loadRequest = resp.Request // Incorporate any adjustments from the runner to avoid needing to do them again pastAllocations[gpuLayers.Hash()] = struct{}{} s.mem = &resp.Memory @@ -822,6 +782,7 @@ nextOperation: resp.Memory.Log(slog.LevelDebug) slog.Debug("memory", "success", resp.Success, "required", resp.Memory) + s.loadRequest = resp.Request if resp.Success { verifyGPULayers, err := s.createLayout(systemInfo, gpus, &resp.Memory, requireFull, backoff) diff --git a/ml/backend.go b/ml/backend.go index 1e781fa7f..a232855bc 100644 --- a/ml/backend.go +++ b/ml/backend.go @@ -118,7 +118,7 @@ type Context interface { // graph, simply preallocates memory. Typically called with a // worst case graph to ensure all resources are available for // for future inference. - Reserve() + Reserve() error MaxGraphNodes() int Close() diff --git a/ml/backend/ggml/ggml.go b/ml/backend/ggml/ggml.go index 6a044260a..4cc27c25a 100644 --- a/ml/backend/ggml/ggml.go +++ b/ml/backend/ggml/ggml.go @@ -684,7 +684,7 @@ func (b *Backend) NewContextSize(n int) ml.Context { } func (b *Backend) CacheConfig() ml.CacheConfig { - if b.flashAttention == ml.FlashAttentionEnabled { + if b.flashAttention != ml.FlashAttentionDisabled { return ml.CacheConfig{CachePadding: 256, MaskDType: ml.DTypeF16, MaskBatchPadding: C.GGML_KQ_MASK_PAD} } else { return ml.CacheConfig{CachePadding: 256, PermutedV: true} @@ -842,11 +842,16 @@ func (c *Context) ComputeWithNotify(cb func(), tensors ...ml.Tensor) { } } -func (c *Context) Reserve() { +func (c *Context) Reserve() error { if c.batchSize > 0 { C.ggml_backend_sched_set_batch_size(c.b.sched, C.int(c.batchSize)) } + flashBackendAssignments, err := validateGraph(c.graph, c.b.flashAttention) + if err != nil { + return err + } + reserved := C.ggml_backend_sched_reserve(c.b.sched, c.graph) slog.Debug("compute graph", "nodes", C.ggml_graph_n_nodes(c.graph), "splits", C.ggml_backend_sched_get_n_splits(c.b.sched)) @@ -867,6 +872,142 @@ func (c *Context) Reserve() { if !reserved { panic(ml.ErrNoMem{BackendMemory: *c.b.requiredMemory}) } + + // If flash attention is in auto mode, ensure that the scheduler placed the flash attention on the same + // (or higher priority) backend as we originally loaded the weights. If it's a lower priority backend (i.e. CPU), + // that means that backend likely does not support flash attention for this graph. + if c.b.flashAttention == ml.FlashAttentionAuto { + flashIdx := 0 + for i := range C.ggml_graph_n_nodes(c.graph) { + node := C.ggml_graph_node(c.graph, i) + if node.op != C.GGML_OP_FLASH_ATTN_EXT { + continue + } + + if flashIdx >= len(flashBackendAssignments) { + slog.Debug("flash attention assignment missing", + "index", flashIdx, + "tensor", C.GoString(C.ggml_get_name(node))) + return errors.New("flash attention not supported by backend") + } + + assignedBT := flashBackendAssignments[flashIdx] + flashIdx++ + + if node.buffer == nil || assignedBT == nil { + continue + } + + bufferType := C.ggml_backend_buffer_get_type(node.buffer) + + actualPriority := bufferTypePriority(bufferType, c.b.schedBufts) + expectedPriority := bufferTypePriority(assignedBT, c.b.schedBufts) + + // A lower numbered priority is better here + if actualPriority > expectedPriority { + slog.Debug("flash attention not supported by backend", + "tensor", C.GoString(C.ggml_get_name(node)), + "assigned_buffer_type", C.GoString(C.ggml_backend_buft_name(bufferType)), + "assigned_priority", actualPriority, + "expected_buffer_type", C.GoString(C.ggml_backend_buft_name(assignedBT)), + "expected_priority", expectedPriority) + return errors.New("flash attention not supported by backend") + } + } + } + + return nil +} + +func bufferTypePriority(buft C.ggml_backend_buffer_type_t, schedBufts []C.ggml_backend_buffer_type_t) int { + for i, b := range schedBufts { + if b == buft { + return i + } + } + + return len(schedBufts) +} + +// Check that there are no illegal operations and build a mapping of flash attention operation locations +// from before the scheduler runs to compare to the result afterwards. +func validateGraph(graph *C.struct_ggml_cgraph, flashAttention ml.FlashAttentionType) ([]C.ggml_backend_buffer_type_t, error) { + var assignments []C.ggml_backend_buffer_type_t + + for i := range C.ggml_graph_n_nodes(graph) { + node := C.ggml_graph_node(graph, i) + + switch node.op { + // Only flash attention supports quantized KV cache, so if we have a matmul that uses a quantized input (other than weights), + // it means that the model is using its own implementation of attention. + case C.GGML_OP_MUL_MAT: + for srcIndex := range int(C.GGML_MAX_SRC) { + src := node.src[srcIndex] + if src == nil { + continue + } + + var quantized *C.struct_ggml_tensor + for current := src; current != nil; current = current.view_src { + if C.ggml_is_quantized(current._type) { + quantized = current + break + } + } + + // If matmul has a quantized input, it is only supported if it is weights (due to uniform stride) + if quantized != nil && + !(quantized.buffer != nil && C.ggml_backend_buffer_get_usage(quantized.buffer) == C.GGML_BACKEND_BUFFER_USAGE_WEIGHTS) { + slog.Debug("unsupported quantized matmul input", + "tensor", C.GoString(C.ggml_get_name(node)), + "src", C.GoString(C.ggml_get_name(src)), + "type", C.GoString(C.ggml_type_name(quantized._type))) + return nil, errors.New("unsupported quantized matmul input") + } + } + + // Build a mapping of flash attention operations to their most direct weight input. We do this before the scheduler runs + // because the graph is fully connected. After scheduling, the graph is hard to trace because it is broken up into splits. + // We index by flash attention number (more or less equivalent to layer) since that is persistent across scheduling. + case C.GGML_OP_FLASH_ATTN_EXT: + if flashAttention == ml.FlashAttentionAuto { + // Breadth-first search for the first ancestor that has a buffer with weights + queue := []*C.struct_ggml_tensor{node} + visited := make(map[*C.struct_ggml_tensor]struct{}) + + var ancestor *C.struct_ggml_tensor + for len(queue) > 0 { + current := queue[0] + queue = queue[1:] + + if _, ok := visited[current]; ok { + continue + } + visited[current] = struct{}{} + + // Only use weights as reference points - we don't want to use inputs like the cache mask, which are always on the CPU + if current.buffer != nil && C.ggml_backend_buffer_get_usage(current.buffer) == C.GGML_BACKEND_BUFFER_USAGE_WEIGHTS { + ancestor = current + break + } + + for srcIndex := range int(C.GGML_MAX_SRC) { + if src := current.src[srcIndex]; src != nil { + queue = append(queue, src) + } + } + } + + var bufferType C.ggml_backend_buffer_type_t + if ancestor != nil { + bufferType = C.ggml_backend_buffer_get_type(ancestor.buffer) + } + assignments = append(assignments, bufferType) + } + } + } + + return assignments, nil } func (c *Context) MaxGraphNodes() int { @@ -1679,7 +1820,7 @@ func (t *Tensor) ScaledDotProductAttention(ctx ml.Context, key, value, mask, sin query := t.Permute(ctx, 0, 2, 1, 3) key = key.Permute(ctx, 0, 2, 1, 3) - if t.b.flashAttention == ml.FlashAttentionEnabled { + if t.b.flashAttention != ml.FlashAttentionDisabled { value = value.Permute(ctx, 0, 2, 1, 3) kqv := C.ggml_flash_attn_ext(ctx.(*Context).ctx, query.(*Tensor).t, key.(*Tensor).t, value.(*Tensor).t, kqMask, C.float(scale), 0, 0) diff --git a/runner/llamarunner/runner.go b/runner/llamarunner/runner.go index de9d718b3..e20e9a113 100644 --- a/runner/llamarunner/runner.go +++ b/runner/llamarunner/runner.go @@ -935,13 +935,13 @@ func (s *Server) load(w http.ResponseWriter, r *http.Request) { case llm.LoadOperationClose: // No-op for us - if err := json.NewEncoder(w).Encode(&llm.LoadResponse{}); err != nil { + if err := json.NewEncoder(w).Encode(&llm.LoadResponse{Request: req}); err != nil { http.Error(w, fmt.Sprintf("failed to encode response: %v", err), http.StatusInternalServerError) } return } - resp := llm.LoadResponse{Success: true} + resp := llm.LoadResponse{Success: true, Request: req} if err := json.NewEncoder(w).Encode(&resp); err != nil { http.Error(w, fmt.Sprintf("failed to encode response: %v", err), http.StatusInternalServerError) return diff --git a/runner/ollamarunner/multimodal.go b/runner/ollamarunner/multimodal.go index 6af89021c..98c7759b6 100644 --- a/runner/ollamarunner/multimodal.go +++ b/runner/ollamarunner/multimodal.go @@ -98,7 +98,10 @@ func (m multimodalStore) getTensor(backend ml.Backend, ctx ml.Context, in ml.Ten } } } else { - computeCtx.Reserve() + err := computeCtx.Reserve() + if err != nil { + return nil, err + } } } diff --git a/runner/ollamarunner/runner.go b/runner/ollamarunner/runner.go index a756cba23..0cbba5e5c 100644 --- a/runner/ollamarunner/runner.go +++ b/runner/ollamarunner/runner.go @@ -1160,22 +1160,12 @@ func (s *Server) reserveWorstCaseGraph(prompt bool) error { } ctx.SetBatchSize(batchSize) - ctx.Forward(t).Reserve() - - return nil + return ctx.Forward(t).Reserve() } // allocModel pre-allocates the maximum needed memory for a model // based on the given parameters -func (s *Server) allocModel( - mpath string, - params ml.BackendParams, - loraPath []string, - parallel int, - kvCacheType string, - kvSize int, - multiUserCache bool, -) (panicErr error) { +func (s *Server) allocModel(mpath string, req *llm.LoadRequest) (panicErr error) { // Convert memory allocation panics to errors defer func() { if r := recover(); r != nil { @@ -1192,43 +1182,73 @@ func (s *Server) allocModel( } }() - var err error - s.model, err = model.New(mpath, params) - if err != nil { - return err - } - - // TODO(jessegross): LoRA loading - if len(loraPath) > 0 { - return errors.New("loras are not yet implemented") - } - - if s.model.Config().Cache == nil { - if parallel > 1 { - parallel = 1 - slog.Warn("model does not support caching, disabling parallel processing") +reload: + for range 2 { + params := ml.BackendParams{ + AllocMemory: req.Operation != llm.LoadOperationFit, + NumThreads: req.NumThreads, + GPULayers: req.GPULayers, + FlashAttention: req.FlashAttention, } - if s.batchSize < kvSize { - s.batchSize = kvSize - slog.Warn("model does not support caching, setting batch size to context length", "batch_size", kvSize) + + var err error + s.model, err = model.New(mpath, params) + if err != nil { + return err } - } - s.cache, err = NewInputCache(s.model, kvCacheType, int32(kvSize), parallel, s.batchSize, multiUserCache) - if err != nil { - return err - } + // TODO(jessegross): LoRA loading + if len(req.LoraPath) > 0 { + return errors.New("loras are not yet implemented") + } - s.parallel = parallel - s.seqs = make([]*Sequence, s.parallel) - s.seqsSem = semaphore.NewWeighted(int64(s.parallel)) + if params.FlashAttention == ml.FlashAttentionDisabled && ggml.KVCacheTypeIsQuantized(req.KvCacheType) { + slog.Warn("quantized kv cache requested but flash attention disabled", "type", req.KvCacheType) + req.KvCacheType = "" + } + + if s.model.Config().Cache == nil { + if req.Parallel > 1 { + req.Parallel = 1 + slog.Warn("model does not support caching, disabling parallel processing") + } + if req.BatchSize < req.KvSize { + req.BatchSize = req.KvSize + slog.Warn("model does not support caching, setting batch size to context length", "batch_size", req.KvSize) + } + } + + s.cache, err = NewInputCache(s.model, req.KvCacheType, int32(req.KvSize), req.Parallel, req.BatchSize, req.MultiUserCache) + if err != nil { + return err + } + + s.batchSize = req.BatchSize + s.parallel = req.Parallel + s.seqs = make([]*Sequence, s.parallel) + s.seqsSem = semaphore.NewWeighted(int64(s.parallel)) + + for _, prompt := range []bool{true, false} { + if err := s.reserveWorstCaseGraph(prompt); err != nil { + if req.FlashAttention != ml.FlashAttentionDisabled { + slog.Warn("flash attention enabled but not supported by model") + req.FlashAttention = ml.FlashAttentionDisabled + s.closeModel() + continue reload + } + + return err + } + } + + if req.FlashAttention == ml.FlashAttentionAuto { + req.FlashAttention = ml.FlashAttentionEnabled + } - err = s.reserveWorstCaseGraph(true) - if err != nil { return nil } - return s.reserveWorstCaseGraph(false) + return errors.New("unable to allocate model") } // closeModel frees all memory associated with a model @@ -1243,7 +1263,15 @@ func (s *Server) closeModel() { // loadModel loads the weights for a model. The memory must already // have been allocated with allocModel -func (s *Server) loadModel() { +func (s *Server) loadModel(req llm.LoadRequest) { + if req.FlashAttention != ml.FlashAttentionDisabled { + slog.Info("enabling flash attention") + } + + if ggml.KVCacheTypeIsQuantized(req.KvCacheType) { + slog.Info("enabling kv cache quantization", "type", req.KvCacheType) + } + err := s.model.Backend().Load(context.TODO(), func(progress float32) { s.progress = progress @@ -1279,7 +1307,7 @@ func (s *Server) load(w http.ResponseWriter, r *http.Request) { if req.Operation == llm.LoadOperationClose { s.closeModel() - if err := json.NewEncoder(w).Encode(&llm.LoadResponse{}); err != nil { + if err := json.NewEncoder(w).Encode(&llm.LoadResponse{Request: req}); err != nil { http.Error(w, fmt.Sprintf("failed to encode response: %v", err), http.StatusInternalServerError) } return @@ -1288,27 +1316,16 @@ func (s *Server) load(w http.ResponseWriter, r *http.Request) { s.lastLoad.Operation = req.Operation loadModel := s.model == nil || !reflect.DeepEqual(req, s.lastLoad) - s.lastLoad = req - if loadModel { s.closeModel() - params := ml.BackendParams{ - AllocMemory: req.Operation != llm.LoadOperationFit, - NumThreads: req.NumThreads, - GPULayers: req.GPULayers, - FlashAttention: req.FlashAttention, - } - - s.batchSize = req.BatchSize - - err := s.allocModel(s.modelPath, params, req.LoraPath, req.Parallel, req.KvCacheType, req.KvSize, req.MultiUserCache) + err := s.allocModel(s.modelPath, &req) if err != nil { s.closeModel() var noMem ml.ErrNoMem if errors.As(err, &noMem) { - resp := llm.LoadResponse{Success: false, Memory: noMem.BackendMemory} + resp := llm.LoadResponse{Success: false, Request: req, Memory: noMem.BackendMemory} if err := json.NewEncoder(w).Encode(&resp); err != nil { http.Error(w, fmt.Sprintf("failed to encode response: %v", err), http.StatusInternalServerError) } @@ -1321,6 +1338,7 @@ func (s *Server) load(w http.ResponseWriter, r *http.Request) { } } + s.lastLoad = req mem := s.model.Backend().BackendMemory() switch req.Operation { @@ -1332,10 +1350,10 @@ func (s *Server) load(w http.ResponseWriter, r *http.Request) { case llm.LoadOperationCommit: s.status = llm.ServerStatusLoadingModel - go s.loadModel() + go s.loadModel(req) } - resp := llm.LoadResponse{Success: true, Memory: mem} + resp := llm.LoadResponse{Success: true, Request: req, Memory: mem} if err := json.NewEncoder(w).Encode(&resp); err != nil { http.Error(w, fmt.Sprintf("failed to encode response: %v", err), http.StatusInternalServerError) return