-
-
Notifications
You must be signed in to change notification settings - Fork 25.9k
FIX skip array API tests when running with device="mps" without the PYTORCH_ENABLE_MPS_FALLBACK env var #27199
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
PCA uses parts of PyTorch that are not supported on the MPS device, this changes the estimator to raise an explicit exception with information on what to do.
I would rather not have pytorch specific things in the scikit-learn code base. I think the lower level error message returned by PyTorch is explicit enough, no? |
As discussed at the bi-weekly Array API meeting, it would be better to find a way to skip the Not 100% sure how to achieve that without adding more boilerplate though. EDIT: it should be easy to add a condition to the |
Good point. I've updated the code |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thanks for the update.
/cc @thomasjpfan for a quick review. |
Ooops I was wrong with the conflict resolution via the github UI. Let me push a fix quickly. |
It should be good now. |
…YTORCH_ENABLE_MPS_FALLBACK env var (scikit-learn#27199) Co-authored-by: Olivier Grisel <olivier.grisel@ensta.org>
PCA uses parts of PyTorch that are not supported on the MPS device, this changes the estimator to raise an explicit exception with information on what to do.
I had to change the common Array API specific tests to handle the exception. There are more tests in the general common tests that would need adjusting. Not quite sure what to do.
sklearn/utils/tests/test_estimator_checks.py::test_check_estimator_clones
sklearn/tests/test_common.py::test_estimators[PCA()-check_array_api_input(array_namespace=torch,dtype=float32,device=mps)]
How is this handled for other cases where an estimator doesn't support a particular setup that is tested in the common tests?