From 50dce51e9fcbf5d80bfdc1bfe2b41fc9a9e9f9cc Mon Sep 17 00:00:00 2001 From: Sebastian Berg Date: Mon, 21 Dec 2020 16:46:25 -0600 Subject: [PATCH] BUG: Fix concatenation when the output is "S" or "U" Previously, the dtype was used, this now assumes that we want to cast to a string of (unknown) length. This is a simplified version of what happens in `np.array()` or `arr.astype()` (it does never inspect the values, e.g. for object casts). This is more complex as I would like, and with the refactor of ResultType and similar can be cleaned up a bit more hopefully. Note that currently, object to "S" or "U" casts simply return length 64 strings, but with the new version, this will be an error (although the error message probably needs improvement). This is a behaviour inherited from other places however. --- numpy/core/src/multiarray/convert_datatype.c | 67 ++++++++++++++++++++ numpy/core/src/multiarray/convert_datatype.h | 4 ++ numpy/core/src/multiarray/multiarraymodule.c | 57 ++++++----------- numpy/core/tests/test_shape_base.py | 28 +++++++- 4 files changed, 116 insertions(+), 40 deletions(-) diff --git a/numpy/core/src/multiarray/convert_datatype.c b/numpy/core/src/multiarray/convert_datatype.c index f9dd35a73e18..5d5b69bd5c5b 100644 --- a/numpy/core/src/multiarray/convert_datatype.c +++ b/numpy/core/src/multiarray/convert_datatype.c @@ -871,6 +871,73 @@ PyArray_CastDescrToDType(PyArray_Descr *descr, PyArray_DTypeMeta *given_DType) } +/* + * Helper to find the target descriptor for multiple arrays given an input + * one that may be a DType class (e.g. "U" or "S"). + * Works with arrays, since that is what `concatenate` works with. However, + * unlike `np.array(...)` or `arr.astype()` we will never inspect the array's + * content, which means that object arrays can only be cast to strings if a + * fixed width is provided (same for string -> generic datetime). + * + * As this function uses `PyArray_ExtractDTypeAndDescriptor`, it should + * eventually be refactored to move the step to an earlier point. + */ +NPY_NO_EXPORT PyArray_Descr * +PyArray_FindConcatenationDescriptor( + npy_intp n, PyArrayObject **arrays, PyObject *requested_dtype) +{ + if (requested_dtype == NULL) { + return PyArray_ResultType(n, arrays, 0, NULL); + } + + PyArray_DTypeMeta *common_dtype; + PyArray_Descr *result = NULL; + if (PyArray_ExtractDTypeAndDescriptor( + requested_dtype, &result, &common_dtype) < 0) { + return NULL; + } + if (result != NULL) { + if (result->subarray != NULL) { + PyErr_Format(PyExc_TypeError, + "The dtype `%R` is not a valid dtype for concatenation " + "since it is a subarray dtype (the subarray dimensions " + "would be added as array dimensions).", result); + Py_DECREF(result); + return NULL; + } + goto finish; + } + assert(n > 0); /* concatenate requires at least one array input. */ + PyArray_Descr *descr = PyArray_DESCR(arrays[0]); + result = PyArray_CastDescrToDType(descr, common_dtype); + if (result == NULL || n == 1) { + goto finish; + } + /* + * This could short-cut a bit, calling `common_instance` directly and/or + * returning the `default_descr()` directly. Avoiding that (for now) as + * it would duplicate code from `PyArray_PromoteTypes`. + */ + for (npy_intp i = 1; i < n; i++) { + descr = PyArray_DESCR(arrays[i]); + PyArray_Descr *curr = PyArray_CastDescrToDType(descr, common_dtype); + if (curr == NULL) { + Py_SETREF(result, NULL); + goto finish; + } + Py_SETREF(result, PyArray_PromoteTypes(result, curr)); + Py_DECREF(curr); + if (result == NULL) { + goto finish; + } + } + + finish: + Py_DECREF(common_dtype); + return result; +} + + /** * This function defines the common DType operator. * diff --git a/numpy/core/src/multiarray/convert_datatype.h b/numpy/core/src/multiarray/convert_datatype.h index cc1930f77db2..97006b952543 100644 --- a/numpy/core/src/multiarray/convert_datatype.h +++ b/numpy/core/src/multiarray/convert_datatype.h @@ -49,6 +49,10 @@ npy_set_invalid_cast_error( NPY_NO_EXPORT PyArray_Descr * PyArray_CastDescrToDType(PyArray_Descr *descr, PyArray_DTypeMeta *given_DType); +NPY_NO_EXPORT PyArray_Descr * +PyArray_FindConcatenationDescriptor( + npy_intp n, PyArrayObject **arrays, PyObject *requested_dtype); + NPY_NO_EXPORT int PyArray_AddCastingImplmentation(PyBoundArrayMethodObject *meth); diff --git a/numpy/core/src/multiarray/multiarraymodule.c b/numpy/core/src/multiarray/multiarraymodule.c index 32c5ac0dc20c..c2c139798a18 100644 --- a/numpy/core/src/multiarray/multiarraymodule.c +++ b/numpy/core/src/multiarray/multiarraymodule.c @@ -448,17 +448,10 @@ PyArray_ConcatenateArrays(int narrays, PyArrayObject **arrays, int axis, /* Get the priority subtype for the array */ PyTypeObject *subtype = PyArray_GetSubType(narrays, arrays); - - if (dtype == NULL) { - /* Get the resulting dtype from combining all the arrays */ - dtype = (PyArray_Descr *)PyArray_ResultType( - narrays, arrays, 0, NULL); - if (dtype == NULL) { - return NULL; - } - } - else { - Py_INCREF(dtype); + PyArray_Descr *descr = PyArray_FindConcatenationDescriptor( + narrays, arrays, (PyObject *)dtype); + if (descr == NULL) { + return NULL; } /* @@ -467,7 +460,7 @@ PyArray_ConcatenateArrays(int narrays, PyArrayObject **arrays, int axis, * resolution rules matching that of the NpyIter. */ PyArray_CreateMultiSortedStridePerm(narrays, arrays, ndim, strideperm); - s = dtype->elsize; + s = descr->elsize; for (idim = ndim-1; idim >= 0; --idim) { int iperm = strideperm[idim]; strides[iperm] = s; @@ -475,17 +468,13 @@ PyArray_ConcatenateArrays(int narrays, PyArrayObject **arrays, int axis, } /* Allocate the array for the result. This steals the 'dtype' reference. */ - ret = (PyArrayObject *)PyArray_NewFromDescr(subtype, - dtype, - ndim, - shape, - strides, - NULL, - 0, - NULL); + ret = (PyArrayObject *)PyArray_NewFromDescr_int( + subtype, descr, ndim, shape, strides, NULL, 0, NULL, + NULL, 0, 1); if (ret == NULL) { return NULL; } + assert(PyArray_DESCR(ret) == descr); } /* @@ -575,32 +564,22 @@ PyArray_ConcatenateFlattenedArrays(int narrays, PyArrayObject **arrays, /* Get the priority subtype for the array */ PyTypeObject *subtype = PyArray_GetSubType(narrays, arrays); - if (dtype == NULL) { - /* Get the resulting dtype from combining all the arrays */ - dtype = (PyArray_Descr *)PyArray_ResultType( - narrays, arrays, 0, NULL); - if (dtype == NULL) { - return NULL; - } - } - else { - Py_INCREF(dtype); + PyArray_Descr *descr = PyArray_FindConcatenationDescriptor( + narrays, arrays, (PyObject *)dtype); + if (descr == NULL) { + return NULL; } - stride = dtype->elsize; + stride = descr->elsize; /* Allocate the array for the result. This steals the 'dtype' reference. */ - ret = (PyArrayObject *)PyArray_NewFromDescr(subtype, - dtype, - 1, - &shape, - &stride, - NULL, - 0, - NULL); + ret = (PyArrayObject *)PyArray_NewFromDescr_int( + subtype, descr, 1, &shape, &stride, NULL, 0, NULL, + NULL, 0, 1); if (ret == NULL) { return NULL; } + assert(PyArray_DESCR(ret) == descr); } /* diff --git a/numpy/core/tests/test_shape_base.py b/numpy/core/tests/test_shape_base.py index 4e56ace90fb6..9922c91731f2 100644 --- a/numpy/core/tests/test_shape_base.py +++ b/numpy/core/tests/test_shape_base.py @@ -343,7 +343,7 @@ def test_bad_out_shape(self): concatenate((a, b), out=np.empty(4)) @pytest.mark.parametrize("axis", [None, 0]) - @pytest.mark.parametrize("out_dtype", ["c8", "f4", "f8", ">f8", "i8"]) + @pytest.mark.parametrize("out_dtype", ["c8", "f4", "f8", ">f8", "i8", "S4"]) @pytest.mark.parametrize("casting", ['no', 'equiv', 'safe', 'same_kind', 'unsafe']) def test_out_and_dtype(self, axis, out_dtype, casting): @@ -369,6 +369,32 @@ def test_out_and_dtype(self, axis, out_dtype, casting): with assert_raises(TypeError): concatenate(to_concat, out=out, dtype=out_dtype, axis=axis) + @pytest.mark.parametrize("axis", [None, 0]) + @pytest.mark.parametrize("string_dt", ["S", "U", "S0", "U0"]) + @pytest.mark.parametrize("arrs", + [([0.],), ([0.], [1]), ([0], ["string"], [1.])]) + def test_dtype_with_promotion(self, arrs, string_dt, axis): + # Note that U0 and S0 should be deprecated eventually and changed to + # actually give the empty string result (together with `np.array`) + res = np.concatenate(arrs, axis=axis, dtype=string_dt, casting="unsafe") + assert res.dtype == np.promote_types("d", string_dt) + + @pytest.mark.parametrize("axis", [None, 0]) + def test_string_dtype_does_not_inspect(self, axis): + # The error here currently depends on NPY_USE_NEW_CASTINGIMPL as + # the new version rejects using the "default string length" of 64. + # The new behaviour is better, `np.array()` and `arr.astype()` would + # have to be used instead. (currently only raises due to unsafe cast) + with pytest.raises((ValueError, TypeError)): + np.concatenate(([None], [1]), dtype="S", axis=axis) + with pytest.raises((ValueError, TypeError)): + np.concatenate(([None], [1]), dtype="U", axis=axis) + + @pytest.mark.parametrize("axis", [None, 0]) + def test_subarray_error(self, axis): + with pytest.raises(TypeError, match=".*subarray dtype"): + np.concatenate(([1], [1]), dtype="(2,)i", axis=axis) + def test_stack(): # non-iterable input