10000 torch.utils.flop_counter.FlopCounterMode broke with torch-2.4 · Issue #134242 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content

torch.utils.flop_counter.FlopCounterMode broke with torch-2.4 #134242

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
stas00 opened this issue Aug 22, 2024 · 9 comments
Closed

torch.utils.flop_counter.FlopCounterMode broke with torch-2.4 #134242

stas00 opened this issue Aug 22, 2024 · 9 comments
Labels
high priority module: flop counter FlopCounterMode mode module: regression It used to work, and now it doesn't triage review
Milestone

Comments

@stas00
Copy link
Contributor
stas00 commented Aug 22, 2024

🐛 Describe the bug

We have been successfully using torch.utils.flop_counter.FlopCounterMode up to torch-2.4 and now it breaks and is impossible to use.

It either warns:
The module hierarchy tracking seems to be messed up.Please file a bug to PyTorch
or crashes with:
The Module hierarchy tracking is wrong. Report a bug to PyTorch

The relevant part of the trace is:

[:1]:[rank1]:   File "/home/stas/anaconda3/envs/py310-pt22/lib/python3.10/site-packages/torch/_tensor.py", line 521, in backward
[:1]:[rank1]:     torch.autograd.backward(
[:1]:[rank1]:   File "/home/stas/anaconda3/envs/py310-pt22/lib/python3.10/site-packages/torch/autograd/__init__.py", line 289, in backward
[:1]:[rank1]:     _engine_run_backward(
[:1]:[rank1]:   File "/home/stas/anaconda3/envs/py310-pt22/lib/python3.10/site-packages/torch/autograd/graph.py", line 768, in _engine_run_backward
[:1]:[rank1]:     return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
[:1]:[rank1]:   File "/home/stas/anaconda3/envs/py310-pt22/lib/python3.10/site-packages/torch/utils/checkpoint.py", line 1116, in unpack_hook
[:1]:[rank1]:     frame.recompute_fn(*args)
[:1]:[rank1]:   File "/home/stas/anaconda3/envs/py310-pt22/lib/python3.10/site-packages/torch/utils/checkpoint.py", line 1400, in recompute_fn
[:1]:[rank1]:     fn(*args, **kwargs)
[:1]:[rank1]:   File "/home/stas/anaconda3/envs/py310-pt22/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
[:1]:[rank1]:     return self._call_impl(*args, **kwargs)
[:1]:[rank1]:   File "/home/stas/anaconda3/envs/py310-pt22/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1592, in _call_impl
[:1]:[rank1]:     args_result = hook(self, args)
[:1]:[rank1]:   File "/home/stas/anaconda3/envs/py310-pt22/lib/python3.10/site-packages/torch/utils/module_tracker.py", line 120, in _fw_pre_hook
[:1]:[rank1]:     self._get_append_fn(name, False)()
[:1]:[rank1]:   File "/home/stas/anaconda3/envs/py310-pt22/lib/python3.10/site-packages/torch/utils/module_tracker.py", line 96, in fn

here is how we use it:

class MatMulFlopCounter:
    def __init__(self, display=False, target_iter=2):
        self.target_iter = target_iter
        self.flop_counter = FlopCounterMode(display=display)
        self.mm_tflops = 0

    @contextmanager
    def __call__(self, current_iter):
        if current_iter == self.target_iter:
            with self.flop_counter:
                yield
            self.mm_tflops = self.flop_counter.get_total_flops()
        else:
            yield

    def get_total_flops(self):
        return self.mm_tflops / 1e12
[...]
                    mm_flop_counter = MatMulFlopCounter()
                    with mm_flop_counter(iter_since_job_start), self.accelerator.accumulate(self.model):
                        (loss_total, output) = self.do_batch(...)

This happens with any HF transformers model I tried - Bert, Lllama, Mistral - clearly their models are perfectly fine.

Rolling back to 2.3.1 restores the functionality.

Questions:

  1. what is the workaround to unblock us using FlopCounterMode with pt-2.4+
  2. what is the long-term solution

Suggestion:
If I may suggest the warning/error is meaningless to the user. What does "messed up mean"?
In particular this one:

https://github.com/pytorch/pytorch/blob/3c5b246d3c6461ef59fa38e8c4265b2c6b223412/torch/distributed/_tools/mod_tracker.py#L175C10-L177C72

"The module hierarchy tracking maybe be messed up. Please file a bug to PyTorch, if it is the case"- how can a user tell if "it is the case"?

Versions

the problem happens on multiple setups - the only common ground is pt-2.4.0

@albanD, @Chillee

cc @ezyang @gchanan @zou3519 @kadeng @msaroufim

@malfet
Copy link
Contributor
malfet commented Aug 22, 2024

Adding triage review as it feels like this belong to oncall: profiler, but it's not..

@zou3519 zou3519 added the module: flop counter FlopCounterMode mode label Aug 22, 2024
@ezyang
Copy link
Contributor
ezyang commented Aug 23, 2024

@sanketpurandare @wanchaol can you please take a look? The module tracker warning is affected by your recent PR which is in the correct regression timeframe: 2e5366f

@sanketpurandare
Copy link
Contributor
sanketpurandare commented Aug 23, 2024

@sanketpurandare @wanchaol can you please take a look? The module tracker warning is affected by your recent PR which is in the correct regression timeframe: 2e5366f

from .module_tracker import ModuleTracker

Th FlopCounter does not use the ModTracker in torch.distributed._tools. It uses the one by @albanD in torch.utils?
https://github.com/pytorch/pytorch/blob/main/torch/utils/module_tracker.py

cc: @wanchaol @stas00

@stas00
Copy link
Contributor Author
stas00 commented Aug 23, 2024

Thank you for clarifying, @sanketpurandare that they are 2 very similar but different copies of code - I was just flagging that at the end of the OP that the error message isn't user-friendly or actionable - and that one is in torch.dist

To repeat the specifics:

"The module hierarchy tracking maybe be messed up. Please file a bug to PyTorch, if it is the case"

how is the user to know "if it is the case"? what does "may be messed up" mean?

so this could probably be improved as well to tell the user where the problem is or how to identify it or some such guidance.


In other words, I ended up asking for 2 related things in this OP

  1. overcoming regression
  2. while helping users to make sense of these cryptic warnings and errors

Thank you.

@JonasGeiping
Copy link
JonasGeiping commented Aug 23, 2024

Just to add a small voice of support here. We're also using the FlopCounter on nightly (2.5) and it's a great tool! But, I can confirm the same problem, concerning the warning The module hierarchy tracking seems to be messed up.Please file a bug to PyTorch.

For what it's worth, the counted FLOPs seem correct, even with the warning?

EDIT: Additional Info: For us, the warning (but not error) is always caused by turning on activation checkpointing (for which it would be amazing to count flops correctly)

@albanD
Copy link
Collaborator
albanD commented Aug 23, 2024

In https://github.com/pytorch/pytorch/blob/main/torch/utils/module_tracker.py there is no warning at the moment.

A similar error message is "print"-ed only when you enter the same Module multiple times. We should definitely update that message to say this.

The is a (not much more helpful) error thrown when you exit the Module during the forward pass (ignored during the backward):

"The Module hierarchy tracking is wrong. Report a bug to PyTorch"

These messages are because tracking entry/exit of module is quite hard and brittle and we wanted to provide as much suggestion there. We can also remove the print/error if you think that they are not helpful.

For this particular error, it will depend on the model and what is done.
One known case (and why there is no error in backward) is during the backward when the input doesn't require gradients.
I think there are issues with calling the same Module recursively and I know there are issues when re-using the same input for multiple modules (I need to prepare a PR for this but it is tricky if we want to catch Tensors inside data structures for inputs).

@stas00
Copy link
Contributor Author
stas00 commented Aug 23, 2024

FWIW, I found the cause - it's gradient_checkpointing (torch.utils.checkpoint) that triggers it.

When activated, the modules will be re-entered and it should be normal/expected.

Weirdly I lost the use case where this lead to the exception, at the moment I can only reproduce the warning repeated on all ranks.

@stas00
Copy link
Contributor Author
stas00 commented Aug 26, 2024

@albanD, thank you for the quick fix - and adding it to 2.4.1 milestone!

albanD added a commit to albanD/pytorch that referenced this issue Aug 26, 2024
Fixes pytorch#134242

Make sure to never raise an error when confused. Logs for confusion can be enabled with `TORCH_LOGS="torch.utils.module_tracker"` or the usual python systems.

Pull Request resolved: pytorch#134467
Approved by: https://github.com/malfet
atalman pushed a commit that referenced this issue Aug 27, 2024
)

* Move module_tracker to logging for confused hierarchy (#134467)

Fixes #134242

Make sure to never raise an error when confused. Logs for confusion can be enabled with `TORCH_LOGS="torch.utils.module_tracker"` or the usual python systems.

Pull Request resolved: #134467
Approved by: https://github.com/malfet

* Fix bad merge conflict resolution
@PaliC
Copy link
Contributor
PaliC commented Sep 4, 2024

Validated the fix seems to work for 2.4.1

pytorch-bot bot pushed a commit that referenced this issue Sep 13, 2024
Fixes #134242

Make sure to never raise an error when confused. Logs for confusion can be enabled with `TORCH_LOGS="torch.utils.module_tracker"` or the usual python systems.

Pull Request resolved: #134467
Approved by: https://github.com/malfet
Chao1Han pushed a commit to Chao1Han/pytorch that referenced this issue Sep 20, 2024
Fixes pytorch#134242

Make sure to never raise an error when confused. Logs for confusion can be enabled with `TORCH_LOGS="torch.utils.module_tracker"` or the usual python systems.

Pull Request resolved: pytorch#134467
Approved by: https://github.com/malfet
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
high priority module: flop counter FlopCounterMode mode module: regression It used to work, and now it doesn't triage review
Projects
None yet
Development

Successfully merging a pull request may close this issue.

8 participants
0