Description
Description
SGDClassifier's predict_proba() is not compatible with MultiOutputClassifier's predict_proba() (even when it has the proper loss functions: log or modified_huber).
The incompatibility occurs because estimators implementing SGDClassifier do not have the attribute "predict_proba"; thus, when wrapped by MultiOutputClassifier, predict_proba() raises an error.
The error occurs in this file:
https://github.com/scikit-learn/scikit-learn/blob/master/sklearn/multioutput.py
At this condition:
if not hasattr(self.estimator, "predict_proba"):
raise ValueError("The base estimator should implement"
"predict_proba method")
Just for the overly simplified example below, LogisticRegression classifiers do have the attribute, and those work correctly.
Steps/Code to Reproduce
from sklearn.linear_model import SGDClassifier as online
from sklearn.linear_model import LogisticRegression as log
# use either one because they allow predict_proba() with SGDClassifier alone:
clf_test = online(loss="log", penalty="l2")
#clf_test = online(loss="modified_huber", penalty="l2")
# The problematic condition in MultiOutputClassifier's predict_proba():
if not hasattr(clf_test, "predict_proba"):
print("Don't allow predict_proba() when wrapped by MultiOutputClassifier.")
else:
print("Allow predict_proba() when wrapped by MultiOutputClassifier.")
# By contrast, the logistic regression classifier would work.
clf_test = log()
if not hasattr(clf_test, "predict_proba"):
print("Don't allow predict_proba() when wrapped by MultiOutputClassifier.")
else:
print("Allow predict_proba() when wrapped by MultiOutputClassifier.")
Expected Results
Allow predict_proba() when wrapped by MultiOutputClassifier.
Allow predict_proba() when wrapped by MultiOutputClassifier.
Actual Results
Don't allow predict_proba() when wrapped by MultiOutputClassifier.
Allow predict_proba() when wrapped by MultiOutputClassifier.
Versions
Windows-10-10.0.15063
('Python', '2.7.11 |Anaconda custom (32-bit)| (default, Mar 4 2016, 15:18:41) [MSC v.1500 32 bit (Intel)]')
('NumPy', '1.10.4')
('SciPy', '0.17.0')
('Scikit-Learn', '0.19.1')