8000 cast to node to appease linter · pytorch/pytorch@593119f · GitHub
[go: up one dir, main page]

Skip to content

Commit 593119f

Browse files
cast to node to appease linter
1 parent 27609f5 commit 593119f

File tree

1 file changed

+3
-1
lines changed

1 file changed

+3
-1
lines changed

torch/_inductor/fx_passes/micro_pipeline_tp.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -655,7 +655,9 @@ def _scatter_dim_after_reshape(
655655
"reshape must produce 2D tensor for scaled_mm"
656656
)
657657

658-
reshape_op_input_tensor = _get_tensor(reshape_node.args[0])
658+
assert len(reshape_node.args) >= 1, "reshape node must have at least 1 arg"
659+
input_tensor_node = cast(torch.fx.Node, reshape_node.args[0])
660+
reshape_op_input_tensor = _get_tensor(input_tensor_node)
659661
assert reshape_op_input_tensor.ndim > reshape_op_output_tensor.ndim, (
660662
"reshape must be from 3D+ to 2D"
661663
)

0 commit comments

Comments
 (0)
0