8000 Merge pull request #5542 from arjoly/tree-no-sparse-y · scikit-learn/scikit-learn@91753dc · GitHub
[go: up one dir, main page]

Skip to content

Commit 91753dc

Browse files
committed
Merge pull request #5542 from arjoly/tree-no-sparse-y
[MRG+1] Raise appropriate error if y is sparse
2 parents 5e0db3c + daa243d commit 91753dc

File tree

2 files changed

+13
-1
lines changed

2 files changed

+13
-1
lines changed

sklearn/tree/tests/test_tree.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1376,3 +1376,15 @@ def check_decision_path(name):
13761376
def test_decision_path():
13771377
for name in ALL_TREES:
13781378
yield (check_decision_path, name)
1379+
1380+
1381+
def check_no_sparse_y_support(name):
1382+
X, y = X_multilabel, csr_matrix(y_multilabel)
1383+
TreeEstimator = ALL_TREES[name]
1384+
assert_raises(ValueError, TreeEstimator(random_state=0).fit, X, y)
1385+
1386+
1387+
def test_no_sparse_y_support():
1388+
# Currently we don't support sparse y
1389+
for name in ALL_TREES:
1390+
yield (check_decision_path, name)

sklearn/tree/tree.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -151,7 +151,7 @@ def fit(self, X, y, sample_weight=None, check_input=True,
151151
random_state = check_random_state(self.random_state)
152152
if check_input:
153153
X = check_array(X, dtype=DTYPE, accept_sparse="csc")
154-
y = check_array(y, accept_sparse='csc', ensure_2d=False, dtype=None)
154+
y = check_array(y, ensure_2d=False, dtype=None)
155155
if issparse(X):
156156
X.sort_indices()
157157

0 commit comments

Comments
 (0)
0