8000 fix bug with broadcasting to higher dimension · numpy/numpy@94f711d · GitHub
[go: up one dir, main page]

Skip to content

Commit 94f711d

Browse files
committed
fix bug with broadcasting to higher dimension
2,1 + 1, -> 2,2 which cannot be done inplace, so check for equal dimensionality
1 parent 89c3e46 commit 94f711d

File tree

2 files changed

+15
-10
lines changed

2 files changed

+15
-10
lines changed

numpy/core/src/multiarray/number.c

Lines changed: 5 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -612,19 +612,14 @@ can_elide_temp(PyArrayObject * alhs, PyObject * orhs, int * cannot)
612612
return 0;
613613
}
614614

615-
/* too large to fit into left hand side */
616-
if (PyArray_NDIM(arhs) > PyArray_NDIM(alhs)) {
617-
Py_DECREF(arhs);
618-
return 0;
619-
}
620-
621615
/*
622616
* if rhs is not a scalar dimensions must match
623-
* todo: one could allow full broadcasting on equal types
617+
* TODO: one could allow broadcasting on equal types
624618
*/
625-
if (PyArray_NDIM(arhs) > 0 &&
626-
!PyArray_CompareLists(PyArray_DIMS(alhs), PyArray_DIMS(arhs),
627-
PyArray_NDIM(arhs))) {
619+
if (!(PyArray_NDIM(arhs) == 0 ||
620+
(PyArray_NDIM(arhs) == PyArray_NDIM(alhs) &&
621+
PyArray_CompareLists(PyArray_DIMS(alhs), PyArray_DIMS(arhs),
622+
PyArray_NDIM(arhs))))) {
628623
Py_DECREF(arhs);
629624
return 0;
630625
}

numpy/core/tests/test_multiarray.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2741,6 +2741,16 @@ def test_temporary_with_cast(self):
27412741
d = f.astype(np.float64)
27422742
assert_equal(((f + f) + d).dtype, np.dtype('f8'))
27432743

2744+
def test_elide_broadcast(self):
2745+
# test no elision on broadcast to higher dimension
2746+
# only triggers elision code path in debug mode as triggering it in
2747+
# normal mode needs 256kb large matching dimension, so a lot of memory
2748+
d = np.ones((2000, 1), dtype=int)
2749+
b = np.ones((2000), dtype=np.bool)
2750+
r = (1 - d) + b
2751+
assert_equal(r, 1)
2752+
assert_equal(r.shape, (2000, 2000))
2753+
27442754
def test_ufunc_override_rop_precedence(self):
27452755
# 2016-01-29: NUMPY_UFUNC_DISABLED
27462756
return

0 commit comments

Comments
 (0)
0