diff --git a/doc/release/1.17.0-notes.rst b/doc/release/1.17.0-notes.rst index 1fa6fb9d6a8b..0007fe1a638c 100644 --- a/doc/release/1.17.0-notes.rst +++ b/doc/release/1.17.0-notes.rst @@ -240,6 +240,18 @@ Floating point scalars implement ``as_integer_ratio`` to match the builtin float This returns a (numerator, denominator) pair, which can be used to construct a `fractions.Fraction`. +structured ``dtype`` objects can be indexed with multiple fields names +---------------------------------------------------------------------- +``arr.dtype[['a', 'b']]`` now returns a dtype that is equivalent to +``arr[['a', 'b']].dtype``, for consistency with +``arr.dtype['a'] == arr['a'].dtype``. + +Like the dtype of structured arrays indexed with a list of fields, this dtype +has the same `itemsize` as the original, but only keeps a subset of the fields. + +This means that `arr[['a', 'b']]` and ``arr.view(arr.dtype[['a', 'b']])`` are +equivalent. + ``.npy`` files support unicode field names ------------------------------------------ A new format version of 3.0 has been introduced, which enables structured types @@ -435,4 +447,9 @@ Additionally, there are some corner cases with behavior changes: ------------------------------------------------------ The interface may use an ``offset`` value that was mistakenly ignored. +Structured arrays indexed with non-existent fields raise ``KeyError`` not ``ValueError`` +---------------------------------------------------------------------------------------- +``arr['bad_field']`` on a structured type raises ``KeyError``, for consistency +with ``dict['bad_field']``. + .. _`NEP 18` : http://www.numpy.org/neps/nep-0018-array-function-protocol.html diff --git a/numpy/core/src/multiarray/descriptor.c b/numpy/core/src/multiarray/descriptor.c index e3e0da24ac49..24716aecf31c 100644 --- a/numpy/core/src/multiarray/descriptor.c +++ b/numpy/core/src/multiarray/descriptor.c @@ -3370,6 +3370,117 @@ _subscript_by_index(PyArray_Descr *self, Py_ssize_t i) return ret; } +static npy_bool +_is_list_of_strings(PyObject *obj) +{ + int seqlen, i; + if (!PyList_CheckExact(obj)) { + return NPY_FALSE; + } + seqlen = PyList_GET_SIZE(obj); + for (i = 0; i < seqlen; i++) { + PyObject *item = PyList_GET_ITEM(obj, i); + if (!PyBaseString_Check(item)) { + return NPY_FALSE; + } + } + + return NPY_TRUE; +} + +NPY_NO_EXPORT PyArray_Descr * +arraydescr_field_subset_view(PyArray_Descr *self, PyObject *ind) +{ + int seqlen, i; + PyObject *fields = NULL; + PyObject *names = NULL; + PyArray_Descr *view_dtype; + + seqlen = PySequence_Size(ind); + if (seqlen == -1) { + return NULL; + } + + fields = PyDict_New(); + if (fields == NULL) { + goto fail; + } + names = PyTuple_New(seqlen); + if (names == NULL) { + goto fail; + } + + for (i = 0; i < seqlen; i++) { + PyObject *name; + PyObject *tup; + + name = PySequence_GetItem(ind, i); + if (name == NULL) { + goto fail; + } + + /* Let the names tuple steal a reference now, so we don't need to + * decref name if an error occurs further on. + */ + PyTuple_SET_ITEM(names, i, name); + + tup = PyDict_GetItem(self->fields, name); + if (tup == NULL) { + PyErr_SetObject(PyExc_KeyError, name); + goto fail; + } + + /* disallow use of titles as index */ + if (PyTuple_Size(tup) == 3) { + PyObject *title = PyTuple_GET_ITEM(tup, 2); + int titlecmp = PyObject_RichCompareBool(title, name, Py_EQ); + if (titlecmp < 0) { + goto fail; + } + if (titlecmp == 1) { + /* if title == name, we were given a title, not a field name */ + PyErr_SetString(PyExc_KeyError, + "cannot use field titles in multi-field index"); + goto fail; + } + if (PyDict_SetItem(fields, title, tup) < 0) { + goto fail; + } + } + /* disallow duplicate field indices */ + if (PyDict_Contains(fields, name)) { + PyObject *msg = NULL; + PyObject *fmt = PyUString_FromString( + "duplicate field of name {!r}"); + if (fmt != NULL) { + msg = PyObject_CallMethod(fmt, "format", "O", name); + Py_DECREF(fmt); + } + PyErr_SetObject(PyExc_ValueError, msg); + Py_XDECREF(msg); + goto fail; + } + if (PyDict_SetItem(fields, name, tup) < 0) { + goto fail; + } + } + + view_dtype = PyArray_DescrNewFromType(NPY_VOID); + if (view_dtype == NULL) { + goto fail; + } + view_dtype->elsize = self->elsize; + view_dtype->names = names; + view_dtype->fields = fields; + view_dtype->flags = self->flags; + return view_dtype; + +fail: + Py_XDECREF(fields); + Py_XDECREF(names); + return NULL; +} + static PyObject * descr_subscript(PyArray_Descr *self, PyObject *op) { @@ -3380,6 +3491,9 @@ descr_subscript(PyArray_Descr *self, PyObject *op) if (PyBaseString_Check(op)) { return _subscript_by_name(self, op); } + else if (_is_list_of_strings(op)) { + return (PyObject *)arraydescr_field_subset_view(self, op); + } else { Py_ssize_t i = PyArray_PyIntAsIntp(op); if (error_converting(i)) { @@ -3387,7 +3501,8 @@ descr_subscript(PyArray_Descr *self, PyObject *op) PyObject *err = PyErr_Occurred(); if (PyErr_GivenExceptionMatches(err, PyExc_TypeError)) { PyErr_SetString(PyExc_TypeError, - "Field key must be an integer, string, or unicode."); + "Field key must be an integer field offset, " + "single field name, or list of field names."); } return NULL; } diff --git a/numpy/core/src/multiarray/descriptor.h b/numpy/core/src/multiarray/descriptor.h index a5f3b8cdf1f1..de749736349a 100644 --- a/numpy/core/src/multiarray/descriptor.h +++ b/numpy/core/src/multiarray/descriptor.h @@ -14,6 +14,18 @@ _arraydescr_from_dtype_attr(PyObject *obj); NPY_NO_EXPORT int is_dtype_struct_simple_unaligned_layout(PyArray_Descr *dtype); +/* + * Filter the fields of a dtype to only those in the list of strings, ind. + * + * No type checking is performed on the input. + * + * Raises: + * ValueError - if a field is repeated + * KeyError - if an invalid field name (or any field title) is used + */ +NPY_NO_EXPORT PyArray_Descr * +arraydescr_field_subset_view(PyArray_Descr *self, PyObject *ind); + extern NPY_NO_EXPORT char *_datetime_strings[]; #endif diff --git a/numpy/core/src/multiarray/mapping.c b/numpy/core/src/multiarray/mapping.c index 10206c03ee41..1109e6dd5076 100644 --- a/numpy/core/src/multiarray/mapping.c +++ b/numpy/core/src/multiarray/mapping.c @@ -15,6 +15,7 @@ #include "common.h" #include "ctors.h" +#include "descriptor.h" #include "iterators.h" #include "mapping.h" #include "lowlevel_strided_loops.h" @@ -1393,9 +1394,9 @@ array_subscript_asarray(PyArrayObject *self, PyObject *op) /* * Attempts to subscript an array using a field name or list of field names. * - * If an error occurred, return 0 and set view to NULL. If the subscript is not - * a string or list of strings, return -1 and set view to NULL. Otherwise - * return 0 and set view to point to a new view into arr for the given fields. + * ret = 0, view != NULL: view points to the requested fields of arr + * ret = 0, view == NULL: an error occurred + * ret = -1, view == NULL: unrecognized input, this is not a field index. */ NPY_NO_EXPORT int _get_field_view(PyArrayObject *arr, PyObject *ind, PyArrayObject **view) @@ -1438,111 +1439,44 @@ _get_field_view(PyArrayObject *arr, PyObject *ind, PyArrayObject **view) } return 0; } + /* next check for a list of field names */ else if (PySequence_Check(ind) && !PyTuple_Check(ind)) { - int seqlen, i; - PyObject *name = NULL, *tup; - PyObject *fields, *names; + npy_intp seqlen, i; PyArray_Descr *view_dtype; seqlen = PySequence_Size(ind); - /* quit if have a 0-d array (seqlen==-1) or a 0-len array */ + /* quit if have a fake sequence-like, which errors on len()*/ if (seqlen == -1) { PyErr_Clear(); return -1; } + /* 0-len list is handled elsewhere as an integer index */ if (seqlen == 0) { return -1; } - fields = PyDict_New(); - if (fields == NULL) { - return 0; - } - names = PyTuple_New(seqlen); - if (names == NULL) { - Py_DECREF(fields); - return 0; - } - + /* check the items are strings */ for (i = 0; i < seqlen; i++) { - name = PySequence_GetItem(ind, i); - if (name == NULL) { - /* only happens for strange sequence objects */ + npy_bool is_string; + PyObject *item = PySequence_GetItem(ind, i); + if (item == NULL) { PyErr_Clear(); - Py_DECREF(fields); - Py_DECREF(names); return -1; } - - if (!PyBaseString_Check(name)) { - Py_DECREF(name); - Py_DECREF(fields); - Py_DECREF(names); + is_string = PyBaseString_Check(item); + Py_DECREF(item); + if (!is_string) { return -1; } - - tup = PyDict_GetItem(PyArray_DESCR(arr)->fields, name); - if (tup == NULL){ - PyObject *errmsg = PyUString_FromString("no field of name "); - PyUString_ConcatAndDel(&errmsg, name); - PyErr_SetObject(PyExc_ValueError, errmsg); - Py_DECREF(errmsg); - Py_DECREF(fields); - Py_DECREF(names); - return 0; - } - /* disallow use of titles as index */ - if (PyTuple_Size(tup) == 3) { - PyObject *title = PyTuple_GET_ITEM(tup, 2); - int titlecmp = PyObject_RichCompareBool(title, name, Py_EQ); - if (titlecmp == 1) { - /* if title == name, we got a title, not a field name */ - PyErr_SetString(PyExc_KeyError, - "cannot use field titles in multi-field index"); - } - if (titlecmp != 0 || PyDict_SetItem(fields, title, tup) < 0) { - Py_DECREF(name); - Py_DECREF(fields); - Py_DECREF(names); - return 0; - } - } - /* disallow duplicate field indices */ - if (PyDict_Contains(fields, name)) { - PyObject *errmsg = PyUString_FromString( - "duplicate field of name "); - PyUString_ConcatAndDel(&errmsg, name); - PyErr_SetObject(PyExc_ValueError, errmsg); - Py_DECREF(errmsg); - Py_DECREF(fields); - Py_DECREF(names); - return 0; - } - if (PyDict_SetItem(fields, name, tup) < 0) { - Py_DECREF(name); - Py_DECREF(fields); - Py_DECREF(names); - return 0; - } - if (PyTuple_SetItem(names, i, name) < 0) { - Py_DECREF(fields); - Py_DECREF(names); - return 0; - } } - view_dtype = PyArray_DescrNewFromType(NPY_VOID); + /* Call into the dtype subscript */ + view_dtype = arraydescr_field_subset_view(PyArray_DESCR(arr), ind); if (view_dtype == NULL) { - Py_DECREF(fields); - Py_DECREF(names); return 0; } - view_dtype->elsize = PyArray_DESCR(arr)->elsize; - view_dtype->names = names; - view_dtype->fields = fields; - view_dtype->flags = PyArray_DESCR(arr)->flags; *view = (PyArrayObject*)PyArray_NewFromDescr_int( Py_TYPE(arr), diff --git a/numpy/core/tests/test_dtype.py b/numpy/core/tests/test_dtype.py index f4736d6940c7..8f33a8daf126 100644 --- a/numpy/core/tests/test_dtype.py +++ b/numpy/core/tests/test_dtype.py @@ -316,15 +316,66 @@ def test_fields_by_index(self): assert_raises(IndexError, lambda: dt[-3]) assert_raises(TypeError, operator.getitem, dt, 3.0) - assert_raises(TypeError, operator.getitem, dt, []) assert_equal(dt[1], dt[np.int8(1)]) + @pytest.mark.parametrize('align_flag',[False, True]) + def test_multifield_index(self, align_flag): + # indexing with a list produces subfields + # the align flag should be preserved + dt = np.dtype([ + (('title', 'col1'), '