8000 [MRG+1] Fix SGDClassifier never has the attribute "predict_proba" (even with log or modified_huber loss) by rebekahkim · Pull Request #12222 · scikit-learn/scikit-learn · GitHub
[go: up one dir, main page]

Skip to content

[MRG+1] Fix SGDClassifier never has the attribute "predict_proba" (even with log or modified_huber loss) #12222

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 14 commits into from
Apr 19, 2019
Merged
8 changes: 8 additions & 0 deletions doc/whats_new/v0.21.rst
Original file line number Diff line number Diff line change
Expand Up @@ -506,6 +506,14 @@ Support for Python 3.4 and below has been officially dropped.
containing this same sample due to the scaling used in decision_function.
:issue:`10440` by :user:`Jonathan Ohayon <Johayon>`.

:mod:`sklearn.multioutput`
........................

- |Fix| Fixed a bug in :class:`multiout.MultiOutputClassifier` where the
`predict_proba` method incorrectly checked for `predict_proba` attribute in
the estimator object.
:issue:`12222` by :user:`Rebekah Kim <rebekahkim>`

:mod:`sklearn.neighbors`
........................

Expand Down
15 changes: 10 additions & 5 deletions sklearn/multioutput.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ def fit(self, X, y, sample_weight=None):

if not hasattr(self.estimator, "fit"):
raise ValueError("The base estimator should implement"
" a fit method")
" a fit method")

X, y = check_X_y(X, y,
multi_output=True,
Expand Down Expand Up @@ -186,7 +186,8 @@ def predict(self, X):
"""
check_is_fitted(self, 'estimators_')
if not hasattr(self.estimator, "predict"):
raise ValueError("The base estimator should implement a predict method")
raise ValueError("The base estimator should implement"
" a predict method")

X = check_array(X, accept_sparse=True)

Expand Down Expand Up @@ -327,6 +328,9 @@ def predict_proba(self, X):
"""Probability estimates.
Returns prediction probabilities for each class of each output.

This method will raise a ``ValueError`` if any of the
estimators do not have ``predict_proba``.

Parameters
----------
X : array-like, shape (n_samples, n_features)
Expand All @@ -340,16 +344,17 @@ def predict_proba(self, X):
classes corresponds to that in the attribute `classes_`.
"""
check_is_fitted(self, 'estimators_')
if not hasattr(self.estimator, "predict_proba"):
raise ValueError("The base estimator should implement"
if not all([hasattr(estimator, "predict_proba")
for estimator in self.estimators_]):
raise ValueError("The base estimator should implement "
"predict_proba method")

results = [estimator.predict_proba(X) for estimator in
self.estimators_]
return results

def score(self, X, y):
""""Returns the mean accuracy on the given test data and labels.
"""Returns the mean accuracy on the given test data and labels.

Parameters
----------
Expand Down
29 changes: 29 additions & 0 deletions sklearn/tests/test_multioutput.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from sklearn.svm import LinearSVC
from sklearn.base import ClassifierMixin
from sklearn.utils import shuffle
from sklearn.model_selection import GridSearchCV


def test_multi_target_regression():
Expand Down Expand Up @@ -176,6 +177,34 @@ def test_multi_output_classification_partial_fit_parallelism():
assert est1 is not est2


# check predict_proba passes
def test_multi_output_predict_proba():
sgd_linear_clf = SGDClassifier(random_state=1, max_iter=5, tol=1e-3)
param = {'loss': ('hinge', 'log', 'modified_huber')}

# inner function for custom scoring
def custom_scorer(estimator, X, y):
if hasattr(estimator, "predict_proba"):
return 1.0
else:
return 0.0
grid_clf = GridSearchCV(sgd_linear_clf, param_grid=param,
scoring=custom_scorer, cv=3, error_score=np.nan)
multi_target_linear = MultiOutputClassifier(grid_clf)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would like to assert not hasattr(..., 'predict_proba') before doing this fit, so that the intention of the test is a bit clearer.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You mean for multi_target_linear.estimator, right? Technically, the estimator still wouldn't have predict_proba after fit because the underlying estimator (SGDClassifier with default loss='hinge') doesn't have predict_proba. But all estimators in estimators_ here would (after fit, of course).

If you mean for the multi_target_linear itself, it would have predict_proba before and after fit; it just won't be valid (raises ValueError)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jnothman thoughts?

multi_target_linear.fit(X, y)

multi_target_linear.predict_proba(X)

# SGDClassifier defaults to loss='hinge' which is not a probabilistic
# loss function; therefore it does not expose a predict_proba method
sgd_linear_clf = SGDClassifier(random_state=1, max_iter=5, tol=1e-3)
multi_target_linear = MultiOutputClassifier(sgd_linear_clf)
multi_target_linear.fit(X, y)
err_msg = "The base estimator should implement predict_proba method"
with pytest.raises(ValueError, match=err_msg):
multi_target_linear.predict_proba(X)


# 0.23. warning about tol not having its correct default value.
@pytest.mark.filterwarnings('ignore:max_iter and tol parameters have been')
def test_multi_output_classification_partial_fit():
Expand Down
0