@@ -928,14 +928,13 @@ class LabelBinarizer(BaseEstimator, TransformerMixin):
928
928
classes : ndarray if int or None (default)
929
929
Array of possible classes.
930
930
931
- multilabel : bool or None (default)
932
- Whether or not data will be multilabel.
933
- If None, it will be inferred during fitting.
934
-
935
- indicator_matrix : bool or None (default)
936
- Whether ``inverse_transform`` will produce an indicator
937
- matrix encoding (if False it will return list of lists).
938
- If None, it will be inferred during fitting.
931
+ label_type : string, default="auto"
932
+ Possible values and expected forms of y are:
933
+ - "multiclass", y is array of ints
934
+ - "multilabel-indicator", y is indicator matrix of classes
935
+ - "multiclass-list", y is list of lists of labels
936
+ - "auto", form of y is determined during 'fit'. If 'fit' is not
937
+ called, multiclass is assumed.
939
938
940
939
Attributes
941
940
----------
@@ -970,23 +969,24 @@ class LabelBinarizer(BaseEstimator, TransformerMixin):
970
969
"""
971
970
972
971
def __init__ (self , neg_label = 0 , pos_label = 1 , classes = None ,
973
- multilabel = None , indicator_matrix = None ):
972
+ label_type = 'auto' ):
974
973
if neg_label >= pos_label :
975
974
raise ValueError ("neg_label must be strictly less than pos_label." )
976
975
977
976
self .neg_label = neg_label
978
977
self .pos_label = pos_label
979
978
self .classes = classes
980
- self .multilabel = multilabel
981
- self .indicator_matrix = indicator_matrix
979
+ self .label_type = label_type
982
980
983
981
def _check_fitted (self ):
984
982
if not hasattr (self , "classes_" ):
985
983
if self .classes is not None :
986
984
self .classes_ = np .unique (self .classes )
987
985
# default to not doing multi-label things
988
- self .multilabel_ = bool (self .multilabel )
989
- self .indicator_matrix_ = bool (self .indicator_matrix )
986
+ self .multilabel_ = self .label_type in ["multilabel-indicator" ,
987
+ "multilabel-list" ]
988
+ self .indicator_matrix_ = (self .label_type ==
989
+ "multilabel-indicator" )
990
990
else :
991
991
raise ValueError ("LabelBinarizer was not fitted yet." )
992
992
@@ -1007,9 +1007,9 @@ def fit(self, y):
1007
1007
1008
1008
"""
1009
1009
self .multilabel_ = _is_multilabel (y )
1010
- if self .multilabel is not None and self .multilabel != self . multilabel_ :
1011
- raise ValueError ("Parameter multilabel was set explicitly but "
1012
- "does not match the data ." )
1010
+ if self .multilabel_ and self .label_type == "multiclass" :
1011
+ raise ValueError ("Multilabel y was passed but"
1012
+ " label_type='multiclass' ." )
1013
1013
if self .multilabel_ :
1014
1014
self .indicator_matrix_ = _is_label_indicator_matrix (y )
1015
1015
if self .indicator_matrix_ :
@@ -1068,7 +1068,7 @@ def transform(self, y):
1068
1068
" input!" )
1069
1069
1070
1070
elif self .multilabel_ :
1071
- if not _is_multilabel ( y ) :
1071
+ if not y_is_multilabel :
1072
1072
raise ValueError ("y should be a list of label lists/tuples,"
1073
1073
"got %r" % (y ,))
1074
1074
0 commit comments