8000 ENH add "indicator_matrix" parameter to LabelBinarizer, add test. · amueller/scikit-learn@eafd52d · GitHub
[go: up one dir, main page]

Skip to content

Commit eafd52d

Browse files
committed
ENH add "indicator_matrix" parameter to LabelBinarizer, add test.
1 parent 9690888 commit eafd52d

File tree

2 files changed

+22
-2
lines changed

2 files changed

+22
-2
lines changed

sklearn/preprocessing.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -932,6 +932,11 @@ class LabelBinarizer(BaseEstimator, TransformerMixin):
932932
Whether or not data will be multilabel.
933933
If None, it will be inferred during fitting.
934934
935+
indicator_matrix : bool or None (default)
936+
Whether ``inverse_transform`` will produce an indicator
937+
matrix encoding (if False it will return list of lists).
938+
If None, it will be inferred during fitting.
939+
935940
Attributes
936941
----------
937942
`classes_` : array of shape [n_class]
@@ -940,6 +945,10 @@ class LabelBinarizer(BaseEstimator, TransformerMixin):
940945
`multilabel_` : bool
941946
Whether the estimator was fitted for multi-label data.
942947
948+
`indicator_matrix_` : bool
949+
Whether the estimator was fitted with a label indicator matrix.
950+
This will determine the result of ``inverse_transform``.
951+
943952
Examples
944953
--------
945954
>>> from sklearn import preprocessing
@@ -960,21 +969,23 @@ class LabelBinarizer(BaseEstimator, TransformerMixin):
960969
"""
961970

962971
def __init__(self, neg_label=0, pos_label=1, classes=None,
963-
multilabel=None):
972+
multilabel=None, indicator_matrix=None):
964973
if neg_label >= pos_label:
965974
raise ValueError("neg_label must be strictly less than pos_label.")
966975

967976
self.neg_label = neg_label
968977
self.pos_label = pos_label
969978
self.classes = classes
970979
self.multilabel = multilabel
980+
self.indicator_matrix = indicator_matrix
971981

972982
def _check_fitted(self):
973983
if not hasattr(self, "classes_"):
974984
if self.classes is not None:
975985
self.classes_ = np.unique(self.classes)
976986
# default to not doing multi-label things
977987
self.multilabel_ = bool(self.multilabel)
988+
self.indicator_matrix_ = bool(self.indicator_matrix)
978989
else:
979990
raise ValueError("LabelBinarizer was not fitted yet.")
980991

sklearn/tests/test_preprocessing.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -657,15 +657,24 @@ def test_label_binarizer_classes():
657657
transformed = lb.transform(['see', 'see'])
658658
assert_equal(transformed.shape, (2, 3))
659659
assert_array_equal(np.argmax(transformed, axis=1), [2, 2])
660+
# test inverse transform
661+
assert_array_equal(['see', 'see'], lb.inverse_transform(transformed))
660662

661663
# also works with multilabel data if we say so:
662664
lb = LabelBinarizer(classes=np.arange(1, 3), multilabel=True)
663-
y = [[1, 2], [1], []]
665+
y = [(1, 2), (1,), ()]
664666
Y = np.array([[1, 1],
665667
[1, 0],
666668
[0, 0]])
667669
assert_array_equal(lb.transform(y), Y)
668670
assert_array_equal(lb.fit_transform(y), Y)
671+
# inverse transform of label indicator matrix to label
672+
assert_array_equal(lb.inverse_transform(Y), y)
673+
674+
# inverse transform with indicator_matrix=True
675+
lb = LabelBinarizer(classes=np.arange(1, 3), multilabel=True,
676+
indicator_matrix=True)
677+
assert_array_equal(lb.inverse_transform(Y), Y)
669678

670679
lb = LabelBinarizer(classes=np.arange(1, 3))
671680
assert_raise_message(ValueError, "not fitted with multilabel",

0 commit comments

Comments
 (0)
0