generic pad
This commit is contained in:
parent
baae175ebe
commit
f49797fbdb
|
|
@ -884,7 +884,7 @@ func shapeToGGML(shape []int) *C.int64_t {
|
|||
return &sh[0]
|
||||
}
|
||||
|
||||
func pad(length, pad C.size_t) C.size_t {
|
||||
func pad[T C.size_t | int](length, pad T) T {
|
||||
return ((length + pad - 1) / pad) * pad
|
||||
}
|
||||
|
||||
|
|
@ -1789,7 +1789,7 @@ func (t *Tensor) SDPA(ctx ml.Context, key, value ml.Tensor, fns ...func(*attenti
|
|||
}
|
||||
|
||||
if opts.Mask != nil {
|
||||
if padSize := int(pad(C.size_t(opts.Mask.Dim(1)), C.size_t(config.MaskBatchPadding))) - opts.Mask.Dim(1); padSize > 0 {
|
||||
if padSize := pad(opts.Mask.Dim(1), config.MaskBatchPadding) - opts.Mask.Dim(1); padSize > 0 {
|
||||
opts.Mask = opts.Mask.Pad(ctx, 0, padSize, 0, 0)
|
||||
}
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue