24
24
from .utils .fixes import signature
25
25
from .isotonic import IsotonicRegression
26
26
from .svm import LinearSVC
27
+ from .linear_model import LogisticRegression
27
28
from .model_selection import check_cv
28
29
from .metrics .classification import _check_binary_probabilistic_predictions
29
30
from .metrics .pairwise import euclidean_distances
@@ -87,7 +88,7 @@ class CutoffClassifier(BaseEstimator, ClassifierMixin):
87
88
Decision threshold for the positive class. Determines the output of
88
89
predict
89
90
"""
90
- def __init__ (self , base_estimator = None , method = 'roc' , pos_label = 1 , cv = 3 ,
91
+ def __init__ (self , base_estimator , method = 'roc' , pos_label = 1 , cv = 3 ,
91
92
min_val_tnr = None , min_val_tpr = None ):
92
93
self .base_estimator = base_estimator
93
94
self .method = method
@@ -112,15 +113,16 @@ def fit(self, X, y):
112
113
self : object
113
114
Instance of self.
114
115
"""
116
+ if not isinstance (self .base_estimator , BaseEstimator ):
117
+ raise AttributeError ('Base estimator must be of type BaseEstimator;'
118
+ 'got %s instead' % type (self .base_estimator ))
119
+
115
120
X , y = check_X_y (X , y )
116
121
117
122
self .label_encoder = LabelEncoder ().fit (y )
118
123
y = self .label_encoder .transform (y )
119
124
self .pos_label = self .label_encoder .transform ([self .pos_label ])[0 ]
120
125
121
- if not self .base_estimator :
122
- self .base_estimator = LinearSVC (random_state = 0 )
123
-
124
126
if self .cv == 'prefit' :
125
127
self .threshold = _CutoffClassifier (
126
128
self .base_estimator , self .method , self .pos_label ,
0 commit comments