8000 [MRG] Fix predict_proba not fitted check in SGDClassifier by aniruddhadave · Pull Request #10961 · scikit-learn/scikit-learn · GitHub
[go: up one dir, main page]

Skip to content

[MRG] Fix predict_proba not fitted check in SGDClassifier #10961

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Apr 17, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions sklearn/linear_model/stochastic_gradient.py
Original file line number Diff line number Diff line change
Expand Up @@ -802,8 +802,6 @@ def __init__(self, loss="hinge", penalty='l2', alpha=0.0001, l1_ratio=0.15,
average=average, n_iter=n_iter)

def _check_proba(self):
check_is_fitted(self, "t_")

if self.loss not in ("log", "modified_huber"):
raise AttributeError("probability estimates are not available for"
" loss=%r" % self.loss)
Expand Down Expand Up @@ -848,6 +846,8 @@ def predict_proba(self):
return self._predict_proba

def _predict_proba(self, X):
check_is_fitted(self, "t_")

if self.loss == "log":
return self._predict_proba_lr(X)

Expand Down
24 changes: 24 additions & 0 deletions sklearn/linear_model/tests/test_sgd.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import pickle
import unittest
import pytest

import numpy as np
import scipy.sparse as sp
Expand Down Expand Up @@ -467,6 +468,29 @@ def test_set_coef_multiclass(self):
# Provided intercept_ does match dataset.
clf = self.factory().fit(X2, Y2, intercept_init=np.zeros((3,)))

def test_sgd_predict_proba_method_access(self):
# Checks that SGDClassifier predict_proba and predict_log_proba methods
# can either be accessed or raise an appropriate error message
# otherwise. See
# https://github.com/scikit-learn/scikit-learn/issues/10938 for more
# details.
for loss in SGDClassifier.loss_functions:
clf = SGDClassifier(loss=loss)
if loss in ('log', 'modified_huber'):
assert hasattr(clf,  6525 9;predict_proba')
assert hasattr(clf, 'predict_log_proba')
else:
message = ("probability estimates are not "
"available for loss={!r}".format(loss))
assert not hasattr(clf, 'predict_proba')
assert not hasattr(clf, 'predict_log_proba')
with pytest.raises(AttributeError,
message=message):
clf.predict_proba
with pytest.raises(AttributeError,
message=message):
clf.predict_log_proba

def test_sgd_proba(self):
# Check SGD.predict_proba

Expand Down
0