diff --git a/sklearn/feature_selection/_from_model.py b/sklearn/feature_selection/_from_model.py index 5fb519a2bd798..0357fa5892cf2 100644 --- a/sklearn/feature_selection/_from_model.py +++ b/sklearn/feature_selection/_from_model.py @@ -144,7 +144,7 @@ class SelectFromModel(MetaEstimatorMixin, SelectorMixin, BaseEstimator): >>> y = [0, 1, 0, 1] >>> selector = SelectFromModel(estimator=LogisticRegression()).fit(X, y) >>> selector.estimator_.coef_ - array([[-0.3252302 , 0.83462377, 0.49750423]]) + array([[-0.32... , 0.83..., 0.49...]]) >>> selector.threshold_ 0.55245... >>> selector.get_support() diff --git a/sklearn/linear_model/_logistic.py b/sklearn/linear_model/_logistic.py index 1afa06637b04a..e547e198407ee 100644 --- a/sklearn/linear_model/_logistic.py +++ b/sklearn/linear_model/_logistic.py @@ -26,7 +26,7 @@ from ..utils import check_random_state from ..utils.extmath import (log_logistic, safe_sparse_dot, softmax, squared_norm) -from ..utils.extmath import row_norms +from ..utils.extmath import row_norms, _weighted_mean_std from ..utils.optimize import _newton_cg, _check_optimize_result from ..utils.validation import check_is_fitted, _check_sample_weight from ..utils.validation import _deprecate_positional_args @@ -44,7 +44,7 @@ # .. some helper functions for logistic_regression_path .. -def _intercept_dot(w, X, y): +def _intercept_dot(w, X, y, X_offset=None): """Computes y * np.dot(X, w). It takes into consideration if the intercept should be fit or not. @@ -60,6 +60,11 @@ def _intercept_dot(w, X, y): y : ndarray of shape (n_samples,) Array of labels. + X_offset : ndarray, shape (n_features,) or None + Offset to use for X to avoid subtracting mean from sparse + matrices if preconditioning. Should be None in the dense case + as the mean was actually subtracted. + Returns ------- w : ndarray of shape (n_features,) @@ -78,11 +83,14 @@ def _intercept_dot(w, X, y): w = w[:-1] z = safe_sparse_dot(X, w) + c + if X_offset is not None: + z += np.inner(X_offset, w) yz = y * z return w, c, yz -def _logistic_loss_and_grad(w, X, y, alpha, sample_weight=None): +def _logistic_loss_and_grad(w, X, y, alpha, sample_weight=None, X_scale=None, + X_offset=None): """Computes the logistic loss and gradient. Parameters @@ -103,6 +111,15 @@ def _logistic_loss_and_grad(w, X, y, alpha, sample_weight=None): Array of weights that are assigned to individual samples. If not provided, then each sample is given unit weight. + X_scale : ndarray, shape (n_features,) or None + Rescaling that was applied to X for preconditioning. + Needed to correctly compute penalty term. + + X_offset : ndarray, shape (n_features,) or None + Offset to use for X to avoid subtracting mean from sparse + matrices if preconditioning. Should be None in the dense case + as the mean was actually subtracted. + Returns ------- out : float @@ -114,18 +131,27 @@ def _logistic_loss_and_grad(w, X, y, alpha, sample_weight=None): n_samples, n_features = X.shape grad = np.empty_like(w) - w, c, yz = _intercept_dot(w, X, y) + w, c, yz = _intercept_dot(w, X, y, X_offset) if sample_weight is None: sample_weight = np.ones(n_samples) + v = w + if X_scale is not None: + v = w / X_scale + # Logistic loss is the negative of the log of the logistic function. - out = -np.sum(sample_weight * log_logistic(yz)) + .5 * alpha * np.dot(w, w) + out = -np.sum(sample_weight * log_logistic(yz)) + .5 * alpha * np.dot(v, v) z = expit(yz) z0 = sample_weight * (z - 1) * y + if X_scale is not None: + grad[:n_features] = safe_sparse_dot(X.T, z0) + alpha * (w / X_scale**2) + else: + grad[:n_features] = safe_sparse_dot(X.T, z0) + alpha * w - grad[:n_features] = safe_sparse_dot(X.T, z0) + alpha * w + if X_offset is not None: + grad[:n_features] += X_offset * z0.sum() # Case where we fit the intercept. if grad.shape[0] > n_features: @@ -133,7 +159,8 @@ def _logistic_loss_and_grad(w, X, y, alpha, sample_weight=None): return out, grad -def _logistic_loss(w, X, y, alpha, sample_weight=None): +def _logistic_loss(w, X, y, alpha, sample_weight=None, X_scale=None, + X_offset=None): """Computes the logistic loss. Parameters @@ -154,18 +181,30 @@ def _logistic_loss(w, X, y, alpha, sample_weight=None): Array of weights that are assigned to individual samples. If not provided, then each sample is given unit weight. + X_scale : ndarray, shape (n_features,) or None + Rescaling that was applied to X for preconditioning. + Needed to correctly compute penalty term. + + X_offset : ndarray, shape (n_features,) or None + Offset to use for X to avoid subtracting mean from sparse + matrices if preconditioning. Should be None in the dense case + as the mean was actually subtracted. + Returns ------- out : float Logistic loss. """ - w, c, yz = _intercept_dot(w, X, y) + w, c, yz = _intercept_dot(w, X, y, X_offset) if sample_weight is None: sample_weight = np.ones(y.shape[0]) # Logistic loss is the negative of the log of the logistic function. - out = -np.sum(sample_weight * log_logistic(yz)) + .5 * alpha * np.dot(w, w) + v = w + if X_scale is not None: + v = w / X_scale + out = -np.sum(sample_weight * log_logistic(yz)) + .5 * alpha * np.dot(v, v) return out @@ -246,7 +285,8 @@ def Hs(s): return grad, Hs -def _multinomial_loss(w, X, Y, alpha, sample_weight): +def _multinomial_loss(w, X, Y, alpha, sample_weight, X_scale=None, + X_offset=None): """Computes multinomial loss and class probabilities. Parameters @@ -267,6 +307,15 @@ def _multinomial_loss(w, X, Y, alpha, sample_weight): sample_weight : array-like of shape (n_samples,) Array of weights that are assigned to individual samples. + X_scale : ndarray, shape (n_features,) or None + Rescaling that was applied to X for preconditioning. + Needed to correctly compute penalty term. + + X_offset : ndarray, shape (n_features,) or None + Offset to use for X to avoid subtracting mean from sparse + matrices if preconditioning. Should be None in the dense case + as the mean was actually subtracted. + Returns ------- loss : float @@ -293,16 +342,23 @@ def _multinomial_loss(w, X, Y, alpha, sample_weight): w = w[:, :-1] else: intercept = 0 + p = safe_sparse_dot(X, w.T) p += intercept + if X_offset is not None: + p += np.dot(X_offset, w.T) p -= logsumexp(p, axis=1)[:, np.newaxis] loss = -(sample_weight * Y * p).sum() - loss += 0.5 * alpha * squared_norm(w) + v = w + if X_scale is not None: + v = w / X_scale + loss += 0.5 * alpha * squared_norm(v) p = np.exp(p, p) return loss, p, w -def _multinomial_loss_grad(w, X, Y, alpha, sample_weight): +def _multinomial_loss_grad(w, X, Y, alpha, sample_weight, X_scale=None, + X_offset=None): """Computes the multinomial loss, gradient and class probabilities. Parameters @@ -323,6 +379,15 @@ def _multinomial_loss_grad(w, X, Y, alpha, sample_weight): sample_weight : array-like of shape (n_samples,) Array of weights that are assigned to individual samples. + X_scale : ndarray, shape (n_features,) or None + Rescaling that was applied to X for preconditioning. + Needed to correctly compute penalty term. + + X_offset : ndarray, shape (n_features,) or None + Offset to use for X to avoid subtracting mean from sparse + matrices if preconditioning. Should be None in the dense case + as the mean was actually subtracted. + Returns ------- loss : float @@ -345,11 +410,17 @@ def _multinomial_loss_grad(w, X, Y, alpha, sample_weight): fit_intercept = (w.size == n_classes * (n_features + 1)) grad = np.zeros((n_classes, n_features + bool(fit_intercept)), dtype=X.dtype) - loss, p, w = _multinomial_loss(w, X, Y, alpha, sample_weight) + loss, p, w = _multinomial_loss(w, X, Y, alpha, sample_weight, + X_scale=X_scale, X_offset=X_offset) sample_weight = sample_weight[:, np.newaxis] diff = sample_weight * (p - Y) grad[:, :n_features] = safe_sparse_dot(diff.T, X) - grad[:, :n_features] += alpha * w + if X_offset is not None: + grad[:, :n_features] += np.outer(diff.T.sum(axis=1), X_offset) + if X_scale is not None: + grad[:, :n_features] += alpha * (w / X_scale**2) + else: + grad[:, :n_features] += alpha * w if fit_intercept: grad[:, -1] = diff.sum(axis=0) return loss, grad.ravel(), p @@ -454,7 +525,6 @@ def _check_solver(solver, penalty, dual): raise ValueError( "penalty='none' is not supported for the liblinear solver" ) - return solver @@ -482,7 +552,7 @@ def _logistic_regression_path(X, y, pos_class=None, Cs=10, fit_intercept=True, intercept_scaling=1., multi_class='auto', random_state=None, check_input=True, max_squared_sum=None, sample_weight=None, - l1_ratio=None): + l1_ratio=None, precondition='auto'): """Compute a Logistic Regression model for a list of regularization parameters. @@ -606,6 +676,12 @@ def _logistic_regression_path(X, y, pos_class=None, Cs=10, fit_intercept=True, to using ``penalty='l1'``. For ``0 < l1_ratio <1``, the penalty is a combination of L1 and L2. + precondition : boolean or 'auto', default='auto' + Whether to use preconditioning for solving the optimization problem. + A diagonal preconditioning based on the data standard deviation is + used. If 'auto', preconditioning is used when ``solver='lbfgs'``, which + is the only solver that currently supports it. + Returns ------- coefs : ndarray of shape (n_cs, n_features) or (n_cs, n_features + 1) @@ -634,6 +710,12 @@ def _logistic_regression_path(X, y, pos_class=None, Cs=10, fit_intercept=True, solver = _check_solver(solver, penalty, dual) + if precondition == 'auto': + precondition = solver == 'lbfgs' + if precondition and solver != 'lbfgs': + raise ValueError("precondition=True only supported with" + " solver='lbfgs'") + # Preprocessing. if check_input: X = check_array(X, accept_sparse='csr', dtype=np.float64, @@ -697,6 +779,36 @@ def _logistic_regression_path(X, y, pos_class=None, Cs=10, fit_intercept=True, w0 = np.zeros((classes.size, n_features + int(fit_intercept)), order='F', dtype=X.dtype) + # preconditioning for lbfgs + # Subtract mean, divide by standard deviation but keep scaling and + # mean to allow solving the original problem. + # The scaling is required in the gradient computation for the penalty + # Both scaling and mean are used later used to transform + # optimization results back to the original space. + # In the sparse case, the mean can not be subtracted and the + # correction is carried along as X_offset. + X_pre = X + X_scale = None + X_offset = None + if precondition: + # FIXME this duplicates some code from _preprocess_data + # and should be refactored + X_mean, X_scale = _weighted_mean_std(X, sample_weight) + if sparse.issparse(X): + X_scale[X_scale == 0] = 1 + if fit_intercept: + X_offset = -X_mean/X_scale + # FIXME old scipy requires conversion to sparse matrix + # before calling multiply + X_pre = X_pre.multiply(sparse.csr_matrix(1 / X_scale)) + + else: + if fit_intercept: + X_pre = X - X_mean + X_scale[X_scale == 0] = 1 + X_pre = X_pre / X_scale + + # warm starting if coef is not None: # it must work both giving the bias term and not if multi_class == 'ovr': @@ -705,6 +817,11 @@ def _logistic_regression_path(X, y, pos_class=None, Cs=10, fit_intercept=True, 'Initialization coef is of shape %d, expected shape ' '%d or %d' % (coef.size, n_features, w0.size)) w0[:coef.size] = coef + if precondition: + if fit_intercept: + w0[-1] += np.inner(w0[:n_features], X_mean) + w0[:n_features] *= X_scale + else: # For binary problems coef.shape[0] should be 1, otherwise it # should be classes.size. @@ -725,6 +842,10 @@ def _logistic_regression_path(X, y, pos_class=None, Cs=10, fit_intercept=True, w0[1, :coef.shape[1]] = coef else: w0[:, :coef.shape[1]] = coef + if precondition: + if fit_intercept: + w0[:, -1] += np.dot(w0[:, :n_features], X_mean) + w0[:, :n_features] *= X_scale if multi_class == 'multinomial': # scipy.optimize.minimize and newton-cg accepts only @@ -740,6 +861,7 @@ def grad(x, *args): return _multinomial_loss_grad(x, *args)[1] hess = _multinomial_grad_hess warm_start_sag = {'coef': w0.T} else: + # binary logistic regression target = y_bin if solver == 'lbfgs': func = _logistic_loss_and_grad @@ -751,19 +873,27 @@ def grad(x, *args): return _logistic_loss_and_grad(x, *args)[1] coefs = list() n_iter = np.zeros(len(Cs), dtype=np.int32) + + loss_value = None for i, C in enumerate(Cs): if solver == 'lbfgs': iprint = [-1, 50, 1, 100, 101][ np.searchsorted(np.array([0, 1, 2, 3]), verbose)] opt_res = optimize.minimize( func, w0, method="L-BFGS-B", jac=True, - args=(X, target, 1. / C, sample_weight), + args=(X_pre, target, 1. / C, sample_weight, X_scale, X_offset), options={"iprint": iprint, "gtol": tol, "maxiter": max_iter} ) n_iter_i = _check_optimize_result( solver, opt_res, max_iter, extra_warning_msg=_LOGISTIC_SOLVER_CONVERGENCE_MSG) - w0, loss = opt_res.x, opt_res.fun + w0, loss_value = opt_res.x, opt_res.fun + if precondition and multi_class != 'multinomial': + # adjust weight scale for rescaling + w0[:n_features] = w0[:n_features] / X_scale + # adjust intercept for mean subtraction + if fit_intercept: + w0[-1] = w0[-1] - np.inner(w0[:-1], X_mean) elif solver == 'newton-cg': args = (X, target, 1. / C, sample_weight) w0, n_iter_i = _newton_cg(hess, func, grad, w0, args=args, @@ -808,6 +938,15 @@ def grad(x, *args): return _logistic_loss_and_grad(x, *args)[1] if multi_class == 'multinomial': n_classes = max(2, classes.size) multi_w0 = np.reshape(w0, (n_classes, -1)) + if precondition: + if fit_intercept: + multi_w0[:, :-1] = multi_w0[:, :-1] / X_scale + # adjust intercept for preconditioning + multi_w0[:, -1] = (multi_w0[:, -1] + - np.dot(multi_w0[:, :-1], X_mean)) + else: + multi_w0 = multi_w0 / X_scale + if n_classes == 2: multi_w0 = multi_w0[1][np.newaxis, :] coefs.append(multi_w0.copy()) @@ -816,7 +955,7 @@ def grad(x, *args): return _logistic_loss_and_grad(x, *args)[1] n_iter[i] = n_iter_i - return np.array(coefs), np.array(Cs), n_iter + return np.array(coefs), np.array(Cs), n_iter, loss_value # helper function for LogisticCV @@ -827,7 +966,7 @@ def _log_reg_scoring_path(X, y, train, test, pos_class=None, Cs=10, dual=False, intercept_scaling=1., multi_class='auto', random_state=None, max_squared_sum=None, sample_weight=None, - l1_ratio=None): + l1_ratio=None, precondition='auto'): """Computes scores across logistic_regression_path Parameters @@ -963,14 +1102,15 @@ def _log_reg_scoring_path(X, y, train, test, pos_class=None, Cs=10, sample_weight = _check_sample_weight(sample_weight, X) sample_weight = sample_weight[train] - coefs, Cs, n_iter = _logistic_regression_path( + coefs, Cs, n_iter, loss_value = _logistic_regression_path( X_train, y_train, Cs=Cs, l1_ratio=l1_ratio, fit_intercept=fit_intercept, solver=solver, max_iter=max_iter, class_weight=class_weight, pos_class=pos_class, multi_class=multi_class, tol=tol, verbose=verbose, dual=dual, penalty=penalty, intercept_scaling=intercept_scaling, random_state=random_state, check_input=False, - max_squared_sum=max_squared_sum, sample_weight=sample_weight) + max_squared_sum=max_squared_sum, sample_weight=sample_weight, + precondition=precondition) log_reg = LogisticRegression(solver=solver, multi_class=multi_class) @@ -1164,6 +1304,12 @@ class LogisticRegression(LinearClassifierMixin, to using ``penalty='l1'``. For ``0 < l1_ratio <1``, the penalty is a combination of L1 and L2. + precondition : boolean or 'auto', default='auto' + Whether to use preconditioning for solving the optimization problem. + A diagonal preconditioning based on the data standard deviation is + used. If 'auto', preconditioning is used when ``solver='lbfgs'``, which + is the only solver that currently supports it. + Attributes ---------- @@ -1186,6 +1332,9 @@ class LogisticRegression(LinearClassifierMixin, corresponds to outcome 1 (True) and `-intercept_` corresponds to outcome 0 (False). + objective_value_ : float + Objective function value (penalized loss). Lower is better. + n_iter_ : ndarray of shape (n_classes,) or (1, ) Actual number of iterations for all classes. If binary or multinomial, it returns only 1 element. For liblinear solver, only the maximum @@ -1256,7 +1405,7 @@ def __init__(self, penalty='l2', *, dual=False, tol=1e-4, C=1.0, fit_intercept=True, intercept_scaling=1, class_weight=None, random_state=None, solver='lbfgs', max_iter=100, multi_class='auto', verbose=0, warm_start=False, n_jobs=None, - l1_ratio=None): + l1_ratio=None, precondition='auto'): self.penalty = penalty self.dual = dual @@ -1273,6 +1422,7 @@ def __init__(self, penalty='l2', *, dual=False, tol=1e-4, C=1.0, self.warm_start = warm_start self.n_jobs = n_jobs self.l1_ratio = l1_ratio + self.precondition = precondition def fit(self, X, y, sample_weight=None): """ @@ -1412,19 +1562,22 @@ def fit(self, X, y, sample_weight=None): class_weight=self.class_weight, check_input=False, random_state=self.random_state, coef=warm_start_coef_, penalty=penalty, max_squared_sum=max_squared_sum, - sample_weight=sample_weight) + sample_weight=sample_weight, + precondition=self.precondition) for class_, warm_start_coef_ in zip(classes_, warm_start_coef)) - fold_coefs_, _, n_iter_ = zip(*fold_coefs_) + fold_coefs_, _, n_iter_, objective_value_ = zip(*fold_coefs_) self.n_iter_ = np.asarray(n_iter_, dtype=np.int32)[:, 0] n_features = X.shape[1] if multi_class == 'multinomial': self.coef_ = fold_coefs_[0][0] + self.objective_value_ = objective_value_[0] else: self.coef_ = np.asarray(fold_coefs_) self.coef_ = self.coef_.reshape(n_classes, n_features + int(self.fit_intercept)) + self.objective_value_ = objective_value_[0] if self.fit_intercept: self.intercept_ = self.coef_[:, -1] @@ -1658,6 +1811,12 @@ class LogisticRegressionCV(LogisticRegression, ``penalty='l1'``. For ``0 < l1_ratio <1``, the penalty is a combination of L1 and L2. + precondition : boolean or 'auto', default='auto' + Whether to use preconditioning for solving the optimization problem. + A diagonal preconditioning based on the data standard deviation is + used. If 'auto', preconditioning is used when ``solver='lbfgs'``, which + is the only solver that currently supports it. + Attributes ---------- classes_ : ndarray of shape (n_classes, ) @@ -1696,6 +1855,10 @@ class LogisticRegressionCV(LogisticRegression, ``(n_folds, n_cs, n_l1_ratios_, n_features)`` or ``(n_folds, n_cs, n_l1_ratios_, n_features + 1)``. + objective_value_ : float + Objective function value (penalized loss). Lower is better. + Only present if `refit=True`. + scores_ : dict dict with classes as the keys, and the values as the grid of scores obtained during cross-validating each fold, after doing @@ -1747,7 +1910,7 @@ def __init__(self, *, Cs=10, fit_intercept=True, cv=None, dual=False, penalty='l2', scoring=None, solver='lbfgs', tol=1e-4, max_iter=100, class_weight=None, n_jobs=None, verbose=0, refit=True, intercept_scaling=1., multi_class='auto', - random_state=None, l1_ratios=None): + random_state=None, l1_ratios=None, precondition='auto'): self.Cs = Cs self.fit_intercept = fit_intercept self.cv = cv @@ -1765,6 +1928,7 @@ def __init__(self, *, Cs=10, fit_intercept=True, cv=None, dual=False, self.multi_class = multi_class self.random_state = random_state self.l1_ratios = l1_ratios + self.precondition = precondition def fit(self, X, y, sample_weight=None): """Fit the model according to the given training data. @@ -1984,7 +2148,7 @@ def fit(self, X, y, sample_weight=None): # Note that y is label encoded and hence pos_class must be # the encoded label / None (for 'multinomial') - w, _, _ = _logistic_regression_path( + w, _, _, objective_value = _logistic_regression_path( X, y, pos_class=encoded_label, Cs=[C_], solver=solver, fit_intercept=self.fit_intercept, coef=coef_init, max_iter=self.max_iter, tol=self.tol, @@ -1997,6 +2161,7 @@ def fit(self, X, y, sample_weight=None): sample_weight=sample_weight, l1_ratio=l1_ratio_) w = w[0] + self.objective_value_ = objective_value else: # Take the best scores across every fold and the average of diff --git a/sklearn/linear_model/tests/test_logistic.py b/sklearn/linear_model/tests/test_logistic.py index e215400b53b80..3ae479f9f188e 100644 --- a/sklearn/linear_model/tests/test_logistic.py +++ b/sklearn/linear_model/tests/test_logistic.py @@ -16,7 +16,7 @@ from sklearn.model_selection import GridSearchCV from sklearn.model_selection import train_test_split from sklearn.model_selection import cross_val_score -from sklearn.preprocessing import LabelEncoder, StandardScaler +from sklearn.preprocessing import LabelEncoder, StandardScaler, label_binarize from sklearn.utils import compute_class_weight, _IS_32BIT from sklearn.utils._testing import assert_raise_message from sklearn.utils._testing import assert_raises @@ -34,7 +34,7 @@ _logistic_regression_path, LogisticRegressionCV, _logistic_loss_and_grad, _logistic_grad_hess, _multinomial_grad_hess, _logistic_loss, - _log_reg_scoring_path) + _log_reg_scoring_path, _multinomial_loss_grad) X = [[-1, 0], [0, 1], [1, 1]] X_sp = sp.csr_matrix(X) @@ -354,7 +354,7 @@ def test_consistency_path(): # can't test with fit_intercept=True since LIBLINEAR # penalizes the intercept for solver in ['sag', 'saga']: - coefs, Cs, _ = f(_logistic_regression_path)( + coefs, Cs, _, _ = f(_logistic_regression_path)( X, y, Cs=Cs, fit_intercept=False, tol=1e-5, solver=solver, max_iter=1000, multi_class='ovr', random_state=0) for i, C in enumerate(Cs): @@ -369,7 +369,7 @@ def test_consistency_path(): # test for fit_intercept=True for solver in ('lbfgs', 'newton-cg', 'liblinear', 'sag', 'saga'): Cs = [1e3] - coefs, Cs, _ = f(_logistic_regression_path)( + coefs, Cs, _, _ = f(_logistic_regression_path)( X, y, Cs=Cs, tol=1e-6, solver=solver, intercept_scaling=10000., random_state=0, multi_class='ovr') lr = LogisticRegression(C=Cs[0], tol=1e-4, @@ -424,33 +424,34 @@ def test_liblinear_dual_random_state(): def test_logistic_loss_and_grad(): - X_ref, y = make_classification(n_samples=20, random_state=0) - n_features = X_ref.shape[1] - + X_ref, y = make_classification(n_samples=21, random_state=0) X_sp = X_ref.copy() X_sp[X_sp < .1] = 0 X_sp = sp.csr_matrix(X_sp) + clf = LogisticRegression(random_state=0).fit(X_ref, y) for X in (X_ref, X_sp): - w = np.zeros(n_features) - - # First check that our derivation of the grad is correct - loss, grad = _logistic_loss_and_grad(w, X, y, alpha=1.) - approx_grad = optimize.approx_fprime( - w, lambda w: _logistic_loss_and_grad(w, X, y, alpha=1.)[0], 1e-3 - ) - assert_array_almost_equal(grad, approx_grad, decimal=2) - - # Second check that our intercept implementation is good - w = np.zeros(n_features + 1) - loss_interp, grad_interp = _logistic_loss_and_grad( - w, X, y, alpha=1. - ) - assert_array_almost_equal(loss, loss_interp) - - approx_grad = optimize.approx_fprime( - w, lambda w: _logistic_loss_and_grad(w, X, y, alpha=1.)[0], 1e-3 - ) - assert_array_almost_equal(grad_interp, approx_grad, decimal=2) + for X_offset in (None, np.asarray(X.mean(axis=0)).squeeze()): + w = clf.coef_.copy().ravel() + + # First check that our derivation of the grad is correct + loss, grad = _logistic_loss_and_grad(w, X, y, alpha=1., + X_offset=X_offset) + approx_grad = optimize.approx_fprime( + w, lambda w: _logistic_loss_and_grad( + w, X, y, alpha=1., X_offset=X_offset)[0], 1e-3 + ) + assert_array_almost_equal(grad, approx_grad, decimal=2) + + # Second check that our intercept implementation is good + w = np.hstack([clf.coef_.copy().ravel(), clf.intercept_]) + loss_interp, grad_interp = _logistic_loss_and_grad( + w, X, y, alpha=1., X_offset=X_offset + ) + approx_grad = optimize.approx_fprime( + w, lambda w: _logistic_loss_and_grad( + w, X, y, alpha=1., X_offset=X_offset)[0], 1e-3 + ) + assert_array_almost_equal(grad_interp, approx_grad, decimal=2) def test_logistic_grad_hess(): @@ -502,6 +503,33 @@ def test_logistic_grad_hess(): assert_array_almost_equal(grad_interp, grad_interp_2) +def test_multinomial_loss_grad(): + n_features = 10 + n_classes = 3 + X_ref, y = make_classification(n_features=n_features, n_classes=n_classes, + random_state=0, n_informative=6) + + X_sp = X_ref.copy() + X_sp[X_sp < .1] = 0 + X_sp = sp.csr_matrix(X_sp) + sample_weight = np.ones(X_ref.shape[0]) + Y = label_binarize(y, [0, 1, 2]) + lr = LogisticRegression(random_state=0).fit(X_ref, y) + for X in (X_ref, X_sp): + for X_offset in (None, X.mean(axis=0)): + + w = np.hstack([lr.coef_, lr.intercept_.reshape(-1, 1)]) + loss, grad, p = _multinomial_loss_grad( + w, X, Y, alpha=1., X_scale=None, sample_weight=sample_weight, + X_offset=X_offset) + approx_grad = optimize.approx_fprime( + w.ravel(), lambda w: _multinomial_loss_grad( + w, X, Y, alpha=1., X_scale=None, X_offset=X_offset, + sample_weight=sample_weight)[0], 1e-5 + ) + assert_array_almost_equal(grad, approx_grad, decimal=3) + + def test_logistic_cv(): # test for LogisticRegressionCV object n_samples, n_features = 50, 5 @@ -952,9 +980,10 @@ def test_logistic_regression_multinomial(): assert clf_w.coef_.shape == (n_classes, n_features) # Compare solutions between lbfgs and the other solvers - assert_allclose(ref_i.coef_, clf_i.coef_, rtol=1e-2) - assert_allclose(ref_w.coef_, clf_w.coef_, rtol=1e-2) - assert_allclose(ref_i.intercept_, clf_i.intercept_, rtol=1e-2) + assert_allclose(ref_i.coef_, clf_i.coef_, rtol=1e-1, atol=1e-4) + assert_allclose(ref_w.coef_, clf_w.coef_, rtol=1e-1, atol=1e-4) + assert_allclose(ref_i.intercept_, clf_i.intercept_, rtol=1e-1, + atol=1e-4) # Test that the path give almost the same results. However since in this # case we take the average of the coefs after fitting across all the @@ -1674,7 +1703,7 @@ def test_logistic_regression_path_coefs_multinomial(): n_redundant=0, n_clusters_per_class=1, random_state=0, n_features=2) Cs = [.00001, 1, 10000] - coefs, _, _ = _logistic_regression_path(X, y, penalty='l1', Cs=Cs, + coefs, _, _, _ = _logistic_regression_path(X, y, penalty='l1', Cs=Cs, solver='saga', random_state=0, multi_class='multinomial') @@ -1827,6 +1856,45 @@ def test_scores_attribute_layout_elasticnet(): assert avg_scores_lrcv[i, j] == pytest.approx(avg_score_lr) +def test_illconditioned_lbfgs(): + # check that lbfgs converges even with ill-conditioned X + X, y = make_classification(n_samples=100, n_features=60, random_state=0) + X[:, 1] += 10000 + X[:, 0] *= 10000 + lr_pre = LogisticRegression(random_state=0, precondition=True) + with pytest.warns(None) as record: + lr_pre.fit(X, y) + assert len(record) == 0 + loss_pre = _logistic_loss( + np.hstack([lr_pre.coef_.ravel(), lr_pre.intercept_]), + X, 2 * y - 1, 1) + + lr = LogisticRegression(random_state=0, precondition=False) + with pytest.warns(ConvergenceWarning): + lr.fit(X, y) + loss = _logistic_loss(np.hstack([lr.coef_.ravel(), lr.intercept_]), + X, 2 * y - 1, 1) + assert loss_pre < loss + + +def test_logistic_loss_preconditioning(): + # check that _logistic_loss is invariant wrt whether we precondition. + X, y = make_classification(n_samples=100, n_features=60, random_state=0) + X[:, 1] += 10000 + lr = LogisticRegression(random_state=0, precondition=True, max_iter=1000) + lr.fit(X, y) + loss = _logistic_loss(np.hstack([lr.coef_.ravel(), lr.intercept_]), + X, 2 * y - 1, 1) + assert_almost_equal(loss, lr.objective_value_) + # do full preconditioning + X_mean = X.mean(axis=0) + X_std = X.std(axis=0) + X_pre = (X - X_mean) / X_std + w_scaled = lr.coef_.ravel() * X_std + w_pre = np.hstack([w_scaled, lr.intercept_ + np.inner(lr.coef_, X_mean)]) + loss_pre = _logistic_loss(w_pre, X_pre, 2 * y - 1, 1, X_scale=X_std) + assert_almost_equal(loss, loss_pre) + @pytest.mark.parametrize("fit_intercept", [False, True]) def test_multinomial_identifiability_on_iris(fit_intercept): """Test that the multinomial classification is identifiable. diff --git a/sklearn/utils/extmath.py b/sklearn/utils/extmath.py index ba8ce9e2879b4..f9341735e2447 100644 --- a/sklearn/utils/extmath.py +++ b/sklearn/utils/extmath.py @@ -912,3 +912,31 @@ def stable_cumsum(arr, axis=None, rtol=1e-05, atol=1e-08): 'its last element does not correspond to sum', RuntimeWarning) return out + + +def _weighted_mean_std(X, sample_weight): + """Compute weighted mean and standard deviation. + + Parameters + ---------- + X : array-like or sparse matrix, shape (n_samples, n_features) + input array. + sample_weight : ndarray, shape (n_samples,) + Weights. + + Returns + ------- + mean : ndarray, shape (n_features,) + Weighted mean. + std : ndarray, shape (n_features,) + Weighted std. + """ + if sparse.issparse(X): + normed_weights = sample_weight / sample_weight.sum() + sq_sum = safe_sparse_dot(normed_weights, X.multiply(X)) + mean = safe_sparse_dot(normed_weights, X) + var = sq_sum - mean ** 2 + else: + mean = np.average(X, weights=sample_weight, axis=0) + var = np.average(X**2, weights=sample_weight, axis=0) - mean ** 2 + return mean, np.sqrt(var) diff --git a/sklearn/utils/sparsefuncs.py b/sklearn/utils/sparsefuncs.py index 7fb9163bf517b..299231482ed87 100644 --- a/sklearn/utils/sparsefuncs.py +++ b/sklearn/utils/sparsefuncs.py @@ -64,7 +64,7 @@ def inplace_csr_row_scale(X, scale): def mean_variance_axis(X, axis): - """Compute mean and variance along an axix on a CSR or CSC matrix + """Compute mean and variance along an axis on a CSR or CSC matrix Parameters ---------- diff --git a/sklearn/utils/tests/test_extmath.py b/sklearn/utils/tests/test_extmath.py index cd0b1f3fd7f70..a9d2761b9a8f0 100644 --- a/sklearn/utils/tests/test_extmath.py +++ b/sklearn/utils/tests/test_extmath.py @@ -35,6 +35,7 @@ from sklearn.utils.extmath import softmax from sklearn.utils.extmath import stable_cumsum from sklearn.utils.extmath import safe_sparse_dot +from sklearn.utils.extmath import _weighted_mean_std from sklearn.datasets import make_low_rank_matrix @@ -809,3 +810,25 @@ def test_safe_sparse_dot_dense_output(dense_output): if dense_output: expected = expected.toarray() assert_allclose_dense_sparse(actual, expected) + + +def test_weighted_mean_std(): + rng = np.random.RandomState(0) + X = rng.normal(size=(100, 10)) + weights = rng.uniform(size=(100,)) + mean_dense, std_dense = _weighted_mean_std(X, weights) + mean_sparse, std_sparse = _weighted_mean_std( + sparse.csr_matrix(X), weights) + assert_allclose_dense_sparse(mean_dense, mean_sparse) + assert_allclose_dense_sparse(std_dense, std_sparse) + # with ones + weights = np.ones(100) + mean_dense, std_dense = _weighted_mean_std(X, weights) + mean_sparse, std_sparse = _weighted_mean_std( + sparse.csr_matrix(X), weights) + mean_expected = X.mean(axis=0) + std_expected = X.std(axis=0) + assert_allclose_dense_sparse(mean_dense, mean_expected) + assert_allclose_dense_sparse(std_dense, std_expected) + assert_allclose_dense_sparse(mean_sparse, mean_expected) + assert_allclose_dense_sparse(std_sparse, std_expected)