8000 Merge pull request #9133 from charris/deprecate-expand_dims-bad-axis · numpy/numpy@65c1a50 · GitHub
[go: up one dir, main page]

Skip to content

Commit 65c1a50

Browse files
authored
Merge pull request #9133 from charris/deprecate-expand_dims-bad-axis
DEP: Deprecate incorrect behavior of expand_dims.
2 parents 1b53503 + 7415596 commit 65c1a50
< 10000 /div>

File tree

3 files changed

+43
-4
lines changed

3 files changed

+43
-4
lines changed

doc/release/1.13.0-notes.rst

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,9 @@ Deprecations
5353
with ``np.minimum``.
5454
* Calling ``ndarray.conjugate`` on non-numeric dtypes is deprecated (it
5555
should match the behavior of ``np.conjugate``, which throws an error).
56+
* Calling ``expand_dims`` when the ``axis`` keyword does not satisfy
57+
``-a.ndim - 1 <= axis <= a.ndim``, where ``a`` is the array being reshaped,
58+
is deprecated.
5659

5760

5861
Future Changes

numpy/lib/shape_base.py

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -240,14 +240,20 @@ def expand_dims(a, axis):
240240
"""
241241
Expand the shape of an array.
242242
243-
Insert a new axis, corresponding to a given position in the array shape.
243+
Insert a new axis that will appear at the `axis` position in the expanded
244+
array shape.
245+
246+
.. note:: Previous to NumPy 1.13.0, neither ``axis < -a.ndim - 1`` nor
247+
``axis > a.ndim`` raised errors or put the new axis where documented.
248+
Those axis values are now deprecated and will raise an AxisError in the
249+
future.
244250
245251
Parameters
246252
----------
247253
a : array_like
248254
Input array.
249255
axis : int
250-
Position (amongst axes) where new axis is to be inserted.
256+
Position in the expanded axes where the new axis is placed.
251257
252258
Returns
253259
-------
@@ -291,7 +297,16 @@ def expand_dims(a, axis):
291297
"""
292298
a = asarray(a)
293299
shape = a.shape
294-
axis = normalize_axis_index(axis, a.ndim + 1)
300+
if axis > a.ndim or axis < -a.ndim - 1:
301+
# 2017-05-17, 1.13.0
302+
warnings.warn("Both axis > a.ndim and axis < -a.ndim - 1 are "
303+
"deprecated and will raise an AxisError in the future.",
304+
DeprecationWarning, stacklevel=2)
305+
# When the deprecation period expires, delete this if block,
306+
if axis < 0:
307+
axis = axis + a.ndim + 1
308+
# and uncomment the following line.
309+
# axis = normalize_axis_index(axis, a.ndim + 1)
295310
return a.reshape(shape[:axis] + (1,) + shape[axis:])
296311

297312
row_stack = vstack

numpy/lib/tests/test_shape_base.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
11
from __future__ import division, absolute_import, print_function
22

33
import numpy as np
4+
import warnings
5+
46
from numpy.lib.shape_base import (
57
apply_along_axis, apply_over_axes, array_split, split, hsplit, dsplit,
6-
vsplit, dstack, column_stack, kron, tile
8+
vsplit, dstack, column_stack, kron, tile, expand_dims,
79
)
810
from numpy.testing import (
911
run_module_suite, TestCase, assert_, assert_equal, assert_array_equal,
@@ -182,6 +184,25 @@ def test_simple(self):
182184
assert_array_equal(aoa_a, np.array([[[60], [92], [124]]]))
183185

184186

187+
class TestExpandDims(TestCase):
188+
def test_functionality(self):
189+
s = (2, 3, 4, 5)
190+
a = np.empty(s)
191+
for axis in range(-5, 4):
192+
b = expand_dims(a, axis)
193+
assert_(b.shape[axis] == 1)
194+
assert_(np.squeeze(b).shape == s)
195+
196+
def test_deprecations(self):
197+
# 2017-05-17, 1.13.0
198+
s = (2, 3, 4, 5)
199+
a = np.empty(s)
200+
with warnings.catch_warnings():
201+
warnings.simplefilter("always")
202+
assert_warns(DeprecationWarning, expand_dims, a, -6)
203+
assert_warns(DeprecationWarning, expand_dims, a, 5)
204+
205+
185206
class TestArraySplit(TestCase):
186207
def test_integer_0_split(self):
187208
a = np.arange(10)

0 commit comments

Comments
 (0)
0