8000 Flex attention with NJT shape error · Issue #153371 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content
Flex attention with NJT shape error #153371
@Modexus

Description

@Modexus

🐛 Describe the bug

I am trying to run the NJT + flex_attention example from the blog post which only works for me if the head dimension D is the same as the batch size BATCH.
In the original blog post the batch size is 8 and head dimension is 16 which throws an error for me even without a mask.

from torch.nn.attention.flex_attention import flex_attention
import torch

BATCH = 8
NUM_HEADS = 8
D = 16
device = "cuda"

sequence_lengths = [torch.randint(5, 30, ()).item() for _ in range(BATCH)]
query = torch.nested.nested_tensor(
    [
        torch.randn(seq_len, NUM_HEADS * D, device=device)
        for seq_len in sequence_lengths
    ],
    layout=torch.jagged,
)
key = torch.randn_like(query)
value = torch.randn_like(query)

query = query.unflatten(-1, [NUM_HEADS, D]).transpose(1, 2)
key = key.unflatten(-1, [NUM_HEADS, D]).transpose(1, 2)
value = value.unflatten(-1, [NUM_HEADS, D]).transpose(1, 2)

output = flex_attention(query, key, value)

throws an error
AssertionError: s13 (could be from ['<ephemeral: symint_visitor_fn>', '<ephemeral: symint_visitor_fn>', '<ephemeral: symint_visitor_fn>']) not in {s50: ["L['args'][0]._values.size()[0]", "L['args'][0]._values.size()[0]"], s81: ["L['args'][0]._values.size()[1]", "L['args'][0]._values.size()[1]"], s21: ["L['args'][0]._values.size()[2]", "L['args'][0]._values.stride()[0]", "L['args'][0]._values.size()[2]", "L['args'][0]._values.stride()[0]"], s95: ["L['args'][0]._min_seqlen_tensor.size()[0]", "L['args'][0]._min_seqlen_tensor.size()[0]", "L['args'][0]._min_seqlen_tensor.size()[0]", "L['args'][1]._min_seqlen_tensor.size()[0]", "L['args'][1]._min_seqlen_tensor.size()[0]", "L['args'][1]._base._min_seqlen_tensor.size()[0]", "L['args'][1]._base._min_seqlen_tensor.size()[0]", "L['args'][2]._min_seqlen_tensor.size()[0]", "L['args'][2]._min_seqlen_tensor.size()[0]", "L['args'][2]._base._min_seqlen_tensor.size()[0]", "L['args'][2]._base._min_seqlen_tensor.size()[0]"], s82: ["L['args'][0]._max_seqlen_tensor.size()[0]", "L['args'][0]._max_seqlen_tensor.size()[0]", "L['args'][0]._max_seqlen_tensor.size()[0]", "L['args'][1]._max_seqlen_tensor.size()[0]", "L['args'][1]._max_seqlen_tensor.size()[0]", "L['args'][1]._base._max_seqlen_tensor.size()[0]", "L['args'][1]._base._max_seqlen_tensor.size()[0]", "L['args'][2]._max_seqlen_tensor.size()[0]", "L['args'][2]._max_seqlen_tensor.size()[0]", "L['args'][2]._base._max_seqlen_tensor.size()[0]", "L['args'][2]._base._max_seqlen_tensor.size()[0]"], s22: ["L['args'][0].size()[0]", "L['args'][0].size()[1]", "L['args'][0]._values.size()[0]", "L['args'][1].size()[0]", "L['args'][1].size()[1]", "L['args'][1]._values.size()[0]", "L['args'][1]._base.size()[0]", "L['args'][2].size()[0]", "L['args'][2].size()[1]", "L['args'][2]._values.size()[0]", "L['args'][2]._base.size()[0]"], s10: ["L['args'][0].size()[2]", "L['args'][1]._base.size()[1]", "L['args'][2]._base.size()[1]"], s48: ["L['args'][0].stride()[2]", "L['args'][0]._values.stride()[1]", "L['args'][0]._base.size()[1]", "L['args'][0]._base.stride()[0]"], s85: ["L['args'][0]._values.size()[1]", "L['args'][0]._base.size()[0]"], s83: ["L['args'][1]._values.size()[0]", "L['args'][1]._values.size()[0]"], s27: ["L['args'][1]._values.size()[1]", "L['args'][1]._values.size()[1]"], s59: ["L['args'][1]._values.size()[2]", "L['args'][1]._values.stride()[0]", "L['args'][1]._values.size()[2]", "L['args'][1]._values.stride()[0]"], s7: ["L['args'][1].size()[2]"], s0: ["L['args'][1].stride()[2]", "L['args'][1]._values.stride()[1]", "L['args'][1]._base._values.size()[1]", "L['args'][1]._base._values.stride()[0]", "L['args'][1]._base.size()[2]", "L['args'][1]._base.stride()[1]", "L['args'][1]._base._values.size()[1]", "L['args'][1]._base._values.stride()[0]"], s26: ["L['args'][1]._values.size()[1]", "L['args'][1]._base._values.size()[0]", "L['args'][1]._base._values.size()[0]"], s29: ["L['args'][2]._values.size()[0]", "L['args'][2]._values.size()[0]"], s11: ["L['args'][2]._values.size()[1]", "L['args'][2]._values.size()[1]"], s15: ["L['args'][2]._values.size()[2]", "L['args'][2]._values.stride()[0]", "L['args'][2]._values.size()[2]", "L['args'][2]._values.stride()[0]"], s40: ["L['args'][2].size()[2]"], s64: ["L['args'][2].stride()[2]", "L['args'][2]._values.stride()[1]", "L['args'][2]._base._values.size()[1]", "L['args'][2]._base._values.stride()[0]", "L['args'][2]._base.size()[2]", "L['args'][2]._base.stride()[1]", "L['args'][2]._base._values.size()[1]", "L['args'][2]._base._values.stride()[0]"], s1: ["L['args'][2]._values.size()[1]", "L['args'][2]._base._values.size()[0]", "L['args'][2]._base._values.size()[0]"], s88: ["L['args'][4][0]"], s16: ["L['args'][4][1]"], s65: ["L['args'][4][10]"], s98: ["L['args'][4][11]"], zf73: ["___as_tensor(L['args'][5]).item()", "L['args'][5]"]}. If this assert is failing, it could be due to the issue described in https://github.com/pytorch/pytorch/pull/90665

It works if both BATCH and D are the same.
It also works if the tensors are constructed directly with the correct dimensions

query = torch.nested.nested_tensor(
    [torch.randn(seq_len, NUM_HEADS, D, device=device) for seq_len in sequence_lengths],
    layout=torch.jagged,
)
key = torch.randn_like(query)
value = torch.randn_like(query)

query = query.transpose(1, 2)
key = key.transpose(1, 2)
value = value.transpose(1, 2)

so it has something to do with the unflatten operation and the jagged tensor construction.

Not sure if I am missing some limitation?

Versions

Tried it both on 2.7 and 2.8 nightly.

cc @cpuhrsch @jbschlosser @bhosmer @drisspg @soulitzer @davidberard98 @YuqingJ @chauhang @penguinwu @zou3519 @ydwu4 @bdhirsh @Chillee @yanboliang @BoyuanFeng

Metadata

Metadata

Assignees

No one assigned

    Labels

    module: flex attentionmodule: higher order operatorstorch.cond and similarmodule: nestedtensorNestedTensor tag see issue #25032module: pt2-dispatcherPT2 dispatcher-related issues (e.g., aotdispatch, functionalization, faketensor, custom-op,oncall: pt2triagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions

      0