3
3
4
4
from abc import ABC , abstractmethod
5
5
from functools import partial
6
+ import itertools
6
7
from numbers import Real , Integral
7
8
import warnings
8
9
@@ -92,7 +93,12 @@ class BaseHistGradientBoosting(BaseEstimator, ABC):
92
93
"min_samples_leaf" : [Interval (Integral , 1 , None , closed = "left" )],
93
94
"l2_regularization" : [Interval (Real , 0 , None , closed = "left" )],
94
95
"monotonic_cst" : ["array-like" , dict , None ],
95
- "interaction_cst" : [list , tuple , None ],
96
+ "interaction_cst" : [
97
+ list ,
98
+ tuple ,
99
+ StrOptions ({"pairwise" , "no_interactions" }),
100
+ None ,
101
+ ],
96
102
"n_iter_no_change" : [Interval (Integral , 1 , None , closed = "left" )],
97
103
"validation_fraction" : [
98
104
Interval (Real , 0 , 1 , closed = "neither" ),
@@ -288,8 +294,15 @@ def _check_interaction_cst(self, n_features):
288
294
if self .interaction_cst is None :
289
295
return None
290
296
297
+ if self .interaction_cst == "no_interactions" :
298
+ interaction_cst = [[i ] for i in range (n_features )]
299
+ elif self .interaction_cst == "pairwise" :
300
+ interaction_cst = itertools .combinations (range (n_features ), 2 )
301
+ else :
302
+ interaction_cst = self .interaction_cst
303
+
291
304
try :
292
- constraints = [set (group ) for group in self . interaction_cst ]
305
+ constraints = [set (group ) for group in interaction_cst ]
293
306
except TypeError :
294
307
raise ValueError (
295
308
"Interaction constraints must be a sequence of tuples or lists, got:"
@@ -1275,7 +1288,8 @@ class HistGradientBoostingRegressor(RegressorMixin, BaseHistGradientBoosting):
1275
1288
.. versionchanged:: 1.2
1276
1289
Accept dict of constraints with feature names as keys.
1277
1290
1278
- interaction_cst : sequence of lists/tuples/sets of int, default=None
1291
+ interaction_cst : {"pairwise", "no_interaction"} or sequence of lists/tuples/sets \
1292
+ of int, default=None
1279
1293
Specify interaction constraints, the sets of features which can
1280
1294
interact with each other in child node splits.
1281
1295
@@ -1284,6 +1298,9 @@ class HistGradientBoostingRegressor(RegressorMixin, BaseHistGradientBoosting):
1284
1298
specified in these constraints, they are treated as if they were
1285
1299
specified as an additional set.
1286
1300
1301
+ The strings "pairwise" and "no_interactions" are shorthands for
1302
+ allowing only pairwise or no interactions, respectively.
1303
+
1287
1304
For instance, with 5 features in total, `interaction_cst=[{0, 1}]`
1288
1305
is equivalent to `interaction_cst=[{0, 1}, {2, 3, 4}]`,
1289
1306
and specifies that each branch of a tree will either only split
@@ -1623,7 +1640,8 @@ class HistGradientBoostingClassifier(ClassifierMixin, BaseHistGradientBoosting):
1623
1640
.. versionchanged:: 1.2
1624
1641
Accept dict of constraints with feature names as keys.
1625
1642
1626
- interaction_cst : sequence of lists/tuples/sets of int, default=None
1643
+ interaction_cst : {"pairwise", "no_interaction"} or sequence of lists/tuples/sets \
1644
+ of int, default=None
1627
1645
Specify interaction constraints, the sets of features which can
1628
1646
interact with each other in child node splits.
1629
1647
@@ -1632,6 +1650,9 @@ class HistGradientBoostingClassifier(ClassifierMixin, BaseHistGradientBoosting):
1632
1650
specified in these constraints, they are treated as if they were
1633
1651
specified as an additional set.
1634
1652
1653
+ The strings "pairwise" and "no_interactions" are shorthands for
1654
+ allowing only pairwise or no interactions, respectively.
1655
+
1635
1656
For instance, with 5 features in total, `interaction_cst=[{0, 1}]`
1636
1657
is equivalent to `interaction_cst=[{0, 1}, {2, 3, 4}]`,
1637
1658
and specifies that each branch of a tree will either only split
0 commit comments