@@ -452,8 +452,12 @@ def insert_reshape_op(node: torch.fx.Node):
452
452
453
453
A_ndim = _get_tensor (A_node ).ndim
454
454
A_scale_ndim = _get_tensor (A_scale_node ).ndim
455
- A_scale_has_single_parent = len (A_scale_node .all_input_nodes ) == 1
456
- tensorwise_scaling = A_scale_ndim <= 1
455
+ is_reciprocal_with_reshape_parent = (
456
+ A_scale_node .target == aten .reciprocal .default
457
+ and len (A_scale_node .all_input_nodes ) == 1
458
+ and A_scale_node .all_input_nodes [0 ].target == aten .reshape .default
459
+ )
460
+ is_tensorwise_scaling = A_scale_ndim <= 1
457
461
458
462
# This is a temporary workaround to handle the reshape -> scaled_mm -> reshape
459
463
# pattern when scales are row-wise, and have been reshaped along with the target
@@ -469,16 +473,10 @@ def insert_reshape_op(node: torch.fx.Node):
469
473
if (
470
474
is_reshape_mm_reshape_pattern
471
475
and A_ndim != A_scale_ndim
472
- and A_scale_has_single_parent
473
- and not tensorwise_scaling
476
+ and not is_tensorwise_scaling
477
+ and is_reciprocal_with_reshape_parent
474
478
):
475
- A_scale_parent = A_scale_node .all_input_nodes [0 ]
476
-
477
- if (
478
- A_scale_parent .target == aten .reshape .default
479
- and A_scale_node .target == aten .reciprocal .default
480
- ):
481
- A_scale_node = insert_reshape_op (A_scale_node )
479
+ A_scale_node = insert_reshape_op (A_scale_node )
482
480
483
481
return _ScaledMatmul (
484
482
nodes = match ,
0 commit comments