10000 MAINT split sklearn/metrics/metrics.py · scikit-learn/scikit-learn@6d881d3 · GitHub
[go: up one dir, main page]

Skip to content

Commit 6d881d3

Browse files
committed
MAINT split sklearn/metrics/metrics.py
1 parent ce0484f commit 6d881d3

File tree

12 files changed

+4514
-4327
lines changed

12 files changed

+4514
-4327
lines changed

sklearn/metrics/__init__.py

Lines changed: 92 additions & 88 deletions
Original file line numberDiff line numberDiff line change
@@ -3,97 +3,101 @@
33
and pairwise metrics and distance computations.
44
"""
55

6-
from .metrics import (accuracy_score,
7-
average_precision_score,
8-
auc,
9-
roc_auc_score,
10-
classification_report,
11-
confusion_matrix,
12-
explained_variance_score,
13-
f1_score,
14-
fbeta_score,
15-
hamming_loss,
16-
hinge_loss,
17-
jaccard_similarity_score,
18-
label_ranking_average_precision_score,
19-
log_loss,
20-
matthews_corrcoef,
21-
mean_squared_error,
22-
mean_absolute_error,
23-
precision_recall_curve,
24-
precision_recall_fscore_support,
25-
precision_score,
26-
recall_score,
27-
r2_score,
28-
roc_curve,
29-
zero_one_loss)
6+
from .ranking import auc
7+
from .ranking import average_precision_score
8+
from .ranking import label_ranking_average_precision_score
9+
from .ranking import log_loss
10+
from .ranking import precision_recall_curve
11+
from .ranking import roc_auc_score
12+
from .ranking import roc_curve
13+
from .ranking import hinge_loss
3014

15+
from .classification import accuracy_score
16+
from .classification import classification_report
17+
from .classification import confusion_matrix
18+
from .classification import f1_score
19+
from .classification import fbeta_score
20+
from .classification import hamming_loss
21+
from .classification import jaccard_similarity_score
22+
from .classification import matthews_corrcoef
23+
from .classification import precision_recall_fscore_support
24+
from .classification import precision_score
25+
from .classification import recall_score
26+
from .classification import zero_one_loss
3127

32-
# Deprecated in 0.16
33-
from .metrics import auc_score
28+
from . import cluster
29+
from .cluster import adjusted_mutual_info_score
30+
from .cluster import adjusted_rand_score
31+
from .cluster import completeness_score
32+
from .cluster import consensus_score
33+
from .cluster import homogeneity_completeness_v_measure
34+
from .cluster import homogeneity_score
35+
from .cluster import mutual_info_score
36+
from .cluster import normalized_mutual_info_score
37+
from .cluster import silhouette_samples
38+
from .cluster import silhouette_score
39+
from .cluster import v_measure_score
3440

35-
from .scorer import make_scorer, SCORERS
41+
from .pairwise import euclidean_distances
42+
from .pairwise import pairwise_distances
43+
from .pairwise import pairwise_distances_argmin
44+
from .pairwise import pairwise_distances_argmin_min
45+
from .pairwise import pairwise_kernels
3646

37-
from . import cluster
38-
from .cluster import (adjusted_rand_score,
39-
adjusted_mutual_info_score,
40-
completeness_score,
41-
homogeneity_completeness_v_measure,
42-
homogeneity_score,
43-
mutual_info_score,
44-
normalized_mutual_info_score,
45-
silhouette_score,
46-
silhouette_samples,
47-
v_measure_score,
48-
consensus_score)
47+
from .regression import explained_variance_score
48+
from .regression import mean_absolute_error
49+
from .regression import mean_squared_error
50+
from .regression import r2_score
4951

50-
from .pairwise import (euclidean_distances,
51-
pairwise_distances,
52-
pairwise_distances_argmin_min,
53-
pairwise_distances_argmin,
54-
pairwise_kernels)
52+
from .scorer import make_scorer
53+
from .scorer import SCORERS
54+
55+
# Deprecated in 0.16
56+
from .ranking import auc_score
5557

56-
__all__ = ['accuracy_score',
57-
'adjusted_mutual_info_score',
58-
'adjusted_rand_score',
59-
'auc',
60-
'roc_auc_score',
61-
'average_precision_score',
62-
'classification_report',
63-
'cluster',
64-
'completeness_score',
65-
'confusion_matrix',
66-
'euclidean_distances',
67-
'pairwise_distances_argmin_min',
68-
'explained_variance_score',
69-
'f1_score',
70-
'fbeta_score',
71-
'hamming_loss',
72-
'hinge_loss',
73-
'homogeneity_completeness_v_measure',
74-
'homogeneity_score',
75-
'jaccard_similarity_score',
76-
'label_ranking_average_precision_score',
77-
'log_loss',
78-
'matthews_corrcoef',
79-
'mean_squared_error',
80-
'mean_absolute_error',
81-
'mutual_info_score',
82-
'normalized_mutual_info_score',
83-
'pairwise_distances',
84-
'pairwise_distances_argmin',
85-
'pairwise_distances_argmin_min',
86-
'pairwise_kernels',
87-
'precision_recall_curve',
88-
'precision_recall_fscore_support',
89-
'precision_score',
90-
'r2_score',
91-
'recall_score',
92-
'roc_curve',
93-
'silhouette_score',
94-
'silhouette_samples',
95-
'v_measure_score',
96-
'consensus_score',
97-
'zero_one_loss',
98-
'make_scorer',
99-
'SCORERS']
58+
__all__ = [
59+
'accuracy_score',
60+
'adjusted_mutual_info_score',
61+
'adjusted_rand_score',
62+
'auc',
63+
'average_precision_score',
64+
'classification_report',
65+
'cluster',
66+
'completeness_score',
67+
'confusion_matrix',
68+
'consensus_score',
69+
'euclidean_distances',
70+
'explained_variance_score',
71+
'f1_score',
72+
'fbeta_score',
73+
'hamming_loss',
74+
'hinge_loss',
75+
'homogeneity_completeness_v_measure',
76+
'homogeneity_score',
77+
'jaccard_similarity_score',
78+
'label_ranking_average_precision_score',
79+
'log_loss',
80+
'make_scorer',
81+
'matthews_corrcoef',
82+
'mean_absolute_error',
83+
'mean_squared_error',
84+
'mutual_info_score',
85+
'normalized_mutual_info_score',
86+
'pairwise_distances',
87+
'pairwise_distances_argmin',
88+
'pairwise_distances_argmin_min',
89+
'pairwise_distances_argmin_min',
90+
'pairwise_kernels',
91+
'precision_recall_curve',
92+
'precision_recall_fscore_support',
93+
'precision_score',
94+
'r2_score',
95+
'recall_score',
96+
'roc_auc_score',
97+
'roc_curve',
98+
'SCORERS',
99+
'silhouette_samples',
100+
'silhouette_score',
101+
'v_measure_score',
102+
'zero_one_loss',
103+
]

sklearn/metrics/base.py

Lines changed: 123 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,123 @@
1+
"""
2+
Common code for all metrics
3+
4+
"""
5+
# Authors: Alexandre Gramfort <alexandre.gramfort@inria.fr>
6+
# Mathieu Blondel <mathieu@mblondel.org>
7+
# Olivier Grisel <olivier.grisel@ensta.org>
8+
# Arnaud Joly <a.joly@ulg.ac.be>
9+
# Jochen Wersdorfer <jochen@wersdoerfer.de>
10+
# Lars Buitinck <L.J.Buitinck@uva.nl>
11+
# Joel Nothman <joel.nothman@gmail.com>
12+
# Noel Dawe <noel@dawe.me>
13+
# License: BSD 3 clause
14+
15+
from __future__ import division
16+
17+
import numpy as np
18+
19+
from ..utils import check_arrays
20+
from ..utils.multiclass import type_of_target
21+
22+
23+
class UndefinedMetricWarning(UserWarning):
24+
pass
25+
26+
27+
def _average_binary_score(binary_metric, y_true, y_score, average,
28+
sample_weight=None):
29+
"""Average a binary metric for multilabel classification
30+
31+
Parameters
32+
----------
33+
y_true : array, shape = [n_samples] or [n_samples, n_classes]
34+
True binary labels in binary label indicators.
35+
36+
y_score : array, shape = [n_samples] or [n_samples, n_classes]
37+
Target scores, can either be probability estimates of the positive
38+
class, confidence values, or binary decisions.
39+
40+
average : string, [None, 'micro', 'macro' (default), 'samples', 'weighted']
41+
If ``None``, the scores for each class are returned. Otherwise,
42+
this determines the type of averaging performed on the data:
43+
44+
``'micro'``:
45+
Calculate metrics globally by considering each element of the label
46+
indicator matrix as a label.
47+
``'macro'``:
48+
Calculate metrics for each label, and find their unweighted
49+
mean. This does not take label imbalance into account.
50+
``'weighted'``:
51+
Calculate metrics for each label, and find their average, weighted
52+
by support (the number of true instances for each label).
53+
``'samples'``:
54+
Calculate metrics for each instance, and find their average.
55+
56+
sample_weight : array-like of shape = [n_samples], optional
57+
Sample weights.
58+
59+
Return
60+
------
61+
score : float or array of shape [n_classes]
62+
If not ``None``, average the score, else return the score for each
63+
classes.
64+
65+
"""
66+
average_options = (None, 'micro', 'macro', 'weighted', 'samples')
67+
if average not in average_options:
68+
raise ValueError('average has to be one of {0}'
69+
''.format(average_options))
70+
71+
y_type = type_of_target(y_true)
72+
if y_type not in ("binary", "multilabel-indicator"):
73+
raise ValueError("{0} format is not supported".format(y_type))
74+
75+
if y_type == "binary":
76+
return binary_metric(y_true, y_score, sample_weight=sample_weight)
77+
78+
y_true, y_score = check_arrays(y_true, y_score)
79+
80+
not_average_axis = 1
81+
score_weight = sample_weight
82+
average_weight = None
83+
84+
if average == "micro":
85+
if score_weight is not None:
86+
score_weight = np.repeat(score_weight, y_true.shape[1])
87+
y_true = y_true.ravel()
88+
y_score = y_score.ravel()
89+
90+
elif average == 'weighted':
91+
if score_weight is not None:
92+
average_weight = np.sum(np.multiply(
93+
y_true, np.reshape(score_weight, (-1, 1))), axis=0)
94+
else:
95+
average_weight = np.sum(y_true, axis=0)
96+
if average_weight.sum() == 0:
97+
return 0
98+
99+
elif average == 'samples':
100+
# swap average_weight <-> score_weight
101+
average_weight = score_weight
102+
score_weight = None
103+
not_average_axis = 0
104+
105+
if y_true.ndim == 1:
106+
y_true = y_true.reshape((-1, 1))
107+
108+
if y_score.ndim == 1:
109+
y_score = y_score.reshape((-1, 1))
110+
111+
n_classes = y_score.shape[not_average_axis]
112+
score = np.zeros((n_classes,))
113+
for c in range(n_classes):
114+
y_true_c = y_true.take([c], axis=not_average_axis).ravel()
115+
y_score_c = y_score.take([c], axis=not_average_axis).ravel()
116+
score[c] = binary_metric(y_true_c, y_score_c,
117+
sample_weight=score_weight)
118+
119+
# Average the results
120+
if average is not None:
121+
return np.average(score, weights=average_weight)
122+
else:
123+
return score

0 commit comments

Comments
 (0)
0