8000 ENH add parameter subsample to KBinsDiscretizer (#21445) · thomasjpfan/scikit-learn@48e5423 · GitHub
[go: up one dir, main page]

Skip to content

Commit 48e5423

Browse files
amy12xxfbiduthomasjpfanglemaitre
authored
ENH add parameter subsample to KBinsDiscretizer (scikit-learn#21445)
Co-authored-by: Felipe Bidu Rodrigues <felipe@felipevr.com> Co-authored-by: Thomas J. Fan <thomasjpfan@gmail.com> Co-authored-by: Guillaume Lemaitre <g.lemaitre58@gmail.com>
1 parent d5ce9c4 commit 48e5423

File tree

3 files changed

+153
-5
lines changed

3 files changed

+153
-5
lines changed

doc/whats_new/v1.1.rst

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,11 @@ Changelog
138138
:mod:`sklearn.preprocessing`
139139
............................
140140

141+
- |Enhancement| Adds a `subsample` parameter to :class:`preprocessing.KBinsDiscretizer`.
142+
This allows specifying a maximum number of samples to be used while fitting
143+
the model. The option is only available when `strategy` is set to `quantile`.
144+
:pr:`21445` by :user:`Felipe Bidu <fbidu>` and :user:`Amanda Dsouza <amy12xx>`.
145+
141146
- |Fix| :class:`preprocessing.LabelBinarizer` now validates input parameters in `fit`
142147
instead of `__init__`.
143148
:pr:`21434` by :user:`Krum Arnaudov <krumeto>`.

sklearn/preprocessing/_discretization.py

Lines changed: 66 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,10 @@
1515
from ..base import BaseEstimator, TransformerMixin
1616
from ..utils.validation import check_array
1717
from ..utils.validation import check_is_fitted
18+
from ..utils.validation import check_random_state
1819
from ..utils.validation import _check_feature_names_in
20+
from ..utils.validation import check_scalar
21+
from ..utils import _safe_indexing
1922

2023

2124
class KBinsDiscretizer(TransformerMixin, BaseEstimator):
@@ -63,6 +66,27 @@ class KBinsDiscretizer(TransformerMixin, BaseEstimator):
6366
6467
.. versionadded:: 0.24
6568
69+
subsample : int or None (default='warn')
70+
Maximum number of samples, used to fit the model, for computational
71+
efficiency. Used when `strategy="quantile"`.
72+
`subsample=None` means that all the training samples are used when
73+
computing the quantiles that determine the binning thresholds.
74+
Since quantile computation relies on sorting each column of `X` and
75+
that sorting has an `n log(n)` time complexity,
76+
it is recommended to use subsampling on datasets with a
77+
very large number of samples.
78+
79+
.. deprecated:: 1.1
80+
In version 1.3 and onwards, `subsample=2e5` will be the default.
81+
82+
random_state : int, RandomState instance or None, default=None
83+
Determines random number generation for subsampling.
84+
Pass an int for reproducible results across multiple function calls.
85+
See the `subsample` parameter for more details.
86+
See :term:`Glossary <random_state>`.
87+
88+
.. versionadded:: 1.1
89+
6690
Attributes
6791
----------
6892
bin_edges_ : ndarray of ndarray of shape (n_features,)
@@ -136,11 +160,22 @@ class KBinsDiscretizer(TransformerMixin, BaseEstimator):
136160
[ 0.5, 3.5, -1.5, 1.5]])
137161
"""
138162

139-
def __init__(self, n_bins=5, *, encode="onehot", strategy="quantile", dtype=None):
163+
def __init__(
164+
self,
165+
n_bins=5,
166+
*,
167+
encode="onehot",
168+
strategy="quantile",
169+
dtype=None,
170+
subsample="warn",
171+
random_state=None,
172+
):
140173
self.n_bins = n_bins
141174
self.encode = encode
142175
self.strategy = strategy
143176
self.dtype = dtype
177+
self.subsample = subsample
178+
self.random_state = random_state
144179

145180
def fit(self, X, y=None):
146181
"""
@@ -174,6 +209,36 @@ def fit(self, X, y=None):
174209
" instead."
175210
)
176211

212+
n_samples, n_features = X.shape
213+
214+
if self.strategy == "quantile" and self.subsample is not None:
215+
if self.subsample == "warn":
216+
if n_samples > 2e5:
217+
warnings.warn(
218+
"In version 1.3 onwards, subsample=2e5 "
219+
"will be used by default. Set subsample explicitly to "
220+
"silence this warning in the mean time. Set "
221+
"subsample=None to disable subsampling explicitly.",
222+
FutureWarning,
223+
)
224+
else:
225+
self.subsample = check_scalar(
226+
self.subsample, "subsample", numbers.Integral, min_val=1
227+
)
228+
rng = check_random_state(self.random_state)
229+
if n_samples > self.subsample:
230+
subsample_idx = rng.choice(
231+
n_samples, size=self.subsample, replace=False
232+
)
233+
X = _safe_indexing(X, subsample_idx)
234+
elif self.strategy != "quantile" and isinstance(
235+
self.subsample, numbers.Integral
236+
):
237+
raise ValueError(
238+
f"Invalid parameter for `strategy`: {self.strategy}. "
239+
'`subsample` must be used with `strategy="quantile"`.'
240+
)
241+
177242
valid_encode = ("onehot", "onehot-dense", "ordinal")
178243
if self.encode not in valid_encode:
179244
raise ValueError(

sklearn/preprocessing/tests/test_discretization.py

Lines changed: 82 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import scipy.sparse as sp
44
import warnings
55

6+
from sklearn import clone
67
from sklearn.preprocessing import KBinsDiscretizer
78
from sklearn.preprocessing import OneHotEncoder
89
from sklearn.utils._testing import (
@@ -37,16 +38,16 @@ def test_valid_n_bins():
3738
def test_invalid_n_bins():
3839
est = KBinsDiscretizer(n_bins=1)
3940
err_msg = (
40-
"KBinsDiscretizer received an invalid "
41-
"number of bins. Received 1, expected at least 2."
41+
"KBinsDiscretizer received an invalid number of bins. Received 1, expected at"
42+
" least 2."
4243
)
4344
with pytest.raises(ValueError, match=err_msg):
4445
est.fit_transform(X)
4546

4647
est = KBinsDiscretizer(n_bins=1.1)
4748
err_msg = (
48-
"KBinsDiscretizer received an invalid "
49-
"n_bins type. Received float, expected int."
49+
"KBinsDiscretizer received an invalid n_bins type. Received float, expected"
50+
" int."
5051
)
5152
with pytest.raises(ValueError, match=err_msg):
5253
est.fit_transform(X)
@@ -357,3 +358,80 @@ def test_32_equal_64(input_dtype, encode):
357358
Xt_64 = kbd_64.transform(X_input)
358359

359360
assert_allclose_dense_sparse(Xt_32, Xt_64)
361+
362+
363+
# FIXME: remove the `filterwarnings` in 1.3
364+
@pytest.mark.filterwarnings("ignore:In version 1.3 onwards, subsample=2e5")
365+
@pytest.mark.parametrize("subsample", [None, "warn"])
366+
def test_kbinsdiscretizer_subsample_default(subsample):
367+
# Since the size of X is small (< 2e5), subsampling will not take place.
368+
X = np.array([-2, 1.5, -4, -1]).reshape(-1, 1)
369+
kbd_default = KBinsDiscretizer(n_bins=10, encode="ordinal", strategy="quantile")
370+
kbd_default.fit(X)
371+
372+
kbd_with_subsampling = clone(kbd_default)
373+
kbd_with_subsampling.set_params(subsample=subsample)
374+
kbd_with_subsampling.fit(X)
375+
376+
for bin_kbd_default, bin_kbd_with_subsampling in zip(
377+
kbd_default.bin_edges_[0], kbd_with_subsampling.bin_edges_[0]
378+
):
379+
np.testing.assert_allclose(bin_kbd_default, bin_kbd_with_subsampling)
380+
assert kbd_default.bin_edges_.shape == kbd_with_subsampling.bin_edges_.shape
381+
382+
383+
def test_kbinsdiscretizer_subsample_invalid_strategy():
384+
X = np.array([-2, 1.5, -4, -1]).reshape(-1, 1)
385+
kbd = KBinsDiscretizer(n_bins=10, encode="ordinal", strategy="uniform", subsample=3)
386+
387+
err_msg = '`subsample` must be used with `strategy="quantile"`.'
388+
with pytest.raises(ValueError, match=err_msg):
389+
kbd.fit(X)
390+
391+
392+
def test_kbinsdiscretizer_subsample_invalid_type():
393+
X = np.array([-2, 1.5, -4, -1]).reshape(-1, 1)
394+
kbd = KBinsDiscretizer(
395+
n_bins=10, encode="ordinal", strategy="quantile", subsample="full"
396+
)
397+
398+
msg = (
399+
"subsample must be an instance of <class 'numbers.Integral'>, not "
400+
"<class 'str'>."
401+
)
402+
with pytest.raises(TypeError, match=msg):
403+
kbd.fit(X)
404+
405+
406+
# TODO: Remove in 1.3
407+
def test_kbinsdiscretizer_subsample_warn():
408+
X = np.random.rand(200001, 1).reshape(-1, 1)
409+
kbd = KBinsDiscretizer(n_bins=100, encode="ordinal", strategy="quantile")
410+
411+
msg = "In version 1.3 onwards, subsample=2e5 will be used by default."
412+
with pytest.warns(FutureWarning, match=msg):
413+
kbd.fit(X)
414+
415+
416+
@pytest.mark.parametrize("subsample", [0, int(2e5)])
417+
def test_kbinsdiscretizer_subsample_values(subsample):
418+
X = np.random.rand(220000, 1).reshape(-1, 1)
419+
kbd_default = KBinsDiscretizer(n_bins=10, encode="ordinal", strategy="quantile")
420+
421+
kbd_with_subsampling = clone(kbd_default)
422+
kbd_with_subsampling.set_params(subsample=subsample)
423+
424+
if subsample == 0:
425+
with pytest.raises(ValueError, match="subsample == 0, must be >= 1."):
426+
kbd_with_subsampling.fit(X)
427+
else:
428+
# TODO: Remove in 1.3
429+
msg = "In version 1.3 onwards, subsample=2e5 will be used by default."
430+
with pytest.warns(FutureWarning, match=msg):
431+
kbd_default.fit(X)
432+
433+
kbd_with_subsampling.fit(X)
434+
assert not np.all(
435+
kbd_default.bin_edges_[0] == kbd_with_subsampling.bin_edges_[0]
436+
)
437+
assert kbd_default.bin_edges_.shape == kbd_with_subsampling.bin_edges_.shape

0 commit comments

Comments
 (0)
0