8000 Update base for Update on "[inductor][cpp] GEMM template (infra and f… · pytorch/pytorch@41057e8 · GitHub
[go: up one dir, main page]

Skip to content

Commit 41057e8

Browse files
author
Jiong Gong
committed
Update base for Update on "[inductor][cpp] GEMM template (infra and fp32)"
This PR adds the Cpp template infrastructure and the initial FP32 gemm template. See RFC #125683 for more background info. 1. Cpp template infrastructure Similar template abstractions as the CUTLASS template, i.e., `CppTemplate`, `CppTemplateKernel`, `CppTemplateBuffer`. The MicroGemm micro-kernel abstraction that can be used by Cpp GEMM templates. 2. Initial FP32 gemm template This involves a GEMM template implementation `CppPackedGemmTemplate` that supports GEMM with constant weight (`B`) requiring `N` to be a multiple of register blocking while allows the static or dynamic sizes for the `M` (batch dim) of `A`. The `B` matrix would be prepacked. This is a typical setting for inference workloads. The template handles the thread decomposition (via `thread_blocking`) and cache blocking (via `cache_blocking`). Then it invokes `CppMicroGemm` which handles register blocking, instruction selection, and other CPU architecture-specific optimizations. A `CppMicroGemmFP32Vec` micro-kernel implementation is provided for fp32 matmuls implemented with ATen vec abstraction. 3. Correctness and performance The changes have been validated with fp32 inference on the three benchmark suites (torchbench, huggingface and timm_models) with both static shape and dynamic shapes. Since it is an initial implementation, we are still working on further performance improves with follow-up PRs including the optimizations in kernels as well as fusions. The perf gains are only observed from a selective number of models compared to the ATen kernels which are implemented with MKL. The perf gains are more obvious with dynamic shapes since MKL only supports packed gemm for static shapes. Below are details. Static shapes | Benchmark | torchbench | huggingface | timm_models | |------------|-------------|--------------|--------------| | Multi-threaded (baseline) | 1.47x | 1.36x | 1.91x | | Multi-threaded (max-autotune) | 1.47x | 1.36x | 1.92x | | Single-threaded (baseline) | 1.56x | 1.19x | 1.51x | | Single-threaded (max-autotune) | 1.56x | 1.19x | 1.52x | Key models being sped up: drq: 1.14x soft_act: 1.12 cait_m36_384: 1.18x Dynamic shapes | Benchmark | torchbench | huggingface | timm_models | | --- | --- | --- | --- | | Multi-threaded (baseline) | 1.43x | 1.28x | 1.85x | | Multi-threaded (max-autotune) | 1.47x | 1.28x | 1.85x | | Single-threaded (baseline) | 1.55x | 1.20x | 1.51x | | Single-threaded (max-autotune) | 1.56x | 1.19x | 1.53x | Key models being sped up: BERT_pytorch: 1.22x pyhpc_turbulent: 1.13x soft_actor_critic: 1.77x BlenderbotForCausalLM: 1.09x cait_m36_384: 1.17x cc voznesenskym penguinwu EikanWang Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler amjames desertfire chauhang [ghstack-poisoned]
2 parents 8035bb4 + aaa2f93 commit 41057e8

File tree

161 files changed

+5640
-2527
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

161 files changed

+5640
-2527
lines changed

.github/ci_commit_pins/vision.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
06ad737628abc3a1e617571dc03cbdd5b36ea96a
1+
d23a6e1664d20707c11781299611436e1f0c104f

.github/scripts/amd/patch_triton_wheel.sh

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,11 @@
11
#!/bin/bash
22
set -x
33

4-
WHEELHOUSE_DIR=/artifacts
4+
if [ -z "$1" ]; then
5+
echo "Need wheel location argument" && exit 1
6+
fi
7+
8+
WHEELHOUSE_DIR=$1
59
PATCHELF_BIN=patchelf
610
ROCM_LIB=backends/amd/lib
711
ROCM_LD=backends/amd/llvm/bin

.github/scripts/build_triton_wheel.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -157,10 +157,10 @@ def build_triton(
157157

158158
if build_rocm:
159159
check_call(
160-
[f"{SCRIPT_DIR}/amd/patch_triton_wheel.sh"],
160+
[f"{SCRIPT_DIR}/amd/patch_triton_wheel.sh", Path.cwd()],
161161
cwd=triton_basedir,
162-
shell=True,
163162
)
163+
164164
return Path.cwd() / whl_path.name
165165

166166

.gitignore

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ torch/csrc/api/include/torch/version.h
8787
torch/csrc/cudnn/cuDNN.cpp
8888
torch/csrc/generated
8989
torch/csrc/generic/TensorMethods.cpp
90-
torch/csrc/inductor/aoti_torch/generated/*
90+
torch/csrc/inductor/aoti_torch/generated/*.cpp
9191
torch/csrc/jit/generated/*
9292
torch/csrc/jit/fuser/config.h
9393
torch/csrc/nn/THCUNN.cpp

.lintrunner.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,7 @@ exclude_patterns = [
7878
'aten/src/ATen/native/vulkan/api/vk_mem_alloc.h',
7979
'c10/util/strong_type.h',
8080
'**/fb/**',
81+
'torch/csrc/inductor/aoti_torch/generated/**',
8182
'torch/csrc/jit/serialization/mobile_bytecode_generated.h',
8283
'torch/csrc/utils/pythoncapi_compat.h',
8384
'aten/src/ATen/dlpack.h',

BUILD.bazel

Lines changed: 4 additions & 731 deletions
Large diffs are not rendered by default.

CMakeLists.txt

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -232,7 +232,6 @@ option(USE_GFLAGS "Use GFLAGS" OFF)
232232
option(USE_GLOG "Use GLOG" OFF)
233233
option(USE_LITE_PROTO "Use lite protobuf instead of full." OFF)
234234
option(USE_MAGMA "Use MAGMA" ON)
235-
option(USE_METAL "Use Metal for Caffe2 iOS build" ON)
236235
option(USE_PYTORCH_METAL "Use Metal for PyTorch iOS build" OFF)
237236
option(USE_PYTORCH_METAL_EXPORT "Export Metal models on MacOSX desktop" OFF)
238237
option(USE_NATIVE_ARCH "Use -march=native" OFF)

aten/src/ATen/core/interned_strings.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -227,6 +227,7 @@ namespace c10 {
227227
_(aten, is_autocast_enabled) \
228228
_(aten, is_autocast_cpu_enabled) \
229229
_(aten, is_autocast_xla_enabled) \
230+
_(aten, get_autocast_dtype) \
230231
FORALL_ATEN_BASE_SYMBOLS(_) \
231232
_(onnx, Add) \
232233
_(onnx, Concat) \

aten/src/ATen/cuda/CUDABlas.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -236,7 +236,7 @@ namespace at::cuda::blas {
236236
CUDABLAS_NONNEGINT_CHECK(bgemm<Dtype>, num_batches); \
237237
} while (0)
238238

239-
#if (!defined(USE_ROCM) && !defined(_MSC_VER)) || (defined(USE_ROCM) && ROCM_VERSION >= 50700)
239+
#if !defined(USE_ROCM) || (defined(USE_ROCM) && ROCM_VERSION >= 50700)
240240

241241
#if defined(USE_ROCM) && ROCM_VERSION >= 50700 && ROCM_VERSION < 60000
242242
// only for rocm 5.7 where we first supported hipblaslt, it was difficult
@@ -375,7 +375,7 @@ class CuBlasLtMatmulPreference : public CuBlasLtDescriptor<
375375

376376
template <typename Dtype>
377377
inline void bgemm_internal_cublaslt(CUDABLAS_BGEMM_ARGTYPES(Dtype)) {
378-
#if (!defined(USE_ROCM) && !defined(_MSC_VER)) || (defined(USE_ROCM) && ROCM_VERSION >= 60000)
378+
#if !defined(USE_ROCM) || (defined(USE_ROCM) && ROCM_VERSION >= 60000)
379379
cudaDataType_t abcType = CUDA_R_32F;
380380
cublasComputeType_t computeType = CUBLAS_COMPUTE_32F;
381381
cudaDataType_t scaleType = CUDA_R_32F;
@@ -1235,7 +1235,7 @@ void gemm<at::BFloat16>(CUDABLAS_GEMM_ARGTYPES(at::BFloat16)) {
12351235
}
12361236
}
12371237

1238-
#if (!defined(USE_ROCM) && !defined(_MSC_VER)) || (defined(USE_ROCM) && ROCM_VERSION >= 50700)
1238+
#if !defined(USE_ROCM) || (defined(USE_ROCM) && ROCM_VERSION >= 50700)
12391239

12401240
template <typename Dtype>
12411241
void gemm_and_bias(
@@ -1745,7 +1745,7 @@ void int8_gemm(
17451745
TORCH_CHECK(false, "int8_gemm is only supported for ROCm 6.0 and above");
17461746
#endif // !defined(USE_ROCM) || (defined(USE_ROCM) && ROCM_VERSION >= 60000)
17471747
}
1748-
#endif // (!defined(USE_ROCM) && !defined(_MSC_VER)) || (defined(USE_ROCM) && ROCM_VERSION >= 50700)
1748+
#endif // !defined(USE_ROCM) || (defined(USE_ROCM) && ROCM_VERSION >= 50700)
17491749

17501750
// ROCm 5.6 hipblas matches the const Dtype *A API, but prior hipblas does not.
17511751
#if defined(USE_ROCM) && ROCM_VERSION < 50600

aten/src/ATen/cuda/CUDABlas.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ void gemm_internal<at::Half>(CUDABLAS_GEMM_ARGTYPES(at::Half));
8282
template <>
8383
void gemm_internal<at::BFloat16>(CUDABLAS_GEMM_ARGTYPES(at::BFloat16));
8484

85-
#if (!defined(USE_ROCM) && !defined(_MSC_VER)) || (defined(USE_ROCM) && ROCM_VERSION >= 50700)
85+
#if !defined(USE_ROCM) || (defined(USE_ROCM) && ROCM_VERSION >= 50700)
8686
enum GEMMAndBiasActivationEpilogue {
8787
None,
8888
RELU,

0 commit comments

Comments
 (0)
0