8000 Triton Compilation Error in Generated Code due to possible float division in index · Issue #153375 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content
Triton Compilation Error in Generated Code due to possible float division in index #153375
@hypernicon

Description

@hypernicon

🐛 Describe the bug

Got a compilation error from triton (Triton 3.3.0 and PyTorch 2.7.0, running on an H100) when using torch.compile(model). It looks like it comes from generating one part of a fused convolutional backward kernel. I don't know exactly which code caused this, but it's probably something like this module:

class ResidualUnit(torch.nn.Module):
    def __init__(self, channels: int, dilation: int, kernel_size: int = 7):
        super(ResidualUnit, self).__init__()
        padding: tuple[int] = (dilation*(kernel_size - 1), 0)
        self.padding = padding
        self.dilation = dilation

        self.conv1 = torch.nn.Conv1d(
            in_channels=channels,
            out_channels=channels,
            kernel_size=kernel_size,
            padding=0,
            dilation=dilation
        )
        self.conv2 = torch.nn.Conv1d(
            in_channels=channels,
            out_channels=channels,
            kernel_size=1,
            padding=0,
        )

    def forward(self, x):
        residual = x
        x = torch.nn.functional.pad(x, self.padding, "constant", 0)
        x = self.conv1(x)
        x = torch.nn.functional.elu(x)
        x = self.conv2(x)
        x = torch.nn.functional.elu(x)
        x = x + residual
        return x

The generated triton code (from running with bf16 precision) was:

def triton_red_fused_convolution_backward_2(in_ptr0, out_ptr0, ks0, ks1, ks2, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr):
     xnumel = 58
     rnumel = r0_numel
     RBLOCK: tl.constexpr = R0_BLOCK
     xoffset = tl.program_id(0) * XBLOCK
     xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
     xmask = xindex < xnumel
     r0_base = tl.arange(0, R0_BLOCK)[None, :]
     rbase = r0_base
    x0 = xindex
     _tmp2 = tl.full([XBLOCK, R0_BLOCK], 0, tl.float32)
     for r0_offset in range(0, r0_numel, R0_BLOCK):
         r0_index = r0_offset + r0_base
         r0_mask = r0_index < r0_numel
         roffset = r0_offset
         rindex = r0_index
         r0_1 = r0_index
         tmp0 = tl.load(in_ptr0 + (6*((((r0_1 + (3/29)*ks1*x0 + (160/29)*ks1*ks2*x0) // ks0) % ks1)) + 320*ks2*((((r0_1 + (3/29)*ks1*x0 + (160/29)*ks1*ks2*x0) // ks0) % ks1)) + (((r0_1 + (3/29)*ks1*x0 + (160/29)*ks1*ks2*x0) % ks0))), xmask & r0_mask, eviction_policy='evict_last', other=0.0).to(tl.float32)
         tmp1 = tl.broadcast_to(tmp0, [XBLOCK, R0_BLOCK])
         tmp3 = _tmp2 + tmp1
         _tmp2 = tl.where(r0_mask & xmask, tmp3, _tmp2)
     tmp2 = tl.sum(_tmp2, 1)[:, None]
     tl.store(out_ptr0 + (x0), tmp2, xmask)

Not 100% sure about this, but I don't know why we have fractional int division, e.g. (3/29), (160/29), etc. and whether that will work; it seems like it would generate a floating point index which would then die when added to "in_ptr0", which is *bf16.

E0512 05:39:38.257000 29662 torch/_inductor/runtime/triton_heuristics.py:539] Triton compilation failed: triton_red_fused_convolution_backward_2647876  
E0512 05:39:38.257000 29662 torch/_inductor/runtime/triton_heuristics.py:539] def triton_red_fused_convolution_backward_2(in_ptr0, out_ptr0, ks0, ks1, ks2, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr):
E0512 05:39:38.257000 29662 torch/_inductor/runtime/triton_heuristics.py:539]     xnumel = 58
E0512 05:39:38.257000 29662 torch/_inductor/runtime/triton_heuristics.py:539]     rnumel = r0_numel
E0512 05:39:38.257000 29662 torch/_inductor/runtime/triton_heuristics.py:539]     RBLOCK: tl.constexpr = R0_BLOCK
E0512 05:39:38.257000 29662 torch/_inductor/runtime/triton_heuristics.py:539]     xoffset = tl.program_id(0) * XBLOCK
E0512 05:39:38.257000 29662 torch/_inductor/runtime/triton_heuristics.py:539]     xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
E0512 05:39:38.257000 29662 torch/_inductor/runtime/triton_heuristics.py:539]     xmask = xindex < xnumel
E0512 05:39:38.257000 29662 torch/_inductor/runtime/triton_heuristics.py:539]     r0_base = tl.arange(0, R0_BLOCK)[None, :]
E0512 05:39:38.257000 29662 torch/_inductor/runtime/triton_heuristics.py:539]     rbase = r0_base
E0512 05:39:38.257000 29662 torch/_inductor/runtime/triton_heuristics.py:539]     x0 = xindex
E0512 05:39:38.257000 29662 torch/_inductor/runtime/triton_heuristics.py:539]     _tmp2 = tl.full([XBLOCK, R0_BLOCK], 0, tl.float32)
E0512 05:39:38.257000 29662 torch/_inductor/runtime/triton_heuristics.py:539]     for r0_offset in range(0, r0_numel, R0_BLOCK):
E0512 05:39:38.257000 29662 torch/_inductor/runtime/triton_heuristics.py:539]         r0_index = r0_offset + r0_base
E0512 05:39:38.257000 29662 torch/_inductor/runtime/triton_heuristics.py:539]         r0_mask = r0_index < r0_numel
E0512 05:39:38.257000 29662 torch/_inductor/runtime/triton_heuristics.py:539]         roffset = r0_offset
E0512 05:39:38.257000 29662 torch/_inductor/runtime/triton_heuristics.py:539]         rindex = r0_index
E0512 05:39:38.257000 29662 torch/_inductor/runtime/triton_heuristics.py:539]         r0_1 = r0_index
E0512 05:39:38.257000 29662 torch/_inductor/runtime/triton_heuristics.py:539]         tmp0 = tl.load(in_ptr0 + (6*((((r0_1 + (3/29)*ks1*x0 + (160/29)*ks1*ks2*x0) // ks0) % ks1)) + 320*ks2*((((r0_1 + (3/29)*ks1*x0 + (160/29)*ks1*ks2*x0) // ks0) % ks1)) + (((r0_1 + (3/29)*ks1*x0 + (160/29)*ks1*ks2*x0) % ks0))), xmask & r0_mask, eviction_policy='evict_last', other=0.0).to(tl.float32)
E0512 05:39:38.257000 29662 torch/_inductor/runtime/triton_heuristics.py:539]         tmp1 = tl.broadcast_to(tmp0, [XBLOCK, R0_BLOCK])
E0512 05:39:38.257000 29662 torch/_inductor/runtime/triton_heuristics.py:539]         tmp3 = _tmp2 + tmp1
E0512 05:39:38.257000 29662 torch/_inductor/runtime/triton_heuristics.py:539]         _tmp2 = tl.where(r0_mask & xmask, tmp3, _tmp2)
E0512 05:39:38.257000 29662 torch/_inductor/runtime/triton_heuristics.py:539]     tmp2 = tl.sum(_tmp2, 1)[:, None]
E0512 05:39:38.257000 29662 torch/_inductor/runtime/triton_heuristics.py:539]     tl.store(out_ptr0 + (x0), tmp2, xmask)
E0512 05:39:38.257000 29662 torch/_inductor/runtime/triton_heuristics.py:539] 
E0512 05:39:38.257000 29662 torch/_inductor/runtime/triton_heuristics.py:539] metadata: {'signature': {'in_ptr0': '*bf16', 'out_ptr0': '*fp32', 'ks0': 'i32', 'ks1': 'i32', 'ks2': 'i32', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': 0, 'constants': {'XBLOCK': 1, 'R0_BLOCK': 2048}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]]}], 'device_type': 'cuda', 'num_warps': 16, 'num_stages': 1, 'debug': True, 'cc': 90}
E0512 05:39:38.257000 29662 torch/_inductor/runtime/triton_heuristics.py:539] Traceback (most recent call last):
E0512 05:39:38.257000 29662 torch/_inductor/runtime/triton_heuristics.py:539]   File "/root/venv/lib/python3.10/site-packages/torch/_inductor/runtime/triton_heuristics.py", line 537, in _precompile_config
E0512 05:39:38.257000 29662 torch/_inductor/runtime/triton_heuristics.py:539]     binary = triton.compile(*compile_args, **compile_kwargs)
E0512 05:39:38.257000 29662 torch/_inductor/runtime/triton_heuristics.py:539]   File "/root/venv/lib/python3.10/site-packages/triton/compiler/compiler.py", line 278, in compile
E0512 05:39:38.257000 29662 torch/_inductor/runtime/triton_heuristics.py:539]     module = src.make_ir(options, codegen_fns, module_map, context)
E0512 05:39:38.257000 29662 torch/_inductor/runtime/triton_heuristics.py:539]   File "/root/venv/lib/python3.10/site-packages/triton/compiler/compiler.py", line 81, in make_ir
E0512 05:39:38.257000 29662 torch/_inductor/runtime/triton_heuristics.py:539]     return ast_to_ttir(self.fn, self, context=context, options=options, codegen_fns=codegen_fns,
E0512 05:39:38.257000 29662 torch/_inductor/runtime/triton_heuristics.py:539] triton.compiler.errors.CompilationError: at 18:39:
E0512 05:39:38.257000 29662 torch/_inductor/runtime/triton_heuristics.py:539]     xmask = xindex < xnumel
E0512 05:39:38.257000 29662 torch/_inductor/runtime/triton_heuristics.py:539]     r0_base = tl.arange(0, R0_BLOCK)[None, :]
E0512 05:39:38.257000 29662 torch/_inductor/runtime/triton_heuristics.py:539]     rbase = r0_base
E0512 05:39:38.257000 29662 torch/_inductor/runtime/triton_heuristics.py:539]     x0 = xindex
E0512 05:39:38.257000 29662 torch/_inductor/runtime/triton_heuristics.py:539]     _tmp2 = tl.full([XBLOCK, R0_BLOCK], 0, tl.float32)
E0512 05:39:38.257000 29662 torch/_inductor/runtime/triton_heuristics.py:539]     for r0_offset in range(0, r0_numel, R0_BLOCK):
E0512 05:39:38.257000 29662 torch/_inductor/runtime/triton_heuristics.py:539]         r0_index = r0_offset + r0_base
E0512 05:39:38.257000 29662 torch/_inductor/runtime/triton_heuristics.py:539]         r0_mask = r0_index < r0_numel
E0512 05:39:38.257000 29662 torch/_inductor/runtime/triton_heuristics.py:539]         roffset = r0_offset
E0512 05:39:38.257000 29662 torch/_inductor/runtime/triton_heuristics.py:539]         rindex = r0_index
E0512 05:39:38.257000 29662 torch/_inductor/runtime/triton_heuristics.py:539]         r0_1 = r0_index
E0512 05:39:38.257000 29662 torch/_inductor/runtime/triton_heuristics.py:539]         tmp0 = tl.load(in_ptr0 + (6*((((r0_1 + (3/29)*ks1*x0 + (160/29)*ks1*ks2*x0) // ks0) % ks1)) + 320*ks2*((((r0_1 + (3/29)*ks1*x0 + (160/29)*ks1*ks2*x0) // ks0) % ks1)) + (((r0_1 + (3/29)*ks1*x0 + (160/29)*ks1*ks2*x0) % ks0))), xmask & r0_mask, eviction_policy='evict_last', other=0.0).to(tl.float32)
E0512 05:39:38.257000 29662 torch/_inductor/runtime/triton_heuristics.py:539]                                        ^
E0512 05:39:38.257000 29662 torch/_inductor/runtime/triton_heuristics.py:539] TypeError('unexpected type fp32')

Versions

Collecting environment information...
PyTorch version: 2.7.0+cu126
Is debug build: False
CUDA used to build PyTorch: 12.6
ROCM used to build PyTorch: N/A

OS: Ubuntu 22.04.4 LTS (x86_64)
GCC version: (Ubuntu 11.4.0-1ubuntu1~22.04) 11.4.0
Clang version: Could not collect
CMake version: Could not collect
Libc version: glibc-2.35

Python version: 3.10.12 (main, Feb 4 2025, 14:57:36) [GCC 11.4.0] (64-bit runtime)
Python platform: Linux-5.15.0-139-generic-x86_64-with-glibc2.35
Is CUDA available: True
CUDA runtime version: 12.1.105
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration: GPU 0: NVIDIA H100 80GB HBM3
Nvidia driver version: 535.247.01
cuDNN version: Could not collect
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

CPU:
Architecture: x86_64
CPU op-mode(s): 32-bit, 64-bit
Address sizes: 46 bits physical, 57 bits virtual
Byte Order: Little Endian
CPU(s): 20
On-line CPU(s) list: 0-19
Vendor ID: GenuineIntel
Model name: Intel(R) Xeon(R) Platinum 8468
CPU family: 6
Model: 143
Thread(s) per core: 1
Core(s) per socket: 20
Socket(s): 1
Stepping: 7F62 8
BogoMIPS: 4200.00
Flags: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ss ht syscall nx pdpe1gb rdtscp lm constant_tsc arch_perfmon rep_good nopl xtopology cpuid tsc_known_freq pni pclmulqdq vmx ssse3 fma cx16 pdcm pcid sse4_1 sse4_2 x2apic movbe popcnt tsc_deadline_timer aes xsave avx f16c rdrand hypervisor lahf_lm abm 3dnowprefetch cpuid_fault invpcid_single ssbd ibrs ibpb stibp ibrs_enhanced tpr_shadow vnmi flexpriority ept vpid ept_ad fsgsbase tsc_adjust bmi1 avx2 smep bmi2 erms invpcid avx512f avx512dq rdseed adx smap avx512ifma clflushopt clwb avx512cd sha_ni avx512bw avx512vl xsaveopt xsavec xgetbv1 xsaves avx_vnni avx512_bf16 wbnoinvd arat avx512vbmi umip pku ospke waitpkg avx512_vbmi2 gfni vaes vpclmulqdq avx512_vnni avx512_bitalg avx512_vpopcntdq la57 rdpid bus_lock_detect cldemote movdiri movdir64b fsrm md_clear serialize tsxldtrk avx512_fp16 arch_capabilities
Virtualization: VT-x
Hypervisor vendor: KVM
Virtualization type: full
L1d cache: 640 KiB (20 instances)
L1i cache: 640 KiB (20 instances)
L2 cache: 80 MiB (20 instances)
NUMA node(s): 1
NUMA node0 CPU(s): 0-19
Vulnerability Gather data sampling: Not affected
Vulnerability Itlb multihit: Not affected
Vulnerability L1tf: Not affected
Vulnerability Mds: Not affected
Vulnerability Meltdown: Not affected
Vulnerability Mmio stale data: Unknown: No mitigations
Vulnerability Reg file data sampling: Not affected
Vulnerability Retbleed: Not affected
Vulnerability Spec rstack overflow: Not affected
Vulnerability Spec store bypass: Mitigation; Speculative Store Bypass disabled via prctl and seccomp
Vulnerability Spectre v1: Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2: Mitigation; Enhanced / Automatic IBRS; IBPB conditional; RSB filling; PBRSB-eIBRS SW sequence; BHI SW loop, KVM SW loop
Vulnerability Srbds: Not affected
Vulnerability Tsx async abort: Mitigation; TSX disabled

Versions of relevant libraries:
[pip3] numpy==2.2.5
[pip3] nvidia-cublas-cu12==12.6.4.1
[pip3] nvidia-cuda-cupti-cu12==12.6.80
[pip3] nvidia-cuda-nvrtc-cu12==12.6.77
[pip3] nvidia-cuda-runtime-cu12==12.6.77
[pip3] nvidia-cudnn-cu12==9.5.1.17
[pip3] nvidia-cufft-cu12==11.3.0.4
[pip3] nvidia-curand-cu12==10.3.7.77
[pip3] nvidia-cusolver-cu12==11.7.1.2
[pip3] nvidia-cusparse-cu12==12.5.4.2
[pip3] nvidia-cusparselt-cu12==0.6.3
[pip3] nvidia-nccl-cu12==2.26.2
[pip3] nvidia-nvjitlink-cu12==12.6.85
[pip3] nvidia-nvtx-cu12==12.6.77
[pip3] torch==2.7.0
[pip3] torchaudio==2.7.0
[pip3] torchdata==0.11.0
[pip3] torchvision==0.22.0
[pip3] triton==3.3.0
[conda] Could not collect

cc @chauhang @penguinwu @voznesenskym @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @chenyang78 @kadeng @muchulee8 @amjames @aakhundov

Metadata

Metadata

Assignees

No one assigned

    Labels

    module: inductorneeds reproductionSomeone else needs to try reproducing the issue given the instructions. No action needed from useroncall: pt2triagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions

      0