8000 [Inductor][Schedule][Fusion] Ops are not fused due to reduction unroll · Issue #153346 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content
[Inductor][Schedule][Fusion] Ops are not fused due to reduction unroll #153346
@etaf

Description

@etaf

🐛 Describe the bug

import torch
# device = "cpu" : all ops fused to one kernel
# devcie = "cuda"/"xpu": all ops fused to two kernel.
device = "cuda"
loss = torch.compile(torch.nn.CrossEntropyLoss(), backend="inductor")
input = torch.randn(3, 5, device=device)
target = torch.empty(3, dtype=torch.long, device=device).random_(5)
for i in range(3):
    output = loss(input, target)

In the above reproducer, two Triton kernels are generated for torch.nn.CrossEntropyLoss() on CUDA/XPU, but in theory, the two fusions could be combined into a single kernel. On the CPU, they can be properly fused into one kernel.

Analysis:

The op torch.nn.CrossEntropyLoss is decomposed to the following graph:

fx graph: graph():
    %arg0_1 : [num_users=4] = placeholder[target=arg0_1]
    %arg1_1 : [num_users=2] = placeholder[target=arg1_1]
// op_0:
    %ne_1 : [num_users=1] = call_function[target=torch.ops.aten.ne.Scalar](args = (%arg0_1, -100), kwargs = {})
    %amax : [num_users=1] = call_function[target=torch.ops.aten.amax.default](args = (%arg1_1, [1], True), kwargs = {})
    %sub : [num_users=2] = call_function[target=torch.ops.aten.sub.Tensor](args = (%arg1_1, %amax), kwargs = {})
// op_1:
    %exp : [num_users=1] = call_function[target=torch.ops.aten.exp.default](args = (%sub,), kwargs = {})
    %sum_1 : [num_users=1] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%exp, [1], True), kwargs = {})
    %log : [num_users=1] = call_function[target=torch.ops.aten.log.default](args = (%sum_1,), kwargs = {})
    %sub_1 : [num_users=1] = call_function[target=torch.ops.aten.sub.Tensor](args = (%sub, %log), kwargs = {})
// op_2:
    %ne : [num_users=1] = call_function[target=torch.ops.aten.ne.Scalar](args = (%arg0_1, -100), kwargs = {})
    %full_default : [num_users=1] = call_function[target=torch.ops.aten.full.default](args = ([], 0), kwargs = {dtype: torch.int64, layout: torch.strided, device: xpu:0, pin_memory: False})
    %where : [num_users=1] = call_function[target=torch.ops.aten.where.self](args = (%ne, %arg0_1, %full_default), kwargs = {})
    %unsqueeze : [num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%where, 1), kwargs = {})
    %gather : [num_users=1] = call_function[target=torch.ops.aten.gather.default](args = (%sub_1, 1, %unsqueeze), kwargs = {})
    %squeeze : [num_users=1] = call_function[target=torch.ops.aten.squeeze.dim](args = (%gather, 1), kwargs = {})
    %neg : [num_users=1] = call_function[target=torch.ops.aten.neg.default](args = (%squeeze,), kwargs = {})
    %full_default_1 : [num_users=1] = call_function[target=torch.ops.aten.full.default](args = ([], 0.0), kwargs = {dtype: torch.float32, layout: torch.strided, device: xpu:0, pin_memory: False})
    %where_1 : [num_users=1] = call_function[target=torch.ops.aten.where.self](args = (%ne_1, %neg, %full_default_1), kwargs = {})
    %sum_3 : [num_users=1] = call_function[target=torch.ops.aten.sum.default](args = (%where_1,), kwargs = {})
// op_3:
    %ne_2 : [num_users=1] = call_function[target=torch.ops.aten.ne.Scalar](args = (%arg0_1, -100), kwargs = {})
    %sum_2 : [num_users=1] = call_function[target=torch.ops.aten.sum.default](args = (%ne_2,), kwargs = {})
    %convert_element_type : [num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%sum_2, torch.float32), kwargs = {})
    %div : [num_users=1] = call_function[target=torch.ops.aten.div.Tensor](args = (%sum_3, %convert_element_type), kwargs = {})
    return (div,)

Then the graph is lowered t 8000 o the following schedule node:

op0: SchedulerNode(ComputedBuffer)
op0.writes = [MemoryDep('buf0', c0, {c0: 3})]
op0.unmet_dependencies = []
...
op1: SchedulerNode(ComputedBuffer)
op1.writes = [MemoryDep('buf1', c0, {c0: 3})]
op1.unmet_dependencies = [MemoryDep('buf0', c0, {c0: 3})]
...
op2: SchedulerNode(ComputedBuffer)
op2.writes = [MemoryDep('buf2', 0, {})]
op2.unmet_dependencies =
    [   MemoryDep('buf0', 0, {}),
        MemoryDep('buf0', 1, {}),
        MemoryDep('buf0', 2, {}),
        MemoryDep('buf1', 0, {}),
        MemoryDep('buf1', 1, {}),
        MemoryDep('buf1', 2, {})]
...
op3: SchedulerNode(ComputedBuffer)
op3.writes = [MemoryDep('buf3', 0, {})]
op3.unmet_dependencies = [MemoryDep('buf2', 0, {})]

The first round of inductor scheduler fusion: fuse op0 + op1 => op0_op1, op2 + op3=> op2_op3


op0_op1: FusedSchedulerNode(SchedulerNode,SchedulerNode)
op0_op1.writes = [MemoryDep('buf0', c0, {c0: 3}), MemoryDep('buf1', c0, {c0: 3})]
op0_op1.unmet_dependencies = []

op2_op3: FusedSchedulerNode(SchedulerNode,SchedulerNode)
op2_op3.writes = [MemoryDep('buf2', 0, {}), MemoryDep('buf3', 0, {})]
op2_op3.unmet_dependencies =
    [   MemoryDep('buf0', 0, {}),
        MemoryDep('buf0', 1, {}),
        MemoryDep('buf0', 2, {}),
        MemoryDep('buf1', 0, {}),
        MemoryDep('buf1', 1, {}),
        MemoryDep('buf1', 2, {})]

From the above IR, it's easy to conclude that op0_op1 and op2_op3 are fusible from the memory dependency perspective. However, the fusion score function score_fusion_memory() in the Inductor scheduler gives a score of 0, which means that the memory dependencies do not match, and therefore the operations cannot be fused into a single kernel.
The score_fusion_memory() function only checks whether two MemoryDep entries are exactly equal to determine if a read and a write match. In reality, it should be sufficient if the read is a subset of the write.

common_memory_deps = (node1.read_writes.reads | node1.read_writes.writes) & (
node2.read_writes.reads | node2.read_writes.writes
)
return sum(self.dep_size_hint(dep) for dep in common_memory_deps)

However, even we get a correct memory fusion score, they still can not be fused because the loop length not match between op2_op3 and op0_op2.

The real root cause is that the reduction loops of op2, op3 are unrolled during lowering stage.
Here is the sum_3 IR of op2, the loop length 3 is unrolled and get range=[], so there is no loop here:

sum_3  =>  TensorBox(StorageBox(
  ComputedBuffer(name='buf2', layout=FlexibleLayout('xpu:0', torch.float32, size=[], stride=[]), data=Reduction(
    'xpu',
    torch.float32,
    def inner_fn(index, rindex):
        r0_0 = rindex
        tmp0 = ops.load(arg0_1, r0_0)
        tmp1 = ops.constant(-100, torch.int64)
        tmp2 = tmp0 != tmp1
        tmp3 = ops.load(arg0_1, r0_0)
        tmp4 = ops.constant(-100, torch.int64)
        tmp5 = tmp3 != tmp4
        tmp6 = ops.load(arg0_1, r0_0)
        tmp7 = ops.constant(0, torch.int64)
        tmp8 = ops.where(tmp5, tmp6, tmp7)
        tmp9 = ops.load(arg1_1, tmp8 + 5 * r0_0)
        tmp10 = ops.load(buf0, r0_0)
        tmp11 = tmp9 - tmp10
        tmp12 = ops.load(buf1, r0_0)
        tmp13 = tmp11 - tmp12
        tmp14 = -tmp13
        tmp15 = ops.constant(0.0, torch.float32)
        tmp16 = ops.where(tmp2, tmp14, tmp15)
        return tmp16
    ,
    ranges=[],
    reduction_ranges=[3],
    reduction_type=sum,
    origin_node=None,
    origins=OrderedSet([sum_3, where_1, ne_1, neg, full_default_1])
  ))
))

But op0, op1 has a loop which is not unrolled.
And the unroll happens here:

pytorch/torch/_inductor/ir.py

Lines 1478 to 1493 in 100ec0b

if (
isinstance(reduction_numel, Integer)
and V.graph.sizevars.size_hint(reduction_numel)
< config.unroll_reductions_threshold
and (sympy_product(ranges) != 1 or is_gpu(device.type))
):
# NB: This works around https://github.com/pytorch/pytorch/issues/140457
# since turning reductions into pointwise ops can exacerbate this problem
return Pointwise.create(
device=device,
dtype=dst_dtype,
inner_fn=cls._unroll_reduction_fn(
inner_fn, reduction_ranges, reduction_type, src_dtype
),
ranges=ranges,
)

When lowering %sum_3 : [num_users=1] = call_function[target=torch.ops.aten.sum.default](args = (%where_1,), kwargs = {}) to Inductor ir node, the reduction are unrolled here. The %sum_3 reduced all the dimension to a single one value, so the range is [] in the above code. The is_gpu check express that why we can get fusion on cpu but not on cuda/xpu.
the gpu check is introduced by https://github.com/pytorch/pytorch/pull/140331/files#diff-cf6ca00beddc32a2a6a2933fb9913b6a2b925ffc3b745488967210e4343134acR1257
which is a work around for issue #140457

I'm trying to remove this work around.

Versions

PyTorch version: 2.8.0a0+git1798b0d
Is debug build: False
CUDA used to build PyTorch: None
ROCM used to build PyTorch: N/A

OS: Ubuntu 22.04.5 LTS (x86_64)
GCC version: (Ubuntu 11.4.0-1ubuntu1~22.04) 11.4.0
Clang version: Could not collect
CMake version: version 3.31.2
Libc version: glibc-2.35

Python version: 3.10.15 | packaged by conda-forge | (main, Oct 16 2024, 01:24:24) [GCC 13.3.0] (64-bit runtime)
Python platform: Linux-5.15.47+prerelease24.6.13-x86_64-with-glibc2.35

cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @chenyang78 @kadeng @muchulee8 @amjames @chauhang @aakhundov

Metadata

Metadata

Assignees

Labels

module: inductoroncall: pt2triagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions

    0