8000 [NT] Implementing Multi-Head Attention with NestedTensors · Issue #125214 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content
[NT] Implementing Multi-Head Attention with NestedTensors #125214
@clessig

Description

@clessig

🚀 The feature, motivation and pitch

Nested tensors are supported by PyTorch's flash attention implementation (cf. https://gist.github.com/victoroliv2/3668f07e11a0757febb6e55a8d78592a) and this has a markable (approx 25%) speedup compared to alternative options. But extending this example to a full multi-head attention implementation does not work at the moment since flash attention expects 3D tensors in the nested_tensor while nn.Linear requires 2D tensors.

RuntimeError: Linear requires nested_tensor.dim == 3 and dense_matrix.dim == 2. Nested tensor dim: 4. Dense tensor dim: 2

This restriction on nn.Linear also seems odd to me. One could in principle construct the nested_tensor only after the projection but since this involves a copy operation it is rather inefficient and will likely negate any benefit from the flash attention with nested tensors.

Alternatives

No response

Additional context

Here's a minimal example:

class AttentionHead(torch.nn.Module):

  def __init__(self, proj_dims) :
    '''Attention head'''

    super(AttentionHead, self).__init__()

    # self.proj = torch.nn.Linear( proj_dims[0], 3*proj_dims[1], bias = False)
    self.proj_q = torch.nn.Linear( proj_dims[0], proj_dims[1], bias = False)
    self.proj_k = torch.nn.Linear( proj_dims[0], proj_dims[1], bias = False)
    self.proj_v = torch.nn.Linear( proj_dims[0], proj_dims[1], bias = False)
    
    self.softmax = torch.nn.Softmax(dim=-1)

    self.lnorm_q = torch.nn.LayerNorm( proj_dims[1], elementwise_affine=False)
    self.lnorm_k = torch.nn.LayerNorm( proj_dims[1], elementwise_affine=False)
    
    self.att = torch.nn.functional.scaled_dot_product_attention

  def forward( self, xs) :
    
    # q, k, v = torch.tensor_split( self.proj( xs), 3, dim=-1)
    q, k, v = self.proj_q( xs), self.proj_k( xs), self.proj_v( xs)
    q, k = self.lnorm_q( q), self.lnorm_k( k)

    with torch.backends.cuda.sdp_kernel( enable_flash=True, enable_math=False,
                                         enable_mem_efficient=False):
      q_out = self.att( q, k, v)
      
    return q_out

cc @cpuhrsch @jbschlosser @bhosmer @drisspg @soulitzer @erichan1 @mikaylagawarecki

Metadata

Metadata

Assignees

No one assigned

    Labels

    module: nestedtensorNestedTensor tag see issue #25032triagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions

      0