8000 [MRG+2] Adds helpful messages in all error assertions in estimator_ch… · scikit-learn/scikit-learn@9e606bf · GitHub
[go: up one dir, main page]

Skip to content

Commit 9e606bf

Browse files
thechargedneutronlesteve
authored andcommitted
[MRG+2] Adds helpful messages in all error assertions in estimator_checks (#9588)
1 parent 026e10a commit 9e606bf

File tree

1 file changed

+50
-15
lines changed

1 file changed

+50
-15
lines changed

sklearn/utils/estimator_checks.py

Lines changed: 50 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -688,7 +688,11 @@ def check_transformers_unfitted(name, transformer):
688688
X, y = _boston_subset()
689689

690690
transformer = clone(transformer)
691-
assert_raises((AttributeError, ValueError), transformer.transform, X)
691+
with assert_raises((AttributeError, ValueError), msg="The unfitted "
692+
"transformer {} does not raise an error when "
693+
"transform is called. Perhaps use "
694+
"check_is_fitted in transform.".format(name)):
695+
transformer.transform(X)
692696

693697

694698
def _check_transformer(name, transformer_orig, X, y):
@@ -760,7 +764,12 @@ def _check_transformer(name, transformer_orig, X, y):
760764
# raises error on malformed input for transform
761765
if hasattr(X, 'T'):
762766
# If it's not an array, it does not have a 'T' property
763-
assert_raises(ValueError, transformer.transform, X.T)
767+
with assert_raises(ValueError, msg="The transformer {} does "
768+
"not raise an error when the number of "
769+
"features in transform is different from"
770+
" the number of features in "
771+
"fit.".format(name)):
772+
transformer.transform(X.T)
764773

765774

766775
@ignore_warnings
@@ -853,7 +862,11 @@ def check_estimators_empty_data_messages(name, estimator_orig):
853862
X_zero_samples = np.empty(0).reshape(0, 3)
854863
# The precise message can change depending on whether X or y is
855864
# validated first. Let us test the type of exception only:
856-
assert_raises(ValueError, e.fit, X_zero_samples, [])
865+
with assert_raises(ValueError, msg="The estimator {} does not"
866+
" raise an error when an empty data is used "
867+
"to train. Perhaps use "
868+
"check_array in train.".format(name)):
869+
e.fit(X_zero_samples, [])
857870

858871
X_zero_features = np.empty(0).reshape(3, 0)
859872
# the following y should be accepted by both classifiers and regressors
@@ -988,7 +1001,12 @@ def check_estimators_partial_fit_n_features(name, estimator_orig):
9881001
except NotImplementedError:
9891002
return
9901003

991-
assert_raises(ValueError, estimator.partial_fit, X[:, :-1], y)
1004+
with assert_raises(ValueError,
1005+
msg="The estimator {} does not raise an"
1006+
" error when the number of features"
1007+
" changes between calls to "
1008+
"partial_fit.".format(name)):
1009+
estimator.partial_fit(X[:, :-1], y)
9921010

9931011

9941012
@ignore_warnings(category=(DeprecationWarning, FutureWarning))
@@ -1092,7 +1110,12 @@ def check_classifiers_train(name, classifier_orig):
10921110
X -= X.min()
10931111
set_random_state(classifier)
10941112
# raises error on malformed input for fit
1095-
assert_raises(ValueError, classifier.fit, X, y[:-1])
1113+
with assert_raises(ValueError, msg="The classifer {} does not"
1114+
" raise an error when incorrect/malformed input "
1115+
"data for fit is passed. The number of training "
1116+
"examples is not the same as the number of labels."
1117+
" Perhaps use check_X_y in fit.".format(name)):
1118+
classifier.fit(X, y[:-1])
10961119

10971120
# fit
10981121
classifier.fit(X, y)
@@ -1106,7 +1129,11 @@ def check_classifiers_train(name, classifier_orig):
11061129
assert_greater(accuracy_score(y, y_pred), 0.83)
11071130

11081131
# raises error on malformed input for predict
1109-
assert_raises(ValueError, classifier.predict, X.T)
1132+
with assert_raises(ValueError, msg="The classifier {} does not"
1133+
" raise an error when the number of features "
1134+
"in predict is different from the number of"
1135+
" features in fit.".format(name)):
1136+
classifier.predict(X.T)
11101137
if hasattr(classifier, "decision_function"):
11111138
try:
11121139
# decision_function agrees with predict
@@ -1121,12 +1148,13 @@ def check_classifiers_train(name, classifier_orig):
11211148
assert_equal(decision.shape, (n_samples, n_classes))
11221149
assert_array_equal(np.argmax(decision, axis=1), y_pred)
11231150

1124-
# raises error on malformed input
1125-
assert_raises(ValueError,
1126-
classifier.decision_function, X.T)
11271151
# raises error on malformed input for decision_function
1128-
assert_raises(ValueError,
1129-
classifier.decision_function, X.T)
1152+
with assert_raises(ValueError, msg="The classifier {} does"
1153+
" not raise an error when the number of "
1154+
"features in decision_function is "
1155+
"different from the number of features"
1156+
" in fit.".format(name)):
1157+
classifier.decision_function(X.T)
11301158
except NotImplementedError:
11311159
pass
11321160
if hasattr(classifier, "predict_proba"):
@@ -1136,10 +1164,12 @@ def check_classifiers_train(name, classifier_orig):
11361164
assert_array_equal(np.argmax(y_prob, axis=1), y_pred)
11371165
# check that probas for all classes sum to one
11381166
assert_allclose(np.sum(y_prob, axis=1), np.ones(n_samples))
1139-
# raises error on malformed input
1140-
assert_raises(ValueError, classifier.predict_proba, X.T)
11411167
# raises error on malformed input for predict_proba
1142-
assert_raises(ValueError, classifier.predict_proba, X.T)
1168+
with assert_raises(ValueError, msg="The classifier {} does not"
1169+
" raise an error when the number of features "
1170+
"in predict_proba is different from the number "
1171+
"of features in fit.".format(name)):
1172+
classifier.predict_proba(X.T)
11431173
if hasattr(classifier, "predict_log_proba"):
11441174
# predict_log_proba is a transformation of predict_proba
11451175
y_log_prob = classifier.predict_log_proba(X)
@@ -1303,7 +1333,12 @@ def check_regressors_train(name, regressor_orig):
13031333
regressor.C = 0.01
13041334

13051335
# raises error on malformed input for fit
1306-
assert_raises(ValueError, regressor.fit, X, y[:-1])
1336+
with assert_raises(ValueError, msg="The classifer {} does not"
1337+
" raise an error when incorrect/malformed input "
1338+
"data for fit is passed. The number of training "
1339+
"examples is not the same as the number of "
1340+
"labels. Perhaps use check_X_y in fit.".format(name)):
1341+
regressor.fit(X, y[:-1])
13071342
# fit
13081343
if name in CROSS_DECOMPOSITION:
13091344
y_ = np.vstack([y, 2 * y + rnd.randint(2, size=len(y))])

0 commit comments

Comments
 (0)
0