6
6
import traceback
7
7
import pickle
8
8
from copy import deepcopy
9
-
10
9
import numpy as np
11
10
from scipy import sparse
11
+ from scipy .stats import rankdata
12
12
import struct
13
13
14
14
from sklearn .externals .six .moves import zip
@@ -113,10 +113,10 @@ def _yield_classifier_checks(name, Classifier):
113
113
# basic consistency testing
114
114
yield check_classifiers_train
115
115
yield check_classifiers_regression_target
116
- if (name not in ["MultinomialNB" , "LabelPropagation" , "LabelSpreading" ]
116
+ if (name not in
117
+ ["MultinomialNB" , "LabelPropagation" , "LabelSpreading" ] and
117
118
# TODO some complication with -1 label
118
- and name not in ["DecisionTreeClassifier" ,
119
- "ExtraTreeClassifier" ]):
119
+ name not in ["DecisionTreeClassifier" , "ExtraTreeClassifier" ]):
120
120
# We don't raise a warning in these classifiers, as
121
121
# the column y interface is used by the forests.
122
122
@@ -127,6 +127,8 @@ def _yield_classifier_checks(name, Classifier):
127
127
yield check_class_weight_classifiers
128
128
129
129
yield check_non_transformer_estimators_n_iter
130
+ # test if predict_proba is a monotonic transformation of decision_function
131
+ yield check_decision_proba_consistency
130
132
131
133
132
134
@ignore_warnings (category = DeprecationWarning )
@@ -269,8 +271,7 @@ def set_testing_parameters(estimator):
269
271
# set parameters to speed up some estimators and
270
272
# avoid deprecated behaviour
271
273
params = estimator .get_params ()
272
- if ("n_iter" in params
273
- and estimator .__class__ .__name__ != "TSNE" ):
274
+ if ("n_iter" in params and estimator .__class__ .__name__ != "TSNE" ):
274
275
estimator .set_params (n_iter = 5 )
275
276
if "max_iter" in params :
276
277
warnings .simplefilter ("ignore" , ConvergenceWarning )
@@ -1112,8 +1113,7 @@ def check_classifiers_train(name, Classifier):
1112
1113
assert_equal (decision .shape , (n_samples ,))
1113
1114
dec_pred = (decision .ravel () > 0 ).astype (np .int )
1114
1115
assert_array_equal (dec_pred , y_pred )
1115
- if (n_classes is 3
1116
- and not isinstance (classifier , BaseLibSVM )):
1116
+ if (n_classes is 3 and not isinstance (classifier , BaseLibSVM )):
1117
1117
# 1on1 of LibSVM works differently
1118
1118
assert_equal (decision .shape , (n_samples , n_classes ))
1119
1119
assert_array_equal (np .argmax (decision , axis = 1 ), y_pred )
@@ -1574,9 +1574,9 @@ def check_parameters_default_constructible(name, Estimator):
1574
1574
try :
1575
1575
def param_filter (p ):
1576
1576
"""Identify hyper parameters of an estimator"""
1577
- return (p .name != 'self'
1578
- and p .kind != p .VAR_KEYWORD
1579
- and p .kind != p .VAR_POSITIONAL )
1577
+ return (p .name != 'self' and
1578
+ p .kind != p .VAR_KEYWORD and
1579
+ p .kind != p .VAR_POSITIONAL )
1580
1580
1581
1581
init_params = [p for p in signature (init ).parameters .values ()
1582
1582
if param_filter (p )]
@@ -1721,3 +1721,25 @@ def check_classifiers_regression_target(name, Estimator):
1721
1721
e = Estimator ()
1722
1722
msg = 'Unknown label type: '
1723
1723
assert_raises_regex (ValueError , msg , e .fit , X , y )
1724
+
1725
+
1726
+ @ignore_warnings (category = DeprecationWarning )
1727
+ def check_decision_proba_consistency (name , Estimator ):
1728
+ # Check whether an estimator having both decision_function and
1729
+ # predict_proba methods has outputs with perfect rank correlation.
1730
+
1731
+ centers = [(2 , 2 ), (4 , 4 )]
1732
+ X , y = make_blobs (n_samples = 100 , random_state = 0 , n_features = 4 ,
1733
+ centers = centers , cluster_std = 1.0 , shuffle = True )
1734
+ X_test = np .random .randn (20 , 2 ) + 4
1735
+ estimator = Estimator ()
1736
+
1737
+ set_testing_parameters (estimator )
1738
+
1739
+ if (hasattr (estimator , "decision_function" ) and
1740
+ hasattr (estimator , "predict_proba" )):
1741
+
1742
+ estimator .fit (X , y )
1743
+ a = estimator .predict_proba (X_test )[:, 1 ]
1744
+ b = estimator .decision_function (X_test )
1745
+ assert_array_equal (rankdata (a ), rankdata (b ))
0 commit comments