-
Notifications
You must be signed in to change notification settings - Fork 24.7k
Adding XPU support to DTensor examples #153213
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/153213
Note: Links to docs will display an error until the docs builds have been completed. ❌ 2 New Failures, 1 Cancelled JobAs of commit 825747f with merge base efbf07e ( NEW FAILURES - The following jobs have failed:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
@pytorchbot label "topic: not user facing" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pls. help fix the linter failure.
torch/distributed/tensor/examples/comm_mode_features_example.py
Outdated
Show resolved
Hide resolved
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
This PR adds support for XPU in DTensor examples by updating how the device type is determined in the visualization and communication mode feature examples.
- Updated device mesh initialization in visualize_sharding_example.py using the accelerator’s current type.
- Removed the legacy get_device_type() function in comm_mode_features_example.py and replaced it with torch.accelerator.current_accelerator().type.
Reviewed Changes
Copilot reviewed 2 out of 2 changed files in this pull request and generated no comments.
File | Description |
---|---|
torch/distributed/tensor/examples/visualize_sharding_example.py | Replaces hardcoded "cuda" with the current accelerator type for device mesh creation. |
torch/distributed/tensor/examples/comm_mode_features_example.py | Removes the get_device_type() function and uses the accelerator's device type directly. |
Comments suppressed due to low confidence (2)
torch/distributed/tensor/examples/visualize_sharding_example.py:20
- The assignment to device_type assumes that torch.accelerator.current_accelerator() always returns a valid accelerator; consider adding documentation or a fallback mechanism to handle cases where no accelerator is available.
device_type = torch.accelerator.current_accelerator().type
torch/distributed/tensor/examples/comm_mode_features_example.py:44
- Directly assigning self.device_type based on torch.accelerator.current_accelerator().type relies on the presence of an accelerator; consider clarifying this assumption or providing fallback behavior for environments without an XPU.
self.device_type = torch.accelerator.current_accelerator().type
@@ -49,7 +41,7 @@ class CommDebugModeExample: | |||
def __init__(self, world_size: int, rank: int) -> None: | |||
self.world_size = world_size | |||
self.rank = rank | |||
self.device_type = get_device_type() | |||
self.device_type = torch.accelerator.current_accelerator().type if torch.accelerator.current_accelerator() and torch.accelerator.device_count() else 'cpu' |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@githubsgi , for cuda, it requires the device count should be greater than 4 - torch.cuda.device_count() >= 4
. Shoud torch.accelerator.device_count()
be equal and greater than 4 as well?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was an assert added above, which makes "torch.cuda.device_count() >= 4 " redundant.
assert int(os.getenv("WORLD_SIZE", "1")) >= 4, "We need at least 4 devices"
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
WORD_SIZE
may mean multiple nodes, while torch.cuda.device_count()
implies a single node with multiple devices. It may be okay for the example now. @kwen2501 , any comments?
Not sure why the lint checkers are complaining about the following .
@EikanWang @colesbury , do you have any insight ? |
Any update on my question above ? |
@@ -17,6 +17,8 @@ | |||
assert int(os.getenv("WORLD_SIZE", "1")) >= 4, "We need at least 4 devices" | |||
rank = int(os.environ["RANK"]) | |||
|
|||
device_type = 'cpu' if not torch.accelerator.current_accelerator() else torch.accelerator.current_accelerator().type |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
device_type = 'cpu' if not torch.accelerator.current_accelerator() else torch.accelerator.current_accelerator().type | |
device_type = 'cpu' if not torch.accelerator.is_available() else torch.accelerator.current_accelerator().type |
@@ -49,7 +42,7 @@ class CommDebugModeExample: | |||
def __init__(self, world_size: int, rank: int) -> None: | |||
self.world_size = world_size | |||
self.rank = rank | |||
self.device_type = get_device_type() | |||
self.device_type = 'cpu' if not torch.accelerator.current_accelerator() else torch.accelerator.current_accelerator().type |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
self.device_type = 'cpu' if not torch.accelerator.current_accelerator() else torch.accelerator.current_accelerator().type | |
self.device_type = get_device_type() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think you change the semantic about get_device_type
. The better way is
def get_device_type() -> str:
return (
torch.accelerator.current_accelerator().type
if torch.accelerator.device_count() > 4
else "cpu"
)
|
Adds XPU support to visualize_sharding_example.py and comm_mode_features_example.py .
topic: not user facing
cc @H-Huang @awgu @wanchaol @fegin @fduwjj @wz337 @wconstab @d4l3k
rebasing of #152973