@@ -932,6 +932,11 @@ class LabelBinarizer(BaseEstimator, TransformerMixin):
932
932
Whether or not data will be multilabel.
933
933
If None, it will be inferred during fitting.
934
934
935
+ indicator_matrix : bool or None (default)
936
+ Whether ``inverse_transform`` will produce an indicator
937
+ matrix encoding (if False it will return list of lists).
938
+ If None, it will be inferred during fitting.
939
+
935
940
Attributes
936
941
----------
937
942
`classes_` : array of shape [n_class]
@@ -940,6 +945,10 @@ class LabelBinarizer(BaseEstimator, TransformerMixin):
940
945
`multilabel_` : bool
941
946
Whether the estimator was fitted for multi-label data.
942
947
948
+ `indicator_matrix_` : bool
949
+ Whether the estimator was fitted with a label indicator matrix.
950
+ This will determine the result of ``inverse_transform``.
951
+
943
952
Examples
944
953
--------
945
954
>>> from sklearn import preprocessing
@@ -960,21 +969,23 @@ class LabelBinarizer(BaseEstimator, TransformerMixin):
960
969
"""
961
970
962
971
def __init__ (self , neg_label = 0 , pos_label = 1 , classes = None ,
963
- multilabel = None ):
972
+ multilabel = None , indicator_matrix = None ):
964
973
if neg_label >= pos_label :
965
974
raise ValueError ("neg_label must be strictly less than pos_label." )
966
975
967
976
self .neg_label = neg_label
968
977
self .pos_label = pos_label
969
978
self .classes = classes
970
979
self .multilabel = multilabel
980
+ self .indicator_matrix = indicator_matrix
971
981
972
982
def _check_fitted (self ):
973
983
if not hasattr (self , "classes_" ):
974
984
if self .classes is not None :
975
985
self .classes_ = np .unique (self .classes )
976
986
# default to not doing multi-label things
977
987
self .multilabel_ = bool (self .multilabel )
988
+ self .indicator_matrix_ = bool (self .indicator_matrix )
978
989
else :
979
990
raise ValueError ("LabelBinarizer was not fitted yet." )
980
991
0 commit comments