diff --git a/numpy/core/src/multiarray/item_selection.c b/numpy/core/src/multiarray/item_selection.c index d6e0980c6173..2513bf5900d3 100644 --- a/numpy/core/src/multiarray/item_selection.c +++ b/numpy/core/src/multiarray/item_selection.c @@ -87,7 +87,8 @@ PyArray_TakeFrom(PyArrayObject *self0, PyObject *indices0, int axis, } else { int flags = NPY_ARRAY_CARRAY | - NPY_ARRAY_UPDATEIFCOPY; + NPY_ARRAY_UPDATEIFCOPY | + NPY_ARRAY_FORCECAST; if ((PyArray_NDIM(out) != nd) || !PyArray_CompareLists(PyArray_DIMS(out), shape, nd)) { @@ -104,8 +105,23 @@ PyArray_TakeFrom(PyArrayObject *self0, PyObject *indices0, int axis, */ flags |= NPY_ARRAY_ENSURECOPY; } + dtype = PyArray_DESCR(self); + /* + * The out array is converted to an array obj of type dtype and after + * the take operation is finished is update-if-copy-ed back. An + * explicit check that dtype can be safely cast to the type of out + * is needed, as the constructor of obj would check the opposite + * casting, were it not disabled by the NPY_ARRAY_FORCECAST flag. + */ + if (PyArray_CanCastTypeTo(dtype, PyArray_DESCR(out), + NPY_SAFE_CASTING) == 0) { + PyErr_SetString(PyExc_TypeError, + "output cannot be safely cast to out dtype"); + goto fail; + } Py_INCREF(dtype); + obj = (PyArrayObject *)PyArray_FromArray(out, dtype, flags); if (obj == NULL) { goto fail; diff --git a/numpy/core/tests/test_item_selection.py b/numpy/core/tests/test_item_selection.py index d8e9e6fd0faf..048d8da0d43b 100644 --- a/numpy/core/tests/test_item_selection.py +++ b/numpy/core/tests/test_item_selection.py @@ -3,6 +3,7 @@ import numpy as np from numpy.testing import * import sys, warnings +import itertools class TestTake(TestCase): @@ -45,7 +46,6 @@ def test_simple(self): res = ta.take(index_array, mode=mode, axis=1) assert_(res.shape == (2,) + index_array.shape) - def test_refcounting(self): objects = [object() for i in range(10)] for mode in ('raise', 'clip', 'wrap'): @@ -60,6 +60,17 @@ def test_refcounting(self): del a assert_(all(sys.getrefcount(o) == 3 for o in objects)) + def test_casting(self): + types = ''.join((np.typecodes['AllInteger'], np.typecodes['Float'])) + modes = ['raise', 'wrap', 'clip'] + for from_, to_, mode in itertools.product(types, types, modes): + a = np.array([0, 1, 2], dtype=np.dtype(from_)) + b = np.empty((3,), dtype=np.dtype(to_)) + if np.can_cast(from_, to_): + a.take([0, 1, 2], out=b, mode=mode) + else: + assert_raises(TypeError, a.take, [0, 1, 2], out=b, mode=mode) + def test_unicode_mode(self): d = np.arange(10) k = b'\xc3\xa4'.decode("UTF8")