10000 Merge pull request #4214 from ogrisel/fix-empty-input-data · scikit-learn/scikit-learn@d0b9955 · GitHub
[go: up one dir, main page]

Skip to content

Commit d0b9955

Browse files
committed
Merge pull request #4214 from ogrisel/fix-empty-input-data
[MRG+2] add validation for non-empty input data
2 parents cd931fa + d16e9ee commit d0b9955

File tree

3 files changed

+108
-15
lines changed

3 files changed

+108
-15
lines changed

sklearn/dummy.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -435,7 +435,7 @@ def fit(self, X, y, sample_weight=None):
435435

436436
self.constant = check_array(self.constant,
437437
accept_sparse=['csr', 'csc', 'coo'],
438-
ensure_2d=False)
438+
ensure_2d=False, ensure_min_samples=0)
439439

440440
if self.output_2d_ and self.constant.shape[0] != y.shape[1]:
441441
raise ValueError(

sklearn/utils/tests/test_validation.py

Lines changed: 57 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from itertools import product
99

1010
from sklearn.utils import as_float_array, check_array, check_symmetric
11+
from sklearn.utils import check_X_y
1112

1213
from sklearn.utils.estimator_checks import NotAnArray
1314

@@ -19,12 +20,12 @@
1920
from sklearn.svm import SVR
2021

2122
from sklearn.datasets import make_blobs
22-
from sklearn.utils import as_float_array, check_array
23-
from sklearn.utils.estimator_checks import NotAnArray
2423
from sklearn.utils.validation import (
25-
NotFittedError,
26-
has_fit_parameter,
27-
check_is_fitted)
24+
NotFittedError,
25+
has_fit_parameter,
26+
check_is_fitted)
27+
28+
from sklearn.utils.testing import assert_raise_message
2829

2930

3031
def test_as_float_array():
@@ -177,7 +178,7 @@ def test_check_array():
177178
Xs = [X_csc, X_coo, X_dok, X_int, X_float]
178179
accept_sparses = [['csr', 'coo'], ['coo', 'dok']]
179180
for X, dtype, accept_sparse, copy in product(Xs, dtypes, accept_sparses,
180-
copys):
181+
copys):
181182
X_checked = check_array(X, dtype=dtype, accept_sparse=accept_sparse,
182183
copy=copy)
183184
if dtype is not None:
@@ -210,6 +211,55 @@ def test_check_array():
210211
assert_true(isinstance(result, np.ndarray))
211212

212213

214+
def test_check_array_min_samples_and_features_messages():
215+
# empty list is considered 2D by default:
216+
msg = "0 feature(s) (shape=(1, 0)) while a minimum of 1 is required."
217+
assert_raise_message(ValueError, msg, check_array, [])
218+
219+
# If considered a 1D collection when ensure_2d=False, then the minimum
220+
# number of samples will break:
221+
msg = "0 sample(s) (shape=(0,)) while a minimum of 1 is required."
222+
assert_raise_message(ValueError, msg, check_array, [], ensure_2d=False)
223+
224+
# Invalid edge case when checking the default minimum sample of a scalar
225+
msg = "Singleton array array(42) cannot be considered a valid collection."
226+
assert_raise_message(TypeError, msg, check_array, 42, ensure_2d=False)
227+
228+
# But this works if the input data is forced to look like a 2 array with
229+
# one sample and one feature:
230+
X_checked = check_array(42, ensure_2d=True)
231+
assert_array_equal(np.array([[42]]), X_checked)
232+
233+
# Simulate a model that would need at least 2 samples to be well defined
234+
X = np.ones((1, 10))
235+
y = np.ones(1)
236+
msg = "1 sample(s) (shape=(1, 10)) while a minimum of 2 is required."
237+
assert_raise_message(ValueError, msg, check_X_y, X, y,
238+
ensure_min_samples=2)
239+
240+
# Simulate a model that would require at least 3 features (e.g. SelectKBest
241+
# with k=3)
242+
X = np.ones((10, 2))
243+
y = np.ones(2)
244+
msg = "2 feature(s) (shape=(10, 2)) while a minimum of 3 is required."
245+
assert_raise_message(ValueError, msg, check_X_y, X, y,
246+
ensure_min_features=3)
247+
248+
# Simulate a case where a pipeline stage as trimmed all the features of a
249+
# 2D dataset.
250+
X = np.empty(0).reshape(10, 0)
251+
y = np.ones(10)
252+
msg = "0 feature(s) (shape=(10, 0)) while a minimum of 1 is required."
253+
assert_raise_message(ValueError, msg, check_X_y, X, y)
254+
255+
# nd-data is not checked for any minimum number of features by default:
256+
X = np.ones((10, 0, 28, 28))
257+
y = np.ones(10)
258+
X_checked, y_checked = check_X_y(X, y, allow_nd=True)
259+
assert_array_equal(X, X_checked)
260+
assert_array_equal(y, y_checked)
261+
262+
213263
def test_has_fit_parameter():
214264
assert_false(has_fit_parameter(KNeighborsClassifier, "sample_weight"))
215265
assert_true(has_fit_parameter(RandomForestRegressor, "sample_weight"))
@@ -274,6 +324,6 @@ def test_check_is_fitted():
274324

275325
ard.fit(*make_blobs())
276326
svr.fit(*make_blobs())
277-
327+
278328
assert_equal(None, check_is_fitted(ard, "coef_"))
279329
assert_equal(None, check_is_fitted(svr, "support_"))

sklearn/utils/validation.py

Lines changed: 50 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,13 @@ def _num_samples(x):
110110
x = np.asarray(x)
111111
else:
112112
raise TypeError("Expected sequence or array-like, got %r" % x)
113-
return x.shape[0] if hasattr(x, 'shape') else len(x)
113+
if hasattr(x, 'shape'):
114+
if len(x.shape) == 0:
115+
raise TypeError("Singleton array %r cannot be considered"
116+
" a valid collection." % x)
117+
return x.shape[0]
118+
else:
119+
return len(x)
114120

115121

116122
def check_consistent_length(*arrays):
@@ -222,10 +228,11 @@ def _ensure_sparse_format(spmatrix, accept_sparse, dtype, order, copy,
222228

223229

224230
def check_array(array, accept_sparse=None, dtype=None, order=None, copy=False,
225-
force_all_finite=True, ensure_2d=True, allow_nd=False):
231+
force_all_finite=True, ensure_2d=True, allow_nd=False,
232+
ensure_min_samples=1, ensure_min_features=1):
226233
"""Input validation on an array, list, sparse matrix or similar.
227234
228-
By default, the input is converted to an at least 2nd numpy array.
235+
By default, the input is converted to an at least 2d numpy array.
229236
230237
Parameters
231238
----------
@@ -257,6 +264,16 @@ def check_array(array, accept_sparse=None, dtype=None, order=None, copy=False,
257264
allow_nd : boolean (default=False)
258265
Whether to allow X.ndim > 2.
259266
267+
ensure_min_samples : int (default=1)
268+
Make sure that the array has a minimum number of samples in its first
269+
axis (rows for a 2D array). Setting to 0 disables this check.
270+
271+
ensure_min_features : int (default=1)
272+
Make sure that the 2D array has some minimum number of features
273+
(columns). The default value of 1 rejects empty datasets.
274+
This check is only enforced when ``ensure_2d`` is True and
275+
``allow_nd`` is False. Setting to 0 disables this check.
276+
260277
Returns
261278
-------
262279
X_converted : object
@@ -278,12 +295,26 @@ def check_array(array, accept_sparse=None, dtype=None, order=None, copy=False,
278295
if force_all_finite:
279296
_assert_all_finite(array)
280297

298+
if ensure_min_samples > 0:
299+
n_samples = _num_samples(array)
300+
if n_samples < ensure_min_samples:
301+
raise ValueError("Found array with %d sample(s) (shape=%r) while a"
302+
" minimum of %d is required."
303+
% (n_samples, array.shape, ensure_min_samples))
304+
305+
if ensure_min_features > 0 and ensure_2d and not allow_nd:
306+
n_features = array.shape[1]
307+
if n_features < ensure_min_features:
308+
raise ValueError("Found array with %d feature(s) (shape=%r) while"
309+
" a minimum of %d is required."
310+
% (n_features, array.shape, ensure_min_features))
281311
return array
282312

283313

284314
def check_X_y(X, y, accept_sparse=None, dtype=None, order=None, copy=False,
285315
force_all_finite=True, ensure_2d=True, allow_nd=False,
286-
multi_output=False):
316+
multi_output=False, ensure_min_samples=1,
317+
ensure_min_features=1):
287318
"""Input validation for standard estimators.
288319
289320
Checks X and y for consistent length, enforces X 2d and y 1d.
@@ -327,13 +358,24 @@ def check_X_y(X, y, accept_sparse=None, dtype=None, order=None, copy=False,
327358
Whether to allow 2-d y (array or sparse matrix). If false, y will be
328359
validated as a vector.
329360
361+
ensure_min_samples : int (default=1)
362+
Make sure that X has a minimum number of samples in its first
363+
axis (rows for a 2D array).
364+
365+
ensure_min_features : int (default=1)
366+
Make sure that the 2D X has some minimum number of features
367+
(columns). The default value of 1 rejects empty datasets.
368+
This check is only enforced when ``ensure_2d`` is True and
369+
``allow_nd`` is False.
370+
330371
Returns
331372
-------
332373
X_converted : object
333374
The converted and validated X.
334375
"""
335376
X = check_array(X, accept_sparse, dtype, order, copy, force_all_finite,
336-
ensure_2d, allow_nd)
377+
ensure_2d, allow_nd, ensure_min_samples,
378+
ensure_min_features)
337379
if multi_output:
338380
y = check_array(y, 'csr', force_all_finite=True, ensure_2d=False)
339381
else:
@@ -353,7 +395,7 @@ def column_or_1d(y, warn=False):
353395
y : array-like
354396
355397
warn : boolean, default False
356-
To control display of warnings.
398+
To control display of warnings.
357399
358400
Returns
359401
-------
@@ -406,6 +448,7 @@ def check_random_state(seed):
406448
raise ValueError('%r cannot be used to seed a numpy.random.RandomState'
407449
' instance' % seed)
408450

451+
409452
def has_fit_parameter(estimator, parameter):
410453
"""Checks whether the estimator's fit method supports the given parameter.
411454
@@ -512,4 +555,4 @@ def check_is_fitted(estimator, attributes, msg=None, all_or_any=all):
512555
attributes = [attributes]
513556

514557
if not all_or_any([hasattr(estimator, attr) for attr in attributes]):
515-
raise NotFittedError(msg % {'name' : type(estimator).__name__})
558+
raise NotFittedError(msg % {'name': type(estimator).__name__})

0 commit comments

Comments
 (0)
0