8000 ENH Add custom loss support for HistGradientBoosting (#16908) · scikit-learn/scikit-learn@9d366a4 · GitHub
[go: up one dir, main page]

Skip to content

Commit 9d366a4

Browse files
authored
ENH Add custom loss support for HistGradientBoosting (#16908)
1 parent cb9ddbb commit 9d366a4

File tree

2 files changed

+29
-2
lines changed

2 files changed

+29
-2
lines changed

sklearn/ensemble/_hist_gradient_boosting/gradient_boosting.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
from .binning import _BinMapper
2424
from .grower import TreeGrower
2525
from .loss import _LOSSES
26+
from .loss import BaseLoss
2627

2728

2829
class BaseHistGradientBoosting(BaseEstimator, ABC):
@@ -58,7 +59,8 @@ def _validate_parameters(self):
5859
The parameters that are directly passed to the grower are checked in
5960
TreeGrower."""
6061

61-
if self.loss not in self._VALID_LOSSES:
62+
if (self.loss not in self._VALID_LOSSES and
63+
not isinstance(self.loss, BaseLoss)):
6264
raise ValueError(
6365
"Loss {} is not supported for {}. Accepted losses: "
6466
"{}.".format(self.loss, self.__class__.__name__,
@@ -150,7 +152,11 @@ def fit(self, X, y, sample_weight=None):
150152
# data.
151153
self._in_fit = True
152154

153-
self.loss_ = self._get_loss(sample_weight=sample_weight)
155+
if isinstance(self.loss, str):
156+
self.loss_ = self._get_loss(sample_weight=sample_weight)
157+
elif isinstance(self.loss, BaseLoss):
158+
self.loss_ = self.loss
159+
154160
if self.early_stopping == 'auto':
155161
self.do_early_stopping_ = n_samples > 10000
156162
else:

sklearn/ensemble/_hist_gradient_boosting/tests/test_gradient_boosting.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212
from sklearn.ensemble import HistGradientBoostingRegressor
1313
from sklearn.ensemble import HistGradientBoostingClassifier
1414
from sklearn.ensemble._hist_gradient_boosting.loss import _LOSSES
15+
from sklearn.ensemble._hist_gradient_boosting.loss import LeastSquares
16+
from sklearn.ensemble._hist_gradient_boosting.loss import BinaryCrossEntropy
1517
from sklearn.ensemble._hist_gradient_boosting.grower import TreeGrower
1618
from sklearn.ensemble._hist_gradient_boosting.binning import _BinMapper
1719
from sklearn.utils import shuffle
@@ -681,3 +683,22 @@ def test_single_node_trees(Est):
681683
for predictor in est._predictors)
682684
# Still gives correct predictions thanks to the baseline prediction
683685
assert_allclose(est.predict(X), y)
686+
687+
688+
@pytest.mark.parametrize('Est, loss, X, y', [
689+
(
690+
HistGradientBoostingClassifier,
691+
BinaryCrossEntropy(sample_weight=None),
692+
X_classification,
693+
y_classification
694+
),
695+
(
696+
HistGradientBoostingRegressor,
697+
LeastSquares(sample_weight=None),
698+
X_regression,
699+
y_regression
700+
)
701+
])
702+
def test_custom_loss(Est, loss, X, y):
703+
est = Est(loss=loss, max_iter=20)
704+
est.fit(X, y)

0 commit comments

Comments
 (0)
0