|
12 | 12 |
|
13 | 13 | from sklearn.externals.six.moves import zip
|
14 | 14 | from sklearn.utils.testing import assert_raises
|
| 15 | +from sklearn.utils.testing import assert_raise_message |
15 | 16 | from sklearn.utils.testing import assert_equal
|
16 | 17 | from sklearn.utils.testing import assert_true
|
17 | 18 | from sklearn.utils.testing import assert_false
|
@@ -318,6 +319,23 @@ def check_estimators_dtypes(name, Estimator):
|
318 | 319 | pass
|
319 | 320 |
|
320 | 321 |
|
| 322 | +def check_estimators_empty_data_messages(name, Estimator): |
| 323 | + with warnings.catch_warnings(record=True): |
| 324 | + e = Estimator() |
| 325 | + set_fast_parameters(e) |
| 326 | + set_random_state(e, 1) |
| 327 | + |
| 328 | + X_zero_samples = np.empty(0).reshape(0, 3) |
| 329 | + y = [] |
| 330 | + msg = "0 sample(s) (shape=(0, 3)) while a minimum of 1 is required." |
| 331 | + assert_raise_message(ValueError, msg, e.fit, X_zero_samples, y) |
| 332 | + |
| 333 | + X_zero_features = np.empty(0).reshape(3, 0) |
| 334 | + y = [1, 2, 3] |
| 335 | + msg = "0 feature(s) (shape=(3, 0)) while a minimum of 1 is required." |
| 336 | + assert_raise_message(ValueError, msg, e.fit, X_zero_features, y) |
| 337 | + |
| 338 | + |
321 | 339 | def check_estimators_nan_inf(name, Estimator):
|
322 | 340 | rnd = np.random.RandomState(0)
|
323 | 341 | X_train_finite = rnd.uniform(size=(10, 3))
|
|
0 commit comments