8000 Merge pull request #8806 from eric-wieser/error-on-ternary-pow · numpy/numpy@987b95c · GitHub
[go: up one dir, main page]

Skip to content

Commit 987b95c

Browse files
authored
Merge pull request #8806 from eric-wieser/error-on-ternary-pow
BUG: Raise TypeError on ternary power
2 parents a2c4d32 + a7203f9 commit 987b95c

File tree

4 files changed

+52
-12
lines changed

4 files changed

+52
-12
lines changed

numpy/core/src/multiarray/number.c

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -614,10 +614,14 @@ fast_scalar_power(PyArrayObject *a1, PyObject *o2, int inplace)
614614
}
615615

616616
static PyObject *
617-
array_power(PyArrayObject *a1, PyObject *o2, PyObject *NPY_UNUSED(modulo))
617+
array_power(PyArrayObject *a1, PyObject *o2, PyObject *modulo)
618618
{
619-
/* modulo is ignored! */
620619
PyObject *value;
620+
if (modulo != Py_None) {
621+
/* modular exponentiation is not implemented (gh-8804) */
622+
Py_INCREF(Py_NotImplemented);
623+
return Py_NotImplemented;
624+
}
621625
GIVE_UP_IF_HAS_RIGHT_BINOP(a1, o2, "__pow__", "__rpow__", 0, nb_power);
622626
value = fast_scalar_power(a1, o2, 0);
623627
if (!value) {

numpy/core/src/multiarray/scalartypes.c.src

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -149,11 +149,17 @@ gentype_free(PyObject *v)
149149

150150

151151
static PyObject *
152-
gentype_power(PyObject *m1, PyObject *m2, PyObject *NPY_UNUSED(m3))
152+
gentype_power(PyObject *m1, PyObject *m2, PyObject *modulo)
153153
{
154154
PyObject *arr, *ret, *arg2;
155155
char *msg="unsupported operand type(s) for ** or pow()";
156156

157+
if (modulo != Py_None) {
158+
/* modular exponentiation is not implemented (gh-8804) */
159+
Py_INCREF(Py_NotImplemented);
160+
return Py_NotImplemented;
161+
}
162+
157163
if (!PyArray_IsScalar(m1, Generic)) {
158164
if (PyArray_Check(m1)) {
159165
ret = Py_TYPE(m1)->tp_as_number->nb_power(m1,m2, Py_None);

numpy/core/src/umath/scalarmath.c.src

Lines changed: 28 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -956,7 +956,7 @@ static PyObject *
956956

957957
#if @cmplx@
958958
static PyObject *
959-
@name@_power(PyObject *a, PyObject *b, PyObject *NPY_UNUSED(c))
959+
@name@_power(PyObject *a, PyObject *b, PyObject *modulo)
960960
{
961961
PyObject *ret;
962962
@type@ arg1, arg2;
@@ -969,13 +969,13 @@ static PyObject *
969969
break;
970970
case -1:
971971
/* can't cast both safely mixed-types? */
972-
return PyArray_Type.tp_as_number->nb_power(a,b,NULL);
972+
return PyArray_Type.tp_as_number->nb_power(a,b,modulo);
973973
case -2:
974974
/* use default handling */
975975
if (PyErr_Occurred()) {
976976
return NULL;
977977
}
978-
return PyGenericArrType_Type.tp_as_number->nb_power(a,b,NULL);
978+
return PyGenericArrType_Type.tp_as_number->nb_power(a,b,modulo);
979979
case -3:
980980
default:
981981
/*
@@ -986,6 +986,12 @@ static PyObject *
986986
return Py_NotImplemented;
987987
}
988988

989+
if (modulo != Py_None) {
990+
/* modular exponentiation is not implemented (gh-8804) */
991+
Py_INCREF(Py_NotImplemented);
992+
return Py_NotImplemented;
993+
}
994+
989995
PyUFunc_clearfperr();
990996

991997
/*
@@ -1030,7 +1036,7 @@ static PyObject *
10301036
#elif @isint@
10311037

10321038
static PyObject *
1033-
@name@_power(PyObject *a, PyObject *b, PyObject *NPY_UNUSED(c))
1039+
@name@_power(PyObject *a, PyObject *b, PyObject *modulo)
10341040
{
10351041
PyObject *ret;
10361042
@type@ arg1, arg2, out;
@@ -1040,13 +1046,13 @@ static PyObject *
10401046
break;
10411047
case -1:
10421048
/* can't cast both safely mixed-types? */
1043-
return PyArray_Type.tp_as_number->nb_power(a,b,NULL);
1049+
return PyArray_Type.tp_as_number->nb_power(a,b,modulo);
10441050
case -2:
10451051
/* use default handling */
10461052
if (PyErr_Occurred()) {
10471053
return NULL;
10481054
}
1049-
return PyGenericArrType_Type.tp_as_number->nb_power(a,b,NULL);
1055+
return PyGenericArrType_Type.tp_as_number->nb_power(a,b,modulo);
10501056
case -3:
10511057
default:
10521058
/*
@@ -1056,6 +1062,13 @@ static PyObject *
10561062
Py_INCREF(Py_NotImplemented);
10571063
return Py_NotImplemented;
10581064
}
1065+
1066+
if (modulo != Py_None) {
1067+
/* modular exponentiation is not implemented (gh-8804) */
1068+
Py_INCREF(Py_NotImplemented);
1069+
return Py_NotImplemented;
1070+
}
1071+
10591072
PyUFunc_clearfperr();
10601073

10611074
/*
@@ -1081,7 +1094,7 @@ static PyObject *
10811094
#else
10821095

10831096
static PyObject *
1084-
@name@_power(PyObject *a, PyObject *b, PyObject *NPY_UNUSED(c))
1097+
@name@_power(PyObject *a, PyObject *b, PyObject *modulo)
10851098
{
10861099
PyObject *ret;
10871100
@type@ arg1, arg2;
@@ -1094,13 +1107,13 @@ static PyObject *
10941107
break;
10951108
case -1:
10961109
/* can't cast both safely mixed-types? */
1097-
return PyArray_Type.tp_as_number->nb_power(a,b,NULL);
1110+
return PyArray_Type.tp_as_number->nb_power(a,b,modulo);
10981111
case -2:
10991112
/* use default handling */
11001113
if (PyErr_Occurred()) {
11011114
return NULL;
11021115
}
1103-
return PyGenericArrType_Type.tp_as_number->nb_power(a,b,NULL);
1116+
return PyGenericArrType_Type.tp_as_number->nb_power(a,b,modulo);
11041117
case -3:
11051118
default:
11061119
/*
@@ -1111,6 +1124,12 @@ static PyObject *
11111124
return Py_NotImplemented;
11121125
}
11131126

1127+
if (modulo != Py_None) {
1128+
/* modular exponentiation is not implemented (gh-8804) */
1129+
Py_INCREF(Py_NotImplemented);
1130+
return Py_NotImplemented;
1131+
}
1132+
11141133
PyUFunc_clearfperr();
11151134

11161135
/*

numpy/core/tests/test_scalarmath.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -177,6 +177,17 @@ def test_mixed_types(self):
177177
else:
178178
assert_almost_equal(result, 9, err_msg=msg)
179179

180+
def test_modular_power(self):
181+
# modular power is not implemented, so ensure it errors
182+
a = 5
183+
b = 4
184+
c = 10
185+
expected = pow(a, b, c)
186+
for t in (np.int32, np.float32, np.complex64):
187+
# note that 3-operand power only dispatches on the first argument
188+
assert_raises(TypeError, operator.pow, t(a), b, c)
189+
assert_raises(TypeError, operator.pow, np.array(t(a)), b, c)
190+
180191

181192
class TestModulus(TestCase):
182193

0 commit comments

Comments
 (0)
0