-
Notifications
You must be signed in to change notification settings - Fork 24.7k
[device_mesh] add back the private init backend option #124780
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
This PR adds a private init backend option, to tackle the issues sub mesh creation: in device mesh slicing we don't want to create process groups again, so explicitly turn the group creation off it's useful Also I think there might be more submesh creation functionality so having this flag would ensure that there's no new group created [ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/124780
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 08f0bae with merge base afa78ad ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
Oof, this is helpful. I was just writing the same code locally 😅 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
SGTM!
torch/distributed/device_mesh.py
Outdated
# TODO(yeounoh) implement DeviceMesh backend and register XLA backend. | ||
if device_type != "xla": | ||
if device_type != "xla" and _init_backend: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If we skip _get_or_create_default_group()
, then self._coordinate_on_dim
will not be defined when _init_backend=False
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nevermind! This is intentional from the unit tests.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
oh I think we should fix this it seems, I'll do it as a follow up PR
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think I did similar thing for nD submesh slicing and not initing the sub pgs again for the mesh slice won't work as there is dim_group_infos for the mesh slice. iirc, that's why I need to revert.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, upon more investigation, I think we need to populate both self._coordinate_on_dim
and self._dim_group_infos
, so this current change should break unit tests IIUC.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It won't break any test yet as _init_backend is default to True, but a followup is needed for most of the API to work with _init_backend = False.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we would always want the default _init_backend
to be True, only when we do submesh slicing we would manually turn it off
@@ -89,6 +89,7 @@ def create_child_mesh( | |||
device_mesh.device_type, | |||
mesh_1d, | |||
mesh_dim_names=(mesh_dim_name,), | |||
_init_backend=False, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@wz337 Even though _init_backend
is defaulted to True
, here we are passing it as False
for child meshes. For FSDP2 2D training, I was thinking that this means the child meshes are now missing the self._coordinate_on_dim
and self._dim_group_infos
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
there seems to have some test failures, so trying to fix it directly in this PR
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
tested locally should fixed those tests
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How about _self._dim_group_infos
? I think we need that to support .get_group()
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
those are already populated when we create the submesh?https://github.com/pytorch/pytorch/blob/main/torch/distributed/device_mesh.py#L96
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah, got it! Makes sense!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
basically when use init_backend=False
I think there would be some manual pg creation or pg reuse, so we would need to explicitly attach the dim_group_info after DeviceMesh
constructor?
i.e. for things like from_group
, inside this API after DeviceMesh
constructed, we need to attach the dim_group_info to the DeviceMesh created, the dim_group_info
need to be generated in a similar way as https://github.com/pytorch/pytorch/blob/main/torch/distributed/device_mesh.py#L314-L319
This PR adds a private init backend option, to tackle the issues sub mesh creation: in device mesh slicing we don't want to create process groups again, so explicitly turn the group creation off it's useful Also I think there might be more submesh creation functionality so having this flag would ensure that there's no new group created cc mrshenli pritamdamania87 zhaojuanmao satgera rohan-varma gqchen aazzolini osalpekar jiayisuse H-Huang kwen2501 awgu penguinwu fegin XilunWu fduwjj wz337 tianyu-l wconstab yf225 chauhang d4l3k [ghstack-poisoned]
This PR adds a private init backend option, to tackle the issues sub mesh creation: in device mesh slicing we don't want to create process groups again, so explicitly turn the group creation off it's useful Also I think there might be more submesh creation functionality so having this flag would ensure that there's no new group created cc mrshenli pritamdamania87 zhaojuanmao satgera rohan-varma gqchen aazzolini osalpekar jiayisuse H-Huang kwen2501 awgu penguinwu fegin XilunWu fduwjj wz337 tianyu-l wconstab yf225 chauhang d4l3k [ghstack-poisoned]
@wz337 has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator. |
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
This PR adds a private init backend option, to tackle the issues sub mesh creation: in device mesh slicing we don't want to create process groups again, so explicitly turn the group creation off it's useful Also I think there might be more submesh creation functionality so having this flag would ensure that there's no new group created Differential Revision: [D56497780](https://our.internmc.facebook.com/intern/diff/D56497780) Pull Request resolved: pytorch#124780 Approved by: https://github.com/awgu ghstack-source-id: d8a4e07
Summary: The mapping is no longer needed after pytorch#124780, as we are not going to re-create the pgs during mesh slicing. Test Plan: CI Differential Revision: D56499001
This PR adds a `DeviceMesh.from_group()` static method to convert an existing process group to a device mesh. Motivation: We need `DeviceMesh.from_group()` to allow FSDP2 to interoperate with distributed libraries that do not use `DeviceMesh` for all parallelisms. Pull Request resolved: #124787 Approved by: https://github.com/wanchaol ghstack dependencies: #124651, #124741, #124767, #124768, #124780
…#124890) Summary: The mapping is no longer needed after #124780, as we are not going to re-create the pgs during mesh slicing. Test Plan: CI Differential Revision: D56499001 Pull Request resolved: #124890 Approved by: https://github.com/awgu
This PR adds a `DeviceMesh.from_group()` static method to convert an existing process group to a device mesh. Motivation: We need `DeviceMesh.from_group()` to allow FSDP2 to interoperate with distributed libraries that do not use `DeviceMesh` for all parallelisms. Pull Request resolved: pytorch#124787 Approved by: https://github.com/wanchaol ghstack dependencies: pytorch#124651, pytorch#124741, pytorch#124767, pytorch#124768, pytorch#124780
This PR renames the `FSDP` class to `FSDPModule`. This is a BC breaking change. The rationale is that `FSDPModule` is more descriptive since `fully_shard` is a module-level API (applied to a `module` arg), so the `FSDP` class will always correspond to a module. Also, users commonly import `FullyShardedDataParallel` as `FSDP`, so this can help avoid some name conflict in some cases. Pull Request resolved: #124955 Approved by: https://github.com/wanchaol, https://github.com/wconstab ghstack dependencies: #124651, #124741, #124767, #124768, #124780, #124787
This PR adds a private init backend option, to tackle the issues sub mesh creation: in device mesh slicing we don't want to create process groups again, so explicitly turn the group creation off it's useful Also I think there might be more submesh creation functionality so having this flag would ensure that there's no new group created Differential Revision: [D56497780](https://our.internmc.facebook.com/intern/diff/D56497780) Pull Request resolved: pytorch#124780 Approved by: https://github.com/awgu
This PR adds a `DeviceMesh.from_group()` static method to convert an existing process group to a device mesh. Motivation: We need `DeviceMesh.from_group()` to allow FSDP2 to interoperate with distributed libraries that do not use `DeviceMesh` for all parallelisms. Pull Request resolved: pytorch#124787 Approved by: https://github.com/wanchaol ghstack dependencies: pytorch#124651, pytorch#124741, pytorch#124767, pytorch#124768, pytorch#124780
…#124890) Summary: The mapping is no longer needed after #124780, as we are not going to re-create the pgs during mesh slicing. Test Plan: CI Differential Revision: D56499001 Pull Request resolved: #124890 Approved by: https://github.com/awgu
This PR renames the `FSDP` class to `FSDPModule`. This is a BC breaking change. The rationale is that `FSDPModule` is more descriptive since `fully_shard` is a module-level API (applied to a `module` arg), so the `FSDP` class will always correspond to a module. Also, users commonly import `FullyShardedDataParallel` as `FSDP`, so this can help avoid some name conflict in some cases. Pull Request resolved: pytorch#124955 Approved by: https://github.com/wanchaol, https://github.com/wconstab ghstack dependencies: pytorch#124651, pytorch#124741, pytorch#124767, pytorch#124768, pytorch#124780, pytorch#124787
@mvpatel2000 Yes thanks for reminding! I'll submit a cherry pick soon |
This PR adds a private init backend option, to tackle the issues sub mesh creation: in device mesh slicing we don't want to create process groups again, so explicitly turn the group creation off it's useful Also I think there might be more submesh creation functionality so having this flag would ensure that there's no new group created Differential Revision: [D56497780](https://our.internmc.facebook.com/intern/diff/D56497780) Pull Request resolved: #124780 Approved by: https://github.com/awgu
…124780) (#126147) [device_mesh] add a private init backend option (#124780) This PR adds a private init backend option, to tackle the issues sub mesh creation: in device mesh slicing we don't want to create process groups again, so explicitly turn the group creation off it's useful Also I think there might be more submesh creation functionality so having this flag would ensure that there's no new group created Differential Revision: [D56497780](https://our.internmc.facebook.com/intern/diff/D56497780) Pull Request resolved: #124780 Approved by: https://github.com/awgu
Stack from ghstack (oldest at bottom):
This PR adds back the private init backend option (we had
_init_process_groups
before), to tackle the issues submesh creation. This is a regression fix to 2.3 as we removed the
_init_process_groups
option in 2.3, which triggers a lot more sub process group creations, potentially causing memory spikesin device mesh slicing we don't want to create process groups again,
so explicitly turn the group creation off it's useful
Also I think there might be more submesh creation functionality so
having this flag would ensure that there's no new group created
cc @mrshenli @pritamdamania87 @zhaojuanmao @satgera @rohan-varma @gqchen @aazzolini @osalpekar @jiayisuse @H-Huang @kwen2501 @awgu @penguinwu @fegin @XilunWu @fduwjj @wz337 @tianyu-l @wconstab @yf225 @chauhang @d4l3k
Differential Revision: D56497780