8000 ENH: simplify field indexing of structured arrays · numpy/numpy@3c1a13d · GitHub
[go: up one dir, main page]

Skip to content
/ numpy Public

Commit 3c1a13d

Browse files
committed
ENH: simplify field indexing of structured arrays
This commit simplifies the code in array_subscript and array_assign_subscript related to field access. This fixes #4806, and also removes a potential segfaults, eg if the array is indexed using an sequence-like object that raises an exception in getitem. Also fixes #5631, related to creation of structured dtypes with no fields (an unusual and probably useless edge case). Also moves all imports in _internal.py to the top. Fixes #4806. Fixes #5631.
1 parent 8c86a0a commit 3c1a13d

File tree

6 files changed

+165
-126
lines changed

6 files changed

+165
-126
lines changed

numpy/core/_internal.py

Lines changed: 49 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -10,15 +10,17 @@
1010
import sys
1111
import warnings
1212

13-
from numpy.compat import asbytes, bytes
13+
from numpy.compat import asbytes, bytes, basestring
14+
from .multiarray import dtype, array, ndarray
15+
import ctypes
16+
from .numerictypes import object_
1417

1518
if (sys.byteorder == 'little'):
1619
_nbo = asbytes('<')
1720
else:
1821
_nbo = asbytes('>')
1922

2023
def _makenames_list(adict, align):
21-
from .multiarray import dtype
2224
allfields = []
2325
fnames = list(adict.keys())
2426
for fname in fnames:
@@ -52,7 +54,6 @@ def _makenames_list(adict, align):
5254
# a dictionary without "names" and "formats"
5355
# fields is used as a data-type descriptor.
5456
def _usefields(adict, align):
55-
from .multiarray import dtype
5657
try:
5758
names = adict[-1]
5859
except KeyError:
@@ -130,7 +131,6 @@ def _array_descr(descriptor):
130131
# so don't remove the name here, or you'll
131132
# break backward compatibilty.
132133
def _reconstruct(subtype, shape, dtype):
133-
from .multiarray import ndarray
134134
return ndarray.__new__(subtype, shape, dtype)
135135

136136

@@ -194,12 +194,10 @@ def _commastring(astr):
194194
return result
195195

196196
def _getintp_ctype():
197-
from .multiarray import dtype
198197
val = _getintp_ctype.cache
199198
if val is not None:
200199
return val
201200
char = dtype('p').char
202-
import ctypes
203201
if (char == 'i'):
204202
val = ctypes.c_int
205203
elif char == 'l':
@@ -224,7 +222,6 @@ def c_void_p(self, num):
224222
class _ctypes(object):
225223
def __init__(self, array, ptr=None):
226224
try:
227-
import ctypes
228225
self._ctypes = ctypes
229226
except ImportError:
230227
self._ctypes = _missing_ctypes()
@@ -287,23 +284,55 @@ def _newnames(datatype, order):
287284
return tuple(list(order) + nameslist)
288285
raise ValueError("unsupported order value: %s" % (order,))
289286

290-
# Given an array with fields and a sequence of field names
291-
# construct a new array with just those fields copied over
292-
def _index_fields(ary, fields):
293-
from .multiarray import empty, dtype, array
287+
def _index_fields(ary, names):
288+
""" Given a structured array and a sequence of field names
289+
construct new array with just those fields.
290+
291+
Parameters
292+
----------
293+
ary : ndarray
294+
Structured array being subscripted
295+
names : string or list of strings
296+
Either a single field name, or a list of field names
297+
298+
Returns
299+
-------
300+
sub_ary : ndarray
301+
If `names` is a single field name, the return value is identical to
302+
ary.getfield, a writeable view into `ary`. If `names` is a list of
303+
field names the return value is a copy of `ary` containing only those
304+
fields. This is planned to return a view in the future.
305+
306+
Raises
307+
------
308+
ValueError
309+
If `ary` does not contain a field given in `names`.
310+
311+
"""
294312
dt = ary.dtype
295313

296-
names = [name for name in fields if name in dt.names]
297-
formats = [dt.fields[name][0] for name in fields if name in dt.names]
298-
offsets = [dt.fields[name][1] for name in fields if name in dt.names]
314+
#use getfield to index a single field
315+
if isinstance(names, basestring):
316+
try:
317+
return ary.getfield(dt.fields[names][0], dt.fields[names][1])
318+
except KeyError:
319+
raise ValueError("no field of name %s" % names)
320+
321+
for name in names:
322+
if name not in dt.fields:
323+
raise ValueError("no field of name %s" % name)
299324

300-
view_dtype = {'names':names, 'formats':formats, 'offsets':offsets, 'itemsize':dt.itemsize}
301-
view = ary.view(dtype=view_dtype)
325+
formats = [dt.fields[name][0] for name in names]
326+
offsets = [dt.fields[name][1] for name in names]
327+
328+
view_dtype = {'names': names, 'formats': formats,
329+
'offsets': offsets, 'itemsize': dt.itemsize}
330+
331+
# return copy for now (future plan to return ary.view(dtype=view_dtype))
332+
copy_dtype = {'names': view_dtype['names'],
333+
'formats': view_dtype['formats']}
334+
return array(ary.view(dtype=view_dtype), dtype=copy_dtype, copy=True)
302335

303-
# Return a copy for now until behavior is fully deprecated
304-
# in favor of returning view
305-
copy_dtype = {'names':view_dtype['names'], 'formats':view_dtype['formats']}
306-
return array(view, dtype=copy_dtype, copy=True 10000 )
307336

308337
def _get_all_field_offsets(dtype, base_offset=0):
309338
""" Returns the types and offsets of all fields in a (possibly structured)
@@ -363,8 +392,6 @@ def _check_field_overlap(new_fields, old_fields):
363392
If the new fields are incompatible with the old fields
364393
365394
"""
366-
from .numerictypes import object_
367-
from .multiarray import dtype
368395

369396
#first go byte by byte and check we do not access bytes not in old_fields
370397
new_bytes = set()
@@ -527,8 +554,6 @@ def _view_is_safe(oldtype, newtype):
527554
_pep3118_standard_typechars = ''.join(_pep3118_standard_map.keys())
528555

529556
def _dtype_from_pep3118(spec, byteorder='@', is_subdtype=False):
530-
from numpy.core.multiarray import dtype
531-
532557
fields = {}
533558
offset = 0
534559
explicit_name = False
@@ -694,8 +719,6 @@ def get_dummy_name():
694719

695720
def _add_trailing_padding(value, padding):
696721
"""Inject the specified number of padding bytes at the end of a dtype"""
697-
from numpy.core.multiarray import dtype
698-
699722
if value.fields is None:
700723
vfields = {'f0': (value, 0)}
701724
else:

numpy/core/src/multiarray/arraytypes.c.src

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -726,7 +726,7 @@ VOID_setitem(PyObject *op, char *ip, PyArrayObject *ap)
726726
PyObject *tup;
727727
int savedflags;
728728

729-
res = -1;
729+
res = 0;
730730
/* get the names from the fields dictionary*/
731731
names = descr->names;
732732
n = PyTuple_GET_SIZE(names);

numpy/core/src/multiarray/mapping.c

Lines changed: 88 additions & 87 deletions
Original file line numberDiff line numberDiff line change
@@ -215,6 +215,7 @@ prepare_index(PyArrayObject *self, PyObject *index,
215215
}
216216
for (i = 0; i < n; i++) {
217217
PyObject *tmp_obj = PySequence_GetItem(index, i);
218+
/* if getitem fails (unusual) treat this as a single index */
218219
if (tmp_obj == NULL) {
219220
PyErr_Clear();
220221
make_tuple = 0;
@@ -1361,6 +1362,52 @@ array_subscript_asarray(PyArrayObject *self, PyObject *op)
13611362
return PyArray_EnsureAnyArray(array_subscript(self, op));
13621363
}
13631364

1365+
NPY_NO_EXPORT int
1366+
obj_is_string_or_stringlist(PyObject *op)
1367+
{
1368+
#if defined(NPY_PY3K)
1369+
if (PyUnicode_Check(op)) {
1370+
#else
1371+
if (PyString_Check(op) || PyUnicode_Check(op)) {
1372+
#endif
1373+
return 1;
1374+
}
1375+
else if (PySequence_Check(op) && !PyTuple_Check(op)) {
1376+
int seqlen, i;
1377+
PyObject *obj = NULL;
1378+
seqlen = PySequence_Size(op);
1379+
1380+
/* quit if we come across a 0-d array (seqlen==-1) or a 0-len array */
1381+
if (seqlen == -1) {
1382+
PyErr_Clear();
1383+
return 0;
1384+
}
1385+
if (seqlen == 0) {
1386+
return 0;
1387+
}
1388+
1389+
for (i = 0; i < seqlen; i++) {
1390+
obj = PySequence_GetItem(op, i);
1391+
if (obj == NULL) {
1392+
/* only happens for strange sequence objects. Silently fail */
1393+
PyErr_Clear();
1394+
return 0;
1395+
}
1396+
1397+
#if defined(NPY_PY3K)
1398+
if (!PyUnicode_Check(obj)) {
1399+
#else
1400+
if (!PyString_Check(obj) && !PyUnicode_Check(obj)) {
1401+
#endif
1402+
Py_DECREF(obj);
1403+
return 0;
1404+
}
1405+
Py_DECREF(obj);
1406+
}
1407+
return 1;
1408+
}
1409+
return 0;
1410+
}
13641411

13651412
/*
13661413
* General function for indexing a NumPy array with a Python object.
@@ -1382,76 +1429,26 @@ array_subscript(PyArrayObject *self, PyObject *op)
13821429

13831430
PyArrayMapIterObject * mit = NULL;
13841431

1385-
/* Check for multiple field access */
1386-
if (PyDataType_HASFIELDS(PyArray_DESCR(self))) {
1387-
/* Check for single field access */
1388-
/*
1389-
* TODO: Moving this code block into the HASFIELDS, means that
1390-
* string integers temporarily work as indices.
1391-
*/
1392-
if (PyString_Check(op) || PyUnicode_Check(op)) {
1393-
PyObject *temp, *obj;
1394-
1395-
if (PyDataType_HASFIELDS(PyArray_DESCR(self))) {
1396-
obj = PyDict_GetItem(PyArray_DESCR(self)->fields, op);
1397-
if (obj != NULL) {
1398-
PyArray_Descr *descr;
1399-
int offset;
1400-
PyObject *title;
1401-
1402-
if (PyArg_ParseTuple(obj, "Oi|O", &descr, &offset, &title)) {
1403-
Py_INCREF(descr);
1404-
return PyArray_GetField(self, descr, offset);
1405-
}
1406-
}
1407-
}
1432+
/* return fields if op is a string index */
1433+
if (PyDataType_HASFIELDS(PyArray_DESCR(self)) &&
1434+
obj_is_string_or_stringlist(op)) {
1435+
PyObject *obj;
1436+
static PyObject *indexfunc = NULL;
1437+
npy_cache_pyfunc("numpy.core._internal", "_index_fields", &indexfunc);
1438+
if (indexfunc == NULL) {
1439+
return NULL;
1440+
}
14081441

1409-
temp = op;
1410-
if (PyUnicode_Check(op)) {
1411-
temp = PyUnicode_AsUnicodeEscapeString(op);
1412-
}
1413-
PyErr_Format(PyExc_ValueError,
1414-
"field named %s not found",
1415-
PyBytes_AsString(temp));
1416-
if (temp != op) {
1417-
Py_DECREF(temp);
1418-
}
1442+
obj = PyObject_CallFunction(indexfunc, "OO", self, op);
1443+
if (obj == NULL) {
14191444
return NULL;
14201445
}
14211446

1422-
else if (PySequence_Check(op) && !PyTuple_Check(op)) {
1423-
int seqlen, i;
1424-
PyObject *obj;
1425-
seqlen = PySequence_Size(op);
1426-
for (i = 0; i < seqlen; i++) {
1427-
obj = PySequence_GetItem(op, i);
1428-
if (!PyString_Check(obj) && !PyUnicode_Check(obj)) {
1429-
Py_DECREF(obj);
1430-
break;
1431-
}
1432-
Py_DECREF(obj);
1433-
}
1434-
/*
1435-
* Extract multiple fields if all elements in sequence
1436-
* are either string or unicode (i.e. no break occurred).
1437-
*/
1438-
fancy = ((seqlen > 0) && (i == seqlen));
1439-
if (fancy) {
1440-
PyObject *_numpy_internal;
1441-
_numpy_internal = PyImport_ImportModule("numpy.core._internal");
1442-
if (_numpy_internal == NULL) {
1443-
return NULL;
1444-
}
1445-
obj = PyObject_CallMethod(_numpy_internal,
1446-
"_index_fields", "OO", self, op);
1447-
Py_DECREF(_numpy_internal);
1448-
if (obj == NULL) {
1449-
return NULL;
1450-
}
1451-
PyArray_ENABLEFLAGS((PyArrayObject*)obj, NPY_ARRAY_WARN_ON_WRITE);
1452-
return obj;
1453-
}
1447+
/* warn if writing to a copy. copies will have no base */
1448+
if (PyArray_BASE((PyArrayObject*)obj) == NULL) {
1449+
PyArray_ENABLEFLAGS((PyArrayObject*)obj, NPY_ARRAY_WARN_ON_WRITE);
14541450
}
1451+
return obj;
14551452
}
14561453

14571454
/* Prepare the indices */
@@ -1783,35 +1780,39 @@ array_assign_subscript(PyArrayObject *self, PyObject *ind, PyObject *op)
17831780
return -1;
17841781
}
17851782

1786-
/* Single field access */
1787-
if (PyDataType_HASFIELDS(PyArray_DESCR(self))) {
1788-
if (PyString_Check(ind) || PyUnicode_Check(ind)) {
1789-
PyObject *obj;
1790-
1791-
obj = PyDict_GetItem(PyArray_DESCR(self)->fields, ind);
1792-
if (obj != NULL) {
1793-
PyArray_Descr *descr;
1794-
int offset;
1795-
PyObject *title;
1783+
/* field access */
1784+
if (PyDataType_HASFIELDS(PyArray_DESCR(self)) &&
1785+
obj_is_string_or_stringlist(ind)) {
1786+
PyObject *obj;
1787+
static PyObject *indexfunc = NULL;
17961788

1797-
if (PyArg_ParseTuple(obj, "Oi|O", &descr, &offset, &title)) {
1798-
Py_INCREF(descr);
1799-
return PyArray_SetField(self, descr, offset, op);
1800-
}
1801-
}
18021789
#if defined(NPY_PY3K)
1803-
PyErr_Format(PyExc_ValueError,
1804-
"field named %S not found",
1805-
ind);
1790+
if (!PyUnicode_Check(ind)) {
18061791
#else
1807-
PyErr_Format(PyExc_ValueError,
1808-
"field named %s not found",
1809-
PyString_AsString(ind));
1792+
if (!PyString_Check(ind) && !PyUnicode_Check(ind)) {
18101793
#endif
1794+
PyErr_SetString(PyExc_ValueError,
1795+
"multi-field assignment is not supported");
1796+
}
1797+
1798+
npy_cache_pyfunc("numpy.core._internal", "_index_fields", &indexfunc);
1799+
if (indexfunc == NULL) {
1800+
return -1;
1801+
}
1802+
1803+
obj = PyObject_CallFunction(indexfunc, "OO", self, ind);
1804+
if (obj == NULL) {
18111805
return -1;
18121806
}
1813-
}
18141807

1808+
if (PyArray_CopyObject((PyArrayObject*)obj, op) < 0) {
1809+
Py_DECREF(obj);
1810+
return -1;
1811+
}
1812+
Py_DECREF(obj);
1813+
1814+
return 0;
1815+
}
18151816

18161817
/* Prepare the indices */
18171818
index_type = prepare_index(self, ind, indices, &index_num,

numpy/core/tests/test_indexing.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -409,6 +409,10 @@ def __getitem__(self, item):
409409
arr = np.arange(10)
410410
assert_array_equal(arr[SequenceLike()], arr[SequenceLike(),])
411411

412+
# also test that field indexing does not segfault
413+
# for a similar reason, by indexing a structured array
414+
arr = np.zeros((1,), dtype=[('f1', 'i8'), ('f2', 'i8')])
415+
assert_array_equal(arr[SequenceLike()], arr[SequenceLike(),])
412416

413417
class TestFieldIndexing(TestCase):
414418
def test_scalar_return_type(self):

0 commit comments

Comments
 (0)
0