@@ -571,31 +571,31 @@ def generate_kernel_call(
571
571
device_index , call_args
572
572
)
573
573
kernel_var_name = self .generate_load_kernel_once (kernel_name , V .graph )
574
- if triton_version_uses_attrs_dict ():
575
- signature = triton_meta [ "signature" ]
576
- arg_signatures = [
577
- val for val in signature . values () if val != "constexpr"
578
- ]
579
- call_args = [
580
- call_arg
581
- for call_arg , arg_name in zip ( call_args , signature )
582
- if signature [ arg_name ] != "constexpr"
583
- ]
584
- arg_types = [
585
- arg_type
586
- for arg_type , arg_name in zip ( arg_types , signature )
587
- if signature [ arg_name ] != "constexpr"
588
- ]
589
- else :
590
- # args with value 1 are added into equal_to_1 and constants
591
- # in triton_meta (in the Python codegen) which makes them
592
- # inlined in the PTX and compiled CUBIN
593
- arg_signatures = []
594
- if (
595
- triton_meta is not None
596
- and triton_meta . get ( "configs" )
597
- and triton_meta . get ( "signature" )
598
- ):
574
+ arg_signatures = []
575
+ if (
576
+ triton_meta is not None
577
+ and triton_meta . get ( "configs" )
578
+ and triton_meta . get ( "signature" )
579
+ ):
580
+ if triton_version_uses_attrs_dict ():
581
+ signatures = triton_meta [ " signature" ]
582
+ arg_signatures = [
583
+ val for val in signatures . values () if val != "constexpr"
584
+ ]
585
+ call_args = [
586
+ call_arg
587
+ for call_arg , arg_name in zip ( call_args , signatures )
588
+ if signatures [ arg_name ] != "constexpr"
589
+ ]
590
+ arg_types = [
591
+ arg_type
592
+ for arg_type , arg_name in zip ( arg_types , signatures )
593
+ if signatures [ arg_name ] != "constexpr"
594
+ ]
595
+ else :
596
+ # args with value 1 are added into equal_to_1 and constants
597
+ # in triton_meta (in the Python codegen) which makes them
598
+ # inlined in the PTX and compiled CUBIN
599
599
equal_to_1 = triton_meta ["configs" ][0 ].equal_to_1
600
600
call_args = [
601
601
arg for i , arg in enumerate (call_args ) if i not in equal_to_1
0 commit comments