8000 FIX yield right exception on empty input data · scikit-learn/scikit-learn@7805ec7 · GitHub
[go: up one dir, main page]

Skip to content

Commit 7805ec7

Browse files
committed
FIX yield right exception on empty input data
1 parent 549d98c commit 7805ec7

File tree

5 files changed

+23
-7
lines changed

5 files changed

+23
-7
lines changed

sklearn/linear_model/coordinate_descent.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -984,6 +984,8 @@ def fit(self, X, y):
984984
Target values
985985
"""
986986
y = np.asarray(y, dtype=np.float64)
987+
if y.shape[0] == 0:
988+
raise ValueError("y has 0 samples: %r" % y)
987989

988990
if hasattr(self, 'l1_ratio'):
989991
model_str = 'ElasticNet'

sklearn/linear_model/ridge.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -743,7 +743,8 @@ def fit(self, X, y, sample_weight=None):
743743
-------
744744
self : Returns self.
745745
"""
746-
X, y = check_X_y(X, y, ['csr', 'csc', 'coo'], dtype=np.float, multi_output=True)
746+
X, y = check_X_y(X, y, ['csr', 'csc', 'coo'], dtype=np.float,
747+
multi_output=True)
747748

748749
n_samples, n_features = X.shape
749750

sklearn/preprocessing/label.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from ..utils.fixes import in1d
2323
from ..utils import deprecated, column_or_1d
2424
from ..utils.validation import check_array
25+
from ..utils.validation import _num_samples
2526
from ..utils.multiclass import unique_labels
2627
from ..utils.multiclass import type_of_target
2728

@@ -315,6 +316,8 @@ def fit(self, y):
315316
if 'multioutput' in self.y_type_:
316317
raise ValueError("Multioutput target data is not supported with "
317318
"label binarization")
319+
if _num_samples(y) == 0:
320+
raise ValueError('y has 0 samples: %r' % y)
318321

319322
self.sparse_input_ = sp.issparse(y)
320323
self.classes_ = unique_labels(y)
@@ -465,6 +468,9 @@ def label_binarize(y, classes, neg_label=0, pos_label=1,
465468
# XXX Workaround that will be removed when list of list format is
466469
# dropped
467470
y = check_array(y, accept_sparse='csr', ensure_2d=False)
471+
else:
472+
if _num_samples(y) == 0:
473+
raise ValueError('y has 0 samples: %r' % y)
468474
if neg_label >= pos_label:
469475
raise ValueError("neg_label={0} must be strictly less than "
470476
"pos_label={1}.".format(neg_label, pos_label))

sklearn/tests/test_common.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,9 @@ def test_non_meta_estimators():
106106
if name not in CROSS_DECOMPOSITION + ['Imputer']:
107107
# Check that all estimator yield informative messages when
108108
# trained on empty datasets
109-
yield check_estimators_empty_data_messages, name, Estimator
109+
multi_output = name.startswith('MultiTask')
110+
yield (check_estimators_empty_data_messages,
111+
name, Estimator, multi_output)
110112

111113
# Test that all estimators check their input for NaN's and infs
112114
yield check_estimators_nan_inf, name, Estimator

sklearn/utils/estimator_checks.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -319,19 +319,24 @@ def check_estimators_dtypes(name, Estimator):
319319
pass
320< D74D /td>320

321321

322-
def check_estimators_empty_data_messages(name, Estimator):
322+
def check_estimators_empty_data_messages(name, Estimator, multi_output=False):
323323
with warnings.catch_warnings(record=True):
324324
e = Estimator()
325325
set_fast_parameters(e)
326326
set_random_state(e, 1)
327327

328328
X_zero_samples = np.empty(0).reshape(0, 3)
329-
y = []
330-
msg = "0 sample(s) (shape=(0, 3)) while a minimum of 1 is required."
331-
assert_raise_message(ValueError, msg, e.fit, X_zero_samples, y)
329+
# The precise message can change depending on whether X or y is
330+
# validated first. Let us test the type of exception only:
331+
assert_raises(ValueError, e.fit, X_zero_samples, [])
332332

333333
X_zero_features = np.empty(0).reshape(3, 0)
334-
y = [1, 2, 3]
334+
# the following y should be accepted by both classifiers and regressors
335+
# and ignored by unsupervised models
336+
if multi_output:
337+
y = np.array([[1, 1], [0, 1], [0, 1]])
338+
else:
339+
y = [1, 0, 1]
335340
msg = "0 feature(s) (shape=(3, 0)) while a minimum of 1 is required."
336341
assert_raise_message(ValueError, msg, e.fit, X_zero_features, y)
337342

0 commit comments

Comments
 (0)
0