From e4cc10e45d552aa62d8db020d944bfb37f51febc Mon Sep 17 00:00:00 2001 From: gbolmier Date: Sun, 12 Apr 2020 18:50:55 -0400 Subject: [PATCH] Add custom loss support for HistGradientBoosting Add custom loss support for HistGradientBoostingClassifier and HistGradientBoostingRegressor as a private API without any documentation. A `BaseLoss` object can now be passed a loss parameter. Resolves: #15841 --- .../gradient_boosting.py | 10 +++++++-- .../tests/test_gradient_boosting.py | 21 +++++++++++++++++++ 2 files changed, 29 insertions(+), 2 deletions(-) diff --git a/sklearn/ensemble/_hist_gradient_boosting/gradient_boosting.py b/sklearn/ensemble/_hist_gradient_boosting/gradient_boosting.py index 796f4f060dda5..6087adb0b6575 100644 --- a/sklearn/ensemble/_hist_gradient_boosting/gradient_boosting.py +++ b/sklearn/ensemble/_hist_gradient_boosting/gradient_boosting.py @@ -23,6 +23,7 @@ from .binning import _BinMapper from .grower import TreeGrower from .loss import _LOSSES +from .loss import BaseLoss class BaseHistGradientBoosting(BaseEstimator, ABC): @@ -58,7 +59,8 @@ def _validate_parameters(self): The parameters that are directly passed to the grower are checked in TreeGrower.""" - if self.loss not in self._VALID_LOSSES: + if (self.loss not in self._VALID_LOSSES and + not isinstance(self.loss, BaseLoss)): raise ValueError( "Loss {} is not supported for {}. Accepted losses: " "{}.".format(self.loss, self.__class__.__name__, @@ -150,7 +152,11 @@ def fit(self, X, y, sample_weight=None): # data. self._in_fit = True - self.loss_ = self._get_loss(sample_weight=sample_weight) + if isinstance(self.loss, str): + self.loss_ = self._get_loss(sample_weight=sample_weight) + elif isinstance(self.loss, BaseLoss): + self.loss_ = self.loss + if self.early_stopping == 'auto': self.do_early_stopping_ = n_samples > 10000 else: diff --git a/sklearn/ensemble/_hist_gradient_boosting/tests/test_gradient_boosting.py b/sklearn/ensemble/_hist_gradient_boosting/tests/test_gradient_boosting.py index 1b61e65793422..6fc412942d180 100644 --- a/sklearn/ensemble/_hist_gradient_boosting/tests/test_gradient_boosting.py +++ b/sklearn/ensemble/_hist_gradient_boosting/tests/test_gradient_boosting.py @@ -12,6 +12,8 @@ from sklearn.ensemble import HistGradientBoostingRegressor from sklearn.ensemble import HistGradientBoostingClassifier from sklearn.ensemble._hist_gradient_boosting.loss import _LOSSES +from sklearn.ensemble._hist_gradient_boosting.loss import LeastSquares +from sklearn.ensemble._hist_gradient_boosting.loss import BinaryCrossEntropy from sklearn.ensemble._hist_gradient_boosting.grower import TreeGrower from sklearn.ensemble._hist_gradient_boosting.binning import _BinMapper from sklearn.utils import shuffle @@ -681,3 +683,22 @@ def test_single_node_trees(Est): for predictor in est._predictors) # Still gives correct predictions thanks to the baseline prediction assert_allclose(est.predict(X), y) + + +@pytest.mark.parametrize('Est, loss, X, y', [ + ( + HistGradientBoostingClassifier, + BinaryCrossEntropy(sample_weight=None), + X_classification, + y_classification + ), + ( + HistGradientBoostingRegressor, + LeastSquares(sample_weight=None), + X_regression, + y_regression + ) +]) +def test_custom_loss(Est, loss, X, y): + est = Est(loss=loss, max_iter=20) + est.fit(X, y)