10000 WIP: MAINT: change list-of-array to tuple-of-array returns (Numba com… · numpy/numpy@6461197 · GitHub
[go: up one dir, main page]

Skip to content

Commit 6461197

Browse files
committed
WIP: MAINT: change list-of-array to tuple-of-array returns (Numba compat)
Functions in NumPy that return lists of arrays are problematic for Numba. See numba/numba#8008 for a detailed discussion on that. This changes the return types to tuples, which are easier to support for Numba, because tuples are immutable. This change is not backwards-compatible. Estimated impact: - idiomatic end user code should continue to work unchanged, - code that attempts to append or otherwise mutate the list will start raising an exception, which should be easy to fix, - user code that does `if isinstance(..., list):` on the return value of a function like `atleast1d` will break. This should be rare, but since it may not result in a clean error it is probably the place with the highest impact. - user code with explicit `list[NDArray]` type annotations will need to be updated. [skip cirrus] [skip circle]
1 parent 220f0ab commit 6461197

File tree

4 files changed

+52
-50
lines changed

4 files changed

+52
-50
lines changed

numpy/_core/shape_base.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ def atleast_1d(*arys):
3636
Returns
3737
-------
3838
ret : ndarray
39-
An array, or list of arrays, each with ``a.ndim >= 1``.
39+
An array, or tuple of arrays, each with ``a.ndim >= 1``.
4040
Copies are made only if necessary.
4141
4242
See Also
@@ -57,7 +57,7 @@ def atleast_1d(*arys):
5757
True
5858
5959
>>> np.atleast_1d(1, [3, 4])
60-
[array([1]), array([3, 4])]
60+
(array([1]), array([3, 4]))
6161
6262
"""
6363
res = []
@@ -71,7 +71,7 @@ def atleast_1d(*arys):
7171
if len(res) == 1:
7272
return res[0]
7373
else:
74-
return res
74+
return tuple(res)
7575

7676

7777
def _atleast_2d_dispatcher(*arys):
@@ -93,7 +93,7 @@ def atleast_2d(*arys):
9393
Returns
9494
-------
9595
res, res2, ... : ndarray
96-
An array, or list of arrays, each with ``a.ndim >= 2``.
96+
An array, or tuple of arrays, each with ``a.ndim >= 2``.
9797
Copies are avoided where possible, and views with two or more
9898
dimensions are returned.
9999
@@ -113,7 +113,7 @@ def atleast_2d(*arys):
113113
True
114114
115115
>>> np.atleast_2d(1, [1, 2], [[1, 2]])
116-
[array([[1]]), array([[1, 2]]), array([[1, 2]])]
116+
(array([[1]]), array([[1, 2]]), array([[1, 2]]))
117117
118118
"""
119119
res = []
@@ -129,7 +129,7 @@ def atleast_2d(*arys):
129129
if len(res) == 1:
130130
return res[0]
131131
else:
132-
return res
132+
return tuple(res)
133133

134134

135135
def _atleast_3d_dispatcher(*arys):
@@ -151,7 +151,7 @@ def atleast_3d(*arys):
151151
Returns
152152
-------
153153
res1, res2, ... : ndarray
154-
An array, or list of arrays, each with ``a.ndim >= 3``. Copies are
154+
An array, or tuple of arrays, each with ``a.ndim >= 3``. Copies are
155155
avoided where possible, and views with three or more dimensions are
156156
returned. For example, a 1-D array of shape ``(N,)`` becomes a view
157157
of shape ``(1, N, 1)``, and a 2-D array of shape ``(M, N)`` becomes a
@@ -201,7 +201,7 @@ def atleast_3d(*arys):
201201
if len(res) == 1:
202202
return res[0]
203203
else:
204-
return res
204+
return tuple(res)
205205

206206

207207
def _arrays_for_stack_dispatcher(arrays):
@@ -282,7 +282,7 @@ def vstack(tup, *, dtype=None, casting="same_kind"):
282282
283283
"""
284284
arrs = atleast_2d(*tup)
285-
if not isinstance(arrs, list):
285+
if not isinstance(arrs, tuple):
286286
arrs = [arrs]
287287
return _nx.concatenate(arrs, 0, dtype=dtype, casting=casting)
288288

@@ -349,7 +349,7 @@ def hstack(tup, *, dtype=None, casting="same_kind"):
349349
350350
"""
351351
arrs = atleast_1d(*tup)
352-
if not isinstance(arrs, list):
352+
if not isinstance(arrs, tuple):
353353
arrs = [arrs]
354354
# As a special case, dimension 0 of 1-dimensional arrays is "horizontal"
355355
if arrs and arrs[0].ndim == 1:

numpy/lib/_shape_base_impl.py

Lines changed: 22 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -726,7 +726,7 @@ def dstack(tup):
726726
727727
"""
728728
arrs = atleast_3d(*tup)
729-
if not isinstance(arrs, list):
729+
if not isinstance(arrs, tuple):
730730
arrs = [arrs]
731731
return _nx.concatenate(arrs, 2)
732732

@@ -764,11 +764,11 @@ def array_split(ary, indices_or_sections, axis=0):
764764
--------
765765
>>> x = np.arange(8.0)
766766
>>> np.array_split(x, 3)
767-
[array([0., 1., 2.]), array([3., 4., 5.]), array([6., 7.])]
767+
(array([0., 1., 2.]), array([3., 4., 5.]), array([6., 7.]))
768768
769769
>>> x = np.arange(9)
770770
>>> np.array_split(x, 4)
771-
[array([0, 1, 2]), array([3, 4]), array([5, 6]), array([7, 8])]
771+
(array([0, 1, 2]), array([3, 4]), array([5, 6]), array([7, 8]))
772772
773773
"""
774774
try:
@@ -797,7 +797,7 @@ def array_split(ary, indices_or_sections, axis=0):
797797
end = div_points[i + 1]
798798
sub_arys.append(_nx.swapaxes(sary[st:end], axis, 0))
799799

800-
return sub_arys
800+
return tuple(sub_arys)
801801

802802

803803
def _split_dispatcher(ary, indices_or_sections, axis=None):
@@ -833,8 +833,8 @@ def split(ary, indices_or_sections, axis=0):
833833
834834
Returns
835835
-------
836-
sub-arrays : list of ndarrays
837-
A list of sub-arrays as views into `ary`.
836+
sub-arrays : tuple of ndarrays
837+
A tuple of sub-arrays as views into `ary`.
838838
839839
Raises
840840
------
@@ -860,15 +860,15 @@ def split(ary, indices_or_sections, axis=0):
860860
--------
861861
>>> x = np.arange(9.0)
862862
>>> np.split(x, 3)
863-
[array([0., 1., 2.]), array([3., 4., 5.]), array([6., 7., 8.])]
863+
(array([0., 1., 2.]), array([3., 4., 5.]), array([6., 7., 8.]))
864864
865865
>>> x = np.arange(8.0)
866866
>>> np.split(x, [3, 5, 6, 10])
867-
[array([0., 1., 2.]),
867+
(array([0., 1., 2.]),
868868
array([3., 4.]),
869869
array([5.]),
870870
array([6., 7.]),
871-
array([], dtype=float64)]
871+
array([], dtype=float64))
872872
873873
"""
874874
try:
@@ -936,16 +936,16 @@ def hsplit(ary, indices_or_sections):
936936
[[4., 5.],
937937
[6., 7.]]])
938938
>>> np.hsplit(x, 2)
939-
[array([[[0., 1.]],
939+
(array([[[0., 1.]],
940940
[[4., 5.]]]),
941941
array([[[2., 3.]],
942-
[[6., 7.]]])]
942+
[[6., 7.]]]))
943943
944944
With a 1-D array, the split is along axis 0.
945945
946946
>>> x = np.array([0, 1, 2, 3, 4, 5])
947947
>>> np.hsplit(x, 2)
948-
[array([0, 1, 2]), array([3, 4, 5])]
948+
(array([0, 1, 2]), array([3, 4, 5]))
949949
950950
"""
951951
if _nx.ndim(ary) == 0:
@@ -978,13 +978,13 @@ def vsplit(ary, indices_or_sections):
978978
[ 8., 9., 10., 11.],
979979
[12., 13., 14., 15.]])
980980
>>> np.vsplit(x, 2)
981-
[array([[0., 1., 2., 3.],
981+
(array([[0., 1., 2., 3.],
982982
[4., 5., 6., 7.]]), array([[ 8., 9., 10., 11.],
983-
[12., 13., 14., 15.]])]
983+
[12., 13., 14., 15.]]))
984984
>>> np.vsplit(x, np.array([3, 6]))
985-
[array([[ 0., 1., 2., 3.],
985+
(array([[ 0., 1., 2., 3.],
986986
[ 4., 5., 6., 7.],
987-
[ 8., 9., 10., 11.]]), array([[12., 13., 14., 15.]]), array([], shape=(0, 4), dtype=float64)]
987+
[ 8., 9., 10., 11.]]), array([[12., 13., 14., 15.]]), array([], shape=(0, 4), dtype=float64))
988988
989989
With a higher dimensional array the split is still along the first axis.
990990
@@ -995,9 +995,9 @@ def vsplit(ary, indices_or_sections):
995995
[[4., 5.],
996996
[6., 7.]]])
997997
>>> np.vsplit(x, 2)
998-
[array([[[0., 1.],
998+
(array([[[0., 1.],
999999
[2., 3.]]]), array([[[4., 5.],
1000-
[6., 7.]]])]
1000+
[6., 7.]]]))
10011001
10021002
"""
10031003
if _nx.ndim(ary) < 2:
@@ -1027,23 +1027,23 @@ def dsplit(ary, indices_or_sections):
10271027
[[ 8., 9., 10., 11.],
10281028
[12., 13., 14., 15.]]])
10291029
>>> np.dsplit(x, 2)
1030-
[array([[[ 0., 1.],
1030+
(array([[[ 0., 1.],
10311031
[ 4., 5.]],
10321032
[[ 8., 9.],
10331033
[12., 13.]]]), array([[[ 2., 3.],
10341034
[ 6., 7.]],
10351035
[[10., 11.],
1036-
[14., 15.]]])]
1036+
[14., 15.]]]))
10371037
>>> np.dsplit(x, np.array([3, 6]))
1038-
[array([[[ 0., 1., 2.],
1038+
(array([[[ 0., 1., 2.],
10391039
[ 4., 5., 6.]],
10401040
[[ 8., 9., 10.],
10411041
[12., 13., 14.]]]),
10421042
array([[[ 3.],
10431043
[ 7.]],
10441044
[[11.],
10451045
[15.]]]),
1046-
array([], shape=(2, 2, 0), dtype=float64)]
1046+
array([], shape=(2, 2, 0), dtype=float64))
10471047
"""
10481048
if _nx.ndim(ary) < 3:
10491049
raise ValueError('dsplit only works on arrays of 3 or more dimensions')

numpy/lib/_shape_base_impl.pyi

Lines changed: 10 additions & 10 deletions
10000
Original file line numberDiff line numberDiff line change
@@ -115,59 +115,59 @@ def array_split(
115115
ary: _ArrayLike[_SCT],
116116
indices_or_sections: _ShapeLike,
117117
axis: SupportsIndex = ...,
118-
) -> list[NDArray[_SCT]]: ...
118+
) -> tuple[NDArray[_SCT]]: ...
119119
@overload
120120
def array_split(
121121
ary: ArrayLike,
122122
indices_or_sections: _ShapeLike,
123123
axis: SupportsIndex = ...,
124-
) -> list[NDArray[Any]]: ...
124+
) -> tuple[NDArray[Any]]: ...
125125

126126
@overload
127127
def split(
128128
ary: _ArrayLike[_SCT],
129129
indices_or_sections: _ShapeLike,
130130
axis: SupportsIndex = ...,
131-
) -> list[NDArray[_SCT]]: ...
131+
) -> tuple[NDArray[_SCT]]: ...
132132
@overload
133133
def split(
134134
ary: ArrayLike,
135135
indices_or_sections: _ShapeLike,
136136
axis: SupportsIndex = ...,
137-
) -> list[NDArray[Any]]: ...
137+
) -> tuple[NDArray[Any]]: ...
138138

139139
@overload
140140
def hsplit(
141141
ary: _ArrayLike[_SCT],
142142
indices_or_sections: _ShapeLike,
143-
) -> list[NDArray[_SCT]]: ...
143+
) -> tuple[NDArray[_SCT]]: ...
144144
@overload
145145
def hsplit(
146146
ary: ArrayLike,
147147
indices_or_sections: _ShapeLike,
148-
) -> list[NDArray[Any]]: ...
148+
) -> tuple[NDArray[Any]]: ...
149149

150150
@overload
151151
def vsplit(
152152
ary: _ArrayLike[_SCT],
153153
indices_or_sections: _ShapeLike,
154-
) -> list[NDArray[_SCT]]: ...
154+
) -> tuple[NDArray[_SCT]]: ...
155155
@overload
156156
def vsplit(
157157
ary: ArrayLike,
158158
indices_or_sections: _ShapeLike,
159-
) -> list[NDArray[Any]]: ...
159+
) -> tuple[NDArray[Any]]: ...
160160

161161
@overload
162162
def dsplit(
163163
ary: _ArrayLike[_SCT],
164164
indices_or_sections: _ShapeLike,
165-
) -> list[NDArray[_SCT]]: ...
165+
) -> tuple[NDArray[_SCT]]: ...
166166
@overload
167167
def dsplit(
168168
ary: ArrayLike,
169169
indices_or_sections: _ShapeLike,
170-
) -> list[NDArray[Any]]: ...
170+
) -> tuple[NDArray[Any]]: ...
171171

172172
@overload
173173
def get_array_wrap(*args: _SupportsArrayWrap) -> _ArrayWrap: ...

numpy/lib/_stride_tricks_impl.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -502,7 +502,7 @@ def broadcast_arrays(*args, subok=False):
502502
503503
Returns
504504
-------
505-
broadcasted : list of arrays
505+
broadcasted : tuple of arrays
506506
These arrays are views on the original arrays. They are typically
507507
not contiguous. Furthermore, more than one element of a
508508
broadcasted array may refer to a single memory location. If you need
@@ -526,17 +526,19 @@ def broadcast_arrays(*args, subok=False):
526526
>>> x = np.array([[1,2,3]])
527527
>>> y = np.array([[4],[5]])
528528
>>> np.broadcast_arrays(x, y)
529-
[array([[1, 2, 3],
530-
[1, 2, 3]]), array([[4, 4, 4],
531-
[5, 5, 5]])]
529+
(array([[1, 2, 3],
530+
[1, 2, 3]]),
531+
array([[4, 4, 4],
532+
[5, 5, 5]]))
532533
533534
Here is a useful idiom for getting contiguous copies instead of
534535
non-contiguous views.
535536
536537
>>> [np.array(a) for a in np.broadcast_arrays(x, y)]
537538
[array([[1, 2, 3],
538-
[1, 2, 3]]), array([[4, 4, 4],
539-
[5, 5, 5]])]
539+
[1, 2, 3]]),
540+
array([[4, 4, 4],
541+
[5, 5, 5]])]
540542
541543
"""
542544
# nditer is not used here to avoid the limit of 32 arrays.
@@ -552,5 +554,5 @@ def broadcast_arrays(*args, subok=False):
552554
# Common case where nothing needs to be broadcasted.
553555
return args
554556

555-
return [_broadcast_to(array, shape, subok=subok, readonly=False)
556-
for array in args]
557+
return tuple([_broadcast_to(array, shape, subok=subok, readonly=False)
558+
for array in args])

0 commit comments

Comments
 (0)
0