|
5 | 5 | # License: BSD 3 clause
|
6 | 6 |
|
7 | 7 |
|
| 8 | +from numbers import Integral, Real |
8 | 9 | import warnings
|
9 | 10 |
|
10 | 11 | import numpy as np
|
|
22 | 23 | )
|
23 | 24 | from ..utils._arpack import _init_arpack_v0
|
24 | 25 | from ..utils.extmath import _deterministic_vector_sign_flip
|
| 26 | +from ..utils._param_validation import Interval, StrOptions |
25 | 27 | from ..utils.fixes import lobpcg
|
26 | 28 | from ..metrics.pairwise import rbf_kernel
|
27 | 29 | from ..neighbors import kneighbors_graph, NearestNeighbors
|
@@ -542,6 +544,27 @@ class SpectralEmbedding(BaseEstimator):
|
542 | 544 | (100, 2)
|
543 | 545 | """
|
544 | 546 |
|
| 547 | + _parameter_constraints: dict = { |
| 548 | + "n_components": [Interval(Integral, 1, None, closed="left")], |
| 549 | + "affinity": [ |
| 550 | + StrOptions( |
| 551 | + { |
| 552 | + "nearest_neighbors", |
| 553 | + "rbf", |
| 554 | + "precomputed", |
| 555 | + "precomputed_nearest_neighbors", |
| 556 | + }, |
| 557 | + ), |
| 558 | + callable, |
| 559 | + ], |
| 560 | + "gamma": [Interval(Real, 0, None, closed="left"), None], |
| 561 | + "random_state": ["random_state"], |
| 562 | + "eigen_solver": [StrOptions({"arpack", "lobpcg", "amg"}), None], |
| 563 | + "eigen_tol": [Interval(Real, 0, None, closed="left"), StrOptions({"auto"})], |
| 564 | + "n_neighbors": [Interval(Integral, 1, None, closed="left"), None], |
| 565 | + "n_jobs": [None, Integral], |
| 566 | + } |
| 567 | + |
545 | 568 | def __init__(
|
546 | 569 | self,
|
547 | 570 | n_components=2,
|
@@ -649,28 +672,11 @@ def fit(self, X, y=None):
|
649 | 672 | self : object
|
650 | 673 | Returns the instance itself.
|
651 | 674 | """
|
| 675 | + self._validate_params() |
652 | 676 |
|
653 | 677 | X = self._validate_data(X, accept_sparse="csr", ensure_min_samples=2)
|
654 | 678 |
|
655 | 679 | random_state = check_random_state(self.random_state)
|
656 |
| - if isinstance(self.affinity, str): |
657 |
| - if self.affinity not in { |
658 |
| - "nearest_neighbors", |
659 |
| - "rbf", |
660 |
| - "precomputed", |
661 |
| - "precomputed_nearest_neighbors", |
662 |
| - }: |
663 |
| - raise ValueError( |
664 |
| - "%s is not a valid affinity. Expected " |
665 |
| - "'precomputed', 'rbf', 'nearest_neighbors' " |
666 |
| - "or a callable." |
667 |
| - % self.affinity |
668 |
| - ) |
669 |
| - elif not callable(self.affinity): |
670 |
| - raise ValueError( |
671 |
| - "'affinity' is expected to be an affinity name or a callable. Got: %s" |
672 |
| - % self.affinity |
673 |
| - ) |
674 | 680 |
|
675 | 681 | affinity_matrix = self._get_affinity_matrix(X)
|
676 | 682 | self.embedding_ = spectral_embedding(
|
|
0 commit comments