8000 [async TP] insert reshape node to handle "reshape -> scaled mm -> res… · pytorch/pytorch@de7af81 · GitHub
[go: up one dir, main page]

Skip to content

Commit de7af81

Browse files
danielvegamyhrepytorchmergebot
authored andcommitted
[async TP] insert reshape node to handle "reshape -> scaled mm -> reshape pattern" in async TP with rowwise scales (#148001)
Fixes pytorch/torchtitan#864 ## Summary While testing torchtitan with float8 training with rowwise scaling + async TP, a [bug](pytorch/torchtitan#864) was discovered. The symptom was the scaling factor dims did not match the dims of the tensor the scales were to be applied to. My [root cause analysis](pytorch/torchtitan#864 (comment)) determined the reason is that when async TP graph manipulation constructs the `fused_scaled_matmul_reduce_scatter` op, it does not yet handle the "reshape -> scaled mm -> reshape" pattern used in torchao [here](https://github.com/pytorch/ao/blob/ed361ff5c7dd33aba9b4a0da2bd744de5a5debfb/torchao/float8/float8_linear.py#L122-L124) - specifically when row-wise scales are being used. ## TL;DR of root cause - When a Float8Tensor is reshaped, the scale is reshaped along with it so the dimensions are aligned. - In the graph manipulation logic of the micropipeline TP post grad pass, the scaled_mm `A tensor` node is referencing the tensor _before_ to the reshape op, but referencing the `A_scale` node _after_ the reshape op. ## Example - Concrete example: - `A tensor` is a Float8Tensor with shape (1,8192,2048) and scale of shape (1,8192,1) when a matmul op is called in torchao [here](https://github.com/pytorch/ao/blob/8706d3f3b087b876d625c720e98236c265c0ba98/torchao/float8/float8_linear.py#L70). Torchao does a reshape -> scaled mm -> reshape [here](https://github.com/pytorch/ao/blob/ed361ff5c7dd33aba9b4a0da2bd744de5a5debfb/torchao/float8/float8_linear.py#L122). When a Float8Tensor is reshaped, its scale is reshaped along with it [here](https://github.com/pytorch/ao/blob/8706d3f3b087b876d625c720e98236c265c0ba98/torchao/float8/float8_ops.py#L152). So the first reshape makes the "A tensor" (1,8192,2048) => (8192,2048) and the scale (1,8192,1) => (8192,1). - During post grad pass in async TP: - `A_node` has shape (1,8192,2048) (tensor from before this [reshape](https://github.com/pytorch/ao/blob/ed361ff5c7dd33aba9b4a0da2bd744de5a5debfb/torchao/float8/float8_linear.py#L122)) - `A_scale` has shape (8192,1) (due to reshape op above, which caused the scale to be reshaped from (1,8192,1) => (8192,1)). ## Solution **Note:** the compiler inserts a `reciprocal` op after the reshape, so we can't simply use the node before the reshape as the `A_scale_node`, otherwise it will affect the numerics. - Short-term solution: if the specific pattern showne below is detected, insert a reshape node after the reciprocal, to reshape the reciprocal output back to the originals shape before the reshape. - reshape is just a view, so there should be no impact on performance ``` Before: reshape (a,bc,) to (a*b,c) -> reciprocal After: reshape (a,bc,) to (a*b,c) -> reciprocal -> reshape (a*b,c) to (a,b,c) ``` - Long-term solution: implement a `torch._scaled_matmul` which can support 3D+ `A tensor` ## Test plan - Added unit test which exercises this new path - Manually tested with torchtitan with float8 rowwise + async TP Pull Request resolved: #148001 Approved by: https://github.com/yifuwang
1 parent ce2f680 commit de7af81

File tree

2 files changed

+159
-4
lines changed
< 8000 /div>

2 files changed

+159
-4
lines changed

test/distributed/tensor/parallel/test_micro_pipeline_tp.py

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -399,6 +399,69 @@ def func(
399399
self.assertIn("fused_scaled_matmul_reduce_scatter", code)
400400
self.assertNotIn("reduce_scatter_tensor", code)
401401

402+
@skipIfRocm
403+
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
404+
@parametrize("scatter_dim", [2])
405+
@fresh_inductor_cache()
406+
def test_fuse_scaled_matmul_reduce_scatter_rowwise_scales_reshape_mm_reshape(
407+
self, scatter_dim
408+
):
409+
group = dist.group.WORLD
410+
411+
def reshape_mm_reshape(
412+
A: torch.Tensor,
413+
B: torch.Tensor,
414+
A_scale: torch.Tensor,
415+
B_scale: torch.Tensor,
416+
out_dtype: torch.dtype,
417+
) -> torch.Tensor:
418+
"""
419+
Performs a scaled_mm followed by a reduce scatter,
420+
following the reshape -> scaled_mm -> reshape pattern.
421+
"""
422+
orig_shape = A.shape
423+
424+
# reshape tensor and scale together
425+
A = A.reshape(-1, orig_shape[-1])
426+
A_scale = A_scale.reshape(-1, A_scale.shape[-1])
427+
A_scale = torch.reciprocal(A_scale)
428+
429+
C = torch._scaled_mm(A, B, A_scale, B_scale, out_dtype=out_dtype)
430+
431+
# reshape output to have same leading dims as original `A` tensor
432+
C = C.view(*orig_shape[:-1], C.shape[-1])
433+
return reduce_scatter_tensor(C, "sum", scatter_dim, group)
434+
435+
A = torch.rand(1, 16, 32, device="cuda").to(torch.float8_e4m3fn)
436+
B = torch.rand(64, 32, device="cuda").to(torch.float8_e4m3fn).T
437+
438+
# A_scale = rowwise scales
439+
A_scale = torch.full((1, 16, 1), 0.1, device="cuda")
440+
441+
# B_scale = rowwise scales transposed for A @ B^T
442+
B_scale = torch.full((1, 64), 0.1, device="cuda")
443+
444+
gm = _make_post_grad_fx(
445+
reshape_mm_reshape, A, B, A_scale, B_scale, torch.bfloat16
446+
)
447+
448+
with _test_mode():
449+
micro_pipeline_tp_pass(gm.graph)
450+
451+
self.assertIn("fused_scaled_matmul_reduce_scatter", str(gm.graph))
452+
self.assertNotIn("reduce_scatter_tensor", str(gm.graph))
453+
454+
if torch.cuda.get_device_capability() < (8, 9):
455+
return
456+
457+
with _test_mode():
458+
compiled = torch.compile(reshape_mm_reshape)
459+
code = run_and_get_triton_code(
460+
compiled, A, B, A_scale, B_scale, torch.bfloat16
461+
)
462+
self.assertIn("fused_scaled_matmul_reduce_scatter", code)
463+
self.assertNotIn("reduce_scatter_tensor", code)
464+
402465
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
403466
@parametrize("shard_dim", [0, 1])
404467
@fresh_inductor_cache()

torch/_inductor/fx_passes/micro_pipeline_tp.py

Lines changed: 96 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -386,12 +386,104 @@ def get_arg(node: torch.fx.Node, idx: int, default: Any) -> Any:
386386
return default
387387
return node.args[idx]
388388

389+
def insert_reshape_op(node: torch.fx.Node):
390+
"""
391+
Given a reciprocal node with a parent reshape node,
392+
insert a reshape node after the reciprocal node which reshapes
393+
the reciprocal output back to the original shape before the first reshape.
394+
395+
Before:
396+
reshape (a,bc,) to (a*b,c) -> reciprocal
397+
398+
After:
399+
reshape (a,bc,) to (a*b,c) -> reciprocal -> reshape (a*b,c) to (a,b,c)
400+
401+
Returns the new reshape node.
402+
"""
403+
# ensure the given node matches the pattern described in the docstring
404+
assert node.target == aten.reciprocal.default, (
405+
"Node must be a aten.reciprocal.default op"
406+
)
407+
assert len(node.all_input_nodes) == 1, "Node must have exactly one parent"
408+
409+
parent_node = node.all_input_nodes[0]
410+
assert parent_node.target == aten.reshape.default, (
411+
"Parent node must be a aten.reshape.default op"
412+
)
413+
assert len(parent_node.all_input_nodes) == 1, (
414+
"Parent node must have exactly one input node"
415+
)
416+
417+
parent_input_node = parent_node.all_input_nodes[0]
418+
parent_input_shape = list(_get_tensor(parent_input_node).shape)
419+
420+
# insert reshape back to shape from before the parent reshape op
421+
graph = node.graph
422+
with graph.inserting_after(node):
423+
reshape_node = graph.call_function(
424+
aten.reshape.default, (node, parent_input_shape)
425+
)
426+
427+
# ensure all users of original node (except the reshape node) now use the reshaped node instead
428+
node_users = list(node.users)
429+
for user in node_users:
430+
if user != reshape_node:
431+
user.replace_input_with(node, reshape_node)
432+
433+
return reshape_node
434+
435+
is_reshape_mm_reshape_pattern = match[0].target == aten.reshape.default
436+
mm_node = match[1] if is_reshape_mm_reshape_pattern else match[0]
437+
438+
# `A_node` is pulled directly from match rather than `mm_node` because it needs to handle
439+
# both of the following cases:
440+
#
441+
# Case 1: single node match (mm):
442+
# - match[0].args[0] will be the "A tensor" node of scaled_mm
443+
# - Has 2D shape
444+
#
445+
# Case 2: 3 node match (reshape -> mm -> reshape)
446+
# - match[0].args[0] will be the "A tensor" input to the reshape op
447+
# - Has 3D+ shape
448+
A_node = cast(torch.fx.Node, match[0].args[0])
449+
B_node = cast(torch.fx.Node, mm_node.args[1])
450+
A_scale_node = cast(torch.fx.Node, mm_node.args[2])
451+
B_scale_node = cast(torch.fx.Node, mm_node.args[3])
452+
453+
A_ndim = _get_tensor(A_node).ndim
454+
A_scale_ndim = _get_tensor(A_scale_node).ndim
455+
is_reciprocal_with_reshape_parent = (
456+
A_scale_node.target == aten.reciprocal.default
457+
and len(A_scale_node.all_input_nodes) == 1
458+
and A_scale_node.all_input_nodes[0].target == aten.reshape.default
459+
)
460+
is_tensorwise_scaling = A_scale_ndim <= 1
461+
462+
# This is a temporary workaround to handle the reshape -> scaled_mm -> reshape
463+
# pattern when scales are row-wise, and have been reshaped along with the target
464+
# tensor. See https://github.com/pytorch/pytorch/pull/148001 for details.
465+
#
466+
# If tensor dim does not match scale dim, check if the scale node follows
467+
# the "reshape -> reciprocal" pattern. If so, we can insert a reshape op after
468+
# the reciprocal, to reshape the reciprocal back to the original shape before
469+
# the first reshape op.
470+
#
471+
# TODO: remove this workaround once torch._scaled_matmul exists and can be used
472+
# to implement a more robust long-term support for 3D+ scaled matmuls.
473+
if (
474+
is_reshape_mm_reshape_pattern
475+
and A_ndim != A_scale_ndim
476+
and not is_tensorwise_scaling
477+
and is_reciprocal_with_reshape_parent
478+
):
479+
A_scale_node = insert_reshape_op(A_scale_node)
480+
389481
return _ScaledMatmul(
390482
nodes=match,
391-
A_node=cast(torch.fx.Node, match[0].args[0]),
392-
B_node=cast(torch.fx.Node, mm_node.args[1]),
393-
A_scale_node=cast(torch.fx.Node, mm_node.args[2]),
394-
B_scale_node=cast(torch.fx.Node, mm_node.args[3]),
483+
A_node=A_node,
484+
B_node=B_node,
485+
A_scale_node=A_scale_node,
486+
B_scale_node=B_scale_node,
395487
bias_node=get_arg(mm_node, 4, None),
396488
result_scale_node=get_arg(mm_node, 5, None),
397489
out_dtype=get_arg(mm_node, 6, None),

0 commit comments

Comments
 (0)
0