8000 * Add __eq__ and __ne__ for support of flexible arrays. · numpy/numpy@a5da87c · GitHub
[go: up one dir, main page]

Skip to content

Commit a5da87c

Browse files
author
pierregm
committed
* Add __eq__ and __ne__ for support of flexible arrays.
* Fixed .filled for nested structures
1 parent 99f428e commit a5da87c

File tree

2 files changed

+129
-4
lines changed

2 files changed

+129
-4
lines changed

numpy/ma/core.py

Lines changed: 85 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -857,6 +857,7 @@ def __str__ (self):
857857
#####--------------------------------------------------------------------------
858858
#---- --- Mask creation functions ---
859859
#####--------------------------------------------------------------------------
860+
860861
def _recursive_make_descr(datatype, newtype=bool_):
861862
"Private function allowing recursion in make_descr."
862863
# Do we have some name fields ?
@@ -1134,6 +1135,7 @@ def masked_where(condition, a, copy=True):
11341135
result._mask = cond
11351136
return result
11361137

1138+
11371139
def masked_greater(x, value, copy=True):
11381140
"""
11391141
Return the array `x` masked where (x > value).
@@ -1142,22 +1144,27 @@ def masked_greater(x, value, copy=True):
11421144
"""
11431145
return masked_where(greater(x, value), x, copy=copy)
11441146

1147+
11451148
def masked_greater_equal(x, value, copy=True):
11461149
"Shortcut to masked_where, with condition = (x >= value)."
11471150
return masked_where(greater_equal(x, value), x, copy=copy)
11481151

1152+
11491153
def masked_less(x, value, copy=True):
11501154
"Shortcut to masked_where, with condition = (x < value)."
11511155
return masked_where(less(x, value), x, copy=copy)
11521156

1157+
11531158
def masked_less_equal(x, value, copy=True):
11541159
"Shortcut to masked_where, with condition = (x <= value)."
11551160
return masked_where(less_equal(x, value), x, copy=copy)
11561161

1162+
11571163
def masked_not_equal(x, value, copy=True):
11581164
"Shortcut to masked_where, with condition = (x != value)."
11591165
return masked_where(not_equal(x, value), x, copy=copy)
11601166

1167+
11611168
def masked_equal(x, value, copy=True):
11621169
"""
11631170
Shortcut to masked_where, with condition = (x == value). For
@@ -1171,6 +1178,7 @@ def masked_equal(x, value, copy=True):
11711178
# return array(d, mask=m, copy=copy)
11721179
return masked_where(equal(x, value), x, copy=copy)
11731180

1181+
11741182
def masked_inside(x, v1, v2, copy=True):
11751183
"""
11761184
Shortcut to masked_where, where ``condition`` is True for x inside
@@ -1188,6 +1196,7 @@ def masked_inside(x, v1, v2, copy=True):
11881196
condition = (xf >= v1) & (xf <= v2)
11891197
return masked_where(condition, x, copy=copy)
11901198

1199+
11911200
def masked_outside(x, v1, v2, copy=True):
11921201
"""
11931202
Shortcut to ``masked_where``, where ``condition`` is True for x outside
@@ -1205,7 +1214,7 @@ def masked_outside(x, v1, v2, copy=True):
12051214
condition = (xf < v1) | (xf > v2)
12061215
return masked_where(condition, x, copy=copy)
12071216

1208-
#
1217+
12091218
def masked_object(x, value, copy=True, shrink=True):
12101219
"""
12111220
Mask the array `x` where the data are exactly equal to value.
@@ -1234,6 +1243,7 @@ def masked_object(x, value, copy=True, shrink=True):
12341243
mask = mask_or(mask, make_mask(condition, shrink=shrink))
12351244
return masked_array(x, mask=mask, copy=copy, fill_value=value)
12361245

1246+
12371247
def masked_values(x, value, rtol=1.e-5, atol=1.e-8, copy=True, shrink=True):
12381248
"""
12391249
Mask the array x where the data are approximately equal in
@@ -1271,6 +1281,7 @@ def masked_values(x, value, rtol=1.e-5, atol=1.e-8, copy=True, shrink=True):
12711281
mask = mask_or(mask, make_mask(condition, shrink=shrink))
12721282
return masked_array(xnew, mask=mask, copy=copy, fill_value=value)
12731283

1284+
12741285
def masked_invalid(a, copy=True):
12751286
"""
12761287
Mask the array for invalid values (NaNs or infs).
@@ -1292,6 +1303,7 @@ def masked_invalid(a, copy=True):
12921303
#####--------------------------------------------------------------------------
12931304
#---- --- Printing options ---
12941305
#####--------------------------------------------------------------------------
1306+
12951307
class _MaskedPrintOption:
12961308
"""
12971309
Handle the string used to represent missing data in a masked array.
@@ -1372,6 +1384,20 @@ def _recursive_printoption(result, mask, printopt):
13721384
#---- --- MaskedArray class ---
13731385
#####--------------------------------------------------------------------------
13741386

1387+
def _recursive_filled(a, mask, fill_value):
1388+
"""
1389+
Recursively fill `a` with `fill_value`.
1390+
Private function
1391+
"""
1392+
names = a.dtype.names
1393+
for name in names:
1394+
current = a[name]
1395+
print "Name: %s : %s" % (name, current)
1396+
if current.dtype.names:
1397+
_recursive_filled(current, mask[name], fill_value[name])
1398+
else:
1399+
np.putmask(current, mask[name], fill_value[name])
1400+
13751401
#...............................................................................
13761402
class _arraymethod(object):
13771403
"""
@@ -2013,6 +2039,7 @@ def _getrecordmask(self):
20132039
try:
20142040
return _mask.view((bool_, len(self.dtype))).all(axis)
20152041
except ValueError:
2042+
# In case we have nested fields...
20162043
return np.all([[f[n].all() for n in _mask.dtype.names]
20172044
for f in _mask], axis=axis)
20182045

@@ -2106,6 +2133,7 @@ def set_fill_value(self, value=None):
21062133
fill_value = property(fget=get_fill_value, fset=set_fill_value,
21072134
doc="Filling value.")
21082135

2136+
21092137
def filled(self, fill_value=None):
21102138
"""Return a copy of self._data, where masked values are filled
21112139
with fill_value.
@@ -2140,9 +2168,10 @@ def filled(self, fill_value=None):
21402168
#
21412169
if m.dtype.names:
21422170
result = self._data.copy()
2143-
for n in result.dtype.names:
2144-
field = result[n]
2145-
np.putmask(field, self._mask[n], fill_value[n])
2171+
_recursive_filled(result, self._mask, fill_value)
2172+
# for n in result.dtype.names:
2173+
# field = result[n]
2174+
# np.putmask(field, self._mask[n], fill_value[n])
21462175
elif not m.any():
21472176
return self._data
21482177
else:
@@ -2287,6 +2316,58 @@ def __repr__(self):
22872316
return _print_templates['short'] % parameters
22882317
return _print_templates['long'] % parameters
22892318
#............................................
2319+
def __eq__(self, other):
2320+
"Check whether other equals self elementwise"
2321+
omask = getattr(other, '_mask', nomask)
2322+
if omask is nomask:
2323+
check = ndarray.__eq__(self.filled(0), other).view(type(self))
2324+
check._mask = self._mask
2325+
else:
2326+
odata = filled(other, 0)
2327+
check = ndarray.__eq__(self.filled(0), odata).view(type(self))
2328+
if self._mask is nomask:
2329+
check._mask = omask
2330+
else:
2331+
mask = mask_or(self._mask, omask)
2332+
if mask.dtype.names:
2333+
if mask.size > 1:
2334+
axis = 1
2335+
else:
2336+
axis = None
2337+
try:
2338+
mask = mask.view((bool_, len(self.dtype))).all(axis)
2339+
except ValueError:
2340+
mask = np.all([[f[n].all() for n in mask.dtype.names]
2341+
for f in mask], axis=axis)
2342+
check._mask = mask
2343+
return check
2344+
#
2345+
def __ne__(self, other):
2346+
"Check whether other doesn't equal self elementwise"
2347+
omask = getattr(other, '_mask', nomask)
2348+
if omask is nomask:
2349+
check = ndarray.__ne__(self.filled(0), other).view(type(self))
2350+
check._mask = self._mask
2351+
else:
2352+
odata = filled(other, 0)
2353+
check = ndarray.__ne__(self.filled(0), odata).view(type(self))
2354+
if self._mask is nomask:
2355+
check._mask = omask
2356+
else:
2357+
mask = mask_or(self._mask, omask)
2358+
if mask.dtype.names:
2359+
if mask.size > 1:
2360+
axis = 1
2361+
else:
2362+
axis = None
2363+
try:
2364+
mask = mask.view((bool_, len(self.dtype))).all(axis)
2365+
except ValueError:
2366+
mask = np.all([[f[n].all() for n in mask.dtype.names]
2367+
for f in mask], axis=axis)
2368+
check._mask = mask
2369+
return check
2370+
#
22902371
def __add__(self, other):
22912372
"Add other to self, and return a new masked array."
22922373
return add(self, other)

numpy/ma/tests/test_core.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -474,6 +474,16 @@ def test_filled_w_flexible_dtype(self):
474474
np.array([(1, '1', 1.)], dtype=flexi.dtype))
475475

476476

477+
def test_filled_w_nested_dtype(self):
478+
"Test filled w/ nested dtype"
479+
ndtype = [('A', int), ('B', [('BA', int), ('BB', int)])]
480+
a = array([(1, (1, 1)), (2, (2, 2))],
481+
mask=[(0, (1, 0)), (0, (0, 1))], dtype=ndtype)
482+
test = a.filled(0)
483+
control = np.array([(1, (0, 1)), (2, (2, 0))], dtype=ndtype)
484+
assert_equal(test, control)
485+
486+
477487
def test_optinfo_propagation(self):
478488
"Checks that _optinfo dictionary isn't back-propagated"
479489
x = array([1,2,3,], dtype=float)
@@ -884,6 +894,40 @@ def test_methods_with_output(self):
884894
self.failUnless(output[0] is masked)
885895

886896

897+
def test_eq_on_structured(self):
898+
"Test the equality of structured arrays"
899+
ndtype = [('A', int), ('B', int)]
900+
a = array([(1, 1), (2, 2)], mask=[(0, 1), (0, 0)], dtype=ndtype)
901+
test = (a == a)
902+
assert_equal(test, [True, True])
903+
assert_equal(test.mask, [False, False])
904+
b = array([(1, 1), (2, 2)], mask=[(1, 0), (0, 0)], dtype=ndtype)
905+
test = (a == b)
906+
assert_equal(test, [False, True])
907+
assert_equal(test.mask, [True, False])
908+
b = array([(1, 1), (2, 2)], mask=[(0, 1), (1, 0)], dtype=ndtype)
909+
test = (a == b)
910+
assert_equal(test, [True, False])
911+
assert_equal(test.mask, [False, False])
912+
913+
914+
def test_ne_on_structured(self):
915+
"Test the equality of structured arrays"
916+
ndtype = [('A', int), ('B', int)]
917+
a = array([(1, 1), (2, 2)], mask=[(0, 1), (0, 0)], dtype=ndtype)
918+
test = (a != a)
919+
assert_equal(test, [False, False])
920+
assert_equal(test.mask, [False, False])
921+
b = array([(1, 1), (2, 2)], mask=[(1, 0), (0, 0)], dtype=ndtype)
922+
test = (a != b)
923+
assert_equal(test, [True, False])
924+
assert_equal(test.mask, [True, False])
925+
b = array([(1, 1), (2, 2)], mask=[(0, 1), (1, 0)], dtype=ndtype)
926+
test = (a != b)
927+
assert_equal(test, [False, True])
928+
assert_equal(test.mask, [False, False])
929+
930+
887931
#------------------------------------------------------------------------------
888932

889933
class TestMaskedArrayAttributes(TestCase):

0 commit comments

Comments
 (0)
0