|
27 | 27 | from . import (r2_score, median_absolute_error, max_error, mean_absolute_error,
|
28 | 28 | mean_squared_error, mean_squared_log_error,
|
29 | 29 | mean_poisson_deviance, mean_gamma_deviance, accuracy_score,
|
30 |
| - f1_score, roc_auc_score, average_precision_score, |
31 |
| - precision_score, recall_score, log_loss, |
32 |
| - balanced_accuracy_score, explained_variance_score, |
| 30 | + top_k_accuracy_score, f1_score, roc_auc_score, |
| 31 | + average_precision_score, precision_score, recall_score, |
| 32 | + log_loss, balanced_accuracy_score, explained_variance_score, |
33 | 33 | brier_score_loss, jaccard_score, mean_absolute_percentage_error)
|
34 | 34 |
|
35 | 35 | from .cluster import adjusted_rand_score
|
@@ -610,6 +610,9 @@ def make_scorer(score_func, *, greater_is_better=True, needs_proba=False,
|
610 | 610 | balanced_accuracy_scorer = make_scorer(balanced_accuracy_score)
|
611 | 611 |
|
612 | 612 | # Score functions that need decision values
|
| 613 | +top_k_accuracy_scorer = make_scorer(top_k_accuracy_score, |
| 614 | + greater_is_better=True, |
| 615 | + needs_threshold=True) |
613 | 616 | roc_auc_scorer = make_scorer(roc_auc_score, greater_is_better=True,
|
614 | 617 | needs_threshold=True)
|
615 | 618 | average_precision_scorer = make_scorer(average_precision_score,
|
@@ -658,7 +661,9 @@ def make_scorer(score_func, *, greater_is_better=True, needs_proba=False,
|
658 | 661 | neg_root_mean_squared_error=neg_root_mean_squared_error_scorer,
|
659 | 662 | neg_mean_poisson_deviance=neg_mean_poisson_deviance_scorer,
|
660 | 663 | neg_mean_gamma_deviance=neg_mean_gamma_deviance_scorer,
|
661 |
| - accuracy=accuracy_scorer, roc_auc=roc_auc_scorer, |
| 664 | + accuracy=accuracy_scorer, |
| 665 | + top_k_accuracy=top_k_accuracy_scorer, |
| 666 | + roc_auc=roc_auc_scorer, |
662 | 667 | roc_auc_ovr=roc_auc_ovr_scorer,
|
663 | 668 | roc_auc_ovo=roc_auc_ovo_scorer,
|
664 | 669 | roc_auc_ovr_weighted=roc_auc_ovr_weighted_scorer,
|
|
0 commit comments