8000 [core] support tensor-valued _extra_state values in `from_pretrained` by pstjohn · Pull Request #38155 · huggingface/transformers · GitHub
[go: up one dir, main page]

Skip to content

[core] support tensor-valued _extra_state values in from_pretrained #38155

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
May 28, 2025

Conversation

pstjohn
Copy link
Contributor
@pstjohn pstjohn commented May 15, 2025

What does this PR do?

TransformerEngine uses the pytorch get/set_extra_state API to store FP8 layer config information as a Tensor in the _extra_state entry in the state dict. With recent changes to from_pretrained, this functionality has broken and loading a model that uses this API doesn't appear to work. This PR fixes the save/load pretrained functions for extra state entries that use a pytorch tensor, and adds a (currently x-failing) test for a dictionary extra state.

Fixes #38154

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

RFR @Cyrilvallez, @ArthurZucker.

TransformerEngine uses the pytorch get/set_extra_state API to store FP8
layer config information as bytes Tensor in the _extra_state entry in
the state dict. With recent changes to from_pretrained, this
functionality has broken and loading a model that uses this API doesn't
appear to work. This PR fixes the save/load pretrained functions for
extra state entries that use a pytorch tensor, and adds a (currently
x-failing) test for a dictionary extra state.

Signed-off-by: Peter St. John <pstjohn@nvidia.com>
@github-actions github-actions bot marked this pull request as draft May 15, 2025 14:33
Copy link
Contributor

Hi 👋, thank you for opening this pull request! The pull request is converted to draft by default. The CI will be paused while the PR is in draft mode. When it is ready for review, please click the Ready for review button (at the bottom of the PR page). This will assign reviewers and trigger CI.

@pstjohn pstjohn changed the title Support tensor-valued _extra_state values Support tensor-valued _extra_state values in from_pretrained May 15, 2025
@pstjohn pstjohn marked this pull request as ready for review May 15, 2025 14:33
@github-actions github-actions bot requested review from Rocketknight1 and ydshieh May 15, 2025 14:33
@pstjohn pstjohn changed the title Support tensor-valued _extra_state values in from_pretrained [core] support tensor-valued _extra_state values in from_pretrained May 15, 2025
@pstjohn
Copy link
Contributor Author
pstjohn commented May 19, 2025

@S1ro1, maybe you'd want to review this? I saw you recently merged #37689 that's somewhat similar

@pstjohn
Copy link
Contributor Author
pstjohn commented May 21, 2025

@Cyrilvallez, gentle ping for review. Thanks!

The test failure seemed like a timeout issue, let me know if you want me to try running this again.
edit: rebased and these are passing now

Copy link
Collaborator
@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This actually looks like a nice feature!
I am just curious, don't we need to also call set_extra_state somewhere?
Also is this bound by any version of torch?

@pstjohn
Copy link
Contributor Author
pstjohn commented May 26, 2025

Thanks for the review!

No, torch will internally call set_extra_state inside load_state_dict, which we call in _load_parameter_into_model (modeling_utils.py:717).

extra_states have been in pytorch for ~4 years, so I don't think this practically restricts you to any recent version of torch: pytorch/pytorch#62976.

I haven't done a full bisection, but I think this is more of a bug fix than a new feature. IIUC when to/from pretrained leveraged more of PyTorch's state_dict machinery this would have been supported out of the box.

Copy link
Member
@Cyrilvallez Cyrilvallez left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Alright, that works for me! Thanks! Merging 🤗

@Cyrilvallez Cyrilvallez merged commit bab40c6 into huggingface:main May 28, 2025
18 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Support extra_state attributes in from_pretrained
3 participants
0