8000 ENH: structured datatype safety checks · ahaldane/numpy@432ddcd · GitHub
[go: up one dir, main page]

Skip to content

Commit 432ddcd

Browse files
committed
ENH: structured datatype safety checks
Previously views of structured arrays containing objects were completely disabled. This commit adds more lenient check for whether an object-array view is allowed, and adds similar checks to getfield/setfield Fixes numpy#2346. Fixes numpy#3256. Fixes numpy#2599. Fixes numpy#3253. Fixes numpy#3286. Fixes numpy#5762.
1 parent 30d755d commit 432ddcd

File tree

7 files changed

+331
-37
lines changed

7 files changed

+331
-37
lines changed

numpy/core/_internal.py

Lines changed: 168 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -305,6 +305,174 @@ def _index_fields(ary, fields):
305305
copy_dtype = {'names':view_dtype['names'], 'formats':view_dtype['formats']}
306306
return array(view, dtype=copy_dtype, copy=True)
307307

308+
def _get_all_field_offsets(dtype, base_offset=0):
309+
""" Returns the types and offsets of all fields in a (possibly structured)
310+
data type, including nested fields and subarrays.
311+
312+
Parameters
313+
----------
314+
dtype : data-type
315+
Data type to extract fields from.
316+
base_offset : int, optional
317+
Additional offset to add to all field offsets.
318+
319+
Returns
320+
-------
321+
fields : list of (data-type, int) pairs
322+
A flat list of (dtype, byte offset) pairs.
323+
324+
"""
325+
fields = []
326+
if dtype.fields is not None:
327+
for name in dtype.names:
328+
sub_dtype = dtype.fields[name][0]
329+
sub_offset = dtype.fields[name][1] + base_offset
330+
fields.extend(_get_all_field_offsets(sub_dtype, sub_offset))
331+
else:
332+
if dtype.shape:
333+
sub_offsets = _get_all_field_offsets(dtype.base, base_offset)
334+
count = 1
335+
for dim in dtype.shape:
336+
count *= dim
337+
fields.extend((typ, off + dtype.base.itemsize*j)
338+
for j in range(count) for (typ, off) in sub_offsets)
339+
else:
340+
fields.append((dtype, base_offset))
341+
return fields
342+
343+
def _check_field_overlap(new_fields, old_fields):
344+
""" Perform object memory overlap tests for two data-types (see
345+
_view_is_safe).
346+
347+
This function checks that new fields only access memory contained in old
348+
fields, and that non-object fields are not interpreted as objects and vice
349+
versa.
350+
351+
Parameters
352+
----------
353+
new_fields : list of (data-type, int) pairs
354+
Flat list of (dtype, byte offset) pairs for the new data type, as
355+
returned by _get_all_field_offsets.
356+
old_fields: list of (data-type, int) pairs
357+
Flat list of (dtype, byte offset) pairs for the old data type, as
358+
returned by _get_all_field_offsets.
359+
360+
Raises
361+
------
362+
TypeError
363+
If the new fields are incompatible with the old fields
364+
365+
"""
366+
from .numerictypes import object_
367+
from .multiarray import dtype
368+
369+
#first go byte by byte and check we do not access bytes not in old_fields
370+
new_bytes = set()
371+
for tp, off in new_fields:
372+
new_bytes.update(set(range(off, off+tp.itemsize)))
373+
old_bytes = set()
374+
for tp, off in old_fields:
375+
old_bytes.update(set(range(off, off+tp.itemsize)))
376+
if new_bytes.difference(old_bytes):
377+
raise TypeError("view would access data parent array doesn't own")
378+
379+
#next check that we do not interpret non-Objects as Objects, and vv
380+
obj_offsets = [off for (tp, off) in old_fields if tp.type is object_]
381+
obj_size = dtype(object_).itemsize
382+
383+
for fld_dtype, fld_offset in new_fields:
384+
if fld_dtype.type is object_:
385+
# check we do not create object views where
386+
# there are no objects.
387+
if fld_offset not in obj_offsets:
388+
raise TypeError("cannot view non-Object data as Object type")
389+
else:
390+
# next check we do not create non-object views
391+
# where there are already objects.
392+
# see validate_object_field_overlap for a similar computation.
393+
for obj_offset in obj_offsets:
394+
if (fld_offset < obj_offset + obj_size and
395+
obj_offset < fld_offset + fld_dtype.itemsize):
396+
raise TypeError("cannot view Object as non-Object type")
397+
398+
def _getfield_is_safe(oldtype, newtype, offset):
399+
""" Checks safety of getfield for object arrays.
400+
401+
As in _view_is_safe, we need to check that memory containing objects is not
402+
reinterpreted as a non-object datatype and vice versa.
403+
404+
Parameters
405+
----------
406+
oldtype : data-type
407+
Data type of the original ndarray.
408+
newtype : data-type
409+
Data type of the field being accessed by ndarray.getfield
410+
offset : int
411+
Offset of the field being accessed by ndarray.getfield
412+
413+
Raises
414+
------
415+
TypeError
416+
If the field access is invalid
417+
418+
"""
419+
new_fields = _get_all_field_offsets(newtype, offset)
420+
old_fields = _get_all_field_offsets(oldtype)
421+
# raises if there is a problem
422+
_check_field_overlap(new_fields, old_fields)
423+
424+
def _view_is_safe(oldtype, newtype):
425+
""" Checks safety of a view involving object arrays, for example when
426+
doing::
427+
428+
np.zeros(10, dtype=oldtype).view(newtype)
429+
430+
We need to check that
431+
1) No memory that is not an object will be interpreted as a object,
432+
2) No memory containing an object will be interpreted as an arbitrary type.
433+
Both cases can cause segfaults, eg in the case the view is written to.
434+
Strategy here is to also disallow views where newtype has any field in a
435+
place oldtype doesn't.
436+
437+
Parameters
438+
----------
439+
oldtype : data-type
440+
Data type of original ndarray
441+
newtype : data-type
442+
Data type of the view
443+
444+
Raises
445+
------
446+
TypeError
447+
If the new type is incompatible with the old type.
448+
449+
"""
450+
new_fields = _get_all_field_offsets(newtype)
451+
new_size = newtype.itemsize
452+
453+
old_fields = _get_all_field_offsets(oldtype)
454+
old_size = oldtype.itemsize
455+
456+
# if the itemsizes are not equal, we need to check that all the
457+
# 'tiled positions' of the object match up. Here, we allow
458+
# for arbirary itemsizes (even those possibly disallowed
459+
# due to stride/data length issues).
460+
if old_size == new_size:
461+
new_num = old_num = 1
462+
else:
463+
gcd_new_old = _gcd(new_size, old_size)
464+
new_num = old_size // gcd_new_old
465+
old_num = new_size // gcd_new_old
466+
467+
# get position of fields within the tiling
468+
new_fieldtile = [(tp, off + new_size*j)
469+
for j in range(new_num) for (tp, off) in new_fields]
470+
old_fieldtile = [(tp, off + old_size*j)
471+
for j in range(old_num) for (tp, off) in old_fields]
472+
473+
# raises if there is a problem
474+
_check_field_overlap(new_fieldtile, old_fieldtile)
475+
308476
# Given a string containing a PEP 3118 format specifier,
309477
# construct a Numpy dtype
310478

numpy/core/src/multiarray/getset.c

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -434,6 +434,13 @@ array_descr_set(PyArrayObject *self, PyObject *arg)
434434
npy_intp newdim;
435435
int i;
436436
char *msg = "new type not compatible with array.";
437+
PyObject *safe;
438+
static PyObject *checkfunc = NULL;
439+
440+
npy_cache_pyfunc("numpy.core._internal", "_view_is_safe", &checkfunc);
441+
if (checkfunc == NULL) {
442+
return NULL;
443+
}
437444

438445
if (arg == NULL) {
439446
PyErr_SetString(PyExc_AttributeError,
@@ -448,15 +455,14 @@ array_descr_set(PyArrayObject *self, PyObject *arg)
448455
return -1;
449456
}
450457

451-
if (PyDataType_FLAGCHK(newtype, NPY_ITEM_HASOBJECT) ||
452-
PyDataType_FLAGCHK(newtype, NPY_ITEM_IS_POINTER) ||
453-
PyDataType_FLAGCHK(PyArray_DESCR(self), NPY_ITEM_HASOBJECT) ||
454-
PyDataType_FLAGCHK(PyArray_DESCR(self), NPY_ITEM_IS_POINTER)) {
455-
PyErr_SetString(PyExc_TypeError,
456-
"Cannot change data-type for object array.");
458+
/* check that we are not reinterpreting memory containing Objects */
459+
safe = PyObject_CallFunction(checkfunc, "OO", PyArray_DESCR(self), newtype);
460+
if (safe == NULL) {
457461
Py_DECREF(newtype);
458462
return -1;
459463
}
464+
Py_DECREF(safe);
465+
460466

461467
if (newtype->elsize == 0) {
462468
/* Allow a void view */

numpy/core/src/multiarray/methods.c

Lines changed: 16 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -358,15 +358,23 @@ NPY_NO_EXPORT PyObject *
358358
PyArray_GetField(PyArrayObject *self, PyArray_Descr *typed, int offset)
359359
{
360360
PyObject *ret = NULL;
361+
PyObject *safe;
362+
static PyObject *checkfunc = NULL;
361363

362-
if (offset < 0 || (offset + typed->elsize) > PyArray_DESCR(self)->elsize) {
363-
PyErr_Format(PyExc_ValueError,
364-
"Need 0 <= offset <= %d for requested type "
365-
"but received offset = %d",
366-
PyArray_DESCR(self)->elsize-typed->elsize, offset);
367-
Py_DECREF(typed);
364+
npy_cache_pyfunc("numpy.core._internal", "_getfield_is_safe", &checkfunc);
365+
if (checkfunc == NULL) {
368366
return NULL;
369367
}
368+
369+
/* check that we are not reinterpreting memory containing Objects */
370+
/* only returns True or raises */
371+
safe = PyObject_CallFunction(checkfunc, "OOi", PyArray_DESCR(self),
372+
typed, offset);
373+
if (safe == NULL) {
374+
return NULL;
375+
}
376+
Py_DECREF(safe);
377+
370378
ret = PyArray_NewFromDescr(Py_TYPE(self),
371379
typed,
372380
PyArray_NDIM(self), PyArray_DIMS(self),
@@ -417,23 +425,12 @@ PyArray_SetField(PyArrayObject *self, PyArray_Descr *dtype,
417425
PyObject *ret = NULL;
418426
int retval = 0;
419427

420-
if (offset < 0 || (offset + dtype->elsize) > PyArray_DESCR(self)->elsize) {
421-
PyErr_Format(PyExc_ValueError,
422-
"Need 0 <= offset <= %d for requested type "
423-
"but received offset = %d",
424-
PyArray_DESCR(self)->elsize-dtype->elsize, offset);
425-
Py_DECREF(dtype);
426-
return -1;
427-
}
428-
ret = PyArray_NewFromDescr(Py_TYPE(self),
429-
dtype, PyArray_NDIM(self), PyArray_DIMS(self),
430-
PyArray_STRIDES(self), PyArray_BYTES(self) + offset,
431-
PyArray_FLAGS(self), (PyObject *)self);
428+
/* getfield returns a view we can write to */
429+
ret = PyArray_GetField(self, dtype, offset);
432430
if (ret == NULL) {
433431
return -1;
434432
}
435433

436-
PyArray_UpdateFlags((PyArrayObject *)ret, NPY_ARRAY_UPDATE_ALL);
437434
retval = PyArray_CopyObject((PyArrayObject *)ret, val);
438435
Py_DECREF(ret);
439436
return retval;
@@ -455,13 +452,6 @@ array_setfield(PyArrayObject *self, PyObject *args, PyObject *kwds)
455452
return NULL;
456453
}
457454

458-
if (PyDataType_REFCHK(PyArray_DESCR(self))) {
459-
PyErr_SetString(PyExc_RuntimeError,
460-
"cannot call setfield on an object array");
461-
Py_DECREF(dtype);
462-
return NULL;
463-
}
464-
465455
if (PyArray_SetField(self, dtype, offset, value) < 0) {
466456
return NULL;
467457
}

numpy/core/tests/test_multiarray.py

Lines changed: 94 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -809,6 +809,14 @@ def test_casting(self):
809809
t = [('a', '>i4'), ('b', '<f8'), ('c', 'i4')]
810810
assert_(not np.can_cast(a.dtype, t, casting=casting))
811811

812+
def test_objview(self):
813+
# https://github.com/numpy/numpy/issues/3286
814+
a = np.array([], dtype=[('a', 'f'), ('b', 'f'), ('c', 'O')])
815+
a[['a', 'b']] # TypeError?
816+
817+
# https://github.com/numpy/numpy/issues/3253
818+
dat2 = np.zeros(3, [('A', 'i'), ('B', '|O')])
819+
new2 = dat2[['B', 'A']] # TypeError?
812820

813821
class TestBool(TestCase):
814822
def test_test_interning(self):
@ F438 @ -3577,8 +3585,9 @@ def test_record_no_hash(self):
35773585

35783586
class TestView(TestCase):
35793587
def test_basic(self):
3580-
x = np.array([(1, 2, 3, 4), (5, 6, 7, 8)], dtype=[('r', np.int8), ('g', np.int8),
3581-
('b', np.int8), ('a', np.int8)])
3588+
x = np.array([(1, 2, 3, 4), (5, 6, 7, 8)],
3589+
dtype=[('r', np.int8), ('g', np.int8),
3590+
('b', np.int8), ('a', np.int8)])
35823591
# We must be specific about the endianness here:
35833592
y = x.view(dtype='<i4')
35843593
# ... and again without the keyword.
@@ -5523,6 +5532,89 @@ def test_collections_hashable(self):
55235532
x = np.array([])
55245533
self.assertFalse(isinstance(x, collections.Hashable))
55255534

5535+
from numpy.core._internal import _view_is_safe
5536+
5537+
class TestObjViewSafetyFuncs:
5538+
def test_view_safety(self):
5539+
psize = dtype('p').itemsize
5540+
5541+
# creates dtype but with extra character code - for missing 'p' fields
5542+
def mtype(s):
5543+
n, offset, fields = 0, 0, []
5544+
for c in s.split(','): #subarrays won't work
5545+
if c != '-':
5546+
fields.append(('f{0}'.format(n), c, offset))
5547+
n += 1
5548+
offset += dtype(c).itemsize if c != '-' else psize
5549+
5550+
names, formats, offsets = zip(*fields)
5551+
return dtype({'names': names, 'formats': formats,
5552+
'offsets': offsets, 'itemsize': offset})
5553+
5554+
# test nonequal itemsizes with objects:
5555+
# these should succeed:
5556+
_view_is_safe(dtype('O,p,O,p'), dtype('O,p,O,p,O,p'))
5557+
_view_is_safe(dtype('O,O'), dtype('O,O,O'))
5558+
# these should fail:
5559+
assert_raises(TypeError, _view_is_safe, dtype('O,O,p'), dtype('O,O'))
5560+
assert_raises(TypeError, _view_is_safe, dtype('O,O,p'), dtype('O,p'))
5561+
assert_raises(TypeError, _view_is_safe, dtype('O,O,p'), dtype('p,O'))
5562+
5563+
# test nonequal itemsizes with missing fields:
5564+
# these should succeed:
5565+
_view_is_safe(mtype('-,p,-,p'), mtype('-,p,-,p,-,p'))
5566+
_view_is_safe(dtype('p,p'), dtype('p,p,p'))
5567+
# these should fail:
5568+
assert_raises(TypeError, _view_is_safe, mtype('p,p,-'), mtype('p,p'))
5569+
assert_raises(TypeError, _view_is_safe, mtype('p,p,-'), mtype('p,-'))
5570+
assert_raises(TypeError, _view_is_safe, mtype('p,p,-'), mtype('-,p'))
5571+
5572+
# scans through positions at which we can view a type
5573+
def scanView(d1, otype):
5574+
goodpos = []
5575+
for shift in range(d1.itemsize - dtype(otype).itemsize+1):
5576+
d2 = dtype({'names': ['f0'], 'formats': [otype],
5577+
'offsets': [shift], 'itemsize': d1.itemsize})
5578+
try:
5579+
_view_is_safe(d1, d2)
5580+
except TypeError:
5581+
pass
5582+
else:
5583+
goodpos.append(shift)
5584+
return goodpos
5585+
5586+
# test partial overlap with object field
5587+
assert_equal(scanView(dtype('p,O,p,p,O,O'), 'p'),
5588+
[0] + list(range(2*psize, 3*psize+1)))
5589+
assert_equal(scanView(dtype('p,O,p,p,O,O'), 'O'),
5590+
[psize, 4*psize, 5*psize])
5591+
5592+
# test partial overlap with missing field
5593+
assert_equal(scanView(mtype('p,-,p,p,-,-'), 'p'),
5594+
[0] + list(range(2*psize, 3*psize+1)))
5595+
5596+
# test nested structures with objects:
5597+
nestedO = dtype([('f0', 'p'), ('f1', 'p,O,p')])
5598+
assert_equal(scanView(nestedO, 'p'), list(range(psize+1)) + [3*psize])
5599+
assert_equal(scanView(nestedO, 'O'), [2*psize])
5600+
5601+
# test nested structures with missing fields:
5602+
nestedM = dtype([('f0', 'p'), ('f1', mtype('p,-,p'))])
5603+
assert_equal(scanView(nestedM, 'p'), list(range(psize+1)) + [3*psize])
5604+
5605+
# test subarrays with objects
5606+
subarrayO = dtype('p,(2,3)O,p')
5607+
assert_equal(scanView(subarrayO, 'p'), [0, 7*psize])
5608+
assert_equal(scanView(subarrayO, 'O'),
5609+
list(range(psize, 6*psize+1, psize)))
5610+
5611+
#test dtype with overlapping fields
5612+
overlapped = dtype({'names': ['f0', 'f1', 'f2', 'f3'],
5613+
'formats': ['p', 'p', 'p', 'p'],
5614+
'offsets': [0, 1, 3*psize-1, 3*psize],
5615+
'itemsize': 4*psize})
5616+
assert_equal(scanView(overlapped, 'p'), [0, 1, 3*psize-1, 3*psize])
5617+
55265618

55275619
if __name__ == "__main__":
55285620
run_module_suite()

0 commit comments

Comments
 (0)
0