8000 ENH FEA add interaction constraints to HGBT by lorentzenchr · Pull Request #21020 · scikit-learn/scikit-learn · GitHub
[go: up one dir, main page]

Skip to content

ENH FEA add interaction constraints to HGBT #21020

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
72 commits
Select commit Hold shift + click to select a range
244c409
DOC add attribues to TreeGrower
lorentzenchr Sep 12, 2021
b31eea0
ENH add interaction_cst
lorentzenchr Sep 12, 2021
d9b273a
use a set in _get_allowed_features
lorentzenchr Sep 12, 2021
1cc1cb5
complete overhaul
lorentzenchr Sep 13, 2021
7baf695
TST test_split_interaction_constraints
lorentzenchr Sep 14, 2021
8aced52
DOC add is_leaf to Attributes section
lorentzenchr Sep 14, 2021
9a9862c
DOC improve interaction_cst_idx
lorentzenchr Sep 14, 2021
ed31a7e
FIX fix logic
lorentzenchr Sep 14, 2021
f2a0679
TST add test_grower_interaction_constraints
lorentzenchr Sep 15, 2021
eb1e255
CLN make allowed_features an instance variable
lorentzenchr Sep 16, 2021
ec48945
TST restructure test_grower_interaction_constraints
lorentzenchr Sep 17, 2021
1ed28d2
CLN improve logic
lorentzenchr Sep 17, 2021
c7c8c3f
TST improve test
lorentzenchr Sep 17, 2021
764cdf5
DOC add docstring for interaction_cst
lorentzenchr Sep 18, 2021
5a26f6e
ENH add validation of interaction_cst
lorentzenchr Sep 18, 2021
eb75a30
TST test input validation
lorentzenchr Sep 18, 2021
3ea3829
8000 DEBUG uncomment if condition
lorentzenchr Sep 18, 2021
0570f61
address review comments
lorentzenchr Sep 21, 2021
16fc0b8
Revert "DEBUG uncomment if condition"
lorentzenchr Sep 21, 2021
2b7e1e2
TST increase max_depth and n_samples in test_grower_interaction_const…
lorentzenchr Sep 21, 2021
ead3b0c
ENH add default group to interaction constraints
lorentzenchr Sep 21, 2021
5a35ab7
TST test_check_interaction_cst
lorentzenchr Sep 21, 2021
c93d3f0
DOC udpate docstring of interaction_cst with default group
lorentzenchr Sep 21, 2021
5092f6b
DOC add whatsnew
lorentzenchr Sep 22, 2021
a18b5ee
DEBUG
lorentzenchr Sep 23, 2021
6a02058
TST make test_split_interaction_constraints more tighter
lorentzenchr Sep 28, 2021
aa21d16
better comments and less typos
lorentzenchr Sep 28, 2021
63191c0
Merge branch 'main' into hgbt_interaction_constraints
lorentzenchr Sep 28, 2021
299f31b
Revert "DEBUG"
lorentzenchr Sep 28, 2021
9ec7b04
Merge branch 'main' into hgbt_interaction_constraints
lorentzenchr Oct 22, 2021
4c9e1a3
DOC address review comments for docstrings
lorentzenchr Oct 22, 2021
c09ba91
TST reviewer suggestion for improved grower test
lorentzenchr Oct 22, 2021
ba78cb9
TST check interaction constraints numerically
lorentzenchr Oct 22, 2021
c7a6ebe
Merge branch 'main' into hgbt_interaction_constraints
lorentzenchr Oct 24, 2021
c8a3a30
EXA add interaction constraints to partial dependence
lorentzenchr Oct 24, 2021
3b6703a
CLN colon in example
lorentzenchr Oct 24, 2021
255646a
CLN fix whatsnew
lorentzenchr Oct 24, 2021
7100600
TST better error messages
lorentzenchr Oct 25, 2021
31c6c3e
EXA add 1D ice plots to see parallel lines
lorentzenchr Oct 25, 2021
bd62aea
TST rely more on default values
lorentzenchr Oct 25, 2021
ec66be7
Merge branch 'main' into hgbt_interaction_constraints
lorentzenchr Nov 15, 2021
eed05ac
DOC add blank lines in whats_new
lorentzenchr Nov 15, 2021
d66f40a
DOC remove 1.0.1 entry in whats_new 1.1
lorentzenchr Nov 15, 2021
fb9f0b1
Merge branch 'main' of https://github.com/scikit-learn/scikit-learn i…
lesteve Apr 6, 2022
0fe1227
Merge branch 'main' into hgbt_interaction_constraints
lorentzenchr Apr 11, 2022
22ecd8d
Merge branch 'main' into hgbt_interaction_constraints
lorentzenchr Jun 28, 2022
ee86a77
CLN fix merge with parameter validation
lorentzenchr Jun 28, 2022
13b0aaf
DOC move whatsnew to v1.2
lorentzenchr Jun 29, 2022
1c75630
CLN move missing docstring additions to other PR
lorentzenchr Aug 16, 2022
b9d880b
DOC add user guide entry
lorentzenchr Aug 17, 2022
10023c0
MNT change versionadded to 1.2
lorentzenchr Aug 17, 2022
4265d23
DOC use code-block:: text
lorentzenchr Aug 17, 2022
a7559b1
DOC allowed_features has dtype uint32
lorentzenchr Aug 17, 2022
3bacb79
DOC remove None from interaction_cst_indices
lorentzenchr Aug 22, 2022
cf4eb15
Merge branch 'main' into hgbt_interaction_constraints
lorentzenchr Aug 25, 2022
6653a4e
DOC fix typo
lorentzenchr Aug 25, 2022
8d02553
DOC try none to switch off language highlightning
lorentzenchr Aug 25, 2022
4989a26
DOC address features by numbers
lorentzenchr Aug 25, 2022
8000
4a90e0f
Merge branch 'main' into hgbt_interaction_constraints
lorentzenchr Aug 30, 2022
61b1e06
address reviewer comments
lorentzenchr Sep 6, 2022
38caedb
DOC add note about LightGBM logic
lorentzenchr Sep 6, 2022
45d178d
Merge branch 'main' into hgbt_interaction_constraints
lorentzenchr Sep 6, 2022
9667937
DOC fix typo
lorentzenchr Sep 7, 2022
ca270f5
Merge branch 'main' into hgbt_interaction_constraints
lorentzenchr Sep 26, 2022
9fb3e55
CLN better comment on test construction
lorentzenchr Sep 26, 2022
5240d9f
EXA review comments
lorentzenchr Oct 9, 2022
295aeee
ENH improvements from Thomas review comments
lorentzenchr Oct 9, 2022
9560ea7
CLN Julien's review comments
lorentzenchr Oct 10, 2022
28c4578
TST fix test_grower_interaction_constraints
lorentzenchr Oct 10, 2022
4d4b80a
DOC add reference Mayer 2022
lorentzenchr Oct 11, 2022
461cd6a
CLN remove if node.is_leaf in for loop
lorentzenchr Oct 11, 2022
e0e8220
CLN fix test_grower_interaction_constraints
lorentzenchr Oct 11, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 42 additions & 4 deletions doc/modules/ensemble.rst
Original file line number Diff line number Diff line change
Expand Up @@ -317,9 +317,9 @@ to the prediction function.

.. topic:: References

.. [L2014] G. Louppe,
"Understanding Random Forests: From Theory to Practice",
PhD Thesis, U. of Liege, 2014.
.. [L2014] G. Louppe, :arxiv:`"Understanding Random Forests: From Theory to
Practice" <1407.7502>`,
PhD Thesis, U. of Liege, 2014.

.. _random_trees_embedding:

Expand Down Expand Up @@ -711,7 +711,7 @@ space.
accurate enough: the tree can only output integer values. As a result, the
leaves values of the tree :math:`h_m` are modified once the tree is
fitted, such that the leaves values minimize the loss :math:`L_m`. The
update is loss-dependent: for the absolute error loss, the value of
update is loss-dependent: for the absolute error loss, the value of
a leaf is updated to the median of the samples in that leaf.

Classification
Expand Down Expand Up @@ -1174,6 +1174,44 @@ Also, monotonic constraints are not supported for multiclass classification.

* :ref:`sphx_glr_auto_examples_ensemble_plot_monotonic_constraints.py`

.. _interaction_cst_hgbt:

Interaction constraints
-----------------------

A priori, the histogram gradient boosting trees are allowed to use any feature
to split a node into child nodes. This creates so called interactions between
features, i.e. usage of different features as split along a branch. Sometimes,
one wants to restrict the possible interactions, see [Mayer2022]_. This can be
done by the parameter ``interaction_cst``, where one can specify the indices
of features that are allowed to interact.
For instance, with 3 features in total, ``interaction_cst=[{0}, {1}, {2}]``
forbids all interactions.
The constraints ``[{0, 1}, {1, 2}]`` specifies two groups of possibly
interacting features. Features 0 and 1 may interact with each other, as well
as features 1 and 2. But note that features 0 and 2 are forbidden to interact.
The following depicts a tree and the possible splits of the tree:

.. code-block:: none

1 <- Both constraint groups could be applied from now on
/ \
1 2 <- Left split still fulfills both constraint groups.
/ \ / \ Right split at feature 2 has only group {1, 2} from now on.

LightGBM uses the same logic for overlapping groups.

Note that features not listed in ``interaction_cst`` are automatically
assigned an interaction group for themselves. With again 3 features, this
means that ``[{0}]`` is equivalent to ``[{0}, {1, 2}]``.

.. topic:: References

.. [Mayer2022] M. Mayer, S.C. Bourassa, M. Hoesli, and D.F. Scognamiglio.
2022. :doi:`Machine Learning Applications to Land and Structure Valuation
<10.3390/jrfm15050193>`.
Journal of Risk and Financial Management 15, no. 5: 193

Low-level parallelism
---------------------

Expand Down
6 changes: 6 additions & 0 deletions doc/whats_new/v1.2.rst
Original file line number Diff line number Diff line change
Expand Up @@ -242,6 +242,12 @@ Changelog
:mod:`sklearn.ensemble`
.......................

- |Feature| :class:`~sklearn.ensemble.HistGradientBoostingClassifier` and
:class:`~sklearn.ensemble.HistGradientBoostingRegressor` now support
interaction constraints via the argument `interaction_cst` of their
constructors.
:pr:`21020` by :user:`Christian Lorentzen <lorentzenchr>`.

- |Feature| Adds `class_weight` to :class:`ensemble.HistGradientBoostingClassifier`.
:pr:`22014` by `Thomas Fan`_.

Expand Down
79 changes: 79 additions & 0 deletions examples/inspection/plot_partial_dependence.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,6 +255,85 @@
# house age, whereas for values less than two there is a strong dependence on
# age.
#
# Interaction constraints
# .......................
#
# The histogram gradient boosters have an interesting option to constrain
# possible interactions among features. In the following, we do not allow any
# interactions and thus render the model as a version of a tree-based boosted
# generalized additive model (GAM). This makes the model more interpretable
# as the effect of each feature can be investigated independently of all others.
#
# We train the :class:`~sklearn.ensemble.HistGradientBoostingRegressor` again,
# now with `interaction_cst`, where we pass for each feature a list containing
# only its own index, e.g. `[[0], [1], [2], ..]`.

print("Training interaction constraint HistGradientBoostingRegressor...")
tic = time()
est_no_interactions = HistGradientBoostingRegressor(
interaction_cst=[[i] for i in range(X_train.shape[1])]
)
est_no_interactions.fit(X_train, y_train)
print(f"done in {time() - tic:.3f}s")

# %%
# The easiest way to show the effect of forbidden interactions is again the
# ICE plots.

print("Computing partial dependence plots...")
tic = time()
display = PartialDependenceDisplay.from_estimator(
est_no_interactions,
X_train,
["MedInc", "AveOccup", "HouseAge", "AveRooms"],
kind="both",
subsample=50,
n_jobs=3,
grid_resolution=20,
random_state=0,
ice_lines_kw={"color": "tab:blue", "alpha": 0.2, "linewidth": 0.5},
pd_line_kw={"color": "tab:orange", "linestyle": "--"},
)

print(f"done in {time() - tic:.3f}s")
display.figure_.suptitle(
"Partial dependence of house value with Gradient Boosting\n"
"and no interactions allowed"
)
display.figure_.subplots_adjust(wspace=0.4, hspace=0.3)

# %%
# All 4 plots have parallel ICE lines meaning there is no interaction in the
# model.
# Let us also have a look at the corresponding 2D-plot.

print("Computing partial dependence plots...")
tic = time()
_, ax = plt.subplots(ncols=3, figsize=(9, 4))
display = PartialDependenceDisplay.from_estimator(
est_no_interactions,
X_train,
["AveOccup", "HouseAge", ("AveOccup", "HouseAge")],
kind="average",
n_jobs=3,
grid_resolution=20,
ax=ax,
)
print(f"done in {time() - tic:.3f}s")
display.figure_.suptitle(
"Partial dependence of house value with Gradient Boosting\n"
"and no interactions allowed"
)
display.figure_.subplots_adjust(wspace=0.4, hspace=0.3)

# %%
# Although the 2D-plot shows much less interaction compared with the 2D-plot
# from above, it is much harder to come to the conclusion that there is no
# interaction at all. This might be a cause of the discrete predictions of
# trees in combination with numerically precision of partial dependence.
# We also observe that the univariate dependence plots have slightly changed
# as the model tries to compensate for the forbidden interactions.
#
# 3D interaction plots
# --------------------
#
Expand Down
80 changes: 80 additions & 0 deletions sklearn/ensemble/_hist_gradient_boosting/gradient_boosting.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# Author: Nicolas Hug

from abc import ABC, abstractmethod
from collections.abc import Iterable
from functools import partial
from numbers import Real, Integral
import warnings
Expand Down Expand Up @@ -91,6 +92,7 @@ class BaseHistGradientBoosting(BaseEstimator, ABC):
"min_samples_leaf": [Interval(Integral, 1, None, closed="left")],
"l2_regularization": [Interval(Real, 0, None, closed="left")],
"monotonic_cst": ["array-like", None],
"interaction_cst": [Iterable, None],
"n_iter_no_change": [Interval(Integral, 1, None, closed="left")],
"validation_fraction": [
Interval(Real, 0, 1, closed="neither"),
Expand Down Expand Up @@ -121,6 +123,7 @@ def __init__(
max_bins,
categorical_features,
monotonic_cst,
interaction_cst,
warm_start,
early_stopping,
scoring,
Expand All @@ -139,6 +142,7 @@ def __init__(
self.l2_regularization = l2_regularization
self.max_bins = max_bins
self.monotonic_cst = monotonic_cst
self.interaction_cst = interaction_cst
self.categorical_features = categorical_features
self.warm_start = warm_start
self.early_stopping = early_stopping
Expand Down Expand Up @@ -252,6 +256,42 @@ def _check_categories(self, X):

return is_categorical, known_categories

def _check_interaction_cst(self, n_features):
"""Check and validation for interaction constraints."""
if self.interaction_cst is None:
return None

if not (
isinstance(self.interaction_cst, Iterable)
and all(isinstance(x, Iterable) for x in self.interaction_cst)
):
raise ValueError(
"Interaction constraints must be None or an iterable of iterables, "
f"got: {self.interaction_cst!r}."
)

invalid_indices = [
x
for cst_set in self.interaction_cst
for x in cst_set
if not (isinstance(x, Integral) and 0 <= x < n_features)
]
if invalid_indices:
raise ValueError(
"Interaction constraints must consist of integer indices in [0,"
f" n_features - 1] = [0, {n_features - 1}], specifying the position of"
f" features, got invalid indices: {invalid_indices!r}"
)

constraints = [set(group) for group in self.interaction_cst]

# Add all not listed features as own group by default.
rest = set(range(n_features)) - set().union(*constraints)
if len(rest) > 0:
constraints.append(rest)

return constraints

def fit(self, X, y, sample_weight=None):
"""Fit the gradient boosting model.

Expand Down Expand Up @@ -308,6 +348,9 @@ def fit(self, X, y, sample_weight=None):

self.is_categorical_, known_categories = self._check_categories(X)

# Encode constraints into a list of sets of features indices (integers).
interaction_cst = self._check_interaction_cst(self._n_features)

# we need this stateful variable to tell raw_predict() that it was
# called from fit() (this current method), and that the data it has
# received is pre-binned.
Expand Down Expand Up @@ -595,6 +638,7 @@ def fit(self, X, y, sample_weight=None):
has_missing_values=has_missing_values,
is_categorical=self.is_categorical_,
monotonic_cst=self.monotonic_cst,
interaction_cst=interaction_cst,
max_leaf_nodes=self.max_leaf_nodes,
max_depth=self.max_depth,
min_samples_leaf=self.min_samples_leaf,
Expand Down Expand Up @@ -1191,6 +1235,22 @@ class HistGradientBoostingRegressor(RegressorMixin, BaseHistGradientBoosting):

.. versionadded:: 0.23

interaction_cst : iterable of iterables of int, default=None
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So specifying a GAM means interaction_cst=[{i} for i in range(X.shape[1])], and doing all pairwise interactions would be
interaction_cst=[{i} for i in range(X.shape[1])] + [{i, j} for i, j in itertools.combinations(range(X.shape[1]), 2)] or something like that?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Talking with @ogrisel and @adrinjalali maybe an option would be to have string special cases for univariate and bivariate, and then this would be a good first step :)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@amueller : n my view, this is not necessary. These two cases are actually less important than one might think when working with interaction constraints. In practice you would simply limit the number of terminal nodes (2 resp. 3) instead of using constraints. The interesting cases are asymmetric ones, e.g. some variables are forced to act additively and others not. There, the intended interface is actually very convenient.

Copy link
Member
@amueller amueller Oct 23, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In practice you would simply limit the number of terminal nodes (2 resp. 3) instead of using constraints

Why? That's not equivalent at all, is it?

These two cases are actually less important than one might think when working with interaction constraints.

I think the motivation for me is interpretable models, and you can get interpretable models that are using deeper trees, which is not the same as boosting stumps as far as I can see.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hm I just realized that this interface doesn't allow restricting to interactions of two features, right? Passing all tuples can still result in trees using more than 2 features, right?

Copy link
Contributor
@mayer79 mayer79 Oct 23, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The resulting trees are not identical, but the resulting model structure is in the sense that the interaction constraints are fulfilled in both cases.

Copy link
Contributor
@mayer79 mayer79 Oct 23, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Regarding the pairwise case: If Christian implemented it correctly (I don't doubt!), then each branch in each tree will use only features of one constraint set. As such, each tree prediction will use only two features. But the tree will usually use three features.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Talking with @ogrisel and @adrinjalali maybe an option would be to have string special cases for univariate and bivariate, and then this would be a good first step :)

+1. I had the same thought 😏 Maybe we can we do that in a follow-up PR?

All pairwise interactions would just be interaction_cst = list(itertools.combinations(range(n_features), 2))).

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As such, each tree prediction will use only two features. But the tree will use many.

Interesting. We need to check how other libraries do it. I wonder if this is a significantly different inductive bais compared to have each tree work only with a small subset of features.

In practice you would simply limit the number of terminal nodes (2 resp. 3) instead of using constraints.

I think allowing deep trees (or deep branches) with a large number of splits but on a small subset of the features (typically 1 or 2) is an interesting inductive bias (similar to GAMs): it allows for decision function with complex non-linear feature-wise functions but very decoupled inter-features decisions. Relying of sequential decision stumps via more iterations of the gradient boosting algorithm is probably quite different from an inductive bias point of view.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The resulting trees are not identical, but the resulting model structure is in the sense that the interaction constraints are fulfilled in both cases.

The constraint is fulfilled but the models are not equivalent in any way, right? Aka what @ogrisel said, it's quite a different model.

But the tree will usually use three features.

One of the reasons people restrict to pairwise interactions is so that the full model can be visualized. That's much harder with three features. There is no way to achieve trees that are on pairs of features with this PR, right?

Specify interaction constraints, i.e. sets of features which can
only interact with each other in child nodes splits.

Each iterable materializes a constraint by the set of indices of
the features that are allowed to interact with each other.
If there are more features than specified in these constraints,
they are treated as if they were specified as an additional set.

For instance, with 5 features in total, `interaction_cst=[{0, 1}]`
is equivalent to `interaction_cst=[{0, 1}, {2, 3, 4}]`,
and specifies that each branch of a tree will either only split
on features 0 and 1 or only split on features 2, 3 and 4.

.. versionadded:: 1.2

warm_start : bool, default=False
When set to ``True``, reuse the solution of the previous call to fit
and add more estimators to the ensemble. For results to be valid, the
Expand Down Expand Up @@ -1315,6 +1375,7 @@ def __init__(
max_bins=255,
categorical_features=None,
monotonic_cst=None,
interaction_cst=None,
warm_start=False,
early_stopping="auto",
scoring="loss",
Expand All @@ -1334,6 +1395,7 @@ def __init__(
l2_regularization=l2_regularization,
max_bins=max_bins,
monotonic_cst=monotonic_cst,
interaction_cst=interaction_cst,
categorical_features=categorical_features,
early_stopping=early_stopping,
warm_start=warm_start,
Expand Down Expand Up @@ -1505,6 +1567,22 @@ class HistGradientBoostingClassifier(ClassifierMixin, BaseHistGradientBoosting):

.. versionadded:: 0.23

interaction_cst : iterable of iterables of int, default=None
Specify interaction constraints, i.e. sets of features which can
only interact with each other in child nodes splits.

Each iterable materializes a constraint by the set of indices of
the features that are allowed to interact with each other.
If there are more features than specified in these constraints,
they are treated as if they were specified as an additional set.

For instance, with 5 features in total, `interaction_cst=[{0, 1}]`
is equivalent to `interaction_cst=[{0, 1}, {2, 3, 4}]`,
and specifies that each branch of a tree will either only split
on features 0 and 1 or only split on features 2, 3 and 4.

.. versionadded:: 1.2

warm_start : bool, default=False
When set to ``True``, reuse the solution of the previous call to fit
and add more estimators to the ensemble. For results to be valid, the
Expand Down Expand Up @@ -1653,6 +1731,7 @@ def __init__(
max_bins=255,
categorical_features=None,
monotonic_cst=None,
interaction_cst=None,
warm_start=False,
early_stopping="auto",
scoring="loss",
Expand All @@ -1674,6 +1753,7 @@ def __init__(
max_bins=max_bins,
categorical_features=categorical_features,
monotonic_cst=monotonic_cst,
interaction_cst=interaction_cst,
warm_start=warm_start,
early_stopping=early_stopping,
scoring=scoring,
Expand Down
Loading
0