From 8f27d12f5f8f4c8acd3e902bed818663ab48a646 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Lo=C3=AFc=20Est=C3=A8ve?= Date: Thu, 10 Mar 2016 19:37:16 +0100 Subject: [PATCH] FIX ufunc called on memmap return a ndarray Special case for reduction functions (e.g. np.sum with axis=None) that return a numpy scalar. Keep original memmap subclasses behavior to be on the safe side. --- doc/release/1.12.0-notes.rst | 7 +++++ numpy/core/memmap.py | 21 ++++++++++++++ numpy/core/tests/test_memmap.py | 49 ++++++++++++++++++++++++++++++++- 3 files changed, 76 insertions(+), 1 deletion(-) diff --git a/doc/release/1.12.0-notes.rst b/doc/release/1.12.0-notes.rst index 058bdaac7079..41d15d114a89 100644 --- a/doc/release/1.12.0-notes.rst +++ b/doc/release/1.12.0-notes.rst @@ -177,6 +177,13 @@ were doing. This caused a complication in the downstream 'pandas' library that encountered an issue with 'numpy' compatibility. Now, all array-like methods in this module are called with keyword arguments instead. +Operations on np.memmap objects return numpy arrays in most cases +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +Previously operations on a memmap (e.g. adding 1 to it) object would +misleadingly return a memmap instance even if the result was actually +not memmapped. Also reduction of a memmap (e.g. ``.sum(axis=None``) +return a numpy scalar instead of a 0d memmap. + Deprecations ============ diff --git a/numpy/core/memmap.py b/numpy/core/memmap.py index 70d7b72b4755..827909c4712d 100644 --- a/numpy/core/memmap.py +++ b/numpy/core/memmap.py @@ -309,3 +309,24 @@ def flush(self): """ if self.base is not None and hasattr(self.base, 'flush'): self.base.flush() + + def __array_wrap__(self, arr, context=None): + arr = super(memmap, self).__array_wrap__(arr, context) + + # Return a memmap if a memmap was given as the output of the + # ufunc. Leave the arr class unchanged if self is not a memmap + # to keep original memmap subclasses behavior + if self is arr or type(self) is not memmap: + return arr + # Return scalar instead of 0d memmap, e.g. for np.sum with + # axis=None + if arr.shape == (): + return arr[()] + # Return ndarray otherwise + return arr.view(np.ndarray) + + def __getitem__(self, index): + res = super(memmap, self).__getitem__(index) + if type(res) is memmap and res._mmap is None: + return res.view(type=ndarray) + return res diff --git a/numpy/core/tests/test_memmap.py b/numpy/core/tests/test_memmap.py index e41758c51033..47f58ea7ea51 100644 --- a/numpy/core/tests/test_memmap.py +++ b/numpy/core/tests/test_memmap.py @@ -5,7 +5,9 @@ import shutil from tempfile import NamedTemporaryFile, TemporaryFile, mktemp, mkdtemp -from numpy import memmap +from numpy import ( + memmap, sum, average, product, ndarray, isscalar, add, subtract, multiply) + from numpy import arange, allclose, asarray from numpy.testing import ( TestCase, run_module_suite, assert_, assert_equal, assert_array_equal, @@ -126,5 +128,50 @@ def test_view(self): new_array = asarray(fp) assert_(new_array.base is fp) + def test_ufunc_return_ndarray(self): + fp = memmap(self.tmpfp, dtype=self.dtype, shape=self.shape) + fp[:] = self.data + + for unary_op in [sum, average, product]: + result = unary_op(fp) + assert_(isscalar(result)) + assert_(result.__class__ is self.data[0, 0].__class__) + + assert_(unary_op(fp, axis=0).__class__ is ndarray) + assert_(unary_op(fp, axis=1).__class__ is ndarray) + + for binary_op in [add, subtract, multiply]: + assert_(binary_op(fp, self.data).__class__ is ndarray) + assert_(binary_op(self.data, fp).__class__ is ndarray) + assert_(binary_op(fp, fp).__class__ is ndarray) + + fp += 1 + assert(fp.__class__ is memmap) + add(fp, 1, out=fp) + assert(fp.__class__ is memmap) + + def test_getitem(self): + fp = memmap(self.tmpfp, dtype=self.dtype, shape=self.shape) + fp[:] = self.data + + assert_(fp[1:, :-1].__class__ is memmap) + # Fancy indexing returns a copy that is not memmapped + assert_(fp[[0, 1]].__class__ is ndarray) + + def test_memmap_subclass(self): + class MemmapSubClass(memmap): + pass + + fp = MemmapSubClass(self.tmpfp, dtype=self.dtype, shape=self.shape) + fp[:] = self.data + + # We keep previous behavior for subclasses of memmap, i.e. the + # ufunc and __getitem__ output is never turned into a ndarray + assert_(sum(fp, axis=0).__class__ is MemmapSubClass) + assert_(sum(fp).__class__ is MemmapSubClass) + assert_(fp[1:, :-1].__class__ is MemmapSubClass) + assert(fp[[0, 1]].__class__ is MemmapSubClass) + + if __name__ == "__main__": run_module_suite()