8000 Addind RoPE to pytorch core · Issue #149534 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content

Addind RoPE to pytorch core #149534

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
manuelcandales opened this issue Mar 19, 2025 · 3 comments
Open

Addind RoPE to pytorch core #149534

manuelcandales opened this issue Mar 19, 2025 · 3 comments
Assignees
Labels
enhancement Not as big of a feature, but technically not a bug. Should be easy to fix module: nn Related to torch.nn needs design triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@manuelcandales
Copy link
Contributor
manuelcandales commented Mar 19, 2025

The RoPE python code is being copied and pasted over and over in multiple pytorch org repos. I propose we move the RoPE operation to pytorch core (e.g. under nn.functional) and also add a RotaryPositionalEmbeddings module. Some examples of code duplication:

pytorch/ao:

pytorch/benchmark:

pytorch/torchchat:

pytorch/torchtune:

pytorch/xla:

pytorch/pytorch:

  • def apply_rotary_emb(x: Tensor, freqs_cis: Tensor) -> Tensor:
    xshaped = x.float().reshape(*x.shape[:-1], -1, 2)
    freqs_cis = freqs_cis.view(1, xshaped.size(1), 1, xshaped.size(3), 2)
    x_out2 = torch.stack(
    [
    xshaped[..., 0] * freqs_cis[..., 0] - xshaped[..., 1] * freqs_cis[..., 1],
    xshaped[..., 1] * freqs_cis[..., 0] + xshaped[..., 0] * freqs_cis[..., 1],
    ],
    -1,
    )
    x_out2 = x_out2.flatten(3)
    return x_out2.type_as(x)
  • def apply_rotary_emb(x: Tensor, freqs_cis: Tensor) -> Tensor:
    xshaped = x.float().reshape(*x.shape[:-1], -1, 2)
    freqs_cis = freqs_cis.view(1, xshaped.size(1), 1, xshaped.size(3), 2)
    x_out2 = torch.stack(
    [
    xshaped[..., 0] * freqs_cis[..., 0] - xshaped[..., 1] * freqs_cis[..., 1],
    xshaped[..., 1] * freqs_cis[..., 0] + xshaped[..., 0] * freqs_cis[..., 1],
    ],
    -1,
    )
    x_out2 = x_out2.flatten(3)
    return x_out2.type_as(x)

cc @albanD @mruberry @jbschlosser @walterddr @mikaylagawarecki

@manuelcandales manuelcandales added the module: nn Related to torch.nn label Mar 19, 2025
@jbschlosser
Copy link
Contributor
jbschlosser commented Mar 19, 2025

Relevant older issue: #24826

Note that we generally maintain a high bar for inclusion of new modules into torch.nn, as each addition comes with a steep maintenance cost. We reserve additions to modules that become ubiquitous (i.e. there is a widespread expectation that a module is provided by torch.nn) or for which there is a compelling performance-based reason to implement the module in torch.nn vs. require user-written variants. I think at this point my opinion is that RoPE is widespread enough to pass the ubiquity test at least.

@malfet malfet added triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module enhancement Not as big of a feature, but technically not a bug. Should be easy to fix labels Mar 20, 2025
@mikaylagawarecki
Copy link
Contributor

Hi @manuelcandales, we discussed this issue and came to the conclusion we would accept a PR that adds this to torch.nn with the following requirements

  1. Before adding this one goes over all the implementations linked above and ensures that there are no subtle differences between the implementations, with motivation given in the PR description
  2. The implementation added integrates properly with the entire pytorch stack (eager + PT2)
    For example, there is an implementation that uses complex numbers and one that does not, for the complex version, there might be some issues with inductor (see issue in torchtitan, these should be investigated when adding this op)

@vadimkantorov
Copy link
Contributor

It seems some variant got designed and added into ONNX opset:

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement Not as big of a feature, but technically not a bug. Should be easy to fix module: nn Related to torch.nn needs design triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

No branches or pull requests

5 participants
0