8000 [Re-land 90265] [inductor] add conv_transpose2d unary fusion for cpu … · pytorch/pytorch@cc49f5a · GitHub
[go: up one dir, main page]

Skip to content

Commit cc49f5a

Browse files
chunyuan-wpytorchmergebot
authored andcommitted
[Re-land 90265] [inductor] add conv_transpose2d unary fusion for cpu in inference mode (#91954)
Re-land #90265. Depend on internal ideep upgrade. [Update]: internal ideep upgrade issue is resolved in #92239. Pull Request resolved: #91954 Approved by: https://github.com/jgong5, https://github.com/desertfire
1 parent 3870fda commit cc49f5a

File tree

4 files changed

+232
-2
lines changed

4 files changed

+232
-2
lines changed

test/inductor/test_torchinductor.py

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1979,6 +1979,71 @@ def forward(self, x, y):
19791979
with torch.no_grad():
19801980
self.common(mod, (v, other), atol=2e-3, rtol=0.016)
19811981

1982+
@unittest.skipIf(HAS_CUDA, "only support cpu conv_transpose2d unary test")
1983+
def test_conv_transpose2d_unary(self):
1984+
class M(torch.nn.Module):
1985+
def __init__(
1986+
self,
1987+
unary_fn,
1988+
in_channels,
1989+
out_channels,
1990+
**kwargs,
1991+
):
1992+
super(M, self).__init__()
1993+
self.conv_transpose2d = torch.nn.ConvTranspose2d(
1994+
in_channels,
1995+
out_channels,
1996+
**kwargs,
1997+
)
1998+
self.unary_fn = unary_fn
1999+
2000+
def forward(self, x):
2001+
x = self.conv_transpose2d(x)
2002+
return self.unary_fn(x)
2003+
2004+
test_memory_format = [torch.contiguous_format, torch.channels_last]
2005+
options = itertools.product(
2006+
unary_list,
2007+
[True, False],
2008+
[1, 3],
2009+
[1, 2],
2010+
[1, 4],
2011+
[0, 1],
2012+
test_memory_format,
2013+
)
2014+
2015+
for (
2016+
unary_fn,
2017+
bias,
2018+
kernel_size,
2019+
dilation,
2020+
groups,
2021+
padding,
2022+
memory_format,
2023+
) in options:
2024+
oC = 32 * groups
2025+
iC = 3 * groups
2026+
x_shape = (1, iC, 28, 28)
2027+
mod = M(
2028+
unary_fn,
2029+
iC,
2030+
oC,
2031+
kernel_size=kernel_size,
2032+
padding=padding,
2033+
dilation=dilation,
2034+
groups=groups,
2035+
bias=bias,
2036+
).eval()
2037+
2038+
v = torch.randn(x_shape, dtype=torch.float32).to(
2039+
memory_format=memory_format
2040+
)
2041+
with torch.no_grad():
2042+
self.common(
2043+
mod,
2044+
(v,),
2045+
)
2046+
19822047
def test_gather1(self):
19832048
def fn(a, b):
19842049
return (

torch/_inductor/ir.py

Lines changed: 63 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3305,6 +3305,8 @@ def _prepare_convolution_fusion_create(
33053305
stride_: List[int],
33063306
dilation_: List[int],
33073307
groups: int,
3308+
transposed: bool = False,
3309+
output_padding_: List[int] = None,
33083310
):
33093311
"""
33103312
This function is a helper function to prepare inputs, layout and constant args
@@ -3317,6 +3319,7 @@ def _prepare_convolution_fusion_create(
33173319
padding = tuple(padding_)
33183320
dilation = tuple(dilation_)
33193321
assert isinstance(groups, int)
3322+
output_padding = tuple(output_padding_) if output_padding_ else (0, 0)
33203323
with V.graph.fake_mode:
33213324
x_fake = ir_node_to_tensor(x, guard_shape=True)
33223325
weight_fake = ir_node_to_tensor(weight, guard_shape=True)
@@ -3330,8 +3333,8 @@ def _prepare_convolution_fusion_create(
33303333
stride,
33313334
padding,
33323335
dilation,
3333-
False,
3334-
[0, 0],
3336+
transposed,
3337+
output_padding,
33353338
groups,
33363339
)
33373340
output_size = output.size()
@@ -3350,6 +3353,8 @@ def _prepare_convolution_fusion_create(
33503353
convert_shape_to_inductor(output_stride),
33513354
)
33523355
constant_args = [padding, stride, dilation, groups]
3356+
if transposed:
3357+
constant_args.insert(1, output_padding)
33533358

33543359
if bias is not None:
33553360
inputs.append(bias)
@@ -3684,6 +3689,62 @@ def apply_constraint(self):
36843689
pass
36853690

36863691

3692+
class ConvolutionTransposeUnary(ExternKernelAlloc):
3693+
kernel = "torch.ops.mkldnn._convolution_transpose_pointwise"
3694+
3695+
def __init__(
3696+
self,
3697+
layout,
3698+
inputs,
3699+
constant_args=(),
3700+
kernel="torch.ops.mkldnn._convolution_transpose_pointwise",
3701+
):
3702+
super().__init__(layout, inputs, constant_args)
3703+
self.kernel = kernel
3704+
3705+
def codegen(self, wrapper):
3706+
wrapper.writeline(
3707+
f"{self.get_name()} = {self.kernel}({', '.join(self.codegen_args())})"
3708+
)
3709+
3710+
@classmethod
3711+
def create(
3712+
cls,
3713+
x: "TensorBox",
3714+
weight: "TensorBox",
3715+
bias: "TensorBox",
3716+
padding_: List[int],
3717+
output_padding_: List[int],
3718+
stride_: List[int],
3719+
dilation_: List[int],
3720+
groups_: int,
3721+
attr,
3722+
scalars,
3723+
algorithm,
3724+
):
3725+
kernel = "torch.ops.mkldnn._convolution_transpose_pointwise"
3726+
transposed = True
3727+
(inputs, constant_args, kernel_layout, _,) = _prepare_convolution_fusion_create(
3728+
cls,
3729+
x,
3730+
weight,
3731+
bias,
3732+
padding_,
3733+
stride_,
3734+
dilation_,
3735+
groups_,
3736+
transposed,
3737+
output_padding_,
3738+
)
3739+
constant_args = constant_args + [attr, scalars, algorithm]
3740+
return ConvolutionTransposeUnary(
3741+
layout=kernel_layout,
3742+
inputs=inputs,
3743+
constant_args=constant_args,
3744+
kernel=kernel,
3745+
)
3746+
3747+
36873748
@dataclasses.dataclass
36883749
class MutableBox(IRNode):
36893750
"""

torch/_inductor/lowering.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -954,6 +954,36 @@ def linear_unary(
954954
def linear_binary(x: TensorBox, y: TensorBox, w: TensorBox, b: TensorBox, attr):
955955
return TensorBox.create(ir.LinearBinary.create(x, y, w, b, attr))
956956

957+
@register_lowering(torch.ops.mkldnn._convolution_transpose_pointwise)
958+
def convolution_transpose_unary(
959+
x: TensorBox,
960+
weight: TensorBox,
961+
bias: TensorBox,
962+
padding,
963+
output_padding,
964+
stride,
965+
dilation,
966+
groups,
967+
attr,
968+
scalars,
969+
algorithm,
970+
):
971+
return TensorBox.create(
972+
ir.ConvolutionTransposeUnary.create(
973+
x,
974+
weight,
975+
bias,
976+
padding,
977+
output_padding,
978+
stride,
979+
dilation,
980+
groups,
981+
attr,
982+
scalars,
983+
algorithm,
984+
)
985+
)
986+
957987
if torch._C.has_mkl:
958988

959989
@register_lowering(torch.ops.mkl._mkl_linear)

torch/_inductor/mkldnn.py

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -417,6 +417,69 @@ def forward(self, input, other):
417417
return y
418418

419419

420+
class ConvTransposeUnary2d(nn.ConvTranspose2d):
421+
def __init__(
422+
self,
423+
conv_transpose: nn.Module,
424+
unary: nn.Module,
425+
):
426+
super(ConvTransposeUnary2d, self).__init__(
427+
conv_transpose.in_channels,
428+
conv_transpose.out_channels,
429+
conv_transpose.kernel_size,
430+
conv_transpose.stride,
431+
conv_transpose.padding,
432+
conv_transpose.output_padding,
433+
conv_transpose.groups,
434+
conv_transpose.bias is not None,
435+
conv_transpose.dilation,
436+
conv_transpose.padding_mode,
437+
conv_transpose.weight.device,
438+
conv_transpose.weight.dtype,
439+
)
440+
self._update_module_params(conv_transpose, unary)
441+
442+
def _update_module_params(self, conv_transpose, unary):
443+
self.__dict__ = copy.deepcopy(conv_transpose.__dict__)
444+
self.attr, self.scalars, self.algorithm = unary_modules_map[unary.__class__](
445+
unary
446+
)
447+
448+
def _conv_transpose_forward(self, input, weight, bias):
449+
if self.padding_mode != "zeros":
450+
return torch.ops.mkldnn._convolution_transpose_pointwise(
451+
F.pad(
452+
input, self._reversed_padding_repeated_twice, mode=self.padding_mode
453+
),
454+
weight,
455+
bias,
456+
_pair(0),
457+
self.output_padding,
458+
self.stride,
459+
self.dilation,
460+
self.groups,
461+
self.attr,
462+
self.scalars,
463+
self.algorithm,
464+
)
465+
return torch.ops.mkldnn._convolution_transpose_pointwise(
466+
input,
467+
weight,
468+
bias,
469+
self.padding,
470+
self.output_padding,
471+
self.stride,
472+
self.dilation,
473+
self.groups,
474+
self.attr,
475+
self.scalars,
476+
self.algorithm,
477+
)
478+
479+
def forward(self, input):
480+
return self._conv_transpose_forward(input, self.weight, self.bias)
481+
482+
420483
def packed_conv_eval(conv: nn.Module, input_size: list):
421484
assert not (conv.training), "Fusion only for eval!"
422485
return ConvUnary2d(
@@ -486,6 +549,16 @@ def fused_linear_binary_eval(linear: nn.Module, attr: str, input_size: list):
486549
return linear_binary
487550

488551

552+
def fused_conv_transpose_unary_eval(
553+
conv_transpose: nn.Module, unary: nn.Module, input_size: list
554+
):
555+
assert not (conv_transpose.training), "Fusion only for eval!"
556+
return ConvTransposeUnary2d(
557+
conv_transpose,
558+
unary,
559+
)
560+
561+
489562
def mkldnn_fuse_fx(gm: torch.fx.GraphModule, example_inputs):
490563
is_cpu = all(
491564
example_input.device == torch.device("cpu")
@@ -753,6 +826,7 @@ def pack_module(gm: torch.fx.GraphModule):
753826
nn.Linear: fused_linear_unary_eval,
754827
ConvBinary2d: fused_conv_binary_unary_eval,
755828
ConvBinaryInplace2d: fused_conv_binary_unary_eval,
829+
nn.ConvTranspose2d: fused_conv_transpose_unary_eval,
756830
}
757831

758832

0 commit comments

Comments
 (0)
0