From 1a418d3fa8de5f511168babb645abd213c2fb9e3 Mon Sep 17 00:00:00 2001 From: jaimefrio Date: Fri, 17 Oct 2014 17:39:44 -0700 Subject: [PATCH] MANT: more informative errors for 'axis' argument of ufuncs --- numpy/core/src/multiarray/conversion_utils.c | 4 +-- numpy/core/src/umath/ufunc_object.c | 32 ++++++++++++++++---- numpy/core/tests/test_ufunc.py | 9 ++++++ 3 files changed, 37 insertions(+), 8 deletions(-) diff --git a/numpy/core/src/multiarray/conversion_utils.c b/numpy/core/src/multiarray/conversion_utils.c index 096a363f10ed..49ee8626a7e8 100644 --- a/numpy/core/src/multiarray/conversion_utils.c +++ b/numpy/core/src/multiarray/conversion_utils.c @@ -258,7 +258,7 @@ PyArray_ConvertMultiAxis(PyObject *axis_in, int ndim, npy_bool *out_axis_flags) for (i = 0; i < naxes; ++i) { PyObject *tmp = PyTuple_GET_ITEM(axis_in, i); int axis = PyArray_PyIntAsInt_ErrMsg(tmp, - "integers are required for the axis tuple elements"); + "'axis' tuple entries must be integers"); int axis_orig = axis; if (error_converting(axis)) { return NPY_FAIL; @@ -289,7 +289,7 @@ PyArray_ConvertMultiAxis(PyObject *axis_in, int ndim, npy_bool *out_axis_flags) memset(out_axis_flags, 0, ndim); axis = PyArray_PyIntAsInt_ErrMsg(axis_in, - "an integer is required for the axis"); + "'axis' must be None, an integer or a tuple of integers"); axis_orig = axis; if (error_converting(axis)) { diff --git a/numpy/core/src/umath/ufunc_object.c b/numpy/core/src/umath/ufunc_object.c index b12fbf57fe90..df58dc2337a6 100644 --- a/numpy/core/src/umath/ufunc_object.c +++ b/numpy/core/src/umath/ufunc_object.c @@ -3745,7 +3745,16 @@ PyUFunc_GenericReduction(PyUFuncObject *ufunc, PyObject *args, for (i = 0; i < naxes; ++i) { PyObject *tmp = PyTuple_GET_ITEM(axes_in, i); int axis = PyArray_PyIntAsInt(tmp); - if (axis == -1 && PyErr_Occurred()) { + int axis_orig = axis; + PyObject *error = PyErr_Occurred(); + if (axis == -1 && error != NULL) { ; + /* + * Need to re-raise the exact same exception returned by + * PyArray_PyIntAsInt to not break the deprecation tests + */ + PyErr_Clear(); + PyErr_SetString(error, + "'axis' tuple entries must be integers"); Py_XDECREF(otype); Py_DECREF(mp); return NULL; @@ -3754,8 +3763,9 @@ PyUFunc_GenericReduction(PyUFuncObject *ufunc, PyObject *args, axis += ndim; } if (axis < 0 || axis >= ndim) { - PyErr_SetString(PyExc_ValueError, - "'axis' entry is out of bounds"); + PyErr_Format(PyExc_ValueError, + "'axis' entry %d is out of bounds [-%d, %d)", + axis_orig, ndim, ndim); Py_XDECREF(otype); Py_DECREF(mp); return NULL; @@ -3766,8 +3776,17 @@ PyUFunc_GenericReduction(PyUFuncObject *ufunc, PyObject *args, /* Try to interpret axis as an integer */ else { int axis = PyArray_PyIntAsInt(axes_in); + int axis_orig = axis; + PyObject *error = PyErr_Occurred(); /* TODO: PyNumber_Index would be good to use here */ - if (axis == -1 && PyErr_Occurred()) { + if (axis == -1 && error != NULL) { + /* + * Need to re-raise the exact same exception returned by + * PyArray_PyIntAsInt to not break the deprecation tests + */ + PyErr_Clear(); + PyErr_SetString(error, + "'axis' must be None, an integer or a tuple of integers"); Py_XDECREF(otype); Py_DECREF(mp); return NULL; @@ -3780,8 +3799,9 @@ PyUFunc_GenericReduction(PyUFuncObject *ufunc, PyObject *args, axis = 0; } else if (axis < 0 || axis >= ndim) { - PyErr_SetString(PyExc_ValueError, - "'axis' entry is out of bounds"); + PyErr_Format(PyExc_ValueError, + "'axis' entry %d is out of bounds [-%d, %d)", + axis_orig, ndim, ndim); Py_XDECREF(otype); Py_DECREF(mp); return NULL; diff --git a/numpy/core/tests/test_ufunc.py b/numpy/core/tests/test_ufunc.py index eacc266be791..d175a8b6ed47 100644 --- a/numpy/core/tests/test_ufunc.py +++ b/numpy/core/tests/test_ufunc.py @@ -1073,6 +1073,9 @@ def test_reduce_arguments(self): assert_equal(f(d), r) # a, axis=0, dtype=None, out=None, keepdims=False assert_equal(f(d, axis=0), r) + assert_equal(f(d, axis=None), 10) + assert_equal(f(d, axis=(0, 1)), 10) + assert_equal(f(d, axis=(-1, -2)), 10) assert_equal(f(d, 0), r) assert_equal(f(d, 0, dtype=None), r) assert_equal(f(d, 0, dtype='i'), r) @@ -1098,6 +1101,12 @@ def test_reduce_arguments(self): assert_raises(TypeError, f, d, axis="invalid") assert_raises(TypeError, f, d, axis="invalid", dtype=None, keepdims=True) + assert_raises(TypeError, f, d, axis=[-1, -2]) + assert_raises(TypeError, f, d, axis=(-1, 'invalid')) + assert_raises(ValueError, f, d, axis=(0, -3)) + assert_raises(ValueError, f, d, axis=(0, 2)) + + # invalid dtype assert_raises(TypeError, f, d, 0, "invalid") assert_raises(TypeError, f, d, dtype="invalid")