8000 FIX float16 overflow on accumulator operations in StandardScaler (#13… · thomasjpfan/scikit-learn@8ebf67d · GitHub
[go: up one dir, main page]

Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit 8ebf67d

Browse files
baluyotrafthomasjpfan
authored andcommitted
FIX float16 overflow on accumulator operations in StandardScaler (scikit-learn#13010)
1 parent eadc983 commit 8ebf67d

File tree

4 files changed

+70
-9
lines changed

4 files changed

+70
-9
lines changed

doc/whats_new/v0.21.rst

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -227,6 +227,10 @@ Support for Python 3.4 and below has been officially dropped.
227227
in the dense case. Also added a new parameter ``order`` which controls output
228228
order for further speed performances. :issue:`12251` by `Tom Dupre la Tour`_.
229229

230+
- |Fix| Fixed the calculation overflow when using a float16 dtype with
231+
:class:`preprocessing.StandardScaler`. :issue:`13007` by
232+
:user:`Raffaello Baluyot <baluyotraf>`
233+
230234
:mod:`sklearn.tree`
231235
...................
232236
- |Feature| Decision Trees can now be plotted with matplotlib using

sklearn/preprocessing/tests/test_data.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -450,6 +450,31 @@ def test_scaler_2d_arrays():
450450
assert X_scaled is not X
451451

452452

453+
def test_scaler_float16_overflow():
454+
# Test if the scaler will not overflow on float16 numpy arrays
455+
rng = np.random.RandomState(0)
456+
# float16 has a maximum of 65500.0. On the worst case 5 * 200000 is 100000
457+
# which is enough to overflow the data type
458+
X = rng.uniform(5, 10, [200000, 1]).astype(np.float16)
459+
460+
with np.errstate(over='raise'):
461+
scaler = StandardScaler().fit(X)
462+
X_scaled = scaler.transform(X)
463+
464+
# Calculate the float64 equivalent to verify result
465+
X_scaled_f64 = StandardScaler().fit_transform(X.astype(np.float64))
466+
467+
# Overflow calculations may cause -inf, inf, or nan. Since there is no nan
468+
# input, all of the outputs should be finite. This may be redundant since a
469+
# FloatingPointError exception will be thrown on overflow above.
470+
assert np.all(np.isfinite(X_scaled))
471+
472+
# The normal distribution is very unlikely to go above 4. At 4.0-8.0 the
473+
# float16 precision is 2^-8 which is around 0.004. Thus only 2 decimals are
474+
# checked to account for precision differences.
475+
assert_array_almost_equal(X_scaled, X_scaled_f64, decimal=2)
476+
477+
453478
def test_handle_zeros_in_scale():
454479
s1 = np.array([0, 1, 2, 3])
455480
s2 = _handle_zeros_in_scale(s1, copy=True)

sklearn/utils/extmath.py

Lines changed: 35 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -658,6 +658,38 @@ def make_nonnegative(X, min_value=0):
658658
return X
659659

660660

661+
# Use at least float64 for the accumulating functions to avoid precision issue
662+
# see https://github.com/numpy/numpy/issues/9393. The float64 is also retained
663+
# as it is in case the float overflows
664+
def _safe_accumulator_op(op, x, *args, **kwargs):
665+
"""
666+
This function provides numpy accumulator functions with a float64 dtype
667+
when used on a floating point input. This prevents accumulator overflow on
668+
smaller floating point dtypes.
669+
670+
Parameters
671+
----------
672+
op : function
673+
A numpy accumulator function such as np.mean or np.sum
674+
x : numpy array
675+
A numpy array to apply the accumulator function
676+
*args : positional arguments
677+
Positional arguments passed to the accumulator function after the
678+
input x
679+
**kwargs : keyword arguments
680+
Keyword arguments passed to the accumulator function
681+
682+
Returns
683+
-------
684+
result : The output of the accumulator function passed to this function
685+
"""
686+
if np.issubdtype(x.dtype, np.floating) and x.dtype.itemsize < 8:
687+
result = op(x, *args, **kwargs, dtype=np.float64)
688+
else:
689+
result = op(x, *args, **kwargs)
690+
return result
691+
692+
661693
def _incremental_mean_and_var(X, last_mean, last_variance, last_sample_count):
662694
"""Calculate mean update and a Youngs and Cramer variance update.
663695
@@ -708,12 +740,7 @@ def _incremental_mean_and_var(X, last_mean, last_variance, last_sample_count):
708740
# new = the current increment
709741
# updated = the aggregated stats
710742
last_sum = last_mean * last_sample_count
711-
if np.issubdtype(X.dtype, np.floating) and X.dtype.itemsize < 8:
712-
# Use at least float64 for the accumulator to avoid precision issues;
713-
# see https://github.com/numpy/numpy/issues/9393
714-
new_sum = np.nansum(X, axis=0, dtype=np.float64).astype(X.dtype)
715-
else:
716-
new_sum = np.nansum(X, axis=0)
743+
new_sum = _safe_accumulator_op(np.nansum, X, axis=0)
717744

718745
new_sample_count = np.sum(~np.isnan(X), axis=0)
719746
updated_sample_count = last_sample_count + new_sample_count
@@ -723,7 +750,8 @@ def _incremental_mean_and_var(X, last_mean, last_variance, last_sample_count):
723750
if last_variance is None:
724751
updated_variance = None
725752
else:
726-
new_unnormalized_variance = np.nanvar(X, axis=0) * new_sample_count
753+
new_unnormalized_variance = (
754+
_safe_accumulator_op(np.nanvar, X, axis=0) * new_sample_count)
727755
last_unnormalized_variance = last_variance * last_sample_count
728756

729757
with np.errstate(divide='ignore', invalid='ignore'):

sklearn/utils/validation.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,14 +34,18 @@
3434

3535
def _assert_all_finite(X, allow_nan=False):
3636
"""Like assert_all_finite, but only for ndarray."""
37+
# validation is also imported in extmath
38+
from .extmath import _safe_accumulator_op
39+
3740
if _get_config()['assume_finite']:
3841
return
3942
X = np.asanyarray(X)
4043
# First try an O(n) time, O(1) space solution for the common case that
4144
# everything is finite; fall back to O(n) space np.isfinite to prevent
42-
# false positives from overflow in sum method.
45+
# false positives from overflow in sum method. The sum is also calculated
46+
# safely to reduce dtype induced overflows.
4347
is_float = X.dtype.kind in 'fc'
44-
if is_float and np.isfinite(X.sum()):
48+
if is_float and (np.isfinite(_safe_accumulator_op(np.sum, X))):
4549
pass
4650
elif is_float:
4751
msg_err = "Input contains {} or a value too large for {!r}."

0 commit comments

Comments
 (0)
0