8000 BUG: fix _array2string for structured array (issue #5692) by skwbc · Pull Request #8160 · numpy/numpy · GitHub
[go: up one dir, main page]

Skip to content

BUG: fix _array2string for structured array (issue #5692) #8160

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Oct 18, 2016
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
99 changes: 66 additions & 33 deletions numpy/core/arrayprint.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,27 +234,23 @@ def _boolFormatter(x):
def repr_format(x):
return repr(x)

def _array2string(a, max_line_width, precision, suppress_small, separator=' ',
prefix="", formatter=None):

if max_line_width is None:
max_line_width = _line_width

if precision is None:
precision = _float_output_precision

if suppress_small is None:
suppress_small = _float_output_suppress_small

if formatter is None:
formatter = _formatter

if a.size > _summaryThreshold:
summary_insert = "..., "
data = _leading_trailing(a)
else:
summary_insert = ""
data = ravel(asarray(a))
def _get_format_function(data, precision, suppress_small, formatter):
"""
find the right formatting function for the dtype_
"""
dtype_ = data.dtype
if dtype_.fields is not None:
format_functions = []
for descr in dtype_.descr:
field_name = descr[0]
field_values = data[field_name]
if len(field_values.shape) <= 1:
format_function = _get_format_function(
field_values, precision, suppress_small, formatter)
else:
format_function = repr_format
format_functions.append(format_function)
return StructureFormat(format_functions)

formatdict = {'bool': _boolFormatter,
'int': IntegerFormat(data),
Expand Down Expand Up @@ -289,31 +285,56 @@ def _array2string(a, max_line_width, precision, suppress_small, separator=' ',
if key in fkeys:
formatdict[key] = formatter[key]

# find the right formatting function for the array
dtypeobj = a.dtype.type
dtypeobj = dtype_.type
if issubclass(dtypeobj, _nt.bool_):
format_function = formatdict['bool']
return formatdict['bool']
elif issubclass(dtypeobj, _nt.integer):
if issubclass(dtypeobj, _nt.timedelta64):
format_function = formatdict['timedelta']
return formatdict['timedelta']
else:
format_function = formatdict['int']
return formatdict['int']
elif issubclass(dtypeobj, _nt.floating):
if issubclass(dtypeobj, _nt.longfloat):
format_function = formatdict['longfloat']
return formatdict['longfloat']
else:
format_function = formatdict['float']
return formatdict['float']
elif issubclass(dtypeobj, _nt.complexfloating):
if issubclass(dtypeobj, _nt.clongfloat):
format_function = formatdict['longcomplexfloat']
return formatdict['longcomplexfloat']
else:
format_function = formatdict['complexfloat']
return formatdict['complexfloat']
elif issubclass(dtypeobj, (_nt.unicode_, _nt.string_)):
format_function = formatdict['numpystr']
return formatdict['numpystr']
elif issubclass(dtypeobj, _nt.datetime64):
format_function = formatdict['datetime']
return formatdict['datetime']
else:
format_function = formatdict['numpystr']
return formatdict['numpystr']

def _array2string(a, max_line_width, precision, suppress_small, separator=' ',
prefix="", formatter=None):

if max_line_width is None:
max_line_width = _line_width

if precision is None:
precision = _float_output_precision

if suppress_small is None:
suppress_small = _float_output_suppress_small

if formatter is None:
formatter = _formatter

if a.size > _summaryThreshold:
summary_insert = "..., "
data = _leading_trailing(a)
else:
summary_insert = ""
data = ravel(asarray(a))

# find the right formatting function for the array
format_function = _get_format_function(data, precision,
suppress_small, formatter)

# skip over "["
next_line_prefix = " "
Expand Down Expand Up @@ -758,3 +779,15 @@ def __call__(self, x):
return self._nat
else:
return self.format % x.astype('i8')


class StructureFormat(object):
def __init__(self, format_functions):
self.format_functions = format_functions
self.num_fields = len(format_functions)

def __call__(self, x):
s = "("
for field, format_function in zip(x, self.format_functions):
s += format_function(field) + ", "
return (s[:-2] if 1 < self.num_fields else s[:-1]) + ")"
15 changes: 15 additions & 0 deletions numpy/core/tests/test_arrayprint.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,21 @@ def _format_function(x):
assert_(np.array2string(s, formatter={'numpystr':lambda s: s*2}) ==
'[abcabc defdef]')

def test_structure_format(self):
dt = np.dtype([('name', np.str_, 16), ('grades', np.float64, (2,))])
x = np.array([('Sarah', (8.0, 7.0)), ('John', (6.0, 7.0))], dtype=dt)
assert_equal(np.array2string(x),
"[('Sarah', array([ 8., 7.])) ('John', array([ 6., 7.]))]")

# for issue #5692
A = np.zeros(shape=10, dtype=[("A", "M8[s]")])
A[5:].fill(np.nan)
assert_equal(np.array2string(A),
"[('1970-01-01T00:00:00',) ('1970-01-01T00:00:00',) " +
"('1970-01-01T00:00:00',)\n ('1970-01-01T00:00:00',) " +
"('1970-01-01T00:00:00',) ('NaT',) ('NaT',)\n " +
"('NaT',) ('NaT',) ('NaT',)]")


class TestPrintOptions:
"""Test getting and setting global print options."""
Expand Down
0