8000 ENH Native support for missing values in GBDTs (#13911) · scikit-learn/scikit-learn@4b6273b · GitHub
[go: up one dir, main page]

Skip to content

Commit 4b6273b

Browse files
NicolasHugadrinjalali
authored andcommitted
ENH Native support for missing values in GBDTs (#13911)
* Added NaN support in mapper * pep * WIP * some more * WIP * WIP * bug fix * basic tests * some doc * avoid some interactions * Added tag * better test * decent test + fix bug * add missing_fraction param to benchmark * bin training and validation data separately * shorter test * Map missing values to first bin instead of last * pep8 * Added whats new entry * avoid some python interactions * make predict_binned work * fixed bug due to offset in bin_thresholds_ attribute * more sensible binning strat * typo * user name * Add small test * convert to fortran array in tests * some doc * Added function test * pep8 * Bin validation data using binmaper of training data * Allocate first bin for missing entries based on the whole data, not just training data. * Addressed Thomas' comments * Update sklearn/ensemble/_hist_gradient_boosting/tests/test_grower.py * Addressed Guillaume's comments * always allocate first bin for missing values * reduce diff * minor more consistent test * typo * WIP * some doc * reduce diff * pep8 * minor * remove prints * towards nan only splits * don't check right to left on split_on_nan * cleaups * format and comment * Fixed bug + added more tests * refactor tests * put back n_threads to max value * minor changes * minor cleaning * Add (failing) test that checks equivalence with min max imputation * Decrease the likelihood of ties when training the trees * More robust test * Fix pytest parametrization * Check bin thresholds in test * Try to make the test even easier to see if the Linux 32bit build would pass in this case * Don't check last non-missing bin if there's no nan * Improve min-max imputation test * FIX: _find_best_bin_to_split_right_to_left is still required even when left to right wants to split on nans * comments * remove split_on_nan * ooops deleted useless files * Got rid of individual checks in predictor code +inf thresholds are only allowed in a split on nan situation. Thresholds that are computed as +inf are capped to a very high constant value * can also remove special case in binning code * minor typos + more consistent test * renamed types -> common * 1e300 -> almost inf * added user guide section on missing values * Addressed Olivier's comment + updated whatsnew * addressed comments * Fix doctest formatting * Fix nan predictive doctest
1 parent a05c8d8 commit 4b6273b

26 files changed

+1180
-321
lines changed

benchmarks/bench_hist_gradient_boosting.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import argparse
33

44
import matplotlib.pyplot as plt
5+
import numpy as np
56
from sklearn.model_selection import train_test_split
67
# To use this experimental feature, we need to explicitly ask for it:
78
from sklearn.experimental import enable_hist_gradient_boosting # noqa
@@ -25,6 +26,7 @@
2526
parser.add_argument('--learning-rate', type=float, default=.1)
2627
parser.add_argument('--problem', type=str, default='classification',
2728
choices=['classification', 'regression'])
29+
parser.add_argument('--missing-fraction', type=float, default=0)
2830
parser.add_argument('--n-classes', type=int, default=2)
2931
parser.add_argument('--n-samples-max', type=int, default=int(1e6))
3032
parser.add_argument('--n-features', type=int, default=20)
@@ -52,6 +54,11 @@ def get_estimator_and_data():
5254

5355

5456
X, y, Estimator = get_estimator_and_data()
57+
if args.missing_fraction:
58+
mask = np.random.binomial(1, args.missing_fraction, size=X.shape).astype(
59+
np.bool)
60+
X[mask] = np.nan
61+
5562
X_train_, X_test_, y_train_, y_test_ = train_test_split(
5663
X, y, test_size=0.5, random_state=0)
5764

doc/modules/ensemble.rst

Lines changed: 42 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -864,7 +864,7 @@ Usage
864864
Most of the parameters are unchanged from
865865
:class:`GradientBoostingClassifier` and :class:`GradientBoostingRegressor`.
866866
One exception is the ``max_iter`` parameter that replaces ``n_estimators``, and
867-
controls the number of iterations of the boosting process:
867+
controls the number of iterations of the boosting process::
868868

869869
>>> from sklearn.experimental import enable_hist_gradient_boosting
870870
>>> from sklearn.ensemble import HistGradientBoostingClassifier
@@ -873,10 +873,10 @@ controls the number of iterations of the boosting process:
873873
>>> X, y = make_hastie_10_2(random_state=0)
874874
>>> X_train, X_test = X[:2000], X[2000:]
875875
>>> y_train, y_test = y[:2000], y[2000:]
876-
>>> clf = HistGradientBoostingClassifier(max_iter=100).fit(X_train, y_train)
877876

877+
>>> clf = HistGradientBoostingClassifier(max_iter=100).fit(X_train, y_train)
878878
>>> clf.score(X_test, y_test)
879-
0.8998
879+
0.8965
880880

881881
The size of the trees can be controlled through the ``max_leaf_nodes``,
882882
``max_depth``, and ``min_samples_leaf`` parameters.
@@ -895,6 +895,45 @@ using an arbitrary :term:`scorer`, or just the training or validation loss. By
895895
default, early-stopping is performed using the default :term:`scorer` of
896896
the estimator on a validation set.
897897

898+
Missing values support
899+
----------------------
900+
901+
:class:`HistGradientBoostingClassifier` and
902+
:class:`HistGradientBoostingRegressor` have built-in support for missing
903+
values (NaNs).
904+
905+
During training, the tree grower learns at each split point whether samples
906+
with missing values should go to the left or right child, based on the
907+
potential gain. When predicting, samples with missing values are assigned to
908+
the left or right child consequently::
909+
910+
>>> from sklearn.experimental import enable_hist_gradient_boosting # noqa
911+
>>> from sklearn.ensemble import HistGradientBoostingClassifier
912+
>>> import numpy as np
913+
914+
>>> X = np.array([0, 1, 2, np.nan]).reshape(-1, 1)
915+
>>> y = [0, 0, 1, 1]
916+
917+
>>> gbdt = HistGradientBoostingClassifier(min_samples_leaf=1).fit(X, y)
918+
>>> gbdt.predict(X)
919+
array([0, 0, 1, 1])
920+
921+
When the missingness pattern is predictive, the splits can be done on
922+
whether the feature value is missing or not::
923+
924+
>>> X = np.array([0, np.nan, 1, 2, np.nan]).reshape(-1, 1)
925+
>>> y = [0, 1, 0, 0, 1]
926+
>>> gbdt = HistGradientBoostingClassifier(min_samples_leaf=1,
927+
... max_depth=2,
928+
... learning_rate=1,
929+
... max_iter=1).fit(X, y)
930+
>>> gbdt.predict(X)
931+
array([0, 1, 0, 0, 1])
932+
933+
If no missing values were encountered for a given feature during training,
934+
then samples with missing values are mapped to whichever child has the most
935+
samples.
936+
898937
Low-level parallelism
899938
---------------------
900939

doc/whats_new/v0.22.rst

Lines changed: 30 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -23,10 +23,11 @@ random sampling procedures.
2323
- :class:`decomposition.SparseCoder` with `algorithm='lasso_lars'` |Fix|
2424
- :class:`decomposition.SparsePCA` where `normalize_components` has no effect
2525
due to deprecation.
26-
2726
- :class:`linear_model.Ridge` when `X` is sparse. |Fix|
28-
2927
- :class:`cluster.KMeans` when `n_jobs=1`. |Fix|
28+
- :class:`ensemble.HistGradientBoostingClassifier` and
29+
:class:`ensemble.HistGradientBoostingRegressor` |Fix|, |Feature|,
30+
|Enhancement|.
3031

3132
Details are listed in the changelog below.
3233

@@ -112,24 +113,31 @@ Changelog
112113
:mod:`sklearn.ensemble`
113114
.......................
114115

115-
- |Feature| :class:`ensemble.HistGradientBoostingClassifier` and
116-
:class:`ensemble.HistGradientBoostingRegressor` have an additional
117-
parameter called `warm_start` that enables warm starting. :pr:`14012` by
118-
:user:`Johann Faouzi <johannfaouzi>`.
119-
120-
- |Fix| :class:`ensemble.HistGradientBoostingClassifier` and
121-
:class:`ensemble.HistGradientBoostingRegressor` now bin the training and
122-
validation data separately to avoid any data leak. :pr:`13933` by
123-
`Nicolas Hug`_.
116+
- Many improvements were made to
117+
:class:`ensemble.HistGradientBoostingClassifier` and
118+
:class:`ensemble.HistGradientBoostingRegressor`:
119+
120+
- |MajorFeature| Estimators now natively support dense data with missing
121+
values both for training and predicting. They also support infinite
122+
values. :pr:`13911` and :pr:`14406` by `Nicolas Hug`_, `Adrin Jalali`_
123+
and `Olivier Grisel`_.
124+
- |Feature| Estimators now have an additional `warm_start` parameter that
125+
enables warm starting. :pr:`14012` by :user:`Johann Faouzi <johannfaouzi>`.
126+
- |Enhancement| for :class:`ensemble.HistGradientBoostingClassifier` the
127+
training loss or score is now monitored on a class-wise stratified
128+
subsample to preserve the class balance of the original training set.
129+
:pr:`14194` by :user:`Johann Faouzi <johannfaouzi>`.
130+
- |Feature| :func:`inspection.partial_dependence` and
131+
:func:`inspection.plot_partial_dependence` now support the fast 'recursion'
132+
method for both estimators. :pr:`13769` by `Nicolas Hug`_.
133+
- |Fix| Estimators now bin the training and validation data separately to
134+
avoid any data leak. :pr:`13933` by `Nicolas Hug`_.
135+
136+
Note that pickles from 0.21 will not work in 0.22.
124137

125138
- |Fix| :func:`ensemble.VotingClassifier.predict_proba` will no longer be
126139
present when `voting='hard'`. :pr:`14287` by `Thomas Fan`_.
127140

128-
- |Enhancement| :class:`ensemble.HistGradientBoostingClassifier` the training
129-
loss or score is now monitored on a class-wise stratified subsample to
130-
preserve the class balance of the original training set. :pr:`14194`
131-
by :user:`Johann Faouzi <johannfaouzi>`.
132-
133141
- |Fix| Run by default
134142
:func:`utils.estimator_checks.check_estimator` on both
135143
:class:`ensemble.VotingClassifier` and :class:`ensemble.VotingRegressor`. It
@@ -182,6 +190,12 @@ Changelog
182190
measure the importance of each feature in an arbitrary trained model with
183191
respect to a given scoring function. :issue:`13146` by `Thomas Fan`_.
184192

193+
- |Feature| :func:`inspection.partial_dependence` and
194+
:func:`inspection.plot_partial_dependence` now support the fast 'recursion'
195+
method for :class:`ensemble.HistGradientBoostingClassifier` and
196+
:class:`ensemble.HistGradientBoostingRegressor`. :pr:`13769` by
197+
`Nicolas Hug`_.
198+
185199
:mod:`sklearn.linear_model`
186200
...........................
187201

sklearn/ensemble/_hist_gradient_boosting/_binning.pyx

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,14 @@ import numpy as np
1212
cimport numpy as np
1313
from numpy.math cimport INFINITY
1414
from cython.parallel import prange
15+
from libc.math cimport isnan
1516

16-
from .types cimport X_DTYPE_C, X_BINNED_DTYPE_C
17+
from .common cimport X_DTYPE_C, X_BINNED_DTYPE_C
1718

18-
cpdef _map_to_bins(const X_DTYPE_C [:, :] data, list binning_thresholds,
19-
X_BINNED_DTYPE_C [::1, :] binned):
19+
def _map_to_bins(const X_DTYPE_C [:, :] data,
20+
list binning_thresholds,
21+
const unsigned char missing_values_bin_idx,
22+
X_BINNED_DTYPE_C [::1, :] binned):
2023
"""Bin numerical values to discrete integer-coded levels.
2124
2225
Parameters
@@ -35,11 +38,13 @@ cpdef _map_to_bins(const X_DTYPE_C [:, :] data, list binning_thresholds,
3538
for feature_idx in range(data.shape[1]):
3639
_map_num_col_to_bins(data[:, feature_idx],
3740
binning_thresholds[feature_idx],
41+
missing_values_bin_idx,
3842
binned[:, feature_idx])
3943

4044

4145
cdef void _map_num_col_to_bins(const X_DTYPE_C [:] data,
4246
const X_DTYPE_C [:] binning_thresholds,
47+
const unsigned char missing_values_bin_idx,
4348
X_BINNED_DTYPE_C [:] binned):
4449
"""Binary search to find the bin index for each value in the data."""
4550
cdef:
@@ -49,11 +54,11 @@ cdef void _map_num_col_to_bins(const X_DTYPE_C [:] data,
4954
int middle
5055

5156
for i in prange(data.shape[0], schedule='static', nogil=True):
52-
if data[i] == INFINITY:
53-
# Special case for +inf.
54-
# -inf is handled properly by binary search.
55-
binned[i] = binning_thresholds.shape[0]
57+
58+
if isnan(data[i]):
59+
binned[i] = missing_values_bin_idx
5660
else:
61+
# for known values, use binary search
5762
left, right = 0, binning_thresholds.shape[0]
5863
while left < right:
5964
middle = (right + left - 1) // 2

sklearn/ensemble/_hist_gradient_boosting/_gradient_boosting.pyx

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,8 @@ from cython.parallel import prange
1010
import numpy as np
1111
cimport numpy as np
1212

13-
from .types import Y_DTYPE
14-
from .types cimport Y_DTYPE_C
13+
from .common import Y_DTYPE
14+
from .common cimport Y_DTYPE_C
1515

1616

1717
def _update_raw_predictions(

sklearn/ensemble/_hist_gradient_boosting/_loss.pyx

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,8 @@ cimport numpy as np
1212

1313
from libc.math cimport exp
1414

15-
from .types cimport Y_DTYPE_C
16-
from .types cimport G_H_DTYPE_C
15+
from .common cimport Y_DTYPE_C
16+
from .common cimport G_H_DTYPE_C
1717

1818

1919
def _update_gradients_least_squares(

sklearn/ensemble/_hist_gradient_boosting/_predictor.pyx

Lines changed: 26 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -7,15 +7,16 @@
77

88
cimport cython
99
from cython.parallel import prange
10+
from libc.math cimport isnan
1011
import numpy as np
1112
cimport numpy as np
1213
from numpy.math cimport INFINITY
1314

14-
from .types cimport X_DTYPE_C
15-
from .types cimport Y_DTYPE_C
16-
from .types import Y_DTYPE
17-
from .types cimport X_BINNED_DTYPE_C
18-
from .types cimport node_struct
15+
from .common cimport X_DTYPE_C
16+
from .common cimport Y_DTYPE_C
17+
from .common import Y_DTYPE
18+
from .common cimport X_BINNED_DTYPE_C
19+
from .common cimport node_struct
1920

2021

2122
def _predict_from_numeric_data(
@@ -43,10 +44,12 @@ cdef inline Y_DTYPE_C _predict_one_from_numeric_data(
4344
while True:
4445
if node.is_leaf:
4546
return node.value
46-
if numeric_data[row, node.feature_idx] == INFINITY:
47-
# if data is +inf we always go to the right child, even when the
48-
# threhsold is +inf
49-
node = nodes[node.right]
47+
48+
if isnan(numeric_data[row, node.feature_idx]):
49+
if node.missing_go_to_left:
50+
node = nodes[node.left]
51+
else:
52+
node = nodes[node.right]
5053
else:
5154
if numeric_data[row, node.feature_idx] <= node.threshold:
5255
node = nodes[node.left]
@@ -57,19 +60,22 @@ cdef inline Y_DTYPE_C _predict_one_from_numeric_data(
5760
def _predict_from_binned_data(
5861
node_struct [:] nodes,
5962
const X_BINNED_DTYPE_C [:, :] binned_data,
63+
const unsigned char missing_values_bin_idx,
6064
Y_DTYPE_C [:] out):
6165

6266
cdef:
6367
int i
6468

6569
for i in prange(binned_data.shape[0], schedule='static', nogil=True):
66-
out[i] = _predict_one_from_binned_data(nodes, binned_data, i)
70+
out[i] = _predict_one_from_binned_data(nodes, binned_data, i,
71+
missing_values_bin_idx)
6772

6873

6974
cdef inline Y_DTYPE_C _predict_one_from_binned_data(
7075
node_struct [:] nodes,
7176
const X_BINNED_DTYPE_C [:, :] binned_data,
72-
const int row) nogil:
77+
const int row,
78+
const unsigned char missing_values_bin_idx) nogil:
7379
# Need to pass the whole array and the row index, else prange won't work.
7480
# See issue Cython #2798
7581

@@ -79,10 +85,16 @@ cdef inline Y_DTYPE_C _predict_one_from_binned_data(
7985
while True:
8086
if node.is_leaf:
8187
return node.value
82-
if binned_data[row, node.feature_idx] <= node.bin_threshold:
83-
node = nodes[node.left]
88+
if binned_data[row, node.feature_idx] == missing_values_bin_idx:
89+
if node.missing_go_to_left:
90+
node = nodes[node.left]
91+
else:
92+
node = nodes[node.right]
8493
else:
85-
node = nodes[node.right]
94+
if binned_data[row, node.feature_idx] <= node.bin_threshold:
95+
node = nodes[node.left]
96+
else:
97+
node = nodes[node.right]
8698

8799
def _compute_partial_dependence(
88100
node_struct [:] nodes,

0 commit comments

Comments
 (0)
0