8000 Update on "[inductor][cpp] GEMM template (infra and fp32)" · pytorch/pytorch@92f4ac4 · GitHub
[go: up one dir, main page]

Skip to content

Commit 92f4ac4

Browse files
author
Jiong Gong
committed
Update on "[inductor][cpp] GEMM template (infra and fp32)"
This PR adds the Cpp template infrastructure and the initial FP32 gemm template. See RFC #125683 for more background info. 1. Cpp template infrastructure Similar template abstractions as the CUTLASS template, i.e., `CppTemplate`, `CppTemplateKernel`, `CppTemplateBuffer`. The MicroGemm micro-kernel abstraction that can be used by Cpp GEMM templates. 2. Initial FP32 gemm template This involves a GEMM template implementation `CppPackedGemmTemplate` that supports GEMM with constant weight (`B`) requiring `N` to be a multiple of register blocking while allows the static or dynamic sizes for the `M` (batch dim) of `A`. The `B` matrix would be prepacked. This is a typical setting for inference workloads. The template handles the thread decomposition (via `thread_blocking`) and cache blocking (via `cache_blocking`). Then it invokes `CppMicroGemm` which handles register blocking, instruction selection, and other CPU architecture-specific optimizations. A `CppMicroGemmFP32Vec` micro-kernel implementation is provided for fp32 matmuls implemented with ATen vec abstraction. 3. Correctness and performance The changes have been validated with fp32 inference on the three benchmark suites (torchbench, huggingface and timm_models) with both static shape and dynamic shapes. Since it is an initial implementation, we are still working on further performance improves with follow-up PRs including the optimizations in kernels as well as fusions. The perf gains are only observed from a selective number of models compared to the ATen kernels which are implemented with MKL. The perf gains are more obvious with dynamic shapes since MKL only supports packed gemm for static shapes. Below are details. Static shapes | Benchmark | torchbench | huggingface | timm_models | |------------|-------------|--------------|--------------| | Multi-threaded (baseline) | 1.47x | 1.36x | 1.91x | | Multi-threaded (max-autotune) | 1.47x | 1.36x | 1.92x | | Single-threaded (baseline) | 1.56x | 1.19x | 1.51x | | Single-threaded (max-autotune) | 1.56x | 1.19x | 1.52x | Key models being sped up: drq: 1.14x soft_act: 1.12 cait_m36_384: 1.18x Dynamic shapes | Benchmark | torchbench | huggingface | timm_models | | --- | --- | --- | --- | | Multi-threaded (baseline) | 1.43x | 1.28x | 1.85x | | Multi-threaded (max-autotune) | 1.47x | 1.28x | 1.85x | | Single-threaded (baseline) | 1.55x | 1.20x | 1.51x | | Single-threaded (max-autotune) | 1.56x | 1.19x | 1.53x | Key models being sped up: BERT_pytorch: 1.22x pyhpc_turbulent: 1.13x soft_actor_critic: 1.77x BlenderbotForCausalLM: 1.09x cait_m36_384: 1.17x cc voznesenskym penguinwu EikanWang Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler amjames desertfire chauhang [ghstack-poisoned]
2 parents 0162cf6 + b4f772b commit 92f4ac4

File tree

12 files changed

+116
-27
lines changed

12 files changed

+116
-27
lines changed

cmake/Dependencies.cmake

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1841,6 +1841,8 @@ if(USE_KINETO)
18411841
set(CUPTI_LIB_NAME "cupti.lib")
18421842
endif()
18431843

1844+
set(NVPERF_HOST_LIB_NAME "libnvperf_host.so")
1845+
18441846
find_library(CUPTI_LIBRARY_PATH ${CUPTI_LIB_NAME} PATHS
18451847
${CUDA_SOURCE_DIR}
18461848
${CUDA_SOURCE_DIR}/extras/CUPTI/lib64
@@ -1855,13 +1857,27 @@ if(USE_KINETO)
18551857
${CUDA_SOURCE_DIR}/include
18561858
NO_DEFAULT_PATH)
18571859

1860+
find_library(NVPERF_HOST_LIBRARY_PATH ${NVPERF_HOST_LIB_NAME} PATHS
1861+
${CUDA_SOURCE_DIR}
1862+
${CUDA_SOURCE_DIR}/lib
1863+
${CUDA_SOURCE_DIR}/lib64
1864+
${CUDA_SOURCE_DIR}/extras/CUPTI/lib64
1865+
NO_DEFAULT_PATH)
1866+
18581867
if(CUPTI_LIBRARY_PATH AND CUPTI_INCLUDE_DIR)
18591868
message(STATUS " CUPTI_INCLUDE_DIR = ${CUPTI_INCLUDE_DIR}")
18601869
set(CUDA_cupti_LIBRARY ${CUPTI_LIBRARY_PATH})
18611870
message(STATUS " CUDA_cupti_LIBRARY = ${CUDA_cupti_LIBRARY}")
1871+
# CUPTI Range Profiler requires the NVPerf library
1872+
# for configuring metrics
1873+
if(NVPERF_HOST_LIBRARY_PATH)
1874+
set(CUDA_nvperf_host_LIBRARY ${NVPERF_HOST_LIBRARY_PATH})
1875+
message(STATUS " CUDA_nvperf_host_LIBRARY = ${NVPERF_HOST_LIBRARY_PATH}")
1876+
endif()
18621877
message(STATUS "Found CUPTI")
18631878
set(LIBKINETO_NOCUPTI OFF CACHE STRING "" FORCE)
18641879

1880+
18651881
# I've only tested this sanity check on Linux; if someone
18661882
# runs into this bug on another platform feel free to
18671883
# generalize it accordingly

test/profiler/test_profiler.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -699,6 +699,48 @@ def create_mkldnn_tensor():
699699
if torch.cuda.is_available():
700700
check_metrics(stats, "device_memory_usage", deallocs=["[memory]"])
701701

702+
@unittest.skipIf(not kineto_available(), "Kineto is required")
703+
@unittest.skipIf(not torch.cuda.is_available(), "CUDA is required")
704+
def test_kineto_cupti_range_profiler(self):
705+
"""CUPTI provides a newer Profiling API from CUDA 10.0 that enables measuring
706+
performance events for the GPU. This is supported as an experimental pytorch profiler feature.
707+
Read more here https://docs.nvidia.com/cupti/r_main.html#r_profiler.
708+
"""
709+
exp_config = _ExperimentalConfig(
710+
profiler_metrics=[
711+
# Metrics list at https://docs.nvidia.com/cupti/r_main.html#r_profiler
712+
# or use kineto__tensor_core_insts, kineto__cuda_core_flops
713+
"kineto__tensor_core_insts",
714+
"dram__bytes_read.sum",
715+
"dram__bytes_write.sum",
716+
],
717+
profiler_measure_per_kernel=True,
718+
)
719+
with _profile(
720+
use_cuda=True, use_kineto=True, experimental_config=exp_config
721+
) as p:
722+
self.payload(use_cuda=True)
723+
724+
def check_trace(fname):
725+
with open(fname) as f:
726+
trace = json.load(f)
727+
self.assertTrue("traceEvents" in trace)
728+
events = trace["traceEvents"]
729+
found_cupti_profiler_events = False
730+
for evt in events:
731+
self.assertTrue("name" in evt)
732+
if "__cupti_profiler__" in evt["name"]:
733+
found_cupti_profiler_events = True
734+
# PyTorch OSS CI runs in docker containers where the Range Profiler
735+
# does not have sufficient privilege level (CUPTI_ERROR_INSUFFICIENT_PRIVILEGES).
736+
# We can check that the profiler does not crash the job and the trace is not
737+
# malformed, however do not check the actual presence of data.
738+
self.assertTrue(1 or found_cupti_profiler_events)
739+
740+
with TemporaryFileName(mode="w+") as fname:
741+
p.export_chrome_trace(fname)
742+
check_trace(fname)
743+
702744
@unittest.skipIf(
703745
IS_JETSON, "Jetson has a guard against OOM since host and gpu memory are shared"
704746
)

tor 10000 ch/_inductor/autotune_process.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@
4444
from torch._inductor.select_algorithm import TritonTemplateCaller
4545

4646
from . import config
47-
from .runtime.runtime_utils import do_bench, do_bench_cpu
47+
from .runtime.runtime_utils import do_bench_cpu, do_bench_gpu
4848
from .virtualized import V
4949

5050
CUDA_VISIBLE_DEVICES = "CUDA_VISIBLE_DEVICES"
@@ -592,7 +592,7 @@ def do_bench(
592592
device_idx = torch.cuda.current_device()
593593

594594
with torch.cuda.device(device_idx):
595-
out = do_bench(fn)
595+
out = do_bench_gpu(fn)
596596
torch.cuda.synchronize() # shake out any CUDA errors
597597

598598
return out

torch/_inductor/codegen/multi_kernel.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
from .. import config
88
from ..codecache import PyCodeCache, TritonFuture
9-
from ..runtime.runtime_utils import do_bench
9+
from ..runtime.runtime_utils import do_bench_gpu
1010
from ..utils import cache_on_self
1111
from ..virtualized import V
1212
from .common import TensorArg
@@ -339,7 +339,7 @@ def benchmark_sub_kernels(kernel_calls):
339339
be picked.
340340
"""
341341
return [
342-
do_bench(lambda: kernel_call(True), rep=40, fast_flush=True)
342+
do_bench_gpu(lambda: kernel_call(True), rep=40, fast_flush=True)
343343
for kernel_call in kernel_calls
344344
]
345345

torch/_inductor/codegen/triton.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@
4949
from ..optimize_indexing import indexing_dtype_strength_reduction
5050
from ..runtime.hints import ReductionHint, TRITON_MAX_BLOCK
5151
from ..runtime.runtime_utils import (
52-
do_bench,
52+
do_bench_gpu,
5353
get_max_y_grid,
5454
green_text,
5555
next_power_of_2,
@@ -2651,7 +2651,7 @@ def codegen_kernel_benchmark(self, num_gb, grid=None):
26512651

26522652
result.writeline("args = get_args()")
26532653
result.writeline(
2654-
"ms = do_bench(lambda: call(args), rep=40, fast_flush=True)"
2654+
"ms = do_bench_gpu(lambda: call(args), rep=40, fast_flush=True)"
26552655
)
26562656
result.writeline(f"num_gb = {num_gb}")
26572657
result.writeline("gb_per_s = num_gb / (ms / 1e3)")
@@ -4034,13 +4034,13 @@ def store_cache():
40344034
else:
40354035
# We have to clone the inplace updated arguments to avoid earlier calls
40364036
# generating out of range indices for later calls.
4037-
ms = do_bench(lambda: call(wrapped_jit_function.clone_args(*args)[0]))
4037+
ms = do_bench_gpu(lambda: call(wrapped_jit_function.clone_args(*args)[0]))
40384038

40394039
# overhead of cloning args gives bias for fusing the kernel
40404040
# in the case of mutating/in-placeable second fusion
40414041
# TODO - would be better as a hook in triton do_bench that reset
40424042
# the input values between benchmarking
4043-
ms = ms - do_bench(lambda: wrapped_jit_function.clone_args(*args))
4043+
ms = ms - do_bench_gpu(lambda: wrapped_jit_function.clone_args(*args))
40444044

40454045
log.debug(
40464046
"The fused kernel for %s took %.3f ms to run",

torch/_inductor/fx_passes/pad_mm.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -251,7 +251,7 @@ def should_pad_bench(
251251
return False
252252

253253
do_bench = functools.partial(
254-
torch._inductor.runtime.runtime_utils.do_bench,
254+
torch._inductor.runtime.runtime_utils.do_bench_gpu,
255255
warmup=5,
256256
)
257257

torch/_inductor/ir.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@
7070
)
7171
from .ops_handler import OpCounterCSE
7272
from .runtime.hints import ReductionHint
73-
from .runtime.runtime_utils import do_bench, do_bench_cpu
73+
from .runtime.runtime_utils import do_bench
7474
from .utils import (
7575
argsort,
7676
cache_on_self,
@@ -79,7 +79,6 @@
7979
convert_shape_to_symint,
8080
developer_warning,
8181
get_kernel_metadata,
82-
is_cpu_device,
8382
is_dynamic,
8483
is_gpu,
8584
pad_listlike,
@@ -3628,10 +3627,7 @@ def __init__(self, name, input_nodes, layout):
36283627

36293628
def benchmark(self, *args, out) -> float:
36303629
algo = self.to_callable()
3631-
if is_cpu_device(args):
3632-
return do_bench_cpu(lambda: algo(*args, out=out))
3633-
else:
3634-
return do_bench(lambda: algo(*args, out=out))
3630+
return do_bench(algo, args, {"out": out})
36353631

36363632
def call_name(self) -> str:
36373633
raise NotImplementedError

torch/_inductor/runtime/runtime_utils.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
import time
1111

1212
import torch
13+
from torch._inductor.utils import is_cpu_device
1314

1415

1516
def conditional_product(*args):
@@ -70,7 +71,16 @@ def get_max_y_grid():
7071
return 65535
7172

7273

73-
def do_bench(*args, **kwargs):
74+
def do_bench(fn, fn_args, fn_kwargs, **kwargs):
75+
args = list(fn_args)
76+
args.extend(fn_kwargs.values())
77+
if is_cpu_device(args):
78+
return do_bench_cpu(lambda: fn(*fn_args, **fn_kwargs), **kwargs)
79+
else:
80+
return do_bench_gpu(lambda: fn(*fn_args, **fn_kwargs), **kwargs)
81+
82+
83+
def do_bench_gpu(fn, **kwargs):
7484
@functools.lru_cache(None)
7585
def load_triton():
7686
try:
@@ -98,7 +108,7 @@ def load_triton():
98108

99109
if quantile_field_name not in kwargs:
100110
kwargs[quantile_field_name] = (0.5, 0.2, 0.8)
101-
return triton_do_bench(*args, **kwargs)[0]
111+
return triton_do_bench(fn, **kwargs)[0]
102112

103113

104114
def do_bench_cpu(fn, warmup=5, times=20):

torch/_inductor/runtime/triton_heuristics.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@
3232
ceildiv,
3333
conditional_product,
3434
create_bandwidth_info_str,
35-
do_bench,
35+
do_bench_gpu,
3636
dynamo_timed,
3737
get_first_attr,
3838
get_max_y_grid,
@@ -628,7 +628,7 @@ def kernel_call():
628628
stream=stream,
629629
)
630630

631-
return do_bench(kernel_call, rep=40, fast_flush=True)
631+
return do_bench_gpu(kernel_call, rep=40, fast_flush=True)
632632

633633
def clone_args(self, *args, **kwargs) -> Tuple[List[Any], Dict[str, Any]]:
634634
from ..compile_fx import clone_preserve_strides

torch/_inductor/select_algorithm.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -38,10 +38,9 @@
3838
from .exc import CUDACompileError
3939
from .ir import ChoiceCaller, PrimitiveInfoType
4040
from .runtime.hints import DeviceProperties
41-
from .runtime.runtime_utils import do_bench, do_bench_cpu
41+
from .runtime.runtime_utils import do_bench
4242
from .utils import (
4343
get_dtype_size,
44-
is_cpu_device,
4544
Placeholder,
4645
restore_stdout_stderr,
4746
sympy_dot,
@@ -847,10 +846,7 @@ def benchmark(self, *args, out):
847846
out_new, tuple(out.size()), tuple(out.stride())
848847
)
849848
out.copy_(out_new) # for correctness checking
850-
if is_cpu_device(args):
851-
return do_bench_cpu(lambda: algo(*args))
852-
else:
853-
return do_bench(lambda: algo(*args))
849+
return do_bench(algo, args, {})
854850

855851
def to_callable(self):
856852
fn = self.choice.to_callable()

0 commit comments

Comments
 (0)
0