@@ -1014,7 +1014,7 @@ def codegen_range_tree(self):
1014
1014
1015
1015
def call_kernel (self , name : str , node : Optional [ir .IRNode ] = None ):
1016
1016
wrapper = V .graph .wrapper_code
1017
- _ , call_args , _ , arg_types = self .args .python_argdefs ()
1017
+ argdefs , call_args , signature , arg_types = self .args .python_argdefs ()
1018
1018
1019
1019
grid_args = ()
1020
1020
if isinstance (self .grid_fn , SymbolicGridFn ):
@@ -1036,13 +1036,144 @@ def call_kernel(self, name: str, node: Optional[ir.IRNode] = None):
1036
1036
1037
1037
if self .workspace_arg
10000
span> is not None :
1038
1038
wrapper .generate_workspace_allocation (self .workspace_arg )
1039
- wrapper .generate_kernel_call (
1040
- name ,
1041
- call_args ,
1042
- arg_types = arg_types ,
1043
- triton_meta = self .triton_meta ,
1044
- triton = True ,
1045
- )
1039
+
1040
+ # Check if we have specialized kernels for divisibility
1041
+ if hasattr (self , 'mod_div16' ) and self .mod_div16 is not None and hasattr (self , 'mod_nodiv16' ) and self .mod_nodiv16 is not None :
1042
+ # Generate a wrapper function that checks divisibility at runtime and dispatches to the appropriate kernel
1043
+ wrapper .add_import_once ("import torch" )
1044
+
1045
+ # Create a unique name for the wrapper function
1046
+ wrapper_name = f"{ name } _divisibility_wrapper"
1047
+
1048
+ # Get the SizeArg indices from the signature
1049
+ size_arg_indices = []
1050
+ argdefs , _ , signature , _ = self .args .python_argdefs ()
1051
+ for i , arg in enumerate (signature ):
1052
+ if isinstance (arg , SizeArg ) and arg .expr is not None :
1053
+ size_arg_indices .append (i )
1054
+
1055
+ # Generate the wrapper function
1056
+ wrapper .writeline (f"def { wrapper_name } ({ ', ' .join (a .full_name () for a in argdefs )} ):" )
1057
+ with wrapper .indent ():
1058
+ # Check if all SizeArgs are divisible by 16
1059
+ if size_arg_indices :
1060
+ divisibility_checks = []
1061
+ for i in size_arg_indices :
1062
+ arg_name = argdefs [i ].name
1063
+ divisibility_checks .append (f"{ arg_name } % 16 == 0" )
1064
+
1065
+ wrapper .writeline (f"if { ' and ' .join (divisibility_checks )} :" )
1066
+ with wrapper .indent ():
1067
+ wrapper .writeline (f"return { name } _div16({ ', ' .join (a .name for a in argdefs )} )" )
1068
+ wrapper .writeline ("else:" )
1069
+ with wrapper .indent ():
1070
+ wrapper .writeline (f"return { name } _nodiv16({ ', ' .join (a .name for a in argdefs )} )" )
1071
+ else :
1072
+ # If there are no SizeArgs, just use the default kernel
1073
+ wrapper .writeline (f"return { name } ({ ', ' .join (a .name for a in argdefs )} )" )
1074
+
1075
+ # Generate the specialized kernel calls
1076
+ wrapper .generate_kernel_call (
1077
+ f"{ name } _div16" ,
1078
+ call_args ,
1079
+ arg_types = arg_types ,
1080
+ triton_meta = self .triton_meta ,
1081
+ triton = True ,
1082
+ )
1083
+
1084
+ wrapper .generate_kernel_call (
1085
+ f"{ name } _nodiv16" ,
1086
+ call_args ,
1087
+ arg_types = arg_types ,
1088
+ triton_meta = self .triton_meta ,
1089
+ triton = True ,
1090
+ )
1091
+
1092
+ # Generate the default kernel call
1093
+ # Check if we have specialized kernels for divisibility
1094
+ if hasattr (self , 'mod_div16' ) and self .mod_div16 is not None and hasattr (self , 'mod_nodiv16' ) and self .mod_nodiv16 is not None :
1095
+ # Generate a wrapper function that checks divisibility at runtime and dispatches to the appropriate kernel
1096
+ wrapper .add_import_once ("import torch" )
1097
+
1098
+ # Create a unique name for the wrapper function
1099
+ wrapper_name = f"{ name } _divisibility_wrapper"
1100
+
1101
+ # Get the SizeArg indices from the signature
1102
+ size_arg_indices = []
1103
+ for i , arg in enumerate (signature ):
1104
+ if isinstance (arg , SizeArg ) and arg .expr is not None :
1105
+ size_arg_indices .append (i )
1106
+
1107
+ # Generate the wrapper function
1108
+ wrapper .writeline (f"def { wrapper_name } ({ ', ' .join (a .full_name () for a in argdefs )} ):" )
1109
+ with wrapper .indent ():
1110
+ # Check if all SizeArgs are divisible by 16
1111
+ if size_arg_indices :
1112
+ divisibility_checks = []
1113
+ for i in size_arg_indices :
1114
+ arg_name = argdefs [i ].name
1115
+ divisibility_checks .append (f"{ arg_name } % 16 == 0" )
1116
+
1117
+ wrapper .writeline (f"if { ' and ' .join (divisibility_checks )} :" )
1118
+ with wrapper .indent ():
1119
+ wrapper .writeline (f"return { name } _div16({ ', ' .join (a .name for a in argdefs )} )" )
1120
+ wrapper .writeline ("else:" )
1121
+ with wrapper .indent ():
1122
+ wrapper .writeline (f"return { name } _nodiv16({ ', ' .join (a .name for a in argdefs )} )" )
1123
+ else :
1124
+ # If there are no SizeArgs, just use the default kernel
1125
+ wrapper .writeline (f"return { name } ({ ', ' .join (a .name for a in argdefs )} )" )
1126
+
1127
+ # Generate the specialized kernel calls
1128
+ wrapper .generate_kernel_call (
1129
+ f"{ name } _div16" ,
1130
+ call_args ,
1131
+ arg_types = arg_types ,
1132
+ triton_meta = self .triton_meta ,
1133
+ triton = True ,
1134
+ )
1135
+
1136
+ wrapper .generate_kernel_call (
1137
+ f"{ name } _nodiv16" ,
1138
+ call_args ,
1139
+ arg_types = arg_types ,
1140
+ triton_meta = self .triton_meta ,
1141
+ triton = True ,
1142
+ )
1143
+
1144
+ # Generate the default kernel call for backward compatibility
1145
+ wrapper .generate_kernel_call (
1146
+ name ,
1147
+ call_args ,
1148
+ arg_types = arg_types ,
1149
+ triton_meta = self .triton_meta ,
1150
+ triton = True ,
1151
+ )
1152
+
1153
+ # Use the wrapper function instead of the direct kernel call
1154
+ name = wrapper_name
1155
+ else :
1156
+ # Just generate the default kernel call
1157
+ wrapper .generate_kernel_call (
1158
+ name ,
1159
+ call_args ,
1160
+ arg_types = arg_types ,
1161
+ triton_meta = self .triton_meta ,
1162
+ triton = True ,
1163
+ )
1164
+
1165
+ # Use the wrapper function instead of the direct kernel call
1166
+ name = wrapper_name
1167
+ else :
1168
+ # Just generate the default kernel call
1169
+ wrapper .generate_kernel_call (
1170
+ name ,
1171
+ call_args ,
1172
+ arg_types = arg_types ,
1173
+ triton_meta = self .triton_meta ,
1174
+ triton = True ,
1175
+ )
1176
+
1046
1177
if self .workspace_arg is not None :
1047
1178
wrapper .generate_workspace_deallocation (self .workspace_arg )
1048
1179
@@ -1078,6 +1209,10 @@ class GenerateAndLoadResult(NamedTuple):
1078
1209
prologue_supported_inputs : OrderedSet [str ]
1079
1210
kernel_args_sizevars_keys : tuple [sympy .Expr ]
1080
1211
kernel_options : dict [str , Any ]
1212
+ mod_div16 : Optional [ModuleType ] = None
1213
+ mod_nodiv16 : Optional [ModuleType ] = None
1214
+ mod_div16 : Optional [ModuleType ] = None
1215
+ mod_nodiv16 : Optional [ModuleType ] = None
1081
1216
1082
1217
1083
1218
class TritonTemplate (KernelTemplate ):
@@ -1125,6 +1260,8 @@ def generate_and_load(
1125
1260
1126
1261
fake_out = ir .Buffer (name = "buf_out" , layout = layout )
1127
1262
kernel_name = f"triton_{ self .name } "
1263
+ kernel_name_div16 = f"{ kernel_name } _div16"
1264
+ kernel_name_nodiv16 = f"{ kernel_name } _nodiv16"
1128
1265
1129
1266
numel = sympy_product (layout .size )
1130
1267
buffers = itertools .chain (input_nodes , (fake_out ,))
@@ -1164,7 +1301,7 @@ def make_kernel():
1164
1301
** kernel_options ,
1165
1302
)
1166
1303
1167
- def generate_code (kernel ) -> Optional [tuple [str , str ]]:
1304
+ def generate_code (kernel , divisible_by_16 = None ) -> Optional [tuple [str , str ]]:
1168
1305
def make_extra () -> str :
1169
1306
extra_parts = [
1170
1307
f"{ kwarg } ={ repr (kwargs [kwarg ])} " for kwarg in sorted (kwargs .keys ())
@@ -1183,13 +1320,16 @@ def make_extra() -> str:
1183
1320
f"num_buffers_warp_spec={ num_buffers_warp_spec } " ,
1184
1321
]
1185
1322
)
1323
+ if divisible_by_16 is not None :
1324
+ extra_parts .append (f"divisible_by_16={ divisible_by_16 } " )
1186
1325
extra = "-" .join (extra_parts ) + "-"
1187
1326
return extra
1188
1327
1189
1328
try :
1190
- template = kernel .render (self .template , kwargs )
1191
- with kernel .set_subgraph_body ("<STORE_OUTPUT>" ):
1192
- code = template .finalize_all ()
1329
+ with patch .object (config .triton , "divisible_by_16" , divisible_by_16 ) if divisible_by_16 is not None else contextlib .nullcontext ():
1330
+ template = kernel .render (self .template , kwargs )
1331
+ with kernel .set_subgraph_body ("<STORE_OUTPUT>" ):
1332
+ code = template .finalize_all ()
1193
1333
except ZeroDivisionError :
1194
1334
# TODO(nmacchioni): fix sympy division by zero
1195
1335
return None
@@ -1202,20 +1342,50 @@ def make_extra() -> str:
1202
1342
# Generate code, extra.
1203
1343
code : Optional [str ] = None
1204
1344
extra : Optional [str ] = None
1345
+ code_div16 : Optional [str ] = None
1346
+ extra_div16 : Optional [str ] = None
1347
+ code_nodiv16 : Optional [str ] = None
1348
+ extra_nodiv16 : Optional [str ] = None
1349
+
1205
1350
with (
1206
1351
patch .object (V .graph , "get_dtype" , self ._fake_get_dtype (fake_out )),
1207
1352
V .graph .set_current_device (layout .device ),
1208
- make_kernel () as kernel ,
1209
1353
):
1210
- result = generate_code (kernel )
1211
- if not result : # happens at ZeroDivisionError:
1212
- return None
1213
- code , extra = result
1354
+ # Generate the default kernel (using config.triton.divisible_by_16 setting)
1355
+ with make_kernel () as kernel :
1356
+ result = generate_code (kernel )
1357
+ if not result : # happens at ZeroDivisionError:
1358
+ return None
1359
+ code , extra = result
1360
+
1361
+ # Generate kernel with divisible_by_16=True
1362
+ with make_kernel () as kernel :
1363
+ kernel .kernel_name = kernel_name_div16 # Set the kernel name for the specialized kernel
1364
+ result_div16 = generate_code (kernel , divisible_by_16 = True )
1365
+ if result_div16 :
1366
+ code_div16 , extra_div16 = result_div16
1367
+
1368
+ # Generate kernel with divisible_by_16=False
1369
+ with make_kernel () as kernel :
1370
+ kernel .kernel_name = kernel_name_nodiv16 # Set the kernel name for the specialized kernel
1371
+ result_nodiv16 = generate_code (kernel , divisible_by_16 = False )
1372
+ if result_nodiv16 :
1373
+ code_nodiv16 , extra_nodiv16 = result_nodiv16
1214
1374
1215
1375
assert code is not None and extra is not None
1216
1376
1377
+ # Load the default kernel
1217
1378
mod = PyCodeCache .load (code , extra )
1218
1379
1380
+ # Load the specialized kernels if they were generated
1381
+ mod_div16 = None
1382
+ if code_div16 is not None and extra_div16 is not None :
1383
+ mod_div16 = PyCodeCache .load (code_div16 , extra_div16 )
1384
+
1385
+ mod_nodiv16 = None
1386
+ if code_nodiv16 is not None and extra_nodiv16 is not None :
1387
+ mod_nodiv16 = PyCodeCache .load (code_nodiv16 , extra_nodiv16 )
1388
+
1219
1389
input_call_args = tuple (kernel .args .input_buffers .keys ())
1220
1390
prologue_supported_inputs = kernel .prologue_supported_inputs .copy ()
1221
1391
kernel_args_sizevars_keys = tuple (kernel .args .sizevars .keys ())
@@ -1227,6 +1397,8 @@ def make_extra() -> str:
1227
1397
prologue_supported_inputs ,
1228
1398
kernel_args_sizevars_keys ,
1229
1399
kernel_options ,
1400
+ mod_div16 ,
1401
+ mod_nodiv16 ,
1230
1402
)
1231
1403
1232
1404
def generate ( # type: ignore[override]
@@ -1376,7 +1548,7 @@ def make_kernel_render(out_node):
1376
1548
output_tensor_meta = TensorMeta .from_irnodes (layout ),
1377
1549
)
1378
1550
1379
- return TritonTemplateCaller (
1551
+ caller = TritonTemplateCaller (
1380
1552
kernel_hash_name ,
1381
1553
full_input_nodes ,
1382
1554
layout ,
@@ -1405,6 +1577,12 @@ def make_kernel_render(out_node):
1405
1577
allowed_prologue_inps = result .prologue_supported_inputs ,
1406
1578
)
1407
1579
1580
+ # Store the specialized kernels for divisibility checks
1581
+ caller .mod_div16 = result .mod_div16
1582
+ caller .mod_nodiv16 = result .mod_nodiv16
1583
+
1584
+ return caller
1585
+
1408
1586
1409
1587
class ExternKernelChoice :
1410
1588
def __init__ (
@@ -1497,6 +1675,9 @@ def __init__(
1497
1675
self .allowed_prologue_inps = (
1498
1676
allowed_prologue_inps if allowed_prologue_inps is not None else OrderedSet ()
1499
1677
)
1678
+ # Store specialized kernels for divisibility checks
1679
+ self .mod_div16 = None
1680
+ self .mod_nodiv16 = None
1500
1681
1501
1682
def benchmark (self , * args , out ):
1502
1683
assert self .bmreq is not None
0 commit comments