8000 Merge pull request #7003 from gfyoung/place_str_fix · numpy/numpy@1429c60 · GitHub
[go: up one dir, main page]

Skip to content

Commit 1429c60

Browse files
committed
Merge pull request #7003 from gfyoung/place_str_fix
BUG: Fix string copying for np.place
2 parents e034b86 + 9128ed5 commit 1429c60

File tree

3 files changed

+78
-154
lines changed

3 files changed

+78
-154
lines changed

numpy/core/src/multiarray/compiled_base.c

Lines changed: 71 additions & 153 deletions
Original file line numberDiff line numberDiff line change
@@ -301,194 +301,112 @@ arr_digitize(PyObject *NPY_UNUSED(self), PyObject *args, PyObject *kwds)
301301
return ret;
302302
}
303303

304-
/*
305-
* Insert values from an input array into an output array, at positions
306-
* indicated by a mask. If the arrays are of dtype object (indicated by
307-
* the objarray flag), take care of reference counting.
308-
*
309-
* This function implements the copying logic of arr_insert() defined
310-
* below.
311-
*/
312-
static void
313-
arr_insert_loop(char *mptr, char *vptr, char *input_data, char *zero,
314-
char *avals_data, int melsize, int delsize, int objarray,
315-
int totmask, int numvals, int nd, npy_intp *instrides,
316-
npy_intp *inshape)
317-
{
318-
int mindx, rem_indx, indx, i, copied;
319-
320-
/*
321-
* Walk through mask array, when non-zero is encountered
322-
* copy next value in the vals array to the input array.
323-
* If we get through the value array, repeat it as necessary.
324-
*/
325-
copied = 0;
326-
for (mindx = 0; mindx < totmask; mindx++) {
327-
if (memcmp(mptr,zero,melsize) != 0) {
328-
/* compute indx into input array */
329-
rem_indx = mindx;
330-
indx = 0;
331-
for (i = nd - 1; i > 0; --i) {
332-
indx += (rem_indx % inshape[i]) * instrides[i];
333-
rem_indx /= inshape[i];
334-
}
335-
indx += rem_indx * instrides[0];
336-
/* fprintf(stderr, "mindx = %d, indx=%d\n", mindx, indx); */
337-
/* Copy value element over to input array */
338-
memcpy(input_data+indx,vptr,delsize);
339-
if (objarray) {
340-
Py_INCREF(*((PyObject **)vptr));
341-
}
342-
vptr += delsize;
343-
copied += 1;
344-
/* If we move past value data. Reset */
345-
if (copied >= numvals) {
346-
vptr = avals_data;
347-
copied = 0;
348-
}
349-
}
350-
mptr += melsize;
351-
}
352-
}
353-
354304
/*
355305
* Returns input array with values inserted sequentially into places
356306
* indicated by the mask
357307
*/
358308
NPY_NO_EXPORT PyObject *
359309
arr_insert(PyObject *NPY_UNUSED(self), PyObject *args, PyObject *kwdict)
360310
{
361-
PyObject *mask = NULL, *vals = NULL;
362-
PyArrayObject *ainput = NULL, *amask = NULL, *avals = NULL, *tmp = NULL;
363-
int numvals, totmask, sameshape;
364-
char *input_data, *mptr, *vptr, *zero = NULL;
365-
int melsize, delsize, nd, objarray, k;
366-
npy_intp *instrides, *inshape;
311+
char *src, *dest;
312+
npy_bool *mask_data;
313+
PyArray_Descr *dtype;
314+
PyArray_CopySwapFunc *copyswap;
315+
PyObject *array0, *mask0, *values0;
316+
PyArrayObject *array, *mask, *values;
317+
npy_intp i, j, chunk, nm, ni, nv;
367318

368319
static char *kwlist[] = {"input", "mask", "vals", NULL};
320+
NPY_BEGIN_THREADS_DEF;
321+
values = mask = NULL;
369322

370-
if (!PyArg_ParseTupleAndKeywords(args, kwdict, "O&OO", kwlist,
371-
PyArray_Converter, &ainput,
372-
&mask, &vals)) {
373-
goto fail;
323+
if (!PyArg_ParseTupleAndKeywords(args, kwdict, "O!OO:place", kwlist,
324+
&PyArray_Type, &array0, &mask0, &values0)) {
325+
return NULL;
374326
}
375327

376-
amask = (PyArrayObject *)PyArray_FROM_OF(mask, NPY_ARRAY_CARRAY);
377-
if (amask == NULL) {
328+
array = (PyArrayObject *)PyArray_FromArray((PyArrayObject *)array0, NULL,
329+
NPY_ARRAY_CARRAY | NPY_ARRAY_UPDATEIFCOPY);
330+
if (array == NULL) {
378331
goto fail;
379332
}
380-
/* Cast an object array */
381-
if (PyArray_DESCR(amask)->type_num == NPY_OBJECT) {
382-
tmp = (PyArrayObject *)PyArray_Cast(amask, NPY_INTP);
383-
if (tmp == NULL) {
384-
goto fail;
385-
}
386-
Py_DECREF(amask);
387-
amask = tmp;
388-
}
389333

390-
sameshape = 1;
391-
if (PyArray_NDIM(amask) == PyArray_NDIM(ainput)) {
392-
for (k = 0; k < PyArray_NDIM(amask); k++) {
393-
if (PyArray_DIMS(amask)[k] != PyArray_DIMS(ainput)[k]) {
394-
sameshape = 0;
395-
}
396-
}
397-
}
398-
else {
399-
/* Test to see if amask is 1d */
400-
if (PyArray_NDIM(amask) != 1) {
401-
sameshape = 0;
402-
}
403-
else if ((PyArray_SIZE(ainput)) != PyArray_SIZE(amask)) {
404-
sameshape = 0;
405-
}
406-
}
407-
if (!sameshape) {
408-
PyErr_SetString(PyExc_TypeError,
409-
"mask array must be 1-d or same shape as input array");
334+
ni = PyArray_SIZE(array);
335+
dest = PyArray_DATA(array);
336+
chunk = PyArray_DESCR(array)->elsize;
337+
mask = (PyArrayObject *)PyArray_FROM_OTF(mask0, NPY_BOOL,
338+
NPY_ARRAY_CARRAY | NPY_ARRAY_FORCECAST);
339+
if (mask == NULL) {
410340
goto fail;
411341
}
412342

413-
avals = (PyArrayObject *)PyArray_FromObject(vals,
414-
PyArray_DESCR(ainput)->type_num, 0, 1);
415-
if (avals == NULL) {
343+
nm = PyArray_SIZE(mask);
344+
if (nm != ni) {
345+
PyErr_SetString(PyExc_ValueError,
346+
"place: mask and data must be "
347+
"the same size");
416348
goto fail;
417349
}
418-
numvals = PyArray_SIZE(avals);
419-
nd = PyArray_NDIM(ainput);
420-
input_data = PyArray_DATA(ainput);
421-
mptr = PyArray_DATA(amask);
422-
melsize = PyArray_DESCR(amask)->elsize;
423-
vptr = PyArray_DATA(avals);
424-
delsize = PyArray_DESCR(avals)->elsize;
425-
zero = PyArray_Zero(amask);
426-
if (zero == NULL) {
350+
351+
mask_data = PyArray_DATA(mask);
352+
dtype = PyArray_DESCR(array);
353+
Py_INCREF(dtype);
354+
355+
values = (PyArrayObject *)PyArray_FromAny(values0, dtype,
356+
0, 0, NPY_ARRAY_CARRAY, NULL);
357+
if (values == NULL) {
427358
goto fail;
428359
}
429-
objarray = (PyArray_DESCR(ainput)->type_num == NPY_OBJECT);
430360

431-
if (!numvals) {
432-
/* nothing to insert! fail unless none of mask is true */
433-
const char *iter = mptr;
434-
const char *const last = iter + PyArray_NBYTES(amask);
435-
while (iter != last && !memcmp(iter, zero, melsize)) {
436-
iter += melsize;
361+
nv = PyArray_SIZE(values); /* zero if null array */
362+
if (nv <= 0) {
363+
npy_bool allFalse = 1;
364+
i = 0;
365+
366+
while (allFalse && i < ni) {
367+
if (mask_data[i]) {
368+
allFalse = 0;
369+
} else {
370+
i++;
371+
}
437372
}
438-
if (iter != last) {
373+
if (!allFalse) {
439374
PyErr_SetString(PyExc_ValueError,
440-
"Cannot insert from an empty array!");
375+
"Cannot insert from an empty array!");
441376
goto fail;
377+
} else {
378+
Py_XDECREF(values);
379+
Py_XDECREF(mask);
380+
Py_RETURN_NONE;
442381
}
443-
goto finish;
444382
}
445383

446-
/* Handle zero-dimensional case separately */
447-
if (nd == 0) {
448-
if (memcmp(mptr,zero,melsize) != 0) {
449-
/* Copy value element over to input array */
450-
memcpy(input_data,vptr,delsize);
451-
if (objarray) {
452-
Py_INCREF(*((PyObject **)vptr));
384+
src = PyArray_DATA(values);
385+
j = 0;
386+
387+
copyswap = PyArray_DESCR(array)->f->copyswap;
388+
NPY_BEGIN_THREADS_DESCR(PyArray_DESCR(array));
389+
for (i = 0; i < ni; i++) {
390+
if (mask_data[i]) {
391+
if (j >= nv) {
392+
j = 0;
453393
}
454-
}
455-
Py_DECREF(amask);
456-
Py_DECREF(avals);
457-
PyDataMem_FREE(zero);
458-
Py_DECREF(ainput);
459-
Py_RETURN_NONE;
460-
}
461394

462-
totmask = (int) PyArray_SIZE(amask);
463-
instrides = PyArray_STRIDES(ainput);
464-
inshape = PyArray_DIMS(ainput);
465-
if (objarray) {
466-
/* object array, need to refcount, can't release the GIL */
467-
arr_insert_loop(mptr, vptr, input_data, zero, PyArray_DATA(avals),
468-
melsize, delsize, objarray, totmask, numvals, nd,
469-
instrides, inshape);
470-
}
471-
else {
472-
/* No increfs take place in arr_insert_loop, so release the GIL */
473-
NPY_BEGIN_ALLOW_THREADS;
474-
arr_insert_loop(mptr, vptr, input_data, zero, PyArray_DATA(avals),
475-
melsize, delsize, objarray, totmask, numvals, nd,
476-
instrides, inshape);
477-
NPY_END_ALLOW_THREADS;
395+
copyswap(dest + i*chunk, src + j*chunk, 0, array);
396+
j++;
397+
}
478398
}
399+
NPY_END_THREADS;
479400

480-
finish:
481-
Py_DECREF(amask);
482-
Py_DECREF(avals);
483-
PyDataMem_FREE(zero);
484-
Py_DECREF(ainput);
401+
Py_XDECREF(values);
402+
Py_XDECREF(mask);
403+
Py_DECREF(array);
485404
Py_RETURN_NONE;
486405

487-
fail:
488-
PyDataMem_FREE(zero);
489-
Py_XDECREF(ainput);
490-
Py_XDECREF(amask);
491-
Py_XDECREF(avals);
406+
fail:
407+
Py_XDECREF(mask);
408+
Py_XDECREF(array);
409+
Py_XDECREF(values);
492410
return NULL;
493411
}
494412

numpy/lib/function_base.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2026,7 +2026,8 @@ def place(arr, mask, vals):
20262026
vals : 1-D sequence
20272027
Values to put into `a`. Only the first N elements are used, where
20282028
N is the number of True values in `mask`. If `vals` is smaller
2029-
than N it will be repeated.
2029+
than N, it will be repeated, and if elements of `a` are to be masked,
2030+
this sequence must be non-empty.
20302031
20312032
See Also
20322033
--------

numpy/lib/tests/test_function_base.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -805,6 +805,11 @@ def test_place(self):
805805
assert_raises_regex(ValueError, "Cannot insert from an empty array",
806806
lambda: place(a, [0, 0, 0, 0, 0, 1, 0], []))
807807

808+
# See Issue #6974
809+
a = np.array(['12', '34'])
810+
place(a, [0, 1], '9')
811+
assert_array_equal(a, ['12', '9'])
812+
808813
def test_both(self):
809814
a = rand(10)
810815
mask = a > 0.5

0 commit comments

Comments
 (0)
0