8000
File tree Expand file tree Collapse file tree 1 file changed +21
-0
lines changed Expand file tree Collapse file tree 1 file changed +21
-0
lines changed Original file line number Diff line number Diff line change @@ -282,6 +282,27 @@ def test_importances_raises():
282
282
clf .feature_importances_
283
283
284
284
285
+ def test_importances_gini_equal_mse ():
286
+ """Check that gini is equivalent to mse for binary output variable"""
287
+
288
+ X , y = datasets .make_classification (n_samples = 2000 ,
289
+ n_features = 10 ,
290
+ n_informative = 3 ,
291
+ n_redundant = 0 ,
292
+ n_repeated = 0 ,
293
+ shuffle = False ,
294
+ random_state = 0 )
295
+
296
+ clf = DecisionTreeClassifier (criterion = "gini" , random_state = 0 ).fit (X , y )
297
+ reg = DecisionTreeRegressor (criterion = "mse" , random_state = 0 ).fit (X , y )
298
+
299
+ assert_almost_equal (clf .feature_importances_ , reg .feature_importances_ )
300
+ assert_array_equal (clf .tree_ .feature , reg .tree_ .feature )
301
+ assert_array_equal (clf .tree_ .children_left , reg .tree_ .children_left )
302
+ assert_array_equal (clf .tree_ .children_right , reg .tree_ .children_right )
303
+ assert_array_equal (clf .tree_ .n_node_samples , reg .tree_ .n_node_samples )
304
+
305
+
285
306
def test_max_features ():
286
307
"""Check max_features."""
287
308
for name , TreeRegressor in REG_TREES .items ():
You can’t perform that action at this time.
0 commit comments