10000 TST allow setting per test settings for estimators (#29820) · scikit-learn/scikit-learn@7bcae6c · GitHub
[go: up one dir, main page]

Skip to content

Commit 7bcae6c

Browse files
authored
TST allow setting per test settings for estimators (#29820)
1 parent d4ab9ed commit 7bcae6c

File tree

3 files changed

+148
-26
lines changed

3 files changed

+148
-26
lines changed

sklearn/cluster/_bicluster.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -362,6 +362,17 @@ def _fit(self, X):
362362
[self.column_labels_ == c for c in range(self.n_clusters)]
363363
)
364364

365+
def __sklearn_tags__(self):
366+
tags = super().__sklearn_tags__()
367+
tags._xfail_checks.update(
368+
{
369+
# ValueError: Found array with 0 feature(s) (shape=(23, 0))
370+
# while a minimum of 1 is required.
371+
"check_dict_unchanged": "FIXME",
372+
}
373+
)
374+
return tags
375+
365376

366377
class SpectralBiclustering(BaseSpectral):
367378
"""Spectral biclustering (Kluger, 2003).

sklearn/utils/_test_common/instance_generator.py

Lines changed: 124 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from functools import partial
88
from inspect import isfunction
99

10-
from sklearn import config_context
10+
from sklearn import clone, config_context
1111
from sklearn.calibration import CalibratedClassifierCV
1212
from sklearn.cluster import (
1313
HDBSCAN,
@@ -33,6 +33,7 @@
3333
FactorAnalysis,
3434
FastICA,
3535
IncrementalPCA,
36+
KernelPCA,
3637
LatentDirichletAllocation,
3738
MiniBatchDictionaryLearning,
3839
MiniBatchNMF,
@@ -41,6 +42,7 @@
4142
SparsePCA,
4243
TruncatedSVD,
4344
)
45+
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis
4446
from sklearn.dummy import DummyClassifier
4547
from sklearn.ensemble import (
4648
AdaBoostClassifier,
@@ -72,6 +74,12 @@
7274
SelectKBest,
7375
SequentialFeatureSelector,
7476
)
77+
from sklearn.kernel_approximation import (
78+
Nystroem,
79+
PolynomialCountSketch,
80+
RBFSampler,
81+
SkewedChi2Sampler,
82+
)
7583
from sklearn.linear_model import (
7684
ARDRegression,
7785
BayesianRidge,
@@ -105,7 +113,13 @@
105113
TheilSenRegressor,
106114
TweedieRegressor,
107115
)
108-
from sklearn.manifold import MDS, TSNE, LocallyLinearEmbedding, SpectralEmbedding
116+
from sklearn.manifold import (
117+
MDS,
118+
TSNE,
119+
Isomap,
120+
LocallyLinearEmbedding,
121+
SpectralEmbedding,
122+
)
109123
from sklearn.mixture import BayesianGaussianMixture, GaussianMixture
110124
from sklearn.model_selection import (
111125
FixedThresholdClassifier,
@@ -457,6 +471,79 @@
457471
),
458472
}
459473

474+
# This dictionary stores parameters for specific checks. It also enables running the
475+
# same check with multiple instances of the same estimator with different parameters.
476+
# The special key "*" allows to apply the parameters to all checks.
477+
# TODO(devtools): allow third-party developers to pass test specific params to checks
478+
PER_ESTIMATOR_CHECK_PARAMS: dict = {
479+
# TODO(devtools): check that function names here exist in checks for the estimator
480+
# TODO(devtools): write a test for the same thing with tags._xfail_checks
481+
AgglomerativeClustering: {"check_dict_unchanged": dict(n_clusters=1)},
482+
BayesianGaussianMixture: {"check_dict_unchanged": dict(max_iter=5, n_init=2)},
483+
BernoulliRBM: {"check_dict_unchanged": dict(n_components=1, n_iter=5)},
484+
Birch: {"check_dict_unchanged": dict(n_clusters=1)},
485+
BisectingKMeans: {"check_dict_unchanged": dict(max_iter=5, n_clusters=1, n_init=2)},
486+
CCA: {"check_dict_unchanged": dict(max_iter=5, n_components=1)},
487+
DictionaryLearning: {
488+
"check_dict_unchanged": dict(
489+
max_iter=20, n_components=1, transform_algorithm="lasso_lars"
490+
)
491+
},
492+
FactorAnalysis: {"check_dict_unchanged": dict(max_iter=5, n_components=1)},
493+
FastICA: {"check_dict_unchanged": dict(max_iter=5, n_components=1)},
494+
FeatureAgglomeration: {"check_dict_unchanged": dict(n_clusters=1)},
495+
GaussianMixture: {"check_dict_unchanged": dict(max_iter=5, n_init=2)},
496+
GaussianRandomProjection: {"check_dict_unchanged": dict(n_components=1)},
497+
IncrementalPCA: {"check_dict_unchanged": dict(batch_size=10, n_components=1)},
498+
Isomap: {"check_dict_unchanged": dict(n_components=1)},
499+
KMeans: {"check_dict_unchanged": dict(max_iter=5, n_clusters=1, n_init=2)},
500+
KernelPCA: {"check_dict_unchanged": dict(n_components=1)},
501+
LatentDirichletAllocation: {
502+
"check_dict_unchanged": dict(batch_size=10, max_iter=5, n_components=1)
503+
},
504+
LinearDiscriminantAnalysis: {"check_dict_unchanged": dict(n_components=1)},
505+
LocallyLinearEmbedding: {"check_dict_unchanged": dict(max_iter=5, n_components=1)},
506+
MDS: {"check_dict_unchanged": dict(max_iter=5, n_components=1, n_init=2)},
507+
MiniBatchDictionaryLearning: {
508+
"check_dict_unchanged": dict(batch_size=10, max_iter=5, n_components=1)
509+
},
510+
MiniBatchKMeans: {
511+
"check_dict_unchanged": dict(batch_size=10, max_iter=5, n_clusters=1, n_init=2)
512+
},
513+
MiniBatchNMF: {
514+
"check_dict_unchanged": dict(
515+
batch_size=10, fresh_restarts=True, max_iter=20, n_components=1
516+
)
517+
},
518+
MiniBatchSparsePCA: {
519+
"check_dict_unchanged": dict(batch_size=10, max_iter=5, n_components=1)
520+
},
521+
NMF: {"check_dict_unchanged": dict(max_iter=500, n_components=1)},
522+
NeighborhoodComponentsAnalysis: {
523+
"check_dict_unchanged": dict(max_iter=5, n_components=1)
524+
},
525+
Nystroem: {"check_dict_unchanged": dict(n_components=1)},
526+
PCA: {"check_dict_unchanged": dict(n_components=1)},
527+
PLSCanonical: {"check_dict_unchanged": dict(max_iter=5, n_components=1)},
528+
PLSRegression: {"check_dict_unchanged": dict(max_iter=5, n_components=1)},
529+
PLSSVD: {"check_dict_unchanged": dict(n_components=1)},
530+
PolynomialCountSketch: {"check_dict_unchanged": dict(n_components=1)},
531+
RBFSampler: {"check_dict_unchanged": dict(n_components=1)},
532+
SkewedChi2Sampler: {"check_dict_unchanged": dict(n_components=1)},
533+
SparsePCA: {"check_dict_unchanged": dict(max_iter=5, n_components=1)},
534+
SparseRandomProjection: {"check_dict_unchanged": dict(n_components=1)},
535+
SpectralBiclustering: {
536+
"check_dict_unchanged": dict(n_best=1, n_clusters=1, n_components=1, n_init=2)
537+
},
538+
SpectralClustering: {
539+
"check_dict_unchanged": dict(n_clusters=1, n_components=1, n_init=2)
540+
},
541+
SpectralCoclustering: {"check_dict_unchanged": dict(n_clusters=1, n_init=2)},
542+
SpectralEmbedding: {"check_dict_unchanged": dict(eigen_tol=1e-05, n_components=1)},
543+
TSNE: {"check_dict_unchanged": dict(n_components=1, perplexity=2)},
544+
TruncatedSVD: {"check_dict_unchanged": dict(n_components=1)},
545+
}
546+
460547

461548
def _tested_estimators(type_filter=None):
462549
for name, Estimator in all_estimators(type_filter=type_filter):
@@ -527,3 +614,38 @@ def _get_check_estimator_ids(obj):
527614
if hasattr(obj, "get_params"):
528615
with config_context(print_changed_only=True):
529616
return re.sub(r"\s", "", str(obj))
617+
618+
619+
def _yield_instances_for_check(check, estimator_orig):
620+
"""Yield instances for a check.
621+
622+
For most estimators, this is a no-op.
623+
624+
For estimators which have an entry in PER_ESTIMATOR_CHECK_PARAMS, this will yield
625+
an estimator for each parameter set in PER_ESTIMATOR_CHECK_PARAMS[estimator].
626+
"""
627+
# TODO(devtools): enable this behavior for third party estimators as well
628+
if type(estimator_orig) not in PER_ESTIMATOR_CHECK_PARAMS:
629+
yield estimator_orig
630+
return
631+
632+
check_params = PER_ESTIMATOR_CHECK_PARAMS[type(estimator_orig)]
633+
634+
try:
635+
check_name = check.__name__
636+
except AttributeError:
637+
# partial tests
638+
check_name = check.func.__name__
639+
640+
if check_name not in check_params:
641+
yield estimator_orig
642+
return
643+
644+
param_set = check_params[check_name]
645+
if isinstance(param_set, dict):
646+
param_set = [param_set]
647+
648+
for params in param_set:
649+
estimator = clone(estimator_orig)
650+
estimator.set_params(**params)
651+
yield estimator

sklearn/utils/estimator_checks.py

Lines changed: 13 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@
5959
from ._test_common.instance_generator import (
6060
CROSS_DECOMPOSITION,
6161
_get_check_estimator_ids,
62+
_yield_instances_for_check,
6263
)
6364
from ._testing import (
6465
SkipTest,
@@ -81,7 +82,6 @@
8182

8283

8384
def _yield_api_checks(estimator):
84-
yield check_estimator_cloneable
8585
yield check_estimator_repr
8686
yield check_no_attributes_set_in_init
8787
yield check_fit_score_takes_y
@@ -509,10 +509,14 @@ def parametrize_with_checks(estimators, *, legacy=True):
509509

510510
def checks_generator():
511511
for estimator in estimators:
512+
# First check that the estimator is cloneable which is needed for the rest
513+
# of the checks to run
512514
name = type(estimator).__name__
515+
yield estimator, partial(check_estimator_cloneable, name)
513516
for check in _yield_all_checks(estimator, legacy=legacy):
514-
check = partial(check, name)
515-
yield _maybe_mark_xfail(estimator, check, pytest)
517+
check_with_name = partial(check, name)
518+
for check_instance in _yield_instances_for_check(check, estimator):
519+
yield _maybe_mark_xfail(check_instance, check_with_name, pytest)
516520

517521
return pytest.mark.parametrize(
518522
"estimator, check", checks_generator(), ids=_get_check_estimator_ids
@@ -597,9 +601,13 @@ def check_estimator(estimator=None, generate_only=False, *, legacy=True):
597601
name = type(estimator).__name__
598602

599603
def checks_generator():
604+
# we first need to check if the estimator is cloneable for the rest of the tests
605+
# to run
606+
yield estimator, partial(check_estimator_cloneable, name)
600607
for check in _yield_all_checks(estimator, legacy=legacy):
601608
check = _maybe_skip(estimator, check)
602-
yield estimator, partial(check, name)
609+
for check_instance in _yield_instances_for_check(check, estimator):
610+
yield check_instance, partial(check, name)
603611

604612
if generate_only:
605613
return checks_generator()
@@ -1257,32 +1265,13 @@ def check_complex_data(name, estimator_orig):
12571265

12581266
@ignore_warnings
12591267
def check_dict_unchanged(name, estimator_orig):
1260-
# this estimator raises
1261-
# ValueError: Found array with 0 feature(s) (shape=(23, 0))
1262-
# while a minimum of 1 is required.
1263-
# error
1264-
if name in ["SpectralCoclustering"]:
1265-
return
12661268
rnd = np.random.RandomState(0)
1267-
if name in ["RANSACRegressor"]:
1268-
X = 3 * rnd.uniform(size=(20, 3))
1269-
else:
1270-
X = 2 * rnd.uniform(size=(20, 3))
1271-
1269+
X = 3 * rnd.uniform(size=(20, 3))
12721270
X = _enforce_estimator_tags_X(estimator_orig, X)
12731271

12741272
y = X[:, 0].astype(int)
12751273
estimator = clone(estimator_orig)
12761274
y = _enforce_estimator_tags_y(estimator, y)
1277-
if hasattr(estimator, "n_components"):
1278-
estimator.n_components = 1
1279-
1280-
if hasattr(estimator, "n_clusters"):
1281-
estimator.n_clusters = 1
1282-
1283-
if hasattr(estimator, "n_best"):
1284-
estimator.n_best = 1
1285-
12861275
set_random_state(estimator, 1)
12871276

12881277
estimator.fit(X, y)

0 commit comments

Comments
 (0)
0