38
38
from .base import ClassNamePrefixFeaturesOutMixin
39
39
40
40
from .utils import check_random_state
41
- from .utils ._param_validation import Interval , StrOptions
41
+ from .utils ._param_validation import Interval , StrOptions , validate_params
42
42
from .utils .extmath import safe_sparse_dot
43
43
from .utils .random import sample_without_replacement
44
44
from .utils .validation import check_array , check_is_fitted
51
51
]
52
52
53
53
54
+ @validate_params (
55
+ {
56
+ "n_samples" : ["array-like" , Interval (Real , 1 , None , closed = "left" )],
57
+ "eps" : ["array-like" , Interval (Real , 0 , 1 , closed = "neither" )],
58
+ }
59
+ )
54
60
def johnson_lindenstrauss_min_dim (n_samples , * , eps = 0.1 ):
55
61
"""Find a 'safe' number of components to randomly project to.
56
62
57
63
The distortion introduced by a random projection `p` only changes the
58
- distance between two points by a factor (1 +- eps) in an euclidean space
64
+ distance between two points by a factor (1 +- eps) in a euclidean space
59
65
with good probability. The projection `p` is an eps-embedding as defined
60
66
by:
61
67
@@ -81,12 +87,12 @@ def johnson_lindenstrauss_min_dim(n_samples, *, eps=0.1):
81
87
Parameters
82
88
----------
83
89
n_samples : int or array-like of int
84
- Number of samples that should be a integer greater than 0. If an array
90
+ Number of samples that should be an integer greater than 0. If an array
85
91
is given, it will compute a safe number of components array-wise.
86
92
87
- eps : float or ndarray of shape (n_components,), dtype=float, \
93
+ eps : float or array-like of shape (n_components,), dtype=float, \
88
94
default=0.1
89
- Maximum distortion rate in the range (0,1 ) as defined by the
95
+ Maximum distortion rate in the range (0, 1 ) as defined by the
90
96
Johnson-Lindenstrauss lemma. If an array is given, it will compute a
91
97
safe number of components array-wise.
92
98
@@ -123,7 +129,7 @@ def johnson_lindenstrauss_min_dim(n_samples, *, eps=0.1):
123
129
if np .any (eps <= 0.0 ) or np .any (eps >= 1 ):
124
130
raise ValueError ("The JL bound is defined for eps in ]0, 1[, got %r" % eps )
125
131
126
- if np .any (n_samples ) <= 0 :
132
+ if np .any (n_samples <= 0 ) :
127
133
raise ValueError (
128
134
"The JL bound is defined for n_samples greater than zero, got %r"
129
135
% n_samples
0 commit comments