10000 Support multilabel case in LabelBinarizer. · scikit-learn/scikit-learn@ecb869c · GitHub
[go: up one dir, main page]

Skip to content

Commit ecb869c

Browse files
committed
Support multilabel case in LabelBinarizer.
1 parent 138e688 commit ecb869c

File tree

2 files changed

+49
-2
lines changed

2 files changed

+49
-2
lines changed

scikits/learn/preprocessing/__init__.py

Lines changed: 38 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,10 @@ def transform(self, X, copy=True):
128128
return X
129129

130130

131+
def _is_multilabel(y):
132+
return isinstance(y[0], tuple) or isinstance(y[0], list)
133+
134+
131135
class LabelBinarizer(BaseEstimator, TransformerMixin):
132136
"""Binarize labels in a one-vs-all fashion.
133137
@@ -160,6 +164,10 @@ class LabelBinarizer(BaseEstimator, TransformerMixin):
160164
>>> clf.transform([1, 6])
161165
array([[ 1., 0., 0., 0.],
162166
[ 0., 0., 0., 1.]])
167+
168+
>>> clf.fit_transform([(1,2),(3,)])
169+
array([[ 1., 1., 0.],
170+
[ 0., 0., 1.]])
163171
"""
164172

165173
def fit(self, y):
@@ -174,7 +182,11 @@ def fit(self, y):
174182
-------
175183
self : returns an instance of self.
176184
"""
177-
self.classes_ = np.unique(y)
185+
self.multilabel = _is_multilabel(y)
186+
if self.multilabel:
187+
self.classes_ = np.unique(reduce(lambda a,b:a+b, y))
188+
else:
189+
self.classes_ = np.unique(y)
178190
return self
179191

180192
def transform(self, y):
@@ -192,13 +204,30 @@ def transform(self, y):
192204
-------
193205
Y : numpy array of shape [n_samples, n_classes]
194206
"""
207+
195208
if len(self.classes_) == 2:
196209
Y = np.zeros((len(y), 1))
210+
else:
211+
Y = np.zeros((len(y), len(self.classes_)))
212+
213+
if self.multilabel:
214+
if not _is_multilabel(y):
215+
raise ValueError, "y should be a list of label lists/tuples"
216+
217+
# inverse map: label => column index
218+
imap = dict((v,k) for k,v in enumerate(self.classes_))
219+
220+
for i, label_tuple in enumerate(y):
221+
for label in label_tuple:
222+
Y[i, imap[label]] = 1
223+
224+
return Y
225+
226+
elif len(self.classes_) == 2:
197227
Y[y == self.classes_[1], 0] = 1
198228
return Y
199229

200230
elif len(self.classes_) >= 2:
201-
Y = np.zeros((len(y), len(self.classes_)))
202231
for i, k in enumerate(self.classes_):
203232
Y[y == k, i] = 1
204233
return Y
@@ -225,8 +254,15 @@ def inverse_transform(self, Y):
225254
this allows to use the output of a linear model's decision_function
226255
method directly as the input of inverse_transform.
227256
"""
257+
if self.multilabel:
258+
Y = np.array(Y > 0, dtype=int)
259+
return [tuple(self.classes_[np.flatnonzero(Y[i])])
260+
for i in range(Y.shape[0])]
261+
228262
if len(Y.shape) == 1 or Y.shape[1] == 1:
229263
y = np.array(Y.ravel() > 0, dtype=int)
264+
230265
else:
231266
y = Y.argmax(axis=1)
267+
232268
return self.classes_[y]

scikits/learn/preprocessing/tests/test_preprocessing.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,17 @@ def test_label_binarizer():
150150
assert_array_equal(expected, got)
151151
assert_array_equal(lb.inverse_transform(got), inp)
152152

153+
def test_label_binarizer_multilabel():
154+
lb = LabelBinarizer()
155+
156+
inp = [(2, 3), (1,), (1, 2)]
157+
expected = np.array([[0, 1, 1],
158+
[1, 0, 0],
159+
[1, 1, 0]])
160+
got = lb.fit_transform(inp)
161+
assert_array_equal(expected, got)
162+
assert_equal(lb.inverse_transform(got), inp)
163+
153164
def test_label_binarizer_iris():
154165
lb = LabelBinarizer()
155166
Y = lb.fit_transform(iris.target)

0 commit comments

Comments
 (0)
0