diff --git a/Lib/test/test_decimal.py b/Lib/test/test_decimal.py index 02d3fa985e75b9..0e15eb3693888d 100644 --- a/Lib/test/test_decimal.py +++ b/Lib/test/test_decimal.py @@ -4481,6 +4481,15 @@ def test_implicit_context(self): self.assertIs(Decimal("NaN").fma(7, 1).is_nan(), True) # three arg power self.assertEqual(pow(Decimal(10), 2, 7), 2) + if self.decimal == C: + self.assertEqual(pow(10, Decimal(2), 7), 2) + self.assertEqual(pow(10, 2, Decimal(7)), 2) + else: + # XXX: Three-arg power doesn't use __rpow__. + self.assertRaises(TypeError, pow, 10, Decimal(2), 7) + # XXX: There is no special method to dispatch on the + # third arg of three-arg power. + self.assertRaises(TypeError, pow, 10, 2, Decimal(7)) # exp self.assertEqual(Decimal("1.01").exp(), 3) # is_normal diff --git a/Misc/NEWS.d/next/Library/2025-02-17-21-16-51.gh-issue-130230.9ta9P9.rst b/Misc/NEWS.d/next/Library/2025-02-17-21-16-51.gh-issue-130230.9ta9P9.rst new file mode 100644 index 00000000000000..20327fd5f25b43 --- /dev/null +++ b/Misc/NEWS.d/next/Library/2025-02-17-21-16-51.gh-issue-130230.9ta9P9.rst @@ -0,0 +1 @@ +Fix crash in :func:`pow` with only :class:`~decimal.Decimal` third argument. diff --git a/Modules/_decimal/_decimal.c b/Modules/_decimal/_decimal.c index 3dcb3e9870c8a4..8a24d8c12cab8a 100644 --- a/Modules/_decimal/_decimal.c +++ b/Modules/_decimal/_decimal.c @@ -147,6 +147,24 @@ find_state_left_or_right(PyObject *left, PyObject *right) return (decimal_state *)state; } +static inline decimal_state * +find_state_ternary(PyObject *left, PyObject *right, PyObject *modulus) +{ + PyTypeObject *base; + if (PyType_GetBaseByToken(Py_TYPE(left), &dec_spec, &base) != 1) { + assert(!PyErr_Occurred()); + if (PyType_GetBaseByToken(Py_TYPE(right), &dec_spec, &base) != 1) { + assert(!PyErr_Occurred()); + PyType_GetBaseByToken(Py_TYPE(modulus), &dec_spec, &base); + } + } + assert(base != NULL); + void *state = _PyType_GetModuleState(base); + assert(state != NULL); + Py_DECREF(base); + return (decimal_state *)state; +} + #if !defined(MPD_VERSION_HEX) || MPD_VERSION_HEX < 0x02050000 #error "libmpdec version >= 2.5.0 required" @@ -4407,7 +4425,7 @@ nm_mpd_qpow(PyObject *base, PyObject *exp, PyObject *mod) PyObject *context; uint32_t status = 0; - decimal_state *state = find_state_left_or_right(base, exp); + decimal_state *state = find_state_ternary(base, exp, mod); CURRENT_CONTEXT(state, context); CONVERT_BINOP(&a, &b, base, exp, context);