8000 ENH added "multilabel" option to LabelBinarizer · amueller/scikit-learn@c6bac77 · GitHub
[go: up one dir, main page]

8000 Skip to content

Commit c6bac77

Browse files
committed
ENH added "multilabel" option to LabelBinarizer
1 parent 9beffbd commit c6bac77

File tree

1 file changed

+17
-7
lines changed

1 file changed

+17
-7
lines changed

sklearn/preprocessing.py

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -933,12 +933,15 @@ class LabelBinarizer(BaseEstimator, TransformerMixin):
933933
`classes_` : array of shape [n_class]
934934
Holds the label for each class.
935935
936+
`multilabel_` : bool
937+
Whether the estimator was fitted for multi-label data.
938+
936939
Examples
937940
--------
938941
>>> from sklearn import preprocessing
939942
>>> lb = preprocessing.LabelBinarizer()
940943
>>> lb.fit([1, 2, 6, 4, 2])
941-
LabelBinarizer(neg_label=0, pos_label=1)
944+
LabelBinarizer(classes=None, multilabel=None, neg_label=0, pos_label=1)
942945
>>> lb.classes_
943946
array([1, 2, 4, 6])
944947
>>> lb.transform([1, 6])
@@ -952,18 +955,22 @@ class LabelBinarizer(BaseEstimator, TransformerMixin):
952955
array([1, 2, 3])
953956
"""
954957

955-
def __init__(self, neg_label=0, pos_label=1, classes=None):
958+
def __init__(self, neg_label=0, pos_label=1, classes=None,
959+
multilabel=None):
956960
if neg_label >= pos_label:
957961
raise ValueError("neg_label must be strictly less than pos_label.")
958962

959963
self.neg_label = neg_label
960964
self.pos_label = pos_label
961965
self.classes = classes
966+
self.multilabel = multilabel
962967

963968
def _check_fitted(self):
964969
if not hasattr(self, "classes_"):
965970
if self.classes is not None:
966971
self.classes_ = np.unique(self.classes)
972+
# default to not doing multi-label things
973+
self.multilabel_ = bool(self.multilabel)
967974
else:
968975
raise ValueError("LabelBinarizer was not fitted yet.")
969976

@@ -983,8 +990,11 @@ def fit(self, y):
983990
self : returns an instance of self.
984991
985992
"""
986-
self.multilabel = _is_multilabel(y)
987-
if self.multilabel:
993+
self.multilabel_ = _is_multilabel(y)
994+
if self.multilabel is not None and self.multilabel != self.multilabel_:
995+
raise ValueError("Parameter multilabel was set explicity but "
996+
"does not match the data.")
997+
if self.multilabel_:
988998
self.indicator_matrix_ = _is_label_indicator_matrix(y)
989999
if self.indicator_matrix_:
9901000
classes = np.arange(y.shape[1])
@@ -1024,7 +1034,7 @@ def transform(self, y):
10241034
"""
10251035
self._check_fitted()
10261036

1027-
if self.multilabel or len(self.classes_) > 2:
1037+
if self.multilabel_ or len(self.classes_) > 2:
10281038
if _is_label_indicator_matrix(y):
10291039
# nothing to do as y is already a label indicator matrix
10301040
return y
@@ -1037,11 +1047,11 @@ def transform(self, y):
10371047

10381048
y_is_multilabel = _is_multilabel(y)
10391049

1040-
if y_is_multilabel and not self.multilabel:
1050+
if y_is_multilabel and not self.multilabel_:
10411051
raise ValueError("The object was not fitted with multilabel"
10421052
" input!")
10431053

1044-
elif self.multilabel:
1054+
elif self.multilabel_:
10451055
if not _is_multilabel(y):
10461056
raise ValueError("y should be a list of label lists/tuples,"
10471057
"got %r" % (y,))

0 commit comments

Comments
 (0)
0