8000 ENH Use scipy Yeo-Johnson implementation in PowerTransformer for scip… · scikit-learn/scikit-learn@6c1d33f · GitHub
[go: up one dir, main page]

Skip to content

Commit 6c1d33f

Browse files
yaichmMohamed Yaichlorentzenchr
authored
ENH Use scipy Yeo-Johnson implementation in PowerTransformer for scipy >= 1.9 (#31227)
Co-authored-by: Mohamed Yaich <Mohamed.Yaich@grenoble-inp.org> Co-authored-by: Christian Lorentzen <lorentzen.ch@gmail.com>
1 parent 4f614da commit 6c1d33f

File tree

4 files changed

+94
-3
lines changed

4 files changed

+94
-3
lines changed
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
- Now using ``scipy.stats.yeojohnson`` instead of our own implementation of the Yeo-Johnson transform.
2+
Fixed numerical stability (mostly overflows) of the Yeo-Johnson transform with
3+
`PowerTransformer(method="yeo-johnson")` when scipy version is `>= 1.12`.
4+
Initial PR by :user:`Xuefeng Xu <xuefeng-xu>` completed by :user:`Mohamed Yaich <yaichm>`,
5+
:user:`Oussama Er-rabie <eroussama>`, :user:`Mohammed Yaslam Dlimi <Dlimim>`,
6+
:user:`Hamza Zaroual <HamzaLuffy>`, :user:`Amine Hannoun <AmineHannoun>` and :user:`Sylvain Marié <smarie>`.

sklearn/preprocessing/_data.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from numbers import Integral, Real
77

88
import numpy as np
9-
from scipy import optimize, sparse, stats
9+
from scipy import sparse, stats
1010
from scipy.special import boxcox, inv_boxcox
1111

1212
from sklearn.utils import metadata_routing
@@ -28,6 +28,7 @@
2828
)
2929
from ..utils._param_validation import Interval, Options, StrOptions, validate_params
3030
from ..utils.extmath import _incremental_mean_and_var, row_norms
31+
from ..utils.fixes import _yeojohnson_lambda
3132
from ..utils.sparsefuncs import (
3233
incr_mean_variance_axis,
3334
inplace_column_scale,
@@ -3542,8 +3543,8 @@ def _neg_log_likelihood(lmbda):
35423543
# the computation of lambda is influenced by NaNs so we need to
35433544
# get rid of them
35443545
x = x[~np.isnan(x)]
3545-
# choosing bracket -2, 2 like for boxcox
3546-
return optimize.brent(_neg_log_likelihood, brack=(-2, 2))
3546+
3547+
return _yeojohnson_lambda(_neg_log_likelihood, x)
35473548

35483549
def _check_input(self, X, in_fit, check_positive=False, check_shape=False):
35493550
"""Validate the input before fit and transform.

sklearn/preprocessing/tests/test_data.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from sklearn import config_context, datasets
1313
from sklearn.base import clone
1414
from sklearn.exceptions import NotFittedError
15+
from sklearn.externals._packaging.version import parse as parse_version
1516
from sklearn.metrics.pairwise import linear_kernel
1617
from sklearn.model_selection import cross_val_predict
1718
from sklearn.pipeline import Pipeline
@@ -62,6 +63,7 @@
6263
CSC_CONTAINERS,
6364
CSR_CONTAINERS,
6465
LIL_CONTAINERS,
66+
sp_version,
6567
)
6668
from sklearn.utils.sparsefuncs import mean_variance_axis
6769

@@ -2640,3 +2642,52 @@ def test_power_transformer_constant_feature(standardize):
26402642
assert_allclose(Xt_, np.zeros_like(X))
26412643
else:
26422644
assert_allclose(Xt_, X)
2645+
2646+
2647+
@pytest.mark.skipif(
2648+
sp_version < parse_version("1.12"),
2649+
reason="scipy version 1.12 required for stable yeo-johnson",
2650+
)
2651+
def test_power_transformer_no_warnings():
2652+
"""Verify that PowerTransformer operates without raising any warnings on valid data.
2653+
2654+
This test addresses numerical issues with floating point numbers (mostly
2655+
overflows) with the Yeo-Johnson transform, see
2656+
https://github.com/scikit-learn/scikit-learn/issues/23319#issuecomment-1464933635
2657+
"""
2658+
x = np.array(
2659+
[
2660+
2003.0,
2661+
1950.0,
2662+
1997.0,
2663+
2000.0,
2664+
2009.0,
2665+
2009.0,
2666+
1980.0,
2667+
1999.0,
2668+
2007.0,
2669+
1991.0,
2670+
]
2671+
)
2672+
2673+
def _test_no_warnings(data):
2674+
"""Internal helper to test for unexpected warnings."""
2675+
with warnings.catch_warnings(record=True) as caught_warnings:
2676+
warnings.simplefilter("always") # Ensure all warnings are captured
2677+
PowerTransformer(method="yeo-johnson", standardize=True).fit_transform(data)
2678+
2679+
assert not caught_warnings, "Unexpected warnings were raised:\n" + "\n".join(
2680+
str(w.message) for w in caught_warnings
2681+
)
2682+
2683+
# Full dataset: Should not trigger overflow in variance calculation.
2684+
_test_no_warnings(x.reshape(-1, 1))
2685+
2686+
# Subset of data: Should not trigger overflow in power calculation.
2687+
_test_no_warnings(x[:5].reshape(-1< 2364 /span>, 1))
2688+
2689+
2690+
def test_yeojohnson_for_different_scipy_version():
2691+
"""Check that the results are consistent across different SciPy versions."""
2692+
pt = PowerTransformer(method="yeo-johnson").fit(X_1col)
2693+
pt.lambdas_[0] == F438 pytest.approx(0.99546157, rel=1e-7)

sklearn/utils/fixes.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
import scipy
1515
import scipy.sparse.linalg
1616
import scipy.stats
17+
from scipy import optimize
1718

1819
try:
1920
import pandas as pd
@@ -80,6 +81,38 @@ def _sparse_linalg_cg(A, b, **kwargs):
8081
return scipy.sparse.linalg.cg(A, b, **kwargs)
8182

8283

84+
# TODO : remove this when required minimum version of scipy >= 1.9.0
85+
def _yeojohnson_lambda(_neg_log_likelihood, x):
86+
"""Estimate the optimal Yeo-Johnson transformation parameter (lambda).
87+
88+
This function provides a compatibility workaround for versions of SciPy
89+
older than 1.9.0, where `scipy.stats.yeojohnson` did not return
90+
the estimated lambda directly.
91+
92+
Parameters
93+
----------
94+
_neg_log_likelihood : callable
95+
A function that computes the negative log-likelihood of the Yeo-Johnson
96+
transformation for a given lambda. Used only for SciPy versions < 1.9.0.
97+
98+
x : array-like
99+
Input data to estimate the Yeo-Johnson transformation parameter.
100+
101+
Returns
102+
-------
103+
lmbda : float
104+
The estimated lambda parameter for the Yeo-Johnson transformation.
105+
"""
106+
min_scipy_version = "1.9.0"
107+
108+
if sp_version < parse_version(min_scipy_version):
109+
# choosing bracket -2, 2 like for boxcox
110+
return optimize.brent(_neg_log_likelihood, brack=(-2, 2))
111+
112+
_, lmbda = scipy.stats.yeojohnson(x, lmbda=None)
113+
return lmbda
114+
115+
83116
# TODO: Fuse the modern implementations of _sparse_min_max and _sparse_nan_min_max
84117
# into the public min_max_axis function when Scipy 1.11 is the minimum supported
85118
# version and delete the backport in the else branch below.

0 commit comments

Comments
 (0)
0