diff --git a/sklearn/linear_model/_omp.py b/sklearn/linear_model/_omp.py index f0bd04568c473..b1dc1e352fd62 100644 --- a/sklearn/linear_model/_omp.py +++ b/sklearn/linear_model/_omp.py @@ -18,6 +18,7 @@ from ..utils import as_float_array, check_array from ..utils.parallel import delayed, Parallel from ..utils._param_validation import Hidden, Interval, StrOptions +from ..utils._param_validation import validate_params from ..model_selection import check_cv premature = ( @@ -281,6 +282,18 @@ def _gram_omp( return gamma, indices[:n_active], n_active +@validate_params( + { + "X": ["array-like"], + "y": [np.ndarray], + "n_nonzero_coefs": [Interval(Integral, 1, None, closed="left"), None], + "tol": [Interval(Real, 0, None, closed="left"), None], + "precompute": ["boolean", StrOptions({"auto"})], + "copy_X": ["boolean"], + "return_path": ["boolean"], + "return_n_iter": ["boolean"], + } +) def orthogonal_mp( X, y, @@ -308,7 +321,7 @@ def orthogonal_mp( Parameters ---------- - X : ndarray of shape (n_samples, n_features) + X : array-like of shape (n_samples, n_features) Input data. Columns are assumed to have unit norm. y : ndarray of shape (n_samples,) or (n_samples, n_targets) @@ -380,10 +393,6 @@ def orthogonal_mp( # default for n_nonzero_coefs is 0.1 * n_features # but at least one. n_nonzero_coefs = max(int(0.1 * X.shape[1]), 1) - if tol is not None and tol < 0: - raise ValueError("Epsilon cannot be negative") - if tol is None and n_nonzero_coefs <= 0: - raise ValueError("The number of atoms must be positive") if tol is None and n_nonzero_coefs > X.shape[1]: raise ValueError( "The number of atoms cannot be more than the number of features" diff --git a/sklearn/linear_model/tests/test_omp.py b/sklearn/linear_model/tests/test_omp.py index 6aecd9abc6505..599e2940f9403 100644 --- a/sklearn/linear_model/tests/test_omp.py +++ b/sklearn/linear_model/tests/test_omp.py @@ -120,7 +120,7 @@ def test_unreachable_accuracy(): @pytest.mark.parametrize("positional_params", [(X, y), (G, Xy)]) @pytest.mark.parametrize( "keyword_params", - [{"tol": -1}, {"n_nonzero_coefs": -1}, {"n_nonzero_coefs": n_features + 1}], + [{"n_nonzero_coefs": n_features + 1}], ) def test_bad_input(positional_params, keyword_params): with pytest.raises(ValueError): diff --git a/sklearn/tests/test_public_functions.py b/sklearn/tests/test_public_functions.py index 3aaeaf0abff27..a9fff191b06a7 100644 --- a/sklearn/tests/test_public_functions.py +++ b/sklearn/tests/test_public_functions.py @@ -126,6 +126,7 @@ def _check_function_param_validation( "sklearn.feature_selection.f_regression", "sklearn.feature_selection.mutual_info_classif", "sklearn.feature_selection.r_regression", + "sklearn.linear_model.orthogonal_mp", "sklearn.metrics.accuracy_score", "sklearn.metrics.auc", "sklearn.metrics.average_precision_score",