8000 FIX validate in `fit` for `LabelBinarizer` estimator by krumetoft · Pull Request #21434 · scikit-learn/scikit-learn · GitHub
[go: up one dir, main page]

Skip to content

FIX validate in fit for LabelBinarizer estimator #21434

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 11 commits into from
Nov 2, 2021
7 changes: 7 additions & 0 deletions doc/whats_new/v1.1.rst
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,13 @@ Changelog
Setting a transformer to "passthrough" will pass the features unchanged.
:pr:`20860` by :user:`Shubhraneel Pal <shubhraneel>`.

:mod:`sklearn.preprocessing`
............................

- |Fix| :class:`preprocessing.LabelBinarizer` now validates input parameters in `fit`
instead of `__init__`.
:pr:`21434` by :user:`Krum Arnaudov <krumeto>`.

:mod:`sklearn.utils`
....................

Expand Down
29 changes: 15 additions & 14 deletions sklearn/preprocessing/_label.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,20 +256,6 @@ class LabelBinarizer(TransformerMixin, BaseEstimator):
"""

def __init__(self, *, neg_label=0, pos_label=1, sparse_output=False):
if neg_label >= pos_label:
raise ValueError(
"neg_label={0} must be strictly less than pos_label={1}.".format(
neg_label, pos_label
)
)

if sparse_output and (pos_label == 0 or neg_label != 0):
raise ValueError(
"Sparse binarization is only supported with non "
"zero pos_label and zero neg_label, got "
"pos_label={0} and neg_label={1}"
"".format(pos_label, neg_label)
)

self.neg_label = neg_label
self.pos_label = pos_label
Expand All @@ -289,7 +275,22 @@ def fit(self, y):
self : object
Returns the instance itself.
"""

if self.neg_label >= self.pos_label:
raise ValueError(
f"neg_label={self.neg_label} must be strictly less than "
f"pos_label={self.pos_label}."
)

if self.sparse_output and (self.pos_label == 0 or self.neg_label != 0):
raise ValueError(
"Sparse binarization is only supported with non "
"zero pos_label and zero neg_label, got "
f"pos_label={self.pos_label} and neg_label={self.neg_label}"
)

self.y_type_ = type_of_target(y, input_name="y")

if "multioutput" in self.y_type_:
raise ValueError(
"Multioutput target data is not supported with label binarization"
Expand Down
48 changes: 32 additions & 16 deletions sklearn/preprocessing/tests/test_label.py
8000
Original file line number Diff line number Diff line change
Expand Up @@ -124,25 +124,37 @@ def test_label_binarizer_errors():
lb = LabelBinarizer().fit(one_class)

multi_label = [(2, 3), (0,), (0, 2)]
with pytest.raises(ValueError):
err_msg = "You appear to be using a legacy multi-label data representation."
with pytest.raises(ValueError, match=err_msg):
lb.transform(multi_label)

lb = LabelBinarizer()
with pytest.raises(ValueError):
err_msg = "This LabelBinarizer instance is not fitted yet"
with pytest.raises(ValueError, match=err_msg):
lb.transform([])
with pytest.raises(ValueError):
with pytest.raises(ValueError, match=err_msg):
lb.inverse_transform([])

with pytest.raises(ValueError):
LabelBinarizer(neg_label=2, pos_label=1)
with pytest.raises(ValueError):
LabelBinarizer(neg_label=2, pos_label=2)

with pytest.raises(ValueError):
LabelBinarizer(neg_label=1, pos_label=2, sparse_output=True)
input_labels = [0, 1, 0, 1]
err_msg = "neg_label=2 must be strictly less than pos_label=1."
lb = LabelBinarizer(neg_label=2, pos_label=1)
with pytest.raises(ValueError, match=err_msg):
lb.fit(input_labels)
err_msg = "neg_label=2 must be strictly less than pos_label=2."
lb = LabelBinarizer(neg_label=2, pos_label=2)
with pytest.raises(ValueError, match=err_msg):
lb.fit(input_labels)
err_msg = (
"Sparse binarization is only supported with non zero pos_label and zero "
"neg_label, got pos_label=2 and neg_label=1"
)
lb = LabelBinarizer(neg_label=1, pos_label=2, sparse_output=True)
with pytest.raises(ValueError, match=err_msg):
lb.fit(input_labels)

# Fail on y_type
with pytest.raises(ValueError):
err_msg = "foo format is not supported"
with pytest.raises(ValueError, match=err_msg):
_inverse_binarize_thresholding(
y=csr_matrix([[1, 2], [2, 1]]),
output_type="foo",
Expand All @@ -152,11 +164,13 @@ def test_label_binarizer_errors():

# Sequence of seq type should raise ValueError
y_seq_of_seqs = [[], [1, 2], [3], [0, 1, 3], [2]]
with pytest.raises(ValueError):
err_msg = "You appear to be using a legacy multi-label data representation"
with pytest.raises(ValueError, match=err_msg):
LabelBinarizer().fit_transform(y_seq_of_seqs)

# Fail on the number of classes
with pytest.raises(ValueError):
err_msg = "The number of class is not equal to the number of dimension of y."
with pytest.raises(ValueError, match=err_msg):
_inverse_binarize_thresholding(
y=csr_matrix([[1, 2], [2, 1]]),
output_type="foo",
Expand All @@ -165,7 +179,8 @@ def test_label_binarizer_errors():
)

# Fail on the dimension of 'binary'
with pytest.raises(ValueError):
err_msg = "output_type='binary', but y.shape"
with pytest.raises(ValueError, match=err_msg):
_inverse_binarize_thresholding(
y=np.array([[1, 2, 3], [2, 1, 3]]),
output_type="binary",
Expand All @@ -174,9 +189,10 @@ def test_label_binarizer_errors():
)

# Fail on multioutput data
with pytest.raises(ValueError):
err_msg = "Multioutput target data is not supported with label binarization"
with pytest.raises(ValueError, match=err_msg):
LabelBinarizer().fit(np.array([[1, 3], [2, 1]]))
with pytest.raises(ValueError):
with pytest.raises(ValueError, match=err_msg):
label_binarize(np.array([[1, 3], [2, 1]]), classes=[1, 2, 3])


Expand Down
1 change: 0 additions & 1 deletion sklearn/tests/test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -412,7 +412,6 @@ def test_transformers_get_feature_names_out(transformer):
"GridSearchCV",
"HalvingGridSearchCV",
"KernelPCA",
"LabelBinarizer",
"NuSVC",
"NuSVR",
"OneClassSVM",
Expand Down
0