10000 TST: Test class variance and string input · jwchennlp/scikit-learn@86b7051 · GitHub
[go: up one dir, main page]

Skip to content

Commit 86b7051

Browse files
committed
TST: Test class variance and string input
1 parent 10e384c commit 86b7051

File tree

1 file changed

+11
-0
lines changed

1 file changed

+11
-0
lines changed

sklearn/metrics/tests/test_metrics.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2372,6 +2372,17 @@ def test_log_loss():
23722372
loss = log_loss(y_true, y_pred, normalize=True, eps=.1)
23732373
assert_almost_equal(loss, log_loss(y_true, np.clip(y_pred, .1, .9)))
23742374

2375+
# raise error if number of classes are not equal.
2376+
y_true = [1, 0, 2]
2377+
y_pred = [[0.2, 0.7], [0.6, 0.5], [0.4, 0.1]]
2378+
assert_raises(ValueError, log_loss, y_true, y_pred)
2379+
2380+
# case when y_true is a string array object
2381+
y_true = ["ham", "spam", "spam", "ham"]
2382+
y_pred = [[0.2, 0.7], [0.6, 0.5], [0.4, 0.1], [0.7, 0.2]]
2383+
loss = log_loss(y_true, y_pred)
2384+
assert_almost_equal(loss, 1.0383217, decimal=6)
2385+
23752386

23762387
@ignore_warnings
23772388
def _check_averaging(metric, y_true, y_pred, y_true_binarize, y_pred_binarize,

0 commit comments

Comments
 (0)
0