8000 Enable TorchInductor to Generate Matmuls Natively via `tl.dot` · Issue #151705 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content
Enable TorchInductor to Generate Matmuls Natively via tl.dot #151705
@nullplay

Description

@nullplay

🚀 The feature, motivation and pitch

TorchInductor currently relies on hand-written templates for matrix multiply variants, such as:

While these templates are effective, they make it difficult to fuse surrounding operations 67E6 , even though Inductor supports prologue/epilogue fusion (see #142315 and #142315).

Proposed Feature

This proposal enables Inductor to generate performant matrix multiplication kernels directly, without relying on the hand-written templates. A prototype implementation is available here:
main...nullplay:pytorch:jaeyeon_fuse

Implementation Overview

1. Emit Tensor Core tl.dot for Matrix Multiply Patterns

When Inductor is forced to generate a matmul kernel using (A.unsqueeze(2) * B.unsqueeze(0)).sum(dim=1), it currently emits something like the following:

@triton.jit
def triton_v1(A, B, C):
    YXBLOCK : tl.constexpr = 32 * 32
    RBLOCK : tl.constexpr = 32

    yxoffset = tl.program_id(0) * YXBLOCK
    yx = yxoffset + tl.arange(0, YXBLOCK)[:, None]       # (YX, 1) 
    r_base = tl.arange(0, RBLOCK)[None, :]               # (1, R) 

    y = yx // 2048
    x = yx % 2048

    acc = tl.full([YXBLOCK, RBLOCK], 0.0)                    # (YX, R)
    for r_offset in range(0, 2048, RBLOCK):
        r = r_offset + r_base                                # (1, R)
        A_yr = tl.load(A + 2048 * y + r)                     # (YX, R)
        B_rx = tl.load(B + 2048 * r + x)                     # (YX, R)
        acc += A_yr * B_rx                                   # (YX, R)

    acc = tl.sum(acc, 1)[:, None]                            # (YX, R) → (YX, 1)
    tl.store(C + yx, acc)

Here, matrix multiplication is expressed as a loop with elementwise multiplication and sum, without using tl.dot.

To address this, a new ops.dot node is introduced in Inductor IR to capture the matmul pattern, enabling codegen to emit tl.dot instead. The resulting kernel looks like:

@triton.jit
def triton_v2(A, B, C):
    YBLOCK : tl.constexpr = 32
    XBLOCK : tl.constexpr = 32
    RBLOCK : tl.constexpr = 32

    yoffset = tl.program_id(1) * YBLOCK
    xoffset = tl.program_id(0) * XBLOCK
    y = yoffset + tl.arange(0, YBLOCK)[:, None, None]     # (Y, 1, 1)
    x = xoffset + tl.arange(0, XBLOCK)[None, :, None]     # (1, X, 1)
    r_base = tl.arange(0, RBLOCK)[None, None, :]          # (1, 1, R)

    acc = tl.full([YBLOCK, XBLOCK], 0.0)                  # (Y, X)
    for r_offset in range(0, 2048, RBLOCK):
        r = r_offset + r_base                             # (1, 1, R)
        A_yxr = tl.load(A + 2048 * y + r)                 # (Y, 1, R)
        B_yxr = tl.load(B + 2048 * r + x)                 # (1, X, R)

        A_yr = tl.view(A_yxr, [YBLOCK, RBLOCK])           # (Y, R)
        B_xr = tl.view(B_yxr, [XBLOCK, RBLOCK])           # (X, R)
        acc += tl.dot(A_yr, tl.trans(B_xr))               # (Y, R) x (R, X) → (Y, X)

    acc = acc[:, :, None]                                 # (Y, X) → (Y, X, 1)
    tl.store(C + 2048 * y + x, acc)

This version uses tl.dot and reshapes inputs appropriately, with the output accumulator remaining output-stationary.

2. Lazy Broadcasting to Avoid Reshape and Transpose

To match the performance of PyTorch’s hand-written Triton templates, it's important to avoid reshapes and transposes. Instead of eagerly broadcasting across all axes (i.e., assigning each loop dimension to a distinct Triton axis), we lazily broadcast only the reduction axis (RBLOCK) to align with tl.dot semantics. For example:

@triton.jit
def triton_v3(A, B, C):
    YBLOCK : tl.constexpr = 32
    XBLOCK : tl.constexpr = 32
    RBLOCK : tl.constexpr = 32

    yoffset = tl.program_id(1) * YBLOCK
    xoffset = tl.program_id(0) * XBLOCK
    y = yoffset + tl.arange(0, YBLOCK)[:, None]           # (Y, 1) -- eager broadcast
    x = xoffset + tl.arange(0, XBLOCK)[None, :]           # (1, X) -- eager broadcast
    r_base = tl.arange(0, RBLOCK)                         # (R)

    acc = tl.full([YBLOCK, XBLOCK], 0.0)                  # (Y, X)
    for r_offset in range(0, 2048, RBLOCK):
        r = r_offset + r_base

        A_yr = tl.load(A + 2048 * y + r[None, :])         # (Y, R) — lazy broadcast
        B_rx = tl.load(B + 2048 * r[:, None] + x)         # (R, X) — lazy broadcast

        acc += tl.dot(A_yr, B_rx)                         # (Y, R) x (R, X) → (Y, X)

    tl.store(C + 2048 * y + x, acc)

This approach eliminates the need for transposing and reshaping inputs, while still matching the expected layout for tl.dot.

A nice thing of this feature is that it enables automatic fusion of operations around tl.dot, without requiring major changes to the Inductor. For instance, consider the following PyTorch program:

# z[w[m],n] += x[w[m],k] * y[k,n] + 3
def f(x,y,z,w):
    intm = x[w,:] @ y + 3 
    return z.index_add_(dim=0, index=w, source=intm)

With this feature enabled, TorchInductor generates a fully fused Triton kernel with Tensor Core:

@triton.jit
def triton_red_fused_add_index_add_mm_0(in_ptr0, in_ptr1, in_ptr2, out_ptr1, ynumel, xnumel, r0_numel, YBLOCK : tl.constexpr, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr):
    ynumel = 128
    xnumel = 128
    r0_numel = 128
    rnumel = r0_numel
    RBLOCK: tl.constexpr = R0_BLOCK
    yoffset = tl.program_id(1) * YBLOCK
    yindex = yoffset + tl.arange(0, YBLOCK)[:,None]
    ymask = yindex < ynumel
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[None,:]
    xmask = xindex < xnumel
    r0_base = tl.arange(0, R0_BLOCK)
    rbase = r0_base
    y0 = yindex
    tmp0 = tl.load(in_ptr0 + (y0), ymask, eviction_policy='evict_last')
    x1 = xindex
    _tmp10 = tl.full([YBLOCK, XBLOCK], 0, tl.float32)
    for r0_offset in range(0, r0_numel, R0_BLOCK):
        r0_index = r0_offset + r0_base
        r0_mask = r0_index < r0_numel
        roffset = r0_offset
        rindex = r0_index
        r0_2 = r0_index
        tmp7 = tl.load(in_ptr2 + (x1 + 128*r0_2[:,None]), r0_mask[:,None] & xmask, eviction_policy='evict_last', other=0.0)
        tmp1 = 128
        tmp2 = tmp0 + tmp1
        tmp3 = tmp0 < 0
        tmp4 = tl.where(tmp3, tmp2, tmp0)
        tl.device_assert(((0 <= tmp4) & (tmp4 < 128)) | ~(ymask), "index out of bounds: 0 <= tmp4 < 128")
        tmp6 = tl.load(in_ptr1 + (r0_2[None,:] + 128*tmp4), ymask & r0_mask[None,:], eviction_policy='evict_first', other=0.0)
        tmp8 = tl.dot(tmp6, tmp7, allow_tf32=False)
        tmp9 = tl.broadcast_to(tmp8, [YBLOCK, XBLOCK])
        tmp11 = _tmp10 + tmp9
        _tmp10 = tmp11
    tmp10 = _tmp10
    tmp12 = 128
    tmp13 = tmp0 + tmp12
    tmp14 = tmp0 < 0
    tmp15 = tl.where(tmp14, tmp13, tmp0)
    tl.device_assert(((0 <= tmp15) & (tmp15 < 128)) | ~(ymask), "index out of bounds: 0 <= tmp15 < 128")
    tmp17 = 3.0
    tmp18 = tmp10 + tmp17
    tl.atomic_add(out_ptr1 + (x1 + 128*tmp15), tmp18, ymask & xmask, sem='relaxed')

Performance and Benefits

  • Matches performance of hand-written mm and bmm templates on both fp16 and fp32
  • Can generate fused kernels for compound expressions such as A @ B + C @ D
  • Achieves up to 5–10× speedup on gather–matmul–scatter patterns by eliminating intermediate tensors
  • Supports multiple dtypes (fp16, fp32, bf16)—though not exhaustively tested.
  • (maybe) more maintainable alternative to hardcoded templates

How to Enable

You can test this feature by setting:

torch._inductor.config.triton.use_dot_reduction = True

Prototype fork: https://github.com/nullplay/pytorch/tree/jaeyeon_fuse
Test cases: https://github.com/nullplay/pytorch/blob/jaeyeon_fuse/test/inductor/test_dot_reduction.py

Since this is a prototype, there are some limitations like 1. Prototype is implemented in a hacky way and needs refactoring,
2. Excessive fusion can sometimes reduce performance (need better fusion heuristics), and
3. Need to implement autotuning for these kernels more robust


I would appreciate feedback from PyTorch developers on this direction. Do you think enabling native tl.dot codegen in Inductor is a reasonable and maintainable path forward for high-performance matmul fusion?

@jansel @eellison @drisspg @blaine-rister

Alternatives

No response

Additional context

No response

cc @chauhang @penguinwu @voznesenskym @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @chenyang78 @kadeng @muchulee8 @amjames @aakhundov

Metadata

Metadata

Assignees

No one assigned

    Labels

    module: inductoroncall: pt2triagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions

      0