-
Notifications
You must be signed in to change notification settings - Fork 24.8k
Description
🚀 The feature, motivation and pitch
Nested tensors are supported by PyTorch's flash attention implementation (cf. https://gist.github.com/victoroliv2/3668f07e11a0757febb6e55a8d78592a) and this has a markable (approx 25%) speedup compared to alternative options. But extending this example to a full multi-head attention implementation does not work at the moment since flash attention expects 3D tensors in the nested_tensor while nn.Linear requires 2D tensors.
RuntimeError: Linear requires nested_tensor.dim == 3 and dense_matrix.dim == 2. Nested tensor dim: 4. Dense tensor dim: 2
This restriction on nn.Linear also seems odd to me. One could in principle construct the nested_tensor only after the projection but since this involves a copy operation it is rather inefficient and will likely negate any benefit from the flash attention with nested tensors.
Alternatives
No response
Additional context
Here's a minimal example:
class AttentionHead(torch.nn.Module):
def __init__(self, proj_dims) :
'''Attention head'''
super(AttentionHead, self).__init__()
# self.proj = torch.nn.Linear( proj_dims[0], 3*proj_dims[1], bias = False)
self.proj_q = torch.nn.Linear( proj_dims[0], proj_dims[1], bias = False)
self.proj_k = torch.nn.Linear( proj_dims[0], proj_dims[1], bias = False)
self.proj_v = torch.nn.Linear( proj_dims[0], proj_dims[1], bias = False)
self.softmax = torch.nn.Softmax(dim=-1)
self.lnorm_q = torch.nn.LayerNorm( proj_dims[1], elementwise_affine=False)
self.lnorm_k = torch.nn.LayerNorm( proj_dims[1], elementwise_affine=False)
self.att = torch.nn.functional.scaled_dot_product_attention
def forward( self, xs) :
# q, k, v = torch.tensor_split( self.proj( xs), 3, dim=-1)
q, k, v = self.proj_q( xs), self.proj_k( xs), self.proj_v( xs)
q, k = self.lnorm_q( q), self.lnorm_k( k)
with torch.backends.cuda.sdp_kernel( enable_flash=True, enable_math=False,
enable_mem_efficient=False):
q_out = self.att( q, k, v)
return q_out
cc @cpuhrsch @jbschlosser @bhosmer @drisspg @soulitzer @erichan1 @mikaylagawarecki