8000 TST Gini is equivalent to mse in binary classification · mohitsingh1007/scikit-learn@260bdb9 · GitHub
[go: up one dir, main page]

8000
Skip to content

Commit 260bdb9

Browse files
committed
TST Gini is equivalent to mse in binary classification
1 parent 8f76f48 commit 260bdb9

File tree

1 file changed

+21
-0
lines changed

1 file changed

+21
-0
lines changed

sklearn/tree/tests/test_tree.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -282,6 +282,27 @@ def test_importances_raises():
282282
clf.feature_importances_
283283

284284

285+
def test_importances_gini_equal_mse():
286+
"""Check that gini is equivalent to mse for binary output variable"""
287+
288+
X, y = datasets.make_classification(n_samples=2000,
289+
n_features=10,
290+
n_informative=3,
291+
n_redundant=0,
292+
n_repeated=0,
293+
shuffle=False,
294+
random_state=0)
295+
296+
clf = DecisionTreeClassifier(criterion="gini", random_state=0).fit(X, y)
297+
reg = DecisionTreeRegressor(criterion="mse", random_state=0).fit(X, y)
298+
299+
assert_almost_equal(clf.feature_importances_, reg.feature_importances_)
300+
assert_array_equal(clf.tree_.feature, reg.tree_.feature)
301+
assert_array_equal(clf.tree_.children_left, reg.tree_.children_left)
302+
assert_array_equal(clf.tree_.children_right, reg.tree_.children_right)
303+
assert_array_equal(clf.tree_.n_node_samples, reg.tree_.n_node_samples)
304+
305+
285306
def test_max_features():
286307
"""Check max_features."""
287308
for name, TreeRegressor in REG_TREES.items():

0 commit comments

Comments
 (0)
0