From c8487ba92611af89e21fba7ba91c4fc60f29e07d Mon Sep 17 00:00:00 2001 From: Marten van Kerkwijk Date: Wed, 6 Jan 2016 11:37:47 -0500 Subject: [PATCH] BUG trace is not subclass aware, such that np.trace(ma) != ma.trace(). --- numpy/core/fromnumeric.py | 6 +++++- numpy/core/tests/test_multiarray.py | 27 +++++++++++++++++++++++++++ numpy/ma/tests/test_core.py | 1 + 3 files changed, 33 insertions(+), 1 deletion(-) diff --git a/numpy/core/fromnumeric.py b/numpy/core/fromnumeric.py index 10f4a98c513c..dcc6dad6ac0d 100644 --- a/numpy/core/fromnumeric.py +++ b/numpy/core/fromnumeric.py @@ -1367,7 +1367,11 @@ def trace(a, offset=0, axis1=0, axis2=1, dtype=None, out=None): (2, 3) """ - return asarray(a).trace(offset, axis1, axis2, dtype, out) + if isinstance(a, np.matrix): + # Get trace of matrix via an array to preserve backward compatibility. + return asarray(a).trace(offset, axis1, axis2, dtype, out) + else: + return asanyarray(a).trace(offset, axis1, axis2, dtype, out) def ravel(a, order='C'): diff --git a/numpy/core/tests/test_multiarray.py b/numpy/core/tests/test_multiarray.py index e70b24565c2b..4743a25b4f06 100644 --- a/numpy/core/tests/test_multiarray.py +++ b/numpy/core/tests/test_multiarray.py @@ -2012,6 +2012,33 @@ def test_diagonal_memleak(self): a.diagonal() assert_(sys.getrefcount(a) < 50) + def test_trace(self): + a = np.arange(12).reshape((3, 4)) + assert_equal(a.trace(), 15) + assert_equal(a.trace(0), 15) + assert_equal(a.trace(1), 18) + assert_equal(a.trace(-1), 13) + + b = np.arange(8).reshape((2, 2, 2)) + assert_equal(b.trace(), [6, 8]) + assert_equal(b.trace(0), [6, 8]) + assert_equal(b.trace(1), [2, 3]) + assert_equal(b.trace(-1), [4, 5]) + assert_equal(b.trace(0, 0, 1), [6, 8]) + assert_equal(b.trace(0, 0, 2), [5, 9]) + assert_equal(b.trace(0, 1, 2), [3, 11]) + assert_equal(b.trace(offset=1, axis1=0, axis2=2), [1, 3]) + + def test_trace_subclass(self): + # The class would need to overwrite trace to ensure single-element + # output also has the right subclass. + class MyArray(np.ndarray): + pass + + b = np.arange(8).reshape((2, 2, 2)).view(MyArray) + t = b.trace() + assert isinstance(t, MyArray) + def test_put(self): icodes = np.typecodes['AllInteger'] fcodes = np.typecodes['AllFloat'] diff --git a/numpy/ma/tests/test_core.py b/numpy/ma/tests/test_core.py index a11fa0e0b9cd..e505f74469ca 100644 --- a/numpy/ma/tests/test_core.py +++ b/numpy/ma/tests/test_core.py @@ -3146,6 +3146,7 @@ def test_trace(self): assert_almost_equal(mX.trace(), X.trace() - sum(mXdiag.mask * X.diagonal(), axis=0)) + assert_equal(np.trace(mX), mX.trace()) def test_dot(self): # Tests dot on MaskedArrays.