8000 TST add test_binomial_vs_alternative_formulation · lorentzenchr/scikit-learn@64098db · GitHub
[go: up one dir, main page]

Skip to content

Commit 64098db

Browse files
committed
TST add test_binomial_vs_alternative_formulation
1 parent bfb76c9 commit 64098db

File tree

1 file changed

+39
-0
lines changed

1 file changed

+39
-0
lines changed

sklearn/_loss/tests/test_loss.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from itertools import product
12
import pickle
23

34
import numpy as np
@@ -981,6 +982,44 @@ def test_binomial_and_multinomial_loss(global_random_seed):
981982
)
982983

983984

985+
def test_binomial_vs_alternative_formulation():
986+
"""Tast that both formulations of the binomial deviance agree.
987+
988+
Often, the binomial deviance or log loss is written in terms of a variable
989+
z in {-1, +1}, but we use y in {0, 1}, hence y = 2 * y - 1.
990+
ESL II Eq. (10.18):
991+
992+
-loglike(z, f) = log(1 + exp(-2 * z * f))
993+
994+
Note:
995+
- ESL 2*f = raw_prediction, hence the factor 2 of ESL disappears.
996+
- Deviance = -2*loglike + .., but HalfBinomialLoss is half of the
997+
deviance, hence the factor of 2 cancels in the comparison.
998+
"""
999+
1000+
def alt_loss(y, raw_pred):
1001+
z = 2 * y - 1
1002+
return np.mean(np.log(1 + np.exp(-z * raw_pred)))
1003+
1004+
bin_loss = HalfBinomialLoss()
1005+
1006+
test_data = product(
1007+
(np.array([0.0, 0, 0]), np.array([1.0, 1, 1])),
1008+
(np.array([-5.0, -5, -5]), np.array([3.0, 3, 3])),
1009+
)
1010+
1011+
for datum in test_data:
1012+
assert bin_loss(*datum) == approx(alt_loss(*datum))
1013+
1014+
# check the negative gradient against alternative formula from ESLII
1015+
def alt_gradient(y, raw_pred):
1016+
z = 2 * y - 1
1017+
return z / (1 + np.exp(z * raw_pred))
1018+
1019+
for datum in test_data:
1020+
assert bin_loss.gradient(*datum) == approx(alt_gradient(*datum))
1021+
1022+
9841023
@pytest.mark.parametrize("loss", LOSS_INSTANCES, ids=loss_instance_name)
9851024
def test_predict_proba(loss, global_random_seed):
9861025
"""Test that predict_proba and gradient_proba work as expected."""

0 commit comments

Comments
 (0)
0