8000 fix: support grad clipping for TP through replicating non-sharded modules by kmehant · Pull Request #36132 · huggingface/transformers · GitHub
[go: up one dir, main page]

Skip to content

fix: support grad clipping for TP through replicating non-sharded modules #36132

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Jun 6, 2025

Conversation

kmehant
Copy link
Contributor
@kmehant kmehant commented Feb 11, 2025

What does this PR do?

torch.nn.utils.clip_grad_norm_ does not support heterogenous set of parameters having a mix of DTensors and Tensors. This PR allows for gradient clipping by distributing non-sharded modules that are not involved in TP. We replicate all such modules across the device mesh.

The PR also adds new parallel style ReplicateParallel so that the existing TP APIs can be used as is for this module replication operation. We could think of contributing this back to PyTorch if it makes sense (cc: @kwen2501) otherwise we can maintain it in transformers.

⭐ Note : We would rebase this PR once #34194 is merged some of the workflow changes that you see here would disappear once the PR is merged.

fixes: #36296

Concerns

Concern 1

When we do two TP runs with gradient clipping with exact same training settings we dont reproduce exact loss parity between the runs though both the runs converge eventually. I am worried if Replicate sharding has something to do here.
Screenshot 2025-02-11 at 6 51 37 PM

Concern 2

Grad norms are not same on each rank, I would assume in TP training the grad norms should come out to be same across the ranks however thats not the case

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@ArthurZucker @muellerzr and @SunMarc
@kwen2501 from PyTorch

@SunMarc
Copy link
Member
SunMarc commented Feb 20, 2025

Can you have a look @kwen2501 ?

@kmehant kmehant force-pushed the tp-gradnorm branch 2 times, most recently from 9bb5edd to ebfb17d Compare February 20, 2025 14:18
@kwen2501
Copy link
Contributor

@weifengpy @mori360 do you mind having a look at the two concerns here? Thanks!

@weifengpy
Copy link

@weifengpy @mori360 do you mind having a look at the two concerns here? Thanks!

does not support heterogenous set of parameters having a mix of DTensors and Tensors

implicit_replication is invented to mix DTensor with plain tensors. Maybe it's cleaner here.

from torch.distributed._tensor.experimental import implicit_replication
with implicit_replication():
   # call gradient clipping here

code pointer: https://github.com/pytorch/pytorch/blob/8b818ab58f635f999de2c8a5bf8e6c01d0c122ed/test/distributed/tensor/parallel/test_tp_examples.py#L262-L264

@kmehant
Copy link
Contributor Author
kmehant commented Feb 21, 2025

@weifengpy Do you recommend to use implicit_replication instead?

Copy link
Contributor
@kwen2501 kwen2501 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for fixing gradient clipping with TP.
Functionality wise, the code change looks reasonable.
I am consulting with my colleagues to see if there is a must to explicitly annotate non-sharded modules as Replicated.

Comment on lines 120 to 127
"layers.*.self_attn.o_proj": "rowwise_output_dtensor",
"layers.*.mlp.gate_proj": "colwise",
"layers.*.mlp.up_proj": "colwise",
"layers.*.mlp.down_proj": "rowwise",
"layers.*.mlp.down_proj": "rowwise_output_dtensor",
"embed_tokens": "replicateparallel_output_dtensor",
"layers.*.post_attention_layernorm": "replicateparallel_output_dtensor",
"layers.*.input_layernorm": "replicateparallel_output_dtensor",
"norm": "replicateparallel_output_dtensor",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for extending the configs here.
I wonder if some of these settings would be more interesting to training than to inference?
(On the other hand, I don't know much about HF's user profile -- training more or inference more?)
If some of the settings are specific to training, is it possible to separate them out? Or, shall we make the config somehow customizable at run time?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@kwen2501
True, these are needed for training for grad norm, however, not so needed for inference. Does using replicate incur costs?

Comment on lines 354 to 358
class ReplicateParallel(ParallelStyle):
"""
Replicate a nn.Module.
Users can compose it together with other parallel styles like RowwiseParallel to achieve a fully distributed model.
Fully distributed model is needed for gradient clipping.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@weifengpy @wz337 @tianyu-l
I wonder if there is anything we can do on DTensor side so that users don't have to annotate the entire model to perform gradient clipping?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@weifengpy @kwen2501 Should we be using implicit_replication() as an alternative?

Comment on lines 347 to 342
# TODO need to add the __repr__ that shows that it is a colwise parallel
# See https://github.com/pytorch/pytorch/issues/145726
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: keep this TODO?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed thanks

Copy link
Collaborator
@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very nice, waiting on @kwen2501 's feed back, but make sure to rebase since we just merged: #36335 !

@ArthurZucker
Copy link
Collaborator

@kmehant we finish the refactoring if you still want to work on this!

@kmehant
Copy link
Contributor Author
kmehant commented Mar 11, 2025

#36132 (comment)

Thanks for the update, I will soon rebase the PR by EoD. Thanks

@kmehant
Copy link
Contributor Author
kmehant commented Mar 11, 2025

@ArthurZucker I have rebased my PR, thanks

@kmehant kmehant force-pushed the tp-gradnorm branch 9 times, most recently from 6290f7f to 99a3817 Compare March 17, 2025 11:39
@kmehant
Copy link
Contributor Author
kmehant commented Mar 17, 2025

#36132 (comment)

@ArthurZucker I kept it consistent with the refactored code now. Waiting for @kwen2501 if the recommendation is to use implicit_replication() from torch instead of introducing a new parallelstyle module for replication.

@kwen2501
Copy link
Contributor

Yeah, per @weifengpy 's comment, I think implicit_replication() is preferred over creating new strategies. You can limit its usage to minimal possible if there is risk concern -- like, only when clip_grad_norm_ is used.

kmehant added 2 commits March 25, 2025 16:25
Signed-off-by: Mehant Kammakomati <mehant.kammakomati2@ibm.com>
Signed-off-by: Mehant Kammakomati <mehant.kammakomati2@ibm.com>
@kmehant kmehant mentioned this pull request May 12, 2025
5 tasks
Copy link
Collaborator
@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks!
I am not familiar with implicit_replication but if it is recommended by @kwen2501 happy to use! I suppose it requires a specific torch version check to use no?

If so, using our own replicate would be a bit better AFAIK for broader support (starting 2.3 vs only version where implicit replication is defined

@@ -234,6 +234,7 @@
AutocastKwargs,
DistributedDataParallelKwargs,
DistributedType,
TorchTensorParallelPlugin,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not seing this used!

Copy link
Member
@SunMarc SunMarc left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

much needed !

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@SunMarc SunMarc merged commit 3d15606 into huggingface:main Jun 6, 2025
20 checks passed
bvantuan pushed a commit to bvantuan/transformers that referenced this pull request Jun 12, 2025
…ules (huggingface#36132)

* feat: fix tp grad norm:

Signed-off-by: Mehant Kammakomati <mehant.kammakomati2@ibm.com>

* feat: use implicit replication

Signed-off-by: Mehant Kammakomati <mehant.kammakomati2@ibm.com>

---------

Signed-off-by: Mehant Kammakomati <mehant.kammakomati2@ibm.com>
Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

tensor parallel training bug
6 participants
0