8000 ENH: Support centering in LogisticRegression · kernc/scikit-learn@7449cca · GitHub
[go: up one dir, main page]

Skip to content

Commit 7449cca

Browse files
committed
ENH: Support centering in LogisticRegression
1 parent 113ee40 commit 7449cca

File tree

2 files changed

+33
-6
lines changed

2 files changed

+33
-6
lines changed

sklearn/linear_model/base.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -274,7 +274,10 @@ def _set_intercept(self, X_offset, y_offset, X_scale):
274274
"""
275275
if self.fit_intercept:
276276
self.coef_ = self.coef_ / X_scale
277-
self.intercept_ = y_offset - np.dot(X_offset, self.coef_.T)
277+
if hasattr(self, 'intercept_'):
278+
self.intercept_ = y_offset - self.intercept_
279+
else:
280+
self.intercept_ = y_offset - np.dot(X_offset, self.coef_.T)
278281
else:
279282
self.intercept_ = 0.
280283

sklearn/linear_model/logistic.py

Lines changed: 29 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,8 @@
1616
import numpy as np
1717
from scipy import optimize, sparse
1818

19-
from .base import LinearClassifierMixin, SparseCoefMixin, BaseEstimator
19+
from .base import (LinearClassifierMixin, SparseCoefMixin, BaseEstimator,
20+
LinearModel)
2021
from .sag import sag_solver
2122
from ..feature_selection.from_model import _LearntSelectorMixin
2223
from ..preprocessing import LabelEncoder, LabelBinarizer
@@ -948,7 +949,7 @@ def _log_reg_scoring_path(X, y, train, test, pos_class=None, Cs=10,
948949

949950

950951
class LogisticRegression(BaseEstimator, LinearClassifierMixin,
951-
_LearntSelectorMixin, SparseCoefMixin):
952+
_LearntSelectorMixin, SparseCoefMixin, LinearModel):
952953
"""Logistic Regression (aka logit, MaxEnt) classifier.
953954
954955
In the multiclass case, the training algorithm uses the one-vs-rest (OvR)
@@ -1001,6 +1002,19 @@ class LogisticRegression(BaseEstimator, LinearClassifierMixin,
10011002
To lessen the effect of regularization on synthetic feature weight
10021003
(and therefore on the intercept) intercept_scaling has to be increased.
10031004
1005+
normalize : boolean, optional, default: False
1006+
If True, the regressors X will be normalized before regression.
1007+
This parameter is ignored when `fit_intercept` is set to False.
1008+
When the regressors are normalized, note that this makes the
1009+
hyperparameters learnt more robust and almost independent of the number
1010+
of samples. The same property is not valid for standardized data.
1011+
However, if you wish to standardize, please use
1012+
`preprocessing.StandardScaler` before calling `fit` on an estimator
1013+
with `normalize=False`.
1014+
1015+
copy_X : boolean, optional, default: True
1016+
If True, X will be copied; else, it may be overwritten.
1017+
10041018
class_weight : dict or 'balanced', default: None
10051019
Weights associated with classes in the form ``{class_label: weight}``.
10061020
If not given, all classes are supposed to have weight one.
@@ -1114,16 +1128,19 @@ class LogisticRegression(BaseEstimator, LinearClassifierMixin,
11141128
"""
11151129

11161130
def __init__(self, penalty='l2', dual=False, tol=1e-4, C=1.0,
1117-
fit_intercept=True, intercept_scaling=1, class_weight=None,
1118-
random_state=None, solver='liblinear', max_iter=100,
1119-
multi_class='ovr', verbose=0, warm_start=False, n_jobs=1):
1131+
fit_intercept=True, intercept_scaling=1, normalize=False,
1132+
copy_X=True, class_weight=None, random_state=None,
1133+
solver='liblinear', max_iter=100, multi_class='ovr',
1134+
verbose=0, warm_start=False, n_jobs=1):< 10000 /div>
11201135

11211136
self.penalty = penalty
11221137
self.dual = dual
11231138
self.tol = tol
11241139
self.C = C
11251140
self.fit_intercept = fit_intercept
11261141
self.intercept_scaling = intercept_scaling
1142+
self.normalize = normalize
1143+
self.copy_X = copy_X
11271144
self.class_weight = class_weight
11281145
self.random_state = random_state
11291146
self.solver = solver
@@ -1176,13 +1193,18 @@ def fit(self, X, y, sample_weight=None):
11761193
_check_solver_option(self.solver, self.multi_class, self.penalty,
11771194
self.dual)
11781195

1196+
X, y, X_offset, y_offset, X_scale = self._preprocess_data(
1197+
X, y, self.fit_intercept, self.normalize, self.copy_X,
1198+
sample_weight=sample_weight)
1199+
11791200
if self.solver == 'liblinear':
11801201
self.coef_, self.intercept_, n_iter_ = _fit_liblinear(
11811202
X, y, self.C, self.fit_intercept, self.intercept_scaling,
11821203
self.class_weight, self.penalty, self.dual, self.verbose,
11831204
self.max_iter, self.tol, self.random_state,
11841205
sample_weight=sample_weight)
11851206
self.n_iter_ = np.array([n_iter_])
1207+
self._set_intercept(X_offset, y_offset, X_scale)
11861208
return self
11871209

11881210
if self.solver == 'sag':
@@ -1252,6 +1274,8 @@ def fit(self, X, y, sample_weight=None):
12521274
self.intercept_ = self.coef_[:, -1]
12531275
self.coef_ = self.coef_[:, :-1]
12541276

1277+
self._set_intercept(X_offset, y_offset, X_scale)
1278+
12551279
return self
12561280

12571281
def predict_proba(self, X):

0 commit comments

Comments
 (0)
0