8000 ENH: add `pad` · Issue #69 · data-apis/array-api-extra · GitHub
[go: up one dir, main page]

Skip to content
ENH: add pad #69
Closed
Closed
@ev-br

Description

@ev-br

np.pad is not in array API spec, and is unlikely to get there. Still is a sometimes useful function, which is available in some backends (numpy, cupy, jax.numpy) but not others (torch). Thus it'd be great to add it to array-api-extra.

Implementing the full set of mode keywords is somewhat tricky, but the most useful one, mode="constant" is easy to implement even with pytorch. A ready implementation is available in the scipy PR:
https://github.com/scipy/scipy/pull/22122/files#diff-351836adc98d076c1552d17a57c52e6aa8ca43760bae44bea9190e55b4769b7fR873

The torch implementation derives from torch._numpy, https://github.com/pytorch/pytorch/blob/main/torch/_numpy/_funcs_impl.py#L2045

The version from the scipy PR is under the fold.

def xp_pad(x, pad_width, mode='constant', *, xp, **kwargs):
    # xp.pad is available on numpy, cupy and jax.numpy; on torch, reuse
    # http://github.com/pytorch/pytorch/blob/main/torch/_numpy/_funcs_impl.py#L2045
    # for mode = 'constant'
    if mode != 'constant':
        raise NotImplementedError()

    value = kwargs.get("constant_values", 0)

    if is_array_api_strict(xp):
        np_x = np.asarray(x)
        padded = np.pad(np_x, pad_width, mode=mode, **kwargs)
        return xp.asarray(padded)
    elif is_torch(xp):
        pad_width = xp.asarray(pad_width)
        pad_width = xp.broadcast_to(pad_width, (x.ndim, 2))
        pad_width = xp.flip(pad_width, axis=(0,)).flatten()
        return xp.nn.functional.pad(x, tuple(pad_width), value=value)
    else:
        return xp.pad(x, pad_width, mode=mode, **kwargs)

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions

      0