8000 MAINT: Use the same exception for all bad axis requests by eric-wieser · Pull Request #8584 · numpy/numpy · GitHub
[go: up one dir, main page]

Skip to content

MAINT: Use the same exception for all bad axis requests #8584

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Feb 21, 2017
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
MAINT: Use normalize_axis_index in all python axis checking
As a result, some exceptions change from ValueError to IndexError

This also changes the exception types raised in places where
normalize_axis_index is not quite appropriate
  • Loading branch information
eric-wieser committed Feb 20, 2017
commit 370b6506f128460371484a50c813d66e64582f44
11 changes: 4 additions & 7 deletions numpy/core/numeric.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
inner, int_asbuffer, lexsort, matmul, may_share_memory,
min_scalar_type, ndarray, nditer, nested_iters, promote_types,
putmask, result_type, set_numeric_ops, shares_memory, vdot, where,
zeros)
zeros, normalize_axis_index)
if sys.version_info[0] < 3:
from .multiarray import newbuffer, getbuffer

Expand Down Expand Up @@ -1527,15 +1527,12 @@ def rollaxis(a, axis, start=0):

"""
n = a.ndim
if axis < 0:
axis += n
axis = normalize_axis_index(axis, n)
if start < 0:
start += n
msg = "'%s' arg requires %d <= %s < %d, but %d was passed in"
if not (0 <= axis < n):
raise ValueError(msg % ('axis', -n, 'axis', n, axis))
if not (0 <= start < n + 1):
raise ValueError(msg % ('start', -n, 'start', n + 1, start))
raise IndexError(msg % ('start', -n, 'start', n + 1, start))
if axis < start:
# it's been removed
start -= 1
Expand All @@ -1554,7 +1551,7 @@ def _validate_axis(axis, ndim, argname):
axis = list(axis)
axis = [a + ndim if a < 0 else a for a in axis]
if not builtins.all(0 <= a < ndim for a in axis):
raise ValueError('invalid axis for this array in `%s` argument' %
raise IndexError('invalid axis for this array in `%s` argument' %
argname)
if len(set(axis)) != len(axis):
raise ValueError('repeated axis in `%s` argument' % argname)
Expand Down
7 changes: 2 additions & 5 deletions numpy/core/shape_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from . import numeric as _nx
from .numeric import asanyarray, newaxis
from .multiarray import normalize_axis_index

def atleast_1d(*arys):
"""
Expand Down Expand Up @@ -347,11 +348,7 @@ def stack(arrays, axis=0):
raise ValueError('all input arrays must have the same shape')

result_ndim = arrays[0].ndim + 1
if not -result_ndim <= axis < result_ndim:
msg = 'axis {0} out of bounds [-{1}, {1})'.format(axis, result_ndim)
raise IndexError(msg)
if axis < 0:
axis += result_ndim
axis = normalize_axis_index(axis, result_ndim)

sl = (slice(None),) * axis + (_nx.newaxis,)
expanded_arrays = [arr[sl] for arr in arrays]
Expand Down
16 changes: 8 additions & 8 deletions numpy/core/tests/test_multiarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -2013,13 +2013,13 @@ def test_partition(self):
d = np.array([2, 1])
d.partition(0, kind=k)
assert_raises(ValueError, d.partition, 2)
assert_raises(ValueError, d.partition, 3, axis=1)
assert_raises(IndexError, d.partition, 3, axis=1)
assert_raises(ValueError, np.partition, d, 2)
assert_raises(ValueError, np.partition, d, 2, axis=1)
assert_raises(IndexError, np.partition, d, 2, axis=1)
assert_raises(ValueError, d.argpartition, 2)
assert_raises(ValueError, d.argpartition, 3, axis=1)
assert_raises(IndexError, d.argpartition, 3, axis=1)
assert_raises(ValueError, np.argpartition, d, 2)
assert_raises(ValueError, np.argpartition, d, 2, axis=1)
assert_raises(IndexError, np.argpartition, d, 2, axis=1)
d = np.arange(10).reshape((2, 5))
d.partition(1, axis=0, kind=k)
d.partition(4, axis=1, kind=k)
Expand Down Expand Up @@ -3522,16 +3522,16 @@ def test_object_argmin_with_NULLs(self):
class TestMinMax(TestCase):

def test_scalar(self):
assert_raises(ValueError, np.amax, 1, 1)
assert_raises(ValueError, np.amin, 1, 1)
assert_raises(IndexError, np.amax, 1, 1)
assert_raises(IndexError, np.amin, 1, 1)

assert_equal(np.amax(1, axis=0), 1)
assert_equal(np.amin(1, axis=0), 1)
assert_equal(np.amax(1, axis=None), 1)
assert_equal(np.amin(1, axis=None), 1)

def test_axis(self):
assert_raises(ValueError, np.amax, [1, 2, 3], 1000)
assert_raises(IndexError, np.amax, [1, 2, 3], 1000)
assert_equal(np.amax([[1, 2, 3]], axis=1), 3)

def test_datetime(self):
Expand Down Expand Up @@ -3793,7 +3793,7 @@ def test_object(self): # gh-6312

def test_invalid_axis(self): # gh-7528
x = np.linspace(0., 1., 42*3).reshape(42, 3)
assert_raises(ValueError, np.lexsort, x, axis=2)
assert_raises(IndexError, np.lexsort, x, axis=2)

class TestIO(object):
"""Test tofile, fromfile, tobytes, and fromstring"""
Expand Down
16 changes: 8 additions & 8 deletions numpy/core/tests/test_numeric.py
Original file line number Diff line number Diff line change
Expand Up @@ -1010,7 +1010,7 @@ def test_count_nonzero_axis(self):

assert_raises(ValueError, np.count_nonzero, m, axis=(1, 1))
assert_raises(TypeError, np.count_nonzero, m, axis='foo')
assert_raises(ValueError, np.count_nonzero, m, axis=3)
assert_raises(IndexError, np.count_nonzero, m, axis=3)
assert_raises(TypeError, np.count_nonzero,
m, axis=np.array([[1], [2]]))

Expand Down Expand Up @@ -2323,10 +2323,10 @@ class TestRollaxis(TestCase):

def test_exceptions(self):
a = np.arange(1*2*3*4).reshape(1, 2, 3, 4)
assert_raises(ValueError, np.rollaxis, a, -5, 0)
assert_raises(ValueError, np.rollaxis, a, 0, -5)
assert_raises(ValueError, np.rollaxis, a, 4, 0)
assert_raises(ValueError, np.rollaxis, a, 0, 5)
assert_raises(IndexError, np.rollaxis, a, -5, 0)
assert_raises(IndexError, np.rollaxis, a, 0, -5)
assert_raises(IndexError, np.rollaxis, a, 4, 0)
assert_raises(IndexError, np.rollaxis, a, 0, 5)

def test_results(self):
a = np.arange(1*2*3*4).reshape(1, 2, 3, 4).copy()
Expand Down Expand Up @@ -2413,11 +2413,11 @@ def test_move_multiples(self):

def test_errors(self):
x = np.random.randn(1, 2, 3)
assert_raises_regex(ValueError, 'invalid axis .* `source`',
assert_raises_regex(IndexError, 'invalid axis .* `source`',
np.moveaxis, x, 3, 0)
assert_raises_regex(ValueError, 'invalid axis .* `source`',
assert_raises_regex(IndexError, 'invalid axis .* `source`',
np.moveaxis, x, -4, 0)
assert_raises_regex(ValueError, 'invalid axis .* `destination`',
assert_raises_regex(IndexError, 'invalid axis .* `destination`',
np.moveaxis, x, 0, 5)
assert_raises_regex(ValueError, 'repeated axis in `source`',
np.moveaxis, x, [0, 0], [0, 1])
Expand Down
8 changes: 4 additions & 4 deletions numpy/core/tests/test_ufunc.py
Original file line number Diff line number Diff line change
Expand Up @@ -703,14 +703,14 @@ def test_zerosize_reduction(self):

def test_axis_out_of_bounds(self):
a = np.array([False, False])
assert_raises(ValueError, a.all, axis=1)
assert_raises(IndexError, a.all, axis=1)
a = np.array([False, False])
assert_raises(ValueError, a.all, axis=-2)
assert_raises(IndexError, a.all, axis=-2)

a = np.array([False, False])
assert_raises(ValueError, a.any, axis=1)
assert_raises(IndexError, a.any, axis=1)
a = np.array([False, False])
assert_raises(ValueError, a.any, axis=-2)
assert_raises(IndexError, a.any, axis=-2)

def test_scalar_reduction(self):
# The functions 'sum', 'prod', etc allow specifying axis=0
Expand Down
13 changes: 4 additions & 9 deletions numpy/lib/function_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from numpy.lib.twodim_base import diag
from .utils import deprecate
from numpy.core.multiarray import (
_insert, add_docstring, digitize, bincount,
_insert, add_docstring, digitize, bincount, normalize_axis_index,
interp as compiled_interp, interp_complex as compiled_interp_complex
)
from numpy.core.umath import _add_newdoc_ufunc as add_newdoc_ufunc
Expand Down Expand Up @@ -4828,14 +4828,7 @@ def insert(arr, obj, values, axis=None):
arr = arr.ravel()
ndim = arr.ndim
axis = ndim - 1
else:
if ndim > 0 and (axis < -ndim or axis >= ndim):
raise IndexError(
"axis %i is out of bounds for an array of "
"dimension %i" % (axis, ndim))
if (axis < 0):
axis += ndim
if (ndim == 0):
elif ndim == 0:
# 2013-09-24, 1.9
warnings.warn(
"in the future the special handling of scalars will be removed "
Expand All @@ -4846,6 +4839,8 @@ def insert(arr, obj, values, axis=None):
return wrap(arr)
else:
return arr
else:
axis = normalize_axis_index(axis, ndim)
slobj = [slice(None)]*ndim
N = arr.shape[axis]
newshape = list(arr.shape)
Expand Down
9 changes: 3 additions & 6 deletions numpy/lib/shape_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
asarray, zeros, outer, concatenate, isscalar, array, asanyarray
)
from numpy.core.fromnumeric import product, reshape, transpose
from numpy.core.multiarray import normalize_axis_index
from numpy.core import vstack, atleast_3d
from numpy.lib.index_tricks import ndindex
from numpy.matrixlib.defmatrix import matrix # this raises all the right alarm bells
Expand Down Expand Up @@ -96,10 +97,7 @@ def apply_along_axis(func1d, axis, arr, *args, **kwargs):
# handle negative axes
arr = asanyarray(arr)
nd = arr.ndim
if not (-nd <= axis < nd):
raise IndexError('axis {0} out of bounds [-{1}, {1})'.format(axis, nd))
if axis < 0:
axis += nd
axis = normalize_axis_index(axis, nd)

# arr, with the iteration axis at the end
in_dims = list(range(nd))
Expand Down Expand Up @@ -289,8 +287,7 @@ def expand_dims(a, axis):
"""
a = asarray(a)
shape = a.shape
if axis < 0:
axis = axis + len(shape) + 1
axis = normalize_axis_index(axis, a.ndim + 1)
return a.reshape(shape[:axis] + (1,) + shape[axis:])

row_stack = vstack
Expand Down
10 changes: 3 additions & 7 deletions numpy/linalg/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
finfo, errstate, geterrobj, longdouble, rollaxis, amin, amax, product, abs,
broadcast, atleast_2d, intp, asanyarray, isscalar, object_
)
from numpy.core.multiarray import normalize_axis_index
from numpy.lib import triu, asfarray
from numpy.linalg import lapack_lite, _umath_linalg
from numpy.matrixlib.defmatrix import matrix_power
Expand Down Expand Up @@ -2225,13 +2226,8 @@ def norm(x, ord=None, axis=None, keepdims=False):
return add.reduce(absx, axis=axis, keepdims=keepdims) ** (1.0 / ord)
elif len(axis) == 2:
row_axis, col_axis = axis
if row_axis < 0:
row_axis += nd
if col_axis < 0:
col_axis += nd
if not (0 <= row_axis < nd and 0 <= col_axis < nd):
raise ValueError('Invalid axis %r for an array with shape %r' %
(axis, x.shape))
row_axis = normalize_axis_index(row_axis, nd)
col_axis = normalize_axis_index(col_axis, nd)
if row_axis == col_axis:
raise ValueError('Duplicate axes given.')
if ord == 2:
Expand Down
4 changes: 2 additions & 2 deletions numpy/linalg/tests/test_linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -1102,8 +1102,8 @@ def test_bad_args(self):
assert_raises(ValueError, norm, B, order, (1, 2))

# Invalid axis
assert_raises(ValueError, norm, B, None, 3)
assert_raises(ValueError, norm, B, None, (2, 3))
assert_raises(IndexError, norm, B, None, 3)
assert_raises(IndexError, norm, B, None, (2, 3))
assert_raises(ValueError, norm, B, None, (0, 1, 2))


Expand Down
15 changes: 9 additions & 6 deletions numpy/ma/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
getargspec, formatargspec, long, basestring, unicode, bytes, sixu
)
from numpy import expand_dims as n_expand_dims
from numpy.core.multiarray import normalize_axis_index


if sys.version_info[0] >= 3:
Expand Down Expand Up @@ -3902,7 +3903,9 @@ def __eq__(self, other):
axis = None
try:
mask = mask.view((bool_, len(self.dtype))).all(axis)
except ValueError:
except (ValueError, IndexError):
# TODO: what error are we trying to catch here?
# invalid axis, or invalid view?
mask = np.all([[f[n].all() for n in mask.dtype.names]
for f in mask], axis=axis)
check._mask = mask
Expand Down Expand Up @@ -3938,7 +3941,9 @@ def __ne__(self, other):
axis = None
try:
mask = mask.view((bool_, len(self.dtype))).all(axis)
except ValueError:
except (ValueError, IndexError):
# TODO: what error are we trying to catch here?
# invalid axis, or invalid view?
mask = np.all([[f[n].all() for n in mask.dtype.names]
for f in mask], axis=axis)
check._mask = mask
Expand Down Expand Up @@ -4340,19 +4345,17 @@ def count(self, axis=None, keepdims=np._NoValue):

if self.shape is ():
if axis not in (None, 0):
raise ValueError("'axis' entry is out of bounds")
raise IndexError("'axis' entry is out of bounds")
return 1
elif axis is None:
if kwargs.get('keepdims', False):
return np.array(self.size, dtype=np.intp, ndmin=self.ndim)
return self.size

axes = axis if isinstance(axis, tuple) else (axis,)
axes = tuple(a if a >= 0 else self.ndim + a for a in axes)
axes = tuple(normalize_axis_index(a, self.ndim) for a in axes)
if len(axes) != len(set(axes)):
raise ValueError("duplicate value in 'axis'")
if builtins.any(a < 0 or a >= self.ndim for a in axes):
raise ValueError("'axis' entry is out of bounds")
items = 1
for ax in axes:
items *= self.shape[ax]
Expand Down
11 changes: 4 additions & 7 deletions numpy/ma/extras.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
import numpy as np
from numpy import ndarray, array as nxarray
import numpy.core.umath as umath
from numpy.core.multiarray import normalize_axis_index
from numpy.lib.function_base import _ureduce
from numpy.lib.index_tricks import AxisConcatenator

Expand Down Expand Up @@ -380,11 +381,7 @@ def apply_along_axis(func1d, axis, arr, *args, **kwargs):
"""
arr = array(arr, copy=False, subok=True)
nd = arr.ndim
if axis < 0:
axis += nd
if (axis >= nd):
raise ValueError("axis must be less than arr.ndim; axis=%d, rank=%d."
% (axis, nd))
axis = normalize_axis_index(axis, nd)
ind = [0] * (nd - 1)
i = np.zeros(nd, 'O')
indlist = list(range(nd))
Expand Down Expand Up @@ -717,8 +714,8 @@ def _median(a, axis=None, out=None, overwrite_input=False):

if axis is None:
axis = 0
elif axis < 0:
axis += asorted.ndim
else:
axis = normalize_axis_index(axis, asorted.ndim)

if asorted.ndim == 1:
counts = count(asorted)
Expand Down
8 changes: 4 additions & 4 deletions numpy/ma/tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1030,7 +1030,7 @@ def test_count_func(self):
res = count(ott, 0)
assert_(isinstance(res, ndarray))
assert_(res.dtype.type is np.intp)
assert_raises(ValueError, ott.count, axis=1)
assert_raises(IndexError, ott.count, axis=1)

def test_minmax_func(self):
# Tests minimum and maximum.
Expand Down Expand Up @@ -4409,7 +4409,7 @@ def test_count(self):
assert_equal(count(a, axis=(0,1), keepdims=True), 4*ones((1,1,4)))
assert_equal(count(a, axis=-2), 2*ones((2,4)))
assert_raises(ValueError, count, a, axis=(1,1))
assert_raises(ValueError, count, a, axis=3)
assert_raises(IndexError, count, a, axis=3)

# check the 'nomask' path
a = np.ma.array(d, mask=nomask)
Expand All @@ -4423,13 +4423,13 @@ def test_count(self):
assert_equal(count(a, axis=(0,1), keepdims=True), 6*ones((1,1,4)))
assert_equal(count(a, axis=-2), 3*ones((2,4)))
assert_raises(ValueError, count, a, axis=(1,1))
assert_raises(ValueError, count, a, axis=3)
assert_raises(IndexError, count, a, axis=3)

# check the 'masked' singleton
assert_equal(count(np.ma.masked), 0)

# check 0-d arrays do not allow axis > 0
assert_raises(ValueError, count, np.ma.array(1), axis=1)
assert_raises(IndexError, count, np.ma.array(1), axis=1)


class TestMaskedConstant(TestCase):
Expand Down
Loading
0