8000 ENH: structured datatype safety checks · ahaldane/numpy@b7fece9 · GitHub
[go: up one dir, main page]

Skip to content

Commit b7fece9

Browse files
committed
ENH: structured datatype safety checks
Previously views of structured arrays containing objects were completely disabled. This commit adds more lenient check for whether an object-array view is allowed, and adds similar checks to getfield/setfield Fixes numpy#2346. Fixes numpy#3256. Fixes numpy#2599. Fixes numpy#3253. Fixes numpy#3286. Fixes numpy#5762.
1 parent 36b9404 commit b7fece9

File tree

8 files changed

+273
-37
lines changed

8 files changed

+273
-37
lines changed

numpy/core/_internal.py

Lines changed: 109 additions & 0 deletions
10000
Original file line numberDiff line numberDiff line change
@@ -305,6 +305,115 @@ def _index_fields(ary, fields):
305305
copy_dtype = {'names':view_dtype['names'], 'formats':view_dtype['formats']}
306306
return array(view, dtype=copy_dtype, copy=True)
307307

308+
def _get_all_field_offsets(dtype, base_offset=0):
309+
""" Returns the types and offsets of all fields in a dtype,
310+
including nested fields and subarrays. Returned value is
311+
a flat list of (dtype, offset) pairs.
312+
"""
313+
fields = []
314+
if dtype.fields is not None:
315+
for name in dtype.names:
316+
sub_dtype = dtype.fields[name][0]
317+
sub_offset = dtype.fields[name][1] + base_offset
318+
fields.extend(_get_all_field_offsets(sub_dtype, sub_offset))
319+
else:
320+
if dtype.shape:
321+
sub_offsets = _get_all_field_offsets(dtype.base, base_offset)
322+
count = 1
323+
for dim in dtype.shape:
324+
count *= dim
325+
fields.extend((typ, off + dtype.base.itemsize*j)
326+
for j in range(count) for (typ, off) in sub_offsets)
327+
else:
328+
fields.append((dtype, base_offset))
329+
return fields
330+
331+
def _check_field_overlap(new_fields, old_fields):
332+
""" Perform object memory overlap tests (see _view_is_safe).
333+
new_fields and old_fields are lists of fields of form (dtype, offset).
334+
This function checks that new fields only access memory contained
335+
in old fields, and that non-object fields are not interpreted as
336+
objects and vice versa.
337+
"""
338+
from .numerictypes import object_
339+
from .multiarray import dtype
340+
341+
#first go byte by byte and check we do not access bytes not in old_fields
342+
new_bytes = set()
343+
for tp, off in new_fields:
344+
new_bytes.update(set(range(off, off+tp.itemsize)))
345+
old_bytes = set()
346+
for tp, off in old_fields:
347+
old_bytes.update(set(range(off, off+tp.itemsize)))
348+
if new_bytes.difference(old_bytes):
349+
raise TypeError("view would access data parent array doesn't own")
350+
351+
#next check that we do not interpret non-Objects as Objects, and vv
352+
obj_offsets = [off for (tp, off) in old_fields if tp.type is object_]
353+
obj_size = dtype(object_).itemsize
354+
355+
for fld_dtype, fld_offset in new_fields:
356+
if fld_dtype.type is object_:
357+
# check we do not create object views where
358+
# there are no objects.
359+
if fld_offset not in obj_offsets:
360+
raise TypeError("cannot view non-Object data as Object type")
361+
else:
362+
# next check we do not create non-object views
363+
# where there are already objects.
364+
# see validate_object_field_overlap for a similar computation.
365+
for obj_offset in obj_offsets:
366+
if (fld_offset < obj_offset + obj_size and
367+
obj_offset < fld_offset + fld_dtype.itemsize):
368+
raise TypeError("cannot view Object as non-Object type")
369+
370+
def _getfield_is_safe(oldtype, newtype, offset):
371+
""" Checks safety of getfield for object arrays. Similarly to
372+
_view_is_safe, we need to check that memory containing objects is not
373+
reinterpreted as a non-object datatype and vice versa.
374+
"""
375+
new_fields = _get_all_field_offsets(newtype, offset)
376+
old_fields = _get_all_field_offsets(oldtype)
377+
# raises if there is a problem
378+
_check_field_overlap(new_fields, old_fields)
379+
return True
380+
381+
def _view_is_safe(oldtype, newtype):
382+
""" Checks safety of a view involving object arrays. We need to check that
383+
1) No memory that is not an object will be interpreted as a object,
384+
2) No memory containing an object will be interpreted as an arbitrary
385+
type. Both cases can cause segfaults, eg in the case the view is
386+
written to. Strategy here is to also disallow views where newtype has
387+
any field in a place oldtype doesn't.
388+
oldtype and newtype are the input dtypes.
389+
"""
390+
new_fields = _get_all_field_offsets(newtype)
391+
new_size = newtype.itemsize
392+
393+
old_fields = _get_all_field_offsets(oldtype)
394+
old_size = oldtype.itemsize
395+
396+
# if the itemsizes are not equal, we need to check that all the
397+
# 'tiled positions' of the object match up. Here, we allow
398+
# for arbirary itemsizes (even those possibly disallowed
399+
# due to stride/data length issues).
400+
if old_size == new_size:
401+
new_num = old_num = 1
402+
else:
403+
gcd_new_old = _gcd(new_size, old_size)
404+
new_num = old_size // gcd_new_old
405+
old_num = new_size // gcd_new_old
406+
407+
# get position of fields within the tiling
408+
new_fieldtile = [(tp, off + new_size*j)
409+
for j in range(new_num) for (tp, off) in new_fields]
410+
old_fieldtile = [(tp, off + old_size*j)
411+
for j in range(old_num) for (tp, off) in old_fields]
412+
413+
# raises if there is a problem
414+
_check_field_overlap(new_fieldtile, old_fieldtile)
415+
return True
416+
308417
# Given a string containing a PEP 3118 format specifier,
309418
# construct a Numpy dtype
310419

numpy/core/records.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -551,6 +551,7 @@ def view(self, dtype=None, type=None):
551551
return ndarray.view(self, dtype, type)
552552

553553

554+
554555
def fromarrays(arrayList, dtype=None, shape=None, formats=None,
555556
names=None, titles=None, aligned=False, byteorder=None):
556557
""" create a record array from a (flat) list of arrays

numpy/core/src/multiarray/getset.c

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -434,6 +434,7 @@ array_descr_set(PyArrayObject *self, PyObject *arg)
434434
npy_intp newdim;
435435
int i;
436436
char *msg = "new type not compatible with array.";
437+
PyObject *_numpy_internal, *safe;
437438

438439
if (arg == NULL) {
439440
PyErr_SetString(PyExc_AttributeError,
@@ -448,15 +449,21 @@ array_descr_set(PyArrayObject *self, PyObject *arg)
448449
return -1;
449450
}
450451

451-
if (PyDataType_FLAGCHK(newtype, NPY_ITEM_HASOBJECT) ||
452-
PyDataType_FLAGCHK(newtype, NPY_ITEM_IS_POINTER) ||
453-
PyDataType_FLAGCHK(PyArray_DESCR(self), NPY_ITEM_HASOBJECT) ||
454-
PyDataType_FLAGCHK(PyArray_DESCR(self), NPY_ITEM_IS_POINTER)) {
455-
PyErr_SetString(PyExc_TypeError,
456-
"Cannot change data-type for object array.");
452+
/* check that we are not reinterpreting memory containing Objects */
453+
_numpy_internal = PyImport_ImportModule("numpy.core._internal");
454+
if (_numpy_internal == NULL) {
455+
Py_DECREF(newtype);
456+
return -1;
457+
}
458+
safe = PyObject_CallMethod(_numpy_internal, "_view_is_safe", "OO",
459+
PyArray_DESCR(self), newtype);
460+
Py_DECREF(_numpy_internal);
461+
if (safe == NULL) {
457462
Py_DECREF(newtype);
458463
return -1;
459464
}
465+
Py_DECREF(safe);
466+
460467

461468
if (newtype->elsize == 0) {
462469
/* Allow a void view */

numpy/core/src/multiarray/methods.c

Lines changed: 15 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -358,15 +358,22 @@ NPY_NO_EXPORT PyObject *
358358
PyArray_GetField(PyArrayObject *self, PyArray_Descr *typed, int offset)
359359
{
360360
PyObject *ret = NULL;
361+
PyObject *_numpy_internal, *safe;
361362

362-
if (offset < 0 || (offset + typed->elsize) > PyArray_DESCR(self)->elsize) {
363-
PyErr_Format(PyExc_ValueError,
364-
"Need 0 <= offset <= %d for requested type "
365-
"but received offset = %d",
366-
PyArray_DESCR(self)->elsize-typed->elsize, offset);
367-
Py_DECREF(typed);
363+
/* check that we are not reinterpreting memory containing Objects */
364+
_numpy_internal = PyImport_ImportModule("numpy.core._internal");
365+
if (_numpy_internal == NULL) {
366+
return NULL;
367+
}
368+
/* only returns True or raises */
369+
safe = PyObject_CallMethod(_numpy_internal,
370+
"_getfield_is_safe", "OOi", PyArray_DESCR(self), typed, offset);
371+
Py_DECREF(_numpy_internal);
372+
if (safe == NULL) {
368373
return NULL;
369374
}
375+
Py_DECREF(safe);
376+
370377
ret = PyArray_NewFromDescr(Py_TYPE(self),
371378
typed,
372379
PyArray_NDIM(self), PyArray_DIMS(self),
@@ -417,23 +424,12 @@ PyArray_SetField(PyArrayObject *self, PyArray_Descr *dtype,
417424
PyObject *ret = NULL;
418425
int retval = 0;
419426

420-
if (offset < 0 || (offset + dtype->elsize) > PyArray_DESCR(self)->elsize) {
421-
PyErr_Format(PyExc_ValueError,
422-
"Need 0 <= offset <= %d for requested type "
423-
"but received offset = %d",
424-
PyArray_DESCR(self)->elsize-dtype->elsize, offset);
425-
Py_DECREF(dtype);
426-
return -1;
427-
}
428-
ret = PyArray_NewFromDescr(Py_TYPE(self),
429-
dtype, PyArray_NDIM(self), PyArray_DIMS(self),
430-
PyArray_STRIDES(self), PyArray_BYTES(self) + offset,
431-
PyArray_FLAGS(self), (PyObject *)self);
427+
/* getfield returns a view we can write to */
428+
ret = PyArray_GetField(self, dtype, offset);
432429
if (ret == NULL) {
433430
return -1;
434431
}
435432

436-
PyArray_UpdateFlags((PyArrayObject *)ret, NPY_ARRAY_UPDATE_ALL);
437433
retval = PyArray_CopyObject((PyArrayObject *)ret, val);
438434
Py_DECREF(ret);
439435
return retval;
@@ -455,13 +451,6 @@ array_setfield(PyArrayObject *self, PyObject *args, PyObject *kwds)
455451
return NULL;
456452
}
457453

458-
if (PyDataType_REFCHK(PyArray_DESCR(self))) {
459-
PyErr_SetString(PyExc_RuntimeError,
460-
"cannot call setfield on an object array");
461-
Py_DECREF(dtype);
462-
return NULL;
463-
}
464-
465454
if (PyArray_SetField(self, dtype, offset, value) < 0) {
466455
return NULL;
467456
}

numpy/core/tests/test_multiarray.py

Lines changed: 94 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -808,6 +808,14 @@ def test_casting(self):
808808
t = [('a', '>i4'), ('b', '<f8'), ('c', 'i4')]
809809
assert_(not np.can_cast(a.dtype, t, casting=casting))
810810

811+
def test_objview(self):
812+
# https://github.com/numpy/numpy/issues/3286
813+
a = np.array([], dtype=[('a', 'f'), ('b', 'f'), ('c', 'O')])
814+
a[['a', 'b']] # TypeError?
815+
816+
# https://github.com/numpy/numpy/issues/3253
817+
dat2 = np.zeros(3, [('A', 'i'), ('B', '|O')])
818+
new2 = dat2[['B', 'A']] # TypeError?
811819

812820
class TestBool(TestCase):
813821
def test_test_interning(self):
@@ -3576,8 +3584,9 @@ def test_record_no_hash(self):
35763584

35773585
class TestView(TestCase):
35783586
def test_basic(self):
3579-
x = np.array([(1, 2, 3, 4), (5, 6, 7, 8)], dtype=[('r', np.int8), ('g', np.int8),
3580-
('b', np.int8), ('a', np.int8)])
3587+
x = np.array([(1, 2, 3, 4), (5, 6, 7, 8)],
3588+
dtype=[('r', np.int8), ('g', np.int8),
3589+
('b', np.int8), ('a', np.int8)])
35813590
# We must be specific about the endianness here:
35823591
y = x.view(dtype='<i4')
35833592
# ... and again without the keyword.
@@ -5242,6 +5251,89 @@ def test_collections_hashable(self):
52425251
x = np.array([])
52435252
self.assertFalse(isinstance(x, collections.Hashable))
52445253

5254+
from numpy.core._internal import _view_is_safe
5255+
5256+
class TestObjViewSafetyFuncs:
5257+
def test_view_safety(self):
5258+
psize = dtype('p').itemsize
5259+
5260+
# creates dtype but with extra character code - for missing 'p' fields
5261+
def mtype(s):
5262+
n, offset, fields = 0, 0, []
5263+
for c in s.split(','): #subarrays won't work
5264+
if c != '-':
5265+
fields.append(('f{0}'.format(n), c, offset))
5266+
n += 1
5267+
offset += dtype(c).itemsize if c != '-' else psize
5268+
5269+
names, formats, offsets = zip(*fields)
5270+
return dtype({'names': names, 'formats': formats,
5271+
'offsets': offsets, 'itemsize': offset})
5272+
5273+
# test nonequal itemsizes with objects:
5274+
# these should succeed:
5275+
_view_is_safe(dtype('O,p,O,p'), dtype('O,p,O,p,O,p'))
5276+
_view_is_safe(dtype('O,O'), dtype('O,O,O'))
5277+
# these should fail:
5278+
assert_raises(TypeError, _view_is_safe, dtype('O,O,p'), dtype('O,O'))
5279+
assert_raises(TypeError, _view_is_safe, dtype('O,O,p'), dtype('O,p'))
5280+
assert_raises(TypeError, _view_is_safe, dtype('O,O,p'), dtype('p,O'))
5281+
5282+
# test nonequal itemsizes with missing fields:
5283+
# these should succeed:
5284+
_view_is_safe(mtype('-,p,-,p'), mtype('-,p,-,p,-,p'))
5285+
_view_is_safe(dtype('p,p'), dtype('p,p,p'))
5286+
# these should fail:
5287+
assert_raises(TypeError, _view_is_safe, mtype('p,p,-'), mtype('p,p'))
5288+
assert_raises(TypeError, _view_is_safe, mtype('p,p,-'), mtype('p,-'))
5289+
assert_raises(TypeError, _view_is_safe, mtype('p,p,-'), mtype('-,p'))
5290+
5291+
# scans through positions at which we can view a type
5292+
def scanView(d1, otype):
5293+
goodpos = []
5294+
for shift in range(d1.itemsize - dtype(otype).itemsize+1):
5295+
d2 = dtype({'names': ['f0'], 'formats': [otype],
5296+
'offsets': [shift], 'itemsize': d1.itemsize})
5297+
try:
5298+
_view_is_safe(d1, d2)
5299+
except TypeError:
5300+
pass
5301+
else:
5302+
goodpos.append(shift)
5303+
return goodpos
5304+
5305+
# test partial overlap with object field
5306+
assert_equal(scanView(dtype('p,O,p,p,O,O'), 'p'),
5307+
[0] + list(range(2*psize, 3*psize+1)))
5308+
assert_equal(scanView(dtype('p,O,p,p,O,O'), 'O'),
5309+
[psize, 4*psize, 5*psize])
5310+
5311+
# test partial overlap with missing field
5312+
assert_equal(scanView(mtype('p,-,p,p,-,-'), 'p'),
5313+
[0] + list(range(2*psize, 3*psize+1)))
5314+
5315+
# test nested structures with objects:
5316+
nestedO = dtype([('f0', 'p'), ('f1', 'p,O,p')])
5317+
assert_equal(scanView(nestedO, 'p'), list(range(psize+1)) + [3*psize])
5318+
assert_equal(scanView(nestedO, 'O'), [2*psize])
5319+
5320+
# test nested structures with missing fields:
5321+
nestedM = dtype([('f0', 'p'), ('f1', mtype('p,-,p'))])
5322+
assert_equal(scanView(nestedM, 'p'), list(range(psize+1)) + [3*psize])
5323+
5324+
# test subarrays with objects
5325+
subarrayO = dtype('p,(2,3)O,p')
5326+
assert_equal(scanView(subarrayO, 'p'), [0, 7*psize])
5327+
assert_equal(scanView(subarrayO, 'O'),
5328+
list(range(psize, 6*psize+1, psize)))
5329+
5330+
#test dtype with overlapping fields
5331+
overlapped = dtype({'names': ['f0', 'f1', 'f2', 'f3'],
5332+
'formats': ['p', 'p', 'p', 'p'],
5333+
'offsets': [0, 1, 3*psize-1, 3*psize],
5334+
'itemsize': 4*psize})
5335+
assert_equal(scanView(overlapped, 'p'), [0, 1, 3*psize-1, 3*psize])
5336+
52455337

52465338
if __name__ == "__main__":
52475339
run_module_suite()

numpy/core/tests/test_records.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -219,6 +219,15 @@ def test_pickle_2(self):
219219
assert_equal(a, pickle.loads(pickle.dumps(a)))
220220
assert_equal(a[0], pickle.loads(pickle.dumps(a[0])))
221221

222+
def test_objview_record(self):
223+
# https://github.com/numpy/numpy/issues/2599
224+
dt = np.dtype([('foo', 'i8'), ('bar', 'O')])
225+
r = np.zeros((1,3), dtype=dt).view(np.recarray)
226+
r.foo = np.array([1, 2, 3]) # TypeError?
227+
228+
# https://github.com/numpy/numpy/issues/3256
229+
ra = np.recarray((2,), dtype=[('x', object), ('y', float), ('z', int)])
230+
ra[['x','y']] #TypeError?
222231

223232
def test_find_duplicate():
224233
l1 = [1, 2, 3, 4, 5, 6]

numpy/core/tests/test_regression.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1739,10 +1739,9 @@ def test_zerosize_accumulate(self):
17391739
assert_equal(np.add.accumulate(x[:-1, 0]), [])
17401740

17411741
def test_objectarray_setfield(self):
1742-
# Setfield directly manipulates the raw array data,
1743-
# so is invalid for object arrays.
1742+
# Setfield should not overwrite Object fields with non-Object data
17441743
x = np.array([1, 2, 3], dtype=object)
1745-
assert_raises(RuntimeError, x.setfield, 4, np.int32, 0)
1744+
assert_raises(TypeError, x.setfield, 4, np.int32, 0)
17461745

17471746
def test_setting_rank0_string(self):
17481747
"Ticket #1736"

0 commit comments

Comments
 (0)
0