8000 Merge pull request #26580 from JuliaPoo/issue-26196-asanyarray-copy-d… · mathomp4/numpy@e6ebcc7 · GitHub
[go: up one dir, main page]

Skip to content

Commit e6ebcc7

Browse files
authored
Merge pull request numpy#26580 from JuliaPoo/issue-26196-asanyarray-copy-device
ENH: Add copy and device keyword to np.asanyarray to match np.asarray
2 parents 09779f9 + a9b773d commit e6ebcc7

File tree

5 files changed

+54
-4
lines changed

5 files changed

+54
-4
lines changed
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
* `numpy.asanyarray` now supports ``copy`` and ``device`` arguments, matching `numpy.asarray`.

numpy/_core/_add_newdocs.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1030,6 +1030,22 @@
10301030
'A' (any) means 'F' if `a` is Fortran contiguous, 'C' otherwise
10311031
'K' (keep) preserve input order
10321032
Defaults to 'C'.
1033+
device : str, optional
1034+
The device on which to place the created array. Default: None.
1035+
For Array-API interoperability only, so must be ``"cpu"`` if passed.
1036+
1037+
.. versionadded:: 2.1.0
1038+
1039+
copy : bool, optional
1040+
If ``True``, then the object is copied. If ``None`` then the object is
1041+
copied only if needed, i.e. if ``__array__`` returns a copy, if obj
1042+
is a nested sequence, or if a copy is needed to satisfy any of
1043+
the other requirements (``dtype``, ``order``, etc.).
1044+
For ``False`` it raises a ``ValueError`` if a copy cannot be avoided.
1045+
Default: ``None``.
1046+
1047+
.. versionadded:: 2.1.0
1048+
10331049
${ARRAY_FUNCTION_LIKE}
10341050
10351051
.. versionadded:: 1.20.0

numpy/_core/multiarray.pyi

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -525,6 +525,8 @@ def asanyarray(
525525
dtype: None = ...,
526526
order: _OrderKACF = ...,
527527
*,
528+
device: None | L["cpu"] = ...,
529+
copy: None | bool = ...,
528530
like: None | _SupportsArrayFunc = ...,
529531
) -> _ArrayType: ...
530532
@overload
@@ -533,6 +535,8 @@ def asanyarray(
533535
dtype: None = ...,
534536
order: _OrderKACF = ...,
535537
*,
538+
device: None | L["cpu"] = ...,
539+
copy: None | bool = ...,
536540
like: None | _SupportsArrayFunc = ...,
537541
) -> NDArray[_SCT]: ...
538542
@overload
@@ -541,6 +545,8 @@ def asanyarray(
541545
dtype: None = ...,
542546
order: _OrderKACF = ...,
543547
*,
548+
device: None | L["cpu"] = ...,
549+
copy: None | bool = ...,
544550
like: None | _SupportsArrayFunc = ...,
545551
) -> NDArray[Any]: ...
546552
@overload
@@ -549,6 +555,8 @@ def asanyarray(
549555
dtype: _DTypeLike[_SCT],
550556
order: _OrderKACF = ...,
551557
*,
558+
device: None | L["cpu"] = ...,
559+
copy: None | bool = ...,
552560
like: None | _SupportsArrayFunc = ...,
553561
) -> NDArray[_SCT]: ...
554562
@overload
@@ -557,6 +565,8 @@ def asanyarray(
557565
dtype: DTypeLike,
558566
order: _OrderKACF = ...,
559567
*,
568+
device: None | L["cpu"] = ...,
569+
copy: None | bool = ...,
560570
like: None | _SupportsArrayFunc = ...,
561571
) -> NDArray[Any]: ...
562572

numpy/_core/src/multiarray/multiarraymodule.c

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1781,8 +1781,10 @@ array_asanyarray(PyObject *NPY_UNUSED(ignored),
17811781
PyObject *const *args, Py_ssize_t len_args, PyObject *kwnames)
17821782
{
17831783
PyObject *op;
1784+
NPY_COPYMODE copy = NPY_COPY_IF_NEEDED;
17841785
npy_dtype_info dt_info = {NULL, NULL};
17851786
NPY_ORDER order = NPY_KEEPORDER;
1787+
NPY_DEVICE device = NPY_DEVICE_CPU;
17861788
PyObject *like = Py_None;
17871789
NPY_PREPARE_ARGPARSER;
17881790

@@ -1791,6 +1793,8 @@ array_asanyarray(PyObject *NPY_UNUSED(ignored),
17911793
"a", NULL, &op,
17921794
"|dtype", &PyArray_DTypeOrDescrConverterOptional, &dt_info,
17931795
"|order", &PyArray_OrderConverter, &order,
1796+
"$device", &PyArray_DeviceConverterOptional, &device,
1797+
"$copy", &PyArray_CopyConverter, &copy,
17941798
"$like", NULL, &like,
17951799
NULL, NULL, NULL) < 0) {
17961800
Py_XDECREF(dt_info.descr);
@@ -1812,7 +1816,7 @@ array_asanyarray(PyObject *NPY_UNUSED(ignored),
18121816
}
18131817

18141818
PyObject *res = _array_fromobject_generic(
1815-
op, dt_info.descr, dt_info.dtype, NPY_COPY_IF_NEEDED, order, NPY_TRUE, 0);
1819+
op, dt_info.descr, dt_info.dtype, copy, order, NPY_TRUE, 0);
18161820
Py_XDECREF(dt_info.descr);
18171821
Py_XDECREF(dt_info.dtype);
18181822
return res;

numpy/_core/tests/test_multiarray.py

Lines changed: 22 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10241,10 +10241,29 @@ class TestDevice:
1024110241
"""
1024210242
Test arr.device attribute and arr.to_device() method.
1024310243
"""
10244-
def test_device(self):
10245-
arr = np.arange(5)
10246-
10244+
@pytest.mark.parametrize("func, arg", [
10245+
(np.arange, 5),
10246+
(np.empty_like, []),
10247+
(np.zeros, 5),
10248+
(np.empty, (5, 5)),
10249+
(np.asarray, []),
10250+
(np.asanyarray, []),
10251+
])
10252+
def test_device(self, func, arg):
10253+
arr = func(arg)
10254+
assert arr.device == "cpu"
10255+
arr = func(arg, device=None)
10256+
assert arr.device == "cpu"
10257+
arr = func(arg, device="cpu")
1024710258
assert arr.device == "cpu"
10259+
10260+
with assert_raises_regex(
10261+
ValueError,
10262+
r"Device not understood. Only \"cpu\" is allowed, "
10263+
r"but received: nonsense"
10264+
):
10265+
func(arg, device="nonsense")
10266+
1024810267
with assert_raises_regex(
1024910268
AttributeError,
1025010269
r"attribute 'device' of '(numpy.|)ndarray' objects is "

0 commit comments

Comments
 (0)
0