-
-
Notifications
You must be signed in to change notification settings - Fork 25.9k
FIX validate in fit
for LabelBinarizer
estimator
#21434
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
I think that we should add an entry in the changelog since it could have an effect on third-party libraries. Please add an entry to the change log at - |Fix| :class:`preprocessing.LabelBinarizer` nows validate input parameters in `fit`
instead of `__init__`.
:pr:`21434` by :user:`krumetoft <krumetoft>`. It should in the section: :mod:`sklearn.preprocsessing` |
You will need to edit the test where we were testing the pattern: with pytest.raises(...):
LabelBinarizer(...) by calling diff --git a/sklearn/preprocessing/_label.py b/sklearn/preprocessing/_label.py
index 72bd8c6d65..12f6d08b5c 100644
--- a/sklearn/preprocessing/_label.py
+++ b/sklearn/preprocessing/_label.py
@@ -275,6 +275,19 @@ class LabelBinarizer(TransformerMixin, BaseEstimator):
self : object
Returns the instance itself.
"""
+ if self.neg_label >= self.pos_label:
+ raise ValueError(
+ f"neg_label={self.neg_label} must be strictly less than "
+ f"pos_label={self.pos_label}."
+ )
+
+ if self.sparse_output and (self.pos_label == 0 or self.neg_label != 0):
+ raise ValueError(
+ "Sparse binarization is only supported with non "
+ "zero pos_label and zero neg_label, got "
+ f"pos_label={self.pos_label} and neg_label={self.neg_label}"
+ )
+
self.y_type_ = type_of_target(y)
if "multioutput" in self.y_type_:
raise ValueError(
diff --git a/sklearn/preprocessing/tests/test_label.py b/sklearn/preprocessing/tests/test_label.py
index 5142144bcb..3dc4afaf89 100644
--- a/sklearn/preprocessing/tests/test_label.py
+++ b/sklearn/preprocessing/tests/test_label.py
@@ -124,25 +124,35 @@ def test_label_binarizer_errors():
lb = LabelBinarizer().fit(one_class)
multi_label = [(2, 3), (0,), (0, 2)]
- with pytest.raises(ValueError):
+ err_msg = "You appear to be using a legacy multi-label data representation."
+ with pytest.raises(ValueError, match=err_msg):
lb.transform(multi_label)
lb = LabelBinarizer()
- with pytest.raises(ValueError):
+
+ err_msg = "This LabelBinarizer instance is not fitted yet"
+ with pytest.raises(ValueError, match=err_msg):
lb.transform([])
- with pytest.raises(ValueError):
+ with pytest.raises(ValueError, match=err_msg):
lb.inverse_transform([])
- with pytest.raises(ValueError):
- LabelBinarizer(neg_label=2, pos_label=1)
- with pytest.raises(ValueError):
- LabelBinarizer(neg_label=2, pos_label=2)
-
- with pytest.raises(ValueError):
- LabelBinarizer(neg_label=1, pos_label=2, sparse_output=True)
+ input_labels = [0, 1, 0, 1]
+ err_msg = "neg_label=2 must be strictly less than pos_label=1."
+ with pytest.raises(ValueError, match=err_msg):
+ LabelBinarizer(neg_label=2, pos_label=1).fit(input_labels)
+ err_msg = "neg_label=2 must be strictly less than pos_label=2."
+ with pytest.raises(ValueError, match=err_msg):
+ LabelBinarizer(neg_label=2, pos_label=2).fit(input_labels)
+ err_msg = (
+ "Sparse binarization is only supported with non zero pos_label and zero "
+ "neg_label, got pos_label=2 and neg_label=1"
+ )
+ with pytest.raises(ValueError, match=err_msg):
+ LabelBinarizer(neg_label=1, pos_label=2, sparse_output=True).fit(input_labels)
# Fail on y_type
- with pytest.raises(ValueError):
+ err_msg = "foo format is not supported"
+ with pytest.raises(ValueError, match=err_msg):
_inverse_binarize_thresholding(
y=csr_matrix([[1, 2], [2, 1]]),
output_type="foo",
@@ -152,11 +162,13 @@ def test_label_binarizer_errors():
# Sequence of seq type should raise ValueError
y_seq_of_seqs = [[], [1, 2], [3], [0, 1, 3], [2]]
- with pytest.raises(ValueError):
+ err_msg = "You appear to be using a legacy multi-label data representation"
+ with pytest.raises(ValueError, match=err_msg):
LabelBinarizer().fit_transform(y_seq_of_seqs)
# Fail on the number of classes
- with pytest.raises(ValueError):
+ err_msg = "The number of class is not equal to the number of dimension of y."
+ with pytest.raises(ValueError, match=err_msg):
_inverse_binarize_thresholding(
y=csr_matrix([[1, 2], [2, 1]]),
output_type="foo",
@@ -165,7 +177,8 @@ def test_label_binarizer_errors():
)
# Fail on the dimension of 'binary'
- with pytest.raises(ValueError):
+ err_msg = "output_type='binary', but y.shape"
+ with pytest.raises(ValueError, match=err_msg):
_inverse_binarize_thresholding(
y=np.array([[1, 2, 3], [2, 1, 3]]),
output_type="binary",
@@ -174,9 +187,10 @@ def test_label_binarizer_errors():
)
# Fail on multioutput data
- with pytest.raises(ValueError):
+ err_msg = "Multioutput target data is not supported with label binarization"
+ with pytest.raises(ValueError, match=err_msg):
LabelBinarizer().fit(np.array([[1, 3], [2, 1]]))
- with pytest.raises(ValueError):
+ with pytest.raises(ValueError, match=err_msg):
label_binarize(np.array([[1, 3], [2, 1]]), classes=[1, 2, 3])
You can use it to edit your PR. |
fit
for LabelBinarizer
Thank you, @glemaitre ! I've just updated doc/whats_new/v1.1. and test_label.py. On a second review, I am wondering if it is an issue that input will be tested in transform, but not in fit (as the tests are performed by label_binarize in transform only)? |
If you check my diff, I move the validation in diff --git a/sklearn/preprocessing/_label.py b/sklearn/preprocessing/_label.py
index 72bd8c6d65..12f6d08b5c 100644
--- a/sklearn/preprocessing/_label.py
+++ b/sklearn/preprocessing/_label.py
@@ -275,6 +275,19 @@ class LabelBinarizer(TransformerMixin, BaseEstimator):
self : object
Returns the instance itself.
"""
+ if self.neg_label >= self.pos_label:
+ raise ValueError(
+ f"neg_label={self.neg_label} must be strictly less than "
+ f"pos_label={self.pos_label}."
+ )
+
+ if self.sparse_output and (self.pos_label == 0 or self.neg_label != 0):
+ raise ValueError(
+ "Sparse binarization is only supported with non "
+ "zero pos_label and zero neg_label, got "
+ f"pos_label={self.pos_label} and neg_label={self.neg_label}"
+ )
+
self.y_type_ = type_of_target(y)
if "multioutput" in self.y_type_:
raise ValueError( Otherwise, we don't do any validation. |
Apologies, I missed that - corrected now! |
Can you run black on the 2 files that are detected by our linter: https://dev.azure.com/scikit-learn/scikit-learn/_build/results?buildId=34025&view=logs&jobId=32e2e1bb-a28f-5b18-6cfc-3f01273f5609&j=32e2e1bb-a28f-5b18-6cfc-3f01273f5609&t=fc67071d-c3d4-58b8-d38e-cafc0d3c731a In the future, you might want to install |
fit
for LabelBinarizer
fit
for LabelBinarizer
estimator
Thank you for your patience and suggestion - I will use pre-commit. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM.
Thank you, @krumetoft, for this first-time contribution!
Thank you, @jjerphan! I need to thank Guillaume for his patience and guidance. |
You can change a local git repository config to use another identity (e.g. @krumeto in this case) using |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Minor comments about test. Otherwise LGTM
Hey @thomasjpfan, thank you, it makes sense! I changed test_label.py accordingly. |
Regarding your last commit, @
8000
krumetoft, you can use the |
Thank you, @jjerphan! I saw the piece of code in one of Olivier's comments and decided to give it a run :) |
Arff I did not see that there is a conflict. @krumetoft @krumeto Could you solve the merge conflict? |
Attempt to resolve merge conflict
Co-authored-by: Thomas J. Fan <thomasjpfan@gmail.com>
Co-authored-by: Thomas J. Fan <thomasjpfan@gmail.com>
646119c
to
11d80f1
Compare
Hey @glemaitre The merge conflict is resolved (line 292 in _label.py). |
Thank you, @krumeto! |
Reference Issues/PRs
This closes the LabelBinarizer part of #21406
What does this implement/fix? Explain your changes.
Removed a parameter check in init of LabelBinarizer, that is also done in function label_binarize during transform.
#DataUmbrella sprint