-
Notifications
You must be signed in to change notification settings - Fork 24.2k
Adding XPU support to DTensor examples #153213
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/153213
Note: Links to docs will display an error until the docs builds have been completed. ❗ 1 Active SEVsThere are 1 currently active SEVs. If your PR is affected, please view them below: ❌ 3 New Failures, 14 Unrelated FailuresAs of commit 1add9bb with merge base 7cb5c75 ( NEW FAILURES - The following jobs have failed:
FLAKY - The following jobs failed but were likely due to flakiness present on trunk:
BROKEN TRUNK - The following jobs failed but were present on the merge base:👉 Rebase onto the `viable/strict` branch to avoid these failures
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
@pytorchbot label "topic: not user facing" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pls. help fix the linter failure.
torch/distributed/tensor/examples/comm_mode_features_example.py
Outdated
Show resolved
Hide resolved
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
This PR adds support for XPU in DTensor examples by updating how the device type is determined in the visualization and communication mode feature examples.
- Updated device mesh initialization in visualize_sharding_example.py using the accelerator’s current type.
- Removed the legacy get_device_type() function in comm_mode_features_example.py and replaced it with torch.accelerator.current_accelerator().type.
Reviewed Changes
Copilot reviewed 2 out of 2 changed files in this pull request and generated no comments.
File | Description |
---|---|
torch/distributed/tensor/examples/visualize_sharding_example.py | Replaces hardcoded "cuda" with the current accelerator type for device mesh creation. |
torch/distributed/tensor/examples/comm_mode_features_example.py | Removes the get_device_type() function and uses the accelerator's device type directly. |
Comments suppressed due to low confidence (2)
torch/distributed/tensor/examples/visualize_sharding_example.py:20
- The assignment to device_type assumes that torch.accelerator.current_accelerator() always returns a valid accelerator; consider adding documentation or a fallback mechanism to handle cases where no accelerator is available.
device_type = torch.accelerator.current_accelerator().type
torch/distributed/tensor/examples/comm_mode_features_example.py:44
- Directly assigning self.device_type based on torch.accelerator.current_accelerator().type relies on the presence of an accelerator; consider clarifying this assumption or providing fallback behavior for environments without an XPU.
self.device_type = torch.accelerator.current_accelerator().type
@@ -49,7 +41,7 @@ class CommDebugModeExample: | |||
def __init__(self, world_size: int, rank: int) -> None: | |||
self.world_size = world_size | |||
self.rank = rank | |||
self.device_type = get_device_type() | |||
self.device_type = torch.accelerator.current_accelerator().type if torch.accelerator.current_accelerator() and torch.accelerator.device_count() else 'cpu' |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@githubsgi , for cuda, it requires the device count should be greater than 4 - torch.cuda.device_count() >= 4
. Shoud torch.accelerator.device_count()
be equal and greater than 4 as well?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was an assert added above, which makes "torch.cuda.device_count() >= 4 " redundant.
assert int(os.getenv("WORLD_SIZE", "1")) >= 4, "We need at least 4 devices"
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
WORD_SIZE
may mean multiple nodes, while torch.cuda.device_count()
implies a single node with multiple devices. It may be okay for the example now. @kwen2501 , any comments?
Not sure why the lint checkers are complaining about the following .
@EikanWang @colesbury , do you have any insight ? |
Adds XPU support to visualize_sharding_example.py and comm_mode_features_example.py .
topic: not user facing
cc @H-Huang @awgu @wanchaol @fegin @fduwjj @wz337 @wconstab @d4l3k
rebasing of #152973