8000 simplify logic · pytorch/pytorch@1adf1c2 · GitHub
[go: up one dir, main page]

Skip to content

Commit 1adf1c2

Browse files
danielvegamyhrepytorchmergebot
authored andcommitted
simplify logic
1 parent 0e0c498 commit 1adf1c2

File tree

1 file changed

+9
-11
lines changed

1 file changed

+9
-11
lines changed

torch/_inductor/fx_passes/micro_pipeline_tp.py

Lines changed: 9 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -452,8 +452,12 @@ def insert_reshape_op(node: torch.fx.Node):
452452

453453
A_ndim = _get_tensor(A_node).ndim
454454
A_scale_ndim = _get_tensor(A_scale_node).ndim
455-
A_scale_has_single_parent = len(A_scale_node.all_input_nodes) == 1
456-
tensorwise_scaling = A_scale_ndim <= 1
455+
is_reciprocal_with_reshape_parent = (
456+
A_scale_node.target == aten.reciprocal.default
457+
and len(A_scale_node.all_input_nodes) == 1
458+
and A_scale_node.all_input_nodes[0].target == aten.reshape.default
459+
)
460+
is_tensorwise_scaling = A_scale_ndim <= 1
457461

458462
# This is a temporary workaround to handle the reshape -> scaled_mm -> reshape
459463
# pattern when scales are row-wise, and have been reshaped along with the target
@@ -469,16 +473,10 @@ def insert_reshape_op(node: torch.fx.Node):
469473
if (
470474
is_reshape_mm_reshape_pattern
471475
and A_ndim != A_scale_ndim
472-
and A_scale_has_single_parent
473-
and not tensorwise_scaling
476+
and not is_tensorwise_scaling
477+
and is_reciprocal_with_reshape_parent
474478
):
475-
A_scale_parent = A_scale_node.all_input_nodes[0]
476-
477-
if (
478-
A_scale_parent.target == aten.reshape.default
479-
and A_scale_node.target == aten.reciprocal.default
480-
):
481-
A_scale_node = insert_reshape_op(A_scale_node)
479+
A_scale_node = insert_reshape_op(A_scale_node)
482480

483481
return _ScaledMatmul(
484482
nodes=match,

0 commit comments

Comments
 (0)
0