8000 Merge pull request #11518 from mhvk/normalize-axis-tuple-speedup · numpy/numpy@a52634d · GitHub
[go: up one dir, main page]

Skip to content

Commit a52634d

Browse files
authored
Merge pull request #11518 from mhvk/normalize-axis-tuple-speedup
MAINT: Speed up normalize_axis_tuple by about 30%
2 parents 0da547d + 4bed228 commit a52634d

File tree

1 file changed

+8
-5
lines changed

1 file changed

+8
-5
lines changed

numpy/core/numeric.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1509,11 +1509,14 @@ def normalize_axis_tuple(axis, ndim, argname=None, allow_duplicate=False):
15091509
--------
15101510
normalize_axis_index : normalizing a single scalar axis
15111511
"""
1512-
try:
1513-
axis = [operator.index(axis)]
1514-
except TypeError:
1515-
axis = tuple(axis)
1516-
axis = tuple(normalize_axis_index(ax, ndim, argname) for ax in axis)
1512+
# Optimization to speed-up the most common cases.
1513+
if type(axis) not in (tuple, list):
1514+
try:
1515+
axis = [operator.index(axis)]
1516+
except TypeError:
1517+
pass
1518+
# Going via an iterator directly is slower than via list comprehension.
1519+
axis = tuple([normalize_axis_index(ax, ndim, argname) for ax in axis])
15171520
if not allow_duplicate and len(set(axis)) != len(axis):
15181521
if argname:
15191522
raise ValueError('repeated axis in `{}` argument'.format(argname))

0 commit comments

Comments
 (0)
0