8000 FIX property for predict_proba in MultiOutput Classifier (#15490) · rasbt/scikit-learn@981fa7b · GitHub
[go: up one dir, main page]

Skip to content

Commit 981fa7b

Browse files
rebekahkimjnothman
authored andcommitted
FIX property for predict_proba in MultiOutput Classifier (scikit-learn#15490)
1 parent faaeba4 commit 981fa7b

File tree

3 files changed

+29
-5
lines changed

3 files changed

+29
-5
lines changed

doc/whats_new/v0.22.rst

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -657,6 +657,10 @@ Changelog
657657
- |Fix| :class:`multioutput.MultiOutputClassifier` now has attribute
658658
``classes_``. :pr:`14629` by :user:`Agamemnon Krasoulis <agamemnonc>`.
659659

660+
- |Fix| :class:`multioutput.MultiOutputClassifier` now has `predict_proba`
661+
as property and can be checked with `hasattr`.
662+
:issue:`15488` :pr:`15490` by :user:`Rebekah Kim <rebekahkim>`
663+
660664
:mod:`sklearn.naive_bayes`
661665
...............................
662666

sklearn/multioutput.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -360,7 +360,8 @@ def fit(self, X, Y, sample_weight=None):
360360
self.classes_ = [estimator.classes_ for estimator in self.estimators_]
361361
return self
362362

363-
def predict_proba(self, X):
363+
@property
364+
def predict_proba(self):
364365
"""Probability estimates.
365366
Returns prediction probabilities for each class of each output.
366367
@@ -382,9 +383,11 @@ def predict_proba(self, X):
382383
check_is_fitted(self)
383384
if not all([hasattr(estimator, "predict_proba")
384385
for estimator in self.estimators_]):
385-
raise ValueError("The base estimator should implement "
386-
"predict_proba method")
386+
raise AttributeError("The base estimator should "
387+
"implement predict_proba method")
388+
return self._predict_proba
387389

390+
def _predict_proba(self, X):
388391
results = [estimator.predict_proba(X) for estimator in
389392
self.estimators_]
390393
return results

sklearn/tests/test_multioutput.py

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -175,6 +175,22 @@ def test_multi_output_classification_partial_fit_parallelism():
175175
assert est1 is not est2
176176

177177

178+
# check multioutput has predict_proba
179+
def test_hasattr_multi_output_predict_proba():
180+
# default SGDClassifier has loss='hinge'
181+
# which does not expose a predict_proba method
182+
sgd_linear_clf = SGDClassifier(random_state=1, max_iter=5)
183+
multi_target_linear = MultiOutputClassifier(sgd_linear_clf)
184+
multi_target_linear.fit(X, y)
185+
assert not hasattr(multi_target_linear, "predict_proba")
186+
187+
# case where predict_proba attribute exists
188+
sgd_linear_clf = SGDClassifier(loss='log', random_state=1, max_iter=5)
189+
multi_target_linear = MultiOutputClassifier(sgd_linear_clf)
190+
multi_target_linear.fit(X, y)
191+
assert hasattr(multi_target_linear, "predict_proba")
192+
193+
178194
# check predict_proba passes
179195
def test_multi_output_predict_proba():
180196
sgd_linear_clf = SGDClassifier(random_state=1, max_iter=5)
@@ -199,7 +215,7 @@ def custom_scorer(estimator, X, y):
199215
multi_target_linear = MultiOutputClassifier(sgd_linear_clf)
200216
multi_target_linear.fit(X, y)
201217
err_msg = "The base estimator should implement predict_proba method"
202-
with pytest.raises(ValueError, match=err_msg):
218+
with pytest.raises(AttributeError, match=err_msg):
203219
multi_target_linear.predict_proba(X)
204220

205221

@@ -378,7 +394,8 @@ def test_multi_output_exceptions():
378394
# and predict_proba are called
379395
moc = MultiOutputClassifier(LinearSVC(random_state=0))
380396
assert_raises(NotFittedError, moc.predict, y)
381-
assert_raises(NotFittedError, moc.predict_proba, y)
397+
with pytest.raises(NotFittedError):
398+
moc.predict_proba
382399
assert_raises(NotFittedError, moc.score, X, y)
383400
# ValueError when number of outputs is different
384401
# for fit and score

0 commit comments

Comments
 (0)
0