2
2
# Author: Nicolas Hug
3
3
4
4
from abc import ABC , abstractmethod
5
+ from collections .abc import Iterable
5
6
from functools import partial
6
7
from numbers import Real , Integral
7
8
import warnings
@@ -91,6 +92,7 @@ class BaseHistGradientBoosting(BaseEstimator, ABC):
91
92
"min_samples_leaf" : [Interval (Integral , 1 , None , closed = "left" )],
92
93
"l2_regularization" : [Interval (Real , 0 , None , closed = "left" )],
93
94
"monotonic_cst" : ["array-like" , None ],
95
+ "interaction_cst" : [Iterable , None ],
94
96
"n_iter_no_change" : [Interval (Integral , 1 , None , closed = "left" )],
95
97
"validation_fraction" : [
96
98
Interval (Real , 0 , 1 , closed = "neither" ),
@@ -121,6 +123,7 @@ def __init__(
121
123
max_bins ,
122
124
categorical_features ,
123
125
monotonic_cst ,
126
+ interaction_cst ,
124
127
warm_start ,
125
128
early_stopping ,
126
129
scoring ,
@@ -139,6 +142,7 @@ def __init__(
139
142
self .l2_regularization = l2_regularization
140
143
self .max_bins = max_bins
141
144
self .monotonic_cst = monotonic_cst
145
+ self .interaction_cst = interaction_cst
142
146
self .categorical_features = categorical_features
143
147
self .warm_start = warm_start
144
148
self .early_stopping = early_stopping
@@ -252,6 +256,42 @@ def _check_categories(self, X):
252
256
253
257
return is_categorical , known_categories
254
258
259
+ def _check_interaction_cst (self , n_features ):
260
+ """Check and validation for interaction constraints."""
261
+ if self .interaction_cst is None :
262
+ return None
263
+
264
+ if not (
265
+ isinstance (self .interaction_cst , Iterable )
266
+ and all (isinstance (x , Iterable ) for x in self .interaction_cst )
267
+ ):
268
+ raise ValueError (
269
+ "Interaction constraints must be None or an iterable of iterables, "
270
+ f"got: { self .interaction_cst !r} ."
271
+ )
272
+
273
+ invalid_indices = [
274
+ x
275
+ for cst_set in self .interaction_cst
276
+ for x in cst_set
277
+ if not (isinstance (x , Integral ) and 0 <= x < n_features )
278
+ ]
279
+ if invalid_indices :
280
+ raise ValueError (
281
+ "Interaction constraints must consist of integer indices in [0,"
282
+ f" n_features - 1] = [0, { n_features - 1 } ], specifying the position of"
283
+ f" features, got invalid indices: { invalid_indices !r} "
284
+ )
285
+
286
+ constraints = [set (group ) for group in self .interaction_cst ]
287
+
288
+ # Add all not listed features as own group by default.
289
+ rest = set (range (n_features )) - set ().union (* constraints )
290
+ if len (rest ) > 0 :
291
+ constraints .append (rest )
292
+
293
+ return constraints
294
+
255
295
def fit (self , X , y , sample_weight = None ):
256
296
"""Fit the gradient boosting model.
257
297
@@ -308,6 +348,9 @@ def fit(self, X, y, sample_weight=None):
308
348
309
349
self .is_categorical_ , known_categories = self ._check_categories (X )
310
350
351
+ # Encode constraints into a list of sets of features indices (integers).
352
+ interaction_cst = self ._check_interaction_cst (self ._n_features )
353
+
311
354
# we need this stateful variable to tell raw_predict() that it was
312
355
# called from fit() (this current method), and that the data it has
313
356
# received is pre-binned.
@@ -595,6 +638,7 @@ def fit(self, X, y, sample_weight=None):
595
638
has_missing_values = has_missing_values ,
596
639
is_categorical = self .is_categorical_ ,
597
640
monotonic_cst = self .monotonic_cst ,
641
+ interaction_cst = interaction_cst ,
598
642
max_leaf_nodes = self .max_leaf_nodes ,
599
643
max_depth = self .max_depth ,
600
644
min_samples_leaf = self .min_samples_leaf ,
@@ -1193,6 +1237,22 @@ class HistGradientBoostingRegressor(RegressorMixin, BaseHistGradientBoosting):
1193
1237
1194
1238
.. versionadded:: 0.23
1195
1239
1240
+ interaction_cst : iterable of iterables of int, default=None
1241
+ Specify interaction constraints, i.e. sets of features which can
1242
+ only interact with each other in child nodes splits.
1243
+
1244
+ Each iterable materializes a constraint by the set of indices of
1245
+ the features that are allowed to interact with each other.
1246
+ If there are more features than specified in these constraints,
1247
+ they are treated as if they were specified as an additional set.
1248
+
1249
+ For instance, with 5 features in total, `interaction_cst=[{0, 1}]`
1250
+ is equivalent to `interaction_cst=[{0, 1}, {2, 3, 4}]`,
1251
+ and specifies that each branch of a tree will either only split
1252
+ on features 0 and 1 or only split on features 2, 3 and 4.
1253
+
1254
+ .. versionadded:: 1.2
10000
code>
1255
+
1196
1256
warm_start : bool, default=False
1197
1257
When set to ``True``, reuse the solution of the previous call to fit
1198
1258
and add more estimators to the ensemble. For results to be valid, the
@@ -1317,6 +1377,7 @@ def __init__(
1317
1377
max_bins = 255 ,
1318
1378
categorical_features = None ,
1319
1379
monotonic_cst = None ,
1380
+ interaction_cst = None ,
1320
1381
warm_start = False ,
1321
1382
early_stopping = "auto" ,
1322
1383
scoring = "loss" ,
@@ -1336,6 +1397,7 @@ def __init__(
1336
1397
l2_regularization = l2_regularization ,
1337
1398
max_bins = max_bins ,
1338
1399
monotonic_cst = monotonic_cst ,
1400
+ interaction_cst = interaction_cst ,
1339
1401
categorical_features = categorical_features ,
1340
1402
early_stopping = early_stopping ,
1341
1403
warm_start = warm_start ,
@@ -1509,6 +1571,22 @@ class HistGradientBoostingClassifier(ClassifierMixin, BaseHistGradientBoosting):
1509
1571
1510
1572
.. versionadded:: 0.23
1511
1573
1574
+ interaction_cst : iterable of iterables of int, default=None
1575
+ Specify interaction constraints, i.e. sets of features which can
1576
+ only interact with each other in child nodes splits.
1577
+
1578
+ Each iterable materializes a constraint by the set of indices of
1579
+ the features that are allowed to interact with each other.
1580
+ If there are more features than specified in these constraints,
1581
+ they are treated as if they were specified as an additional set.
1582
+
1583
+ For instance, with 5 features in total, `interaction_cst=[{0, 1}]`
1584
+ is equivalent to `interaction_cst=[{0, 1}, {2, 3, 4}]`,
1585
+ and specifies that each branch of a tree will either only split
1586
+ on features 0 and 1 or only split on features 2, 3 and 4.
1587
+
1588
+ .. versionadded:: 1.2
1589
+
1512
1590
warm_start : bool, default=False
1513
1591
When set to ``True``, reuse the solution of the previous call to fit
1514
1592
and add more estimators to the ensemble. For results to be valid, the
@@ -1657,6 +1735,7 @@ def __init__(
1657
1735
max_bins = 255 ,
1658
1736
categorical_features = None ,
1659
1737
monotonic_cst = None ,
1738
+ interaction_cst = None ,
1660
1739
warm_start = False ,
1661
1740
early_stopping = "auto" ,
1662
1741
scoring = "loss" ,
@@ -1678,6 +1757,7 @@ def __init__(
1678
1757
max_bins = max_bins ,
1679
1758
categorical_features = categorical_features ,
1680
1759
monotonic_cst = monotonic_cst ,
1760
+ interaction_cst = interaction_cst ,
1681
1761
warm_start = warm_start ,
1682
1762
early_stopping = early_stopping ,
1683
1763
scoring = scoring ,
0 commit comments