8000 Merge pull request #16134 from seberg/concatenate-dtype · numpy/numpy@99bcf99 · GitHub
[go: up one dir, main page]

Skip to content

Commit 99bcf99

Browse files
authored
Merge pull request #16134 from seberg/concatenate-dtype
ENH: Implement concatenate dtype and casting keyword arguments
2 parents f6752db + c8eb9d4 commit 99bcf99

File tree

7 files changed

+193
-34
lines changed

7 files changed

+193
-34
lines changed
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
Same kind casting in concatenate with ``axis=None``
2+
---------------------------------------------------
3+
When `~numpy.concatenate` is called with `axis=None`,
4+
the flattened arrays were cast with ``unsafe``. Any other axis
5+
choice uses "same kind". That different default
6+
has been deprecated and "same kind" casting will be used
7+
instead. The new ``casting`` keyword argument
8+
can be used to retain the old behaviour.
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
Concatenate supports providing an output dtype
2+
----------------------------------------------
3+
Support was added to `~numpy.concatenate` to provide
4+
an output ``dtype`` and ``casting`` using keyword
5+
arguments. The ``dtype`` argument cannot be provided
6+
in conjunction with the ``out`` one.

numpy/core/multiarray.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -141,9 +141,9 @@ def empty_like(prototype, dtype=None, order=None, subok=None, shape=None):
141141

142142

143143
@array_function_from_c_func_and_dispatcher(_multiarray_umath.concatenate)
144-
def concatenate(arrays, axis=None, out=None):
144+
def concatenate(arrays, axis=None, out=None, *, dtype=None, casting=None):
145145
"""
146-
concatenate((a1, a2, ...), axis=0, out=None)
146+
concatenate((a1, a2, ...), axis=0, out=None, dtype=None, casting="same_kind")
147147
148148
Join a sequence of arrays along an existing axis.
149149
@@ -159,6 +159,16 @@ def concatenate(arrays, axis=None, out=None):
159159
If provided, the destination to place the result. The shape must be
160160
correct, matching that of what concatenate would have returned if no
161161
out argument were specified.
162+
dtype : str or dtype
163+
If provided, the destination array will have this dtype. Cannot be
164+
provided together with `out`.
165+
166+
..versionadded:: 1.20.0
167+
168+
casting : {'no', 'equiv', 'safe', 'same_kind', 'unsafe'}, optional
169+
Controls what kind of data casting may occur. Defaults to 'same_kind'.
170+
171+
..versionadded:: 1.20.0
162172
163173
Returns
164174
-------

numpy/core/src/multiarray/ctors.c

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1578,6 +1578,7 @@ PyArray_CheckFromAny(PyObject *op, PyArray_Descr *descr, int min_depth,
15781578
return obj;
15791579
}
15801580

1581+
15811582
/*NUMPY_API
15821583
* steals reference to newtype --- acc. NULL
15831584
*/
@@ -2252,7 +2253,10 @@ PyArray_EnsureAnyArray(PyObject *op)
22522253
return PyArray_EnsureArray(op);
22532254
}
22542255

2255-
/* TODO: Put the order parameter in PyArray_CopyAnyInto and remove this */
2256+
/*
2257+
* Private implementation of PyArray_CopyAnyInto with an additional order
2258+
* parameter.
2259+
*/
22562260
NPY_NO_EXPORT int
22572261
PyArray_CopyAsFlat(PyArrayObject *dst, PyArrayObject *src, NPY_ORDER order)
22582262
{

numpy/core/src/multiarray/multiarraymodule.c

Lines changed: 115 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -362,7 +362,8 @@ PyArray_GetSubType(int narrays, PyArrayObject **arrays) {
362362
*/
363363
NPY_NO_EXPORT PyArrayObject *
364364
PyArray_ConcatenateArrays(int narrays, PyArrayObject **arrays, int axis,
365-
PyArrayObject* ret)
365+
PyArrayObject* ret, PyArray_Descr *dtype,
366+
NPY_CASTING casting)
366367
{
367368
int iarrays, idim, ndim;
368369
npy_intp shape[NPY_MAXDIMS];
@@ -426,6 +427,7 @@ PyArray_ConcatenateArrays(int narrays, PyArrayObject **arrays, int axis,
426427
}
427428

428429
if (ret != NULL) {
430+
assert(dtype == NULL);
429431
if (PyArray_NDIM(ret) != ndim) {
430432
PyErr_SetString(PyExc_ValueError,
431433
"Output array has wrong dimensionality");
@@ -445,10 +447,16 @@ PyArray_ConcatenateArrays(int narrays, PyArrayObject **arrays, int axis,
445447
/* Get the priority subtype for the array */
446448
PyTypeObject *subtype = PyArray_GetSubType(narrays, arrays);
447449

448-
/* Get the resulting dtype from combining all the arrays */
449-
PyArray_Descr *dtype = PyArray_ResultType(narrays, arrays, 0, NULL);
450450
if (dtype == NULL) {
451-
return NULL;
451+
/* Get the resulting dtype from combining all the arrays */
452+
dtype = (PyArray_Descr *)PyArray_ResultType(
453+
narrays, arrays, 0, NULL);
454+
if (dtype == NULL) {
455+
return NULL;
456+
}
457+
}
458+
else {
459+
Py_INCREF(dtype);
452460
}
453461

454462
/*
@@ -494,7 +502,7 @@ PyArray_ConcatenateArrays(int narrays, PyArrayObject **arrays, int axis,
494502

495503
/* Copy the data for this array */
496504
if (PyArray_AssignArray((PyArrayObject *)sliding_view, arrays[iarrays],
497-
NULL, NPY_SAME_KIND_CASTING) < 0) {
505+
NULL, casting) < 0) {
498506
Py_DECREF(sliding_view);
499507
Py_DECREF(ret);
500508
return NULL;
@@ -514,7 +522,9 @@ PyArray_ConcatenateArrays(int narrays, PyArrayObject **arrays, int axis,
514522
*/
515523
NPY_NO_EXPORT PyArrayObject *
516524
PyArray_ConcatenateFlattenedArrays(int narrays, PyArrayObject **arrays,
517-
NPY_ORDER order, PyArrayObject *ret)
525+
NPY_ORDER order, PyArrayObject *ret,
526+
PyArray_Descr *dtype, NPY_CASTING casting,
527+
npy_bool casting_not_passed)
518528
{
519529
int iarrays;
520530
npy_intp shape = 0;
@@ -541,7 +551,10 @@ PyArray_ConcatenateFlattenedArrays(int narrays, PyArrayObject **arrays,
541551
}
542552
}
543553

554+
int out_passed = 0;
544555
if (ret != NULL) {
556+
assert(dtype == NULL);
557+
out_passed = 1;
545558
if (PyArray_NDIM(ret) != 1) {
546559
PyErr_SetString(PyExc_ValueError,
547560
"Output array must be 1D");
@@ -560,10 +573,16 @@ PyArray_ConcatenateFlattenedArrays(int narrays, PyArrayObject **arrays,
560573
/* Get the priority subtype for the array */
561574
PyTypeObject *subtype = PyArray_GetSubType(narrays, arrays);
562575

563-
/* Get the resulting dtype from combining all the arrays */
564-
PyArray_Descr *dtype = PyArray_ResultType(narrays, arrays, 0, NULL);
565576
if (dtype == NULL) {
566-
return NULL;
577+
/* Get the resulting dtype from combining all the arrays */
578+
dtype = (PyArray_Descr *)PyArray_ResultType(
579+
narrays, arrays, 0, NULL);
580+
if (dtype == NULL) {
581+
return NULL;
582+
}
583+
}
584+
else {
585+
Py_INCREF(dtype);
567586
}
568587

569588
stride = dtype->elsize;
@@ -593,10 +612,37 @@ PyArray_ConcatenateFlattenedArrays(int narrays, PyArrayObject **arrays,
593612
return NULL;
594613
}
595614

615+
int give_deprecation_warning = 1; /* To give warning for just one input array. */
596616
for (iarrays = 0; iarrays < narrays; ++iarrays) {
597617
/* Adjust the window dimensions for this array */
598618
sliding_view->dimensions[0] = PyArray_SIZE(arrays[iarrays]);
599619

620+
if (!PyArray_CanCastArrayTo(
621+
arrays[iarrays], PyArray_DESCR(ret), casting)) {
622+
/* This should be an error, but was previously allowed here. */
623+
if (casting_not_passed && out_passed) {
624+
/* NumPy 1.20, 2020-09-03 */
625+
if (give_deprecation_warning && DEPRECATE(
626+
"concatenate() with `axis=None` will use same-kind "
627+
"casting by default in the future. Please use "
628+
"`casting='unsafe'` to retain the old behaviour. "
629+
"In the future this will be a TypeError.") < 0) {
630+
Py_DECREF(sliding_view);
631+
Py_DECREF(ret);
632+
return NULL;
633+
}
634+
give_deprecation_warning = 0;
635+
}
636+
else {
637+
npy_set_invalid_cast_error(
638+
PyArray_DESCR(arrays[iarrays]), PyArray_DESCR(ret),
639+
casting, PyArray_NDIM(arrays[iarrays]) == 0);
640+
Py_DECREF(sliding_view);
641+
Py_DECREF(ret);
642+
return NULL;
643+
}
644+
}
645+
600646
/* Copy the data for this array */
601647
if (PyArray_CopyAsFlat((PyArrayObject *)sliding_view, arrays[iarrays],
602648
order) < 0) {
@@ -614,8 +660,21 @@ PyArray_ConcatenateFlattenedArrays(int narrays, PyArrayObject **arrays,
614660
return ret;
615661
}
616662

663+
664+
/**
665+
* Implementation for np.concatenate
666+
*
667+
* @param op Sequence of arrays to concatenate
668+
* @param axis Axis to concatenate along
669+
* @param ret output array to fill
670+
* @param dtype Forced output array dtype (cannot be combined with ret)
671+
* @param casting Casting mode used
672+
* @param casting_not_passed Deprecation helper
673+
*/
617674
NPY_NO_EXPORT PyObject *
618-
PyArray_ConcatenateInto(PyObject *op, int axis, PyArrayObject *ret)
675+
PyArray_ConcatenateInto(PyObject *op,
676+
int axis, PyArrayObject *ret, PyArray_Descr *dtype,
677+
NPY_CASTING casting, npy_bool casting_not_passed)
619678
{
620679
int iarrays, narrays;
621680
PyArrayObject **arrays;
@@ -625,6 +684,12 @@ PyArray_ConcatenateInto(PyObject *op, int axis, PyArrayObject *ret)
625684
"The first input argument needs to be a sequence");
626685
return NULL;
627686
}
687+
if (ret != NULL && dtype != NULL) {
688+
PyErr_SetString(PyExc_TypeError,
689+
"concatenate() only takes `out` or `dtype` as an "
690+
"argument, but both were provided.");
691+
return NULL;
692+
}
628693

629694
/* Convert the input list into arrays */
630695
narrays = PySequence_Size(op);
@@ -651,10 +716,13 @@ PyArray_ConcatenateInto(PyObject *op, int axis, PyArrayObject *ret)
651716
}
652717

653718
if (axis >= NPY_MAXDIMS) {
654-
ret = PyArray_ConcatenateFlattenedArrays(narrays, arrays, NPY_CORDER, ret);
719+
ret = PyArray_ConcatenateFlattenedArrays(
720+
narrays, arrays, NPY_CORDER, ret, dtype,
721+
casting, casting_not_passed);
655722
}
656723
else {
657-
ret = PyArray_ConcatenateArrays(narrays, arrays, axis, ret);
724+
ret = PyArray_ConcatenateArrays(
725+
narrays, arrays, axis, ret, dtype, casting);
658726
}
659727

660728
for (iarrays = 0; iarrays < narrays; ++iarrays) {
@@ -686,7 +754,16 @@ PyArray_ConcatenateInto(PyObject *op, int axis, PyArrayObject *ret)
686754
NPY_NO_EXPORT PyObject *
687755
PyArray_Concatenate(PyObject *op, int axis)
688756
{
689-
return PyArray_ConcatenateInto(op, axis, NULL);
757+
/* retain legacy behaviour for casting */
758+
NPY_CASTING casting;
759+
if (axis >= NPY_MAXDIMS) {
760+
casting = NPY_UNSAFE_CASTING;
761+
}
762+
else {
763+
casting = NPY_SAME_KIND_CASTING;
764+
}
765+
return PyArray_ConcatenateInto(
766+
op, axis, NULL, NULL, casting, 0);
690767
}
691768

692769
static int
@@ -2259,11 +2336,27 @@ array_concatenate(PyObject *NPY_UNUSED(dummy), PyObject *args, PyObject *kwds)
22592336
{
22602337
PyObject *a0;
22612338
PyObject *out = NULL;
2339+
PyArray_Descr *dtype = NULL;
2340+
NPY_CASTING casting = NPY_SAME_KIND_CASTING;
2341+
PyObject *casting_obj = NULL;
2342+
PyObject *res;
22622343
int axis = 0;
2263-
static char *kwlist[] = {"seq", "axis", "out", NULL};
2264-
2265-
if (!PyArg_ParseTupleAndKeywords(args, kwds, "O|O&O:concatenate", kwlist,
2266-
&a0, PyArray_AxisConverter, &axis, &out)) {
2344+
static char *kwlist[] = {"seq", "axis", "out", "dtype", "casting", NULL};
2345+
if (!PyArg_ParseTupleAndKeywords(args, kwds, "O|O&O$O&O:concatenate", kwlist,
2346+
&a0, PyArray_AxisConverter, &axis, &out,
2347+
PyArray_DescrConverter2, &dtype, &casting_obj)) {
2348+
return NULL;
2349+
}
2350+
int casting_not_passed = 0;
2351+
if (casting_obj == NULL) {
2352+
/*
2353+
* Casting was not passed in, needed for deprecation only.
2354+
* This should be simplified once the deprecation is finished.
2355+
*/
2356+
casting_not_passed = 1;
2357+
}
2358+
else if (!PyArray_CastingConverter(casting_obj, &casting)) {
2359+
Py_XDECREF(dtype);
22672360
return NULL;
22682361
}
22692362
if (out != NULL) {
@@ -2272,10 +2365,14 @@ array_concatenate(PyObject *NPY_UNUSED(dummy), PyObject *args, PyObject *kwds)
22722365
}
22732366
else if (!PyArray_Check(out)) {
22742367
PyErr_SetString(PyExc_TypeError, "'out' must be an array");
2368+
Py_XDECREF(dtype);
22752369
return NULL;
22762370
}
22772371
}
2278-
return PyArray_ConcatenateInto(a0, axis, (PyArrayObject *)out);
2372+
res = PyArray_ConcatenateInto(a0, axis, (PyArrayObject *)out, dtype,
2373+
casting, casting_not_passed);
2374+
Py_XDECREF(dtype);
2375+
return res;
22792376
}
22802377

22812378
static PyObject *

numpy/core/tests/test_deprecations.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -707,3 +707,24 @@ def test_deprecated(self):
707707
self.assert_deprecated(lambda: np.array([arr, [0]], dtype=np.float64))
708708
self.assert_deprecated(lambda: np.array([[0], arr], dtype=np.float64))
709709

710+
711+
class FlatteningConcatenateUnsafeCast(_DeprecationTestCase):
712+
# NumPy 1.20, 2020-09-03
713+
message = "concatenate with `axis=None` will use same-kind casting"
714+
715+
def test_deprecated(self):
716+
self.assert_deprecated(np.concatenate,
717+
args=(([0.], [1.]),),
718+
kwargs=dict(axis=None, out=np.empty(2, dtype=np.int64)))
719+
720+
def test_not_deprecated(self):
721+
self.assert_not_deprecated(np.concatenate,
722+
args=(([0.], [1.]),),
723+
kwargs={'axis': None, 'out': np.empty(2, dtype=np.int64),
724+
'casting': "unsafe"})
725+
726+
with assert_raises(TypeError):
727+
# Tests should notice if the deprecation warning is given first...
728+
np.concatenate(([0.], [1.]), out=np.empty(2, dtype=np.int64),
729+
casting="same_kind")
730+

numpy/core/tests/test_shape_base.py

Lines changed: 26 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -342,19 +342,32 @@ def test_bad_out_shape(self):
342342
assert_raises(ValueError, concatenate, (a, b), out=np.empty((1,4)))
343343
concatenate((a, b), out=np.empty(4))
344344

345-
def test_out_dtype(self):
346-
out = np.empty(4, np.float32)
347-
res = concatenate((array([1, 2]), array([3, 4])), out=out)
348-
assert_(out is res)
349-
350-
out = np.empty(4, np.complex64)
351-
res = concatenate((array([0.1, 0.2]), array([0.3, 0.4])), out=out)
352-
assert_(out is res)
353-
354-
# invalid cast
355-
out = np.empty(4, np.int32)
356-
assert_raises(TypeError, concatenate,
357-
(array([0.1, 0.2]), array([0.3, 0.4])), out=out)
345+
@pytest.mark.parametrize("axis", [None, 0])
346+
@pytest.mark.parametrize("out_dtype", ["c8", "f4", "f8", ">f8", "i8"])
347+
@pytest.mark.parametrize("casting",
348+
['no', 'equiv', 'safe', 'same_kind', 'unsafe'])
349+
def test_out_and_dtype(self, axis, out_dtype, casting):
350+
# Compare usage of `out=out` with `dtype=out.dtype`
351+
out = np.empty(4, dtype=out_dtype)
352+
to_concat = (array([1.1, 2.2]), array([3.3, 4.4]))
353+
354+
if not np.can_cast(to_concat[0], out_dtype, casting=casting):
355+
with assert_raises(TypeError):
356+
concatenate(to_concat, out=out, axis=axis, casting=casting)
357+
with assert_raises(TypeError):
358+
concatenate(to_concat, dtype=out.dtype,
359+
axis=axis, casting=casting)
360+
else:
361+
res_out = concatenate(to_concat, out=out,
362+
axis=axis, casting=casting)
363+
res_dtype = concatenate(to_concat, dtype=out.dtype,
364+
axis=axis, casting=casting)
365+
assert res_out is out
366+
assert_array_equal(out, res_dtype)
367+
assert res_dtype.dtype == out_dtype
368+
369+
with assert_raises(TypeError):
370+
concatenate(to_concat, out=out, dtype=out_dtype, axis=axis)
358371

359372

360373
def test_stack():

0 commit comments

Comments
 (0)
0