From aca35f74b0ff0798508af6bab47810cc96ce33c6 Mon Sep 17 00:00:00 2001 From: Nathan Goldbaum Date: Thu, 14 Mar 2024 12:39:21 -0600 Subject: [PATCH 1/2] BUG: raise error trying to coerce timedelta64('NaT') --- numpy/_core/src/multiarray/stringdtype/dtype.c | 3 +++ numpy/_core/tests/test_stringdtype.py | 6 ++++++ 2 files changed, 9 insertions(+) diff --git a/numpy/_core/src/multiarray/stringdtype/dtype.c b/numpy/_core/src/multiarray/stringdtype/dtype.c index 807184c3c26a..be6c460b5e3f 100644 --- a/numpy/_core/src/multiarray/stringdtype/dtype.c +++ b/numpy/_core/src/multiarray/stringdtype/dtype.c @@ -638,6 +638,9 @@ stringdtype_is_known_scalar_type(PyArray_DTypeMeta *NPY_UNUSED(cls), if (pytype == &PyDatetimeArrType_Type) { return 1; } + if (pytype == &PyTimedeltaArrType_Type) { + return 1; + } return 0; } diff --git a/numpy/_core/tests/test_stringdtype.py b/numpy/_core/tests/test_stringdtype.py index b856c667c021..9f53ce1cb734 100644 --- a/numpy/_core/tests/test_stringdtype.py +++ b/numpy/_core/tests/test_stringdtype.py @@ -916,6 +916,12 @@ def test_nat_casts(): np.array([output_object]*arr.size, dtype=dtype)) +def test_nat_conversion(): + for nat in [np.datetime64("NaT", "s"), np.timedelta64("NaT", "s")]: + with pytest.raises(ValueError, match="string coercion is disabled"): + np.array(["a", nat], dtype=StringDType(coerce=False)) + + def test_growing_strings(dtype): # growing a string leads to a heap allocation, this tests to make sure # we do that bookkeeping correctly for all possible starting cases From d2e36f2fcec883cf1dfdbc001a1ff668392c3ed6 Mon Sep 17 00:00:00 2001 From: Nathan Goldbaum Date: Thu, 14 Mar 2024 14:23:50 -0600 Subject: [PATCH 2/2] MNT: eliminate branching and duplication in is_known_scalar_type --- numpy/_core/src/multiarray/dtypemeta.c | 23 ++-- .../_core/src/multiarray/stringdtype/dtype.c | 107 +++++------------- 2 files changed, 33 insertions(+), 97 deletions(-) diff --git a/numpy/_core/src/multiarray/dtypemeta.c b/numpy/_core/src/multiarray/dtypemeta.c index 626b3bde1032..acee68bad54f 100644 --- a/numpy/_core/src/multiarray/dtypemeta.c +++ b/numpy/_core/src/multiarray/dtypemeta.c @@ -838,22 +838,13 @@ python_builtins_are_known_scalar_types( * This is necessary only for python scalar classes which we discover * as valid DTypes. */ - if (pytype == &PyFloat_Type) { - return 1; - } - if (pytype == &PyLong_Type) { - return 1; - } - if (pytype == &PyBool_Type) { - return 1; - } - if (pytype == &PyComplex_Type) { - return 1; - } - if (pytype == &PyUnicode_Type) { - return 1; - } - if (pytype == &PyBytes_Type) { + if (pytype == &PyFloat_Type || + pytype == &PyLong_Type || + pytype == &PyBool_Type || + pytype == &PyComplex_Type || + pytype == &PyUnicode_Type || + pytype == &PyBytes_Type) + { return 1; } return 0; diff --git a/numpy/_core/src/multiarray/stringdtype/dtype.c b/numpy/_core/src/multiarray/stringdtype/dtype.c index be6c460b5e3f..252888150d6e 100644 --- a/numpy/_core/src/multiarray/stringdtype/dtype.c +++ b/numpy/_core/src/multiarray/stringdtype/dtype.c @@ -554,91 +554,36 @@ stringdtype_get_clear_loop(void *NPY_UNUSED(traverse_context), } static int -stringdtype_is_known_scalar_type(PyArray_DTypeMeta *NPY_UNUSED(cls), +stringdtype_is_known_scalar_type(PyArray_DTypeMeta *cls, PyTypeObject *pytype) { - if (pytype == &PyFloat_Type) { + if (python_builtins_are_known_scalar_types(cls, pytype)) { return 1; } - if (pytype == &PyLong_Type) { - return 1; - } - if (pytype == &PyBool_Type) { - return 1; - } - if (pytype == &PyComplex_Type) { - return 1; - } - if (pytype == &PyUnicode_Type) { - return 1; - } - if (pytype == &PyBytes_Type) { - return 1; - } - if (pytype == &PyBoolArrType_Type) { - return 1; - } - if (pytype == &PyByteArrType_Type) { - return 1; - } - if (pytype == &PyShortArrType_Type) { - return 1; - } - if (pytype == &PyIntArrType_Type) { - return 1; - } - if (pytype == &PyLongArrType_Type) { - return 1; - } - if (pytype == &PyLongLongArrType_Type) { - return 1; - } - if (pytype == &PyUByteArrType_Type) { - return 1; - } - if (pytype == &PyUShortArrType_Type) { - return 1; - } - if (pytype == &PyUIntArrType_Type) { - return 1; - } - if (pytype == &PyULongArrType_Type) { - return 1; - } - if (pytype == &PyULongLongArrType_Type) { - return 1; - } - if (pytype == &PyHalfArrType_Type) { - return 1; - } - if (pytype == &PyFloatArrType_Type) { - return 1; - } - if (pytype == &PyDoubleArrType_Type) { - return 1; - } - if (pytype == &PyLongDoubleArrType_Type) { - return 1; - } - if (pytype == &PyCFloatArrType_Type) { - return 1; - } - if (pytype == &PyCDoubleArrType_Type) { - return 1; - } - if (pytype == &PyCLongDoubleArrType_Type) { - return 1; - } - if (pytype == &PyIntpArrType_Type) { - return 1; - } - if (pytype == &PyUIntpArrType_Type) { - return 1; - } - if (pytype == &PyDatetimeArrType_Type) { - return 1; - } - if (pytype == &PyTimedeltaArrType_Type) { + // accept every built-in numpy dtype + else if (pytype == &PyBoolArrType_Type || + pytype == &PyByteArrType_Type || + pytype == &PyShortArrType_Type || + pytype == &PyIntArrType_Type || + pytype == &PyLongArrType_Type || + pytype == &PyLongLongArrType_Type || + pytype == &PyUByteArrType_Type || + pytype == &PyUShortArrType_Type || + pytype == &PyUIntArrType_Type || + pytype == &PyULongArrType_Type || + pytype == &PyULongLongArrType_Type || + pytype == &PyHalfArrType_Type || + pytype == &PyFloatArrType_Type || + pytype == &PyDoubleArrType_Type || + pytype == &PyLongDoubleArrType_Type || + pytype == &PyCFloatArrType_Type || + pytype == &PyCDoubleArrType_Type || + pytype == &PyCLongDoubleArrType_Type || + pytype == &PyIntpArrType_Type || + pytype == &PyUIntpArrType_Type || + pytype == &PyDatetimeArrType_Type || + pytype == &PyTimedeltaArrType_Type) + { return 1; } return 0;