8000 ENH fit intercept in ``ProxNewton`` solver (#77) · scikit-learn-contrib/skglm@db46d69 · GitHub
[go: up one dir, main page]

Skip to content

Commit db46d69

Browse files
Badr-MOUFADPABanniermathurinm
authored
ENH fit intercept in ProxNewton solver (#77)
Co-authored-by: PAB <pierreantoine.bannier@gmail.com> Co-authored-by: mathurinm <mathurin.massias@gmail.com>
1 parent 8b08c09 commit db46d69

File tree

2 files changed

+118
-59
lines changed

2 files changed

+118
-59
lines changed

skglm/solvers/prox_newton.py

Lines changed: 107 additions & 37 deletions
< 1E80 td data-grid-cell-id="diff-f30a3e38a9b6325a7becf51ca9536d89634dcd8d22e22e1f550a17aa0797c59e-76-98-1" data-selected="false" role="gridcell" style="background-color:var(--bgColor-default);text-align:center" tabindex="-1" valign="top" class="focusable-grid-cell diff-line-number position-relative diff-line-number-neutral left-side">98
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,9 @@ class ProxNewton(BaseSolver):
2424
tol : float, default 1e-4
2525
Tolerance for convergence.
2626
27+
fit_intercept : bool, default True
28+
If ``True``, fits an unpenalized intercept.
29+
2730
verbose : bool, default False
2831
Amount of verbosity. 0/False is silent.
2932
@@ -53,7 +56,8 @@ def __init__(self, p0=10, max_iter=20, max_pn_iter=1000, tol=1e-4,
5356

5457
def solve(self, X, y, datafit, penalty, w_init=None, Xw_init=None):
5558
n_samples, n_features = X.shape
56-
w = np.zeros(n_features) if w_init is None else w_init
59+
fit_intercept = self.fit_intercept
60+
w = np.zeros(n_features + fit_intercept) if w_init is None else w_init
5761
Xw = np.zeros(n_samples) if Xw_init is None else Xw_init
5862
all_features = np.arange(n_features)
5963
stop_crit = 0.
@@ -63,20 +67,38 @@ def solve(self, X, y, datafit, penalty, w_init=None, Xw_init=None):
6367
if is_sparse:
6468
X_bundles = (X.data, X.indptr, X.indices)
6569

70+
if len(w) != n_features + self.fit_intercept:
71+
if self.fit_intercept:
72+
val_error_message = (
73+
"w should be of size n_features + 1 when using fit_intercept=True: "
74+
f"expected {n_features + 1}, got {len(w)}.")
75+
else:
76+
val_error_message = (
77+
"w should be of size n_features: "
78+
f"expected {n_features}, got {len(w)}.")
79+
raise ValueError(val_error_message)
80+
6681
for t in range(self.max_iter):
6782
# compute scores
6883
if is_sparse:
6984
grad = _construct_grad_sparse(
70-
*X_bundles, y, w, Xw, datafit, all_features)
85+
*X_bundles, y, w[:n_features], Xw, datafit, all_features)
7186
else:
72-
grad = _construct_grad(X, y, w, Xw, datafit, all_features)
87+
grad = _construct_grad(X, y, w[:n_features], Xw, datafit, all_features)
7388

74-
opt = penalty.subdiff_distance(w, grad, all_features)
89+
opt = penalty.subdiff_distance(w[:n_features], grad, all_features)
90+
91+
# optimality of intercept
92+
if fit_intercept:
93+
# gradient w.r.t. intercept (constant features of ones)
94+
intercept_opt = np.abs(np.sum(datafit.raw_grad(y, Xw)))
95+
else:
96+
intercept_opt = 0.
7597

76
# check convergences
77-
stop_crit = np.max(opt)
99+
stop_crit = max(np.max(opt), intercept_opt)
78100
if self.verbose:
79-
p_obj = datafit.value(y, w, Xw) + penalty.value(w)
101+
p_obj = datafit.value(y, w, Xw) + penalty.value(w[:n_features])
80102
print(
81103
"Iteration {}: {:.10f}, ".format(t+1, p_obj) +
82104
"stopping crit: {:.2e}".format(stop_crit)
@@ -101,20 +123,22 @@ def solve(self, X, y, datafit, penalty, w_init=None, Xw_init=None):
101123
# find descent direction
102124
if is_sparse:
103125
delta_w_ws, X_delta_w_ws = _descent_direction_s(
104-
*X_bundles, y, w, Xw, grad_ws, datafit,
126+
*X_bundles, y, w, Xw, fit_intercept, grad_ws, datafit,
105127
penalty, ws, tol=EPS_TOL*tol_in)
106128
else:
107129
delta_w_ws, X_delta_w_ws = _descent_direction(
108-
X, y, w, Xw, grad_ws, datafit, penalty, ws, tol=EPS_TOL*tol_in)
130+
X, y, w, Xw, fit_intercept, grad_ws, datafit,
131+
penalty, ws, tol=EPS_TOL*tol_in)
109132

110133
# backtracking line search with inplace update of w, Xw
111134
if is_sparse:
112135
grad_ws[:] = _backtrack_line_search_s(
113-
*X_bundles, y, w, Xw, datafit, penalty, delta_w_ws,
114-
X_delta_w_ws, ws)
136+
*X_bundles, y, w, Xw, fit_intercept, datafit, penalty,
137+
delta_w_ws, X_delta_w_ws, ws)
115138
else:
116139
grad_ws[:] = _backtrack_line_search(
117-
X, y, w, Xw, datafit, penalty, delta_w_ws, X_delta_w_ws, ws)
140+
X, y, w, Xw, fit_intercept, datafit, penalty,
141+
delta_w_ws, X_delta_w_ws, ws)
118142

119143
# check convergence
120144
opt_in = penalty.subdiff_distance(w, grad_ws, ws)
@@ -138,7 +162,7 @@ def solve(self, X, y, datafit, penalty, w_init=None, Xw_init=None):
138162

139163

140164
@njit
141-
def _descent_direction(X, y, w_epoch, Xw_epoch, grad_ws, datafit,
165+
def _descent_direction(X, y, w_epoch, Xw_epoch, fit_intercept, grad_ws, datafit,
142166
penalty, ws, tol):
143167
# Given:
144168
# 1) b = \nabla F(X w_epoch)
@@ -152,11 +176,16 @@ def _descent_direction(X, y, w_epoch, Xw_epoch, grad_ws, datafit,
152176
for idx, j in enumerate(ws):
153177
lipschitz[idx] = raw_hess @ X[:, j] ** 2
154178

155-
# for a less costly stopping criterion, we do no compute the exact gradient,
156-
# but store each coordinate-wise gradient every time we upate one coordinate:
179+
# for a less costly stopping criterion, we do not compute the exact gradient,
180+
# but store each coordinate-wise gradient every time we update one coordinate
157181
past_grads = np.zeros(len(ws))
158182
X_delta_w_ws = np.zeros(X.shape[0])
159-
w_ws = w_epoch[ws]
183+
ws_intercept = np.append(ws, -1) if fit_intercept else ws
184+
w_ws = w_epoch[ws_intercept]
185+
186+
if fit_intercept:
187+
lipschitz_intercept = np.sum(raw_hess)
188+
grad_intercept = np.sum(datafit.raw_grad(y, Xw_epoch))
160189

161190
for cd_iter in range(MAX_CD_ITER):
162191
for idx, j in enumerate(ws):
@@ -174,22 +203,35 @@ def _descent_direction(X, y, w_epoch, Xw_epoch, grad_ws, datafit,
174203
if w_ws[idx] != old_w_idx:
175204
X_delta_w_ws += (w_ws[idx] - old_w_idx) * X[:, j]
176205

206+
if fit_intercept:
207+
past_grads_intercept = grad_intercept + raw_hess @ X_delta_w_ws
208+
old_intercept = w_ws[-1]
209+
w_ws[-1] -= past_grads_intercept / lipschitz_intercept
210+
211+
if w_ws[-1] != old_intercept:
212+
X_delta_w_ws += w_ws[-1] - old_intercept
213+
177214
if cd_iter % 5 == 0:
178215
# TODO: can be improved by passing in w_ws but breaks for WeightedL1
179216
current_w = w_epoch.copy()
180-
current_w[ws] = w_ws
217+
current_w[ws_intercept] = w_ws
181218
opt = penalty.subdiff_distance(current_w, past_grads, ws)
182-
if np.max(opt) <= tol:
219+
stop_crit = np.max(opt)
220+
221+
if fit_intercept:
222+
stop_crit = max(stop_crit, np.abs(past_grads_intercept))
223+
224+
if stop_crit <= tol:
183225
break
184226

185227
# descent direction
186-
return w_ws - w_epoch[ws], X_delta_w_ws
228+
return w_ws - w_epoch[ws_intercept], X_delta_w_ws
187229

188230

189-
# sparse version of _compute_descent_direction
231+
# sparse version of _descent_direction
190232
@njit
191233
def _descent_direction_s(X_data, X_indptr, X_indices, y, w_epoch,
192-
Xw_epoch, grad_ws, datafit, penalty, ws, tol):
234+
Xw_epoch, fit_intercept, grad_ws, datafit, penalty, ws, tol):
193235
raw_hess = datafit.raw_hessian(y, Xw_epoch)
194236

195237
lipschitz = np.zeros(len(ws))
@@ -201,7 +243,12 @@ def _descent_direction_s(X_data, X_indptr, X_indices, y, w_epoch,
201243
# see _descent_direction() comment
202244
past_grads = np.zeros(len(ws))
203245
X_delta_w_ws = np.zeros(len(y))
204-
w_ws = w_epoch[ws]
246+
ws_intercept = np.append(ws, -1) if fit_intercept else ws
247+
w_ws = w_epoch[ws_intercept]
248+
249+
if fit_intercept:
250+
lipschitz_intercept = np.sum(raw_hess)
251+
grad_intercept = np.sum(datafit.raw_grad(y, Xw_epoch))
205252

206253
for cd_iter in range(MAX_CD_ITER):
207254
for idx, j in enumerate(ws):
@@ -224,39 +271,57 @@ def _descent_direction_s(X_data, X_indptr, X_indices, y, w_epoch,
224271
_update_X_delta_w(X_data, X_indptr, X_indices, X_delta_w_ws,
225272
w_ws[idx] - old_w_idx, j)
226273

274+
if fit_intercept:
275+
past_grads_intercept = grad_intercept + raw_hess @ X_delta_w_ws
276+
old_intercept = w_ws[-1]
277+
w_ws[-1] -= past_grads_intercept / lipschitz_intercept
278+
279+
if w_ws[-1] != old_intercept:
280+
X_delta_w_ws += w_ws[-1] - old_intercept
281+
227282
if cd_iter % 5 == 0:
228283
# TODO: could be improved by passing in w_ws
229284
current_w = w_epoch.copy()
230-
current_w[ws] = w_ws
285+
current_w[ws_intercept] = w_ws
231286
opt = penalty.subdiff_distance(current_w, past_grads, ws)
232-
if np.max(opt) <= tol:
287+
stop_crit = np.max(opt)
288+
289+
if fit_intercept:
290+
stop_crit = max(stop_crit, np.abs(past_grads_intercept))
291+
292+
if stop_crit <= tol:
233293
break
234294

235295
# descent direction
236-
return w_ws - w_epoch[ws], X_delta_w_ws
296+
return w_ws - w_epoch[ws_intercept], X_delta_w_ws
237297

238298

239299
@njit
240-
def _backtrack_line_search(X, y, w, Xw, datafit, penalty, delta_w_ws,
300+
def _backtrack_line_search(X, y, w, Xw, fit_intercept, datafit, penalty, delta_w_ws,
241301
X_delta_w_ws, ws):
242302
# 1) find step in [0, 1] such that:
243303
# penalty(w + step * delta_w) - penalty(w) +
244304
# step * \nabla datafit(w + step * delta_w) @ delta_w < 0
245305
# ref: https://www.di.ens.fr/~aspremon/PDF/ENSAE/Newton.pdf
246306
# 2) inplace update of w and Xw and return grad_ws of the last w and Xw
247307
step, prev_step = 1., 0.
308+
n_features = X.shape[1]
309+
ws_intercept = np.append(ws, -1) if fit_intercept else ws
248310
# TODO: could be improved by passing in w[ws]
249-
old_penalty_val = penalty.value(w)
311+
old_penalty_val = penalty.value(w[:n_features])
250312

251313
# try step = 1, 1/2, 1/4, ...
252314
for _ in range(MAX_BACKTRACK_ITER):
253-
w[ws] += (step - prev_step) * delta_w_ws
315+
w[ws_intercept] += (step - prev_step) * delta_w_ws
254316
Xw += (step - prev_step) * X_delta_w_ws
255317

256-
grad_ws = _construct_grad(X, y, w, Xw, datafit, ws)
318+
grad_ws = _construct_grad(X, y, w[:n_features], Xw, datafit, ws)
257319
# TODO: could be improved by passing in w[ws]
258-
stop_crit = penalty.value(w) - old_penalty_val
259-
stop_crit += step * grad_ws @ delta_w_ws
320+
stop_crit = penalty.value(w[:n_features]) - old_penalty_val
321+
stop_crit += step * grad_ws @ delta_w_ws[:len(ws)]
322+
323+
if fit_intercept:
324+
stop_crit += step * delta_w_ws[-1] * np.sum(datafit.raw_grad(y, Xw))
260325

261326
if stop_crit < 0:
262327
break
@@ -272,21 +337,26 @@ def _backtrack_line_search(X, y, w, Xw, datafit, penalty, delta_w_ws,
272337

273338
# sparse version of _backtrack_line_search
274339
@njit
275-
def _backtrack_line_search_s(X_data, X_indptr, X_indices, y, w, Xw, datafit,
276-
penalty, delta_w_ws, X_delta_w_ws, ws):
340+
def _backtrack_line_search_s(X_data, X_indptr, X_indices, y, w, Xw, fit_intercept,
341+
datafit, penalty, delta_w_ws, X_delta_w_ws, ws):
277342
step, prev_step = 1., 0.
343+
n_features = len(X_indptr) - 1
344+
ws_intercept = np.append(ws, -1) if fit_intercept else ws
278345
# TODO: could be improved by passing in w[ws]
279-
old_penalty_val = penalty.value(w)
346+
old_penalty_val = penalty.value(w[:n_features])
280347

281348
for _ in range(MAX_BACKTRACK_ITER):
282-
w[ws] += (step - prev_step) * delta_w_ws
349+
w[ws_intercept] += (step - prev_step) * delta_w_ws
283350
Xw += (step - prev_step) * X_delta_w_ws
284351

285352
grad_ws = _construct_grad_sparse(X_data, X_indptr, X_indices,
286-
y, w, Xw, datafit, ws)
353+
y, w[:n_features], Xw, datafit, ws)
287354
# TODO: could be improved by passing in w[ws]
288-
stop_crit = penalty.value(w) - old_penalty_val
289-
stop_crit += step * grad_ws.T @ delta_w_ws
355+
stop_crit = penalty.value(w[:n_features]) - old_penalty_val
356+
stop_crit += step * grad_ws.T @ delta_w_ws[:len(ws)]
357+
358+
if fit_intercept:
359+
stop_crit += step * delta_w_ws[-1] * np.sum(datafit.raw_grad(y, Xw))
290360

291361
if stop_crit < 0:
292362
break

skglm/tests/test_prox_newton.py

Lines changed: 11 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -9,25 +9,10 @@
99
from skglm.utils import make_correlated_data, compiled_clone
1010

1111

12-
@pytest.mark.parametrize('X_density', [1, 0.5])
13-
def test_alpha_max(X_density):
14-
n_samples, n_features = 10, 20
15-
X, y, _ = make_correlated_data(
16-
n_samples, n_features, X_density=X_density, random_state=2)
17-
y = np.sign(y)
18-
19-
alpha_max = np.linalg.norm(X.T @ y, ord=np.inf) / (2 * n_samples)
20-
21-
log_datafit = compiled_clone(Logistic())
22-
l1_penalty = compiled_clone(L1(alpha_max))
23-
w = ProxNewton().solve(X, y, log_datafit, l1_penalty)[0]
24-
25-
np.testing.assert_equal(w, 0)
26-
27-
28-
@pytest.mark.parametrize("rho, X_density", product([1e-1, 1e-2], [1., 0.5]))
29-
def test_pn_vs_sklearn(rho, X_density):
30-
n_samples, n_features = 11, 19
12+
@pytest.mark.parametrize("X_density, fit_intercept", product([1., 0.5], [True, False]))
13+
def test_pn_vs_sklearn(X_density, fit_intercept):
14+
n_samples, n_features = 12, 25
15+
rho = 1e-1
3116

3217
X, y, _ = make_correlated_data(n_samples, n_features, random_state=0,
3318
X_density=X_density)
@@ -37,14 +22,18 @@ def test_pn_vs_sklearn(rho, X_density):
3722
alpha = rho * alpha_max
3823

3924
sk_log_reg = LogisticRegression(penalty='l1', C=1/(n_samples * alpha),
40-
fit_intercept=False, tol=1e-9, solver='liblinear')
25+
fit_intercept=fit_intercept, random_state=0,
26+
tol=1e-12, solver='saga', max_iter=1_000_000)
4127
sk_log_reg.fit(X, y)
4228

4329
log_datafit = compiled_clone(Logistic())
4430
l1_penalty = compiled_clone(L1(alpha))
45-
w = ProxNewton(tol=1e-9).solve(X, y, log_datafit, l1_penalty)[0]
31+
prox_solver = ProxNewton(fit_intercept=fit_intercept, tol=1e-12)
32+
w = prox_solver.solve(X, y, log_datafit, l1_penalty)[0]
4633

47-
np.testing.assert_allclose(w, sk_log_reg.coef_.flatten(), rtol=1e-6, atol=1e-6)
34+
np.testing.assert_allclose(w[:n_features], sk_log_reg.coef_.flatten())
35+
if fit_intercept:
36+
np.testing.assert_allclose(w[-1], sk_log_reg.intercept_)
4837

4938

5039
if __name__ == '__main__':

0 commit comments

Comments
 (0)
0