8000 Not allow None base estimator · scikit-learn/scikit-learn@e63657c · GitHub
[go: up one dir, main page]

Skip to content

Commit e63657c

Browse files
author
Prokopios Gryllos
committed
Not allow None base estimator
1 parent 2e5a9bb commit e63657c

File tree

1 file changed

+5
-4
lines changed

1 file changed

+5
-4
lines changed

sklearn/calibration.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ class CutoffClassifier(BaseEstimator, ClassifierMixin):
8787
Decision threshold for the positive class. Determines the output of
8888
predict
8989
"""
90-
def __init__(self, base_estimator=None, method='roc', pos_label=1, cv=3,
90+
def __init__(self, base_estimator, method='roc', pos_label=1, cv=3,
9191
min_val_tnr=None, min_val_tpr=None):
9292
self.base_estimator = base_estimator
9393
self.method = method
@@ -112,15 +112,16 @@ def fit(self, X, y):
112112
self : object
113113
Instance of self.
114114
"""
115+
if not isinstance(self.base_estimator, BaseEstimator):
116+
raise AttributeError('Base estimator must be of type BaseEstimator;'
117+
'got %s instead' % type(self.base_estimator))
118+
115119
X, y = check_X_y(X, y)
116120

117121
self.label_encoder = LabelEncoder().fit(y)
118122
y = self.label_encoder.transform(y)
119123
self.pos_label = self.label_encoder.transform([self.pos_label])[0]
120124

121-
if not self.base_estimator:
122-
self.base_estimator = LinearSVC(random_state=0)
123-
124125
if self.cv == 'prefit':
125126
self.threshold = _CutoffClassifier(
126127
self.base_estimator, self.method, self.pos_label,

0 commit comments

Comments
 (0)
0