8000 MAINT: Rewrite np.ma.apply_along_axis in terms of apply_along_axis · numpy/numpy@72f88cd · GitHub
[go: up one dir, main page]

Skip to content

Commit 72f88cd

Browse files
committed
MAINT: Rewrite np.ma.apply_along_axis in terms of apply_along_axis
This changes the behavior somewhat, but not in ways related to masked arrays It would previously try to find a suitable dtype for all values, now it just uses the first dtype.
1 parent 2aabeaf commit 72f88cd

File tree

2 files changed

+20
-77
lines changed

2 files changed

+20
-77
lines changed

numpy/ma/extras.py

Lines changed: 5 additions & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -379,83 +379,11 @@ def apply_along_axis(func1d, axis, arr, *args, **kwargs):
379379
"""
380380
(This docstring should be overwritten)
381381
"""
382-
arr = array(arr, copy=False, subok=True)
383-
nd = arr.ndim
384-
axis = normalize_axis_index(axis, nd)
385-
ind = [0] * (nd - 1)
386-
i = np.zeros(nd, 'O')
387-
indlist = list(range(nd))
388-
indlist.remove(axis)
389-
i[axis] = slice(None, None)
390-
outshape = np.asarray(arr.shape).take(indlist)
391-
i.put(indlist, ind)
392-
j = i.copy()
393-
res = func1d(arr[tuple(i.tolist())], *args, **kwargs)
394-
# if res is a number, then we have a smaller output array
395-
asscalar = np.isscalar(res)
396-
if not asscalar:
397-
try:
398-
len(res)
399-
except TypeError:
400-
asscalar = True
401-
# Note: we shouldn't set the dtype of the output from the first result
402-
# so we force the type to object, and build a list of dtypes. We'll
403-
# just take the largest, to avoid some downcasting
404-
dtypes = []
405-
if asscalar:
406-
dtypes.append(np.asarray(res).dtype)
407-
outarr = zeros(outshape, object)
408-
outarr[tuple(ind)] = res
409-
Ntot = np.product(outshape)
410-
k = 1
411-
while k < Ntot:
412-
# increment the index
413-
ind[-1] += 1
414-
n = -1
415-
while (ind[n] >= outshape[n]) and (n > (1 - nd)):
416-
ind[n - 1] += 1
417-
ind[n] = 0
418-
n -= 1
419-
i.put(indlist, ind)
420-
res = func1d(arr[tuple(i.tolist())], *args, **kwargs)
421-
outarr[tuple(ind)] = res
422-
dtypes.append(asarray(res).dtype)
423-
k += 1
424-
else:
425-
res = array(res, copy=False, subok=True)
426-
j = i.copy()
427-
j[axis] = ([slice(None, None)] * res.ndim)
428-
j.put(indlist, ind)
429-
Ntot = np.product(outshape)
430-
holdshape = outshape
431-
outshape = list(arr.shape)
432-
outshape[axis] = res.shape
433-
dtypes.append(asarray(res).dtype)
434-
outshape = flatten_inplace(outshape)
435-
outarr = zeros(outshape, object)
436-
outarr[tuple(flatten_inplace(j.tolist()))] = res
437-
k = 1
438-
while k < Ntot:
439-
# increment the index
440-
ind[-1] += 1
441-
n = -1
442-
while (ind[n] >= holdshape[n]) and (n > (1 - nd)):
443-
ind[n - 1] += 1
444-
ind[n] = 0
445-
n -= 1
446-
i.put(indlist, ind)
447-
j.put(indlist, ind)
448-
res = func1d(arr[tuple(i.tolist())], *args, **kwargs)
449-
outarr[tuple(flatten_inplace(j.tolist()))] = res
450-
dtypes.append(asarray(res).dtype)
451-
k += 1
452-
max_dtypes = np.dtype(np.asarray(dtypes).max())
453-
if not hasattr(arr, '_mask'):
454-
result = np.asarray(outarr, dtype=max_dtypes)
455-
else:
456-
result = asarray(outarr, dtype=max_dtypes)
457-
result.fill_value = ma.default_fill_value(result)
458-
return result
382+
def wrapped_func(a, *args, **kwargs):
383+
res = func1d(a, *args, **kwargs)
384+
return np.asanyarray(res).view(masked_array)
385+
386+
return np.apply_along_axis(wrapped_func, axis, arr, *args, **kwargs)
459387
apply_along_axis.__doc__ = np.apply_along_axis.__doc__
460388

461389

numpy/ma/tests/test_extras.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -640,6 +640,21 @@ def myfunc(b, offset=0):
640640
xa = apply_along_axis(myfunc, 2, a, offset=1)
641641
assert_equal(xa, [[2, 5], [8, 11]])
642642

643+
def test_mask(self):
644+
x = np.arange(20).reshape(4, 5)
645+
m = np.ma.masked_where(x%2 == 0, x)
646+
647+
# note that this lambda sometime returns np.ma.masked
648+
# works fine with normal apply_along_axis
649+
row0 = np.ma.apply_along_axis(lambda x: x[0], 0, m)
650+
651+
# fails with normal apply_along_axis, which doesn't use a masked_array
652+
row1 = np.ma.apply_along_axis(lambda x: x[1], 0, m)
653+
654+
# note the first of these is a float, because that's the type of np.ma.masked...
655+
assert_equal(row0, np.ma.masked_array([-1, 1.0, -1, 3.0, -1], [1, 0, 1, 0, 1]))
656+
assert_equal(row1, np.ma.masked_array([5, -1, 7, -1, 9], [0, 1, 0, 1, 0]))
657+
643658

644659
class TestApplyOverAxes(TestCase):
645660
# Tests apply_over_axes

0 commit comments

Comments
 (0)
0