-
Notifications
You must be signed in to change notification settings - Fork 24.3k
[RFC] Add Cpp Template for GEMM related ops via max-autotune for Inductor CPU #125683
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
9 of 18 tasks
Labels
module: inductor
oncall: cpu inductor
CPU Inductor issues for Intel team to triage
oncall: pt2
triaged
This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Comments
jgong5
pushed a commit
that referenced
this issue
May 8, 2024
…p32)" This PR adds the Cpp template infrastructure and the initial FP32 gemm template. See RFC #125683 for more background info. 1. Cpp template infrastructure Similar template abstractions as the CUTLASS template, i.e., `CppTemplate`, `CppTemplateKernel`, `CppTemplateBuffer`. The MicroGemm micro-kernel abstraction that can be used by Cpp GEMM templates. 2. Initial FP32 gemm template This involves a GEMM template implementation `CppPackedGemmTemplate` that supports GEMM with constant weight (`B`) requiring `N` to be a multiple of register blocking while allows the static or dynamic sizes for the `M` (batch dim) of `A`. The `B` matrix would be prepacked. This is a typical setting for inference workloads. The template handles the thread decomposition (via `thread_blocking`) and cache blocking (via `cache_blocking`). Then it invokes `CppMicroGemm` which handles register blocking, instruction selection, and other CPU architecture-specific optimizations. A `CppMicroGemmFP32Vec` micro-kernel implementation is provided for fp32 matmuls implemented with ATen vec abstraction. 3. Correctness and performance The changes have been validated with fp32 inference on the three benchmark suites (torchbench, huggingface and timm_models) with both static shape and dynamic shapes. Since it is an initial implementation, we are still working on further performance improves with follow-up PRs including the optimizations in kernels as well as fusions. The perf gains are only observed from a selective number of models compared to the ATen kernels which are implemented with MKL. The perf gains are more obvious with dynamic shapes since MKL only supports packed gemm for static shapes. Below are details. Static shapes | Benchmark | torchbench | huggingface | timm_models | |------------|-------------|--------------|--------------| | Multi-threaded (baseline) | 1.47x | 1.36x | 1.91x | | Multi-threaded (max-autotune) | 1.47x | 1.36x | 1.92x | | Single-threaded (baseline) | 1.56x | 1.19x | 1.51x | | Single-threaded (max-autotune) | 1.56x | 1.19x | 1.52x | Key models being sped up: drq: 1.14x soft_act: 1.12 cait_m36_384: 1.18x Dynamic shapes | Benchmark | torchbench | huggingface | timm_models | | --- | --- | --- | --- | | Multi-threaded (baseline) | 1.43x | 1.28x | 1.85x | | Multi-threaded (max-autotune) | 1.47x | 1.28x | 1.85x | | Single-threaded (baseline) | 1.55x | 1.20x | 1.51x | | Single-threaded (max-autotune) | 1.56x | 1.19x | 1.53x | Key models being sped up: BERT_pytorch: 1.22x pyhpc_turbulent: 1.13x soft_actor_critic: 1.77x BlenderbotForCausalLM: 1.09x cait_m36_384: 1.17x cc voznesenskym penguinwu EikanWang Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler amjames desertfire chauhang [ghstack-poisoned]
jgong5
pushed a commit
that referenced
this issue
May 8, 2024
This PR adds the Cpp template infrastructure and the initial FP32 gemm template. See RFC #125683 for more background info. 1. Cpp template infrastructure Similar template abstractions as the CUTLASS template, i.e., `CppTemplate`, `CppTemplateKernel`, `CppTemplateBuffer`. The MicroGemm micro-kernel abstraction that can be used by Cpp GEMM templates. 2. Initial FP32 gemm template This involves a GEMM template implementation `CppPackedGemmTemplate` that supports GEMM with constant weight (`B`) requiring `N` to be a multiple of register blocking while allows the static or dynamic sizes for the `M` (batch dim) of `A`. The `B` matrix would be prepacked. This is a typical setting for inference workloads. The template handles the thread decomposition (via `thread_blocking`) and cache blocking (via `cache_blocking`). Then it invokes `CppMicroGemm` which handles register blocking, instruction selection, and other CPU architecture-specific optimizations. A `CppMicroGemmFP32Vec` micro-kernel implementation is provided for fp32 matmuls implemented with ATen vec abstraction. 3. Correctness and performance The changes have been validated with fp32 inference on the three benchmark suites (torchbench, huggingface and timm_models) with both static shape and dynamic shapes. Since it is an initial implementation, we are still working on further performance improves with follow-up PRs including the optimizations in kernels as well as fusions. The perf gains are only observed from a selective number of models compared to the ATen kernels which are implemented with MKL. The perf gains are more obvious with dynamic shapes since MKL only supports packed gemm for static shapes. Below are details. Static shapes | Benchmark | torchbench | huggingface | timm_models | |------------|-------------|--------------|--------------| | Multi-threaded (baseline) | 1.47x | 1.36x | 1.91x | | Multi-threaded (max-autotune) | 1.47x | 1.36x | 1.92x | | Single-threaded (baseline) | 1.56x | 1.19x | 1.51x | | Single-threaded (max-autotune) | 1.56x | 1.19x | 1.52x | Key models being sped up: drq: 1.14x soft_act: 1.12 cait_m36_384: 1.18x Dynamic shapes | Benchmark | torchbench | huggingface | timm_models | | --- | --- | --- | --- | | Multi-threaded (baseline) | 1.43x | 1.28x | 1.85x | | Multi-threaded (max-autotune) | 1.47x | 1.28x | 1.85x | | Single-threaded (baseline) | 1.55x | 1.20x | 1.51x | | Single-threaded (max-autotune) | 1.56x | 1.19x | 1.53x | Key models being sped up: BERT_pytorch: 1.22x pyhpc_turbulent: 1.13x soft_actor_critic: 1.77x BlenderbotForCausalLM: 1.09x cait_m36_384: 1.17x cc voznesenskym penguinwu EikanWang Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler amjames desertfire chauhang [ghstack-poisoned]
jgong5
pushed a commit
that referenced
this issue
May 8, 2024
…p32)" This PR adds the Cpp template infrastructure and the initial FP32 gemm template. See RFC #125683 for more background info. 1. Cpp template infrastructure Similar template abstractions as the CUTLASS template, i.e., `CppTemplate`, `CppTemplateKernel`, `CppTemplateBuffer`. The MicroGemm micro-kernel abstraction that can be used by Cpp GEMM templates. 2. Initial FP32 gemm template This involves a GEMM template implementation `CppPackedGemmTemplate` that supports GEMM with constant weight (`B`) requiring `N` to be a multiple of register blocking while allows the static or dynamic sizes for the `M` (batch dim) of `A`. The `B` matrix would be prepacked. This is a typical setting for inference workloads. The template handles the thread decomposition (via `thread_blocking`) and cache blocking (via `cache_blocking`). Then it invokes `CppMicroGemm` which handles register blocking, instruction selection, and other CPU architecture-specific optimizations. A `CppMicroGemmFP32Vec` micro-kernel implementation is provided for fp32 matmuls implemented with ATen vec abstraction. 3. Correctness and performance The changes have been validated with fp32 inference on the three benchmark suites (torchbench, huggingface and timm_models) with both static shape and dynamic shapes. Since it is an initial implementation, we are still working on further performance improves with follow-up PRs including the optimizations in kernels as well as fusions. The perf gains are only observed from a selective number of models compared to the ATen kernels which are implemented with MKL. The perf gains are more obvious with dynamic shapes since MKL only supports packed gemm for static shapes. Below are details. Static shapes | Benchmark | torchbench | huggingface | timm_models | |------------|-------------|--------------|--------------| | Multi-threaded (baseline) | 1.47x | 1.36x | 1.91x | | Multi-threaded (max-autotune) | 1.47x | 1.36x | 1.92x | | Single-threaded (baseline) | 1.56x | 1.19x | 1.51x | | Single-threaded (max-autotune) | 1.56x | 1.19x | 1.52x | Key models being sped up: drq: 1.14x soft_act: 1.12 cait_m36_384: 1.18x Dynamic shapes | Benchmark | torchbench | huggingface | timm_models | | --- | --- | --- | --- | | Multi-threaded (baseline) | 1.43x | 1.28x | 1.85x | | Multi-threaded (max-autotune) | 1.47x | 1.28x | 1.85x | | Single-threaded (baseline) | 1.55x | 1.20x | 1.51x | | Single-threaded (max-autotune) | 1.56x | 1.19x | 1.53x | Key models being sped up: BERT_pytorch: 1.22x pyhpc_turbulent: 1.13x soft_actor_critic: 1.77x BlenderbotForCausalLM: 1.09x cait_m36_384: 1.17x cc voznesenskym penguinwu EikanWang Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler amjames desertfire chauhang [ghstack-poisoned]
jgong5
pushed a commit
that referenced
this issue
May 8, 2024
This PR adds the Cpp template infrastructure and the initial FP32 gemm template. See RFC #125683 for more background info. 1. Cpp template infrastructure Similar template abstractions as the CUTLASS template, i.e., `CppTemplate`, `CppTemplateKernel`, `CppTemplateBuffer`. The MicroGemm micro-kernel abstraction that can be used by Cpp GEMM templates. 2. Initial FP32 gemm template This involves a GEMM template implementation `CppPackedGemmTemplate` that supports GEMM with constant weight (`B`) requiring `N` to be a multiple of register blocking while allows the static or dynamic sizes for the `M` (batch dim) of `A`. The `B` matrix would be prepacked. This is a typical setting for inference workloads. The template handles the thread decomposition (via `thread_blocking`) and cache blocking (via `cache_blocking`). Then it invokes `CppMicroGemm` which handles register blocking, instruction selection, and other CPU architecture-specific optimizations. A `CppMicroGemmFP32Vec` micro-kernel implementation is provided for fp32 matmuls implemented with ATen vec abstraction. 3. Correctness and performance The changes have been validated with fp32 inference on the three benchmark suites (torchbench, huggingface and timm_models) with both static shape and dynamic shapes. Since it is an initial implementation, we are still working on further performance improves with follow-up PRs including the optimizations in kernels as well as fusions. The perf gains are only observed from a selective number of models compared to the ATen kernels which are implemented with MKL. The perf gains are more obvious with dynamic shapes since MKL only supports packed gemm for static shapes. Below are details. Static shapes | Benchmark | torchbench | huggingface | timm_models | |------------|-------------|--------------|--------------| | Multi-threaded (baseline) | 1.47x | 1.36x | 1.91x | | Multi-threaded (max-autotune) | 1.47x | 1.36x | 1.92x | | Single-threaded (baseline) | 1.56x | 1.19x | 1.51x | | Single-threaded (max-autotune) | 1.56x | 1.19x | 1.52x | Key models being sped up: drq: 1.14x soft_act: 1.12 cait_m36_384: 1.18x Dynamic shapes | Benchmark | torchbench | huggingface | timm_models | | --- | --- | --- | --- | | Multi-threaded (baseline) | 1.43x | 1.28x | 1.85x | | Multi-threaded (max-autotune) | 1.47x | 1.28x | 1.85x | | Single-threaded (baseline) | 1.55x | 1.20x | 1.51x | | Single-threaded (max-autotune) | 1.56x | 1.19x | 1.53x | Key models being sped up: BERT_pytorch: 1.22x pyhpc_turbulent: 1.13x soft_actor_critic: 1.77x BlenderbotForCausalLM: 1.09x cait_m36_384: 1.17x cc voznesenskym penguinwu EikanWang Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler amjames desertfire chauhang [ghstack-poisoned]
jgong5
pushed a commit
that referenced
this issue
May 8, 2024
…p32)" This PR adds the Cpp template infrastructure and the initial FP32 gemm template. See RFC #125683 for more background info. 1. Cpp template infrastructure Similar template abstractions as the CUTLASS template, i.e., `CppTemplate`, `CppTemplateKernel`, `CppTemplateBuffer`. The MicroGemm micro-kernel abstraction that can be used by Cpp GEMM templates. 2. Initial FP32 gemm template This involves a GEMM template implementation `CppPackedGemmTemplate` that supports GEMM with constant weight (`B`) requiring `N` to be a multiple of register blocking while allows the static or dynamic sizes for the `M` (batch dim) of `A`. The `B` matrix would be prepacked. This is a typical setting for inference workloads. The template handles the thread decomposition (via `thread_blocking`) and cache blocking (via `cache_blocking`). Then it invokes `CppMicroGemm` which handles register blocking, instruction selection, and other CPU architecture-specific optimizations. A `CppMicroGemmFP32Vec` micro-kernel implementation is provided for fp32 matmuls implemented with ATen vec abstraction. 3. Correctness and performance The changes have been validated with fp32 inference on the three benchmark suites (torchbench, huggingface and timm_models) with both static shape and dynamic shapes. Since it is an initial implementation, we are still working on further performance improves with follow-up PRs including the optimizations in kernels as well as fusions. The perf gains are only observed from a selective number of models compared to the ATen kernels which are implemented with MKL. The perf gains are more obvious with dynamic shapes since MKL only supports packed gemm for static shapes. Below are details. Static shapes | Benchmark | torchbench | huggingface | timm_models | |------------|-------------|--------------|--------------| | Multi-threaded (baseline) | 1.47x | 1.36x | 1.91x | | Multi-threaded (max-autotune) | 1.47x | 1.36x | 1.92x | | Single-threaded (baseline) | 1.56x | 1.19x | 1.51x | | Single-threaded (max-autotune) | 1.56x | 1.19x | 1.52x | Key models being sped up: drq: 1.14x soft_act: 1.12 cait_m36_384: 1.18x Dynamic shapes | Benchmark | torchbench | huggingface | timm_models | | --- | --- | --- | --- | | Multi-threaded (baseline) | 1.43x | 1.28x | 1.85x | | Multi-threaded (max-autotune) | 1.47x | 1.28x | 1.85x | | Single-threaded (baseline) | 1.55x | 1.20x | 1.51x | | Single-threaded (max-autotune) | 1.56x | 1.19x | 1.53x | Key models being sped up: BERT_pytorch: 1.22x pyhpc_turbulent: 1.13x soft_actor_critic: 1.77x BlenderbotForCausalLM: 1.09x cait_m36_384: 1.17x cc voznesenskym penguinwu EikanWang Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler amjames desertfire chauhang [ghstack-poisoned]
jgong5
pushed a commit
that referenced
this issue
May 8, 2024
This PR adds the Cpp template infrastructure and the initial FP32 gemm template. See RFC #125683 for more background info. 1. Cpp template infrastructure Similar template abstractions as the CUTLASS template, i.e., `CppTemplate`, `CppTemplateKernel`, `CppTemplateBuffer`. The MicroGemm micro-kernel abstraction that can be used by Cpp GEMM templates. 2. Initial FP32 gemm template This involves a GEMM template implementation `CppPackedGemmTemplate` that supports GEMM with constant weight (`B`) requiring `N` to be a multiple of register blocking while allows the static or dynamic sizes for the `M` (batch dim) of `A`. The `B` matrix would be prepacked. This is a typical setting for inference workloads. The template handles the thread decomposition (via `thread_blocking`) and cache blocking (via `cache_blocking`). Then it invokes `CppMicroGemm` which handles register blocking, instruction selection, and other CPU architecture-specific optimizations. A `CppMicroGemmFP32Vec` micro-kernel implementation is provided for fp32 matmuls implemented with ATen vec abstraction. 3. Correctness and performance The changes have been validated with fp32 inference on the three benchmark suites (torchbench, huggingface and timm_models) with both static shape and dynamic shapes. Since it is an initial implementation, we are still working on further performance improves with follow-up PRs including the optimizations in kernels as well as fusions. The perf gains are only observed from a selective number of models compared to the ATen kernels which are implemented with MKL. The perf gains are more obvious with dynamic shapes since MKL only supports packed gemm for static shapes. Below are details. Static shapes | Benchmark | torchbench | huggingface | timm_models | |------------|-------------|--------------|--------------| | Multi-threaded (baseline) | 1.47x | 1.36x | 1.91x | | Multi-threaded (max-autotune) | 1.47x | 1.36x | 1.92x | | Single-threaded (baseline) | 1.56x | 1.19x | 1.51x | | Single-threaded (max-autotune) | 1.56x | 1.19x | 1.52x | Key models being sped up: drq: 1.14x soft_act: 1.12 cait_m36_384: 1.18x Dynamic shapes | Benchmark | torchbench | huggingface | timm_models | | --- | --- | --- | --- | | Multi-threaded (baseline) | 1.43x | 1.28x | 1.85x | | Multi-threaded (max-autotune) | 1.47x | 1.28x | 1.85x | | Single-threaded (baseline) | 1.55x | 1.20x | 1.51x | | Single-threaded (max-autotune) | 1.56x | 1.19x | 1.53x | Key models being sped up: BERT_pytorch: 1.22x pyhpc_turbulent: 1.13x soft_actor_critic: 1.77x BlenderbotForCausalLM: 1.09x cait_m36_384: 1.17x cc voznesenskym penguinwu EikanWang Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler amjames desertfire chauhang [ghstack-poisoned]
jgong5
pushed a commit
that referenced
this issue
May 9, 2024
…p32)" This PR adds the Cpp template infrastructure and the initial FP32 gemm template. See RFC #125683 for more background info. 1. Cpp template infrastructure Similar template abstractions as the CUTLASS template, i.e., `CppTemplate`, `CppTemplateKernel`, `CppTemplateBuffer`. The MicroGemm micro-kernel abstraction that can be used by Cpp GEMM templates. 2. Initial FP32 gemm template This involves a GEMM template implementation `CppPackedGemmTemplate` that supports GEMM with constant weight (`B`) requiring `N` to be a multiple of register blocking while allows the static or dynamic sizes for the `M` (batch dim) of `A`. The `B` matrix would be prepacked. This is a typical setting for inference workloads. The template handles the thread decomposition (via `thread_blocking`) and cache blocking (via `cache_blocking`). Then it invokes `CppMicroGemm` which handles register blocking, instruction selection, and other CPU architecture-specific optimizations. A `CppMicroGemmFP32Vec` micro-kernel implementation is provided for fp32 matmuls implemented with ATen vec abstraction. 3. Correctness and performance The changes have been validated with fp32 inference on the three benchmark suites (torchbench, huggingface and timm_models) with both static shape and dynamic shapes. Since it is an initial implementation, we are still working on further performance improves with follow-up PRs including the optimizations in kernels as well as fusions. The perf gains are only observed from a selective number of models compared to the ATen kernels which are implemented with MKL. The perf gains are more obvious with dynamic shapes since MKL only supports packed gemm for static shapes. Below are details. Static shapes | Benchmark | torchbench | huggingface | timm_models | |------------|-------------|--------------|--------------| | Multi-threaded (baseline) | 1.47x | 1.36x | 1.91x | | Multi-threaded (max-autotune) | 1.47x | 1.36x | 1.92x | | Single-threaded (baseline) | 1.56x | 1.19x | 1.51x | | Single-threaded (max-autotune) | 1.56x | 1.19x | 1.52x | Key models being sped up: drq: 1.14x soft_act: 1.12 cait_m36_384: 1.18x Dynamic shapes | Benchmark | torchbench | huggingface | timm_models | | --- | --- | --- | --- | | Multi-threaded (baseline) | 1.43x | 1.28x | 1.85x | | Multi-threaded (max-autotune) | 1.47x | 1.28x | 1.85x | | Single-threaded (baseline) | 1.55x | 1.20x | 1.51x | | Single-threaded (max-autotune) | 1.56x | 1.19x | 1.53x | Key models being sped up: BERT_pytorch: 1.22x pyhpc_turbulent: 1.13x soft_actor_critic: 1.77x BlenderbotForCausalLM: 1.09x cait_m36_384: 1.17x cc voznesenskym penguinwu EikanWang Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler amjames desertfire chauhang [ghstack-poisoned]
jgong5
pushed a commit
that referenced
this issue
May 9, 2024
This PR adds the Cpp template infrastructure and the initial FP32 gemm template. See RFC #125683 for more background info. 1. Cpp template infrastructure Similar template abstractions as the CUTLASS template, i.e., `CppTemplate`, `CppTemplateKernel`, `CppTemplateBuffer`. The MicroGemm micro-kernel abstraction that can be used by Cpp GEMM templates. 2. Initial FP32 gemm template This involves a GEMM template implementation `CppPackedGemmTemplate` that supports GEMM with constant weight (`B`) requiring `N` to be a multiple of register blocking while allows the static or dynamic sizes for the `M` (batch dim) of `A`. The `B` matrix would be prepacked. This is a typical setting for inference workloads. The template handles the thread decomposition (via `thread_blocking`) and cache blocking (via `cache_blocking`). Then it invokes `CppMicroGemm` which handles register blocking, instruction selection, and other CPU architecture-specific optimizations. A `CppMicroGemmFP32Vec` micro-kernel implementation is provided for fp32 matmuls implemented with ATen vec abstraction. 3. Correctness and performance The changes have been validated with fp32 inference on the three benchmark suites (torchbench, huggingface and timm_models) with both static shape and dynamic shapes. Since it is an initial implementation, we are still working on further performance improves with follow-up PRs including the optimizations in kernels as well as fusions. The perf gains are only observed from a selective number of models compared to the ATen kernels which are implemented with MKL. The perf gains are more obvious with dynamic shapes since MKL only supports packed gemm for static shapes. Below are details. Static shapes | Benchmark | torchbench | huggingface | timm_models | |------------|-------------|--------------|--------------| | Multi-threaded (baseline) | 1.47x | 1.36x | 1.91x | | Multi-threaded (max-autotune) | 1.47x | 1.36x | 1.92x | | Single-threaded (baseline) | 1.56x | 1.19x | 1.51x | | Single-threaded (max-autotune) | 1.56x | 1.19x | 1.52x | Key models being sped up: drq: 1.14x soft_act: 1.12 cait_m36_384: 1.18x Dynamic shapes | Benchmark | torchbench | huggingface | timm_models | | --- | --- | --- | --- | | Multi-threaded (baseline) | 1.43x | 1.28x | 1.85x | | Multi-threaded (max-autotune) | 1.47x | 1.28x | 1.85x | | Single-threaded (baseline) | 1.55x | 1.20x | 1.51x | | Single-threaded (max-autotune) | 1.56x | 1.19x | 1.53x | Key models being sped up: BERT_pytorch: 1.22x pyhpc_turbulent: 1.13x soft_actor_critic: 1.77x BlenderbotForCausalLM: 1.09x cait_m36_384: 1.17x cc voznesenskym penguinwu EikanWang Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler amjames desertfire chauhang [ghstack-poisoned]
jgong5
pushed a commit
that referenced
this issue
May 9, 2024
…p32)" This PR adds the Cpp template infrastructure and the initial FP32 gemm template. See RFC #125683 for more background info. 1. Cpp template infrastructure Similar template abstractions as the CUTLASS template, i.e., `CppTemplate`, `CppTemplateKernel`, `CppTemplateBuffer`. The MicroGemm micro-kernel abstraction that can be used by Cpp GEMM templates. 2. Initial FP32 gemm template This involves a GEMM template implementation `CppPackedGemmTemplate` that supports GEMM with constant weight (`B`) requiring `N` to be a multiple of register blocking while allows the static or dynamic sizes for the `M` (batch dim) of `A`. The `B` matrix would be prepacked. This is a typical setting for inference workloads. The template handles the thread decomposition (via `thread_blocking`) and cache blocking (via `cache_blocking`). Then it invokes `CppMicroGemm` which handles register blocking, instruction selection, and other CPU architecture-specific optimizations. A `CppMicroGemmFP32Vec` micro-kernel implementation is provided for fp32 matmuls implemented with ATen vec abstraction. 3. Correctness and performance The changes have been validated with fp32 inference on the three benchmark suites (torchbench, huggingface and timm_models) with both static shape and dynamic shapes. Since it is an initial implementation, we are still working on further performance improves with follow-up PRs including the optimizations in kernels as well as fusions. The perf gains are only observed from a selective number of models compared to the ATen kernels which are implemented with MKL. The perf gains are more obvious with dynamic shapes since MKL only supports packed gemm for static shapes. Below are details. Static shapes | Benchmark | torchbench | huggingface | timm_models | |------------|-------------|--------------|--------------| | Multi-threaded (baseline) | 1.47x | 1.36x | 1.91x | | Multi-threaded (max-autotune) | 1.47x | 1.36x | 1.92x | | Single-threaded (baseline) | 1.56x | 1.19x | 1.51x | | Single-threaded (max-autotune) | 1.56x | 1.19x | 1.52x | Key models being sped up: drq: 1.14x soft_act: 1.12 cait_m36_384: 1.18x Dynamic shapes | Benchmark | torchbench | huggingface | timm_models | | --- | --- | --- | --- | | Multi-threaded (baseline) | 1.43x | 1.28x | 1.85x | | Multi-threaded (max-autotune) | 1.47x | 1.28x | 1.85x | | Single-threaded (baseline) | 1.55x | 1.20x | 1.51x | | Single-threaded (max-autotune) | 1.56x | 1.19x | 1.53x | Key models being sped up: BERT_pytorch: 1.22x pyhpc_turbulent: 1.13x soft_actor_critic: 1.77x BlenderbotForCausalLM: 1.09x cait_m36_384: 1.17x cc voznesenskym penguinwu EikanWang Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler amjames desertfire chauhang [ghstack-poisoned]
jgong5
pushed a commit
that referenced
this issue
May 9, 2024
This PR adds the Cpp template infrastructure and the initial FP32 gemm template. See RFC #125683 for more background info. 1. Cpp template infrastructure Similar template abstractions as the CUTLASS template, i.e., `CppTemplate`, `CppTemplateKernel`, `CppTemplateBuffer`. The MicroGemm micro-kernel abstraction that can be used by Cpp GEMM templates. 2. Initial FP32 gemm template This involves a GEMM template implementation `CppPackedGemmTemplate` that supports GEMM with constant weight (`B`) requiring `N` to be a multiple of register blocking while allows the static or dynamic sizes for the `M` (batch dim) of `A`. The `B` matrix would be prepacked. This is a typical setting for inference workloads. The template handles the thread decomposition (via `thread_blocking`) and cache blocking (via `cache_blocking`). Then it invokes `CppMicroGemm` which handles register blocking, instruction selection, and other CPU architecture-specific optimizations. A `CppMicroGemmFP32Vec` micro-kernel implementation is provided for fp32 matmuls implemented with ATen vec abstraction. 3. Correctness and performance The changes have been validated with fp32 inference on the three benchmark suites (torchbench, huggingface and timm_models) with both static shape and dynamic shapes. Since it is an initial implementation, we are still working on further performance improves with follow-up PRs including the optimizations in kernels as well as fusions. The perf gains are only observed from a selective number of models compared to the ATen kernels which are implemented with MKL. The perf gains are more obvious with dynamic shapes since MKL only supports packed gemm for static shapes. Below are details. Static shapes | Benchmark | torchbench | huggingface | timm_models | |------------|-------------|--------------|--------------| | Multi-threaded (baseline) | 1.47x | 1.36x | 1.91x | | Multi-threaded (max-autotune) | 1.47x | 1.36x | 1.92x | | Single-threaded (baseline) | 1.56x | 1.19x | 1.51x | | Single-threaded (max-autotune) | 1.56x | 1.19x | 1.52x | Key models being sped up: drq: 1.14x soft_act: 1.12 cait_m36_384: 1.18x Dynamic shapes | Benchmark | torchbench | huggingface | timm_models | | --- | --- | --- | --- | | Multi-threaded (baseline) | 1.43x | 1.28x | 1.85x | | Multi-threaded (max-autotune) | 1.47x | 1.28x | 1.85x | | Single-threaded (baseline) | 1.55x | 1.20x | 1.51x | | Single-threaded (max-autotune) | 1.56x | 1.19x | 1.53x | Key models being sped up: BERT_pytorch: 1.22x pyhpc_turbulent: 1.13x soft_actor_critic: 1.77x BlenderbotForCausalLM: 1.09x cait_m36_384: 1.17x cc voznesenskym penguinwu EikanWang Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler amjames desertfire chauhang [ghstack-poisoned]
jgong5
pushed a commit
that referenced
this issue
May 9, 2024
…p32)" This PR adds the Cpp template infrastructure and the initial FP32 gemm template. See RFC #125683 for more background info. 1. Cpp template infrastructure Similar template abstractions as the CUTLASS template, i.e., `CppTemplate`, `CppTemplateKernel`, `CppTemplateBuffer`. The MicroGemm micro-kernel abstraction that can be used by Cpp GEMM templates. 2. Initial FP32 gemm template This involves a GEMM template implementation `CppPackedGemmTemplate` that supports GEMM with constant weight (`B`) requiring `N` to be a multiple of register blocking while allows the static or dynamic sizes for the `M` (batch dim) of `A`. The `B` matrix would be prepacked. This is a typical setting for inference workloads. The template handles the thread decomposition (via `thread_blocking`) and cache blocking (via `cache_blocking`). Then it invokes `CppMicroGemm` which handles register blocking, instruction selection, and other CPU architecture-specific optimizations. A `CppMicroGemmFP32Vec` micro-kernel implementation is provided for fp32 matmuls implemented with ATen vec abstraction. 3. Correctness and performance The changes have been validated with fp32 inference on the three benchmark suites (torchbench, huggingface and timm_models) with both static shape and dynamic shapes. Since it is an initial implementation, we are still working on further performance improves with follow-up PRs including the optimizations in kernels as well as fusions. The perf gains are only observed from a selective number of models compared to the ATen kernels which are implemented with MKL. The perf gains are more obvious with dynamic shapes since MKL only supports packed gemm for static shapes. Below are details. Static shapes | Benchmark | torchbench | huggingface | timm_models | |------------|-------------|--------------|--------------| | Multi-threaded (baseline) | 1.47x | 1.36x | 1.91x | | Multi-threaded (max-autotune) | 1.47x | 1.36x | 1.92x | | Single-threaded (baseline) | 1.56x | 1.19x | 1.51x | | Single-threaded (max-autotune) | 1.56x | 1.19x | 1.52x | Key models being sped up: drq: 1.14x soft_act: 1.12 cait_m36_384: 1.18x Dynamic shapes | Benchmark | torchbench | huggingface | timm_models | | --- | --- | --- | --- | | Multi-threaded (baseline) | 1.43x | 1.28x | 1.85x | | Multi-threaded (max-autotune) | 1.47x | 1.28x | 1.85x | | Single-threaded (baseline) | 1.55x | 1.20x | 1.51x | | Single-threaded (max-autotune) | 1.56x | 1.19x | 1.53x | Key models being sped up: BERT_pytorch: 1.22x pyhpc_turbulent: 1.13x soft_actor_critic: 1.77x BlenderbotForCausalLM: 1.09x cait_m36_384: 1.17x cc voznesenskym penguinwu EikanWang Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler amjames desertfire chauhang [ghstack-poisoned]
jgong5
pushed a commit
that referenced
this issue
May 9, 2024
This PR adds the Cpp template infrastructure and the initial FP32 gemm template. See RFC #125683 for more background info. 1. Cpp template infrastructure Similar template abstractions as the CUTLASS template, i.e., `CppTemplate`, `CppTemplateKernel`, `CppTemplateBuffer`. The MicroGemm micro-kernel abstraction that can be used by Cpp GEMM templates. 2. Initial FP32 gemm template This involves a GEMM template implementation `CppPackedGemmTemplate` that supports GEMM with constant weight (`B`) requiring `N` to be a multiple of register blocking while allows the static or dynamic sizes for the `M` (batch dim) of `A`. The `B` matrix would be prepacked. This is a typical setting for inference workloads. The template handles the thread decomposition (via `thread_blocking`) and cache blocking (via `cache_blocking`). Then it invokes `CppMicroGemm` which handles register blocking, instruction selection, and other CPU architecture-specific optimizations. A `CppMicroGemmFP32Vec` micro-kernel implementation is provided for fp32 matmuls implemented with ATen vec abstraction. 3. Correctness and performance The changes have been validated with fp32 inference on the three benchmark suites (torchbench, huggingface and timm_models) with both static shape and dynamic shapes. Since it is an initial implementation, we are still working on further performance improves with follow-up PRs including the optimizations in kernels as well as fusions. The perf gains are only observed from a selective number of models compared to the ATen kernels which are implemented with MKL. The perf gains are more obvious with dynamic shapes since MKL only supports packed gemm for static shapes. Below are details. Static shapes | Benchmark | torchbench | huggingface | timm_models | |------------|-------------|--------------|--------------| | Multi-threaded (baseline) | 1.47x | 1.36x | 1.91x | | Multi-threaded (max-autotune) | 1.47x | 1.36x | 1.92x | | Single-threaded (baseline) | 1.56x | 1.19x | 1.51x | | Single-threaded (max-autotune) | 1.56x | 1.19x | 1.52x | Key models being sped up: drq: 1.14x soft_act: 1.12 cait_m36_384: 1.18x Dynamic shapes | Benchmark | torchbench | huggingface | timm_models | | --- | --- | --- | --- | | Multi-threaded (baseline) | 1.43x | 1.28x | 1.85x | | Multi-threaded (max-autotune) | 1.47x | 1.28x | 1.85x | | Single-threaded (baseline) | 1.55x | 1.20x | 1.51x | | Single-threaded (max-autotune) | 1.56x | 1.19x | 1.53x | Key models being sped up: BERT_pytorch: 1.22x pyhpc_turbulent: 1.13x soft_actor_critic: 1.77x BlenderbotForCausalLM: 1.09x cait_m36_384: 1.17x cc voznesenskym penguinwu EikanWang Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler amjames desertfire chauhang [ghstack-poisoned]
jgong5
pushed a commit
that referenced
this issue
May 12, 2024
…p32)" This PR adds the Cpp template infrastructure and the initial FP32 gemm template. See RFC #125683 for more background info. 1. Cpp template infrastructure Similar template abstractions as the CUTLASS template, i.e., `CppTemplate`, `CppTemplateKernel`, `CppTemplateBuffer`. The MicroGemm micro-kernel abstraction that can be used by Cpp GEMM templates. 2. Initial FP32 gemm template This involves a GEMM template implementation `CppPackedGemmTemplate` that supports GEMM with constant weight (`B`) requiring `N` to be a multiple of register blocking while allows the static or dynamic sizes for the `M` (batch dim) of `A`. The `B` matrix would be prepacked. This is a typical setting for inference workloads. The template handles the thread decomposition (via `thread_blocking`) and cache blocking (via `cache_blocking`). Then it invokes `CppMicroGemm` which handles register blocking, instruction selection, and other CPU architecture-specific optimizations. A `CppMicroGemmFP32Vec` micro-kernel implementation is provided for fp32 matmuls implemented with ATen vec abstraction. 3. Correctness and performance The changes have been validated with fp32 inference on the three benchmark suites (torchbench, huggingface and timm_models) with both static shape and dynamic shapes. Since it is an initial implementation, we are still working on further performance improves with follow-up PRs including the optimizations in kernels as well as fusions. The perf gains are only observed from a selective number of models compared to the ATen kernels which are implemented with MKL. The perf gains are more obvious with dynamic shapes since MKL only supports packed gemm for static shapes. Below are details. Static shapes | Benchmark | torchbench | huggingface | timm_models | |------------|-------------|--------------|--------------| | Multi-threaded (baseline) | 1.47x | 1.36x | 1.91x | | Multi-threaded (max-autotune) | 1.47x | 1.36x | 1.92x | | Single-threaded (baseline) | 1.56x | 1.19x | 1.51x | | Single-threaded (max-autotune) | 1.56x | 1.19x | 1.52x | Key models being sped up: drq: 1.14x soft_act: 1.12 cait_m36_384: 1.18x Dynamic shapes | Benchmark | torchbench | huggingface | timm_models | | --- | --- | --- | --- | | Multi-threaded (baseline) | 1.43x | 1.28x | 1.85x | | Multi-threaded (max-autotune) | 1.47x | 1.28x | 1.85x | | Single-threaded (baseline) | 1.55x | 1.20x | 1.51x | | Single-threaded (max-autotune) | 1.56x | 1.19x | 1.53x | Key models being sped up: BERT_pytorch: 1.22x pyhpc_turbulent: 1.13x soft_actor_critic: 1.77x BlenderbotForCausalLM: 1.09x cait_m36_384: 1.17x cc voznesenskym penguinwu EikanWang Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler amjames desertfire chauhang [ghstack-poisoned]
jgong5
pushed a commit
that referenced
this issue
May 12, 2024
This PR adds the Cpp template infrastructure and the initial FP32 gemm template. See RFC #125683 for more background info. 1. Cpp template infrastructure Similar template abstractions as the CUTLASS template, i.e., `CppTemplate`, `CppTemplateKernel`, `CppTemplateBuffer`. The MicroGemm micro-kernel abstraction that can be used by Cpp GEMM templates. 2. Initial FP32 gemm template This involves a GEMM template implementation `CppPackedGemmTemplate` that supports GEMM with constant weight (`B`) requiring `N` to be a multiple of register blocking while allows the static or dynamic sizes for the `M` (batch dim) of `A`. The `B` matrix would be prepacked. This is a typical setting for inference workloads. The template handles the thread decomposition (via `thread_blocking`) and cache blocking (via `cache_blocking`). Then it invokes `CppMicroGemm` which handles register blocking, instruction selection, and other CPU architecture-specific optimizations. A `CppMicroGemmFP32Vec` micro-kernel implementation is provided for fp32 matmuls implemented with ATen vec abstraction. 3. Correctness and performance The changes have been validated with fp32 inference on the three benchmark suites (torchbench, huggingface and timm_models) with both static shape and dynamic shapes. Since it is an initial implementation, we are still working on further performance improves with follow-up PRs including the optimizations in kernels as well as fusions. The perf gains are only observed from a selective number of models compared to the ATen kernels which are implemented with MKL. The perf gains are more obvious with dynamic shapes since MKL only supports packed gemm for static shapes. Below are details. Static shapes | Benchmark | torchbench | huggingface | timm_models | |------------|-------------|--------------|--------------| | Multi-threaded (baseline) | 1.47x | 1.36x | 1.91x | | Multi-threaded (max-autotune) | 1.47x | 1.36x | 1.92x | | Single-threaded (baseline) | 1.56x | 1.19x | 1.51x | | Single-threaded (max-autotune) | 1.56x | 1.19x | 1.52x | Key models being sped up: drq: 1.14x soft_act: 1.12 cait_m36_384: 1.18x Dynamic shapes | Benchmark | torchbench | huggingface | timm_models | | --- | --- | --- | --- | | Multi-threaded (baseline) | 1.43x | 1.28x | 1.85x | | Multi-threaded (max-autotune) | 1.47x | 1.28x | 1.85x | | Single-threaded (baseline) | 1.55x | 1.20x | 1.51x | | Single-threaded (max-autotune) | 1.56x | 1.19x | 1.53x | Key models being sped up: BERT_pytorch: 1.22x pyhpc_turbulent: 1.13x soft_actor_critic: 1.77x BlenderbotForCausalLM: 1.09x cait_m36_384: 1.17x cc voznesenskym penguinwu EikanWang Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler amjames desertfire chauhang [ghstack-poisoned]
pytorchmergebot
pushed a commit
that referenced
this issue
May 12, 2024
This PR adds the Cpp template infrastructure and the initial FP32 gemm template. See RFC #125683 for more background info. 1. Cpp template infrastructure Similar template abstractions as the CUTLASS template, i.e., `CppTemplate`, `CppTemplateKernel`, `CppTemplateBuffer`. The MicroGemm micro-kernel abstraction that can be used by Cpp GEMM templates. 2. Initial FP32 gemm template This involves a GEMM template implementation `CppPackedGemmTemplate` that supports GEMM with constant weight (`B`) requiring `N` to be a multiple of register blocking while allows the static or dynamic sizes for the `M` (batch dim) of `A`. The `B` matrix would be prepacked. This is a typical setting for inference workloads. The template handles the thread decomposition (via `thread_blocking`) and cache blocking (via `cache_blocking`). Then it invokes `CppMicroGemm` which handles register blocking, instruction selection, and other CPU architecture-specific optimizations. A `CppMicroGemmFP32Vec` micro-kernel implementation is provided for fp32 matmuls implemented with ATen vec abstraction. 3. Correctness and performance The changes have been validated with fp32 inference on the three benchmark suites (torchbench, huggingface and timm_models) with both static shape and dynamic shapes. Since it is an initial implementation, we are still working on further performance improves with follow-up PRs including the optimizations in kernels as well as fusions. The perf gains are only observed from a selective number of models compared to the ATen kernels which are implemented with MKL. The perf gains are more obvious with dynamic shapes since MKL only supports packed gemm for static shapes. Below are details. Static shapes | Benchmark | torchbench | huggingface | timm_models | |------------|-------------|--------------|--------------| | Multi-threaded (baseline) | 1.47x | 1.36x | 1.91x | | Multi-threaded (max-autotune) | 1.47x | 1.36x | 1.92x | | Single-threaded (baseline) | 1.56x | 1.19x | 1.51x | | Single-threaded (max-autotune) | 1.56x | 1.19x | 1.52x | Key models being sped up: drq: 1.14x soft_act: 1.12 cait_m36_384: 1.18x Dynamic shapes | Benchmark | torchbench | huggingface | timm_models | | --- | --- | --- | --- | | Multi-threaded (baseline) | 1.43x | 1.28x | 1.85x | | Multi-threaded (max-autotune) | 1.47x | 1.28x | 1.85x | | Single-threaded (baseline) | 1.55x | 1.20x | 1.51x | | Single-threaded (max-autotune) | 1.56x | 1.19x | 1.53x | Key models being sped up: BERT_pytorch: 1.22x pyhpc_turbulent: 1.13x soft_actor_critic: 1.77x BlenderbotForCausalLM: 1.09x cait_m36_384: 1.17x Pull Request resolved: #124021 Approved by: https://github.com/jansel
jgong5
pushed a commit
that referenced
this issue
May 14, 2024
…template" As part of #125683, this PR adds the epilogue support for c++ gemm template by reusing the c++ vector codegen on sub-slices of tensors. This is implemented by retracing the epilogue IR nodes with new ranges and offsets. The new `codegen_loop_bodies` and `codegen_functions` methods are added to c++ vector codegen for this purpose. This is leveraged by the `store_output` method of the template kernel for epilogue codegen and store to the final result. cc voznesenskym penguinwu EikanWang Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang [ghstack-poisoned]
jgong5
pushed a commit
that referenced
this issue
May 14, 2024
As part of #125683, this PR adds the epilogue support for c++ gemm template by reusing the c++ vector codegen on sub-slices of tensors. This is implemented by retracing the epilogue IR nodes with new ranges and offsets. The new `codegen_loop_bodies` and `codegen_functions` methods are added to c++ vector codegen for this purpose. This is leveraged by the `store_output` method of the template kernel for epilogue codegen and store to the final result. cc voznesenskym penguinwu EikanWang Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang [ghstack-poisoned]
jgong5
pushed a commit
that referenced
this issue
May 14, 2024
…template" As part of #125683, this PR adds the epilogue support for c++ gemm template by reusing the c++ vector codegen on sub-slices of tensors. This is implemented by retracing the epilogue IR nodes with new ranges and offsets. The new `codegen_loop_bodies` and `codegen_functions` methods are added to c++ vector codegen for this purpose. This is leveraged by the `store_output` method of the template kernel for epilogue codegen and store to the final result. cc voznesenskym penguinwu EikanWang Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang [ghstack-poisoned]
jgong5
pushed a commit
that referenced
this issue
May 14, 2024
As part of #125683, this PR adds the epilogue support for c++ gemm template by reusing the c++ vector codegen on sub-slices of tensors. This is implemented by retracing the epilogue IR nodes with new ranges and offsets. The new `codegen_loop_bodies` and `codegen_functions` methods are added to c++ vector codegen for this purpose. This is leveraged by the `store_output` method of the template kernel for epilogue codegen and store to the final result. cc voznesenskym penguinwu EikanWang Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang [ghstack-poisoned]
jgong5
pushed a commit
that referenced
this issue
May 14, 2024
…template" As part of #125683, this PR adds the epilogue support for c++ gemm template by reusing the c++ vector codegen on sub-slices of tensors. This is implemented by retracing the epilogue IR nodes with new ranges and offsets. The new `codegen_loop_bodies` and `codegen_functions` methods are added to c++ vector codegen for this purpose. This is leveraged by the `store_output` method of the template kernel for epilogue codegen and store to the final result. cc voznesenskym penguinwu EikanWang Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang [ghstack-poisoned]
jgong5
pushed a commit
that referenced
this issue
May 14, 2024
As part of #125683, this PR adds the epilogue support for c++ gemm template by reusing the c++ vector codegen on sub-slices of tensors. This is implemented by retracing the epilogue IR nodes with new ranges and offsets. The new `codegen_loop_bodies` and `codegen_functions` methods are added to c++ vector codegen for this purpose. This is leveraged by the `store_output` method of the template kernel for epilogue codegen and store to the final result. cc voznesenskym penguinwu EikanWang Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang [ghstack-poisoned]
tinglvv
pushed a commit
to tinglvv/pytorch
that referenced
this issue
May 14, 2024
This PR adds the Cpp template infrastructure and the initial FP32 gemm template. See RFC pytorch#125683 for more background info. 1. Cpp template infrastructure Similar template abstractions as the CUTLASS template, i.e., `CppTemplate`, `CppTemplateKernel`, `CppTemplateBuffer`. The MicroGemm micro-kernel abstraction that can be used by Cpp GEMM templates. 2. Initial FP32 gemm template This involves a GEMM template implementation `CppPackedGemmTemplate` that supports GEMM with constant weight (`B`) requiring `N` to be a multiple of register blocking while allows the static or dynamic sizes for the `M` (batch dim) of `A`. The `B` matrix would be prepacked. This is a typical setting for inference workloads. The template handles the thread decomposition (via `thread_blocking`) and cache blocking (via `cache_blocking`). Then it invokes `CppMicroGemm` which handles register blocking, instruction selection, and other CPU architecture-specific optimizations. A `CppMicroGemmFP32Vec` micro-kernel implementation is provided for fp32 matmuls implemented with ATen vec abstraction. 3. Correctness and performance The changes have been validated with fp32 inference on the three benchmark suites (torchbench, huggingface and timm_models) with both static shape and dynamic shapes. Since it is an initial implementation, we are still working on further performance improves with follow-up PRs including the optimizations in kernels as well as fusions. The perf gains are only observed from a selective number of models compared to the ATen kernels which are implemented with MKL. The perf gains are more obvious with dynamic shapes since MKL only supports packed gemm for static shapes. Below are details. Static shapes | Benchmark | torchbench | huggingface | timm_models | |------------|-------------|--------------|--------------| | Multi-threaded (baseline) | 1.47x | 1.36x | 1.91x | | Multi-threaded (max-autotune) | 1.47x | 1.36x | 1.92x | | Single-threaded (baseline) | 1.56x | 1.19x | 1.51x | | Single-threaded (max-autotune) | 1.56x | 1.19x | 1.52x | Key models being sped up: drq: 1.14x soft_act: 1.12 cait_m36_384: 1.18x Dynamic shapes | Benchmark | torchbench | huggingface | timm_models | | --- | --- | --- | --- | | Multi-threaded (baseline) | 1.43x | 1.28x | 1.85x | | Multi-threaded (max-autotune) | 1.47x | 1.28x | 1.85x | | Single-threaded (baseline) | 1.55x | 1.20x | 1.51x | | Single-threaded (max-autotune) | 1.56x | 1.19x | 1.53x | Key models being sped up: BERT_pytorch: 1.22x pyhpc_turbulent: 1.13x soft_actor_critic: 1.77x BlenderbotForCausalLM: 1.09x cait_m36_384: 1.17x Pull Request resolved: pytorch#124021 Approved by: https://github.com/jansel
pytorchmergebot
pushed a commit
that referenced
this issue
May 15, 2024
This PR adds the Cpp template infrastructure and the initial FP32 gemm template. See RFC #125683 for more background info. 1. Cpp template infrastructure Similar template abstractions as the CUTLASS template, i.e., `CppTemplate`, `CppTemplateKernel`, `CppTemplateBuffer`. The MicroGemm micro-kernel abstraction that can be used by Cpp GEMM templates. 2. Initial FP32 gemm template This involves a GEMM template implementation `CppPackedGemmTemplate` that supports GEMM with constant weight (`B`) requiring `N` to be a multiple of register blocking while allows the static or dynamic sizes for the `M` (batch dim) of `A`. The `B` matrix would be prepacked. This is a typical setting for inference workloads. The template handles the thread decomposition (via `thread_blocking`) and cache blocking (via `cache_blocking`). Then it invokes `CppMicroGemm` which handles register blocking, instruction selection, and other CPU architecture-specific optimizations. A `CppMicroGemmFP32Vec` micro-kernel implementation is provided for fp32 matmuls implemented with ATen vec abstraction. 3. Correctness and performance The changes have been validated with fp32 inference on the three benchmark suites (torchbench, huggingface and timm_models) with both static shape and dynamic shapes. Since it is an initial implementation, we are still working on further performance improves with follow-up PRs including the optimizations in kernels as well as fusions. The perf gains are only observed from a selective number of models compared to the ATen kernels which are implemented with MKL. The perf gains are more obvious with dynamic shapes since MKL only supports packed gemm for static shapes. Below are details. Static shapes | Benchmark | torchbench | huggingface | timm_models | |------------|-------------|--------------|--------------| | Multi-threaded (baseline) | 1.47x | 1.36x | 1.91x | | Multi-threaded (max-autotune) | 1.47x | 1.36x | 1.92x | | Single-threaded (baseline) | 1.56x | 1.19x | 1.51x | | Single-threaded (max-autotune) | 1.56x | 1.19x | 1.52x | Key models being sped up: drq: 1.14x soft_act: 1.12 cait_m36_384: 1.18x Dynamic shapes | Benchmark | torchbench | huggingface | timm_models | | --- | --- | --- | --- | | Multi-threaded (baseline) | 1.43x | 1.28x | 1.85x | | Multi-threaded (max-autotune) | 1.47x | 1.28x | 1.85x | | Single-threaded (baseline) | 1.55x | 1.20x | 1.51x | | Single-threaded (max-autotune) | 1.56x | 1.19x | 1.53x | Key models being sped up: BERT_pytorch: 1.22x pyhpc_turbulent: 1.13x soft_actor_critic: 1.77x BlenderbotForCausalLM: 1.09x cait_m36_384: 1.17x Pull Request resolved: #124021 Approved by: https://github.com/jansel
jgong5
pushed a commit
that referenced
this issue
Jun 1, 2024
…mputed with fp32 w/o epilogue fusion" As part of #125683, this PR adds the initial bf16/fp16 gemm template support with micro-gemm implemented with fused type casting and fp32 computation. It doesn't provide epilogue fusion support yet which will be added in the next PR. cc voznesenskym penguinwu EikanWang Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang Differential Revision: [D58017580](https://our.internmc.facebook.com/intern/diff/D58017580) [ghstack-poisoned]
jgong5
pushed a commit
that referenced
this issue
Jun 1, 2024
…usion" As part of #125683, this PR adds epilogue fusion support for bf16/fp16 gemms. The key changes are as follows: 1. bf16 linear w/ epilogue fusion of some ops was originally supported via ATen oneDNN linear pointwise ops. In order to match the ATen op semantics, in-template epilogue support is added to the cpp gemm template so that we would have: "gemm + in-template epilogues -> template buffer". If the template is chosen for codegen, the in-template epilogues will be concatenated with the out-of-template epilogues that are appended during the scheduling. 2. Support bf16/fp16 legalization for `codegen_loop_bodies` which is used to generate the epilogue loops. 3. We used to leverage the in-place buffer mechanism to handle the in-place buffers in the epilogue codegen, in particular, for the reuses for output buffers of GEMM, template and epilogues. This is not correct since the output buffer is an "output" not an "in-place" buffer of the template kernel itself. Now, we use a dedicated "aliases" dict to manage such buffer reuses and the intermediate aliasing buffers are removed after codegen. 4. Add `localize_buffer` method to `LocalBufferScope` to allow the replacement of a global buffer with a local one in the given inductor IR nodes. This helps the fused loops to work on smaller-sized local buffers for better data locality. cc voznesenskym penguinwu EikanWang Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang [ghstack-poisoned]
jgong5
pushed a commit
that referenced
this issue
Jun 1, 2024
… w/o epilogue fusion" As part of #125683, this PR adds the initial bf16/fp16 gemm template support with micro-gemm implemented with fused type casting and fp32 computation. It doesn't provide epilogue fusion support yet which will be added in the next PR. cc voznesenskym penguinwu EikanWang Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell1 F438 0 ipiszy yf225 chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang Differential Revision: [D58017580](https://our.internmc.facebook.com/intern/diff/D58017580) [ghstack-poisoned]
jgong5
pushed a commit
that referenced
this issue
Jun 1, 2024
…plate epilogue fusion" As part of #125683, this PR adds epilogue fusion support for bf16/fp16 gemms. The key changes are as follows: 1. bf16 linear w/ epilogue fusion of some ops was originally supported via ATen oneDNN linear pointwise ops. In order to match the ATen op semantics, in-template epilogue support is added to the cpp gemm template so that we would have: "gemm + in-template epilogues -> template buffer". If the template is chosen for codegen, the in-template epilogues will be concatenated with the out-of-template epilogues that are appended during the scheduling. 2. Support bf16/fp16 legalization for `codegen_loop_bodies` which is used to generate the epilogue loops. 3. We used to leverage the in-place buffer mechanism to handle the in-place buffers in the epilogue codegen, in particular, for the reuses for output buffers of GEMM, template and epilogues. This is not correct since the output buffer is an "output" not an "in-place" buffer of the template kernel itself. Now, we use a dedicated "aliases" dict to manage such buffer reuses and the intermediate aliasing buffers are removed after codegen. 4. Add `localize_buffer` method to `LocalBufferScope` to allow the replacement of a global buffer with a local one in the given inductor IR nodes. This helps the fused loops to work on smaller-sized local buffers for better data locality. cc voznesenskym penguinwu EikanWang Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang [ghstack-poisoned]
jgong5
pushed a commit
that referenced
this issue
Jun 1, 2024
…mputed with fp32 w/o epilogue fusion" As part of #125683, this PR adds the initial bf16/fp16 gemm template support with micro-gemm implemented with fused type casting and fp32 computation. It doesn't provide epilogue fusion support yet which will be added in the next PR. cc voznesenskym penguinwu EikanWang Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang Differential Revision: [D58017580](https://our.internmc.facebook.com/intern/diff/D58017580) [ghstack-poisoned]
jgong5
pushed a commit
that referenced
this issue
Jun 1, 2024
…usion" As part of #125683, this PR adds epilogue fusion support for bf16/fp16 gemms. The key changes are as follows: 1. bf16 linear w/ epilogue fusion of some ops was originally supported via ATen oneDNN linear pointwise ops. In order to match the ATen op semantics, in-template epilogue support is added to the cpp gemm template so that we would have: "gemm + in-template epilogues -> template buffer". If the template is chosen for codegen, the in-template epilogues will be concatenated with the out-of-template epilogues that are appended during the scheduling. 2. Support bf16/fp16 legalization for `codegen_loop_bodies` which is used to generate the epilogue loops. 3. We used to leverage the in-place buffer mechanism to handle the in-place buffers in the epilogue codegen, in particular, for the reuses for output buffers of GEMM, template and epilogues. This is not correct since the output buffer is an "output" not an "in-place" buffer of the template kernel itself. Now, we use a dedicated "aliases" dict to manage such buffer reuses and the intermediate aliasing buffers are removed after codegen. 4. Add `localize_buffer` method to `LocalBufferScope` to allow the replacement of a global buffer with a local one in the given inductor IR nodes. This helps the fused loops to work on smaller-sized local buffers for better data locality. cc voznesenskym penguinwu EikanWang Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang [ghstack-poisoned]
jgong5
pushed a commit
that referenced
this issue
Jun 1, 2024
… w/o epilogue fusion" As part of #125683, this PR adds the initial bf16/fp16 gemm template support with micro-gemm implemented with fused type casting and fp32 computation. It doesn't provide epilogue fusion support yet which will be added in the next PR. cc voznesenskym penguinwu EikanWang Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang Differential Revision: [D58017580](https://our.internmc.facebook.com/intern/diff/D58017580) [ghstack-poisoned]
jgong5
pushed a commit
that referenced
this issue
Jun 1, 2024
…mputed with fp32 w/o epilogue fusion" As part of #125683, this PR adds the initial bf16/fp16 gemm template support with micro-gemm implemented with fused type casting and fp32 computation. It doesn't provide epilogue fusion support yet which will be added in the next PR. cc voznesenskym penguinwu EikanWang Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang Differential Revision: [D58017580](https://our.internmc.facebook.com/intern/diff/D58017580) [ghstack-poisoned]
jgong5
pushed a commit
that referenced
this issue
Jun 1, 2024
…plate epilogue fusion" As part of #125683, this PR adds epilogue fusion support for bf16/fp16 gemms. The key changes are as follows: 1. bf16 linear w/ epilogue fusion of some ops was originally supported via ATen oneDNN linear pointwise ops. In order to match the ATen op semantics, in-template epilogue support is added to the cpp gemm template so that we would have: "gemm + in-template epilogues -> template buffer". If the template is chosen for codegen, the in-template epilogues will be concatenated with the out-of-template epilogues that are appended during the scheduling. 2. Support bf16/fp16 legalization for `codegen_loop_bodies` which is used to generate the epilogue loops. 3. We used to leverage the in-place buffer mechanism to handle the in-place buffers in the epilogue codegen, in particular, for the reuses for output buffers of GEMM, template and epilogues. This is not correct since the output buffer is an "output" not an "in-place" buffer of the template kernel itself. Now, we use a dedicated "aliases" dict to manage such buffer reuses and the intermediate aliasing buffers are removed after codegen. 4. Add `localize_buffer` method to `LocalBufferScope` to allow the replacement of a global buffer with a local one in the given inductor IR nodes. This helps the fused loops to work on smaller-sized local buffers for better data locality. cc voznesenskym penguinwu EikanWang Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang [ghstack-poisoned]
jgong5
pushed a commit
that referenced
this issue
Jun 1, 2024
…usion" As part of #125683, this PR adds epilogue fusion support for bf16/fp16 gemms. The key changes are as follows: 1. bf16 linear w/ epilogue fusion of some ops was originally supported via ATen oneDNN linear pointwise ops. In order to match the ATen op semantics, in-template epilogue support is added to the cpp gemm template so that we would have: "gemm + in-template epilogues -> template buffer". If the template is chosen for codegen, the in-template epilogues will be concatenated with the out-of-template epilogues that are appended during the scheduling. 2. Support bf16/fp16 legalization for `codegen_loop_bodies` which is used to generate the epilogue loops. 3. We used to leverage the in-place buffer mechanism to handle the in-place buffers in the epilogue codegen, in particular, for the reuses for output buffers of GEMM, template and epilogues. This is not correct since the output buffer is an "output" not an "in-place" buffer of the template kernel itself. Now, we use a dedicated "aliases" dict to manage such buffer reuses and the intermediate aliasing buffers are removed after codegen. 4. Add `localize_buffer` method to `LocalBufferScope` to allow the replacement of a global buffer with a local one in the given inductor IR nodes. This helps the fused loops to work on smaller-sized local buffers for better data locality. cc voznesenskym penguinwu EikanWang Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang [ghstack-poisoned]
jgong5
pushed a commit
that referenced
this issue
Jun 1, 2024
… w/o epilogue fusion" As part of #125683, this PR adds the initial bf16/fp16 gemm template support with micro-gemm implemented with fused type casting and fp32 computation. It doesn't provide epilogue fusion support yet which will be added in the next PR. cc voznesenskym penguinwu EikanWang Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang Differential Revision: [D58017580](https://our.internmc.facebook.com/intern/diff/D58017580) [ghstack-poisoned]
jgong5
pushed a commit
that referenced
this issue
Jun 1, 2024
…mputed with fp32 w/o epilogue fusion" As part of #125683, this PR adds the initial bf16/fp16 gemm template support with micro-gemm implemented with fused type casting and fp32 computation. It doesn't provide epilogue fusion support yet which will be added in the next PR. cc voznesenskym penguinwu EikanWang Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang Differential Revision: [D58017580](https://our.internmc.facebook.com/intern/diff/D58017580) [ghstack-poisoned]
jgong5
pushed a commit
that referenced
this issue
Jun 1, 2024
…plate epilogue fusion" As part of #125683, this PR adds epilogue fusion support for bf16/fp16 gemms. The key changes are as follows: 1. bf16 linear w/ epilogue fusion of some ops was originally supported via ATen oneDNN linear pointwise ops. In order to match the ATen op semantics, in-template epilogue support is added to the cpp gemm template so that we would have: "gemm + in-template epilogues -> template buffer". If the template is chosen for codegen, the in-template epilogues will be concatenated with the out-of-template epilogues that are appended during the scheduling. 2. Support bf16/fp16 legalization for `codegen_loop_bodies` which is used to generate the epilogue loops. 3. We used to leverage the in-place buffer mechanism to handle the in-place buffers in the epilogue codegen, in particular, for the reuses for output buffers of GEMM, template and epilogues. This is not correct since the output buffer is an "output" not an "in-place" buffer of the template kernel itself. Now, we use a dedicated "aliases" dict to manage such buffer reuses and the intermediate aliasing buffers are removed after codegen. 4. Add `localize_buffer` method to `LocalBufferScope` to allow the replacement of a global buffer with a local one in the given inductor IR nodes. This helps the fused loops to work on smaller-sized local buffers for better data locality. cc voznesenskym penguinwu EikanWang Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang [ghstack-poisoned]
jgong5
pushed a commit
that referenced
this issue
Jun 1, 2024
… w/o epilogue fusion" As part of #125683, this PR adds the initial bf16/fp16 gemm template support with micro-gemm implemented with fused type casting and fp32 computation. It doesn't provide epilogue fusion support yet which will be added in the next PR. cc voznesenskym penguinwu EikanWang Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang Differential Revision: [D58017580](https://our.internmc.facebook.com/intern/diff/D58017580) [ghstack-poisoned]
jgong5
pushed a commit
that referenced
this issue
Jun 1, 2024
…usion" As part of #125683, this PR adds epilogue fusion support for bf16/fp16 gemms. The key changes are as follows: 1. bf16 linear w/ epilogue fusion of some ops was originally supported via ATen oneDNN linear pointwise ops. In order to match the ATen op semantics, in-template epilogue support is added to the cpp gemm template so that we would have: "gemm + in-template epilogues -> template buffer". If the template is chosen for codegen, the in-template epilogues will be concatenated with the out-of-template epilogues that are appended during the scheduling. 2. Support bf16/fp16 legalization for `codegen_loop_bodies` which is used to generate the epilogue loops. 3. We used to leverage the in-place buffer mechanism to handle the in-place buffers in the epilogue codegen, in particular, for the reuses for output buffers of GEMM, template and epilogues. This is not correct since the output buffer is an "output" not an "in-place" buffer of the template kernel itself. Now, we use a dedicated "aliases" dict to manage such buffer reuses and the intermediate aliasing buffers are removed after codegen. 4. Add `localize_buffer` method to `LocalBufferScope` to allow the replacement of a global buffer with a local one in the given inductor IR nodes. This helps the fused loops to work on smaller-sized local buffers for better data locality. cc voznesenskym penguinwu EikanWang Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang [ghstack-poisoned]
pytorchmergebot
pushed a commit
that referenced
this issue
Jun 13, 2024
) As part of #125683, this PR adds epilogue fusion support for bf16/fp16 gemms. The key changes are as follows: 1. bf16 linear w/ epilogue fusion of some ops was originally supported via ATen oneDNN linear pointwise ops. In order to match the ATen op semantics, in-template epilogue support is added to the cpp gemm template so that we would have: "gemm + in-template epilogues -> template buffer". If the template is chosen for codegen, the in-template epilogues will be concatenated with the out-of-template epilogues that are appended during the scheduling. 2. Support bf16/fp16 legalization for `codegen_loop_bodies` which is used to generate the epilogue loops. 3. We used to leverage the in-place buffer mechanism to handle the in-place buffers in the epilogue codegen, in particular, for the reuses for output buffers of GEMM, template and epilogues. This is not correct since the output buffer is an "output" not an "in-place" buffer of the template kernel itself. Now, we use a dedicated "aliases" dict to manage such buffer reuses and the intermediate aliasing buffers are removed after codegen. 4. Add `localize_buffer` method to `LocalBufferScope` to allow the replacement of a global buffer with a local one in the given inductor IR nodes. This helps the fused loops to work on smaller-sized local buffers for better data locality. Pull Request resolved: #126545 Approved by: https://github.com/jansel
TharinduRusira
pushed a commit
to TharinduRusira/pytorch
that referenced
this issue
Jun 14, 2024
…rch#126545) As part of pytorch#125683, this PR adds epilogue fusion support for bf16/fp16 gemms. The key changes are as follows: 1. bf16 linear w/ epilogue fusion of some ops was originally supported via ATen oneDNN linear pointwise ops. In order to match the ATen op semantics, in-template epilogue support is added to the cpp gemm template so that we would have: "gemm + in-template epilogues -> template buffer". If the template is chosen for codegen, the in-template epilogues will be concatenated with the out-of-template epilogues that are appended during the scheduling. 2. Support bf16/fp16 legalization for `codegen_loop_bodies` which is used to generate the epilogue loops. 3. We used to leverage the in-place buffer mechanism to handle the in-place buffers in the epilogue codegen, in particular, for the reuses for output buffers of GEMM, template and epilogues. This is not correct since the output buffer is an "output" not an "in-place" buffer of the template kernel itself. Now, we use a dedicated "aliases" dict to manage such buffer reuses and the intermediate aliasing buffers are removed after codegen. 4. Add `localize_buffer` method to `LocalBufferScope` to allow the replacement of a global buffer with a local one in the given inductor IR nodes. This helps the fused loops to work on smaller-sized local buffers for better data locality. Pull Request resolved: pytorch#126545 Approved by: https://github.com/jansel
ignaciobartol
pushed a commit
to ignaciobartol/pytorch
that referenced
this issue
Jun 14, 2024
…rch#126545) As part of pytorch#125683, this PR adds epilogue fusion support for bf16/fp16 gemms. The key changes are as follows: 1. bf16 linear w/ epilogue fusion of some ops was originally supported via ATen oneDNN linear pointwise ops. In order to match the ATen op semantics, in-template epilogue support is added to the cpp gemm template so that we would have: "gemm + in-template epilogues -> template buffer". If the template is chosen for codegen, the in-template epilogues will be concatenated with the out-of-template epilogues that are appended during the scheduling. 2. Support bf16/fp16 legalization for `codegen_loop_bodies` which is used to generate the epilogue loops. 3. We used to leverage the in-place buffer mechanism to handle the in-place buffers in the epilogue codegen, in particular, for the reuses for output buffers of GEMM, template and epilogues. This is not correct since the output buffer is an "output" not an "in-place" buffer of the template kernel itself. Now, we use a dedicated "aliases" dict to manage such buffer reuses and the intermediate aliasing buffers are removed after codegen. 4. Add `localize_buffer` method to `LocalBufferScope` to allow the replacement of a global buffer with a local one in the given inductor IR nodes. This helps the fused loops to work on smaller-sized local buffers for better data locality. Pull Request resolved: pytorch#126545 Approved by: https://github.com/jansel
pytorchmergebot
pushed a commit
that referenced
this issue
Aug 10, 2024
## Summary As part of #125683, this PR modifies existing CPU GEMM cpp template & micro-kernel template to enable int8 WoQ GEMM auto-tuning with AVX2, AVX512 & AMX ISAs (the latter is only available on Xeon 4th generation & beyond). WoQ GEMM takes FP16/BF16 activations, int8 weights, and scale of the same dtype as activations. The operation is equivalent to `torch.nn.functional.linear(x, w.to(x.dtype)) * scale`, which is essentially what the ATen op `torch.ops.aten._weight_int8pack_mm` currently does (except that weights are not cached by it). Weights will be considered constant & cached, so this implementation is suitable for inference, and not QAT. `scale` is supported as a `mul` epilogue. Only BF16 activations have been supported in this PR because for FP16 & FP32, weight is dequantized during constant-folding pass of freezing, and then after auto-tuning, performance with a large `M` dimension may be better than either torch.ops.aten._weight_int8pack_mm, or the WoQ micro-kernel support introduced in this PR, which dequantizes `w` within the micro-kernel. While even BF16 activations with a large `M` dimension may benefit from dequantizing `w` beforehand, for now, they would use WoQ support in GEMM templates for auto-tuning, and then a subsequent PR would add logic for deciding whether or not to dequantize weights beforehand. ### Performance #### AMX Op-level speedup due to AMX micro-kernel (selected during auto-tuning) on 32 physical cores of Intel(R) Xeon(R) Platinum 8468H (of Xeon 4th generation series, codenamed Sapphire Rapids) vs. ATen kernel `torch.ops.aten._weight_int8pack_mm`. Intel OpenMP & tcmalloc were preloaded. In a few cases with an odd `K`, the implementation being added in this PR may not perform as well as the ATen kernel, which is unrelated to this PR, though, since `test_linear_amx` also exhibits similar datapoints. In those cases, the AMX micro-kernel might be slower than AVX512 micro-kernel, so if such sets of shapes are used for auto-tuning, either the AVX512 micro-kernel implementation, or the ATen kernel would be chosen instead. Benchmarked with unit-tests. Tabular data at https://gist.github.com/sanchitintel/294811a86c8ff6b867c668ae2107c405?permalink_comment_id=5142442#gistcomment-5142442 The AVX512 micro-kernel was disabled to collect data for AMX micro-kernel. #### AVX2/AVX512 micro-kernels Tabular data at at https://gist.github.com/sanchitintel/52b5fa9c66f791be19e48e2aa6423dc4?permalink_comment_id=5142437#gistcomment-5142437 ### Follow-up 1. int4 WoQ GEMM micro-kernel will also be added in a separate PR. 2. A subsequent PR would add logic for deciding whether or not to dequantize weights beforehand. E2E perf measurement should be done with #131310. Pull Request resolved: #131887 Approved by: https://github.com/jgong5, https://github.com/leslie-fang-intel, https://github.com/jansel
pytorchmergebot
pushed a commit
that referenced
this issue
Aug 14, 2024
## Summary As part of #125683, this PR modifies existing CPU GEMM cpp template & micro-kernel template to enable int8 WoQ GEMM auto-tuning with AVX2, AVX512 & AMX ISAs (the latter is only available on Xeon 4th generation & beyond). WoQ GEMM takes FP16/BF16 activations, int8 weights, and scale of the same dtype as activations. The operation is equivalent to `torch.nn.functional.linear(x, w.to(x.dtype)) * scale`, which is essentially what the ATen op `torch.ops.aten._weight_int8pack_mm` currently does (except that weights are not cached by it). Weights will be considered constant & cached, so this implementation is suitable for inference, and not QAT. `scale` is supported as a `mul` epilogue. Only BF16 activations have been supported in this PR because for FP16 & FP32, weight is dequantized during constant-folding pass of freezing, and then after auto-tuning, performance with a large `M` dimension may be better than either torch.ops.aten._weight_int8pack_mm, or the WoQ micro-kernel support introduced in this PR, which dequantizes `w` within the micro-kernel. While even BF16 activations with a large `M` dimension may benefit from dequantizing `w` beforehand, for now, they would use WoQ support in GEMM templates for auto-tuning, and then a subsequent PR would add logic for deciding whether or not to dequantize weights beforehand. ### Performance #### AMX Op-level speedup due to AMX micro-kernel (selected during auto-tuning) on 32 physical cores of Intel(R) Xeon(R) Platinum 8468H (of Xeon 4th generation series, codenamed Sapphire Rapids) vs. ATen kernel `torch.ops.aten._weight_int8pack_mm`. Intel OpenMP & tcmalloc were preloaded. In a few cases with an odd `K`, the implementation being added in this PR may not perform as well as the ATen kernel, which is unrelated to this PR, though, since `test_linear_amx` also exhibits similar datapoints. In those cases, the AMX micro-kernel might be slower than AVX512 micro-kernel, so if such sets of shapes are used for auto-tuning, either the AVX512 micro-kernel implementation, or the ATen kernel would be chosen instead. Benchmarked with unit-tests. Tabular data at https://gist.github.com/sanchitintel/294811a86c8ff6b867c668ae2107c405?permalink_comment_id=5142442#gistcomment-5142442 The AVX512 micro-kernel was disabled to collect data for AMX micro-kernel. #### AVX2/AVX512 micro-kernels Tabular data at at https://gist.github.com/sanchitintel/52b5fa9c66f791be19e48e2aa6423dc4?permalink_comment_id=5142437#gistcomment-5142437 ### Follow-up 1. int4 WoQ GEMM micro-kernel will also be added in a separate PR. 2. A subsequent PR would add logic for deciding whether or not to dequantize weights beforehand. E2E perf measurement should be done with #131310. Pull Request resolved: #131887 Approved by: https://github.com/jgong5, https://github.com/leslie-fang-intel, https://github.com/jansel
WeizhuoZhang-intel
pushed a commit
to WeizhuoZhang-intel/pytorch
that referenced
this issue
Aug 15, 2024
…131887) ## Summary As part of pytorch#125683, this PR modifies existing CPU GEMM cpp template & micro-kernel template to enable int8 WoQ GEMM auto-tuning with AVX2, AVX512 & AMX ISAs (the latter is only available on Xeon 4th generation & beyond). WoQ GEMM takes FP16/BF16 activations, int8 weights, and scale of the same dtype as activations. The operation is equivalent to `torch.nn.functional.linear(x, w.to(x.dtype)) * scale`, which is essentially what the ATen op `torch.ops.aten._weight_int8pack_mm` currently does (except that weights are not cached by it). Weights will be considered constant & cached, so this implementation is suitable for inference, and not QAT. `scale` is supported as a `mul` epilogue. Only BF16 activations have been supported in this PR because for FP16 & FP32, weight is dequantized during constant-folding pass of freezing, and then after auto-tuning, performance with a large `M` dimension may be better than either torch.ops.aten._weight_int8pack_mm, or the WoQ micro-kernel support introduced in this PR, which dequantizes `w` within the micro-kernel. While even BF16 activations with a large `M` dimension may benefit from dequantizing `w` beforehand, for now, they would use WoQ support in GEMM templates for auto-tuning, and then a subsequent PR would add logic for deciding whether or not to dequantize weights beforehand. ### Performance #### AMX Op-level speedup due to AMX micro-kernel (selected during auto-tuning) on 32 physical cores of Intel(R) Xeon(R) Platinum 8468H (of Xeon 4th generation series, codenamed Sapphire Rapids) vs. ATen kernel `torch.ops.aten._weight_int8pack_mm`. Intel OpenMP & tcmalloc were preloaded. In a few cases with an odd `K`, the implementation being added in this PR may not perform as well as the ATen kernel, which is unrelated to this PR, though, since `test_linear_amx` also exhibits similar datapoints. In those cases, the AMX micro-kernel might be slower than AVX512 micro-kernel, so if such sets of shapes are used for auto-tuning, either the AVX512 micro-kernel implementation, or the ATen kernel would be chosen instead. Benchmarked with unit-tests. Tabular data at https://gist.github.com/sanchitintel/294811a86c8ff6b867c668ae2107c405?permalink_comment_id=5142442#gistcomment-5142442 The AVX512 micro-kernel was disabled to collect data for AMX micro-kernel. #### AVX2/AVX512 micro-kernels Tabular data at at https://gist.github.com/sanchitintel/52b5fa9c66f791be19e48e2aa6423dc4?permalink_comment_id=5142437#gistcomment-5142437 ### Follow-up 1. int4 WoQ GEMM micro-kernel will also be added in a separate PR. 2. A subsequent PR would add logic for deciding whether or not to dequantize weights beforehand. E2E perf measurement should be done with pytorch#131310. Pull Request resolved: pytorch#131887 Approved by: https://github.com/jgong5, https://github.com/leslie-fang-intel, https://github.com/jansel
malfet
pushed a commit
to aditew01/pytorch
that referenced
this issue
Sep 13, 2024
…131887) ## Summary As part of pytorch#125683, this PR modifies existing CPU GEMM cpp template & micro-kernel template to enable int8 WoQ GEMM auto-tuning with AVX2, AVX512 & AMX ISAs (the latter is only available on Xeon 4th generation & beyond). WoQ GEMM takes FP16/BF16 activations, int8 weights, and scale of the same dtype as activations. The operation is equivalent to `torch.nn.functional.linear(x, w.to(x.dtype)) * scale`, which is essentially what the ATen op `torch.ops.aten._weight_int8pack_mm` currently does (except that weights are not cached by it). Weights will be considered constant & cached, so this implementation is suitable for inference, and not QAT. `scale` is supported as a `mul` epilogue. Only BF16 activations have been supported in this PR because for FP16 & FP32, weight is dequantized during constant-folding pass of freezing, and then after auto-tuning, performance with a large `M` dimension may be better than either torch.ops.aten._weight_int8pack_mm, or the WoQ micro-kernel support introduced in this PR, which dequantizes `w` within the micro-kernel. While even BF16 activations with a large `M` dimension may benefit from dequantizing `w` beforehand, for now, they would use WoQ support in GEMM templates for auto-tuning, and then a subsequent PR would add logic for deciding whether or not to dequantize weights beforehand. ### Performance #### AMX Op-level speedup due to AMX micro-kernel (selected during auto-tuning) on 32 physical cores of Intel(R) Xeon(R) Platinum 8468H (of Xeon 4th generation series, codenamed Sapphire Rapids) vs. ATen kernel `torch.ops.aten._weight_int8pack_mm`. Intel OpenMP & tcmalloc were preloaded. In a few cases with an odd `K`, the implementation being added in this PR may not perform as well as the ATen kernel, which is unrelated to this PR, though, since `test_linear_amx` also exhibits similar datapoints. In those cases, the AMX micro-kernel might be slower than AVX512 micro-kernel, so if such sets of shapes are used for auto-tuning, either the AVX512 micro-kernel implementation, or the ATen kernel would be chosen instead. Benchmarked with unit-tests. Tabular data at https://gist.github.com/sanchitintel/294811a86c8ff6b867c668ae2107c405?permalink_comment_id=5142442#gistcomment-5142442 The AVX512 micro-kernel was disabled to collect data for AMX micro-kernel. #### AVX2/AVX512 micro-kernels Tabular data at at https://gist.github.com/sanchitintel/52b5fa9c66f791be19e48e2aa6423dc4?permalink_comment_id=5142437#gistcomment-5142437 ### Follow-up 1. int4 WoQ GEMM micro-kernel will also be added in a separate PR. 2. A subsequent PR would add logic for deciding whether or not to dequantize weights beforehand. E2E perf measurement should be done with pytorch#131310. Pull Request resolved: pytorch#131887 Approved by: https://github.com/jgong5, https://github.com/leslie-fang-intel, https://github.com/jansel
Chao1Han
pushed a commit
to Chao1Han/pytorch
that referenced
this issue
Sep 26, 2024
…131887) ## Summary As part of pytorch#125683, this PR modifies existing CPU GEMM cpp template & micro-kernel template to enable int8 WoQ GEMM auto-tuning with AVX2, AVX512 & AMX ISAs (the latter is only available on Xeon 4th generation & beyond). WoQ GEMM takes FP16/BF16 activations, int8 weights, and scale of the same dtype as activations. The operation is equivalent to `torch.nn.functional.linear(x, w.to(x.dtype)) * scale`, which is essentially what the ATen op `torch.ops.aten._weight_int8pack_mm` currently does (except that weights are not cached by it). Weights will be considered co 2851 nstant & cached, so this implementation is suitable for inference, and not QAT. `scale` is supported as a `mul` epilogue. Only BF16 activations have been supported in this PR because for FP16 & FP32, weight is dequantized during constant-folding pass of freezing, and then after auto-tuning, performance with a large `M` dimension may be better than either torch.ops.aten._weight_int8pack_mm, or the WoQ micro-kernel support introduced in this PR, which dequantizes `w` within the micro-kernel. While even BF16 activations with a large `M` dimension may benefit from dequantizing `w` beforehand, for now, they would use WoQ support in GEMM templates for auto-tuning, and then a subsequent PR would add logic for deciding whether or not to dequantize weights beforehand. ### Performance #### AMX Op-level speedup due to AMX micro-kernel (selected during auto-tuning) on 32 physical cores of Intel(R) Xeon(R) Platinum 8468H (of Xeon 4th generation series, codenamed Sapphire Rapids) vs. ATen kernel `torch.ops.aten._weight_int8pack_mm`. Intel OpenMP & tcmalloc were preloaded. In a few cases with an odd `K`, the implementation being added in this PR may not perform as well as the ATen kernel, which is unrelated to this PR, though, since `test_linear_amx` also exhibits similar datapoints. In those cases, the AMX micro-kernel might be slower than AVX512 micro-kernel, so if such sets of shapes are used for auto-tuning, either the AVX512 micro-kernel implementation, or the ATen kernel would be chosen instead. Benchmarked with unit-tests. Tabular data at https://gist.github.com/sanchitintel/294811a86c8ff6b867c668ae2107c405?permalink_comment_id=5142442#gistcomment-5142442 The AVX512 micro-kernel was disabled to collect data for AMX micro-kernel. #### AVX2/AVX512 micro-kernels Tabular data at at https://gist.github.com/sanchitintel/52b5fa9c66f791be19e48e2aa6423dc4?permalink_comment_id=5142437#gistcomment-5142437 ### Follow-up 1. int4 WoQ GEMM micro-kernel will also be added in a separate PR. 2. A subsequent PR would add logic for deciding whether or not to dequantize weights beforehand. E2E perf measurement should be done with pytorch#131310. Pull Request resolved: pytorch#131887 Approved by: https://github.com/jgong5, https://github.com/leslie-fang-intel, https://github.com/jansel
pytorchmergebot
pushed a commit
that referenced
this issue
Dec 6, 2024
This PR adds the Cpp template for BMM, for FP32, FP16, and BF16. See #125683 for more background. 1. Adds `CppBmmTemplate` class which inherits from `CppPackedGemmTemplate`. Given a number of worker threads `num_threads` and batch size `B`, execute the Gemm kernel. For the first `B - (B % num_threads)` batch inputs, run one sub-gemm problem per thread. Then for the remaining `B % num_threads` sub-gemms, we execute each subproblem using the parallelized Gemm kernel. To manage this code, the `GEMM_TEMPLATE` from `CppPackedGemmTemplate` is rendered two different times, one with a single thread and one which includes the parallel OMP pragma. 2. Adapts `CppPackedGemmTemplate` to allow for child class. The `GEMM_TEMPLATE` is separated into different strings to allow for rendering by the child class. Slicing/indexing are adapted to allow for 3D BMM inputs. Additional methods `get_options()` and `_get_params_for_choices()` are added to reduce code duplication. BMM within `dlrm` benchmark has a single input buffer which is used for but X and W inputs. This is currently not supported in this PR. ### Performance On Granite/Sapphire Rapids, cpp_bmm template code uses AMX which requires an expensive transpose operation so the BMM op is rarely selected as faster than the existing external bmm kernel. As a result, speedup on SPR is identical with and without BMM code. Pass rate matches the rates for main exactly. #### Test Summary on Granite Rapids Test Scenario | Comp Item | Date | Compiler | torchbench | huggingface | timm_models -- | -- | -- | -- | -- | -- | -- Single Socket Multi-Threads | Pass Rate | gemm autotune| inductor | 91%, 73/80 | 100%, 46/46 | 100%, 61/61 | | | bmm + gemm autotune | inductor | 91%, 73/80 | 100%, 46/46 | 100%, 61/61 | | Geomean Speedup | gemm autotune| inductor | 2.15x | 1.91x | 2.52x | | | bmm + gemm autotune | inductor | 2.15x | 1.96x | 2.53x Single Core Single-Thread | Pass Rate | gemm autotune | inductor | 91%, 73/80 | 100%, 46/46 | 100%, 61/61 | | | bmm + gemm autotune| inductor | 91%, 73/80 | 100%, 46/46 | 100%, 61/61 | | Geomean Speedup | inductor_locally_benchmark_586 | inductor | 2.43x | 1.56x | 2.60x | | | inductor_locally_benchmark_585 | inductor | 2.45x | 1.56x | 2.63x This is not the case on an older Skylake Xeon machine. For the BMM ops contained in torchbench models, bmm performance improves by 1.10-2.64x. #### BF16 28-core Skylake Xeon | Model | Inductor | GemmAutotune | Gemm+BMM Autotune | |--------|--------|--------|--------| | BERT_pytorch | 1.233x | 2.597x | 2.608x | | hf_DistilBert | 1.128x | 2.242x | 2.368x | | hf_Reformer | 1.124x | 1.419x | 1.590x | | hf_T5_base | 1.012x | 1.257x | 1.382x | | hf_T5_large | 1.085x | 2.228x | 2.345x | ## Example BMM Code ``` #include <c10/util/Unroll.h> #include <torch/csrc/inductor/aoti_torch/c/shim.h> template <bool accum> inline void cpp_bmm_micro_gemm_amx_kernel_32_2( AMXState& amx_state, const bfloat16* __restrict__ A, const bfloat16* __restrict__ B, float* __restrict__ C, int64_t K, int64_t lda, int64_t ldb, int64_t ldc, uint8_t tilecfg_rows ) { // TODO(jgong5): add prefetch hint for A, B, C auto loadconfig = [](const amx_tilecfg& cfg) { _tile_loadconfig(&cfg); }; const auto last_k_offset = K / 32 * 32; const auto tail_k_size = K - last_k_offset; if C10_LIKELY (last_k_offset > 0) { amx_state.configure(tilecfg_rows, 64, 32 / 16, 2, loadconfig); } else { amx_state.configure(tilecfg_rows, tail_k_size * sizeof(bfloat16), 32 / 16, 2, loadconfig); } auto load_c = [&]() { _tile_loadd(0, C + 0 * ldc + 0, ldc * sizeof(float)); _tile_loadd(1, C + 0 * ldc + 16, ldc * sizeof(float)); _tile_loadd(2, C + 16 * ldc + 0, ldc * sizeof(float)); _tile_loadd(3, C + 16 * ldc + 16, ldc * sizeof(float)); }; auto zero_c = [&]() { _tile_zero(0); _tile_zero(1); _tile_zero(2); _tile_zero(3); }; if constexpr (accum) { load_c(); } else { zero_c(); } auto compute = [&](int k) { _tile_stream_loadd(4, A + 0 * lda + k, lda * sizeof(bfloat16)); _tile_loadd(6, B + k * ldb + 0, ldb * 2 * sizeof(bfloat16)); _tile_dpbf16ps(0, 4, 6); _tile_loadd(7, B + k * ldb + 32, ldb * 2 * sizeof(bfloat16)); _tile_dpbf16ps(1, 4, 7); _tile_stream_loadd(5, A + 16 * lda + k, lda * sizeof(bfloat16)); _tile_dpbf16ps(2, 5, 6); _tile_dpbf16ps(3, 5, 7); }; #pragma GCC unroll 4 for (int k = 0; k < last_k_offset; k += 32) { compute(k); } auto store_c = [&]() { // store to C _tile_stored(0, C + 0 * ldc + 0, ldc * sizeof(float)); _tile_stored(1, C + 0 * ldc + 16, ldc * sizeof(float)); _tile_stored(2, C + 16 * ldc + 0, ldc * sizeof(float)); _tile_stored(3, C + 16 * ldc + 16, ldc * sizeof(float)); }; // TODO(jgong5): move tail k computation to separate loopnest to save tile configuration overhead if C10_UNLIKELY (tail_k_size > 0) { if C10_LIKELY (last_k_offset > 0) { store_c(); amx_state.configure(tilecfg_rows, tail_k_size * sizeof(bfloat16), 32 / 16, 2, loadconfig); load_c(); } compute(last_k_offset); } store_c(); } template <bool accum> inline void cpp_bmm_micro_gemm_amx_kernel_16_2( AMXState& amx_state, const bfloat16* __restrict__ A, const bfloat16* __restrict__ B, float* __restrict__ C, int64_t K, int64_t lda, int64_t ldb, int64_t ldc, uint8_t tilecfg_rows ) { // TODO(jgong5): add prefetch hint for A, B, C auto loadconfig = [](const amx_tilecfg& cfg) { _tile_loadconfig(&cfg); }; const auto last_k_offset = K / 32 * 32; const auto tail_k_size = K - last_k_offset; if C10_LIKELY (last_k_offset > 0) { amx_state.configure(tilecfg_rows, 64, 16 / 16, 2, loadconfig); } else { amx_state.configure(tilecfg_rows, tail_k_size * sizeof(bfloat16), 16 / 16, 2, loadconfig); } auto load_c = [&]() { _tile_loadd(0, C + 0 * ldc + 0, ldc * sizeof(float)); _tile_loadd(1, C + 0 * ldc + 16, ldc * sizeof(float)); }; auto zero_c = [&]() { _tile_zero(0); _tile_zero(1); }; if constexpr (accum) { load_c(); } else { zero_c(); } auto compute = [&](int k) { _tile_stream_loadd(2, A + 0 * lda + k, lda * sizeof(bfloat16)); _tile_loadd(3, B + k * ldb + 0, ldb * 2 * sizeof(bfloat16)); _tile_dpbf16ps(0, 2, 3); _tile_loadd(4, B + k * ldb + 32, ldb * 2 * sizeof(bfloat16)); _tile_dpbf16ps(1, 2, 4); }; #pragma GCC unroll 4 for (int k = 0; k < last_k_offset; k += 32) { compute(k); } auto store_c = [&]() { // store to C _tile_stored(0, C + 0 * ldc + 0, ldc * sizeof(float)); _tile_stored(1, C + 0 * ldc + 16, ldc * sizeof(float)); }; // TODO(jgong5): move tail k computation to separate loopnest to save tile configuration overhead if C10_UNLIKELY (tail_k_size > 0) { if C10_LIKELY (last_k_offset > 0) { store_c(); amx_state.configure(tilecfg_rows, tail_k_size * sizeof(bfloat16), 16 / 16, 2, loadconfig); load_c(); } compute(last_k_offset); } store_c(); } template <bool accum> inline void cpp_bmm_micro_gemm( AMXState& amx_state, const bfloat16* __restrict__ A, const bfloat16* __restrict__ B, float* __restrict__ C, int64_t M, int64_t N, int64_t K, int64_t lda, int64_t ldb, int64_t ldc ) { AOTI_TORCH_CHECK(N % 32 == 0, "N dimension must be multiple of 32"); AOTI_TORCH_CHECK(K % 2 == 0, "K dimension must be multiple of 2"); // TODO(jgong5): loop unroll for M and N for (int64_t n = 0; n < N; n += 32) { for (int64_t m = 0; m < M; m += 32) { int64_t block_m = std::min<int64_t>(M - m, 32); int64_t m_tail = m; if (block_m >= 32) { cpp_bmm_micro_gemm_amx_kernel_32_2<accum>( amx_state, A + m * lda, B + n, C + m * ldc + n, K, lda, ldb, ldc, 16 ); block_m -= 32; m_tail += 32; } else if (block_m >= 16) { cpp_bmm_micro_gemm_amx_kernel_16_2<accum>( amx_state, A + m * lda, B + n, C + m * ldc + n, K, lda, ldb, ldc, 16 ); block_m -= 16; m_tail += 16; } if (block_m & 10000 gt; 0) { cpp_bmm_micro_gemm_amx_kernel_16_2<accum>( amx_state, A + m_tail * lda, B + n, C + m_tail * ldc + n, K, lda, ldb, ldc, block_m ); } } } } void threaded_mm(const bfloat16* X, const bfloat16* W, bfloat16* Y, const int64_t ks_b_index) { constexpr int64_t num_threads = 48; constexpr int64_t N = 64; constexpr int64_t K = 96; constexpr int64_t Mr = 32; constexpr int64_t Nr = 32; constexpr int64_t Kr = 32; constexpr int64_t Nr_blocks = (N + Nr - 1) / Nr; constexpr int64_t Kr_blocks = (K + Kr - 1) / Kr; constexpr int64_t M = static_cast<int64_t>(384L); constexpr int64_t Mr_blocks = (M + Mr - 1) / Mr; constexpr int64_t Mt_blocks = 1; constexpr int64_t Nt_blocks = 1; constexpr int64_t Kt_blocks = 3; constexpr int64_t Mc_blocks = 1; constexpr int64_t Nc_blocks = 1; constexpr int64_t Kc_blocks = 3; constexpr int64_t num_Mc_blocks = (Mr_blocks + Mc_blocks - 1) / Mc_blocks; constexpr int64_t num_Nc_blocks = (Nr_blocks + Nc_blocks - 1) / Nc_blocks; constexpr int64_t num_Mt_blocks = (Mr_blocks + Mt_blocks - 1) / Mt_blocks; constexpr int64_t num_Nt_blocks = (Nr_blocks + Nt_blocks - 1) / Nt_blocks; constexpr int64_t num_Kt_blocks = (Kr_blocks + Kt_blocks - 1) / Kt_blocks; // make sure all partitions are assigned AOTI_TORCH_CHECK( Mt_blocks * Nt_blocks * Kt_blocks * 48 >= Mr_blocks * Nr_blocks * Kr_blocks, "Not all partitions are assigned." ); #pragma omp parallel num_threads(48) { const int tid = omp_get_thread_num(); const int64_t k_group_id = tid / num_Kt_blocks; const int64_t k_slice_id = tid % num_Kt_blocks; const int64_t n_group_id = k_group_id / num_Nt_blocks; const int64_t n_slice_id = k_group_id % num_Nt_blocks; const int64_t k_block_start = k_slice_id * Kt_blocks; const int64_t k_block_end = std::min(k_block_start + Kt_blocks, Kr_blocks); const int64_t n_block_start = n_slice_id * Nt_blocks; const int64_t n_block_end = std::min(n_block_start + Nt_blocks, Nr_blocks); const int64_t m_block_start = std::min(n_group_id * Mt_blocks, Mr_blocks); const int64_t m_block_end = std::min(m_block_start + Mt_blocks, Mr_blocks); const int64_t num_Mc_blocks_per_thread = (m_block_end - m_block_start + Mc_blocks - 1) / Mc_blocks; AMXState amx_state; auto _local_acc_buf = std::make_unique<float[]>(static_cast<int64_t>(Mc_blocks*Mr*Nc_blocks*Nr)); auto local_acc_buf = _local_acc_buf.get(); for (int64_t mc_block_id = 0; mc_block_id < num_Mc_blocks_per_thread; mc_block_id++) { const int64_t my_mc_block_id = (mc_block_id + n_slice_id) % num_Mc_blocks_per_thread; const int64_t mc = m_block_start + my_mc_block_id * Mc_blocks; const int64_t m_start = mc * Mr; const int64_t m_end = std::min(std::min(mc + Mc_blocks, m_block_end) * Mr, M); const int64_t m_size = m_end - m_start; for (int64_t nc = n_block_start; nc < n_block_end; nc += Nc_blocks) { const int64_t n_start = nc * Nr; const int64_t n_end = std::min(std::min(nc + Nc_blocks, n_block_end) * Nr, N); const int64_t n_size = n_end - n_start; // NB: assume we pad N, nc_block_end won't exceed padded N here. const int64_t nc_block_end = std::min(nc + Nc_blocks, n_block_end); if (_local_acc_buf == nullptr) { _local_acc_buf = std::make_unique<float[]>(static_cast<int64_t>(Mc_blocks*Mr*Nc_blocks*Nr)); local_acc_buf = _local_acc_buf.get(); } for (int64_t kc = k_block_start; kc < k_block_end; kc += Kc_blocks) { int64_t k_start = kc * Kr; int64_t k_end = std::min(std::min(kc + Kc_blocks, k_block_end) * Kr, K); for (int64_t nci = nc; nci < nc_block_end; nci++) { if (kc == k_block_start) { cpp_bmm_micro_gemm<static_cast<bool>(false)>( amx_state, &(X[static_cast<int64_t>(k_start + (96L*m_start) + (36864L*ks_b_index))]), &(W[static_cast<int64_t>((32L*k_start) + (3072L*nci) + (6144L*ks_b_index))]), &(local_acc_buf[static_cast<int64_t>((Nr*nci) + ((-1L)*Nr*nc))]), static_cast<int64_t>(m_end + ((-1L)*m_start)), static_cast<int64_t>(Nr), static_cast<int64_t>(k_end + ((-1L)*k_start)), static_cast<int64_t>(96L), static_cast<int64_t>(32L), static_cast<int64_t>(Nc_blocks*Nr) ); } else { cpp_bmm_micro_gemm<static_cast<bool>(true)>( amx_state, &(X[static_cast<int64_t>(k_start + (96L*m_start) + (36864L*ks_b_index))]), &(W[static_cast<int64_t>((32L*k_start) + (3072L*nci) + (6144L*ks_b_index))]), &(local_acc_buf[static_cast<int64_t>((Nr*nci) + ((-1L)*Nr*nc))]), static_cast<int64_t>(m_end + ((-1L)*m_start)), static_cast<int64_t>(Nr), static_cast<int64_t>(k_end + ((-1L)*k_start)), static_cast<int64_t>(96L), static_cast<int64_t>(32L), static_cast<int64_t>(Nc_blocks*Nr) ); } } } { { #pragma GCC ivdep for(int64_t x0=static_cast<int64_t>(0L); x0<static_cast<int64_t>(m_end + ((-1L)*m_start)); x0+=static_cast<int64_t>(1L)) { for(int64_t x1=static_cast<int64_t>(0L); x1<static_cast<int64_t>(16L*(c10::div_floor_integer(static_cast<int64_t>((n_end + ((-1L)*n_start))), static_cast<int64_t>(16L)))); x1+=static_cast<int64_t>(16L)) { auto tmp0 = at::vec::Vectorized<float>::loadu(local_acc_buf + static_cast<int64_t>(x1 + (Nc_blocks*Nr*x0)), static_cast<int64_t>(16)); auto tmp1 = at::vec::convert<bfloat16>(tmp0); tmp1.store(Y + static_cast<int64_t>(n_start + x1 + (64L*m_start) + (64L*x0) + (24576L*ks_b_index)), static_cast<int64_t>(16)); } for(int64_t x1=static_cast<int64_t>(16L*(c10::div_floor_integer(static_cast<int64_t>((n_end + ((-1L)*n_start))), static_cast<int64_t>(16L)))); x1<static_cast<int64_t>(n_end + ((-1L)*n_start)); x1+=(static_cast<int64_t>(n_end + ((-1L)*n_start) + ((-16L)*(c10::div_floor_integer(static_cast<int64_t>((n_end + ((-1L)*n_start))), static_cast<int64_t>(16L))))) == 0 ? 1 : static_cast<int64_t>(n_end + ((-1L)*n_start) + ((-16L)*(c10::div_floor_integer(static_cast<int64_t>((n_end + ((-1L)*n_start))), static_cast<int64_t>(16L))))))) { auto tmp0 = at::vec::Vectorized<float>::loadu(local_acc_buf + static_cast<int64_t>(x1 + (Nc_blocks*Nr*x0)), static_cast<int64_t>(n_end + ((-1L)*n_start) + ((-16L)*(c10::div_floor_integer(static_cast<int64_t>((n_end + ((-1L)*n_start))), static_cast<int64_t>(16L)))))); auto tmp1 = at::vec::convert<bfloat16>(tmp0); tmp1.store(Y + static_cast<int64_t>(n_start + x1 + (64L*m_start) + (64L*x0) + (24576L*ks_b_index)), static_cast<int64_t>(n_end + ((-1L)*n_start) + ((-16L)*(c10::div_floor_integer(static_cast<int64_t>((n_end + ((-1L)*n_start))), static_cast<int64_t>(16L)))))); } } } } } } amx_state.release([]() { _tile_release(); }); } } void single_thread_mm(const bfloat16* X, const bfloat16* W, bfloat16* Y, const int64_t ks_b_index) { constexpr int64_t num_threads = 1; constexpr int64_t N = 64; constexpr int64_t K = 96; constexpr int64_t Mr = 32; constexpr int64_t Nr = 32; constexpr int64_t Kr = 32; constexpr int64_t Nr_blocks = (N + Nr - 1) / Nr; constexpr int64_t Kr_blocks = (K + Kr - 1) / Kr; constexpr int64_t M = static_cast<int64_t>(384L); constexpr int64_t Mr_blocks = (M + Mr - 1) / Mr; constexpr int64_t Mt_blocks = 12; constexpr int64_t Nt_blocks = 2; constexpr int64_t Kt_blocks = 3; constexpr int64_t Mc_blocks = 12; constexpr int64_t Nc_blocks = 1; constexpr int64_t Kc_blocks = 3; constexpr int64_t num_Mc_blocks = (Mr_blocks + Mc_blocks - 1) / Mc_blocks; constexpr int64_t num_Nc_blocks = (Nr_blocks + Nc_blocks - 1) / Nc_blocks; constexpr int64_t num_Mt_blocks = (Mr_blocks + Mt_blocks - 1) / Mt_blocks; constexpr int64_t num_Nt_blocks = (Nr_blocks + Nt_blocks - 1) / Nt_blocks; constexpr int64_t num_Kt_blocks = (Kr_blocks + Kt_blocks - 1) / Kt_blocks; // make sure all partitions are assigned AOTI_TORCH_CHECK( Mt_blocks * Nt_blocks * Kt_blocks * 1 >= Mr_blocks * Nr_blocks * Kr_blocks, "Not all partitions are assigned." ); { constexpr int tid = 0; constexpr int64_t k_group_id = 0; constexpr int64_t k_slice_id = 0; constexpr int64_t n_group_id = 0; constexpr int64_t n_slice_id = 0; constexpr int64_t m_block_start = 0; constexpr int64_t n_block_start = 0; constexpr int64_t n_block_end = Nr_blocks; constexpr int64_t k_block_start = 0; constexpr int64_t k_block_end = Kr_blocks; constexpr int64_t num_Mc_blocks_per_thread = num_Mc_blocks; constexpr int64_t m_block_end = Mr_blocks; AMXState amx_state; auto _local_acc_buf = std::make_unique<float[]>(static_cast<int64_t>(Mc_blocks*Mr*Nc_blocks*Nr)); auto local_acc_buf = _local_acc_buf.get(); for (int64_t mc_block_id = 0; mc_block_id < num_Mc_blocks_per_thread; mc_block_id++) { const int64_t my_mc_block_id = (mc_block_id + n_slice_id) % num_Mc_blocks_per_thread; const int64_t mc = m_block_start + my_mc_block_id * Mc_blocks; const int64_t m_start = mc * Mr; const int64_t m_end = std::min(std::min(mc + Mc_blocks, m_block_end) * Mr, M); const int64_t m_size = m_end - m_start; for (int64_t nc = n_block_start; nc < n_block_end; nc += Nc_blocks) { const int64_t n_start = nc * Nr; const int64_t n_end = std::min(std::min(nc + Nc_blocks, n_block_end) * Nr, N); const int64_t n_size = n_end - n_start; // NB: assume we pad N, nc_block_end won't exceed padded N here. const int64_t nc_block_end = std::min(nc + Nc_blocks, n_block_end); if (_local_acc_buf == nullptr) { _local_acc_buf = std::make_unique<float[]>(static_cast<int64_t>(Mc_blocks*Mr*Nc_blocks*Nr)); local_acc_buf = _local_acc_buf.get(); } for (int64_t kc = k_block_start; kc < k_block_end; kc += Kc_blocks) { int64_t k_start = kc * Kr; int64_t k_end = std::min(std::min(kc + Kc_blocks, k_block_end) * Kr, K); for (int64_t nci = nc; nci < nc_block_end; nci++) { if (kc == k_block_start) { cpp_bmm_micro_gemm<static_cast<bool>(false)>( amx_state, &(X[static_cast<int64_t>(k_start + (96L*m_start) + (36864L*ks_b_index))]), &(W[static_cast<int64_t>((32L*k_start) + (3072L*nci) + (6144L*ks_b_index))]), &(local_acc_buf[static_cast<int64_t>((Nr*nci) + ((-1L)*Nr*nc))]), static_cast<int64_t>(m_end + ((-1L)*m_start)), static_cast<int64_t>(Nr), static_cast<int64_t>(k_end + ((-1L)*k_start)), static_cast<int64_t>(96L), static_cast<int64_t>(32L), static_cast<int64_t>(Nc_blocks*Nr) ); } else { cpp_bmm_micro_gemm<static_cast<bool>(true)>( amx_state, &(X[static_cast<int64_t>(k_start + (96L*m_start) + (36864L*ks_b_index))]), &(W[static_cast<int64_t>((32L*k_start) + (3072L*nci) + (6144L*ks_b_index))]), &(local_acc_buf[static_cast<int64_t>((Nr*nci) + ((-1L)*Nr*nc))]), static_cast<int64_t>(m_end + ((-1L)*m_start)), static_cast<int64_t>(Nr), static_cast<int64_t>(k_end + ((-1L)*k_start)), static_cast<int64_t>(96L), static_cast<int64_t>(32L), static_cast<int64_t>(Nc_blocks*Nr) ); } } } { { #pragma GCC ivdep for(int64_t x0=static_cast<int64_t>(0L); x0<static_cast<int64_t>(m_end + ((-1L)*m_start)); x0+=static_cast<int64_t>(1L)) { for(int64_t x1=static_cast<int64_t>(0L); x1<static_cast<int64_t>(16L*(c10::div_floor_integer(static_cast<int64_t>((n_end + ((-1L)*n_start))), static_cast<int64_t>(16L)))); x1+=static_cast<int64_t>(16L)) { auto tmp0 = at::vec::Vectorized<float>::loadu(local_acc_buf + static_cast<int64_t>(x1 + (Nc_blocks*Nr*x0)), static_cast<int64_t>(16)); auto tmp1 = at::vec::convert<bfloat16>(tmp0); tmp1.store(Y + static_cast<int64_t>(n_start + x1 + (64L*m_start) + (64L*x0) + (24576L*ks_b_index)), static_cast<int64_t>(16)); } for(int64_t x1=static_cast<int64_t>(16L*(c10::div_floor_integer(static_cast<int64_t>((n_end + ((-1L)*n_start))), static_cast<int64_t>(16L)))); x1<static_cast<int64_t>(n_end + ((-1L)*n_start)); x1+=(static_cast<int64_t>(n_end + ((-1L)*n_start) + ((-16L)*(c10::div_floor_integer(static_cast<int64_t>((n_end + ((-1L)*n_start))), static_cast<int64_t>(16L))))) == 0 ? 1 : static_cast<int64_t>(n_end + ((-1L)*n_start) + ((-16L)*(c10::div_floor_integer(static_cast<int64_t>((n_end + ((-1L)*n_start))), static_cast<int64_t>(16L))))))) { auto tmp0 = at::vec::Vectorized<float>::loadu(local_acc_buf + static_cast<int64_t>(x1 + (Nc_blocks*Nr*x0)), static_cast<int64_t>(n_end + ((-1L)*n_start) + ((-16L)*(c10::div_floor_integer(static_cast<int64_t>((n_end + ((-1L)*n_start))), static_cast<int64_t>(16L)))))); auto tmp1 = at::vec::convert<bfloat16>(tmp0); tmp1.store(Y + static_cast<int64_t>(n_start + x1 + (64L*m_start) + (64L*x0) + (24576L*ks_b_index)), static_cast<int64_t>(n_end + ((-1L)*n_start) + ((-16L)*(c10::div_floor_integer(static_cast<int64_t>((n_end + ((-1L)*n_start))), static_cast<int64_t>(16L)))))); } } } } } } amx_state.release([]() { _tile_release(); }); } } extern "C" void cpp_bmm(const bfloat16* X, const bfloat16* W, bfloat16* Y) { const int64_t B = static_cast<int64_t>(5L); constexpr int64_t num_threads = 48; int64_t B_single_thread_block = (B / num_threads) * num_threads; #pragma omp parallel for num_threads(48) for (int64_t b_start = 0; b_start < B_single_thread_block; ++b_start) { single_thread_mm(X, W, Y, b_start); } for (int64_t b_start = B_single_thread_block; b_start < B; ++b_start) { threaded_mm(X, W, Y, b_start); } } ``` Pull Request resolved: #129772 Approved by: https://github.com/jgong5, https://github.com/leslie-fang-intel, https://github.com/jansel
zhangxiaoli73
pushed a commit
to zhangxiaoli73/pytorch
that referenced
this issue
Dec 6, 2024
This PR adds the Cpp template for BMM, for FP32, FP16, and BF16. See pytorch#125683 for more background. 1. Adds `CppBmmTemplate` class which inherits from `CppPackedGemmTemplate`. Given a number of worker threads `num_threads` and batch size `B`, execute the Gemm kernel. For the first `B - (B % num_threads)` batch inputs, run one sub-gemm problem per thread. Then for the remaining `B % num_threads` sub-gemms, we execute each subproblem using the parallelized Gemm kernel. To manage this code, the `GEMM_TEMPLATE` from `CppPackedGemmTemplate` is rendered two different times, one with a single thread and one which includes the parallel OMP pragma. 2. Adapts `CppPackedGemmTemplate` to allow for child class. The `GEMM_TEMPLATE` is separated into different strings to allow for rendering by the child class. Slicing/indexing are adapted to allow for 3D BMM inputs. Additional methods `get_options()` and `_get_params_for_choices()` are added to reduce code duplication. BMM within `dlrm` benchmark has a single input buffer which is used for but X and W inputs. This is currently not supported in this PR. ### Performance On Granite/Sapphire Rapids, cpp_bmm template code uses AMX which requires an expensive transpose operation so the BMM op is rarely selected as faster than the existing external bmm kernel. As a result, speedup on SPR is identical with and without BMM code. Pass rate matches the rates for main exactly. #### Test Summary on Granite Rapids Test Scenario | Comp Item | Date | Compiler | torchbench | huggingface | timm_models -- | -- | -- | -- | -- | -- | -- Single Socket Multi-Threads | Pass Rate | gemm autotune| inductor | 91%, 73/80 | 100%, 46/46 | 100%, 61/61 | | | bmm + gemm autotune | inductor | 91%, 73/80 | 100%, 46/46 | 100%, 61/61 | | Geomean Speedup | gemm autotune| inductor | 2.15x | 1.91x | 2.52x | | | bmm + gemm autotune | inductor | 2.15x | 1.96x | 2.53x Single Core Single-Thread | Pass Rate | gemm autotune | inductor | 91%, 73/80 | 100%, 46/46 | 100%, 61/61 | | | bmm + gemm autotune| inductor | 91%, 73/80 | 100%, 46/46 | 100%, 61/61 | | Geomean Speedup | inductor_locally_benchmark_586 | inductor | 2.43x | 1.56x | 2.60x | | | inductor_locally_benchmark_585 | inductor | 2.45x | 1.56x | 2.63x This is not the case on an older Skylake Xeon machine. For the BMM ops contained in torchbench models, bmm performance improves by 1.10-2.64x. #### BF16 28-core Skylake Xeon | Model | Inductor | GemmAutotune | Gemm+BMM Autotune | |--------|--------|--------|--------| | BERT_pytorch | 1.233x | 2.597x | 2.608x | | hf_DistilBert | 1.128x | 2.242x | 2.368x | | hf_Reformer | 1.124x | 1.419x | 1.590x | | hf_T5_base | 1.012x | 1.257x | 1.382x | | hf_T5_large | 1.085x | 2.228x | 2.345x | ## Example BMM Code ``` #include <c10/util/Unroll.h> #include <torch/csrc/inductor/aoti_torch/c/shim.h> template <bool accum> inline void cpp_bmm_micro_gemm_amx_kernel_32_2( AMXState& amx_state, const bfloat16* __restrict__ A, const bfloat16* __restrict__ B, float* __restrict__ C, int64_t K, int64_t lda, int64_t ldb, int64_t ldc, uint8_t tilecfg_rows ) { // TODO(jgong5): add prefetch hint for A, B, C auto loadconfig = [](const amx_tilecfg& cfg) { _tile_loadconfig(&cfg); }; const auto last_k_offset = K / 32 * 32; const auto tail_k_size = K - last_k_offset; if C10_LIKELY (last_k_offset > 0) { amx_state.configure(tilecfg_rows, 64, 32 / 16, 2, loadconfig); } else { amx_state.configure(tilecfg_rows, tail_k_size * sizeof(bfloat16), 32 / 16, 2, loadconfig); } auto load_c = [&]() { _tile_loadd(0, C + 0 * ldc + 0, ldc * sizeof(float)); _tile_loadd(1, C + 0 * ldc + 16, ldc * sizeof(float)); _tile_loadd(2, C + 16 * ldc + 0, ldc * sizeof(float)); _tile_loadd(3, C + 16 * ldc + 16, ldc * sizeof(float)); }; auto zero_c = [&]() { _tile_zero(0); _tile_zero(1); _tile_zero(2); _tile_zero(3); }; if constexpr (accum) { load_c(); } else { zero_c(); } auto compute = [&](int k) { _tile_stream_loadd(4, A + 0 * lda + k, lda * sizeof(bfloat16)); _tile_loadd(6, B + k * ldb + 0, ldb * 2 * sizeof(bfloat16)); _tile_dpbf16ps(0, 4, 6); _tile_loadd(7, B + k * ldb + 32, ldb * 2 * sizeof(bfloat16)); _tile_dpbf16ps(1, 4, 7); _tile_stream_loadd(5, A + 16 * lda + k, lda * sizeof(bfloat16)); _tile_dpbf16ps(2, 5, 6); _tile_dpbf16ps(3, 5, 7); }; #pragma GCC unroll 4 for (int k = 0; k < last_k_offset; k += 32) { compute(k); } auto store_c = [&]() { // store to C _tile_stored(0, C + 0 * ldc + 0, ldc * sizeof(float)); _tile_stored(1, C + 0 * ldc + 16, ldc * sizeof(float)); _tile_stored(2, C + 16 * ldc + 0, ldc * sizeof(float)); _tile_stored(3, C + 16 * ldc + 16, ldc * sizeof(float)); }; // TODO(jgong5): move tail k computation to separate loopnest to save tile configuration overhead if C10_UNLIKELY (tail_k_size > 0) { if C10_LIKELY (last_k_offset > 0) { store_c(); amx_state.configure(tilecfg_rows, tail_k_size * sizeof(bfloat16), 32 / 16, 2, loadconfig); load_c(); } compute(last_k_offset); } store_c(); } template <bool accum> inline void cpp_bmm_micro_gemm_amx_kernel_16_2( AMXState& amx_state, const bfloat16* __restrict__ A, const bfloat16* __restrict__ B, float* __restrict__ C, int64_t K, int64_t lda, int64_t ldb, int64_t ldc, uint8_t tilecfg_rows ) { // TODO(jgong5): add prefetch hint for A, B, C auto loadconfig = [](const amx_tilecfg& cfg) { _tile_loadconfig(&cfg); }; const auto last_k_offset = K / 32 * 32; const auto tail_k_size = K - last_k_offset; if C10_LIKELY (last_k_offset > 0) { amx_state.configure(tilecfg_rows, 64, 16 / 16, 2, loadconfig); } else { amx_state.configure(tilecfg_rows, tail_k_size * sizeof(bfloat16), 16 / 16, 2, loadconfig); } auto load_c = [&]() { _tile_loadd(0, C + 0 * ldc + 0, ldc * sizeof(float)); _tile_loadd(1, C + 0 * ldc + 16, ldc * sizeof(float)); }; auto zero_c = [&]() { _tile_zero(0); _tile_zero(1); }; if constexpr (accum) { load_c(); } else { zero_c(); } auto compute = [&](int k) { _tile_stream_loadd(2, A + 0 * lda + k, lda * sizeof(bfloat16)); _tile_loadd(3, B + k * ldb + 0, ldb * 2 * sizeof(bfloat16)); _tile_dpbf16ps(0, 2, 3); _tile_loadd(4, B + k * ldb + 32, ldb * 2 * sizeof(bfloat16)); _tile_dpbf16ps(1, 2, 4); }; #pragma GCC unroll 4 for (int k = 0; k < last_k_offset; k += 32) { compute(k); } auto store_c = [&]() { // store to C _tile_stored(0, C + 0 * ldc + 0, ldc * sizeof(float)); _tile_stored(1, C + 0 * ldc + 16, ldc * sizeof(float)); }; // TODO(jgong5): move tail k computation to separate loopnest to save tile configuration overhead if C10_UNLIKELY (tail_k_size > 0) { if C10_LIKELY (last_k_offset > 0) { store_c(); amx_state.configure(tilecfg_rows, tail_k_size * sizeof(bfloat16), 16 / 16, 2, loadconfig); load_c(); } compute(last_k_offset); } store_c(); } template <bool accum> inline void cpp_bmm_micro_gemm( AMXState& amx_state, const bfloat16* __restrict__ A, const bfloat16* __restrict__ B, float* __restrict__ C, int64_t M, int64_t N, int64_t K, int64_t lda, int64_t ldb, int64_t ldc ) { AOTI_TORCH_CHECK(N % 32 == 0, "N dimension must be multiple of 32"); AOTI_TORCH_CHECK(K % 2 == 0, "K dimension must be multiple of 2"); // TODO(jgong5): loop unroll for M and N for (int64_t n = 0; n < N; n += 32) { for (int64_t m = 0; m < M; m += 32) { int64_t block_m = std::min<int64_t>(M - m, 32); int64_t m_tail = m; if (block_m >= 32) { cpp_bmm_micro_gemm_amx_kernel_32_2<accum>( amx_state, A + m * lda, B + n, C + m * ldc + n, K, lda, ldb, ldc, 16 ); block_m -= 32; m_tail += 32; } else if (block_m >= 16) { cpp_bmm_micro_gemm_amx_kernel_16_2<accum>( amx_state, A + m * lda, B + n, C + m * ldc + n, K, lda, ldb, ldc, 16 ); block_m -= 16; m_tail += 16; } if (block_m > 0) { cpp_bmm_micro_gemm_amx_kernel_16_2<accum>( amx_state, A + m_tail * lda, B + n, C + m_tail * ldc + n, K, lda, ldb, ldc, block_m ); } } } } void threaded_mm(const bfloat16* X, const bfloat16* W, bfloat16* Y, const int64_t ks_b_index) { constexpr int64_t num_threads = 48; constexpr int64_t N = 64; constexpr int64_t K = 96; constexpr int64_t Mr = 32; constexpr int64_t Nr = 32; constexpr int64_t Kr = 32; constexpr int64_t Nr_blocks = (N + Nr - 1) / Nr; constexpr int64_t Kr_blocks = (K + Kr - 1) / Kr; constexpr int64_t M = static_cast<int64_t>(384L); constexpr int64_t Mr_blocks = (M + Mr - 1) / Mr; constexpr int64_t Mt_blocks = 1; constexpr int64_t Nt_blocks = 1; constexpr int64_t Kt_blocks = 3; constexpr int64_t Mc_blocks = 1; constexpr int64_t Nc_blocks = 1; constexpr int64_t Kc_blocks = 3; constexpr int64_t num_Mc_blocks = (Mr_blocks + Mc_blocks - 1) / Mc_blocks; constexpr int64_t num_Nc_blocks = (Nr_blocks + Nc_blocks - 1) / Nc_blocks; constexpr int64_t num_Mt_blocks = (Mr_blocks + Mt_blocks - 1) / Mt_blocks; constexpr int64_t num_Nt_blocks = (Nr_blocks + Nt_blocks - 1) / Nt_blocks; constexpr int64_t num_Kt_blocks = (Kr_blocks + Kt_blocks - 1) / Kt_blocks; // make sure all partitions are assigned AOTI_TORCH_CHECK( Mt_blocks * Nt_blocks * Kt_blocks * 48 >= Mr_blocks * Nr_blocks * Kr_blocks, "Not all partitions are assigned." ); #pragma omp parallel num_threads(48) { const int tid = omp_get_thread_num(); const int64_t k_group_id = tid / num_Kt_blocks; const int64_t k_slice_id = tid % num_Kt_blocks; const int64_t n_group_id = k_group_id / num_Nt_blocks; const int64_t n_slice_id = k_group_id % num_Nt_blocks; const int64_t k_block_start = k_slice_id * Kt_blocks; const int64_t k_block_end = std::min(k_block_start + Kt_blocks, Kr_blocks); const int64_t n_block_start = n_slice_id * Nt_blocks; const int64_t n_block_end = std::min(n_block_start + Nt_blocks, Nr_blocks); const int64_t m_block_start = std::min(n_group_id * Mt_blocks, Mr_blocks); const int64_t m_block_end = std::min(m_block_start + Mt_blocks, Mr_blocks); const int64_t num_Mc_blocks_per_thread = (m_block_end - m_block_start + Mc_blocks - 1) / Mc_blocks; AMXState amx_state; auto _local_acc_buf = std::make_unique<float[]>(static_cast<int64_t>(Mc_blocks*Mr*Nc_blocks*Nr)); auto local_acc_buf = _local_acc_buf.get(); for (int64_t mc_block_id = 0; mc_block_id < num_Mc_blocks_per_thread; mc_block_id++) { const int64_t my_mc_block_id = (mc_block_id + n_slice_id) % num_Mc_blocks_per_thread; const int64_t mc = m_block_start + my_mc_block_id * Mc_blocks; const int64_t m_start = mc * Mr; const int64_t m_end = std::min(std::min(mc + Mc_blocks, m_block_end) * Mr, M); const int64_t m_size = m_end - m_start; for (int64_t nc = n_block_start; nc < n_block_end; nc += Nc_blocks) { const int64_t n_start = nc * Nr; const int64_t n_end = std::min(std::min(nc + Nc_blocks, n_block_end) * Nr, N); const int64_t n_size = n_end - n_start; // NB: assume we pad N, nc_block_end won't exceed padded N here. const int64_t nc_block_end = std::min(nc + Nc_blocks, n_block_end); if (_local_acc_buf == nullptr) { _local_acc_buf = std::make_unique<float[]>(static_cast<int64_t>(Mc_blocks*Mr*Nc_blocks*Nr)); local_acc_buf = _local_acc_buf.get(); } for (int64_t kc = k_block_start; kc < k_block_end; kc += Kc_blocks) { int64_t k_start = kc * Kr; int64_t k_end = std::min(std::min(kc + Kc_blocks, k_block_end) * Kr, K); for (int64_t nci = nc; nci < nc_block_end; nci++) { if (kc == k_block_start) { cpp_bmm_micro_gemm<static_cast<bool>(false)>( amx_state, &(X[static_cast<int64_t>(k_start + (96L*m_start) + (36864L*ks_b_index))]), &(W[static_cast<int64_t>((32L*k_start) + (3072L*nci) + (6144L*ks_b_index))]), &(local_acc_buf[static_cast<int64_t>((Nr*nci) + ((-1L)*Nr*nc))]), static_cast<int64_t>(m_end + ((-1L)*m_start)), static_cast<int64_t>(Nr), static_cast<int64_t>(k_end + ((-1L)*k_start)), static_cast<int64_t>(96L), static_cast<int64_t>(32L), static_cast<int64_t>(Nc_blocks*Nr) ); } else { cpp_bmm_micro_gemm<static_cast<bool>(true)>( amx_state, &(X[static_cast<int64_t>(k_start + (96L*m_start) + (36864L*ks_b_index))]), &(W[static_cast<int64_t>((32L*k_start) + (3072L*nci) + (6144L*ks_b_index))]), &(local_acc_buf[static_cast<int64_t>((Nr*nci) + ((-1L)*Nr*nc))]), static_cast<int64_t>(m_end + ((-1L)*m_start)), static_cast<int64_t>(Nr), static_cast<int64_t>(k_end + ((-1L)*k_start)), static_cast<int64_t>(96L), static_cast<int64_t>(32L), static_cast<int64_t>(Nc_blocks*Nr) ); } } } { { #pragma GCC ivdep for(int64_t x0=static_cast<int64_t>(0L); x0<static_cast<int64_t>(m_end + ((-1L)*m_start)); x0+=static_cast<int64_t>(1L)) { for(int64_t x1=static_cast<int64_t>(0L); x1<static_cast<int64_t>(16L*(c10::div_floor_integer(static_cast<int64_t>((n_end + ((-1L)*n_start))), static_cast<int64_t>(16L)))); x1+=static_cast<int64_t>(16L)) { auto tmp0 = at::vec::Vectorized<float>::loadu(local_acc_buf + static_cast<int64_t>(x1 + (Nc_blocks*Nr*x0)), static_cast<int64_t>(16)); auto tmp1 = at::vec::convert<bfloat16>(tmp0); tmp1.store(Y + static_cast<int64_t>(n_start + x1 + (64L*m_start) + (64L*x0) + (24576L*ks_b_index)), static_cast<int64_t>(16)); } for(int64_t x1=static_cast<int64_t>(16L*(c10::div_floor_integer(static_cast<int64_t>((n_end + ((-1L)*n_start))), static_cast<int64_t>(16L)))); x1<static_cast<int64_t>(n_end + ((-1L)*n_start)); x1+=(static_cast<int64_t>(n_end + ((-1L)*n_start) + ((-16L)*(c10::div_floor_integer(static_cast<int64_t>((n_end + ((-1L)*n_start))), static_cast<int64_t>(16L))))) == 0 ? 1 : static_cast<int64_t>(n_end + ((-1L)*n_start) + ((-16L)*(c10::div_floor_integer(static_cast<int64_t>((n_end + ((-1L)*n_start))), static_cast<int64_t>(16L))))))) { auto tmp0 = at::vec::Vectorized<float>::loadu(local_acc_buf + static_cast<int64_t>(x1 + (Nc_blocks*Nr*x0)), static_cast<int64_t>(n_end + ((-1L)*n_start) + ((-16L)*(c10::div_floor_integer(static_cast<int64_t>((n_end + ((-1L)*n_start))), static_cast<int64_t>(16L)))))); auto tmp1 = at::vec::convert<bfloat16>(tmp0); tmp1.store(Y + static_cast<int64_t>(n_start + x1 + (64L*m_start) + (64L*x0) + (24576L*ks_b_index)), static_cast<int64_t>(n_end + ((-1L)*n_start) + ((-16L)*(c10::div_floor_integer(static_cast<int64_t>((n_end + ((-1L)*n_start))), static_cast<int64_t>(16L)))))); } } } } } } amx_state.release([]() { _tile_release(); }); } } void single_thread_mm(const bfloat16* X, const bfloat16* W, bfloat16* Y, const int64_t ks_b_index) { constexpr int64_t num_threads = 1; constexpr int64_t N = 64; constexpr int64_t K = 96; constexpr int64_t Mr = 32; constexpr int64_t Nr = 32; constexpr int64_t Kr = 32; constexpr int64_t Nr_blocks = (N + Nr - 1) / Nr; constexpr int64_t Kr_blocks = (K + Kr - 1) / Kr; constexpr int64_t M = static_cast<int64_t>(384L); constexpr int64_t Mr_blocks = (M + Mr - 1) / Mr; constexpr int64_t Mt_blocks = 12; constexpr int64_t Nt_blocks = 2; constexpr int64_t Kt_blocks = 3; constexpr int64_t Mc_blocks = 12; constexpr int64_t Nc_blocks = 1; constexpr int64_t Kc_blocks = 3; constexpr int64_t num_Mc_blocks = (Mr_blocks + Mc_blocks - 1) / Mc_blocks; constexpr int64_t num_Nc_blocks = (Nr_blocks + Nc_blocks - 1) / Nc_blocks; constexpr int64_t num_Mt_blocks = (Mr_blocks + Mt_blocks - 1) / Mt_blocks; constexpr int64_t num_Nt_blocks = (Nr_blocks + Nt_blocks - 1) / Nt_blocks; constexpr int64_t num_Kt_blocks = (Kr_blocks + Kt_blocks - 1) / Kt_blocks; // make sure all partitions are assigned AOTI_TORCH_CHECK( Mt_blocks * Nt_blocks * Kt_blocks * 1 >= Mr_blocks * Nr_blocks * Kr_blocks, "Not all partitions are assigned." ); { constexpr int tid = 0; constexpr int64_t k_group_id = 0; constexpr int64_t k_slice_id = 0; constexpr int64_t n_group_id = 0; constexpr int64_t n_slice_id = 0; constexpr int64_t m_block_start = 0; constexpr int64_t n_block_start = 0; constexpr int64_t n_block_end = Nr_blocks; constexpr int64_t k_block_start = 0; constexpr int64_t k_block_end = Kr_blocks; constexpr int64_t num_Mc_blocks_per_thread = num_Mc_blocks; constexpr int64_t m_block_end = Mr_blocks; AMXState amx_state; auto _local_acc_buf = std::make_unique<float[]>(static_cast<int64_t>(Mc_blocks*Mr*Nc_blocks*Nr)); auto local_acc_buf = _local_acc_buf.get(); for (int64_t mc_block_id = 0; mc_block_id < num_Mc_blocks_per_thread; mc_block_id++) { const int64_t my_mc_block_id = (mc_block_id + n_slice_id) % num_Mc_blocks_per_thread; const int64_t mc = m_block_start + my_mc_block_id * Mc_blocks; const int64_t m_start = mc * Mr; const int64_t m_end = std::min(std::min(mc + Mc_blocks, m_block_end) * Mr, M); const int64_t m_size = m_end - m_start; for (int64_t nc = n_block_start; nc < n_block_end; nc += Nc_blocks) { const int64_t n_start = nc * Nr; const int64_t n_end = std::min(std::min(nc + Nc_blocks, n_block_end) * Nr, N); const int64_t n_size = n_end - n_start; // NB: assume we pad N, nc_block_end won't exceed padded N here. const int64_t nc_block_end = std::min(nc + Nc_blocks, n_block_end); if (_local_acc_buf == nullptr) { _local_acc_buf = std::make_unique<float[]>(static_cast<int64_t>(Mc_blocks*Mr*Nc_blocks*Nr)); local_acc_buf = _local_acc_buf.get(); } for (int64_t kc = k_block_start; kc < k_block_end; kc += Kc_blocks) { int64_t k_start = kc * Kr; int64_t k_end = std::min(std::min(kc + Kc_blocks, k_block_end) * Kr, K); for (int64_t nci = nc; nci < nc_block_end; nci++) { if (kc == k_block_start) { cpp_bmm_micro_gemm<static_cast<bool>(false)>( amx_state, &(X[static_cast<int64_t>(k_start + (96L*m_start) + (36864L*ks_b_index))]), &(W[static_cast<int64_t>((32L*k_start) + (3072L*nci) + (6144L*ks_b_index))]), &(local_acc_buf[static_cast<int64_t>((Nr*nci) + ((-1L)*Nr*nc))]), static_cast<int64_t>(m_end + ((-1L)*m_start)), static_cast<int64_t>(Nr), static_cast<int64_t>(k_end + ((-1L)*k_start)), static_cast<int64_t>(96L), static_cast<int64_t>(32L), static_cast<int64_t>(Nc_blocks*Nr) ); } else { cpp_bmm_micro_gemm<static_cast<bool>(true)>( amx_state, &(X[static_cast<int64_t>(k_start + (96L*m_start) + (36864L*ks_b_index))]), &(W[static_cast<int64_t>((32L*k_start) + (3072L*nci) + (6144L*ks_b_index))]), &(local_acc_buf[static_cast<int64_t>((Nr*nci) + ((-1L)*Nr*nc))]), static_cast<int64_t>(m_end + ((-1L)*m_start)), static_cast<int64_t>(Nr), static_cast<int64_t>(k_end + ((-1L)*k_start)), static_cast<int64_t>(96L), static_cast<int64_t>(32L), static_cast<int64_t>(Nc_blocks*Nr) ); } } } { { #pragma GCC ivdep for(int64_t x0=static_cast<int64_t>(0L); x0<static_cast<int64_t>(m_end + ((-1L)*m_start)); x0+=static_cast<int64_t>(1L)) { for(int64_t x1=static_cast<int64_t>(0L); x1<static_cast<int64_t>(16L*(c10::div_floor_integer(static_cast<int64_t>((n_end + ((-1L)*n_start))), static_cast<int64_t>(16L)))); x1+=static_cast<int64_t>(16L)) { auto tmp0 = at::vec::Vectorized<float>::loadu(local_acc_buf + static_cast<int64_t>(x1 + (Nc_blocks*Nr*x0)), static_cast<int64_t>(16)); auto tmp1 = at::vec::convert<bfloat16>(tmp0); tmp1.store(Y + static_cast<int64_t>(n_start + x1 + (64L*m_start) + (64L*x0) + (24576L*ks_b_index)), static_cast<int64_t>(16)); } for(int64_t x1=static_cast<int64_t>(16L*(c10::div_floor_integer(static_cast<int64_t>((n_end + ((-1L)*n_start))), static_cast<int64_t>(16L)))); x1<static_cast<int64_t>(n_end + ((-1L)*n_start)); x1+=(static_cast<int64_t>(n_end + ((-1L)*n_start) + ((-16L)*(c10::div_floor_integer(static_cast<int64_t>((n_end + ((-1L)*n_start))), static_cast<int64_t>(16L))))) == 0 ? 1 : static_cast<int64_t>(n_end + ((-1L)*n_start) + ((-16L)*(c10::div_floor_integer(static_cast<int64_t>((n_end + ((-1L)*n_start))), static_cast<int64_t>(16L))))))) { auto tmp0 = at::vec::Vectorized<float>::loadu(local_acc_buf + static_cast<int64_t>(x1 + (Nc_blocks*Nr*x0)), static_cast<int64_t>(n_end + ((-1L)*n_start) + ((-16L)*(c10::div_floor_integer(static_cast<int64_t>((n_end + ((-1L)*n_start))), static_cast<int64_t>(16L)))))); auto tmp1 = at::vec::convert<bfloat16>(tmp0); tmp1.store(Y + static_cast<int64_t>(n_start + x1 + (64L*m_start) + (64L*x0) + (24576L*ks_b_index)), static_cast<int64_t>(n_end + ((-1L)*n_start) + ((-16L)*(c10::div_floor_integer(static_cast<int64_t>((n_end + ((-1L)*n_start))), static_cast<int64_t>(16L)))))); } } } } } } amx_state.release([]() { _tile_release(); }); } } extern "C" void cpp_bmm(const bfloat16* X, const bfloat16* W, bfloat16* Y) { const int64_t B = static_cast<int64_t>(5L); constexpr int64_t num_threads = 48; int64_t B_single_thread_block = (B / num_threads) * num_threads; #pragma omp parallel for num_threads(48) for (int64_t b_start = 0; b_start < B_single_thread_block; ++b_start) { single_thread_mm(X, W, Y, b_start); } for (int64_t b_start = B_single_thread_block; b_start < B; ++b_start) { threaded_mm(X, W, Y, b_start); } } ``` Pull Request resolved: pytorch#129772 Approved by: https://github.com/jgong5, https://github.com/leslie-fang-intel, https://github.com/jansel
pytorch-bot bot
pushed a commit
that referenced
this issue
Dec 9, 2024
This PR adds the Cpp template for BMM, for FP32, FP16, and BF16. See #125683 for more background. 1. Adds `CppBmmTemplate` class which inherits from `CppPackedGemmTemplate`. Given a number of worker threads `num_threads` and batch size `B`, execute the Gemm kernel. For the first `B - (B % num_threads)` batch inputs, run one sub-gemm problem per thread. Then for the remaining `B % num_threads` sub-gemms, we execute each subproblem using the parallelized Gemm kernel. To manage this code, the `GEMM_TEMPLATE` from `CppPackedGemmTemplate` is rendered two different times, one with a single thread and one which includes the parallel OMP pragma. 2. Adapts `CppPackedGemmTemplate` to allow for child class. The `GEMM_TEMPLATE` is separated into different strings to allow for rendering by the child class. Slicing/indexing are adapted to allow for 3D BMM inputs. Additional methods `get_options()` and `_get_params_for_choices()` are added to reduce code duplication. BMM within `dlrm` benchmark has a single input buffer which is used for but X and W inputs. This is currently not supported in this PR. ### Performance On Granite/Sapphire Rapids, cpp_bmm template code uses AMX which requires an expensive transpose operation so the BMM op is rarely selected as faster than the existing external bmm kernel. As a result, speedup on SPR is identical with and without BMM code. Pass rate matches the rates for main exactly. #### Test Summary on Granite Rapids Test Scenario | Comp Item | Date | Compiler | torchbench | huggingface | timm_models -- | -- | -- | -- | -- | -- | -- Single Socket Multi-Threads | Pass Rate | gemm autotune| inductor | 91%, 73/80 | 100%, 46/46 | 100%, 61/61 | | | bmm + gemm autotune | inductor | 91%, 73/80 | 100%, 46/46 | 100%, 61/61 | | Geomean Speedup | gemm autotune| inductor | 2.15x | 1.91x | 2.52x | | | bmm + gemm autotune | inductor | 2.15x | 1.96x | 2.53x Single Core Single-Thread | Pass Rate | gemm autotune | inductor | 91%, 73/80 | 100%, 46/46 | 100%, 61/61 | | | bmm + gemm autotune| inductor | 91%, 73/80 | 100%, 46/46 | 100%, 61/61 | | Geomean Speedup | inductor_locally_benchmark_586 | inductor | 2.43x | 1.56x | 2.60x | | | inductor_locally_benchmark_585 | inductor | 2.45x | 1.56x | 2.63x This is not the case on an older Skylake Xeon machine. For the BMM ops contained in torchbench models, bmm performance improves by 1.10-2.64x. #### BF16 28-core Skylake Xeon | Model | Inductor | GemmAutotune | Gemm+BMM Autotune | |--------|--------|--------|--------| | BERT_pytorch | 1.233x | 2.597x | 2.608x | | hf_DistilBert | 1.128x | 2.242x | 2.368x | | hf_Reformer | 1.124x | 1.419x | 1.590x | | hf_T5_base | 1.012x | 1.257x | 1.382x | | hf_T5_large | 1.085x | 2.228x | 2.345x | ## Example BMM Code ``` #include <c10/util/Unroll.h> #include <torch/csrc/inductor/aoti_torch/c/shim.h> template <bool accum> inline void cpp_bmm_micro_gemm_amx_kernel_32_2( AMXState& amx_state, const bfloat16* __restrict__ A, const bfloat16* __restrict__ B, float* __restrict__ C, int64_t K, int64_t lda, int64_t ldb, int64_t ldc, uint8_t tilecfg_rows ) { // TODO(jgong5): add prefetch hint for A, B, C auto loadconfig = [](const amx_tilecfg& cfg) { _tile_loadconfig(&cfg); }; const auto last_k_offset = K / 32 * 32; const auto tail_k_size = K - last_k_offset; if C10_LIKELY (last_k_offset > 0) { amx_state.configure(tilecfg_rows, 64, 32 / 16, 2, loadconfig); } else { amx_state.configure(tilecfg_rows, tail_k_size * sizeof(bfloat16), 32 / 16, 2, loadconfig); } auto load_c = [&]() { _tile_loadd(0, C + 0 * ldc + 0, ldc * sizeof(float)); _tile_loadd(1, C + 0 * ldc + 16, ldc * sizeof(float)); _tile_loadd(2, C + 16 * ldc + 0, ldc * sizeof(float)); _tile_loadd(3, C + 16 * ldc + 16, ldc * sizeof(float)); }; auto zero_c = [&]() { _tile_zero(0); _tile_zero(1); _tile_zero(2); _tile_zero(3); }; if constexpr (accum) { load_c(); } else { zero_c(); } auto compute = [&](int k) { _tile_stream_loadd(4, A + 0 * lda + k, lda * sizeof(bfloat16)); _tile_loadd(6, B + k * ldb + 0, ldb * 2 * sizeof(bfloat16)); _tile_dpbf16ps(0, 4, 6); _tile_loadd(7, B + k * ldb + 32, ldb * 2 * sizeof(bfloat16)); _tile_dpbf16ps(1, 4, 7); _tile_stream_loadd(5, A + 16 * lda + k, lda * sizeof(bfloat16)); _tile_dpbf16ps(2, 5, 6); _tile_dpbf16ps(3, 5, 7); }; #pragma GCC unroll 4 for (int k = 0; k < last_k_offset; k += 32) { compute(k); } auto store_c = [&]() { // store to C _tile_stored(0, C + 0 * ldc + 0, ldc * sizeof(float)); _tile_stored(1, C + 0 * ldc + 16, ldc * sizeof(float)); _tile_stored(2, C + 16 * ldc + 0, ldc * sizeof(float)); _tile_stored(3, C + 16 * ldc + 16, ldc * sizeof(float)); }; // TODO(jgong5): move tail k computation to separate loopnest to save tile configuration overhead if C10_UNLIKELY (tail_k_size > 0) { if C10_LIKELY (last_k_offset > 0) { store_c(); amx_state.configure(tilecfg_rows, tail_k_size * sizeof(bfloat16), 32 / 16, 2, loadconfig); load_c(); } compute(last_k_offset); } store_c(); } template <bool accum> inline void cpp_bmm_micro_gemm_amx_kernel_16_2( AMXState& amx_state, const bfloat16* __restrict__ A, const bfloat16* __restrict__ B, float* __restrict__ C, int64_t K, int64_t lda, int64_t ldb, int64_t ldc, uint8_t tilecfg_rows ) { // TODO(jgong5): add prefetch hint for A, B, C auto loadconfig = [](const amx_tilecfg& cfg) { _tile_loadconfig(&cfg); }; const auto last_k_offset = K / 32 * 32; const auto tail_k_size = K - last_k_offset; if C10_LIKELY (last_k_offset > 0) { amx_state.configure(tilecfg_rows, 64, 16 / 16, 2, loadconfig); } else { amx_state.configure(tilecfg_rows, tail_k_size * sizeof(bfloat16), 16 / 16, 2, loadconfig); } auto load_c = [&]() { _tile_loadd(0, C + 0 * ldc + 0, ldc * sizeof(float)); _tile_loadd(1, C + 0 * ldc + 16, ldc * sizeof(float)); }; auto zero_c = [&]() { _tile_zero(0); _tile_zero(1); }; if constexpr (accum) { load_c(); } else { zero_c(); } auto compute = [&](int k) { _tile_stream_loadd(2, A + 0 * lda + k, lda * sizeof(bfloat16)); _tile_loadd(3, B + k * ldb + 0, ldb * 2 * sizeof(bfloat16)); _tile_dpbf16ps(0, 2, 3); _tile_loadd(4, B + k * ldb + 32, ldb * 2 * sizeof(bfloat16)); _tile_dpbf16ps(1, 2, 4); }; #pragma GCC unroll 4 for (int k = 0; k < last_k_offset; k += 32) { compute(k); } auto store_c = [&]() { // store to C _tile_stored(0, C + 0 * ldc + 0, ldc * sizeof(float)); _tile_stored(1, C + 0 * ldc + 16, ldc * sizeof(float)); }; // TODO(jgong5): move tail k computation to separate loopnest to save tile configuration overhead if C10_UNLIKELY (tail_k_size > 0) { if C10_LIKELY (last_k_offset > 0) { store_c(); amx_state.configure(tilecfg_rows, tail_k_size * sizeof(bfloat16), 16 / 16, 2, loadconfig); load_c(); } compute(last_k_offset); } store_c(); } template <bool accum> inline void cpp_bmm_micro_gemm( AMXState& amx_state, const bfloat16* __restrict__ A, const bfloat16* __restrict__ B, float* __restrict__ C, int64_t M, int64_t N, int64_t K, int64_t lda, int64_t ldb, int64_t ldc ) { AOTI_TORCH_CHECK(N % 32 == 0, "N dimension must be multiple of 32"); AOTI_TORCH_CHECK(K % 2 == 0, "K dimension must be multiple of 2"); // TODO(jgong5): loop unroll for M and N for (int64_t n = 0; n < N; n += 32) { for (int64_t m = 0; m < M; m += 32) { int64_t block_m = std::min<int64_t>(M - m, 32); int64_t m_tail = m; if (block_m >= 32) { cpp_bmm_micro_gemm_amx_kernel_32_2<accum>( amx_state, A + m * lda, B + n, C + m * ldc + n, K, lda, ldb, ldc, 16 ); block_m -= 32; m_tail += 32; } else if (block_m >= 16) { cpp_bmm_micro_gemm_amx_kernel_16_2<accum>( amx_state, A + m * lda, B + n, C + m * ldc + n, K, lda, ldb, ldc, 16 ); block_m -= 16; m_tail += 16; } if (block_m > 0) { cpp_bmm_micro_gemm_amx_kernel_16_2<accum>( amx_state, A + m_tail * lda, B + n, C + m_tail * ldc + n, K, lda, ldb, ldc, block_m ); } } } } void threaded_mm(const bfloat16* X, const bfloat16* W, bfloat16* Y, const int64_t ks_b_index) { constexpr int64_t num_threads = 48; constexpr int64_t N = 64; constexpr int64_t K = 96; constexpr int64_t Mr = 32; constexpr int64_t Nr = 32; constexpr int64_t Kr = 32; constexpr int64_t Nr_blocks = (N + Nr - 1) / Nr; constexpr int64_t Kr_blocks = (K + Kr - 1) / Kr; constexpr int64_t M = static_cast<int64_t>(384L); constexpr int64_t Mr_blocks = (M + Mr - 1) / Mr; constexpr int64_t Mt_blocks = 1; constexpr int64_t Nt_blocks = 1; constexpr int64_t Kt_blocks = 3; constexpr int64_t Mc_blocks = 1; constexpr int64_t Nc_blocks = 1; constexpr int64_t Kc_blocks = 3; constexpr int64_t num_Mc_blocks = (Mr_blocks + Mc_blocks - 1) / Mc_blocks; constexpr int64_t num_Nc_blocks = (Nr_blocks + Nc_blocks - 1) / Nc_blocks; constexpr int64_t num_Mt_blocks = (Mr_blocks + Mt_blocks - 1) / Mt_blocks; constexpr int64_t num_Nt_blocks = (Nr_blocks + Nt_blocks - 1) / Nt_blocks; constexpr int64_t num_Kt_blocks = (Kr_blocks + Kt_blocks - 1) / Kt_blocks; // make sure all partitions are assigned AOTI_TORCH_CHECK( Mt_blocks * Nt_blocks * Kt_blocks * 48 >= Mr_blocks * Nr_blocks * Kr_blocks, "Not all partitions are assigned." ); #pragma omp parallel num_threads(48) { const int tid = omp_get_thread_num(); const int64_t k_group_id = tid / num_Kt_blocks; const int64_t k_slice_id = tid % num_Kt_blocks; const int64_t n_group_id = k_group_id / num_Nt_blocks; const int64_t n_slice_id = k_group_id % num_Nt_blocks; const int64_t k_block_start = k_slice_id * Kt_blocks; const int64_t k_block_end = std::min(k_block_start + Kt_blocks, Kr_blocks); const int64_t n_block_start = n_slice_id * Nt_blocks; const int64_t n_block_end = std::min(n_block_start + Nt_blocks, Nr_blocks); const int64_t m_block_start = std::min(n_group_id * Mt_blocks, Mr_blocks); const int64_t m_block_end = std::min(m_block_start + Mt_blocks, Mr_blocks); const int64_t num_Mc_blocks_per_thread = (m_block_end - m_block_start + Mc_blocks - 1) / Mc_blocks; AMXState amx_state; auto _local_acc_buf = std::make_unique<float[]>(static_cast<int64_t>(Mc_blocks*Mr*Nc_blocks*Nr)); auto local_acc_buf = _local_acc_buf.get(); for (int64_t mc_block_id = 0; mc_block_id < num_Mc_blocks_per_thread; mc_block_id++) { const int64_t my_mc_block_id = (mc_block_id + n_slice_id) % num_Mc_blocks_per_thread; const int64_t mc = m_block_start + my_mc_block_id * Mc_blocks; const int64_t m_start = mc * Mr; const int64_t m_end = std::min(std::min(mc + Mc_blocks, m_block_end) * Mr, M); const int64_t m_size = m_end - m_start; for (int64_t nc = n_block_start; nc < n_block_end; nc += Nc_blocks) { const int64_t n_start = nc * Nr; const int64_t n_end = std::min(std::min(nc + Nc_blocks, n_block_end) * Nr, N); const int64_t n_size = n_end - n_start; // NB: assume we pad N, nc_block_end won't exceed padded N here. const int64_t nc_block_end = std::min(nc + Nc_blocks, n_block_end); if (_local_acc_buf == nullptr) { _local_acc_buf = std::make_unique<float[]>(static_cast<int64_t>(Mc_blocks*Mr*Nc_blocks*Nr)); local_acc_buf = _local_acc_buf.get(); } for (int64_t kc = k_block_start; kc < k_block_end; kc += Kc_blocks) { int64_t k_start = kc * Kr; int64_t k_end = std::min(std::min(kc + Kc_blocks, k_block_end) * Kr, K); for (int64_t nci = nc; nci < nc_block_end; nci++) { if (kc == k_block_start) { cpp_bmm_micro_gemm<static_cast<bool>(false)>( amx_state, &(X[static_cast<int64_t>(k_start + (96L*m_start) + (36864L*ks_b_index))]), &(W[static_cast<int64_t>((32L*k_start) + (3072L*nci) + (6144L*ks_b_index))]), &(local_acc_buf[static_cast<int64_t>((Nr*nci) + ((-1L)*Nr*nc))]), static_cast<int64_t>(m_end + ((-1L)*m_start)), static_cast<int64_t>(Nr), static_cast<int64_t>(k_end + ((-1L)*k_start)), static_cast<int64_t>(96L), static_cast<int64_t>(32L), static_cast<int64_t>(Nc_blocks*Nr) ); } else { cpp_bmm_micro_gemm<static_cast<bool>(true)>( amx_state, &(X[static_cast<int64_t>(k_start + (96L*m_start) + (36864L*ks_b_index))]), &(W[static_cast<int64_t>((32L*k_start) + (3072L*nci) + (6144L*ks_b_index))]), &(local_acc_buf[static_cast<int64_t>((Nr*nci) + ((-1L)*Nr*nc))]), static_cast<int64_t>(m_end + ((-1L)*m_start)), static_cast<int64_t>(Nr), static_cast<int64_t>(k_end + ((-1L)*k_start)), static_cast<int64_t>(96L), static_cast<int64_t>(32L), static_cast<int64_t>(Nc_blocks*Nr) ); } } } { { #pragma GCC ivdep for(int64_t x0=static_cast<int64_t>(0L); x0<static_cast<int64_t>(m_end + ((-1L)*m_start)); x0+=static_cast<int64_t>(1L)) { for(int64_t x1=static_cast<int64_t>(0L); x1<static_cast<int64_t>(16L*(c10::div_floor_integer(static_cast<int64_t>((n_end + ((-1L)*n_start))), static_cast<int64_t>(16L)))); x1+=static_cast<int64_t>(16L)) { auto tmp0 = at::vec::Vectorized<float>::loadu(local_acc_buf + static_cast<int64_t>(x1 + (Nc_blocks*Nr*x0)), static_cast<int64_t>(16)); auto tmp1 = at::vec::convert<bfloat16>(tmp0); tmp1.store(Y + static_cast<int64_t>(n_start + x1 + (64L*m_start) + (64L*x0) + (24576L*ks_b_index)), static_cast<int64_t>(16)); } for(int64_t x1=static_cast<int64_t>(16L*(c10::div_floor_integer(static_cast<int64_t>((n_end + ((-1L)*n_start))), static_cast<int64_t>(16L)))); x1<static_cast<int64_t>(n_end + ((-1L)*n_start)); x1+=(static_cast<int64_t>(n_end + ((-1L)*n_start) + ((-16L)*(c10::div_floor_integer(static_cast<int64_t>((n_end + ((-1L)*n_start))), static_cast<int64_t>(16L))))) == 0 ? 1 : static_cast<int64_t>(n_end + ((-1L)*n_start) + ((-16L)*(c10::div_floor_integer(static_cast<int64_t>((n_end + ((-1L)*n_start))), static_cast<int64_t>(16L))))))) { auto tmp0 = at::vec::Vectorized<float>::loadu(local_acc_buf + static_cast<int64_t>(x1 + (Nc_blocks*Nr*x0)), static_cast<int64_t>(n_end + ((-1L)*n_start) + ((-16L)*(c10::div_floor_integer(static_cast<int64_t>((n_end + ((-1L)*n_start))), static_cast<int64_t>(16L)))))); auto tmp1 = at::vec::convert<bfloat16>(tmp0); tmp1.store(Y + static_cast<int64_t>(n_start + x1 + (64L*m_start) + (64L*x0) + (24576L*ks_b_index)), static_cast<int64_t>(n_end + ((-1L)*n_start) + ((-16L)*(c10::div_floor_integer(static_cast<int64_t>((n_end + ((-1L)*n_start))), static_cast<int64_t>(16L)))))); } } } } } } amx_state.release([]() { _tile_release(); }); } } void single_thread_mm(const bfloat16* X, const bfloat16* W, bfloat16* Y, const int64_t ks_b_index) { constexpr int64_t num_threads = 1; constexpr int64_t N = 64; constexpr int64_t K = 96; constexpr int64_t Mr = 32; constexpr int64_t Nr = 32; constexpr int64_t Kr = 32; constexpr int64_t Nr_blocks = (N + Nr - 1) / Nr; constexpr int64_t Kr_blocks = (K + Kr - 1) / Kr; constexpr int64_t M = static_cast<int64_t>(384L); constexpr int64_t Mr_blocks = (M + Mr - 1) / Mr; constexpr int64_t Mt_blocks = 12; constexpr int64_t Nt_blocks = 2; constexpr int64_t Kt_blocks = 3; constexpr int64_t Mc_blocks = 12; constexpr int64_t Nc_blocks = 1; constexpr int64_t Kc_blocks = 3; constexpr int64_t num_Mc_blocks = (Mr_blocks + Mc_blocks - 1) / Mc_blocks; constexpr int64_t num_Nc_blocks = (Nr_blocks + Nc_blocks - 1) / Nc_blocks; constexpr int64_t num_Mt_blocks = (Mr_blocks + Mt_blocks - 1) / Mt_blocks; constexpr int64_t num_Nt_blocks = (Nr_blocks + Nt_blocks - 1) / Nt_blocks; constexpr int64_t num_Kt_blocks = (Kr_blocks + Kt_blocks - 1) / Kt_blocks; // make sure all partitions are assigned AOTI_TORCH_CHECK( Mt_blocks * Nt_blocks * Kt_blocks * 1 >= Mr_blocks * Nr_blocks * Kr_blocks, "Not all partitions are assigned." ); { constexpr int tid = 0; constexpr int64_t k_group_id = 0; constexpr int64_t k_slice_id = 0; constexpr int64_t n_group_id = 0; constexpr int64_t n_slice_id = 0; constexpr int64_t m_block_start = 0; constexpr int64_t n_block_start = 0; constexpr int64_t n_block_end = Nr_blocks; constexpr int64_t k_block_start = 0; constexpr int64_t k_block_end = Kr_blocks; constexpr int64_t num_Mc_blocks_per_thread = num_Mc_blocks; constexpr int64_t m_block_end = Mr_blocks; AMXState amx_state; auto _local_acc_buf = std::make_unique<float[]>(static_cast<int64_t>(Mc_blocks*Mr*Nc_blocks*Nr)); auto local_acc_buf = _local_acc_buf.get(); for (int64_t mc_block_id = 0; mc_block_id < num_Mc_blocks_per_thread; mc_block_id++) { const int64_t my_mc_block_id = (mc_block_id + n_slice_id) % num_Mc_blocks_per_thread; const int64_t mc = m_block_start + my_mc_block_id * Mc_blocks; const int64_t m_start = mc * Mr; const int64_t m_end = std::min(std::min(mc + Mc_blocks, m_block_end) * Mr, M); const int64_t m_size = m_end - m_start; for (int64_t nc = n_block_start; nc < n_block_end; nc += Nc_blocks) { const int64_t n_start = nc * Nr; const int64_t n_end = std::min(std::min(nc + Nc_blocks, n_block_end) * Nr, N); const int64_t n_size = n_end - n_start; // NB: assume we pad N, nc_block_end won't exceed padded N here. const int64_t nc_block_end = std::min(nc + Nc_blocks, n_block_end); if (_local_acc_buf == nullptr) { _local_acc_buf = std::make_unique<float[]>(static_cast<int64_t>(Mc_blocks*Mr*Nc_blocks*Nr)); local_acc_buf = _local_acc_buf.get(); } for (int64_t kc = k_block_start; kc < k_block_end; kc += Kc_blocks) { int64_t k_start = kc * Kr; int64_t k_end = std::min(std::min(kc + Kc_blocks, k_block_end) * Kr, K); for (int64_t nci = nc; nci < nc_block_end; nci++) { if (kc == k_block_start) { cpp_bmm_micro_gemm<static_cast<bool>(false)>( amx_state, &(X[static_cast<int64_t>(k_start + (96L*m_start) + (36864L*ks_b_index))]), &(W[static_cast<int64_t>((32L*k_start) + (3072L*nci) + (6144L*ks_b_index))]), &(local_acc_buf[static_cast<int64_t>((Nr*nci) + ((-1L)*Nr*nc))]), static_cast<int64_t>(m_end + ((-1L)*m_start)), static_cast<int64_t>(Nr), static_cast<int64_t>(k_end + ((-1L)*k_start)), static_cast<int64_t>(96L), static_cast<int64_t>(32L), static_cast<int64_t>(Nc_blocks*Nr) ); } else { cpp_bmm_micro_gemm<static_cast<bool>(true)>( amx_state, &(X[static_cast<int64_t>(k_start + (96L*m_start) + (36864L*ks_b_index))]), &(W[static_cast<int64_t>((32L*k_start) + (3072L*nci) + (6144L*ks_b_index))]), &(local_acc_buf[static_cast<int64_t>((Nr*nci) + ((-1L)*Nr*nc))]), static_cast<int64_t>(m_end + ((-1L)*m_start)), static_cast<int64_t>(Nr), static_cast<int64_t>(k_end + ((-1L)*k_start)), static_cast<int64_t>(96L), static_cast<int64_t>(32L), static_cast<int64_t>(Nc_blocks*Nr) ); } } } { { #pragma GCC ivdep for(int64_t x0=static_cast<int64_t>(0L); x0<static_cast<int64_t>(m_end + ((-1L)*m_start)); x0+=static_cast<int64_t>(1L)) { for(int64_t x1=static_cast<int64_t>(0L); x1<static_cast<int64_t>(16L*(c10::div_floor_integer(static_cast<int64_t>((n_end + ((-1L)*n_start))), static_cast<int64_t>(16L)))); x1+=static_cast<int64_t>(16L)) { auto tmp0 = at::vec::Vectorized<float>::loadu(local_acc_buf + static_cast<int64_t>(x1 + (Nc_blocks*Nr*x0)), static_cast<int64_t>(16)); auto tmp1 = at::vec::convert<bfloat16>(tmp0); tmp1.store(Y + static_cast<int64_t>(n_start + x1 + (64L*m_start) + (64L*x0) + (24576L*ks_b_index)), static_cast<int64_t>(16)); } for(int64_t x1=static_cast<int64_t>(16L*(c10::div_floor_integer(static_cast<int64_t>((n_end + ((-1L)*n_start))), static_cast<int64_t>(16L)))); x1<static_cast<int64_t>(n_end + ((-1L)*n_start)); x1+=(static_cast<int64_t>(n_end + ((-1L)*n_start) + ((-16L)*(c10::div_floor_integer(static_cast<int64_t>((n_end + ((-1L)*n_start))), static_cast<int64_t>(16L))))) == 0 ? 1 : static_cast<int64_t>(n_end + ((-1L)*n_start) + ((-16L)*(c10::div_floor_integer(static_cast<int64_t>((n_end + ((-1L)*n_start))), static_cast<int64_t>(16L))))))) { auto tmp0 = at::vec::Vectorized<float>::loadu(local_acc_buf + static_cast<int64_t>(x1 + (Nc_blocks*Nr*x0)), static_cast<int64_t>(n_end + ((-1L)*n_start) + ((-16L)*(c10::div_floor_integer(static_cast<int64_t>((n_end + ((-1L)*n_start))), static_cast<int64_t>(16L)))))); auto tmp1 = at::vec::convert<bfloat16>(tmp0); tmp1.store(Y + static_cast<int64_t>(n_start + x1 + (64L*m_start) + (64L*x0) + (24576L*ks_b_index)), static_cast<int64_t>(n_end + ((-1L)*n_start) + ((-16L)*(c10::div_floor_integer(static_cast<int64_t>((n_end + ((-1L)*n_start))), static_cast<int64_t>(16L)))))); } } } } } } amx_state.release([]() { _tile_release(); }); } } extern "C" void cpp_bmm(const bfloat16* X, const bfloat16* W, bfloat16* Y) { const int64_t B = static_cast<int64_t>(5L); constexpr int64_t num_threads = 48; int64_t B_single_thread_block = (B / num_threads) * num_threads; #pragma omp parallel for num_threads(48) for (int64_t b_start = 0; b_start < B_single_thread_block; ++b_start) { single_thread_mm(X, W, Y, b_start); } for (int64_t b_start = B_single_thread_block; b_start < B; ++b_start) { threaded_mm(X, W, Y, b_start); } } ``` Pull Request resolved: #129772 Approved by: https://github.com/jgong5, https://github.com/leslie-fang-intel, https://github.com/jansel
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Labels
module: inductor
oncall: cpu inductor
CPU Inductor issues for Intel team to triage
oncall: pt2
triaged
This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Uh oh!
There was an error while loading. Please reload this page.
🚀 The feature, motivation and pitch
Motivation
torch.compile
provides the "max-autotune" mode. For CUDA, the inductor backend leverages online benchmark results to select the best-performing kernels from various options, including ATen kernels and template-based kernels implemented with Triton and CUTLASS. These kernels are primarily designed to accelerate GEMM-related operations. However, for CPU, this "max-autotune" mechanism is not yet supported, and only ATen kernels are currently utilized.This RFC proposes the introduction of similar template-based code generation support for GEMM-related operations on CPUs, implemented with C++ and activated through the "max-autotune" mode of
torch.compile
. By utilizing the autotuning mechanism of Inductor, users are expected to achieve enhanced performance for GEMM-related operations beyond the capabilities of ATen-based implementations.Approaches
At a high level, the autotuning and template infrastructure from CUDA is mature enough to be adapted for CPU usage. We plan to extend the existing autotuning code to support CPU and develop the C++ template abstraction by referencing the CUTLASS template counterpart. Additionally, CPU-specific challenges such as thread decomposition, data layout arrangement (e.g., weight prepacking), and data blocking at various memory hierarchy levels for optimal performance need to be addressed. Based on our previous experiences, we employ a two-level abstraction to implement GEMMs: an outer loop that manages thread decomposition and cache blocking, and an inner micro-kernel that handles register blocking and various CPU architecture-specific optimizations. This approach allows for flexible performance tuning at multiple levels and direct utilization of low-level CPU hardware acceleration.
Key Components
Task Breakdowns
Alternatives
No response
Additional context
No response
cc @ezyang @msaroufim @bdhirsh @anijain2305 @chauhang @voznesenskym @penguinwu @EikanWang @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @peterbell10 @ipiszy @yf225 @chenyang78 @kadeng @muchulee8 @ColinPeppler @amjames @desertfire
The text was updated successfully, but these errors were encountered: