8000 MISC replace two boolean parameters by a single string parameter. · amueller/scikit-learn@983bcd9 · GitHub
[go: up one dir, main page]

Skip to content

Commit 983bcd9

Browse files
committed
MISC replace two boolean parameters by a single string parameter.
1 parent 384df60 commit 983bcd9

File tree

1 file changed

+17
-17
lines changed

1 file changed

+17
-17
lines changed

sklearn/preprocessing.py

Lines changed: 17 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -928,14 +928,13 @@ class LabelBinarizer(BaseEstimator, TransformerMixin):
928928
classes : ndarray if int or None (default)
929929
Array of possible classes.
930930
931-
multilabel : bool or None (default)
932-
Whether or not data will be multilabel.
933-
If None, it will be inferred during fitting.
934-
935-
indicator_matrix : bool or None (default)
936-
Whether ``inverse_transform`` will produce an indicator
937-
matrix encoding (if False it will return list of lists).
938-
If None, it will be inferred during fitting.
931+
label_type : string, default="auto"
932+
Possible values and expected forms of y are:
933+
- "multiclass", y is array of ints
934+
- "multilabel-indicator", y is indicator matrix of classes
935+
- "multiclass-list", y is list of lists of labels
936+
- "auto", form of y is determined during 'fit'. If 'fit' is not
937+
called, multiclass is assumed.
939938
940939
Attributes
941940
----------
@@ -970,23 +969,24 @@ class LabelBinarizer(BaseEstimator, TransformerMixin):
970969
"""
971970

972971
def __init__(self, neg_label=0, pos_label=1, classes=None,
973-
multilabel=None, indicator_matrix=None):
972+
label_type='auto'):
974973
if neg_label >= pos_label:
975974
raise ValueError("neg_label must be strictly less than pos_label.")
976975

977976
self.neg_label = neg_label
978977
self.pos_label = pos_label
979978
self.classes = classes
980-
self.multilabel = multilabel
981-
self.indicator_matrix = indicator_matrix
979+
self.label_type = label_type
982980

983981
def _check_fitted(self):
984982
if not hasattr(self, "classes_"):
985983
if self.classes is not None:
986984
self.classes_ = np.unique(self.classes)
987985
# default to not doing multi-label things
988-
self.multilabel_ = bool(self.multilabel)
989-
self.indicator_matrix_ = bool(self.indicator_matrix)
986+
self.multilabel_ = self.label_type in ["multilabel-indicator",
987+
"multilabel-list"]
988+
self.indicator_matrix_ = (self.label_type ==
989+
"multilabel-indicator")
990990
else:
991991
raise ValueError("LabelBinarizer was not fitted yet.")
992992

@@ -1007,9 +1007,9 @@ def fit(self, y):
10071007
10081008
"""
10091009
self.multilabel_ = _is_multilabel(y)
1010-
if self.multilabel is not None and self.multilabel != self.multilabel_:
1011-
raise ValueError("Parameter multilabel was set explicitly but "
1012-
"does not match the data.")
1010+
if self.multilabel_ and self.label_type == "multiclass":
1011+
raise ValueError("Multilabel y was passed but"
1012+
" label_type='multiclass'.")
10131013
if self.multilabel_:
10141014
self.indicator_matrix_ = _is_label_indicator_matrix(y)
10151015
if self.indicator_matrix_:
@@ -1068,7 +1068,7 @@ def transform(self, y):
10681068
" input!")
10691069

10701070
elif self.multilabel_:
1071-
if not _is_multilabel(y):
1071+
if not y_is_multilabel:
10721072
raise ValueError("y should be a list of label lists/tuples,"
10731073
"got %r" % (y,))
10741074

0 commit comments

Comments
 (0)
0