8000 Merge pull request #3443 from amueller/input_validation_refactoring · scikit-learn/scikit-learn@8dab222 · GitHub
[go: up one dir, main page]

Skip to content

Commit 8dab222

Browse files
committed
Merge pull request #3443 from amueller/input_validation_refactoring
[MRG] Input validation refactoring
2 parents 41d02e0 + f7549fd commit 8dab222

File tree

4 files changed

+266
-101
lines changed

4 files changed

+266
-101
lines changed

sklearn/feature_extraction/image.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from scipy import sparse
1616
from numpy.lib.stride_tricks import as_strided
1717

18-
from ..utils import array2d, check_random_state
18+
from ..utils import check_array, check_random_state
1919
from ..utils.fixes import astype
2020
from ..base import BaseEstimator
2121

@@ -349,7 +349,7 @@ def extract_patches_2d(image, patch_size, max_patches=None, random_state=None):
349349
i_h, i_w = image.shape[:2]
350350
p_h, p_w = patch_size
351351

352-
image = array2d(image)
352+
image = check_array(image, allow_nd=True)
353353
image = image.reshape((i_h, i_w, -1))
354354
n_colors = image.shape[-1]
355355

sklearn/utils/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,13 +11,13 @@
1111
from .validation import (as_float_array, check_arrays, safe_asarray,
1212
assert_all_finite, array2d, atleast2d_or_csc,
1313
atleast2d_or_csr, warn_if_not_float,
14-
check_random_state, column_or_1d)
14+
check_random_state, column_or_1d, check_array)
1515
from .class_weight import compute_class_weight
1616
from sklearn.utils.sparsetools import minimum_spanning_tree
1717

1818

1919
__all__ = ["murmurhash3_32", "as_float_array", "check_arrays", "safe_asarray",
20-
"assert_all_finite", "array2d", "atleast2d_or_csc",
20+
"assert_all_finite", "array2d", "atleast2d_or_csc", "check_array",
2121
"atleast2d_or_csr",
2222
"warn_if_not_float",
2323
"check_random_state",

sklearn/utils/tests/test_validation.py

Lines changed: 105 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,13 @@
55
from numpy.testing import assert_array_equal
66
import scipy.sparse as sp
77
from nose.tools import assert_raises, assert_true, assert_false, assert_equal
8+
from itertools import product
89

910
from sklearn.utils import (array2d, as_float_array, atleast2d_or_csr,
10-
atleast2d_or_csc, check_arrays, safe_asarray)
11+
atleast2d_or_csc, check_arrays, safe_asarray,
12+
check_array)
13+
14+
from sklearn.utils.estimator_checks import NotAnArray
1115

1216
from sklearn.random_projection import sparse_random_matrix
1317

@@ -223,3 +227,103 @@ def test_check_arrays():
223227
# check that lists are passed through if force_arrays is true
224228
X_, Y_ = check_arrays(X, Y, force_arrays=False)
225229
assert_true(isinstance(X_, list))
230+
231+
232+
def test_check_array():
233+
# allowed_sparse == None
234+
# raise error on sparse inputs
235+
X = [[1, 2], [3, 4]]
236+
X_csr = sp.csr_matrix(X)
237+
assert_raises(TypeError, check_array, X_csr)
238+
# ensure_2d
239+
X_array = check_array([0, 1, 2])
240+
assert_equal(X_array.ndim, 2)
241+
X_array = check_array([0, 1, 2], ensure_2d=False)
242+
assert_equal(X_array.ndim, 1)
243+
# don't allow ndim > 3
244+
X_ndim = np.arange(8).reshape(2, 2, 2)
245+
assert_raises(ValueError, check_array, X_ndim)
246+
check_array(X_ndim, allow_nd=True) # doesn't raise
247+
# force_all_finite
248+
X_inf = np.arange(4).reshape(2, 2).astype(np.float)
249+
X_inf[0, 0] = np.inf
250+
assert_raises(ValueError, check_array, X_inf)
251+
check_array(X_inf, force_all_finite=False) # no raise
252+
# nan check
253+
X_nan = np.arange(4).reshape(2, 2).astype(np.float)
254+
X_nan[0, 0] = np.nan
255+
assert_raises(ValueError, check_array, X_nan)
256+
check_array(X_inf, force_all_finite=False) # no raise
257+
258+
# dtype and order enforcement.
259+
X_C = np.arange(4).reshape(2, 2).copy("C")
260+
X_F = X_C.copy("F")
261+
X_int = X_C.astype(np.int)
262+
X_float = X_C.astype(np.float)
263+
Xs = [X_C, X_F, X_int, X_float]
264+
dtypes = [np.int32, np.int, np.float, np.float32, None, np.bool, object]
265+
orders = ['C', 'F', None]
266+
copys = [True, False]
267+
268+
for X, dtype, order, copy in product(Xs, dtypes, orders, copys):
269+
X_checked = check_array(X, dtype=dtype, order=order, copy=copy)
270+
if dtype is not None:
271+
assert_equal(X_checked.dtype, dtype)
272+
else:
273+
assert_equal(X_checked.dtype, X.dtype)
274+
if order == 'C':
275+
assert_true(X_checked.flags['C_CONTIGUOUS'])
276+
assert_false(X_checked.flags['F_CONTIGUOUS'])
277+
elif order == 'F':
278+
assert_true(X_checked.flags['F_CONTIGUOUS'])
279+
assert_false(X_checked.flags['C_CONTIGUOUS'])
280+
if copy:
281+
assert_false(X is X_checked)
282+
else:
283+
# doesn't copy if it was already good
284+
if (X.dtype == X_checked.dtype and
285+
X_checked.flags['C_CONTIGUOUS'] == X.flags['C_CONTIGUOUS']
286+
and X_checked.flags['F_CONTIGUOUS'] == X.flags['F_CONTIGUOUS']):
287+
assert_true(X is X_checked)
288+
289+
# allowed sparse != None
290+
X_csc = sp.csc_matrix(X_C)
291+
X_coo = X_csc.tocoo()
292+
X_dok = X_csc.todok()
293+
X_int = X_csc.astype(np.int)
294+
X_float = X_csc.astype(np.float)
295+
296+
Xs = [X_csc, X_coo, X_dok, X_int, X_float]
297+
allowed_sparses = [['csr', 'coo'], ['coo', 'dok']]
298+
for X, dtype, allowed_sparse, copy in product(Xs, dtypes, allowed_sparses,
299+
copys):
300+
X_checked = check_array(X, dtype=dtype, allowed_sparse=allowed_sparse,
301+
copy=copy)
302+
if dtype is not None:
303+
assert_equal(X_checked.dtype, dtype)
304+
else:
305+
assert_equal(X_checked.dtype, X.dtype)
306+
if X.format in allowed_sparse:
307+
# no change if allowed
308+
assert_equal(X.format, X_checked.format)
309+
else:
310+
# got converted
311+
assert_equal(X_checked.format, allowed_sparse[0])
312+
if copy:
313+
assert_false(X is X_checked)
314+
else:
315+
# doesn't copy if it was already good
316+
if (X.dtype == X_checked.dtype and X.format == X_checked.format):
317+
assert_true(X is X_checked)
318+
319+
# other input formats
320+
# convert lists to arrays
321+
X_dense = check_array([[1, 2], [3, 4]])
322+
assert_true(isinstance(X_dense, np.ndarray))
323+
# raise on too deep lists
324+
assert_raises(ValueError, check_array, X_ndim.tolist())
325+
check_array(X_ndim.tolist(), allow_nd=True) # doesn't raise
326+
# convert weird stuff to arrays
327+
X_no_array = NotAnArray(X_dense)
328+
result = check_array(X_no_array)
329+
assert_true(isinstance(result, np.ndarray))

0 commit comments

Comments
 (0)
0