66import traceback
77import pickle
88from copy import deepcopy
9-
109import numpy as np
1110from scipy import sparse
11+ from scipy .stats import rankdata
1212import struct
1313
1414from sklearn .externals .six .moves import zip
@@ -113,10 +113,10 @@ def _yield_classifier_checks(name, Classifier):
113113 # basic consistency testing
114114 yield check_classifiers_train
115115 yield check_classifiers_regression_target
116- if (name not in ["MultinomialNB" , "LabelPropagation" , "LabelSpreading" ]
116+ if (name not in
117+ ["MultinomialNB" , "LabelPropagation" , "LabelSpreading" ] and
117118 # TODO some complication with -1 label
118- and name not in ["DecisionTreeClassifier" ,
119- "ExtraTreeClassifier" ]):
119+ name not in ["DecisionTreeClassifier" , "ExtraTreeClassifier" ]):
120120 # We don't raise a warning in these classifiers, as
121121 # the column y interface is used by the forests.
122122
@@ -127,6 +127,8 @@ def _yield_classifier_checks(name, Classifier):
127127 yield check_class_weight_classifiers
128128
129129 yield check_non_transformer_estimators_n_iter
130+ # test if predict_proba is a monotonic transformation of decision_function
131+ yield check_decision_proba_consistency
130132
131133
132134@ignore_warnings (category = DeprecationWarning )
@@ -269,8 +271,7 @@ def set_testing_parameters(estimator):
269271 # set parameters to speed up some estimators and
270272 # avoid deprecated behaviour
271273 params = estimator .get_params ()
272- if ("n_iter" in params
273- and estimator .__class__ .__name__ != "TSNE" ):
274+ if ("n_iter" in params and estimator .__class__ .__name__ != "TSNE" ):
274275 estimator .set_params (n_iter = 5 )
275276 if "max_iter" in params :
276277 warnings .simplefilter ("ignore" , ConvergenceWarning )
@@ -1112,8 +1113,7 @@ def check_classifiers_train(name, Classifier):
11121113 assert_equal (decision .shape , (n_samples ,))
11131114 dec_pred = (decision .ravel () > 0 ).astype (np .int )
11141115 assert_array_equal (dec_pred , y_pred )
1115- if (n_classes is 3
1116- and not isinstance (classifier , BaseLibSVM )):
1116+ if (n_classes is 3 and not isinstance (classifier , BaseLibSVM )):
11171117 # 1on1 of LibSVM works differently
11181118 assert_equal (decision .shape , (n_samples , n_classes ))
11191119 assert_array_equal (np .argmax (decision , axis = 1 ), y_pred )
@@ -1574,9 +1574,9 @@ def check_parameters_default_constructible(name, Estimator):
15741574 try :
15751575 def param_filter (p ):
15761576 """Identify hyper parameters of an estimator"""
1577- return (p .name != 'self'
1578- and p .kind != p .VAR_KEYWORD
1579- and p .kind != p .VAR_POSITIONAL )
1577+ return (p .name != 'self' and
1578+ p .kind != p .VAR_KEYWORD and
1579+ p .kind != p .VAR_POSITIONAL )
15801580
15811581 init_params = [p for p in signature (init ).parameters .values ()
15821582 if param_filter (p )]
@@ -1721,3 +1721,25 @@ def check_classifiers_regression_target(name, Estimator):
17211721 e = Estimator ()
17221722 msg = 'Unknown label type: '
17231723 assert_raises_regex (ValueError , msg , e .fit , X , y )
1724+
1725+
1726+ @ignore_warnings (category = DeprecationWarning )
1727+ def check_decision_proba_consistency (name , Estimator ):
1728+ # Check whether an estimator having both decision_function and
1729+ # predict_proba methods has outputs with perfect rank correlation.
1730+
1731+ centers = [(2 , 2 ), (4 , 4 )]
1732+ X , y = make_blobs (n_samples = 100 , random_state = 0 , n_features = 4 ,
1733+ centers = centers , cluster_std = 1.0 , shuffle = True )
1734+ X_test = np .random .randn (20 , 2 ) + 4
55CE
1735+ estimator = Estimator ()
1736+
1737+ set_testing_parameters (estimator )
1738+
1739+ if (hasattr (estimator , "decision_function" ) and
1740+ hasattr (estimator , "predict_proba" )):
1741+
1742+ estimator .fit (X , y )
1743+ a = estimator .predict_proba (X_test )[:, 1 ]
1744+ b = estimator .decision_function (X_test )
1745+ assert_array_equal (rankdata (a ), rankdata (b ))
0 commit comments