25
25
squared_norm )
26
26
from ..utils .optimize import newton_cg
27
27
from ..utils .validation import (as_float_array , DataConversionWarning ,
28
- check_X_y )
28
+ check_X_y , NotFittedError )
29
29
from ..utils .fixes import expit
30
30
from ..externals .joblib import Parallel , delayed
31
31
from ..cross_validation import check_cv
@@ -1091,10 +1091,9 @@ def predict_proba(self, X):
1091
1091
For a multi_class problem, if multi_class is set to be "multinomial"
1092
1092
the softmax function is used to find the predicted probability of
1093
1093
each class.
1094
- Else use a one-vs-rest approach, i.e calculating the probability
1095
- of each class assuming it to be positive using th logistic function.
1096
- Normalize across all the classes at the end such that the sum of
1097
- probabilities is 1.
1094
+ Else use a one-vs-rest approach, i.e calculate the probability
1095
+ of each class assuming it to be positive using the logistic function.
1096
+ and normalize these values across all the classes.
1098
1097
1099
1098
Parameters
1100
1099
----------
@@ -1106,7 +1105,17 @@ def predict_proba(self, X):
1106
1105
Returns the probability of the sample for each class in the model,
1107
1106
where classes are ordered as they are in ``self.classes_``.
1108
1107
"""
1109
- return self ._predict_proba_lr (X )
1108
+ if not hasattr (self , "coef_" ):
1109
+ raise NotFittedError ("Call fit before prediction" )
1110
+ calculate_ovr = self .coef_ .shape [0 ] == 1 or self .multi_class == "ovr"
1111
+ if calculate_ovr :
1112
+ return super (LogisticRegression , self )._predict_proba_lr (X )
1113
+ else :
1114
+ prob = self .decision_function (X )
1115
+ np .exp (prob , prob )
1116
+ sum_prob = np .sum (prob , axis = 1 ).reshape ((- 1 , 1 ))
1117
+ prob /= sum_prob
1118
+ return prob
1110
1119
1111
1120
def predict_log_proba (self , X ):
1112
1121
"""Log of probability estimates.
0 commit comments