8000 ENH: Support data centering in LogisticRegression · kernc/scikit-learn@d9040ff · GitHub
[go: up one dir, main page]

Skip to content

Commit d9040ff

Browse files
committed
ENH: Support data centering in LogisticRegression
1 parent 113ee40 commit d9040ff

File tree

2 files changed

+35
-7
lines changed

2 files changed

+35
-7
lines changed

sklearn/linear_model/base.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -269,12 +269,15 @@ def predict(self, X):
269269

270270
_preprocess_data = staticmethod(_preprocess_data)
271271

272-
def _set_intercept(self, X_offset, y_offset, X_scale):
272+
def _set_intercept(self, X_offset, y_offset, X_scale, intercept_=None):
273273
"""Set the intercept_
274274
"""
275275
if self.fit_intercept:
276276
self.coef_ = self.coef_ / X_scale
277-
self.intercept_ = y_offset - np.dot(X_offset, self.coef_.T)
277+
if intercept_ is None:
278+
self.intercept_ = y_offset - np.dot(X_offset, self.coef_.T)
279+
else:
280+
self.intercept_ = y_offset - intercept_
278281
else:
279282
self.intercept_ = 0.
280283

sklearn/linear_model/logistic.py

Lines changed: 30 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,8 @@
1616
import numpy as np
1717
from scipy import optimize, sparse
1818

19-
from .base import LinearClassifierMixin, SparseCoefMixin, BaseEstimator
19+
from .base import (LinearClassifierMixin, SparseCoefMixin, BaseEstimator,
20+
LinearModel)
2021
from .sag import sag_solver
2122
from ..feature_selection.from_model import _LearntSelectorMixin
2223
from ..preprocessing import LabelEncoder, LabelBinarizer
@@ -987,7 +988,9 @@ class LogisticRegression(BaseEstimator, LinearClassifierMixin,
987988
988989
fit_intercept : bool, default: True
989990
Specifies if a constant (a.k.a. bias or intercept) should be
990-
added to the decision function.
991+
added to the decision function. If set to false, no intercept
992+
will be used in calculations (e.g. data is expected to be already
993+
centered).
991994
992995
intercept_scaling : float, default: 1
993996
Useful only if solver is liblinear.
@@ -1001,6 +1004,19 @@ class LogisticRegression(BaseEstimator, LinearClassifierMixin,
10011004
To lessen the effect of regularization on synthetic feature weight
10021005
(and therefore on the intercept) intercept_scaling has to be increased.
10031006
1007+
normalize : boolean, optional, default: False
1008+
If True, the regressors X will be normalized before regression.
1009+
This parameter is ignored when `fit_intercept` is set to False.
1010+
When the regressors are normalized, note that this makes the
1011+
hyperparameters learnt more robust and almost independent of the number
1012+
of samples. The same property is not valid for standardized data.
1013+
However, if you wish to standardize, please use
1014+
`preprocessing.StandardScaler` before calling `fit` on an estimator
1015+
with `normalize=False`.
1016+
1017+
copy_X : boolean, optional, default: True
1018+
If True, X will be copied; else, it may be overwritten.
1019+
10041020
class_weight : dict or 'balanced', default: None
10051021
Weights associated with classes in the form ``{class_label: weight}``.
10061022
If not given, all classes are supposed to have weight one.
@@ -1114,16 +1130,19 @@ class LogisticRegression(BaseEstimator, LinearClassifierMixin,
11141130
"""
11151131

11161132
def __init__(self, penalty='l2', dual=False, tol=1e-4, C=1.0,
1117-
fit_intercept=True, intercept_scaling=1, class_weight=None,
1118-
random_state=None, solver='liblinear', max_iter=100,
1119-
multi_class='ovr', verbose=0, warm_start=False, n_jobs=1):
1133+
fit_intercept=True, intercept_scaling=1, normalize=False,
1134+
copy_X=True, class_weight=None, random_state=None,
1135+
solver='liblinear', max_iter=100, multi_class='ovr',
1136+
verbose=0, warm_start=False, n_jobs=1):
11201137

11211138
self.penalty = penalty
11221139
self.dual = dual
11231140
self.tol = tol
11241141
self.C = C
11251142
self.fit_intercept = fit_intercept
11261143
self.intercept_scaling = intercept_scaling
1144+
self.normalize = normalize
1145+
self.copy_X = copy_X
11271146
self.class_weight = class_weight
11281147
self.random_state = random_state
11291148
self.solver = solver
@@ -1176,13 +1195,18 @@ def fit(self, X, y, sample_weight=None):
11761195
_check_solver_option(self.solver, self.multi_class, self.penalty,
11771196
self.dual)
11781197

1198+
X, y, X_offset, y_offset, X_scale = self._preprocess_data(
1199+
X, y, self.fit_intercept, self.normalize, self.copy_X,
1200+
sample_weight=sample_weight)
1201+
11791202
if self.solver == 'liblinear':
11801203
self.coef_, self.intercept_, n_iter_ = _fit_liblinear(
11811204
X, y, self.C, self.fit_intercept, self.intercept_scaling,
11821205
self.class_weight, self.penalty, self.dual, self.verbose,
11831206
self.max_iter, self.tol, self.random_state,
11841207
sample_weight=sample_weight)
11851208
self.n_iter_ = np.array([n_iter_])
1209+
self._set_intercept(X_offset, y_offset, X_scale, self.intercept_)
11861210
return self
11871211

11881212
if self.solver == 'sag':
@@ -1251,6 +1275,7 @@ def fit(self, X, y, sample_weight=None):
12511275
if self.fit_intercept:
12521276
self.intercept_ = self.coef_[:, -1]
12531277
self.coef_ = self.coef_[:, :-1]
1278+
self._set_intercept(X_offset, y_offset, X_scale, self.intercept_)
12541279

12551280
return self
12561281

0 commit comments

Comments
 (0)
0