10000 BUG: Fix misuse of .names and .fields in various places (backport 14290 to 1.17) by eric-wieser · Pull Request #14339 · numpy/numpy · GitHub
[go: up one dir, main page]

Skip to content

BUG: Fix misuse of .names and .fields in various places (backport 14290 to 1.17) #14339

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 13 commits into from
Aug 23, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion numpy/core/_internal.py
Original file line number Diff line number Diff line change
Expand Up @@ -459,7 +459,7 @@ def _getfield_is_safe(oldtype, newtype, offset):
if newtype.hasobject or oldtype.hasobject:
if offset == 0 and newtype == oldtype:
return
if oldtype.names:
if oldtype.names is not None:
for name in oldtype.names:
if (oldtype.fields[name][1] == offset and
oldtype.fields[name][0] == newtype):
Expand Down
2 changes: 1 addition & 1 deletion numpy/core/arrayprint.py
Original file line number Diff line number Diff line change
Expand Up @@ -682,7 +682,7 @@ def array2string(a, max_line_width=None, precision=None,
if style is np._NoValue:
style = repr

if a.shape == () and not a.dtype.names:
if a.shape == () and a.dtype.names is None:
return style(a.item())
elif style is not np._NoValue:
# Deprecation 11-9-2017 v1.14
Expand Down
18 changes: 9 additions & 9 deletions numpy/core/records.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,8 +268,8 @@ def __getattribute__(self, attr):
except AttributeError:
#happens if field is Object type
return obj
if dt.fields:
return obj.view((self.__class__, obj.dtype.fields))
if dt.names is not None:
return obj.view((self.__class__, obj.dtype))
return obj
else:
raise AttributeError("'record' object has no "
Expand All @@ -293,8 +293,8 @@ def __getitem__(self, indx):
obj = nt.void.__getitem__(self, indx)

# copy behavior of record.__getattribute__,
if isinstance(obj, nt.void) and obj.dtype.fields:
return obj.view((self.__class__, obj.dtype.fields))
if isinstance(obj, nt.void) and obj.dtype.names is not None:
return obj.view((self.__class__, obj.dtype))
else:
# return a single element
return obj
Expand Down Expand Up @@ -444,7 +444,7 @@ def __new__(subtype, shape, dtype=None, buf=None, offset=0, strides=None,
return self

def __array_finalize__(self, obj):
if self.dtype.type is not record and self.dtype.fields:
if self.dtype.type is not record and self.dtype.names is not None:
# if self.dtype is not np.record, invoke __setattr__ which will
# convert it to a record if it is a void dtype.
self.dtype = self.dtype
Expand Down Expand Up @@ -472,7 +472,7 @@ def __getattribute__(self, attr):
# with void type convert it to the same dtype.type (eg to preserve
# numpy.record type if present), since nested structured fields do not
# inherit type. Don't do this for non-void structures though.
if obj.dtype.fields:
if obj.dtype.names is not None:
if issubclass(obj.dtype.type, nt.void):
return obj.view(dtype=(self.dtype.type, obj.dtype))
return obj
Expand All @@ -487,7 +487,7 @@ def __setattr__(self, attr, val):

# Automatically convert (void) structured types to records
# (but not non-void structures, subarrays, or non-structured voids)
if attr == 'dtype' and issubclass(val.type, nt.void) and val.fields:
if attr == 'dtype' and issubclass(val.type, nt.void) and val.names is not None:
val = sb.dtype((record, val))

newattr = attr not in self.__dict__
Expand Down Expand Up @@ -521,7 +521,7 @@ def __getitem__(self, indx):
# copy behavior of getattr, except that here
# we might also be returning a single element
if isinstance(obj, ndarray):
if obj.dtype.fields:
if obj.dtype.names is not None:
obj = obj.view(type(self))
if issubclass(obj.dtype.type, nt.void):
return obj.view(dtype=(self.dtype.type, obj.dtype))
Expand Down Expand Up @@ -577,7 +577,7 @@ def field(self, attr, val=None):

if val is None:
obj = self.getfield(*res)
if obj.dtype.fields:
if obj.dtype.names is not None:
return obj
return obj.view(ndarray)
else:
Expand Down
42 changes: 42 additions & 0 deletions numpy/core/tests/test_records.py
Original file line number Diff line number Diff line change
Expand Up @@ -444,6 +444,48 @@ def test_fromarrays_nested_structured_arrays(self):
]
arr = np.rec.fromarrays(arrays) # ValueError?

@pytest.mark.parametrize('nfields', [0, 1, 2])
def test_assign_dtype_attribute(self, nfields):
dt = np.dtype([('a', np.uint8), ('b', np.uint8), ('c', np.uint8)][:nfields])
data = np.zeros(3, dt).view(np.recarray)

# the original and resulting dtypes differ on whether they are records
assert data.dtype.type == np.record
assert dt.type != np.record

# ensure that the dtype remains a record even when assigned
data.dtype = dt
assert data.dtype.type == np.record

@pytest.mark.parametrize('nfields', [0, 1, 2])
def test_nested_fields_are_records(self, nfields):
""" Test that nested structured types are treated as records too """
dt = np.dtype([('a', np.uint8), ('b', np.uint8), ('c', np.uint8)][:nfields])
dt_outer = np.dtype([('inner', dt)])

data = np.zeros(3, dt_outer).view(np.recarray)
assert isinstance(data, np.recarray)
assert isinstance(data['inner'], np.recarray)

data0 = data[0]
assert isinstance(data0, np.record)
assert isinstance(data0['inner'], np.record)

def test_nested_dtype_padding(self):
""" test that trailing padding is preserved """
# construct a dtype with padding at the end
dt = np.dtype([('a', np.uint8), ('b', np.uint8), ('c', np.uint8)])
dt_padded_end = dt[['a', 'b']]
assert dt_padded_end.itemsize == dt.itemsize

dt_outer = np.dtype([('inner', dt_padded_end)])

data = np.zeros(3, dt_outer).view(np.recarray)
assert_equal(data['inner'].dtype, dt_padded_end)

data0 = data[0]
assert_equal(data0['inner'].dtype, dt_padded_end)


def test_find_duplicate():
l1 = [1, 2, 3, 4, 5, 6]
Expand Down
2 changes: 1 addition & 1 deletion numpy/core/tests/test_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -468,7 +468,7 @@ def test_pickle_py2_bytes_encoding(self):
result = pickle.loads(data, encoding='bytes')
assert_equal(result, original)

if isinstance(result, np.ndarray) and result.dtype.names:
if isinstance(result, np.ndarray) and result.dtype.names is not None:
for name in result.dtype.names:
assert_(isinstance(name, str))

Expand Down
2 changes: 1 addition & 1 deletion numpy/ctypeslib.py
Original file line number Diff line number Diff line change
Expand Up @@ -321,7 +321,7 @@ def ndpointer(dtype=None, ndim=None, shape=None, flags=None):
# produce a name for the new type
if dtype is None:
name = 'any'
elif dtype.names:
elif dtype.names is not None:
name = str(id(dtype))
else:
name = dtype.str
Expand Down
13 changes: 6 additions & 7 deletions numpy/lib/_iotools.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ def has_nested_fields(ndtype):

"""
for name in ndtype.names or ():
if ndtype[name].names:
if ndtype[name].names is not None:
return True
return False

Expand Down Expand Up @@ -931,28 +931,27 @@ def easy_dtype(ndtype, names=None, defaultfmt="f%i", **validationargs):
names = validate(names, nbfields=nbfields, defaultfmt=defaultfmt)
ndtype = np.dtype(dict(formats=ndtype, names=names))
else:
nbtypes = len(ndtype)
# Explicit names
if names is not None:
validate = NameValidator(**validationargs)
if isinstance(names, basestring):
names = names.split(",")
# Simple dtype: repeat to match the nb of names
if nbtypes == 0:
if ndtype.names is None:
formats = tuple([ndtype.type] * len(names))
names = validate(names, defaultfmt=defaultfmt)
ndtype = np.dtype(list(zip(names, formats)))
# Structured dtype: just validate the names as needed
else:
ndtype.names = validate(names, nbfields=nbtypes,
ndtype.names = validate(names, nbfields=len(ndtype.names),
defaultfmt=defaultfmt)
# No implicit names
elif (nbtypes > 0):
elif ndtype.names is not None:
validate = NameValidator(**validationargs)
# Default initial names : should we change the format ?
if ((ndtype.names == tuple("f%i" % i for i in range(nbtypes))) and
if ((ndtype.names == tuple("f%i" % i for i in range(len(ndtype.names)))) and
(defaultfmt != "f%i")):
ndtype.names = validate([''] * nbtypes, defaultfmt=defaultfmt)
ndtype.names = validate([''] * len(ndtype.names), defaultfmt=defaultfmt)
# Explicit initial names : just validate
else:
ndtype.names = validate(ndtype.names, defaultfmt=defaultfmt)
Expand Down
4 changes: 2 additions & 2 deletions numpy/lib/npyio.py
Original file line number Diff line number Diff line change
Expand Up @@ -2168,7 +2168,7 @@ def encode_unicode_cols(row_tup):
outputmask = np.array(masks, dtype=mdtype)
else:
# Overwrite the initial dtype names if needed
if names and dtype.names:
if names and dtype.names is not None:
dtype.names = names
# Case 1. We have a structured type
if len(dtype_flat) > 1:
Expand Down Expand Up @@ -2218,7 +2218,7 @@ def encode_unicode_cols(row_tup):
#
output = np.array(data, dtype)
if usemask:
if dtype.names:
if dtype.names is not None:
mdtype = [(_, bool) for _ in dtype.names]
else:
mdtype = bool
Expand Down
22 changes: 11 additions & 11 deletions numpy/lib/recfunctions.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def recursive_fill_fields(input, output):
current = input[field]
except ValueError:
continue
if current.dtype.names:
if current.dtype.names is not None:
recursive_fill_fields(current, output[field])
else:
output[field][:len(current)] = current
Expand Down Expand Up @@ -139,11 +139,11 @@ def get_names(adtype):
names = adtype.names
for name in names:
current = adtype[name]
if current.names:
if current.names is not None:
listnames.append((name, tuple(get_names(current))))
else:
listnames.append(name)
return tuple(listnames) or None
return tuple(listnames)


def get_names_flat(adtype):
Expand Down Expand Up @@ -176,9 +176,9 @@ def get_names_flat(adtype):
for name in names:
listnames.append(name)
current = adtype[name]
if current.names:
if current.names is not None:
listnames.extend(get_names_flat(current))
return tuple(listnames) or None
return tuple(listnames)


def flatten_descr(ndtype):
Expand Down Expand Up @@ -215,8 +215,8 @@ def _zip_dtype(seqarrays, flatten=False):
else:
for a in seqarrays:
current = a.dtype
if current.names and len(current.names) <= 1:
# special case - dtypes of 0 or 1 field are flattened
if current.names is not None and len(current.names) == 1:
# special case - dtypes of 1 field are flattened
newdtype.extend(_get_fieldspec(current))
else:
newdtype.append(('', current))
Expand Down Expand Up @@ -268,7 +268,7 @@ def get_fieldstructure(adtype, lastname=None, parents=None,):
names = adtype.names
for name in names:
current = adtype[name]
if current.names:
if current.names is not None:
if lastname:
parents[name] = [lastname, ]
else:
Expand All @@ -281,7 +281,7 @@ def get_fieldstructure(adtype, lastname=None, parents=None,):
elif lastname:
lastparent = [lastname, ]
parents[name] = lastparent or []
return parents or None
return parents


def _izip_fields_flat(iterable):
Expand Down Expand Up @@ -435,7 +435,7 @@ def merge_arrays(seqarrays, fill_value=-1, flatten=False,
if isinstance(seqarrays, (ndarray, np.void)):
seqdtype = seqarrays.dtype
# Make sure we have named fields
if not seqdtype.names:
if seqdtype.names is None:
seqdtype = np.dtype([('', seqdtype)])
if not flatten or _zip_dtype((seqarrays,), flatten=True) == seqdtype:
# Minimal processing needed: just make sure everythng's a-ok
Expand Down Expand Up @@ -653,7 +653,7 @@ def _recursive_rename_fields(ndtype, namemapper):
for name in ndtype.names:
newname = namemapper.get(name, name)
current = ndtype[name]
if current.names:
if current.names is not None:
newdtype.append(
(newname, _recursive_rename_fields(current, namemapper))
)
Expand Down
7 changes: 7 additions & 0 deletions numpy/lib/tests/test_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -1565,6 +1565,13 @@ def test_dtype_with_object(self):
test = np.genfromtxt(TextIO(data), delimiter=";",
dtype=ndtype, converters=converters)

# nested but empty fields also aren't supported
ndtype = [('idx', int), ('code', object), ('nest', [])]
with assert_raises_regex(NotImplementedError,
'Nested fields.* not supported.*'):
test = np.genfromtxt(TextIO(data), delimiter=";",
dtype=ndtype, converters=converters)

def test_userconverters_with_explicit_dtype(self):
# Test user_converters w/ explicit (standard) dtype
data = TextIO('skip,skip,2001-01-01,1.0,skip')
Expand Down
21 changes: 21 additions & 0 deletions numpy/lib/tests/test_recfunctions.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,14 @@ def test_get_names(self):
test = get_names(ndtype)
assert_equal(test, ('a', ('b', ('ba', 'bb'))))

ndtype = np.dtype([('a', int), ('b', [])])
test = get_names(ndtype)
assert_equal(test, ('a', ('b', ())))

ndtype = np.dtype([])
test = get_names(ndtype)
assert_equal(test, ())

def test_get_names_flat(self):
# Test get_names_flat
ndtype = np.dtype([('A', '|S3'), ('B', float)])
Expand All @@ -125,6 +133,14 @@ def test_get_names_flat(self):
test = get_names_flat(ndtype)
assert_equal(test, ('a', 'b', 'ba', 'bb'))

ndtype = np.dtype([('a', int), ('b', [])])
test = get_names_flat(ndtype)
assert_equal(test, ('a', 'b'))

ndtype = np.dtype([])
test = get_names_flat(ndtype)
assert_equal(test, ())

def test_get_fieldstructure(self):
# Test get_fieldstructure

Expand All @@ -147,6 +163,11 @@ def test_get_fieldstructure(self):
'BBA': ['B', 'BB'], 'BBB': ['B', 'BB']}
assert_equal(test, control)

# 0 fields
ndtype = np.dtype([])
test = get_fieldstructure(ndtype)
assert_equal(test, {})

def test_find_duplicates(self):
# Test find_duplicates
a = ma.array([(2, (2., 'B')), (1, (2., 'B')), (2, (2., 'B')),
Expand Down
2 changes: 1 addition & 1 deletion numpy/ma/mrecords.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,7 @@ def __getattribute__(self, attr):
_localdict = ndarray.__getattribute__(self, '__dict__')
_data = ndarray.view(self, _localdict['_baseclass'])
obj = _data.getfield(*res)
if obj.dtype.fields:
if obj.dtype.names is not None:
raise NotImplementedError("MaskedRecords is currently limited to"
"simple records.")
# Get some special attributes
Expand Down
0