-
-
Notifications
You must be signed in to change notification settings - Fork 25.9k
[MRG] Adds FutureWarning changing default solver to 'lbfgs' and multi_class to 'multinomial' in LogisticRegression #10001
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
9b72243
5bb61e8
9f99456
b877620
8832c89
803b6f7
bc7db9e
1c7b502
16fa162
d672f1e
e215d9f
edb071e
ba50c26
06d08d3
7079567
c92c6dd
23777e9
f26ed59
e9708ee
0577424
7af054a
73cbb8f
3e0c6d0
73ed12c
d6628f4
ec2a7f7
100decd
f4d402e
e227500
fee3ed6
4cc2bad
f99bda7
4f98f8d
6f41136
d942b06
e865efe
ec8e6b4
8d2bc20
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -449,9 +449,9 @@ def _check_solver_option(solver, multi_class, penalty, dual): | |
|
||
def logistic_regression_path(X, y, pos_class=None, Cs=10, fit_intercept=True, | ||
max_iter=100, tol=1e-4, verbose=0, | ||
solver='lbfgs', coef=None, | ||
solver='default', coef=None, | ||
class_weight=None, dual=False, penalty='l2', | ||
intercept_scaling=1., multi_class='ovr', | ||
intercept_scaling=1., multi_class='default', | ||
random_state=None, check_input=True, | ||
max_squared_sum=None, sample_weight=None): | ||
"""Compute a Logistic Regression model for a list of regularization | ||
|
@@ -500,6 +500,7 @@ def logistic_regression_path(X, y, pos_class=None, Cs=10, fit_intercept=True, | |
number for verbosity. | ||
|
||
solver : {'lbfgs', 'newton-cg', 'liblinear', 'sag', 'saga'} | ||
default: 'default'. Will be changed to 'auto' solver in 0.22. | ||
Numerical solver to use. | ||
|
||
coef : array-like, shape (n_features,), default None | ||
|
@@ -540,6 +541,7 @@ def logistic_regression_path(X, y, pos_class=None, Cs=10, fit_intercept=True, | |
(and therefore on the intercept) intercept_scaling has to be increased. | ||
|
||
multi_class : str, {'ovr', 'multinomial'} | ||
default: 'default'. Will be changed to 'multinomial' in 0.22. | ||
Multiclass option can be either 'ovr' or 'multinomial'. If the option | ||
chosen is 'ovr', then a binary problem is fit for each label. Else | ||
the loss minimised is the multinomial loss fit across | ||
|
@@ -587,6 +589,19 @@ def logistic_regression_path(X, y, pos_class=None, Cs=10, fit_intercept=True, | |
.. versionchanged:: 0.19 | ||
The "copy" parameter was removed. | ||
""" | ||
if solver == 'default': | ||
solver = 'lbfgs' | ||
warnings.warn("Default solver will be changed from 'lbfgs' " | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We should probably only warn if auto would choose a solver other than lbfgs. |
||
"to 'auto' solver in 0.22", FutureWarning) | ||
elif solver == 'auto': | ||
if penalty == 'l1': | ||
solver = 'saga' | ||
if penalty == 'l2': | ||
solver = 'lbfgs' | ||
if multi_class == 'default': | ||
multi_class = 'ovr' | ||
warnings.warn("Default multi_class will be changed from 'ovr' to" | ||
" 'multinomial' in 0.22", FutureWarning) | ||
if isinstance(Cs, numbers.Integral): | ||
Cs = np.logspace(-4, 4, Cs) | ||
|
||
|
@@ -1043,7 +1058,7 @@ class LogisticRegression(BaseEstimator, LinearClassifierMixin, | |
'liblinear'. | ||
|
||
solver : {'newton-cg', 'lbfgs', 'liblinear', 'sag', 'saga'}, | ||
default: 'liblinear' | ||
default: 'default'. Will be changed to 'auto' solver in 0.22. | ||
Algorithm to use in the optimization problem. | ||
|
||
- For small datasets, 'liblinear' is a good choice, whereas 'sag' and | ||
|
@@ -1058,16 +1073,25 @@ class LogisticRegression(BaseEstimator, LinearClassifierMixin, | |
features with approximately the same scale. You can | ||
preprocess the data with a scaler from sklearn.preprocessing. | ||
|
||
The solver 'auto' selects 'lbfgs' if penalty is 'l2' and 'saga' if | ||
penalty is 'l1'. Note that 'saga' may suffer from slow convergence | ||
issues on small datasets. The only other solver supporting 'l1' is | ||
'liblinear', which requires multiclass='ovr' and which unfortunately | ||
regularizes the intercept (see 'intercept_scaling'). | ||
|
||
.. versionadded:: 0.17 | ||
Stochastic Average Gradient descent solver. | ||
.. versionadded:: 0.19 | ||
SAGA solver. | ||
.. versionadded:: 0.20 | ||
auto solver | ||
|
||
max_iter : int, default: 100 | ||
Useful only for the newton-cg, sag and lbfgs solvers. | ||
Maximum number of iterations taken for the solvers to converge. | ||
|
||
multi_class : str, {'ovr', 'multinomial'}, default: 'ovr' | ||
multi_class : str, {'ovr', 'multinomial'}, | ||
default: 'default'. Will be changed to 'multinomial' in 0.22. | ||
Multiclass option can be either 'ovr' or 'multinomial'. If the option | ||
chosen is 'ovr', then a binary problem is fit for each label. Else | ||
the loss minimised is the multinomial loss fit across | ||
|
@@ -1160,8 +1184,8 @@ class LogisticRegression(BaseEstimator, LinearClassifierMixin, | |
|
||
def __init__(self, penalty='l2', dual=False, tol=1e-4, C=1.0, | ||
fit_intercept=True, intercept_scaling=1, class_weight=None, | ||
random_state=None, solver='liblinear', max_iter=100, | ||
multi_class='ovr', verbose=0, warm_start=False, n_jobs=1): | ||
random_state=None, solver='default', max_iter=100, | ||
multi_class='default', verbose=0, warm_start=False, n_jobs=1): | ||
|
||
self.penalty = penalty | ||
self.dual = dual | ||
|
@@ -1201,6 +1225,23 @@ def fit(self, X, y, sample_weight=None): | |
------- | ||
self : object | ||
""" | ||
if self.solver == 'default': | ||
_solver = 'liblinear' | ||
warnings.warn("Default solver will be changed from 'liblinear' to " | ||
"'auto' solver in 0.22", FutureWarning) | ||
elif self.solver == 'auto': | ||
if self.penalty == 'l1': | ||
_solver = 'saga' | ||
if self.penalty == 'l2': | ||
_solver = 'lbfgs' | ||
else: | ||
_solver = self.solver | ||
if self.multi_class == 'default': | ||
_multi_class = 'ovr' | ||
warnings.warn("Default multi_class will be changed from 'ovr' to " | ||
"'multinomial' in 0.22", FutureWarning) | ||
else: | ||
_multi_class = self.multi_class | ||
if not isinstance(self.C, numbers.Number) or self.C < 0: | ||
raise ValueError("Penalty term must be positive; got (C=%r)" | ||
% self.C) | ||
|
@@ -1211,7 +1252,7 @@ def fit(self, X, y, sample_weight=None): | |
raise ValueError("Tolerance for stopping criteria must be " | ||
"positive; got (tol=%r)" % self.tol) | ||
|
||
if self.solver in ['newton-cg']: | ||
if _solver in ['newton-cg']: | ||
_dtype = [np.float64, np.float32] | ||
else: | ||
_dtype = np.float64 | ||
|
@@ -1222,10 +1263,10 @@ def fit(self, X, y, sample_weight=None): | |
self.classes_ = np.unique(y) | ||
n_samples, n_features = X.shape | ||
|
||
_check_solver_option(self.solver, self.multi_class, self.penalty, | ||
_check_solver_option(_solver, _multi_class, self.penalty, | ||
self.dual) | ||
|
||
if self.solver == 'liblinear': | ||
if _solver == 'liblinear': | ||
if self.n_jobs != 1: | ||
warnings.warn("'n_jobs' > 1 does not have any effect when" | ||
" 'solver' is set to 'liblinear'. Got 'n_jobs'" | ||
|
@@ -1238,7 +1279,7 @@ def fit(self, X, y, sample_weight=None): | |
self.n_iter_ = np.array([n_iter_]) | ||
return self | ||
|
||
if self.solver in ['sag', 'saga']: | ||
if _solver in ['sag', 'saga']: | ||
max_squared_sum = row_norms(X, squared=True).max() | ||
else: | ||
max_squared_sum = None | ||
|
@@ -1267,7 +1308,7 @@ def fit(self, X, y, sample_weight=None): | |
self.intercept_ = np.zeros(n_classes) | ||
|
||
# Hack so that we iterate only once for the multinomial case. | ||
if self.multi_class == 'multinomial': | ||
if _multi_class == 'multinomial': | ||
classes_ = [None] | ||
warm_start_coef = [warm_start_coef] | ||
if warm_start_coef is None: | ||
|
@@ -1277,16 +1318,16 @@ def fit(self, X, y, sample_weight=None): | |
|
||
# The SAG solver releases the GIL so it's more efficient to use | ||
# threads for this solver. | ||
if self.solver in ['sag', 'saga']: | ||
if _solver in ['sag', 'saga']: | ||
backend = 'threading' | ||
else: | ||
backend = 'multiprocessing' | ||
fold_coefs_ = Parallel(n_jobs=self.n_jobs, verbose=self.verbose, | ||
backend=backend)( | ||
path_func(X, y, pos_class=class_, Cs=[self.C], | ||
fit_intercept=self.fit_intercept, tol=self.tol, | ||
verbose=self.verbose, solver=self.solver, | ||
multi_class=self.multi_class, max_iter=self.max_iter, | ||
verbose=self.verbose, solver=_solver, | ||
multi_class=_multi_class, max_iter=self.max_iter, | ||
class_weight=self.class_weight, check_input=False, | ||
random_state=self.random_state, coef=warm_start_coef_, | ||
penalty=self.penalty, | ||
|
@@ -1297,7 +1338,7 @@ def fit(self, X, y, sample_weight=None): | |
fold_coefs_, _, n_iter_ = zip(*fold_coefs_) | ||
self.n_iter_ = np.asarray(n_iter_, dtype=np.int32)[:, 0] | ||
|
||
if self.multi_class == 'multinomial': | ||
if _multi_class == 'multinomial': | ||
self.coef_ = fold_coefs_[0][0] | ||
else: | ||
self.coef_ = np.asarray(fold_coefs_) | ||
|
@@ -1335,7 +1376,11 @@ def predict_proba(self, X): | |
""" | ||
if not hasattr(self, "coef_"): | ||
raise NotFittedError("Call fit before prediction") | ||
if self.multi_class == "ovr": | ||
if self.multi_class == 'default': | ||
_multi_class = 'ovr' | ||
else: | ||
_multi_class = self.multi_class | ||
if _multi_class == "ovr": | ||
return super(LogisticRegression, self)._predict_proba_lr(X) | ||
else: | ||
decision = self.decision_function(X) | ||
|
@@ -1425,7 +1470,7 @@ class LogisticRegressionCV(LogisticRegression, BaseEstimator, | |
default scoring option used is 'accuracy'. | ||
|
||
solver : {'newton-cg', 'lbfgs', 'liblinear', 'sag', 'saga'}, | ||
default: 'lbfgs' | ||
default: 'default'. Will be changed to 'auto' solver in 0.22. | ||
Algorithm to use in the optimization problem. | ||
|
||
- For small datasets, 'liblinear' is a good choice, whereas 'sag' and | ||
|
@@ -1442,6 +1487,14 @@ class LogisticRegressionCV(LogisticRegression, BaseEstimator, | |
features with approximately the same scale. You can preprocess the data | ||
with a scaler from sklearn.preprocessing. | ||
|
||
The solver 'auto' selects 'lbfgs' if penalty is 'l2' and 'saga' if | ||
penalty is 'l1'. Note that 'saga' may suffer from slow convergence | ||
issues on small datasets. The only other solver supporting 'l1' is | ||
'liblinear', which requires multiclass='ovr' and which unfortunately | ||
regularizes the intercept (see 'intercept_scaling'). | ||
|
||
.. versionadded:: 0.20 | ||
auto solver | ||
.. versionadded:: 0.17 | ||
Stochastic Average Gradient descent solver. | ||
.. versionadded:: 0.19 | ||
|
@@ -1496,6 +1549,7 @@ class LogisticRegressionCV(LogisticRegression, BaseEstimator, | |
(and therefore on the intercept) intercept_scaling has to be increased. | ||
|
||
multi_class : str, {'ovr', 'multinomial'} | ||
default: 'default'. Will be changed to 'multinomial' in 0.22 | ||
Multiclass option can be either 'ovr' or 'multinomial'. If the option | ||
chosen is 'ovr', then a binary problem is fit for each label. Else | ||
the loss minimised is the multinomial loss fit across | ||
|
@@ -1565,9 +1619,9 @@ class LogisticRegressionCV(LogisticRegression, BaseEstimator, | |
""" | ||
|
||
def __init__(self, Cs=10, fit_intercept=True, cv=None, dual=False, | ||
penalty='l2', scoring=None, solver='lbfgs', tol=1e-4, | ||
penalty='l2', scoring=None, solver='default', tol=1e-4, | ||
max_iter=100, class_weight=None, n_jobs=1, verbose=0, | ||
refit=True, intercept_scaling=1., multi_class='ovr', | ||
refit=True, intercept_scaling=1., multi_class='default', | ||
random_state=None): | ||
self.Cs = Cs | ||
self.fit_intercept = fit_intercept | ||
|
@@ -1606,7 +1660,24 @@ def fit(self, X, y, sample_weight=None): | |
------- | ||
self : object | ||
""" | ||
_check_solver_option(self.solver, self.multi_class, self.penalty, | ||
if self.solver == 'default': | ||
_solver = 'lbfgs' | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I find it strange to have these preceding undescores. This is just a plain old local variable. |
||
warnings.warn("Default solver will be changed from 'lbfgs' " | ||
"to 'auto' solver in 0.22", FutureWarning) | ||
elif self.solver == 'auto': | ||
if self.penalty == 'l1': | ||
_solver = 'saga' | ||
if self.penalty == 'l2': | ||
_solver = 'lbfgs' | ||
else: | ||
_solver = self.solver | ||
if self.multi_class == 'default': | ||
_multi_class = 'ovr' | ||
warnings.warn("Default multi_class will be changed from 'ovr' to" | ||
" 'multinomial' in 0.22", FutureWarning) | ||
else: | ||
_multi_class = self.multi_class | ||
_check_solver_option(_solver, _multi_class, self.penalty, | ||
self.dual) | ||
|
||
if not isinstance(self.max_iter, numbers.Number) or self.max_iter < 0: | ||
|
@@ -1633,7 +1704,7 @@ def fit(self, X, y, sample_weight=None): | |
classes = self.classes_ = label_encoder.classes_ | ||
encoded_labels = label_encoder.transform(label_encoder.classes_) | ||
|
||
if self.solver in ['sag', 'saga']: | ||
if _solver in ['sag', 'saga']: | ||
max_squared_sum = row_norms(X, squared=True).max() | ||
else: | ||
max_squared_sum = None | ||
|
@@ -1659,7 +1730,7 @@ def fit(self, X, y, sample_weight=None): | |
|
||
# We need this hack to iterate only once over labels, in the case of | ||
# multi_class = multinomial, without changing the value of the labels. | ||
if self.multi_class == 'multinomial': | ||
if _multi_class == 'multinomial': | ||
iter_encoded_labels = iter_classes = [None] | ||
else: | ||
iter_encoded_labels = encoded_labels | ||
|
@@ -1676,18 +1747,18 @@ def fit(self, X, y, sample_weight=None): | |
|
||
# The SAG solver releases the GIL so it's more efficient to use | ||
# threads for this solver. | ||
if self.solver in ['sag', 'saga']: | ||
if _solver in ['sag', 'saga']: | ||
backend = 'threading' | ||
else: | ||
backend = 'multiprocessing' | ||
fold_coefs_ = Parallel(n_jobs=self.n_jobs, verbose=self.verbose, | ||
backend=backend)( | ||
path_func(X, y, train, test, pos_class=label, Cs=self.Cs, | ||
fit_intercept=self.fit_intercept, penalty=self.penalty, | ||
dual=self.dual, solver=self.solver, tol=self.tol, | ||
dual=self.dual, solver=_solver, tol=self.tol, | ||
max_iter=self.max_iter, verbose=self.verbose, | ||
class_weight=class_weight, scoring=self.scoring, | ||
multi_class=self.multi_class, | ||
multi_class=_multi_class, | ||
intercept_scaling=self.intercept_scaling, | ||
random_state=self.random_state, | ||
max_squared_sum=max_squared_sum, | ||
|
@@ -1696,7 +1767,7 @@ def fit(self, X, y, sample_weight=None): | |
for label in iter_encoded_labels | ||
for train, test in folds) | ||
|
||
if self.multi_class == 'multinomial': | ||
if _multi_class == 'multinomial': | ||
multi_coefs_paths, Cs, multi_scores, n_iter_ = zip(*fold_coefs_) | ||
multi_coefs_paths = np.asarray(multi_coefs_paths) | ||
multi_scores = np.asarray(multi_scores) | ||
|
@@ -1733,14 +1804,14 @@ def fit(self, X, y, sample_weight=None): | |
self.intercept_ = np.zeros(n_classes) | ||
|
||
# hack to iterate only once for multinomial case. | ||
if self.multi_class == 'multinomial': | ||
if _multi_class == 'multinomial': | ||
scores = multi_scores | ||
coefs_paths = multi_coefs_paths | ||
|
||
for index, (cls, encoded_label) in enumerate( | ||
zip(iter_classes, iter_encoded_labels)): | ||
|
||
if self.multi_class == 'ovr': | ||
if _multi_class == 'ovr': | ||
# The scores_ / coefs_paths_ dict have unencoded class | ||
# labels as their keys | ||
scores = self.scores_[cls] | ||
|
@@ -1751,7 +1822,7 @@ def fit(self, X, y, sample_weight=None): | |
|
||
C_ = self.Cs_[best_index] | ||
self.C_.append(C_) | ||
if self.multi_class == 'multinomial': | ||
if _multi_class == 'multinomial': | ||
coef_init = np.mean(coefs_paths[:, best_index, :, :], | ||
axis=0) | ||
else: | ||
|
@@ -1760,12 +1831,12 @@ def fit(self, X, y, sample_weight=None): | |
# Note that y is label encoded and hence pos_class must be | ||
# the encoded label / None (for 'multinomial') | ||
w, _, _ = logistic_regression_path( | ||
X, y, pos_class=encoded_label, Cs=[C_], solver=self.solver, | ||
X, y, pos_class=encoded_label, Cs=[C_], solver=_solver, | ||
fit_intercept=self.fit_intercept, coef=coef_init, | ||
max_iter=self.max_iter, tol=self.tol, | ||
penalty=self.penalty, | ||
class_weight=class_weight, | ||
multi_class=self.multi_class, | ||
multi_class=_multi_class, | ||
verbose=max(0, self.verbose - 1), | ||
random_state=self.random_state, | ||
check_input=False, max_squared_sum=max_squared_sum, | ||
|
@@ -1780,7 +1851,7 @@ def fit(self, X, y, sample_weight=None): | |
for i in range(len(folds))], axis=0) | ||
self.C_.append(np.mean(self.Cs_[best_indices])) | ||
|
||
if self.multi_class == 'multinomial': | ||
if _multi_class == 'multinomial': | ||
self.C_ = np.tile(self.C_, n_classes) | ||
self.coef_ = w[:, :X.shape[1]] | ||
if self.fit_intercept: | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
default 'default' doesn't mean anything. If you're going to say such a thing, you need to explain what it means. Better to say it is lbfgs by default, even if the string 'default' is used as a placeholder. 'auto' needs to be added to the list of options and described.