8000 ENH: change list-of-array to tuple-of-array returns (Numba compat) (#… · numpy/numpy@010e841 · GitHub
[go: up one dir, main page]

Skip to content
8000

Commit 010e841

Browse files
authored
ENH: change list-of-array to tuple-of-array returns (Numba compat) (#25570)
Functions in NumPy that return lists of arrays are problematic for Numba. See numba/numba#8008 for a detailed discussion on that. This changes the return types to tuples, which are easier to support for Numba, because tuples are immutable. This change is not backwards-compatible. Estimated impact: - idiomatic end user code should continue to work unchanged, - code that attempts to append or otherwise mutate the list will start raising an exception, which should be easy to fix, - user code that does `if isinstance(..., list):` on the return value of a function like `atleast1d` will break. This should be rare, but since it may not result in a clean error it is probably the place with the highest impact. - user code with explicit `list[NDArray]` type annotations will need to be updated. Includes some small docstring improvements to make clearer what is returned.
1 parent 6bd3abf commit 010e841

19 files changed

+139
-113
lines changed
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
Functions that returned a list of ndarrays have been changed to return a tuple
2+
of ndarrays instead. Returning tuples consistently whenever a sequence of
3+
arrays is returned makes it easier for JIT compilers like Numba, as well as for
4+
static type checkers in some cases, to support these functions. Changed
5+
functions are: ``atleast_1d``, ``atleast_2d``, ``atleast_3d``, ``broadcast_arrays``,
6+
``array_split``, ``split``, ``hsplit``, ``vsplit``, ``dsplit``, ``meshgrid``,
7+
``ogrid``, ``histogramdd``.

numpy/_core/shape_base.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ def atleast_1d(*arys):
3636
Returns
3737
-------
3838
ret : ndarray
39-
An array, or list of arrays, each with ``a.ndim >= 1``.
39+
An array, or tuple of arrays, each with ``a.ndim >= 1``.
4040
Copies are made only if necessary.
4141
4242
See Also
@@ -57,7 +57,7 @@ def atleast_1d(*arys):
5757
True
5858
5959
>>> np.atleast_1d(1, [3, 4])
60-
[array([1]), array([3, 4])]
60+
(array([1]), array([3, 4]))
6161
6262
"""
6363
res = []
@@ -71,7 +71,7 @@ def atleast_1d(*arys):
7171
if len(res) == 1:
7272
return res[0]
7373
else:
74-
return res
74+
return tuple(res)
7575

7676

7777
def _atleast_2d_dispatcher(*arys):
@@ -93,7 +93,7 @@ def atleast_2d(*arys):
9393
Returns
9494
-------
9595
res, res2, ... : ndarray
96-
An array, or list of arrays, each with ``a.ndim >= 2``.
96+
An array, or tuple of arrays, each with ``a.ndim >= 2``.
9797
Copies are avoided where possible, and views with two or more
9898
dimensions are returned.
9999
@@ -113,7 +113,7 @@ def atleast_2d(*arys):
113113
True
114114
115115
>>> np.atleast_2d(1, [1, 2], [[1, 2]])
116-
[array([[1]]), array([[1, 2]]), array([[1, 2]])]
116+
(array([[1]]), array([[1, 2]]), array([[1, 2]]))
117117
118118
"""
119119
res = []
@@ -129,7 +129,7 @@ def atleast_2d(*arys):
129129
if len(res) == 1:
130130
return res[0]
131131
else:
132-
return res
132+
return tuple(res)
133133

134134

135135
def _atleast_3d_dispatcher(*arys):
@@ -151,7 +151,7 @@ def atleast_3d(*arys):
151151
Returns
152152
-------
153153
res1, res2, ... : ndarray
154-
An array, or list of arrays, each with ``a.ndim >= 3``. Copies are
154+
An array, or tuple of arrays, each with ``a.ndim >= 3``. Copies are
155155
avoided where possible, and views with three or more dimensions are
156156
returned. For example, a 1-D array of shape ``(N,)`` becomes a view
157157
of shape ``(1, N, 1)``, and a 2-D array of shape ``(M, N)`` becomes a
@@ -201,7 +201,7 @@ def atleast_3d(*arys):
201201
if len(res) == 1:
202202
return res[0]
203203
else:
204-
return res
204+
return tuple(res)
205205

206206

207207
def _arrays_for_stack_dispatcher(arrays):
@@ -282,8 +282,8 @@ def vstack(tup, *, dtype=None, casting="same_kind"):
282282
283283
"""
284284
arrs = atleast_2d(*tup)
285-
if not isinstance(arrs, list):
286-
arrs = [arrs]
285+
if not isinstance(arrs, tuple):
286+
arrs = (arrs,)
287287
return _nx.concatenate(arrs, 0, dtype=dtype, casting=casting)
288288

289289

@@ -349,8 +349,8 @@ def hstack(tup, *, dtype=None, casting="same_kind"):
349349
350350
"""
351351
arrs = atleast_1d(*tup)
352-
if not isinstance(arrs, list):
353-
arrs = [arrs]
352+
if not isinstance(arrs, tuple):
353+
arrs = (arrs,)
354354
# As a special case, dimension 0 of 1-dimensional arrays is "horizontal"
355355
if arrs and arrs[0].ndim == 1:
356356
return _nx.concatenate(arrs, 0, dtype=dtype, casting=casting)

numpy/_core/shape_base.pyi

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,21 +20,21 @@ def atleast_1d(arys: _ArrayLike[_SCT], /) -> NDArray[_SCT]: ...
2020
@overload
2121
def atleast_1d(arys: ArrayLike, /) -> NDArray[Any]: ...
2222
@overload
23-
def atleast_1d(*arys: ArrayLike) -> list[NDArray[Any]]: ...
23+
def atleast_1d(*arys: ArrayLike) -> tuple[NDArray[Any], ...]: ...
2424

2525
@overload
2626
def atleast_2d(arys: _ArrayLike[_SCT], /) -> NDArray[_SCT]: ...
2727
@overload
2828
def atleast_2d(arys: ArrayLike, /) -> NDArray[Any]: ...
2929
@overload
30-
def atleast_2d(*arys: ArrayLike) -> list[NDArray[Any]]: ...
30+
def atleast_2d(*arys: ArrayLike) -> tuple[NDArray[Any], ...]: ...
3131

3232
@overload
3333
def atleast_3d(arys: _ArrayLike[_SCT], /) -> NDArray[_SCT]: ...
3434
@overload
3535
def atleast_3d(arys: ArrayLike, /) -> NDArray[Any]: ...
3636
@overload
37-
def atleast_3d(*arys: ArrayLike) -> list[NDArray[Any]]: ...
37+
def atleast_3d(*arys 10000 : ArrayLike) -> tuple[NDArray[Any], ...]: ...
3838

3939
@overload
4040
def vstack(

numpy/lib/_function_base_impl.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4986,7 +4986,7 @@ def _meshgrid_dispatcher(*xi, copy=None, sparse=None, indexing=None):
49864986
@array_function_dispatch(_meshgrid_dispatcher)
49874987
def meshgrid(*xi, copy=True, sparse=False, indexing='xy'):
49884988
"""
4989-
Return a list of coordinate matrices from coordinate vectors.
4989+
Return a tuple of coordinate matrices from coordinate vectors.
49904990
49914991
Make N-D coordinate arrays for vectorized evaluations of
49924992
N-D scalar/vector fields over N-D grids, given
@@ -5027,7 +5027,7 @@ def meshgrid(*xi, copy=True, sparse=False, indexing='xy'):
50275027
50285028
Returns
50295029
-------
5030-
X1, X2,..., XN : list of ndarrays
5030+
X1, X2,..., XN : tuple of ndarrays
50315031
For vectors `x1`, `x2`,..., `xn` with lengths ``Ni=len(xi)``,
50325032
returns ``(N1, N2, N3,..., Nn)`` shaped arrays if indexing='ij'
50335033
or ``(N2, N1, N3,..., Nn)`` shaped arrays if indexing='xy'
@@ -5136,7 +5136,7 @@ def meshgrid(*xi, copy=True, sparse=False, indexing='xy'):
51365136
output = np.broadcast_arrays(*output, subok=True)
51375137

51385138
if copy:
5139-
output = [x.copy() for x in output]
5139+
output = tuple(x.copy() for x in output)
51405140

51415141
return output
51425142

numpy/lib/_function_base_impl.pyi

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -626,7 +626,7 @@ def meshgrid(
626626
copy: bool = ...,
627627
sparse: bool = ...,
628628
indexing: L["xy", "ij"] = ...,
629-
) -> list[NDArray[Any]]: ...
629+
) -> tuple[NDArray[Any], ...]: ...
630630

631631
@overload
632632
def delete(

numpy/lib/_histograms_impl.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -957,8 +957,8 @@ def histogramdd(sample, bins=10, range=None, density=None, weights=None):
957957
H : ndarray
958958
The multidimensional histogram of sample x. See density and weights
959959
for the different possible semantics.
960-
edges : list
961-
A list of D arrays describing the bin edges for each dimension.
960+
edges : tuple of ndarrays
961+
A tuple of D arrays describing the bin edges for each dimension.
962962
963963
See Also
964964
--------

numpy/lib/_histograms_impl.pyi

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,4 +44,4 @@ def histogramdd(
4444
range: Sequence[tuple[float, float]] = ...,
4545
density: None | bool = ...,
4646
weights: None | ArrayLike = ...,
47-
) -> tuple[NDArray[Any], list[NDArray[Any]]]: ...
47+
) -> tuple[NDArray[Any], tuple[NDArray[Any], ...]]: ...

numpy/lib/_index_tricks_impl.py

Lines changed: 20 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -189,7 +189,8 @@ def __getitem__(self, key):
189189
slobj[k] = slice(None, None)
190190
nn[k] = nn[k][tuple(slobj)]
191191
slobj[k] = _nx.newaxis
192-
return nn
192+
return tuple(nn) # ogrid -> tuple of arrays
193+
return nn # mgrid -> ndarray
193194
except (IndexError, TypeError):
194195
step = key.step
195196
stop = key.stop
@@ -225,8 +226,9 @@ class MGridClass(nd_grid):
225226
226227
Returns
227228
-------
228-
mesh-grid
229-
`ndarray`\\ s all of the same dimensions
229+
mesh-grid : ndarray
230+
A single array, containing a set of `ndarray`\\ s all of the same
231+
dimensions. stacked along the first axis.
230232
231233
See Also
232234
--------
@@ -251,6 +253,13 @@ class MGridClass(nd_grid):
251253
>>> np.mgrid[-1:1:5j]
252254
array([-1. , -0.5, 0. , 0.5, 1. ])
253255
256+
>>> np.mgrid[0:4].shape
257+
(4,)
258+
>>> np.mgrid[0:4, 0:5].shape
259+
(2, 4, 5)
260+
>>> np.mgrid[0:4, 0:5, 0:6].shape
261+
(3, 4, 5, 6)
262+
254263
"""
255264

256265
def __init__(self):
@@ -277,8 +286,10 @@ class OGridClass(nd_grid):
277286
278287
Returns
279288
-------
280-
mesh-grid
281-
`ndarray`\\ s with only one dimension not equal to 1
289+
mesh-grid : ndarray or tuple of ndarrays
290+
If the input is a single slice, returns an array.
291+
If the input is multiple slices, returns a tuple of arrays, with
292+
only one dimension not equal to 1.
282293
283294
See Also
284295
--------
@@ -292,12 +303,13 @@ class OGridClass(nd_grid):
292303
>>> from numpy import ogrid
293304
>>> ogrid[-1:1:5j]
294305
array([-1. , -0.5, 0. , 0.5, 1. ])
295-
>>> ogrid[0:5,0:5]
296-
[array([[0],
306+
>>> ogrid[0:5, 0:5]
307+
(array([[0],
297308
[1],
298309
[2],
299310
[3],
300-
[4]]), array([[0, 1, 2, 3, 4]])]
311+
[4]]),
312+
array([[0, 1, 2, 3, 4]]))
301313
302314
"""
303315

numpy/lib/_index_tricks_impl.pyi

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ class nd_grid(Generic[_BoolType]):
7474
def __getitem__(
7575
self: nd_grid[Literal[True]],
7676
key: slice | Sequence[slice],
77-
) -> list[NDArray[Any]]: ...
77+
) -> tuple[NDArray[Any], ...]: ...
7878

7979
class MGridClass(nd_grid[Literal[False]]):
8080
def __init__(self) -> None: ...

0 commit comments

Comments
 (0)
0