8000 MAINT Parameters validation for `sklearn.manifold.spectral_embedding… · jeremiedbb/scikit-learn@a25e3ef · GitHub
[go: up one dir, main page]

Skip to content

Commit a25e3ef

Browse files
MAINT Parameters validation for sklearn.manifold.spectral_embedding (scikit-learn#25579)
Co-authored-by: jeremie du boisberranger <jeremiedbb@yahoo.fr>
1 parent 24452ef commit a25e3ef

File tree

3 files changed

+40
-11
lines changed

3 files changed

+40
-11
lines changed

sklearn/cluster/_spectral.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from scipy.sparse import csc_matrix
1515

1616
from ..base import BaseEstimator, ClusterMixin, _fit_context
17-
from ..manifold import spectral_embedding
17+
from ..manifold._spectral_embedding import _spectral_embedding
1818
from ..metrics.pairwise import KERNEL_PARAMS, pairwise_kernels
1919
from ..neighbors import NearestNeighbors, kneighbors_graph
2020
from ..utils import as_float_array, check_random_state
@@ -741,7 +741,7 @@ def fit(self, X, y=None):
741741
# The first eigenvector is constant only for fully connected graphs
742742
# and should be kept for spectral clustering (drop_first = False)
743743
# See spectral_embedding documentation.
744-
maps = spectral_embedding(
744+
maps = _spectral_embedding(
745745
self.affinity_matrix_,
746746
n_components=n_components,
747747
eigen_solver=self.eigen_solver,

sklearn/manifold/_spectral_embedding.py

Lines changed: 37 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
check_symmetric,
2424
)
2525
from ..utils._arpack import _init_arpack_v0
26-
from ..utils._param_validation import Interval, StrOptions
26+
from ..utils._param_validation import Interval, StrOptions, validate_params
2727
from ..utils.extmath import _deterministic_vector_sign_flip
2828
from ..utils.fixes import laplacian as csgraph_laplacian
2929
from ..utils.fixes import parse_version, sp_version
@@ -152,6 +152,18 @@ def _set_diag(laplacian, value, norm_laplacian):
152152
return laplacian
153153

154154

155+
@validate_params(
156+
{
157+
"adjacency": ["array-like", "sparse matrix"],
158+
"n_components": [Interval(Integral, 1, None, closed="left")],
159+
"eigen_solver": [StrOptions({"arpack", "lobpcg", "amg"}), None],
160+
"random_state": ["random_state"],
161+
"eigen_tol": [Interval(Real, 0, None, closed="left"), StrOptions({"auto"})],
162+
"norm_laplacian": ["boolean"],
163+
"drop_first": ["boolean"],
164+
},
165+
prefer_skip_nested_validation=True,
166+
)
155167
def spectral_embedding(
156168
adjacency,
157169
*,
@@ -272,6 +284,29 @@ def spectral_embedding(
272284
>>> embedding.shape
273285
(100, 2)
274286
"""
287+
random_state = check_random_state(random_state)
288+
289+
return _spectral_embedding(
290+
adjacency,
291+
n_components=n_components,
292+
eigen_solver=eigen_solver,
293+
random_state=random_state,
294+
eigen_tol=eigen_tol,
295+
norm_laplacian=norm_laplacian,
296+
drop_first=drop_first,
297+
)
298+
299+
300+
def _spectral_embedding(
301+
adjacency,
302+
*,
303+
n_components=8,
304+
eigen_solver=None,
305+
random_state=None,
306+
eigen_tol="auto",
307+
norm_laplacian=True,
308+
drop_first=True,
309+
):
275310
adjacency = check_symmetric(adjacency)
276311

277312
if eigen_solver == "amg":
@@ -284,13 +319,6 @@ def spectral_embedding(
284319

285320
if eigen_solver is None:
286321
eigen_solver = "arpack"
287-
elif eigen_solver not in ("arpack", "lobpcg", "amg"):
288-
raise ValueError(
289-
"Unknown value for eigen_solver: '%s'."
290-
"Should be 'amg', 'arpack', or 'lobpcg'" % eigen_solver
291-
)
292-
293-
random_state = check_random_state(random_state)
294322

295323
n_nodes = adjacency.shape[0]
296324
# Whether to drop the first eigenvector
@@ -714,7 +742,7 @@ def fit(self, X, y=None):
714742
random_state = check_random_state(self.random_state)
715743

716744
affinity_matrix = self._get_affinity_matrix(X)
717-
self.embedding_ = spectral_embedding(
745+
self.embedding_ = _spectral_embedding(
718746
affinity_matrix,
719747
n_components=self.n_components,
720748
eigen_solver=self.eigen_solver,

sklearn/tests/test_public_functions.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -209,6 +209,7 @@ def _check_function_param_validation(
209209
"sklearn.linear_model.ridge_regression",
210210
"sklearn.manifold.locally_linear_embedding",
211211
"sklearn.manifold.smacof",
212+
"sklearn.manifold.spectral_embedding",
212213
"sklearn.manifold.trustworthiness",
213214
"sklearn.metrics.accuracy_score",
214215
"sklearn.metrics.auc",

0 commit comments

Comments
 (0)
0