8000 Addind RoPE to pytorch core · Issue #149534 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content
Addind RoPE to pytorch core #149534
Open
@manuelcandales

Description

@manuelcandales

The RoPE python code is being copied and pasted over and over in multiple pytorch org repos. I propose we move the RoPE operation to pytorch core (e.g. under nn.functional) and also add a RotaryPositionalEmbeddings module. Some examples of code duplication:

pytorch/ao:

pytorch/benchmark:

pytorch/torchchat:

pytorch/torchtune:

pytorch/xla:

pytorch/pytorch:

  • def apply_rotary_emb(x: Tensor, freqs_cis: Tensor) -> Tensor:
    xshaped = x.float().reshape(*x.shape[:-1], -1, 2)
    freqs_cis = freqs_cis.view(1, xshaped.size(1), 1, xshaped.size(3), 2)
    x_out2 = torch.stack(
    [
    xshaped[..., 0] * freqs_cis[..., 0] - xshaped[..., 1] * freqs_cis[..., 1],
    xshaped[..., 1] * freqs_cis[..., 0] + xshaped[..., 0] * freqs_cis[..., 1],
    ],
    -1,
    )
    x_out2 = x_out2.flatten(3)
    return x_out2.type_as(x)
  • def apply_rotary_emb(x: Tensor, freqs_cis: Tensor) -> Tensor:
    xshaped = x.float().reshape(*x.shape[:-1], -1, 2)
    freqs_cis = freqs_cis.view(1, xshaped.size(1), 1, xshaped.size(3), 2)
    x_out2 = torch.stack(
    [
    xshaped[..., 0] * freqs_cis[..., 0] - xshaped[..., 1] * freqs_cis[..., 1],
    xshaped[..., 1] * freqs_cis[..., 0] + xshaped[..., 0] * freqs_cis[..., 1],
    ],
    -1,
    )
    x_out2 = x_out2.flatten(3)
    return x_out2.type_as(x)

cc @albanD @mruberry @jbschlosser @walterddr @mikaylagawarecki

Metadata

Metadata

Labels

enhancementNot as big of a feature, but technically not a bug. Should be easy to fixmodule: nnRelated to torch.nnneeds designWe want to add this feature but we need to figure out how firsttriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions

    0