8000 Merge pull request #26905 from seberg/strict-typing-nep50 · numpy/numpy@0819378 · GitHub
[go: up one dir, main page]

Skip to content

Commit 0819378

Browse files
authored
Merge pull request #26905 from seberg/strict-typing-nep50
API: Do not consider subclasses for NEP 50 weak promotion
2 parents a87fb26 + 1e92917 commit 0819378

File tree

4 files changed

+25
-47
lines changed

4 files changed

+25
-47
lines changed

numpy/_core/src/multiarray/abstractdtypes.h

Lines changed: 3 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -41,34 +41,23 @@ static inline int
4141
npy_mark_tmp_array_if_pyscalar(
4242
PyObject *obj, PyArrayObject *arr, PyArray_DTypeMeta **dtype)
4343
{
44-
/*
45-
* We check the array dtype for two reasons: First, booleans are
46-
* integer subclasses. Second, an int, float, or complex could have
47-
* a custom DType registered, and then we should use that.
48-
* Further, `np.float64` is a double subclass, so must reject it.
49-
*/
50-
// TODO,NOTE: This function should be changed to do exact long checks
51-
// For NumPy 2.1!
52-
if (PyLong_Check(obj)
53-
&& (PyArray_ISINTEGER(arr) || PyArray_ISOBJECT(arr))) {
44+
if (PyLong_CheckExact(obj)) {
5445
((PyArrayObject_fields *)arr)->flags |= NPY_ARRAY_WAS_PYTHON_INT;
5546
if (dtype != NULL) {
5647
Py_INCREF(&PyArray_PyLongDType);
5748
Py_SETREF(*dtype, &PyArray_PyLongDType);
5849
}
5950
return 1;
6051
}
61-
else if (PyFloat_Check(obj) && !PyArray_IsScalar(obj, Double)
62-
&& PyArray_TYPE(arr) == NPY_DOUBLE) {
52+
else if (PyFloat_CheckExact(obj)) {
6353
((PyArrayObject_fields *)arr)->flags |= NPY_ARRAY_WAS_PYTHON_FLOAT;
6454
if (dtype != NULL) {
6555
Py_INCREF(&PyArray_PyFloatDType);
6656
Py_SETREF(*dtype, &PyArray_PyFloatDType);
6757
}
6858
return 1;
6959
}
70-
else if (PyComplex_Check(obj) && !PyArray_IsScalar(obj, CDouble)
71-
&& PyArray_TYPE(arr) == NPY_CDOUBLE) {
60+
else if (PyComplex_CheckExact(obj)) {
7261
((PyArrayObject_fields *)arr)->flags |= NPY_ARRAY_WAS_PYTHON_COMPLEX;
7362
if (dtype != NULL) {
7463
Py_INCREF(&PyArray_PyComplexDType);

numpy/_core/src/multiarray/scalartypes.c.src

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -191,9 +191,9 @@ find_binary_operation_path(
191191
*self_op = NULL;
192192

193193
if (PyArray_IsScalar(other, Generic) ||
194-
PyLong_Check(other) ||
195-
PyFloat_Check(other) ||
196-
PyComplex_Check(other) ||
194+
PyLong_CheckExact(other) ||
195+
PyFloat_CheckExact(other) ||
196+
PyComplex_CheckExact(other) ||
197197
PyBool_Check(other)) {
198198
/*
199199
* The other operand is ready for the operation already. Must pass on

numpy/_core/src/umath/scalarmath.c.src

Lines changed: 6 additions & 24 deletions
10000
Original file line numberDiff line numberDiff line change
@@ -954,15 +954,7 @@ convert_to_@name@(PyObject *value, @type@ *result, npy_bool *may_need_deferring)
954954
return CONVERSION_SUCCESS;
955955
}
956956

957-
if (PyFloat_Check(value)) {
958-
if (!PyFloat_CheckExact(value)) {
959-
/* A NumPy double is a float subclass, but special. */
960-
if (PyArray_IsScalar(value, Double)) {
961-
descr = PyArray_DescrFromType(NPY_DOUBLE);
962-
goto numpy_scalar;
963-
}
964-
*may_need_deferring = NPY_TRUE;
965-
}
957+
if (PyFloat_CheckExact(value)) {
966958
if (!IS_SAFE(NPY_DOUBLE, NPY_@TYPE@)) {
967959
if (get_npy_promotion_state() != NPY_USE_WEAK_PROMOTION) {
968960
/* Legacy promotion and weak-and-warn not handled here */
@@ -978,10 +970,7 @@ convert_to_@name@(PyObject *value, @type@ *result, npy_bool *may_need_deferring)
978970
return CONVERSION_SUCCESS;
979971
}
980972

981-
if (PyLong_Check(value)) {
982-
if (!PyLong_CheckExact(value)) {
983-
*may_need_deferring = NPY_TRUE;
984-
}
973+
if (PyLong_CheckExact(value)) {
985974
if (!IS_SAFE(NPY_LONG, NPY_@TYPE@)) {
986975
/*
987976
* long -> (c)longdouble is safe, so `OTHER_IS_UNKNOWN_OBJECT` will
@@ -1009,15 +998,7 @@ convert_to_@name@(PyObject *value, @type@ *result, npy_bool *may_need_deferring)
1009998
return CONVERSION_SUCCESS;
1010999
}
10111000

1012-
if (PyComplex_Check(value)) {
1013-
if (!PyComplex_CheckExact(value)) {
1014-
/* A NumPy complex double is a float subclass, but special. */
1015-
if (PyArray_IsScalar(value, CDouble)) {
1016-
descr = PyArray_DescrFromType(NPY_CDOUBLE);
1017-
goto numpy_scalar;
1018-
}
1019-
*may_need_deferring = NPY_TRUE;
1020-
}
1001+
if (PyComplex_CheckExact(value)) {
10211002
if (!IS_SAFE(NPY_CDOUBLE, NPY_@TYPE@)) {
10221003
if (get_npy_promotion_state() != NPY_USE_WEAK_PROMOTION) {
10231004
/* Legacy promotion and weak-and-warn not handled here */
@@ -1079,7 +1060,6 @@ convert_to_@name@(PyObject *value, @type@ *result, npy_bool *may_need_deferring)
10791060
return OTHER_IS_UNKNOWN_OBJECT;
10801061
}
10811062

1082-
numpy_scalar:
10831063
if (descr->typeobj != Py_TYPE(value)) {
10841064
/*
10851065
* This is a subclass of a builtin type, we may continue normally,
@@ -1409,7 +1389,8 @@ static PyObject *
14091389
npy_bool may_need_deferring;
14101390
conversion_result res = convert_to_@name@(
14111391
other, &other_val_conv, &may_need_deferring);
1412-
other_val = other_val_conv; /* Need a float value */
1392+
/* Actual float cast `other_val` is set below on success. */
1393+
14131394
if (res == CONVERSION_ERROR) {
14141395
return NULL; /* an error occurred (should never happen) */
14151396
}
@@ -1420,6 +1401,7 @@ static PyObject *
14201401
case DEFER_TO_OTHER_KNOWN_SCALAR:
14211402
Py_RETURN_NOTIMPLEMENTED;
14221403
case CONVERSION_SUCCESS:
1404+
other_val = other_val_conv; /* Need a float value */
14231405
break; /* successfully extracted value we can proceed */
14241406
case OTHER_IS_UNKNOWN_OBJECT:
14251407
case PROMOTION_REQUIRED:

numpy/_core/tests/test_scalarmath.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1073,6 +1073,9 @@ def test_longdouble_complex():
10731073
@pytest.mark.parametrize("subtype", [float, int, complex, np.float16])
10741074
@np._no_nep50_warning()
10751075
def test_pyscalar_subclasses(subtype, __op__, __rop__, op, cmp):
1076+
# This tests that python scalar subclasses behave like a float64 (if they
1077+
# don't override it).
1078+
# In an earlier version of NEP 50, they behaved like the Python buildins.
10761079
def op_func(self, other):
10771080
return __op__
10781081

@@ -1095,25 +1098,29 @@ def rop_func(self, other):
10951098

10961099
# When no deferring is indicated, subclasses are handled normally.
10971100
myt = type("myt", (subtype,), {__rop__: rop_func})
1101+
behaves_like = lambda x: np.array(subtype(x))[()]
10981102

10991103
# Check for float32, as a float subclass float64 may behave differently
11001104
res = op(myt(1), np.float16(2))
1101-
expected = op(subtype(1), np.float16(2))
1105+
expected = op(behaves_like(1), np.float16(2))
11021106
assert res == expected
11031107
assert type(res) == type(expected)
11041108
res = op(np.float32(2), myt(1))
1105-
expected = op(np.float32(2), subtype(1))
1109+
expected = op(np.float32(2), behaves_like(1))
11061110
assert res == expected
11071111
assert type(res) == type(expected)
11081112

1109-
# Same check for longdouble:
1113+
# Same check for longdouble (compare via dtype to accept float64 when
1114+
# longdouble has the identical size), which is currently not perfectly
1115+
# consistent.
11101116
res = op(myt(1), np.longdouble(2))
1111-
expected = op(subtype(1), np.longdouble(2))
1117+
expected = op(behaves_like(1), np.longdouble(2))
11121118
assert res == expected
1113-
assert type(res) == type(expected)
1119+
assert np.dtype(type(res)) == np.dtype(type(expected))
11141120
res = op(np.float32(2), myt(1))
1115-
expected = op(np.longdouble(2), subtype(1))
1121+
expected = op(np.float32(2), behaves_like(1))
11161122
assert res == expected
1123+
assert np.dtype(type(res)) == np.dtype(type(expected))
11171124

11181125

11191126
def test_truediv_int():

0 commit comments

Comments
 (0)
0