8000
We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 27609f5 commit 593119fCopy full SHA for 593119f
torch/_inductor/fx_passes/micro_pipeline_tp.py
@@ -655,7 +655,9 @@ def _scatter_dim_after_reshape(
655
"reshape must produce 2D tensor for scaled_mm"
656
)
657
658
- reshape_op_input_tensor = _get_tensor(reshape_node.args[0])
+ assert len(reshape_node.args) >= 1, "reshape node must have at least 1 arg"
659
+ input_tensor_node = cast(torch.fx.Node, reshape_node.args[0])
660
+ reshape_op_input_tensor = _get_tensor(input_tensor_node)
661
assert reshape_op_input_tensor.ndim > reshape_op_output_tensor.ndim, (
662
"reshape must be from 3D+ to 2D"
663
0 commit comments