diff --git a/ml/backend/ggml/ggml.go b/ml/backend/ggml/ggml.go index 919203549..98f8dfbfd 100644 --- a/ml/backend/ggml/ggml.go +++ b/ml/backend/ggml/ggml.go @@ -884,7 +884,7 @@ func shapeToGGML(shape []int) *C.int64_t { return &sh[0] } -func pad(length, pad C.size_t) C.size_t { +func pad[T C.size_t | int](length, pad T) T { return ((length + pad - 1) / pad) * pad } @@ -1789,7 +1789,7 @@ func (t *Tensor) SDPA(ctx ml.Context, key, value ml.Tensor, fns ...func(*attenti } if opts.Mask != nil { - if padSize := int(pad(C.size_t(opts.Mask.Dim(1)), C.size_t(config.MaskBatchPadding))) - opts.Mask.Dim(1); padSize > 0 { + if padSize := pad(opts.Mask.Dim(1), config.MaskBatchPadding) - opts.Mask.Dim(1); padSize > 0 { opts.Mask = opts.Mask.Pad(ctx, 0, padSize, 0, 0) }