10000 MAINT Add parameter validation to locally_linear_embedding (#25581) · jeremiedbb/scikit-learn@24452ef · GitHub
[go: up one dir, main page]

Skip to content

Commit 24452ef

Browse files
MAINT Add parameter validation to locally_linear_embedding (scikit-learn#25581)
Co-authored-by: jeremie du boisberranger <jeremiedbb@yahoo.fr>
1 parent a45d106 commit 24452ef

File tree

2 files changed

+159
-119
lines changed

2 files changed

+159
-119
lines changed

sklearn/manifold/_locally_linear.py

Lines changed: 157 additions & 118 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
from ..neighbors import NearestNeighbors
2222
from ..utils import check_array, check_random_state
2323
from ..utils._arpack import _init_arpack_v0
24-
from ..utils._param_validation import Interval, StrOptions
24+
from ..utils._param_validation import Interval, StrOptions, validate_params
2525
from ..utils.extmath import stable_cumsum
2626
from ..utils.validation import FLOAT_DTYPES, check_is_fitted
2727

@@ -198,7 +198,7 @@ def null_space(
198198
raise ValueError("Unrecognized eigen_solver '%s'" % eigen_solver)
199199

200200

201-
def locally_linear_embedding(
201+
def _locally_linear_embedding(
202202
X,
203203
*,
204204
n_neighbors,
@@ -213,118 +213,6 @@ def locally_linear_embedding(
213213
random_state=None,
214214
8000 n_jobs=None,
215215
):
216-
"""Perform a Locally Linear Embedding analysis on the data.
217-
218-
Read more in the :ref:`User Guide <locally_linear_embedding>`.
219-
220-
Parameters
221-
----------
222-
X : {array-like, NearestNeighbors}
223-
Sample data, shape = (n_samples, n_features), in the form of a
224-
numpy array or a NearestNeighbors object.
225-
226-
n_neighbors : int
227-
Number of neighbors to consider for each point.
228-
229-
n_components : int
230-
Number of coordinates for the manifold.
231-
232-
reg : float, default=1e-3
233-
Regularization constant, multiplies the trace of the local covariance
234-
matrix of the distances.
235-
236-
eigen_solver : {'auto', 'arpack', 'dense'}, default='auto'
237-
auto : algorithm will attempt to choose the best method for input data
238-
239-
arpack : use arnoldi iteration in shift-invert mode.
240-
For this method, M may be a dense matrix, sparse matrix,
241-
or general linear operator.
242-
Warning: ARPACK can be unstable for some problems. It is
243-
best to try several random seeds in order to check results.
244-
245-
dense : use standard dense matrix operations for the eigenvalue
246-
decomposition. For this method, M must be an array
247-
or matrix type. This method should be avoided for
248-
large problems.
249-
250-
tol : float, default=1e-6
251-
Tolerance for 'arpack' method
252-
Not used if eigen_solver=='dense'.
253-
254-
max_iter : int, default=100
255-
Maximum number of iterations for the arpack solver.
256-
257-
method : {'standard', 'hessian', 'modified', 'ltsa'}, default='standard'
258-
standard : use the standard locally linear embedding algorithm.
259-
see reference [1]_
260-
hessian : use the Hessian eigenmap method. This method requires
261-
n_neighbors > n_components * (1 + (n_components + 1) / 2.
262-
see reference [2]_
263-
modified : use the modified locally linear embedding algorithm.
264-
see reference [3]_
265-
ltsa : use local tangent space alignment algorithm
266-
see reference [4]_
267-
268-
hessian_tol : float, default=1e-4
269-
Tolerance for Hessian eigenmapping method.
270-
Only used if method == 'hessian'.
271-
272-
modified_tol : float, default=1e-12
273-
Tolerance for modified LLE method.
274-
Only used if method == 'modified'.
275-
276-
random_state : int, RandomState instance, default=None
277-
Determines the random number generator when ``solver`` == 'arpack'.
278-
Pass an int for reproducible results across multiple function calls.
279-
See :term:`Glossary <random_state>`.
280-
281-
n_jobs : int or None, default=None
282-
The number of parallel jobs to run for neighbors search.
283-
``None`` means 1 unless in a :obj:`joblib.parallel_backend` context.
284-
``-1`` means using all processors. See :term:`Glossary <n_jobs>`
285-
for more details.
286-
287-
Returns
288-
-------
289-
Y : array-like, shape [n_samples, n_components]
290-
Embedding vectors.
291-
292-
squared_error : float
293-
Reconstruction error for the embedding vectors. Equivalent to
294-
``norm(Y - W Y, 'fro')**2``, where W are the reconstruction weights.
295-
296-
References
297-
----------
298-
299-
.. [1] Roweis, S. & Saul, L. Nonlinear dimensionality reduction
300-
by locally linear embedding. Science 290:2323 (2000).
301-
.. [2] Donoho, D. & Grimes, C. Hessian eigenmaps: Locally
302-
linear embedding techniques for high-dimensional data.
303-
Proc Natl Acad Sci U S A. 100:5591 (2003).
304-
.. [3] `Zhang, Z. & Wang, J. MLLE: Modified Locally Linear
305-
Embedding Using Multiple Weights.
306-
<https://citeseerx.ist.psu.edu/doc_view/pid/0b060fdbd92cbcc66b383bcaa9ba5e5e624d7ee3>`_
307-
.. [4] Zhang, Z. & Zha, H. Principal manifolds and nonlinear
308-
dimensionality reduction via tangent space alignment.
309-
Journal of Shanghai Univ. 8:406 (2004)
310-
311-
Examples
312-
--------
313-
>>> from sklearn.datasets import load_digits
314-
>>> from sklearn.manifold import locally_linear_embedding
315-
>>> X, _ = load_digits(return_X_y=True)
316-
>>> X.shape
317-
(1797, 64)
318-
>>> embedding, _ = locally_linear_embedding(X[:100],n_neighbors=5, n_components=2)
319-
>>> embedding.shape
320-
(100, 2)
321-
"""
322-
if eigen_solver not in ("auto", "arpack", "dense"):
323-
raise ValueError("unrecognized eigen_solver '%s'" % eigen_solver)
324-
325-
if method not in ("standard", "hessian", "modified", "ltsa"):
326-
raise ValueError("unrecognized method '%s'" % method)
327-
328216
nbrs = NearestNeighbors(n_neighbors=n_neighbors + 1, n_jobs=n_jobs)
329217
nbrs.fit(X)
330218
X = nbrs._fit_X
@@ -341,9 +229,6 @@ def locally_linear_embedding(
341229
% (N, n_neighbors)
342230
)
343231

344-
if n_neighbors <= 0:
345-
raise ValueError("n_neighbors must be positive")
346-
347232
M_sparse = eigen_solver != "dense"
348233

349234
if method == "standard":
@@ -561,6 +446,160 @@ def locally_linear_embedding(
561446
)
562447

563448

449+
@validate_params(
450+
{
451+
"X": ["array-like", NearestNeighbors],
452+
"n_neighbors": [Interval(Integral, 1, None, closed="left")],
453+
"n_components": [Interval(Integral, 1, None, closed="left")],
454+
"reg": [Interval(Real, 0, None, closed="left")],
455+
"eigen_solver": [StrOptions({"auto", "arpack", "dense"})],
456+
"tol": [Interval(Real, 0, None, closed="left")],
457+
"max_iter": [Interval(Integral, 1, None, closed="left")],
458+
"method": [StrOptions({"standard", "hessian", "modified", "ltsa"})],
459+
"hessian_tol": [Interval(Real, 0, None, closed="left")],
460+
"modified_tol": [Interval(Real, 0, None, closed="left")],
461+
"random_state": ["random_state"],
462+
"n_jobs": [None, Integral],
463+
},
464+
prefer_skip_nested_validation=True,
465+
)
466+
def locally_linear_embedding(
467+
X,
468+
*,
469+
n_neighbors,
470+
n_components,
471+
reg=1e-3,
472+
eigen_solver="auto",
473+
tol=1e-6,
474+
max_iter=100,
475+
method="standard",
476+
hessian_tol=1e-4,
477+
modified_tol=1e-12,
478+
random_state=None,
479+
n_jobs=None,
480+
):
481+
"""Perform a Locally Linear Embedding analysis on the data.
482+
483+
Read more in the :ref:`User Guide <locally_linear_embedding>`.
484+
485+
Parameters
486+
----------
487+
X : {array-like, NearestNeighbors}
488+
Sample data, shape = (n_samples, n_features), in the form of a
489+
numpy array or a NearestNeighbors object.
490+
491+
n_neighbors : int
492+
Number of neighbors to consider for each point.
493+
494+
n_components : int
495+
Number of coordinates for the manifold.
496+
497+
reg : float, default=1e-3
498+
Regularization constant, multiplies the trace of the local covariance
499+
matrix of the distances.
500+
501+
eigen_solver : {'auto', 'arpack', 'dense'}, default='auto'
502+
auto : algorithm will attempt to choose the best method for input data
503+
504+
arpack : use arnoldi iteration in shift-invert mode.
505+
For this method, M may be a dense matrix, sparse matrix,
506+
or general linear operator.
507+
Warning: ARPACK can be unstable for some problems. It is
508+
best to try several random seeds in order to check results.
509+
510+
dense : use standard dense matrix operations for the eigenvalue
511+
decomposition. For this method, M must be an array
512+
or matrix type. This method should be avoided for
513+
large problems.
514+
515+
tol : float, default=1e-6
516+
Tolerance for 'arpack' method
517+
Not used if eigen_solver=='dense'.
518+
519+
max_iter : int, default=100
520+
Maximum number of iterations for the arpack solver.
521+
522+
method : {'standard', 'hessian', 'modified', 'ltsa'}, default='standard'
523+
standard : use the standard locally linear embedding algorithm.
524+
see reference [1]_
525+
hessian : use the Hessian eigenmap method. This method requires
526+
n_neighbors > n_components * (1 + (n_components + 1) / 2.
527+
see reference [2]_
528+
modified : use the modified locally linear embedding algorithm.
529+
see reference [3]_
530+
ltsa : use local tangent space alignment algorithm
531+
see reference [4]_
532+
533+
hessian_tol : float, default=1e-4
534+
Tolerance for Hessian eigenmapping method.
535+
Only used if method == 'hessian'.
536+
537+
modified_tol : float, default=1e-12
538+
Tolerance for modified LLE method.
539+
Only used if method == 'modified'.
540+
541+
random_state : int, RandomState instance, default=None
542+
Determines the random number generator when ``solver`` == 'arpack'.
543+
Pass an int for reproducible results across multiple function calls.
544+
See :term:`Glossary <random_state>`.
545+
546+
n_jobs : int or None, default=None
547+
The number of parallel jobs to run for neighbors search.
548+
``None`` means 1 unless in a :obj:`joblib.parallel_backend` context.
549+
``-1`` means using all processors. See :term:`Glossary <n_jobs>`
550+
for more details.
551+
552+
Returns
553+
-------
554+
Y : ndarray of shape (n_samples, n_components)
555+
Embedding vectors.
556+
557+
squared_error : float
558+
Reconstruction error for the embedding vectors. Equivalent to
559+
``norm(Y - W Y, 'fro')**2``, where W are the reconstruction weights.
560+
561+
References
562+
----------
563+
564+
.. [1] Roweis, S. & Saul, L. Nonlinear dimensionality reduction
565+
by locally linear embedding. Science 290:2323 (2000).
566+
.. [2] Donoho, D. & Grimes, C. Hessian eigenmaps: Locally
567+
linear embedding techniques for high-dimensional data.
568+
Proc Natl Acad Sci U S A. 100:5591 (2003).
569+
.. [3] `Zhang, Z. & Wang, J. MLLE: Modified Locally Linear
570+
Embedding Using Multiple Weights.
571+
<https://citeseerx.ist.psu.edu/doc_view/pid/0b060fdbd92cbcc66b383bcaa9ba5e5e624d7ee3>`_
572+
.. [4] Zhang, Z. & Zha, H. Principal manifolds and nonlinear
573+
dimensionality reduction via tangent space alignment.
574+
Journal of Shanghai Univ. 8:406 (2004)
575+
576+
Examples
577+
--------
578+
>>> from sklearn.datasets import load_digits
579+
>>> from sklearn.manifold import locally_linear_embedding
580+
>>> X, _ = load_digits(return_X_y=True)
581+
>>> X.shape
582+
(1797, 64)
583+
>>> embedding, _ = locally_linear_embedding(X[:100],n_neighbors=5, n_components=2)
584+
>>> embedding.shape
585+
(100, 2)
586+
"""
587+
return _locally_linear_embedding(
588+
X=X,
589+
n_neighbors=n_neighbors,
590+
n_components=n_components,
591+
reg=reg,
592+
eigen_solver=eigen_solver,
593+
tol=tol,
594+
max_iter=max_iter,
595+
method=method,
596+
hessian_tol=hessian_tol,
597+
modified_tol=modified_tol,
598+
random_state=random_state,
599+
n_jobs=n_jobs,
600+
)
601+
602+
564603
class LocallyLinearEmbedding(
565604
ClassNamePrefixFeaturesOutMixin,
566605
TransformerMixin,
@@ -753,7 +792,7 @@ def _fit_transform(self, X):
753792
random_state = check_random_state(self.random_state)
754793
X = self._validate_data(X, dtype=float)
755794
self.nbrs_.fit(X)
756-
self.embedding_, self.reconstruction_error_ = locally_linear_embedding(
795+
self.embedding_, self.reconstruction_error_ = _locally_linear_embedding(
757796
X=self.nbrs_,
758797
n_neighbors=self.n_neighbors,
759798
n_components=self.n_components,

sklearn/tests/test_public_functions.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -207,9 +207,10 @@ def _check_function_param_validation(
207207
"sklearn.linear_model.orthogonal_mp",
208208
"sklearn.linear_model.orthogonal_mp_gram",
209209
"sklearn.linear_model.ridge_regression",
210+
"sklearn.manifold.locally_linear_embedding",
211+
"sklearn.manifold.smacof",
210212
"sklearn.manifold.trustworthiness",
211213
"sklearn.metrics.accuracy_score",
212-
"sklearn.manifold.smacof",
213214
"sklearn.metrics.auc",
214215
"sklearn.metrics.average_precision_score",
215216
"sklearn.metrics.balanced_accuracy_score",

0 commit comments

Comments
 (0)
0