8000 Not allow None base estimator · scikit-learn/scikit-learn@39e23c6 · GitHub
[go: up one dir, main page]

Skip to content

Commit 39e23c6

Browse files
author
Prokopios Gryllos
committed
Not allow None base estimator
1 parent 2e5a9bb commit 39e23c6

File tree

1 file changed

+6
-4
lines changed

1 file changed

+6
-4
lines changed

sklearn/calibration.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
from .utils.fixes import signature
2525
from .isotonic import IsotonicRegression
2626
from .svm import LinearSVC
27+
from .linear_model import LogisticRegression
2728
from .model_selection import check_cv
2829
from .metrics.classification import _check_binary_probabilistic_predictions
2930
from .metrics.pairwise import euclidean_distances
@@ -87,7 +88,7 @@ class CutoffClassifier(BaseEstimator, ClassifierMixin):
8788
Decision threshold for the positive class. Determines the output of
8889
predict
8990
"""
90-
def __init__(self, base_estimator=None, method='roc', pos_label=1, cv=3,
91+
def __init__(self, base_estimator, method='roc', pos_label=1, cv=3,
9192
min_val_tnr=None, min_val_tpr=None):
9293
self.base_estimator = base_estimator
9394
self.method = method
@@ -112,15 +113,16 @@ def fit(self, X, y):
112113
self : object
113114
Instance of self.
114115
"""
116+
if not isinstance(self.base_estimator, BaseEstimator):
117+
raise AttributeError('Base estimator must be of type BaseEstimator;'
118+
'got %s instead' % type(self.base_estimator))
119+
115120
X, y = check_X_y(X, y)
116121

117122
self.label_encoder = LabelEncoder().fit(y)
118123
y = self.label_encoder.transform(y)
119124
self.pos_label = self.label_encoder.transform([self.pos_label])[0]
120125

121-
if not self.base_estimator:
122-
self.base_estimator = LinearSVC(random_state=0)
123-
124126
if self.cv == 'prefit':
125127
self.threshold = _CutoffClassifier(
126128
self.base_estimator, self.method, self.pos_label,

0 commit comments

Comments
 (0)
0