8000 ENH Add interaction constraint shortcuts to HistGradientBoosting* (#2… · scikit-learn/scikit-learn@df14322 · GitHub
[go: up one dir, main page]

Skip to content

Commit df14322

Browse files
betatimjeremiedbb
andauthored
ENH Add interaction constraint shortcuts to HistGradientBoosting* (#24849)
Co-authored-by: jeremie du boisberranger <jeremiedbb@yahoo.fr>
1 parent f76ea1b commit df14322

File tree

4 files changed

+36
-7
lines changed

4 files changed

+36
-7
lines changed

doc/whats_new/v1.2.rst

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -328,6 +328,12 @@ Changelog
328328
as value to specify monotonicity constraints for each feature.
329329
:pr:`24855` by :user:`Olivier Grisel <ogrisel>`.
330330

331+
- |Enhancement| Interaction constraints for
332+
:class:`ensemble.HistGradientBoostingClassifier`
333+
and :class:`ensemble.HistGradientBoostingRegressor` can now be specified
334+
as strings for two common cases: "no_interactions" and "pairwise" interactions.
335+
:pr:`24849` by :user:`Tim Head <betatim>`.
336+
331337
- |Fix| Fixed the issue where :class:`ensemble.AdaBoostClassifier` outputs
332338
NaN in feature importance when fitted with very small sample weight.
333339
:pr:`20415` by :user:`Zhehao Liu <MaxwellLZH>`.

examples/inspection/plot_partial_dependence.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -270,9 +270,7 @@
270270

271271
print("Training interaction constraint HistGradientBoostingRegressor...")
272272
tic = time()
273-
est_no_interactions = HistGradientBoostingRegressor(
274-
interaction_cst=[[i] for i in range(X_train.shape[1])]
275-
)
273+
est_no_interactions = HistGradientBoostingRegressor(interaction_cst="no_interactions")
276274
est_no_interactions.fit(X_train, y_train)
277275
print(f"done in {time() - tic:.3f}s")
278276

sklearn/ensemble/_hist_gradient_boosting/gradient_boosting.py

Lines changed: 25 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
from abc import ABC, abstractmethod
55
from functools import partial
6+
import itertools
67
from numbers import Real, Integral
78
import warnings
89

@@ -92,7 +93,12 @@ class BaseHistGradientBoosting(BaseEstimator, ABC):
9293
"min_samples_leaf": [Interval(Integral, 1, None, closed="left")],
9394
"l2_regularization": [Interval(Real, 0, None, closed="left")],
9495
"monotonic_cst": ["array-like", dict, None],
95-
"interaction_cst": [list, tuple, None],
96+
"interaction_cst": [
97+
list,
98+
tuple,
99+
StrOptions({"pairwise", "no_interactions"}),
100+
None,
101+
],
96102
"n_iter_no_change": [Interval(Integral, 1, None, closed="left")],
97103
"validation_fraction": [
98104
Interval(Real, 0, 1, closed="neither"),
@@ -288,8 +294,15 @@ def _check_interaction_cst(self, n_features):
288294
if self.interaction_cst is None:
289295
return None
290296

297+
if self.interaction_cst == "no_interactions":
298+
interaction_cst = [[i] for i in range(n_features)]
299+
elif self.interaction_cst == "pairwise":
300+
interaction_cst = itertools.combinations(range(n_features), 2)
301+
else:
302+
interaction_cst = self.interaction_cst
303+
291304
try:
292-
constraints = [set(group) for group in self.interaction_cst]
305+
constraints = [set(group) for group in interaction_cst]
293306
except TypeError:
294307
raise ValueError(
295308
"Interaction constraints must be a sequence of tuples or lists, got:"
@@ -1275,7 +1288,8 @@ class HistGradientBoostingRegressor(RegressorMixin, BaseHistGradientBoosting):
12751288
.. versionchanged:: 1.2
12761289
Accept dict of constraints with feature names as keys.
12771290
1278-
interaction_cst : sequence of lists/tuples/sets of int, default=None
1291+
interaction_cst : {"pairwise", "no_interaction"} or sequence of lists/tuples/sets \
1292+
of int, default=None
12791293
Specify interaction constraints, the sets of features which can
12801294
interact with each other in child node splits.
12811295
@@ -1284,6 +1298,9 @@ class HistGradientBoostingRegressor(RegressorMixin, BaseHistGradientBoosting):
12841298
specified in these constraints, they are treated as if they were
12851299
specified as an additional set.
12861300
1301+
The strings "pairwise" and "no_interactions" are shorthands for
1302+
allowing only pairwise or no interactions, respectively.
1303+
12871304
For instance, with 5 features in total, `interaction_cst=[{0, 1}]`
12881305
is equivalent to `interaction_cst=[{0, 1}, {2, 3, 4}]`,
12891306
and specifies that each branch of a tree will either only split
@@ -1623,7 +1640,8 @@ class HistGradientBoostingClassifier(ClassifierMixin, BaseHistGradientBoosting):
16231640
.. versionchanged:: 1.2
16241641
Accept dict of constraints with feature names as keys.
16251642
1626-
interaction_cst : sequence of lists/tuples/sets of int, default=None
1643+
interaction_cst : {"pairwise", "no_interaction"} or sequence of lists/tuples/sets \
1644+
of int, default=None
16271645
Specify interaction constraints, the sets of features which can
16281646
interact with each other in child node splits.
16291647
@@ -1632,6 +1650,9 @@ class HistGradientBoostingClassifier(ClassifierMixin, BaseHistGradientBoosting):
16321650
specified in these constraints, they are treated as if they were
16331651
specified as an additional set.
16341652
1653+
The strings "pairwise" and "no_interactions" are shorthands for
1654+
allowing only pairwise or no interactions, respectively.
1655+
16351656
For instance, with 5 features in total, `interaction_cst=[{0, 1}]`
16361657
is equivalent to `interaction_cst=[{0, 1}, {2, 3, 4}]`,
16371658
and specifies that each branch of a tree will either only split

8BFF sklearn/ensemble/_hist_gradient_boosting/tests/test_gradient_boosting.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1187,6 +1187,10 @@ def test_uint8_predict(Est):
11871187
[
11881188
(None, 931, None),
11891189
([{0, 1}], 2, [{0, 1}]),
1190+
("pairwise", 2, [{0, 1}]),
1191+
("pairwise", 4, [{0, 1}, {0, 2}, {0, 3}, {1, 2}, {1, 3}, {2, 3}]),
1192+
("no_interactions", 2, [{0}, {1}]),
1193+
("no_interactions", 4, [{0}, {1}, {2}, {3}]),
11901194
([(1, 0), [5, 1]], 6, [{0, 1}, {1, 5}, {2, 3, 4}]),
11911195
],
11921196
)

0 commit comments

Comments
 (0)
0