8000 TST common test for empty input data · scikit-learn/scikit-learn@25898de · GitHub
[go: up one dir, main page]

Skip to content

Commit 25898de

Browse files
committed
TST common test for empty input data
1 parent 5f95154 commit 25898de

File tree

2 files changed

+23
-0
lines changed

2 files changed

+23
-0
lines changed

sklearn/tests/test_common.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@
4040
check_regressors_pickle,
4141
check_transformer_pickle,
4242
check_transformers_unfitted,
43+
check_estimators_empty_data_messages,
4344
check_estimators_nan_inf,
4445
check_estimators_unfitted,
4546
check_classifiers_one_label,
@@ -103,6 +104,10 @@ def test_non_meta_estimators():
103104
yield check_pipeline_consistency, name, Estimator
104105

105106
if name not in CROSS_DECOMPOSITION + ['Imputer']:
107+
# Check that all estimator yield informative messages when
108+
# trained on empty datasets
109+
yield check_estimators_empty_data_messages, name, Estimator
110+
106111
# Test that all estimators check their input for NaN's and infs
107112
yield check_estimators_nan_inf, name, Estimator
108113

sklearn/utils/estimator_checks.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212

1313
from sklearn.externals.six.moves import zip
1414
from sklearn.utils.testing import assert_raises
15+
from sklearn.utils.testing import assert_raise_message
1516
from sklearn.utils.testing import assert_equal
1617
from sklearn.utils.testing import assert_true
1718
from sklearn.utils.testing import assert_false
@@ -318,6 +319,23 @@ def check_estimators_dtypes(name, Estimator):
318319
pass
319320

320321

322+
def check_estimators_empty_data_messages(name, Estimator):
323+
with warnings.catch_warnings(record=True):
324+
e = Estimator()
325+
set_fast_parameters(e)
326+
set_random_state(e, 1)
327+
328+
X_zero_samples = np.empty(0).reshape(0, 3)
329+
y = []
330+
msg = "0 sample(s) (shape=(0, 3)) while a minimum of 1 is required."
331+
assert_raise_message(ValueError, msg, e.fit, X_zero_samples, y)
332+
333+
X_zero_features = np.empty(0).reshape(3, 0)
334+
y = [1, 2, 3]
335+
msg = "0 feature(s) (shape=(3, 0)) while a minimum of 1 is required."
336+
assert_raise_message(ValueError, msg, e.fit, X_zero_features, y)
337+
338+
321339
def check_estimators_nan_inf(name, Estimator):
322340
rnd = np.random.RandomState(0)
323341
X_train_finite = rnd.uniform(size=(10, 3))

0 commit comments

Comments
 (0)
0