8000 MAINT Parameters validation for sklearn.model_selection.validation_cu… · scikit-learn/scikit-learn@144324b · GitHub
[go: up one dir, main page]

Skip to content

Commit 144324b

Browse files
MAINT Parameters validation for sklearn.model_selection.validation_curve (#26229)
Co-authored-by: Jérémie du Boisberranger <34657725+jeremiedbb@users.noreply.github.com>
1 parent a5b4b14 commit 144324b

File tree

3 files changed

+23
-14
lines changed

3 files changed

+23
-14
lines changed

sklearn/model_selection/_validation.py

Lines changed: 22 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1818,6 +1818,23 @@ def _incremental_fit_estimator(
18181818
return np.array(ret).T
18191819

18201820

1821+
@validate_params(
1822+
{
1823+
"estimator": [HasMethods(["fit"])],
1824+
"X": ["array-like", "sparse matrix"],
1825+
"y": ["array-like", None],
1826+
"param_name": [str],
1827+
"param_range": ["array-like"],
1828+
"groups": ["array-like", None],
1829+
"cv": ["cv_object"],
1830+
"scoring": [StrOptions(set(get_scorer_names())), callable, None],
1831+
"n_jobs": [Integral, None],
1832+
"pre_dispatch": [Integral, str],
1833+
"verbose": ["verbose"],
1834+
"error_score": [StrOptions({"raise"}), Real],
1835+
"fit_params": [dict, None],
1836+
}
1837+
)
18211838
def validation_curve(
18221839
estimator,
18231840
X,
@@ -1847,10 +1864,12 @@ def validation_curve(
18471864
18481865
Parameters
18491866
----------
1850-
estimator : object type that implements the "fit" and "predict" methods
1851-
An object of that type which is cloned for each validation.
1867+
estimator : object type that implements the "fit" method
1868+
An object of that type which is cloned for each validation. It must
1869+
also implement "predict" unless `scoring` is a callable that doesn't
1870+
rely on "predict" to compute a score.
18521871
1853-
X : array-like of shape (n_samples, n_features)
1872+
X : {array-like, sparse matrix} of shape (n_samples, n_features)
18541873
Training vector, where `n_samples` is the number of samples and
18551874
`n_features` is the number of features.
18561875

sklearn/model_selection/tests/test_validation.py

Lines changed: 0 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -2118,17 +2118,6 @@ def test_fit_and_score_failing():
21182118
with pytest.raises(ValueError, match=error_message):
21192119
learning_curve(failing_clf, X, y, cv=3, error_score="unvalid-string")
21202120

2121-
with pytest.raises(ValueError, match=error_message):
2122-
validation_curve(
2123-
failing_clf,
2124-
X,
2125-
y,
2126-
param_name="parameter",
2127-
param_range=[FailingClassifier.FAILING_PARAMETER],
2128-
cv=3,
2129-
error_score="unvalid-string",
2130-
)
2131-
21322121
assert failing_clf.score() == 0.0 # FailingClassifier coverage
21332122

21342123

sklearn/tests/test_public_functions.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -258,6 +258,7 @@ def _check_function_param_validation(
258258
"sklearn.model_selection.cross_validate",
259259
"sklearn.model_selection.permutation_test_score",
260260
"sklearn.model_selection.train_test_split",
261+
"sklearn.model_selection.validation_curve",
261262
"sklearn.neighbors.sort_graph_by_row_values",
262263
"sklearn.preprocessing.add_dummy_feature",
263264
"sklearn.preprocessing.binarize",

0 commit comments

Comments
 (0)
0