10000 Merge pull request #5747 from jaimefrio/repeat_broadcast · numpy/numpy@f1f9e14 · GitHub
[go: up one dir, main page]

Skip to content

Commit f1f9e14

Browse files
committed
Merge pull request #5747 from jaimefrio/repeat_broadcast
BUG: np.repeat does not properly broadcast size 1 repeat arrays
2 parents e05b758 + 77e433a commit f1f9e14

File tree

2 files changed

+29
-22
lines changed

2 files changed

+29
-22
lines changed

numpy/core/src/multiarray/item_selection.c

Lines changed: 22 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -546,9 +546,9 @@ NPY_NO_EXPORT PyObject *
546546
PyArray_Repeat(PyArrayObject *aop, PyObject *op, int axis)
547547
{
548548
npy_intp *counts;
549-
npy_intp n, n_outer, i, j, k, chunk, total;
550-
npy_intp tmp;
551-
int nd;
549+
npy_intp n, n_outer, i, j, k, chunk;
550+
npy_intp total = 0;
551+
npy_bool broadcast = NPY_FALSE;
552552
PyArrayObject *repeats = NULL;
553553
PyObject *ap = NULL;
554554
PyArrayObject *ret = NULL;
@@ -558,34 +558,35 @@ PyArray_Repeat(PyArrayObject *aop, PyObject *op, int axis)
558558
if (repeats == NULL) {
559559
return NULL;
560560
}
561-
nd = PyArray_NDIM(repeats);
561+
562+
/*
563+
* Scalar and size 1 'repeat' arrays broadcast to any shape, for all
564+
* other inputs the dimension must match exactly.
565+
*/
566+
if (PyArray_NDIM(repeats) == 0 || PyArray_SIZE(repeats) == 1) {
567+
broadcast = NPY_TRUE;
568+
}
569+
562570
counts = (npy_intp *)PyArray_DATA(repeats);
563571

564-
if ((ap=PyArray_CheckAxis(aop, &axis, NPY_ARRAY_CARRAY))==NULL) {
572+
if ((ap = PyArray_CheckAxis(aop, &axis, NPY_ARRAY_CARRAY)) == NULL) {
565573
Py_DECREF(repeats);
566574
return NULL;
567575
}
568576

569577
aop = (PyArrayObject *)ap;
570-
if (nd == 1) {
571-
n = PyArray_DIMS(repeats)[0];
572-
}
573-
else {
574-
/* nd == 0 */
575-
n = PyArray_DIMS(aop)[axis];
576-
}
577-
if (PyArray_DIMS(aop)[axis] != n) {
578-
PyErr_SetString(PyExc_ValueError,
579-
"a.shape[axis] != len(repeats)");
578+
n = PyArray_DIM(aop, axis);
579+
580+
if (!broadcast && PyArray_SIZE(repeats) != n) {
581+
PyErr_Format(PyExc_ValueError,
582+
"operands could not be broadcast together "
583+
"with shape (%zd,) (%zd,)", n, PyArray_DIM(repeats, 0));
580584
goto fail;
581585
}
582-
583-
if (nd == 0) {
584-
total = counts[0]*n;
586+
if (broadcast) {
587+
total = counts[0] * n;
585588
}
586589
else {
587-
588-
total = 0;
589590
for (j = 0; j < n; j++) {
590591
if (counts[j] < 0) {
591592
PyErr_SetString(PyExc_ValueError, "count < 0");
@@ -595,7 +596,6 @@ PyArray_Repeat(PyArrayObject *aop, PyObject *op, int axis)
595596
}
596597
}
597598

598-
599599
/* Construct new array */
600600
PyArray_DIMS(aop)[axis] = total;
601601
Py_INCREF(PyArray_DESCR(aop));
@@ -623,7 +623,7 @@ PyArray_Repeat(PyArrayObject *aop, PyObject *op, int axis)
623623
}
624624
for (i = 0; i < n_outer; i++) {
625625
for (j = 0; j < n; j++) {
626-
tmp = nd ? counts[j] : counts[0];
626+
npy_intp tmp = broadcast ? counts[0] : counts[j];
627627
for (k = 0; k < tmp; k++) {
628628
memcpy(new_data, old_data, chunk);
629629
new_data += chunk;

numpy/core/tests/test_regression.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
import tempfile
1111
from os import path
1212
from io import BytesIO
13+
from itertools import chain
1314

1415
import numpy as np
1516
from numpy.testing import (
@@ -2118,6 +2119,12 @@ def passer(*args):
21182119

21192120
assert_raises(ValueError, np.frompyfunc, passer, 32, 1)
21202121

2122+
def test_repeat_broadcasting(self):
2123+
# gh-5743
2124+
a = np.arange(60).reshape(3, 4, 5)
2125+
for axis in chain(range(-a.ndim, a.ndim), [None]):
2126+
assert_equal(a.repeat(2, axis=axis), a.repeat([2], axis=axis))
2127+
21212128

21222129
if __name__ == "__main__":
21232130
run_module_suite()

0 commit commen 309C ts

Comments
 (0)
0