-
Notifications
You must be signed in to change notification settings - Fork 24.3k
Pipeline Parallelism Fails when stage input does not produce gradients in all stages. #152827
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Comments
Update on this issue: I was partially able to workaround this by adding in an additional graph_connection output to the pipeline stage outputs I have tested this on various scenarios and the workaround does confirm that the disconnected gradients (producing None in the backward) in PP, PP + FSDP, and FSDP is the issue. The workaround does allow the code to run in the case of PP only and FSDP + PP with only one level of FSDP wrapping. However, it doesn't solve it with 2 levels of FSDP wrapping (like in TorchTitan where the full model and the transformer blocks are fully sharded) due to All of these workarounds introduced significant additional user code, and may introduce additional unneeded communication from the additional graph connections. A proper fix for this issue is needed so the issue still stands. |
@H-Huang can you take a look? Thanks! |
Do you have a minimal example of this bug that we could use to aid the feature development (and eventually turn into a test for CI)? That would be really helpful! It sounds like the fix is to set gradients to 0 even if the inputs grads are None, that sounds doable to update in our PP framework. I'm assuming this use case is only for stages when handling the grads for multiple inputs? An alternative is to skip gradient sending if no grads are detected, but that would also require the previous stage to also know to set the recv, which seems a bit more challenging so I think the former is better. |
I can work on a minimal example of the bug after a few of days. This bug happens in:
To work around this, I manually added a graph connection tensor I described earlier to force PyTorch to connect the tensors into the graph. I added hooks to each input tensor and parameter tensor to check if it was used or not. It ends up being very hairy when you perform nested fully_shard like in TorchTitan, and I still haven't finished the workaround for that yet. Setting the gradients to be 0 sounds like the right approach, but I suggest being careful as to avoid unnecessary communication overhead. Specifically, sending a full tensor of just zeros could waste megabytes or gigabytes of bandwidth depending on the size of the tensor. So there needs to be a way of sending a zero update such a special sentinel/marker or some other construct. And we would want to prevent additional computation from happening based on those zero gradients. Otherwise, much of the benefit that arises from mixture of expert models is lost. Skipping sending gradients would be difficult as the recv stage would not necessarily know beforehand which inputs/parameters are going to be used in the gradients until the forward pass for the send stage has completed. |
Uh oh!
There was an error while loading. Please reload this page.
🐛 Describe the bug
TLDR: Pipeline parallelism fails if stage input does not have gradients produced
Consider the case where a outputs from each pipeline stage is passed to the next stage, but whether or not the output is used or not for a particular batch is conditional (based on the code of the model). Hence, in many cases (such as in conditional or mixture models), these weights may not be used for a particular stage, thus resulting in an error from
get_bwd_send_ops
:"[{self.stage_index}] for chunk {bwd_chunk_id} has gradients {grad} and is expecting to send gradients to stage {grad_recv_stage}"
.As a result, such a model fails when using pipeline parallelism, although with FSDP (or no parallelism) it has no issues. This is caused because for each input tensor to a stage, after calculating the gradient for that tensor, if there is no gradient it produces that error, even if that tensor would otherwise be passed onto a subsequent stage that would result in gradients being produced.
Currently it uses the
dist.isend
to send the tensor, but in order to send None, a different asynchronous P2P commutation operation is needed, to be able to asynchronously send or recv objects (which may or may not be tensors).It would be great if this can be implemented, as this pipeline parallelism is critical to achieving high throughput in distributed execution, and conditional or mixture models are limited due to this bug.
Versions
Pytorch 2.6+
cc @H-Huang @awgu @wanchaol @fegin @fduwjj @wz337 @wconstab @d4l3k
The text was updated successfully, but these errors were encountered: