8000 Merge pull request #11406 from mattip/einsum-out-is-res · numpy/numpy@c4924b0 · GitHub
[go: up one dir, main page]

Skip to content 8000

Commit c4924b0

Browse files
authored
Merge pull request #11406 from mattip/einsum-out-is-res
BUG: ensure ret is out in einsum
2 parents 166b39f + ae5000f commit c4924b0

File tree

2 files changed

+15
-18
lines changed

2 files changed

+15
-18
lines changed

numpy/core/src/multiarray/einsum.c.src

Lines changed: 10 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -2767,11 +2767,11 @@ PyArray_EinsteinSum(char *subscripts, npy_intp nop,
27672767
goto fail;
27682768
}
27692769

2770-
/* Initialize the output to all zeros and reset the iterator */
2770+
/* Initialize the output to all zeros */
27712771
ret = NpyIter_GetOperandArray(iter)[nop];
2772-
Py_INCREF(ret);
2773-
PyArray_AssignZero(ret, NULL);
2774-
2772+
if (PyArray_AssignZero(ret, NULL) < 0) {
2773+
goto fail;
2774+
}
27752775

27762776
/***************************/
27772777
/*
@@ -2785,16 +2785,12 @@ PyArray_EinsteinSum(char *subscripts, npy_intp nop,
27852785
case 1:
27862786
if (ndim == 2) {
27872787
if (unbuffered_loop_nop1_ndim2(iter) < 0) {
2788-
Py_DECREF(ret);
2789-
ret = NULL;
27902788
goto fail;
27912789
}
27922790
goto finish;
27932791
}
27942792
else if (ndim == 3) {
27952793
if (unbuffered_loop_nop1_ndim3(iter) < 0) {
2796-
Py_DECREF(ret);
2797-
ret = NULL;
27982794
goto fail;
27992795
}
28002796
goto finish;
@@ -2803,16 +2799,12 @@ PyArray_EinsteinSum(char *subscripts, npy_intp nop,
28032799
case 2:
28042800
if (ndim == 2) {
28052801
if (unbuffered_loop_nop2_ndim2(iter) < 0) {
2806-
Py_DECREF(ret);
2807-
ret = NULL;
28082802
goto fail;
28092803
}
28102804
goto finish;
28112805
}
28122806
else if (ndim == 3) {
28132807
if (unbuffered_loop_nop2_ndim3(iter) < 0) {
2814-
Py_DECREF(ret);
2815-
ret = NULL;
28162808
goto fail;
28172809
}
28182810
goto finish;
@@ -2823,7 +2815,6 @@ PyArray_EinsteinSum(char *subscripts, npy_intp nop,
28232815
/***************************/
28242816

28252817
if (NpyIter_Reset(iter, NULL) != NPY_SUCCEED) {
2826-
Py_DECREF(ret);
28272818
goto fail;
28282819
}
28292820

@@ -2845,8 +2836,6 @@ PyArray_EinsteinSum(char *subscripts, npy_intp nop,
28452836
if (sop == NULL) {
28462837
PyErr_SetString(PyExc_TypeError,
28472838
"invalid data type for einsum");
2848-
Py_DECREF(ret);
2849-
ret = NULL;
28502839
}
28512840
else if (NpyIter_GetIterSize(iter) != 0) {
28522841
NpyIter_IterNextFunc *iternext;
@@ -2858,7 +2847,6 @@ PyArray_EinsteinSum(char *subscripts, npy_intp nop,
28582847
iternext = NpyIter_GetIterNext(iter, NULL);
28592848
if (iternext == NULL) {
28602849
NpyIter_Deallocate(iter);
2861-
Py_DECREF(ret);
28622850
goto fail;
28632851
}
28642852
dataptr = NpyIter_GetDataPtrArray(iter);
@@ -2874,12 +2862,16 @@ PyArray_EinsteinSum(char *subscripts, npy_intp nop,
28742862

28752863
/* If the API was needed, it may have thrown an error */
28762864
if (NpyIter_IterationNeedsAPI(iter) && PyErr_Occurred()) {
2877-
Py_DECREF(ret);
2878-
ret = NULL;
2865+
goto fail;
28792866
}
28802867
}
28812868

28822869
finish:
2870+
if (out != NULL) {
2871+
ret = out;
2872+
}
2873+
Py_INCREF(ret);
2874+
28832875
NpyIter_Deallocate(iter);
28842876
for (iop = 0; iop < nop; ++iop) {
28852877
Py_DECREF(op[iop]);

numpy/core/tests/test_einsum.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -730,6 +730,11 @@ def test_small_boolean_arrays(self):
730730
res = np.einsum('...ij,...jk->...ik', a, a, out=out)
731731
assert_equal(res, tgt)
732732

733+
def test_out_is_res(self):
734+
a = np.arange(9).reshape(3, 3)
735+
res = np.einsum('...ij,...jk->...ik', a, a, out=a)
736+
assert res is a
737+
733738
def optimize_compare(self, string):
734739
# Tests all paths of the optimization function against
735740
# conventional einsum

0 commit comments

Comments
 (0)
0