-
Notifications
You must be signed in to change notification settings - Fork 24.7k
[DeviceMesh] Add some documentation for from_group
API and add a 2D test
#146364
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/146364
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 2842508 with merge base 810d2a3 ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
@@ -823,11 +846,17 @@ def from_group( | |||
(_get_group_tag(group), group_ranks, group.group_name) | |||
] | |||
return device_mesh | |||
|
|||
# nD scenario | |||
groups = list(group) | |||
if len(groups) == 0: | |||
raise ValueError("Expects at least one ProcessGroup to be passed") | |||
if mesh is None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not a blocker but I have a question around this condition, do we really need this one? I mean can we infer this? we can just assume the size of passed in PG as the mesh, no? Also, if mesh is passed in, we don't need mesh_dim_names right? They are OR not AND?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Disccued offline regarding keeping mesh
tensor required for now but this is subject to change.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
stamp to unblock, the unit test looks good to me. But I have some questions regarding the args needed.
79e8922
to
c13af6e
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Left some comments, may need to rethink the API signature.
of the default process group. Default is None. | ||
mesh_dim_names (tuple[str], optional): A tuple of mesh dimension names to assign | ||
to each dimension of the multi-dimensional array describing the layout of devices. | ||
Its length must match the length of `mesh_shape`. Each string in `mesh_dim_names` |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
mesh.shape
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
mesh
and mesh_dim_names
are asymmetrical. One can be positional and another one must be keyword, but they are used together.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ye...It is inconsistent. This is due to the fact that originally when we introduced init_device_mesh
, mesh_dim_names
is not required, and it is also an optional keyword arg...
https://github.com/pytorch/pytorch/blob/main/torch/distributed/device_mesh.py#L940-L945
Also, the mesh tensor should be mesh_shape in order to be consistent with init_device_mesh
, so it may be good to do breaking change now instead of later.
@pytorchmergebot rebase |
@pytorchbot started a rebase job onto refs/remotes/origin/viable/strict. Check the current status here |
Successfully rebased |
c13af6e
to
ede80d2
Compare
@pytorchmergebot rebase |
@pytorchbot started a rebase job onto refs/remotes/origin/viable/strict. Check the current status here |
Successfully rebased |
ede80d2
to
4e62c31
Compare
@pytorchmergebot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Fixes #ISSUE_NUMBER
cc @H-Huang @awgu @kwen2501 @wanchaol @fegin @fduwjj @wconstab @d4l3k @c-p-i-o @tianyu-l @XilunWu