10000 Update base for Update on "[inductor][cpp] GEMM template (infra and f… · pytorch/pytorch@8035bb4 · GitHub
[go: up one dir, main page]

Skip to content

Commit 8035bb4

Browse files
author
Jiong Gong
committed
Update base for Update on "[inductor][cpp] GEMM template (infra and fp32)"
This PR adds the Cpp template infrastructure and the initial FP32 gemm template. See RFC #125683 for more background info. 1. Cpp template infrastructure Similar template abstractions as the CUTLASS template, i.e., `CppTemplate`, `CppTemplateKernel`, `CppTemplateBuffer`. The MicroGemm micro-kernel abstraction that can be used by Cpp GEMM templates. 2. Initial FP32 gemm template This involves a GEMM template implementation `CppPackedGemmTemplate` that supports GEMM with constant weight (`B`) requiring `N` to be a multiple of register blocking while allows the static or dynamic sizes for the `M` (batch dim) of `A`. The `B` matrix would be prepacked. This is a typical setting for inference workloads. The template handles the thread decomposition (via `thread_blocking`) and cache blocking (via `cache_blocking`). Then it invokes `CppMicroGemm` which handles register blocking, instruction selection, and other CPU architecture-specific optimizations. A `CppMicroGemmFP32Vec` micro-kernel implementation is provided for fp32 matmuls implemented with ATen vec abstraction. 3. Correctness and performance The changes have been validated with fp32 inference on the three benchmark suites (torchbench, huggingface and timm_models) with both static shape and dynamic shapes. Since it is an initial implementation, we are still working on further performance improves with follow-up PRs including the optimizations in kernels as well as fusions. The perf gains are only observed from a selective number of models compared to the ATen kernels which are implemented with MKL. The perf gains are more obvious with dynamic shapes since MKL only supports packed gemm for static shapes. Below are details. Static shapes | Benchmark | torchbench | huggingface | timm_models | |------------|-------------|--------------|--------------| | Multi-threaded (baseline) | 1.47x | 1.36x | 1.91x | | Multi-threaded (max-autotune) | 1.47x | 1.36x | 1.92x | | Single-threaded (baseline) | 1.56x | 1.19x | 1.51x | | Single-threaded (max-autotune) | 1.56x | 1.19x | 1.52x | Key models being sped up: drq: 1.14x soft_act: 1.12 cait_m36_384: 1.18x Dynamic shapes | Benchmark | torchbench | huggingface | timm_models | | --- | --- | --- | --- | | Multi-threaded (baseline) | 1.43x | 1.28x | 1.85x | | Multi-threaded (max-autotune) | 1.47x | 1.28x | 1.85x | | Single-threaded (baseline) | 1.55x | 1.20x | 1.51x | | Single-threaded (max-autotune) | 1.56x | 1.19x | 1.53x | Key models being sped up: BERT_pytorch: 1.22x pyhpc_turbulent: 1.13x soft_actor_critic: 1.77x BlenderbotForCausalLM: 1.09x cait_m36_384: 1.17x cc voznesenskym penguinwu EikanWang Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler amjames desertfire chauhang [ghstack-poisoned]
2 parents b4f772b + 5007312 commit 8035bb4

File tree

4 files changed

+42
-7
lines changed

4 files changed

+42
-7
lines changed

test/distributed/tensor/parallel/test_tp_style.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -168,6 +168,29 @@ def test_rowwise_parallel_embedding(self):
168168
# no comm in bwd
169169
self.assertEqual(comm_mode.get_total_counts(), 1)
170170

171+
sharded_row_parallel = RowwiseParallel(
172+
input_layouts=Replicate(), output_layouts=Shard(1)
173+
)
174+
175+
rowwise_mod = parallelize_module(deepcopy(model), mesh, sharded_row_parallel)
176+
177+
inp_indices = torch.arange(8, device=self.device_type)
178+
with comm_mode:
179+
out = rowwise_mod(inp_indices)
180+
# ensure output shard on the last dim
181+
self.assertEqual(out.shape, (8, 16 // self.world_size))
182+
# reduce scatter in fwd
183+
self.assertEqual(comm_mode.get_total_counts(), 1)
184+
self.assertEqual(
185+
comm_mode.get_comm_counts()[c10d_functional.reduce_scatter_tensor], 1
186+
)
187+
out.sum().backward()
188+
# allgather comm in bwd
189+
self.assertEqual(comm_mode.get_total_counts(), 2)
190+
self.assertEqual(
191+
comm_mode.get_comm_counts()[c10d_functional.all_gather_into_tensor], 1
192+
)
193+
171194
@with_comms
172195
def test_prepare_module_input(self):
173196
mesh = init_device_mesh(self.device_type, (self.world_size,))

torch/_inductor/runtime/runtime_utils.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010
import time
1111

1212
import torch
13-
from torch._inductor.utils import is_cpu_device
1413

1514

1615
def conditional_product(*args):
@@ -72,6 +71,8 @@ def get_max_y_grid():
7271

7372

7473
def do_bench(fn, fn_args, fn_kwargs, **kwargs):
74+
from torch._inductor.utils import is_cpu_device
75+
7576
args = list(fn_args)
7677
args.extend(fn_kwargs.values())
7778
if is_cpu_device(args):

torch/distributed/_tensor/dispatch.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -187,10 +187,18 @@ def default_tensor(spec: DTensorSpec) -> torch.Tensor:
187187
# Default to `OffsetBasedRNGTracker` if the parallelism API
188188
# did not already construct one
189189
random._rng_tracker = random.OffsetBasedRNGTracker(mesh.device_type)
190+
191+
first_arg, first_local_arg = cast(dtensor.DTensor, args[0]), cast(
192+
torch.Tensor, local_tensor_args[0]
193+
)
194+
rng_context = (
195+
random._rng_tracker._distribute_region(first_arg._spec)
196+
if random._rng_tracker and not first_local_arg.is_meta
197+
else contextlib.nullcontext()
198+
)
199+
190200
# For DTensor random operator, run it within a distribute region
191-
with random._rng_tracker._distribute_region(
192-
cast(dtensor.DTensor, args[0])._spec
193-
) if random._rng_tracker else contextlib.nullcontext():
201+
with rng_context:
194202
local_results = op_call(*local_tensor_args, **op_info.local_kwargs)
195203
else:
196204
local_results = op_call(*local_tensor_args, **op_info.local_kwargs)

torch/distributed/_tensor/op_schema.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -104,9 +104,12 @@ def input_spec(self, index: int = 0) -> DTensorSpec:
104104
return self.input_specs[index]
105105

106106
def __str__(self) -> str:
107-
input_specs_str = _pretty_print_spec(self.input_specs)
107+
if self.input_specs is not None:
108+
input_specs_str = f"{_pretty_print_spec(self.input_specs)} -> "
109+
else:
110+
input_specs_str = ""
108111
output_spec_str = _pretty_print_spec(self.output_specs)
109-
return f"{input_specs_str} -> {output_spec_str}"
112+
return f"{input_specs_str}{output_spec_str}"
110113

111114

112115
class StrategyType:
@@ -130,7 +133,7 @@ def __init__(self, strategies: List[PlacementStrategy]) -> None:
130133
def __str__(self) -> str:
131134
strategy_list_str = ", ".join([str(strategy) for strategy in self.strategies])
132135
mesh_shape = self.output_mesh_shape
133-
return f"OpStrategy:[{strategy_list_str}] @ mesh: {mesh_shape}"
136+
return f"[{strategy_list_str}] @ mesh: {mesh_shape}"
134137

135138
def max_num_shards(self) -> int:
136139
"""

0 commit comments

Comments
 (0)
0