generic pad

This commit is contained in:
Michael Yang 2025-12-15 16:31:35 -08:00
parent baae175ebe
commit f49797fbdb
1 changed files with 2 additions and 2 deletions

View File

@ -884,7 +884,7 @@ func shapeToGGML(shape []int) *C.int64_t {
return &sh[0]
}
func pad(length, pad C.size_t) C.size_t {
func pad[T C.size_t | int](length, pad T) T {
return ((length + pad - 1) / pad) * pad
}
@ -1789,7 +1789,7 @@ func (t *Tensor) SDPA(ctx ml.Context, key, value ml.Tensor, fns ...func(*attenti
}
if opts.Mask != nil {
if padSize := int(pad(C.size_t(opts.Mask.Dim(1)), C.size_t(config.MaskBatchPadding))) - opts.Mask.Dim(1); padSize > 0 {
if padSize := pad(opts.Mask.Dim(1), config.MaskBatchPadding) - opts.Mask.Dim(1); padSize > 0 {
opts.Mask = opts.Mask.Pad(ctx, 0, padSize, 0, 0)
}