8000 FIX float16 overflow on accumulator operations in StandardScaler (#13… · scikit-learn/scikit-learn@1f5bcae · GitHub
[go: up one dir, main page]

Skip to content

Commit 1f5bcae

Browse files
baluyotrafrth
authored andcommitted
FIX float16 overflow on accumulator operations in StandardScaler (#13010)
1 parent e8045a7 commit 1f5bcae

File tree

4 files changed

+70
-9
lines changed

4 files changed

+70
-9
lines changed

doc/whats_new/v0.21.rst

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -222,6 +222,10 @@ Support for Python 3.4 and below has been officially dropped.
222222
in the dense case. Also added a new parameter ``order`` which controls output
223223
order for further speed performances. :issue:`12251` by `Tom Dupre la Tour`_.
224224

225+
- |Fix| Fixed the calculation overflow when using a float16 dtype with
226+
:class:`preprocessing.StandardScaler`. :issue:`13007` by
227+
:user:`Raffaello Baluyot <baluyotraf>`
228+
225229
:mod:`sklearn.tree`
226230
...................
227231
- |Feature| Decision Trees can now be plotted with matplotlib using

sklearn/preprocessing/tests/test_data.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -450,6 +450,31 @@ def test_scaler_2d_arrays():
450450
assert X_scaled is not X
451451

452452

453+
def test_scaler_float16_overflow():
454+
# Test if the scaler will not overflow on float16 numpy arrays
455+
rng = np.random.RandomState(0)
456+
# float16 has a maximum of 65500.0. On the worst case 5 * 200000 is 100000
457+
# which is enough to overflow the data type
458+
X = rng.uniform(5, 10, [200000, 1]).astype(np.float16)
459+
460+
with np.errstate(over='raise'):
461+
scaler = StandardScaler().fit(X)
462+
X_scaled = scaler.transform(X)
463+
464+
# Calculate the float64 equivalent to verify result
465+
X_scaled_f64 = StandardScaler().fit_transform(X.astype(np.float64))
466+
467+
# Overflow calculations may cause -inf, inf, or nan. Since there is no nan
468+
# input, all of the outputs should be finite. This may be redundant since a
469+
# FloatingPointError exception will be thrown on overflow above.
470+
assert np.all(np.isfinite(X_scaled))
471+
472+
# The normal distribution is very unlikely to go above 4. At 4.0-8.0 the
473+
# float16 precision is 2^-8 which is around 0.004. Thus only 2 decimals are
474+
# checked to account for precision differences.
475+
assert_array_almost_equal(X_scaled, X_scaled_f64, decimal=2)
476+
477+
453478
def test_handle_zeros_in_scale():
454479
s1 = np.array([0, 1, 2, 3])
455480
s2 = _handle_zeros_in_scale(s1, copy=True)

sklearn/utils/extmath.py

Lines changed: 35 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -658,6 +658,38 @@ def make_nonnegative(X, min_value=0):
658658
return X
659659

660660

661+
# Use at least float64 for the accumulating functions to avoid precision issue
662+
# see https://github.com/numpy/numpy/issues/9393. The float64 is also retained
663+
# as it is in case the float overflows
664+
def _safe_accumulator_op(op, x, *args, **kwargs):
665+
"""
666+
This function provides numpy accumulator functions with a float64 dtype
667+
when used on a floating point input. This prevents accumulator overflow on
668+
smaller floating point dtypes.
669+
670+
Parameters
671+
----------
672+
op : function
673+
A numpy accumulator function such as np.mean or np.sum
674+
x : numpy array
675+
A numpy array to apply the accumulator function
676+
*args : positional arguments
677+
Positional arguments passed to the accumulator function after the
678+
input x
679+
**kwargs : keyword arguments
680+
Keyword arguments passed to the accumulator function
681+
682+
Returns
683+
-------
684+
result : The output of the accumulator function passed to this function
685+
"""
686+
if np.issubdtype(x.dtype, np.floating) and x.dtype.itemsize < 8:
687+
result = op(x, *args, **kwargs, dtype=np.float64)
688+
else:
689+
result = op(x, *args, **kwargs)
690+
return result
691+
692+
661693
def _incremental_mean_and_var(X, last_mean, last_variance, last_sample_count):
662694
"""Calculate mean update and a Youngs and Cramer variance update.
663695
@@ -708,12 +740,7 @@ def _incremental_mean_and_var(X, last_mean, last_variance, last_sample_count):
708740
# new = the current increment
709741
# updated = the aggregated stats
710742
last_sum = last_mean * last_sample_count
711-
if np.issubdtype(X.dtype, np.floating) and X.dtype.itemsize < 8:
712-
# Use at least float64 for the accumulator to avoid precision issues;
713-
# see https://github.com/numpy/numpy/issues/9393
714-
new_sum = np.nansum(X, axis=0, dtype=np.float64).astype(X.dtype)
715-
else:
716-
new_sum = np.nansum(X, axis=0)
743+
new_sum = _safe_accumulator_op(np.nansum, X, axis=0)
717744

718745
new_sample_count = np.sum(~np.isnan(X), axis=0)
719746
updated_sample_count = last_sample_count + new_sample_count
@@ -723,7 +750,8 @@ def _incremental_mean_and_var(X, last_mean, last_variance, last_sample_count):
723750
if last_variance is None:
724751
updated_variance = None
725752
else:
726-
new_unnormalized_variance = np.nanvar(X, axis=0) * new_sample_count
753+
new_unnormalized_variance = (
754+
_safe_accumulator_op(np.nanvar, X, axis=0) * new_sample_count)
727755
last_unnormalized_variance = last_variance * last_sample_count
728756

729757
with np.errstate(divide='ignore', invalid='ignore'):

sklearn/utils/validation.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,14 +34,18 @@
3434

3535
def _assert_all_finite(X, allow_nan=False):
3636
"""Like assert_all_finite, but only for ndarray."""
37+
# validation is also imported in extmath
38+
from .extmath import _safe_accumulator_op
39+
3740
if _get_config()['assume_finite']:
3841
return
3942
X = np.asanyarray(X)
4043
# First try an O(n) time, O(1) space solution for the common case that
4144
# everything is finite; fall back to O(n) space np.isfinite to prevent
42-
# false positives from overflow in sum method.
45+
# false positives from overflow in sum method. The sum is also calculated
46+
# safely to reduce dtype induced overflows.
4347
is_float = X.dtype.kind in 'fc'
44-
if is_float and np.isfinite(X.sum()):
48+
if is_float and (np.isfinite(_safe_accumulator_op(np.sum, X))):
4549
pass
4650
elif is_float:
4751
msg_err = "Input contains {} or a value too large for {!r}."

0 commit comments

Comments
 (0)
0