-
Notifications
You must be signed in to change notification settings - Fork 29.8k
fix: support grad clipping for TP through replicating non-sharded modules #36132
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
289af7a
to
2192a35
Compare
Can you have a look @kwen2501 ? |
9bb5edd
to
ebfb17d
Compare
@weifengpy @mori360 do you mind having a look at the two concerns here? Thanks! |
|
@weifengpy Do you recommend to use |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for fixing gradient clipping with TP.
Functionality wise, the code change looks reasonable.
I am consulting with my colleagues to see if there is a must to explicitly annotate non-sharded modules as Replicated.
"layers.*.self_attn.o_proj": "rowwise_output_dtensor", | ||
"layers.*.mlp.gate_proj": "colwise", | ||
"layers.*.mlp.up_proj": "colwise", | ||
"layers.*.mlp.down_proj": "rowwise", | ||
"layers.*.mlp.down_proj": "rowwise_output_dtensor", | ||
"embed_tokens": "replicateparallel_output_dtensor", | ||
"layers.*.post_attention_layernorm": "replicateparallel_output_dtensor", | ||
"layers.*.input_layernorm": "replicateparallel_output_dtensor", | ||
"norm": "replicateparallel_output_dtensor", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for extending the configs here.
I wonder if some of these settings would be more interesting to training than to inference?
(On the other hand, I don't know much about HF's user profile -- training more or inference more?)
If some of the settings are specific to training, is it possible to separate them out? Or, shall we make the config somehow customizable at run time?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@kwen2501
True, these are needed for training for grad norm, however, not so needed for inference. Does using replicate incur costs?
src/transformers/pytorch_utils.py
Outdated
class ReplicateParallel(ParallelStyle): | ||
""" | ||
Replicate a nn.Module. | ||
Users can compose it together with other parallel styles like RowwiseParallel to achieve a fully distributed model. | ||
Fully distributed model is needed for gradient clipping. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@weifengpy @wz337 @tianyu-l
I wonder if there is anything we can do on DTensor side so that users don't have to annotate the entire model to perform gradient clipping?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@weifengpy @kwen2501 Should we be using implicit_replication()
as an alternative?
src/transformers/pytorch_utils.py
Outdated
# TODO need to add the __repr__ that shows that it is a colwise parallel | ||
# See https://github.com/pytorch/pytorch/issues/145726 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: keep this TODO?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed thanks
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@kmehant we finish the refactoring if you still want to work on this! |
Thanks for the update, I will soon rebase the PR by EoD. Thanks |
@ArthurZucker I have rebased my PR, thanks |
6290f7f
to
99a3817
Compare
@ArthurZucker I kept it consistent with the refactored code now. Waiting for @kwen2501 if the recommendation is to use |
Yeah, per @weifengpy 's comment, I think |
Signed-off-by: Mehant Kammakomati <mehant.kammakomati2@ibm.com>
Signed-off-by: Mehant Kammakomati <mehant.kammakomati2@ibm.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks!
I am not familiar with implicit_replication
but if it is recommended by @kwen2501 happy to use! I suppose it requires a specific torch version check to use no?
If so, using our own replicate would be a bit better AFAIK for broader support (starting 2.3 vs only version where implicit replication is defined
@@ -234,6 +234,7 @@ | |||
AutocastKwargs, | |||
DistributedDataParallelKwargs, | |||
DistributedType, | |||
TorchTensorParallelPlugin, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
not seing this used!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
much needed !
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
…ules (huggingface#36132) * feat: fix tp grad norm: Signed-off-by: Mehant Kammakomati <mehant.kammakomati2@ibm.com> * feat: use implicit replication Signed-off-by: Mehant Kammakomati <mehant.kammakomati2@ibm.com> --------- Signed-off-by: Mehant Kammakomati <mehant.kammakomati2@ibm.com> Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com>
What does this PR do?
torch.nn.utils.clip_grad_norm_
does not support heterogenous set of parameters having a mix of DTensors and Tensors. This PR allows for gradient clipping by distributing non-sharded modules that are not involved in TP. We replicate all such modules across the device mesh.The PR also adds new parallel style
ReplicateParallel
so that the existing TP APIs can be used as is for this module replication operation. We could think of contributing this back to PyTorch if it makes sense (cc: @kwen2501) otherwise we can maintain it in transformers.⭐ Note : We would rebase this PR once #34194 is merged some of the workflow changes that you see here would disappear once the PR is merged.
fixes: #36296
Concerns
Concern 1
When we do two TP runs with gradient clipping with exact same training settings we dont reproduce exact loss parity between the runs though both the runs converge eventually. I am worried if

Replicate
sharding has something to do here.Concern 2
Grad norms are not same on each rank, I would assume in TP training the grad norms should come out to be same across the ranks however thats not the case
Before submitting
Pull Request section?
to it if that's the case.
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.
@ArthurZucker @muellerzr and @SunMarc
@kwen2501 from PyTorch