8000 Group shape/dtype validation logic in image_resample. · matplotlib/matplotlib@736c370 · GitHub
[go: up one dir, main page]

Skip to content

Commit 736c370

Browse files
committed
Group shape/dtype validation logic in image_resample.
Move it all to a single place rather than having some of it interspersed with the dtype dispatch later. Also reorder the dtype dispatch to be consistent in the 2D and 3D cases, and remove _array from many local variable names.
1 parent 78442c4 commit 736c370

File tree

1 file changed

+71
-82
lines changed

1 file changed

+71
-82
lines changed

src/_image_wrapper.cpp

Lines changed: 71 additions & 82 deletions
Original file line numberDiff line numberDiff line change
@@ -121,14 +121,15 @@ resample(PyArrayObject* input, PyArrayObject* output, resample_params_t params)
121121
static PyObject *
122122
image_resample(PyObject *self, PyObject* args, PyObject *kwargs)
123123
{
124-
PyObject *py_input_array = NULL;
125-
PyObject *py_output_array = NULL;
124+
PyObject *py_input = NULL;
125+
PyObject *py_output = NULL;
126126
PyObject *py_transform = NULL;
127127
resample_params_t params;
128128

129-
PyArrayObject *input_array = NULL;
130-
PyArrayObject *output_array = NULL;
131-
PyArrayObject *transform_mesh_array = NULL;
129+
PyArrayObject *input = NULL;
130+
PyArrayObject *output = NULL;
131+
PyArrayObject *transform_mesh = NULL;
132+
int ndim;
132133

133134
params.interpolation = NEAREST;
134135
params.transform_mesh = NULL;
@@ -143,7 +144,7 @@ image_resample(PyObject *self, PyObject* args, PyObject *kwargs)
143144

144145
if (!PyArg_ParseTupleAndKeywords(
145146
args, kwargs, "OOO|iO&dO&d:resample", (char **)kwlist,
146-
&py_input_array, &py_output_array, &py_transform,
147+
&py_input, &py_output, &py_transform,
147148
&params.interpolation, &convert_bool, &params.resample,
148149
&params.alpha, &convert_bool, &params.norm, &params.radius)) {
149150
return NULL;
@@ -155,24 +156,39 @@ image_resample(PyObject *self, PyObject* args, PyObject *kwargs)
155156
goto error;
156157
}
157158

158-
input_array = (PyArrayObject *)PyArray_FromAny(
159-
py_input_array, NULL, 2, 3, NPY_ARRAY_C_CONTIGUOUS, NULL);
160-
if (input_array == NULL) {
159+
input = (PyArrayObject *)PyArray_FromAny(
160+
py_input, NULL, 2, 3, NPY_ARRAY_C_CONTIGUOUS, NULL);
161+
if (input == NULL) {
161162
goto error;
162163
}
164+
ndim = PyArray_NDIM(input);
163165

164-
if (!PyArray_Check(py_output_array)) {
166+
if (!PyArray_Check(py_output)) {
165167
PyErr_SetString(PyExc_ValueError, "output array must be a NumPy array");
166168
goto error;
167169
}
168-
output_array = (PyArrayObject *)py_output_array;
169-
if (!PyArray_IS_C_CONTIGUOUS(output_array)) {
170-
PyErr_SetString(PyExc_ValueError, "output array must be C-contiguous");
170+
output = (PyArrayObject *)py_output;
171+
if (PyArray_NDIM(output) != ndim) {
172+
PyErr_Format(
173+
PyExc_ValueError, "Mismatched number of dimensions. Got %d and %d.",
174+
ndim, PyArray_NDIM(output));
171175
goto error;
172176
}
173-
if (PyArray_NDIM(output_array) < 2 || PyArray_NDIM(output_array) > 3) {
174-
PyErr_SetString(PyExc_ValueError,
175-
"output array must be 2- or 3-dimensional");
177+
// PyArray_FromAny above checks that input is 2D or 3D.
178+
if (ndim == 3 && (PyArray_DIM(input, 2) != 4 || PyArray_DIM(output, 2) != 4)) {
179+
PyErr_Format(
180+
PyExc_ValueError,
181+
"If 3D, input and output arrays must be RGBA with shape (M, N, 4); "
182+
"got trailing dimensions of %" NPY_INTP_FMT " and %" NPY_INTP_FMT
183+
" respectively", PyArray_DIM(input, 2), PyArray_DIM(output, 2));
184+
goto error;
185+
}
186+
if (PyArray_TYPE(input) != PyArray_TYPE(output)) {
187+
PyErr_SetString(PyExc_ValueError, "Mismatched types");
188+
goto error;
189+
}
190+
if (!PyArray_IS_C_CONTIGUOUS(output)) {
191+
PyErr_SetString(PyExc_ValueError, "output array must be C-contiguous");
176192
goto error;
177193
}
178194

@@ -197,96 +213,69 @@ image_resample(PyObject *self, PyObject* args, PyObject *kwargs)
197213
}
198214
params.is_affine = true;
199215
} else {
200-
transform_mesh_array = _get_transform_mesh(
201-
py_transform, PyArray_DIMS(output_array));
202-
if (transform_mesh_array == NULL) {
216+
transform_mesh = _get_transform_mesh(
217+
py_transform, PyArray_DIMS(output));
218+
if (transform_mesh == NULL) {
203219
goto error;
204220
}
205-
params.transform_mesh = (double *)PyArray_DATA(transform_mesh_array);
221+
params.transform_mesh = (double *)PyArray_DATA(transform_mesh);
206222
params.is_affine = false;
207223
}
208224
}
209225

210-
if (PyArray_NDIM(input_array) != PyArray_NDIM(output_array)) {
211-
PyErr_Format(
212-
PyExc_ValueError,
213-
"Mismatched number of dimensions. Got %d and %d.",
214-
PyArray_NDIM(input_array), PyArray_NDIM(output_array));
215-
goto error;
216-
}
217-
218-
if (PyArray_TYPE(input_array) != PyArray_TYPE(output_array)) {
219-
PyErr_SetString(PyExc_ValueError, "Mismatched types");
220-
goto error;
221-
}
222-
223-
if (PyArray_NDIM(input_array) == 3) {
224-
if (PyArray_DIM(output_array, 2) != 4) {
226+
if (ndim == 3) {
227+
switch (PyArray_TYPE(input)) {
228+
case NPY_UINT8:
229+
case NPY_INT8:
230+
resample<agg::rgba8>(input, output, params);
231+
break;
232+
case NPY_UINT16:
233+
case NPY_INT16:
234+
resample<agg::rgba16>(input, output, params);
235+
break;
236+
case NPY_FLOAT32:
237+
resample<agg::rgba32>(input, output, params);
238+
break;
239+
case NPY_FLOAT64:
240+
resample<agg::rgba64>(input, output, params);
241+
break;
242+
default:
225243
PyErr_SetString(
226244
PyExc_ValueError,
227-
"Output array must be RGBA");
228-
goto error;
229-
}
230-
231-
if (PyArray_DIM(input_array, 2) == 4) {
232-
switch (PyArray_TYPE(input_array)) {
233-
case NPY_UINT8:
234-
case NPY_INT8:
235-
resample<agg::rgba8>(input_array, output_array, params);
236-
break;
237-
case NPY_UINT16:
238-
case NPY_INT16:
239-
resample<agg::rgba16>(input_array, output_array, params);
240-
break;
241-
case NPY_FLOAT32:
242-
resample<agg::rgba32>(input_array, output_array, params);
243-
break;
244-
case NPY_FLOAT64:
245-
resample<agg::rgba64>(input_array, output_array, params);
246-
break;
247-
default:
248-
PyErr_SetString(
249-
PyExc_ValueError,
250-
"3-dimensional arrays must be of dtype unsigned byte, "
251-
"unsigned short, float32 or float64");
252-
goto error;
253-
}
254-
} else {
255-
PyErr_Format(
256-
PyExc_ValueError,
257-
"If 3-dimensional, array must be RGBA. Got %" NPY_INTP_FMT " planes.",
258-
PyArray_DIM(input_array, 2));
245+
"arrays must be of dtype byte, short, float32 or float64");
259246
goto error;
260247
}
261-
} else { // NDIM == 2
262-
switch (PyArray_TYPE(input_array)) {
263-
case NPY_DOUBLE:
264-
resample<double>(input_array, output_array, params);
265-
break;
266-
case NPY_FLOAT:
267-
resample<float>(input_array, output_array, params);
268-
break;
248+
} else { // ndim == 2
249+
switch (PyArray_TYPE(input)) {
269250
case NPY_UINT8:
270251
case NPY_INT8:
271-
resample<unsigned char>(input_array, output_array, params);
252+
resample<unsigned char>(input, output, params);
272253
break;
273254
case NPY_UINT16:
274255
case NPY_INT16:
275-
resample<unsigned short>(input_array, output_array, params);
256+
resample<unsigned short>(input, output, params);
257+
break;
258+
case NPY_FLOAT32:
259+
resample<float>(input, output, params);
260+
break;
261+
case NPY_FLOAT64:
262+
resample<double>(input, output, params);
276263
break;
277264
default:
278-
PyErr_SetString(PyExc_ValueError, "Unsupported dtype");
265+
PyErr_SetString(
266+
PyExc_ValueError,
267+
"arrays must be of dtype byte, short, float32 or float64");
279268
goto error;
280269
}
281270
}
282271

283-
Py_DECREF(input_array);
284-
Py_XDECREF(transform_mesh_array);
272+
Py_DECREF(input);
273+
Py_XDECREF(transform_mesh);
285274
Py_RETURN_NONE;
286275

287276
error:
288-
Py_XDECREF(input_array);
289-
Py_XDECREF(transform_mesh_array);
277+
Py_XDECREF(input);
278+
Py_XDECREF(transform_mesh);
290279
return NULL;
291280
}
292281

0 commit comments

Comments
 (0)
0