-
Notifications
You must be signed in to change notification settings - Fork 24.8k
Description
🚀 The feature, motivation and pitch
TorchInductor currently relies on hand-written templates for matrix multiply variants, such as:
While these templates are effective, they make it difficult to fuse surrounding operations 67E6 , even though Inductor supports prologue/epilogue fusion (see #142315 and #142315).
Proposed Feature
This proposal enables Inductor to generate performant matrix multiplication kernels directly, without relying on the hand-written templates. A prototype implementation is available here:
main...nullplay:pytorch:jaeyeon_fuse
Implementation Overview
1. Emit Tensor Core tl.dot
for Matrix Multiply Patterns
When Inductor is forced to generate a matmul kernel using (A.unsqueeze(2) * B.unsqueeze(0)).sum(dim=1)
, it currently emits something like the following:
@triton.jit
def triton_v1(A, B, C):
YXBLOCK : tl.constexpr = 32 * 32
RBLOCK : tl.constexpr = 32
yxoffset = tl.program_id(0) * YXBLOCK
yx = yxoffset + tl.arange(0, YXBLOCK)[:, None] # (YX, 1)
r_base = tl.arange(0, RBLOCK)[None, :] # (1, R)
y = yx // 2048
x = yx % 2048
acc = tl.full([YXBLOCK, RBLOCK], 0.0) # (YX, R)
for r_offset in range(0, 2048, RBLOCK):
r = r_offset + r_base # (1, R)
A_yr = tl.load(A + 2048 * y + r) # (YX, R)
B_rx = tl.load(B + 2048 * r + x) # (YX, R)
acc += A_yr * B_rx # (YX, R)
acc = tl.sum(acc, 1)[:, None] # (YX, R) → (YX, 1)
tl.store(C + yx, acc)
Here, matrix multiplication is expressed as a loop with elementwise multiplication and sum, without using tl.dot
.
To address this, a new ops.dot
node is introduced in Inductor IR to capture the matmul pattern, enabling codegen to emit tl.dot
instead. The resulting kernel looks like:
@triton.jit
def triton_v2(A, B, C):
YBLOCK : tl.constexpr = 32
XBLOCK : tl.constexpr = 32
RBLOCK : tl.constexpr = 32
yoffset = tl.program_id(1) * YBLOCK
xoffset = tl.program_id(0) * XBLOCK
y = yoffset + tl.arange(0, YBLOCK)[:, None, None] # (Y, 1, 1)
x = xoffset + tl.arange(0, XBLOCK)[None, :, None] # (1, X, 1)
r_base = tl.arange(0, RBLOCK)[None, None, :] # (1, 1, R)
acc = tl.full([YBLOCK, XBLOCK], 0.0) # (Y, X)
for r_offset in range(0, 2048, RBLOCK):
r = r_offset + r_base # (1, 1, R)
A_yxr = tl.load(A + 2048 * y + r) # (Y, 1, R)
B_yxr = tl.load(B + 2048 * r + x) # (1, X, R)
A_yr = tl.view(A_yxr, [YBLOCK, RBLOCK]) # (Y, R)
B_xr = tl.view(B_yxr, [XBLOCK, RBLOCK]) # (X, R)
acc += tl.dot(A_yr, tl.trans(B_xr)) # (Y, R) x (R, X) → (Y, X)
acc = acc[:, :, None] # (Y, X) → (Y, X, 1)
tl.store(C + 2048 * y + x, acc)
This version uses tl.dot
and reshapes inputs appropriately, with the output accumulator remaining output-stationary.
2. Lazy Broadcasting to Avoid Reshape and Transpose
To match the performance of PyTorch’s hand-written Triton templates, it's important to avoid reshapes and transposes. Instead of eagerly broadcasting across all axes (i.e., assigning each loop dimension to a distinct Triton axis), we lazily broadcast only the reduction axis (RBLOCK) to align with tl.dot
semantics. For example:
@triton.jit
def triton_v3(A, B, C):
YBLOCK : tl.constexpr = 32
XBLOCK : tl.constexpr = 32
RBLOCK : tl.constexpr = 32
yoffset = tl.program_id(1) * YBLOCK
xoffset = tl.program_id(0) * XBLOCK
y = yoffset + tl.arange(0, YBLOCK)[:, None] # (Y, 1) -- eager broadcast
x = xoffset + tl.arange(0, XBLOCK)[None, :] # (1, X) -- eager broadcast
r_base = tl.arange(0, RBLOCK) # (R)
acc = tl.full([YBLOCK, XBLOCK], 0.0) # (Y, X)
for r_offset in range(0, 2048, RBLOCK):
r = r_offset + r_base
A_yr = tl.load(A + 2048 * y + r[None, :]) # (Y, R) — lazy broadcast
B_rx = tl.load(B + 2048 * r[:, None] + x) # (R, X) — lazy broadcast
acc += tl.dot(A_yr, B_rx) # (Y, R) x (R, X) → (Y, X)
tl.store(C + 2048 * y + x, acc)
This approach eliminates the need for transposing and reshaping inputs, while still matching the expected layout for tl.dot
.
A nice thing of this feature is that it enables automatic fusion of operations around tl.dot, without requiring major changes to the Inductor. For instance, consider the following PyTorch program:
# z[w[m],n] += x[w[m],k] * y[k,n] + 3
def f(x,y,z,w):
intm = x[w,:] @ y + 3
return z.index_add_(dim=0, index=w, source=intm)
With this feature enabled, TorchInductor generates a fully fused Triton kernel with Tensor Core:
@triton.jit
def triton_red_fused_add_index_add_mm_0(in_ptr0, in_ptr1, in_ptr2, out_ptr1, ynumel, xnumel, r0_numel, YBLOCK : tl.constexpr, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr):
ynumel = 128
xnumel = 128
r0_numel = 128
rnumel = r0_numel
RBLOCK: tl.constexpr = R0_BLOCK
yoffset = tl.program_id(1) * YBLOCK
yindex = yoffset + tl.arange(0, YBLOCK)[:,None]
ymask = yindex < ynumel
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[None,:]
xmask = xindex < xnumel
r0_base = tl.arange(0, R0_BLOCK)
rbase = r0_base
y0 = yindex
tmp0 = tl.load(in_ptr0 + (y0), ymask, eviction_policy='evict_last')
x1 = xindex
_tmp10 = tl.full([YBLOCK, XBLOCK], 0, tl.float32)
for r0_offset in range(0, r0_numel, R0_BLOCK):
r0_index = r0_offset + r0_base
r0_mask = r0_index < r0_numel
roffset = r0_offset
rindex = r0_index
r0_2 = r0_index
tmp7 = tl.load(in_ptr2 + (x1 + 128*r0_2[:,None]), r0_mask[:,None] & xmask, eviction_policy='evict_last', other=0.0)
tmp1 = 128
tmp2 = tmp0 + tmp1
tmp3 = tmp0 < 0
tmp4 = tl.where(tmp3, tmp2, tmp0)
tl.device_assert(((0 <= tmp4) & (tmp4 < 128)) | ~(ymask), "index out of bounds: 0 <= tmp4 < 128")
tmp6 = tl.load(in_ptr1 + (r0_2[None,:] + 128*tmp4), ymask & r0_mask[None,:], eviction_policy='evict_first', other=0.0)
tmp8 = tl.dot(tmp6, tmp7, allow_tf32=False)
tmp9 = tl.broadcast_to(tmp8, [YBLOCK, XBLOCK])
tmp11 = _tmp10 + tmp9
_tmp10 = tmp11
tmp10 = _tmp10
tmp12 = 128
tmp13 = tmp0 + tmp12
tmp14 = tmp0 < 0
tmp15 = tl.where(tmp14, tmp13, tmp0)
tl.device_assert(((0 <= tmp15) & (tmp15 < 128)) | ~(ymask), "index out of bounds: 0 <= tmp15 < 128")
tmp17 = 3.0
tmp18 = tmp10 + tmp17
tl.atomic_add(out_ptr1 + (x1 + 128*tmp15), tmp18, ymask & xmask, sem='relaxed')
Performance and Benefits
- Matches performance of hand-written
mm
andbmm
templates on bothfp16
andfp32
- Can generate fused kernels for compound expressions such as
A @ B + C @ D
- Achieves up to 5–10× speedup on gather–matmul–scatter patterns by eliminating intermediate tensors
- Supports multiple dtypes (
fp16
,fp32
,bf16
)—though not exhaustively tested. - (maybe) more maintainable alternative to hardcoded templates
How to Enable
You can test this feature by setting:
torch._inductor.config.triton.use_dot_reduction = True
Prototype fork: https://github.com/nullplay/pytorch/tree/jaeyeon_fuse
Test cases: https://github.com/nullplay/pytorch/blob/jaeyeon_fuse/test/inductor/test_dot_reduction.py
Since this is a prototype, there are some limitations like 1. Prototype is implemented in a hacky way and needs refactoring,
2. Excessive fusion can sometimes reduce performance (need better fusion heuristics), and
3. Need to implement autotuning for these kernels more robust
I would appreciate feedback from PyTorch developers on this direction. Do you think enabling native tl.dot
codegen in Inductor is a reasonable and maintainable path forward for high-performance matmul fusion?
@jansel @eellison @drisspg @blaine-rister
Alternatives
No response
Additional context
No response
cc @chauhang @penguinwu @voznesenskym @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @chenyang78 @kadeng @muchulee8 @amjames @aakhundov