diff --git a/numpy/add_newdocs.py b/numpy/add_newdocs.py index 09f4e40c4d7a..3916d13045e7 100644 --- a/numpy/add_newdocs.py +++ b/numpy/add_newdocs.py @@ -6728,6 +6728,51 @@ def luf(lamdaexpr, *args, **kwargs): 53 """) +add_newdoc('numpy.core.multiarray', 'normalize_axis_index', + """ + normalize_axis_index(axis, ndim) + + Normalizes an axis index, `axis`, such that is a valid positive index into + the shape of array with `ndim` dimensions. Raises an AxisError with an + appropriate message if this is not possible. + + Used internally by all axis-checking logic. + + .. versionadded:: 1.13.0 + + Parameters + ---------- + axis : int + The un-normalized index of the axis. Can be negative + ndim : int + The number of dimensions of the array that `axis` should be normalized + against + + Returns + ------- + normalized_axis : int + The normalized axis index, such that `0 <= normalized_axis < ndim` + + Raises + ------ + AxisError + If the axis index is invalid, when `-ndim <= axis < ndim` is false. + + Examples + -------- + >>> normalize_axis_index(0, ndim=3) + 0 + >>> normalize_axis_index(1, ndim=3) + 1 + >>> normalize_axis_index(-1, ndim=3) + 2 + + >>> normalize_axis_index(3, ndim=3) + Traceback (most recent call last): + ... + AxisError: axis 3 is out of bounds for array of dimension 3 + """) + ############################################################################## # # nd_grid instances diff --git a/numpy/core/_internal.py b/numpy/core/_internal.py index 741c8bb5fbcb..d73cdcc55ab6 100644 --- a/numpy/core/_internal.py +++ b/numpy/core/_internal.py @@ -630,3 +630,6 @@ def _gcd(a, b): # Exception used in shares_memory() class TooHardError(RuntimeError): pass + +class AxisError(ValueError, IndexError): + pass diff --git a/numpy/core/numeric.py b/numpy/core/numeric.py index 97d19f008a78..066697f3eda6 100644 --- a/numpy/core/numeric.py +++ b/numpy/core/numeric.py @@ -17,7 +17,7 @@ inner, int_asbuffer, lexsort, matmul, may_share_memory, min_scalar_type, ndarray, nditer, nested_iters, promote_types, putmask, result_type, set_numeric_ops, shares_memory, vdot, where, - zeros) + zeros, normalize_axis_index) if sys.version_info[0] < 3: from .multiarray import newbuffer, getbuffer @@ -27,7 +27,7 @@ ERR_DEFAULT, PINF, NAN) from . import numerictypes from .numerictypes import longlong, intc, int_, float_, complex_, bool_ -from ._internal import TooHardError +from ._internal import TooHardError, AxisError bitwise_not = invert ufunc = type(sin) @@ -65,7 +65,7 @@ 'True_', 'bitwise_not', 'CLIP', 'RAISE', 'WRAP', 'MAXDIMS', 'BUFSIZE', 'ALLOW_THREADS', 'ComplexWarning', 'full', 'full_like', 'matmul', 'shares_memory', 'may_share_memory', 'MAY_SHARE_BOUNDS', 'MAY_SHARE_EXACT', - 'TooHardError', + 'TooHardError', 'AxisError' ] @@ -1527,15 +1527,12 @@ def rollaxis(a, axis, start=0): """ n = a.ndim - if axis < 0: - axis += n + axis = normalize_axis_index(axis, n) if start < 0: start += n msg = "'%s' arg requires %d <= %s < %d, but %d was passed in" - if not (0 <= axis < n): - raise ValueError(msg % ('axis', -n, 'axis', n, axis)) if not (0 <= start < n + 1): - raise ValueError(msg % ('start', -n, 'start', n + 1, start)) + raise AxisError(msg % ('start', -n, 'start', n + 1, start)) if axis < start: # it's been removed start -= 1 @@ -1554,7 +1551,7 @@ def _validate_axis(axis, ndim, argname): axis = list(axis) axis = [a + ndim if a < 0 else a for a in axis] if not builtins.all(0 <= a < ndim for a in axis): - raise ValueError('invalid axis for this array in `%s` argument' % + raise AxisError('invalid axis for this array in `%s` argument' % argname) if len(set(axis)) != len(axis): raise ValueError('repeated axis in `%s` argument' % argname) diff --git a/numpy/core/shape_base.py b/numpy/core/shape_base.py index 70afdb7465ec..58b0dcaac827 100644 --- a/numpy/core/shape_base.py +++ b/numpy/core/shape_base.py @@ -5,6 +5,7 @@ from . import numeric as _nx from .numeric import asanyarray, newaxis +from .multiarray import normalize_axis_index def atleast_1d(*arys): """ @@ -347,11 +348,7 @@ def stack(arrays, axis=0): raise ValueError('all input arrays must have the same shape') result_ndim = arrays[0].ndim + 1 - if not -result_ndim <= axis < result_ndim: - msg = 'axis {0} out of bounds [-{1}, {1})'.format(axis, result_ndim) - raise IndexError(msg) - if axis < 0: - axis += result_ndim + axis = normalize_axis_index(axis, result_ndim) sl = (slice(None),) * axis + (_nx.newaxis,) expanded_arrays = [arr[sl] for arr in arrays] diff --git a/numpy/core/src/multiarray/common.h b/numpy/core/src/multiarray/common.h index 5e14b80a71ca..625ca9d76887 100644 --- a/numpy/core/src/multiarray/common.h +++ b/numpy/core/src/multiarray/common.h @@ -134,6 +134,42 @@ check_and_adjust_index(npy_intp *index, npy_intp max_item, int axis, return 0; } +/* + * Returns -1 and sets an exception if *axis is an invalid axis for + * an array of dimension ndim, otherwise adjusts it in place to be + * 0 <= *axis < ndim, and returns 0. + */ +static NPY_INLINE int +check_and_adjust_axis(int *axis, int ndim) +{ + /* Check that index is valid, taking into account negative indices */ + if (NPY_UNLIKELY((*axis < -ndim) || (*axis >= ndim))) { + /* + * Load the exception type, if we don't already have it. Unfortunately + * we don't have access to npy_cache_import here + */ + static PyObject *AxisError_cls = NULL; + if (AxisError_cls == NULL) { + PyObject *mod = PyImport_ImportModule("numpy.core._internal"); + + if (mod != NULL) { + AxisError_cls = PyObject_GetAttrString(mod, "AxisError"); + Py_DECREF(mod); + } + } + + PyErr_Format(AxisError_cls, + "axis %d is out of bounds for array of dimension %d", + *axis, ndim); + return -1; + } + /* adjust negative indices */ + if (*axis < 0) { + *axis += ndim; + } + return 0; +} + /* * return true if pointer is aligned to 'alignment' diff --git a/numpy/core/src/multiarray/conversion_utils.c b/numpy/core/src/multiarray/conversion_utils.c index c016bb8d10d4..8ed08a366177 100644 --- a/numpy/core/src/multiarray/conversion_utils.c +++ b/numpy/core/src/multiarray/conversion_utils.c @@ -259,17 +259,10 @@ PyArray_ConvertMultiAxis(PyObject *axis_in, int ndim, npy_bool *out_axis_flags) PyObject *tmp = PyTuple_GET_ITEM(axis_in, i); int axis = PyArray_PyIntAsInt_ErrMsg(tmp, "integers are required for the axis tuple elements"); - int axis_orig = axis; if (error_converting(axis)) { return NPY_FAIL; } - if (axis < 0) { - axis += ndim; - } - if (axis < 0 || axis >= ndim) { - PyErr_Format(PyExc_ValueError, - "'axis' entry %d is out of bounds [-%d, %d)", - axis_orig, ndim, ndim); + if (check_and_adjust_axis(&axis, ndim) < 0) { return NPY_FAIL; } if (out_axis_flags[axis]) { @@ -284,20 +277,16 @@ PyArray_ConvertMultiAxis(PyObject *axis_in, int ndim, npy_bool *out_axis_flags) } /* Try to interpret axis as an integer */ else { - int axis, axis_orig; + int axis; memset(out_axis_flags, 0, ndim); axis = PyArray_PyIntAsInt_ErrMsg(axis_in, "an integer is required for the axis"); - axis_orig = axis; if (error_converting(axis)) { return NPY_FAIL; } - if (axis < 0) { - axis += ndim; - } /* * Special case letting axis={-1,0} slip through for scalars, * for backwards compatibility reasons. @@ -306,10 +295,7 @@ PyArray_ConvertMultiAxis(PyObject *axis_in, int ndim, npy_bool *out_axis_flags) return NPY_SUCCEED; } - if (axis < 0 || axis >= ndim) { - PyErr_Format(PyExc_ValueError, - "'axis' entry %d is out of bounds [-%d, %d)", - axis_orig, ndim, ndim); + if (check_and_adjust_axis(&axis, ndim) < 0) { return NPY_FAIL; } diff --git a/numpy/core/src/multiarray/ctors.c b/numpy/core/src/multiarray/ctors.c index 349b59c5f87d..ee6b66eef461 100644 --- a/numpy/core/src/multiarray/ctors.c +++ b/numpy/core/src/multiarray/ctors.c @@ -2793,7 +2793,6 @@ PyArray_CheckAxis(PyArrayObject *arr, int *axis, int flags) { PyObject *temp1, *temp2; int n = PyArray_NDIM(arr); - int axis_orig = *axis; if (*axis == NPY_MAXDIMS || n == 0) { if (n != 1) { @@ -2831,12 +2830,7 @@ PyArray_CheckAxis(PyArrayObject *arr, int *axis, int flags) temp2 = (PyObject *)temp1; } n = PyArray_NDIM((PyArrayObject *)temp2); - if (*axis < 0) { - *axis += n; - } - if ((*axis < 0) || (*axis >= n)) { - PyErr_Format(PyExc_ValueError, - "axis(=%d) out of bounds", axis_orig); + if (check_and_adjust_axis(axis, n) < 0) { Py_DECREF(temp2); return NULL; } diff --git a/numpy/core/src/multiarray/item_selection.c b/numpy/core/src/multiarray/item_selection.c index 08b9c5965240..3c0f0782e586 100644 --- a/numpy/core/src/multiarray/item_selection.c +++ b/numpy/core/src/multiarray/item_selection.c @@ -1101,16 +1101,12 @@ NPY_NO_EXPORT int PyArray_Sort(PyArrayObject *op, int axis, NPY_SORTKIND which) { PyArray_SortFunc *sort; - int axis_orig = axis; - int n = PyArray_NDIM(op); + int n = PyArray_NDIM(op); - if (axis < 0) { - axis += n; - } - if (axis < 0 || axis >= n) { - PyErr_Format(PyExc_ValueError, "axis(=%d) out of bounds", axis_orig); + if (check_and_adjust_axis(&axis, n) < 0) { return -1; } + if (PyArray_FailUnlessWriteable(op, "sort array") < 0) { return -1; } @@ -1212,17 +1208,13 @@ PyArray_Partition(PyArrayObject *op, PyArrayObject * ktharray, int axis, PyArrayObject *kthrvl; PyArray_PartitionFunc *part; PyArray_SortFunc *sort; - int axis_orig = axis; int n = PyArray_NDIM(op); int ret; - if (axis < 0) { - axis += n; - } - if (axis < 0 || axis >= n) { - PyErr_Format(PyExc_ValueError, "axis(=%d) out of bounds", axis_orig); + if (check_and_adjust_axis(&axis, n) < 0) { return -1; } + if (PyArray_FailUnlessWriteable(op, "partition array") < 0) { return -1; } @@ -1455,12 +1447,7 @@ PyArray_LexSort(PyObject *sort_keys, int axis) *((npy_intp *)(PyArray_DATA(ret))) = 0; goto finish; } - if (axis < 0) { - axis += nd; - } - if ((axis < 0) || (axis >= nd)) { - PyErr_Format(PyExc_ValueError, - "axis(=%d) out of bounds", axis); + if (check_and_adjust_axis(&axis, nd) < 0) { goto fail; } diff --git a/numpy/core/src/multiarray/multiarraymodule.c b/numpy/core/src/multiarray/multiarraymodule.c index f00de46c45aa..1c8d9b5e47e3 100644 --- a/numpy/core/src/multiarray/multiarraymodule.c +++ b/numpy/core/src/multiarray/multiarraymodule.c @@ -327,7 +327,6 @@ PyArray_ConcatenateArrays(int narrays, PyArrayObject **arrays, int axis) PyArray_Descr *dtype = NULL; PyArrayObject *ret = NULL; PyArrayObject_fields *sliding_view = NULL; - int orig_axis = axis; if (narrays <= 0) { PyErr_SetString(PyExc_ValueError, @@ -345,13 +344,7 @@ PyArray_ConcatenateArrays(int narrays, PyArrayObject **arrays, int axis) } /* Handle standard Python negative indexing */ - if (axis < 0) { - axis += ndim; - } - - if (axis < 0 || axis >= ndim) { - PyErr_Format(PyExc_IndexError, - "axis %d out of bounds [0, %d)", orig_axis, ndim); + if (check_and_adjust_axis(&axis, ndim) < 0) { return NULL; } @@ -4109,6 +4102,24 @@ array_may_share_memory(PyObject *NPY_UNUSED(ignored), PyObject *args, PyObject * return array_shares_memory_impl(args, kwds, NPY_MAY_SHARE_BOUNDS, 0); } +static PyObject * +normalize_axis_index(PyObject *NPY_UNUSED(self), PyObject *args, PyObject *kwds) +{ + static char *kwlist[] = {"axis", "ndim", NULL}; + int axis; + int ndim; + + if (!PyArg_ParseTupleAndKeywords(args, kwds, "ii", kwlist, + &axis, &ndim)) { + return NULL; + } + + if(check_and_adjust_axis(&axis, ndim) < 0) { + return NULL; + } + + return PyInt_FromLong(axis); +} static struct PyMethodDef array_module_methods[] = { {"_get_ndarray_c_version", @@ -4284,6 +4295,8 @@ static struct PyMethodDef array_module_methods[] = { METH_VARARGS | METH_KEYWORDS, NULL}, {"unpackbits", (PyCFunction)io_unpack, METH_VARARGS | METH_KEYWORDS, NULL}, + {"normalize_axis_index", (PyCFunction)normalize_axis_index, + METH_VARARGS | METH_KEYWORDS, NULL}, {NULL, NULL, 0, NULL} /* sentinel */ }; diff --git a/numpy/core/src/multiarray/shape.c b/numpy/core/src/multiarray/shape.c index 3bee562be123..5207513bf531 100644 --- a/numpy/core/src/multiarray/shape.c +++ b/numpy/core/src/multiarray/shape.c @@ -705,12 +705,7 @@ PyArray_Transpose(PyArrayObject *ap, PyArray_Dims *permute) } for (i = 0; i < n; i++) { axis = axes[i]; - if (axis < 0) { - axis = PyArray_NDIM(ap) + axis; - } - if (axis < 0 || axis >= PyArray_NDIM(ap)) { - PyErr_SetString(PyExc_ValueError, - "invalid axis for this array"); + if (check_and_adjust_axis(&axis, PyArray_NDIM(ap)) < 0) { return NULL; } if (reverse_permutation[axis] != -1) { diff --git a/numpy/core/src/umath/ufunc_object.c b/numpy/core/src/umath/ufunc_object.c index 0bae2d591a84..af4ce12dbdf7 100644 --- a/numpy/core/src/umath/ufunc_object.c +++ b/numpy/core/src/umath/ufunc_object.c @@ -4036,12 +4036,7 @@ PyUFunc_GenericReduction(PyUFuncObject *ufunc, PyObject *args, Py_DECREF(mp); return NULL; } - if (axis < 0) { - axis += ndim; - } - if (axis < 0 || axis >= ndim) { - PyErr_SetString(PyExc_ValueError, - "'axis' entry is out of bounds"); + if (check_and_adjust_axis(&axis, ndim) < 0) { Py_XDECREF(otype); Py_DECREF(mp); return NULL; @@ -4058,18 +4053,11 @@ PyUFunc_GenericReduction(PyUFuncObject *ufunc, PyObject *args, Py_DECREF(mp); return NULL; } - if (axis < 0) { - axis += ndim; - } /* Special case letting axis={0 or -1} slip through for scalars */ if (ndim == 0 && (axis == 0 || axis == -1)) { axis = 0; } - else if (axis < 0 || axis >= ndim) { - PyErr_SetString(PyExc_ValueError, - "'axis' entry is out of bounds"); - Py_XDECREF(otype); - Py_DECREF(mp); + else if (check_and_adjust_axis(&axis, ndim) < 0) { return NULL; } axes[0] = (int)axis; diff --git a/numpy/core/tests/test_multiarray.py b/numpy/core/tests/test_multiarray.py index 8229f1e1ab14..fa5051ba73fa 100644 --- a/numpy/core/tests/test_multiarray.py +++ b/numpy/core/tests/test_multiarray.py @@ -2013,13 +2013,13 @@ def test_partition(self): d = np.array([2, 1]) d.partition(0, kind=k) assert_raises(ValueError, d.partition, 2) - assert_raises(ValueError, d.partition, 3, axis=1) + assert_raises(np.AxisError, d.partition, 3, axis=1) assert_raises(ValueError, np.partition, d, 2) - assert_raises(ValueError, np.partition, d, 2, axis=1) + assert_raises(np.AxisError, np.partition, d, 2, axis=1) assert_raises(ValueError, d.argpartition, 2) - assert_raises(ValueError, d.argpartition, 3, axis=1) + assert_raises(np.AxisError, d.argpartition, 3, axis=1) assert_raises(ValueError, np.argpartition, d, 2) - assert_raises(ValueError, np.argpartition, d, 2, axis=1) + assert_raises(np.AxisError, np.argpartition, d, 2, axis=1) d = np.arange(10).reshape((2, 5)) d.partition(1, axis=0, kind=k) d.partition(4, axis=1, kind=k) @@ -3522,8 +3522,8 @@ def test_object_argmin_with_NULLs(self): class TestMinMax(TestCase): def test_scalar(self): - assert_raises(ValueError, np.amax, 1, 1) - assert_raises(ValueError, np.amin, 1, 1) + assert_raises(np.AxisError, np.amax, 1, 1) + assert_raises(np.AxisError, np.amin, 1, 1) assert_equal(np.amax(1, axis=0), 1) assert_equal(np.amin(1, axis=0), 1) @@ -3531,7 +3531,7 @@ def test_scalar(self): assert_equal(np.amin(1, axis=None), 1) def test_axis(self): - assert_raises(ValueError, np.amax, [1, 2, 3], 1000) + assert_raises(np.AxisError, np.amax, [1, 2, 3], 1000) assert_equal(np.amax([[1, 2, 3]], axis=1), 3) def test_datetime(self): @@ -3793,7 +3793,7 @@ def test_object(self): # gh-6312 def test_invalid_axis(self): # gh-7528 x = np.linspace(0., 1., 42*3).reshape(42, 3) - assert_raises(ValueError, np.lexsort, x, axis=2) + assert_raises(np.AxisError, np.lexsort, x, axis=2) class TestIO(object): """Test tofile, fromfile, tobytes, and fromstring""" diff --git a/numpy/core/tests/test_numeric.py b/numpy/core/tests/test_numeric.py index 4aa6bed3366a..906280e15353 100644 --- a/numpy/core/tests/test_numeric.py +++ b/numpy/core/tests/test_numeric.py @@ -1010,7 +1010,7 @@ def test_count_nonzero_axis(self): assert_raises(ValueError, np.count_nonzero, m, axis=(1, 1)) assert_raises(TypeError, np.count_nonzero, m, axis='foo') - assert_raises(ValueError, np.count_nonzero, m, axis=3) + assert_raises(np.AxisError, np.count_nonzero, m, axis=3) assert_raises(TypeError, np.count_nonzero, m, axis=np.array([[1], [2]])) @@ -2323,10 +2323,10 @@ class TestRollaxis(TestCase): def test_exceptions(self): a = np.arange(1*2*3*4).reshape(1, 2, 3, 4) - assert_raises(ValueError, np.rollaxis, a, -5, 0) - assert_raises(ValueError, np.rollaxis, a, 0, -5) - assert_raises(ValueError, np.rollaxis, a, 4, 0) - assert_raises(ValueError, np.rollaxis, a, 0, 5) + assert_raises(np.AxisError, np.rollaxis, a, -5, 0) + assert_raises(np.AxisError, np.rollaxis, a, 0, -5) + assert_raises(np.AxisError, np.rollaxis, a, 4, 0) + assert_raises(np.AxisError, np.rollaxis, a, 0, 5) def test_results(self): a = np.arange(1*2*3*4).reshape(1, 2, 3, 4).copy() @@ -2413,11 +2413,11 @@ def test_move_multiples(self): def test_errors(self): x = np.random.randn(1, 2, 3) - assert_raises_regex(ValueError, 'invalid axis .* `source`', + assert_raises_regex(np.AxisError, 'invalid axis .* `source`', np.moveaxis, x, 3, 0) - assert_raises_regex(ValueError, 'invalid axis .* `source`', + assert_raises_regex(np.AxisError, 'invalid axis .* `source`', np.moveaxis, x, -4, 0) - assert_raises_regex(ValueError, 'invalid axis .* `destination`', + assert_raises_regex(np.AxisError, 'invalid axis .* `destination`', np.moveaxis, x, 0, 5) assert_raises_regex(ValueError, 'repeated axis in `source`', np.moveaxis, x, [0, 0], [0, 1]) diff --git a/numpy/core/tests/test_shape_base.py b/numpy/core/tests/test_shape_base.py index ac8dc1eea511..727608a175fa 100644 --- a/numpy/core/tests/test_shape_base.py +++ b/numpy/core/tests/test_shape_base.py @@ -184,8 +184,8 @@ def test_exceptions(self): for ndim in [1, 2, 3]: a = np.ones((1,)*ndim) np.concatenate((a, a), axis=0) # OK - assert_raises(IndexError, np.concatenate, (a, a), axis=ndim) - assert_raises(IndexError, np.concatenate, (a, a), axis=-(ndim + 1)) + assert_raises(np.AxisError, np.concatenate, (a, a), axis=ndim) + assert_raises(np.AxisError, np.concatenate, (a, a), axis=-(ndim + 1)) # Scalars cannot be concatenated assert_raises(ValueError, concatenate, (0,)) @@ -294,8 +294,8 @@ def test_stack(): expected_shapes = [(10, 3), (3, 10), (3, 10), (10, 3)] for axis, expected_shape in zip(axes, expected_shapes): assert_equal(np.stack(arrays, axis).shape, expected_shape) - assert_raises_regex(IndexError, 'out of bounds', stack, arrays, axis=2) - assert_raises_regex(IndexError, 'out of bounds', stack, arrays, axis=-3) + assert_raises_regex(np.AxisError, 'out of bounds', stack, arrays, axis=2) + assert_raises_regex(np.AxisError, 'out of bounds', stack, arrays, axis=-3) # all shapes for 2d input arrays = [np.random.randn(3, 4) for _ in range(10)] axes = [0, 1, 2, -1, -2, -3] diff --git a/numpy/core/tests/test_ufunc.py b/numpy/core/tests/test_ufunc.py index 3fea68700a5b..f7b66f90c0df 100644 --- a/numpy/core/tests/test_ufunc.py +++ b/numpy/core/tests/test_ufunc.py @@ -703,14 +703,14 @@ def test_zerosize_reduction(self): def test_axis_out_of_bounds(self): a = np.array([False, False]) - assert_raises(ValueError, a.all, axis=1) + assert_raises(np.AxisError, a.all, axis=1) a = np.array([False, False]) - assert_raises(ValueError, a.all, axis=-2) + assert_raises(np.AxisError, a.all, axis=-2) a = np.array([False, False]) - assert_raises(ValueError, a.any, axis=1) + assert_raises(np.AxisError, a.any, axis=1) a = np.array([False, False]) - assert_raises(ValueError, a.any, axis=-2) + assert_raises(np.AxisError, a.any, axis=-2) def test_scalar_reduction(self): # The functions 'sum', 'prod', etc allow specifying axis=0 diff --git a/numpy/lib/function_base.py b/numpy/lib/function_base.py index ae1420b726a8..4d1ffbccc277 100644 --- a/numpy/lib/function_base.py +++ b/numpy/lib/function_base.py @@ -25,7 +25,7 @@ from numpy.lib.twodim_base import diag from .utils import deprecate from numpy.core.multiarray import ( - _insert, add_docstring, digitize, bincount, + _insert, add_docstring, digitize, bincount, normalize_axis_index, interp as compiled_interp, interp_complex as compiled_interp_complex ) from numpy.core.umath import _add_newdoc_ufunc as add_newdoc_ufunc @@ -4828,14 +4828,7 @@ def insert(arr, obj, values, axis=None): arr = arr.ravel() ndim = arr.ndim axis = ndim - 1 - else: - if ndim > 0 and (axis < -ndim or axis >= ndim): - raise IndexError( - "axis %i is out of bounds for an array of " - "dimension %i" % (axis, ndim)) - if (axis < 0): - axis += ndim - if (ndim == 0): + elif ndim == 0: # 2013-09-24, 1.9 warnings.warn( "in the future the special handling of scalars will be removed " @@ -4846,6 +4839,8 @@ def insert(arr, obj, values, axis=None): return wrap(arr) else: return arr + else: + axis = normalize_axis_index(axis, ndim) slobj = [slice(None)]*ndim N = arr.shape[axis] newshape = list(arr.shape) diff --git a/numpy/lib/shape_base.py b/numpy/lib/shape_base.py index 58e13533bf1f..62798286f32c 100644 --- a/numpy/lib/shape_base.py +++ b/numpy/lib/shape_base.py @@ -7,6 +7,7 @@ asarray, zeros, outer, concatenate, isscalar, array, asanyarray ) from numpy.core.fromnumeric import product, reshape, transpose +from numpy.core.multiarray import normalize_axis_index from numpy.core import vstack, atleast_3d from numpy.lib.index_tricks import ndindex from numpy.matrixlib.defmatrix import matrix # this raises all the right alarm bells @@ -96,10 +97,7 @@ def apply_along_axis(func1d, axis, arr, *args, **kwargs): # handle negative axes arr = asanyarray(arr) nd = arr.ndim - if not (-nd <= axis < nd): - raise IndexError('axis {0} out of bounds [-{1}, {1})'.format(axis, nd)) - if axis < 0: - axis += nd + axis = normalize_axis_index(axis, nd) # arr, with the iteration axis at the end in_dims = list(range(nd)) @@ -289,8 +287,7 @@ def expand_dims(a, axis): """ a = asarray(a) shape = a.shape - if axis < 0: - axis = axis + len(shape) + 1 + axis = normalize_axis_index(axis, a.ndim + 1) return a.reshape(shape[:axis] + (1,) + shape[axis:]) row_stack = vstack diff --git a/numpy/lib/tests/test_function_base.py b/numpy/lib/tests/test_function_base.py index f69c24d59c5f..d914260add3c 100644 --- a/numpy/lib/tests/test_function_base.py +++ b/numpy/lib/tests/test_function_base.py @@ -466,8 +466,8 @@ def test_multidim(self): insert(a, 1, a[:, 2,:], axis=1)) # invalid axis value - assert_raises(IndexError, insert, a, 1, a[:, 2, :], axis=3) - assert_raises(IndexError, insert, a, 1, a[:, 2, :], axis=-4) + assert_raises(np.AxisError, insert, a, 1, a[:, 2, :], axis=3) + assert_raises(np.AxisError, insert, a, 1, a[:, 2, :], axis=-4) # negative axis value a = np.arange(24).reshape((2, 3, 4)) diff --git a/numpy/linalg/linalg.py b/numpy/linalg/linalg.py index 03f456601954..119326912b96 100644 --- a/numpy/linalg/linalg.py +++ b/numpy/linalg/linalg.py @@ -25,6 +25,7 @@ finfo, errstate, geterrobj, longdouble, rollaxis, amin, amax, product, abs, broadcast, atleast_2d, intp, asanyarray, isscalar, object_ ) +from numpy.core.multiarray import normalize_axis_index from numpy.lib import triu, asfarray from numpy.linalg import lapack_lite, _umath_linalg from numpy.matrixlib.defmatrix import matrix_power @@ -2225,13 +2226,8 @@ def norm(x, ord=None, axis=None, keepdims=False): return add.reduce(absx, axis=axis, keepdims=keepdims) ** (1.0 / ord) elif len(axis) == 2: row_axis, col_axis = axis - if row_axis < 0: - row_axis += nd - if col_axis < 0: - col_axis += nd - if not (0 <= row_axis < nd and 0 <= col_axis < nd): - raise ValueError('Invalid axis %r for an array with shape %r' % - (axis, x.shape)) + row_axis = normalize_axis_index(row_axis, nd) + col_axis = normalize_axis_index(col_axis, nd) if row_axis == col_axis: raise ValueError('Duplicate axes given.') if ord == 2: diff --git a/numpy/linalg/tests/test_linalg.py b/numpy/linalg/tests/test_linalg.py index a353271de9fd..b0a1f04d07e9 100644 --- a/numpy/linalg/tests/test_linalg.py +++ b/numpy/linalg/tests/test_linalg.py @@ -1102,8 +1102,8 @@ def test_bad_args(self): assert_raises(ValueError, norm, B, order, (1, 2)) # Invalid axis - assert_raises(ValueError, norm, B, None, 3) - assert_raises(ValueError, norm, B, None, (2, 3)) + assert_raises(np.AxisError, norm, B, None, 3) + assert_raises(np.AxisError, norm, B, None, (2, 3)) assert_raises(ValueError, norm, B, None, (0, 1, 2)) diff --git a/numpy/ma/core.py b/numpy/ma/core.py index a6f474b954c9..1b25725d1301 100644 --- a/numpy/ma/core.py +++ b/numpy/ma/core.py @@ -41,6 +41,7 @@ getargspec, formatargspec, long, basestring, unicode, bytes, sixu ) from numpy import expand_dims as n_expand_dims +from numpy.core.multiarray import normalize_axis_index if sys.version_info[0] >= 3: @@ -3902,7 +3903,9 @@ def __eq__(self, other): axis = None try: mask = mask.view((bool_, len(self.dtype))).all(axis) - except ValueError: + except (ValueError, np.AxisError): + # TODO: what error are we trying to catch here? + # invalid axis, or invalid view? mask = np.all([[f[n].all() for n in mask.dtype.names] for f in mask], axis=axis) check._mask = mask @@ -3938,7 +3941,9 @@ def __ne__(self, other): axis = None try: mask = mask.view((bool_, len(self.dtype))).all(axis) - except ValueError: + except (ValueError, np.AxisError): + # TODO: what error are we trying to catch here? + # invalid axis, or invalid view? mask = np.all([[f[n].all() for n in mask.dtype.names] for f in mask], axis=axis) check._mask = mask @@ -4340,7 +4345,7 @@ def count(self, axis=None, keepdims=np._NoValue): if self.shape is (): if axis not in (None, 0): - raise ValueError("'axis' entry is out of bounds") + raise np.AxisError("'axis' entry is out of bounds") return 1 elif axis is None: if kwargs.get('keepdims', False): @@ -4348,11 +4353,9 @@ def count(self, axis=None, keepdims=np._NoValue): return self.size axes = axis if isinstance(axis, tuple) else (axis,) - axes = tuple(a if a >= 0 else self.ndim + a for a in axes) + axes = tuple(normalize_axis_index(a, self.ndim) for a in axes) if len(axes) != len(set(axes)): raise ValueError("duplicate value in 'axis'") - if builtins.any(a < 0 or a >= self.ndim for a in axes): - raise ValueError("'axis' entry is out of bounds") items = 1 for ax in axes: items *= self.shape[ax] diff --git a/numpy/ma/extras.py b/numpy/ma/extras.py index 29a15633d95c..7149b525bad1 100644 --- a/numpy/ma/extras.py +++ b/numpy/ma/extras.py @@ -36,6 +36,7 @@ import numpy as np from numpy import ndarray, array as nxarray import numpy.core.umath as umath +from numpy.core.multiarray import normalize_axis_index from numpy.lib.function_base import _ureduce from numpy.lib.index_tricks import AxisConcatenator @@ -380,11 +381,7 @@ def apply_along_axis(func1d, axis, arr, *args, **kwargs): """ arr = array(arr, copy=False, subok=True) nd = arr.ndim - if axis < 0: - axis += nd - if (axis >= nd): - raise ValueError("axis must be less than arr.ndim; axis=%d, rank=%d." - % (axis, nd)) + axis = normalize_axis_index(axis, nd) ind = [0] * (nd - 1) i = np.zeros(nd, 'O') indlist = list(range(nd)) @@ -717,8 +714,8 @@ def _median(a, axis=None, out=None, overwrite_input=False): if axis is None: axis = 0 - elif axis < 0: - axis += asorted.ndim + else: + axis = normalize_axis_index(axis, asorted.ndim) if asorted.ndim == 1: counts = count(asorted) diff --git a/numpy/ma/tests/test_core.py b/numpy/ma/tests/test_core.py index f72ddc5eace8..9d8002ed09e4 100644 --- a/numpy/ma/tests/test_core.py +++ b/numpy/ma/tests/test_core.py @@ -1030,7 +1030,7 @@ def test_count_func(self): res = count(ott, 0) assert_(isinstance(res, ndarray)) assert_(res.dtype.type is np.intp) - assert_raises(ValueError, ott.count, axis=1) + assert_raises(np.AxisError, ott.count, axis=1) def test_minmax_func(self): # Tests minimum and maximum. @@ -4409,7 +4409,7 @@ def test_count(self): assert_equal(count(a, axis=(0,1), keepdims=True), 4*ones((1,1,4))) assert_equal(count(a, axis=-2), 2*ones((2,4))) assert_raises(ValueError, count, a, axis=(1,1)) - assert_raises(ValueError, count, a, axis=3) + assert_raises(np.AxisError, count, a, axis=3) # check the 'nomask' path a = np.ma.array(d, mask=nomask) @@ -4423,13 +4423,13 @@ def test_count(self): assert_equal(count(a, axis=(0,1), keepdims=True), 6*ones((1,1,4))) assert_equal(count(a, axis=-2), 3*ones((2,4))) assert_raises(ValueError, count, a, axis=(1,1)) - assert_raises(ValueError, count, a, axis=3) + assert_raises(np.AxisError, count, a, axis=3) # check the 'masked' singleton assert_equal(count(np.ma.masked), 0) # check 0-d arrays do not allow axis > 0 - assert_raises(ValueError, count, np.ma.array(1), axis=1) + assert_raises(np.AxisError, count, np.ma.array(1), axis=1) class TestMaskedConstant(TestCase): diff --git a/numpy/polynomial/chebyshev.py b/numpy/polynomial/chebyshev.py index 3babb8fc2e95..49d0302e0e75 100644 --- a/numpy/polynomial/chebyshev.py +++ b/numpy/polynomial/chebyshev.py @@ -90,6 +90,7 @@ import warnings import numpy as np import numpy.linalg as la +from numpy.core.multiarray import normalize_axis_index from . import polyutils as pu from ._polybase import ABCPolyBase @@ -936,10 +937,7 @@ def chebder(c, m=1, scl=1, axis=0): raise ValueError("The order of derivation must be non-negative") if iaxis != axis: raise ValueError("The axis must be integer") - if not -c.ndim <= iaxis < c.ndim: - raise ValueError("The axis is out of range") - if iaxis < 0: - iaxis += c.ndim + iaxis = normalize_axis_index(iaxis, c.ndim) if cnt == 0: return c @@ -1064,10 +1062,7 @@ def chebint(c, m=1, k=[], lbnd=0, scl=1, axis=0): raise ValueError("Too many integration constants") if iaxis != axis: raise ValueError("The axis must be integer") - if not -c.ndim <= iaxis < c.ndim: - raise ValueError("The axis is out of range") - if iaxis < 0: - iaxis += c.ndim + iaxis = normalize_axis_index(iaxis, c.ndim) if cnt == 0: return c diff --git a/numpy/polynomial/hermite.py b/numpy/polynomial/hermite.py index 0ebae2027792..a03fe722cb01 100644 --- a/numpy/polynomial/hermite.py +++ b/numpy/polynomial/hermite.py @@ -62,6 +62,7 @@ import warnings import numpy as np import numpy.linalg as la +from numpy.core.multiarray import normalize_axis_index from . import polyutils as pu from ._polybase import ABCPolyBase @@ -700,10 +701,7 @@ def hermder(c, m=1, scl=1, axis=0): raise ValueError("The order of derivation must be non-negative") if iaxis != axis: raise ValueError("The axis must be integer") - if not -c.ndim <= iaxis < c.ndim: - raise ValueError("The axis is out of range") - if iaxis < 0: - iaxis += c.ndim + iaxis = normalize_axis_index(iaxis, c.ndim) if cnt == 0: return c @@ -822,10 +820,7 @@ def hermint(c, m=1, k=[], lbnd=0, scl=1, axis=0): raise ValueError("Too many integration constants") if iaxis != axis: raise ValueError("The axis must be integer") - if not -c.ndim <= iaxis < c.ndim: - raise ValueError("The axis is out of range") - if iaxis < 0: - iaxis += c.ndim + iaxis = normalize_axis_index(iaxis, c.ndim) if cnt == 0: return c diff --git a/numpy/polynomial/hermite_e.py b/numpy/polynomial/hermite_e.py index a09b66670a25..2a29d61cf6fd 100644 --- a/numpy/polynomial/hermite_e.py +++ b/numpy/polynomial/hermite_e.py @@ -62,6 +62,7 @@ import warnings import numpy as np import numpy.linalg as la +from numpy.core.multiarray import normalize_axis_index from . import polyutils as pu from ._polybase import ABCPolyBase @@ -699,10 +700,7 @@ def hermeder(c, m=1, scl=1, axis=0): raise ValueError("The order of derivation must be non-negative") if iaxis != axis: raise ValueError("The axis must be integer") - if not -c.ndim <= iaxis < c.ndim: - raise ValueError("The axis is out of range") - if iaxis < 0: - iaxis += c.ndim + iaxis = normalize_axis_index(iaxis, c.ndim) if cnt == 0: return c @@ -821,10 +819,7 @@ def hermeint(c, m=1, k=[], lbnd=0, scl=1, axis=0): raise ValueError("Too many integration constants") if iaxis != axis: raise ValueError("The axis must be integer") - if not -c.ndim <= iaxis < c.ndim: - raise ValueError("The axis is out of range") - if iaxis < 0: - iaxis += c.ndim + iaxis = normalize_axis_index(iaxis, c.ndim) if cnt == 0: return c diff --git a/numpy/polynomial/laguerre.py b/numpy/polynomial/laguerre.py index dfa997254eb8..c9e1302e133f 100644 --- a/numpy/polynomial/laguerre.py +++ b/numpy/polynomial/laguerre.py @@ -62,6 +62,7 @@ import warnings import numpy as np import numpy.linalg as la +from numpy.core.multiarray import normalize_axis_index from . import polyutils as pu from ._polybase import ABCPolyBase @@ -697,10 +698,7 @@ def lagder(c, m=1, scl=1, axis=0): raise ValueError("The order of derivation must be non-negative") if iaxis != axis: raise ValueError("The axis must be integer") - if not -c.ndim <= iaxis < c.ndim: - raise ValueError("The axis is out of range") - if iaxis < 0: - iaxis += c.ndim + iaxis = normalize_axis_index(iaxis, c.ndim) if cnt == 0: return c @@ -822,10 +820,7 @@ def lagint(c, m=1, k=[], lbnd=0, scl=1, axis=0): raise ValueError("Too many integration constants") if iaxis != axis: raise ValueError("The axis must be integer") - if not -c.ndim <= iaxis < c.ndim: - raise ValueError("The axis is out of range") - if iaxis < 0: - iaxis += c.ndim + iaxis = normalize_axis_index(iaxis, c.ndim) if cnt == 0: return c diff --git a/numpy/polynomial/legendre.py b/numpy/polynomial/legendre.py index fdaa56e0c397..fa578360e435 100644 --- a/numpy/polynomial/legendre.py +++ b/numpy/polynomial/legendre.py @@ -86,6 +86,7 @@ import warnings import numpy as np import numpy.linalg as la +from numpy.core.multiarray import normalize_axis_index from . import polyutils as pu from ._polybase import ABCPolyBase @@ -736,10 +737,7 @@ def legder(c, m=1, scl=1, axis=0): raise ValueError("The order of derivation must be non-negative") if iaxis != axis: raise ValueError("The axis must be integer") - if not -c.ndim <= iaxis < c.ndim: - raise ValueError("The axis is out of range") - if iaxis < 0: - iaxis += c.ndim + iaxis = normalize_axis_index(iaxis, c.ndim) if cnt == 0: return c @@ -864,10 +862,7 @@ def legint(c, m=1, k=[], lbnd=0, scl=1, axis=0): raise ValueError("Too many integration constants") if iaxis != axis: raise ValueError("The axis must be integer") - if not -c.ndim <= iaxis < c.ndim: - raise ValueError("The axis is out of range") - if iaxis < 0: - iaxis += c.ndim + iaxis = normalize_axis_index(iaxis, c.ndim) if cnt == 0: return c diff --git a/numpy/polynomial/polynomial.py b/numpy/polynomial/polynomial.py index 19b085eaf20d..c357b48c9603 100644 --- a/numpy/polynomial/polynomial.py +++ b/numpy/polynomial/polynomial.py @@ -66,6 +66,7 @@ import warnings import numpy as np import numpy.linalg as la +from numpy.core.multiarray import normalize_axis_index from . import polyutils as pu from ._polybase import ABCPolyBase @@ -540,10 +541,7 @@ def polyder(c, m=1, scl=1, axis=0): raise ValueError("The order of derivation must be non-negative") if iaxis != axis: raise ValueError("The axis must be integer") - if not -c.ndim <= iaxis < c.ndim: - raise ValueError("The axis is out of range") - if iaxis < 0: - iaxis += c.ndim + iaxis = normalize_axis_index(iaxis, c.ndim) if cnt == 0: return c @@ -658,10 +656,7 @@ def polyint(c, m=1, k=[], lbnd=0, scl=1, axis=0): raise ValueError("Too many integration constants") if iaxis != axis: raise ValueError("The axis must be integer") - if not -c.ndim <= iaxis < c.ndim: - raise ValueError("The axis is out of range") - if iaxis < 0: - iaxis += c.ndim + iaxis = normalize_axis_index(iaxis, c.ndim) if cnt == 0: return c