-
-
Notifications
You must be signed in to change notification settings - Fork 10.9k
MAINT: Use AxisError in more places #8843
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
7bf9115
4feb97e
09d4d35
cd01f01
efa1bd2
b6850e9
539d4f7
e3ed705
17466ad
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -441,7 +441,7 @@ def count_nonzero(a, axis=None): | |
nullstr = a.dtype.type('') | ||
return (a != nullstr).sum(axis=axis, dtype=np.intp) | ||
|
||
axis = asarray(_validate_axis(axis, a.ndim, 'axis')) | ||
axis = asarray(normalize_axis_tuple(axis, a.ndim)) | ||
counts = np.apply_along_axis(multiarray.count_nonzero, axis[0], a) | ||
|
||
if axis.size == 1: | ||
|
@@ -1460,16 +1460,14 @@ def roll(a, shift, axis=None): | |
return roll(a.ravel(), shift, 0).reshape(a.shape) | ||
|
||
else: | ||
axis = normalize_axis_tuple(axis, a.ndim, allow_duplicate=True) | ||
broadcasted = broadcast(shift, axis) | ||
if broadcasted.ndim > 1: | ||
raise ValueError( | ||
"'shift' and 'axis' should be scalars or 1D sequences") | ||
shifts = {ax: 0 for ax in range(a.ndim)} | ||
for sh, ax in broadcasted: | ||
if -a.ndim <= ax < a.ndim: | ||
shifts[ax % a.ndim] += sh | ||
else: | ||
raise ValueError("'axis' entry is out of bounds") | ||
shifts[ax] += sh | ||
|
||
rolls = [((slice(None), slice(None)),)] * a.ndim | ||
for ax, offset in shifts.items(): | ||
|
@@ -1544,17 +1542,59 @@ def rollaxis(a, axis, start=0): | |
return a.transpose(axes) | ||
|
||
|
||
def _validate_axis(axis, ndim, argname): | ||
def normalize_axis_tuple(axis, ndim, argname=None, allow_duplicate=False): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I slightly wonder -- mostly because of #8819 -- whether it would be useful to have this in C... There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Yes, I think it would - and I think we even already have it for |
||
""" | ||
Normalizes an axis argument into a tuple of non-negative integer axes. | ||
|
||
This handles shorthands such as ``1`` and converts them to ``(1,)``, | ||
as well as performing the handling of negative indices covered by | ||
`normalize_axis_index`. | ||
|
||
By default, this forbids axes from being specified multiple times. | ||
|
||
Used internally by multi-axis-checking logic. | ||
|
||
.. versionadded:: 1.13.0 | ||
|
||
Parameters | ||
---------- | ||
axis : int, iterable of int | ||
The un-normalized index or indices of the axis. | ||
ndim : int | ||
The number of dimensions of the array that `axis` should be normalized | ||
against. | ||
argname : str, optional | ||
A prefix to put before the error message, typically the name of the | ||
argument. | ||
allow_duplicate : bool, optional | ||
If False, the default, disallow an axis from being specified twice. | ||
|
||
Returns | ||
------- | ||
normalized_axes : tuple of int | ||
The normalized axis index, such that `0 <= normalized_axis < ndim` | ||
|
||
Raises | ||
------ | ||
AxisError | ||
If any axis provided is out of range | ||
ValueError | ||
If an axis is repeated | ||
|
||
See also | ||
-------- | ||
normalize_axis_index : normalizing a single scalar axis | ||
""" | ||
try: | ||
axis = [operator.index(axis)] | ||
except TypeError: | ||
axis = list(axis) | ||
axis = [a + ndim if a < 0 else a for a in axis] | ||
if not builtins.all(0 <= a < ndim for a in axis): | ||
raise AxisError('invalid axis for this array in `%s` argument' % | ||
argname) | ||
if len(set(axis)) != len(axis): | ||
raise ValueError('repeated axis in `%s` argument' % argname) | ||
axis = tuple(axis) | ||
axis = tuple(normalize_axis_index(ax, ndim, argname) for ax in axis) | ||
if not allow_duplicate and len(set(axis)) != len(axis): | ||
if argname: | ||
raise ValueError('repeated axis in `{}` argument'.format(argname)) | ||
else: | ||
raise ValueError('repeated axis') | ||
return axis | ||
|
||
|
||
|
@@ -1614,8 +1654,8 @@ def moveaxis(a, source, destination): | |
a = asarray(a) | ||
transpose = a.transpose | ||
|
||
source = _validate_axis(source, a.ndim, 'source') | ||
destination = _validate_axis(destination, a.ndim, 'destination') | ||
source = normalize_axis_tuple(source, a.ndim, 'source') | ||
destination = normalize_axis_tuple(destination, a.ndim, 'destination') | ||
if len(source) != len(destination): | ||
raise ValueError('`source` and `destination` arguments must have ' | ||
'the same number of elements') | ||
|
@@ -1752,11 +1792,9 @@ def cross(a, b, axisa=-1, axisb=-1, axisc=-1, axis=None): | |
a = asarray(a) | ||
b = asarray(b) | ||
# Check axisa and axisb are within bounds | ||
axis_msg = "'axis{0}' out of bounds" | ||
if axisa < -a.ndim or axisa >= a.ndim: | ||
raise ValueError(axis_msg.format('a')) | ||
if axisb < -b.ndim or axisb >= b.ndim: | ||
raise ValueError(axis_msg.format('b')) | ||
axisa = normalize_axis_index(axisa, a.ndim, msg_prefix='axisa') | ||
axisb = normalize_axis_index(axisb, b.ndim, msg_prefix='axisb') | ||
|
||
# Move working axis to the end of the shape | ||
a = rollaxis(a, axisa, a.ndim) | ||
b = rollaxis(b, axisb, b.ndim) | ||
|
@@ -1770,8 +1808,7 @@ def cross(a, b, axisa=-1, axisb=-1, axisc=-1, axis=None): | |
if a.shape[-1] == 3 or b.shape[-1] == 3: | ||
shape += (3,) | ||
# Check axisc is within bounds | ||
if axisc < -len(shape) or axisc >= len(shape): | ||
raise ValueError(axis_msg.format('c')) | ||
axisc = normalize_axis_index(axisc, len(shape), msg_prefix='axisc') | ||
dtype = promote_types(a.dtype, b.dtype) | ||
cp = empty(shape, dtype) | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -138,9 +138,11 @@ check_and_adjust_index(npy_intp *index, npy_intp max_item, int axis, | |
* Returns -1 and sets an exception if *axis is an invalid axis for | ||
* an array of dimension ndim, otherwise adjusts it in place to be | ||
* 0 <= *axis < ndim, and returns 0. | ||
* | ||
* msg_prefix: borrowed reference, a string to prepend to the message | ||
*/ | ||
static NPY_INLINE int | ||
check_and_adjust_axis(int *axis, int ndim) | ||
check_and_adjust_axis_msg(int *axis, int ndim, PyObject *msg_prefix) | ||
{ | ||
/* Check that index is valid, taking into account negative indices */ | ||
if (NPY_UNLIKELY((*axis < -ndim) || (*axis >= ndim))) { | ||
|
@@ -149,6 +151,8 @@ check_and_adjust_axis(int *axis, int ndim) | |
* we don't have access to npy_cache_import here | ||
*/ | ||
static PyObject *AxisError_cls = NULL; | ||
PyObject *exc; | ||
|
||
if (AxisError_cls == NULL) { | ||
PyObject *mod = PyImport_ImportModule("numpy.core._internal"); | ||
|
||
|
@@ -158,9 +162,15 @@ check_and_adjust_axis(int *axis, int ndim) | |
} | ||
} | ||
|
||
PyErr_Format(AxisError_cls, | ||
"axis %d is out of bounds for array of dimension %d", | ||
*axis, ndim); | ||
/* Invoke the AxisError constructor */ | ||
exc = PyObject_CallFunction(AxisError_cls, "iiO", | ||
*axis, ndim, msg_prefix); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This was the way I should have written this the first time... |
||
if (exc == NULL) { | ||
return -1; | ||
} | ||
PyErr_SetObject(AxisError_cls, exc); | ||
Py_DECREF(exc); | ||
|
||
return -1; | ||
} | ||
/* adjust negative indices */ | ||
|
@@ -169,6 +179,11 @@ check_and_adjust_axis(int *axis, int ndim) | |
} | ||
return 0; | ||
} | ||
static NPY_INLINE int | ||
check_and_adjust_axis(int *axis, int ndim) | ||
{ | ||
return check_and_adjust_axis_msg(axis, ndim, Py_None); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think you have to There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Using There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't need to here, because It took a lot of iteration to get this working, and attempts at doing what you describe only caused segfaults There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Why? We already need to handle The only thing that's unclear to me is whether I need to There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @eric-wieser How would it cause segfaults? You could just initialize it as There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Segfaults came from other attempts at refcounting. I wasn't aiming to be passing null originally. Sorry if that was unclear. You're correct, what you have would add support for passing null, and would work correctly assuming that this works correctly - but I see no purpose in adding that as a special case There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah, that's the reason why it's generally not a good idea to pass borrowed references around. But I didn't check the numpy source code, maybe it's common here. In that case please ignore my comments. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I guess my argument would be that if I manually inlined the contents of I'm not sure what you mean there - can you elaborate on exactly why it is a bad idea? (and perhaps link me to some recommendations for working with python refcounting) Counter-point to it generally not being a good idea - every function in a There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ok, this goes a bit off-topic. I still think the code is correct wrt reference counting. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Some confusion added here by my first comment being written before yours appeared (and not directed at you), I think :). |
||
} | ||
|
||
|
||
/* | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -4106,16 +4106,16 @@ array_may_share_memory(PyObject *NPY_UNUSED(ignored), PyObject *args, PyObject * | |
static PyObject * | ||
normalize_axis_index(PyObject *NPY_UNUSED(self), PyObject *args, PyObject *kwds) | ||
{ | ||
static char *kwlist[] = {"axis", "ndim", NULL}; | ||
static char *kwlist[] = {"axis", "ndim", "msg_prefix", NULL}; | ||
int axis; | ||
int ndim; | ||
PyObject *msg_prefix = Py_None; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Here, again, I think you cannot just use There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is definitely correct. The question remains if I should |
||
|
||
if (!PyArg_ParseTupleAndKeywords(args, kwds, "ii:normalize_axis_index", | ||
kwlist, &axis, &ndim)) { | ||
if (!PyArg_ParseTupleAndKeywords(args, kwds, "ii|O:normalize_axis_index", | ||
kwlist, &axis, &ndim, &msg_prefix)) { | ||
return NULL; | ||
} | ||
|
||
if(check_and_adjust_axis(&axis, ndim) < 0) { | ||
if (check_and_adjust_axis_msg(&axis, ndim, msg_prefix) < 0) { | ||
return NULL; | ||
} | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2425,11 +2425,11 @@ def test_move_multiples(self): | |
|
||
def test_errors(self): | ||
x = np.random.randn(1, 2, 3) | ||
assert_raises_regex(np.AxisError, 'invalid axis .* `source`', | ||
assert_raises_regex(np.AxisError, 'source.*out of bounds', | ||
np.moveaxis, x, 3, 0) | ||
assert_raises_regex(np.AxisError, 'invalid axis .* `source`', | ||
assert_raises_regex(np.AxisError, 'source.*out of bounds', | ||
np.moveaxis, x, -4, 0) | ||
assert_raises_regex(np.AxisError, 'invalid axis .* `destination`', | ||
assert_raises_regex(np.AxisError, 'destination.*out of bounds', | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Error message comparison:
|
||
np.moveaxis, x, 0, 5) | ||
assert_raises_regex(ValueError, 'repeated axis in `source`', | ||
np.moveaxis, x, [0, 0], [0, 1]) | ||
|
@@ -2517,13 +2517,13 @@ def test_broadcasting_shapes(self): | |
u = np.ones((10, 3, 5)) | ||
v = np.ones((2, 5)) | ||
assert_equal(np.cross(u, v, axisa=1, axisb=0).shape, (10, 5, 3)) | ||
assert_raises(ValueError, np.cross, u, v, axisa=1, axisb=2) | ||
assert_raises(ValueError, np.cross, u, v, axisa=3, axisb=0) | ||
assert_raises(np.AxisError, np.cross, u, v, axisa=1, axisb=2) | ||
assert_raises(np.AxisError, np.cross, u, v, axisa=3, axisb=0) | ||
u = np.ones((10, 3, 5, 7)) | ||
v = np.ones((5, 7, 2)) | ||
assert_equal(np.cross(u, v, axisa=1, axisc=2).shape, (10, 5, 3, 7)) | ||
assert_raises(ValueError, np.cross, u, v, axisa=-5, axisb=2) | ||
assert_raises(ValueError, np.cross, u, v, axisa=1, axisb=-4) | ||
assert_raises(np.AxisError, np.cross, u, v, axisa=-5, axisb=2) | ||
assert_raises(np.AxisError, np.cross, u, v, axisa=1, axisb=-4) | ||
# gh-5885 | ||
u = np.ones((3, 4, 2)) | ||
for axisc in range(-2, 2): | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -12,7 +12,7 @@ | |
from numpy.core.numeric import ( | ||
ones, zeros, arange, concatenate, array, asarray, asanyarray, empty, | ||
empty_like, ndarray, around, floor, ceil, take, dot, where, intp, | ||
integer, isscalar, absolute | ||
integer, isscalar, absolute, AxisError | ||
) | ||
from numpy.core.umath import ( | ||
pi, multiply, add, arctan2, frompyfunc, cos, less_equal, sqrt, sin, | ||
|
@@ -1679,21 +1679,10 @@ def gradient(f, *varargs, **kwargs): | |
axes = kwargs.pop('axis', None) | ||
if axes is None: | ||
axes = tuple(range(N)) | ||
# check axes to have correct type and no duplicate entries | ||
if isinstance(axes, int): | ||
axes = (axes,) | ||
if not isinstance(axes, tuple): | ||
raise TypeError("A tuple of integers or a single integer is required") | ||
|
||
# normalize axis values: | ||
axes = tuple(x + N if x < 0 else x for x in axes) | ||
if max(axes) >= N or min(axes) < 0: | ||
raise ValueError("'axis' entry is out of bounds") | ||
else: | ||
axes = _nx.normalize_axis_tuple(axes, N) | ||
|
||
len_axes = len(axes) | ||
if len(set(axes)) != len_axes: | ||
raise ValueError("duplicate value in 'axis'") | ||
|
||
n = len(varargs) | ||
if n == 0: | ||
dx = [1.0] * len_axes | ||
|
@@ -3983,21 +3972,15 @@ def _ureduce(a, func, **kwargs): | |
if axis is not None: | ||
keepdim = list(a.shape) | ||
nd = a.ndim | ||
try: | ||
axis = operator.index(axis) | ||
if axis >= nd or axis < -nd: | ||
raise IndexError("axis %d out of bounds (%d)" % (axis, a.ndim)) | ||
keepdim[axis] = 1 | ||
except TypeError: | ||
sax = set() | ||
for x in axis: | ||
if x >= nd or x < -nd: | ||
raise IndexError("axis %d out of bounds (%d)" % (x, nd)) | ||
if x in sax: | ||
raise ValueError("duplicate value in axis") | ||
sax.add(x % nd) | ||
keepdim[x] = 1 | ||
keep = sax.symmetric_difference(frozenset(range(nd))) | ||
axis = _nx.normalize_axis_tuple(axis, nd) | ||
|
||
for ax in axis: | ||
keepdim[ax] = 1 | ||
|
||
if len(axis) == 1: | ||
kwargs['axis'] = axis[0] | ||
else: | ||
keep = set(range(nd)) - set(axis) | ||
nkeep = len(keep) | ||
# swap axis that should not be reduced to front | ||
for i, s in enumerate(sorted(keep)): | ||
|
@@ -4753,7 +4736,8 @@ def delete(arr, obj, axis=None): | |
if ndim != 1: | ||
arr = arr.ravel() | ||
ndim = arr.ndim | ||
axis = ndim - 1 | ||
axis = -1 | ||
|
||
if ndim == 0: | ||
# 2013-09-24, 1.9 | ||
warnings.warn( | ||
|
@@ -4764,6 +4748,8 @@ def delete(arr, obj, axis=None): | |
else: | ||
return arr.copy(order=arrorder) | ||
|
||
axis = normalize_axis_index(axis, ndim) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This guards against |
||
|
||
slobj = [slice(None)]*ndim | ||
N = arr.shape[axis] | ||
newshape = list(arr.shape) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Now with a docstring, which Ipython shows when you type an exception name. This is in a similar style to the builtin ones.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This should be in
numpy/_globals.py
.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
And imported into
numpy/__init__.py
.Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I put it here in #8584 because it needs to be imported from C code, and that's what
_internals
seems to be used for. Is it safe to load_globals.py
from C?AxisError
is already visible at the global scope asnp.AxisError
, as it goes through the__all__
chain.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You can import from practically anything. We also import from multiarray itself into C code. I'd import this from '"numpy"', The
_globals
is for singletons that may also be used by python code.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should
TooHardError
move there too?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
right now,
numeric.py
hasfrom ._internal import TooHardError, AxisError
core/__init__.py
also containsfrom . import _internal # for freeze programs
, for some reasonThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Probably, but it can be done in a separate PR as the imports also need to be changed. Maybe
_globals.py
should be renamed_global_singletons
or_global_exceptions
;) There are other errors scattered about that maybe we should move at some point.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Or maybe
_numpy_exceptions
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do you want me to do both in a separate PR?
AxisError
being in that file is not new to this PR.(either way, fix coming for something else, so don't merge yet!)