8000 MAINT Parameters validation for sklearn.model_selection.learning_curv… · scikit-learn/scikit-learn@4caa4ab · GitHub
[go: up one dir, main page]

Skip to content

Commit 4caa4ab

Browse files
MAINT Parameters validation for sklearn.model_selection.learning_curve (#26227)
Co-authored-by: jeremie du boisberranger <jeremiedbb@yahoo.fr>
1 parent 144324b commit 4caa4ab

File tree

3 files changed

+29
-15
lines changed

3 files changed

+29
-15
lines changed

sklearn/model_selection/_validation.py

Lines changed: 28 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -329,7 +329,7 @@ def cross_validate(
329329

330330
_warn_or_raise_about_fit_failures(results, error_score)
331331

332-
# For callabe scoring, the return type is only know after calling. If the
332+
# For callable scoring, the return type is only know after calling. If the
333333
# return type is a dictionary, the error scores can now be inserted with
334334
# the correct key.
335335
if callable(scoring):
@@ -1432,6 +1432,26 @@ def _shuffle(y, groups, random_state):
14321432
return _safe_indexing(y, indices)
14331433

14341434

1435+
@validate_params(
1436+
{
1437+
"estimator": [HasMethods(["fit"])],
1438+
"X": ["array-like", "sparse matrix"],
1439+
"y": ["array-like", None],
1440+
"groups": ["array-like", None],
1441+
"train_sizes": ["array-like"],
1442+
"cv": ["cv_object"],
1443+
"scoring": [StrOptions(set(get_scorer_names())), callable, None],
1444+
"exploit_incremental_learning": ["boolean"],
1445+
"n_jobs": [Integral, None],
1446+
"pre_dispatch": [Integral, str],
1447+
"verbose": ["verbose"],
1448+
"shuffle": ["boolean"],
1449+
"random_state": ["random_state"],
1450+
"error_score": [StrOptions({"raise"}), Real],
1451+
"return_times": ["boolean"],
1452+
"fit_params": [dict, None],
1453+
}
1454+
)
14351455
def learning_curve(
14361456
estimator,
14371457
X,
@@ -1466,18 +1486,20 @@ def learning_curve(
14661486
14671487
Parameters
14681488
----------
1469-
estimator : object type that implements the "fit" and "predict" methods
1470-
An object of that type which is cloned for each validation.
1489+
estimator : object type that implements the "fit" method
1490+
An object of that type which is cloned for each validation. It must
1491+
also implement "predict" unless `scoring` is a callable that doesn't
1492+
rely on "predict" to compute a score.
14711493
1472-
X : array-like of shape (n_samples, n_features)
1494+
X : {array-like, sparse matrix} of shape (n_samples, n_features)
14731495
Training vector, where `n_samples` is the number of samples and
14741496
`n_features` is the number of features.
14751497
1476-
y : array-like of shape (n_samples,) or (n_samples, n_outputs)
1498+
y : array-like of shape (n_samples,) or (n_samples, n_outputs) or None
14771499
Target relative to X for classification or regression;
14781500
None for unsupervised learning.
14791501
1480-
groups : array-like of shape (n_samples,), default=None
1502+
groups : array-like of shape (n_samples,), default=None
14811503
Group labels for the samples used while splitting the dataset into
14821504
train/test set. Only used in conjunction with a "Group" :term:`cv`
14831505
instance (e.g., :class:`GroupKFold`).

sklearn/model_selection/tests/test_validation.py

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2094,7 +2094,6 @@ def test_fit_and_score_failing():
20942094
failing_clf = FailingClassifier(FailingClassifier.FAILING_PARAMETER)
20952095
# dummy X data
20962096
X = np.arange(1, 10)
2097-
y = np.ones(9)
20982097
fit_and_score_args = [failing_clf, X, None, dict(), None, None, 0, None, None]
20992098
# passing error score to trigger the warning message
21002099
fit_and_score_kwargs = {"error_score": "raise"}
@@ -2103,21 +2102,13 @@ def test_fit_and_score_failing():
21032102
_fit_and_score(*fit_and_score_args, **fit_and_score_kwargs)
21042103

21052104
# check that functions upstream pass error_score param to _fit_and_score
2106-
error_message = re.escape(
2107-
"error_score must be the string 'raise' or a numeric value. (Hint: if "
2108-
"using 'raise', please make sure that it has been spelled correctly.)"
2109-
)
2110-
21112105
error_message_cross_validate = (
21122106
"The 'error_score' parameter of cross_validate must be .*. Got .* instead."
21132107
)
21142108

21152109
with pytest.raises(ValueError, match=error_message_cross_validate):
21162110
cross_val_score(failing_clf, X, cv=3, error_score="unvalid-string")
21172111

2118-
with pytest.raises(ValueError, match=error_message):
2119-
learning_curve(failing_clf, X, y, cv=3, error_score="unvalid-string")
2120-
21212112
assert failing_clf.score() == 0.0 # FailingClassifier coverage
21222113

21232114

sklearn/tests/test_public_functions.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -256,6 +256,7 @@ def _check_function_param_validation(
256256
"sklearn.metrics.v_measure_score",
257257
"sklearn.metrics.zero_one_loss",
258258
"sklearn.model_selection.cross_validate",
259+
"sklearn.model_selection.learning_curve",
259260
"sklearn.model_selection.permutation_test_score",
260261
"sklearn.model_selection.train_test_split",
261262
"sklearn.model_selection.validation_curve",

0 commit comments

Comments
 (0)
0