From 716925fc2da0103f2b546eaf71b22fc780fd84df Mon Sep 17 00:00:00 2001 From: Pooja Subramaniam Date: Fri, 7 Apr 2023 10:13:00 +0200 Subject: [PATCH 1/2] validating parameters for metrics.check_scoring function --- sklearn/metrics/_scorer.py | 9 ++++++++- sklearn/tests/test_public_functions.py | 1 + 2 files changed, 9 insertions(+), 1 deletion(-) diff --git a/sklearn/metrics/_scorer.py b/sklearn/metrics/_scorer.py index d085cd66d8232..c3dcd9405b237 100644 --- a/sklearn/metrics/_scorer.py +++ b/sklearn/metrics/_scorer.py @@ -65,7 +65,7 @@ from ..utils.multiclass import type_of_target from ..base import is_regressor -from ..utils._param_validation import validate_params +from ..utils._param_validation import validate_params, HasMethods def _cached_call(cache, estimator, method, *args, **kwargs): @@ -451,6 +451,13 @@ def _passthrough_scorer(estimator, *args, **kwargs): return estimator.score(*args, **kwargs) +@validate_params( + { + "estimator": [HasMethods(["fit"])], + "scoring": [str, callable, None], + "allow_none": ["boolean"], + } +) def check_scoring(estimator, scoring=None, *, allow_none=False): """Determine scorer from user options. diff --git a/sklearn/tests/test_public_functions.py b/sklearn/tests/test_public_functions.py index 55c33295e1d3d..b11265c9dbfba 100644 --- a/sklearn/tests/test_public_functions.py +++ b/sklearn/tests/test_public_functions.py @@ -166,6 +166,7 @@ def _check_function_param_validation( "sklearn.metrics.average_precision_score", "sklearn.metrics.balanced_accuracy_score", "sklearn.metrics.brier_score_loss", + "sklearn.metrics.check_scoring", "sklearn.metrics.class_likelihood_ratios", "sklearn.metrics.classification_report", "sklearn.metrics.cluster.adjusted_mutual_info_score", From f0914e46e25340f2c86dc2bbd487a0fb996d8e2b Mon Sep 17 00:00:00 2001 From: Pooja Subramaniam Date: Fri, 7 Apr 2023 10:47:30 +0200 Subject: [PATCH 2/2] removing separate test validating estimator has 'fit' method --- sklearn/metrics/tests/test_score_objects.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/sklearn/metrics/tests/test_score_objects.py b/sklearn/metrics/tests/test_score_objects.py index 7f3e804f68d46..710ba11147b00 100644 --- a/sklearn/metrics/tests/test_score_objects.py +++ b/sklearn/metrics/tests/test_score_objects.py @@ -228,12 +228,6 @@ def test_all_scorers_repr(): def check_scoring_validator_for_single_metric_usecases(scoring_validator): # Test all branches of single metric usecases - estimator = EstimatorWithoutFit() - pattern = ( - r"estimator should be an estimator implementing 'fit' method," r" .* was passed" - ) - with pytest.raises(TypeError, match=pattern): - scoring_validator(estimator) estimator = EstimatorWithFitAndScore() estimator.fit([[1]], [1])