@@ -445,10 +445,10 @@ def insert_reshape_op(node: torch.fx.Node):
445
445
# Case 2: 3 node match (reshape -> mm -> reshape)
446
446
# - match[0].args[0] will be the "A tensor" input to the reshape op
447
447
# - Has 3D+ shape
448
- A_node = match [0 ].args [0 ]
449
- B_node = mm_node .args [1 ]
450
- A_scale_node = mm_node .args [2 ]
451
- B_scale_node = mm_node .args [3 ]
448
+ A_node = cast ( torch . fx . Node , match [0 ].args [0 ])
449
+ B_node = cast ( torch . fx . Node , mm_node .args [1 ])
450
+ A_scale_node = cast ( torch . fx . Node , mm_node .args [2 ])
451
+ B_scale_node = cast ( torch . fx . Node , mm_node .args [3 ])
452
452
453
453
A_ndim = _get_tensor (A_node ).ndim
454
454
A_scale_ndim = _get_tensor (A_scale_node ).ndim
@@ -480,10 +480,10 @@ def insert_reshape_op(node: torch.fx.Node):
480
480
481
481
return _ScaledMatmul (
482
482
nodes = match ,
483
- A_node = cast ( torch . fx . Node , A_node ) ,
484
- B_node = cast ( torch . fx . Node , B_node ) ,
485
- A_scale_node = cast ( torch . fx . Node , A_scale_node ) ,
486
- B_scale_node = cast ( torch . fx . Node , B_scale_node ) ,
483
+ A_node = A_node ,
484
+ B_node = B_node ,
485
+ A_scale_node = A_scale_node ,
486
+ B_scale_node = B_scale_node ,
487
487
bias_node = get_arg (mm_node , 4 , None ),
488
488
result_scale_node = get_arg (mm_node , 5 , None ),
489
489
out_dtype = get_arg (mm_node , 6 , None ),
0 commit comments