8000 [inductor][triton 3.3] Fix cpp_wrapper w/ TMA in triton 3.3 (#149973) · pytorch/pytorch@a8d0c5c · GitHub
[go: up one dir, main page]

Skip to content

Commit a8d0c5c

Browse files
davidberard98pytorchmergebot
authored andcommitted
[inductor][triton 3.3] Fix cpp_wrapper w/ TMA in triton 3.3 (#149973)
Fixes #148938 Context: In triton 3.3, triton kernels expect a global scratch space arg to be passed in. This is fixed in #148051, which fixed most of the AOTI/cpp_wrapper failures; the fix is to inject a (null) global scratch space arg passed as an argument to all kernels. But in the case of TMA, we need to call a non-triton-generated function - init1DTMADescriptor. The same `generate_args_decl` function used for calling triton kernels (and modified in #148051 to insert a global scratch space) is used to prepare the arguments to init1DTMADescriptor, and so it had an extra global scratch space arg. Then we'd get a null pointer passed into init1DTMADescriptor, resulting in an IMA later on when the TMA use kernel This PR: adds an option to `generate_args_decl` to specify whether this is a triton kernel (in which case we should add the global scratch space arg) or not (when we shouldn't add the extra arg). Note: this doesn't appear in CI because we don't run these tests with Hopper machines in CI. Pull Request resolved: #149973 Approved by: https://github.com/drisspg
1 parent 1b373f6 commit a8d0c5c

File tree

1 file changed

+29
-4
lines changed

1 file changed

+29
-4
lines changed

torch/_inductor/codegen/cpp_wrapper_gpu.py

Lines changed: 29 additions & 4 deletions
348
Original file line numberDiff line numberDiff line change
@@ -326,6 +326,8 @@ def generate_tma_descriptor(self, desc):
326326
call_args=[self.val_to_arg_str(desc.tensor)],
327327
arg_types=[desc.tensor.get_dtype()],
328328
arg_signatures=[None],
329+
# these args are passed to initNDTMADescriptor, which is NOT a triton kernel
330+
is_triton_kernel=False,
329331
)
330332

331333
desc_name = desc.name
@@ -344,8 +346,27 @@ def generate_tma_descriptor(self, desc):
344346
self.writeline(f"{fn}({args});")
345347

346
def generate_args_decl(
347-
self, code: Union[IndentedBuffer, Self], call_args, arg_types, arg_signatures
349+
self,
350+
code: Union[IndentedBuffer, Self],
351+
call_args,
352+
arg_types,
353+
arg_signatures,
354+
is_triton_kernel=True,
348355
):
356+
"""
357+
Generates any declarations of args to pass into a kernel call, and then returns the arg names.
358+
359+
In more detail:
360+
* declarations: e.g. this function has a side effect of generating lines like `auto var_0 = ...;`
361+
* returns: a string with the list of args, e.g. "var_0, var_1"
362+
363+
call_args: list of call arguments
364+
arg_types: list of argument types
365+
arg_signatures: list with signatures of all the args
366+
is_triton_kernel: whether these are passed into a triton kernel or not. In particular,
367+
calls to triton kernels will have an additional global scratch space
368+
arg injected at the front of the arg list.
369+
"""
349370
new_args: list[str] = []
350371

351372
# Add more cases for other types as needed
@@ -398,10 +419,14 @@ def process_args(arg, arg_type, arg_signature=None):
398419
process_args(arg, arg_type, arg_signature)
399420

400421
if (
401-
global_scratch := self.device_codegen.cpp_global_scratch(
402-
next(self.arg_var_id)
422+
is_triton_kernel
423+
and (
424+
global_scratch := self.device_codegen.cpp_global_scratch(
425+
next(self.arg_var_id)
426+
)
403427
)
404-
) is not None:
428+
is not None
429+
):
405430
global_scratch_def, global_scratch_var = global_scratch
406431
code.writeline(global_scratch_def)
407432
new_args.append(f"&{global_scratch_var}")

0 commit comments

Comments
 (0)
0