8000 FIX Tests and add more test cases · scikit-learn/scikit-learn@49d6e18 · GitHub
[go: up one dir, main page]

Skip to content 8000

Commit 49d6e18

Browse files
committed
FIX Tests and add more test cases
1 parent 2ff0464 commit 49d6e18

File tree

1 file changed

+9
-10
lines changed

1 file changed

+9
-10
lines changed

sklearn/tree/tests/test_tree.py

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -844,28 +844,27 @@ def test_min_impurity_split():
844844
def test_min_impurity_decrease():
845845
# test if min_impurity_decrease ensure that a split is made only if
846846
# if the impurity decrease is atleast that value
847-
X = np.asfortranarray(iris.data.astype(tree._tree.DTYPE))
848-
y = iris.target
847+
X, y = datasets.make_classification(n_samples=10000, random_state=42)
849848

850849
# test both DepthFirstTreeBuilder and BestFirstTreeBuilder
851850
# by setting max_leaf_nodes
852851
for max_leaf_nodes, name in product((None, 1000), ALL_TREES.keys()):
853852
TreeEstimator = ALL_TREES[name]
854853

855854
# Check default value of min_impurity_decrease, 1e-7
856-
est1 = TreeEstimator(max_leaf_nodes=max_leaf_nodes,
857-
random_state=0)
855+
est1 = TreeEstimator(max_leaf_nodes=max_leaf_nodes, random_state=0)
858856
# Check with explicit value of 0.05
859857
est2 = TreeEstimator(max_leaf_nodes=max_leaf_nodes,
860-
min_impurity_decrease=0.05,
861-
random_state=0)
862-
# Check with a much lower value of 0.00001
858+
min_impurity_decrease=0.05, random_state=0)
859+
# Check with a much lower value of 0.0001
863860
est3 = TreeEstimator(max_leaf_nodes=max_leaf_nodes,
864-
min_impurity_decrease=0.00001,
865-
random_state=0)
861+
min_impurity_decrease=0.0001, random_state=0)
862+
# Check with a much lower value of 0.1
863+
est4 = TreeEstimator(max_leaf_nodes=max_leaf_nodes,
864+
min_impurity_decrease=0.1, random_state=0)
866865

867866
for est, expected_decrease in ((est1, 1e-7), (est2, 0.05),
868-
(est3, 0.00001)):
867+
(est3, 0.0001), (est4, 0.1)):
869868
assert_less_equal(est.min_impurity_decrease, expected_decrease,
870869
"Failed, min_impurity_decrease = {0} > {1}"
871870
.format(est.min_impurity_decrease,

0 commit comments

Comments
 (0)
0