|
18 | 18 | # Arnaud Joly <arnaud.v.joly@gmail.com>
|
19 | 19 | # License: Simplified BSD
|
20 | 20 |
|
21 |
| -from collections.abc import Iterable |
22 | 21 | from functools import partial
|
23 | 22 | from collections import Counter
|
24 | 23 | from traceback import format_exc
|
|
65 | 64 |
|
66 | 65 | from ..utils.multiclass import type_of_target
|
67 | 66 | from ..base import is_regressor
|
68 |
| -from ..utils._param_validation import validate_params |
| 67 | +from ..utils._param_validation import HasMethods, StrOptions, validate_params |
69 | 68 |
|
70 | 69 |
|
71 | 70 | def _cached_call(cache, estimator, method, *args, **kwargs):
|
@@ -451,79 +450,6 @@ def _passthrough_scorer(estimator, *args, **kwargs):
|
451 | 450 | return estimator.score(*args, **kwargs)
|
452 | 451 |
|
453 | 452 |
|
454 |
| -def check_scoring(estimator, scoring=None, *, allow_none=False): |
455 |
| - """Determine scorer from user options. |
456 |
| -
|
457 |
| - A TypeError will be thrown if the estimator cannot be scored. |
458 |
| -
|
459 |
| - Parameters |
460 |
| - ---------- |
461 |
| - estimator : estimator object implementing 'fit' |
462 |
| - The object to use to fit the data. |
463 |
| -
|
464 |
| - scoring : str or callable, default=None |
465 |
| - A string (see model evaluation documentation) or |
466 |
| - a scorer callable object / function with signature |
467 |
| - ``scorer(estimator, X, y)``. |
468 |
| - If None, the provided estimator object's `score` method is used. |
469 |
| -
|
470 |
| - allow_none : bool, default=False |
471 |
| - If no scoring is specified and the estimator has no score function, we |
472 |
| - can either return None or raise an exception. |
473 |
| -
|
474 |
| - Returns |
475 |
| - ------- |
476 |
| - scoring : callable |
477 |
| - A scorer callable object / function with signature |
478 |
| - ``scorer(estimator, X, y)``. |
479 |
| - """ |
480 |
| - if not hasattr(estimator, "fit"): |
481 |
| - raise TypeError( |
482 |
| - "estimator should be an estimator implementing 'fit' method, %r was passed" |
483 |
| - % estimator |
484 |
| - ) |
485 |
| - if isinstance(scoring, str): |
486 |
| - return get_scorer(scoring) |
487 |
| - elif callable(scoring): |
488 |
| - # Heuristic to ensure user has not passed a metric |
489 |
| - module = getattr(scoring, "__module__", None) |
490 |
| - if ( |
491 |
| - hasattr(module, "startswith") |
492 |
| - and module.startswith("sklearn.metrics.") |
493 |
| - and not module.startswith("sklearn.metrics._scorer") |
494 |
| - and not module.startswith("sklearn.metrics.tests.") |
495 |
| - ): |
496 |
| - raise ValueError( |
497 |
| - "scoring value %r looks like it is a metric " |
498 |
| - "function rather than a scorer. A scorer should " |
499 |
| - "require an estimator as its first parameter. " |
500 |
| - "Please use `make_scorer` to convert a metric " |
501 |
| - "to a scorer." % scoring |
502 |
| - ) |
503 |
| - return get_scorer(scoring) |
504 |
| - elif scoring is None: |
505 |
| - if hasattr(estimator, "score"): |
506 |
| - return _passthrough_scorer |
507 |
| - elif allow_none: |
508 |
| - return None |
509 |
| - else: |
510 |
| - raise TypeError( |
511 |
| - "If no scoring is specified, the estimator passed should " |
512 |
| - "have a 'score' method. The estimator %r does not." % estimator |
513 |
| - ) |
514 |
| - elif isinstance(scoring, Iterable): |
515 |
| - raise ValueError( |
516 |
| - "For evaluating multiple scores, use " |
517 |
| - "sklearn.model_selection.cross_validate instead. " |
518 |
| - "{0} was passed.".format(scoring) |
519 |
| - ) |
520 |
| - else: |
521 |
| - raise ValueError( |
522 |
| - "scoring value should either be a callable, string or None. %r was passed" |
523 |
| - % scoring |
524 |
| - ) |
525 |
| - |
526 |
| - |
527 | 453 | def _check_multimetric_scoring(estimator, scoring):
|
528 | 454 | """Check the scoring parameter in cases when multiple metrics are allowed.
|
529 | 455 |
|
@@ -882,3 +808,67 @@ def get_scorer_names():
|
882 | 808 | _SCORERS[qualified_name] = make_scorer(metric, pos_label=None, average=average)
|
883 | 809 |
|
884 | 810 | SCORERS = _DeprecatedScorers(_SCORERS)
|
| 811 | + |
| 812 | + |
| 813 | +@validate_params( |
| 814 | + { |
| 815 | + "estimator": [HasMethods("fit")], |
| 816 | + "scoring": [StrOptions(set(get_scorer_names())), callable, None], |
| 817 | + "allow_none": ["boolean"], |
| 818 | + } |
| 819 | +) |
| 820 | +def check_scoring(estimator, scoring=None, *, allow_none=False): |
| 821 | + """Determine scorer from user options. |
| 822 | +
|
| 823 | + A TypeError will be thrown if the estimator cannot be scored. |
| 824 | +
|
| 825 | + Parameters |
| 826 | + ---------- |
| 827 | + estimator : estimator object implementing 'fit' |
| 828 | + The object to use to fit the data. |
| 829 | +
|
| 830 | + scoring : str or callable, default=None |
| 831 | + A string (see model evaluation documentation) or |
| 832 | + a scorer callable object / function with signature |
| 833 | + ``scorer(estimator, X, y)``. |
| 834 | + If None, the provided estimator object's `score` method is used. |
| 835 | +
|
| 836 | + allow_none : bool, default=False |
| 837 | + If no scoring is specified and the estimator has no score function, we |
| 838 | + can either return None or raise an exception. |
| 839 | +
|
| 840 | + Returns |
| 841 | + ------- |
| 842 | + scoring : callable |
| 843 | + A scorer callable object / function with signature |
| 844 | + ``scorer(estimator, X, y)``. |
| 845 | + """ |
| 846 | + if isinstance(scoring, str): |
| 847 | + return get_scorer(scoring) |
| 848 | + if callable(scoring): |
| 849 | + # Heuristic to ensure user has not passed a metric |
| 850 | + module = getattr(scoring, "__module__", None) |
| 851 | + if ( |
| 852 | + hasattr(module, "startswith") |
| 853 | + and module.startswith("sklearn.metrics.") |
| 854 | + and not module.startswith("sklearn.metrics._scorer") |
| 855 | + and not module.startswith("sklearn.metrics.tests.") |
| 856 | + ): |
| 857 | + raise ValueError( |
| 858 | + "scoring value %r looks like it is a metric " |
| 859 | + "function rather than a scorer. A scorer should " |
| 860 | + "require an estimator as its first parameter. " |
| 861 | + "Please use `make_scorer` to convert a metric " |
| 862 | + "to a scorer." % scoring |
| 863 | + ) |
| 864 | + return get_scorer(scoring) |
| 865 | + if scoring is None: |
| 866 | + if hasattr(estimator, "score"): |
| 867 | + return _passthrough_scorer |
| 868 | + elif allow_none: |
| 869 | + return None |
| 870 | + else: |
| 871 | + raise TypeError( |
| 872 | + "If no scoring is specified, the estimator passed should " |
| 873 | + "have a 'score' method. The estimator %r does not." % estimator |
| 874 | + ) |
0 commit comments