8000 [DSD] Fix to remove non_persistent buffer in distributed state dict by fegin · Pull Request #125337 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content

[DSD] Fix to remove non_persistent buffer in distributed state dict #125337

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 3 commits into from

Conversation

[ghstack-poisoned]
@pytorch-bot pytorch-bot bot added module: distributed_checkpoint oncall: distributed Add this issue/PR to distributed oncall triage queue labels May 1, 2024
Copy link
pytorch-bot bot commented May 1, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/125337

Note: Links to docs will display an error until the docs builds have been completed.

❌ 2 New Failures, 1 Unrelated Failure

As of commit bab5c42 with merge base 746da87 (image):

NEW FAILURES - The following jobs have failed:

FLAKY - The following job failed but was likely due to flakiness present on trunk:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@fegin fegin requested a review from wz337 May 1, 2024 21:31
@fegin fegin added ciflow/trunk Trigger trunk jobs on your pull request ciflow/periodic Trigger jobs ran periodically on master (periodic.yml) on the PR labels May 1, 2024
@fegin fegin requested review from awgu and LucasLLC May 1, 2024 21:32
[ghstack-poisoned]
"dont_save_me", torch.rand(100, device="cuda"), persistent=False
)
ddp_model = DDP(copy.deepcopy(model))
set_model_state_dict(ddp_model, get_model_state_dict(ddp_model))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IIUC, set_model_state_dict(module, get_model_state_dict(module)) should be a no-op. Is this just testing that set_model_state_dict() does not error?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, just ensure that there is no error for set when there is non_persistent buffer. The actual value comparison to the single rank model is done below.

@@ -215,6 +215,8 @@ def recurse(module: nn.Module, curr_fqn: str) -> Generator:
for name, obj in chain(
module.named_buffers(recurse=False), module.named_parameters(recurse=False)
):
if name in module._non_persistent_buffers_set:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I might have missed some discussion. Could you remind me why we use named_buffers() rather than some logic that relies only on the keys in the state dict itself?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It will trigger all_gather for FSDP. Since many users still use FSDP not FSDP2, we will have to ensure no performance penalty for this API.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see. The issue is that both full and sharded state dict all-gather?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes

[ghstack-poisoned]
@fegin
Copy link
Contributor Author
fegin commented May 7, 2024

@pytorchbot merge -f "The failing tests are not related."

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged immediately since you used the force (-f) flag, bypassing any CI checks (ETA: 1-5 minutes). Please use -f as last resort and instead consider -i/--ignore-current to continue the merge ignoring current failures. This will allow currently pending tests to finish and report signal before the merge.

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

antoinebrl pushed a commit to antoinebrl/pytorch that referenced this pull request May 27, 2024
…ytorch#125337)

Summary:
Fixes pytorch#122792

state_dict includes only persistent buffers, while named_buffers() would
include non_persistent buffers.

Pull Request resolved: pytorch#125337
Approved by: https://github.com/awgu
ghstack dependencies: pytorch#125333, pytorch#125501, pytorch#125334, pytorch#125335, pytorch#125336
huydhn pushed a commit that referenced this pull request May 27, 2024
…125337) (#127219)

* [DSD] Fix to remove non_persistent buffer in distributed state dict (#125337)

Summary:
Fixes #122792

state_dict includes only persistent buffers, while named_buffers() would
include non_persistent buffers.

Pull Request resolved: #125337
Approved by: https://github.com/awgu
ghstack dependencies: #125333, #125501, #125334, #125335, #125336

* lintrunner

* lint

---------

Co-authored-by: Chien-Chin Huang <chienchin@fb.com>
Co-authored-by: Andrey Talman <atalman@fb.com>
@github-actions github-actions bot deleted the gh/fegin/234/head branch June 7, 2024 01:55
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ciflow/periodic Trigger jobs ran periodically on master (periodic.yml) on the PR ciflow/trunk Trigger trunk jobs on your pull request Merged oncall: distributed Add this issue/PR to distributed oncall triage queue
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants
0