-
Notifications
You must be signed in to change notification settings - Fork 24.8k
Description
🐛 Describe the bug
import torch
# device = "cpu" : all ops fused to one kernel
# devcie = "cuda"/"xpu": all ops fused to two kernel.
device = "cuda"
loss = torch.compile(torch.nn.CrossEntropyLoss(), backend="inductor")
input = torch.randn(3, 5, device=device)
target = torch.empty(3, dtype=torch.long, device=device).random_(5)
for i in range(3):
output = loss(input, target)
In the above reproducer, two Triton kernels are generated for torch.nn.CrossEntropyLoss()
on CUDA/XPU, but in theory, the two fusions could be combined into a single kernel. On the CPU, they can be properly fused into one kernel.
Analysis:
The op torch.nn.CrossEntropyLoss
is decomposed to the following graph:
fx graph: graph():
%arg0_1 : [num_users=4] = placeholder[target=arg0_1]
%arg1_1 : [num_users=2] = placeholder[target=arg1_1]
// op_0:
%ne_1 : [num_users=1] = call_function[target=torch.ops.aten.ne.Scalar](args = (%arg0_1, -100), kwargs = {})
%amax : [num_users=1] = call_function[target=torch.ops.aten.amax.default](args = (%arg1_1, [1], True), kwargs = {})
%sub : [num_users=2] = call_function[target=torch.ops.aten.sub.Tensor](args = (%arg1_1, %amax), kwargs = {})
// op_1:
%exp : [num_users=1] = call_function[target=torch.ops.aten.exp.default](args = (%sub,), kwargs = {})
%sum_1 : [num_users=1] = call_function[target=torch.ops.aten.sum.dim_IntList](args = (%exp, [1], True), kwargs = {})
%log : [num_users=1] = call_function[target=torch.ops.aten.log.default](args = (%sum_1,), kwargs = {})
%sub_1 : [num_users=1] = call_function[target=torch.ops.aten.sub.Tensor](args = (%sub, %log), kwargs = {})
// op_2:
%ne : [num_users=1] = call_function[target=torch.ops.aten.ne.Scalar](args = (%arg0_1, -100), kwargs = {})
%full_default : [num_users=1] = call_function[target=torch.ops.aten.full.default](args = ([], 0), kwargs = {dtype: torch.int64, layout: torch.strided, device: xpu:0, pin_memory: False})
%where : [num_users=1] = call_function[target=torch.ops.aten.where.self](args = (%ne, %arg0_1, %full_default), kwargs = {})
%unsqueeze : [num_users=1] = call_function[target=torch.ops.aten.unsqueeze.default](args = (%where, 1), kwargs = {})
%gather : [num_users=1] = call_function[target=torch.ops.aten.gather.default](args = (%sub_1, 1, %unsqueeze), kwargs = {})
%squeeze : [num_users=1] = call_function[target=torch.ops.aten.squeeze.dim](args = (%gather, 1), kwargs = {})
%neg : [num_users=1] = call_function[target=torch.ops.aten.neg.default](args = (%squeeze,), kwargs = {})
%full_default_1 : [num_users=1] = call_function[target=torch.ops.aten.full.default](args = ([], 0.0), kwargs = {dtype: torch.float32, layout: torch.strided, device: xpu:0, pin_memory: False})
%where_1 : [num_users=1] = call_function[target=torch.ops.aten.where.self](args = (%ne_1, %neg, %full_default_1), kwargs = {})
%sum_3 : [num_users=1] = call_function[target=torch.ops.aten.sum.default](args = (%where_1,), kwargs = {})
// op_3:
%ne_2 : [num_users=1] = call_function[target=torch.ops.aten.ne.Scalar](args = (%arg0_1, -100), kwargs = {})
%sum_2 : [num_users=1] = call_function[target=torch.ops.aten.sum.default](args = (%ne_2,), kwargs = {})
%convert_element_type : [num_users=1] = call_function[target=torch.ops.prims.convert_element_type.default](args = (%sum_2, torch.float32), kwargs = {})
%div : [num_users=1] = call_function[target=torch.ops.aten.div.Tensor](args = (%sum_3, %convert_element_type), kwargs = {})
return (div,)
Then the graph is lowered t 8000 o the following schedule node:
op0: SchedulerNode(ComputedBuffer)
op0.writes = [MemoryDep('buf0', c0, {c0: 3})]
op0.unmet_dependencies = []
...
op1: SchedulerNode(ComputedBuffer)
op1.writes = [MemoryDep('buf1', c0, {c0: 3})]
op1.unmet_dependencies = [MemoryDep('buf0', c0, {c0: 3})]
...
op2: SchedulerNode(ComputedBuffer)
op2.writes = [MemoryDep('buf2', 0, {})]
op2.unmet_dependencies =
[ MemoryDep('buf0', 0, {}),
MemoryDep('buf0', 1, {}),
MemoryDep('buf0', 2, {}),
MemoryDep('buf1', 0, {}),
MemoryDep('buf1', 1, {}),
MemoryDep('buf1', 2, {})]
...
op3: SchedulerNode(ComputedBuffer)
op3.writes = [MemoryDep('buf3', 0, {})]
op3.unmet_dependencies = [MemoryDep('buf2', 0, {})]
The first round of inductor scheduler fusion: fuse op0 + op1 => op0_op1, op2 + op3=> op2_op3
op0_op1: FusedSchedulerNode(SchedulerNode,SchedulerNode)
op0_op1.writes = [MemoryDep('buf0', c0, {c0: 3}), MemoryDep('buf1', c0, {c0: 3})]
op0_op1.unmet_dependencies = []
op2_op3: FusedSchedulerNode(SchedulerNode,SchedulerNode)
op2_op3.writes = [MemoryDep('buf2', 0, {}), MemoryDep('buf3', 0, {})]
op2_op3.unmet_dependencies =
[ MemoryDep('buf0', 0, {}),
MemoryDep('buf0', 1, {}),
MemoryDep('buf0', 2, {}),
MemoryDep('buf1', 0, {}),
MemoryDep('buf1', 1, {}),
MemoryDep('buf1', 2, {})]
From the above IR, it's easy to conclude that op0_op1 and op2_op3 are fusible from the memory dependency perspective. However, the fusion score function score_fusion_memory() in the Inductor scheduler gives a score of 0, which means that the memory dependencies do not match, and therefore the operations cannot be fused into a single kernel.
The score_fusion_memory() function only checks whether two MemoryDep entries are exactly equal to determine if a read and a write match. In reality, it should be sufficient if the read is a subset of the write.
pytorch/torch/_inductor/scheduler.py
Lines 3873 to 3878 in c51bdf5
common_memory_deps = (node1.read_writes.reads | node1.read_writes.writes) & ( | |
node2.read_writes.reads | node2.read_writes.writes | |
) | |
return sum(self.dep_size_hint(dep) for dep in common_memory_deps) | |
However, even we get a correct memory fusion score, they still can not be fused because the loop length not match between op2_op3 and op0_op2.
The real root cause is that the reduction loops of op2, op3 are unrolled during lowering stage.
Here is the sum_3 IR of op2, the loop length 3 is unrolled and get range=[]
, so there is no loop here:
sum_3 => TensorBox(StorageBox(
ComputedBuffer(name='buf2', layout=FlexibleLayout('xpu:0', torch.float32, size=[], stride=[]), data=Reduction(
'xpu',
torch.float32,
def inner_fn(index, rindex):
r0_0 = rindex
tmp0 = ops.load(arg0_1, r0_0)
tmp1 = ops.constant(-100, torch.int64)
tmp2 = tmp0 != tmp1
tmp3 = ops.load(arg0_1, r0_0)
tmp4 = ops.constant(-100, torch.int64)
tmp5 = tmp3 != tmp4
tmp6 = ops.load(arg0_1, r0_0)
tmp7 = ops.constant(0, torch.int64)
tmp8 = ops.where(tmp5, tmp6, tmp7)
tmp9 = ops.load(arg1_1, tmp8 + 5 * r0_0)
tmp10 = ops.load(buf0, r0_0)
tmp11 = tmp9 - tmp10
tmp12 = ops.load(buf1, r0_0)
tmp13 = tmp11 - tmp12
tmp14 = -tmp13
tmp15 = ops.constant(0.0, torch.float32)
tmp16 = ops.where(tmp2, tmp14, tmp15)
return tmp16
,
ranges=[],
reduction_ranges=[3],
reduction_type=sum,
origin_node=None,
origins=OrderedSet([sum_3, where_1, ne_1, neg, full_default_1])
))
))
But op0, op1 has a loop which is not unrolled.
And the unroll happens here:
Lines 1478 to 1493 in 100ec0b
if ( | |
isinstance(reduction_numel, Integer) | |
and V.graph.sizevars.size_hint(reduction_numel) | |
< config.unroll_reductions_threshold | |
and (sympy_product(ranges) != 1 or is_gpu(device.type)) | |
): | |
# NB: This works around https://github.com/pytorch/pytorch/issues/140457 | |
# since turning reductions into pointwise ops can exacerbate this problem | |
return Pointwise.create( | |
device=device, | |
dtype=dst_dtype, | |
inner_fn=cls._unroll_reduction_fn( | |
inner_fn, reduction_ranges, reduction_type, src_dtype | |
), | |
ranges=ranges, | |
) |
When lowering %sum_3 : [num_users=1] = call_function[target=torch.ops.aten.sum.default](args = (%where_1,), kwargs = {})
to Inductor ir node, the reduction are unrolled here. The %sum_3
reduced all the dimension to a single one value, so the range
is []
in the above code. The is_gpu
check express that why we can get fusion on cpu but not on cuda/xpu.
the gpu check is introduced by https://github.com/pytorch/pytorch/pull/140331/files#diff-cf6ca00beddc32a2a6a2933fb9913b6a2b925ffc3b745488967210e4343134acR1257
which is a work around for issue #140457
I'm trying to remove this work around.
Versions
PyTorch version: 2.8.0a0+git1798b0d
Is debug build: False
CUDA used to build PyTorch: None
ROCM used to build PyTorch: N/A
OS: Ubuntu 22.04.5 LTS (x86_64)
GCC version: (Ubuntu 11.4.0-1ubuntu1~22.04) 11.4.0
Clang version: Could not collect
CMake version: version 3.31.2
Libc version: glibc-2.35
Python version: 3.10.15 | packaged by conda-forge | (main, Oct 16 2024, 01:24:24) [GCC 13.3.0] (64-bit runtime)
Python platform: Linux-5.15.47+prerelease24.6.13-x86_64-with-glibc2.35
cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @chenyang78 @kadeng @muchulee8 @amjames @chauhang @aakhundov