10000 Merge pull request #6221 from dsquareindia/LabelBinarizer_fix · scikit-learn/scikit-learn@945cb7e · GitHub
[go: up one dir, main page]

Skip to content

Commit 945cb7e

Browse files
committed
Merge pull request #6221 from dsquareindia/LabelBinarizer_fix
[MRG+1] LabelBinarizer single label case now works for sparse and dense case
2 parents e228581 + e9492b7 commit 945cb7e

File tree

2 files changed

+19
-6
lines changed

2 files changed

+19
-6
lines changed

sklearn/preprocessing/label.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -472,10 +472,13 @@ def label_binarize(y, classes, neg_label=0, pos_label=1, sparse_output=False):
472472
classes = np.asarray(classes)
473473

474474
if y_type == "binary":
475-
if len(classes) == 1:
476-
Y = np.zeros((len(y), 1), dtype=np.int)
477-
Y += neg_label
478-
return Y
475+
if n_classes == 1:
476+
if sparse_output:
477+
return sp.csr_matrix((n_samples, 1), dtype=int)
478+
else:
479+
Y = np.zeros((len(y), 1), dtype=np.int)
480+
Y += neg_label
481+
return Y
479482
elif len(classes) >= 3:
480483
y_type = "multiclass"
481484

sklearn/preprocessing/tests/test_label.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111

1212
from sklearn.utils.testing import assert_array_equal
1313
from sklearn.utils.testing import assert_equal
14+
from sklearn.utils.testing import assert_true
1415
from sklearn.utils.testing import assert_raises
1516
from sklearn.utils.testing import assert_raise_message
1617
from sklearn.utils.testing import ignore_warnings
@@ -35,16 +36,25 @@ def toarray(a):
3536

3637

3738
def test_label_binarizer():
38-
lb = LabelBinarizer()
39-
4039
# one-class case defaults to negative label
40+
# For dense case:
4141
inp = ["pos", "pos", "pos", "pos"]
42+
lb = LabelBinarizer(sparse_output=False)
4243
expected = np.array([[0, 0, 0, 0]]).T
4344
got = lb.fit_transform(inp)
4445
assert_array_equal(lb.classes_, ["pos"])
4546
assert_array_equal(expected, got)
4647
assert_array_equal(lb.inverse_transform(got), inp)
4748

49+
# For sparse case:
50+
lb = LabelBinarizer(sparse_output=True)
51+
got = lb.fit_transform(inp)
52+
assert_true(issparse(got))
53+
assert_array_equal(lb.classes_, ["pos"])
54+
assert_array_equal(expected, got.toarray())
55+
assert_array_equal(lb.inverse_transform(got.toarray()), inp)
56+
57+
lb = LabelBinarizer(sparse_output=False)
4858
# two-class case
4959
inp = ["neg", "pos", "pos", "neg"]
5060
expected = np.array([[0, 1, 1, 0]]).T

0 commit comments

Comments
 (0)
0