8000 BUG: Fix concatenation when the output is "S" or "U" · numpy/numpy@50dce51 · GitHub
[go: up one dir, main page]

Skip to content

Commit 50dce51

Browse files
committed
BUG: Fix concatenation when the output is "S" or "U"
Previously, the dtype was used, this now assumes that we want to cast to a string of (unknown) length. This is a simplified version of what happens in `np.array()` or `arr.astype()` (it does never inspect the values, e.g. for object casts). This is more complex as I would like, and with the refactor of ResultType and similar can be cleaned up a bit more hopefully. Note that currently, object to "S" or "U" casts simply return length 64 strings, but with the new version, this will be an error (although the error message probably needs improvement). This is a behaviour inherited from other places however.
1 parent 74e135b commit 50dce51

File tree

4 files changed

+116
-40
lines changed

4 files changed

+116
-40
lines changed

numpy/core/src/multiarray/convert_datatype.c

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -871,6 +871,73 @@ PyArray_CastDescrToDType(PyArray_Descr *descr, PyArray_DTypeMeta *given_DType)
871871
}
872872

873873

874+
/*
875+
* Helper to find the target descriptor for multiple arrays given an input
876+
* one that may be a DType class (e.g. "U" or "S").
877+
* Works with arrays, since that is what `concatenate` works with. However,
878+
* unlike `np.array(...)` or `arr.astype()` we will never inspect the array's
879+
* content, which means that object arrays can only be cast to strings if a
880+
* fixed width is provided (same for string -> generic datetime).
881+
*
882+
* As this function uses `PyArray_ExtractDTypeAndDescriptor`, it should
883+
* eventually be refactored to move the step to an earlier point.
884+
*/
885+
NPY_NO_EXPORT PyArray_Descr *
886+
PyArray_FindConcatenationDescriptor(
887+
npy_intp n, PyArrayObject **arrays, PyObject *requested_dtype)
888+
{
889+
if (requested_dtype == NULL) {
890+
return PyArray_ResultType(n, arrays, 0, NULL);
891+
}
892+
893+
PyArray_DTypeMeta *common_dtype;
894+
PyArray_Descr *result = NULL;
895+
if (PyArray_ExtractDTypeAndDescriptor(
896+
requested_dtype, &result, &common_dtype) < 0) {
897+
return NULL;
898+
}
899+
if (result != NULL) {
900+
if (result->subarray != NULL) {
901+
PyErr_Format(PyExc_TypeError< 10000 /span>,
902+
"The dtype `%R` is not a valid dtype for concatenation "
903+
"since it is a subarray dtype (the subarray dimensions "
904+
"would be added as array dimensions).", result);
905+
Py_DECREF(result);
906+
return NULL;
907+
}
908+
goto finish;
909+
}
910+
assert(n > 0); /* concatenate requires at least one array input. */
911+
PyArray_Descr *descr = PyArray_DESCR(arrays[0]);
912+
result = PyArray_CastDescrToDType(descr, common_dtype);
913+
if (result == NULL || n == 1) {
914+
goto finish;
915+
}
916+
/*
917+
* This could short-cut a bit, calling `common_instance` directly and/or
918+
* returning the `default_descr()` directly. Avoiding that (for now) as
919+
* it would duplicate code from `PyArray_PromoteTypes`.
920+
*/
921+
for (npy_intp i = 1; i < n; i++) {
922+
descr = PyArray_DESCR(arrays[i]);
923+
PyArray_Descr *curr = PyArray_CastDescrToDType(descr, common_dtype);
924+
if (curr == NULL) {
925+
Py_SETREF(result, NULL);
926+
goto finish;
927+
}
928+
Py_SETREF(result, PyArray_PromoteTypes(result, curr));
929+
Py_DECREF(curr);
930+
if (result == NULL) {
931+
goto finish;
932+
}
933+
}
934+
935+
finish:
936+
Py_DECREF(common_dtype);
937+
return result;
938+
}
939+
940+
874941
/**
875942
* This function defines the common DType operator.
876943
*

numpy/core/src/multiarray/convert_datatype.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,10 @@ npy_set_invalid_cast_error(
4949
NPY_NO_EXPORT PyArray_Descr *
5050
PyArray_CastDescrToDType(PyArray_Descr *descr, PyArray_DTypeMeta *given_DType);
5151

52+
NPY_NO_EXPORT PyArray_Descr *
53+
PyArray_FindConcatenationDescriptor(
54+
npy_intp n, PyArrayObject **arrays, PyObject *requested_dtype);
55+
5256
NPY_NO_EXPORT int
5357
PyArray_AddCastingImplmentation(PyBoundArrayMethodObject *meth);
5458

numpy/core/src/multiarray/multiarraymodule.c

Lines changed: 18 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -448,17 +448,10 @@ PyArray_ConcatenateArrays(int narrays, PyArrayObject **arrays, int axis,
448448

449449
/* Get the priority subtype for the array */
450450
PyTypeObject *subtype = PyArray_GetSubType(narrays, arrays);
451-
452-
if (dtype == NULL) {
453-
/* Get the resulting dtype from combining all the arrays */
454-
dtype = (PyArray_Descr *)PyArray_ResultType(
455-
narrays, arrays, 0, NULL);
456-
if (dtype == NULL) {
457-
return NULL;
458-
}
459-
}
460-
else {
461-
Py_INCREF(dtype);
451+
PyArray_Descr *descr = PyArray_FindConcatenationDescriptor(
452+
narrays, arrays, (PyObject *)dtype);
453+
if (descr == NULL) {
454+
return NULL;
462455
}
463456

464457
/*
@@ -467,25 +460,21 @@ PyArray_ConcatenateArrays(int narrays, PyArrayObject **arrays, int axis,
467460
* resolution rules matching that of the NpyIter.
468461
*/
469462
PyArray_CreateMultiSortedStridePerm(narrays, arrays, ndim, strideperm);
470-
s = dtype->elsize;
463+
s = descr->elsize;
471464
for (idim = ndim-1; idim >= 0; --idim) {
472465
int iperm = strideperm[idim];
473466
strides[iperm] = s;
474467
s *= shape[iperm];
475468
}
476469

477470
/* Allocate the array for the result. This steals the 'dtype' reference. */
478-
ret = (PyArrayObject *)PyArray_NewFromDescr(subtype,
479-
dtype,
480-
ndim,
481-
shape,
482-
strides,
483-
NULL,
484-
0,
485-
NULL);
471+
ret = (PyArrayObject *)PyArray_NewFromDescr_int(
472+
subtype, descr, ndim, shape, strides, NULL, 0, NULL,
473+
NULL, 0, 1);
486474
if (ret == NULL) {
487475
return NULL;
488476
}
477+
assert(PyArray_DESCR(ret) == descr);
489478
}
490479

491480
/*
@@ -575,32 +564,22 @@ PyArray_ConcatenateFlattenedArrays(int narrays, PyArrayObject **arrays,
575564
/* Get the priority subtype for the array */
576565
PyTypeObject *subtype = PyArray_GetSubType(narrays, arrays);
577566

578-
if (dtype == NULL) {
579-
/* Get the resulting dtype from combining all the arrays */
580-
dtype = (PyArray_Descr *)PyArray_ResultType(
581-
narrays, arrays, 0, NULL);
582-
if (dtype == NULL) {
583-
return NULL;
584-
}
585-
}
586-
else {
587-
Py_INCREF(dtype);
567+
PyArray_Descr *descr = PyArray_FindConcatenationDescriptor(
568+
narrays, arrays, (PyObject *)dtype);
569+
if (descr == NULL) {
570+
return NULL;
588571
}
589572

590-
stride = dtype->elsize;
573+
stride = descr->elsize;
591574

592575
/* Allocate the array for the result. This steals the 'dtype' reference. */
593-
ret = (PyArrayObject *)PyArray_NewFromDescr(subtype,
594-
dtype,
595-
1,
596-
&shape,
597-
&stride,
598-
NULL,
599-
0,
600-
NULL);
576+
ret = (PyArrayObject *)PyArray_NewFromDescr_int(
577+
subtype, descr, 1, &shape, &stride, NULL, 0, NULL,
578+
NULL, 0, 1);
601579
if (ret == NULL) {
602580
return NULL;
603581
}
582+
assert(PyArray_DESCR(ret) == descr);
604583
}
605584

606585
/*

numpy/core/tests/test_shape_base.py

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -343,7 +343,7 @@ def test_bad_out_shape(self):
343343
concatenate((a, b), out=np.empty(4))
344344

345345
@pytest.mark.parametrize("axis", [None, 0])
346-
@pytest.mark.parametrize("out_dtype", ["c8", "f4", "f8", ">f8", "i8"])
346+
@pytest.mark.parametrize("out_dtype", ["c8", "f4", "f8", ">f8", "i8", "S4"])
347347
@pytest.mark.parametrize("casting",
348348
['no', 'equiv', 'safe', 'same_kind', 'unsafe'])
349349
def test_out_and_dtype(self, axis, out_dtype, casting):
@@ -369,6 +369,32 @@ def test_out_and_dtype(self, axis, out_dtype, casting):
369369
with assert_raises(TypeError):
370370
concatenate(to_concat, out=out, dtype=out_dtype, axis=axis)
371371

372+
@pytest.mark.parametrize("axis", [None, 0])
373+
@pytest.mark.parametrize("string_dt", ["S", "U", "S0", "U0"])
374+
@pytest.mark.parametrize("arrs",
375+
[([0.],), ([0.], [1]), ([0], ["string"], [1.])])
376+
def test_dtype_with_promotion(self, arrs, string_dt, axis):
377+
# Note that U0 and S0 should be deprecated eventually and changed to
378+
# actually give the empty string result (together with `np.array`)
379+
res = np.concatenate(arrs, axis=axis, dtype=string_dt, casting="unsafe")
380+
assert res.dtype == np.promote_types("d", string_dt)
381+
382+
@pytest.mark.parametrize("axis", [None, 0])
383+
def test_string_dtype_does_not_inspect(self, axis):
384+
# The error here currently depends on NPY_USE_NEW_CASTINGIMPL as
385+
# the new version rejects using the "default string length" of 64.
386+
# The new behaviour is better, `np.array()` and `arr.astype()` would
387+
# have to be used instead. (currently only raises due to unsafe cast)
388+
with pytest.raises((ValueError, TypeError)):
389+
np.concatenate(([None], [1]), dtype="S", axis=axis)
390+
with pytest.raises((ValueError, TypeError)):
391+
np.concatenate(([None], [1]), dtype="U", axis=axis)
392+
393+
@pytest.mark.parametrize("axis", [None, 0])
394+
def test_subarray_error(self, axis):
395+
with pytest.raises(TypeError, match=".*subarray dtype"):
396+
np.concatenate(([1], [1]), dtype="(2,)i", axis=axis)
397+
372398

373399
def test_stack():
374400
# non-iterable input

0 commit comments

Comments
 (0)
0