8000 MAINT: Use the same exception for all bad axis requests by eric-wieser · Pull Request #8584 · numpy/numpy · GitHub
[go: up one dir, main page]

Skip to content

MAINT: Use the same exception for all bad axis requests #8584

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Feb 21, 2017
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 45 additions & 0 deletions numpy/add_newdocs.py
Original file line number Diff line number Diff line change
Expand Up @@ -6728,6 +6728,51 @@ def luf(lamdaexpr, *args, **kwargs):
53
""")

add_newdoc('numpy.core.multiarray', 'normalize_axis_index',
"""
normalize_axis_index(axis, ndim)

Normalizes an axis index, `axis`, such that is a valid positive index into
the shape of array with `ndim` dimensions. Raises an AxisError with an
appropriate message if this is not possible.

Used internally by all axis-checking logic.

.. versionadded:: 1.13.0

Parameters
----------
axis : int
The un-normalized index of the axis. Can be negative
ndim : int
The number of dimensions of the array that `axis` should be normalized
against

Returns
-------
normalized_axis : int
The normalized axis index, such that `0 <= normalized_axis < ndim`

Raises
------
AxisError
If the axis index is invalid, when `-ndim <= axis < ndim` is false.

Examples
--------
>>> normalize_axis_index(0, ndim=3)
0
>>> normalize_axis_index(1, ndim=3)
1
>>> normalize_axis_index(-1, ndim=3)
2

>>> normalize_axis_index(3, ndim=3)
Traceback (most recent call last):
...
AxisError: axis 3 is out of bounds for array of dimension 3
""")

##############################################################################
#
# nd_grid instances
Expand Down
3 changes: 3 additions & 0 deletions numpy/core/_internal.py
Original file line number Diff line number Diff line change
Expand Up @@ -630,3 +630,6 @@ def _gcd(a, b):
# Exception used in shares_memory()
class TooHardError(RuntimeError):
pass

class AxisError(ValueError, IndexError):
pass
15 changes: 6 additions & 9 deletions numpy/core/numeric.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
inner, int_asbuffer, lexsort, matmul, may_share_memory,
min_scalar_type, ndarray, nditer, nested_iters, promote_types,
putmask, result_type, set_numeric_ops, shares_memory, vdot, where,
zeros)
zeros, normalize_axis_index)
if sys.version_info[0] < 3:
from .multiarray import newbuffer, getbuffer

Expand All @@ -27,7 +27,7 @@
ERR_DEFAULT, PINF, NAN)
from . import numerictypes
from .numerictypes import longlong, intc, int_, float_, complex_, bool_
from ._internal import TooHardError
from ._internal import TooHardError, AxisError

bitwise_not = invert
ufunc = type(sin)
Expand Down Expand Up @@ -65,7 +65,7 @@
'True_', 'bitwise_not', 'CLIP', 'RAISE', 'WRAP', 'MAXDIMS', 'BUFSIZE',
'ALLOW_THREADS', 'ComplexWarning', 'full', 'full_like', 'matmul',
'shares_memory', 'may_share_memory', 'MAY_SHARE_BOUNDS', 'MAY_SHARE_EXACT',
'TooHardError',
'TooHardError', 'AxisError'
]


Expand Down Expand Up @@ -1527,15 +1527,12 @@ def rollaxis(a, axis, start=0):

"""
n = a.ndim
if axis < 0:
axis += n
axis = normalize_axis_index(axis, n)
if start < 0:
start += n
msg = "'%s' arg requires %d <= %s < %d, but %d was passed in"
if not (0 <= axis < n):
raise ValueError(msg % ('axis', -n, 'axis', n, axis))
if not (0 <= start < n + 1):
raise ValueError(msg % ('start', -n, 'start', n + 1, start))
raise AxisError(msg % ('start', -n, 'start', n + 1, start))
if axis < start:
# it's been removed
start -= 1
Expand All @@ -1554,7 +1551,7 @@ def _validate_axis(axis, ndim, argname):
axis = list(axis)
axis = [a + ndim if a < 0 else a for a in axis]
if not builtins.all(0 <= a < ndim for a in axis):
raise ValueError('invalid axis for this array in `%s` argument' %
raise AxisError('invalid axis for this array in `%s` argument' %
argname)
if len(set(axis)) != len(axis):
raise ValueError('repeated axis in `%s` argument' % argname)
Expand Down
7 changes: 2 additions & 5 deletions numpy/core/shape_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from . import numeric as _nx
from .numeric import asanyarray, newaxis
from .multiarray import normalize_axis_index

def atleast_1d(*arys):
"""
Expand Down Expand Up @@ -347,11 +348,7 @@ def stack(arrays, axis=0):
raise ValueError('all input arrays must have the same shape')

result_ndim = arrays[0].ndim + 1
if not -result_ndim <= axis < result_ndim:
msg = 'axis {0} out of bounds [-{1}, {1})'.format(axis, result_ndim)
raise IndexError(msg)
if axis < 0:
axis += result_ndim
axis = normalize_axis_index(axis, result_ndim)

sl = (slice(None),) * axis + (_nx.newaxis,)
expanded_arrays = [arr[sl] for arr in arrays]
Expand Down
36 changes: 36 additions & 0 deletions numpy/core/src/multiarray/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,42 @@ check_and_adjust_index(npy_intp *index, npy_intp max_item, int axis,
return 0;
}

/*
* Returns -1 and sets an exception if *axis is an invalid axis for
* an array of dimension ndim, otherwise adjusts it in place to be
* 0 <= *axis < ndim, and returns 0.
*/
static NPY_INLINE int
check_and_adjust_axis(int *axis, int ndim)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

adjust seems to imply something always happens; how about check_and_normalize_axis. Or just normalize_axis?

{
/* Check that index is valid, taking into account negative indices */
if (NPY_UNLIKELY((*axis < -ndim) || (*axis >= ndim))) {
/*
* Load the exception type, if we don't already have it. Unfortunately
* we don't have access to npy_cache_import here
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

More for my education than anything else, is there a reason not just to #include <npy_import.h> at the top?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a public header, but I think npy_import is private. So when another module includes this header, it can no longer find npy_import.

That could be fixed by moving this to common.c, but I don't know if the inlining would then work

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since every problem supposedly is solvable with another layer of redirection, could one call an error-producing function in common.c here, where that function raises the error? (But maybe this is very much not worth it...)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Arguably, the inlining is not important here, as this isn't nearly as critical a path as check_and_adjust_index

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK, how about then moving it to common.c? Somewhat nicer anyway to have c code reside in *.c files...

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Main argument for leaving it here is it gives clear contrast with check_and_adjust_index

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

True; fine to keep as is too even if it duplicates a bit of code.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@eric-wieser did you try to include npy_import.h? It should work. Private here means private to numpy. The files therein should be available to everything in numpy/core/src.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the problem is that something outside numpy/core/src tries to include "common.h", which then doesn't have access to npy_include. I remember it not building, but I forget the exact error.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And just to be clear, the "public" headers are in numpy/core/include/numpy.

*/
static PyObject *AxisError_cls = NULL;
if (AxisError_cls == NULL) {
PyObject *mod = PyImport_ImportModule("numpy.core._internal");

if (mod != NULL) {
AxisError_cls = PyObject_GetAttrString(mod, "AxisError");
Py_DECREF(mod);
}
}

PyErr_Format(AxisError_cls,
"axis %d is out of bounds for array of dimension %d",
*axis, ndim);
return -1;
}
/* adjust negative indices */
if (*axis < 0) {
*axis += ndim;
}
return 0;
}


/*
* return true if pointer is aligned to 'alignment'
Expand Down
20 changes: 3 additions & 17 deletions numpy/core/src/multiarray/conversion_utils.c
Original file line number Diff line number Diff line change
Expand Up @@ -259,17 +259,10 @@ PyArray_ConvertMultiAxis(PyObject *axis_in, int ndim, npy_bool *out_axis_flags)
PyObject *tmp = PyTuple_GET_ITEM(axis_in, i);
int axis = PyArray_PyIntAsInt_ErrMsg(tmp,
"integers are required for the axis tuple elements");
int axis_orig = axis;
if (error_converting(axis)) {
return NPY_FAIL;
}
if (axis < 0) {
axis += ndim;
}
if (axis < 0 || axis >= ndim) {
PyErr_Format(PyExc_ValueError,
"'axis' entry %d is out of bounds [-%d, %d)",
axis_orig, ndim, ndim);
if (check_and_adjust_axis(&axis, ndim) < 0) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here and below: check for status using != 0? Or just omit the comparison altogether (as above in error_converting(axis)

return NPY_FAIL;
}
if (out_axis_flags[axis]) {
Expand All @@ -284,20 +277,16 @@ PyArray_ConvertMultiAxis(PyObject *axis_in, int ndim, npy_bool *out_axis_flags)
}
/* Try to interpret axis as an integer */
else {
int axis, axis_orig;
int axis;

memset(out_axis_flags, 0, ndim);

axis = PyArray_PyIntAsInt_ErrMsg(axis_in,
"an integer is required for the axis");
axis_orig = axis;

if (error_converting(axis)) {
return NPY_FAIL;
B41A }
if (axis < 0) {
axis += ndim;
}
/*
* Special case letting axis={-1,0} slip through for scalars,
* for backwards compatibility reasons.
Expand All @@ -306,10 +295,7 @@ PyArray_ConvertMultiAxis(PyObject *axis_in, int ndim, npy_bool *out_axis_flags)
return NPY_SUCCEED;
}

if (axis < 0 || axis >= ndim) {
PyErr_Format(PyExc_ValueError,
"'axis' entry %d is out of bounds [-%d, %d)",
axis_orig, ndim, ndim);
if (check_and_adjust_axis(&axis, ndim) < 0) {
return NPY_FAIL;
}

Expand Down
8 changes: 1 addition & 7 deletions numpy/core/src/multiarray/ctors.c
Original file line number Diff line number Diff line change
Expand Up @@ -2793,7 +2793,6 @@ PyArray_CheckAxis(PyArrayObject *arr, int *axis, int flags)
{
PyObject *temp1, *temp2;
int n = PyArray_NDIM(arr);
int axis_orig = *axis;

if (*axis == NPY_MAXDIMS || n == 0) {
if (n != 1) {
Expand Down Expand Up @@ -2831,12 +2830,7 @@ PyArray_CheckAxis(PyArrayObject *arr, int *axis, int flags)
temp2 = (PyObject *)temp1;
}
n = PyArray_NDIM((PyArrayObject *)temp2);
if (*axis < 0) {
*axis += n;
}
if ((*axis < 0) || (*axis >= n)) {
PyErr_Format(PyExc_ValueError,
"axis(=%d) out of bounds", axis_orig);
if (check_and_adjust_axis(axis, n) < 0) {
Py_DECREF(temp2);
return NULL;
}
Expand Down
25 changes: 6 additions & 19 deletions numpy/core/src/multiarray/item_selection.c
Original file line number Diff line number Diff line change
Expand Up @@ -1101,16 +1101,12 @@ NPY_NO_EXPORT int
PyArray_Sort(PyArrayObject *op, int axis, NPY_SORTKIND which)
{
PyArray_SortFunc *sort;
int axis_orig = axis;
int n = PyArray_NDIM(op);
int n = PyArray_NDIM(op);

if (axis < 0) {
axis += n;
}
if (axis < 0 || axis >= n) {
PyErr_Format(PyExc_ValueError, "axis(=%d) out of bounds", axis_orig);
if (check_and_adjust_axis(&axis, n) < 0) {
return -1;
}

if (PyArray_FailUnlessWriteable(op, "sort array") < 0) {
return -1;
}
Expand Down Expand Up @@ -1212,17 +1208,13 @@ PyArray_Partition(PyArrayObject *op, PyArrayObject * ktharray, int axis,
PyArrayObject *kthrvl;
PyArray_PartitionFunc *part;
PyArray_SortFunc *sort;
int axis_orig = axis;
int n = PyArray_NDIM(op);
int ret;

if (axis < 0) {
axis += n;
}
if (axis < 0 || axis >= n) {
PyErr_Format(PyExc_ValueError, "axis(=%d) out of bounds", axis_orig);
if (check_and_adjust_axis(&axis, n) < 0) {
return -1;
}

if (PyArray_FailUnlessWriteable(op, "partition array") < 0) {
return -1;
}
Expand Down Expand Up @@ -1455,12 +1447,7 @@ PyArray_LexSort(PyObject *sort_keys, int axis)
*((npy_intp *)(PyArray_DATA(ret))) = 0;
goto finish;
}
if (axis < 0) {
axis += nd;
}
if ((axis < 0) || (axis >= nd)) {
PyErr_Format(PyExc_ValueError,
"axis(=%d) out of bounds", axis);
if (check_and_adjust_axis(&axis, nd) < 0) {
goto fail;
}

Expand Down
29 changes: 21 additions & 8 dele 10000 tions numpy/core/src/multiarray/multiarraymodule.c
Original file line number Diff line number Diff line change
Expand Up @@ -327,7 +327,6 @@ PyArray_ConcatenateArrays(int narrays, PyArrayObject **arrays, int axis)
PyArray_Descr *dtype = NULL;
PyArrayObject *ret = NULL;
PyArrayObject_fields *sliding_view = NULL;
int orig_axis = axis;

if (narrays <= 0) {
PyErr_SetString(PyExc_ValueError,
Expand All @@ -345,13 +344,7 @@ PyArray_ConcatenateArrays(int narrays, PyArrayObject **arrays, int axis)
}

/* Handle standard Python negative indexing */
if (axis < 0) {
axis += ndim;
}

if (axis < 0 || axis >= ndim) {
PyErr_Format(PyExc_IndexError,
"axis %d out of bounds [0, %d)", orig_axis, ndim);
if (check_and_adjust_axis(&axis, ndim) < 0) {
return NULL;
}

Expand Down Expand Up @@ -4109,6 +4102,24 @@ array_may_share_memory(PyObject *NPY_UNUSED(ignored), PyObject *args, PyObject *
return array_shares_memory_impl(args, kwds, NPY_MAY_SHARE_BOUNDS, 0);
}

static PyObject *
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I presume this is for follow up in python code? If so, I think it should be part of a follow-up PR, not be done here (if only to give a chance to discuss whether it is best to go through C for something like this).

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've added the python stuff to this PR now...

normalize_axis_index(PyObject *NPY_UNUSED(self), PyObject *args, PyObject *kwds)
{
static char *kwlist[] = {"axis", "ndim", NULL};
int axis;
int ndim;

if (!PyArg_ParseTupleAndKeywords(args, kwds, "ii", kwlist,
&axis, &ndim)) {
return NULL;
}

if(check_and_adjust_axis(&axis, ndim) < 0) {
return NULL;
}

return PyInt_FromLong(axis);
}

static struct PyMethodDef array_module_methods[] = {
{"_get_ndarray_c_version",
Expand Down Expand Up @@ -4284,6 +4295,8 @@ static struct PyMethodDef array_module_methods[] = {
METH_VARARGS | METH_KEYWORDS, NULL},
{"unpackbits", (PyCFunction)io_unpack,
METH_VARARGS | METH_KEYWORDS, NULL},
{"normalize_axis_index", (PyCFunction)normalize_axis_index,
METH_VARARGS | METH_KEYWORDS, NULL},
{NULL, NULL, 0, NULL} /* sentinel */
};

Expand Down
7 changes: 1 addition & 6 deletions numpy/core/src/multiarray/shape.c
Original file line number Diff line number Diff line change
Expand Up @@ -705,12 +705,7 @@ PyArray_Transpose(PyArrayObject *ap, PyArray_Dims *permute)
}
for (i = 0; i < n; i++) {
axis = axes[i];
if (axis < 0) {
axis = PyArray_NDIM(ap) + axis;
}
if (axis < 0 || axis >= PyArray_NDIM(ap)) {
PyErr_SetString(PyExc_ValueError,
"invalid axis for this array");
if (check_and_adjust_axis(&axis, PyArray_NDIM(ap)) < 0) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

axis is npy_intp, not int.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe useful to double check this everywhere.

Copy link
Member Author
@eric-wieser eric-wieser Feb 23, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Which do you think I should accept in check_and_adjust_axis? Sometimes axis can be int, sometimes it is intp. Obviously I can introduce some locals for conversion, but I still need to choose which is right

return NULL;
}
if (reverse_permutation[axis] != -1) {
Expand Down
Loading
0