8000 override predict_proba in log_reg · scikit-learn/scikit-learn@491f0ad · GitHub
[go: up one dir, main page]

Skip to content

Commit 491f0ad

Browse files
committed
override predict_proba in log_reg
1 parent c3cfebe commit 491f0ad

File tree

2 files changed

+20
-29
lines changed

2 files changed

+20
-29
lines changed

sklearn/linear_model/base.py

Lines changed: 8 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -238,32 +238,16 @@ def _predict_proba_lr(self, X):
238238
1. / (1. + np.exp(-self.decision_function(X)));
239239
multiclass is handled by normalizing that over all classes.
240240
"""
241-
from sklearn.linear_model.logistic import (
242-
LogisticRegression, LogisticRegressionCV)
243-
244-
calculate_ovr = True
245241
prob = self.decision_function(X)
246-
binary = len(prob.shape) == 1
247-
if (isinstance(self, LogisticRegression) or
248-
isinstance(self, LogisticRegressionCV)) and (
249-
self.multi_class == "multinomial" and not binary):
250-
calculate_ovr = False
251-
if calculate_ovr:
252-
prob *= -1
253-
np.exp(prob, prob)
254-
prob += 1
255-
np.reciprocal(prob, prob)
256-
if binary:
257-
return np.vstack([1 - prob, prob]).T
258-
else:
259-
# OvR normalization, like LibLinear's predict_probability
260-
prob /= prob.sum(axis=1).reshape((prob.shape[0], -1))
261-
return prob
262-
242+
prob *= -1
243+
np.exp(prob, prob)
244+
prob += 1
245+
np.reciprocal(prob, prob)
246+
if prob.ndim == 1:
247+
return np.vstack([1 - prob, prob]).T
263248
else:
264-
np.exp(prob, prob)
265-
sum_prob = np.sum(prob, axis=1).reshape((-1, 1))
266-
prob /= sum_prob
249+
# OvR normalization, like LibLinear's predict_probability
250+
prob /= prob.sum(axis=1).reshape((prob.shape[0], -1))
267251
return prob
268252

269253

sklearn/linear_model/logistic.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1091,10 +1091,9 @@ def predict_proba(self, X):
10911091
For a multi_class problem, if multi_class is set to be "multinomial"
10921092
the softmax function is used to find the predicted probability of
10931093
each class.
1094-
Else use a one-vs-rest approach, i.e calculating the probability
1095-
of each class assuming it to be positive using th logistic function.
1096-
Normalize across all the classes at the end such that the sum of
1097-
probabilities is 1.
1094+
Else use a one-vs-rest approach, i.e calculate the probability
1095+
of each class assuming it to be positive using the logistic function.
1096+
and normalize these values across all the classes.
10981097
10991098
Parameters
11001099
----------
@@ -1106,7 +1105,15 @@ def predict_proba(self, X):
11061105
Returns the probability of the sample for each class in the model,
11071106
where classes are ordered as they are in ``self.classes_``.
11081107
"""
1109-
return self._predict_proba_lr(X)
1108+
calculate_ovr = self.coef_.shape[0] == 1 or self.multi_class == "ovr"
1109+
if calculate_ovr:
1110+
return super(LogisticRegression, self)._predict_proba_lr(X)
1111+
else:
1112+
prob = self.decision_function(X)
1113+
np.exp(prob, prob)
1114+
sum_prob = np.sum(prob, axis=1).reshape((-1, 1))
1115+
prob /= sum_prob
1116+
return prob
11101117

11111118
def predict_log_proba(self, X):
11121119
"""Log of probability estimates.

0 commit comments

Comments
 (0)
0