8000 FIX divide by zero in line search of GradientBoostingClassifier (#28095) · scikit-learn/scikit-learn@f3b13e5 · GitHub
[go: up one dir, main page]

Skip to content

Commit f3b13e5

Browse files
authored
FIX divide by zero in line search of GradientBoostingClassifier (#28095)
1 parent 54de830 commit f3b13e5

File tree

3 files changed

+42
-7
lines changed

3 files changed

+42
-7
lines changed

doc/whats_new/v1.4.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -470,7 +470,7 @@ Changelog
470470
- |Efficiency| :class:`ensemble.GradientBoostingClassifier` is faster,
471471
for binary and in particular for multiclass problems thanks to the private loss
472472
function module.
473-
:pr:`26278` by :user:`Christian Lorentzen <lorentzenchr>`.
473+
:pr:`26278` and :pr:`28095` by :user:`Christian Lorentzen <lorentzenchr>`.
474474

475475
- |Efficiency| Improves runtime and memory usage for
476476
:class:`ensemble.GradientBoostingClassifier` and

sklearn/ensemble/_gb.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -65,16 +65,23 @@
6565

6666
def _safe_divide(numerator, denominator):
6767
"""Prevents overflow and division by zero."""
68-
try:
68+
# This is used for classifiers where the denominator might become zero exatly.
69+
# For instance for log loss, HalfBinomialLoss, if proba=0 or proba=1 exactly, then
70+
# denominator = hessian = 0, and we should set the node value in the line search to
71+
# zero as there is no improvement of the loss possible.
72+
# For numerical safety, we do this already for extremely tiny values.
73+
if abs(denominator) < 1e-150:
74+
return 0.0
75+
else:
76+
# Cast to Python float to trigger Python errors, e.g. ZeroDivisionError,
77+
# without relying on `np.errstate` that is not supported by Pyodide.
78+
result = float(numerator) / float(denominator)
6979
# Cast to Python float to trigger a ZeroDivisionError without relying
7080
# on `np.errstate` that is not supported by Pyodide.
7181
result = float(numerator) / float(denominator)
7282
if math.isinf(result):
7383
warnings.warn("overflow encountered in _safe_divide", RuntimeWarning)
7484
return result
75-
except ZeroDivisionError:
76-
warnings.warn("divide by zero encountered in _safe_divide", RuntimeWarning)
77-
return 0.0
7885

7986

8087
def _init_raw_predictions(X, estimator, loss, use_predict_proba):

sklearn/ensemble/tests/test_gradient_boosting.py

Lines changed: 30 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1452,9 +1452,9 @@ def test_huber_vs_mean_and_median():
14521452

14531453
def test_safe_divide():
14541454
"""Test that _safe_divide handles division by zero."""
1455-
with pytest.warns(RuntimeWarning, match="divide"):
1455+
with warnings.catch_warnings():
1456+
warnings.simplefilter("error")
14561457
assert _safe_divide(np.float64(1e300), 0) == 0
1457-
with pytest.warns(RuntimeWarning, match="divide"):
14581458
assert _safe_divide(np.float64(0.0), np.float64(0.0)) == 0
14591459
with pytest.warns(RuntimeWarning, match="overflow"):
14601460
# np.finfo(float).max = 1.7976931348623157e+308
@@ -1680,3 +1680,31 @@ def test_multinomial_error_exact_backward_compat():
16801680
]
16811681
)
16821682
assert_allclose(gbt.train_score_[-10:], train_score, rtol=1e-8)
1683+
1684+
1685+
def test_gb_denominator_zero(global_random_seed):
1686+
"""Test _update_terminal_regions denominator is not zero.
1687+
1688+
For instance for log loss based binary classification, the line search step might
1689+
become nan/inf as denominator = hessian = prob * (1 - prob) and prob = 0 or 1 can
1690+
happen.
1691+
Here, we create a situation were this happens (at least with roughly 80%) based
1692+
on the random seed.
1693+
"""
1694+
X, y = datasets.make_hastie_10_2(n_samples=100, random_state=20)
1695+
1696+
params = {
1697+
"learning_rate": 1.0,
1698+
"subsample": 0.5,
1699+
"n_estimators": 100,
1700+
"max_leaf_nodes": 4,
1701+
"max_depth": None,
1702+
"random_state": global_random_seed,
1703+
"min_samples_leaf": 2,
1704+
}
1705+
1706+
clf = GradientBoostingClassifier(**params)
1707+
# _safe_devide would raise a RuntimeWarning
1708+
with warnings.catch_warnings():
1709+
warnings.simplefilter("error")
1710+
clf.fit(X, y)

0 commit comments

Comments
 (0)
0