From e1da0d68390702e483cd4fb160a0c0b46ba96661 Mon Sep 17 00:00:00 2001 From: crusaderky Date: Thu, 23 Jan 2025 11:52:30 +0000 Subject: [PATCH 1/2] ENH: astype: add device kwarg --- array_api_compat/common/_aliases.py | 7 +------ array_api_compat/cupy/_aliases.py | 22 ++++++++++++++++++---- array_api_compat/dask/array/_aliases.py | 2 +- array_api_compat/numpy/_aliases.py | 17 ++++++++++++++--- array_api_compat/torch/_aliases.py | 15 +++++++++++++-- cupy-xfails.txt | 1 - dask-xfails.txt | 1 - numpy-1-21-xfails.txt | 1 - numpy-1-26-xfails.txt | 1 - numpy-dev-xfails.txt | 1 - numpy-xfails.txt | 1 - torch-xfails.txt | 2 -- 12 files changed, 47 insertions(+), 24 deletions(-) diff --git a/array_api_compat/common/_aliases.py b/array_api_compat/common/_aliases.py index d5405745..4a884893 100644 --- a/array_api_compat/common/_aliases.py +++ b/array_api_compat/common/_aliases.py @@ -233,11 +233,6 @@ def unique_values(x: ndarray, /, xp) -> ndarray: **kwargs, ) -def astype(x: ndarray, dtype: Dtype, /, *, copy: bool = True) -> ndarray: - if not copy and dtype == x.dtype: - return x - return x.astype(dtype=dtype, copy=copy) - # These functions have different keyword argument names def std( @@ -549,7 +544,7 @@ def sign(x: ndarray, /, xp, **kwargs) -> ndarray: 'linspace', 'ones', 'ones_like', 'zeros', 'zeros_like', 'UniqueAllResult', 'UniqueCountsResult', 'UniqueInverseResult', 'unique_all', 'unique_counts', 'unique_inverse', 'unique_values', - 'astype', 'std', 'var', 'cumulative_sum', 'clip', 'permute_dims', + 'std', 'var', 'cumulative_sum', 'clip', 'permute_dims', 'reshape', 'argsort', 'sort', 'nonzero', 'ceil', 'floor', 'trunc', 'matmul', 'matrix_transpose', 'tensordot', 'vecdot', 'isdtype', 'unstack', 'sign'] diff --git a/array_api_compat/cupy/_aliases.py b/array_api_compat/cupy/_aliases.py index 3627fb6b..7b8000f5 100644 --- a/array_api_compat/cupy/_aliases.py +++ b/array_api_compat/cupy/_aliases.py @@ -2,7 +2,7 @@ import cupy as cp -from ..common import _aliases +from ..common import _aliases, _helpers from .._internal import get_xp from ._info import __array_namespace_info__ @@ -46,7 +46,6 @@ unique_counts = get_xp(cp)(_aliases.unique_counts) unique_inverse = get_xp(cp)(_aliases.unique_inverse) unique_values = get_xp(cp)(_aliases.unique_values) -astype = _aliases.astype std = get_xp(cp)(_aliases.std) var = get_xp(cp)(_aliases.var) cumulative_sum = get_xp(cp)(_aliases.cumulative_sum) @@ -110,6 +109,21 @@ def asarray( return cp.array(obj, dtype=dtype, **kwargs) + +def astype( + x: ndarray, + dtype: Dtype, + /, + *, + copy: bool = True, + device: Optional[Device] = None, +) -> ndarray: + if device is None: + return x.astype(dtype=dtype, copy=copy) + out = _helpers.to_device(x.astype(dtype=dtype, copy=False), device) + return out.copy() if copy and out is x else x + + # These functions are completely new here. If the library already has them # (i.e., numpy 2.0), use the library version instead of our wrapper. if hasattr(cp, 'vecdot'): @@ -127,10 +141,10 @@ def asarray( else: unstack = get_xp(cp)(_aliases.unstack) -__all__ = _aliases.__all__ + ['__array_namespace_info__', 'asarray', 'bool', +__all__ = _aliases.__all__ + ['__array_namespace_info__', 'asarray', 'astype', 'acos', 'acosh', 'asin', 'asinh', 'atan', 'atan2', 'atanh', 'bitwise_left_shift', 'bitwise_invert', 'bitwise_right_shift', - 'concat', 'pow', 'sign'] + 'bool', 'concat', 'pow', 'sign'] _all_ignore = ['cp', 'get_xp'] diff --git a/array_api_compat/dask/array/_aliases.py b/array_api_compat/dask/array/_aliases.py index 861b0bd0..a8ed6f26 100644 --- a/array_api_compat/dask/array/_aliases.py +++ b/array_api_compat/dask/array/_aliases.py @@ -233,7 +233,7 @@ def _isscalar(a): _common_aliases = [alias for alias in _aliases.__all__ if alias not in _da_unsupported] -__all__ = _common_aliases + ['__array_namespace_info__', 'asarray', 'acos', +__all__ = _common_aliases + ['__array_namespace_info__', 'asarray', 'astype', 'acos', 'acosh', 'asin', 'asinh', 'atan', 'atan2', 'atanh', 'bitwise_left_shift', 'bitwise_invert', 'bitwise_right_shift', 'concat', 'pow', 'iinfo', 'finfo', 'can_cast', diff --git a/array_api_compat/numpy/_aliases.py b/array_api_compat/numpy/_aliases.py index 2bfc98ff..789eefb3 100644 --- a/array_api_compat/numpy/_aliases.py +++ b/array_api_compat/numpy/_aliases.py @@ -46,7 +46,6 @@ unique_counts = get_xp(np)(_aliases.unique_counts) unique_inverse = get_xp(np)(_aliases.unique_inverse) unique_values = get_xp(np)(_aliases.unique_values) -astype = _aliases.astype std = get_xp(np)(_aliases.std) var = get_xp(np)(_aliases.var) cumulative_sum = get_xp(np)(_aliases.cumulative_sum) @@ -115,6 +114,18 @@ def asarray( return np.array(obj, copy=copy, dtype=dtype, **kwargs) + +def astype( + x: ndarray, + dtype: Dtype, + /, + *, + copy: bool = True, + device: Optional[Device] = None, +) -> ndarray: + return x.astype(dtype=dtype, copy=copy) + + # These functions are completely new here. If the library already has them # (i.e., numpy 2.0), use the library version instead of our wrapper. if hasattr(np, 'vecdot'): @@ -132,10 +143,10 @@ def asarray( else: unstack = get_xp(np)(_aliases.unstack) -__all__ = _aliases.__all__ + ['__array_namespace_info__', 'asarray', 'bool', +__all__ = _aliases.__all__ + ['__array_namespace_info__', 'asarray', 'astype', 'acos', 'acosh', 'asin', 'asinh', 'atan', 'atan2', 'atanh', 'bitwise_left_shift', 'bitwise_invert', 'bitwise_right_shift', - 'concat', 'pow'] + 'bool', 'concat', 'pow'] _all_ignore = ['np', 'get_xp'] diff --git a/array_api_compat/torch/_aliases.py b/array_api_compat/torch/_aliases.py index 5ac66bcb..f2ec7b17 100644 --- a/array_api_compat/torch/_aliases.py +++ b/array_api_compat/torch/_aliases.py @@ -613,8 +613,19 @@ def triu(x: array, /, *, k: int = 0) -> array: def expand_dims(x: array, /, *, axis: int = 0) -> array: return torch.unsqueeze(x, axis) -def astype(x: array, dtype: Dtype, /, *, copy: bool = True) -> array: - return x.to(dtype, copy=copy) + +def astype( + x: array, + dtype: Dtype, + /, + *, + copy: bool = True, + device: Optional[Device] = None, +) -> array: + if device is not None: + return x.to(device, dtype=dtype, copy=copy) + return x.to(dtype=dtype, copy=copy) + def broadcast_arrays(*arrays: array) -> List[array]: shape = torch.broadcast_shapes(*[a.shape for a in arrays]) diff --git a/cupy-xfails.txt b/cupy-xfails.txt index aa74e4d5..f30004c1 100644 --- a/cupy-xfails.txt +++ b/cupy-xfails.txt @@ -181,5 +181,4 @@ array_api_tests/test_fft.py::test_irfftn # cupy.ndaray cannot be specified as `repeats` argument. array_api_tests/test_manipulation_functions.py::test_repeat array_api_tests/test_signatures.py::test_func_signature[from_dlpack] -array_api_tests/test_signatures.py::test_func_signature[astype] array_api_tests/test_signatures.py::test_array_method_signature[__dlpack__] diff --git a/dask-xfails.txt b/dask-xfails.txt index 1e9c421c..1631ea12 100644 --- a/dask-xfails.txt +++ b/dask-xfails.txt @@ -154,4 +154,3 @@ array_api_tests/test_statistical_functions.py::test_prod # 2023.12 support array_api_tests/test_manipulation_functions.py::test_repeat array_api_tests/test_searching_functions.py::test_searchsorted -array_api_tests/test_signatures.py::test_func_signature[astype] diff --git a/numpy-1-21-xfails.txt b/numpy-1-21-xfails.txt index 459b33e3..f396b789 100644 --- a/numpy-1-21-xfails.txt +++ b/numpy-1-21-xfails.txt @@ -254,7 +254,6 @@ array_api_tests/test_special_cases.py::test_iop[__iadd__(x1_i is -0 and x2_i is # 2023.12 support array_api_tests/test_searching_functions.py::test_searchsorted array_api_tests/test_signatures.py::test_func_signature[from_dlpack] -array_api_tests/test_signatures.py::test_func_signature[astype] array_api_tests/test_signatures.py::test_array_method_signature[__dlpack__] # uint64 repeats not supported array_api_tests/test_manipulation_functions.py::test_repeat diff --git a/numpy-1-26-xfails.txt b/numpy-1-26-xfails.txt index 57b80e7e..2cb9fe4f 100644 --- a/numpy-1-26-xfails.txt +++ b/numpy-1-26-xfails.txt @@ -49,7 +49,6 @@ array_api_tests/test_statistical_functions.py::test_prod # 2023.12 support array_api_tests/test_searching_functions.py::test_searchsorted array_api_tests/test_signatures.py::test_func_signature[from_dlpack] -array_api_tests/test_signatures.py::test_func_signature[astype] array_api_tests/test_signatures.py::test_array_method_signature[__dlpack__] # uint64 repeats not supported array_api_tests/test_manipulation_functions.py::test_repeat diff --git a/numpy-dev-xfails.txt b/numpy-dev-xfails.txt index 23a83e1e..e904357d 100644 --- a/numpy-dev-xfails.txt +++ b/numpy-dev-xfails.txt @@ -20,7 +20,6 @@ array_api_tests/test_signatures.py::test_extension_func_signature[linalg.vecdot] # 2023.12 support # Argument 'device' missing from signature -array_api_tests/test_signatures.py::test_func_signature[astype] array_api_tests/test_signatures.py::test_func_signature[from_dlpack] array_api_tests/test_signatures.py::test_array_method_signature[__dlpack__] # uint64 repeats not supported diff --git a/numpy-xfails.txt b/numpy-xfails.txt index 1c9d98f6..3e898371 100644 --- a/numpy-xfails.txt +++ b/numpy-xfails.txt @@ -41,7 +41,6 @@ array_api_tests/test_signatures.py::test_extension_func_signature[linalg.vecdot] # 2023.12 support array_api_tests/test_searching_functions.py::test_searchsorted array_api_tests/test_signatures.py::test_func_signature[from_dlpack] -array_api_tests/test_signatures.py::test_func_signature[astype] array_api_tests/test_signatures.py::test_array_method_signature[__dlpack__] # uint64 repeats not supported array_api_tests/test_manipulation_functions.py::test_repeat diff --git a/torch-xfails.txt b/torch-xfails.txt index c972659e..44bef5af 100644 --- a/torch-xfails.txt +++ b/torch-xfails.txt @@ -202,5 +202,3 @@ array_api_tests/test_signatures.py::test_func_signature[repeat] array_api_tests/test_signatures.py::test_func_signature[from_dlpack] # Argument 'max_version' missing from signature array_api_tests/test_signatures.py::test_array_method_signature[__dlpack__] -# Argument 'device' missing from signature -array_api_tests/test_signatures.py::test_func_signature[astype] From 7c734dbbd640f2ad360651f4bb5a50b87a458767 Mon Sep 17 00:00:00 2001 From: Guido Imperiale Date: Fri, 24 Jan 2025 11:24:05 +0000 Subject: [PATCH 2/2] Update array_api_compat/cupy/_aliases.py Co-authored-by: Lucas Colley --- array_api_compat/cupy/_aliases.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/array_api_compat/cupy/_aliases.py b/array_api_compat/cupy/_aliases.py index 7b8000f5..8ab5629b 100644 --- a/array_api_compat/cupy/_aliases.py +++ b/array_api_compat/cupy/_aliases.py @@ -121,7 +121,7 @@ def astype( if device is None: return x.astype(dtype=dtype, copy=copy) out = _helpers.to_device(x.astype(dtype=dtype, copy=False), device) - return out.copy() if copy and out is x else x + return out.copy() if copy and out is x else out # These functions are completely new here. If the library already has them