8000 FIX skip array API tests when running with device="mps" without the PYTORCH_ENABLE_MPS_FALLBACK env var by betatim · Pull Request #27199 · scikit-learn/scikit-learn · GitHub
[go: up one dir, main page]

Skip to content

FIX skip array API tests when running with device="mps" without the PYTORCH_ENABLE_MPS_FALLBACK env var #27199

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Sep 7, 2023

Conversation

betatim
Copy link
Member
@betatim betatim commented Aug 29, 2023

PCA uses parts of PyTorch that are not supported on the MPS device, this changes the estimator to raise an explicit exception with information on what to do.

I had to change the common Array API specific tests to handle the exception. There are more tests in the general common tests that would need adjusting. Not quite sure what to do.

  • sklearn/utils/tests/test_estimator_checks.py::test_check_estimator_clones
  • sklearn/tests/test_common.py::test_estimators[PCA()-check_array_api_input(array_namespace=torch,dtype=float32,device=mps)]

How is this handled for other cases where an estimator doesn't support a particular setup that is tested in the common tests?

PCA uses parts of PyTorch that are not supported on the MPS device, this
changes the estimator to raise an explicit exception with information on
what to do.
@github-actions
Copy link
github-actions bot commented Aug 29, 2023

✔️ Linting Passed

All linting checks passed. Your pull request is in excellent shape! ☀️

Generated for commit: f80e120. Link to the linter CI: here

@betatim betatim changed the title Change PCA to raise when MPS device is used FIX Change PCA to raise when MPS device is used Aug 30, 2023
@ogrisel
Copy link
Member
ogrisel commented Aug 31, 2023

I would rather not have pytorch specific things in the scikit-learn code base.

I think the lower level error message returned by PyTorch is explicit enough, no?

@ogrisel
Copy link
Member
ogrisel commented Aug 31, 2023

As discussed at the bi-weekly Array API meeting, it would be better to find a way to skip the device="mps" tests when os.getenv("PYTORCH_ENABLE_MPS_FALLBACK") != "1".

Not 100% sure how to achieve that without adding more boilerplate though.

EDIT: it should be easy to add a condition to the _array_api_for_tests helper function.

8000
@betatim
Copy link
Member Author
betatim commented Sep 5, 2023

Good point. I've updated the code

Copy link
Member
@ogrisel ogrisel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks for the update.

@ogrisel ogrisel added Quick Review For PRs that are quick to review Waiting for Second Reviewer First reviewer is done, need a second one! labels Sep 5, 2023
@ogrisel ogrisel changed the title FIX Change PCA to raise when MPS device is used FIX skip array API tests when running with device="mps" without the PYTORCH_ENABLE_MPS_FALLBACK env var Sep 5, 2023
@ogrisel
Copy link
Member
ogrisel commented Sep 7, 2023

/cc @thomasjpfan for a quick review.

@ogrisel
Copy link
Member
ogrisel commented Sep 7, 2023

Ooops I was wrong with the conflict resolution via the github UI. Let me push a fix quickly.

@ogrisel
Copy link
Member
ogrisel commented Sep 7, 2023

It should be good now.

@thomasjpfan thomasjpfan merged commit b7d80a0 into scikit-learn:main Sep 7, 2023
@betatim betatim deleted the array-api-mps-fixup branch September 8, 2023 07:42
REDVM pushed a commit to REDVM/scikit-learn that referenced this pull request Nov 16, 2023
…YTORCH_ENABLE_MPS_FALLBACK env var (scikit-learn#27199)

Co-authored-by: Olivier Grisel <olivier.grisel@ensta.org>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Array API module:decomposition Quick Review For PRs that are quick to review Waiting for Second Reviewer First reviewer is done, need a second one!
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants
0