8000 ENH: allow numpy.apply_along_axis() to work with ndarray subclasses (… · numpy/numpy@84b11f5 · GitHub
[go: up one dir, main page]

Skip to content

Commit 84b11f5

Browse files
bennyrowlandshoyer
authored andcommitted
ENH: allow numpy.apply_along_axis() to work with ndarray subclasses (#7918)
This commit modifies the numpy.apply_along_axis() function so that if it is called with an ndarray subclass, the internal func1d calls receive subclass instances and the overall function returns an instance of the subclass. There are two new tests for these two behaviours.
1 parent dbb7094 commit 84b11f5

File tree

2 files changed

+38
-3
lines changed

2 files changed

+38
-3
lines changed

numpy/lib/shape_base.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ def apply_along_axis(func1d, axis, arr, *args, **kwargs):
7474
[2, 5, 6]])
7575
7676
"""
77-
arr = asarray(arr)
77+
arr = asanyarray(arr)
7878
nd = arr.ndim
7979
if axis < 0:
8080
axis += nd
@@ -109,11 +109,13 @@ def apply_along_axis(func1d, axis, arr, *args, **kwargs):
109109
k += 1
110110
return outarr
111111
else:
112+
res = asanyarray(res)
112113
Ntot = product(outshape)
113114
holdshape = outshape
114115
outshape = list(arr.shape)
115-
outshape[axis] = len(res)
116-
outarr = zeros(outshape, asarray(res).dtype)
116+
outshape[axis] = res.size
117+
outarr = zeros(outshape, res.dtype)
118+
outarr = res.__array_wrap__(outarr)
117119
outarr[tuple(i.tolist())] = res
118120
k = 1
119121
while k < Ntot:
@@ -128,6 +130,8 @@ def apply_along_axis(func1d, axis, arr, *args, **kwargs):
128130
res = func1d(arr[tuple(i.tolist())], *args, **kwargs)
129131
outarr[tuple(i.tolist())] = res
130132
k += 1
133+
if res.shape == ():
134+
outarr = outarr.squeeze(axis)
131135
return outarr
132136

133137

numpy/lib/tests/test_shape_base.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,37 @@ def test_3d(self):
2727
assert_array_equal(apply_along_axis(np.sum, 0, a),
2828
[[27, 30, 33], [36, 39, 42], [45, 48, 51]])
2929

30+
def test_preserve_subclass(self):
31+
def double(row):
32+
return row * 2
33+
m = np.matrix([[0, 1], [2, 3]])
34+
result = apply_along_axis(double, 0, m)
35+
assert isinstance(result, np.matrix)
36+
assert_array_equal(
37+
result, np.matrix([[0, 2], [4 8000 , 6]])
38+
)
39+
40+
def test_subclass(self):
41+
class MinimalSubclass(np.ndarray):
42+
data = 1
43+
44+
def minimal_function(array):
45+
return array.data
46+
47+
a = np.zeros((6, 3)).view(MinimalSubclass)
48+
49+
assert_array_equal(
50+
apply_along_axis(minimal_function, 0, a), np.array([1, 1, 1])
51+
)
52+
53+
def test_scalar_array(self):
54+
class MinimalSubclass(np.ndarray):
55+
pass
56+
a = np.ones((6, 3)).view(MinimalSubclass)
57+
res = apply_along_axis(np.sum, 0, a)
58+
assert isinstance(res, MinimalSubclass)
59+
assert_array_equal(res, np.array([6, 6, 6]).view(MinimalSubc 4C39 lass))
60+
3061

3162
class TestApplyOverAxes(TestCase):
3263
def test_simple(self):

0 commit comments

Comments
 (0)
0