8000 MAINT Parameter validation for linear_model.orthogonal_mp by choudharynishu · Pull Request #25817 · scikit-learn/scikit-learn · GitHub
[go: up one dir, main page]

Skip to content

MAINT Parameter validation for linear_model.orthogonal_mp #25817

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
17 commits
Select commit Hold shift + click to select a range
bbda940
MAINT Parameter Validation for linear_model.orthogonal_mp
choudharynishu Mar 10, 2023
dee4ebc
Fixed tol parameter validation dictionary
choudharynishu Mar 10, 2023
0869598
Fixed n_nonzero_coefs parameter validation
choudharynishu Mar 10, 2023
946801b
Edited linear_model orthogonal file
choudharynishu Mar 10, 2023
de0c605
"merge main"
choudharynishu Mar 10, 2023
126c030
Edited tol parameter validation
choudharynishu Mar 10, 2023
1730480
Edited range for tol parameter in linear_model.orthogonal_mp
choudharynishu Mar 10, 2023
c960105
Added sklearn.linear_model.orthogonal_mp to test public functions list
choudharynishu Mar 10, 2023
df75106
Merge branch 'main' of github.com:scikit-learn/scikit-learn into para…
choudharynishu Mar 13, 2023
3e61c9a
validation for linear_model.orthogonal_mp changed X from ndarray to a…
choudharynishu Mar 13, 2023
ea7de7b
Merge branch 'main' into param_validation_linearmodel
choudharynishu Mar 14, 2023
3698c55
Merge branch 'main' into param_validation_linearmodel
choudharynishu Mar 14, 2023
cc70faa
Merge branch 'main' of github.com:scikit-learn/scikit-learn into para…
choudharynishu Mar 14, 2023
98ea492
Merge branch 'main' of github.com:scikit-learn/scikit-learn into para…
choudharynishu Mar 14, 2023
f868af8
Merge branch 'main' of github.com:scikit-learn/scikit-learn into para…
choudharynishu Mar 14, 2023
e2005ae
Removed outdated validation for 'tol' and 'n_nonzero_coefs' in linear…
choudharynishu Mar 14, 2023
6796ebf
Merge branch 'main' of github.com:scikit-learn/scikit-le 10000 arn into para…
choudharynishu Mar 14, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 14 additions & 5 deletions sklearn/linear_model/_omp.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from ..utils import as_float_array, check_array
from ..utils.parallel import delayed, Parallel
from ..utils._param_validation import Hidden, Interval, StrOptions
from ..utils._param_validation import validate_params
from ..model_selection import check_cv

premature = (
Expand Down Expand Up @@ -281,6 +282,18 @@ def _gram_omp(
return gamma, indices[:n_active], n_active


@validate_params(
{
"X": ["array-like"],
"y": [np.ndarray],
"n_nonzero_coefs": [Interval(Integral, 1, None, closed="left"), None],
"tol": [Interval(Real, 0, None, closed="left"), None],
"precompute": ["boolean", StrOptions({"auto"})],
"copy_X": ["boolean"],
"return_path": ["boolean"],
"return_n_iter": ["boolean"],
}
)
def orthogonal_mp(
X,
y,
Expand Down Expand Up @@ -308,7 +321,7 @@ def orthogonal_mp(

Parameters
----------
X : ndarray of shape (n_samples, n_features)
X : array-like of shape (n_samples, n_features)
Input data. Columns are assumed to have unit norm.

y : ndarray of shape (n_samples,) or (n_samples, n_targets)
Expand Down Expand Up @@ -380,10 +393,6 @@ def orthogonal_mp(
# default for n_nonzero_coefs is 0.1 * n_features
# but at least one.
n_nonzero_coefs = max(int(0.1 * X.shape[1]), 1)
if tol is not None and tol < 0:
raise ValueError("Epsilon cannot be negative")
if tol is None and n_nonzero_coefs <= 0:
raise ValueError("The number of atoms must be positive")
if tol is None and n_nonzero_coefs > X.shape[1]:
raise ValueError(
"The number of atoms cannot be more than the number of features"
Expand Down
2 changes: 1 addition & 1 deletion sklearn/linear_model/tests/test_omp.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ def test_unreachable_accuracy():
@pytest.mark.parametrize("positional_params", [(X, y), (G, Xy)])
@pytest.mark.parametrize(
"keyword_params",
[{"tol": -1}, {"n_nonzero_coefs": -1}, {"n_nonzero_coefs": n_features + 1}],
[{"n_nonzero_coefs": n_features + 1}],
)
def test_bad_input(positional_params, keyword_params):
with pytest.raises(ValueError):
Expand Down
1 change: 1 addition & 0 deletions sklearn/tests/test_public_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,7 @@ def _check_function_param_validation(
"sklearn.feature_selection.f_regression",
"sklearn.feature_selection.mutual_info_classif",
"sklearn.feature_selection.r_regression",
"sklearn.linear_model.orthogonal_mp",
"sklearn.metrics.accuracy_score",
"sklearn.metrics.auc",
"sklearn.metrics.average_precision_score",
Expand Down
0