8000 WIP to be continued · scikit-learn/scikit-learn@fc0f0e7 · GitHub
[go: up one dir, main page]

Skip to content

Commit fc0f0e7

Browse files
committed
WIP to be continued
1 parent ee5d94e commit fc0f0e7

File tree

4 files changed

+115
-62
lines changed

4 files changed

+115
-62
lines changed

sklearn/linear_model/_glm/glm.py

Lines changed: 16 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -207,10 +207,10 @@ def fit(self, X, y, sample_weight=None):
207207
loss_dtype = min(max(y.dtype, X.dtype), np.float64)
208208
y = check_array(y, dtype=loss_dtype, order="C", ensure_2d=False)
209209

210-
# TODO: We could support samples_weight=None as the losses support it.
211-
# Note that _check_sample_weight calls check_array(order="C") required by
212-
# losses.
213-
sample_weight = _check_sample_weight(sample_weight, X, dtype=loss_dtype)
210+
if sample_weight is not None:
211+
# Note that _check_sample_weight calls check_array(order="C") required by
212+
# losses.
213+
sample_weight = _check_sample_weight(sample_weight, X, dtype=loss_dtype)
214214

215215
n_samples, n_features = X.shape
216216
self._base_loss = self._get_loss()
@@ -228,17 +228,20 @@ def fit(self, X, y, sample_weight=None):
228228

229229
# TODO: if alpha=0 check that X is not rank deficient
230230

231-
# IMPORTANT NOTE: Rescaling of sample_weight:
231+
# NOTE: Rescaling of sample_weight:
232232
# We want to minimize
233-
# obj = 1/(2*sum(sample_weight)) * sum(sample_weight * deviance)
233+
# obj = 1/(2 * sum(sample_weight)) * sum(sample_weight * deviance)
234234
# + 1/2 * alpha * L2,
235235
# with
236236
# deviance = 2 * loss.
237237
# The objective is invariant to multiplying sample_weight by a constant. We
238-
# choose this constant such that sum(sample_weight) = 1. Thus, we end up with
238+
# could choose this constant such that sum(sample_weight) = 1 in order to end
239+
# up with
239240
# obj = sum(sample_weight * loss) + 1/2 * alpha * L2.
240-
# Note that LinearModelLoss.loss() computes sum(sample_weight * loss).
241-
sample_weight = sample_weight / sample_weight.sum()
241+
# But LinearModelLoss.loss() already computes
242+
# average(loss, weights=sample_weight)
243+
# Thus, without rescaling, we have
244+
# obj = LinearModelLoss.loss(...)
242245

243246
if self.warm_start and hasattr(self, "coef_"):
244247
if self.fit_intercept:
@@ -415,10 +418,10 @@ def score(self, X, y, sample_weight=None):
415418
f" {base_loss.__name__}."
416419
)
417420

418-
# Note that constant_to_optimal_zero is already multiplied by sample_weight.
419-
constant = np.mean(base_loss.constant_to_optimal_zero(y_true=y))
420-
if sample_weight is not None:
421-
constant *= sample_weight.shape[0] / np.sum(sample_weight)
421+
constant = np.average(
422+
base_loss.constant_to_optimal_zero(y_true=y, sample_weight=None),
423+
weights=sample_weight,
424+
)
422425

423426
# Missing factor of 2 in deviance cancels out.
424427
deviance = base_loss(

sklearn/linear_model/_linear_loss.py

Lines changed: 28 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -12,18 +12,19 @@ class LinearModelLoss:
1212
1313
Note that raw_prediction is also known as linear predictor.
1414
15-
The loss is the sum of per sample losses and includes a term for L2
15+
The loss is the average of per sample losses and includes a term for L2
1616
regularization::
1717
18-
loss = sum_i s_i loss(y_i, X_i @ coef + intercept)
18+
loss = 1 / s_sum * sum_i s_i loss(y_i, X_i @ coef + intercept)
1919
+ 1/2 * l2_reg_strength * ||coef||_2^2
2020
21-
with sample weights s_i=1 if sample_weight=None.
21+
with sample weights s_i=1 if sample_weight=None and s_sum=sum_i s_i.
2222
2323
Gradient and hessian, for simplicity without intercept, are::
2424
25-
gradient = X.T @ loss.gradient + l2_reg_strength * coef
26-
hessian = X.T @ diag(loss.hessian) @ X + l2_reg_strength * identity
25+
gradient = 1 / s_sum * X.T @ loss.gradient + l2_reg_strength * coef
26+
hessian = 1 / s_sum * X.T @ diag(loss.hessian) @ X
27+
+ l2_reg_strength * identity
2728
2829
Conventions:
2930
if fit_intercept:
@@ -182,7 +183,7 @@ def loss(
182183
n_threads=1,
183184
raw_prediction=None,
184185
):
185-
"""Compute the loss as sum over point-wise losses.
186+
"""Compute the loss as weighted average over point-wise losses.
186187
187188
Parameters
188189
----------
@@ -209,7 +210,7 @@ def loss(
209210
Returns
210211
-------
211212
loss : float
212-
Sum of losses per sample plus penalty.
213+
Weighted average of losses per sample, plus penalty.
213214
"""
214215
if raw_prediction is None:
215216
weights, intercept, raw_prediction = self.weight_intercept_raw(coef, X)
@@ -219,10 +220,10 @@ def loss(
219220
loss = self.base_loss.loss(
220221
y_true=y,
221222
raw_prediction=raw_prediction,
222-
sample_weight=sample_weight,
223+
sample_weight=None,
223224
n_threads=n_threads,
224225
)
225-
loss = loss.sum()
226+
loss = np.average(loss, weights=sample_weight)
226227

227228
return loss + self.l2_penalty(weights, l2_reg_strength)
228229

@@ -263,12 +264,12 @@ def loss_gradient(
263264
Returns
264265
-------
265266
loss : float
266-
Sum of losses per sample plus penalty.
267+
Weighted average of losses per sample, plus penalty.
267268
268269
gradient : ndarray of shape coef.shape
269270
The gradient of the loss.
270271
"""
271-
n_features, n_classes = X.shape[1], self.base_loss.n_classes
272+
(n_samples, n_features), n_classes = X.shape, self.base_loss.n_classes
272273
n_dof = n_features + int(self.fit_intercept)
273274

274275
if raw_prediction is None:
@@ -282,9 +283,12 @@ def loss_gradient(
282283
sample_weight=sample_weight,
283284
n_threads=n_threads,
284285
)
285-
loss = loss.sum()
286+
sw_sum = n_samples if sample_weight is None else np.sum(sample_weight)
287+
loss = loss.sum() / sw_sum
286288
loss += self.l2_penalty(weights, l2_reg_strength)
287289

290+
grad_pointwise /= sw_sum
291+
288292
if not self.base_loss.is_multiclass:
289293
grad = np.empty_like(coef, dtype=weights.dtype)
290294
grad[:n_features] = X.T @ grad_pointwise + l2_reg_strength * weights
@@ -340,7 +344,7 @@ def gradient(
340344
gradient : ndarray of shape coef.shape
341345
The gradient of the loss.
342346
"""
343-
n_features, n_classes = X.shape[1], self.base_loss.n_classes
347+
(n_samples, n_features), n_classes = X.shape, self.base_loss.n_classes
344348
n_dof = n_features + int(self.fit_intercept)
345349

346350
if raw_prediction is None:
@@ -354,6 +358,8 @@ def gradient(
354358
sample_weight=sample_weight,
355359
n_threads=n_threads,
356360
)
361+
sw_sum = n_samples if sample_weight is None else np.sum(sample_weight)
362+
grad_pointwise /= sw_sum
357363

358364
if not self.base_loss.is_multiclass:
359365
grad = np.empty_like(coef, dtype=weights.dtype)
@@ -439,6 +445,9 @@ def gradient_hessian(
439445
sample_weight=sample_weight,
440446
n_threads=n_threads,
441447
)
448+
sw_sum = n_samples if sample_weight is None else np.sum(sample_weight)
449+
grad_pointwise /= sw_sum
450+
hess_pointwise /= sw_sum
442451

443452
# For non-canonical link functions and far away from the optimum, the pointwise
444453
# hessian can be negative. We take care that 75% of the hessian entries are
@@ -543,6 +552,7 @@ def gradient_hessian_product(
543552
(n_samples, n_features), n_classes = X.shape, self.base_loss.n_classes
544553
n_dof = n_features + int(self.fit_intercept)
545554
weights, intercept, raw_prediction = self.weight_intercept_raw(coef, X)
555+
sw_sum = n_samples if sample_weight is None else np.sum(sample_weight)
546556

547557
if not self.base_loss.is_multiclass:
548558
grad_pointwise, hess_pointwise = self.base_loss.gradient_hessian(
@@ -551,6 +561,8 @@ def gradient_hessian_product(
551561
sample_weight=sample_weight,
552562
n_threads=n_threads,
553563
)
564+
grad_pointwise /= sw_sum
565+
hess_pointwise /= sw_sum
554566
grad = np.empty_like(coef, dtype=weights.dtype)
555567
grad[:n_features] = X.T @ grad_pointwise + l2_reg_strength * weights
556568
if self.fit_intercept:
@@ -603,6 +615,7 @@ def hessp(s):
603615
sample_weight=sample_weight,
604616
n_threads=n_threads,
605617
)
618+
grad_pointwise /= sw_sum
606619
grad = np.empty((n_classes, n_dof), dtype=weights.dtype, order="F")
607620
grad[:, :n_features] = grad_pointwise.T @ X + l2_reg_strength * weights
608621
if self.fit_intercept:
@@ -644,9 +657,9 @@ def hessp(s):
644657
# hess_prod = empty_like(grad), but we ravel grad below and this
645658
# function is run after that.
646659
hess_prod = np.empty((n_classes, n_dof), dtype=weights.dtype, order="F")
647-
hess_prod[:, :n_features] = tmp.T @ X + l2_reg_strength * s
660+
hess_prod[:, :n_features] = (tmp.T @ X) / sw_sum + l2_reg_strength * s
648661
if self.fit_intercept:
649-
hess_prod[:, -1] = tmp.sum(axis=0)
662+
hess_prod[:, -1] = tmp.sum(axis=0) / sw_sum
650663
if coef.ndim == 1:
651664
return hess_prod.ravel(order="F")
652665
else:

sklearn/linear_model/_logistic.py

Lines changed: 25 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -292,33 +292,27 @@ def _logistic_regression_path(
292292
# np.unique(y) gives labels in sorted order.
293293
pos_class = classes[1]
294294

295-
# If sample weights exist, convert them to array (support for lists)
296-
# and check length
297-
# Otherwise set them to 1 for all examples
298-
sample_weight = _check_sample_weight(sample_weight, X, dtype=X.dtype, copy=True)
299-
300-
if solver == "newton-cholesky":
301-
# IMPORTANT NOTE: Rescaling of sample_weight:
302-
# Same as in _GeneralizedLinearRegressor.fit().
303-
# We want to minimize
304-
# obj = 1/(2*sum(sample_weight)) * sum(sample_weight * deviance)
305-
# + 1/2 * alpha * L2,
306-
# with
307-
# deviance = 2 * log_loss.
308-
# The objective is invariant to multiplying sample_weight by a constant. We
309-
# choose this constant such that sum(sample_weight) = 1. Thus, we end up with
310-
# obj = sum(sample_weight * loss) + 1/2 * alpha * L2.
311-
# Note that LinearModelLoss.loss() computes sum(sample_weight * loss).
312-
#
313-
# This rescaling has to be done before multiplying by class_weights.
314-
sw_sum = sample_weight.sum() # needed to rescale penalty, nasty matter!
315-
sample_weight = sample_weight / sw_sum
295+
if sample_weight is not None or class_weight is not None:
296+
sample_weight = _check_sample_weight(sample_weight, X, dtype=X.dtype, copy=True)
297+
# IMPORTANT NOTE:
298+
# All solvers relying on LinearModelLoss need to scale the penalty with n_samples
299+
10000 # or the sum of sample weights as the here implemented logistic regression
300+
# objective is (unfortunately)
301+
# C * sum(pointwise_loss) + penalty
302+
# instead of (as LinearModelLoss does)
303+
# mean(pointwise_loss) + 1/C * penalty
304+
if solver in ["lbfgs", "newton-cg", "newton-cholesky"]:
305+
# This needs to be calculated before sample_weight is multiplied by
306+
# class_weight.
307+
sw_sum = n_samples if sample_weight is None else np.sum(sample_weight)
316308

317309
# If class_weights is a dict (provided by the user), the weights
318310
# are assigned to the original labels. If it is "balanced", then
319311
# the class_weights are assigned after masking the labels with a OvR.
320312
le = LabelEncoder()
321-
if isinstance(class_weight, dict) or multi_class == "multinomial":
313+
if isinstance(class_weight, dict) or (
314+
multi_class == "multinomial" and class_weight is not None
315+
):
322316
class_weight_ = compute_class_weight(class_weight, classes=classes, y=y)
323317
sample_weight *= class_weight_[le.fit_transform(y)]
324318

@@ -445,7 +439,7 @@ def _logistic_regression_path(
445439
n_iter = np.zeros(len(Cs), dtype=np.int32)
446440
for i, C in enumerate(Cs):
447441
if solver == "lbfgs":
448-
l2_reg_strength = 1.0 / C
442+
l2_reg_strength = 1.0 / (C * sw_sum)
449443
iprint = [-1, 50, 1, 100, 101][
450444
np.searchsorted(np.array([0, 1, 2, 3]), verbose)
451445
]
@@ -455,7 +449,12 @@ def _logistic_regression_path(
455449
method="L-BFGS-B",
456450
jac=True,
457451
args=(X, target, sample_weight, l2_reg_strength, n_threads),
458-
options={"iprint": iprint, "gtol": tol, "maxiter": max_iter},
452+
options={
453+
"iprint": iprint,
454+
"gtol": tol,
455+
"maxiter": max_iter,
456+
"ftol": 64 * np.finfo(float).eps,
457+
},
459458
)
460459
n_iter_i = _check_optimize_result(
461460
solver,
@@ -465,15 +464,13 @@ def _logistic_regression_path(
465464
)
466465
w0, loss = opt_res.x, opt_res.fun
467466
elif solver == "newton-cg":
468-
l2_reg_strength = 1.0 / C
467+
l2_reg_strength = 1.0 / (C * sw_sum)
469468
args = (X, target, sample_weight, l2_reg_strength, n_threads)
470469
w0, n_iter_i = _newton_cg(
471470
hess, func, grad, w0, args=args, maxiter=max_iter, tol=tol
472471
)
473472
elif solver == "newton-cholesky":
474-
# The division by sw_sum is a consequence of the rescaling of
475-
# sample_weight, see comment above.
476-
l2_reg_strength = 1.0 / C / sw_sum
473+
l2_reg_strength = 1.0 / (C * sw_sum)
477474
sol = NewtonCholeskySolver(
478475
coef=w0,
479476
linear_loss=loss,

sklearn/linear_model/tests/test_logistic.py

Lines changed: 46 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -702,14 +702,17 @@ def test_logistic_regression_solvers_multiclass():
702702
}
703703

704704
for solver_1, solver_2 in itertools.combinations(regressors, r=2):
705-
assert_array_almost_equal(
706-
regressors[solver_1].coef_, regressors[solver_2].coef_, decimal=4
705+
assert_allclose(
706+
regressors[solver_1].coef_,
707+
regressors[solver_2].coef_,
708+
rtol=5e-3 if solver_2 == "saga" else 1e-3,
709+
err_msg=f"{solver_1} vs {solver_2}",
707710
)
708711

709712

710713
@pytest.mark.parametrize("weight", [{0: 0.1, 1: 0.2}, {0: 0.1, 1: 0.2, 2: 0.5}])
711714
@pytest.mark.parametrize("class_weight", ["weight", "balanced"])
712-
def test_logistic_regressioncv_class_weights(weight, class_weight):
715+
def test_logistic_regressioncv_class_weights(weight, class_weight, global_random_seed):
713716
"""Test class_weight for LogisticRegressionCV."""
714717
n_classes = len(weight)
715718
if class_weight == "weight":
@@ -722,23 +725,60 @@ def test_logistic_regressioncv_class_weights(weight, class_weight):
722725
n_informative=3,
723726
n_redundant=0,
724727
n_classes=n_classes,
725-
random_state=0,
728+
random_state=global_random_seed,
726729
)
727730
params = dict(
728731
Cs=1,
729732
fit_intercept=False,
730733
multi_class="ovr",
731734
class_weight=class_weight,
735+
tol=1e-8,
732736
)
733737
clf_lbfgs = LogisticRegressionCV(solver="lbfgs", **params)
734738
clf_lbfgs.fit(X, y)
735739

740+
from sklearn.linear_model._linear_loss import LinearModelLoss
741+
from sklearn._loss.loss import HalfMultinomialLoss, HalfBinomialLoss
742+
743+
if n_classes > 2:
744+
loss = LinearModelLoss(
745+
base_loss=HalfMultinomialLoss(n_classes=n_classes),
746+
fit_intercept=False,
747+
)
748+
else:
749+
loss = LinearModelLoss(
750+
base_loss=HalfBinomialLoss(),
751+
fit_intercept=False,
752+
)
753+
l_lbfgs = loss.loss(
754+
coef=clf_lbfgs.coef_.squeeze(),
755+
X=X,
756+
y=LabelEncoder().fit_transform(y).astype(float),
757+
sample_weight=None,
758+
l2_reg_strength=1 / 20,
759+
)
760+
print(f"loss lbfgs = {l_lbfgs} C_={clf_lbfgs.C_}")
761+
736762
for solver in set(SOLVERS) - set(["lbfgs"]):
737763
clf = LogisticRegressionCV(solver=solver, **params)
738764
if solver in ("sag", "saga"):
739-
clf.set_params(tol=1e-5, max_iter=10000, random_state=0)
765+
clf.set_params(
766+
tol=1e-18, max_iter=10000, random_state=global_random_seed + 1
767+
)
740768
clf.fit(X, y)
741-
assert_allclose(clf.coef_, clf_lbfgs.coef_, rtol=1e-3)
769+
770+
l_solver = loss.loss(
771+
coef=clf.coef_.squeeze(),
772+
X=X,
773+
y=LabelEncoder().fit_transform(y).astype(float),
774+
sample_weight=None,
775+
l2_reg_strength=1 / 20,
776+
)
777+
print(f"loss {solver} = {l_solver} C_={clf.C_}")
778+
779+
assert_allclose(
780+
clf.coef_, clf_lbfgs.coef_, rtol=1e-3, err_msg=f"{solver} vs lbfgs"
781+
)
742782

743783

744784
def test_logistic_regression_sample_weights():

0 commit comments

Comments
 (0)
0