8000 Remove renormalization of y_pred inside log_loss · scikit-learn/scikit-learn@788e8a4 · GitHub
[go: up one dir, main page]

Skip to content

Commit 788e8a4

Browse files
committed
Remove renormalization of y_pred inside log_loss
1 parent 906faa5 commit 788e8a4

File tree

1 file changed

+19
-2
lines changed

1 file changed

+19
-2
lines changed

sklearn/metrics/_classification.py

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2617,7 +2617,17 @@ def log_loss(
26172617
y_pred = check_array(
26182618
y_pred, ensure_2d=False, dtype=[np.float64, np.float32, np.float16]
26192619
)
2620-
eps = np.finfo(y_pred.dtype).eps if eps == "auto" else eps
2620+
if eps == "auto":
2621+
eps = np.finfo(y_pred.dtype).eps
2622+
else:
2623+
# TODO: Remove user defined eps in 1.4
2624+
warnings.warn(
2625+
"Setting the eps parameter is deprecated and will "
2626+
"be removed in 1.4. Instead eps will always have"
2627+
"a default value of `np.finfo(y_pred.dtype).eps`.",
2628+
FutureWarning,
2629+
)
2630+
eps = eps
2621< 8000 /code>2631

26222632
check_consistent_length(y_pred, y_true, sample_weight)
26232633
lb = LabelBinarizer()
@@ -2680,7 +2690,14 @@ def log_loss(
26802690

26812691
# Renormalize
26822692
y_pred_sum = y_pred.sum(axis=1)
2683-
y_pred = y_pred / y_pred_sum[:, np.newaxis]
2693+
if y_pred_sum.any() != 1:
2694+
warnings.warn(
2695+
"The y_pred values are not normalized. Starting from 1.3 this"
2696+
"would result in an error.",
2697+
UserWarning,
2698+
)
2699+
y_pred = y_pred / y_pred_sum[:, np.newaxis]
2700+
26842701
loss = -xlogy(transformed_labels, y_pred).sum(axis=1)
26852702

26862703
return _weighted_sum(loss, sample_weight, normalize)

0 commit comments

Comments
 (0)
0