8000 devmate attempt multi kernel by bobrenjc93 · Pull Request #153353 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content

devmate attempt multi kernel #153353

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
devmate attempt multi kernel
[ghstack-poisoned]
  • Loading branch information
bobrenjc93 committed May 12, 2025
commit a93d42bd5d2a2b3735cee6ab6ec201c28ad0f956
217 changes: 199 additions & 18 deletions torch/_inductor/select_algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -1014,7 +1014,7 @@

def call_kernel(self, name: str, node: Optional[ir.IRNode] = None):
wrapper = V.graph.wrapper_code
_, call_args, _, arg_types = self.args.python_argdefs()
argdefs, call_args, signature, arg_types = self.args.python_argdefs()

grid_args = ()
if isinstance(self.grid_fn, SymbolicGridFn):
Expand All @@ -1036,13 +1036,144 @@

if self.workspace_arg is not None:
wrapper.generate_workspace_allocation(self.workspace_arg)
wrapper.generate_kernel_call(
name,
call_args,
arg_types=arg_types,
triton_meta=self.triton_meta,
triton=True,
)

# Check if we have specialized kernels for divisibility
if hasattr(self, 'mod_div16') and self.mod_div16 is not None and hasattr(self, 'mod_nodiv16') and self.mod_nodiv16 is not None:
# Generate a wrapper function that checks divisibility at runtime and dispatches to the appropriate kernel
wrapper.add_import_once("import torch")

# Create a unique name for the wrapper function
wrapper_name = f"{name}_divisibility_wrapper"

# Get the SizeArg indices from the signature
size_arg_indices = []
argdefs, _, signature, _ = self.args.python_argdefs()
for i, arg in enumerate(signature):
if isinstance(arg, SizeArg) and arg.expr is not None:

Check failure on line 1052 in torch/_inductor/select_algorithm.py

View workflow job for this annotation

GitHub Actions / lintrunner-noclang / linux-job

MYPY [name-defined]

Name "SizeArg" is not defined
size_arg_indices.append(i)

# Generate the wrapper function
wrapper.writeline(f"def {wrapper_name}({', '.join(a.full_name() for a in argdefs)}):")
with wrapper.indent():

Check failure on line 1057 in torch/_inductor/select_algorithm.py

View 10000 workflow job for this annotation

GitHub Actions / lintrunner-noclang / linux-job

MYPY [attr-defined]

"PythonWrapperCodegen" has no attribute "indent"
# Check if all SizeArgs are divisible by 16
if size_arg_indices:
divisibility_checks = []
for i in size_arg_indices:
arg_name = argdefs[i].name
divisibility_checks.append(f"{arg_name} % 16 == 0")

wrapper.writeline(f"if {' and '.join(divisibility_checks)}:")
with wrapper.indent():

Check failure on line 1066 in torch/_inductor/select_algorithm.py

View workflow job for this annotation

GitHub Actions / lintrunner-noclang / linux-job

MYPY [attr-defined]

"PythonWrapperCodegen" has no attribute "indent"
wrapper.writeline(f"return {name}_div16({', '.join(a.name for a in argdefs)})")
wrapper.writeline("else:")
with wrapper.indent():

Check failure on line 1069 in torch/_inductor/select_algorithm.py

View workflow job for this annotation

GitHub Actions / lintrunner-noclang / linux-job

MYPY [attr-defined]

"PythonWrapperCodegen" has no attribute "indent"
wrapper.writeline(f"return {name}_nodiv16({', '.join(a.name for a in argdefs)})")
else:
# If there are no SizeArgs, just use the default kernel
wrapper.writeline(f"return {name}({', '.join(a.name for a in argdefs)})")

# Generate the specialized kernel calls
wrapper.generate_kernel_call(
f"{name}_div16",
call_args,
arg_types=arg_types,
triton_meta=self.triton_meta,
triton=True,
)

wrapper.generate_kernel_call(
f"{name}_nodiv16",
call_args,
arg_types=arg_types,
triton_meta=self.triton_meta,
triton=True,
)

# Generate the default kernel call
# Check if we have specialized kernels for divisibility
if hasattr(self, 'mod_div16') and self.mod_div16 is not None and hasattr(self, 'mod_nodiv16') and self.mod_nodiv16 is not None:
# Generate a wrapper function that checks divisibility at runtime and dispatches to the appropriate kernel
wrapper.add_import_once("import torch")

# Create a unique name for the wrapper function
wrapper_name = f"{name}_divisibility_wrapper"

# Get the SizeArg indices from the signature
size_arg_indices = []
for i, arg in enumerate(signature):
if isinstance(arg, SizeArg) and arg.expr is not None:

Check failure on line 1104 in torch/_inductor/select_algorithm.py

View workflow job for this annotation

GitHub Actions / lintrunner-noclang / linux-job

MYPY [name-defined]

Name "SizeArg" is not defined
size_arg_indices.append(i)

# Generate the wrapper function
wrapper.writeline(f"def {wrapper_name}({', '.join(a.full_name() for a in argdefs)}):")
with wrapper.indent():

Check failure on line 1109 in torch/_inductor/select_algorithm.py

View workflow job for this annotation

GitHub Actions / lintrunner-noclang / linux-job

MYPY [attr-defined]

"PythonWrapperCodegen" has no attribute "indent"
# Check if all SizeArgs are divisible by 16
if size_arg_indices:
divisibility_checks = []
for i in size_arg_indices:
arg_name = argdefs[i].name
divisibility_checks.append(f"{arg_name} % 16 == 0")

wrapper.writeline(f"if {' and '.join(divisibility_checks)}:")
with wrapper.indent():

Check failure on line 1118 in torch/_inductor/select_algorithm.py

View workflow job for this annotation

GitHub Actions / lintrunner-noclang / linux-job

MYPY [attr-defined]

"PythonWrapperCodegen" has no attribute "indent"
wrapper.writeline(f"return {name}_div16({', '.join(a.name for a in argdefs)})")
wrapper.writeline("else:")
with wrapper.indent():

Check failure on line 1121 in torch/_inductor/select_algorithm.py

View workflow job for this annotation

GitHub Actions / lintrunner-noclang / linux-job

MYPY [attr-defined]

"PythonWrapperCodegen" has no attribute "indent"
wrapper.writeline(f"return {name}_nodiv16({', '.join(a.name for a in argdefs)})")
else:
# If there are no SizeArgs, just use the default kernel
wrapper.writeline(f"return {name}({', '.join(a.name for a in argdefs)})")

# Generate the specialized kernel calls
wrapper.generate_kernel_call(
f"{name}_div16",
call_args,
arg_types=arg_types,
triton_meta=self.triton_meta,
triton=True,
)

wrapper.generate_kernel_call(
f"{name}_nodiv16",
call_args,
arg_types=arg_types,
triton_meta=self.triton_meta,
triton=True,
)

# Generate the default kernel call for backward compatibility
wrapper.generate_kernel_call(
name,
call_args,
arg_types=arg_types,
triton_meta=self.triton_meta,
triton=True,
)

# Use the wrapper function instead of the direct kernel call
name = wrapper_name
else:
# Just generate the default kernel call
wrapper.generate_kernel_call(
name,
call_args,
arg_types=arg_types,
triton_meta=self.triton_meta,
triton=True,
)

# Use the wrapper function instead of the direct kernel call
name = wrapper_name
else:
# Just generate the default kernel call
wrapper.generate_kernel_call(
name,
call_args,
arg_types=arg_types,
triton_meta=self.triton_meta,
triton=True,
)

if self.workspace_arg is not None:
wrapper.generate_workspace_deallocation(self.workspace_arg)

Expand All @@ -1067,7 +1198,7 @@
return None


class GenerateAndLoadResult(NamedTuple):

Check failure on line 1201 in torch/_inductor/select_algorithm.py

View workflow job for this annotation

GitHub Actions / lintrunner-noclang / linux-job

MYPY [no-redef]

Name "mod_nodiv16" already defined (possibly by an import)

Check failure on line 1201 in torch/_inductor/select_algorithm.py

View workflow job for this annotation

GitHub Actions / lintrunner-noclang / linux-job

MYPY [no-redef]

Name "mod_div16" already defined (possibly by an import)
"""
Return type of TritonTemplate.generate_and_load.
"""
Expand All @@ -1078,6 +1209,10 @@
prologue_supported_inputs: OrderedSet[str]
kernel_args_sizevars_keys: tuple[sympy.Expr]
kernel_options: dict[str, Any]
mod_div16: Optional[ModuleType] = None
mod_nodiv16: Optional[ModuleType] = None
mod_div16: Optional[ModuleType] = None
mod_nodiv16: Optional[ModuleType] = None


class TritonTemplate(KernelTemplate):
Expand Down Expand Up @@ -1125,6 +1260,8 @@

fake_out = ir.Buffer(name="buf_out", layout=layout)
kernel_name = f"triton_{self.name}"
kernel_name_div16 = f"{kernel_name}_div16"
kernel_name_nodiv16 = f"{kernel_name}_nodiv16"

numel = sympy_product(layout.size)
buffers = itertools.chain(input_nodes, (fake_out,))
Expand Down Expand Up @@ -1164,7 +1301,7 @@
**kernel_options,
)

def generate_code(kernel) -> Optional[tuple[str, str]]:
def generate_code(kernel, divisible_by_16=None) -> Optional[tuple[str, str]]:
def make_extra() -> str:
extra_parts = [
f"{kwarg}={repr(kwargs[kwarg])}" for kwarg in sorted(kwargs.keys())
Expand All @@ -1183,13 +1320,16 @@
f"num_buffers_warp_spec={num_buffers_warp_spec}",
]
)
if divisible_by_16 is not None:
extra_parts.append(f"divisible_by_16={divisible_by_16}")
extra = "-".join(extra_parts) + "-"
return extra

try:
template = kernel.render(self.template, kwargs)
with kernel.set_subgraph_body("<STORE_OUTPUT>"):
code = template.finalize_all()
with patch.object(config.triton, "divisible_by_16", divisible_by_16) if divisible_by_16 is not None else contextlib.nullcontext():
template = kernel.render(self.template, kwargs)
with kernel.set_subgraph_body("<STORE_OUTPUT>"):
code = template.finalize_all()
except ZeroDivisionError:
# TODO(nmacchioni): fix sympy division by zero
return None
Expand All @@ -1202,20 +1342,50 @@
# Generate code, extra.
code: Optional[str] = None
extra: Optional[str] = None
code_div16: Optional[str] = None
extra_div16: Optional[str] = None
code_nodiv16: Optional[str] = None
extra_nodiv16: Optional[str] = None

with (
patch.object(V.graph, "get_dtype", self._fake_get_dtype(fake_out)),
V.graph.set_current_device(layout.device),
make_kernel() as kernel,
):
result = generate_code(kernel)
if not result: # happens at ZeroDivisionError:
return None
code, extra = result
# Generate the default kernel (using config.triton.divisible_by_16 setting)
with make_kernel() as kernel:
result = generate_code(kernel)
if not result: # happens at ZeroDivisionError:
return None
code, extra = result

# Generate kernel with divisible_by_16=True
with make_kernel() as kernel:
kernel.kernel_name = kernel_name_div16 # Set the kernel name for the specialized kernel
result_div16 = generate_code(kernel, divisible_by_16=True)
if result_div16:
code_div16, extra_div16 = result_div16

# Generate kernel with divisible_by_16=False
with make_kernel() as kernel:
kernel.kernel_name = kernel_name_nodiv16 # Set the kernel name for the specialized kernel
result_nodiv16 = generate_code(kernel, divisible_by_16=False)
if result_nodiv16:
code_nodiv16, extra_nodiv16 = result_nodiv16

assert code is not None and extra is not None

# Load the default kernel
mod = PyCodeCache.load(code, extra)

# Load the specialized kernels if they were generated
mod_div16 = None
if code_div16 is not None and extra_div16 is not None:
mod_div16 = PyCodeCache.load(code_div16, extra_div16)

mod_nodiv16 = None
if code_nodiv16 is not None and extra_nodiv16 is not None:
mod_nodiv16 = PyCodeCache.load(code_nodiv16, extra_nodiv16)

input_call_args = tuple(kernel.args.input_buffers.keys())
prologue_supported_inputs = kernel.prologue_supported_inputs.copy()
kernel_args_sizevars_keys = tuple(kernel.args.sizevars.keys())
Expand All @@ -1227,6 +1397,8 @@
prologue_supported_inputs,
kernel_args_sizevars_keys,
kernel_options,
mod_div16,
mod_nodiv16,
)

def generate( # type: ignore[override]
Expand Down Expand Up @@ -1376,7 +1548,7 @@
output_tensor_meta=TensorMeta.from_irnodes(layout),
)

return TritonTemplateCaller(
caller = TritonTemplateCaller(
kernel_hash_name,
full_input_nodes,
layout,
Expand Down Expand Up @@ -1405,6 +1577,12 @@
allowed_prologue_inps=result.prologue_supported_inputs,
)

# Store the specialized kernels for divisibility checks
caller.mod_div16 = result.mod_div16
caller.mod_nodiv16 = result.mod_nodiv16

return caller


class ExternKernelChoice:
def __init__(
Expand Down Expand Up @@ -1497,6 +1675,9 @@
self.allowed_prologue_inps = (
allowed_prologue_inps if allowed_prologue_inps is not None else OrderedSet()
)
# Store specialized kernels for divisibility checks
self.mod_div16 = None
self.mod_nodiv16 = None

def benchmark(self, *args, out):
assert self.bmreq is not None
Expand Down
Loading
0