8000 MAINT Parameters validation for `SpectralEmbedding` (#24103) · scikit-learn/scikit-learn@5944077 · GitHub
[go: up one dir, main page]

Skip to content

Commit 5944077

Browse files
MAINT Parameters validation for SpectralEmbedding (#24103)
Co-authored-by: jeremie du boisberranger <jeremiedbb@yahoo.fr>
1 parent 42fa09e commit 5944077

File tree

3 files changed

+24
-30
lines changed

3 files changed

+24
-30
lines changed

sklearn/manifold/_spectral_embedding.py

Lines changed: 24 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
# License: BSD 3 clause
66

77

8+
from numbers import Integral, Real
89
import warnings
910

1011
import numpy as np
@@ -22,6 +23,7 @@
2223
)
2324
from ..utils._arpack import _init_arpack_v0
2425
from ..utils.extmath import _deterministic_vector_sign_flip
26+
from ..utils._param_validation import Interval, StrOptions
2527
from ..utils.fixes import lobpcg
2628
from ..metrics.pairwise import rbf_kernel
2729
from ..neighbors import kneighbors_graph, NearestNeighbors
@@ -542,6 +544,27 @@ class SpectralEmbedding(BaseEstimator):
542544
(100, 2)
543545
"""
544546

547+
_parameter_constraints: dict = {
548+
"n_components": [Interval(Integral, 1, None, closed="left")],
549+
"affinity": [
550+
StrOptions(
551+
{
552+
"nearest_neighbors",
553+
"rbf",
554+
"precomputed",
555+
"precomputed_nearest_neighbors",
556+
},
557+
),
558+
callable,
559+
],
560+
"gamma": [Interval(Real, 0, None, closed="left"), None],
561+
"random_state": ["random_state"],
562+
"eigen_solver": [StrOptions({"arpack", "lobpcg", "amg"}), None],
563+
"eigen_tol": [Interval(Real, 0, None, closed="left"), StrOptions({"auto"})],
564+
"n_neighbors": [Interval(Integral, 1, None, closed="left"), None],
565+
"n_jobs": [None, Integral],
566+
}
567+
545568
def __init__(
546569
self,
547570
n_components=2,
@@ -649,28 +672,11 @@ def fit(self, X, y=None):
649672
self : object
650673
Returns the instance itself.
651674
"""
675+
self._validate_params()
652676

653677
X = self._validate_data(X, accept_sparse="csr", ensure_min_samples=2)
654678

655679
random_state = check_random_state(self.random_state)
656-
if isinstance(self.affinity, str):
657-
if self.affinity not in {
658-
"nearest_neighbors",
659-
"rbf",
660-
"precomputed",
661-
"precomputed_nearest_neighbors",
662-
}:
663-
raise ValueError(
664-
"%s is not a valid affinity. Expected "
665-
"'precomputed', 'rbf', 'nearest_neighbors' "
666-
"or a callable."
667-
% self.affinity
668-
)
669-
elif not callable(self.affinity):
670-
raise ValueError(
671-
"'affinity' is expected to be an affinity name or a callable. Got: %s"
672-
% self.affinity
673-
)
674680

675681
affinity_matrix = self._get_affinity_matrix(X)
676682
self.embedding_ = spectral_embedding(

sklearn/manifold/tests/test_spectral_embedding.py

Lines changed: 0 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -351,17 +351,6 @@ def test_spectral_embedding_unknown_eigensolver(seed=36):
351351
se.fit(S)
352352

353353

354-
def test_spectral_embedding_unknown_affinity(seed=36):
355-
# Test that SpectralClustering fails with an unknown affinity type
356-
se = SpectralEmbedding(
357-
n_components=1,
358-
affinity="<unknown>",
359-
random_state=np.random.RandomState(seed),
360-
)
361-
with pytest.raises(ValueError):
362-
se.fit(S)
363-
364-
365354
def test_connectivity(seed=36):
366355
# Test that graph connectivity test works as expected
367356
graph = np.array(

sklearn/tests/test_common.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -484,7 +484,6 @@ def test_estimators_do_not_raise_errors_in_init_or_set_params(Estimator):
484484
"SelectFromModel",
485485
"SpectralBiclustering",
486486
"SpectralCoclustering",
487-
"SpectralEmbedding",
488487
]
489488

490489

0 commit comments

Comments
 (0)
0