Description
np.pad
is not in array API spec, and is unlikely to get there. Still is a sometimes useful function, which is available in some backends (numpy, cupy, jax.numpy) but not others (torch). Thus it'd be great to add it to array-api-extra
.
Implementing the full set of mode keywords is somewhat tricky, but the most useful one, mode="constant"
is easy to implement even with pytorch. A ready implementation is available in the scipy PR:
https://github.com/scipy/scipy/pull/22122/files#diff-351836adc98d076c1552d17a57c52e6aa8ca43760bae44bea9190e55b4769b7fR873
The torch implementation derives from torch._numpy
, https://github.com/pytorch/pytorch/blob/main/torch/_numpy/_funcs_impl.py#L2045
The version from the scipy PR is under the fold.
def xp_pad(x, pad_width, mode='constant', *, xp, **kwargs):
# xp.pad is available on numpy, cupy and jax.numpy; on torch, reuse
# http://github.com/pytorch/pytorch/blob/main/torch/_numpy/_funcs_impl.py#L2045
# for mode = 'constant'
if mode != 'constant':
raise NotImplementedError()
value = kwargs.get("constant_values", 0)
if is_array_api_strict(xp):
np_x = np.asarray(x)
padded = np.pad(np_x, pad_width, mode=mode, **kwargs)
return xp.asarray(padded)
elif is_torch(xp):
pad_width = xp.asarray(pad_width)
pad_width = xp.broadcast_to(pad_width, (x.ndim, 2))
pad_width = xp.flip(pad_width, axis=(0,)).flatten()
return xp.nn.functional.pad(x, tuple(pad_width), value=value)
else:
return xp.pad(x, pad_width, mode=mode, **kwargs)