8000 Add a scorer for model evaluation · gbolmier/scikit-learn@5f15950 · GitHub
[go: up one dir, main page]

Skip to content

Commit 5f15950

Browse files
committed
Add a scorer for model evaluation
# More detailed explanatory text, if necessary. Wrap it to about 72 # characters or so. In some contexts, the first line is treated as the # subject of the commit and the rest of the text as the body. The # blank line separating the summary from the body is critical (unless # you omit the body entirely); various tools like `log`, `shortlog` # and `rebase` can get confused if you run the two together. # Explain the problem that this commit is solving. Focus on why you # are making this change as opposed to how (the code explains that). # Are there side effects or other unintuitive consequences of this # change? Here's the place to explain them. # Further paragraphs come after blank lines. # - Bullet points are okay, too # - Typically a hyphen or asterisk is used for the bullet, preceded # by a single space, with blank lines in between, but conventions # vary here # If you use an issue tracker, put references to them at the bottom, # like this: # Resolves: scikit-learn#123 # See also: scikit-learn#456, scikit-learn#789
1 parent 51f8c15 commit 5f15950

File tree

3 files changed

+14
-5
lines changed

3 files changed

+14
-5
lines changed

doc/modules/model_evaluation.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@ Scoring Function
6060
**Classification**
6161
'accuracy' :func:`metrics.accuracy_score`
6262
'balanced_accuracy' :func:`metrics.balanced_accuracy_score`
63+
'top_k_accuracy' :func:`metrics.top_k_accuracy_score`
6364
'average_precision' :func:`metrics.average_precision_score`
6465
'neg_brier_score' :func:`metrics.brier_score_loss`
6566
'f1' :func:`metrics.f1_score` for binary targets

sklearn/metrics/_scorer.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -27,9 +27,9 @@
2727
from . import (r2_score, median_absolute_error, max_error, mean_absolute_error,
2828
mean_squared_error, mean_squared_log_error,
2929
mean_poisson_deviance, mean_gamma_deviance, accuracy_score,
30-
f1_score, roc_auc_score, average_precision_score,
31-
precision_score, recall_score, log_loss,
32-
balanced_accuracy_score, explained_variance_score,
30+
top_k_accuracy_score, f1_score, roc_auc_score,
31+
average_precision_score, precision_score, recall_score,
32+
log_loss, balanced_accuracy_score, explained_variance_score,
3333
brier_score_loss, jaccard_score, mean_absolute_percentage_error)
3434

3535
from .cluster import adjusted_rand_score
@@ -610,6 +610,9 @@ def make_scorer(score_func, *, greater_is_better=True, needs_proba=False,
610610
balanced_accuracy_scorer = make_scorer(balanced_accuracy_score)
611611

612612
# Score functions that need decision values
613+
top_k_accuracy_scorer = make_scorer(top_k_accuracy_score,
614+
greater_is_better=True,
615+
needs_threshold=True)
613616
roc_auc_scorer = make_scorer(roc_auc_score, greater_is_better=True,
614617
needs_threshold=True)
615618
average_precision_scorer = make_scorer(average_precision_score,
@@ -658,7 +661,9 @@ def make_scorer(score_func, *, greater_is_better=True, needs_proba=False,
658661
neg_root_mean_squared_error=neg_root_mean_squared_error_scorer,
659662
neg_mean_poisson_deviance=neg_mean_poisson_deviance_scorer,
660663
neg_mean_gamma_deviance=neg_mean_gamma_deviance_scorer,
661-
accuracy=accuracy_scorer, roc_auc=roc_auc_scorer,
664+
accuracy=accuracy_scorer,
665+
top_k_accuracy=top_k_accuracy_scorer,
666+
roc_auc=roc_auc_scorer,
662667
roc_auc_ovr=roc_auc_ovr_scorer,
663668
roc_auc_ovo=roc_auc_ovo_scorer,
664669
roc_auc_ovr_weighted=roc_auc_ovr_weighted_scorer,

sklearn/metrics/tests/test_score_objects.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@
5353
'max_error', 'neg_mean_poisson_deviance',
5454
'neg_mean_gamma_deviance']
5555

56-
CLF_SCORERS = ['accuracy', 'balanced_accuracy',
56+
CLF_SCORERS = ['accuracy', 'balanced_accuracy', 'top_k_accuracy',
5757
'f1', 'f1_weighted', 'f1_macro', 'f1_micro',
5858
'roc_auc', 'average_precision', 'precision',
5959
'precision_weighted', 'precision_macro', 'precision_micro',
@@ -496,6 +496,9 @@ def test_classification_scorer_sample_weight():
496496
if name in REGRESSION_SCORERS:
497497
# skip the regression scores
498498
continue
499+
if name == 'top_k_accuracy':
500+
# in the binary case k > 1 will always lead to a perfect score
501+
scorer._kwargs = {'k': 1}
499502
if name in MULTILABEL_ONLY_SCORERS:
500503
target = y_ml_test
501504
else:

0 commit comments

Comments
 (0)
0