8000 [device_mesh] add back the private init backend option by wanchaol · Pull Request #124780 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content

[device_mesh] add back the private init backend option #124780

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 3 commits into from

Conversation

wanchaol
Copy link
Collaborator
@wanchaol wanchaol commented Apr 23, 2024

Stack from ghstack (oldest at bottom):

This PR adds back the private init backend option (we had _init_process_groups before), to tackle the issues sub
mesh creation. This is a regression fix to 2.3 as we removed the _init_process_groups option in 2.3, which triggers a lot more sub process group creations, potentially causing memory spikes

in device mesh slicing we don't want to create process groups again,
so explicitly turn the group creation off it's useful

Also I think there might be more submesh creation functionality so
having this flag would ensure that there's no new group created

cc @mrshenli @pritamdamania87 @zhaojuanmao @satgera @rohan-varma @gqchen @aazzolini @osalpekar @jiayisuse @H-Huang @kwen2501 @awgu @penguinwu @fegin @XilunWu @fduwjj @wz337 @tianyu-l @wconstab @yf225 @chauhang @d4l3k

Differential Revision: D56497780

This PR adds a private init backend option, to tackle the issues sub
mesh creation:

in device mesh slicing we don't want to create process groups again,
so explicitly turn the group creation off it's useful

Also I think there might be more submesh creation functionality so
having this flag would ensure that there's no new group created

[ghstack-poisoned]
Copy link
pytorch-bot bot commented Apr 23, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/124780

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 08f0bae with merge base afa78ad (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@pytorch-bot pytorch-bot bot added ci-td-distributed oncall: distributed Add this issue/PR to distributed oncall triage queue labels Apr 23, 2024
@awgu
Copy link
Collaborator
awgu commented Apr 23, 2024

Oof, this is helpful. I was just writing the same code locally 😅

Copy link
Collaborator
@awgu awgu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

SGTM!

@wanchaol wanchaol added the ciflow/trunk Trigger trunk jobs on your pull request label Apr 23, 2024
# TODO(yeounoh) implement DeviceMesh backend and register XLA backend.
if device_type != "xla":
if device_type != "xla" and _init_backend:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we skip _get_or_create_default_group(), then self._coordinate_on_dim will not be defined when _init_backend=False.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nevermind! This is intentional from the unit tests.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh I think we should fix this it seems, I'll do it as a follow up PR

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think I did similar thing for nD submesh slicing and not initing the sub pgs again for the mesh slice won't work as there is dim_group_infos for the mesh slice. iirc, that's why I need to revert.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, upon more investigation, I think we need to populate both self._coordinate_on_dim and self._dim_group_infos, so this current change should break unit tests IIUC.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It won't break any test yet as _init_backend is default to True, but a followup is needed for most of the API to work with _init_backend = False.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we would always want the default _init_backend to be True, only when we do submesh slicing we would manually turn it off

@@ -89,6 +89,7 @@ def create_child_mesh(
device_mesh.device_type,
mesh_1d,
mesh_dim_names=(mesh_dim_name,),
_init_backend=False,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@wz337 Even though _init_backend is defaulted to True, here we are passing it as False for child meshes. For FSDP2 2D training, I was thinking that this means the child meshes are now missing the self._coordinate_on_dim and self._dim_group_infos.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

there seems to have some test failures, so trying to fix it directly in this PR

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

tested locally should fixed those tests

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about _self._dim_group_infos? I think we need that to support .get_group().

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

those are already populated when we create the submesh?https://github.com/pytorch/pytorch/blob/main/torch/distributed/device_mesh.py#L96

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, got it! Makes sense!

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

basically when use init_backend=False I think there would be some manual pg creation or pg reuse, so we would need to explicitly attach the dim_group_info after DeviceMesh constructor?

i.e. for things like from_group, inside this API after DeviceMesh constructed, we need to attach the dim_group_info to the DeviceMesh created, the dim_group_info need to be generated in a similar way as https://github.com/pytorch/pytorch/blob/main/torch/distributed/device_mesh.py#L314-L319

This PR adds a private init backend option, to tackle the issues sub
mesh creation:

in device mesh slicing we don't want to create process groups again,
so explicitly turn the group creation off it's useful

Also I think there might be more submesh creation functionality so
having this flag would ensure that there's no new group created

cc mrshenli pritamdamania87 zhaojuanmao satgera rohan-varma gqchen aazzolini osalpekar jiayisuse H-Huang kwen2501 awgu penguinwu fegin XilunWu fduwjj wz337 tianyu-l wconstab yf225 chauhang d4l3k

[ghstack-poisoned]
This PR adds a private init backend option, to tackle the issues sub
mesh creation:

in device mesh slicing we don't want to create process groups again,
so explicitly turn the group creation off it's useful

Also I think there might be more submesh creation functionality so
having this flag would ensure that there's no new group created

cc mrshenli pritamdamania87 zhaojuanmao satgera rohan-varma gqchen aazzolini osalpekar jiayisuse H-Huang kwen2501 awgu penguinwu fegin XilunWu fduwjj wz337 tianyu-l wconstab yf225 chauhang d4l3k

[ghstack-poisoned]
@wz337 6D40 wz337 added the ciflow/periodic Trigger jobs ran periodically on master (periodic.yml) on the PR label Apr 23, 2024
@wz337
Copy link
Contributor
wz337 commented Apr 23, 2024

@wz337 has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

@wanchaol wanchaol added the release notes: distributed (dtensor) release notes category label Apr 24, 2024
@wanchaol
Copy link
Collaborator Author

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

awgu pushed a commit to awgu/pytorch that referenced this pull request Apr 24, 2024
This PR adds a private init backend option, to tackle the issues sub
mesh creation:

in device mesh slicing we don't want to create process groups again,
so explicitly turn the group creation off it's useful

Also I think there might be more submesh creation functionality so
having this flag would ensure that there's no new group created

Differential Revision: [D56497780](https://our.internmc.facebook.com/intern/diff/D56497780)
Pull Request resolved: pytorch#124780
Approved by: https://github.com/awgu
ghstack-source-id: d8a4e07
wz337 added a commit to wz337/pytorch that referenced this pull request Apr 24, 2024
Summary: The mapping is no longer needed after pytorch#124780, as we are not going to re-create the pgs during mesh slicing.

Test Plan: CI

Differential Revision: D56499001
pytorchmergebot pushed a commit that referenced this pull request Apr 24, 2024
This PR adds a `DeviceMesh.from_group()` static method to convert an existing process group to a device mesh.

Motivation: We need `DeviceMesh.from_group()` to allow FSDP2 to interoperate with distributed libraries that do not use `DeviceMesh` for all parallelisms.

Pull Request resolved: #124787
Approved by: https://github.com/wanchaol
ghstack dependencies: #124651, #124741, #124767, #124768, #124780
pytorchmergebot pushed a commit that referenced this pull request Apr 26, 2024
…#124890)

Summary: The mapping is no longer needed after #124780, as we are not going to re-create the pgs during mesh slicing.

Test Plan: CI

Differential Revision: D56499001

Pull Request resolved: #124890
Approved by: https://github.com/awgu
alat-rights pushed a commit to alat-rights/pytorch that referenced this pull request Apr 26, 2024
This PR adds a `DeviceMesh.from_group()` static method to convert an existing process group to a device mesh.

Motivation: We need `DeviceMesh.from_group()` to allow FSDP2 to interoperate with distributed libraries that do not use `DeviceMesh` for all parallelisms.

Pull Request resolved: pytorch#124787
Approved by: https://github.com/wanchaol
ghstack dependencies: pytorch#124651, pytorch#124741, pytorch#124767, pytorch#124768, pytorch#124780
pytorchmergebot pushed a commit that referenced this pull request Apr 29, 2024
This PR renames the `FSDP` class to `FSDPModule`. This is a BC breaking change. The rationale is that `FSDPModule` is more descriptive since `fully_shard` is a module-level API (applied to a `module` arg), so the `FSDP` class will always correspond to a module.

Also, users commonly import `FullyShardedDataParallel` as `FSDP`, so this can help avoid some name conflict in some cases.

Pull Request resolved: #124955
Approved by: https://github.com/wanchaol, https://github.com/wconstab
ghstack dependencies: #124651, #124741, #124767, #124768, #124780, #124787
@mvpatel2000
Copy link
Contributor

@wanchaol can we include this in torch 2.3? #125425 it would be nice to fix this regression

petrex pushed a commit to petrex/pytorch that referenced this pull request May 3, 2024
This PR adds a private init backend option, to tackle the issues sub
mesh creation:

in device mesh slicing we don't want to create process groups again,
so explicitly turn the group creation off it's useful

Also I think there might be more submesh creation functionality so
having this flag would ensure that there's no new group created

Differential Revision: [D56497780](https://our.internmc.facebook.com/intern/diff/D56497780)
Pull Request resolved: pytorch#124780
Approved by: https://github.com/awgu
petrex pushed a commit to petrex/pytorch that referenced this pull request May 3, 2024
This PR adds a `DeviceMesh.from_group()` static method to convert an existing process group to a device mesh.

Motivation: We need `DeviceMesh.from_group()` to allow FSDP2 to interoperate with distributed libraries that do not use `DeviceMesh` for all parallelisms.

Pull Request resolved: pytorch#124787
Approved by: https://github.com/wanchaol
ghstack dependencies: pytorch#124651, pytorch#124741, pytorch#124767, pytorch#124768, pytorch#124780
pytorch-bot bot pushed a commit that referenced this pull request May 3, 2024
…#124890)

Summary: The mapping is no longer needed after #124780, as we are not going to re-create the pgs during mesh slicing.

Test Plan: CI

Differential Revision: D56499001

Pull Request resolved: #124890
Approved by: https://github.com/awgu
petrex pushed a commit to petrex/pytorch that referenced this pull request May 3, 2024
This PR renames the `FSDP` class to `FSDPModule`. This is a BC breaking change. The rationale is that `FSDPModule` is more descriptive since `fully_shard` is a module-level API (applied to a `module` arg), so the `FSDP` class will always correspond to a module.

Also, users commonly import `FullyShardedDataParallel` as `FSDP`, so this can help avoid some name conflict in some cases.

Pull Request resolved: pytorch#124955
Approved by: https://github.com/wanchaol, https://github.com/wconstab
ghstack dependencies: pytorch#124651, pytorch#124741, pytorch#124767, pytorch#124768, pytorch#124780, pytorch#124787
@wanchaol wanchaol changed the title [device_mesh] add a private init backend option [device_mesh] add back the private init backend option May 3, 2024
@wanchaol
Copy link
Collaborator Author
wanchaol commented May 3, 2024

@wanchaol can we include this in torch 2.3? #125425 it would be nice to fix this regression

@mvpatel2000 Yes thanks for reminding! I'll submit a cherry pick soon

wanchaol added a commit that referenced this pull request May 14, 2024
This PR adds a private init backend option, to tackle the issues sub
mesh creation:

in device mesh slicing we don't want to create process groups again,
so explicitly turn the group creation off it's useful

Also I think there might be more submesh creation functionality so
having this flag would ensure that there's no new group created

Differential Revision: [D56497780](https://our.internmc.facebook.com/intern/diff/D56497780)
Pull Request resolved: #124780
Approved by: https://github.com/awgu
atalman pushed a commit that referenced this pull request May 14, 2024
…124780) (#126147)

[device_mesh] add a private init backend option (#124780)

This PR adds a private init backend option, to tackle the issues sub
mesh creation:

in device mesh slicing we don't want to create process groups again,
so explicitly turn the group creation off it's useful

Also I think there might be more submesh creation functionality so
having this flag would ensure that there's no new group created

Differential Revision: [D56497780](https://our.internmc.facebook.com/intern/diff/D56497780)
Pull Request resolved: #124780
Approved by: https://github.com/awgu
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ci-td-distributed ciflow/periodic Trigger jobs ran periodically on master (periodic.yml) on the PR ciflow/trunk Trigger trunk jobs on your pull request Merged oncall: distributed Add this issue/PR to distributed oncall triage queue release notes: distributed (dtensor) release notes category
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants
0