8000 ENH support multilabel targets in LabelEncoder · jnothman/scikit-learn@96dbd77 · GitHub
[go: up one dir, main page]

Skip to content

Commit 96dbd77

Browse files
committed
ENH support multilabel targets in LabelEncoder
Also, support 1d-array of sequences as a multilabel format
1 parent 5b0afce commit 96dbd77

File tree

2 files changed

+79
-8
lines changed

2 files changed

+79
-8
lines changed

sklearn/preprocessing.py

Lines changed: 30 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
# Andreas Mueller <amueller@ais.uni-bonn.de>
55
# License: BSD 3 clause
66

7+
import functools
78
import warnings
89
import numbers
910

@@ -19,6 +20,7 @@
1920
from .utils.multiclass import unique_labels
2021
from .utils.multiclass import is_multilabel
2122
from .utils.multiclass import is_label_indicator_matrix
23+
from .utils.multiclass import multilabel_vectorize
2224

2325
from .utils.sparsefuncs import inplace_csr_row_normalize_l1
2426
from .utils.sparsefuncs import inplace_csr_row_normalize_l2
@@ -829,6 +831,15 @@ class LabelEncoder(BaseEstimator, TransformerMixin):
829831
>>> list(le.inverse_transform([2, 2, 1]))
830832
['tokyo', 'tokyo', 'paris']
831833
834+
It can also be used to transform multi-label sequences of sequences:
835+
836+
>>> le = preprocessing.LabelEncoder()
837+
>>> targets = [["paris", "tokyo"], ["amsterdam", "paris"]]
838+
>>> le.fit_transform(targets)
839+
array([[1 2], [0 1]], dtype=object)
840+
>>> list(map(list, le.inverse_transform([[1, 2], [0, 1]])))
841+
[['paris', 'tokyo'], ['amsterdam', 'paris']]
842+
832843
"""
833844

834845
def _check_fitted(self):
@@ -847,7 +858,7 @@ def fit(self, y):
847858
-------
848859
self : returns an instance of self.
849860
"""
850-
self.classes_ = np.unique(y)
861+
self.classes_ = unique_labels(y)
851862
return self
852863

853864
def fit_transform(self, y):
@@ -862,6 +873,9 @@ def fit_transform(self, y):
862873
-------
863874
y : array-like of shape [n_samples]
864875
"""
876+
if is_multilabel(y):
877+
self.fit(y)
878+
return self.transform(y)
865879
self.classes_, y = unique(y, return_inverse=True)
866880
return y
867881

@@ -878,12 +892,19 @@ def transform(self, y):
878892
y : array-like of shape [n_samples]
879893
"""
880894
self._check_fitted()
895+
if is_multilabel(y):
896+
if is_label_indicator_matrix(y):
897+
raise ValueError(
898+
'{} does not support label indicator matrices'.format(
899+
self.__class__.__name__))
900+
return multilabel_vectorize(self._transform)(y)
881901

882-
classes = np.unique(y)
883-
if len(np.intersect1d(classes, self.classes_)) < len(classes):
884-
diff = np.setdiff1d(classes, self.classes_)
885-
raise ValueError("y contains new labels: %s" % str(diff))
902+
return self._transform(y)
886903

904+
def _transform(self, y):
905+
diff = np.setdiff1d(y, self.classes_)
906+
if len(diff):
907+
raise ValueError("y contains new labels: %s" % str(diff))
887908
return np.searchsorted(self.classes_, y)
888909

889910
def inverse_transform(self, y):
@@ -900,6 +921,10 @@ def inverse_transform(self, y):
900921
"""
901922
self._check_fitted()
902923

924+
if is_multilabel(y):
925+
# np.vectorize does not work with np.ndarray.take!
926+
take = functools.partial(np.take, self.classes_)
927+
return multilabel_vectorize(take)(y)
903928
y = np.asarray(y)
904929
return self.classes_[y]
905930

sklearn/utils/multiclass.py

Lines changed: 49 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,9 @@ def is_multilabel(y):
120120
False
121121
>>> is_multilabel([[1], [0, 2], []])
122122
True
123-
>>> is_multilabel(np.array([[1, 0], [0, 0]]))
123+
>>> is_multilabel(np.array([np.array([1]), np.array([0, 2])]))
124+
True
125+
>>> is_multilabel(np.array([[1, 0], [0, 0]])) # label indicator matrix
124126
True
125127
>>> is_multilabel(np.array([[1], [0], [0]]))
126128
False
@@ -130,5 +132,49 @@ def is_multilabel(y):
130132
"""
131133
# the explicit check for ndarray is for forward compatibility; future
132134
# versions of Numpy might want to register ndarray as a Sequence
133-
return (not isinstance(y[0], np.ndarray) and isinstance(y[0], Sequence) and
134-
not isinstance(y[0], string_types) or is_label_indicator_matrix(y))
135+
if getattr(y, 'ndim', 1) != 1:
136+
return is_label_indicator_matrix(y)
137+
return ((isinstance(y[0], Sequence) and not isinstance(y[0], string_types))
138+
or isinstance(y[0], np.ndarray))
139+
140+
141+
def multilabel_as_array(y):
142+
"""Transform a sequence of sequences into an array of sequences
143+
144+
Parameters
145+
----------
146+
y : sequence or array of sequences
147+
Target values. In the multilabel case the nested sequences can
148+
have variable lengths. Label indicator matrices are not supported.
149+
150+
Returns
151+
-------
152+
out : numpy array of shape [len(y)]
153+
The elements of the returned array correspond to the elements of y.
154+
If y is an array, it is returned without copying.
155+
"""
156+
if hasattr(y, '__array__'):
157+
return np.asarray(y)
158+
out = np.empty(len(y), dtype=object)
159+
out[:] = y
160+
return out
161+
162+
163+
def multilabel_vectorize(func, otypes='O'):
164+
"""Vectorize a function suitably for sequence-of-sequence input and output
165+
166+
Parameters
167+
----------
168+
func : a function to vectorize
169+
otypes : the dtypes of the output arrays, default objects
170+
171+
Returns
172+
-------
173+
out : callable
174+
The returned function will vectorize `func` over its arguments, first
175+
ensuring they are arrays of sequences.
176+
"""
177+
vfunc = np.vectorize(func, otypes=otypes)
178+
def wrapper(*args):
179+
return vfunc(*[multilabel_as_array(arg) for arg in args])
180+
return wrapper

0 commit comments

Comments
 (0)
0