8000 Merge pull request #18052 from seberg/concat-with-string-dtype · numpy/numpy@cd50d88 · GitHub
[go: up one dir, main page]

Skip to content

Commit cd50d88

Browse files
authored
Merge pull request #18052 from seberg/concat-with-string-dtype
BUG: Fix concatenation when the output is "S" or "U"
2 parents 073b9b9 + 50dce51 commit cd50d88

File tree

4 files changed

+116
-40
lines changed

4 files changed

+116
-40
lines changed

numpy/core/src/multiarray/convert_datatype.c

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -871,6 +871,73 @@ PyArray_CastDescrToDType(PyArray_Descr *descr, PyArray_DTypeMeta *given_DType)
871871
}
872872

873873

874+
/*
875+
* Helper to find the target descriptor for multiple arrays given an input
876+
* one that may be a DType class (e.g. "U" or "S").
877+
* Works with arrays, since that is what `concatenate` works with. However,
878+
* unlike `np.array(...)` or `arr.astype()` we will never inspect the array's
879+
* content, which means that object arrays can only be cast to strings if a
880+
* fixed width is provided (same for string -> generic datetime).
881+
*
882+
* As this function uses `PyArray_ExtractDTypeAndDescriptor`, it should
883+
* eventually be refactored to move the step to an earlier point.
884+
*/
885+
NPY_NO_EXPORT PyArray_Descr *
886+
PyArray_FindConcatenationDescriptor(
887+
npy_intp n, PyArrayObject **arrays, PyObject *requested_dtype)
888+
{
889+
if (requested_dtype == NULL) {
890+
return PyArray_ResultType(n, arrays, 0, NULL);
891+
}
892+
893+
PyArray_DTypeMeta *common_dtype;
894+
PyArray_Descr *result = NULL;
895+
if (PyArray_ExtractDTypeAndDescriptor(
896+
requested_dtype, &result, &common_dtype) < 0) {
897+
return NULL;
898+
}
899+
if (result != NULL) {
900+
if (result->subarray != NULL) {
901+
PyErr_Format(PyExc_TypeError,
902+
"The dtype `%R` is not a valid dtype for concatenation "
903+
"since it is a subarray dtype (the subarray dimensions "
904+
"would be added as array dimensions).", result);
905+
Py_DECREF(result);
906+
return NULL;
907+
}
908+
goto finish;
909+
}
910+
assert(n > 0); /* concatenate requires at least one array input. */
911+
PyArray_Descr *descr = PyArray_DESCR(arrays[0]);
912+
result = PyArray_CastDescrToDType(descr, common_dtype);
913+
if (result == NULL || n == 1) {
914+
goto finish;
915+
}
916+
/*
917+
* This could short-cut a bit, calling `common_instance` directly and/or
918+
* returning the `default_descr()` directly. Avoiding that (for now) as
919+
* it would duplicate code from `PyArray_PromoteTypes`.
920+
*/
921+
for (npy_intp i = 1; i < n; i++) {
922+
descr = PyArray_DESCR(arrays[i]);
923+
PyArray_Descr *curr = PyArray_CastDescrToDType(descr, common_dtype);
924+
if (curr == NULL) {
925+
Py_SETREF(result, NULL);
926+
goto finish;
927+
}
928+
Py_SETREF(result, PyArray_PromoteTypes(result, curr));
929+
Py_DECREF(curr);
930+
if (result == NULL) {
931+
goto finish;
932+
}
933+
}
934+
935+
finish:
936+
Py_DECREF(common_dtype);
937+
return result;
938+
}
939+
940+
874941
/**
875942
* This function defines the common DType operator.
876943
*

numpy/core/src/multiarray/convert_datatype.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,10 @@ npy_set_invalid_cast_error(
4949
NPY_NO_EXPORT PyArray_Descr *
5050
PyArray_CastDescrToDType(PyArray_Descr *descr, PyArray_DTypeMeta *given_DType);
5151

52+
NPY_NO_EXPORT PyArray_Descr *
53+
PyArray_FindConcatenationDescriptor(
54+
npy_intp n, PyArrayObject **arrays, PyObject *requested_dtype);
55+
5256
NPY_NO_EXPORT int
5357
PyArray_AddCastingImplmentation(PyBoundArrayMethodObject *meth);
5458

numpy/core/src/multiarray/multiarraymodule.c

Lines changed: 18 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -448,17 +448,10 @@ PyArray_ConcatenateArrays(int narrays, PyArrayObject **arrays, int axis,
448448

449449
/* Get the priority subtype for the array */
450450
PyTypeObject *subtype = PyArray_GetSubType(narrays, arrays);
451-
452-
if (dtype == NULL) {
453-
/* Get the resulting dtype from combining all the arrays */
454-
dtype = (PyArray_Descr *)PyArray_ResultType(
455-
narrays, arrays, 0, NULL);
456-
if (dtype == NULL) {
457-
return NULL;
458-
}
459-
}
460-
else {
461-
Py_INCREF(dtype);
451+
PyArray_Descr *descr = PyArray_FindConcatenationDescriptor(
452+
narrays, arrays, (PyObject *)dtype);
453+
if (descr == NULL) {
454+
return NULL;
462455
}
463456

464457
/*
@@ -467,25 +460,21 @@ PyArray_ConcatenateArrays(int narrays, PyArrayObject **arrays, int axis,
467460
* resolution rules matching that of the NpyIter.
468461
*/
469462
PyArray_CreateMultiSortedStridePerm(narrays, arrays, ndim, strideperm);
470-
s = dtype->elsize;
463+
s = descr->elsize;
471464
for (idim = ndim-1; idim >= 0; --idim) {
472465
int iperm = strideperm[idim];
473466
strides[iperm] = s;
474467
s *= shape[iperm];
475468
}
476469

477470
/* Allocate the array for the result. This steals the 'dtype' reference. */
478-
ret = (PyArrayObject *)PyArray_NewFromDescr(subtype,
479-
dtype,
480-
ndim,
481-
shape,
482-
strides,
483-
NULL,
484-
0,
485-
NULL);
471+
ret = (PyArrayObject *)PyArray_NewFromDescr_int(
472+
subtype, descr, ndim, shape, strides, NULL, 0, NULL,
473+
NULL, 0, 1);
486474
if (ret == NULL) {
487475
return NULL;
488476
}
477+
assert(PyArray_DESCR(ret) == descr);
489478
}
490479

491480
/*
@@ -575,32 +564,22 @@ PyArray_ConcatenateFlattenedArrays(int narrays, PyArrayObject **arrays,
575564
/* Get the priority subtype for the array */
576565
PyTypeObject *subtype = PyArray_GetSubType(narrays, arrays);
577566

578-
if (dtype == NULL) {
579-
/* Get the resulting dtype from combining all the arrays */
580-
dtype = (PyArray_Descr *)PyArray_ResultType(
581-
narrays, arrays, 0, NULL);
582-
if (dtype == NULL) {
583-
return NULL;
584-
}
585-
}
586-
else {
587-
Py_INCREF(dtype);
567+
PyArray_Descr *descr = PyArray_FindConcatenationDescriptor(
568+
narrays, arrays, (PyObject *)dtype);
569+
if (descr == NULL) {
570+
return NULL;
588571
}
589572

590-
stride = dtype->elsize;
573+
stride = descr->elsize;
591574

592575
/* Allocate the array for the result. This steals the 'dtype' reference. */
593-
ret = (PyArrayObject *)PyArray_NewFromDescr(subtype,
594-
dtype,
595-
1,
596-
&shape,
597-
&stride,
598-
NULL,
599-
0,
600-
NULL);
576+
ret = (PyArrayObject *)PyArray_NewFromDescr_int(
577+
subtype, descr, 1, &shape, &stride, NULL, 0, NULL,
578+
NULL, 0, 1);
601579
if (ret == NULL) {
602580
return NULL;
603581
}
582+
assert(PyArray_DESCR(ret) == descr);
604583
}
605584

606585
/*

numpy/core/tests/test_shape_base.py

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -343,7 +343,7 @@ def test_bad_out_shape(self):
343343
concatenate((a, b), out=np.empty(4))
344344

345345
@pytest.mark.parametrize("axis", [None, 0])
346-
@pytest.mark.parametrize("out_dtype", ["c8", "f4", "f8", ">f8", "i8"])
346+
@pytest.mark.parametrize("out_dtype", ["c8", "f4", "f8", ">f8", "i8", "S4"])
347347
@pytest.mark.parametrize("casting",
348348
['no', 'equiv', 'safe', 'same_kind', 'unsafe'])
349349
def test_out_and_dtype(self, axis, out_dtype, casting):
@@ -369,6 +369,32 @@ def test_out_and_dtype(self, axis, out_dtype, casting):
369369
with assert_raises(TypeError):
370370
concatenate(to_concat, out=out, dtype=out_dtype, axis=axis)
371371

372+
@pytest.mark.parametrize("axis", [None, 0])
373+
@pytest.mark.parametrize("string_dt", ["S", "U", "S0", "U0"])
374+
@pytest.mark.parametrize("arrs",
375+
[([0.],), ([0.], [1]), ([0], ["string"], [1.])])
376+
def test_dtype_with_promotion(self, arrs, string_dt, axis):
377+
# Note that U0 and S0 should be deprecated eventually and changed to
378+
# actually give the empty string result (together with `np.array`)
379+
res = np.concatenate(arrs, axis=axis, dtype=string_dt, casting="unsafe")
380+
assert res.dtype == np.promote_types("d", string_dt)
381+
382+
@pytest.mark.parametrize("axis", [None, 0])
383+
def test_string_dtype_does_not_inspect(self, axis):
384+
# The error here currently depends on NPY_USE_NEW_CASTINGIMPL as
385+
# the new version rejects using the "default string length" of 64.
386+
# The new behaviour is better, `np.array()` and `arr.astype()` would
387+
# have to be used instead. (currently only raises due to unsafe cast)
388+
with pytest.raises((ValueError, TypeError)):
389+
np.concatenate(([None], [1]), dtype="S", axis=axis)
390+
with pytest.raises((ValueError, TypeError)):
391+
np.concatenate(([None], [1]), dtype="U", axis=axis)
392+
393+
@pytest.mark.parametrize("axis", [None, 0])
394+
def test_subarray_error(self, axis):
395+
with pytest.raises(TypeError, match=".*subarray dtype"):
396+
np.concatenate(([1], [1]), dtype="(2,)i", axis=axis)
397+
372398

373399
def test_stack():
374400
# non-iterable input

0 commit comments

Comments
 (0)
0