[go: up one dir, main page]

Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FIX make sure to reinitialize criterion even when n_missing=0 #28295

Merged
merged 8 commits into from
Jan 30, 2024

Conversation

glemaitre
Copy link
Member
@glemaitre glemaitre commented Jan 27, 2024

closes #28254

The criterion was not initialize at each split and could wrongly contain information about a previous split containing missing values. I lead to wrong statistic reported by the criterion.

TODO:

@glemaitre glemaitre changed the title Is/28254 FIX make sure to reinitialize criterion even when n_missing=0 Jan 27, 2024
Copy link
github-actions bot commented Jan 27, 2024

✔️ Linting Passed

All linting checks passed. Your pull request is in excellent shape! ☀️

Generated for commit: ff3700d. Link to the linter CI: here

@glemaitre glemaitre marked this pull request as draft January 27, 2024 23:43
@glemaitre
Copy link
Member Author
glemaitre commented Jan 27, 2024

Not enough. Still have a bug with the 4th estimator in the following forest:

import numpy as np
import sklearn
from sklearn.datasets import load_iris
from sklearn.feature_selection import RFECV
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import train_test_split

X, y = load_iris(as_frame=True, return_X_y=True)

rng = np.random.RandomState(42)
X_missing = X.copy()
mask = rng.binomial(n=np.array([1, 1, 1, 1]).reshape(1, -1),
                    p=(X['petal length (cm)'] / 8).values.reshape(-1, 1)).astype(bool)
X_missing[mask] = np.NaN

X_train, X_test, y_train, y_test = train_test_split(X_missing, y, random_state=13)
clf = RandomForestClassifier(n_estimators=4, random_state=1).fit(X_train, y_train)

The criterion is still nan for node #5. Need to debug it.

@glemaitre glemaitre marked this pull request as ready for review January 28, 2024 11:21
@glemaitre
Copy link
Member Author

I'm more confident now that we solve the original issue. The integration case on Ames housing show that we are as good as in the imputation case and I have 2 toy examples that I hand crafted that allow me to debug the code and find the issue with some wrong computation of the impurity.

I still have the issue about the np.inf when building the tree but I would tackle this in another issue to understand if it is actually a bug or not.

@glemaitre
Copy link
Member Author

@thomasjpfan If you would have a bit of time, it would be great if you can give a look. At the end it is only a 2-lines diff :).

@@ -1854,3 +1860,70 @@ def test_non_supported_criterion_raises_error_with_missing_values():
msg = "RandomForestRegressor does not accept missing values"
with pytest.raises(ValueError, match=msg):
forest.fit(X, y)


@skip_if_no_network
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This require to fetch Ames housing. I don't know if we could kind of cache the dataset to make sure that we run the test instead of skipping it.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would very much rather have a toy dataset made for this test rather than downloading one.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Although this test is useful, I prefer to hold off on adding it because running fetch_openml can lead to rare conditions when running with pytest-xdist. There is no race condition now because fetch_openml(("house_prices"), but I rather not establish the pattern.

Note, much of

dataset_fetchers = {
was written to "download all the datasets before pytest runs".

For this specific PR, I think test_regression_tree_missing_values_toy is sufficient as a non-regression test.

@adam2392
Copy link
Member

Chiming in here w/ some additional context:

I saw some issues when implementing the missing-value support for ExtraTrees in #27966 and applied this fix and also saw empirical performance go up for the ExtraTreeRegressor (one of the unit-tests was having a hard time passing, and I could not figure out why). So this fix definitely is needed :)

There shouldn't be any noticeable performance degradation either since init_missing exits early when n_missing == 0.

Copy link
Member
@adrinjalali adrinjalali left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Otherwise LGTM.

@@ -1854,3 +1860,70 @@ def test_non_supported_criterion_raises_error_with_missing_values():
msg = "RandomForestRegressor does not accept missing values"
with pytest.raises(ValueError, match=msg):
forest.fit(X, y)


@skip_if_no_network
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would very much rather have a toy dataset made for this test rather than downloading one.

Copy link
Member
@thomasjpfan thomasjpfan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for the fix! This PR looks right to me.

@@ -1854,3 +1860,70 @@ def test_non_supported_criterion_raises_error_with_missing_values():
msg = "RandomForestRegressor does not accept missing values"
with pytest.raises(ValueError, match=msg):
forest.fit(X, y)


@skip_if_no_network
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Although this test is useful, I prefer to hold off on adding it because running fetch_openml can lead to rare conditions when running with pytest-xdist. There is no race condition now because fetch_openml(("house_prices"), but I rather not establish the pattern.

Note, much of

dataset_fetchers = {
was written to "download all the datasets before pytest runs".

For this specific PR, I think test_regression_tree_missing_values_toy is sufficient as a non-regression test.

@glemaitre
Copy link
Member Author

I would very much rather have a toy dataset made for this test rather than downloading one.

We actually have 2 datasets in the two other toy tests. Here, this is more an integration test. Actually, it should be given in an example to show the native support for the missing values.

@glemaitre
Copy link
Member Author

I remove the integration test. I want to keep this PR focus on the bug fix. Once merge, I propose to have a go into the different examples to remove the SimpleImputer and change narrative and see if I can add back this new example.

@thomasjpfan thomasjpfan added the To backport PR merged in master that need a backport to a release branch defined based on the milestone. label Jan 30, 2024
@thomasjpfan thomasjpfan added this to the 1.4.1 milestone Jan 30, 2024
@thomasjpfan thomasjpfan enabled auto-merge (squash) January 30, 2024 15:47
@thomasjpfan thomasjpfan merged commit 0fb5295 into scikit-learn:main Jan 30, 2024
28 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
cython module:tree To backport PR merged in master that need a backport to a release branch defined based on the milestone.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

DecisionTree does not handle properly missing values in criterion partitioning
4 participants