8000 Review feedback. · scikit-learn/scikit-learn@78f3938 · GitHub
[go: up one dir, main page]

Skip to content

Commit 78f3938

Browse files
committed
Review feedback.
1 parent 2b7dab1 commit 78f3938

File tree

1 file changed

+26
-12
lines changed

1 file changed

+26
-12
lines changed

sklearn/utils/estimator_checks.py

Lines changed: 26 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1500,25 +1500,29 @@ def check_estimators_partial_fit_n_features(name, estimator_orig):
15001500

15011501
@ignore_warnings(category=(DeprecationWarning, FutureWarning))
15021502
def check_classifier_multioutput(name, estimator):
1503-
n_samples, n_labels = 42, 5
1503+
n_samples, n_labels, n_classes = 42, 5, 3
15041504
tags = _safe_tags(estimator)
15051505
estimator = clone(estimator)
15061506
X, y = make_multilabel_classification(random_state=42,
15071507
n_samples=n_samples,
1508-
n_labels=n_labels)
1508+
n_labels=n_labels,
1509+
n_classes=n_classes)
15091510
estimator.fit(X, y)
15101511
y_pred = estimator.predict(X)
15111512

1512-
assert (y_pred.shape == (n_samples, n_labels),
1513-
"The shape of the prediction for multioutput data is "
1514-
" incorrect. Expected {}, got {}."
1515-
.format((n_samples, n_labels), y_pred.shape))
1513+
assert y_pred.shape == (n_samples, n_classes), (
1514+
"The shape of the prediction for multioutput data is "
1515+
"incorrect. Expected {}, got {}."
1516+
.format((n_samples, n_labels), y_pred.shape))
15161517
assert y_pred.dtype.kind == 'i'
15171518

15181519
if hasattr(estimator, "decision_function"):
15191520
decision = estimator.decision_function(X)
15201521
assert isinstance(decision, np.ndarray)
1521-
assert decision.shape == (n_samples, n_labels)
1522+
assert decision.shape == (n_samples, n_classes), (
1523+
"The shape of the decision function output for "
1524+
"multioutput data is incorrect. Expected {}, got {}."
1525+
.format((n_samples, n_classes), decision.shape))
15221526

15231527
dec_pred = (decision > 0).astype(np.int)
15241528
dec_exp = estimator.classes_[dec_pred]
@@ -1528,14 +1532,20 @@ def check_classifier_multioutput(name, estimator):
15281532
y_prob = estimator.predict_proba(X)
15291533

15301534
if isinstance(y_prob, list) and not tags['poor_score']:
1531-
for i in range(n_labels):
1532-
y_prob[i].shape == (n_samples, n_labels)
1535+
for i in range(n_classes):
1536+
assert y_prob[i].shape == (n_samples, 2), (
1537+
"The shape of the probability for multioutput data is"
1538+
" incorrect. Expected {}, got {}."
1539+
.format((n_samples, 2), y_prob[i].shape))
15331540
assert_array_equal(
15341541
np.argmax(y_prob[i], axis=1).astype(np.int),
15351542
y_pred[:, i]
15361543
)
15371544
elif not tags['poor_score']:
1538-
y_prob.shape == (n_samples, n_labels)
1545+
assert y_prob.shape == (n_samples, n_labels), (
1546+
"The shape of the probability for multioutput data is"
1547+
" incorrect. Expected {}, got {}."
1548+
.format((n_samples, n_labels), y_prob.shape))
15391549
assert_array_equal(y_prob.round().astype(int), y_pred)
15401550

15411551
if (hasattr(estimator, "decision_function") and
@@ -1561,8 +1571,12 @@ def check_regressor_multioutput(name, estimator):
15611571
estimator.fit(X, y)
15621572
y_pred = estimator.predict(X)
15631573

1564-
assert y_pred.dtype == np.dtype('float')
1565-
assert y_pred.shape == y.shape
1574+
assert y_pred.dtype == np.dtype('float64'), (
1575+
"Multioutput predictions by a regressor are expected to be"
1576+
" floating-point precision. Got {} instead".format(y_pred.dtype))
1577+
assert y_pred.shape == y.shape, (
1578+
"The shape of the orediction for multioutput data is incorrect."
1579+
" Expected {}, got {}.")
15661580

15671581

15681582
@ignore_warnings(category=(DeprecationWarning, FutureWarning))

0 commit comments

Comments
 (0)
0