8000 API: Add device and to_device to ndarray · numpy/numpy@89184df · GitHub
[go: up one dir, main page]

Skip to content

Commit 89184df

Browse files
committed
API: Add device and to_device to ndarray
1 parent 06d7bdf commit 89184df

22 files changed

+340
-43
lines changed
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
``ndarray.device`` and ``ndarray.to_device``
2+
--------------------------------------------
3+
4+
``ndarray.device`` attribute and ``ndarray.to_device`` function were
5+
added to `numpy.ndarray` class for Array API compatibility.
6+
7+
Additionally, ``device`` keyword-only arguments were added to:
8+
`numpy.asarray`, `numpy.arange`, `numpy.empty`, `numpy.empty_like`,
9+
`numpy.eye`, `numpy.full`, `numpy.full_like`, `numpy.linspace`,
10+
`numpy.ones`, `numpy.ones_like`, `numpy.zeros`, and `numpy.zeros_like`.

doc/source/reference/array_api.rst

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -588,12 +588,6 @@ Creation functions differences
588588
* - ``copy`` keyword argument to ``asarray``
589589
- **Compatible**
590590
-
591-
* - New ``device`` keyword argument to all array creation functions
592-
(``asarray``, ``arange``, ``empty``, ``empty_like``, ``eye``, ``full``,
593-
``full_like``, ``linspace``, ``ones``, ``ones_like``, ``zeros``, and
594-
``zeros_like``).
595-
- **Compatible**
596-
- ``device`` would effectively do nothing, since NumPy is CPU only.
597591

598592
Elementwise functions differences
599593
---------------------------------

numpy/__init__.pyi

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2465,6 +2465,11 @@ class ndarray(_ArrayOrScalarCommon, Generic[_ShapeType, _DType_co]):
24652465

24662466
def __array_namespace__(self, *, api_version: str = ...) -> Any: ...
24672467

2468+
def to_device(self, device: L["cpu"], /, *, stream: None | int | Any = ...) -> NDArray[Any]: ...
2469+
2470+
@property
2471+
def device(self) -> L["cpu"]: ...
2472+
24682473
def bitwise_count(
24692474
self,
24702475
out: None | NDArray[Any] = ...,

numpy/_core/_add_newdocs.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -912,7 +912,7 @@
912912

913913
add_newdoc('numpy._core.multiarray', 'asarray',
914914
"""
915-
asarray(a, dtype=None, order=None, *, like=None)
915+
asarray(a, dtype=None, order=None, *, device=None, like=None)
916916
917917
Convert the input to an array.
918918
@@ -931,6 +931,10 @@
931931
'A' (any) means 'F' if `a` is Fortran contiguous, 'C' otherwise
932932
'K' (keep) preserve input order
933933
Defaults to 'K'.
934+
device : str, optional
935+
The device on which to place the created array. Default: None.
936+
937+
.. versionadded:: 2.0.0
934938
${ARRAY_FUNCTION_LIKE}
935939
936940
.. versionadded:: 1.20.0
@@ -1184,7 +1188,7 @@
11841188

11851189
add_newdoc('numpy._core.multiarray', 'empty',
11861190
"""
1187-
empty(shape, dtype=float, order='C', *, like=None)
1191+
empty(shape, dtype=float, order='C', *, device=None, like=None)
11881192
11891193
Return a new array of given shape and type, without initializing entries.
11901194
@@ -1199,6 +1203,10 @@
11991203
Whether to store multi-dimensional data in row-major
12001204
(C-style) or column-major (Fortran-style) order in
12011205
memory.
1206+
device : str, optional
1207+
The device on which to place the created array. Default: None.
1208+
1209+
.. versionadded:: 2.0.0
12021210
${ARRAY_FUNCTION_LIKE}
12031211
12041212
.. versionadded:: 1.20.0
@@ -1676,7 +1684,7 @@
16761684

16771685
add_newdoc('numpy._core.multiarray', 'arange',
16781686
"""
1679-
arange([start,] stop[, step,], dtype=None, *, like=None)
1687+
arange([start,] stop[, step,], dtype=None, *, device=None, like=None)
16801688
16811689
Return evenly spaced values within a given interval.
16821690
@@ -1717,6 +1725,10 @@
17171725
dtype : dtype, optional
17181726
The type of the output array. If `dtype` is not given, infer the data
17191727
type from the other input arguments.
1728+
device : str, optional
1729+
The device on which to place the created array. Default: None.
1730+
1731+
.. versionadded:: 2.0.0
17201732
${ARRAY_FUNCTION_LIKE}
17211733
17221734
.. versionadded:: 1.20.0

numpy/_core/function_base.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,13 +17,13 @@
1717

1818

1919
def _linspace_dispatcher(start, stop, num=None, endpoint=None, retstep=None,
20-
dtype=None, axis=None):
20+
dtype=None, axis=None, *, device=None):
2121
return (start, stop)
2222

2323

2424
@array_function_dispatch(_linspace_dispatcher)
2525
def linspace(start, stop, num=50, endpoint=True, retstep=False, dtype=None,
26-
axis=0):
26+
axis=0, *, device=None):
2727
"""
2828
Return evenly spaced numbers over a specified interval.
2929
@@ -64,13 +64,16 @@ def linspace(start, stop, num=50, endpoint=True, retstep=False, dtype=None,
6464
array of integers.
6565
6666
.. versionadded:: 1.9.0
67-
6867
axis : int, optional
6968
The axis in the result to store the samples. Relevant only if start
7069
or stop are array-like. By default (0), the samples will be along a
7170
new axis inserted at the beginning. Use -1 to get an axis at the end.
7271
7372
.. versionadded:: 1.16.0
73+
device : str, optional
74+
The device on which to place the created array. Default: None.
75+
76+
.. versionadded:: 2.0.0
7477
7578
Returns
7679
-------
@@ -119,6 +122,11 @@ def linspace(start, stop, num=50, endpoint=True, retstep=False, dtype=None,
119122
>>> plt.show()
120123
121124
"""
125+
if device not in ["cpu", None]:
126+
raise ValueError(
127+
f"Unsupported device: {device}. Only \"cpu\" is allowed."
128+
)
129+
122130
num = operator.index(num)
123131
if num < 0:
124132
raise ValueError(

numpy/_core/function_base.pyi

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,8 @@ def linspace(
2828
retstep: L[False] = ...,
2929
dtype: None = ...,
3030
axis: SupportsIndex = ...,
31+
*,
32+
device: None | L["cpu"] = ...,
3133
) -> NDArray[floating[Any]]: ...
3234
@overload
3335
def linspace(
@@ -38,6 +40,8 @@ def linspace(
3840
retstep: L[False] = ...,
3941
dtype: None = ...,
4042
axis: SupportsIndex = ...,
43+
*,
44+
device: None | L["cpu"] = ...,
4145
) -> NDArray[complexfloating[Any, Any]]: ...
4246
@overload
4347
def linspace(
@@ -48,6 +52,8 @@ def linspace(
4852
retstep: L[False] = ...,
4953
dtype: _DTypeLike[_SCT] = ...,
5054
axis: SupportsIndex = ...,
55+
*,
56+
device: None | L["cpu"] = ...,
5157
) -> NDArray[_SCT]: ...
5258
@overload
5359
def linspace(
@@ -58,6 +64,8 @@ def linspace(
5864
retstep: L[False] = ...,
5965
dtype: DTypeLike = ...,
6066
axis: SupportsIndex = ...,
67+
*,
68+
device: None | L["cpu"] = ...,
6169
) -> NDArray[Any]: ...
6270
@overload
6371
def linspace(
@@ -68,6 +76,8 @@ def linspace(
6876
retstep: L[True] = ...,
6977
dtype: None = ...,
7078
axis: SupportsIndex = ...,
79+
*,
80+
device: None | L["cpu"] = ...,
7181
) -> tuple[NDArray[floating[Any]], floating[Any]]: ...
7282
@overload
7383
def linspace(
@@ -78,6 +88,8 @@ def linspace(
7888
retstep: L[True] = ...,
7989
dtype: None = ...,
8090
axis: SupportsIndex = ...,
91+
*,
92+
device: None | L["cpu"] = ...,
8193
) -> tuple[NDArray[complexfloating[Any, Any]], complexfloating[Any, Any]]: ...
8294
@overload
8395
def linspace(
@@ -88,6 +100,8 @@ def linspace(
88100
retstep: L[True] = ...,
89101
dtype: _DTypeLike[_SCT] = ...,
90102
axis: SupportsIndex = ...,
103+
*,
104+
device: None | L["cpu"] = ...,
91105
) -> tuple[NDArray[_SCT], _SCT]: ...
92106
@overload
93107
def linspace(
@@ -98,6 +112,8 @@ def linspace(
98112
retstep: L[True] = ...,
99113
dtype: DTypeLike = ...,
100114
axis: SupportsIndex = ...,
115+
*,
116+
device: None | L["cpu"] = ...,
101117
) -> tuple[NDArray[Any], Any]: ...
102118

103119
@overload

numpy/_core/include/numpy/ndarraytypes.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -300,6 +300,11 @@ typedef enum {
300300
NPY_BUSDAY_RAISE
301301
} NPY_BUSDAY_ROLL;
302302

303+
/* Device enum for Array API compatibility */
304+
typedef enum {
305+
NPY_DEVICE_CPU = 0,
306+
} NPY_DEVICE;
307+
303308
/************************************************************
304309
* NumPy Auxiliary Data for inner loops, sort functions, etc.
305310
************************************************************/

numpy/_core/multiarray.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -81,9 +81,12 @@
8181

8282

8383
@array_function_from_c_func_and_dispatcher(_multiarray_umath.empty_like)
84-
def empty_like(prototype, dtype=None, order=None, subok=None, shape=None):
84+
def empty_like(
85+
prototype, dtype=None, order=None, subok=None, shape=None, *, device=None
86+
):
8587
"""
86-
empty_like(prototype, dtype=None, order='K', subok=True, shape=None)
88+
empty_like(prototype, dtype=None, order='K', subok=True, shape=None,
89+
shape=None, *, device=None)
8790
8891
Return a new array with the same shape and type as a given array.
8992
@@ -113,6 +116,10 @@ def empty_like(prototype, dtype=None, order=None, subok=None, shape=None):
113116
order='C' is implied.
114117
115118
.. versionadded:: 1.17.0
119+
device : str, optional
120+
The device on which to place the created array. Default: None.
121+
122+
.. versionadded:: 2.0.0
116123
117124
Returns
118125
-------

0 commit comments

Comments
 (0)
0