8000 Remove renormalization of y_pred inside log_loss · scikit-learn/scikit-learn@f3261bd · GitHub
[go: up one dir, main page]

Skip to content

Commit f3261bd

Browse files
committed
Remove renormalization of y_pred inside log_loss
1 parent 906faa5 commit f3261bd

File tree

1 file changed

+18
-2
lines changed

1 file changed

+18
-2
lines changed

sklearn/metrics/_classification.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2617,7 +2617,16 @@ def log_loss(
26172617
y_pred = check_array(
26182618
y_pred, ensure_2d=False, dtype=[np.float64, np.float32, np.float16]
26192619
)
2620-
eps = np.finfo(y_pred.dtype).eps if eps == "auto" else eps
2620+
if eps == "auto":
2621+
eps = np.finfo(y_pred.dtype).eps
2622+
else:
2623+
# TODO: Remove user defined eps in 1.4
2624+
warnings.warn(
2625+
"Setting the eps parameter is deprecated and will "
2626+
"be removed in 1.4. Instead eps will always have"
2627+
"a default value of `np.finfo(y_pred.dtype).eps`.",
2628+
FutureWarning,
2629+
)
26212630

26222631
check_consistent_length(y_pred, y_true, sample_weight)
26232632
lb = LabelBinarizer()
@@ -2680,7 +2689,14 @@ def log_loss(
26802689

26812690
# Renormalize
26822691
y_pred_sum = y_pred.sum(axis=1)
2683-
y_pred = y_pred / y_pred_sum[:, np.newaxis]
2692+
if (y_pred_sum != 1).any():
2693+
warnings.warn(
2694+
"The y_pred values are not normalized. Starting from 1.3 this"
2695+
"would result in an error.",
2696+
UserWarning,
2697+
)
2698+
y_pred = y_pred / y_pred_sum[:, np.newaxis]
2699+
26842700
loss = -xlogy(transformed_labels, y_pred).sum(axis=1)
26852701

26862702
return _weighted_sum(loss, sample_weight, normalize)

0 commit comments

Comments
 (0)
0