8000 FIX test for consistent handling on empty input data · scikit-learn/scikit-learn@feceab6 · GitHub
[go: up one dir, main page]

Skip to content

Commit feceab6

Browse files
committed
FIX test for consistent handling on empty input data
1 parent 2a7a213 commit feceab6

File tree

9 files changed

+64
-17
lines changed

9 files changed

+64
-17
lines changed

doc/whats_new.rst

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -401,7 +401,9 @@ API changes summary
401401
- Estimators will treat input with dtype object as numeric when possible.
402402
By `Andreas Müller`_
403403

404-
404+
- Estimators now raise `ValueError` consistently when fitted on empty
405+
data (less than 1 sample or less than 1 feature for 2D input).
406+
By `Olivier Grisel`_.
405407

406408
.. _changes_0_15_2:
407409

sklearn/ensemble/forest.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -191,10 +191,8 @@ def fit(self, X, y, sample_weight=None):
191191
self : object
192192
Returns self.
193193
"""
194-
# Convert data
195-
# ensure_2d=False because there are actually unit test checking we fail
196-
# for 1d. FIXME make this consistent in the future.
197-
X = check_array(X, dtype=DTYPE, ensure_2d=False, accept_sparse="csc")
194+
# Validate or convert input data
195+
X = check_array(X, dtype=DTYPE, accept_sparse="csc")
198196
if issparse(X):
199197
# Pre-sort indices to avoid that each individual tree of the
200198
# ensemble sorts the indices.
@@ -207,7 +205,7 @@ def fit(self, X, y, sample_weight=None):
207205
if y.ndim == 2 and y.shape[1] == 1:
208206
warn("A column-vector y was passed when a 1d array was"
209207
" expected. Please change the shape of y to "
210-
"(n_samples, ), for example using ravel().",
208+
"(n_samples,), for example using ravel().",
211209
DataConversionWarning, stacklevel=2)
212210

213211
if y.ndim == 1:

sklearn/kernel_approximation.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -437,10 +437,8 @@ def fit(self, X, y=None):
437437
X : array-like, shape=(n_samples, n_feature)
438438
Training data.
439439
"""
440-
440+
X = check_array(X, accept_sparse='csr')
441441
rnd = check_random_state(self.random_state)
442-
if not sp.issparse(X):
443-
X = np.asarray(X)
444442
n_samples = X.shape[0]
445443

446444
# get basis vectors
@@ -487,6 +485,7 @@ def transform(self, X):
487485
Transformed data.
488486
"""
489487
check_is_fitted(self, 'components_')
488+
X = check_array(X, accept_sparse='csr')
490489

491490
kernel_params = self._get_kernel_params()
492491
embedded = pairwise_kernels(X, self.components_,

sklearn/linear_model/coordinate_descent.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -984,6 +984,8 @@ def fit(self, X, y):
984984
Target values
985985
"""
986986
y = np.asarray(y, dtype=np.float64)
987+
if y.shape[0] == 0:
988+
raise ValueError("y has 0 samples: %r" % y)
987989

988990
if hasattr(self, 'l1_ratio'):
989991
model_str = 'ElasticNet'

sklearn/preprocessing/label.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from ..utils.fixes import in1d
2323
from ..utils import deprecated, column_or_1d
2424
from ..utils.validation import check_array
25+
from ..utils.validation import _num_samples
2526
from ..utils.multiclass import unique_labels
2627
from ..utils.multiclass import type_of_target
2728

@@ -315,6 +316,8 @@ def fit(self, y):
315316
if 'multioutput' in self.y_type_:
316317
raise ValueError("Multioutput target data is not supported with "
317318
"label binarization")
319+
if _num_samples(y) == 0:
320+
raise ValueError('y has 0 samples: %r' % y)
318321

319322
self.sparse_input_ = sp.issparse(y)
320323
self.classes_ = unique_labels(y)
@@ -465,6 +468,9 @@ def label_binarize(y, classes, neg_label=0, pos_label=1,
465468
# XXX Workaround that will be removed when list of list format is
466469
# dropped
467470
y = check_array(y, accept_sparse='csr', ensure_2d=False, dtype=None)
471+
else:
472+
if _num_samples(y) == 0:
473+
raise ValueError('y has 0 samples: %r' % y)
468474
if neg_label >= pos_label:
469475
raise ValueError("neg_label={0} must be strictly less than "
470476
"pos_label={1}.".format(neg_label, pos_label))

sklearn/tests/test_common.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@
4141
check_regressors_pickle,
4242
check_transformer_pickle,
4343
check_transformers_unfitted,
44+
check_estimators_empty_data_messages,
4445
check_estimators_nan_inf,
4546
check_estimators_unfitted,
4647
check_classifiers_one_label,
@@ -99,6 +100,10 @@ def test_non_meta_estimators():
99100
yield check_fit_score_takes_y, name, Estimator
100101
yield check_dtype_object, name, Estimator
101102

103+
# Check that all estimator yield informative messages when
104+
# trained on empty datasets
105+
yield check_estimators_empty_data_messages, name, Estimator
106+
102107
if name not in CROSS_DECOMPOSITION + ['SpectralEmbedding']:
103108
# SpectralEmbedding is non-deterministic,
104109
# see issue #4236

sklearn/utils/estimator_checks.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212

1313
from sklearn.externals.six.moves import zip
1414
from sklearn.utils.testing import assert_raises
15+
from sklearn.utils.testing import assert_raise_message
1516
from sklearn.utils.testing import assert_equal
1617
from sklearn.utils.testing import assert_true
1718
from sklearn.utils.testing import assert_false
@@ -346,6 +347,24 @@ def check_estimators_dtypes(name, Estimator):
346347
pass
347348

348349

350+
def check_estimators_empty_data_messages(name, Estimator):
351+
e = Estimator()
352+
set_fast_parameters(e)
353+
set_random_state(e, 1)
354+
355+
X_zero_samples = np.empty(0).reshape(0, 3)
356+
# The precise message can change depending on whether X or y is
357+
# validated first. Let us test the type of exception only:
358+
assert_raises(ValueError, e.fit, X_zero_samples, [])
359+
360+
X_zero_features = np.empty(0).reshape(3, 0)
361+
# the following y should be accepted by both classifiers and regressors
362+
# and ignored by unsupervised models
363+
y = multioutput_estimator_convert_y_2d(name, np.array([1, 0, 1]))
364+
msg = "0 feature(s) (shape=(3, 0)) while a minimum of 1 is required."
365+
assert_raise_message(ValueError, msg, e.fit, X_zero_features, y)
366+
367+
349368
def check_estimators_nan_inf(name, Estimator):
350369
rnd = np.random.RandomState(0)
351370
X_train_finite = rnd.uniform(size=(10, 3))

sklearn/utils/tests/test_validation.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -237,6 +237,11 @@ def test_check_array_min_samples_and_features_messages():
237237
assert_raise_message(ValueError, msg, check_X_y, X, y,
238238
ensure_min_samples=2)
239239

240+
# The same message is raised if the data has 2 dimensions even if this is
241+
# not mandatory
242+
assert_raise_message(ValueError, msg, check_X_y, X, y,
243+
ensure_min_samples=2, ensure_2d=False)
244+
240245
# Simulate a model that would require at least 3 features (e.g. SelectKBest
241246
# with k=3)
242247
X = np.ones((10, 2))
@@ -245,6 +250,11 @@ def test_check_array_min_samples_and_features_messages():
245250
assert_raise_message(ValueError, msg, check_X_y, X, y,
246251
ensure_min_features=3)
247252

253+
# Only the feature check is enabled whenever the number of dimensions is 2
254+
# even if allow_nd is enabled:
255+
assert_raise_message(ValueError, msg, check_X_y, X, y,
256+
ensure_min_features=3, allow_nd=True)
257+
248258
# Simulate a case where a pipeline stage as trimmed all the features of a
249259
# 2D dataset.
250260
X = np.empty(0).reshape(10, 0)

sklearn/utils/validation.py

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -307,8 +307,9 @@ def check_array(array, accept_sparse=None, dtype="numeric", order=None, copy=Fal
307307
ensure_min_features : int (default=1)
308308
Make sure that the 2D array has some minimum number of features
309309
(columns). The default value of 1 rejects empty datasets.
310-
This check is only enforced when ``ensure_2d`` is True and
311-
``allow_nd`` is False. Setting to 0 disables this check.
310+
This check is only enforced when the input data has effectively 2
311+
dimensions or is originally 1D and ``ensure_2d`` is True. Setting to 0
312+
disables this check.
312313
313314
Returns
314315
-------
@@ -347,7 +348,8 @@ def check_array(array, accept_sparse=None, dtype="numeric", order=None, copy=Fal
347348
" minimum of %d is required."
348349
% (n_samples, shape_repr, ensure_min_samples))
349350

350-
if ensure_min_features > 0 and ensure_2d and not allow_nd:
351+
352+
if ensure_min_features > 0 and array.ndim == 2:
351353
n_features = array.shape[1]
352354
if n_features < ensure_min_features:
353355
raise ValueError("Found array with %d feature(s) (shape=%s) while"
@@ -411,13 +413,16 @@ def check_X_y(X, y, accept_sparse=None, dtype="numeric", order=None, copy=False,
411413
axis (rows for a 2D array).
412414
413415
ensure_min_features : int (default=1)
414-
Make sure that the 2D X has some minimum number of features
416+
Make sure that the 2D array has some minimum number of features
415417
(columns). The default value of 1 rejects empty datasets.
416-
This check is only enforced when ``ensure_2d`` is True and
417-
``allow_nd`` is False.
418+
This check is only enforced when X has effectively 2 dimensions or
419+
is originally 1D and ``ensure_2d`` is True. Setting to 0 disables
420+
this check.
421+
418422
y_numeric : boolean (default=False)
419423
Whether to ensure that y has a numeric type. If dtype of y is object,
420-
it is converted to float64. Should only be used for regression algorithms.
424+
it is converted to float64. Should only be used for regression
425+
algorithms.
421426
422427
Returns
423428
-------
@@ -428,7 +433,8 @@ def check_X_y(X, y, accept_sparse=None, dtype="numeric", order=None, copy=False,
428433
ensure_2d, allow_nd, ensure_min_samples,
429434
ensure_min_features)
430435
if multi_output:
431-
y = check_array(y, 'csr', force_all_finite=True, ensure_2d=False, dtype=None)
436+
y = check_array(y, 'csr', force_all_finite=True, ensure_2d=False,
437+
dtype=None)
432438
else:
433439
y = column_or_1d(y, warn=True)
434440
_assert_all_finite(y)

0 commit comments

Comments
 (0)
0