|
7 | 7 | from functools import partial
|
8 | 8 | from inspect import isfunction
|
9 | 9 |
|
10 |
| -from sklearn import config_context |
| 10 | +from sklearn import clone, config_context |
11 | 11 | from sklearn.calibration import CalibratedClassifierCV
|
12 | 12 | from sklearn.cluster import (
|
13 | 13 | HDBSCAN,
|
|
33 | 33 | FactorAnalysis,
|
34 | 34 | FastICA,
|
35 | 35 | IncrementalPCA,
|
| 36 | + KernelPCA, |
36 | 37 | LatentDirichletAllocation,
|
37 | 38 | MiniBatchDictionaryLearning,
|
38 | 39 | MiniBatchNMF,
|
|
41 | 42 | SparsePCA,
|
42 | 43 | TruncatedSVD,
|
43 | 44 | )
|
| 45 | +from sklearn.discriminant_analysis import LinearDiscriminantAnalysis |
44 | 46 | from sklearn.dummy import DummyClassifier
|
45 | 47 | from sklearn.ensemble import (
|
46 | 48 | AdaBoostClassifier,
|
|
72 | 74 | SelectKBest,
|
73 | 75 | SequentialFeatureSelector,
|
74 | 76 | )
|
| 77 | +from sklearn.kernel_approximation import ( |
| 78 | + Nystroem, |
| 79 | + PolynomialCountSketch, |
| 80 | + RBFSampler, |
| 81 | + SkewedChi2Sampler, |
| 82 | +) |
75 | 83 | from sklearn.linear_model import (
|
76 | 84 | ARDRegression,
|
77 | 85 | BayesianRidge,
|
|
105 | 113 | TheilSenRegressor,
|
106 | 114 | TweedieRegressor,
|
107 | 115 | )
|
108 |
| -from sklearn.manifold import MDS, TSNE, LocallyLinearEmbedding, SpectralEmbedding |
| 116 | +from sklearn.manifold import ( |
| 117 | + MDS, |
| 118 | + TSNE, |
| 119 | + Isomap, |
| 120 | + LocallyLinearEmbedding, |
| 121 | + SpectralEmbedding, |
| 122 | +) |
109 | 123 | from sklearn.mixture import BayesianGaussianMixture, GaussianMixture
|
110 | 124 | from sklearn.model_selection import (
|
111 | 125 | FixedThresholdClassifier,
|
|
457 | 471 | ),
|
458 | 472 | }
|
459 | 473 |
|
| 474 | +# This dictionary stores parameters for specific checks. It also enables running the |
| 475 | +# same check with multiple instances of the same estimator with different parameters. |
| 476 | +# The special key "*" allows to apply the parameters to all checks. |
| 477 | +# TODO(devtools): allow third-party developers to pass test specific params to checks |
| 478 | +PER_ESTIMATOR_CHECK_PARAMS: dict = { |
| 479 | + # TODO(devtools): check that function names here exist in checks for the estimator |
| 480 | + # TODO(devtools): write a test for the same thing with tags._xfail_checks |
| 481 | + AgglomerativeClustering: {"check_dict_unchanged": dict(n_clusters=1)}, |
| 482 | + BayesianGaussianMixture: {"check_dict_unchanged": dict(max_iter=5, n_init=2)}, |
| 483 | + BernoulliRBM: {"check_dict_unchanged": dict(n_components=1, n_iter=5)}, |
| 484 | + Birch: {"check_dict_unchanged": dict(n_clusters=1)}, |
| 485 | + BisectingKMeans: {"check_dict_unchanged": dict(max_iter=5, n_clusters=1, n_init=2)}, |
| 486 | + CCA: {"check_dict_unchanged": dict(max_iter=5, n_components=1)}, |
| 487 | + DictionaryLearning: { |
| 488 | + "check_dict_unchanged": dict( |
| 489 | + max_iter=20, n_components=1, transform_algorithm="lasso_lars" |
| 490 | + ) |
| 491 | + }, |
| 492 | + FactorAnalysis: {"check_dict_unchanged": dict(max_iter=5, n_components=1)}, |
| 493 | + FastICA: {"check_dict_unchanged": dict(max_iter=5, n_components=1)}, |
| 494 | + FeatureAgglomeration: {"check_dict_unchanged": dict(n_clusters=1)}, |
| 495 | + GaussianMixture: {"check_dict_unchanged": dict(max_iter=5, n_init=2)}, |
| 496 | + GaussianRandomProjection: {"check_dict_unchanged": dict(n_components=1)}, |
| 497 | + IncrementalPCA: {"check_dict_unchanged": dict(batch_size=10, n_components=1)}, |
| 498 | + Isomap: {"check_dict_unchanged": dict(n_components=1)}, |
| 499 | + KMeans: {"check_dict_unchanged": dict(max_iter=5, n_clusters=1, n_init=2)}, |
| 500 | + KernelPCA: {"check_dict_unchanged": dict(n_components=1)}, |
| 501 | + LatentDirichletAllocation: { |
| 502 | + "check_dict_unchanged": dict(batch_size=10, max_iter=5, n_components=1) |
| 503 | + }, |
| 504 | + LinearDiscriminantAnalysis: {"check_dict_unchanged": dict(n_components=1)}, |
| 505 | + LocallyLinearEmbedding: {"check_dict_unchanged": dict(max_iter=5, n_components=1)}, |
| 506 | + MDS: {"check_dict_unchanged": dict(max_iter=5, n_components=1, n_init=2)}, |
| 507 | + MiniBatchDictionaryLearning: { |
| 508 | + "check_dict_unchanged": dict(batch_size=10, max_iter=5, n_components=1) |
| 509 | + }, |
| 510 | + MiniBatchKMeans: { |
| 511 | + "check_dict_unchanged": dict(batch_size=10, max_iter=5, n_clusters=1, n_init=2) |
| 512 | + }, |
| 513 | + MiniBatchNMF: { |
| 514 | + "check_dict_unchanged": dict( |
| 515 | + batch_size=10, fresh_restarts=True, max_iter=20, n_components=1 |
| 516 | + ) |
| 517 | + }, |
| 518 | + MiniBatchSparsePCA: { |
| 519 | + "check_dict_unchanged": dict(batch_size=10, max_iter=5, n_components=1) |
| 520 | + }, |
| 521 | + NMF: {"check_dict_unchanged": dict(max_iter=500, n_components=1)}, |
| 522 | + NeighborhoodComponentsAnalysis: { |
| 523 | + "check_dict_unchanged": dict(max_iter=5, n_components=1) |
| 524 | + }, |
| 525 | + Nystroem: {"check_dict_unchanged": dict(n_components=1)}, |
| 526 | + PCA: {"check_dict_unchanged": dict(n_components=1)}, |
| 527 | + PLSCanonical: {"check_dict_unchanged": dict(max_iter=5, n_components=1)}, |
| 528 | + PLSRegression: {"check_dict_unchanged": dict(max_iter=5, n_components=1)}, |
| 529 | + PLSSVD: {"check_dict_unchanged": dict(n_components=1)}, |
| 530 | + PolynomialCountSketch: {"check_dict_unchanged": dict(n_components=1)}, |
| 531 | + RBFSampler: {"check_dict_unchanged": dict(n_components=1)}, |
| 532 | + SkewedChi2Sampler: {"check_dict_unchanged": dict(n_components=1)}, |
| 533 | + SparsePCA: {"check_dict_unchanged": dict(max_iter=5, n_components=1)}, |
| 534 | + SparseRandomProjection: {"check_dict_unchanged": dict(n_components=1)}, |
| 535 | + SpectralBiclustering: { |
| 536 | + "check_dict_unchanged": dict(n_best=1, n_clusters=1, n_components=1, n_init=2) |
| 537 | + }, |
| 538 | + SpectralClustering: { |
| 539 | + "check_dict_unchanged": dict(n_clusters=1, n_components=1, n_init=2) |
| 540 | + }, |
| 541 | + SpectralCoclustering: {"check_dict_unchanged": dict(n_clusters=1, n_init=2)}, |
| 542 | + SpectralEmbedding: {"check_dict_unchanged": dict(eigen_tol=1e-05, n_components=1)}, |
| 543 | + TSNE: {"check_dict_unchanged": dict(n_components=1, perplexity=2)}, |
| 544 | + TruncatedSVD: {"check_dict_unchanged": dict(n_components=1)}, |
| 545 | +} |
| 546 | + |
460 | 547 |
|
461 | 548 | def _tested_estimators(type_filter=None):
|
462 | 549 | for name, Estimator in all_estimators(type_filter=type_filter):
|
@@ -527,3 +614,38 @@ def _get_check_estimator_ids(obj):
|
527 | 614 | if hasattr(obj, "get_params"):
|
528 | 615 | with config_context(print_changed_only=True):
|
529 | 616 | return re.sub(r"\s", "", str(obj))
|
| 617 | + |
| 618 | + |
| 619 | +def _yield_instances_for_check(check, estimator_orig): |
| 620 | + """Yield instances for a check. |
| 621 | +
|
| 622 | + For most estimators, this is a no-op. |
| 623 | +
|
| 624 | + For estimators which have an entry in PER_ESTIMATOR_CHECK_PARAMS, this will yield |
| 625 | + an estimator for each parameter set in PER_ESTIMATOR_CHECK_PARAMS[estimator]. |
| 626 | + """ |
| 627 | + # TODO(devtools): enable this behavior for third party estimators as well |
| 628 | + if type(estimator_orig) not in PER_ESTIMATOR_CHECK_PARAMS: |
| 629 | + yield estimator_orig |
| 630 | + return |
| 631 | + |
| 632 | + check_params = PER_ESTIMATOR_CHECK_PARAMS[type(estimator_orig)] |
| 633 | + |
| 634 | + try: |
| 635 | + check_name = check.__name__ |
| 636 | + except AttributeError: |
| 637 | + # partial tests |
| 638 | + check_name = check.func.__name__ |
| 639 | + |
| 640 | + if check_name not in check_params: |
| 641 | + yield estimator_orig |
| 642 | + return |
| 643 | + |
| 644 | + param_set = check_params[check_name] |
| 645 | + if isinstance(param_set, dict): |
| 646 | + param_set = [param_set] |
| 647 | + |
| 648 | + for params in param_set: |
| 649 | + estimator = clone(estimator_orig) |
| 650 | + estimator.set_params(**params) |
| 651 | + yield estimator |
0 commit comments