8000 Merge pull request #4926 from juliantaylor/concatenate-error · numpy/numpy@6cafbfd · GitHub
[go: up one dir, main page]

Skip to content

Commit 6cafbfd

Browse files
committed
Merge pull request #4926 from juliantaylor/concatenate-error
ENH: better error message for invalid axis and concatenate inputs
2 parents b1d6361 + 3f0cb83 commit 6cafbfd

File tree

2 files changed

+36
-11
lines changed

2 files changed

+36
-11
lines changed

numpy/core/src/multiarray/conversion_utils.c

Lines changed: 30 additions & 11 deletions
Original file line numberDiff line numberDiff line change
< 8000 button class="Button Button--iconOnly Button--invisible ExpandableHunkHeaderDiffLine-module__expand-button-line--rnQN5 ExpandableHunkHeaderDiffLine-module__expand-button-unified--j86KQ" aria-label="Expand file up from line 16" data-direction="up" aria-hidden="true" tabindex="-1">
@@ -16,6 +16,11 @@
1616

1717
#include "conversion_utils.h"
1818

19+
static int
20+
PyArray_PyIntAsInt_ErrMsg(PyObject *o, const char * msg) NPY_GCC_NONNULL(2);
21+
static npy_intp
22+
PyArray_PyIntAsIntp_ErrMsg(PyObject *o, const char * msg) NPY_GCC_NONNULL(2);
23+
1924
/****************************************************************
2025
* Useful function for conversion when used with PyArg_ParseTuple
2126
****************************************************************/
@@ -215,8 +220,9 @@ PyArray_AxisConverter(PyObject *obj, int *axis)
215220
*axis = NPY_MAXDIMS;
216221
}
217222
else {
218-
*axis = PyArray_PyIntAsInt(obj);
219-
if (PyErr_Occurred()) {
223+
*axis = PyArray_PyIntAsInt_ErrMsg(obj,
224+
"an integer is required for the axis");
225+
if (error_converting(*axis)) {
220226
return NPY_FAIL;
221227
}
222228
}
@@ -251,7 +257,8 @@ PyArray_ConvertMultiAxis(PyObject *axis_in, int ndim, npy_bool *out_axis_flags)
251257
}
252258
for (i = 0; i < naxes; ++i) {
253259
PyObject *tmp = PyTuple_GET_ITEM(axis_in, i);
254-
int axis = PyArray_PyIntAsInt(tmp);
260+
int axis = PyArray_PyIntAsInt_ErrMsg(tmp,
261+
"integers are required for the axis tuple elements");
255262
int axis_orig = axis;
256263
if (error_converting(axis)) {
257264
return NPY_FAIL;
@@ -281,7 +288,8 @@ PyArray_ConvertMultiAxis(PyObject *axis_in, int ndim, npy_bool *out_axis_flags)
281288

282289
memset(out_axis_flags, 0, ndim);
283290

284-
axis = PyArray_PyIntAsInt(axis_in);
291+
axis = PyArray_PyIntAsInt_ErrMsg(axis_in,
292+
"an integer is required for the axis");
285293
axis_orig = axis;
286294

287295
if (error_converting(axis)) {
@@ -736,13 +744,12 @@ PyArray_CastingConverter(PyObject *obj, NPY_CASTING *casting)
736744
* Other conversion functions
737745
*****************************/
738746

739-
/*NUMPY_API*/
740-
NPY_NO_EXPORT int
741-
PyArray_PyIntAsInt(PyObject *o)
747+
static int
748+
PyArray_PyIntAsInt_ErrMsg(PyObject *o, const char * msg)
742749
{
743750
npy_intp long_value;
744751
/* This assumes that NPY_SIZEOF_INTP >= NPY_SIZEOF_INT */
745-
long_value = PyArray_PyIntAsIntp(o);
752+
long_value = PyArray_PyIntAsIntp_ErrMsg(o, msg);
746753

747754
#if (NPY_SIZEOF_INTP > NPY_SIZEOF_INT)
748755
if ((long_value < INT_MIN) || (long_value > INT_MAX)) {
@@ -754,16 +761,21 @@ PyArray_PyIntAsInt(PyObject *o)
754761
}
755762

756763
/*NUMPY_API*/
757-
NPY_NO_EXPORT npy_intp
758-
PyArray_PyIntAsIntp(PyObject *o)
764+
NPY_NO_EXPORT int
765+
PyArray_PyIntAsInt(PyObject *o)
766+
{
767+
return PyArray_PyIntAsInt_ErrMsg(o, "an integer is required");
768+
}
769+
770+
static npy_intp
771+
PyArray_PyIntAsIntp_ErrMsg(PyObject *o, const char * msg)
759772
{
760773
#if (NPY_SIZEOF_LONG < NPY_SIZEOF_INTP)
761774
long long long_value = -1;
762775
#else
763776
long long_value = -1;
764777
#endif
765778
PyObject *obj, *err;
766-
static char *msg = "an integer is required";
767779

768780
if (!o) {
769781
PyErr_SetString(PyExc_TypeError, msg);
@@ -909,6 +921,13 @@ PyArray_PyIntAsIntp(PyObject *o)
909921
return long_value;
910922
}
911923

924+
/*NUMPY_API*/
925+
NPY_NO_EXPORT npy_intp
926+
PyArray_PyIntAsIntp(PyObject *o)
927+
{
928+
return PyArray_PyIntAsIntp_ErrMsg(o, "an integer is required");
929+
}
930+
912931

913932
/*
914933
* PyArray_IntpFromIndexSequence

numpy/core/src/multiarray/multiarraymodule.c

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -576,6 +576,12 @@ PyArray_Concatenate(PyObject *op, int axis)
576576
PyArrayObject **arrays;
577577
PyArrayObject *ret;
578578

579+
if (!PySequence_Check(op)) {
580+
PyErr_SetString(PyExc_TypeError,
581+
"The first input argument needs to be a sequence");
582+
return NULL;
583+
}
584+
579585
/* Convert the input list into arrays */
580586
narrays = PySequence_Size(op);
581587
if (narrays < 0) {

0 commit comments

Comments
 (0)
0