@@ -2549,7 +2549,8 @@ def test_missing_values_poisson():
25492549 (datasets .make_classification , DecisionTreeClassifier ),
25502550 ],
25512551)
2552- def test_missing_values_is_resilience (make_data , Tree ):
2552+ @pytest .mark .parametrize ("sample_weight_train" , [None , "ones" ])
2553+ def test_missing_values_is_resilience (make_data , Tree , sample_weight_train ):
25532554 """Check that trees can deal with missing values and have decent performance."""
25542555
25552556 rng = np .random .RandomState (0 )
@@ -2563,15 +2564,18 @@ def test_missing_values_is_resilience(make_data, Tree):
25632564 X_missing , y , random_state = 0
25642565 )
25652566
2567+ if sample_weight_train == "ones" :
2568+ sample_weight_train = np .ones (X_missing_train .shape [0 ])
2569+
25662570 # Train tree with missing values
25672571 tree_with_missing = Tree (random_state = rng )
2568- tree_with_missing .fit (X_missing_train , y_train )
2572+ tree_with_missing .fit (X_missing_train , y_train , sample_weight = sample_weight_train )
25692573 score_with_missing = tree_with_missing .score (X_missing_test , y_test )
25702574
25712575 # Train tree without missing values
25722576 X_train , X_test , y_train , y_test = train_test_split (X , y , random_state = 0 )
25732577 tree = Tree (random_state = rng )
2574- tree .fit (X_train , y_train )
2578+ tree .fit (X_train , y_train , sample_weight = sample_weight_train )
25752579 score_without_missing = tree .score (X_test , y_test )
25762580
25772581 # Score is still 90 percent of the tree's score that had no missing values
@@ -2601,3 +2605,32 @@ def test_missing_value_is_predictive():
26012605
26022606 assert tree .score (X_train , y_train ) >= 0.85
26032607 assert tree .score (X_test , y_test ) >= 0.85
2608+
2609+
2610+ @pytest .mark .parametrize (
2611+ "make_data, Tree" ,
2612+ [
2613+ (datasets .make_regression , DecisionTreeRegressor ),
2614+ (datasets .make_classification , DecisionTreeClassifier ),
2615+ ],
2616+ )
2617+ def test_sample_weight_non_uniform (make_data , Tree ):
2618+ """Check sample weight is correctly handled with missing values."""
2619+ rng = np .random .RandomState (0 )
2620+ n_samples , n_features = 1000 , 10
2621+ X , y = make_data (n_samples = n_samples , n_features = n_features , random_state = rng )
2622+
2623+ # Create dataset with missing values
2624+ X [rng .choice ([False , True ], size = X .shape , p = [0.9 , 0.1 ])] = np .nan
2625+
2626+ # Zero sample weight is the same as removing the sample
2627+ sample_weight = np .ones (X .shape [0 ])
2628+ sample_weight [::2 ] = 0.0
2629+
2630+ tree_with_sw = Tree (random_state = 0 )
2631+ tree_with_sw .fit (X , y , sample_weight = sample_weight )
2632+
2633+ tree_samples_removed = Tree (random_state = 0 )
2634+ tree_samples_removed .fit (X [1 ::2 , :], y [1 ::2 ])
2635+
2636+ assert_allclose (tree_samples_removed .predict (X ), tree_with_sw .predict (X ))
0 commit comments