8000 ENH: simplify field indexing of structured arrays by ahaldane · Pull Request #5636 · numpy/numpy · GitHub
[go: up one dir, main page]

Skip to content

ENH: simplify field indexing of structured arrays #5636

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jun 17, 2015
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
8000
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 49 additions & 26 deletions numpy/core/_internal.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,17 @@
import sys
import warnings

from numpy.compat import asbytes, bytes
from numpy.compat import asbytes, bytes, basestring
from .multiarray import dtype, array, ndarray
import ctypes
from .numerictypes import object_

if (sys.byteorder == 'little'):
_nbo = asbytes('<')
else:
_nbo = asbytes('>')

def _makenames_list(adict, align):
from .multiarray import dtype
allfields = []
fnames = list(adict.keys())
for fname in fnames:
Expand Down Expand Up @@ -52,7 +54,6 @@ def _makenames_list(adict, align):
# a dictionary without "names" and "formats"
# fields is used as a data-type descriptor.
def _usefields(adict, align):
from .multiarray import dtype
try:
names = adict[-1]
except KeyError:
Expand Down Expand Up @@ -130,7 +131,6 @@ def _array_descr(descriptor):
# so don't remove the name here, or you'll
# break backward compatibilty.
def _reconstruct(subtype, shape, dtype):
from .multiarray import ndarray
return ndarray.__new__(subtype, shape, dtype)


Expand Down Expand Up @@ -194,12 +194,10 @@ def _commastring(astr):
return result

def _getintp_ctype():
from .multiarray import dtype
val = _getintp_ctype.cache
if val is not None:
return val
char = dtype('p').char
import ctypes
if (char == 'i'):
val = ctypes.c_int
elif char == 'l':
Expand All @@ -224,7 +222,6 @@ def c_void_p(self, num):
class _ctypes(object):
def __init__(self, array, ptr=None):
try:
import ctypes
self._ctypes = ctypes
except ImportError:
self._ctypes = _missing_ctypes()
Expand Down Expand Up @@ -287,23 +284,55 @@ def _newnames(datatype, order):
return tuple(list(order) + nameslist)
raise ValueError("unsupported order value: %s" % (order,))

# Given an array with fields and a sequence of field names
# construct a new array with just those fields copied over
def _index_fields(ary, fields):
from .multiarray import empty, dtype, array
def _index_fields(ary, names):
""" Given a structured array and a sequence of field names
construct new array with just those fields.

Parameters
----------
ary : ndarray
Structured array being subscripted
names : string or list of strings
Either a single field name, or a list of field names

Returns
-------
sub_ary : ndarray
If `names` is a single field name, the return value is identical to
ary.getfield, a writeable view into `ary`. If `names` is a list of
field names the return value is a copy of `ary` containing only those
fields. This is planned to return a view in the future.

Raises
------
ValueError
If `ary` does not contain a field given in `names`.

"""
dt = ary.dtype

names = [name for name in fields if name in dt.names]
formats = [dt.fields[name][0] for name in fields if name in dt.names]
offsets = [dt.fields[name][1] for name in fields if name in dt.names]
#use getfield to index a single field
if isinstance(names, basestring):
try:
return ary.getfield(dt.fields[names][0], dt.fields[names][1])
except KeyError:
raise ValueError("no field of name %s" % names)

for name in names:
if name not in dt.fields:
raise ValueError("no field of name %s" % name)

view_dtype = {'names':names, 'formats':formats, 'offsets':offsets, 'itemsize':dt.itemsize}
view = ary.view(dtype=view_dtype)
formats = [dt.fields[name][0] for name in names]
offsets = [dt.fields[name][1] for name in names]

view_dtype = {'names': names, 'formats': formats,
'offsets': offsets, 'itemsize': dt.itemsize}

# return copy for now (future plan to return ary.view(dtype=view_dtype))
copy_dtype = {'names': view_dtype['names'],
'formats': view_dtype['formats']}
return array(ary.view(dtype=view_dtype), dtype=copy_dtype, copy=True)

# Return a copy for now until behavior is fully deprecated
# in favor of returning view
copy_dtype = {'names':view_dtype['names'], 'formats':view_dtype['formats']}
return array(view, dtype=copy_dtype, copy=True)

def _get_all_field_offsets(dtype, base_offset=0):
""" Returns the types and offsets of all fields in a (possibly structured)
Expand Down Expand Up @@ -363,8 +392,6 @@ def _check_field_overlap(new_fields, old_fields):
If the new fields are incompatible with the old fields

"""
from .numerictypes import object_
from .multiarray import dtype

#first go byte by byte and check we do not access bytes not in old_fields
new_bytes = set()
Expand Down Expand Up @@ -527,8 +554,6 @@ def _view_is_safe(oldtype, newtype):
_pep3118_standard_typechars = ''.join(_pep3118_standard_map.keys())

def _dtype_from_pep3118(spec, byteorder='@', is_subdtype=False):
from numpy.core.multiarray import dtype

fields = {}
offset = 0
explicit_name = False
Expand Down Expand Up @@ -694,8 +719,6 @@ def get_dummy_name():

def _add_trailing_padding(value, padding):
"""Inject the specified number of padding bytes at the end of a dtype"""
from numpy.core.multiarray import dtype

if value.fields is None:
vfields = {'f0': (value, 0)}
else:
Expand Down
2 changes: 1 addition & 1 deletion numpy/core/src/multiarray/arraytypes.c.src
Original file line number Diff line number Diff line change
Expand Up @@ -726,7 +726,7 @@ VOID_setitem(PyObject *op, char *ip, PyArrayObject *ap)
PyObject *tup;
int savedflags;

res = -1;
res = 0;
/* get the names from the fields dictionary*/
names = descr->names;
n = PyTuple_GET_SIZE(names);
Expand Down
175 changes: 88 additions & 87 deletions numpy/core/src/multiarray/mapping.c
28BE
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,7 @@ prepare_index(PyArrayObject *self, PyObject *index,
}
for (i = 0; i < n; i++) {
PyObject *tmp_obj = PySequence_GetItem(index, i);
/* if getitem fails (unusual) treat this as a single index */
if (tmp_obj == NULL) {
PyErr_Clear();
make_tuple = 0;
Expand Down Expand Up @@ -1361,6 +1362,52 @@ array_subscript_asarray(PyArrayObject *self, PyObject *op)
return PyArray_EnsureAnyArray(array_subscript(self, op));
}

NPY_NO_EXPORT int
obj_is_string_or_stringlist(PyObject *op)
{
#if defined(NPY_PY3K)
if (PyUnicode_Check(op)) {
#else
if (PyString_Check(op) || PyUnicode_Check(op)) {
#endif
return 1;
}
else if (PySequence_Check(op) && !PyTuple_Check(op)) {
int seqlen, i;
PyObject *obj = NULL;
seqlen = PySequence_Size(op);

/* quit if we come across a 0-d array (seqlen==-1) or a 0-len array */
if (seqlen == -1) {
PyErr_Clear();
return 0;
}
if (seqlen == 0) {
return 0;
}

for (i = 0; i < seqlen; i++) {
obj = PySequence_GetItem(op, i);
if (obj == NULL) {
/* only happens for strange sequence objects. Silently fail */
PyErr_Clear();
return 0;
}

#if defined(NPY_PY3K)
if (!PyUnicode_Check(obj)) {
#else
if (!PyString_Check(obj) && !PyUnicode_Check(obj)) {
#endif
Py_DECREF(obj);
return 0;
}
Py_DECREF(obj);
}
return 1;
}
return 0;
}

/*
* General function for indexing a NumPy array with a Python object.
Expand All @@ -1382,76 +1429,26 @@ array_subscript(PyArrayObject *self, PyObject *op)

PyArrayMapIterObject * mit = NULL;

/* Check for multiple field access */
if (PyDataType_HASFIELDS(PyArray_DESCR(self))) {
/* Check for single field access */
/*
* TODO: Moving this code block into the HASFIELDS, means that
* string integers temporarily work as indices.
*/
if (PyString_Check(op) || PyUnicode_Check(op)) {
PyObject *temp, *obj;

if (PyDataType_HASFIELDS(PyArray_DESCR(self))) {
obj = PyDict_GetItem(PyArray_DESCR(self)->fields, op);
if (obj != NULL) {
PyArray_Descr *descr;
int offset;
PyObject *title;

if (PyArg_ParseTuple(obj, "Oi|O", &descr, &offset, &title)) {
Py_INCREF(descr);
return PyArray_GetField(self, descr, offset);
}
}
}
/* return fields if op is a string index */
if (PyDataType_HASFIELDS(PyArray_DESCR(self)) &&
obj_is_string_or_stringlist(op)) {
PyObject *obj;
static PyObject *indexfunc = NULL;
npy_cache_pyfunc("numpy.core._internal", "_index_fields", &indexfunc);
if (indexfunc == NULL) {
return NULL;
}

temp = op;
if (PyUnicode_Check(op)) {
temp = PyUnicode_AsUnicodeEscapeString(op);
}
PyErr_Format(PyExc_ValueError,
"field named %s not found",
PyBytes_AsString(temp));
if (temp != op) {
Py_DECREF(temp);
}
obj = PyObject_CallFunction(indexfunc, "OO", self, op);
if (obj == NULL) {
return NULL;
}

else if (PySequence_Check(op) && !PyTuple_Check(op)) {
int seqlen, i;
PyObject *obj;
seqlen = PySequence_Size(op);
for (i = 0; i < seqlen; i++) {
obj = PySequence_GetItem(op, i);
if (!PyString_Check(obj) && !PyUnicode_Check(obj)) {
Py_DECREF(obj);
break;
}
Py_DECREF(obj);
}
/*
* Extract multiple fields if all elements in sequence
* are either string or unicode (i.e. no break occurred).
*/
fancy = ((seqlen > 0) && (i == seqlen));
if (fancy) {
PyObject *_numpy_internal;
_numpy_internal = PyImport_ImportModule("numpy.core._internal");
if (_numpy_internal == NULL) {
return NULL;
}
obj = PyObject_CallMethod(_numpy_internal,
"_index_fields", "OO", self, op);
Py_DECREF(_numpy_internal);
if (obj == NULL) {
return NULL;
}
PyArray_ENABLEFLAGS((PyArrayObject*)obj, NPY_ARRAY_WARN_ON_WRITE);
return obj;
}
/* warn if writing to a copy. copies will have no base */
if (PyArray_BASE((PyArrayObject*)obj) == NULL) {
PyArray_ENABLEFLAGS((PyArrayObject*)obj, NPY_ARRAY_WARN_ON_WRITE);
}
return obj;
}

/* Prepare the indices */
Expand Down Expand Up @@ -1783,35 +1780,39 @@ array_assign_subscript(PyArrayObject *self, PyObject *ind, PyObject *op)
return -1;
}

/* Single field access */
if (PyDataType_HASFIELDS(PyArray_DESCR(self))) {
if (PyString_Check(ind) || PyUnicode_Check(ind)) {
PyObject *obj;

obj = PyDict_GetItem(PyArray_DESCR(self)->fields, ind);
if (obj != NULL) {
PyArray_Descr *descr;
int offset;
PyObject *title;
/* field access */
if (PyDataType_HASFIELDS(PyArray_DESCR(self)) &&
obj_is_string_or_stringlist(ind)) {
PyObject *obj;
static PyObject *indexfunc = NULL;

if (PyArg_ParseTuple(obj, "Oi|O", &descr, &offset, &title)) {
Py_INCREF(descr);
return PyArray_SetField(self, descr, offset, op);
}
}
#if defined(NPY_PY3K)
PyErr_Format(PyExc_ValueError,
"field named %S not found",
ind);
if (!PyUnicode_Check(ind)) {
#else
PyErr_Format(PyExc_ValueError,
"field named %s not found",
PyString_AsString(ind));
if (!PyString_Check(ind) && !PyUnicode_Check(ind)) {
#endif
PyErr_SetString(PyExc_ValueError,
"multi-field assignment is not supported");
}

npy_cache_pyfunc("numpy.core._internal", "_index_fields", &indexfunc);
if (indexfunc == NULL) {
return -1;
}

obj = PyObject_CallFunction(indexfunc, "OO", self, ind);
if (obj == NULL) {
return -1;
}
}

if (PyArray_CopyObject((PyArrayObject*)obj, op) < 0) {
Py_DECREF(obj);
return -1;
}
Py_DECREF(obj);

return 0;
}

/* Prepare the indices */
index_type = prepare_index(self, ind, indices, &index_num,
Expand Down
4 changes: 4 additions & 0 deletions numpy/core/tests/test_indexing.py
Original file line number Diff line number Diff line change
Expand Up @@ -409,6 +409,10 @@ def __getitem__(self, item):
arr = np.arange(10)
assert_array_equal(arr[SequenceLike()], arr[SequenceLike(),])

# also test that field indexing does not segfault
# for a similar reason, by indexing a structured array
arr = np.zeros((1,), dtype=[('f1', 'i8'), ('f2', 'i8')])
assert_array_equal(arr[SequenceLike()], arr[SequenceLike(),])

class TestFieldIndexing(TestCase):
def test_scalar_return_type(self):
Expand Down
Loading
0