8000 [triton 3.3] Fix aoti cpp wrapper remaining 5 issue. (following #1480… · pytorch/pytorch@4e160d5 · GitHub
[go: up one dir, main page]

Skip to content

Commit 4e160d5

Browse files
YUNQIUGUOpytorchmergebot
authored andcommitted
[triton 3.3] Fix aoti cpp wrapper remaining 5 issue. (following #148051) (#148117)
Summary: Fix the following 5 on a100: - test_foreach_cpp_wrapper_cuda_gpu_wrapper - test_enable_dynamic_shapes_cpp_wrapper_cuda_gpu_wrapper - test_dynamic_shapes_persistent_reduction_mixed_x_dim_cuda_gpu_wrapper - test_enable_dynamic_shapes_cpp_wrapper_cuda_dynamic_shapes_gpu_wrapper - test_dynamic_shapes_persistent_reduction_mixed_x_dim_cuda_dynamic_shapes_gpu_wrapper Test Plan: oss : ``` TORCHINDUCTOR_COMPILE_THREADS=1 TORCHINDUCTOR_MAX_AUTOTUNE_GEMM_BACKENDS=TRITON TORCH_LOGS="+inductor, output_code" TORCHINDUCTOR_FORCE_DISABLE_CACHES=1 CPLUS_INCLUDE_PATH=/usr/local/cuda-12.6/include:$CPLUS_INCLUDE_PATH python test/inductor/test_gpu_cpp_wrapper.py -k test_foreach_cpp_wrapper_cuda_gpu_wrapper ``` @diff-train-skip-merge Pull Request resolved: #148117 Approved by: https://github.com/davidberard98, https://github.com/chenyang78
1 parent ea12fc8 commit 4e160d5

File tree

1 file changed

+25
-25
lines changed

1 file changed

+25
-25
lines changed

torch/_inductor/codegen/cpp_wrapper_gpu.py

Lines changed: 25 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -571,31 +571,31 @@ def generate_kernel_call(
571571
device_index, call_args
572572
)
573573
kernel_var_name = self.generate_load_kernel_once(kernel_name, V.graph)
574-
if triton_version_uses_attrs_dict():
575-
signature = triton_meta["signature"]
576-
arg_signatures = [
577-
val for val in signature.values() if val != "constexpr"
578-
]
579-
call_args = [
580-
call_arg
581-
for call_arg, arg_name in zip(call_args, signature)
582-
if signature[arg_name] != "constexpr"
583-
]
584-
arg_types = [
585-
arg_type
586-
for arg_type, arg_name in zip(arg_types, signature)
587-
if signature[arg_name] != "constexpr"
588-
]
589-
else:
590-
# args with value 1 are added into equal_to_1 and constants
591-
# in triton_meta (in the Python codegen) which makes them
592-
# inlined in the PTX and compiled CUBIN
593-
arg_signatures = []
594-
if (
595-
triton_meta is not None
596-
and triton_meta.get("configs")
597-
and triton_meta.get("signature")
598-
):
574+
arg_signatures = []
575+
if (
576+
triton_meta is not None
577+
and triton_meta.get("configs")
578+
and triton_meta.get("signature")
579+
):
580+
if triton_version_uses_attrs_dict():
581+
signatures = triton_meta["signature"]
582+
arg_signatures = [
583+
val for val in signatures.values() if val != "constexpr"
584+
]
585+
call_args = [
586+
call_arg
587+
for call_arg, arg_name in zip(call_args, signatures)
588+
if signatures[arg_name] != "constexpr"
589+
]
590+
arg_types = [
591+
arg_type
592+
for arg_type, arg_name in zip(arg_types, signatures)
593+
if signatures[arg_name] != "constexpr"
594+
]
595+
else:
596+
# args with value 1 are added into equal_to_1 and constants
597+
# in triton_meta (in the Python codegen) which makes them
598+
# inlined in the PTX and compiled CUBIN
599599
equal_to_1 = triton_meta["configs"][0].equal_to_1
600600
call_args = [
601601
arg for i, arg in enumerate(call_args) if i not in equal_to_1

0 commit comments

Comments
 (0)
0