8000 Merge pull request #14290 from eric-wieser/fix-if-fields · numpy/numpy@cd4cda8 · GitHub
[go: up one dir, main page]

Skip to content

Commit cd4cda8

Browse files
authored
Merge pull request #14290 from eric-wieser/fix-if-fields
BUG: Fix misuse of .names and .fields in various places
2 parents 3ca0eb1 + 0f5e376 commit cd4cda8

12 files changed

+103
-34
lines changed

numpy/core/_internal.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -459,7 +459,7 @@ def _getfield_is_safe(oldtype, newtype, offset):
459459
if newtype.hasobject or oldtype.hasobject:
460460
if offset == 0 and newtype == oldtype:
461461
return
462-
if oldtype.names:
462+
if oldtype.names is not None:
463463
for name in oldtype.names:
464464
if (oldtype.fields[name][1] == offset and
465465
oldtype.fields[name][0] == newtype):

numpy/core/arrayprint.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -685,7 +685,7 @@ def array2string(a, max_line_width=None, precision=None,
685685
if style is np._NoValue:
686686
style = repr
687687

688-
if a.shape == () and not a.dtype.names:
688+
if a.shape == () and a.dtype.names is None:
689689
return style(a.item())
690690
elif style is not np._NoValue:
691691
# Deprecation 11-9-2017 v1.14

numpy/core/records.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -268,8 +268,8 @@ def __getattribute__(self, attr):
268268
except AttributeError:
269269
#happens if field is Object type
270270
return obj
271-
if dt.fields:
272-
return obj.view((self.__class__, obj.dtype.fields))
271+
if dt.names is not None:
272+
return obj.view((self.__class__, obj.dtype))
273273
return obj
274274
else:
275275
raise AttributeError("'record' object has no "
@@ -293,8 +293,8 @@ def __getitem__(self, indx):
293293
obj = nt.void.__getitem__(self, indx)
294294

295295
# copy behavior of record.__getattribute__,
296-
if isinstance(obj, nt.void) and obj.dtype.fields:
297-
return obj.view((self.__class__, obj.dtype.fields))
296+
if isinstance(obj, nt.void) and obj.dtype.names is not None:
297+
return obj.view((self.__class__, obj.dtype))
298298
else:
299299
# return a single element
300300
return obj
@@ -444,7 +444,7 @@ def __new__(subtype, shape, dtype=None, buf=None, offset=0, strides=None,
444444
return self
445445

446446
def __array_finalize__(self, obj):
447-
if self.dtype.type is not record and self.dtype.fields:
447+
if self.dtype.type is not record and self.dtype.names is not None:
448448
# if self.dtype is not np.record, invoke __setattr__ which will
449449
# convert it to a record if it is a void dtype.
450450
self.dtype = self.dtype
@@ -472,7 +472,7 @@ def __getattribute__(self, attr):
472472
# with void type convert it to the same dtype.type (eg to preserve
473473
# numpy.record type if present), since nested structured fields do not
474474
# inherit type. Don't do this for non-void structures though.
475-
if obj.dtype.fields:
475+
if obj.dtype.names is not None:
476476
if issubclass(obj.dtype.type, nt.void):
477477
return obj.view(dtype=(self.dtype.type, obj.dtype))
478478
return obj
@@ -487,7 +487,7 @@ def __setattr__(self, attr, val):
487487

488488
# Automatically convert (void) structured types to records
489489
# (but not non-void structures, subarrays, or non-structured voids)
490-
if attr == 'dtype' and issubclass(val.type, nt.void) and val.fields:
490+
if attr == 'dtype' and issubclass(val.type, nt.void) and val.names is not None:
491491
val = sb.dtype((record, val))
492492

493493
newattr = attr not in self.__dict__
@@ -521,7 +521,7 @@ def __getitem__(self, indx):
521521
# copy behavior of getattr, except that here
522522
# we might also be returning a single element
523523
if isinstance(obj, ndarray):
524-
if obj.dtype.fields:
524+
if obj.dtype.names is not None:
525525
obj = obj.view(type(self))
526526
if issubclass(obj.dtype.type, nt.void):
527527
return obj.view(dtype=(self.dtype.type, obj.dtype))
@@ -577,7 +577,7 @@ def field(self, attr, val=None):
577577

578578
if val is None:
579579
obj = self.getfield(*res)
580-
if obj.dtype.fields:
580+
if obj.dtype.names is not None:
581581
return obj
582582
return obj.view(ndarray)
583583
else:

numpy/core/tests/test_records.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -444,6 +444,48 @@ def test_fromarrays_nested_structured_arrays(self):
444444
]
445445
arr = np.rec.fromarrays(arrays) # ValueError?
446446

447+
@pytest.mark.parametrize('nfields', [0, 1, 2])
448+
def test_assign_dtype_attribute(self, nfields):
449+
dt = np.dtype([('a', np.uint8), ('b', np.uint8), ('c', np.uint8)][:nfields])
450+
data = np.zeros(3, dt).view(np.recarray)
451+
452+
# the original and resulting dtypes differ on whether they are records
453+
assert data.dtype.type == np.record
454+
assert dt.type != np.record
455+
456+
# ensure that the dtype remains a record even when assigned
457+
data.dtype = dt
458+
assert data.dtype.type == np.record
459+
460+
@pytest.mark.parametrize('nfields', [0, 1, 2])
461+
def test_nested_fields_are_records(self, nfields):
462+
""" Test that nested structured types are treated as records too """
463+
dt = np.dtype([('a', np.uint8), ('b', np.uint8), ('c', np.uint8)][:nfields])
464+
dt_outer = np.dtype([('inner', dt)])
465+
466+
data = np.zeros(3, dt_outer).view(np.recarray)
467+
assert isinstance(data, np.recarray)
468+
assert isinstance(data['inner'], np.recarray)
469+
470+
data0 = data[0]
471+
assert isinstance(data0, np.record)
472+
assert isinstance(data0['inner'], np.record)
473+
474+
def test_nested_dtype_padding(self):
475+
""" test that trailing padding is preserved """
476+
# construct a dtype with padding at the end
477+
dt = np.dtype([('a', np.uint8), ('b', np.uint8), ('c', np.uint8)])
478+
dt_padded_end = dt[['a', 'b']]
479+
assert dt_padded_end.itemsize == dt.itemsize
480+
481+
dt_outer = np.dtype([('inner', dt_padded_end)])
482+
483+
data = np.zeros(3, dt_outer).view(np.recarray)
484+
assert_equal(data['inner'].dtype, dt_padded_end)
485+
486+
data0 = data[0]
487+
assert_equal(data0['inner'].dtype, dt_padded_end)
488+
447489

448490
def test_find_duplicate():
449491
l1 = [1, 2, 3, 4, 5, 6]

numpy/core/tests/test_regression.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -468,7 +468,7 @@ def test_pickle_py2_bytes_encoding(self):
468468
result = pickle.loads(data, encoding='bytes')
469469
assert_equal(result, original)
470470

471-
if isinstance(result, np.ndarray) and result.dtype.names:
471+
if isinstance(result, np.ndarray) and result.dtype.names is not None:
472472
for name in result.dtype.names:
473473
assert_(isinstance(name, str))
474474

numpy/ctypeslib.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -321,7 +321,7 @@ def ndpointer(dtype=None, ndim=None, shape=None, flags=None):
321321
# produce a name for the new type
322322
if dtype is None:
323323
name = 'any'
324-
elif dtype.names:
324+
elif dtype.names is not None:
325325
name = str(id(dtype))
326326
else:
327327
name = dtype.str

numpy/lib/_iotools.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,7 @@ def has_nested_fields(ndtype):
121121
122122
"""
123123
for name in ndtype.names or ():
124-
if ndtype[name].names:
124+
if ndtype[name].names is not None:
125125
return True
126126
return False
127127

@@ -931,28 +931,27 @@ def easy_dtype(ndtype, names=None, defaultfmt="f%i", **validationargs):
931931
names = validate(names, nbfields=nbfields, defaultfmt=defaultfmt)
932932
ndtype = np.dtype(dict(formats=ndtype, names=names))
933933
else:
934-
nbtypes = len(ndtype)
935934
# Explicit names
936935
if names is not None:
937936
validate = NameValidator(**validationargs)
938937
if isinstance(names, basestring):
939938
names = names.split(",")
940939
# Simple dtype: repeat to match the nb of names
941-
if nbtypes == 0:
940+
if ndtype.names is None:
942941
formats = tuple([ndtype.type] * len(names))
943942
names = validate(names, defaultfmt=defaultfmt)
944943
ndtype = np.dtype(list(zip(names, formats)))
945944
# Structured dtype: just validate the names as needed
946945
else:
947-
ndtype.names = validate(names, nbfields=nbtypes,
946+
ndtype.names = validate(names, nbfields=len(ndtype.names),
948947
defaultfmt=defaultfmt)
949948
# No implicit names
950-
elif (nbtypes > 0):
949+
elif ndtype.names is not None:
951950
validate = NameValidator(**validationargs)
952951
# Default initial names : should we change the format ?
953-
if ((ndtype.names == tuple("f%i" % i for i in range(nbtypes))) and
952+
if ((ndtype.names == tuple("f%i" % i for i in range(len(ndtype.names)))) and
954953
(defaultfmt != "f%i")):
955-
ndtype.names = validate([''] * nbtypes, defaultfmt=defaultfmt)
954+
ndtype.names = validate([''] * len(ndtype.names), defaultfmt=defaultfmt)
956955
# Explicit initial names : just validate
957956
else:
958957
ndtype.names = validate(ndtype.names, defaultfmt=defaultfmt)

numpy/lib/npyio.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2180,7 +2180,7 @@ def encode_unicode_cols(row_tup):
21802180
outputmask = np.array(masks, dtype=mdtype)
21812181
else:
21822182
# Overwrite the initial dtype names if needed
2183-
if names and dtype.names:
2183+
if names and dtype.names is not None:
21842184
dtype.names = names
21852185
# Case 1. We have a structured type
21862186
if len(dtype_flat) > 1:
@@ -2230,7 +2230,7 @@ def encode_unicode_cols(row_tup):
22302230
#
22312231
output = np.array(data, dtype)
22322232
if usemask:
2233-
if dtype.names:
2233+
if dtype.names is not None:
22342234
mdtype = [(_, bool) for _ in dtype.names]
22352235
else:
22362236
mdtype = bool

numpy/lib/recfunctions.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ def recursive_fill_fields(input, output):
7272
current = input[field]
7373
except ValueError:
7474
continue
75-
if current.dtype.names:
75+
if current.dtype.names is not None:
7676
recursive_fill_fields(current, output[field])
7777
else:
7878
output[field][:len(current)] = current
@@ -139,11 +139,11 @@ def get_names(adtype):
139139
names = adtype.names
140140
for name in names:
141141
current = adtype[name]
142-
if current.names:
142+
if current.names is not None:
143143
listnames.append((name, tuple(get_names(current))))
144144
else:
145145
listnames.append(name)
146-
return tuple(listnames) or None
146+
return tuple(listnames)
147147

148148

149149
def get_names_flat(adtype):
@@ -176,9 +176,9 @@ def get_names_flat(adtype):
176176
for name in names:
177177
listnames.append(name)
178178
current = adtype[name]
179-
if current.names:
179+
if current.names is not None:
180180
listnames.extend(get_names_flat(current))
181-
return tuple(listnames) or None
181+
return tuple(listnames)
182182

183183

184184
def flatten_descr(ndtype):
@@ -215,8 +215,8 @@ def _zip_dtype(seqarrays, flatten=False):
215215
else:
216216
for a in seqarrays:
217217
current = a.dtype
218-
if current.names and len(current.names) <= 1:
219-
# special case - dtypes of 0 or 1 field are flattened
218+
if current.names is not None and len(current.names) == 1:
219+
# special case - dtypes of 1 field are flattened
220220
newdtype.extend(_get_fieldspec(current))
221221
else:
222222
newdtype.append(('', current))
@@ -268,7 +268,7 @@ def get_fieldstructure(adtype, lastname=None, parents=None,):
268268
names = adtype.names
269269
for name in names:
270270
current = adtype[name]
271-
if current.names:
271+
if current.names is not None:
272272
if lastname:
273273
parents[name] = [lastname, ]
274274
else:
@@ -281,7 +281,7 @@ def get_fieldstructure(adtype, lastname=None, parents=None,):
281281
elif lastname:
282282
lastparent = [lastname, ]
283283
parents[name] = lastparent or []
284-
return parents or None
284+
return parents
285285

286286

287287
def _izip_fields_flat(iterable):
@@ -435,7 +435,7 @@ def merge_arrays(seqarrays, fill_value=-1, flatten=False,
435435
if isinstance(seqarrays, (ndarray, np.void)):
436436
seqdtype = seqarrays.dtype
437437
# Make sure we have named fields
438-
if not seqdtype.names:
438+
if seqdtype.names is None:
439439
seqdtype = np.dtype([('', seqdtype)])
440440
if not flatten or _zip_dtype((seqarrays,), flatten=True) == seqdtype:
441441
# Minimal processing needed: just make sure everythng's a-ok
@@ -653,7 +653,7 @@ def _recursive_rename_fields(ndtype, namemapper):
653653
for name in ndtype.names:
654654
newname = namemapper.get(name, name)
655655
current = ndtype[name]
656-
if current.names:
656+
if current.names is not None:
657657
newdtype.append(
658658
(newname, _recursive_rename_fields(current, namemapper))
659659
)

numpy/lib/tests/test_io.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1565,6 +1565,13 @@ def test_dtype_with_object(self):
15651565
test = np.genfromtxt(TextIO(data), delimiter=";",
15661566
dtype=ndtype, converters=converters)
15671567

1568+
# nested but empty fields also aren't supported
1569+
ndtype = [('idx', int), ('code', object), ('nest', [])]
1570+
with assert_raises_regex(NotImplementedError,
1571+
'Nested fields.* not supported.*'):
1572+
test = np.genfromtxt(TextIO(data), delimiter=";",
1573+
dtype=ndtype, converters=converters)
1574+
15681575
def test_userconverters_with_explicit_dtype(self):
15691576
# Test user_converters w/ explicit (standard) dtype
15701577
data = TextIO('skip,skip,2001-01-01,1.0,skip')

numpy/lib/tests/test_recfunctions.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,14 @@ def test_get_names(self):
115115
test = get_names(ndtype)
116116
assert_equal(test, ('a', ('b', ('ba', 'bb'))))
117117

118+
ndtype = np.dtype([('a', int), ('b', [])])
119+
test = get_names(ndtype)
120+
assert_equal(test, ('a', ('b', ())))
121+
122+
ndtype = np.dtype([])
123+
test = get_names(ndtype)
124+
assert_equal(test, ())
125+
118126
def test_get_names_flat(self):
119127
# Test get_names_flat
120128
ndtype = np.dtype([('A', '|S3'), ('B', float)])
@@ -125,6 +133,14 @@ def test_get_names_flat(self):
125133
test = get_names_flat(ndtype)
126134
assert_equal(test, ('a', 'b', 'ba', 'bb'))
127135

136+
ndtype = np.dtype([('a', int), ('b', [])])
137+
test = get_names_flat(ndtype)
138+
assert_equal(test, ('a', 'b'))
139+
140+
ndtype = np.dtype([])
141+
test = get_names_flat(ndtype)
142+
assert_equal(test, ())
143+
128144
def test_get_fieldstructure(self):
129145
# Test get_fieldstructure
130146

@@ -147,6 +163,11 @@ def test_get_fieldstructure(self):
147163
'BBA': ['B', 'BB'], 'BBB': ['B', 'BB']}
148164
assert_equal(test, control)
149165

166+
# 0 fields
167+
ndtype = np.dtype([])
168+
test = get_fieldstructure(ndtype)
169+
assert_equal(test, {})
170+
150171
def test_find_duplicates(self):
151172
# Test find_duplicates
152173
a = ma.array([(2, (2., 'B')), (1, (2., 'B')), (2, (2., 'B')),

numpy/ma/mrecords.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -208,7 +208,7 @@ def __getattribute__(self, attr):
208208
_localdict = ndarray.__getattribute__(self, '__dict__')
209209
_data = ndarray.view(self, _localdict['_baseclass'])
210210
obj = _data.getfield(*res)
211-
if obj.dtype.fields:
211+
if obj.dtype.names is not None:
212212
raise NotImplementedError("MaskedRecords is currently limited to"
213213
"simple records.")
214214
# Get some special attributes

0 commit comments

Comments
 (0)
0