8000 FIX adaboost return nan in feature importance (#20415) · rusdes/scikit-learn@b903486 · GitHub
[go: up one dir, main page]

Skip to content

Commit b903486

Browse files
MaxwellLZHogriselglemaitrecmarmo
authored
FIX adaboost return nan in feature importance (scikit-learn#20415)
Co-authored-by: Olivier Grisel <olivier.grisel@ensta.org> Co-authored-by: Guillaume Lemaitre <g.lemaitre58@gmail.com> Co-authored-by: Chiara Marmo <cmarmo@users.noreply.github.com>
1 parent 122f5fe commit b903486

File tree

3 files changed

+35
-2
lines changed

3 files changed

+35
-2
lines changed

doc/whats_new/v1.2.rst

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -183,6 +183,13 @@ Changelog
183183
- |Efficiency| Improve runtime performance of :class:`ensemble.IsolationForest`
184184
by avoiding data copies. :pr:`23252` by :user:`Zhehao Liu <MaxwellLZH>`.
185185

186+
- |Fix| Fixed the issue where :class:`ensemble.AdaBoostClassifier` outputs
187+
NaN in feature importance when fitted with very small sample weight.
188+
:pr:`20415` by :user:`Zhehao Liu <MaxwellLZH>`.
189+
190+
:mod:`sklearn.feature_selection`
191+
................................
192+
186193
:mod:`sklearn.decomposition`
187194
............................
188195

sklearn/ensemble/_weight_boosting.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -146,8 +146,15 @@ def fit(self, X, y, sample_weight=None):
146146
# Initialization of the random number instance that will be used to
147147
# generate a seed at each iteration
148148
random_state = check_random_state(self.random_state)
149+
epsilon = np.finfo(sample_weight.dtype).eps
149150

151+
zero_weight_mask = sample_weight == 0.0
150152
for iboost in range(self.n_estimators):
153+
# avoid extremely small sample weight, for details see issue #20320
154+
sample_weight = np.clip(sample_weight, a_min=epsilon, a_max=None)
155+
# do not clip sample weights that were exactly zero originally
156+
sample_weight[zero_weight_mask] = 0.0
157+
151158
# Boosting step
152159
sample_weight, estimator_weight, estimator_error = self._boost(
153160
iboost, X, y, sample_weight, random_state
@@ -635,7 +642,7 @@ def _boost_discrete(self, iboost, X, y, sample_weight, random_state):
635642
np.log((1.0 - estimator_error) / estimator_error) + np.log(n_classes - 1.0)
636643
)
637644

638-
# Only boost the weights if I will fit again
645+
# Only boost the weights if it will fit again
639646
if not iboost == self.n_estimators - 1:
640647
# Only boost positive weights
641648
sample_weight = np.exp(

sklearn/ensemble/tests/test_weight_boosting.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -309,7 +309,7 @@ def test_base_estimator():
309309

310310
def test_sample_weights_infinite():
311311
msg = "Sample weights have reached infinite values"
312-
clf = AdaBoostClassifier(n_estimators=30, learning_rate=5.0, algorithm="SAMME")
312+
clf = AdaBoostClassifier(n_estimators=30, learning_rate=23.0, algorithm="SAMME")
313313
with pytest.warns(UserWarning, match=msg):
314314
clf.fit(iris.data, iris.target)
315315

@@ -575,3 +575,22 @@ def test_adaboost_negative_weight_error(model, X, y):
575575
err_msg = "Negative values in data passed to `sample_weight`"
576576
with pytest.raises(ValueError, match=err_msg):
577577
model.fit(X, y, sample_weight=sample_weight)
578+
579+
580+
def test_adaboost_numerically_stable_feature_importance_with_small_weights():
581+
"""Check that we don't create NaN feature importance with numerically
582+
instable inputs.
583+
584+
Non-regression test for:
585+
https://github.com/scikit-learn/scikit-learn/issues/20320
586+
"""
587+
rng = np.random.RandomState(42)
588+
X = rng.normal(size=(1000, 10))
589+
y = rng.choice([0, 1], size=1000)
590+
sample_weight = np.ones_like(y) * 1e-263
591+
tree = DecisionTreeClassifier(max_depth=10, random_state=12)
592+
ada_model = AdaBoostClassifier(
593+
base_estimator=tree, n_estimators=20, random_state=12
594+
)
595+
ada_model.fit(X, y, sample_weight=sample_weight)
596+
assert np.isnan(ada_model.feature_importances_).sum() == 0

0 commit comments

Comments
 (0)
0