@@ -128,6 +128,10 @@ def transform(self, X, copy=True):
128
128
return X
129
129
130
130
131
+ def _is_multilabel (y ):
132
+ return isinstance (y [0 ], tuple ) or isinstance (y [0 ], list )
133
+
134
+
131
135
class LabelBinarizer (BaseEstimator , TransformerMixin ):
132
136
"""Binarize labels in a one-vs-all fashion.
133
137
@@ -160,6 +164,10 @@ class LabelBinarizer(BaseEstimator, TransformerMixin):
160
164
>>> clf.transform([1, 6])
161
165
array([[ 1., 0., 0., 0.],
162
166
[ 0., 0., 0., 1.]])
167
+
168
+ >>> clf.fit_transform([(1,2),(3,)])
169
+ array([[ 1., 1., 0.],
170
+ [ 0., 0., 1.]])
163
171
"""
164
172
165
173
def fit (self , y ):
@@ -174,7 +182,11 @@ def fit(self, y):
174
182
-------
175
183
self : returns an instance of self.
176
184
"""
177
- self .classes_ = np .unique (y )
185
+ self .multilabel = _is_multilabel (y )
186
+ if self .multilabel :
187
+ self .classes_ = np .unique (reduce (lambda a ,b :a + b , y ))
188
+ else :
189
+ self .classes_ = np .unique (y )
178
190
return self
179
191
180
192
def transform (self , y ):
@@ -192,13 +204,30 @@ def transform(self, y):
192
204
-------
193
205
Y : numpy array of shape [n_samples, n_classes]
194
206
"""
207
+
195
208
if len (self .classes_ ) == 2 :
196
209
Y = np .zeros ((len (y ), 1 ))
210
+ else :
211
+ Y = np .zeros ((len (y ), len (self .classes_ )))
212
+
213
+ if self .multilabel :
214
+ if not _is_multilabel (y ):
215
+ raise ValueError , "y should be a list of label lists/tuples"
216
+
217
+ # inverse map: label => column index
218
+ imap = dict ((v ,k ) for k ,v in enumerate (self .classes_ ))
219
+
220
+ for i , label_tuple in enumerate (y ):
221
+ for label in label_tuple :
222
+ Y [i , imap [label ]] = 1
223
+
224
+ return Y
225
+
226
+ elif len (self .classes_ ) == 2 :
197
227
Y [y == self .classes_ [1 ], 0 ] = 1
198
228
return Y
199
229
200
230
elif len (self .classes_ ) >= 2 :
201
- Y = np .zeros ((len (y ), len (self .classes_ )))
202
231
for i , k in enumerate (self .classes_ ):
203
232
Y [y == k , i ] = 1
204
233
return Y
@@ -225,8 +254,15 @@ def inverse_transform(self, Y):
225
254
this allows to use the output of a linear model's decision_function
226
255
method directly as the input of inverse_transform.
227
256
"""
257
+ if self .multilabel :
258
+ Y = np .array (Y > 0 , dtype = int )
259
+ return [tuple (self .classes_ [np .flatnonzero (Y [i ])])
260
+ for i in range (Y .shape [0 ])]
261
+
228
262
if len (Y .shape ) == 1 or Y .shape [1 ] == 1 :
229
263
y = np .array (Y .ravel () > 0 , dtype = int )
264
+
230
265
else :
231
266
y = Y .argmax (axis = 1 )
267
+
232
268
return self .classes_ [y ]
0 commit comments