8000 ENH: support PyTorch `device='meta'` by crusaderky · Pull Request #300 · data-apis/array-api-extra · GitHub
[go: up one dir, main page]

Skip to content

ENH: support PyTorch device='meta' #300

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
May 19, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/array_api_extra/_lib/_funcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ def _apply_where( # type: ignore[explicit-any] # numpydoc ignore=PR01,RT01
) -> Array:
"""Helper of `apply_where`. On Dask, this runs on a single chunk."""

if not capabilities(xp)["boolean indexing"]:
if not capabilities(xp, device=_compat.device(cond))["boolean indexing"]:
# jax.jit does not support assignment by boolean mask
return xp.where(cond, f1(*args), f2(*args) if f2 is not None else fill_value)

Expand Down Expand Up @@ -716,7 +716,7 @@ def nunique(x: Array, /, *, xp: ModuleType | None = None) -> Array:
# 2. backend has unique_counts and it returns a None-sized array;
# e.g. Dask, ndonnx
# 3. backend does not have unique_counts; e.g. wrapped JAX
if capabilities(xp)["data-dependent shapes"]:
if capabilities(xp, device=_compat.device(x))["data-dependent shapes"]:
# xp has unique_counts; O(n) complexity
_, counts = xp.unique_counts(x)
n = _compat.size(counts)
Expand Down
44 changes: 35 additions & 9 deletions src/array_api_extra/_lib/_testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
is_jax_namespace,
is_numpy_namespace,
is_pydata_sparse_namespace,
is_torch_array,
is_torch_namespace,
to_device,
)
Expand Down Expand Up @@ -62,18 +63,28 @@ def _check_ns_shape_dtype(
msg = f"namespaces do not match: {actual_xp} != f{desired_xp}"
assert actual_xp == desired_xp, msg

if check_shape:
actual_shape = actual.shape
desired_shape = desired.shape
if is_dask_namespace(desired_xp):
# Dask uses nan instead of None for unknown shapes
if any(math.isnan(i) for i in cast(tuple[float, ...], actual_shape)):
actual_shape = actual.compute().shape # type: ignore[attr-defined] # pyright: ignore[reportAttributeAccessIssue]
if any(math.isnan(i) for i in cast(tuple[float, ...], desired_shape)):
desired_shape = desired.compute().shape # type: ignore[attr-defined] # pyright: ignore[reportAttributeAccessIssue]
# Dask uses nan instead of None for unknown shapes
actual_shape = cast(tuple[float, ...], actual.shape)
desired_shape = cast(tuple[float, ...], desired.shape)
assert None not in actual_shape # Requires explicit support
assert None not in desired_shape
if is_dask_namespace(desired_xp):
if any(math.isnan(i) for i in actual_shape):
actual_shape = actual.compute().shape # type: ignore[attr-defined] # pyright: ignore[reportAttributeAccessIssue]
if any(math.isnan(i) for i in desired_shape):
desired_shape = desired.compute().shape # type: ignore[attr-defined] # pyright: ignore[reportAttributeAccessIssue]

if check_shape:
msg = f"shapes do not match: {actual_shape} != f{desired_shape}"
assert actual_shape == desired_shape, msg
else:
# Ignore shape, but check flattened size. This is normally done by
# np.testing.assert_array_equal etc even when strict=False, but not for
# non-materializable arrays.
actual_size = math.prod(actual_shape) # pyright: ignore[reportUnknownArgumentType]
desired_size = math.prod(desired_shape) # pyright: ignore[reportUnknownArgumentType]
msg = f"sizes do not match: {actual_size} != f{desired_size}"
assert actual_size == desired_size, msg

if check_dtype:
msg = f"dtypes do not match: {actual.dtype} != {desired.dtype}"
Expand All @@ -90,6 +101,15 @@ def _check_ns_shape_dtype(
return desired_xp


def _is_materializable(x: Array) -> bool:
"""
Return True if you can call `as_numpy_array(x)`; False otherwise.
"""
# Important: here we assume that we're not tracing -
# e.g. we're not inside `jax.jit`` nor `cupy.cuda.Stream.begin_capture`.
return not is_torch_array(x) or x.device.type != "meta" # type: ignore[attr-defined] # pyright: ignore[reportAttributeAccessIssue]


def as_numpy_array(array: Array, *, xp: ModuleType) -> np.typing.NDArray[Any]: # type: ignore[explicit-any]
"""
Convert array to NumPy, bypassing GPU-CPU transfer guards and densification guards.
Expand Down Expand Up @@ -146,6 +166,8 @@ def xp_assert_equal(
numpy.testing.assert_array_equal : Similar function for NumPy arrays.
"""
xp = _check_ns_shape_dtype(actual, desired, check_dtype, check_shape, check_scalar)
if not _is_materializable(actual):
return
actual_np = as_numpy_array(actual, xp=xp)
desired_np = as_numpy_array(desired, xp=xp)
np.testing.assert_array_equal(actual_np, desired_np, err_msg=err_msg)
Expand Down Expand Up @@ -181,6 +203,8 @@ def xp_assert_less(
numpy.testing.assert_array_equal : Similar function for NumPy arrays.
"""
xp = _check_ns_shape_dtype(x, y, check_dtype, check_shape, check_scalar)
if not _is_materializable(x):
return
x_np = as_numpy_array(x, xp=xp)
y_np = as_numpy_array(y, xp=xp)
np.testing.assert_array_less(x_np, y_np, err_msg=err_msg)
Expand Down Expand Up @@ -229,6 +253,8 @@ def xp_assert_close(
The default `atol` and `rtol` differ from `xp.all(xpx.isclose(a, b))`.
"""
xp = _check_ns_shape_dtype(actual, desired, check_dtype, check_shape, check_scalar)
if not _is_materializable(actual):
return

if rtol is None:
if xp.isdtype(actual.dtype, ("real floating", "complex floating")):
Expand Down
14 changes: 12 additions & 2 deletions src/array_api_extra/_lib/_utils/_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,9 @@
is_jax_namespace,
is_numpy_array,
is_pydata_sparse_namespace,
is_torch_namespace,
)
from ._typing import Array
from ._typing import Array, Device

if TYPE_CHECKING: # pragma: no cover
# TODO import from typing (requires Python >=3.12 and >=3.13)
Expand Down Expand Up @@ -300,7 +301,7 @@ def meta_namespace(
return array_namespace(*metas)


def capabilities(xp: ModuleType) -> dict[str, int]:
def capabilities(xp: ModuleType, *, device: Device | None = None) -> dict[str, int]:
"""
Return patched ``xp.__array_namespace_info__().capabilities()``.

Expand All @@ -311,6 +312,8 @@ def capabilities(xp: ModuleType) -> dict[str, int]:
----------
xp : array_namespace
The standard-compatible namespace.
device : Device, optional
The device to use.

Returns
-------
Expand All @@ -326,6 +329,13 @@ def capabilities(xp: ModuleType) -> dict[str, int]:
# Fixed in jax >=0.6.0
out = out.copy()
out["boolean indexing"] = False
if is_torch_namespace(xp):
# FIXME https://github.com/data-apis/array-api/issues/945
device = xp.get_default_device() if device is None else xp.device(device)
if device.type == "meta": # type: ignore[union-attr] # pyright: ignore[reportAttributeAccessIssue,reportOptionalMemberAccess]
out = out.copy()
out["boolean indexing"] = False
out["data-dependent shapes"] = False
return out


Expand Down
8 changes: 5 additions & 3 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,9 @@ def device(
Where possible, return a device that is not the default one.
"""
if library == Backend.ARRAY_API_STRICT:
d = xp.Device("device1")
assert get_device(xp.empty(0)) != d
return d
return xp.Device("device1")
if library 6D40 == Backend.TORCH:
return xp.device("meta")
if library == Backend.TORCH_GPU:
return xp.device("cpu")
return get_device(xp.empty(0))
6 changes: 3 additions & 3 deletions tests/test_funcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -731,9 +731,6 @@ def test_device(self, xp: ModuleType, device: Device, equal_nan: bool):
b = xp.asarray([1e-9, 1e-4, xp.nan], device=device)
res = isclose(a, b, equal_nan=equal_nan)
assert get_device(res) == device
xp_assert_equal(
isclose(a, b, equal_nan=equal_nan), xp.asarray([True, False, equal_nan])
)


class TestKron:
Expand Down Expand Up @@ -996,6 +993,9 @@ def test_all_python_scalars(self, assume_unique: bool):
_ = setdiff1d(0, 0, assume_unique=assume_unique)

@assume_unique
@pytest.mark.skip_xp_backend(
Backend.TORCH, reason="device='meta' does not support unknown shapes"
)
def test_device(self, xp: ModuleType, device: Device, assume_unique: bool):
x1 = xp.asarray([3, 8, 20], device=device)
x2 = xp.asarray([2, 3, 4], device=device)
Expand Down
30 changes: 25 additions & 5 deletions tests/test_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,11 +212,31 @@ def test_xp(self, xp: ModuleType):
assert meta_namespace(*args, xp=xp) in (xp, np_compat)


def test_capabilities(xp: ModuleType):
expect = {"boolean indexing", "data-dependent shapes"}
if xp.__array_api_version__ >= "2024.12":
expect.add("max dimensions")
assert capabilities(xp).keys() == expect
class TestCapabilities:
def test_basic(self, xp: ModuleType):
expect = {"boolean indexing", "data-dependent shapes"}
if xp.__array_api_version__ >= "2024.12":
expect.add("max dimensions")
assert capabilities(xp).keys() == expect

def test_device(self, xp: ModuleType, library: Backend, device: Device):
expect_keys = {"boolean indexing", "data-dependent shapes"}
if xp.__array_api_version__ >= "2024.12":
expect_keys.add("max dimensions")
assert capabilities(xp, device=device).keys() == expect_keys

if library.like(Backend.TORCH):
# The output of capabilities is device-specific.

# Test that device=None gets the current default device.
expect = capabilities(xp, device=device)
with xp.device(device):
actual = capabilities(xp)
assert actual == expect

# Test that we're accepting anything that is accepted by the
# device= parameter in other functions
actual = capabilities(xp, device=device.type) # type: ignore[attr-defined] # pyright: ignore[reportUnknownArgumentType,reportAttributeAccessIssue]


class Wrapper(Generic[T]):
Expand Down
3 changes: 3 additions & 0 deletions tests/test_lazy.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,6 +278,9 @@ def test_lazy_apply_none_shape_broadcast(xp: ModuleType):
Backend.ARRAY_API_STRICT, reason="device->host copy"
),
pytest.mark.skip_xp_backend(Backend.CUPY, reason="device->host copy"),
pytest.mark.skip_xp_backend(
Backend.TORCH, reason="materialize 'meta' device"
),
pytest.mark.skip_xp_backend(
Backend.TORCH_GPU, reason="device->host copy"
),
Expand Down
31 changes: 26 additions & 5 deletions tests/test_testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,17 @@
# pyright: reportUnknownParameterType=false,reportMissingParameterType=false


def test_as_numpy_array(xp: ModuleType, device: Device):
x = xp.asarray([1, 2, 3], device=device)
y = as_numpy_array(x, xp=xp)
assert isinstance(y, np.ndarray)
class TestAsNumPyArray:
def test_basic(self, xp: ModuleType):
x = xp.asarray([1, 2, 3])
y = as_numpy_array(x, xp=xp)
xp_assert_equal(y, np.asarray([1, 2, 3])) # type: ignore[arg-type] # pyright: ignore[reportArgumentType]

@pytest.mark.skip_xp_backend(Backend.TORCH, reason="materialize 'meta' device")
def test_device(self, xp: ModuleType, device: Device):
x = xp.asarray([1, 2, 3], device=device)
y = as_numpy_array(x, xp=xp)
xp_assert_equal(y, np.asarray([1, 2, 3])) # type: ignore[arg-type] # pyright: ignore[reportArgumentType]


class TestAssertEqualCloseLess:
Expand Down Expand Up @@ -80,7 +87,7 @@ def test_check_shape(self, xp: ModuleType, func: Callable[..., None]):
func(a, b, check_shape=False)
with pytest.raises(AssertionError, match="Mismatched elements"):
func(a, c, check_shape=False)
with pytest.raises(AssertionError, match=r"shapes \(1,\), \(2,\) mismatch"):
with pytest.raises(AssertionError, match="sizes do not match"):
func(a, d, check_shape=False)

@pytest.mark.parametrize("func", [xp_assert_equal, pr_assert_close, xp_assert_less])
Expand Down Expand Up @@ -169,6 +176,20 @@ def test_none_shape(self, xp: ModuleType, func: Callable[..., None]):
with pytest.raises(AssertionError, match="Mismatched elements"):
func(xp.asarray([4]), a)

@pytest.mark.parametrize("func", [xp_assert_equal, pr_assert_close, xp_assert_less])
def test_device(self, xp: ModuleType, device: Device, func: Callable[..., None]):
a = xp.asarray([1] if func is xp_assert_less else [2], device=device)
b = xp.asarray([2], device=device)
c = xp.asarray([2, 2], device=device)

func(a, b)
with pytest.raises(AssertionError, match="shapes do not match"):
func(a, c)
# This is normally performed by np.testing.assert_array_equal etc.
# but in case of torch device='meta' we have to do it manually
with pytest.raises(AssertionError, match="sizes do not match"):
func(a, c, check_shape=False)


def good_lazy(x: Array) -> Array:
"""A function that behaves well in Dask and jax.jit"""
Expand Down
0