8000 BUG: Fix concatenation when the output is "S" or "U" by seberg · Pull Request #18052 · numpy/numpy · GitHub
[go: up one dir, main page]

Skip to content

BUG: Fix concatenation when the output is "S" or "U" #18052

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Dec 23, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 67 additions & 0 deletions numpy/core/src/multiarray/convert_datatype.c
Original file line number Diff line number Diff line change
Expand Up @@ -871,6 +871,73 @@ PyArray_CastDescrToDType(PyArray_Descr *descr, PyArray_DTypeMeta *given_DType)
}


/*
* Helper to find the target descriptor for multiple arrays given an input
* one that may be a DType class (e.g. "U" or "S").
* Works with arrays, since that is what `concatenate` works with. However,
* unlike `np.array(...)` or `arr.astype()` we will never inspect the array's
* content, which means that object arrays can only be cast to strings if a
* fixed width is provided (same for string -> generic datetime).
*
* As this function uses `PyArray_ExtractDTypeAndDescriptor`, it should
* eventually be refactored to move the step to an earlier point.
*/
NPY_NO_EXPORT PyArray_Descr *
PyArray_FindConcatenationDescriptor(
npy_intp n, PyArrayObject **arrays, PyObject *requested_dtype)
{
if (requested_dtype == NULL) {
return PyArray_ResultType(n, arrays, 0, NULL);
}

PyArray_DTypeMeta *common_dtype;
PyArray_Descr *result = NULL;
if (PyArray_ExtractDTypeAndDescriptor(
requested_dtype, &result, &common_dtype) < 0) {
return NULL;
}
if (result != NULL) {
if (result->subarray != NULL) {
PyErr_Format(PyExc_TypeError,
"The dtype `%R` is not a valid dtype for concatenation "
"since it is a subarray dtype (the subarray dimensions "
"would be added as array dimensions).", result);
Py_DECREF(result);
return NULL;
}
goto finish;
}
assert(n > 0); /* concatenate requires at least one array input. */
PyArray_Descr *descr = PyArray_DESCR(arrays[0]);
result = PyArray_CastDescrToDType(descr, common_dtype);
if (result == NULL || n == 1) {
goto finish;
}
/*
* This could short-cut a bit, calling `common_instance` directly and/or
* returning the `default_descr()` directly. Avoiding that (for now) as
* it would duplicate code from `PyArray_PromoteTypes`.
*/
for (npy_intp i = 1; i < n; i++) {
descr = PyArray_DESCR(arrays[i]);
PyArray_Descr *curr = PyArray_CastDescrToDType(descr, common_dtype);
if (curr == NULL) {
Py_SETREF(result, NULL);
goto finish;
}
Py_SETREF(result, PyArray_PromoteTypes(result, curr));
Py_DECREF(curr);
if (result == NULL) {
goto finish;
}
}

finish:
Py_DECREF(common_dtype);
return result;
}


/**
* This function defines the common DType operator.
*
Expand Down
4 changes: 4 additions & 0 deletions numpy/core/src/multiarray/convert_datatype.h
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,10 @@ npy_set_invalid_cast_error(
NPY_NO_EXPORT PyArray_Descr *
PyArray_CastDescrToDType(PyArray_Descr *descr, PyArray_DTypeMeta *given_DType);

NPY_NO_EXPORT PyArray_Descr *
PyArray_FindConcatenationDescriptor(
npy_intp n, PyArrayObject **arrays, PyObject *requested_dtype);

NPY_NO_EXPORT int
PyArray_AddCastingImplmentation(PyBoundArrayMethodObject *meth);

Expand Down
57 changes: 18 additions & 39 deletions numpy/core/src/multiarray/multiarraymodule.c
Original file line number Diff line number Diff line change
Expand Up @@ -448,17 +448,10 @@ PyArray_ConcatenateArrays(int narrays, PyArrayObject **arrays, int axis,

/* Get the priority subtype for the array */
PyTypeObject *subtype = PyArray_GetSubType(narrays, arrays);

if (dtype == NULL) {
/* Get the resulting dtype from combining all the arrays */
dtype = (PyArray_Descr *)PyArray_ResultType(
narrays, arrays, 0, NULL);
if (dtype == NULL) {
return NULL;
}
}
else {
Py_INCREF(dtype);
PyArray_Descr *descr = PyArray_FindConcatenationDescriptor(
narrays, arrays, (PyObject *)dtype);
if (descr == NULL) {
return NULL;
}

/*
Expand All @@ -467,25 +460,21 @@ PyArray_ConcatenateArrays(int narrays, PyArrayObject **arrays, int axis,
* resolution rules matching that of the NpyIter.
*/
PyArray_CreateMultiSortedStridePerm(narrays, arrays, ndim, strideperm);
s = dtype->elsize;
s = descr->elsize;
for (idim = ndim-1; idim >= 0; --idim) {
int iperm = strideperm[idim];
strides[iperm] = s;
s *= shape[iperm];
}

/* Allocate the array for the result. This steals the 'dtype' reference. */
ret = (PyArrayObject *)PyArray_NewFromDescr(subtype,
dtype,
ndim,
shape,
strides,
NULL,
0,
NULL);
ret = (PyArrayObject *)PyArray_NewFromDescr_int(
subtype, descr, ndim, shape, strides, NULL, 0, NULL,
NULL, 0, 1);
if (ret == NULL) {
return NULL;
}
assert(PyArray_DESCR(ret) == descr);
}

/*
Expand Down Expand Up @@ -575,32 +564,22 @@ PyArray_ConcatenateFlattenedArrays(int narrays, PyArrayObject **arrays,
/* Get the priority subtype for the array */
PyTypeObject *subtype = PyArray_GetSubType(narrays, arrays);

if (dtype == NULL) {
/* Get the resulting dtype from combining all the arrays */
dtype = (PyArray_Descr *)PyArray_ResultType(
narrays, arrays, 0, NULL);
if (dtype == NULL) {
return NULL;
}
}
else {
Py_INCREF(dtype);
PyArray_Descr *descr = PyArray_FindConcatenationDescriptor(
narrays, arrays, (PyObject *)dtype);
if (descr == NULL) {
return NULL;
}

stride = dtype->elsize;
stride = descr->elsize;

/* Allocate the array for the result. This steals the 'dtype' reference. */
ret = (PyArrayObject *)PyArray_NewFromDescr(subtype,
dtype,
1,
&shape,
&stride,
NULL,
0,
NULL);
ret = (PyArrayObject *)PyArray_NewFromDescr_int(
subtype, descr, 1, &shape, &stride, NULL, 0, NULL,
NULL, 0, 1);
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This actually changes nothing, but eventually, we may want to not do weird things about dtype="S0" and then it will be necessary.

if (ret == NULL) {
return NULL;
}
assert(PyArray_DESCR(ret) == descr);
}

/*
Expand Down
28 changes: 27 additions & 1 deletion numpy/core/tests/test_shape_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -343,7 +343,7 @@ def test_bad_out_shape(self):
concatenate((a, b), out=np.empty(4))

@pytest.mark.parametrize("axis", [None, 0])
@pytest.mark.parametrize("out_dtype", ["c8", "f4", "f8", ">f8", "i8"])
@pytest.mark.parametrize("out_dtype", ["c8", "f4", "f8", ">f8", "i8", "S4"])
@pytest.mark.parametrize("casting",
['no', 'equiv', 'safe', 'same_kind', 'unsafe'])
def test_out_and_dtype(self, axis, out_dtype, casting):
Expand All @@ -369,6 +369,32 @@ def test_out_and_dtype(self, axis, out_dtype, casting):
with assert_raises(TypeError):
concatenate(to_concat, out=out, dtype=out_dtype, axis=axis)

@pytest.mark.parametrize("axis", [None, 0])
@pytest.mark.parametrize("string_dt", ["S", "U", "S0", "U0"])
@pytest.mark.parametrize("arrs",
[([0.],), ([0.], [1]), ([0], ["string"], [1.])])
def test_dtype_with_promotion(self, arrs, string_dt, axis):
# Note that U0 and S0 should be deprecated eventually and changed to
# actually give the empty string result (together with `np.array`)
res = np.concatenate(arrs, axis=axis, dtype=string_dt, casting="unsafe")
assert res.dtype == np.promote_types("d", string_dt)

@pytest.mark.parametrize("axis", [None, 0])
def test_string_dtype_does_not_inspect(self, axis):
# The error here currently depends on NPY_USE_NEW_CASTINGIMPL as
# the new version rejects using the "default string length" of 64.
# The new behaviour is better, `np.array()` and `arr.astype()` would
# have to be used instead. (currently only raises due to unsafe cast)
with pytest.raises((ValueError, TypeError)):
np.concatenate(([None], [1]), dtype="S", axis=axis)
with pytest.raises((ValueError, TypeError)):
np.concatenate(([None], [1]), dtype="U", axis=axis)

@pytest.mark.parametrize("axis", [None, 0])
def test_subarray_error(self, axis):
with pytest.raises(TypeError, match=".*subarray dtype"):
np.concatenate(([1], [1]), dtype="(2,)i", axis=axis)


def test_stack():
# non-iterable input
Expand Down
0