diff --git a/benchmarks/bench_hist_gradient_boosting.py b/benchmarks/bench_hist_gradient_boosting.py index 5f070bd45708d..9bfd6d743ee4f 100644 --- a/benchmarks/bench_hist_gradient_boosting.py +++ b/benchmarks/bench_hist_gradient_boosting.py @@ -26,6 +26,7 @@ parser.add_argument('--learning-rate', type=float, default=.1) parser.add_argument('--problem', type=str, default='classification', choices=['classification', 'regression']) +parser.add_argument('--loss', type=str, default='default') parser.add_argument('--missing-fraction', type=float, default=0) parser.add_argument('--n-classes', type=int, default=2) parser.add_argument('--n-samples-max', type=int, default=int(1e6)) @@ -81,6 +82,17 @@ def one_run(n_samples): n_iter_no_change=None, random_state=0, verbose=0) + loss = args.loss + if args.problem == 'classification': + if loss == 'default': + # loss='auto' does not work with get_equivalent_estimator() + loss = 'binary_crossentropy' if args.n_classes == 2 else \ + 'categorical_crossentropy' + else: + # regression + if loss == 'default': + loss = 'least_squares' + est.set_params(loss=loss) est.fit(X_train, y_train) sklearn_fit_duration = time() - tic tic = time() @@ -95,11 +107,6 @@ def one_run(n_samples): lightgbm_score_duration = None if args.lightgbm: print("Fitting a LightGBM model...") - # get_lightgbm does not accept loss='auto' - if args.problem == 'classification': - loss = 'binary_crossentropy' if args.n_classes == 2 else \ - 'categorical_crossentropy' - est.set_params(loss=loss) lightgbm_est = get_equivalent_estimator(est, lib='lightgbm') tic = time() @@ -117,11 +124,6 @@ def one_run(n_samples): xgb_score_duration = None if args.xgboost: print("Fitting an XGBoost model...") - # get_xgb does not accept loss='auto' - if args.problem == 'classification': - loss = 'binary_crossentropy' if args.n_classes == 2 else \ - 'categorical_crossentropy' - est.set_params(loss=loss) xgb_est = get_equivalent_estimator(est, lib='xgboost') tic = time() @@ -139,11 +141,6 @@ def one_run(n_samples): cat_score_duration = None if args.catboost: print("Fitting a CatBoost model...") - # get_cat does not accept loss='auto' - if args.problem == 'classification': - loss = 'binary_crossentropy' if args.n_classes == 2 else \ - 'categorical_crossentropy' - est.set_params(loss=loss) cat_est = get_equivalent_estimator(est, lib='catboost') tic = time() diff --git a/doc/modules/ensemble.rst b/doc/modules/ensemble.rst index fde8f40db6c8c..da5aaebfb6870 100644 --- a/doc/modules/ensemble.rst +++ b/doc/modules/ensemble.rst @@ -878,6 +878,13 @@ controls the number of iterations of the boosting process:: >>> clf.score(X_test, y_test) 0.8965 +Available losses for regression are 'least_squares' and +'least_absolute_deviation', which is less sensitive to outliers. For +classification, 'binary_crossentropy' is used for binary classification and +'categorical_crossentropy' is used for multiclass classification. By default +the loss is 'auto' and will select the appropriate loss depending on +:term:`y` passed to :term:`fit`. + The size of the trees can be controlled through the ``max_leaf_nodes``, ``max_depth``, and ``min_samples_leaf`` parameters. diff --git a/doc/whats_new/v0.22.rst b/doc/whats_new/v0.22.rst index 462a420a0d3ce..3268f0761ad86 100644 --- a/doc/whats_new/v0.22.rst +++ b/doc/whats_new/v0.22.rst @@ -130,6 +130,8 @@ Changelog - |Feature| :func:`inspection.partial_dependence` and :func:`inspection.plot_partial_dependence` now support the fast 'recursion' method for both estimators. :pr:`13769` by `Nicolas Hug`_. + - |Enhancement| :class:`ensemble.HistGradientBoostingRegressor` now supports + the 'least_absolute_deviation' loss. :pr:`13896` by `Nicolas Hug`_. - |Fix| Estimators now bin the training and validation data separately to avoid any data leak. :pr:`13933` by `Nicolas Hug`_. - |Fix| Fixed a bug where early stopping would break with string targets. diff --git a/sklearn/ensemble/_hist_gradient_boosting/_loss.pyx b/sklearn/ensemble/_hist_gradient_boosting/_loss.pyx index ff17654840005..418a9124d37fa 100644 --- a/sklearn/ensemble/_hist_gradient_boosting/_loss.pyx +++ b/sklearn/ensemble/_hist_gradient_boosting/_loss.pyx @@ -27,12 +27,24 @@ def _update_gradients_least_squares( n_samples = raw_predictions.shape[0] for i in prange(n_samples, schedule='static', nogil=True): - # Note: a more correct exp is 2 * (raw_predictions - y_true) but - # since we use 1 for the constant hessian value (and not 2) this - # is strictly equivalent for the leaves values. gradients[i] = raw_predictions[i] - y_true[i] +def _update_gradients_least_absolute_deviation( + G_H_DTYPE_C [::1] gradients, # OUT + const Y_DTYPE_C [::1] y_true, # IN + const Y_DTYPE_C [::1] raw_predictions): # IN + + cdef: + int n_samples + int i + + n_samples = raw_predictions.shape[0] + for i in prange(n_samples, schedule='static', nogil=True): + # gradient = sign(raw_predicition - y_pred) + gradients[i] = 2 * (y_true[i] - raw_predictions[i] < 0) - 1 + + def _update_gradients_hessians_binary_crossentropy( G_H_DTYPE_C [::1] gradients, # OUT G_H_DTYPE_C [::1] hessians, # OUT diff --git a/sklearn/ensemble/_hist_gradient_boosting/gradient_boosting.py b/sklearn/ensemble/_hist_gradient_boosting/gradient_boosting.py index 70a507d09c1c6..3ba58f700c062 100644 --- a/sklearn/ensemble/_hist_gradient_boosting/gradient_boosting.py +++ b/sklearn/ensemble/_hist_gradient_boosting/gradient_boosting.py @@ -322,6 +322,10 @@ def fit(self, X, y): acc_find_split_time += grower.total_find_split_time acc_compute_hist_time += grower.total_compute_hist_time + if self.loss_.need_update_leaves_values: + self.loss_.update_leaves_values(grower, y_train, + raw_predictions[k, :]) + predictor = grower.make_predictor( bin_thresholds=self.bin_mapper_.bin_thresholds_ ) @@ -672,7 +676,8 @@ class HistGradientBoostingRegressor(BaseHistGradientBoosting, RegressorMixin): Parameters ---------- - loss : {'least_squares'}, optional (default='least_squares') + loss : {'least_squares', 'least_absolute_deviation'}, \ + optional (default='least_squares') The loss function to use in the boosting process. Note that the "least squares" loss actually implements an "half least squares loss" to simplify the computation of the gradient. @@ -770,7 +775,7 @@ class HistGradientBoostingRegressor(BaseHistGradientBoosting, RegressorMixin): 0.98... """ - _VALID_LOSSES = ('least_squares',) + _VALID_LOSSES = ('least_squares', 'least_absolute_deviation') def __init__(self, loss='least_squares', learning_rate=0.1, max_iter=100, max_leaf_nodes=31, max_depth=None, diff --git a/sklearn/ensemble/_hist_gradient_boosting/loss.py b/sklearn/ensemble/_hist_gradient_boosting/loss.py index 9e00187d62425..bcfec023b5571 100644 --- a/sklearn/ensemble/_hist_gradient_boosting/loss.py +++ b/sklearn/ensemble/_hist_gradient_boosting/loss.py @@ -18,6 +18,7 @@ from .common import Y_DTYPE from .common import G_H_DTYPE from ._loss import _update_gradients_least_squares +from ._loss import _update_gradients_least_absolute_deviation from ._loss import _update_gradients_hessians_binary_crossentropy from ._loss import _update_gradients_hessians_categorical_crossentropy @@ -25,6 +26,16 @@ class BaseLoss(ABC): """Base class for a loss.""" + # This variable indicates whether the loss requires the leaves values to + # be updated once the tree has been trained. The trees are trained to + # predict a Newton-Raphson step (see grower._finalize_leaf()). But for + # some losses (e.g. least absolute deviation) we need to adjust the tree + # values to account for the "line search" of the gradient descent + # procedure. See the original paper Greedy Function Approximation: A + # Gradient Boosting Machine by Friedman + # (https://statweb.stanford.edu/~jhf/ftp/trebst.pdf) for the theory. + need_update_leaves_values = False + def init_gradients_and_hessians(self, n_samples, prediction_dim): """Return initial gradients and hessians. @@ -53,9 +64,10 @@ def init_gradients_and_hessians(self, n_samples, prediction_dim): shape = (prediction_dim, n_samples) gradients = np.empty(shape=shape, dtype=G_H_DTYPE) if self.hessians_are_constant: - # if the hessians are constant, we consider they are equal to 1. - # this is correct as long as we adjust the gradients. See e.g. LS - # loss + # If the hessians are constant, we consider they are equal to 1. + # - This is correct for the half LS loss + # - For LAD loss, hessians are actually 0, but they are always + # ignored anyway. hessians = np.ones(shape=(1, 1), dtype=G_H_DTYPE) else: hessians = np.empty(shape=shape, dtype=G_H_DTYPE) @@ -141,6 +153,63 @@ def update_gradients_and_hessians(self, gradients, hessians, y_true, _update_gradients_least_squares(gradients, y_true, raw_predictions) +class LeastAbsoluteDeviation(BaseLoss): + """Least asbolute deviation, for regression. + + For a given sample x_i, the loss is defined as:: + + loss(x_i) = |y_true_i - raw_pred_i| + """ + + hessians_are_constant = True + # This variable indicates whether the loss requires the leaves values to + # be updated once the tree has been trained. The trees are trained to + # predict a Newton-Raphson step (see grower._finalize_leaf()). But for + # some losses (e.g. least absolute deviation) we need to adjust the tree + # values to account for the "line search" of the gradient descent + # procedure. See the original paper Greedy Function Approximation: A + # Gradient Boosting Machine by Friedman + # (https://statweb.stanford.edu/~jhf/ftp/trebst.pdf) for the theory. + need_update_leaves_values = True + + def __call__(self, y_true, raw_predictions, average=True): + # shape (1, n_samples) --> (n_samples,). reshape(-1) is more likely to + # return a view. + raw_predictions = raw_predictions.reshape(-1) + loss = np.abs(y_true - raw_predictions) + return loss.mean() if average else loss + + def get_baseline_prediction(self, y_train, prediction_dim): + return np.median(y_train) + + @staticmethod + def inverse_link_function(raw_predictions): + return raw_predictions + + def update_gradients_and_hessians(self, gradients, hessians, y_true, + raw_predictions): + # shape (1, n_samples) --> (n_samples,). reshape(-1) is more likely to + # return a view. + raw_predictions = raw_predictions.reshape(-1) + gradients = gradients.reshape(-1) + _update_gradients_least_absolute_deviation(gradients, y_true, + raw_predictions) + + def update_leaves_values(self, grower, y_true, raw_predictions): + # Update the values predicted by the tree with + # median(y_true - raw_predictions). + # See note about need_update_leaves_values in BaseLoss. + + # TODO: ideally this should be computed in parallel over the leaves + # using something similar to _update_raw_predictions(), but this + # requires a cython version of median() + for leaf in grower.finalized_leaves: + indices = leaf.sample_indices + median_res = np.median(y_true[indices] - raw_predictions[indices]) + leaf.value = grower.shrinkage * median_res + # Note that the regularization is ignored here + + class BinaryCrossEntropy(BaseLoss): """Binary cross-entropy loss, for binary classification. @@ -242,6 +311,7 @@ def predict_proba(self, raw_predictions): _LOSSES = { 'least_squares': LeastSquares, + 'least_absolute_deviation': LeastAbsoluteDeviation, 'binary_crossentropy': BinaryCrossEntropy, 'categorical_crossentropy': CategoricalCrossEntropy } diff --git a/sklearn/ensemble/_hist_gradient_boosting/tests/test_compare_lightgbm.py b/sklearn/ensemble/_hist_gradient_boosting/tests/test_compare_lightgbm.py index 63d8c8fb1059d..32bb5dee4b197 100644 --- a/sklearn/ensemble/_hist_gradient_boosting/tests/test_compare_lightgbm.py +++ b/sklearn/ensemble/_hist_gradient_boosting/tests/test_compare_lightgbm.py @@ -39,6 +39,13 @@ def test_same_predictions_regression(seed, min_samples_leaf, n_samples, # and max_leaf_nodes is low enough. # - To ignore discrepancies caused by small differences the binning # strategy, data is pre-binned if n_samples > 255. + # - We don't check the least_absolute_deviation loss here. This is because + # LightGBM's computation of the median (used for the initial value of + # raw_prediction) is a bit off (they'll e.g. return midpoints when there + # is no need to.). Since these tests only run 1 iteration, the + # discrepancy between the initial values leads to biggish differences in + # the predictions. These differences are much smaller with more + # iterations. rng = np.random.RandomState(seed=seed) n_samples = n_samples diff --git a/sklearn/ensemble/_hist_gradient_boosting/tests/test_gradient_boosting.py b/sklearn/ensemble/_hist_gradient_boosting/tests/test_gradient_boosting.py index 5de49ef740295..0574b045523e7 100644 --- a/sklearn/ensemble/_hist_gradient_boosting/tests/test_gradient_boosting.py +++ b/sklearn/ensemble/_hist_gradient_boosting/tests/test_gradient_boosting.py @@ -155,6 +155,15 @@ def test_should_stop(scores, n_iter_no_change, tol, stopping): assert gbdt._should_stop(scores) == stopping +def test_least_absolute_deviation(): + # For coverage only. + X, y = make_regression(n_samples=500, random_state=0) + gbdt = HistGradientBoostingRegressor(loss='least_absolute_deviation', + random_state=0) + gbdt.fit(X, y) + assert gbdt.score(X, y) > .9 + + def test_binning_train_validation_are_separated(): # Make sure training and validation data are binned separately. # See issue 13926 diff --git a/sklearn/ensemble/_hist_gradient_boosting/tests/test_loss.py b/sklearn/ensemble/_hist_gradient_boosting/tests/test_loss.py index b49acc52b6e40..8c300db993d3d 100644 --- a/sklearn/ensemble/_hist_gradient_boosting/tests/test_loss.py +++ b/sklearn/ensemble/_hist_gradient_boosting/tests/test_loss.py @@ -32,9 +32,12 @@ def get_hessians(y_true, raw_predictions): if loss.__class__.__name__ == 'LeastSquares': # hessians aren't updated because they're constant: - # the value is 1 because the loss is actually an half + # the value is 1 (and not 2) because the loss is actually an half # least squares loss. hessians = np.full_like(raw_predictions, fill_value=1) + elif loss.__class__.__name__ == 'LeastAbsoluteDeviation': + # hessians aren't updated because they're constant + hessians = np.full_like(raw_predictions, fill_value=0) return hessians @@ -81,6 +84,7 @@ def fprime2(x): @pytest.mark.parametrize('loss, n_classes, prediction_dim', [ ('least_squares', 0, 1), + ('least_absolute_deviation', 0, 1), ('binary_crossentropy', 2, 1), ('categorical_crossentropy', 3, 3), ]) @@ -94,7 +98,7 @@ def test_numerical_gradients(loss, n_classes, prediction_dim, seed=0): rng = np.random.RandomState(seed) n_samples = 100 - if loss == 'least_squares': + if loss in ('least_squares', 'least_absolute_deviation'): y_true = rng.normal(size=n_samples).astype(Y_DTYPE) else: y_true = rng.randint(0, n_classes, size=n_samples).astype(Y_DTYPE) @@ -128,11 +132,8 @@ def test_numerical_gradients(loss, n_classes, prediction_dim, seed=0): f = loss(y_true, raw_predictions, average=False) numerical_hessians = (f_plus_eps + f_minus_eps - 2 * f) / eps**2 - def relative_error(a, b): - return np.abs(a - b) / np.maximum(np.abs(a), np.abs(b)) - - assert_allclose(numerical_gradients, gradients, rtol=1e-4) - assert_allclose(numerical_hessians, hessians, rtol=1e-4) + assert_allclose(numerical_gradients, gradients, rtol=1e-4, atol=1e-7) + assert_allclose(numerical_hessians, hessians, rtol=1e-4, atol=1e-7) def test_baseline_least_squares(): @@ -145,6 +146,22 @@ def test_baseline_least_squares(): assert baseline_prediction.dtype == y_train.dtype # Make sure baseline prediction is the mean of all targets assert_almost_equal(baseline_prediction, y_train.mean()) + assert np.allclose(loss.inverse_link_function(baseline_prediction), + baseline_prediction) + + +def test_baseline_least_absolute_deviation(): + rng = np.random.RandomState(0) + + loss = _LOSSES['least_absolute_deviation']() + y_train = rng.normal(size=100) + baseline_prediction = loss.get_baseline_prediction(y_train, 1) + assert baseline_prediction.shape == tuple() # scalar + assert baseline_prediction.dtype == y_train.dtype + # Make sure baseline prediction is the median of all targets + assert np.allclose(loss.inverse_link_function(baseline_prediction), + baseline_prediction) + assert baseline_prediction == pytest.approx(np.median(y_train)) def test_baseline_binary_crossentropy(): diff --git a/sklearn/ensemble/_hist_gradient_boosting/utils.pyx b/sklearn/ensemble/_hist_gradient_boosting/utils.pyx index 291c015fec5d3..4b1188b87e69e 100644 --- a/sklearn/ensemble/_hist_gradient_boosting/utils.pyx +++ b/sklearn/ensemble/_hist_gradient_boosting/utils.pyx @@ -43,6 +43,7 @@ def get_equivalent_estimator(estimator, lib='lightgbm'): lightgbm_loss_mapping = { 'least_squares': 'regression_l2', + 'least_absolute_deviation': 'regression_l1', 'binary_crossentropy': 'binary', 'categorical_crossentropy': 'multiclass' } @@ -75,6 +76,7 @@ def get_equivalent_estimator(estimator, lib='lightgbm'): # XGB xgboost_loss_mapping = { 'least_squares': 'reg:linear', + 'least_absolute_deviation': 'LEAST_ABSOLUTE_DEV_NOT_SUPPORTED', 'binary_crossentropy': 'reg:logistic', 'categorical_crossentropy': 'multi:softmax' } @@ -98,6 +100,8 @@ def get_equivalent_estimator(estimator, lib='lightgbm'): # Catboost catboost_loss_mapping = { 'least_squares': 'RMSE', + # catboost does not support MAE when leaf_estimation_method is Newton + 'least_absolute_deviation': 'LEAST_ASBOLUTE_DEV_NOT_SUPPORTED', 'binary_crossentropy': 'Logloss', 'categorical_crossentropy': 'MultiClass' }