8000 Group shape/dtype validation logic in image_resample. · matplotlib/matplotlib@1fdd4b6 · GitHub
[go: up one dir, main page]

Skip to content

Commit 1fdd4b6

Browse files
committed
Group shape/dtype validation logic in image_resample.
Move it all to a single place rather than having some of it interspersed with the dtype dispatch later. Also reorder the dtype dispatch to be consistent in the 2D and 3D cases, and remove _array from many local variable names.
1 parent 78442c4 commit 1fdd4b6

File tree

2 files changed

+91
-86
lines changed

2 files changed

+91
-86
lines changed

lib/matplotlib/tests/test_image.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from contextlib import ExitStack
22
from copy import copy
3+
import functools
34
import io
45
import os
56
from pathlib import Path
@@ -1453,3 +1454,17 @@ def test_str_norms(fig_test, fig_ref):
14531454
assert type(axts[0].images[0].norm) == colors.LogNorm # Exactly that class
14541455
with pytest.raises(ValueError):
14551456
axts[0].imshow(t, norm="foobar")
1457+
1458+
1459+
def test__resample_valid_output():
1460+
resample = functools.partial(mpl._image.resample, transform=Affine2D())
1461+
with pytest.raises(ValueError, match="must be a NumPy array"):
1462+
resample(np.zeros((9, 9)), None)
1463+
with pytest.raises(ValueError, match="different dimensionalities"):
1464+
resample(np.zeros((9, 9)), np.zeros((9, 9, 4)))
1465+
with pytest.raises(ValueError, match="must be RGBA"):
1466+
resample(np.zeros((9, 9, 4)), np.zeros((9, 9, 3)))
1467+
with pytest.raises(ValueError, match="Mismatched types"):
1468+
resample(np.zeros((9, 9), np.uint8), np.zeros((9, 9)))
1469+
with pytest.raises(ValueError, match="must be C-contiguous"):
1470+
resample(np.zeros((9, 9)), np.zeros((9, 9)).T)

src/_image_wrapper.cpp

Lines changed: 76 additions & 86 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
* */
1010

1111
const char* image_resample__doc__ =
12-
"resample(input_array, output_array, matrix, interpolation=NEAREST, alpha=1.0, norm=False, radius=1)\n"
12+
"resample(input_array, output_array, transform, interpolation=NEAREST, alpha=1.0, norm=False, radius=1)\n"
1313
"--\n\n"
1414

1515
"Resample input_array, blending it in-place into output_array, using an\n"
@@ -121,14 +121,15 @@ resample(PyArrayObject* input, PyArrayObject* output, resample_params_t params)
121121
static PyObject *
122122
image_resample(PyObject *self, PyObject* args, PyObject *kwargs)
123123
{
124-
PyObject *py_input_array = NULL;
125-
PyObject *py_output_array = NULL;
124+
PyObject *py_input = NULL;
125+
PyObject *py_output = NULL;
126126
PyObject *py_transform = NULL;
127127
resample_params_t params;
128128

129-
PyArrayObject *input_array = NULL;
130-
PyArrayObject *output_array = NULL;
131-
PyArrayObject *transform_mesh_array = NULL;
129+
PyArrayObject *input = NULL;
130+
PyArrayObject *output = NULL;
131+
PyArrayObject *transform_mesh = NULL;
132+
int ndim;
132133

133134
params.interpolation = NEAREST;
134135
params.transform_mesh = NULL;
@@ -143,36 +144,52 @@ image_resample(PyObject *self, PyObject* args, PyObject *kwargs)
143144

144145
if (!PyArg_ParseTupleAndKeywords(
145146
args, kwargs, "OOO|iO&dO&d:resample", (char **)kwlist,
146-
&py_input_array, &py_output_array, &py_transform,
147+
&py_input, &py_output, &py_transform,
147148
&params.interpolation, &convert_bool, &params.resample,
148149
&params.alpha, &convert_bool, &params.norm, &params.radius)) {
149150
return NULL;
150151
}
151152

152153
if (params.interpolation < 0 || params.interpolation >= _n_interpolation) {
153-
PyErr_Format(PyExc_ValueError, "invalid interpolation value %d",
154+
PyErr_Format(PyExc_ValueError, "Invalid interpolation value %d",
154155
params.interpolation);
155156
goto error;
156157
}
157158

158-
input_array = (PyArrayObject *)PyArray_FromAny(
159-
py_input_array, NULL, 2, 3, NPY_ARRAY_C_CONTIGUOUS, NULL);
160-
if (input_array == NULL) {
159+
input = (PyArrayObject *)PyArray_FromAny(
160+
py_input, NULL, 2, 3, NPY_ARRAY_C_CONTIGUOUS, NULL);
161+
if (!input) {
161162
goto error;
162163
}
164+
ndim = PyArray_NDIM(input);
163165

164-
if (!PyArray_Check(py_output_array)) {
165-
PyErr_SetString(PyExc_ValueError, "output array must be a NumPy array");
166+
if (!PyArray_Check(py_output)) {
167+
PyErr_SetString(PyExc_ValueError, "Output array must be a NumPy array");
166168
goto error;
167169
}
168-
output_array = (PyArrayObject *)py_output_array;
169-
if (!PyArray_IS_C_CONTIGUOUS(output_array)) {
170-
PyErr_SetString(PyExc_ValueError, "output array must be C-contiguous");
170+
output = (PyArrayObject *)py_output;
171+
if (PyArray_NDIM(output) != ndim) {
172+
PyErr_Format(
173+
PyExc_ValueError,
174+
"Input (%dD) and output (%dD) have different dimensionalities.",
175+
ndim, PyArray_NDIM(output));
176+
goto error;
177+
}
178+
// PyArray_FromAny above checks that input is 2D or 3D.
179+
if (ndim == 3 && (PyArray_DIM(input, 2) != 4 || PyArray_DIM(output, 2) != 4)) {
180+
PyErr_Format(
181+
PyExc_ValueError,
182+
"If 3D, input and output arrays must be RGBA with shape (M, N, 4); "
183+
"got trailing dimensions of %" NPY_INTP_FMT " and %" NPY_INTP_FMT
184+
" respectively", PyArray_DIM(input, 2), PyArray_DIM(output, 2));
171185
goto error;
172186
}
173-
if (PyArray_NDIM(output_array) < 2 || PyArray_NDIM(output_array) > 3) {
174-
PyErr_SetString(PyExc_ValueError,
175-
"output array must be 2- or 3-dimensional");
187+
if (PyArray_TYPE(input) != PyArray_TYPE(output)) {
188+
PyErr_SetString(PyExc_ValueError, "Mismatched types");
189+
goto error;
190+
}
191+
if (!PyArray_IS_C_CONTIGUOUS(output)) {
192+
PyErr_SetString(PyExc_ValueError, "Output array must be C-contiguous");
176193
goto error;
177194
}
178195

@@ -182,7 +199,7 @@ image_resample(PyObject *self, PyObject* args, PyObject *kwargs)
182199
PyObject *py_is_affine;
183200
int py_is_affine2;
184201
py_is_affine = PyObject_GetAttrString(py_transform, "is_affine");
185-
if (py_is_affine == NULL) {
202+
if (!py_is_affine) {
186203
goto error;
187204
}
188205

@@ -197,96 +214,69 @@ image_resample(PyObject *self, PyObject* args, PyObject *kwargs)
197214
}
198215
params.is_affine = true;
199216
} else {
200-
transform_mesh_array = _get_transform_mesh(
201-
py_transform, PyArray_DIMS(output_array));
202-
if (transform_mesh_array == NULL) {
217+
transform_mesh = _get_transform_mesh(
218+
py_transform, PyArray_DIMS(output));
219+
if (!transform_mesh) {
203220
goto error;
204221
}
205-
params.transform_mesh = (double *)PyArray_DATA(transform_mesh_array);
222+
params.transform_mesh = (double *)PyArray_DATA(transform_mesh);
206223
params.is_affine = false;
207224
}
208225
}
209226

210-
if (PyArray_NDIM(input_array) != PyArray_NDIM(output_array)) {
211-
PyErr_Format(
212-
PyExc_ValueError,
213-
"Mismatched number of dimensions. Got %d and %d.",
214-
PyArray_NDIM(input_array), PyArray_NDIM(output_array));
215-
goto error;
216-
}
217-
218-
if (PyArray_TYPE(input_array) != PyArray_TYPE(output_array)) {
219-
PyErr_SetString(PyExc_ValueError, "Mismatched types");
220-
goto error;
221-
}
222-
223-
if (PyArray_NDIM(input_array) == 3) {
224-
if (PyArray_DIM(output_array, 2) != 4) {
227+
if (ndim == 3) {
228+
switch (PyArray_TYPE(input)) {
229+
case NPY_UINT8:
230+
case NPY_INT8:
231+
resample<agg::rgba8>(input, output, params);
232+
break;
233+
case NPY_UINT16:
234+
case NPY_INT16:
235+
resample<agg::rgba16>(input, output, params);
236+
break;
237+
case NPY_FLOAT32:
238+
resample<agg::rgba32>(input, output, params);
239+
break;
240+
case NPY_FLOAT64:
241+
resample<agg::rgba64>(input, output, params);
242+
break;
243+
default:
225244
PyErr_SetString(
226245
PyExc_ValueError,
227-
"Output array must be RGBA");
228-
goto error;
229-
}
230-
231-
if (PyArray_DIM(input_array, 2) == 4) {
232-
switch (PyArray_TYPE(input_array)) {
233-
case NPY_UINT8:
234-
case NPY_INT8:
235-
resample<agg::rgba8>(input_array, output_array, params);
236-
break;
237-
case NPY_UINT16:
238-
case NPY_INT16:
239-
resample<agg::rgba16>(input_array, output_array, params);
240-
break;
241-
case NPY_FLOAT32:
242-
resample<agg::rgba32>(input_array, output_array, params);
243-
break;
244-
case NPY_FLOAT64:
245-
resample<agg::rgba64>(input_array, output_array, params);
246-
break;
247-
default:
248-
PyErr_SetString(
249-
PyExc_ValueError,
250-
"3-dimensional arrays must be of dtype unsigned byte, "
251-
"unsigned short, float32 or float64");
252-
goto error;
253-
}
254-
} else {
255-
PyErr_Format(
256-
PyExc_ValueError,
257-
"If 3-dimensional, array must be RGBA. Got %" NPY_INTP_FMT " planes.",
258-
PyArray_DIM(input_array, 2));
246+
"arrays must be of dtype byte, short, float32 or float64");
259247
goto error;
260248
}
261-
} else { // NDIM == 2
262-
switch (PyArray_TYPE(input_array)) {
263-
case NPY_DOUBLE:
264-
resample<double>(input_array, output_array, params);
265-
break;
266-
case NPY_FLOAT:
267-
resample<float>(input_array, output_array, params);
268-
break;
249+
} else { // ndim == 2
250+
switch (PyArray_TYPE(input)) {
269251
case NPY_UINT8:
270252
case NPY_INT8:
271-
resample<unsigned char>(input_array, output_array, params);
253+
resample<unsigned char>(input, output, params);
272254
break;
273255
case NPY_UINT16:
274256
case NPY_INT16:
275-
resample<unsigned short>(input_array, output_array, params);
257+
resample<unsigned short>(input, output, params);
258+
break;
259+
case NPY_FLOAT32:
260+
resample<float>(input, output, params);
261+
break;
262+
case NPY_FLOAT64:
263+
resample<double>(input, output, params);
276264
break;
277265
default:
278-
PyErr_SetString(PyExc_ValueError, "Unsupported dtype");
266+
PyErr_SetString(
267+
PyExc_ValueError,
268+
"arrays must be of dtype byte, short, float32 or float64");
279269
goto error;
280270
}
281271
}
282272

283-
Py_DECREF(input_array);
284-
Py_XDECREF(transform_mesh_array);
273+
Py_DECREF(input);
274+
Py_XDECREF(transform_mesh);
285275
Py_RETURN_NONE;
286276

287277
error:
288-
Py_XDECREF(input_array);
289-
Py_XDECREF(transform_mesh_array);
278+
Py_XDECREF(input);
279+
Py_XDECREF(transform_mesh);
290280
return NULL;
291281
}
292282

0 commit comments

Comments
 (0)
0