8000 Fix MultiOutputClassifier checking for predict_proba method of base e… · jeremiedbb/scikit-learn@3b54222 · GitHub
[go: up one dir, main page]

Skip to content

Commit 3b54222

Browse files
rebekahkimjeremiedbb
authored andcommitted
Fix MultiOutputClassifier checking for predict_proba method of base estimator (scikit-learn#12222)
#WiMLDS
1 parent 62b5a85 commit 3b54222

File tree

3 files changed

+47
-5
lines changed

3 files changed

+47
-5
lines changed

doc/whats_new/v0.21.rst

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -511,6 +511,14 @@ Support for Python 3.4 and below has been officially dropped.
511511
containing this same sample due to the scaling used in decision_function.
512512
:issue:`10440` by :user:`Jonathan Ohayon <Johayon>`.
513513

514+
:mod:`sklearn.multioutput`
515+
........................
516+
517+
- |Fix| Fixed a bug in :class:`multiout.MultiOutputClassifier` where the
518+
`predict_proba` method incorrectly checked for `predict_proba` attribute in
519+
the estimator object.
520+
:issue:`12222` by :user:`Rebekah Kim <rebekahkim>`
521+
514522
:mod:`sklearn.neighbors`
515523
........................
516524

sklearn/multioutput.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -145,7 +145,7 @@ def fit(self, X, y, sample_weight=None):
145145

146146
if not hasattr(self.estimator, "fit"):
147147
raise ValueError("The base estimator should implement"
148-
" a fit method")
148+
" a fit method")
149149

150150
X, y = check_X_y(X, y,
151151
multi_output=True,
@@ -186,7 +186,8 @@ def predict(self, X):
186186
"""
187187
check_is_fitted(self, 'estimators_')
188188
if not hasattr(self.estimator, "predict"):
189-
raise ValueError("The base estimator should implement a predict method")
189+
raise ValueError("The base estimator should implement"
190+
" a predict method")
190191

191192
X = check_array(X, accept_sparse=True)
192193

@@ -327,6 +328,9 @@ def predict_proba(self, X):
327328
"""Probability estimates.
328329
Returns prediction probabilities for each class of each output.
329330
331+
This method will raise a ``ValueError`` if any of the
332+
estimators do not have ``predict_proba``.
333+
330334
Parameters
331335
----------
332336
X : array-like, shape (n_samples, n_features)
@@ -340,16 +344,17 @@ def predict_proba(self, X):
340344
classes corresponds to that in the attribute `classes_`.
341345
"""
342346
check_is_fitted(self, 'estimators_')
343-
if not hasattr(self.estimator, "predict_proba"):
344-
raise ValueError("The base estimator should implement"
347+
if not all([hasattr(estimator, "predict_proba")
348+
for estimator in self.estimators_]):
349+
raise ValueError("The base estimator should implement "
345350
"predict_proba method")
346351

347352
results = [estimator.predict_proba(X) for estimator in
348353
self.estimators_]
349354
return results
350355

351356
def score(self, X, y):
352-
""""Returns the mean accuracy on the given test data and labels.
357+
"""Returns the mean accuracy on the given test data and labels.
353358
354359
Parameters
355360
----------

sklearn/tests/test_multioutput.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
from sklearn.svm import LinearSVC
3232
from sklearn.base import ClassifierMixin
3333
from sklearn.utils import shuffle
34+
from sklearn.model_selection import GridSearchCV
3435

3536

3637
def test_multi_target_regression():
@@ -176,6 +177,34 @@ def test_multi_output_classification_partial_fit_parallelism():
176177
assert est1 is not est2
177178

178179

180+
# check predict_proba passes
181+
def test_multi_output_predict_proba():
182+
sgd_linear_clf = SGDClassifier(random_state=1, max_iter=5, tol=1e-3)
183+
param = {'loss': ('hinge', 'log', 'modified_huber')}
184+
185+
# inner function for custom scoring
186+
def custom_scorer(estimator, X, y):
187+
if hasattr(estimator, "predict_proba"):
188+
return 1.0
189+
else:
190+
return 0.0
191+
grid_clf = GridSearchCV(sgd_linear_clf, param_grid=param,
192+
scoring=custom_scorer, cv=3, error_score=np.nan)
193+
multi_target_linear = MultiOutputClassifier(grid_clf)
194+
multi_target_linear.fit(X, y)
195+
196+
multi_target_linear.predict_proba(X)
197+
198+
# SGDClassifier defaults to loss='hinge' which is not a probabilistic
199+
# loss function; therefore it does not expose a predict_proba method
200+
sgd_linear_clf = SGDClassifier(random_state=1, max_iter=5, tol=1e-3)
201+
multi_target_linear = MultiOutputClassifier(sgd_linear_clf)
202+
multi_target_linear.fit(X, y)
203+
err_msg = "The base estimator should implement predict_proba method"
204+
with pytest.raises(ValueError, match=err_msg):
205+
multi_target_linear.predict_proba(X)
206+
207+
179208
# 0.23. warning about tol not having its correct default value.
180209
@pytest.mark.filterwarnings('ignore:max_iter and tol parameters have been')
181210
def test_multi_output_classification_partial_fit():

0 commit comments

Comments
 (0)
0