18
18
19
19
import torch
20
20
import torch ._logging
21
+ from torch ._inductor .tiling_utils import analyze_memory_coalescing
21
22
from torch .fx .experimental .symbolic_shapes import free_unbacked_symbols
22
23
from torch .fx .immutable_collections import immutable_dict
23
24
from torch .utils ._ordered_set import OrderedSet
58
59
from .simd_kernel_features import (
59
60
DisableReduction ,
60
61
EnableReduction ,
62
+ NodeScheduleEntry ,
61
63
NodeScheduleMarker ,
62
64
SIMDKernelFeatures ,
63
65
)
66
68
if TYPE_CHECKING :
67
69
from collections .abc import Iterable , Iterator , Sequence
68
70
71
+ from torch ._inductor .tiling_utils import CoalesceVarAnalysis
72
+
69
73
70
74
log = logging .getLogger (__name__ )
71
75
perf_hint_log = torch ._logging .getArtifactLogger (__name__ , "perf_hints" )
@@ -679,6 +683,7 @@ def getter(flat_vars: list[sympy.Expr]) -> sympy.Expr:
679
683
size , remaining [current_group ]
680
684
):
681
685
raise CantSplit
686
+
682
687
size1 = remaining [current_group ]
683
688
size2 = FloorDiv (size , remaining [current_group ])
684
689
return_getters .append (
@@ -1344,13 +1349,14 @@ def codegen_node(
1344
1349
1345
1350
nodes : list [scheduler .SchedulerNode ] = node .get_nodes () # type: ignore[assignment]
1346
1351
1352
+ coalesce_analysis = analyze_memory_coalescing (node )
1347
1353
_ , (numel , rnumel ) = max (nodes , key = lambda x : int (x .is_reduction ())).group
1348
1354
1349
1355
node_schedule = self .generate_node_schedule (nodes , numel , rnumel )
1350
1356
schedule_log .debug ("Schedule:\n %s" , node_schedule )
1351
1357
1352
1358
return self .codegen_node_schedule (
1353
- SIMDKernelFeatures (node_schedule , numel , rnumel )
1359
+ SIMDKernelFeatures (node_schedule , numel , rnumel , coalesce_analysis )
1354
1360
)
1355
1361
1356
1362
@staticmethod
@@ -1383,11 +1389,17 @@ def can_use_32bit_indexing(
1383
1389
1384
1390
def codegen_node_schedule (self , kernel_features : SIMDKernelFeatures ):
1385
1391
node_schedule = kernel_features .node_schedule
1386
- tiling = self .select_tiling (
1387
- node_schedule , kernel_features .numel , kernel_features .reduction_numel
1392
+
1393
+ tiling , tiling_score = self .get_tiling_and_scores (
1394
+ node_schedule ,
1395
+ kernel_features .numel ,
1396
+ kernel_features .reduction_numel ,
1397
+ kernel_features .coalesce_analysis ,
1388
1398
)
1389
1399
kernels = self .create_kernel_choices (
1390
- kernel_features , [tiling ], {"features" : kernel_features }
1400
+ kernel_features ,
1401
+ [tiling ],
1402
+ {"features" : kernel_features , "tiling_scores" : tiling_score },
1391
1403
)
1392
1404
for kernel in kernels :
1393
1405
self .codegen_node_schedule_with_kernel (node_schedule , kernel )
@@ -1995,10 +2007,225 @@ def get_nd_tilings(
1995
2007
1996
2008
return ranked_tilings
1997
2009
2010
+ @classmethod
2011
+ def compute_tiling_strategy (
2012
+ cls ,
2013
+ node_schedule : list [NodeScheduleEntry ],
2014
+ pointwise_numel : sympy .Expr ,
2015
+ reduction_numel : sympy .Expr ,
2016
+ coalesce_analysis : CoalesceVarAnalysis ,
2017
+ ) -> tuple [dict [str , sympy .Expr ], Optional [dict [str , sympy .Expr ]]]:
2018
+ """
2019
+ Generates a tiling, and a score of each tile according to each tile's coalesced memory accesses.
2020
+ """
2021
+ tiling_var : Optional [sympy .Expr ] = (
2022
+ None
2023
+ if not coalesce_analysis .suggested_split
2024
+ else coalesce_analysis .suggested_split .var
2025
+ )
2026
+
2027
+ all_iter_vars = coalesce_analysis .norm_read_writes .index_vars
2028
+ all_red_vars = coalesce_analysis .norm_read_writes .reduce_vars
2029
+ ranges = coalesce_analysis .norm_read_writes .var_ranges
2030
+
2031
+ pw_ranges = [ranges [v ] for v in all_iter_vars ]
2032
+ red_ranges = [ranges [v ] for v in all_red_vars ]
2033
+
2034
+ torch ._check (
2035
+ sympy_product (pw_ranges ) == pointwise_numel ,
2036
+ lambda : f"{ pw_ranges } , { pointwise_numel } , { node_schedule } " ,
2037
+ )
2038
+ torch ._check (
2039
+ sympy_product (red_ranges ) == reduction_numel ,
2040
+ lambda : f"{ red_ranges } , { reduction_numel } , { node_schedule } " ,
2041
+ )
2042
+
2043
+ # score of a pointwise or reduction split
2044
+ scored_sub_split : dict [Any , tuple [list [int ], list [int ]]] = {}
2045
+
2046
+ score_split : list [
2047
+ tuple [tuple [list [int ], list [int ]], tuple [list [int ], list [int ]]]
2048
+ ] = []
2049
+
2050
+ def process_node_vars (
2051
+ vars_to_use : tuple [sympy .Expr , ...] = (),
2052
+ use_split_var : bool = False ,
2053
+ is_pointwise : bool = False ,
2054
+ ) -> tuple [list [int ], list [int ]]:
2055
+ """
2056
+ Generate a tiling, and a tiling score, given vars to use as splits.
2057
+ """
2058
+
2059
+ ranges = pw_ranges if is_pointwise else red_ranges
2060
+ target_numel = pointwise_numel if is_pointwise else reduction_numel
2061
+ # Some kernels have no reduction ranges, and a reduction numel of 1
2062
+ if not ranges :
2063
+ if target_numel :
2064
+ return ([target_numel ], [])
2065
+ else :
2066
+ return ([], [])
2067
+
2068
+ key = (repr (vars_to_use ), use_split_var , is_pointwise )
2069
+ if out := scored_sub_split .get (key , None ):
2070
+ return out
2071
+
2072
+ splitting_vars = all_iter_vars if is_pointwise else all_red_vars
2073
+
2074
+ splits = []
2075
+ split_scores = []
2076
+ prod = 1
2077
+ prev_var_coalesced_score = 0
2078
+
2079
+ for v , v_range in zip (splitting_vars , ranges ):
2080
+ prod *= v_range
2081
+ if v not in vars_to_use :
2082
+ prev_var_coalesced_score = coalesce_analysis .coalesced_by_var .get (
2083
+ v , 0
2084
+ )
2085
+ continue
2086
+
2087
+ # If this is the split variable and we're using it as such
2088
+ if use_split_var and v == tiling_var :
2089
+ var_tiling = coalesce_analysis .suggested_split
2090
+ assert var_tiling is not None
2091
+ # Add the original range up to this point
2092
+ if prod > 1 :
2093
+ splits .append (prod // var_tiling .tiling_factor )
2094
+ split_scores .append (var_tiling .score )
2095
+
2096
+ prod = var_tiling .tiling_factor # Remaining size
2097
+ # if we end up splitting on this, v will be coalesced as well
2098
+ prev_var_coalesced_score = coalesce_analysis .coalesced_by_var .get (
2099
+ v , 0
2100
+ )
2101
+
2102
+ else :
2103
+ # splitting on this var
2104
+ splits .append (prod )
2105
+ split_scores .append (coalesce_analysis .coalesced_by_var .get (v , 0 ))
2106
+ prod = 1
2107
+
2108
+ splits .append (prod )
2109
+ split_scores .append (prev_var_coalesced_score )
2110
+
2111
+ scored_sub_split [key ] = (splits , split_scores )
2112
+ return (splits , split_scores )
2113
+
2114
+ # add the default tiling
2115
+ score_split .append (
2116
+ (
2117
+ process_node_vars (is_pointwise = True ),
2118
+ process_node_vars (is_pointwise = False ),
2119
+ )
2120
+ )
2121
+
2122
+ if tiling_var :
2123
+ score_split .append (
2124
+ (
2125
+ process_node_vars (
2126
+ (tiling_var ,), use_split_var = True , is_pointwise = True
2127
+ ),
2128
+ process_node_vars (is_pointwise = False ),
2129
+ )
2130
+ )
2131
+
2132
+ # TODO, add tests, reduction splits if config.triton.tile_reductions
2133
+ # TODO: we should ignore tiny increases in score for extra splits
2134
+ overlapping_iter_vars = (
2135
+ all_iter_vars & coalesce_analysis .coalesced_by_var .keys ()
2136
+ )
2137
+ for v in overlapping_iter_vars :
2138
+ score_split .append (
2139
+ (
2140
+ process_node_vars ((v ,), is_pointwise = True ),
2141
+ process_node_vars (is_pointwise = False ),
2142
+ )
2143
+ )
2144
+
2145
+ tilings : list [tuple [CandidateTiling , dict [str , sympy .Expr ]]] = []
2146
+ for (pw_split , pw_score ), (red_split , red_score ) in score_split :
2147
+ candidate = CandidateTiling (
2148
+ cls .create_tiling (pw_split , red_split ),
2149
+ score = sum (pw_score ) + sum (red_score ),
2150
+ )
2151
+ tiling_score = cls .create_tiling (pw_score , red_score )
2152
+ tilings .append ((candidate , tiling_score ))
2153
+
2154
+ default_tiling = cls .create_tiling ([pointwise_numel ], [reduction_numel ])
2155
+
2156
+ for cand , tiling_score in sorted (tilings , key = lambda t : - t [0 ].score ):
2157
+ if cls .tiling_is_compatible (
2158
+ node_schedule , pointwise_numel , reduction_numel , cand .tiling
2159
+ ):
2160
+ if len (cand .tiling ) > torch ._inductor .config .triton .max_tiles :
2161
+ perf_hint_log .info (
2162
+ "Found optimal tiling with %s tiles but torch._inductor.config.triton.max_tiles "
2163
+ "set to %s. Consider increasing" ,
2164
+ len (cand .tiling ),
2165
+ torch ._inductor .config .triton .max_tiles ,
2166
+ )
2167
+ continue
2168
+
2169
+ return cand .tiling , tiling_score
2170
+
2171
+ # surprisingly, the default tiling is not always read as compatible by `tiling_is_compatible`
2172
+ # TODO - look into, occurs with dynamic shapes often
2173
+ if cand .tiling == default_tiling :
2174
+ return cand .tiling , tiling_score
2175
+
2176
+ return default_tiling , None
2177
+
2178
+ @classmethod
2179
+ def tiling_is_compatible (
2180
+ cls ,
2181
+ node_schedule : list [NodeScheduleEntry ],
2182
+ numel : sympy .Expr ,
2183
+ reduction_numel : sympy .Expr ,
2184
+ tiling : dict [str , sympy .Expr ],
2185
+ ):
2186
+ assert isinstance (tiling , dict )
2187
+ return all (
2188
+ SIMDKernel .is_compatible (
2189
+ tiling .values (), node .get_ranges (), reduction_numel = reduction_numel
2190
+ )
2191
+ for node in node_schedule
2192
+ if isinstance (node , scheduler .SchedulerNode )
2193
+ )
2194
+
2195
+ @classmethod
2196
+ def get_first_compatible_tiling (
2197
+ cls ,
2198
+ node_schedule : list [NodeScheduleEntry ],
2199
+ numel : sympy .Expr ,
2200
+ reduction_numel : sympy .Expr ,
2201
+ ranked_tilings : list [dict [str , sympy .Expr ]],
2202
+ ):
2203
+ for tiling in ranked_tilings :
2204
+ if cls .tiling_is_compatible (node_schedule , numel , reduction_numel , tiling ):
2205
+ return tiling
2206
+
2207
+ return None
2208
+
1998
2209
@classmethod
1999
2210
def select_tiling (
2000
- cls , node_schedule , numel , reduction_numel = sympy .S .One
2211
+ cls ,
2212
+ node_schedule ,
2213
+ numel ,
2214
+ reduction_numel = sympy .S .One ,
2215
+ coalesce_analysis : Optional [CoalesceVarAnalysis ] = None ,
2001
2216
) -> dict [str , sympy .Expr ]:
2217
+ return cls .get_tiling_and_scores (
2218
+ node_schedule , numel , reduction_numel , coalesce_analysis
2219
+ )[0 ]
2220
+
2221
+ @classmethod
2222
+ def get_tiling_and_scores (
2223
+ cls ,
2224
+ node_schedule ,
2225
+ numel ,
2226
+ reduction_numel = sympy .S .One ,
2227
+ coalesce_analysis : Optional [CoalesceVarAnalysis ] = None ,
2228
+ ) -> tuple [dict [str , sympy .Expr ], Optional [dict [str , sympy .Expr ]]]:
2002
2229
"""
2003
2230
Heuristics to decide how to tile kernels.
2004
2231
Currently, we tile based on stride-1 dimensions.
@@ -2012,6 +2239,15 @@ def select_tiling(
2012
2239
2013
2240
# Tiled reductions are gated by a config flag.
2014
2241
default_tiling = cls .create_tiling ([numel ], [reduction_numel ])
2242
+
2243
+ # TODO: enable by default
2244
+ if (
2245
+ torch ._inductor .config .test_configs .global_tiling_analysis
2246
+ and coalesce_analysis
2247
+ ):
2248
+ return cls .compute_tiling_strategy (
2249
+ node_schedule , numel , reduction_numel , coalesce_analysis
2250
+ )
2015
2251
if (
2016
2252
not is_pointwise and not config .triton .tile_reductions
2017
2253
) or config .triton .max_tiles <= 1 :
@@ -2031,7 +2267,8 @@ def select_tiling(
2031
2267
)
2032
2268
)
2033
2269
break
2034
- return default_tiling
2270
+
2271
+ return default_tiling , None
2035
2272
2036
2273
seen_names = OrderedSet [str ]()
2037
2274
candidate_tiles : Counter [CandidateTiling ] = collections .Counter ()
@@ -2101,18 +2338,12 @@ def convert_tiling_to_3d(
2101
2338
+ ranked_tilings
2102
2339
)
2103
2340
2104
- for tiling in ranked_tilings :
2105
- assert isinstance (tiling , dict )
2106
- if all (
2107
- SIMDKernel .is_compatible (
2108
- tiling .values (), node .get_ranges (), reduction_numel = reduction_numel
2109
- )
2110
- for node in node_schedule
2111
- if isinstance (node , scheduler .SchedulerNode )
2112
- ):
2113
- return tiling
2341
+ if tiling := cls .get_first_compatible_tiling (
2342
+ node_schedule , numel , reduction_numel , ranked_tilings
2343
+ ):
2344
+ return tiling , None
2114
2345
2115
- return default_tiling
2346
+ return default_tiling , None
2116
2347
2117
2348
def flush (self ):
2118
2349
pass
0 commit comments