8000 MAINT Parameter validation for linear_model.orthogonal_mp (#25817) · scikit-learn/scikit-learn@263b428 · GitHub
[go: up one dir, main page]

Skip to content

Commit 263b428

Browse files
MAINT Parameter validation for linear_model.orthogonal_mp (#25817)
1 parent 59a48db commit 263b428

File tree

3 files changed

+16
-6
lines changed

3 files changed

+16
-6
lines changed

sklearn/linear_model/_omp.py

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from ..utils import as_float_array, check_array
1919
from ..utils.parallel import delayed, Parallel
2020
from ..utils._param_validation import Hidden, Interval, StrOptions
21+
from ..utils._param_validation import validate_params
2122
from ..model_selection import check_cv
2223

2324
premature = (
@@ -281,6 +282,18 @@ def _gram_omp(
281282
return gamma, indices[:n_active], n_active
282283

283284

285+
@validate_params(
286+
{
287+
"X": ["array-like"],
288+
"y": [np.ndarray],
289+
"n_nonzero_coefs": [Interval(Integral, 1, None, closed="left"), None],
290+
"tol": [Interval(Real, 0, None, closed="left"), None],
291+
"precompute": ["boolean", StrOptions({"auto"})],
292+
"copy_X": ["boolean"],
293+
"return_path": ["boolean"],
294+
"return_n_iter": ["boolean"],
295+
}
296+
)
284297
def orthogonal_mp(
285298
X,
286299
y,
@@ -308,7 +321,7 @@ def orthogonal_mp(
308321
309322
Parameters
310323
----------
311-
X : ndarray of shape (n_samples, n_features)
324+
X : array-like of shape (n_samples, n_features)
312325
Input data. Columns are assumed to have unit norm.
313326
314327
y : ndarray of shape (n_samples,) or (n_samples, n_targets)
@@ -380,10 +393,6 @@ def orthogonal_mp(
380393
# default for n_nonzero_coefs is 0.1 * n_features
381394
# but at least one.
382395
n_nonzero_coefs = max(int(0.1 * X.shape[1]), 1)
383-
if tol is not None and tol < 0:
384-
raise ValueError("Epsilon cannot be negative")
385-
if tol is None and n_nonzero_coefs <= 0:
386-
raise ValueError("The number of atoms must be positive")
387396
if tol is None and n_nonzero_coefs > X.shape[1]:
388397
raise ValueError(
389398
"The number of atoms cannot be more than the number of features"

sklearn/linear_model/tests/test_omp.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,7 @@ def test_unreachable_accuracy():
120120
@pytest.mark.parametrize("positional_params", [(X, y), (G, Xy)])
121121
@pytest.mark.parametrize(
122122
"keyword_params",
123-
[{"tol": -1}, {"n_nonzero_coefs": -1}, {"n_nonzero_coefs": n_features + 1}],
123+
[{"n_nonzero_coefs": n_features + 1}],
124124
)
125125
def test_bad_input(positional_params, keyword_params):
126126
with pytest.raises(ValueError):

sklearn/tests/test_public_functions.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,7 @@ 5C39 def _check_function_param_validation(
126126
"sklearn.feature_selection.f_regression",
127127
"sklearn.feature_selection.mutual_info_classif",
128128
"sklearn.feature_selection.r_regression",
129+
"sklearn.linear_model.orthogonal_mp",
129130
"sklearn.metrics.accuracy_score",
130131
"sklearn.metrics.auc",
131132
"sklearn.metrics.average_precision_score",

0 commit comments

Comments
 (0)
0