4
4
# Andreas Mueller <amueller@ais.uni-bonn.de>
5
5
# License: BSD 3 clause
6
6
7
+ import functools
7
8
import warnings
8
9
import numbers
9
10
19
20
from .utils .multiclass import unique_labels
20
21
from .utils .multiclass import is_multilabel
21
22
from .utils .multiclass import is_label_indicator_matrix
23
+ from .utils .multiclass import multilabel_vectorize
22
24
23
25
from .utils .sparsefuncs import inplace_csr_row_normalize_l1
24
26
from .utils .sparsefuncs import inplace_csr_row_normalize_l2
@@ -829,6 +831,15 @@ class LabelEncoder(BaseEstimator, TransformerMixin):
829
831
>>> list(le.inverse_transform([2, 2, 1]))
830
832
['tokyo', 'tokyo', 'paris']
831
833
834
+ It can also be used to transform multi-label sequences of sequences:
835
+
836
+ >>> le = preprocessing.LabelEncoder()
837
+ >>> targets = [["paris", "tokyo"], ["amsterdam", "paris"]]
838
+ >>> le.fit_transform(targets)
839
+ array([[1 2], [0 1]], dtype=object)
840
+ >>> list(map(list, le.inverse_transform([[1, 2], [0, 1]])))
841
+ [['paris', 'tokyo'], ['amsterdam', 'paris']]
842
+
832
843
"""
833
844
834
845
def _check_fitted (self ):
@@ -847,7 +858,7 @@ def fit(self, y):
847
858
-------
848
859
self : returns an instance of self.
849
860
"""
850
- self .classes_ = np . unique (y )
861
+ self .classes_ = unique_labels (y )
851
862
return self
852
863
853
864
def fit_transform (self , y ):
@@ -862,6 +873,9 @@ def fit_transform(self, y):
862
873
-------
863
874
y : array-like of shape [n_samples]
864
875
"""
876
+ if is_multilabel (y ):
877
+ self .fit (y )
878
+ return self .transform (y )
865
879
self .classes_ , y = unique (y , return_inverse = True )
866
880
return y
867
881
@@ -878,12 +892,19 @@ def transform(self, y):
878
892
y : array-like of shape [n_samples]
879
893
"""
880
894
self ._check_fitted ()
895
+ if is_multilabel (y ):
896
+ if is_label_indicator_matrix (y ):
897
+ raise ValueError (
898
+ '{} does not support label indicator matrices' .format (
899
+ self .__class__ .__name__ ))
900
+ return multilabel_vectorize (self ._transform )(y )
881
901
882
- classes = np .unique (y )
883
- if len (np .intersect1d (classes , self .classes_ )) < len (classes ):
884
- diff = np .setdiff1d (classes , self .classes_ )
885
- raise ValueError ("y contains new labels: %s" % str (diff ))
902
+ return self ._transform (y )
886
903
904
+ def _transform (self , y ):
905
+ diff = np .setdiff1d (y , self .classes_ )
906
+ if len (diff ):
907
+ raise ValueError ("y contains new labels: %s" % str (diff ))
887
908
return np .searchsorted (self .classes_ , y )
888
909
889
910
def inverse_transform (self , y ):
@@ -900,6 +921,10 @@ def inverse_transform(self, y):
900
921
"""
901
922
self ._check_fitted ()
902
923
924
+ if is_multilabel (y ):
925
+ # np.vectorize does not work with np.ndarray.take!
926
+ take = functools .partial (np .take , self .classes_ )
927
+ return multilabel_vectorize (take )(y )
903
928
y = np .asarray (y )
904
929
return self .classes_ [y ]
905
930
0 commit comments