8000 ENH: Allow dtype objects to be indexed with multiple fields at once by eric-wieser · Pull Request #10417 · numpy/numpy · GitHub
[go: up one dir, main page]

Skip to content

ENH: Allow dtype objects to be indexed with multiple fields at once #10417

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Jun 7, 2019
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
MAINT: Reuse code from dtype.__getitem__ in mapping.c:_get_field_view
  • Loading branch information
eric-wieser committed May 11, 2019
commit 90f710be699b26b99899a73dd6cc840f9908e4b6
8 changes: 4 additions & 4 deletions numpy/core/src/multiarray/descriptor.c
Original file line number Diff line number Diff line change
Expand Up 8000 @@ -3381,8 +3381,8 @@ _is_list_of_strings(PyObject *obj)
return NPY_TRUE;
}

static PyObject *
_subscript_by_list_of_strings(PyArray_Descr *self, PyObject *ind)
NPY_NO_EXPORT PyArray_Descr *
arraydescr_field_subset_view(PyArray_Descr *self, PyObject *ind)
{
int seqlen, i;
PyObject *fields = NULL;
Expand Down Expand Up @@ -3466,7 +3466,7 @@ _subscript_by_list_of_strings(PyArray_Descr *self, PyObject *ind)
view_dtype->names = names;
view_dtype->fields = fields;
view_dtype->flags = self->flags;
return (PyObject *)view_dtype;
return view_dtype;

fail:
Py_XDECREF(fields);
Expand All @@ -3485,7 +3485,7 @@ descr_subscript(PyArray_Descr *self, PyObject *op)
return _subscript_by_name(self, op);
}
else if (_is_list_of_strings(op)) {
return _subscript_by_list_of_strings(self, op);
return (PyObject *)arraydescr_field_subset_view(self, op);
}
else {
Py_ssize_t i = PyArray_PyIntAsIntp(op);
Expand Down
12 changes: 12 additions & 0 deletions numpy/core/src/multiarray/descriptor.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,18 @@ _arraydescr_from_dtype_attr(PyObject *obj);
NPY_NO_EXPORT int
is_dtype_struct_simple_unaligned_layout(PyArray_Descr *dtype);

/*
* Filter the fields of a dtype to only those in the list of strings, ind.
*
* No type checking is performed on the input.
*
* Raises:
* ValueError - if a field is repeated
* KeyError - if an invalid field name (or any field title) is used
*/
NPY_NO_EXPORT PyArray_Descr *
arraydescr_field_subset_view(PyArray_Descr *self, PyObject *ind);

extern NPY_NO_EXPORT char *_datetime_strings[];

#endif
100 changes: 17 additions & 83 deletions numpy/core/src/multiarray/mapping.c
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

#include "common.h"
#include "ctors.h"
#include "descriptor.h"
#include "iterators.h"
#include "mapping.h"
#include "lowlevel_strided_loops.h"
Expand Down Expand Up @@ -1393,9 +1394,9 @@ array_subscript_asarray(PyArrayObject *self, PyObject *op)
/*
* Attempts to subscript an array using a field name or list of field names.
*
* If an error occurred, return 0 and set view to NULL. If the subscript is not
* a string or list of strings, return -1 and set view to NULL. Otherwise
* return 0 and set view to point to a new view into arr for the given fields.
* ret = 0, view != NULL: view points to the requested fields of arr
* ret = 0, view == NULL: an error occurred
* ret = -1, view == NULL: unrecognized input, this is not a field index.
*/
NPY_NO_EXPORT int
_get_field_view(PyArrayObject *arr, PyObject *ind, PyArrayObject **view)
Expand Down Expand Up @@ -1438,111 +1439,44 @@ _get_field_view(PyArrayObject *arr, PyObject *ind, PyArrayObject **view)
}
return 0;
}

/* next check for a list of field names */
else if (PySequence_Check(ind) && !PyTuple_Check(ind)) {
int seqlen, i;
PyObject *name = NULL, *tup;
PyObject *fields, *names;
npy_intp seqlen, i;
PyArray_Descr *view_dtype;

seqlen = PySequence_Size(ind);

/* quit if have a 0-d array (seqlen==-1) or a 0-len array */
/* quit if have a fake sequence-like, which errors on len()*/
if (seqlen == -1) {
PyErr_Clear();
return -1;
}
/* 0-len list is handled elsewhere as an integer index */
if (seqlen == 0) {
return -1;
}

fields = PyDict_New();
if (fields == NULL) {
return 0;
}
names = PyTuple_New(seqlen);
if (names == NULL) {
Py_DECREF(fields);
return 0;
}

/* check the items are strings */
for (i = 0; i < seqlen; i++) {
name = PySequence_GetItem(ind, i);
if (name == NULL) {
/* only happens for strange sequence objects */
npy_bool is_string;
PyObject *item = PySequence_GetItem(ind, i);
if(item == NULL) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if(item == NULL) {
if (item == NULL) {

Just nitpicking for no reason.

PyErr_Clear();
Py_DECREF(fields);
Py_DECREF(names);
return -1;
}

if (!PyBaseString_Check(name)) {
Py_DECREF(name);
Py_DECREF(fields);
Py_DECREF(names);
is_string = PyBaseString_Check(item);
Py_DECREF(item);
if (!is_string) {
return -1;
}

tup = PyDict_GetItem(PyArray_DESCR(arr)->fields, name);
if (tup == NULL){
PyObject *errmsg = PyUString_FromString("no field of name ");
PyUString_ConcatAndDel(&errmsg, name);
PyErr_SetObject(PyExc_ValueError, errmsg);
Py_DECREF(errmsg);
Py_DECREF(fields);
Py_DECREF(names);
return 0;
}
/* disallow use of titles as index */
if (PyTuple_Size(tup) == 3) {
PyObject *title = PyTuple_GET_ITEM(tup, 2);
int titlecmp = PyObject_RichCompareBool(title, name, Py_EQ);
if (titlecmp == 1) {
/* if title == name, we got a title, not a field name */
PyErr_SetString(PyExc_KeyError,
"cannot use field titles in multi-field index");
}
if (titlecmp != 0 || PyDict_SetItem(fields, title, tup) < 0) {
Py_DECREF(name);
Py_DECREF(fields);
Py_DECREF(names);
return 0;
}
}
/* disallow duplicate field indices */
if (PyDict_Contains(fields, name)) {
PyObject *errmsg = PyUString_FromString(
"duplicate field of name ");
PyUString_ConcatAndDel(&errmsg, name);
PyErr_SetObject(PyExc_ValueError, errmsg);
Py_DECREF(errmsg);
Py_DECREF(fields);
Py_DECREF(names);
return 0;
}
if (PyDict_SetItem(fields, name, tup) < 0) {
Py_DECREF(name);
Py_DECREF(fields);
Py_DECREF(names);
return 0;
}
if (PyTuple_SetItem(names, i, name) < 0) {
Py_DECREF(fields);
Py_DECREF(names);
return 0;
}
}

view_dtype = PyArray_DescrNewFromType(NPY_VOID);
/* Call into the dtype subscript */
view_dtype = arraydescr_field_subset_view(PyArray_DESCR(arr), ind);
if (view_dtype == NULL) {
Py_DECREF(fields);
Py_DECREF(names);
return 0;
}
view_dtype->elsize = PyArray_DESCR(arr)->elsize;
view_dtype->names = names;
view_dtype->fields = fields;
view_dtype->flags = PyArray_DESCR(arr)->flags;

*view = (PyArrayObject*)PyArray_NewFromDescr_int(
Py_TYPE(arr),
Expand Down
2 changes: 1 addition & 1 deletion numpy/core/tests/test_records.py
Original file line number Diff line number Diff line change
Expand Up @@ -435,7 +435,7 @@ def test_record_scalar_setitem(self):
def test_missing_field(self):
# https://github.com/numpy/numpy/issues/4806
arr = np.zeros((3,), dtype=[('x', int), ('y', int)])
assert_raises(ValueError, lambda: arr[['nofield']])
assert_raises(KeyError, lambda: arr[['nofield']])

def test_fromarrays_nested_structured_arrays(self):
arrays = [
Expand Down
0