@@ -933,12 +933,15 @@ class LabelBinarizer(BaseEstimator, TransformerMixin):
933
933
`classes_` : array of shape [n_class]
934
934
Holds the label for each class.
935
935
936
+ `multilabel_` : bool
937
+ Whether the estimator was fitted for multi-label data.
938
+
936
939
Examples
937
940
--------
938
941
>>> from sklearn import preprocessing
939
942
>>> lb = preprocessing.LabelBinarizer()
940
943
>>> lb.fit([1, 2, 6, 4, 2])
941
- LabelBinarizer(neg_label=0, pos_label=1)
944
+ LabelBinarizer(classes=None, multilabel=None, neg_label=0, pos_label=1)
942
945
>>> lb.classes_
943
946
array([1, 2, 4, 6])
944
947
>>> lb.transform([1, 6])
@@ -952,18 +955,22 @@ class LabelBinarizer(BaseEstimator, TransformerMixin):
952
955
array([1, 2, 3])
953
956
"""
954
957
955
- def __init__ (self , neg_label = 0 , pos_label = 1 , classes = None ):
958
+ def __init__ (self , neg_label = 0 , pos_label = 1 , classes = None ,
959
+ multilabel = None ):
956
960
if neg_label >= pos_label :
957
961
raise ValueError ("neg_label must be strictly less than pos_label." )
958
962
959
963
self .neg_label = neg_label
960
964
self .pos_label = pos_label
961
965
self .classes = classes
966
+ self .multilabel = multilabel
962
967
963
968
def _check_fitted (self ):
964
969
if not hasattr (self , "classes_" ):
965
970
if self .classes is not None :
966
971
self .classes_ = np .unique (self .classes )
972
+ # default to not doing multi-label things
973
+ self .multilabel_ = bool (self .multilabel )
967
974
else :
968
975
raise ValueError ("LabelBinarizer was not fitted yet." )
969
976
@@ -983,8 +990,11 @@ def fit(self, y):
983
990
self : returns an instance of self.
984
991
985
992
"""
986
- self .multilabel = _is_multilabel (y )
987
- if self .multilabel :
993
+ self .multilabel_ = _is_multilabel (y )
994
+ if self .multilabel is not None and self .multilabel != self .multilabel_ :
995
+ raise ValueError ("Parameter multilabel was set explicity but "
996
+ "does not match the data." )
997
+ if self .multilabel_ :
988
998
self .indicator_matrix_ = _is_label_indicator_matrix (y )
989
999
if self .indicator_matrix_ :
990
1000
classes = np .arange (y .shape [1 ])
@@ -1024,7 +1034,7 @@ def transform(self, y):
1024
1034
"""
1025
1035
self ._check_fitted ()
1026
1036
1027
- if self .multilabel or len (self .classes_ ) > 2 :
1037
+ if self .multilabel_ or len (self .classes_ ) > 2 :
1028
1038
if _is_label_indicator_matrix (y ):
1029
1039
# nothing to do as y is already a label indicator matrix
1030
1040
return y
@@ -1037,11 +1047,11 @@ def transform(self, y):
1037
1047
1038
1048
y_is_multilabel = _is_multilabel (y )
1039
1049
1040
- if y_is_multilabel and not self .multilabel :
1050
+ if y_is_multilabel and not self .multilabel_ :
1041
1051
raise ValueError ("The object was not fitted with multilabel"
1042
1052
" input!" )
1043
1053
1044
- elif self .multilabel :
1054
+ elif self .multilabel_ :
1045
1055
if not _is_multilabel (y ):
1046
1056
raise ValueError ("y should be a list of label lists/tuples,"
1047
1057
"got %r" % (y ,))
0 commit comments