|
147 | 147 | }
|
148 | 148 |
|
149 | 149 | THRESHOLDED_METRICS = {
|
| 150 | + "log_loss": log_loss, |
| 151 | + "hinge_loss": hinge_loss, |
150 | 152 | "roc_auc_score": roc_auc_score,
|
151 | 153 | "weighted_roc_auc": partial(roc_auc_score, average="weighted"),
|
152 | 154 | "samples_roc_auc": partial(roc_auc_score, average="samples"),
|
|
198 | 200 |
|
199 | 201 | # Metrics with a "pos_label" argument
|
200 | 202 | METRICS_WITH_POS_LABEL = [
|
201 |
| - "roc_curve", |
| 203 | + "roc_curve", "hinge_loss", |
202 | 204 |
|
203 | 205 | "precision_score", "recall_score", "f1_score", "f2_score", "f0.5_score",
|
204 | 206 |
|
|
238 | 240 | # Threshold-based metrics with "multilabel-indicator" format support
|
239 | 241 | THRESHOLDED_MULTILABEL_METRICS = [
|
240 | 242 | "roc_auc_score", "weighted_roc_auc", "samples_roc_auc",
|
241 |
| - "micro_roc_auc", "macro_roc_auc", |
| 243 | + "micro_roc_auc", "macro_roc_auc", "log_loss", |
242 | 244 |
|
243 | 245 | "average_precision_score", "weighted_average_precision_score",
|
244 | 246 | "samples_average_precision_score", "micro_average_precision_score",
|
|
303 | 305 | "micro_recall_score",
|
304 | 306 |
|
305 | 307 | "macro_f0.5_score", "macro_f2_score", "macro_precision_score",
|
306 |
| - "macro_recall_score", |
| 308 | + "macro_recall_score", "log_loss", "hinge_loss" |
307 | 309 | ]
|
308 | 310 |
|
309 | 311 | ###############################################################################
|
@@ -1496,9 +1498,22 @@ def test_invariance_string_vs_numbers_labels():
|
1496 | 1498 | err_msg="{0} failed string vs number "
|
1497 | 1499 | "invariance test".format(name))
|
1498 | 1500 |
|
1499 |
| - # TODO Currently not supported |
1500 |
| - for name, metrics in THRESHOLDED_METRICS.items(): |
1501 |
| - assert_raises(ValueError, metrics, y1_str, y2_str) |
| 1501 | + for name, metric in THRESHOLDED_METRICS.items(): |
| 1502 | + if name in ("log_loss", "hinge_loss"): |
| 1503 | + measure_with_number = metric(y1, y2) |
| 1504 | + measure_with_str = metric(y1_str, y2) |
| 1505 | + assert_array_equal(measure_with_number, measure_with_str, |
| 1506 | + err_msg="{0} failed string vs number invariance " |
| 1507 | + "test".format(name)) |
| 1508 | + |
| 1509 | + measure_with_strobj = metric(y1_str.astype('O'), y2) |
| 1510 | + assert_array_equal(measure_with_number, measure_with_strobj, |
| 1511 | + err_msg="{0} failed string object vs number " |
| 1512 | + "invariance test".format(name)) |
| 1513 | + else: |
| 1514 | + # TODO those metrics doesn't support string label yet |
| 1515 | + assert_raises(ValueError, metric, y1_str, y2) |
| 1516 | + assert_raises(ValueError, metric, y1_str.astype('O'), y2) |
1502 | 1517 |
|
1503 | 1518 |
|
1504 | 1519 | @ignore_warnings
|
@@ -2370,6 +2385,17 @@ def test_log_loss():
|
2370 | 2385 | loss = log_loss(y_true, y_pred, normalize=True, eps=.1)
|
2371 | 2386 | assert_almost_equal(loss, log_loss(y_true, np.clip(y_pred, .1, .9)))
|
2372 | 2387 |
|
| 2388 | + # raise error if number of classes are not equal. |
| 2389 | + y_true = [1, 0, 2] |
| 2390 | + y_pred = [[0.2, 0.7], [0.6, 0.5], [0.4, 0.1]] |
| 2391 | + assert_raises(ValueError, log_loss, y_true, y_pred) |
| 2392 | + |
| 2393 | + # case when y_true is a string array object |
| 2394 | + y_true = ["ham", "spam", "spam", "ham"] |
| 2395 | + y_pred = [[0.2, 0.7], [0.6, 0.5], [0.4, 0.1], [0.7, 0.2]] |
| 2396 | + loss = log_loss(y_true, y_pred) |
| 2397 | + assert_almost_equal(loss, 1.0383217, decimal=6) |
| 2398 | + |
2373 | 2399 |
|
2374 | 2400 | @ignore_warnings
|
2375 | 2401 | def _check_averaging(metric, y_true, y_pred, y_true_binarize, y_pred_binarize,
|
|
0 commit comments