8000 Simplified error message in get_scorer() function in sklearn.metrics.scorer.py file by princejha95 · Pull Request #11062 · scikit-learn/scikit-learn · GitHub
[go: up one dir, main page]

Skip to content

Simplified error message in get_scorer() function in sklearn.metrics.scorer.py file #11062

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 3 commits into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 4 additions & 19 deletions sklearn/metrics/scorer.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,16 +19,15 @@
# License: Simplified BSD

from abc import ABCMeta, abstractmethod
from collections import Iterable
import warnings

import numpy as np

from . import (r2_score, median_absolute_error, mean_absolute_error,
mean_squared_error, mean_squared_log_error, accuracy_score,
f1_score, roc_auc_score, average_precision_score,
precision_score, recall_score, log_loss, balanced_accuracy_score,
explained_variance_score, brier_score_loss)
precision_score, recall_score, log_loss,
explained_variance_score)

from .cluster import adjusted_rand_score
from .cluster import homogeneity_score
Expand Down Expand Up @@ -136,10 +135,7 @@ def __call__(self, clf, X, y, sample_weight=None):
"""
super(_ProbaScorer, self).__call__(clf, X, y,
sample_weight=sample_weight)
y_type = type_of_target(y)
y_pred = clf.predict_proba(X)
if y_type == "binary":
y_pred = y_pred[:, 1]
if sample_weight is not None:
return self._sign * self._score_func(y, y_pred,
sample_weight=sample_weight,
Expand Down Expand Up @@ -212,7 +208,6 @@ def __call__(self, clf, X, y, sample_weight=None):
def _factory_args(self):
return ", needs_threshold=True"


def get_scorer(scoring):
"""Get a scorer from string

Expand All @@ -236,8 +231,8 @@ def get_scorer(scoring):
valid = False # Don't raise here to make the error message elegant
if not valid:
raise ValueError('%r is not a valid scoring value. '
'Valid options are %s'
% (scoring, sorted(scorers)))
'For valid options use sorted(SCORERS.keys())'
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe should make this sklearn.metrics.SCORERS instead of just SCORERS

% (scoring))
else:
scorer = scoring
return scorer
Expand Down Expand Up @@ -301,10 +296,6 @@ def check_scoring(estimator, scoring=None, allow_none=False):
"If no scoring is specified, the estimator passed should "
"have a 'score' method. The estimator %r does not."
% estimator)
elif isinstance(scoring, Iterable):
raise ValueError("For evaluating multiple scores, use "
"sklearn.model_selection.cross_validate instead. "
"{0} was passed.".format(scoring))
else:
raise ValueError("scoring value should either be a callable, string or"
" None. %r was passed" % scoring)
Expand Down Expand Up @@ -505,7 +496,6 @@ def make_scorer(score_func, greater_is_better=True, needs_proba=False,
# Standard Classification Scores
accuracy_scorer = make_scorer(accuracy_score)
f1_scorer = make_scorer(f1_score)
balanced_accuracy_scorer = make_scorer(balanced_accuracy_score)

# Score functions that need decision values
roc_auc_scorer = make_scorer(roc_auc_score, greater_is_better=True,
Expand All @@ -523,9 +513,6 @@ def make_scorer(score_func, greater_is_better=True, needs_proba=False,
log_loss_scorer = make_scorer(log_loss, greater_is_better=False,
needs_proba=True)
log_loss_scorer._deprecation_msg = deprecation_msg
brier_score_loss_scorer = make_scorer(brier_score_loss,
greater_is_better=False,
needs_proba=True)


# Clustering scores
Expand All @@ -549,11 +536,9 @@ def make_scorer(score_func, greater_is_better=True, needs_proba=False,
mean_absolute_error=mean_absolute_error_scorer,
mean_squared_error=mean_squared_error_scorer,
accuracy=accuracy_scorer, roc_auc=roc_auc_scorer,
balanced_accuracy=balanced_accuracy_scorer,
average_precision=average_precision_scorer,
log_loss=log_loss_scorer,
neg_log_loss=neg_log_loss_scorer,
brier_score_loss=brier_score_loss_scorer,
# Cluster metrics that use supervised evaluation
adjusted_rand_score=adjusted_rand_scorer,
homogeneity_score=homogeneity_scorer,
Expand Down
0