8000 Add torch.accelerator.device_index as accelerator's device switch context by guangyey · Pull Request #148864 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content

Add torch.accelerator.device_index as accelerator's device switch context #148864

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 22 commits into from
Prev Previous commit
Next Next commit
Update
[ghstack-poisoned]
  • Loading branch information
guangyey committed Apr 23, 2025
commit fe60ddcfc2c738d90ae32947ac99180c6f7c0544
4 changes: 3 additions & 1 deletion torch/accelerator/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,10 +197,13 @@ class device_index:
Temporarily changes the current device index to the specified value for the duration
of the context, and automatically restores the previous device index when exiting
the context.

Args:
device (Optional[int]): a given device index to temporarily set. If None,
no device index switching occurs.

Examples:

>>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA)
>>> # Set device 0 as the current device temporarily
>>> with torch.accelerator.device_index(0):
Expand All @@ -212,7 +215,6 @@ class device_index:
... # No device switching occurs
... pass
"""

def __init__(self, device: Optional[int], /) -> None:
self.idx = device
self.prev_idx = -1
Expand Down
Loading
0