8000 BUG Fixes sample weights when there are missing values in DecisionTre… · scikit-learn/scikit-learn@43cf7d4 · GitHub
[go: up one dir, main page]

Skip to content

Commit 43cf7d4

Browse files
authored
BUG Fixes sample weights when there are missing values in DecisionTrees (#26376)
1 parent 4a5f954 commit 43cf7d4

File tree

3 files changed

+43
-6
lines changed

3 files changed

+43
-6
lines changed

doc/whats_new/v1.3.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -516,7 +516,7 @@ Changelog
516516
:class:`tree.DecisionTreeClassifier` support missing values when
517517
`splitter='best'` and criterion is `gini`, `entropy`, or `log_loss`,
518518
for classification or `squared_error`, `friedman_mse`, or `poisson`
519-
for regression. :pr:`23595` by `Thomas Fan`_.
519+
for regression. :pr:`23595`, :pr:`26376` by `Thomas Fan`_.
520520

521521
- |Enhancement| Adds a `class_names` parameter to
522522
:func:`tree.export_text`. This allows specifying the parameter `class_names`

sklearn/tree/_criterion.pyx

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -838,7 +838,9 @@ cdef class RegressionCriterion(Criterion):
838838
self.sample_indices[-n_missing:]
839839
"""
840840
cdef SIZE_t i, p, k
841-
cdef DOUBLE_t w = 0.0
841+
cdef DOUBLE_t y_ik
842+
cdef DOUBLE_t w_y_ik
843+
cdef DOUBLE_t w = 1.0
842844

843845
self.n_missing = n_missing
844846
if n_missing == 0:
@@ -855,7 +857,9 @@ cdef class RegressionCriterion(Criterion):
855857
w = self.sample_weight[i]
856858

857859
for k in range(self.n_outputs):
858-
self.sum_missing[k] += w
860+
y_ik = self.y[i, k]
861+
w_y_ik = w * y_ik
862+
self.sum_missing[k] += w_y_ik
859863

860864
self.weighted_n_missing += w
861865

sklearn/tree/tests/test_tree.py

Lines changed: 36 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2549,7 +2549,8 @@ def test_missing_values_poisson():
25492549
(datasets.make_classification, DecisionTreeClassifier),
25502550
],
25512551
)
2552-
def test_missing_values_is_resilience(make_data, Tree):
2552+
@pytest.mark.parametrize("sample_weight_train", [None, "ones"])
2553+
def test_missing_values_is_resilience(make_data, Tree, sample_weight_train):
25532554
"""Check that trees can deal with missing values and have decent performance."""
25542555

25552556
rng = np.random.RandomState(0)
@@ -2563,15 +2564,18 @@ def test_missing_values_is_resilience(make_data, Tree):
25632564
X_missing, y, random_state=0
25642565
)
25652566

2567+
if sample_weight_train == "ones":
2568+
sample_weight_train = np.ones(X_missing_train.shape[0])
2569+
25662570
# Train tree with missing values
25672571
tree_with_missing = Tree(random_state=rng)
2568-
tree_with_missing.fit(X_missing_train, y_train)
2572+
tree_with_missing.fit(X_missing_train, y_train, sample_weight=sample_weight_train)
25692573
score_with_missing = tree_with_missing.score(X_missing_test, y_test)
25702574

25712575
# Train tree without missing values
25722576
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=0)
25732577
tree = Tree(random_state=rng)
2574-
tree.fit(X_train, y_train)
2578+
tree.fit(X_train, y_train, sample_weight=sample_weight_train)
25752579
score_without_missing = tree.score(X_test, y_test)
25762580

25772581
# Score is still 90 percent of the tree's score that had no missing values
@@ -2601,3 +2605,32 @@ def test_missing_value_is_predictive():
26012605

26022606
assert tree.score(X_train, y_train) >= 0.85
26032607
assert tree.score(X_test, y_test) >= 0.85
2608+
2609+
2610+
@pytest.mark.parametrize(
2611+
"make_data, Tree",
2612+
[
2613+
(datasets.make_regression, DecisionTreeRegressor),
2614+
(datasets.make_classification, DecisionTreeClassifier),
2615+
],
2616+
)
2617+
def test_sample_weight_non_uniform(make_data, Tree):
2618+
"""Check sample weight is correctly handled with missing values."""
2619+
rng = np.random.RandomState(0)
2620+
n_samples, n_features = 1000, 10
2621+
X, y = make_data(n_samples=n_samples, n_features=n_features, random_state=rng)
2622+
2623+
# Create dataset with missing values
2624+
X[rng.choice([False, True], size=X.shape, p=[0.9, 0.1])] = np.nan
2625+
2626+
# Zero sample weight is the same as removing the sample
2627+
sample_weight = np.ones(X.shape[0])
2628+
sample_weight[::2] = 0.0
2629+
2630+
tree_with_sw = Tree(random_state=0)
2631+
tree_with_sw.fit(X, y, sample_weight=sample_weight)
2632+
2633+
tree_samples_removed = Tree(random_state=0)
2634+
tree_samples_removed.fit(X[1::2, :], y[1::2])
2635+
2636+
assert_allclose(tree_samples_removed.predict(X), tree_with_sw.predict(X))

0 commit comments

Comments
 (0)
0