8000 devmate attempt multi kernel · pytorch/pytorch@0349f2f · GitHub
[go: up one dir, main page]

Skip to content

Commit 0349f2f

Browse files
committed
devmate attempt multi kernel
ghstack-source-id: 439ddb5 Pull Request resolved: #153353
1 parent 3ad3346 commit 0349f2f

File tree

1 file changed

+199
-18
lines changed

1 file changed

+199
-18
lines changed

torch/_inductor/select_algorithm.py

Lines changed: 199 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1014,7 +1014,7 @@ def codegen_range_tree(self):
10141014

10151015
def call_kernel(self, name: str, node: Optional[ir.IRNode] = None):
10161016
wrapper = V.graph.wrapper_code
1017-
_, call_args, _, arg_types = self.args.python_argdefs()
1017+
argdefs, call_args, signature, arg_types = self.args.python_argdefs()
10181018

10191019
grid_args = ()
10201020
if isinstance(self.grid_fn, SymbolicGridFn):
@@ -1036,13 +1036,144 @@ def call_kernel(self, name: str, node: Optional[ir.IRNode] = None):
10361036

10371037
if self.workspace_arg is not None:
10381038
wrapper.generate_workspace_allocation(self.workspace_arg)
1039-
wrapper.generate_kernel_call(
1040-
name,
1041-
call_args,
1042-
arg_types=arg_types,
1043-
triton_meta=self.triton_meta,
1044-
triton=True,
1045-
)
1039+
1040+
# Check if we have specialized kernels for divisibility
1041+
if hasattr(self, 'mod_div16') and self.mod_div16 is not None and hasattr(self, 'mod_nodiv16') and self.mod_nodiv16 is not None:
1042+
# Generate a wrapper function that checks divisibility at runtime and dispatches to the appropriate kernel
1043+
wrapper.add_import_once("import torch")
1044+
1045+
# Create a unique name for the wrapper function
1046+
wrapper_name = f"{name}_divisibility_wrapper"
1047+
1048+
# Get the SizeArg indices from the signature
1049+
size_arg_indices = []
1050+
argdefs, _, signature, _ = self.args.python_argdefs()
1051+
for i, arg in enumerate(signature):
1052+
if isinstance(arg, SizeArg) and arg.expr is not None:
1053+
size_arg_indices.append(i)
1054+
1055+
# Generate the wrapper function
1056+
wrapper.writeline(f"def {wrapper_name}({', '.join(a.full_name() for a in argdefs)}):")
1057+
with wrapper.indent():
1058+
# Check if all SizeArgs are divisible by 16
1059+
if size_arg_indices:
1060+
divisibility_checks = []
1061+
for i in size_arg_indices:
1062+
arg_name = argdefs[i].name
1063+
divisibility_checks.append(f"{arg_name} % 16 == 0")
1064+
1065+
wrapper.writeline(f"if {' and '.join(divisibility_checks)}:")
1066+
with wrapper.indent():
1067+
wrapper.writeline(f"return {name}_div16({', '.join(a.name for a in argdefs)})")
1068+
wrapper.writeline("else:")
1069+
with wrapper.indent():
1070+
wrapper.writeline(f"return {name}_nodiv16({', '.join(a.name for a in argdefs)})")
1071+
else:
1072+
# If there are no SizeArgs, just use the default kernel
1073+
wrapper.writeline(f"return {name}({', '.join(a.name for a in argdefs)})")
1074+
1075+
# Generate the specialized kernel calls
1076+
wrapper.generate_kernel_call(
1077+
f"{name}_div16",
1078+
call_args,
1079+
arg_types=arg_types,
1080+
triton_meta=self.triton_meta,
1081+
triton=True,
1082+
)
1083+
1084+
wrapper.generate_kernel_call(
1085+
f"{name}_nodiv16",
1086+
call_args,
1087+
arg_types=arg_types,
1088+
triton_meta=self.triton_meta,
1089+
triton=True,
1090+
)
1091+
1092+
# Generate the default kernel call
1093+
# Check if we have specialized kernels for divisibility
1094+
if hasattr(self, 'mod_div16') and self.mod_div16 is not None and hasattr(self, 'mod_nodiv16') and self.mod_nodiv16 is not None:
1095+
# Generate a wrapper function that checks divisibility at runtime and dispatches to the appropriate kernel
1096+
wrapper.add_import_once("import torch")
1097+
1098+
# Create a unique name for the wrapper function
1099+
wrapper_name = f"{name}_divisibility_wrapper"
1100+
1101+
# Get the SizeArg indices from the signature
1102+
size_arg_indices = []
1103+
for i, arg in enumerate(signature):
1104+
if isinstance(arg, SizeArg) and arg.expr is not None:
1105+
size_arg_indices.append(i)
1106+
1107+
# Generate the wrapper function
1108+
wrapper.writeline(f"def {wrapper_name}({', '.join(a.full_name() for a in argdefs)}):")
1109+
with wrapper.indent():
1110+
# Check if all SizeArgs are divisible by 16
1111+
if size_arg_indices:
1112+
divisibility_checks = []
1113+
for i in size_arg_indices:
1114+
arg_name = argdefs[i].name
1115+
divisibility_checks.append(f"{arg_name} % 16 == 0")
1116+
1117+
wrapper.writeline(f"if {' and '.join(divisibility_checks)}:")
1118+
with wrapper.indent():
1119+
wrapper.writeline(f"return {name}_div16({', '.join(a.name for a in argdefs)})")
1120+
wrapper.writeline("else:")
1121+
with wrapper.indent():
1122+
wrapper.writeline(f"return {name}_nodiv16({', '.join(a.name for a in argdefs)})")
1123+
else:
1124+
# If there are no SizeArgs, just use the default kernel
1125+
wrapper.writeline(f"return {name}({', '.join(a.name for a in argdefs)})")
1126+
1127+
# Generate the specialized kernel calls
1128+
wrapper.generate_kernel_call(
1129+
f"{name}_div16",
1130+
call_args,
1131+
arg_types=arg_types,
1132+
triton_meta=self.triton_meta,
1133+
triton=True,
1134+
)
1135+
1136+
wrapper.generate_kernel_call(
1137+
f"{name}_nodiv16",
1138+
call_args,
1139+
arg_types=arg_types,
1140+
triton_meta=self.triton_meta,
1141+
triton=True,
1142+
)
1143+
1144+
# Generate the default kernel call for backward compatibility
1145+
wrapper.generate_kernel_call(
1146+
name,
1147+
call_args,
1148+
arg_types=arg_types,
1149+
triton_meta=self.triton_meta,
1150+
triton=True,
1151+
)
1152+
1153+
# Use the wrapper function instead of the direct kernel call
1154+
name = wrapper_name
1155+
else:
1156+
# Just generate the default kernel call
1157+
wrapper.generate_kernel_call(
1158+
name,
1159+
call_args,
1160+
arg_types=arg_types,
1161+
triton_meta=self.triton_meta,
1162+
triton=True,
1163+
)
1164+
1165+
# Use the wrapper function instead of the direct kernel call
1166+
name = wrapper_name
1167+
else:
1168+
# Just generate the default kernel call
1169+
wrapper.generate_kernel_call(
1170+
name,
1171+
call_args,
1172+
arg_types=arg_types,
1173+
triton_meta=self.triton_meta,
1174+
triton=True,
1175+
)
1176+
10461177
if self.workspace_arg is not None:
10471178
wrapper.generate_workspace_deallocation(self.workspace_arg)
10481179

@@ -1078,6 +1209,10 @@ class GenerateAndLoadResult(NamedTuple):
10781209
prologue_supported_inputs: OrderedSet[str]
10791210
kernel_args_sizevars_keys: tuple[sympy.Expr]
10801211
kernel_options: dict[str, Any]
1212+
mod_div16: Optional[ModuleType] = None
1213+
mod_nodiv16: Optional[ModuleType] = None
1214+
mod_div16: Optional[ModuleType] = None
1215+
mod_nodiv16: Optional[ModuleType] = None
10811216

10821217

10831218
class TritonTemplate(KernelTemplate):
@@ -1125,6 +1260,8 @@ def generate_and_load(
11251260

11261261
fake_out = ir.Buffer(name="buf_out", layout=layout)
11271262
kernel_name = f"triton_{self.name}"
1263+
kernel_name_div16 = f"{kernel_name}_div16"
1264+
kernel_name_nodiv16 = f"{kernel_name}_nodiv16"
11281265

11291266
numel = sympy_product(layout.size)
11301267
buffers = itertools.chain(input_nodes, (fake_out,))
@@ -1164,7 +1301,7 @@ def make_kernel():
11641301
**kernel_options,
11651302
)
11661303

1167-
def generate_code(kernel) -> Optional[tuple[str, str]]:
1304+
def generate_code(kernel, divisible_by_16=None) -> Optional[tuple[str, str]]:
11681305
def make_extra() -> str:
11691306
extra_parts = [
11701307
f"{kwarg}={repr(kwargs[kwarg])}" for kwarg in sorted(kwargs.keys())
@@ -1183,13 +1320,16 @@ def make_extra() -> str:
11831320
f"num_buffers_warp_spec={num_buffers_warp_spec}",
11841321
]
11851322
)
1323+
if divisible_by_16 is not None:
1324+
extra_parts.append(f"divisible_by_16={divisible_by_16}")
11861325
extra = "-".join(extra_parts) + "-"
11871326
return extra
11881327

11891328
try:
1190-
template = kernel.render(self.template, kwargs)
1191-
with kernel.set_subgraph_body("<STORE_OUTPUT>"):
1192-
code = template.finalize_all()
1329+
with patch.object(config.triton, "divisible_by_16", divisible_by_16) if divisible_by_16 is not None else contextlib.nullcontext():
1330+
template = kernel.render(self.template, kwargs)
1331+
with kernel.set_subgraph_body("<STORE_OUTPUT>"):
1332+
code = template.finalize_all()
11931333
except ZeroDivisionError:
11941334
# TODO(nmacchioni): fix sympy division by zero
11951335
return None
@@ -1202,20 +1342,50 @@ def make_extra() -> str:
12021342
# Generate code, extra.
12031343
code: Optional[str] = None
12041344
extra: Optional[str] = None
1345+
code_div16: Optional[str] = None
1346+
extra_div16: Optional[str] = None
1347+
code_nodiv16: Optional[str] = None
1348+
extra_nodiv16: Optional[str] = None
1349+
12051350
with (
12061351
patch.object(V.graph, "get_dtype", self._fake_get_dtype(fake_out)),
12071352
V.graph.set_current_device(layout.device),
1208-
make_kernel() as kernel,
12091353
):
1210-
result = generate_code(kernel)
1211-
if not result: # happens at ZeroDivisionError:
1212-
return None
1213-
code, extra = result
1354+
# Generate the default kernel (using config.triton.divisible_by_16 setting)
1355+
with make_kernel() as kernel:
1356+
result = generate_code(kernel)
1357+
if not result: # happens at ZeroDivisionError:
1358+
return None
1359+
code, extra = result
1360+
1361+
# Generate kernel with divisible_by_16=True
1362+
with make_kernel() as kernel:
1363+
kernel.kernel_name = kernel_name_div16 # Set the kernel name for the specialized kernel
1364+
result_div16 = generate_code(kernel, divisible_by_16=True)
1365+
if result_div16:
1366+
code_div16, extra_div16 = result_div16
1367+
1368+
# Generate kernel with divisible_by_16=False
1369+
with make_kernel() as kernel:
1370+
kernel.kernel_name = kernel_name_nodiv16 # Set the kernel name for the specialized kernel
1371+
result_nodiv16 = generate_code(kernel, divisible_by_16=False)
1372+
if result_nodiv16:
1373+
code_nodiv16, extra_nodiv16 = result_nodiv16
12141374

12151375
assert code is not None and extra is not None
12161376

1377+
# Load the default kernel
12171378
mod = PyCodeCache.load(code, extra)
12181379

1380+
# Load the specialized kernels if they were generated
1381+
mod_div16 = None
1382+
if code_div16 is not None and extra_div16 is not None:
1383+
mod_div16 = PyCodeCache.load(code_div16, extra_div16)
1384+
1385+
mod_nodiv16 = None
1386+
if code_nodiv16 is not None and extra_nodiv16 is not None:
1387+
mod_nodiv16 = PyCodeCache.load(code_nodiv16, extra_nodiv16)
1388+
12191389
input_call_args = tuple(kernel.args.input_buffers.keys())
12201390
prologue_supported_inputs = kernel.prologue_supported_inputs.copy()
12211391
kernel_args_sizevars_keys = tuple(kernel.args.sizevars.keys())
@@ -1227,6 +1397,8 @@ def make_extra() -> str:
12271397
prologue_supported_inputs,
12281398
kernel_args_sizevars_keys,
12291399
kernel_options,
1400+
mod_div16,
1401+
mod_nodiv16,
12301402
)
12311403

12321404
def generate( # type: ignore[override]
@@ -1376,7 +1548,7 @@ def make_kernel_render(out_node):
13761548
output_tensor_meta=TensorMeta.from_irnodes(layout),
13771549
)
13781550

1379-
return TritonTemplateCaller(
1551+
caller = TritonTemplateCaller(
13801552
kernel_hash_name,
13811553
full_input_nodes,
13821554
layout,
@@ -1405,6 +1577,12 @@ def make_kernel_render(out_node):
14051577
allowed_prologue_inps=result.prologue_supported_inputs,
14061578
)
14071579

1580+
# Store the specialized kernels for divisibility checks
1581+
caller.mod_div16 = result.mod_div16
1582+
caller.mod_nodiv16 = result.mod_nodiv16
1583+
1584+
return caller
1585+
14081586

14091587
class ExternKernelChoice:
14101588
def __init__(
@@ -1497,6 +1675,9 @@ def __init__(
14971675
self.allowed_prologue_inps = (
14981676
allowed_prologue_inps if allowed_prologue_inps is not None else OrderedSet()
14991677
)
1678+
# Store specialized kernels for divisibility checks
1679+
self.mod_div16 = None
1680+
self.mod_nodiv16 = None
15001681

15011682
def benchmark(self, *args, out):
15021683
assert self.bmreq is not None

0 commit comments

Comments
 (0)
< 10 /body> 0