8000 FIX Fix device detection when array API dispatch is disabled by lesteve · Pull Request #30454 · scikit-learn/scikit-learn · GitHub
[go: up one dir, main page]

Skip to content

FIX Fix device detection when array API dispatch is disabled #30454

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Conversation

lesteve
Copy link
Member
@lesteve lesteve commented Dec 10, 2024

Fix #29107, in particular #29107 (comment) which is a regression in 1.6.

Main change: with array_api_dispatch=False, the device is always None (I tried "cpu" originally but this doesn't 8000 work with array-api-strict). In this case np.asarray will be called and the resulting array will always be CPU. We don't want to check devices too early and prevent the np.asarray conversion to happen (there may be an error like np.asarray with a PyTorch array on a CUDA device).

Tests added:

  • calling metric with array API inputs and array API disabled, this is pretty much the new regression in 1.6 Incorrect invalid device error introduced in #25956 #29107 (comment)
  • smoke-test in estimator_checks that makes sure that when array API is disabled you can fit and call other methods on numpy-convertible arrays (for example PyTorch array on CPU)

Copy link
github-actions bot commented Dec 10, 2024

✔️ Linting Passed

All linting checks passed. Your pull request is in excellent shape! ☀️

Generated for commit: ce983e8. Link to the linter CI: here

@lesteve
Copy link
Member Author
lesteve commented Dec 10, 2024

Oh well some failures with array-api-strict that need to be investigated ...

@virchan
Copy link
Member
virchan commented Dec 11, 2024

About the error message for the test_device_inspection function, I think the following part:

        # Test expected value is returned otherwise
        array1 = Array("device")
        array2 = Array("device")

        assert array1.device == device(array1)
        assert array1.device == device(array1, array2)
        assert array1.device == device(array1, array1, array2)

should be inside the with config_context(array_api_dispatch=True): block. This is because device(array1) will always return "cpu" when array_api_dispatch=False, which is not the intended behaviour for this test.

@virchan
Copy link
Member
virchan commented Dec 11, 2024

Regarding the error message for the test_fill_or_add_to_diagonal function, I think the following code might address the issue:

    with config_context(array_api_dispatch=True):
        _fill_or_add_to_diagonal(array_xp, value=1, xp=xp, add_value=False, wrap=wrap)
    assert_array_equal(_convert_to_numpy(array_xp, xp=xp), array_np)

It suppresses the error message, but I'm not entirely sure if this is the correct fix...

@lesteve
Copy link
Member Author
lesteve commented Dec 11, 2024

CUDA CI passed on 0073b09

Copy link
Member
@ogrisel ogrisel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need a changelog entry. A few cosmetic suggestions but otherwise the fix itself LGTM.

8000
# we end up doing X[train_indices] where X is a array-api-strict array
# and indices a numpy array. Probably not worth investigating for now,
# since using array-api-strict with array API disabled does not seem a
# very relevant.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

While I agree that we don't really care about supporting array-api-strict inputs when array API dispatch is disabled, I still think that those X[train_indices] are bugs in our code. We should probably always do X[xp.asarray(train_indices, device=device(X)] in our CV tools.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tried to tweak the comment a bit to be more self-explanatory. I am not sure it is so easy to tackle and discussing this in real-life this is probably not worth doing anything about it for now. All the array API library I could try (PyTorch, jax, CuPy) accept to index with numpy array.

For compleness, the proposed work-around would not work because the issue is that X is an array-api-strict array, train_indices is a numpy array and array-api-strict is insisting that to index X you need to pass array-api-strict array.

lesteve and others added 3 commits December 11, 2024 15:05
@jeremiedbb jeremiedbb added this to the 1.6.1 milestone Dec 11, 2024
lesteve and others added 2 commits December 11, 2024 17:31
Co-authored-by: Olivier Grisel <olivier.grisel@ensta.org>
Copy link
Contributor
@OmarManzoor OmarManzoor left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks good. Thanks @lesteve

Just some minor suggestions.

Co-authored-by: Omar Salman <omar.salman2007@gmail.com>
@OmarManzoor OmarManzoor merged commit 9d59e8e into scikit-learn:main Dec 18, 2024
29 of 30 checks passed
jeremiedbb pushed a commit to jeremiedbb/scikit-learn that referenced this pull request Jan 8, 2025
…learn#30454)

Co-authored-by: Olivier Grisel <olivier.grisel@ensta.org>
Co-authored-by: Omar Salman <omar.salman2007@gmail.com>
jeremiedbb pushed a commit that referenced this pull request Jan 9, 2025
Co-authored-by: Olivier Grisel <olivier.grisel@ensta.org>
Co-authored-by: Omar Salman <omar.salman2007@gmail.com>
@lesteve lesteve deleted the fix-device-detection-array-api-dispatch-disabled branch January 15, 2025 15:05
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Incorrect invalid device error introduced in #25956
5 participants
0