8000 DEP deprecate "mae" criterion in GadientBoosting estimators (#18326) · scikit-learn/scikit-learn@4de0f97 · GitHub
[go: up one dir, main page]

Skip to content

Commit 4de0f97

Browse files
authored
DEP deprecate "mae" criterion in GadientBoosting estimators (#18326)
1 parent 5008c28 commit 4de0f97

File tree

3 files changed

+50
-0
lines changed

3 files changed

+50
-0
lines changed

doc/whats_new/v0.24.rst

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -199,6 +199,11 @@ Changelog
199199
:class:`ensemble.GradientBoostingRegressor` and returns `1`.
200200
:pr:`17702` by :user:`Simona Maggio <simonamaggio>`.
201201

202+
- |API|: Mean absolute error ('mae') is now deprecated for the parameter
203+
``criterion`` in :class:`ensemble.GradientBoostingRegressor` and
204+
:class:`ensemble.GradientBoostingClassifier`.
205+
:pr:`18326` by :user:`Madhura Jayaratne <madhuracj>`.
206+
202207
:mod:`sklearn.exceptions`
203208
.........................
204209

sklearn/ensemble/_gb.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -358,6 +358,10 @@ def _check_initialized(self):
358358
"""Check that the estimator is initialized, raising an error if not.& 8000 quot;""
359359
check_is_fitted(self)
360360

361+
@abstractmethod
362+
def _warn_mae_for_criterion(self):
363+
pass
364+
361365
def fit(self, X, y, sample_weight=None, monitor=None):
362366
"""Fit the gradient boosting model.
363367
@@ -393,6 +397,10 @@ def fit(self, X, y, sample_weight=None, monitor=None):
393397
-------
394398
self : object
395399
"""
400+
if self.criterion == 'mae':
401+
# TODO: This should raise an error from 0.26
402+
self._warn_mae_for_criterion()
403+
396404
# if not warmstart - clear the estimator state
397405
if not self.warm_start:
398406
self._clear_state()
@@ -802,6 +810,10 @@ class GradientBoostingClassifier(ClassifierMixin, BaseGradientBoosting):
802810
some cases.
803811
804812
.. versionadded:: 0.18
813+
.. deprecated:: 0.24
814+
`criterion='mae'` is deprecated and will be removed in version
815+
0.26. Use `criterion='friedman_mse'` or `'mse'` instead, as trees
816+
should use a least-square criterion in Gradient Boosting.
805817
806818
min_samples_split : int or float, default=2
807819
The minimum number of samples required to split an internal node:
@@ -1102,6 +1114,14 @@ def _validate_y(self, y, sample_weight):
11021114
self.n_classes_ = self._n_classes
11031115
return y
11041116

1117+
def _warn_mae_for_criterion(self):
1118+
# TODO: This should raise an error from 0.26
1119+
warnings.warn("criterion='mae' was deprecated in version 0.24 and "
1120+
"will be removed in version 0.26. Use "
1121+
"criterion='friedman_mse' or 'mse' instead, as trees "
1122+
"should use a least-square criterion in Gradient "
1123+
"Boosting.", FutureWarning)
1124+
11051125
def decision_function(self, X):
11061126
"""Compute the decision function of ``X``.
11071127
@@ -1320,6 +1340,10 @@ class GradientBoostingRegressor(RegressorMixin, BaseGradientBoosting):
13201340
some cases.
13211341
13221342
.. versionadded:: 0.18
1343+
.. deprecated:: 0.24
1344+
`criterion='mae'` is deprecated and will be removed in version
1345+
0.26. The correct way of minimizing the absolute error is to use
1346+
`loss='lad'` instead.
13231347
13241348
min_samples_split : int or float, default=2
13251349
The minimum number of samples required to split an internal node:
@@ -1601,6 +1625,13 @@ def _validate_y(self, y, sample_weight=None):
16011625
y = y.astype(DOUBLE)
16021626
return y
16031627

1628+
def _warn_mae_for_criterion(self):
1629+
# TODO: This should raise an error from 0.26
1630+
warnings.warn("criterion='mae' was deprecated in version 0.24 and "
1631+
"will be removed in version 0.26. The correct way of "
1632+
"minimizing the absolute error is to use loss='lad' "
1633+
"instead.", FutureWarning)
1634+
16041635
def predict(self, X):
16051636
"""Predict regression target for X.
16061637

sklearn/ensemble/tests/test_gradient_boosting.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1333,3 +1333,17 @@ def test_attr_error_raised_if_not_fitted():
13331333
)
13341334
with pytest.raises(AttributeError, match=msg):
13351335
gbr.n_classes_
1336+
1337+
1338+
# TODO: Update in 0.26 to check for the error raised
1339+
@pytest.mark.parametrize('estimator', [
1340+
GradientBoostingClassifier(criterion='mae'),
1341+
GradientBoostingRegressor(criterion='mae')
1342+
])
1343+
def test_criterion_mae_deprecation(estimator):
1344+
# checks whether a deprecation warning is issues when criterion='mae'
1345+
# is used.
1346+
msg = ("criterion='mae' was deprecated in version 0.24 and "
1347+
"will be removed in version 0.26.")
1348+
with pytest.warns(FutureWarning, match=msg):
1349+
estimator.fit(X, y)

0 commit comments

Comments
 (0)
0