8000 [AOTI] Conv-BN folding on CPU not working anymore after benchmark script change in https://github.com/pytorch/pytorch/pull/123403 · Issue #127513 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content

[AOTI] Conv-BN folding on CPU not working anymore after benchmark script change in https://github.com/pytorch/pytorch/pull/123403 #127513

8000
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
chunyuan-w opened this issue May 30, 2024 · 8 comments
Labels
module: aotinductor aot inductor module: inductor oncall: pt2 triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@chunyuan-w
Copy link
Collaborator
chunyuan-w commented May 30, 2024

🐛 Describe the bug

Before #123403, when running the dynamo benchmark for AOTI, for inference on CPU, Conv and BN could be folded. However, after this PR which changed the benchmark script common.py: from gm = torch.export._trace._export_to_torch_ir to gm = torch.export._trace._export(xxx).module(), Conv-BN folding on CPU is not working anymore. This will bring performance gap between AOTI and Inductor for many CNN models in the benchmark.

Reproducer:
On the bottom is test.py to reproduce this issue.
By running TORCH_LOGS="+inductor" python -u test.py which simulates the behavior before this PR, from the output log we can see that Conv and BN have been folded:

V0530 00:24:04.224243 139790722356864 torch/_inductor/freezing.py:117] TRACED GRAPH
V0530 00:24:04.224243 139790722356864 torch/_inductor/freezing.py:117]  ===== FROZEN GRAPH =====
V0530 00:24:04.224243 139790722356864 torch/_inductor/freezing.py:117]  /home/chunyuan/inductor/pytorch/torch/fx/_lazy_graph_module.py class <lambda>(torch.nn.Module):
V0530 00:24:04.224243 139790722356864 torch/_inductor/freezing.py:117]     def forward(self, arg7_1: "f32[1, 3, 224, 224]"):
V0530 00:24:04.224243 139790722356864 torch/_inductor/freezing.py:117]         # No stacktrace found for following nodes
V0530 00:24:04.224243 139790722356864 torch/_inductor/freezing.py:117]         l__self___conv_bias: "f32[3]" = self.L__self___conv.bias
V0530 00:24:04.224243 139790722356864 torch/_inductor/freezing.py:117]         _frozen_param1 = self._frozen_param1
V0530 00:24:04.224243 139790722356864 torch/_inductor/freezing.py:117]         
V0530 00:24:04.224243 139790722356864 torch/_inductor/freezing.py:117]         # File: /home/chunyuan/inductor/pytorch/test/my-repro-conv-bn-folding.py:110 in forward, code: return self.bn(self.conv(x))
V0530 00:24:04.224243 139790722356864 torch/_inductor/freezing.py:117]         _convolution_pointwise_default: "f32[1, 3, 222, 222]" = torch.ops.mkldnn._convolution_pointwise.default(arg7_1, _frozen_param1, l__self___conv_bias, [0, 0], [1, 1], [1, 1], 1, 'none', [], '');  arg7_1 = _frozen_param1 = l__self___conv_bias = None
V0530 00:24:04.224243 139790722356864 torch/_inductor/freezing.py:117]         
V0530 00:24:04.224243 139790722356864 torch/_inductor/freezing.py:117]         # No stacktrace found for following nodes
V0530 00:24:04.224243 139790722356864 torch/_inductor/freezing.py:117]         inductor_force_stride_order_default: "f32[1, 3, 222, 222]" = torch.ops.prims.inductor_force_stride_order.default(_convolution_pointwise_default, (147852, 49284, 222, 1));  _convolution_pointwise_default = None
V0530 00:24:04.224243 139790722356864 torch/_inductor/freezing.py:117]         return (inductor_force_stride_order_default,)
V0530 00:24:04.224243 139790722356864 torch/_inductor/freezing.py:117]         
V0530 00:24:04.224243 139790722356864 torch/_inductor/freezing.py:117] 

By running TORCH_LOGS="+inductor" python -u test.py -a which is the situation after the above PR, we can see from the output log that Conv and BN are not folded, there's extra computation full and mul:

V0530 00:33:02.637904 140003124713088 torch/_inductor/freezing.py:117] TRACED GRAPH
V0530 00:33:02.637904 140003124713088 torch/_inductor/freezing.py:117]  ===== FROZEN GRAPH =====
V0530 00:33:02.637904 140003124713088 torch/_inductor/freezing.py:117]  /home/chunyuan/inductor/pytorch/torch/fx/_lazy_graph_module.py class <lambda>(torch.nn.Module):
V0530 00:33:02.637904 140003124713088 torch/_inductor/freezing.py:117]     def forward(self, arg5_1: "f32[1, 3, 224, 224]"):
V0530 00:33:02.637904 140003124713088 torch/_inductor/freezing.py:117]         # No stacktrace found for following nodes
V0530 00:33:02.637904 140003124713088 torch/_inductor/freezing.py:117]         _frozen_param5 = self._frozen_param5
V0530 00:33:02.637904 140003124713088 torch/_inductor/freezing.py:117]         
V0530 00:33:02.637904 140003124713088 torch/_inductor/freezing.py:117]         # File: /home/chunyuan/inductor/pytorch/test/my-repro-conv-bn-folding.py:110 in forward, code: return self.bn(self.conv(x))
V0530 00:33:02.637904 140003124713088 torch/_inductor/freezing.py:117]         _convolution_pointwise_default: "f32[1, 3, 222, 222]" = torch.ops.mkldnn._convolution_pointwise.default(arg5_1, _frozen_param5, None, [0, 0], [1, 1], [1, 1], 1, 'none', [], '');  arg5_1 = _frozen_param5 = None
V0530 00:33:02.637904 140003124713088 torch/_inductor/freezing.py:117]         full_default_1: "f32[3, 1, 1]" = torch.ops.aten.full.default([3, 1, 1], 0.9999949932098389, dtype = torch.float32, layout = torch.strided, device = device(type='cpu'), pin_memory = False)
V0530 00:33:02.637904 140003124713088 torch/_inductor/freezing.py:117]         mul_1: "f32[1, 3, 222, 222]" = torch.ops.aten.mul.Tensor(_convolution_pointwise_default, full_default_1);  _convolution_pointwise_default = full_default_1 = None
V0530 00:33:02.637904 140003124713088 torch/_inductor/freezing.py:117]         
V0530 00:33:02.637904 140003124713088 torch/_inductor/freezing.py:117]         # No stacktrace found for following nodes
V0530 00:33:02.637904 140003124713088 torch/_inductor/freezing.py:117]         inductor_force_stride_order_default: "f32[1, 3, 222, 222]" = torch.ops.prims.inductor_force_stride_order.default(mul_1, (147852, 49284, 222, 1));  mul_1 = None
V0530 00:33:02.637904 140003124713088 torch/_inductor/freezing.py:117]         return (inductor_force_stride_order_default,)
# test.py
import argparse
import torch
import torch._inductor.config as config
from typing import (
    Any,
    Mapping,
    Tuple,
)
import copy
import dataclasses
import functools
import weakref
from torch.utils import _pytree as pytree

def main(use_api_after_the_change):

    def _normalize_bench_inputs(example_inputs) -> Tuple[Tuple[Any], Mapping[str, Any]]:
        # NOTE(bowbao): For huggingface benchmark, example_inputs are formatted as dictionary,
        # and consumed like `model(**example_inputs)`.
        # For other benchmarks, example_inputs are formatted as tuple and consumed
        # like `model(*example_inputs)`.
        if isinstance(example_inputs, dict):
            return (), example_inputs
        else:
            return tuple(example_inputs), {}

    def _register_dataclass_output_as_pytree(example_outputs) -> None:
        # NOTE(angelayi): For huggingface benchmark, some example outputs are
        # formatted as a dataclass which pytree cannot consume. So we want
        # to register the pytree implementation here
        example_outputs_flat = pytree.tree_leaves(example_outputs)
        output_dataclass_types = [
            type(out) for out in example_outputs_flat if dataclasses.is_dataclass(type(out))
        ]
        for output_type in output_dataclass_types:
            from torch._export.utils import register_dataclass_as_pytree_node

            register_dataclass_as_pytree_node(
                output_type,
                serialized_type_name=f"{output_type.__module__}.{output_type.__name__}",
            )

    class AOTInductorModelCache:
        cache = dict()

        @classmethod
        def load(cls, model, example_inputs, device):
            import torch._inductor
            import torch.export._trace

            key = weakref.ref(model)
            if key not in cls.cache:
                # Register the output dataclass to pytree
                example_args, example_kwargs = _normalize_bench_inputs(example_inputs)
                with torch.no_grad():
                    # copy.deepcopy is required to prevent any surprising side-effect,
                    # see https://github.com/pytorch/pytorch/issues/113029
                    example_outputs = copy.deepcopy(model)(*example_args, **example_kwargs)

                if pytree._is_namedtuple_instance(example_outputs):
                    typ = type(example_outputs)
                    pytree._register_namedtuple(
                        typ,
                        serialized_type_name=f"{typ.__module__}.{typ.__name__}",
                    )
                else:
                    _register_dataclass_output_as_pytree(example_outputs)

                
                if use_api_after_the_change:
                    gm = torch.export._trace._export(
                        model,
                        example_args,
                        example_kwargs,
                        pre_dispatch=True,
                    ).module()
                else:
                    gm = torch.export._trace._export_to_torch_ir(
                        model,
                        example_args,
                        example_kwargs,
                    )
                
                with torch.no_grad():
                    so_path = torch._inductor.aot_compile(
                        gm, example_args, example_kwargs
                    )  # type: ignore[arg-type]

                cls.cache[key] = torch._export.aot_load(so_path, device)

            return cls.cache[key]

    def export_aot_inductor(model, example_inputs, device):
        optimized = AOTInductorModelCache.load(model, example_inputs, device)

        def opt_aot_inductor(_, example_inputs, collect_outputs=False):
            example_args, example_kwargs = _normalize_bench_inputs(example_inputs)
            return optimized(*example_args, **example_kwargs)

        return opt_aot_inductor


    class Model(torch.nn.Module):
        def __init__(self):
            super().__init__()
            self.conv = torch.nn.Conv2d(3, 3, 3, bias=False)
            self.bn = torch.nn.BatchNorm2d(3)

        def forward(self, x):
            return self.bn(self.conv(x))

    optimize_ctx = functools.partial(
        export_aot_inductor, device="cpu"
    )
    optimized_model_iter_fn = optimize_ctx
    example_inputs = (torch.randn(1, 3, 224, 224),)
    model = Model()
    model.eval()
    with torch.no_grad(), config.patch({"freezing": True}):
        for _ in range(3):
            optimized_model_iter_fn(model, example_inputs)
    
if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--after", "-a", action="store_true", help="use the api after the regression"
    )
    args = parser.parse_args()
    main(args.after)
    print("done")

Versions

Collecting environment information...
PyTorch version: 2.4.0a0+git669560d
Is debug build: False
CUDA used to build PyTorch: None
ROCM used to build PyTorch: N/A

OS: CentOS Stream 8 (x86_64)
GCC version: (GCC) 11.2.1 20220127 (Red Hat 11.2.1-9)
Clang version: Could not collect
CMake version: version 3.26.4
Libc version: glibc-2.28

Python version: 3.9.18 (main, Sep 11 2023, 13:41:44) [GCC 11.2.0] (64-bit runtime)
Python platform: Linux-5.16.0-x86_64-with-glibc2.28
Is CUDA available: False
CUDA runtime version: No CUDA
CUDA_MODULE_LOADING set to: N/A
GPU models and configuration: No CUDA
Nvidia driver version: No CUDA
cuDNN version: No CUDA
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

CPU:
Architecture: x86_64
CPU op-mode(s): 32-bit, 64-bit
Byte Order: Little Endian
CPU(s): 240
On-line CPU(s) list: 0-239
Thread(s) per core: 2
Core(s) per socket: 60
Socket(s): 2
NUMA node(s): 2
Vendor ID: GenuineIntel
CPU family: 6
Model: 143
Model name: Intel(R) Xeon(R) Platinum 8490H
Stepping: 8
CPU MHz: 1900.000
CPU max MHz: 3500.0000
CPU min MHz: 800.0000
BogoMIPS: 3800.00
Virtualization: VT-x
L1d cache: 48K
L1i cache: 32K
L2 cache: 2048K
L3 cache: 115200K
NUMA node0 CPU(s): 0-59,120-179
NUMA node1 CPU(s): 60-119,180-239
Flags: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush dts acpi mmx fxsr sse sse2 ss ht tm pbe syscall nx pdpe1gb rdtscp lm constant_tsc art arch_perfmon pebs bts rep_good nopl xtopology nonstop_tsc cpuid aperfmperf tsc_known_freq pni pclmulqdq dtes64 monitor ds_cpl vmx smx est tm2 ssse3 sdbg fma cx16 xtpr pdcm pcid dca sse4_1 sse4_2 x2apic movbe popcnt tsc_deadline_timer aes xsave avx f16c rdrand lahf_lm abm 3dnowprefetch cpuid_fault epb cat_l3 cat_l2 cdp_l3 invpcid_single intel_ppin cdp_l2 ssbd mba ibrs ibpb stibp ibrs_enhanced tpr_shadow vnmi flexpriority ept vpid ept_ad fsgsbase tsc_adjust bmi1 avx2 smep bmi2 erms invpcid cqm rdt_a avx512f avx512dq rdseed adx smap avx512ifma clflushopt clwb intel_pt avx512cd sha_ni avx512bw avx512vl xsaveopt xsavec xgetbv1 xsaves cqm_llc cqm_occup_llc cqm_mbm_total cqm_mbm_local split_lock_detect avx_vnni avx512_bf16 wbnoinvd dtherm ida arat pln pts hwp hwp_act_window hwp_epp hwp_pkg_req avx512vbmi umip pku ospke waitpkg avx512_vbmi2 gfni vaes vpclmulqdq avx512_vnni avx512_bitalg tme avx512_vpopcntdq la57 rdpid bus_lock_detect cldemote movdiri movdir64b enqcmd fsrm md_clear serialize tsxldtrk pconfig arch_lbr avx512_fp16 amx_tile flush_l1d arch_capabilities

Versions of relevant libraries:
[pip3] flake8==6.1.0
[pip3] flake8-bugbear==23.3.23
[pip3] flake8-comprehensions==3.12.0
[pip3] flake8-executable==2.1.3
[pip3] flake8-logging-format==0.9.0
[pip3] flake8-pyi==23.3.1
[pip3] flake8-simplify==0.19.3
[pip3] mypy==1.9.0
[pip3] mypy-extensions==1.0.0
[pip3] numpy==1.26.0
[pip3] optree==0.11.0
[pip3] torch==2.4.0a0+git669560d
[pip3] torchvision==0.19.0a0+f0c94cd
[pip3] vit-pytorch==0.40.2
[conda] blas 1.0 mkl
[conda] mkl 2023.1.0 h213fc3f_46343
[conda] mkl-include 2023.2.0 pypi_0 pypi
[conda] mkl-service 2.4.0 py39h5eee18b_1
[conda] mkl-static 2023.2.0 pypi_0 pypi
[conda] mkl_fft 1.3.8 py39h5eee18b_0
[conda] mkl_random 1.2.4 py39hdb19cb5_0
[conda] numpy 1.26.0 pypi_0 pypi
[conda] numpy-base 1.26.4 py39hb5e798b_0
[conda] optree 0.11.0 pypi_0 pypi
[conda] torch 2.4.0a0+git669560d dev_0
[conda] torchfix 0.4.0 pypi_0 pypi
[conda] torchvision 0.19.0a0+f0c94cd dev_0
[conda] vit-pytorch 0.40.2 dev_0

cc @ezyang @msaroufim @bdhirsh @anijain2305 @chauhang @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @peterbell10 @ipiszy @yf225 @chenyang78 @kadeng @muchulee8 @ColinPeppler @amjames @desertfire

@zou3519 zou3519 added triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module oncall: pt2 module: inductor module: aotinductor aot inductor labels May 30, 2024
@chunyuan-w
Copy link
Collaborator Author

The graph is different after this change for hf_T5, hf_T5_base and hf_T5_base in the Torchbench suite, which has also brought performance regression, seems there're other optimizations besides Conv-BN folding that are also missing after switching from gm = torch.export._trace._export_to_torch_ir to gm = torch.export._trace._export(xxx).module().

@chunyuan-w
Copy link
Collaborator Author

cc @tugsbayasgalan

@desertfire
Copy link
Contributor

@angelayi , any comments?

@angelayi
Copy link
Contributor
angelayi commented Jun 5, 2024

Sorry, looks like this is also an issue regarding pattern matching, where we're missing some patterns in the fuse_conv_bn pass.

@chunyuan-w
Copy link
Collaborator Author

Sorry, looks like this is also an issue regarding pattern matching, where we're missing some patterns in the fuse_conv_bn pass.

Thanks for the comment. May I know if there's a plan to fix this issue or any other suggestions? Since conv-bn folding is not working anymore after the change in #123403, many CNN models have regressions using AOTI on CPU. We're wondering if there's a quick fix possible to address this issue before the PT 2.4 branch cut on 6/10.

@chunyuan-w
Copy link
Collaborator Author

Sorry, looks like this is also an issue regarding pattern matching, where we're missing some patterns in the fuse_conv_bn pass.

Thanks for the comment. May I know if there's a plan to fix this issue or any other suggestions? Since conv-bn folding is not working anymore after the change in #123403, many CNN models have regressions using AOTI on CPU. We're wondering if there's a quick fix possible to address this issue before the PT 2.4 branch cut on 6/10.

Hi @angelayi, just want to further check with you if any fix plan on this issue since we observed 10% - 50% performance regression on multiple CNN models in the dynamo benchmark suite on CPU after this change. We fixed the support of freezing in AOTI on CPU (#124350) and would like to propose in the coming PT 2.4 that AOTI can match or exceeding the performance of torch.compile counterparts. However, due to the current issue, there will be large performance gap between AOTI and Inductor on CNN models since AOTI does not have conv-bn folding anymore.

For example the below models in TIMM are impacted:
mobilenetv2_100
lcnet_050
fbnetv3_b
tf_efficientnet_b0
cspdarknet53
tinynet_a
rexnet_100
eca_botnext26ts_256
mobilenetv3_large_100
eca_halonext26ts
mobilevit_s
tf_mixnet_l
levit_128
gmlp_s16_224

@angelayi
Copy link
Contributor

@chunyuan-w I not sure whats a quick way to fix this before the 2.4 release, so I'm ok with changing it back to using _export_to_torch_ir. Could you leave a TODO and point to this issue?

@chunyuan-w
Copy link
Collaborator Author

@chunyuan-w I not sure whats a quick way to fix this before the 2.4 release, so I'm ok with changing it back to using _export_to_torch_ir. Could you leave a TODO and point to this issue?

Sure, #128377 is created to revert the change in #123403. I've also added a comment as TODO to point to this issue.

I see that branch cut has already been done for 2.4. I'll submit a cherry-pick PR to release/2.4 when #128377 lands on main.

chunyuan-w added a commit that referenced this issue Jun 12, 2024
…#123403)""


This reverts commit d78991a.

This PR reverts #123403 to fix the performance regression as discussed in #127513 (comment).

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang

[ghstack-poisoned]
pytorchmergebot pushed a commit that referenced this issue Jun 12, 2024
chunyuan-w added a commit to chunyuan-w/pytorch that referenced this issue Jun 12, 2024
…23403)" (pytorch#128377)

This reverts commit d78991a.

This PR reverts pytorch#123403 to fix the performance regression as discussed in pytorch#127513 (comment).

Pull Request resolved: pytorch#128377
Approved by: https://github.com/jgong5, https://github.com/angelayi, https://github.com/desertfire

(cherry picked from commit 5ef70fa)
facebook-github-bot pushed a commit to pytorch/benchmark that referenced this issue Jun 13, 2024
… (#128377)

Summary:
This reverts commit d78991a7381adb3df5e9b63c365db4506643edce.

This PR reverts pytorch/pytorch#123403 to fix the performance regression as discussed in pytorch/pytorch#127513 (comment).

X-link: pytorch/pytorch#128377
Approved by: https://github.com/jgong5, https://github.com/angelayi, https://github.com/desertfire

Reviewed By: clee2000

Differential Revision: D58501783

fbshipit-source-id: 1e55cc2c0b315ed6869195ee6730e72eb6be9da9
TharinduRusira pushed a commit to TharinduRusira/pytorch that referenced this issue Jun 14, 2024
ignaciobartol pushed a commit to ignaciobartol/pytorch that referenced this issue Jun 14, 2024
atalman pushed a commit that referenced this issue Jun 19, 2024
#128511)

Revert "Make torch_geometric models compatible with export (#123403)" (#128377)

This reverts commit d78991a.

This PR reverts #123403 to fix the performance regression as discussed in #127513 (comment).

Pull Request resolved: #128377
Approved by: https://github.com/jgong5, https://github.com/angelayi, https://github.com/desertfire

(cherry picked from commit 5ef70fa)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: aotinductor aot inductor module: inductor oncall: pt2 triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

No branches or pull requests

4 participants
0