diff --git a/numpy/core/arrayprint.py b/numpy/core/arrayprint.py index cd618d72aace..7a84eb7c2d49 100644 --- a/numpy/core/arrayprint.py +++ b/numpy/core/arrayprint.py @@ -234,27 +234,23 @@ def _boolFormatter(x): def repr_format(x): return repr(x) -def _array2string(a, max_line_width, precision, suppress_small, separator=' ', - prefix="", formatter=None): - - if max_line_width is None: - max_line_width = _line_width - - if precision is None: - precision = _float_output_precision - - if suppress_small is None: - suppress_small = _float_output_suppress_small - - if formatter is None: - formatter = _formatter - - if a.size > _summaryThreshold: - summary_insert = "..., " - data = _leading_trailing(a) - else: - summary_insert = "" - data = ravel(asarray(a)) +def _get_format_function(data, precision, suppress_small, formatter): + """ + find the right formatting function for the dtype_ + """ + dtype_ = data.dtype + if dtype_.fields is not None: + format_functions = [] + for descr in dtype_.descr: + field_name = descr[0] + field_values = data[field_name] + if len(field_values.shape) <= 1: + format_function = _get_format_function( + field_values, precision, suppress_small, formatter) + else: + format_function = repr_format + format_functions.append(format_function) + return StructureFormat(format_functions) formatdict = {'bool': _boolFormatter, 'int': IntegerFormat(data), @@ -289,31 +285,56 @@ def _array2string(a, max_line_width, precision, suppress_small, separator=' ', if key in fkeys: formatdict[key] = formatter[key] - # find the right formatting function for the array - dtypeobj = a.dtype.type + dtypeobj = dtype_.type if issubclass(dtypeobj, _nt.bool_): - format_function = formatdict['bool'] + return formatdict['bool'] elif issubclass(dtypeobj, _nt.integer): if issubclass(dtypeobj, _nt.timedelta64): - format_function = formatdict['timedelta'] + return formatdict['timedelta'] else: - format_function = formatdict['int'] + return formatdict['int'] elif issubclass(dtypeobj, _nt.floating): if issubclass(dtypeobj, _nt.longfloat): - format_function = formatdict['longfloat'] + return formatdict['longfloat'] else: - format_function = formatdict['float'] + return formatdict['float'] elif issubclass(dtypeobj, _nt.complexfloating): if issubclass(dtypeobj, _nt.clongfloat): - format_function = formatdict['longcomplexfloat'] + return formatdict['longcomplexfloat'] else: - format_function = formatdict['complexfloat'] + return formatdict['complexfloat'] elif issubclass(dtypeobj, (_nt.unicode_, _nt.string_)): - format_function = formatdict['numpystr'] + return formatdict['numpystr'] elif issubclass(dtypeobj, _nt.datetime64): - format_function = formatdict['datetime'] + return formatdict['datetime'] else: - format_function = formatdict['numpystr'] + return formatdict['numpystr'] + +def _array2string(a, max_line_width, precision, suppress_small, separator=' ', + prefix="", formatter=None): + + if max_line_width is None: + max_line_width = _line_width + + if precision is None: + precision = _float_output_precision + + if suppress_small is None: + suppress_small = _float_output_suppress_small + + if formatter is None: + formatter = _formatter + + if a.size > _summaryThreshold: + summary_insert = "..., " + data = _leading_trailing(a) + else: + summary_insert = "" + data = ravel(asarray(a)) + + # find the right formatting function for the array + format_function = _get_format_function(data, precision, + suppress_small, formatter) # skip over "[" next_line_prefix = " " @@ -758,3 +779,15 @@ def __call__(self, x): return self._nat else: return self.format % x.astype('i8') + + +class StructureFormat(object): + def __init__(self, format_functions): + self.format_functions = format_functions + self.num_fields = len(format_functions) + + def __call__(self, x): + s = "(" + for field, format_function in zip(x, self.format_functions): + s += format_function(field) + ", " + return (s[:-2] if 1 < self.num_fields else s[:-1]) + ")" diff --git a/numpy/core/tests/test_arrayprint.py b/numpy/core/tests/test_arrayprint.py index 991ead97363f..97b5420ca508 100644 --- a/numpy/core/tests/test_arrayprint.py +++ b/numpy/core/tests/test_arrayprint.py @@ -113,6 +113,21 @@ def _format_function(x): assert_(np.array2string(s, formatter={'numpystr':lambda s: s*2}) == '[abcabc defdef]') + def test_structure_format(self): + dt = np.dtype([('name', np.str_, 16), ('grades', np.float64, (2,))]) + x = np.array([('Sarah', (8.0, 7.0)), ('John', (6.0, 7.0))], dtype=dt) + assert_equal(np.array2string(x), + "[('Sarah', array([ 8., 7.])) ('John', array([ 6., 7.]))]") + + # for issue #5692 + A = np.zeros(shape=10, dtype=[("A", "M8[s]")]) + A[5:].fill(np.nan) + assert_equal(np.array2string(A), + "[('1970-01-01T00:00:00',) ('1970-01-01T00:00:00',) " + + "('1970-01-01T00:00:00',)\n ('1970-01-01T00:00:00',) " + + "('1970-01-01T00:00:00',) ('NaT',) ('NaT',)\n " + + "('NaT',) ('NaT',) ('NaT',)]") + class TestPrintOptions: """Test getting and setting global print options."""