10000 [MRG] Native support for missing values in GBDTs by NicolasHug · Pull Request #13911 · scikit-learn/scikit-learn · GitHub
[go: up one dir, main page]

Skip to content

[MRG] Native support for missing values in GBDTs #13911

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 103 commits into from
Aug 21, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
103 commits
Select commit Hold shift + click to select a range
e279161
Added NaN support in mapper
NicolasHug May 20, 2019
91105a6
pep
NicolasHug May 20, 2019
000ab9a
WIP
NicolasHug May 20, 2019
66c2502
some more
NicolasHug May 20, 2019
810b7b0
WIP
NicolasHug May 21, 2019
5fd59cb
WIP
NicolasHug May 21, 2019
670566b
bug fix
NicolasHug May 21, 2019
e338e0a
basic tests
NicolasHug May 21, 2019
d288518
some doc
NicolasHug May 21, 2019
2d1659b
avoid some interactions
NicolasHug May 21, 2019
f2a83a0
Added tag
NicolasHug May 21, 2019
cd1de3c
better test
NicolasHug May 22, 2019
5cd8e59
decent test + fix bug
NicolasHug May 22, 2019
af1558a
Merge branch 'master' of github.com:scikit-learn/scikit-learn into mi…
NicolasHug May 23, 2019
d6b73ed
add missing_fraction param to benchmark
NicolasHug May 23, 2019
5e06fa7
bin training and validation data separately
NicolasHug May 23, 2019
1a34856
shorter test
NicolasHug May 23, 2019
aae10a2
Map missing values to first bin instead of last
NicolasHug May 23, 2019
35eda6e
pep8
NicolasHug May 23, 2019
1fa9b26
Merge branch 'master' of github.com:scikit-learn/scikit-learn into mi…
NicolasHug May 23, 2019
1f63282
Added whats new entry
NicolasHug May 23, 2019
e3d34a9
avoid some python interactions
NicolasHug May 23, 2019
542cb25
make predict_binned work
NicolasHug May 23, 2019
bf822b4
fixed bug due to offset in bin_thresholds_ attribute
NicolasHug May 23, 2019
112b400
more sensible binning strat
NicolasHug May 23, 2019
21a3ee3
typo
NicolasHug May 23, 2019
28c15b2
user name
NicolasHug May 23, 2019
5a5f39d
Add small test
NicolasHug May 24, 2019
a4da8d0
Merge branch 'master' of github.com:scikit-learn/scikit-learn into mi…
NicolasHug May 24, 2019
b07fed9
convert to fortran array in tests
NicolasHug May 24, 2019
b78e96b
some doc
NicolasHug May 24, 2019
a9f878c
Added function test
NicolasHug May 24, 2019
71b64e8
pep8
NicolasHug May 24, 2019
0af212f
Merge branch 'master' of github.com:scikit-learn/scikit-learn into bi…
NicolasHug May 25, 2019
2c2373e
Bin validation data using binmaper of training data
NicolasHug May 26, 2019
0e8edd1
Merge branch 'bin_train_val_separately' into missing_value_gbdt
NicolasHug May 26, 2019
deda348
Allocate first bin for missing entries based on the whole data, not just
NicolasHug May 27, 2019
e8fcc31
Merge branch 'master' of github.com:scikit-learn/scikit-learn into mi…
NicolasHug May 28, 2019
1a471ce
Merge branch 'master' of github.com:scikit-learn/scikit-learn into mi…
NicolasHug May 28, 2019
1f78807
Merge branch 'master' of github.com:scikit-learn/scikit-learn into mi…
NicolasHug May 29, 2019
3fed0ab
Addressed Thomas' comments
NicolasHug May 30, 2019
3e7bb7d
Merge branch 'master' of github.com:scikit-learn/scikit-learn into mi…
NicolasHug May 30, 2019
7ad5bce
Update sklearn/ensemble/_hist_gradient_boosting/tests/test_grower.py
NicolasHug May 30, 2019
4a9dc3a
Merge branch 'master' of github.com:scikit-learn/scikit-learn into mi…
NicolasHug Jun 3, 2019
489b861
Merge branch 'missing_value_gbdt' of github.com:NicolasHug/scikit-lea…
NicolasHug Jun 3, 2019
e83b39e
Addressed Guillaume's comments
NicolasHug Jun 3, 2019
c80b250
Merge branch 'master' of github.com:scikit-learn/scikit-learn into mi…
NicolasHug Jun 15, 2019
d2de00b
always allocate first bin for missing values
NicolasHug Jun 15, 2019
26b66ab
reduce diff
NicolasHug Jun 15, 2019
f370a71
minor more consistent test
NicolasHug Jun 18, 2019
ec57171
typo
NicolasHug Jun 18, 2019
beae859
Merge branch 'master' of github.com:scikit-learn/scikit-learn into mi…
NicolasHug Jun 21, 2019
e82a5f4
Merge branch 'master' of github.com:scikit-learn/scikit-learn into mi…
NicolasHug Jun 26, 2019
92f3e28
Merge branch 'master' of github.com:scikit-learn/scikit-learn into mi…
NicolasHug Jun 27, 2019
2dfaad8
WIP
NicolasHug Jun 27, 2019
457e720
some doc
NicolasHug Jun 28, 2019
af5ef38
Merge branch 'master' of github.com:scikit-learn/scikit-learn into mi…
NicolasHug Jun 28, 2019
45c5068
reduce diff
NicolasHug Jun 28, 2019
5a8fbe5
pep8
NicolasHug Jun 28, 2019
bc9c0df
Merge branch 'master' of github.com:scikit-learn/scikit-learn into mi…
NicolasHug Jun 28, 2019
889835a
minor
NicolasHug Jun 28, 2019
8995db4
Merge branch 'master' of github.com:scikit-learn/scikit-learn into mi…
NicolasHug Jul 1, 2019
8d5e36e
remove prints
NicolasHug Jul 1, 2019
d28ab14
towards nan only splits
adrinjalali Jul 5, 2019
48fa149
don't check right to left on split_on_nan
adrinjalali Jul 11, 2019
76e18f8
cleaups
adrinjalali Jul 11, 2019
eb0f7e6
format and comment
adrinjalali Jul 11, 2019
e0abc50
Fixed bug + added more tests
NicolasHug Jul 12, 2019
77846a3
refactor tests
NicolasHug Jul 12, 2019
14d444f
put back n_threads to max value
NicolasHug Jul 12, 2019
8fb80fd
minor changes
NicolasHug Jul 12, 2019
0440398
minor cleaning
NicolasHug Jul 12, 2019
0fde968
Support splitting on nans
NicolasHug Jul 12, 2019
5301c5a
Merge branch 'master' of github.com:scikit-learn/scikit-learn into mi…
NicolasHug Jul 12, 2019
210f90f
Merge branch 'missing_value_gbdt' of github.com:NicolasHug/scikit-lea…
NicolasHug Jul 12, 2019
4b0176a
Add (failing) test that checks equivalence with min max imputation
ogrisel Jul 12, 2019
d38881c
Decrease the likelihood of ties when training the trees
ogrisel Jul 15, 2019
f5e8e45
More robust test
ogrisel Jul 15, 2019
a0963fb
Fix pytest parametrization
ogrisel Jul 15, 2019
d0be6cb
Check bin thresholds in test
ogrisel Jul 15, 2019
9c9d7e5
Try to make the test even easier to see if the Linux 32bit build woul…
ogrisel Jul 15, 2019
191cfc6
Merge branch 'master' of github.com:scikit-learn/scikit-learn into mi…
NicolasHug Jul 16, 2019
5fc7453
Merge branch 'missing_value_gbdt' of github.com:NicolasHug/scikit-lea…
NicolasHug Jul 16, 2019
75dc126
Don't check last non-missing bin if there's no nan
NicolasHug Jul 17, 2019
3b2075c
Improve min-max imputation test
ogrisel Jul 19, 2019
a66103c
FIX: _find_best_bin_to_split_right_to_left is still required even whe…
ogrisel Jul 19, 2019
0bda5d1
Merge branch 'master' of github.com:scikit-learn/scikit-learn into mi…
NicolasHug Jul 19, 2019
49140c2
comments
NicolasHug Jul 19, 2019
e39f48e
remove split_on_nan
NicolasHug Jul 19, 2019
f89c1c5
ooops deleted useless files
NicolasHug Jul 19, 2019
299d3e0
Got rid of individual checks in predictor code
NicolasHug Jul 22, 2019
9540f99
can also remove special case in binning code
NicolasHug Jul 23, 2019
cb3936d
minor typos + more consistent test
NicolasHug Jul 23, 2019
c8f6409
renamed types -> common
NicolasHug Jul 23, 2019
6f0e191
1e300 -> almost inf
NicolasHug Jul 23, 2019
a56db0b
Merge branch 'master' of github.com:scikit-learn/scikit-learn into mi…
NicolasHug Aug 5, 2019
c112335
added user guide section on missing values
NicolasHug Aug 5, 2019
0c5dc90
Merge branch 'master' into missing_value_gbdt
ogrisel Aug 19, 2019
ef5cce2
Merge branch 'master' of github.com:scikit-learn/scikit-learn into mi…
ogrisel Aug 20, 2019
3b0c2ba
Addressed Olivier's comment + updated whatsnew
NicolasHug Aug 20, 2019
7c868ae
addressed comments
NicolasHug Aug 20, 2019
876f538
Fix doctest formatting
ogrisel Aug 21, 2019
601dc22
Fix nan predictive doctest
ogrisel Aug 21, 2019
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions benchmarks/bench_hist_gradient_boosting.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import argparse

import matplotlib.pyplot as plt
import numpy as np
from sklearn.model_selection import train_test_split
# To use this experimental feature, we need to explicitly ask for it:
from sklearn.experimental import enable_hist_gradient_boosting # noqa
Expand All @@ -25,6 +26,7 @@
parser.add_argument('--learning-rate', type=float, default=.1)
parser.add_argument('--problem', type=str, default='classification',
choices=['classification', 'regression'])
parser.add_argument('--missing-fraction', type=float, default=0)
parser.add_argument('--n-classes', type=int, default=2)
parser.add_argument('--n-samples-max', type=int, default=int(1e6))
parser.add_argument('--n-features', type=int, default=20)
Expand Down Expand Up @@ -52,6 +54,11 @@ def get_estimator_and_data():


X, y, Estimator = get_estimator_and_data()
if args.missing_fraction:
mask = np.random.binomial(1, args.missing_fraction, size=X.shape).astype(
np.bool)
X[mask] = np.nan

X_train_, X_test_, y_train_, y_test_ = train_test_split(
X, y, test_size=0.5, random_state=0)

Expand Down
45 changes: 42 additions & 3 deletions doc/modules/ensemble.rst
Original file line number Diff line number Diff line change
Expand Up @@ -864,7 +864,7 @@ Usage
Most of the parameters are unchanged from
:class:`GradientBoostingClassifier` and :class:`GradientBoostingRegressor`.
One exception is the ``max_iter`` parameter that replaces ``n_estimators``, and
controls the number of iterations of the boosting process:
controls the number of iterations of the boosting process::

>>> from sklearn.experimental import enable_hist_gradient_boosting
>>> from sklearn.ensemble import HistGradientBoostingClassifier
Expand All @@ -873,10 +873,10 @@ controls the number of iterations of the boosting process:
>>> X, y = make_hastie_10_2(random_state=0)
>>> X_train, X_test = X[:2000], X[2000:]
>>> y_train, y_test = y[:2000], y[2000:]
>>> clf = HistGradientBoostingClassifier(max_iter=100).fit(X_train, y_train)

>>> clf = HistGradientBoostingClassifier(max_iter=100).fit(X_train, y_train)
>>> clf.score(X_test, y_test)
0.8998
0.8965

The size of the trees can be controlled through the ``max_leaf_nodes``,
``max_depth``, and ``min_samples_leaf`` parameters.
Expand All @@ -895,6 +895,45 @@ using an arbitrary :term:`scorer`, or just the training or validation loss. By
default, early-stopping is performed using the default :term:`scorer` of
the estimator on a validation set.

Missing values support
----------------------

:class:`HistGradientBoostingClassifier` and
:class:`HistGradientBoostingRegressor` have built-in support for missing
values (NaNs).

During training, the tree grower learns at each split point whether samples
with missing values should go to the left or right child, based on the
potential gain. When predicting, samples with missing values are assigned to
the left or right child consequently::

>>> from sklearn.experimental import enable_hist_gradient_boosting # noqa
>>> from sklearn.ensemble import HistGradientBoostingClassifier
>>> import numpy as np

>>> X = np.array([0, 1, 2, np.nan]).reshape(-1, 1)
>>> y = [0, 0, 1, 1]

>>> gbdt = HistGradientBoostingClassifier(min_samples_leaf=1).fit(X, y)
>>> gbdt.predict(X)
array([0, 0, 1, 1])

When the missingness pattern is predictive, the splits can be done on
whether the feature value is missing or not::

>>> X = np.array([0, np.nan, 1, 2, np.nan]).reshape(-1, 1)
>>> y = [0, 1, 0, 0, 1]
>>> gbdt = HistGradientBoostingClassifier(min_samples_leaf=1,
... max_depth=2,
... learning_rate=1,
... max_iter=1).fit(X, y)
>>> gbdt.predict(X)
array([0, 1, 0, 0, 1])

If no missing values were encountered for a given feature during training,
then samples with missing values are mapped to whichever child has the most
samples.

Low-level parallelism
---------------------

Expand Down
46 changes: 30 additions & 16 deletions doc/whats_new/v0.22.rst
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,11 @@ random sampling procedures.
- :class:`decomposition.SparseCoder` with `algorithm='lasso_lars'` |Fix|
- :class:`decomposition.SparsePCA` where `normalize_components` has no effect
due to deprecation.

- :class:`linear_model.Ridge` when `X` is sparse. |Fix|

- :class:`cluster.KMeans` when `n_jobs=1`. |Fix|
- :class:`ensemble.HistGradientBoostingClassifier` and
:class:`ensemble.HistGradientBoostingRegressor` |Fix|, |Feature|,
|Enhancement|.

Details are listed in the changelog below.

Expand Down Expand Up @@ -112,24 +113,31 @@ Changelog
:mod:`sklearn.ensemble`
.......................

- |Feature| :class:`ensemble.HistGradientBoostingClassifier` and
:class:`ensemble.HistGradientBoostingRegressor` have an additional
parameter called `warm_start` that enables warm starting. :pr:`14012` by
:user:`Johann Faouzi <johannfaouzi>`.

- |Fix| :class:`ensemble.HistGradientBoostingClassifier` and
:class:`ensemble.HistGradientBoostingRegressor` now bin the training and
validation data separately to avoid any data leak. :pr:`13933` by
`Nicolas Hug`_.
- Many improvements were made to
:class:`ensemble.HistGradientBoostingClassifier` and
:class:`ensemble.HistGradientBoostingRegressor`:

- |MajorFeature| Estimators now natively support dense data with missing
values both for training and predicting. They also support infinite
values. :pr:`13911` and :pr:`14406` by `Nicolas Hug`_, `Adrin Jalali`_
and `Olivier Grisel`_.
- |Feature| Estimators now have an additional `warm_start` parameter that
enables warm starting. :pr:`14012` by :user:`Johann Faouzi <johannfaouzi>`.
- |Enhancement| for :class:`ensemble.HistGradientBoostingClassifier` the
training loss or score is now monitored on a class-wise stratified
subsample to preserve the class balance of the original training set.
:pr:`14194` by :user:`Johann Faouzi <johannfaouzi>`.
- |Feature| :func:`inspection.partial_dependence` and
:func:`inspection.plot_partial_dependence` now support the fast 'recursion'
method for both estimators. :pr:`13769` by `Nicolas Hug`_.
- |Fix| Estimators now bin the training and validation data separately to
avoid any data leak. :pr:`13933` by `Nicolas Hug`_.

Note that pickles from 0.21 will not work in 0.22.

- |Fix| :func:`ensemble.VotingClassifier.predict_proba` will no longer be
present when `voting='hard'`. :pr:`14287` by `Thomas Fan`_.

- |Enhancement| :class:`ensemble.HistGradientBoostingClassifier` the training
loss or score is now monitored on a class-wise stratified subsample to
preserve the class balance of the original training set. :pr:`14194`
by :user:`Johann Faouzi <johannfaouzi>`.

- |Fix| Run by default
:func:`utils.estimator_checks.check_estimator` on both
:class:`ensemble.VotingClassifier` and :class:`ensemble.VotingRegressor`. It
Expand Down Expand Up @@ -182,6 +190,12 @@ Changelog
measure the importance of each feature in an arbitrary trained model with
respect to a given scoring function. :issue:`13146` by `Thomas Fan`_.

- |Feature| :func:`inspection.partial_dependence` and
:func:`inspection.plot_partial_dependence` now support the fast 'recursion'
method for :class:`ensemble.HistGradientBoostingClassifier` and
:class:`ensemble.HistGradientBoostingRegressor`. :pr:`13769` by
`Nicolas Hug`_.

:mod:`sklearn.linear_model`
...........................

Expand Down
19 changes: 12 additions & 7 deletions sklearn/ensemble/_hist_gradient_boosting/_binning.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,14 @@ import numpy as np
cimport numpy as np
from numpy.math cimport INFINITY
from cython.parallel import prange
from libc.math cimport isnan

from .types cimport X_DTYPE_C, X_BINNED_DTYPE_C
from .common cimport X_DTYPE_C, X_BINNED_DTYPE_C

cpdef _map_to_bins(const X_DTYPE_C [:, :] data, list binning_thresholds,
X_BINNED_DTYPE_C [::1, :] binned):
def _map_to_bins(const X_DTYPE_C [:, :] data,
list binning_thresholds,
const unsigned char missing_values_bin_idx,
X_BINNED_DTYPE_C [::1, :] binned):
"""Bin numerical values to discrete integer-coded levels.

Parameters
Expand All @@ -35,11 +38,13 @@ cpdef _map_to_bins(const X_DTYPE_C [:, :] data, list binning_thresholds,
for feature_idx in range(data.shape[1]):
_map_num_col_to_bins(data[:, feature_idx],
binning_thresholds[feature_idx],
missing_values_bin_idx,
binned[:, feature_idx])


cdef void _map_num_col_to_bins(const X_DTYPE_C [:] data,
const X_DTYPE_C [:] binning_thresholds,
const unsigned char missing_values_bin_idx,
X_BINNED_DTYPE_C [:] binned):
"""Binary search to find the bin index for each value in the data."""
cdef:
Expand All @@ -49,11 +54,11 @@ cdef void _map_num_col_to_bins(const X_DTYPE_C [:] data,
int middle

for i in prange(data.shape[0], schedule='static', nogil=True):
if data[i] == INFINITY:
# Special case for +inf.
# -inf is handled properly by binary search.
binned[i] = binning_thresholds.shape[0]

if isnan(data[i]):
binned[i] = missing_values_bin_idx
else:
# for known values, use binary search
left, right = 0, binning_thresholds.shape[0]
while left < right:
middle = (right + left - 1) // 2
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@ from cython.parallel import prange
import numpy as np
cimport numpy as np

from .types import Y_DTYPE
from .types cimport Y_DTYPE_C
from .common import Y_DTYPE
from .common cimport Y_DTYPE_C


def _update_raw_predictions(
Expand Down
4 changes: 2 additions & 2 deletions sklearn/ensemble/_hist_gradient_boosting/_loss.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@ cimport numpy as np

from libc.math cimport exp

from .types cimport Y_DTYPE_C
from .types cimport G_H_DTYPE_C
from .common cimport Y_DTYPE_C
from .common cimport G_H_DTYPE_C


def _update_gradients_least_squares(
Expand Down
40 changes: 26 additions & 14 deletions sklearn/ensemble/_hist_gradient_boosting/_predictor.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,16 @@

cimport cython
from cython.parallel import prange
from libc.math cimport isnan
import numpy as np
cimport numpy as np
from numpy.math cimport INFINITY

from .types cimport X_DTYPE_C
from .types cimport Y_DTYPE_C
from .types import Y_DTYPE
from .types cimport X_BINNED_DTYPE_C
from .types cimport node_struct
from .common cimport X_DTYPE_C
from .common cimport Y_DTYPE_C
from .common import Y_DTYPE
from .common cimport X_BINNED_DTYPE_C
from .common cimport node_struct

FB72
def _predict_from_numeric_data(
Expand Down Expand Up @@ -43,10 +44,12 @@ cdef inline Y_DTYPE_C _predict_one_from_numeric_data(
while True:
if node.is_leaf:
return node.value
if numeric_data[row, node.feature_idx] == INFINITY:
# if data is +inf we always go to the right child, even when the
# threhsold is +inf
node = nodes[node.right]

if isnan(numeric_data[row, node.feature_idx]):
if node.missing_go_to_left:
node = nodes[node.left]
else:
node = nodes[node.right]
else:
if numeric_data[row, node.feature_idx] <= node.threshold:
node = nodes[node.left]
Expand All @@ -57,19 +60,22 @@ cdef inline Y_DTYPE_C _predict_one_from_numeric_data(
def _predict_from_binned_data(
node_struct [:] nodes,
const X_BINNED_DTYPE_C [:, :] binned_data,
const unsigned char missing_values_bin_idx,
Y_DTYPE_C [:] out):

cdef:
int i

for i in prange(binned_data.shape[0], schedule='static', nogil=True):
out[i] = _predict_one_from_binned_data(nodes, binned_data, i)
out[i] = _predict_one_from_binned_data(nodes, binned_data, i,
missing_values_bin_idx)


cdef inline Y_DTYPE_C _predict_one_from_binned_data(
node_struct [:] nodes,
const X_BINNED_DTYPE_C [:, :] binned_data,
const int row) nogil:
const int row,
const unsigned char missing_values_bin_idx) nogil:
# Need to pass the whole array and the row index, else prange won't work.
# See issue Cython #2798

Expand All @@ -79,10 +85,16 @@ cdef inline Y_DTYPE_C _predict_one_from_binned_data(
while True:
if node.is_leaf:
return node.value
if binned_data[row, node.feature_idx] <= node.bin_threshold:
node = nodes[node.left]
if binned_data[row, node.feature_idx] == missing_values_bin_idx:
if node.missing_go_to_left:
node = nodes[node.left]
else:
node = nodes[node.right]
else:
node = nodes[node.right]
if binned_data[row, node.feature_idx] <= node.bin_threshold:
node = nodes[node.left]
else:
node = nodes[node.right]

def _compute_partial_dependence(
node_struct [:] nodes,
Expand Down
Loading
0