-
-
Notifications
You must be signed in to change notification settings - Fork 25.8k
add array api support in label binarizer #28626
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Closed
jeromedockes
wants to merge
18
commits into
scikit-learn:main
from
jeromedockes:arrayapi-labelbinarizer
Closed
Changes from all commits
Commits
Show all changes
18 commits
Select commit
Hold shift + click to select a range
8191596
add array api support in label binarizer
jeromedockes b56425d
Merge remote-tracking branch 'upstream/main' into arrayapi-labelbinar…
jeromedockes 1c08d5b
update label_binarize
jeromedockes 2a69ad3
Merge remote-tracking branch 'upstream/main' into arrayapi-labelbinar…
jeromedockes d3be4cc
do all label binarizing in numpy
jeromedockes 0e6b716
convert output of inverse_transform
jeromedockes d718c07
add test
jeromedockes c4a4941
fix inverse_transform for sparse Y
jeromedockes 9ea0c55
update changelog and array_api.rst
jeromedockes ba117f3
add test for binary case
jeromedockes 8e859a3
Merge remote-tracking branch 'upstream/main' into arrayapi-labelbinar…
jeromedockes 35d31c0
Merge main + fix conflict in import statements
ogrisel a8a270e
Fix broken test with pytorch on a non-CPU device
ogrisel 86e7daf
Apply suggestions from code review
jeromedockes 21c441a
Merge remote-tracking branch 'upstream/main' into arrayapi-labelbinar…
jeromedockes 632f4e8
formatting
jeromedockes 54ff366
add test for case where y is constant & for transform (in addition to…
jeromedockes 9f4762b
fix text removed from whatsnew in merge
jeromedockes File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -17,6 +17,7 @@ | |
|
||
from ..base import BaseEstimator, TransformerMixin, _fit_context | ||
from ..utils import column_or_1d | ||
from ..utils._array_api import _convert_to_numpy, device, get_namespace | ||
from ..utils._encode import _encode, _unique | ||
from ..utils._param_validation import Interval, validate_params | ||
from ..utils.multiclass import type_of_target, unique_labels | ||
|
@@ -303,7 +304,8 @@ def fit(self, y): | |
raise ValueError("y has 0 samples: %r" % y) | ||
|
||
self.sparse_input_ = sp.issparse(y) | ||
self.classes_ = unique_labels(y) | ||
xp, _ = get_namespace(y) | ||
self.classes_ = _convert_to_numpy(unique_labels(y), xp) | ||
return self | ||
|
||
def fit_transform(self, y): | ||
|
@@ -396,6 +398,21 @@ def inverse_transform(self, Y, threshold=None): | |
""" | ||
check_is_fitted(self) | ||
|
||
# LabelBinarizer supports Array API compatibility for convenience when | ||
# used as a sub-component of an classifier that does. However | ||
# label_binarize internally uses a NumPy copy of the data because | ||
# all the operations are meant to construct the backing NumPy arrays of a | ||
# scipy.sparse CSR datastructure even when sparse_output=False. | ||
# | ||
# In the future, we might consider a dedicated code path for the | ||
# sparse_output=False case that would directly be implemented using Array | ||
# API without the intermediate NumPy conversion and scipy.sparse | ||
# datastructure. | ||
xp, is_array_api_compliant = get_namespace(Y) | ||
jeromedockes marked this conversation as resolved.
Show resolved
Hide resolved
|
||
device_ = device(Y) if is_array_api_compliant else None | ||
if not sp.issparse(Y): | ||
Y = _convert_to_numpy(Y, xp) | ||
|
||
if threshold is None: | ||
threshold = (self.pos_label + self.neg_label) / 2.0 | ||
|
||
|
@@ -410,11 +427,13 @@ def inverse_transform(self, Y, threshold=None): | |
y_inv = sp.csr_matrix(y_inv) | ||
elif sp.issparse(y_inv): | ||
y_inv = y_inv.toarray() | ||
if is_array_api_compliant and not sp.issparse(y_inv): | ||
y_inv = xp.asarray(y_inv, device=device_) | ||
|
||
return y_inv | ||
|
||
def _more_tags(self): | ||
return {"X_types": ["1dlabels"]} | ||
return {"X_types": ["1dlabels"], "array_api_support": True} | ||
|
||
|
||
@validate_params( | ||
|
@@ -487,6 +506,18 @@ def label_binarize(y, *, classes, neg_label=0, pos_label=1, sparse_output=False) | |
[0], | ||
[1]]) | ||
""" | ||
# label_binarize supports Array API compatibility for convenience when | ||
# LabelBinarizer is used as a sub-component of an classifier that does. | ||
# However label_binarize internally uses a NumPy copy of the data because | ||
# all the operations are meant to construct the backing NumPy arrays of a | ||
# scipy.sparse CSR datastructure even when sparse_output=False. | ||
y_xp, y_is_array_api = get_namespace(y) | ||
jeromedockes marked this conversation as resolved.
Show resolved
Hide resolved
|
||
if y_is_array_api: | ||
device_ = device(y) | ||
y = _convert_to_numpy(y, y_xp) | ||
classes_xp, classes_is_array_api = get_namespace(classes) | ||
if classes_is_array_api: | ||
classes = _convert_to_numpy(classes, classes_xp) | ||
if not isinstance(y, list): | ||
# XXX Workaround that will be removed when list of list format is | ||
# dropped | ||
|
@@ -535,7 +566,10 @@ def label_binarize(y, *, classes, neg_label=0, pos_label=1, sparse_output=False) | |
else: | ||
Y = np.zeros((len(y), 1), dtype=int) | ||
Y += neg_label | ||
return Y | ||
|
||
if not y_is_array_api: | ||
return Y | ||
return y_xp.asarray(Y, device=device_) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could you please add a new test case to cover that line? |
||
elif len(classes) >= 3: | ||
y_type = "multiclass" | ||
|
||
|
@@ -595,7 +629,9 @@ def label_binarize(y, *, classes, neg_label=0, pos_label=1, sparse_output=False) | |
else: | ||
Y = Y[:, -1].reshape((-1, 1)) | ||
|
||
return Y | ||
if not y_is_array_api: | ||
return Y | ||
return y_xp.asarray(Y, devic A92E e=device_) | ||
|
||
|
||
def _inverse_binarize_multiclass(y, classes): | ||
|
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why do we need to convert
classes_
to numpy?This is related to #26083 as well.
This right now is a bit confusing, since fitting an estimator which supports array API ends up with an object where some attributes are in the same space as input X, and some are still in numpy.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think that so far the policy is: store the fitted attributes with the array type that makes most sense to be efficient at prediction time assuming the prediction-time data container type will be consistent with the fit-time data container type. Since this PR only changes
LabelBinarizer
for convenience without actually delegating any computation to the underlying Array API namespace (see the inline comments), I think it's better to always keepclasses_
as a numpy array for now.If one day we decide to recode
label_binarize
to actually delegate some computation to the underlying namespace, then we think about make the type ofclasses_
input dependent instead.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Then I'm confused. Doesn't
check_array
already convert data tonumpy
? So we already support non-numpy data.If the point is to have the output of
predict
in the same space as what user gives, shouldn't that be like a decorator or something aroundpredict
? Or a simply function call at the end of predict?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
that is quite close to what is done here: we record the input namespace and device at the beginning of the function, convert to numpy and do everything in numpy, and convert back to the input format and device where the function returns.
As you say, it could probably be implemented as a decorator applied to
fit
,label_binarize
andinverse_transform