-
Notifications
You must be signed in to change notification settings - Fork 24.8k
Description
🐛 Describe the bug
I am trying to run the NJT + flex_attention example from the blog post which only works for me if the head dimension D is the same as the batch size BATCH.
In the original blog post the batch size is 8 and head dimension is 16 which throws an error for me even without a mask.
from torch.nn.attention.flex_attention import flex_attention
import torch
BATCH = 8
NUM_HEADS = 8
D = 16
device = "cuda"
sequence_lengths = [torch.randint(5, 30, ()).item() for _ in range(BATCH)]
query = torch.nested.nested_tensor(
[
torch.randn(seq_len, NUM_HEADS * D, device=device)
for seq_len in sequence_lengths
],
layout=torch.jagged,
)
key = torch.randn_like(query)
value = torch.randn_like(query)
query = query.unflatten(-1, [NUM_HEADS, D]).transpose(1, 2)
key = key.unflatten(-1, [NUM_HEADS, D]).transpose(1, 2)
value = value.unflatten(-1, [NUM_HEADS, D]).transpose(1, 2)
output = flex_attention(query, key, value)
throws an error
AssertionError: s13 (could be from ['<ephemeral: symint_visitor_fn>', '<ephemeral: symint_visitor_fn>', '<ephemeral: symint_visitor_fn>']) not in {s50: ["L['args'][0]._values.size()[0]", "L['args'][0]._values.size()[0]"], s81: ["L['args'][0]._values.size()[1]", "L['args'][0]._values.size()[1]"], s21: ["L['args'][0]._values.size()[2]", "L['args'][0]._values.stride()[0]", "L['args'][0]._values.size()[2]", "L['args'][0]._values.stride()[0]"], s95: ["L['args'][0]._min_seqlen_tensor.size()[0]", "L['args'][0]._min_seqlen_tensor.size()[0]", "L['args'][0]._min_seqlen_tensor.size()[0]", "L['args'][1]._min_seqlen_tensor.size()[0]", "L['args'][1]._min_seqlen_tensor.size()[0]", "L['args'][1]._base._min_seqlen_tensor.size()[0]", "L['args'][1]._base._min_seqlen_tensor.size()[0]", "L['args'][2]._min_seqlen_tensor.size()[0]", "L['args'][2]._min_seqlen_tensor.size()[0]", "L['args'][2]._base._min_seqlen_tensor.size()[0]", "L['args'][2]._base._min_seqlen_tensor.size()[0]"], s82: ["L['args'][0]._max_seqlen_tensor.size()[0]", "L['args'][0]._max_seqlen_tensor.size()[0]", "L['args'][0]._max_seqlen_tensor.size()[0]", "L['args'][1]._max_seqlen_tensor.size()[0]", "L['args'][1]._max_seqlen_tensor.size()[0]", "L['args'][1]._base._max_seqlen_tensor.size()[0]", "L['args'][1]._base._max_seqlen_tensor.size()[0]", "L['args'][2]._max_seqlen_tensor.size()[0]", "L['args'][2]._max_seqlen_tensor.size()[0]", "L['args'][2]._base._max_seqlen_tensor.size()[0]", "L['args'][2]._base._max_seqlen_tensor.size()[0]"], s22: ["L['args'][0].size()[0]", "L['args'][0].size()[1]", "L['args'][0]._values.size()[0]", "L['args'][1].size()[0]", "L['args'][1].size()[1]", "L['args'][1]._values.size()[0]", "L['args'][1]._base.size()[0]", "L['args'][2].size()[0]", "L['args'][2].size()[1]", "L['args'][2]._values.size()[0]", "L['args'][2]._base.size()[0]"], s10: ["L['args'][0].size()[2]", "L['args'][1]._base.size()[1]", "L['args'][2]._base.size()[1]"], s48: ["L['args'][0].stride()[2]", "L['args'][0]._values.stride()[1]", "L['args'][0]._base.size()[1]", "L['args'][0]._base.stride()[0]"], s85: ["L['args'][0]._values.size()[1]", "L['args'][0]._base.size()[0]"], s83: ["L['args'][1]._values.size()[0]", "L['args'][1]._values.size()[0]"], s27: ["L['args'][1]._values.size()[1]", "L['args'][1]._values.size()[1]"], s59: ["L['args'][1]._values.size()[2]", "L['args'][1]._values.stride()[0]", "L['args'][1]._values.size()[2]", "L['args'][1]._values.stride()[0]"], s7: ["L['args'][1].size()[2]"], s0: ["L['args'][1].stride()[2]", "L['args'][1]._values.stride()[1]", "L['args'][1]._base._values.size()[1]", "L['args'][1]._base._values.stride()[0]", "L['args'][1]._base.size()[2]", "L['args'][1]._base.stride()[1]", "L['args'][1]._base._values.size()[1]", "L['args'][1]._base._values.stride()[0]"], s26: ["L['args'][1]._values.size()[1]", "L['args'][1]._base._values.size()[0]", "L['args'][1]._base._values.size()[0]"], s29: ["L['args'][2]._values.size()[0]", "L['args'][2]._values.size()[0]"], s11: ["L['args'][2]._values.size()[1]", "L['args'][2]._values.size()[1]"], s15: ["L['args'][2]._values.size()[2]", "L['args'][2]._values.stride()[0]", "L['args'][2]._values.size()[2]", "L['args'][2]._values.stride()[0]"], s40: ["L['args'][2].size()[2]"], s64: ["L['args'][2].stride()[2]", "L['args'][2]._values.stride()[1]", "L['args'][2]._base._values.size()[1]", "L['args'][2]._base._values.stride()[0]", "L['args'][2]._base.size()[2]", "L['args'][2]._base.stride()[1]", "L['args'][2]._base._values.size()[1]", "L['args'][2]._base._values.stride()[0]"], s1: ["L['args'][2]._values.size()[1]", "L['args'][2]._base._values.size()[0]", "L['args'][2]._base._values.size()[0]"], s88: ["L['args'][4][0]"], s16: ["L['args'][4][1]"], s65: ["L['args'][4][10]"], s98: ["L['args'][4][11]"], zf73: ["___as_tensor(L['args'][5]).item()", "L['args'][5]"]}. If this assert is failing, it could be due to the issue described in https://github.com/pytorch/pytorch/pull/90665
It works if both BATCH and D are the same.
It also works if the tensors are constructed directly with the correct dimensions
query = torch.nested.nested_tensor(
[torch.randn(seq_len, NUM_HEADS, D, device=device) for seq_len in sequence_lengths],
layout=torch.jagged,
)
key = torch.randn_like(query)
value = torch.randn_like(query)
query = query.transpose(1, 2)
key = key.transpose(1, 2)
value = value.transpose(1, 2)
so it has something to do with the unflatten operation and the jagged tensor construction.
Not sure if I am missing some limitation?
Versions
Tried it both on 2.7 and 2.8 nightly.
cc @cpuhrsch @jbschlosser @bhosmer @drisspg @soulitzer @davidberard98 @YuqingJ @chauhang @penguinwu @zou3519 @ydwu4 @bdhirsh @Chillee @yanboliang @BoyuanFeng