22# Author: Nicolas Hug
33
44from abc import ABC , abstractmethod
5+ from collections .abc import Iterable
56from functools import partial
67from numbers import Real , Integral
78import warnings
@@ -91,6 +92,7 @@ class BaseHistGradientBoosting(BaseEstimator, ABC):
9192 "min_samples_leaf" : [Interval (Integral , 1 , None , closed = "left" )],
9293 "l2_regularization" : [Interval (Real , 0 , None , closed = "left" )],
9394 "monotonic_cst" : ["array-like" , None ],
95+ "interaction_cst" : [Iterable , None ],
9496 "n_iter_no_change" : [Interval (Integral , 1 , None , closed = "left" )],
9597 "validation_fraction" : [
9698 Interval (Real , 0 , 1 , closed = "neither" ),
@@ -121,6 +123,7 @@ def __init__(
121123 max_bins ,
122124 categorical_features ,
123125 monotonic_cst ,
126+ interaction_cst ,
124127 warm_start ,
125128 early_stopping ,
126129 scoring ,
@@ -139,6 +142,7 @@ def __init__(
139142 self .l2_regularization = l2_regularization
140143 self .max_bins = max_bins
141144 self .monotonic_cst = monotonic_cst
145+ self .interaction_cst = interaction_cst
142146 self .categorical_features = categorical_features
143147 self .warm_start = warm_start
144148 self .early_stopping = early_stopping
@@ -252,6 +256,42 @@ def _check_categories(self, X):
252256
253257 return is_categorical , known_categories
254258
259+ def _check_interaction_cst (self , n_features ):
260+ """Check and validation for interaction constraints."""
261+ if self .interaction_cst is None :
262+ return None
263+
264+ if not (
265+ isinstance (self .interaction_cst , Iterable )
266+ and all (isinstance (x , Iterable ) for x in self .interaction_cst )
267+ ):
268+ raise ValueError (
269+ "Interaction constraints must be None or an iterable of iterables, "
270+ f"got: { self .interaction_cst !r} ."
271+ )
272+
273+ invalid_indices = [
274+ x
275+ for cst_set in self .interaction_cst
276+ for x in cst_set
277+ if not (isinstance (x , Integral ) and 0 <= x < n_features )
278+ ]
279+ if invalid_indices :
280+ raise ValueError (
281+ "Interaction constraints must consist of integer indices in [0,"
282+ f" n_features - 1] = [0, { n_features - 1 } ], specifying the position of"
283+ f" features, got invalid indices: { invalid_indices !r} "
284+ )
285+
286+ constraints = [set (group ) for group in self .interaction_cst ]
287+
288+ # Add all not listed features as own group by default.
289+ rest = set (range (n_features )) - set ().union (* constraints )
290+ if len (rest ) > 0 :
291+ constraints .append (rest )
292+
293+ return constraints
294+
255295 def fit (self , X , y , sample_weight = None ):
256296 """Fit the gradient boosting model.
257297
@@ -308,6 +348,9 @@ def fit(self, X, y, sample_weight=None):
308348
309349 self .is_categorical_ , known_categories = self ._check_categories (X )
310350
351+ # Encode constraints into a list of sets of features indices (integers).
352+ interaction_cst = self ._check_interaction_cst (self ._n_features )
353+
311354 # we need this stateful variable to tell raw_predict() that it was
312355 # called from fit() (this current method), and that the data it has
313356 # received is pre-binned.
@@ -595,6 +638,7 @@ def fit(self, X, y, sample_weight=None):
595638 has_missing_values = has_missing_values ,
596639 is_categorical = self .is_categorical_ ,
597640 monotonic_cst = self .monotonic_cst ,
641+ interaction_cst = interaction_cst ,
598642 max_leaf_nodes = self .max_leaf_nodes ,
599643 max_depth = self .max_depth ,
600644 min_samples_leaf = self .min_samples_leaf ,
@@ -1193,6 +1237,22 @@ class HistGradientBoostingRegressor(RegressorMixin, BaseHistGradientBoosting):
11931237
11941238 .. versionadded:: 0.23
11951239
1240+ interaction_cst : iterable of iterables of int, default=None
1241+ Specify interaction constraints, i.e. sets of features which can
1242+ only interact with each other in child nodes splits.
1243+
1244+ Each iterable materializes a constraint by the set of indices of
1245+ the features that are allowed to interact with each other.
1246+ If there are more features than specified in these constraints,
1247+ they are treated as if they were specified as an additional set.
1248+
1249+ For instance, with 5 features in total, `interaction_cst=[{0, 1}]`
1250+ is equivalent to `interaction_cst=[{0, 1}, {2, 3, 4}]`,
1251+ and specifies that each branch of a tree will either only split
1252+ on features 0 and 1 or only split on features 2, 3 and 4.
1253+
1254+ .. versionadded:: 1.2
1255+
11961256 warm_start : bool, default=False
11971257 When set to ``True``, reuse the solution of the previous call to fit
11981258 and add more estimators to the ensemble. For results to be valid, the
@@ -1317,6 +1377,7 @@ def __init__(
13171377 max_bins = 255 ,
13181378 categorical_features = None ,
13191379 monotonic_cst = None ,
1380+ interaction_cst = None ,
13201381 warm_start = False ,
13211382 early_stopping = "auto" ,
13221383 scoring = "loss" ,
@@ -1336,6 +1397,7 @@ def __init__(
13361397 l2_regularization = l2_regularization ,
13371398 max_bins = max_bins ,
13381399 monotonic_cst = monotonic_cst ,
1400+ interaction_cst = interaction_cst ,
13391401 categorical_features = categorical_features ,
13401402 early_stopping = early_stopping ,
13411403 warm_start = warm_start ,
@@ -1509,6 +1571,22 @@ class HistGradientBoostingClassifier(ClassifierMixin, BaseHistGradientBoosting):
15091571
15101572 .. versionadded:: 0.23
15111573
1574+ interaction_cst : iterable of iterables of int, default=None
1575+ Specify interaction constraints, i.e. sets of features which can
1576+ only interact with each other in child nodes splits.
1577+
1578+ Each iterable materializes a constraint by the set of indices of
1579+ the features that are allowed to interact with each other.
1580+ If there are more features than specified in these constraints,
1581+ they are treated as if they were specified as an additional set.
1582+
1583+ For instance, with 5 features in total, `interaction_cst=[{0, 1}]`
1584+ is equivalent to `interaction_cst=[{0, 1}, {2, 3, 4}]`,
1585+ and specifies that each branch of a tree will either only split
1586+ on features 0 and 1 or only split on features 2, 3 and 4.
1587+
1588+ .. versionadded:: 1.2
1589+
15121590 warm_start : bool, default=False
15131591 When set to ``True``, reuse the solution of the previous call to fit
15141592 and add more estimators to the ensemble. For results to be valid, the
@@ -1657,6 +1735,7 @@ def __init__(
16571735 max_bins = 255 ,
16581736 categorical_features = None ,
16591737 monotonic_cst = None ,
1738+ interaction_cst = None ,
16601739 warm_start = False ,
16611740 early_stopping = "auto" ,
16621741 scoring = "loss" ,
@@ -1678,6 +1757,7 @@ def __init__(
16781757 max_bins = max_bins ,
16791758 categorical_features = categorical_features ,
16801759 monotonic_cst = monotonic_cst ,
1760+ interaction_cst = interaction_cst ,
16811761 warm_start = warm_start ,
16821762 early_stopping = early_stopping ,
16831763 scoring = scoring ,
0 commit comments