8000 Revert "Sag handle numerical error outside of cython (#13389)" · xhluca/scikit-learn@3cc43e4 · GitHub
[go: up one dir, main page]

Skip to content

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit 3cc43e4

Browse files
author
Xing
committed
Revert "Sag handle numerical error outside of cython (scikit-learn#13389)"
This reverts commit b6c2b95.
1 parent f9a665e commit 3cc43e4

File tree

3 files changed

+61
-115
lines changed

3 files changed

+61
-115
lines changed

doc/whats_new/v0.21.rst

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -191,11 +191,6 @@ Support for Python 3.4 and below has been officially dropped.
191191
:mod:`sklearn.linear_model`
192192
...........................
193193

194-
- |Fix| Fixed a performance issue of ``saga`` and ``sag`` solvers when called
195-
in a :class:`joblib.Parallel` setting with ``n_jobs > 1`` and
196-
``backend="threading"``, causing them to perform worse than in the sequential
197-
case. :issue:`13389` by :user:`Pierre Glaser <pierreglaser>`.
198-
199194
- |Feature| :class:`linear_model.LogisticRegression` and
200195
:class:`linear_model.LogisticRegressionCV` now support Elastic-Net penalty,
201196
with the 'saga' solver. :issue:`11646` by :user:`Nicolas Hug <NicolasHug>`.

sklearn/linear_model/sag_fast.pyx.tp

Lines changed: 60 additions & 88 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,12 @@ from ..utils.seq_dataset cimport SequentialDataset32, SequentialDataset64
5959

6060
from libc.stdio cimport printf
6161

62+
cdef void raise_infinite_error(int n_iter):
63+
raise ValueError("Floating-point under-/overflow occurred at "
64+
"epoch #%d. Lowering the step_size or "
65+
"scaling the input data with StandardScaler "
66+
"or MinMaxScaler might help." % (n_iter + 1))
67+
6268

6369

6470
{{for name, c_type, np_type in get_dispatch(dtypes)}}
@@ -343,9 +349,6 @@ def sag{{name}}(SequentialDataset{{name}} dataset,
343349
# the scalar used for multiplying z
344350
cdef {{c_type}} wscale = 1.0
345351

346-
# return value (-1 if an error occurred, 0 otherwise)
347-
cdef int status = 0
348-
349352
# the cumulative sums for each iteration for the sparse implementation
350353
cumulative_sums[0] = 0.0
351354

@@ -399,19 +402,16 @@ def sag{{name}}(SequentialDataset{{name}} dataset,
399402

400403
# make the weight updates
401404
if sample_itr > 0:
402-
status = lagged_update{{name}}(weights, wscale, xnnz,
403-
n_samples, n_classes,
404-
sample_itr,
405-
cumulative_sums,
406-
cumulative_sums_prox,
407-
feature_hist,
408-
prox,
409-
sum_gradient,
410-
x_ind_ptr,
411-
False,
412-
n_iter)
413-
if status == -1:
414-
break
405+
lagged_update{{name}}(weights, wscale, xnnz,
406+
n_samples, n_classes, sample_itr,
407+
cumulative_sums,
408+
cumulative_sums_prox,
409+
feature_hist,
410+
prox,
411+
sum_gradient,
412+
x_ind_ptr,
413+
False,
414+
n_iter)
415415

416416
# find the current prediction
417417
predict_sample{{name}}(x_data_ptr, x_ind_ptr, xnnz, weights, wscale,
@@ -460,12 +460,8 @@ def sag{{name}}(SequentialDataset{{name}} dataset,
460460

461461
# check to see that the intercept is not inf or NaN
462462
if not skl_isfinite{{name}}(intercept[class_ind]):
463-
status = -1
464-
break
465-
# Break from the n_samples outer loop if an error happened
466-
# in the fit_intercept n_classes inner loop
467-
if status == -1:
468-
break
463+
with gil:
464+
raise_infinite_error(n_iter)
469465

470466
# update the gradient memory for this sample
471467
for class_ind in range(n_classes):
@@ -488,32 +484,21 @@ def sag{{name}}(SequentialDataset{{name}} dataset,
488484
if verbose:
489485
with gil:
490486
print("rescaling...")
491-
status = scale_weights{{name}}(
492-
weights, &wscale, n_features, n_samples, n_classes,
487+
wscale = scale_weights{{name}}(
488+
weights, wscale, n_features, n_samples, n_classes,
493489
sample_itr, cumulative_sums,
494490
cumulative_sums_prox,
495491
feature_hist,
496492
prox, sum_gradient, n_iter)
497-
if status == -1:
498-
break
499-
500-
# Break from the n_iter outer loop if an error happened in the
501-
# n_samples inner loop
502-
if status == -1:
503-
break
504493

505494
# we scale the weights every n_samples iterations and reset the
506495
# just-in-time update system for numerical stability.
507-
status = scale_weights{{name}}(weights, &wscale, n_features,
508-
n_samples,
509-
n_classes, n_samples - 1,
510-
cumulative_sums,
511-
cumulative_sums_prox,
512-
feature_hist,
513-
prox, sum_gradient, n_iter)
514-
515-
if status == -1:
516-
break
496+
wscale = scale_weights{{name}}(weights, wscale, n_features, n_samples,
497+
n_classes, n_samples - 1, cumulative_sums,
498+
cumulative_sums_prox,
499+
feature_hist,
500+
prox, sum_gradient, n_iter)
501+
517502
# check if the stopping criteria is reached
518503
max_change = 0.0
519504
max_weight = 0.0
@@ -535,13 +520,6 @@ def sag{{name}}(SequentialDataset{{name}} dataset,
535520
printf('Epoch %d, change: %.8f\n', n_iter + 1,
536521
max_change / max_weight)
537522
n_iter += 1
538-
# We do the error treatment here based on error code in status to avoid
539-
# re-acquiring the GIL within the cython code, which slows the computation
540-
# when the sag/saga solver is used concurrently in multiple Python threads.
541-
if status == -1:
542-
raise ValueError(("Floating-point under-/overflow occurred at epoch"
543-
" #%d. Scaling input data with StandardScaler or"
544-
" MinMaxScaler might help.") % n_iter)
545523

546524
if verbose and n_iter >= max_iter:
547525
end_time = time(NULL)
@@ -555,15 +533,14 @@ def sag{{name}}(SequentialDataset{{name}} dataset,
555533

556534
{{for name, c_type, np_type in get_dispatch(dtypes)}}
557535

558-
cdef int scale_weights{{name}}({{c_type}}* weights, {{c_type}}* wscale,
559-
int n_features,
560-
int n_samples, int n_classes, int sample_itr,
561-
{{c_type}}* cumulative_sums,
562-
{{c_type}}* cumulative_sums_prox,
563-
int* feature_hist,
564-
bint prox,
565-
{{c_type}}* sum_gradient,
566-
int n_iter) nogil:
536+
cdef {{c_type}} scale_weights{{name}}({{c_type}}* weights, {{c_type}} wscale, int n_features,
537+
int n_samples, int n_classes, int sample_itr,
538+
{{c_type}}* cumulative_sums,
539+
{{c_type}}* cumulative_sums_prox,
540+
int* feature_hist,
541+
bint prox,
542+
{{c_type}}* sum_gradient,
543+
int n_iter) nogil:
567544
"""Scale the weights with wscale for numerical stability.
568545

569546
wscale = (1 - step_size * alpha) ** (n_iter * n_samples + sample_itr)
@@ -573,37 +550,34 @@ cdef int scale_weights{{name}}({{c_type}}* weights, {{c_type}}* wscale,
573550
This also limits the size of `cumulative_sums`.
574551
"""
575552

576-
cdef int status
577-
status = lagged_update{{name}}(weights, wscale[0], n_features,
578-
n_samples, n_classes, sample_itr + 1,
579-
cumulative_sums,
580-
cumulative_sums_prox,
581-
feature_hist,
582-
prox,
583-
sum_gradient,
584-
NULL,
585-
True,
586-
n_iter)
587-
# if lagged update succeeded, reset wscale to 1.0
588-
if status == 0:
589-
wscale[0] = 1.0
590-
return status
553+
lagged_update{{name}}(weights, wscale, n_features,
554+
n_samples, n_classes, sample_itr + 1,
555+
cumulative_sums,
556+
cumulative_sums_prox,
557+
feature_hist,
558+
prox,
559+
sum_gradient,
560+
NULL,
561+
True,
562+
n_iter)
563+
# reset wscale to 1.0
564+
return 1.0
591565

592566
{{endfor}}
593567

594568

595569
{{for name, c_type, np_type in get_dispatch(dtypes)}}
596570

597-
cdef int lagged_update{{name}}({{c_type}}* weights, {{c_type}} wscale, int xnnz,
598-
int n_samples, int n_classes, int sample_itr,
599-
{{c_type}}* cumulative_sums,
600-
{{c_type}}* cumulative_sums_prox,
601-
int* feature_hist,
602-
bint prox,
603-
{{c_type}}* sum_gradient,
604-
int* x_ind_ptr,
605-
bint reset,
606-
int n_iter) nogil:
571+
cdef void lagged_update{{name}}({{c_type}}* weights, {{c_type}} wscale, int xnnz,
572+
int n_samples, int n_classes, int sample_itr,
573+
{{c_type}}* cumulative_sums,
574+
{{c_type}}* cumulative_sums_prox,
575+
int* feature_hist,
576+
bint prox,
577+
{{c_type}}* sum_gradient,
578+
int* x_ind_ptr,
579+
bint reset,
580+
int n_iter) nogil:
607581
"""Hard perform the JIT updates for non-zero features of present sample.
608582
The updates that awaits are kept in memory using cumulative_sums,
609583
cumulative_sums_prox, wscale and feature_hist. See original SAGA paper
@@ -631,9 +605,8 @@ cdef int lagged_update{{name}}({{c_type}}* weights, {{c_type}} wscale, int xnnz,
631605
if reset:
632606
weights[idx] *= wscale
633607
if not skl_isfinite{{name}}(weights[idx]):
634-
# returning here does not require the gil as the return
635-
# type is a C integer
636-
return -1
608+
with gil:
609+
raise_infinite_error(n_iter)
637610
else:
638611
for class_ind in range(n_classes):
639612
idx = f_idx + class_ind
@@ -667,7 +640,8 @@ cdef int lagged_update{{name}}({{c_type}}* weights, {{c_type}} wscale, int xnnz,
667640
weights[idx] *= wscale
668641
# check to see that the weight is not inf or NaN
669642
if not skl_isfinite{{name}}(weights[idx]):
670-
return -1
643+
with gil:
644+
raise_infinite_error(n_iter)
671645
if reset:
672646
feature_hist[feature_ind] = sample_itr % n_samples
673647
else:
@@ -678,8 +652,6 @@ cdef int lagged_update{{name}}({{c_type}}* weights, {{c_type}} wscale, int xnnz,
678652
if prox:
679653
cumulative_sums_prox[sample_itr - 1] = 0.0
680654

681-
return 0
682-
683655
{{endfor}}
684656

685657

sklearn/linear_model/tests/test_sag.py

Lines changed: 1 addition & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
from sklearn.utils import compute_class_weight
2525
from sklearn.utils import check_random_state
2626
from sklearn.preprocessing import LabelEncoder, LabelBinarizer
27-
from sklearn.datasets import make_blobs, load_iris, make_classification
27+
from sklearn.datasets import make_blobs, load_iris
2828
from sklearn.base import clone
2929

3030
iris = load_iris()
@@ -826,24 +826,3 @@ def test_multinomial_loss_ground_truth():
826826
[-0.903942, +5.258745, -4.354803]])
827827
assert_almost_equal(loss_1, loss_gt)
828828
assert_array_almost_equal(grad_1, grad_gt)
829-
830-
831-
@pytest.mark.parametrize("solver", ["sag", "saga"])
832-
def test_sag_classifier_raises_error(solver):
833-
# Following #13316, the error handling behavior changed in cython sag. This
834-
# is simply a non-regression test to make sure numerical errors are
835-
# properly raised.
836-
837-
# Train a classifier on a simple problem
838-
rng = np.random.RandomState(42)
839-
X, y = make_classification(random_state=rng)
840-
clf = LogisticRegression(solver=solver, random_state=rng, warm_start=True)
841-
clf.fit(X, y)
842-
843-
# Trigger a numerical error by:
844-
# - corrupting the fitted coefficients of the classifier
845-
# - fit it again starting from its current state thanks to warm_start
846-
clf.coef_[:] = np.nan
847-
848-
with pytest.raises(ValueError, match="Floating-point under-/overflow"):
849-
clf.fit(X, y)

0 commit comments

Comments
 (0)
0