8000 Merge pull request #8843 from eric-wieser/more-AxisError · numpy/numpy@ab9b15e · GitHub
[go: up one dir, main page]

Skip to content

Commit ab9b15e

Browse files
authored
Merge pull request #8843 from eric-wieser/more-AxisError
MAINT: Use AxisError in more places
2 parents 31a8fd3 + 17466ad commit ab9b15e

File tree

12 files changed

+158
-106
lines changed

12 files changed

+158
-106
lines changed

numpy/add_newdocs.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6739,7 +6739,7 @@ def luf(lamdaexpr, *args, **kwargs):
67396739

67406740
add_newdoc('numpy.core.multiarray', 'normalize_axis_index',
67416741
"""
6742-
normalize_axis_index(axis, ndim)
6742+
normalize_axis_index(axis, ndim, msg_prefix=None)
67436743
67446744
Normalizes an axis index, `axis`, such that is a valid positive index into
67456745
the shape of array with `ndim` dimensions. Raises an AxisError with an
@@ -6756,6 +6756,8 @@ def luf(lamdaexpr, *args, **kwargs):
67566756
ndim : int
67576757
The number of dimensions of the array that `axis` should be normalized
67586758
against
6759+
msg_prefix : str
6760+
A prefix to put before the message, typically the name of the argument
67596761
67606762
Returns
67616763
-------
@@ -6780,6 +6782,10 @@ def luf(lamdaexpr, *args, **kwargs):
67806782
Traceback (most recent call last):
67816783
...
67826784
AxisError: axis 3 is out of bounds for array of dimension 3
6785+
>>> normalize_axis_index(-4, ndim=3, msg_prefix='axes_arg')
6786+
Traceback (most recent call last):
6787+
...
6788+
AxisError: axes_arg: axis -4 is out of bounds for array of dimension 3
67836789
""")
67846790

67856791
##############################################################################

numpy/core/_internal.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -631,4 +631,17 @@ class TooHardError(RuntimeError):
631631
pass
632632

633633
class AxisError(ValueError, IndexError):
634-
pass
634+
""" Axis supplied was invalid. """
635+
def __init__(self, axis, ndim=None, msg_prefix=None):
636+
# single-argument form just delegates to base class
637+
if ndim is None and msg_prefix is None:
638+
msg = axis
639+
640+
# do the string formatting here, to save work in the C code
641+
else:
642+
msg = ("axis {} is out of bounds for array of dimension {}"
643+
.format(axis, ndim))
644+
if msg_prefix is not None:
645+
msg = "{}: {}".format(msg_prefix, msg)
646+
647+
super(AxisError, self).__init__(msg)

numpy/core/numeric.py

Lines changed: 59 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -441,7 +441,7 @@ def count_nonzero(a, axis=None):
441441
nullstr = a.dtype.type('')
442442
return (a != nullstr).sum(axis=axis, dtype=np.intp)
443443

444-
axis = asarray(_validate_axis(axis, a.ndim, 'axis'))
444+
axis = asarray(normalize_axis_tuple(axis, a.ndim))
445445
counts = np.apply_along_axis(multiarray.count_nonzero, axis[0], a)
446446

447447
if axis.size == 1:
@@ -1460,16 +1460,14 @@ def roll(a, shift, axis=None):
14601460
return roll(a.ravel(), shift, 0).reshape(a.shape)
14611461

14621462
else:
1463+
axis = normalize_axis_tuple(axis, a.ndim, allow_duplicate=True)
14631464
broadcasted = broadcast(shift, axis)
14641465
if broadcasted.ndim > 1:
14651466
raise ValueError(
14661467
"'shift' and 'axis' should be scalars or 1D sequences")
14671468
shifts = {ax: 0 for ax in range(a.ndim)}
14681469
for sh, ax in broadcasted:
1469-
if -a.ndim <= ax < a.ndim:
1470-
shifts[ax % a.ndim] += sh
1471-
else:
1472-
raise ValueError("'axis' entry is out of bounds")
1470+
shifts[ax] += sh
14731471

14741472
rolls = [((slice(None), slice(None)),)] * a.ndim
14751473
for ax, offset in shifts.items():
@@ -1544,17 +1542,59 @@ def rollaxis(a, axis, start=0):
15441542
return a.transpose(axes)
15451543

15461544

1547-
def _validate_axis(axis, ndim, argname):
1545+
def normalize_axis_tuple(axis, ndim, argname=None, allow_duplicate=False):
1546+
"""
1547+
Normalizes an axis argument into a tuple of non-negative integer axes.
1548+
1549+
This handles shorthands such as ``1`` and converts them to ``(1,)``,
1550+
as well as performing the handling of negative indices covered by
1551+
`normalize_axis_index`.
1552+
1553+
By default, this forbids axes from being specified multiple times.
1554+
1555+
Used internally by multi-axis-checking logic.
1556+
1557+
.. versionadded:: 1.13.0
1558+
1559+
Parameters
1560+
----------
1561+
axis : int, iterable of int
1562+
The un-normalized index or indices of the axis.
1563+
ndim : int
1564+
The number of dimensions of the array that `axis` should be normalized
1565+
against.
1566+
argname : str, optional
1567+
A prefix to put before the error message, typically the name of the
1568+
argument.
1569+
allow_duplicate : bool, optional
1570+
If False, the default, disallow an axis from being specified twice.
1571+
1572+
Returns
1573+
-------
1574+
normalized_axes : tuple of int
1575+
The normalized axis index, such that `0 <= normalized_axis < ndim`
1576+
1577+
Raises
1578+
------
1579+
AxisError
1580+
If any axis provided is out of range
1581+
ValueError
1582+
If an axis is repeated
1583+
1584+
See also
1585+
--------
1586+
normalize_axis_index : normalizing a single scalar axis
1587+
"""
15481588
try:
15491589
axis = [operator.index(axis)]
15501590
except TypeError:
1551-
axis = list(axis)
1552-
axis = [a + ndim if a < 0 else a for a in axis]
1553-
if not builtins.all(0 <= a < ndim for a in axis):
1554-
raise AxisError('invalid axis for this array in `%s` argument' %
1555-
argname)
1556-
if len(set(axis)) != len(axis):
1557-
raise ValueError F438 ('repeated axis in `%s` argument' % argname)
1591+
axis = tuple(axis)
1592+
axis = tuple(normalize_axis_index(ax, ndim, argname) for ax in axis)
1593+
if not allow_duplicate and len(set(axis)) != len(axis):
1594+
if argname:
1595+
raise ValueError('repeated axis in `{}` argument'.format(argname))
1596+
else:
1597+
raise ValueError('repeated axis')
15581598
return axis
15591599

15601600

@@ -1614,8 +1654,8 @@ def moveaxis(a, source, destination):
16141654
a = asarray(a)
16151655
transpose = a.transpose
16161656

1617-
source = _validate_axis(source, a.ndim, 'source')
1618-
destination = _validate_axis(destination, a.ndim, 'destination')
1657+
source = normalize_axis_tuple(source, a.ndim, 'source')
1658+
destination = normalize_axis_tuple(destination, a.ndim, 'destination')
16191659
if len(source) != len(destination):
16201660
raise ValueError('`source` and `destination` arguments must have '
16211661
'the same number of elements')
@@ -1752,11 +1792,9 @@ def cross(a, b, axisa=-1, axisb=-1, axisc=-1, axis=None):
17521792
a = asarray(a)
17531793
b = asarray(b)
17541794
# Check axisa and axisb are within bounds
1755-
axis_msg = "'axis{0}' out of bounds"
1756-
if axisa < -a.ndim or axisa >= a.ndim:
1757-
raise ValueError(axis_msg.format('a'))
1758-
if axisb < -b.ndim or axisb >= b.ndim:
1759-
raise ValueError(axis_msg.format('b'))
1795+
axisa = normalize_axis_index(axisa, a.ndim, msg_prefix='axisa')
1796+
axisb = normalize_axis_index(axisb, b.ndim, msg_prefix='axisb')
1797+
17601798
# Move working axis to the end of the shape
17611799
a = rollaxis(a, axisa, a.ndim)
17621800
b = rollaxis(b, axisb, b.ndim)
@@ -1770,8 +1808,7 @@ def cross(a, b, axisa=-1, axisb=-1, axisc=-1, axis=None):
17701808
if a.shape[-1] == 3 or b.shape[-1] == 3:
17711809
shape += (3,)
17721810
# Check axisc is within bounds
1773-
if axisc < -len(shape) or axisc >= len(shape):
1774-
raise ValueError(axis_msg.format('c'))
1811+
axisc = normalize_axis_index(axisc, len(shape), msg_prefix='axisc')
17751812
dtype = promote_types(a.dtype, b.dtype)
17761813
cp = empty(shape, dtype)
17771814

numpy/core/src/multiarray/common.h

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -138,9 +138,11 @@ check_and_adjust_index(npy_intp *index, npy_intp max_item, int axis,
138138
* Returns -1 and sets an exception if *axis is an invalid axis for
139139
* an array of dimension ndim, otherwise adjusts it in place to be
140140
* 0 <= *axis < ndim, and returns 0.
141+
*
142+
* msg_prefix: borrowed reference, a string to prepend to the message
141143
*/
142144
static NPY_INLINE int
143-
check_and_adjust_axis(int *axis, int ndim)
145+
check_and_adjust_axis_msg(int *axis, int ndim, PyObject *msg_prefix)
144146
{
145147
/* Check that index is valid, taking into account negative indices */
146148
if (NPY_UNLIKELY((*axis < -ndim) || (*axis >= ndim))) {
@@ -149,6 +151,8 @@ check_and_adjust_axis(int *axis, int ndim)
149151
* we don't have access to npy_cache_import here
150152
*/
151153
static PyObject *AxisError_cls = NULL;
154+
PyObject *exc;
155+
152156
if (AxisError_cls == NULL) {
153157
PyObject *mod = PyImport_ImportModule("numpy.core._internal");
154158

@@ -158,9 +162,15 @@ check_and_adjust_axis(int *axis, int ndim)
158162
}
159163
}
160164

161-
PyErr_Format(AxisError_cls,
162-
"axis %d is out of bounds for array of dimension %d",
163-
*axis, ndim);
165+
/* Invoke the AxisError constructor */
166+
exc = PyObject_CallFunction(AxisError_cls, "iiO",
167+
*axis, ndim, msg_prefix);
168+
if (exc == NULL) {
169+
return -1;
170+
}
171+
PyErr_SetObject(AxisError_cls, exc);
172+
Py_DECREF(exc);
173+
164174
return -1;
165175
}
166176
/* adjust negative indices */
@@ -169,6 +179,11 @@ check_and_adjust_axis(int *axis, int ndim)
169179
}
170180
return 0;
171181
}
182+
static NPY_INLINE int
183+
check_and_adjust_axis(int *axis, int ndim)
184+
{
185+
return check_and_adjust_axis_msg(axis, ndim, Py_None);
186+
}
172187

173188

174189
/*

numpy/core/src/multiarray/multiarraymodule.c

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4106,16 +4106,16 @@ array_may_share_memory(PyObject *NPY_UNUSED(ignored), PyObject *args, PyObject *
41064106
static PyObject *
41074107
normalize_axis_index(PyObject *NPY_UNUSED(self), PyObject *args, PyObject *kwds)
41084108
{
4109-
static char *kwlist[] = {"axis", "ndim", NULL};
4109+
static char *kwlist[] = {"axis", "ndim", "msg_prefix", NULL};
41104110
int axis;
41114111
int ndim;
4112+
PyObject *msg_prefix = Py_None;
41124113

4113-
if (!PyArg_ParseTupleAndKeywords(args, kwds, "ii:normalize_axis_index",
4114-
kwlist, &axis, &ndim)) {
4114+
if (!PyArg_ParseTupleAndKeywords(args, kwds, "ii|O:normalize_axis_index",
4115+
kwlist, &axis, &ndim, &msg_prefix)) {
41154116
return NULL;
41164117
}
4117-
4118-
if(check_and_adjust_axis(&axis, ndim) < 0) {
4118+
if (check_and_adjust_axis_msg(&axis, ndim, msg_prefix) < 0) {
41194119
return NULL;
41204120
}
41214121

numpy/core/tests/test_numeric.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2425,11 +2425,11 @@ def test_move_multiples(self):
24252425

24262426
def test_errors(self):
24272427
x = np.random.randn(1, 2, 3)
2428-
assert_raises_regex(np.AxisError, 'invalid axis .* `source`',
2428+
assert_raises_regex(np.AxisError, 'source.*out of bounds',
24292429
np.moveaxis, x, 3, 0)
2430-
assert_raises_regex(np.AxisError, 'invalid axis .* `source`',
2430+
assert_raises_regex(np.AxisError, 'source.*out of bounds',
24312431
np.moveaxis, x, -4, 0)
2432-
assert_raises_regex(np.AxisError, 'invalid axis .* `destination`',
2432+
assert_raises_regex(np.AxisError, 'destination.*out of bounds',
24332433
np.moveaxis, x, 0, 5)
24342434
assert_raises_regex(ValueError, 'repeated axis in `source`',
24352435
np.moveaxis, x, [0, 0], [0, 1])
@@ -2517,13 +2517,13 @@ def test_broadcasting_shapes(self):
25172517
u = np.ones((10, 3, 5))
25182518
v = np.ones((2, 5))
25192519
assert_equal(np.cross(u, v, axisa=1, axisb=0).shape, (10, 5, 3))
2520-
assert_raises(ValueError, np.cross, u, v, axisa=1, axisb=2)
2521-
assert_raises(ValueError, np.cross, u, v, axisa=3, axisb=0)
2520+
assert_raises(np.AxisError, np.cross, u, v, axisa=1, axisb=2)
2521+
assert_raises(np.AxisError, np.cross, u, v, axisa=3, axisb=0)
25222522
u = np.ones((10, 3, 5, 7))
25232523
v = np.ones((5, 7, 2))
25242524
assert_equal(np.cross(u, v, axisa=1, axisc=2).shape, (10, 5, 3, 7))
2525-
assert_raises(ValueError, np.cross, u, v, axisa=-5, axisb=2)
2526-
assert_raises(ValueError, np.cross, u, v, axisa=1, axisb=-4)
2525+
assert_raises(np.AxisError, np.cross, u, v, axisa=-5, axisb=2)
2526+
assert_raises(np.AxisError, np.cross, u, v, axisa=1, axisb=-4)
25272527
# gh-5885
25282528
u = np.ones((3, 4, 2))
25292529
for axisc in range(-2, 2):

numpy/lib/function_base.py

Lines changed: 16 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from numpy.core.numeric import (
1313
ones, zeros, arange, concatenate, array, asarray, asanyarray, empty,
1414
empty_like, ndarray, around, floor, ceil, take, dot, where, intp,
15-
integer, isscalar, absolute
15+
integer, isscalar, absolute, AxisError
1616
)
1717
from numpy.core.umath import (
1818
pi, multiply, add, arctan2, frompyfunc, cos, less_equal, sqrt, sin,
@@ -1679,21 +1679,10 @@ def gradient(f, *varargs, **kwargs):
16791679
axes = kwargs.pop('axis', None)
16801680
if axes is None:
16811681
axes = tuple(range(N))
1682-
# check axes to have correct type and no duplicate entries
1683-
if isinstance(axes, int):
1684-
axes = (axes,)
1685-
if not isinstance(axes, tuple):
1686-
raise TypeError("A tuple of integers or a single integer is required")
1687-
1688-
# normalize axis values:
1689-
axes = tuple(x + N if x < 0 else x for x in axes)
1690-
if max(axes) >= N or min(axes) < 0:
1691-
raise ValueError("'axis' entry is out of bounds")
1682+
else:
1683+
axes = _nx.normalize_axis_tuple(axes, N)
16921684

16931685
len_axes = len(axes)
1694-
if len(set(axes)) != len_axes:
1695-
raise ValueError("duplicate value in 'axis'")
1696-
16971686
n = len(varargs)
16981687
if n == 0:
16991688
dx = [1.0] * len_axes
@@ -3983,21 +3972,15 @@ def _ureduce(a, func, **kwargs):
39833972
if axis is not None:
39843973
keepdim = list(a.shape)
39853974
nd = a.ndim
3986-
try:
3987-
axis = operator.index(axis)
3988-
if axis >= nd or axis < -nd:
3989-
raise IndexError("axis %d out of bounds (%d)" % (axis, a.ndim))
3990-
keepdim[axis] = 1
3991-
except TypeError:
3992-
sax = set()
3993-
for x in axis:
3994-
if x >= nd or x < -nd:
3995-
raise IndexError("axis %d out of bounds (%d)" % (x, nd))
3996-
if x in sax:
3997-
raise ValueError("duplicate value in axis")
3998-
sax.add(x % nd)
3999-
keepdim[x] = 1
4000-
keep = sax.symmetric_difference(frozenset(range(nd)))
3975+
axis = _nx.normalize_axis_tuple(axis, nd)
3976+
3977+
for ax in axis:
3978+
keepdim[ax] = 1
3979+
3980+
if len(axis) == 1:
3981+
kwargs['axis'] = axis[0]
3982+
else:
3983+
keep = set(range(nd)) - set(axis)
40013984
nkeep = len(keep)
40023985
# swap axis that should not be reduced to front
40033986
for i, s in enumerate(sorted(keep)):
@@ -4753,7 +4736,8 @@ def delete(arr, obj, axis=None):
47534736
if ndim != 1:
47544737
arr = arr.ravel()
47554738
ndim = arr.ndim
4756-
axis = ndim - 1
4739+
axis = -1
4740+
47574741
if ndim == 0:
47584742
# 2013-09-24, 1.9
47594743
warnings.warn(
@@ -4764,6 +4748,8 @@ def delete(arr, obj, axis=None):
47644748
else:
47654749
return arr.copy(order=arrorder)
47664750

4751+
axis = normalize_axis_index(axis, ndim)
4752+
47674753
slobj = [slice(None)]*ndim
47684754
N = arr.shape[axis]
47694755
newshape = list(arr.shape)

0 commit comments

Comments
 (0)
0