8000 [MRG+1] Fixes #7578 added check_decision_proba_consistency in estimat… · massich/scikit-learn@cf903c3 · GitHub
[go: up one dir, main page]

Skip to content

Commit cf903c3

Browse files
Shubham BhardwajJoan Massich
authored andcommitted
[MRG+1] Fixes scikit-learn#7578 added check_decision_proba_consistency in estimator_checks (scikit-learn#8253)
1 parent 8b37b58 commit cf903c3

File tree

2 files changed

+40
-11
lines changed

2 files changed

+40
-11
lines changed

doc/whats_new.rst

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -270,6 +270,13 @@ API changes summary
270270
selection classes to be used with tools such as
271271
:func:`sklearn.model_selection.cross_val_predict`.
272272
:issue:`2879` by :user:`Stephen Hoover <stephen-hoover>`.
273+
274+
- Estimators with both methods ``decision_function`` and ``predict_proba``
275+
are now required to have a monotonic relation between them. The
276+
method ``check_decision_proba_consistency`` has been added in
277+
**sklearn.utils.estimator_checks** to check their consistency.
278+
:issue:`7578` by :user:`Shubham Bhardwaj <shubham0704>`
279+
273280

274281
.. _changes_0_18_1:
275282

sklearn/utils/estimator_checks.py

Lines changed: 33 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,9 @@
66
import traceback
77
import pickle
88
from copy import deepcopy
9-
109
import numpy as np
1110
from scipy import sparse
11+
from scipy.stats import rankdata
1212
import struct
1313

1414
from sklearn.externals.six.moves import zip
@@ -113,10 +113,10 @@ def _yield_classifier_checks(name, Classifier):
113113
# basic consistency testing
114114
yield check_classifiers_train
115115
yield check_classifiers_regression_target
116-
if (name not in ["MultinomialNB", "LabelPropagation", "LabelSpreading"]
116+
if (name not in
117+
["MultinomialNB", "LabelPropagation", "LabelSpreading"] and
117118
# TODO some complication with -1 label
118-
and name not in ["DecisionTreeClassifier",
119-
"ExtraTreeClassifier"]):
119+
name not in ["DecisionTreeClassifier", "ExtraTreeClassifier"]):
120120
# We don't raise a warning in these classifiers, as
121121
# the column y interface is used by the forests.
122122

@@ -127,6 +127,8 @@ def _yield_classifier_checks(name, Classifier):
127127
yield check_class_weight_classifiers
128128

129129
yield check_non_transformer_estimators_n_iter
130+
# test if predict_proba is a monotonic transformation of decision_function
131+
yield check_decision_proba_consistency
130132

131133

132134
@ignore_warnings(category=DeprecationWarning)
@@ -269,8 +271,7 @@ def set_testing_parameters(estimator):
269271
# set parameters to speed up some estimators and
270272
# avoid deprecated behaviour
271273
params = estimator.get_params()
272-
if ("n_iter" in params
273-
and estimator.__class__.__name__ != "TSNE"):
274+
if ("n_iter" in params and estimator.__class__.__name__ != "TSNE"):
274275
estimator.set_params(n_iter=5)
275276
if "max_iter" in params:
276277
warnings.simplefilter("ignore", ConvergenceWarning)
@@ -1112,8 +1113,7 @@ def check_classifiers_train(name, Classifier):
11121113
assert_equal(decision.shape, (n_samples,))
11131114
dec_pred = (decision.ravel() > 0).astype(np.int)
11141115
assert_array_equal(dec_pred, y_pred)
1115-
if (n_classes is 3
1116-
and not isinstance(classifier, BaseLibSVM)):
1116+
if (n_classes is 3 and not isinstance(classifier, BaseLibSVM)):
11171117
# 1on1 of LibSVM works differently
11181118
assert_equal(decision.shape, (n_samples, n_classes))
11191119
assert_array_equal(np.argmax(decision, axis=1), y_pred)
@@ -1574,9 +1574,9 @@ def check_parameters_default_constructible(name, Estimator):
15741574
try:
15751575
def param_filter(p):
15761576
"""Identify hyper parameters of an estimator"""
1577-
return (p.name != 'self'
1578-
and p.kind != p.VAR_KEYWORD
1579-
and p.kind != p.VAR_POSITIONAL)
1577+
return (p.name != 'self' and
1578+
p.kind != p.VAR_KEYWORD and
1579+
p.kind != p.VAR_POSITIONAL)
15801580

15811581
init_params = [p for p in signature(init).parameters.values()
15821582
if param_filter(p)]
@@ -1721,3 +1721,25 @@ def check_classifiers_regression_target(name, Estimator):
17211721
e = Estimator()
17221722
msg = 'Unknown label type: '
17231723
assert_raises_regex(ValueError, msg, e.fit, X, y)
1724+
1725+
1726+
@ignore_warnings(category=DeprecationWarning)
1727+
def check_decision_proba_consistency(name, Estimator):
1728+
# Check whether an estimator having both decision_function and
1729+
# predict_proba methods has outputs with perfect rank correlation.
1730+
1731+
centers = [(2, 2), (4, 4)]
1732+
X, y = make_blobs(n_samples=100, random_state=0, n_features=4,
1733+
centers=centers, cluster_std=1.0, shuffle=True)
1734+
X_test = np.random.randn(20, 2) + 4
1735+
estimator = Estimator()
1736+
1737+
set_testing_parameters(estimator)
1738+
1739+
if (hasattr(estimator, "decision_function") and
1740+
hasattr(estimator, "predict_proba")):
1741+
1742+
estimator.fit(X, y)
1743+
a = estimator.predict_proba(X_test)[:, 1]
1744+
b = estimator.decision_function(X_test)
1745+
assert_array_equal(rankdata(a), rankdata(b))

0 commit comments

Comments
 (0)
0