-
-
Notifications
You must be signed in to change notification settings - Fork 25.9k
[MRG] Input validation refactoring #3443
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -11,13 +11,13 @@ | |
from .validation import (as_float_array, check_arrays, safe_asarray, | ||
assert_all_finite, array2d, atleast2d_or_csc, | ||
atleast2d_or_csr, warn_if_not_float, | ||
check_random_state, column_or_1d) | ||
check_random_state, column_or_1d, check_array) | ||
from .class_weight import compute_class_weight | ||
from sklearn.utils.sparsetools import minimum_spanning_tree | ||
|
||
|
||
__all__ = ["murmurhash3_32", "as_float_array", "check_arrays", "safe_asarray", | ||
"assert_all_finite", "array2d", "atleast2d_or_csc", | ||
"assert_all_finite", "array2d", "atleast2d_or_csc", "check_array", | ||
"atleast2d_or_csr", | ||
"warn_if_not_float", | ||
"check_random_state", | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can you remove deprecated stuff? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Not in this PR, but in the next PR which will touch all files. |
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -5,9 +5,13 @@ | |
from numpy.testing import assert_array_equal | ||
import scipy.sparse as sp | ||
from nose.tools import assert_raises, assert_true, assert_false, assert_equal | ||
from itertools import product | ||
|
||
from sklearn.utils import (array2d, as_float_array, atleast2d_or_csr, | ||
atleast2d_or_csc, check_arrays, safe_asarray) | ||
atleast2d_or_csc, check_arrays, safe_asarray, | ||
check_array) | ||
|
||
from sklearn.utils.estimator_checks import NotAnArray | ||
|
||
from sklearn.random_projection import sparse_random_matrix | ||
|
||
|
@@ -223,3 +227,103 @@ def test_check_arrays(): | |
# check that lists are passed through if force_arrays is true | ||
X_, Y_ = check_arrays(X, Y, force_arrays=False) | ||
assert_true(isinstance(X_, list)) | ||
|
||
|
||
def test_check_array(): | ||
# allowed_sparse == None | ||
# raise error on sparse inputs | ||
X = [[1, 2], [3, 4]] | ||
X_csr = sp.csr_matrix(X) | ||
assert_raises(TypeError, check_array, X_csr) | ||
# ensure_2d | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can you make a blank line between each case to ease reading? |
||
X_array = check_array([0, 1, 2]) | ||
assert_equal(X_array.ndim, 2) | ||
X_array = check_array([0, 1, 2], ensure_2d=False) | ||
assert_equal(X_array.ndim, 1) | ||
# don't allow ndim > 3 | ||
X_ndim = np.arange(8).reshape(2, 2, 2) | ||
assert_raises(ValueError, check_array, X_ndim) | ||
check_array(X_ndim, allow_nd=True) # doesn't raise | ||
# force_all_finite | ||
X_inf = np.arange(4).reshape(2, 2).astype(np.float) | ||
X_inf[0, 0] = np.inf | ||
assert_raises(ValueError, check_array, X_inf) | ||
check_array(X_inf, force_all_finite=False) # no raise | ||
# nan check | ||
X_nan = np.arange(4).reshape(2, 2).astype(np.float) | ||
X_nan[0, 0] = np.nan | ||
assert_raises(ValueError, check_array, X_nan) | ||
check_array(X_inf, force_all_finite=False) # no raise | ||
|
||
# dtype and order enforcement. | ||
X_C = np.arange(4).reshape(2, 2).copy("C") | ||
X_F = X_C.copy("F") | ||
X_int = X_C.astype(np.int) | ||
X_float = X_C.astype(np.float) | ||
Xs = [X_C, X_F, X_int, X_float] | ||
dtypes = [np.int32, np.int, np.float, np.float32, None, np.bool, object] | ||
orders = ['C', 'F', None] | ||
copys = [True, False] | ||
|
||
for X, dtype, order, copy in product(Xs, dtypes, orders, copys): | ||
X_checked = check_array(X, dtype=dtype, order=order, copy=copy) | ||
if dtype is not None: | ||
assert_equal(X_checked.dtype, dtype) | ||
else: | ||
assert_equal(X_checked.dtype, X.dtype) | ||
if order == 'C': | ||
assert_true(X_checked.flags['C_CONTIGUOUS']) | ||
assert_false(X_checked.flags['F_CONTIGUOUS']) | ||
elif order == 'F': | ||
assert_true(X_checked.flags['F_CONTIGUOUS']) | ||
assert_false(X_checked.flags['C_CONTIGUOUS']) | ||
if copy: | ||
assert_false(X is X_checked) | ||
else: | ||
# doesn't copy if it was already good | ||
if (X.dtype == X_checked.dtype and | ||
X_checked.flags['C_CONTIGUOUS'] == X.flags['C_CONTIGUOUS'] | ||
and X_checked.flags['F_CONTIGUOUS'] == X.flags['F_CONTIGUOUS']): | ||
assert_true(X is X_checked) | ||
|
||
# allowed sparse != None | ||
X_csc = sp.csc_matrix(X_C) | ||
X_coo = X_csc.tocoo() | ||
X_dok = X_csc.todok() | ||
X_int = X_csc.astype(np.int) | ||
X_float = X_csc.astype(np.float) | ||
|
||
Xs = [X_csc, X_coo, X_dok, X_int, X_float] | ||
allowed_sparses = [['csr', 'coo'], ['coo', 'dok']] | ||
for X, dtype, allowed_sparse, copy in product(Xs, dtypes, allowed_sparses, | ||
copys): | ||
X_checked = check_array(X, dtype=dtype, allowed_sparse=allowed_sparse, | ||
copy=copy) | ||
if dtype is not None: | ||
assert_equal(X_checked.dtype, dtype) | ||
else: | ||
assert_equal(X_checked.dtype, X.dtype) | ||
if X.format in allowed_sparse: | ||
# no change if allowed | ||
assert_equal(X.format, X_checked.format) | ||
else: | ||
# got converted | ||
assert_equal(X_checked.format, allowed_sparse[0]) | ||
if copy: | ||
assert_false(X is X_checked) | ||
else: | ||
# doesn't copy if it was already good | ||
if (X.dtype == X_checked.dtype and X.format == X_checked.format): | ||
assert_true(X is X_checked) | ||
|
||
# other input formats | ||
# convert lists to arrays | ||
X_dense = check_array([[1, 2], [3, 4]]) | ||
assert_true(isinstance(X_dense, np.ndarray)) | ||
# raise on too deep lists | ||
assert_raises(ValueError, check_array, X_ndim.tolist()) | ||
check_array(X_ndim.tolist(), allow_nd=True) # doesn't raise | ||
# convert weird stuff to arrays | ||
X_no_array = NotAnArray(X_dense) | ||
result = check_array(X_no_array) | ||
assert_true(isinstance(result, np.ndarray)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hum, that change is really surprising for me: I would read the 2 lines (the one removed and the one added) as doing very different things. It's probably just a question of choice of names on the arguments.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
that is because the previous behavior was surprising ;)