8000 MAINT Parameters validation for sklearn.model_selection.permutation_t… · scikit-learn/scikit-learn@e20b3a6 · GitHub
[go: up one dir, main page]

Skip to content

Commit e20b3a6

Browse files
authored
MAINT Parameters validation for sklearn.model_selection.permutation_test_score (#26230)
1 parent ae78c25 commit e20b3a6

File tree

2 files changed

+17
-0
lines changed

2 files changed

+17
-0
lines changed

sklearn/model_selection/_validation.py

+16
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
from ..utils.metaestimators import _safe_split
3333
from ..utils._param_validation import (
3434
HasMethods,
35+
Interval,
3536
Integral,
3637
StrOptions,
3738
validate_params,
@@ -1235,6 +1236,21 @@ def _check_is_permutation(indices, n_samples):
12351236
return True
12361237

12371238

1239+
@validate_params(
1240+
{
1241+
"estimator": [HasMethods("fit")],
1242+
"X": ["array-like", "sparse matrix"],
1243+
"y": ["array-like", None],
1244+
"groups": ["array-like", None],
1245+
"cv": ["cv_object"],
1246+
"n_permutations": [Interval(Integral, 1, None, closed="left")],
1247+
"n_jobs": [Integral, None],
1248+
"random_state": ["random_state"],
1249+
"verbose": ["verbose"],
1250+
"scoring": [StrOptions(set(get_scorer_names())), callable, None],
1251+
"fit_params": [dict, None],
1252+
}
1253+
)
12381254
def permutation_test_score(
12391255
estimator,
12401256
X,

sklearn/tests/test_public_functions.py

+1
Original file line numberDiff line numberDiff line change
@@ -247,6 +247,7 @@ def _check_function_param_validation(
247247
"sklearn.metrics.top_k_accuracy_score",
248248
"sklearn.metrics.zero_one_loss",
249249
"sklearn.model_selection.cross_validate",
250+
"sklearn.model_selection.permutation_test_score",
250251
"sklearn.model_selection.train_test_split",
251252
"sklearn.neighbors.sort_graph_by_row_values",
252253
"sklearn.preprocessing.add_dummy_feature",

0 commit comments

Comments
 (0)
0