8000 MAINT Parameters validation for sklearn.random_projection.johnson_lin… · Anthony22-dev/scikit-learn@38d7fba · GitHub 65E1
[go: up one dir, main page]

Skip to content

Commit 38d7fba

Browse files
authored
MAINT Parameters validation for sklearn.random_projection.johnson_lindenstrauss_min_dim (scikit-learn#25278)
1 parent ba1d23d commit 38d7fba

File tree

3 files changed

+19
-7
lines changed

3 files changed

+19
-7
lines changed

sklearn/random_projection.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@
3838
from .base import ClassNamePrefixFeaturesOutMixin
3939

4040
from .utils import check_random_state
41-
from .utils._param_validation import Interval, StrOptions
41+
from .utils._param_validation import Interval, StrOptions, validate_params
4242
from .utils.extmath import safe_sparse_dot
4343
from .utils.random import sample_without_replacement
4444
from .utils.validation import check_array, check_is_fitted
@@ -51,11 +51,17 @@
5151
]
5252

5353

54+
@validate_params(
55+
{
56+
"n_samples": ["array-like", Interval(Real, 1, None, closed="left")],
57+
"eps": ["array-like", Interval(Real, 0, 1, closed="neither")],
58+
}
59+
)
5460
def johnson_lindenstrauss_min_dim(n_samples, *, eps=0.1):
5561
"""Find a 'safe' number of components to randomly project to.
5662
5763
The distortion introduced by a random projection `p` only changes the
58-
distance between two points by a factor (1 +- eps) in an euclidean space
64+
distance between two points by a factor (1 +- eps) in a euclidean space
5965
with good probability. The projection `p` is an eps-embedding as defined
6066
by:
6167
@@ -81,12 +87,12 @@ def johnson_lindenstrauss_min_dim(n_samples, *, eps=0.1):
8187
Parameters
8288
----------
8389
n_samples : int or array-like of int
84-
Number of samples that should be a integer greater than 0. If an array
90+
Number of samples that should be an integer greater than 0. If an array
8591
is given, it will compute a safe number of components array-wise.
8692
87-
eps : float or ndarray of shape (n_components,), dtype=float, \
93+
eps : float or array-like of shape (n_components,), dtype=float, \
8894
default=0.1
89-
Maximum distortion rate in the range (0,1 ) as defined by the
95+
Maximum distortion rate in the range (0, 1) as defined by the
9096
Johnson-Lindenstrauss lemma. If an array is given, it will compute a
9197
safe number of components array-wise.
9298
@@ -123,7 +129,7 @@ def johnson_lindenstrauss_min_dim(n_samples, *, eps=0.1):
123129
if np.any(eps <= 0.0) or np.any(eps >= 1):
124130
raise ValueError("The JL bound is defined for eps in ]0, 1[, got %r" % eps)
125131

126-
if np.any(n_samples) <= 0:
132+
if np.any(n_samples <= 0):
127133
raise ValueError(
128134
"The JL bound is defined for n_samples greater than zero, got %r"
129135
% n_samples

sklearn/tests/test_public_functions.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,7 @@ def _check_function_param_validation(
120120
"sklearn.metrics.roc_curve",
121121
"sklearn.metrics.zero_one_loss",
122122
"sklearn.model_selection.train_test_split",
123+
"sklearn.random_projection.johnson_lindenstrauss_min_dim",
123124
"sklearn.svm.l1_min_c",
124125
]
125126

sklearn/tests/test_random_projection.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,12 @@ def densify(matrix):
6565

6666

6767
@pytest.mark.parametrize(
68-
"n_samples, eps", [(100, 1.1), (100, 0.0), (100, -0.1), (0, 0.5)]
68+
"n_samples, eps",
69+
[
70+
([100, 110], [0.9, 1.1]),
71+
([90, 100], [0.1, 0.0]),
72+
([50, -40], [0.1, 0.2]),
73+
],
6974
)
7075
def test_invalid_jl_domain(n_samples, eps):
7176
with pytest.raises(ValueError):

0 commit comments

Comments
 (0)
0