@@ -546,9 +546,9 @@ NPY_NO_EXPORT PyObject *
546
546
PyArray_Repeat (PyArrayObject * aop , PyObject * op , int axis )
547
547
{
548
548
npy_intp * counts ;
549
- npy_intp n , n_outer , i , j , k , chunk , total ;
550
- npy_intp tmp ;
551
- int nd ;
549
+ npy_intp n , n_outer , i , j , k , chunk ;
550
+ npy_intp total = 0 ;
551
+ npy_bool broadcast = NPY_FALSE ;
552
552
PyArrayObject * repeats = NULL ;
553
553
PyObject * ap = NULL ;
554
554
PyArrayObject * ret = NULL ;
@@ -558,34 +558,35 @@ PyArray_Repeat(PyArrayObject *aop, PyObject *op, int axis)
558
558
if (repeats == NULL ) {
559
559
return NULL ;
560
560
}
561
- nd = PyArray_NDIM (repeats );
561
+
562
+ /*
563
+ * Scalar and size 1 'repeat' arrays broadcast to any shape, for all
564
+ * other inputs the dimension must match exactly.
565
+ */
566
+ if (PyArray_NDIM (repeats ) == 0 || PyArray_SIZE (repeats ) == 1 ) {
567
+ broadcast = NPY_TRUE ;
568
+ }
569
+
562
570
counts = (npy_intp * )PyArray_DATA (repeats );
563
571
564
- if ((ap = PyArray_CheckAxis (aop , & axis , NPY_ARRAY_CARRAY ))== NULL ) {
572
+ if ((ap = PyArray_CheckAxis (aop , & axis , NPY_ARRAY_CARRAY )) == NULL ) {
565
573
Py_DECREF (repeats );
566
574
return NULL ;
567
575
}
568
576
569
577
aop = (PyArrayObject * )ap ;
570
- if (nd == 1 ) {
571
- n = PyArray_DIMS (repeats )[0 ];
572
- }
573
- else {
574
- /* nd == 0 */
575
- n = PyArray_DIMS (aop )[axis ];
576
- }
577
- if (PyArray_DIMS (aop )[axis ] != n ) {
578
- PyErr_SetString (PyExc_ValueError ,
579
- "a.shape[axis] != len(repeats)" );
578
+ n = PyArray_DIM (aop , axis );
579
+
580
+ if (!broadcast && PyArray_SIZE (repeats ) != n ) {
581
+ PyErr_Format (PyExc_ValueError ,
582
+ "operands could not be broadcast together "
583
+ "with shape (%zd,) (%zd,)" , n , PyArray_DIM (repeats , 0 ));
580
584
goto fail ;
581
585
}
582
-
583
- if (nd == 0 ) {
584
- total = counts [0 ]* n ;
586
+ if (broadcast ) {
587
+ total = counts [0 ] * n ;
585
588
}
586
589
else {
587
-
588
- total = 0 ;
589
590
for (j = 0 ; j < n ; j ++ ) {
590
591
if (counts [j ] < 0 ) {
591
592
PyErr_SetString (PyExc_ValueError , "count < 0" );
@@ -595,7 +596,6 @@ PyArray_Repeat(PyArrayObject *aop, PyObject *op, int axis)
595
596
}
596
597
}
597
598
598
-
599
599
/* Construct new array */
600
600
PyArray_DIMS (aop )[axis ] = total ;
601
601
Py_INCREF (PyArray_DESCR (aop ));
@@ -623,7 +623,7 @@ PyArray_Repeat(PyArrayObject *aop, PyObject *op, int axis)
623
623
}
624
624
for (i = 0 ; i < n_outer ; i ++ ) {
625
625
for (j = 0 ; j < n ; j ++ ) {
626
- tmp = nd ? counts [j ] : counts [0 ];
626
+ npy_intp tmp = broadcast ? counts [0 ] : counts [j ];
627
627
for (k = 0 ; k < tmp ; k ++ ) {
628
628
memcpy (new_data , old_data , chunk );
629
629
new_data += chunk ;
0 commit commen 309C ts