10000 Merge pull request #19269 from charris/backport-19228 · numpy/numpy@143d45f · GitHub
[go: up one dir, main page]

Skip to content

Commit 143d45f

Browse files
authored
Merge pull request #19269 from charris/backport-19228
BUG: Invalid dtypes comparison should not raise TypeError
2 parents a070e5d + d80e473 commit 143d45f

File tree

3 files changed

+21
-3
lines changed

3 files changed

+21
-3
lines changed

numpy/__init__.pyi

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1075,8 +1075,6 @@ class dtype(Generic[_DTypeScalar_co]):
10751075
# literals as of mypy 0.800. Set the return-type to `Any` for now.
10761076
def __rmul__(self, value: int) -> Any: ...
10771077

1078-
def __eq__(self, other: DTypeLike) -> bool: ...
1079-
def __ne__(self, other: DTypeLike) -> bool: ...
10801078
def __gt__(self, other: DTypeLike) -> bool: ...
10811079
def __ge__(self, other: DTypeLike) -> bool: ...
10821080
def __lt__(self, other: DTypeLike) -> bool: ...

numpy/core/src/multiarray/descriptor.c

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3228,7 +3228,9 @@ arraydescr_richcompare(PyArray_Descr *self, PyObject *other, int cmp_op)
32283228
{
32293229
PyArray_Descr *new = _convert_from_any(other, 0);
32303230
if (new == NULL) {
3231-
return NULL;
3231+
/* Cannot convert `other` to dtype */
3232+
PyErr_Clear();
3233+
Py_RETURN_NOTIMPLEMENTED;
32323234
}
32333235

32343236
npy_bool ret;

numpy/core/tests/test_dtype.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,24 @@ def test_invalid_types(self):
8888
assert_raises(TypeError, np.dtype, 'q8')
8989
assert_raises(TypeError, np.dtype, 'Q8')
9090

91+
def test_richcompare_invalid_dtype_equality(self):
92+
# Make sure objects that cannot be converted to valid
93+
# dtypes results in False/True when compared to valid dtypes.
94+
# Here 7 cannot be converted to dtype. No exceptions should be raised
95+
96+
assert not np.dtype(np.int32) == 7, "dtype richcompare failed for =="
97+
assert np.dtype(np.int32) != 7, "dtype richcompare failed for !="
98+
99+
@pytest.mark.parametrize(
100+
'operation',
101+
[operator.le, operator.lt, operator.ge, operator.gt])
102+
def test_richcompare_invalid_dtype_comparison(self, operation):
103+
# Make sure TypeError is raised for comparison operators
104+
# for invalid dtypes. Here 7 is an invalid dtype.
105+
106+
with pytest.raises(TypeError):
107+
operation(np.dtype(np.int32), 7)
108+
91109
@pytest.mark.parametrize("dtype",
92110
['Bool', 'Complex32', 'Complex64', 'Float16', 'Float32', 'Float64',
93111
'Int8', 'Int16', 'Int32', 'Int64', 'Object0', 'Timedelta64',

0 commit comments

Comments
 (0)
0