8000 Inductor may permute inputs to flex attention, leading to assertion error · Issue #148827 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content

Inductor may permute inputs to flex attention, leading to assertion error #148827

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Aleko2286 opened this issue Mar 8, 2025 · 6 comments
Closed
Assignees
Labels
module: flex attention module: higher order operators torch.cond and similar module: inductor module: pt2-dispatcher PT2 dispatcher-related issues (e.g., aotdispatch, functionalization, faketensor, custom-op, oncall: pt2 triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@Aleko2286
Copy link
Aleko2286 commented Mar 8, 2025

🐛 Describe the bug

For flex attention, inputs must be contiguous, but inductor seems to permute inputs under certain conditions which then results in an assertion error.

When using an Attention layer looking somewhat like this:

class Attention(nn.Module):
    def __init__(
        self,
        q_ch: int,
        kv_ch: Optional[int] = None,
        qk_embed_dim: Optional[int] = None,
        v_embed_dim: Optional[int] = None,
        linear_bias: bool = False,
        num_heads: int = 1,
    ):
        self.q_ch = q_ch
        self.kv_ch = kv_ch or self.q_ch
        self.qk_embed_dim = qk_embed_dim or self.q_ch
        self.v_embed_dim = v_embed_dim or self.kv_ch
        self.num_heads = num_heads
        assert (
            not self.qk_embed_dim % num_heads and not self.v_embed_dim % num_heads
        ), "The dimension of the embeddings in Attention must be divisible by the number of heads."
        super().__init__()

        self.q_proj = nn.Linear(self.q_ch, self.qk_embed_dim, bias=linear_bias)
        self.kv_proj = nn.Linear(
            self.kv_ch, self.qk_embed_dim + self.v_embed_dim, bias=linear_bias
        )
        self.o_proj = nn.Linear(self.v_embed_dim, self.q_ch, bias=linear_bias)

    def scaled_dot_product_attention(
        self,
        q: torch.Tensor,
        k: torch.Tensor,
        v: torch.Tensor,
        block_mask: torch.nn.attention.flex_attention.BlockMask,
    ) -> torch.Tensor:
        return torch.nn.attention.flex_attention.flex_attention(
            q, k, v, block_mask=block_mask
        )

    def forward(
        self, x: torch.Tensor, block_mask: torch.nn.attention.flex_attention.BlockMask
    ) -> torch.Tensor:
        q = self.q_proj(x)
        kv = self.kv_proj(x)
        k = kv[..., : self.qk_embed_dim]
        v = kv[..., self.qk_embed_dim :]
        q = q.reshape((q.shape[0], q.shape[1], self.num_heads, -1)).transpose(1, 2)
        k = k.reshape((k.shape[0], k.shape[1], self.num_heads, -1)).transpose(1, 2)
        v = v.reshape((v.shape[0], v.shape[1], self.num_heads, -1)).transpose(1, 2)
        return self.o_proj(
            self.scaled_dot_product_attention(q, k, v, block_mask)
            .transpose(1, 2)
           .reshape((x.shape[0], x.shape[1], -1))
        )

I get a LoweringException under certain conditions. It does not reproduce as a standalone example sadly. In my model, this only happens if I do a validation iteration before doing a training iteration. If I directly start training, the compilation results seems to be different and the training runs without any issue. From the error message, it looks like inductor swaps the original memory format (B, L, H, C) [transposed to (B, H, L, C)] to (B, L, C, H), which results in non-contiguous q, k and v. (B=1, H=24, L=1904, C=32)

There seems to be no straight forward way to fix this. For example, the following code will also run into the same problem:

def scaled_dot_product_attention(
        self,
        q: torch.Tensor,
        k: torch.Tensor,
        v: torch.Tensor,
        block_mask: torch.nn.attention.flex_attention.BlockMask,
    ) -> torch.Tensor:
        return torch.nn.attention.flex_attention.flex_attention(
            q.contiguous(), k.contiguous(), v.contiguous(), block_mask=block_mask
        )

Workarounds like this exist:

def scaled_dot_product_attention(
        self,
        q: torch.Tensor,
        k: torch.Tensor,
        v: torch.Tensor,
        block_mask: torch.nn.attention.flex_attention.BlockMask,
    ) -> torch.Tensor:
        print("", end="")
        return torch.nn.attention.flex_attention.flex_attention(
            q, k, v, block_mask=block_mask
        )

I think flex_attention should probably not error on non-contiguous tensors, but rather enforce it itself. In any case, this is unexpected behavior from a user's perspective, since even when the eager version is contiguous, the compilation may fail due to query not being contiguous.

Error logs

torch._inductor.exc.InductorError: LoweringException: AssertionError: Query must be contiguous in the last dimension
target: flex_attention
args[0]: TensorBox(StorageBox(
ComputedBuffer(name='buf195', layout=FixedLayout('cuda:0', torch.float16, size=[1, 24, 1904, 32], stride=[1462272, 1, 768, 24]), data=Pointwise(device=device(type='cuda', index=0), dtype=torch.float16, inner_fn=<function pointwise_cat..inner_fn at 0x7f037043fd80>, ranges=[1, 24, 1904, 32]))
))
args[1]: TensorBox(StorageBox(
ComputedBuffer(name='buf196', layout=FixedLayout('cuda:0', torch.float16, size=[1, 24, 1904, 32], stride=[1462272, 1, 768, 24]), data=Pointwise(device=device(type='cuda', index=0), dtype=torch.float16, inner_fn=<function pointwise_cat..inner_fn at 0x7f0370485e40>, ranges=[1, 24, 1904, 32]))
))
args[2]: TensorBox(
ReinterpretView(
StorageBox(
ExternKernelOut(
python_kernel_name='extern_kernels.mm',
name=buf192,
layout=FixedLayout('cuda:0', torch.float16, size=[1904, 1536], stride=[1536, 1]),
inputs=[ReinterpretView(
StorageBox(
ComputedBuffer(name='buf189', layout=FixedLayout('cuda:0', torch.float16, size=[1, 1904, 768], stride=[1462272, 768, 1]), data=Pointwise(device=device(type='cuda', index=0), dtype=torch.float16, inner_fn=<function make_pointwise..inner..inner_fn at 0x7f03703a4e00>, ranges=[1, 1904, 768]))
),
FixedLayout('cuda:0', torch.float16, size=[1904, 768], stride=[768, 1]),
origins=OrderedSet([mm_1])
), ComputedBuffer(name='buf191', layout=FixedLayout('cuda:0', torch.float16, size=[768, 1536], stride=[1, 768]), data=Pointwise(device=device(type='cuda', index=0), dtype=torch.float16, inner_fn=<function BaseView.make_loader..loader at 0x7f03703a49a0>, ranges=[768, 1536]))],
constant_args=(),
kwargs={},
output_view=None,
python_kernel_name=extern_kernels.mm,
cpp_kernel_name=at::mm_out,
ordered_kwargs_for_cpp_kernel=(),
op_overload=None,
arg_properties=[{}, {}],
kwarg_properties=None,
unbacked_bindings={},
mutation_outputs=[],
origin_node=mm_1,
origins=OrderedSet([mm_1])
)
),
FixedLayout('cuda:0', torch.float16, size=[1, 24, 1904, 32], stride=[0, 32, 1536, 1], offset=768),
origins=OrderedSet([permute_60])
)
)
args[3]: Subgraph(name='sdpa_score0', graph_module=(), graph=None)
args[4]: (1904, 1904, TensorBox(StorageBox(
InputBuffer(name='primals_95', layout=FixedLayout('cuda:0', torch.int32, size=[4, 24, 15], stride=[360, 15, 1]))
)), TensorBox(StorageBox(
InputBuffer(name='primals_94', layout=FixedLayout('cuda:0', torch.int32, size=[4, 24, 15, 15], stride=[5400, 225, 15, 1]))
)), TensorBox(StorageBox(
InputBuffer(name='primals_96', layout=FixedLayout('cuda:0', torch.int32, size=[4, 24, 15], stride=[360, 15, 1]))
)), TensorBox(StorageBox(
InputBuffer(name='primals_97', layout=FixedLayout('cuda:0', torch.int32, size=[4, 24, 15, 15], stride=[5400, 225, 15, 1]))
)), TensorBox(StorageBox(
InputBuffer(name='primals_98', layout=FixedLayout('cuda:0', torch.int32, size=[4, 24, 15], stride=[360, 15, 1]))
)), TensorBox(StorageBox(
InputBuffer(name='primals_99', layout=FixedLayout('cuda:0', torch.int32, size=[4, 24, 15, 15], stride=[5400, 225, 15, 1]))
)), TensorBox(StorageBox(
InputBuffer(name='primals_100', layout=FixedLayout('cuda:0', torch.int32, size=[4, 24, 15], stride=[360, 15, 1]))
)), TensorBox(StorageBox(
InputBuffer(name='primals_101', layout=FixedLayout('cuda:0', torch.int32, size=[4, 24, 15, 15], stride=[5400, 225, 15, 1]))
)), 128, 128, Subgraph(name='sdpa_mask0', graph_module=(), graph=None))
args[5]: 0.17677669529663687
args[6]: {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True}
args[7]: ()
args[8]: ()

Versions

PyTorch version: 2.7.0.dev20250308+cu128
Is debug build: False
CUDA used to build PyTorch: 12.8
ROCM used to build PyTorch: N/A

OS: Ubuntu 24.04.1 LTS (x86_64)
GCC version: (Ubuntu 13.3.0-6ubuntu2~24.04) 13.3.0
Clang version: Could not collect
CMake version: Could not collect
Libc version: glibc-2.39

Python version: 3.12.3 (main, Jan 17 2025, 18:03:48) [GCC 13.3.0] (64-bit runtime)
Python platform: Linux-6.13.5-1-default-x86_64-with-glibc2.39
Is CUDA available: True
CUDA runtime version: Could not collect
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration: GPU 0: NVIDIA GeForce RTX 4070
Nvidia driver version: 570.124.04
cuDNN version: Could not collect
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

CPU:
Architecture: x86_64
CPU op-mode(s): 32-bit, 64-bit
Address sizes: 48 bits physical, 48 bits virtual
Byte Order: Little Endian
CPU(s): 12
On-line CPU(s) list: 0-11
Vendor ID: AuthenticAMD
Model name: AMD Ryzen 5 5600X 6-Core Processor
CPU family: 25
Model: 33
Thread(s) per core: 2
Core(s) per socket: 6
Socket(s): 1
Stepping: 0
Frequency boost: enabled
CPU(s) scaling MHz: 100%
CPU max MHz: 4651.0000
CPU min MHz: 550.0000
BogoMIPS: 7400.31
Flags: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ht syscall nx mmxext fxsr_opt pdpe1gb rdtscp lm constant_tsc rep_good nopl xtopology nonstop_tsc cpuid extd_apicid aperfmperf rapl pni pclmulqdq monitor ssse3 fma cx16 sse4_1 sse4_2 x2apic movbe popcnt aes xsave avx f16c rdrand lahf_lm cmp_legacy svm extapic cr8_legacy abm sse4a misalignsse 3dnowprefetch osvw ibs skinit wdt tce topoext perfctr_core perfctr_nb bpext perfctr_llc mwaitx cpb cat_l3 cdp_l3 hw_pstate ssbd mba ibrs ibpb stibp vmmcall fsgsbase bmi1 avx2 smep bmi2 erms invpcid cqm rdt_a rdseed adx smap clflushopt clwb sha_ni xsaveopt xsavec xgetbv1 xsaves cqm_llc cqm_occup_llc cqm_mbm_total cqm_mbm_local user_shstk clzero irperf xsaveerptr rdpru wbnoinvd arat npt lbrv svm_lock nrip_save tsc_scale vmcb_clean flushbyasid decodeassists pausefilter pfthreshold avic v_vmsave_vmload vgif v_spec_ctrl umip pku ospke vaes vpclmulqdq rdpid overflow_recov succor smca fsrm debug_swap
Virtualization: AMD-V
L1d cache: 192 KiB (6 instances)
L1i cache: 192 KiB (6 instances)
L2 cache: 3 MiB (6 instances)
L3 cache: 32 MiB (1 instance)
NUMA node(s): 1
NUMA node0 CPU(s): 0-11
Vulnerability Gather data sampling: Not affected
Vulnerability Itlb multihit: Not affected
Vulnerability L1tf: Not affected
Vulnerability Mds: Not affected
Vulnerability Meltdown: Not affected
Vulnerability Mmio stale data: Not affected
Vulnerability Reg file data sampling: Not affected
Vulnerability Retbleed: Not affected
Vulnerability Spec rstack overflow: Mitigation; Safe RET
Vulnerability Spec store bypass: Mitigation; Speculative Store Bypass disabled via prctl
Vulnerability Spectre v1: Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2: Mitigation; Retpolines; IBPB conditional; IBRS_FW; STIBP always-on; RSB filling; PBRSB-eIBRS Not affected; BHI Not affected
Vulnerability Srbds: Not affected
Vulnerability Tsx async abort: Not affected

Versions of relevant libraries:
[pip3] mypy-extensions==1.0.0
[pip3] numpy==2.2.2
[pip3] nvidia-cublas-cu12==12.8.3.14
[pip3] nvidia-cuda-cupti-cu12==12.8.57
[pip3] nvidia-cuda-nvrtc-cu12==12.8.61
[pip3] nvidia-cuda-runtime-cu12==12.8.57
[pip3] nvidia-cudnn-cu12==9.7.1.26
[pip3] nvidia-cufft-cu12==11.3.3.41
[pip3] nvidia-curand-cu12==10.3.9.55
[pip3] nvidia-cusolver-cu12==11.7.2.55
[pip3] nvidia-cusparse-cu12==12.5.7.53
[pip3] nvidia-cusparselt-cu12==0.6.3
[pip3] nvidia-nccl-cu12==2.25.1
[pip3] nvidia-nvjitlink-cu12==12.8.61
[pip3] nvidia-nvtx-cu12==12.8.55
[pip3] pytorch-triton==3.2.0+git4b3bb1f8
[pip3] torch==2.7.0.dev20250308+cu128
[pip3] torchaudio==2.6.0.dev20250308+cu128
[pip3] torchinfo==1.8.0
[pip3] torchvision==0.22.0.dev20250308+cu128
[pip3] triton==3.2.0
[conda] Could not collect

cc @chauhang @penguinwu @voznesenskym @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @chenyang78 @kadeng @muchulee8 @amjames @aakhundov @zou3519 @ydwu4 @bdhirsh @Chillee @drisspg @yanboliang @BoyuanFeng

@pytorch-bot pytorch-bot bot added module: higher order operators torch.cond and similar module: pt2-dispatcher PT2 dispatcher-related issues (e.g., aotdispatch, functionalization, faketensor, custom-op, labels Mar 18, 2025
@IvanKobzarev
Copy link
Contributor

@Aleko2286
Could you please record the graphs in case of failure TORCH_LOGS="graph_code,aot,output_code"
How different are 'validation' inputs from training, do they have the same strideness?

@Aleko2286
Copy link
Author

The validation is not different, it just happens to result in a different memory layout after optimization for this particular model. I have other models now where it always happens though.

I also have a reproducible example now:

from typing import Any, Dict, Optional, Tuple


import torch
import torch.nn as nn
import torch.nn.attention.flex_attention
import torch.nn.functional as F


class Attention(nn.Module):
    def __init__(
        self,
        q_ch: int,
        kv_ch: Optional[int] = None,
        qk_embed_dim: Optional[int] = None,
        v_embed_dim: Optional[int] = None,
        linear_bias: bool = False,
        num_heads: int = 1,
    ):
        self.q_ch = q_ch
        self.kv_ch = kv_ch or self.q_ch
        self.qk_embed_dim = qk_embed_dim or self.q_ch
        self.v_embed_dim = v_embed_dim or self.kv_ch
        self.num_heads = num_heads
        assert (
            not self.qk_embed_dim % num_heads and not self.v_embed_dim % num_heads
        ), "The dimension of the embeddings in Attention must be divisible by the number of heads."
        super().__init__()

        self.q_proj = nn.Linear(self.q_ch, self.qk_embed_dim, bias=linear_bias)
        self.kv_proj = nn.Linear(
            self.kv_ch, self.qk_embed_dim + self.v_embed_dim, bias=linear_bias
        )
        self.o_proj = nn.Linear(self.v_embed_dim, self.q_ch, bias=linear_bias)

    def scaled_dot_product_attention(
        self,
        q: torch.Tensor,
        k: torch.Tensor,
        v: torch.Tensor,
        block_mask: torch.nn.attention.flex_attention.BlockMask,
    ) -> torch.Tensor:
        return torch.nn.attention.flex_attention.flex_attention(
            q, k, v, block_mask=block_mask
        )

    def forward(
        self, x: torch.Tensor, block_mask: torch.nn.attention.flex_attention.BlockMask
    ) -> torch.Tensor:
        q = self.q_proj(x)
        kv = self.kv_proj(x)
        k = kv[..., : self.qk_embed_dim]
        v = kv[..., self.qk_embed_dim :]
        q = q.reshape((q.shape[0], q.shape[1], self.num_heads, -1)).transpose(1, 2)
        k = k.reshape((k.shape[0], k.shape[1], self.num_heads, -1)).transpose(1, 2)
        v = v.reshape((v.shape[0], v.shape[1], self.num_heads, -1)).transpose(1, 2)
        return self.o_proj(
            self.scaled_dot_product_attention(q, k, v, block_mask)
            .transpose(1, 2)
            .reshape((x.shape[0], x.shape[1], -1))
        )


def create_block_mask(b: int, h: int, q_seq_len: int, k_seq_len: int):
    def mask(b, h, q_idx, k_idx):
        return q_idx >= k_idx

    return torch.nn.attention.flex_attention.create_block_mask(
        mask, b, h, q_seq_len, k_seq_len
    )


@torch.compile
class TestModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.attn_layer = Attention(768, num_heads=24)
        self.convs = nn.Sequential(
            *(nn.Conv2d(768, 768, 3, padding=1) for _ in range(3))
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        block_mask = create_block_mask(8, 24, 256, 256)
        y = self.convs(x).view((x.shape[0], x.shape[1], -1)).transpose(1, 2)
        o = self.attn_layer(y, block_mask)
        return self.convs(o.transpose(1, 2).reshape(x.shape))


if __name__ == "__main__":
    test_model = TestModel().cuda()
    x = torch.randn((8, 768, 16, 16), device="cuda", requires_grad=True)
    z = test_model(x)
    z.sum().backward()
    torch.cuda.synchronize()

All my models use convolution somewhere, and inductor seems to like to change the memory format for those, while flex attention does not enforce its required memory format, but only errors if it doesn't fit.

The example runs fine with TORCHINDUCTOR_LAYOUT_OPTIMIZATION=0. Nonetheless, I have attached the debug output for this example for TORCHINDUCTOR_LAYOUT_OPTIMIZATION=1.

Since I was looking through inductor config for options to disable this layout optimization, I also noticed that custom/triton ops seem to have some flags to prevent the compiler from calling them with a changed memory layout:

# The default layout constraint for custom operators.
# This must be the name of one of the layout constraint tags
# (that is, one of {"needs_fixed_stride_order", "flexible_layout"}),
# If the custom op does not have a layout constraint tag already
# then we assume the following applies.
custom_op_default_layout_constraint: Literal[
"needs_fixed_stride_order", "flexible_layout"
] = "needs_fixed_stride_order"
# The default layout constraint for user-defined triton kernels.
# See "The default layout constraint for custom operators" for options.
triton_kernel_default_layout_constraint: Literal[
"needs_fixed_stride_order", "flexible_layout"
] = "needs_fixed_stride_order"

Something like that should probably also be implemented for flex attention.

debug_logs.zip

@zou3519
Copy link
Contributor
zou3519 commented Mar 18, 2025

Does flex attention have strict layout requirements on the inputs? @drisspg

@drisspg
Copy link
Contributor
drisspg commented Mar 18, 2025

The only restriction we have on layout is that the last dim's stride is 1, like torch._inductor.exc.InductorError: LoweringException: AssertionError: Query must be contiguous in the last dimension target: flex_attention args[0]: TensorBox(StorageBox( suggests

I have rarely seen this assert fail, do we need a minimal assert size_strides like what is done for sdpa?

@zou3519
Copy link
Contributor
zou3519 commented Mar 19, 2025

Inductor generally is able to change the strides of inputs to any operator, unless there is special handling for it in FlexAttention. We recently added a "don't change the strides of inputs to operator X" that we can apply here if FlexAttention doesn't have something similar yet

@jansel jansel added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Mar 29, 2025
@jansel
Copy link
Contributor
jansel commented Mar 29, 2025

@drisspg I think you are the right one to look into this, though should be a 1-line fix. We have a prim to force strides here:

force_stride_order = make_prim(
"inductor_force_stride_order(Tensor input, SymInt[] stride) -> Tensor",
eager_force_stride,
doc="Force the stride order for input tensor. No-op if the input tensor already has the stride. Do a copy otherwise",
)

You could also add the torch._C.Tag.needs_fixed_stride_order tag to the op to get inductor to match eager strides.

drisspg added a commit that referenced this issue Apr 21, 2025
… was causing assert failure"



# Summary
Fixes: #148827

This one is strange, I could have sworn this was a real constraint, but I verified and did some performance checks and this constraint isn't required. 

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov Chillee yanboliang BoyuanFeng

[ghstack-poisoned]
drisspg added a commit that referenced this issue Apr 21, 2025
…ert failure"



# Summary
Fixes: #148827

This one is strange, I could have sworn this was a real constraint, but I verified and did some performance checks and this constraint isn't required. 

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov Chillee yanboliang BoyuanFeng

[ghstack-poisoned]
wangkuiyi pushed a commit to wangkuiyi/pytorch that referenced this issue Apr 25, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: flex attention module: higher order operators torch.cond and similar module: inductor module: pt2-dispatcher PT2 dispatcher-related issues (e.g., aotdispatch, functionalization, faketensor, custom-op, oncall: pt2 triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

No branches or pull requests

5 participants
0