8000 MAINT Parameters validation for sklearn.model_selection.learning_curve by Charlie-XIAO · Pull Request #26227 · scikit-learn/scikit-learn · GitHub
[go: up one dir, main page]

Skip to content
34 changes: 28 additions & 6 deletions sklearn/model_selection/_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,7 +329,7 @@ def cross_validate(

_warn_or_raise_about_fit_failures(results, error_score)

# For callabe scoring, the return type is only know after calling. If the
# For callable scoring, the return type is only know after calling. If the
# return type is a dictionary, the error scores can now be inserted with
# the correct key.
if callable(scoring):
Expand Down Expand Up @@ -1432,6 +1432,26 @@ def _shuffle(y, groups, random_state):
return _safe_indexing(y, indices)


@validate_params(
{
"estimator": [HasMethods(["fit"])],
"X": ["array-like", "sparse matrix"],
"y": ["array-like", None],
"groups": ["array-like", None],
"train_sizes": ["array-like"],
"cv": ["cv_object"],
"scoring": [StrOptions(set(get_scorer_names())), callable, None],
"exploit_incremental_learning": ["boolean"],
"n_jobs": [Integral, None],
"pre_dispatch": [Integral, str],
"verbose": ["verbose"],
"shuffle": ["boolean"],
"random_state": ["random_state"],
"error_score": [StrOptions({"raise"}), Real],
"return_times": ["boolean"],
"fit_params": [dict, None],
}
)
def learning_curve(
estimator,
X,
Expand Down Expand Up @@ -1466,18 +1486,20 @@ def learning_curve(

Parameters
----------
estimator : object type that implements the "fit" and "predict" methods
An object of that type which is cloned for each validation.
estimator : object type that implements the "fit" method
An object of that type which is cloned for each validation. It must
also implement "predict" unless `scoring` is a callable that doesn't
rely on "predict" to compute a score.

X : array-like of shape (n_samples, n_features)
X : {array-like, sparse matrix} of shape (n_samples, n_features)
Training vector, where `n_samples` is the number of samples and
`n_features` is the number of features.

y : array-like of shape (n_samples,) or (n_samples, n_outputs)
y : array-like of shape (n_samples,) or (n_samples, n_outputs) or None
Target relative to X for classification or regression;
None for unsupervised learning.

groups : array-like of shape (n_samples,), default=None
groups : array-like of shape (n_samples,), default=None
Group labels for the samples used while splitting the dataset into
train/test set. Only used in conjunction with a "Group" :term:`cv`
instance (e.g., :class:`GroupKFold`).
Expand Down
9 changes: 0 additions & 9 deletions sklearn/model_selection/tests/test_validation.py
8000
Original file line number Diff line number Diff line change
Expand Up @@ -2094,7 +2094,6 @@ def test_fit_and_score_failing():
failing_clf = FailingClassifier(FailingClassifier.FAILING_PARAMETER)
# dummy X data
X = np.arange(1, 10)
y = np.ones(9)
fit_and_score_args = [failing_clf, X, None, dict(), None, None, 0, None, None]
# passing error score to trigger the warning message
fit_and_score_kwargs = {"error_score": "raise"}
Expand All @@ -2103,21 +2102,13 @@ def test_fit_and_score_failing():
_fit_and_score(*fit_and_score_args, **fit_and_score_kwargs)

# check that functions upstream pass error_score param to _fit_and_score
error_message = re.escape(
"error_score must be the string 'raise' or a numeric value. (Hint: if "
"using 'raise', please make sure that it has been spelled correctly.)"
)

error_message_cross_validate = (
"The 'error_score' parameter of cross_validate must be .*. Got .* instead."
)

with pytest.raises(ValueError, match=error_message_cross_validate):
cross_val_score(failing_clf, X, cv=3, error_score="unvalid-string")

with pytest.raises(ValueError, match=error_message):
learning_curve(failing_clf, X, y, cv=3, error_score="unvalid-string")

assert failing_clf.score() == 0.0 # FailingClassifier coverage


Expand Down
1 change: 1 addition & 0 deletions sklearn/tests/test_public_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,6 +256,7 @@ def _check_function_param_validation(
"sklearn.metrics.v_measure_score",
"sklearn.metrics.zero_one_loss",
"sklearn.model_selection.cross_validate",
"sklearn.model_selection.learning_curve",
"sklearn.model_selection.permutation_test_score",
"sklearn.model_selection.train_test_split",
"sklearn.model_selection.validation_curve",
Expand Down
0