8000 Incorporate coalesce analysis in codegen by eellison · Pull Request #153751 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content

Incorporate coalesce analysis in codegen #153751

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: gh/eellison/793/base
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 8 additions & 4 deletions test/inductor/test_loop_ordering.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from torch._inductor.scheduler import SchedulerNode
from torch._inductor.test_case import run_tests, TestCase
from torch._inductor.test_operators import realize< 8000 /span>
from torch._inductor.utils import sympy_index_symbol
from torch._inductor.utils import run_and_get_code, sympy_index_symbol
from torch._inductor.virtualized import ops, V
from torch.testing import FileCheck
from torch.testing._internal.common_cuda import PLATFORM_SUPPORTS_FP8
Expand Down Expand Up @@ -520,9 +520,8 @@ def f(x):

@inductor_config.patch(
{
"benchmark_kernel": True,
"loop_ordering_after_fusion": True,
"triton.unique_kernel_names": True,
"triton.max_tiles": 3,
}
)
@instantiate_parametrized_tests
Expand Down Expand Up @@ -867,6 +866,8 @@ def fn(nodes):
coalesce_analysis = tiling_utils.analyze_memory_coalescing(nodes[0])
self.assertEqual(coalesce_analysis.suggested_split.tiling_factor, 64)

return nodes

with torch._inductor.config.patch(_post_fusion_custom_pass=fn), torch.no_grad():

def forward(permute):
Expand All @@ -883,7 +884,10 @@ def forward(permute):
arg0_1 = torch.randn([XDIM, YDIM], device="cuda", dtype=torch.bfloat16)
permute = torch.ops.aten.permute.default(arg0_1, [1, 0])

torch.compile(forward)(permute)
out, code = run_and_get_code(torch.compile(forward), (permute))

self.assertEqual(out, forward(permute))
FileCheck().check("YBLOCK").check("XBLOCK").run(code[0])


if __name__ == "__main__":
Expand Down
265 changes: 248 additions & 17 deletions torch/_inductor/codegen/simd.py
< 8000 td id="diff-11df6999ae1e58f4dcb00977e32f4e56d1d2d530119d574c6e9abbcb1626e5a3R72" data-line-number="72" class="blob-num blob-num-addition js-linkable-line-number js-blob-rnum">
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

import torch
import torch._logging
from torch._inductor.tiling_utils import analyze_memory_coalescing
from torch.fx.experimental.symbolic_shapes import free_unbacked_symbols
from torch.fx.immutable_collections import immutable_dict
from torch.utils._ordered_set import OrderedSet
Expand Down Expand Up @@ -58,6 +59,7 @@
from .simd_kernel_features import (
DisableReduction,
EnableReduction,
NodeScheduleEntry,
NodeScheduleMarker,
SIMDKernelFeatures,
)
Expand All @@ -66,6 +68,8 @@
if TYPE_CHECKING:
from collections.abc import Iterable, Iterator, Sequence

from torch._inductor.tiling_utils import CoalesceVarAnalysis


log = logging.getLogger(__name__)
perf_hint_log = torch._logging.getArtifactLogger(__name__, "perf_hints")
Expand Down Expand Up @@ -679,6 +683,7 @@ def getter(flat_vars: list[sympy.Expr]) -> sympy.Expr:
size, remaining[current_group]
):
raise CantSplit

size1 = remaining[current_group]
size2 = FloorDiv(size, remaining[current_group])
return_getters.append(
Expand Down Expand Up @@ -1344,13 +1349,14 @@ def codegen_node(

nodes: list[scheduler.SchedulerNode] = node.get_nodes() # type: ignore[assignment]

coalesce_analysis = analyze_memory_coalescing(node)
_, (numel, rnumel) = max(nodes, key=lambda x: int(x.is_reduction())).group

node_schedule = self.generate_node_schedule(nodes, numel, rnumel)
schedule_log.debug("Schedule:\n %s", node_schedule)

return self.codegen_node_schedule(
SIMDKernelFeatures(node_schedule, numel, rnumel)
SIMDKernelFeatures(node_schedule, numel, rnumel, coalesce_analysis)
)

@staticmethod
Expand Down Expand Up @@ -1383,11 +1389,17 @@ def can_use_32bit_indexing(

def codegen_node_schedule(self, kernel_features: SIMDKernelFeatures):
node_schedule = kernel_features.node_schedule
tiling = self.select_tiling(
node_schedule, kernel_features.numel, kernel_features.reduction_numel

tiling, tiling_score = self.get_tiling_and_scores(
8000 node_schedule,
kernel_features.numel,
kernel_features.reduction_numel,
kernel_features.coalesce_analysis,
)
kernels = self.create_kernel_choices(
kernel_features, [tiling], {"features": kernel_features}
kernel_features,
[tiling],
{"features": kernel_features, "tiling_scores": tiling_score},
)
for kernel in kernels:
self.codegen_node_schedule_with_kernel(node_schedule, kernel)
Expand Down Expand Up @@ -1995,10 +2007,225 @@ def get_nd_tilings(

return ranked_tilings

@classmethod
def compute_tiling_strategy(
cls,
node_schedule: list[NodeScheduleEntry],
pointwise_numel: sympy.Expr,
reduction_numel: sympy.Expr,
coalesce_analysis: CoalesceVarAnalysis,
) -> tuple[dict[str, sympy.Expr], Optional[dict[str, sympy.Expr]]]:
"""
Generates a tiling, and a score of each tile according to each tile's coalesced memory accesses.
"""
tiling_var: Optional[sympy.Expr] = (
None
if not coalesce_analysis.suggested_split
else coalesce_analysis.suggested_split.var
)

all_iter_vars = coalesce_analysis.norm_read_writes.index_vars
all_red_vars = coalesce_analysis.norm_read_writes.reduce_vars
ranges = coalesce_analysis.norm_read_writes.var_ranges

pw_ranges = [ranges[v] for v in all_iter_vars]
red_ranges = [ranges[v] for v in all_red_vars]

torch._check(
sympy_product(pw_ranges) == pointwise_numel,
lambda: f"{pw_ranges}, {pointwise_numel}, {node_schedule}",
)
torch._check(
sympy_product(red_ranges) == reduction_numel,
lambda: f"{red_ranges}, {reduction_numel}, {node_schedule}",
)

# score of a pointwise or reduction split
scored_sub_split: dict[Any, tuple[lis 8000 t[int], list[int]]] = {}

score_split: list[
tuple[tuple[list[int], list[int]], tuple[list[int], list[int]]]
] = []

def process_node_vars(
vars_to_use: tuple[sympy.Expr, ...] = (),
use_split_var: bool = False,
is_pointwise: bool = False,
) -> tuple[list[int], list[int]]:
"""
Generate a tiling, and a tiling score, given vars to use as splits.
"""

ranges = pw_ranges if is_pointwise else red_ranges
target_numel = pointwise_numel if is_pointwise else reduction_numel
# Some kernels have no reduction ranges, and a reduction numel of 1
if not ranges:
if target_numel:
return ([target_numel], [])
else:
return ([], [])

key = (repr(vars_to_use), use_split_var, is_pointwise)
if out := scored_sub_split.get(key, None):
return out

splitting_vars = all_iter_vars if is_pointwise else all_red_vars

splits = []
split_scores = []
prod = 1
prev_var_coalesced_score = 0

for v, v_range in zip(splitting_vars, ranges):
prod *= v_range
if v not in vars_to_use:
prev_var_coalesced_score = coalesce_analysis.coalesced_by_var.get(
v, 0
)
continue

# If this is the split variable and we're using it as such
if use_split_var and v == tiling_var:
var_tiling = coalesce_analysis.suggested_split
assert var_tiling is not None
# Add the original range up to this point
if prod > 1:
splits.append(prod // var_tiling.tiling_factor)
split_scores.append(var_tiling.score)

prod = var_tiling.tiling_factor # Remaining size
# if we end up splitting on this, v will be coalesced as well
prev_var_coalesced_score = coalesce_analysis.coalesced_by_var.get(
v, 0
)

else:
# splitting on this var
splits.append(prod)
split_scores.append(coalesce_analysis.coalesced_by_var.get(v, 0))
prod = 1

splits.append(prod)
split_scores.append(prev_var_coalesced_score)

scored_sub_split[key] = (splits, split_scores)
return (splits, split_scores)

# add the default tiling
score_split.append(
(
process_node_vars(is_pointwise=True),
process_node_vars(is_pointwise=False),
)
)

if tiling_var:
score_split.append(
(
process_node_vars(
(tiling_var,), use_split_var=True, is_pointwise=True
),
process_node_vars(is_pointwise=False),
)
)

# TODO, add tests, reduction splits if config.triton.tile_reductions
# TODO: we should ignore tiny increases in score for extra splits
overlapping_iter_vars = (
all_iter_vars & coalesce_analysis.coalesced_by_var.keys()
)
for v in overlapping_iter_vars:
score_split.append(
(
process_node_vars((v,), is_pointwise=True),
process_node_vars(is_pointwise=False),
)
)

tilings: list[tuple[CandidateTiling, dict[str, sympy.Expr]]] = []
for (pw_split, pw_score), (red_split, red_score) in score_split:
candidate = CandidateTiling(
cls.create_tiling(pw_split, red_split),
score=sum(pw_score) + sum(red_score),
)
tiling_score = cls.create_tiling(pw_score, red_score)
tilings.append((candidate, tiling_score))

default_tiling = cls.create_tiling([pointwise_numel], [reduction_numel])

for cand, tiling_score in sorted(tilings, key=lambda t: -t[0].score):
if cls.tiling_is_compatible(
node_schedule, pointwise_numel, reduction_numel, cand.tiling
):
if len(cand.tiling) > torch._inductor.config.triton.max_tiles:
perf_hint_log.info(
"Found optimal tiling with %s tiles but torch._inductor.config.triton.max_tiles "
"set to %s. Consider increasing",
len(cand.tiling),
torch._inductor.config.triton.max_tiles,
)
continue

return cand.tiling, tiling_score

# surprisingly, the default tiling is not always read as compatible by `tiling_is_compatible`
# TODO - look into, occurs with dynamic shapes often
if cand.tiling == default_tiling:
return cand.tiling, tiling_score

return default_tiling, None

@classmethod
def tiling_is_compatible(
cls,
node_schedule: list[NodeScheduleEntry],
numel: sympy.Expr,
reduction_numel: sympy.Expr,
tiling: dict[str, sympy.Expr],
):
assert isinstance(tiling, dict)
return all(
SIMDKernel.is_compatible(
tiling.values(), node.get_ranges(), reduction_numel=reduction_numel
)
for node in node_schedule
if isinstance(node, scheduler.SchedulerNode)
)

@classmethod
def get_first_compatible_tiling(
cls,
node_schedule: list[NodeScheduleEntry],
numel: sympy.Expr,
reduction_numel: sympy.Expr,
ranked_tilings: list[dict[str, sympy.Expr]],
):
for tiling in ranked_tilings:
if cls.tiling_is_compatible(node_schedule, numel, reduction_numel, tiling):
return tiling

return None

@classmethod
def select_tiling(
cls, node_schedule, numel, reduction_numel=sympy.S.One
cls,
node_schedule,
numel,
reduction_numel=sympy.S.One,
coalesce_analysis: Optional[CoalesceVarAnalysis] = None,
) -> dict[str, sympy.Expr]:
return cls.get_tiling_and_scores(
node_schedule, numel, reduction_numel, coalesce_analysis
)[0]

@classmethod
def get_tiling_and_scores(
c F104 ls,
node_schedule,
numel,
reduction_numel=sympy.S.One,
coalesce_analysis: Optional[CoalesceVarAnalysis] = None,
) -> tuple[dict[str, sympy.Expr], Optional[dict[str, sympy.Expr]]]:
"""
Heuristics to decide how to tile kernels.
Currently, we tile based on stride-1 dimensions.
Expand All @@ -2012,6 +2239,15 @@ def select_tiling(

# Tiled reductions are gated by a config flag.
default_tiling = cls.create_tiling([numel], [reduction_numel])

# TODO: enable by default
if (
torch._inductor.config.test_configs.global_tiling_analysis
and coalesce_analysis
):
return cls.compute_tiling_strategy(
node_schedule, numel, reduction_numel, coalesce_analysis
)
if (
not is_pointwise and not config.triton.tile_reductions
) or config.triton.max_tiles <= 1:
Expand All @@ -2031,7 +2267,8 @@ def select_tiling(
)
)
break
return default_tiling

return default_tiling, None

seen_names = OrderedSet[str]()
candidate_tiles: Counter[CandidateTiling] = collections.Counter()
Expand Down Expand Up @@ -2101,18 +2338,12 @@ def convert_tiling_to_3d(
+ ranked_tilings
)

for tiling in ranked_tilings:
assert isinstance(tiling, dict)
if all(
SIMDKernel.is_compatible(
tiling.values(), node.get_ranges(), reduction_numel=reduction_numel
)
for node in node_schedule
if isinstance(node, scheduler.SchedulerNode)
):
return tiling
if tiling := cls.get_first_compatible_tiling(
node_schedule, numel, reduction_numel, ranked_tilings
):
return tiling, None

return default_tiling
return default_tiling, None

def flush(self):
pass
Expand Down
Loading
Loading
0