8000 Merge pull request #2717 from Manoj-Kumar-S/test_log_loss · jwchennlp/scikit-learn@8f2d8b9 · GitHub
[go: up one dir, main page]

Skip to content

Commit 8f2d8b9

Browse files
committed
Merge pull request scikit-learn#2717 from Manoj-Kumar-S/test_log_loss
Testing log_loss and hinge_loss under THRESHOLDED_METRICS
2 parents 834b375 + c9714b5 commit 8f2d8b9

File tree

2 files changed

+61
-9
lines changed

2 files changed

+61
-9
lines changed

sklearn/metrics/metrics.py

Lines changed: 29 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -257,6 +257,7 @@ def hinge_loss(y_true, pred_decision, pos_label=None, neg_label=None):
257257
"release 0.15.", DeprecationWarning)
258258

259259
# TODO: multi-class hinge-loss
260+
y_true, pred_decision = check_arrays(y_true, pred_decision)
260261

261262
# the rest of the code assumes that positive and negative labels
262263
# are encoded as +1 and -1 respectively
@@ -265,7 +266,14 @@ def hinge_loss(y_true, pred_decision, pos_label=None, neg_label=None):
265266
else:
266267
y_true = LabelBinarizer(neg_label=-1).fit_transform(y_true)[:, 0]
267268

268-
margin = y_true * np.asarray(pred_decision)
269+
if pred_decision.ndim == 2 and pred_decision.shape[1] != 1:
270+
raise ValueError("Multi-class hinge loss not supported")
271+
pred_decision = np.ravel(pred_decision)
272+
273+
try:
274+
margin = y_true * pred_decision
275+
except TypeError:
276+
raise ValueError("pred_decision should be an array of floats.")
269277
losses = 1 - margin
270278
# The hinge doesn't penalize good enough predictions.
271279
losses[losses <= 0] = 0
@@ -1015,10 +1023,28 @@ def log_loss(y_true, y_pred, eps=1e-15, normalize=True):
10151023
if T.shape[1] == 1:
10161024
T = np.append(1 - T, T, axis=1)
10171025

1018-
# Clip and renormalize
1026+
# Clipping
10191027
Y = np.clip(y_pred, eps, 1 - eps)
1020-
Y /= Y.sum(axis=1)[:, np.newaxis]
10211028

1029+
# This happens in cases when elements in y_pred have type "str".
1030+
if not isinstance(Y, np.ndarray):
1031+
raise ValueError("y_pred should be an array of floats.")
1032+
1033+
# If y_pred is of single dimension, assume y_true to be binary
1034+
# and then check.
1035+
if Y.ndim == 1:
1036+
Y = Y[:, np.newaxis]
1037+
if Y.shape[1] == 1:
1038+
Y = np.append(1 - Y, Y, axis=1)
1039+
1040+
# Check if dimensions are consistent.
1041+
T, Y = check_arrays(T, Y)
1042+
if T.shape[1] != Y.shape[1]:
1043+
raise ValueError("y_true and y_pred have different number of classes "
1044+
"%d, %d" % (T.shape[1], Y.shape[1]))
1045+
1046+
# Renormalize
1047+
Y /= Y.sum(axis=1)[:, np.newaxis]
10221048
loss = -(T * np.log(Y)).sum()
10231049
return loss / T.shape[0] if normalize else loss
10241050

sklearn/metrics/tests/test_metrics.py

Lines changed: 32 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -147,6 +147,8 @@
147147
}
148148

149149
THRESHOLDED_METRICS = {
150+
"log_loss": log_loss,
151+
"hinge_loss": hinge_loss,
150152
"roc_auc_score": roc_auc_score,
151153
"weighted_roc_auc": partial(roc_auc_score, average="weighted"),
152154
"samples_roc_auc": partial(roc_auc_score, average="samples"),
@@ -198,7 +200,7 @@
198200

199201
# Metrics with a "pos_label" argument
200202
METRICS_WITH_POS_LABEL = [
201-
"roc_curve",
203+
"roc_curve", "hinge_loss",
202204

203205
"precision_score", "recall_score", "f1_score", "f2_score", "f0.5_score",
204206

@@ -238,7 +240,7 @@
238240
# Threshold-based metrics with "multilabel-indicator" format support
239241
THRESHOLDED_MULTILABEL_METRICS = [
240242
"roc_auc_score", "weighted_roc_auc", "samples_roc_auc",
241-
"micro_roc_auc", "macro_roc_auc",
243+
"micro_roc_auc", "macro_roc_auc", "log_loss",
242244

243245
"average_precision_score", "weighted_average_precision_score",
244246
"samples_average_precision_score", "micro_average_precision_score",
@@ -303,7 +305,7 @@
303305
"micro_recall_score",
304306

305307
"macro_f0.5_score", "macro_f2_score", "macro_precision_score",
306-
"macro_recall_score",
308+
"macro_recall_score", "log_loss", "hinge_loss"
307309
]
308310

309311
###############################################################################
@@ -1496,9 +1498,22 @@ def test_invariance_string_vs_numbers_labels():
14961498
err_msg="{0} failed string vs number "
14971499
"invariance test".format(name))
14981500

1499-
# TODO Currently not supported
1500-
for name, metrics in THRESHOLDED_METRICS.items():
1501-
assert_raises(ValueError, metrics, y1_str, y2_str)
1501+
for name, metric in THRESHOLDED_METRICS.items():
1502+
if name in ("log_loss", "hinge_loss"):
1503+
measure_with_number = metric(y1, y2)
1504+
measure_with_str = metric(y1_str, y2)
1505+
assert_array_equal(measure_with_number, measure_with_str,
1506+
err_msg="{0} failed string vs number invariance "
1507+
"test".format(name))
1508+
1509+
measure_with_strobj = metric(y1_str.astype('O'), y2)
1510+
assert_array_equal(measure_with_number, measure_with_strobj,
1511+
err_msg="{0} failed string object vs number "
1512+
"invariance test".format(name))
1513+
else:
1514+
# TODO those metrics doesn't support string label yet
1515+
assert_raises(ValueError, metric, y1_str, y2)
1516+
assert_raises(ValueError, metric, y1_str.astype('O'), y2)
15021517

15031518

15041519
@ignore_warnings
@@ -2370,6 +2385,17 @@ def test_log_loss():
23702385
loss = log_loss(y_true, y_pred, normalize=True, eps=.1)
23712386
assert_almost_equal(loss, log_loss(y_true, np.clip(y_pred, .1, .9)))
23722387

2388+
# raise error if number of classes are not equal.
2389+
y_true = [1, 0, 2]
2390+
y_pred = [[0.2, 0.7], [0.6, 0.5], [0.4, 0.1]]
2391+
assert_raises(ValueError, log_loss, y_true, y_pred)
2392+
2393+
# case when y_true is a string array object
2394+
y_true = ["ham", "spam", "spam", "ham"]
2395+
y_pred = [[0.2, 0.7], [0.6, 0.5], [0.4, 0.1], [0.7, 0.2]]
2396+
loss = log_loss(y_true, y_pred)
2397+
assert_almost_equal(loss, 1.0383217, decimal=6)
2398+
23732399

23742400
@ignore_warnings
23752401
def _check_averaging(metric, y_true, y_pred, y_true_binarize, y_pred_binarize,

0 commit comments

Comments
 (0)
0