8000 [async TP] insert reshape node to handle "reshape -> scaled mm -> reshape pattern" in async TP with rowwise scales by danielvegamyhre · Pull Request #148001 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content

[async TP] insert reshape node to handle "reshape -> scaled mm -> reshape pattern" in async TP with rowwise scales #148001

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

8000 Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 63 additions & 0 deletions test/distributed/tensor/parallel/test_micro_pipeline_tp.py
Original file line number Diff line number Diff line change
Expand Up @@ -399,6 +399,69 @@ def func(
self.assertIn("fused_scaled_matmul_reduce_scatter", code)
self.assertNotIn("reduce_scatter_tensor", code)

@skipIfRocm
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
@parametrize("scatter_dim", [2])
@fresh_inductor_cache()
def test_fuse_scaled_matmul_reduce_scatter_rowwise_scales_reshape_mm_reshape(
self, scatter_dim
):
group = dist.group.WORLD

def reshape_mm_reshape(
A: torch.Tensor,
B: torch.Tensor,
A_scale: torch.Tensor,
B_scale: torch.Tensor,
out_dtype: torch.dtype,
) -> torch.Tensor:
"""
Performs a scaled_mm followed by a reduce scatter,
following the reshape -> scaled_mm -> reshape pattern.
"""
orig_shape = A.shape

# reshape tensor and scale together
A = A.reshape(-1, orig_shape[-1])
A_scale = A_scale.reshape(-1, A_scale.shape[-1])
A_scale = torch.reciprocal(A_scale)

C = torch._scaled_mm(A, B, A_scale, B_scale, out_dtype=out_dtype)

# reshape output to have same leading dims as original `A` tensor
C = C.view(*orig_shape[:-1], C.shape[-1])
return reduce_scatter_tensor(C, "sum", scatter_dim, group)

A = torch.rand(1, 16, 32, device="cuda").to(torch.float8_e4m3fn)
B = torch.rand(64, 32, device="cuda").to(torch.float8_e4m3fn).T

# A_scale = rowwise scales
A_scale = torch.full((1, 16, 1), 0.1, device="cuda")

# B_scale = rowwise scales transposed for A @ B^T
B_scale = torch.full((1, 64), 0.1, device="cuda")

gm = _make_post_grad_fx(
reshape_mm_reshape, A, B, A_scale, B_scale, torch.bfloat16
)

with _test_mode():
micro_pipeline_tp_pass(gm.graph)

self.assertIn("fused_scaled_matmul_reduce_scatter", str(gm.graph))
self.assertNotIn("reduce_scatter_tensor", str(gm.graph))

if torch.cuda.get_device_capability() < (8, 9):
return

with _test_mode():
compiled = torch.compile(reshape_mm_reshape)
code = run_and_get_triton_code(
compiled, A, B, A_scale, B_scale, torch.bfloat16
)
self.assertIn("fused_scaled_matmul_reduce_scatter", code)
self.assertNotIn("reduce_scatter_tensor", code)

@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
@parametrize("shard_dim", [0, 1])
@fresh_inductor_cache()
Expand Down
100 changes: 96 additions & 4 deletions torch/_inductor/fx_passes/micro_pipeline_tp.py
Original file line number Diff line number Diff line change
Expand Up @@ -386,12 +386,104 @@ def get_arg(node: torch.fx.Node, idx: int, default: Any) -> Any:
return default
return node.args[idx]

def insert_reshape_op(node: torch.fx.Node):
"""
Given a reciprocal node with a parent reshape node,
insert a reshape node after the reciprocal node which reshapes
the reciprocal output back to the original shape before the first reshape.

Before:
reshape (a,bc,) to (a*b,c) -> reciprocal

After:
reshape (a,bc,) to (a*b,c) -> reciprocal -> reshape (a*b,c) to (a,b,c)

Returns the new reshape node.
"""
# ensure the given node matches the pattern described in the docstring
assert node.target == aten.reciprocal.default, (
"Node must be a aten.reciprocal.default op"
)
assert len(node.all_input_nodes) == 1, "Node must have exactly one parent"

parent_node = node.all_input_nodes[0]
assert parent_node.target == aten.reshape.default, (
"Parent node must be a aten.reshape.default op"
)
assert len(parent_node.all_input_nodes) == 1, (
"Parent node must have exactly one input node"
)

parent_input_node = parent_node.all_input_nodes[0]
parent_input_shape = list(_get_tensor(parent_input_node).shape)

# insert reshape back to shape from before the parent reshape op
graph = node.graph
with graph.inserting_after(node):
reshape_node = graph.call_function(
aten.reshape.default, (node, parent_input_shape)
)

# ensure all users of original node (except the reshape node) now use the reshaped node instead
node_users = list(node.users)
for user in node_users:
if user != reshape_node:
user.replace_input_with(node, reshape_node)

return reshape_node

is_reshape_mm_reshape_pattern = match[0].target == aten.reshape.default
mm_node = match[1] if is_reshape_mm_reshape_pattern else match[0]

# `A_node` is pulled directly from match rather than `mm_node` because it needs to handle
# both of the following cases:
#
# Case 1: single node match (mm):
# - match[0].args[0] will be the "A tensor" node of scaled_mm
# - Has 2D shape
#
# Case 2: 3 node match (reshape -> mm -> reshape)
# - match[0].args[0] will be the "A tensor" input to the reshape op
# - Has 3D+ shape
A_node = cast(torch.fx.Node, match[0].args[0])
B_node = cast(torch.fx.Node, mm_node.args[1])
A_scale_node = cast(torch.fx.Node, mm_node.args[2])
B_scale_node = cast(torch.fx.Node, mm_node.args[3])

A_ndim = _get_tensor(A_node).ndim
A_scale_ndim = _get_tensor(A_scale_node).ndim
is_reciprocal_with_reshape_parent = (
A_scale_node.target == aten.reciprocal.default
and len(A_scale_node.all_input_nodes) == 1
and A_scale_node.all_input_nodes[0].target == aten.reshape.default
)
is_tensorwise_scaling = A_scale_ndim <= 1

# This is a temporary workaround to handle the reshape -> scaled_mm -> reshape
# pattern when scales are row-wise, and have been reshaped along with the target
# tensor. See https://github.com/pytorch/pytorch/pull/148001 for details.
#
# If tensor dim does not match scale dim, check if the scale node follows
# the "reshape -> reciprocal" pattern. If so, we can insert a reshape op after
# the reciprocal, to reshape the reciprocal back to the original shape before
# the first reshape op.
#
# TODO: remove this workaround once torch._scaled_matmul exists and can be used
# to implement a more robust long-term support for 3D+ scaled matmuls.
if (
is_reshape_mm_reshape_pattern
and A_ndim != A_scale_ndim
and not is_tensorwise_scaling
and is_reciprocal_with_reshape_parent
):
A_scale_node = insert_reshape_op(A_scale_node)

return _ScaledMatmul(
nodes=match,
A_node=cast(torch.fx.Node, match[0].args[0]),
B_node=cast(torch.fx.Node, mm_node.args[1]),
A_scale_node=cast(torch.fx.Node, mm_node.args[2]),
B_scale_node=cast(torch.fx.Node, mm_node.args[3]),
A_node=A_node,
B_node=B_node,
A_scale_node=A_scale_node,
B_scale_node=B_scale_node,
bias_node=get_arg(mm_node, 4, None),
result_scale_node=get_arg(mm_node, 5, None),
out_dtype=get_arg(mm_node, 6, None),
Expand Down
Loading
0