-
Notifications
You must be signed in to change notification settings - Fork 24.2k
[DSD] Fix to remove non_persistent buffer in distributed state dict #125337
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/125337
Note: Links to docs will display an error until the docs builds have been completed. ❌ 2 New Failures, 1 Unrelated FailureAs of commit bab5c42 with merge base 746da87 ( NEW FAILURES - The following jobs have failed:
FLAKY - The following job failed but was likely due to flakiness present on trunk:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
"dont_save_me", torch.rand(100, device="cuda"), persistent=False | ||
) | ||
ddp_model = DDP(copy.deepcopy(model)) | ||
set_model_state_dict(ddp_model, get_model_state_dict(ddp_model)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
IIUC, set_model_state_dict(module, get_model_state_dict(module))
should be a no-op. Is this just testing that set_model_state_dict()
does not error?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, just ensure that there is no error for set when there is non_persistent buffer. The actual value comparison to the single rank model is done below.
@@ -215,6 +215,8 @@ def recurse(module: nn.Module, curr_fqn: str) -> Generator: | |||
for name, obj in chain( | |||
module.named_buffers(recurse=False), module.named_parameters(recurse=False) | |||
): | |||
if name in module._non_persistent_buffers_set: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I might have missed some discussion. Could you remind me why we use named_buffers()
rather than some logic that relies only on the keys in the state dict itself?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It will trigger all_gather for FSDP. Since many users still use FSDP not FSDP2, we will have to ensure no performance penalty for this API.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see. The issue is that both full and sharded state dict all-gather?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes
@pytorchbot merge -f "The failing tests are not related." |
Merge startedYour change will be merged immediately since you used the force (-f) flag, bypassing any CI checks (ETA: 1-5 minutes). Please use Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
…ytorch#125337) Summary: Fixes pytorch#122792 state_dict includes only persistent buffers, while named_buffers() would include non_persistent buffers. Pull Request resolved: pytorch#125337 Approved by: https://github.com/awgu ghstack dependencies: pytorch#125333, pytorch#125501, pytorch#125334, pytorch#125335, pytorch#125336
…125337) (#127219) * [DSD] Fix to remove non_persistent buffer in distributed state dict (#125337) Summary: Fixes #122792 state_dict includes only persistent buffers, while named_buffers() would include non_persistent buffers. Pull Request resolved: #125337 Approved by: https://github.com/awgu ghstack dependencies: #125333, #125501, #125334, #125335, #125336 * lintrunner * lint --------- Co-authored-by: Chien-Chin Huang <chienchin@fb.com> Co-authored-by: Andrey Talman <atalman@fb.com>
Stack from ghstack (oldest at bottom):
Summary:
Fixes #122792
state_dict includes only persistent buffers, while named_buffers() would
include non_persistent buffers.
cc @mrshenli @pritamdamania87 @zhaojuanmao @satgera @rohan-varma @gqchen @aazzolini @osalpekar @jiayisuse @H-Huang @kwen2501 @awgu @penguinwu @XilunWu @wanchaol @fduwjj @wz337 @tianyu-l @wconstab @yf225 @chauhang @d4l3k @LucasLLC