8000 MAINT: Use AxisError in more places by eric-wieser · Pull Request #8843 · numpy/numpy · GitHub
[go: up one dir, main page]

Skip to content

MAINT: Use AxisError in more places #8843

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Mar 29, 2017
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion numpy/add_newdocs.py
Original file line number Diff line number Diff line change
Expand Up @@ -6739,7 +6739,7 @@ def luf(lamdaexpr, *args, **kwargs):

add_newdoc('numpy.core.multiarray', 'normalize_axis_index',
"""
normalize_axis_index(axis, ndim)
normalize_axis_index(axis, ndim, msg_prefix=None)

Normalizes an axis index, `axis`, such that is a valid positive index into
the shape of array with `ndim` dimensions. Raises an AxisError with an
Expand All @@ -6756,6 +6756,8 @@ def luf(lamdaexpr, *args, **kwargs):
ndim : int
The number of dimensions of the array that `axis` should be normalized
against
msg_prefix : str
A prefix to put before the message, typically the name of the argument

Returns
-------
Expand All @@ -6780,6 +6782,10 @@ def luf(lamdaexpr, *args, **kwargs):
Traceback (most recent call last):
...
AxisError: axis 3 is out of bounds for array of dimension 3
>>> normalize_axis_index(-4, ndim=3, msg_prefix='axes_arg')
Traceback (most recent call last):
...
AxisError: axes_arg: axis -4 is out of bounds for array of dimension 3
""")

##############################################################################
Expand Down
15 changes: 14 additions & 1 deletion numpy/core/_internal.py
Original file line number Diff line number Diff line change
Expand Up @@ -631,4 +631,17 @@ class TooHardError(RuntimeError):
pass

class AxisError(ValueError, IndexError):
pass
""" Axis supplied was invalid. """
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Now with a docstring, which Ipython shows when you type an exception name. This is in a similar style to the builtin ones.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be in numpy/_globals.py.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And imported into numpy/__init__.py.

Copy link
Member Author
@eric-wieser eric-wieser Mar 28, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I put it here in #8584 because it needs to be imported from C code, and that's what _internals seems to be used for. Is it safe to load _globals.py from C?

AxisError is already visible at the global scope as np.AxisError, as it goes through the __all__ chain.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can import from practically anything. We also import from multiarray itself into C code. I'd import this from '"numpy"', The _globals is for singletons that may also be used by python code.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should TooHardError move there too?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

right now, numeric.py has from ._internal import TooHardError, AxisError

core/__init__.py also contains from . import _internal # for freeze programs, for some reason

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Probably, but it can be done in a separate PR as the imports also need to be changed. Maybe _globals.py should be renamed _global_singletons or _global_exceptions ;) There are other errors scattered about that maybe we should move at some point.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Or maybe _numpy_exceptions

Copy link
Member Author
@eric-wieser eric-wieser Mar 28, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you want me to do both in a separate PR? AxisError being in that file is not new to this PR.

(either way, fix coming for something else, so don't merge yet!)

def __init__(self, axis, ndim=None, msg_prefix=None):
# single-argument form just delegates to base class
if ndim is None and msg_prefix is None:
msg = axis

# do the string formatting here, to save work in the C code
else:
msg = ("axis {} is out of bounds for array of dimension {}"
.format(axis, ndim))
if msg_prefix is not None:
msg = "{}: {}".format(msg_prefix, msg)

super(AxisError, self).__init__(msg)
81 changes: 59 additions & 22 deletions numpy/core/numeric.py
Original file line number Diff line number Diff line change
Expand Up @@ -441,7 +441,7 @@ def count_nonzero(a, axis=None):
nullstr = a.dtype.type('')
return (a != nullstr).sum(axis=axis, dtype=np.intp)

axis = asarray(_validate_axis(axis, a.ndim, 'axis'))
axis = asarray(normalize_axis_tuple(axis, a.ndim))
counts = np.apply_along_axis(multiarray.count_nonzero, axis[0], a)

if axis.size == 1:
Expand Down Expand Up @@ -1460,16 +1460,14 @@ def roll(a, shift, axis=None):
return roll(a.ravel(), shift, 0).reshape(a.shape)

else:
axis = normalize_axis_tuple(axis, a.ndim, allow_duplicate=True)
broadcasted = broadcast(shift, axis)
if broadcasted.ndim > 1:
raise ValueError(
"'shift' and 'axis' should be scalars or 1D sequences")
shifts = {ax: 0 for ax in range(a.ndim)}
for sh, ax in broadcasted:
if -a.ndim <= ax < a.ndim:
shifts[ax % a.ndim] += sh
else:
raise ValueError("'axis' entry is out of bounds")
shifts[ax] += sh

rolls = [((slice(None), slice(None)),)] * a.ndim
for ax, offset in shifts.items():
Expand Down Expand Up @@ -1544,17 +1542,59 @@ def rollaxis(a, axis, start=0):
return a.transpose(axes)


def _validate_axis(axis, ndim, argname):
def normalize_axis_tuple(axis, ndim, argname=None, allow_duplicate=False):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I slightly wonder -- mostly because of #8819 -- whether it would be useful to have this in C...

Copy link
Member Author
@eric-wieser eric-wieser Mar 27, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

whether it would be useful to have this in C

Yes, I think it would - and I think we even already have it for ufunc.reduce. I think this PR is a good stepping stone to implementing that, since it pulls all the python codepaths through one point for parsing axes - that point can now be moved to C later.

"""
Normalizes an axis argument into a tuple of non-negative integer axes.

This handles shorthands such as ``1`` and converts them to ``(1,)``,
as well as performing the handling of negative indices covered by
`normalize_axis_index`.

By default, this forbids axes from being specified multiple times.

Used internally by multi-axis-checking logic.

.. versionadded:: 1.13.0

Parameters
----------
axis : int, iterable of int
The un-normalized index or indices of the axis.
ndim : int
The number of dimensions of the array that `axis` should be normalized
against.
argname : str, optional
A prefix to put before the error message, typically the name of the
argument.
allow_duplicate : bool, optional
If False, the default, disallow an axis from being specified twice.

Returns
-------
normalized_axes : tuple of int
The normalized axis index, such that `0 <= normalized_axis < ndim`

Raises
------
AxisError
If any axis provided is out of range
ValueError
If an axis is repeated

See also
--------
normalize_axis_index : normalizing a single scalar axis
"""
try:
axis = [operator.index(axis)]
except TypeError:
axis = list(axis)
axis = [a + ndim if a < 0 else a for a in axis]
if not builtins.all(0 <= a < ndim for a in axis):
raise AxisError('invalid axis for this array in `%s` argument' %
argname)
if len(set(axis)) != len(axis):
raise ValueError('repeated axis in `%s` argument' % argname)
axis = tuple(axis)
axis = tuple(normalize_axis_index(ax, ndim, argname) for ax in axis)
if not allow_duplicate and len(set(axis)) != len(axis):
if argname:
raise ValueError('repeated axis in `{}` argument'.format(argname))
else:
raise ValueError('repeated axis')
return axis


Expand Down Expand Up @@ -1614,8 +1654,8 @@ def moveaxis(a, source, destination):
a = asarray(a)
transpose = a.transpose

source = _validate_axis(source, a.ndim, 'source')
destination = _validate_axis(destination, a.ndim, 'destination')
source = normalize_axis_tuple(source, a.ndim, 'source')
destination = normalize_axis_tuple(destination, a.ndim, 'destination')
if len(source) != len(destination):
raise ValueError('`source` and `destination` arguments must have '
'the same number of elements')
Expand Down Expand Up @@ -1752,11 +1792,9 @@ def cross(a, b, axisa=-1, axisb=-1, axisc=-1, axis=None):
a = asarray(a)
b = asarray(b)
# Check axisa and axisb are within bounds
axis_msg = "'axis{0}' out of bounds"
if axisa < -a.ndim or axisa >= a.ndim:
raise ValueError(axis_msg.format('a'))
if axisb < -b.ndim or axisb >= b.ndim:
raise ValueError(axis_msg.format('b'))
axisa = normalize_axis_index(axisa, a.ndim, msg_prefix='axisa')
axisb = normalize_axis_index(axisb, b.ndim, msg_prefix='axisb')

# Move working axis to the end of the shape
a = rollaxis(a, axisa, a.ndim)
b = rollaxis(b, axisb, b.ndim)
Expand All @@ -1770,8 +1808,7 @@ def cross(a, b, axisa=-1, axisb=-1, axisc=-1, axis=None):
if a.shape[-1] == 3 or b.shape[-1] == 3:
shape += (3,)
# Check axisc is within bounds
if axisc < -len(shape) or axisc >= len(shape):
raise ValueError(axis_msg.format('c'))
axisc = normalize_axis_index(axisc, len(shape), msg_prefix='axisc')
dtype = promote_types(a.dtype, b.dtype)
cp = empty(shape, dtype)

Expand Down
23 changes: 19 additions & 4 deletions numpy/core/src/multiarray/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -138,9 +138,11 @@ check_and_adjust_index(npy_intp *index, npy_intp max_item, int axis,
* Returns -1 and sets an exception if *axis is an invalid axis for
* an array of dimension ndim, otherwise adjusts it in place to be
* 0 <= *axis < ndim, and returns 0.
*
* msg_prefix: borrowed reference, a string to prepend to the message
*/
static NPY_INLINE int
check_and_adjust_axis(int *axis, int ndim)
check_and_adjust_axis_msg(int *axis, int ndim, PyObject *msg_prefix)
{
/* Check that index is valid, taking into account negative indices */
if (NPY_UNLIKELY((*axis < -ndim) || (*axis >= ndim))) {
Expand All @@ -149,6 +151,8 @@ check_and_adjust_axis(int *axis, int ndim)
* we don't have access to npy_cache_import here
*/
static PyObject *AxisError_cls = NULL;
PyObject *exc;

if (AxisError_cls == NULL) {
PyObject *mod = PyImport_ImportModule("numpy.core._internal");

Expand All @@ -158,9 +162,15 @@ check_and_adjust_axis(int *axis, int ndim)
}
}

PyErr_Format(AxisError_cls,
"axis %d is out of bounds for array of dimension %d",
*axis, ndim);
/* Invoke the AxisError constructor */
exc = PyObject_CallFunction(AxisError_cls, "iiO",
*axis, ndim, msg_prefix);
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was the way I should have written this the first time...

if (exc == NULL) {
return -1;
}
PyErr_SetObject(AxisError_cls, exc);
Py_DECREF(exc);

return -1;
}
/* adjust negative indices */
Expand All @@ -169,6 +179,11 @@ check_and_adjust_axis(int *axis, int ndim)
}
return 0;
}
static NPY_INLINE int
check_and_adjust_axis(int *axis, int ndim)
{
return check_and_adjust_axis_msg(axis, ndim, Py_None);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you have to Py_INCREF(Py_None) before passing it on. Since you then would have to Py_DECREF it as well, I suggest passing in NULL here, and doing the translation to Py_None above.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Py_None is garantueed to exist and as long as it doesn't escape to the user and it isn't decref'd everything should be fine. PyObject_CallFunction increments it's reference count for the exception so what the user gets is a new reference.

Using NULL seems cleaner though.

Copy link
Member Author
@eric-wieser eric-wieser Mar 28, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't need to here, because check_and_adjust_axis_msg takes a borrowed reference. It doesn't actually need a refcount at all until it makes it to the python code, and PyObject_CallFunction() increments the refcount for me, AFAICT.

It took a lot of iteration to get this working, and attempts at doing what you describe only caused segfaults

Copy link
Member Author
@eric-wieser eric-wieser Mar 28, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Using NULL seems cleaner though.

Why? We already need to handle Py_None being passed in from user code. Furthermore, when we get Py_None from usercode, it's a borrowed reference - so like here, it doesn't need an incref unless it escapes to python code.

The only thing that's unclear to me is whether I need to INCREF msg_prefix on the line immediately before PyObject_CallFunction. I think the answer is no, (it was no when I used Py_Buildvalue here previously), but that's not clear from the documentation

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@eric-wieser How would it cause segfaults? You could just initialize it as NULL and when you call the "exception" you do PyObject_CallFunction(AxisError_cls, "iiO", *axis, ndim, msg_prefix?msg_prefix:Py_None);

Copy link
Member Author
@eric-wieser eric-wieser Mar 28, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Segfaults came from other attempts at refcounting. I wasn't aiming to be passing null originally. Sorry if that was unclear. You're correct, what you have would add support for passing null, and would work correctly assuming that this works correctly - but I see no purpose in adding that as a special case

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, that's the reason why it's generally not a good idea to pass borrowed references around. But I didn't check the numpy source code, maybe it's common here. In that case please ignore my comments.

Copy link
Member Author
@eric-wieser eric-wieser Mar 28, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess my argument would be that if I manually inlined the contents of check_and_adjust_axis_msg into normalize_axis_index, you would not expect me to add any additional Py_INCREFs

I'm not sure what you mean there - can you elaborate on exactly why it is a bad idea? (and perhaps link me to some recommendations for working with python refcounting)

Counter-point to it generally not being a good idea - every function in a PyMethodDef object is passed borrowed references - so python seems to favor it internally.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok, this goes a bit off-topic. I still think the code is correct wrt reference counting.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some confusion added here by my first comment being written before yours appeared (and not directed at you), I think :).

}


/*
Expand Down
10 changes: 5 additions & 5 deletions numpy/core/src/multiarray/multiarraymodule.c
Original file line number Diff line number Diff line change
Expand Up @@ -4106,16 +4106,16 @@ array_may_share_memory(PyObject *NPY_UNUSED(ignored), PyObject *args, PyObject *
static PyObject *
normalize_axis_index(PyObject *NPY_UNUSED(self), PyObject *args, PyObject *kwds)
{
static char *kwlist[] = {"axis", "ndim", NULL};
static char *kwlist[] = {"axis", "ndim", "msg_prefix", NULL};
int axis;
int ndim;
PyObject *msg_prefix = Py_None;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here, again, I think you cannot just use Py_None. Also solve by initlializing to NULL?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is definitely correct. PyArg_ParseTuple does not INCREF O arguments, so this line is exactly equivalent to passing None from python.

The question remains if I should INCREF before calling check_and_adjust_axis_msg, but let's discuss that above.


if (!PyArg_ParseTupleAndKeywords(args, kwds, "ii:normalize_axis_index",
kwlist, &axis, &ndim)) {
if (!PyArg_ParseTupleAndKeywords(args, kwds, "ii|O:normalize_axis_index",
kwlist, &axis, &ndim, &msg_prefix)) {
return NULL;
}

if(check_and_adjust_axis(&axis, ndim) < 0) {
if (check_and_adjust_axis_msg(&axis, ndim, msg_prefix) < 0) {
return NULL;
}

Expand Down
14 changes: 7 additions & 7 deletions numpy/core/tests/test_numeric.py
Original file line number Diff line number Diff line change
Expand Up @@ -2425,11 +2425,11 @@ def test_move_multiples(self):

def test_errors(self):
x = np.random.randn(1, 2, 3)
assert_raises_regex(np.AxisError, 'invalid axis .* `source`',
assert_raises_regex(np.AxisError, 'source.*out of bounds',
np.moveaxis, x, 3, 0)
assert_raises_regex(np.AxisError, 'invalid axis .* `source`',
assert_raises_regex(np.AxisError, 'source.*out of bounds',
np.moveaxis, x, -4, 0)
assert_raises_regex(np.AxisError, 'invalid axis .* `destination`',
assert_raises_regex(np.AxisError, 'destination.*out of bounds',
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Error message comparison:

ValueError: invalid axis for this array in `destination` argument
AxisError: destination: axis 5 is out of bounds for array of dimension 3

np.moveaxis, x, 0, 5)
assert_raises_regex(ValueError, 'repeated axis in `source`',
np.moveaxis, x, [0, 0], [0, 1])
Expand Down Expand Up @@ -2517,13 +2517,13 @@ def test_broadcasting_shapes(self):
u = np.ones((10, 3, 5))
v = np.ones((2, 5))
assert_equal(np.cross(u, v, axisa=1, axisb=0).shape, (10, 5, 3))
assert_raises(ValueError, np.cross, u, v, axisa=1, axisb=2)
assert_raises(ValueError, np.cross, u, v, axisa=3, axisb=0)
assert_raises(np.AxisError, np.cross, u, v, axisa=1, axisb=2)
assert_raises(np.AxisError, np.cross, u, v, axisa=3, axisb=0)
u = np.ones((10, 3, 5, 7))
v = np.ones((5, 7, 2))
assert_equal(np.cross(u, v, axisa=1, axisc=2).shape, (10, 5, 3, 7))
assert_raises(ValueError, np.cross, u, v, axisa=-5, axisb=2)
assert_raises(ValueError, np.cross, u, v, axisa=1, axisb=-4)
assert_raises(np.AxisError, np.cross, u, v, axisa=-5, axisb=2)
assert_raises(np.AxisError, np.cross, u, v, axisa=1, axisb=-4)
# gh-5885
u = np.ones((3, 4, 2))
for axisc in range(-2, 2):
Expand Down
46 changes: 16 additions & 30 deletions numpy/lib/function_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from numpy.core.numeric import (
ones, zeros, arange, concatenate, array, asarray, asanyarray, empty,
empty_like, ndarray, around, floor, ceil, take, dot, where, intp,
integer, isscalar, absolute
integer, isscalar, absolute, AxisError
)
from numpy.core.umath import (
pi, multiply, add, arctan2, frompyfunc, cos, less_equal, sqrt, sin,
Expand Down Expand Up @@ -1679,21 +1679,10 @@ def gradient(f, *varargs, **kwargs):
axes = kwargs.pop('axis', None)
if axes is None:
axes = tuple(range(N))
# check axes to have correct type and no duplicate entries
if isinstance(axes, int):
axes = (axes,)
if not isinstance(axes, tuple):
raise TypeError("A tuple of integers or a single integer is required")

# normalize axis values:
axes = tuple(x + N if x < 0 else x for x in axes)
if max(axes) >= N or min(axes) < 0:
raise ValueError("'axis' entry is out of bounds")
else:
axes = _nx.normalize_axis_tuple(axes, N)

len_axes = len(axes)
if len(set(axes)) != len_axes:
raise ValueError("duplicate value in 'axis'")

n = len(varargs)
if n == 0:
dx = [1.0] * len_axes
Expand Down Expand Up @@ -3983,21 +3972,15 @@ def _ureduce(a, func, **kwargs):
if axis is not None:
keepdim = list(a.shape)
nd = a.ndim
try:
axis = operator.index(axis)
if axis >= nd or axis < -nd:
raise IndexError("axis %d out of bounds (%d)" % (axis, a.ndim))
keepdim[axis] = 1
except TypeError:
sax = set()
for x in axis:
if x >= nd or x < -nd:
raise IndexError("axis %d out of bounds (%d)" % (x, nd))
if x in sax:
raise ValueError("duplicate value in axis")
sax.add(x % nd)
keepdim[x] = 1
keep = sax.symmetric_difference(frozenset(range(nd)))
axis = _nx.normalize_axis_tuple(axis, nd)

for ax in axis:
keepdim[ax] = 1

if len(axis) == 1:
kwargs['axis'] = axis[0]
else:
keep = set(range(nd)) - set(axis)
nkeep = len(keep)
# swap axis that should not be reduced to front
for i, s in enumerate(sorted(keep)):
Expand Down Expand Up @@ -4753,7 +4736,8 @@ def delete(arr, obj, axis=None):
if ndim != 1:
arr = arr.ravel()
ndim = arr.ndim
axis = ndim - 1
axis = -1

if ndim == 0:
# 2013-09-24, 1.9
warnings.warn(
Expand All @@ -4764,6 +4748,8 @@ def delete(arr, obj, axis=None):
else:
return arr.copy(order=arrorder)

axis = normalize_axis_index(axis, ndim)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This guards against IndexError


slobj = [slice(None)]*ndim
N = arr.shape[axis]
newshape = list(arr.shape)
Expand Down
Loading
0