kvcache: Skip computing causal mask for worst case graph reservation

Computing an attention mask for a large context and max batch is
expensive - over 100ms. Models like Gemma3 that have multiple types
of caches and custom attention masks need to do this 4 times, so this
adds approximately 500ms to startup time when using 128k context

When we are reserving the worst case graph, we don't need the mask,
only its shape, so we can skip this.
This commit is contained in:
Jesse Gross 2025-05-27 13:33:57 -07:00 committed by Ryan Schumacher
parent 8d989025e2
commit 9c5c197393
No known key found for this signature in database
1 changed files with 12 additions and 1 deletions

View File

@ -30,6 +30,11 @@ type Causal struct {
// ** current forward pass **
// curReserve indicates that this forward pass is only for
// memory reservation and we should not update our metadata
// based on it.
curReserve bool
// the active layer for Get and Put
curLayer int
@ -159,12 +164,13 @@ func (c *Causal) Close() {
}
func (c *Causal) StartForward(ctx ml.Context, batch input.Batch, reserve bool) error {
c.curReserve = reserve
c.curBatchSize = len(batch.Positions)
c.curSequences = batch.Sequences
c.curPositions = batch.Positions
c.opts.Except = nil
if !reserve {
if !c.curReserve {
c.updateSlidingWindow()
var err error
@ -304,6 +310,11 @@ func (c *Causal) buildMask(ctx ml.Context) ml.Tensor {
c.curCellRange.max = roundUp(c.curCellRange.max+1, c.config.CachePadding) - 1
length := c.curCellRange.max - c.curCellRange.min + 1
if c.curReserve {
return ctx.Input().Empty(c.config.MaskDType, length, batchSize)
}
mask := make([]float32, batchSize*length)
for i := range c.curBatchSize {