8000 MAINT: improvements for void scalar getfield/setfield · numpy/numpy@6c7e75d · GitHub
[go: up one dir, main page]

8000
Skip to content

Commit 6c7e75d

Browse files
committed
MAINT: improvements for void scalar getfield/setfield
This commit modifiesvoidtype_get/setfield to that they call the ndarray get/setfield. This solves bugs related to void-scalar assignment. This fixes issues #3126, #3561.
1 parent 4cba531 commit 6c7e75d

File tree

3 files changed

+100
-74
lines changed

3 files changed

+100
-74
lines changed

numpy/core/src/multiarray/scalartypes.c.src

Lines changed: 73 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -1652,33 +1652,12 @@ gentype_@name@(PyObject *self, PyObject *args, PyObject *kwds)
16521652
static PyObject *
16531653
voidtype_getfield(PyVoidScalarObject *self, PyObject *args, PyObject *kwds)
16541654
{
1655-
PyObject *ret, *newargs;
1656-
1657-
newargs = PyTuple_GetSlice(args, 0, 2);
1658-
if (newargs == NULL) {
1659-
return NULL;
1660-
}
1661-
ret = gentype_generic_method((PyObject *)self, newargs, kwds, "getfield");
1662-
Py_DECREF(newargs);
1663-
if (!ret) {
1664-
return ret;
1665-
}
1666-
if (PyArray_IsScalar(ret, Generic) && \
1667-
(!PyArray_IsScalar(ret, Void))) {
1668-
PyArray_Descr *new;
1669-
void *ptr;
1670-
if (!PyArray_ISNBO(self->descr->byteorder)) {
1671-
new = PyArray_DescrFromScalar(ret);
1672-
ptr = scalar_value(ret, new);
1673-
byte_swap_vector(ptr, 1, new->elsize);
1674-
Py_DECREF(new);
1675-
}
1676-
}
1677-
return ret;
1655+
return gentype_generic_method((PyObject *)self, args, kwds, "getfield");
16781656
}
16791657

16801658
static PyObject *
1681-
gentype_setfield(PyObject *NPY_UNUSED(self), PyObject *NPY_UNUSED(args), PyObject *NPY_UNUSED(kwds))
1659+
gentype_setfield(PyObject *NPY_UNUSED(self), PyObject *NPY_UNUSED(args),
1660+
PyObject *NPY_UNUSED(kwds))
16821661
{
16831662
PyErr_SetString(PyExc_TypeError,
16841663
"Can't set fields in a non-void array scalar.");
@@ -1688,59 +1667,71 @@ gentype_setfield(PyObject *NPY_UNUSED(self), PyObject *NPY_UNUSED(args), PyObjec
16881667
static PyObject *
16891668
voidtype_setfield(PyVoidScalarObject *self, PyObject *args, PyObject *kwds)
16901669
{
1691-
PyArray_Descr *typecode = NULL;
1692-
int offset = 0;
1693-
PyObject *value;
1694-
PyArrayObject *src;
1695-
int mysize;
1696-
char *dptr;
1697-
static char *kwlist[] = {"value", "dtype", "offset", 0};
1698-
1699-
if ((self->flags & NPY_ARRAY_WRITEABLE) != NPY_ARRAY_WRITEABLE) {
1700-
PyErr_SetString(PyExc_RuntimeError, "Can't write to memory");
1670+
/*
1671+
* we can't simply use ndarray's setfield because of the case where self is
1672+
* an object array and the value being assigned is an ndarray. We want
1673+
* the lines below to behave identically:
1674+
*
1675+
* b = np.zeros(1, dtype=[('x', 'O')])
1676+
* b[0]['x'] = arange(3)
1677+
* b['x'][0] = arange(3)
1678+
*
1679+
* Ndarray's setfield would broadcast the ndarray. Instead we use ndarray
1680+
* getfield to get the field safely, then setitem to set the value without
1681+
* broadcast. Note we also want subarrays to be set properly, ie
1682+
*
1683+
* a = np.zeros(1, dtype=[('x', 'i', 5)])
1684+
* a[0]['x'] = 1
1685+
*
1686+
* sets all values to 1. Setitem does this.
1687+
*/
1688+
PyObject *getfield_args, *value, *arr, *meth, *arr_field, *emptytuple;
1689+
1690+
value = PyTuple_GetItem(args, 0);
1691+
if (value == NULL) {
17011692
return NULL;
17021693
}
1703-
if (!PyArg_ParseTupleAndKeywords(args, kwds, "OO&|i", kwlist,
1704-
&value,
1705-
PyArray_DescrConverter,
1706-
&typecode, &offset)) {
1707-
Py_XDECREF(typecode);
1694+
getfield_args = PyTuple_GetSlice(args, 1, 3);
1695+
if (getfield_args == NULL) {
17081696
return NULL;
17091697
}
17101698

1711-
mysize = Py_SIZE(self);
1712-
1713-
if (offset < 0 || (offset + typecode->elsize) > mysize) {
1714-
PyErr_Format(PyExc_ValueError,
1715-
"Need 0 <= offset <= %d for requested type " \
1716-
"but received offset = %d",
1717-
mysize-typecode->elsize, offset);
1718-
Py_DECREF(typecode);
1699+
/* convert to 0-d array and use getfield */
1700+
arr = PyArray_FromScalar(self, NULL);
1701+
if (arr == NULL) {
1702+
Py_DECREF(getfield_args);
17191703
return NULL;
17201704
}
1721-
1722-
dptr = self->obval + offset;
1723-
1724-
if (typecode->type_num == NPY_OBJECT) {
1725-
PyObject *temp;
1726-
Py_INCREF(value);
1727-
NPY_COPY_PYOBJECT_PTR(&temp, dptr);
1728-
Py_XDECREF(temp);
1729-
NPY_COPY_PYOBJECT_PTR(dptr, &value);
1730-
Py_DECREF(typecode);
1705+
meth = PyObject_GetAttrString(arr, "getfield");
1706+
if (meth == NULL) {
1707+
Py_DECREF(getfield_args);
1708+
Py_DECREF(arr);
1709+
return NULL;
1710+
}
1711+
if (kwds == NULL) {
1712+
arr_field = PyObject_CallObject(meth, getfield_args);
17311713
}
17321714
else {
1733-
/* Copy data from value to correct place in dptr */
1734-
src = (PyArrayObject *)PyArray_FromAny(value, typecode,
1735-
0, 0, NPY_ARRAY_CARRAY, NULL);
1736-
if (src == NULL) {
1737-
return NULL;
1738-
}
1739-
typecode->f->copyswap(dptr, PyArray_DATA(src),
1740-
!PyArray_ISNBO(self->descr->byteorder),
1741-
src);
1742-
Py_DECREF(src);
1715+
arr_field = PyObject_Call(meth, getfield_args, kwds);
17431716
}
1717+
Py_DECREF(getfield_args);
1718+
Py_DECREF(meth);
1719+
Py_DECREF(arr);
1720+
1721+
if(arr_field == NULL){
1722+
return NULL;
1723+
}
1724+
1725+
/* fill the resulting array using setitem */
1726+
emptytuple = PyTuple_New(0);
1727+
if (PyObject_SetItem(arr_field, emptytuple, value) < 0) {
1728+
Py_DECREF(arr_field);
1729+
Py_DECREF(emptytuple);
1730+
return NULL;
1731+
}
1732+
Py_DECREF(arr_field);
1733+
Py_DECREF(emptytuple);
1734+
17441735
Py_INCREF(Py_None);
17451736
return Py_None;
17461737
}
@@ -2170,7 +2161,7 @@ static PyObject *
21702161
voidtype_item(PyVoidScalarObject *self, Py_ssize_t n)
21712162
{
21722163
npy_intp m;
2173-
PyObject *flist=NULL, *fieldinfo;
2164+
PyObject *flist=NULL, *fieldind, *fieldparam, *fieldinfo, *ret;
21742165

21752166
if (!(PyDataType_HASFIELDS(self->descr))) {
21762167
PyErr_SetString(PyExc_IndexError,
@@ -2186,9 +2177,13 @@ voidtype_item(PyVoidScalarObject *self, Py_ssize_t n)
21862177
PyErr_Format(PyExc_IndexError, "invalid index (%d)", (int) n);
21872178
return NULL;
21882179
}
2189-
fieldinfo = PyDict_GetItem(self->descr->fields,
2190-
PyTuple_GET_ITEM(flist, n));
2191-
return voidtype_getfield(self, fieldinfo, NULL);
2180+
/* no error checking needed: descr->names is well structured */
2181+
fieldind = PyTuple_GET_ITEM(flist, n);
2182+
fieldparam = PyDict_GetItem(self->descr->fields, fieldind);
2183+
fieldinfo = PyTuple_GetSlice(fieldparam, 0, 2);
2184+
ret = voidtype_getfield(self, fieldinfo, NULL);
2185+
Py_DECREF(fieldinfo);
2186+
return ret;
21922187
}
21932188

21942189

@@ -2210,12 +2205,16 @@ voidtype_subscript(PyVoidScalarObject *self, PyObject *ind)
22102205
#else
22112206
if (PyBytes_Check(ind) || PyUnicode_Check(ind)) {
22122207
#endif
2208+
PyObject *ret, *fieldparam;
22132209
/* look up in fields */
2214-
fieldinfo = PyDict_GetItem(self->descr->fields, ind);
2215-
if (!fieldinfo) {
2210+
fieldparam = PyDict_GetItem(self->descr->fields, ind);
2211+
if (!fieldparam) {
22162212
goto fail;
22172213
}
2218-
return voidtype_getfield(self, fieldinfo, NULL);
2214+
fieldinfo = PyTuple_GetSlice(fieldparam, 0, 2);
2215+
ret = voidtype_getfield(self, fieldinfo, NULL);
2216+
Py_DECREF(fieldinfo);
2217+
return ret;
22192218
}
22202219

22212220
/* try to convert it to a number */

numpy/core/tests/test_multiarray.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -797,6 +797,28 @@ def test_casting(self):
797797
t = [('a', '>i4'), ('b', '<f8'), ('c', 'i4')]
798798
assert_(not np.can_cast(a.dtype, t, casting=casting))
799799

800+
def test_setfield(self):
801+
# https://github.com/numpy/numpy/issues/3126
802+
struct_dt = np.dtype([('elem', 'i4', 5),])
803+
dt = np.dtype([('field', 'i4', 10),('struct', struct_dt)])
804+
x = np.zeros(1, dt)
805+
x[0]['field'] = np.ones(10, dtype='i4')
806+
x[0]['struct'] = np.ones(1, dtype=struct_dt)
807+
assert_equal(x[0]['field'], np.ones(10, dtype='i4'))
808+
809+
def test_setfield_object(self):
810+
# make sure object field assignment with ndarray value
811+
# on void scalar mimics setitem behavior
812+
b = np.zeros(1, dtype=[('x', 'O')])
813+
# next line should work identically to b['x'][0] = np.arange(3)
814+
b[0]['x'] = np.arange(3)
815+
assert_equal(b[0]['x'], np.arange(3))
816+
817+
#check that broadcasting check still works
818+
c = np.zeros(1, dtype=[('x', 'O', 5)])
819+
def testassign():
820+
c[0]['x'] = np.arange(3)
821+
assert_raises(ValueError, testassign)
800822

801823
class TestBool(TestCase):
802824
def test_test_interning(self):

numpy/core/tests/test_records.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -206,6 +206,11 @@ def test_pickle_2(self):
206206
assert_equal(a, pickle.loads(pickle.dumps(a)))
207207
assert_equal(a[0], pickle.loads(pickle.dumps(a[0])))
208208

209+
def test_record_scalar_setitem(self):
210+
# https://github.com/numpy/numpy/issues/3561
211+
rec = np.recarray(1, dtype=[('x', float, 5)])
212+
rec[0].x = 1
213+
assert_equal(rec[0].x, np.ones(5))
209214

210215
def test_find_duplicate():
211216
l1 = [1, 2, 3, 4, 5, 6]

0 commit comments

Comments
 (0)
0