8000 FIX _safe_divide should handle zero-division with numpy scalar (#27312) · scikit-learn/scikit-learn@bbc73cf · GitHub
[go: up one dir, main page]

Skip to content

Commit bbc73cf

Browse files
FIX _safe_divide should handle zero-division with numpy scalar (#27312)
Co-authored-by: Christian Lorentzen <lorentzen.ch@gmail.com>
1 parent c634b8a commit bbc73cf

File tree

2 files changed

+19
-9
lines changed

2 files changed

+19
-9
lines changed

sklearn/ensemble/_gb.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
# Arnaud Joly, Jacob Schreiber
2121
# License: BSD 3 clause
2222

23+
import math
2324
import warnings
2425
from abc import ABCMeta, abstractmethod
2526
from numbers import Integral, Real
@@ -64,11 +65,16 @@
6465

6566
def _safe_divide(numerator, denominator):
6667
"""Prevents overflow and division by zero."""
67-
with np.errstate(divide="raise"):
68-
try:
69-
return numerator / denominator
70-
except FloatingPointError:
71-
return 0.0
68+
try:
69+
# Cast to Python float to trigger a ZeroDivisionError without relying
70+
# on `np.errstate` that is not supported by Pyodide.
71+
result = float(numerator) / float(denominator)
72+
if math.isinf(result):
73+
warnings.warn("overflow encountered in _safe_divide", RuntimeWarning)
74+
return result
75+
except ZeroDivisionError:
76+
warnings.warn("divide by zero encountered in _safe_divide", RuntimeWarning)
77+
return 0.0
7278

7379

7480
def _init_raw_predictions(X, estimator, loss, use_predict_proba):
@@ -235,7 +241,9 @@ def compute_update(y_, indices, neg_gradient, raw_prediction, k):
235241

236242
# update each leaf (= perform line search)
237243
for leaf in np.nonzero(tree.children_left == TREE_LEAF)[0]:
238-
indices = np.nonzero(terminal_regions == leaf)[0] # of terminal regions
244+
indices = np.nonzero(masked_terminal_regions == leaf)[
245+
0
246+
] # of terminal regions
239247
y_ = y.take(indices, axis=0)
240248
sw = None if sample_weight is None else sample_weight[indices]
241249
update = compute_update(y_, indices, neg_gradient, raw_prediction, k)

sklearn/ensemble/tests/test_gradient_boosting.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1445,11 +1445,13 @@ def test_huber_vs_mean_and_median():
14451445

14461446
def test_safe_divide():
14471447
"""Test that _safe_divide handles division by zero."""
1448-
assert _safe_divide(np.array([1e300]), 0) == 0
1449-
1448+
with pytest.warns(RuntimeWarning, match="divide"):
1449+
assert _safe_divide(np.float64(1e300), 0) == 0
1450+
with pytest.warns(RuntimeWarning, match="divide"):
1451+
assert _safe_divide(np.float64(0.0), np.float64(0.0)) == 0
14501452
with pytest.warns(RuntimeWarning, match="overflow"):
14511453
# np.finfo(float).max = 1.7976931348623157e+308
1452-
_safe_divide(np.array([1e300]), 1e-10)
1454+
_safe_divide(np.float64(1e300), 1e-10)
14531455

14541456

14551457
def test_squared_error_exact_backward_compat():

0 commit comments

Comments
 (0)
0