8000 FSDP + DTensor Loss Flatlines Randomly · Issue #117471 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content

FSDP + DTensor Loss Flatlines Randomly #117471

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
mvpatel2000 opened this issue Jan 14, 2024 · 5 comments
Closed

FSDP + DTensor Loss Flatlines Randomly #117471

mvpatel2000 opened this issue Jan 14, 2024 · 5 comments
Labels
module: dtensor distributed tensor tag module: fsdp triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Milestone

Comments

@mvpatel2000
Copy link
Contributor
mvpatel2000 commented Jan 14, 2024

🐛 Describe the bug

We have been training dtensor off torch nightly (in anticipation for 2.2), and we are very often seeing the loss flatline. We do not see this at all on current nightly (as of 4 days ago), and at this point we are very confident there is a regression/bug in the current release candidate (for 2.2) that breaks FSDP training (at least with dtensor).
Our best guess is one of the two PRs linked fix it:

image

Versions

Torch 2.2 branch

cc @zhaojuanmao @mrshenli @rohan-varma @awgu @fegin @penguinwu @kwen2501 @wanchaol @XilunWu @tianyu-l

@Skylion007
Copy link
Collaborator

We confirmed this affects 2.2.0 final RC

@atalman
Copy link
Contributor
atalman commented Jan 15, 2024

@mvpatel2000
Copy link
Contributor Author

@atalman unfortunately I do not have a minimal repro nor am able to share the code for this run at this time :(

We run a transformer model with Dtensor + FSDP (pass in device mesh). The only different think we do is some weights are manually wrapped with dtensor and presharded before FSDP -- I'm pretty sure this won't matter so reproducing on your end shouldn't be too hard, but I'm not 100% confident

@wanchaol
Copy link
Collaborator

@atalman I just checked our release branch, in addition to #117020 We'll also need this PR together to resolve the merge conflicts #116122.

I can also confirms that I also met similar numeric issues (although not loss flatline, but it's loss NaN problem which looks similar to the issue that @mvpatel2000 met). These two fixes helps me resolve the NaN problem, it would be great if we can include these two fixes in our release branch :)

@awgu awgu added triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module module: fsdp module: dtensor distributed tensor tag labels Jan 16, 2024
@atalman atalman modified the milestones: 2.2.0, 2.2.1 Jan 18, 2024
@mvpatel2000
Copy link
Contributor Author

Fixed in dev

@github-project-automation github-project-automation bot moved this to Review Required in Release Milestone Review Feb 1, 2024
@atalman atalman moved this from Review Required to Validation Required in Release Milestone Review Feb 14, 2024
@atalman atalman moved this from Validation Required to Done in Release Milestone Review Feb 19, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: dtensor distributed tensor tag module: fsdp triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

No branches or pull requests

5 participants
0