8000 ENH FEA add interaction constraints to HGBT (#21020) · scikit-learn/scikit-learn@5ceb8a6 · GitHub
[go: up one dir, main page]

Skip to content

Commit 5ceb8a6

Browse files
ENH FEA add interaction constraints to HGBT (#21020)
Co-authored-by: Loïc Estève <loic.esteve@ymail.com>
1 parent 97057d3 commit 5ceb8a6

File tree

9 files changed

+627
-36
lines changed

9 files changed

+627
-36
lines changed

doc/modules/ensemble.rst

Lines changed: 42 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -317,9 +317,9 @@ to the prediction function.
317317

318318
.. topic:: References
319319

320-
.. [L2014] G. Louppe,
321-
"Understanding Random Forests: From Theory to Practice",
322-
PhD Thesis, U. of Liege, 2014.
320+
.. [L2014] G. Louppe, :arxiv:`"Understanding Random Forests: From Theory to
321+
Practice" <1407.7502>`,
322+
PhD Thesis, U. of Liege, 2014.
323323
324324
.. _random_trees_embedding:
325325

@@ -711,7 +711,7 @@ space.
711711
accurate enough: the tree can only output integer values. As a result, the
712712
leaves values of the tree :math:`h_m` are modified once the tree is
713713
fitted, such that the leaves values minimize the loss :math:`L_m`. The
714-
update is loss-dependent: for the absolute error loss, the value of
714+
update is loss-dependent: for the absolute error loss, the value of
715715
a leaf is updated to the median of the samples in that leaf.
716716

717717
Classification
@@ -1174,6 +1174,44 @@ Also, monotonic constraints are not supported for multiclass classification.
11741174

11751175
* :ref:`sphx_glr_auto_examples_ensemble_plot_monotonic_constraints.py`
11761176

1177+
.. _interaction_cst_hgbt:
1178+
1179+
Interaction constraints
1180+
-----------------------
1181+
1182+
A priori, the histogram gradient boosting trees are allowed to use any feature
1183+
to split a node into child nodes. This creates so called interactions between
1184+
features, i.e. usage of different features as split along a branch. Sometimes,
1185+
one wants to restrict the possible interactions, see [Mayer2022]_. This can be
1186+
done by the parameter ``interaction_cst``, where one can specify the indices
1187+
of features that are allowed to interact.
1188+
For instance, with 3 features in total, ``interaction_cst=[{0}, {1}, {2}]``
1189+
forbids all interactions.
1190+
The constraints ``[{0, 1}, {1, 2}]`` specifies two groups of possibly
1191+
interacting features. Features 0 and 1 may interact with each other, as well
1192+
as features 1 and 2. But note that features 0 and 2 are forbidden to interact.
1193+
The following depicts a tree and the possible splits of the tree:
1194+
1195+
.. code-block:: none
1196+
1197+
1 <- Both constraint groups could be applied from now on
1198+
/ \
1199+
1 2 <- Left split still fulfills both constraint groups.
1200+
/ \ / \ Right split at feature 2 has only group {1, 2} from now on.
1201+
1202+
LightGBM uses the same logic for overlapping groups.
1203+
1204+
Note that features not listed in ``interaction_cst`` are automatically
1205+
assigned an interaction group for themselves. With again 3 features, this
1206+
means that ``[{0}]`` is equivalent to ``[{0}, {1, 2}]``.
1207+
1208+
.. topic:: References
1209+
1210+
.. [Mayer2022] M. Mayer, S.C. Bourassa, M. Hoesli, and D.F. Scognamiglio.
1211+
2022. :doi:`Machine Learning Applications to Land and Structure Valuation
1212+
<10.3390/jrfm15050193>`.
1213+
Journal of Risk and Financial Management 15, no. 5: 193
1214+
11771215
Low-level parallelism
11781216
---------------------
11791217

doc/whats_new/v1.2.rst

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -255,6 +255,12 @@ Changelog
255255
:mod:`sklearn.ensemble`
256256
.......................
257257

258+
- |Feature| :class:`~sklearn.ensemble.HistGradientBoostingClassifier` and
259+
:class:`~sklearn.ensemble.HistGradientBoostingRegressor` now support
260+
interaction constraints via the argument `interaction_cst` of their
261+
constructors.
262+
:pr:`21020` by :user:`Christian Lorentzen <lorentzenchr>`.
263+
258264
- |Feature| Adds `class_weight` to :class:`ensemble.HistGradientBoostingClassifier`.
259265
:pr:`22014` by `Thomas Fan`_.
260266

examples/inspection/plot_partial_dependence.py

Lines changed: 79 additions & 0 deletions
8E73
Original file line numberDiff line numberDiff line change
@@ -255,6 +255,85 @@
255255
# house age, whereas for values less than two there is a strong dependence on
256256
# age.
257257
#
258+
# Interaction constraints
259+
# .......................
260+
#
261+
# The histogram gradient boosters have an interesting option to constrain
262+
# possible interactions among features. In the following, we do not allow any
263+
# interactions and thus render the model as a version of a tree-based boosted
264+
# generalized additive model (GAM). This makes the model more interpretable
265+
# as the effect of each feature can be investigated independently of all others.
266+
#
267+
# We train the :class:`~sklearn.ensemble.HistGradientBoostingRegressor` again,
268+
# now with `interaction_cst`, where we pass for each feature a list containing
269+
# only its own index, e.g. `[[0], [1], [2], ..]`.
270+
271+
print("Training interaction constraint HistGradientBoostingRegressor...")
272+
tic = time()
273+
est_no_interactions = HistGradientBoostingRegressor(
274+
interaction_cst=[[i] for i in range(X_train.shape[1])]
275+
)
276+
est_no_interactions.fit(X_train, y_train)
277+
print(f"done in {time() - tic:.3f}s")
278+
279+
# %%
280+
# The easiest way to show the effect of forbidden interactions is again the
281+
# ICE plots.
282+
283+
print("Computing partial dependence plots...")
284+
tic = time()
285+
display = PartialDependenceDisplay.from_estimator(
286+
est_no_interactions,
287+
X_train,
288+
["MedInc", "AveOccup", "HouseAge", "AveRooms"],
289+
kind="both",
290+
subsample=50,
291+
n_jobs=3,
292+
grid_resolution=20,
293+
random_state=0,
294+
ice_lines_kw={"color": "tab:blue", "alpha": 0.2, "linewidth": 0.5},
295+
pd_line_kw={"color": "tab:orange", "linestyle": "--"},
296+
)
297+
298+
print(f"done in {time() - tic:.3f}s")
299+
display.figure_.suptitle(
300+
"Partial dependence of house value with Gradient Boosting\n"
301+
"and no interactions allowed"
302+
)
303+
display.figure_.subplots_adjust(wspace=0.4, hspace=0.3)
304+
305+
# %%
306+
# All 4 plots have parallel ICE lines meaning there is no interaction in the
307+
# model.
308+
# Let us also have a look at the corresponding 2D-plot.
309+
310+
print("Computing partial dependence plots...")
311+
tic = time()
312+
_, ax = plt.subplots(ncols=3, figsize=(9, 4))
313+
display = PartialDependenceDisplay.from_estimator(
314+
est_no_interactions,
315+
X_train,
316+
["AveOccup", "HouseAge", ("AveOccup", "HouseAge")],
317+
kind="average",
318+
n_jobs=3,
319+
grid_resolution=20,
320+
ax=ax,
321+
)
322+
print(f"done in {time() - tic:.3f}s")
323+
display.figure_.suptitle(
324+
"Partial dependence of house value with Gradient Boosting\n"
325+
"and no interactions allowed"
326+
)
327+
display.figure_.subplots_adjust(wspace=0.4, hspace=0.3)
328+
329+
# %%
330+
# Although the 2D-plot shows much less interaction compared with the 2D-plot
331+
# from above, it is much harder to come to the conclusion that there is no
332+
# interaction at all. This might be a cause of the discrete predictions of
333+
# trees in combination with numerically precision of partial dependence.
334+
# We also observe that the univariate dependence plots have slightly changed
335+
# as the model tries to compensate for the forbidden interactions.
336+
#
258337
# 3D interaction plots
259338
# --------------------
260339
#

sklearn/ensemble/_hist_gradient_boosting/gradient_boosting.py

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
# Author: Nicolas Hug
33

44
from abc import ABC, abstractmethod
5+
from collections.abc import Iterable
56
from functools import partial
67
from numbers import Real, Integral
78
import warnings
@@ -91,6 +92,7 @@ class BaseHistGradientBoosting(BaseEstimator, ABC):
9192
"min_samples_leaf": [Interval(Integral, 1, None, closed="left")],
9293
"l2_regularization": [Interval(Real, 0, None, closed="left")],
9394
"monotonic_cst": ["array-like", None],
95+
"interaction_cst": [Iterable, None],
9496
"n_iter_no_change": [Interval(Integral, 1, None, closed="left")],
9597
"validation_fraction": [
9698
Interval(Real, 0, 1, closed="neither"),
@@ -121,6 +123,7 @@ def __init__(
121123
max_bins,
122124
categorical_features,
123125
monotonic_cst,
126+
interaction_cst,
124127
warm_start,
125128
early_stopping,
126129
scoring,
@@ -139,6 +142,7 @@ def __init__(
139142
self.l2_regularization = l2_regularization
140143
self.max_bins = max_bins
141144
self.monotonic_cst = monotonic_cst
145+
self.interaction_cst = interaction_cst
142146
self.categorical_features = categorical_features
143147
self.warm_start = warm_start
144148
self.early_stopping = early_stopping
@@ -252,6 +256,42 @@ def _check_categories(self, X):
252256

253257
return is_categorical, known_categories
254258

259+
def _check_interaction_cst(self, n_features):
260+
"""Check and validation for interaction constraints."""
261+
if self.interaction_cst is None:
262+
return None
263+
264+
if not (
265+
isinstance(self.interaction_cst, Iterable)
266+
and all(isinstance(x, Iterable) for x in self.interaction_cst)
267+
):
268+
raise ValueError(
269+
"Interaction constraints must be None or an iterable of iterables, "
270+
f"got: {self.interaction_cst!r}."
271+
)
272+
273+
invalid_indices = [
274+
x
275+
for cst_set in self.interaction_cst
276+
for x in cst_set
277+
if not (isinstance(x, Integral) and 0 <= x < n_features)
278+
]
279+
if invalid_indices:
280+
raise ValueError(
281+
"Interaction constraints must consist of integer indices in [0,"
282+
f" n_features - 1] = [0, {n_features - 1}], specifying the position of"
283+
f" features, got invalid indices: {invalid_indices!r}"
284+
)
285+
286+
constraints = [set(group) for group in self.interaction_cst]
287+
288+
# Add all not listed features as own group by default.
289+
rest = set(range(n_features)) - set().union(*constraints)
290+
if len(rest) > 0:
291+
constraints.append(rest)
292+
293+
return constraints
294+
255295
def fit(self, X, y, sample_weight=None):
256296
"""Fit the gradient boosting model.
257297
@@ -308,6 +348,9 @@ def fit(self, X, y, sample_weight=None):
308348

309349
self.is_categorical_, known_categories = self._check_categories(X)
310350

351+
# Encode constraints into a list of sets of features indices (integers).
352+
interaction_cst = self._check_interaction_cst(self._n_features)
353+
311354
# we need this stateful variable to tell raw_predict() that it was
312355
# called from fit() (this current method), and that the data it has
313356
# received is pre-binned.
@@ -595,6 +638,7 @@ def fit(self, X, y, sample_weight=None):
595638
has_missing_values=has_missing_values,
596639
is_categorical=self.is_categorical_,
597640
monotonic_cst=self.monotonic_cst,
641+
interaction_cst=interaction_cst,
598642
max_leaf_nodes=self.max_leaf_nodes,
599643
max_depth=self.max_depth,
600644
min_samples_leaf=self.min_samples_leaf,
@@ -1193,6 +1237,22 @@ class HistGradientBoostingRegressor(RegressorMixin, BaseHistGradientBoosting):
11931237
11941238
.. versionadded:: 0.23
11951239
1240+
interaction_cst : iterable of iterables of int, default=None
1241+
Specify interaction constraints, i.e. sets of features which can
1242+
only interact with each other in child nodes splits.
1243+
1244+
Each iterable materializes a constraint by the set of indices of
1245+
the features that are allowed to interact with each other.
1246+
If there are more features than specified in these constraints,
1247+
they are treated as if they were specified as an additional set.
1248+
1249+
For instance, with 5 features in total, `interaction_cst=[{0, 1}]`
1250+
is equivalent to `interaction_cst=[{0, 1}, {2, 3, 4}]`,
1251+
and specifies that each branch of a tree will either only split
1252+
on features 0 and 1 or only split on features 2, 3 and 4.
1253+
1254+
.. versionadded:: 1.2
1255+
11961256
warm_start : bool, default=False
11971257
When set to ``True``, reuse the solution of the previous call to fit
11981258
and add more estimators to the ensemble. For results to be valid, the
@@ -1317,6 +1377,7 @@ def __init__(
13171377
max_bins=255,
13181378
categorical_features=None,
13191379
monotonic_cst=None,
1380+
interaction_cst=None,
13201381
warm_start=False,
13211382
early_stopping="auto",
13221383
scoring="loss",
@@ -1336,6 +1397,7 @@ def __init__(
13361397
l2_regularization=l2_regularization,
13371398
max_bins=max_bins,
13381399
monotonic_cst=monotonic_cst,
1400+
interaction_cst=interaction_cst,
13391401
categorical_features=categorical_features,
13401402
early_stopping=early_stopping,
13411403
warm_start=warm_start,
@@ -1509,6 +1571,22 @@ class HistGradientBoostingClassifier(ClassifierMixin, BaseHistGradientBoosting):
15091571
15101572
.. versionadded:: 0.23
15111573
1574+
interaction_cst : iterable of iterables of int, default=None
1575+
Specify interaction constraints, i.e. sets of features which can
1576+
only interact with each other in child nodes splits.
1577+
1578+
Each iterable materializes a constraint by the set of indices of
1579+
the features that are allowed to interact with each other.
1580+
If there are more features than specified in these constraints,
1581+
they are treated as if they were specified as an additional set.
1582+
1583+
For instance, with 5 features in total, `interaction_cst=[{0, 1}]`
1584+
is equivalent to `interaction_cst=[{0, 1}, {2, 3, 4}]`,
1585+
and specifies that each branch of a tree will either only split
1586+
on features 0 and 1 or only split on features 2, 3 and 4.
1587+
1588+
.. versionadded:: 1.2
1589+
15121590
warm_start : bool, default=False
15131591
When set to ``True``, reuse the solution of the previous call to fit
15141592
and add more estimators to the ensemble. For results to be valid, the
@@ -1657,6 +1735,7 @@ def __init__(
16571735
max_bins=255,
16581736
categorical_features=None,
16591737
monotonic_cst=None,
1738+
interaction_cst=None,
16601739
warm_start=False,
16611740
early_stopping="auto",
16621741
scoring="loss",
@@ -1678,6 +1757,7 @@ def __init__(
16781757
max_bins=max_bins,
16791758
categorical_features=categorical_features,
16801759
monotonic_cst=monotonic_cst,
1760+
interaction_cst=interaction_cst,
16811761
warm_start=warm_start,
16821762
early_stopping=early_stopping,
16831763
scoring=scoring,

0 commit comments

Comments
 (0)
0