8000 FIX Fixes tree and forest classification for non-numeric multi-target… · xhluca/scikit-learn@f95ffe6 · GitHub
[go: up one dir, main page]

Skip to content

Commit f95ffe6

Browse files
mitarXing
authored andcommitted
FIX Fixes tree and forest classification for non-numeric multi-target (scikit-learn#11458)
* Fixes tree and forest classification for non-numeric multi-target. Fixes scikit-learn#11451. * Renaming test functions, adding dtype to predictions array in tree.py. * Fixing flake8 issue. * Adding ignore warning to test_forest.py. * Switching to iris data for tests.
1 parent 31ad4db commit f95ffe6

File tree

3 files changed

+45
-2
lines changed

3 files changed

+45
-2
lines changed

sklearn/ensemble/tests/test_forest.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1337,3 +1337,25 @@ def test_backend_respected():
13371337
clf.predict_proba(X)
13381338

13391339
assert ba.count == 0
1340+
1341+
1342+
@pytest.mark.filterwarnings('ignore:The default value of n_estimators')
1343+
@pytest.mark.parametrize('name', FOREST_CLASSIFIERS)
1344+
@pytest.mark.parametrize('oob_score', (True, False))
1345+
def test_multi_target(name, oob_score):
1346+
ForestClassifier = FOREST_CLASSIFIERS[name]
1347+
1348+
clf = ForestClassifier(bootstrap=True, oob_score=oob_score)
1349+
1350+
X = iris.data
1351+
1352+
# Make multi column mixed type target.
1353+
y = np.vstack([
1354+
iris.target.astype(float),
1355+
iris.target.astype(int),
1356+
iris.target.astype(str),
1357+
]).T
1358+
1359+
# Try to fit and predict.
1360+
clf.fit(X, y)
1361+
clf.predict(X)

sklearn/tree/tests/test_tree.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1828,3 +1828,23 @@ def test_empty_leaf_infinite_threshold():
18281828
infinite_threshold = np.where(~np.isfinite(tree.tree_.threshold))[0]
18291829
assert len(infinite_threshold) == 0
18301830
assert len(empty_leaf) == 0
1831+
1832+
1833+
@pytest.mark.parametrize('name', CLF_TREES)
1834+
def test_multi_target(name):
1835+
Tree = CLF_TREES[name]
1836+
1837+
clf = Tree()
1838+
1839+
X = iris.data
1840+
1841+
# Make multi column mixed type target.
1842+
y = np.vstack([
1843+
iris.target.astype(float),
1844+
iris.target.astype(int),
1845+
iris.target.astype(str),
1846+
]).T
1847+
1848+
# Try to fit and predict.
1849+
clf.fit(X, y)
1850+
clf.predict(X)

sklearn/tree/tree.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -436,8 +436,9 @@ def predict(self, X, check_input=True):
436436
return self.classes_.take(np.argmax(proba, axis=1), axis=0)
437437

438438
else:
439-
predictions = np.zeros((n_samples, self.n_outputs_))
440-
439+
class_type = self.classes_[0].dtype
440+
predictions = np.zeros((n_samples, self.n_outputs_),
441+
dtype=class_type)
441442
for k in range(self.n_outputs_):
442443
predictions[:, k] = self.classes_[k].take(
443444
np.argmax(proba[:, k], axis=1),

0 commit comments

Comments
 (0)
0