8000 MRG FIX: order of values of self.quantiles_ in QuantileTransformer (#… · scikit-learn/scikit-learn@0c0b6da · GitHub
[go: up one dir, main page]

Skip to content

Commit 0c0b6da

Browse files
tirthasheshpatelogrisel
authored andcommitted
MRG FIX: order of values of self.quantiles_ in QuantileTransformer (#15751)
1 parent fa46467 commit 0c0b6da

File tree

3 files changed

+35
-0
lines changed

3 files changed

+35
-0
lines changed

doc/whats_new/v0.22.rst

+4
Original file line numberDiff line numberDiff line change
@@ -817,6 +817,10 @@ Changelog
817817
:class:`preprocessing.KernelCenterer`
818818
:pr:`14336` by :user:`Gregory Dexter <gdex1>`.
819819

820+
- |Fix| :class:`preprocessing.QuantileTransformer` now guarantees the
821+
`quantiles_` attribute to be completely sorted in non-decreasing manner.
822+
:pr:`15751` by :user:`Tirth Patel <tirthasheshpatel>`.
823+
820824
:mod:`sklearn.model_selection`
821825
..............................
822826

sklearn/preprocessing/_data.py

+10
Original file line numberDiff line numberDiff line change
@@ -2262,6 +2262,11 @@ def _dense_fit(self, X, random_state):
22622262
col = col.take(subsample_idx, mode='clip')
22632263
self.quantiles_.append(np.nanpercentile(col, references))
22642264
self.quantiles_ = np.transpose(self.quantiles_)
2265+
# Due to floating-point precision error in `np.nanpercentile`,
2266+
# make sure that quantiles are monotonically increasing.
2267+
# Upstream issue in numpy:
2268+
# https://github.com/numpy/numpy/issues/14685
2269+
self.quantiles_ = np.maximum.accumulate(self.quantiles_)
22652270

22662271
def _sparse_fit(self, X, random_state):
22672272
"""Compute percentiles for sparse matrices.
@@ -2305,6 +2310,11 @@ def _sparse_fit(self, X, random_state):
23052310
self.quantiles_.append(
23062311
np.nanpercentile(column_data, references))
23072312
self.quantiles_ = np.transpose(self.quantiles_)
2313+
# due to floating-point precision error in `np.nanpercentile`,
2314+
# make sure the quantiles are monotonically increasing
2315+
# Upstream issue in numpy:
2316+
# https://github.com/numpy/numpy/issues/14685
2317+
self.quantiles_ = np.maximum.accumulate(self.quantiles_)
23082318

23092319
def fit(self, X, 10000 y=None):
23102320
"""Compute the quantiles used for transforming.

sklearn/preprocessing/tests/test_data.py

+21
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
from sklearn.utils._testing import assert_allclose
2626
from sklearn.utils._testing import assert_allclose_dense_sparse
2727
from sklearn.utils._testing import skip_if_32bit
28+
from sklearn.utils._testing import _convert_container
2829

2930
from sklearn.utils.sparsefuncs import mean_variance_axis
3031
from sklearn.preprocessing._data import _handle_zeros_in_scale
@@ -1532,6 +1533,26 @@ def test_quantile_transform_nan():
15321533
assert not np.isnan(transformer.quantiles_[:, 1:]).any()
15331534

15341535

1536+
@pytest.mark.parametrize("array_type", ['array', 'sparse'])
1537+
def test_quantile_transformer_sorted_quantiles(array_type):
1538+
# Non-regression test for:
1539+
# https://github.com/scikit-learn/scikit-learn/issues/15733
1540+
# Taken from upstream bug report:
1541+
# https://github.com/numpy/numpy/issues/14685
1542+
X = np.array([0, 1, 1, 2, 2, 3, 3, 4, 5, 5, 1, 1, 9, 9, 9, 8, 8, 7] * 10)
1543+
X = 0.1 * X.reshape(-1, 1)
1544+
X = _convert_container(X, array_type)
1545+
1546+
n_quantiles = 100
1547+
qt = QuantileTransformer(n_quantiles=n_quantiles).fit(X)
1548+
1549+
# Check that the estimated quantile threasholds are monotically
1550+
# increasing:
1551+
quantiles = qt.quantiles_[:, 0]
1552+
assert len(quantiles) == 100
1553+
assert all(np.diff(quantiles) >= 0)
1554+
1555+
15351556
def test_robust_scaler_invalid_range():
15361557
for range_ in [
15371558
(-1, 90),

0 commit comments

Comments
 (0)
0