From 3c1a13dea6a7e189675977ad65ea230ce4816061 Mon Sep 17 00:00:00 2001 From: Allan Haldane Date: Thu, 5 Mar 2015 16:58:18 -0500 Subject: [PATCH] ENH: simplify field indexing of structured arrays This commit simplifies the code in array_subscript and array_assign_subscript related to field access. This fixes #4806, and also removes a potential segfaults, eg if the array is indexed using an sequence-like object that raises an exception in getitem. Also fixes #5631, related to creation of structured dtypes with no fields (an unusual and probably useless edge case). Also moves all imports in _internal.py to the top. Fixes #4806. Fixes #5631. --- numpy/core/_internal.py | 75 ++++++--- numpy/core/src/multiarray/arraytypes.c.src | 2 +- numpy/core/src/multiarray/mapping.c | 175 +++++++++++---------- numpy/core/tests/test_indexing.py | 4 + numpy/core/tests/test_multiarray.py | 30 ++-- numpy/core/tests/test_records.py | 5 + 6 files changed, 165 insertions(+), 126 deletions(-) diff --git a/numpy/core/_internal.py b/numpy/core/_internal.py index e80c22dfe0d3..a20bf10e4982 100644 --- a/numpy/core/_internal.py +++ b/numpy/core/_internal.py @@ -10,7 +10,10 @@ import sys import warnings -from numpy.compat import asbytes, bytes +from numpy.compat import asbytes, bytes, basestring +from .multiarray import dtype, array, ndarray +import ctypes +from .numerictypes import object_ if (sys.byteorder == 'little'): _nbo = asbytes('<') @@ -18,7 +21,6 @@ _nbo = asbytes('>') def _makenames_list(adict, align): - from .multiarray import dtype allfields = [] fnames = list(adict.keys()) for fname in fnames: @@ -52,7 +54,6 @@ def _makenames_list(adict, align): # a dictionary without "names" and "formats" # fields is used as a data-type descriptor. def _usefields(adict, align): - from .multiarray import dtype try: names = adict[-1] except KeyError: @@ -130,7 +131,6 @@ def _array_descr(descriptor): # so don't remove the name here, or you'll # break backward compatibilty. def _reconstruct(subtype, shape, dtype): - from .multiarray import ndarray return ndarray.__new__(subtype, shape, dtype) @@ -194,12 +194,10 @@ def _commastring(astr): return result def _getintp_ctype(): - from .multiarray import dtype val = _getintp_ctype.cache if val is not None: return val char = dtype('p').char - import ctypes if (char == 'i'): val = ctypes.c_int elif char == 'l': @@ -224,7 +222,6 @@ def c_void_p(self, num): class _ctypes(object): def __init__(self, array, ptr=None): try: - import ctypes self._ctypes = ctypes except ImportError: self._ctypes = _missing_ctypes() @@ -287,23 +284,55 @@ def _newnames(datatype, order): return tuple(list(order) + nameslist) raise ValueError("unsupported order value: %s" % (order,)) -# Given an array with fields and a sequence of field names -# construct a new array with just those fields copied over -def _index_fields(ary, fields): - from .multiarray import empty, dtype, array +def _index_fields(ary, names): + """ Given a structured array and a sequence of field names + construct new array with just those fields. + + Parameters + ---------- + ary : ndarray + Structured array being subscripted + names : string or list of strings + Either a single field name, or a list of field names + + Returns + ------- + sub_ary : ndarray + If `names` is a single field name, the return value is identical to + ary.getfield, a writeable view into `ary`. If `names` is a list of + field names the return value is a copy of `ary` containing only those + fields. This is planned to return a view in the future. + + Raises + ------ + ValueError + If `ary` does not contain a field given in `names`. + + """ dt = ary.dtype - names = [name for name in fields if name in dt.names] - formats = [dt.fields[name][0] for name in fields if name in dt.names] - offsets = [dt.fields[name][1] for name in fields if name in dt.names] + #use getfield to index a single field + if isinstance(names, basestring): + try: + return ary.getfield(dt.fields[names][0], dt.fields[names][1]) + except KeyError: + raise ValueError("no field of name %s" % names) + + for name in names: + if name not in dt.fields: + raise ValueError("no field of name %s" % name) - view_dtype = {'names':names, 'formats':formats, 'offsets':offsets, 'itemsize':dt.itemsize} - view = ary.view(dtype=view_dtype) + formats = [dt.fields[name][0] for name in names] + offsets = [dt.fields[name][1] for name in names] + + view_dtype = {'names': names, 'formats': formats, + 'offsets': offsets, 'itemsize': dt.itemsize} + + # return copy for now (future plan to return ary.view(dtype=view_dtype)) + copy_dtype = {'names': view_dtype['names'], + 'formats': view_dtype['formats']} + return array(ary.view(dtype=view_dtype), dtype=copy_dtype, copy=True) - # Return a copy for now until behavior is fully deprecated - # in favor of returning view - copy_dtype = {'names':view_dtype['names'], 'formats':view_dtype['formats']} - return array(view, dtype=copy_dtype, copy=True) def _get_all_field_offsets(dtype, base_offset=0): """ Returns the types and offsets of all fields in a (possibly structured) @@ -363,8 +392,6 @@ def _check_field_overlap(new_fields, old_fields): If the new fields are incompatible with the old fields """ - from .numerictypes import object_ - from .multiarray import dtype #first go byte by byte and check we do not access bytes not in old_fields new_bytes = set() @@ -527,8 +554,6 @@ def _view_is_safe(oldtype, newtype): _pep3118_standard_typechars = ''.join(_pep3118_standard_map.keys()) def _dtype_from_pep3118(spec, byteorder='@', is_subdtype=False): - from numpy.core.multiarray import dtype - fields = {} offset = 0 explicit_name = False @@ -694,8 +719,6 @@ def get_dummy_name(): def _add_trailing_padding(value, padding): """Inject the specified number of padding bytes at the end of a dtype""" - from numpy.core.multiarray import dtype - if value.fields is None: vfields = {'f0': (value, 0)} else: diff --git a/numpy/core/src/multiarray/arraytypes.c.src b/numpy/core/src/multiarray/arraytypes.c.src index abe3d749d35a..307f9a0c0c48 100644 --- a/numpy/core/src/multiarray/arraytypes.c.src +++ b/numpy/core/src/multiarray/arraytypes.c.src @@ -726,7 +726,7 @@ VOID_setitem(PyObject *op, char *ip, PyArrayObject *ap) PyObject *tup; int savedflags; - res = -1; + res = 0; /* get the names from the fields dictionary*/ names = descr->names; n = PyTuple_GET_SIZE(names); diff --git a/numpy/core/src/multiarray/mapping.c b/numpy/core/src/multiarray/mapping.c index a703f3d845c9..de9a2d444296 100644 --- a/numpy/core/src/multiarray/mapping.c +++ b/numpy/core/src/multiarray/mapping.c @@ -215,6 +215,7 @@ prepare_index(PyArrayObject *self, PyObject *index, } for (i = 0; i < n; i++) { PyObject *tmp_obj = PySequence_GetItem(index, i); + /* if getitem fails (unusual) treat this as a single index */ if (tmp_obj == NULL) { PyErr_Clear(); make_tuple = 0; @@ -1361,6 +1362,52 @@ array_subscript_asarray(PyArrayObject *self, PyObject *op) return PyArray_EnsureAnyArray(array_subscript(self, op)); } +NPY_NO_EXPORT int +obj_is_string_or_stringlist(PyObject *op) +{ +#if defined(NPY_PY3K) + if (PyUnicode_Check(op)) { +#else + if (PyString_Check(op) || PyUnicode_Check(op)) { +#endif + return 1; + } + else if (PySequence_Check(op) && !PyTuple_Check(op)) { + int seqlen, i; + PyObject *obj = NULL; + seqlen = PySequence_Size(op); + + /* quit if we come across a 0-d array (seqlen==-1) or a 0-len array */ + if (seqlen == -1) { + PyErr_Clear(); + return 0; + } + if (seqlen == 0) { + return 0; + } + + for (i = 0; i < seqlen; i++) { + obj = PySequence_GetItem(op, i); + if (obj == NULL) { + /* only happens for strange sequence objects. Silently fail */ + PyErr_Clear(); + return 0; + } + +#if defined(NPY_PY3K) + if (!PyUnicode_Check(obj)) { +#else + if (!PyString_Check(obj) && !PyUnicode_Check(obj)) { +#endif + Py_DECREF(obj); + return 0; + } + Py_DECREF(obj); + } + return 1; + } + return 0; +} /* * General function for indexing a NumPy array with a Python object. @@ -1382,76 +1429,26 @@ array_subscript(PyArrayObject *self, PyObject *op) PyArrayMapIterObject * mit = NULL; - /* Check for multiple field access */ - if (PyDataType_HASFIELDS(PyArray_DESCR(self))) { - /* Check for single field access */ - /* - * TODO: Moving this code block into the HASFIELDS, means that - * string integers temporarily work as indices. - */ - if (PyString_Check(op) || PyUnicode_Check(op)) { - PyObject *temp, *obj; - - if (PyDataType_HASFIELDS(PyArray_DESCR(self))) { - obj = PyDict_GetItem(PyArray_DESCR(self)->fields, op); - if (obj != NULL) { - PyArray_Descr *descr; - int offset; - PyObject *title; - - if (PyArg_ParseTuple(obj, "Oi|O", &descr, &offset, &title)) { - Py_INCREF(descr); - return PyArray_GetField(self, descr, offset); - } - } - } + /* return fields if op is a string index */ + if (PyDataType_HASFIELDS(PyArray_DESCR(self)) && + obj_is_string_or_stringlist(op)) { + PyObject *obj; + static PyObject *indexfunc = NULL; + npy_cache_pyfunc("numpy.core._internal", "_index_fields", &indexfunc); + if (indexfunc == NULL) { + return NULL; + } - temp = op; - if (PyUnicode_Check(op)) { - temp = PyUnicode_AsUnicodeEscapeString(op); - } - PyErr_Format(PyExc_ValueError, - "field named %s not found", - PyBytes_AsString(temp)); - if (temp != op) { - Py_DECREF(temp); - } + obj = PyObject_CallFunction(indexfunc, "OO", self, op); + if (obj == NULL) { return NULL; } - else if (PySequence_Check(op) && !PyTuple_Check(op)) { - int seqlen, i; - PyObject *obj; - seqlen = PySequence_Size(op); - for (i = 0; i < seqlen; i++) { - obj = PySequence_GetItem(op, i); - if (!PyString_Check(obj) && !PyUnicode_Check(obj)) { - Py_DECREF(obj); - break; - } - Py_DECREF(obj); - } - /* - * Extract multiple fields if all elements in sequence - * are either string or unicode (i.e. no break occurred). - */ - fancy = ((seqlen > 0) && (i == seqlen)); - if (fancy) { - PyObject *_numpy_internal; - _numpy_internal = PyImport_ImportModule("numpy.core._internal"); - if (_numpy_internal == NULL) { - return NULL; - } - obj = PyObject_CallMethod(_numpy_internal, - "_index_fields", "OO", self, op); - Py_DECREF(_numpy_internal); - if (obj == NULL) { - return NULL; - } - PyArray_ENABLEFLAGS((PyArrayObject*)obj, NPY_ARRAY_WARN_ON_WRITE); - return obj; - } + /* warn if writing to a copy. copies will have no base */ + if (PyArray_BASE((PyArrayObject*)obj) == NULL) { + PyArray_ENABLEFLAGS((PyArrayObject*)obj, NPY_ARRAY_WARN_ON_WRITE); } + return obj; } /* Prepare the indices */ @@ -1783,35 +1780,39 @@ array_assign_subscript(PyArrayObject *self, PyObject *ind, PyObject *op) return -1; } - /* Single field access */ - if (PyDataType_HASFIELDS(PyArray_DESCR(self))) { - if (PyString_Check(ind) || PyUnicode_Check(ind)) { - PyObject *obj; - - obj = PyDict_GetItem(PyArray_DESCR(self)->fields, ind); - if (obj != NULL) { - PyArray_Descr *descr; - int offset; - PyObject *title; + /* field access */ + if (PyDataType_HASFIELDS(PyArray_DESCR(self)) && + obj_is_string_or_stringlist(ind)) { + PyObject *obj; + static PyObject *indexfunc = NULL; - if (PyArg_ParseTuple(obj, "Oi|O", &descr, &offset, &title)) { - Py_INCREF(descr); - return PyArray_SetField(self, descr, offset, op); - } - } #if defined(NPY_PY3K) - PyErr_Format(PyExc_ValueError, - "field named %S not found", - ind); + if (!PyUnicode_Check(ind)) { #else - PyErr_Format(PyExc_ValueError, - "field named %s not found", - PyString_AsString(ind)); + if (!PyString_Check(ind) && !PyUnicode_Check(ind)) { #endif + PyErr_SetString(PyExc_ValueError, + "multi-field assignment is not supported"); + } + + npy_cache_pyfunc("numpy.core._internal", "_index_fields", &indexfunc); + if (indexfunc == NULL) { + return -1; + } + + obj = PyObject_CallFunction(indexfunc, "OO", self, ind); + if (obj == NULL) { return -1; } - } + if (PyArray_CopyObject((PyArrayObject*)obj, op) < 0) { + Py_DECREF(obj); + return -1; + } + Py_DECREF(obj); + + return 0; + } /* Prepare the indices */ index_type = prepare_index(self, ind, indices, &index_num, diff --git a/numpy/core/tests/test_indexing.py b/numpy/core/tests/test_indexing.py index e55c212b78e1..d412c44fb313 100644 --- a/numpy/core/tests/test_indexing.py +++ b/numpy/core/tests/test_indexing.py @@ -409,6 +409,10 @@ def __getitem__(self, item): arr = np.arange(10) assert_array_equal(arr[SequenceLike()], arr[SequenceLike(),]) + # also test that field indexing does not segfault + # for a similar reason, by indexing a structured array + arr = np.zeros((1,), dtype=[('f1', 'i8'), ('f2', 'i8')]) + assert_array_equal(arr[SequenceLike()], arr[SequenceLike(),]) class TestFieldIndexing(TestCase): def test_scalar_return_type(self): diff --git a/numpy/core/tests/test_multiarray.py b/numpy/core/tests/test_multiarray.py index 6f45972c3c53..ac645f01322c 100644 --- a/numpy/core/tests/test_multiarray.py +++ b/numpy/core/tests/test_multiarray.py @@ -3484,7 +3484,7 @@ def test_bytes_fields(self): assert_raises(ValueError, dt.__getitem__, asbytes('a')) x = np.array([(1,), (2,), (3,)], dtype=dt) - assert_raises(ValueError, x.__getitem__, asbytes('a')) + assert_raises(IndexError, x.__getitem__, asbytes('a')) y = x[0] assert_raises(IndexError, y.__getitem__, asbytes('a')) @@ -3517,8 +3517,8 @@ def test_field_names(self): if is_py3: funcs = (str,) # byte string indexing fails gracefully - assert_raises(ValueError, a.__setitem__, asbytes('f1'), 1) - assert_raises(ValueError, a.__getitem__, asbytes('f1')) + assert_raises(IndexError, a.__setitem__, asbytes('f1'), 1) + assert_raises(IndexError, a.__getitem__, asbytes('f1')) assert_raises(IndexError, a['f1'].__setitem__, asbytes('sf1'), 1) assert_raises(IndexError, a['f1'].__getitem__, asbytes('sf1')) else: @@ -3564,7 +3564,7 @@ def test_field_names(self): def test_field_names_deprecation(self): - def collect_warning_types(f, *args, **kwargs): + def collect_warnings(f, *args, **kwargs): with warnings.catch_warnings(record=True) as log: warnings.simplefilter("always") f(*args, **kwargs) @@ -3585,20 +3585,19 @@ def collect_warning_types(f, *args, **kwargs): # All the different functions raise a warning, but not an error, and # 'a' is not modified: - assert_equal(collect_warning_types(a[['f1', 'f2']].__setitem__, 0, (10, 20)), + assert_equal(collect_warnings(a[['f1', 'f2']].__setitem__, 0, (10, 20)), [FutureWarning]) assert_equal(a, b) # Views also warn subset = a[['f1', 'f2']] subset_view = subset.view() - assert_equal(collect_warning_types(subset_view['f1'].__setitem__, 0, 10), + assert_equal(collect_warnings(subset_view['f1'].__setitem__, 0, 10), [FutureWarning]) # But the write goes through: assert_equal(subset['f1'][0], 10) - # Only one warning per multiple field indexing, though (even if there are - # multiple views involved): - assert_equal(collect_warning_types(subset['f1'].__setitem__, 0, 10), - []) + # Only one warning per multiple field indexing, though (even if there + # are multiple views involved): + assert_equal(collect_warnings(subset['f1'].__setitem__, 0, 10), []) def test_record_hash(self): a = np.array([(1, 2), (1, 2)], dtype='i1,i2') @@ -3616,11 +3615,18 @@ def test_record_no_hash(self): a = np.array([(1, 2), (1, 2)], dtype='i1,i2') self.assertRaises(TypeError, hash, a[0]) + def test_empty_structure_creation(self): + # make sure these do not raise errors (gh-5631) + array([()], dtype={'names': [], 'formats': [], + 'offsets': [], 'itemsize': 12}) + array([(), (), (), (), ()], dtype={'names': [], 'formats': [], + 'offsets': [], 'itemsize': 12}) + class TestView(TestCase): def test_basic(self): x = np.array([(1, 2, 3, 4), (5, 6, 7, 8)], - dtype=[('r', np.int8), ('g', np.int8), - ('b', np.int8), ('a', np.int8)]) + dtype=[('r', np.int8), ('g', np.int8), + ('b', np.int8), ('a', np.int8)]) # We must be specific about the endianness here: y = x.view(dtype='