8000 Incorporate coalesce analysis in codegen · pytorch/pytorch@07eced7 · GitHub
[go: up one dir, main page]

Skip to content

Commit 07eced7

Browse files
committed
Incorporate coalesce analysis in codegen
ghstack-source-id: b6a6026 Pull Request resolved: #153751
1 parent 86b115f commit 07eced7

File tree

6 files changed

+419
-43
lines changed

6 files changed

+419
-43
lines changed

test/inductor/test_loop_ordering.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
from torch._inductor.scheduler import SchedulerNode
1818
from torch._inductor.test_case import run_tests, TestCase
1919
from torch._inductor.test_operators import realize
20-
from torch._inductor.utils import sympy_index_symbol
20+
from torch._inductor.utils import run_and_get_code, sympy_index_symbol
2121
from torch._inductor.virtualized import ops, V
2222
from torch.testing import FileCheck
2323
from torch.testing._internal.common_cuda import PLATFORM_SUPPORTS_FP8
@@ -520,9 +520,8 @@ def f(x):
520520

521521
@inductor_config.patch(
522522
{
523-
"benchmark_kernel": True,
524-
"loop_ordering_after_fusion": True,
525523
"triton.unique_kernel_names": True,
524+
"triton.max_tiles": 3,
526525
}
527526
)
528527
@instantiate_parametrized_tests
@@ -867,6 +866,8 @@ def fn(nodes):
867866
coalesce_analysis = tiling_utils.analyze_memory_coalescing(nodes[0])
868867
self.assertEqual(coalesce_analysis.suggested_split.tiling_factor, 64)
869868

869+
return nodes
870+
870871
with torch._inductor.config.patch(_post_fusion_custom_pass=fn), torch.no_grad():
871872

872873
def forward(permute):
@@ -883,7 +884,10 @@ def forward(permute):
883884
arg0_1 = torch.randn([XDIM, YDIM], device="cuda", dtype=torch.bfloat16)
884885
permute = torch.ops.aten.permute.default(arg0_1, [1, 0])
885886

886-
torch.compile(forward)(permute)
887+
out, code = run_and_get_code(torch.compile(forward), (permute))
888+
889+
self.assertEqual(out, forward(permute))
890+
FileCheck().check("YBLOCK").check("XBLOCK").run(code[0])
887891

888892

889893
if __name__ == "__main__":

torch/_inductor/codegen/simd.py

Lines changed: 248 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818

1919
import torch
2020
import torch._logging
21+
from torch._inductor.tiling_utils import analyze_memory_coalescing
2122
from torch.fx.experimental.symbolic_shapes import free_unbacked_symbols
2223
from torch.fx.immutable_collections import immutable_dict
2324
from torch.utils._ordered_set import OrderedSet
@@ -58,6 +59,7 @@
5859
from .simd_kernel_features import (
5960
DisableReduction,
6061
EnableReduction,
62+
NodeScheduleEntry,
6163
NodeScheduleMarker,
6264
SIMDKernelFeatures,
6365
)
@@ -66,6 +68,8 @@
6668
if TYPE_CHECKING:
6769
from collections.abc import Iterable, Iterator, Sequence
6870

71+
from torch._inductor.tiling_utils import CoalesceVarAnalysis
72+
6973

7074
log = logging.getLogger(__name__)
7175
perf_hint_log = torch._logging.getArtifactLogger(__name__, "perf_hints")
@@ -679,6 +683,7 @@ def getter(flat_vars: list[sympy.Expr]) -> sympy.Expr:
679683
size, remaining[current_group]
680684
):
681685
raise CantSplit
686+
682687
size1 = remaining[current_group]
683688
size2 = FloorDiv(size, remaining[current_group])
684689
return_getters.append(
@@ -1344,13 +1349,14 @@ def codegen_node(
13441349

13451350
nodes: list[scheduler.SchedulerNode] = node.get_nodes() # type: ignore[assignment]
13461351

1352+
coalesce_analysis = analyze_memory_coalescing(node)
13471353
_, (numel, rnumel) = max(nodes, key=lambda x: int(x.is_reduction())).group
13481354

13491355
node_schedule = self.generate_node_schedule(nodes, numel, rnumel)
13501356
schedule_log.debug("Schedule:\n %s", node_schedule)
13511357

13521358
return self.codegen_node_schedule(
1353-
SIMDKernelFeatures(node_schedule, numel, rnumel)
1359+
SIMDKernelFeatures(node_schedule, numel, rnumel, coalesce_analysis)
13541360
)
13551361

13561362
@staticmethod
@@ -1383,11 +1389,17 @@ def can_use_32bit_indexing(
13831389

13841390
def codegen_node_schedule(self, kernel_features: SIMDKernelFeatures):
13851391
node_schedule = kernel_features.node_schedule
1386-
tiling = self.select_tiling(
1387-
node_schedule, kernel_features.numel, kernel_features.reduction_numel
1392+
1393+
tiling, tiling_score = self.get_tiling_and_scores(
1394+
node_schedule,
1395+
kernel_features.numel,
1396+
kernel_features.reduction_numel,
1397+
kernel_features.coalesce_analysis,
13881398
)
13891399
kernels = self.create_kernel_choices(
1390-
kernel_features, [tiling], {"features": kernel_features}
1400+
kernel_features,
1401+
[tiling],
1402+
{"features": kernel_features, "tiling_scores": tiling_score},
13911403
)
13921404
for kernel in kernels:
13931405
self.codegen_node_schedule_with_kernel(node_schedule, kernel)
@@ -1995,10 +2007,225 @@ def get_nd_tilings(
19952007

19962008
return ranked_tilings
19972009

2010+
@classmethod
2011+
def compute_tiling_strategy(
2012+
cls,
2013+
node_schedule: list[NodeScheduleEntry],
2014+
pointwise_numel: sympy.Expr,
2015+
reduction_numel: sympy.Expr,
2016+
coalesce_analysis: CoalesceVarAnalysis,
2017+
) -> tuple[dict[str, sympy.Expr], Optional[dict[str, sympy.Expr]]]:
2018+
"""
2019+
Generates a tiling, and a score of each tile according to each tile's coalesced memory accesses.
2020+
"""
2021+
tiling_var: Optional[sympy.Expr] = (
2022+
None
2023+
if not coalesce_analysis.suggested_split
2024+
else coalesce_analysis.suggested_split.var
2025+
)
2026+
2027+
all_iter_vars = coalesce_analysis.norm_read_writes.index_vars
2028+
all_red_vars = coalesce_analysis.norm_read_writes.reduce_vars
2029+
ranges = coalesce_analysis.norm_read_writes.var_ranges
2030+
2031+
pw_ranges = [ranges[v] for v in all_iter_vars]
2032+
red_ranges = [ranges[v] for v in all_red_vars]
2033+
2034+
torch._check(
2035+
sympy_product(pw_ranges) == pointwise_numel,
2036+
lambda: f"{pw_ranges}, {pointwise_numel}, {node_schedule}",
2037+
)
2038+
torch._check(
2039+
sympy_product(red_ranges) == reduction_numel,
2040+
lambda: f"{red_ranges}, {reduction_numel}, {node_schedule}",
2041+
)
2042+
2043+
# score of a pointwise or reduction split
2044+
scored_sub_split: dict[Any, tuple[list[int], list[int]]] = {}
2045+
2046+
score_split: list[
2047+
tuple[tuple[list[int], list[int]], tuple[list[int], list[int]]]
2048+
] = []
2049+
2050+
def process_node_vars(
2051+
vars_to_use: tuple[sympy.Expr, ...] = (),
2052+
use_split_var: bool = False,
2053+
is_pointwise: bool = False,
2054+
) -> tuple[list[int], list[int]]:
2055+
"""
2056+
Generate a tiling, and a tiling score, given vars to use as splits.
2057+
"""
2058+
2059+
ranges = pw_ranges if is_pointwise else red_ranges
2060+
target_numel = pointwise_numel if is_pointwise else reduction_numel
2061+
# Some kernels have no reduction ranges, and a reduction numel of 1
2062+
if not ranges:
2063+
if target_numel:
2064+
return ([target_numel], [])
2065+
else:
2066+
return ([], [])
2067+
2068+
key = (repr(vars_to_use), use_split_var, is_pointwise)
2069+
if out := scored_sub_split.get(key, None):
2070+
return out
2071+
2072+
splitting_vars = all_iter_vars if is_pointwise else all_red_vars
2073+
2074+
splits = []
2075+
split_scores = []
2076+
prod = 1
2077+
prev_var_coalesced_score = 0
2078+
2079+
for v, v_range in zip(splitting_vars, ranges):
2080+
prod *= v_range
2081+
if v not in vars_to_use:
2082+
prev_var_coalesced_score = coalesce_analysis.coalesced_by_var.get(
2083+
v, 0
2084+
)
2085+
continue
2086+
2087+
# If this is the split variable and we're using it as such
2088+
if use_split_var and v == tiling_var:
2089+
var_tiling = coalesce_analysis.suggested_split
2090+
assert var_tiling is not None
2091+
# Add the original range up to this point
2092+
if prod > 1:
2093+
splits.append(prod // var_tiling.tiling_factor)
2094+
split_scores.append(var_tiling.score)
2095+
2096+
prod = var_tiling.tiling_factor # Remaining size
2097+
# if we end up splitting on this, v will be coalesced as well
2098+
prev_var_coalesced_score = coalesce_analysis.coalesced_by_var.get(
2099+
v, 0
2100+
)
2101+
2102+
else:
2103+
# splitting on this var
2104+
splits.append(prod)
2105+
split_scores.append(coalesce_analysis.coalesced_by_var.get(v, 0))
2106+
prod = 1
2107+
2108+
splits.append(prod)
2109+
split_scores.append(prev_var_coalesced_score)
2110+
2111+
scored_sub_split[key] = (splits, split_scores)
2112+
return (splits, split_scores)
2113+
2114+
# add the default tiling
2115+
score_split.append(
2116+
(
2117+
process_node_vars(is_pointwise=True),
2118+
process_node_vars(is_pointwise=False),
2119+
)
2120+
)
2121+
2122+
if tiling_var:
2123+
score_split.append(
2124+
(
2125+
process_node_vars(
2126+
(tiling_var,), use_split_var=True, is_pointwise=True
2127+
),
2128+
process_node_vars(is_pointwise=False),
2129+
)
2130+
)
2131+
2132+
# TODO, add tests, reduction splits if config.triton.tile_reductions
2133+
# TODO: we should ignore tiny increases in score for extra splits
2134+
overlapping_iter_vars = (
2135+
all_iter_vars & coalesce_analysis.coalesced_by_var.keys()
2136+
)
2137+
for v in overlapping_iter_vars:
2138+
score_split.append(
2139+
(
2140+
process_node_vars((v,), is_pointwise=True),
2141+
process_node_vars(is_pointwise=False),
2142+
)
2143+
)
2144+
2145+
tilings: list[tuple[CandidateTiling, dict[str, sympy.Expr]]] = []
2146+
for (pw_split, pw_score), (red_split, red_score) in score_split:
2147+
candidate = CandidateTiling(
2148+
cls.create_tiling(pw_split, red_split),
2149+
score=sum(pw_score) + sum(red_score),
2150+
)
2151+
tiling_score = cls.create_tiling(pw_score, red_score)
2152+
tilings.append((candidate, tiling_score))
2153+
2154+
default_tiling = cls.create_tiling([pointwise_numel], [reduction_numel])
2155+
2156+
for cand, tiling_score in sorted(tilings, key=lambda t: -t[0].score):
2157+
if cls.tiling_is_compatible(
2158+
node_schedule, pointwise_numel, reduction_numel, cand.tiling
2159+
):
2160+
if len(cand.tiling) > torch._inductor.config.triton.max_tiles:
2161+
perf_hint_log.info(
2162+
"Found optimal tiling with %s tiles but torch._inductor.config.triton.max_tiles "
2163+
"set to %s. Consider increasing",
2164+
len(cand.tiling),
2165+
torch._inductor.config.triton.max_tiles,
2166+
)
2167+
continue
2168+
2169+
return cand.tiling, tiling_score
2170+
2171+
# surprisingly, the default tiling is not always read as compatible by `tiling_is_compatible`
2172+
# TODO - look into, occurs with dynamic shapes often
2173+
if cand.tiling == default_tiling:
2174+
return cand.tiling, tiling_score
2175+
2176+
return default_tiling, None
2177+
2178+
@classmethod
2179+
def tiling_is_compatible(
2180+
cls,
2181+
node_schedule: list[NodeScheduleEntry],
2182+
numel: sympy.Expr,
2183+
reduction_numel: sympy.Expr,
2184+
tiling: dict[str, sympy.Expr],
2185+
):
2186+
assert isinstance(tiling, dict)
2187+
return all(
2188+
SIMDKernel.is_compatible(
2189+
tiling.values(), node.get_ranges(), reduction_numel=reduction_numel
2190+
)
2191+
for node in node_schedule
2192+
if isinstance(node, scheduler.SchedulerNode)
2193+
)
2194+
2195+
@classmethod
2196+
def get_first_compatible_tiling(
2197+
cls,
2198+
node_schedule: list[NodeScheduleEntry],
2199+
numel: sympy.Expr,
2200+
reduction_numel: sympy.Expr,
2201+
ranked_tilings: list[dict[str, sympy.Expr]],
2202+
):
2203+
for tiling in ranked_tilings:
2204+
if cls.tiling_is_compatible(node_schedule, numel, reduction_numel, tiling):
2205+
return tiling
2206+
2207+
return None
2208+
19982209
@classmethod
19992210
def select_tiling(
2000-
cls, node_schedule, numel, reduction_numel=sympy.S.One
2211+
cls,
2212+
node_schedule,
2213+
numel,
2214+
reduction_numel=sympy.S.One,
2215+
coalesce_analysis: Optional[CoalesceVarAnalysis] = None,
20012216
) -> dict[str, sympy.Expr]:
2217+
return cls.get_tiling_and_scores(
2218+
node_schedule, numel, reduction_numel, coalesce_analysis
2219+
)[0]
2220+
2221+
@classmethod
2222+
def get_tiling_and_scores(
2223+
cls,
2224+
node_schedule,
2225+
numel,
2226+
reduction_numel=sympy.S.One,
2227+
coalesce_analysis: Optional[CoalesceVarAnalysis] = None,
2228+
) -> tuple[dict[str, sympy.Expr], Optional[dict[str, sympy.Expr]]]:
20022229
"""
20032230
Heuristics to decide how to tile kernels.
20042231
Currently, we tile based on stride-1 dimensions.
@@ -2012,6 +2239,15 @@ def select_tiling(
20122239

20132240
# Tiled reductions are gated by a config flag.
20142241
default_tiling = cls.create_tiling([numel], [reduction_numel])
2242+
2243+
# TODO: enable by default
2244+
if (
2245+
torch._inductor.config.test_configs.global_tiling_analysis
2246+
and coalesce_analysis
2247+
):
2248+
return cls.compute_tiling_strategy(
2249+
node_schedule, numel, reduction_numel, coalesce_analysis
2250+
)
20152251
if (
20162252
not is_pointwise and not config.triton.tile_reductions
20172253
) or config.triton.max_tiles <= 1:
@@ -2031,7 +2267,8 @@ def select_tiling(
20312267
)
20322268
)
20332269
break
2034-
return default_tiling
2270+
2271+
return default_tiling, None
20352272

20362273
seen_names = OrderedSet[str]()
20372274
candidate_tiles: Counter[CandidateTiling] = collections.Counter()
@@ -2101,18 +2338,12 @@ def convert_tiling_to_3d(
21012338
+ ranked_tilings
21022339
)
21032340

2104-
for tiling in ranked_tilings:
2105-
assert isinstance(tiling, dict)
2106-
if all(
2107-
SIMDKernel.is_compatible(
2108-
tiling.values(), node.get_ranges(), reduction_numel=reduction_numel
2109-
)
2110-
for node in node_schedule
2111-
if isinstance(node, scheduler.SchedulerNode)
2112-
):
2113-
return tiling
2341+
if tiling := cls.get_first_compatible_tiling(
2342+
node_schedule, numel, reduction_numel, ranked_tilings
2343+
):
2344+
return tiling, None
21142345

2115-
return default_tiling
2346+
return default_tiling, None
21162347

21172348
def flush(self):
21182349
pass

0 commit comments

Comments
 (0)
0