-
-
Notifications
You must be signed in to change notification settings - Fork 25.9k
array API support for cosine_distances #29265
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
array API support for cosine_distances #29265
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Here is some feedback. Please also test that this works on cuda devices with PyTorch and CuPy using:
https://gist.github.com/EdAbati/ff3bdc06bafeb92452b3740686cc8d7c
Merging main to hopefully get a green CI on this PR. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM. Thanks @EmilyXiny!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the PR @EmilyXinyi. A few comments otherwise looks good.
da921c4
to
12d5569
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM. Thanks @EmilyXinyi
Reference Issues/PRs
towards #26024
What does this implement/fix? Explain your changes.
array API support for
cosine_distances
Any other comments?
.clip
is supported in the 2023.12 version but not the one that we are currently using, so I created a function insklearn.utils._array_api
for now.fill_diagonal
is not supported and I don't see it being in the plan of being supported in the array-api repo, so I created an alternative._fill_diagonal
insklearn.metrics.pairwise
Please let me know if there are any changes I should make, thanks!