8000 fix mypy errors · pytorch/pytorch@22c33ca · GitHub
[go: up one dir, main page]

Skip to content

Commit 22c33ca

Browse files
danielvegamyhrepytorchmergebot
authored andcommitted
fix mypy errors
1 parent 1adf1c2 commit 22c33ca

File tree

1 file changed

+8
-8
lines changed

1 file changed

+8
-8
lines changed

torch/_inductor/fx_passes/micro_pipeline_tp.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -445,10 +445,10 @@ def insert_reshape_op(node: torch.fx.Node):
445445
# Case 2: 3 node match (reshape -> mm -> reshape)
446446
# - match[0].args[0] will be the "A tensor" input to the reshape op
447447
# - Has 3D+ shape
448-
A_node = match[0].args[0]
449-
B_node = mm_node.args[1]
450-
A_scale_node = mm_node.args[2]
451-
B_scale_node = mm_node.args[3]
448+
A_node = cast(torch.fx.Node, match[0].args[0])
449+
B_node = cast(torch.fx.Node, mm_node.args[1])
450+
A_scale_node = cast(torch.fx.Node, mm_node.args[2])
451+
B_scale_node = cast(torch.fx.Node, mm_node.args[3])
452452

453453
A_ndim = _get_tensor(A_node).ndim
454454
A_scale_ndim = _get_tensor(A_scale_node).ndim
@@ -480,10 +480,10 @@ def insert_reshape_op(node: torch.fx.Node):
480480

481481
return _ScaledMatmul(
482482
nodes=match,
483-
A_node=cast(torch.fx.Node, A_node),
484-
B_node=cast(torch.fx.Node, B_node),
485-
A_scale_node=cast(torch.fx.Node, A_scale_node),
486-
B_scale_node=cast(torch.fx.Node, B_scale_node),
483+
A_node=A_node,
484+
B_node=B_node,
485+
A_scale_node=A_scale_node,
486+
B_scale_node=B_scale_node,
487487
bias_node=get_arg(mm_node, 4, None),
488488
result_scale_node=get_arg(mm_node, 5, None),
489489
out_dtype=get_arg(mm_node, 6, None),

0 commit comments

Comments
 (0)
0