8000 Fixing bug in `get_namespace_and_device`. by drivanov · Pull Request #30647 · scikit-learn/scikit-learn · GitHub
[go: up one dir, main page]

Skip to content

Fixing bug in get_namespace_and_device. #30647

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from

Conversation

drivanov
Copy link

Reference Issues/PRs

What does this implement/fix? Explain your changes.

When we upgraded to scikit-learn version 1.6.0 we encountered a bug:

Traceback (most recent call last):
  File "/workspace/examples/graph_sage_unsup.py", line 80, in <module>
    val_acc, test_acc = test()
                        ^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/_contextlib.py", line 116, in decorate_context
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/workspace/examples/graph_sage_unsup.py", line 70, in test
    val_acc = clf.score(out[data.val_mask], data.y[data.val_mask])
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/sklearn/base.py", line 572, in score
    return accuracy_score(y, self.predict(X), sample_weight=sample_weight)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/sklearn/utils/_param_validation.py", line 216, in wrapper
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/sklearn/metrics/_classification.py", line 224, in accuracy_score
    xp, _, device = get_namespace_and_device(y_true, y_pred, sample_weight)
                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/sklearn/utils/_array_api.py", line 614, in get_namespace_and_device
    arrays_device = device(*array_list, **skip_remove_kwargs)
                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/sklearn/utils/_array_api.py", line 178, in device
    raise ValueError(
Valu
8000
eError: Input arrays use different devices: cpu, cpu

It appears that device(*array_list, **skip_remove_kwargs) should only be called when is_array_api is True, as it was correctly handled in version 1.5.2:

    xp, is_array_api = get_namespace(*array_list, **skip_remove_kwargs)
    if is_array_api:
        return (
            xp,
            is_array_api,
            device(*array_list, **skip_remove_kwargs),
        )
    else:
        return xp, False, None

The proposed changes corrected this error.

Any other comments?

Copy link

✔️ Linting Passed

All linting checks passed. Your pull request is in excellent shape! ☀️

Generated for commit: 0ce7003. Link to the linter CI: here

@lesteve
Copy link
Member
lesteve commented Jan 15, 2025

Thanks for the PR, can you make sure it has not been fixed in 1.6.1? It seems somewhat similar to #30454.

@drivanov
Copy link
Author

You were right - this was fixed in version 1.6.1. Closing.

@drivanov drivanov closed this Jan 15, 2025
@lesteve
Copy link
Member
lesteve commented Jan 16, 2025

Great to hear that it was the same issue, thanks for your feed-back!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants
0