diff --git a/sklearn/linear_model/stochastic_gradient.py b/sklearn/linear_model/stochastic_gradient.py index b3f8a34fe6a8d..c6ad4c2754c51 100644 --- a/sklearn/linear_model/stochastic_gradient.py +++ b/sklearn/linear_model/stochastic_gradient.py @@ -802,8 +802,6 @@ def __init__(self, loss="hinge", penalty='l2', alpha=0.0001, l1_ratio=0.15, average=average, n_iter=n_iter) def _check_proba(self): - check_is_fitted(self, "t_") - if self.loss not in ("log", "modified_huber"): raise AttributeError("probability estimates are not available for" " loss=%r" % self.loss) @@ -848,6 +846,8 @@ def predict_proba(self): return self._predict_proba def _predict_proba(self, X): + check_is_fitted(self, "t_") + if self.loss == "log": return self._predict_proba_lr(X) diff --git a/sklearn/linear_model/tests/test_sgd.py b/sklearn/linear_model/tests/test_sgd.py index 80b3ca394f990..9f372f706ca71 100644 --- a/sklearn/linear_model/tests/test_sgd.py +++ b/sklearn/linear_model/tests/test_sgd.py @@ -1,5 +1,6 @@ import pickle import unittest +import pytest import numpy as np import scipy.sparse as sp @@ -467,6 +468,29 @@ def test_set_coef_multiclass(self): # Provided intercept_ does match dataset. clf = self.factory().fit(X2, Y2, intercept_init=np.zeros((3,))) + def test_sgd_predict_proba_method_access(self): + # Checks that SGDClassifier predict_proba and predict_log_proba methods + # can either be accessed or raise an appropriate error message + # otherwise. See + # https://github.com/scikit-learn/scikit-learn/issues/10938 for more + # details. + for loss in SGDClassifier.loss_functions: + clf = SGDClassifier(loss=loss) + if loss in ('log', 'modified_huber'): + assert hasattr(clf, 'predict_proba') + assert hasattr(clf, 'predict_log_proba') + else: + message = ("probability estimates are not " + "available for loss={!r}".format(loss)) + assert not hasattr(clf, 'predict_proba') + assert not hasattr(clf, 'predict_log_proba') + with pytest.raises(AttributeError, + message=message): + clf.predict_proba + with pytest.raises(AttributeError, + message=message): + clf.predict_log_proba + def test_sgd_proba(self): # Check SGD.predict_proba