8000 Renaming test functions, adding dtype to predictions array in tree.py. · scikit-learn/scikit-learn@b26e943 · GitHub
[go: up one dir, main page]

Skip to content

Commit b26e943

Browse files
committed
Renaming test functions, adding dtype to predictions array in tree.py.
1 parent 2ee5ddc commit b26e943

File tree

3 files changed

+7
-15
lines changed
  • sklearn
    • ensemble/tests
  • tree
  • 3 files changed

    +7
    -15
    lines changed

    sklearn/ensemble/tests/test_forest.py

    Lines changed: 3 additions & 7 deletions
    Original file line numberDiff line numberDiff line change
    @@ -1339,7 +1339,9 @@ def test_backend_respected():
    13391339
    assert ba.count == 0
    13401340

    13411341

    1342-
    def check_multi_target(name, oob_score):
    1342+
    @pytest.mark.parametrize('name', FOREST_CLASSIFIERS)
    1343+
    @pytest.mark.parametrize('oob_score', (True, False))
    1344+
    def test_multi_target(name, oob_score):
    13431345
    ForestClassifier = FOREST_CLASSIFIERS[name]
    13441346

    13451347
    clf = ForestClassifier(bootstrap=True, oob_score=oob_score)
    @@ -1354,9 +1356,3 @@ def check_multi_target(name, oob_score):
    13541356
    # Try to fix and predict.
    13551357
    clf.fit(X, ys)
    13561358
    clf.predict(X)
    1357-
    1358-
    1359-
    @pytest.mark.parametrize('name', FOREST_CLASSIFIERS)
    1360-
    @pytest.mark.parametrize('oob_score', (True, False))
    1361-
    def test_multi_target(name, oob_score):
    1362-
    check_multi_target(name, oob_score)

    sklearn/tree/tests/test_tree.py

    Lines changed: 2 additions & 6 deletions
    Original file line numberDiff line numberDiff line change
    @@ -1830,7 +1830,8 @@ def test_empty_leaf_infinite_threshold():
    18301830
    assert len(empty_leaf) == 0
    18311831

    18321832

    1833-
    def check_multi_target(name):
    1833+
    @pytest.mark.parametrize('name', CLF_TREES)
    1834+
    def test_multi_target(name):
    18341835
    Tree = CLF_TREES[name]
    18351836

    18361837
    clf = Tree()
    @@ -1845,8 +1846,3 @@ def check_multi_target(name):
    18451846
    # Try to fix and predict.
    18461847
    clf.fit(X, ys)
    18471848
    clf.predict(X)
    1848-
    1849-
    1850-
    @pytest.mark.parametrize('name', CLF_TREES)
    1851-
    def test_multi_target(name):
    1852-
    check_multi_target(name)

    sklearn/tree/tree.py

    Lines changed: 2 additions & 2 deletions
    Original file line numberDiff line numberDiff line change
    @@ -436,8 +436,8 @@ def predict(self, X, check_input=True):
    436436
    return self.classes_.take(np.argmax(proba, axis=1), axis=0)
    437437

    438438
    else:
    439-
    predictions = np.zeros((n_samples, self.n_outputs_))
    440-
    439+
    class_type = self.classes_[0].dtype
    440+
    predictions = np.zeros((n_samples, self.n_outputs_), dtype=class_type)
    441441
    for k in range(self.n_outputs_):
    442442
    predictions[:, k] = self.classes_[k].take(
    443443
    np.argmax(proba[:, k], axis=1),

    0 commit comments

    Comments
     (0)
    0