8000 Ensure convergence and mark two new estimators as XFAIL · scikit-learn/scikit-learn@a0371ad · GitHub
[go: up one dir, main page]

Skip to content

Commit a0371ad

Browse files
committed
Ensure convergence and mark two new estimators as XFAIL
1 parent 8816d3f commit a0371ad

File tree

4 files changed

+63
-14
lines changed

4 files changed

+63
-14
lines changed

sklearn/linear_model/_huber.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -351,3 +351,16 @@ def fit(self, X, y, sample_weight=None):
351351
residual = np.abs(y - safe_sparse_dot(X, self.coef_) - self.intercept_)
352352
self.outliers_ = residual > self.scale_ * self.epsilon
353353
return self
354+
355+
def __sklearn_tags__(self):
356+
tags = super().__sklearn_tags__()
357+
tags._xfail_checks.update(
358+
{
359+
# TODO: fix sample_weight handling of this estimator, see
360+
# meta-issue #16298
361+
"check_sample_weight_equivalence": (
362+
"sample_weight is not equivalent to removing/repeating samples."
363+
),
364+
}
365+
)
366+
return tags

sklearn/linear_model/_logistic.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1463,9 +1463,19 @@ def __sklearn_tags__(self):
14631463
{
14641464
"check_non_transformer_estimators_n_iter": (
14651465
"n_iter_ cannot be easily accessed."
1466-
)
1466+
),
14671467
}
14681468
)
1469+
if self.solver in ("lbfgs", "liblinear"):
1470+
tags._xfail_checks.update(
1471+
{
1472+
# TODO: fix sample_weight handling of this estimator, see
1473+
# meta-issue #16298
1474+
"check_sample_weight_equivalence": (
1475+
"sample_weight is not equivalent to removing/repeating samples."
1476+
),
1477+
}
1478+
)
14691479
return tags
14701480

14711481

sklearn/utils/_test_common/instance_generator.py

Lines changed: 26 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -516,33 +516,49 @@
516516
max_iter=20, n_components=1, transform_algorithm="lasso_lars"
517517
)
518518
},
519+
ElasticNetCV: {"check_sample_weight_equivalence": dict(max_iter=100, tol=1e-2)},
519520
FactorAnalysis: {"check_dict_unchanged": dict(max_iter=5, n_components=1)},
520521
FastICA: {"check_dict_unchanged": dict(max_iter=5, n_components=1)},
521522
FeatureAgglomeration: {"check_dict_unchanged": dict(n_clusters=1)},
522523
GammaRegressor: {
523524
"check_sample_weight_equivalence": [
524-
dict(solver="newton-cholesky"),
525-
dict(solver="lbfgs"),
525+
dict(solver="newton-cholesky", max_iter=1_000, tol=1e-12),
526+
dict(solver="lbfgs", max_iter=1_000, tol=1e-12),
526527
]
527528
},
528529
GaussianMixture: {"check_dict_unchanged": dict(max_iter=5, n_init=2)},
529530
GaussianRandomProjection: {"check_dict_unchanged": dict(n_components=1)},
531+
HuberRegressor: {
532+
"check_sample_weight_equivalence": dict(tol=1e-12, max_iter=1_000)
533+
},
530534
IncrementalPCA: {"check_dict_unchanged": dict(batch_size=10, n_components=1)},
531535
Isomap: {"check_dict_unchanged": dict(n_components=1)},
532536
KMeans: {"check_dict_unchanged": dict(max_iter=5, n_clusters=1, n_init=2)},
533537
KernelPCA: {"check_dict_unchanged": dict(n_components=1)},
534538
LassoLars: {"check_non_transformer_estimators_n_iter": dict(alpha=0.0)},
539+
LassoCV: {"check_sample_weight_equivalence": dict(max_iter=100, tol=1e-2)},
535540
LatentDirichletAllocation: {
536541
"check_dict_unchanged": dict(batch_size=10, max_iter=5, n_components=1)
537542
},
538543
LinearDiscriminantAnalysis: {"check_dict_unchanged": dict(n_components=1)},
539544
LocallyLinearEmbedding: {"check_dict_unchanged": dict(max_iter=5, n_components=1)},
540545
LogisticRegression: {
541546
"check_sample_weight_equivalence": [
542-
dict(solver="lbfgs"),
543-
dict(solver="liblinear"),
544-
dict(solver="newton-cg"),
545-
dict(solver="newton-cholesky"),
547+
dict(solver="lbfgs", max_iter= 685C 1_000, tol=1e-12),
548+
# liblinear has more problems with higher regularization apparently...
549+
dict(solver="liblinear", C=0.01, max_iter=1_000, tol=1e-12),
550+
dict(solver="newton-cg", max_iter=1_000, tol=1e-12),
551+
dict(solver="newton-cholesky", max_iter=1_000, tol=1e-12),
552+
]
553+
},
554+
LogisticRegressionCV: {
555+
"check_sample_weight_equivalence": [
556+
dict(
557+
solver="newton-cholesky",
558+
Cs=np.logspace(-3, 3, 5),
559+
max_iter=1_000,
560+
tol=1e-12,
561+
),
546562
]
547563
},
548564
MDS: {"check_dict_unchanged": dict(max_iter=5, n_components=1, n_init=2)},
@@ -571,8 +587,8 @@
571587
PLSSVD: {"check_dict_unchanged": dict(n_components=1)},
572588
PoissonRegressor: {
573589
"check_sample_weight_equivalence": [
574-
dict(solver="newton-cholesky"),
575-
dict(solver="lbfgs"),
590+
dict(solver="newton-cholesky", max_iter=100),
591+
dict(solver="lbfgs", max_iter=100),
576592
]
577593
},
578594
PolynomialCountSketch: {"check_dict_unchanged": dict(n_components=1)},
@@ -626,8 +642,8 @@
626642
TruncatedSVD: {"check_dict_unchanged": dict(n_components=1)},
627643
TweedieRegressor: {
628644
"check_sample_weight_equivalence": [
629-
dict(solver="newton-cholesky"),
630-
dict(solver="lbfgs"),
645+
dict(solver="newton-cholesky", max_iter=1_000, tol=1e-12),
646+
dict(solver="lbfgs", max_iter=1_000, tol=1e-12),
631647
]
632648
},
633649
}

sklearn/utils/estimator_checks.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,12 @@
3333
make_multilabel_classification,
3434
make_regression,
3535
)
36-
from ..exceptions import DataConversionWarning, NotFittedError, SkipTestWarning
36+
from ..exceptions import (
37+
ConvergenceWarning,
38+
DataConversionWarning,
39+
NotFittedError,
40+
SkipTestWarning,
41+
)
3742
from ..linear_model._base import LinearClassifierMixin
3843
from ..metrics import accuracy_score, adjusted_rand_score, f1_score
3944
from ..metrics.pairwise import linear_kernel, pairwise_distances, rbf_kernel
@@ -1161,8 +1166,13 @@ def check_sample_weight_equivalence(name, estimator_orig):
11611166
y_weighted = _enforce_estimator_tags_y(estimator_weighted, y_weighted)
11621167
y_repeated = _enforce_estimator_tags_y(estimator_repeated, y_repeated)
11631168

1164-
estimator_repeated.fit(X_repeated, y=y_repeated, sample_weight=None)
1165-
estimator_weighted.fit(X_weigthed, y=y_weighted, sample_weight=sw)
1169+
with warnings.catch_warnings(record=True):
1170+
# Ensure we converge, otherwise debugging sample_weight equivalence
1171+
# failures can be very misleading.
1172+
warnings.simplefilter("error", category=ConvergenceWarning)
1173+
1174+
estimator_repeated.fit(X_repeated, y=y_repeated, sample_weight=None)
1175+
estimator_weighted.fit(X_weigthed, y=y_weighted, sample_weight=sw)
11661176

11671177
X_test = rng.uniform(low=X.min(), high=X.max(), size=(300, n_features))
11681178

0 commit comments

Comments
 (0)
0