-
Notifications
You must be signed in to change notification settings - Fork 24.3k
torch.utils.flop_counter.FlopCounterMode broke with torch-2.4 #134242
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Comments
Adding triage review as it feels like this belong to |
@sanketpurandare @wanchaol can you please take a look? The module tracker warning is affected by your recent PR which is in the correct regression timeframe: 2e5366f |
pytorch/torch/utils/flop_counter.py Line 5 in 0583024
Th FlopCounter does not use the |
Thank you for clarifying, @sanketpurandare that they are 2 very similar but different copies of code - I was just flagging that at the end of the OP that the error message isn't user-friendly or actionable - and that one is in torch.dist To repeat the specifics: "The module hierarchy tracking maybe be messed up. Please file a bug to PyTorch, if it is the case" how is the user to know "if it is the case"? what does "may be messed up" mean? so this could probably be improved as well to tell the user where the problem is or how to identify it or some such guidance. In other words, I ended up asking for 2 related things in this OP
Thank you. |
Just to add a small voice of support here. We're also using the FlopCounter on nightly (2.5) and it's a great tool! But, I can confirm the same problem, concerning the warning For what it's worth, the counted FLOPs seem correct, even with the warning? EDIT: Additional Info: For us, the warning (but not error) is always caused by turning on activation checkpointing (for which it would be amazing to count flops correctly) |
In https://github.com/pytorch/pytorch/blob/main/torch/utils/module_tracker.py there is no warning at the moment. A similar error message is "print"-ed only when you enter the same Module multiple times. We should definitely update that message to say this. The is a (not much more helpful) error thrown when you exit the Module during the forward pass (ignored during the backward): pytorch/torch/utils/module_tracker.py Line 112 in 75c22dd
These messages are because tracking entry/exit of module is quite hard and brittle and we wanted to provide as much suggestion there. We can also remove the print/error if you think that they are not helpful. For this particular error, it will depend on the model and what is done. |
FWIW, I found the cause - it's gradient_checkpointing ( When activated, the modules will be re-entered and it should be normal/expected. Weirdly I lost the use case where this lead to the exception, at the moment I can only reproduce the warning repeated on all ranks. |
@albanD, thank you for the quick fix - and adding it to 2.4.1 milestone! |
Fixes pytorch#134242 Make sure to never raise an error when confused. Logs for confusion can be enabled with `TORCH_LOGS="torch.utils.module_tracker"` or the usual python systems. Pull Request resolved: pytorch#134467 Approved by: https://github.com/malfet
) * Move module_tracker to logging for confused hierarchy (#134467) Fixes #134242 Make sure to never raise an error when confused. Logs for confusion can be enabled with `TORCH_LOGS="torch.utils.module_tracker"` or the usual python systems. Pull Request resolved: #134467 Approved by: https://github.com/malfet * Fix bad merge conflict resolution
Validated the fix seems to work for 2.4.1 |
Fixes #134242 Make sure to never raise an error when confused. Logs for confusion can be enabled with `TORCH_LOGS="torch.utils.module_tracker"` or the usual python systems. Pull Request resolved: #134467 Approved by: https://github.com/malfet
Fixes pytorch#134242 Make sure to never raise an error when confused. Logs for confusion can be enabled with `TORCH_LOGS="torch.utils.module_tracker"` or the usual python systems. Pull Request resolved: pytorch#134467 Approved by: https://github.com/malfet
Uh oh!
There was an error while loading. Please reload this page.
🐛 Describe the bug
We have been successfully using
torch.utils.flop_counter.FlopCounterMode
up to torch-2.4 and now it breaks and is impossible to use.It either warns:
The module hierarchy tracking seems to be messed up.Please file a bug to PyTorch
or crashes with:
The Module hierarchy tracking is wrong. Report a bug to PyTorch
The relevant part of the trace is:
here is how we use it:
This happens with any HF transformers model I tried - Bert, Lllama, Mistral - clearly their models are perfectly fine.
Rolling back to 2.3.1 restores the functionality.
Questions:
Suggestion:
If I may suggest the warning/error is meaningless to the user. What does "messed up mean"?
In particular this one:
https://github.com/pytorch/pytorch/blob/3c5b246d3c6461ef59fa38e8c4265b2c6b223412/torch/distributed/_tools/mod_tracker.py#L175C10-L177C72
"The module hierarchy tracking maybe be messed up. Please file a bug to PyTorch, if it is the case"-
how can a user tell if "it is the case"?Versions
the problem happens on multiple setups - the only common ground is pt-2.4.0
@albanD, @Chillee
cc @ezyang @gchanan @zou3519 @kadeng @msaroufim
The text was updated successfully, but these errors were encountered: