-
Notifications
You must be signed in to change notification settings - Fork 24.3k
Inductor may permute inputs to flex attention, leading to assertion error #148827
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Comments
@Aleko2286 |
The validation is not different, it just happens to result in a different memory layout after optimization for this particular model. I have other models now where it always happens though. I also have a reproducible example now: from typing import Any, Dict, Optional, Tuple
import torch
import torch.nn as nn
import torch.nn.attention.flex_attention
import torch.nn.functional as F
class Attention(nn.Module):
def __init__(
self,
q_ch: int,
kv_ch: Optional[int] = None,
qk_embed_dim: Optional[int] = None,
v_embed_dim: Optional[int] = None,
linear_bias: bool = False,
num_heads: int = 1,
):
self.q_ch = q_ch
self.kv_ch = kv_ch or self.q_ch
self.qk_embed_dim = qk_embed_dim or self.q_ch
self.v_embed_dim = v_embed_dim or self.kv_ch
self.num_heads = num_heads
assert (
not self.qk_embed_dim % num_heads and not self.v_embed_dim % num_heads
), "The dimension of the embeddings in Attention must be divisible by the number of heads."
super().__init__()
self.q_proj = nn.Linear(self.q_ch, self.qk_embed_dim, bias=linear_bias)
self.kv_proj = nn.Linear(
self.kv_ch, self.qk_embed_dim + self.v_embed_dim, bias=linear_bias
)
self.o_proj = nn.Linear(self.v_embed_dim, self.q_ch, bias=linear_bias)
def scaled_dot_product_attention(
self,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
block_mask: torch.nn.attention.flex_attention.BlockMask,
) -> torch.Tensor:
return torch.nn.attention.flex_attention.flex_attention(
q, k, v, block_mask=block_mask
)
def forward(
self, x: torch.Tensor, block_mask: torch.nn.attention.flex_attention.BlockMask
) -> torch.Tensor:
q = self.q_proj(x)
kv = self.kv_proj(x)
k = kv[..., : self.qk_embed_dim]
v = kv[..., self.qk_embed_dim :]
q = q.reshape((q.shape[0], q.shape[1], self.num_heads, -1)).transpose(1, 2)
k = k.reshape((k.shape[0], k.shape[1], self.num_heads, -1)).transpose(1, 2)
v = v.reshape((v.shape[0], v.shape[1], self.num_heads, -1)).transpose(1, 2)
return self.o_proj(
self.scaled_dot_product_attention(q, k, v, block_mask)
.transpose(1, 2)
.reshape((x.shape[0], x.shape[1], -1))
)
def create_block_mask(b: int, h: int, q_seq_len: int, k_seq_len: int):
def mask(b, h, q_idx, k_idx):
return q_idx >= k_idx
return torch.nn.attention.flex_attention.create_block_mask(
mask, b, h, q_seq_len, k_seq_len
)
@torch.compile
class TestModel(nn.Module):
def __init__(self):
super().__init__()
self.attn_layer = Attention(768, num_heads=24)
self.convs = nn.Sequential(
*(nn.Conv2d(768, 768, 3, padding=1) for _ in range(3))
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
block_mask = create_block_mask(8, 24, 256, 256)
y = self.convs(x).view((x.shape[0], x.shape[1], -1)).transpose(1, 2)
o = self.attn_layer(y, block_mask)
return self.convs(o.transpose(1, 2).reshape(x.shape))
if __name__ == "__main__":
test_model = TestModel().cuda()
x = torch.randn((8, 768, 16, 16), device="cuda", requires_grad=True)
z = test_model(x)
z.sum().backward()
torch.cuda.synchronize() All my models use convolution somewhere, and inductor seems to like to change the memory format for those, while flex attention does not enforce its required memory format, but only errors if it doesn't fit. The example runs fine with TORCHINDUCTOR_LAYOUT_OPTIMIZATION=0. Nonetheless, I have attached the debug output for this example for TORCHINDUCTOR_LAYOUT_OPTIMIZATION=1. Since I was looking through inductor config for options to disable this layout optimization, I also noticed that custom/triton ops seem to have some flags to prevent the compiler from calling them with a changed memory layout: pytorch/torch/_inductor/config.py Lines 123 to 136 in e0e8639
Something like that should probably also be implemented for flex attention. |
Does flex attention have strict layout requirements on the inputs? @drisspg |
The only restriction we have on layout is that the last dim's stride is 1, like I have rarely seen this assert fail, do we need a minimal assert size_strides like what is done for sdpa? |
Inductor generally is able to change the strides of inputs to any operator, unless there is special handling for it in FlexAttention. We recently added a "don't change the strides of inputs to operator X" that we can apply here if FlexAttention doesn't have something similar yet |
@drisspg I think you are the right one to look into this, though should be a 1-line fix. We have a prim to force strides here: pytorch/torch/_inductor/inductor_prims.py Lines 97 to 101 in 3b00ff8
You could also add the |
… was causing assert failure" # Summary Fixes: #148827 This one is strange, I could have sworn this was a real constraint, but I verified and did some performance checks and this constraint isn't required. cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov Chillee yanboliang BoyuanFeng [ghstack-poisoned]
…ert failure" # Summary Fixes: #148827 This one is strange, I could have sworn this was a real constraint, but I verified and did some performance checks and this constraint isn't required. cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov Chillee yanboliang BoyuanFeng [ghstack-poisoned]
) Fixes: pytorch#148827 Pull Request resolved: pytorch#151959 Approved by: https://github.com/Chillee ghstack dependencies: pytorch#151846
Uh oh!
There was an error while loading. Please reload this page.
🐛 Describe the bug
For flex attention, inputs must be contiguous, but inductor seems to permute inputs under certain conditions which then results in an assertion error.
When using an Attention layer looking somewhat like this:
I get a LoweringException under certain conditions. It does not reproduce as a standalone example sadly. In my model, this only happens if I do a validation iteration before doing a training iteration. If I directly start training, the compilation results seems to be different and the training runs without any issue. From the error message, it looks like inductor swaps the original memory format (B, L, H, C) [transposed to (B, H, L, C)] to (B, L, C, H), which results in non-contiguous q, k and v. (B=1, H=24, L=1904, C=32)
There seems to be no straight forward way to fix this. For example, the following code will also run into the same problem:
Workarounds like this exist:
I think flex_attention should probably not error on non-contiguous tensors, but rather enforce it itself. In any case, this is unexpected behavior from a user's perspective, since even when the eager version is contiguous, the compilation may fail due to query not being contiguous.
Error logs
torch._inductor.exc.InductorError: LoweringException: AssertionError: Query must be contiguous in the last dimension
target: flex_attention
args[0]: TensorBox(StorageBox(
ComputedBuffer(name='buf195', layout=FixedLayout('cuda:0', torch.float16, size=[1, 24, 1904, 32], stride=[1462272, 1, 768, 24]), data=Pointwise(device=device(type='cuda', index=0), dtype=torch.float16, inner_fn=<function pointwise_cat..inner_fn at 0x7f037043fd80>, ranges=[1, 24, 1904, 32]))
))
args[1]: TensorBox(StorageBox(
ComputedBuffer(name='buf196', layout=FixedLayout('cuda:0', torch.float16, size=[1, 24, 1904, 32], stride=[1462272, 1, 768, 24]), data=Pointwise(device=device(type='cuda', index=0), dtype=torch.float16, inner_fn=<function pointwise_cat..inner_fn at 0x7f0370485e40>, ranges=[1, 24, 1904, 32]))
))
args[2]: TensorBox(
ReinterpretView(
StorageBox(
ExternKernelOut(
python_kernel_name='extern_kernels.mm',
name=buf192,
layout=FixedLayout('cuda:0', torch.float16, size=[1904, 1536], stride=[1536, 1]),
inputs=[ReinterpretView(
StorageBox(
ComputedBuffer(name='buf189', layout=FixedLayout('cuda:0', torch.float16, size=[1, 1904, 768], stride=[1462272, 768, 1]), data=Pointwise(device=device(type='cuda', index=0), dtype=torch.float16, inner_fn=<function make_pointwise..inner..inner_fn at 0x7f03703a4e00>, ranges=[1, 1904, 768]))
),
FixedLayout('cuda:0', torch.float16, size=[1904, 768], stride=[768, 1]),
origins=OrderedSet([mm_1])
), ComputedBuffer(name='buf191', layout=FixedLayout('cuda:0', torch.float16, size=[768, 1536], stride=[1, 768]), data=Pointwise(device=device(type='cuda', index=0), dtype=torch.float16, inner_fn=<function BaseView.make_loader..loader at 0x7f03703a49a0>, ranges=[768, 1536]))],
constant_args=(),
kwargs={},
output_view=None,
python_kernel_name=extern_kernels.mm,
cpp_kernel_name=at::mm_out,
ordered_kwargs_for_cpp_kernel=(),
op_overload=None,
arg_properties=[{}, {}],
kwarg_properties=None,
unbacked_bindings={},
mutation_outputs=[],
origin_node=mm_1,
origins=OrderedSet([mm_1])
)
),
FixedLayout('cuda:0', torch.float16, size=[1, 24, 1904, 32], stride=[0, 32, 1536, 1], offset=768),
origins=OrderedSet([permute_60])
)
)
args[3]: Subgraph(name='sdpa_score0', graph_module=(), graph=None)
args[4]: (1904, 1904, TensorBox(StorageBox(
InputBuffer(name='primals_95', layout=FixedLayout('cuda:0', torch.int32, size=[4, 24, 15], stride=[360, 15, 1]))
)), TensorBox(StorageBox(
InputBuffer(name='primals_94', layout=FixedLayout('cuda:0', torch.int32, size=[4, 24, 15, 15], stride=[5400, 225, 15, 1]))
)), TensorBox(StorageBox(
InputBuffer(name='primals_96', layout=FixedLayout('cuda:0', torch.int32, size=[4, 24, 15], stride=[360, 15, 1]))
)), TensorBox(StorageBox(
InputBuffer(name='primals_97', layout=FixedLayout('cuda:0', torch.int32, size=[4, 24, 15, 15], stride=[5400, 225, 15, 1]))
)), TensorBox(StorageBox(
InputBuffer(name='primals_98', layout=FixedLayout('cuda:0', torch.int32, size=[4, 24, 15], stride=[360, 15, 1]))
)), TensorBox(StorageBox(
InputBuffer(name='primals_99', layout=FixedLayout('cuda:0', torch.int32, size=[4, 24, 15, 15], stride=[5400, 225, 15, 1]))
)), TensorBox(StorageBox(
InputBuffer(name='primals_100', layout=FixedLayout('cuda:0', torch.int32, size=[4, 24, 15], stride=[360, 15, 1]))
)), TensorBox(StorageBox(
InputBuffer(name='primals_101', layout=FixedLayout('cuda:0', torch.int32, size=[4, 24, 15, 15], stride=[5400, 225, 15, 1]))
)), 128, 128, Subgraph(name='sdpa_mask0', graph_module=(), graph=None))
args[5]: 0.17677669529663687
args[6]: {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'WRITE_DQ': True, 'OUTPUT_LOGSUMEXP': True}
args[7]: ()
args[8]: ()
Versions
PyTorch version: 2.7.0.dev20250308+cu128
Is debug build: False
CUDA used to build PyTorch: 12.8
ROCM used to build PyTorch: N/A
OS: Ubuntu 24.04.1 LTS (x86_64)
GCC version: (Ubuntu 13.3.0-6ubuntu2~24.04) 13.3.0
Clang version: Could not collect
CMake version: Could not collect
Libc version: glibc-2.39
Python version: 3.12.3 (main, Jan 17 2025, 18:03:48) [GCC 13.3.0] (64-bit runtime)
Python platform: Linux-6.13.5-1-default-x86_64-with-glibc2.39
Is CUDA available: True
CUDA runtime version: Could not collect
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration: GPU 0: NVIDIA GeForce RTX 4070
Nvidia driver version: 570.124.04
cuDNN version: Could not collect
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True
CPU:
Architecture: x86_64
CPU op-mode(s): 32-bit, 64-bit
Address sizes: 48 bits physical, 48 bits virtual
Byte Order: Little Endian
CPU(s): 12
On-line CPU(s) list: 0-11
Vendor ID: AuthenticAMD
Model name: AMD Ryzen 5 5600X 6-Core Processor
CPU family: 25
Model: 33
Thread(s) per core: 2
Core(s) per socket: 6
Socket(s): 1
Stepping: 0
Frequency boost: enabled
CPU(s) scaling MHz: 100%
CPU max MHz: 4651.0000
CPU min MHz: 550.0000
BogoMIPS: 7400.31
Flags: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ht syscall nx mmxext fxsr_opt pdpe1gb rdtscp lm constant_tsc rep_good nopl xtopology nonstop_tsc cpuid extd_apicid aperfmperf rapl pni pclmulqdq monitor ssse3 fma cx16 sse4_1 sse4_2 x2apic movbe popcnt aes xsave avx f16c rdrand lahf_lm cmp_legacy svm extapic cr8_legacy abm sse4a misalignsse 3dnowprefetch osvw ibs skinit wdt tce topoext perfctr_core perfctr_nb bpext perfctr_llc mwaitx cpb cat_l3 cdp_l3 hw_pstate ssbd mba ibrs ibpb stibp vmmcall fsgsbase bmi1 avx2 smep bmi2 erms invpcid cqm rdt_a rdseed adx smap clflushopt clwb sha_ni xsaveopt xsavec xgetbv1 xsaves cqm_llc cqm_occup_llc cqm_mbm_total cqm_mbm_local user_shstk clzero irperf xsaveerptr rdpru wbnoinvd arat npt lbrv svm_lock nrip_save tsc_scale vmcb_clean flushbyasid decodeassists pausefilter pfthreshold avic v_vmsave_vmload vgif v_spec_ctrl umip pku ospke vaes vpclmulqdq rdpid overflow_recov succor smca fsrm debug_swap
Virtualization: AMD-V
L1d cache: 192 KiB (6 instances)
L1i cache: 192 KiB (6 instances)
L2 cache: 3 MiB (6 instances)
L3 cache: 32 MiB (1 instance)
NUMA node(s): 1
NUMA node0 CPU(s): 0-11
Vulnerability Gather data sampling: Not affected
Vulnerability Itlb multihit: Not affected
Vulnerability L1tf: Not affected
Vulnerability Mds: Not affected
Vulnerability Meltdown: Not affected
Vulnerability Mmio stale data: Not affected
Vulnerability Reg file data sampling: Not affected
Vulnerability Retbleed: Not affected
Vulnerability Spec rstack overflow: Mitigation; Safe RET
Vulnerability Spec store bypass: Mitigation; Speculative Store Bypass disabled via prctl
Vulnerability Spectre v1: Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2: Mitigation; Retpolines; IBPB conditional; IBRS_FW; STIBP always-on; RSB filling; PBRSB-eIBRS Not affected; BHI Not affected
Vulnerability Srbds: Not affected
Vulnerability Tsx async abort: Not affected
Versions of relevant libraries:
[pip3] mypy-extensions==1.0.0
[pip3] numpy==2.2.2
[pip3] nvidia-cublas-cu12==12.8.3.14
[pip3] nvidia-cuda-cupti-cu12==12.8.57
[pip3] nvidia-cuda-nvrtc-cu12==12.8.61
[pip3] nvidia-cuda-runtime-cu12==12.8.57
[pip3] nvidia-cudnn-cu12==9.7.1.26
[pip3] nvidia-cufft-cu12==11.3.3.41
[pip3] nvidia-curand-cu12==10.3.9.55
[pip3] nvidia-cusolver-cu12==11.7.2.55
[pip3] nvidia-cusparse-cu12==12.5.7.53
[pip3] nvidia-cusparselt-cu12==0.6.3
[pip3] nvidia-nccl-cu12==2.25.1
[pip3] nvidia-nvjitlink-cu12==12.8.61
[pip3] nvidia-nvtx-cu12==12.8.55
[pip3] pytorch-triton==3.2.0+git4b3bb1f8
[pip3] torch==2.7.0.dev20250308+cu128
[pip3] torchaudio==2.6.0.dev20250308+cu128
[pip3] torchinfo==1.8.0
[pip3] torchvision==0.22.0.dev20250308+cu128
[pip3] triton==3.2.0
[conda] Could not collect
cc @chauhang @penguinwu @voznesenskym @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @chenyang78 @kadeng @muchulee8 @amjames @aakhundov @zou3519 @ydwu4 @bdhirsh @Chillee @drisspg @yanboliang @BoyuanFeng
The text was updated successfully, but these errors were encountered: