10000 [Async TP] Fuse matmul-reduce-scatters when reduce scatters have multiple users, and save fused node for backward instead of reduce_scatter node by danielvegamyhre · Pull Request #149946 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content

[Async TP] Fuse matmul-reduce-scatters when reduce scatters have multiple users, and save fused node for backward instead of reduce_scatter node #149946

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 8 commits into from
Closed
42 changes: 2 additions & 40 deletions test/distributed/tensor/parallel/test_micro_pipeline_tp.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,11 +157,11 @@ def func(inp: torch.Tensor) -> torch.Tensor:
"placeholder",
)
self.assertEqual(
reduce_scatter.rs_node.target,
reduce_scatter.reduce_scatter_node.target,
torch.ops._c10d_functional.reduce_scatter_tensor.default,
)
self.assertEqual(
reduce_scatter.res_node.target,
reduce_scatter.wait_tensor_node.target,
torch.ops._c10d_functional.wait_tensor.default,
)
self.assertEqual(reduce_scatter.group_name, group.group_name)
Expand Down Expand Up @@ -492,44 +492,6 @@ def no_matching_pattern(
gm.graph,
)

@skipIfRocm
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
@fresh_inductor_cache()
def test_unsuccessful_fusion(self):
group = dist.group.WORLD
scatter_dim = 0

def no_matching_pattern(
A: torch.Tensor,
B: torch.Tensor,
) -> torch.Tensor:
"""
Performs 'reshape -> reciprocal -> mm -> reshape -> reduce scatter' pattern,
so the extra 'reciprocal' op in the middle should cause pattern matching to fail.
"""
out_shape = [*A.shape[:-1], B.shape[-1]]
A = A.reshape(-1, A.shape[-1])

# insert extra op after reshape that will cause pattern matching to fail
A = torch.reciprocal(A)

C = A @ B
C = C.view(out_shape)
return reduce_scatter_tensor(C, "sum", scatter_dim, group)

A = torch.rand(2, 16, 32, device="cuda").to(torch.bfloat16)
B = torch.rand(16, 32, device="cuda").to(torch.bfloat16).T

gm = _make_post_grad_fx(no_matching_pattern, A, B)

with _test_mode():
self.assertRaisesRegex(
AssertionError,
"no successful fusions of matul-reduce-scatters",
micro_pipeline_tp_pass,
gm.graph,
)

@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
@parametrize("shard_dim", [0, 1])
@fresh_inductor_cache()
Expand Down
169 changes: 121 additions & 48 deletions torch/_inductor/fx_passes/micro_pipeline_tp.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,14 +204,37 @@ def make_cat_pattern(splits):
class _ReduceScatterMatch:
match: Match
input_node: torch.fx.Node
rs_node: torch.fx.Node
res_node: torch.fx.Node
reduce_scatter_node: torch.fx.Node
wait_tensor_node: torch.fx.Node
reduce_op: str
scatter_dim: int
group_name: str

def replace_with(self, new_node: torch.fx.Node) -> None:
self.res_node.replace_all_uses_with(new_node)
# Replace all uses of the result node (wait_tensor) with the fused node.
self.wait_tensor_node.replace_all_uses_with(new_node)

# If the reduce-scatter result is saved for backward, save the fused node for backward instead.
self._update_save_for_backward(new_node)

def _update_save_for_backward(self, new_node: torch.fx.Node) -> None:
"""
If the output node is a user of the reduce_scatter node (indicating the reduce_scatter
result is saved for backward), this method will update the output node to use the fused node instead.
"""
output_node = None
for user in self.reduce_scatter_node.users:
if user.target == "output":
output_node = user
break
if output_node is not None:
output_node.replace_input_with(self.reduce_scatter_node, new_node)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it true that the original reduce_scatter_node should have only one user at this moment (the original wait_tensor)? iirc, only wait_tensor_node will be used by other nodes. So can we also assert that there is only one user for reduce_scatter_node?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, the reduce-scatter may have 2 users: the wait_tensor, and the final output node (if it is saved for backward)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@fegin I think one alternative here is that we do something to force the partitioner to never save a collective for bw directly, and to always save its corresponding wait_tensor. We could try to do this. Two things though:

(1) You could imagine cases where.... a collective is run in the forward, but its result is not actually needed until the backward. In that case, it would actually be more profitable to delay the sync until the backward when the collective is actually used. I can pretty easily construct a case like this but I'm not sure how likely it is to show up in practice

(1) I'm not sure how difficult of an invariant this would be to maintain in the partitioner. So the tradeoff here is probably more around "increased complexity in the partitioner" vs "increased complexity in the pattern matcher"

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@danielvegamyhre I didn't express my question clearly. I meant after the output_node .replace_input_with(), there should be only one node for the ORIGINAL reduce_scatter_node.

@bdhirsh I do believe there are some cases where collective waits are intentionally delayed until the backward. So making partitioner have the assumption is not great. I simply wants to add some check AFTER the replacement.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@danielvegamyhre I didn't express my question clearly. I meant after the output_node .replace_input_with(), there should be only one node for the ORIGINAL reduce_scatter_node.

@fegin Sure, that would be a helpful check - I added an assertion here and reran torchittan training runs for bf16, float8 tensorwise, float8 rowwise and validated everything still works as expected (with the exception of all-gather-matmuls not fusing properly for rowwise scales, which is a known issue I'm already tracking in #149990).


# Assert that now the reduce scatter node has only one user (the wait_tensor) and it's not
# saved for backward anymore.
assert len(self.reduce_scatter_node.users) == 1, (
"Reduce scatter node has multiple users, this is not expected"
)

def erase(self) -> None:
for node in reversed(self.match.nodes):
Expand All @@ -222,7 +245,7 @@ def erase(self) -> None:
def find_reduce_scatter_patterns(graph: torch.fx.Graph):
c10d = torch.ops._c10d_functional

def reduce_scatter_template(inp: PatternExpr):
def reduce_scatter_template(inp: PatternExpr, users: int):
return CallFunction(
c10d.wait_tensor.default,
CallFunction(
Expand All @@ -231,14 +254,43 @@ def reduce_scatter_template(inp: PatternExpr):
KeywordArg("reduce_op"),
Ignored(),
KeywordArg("group_name"),
_users=users,
),
)

# Matches funcol.reduce_scatter_tensor with scatter_dim == 0
zero_dim_reduce_scatter_pattern = reduce_scatter_template(KeywordArg("input"))
zero_dim_reduce_scatter_pattern_single_user = reduce_scatter_template(
KeywordArg("input"), users=1
)

# Two users will occur when the reduce-scatter result is saved for backward
zero_dim_reduce_scatter_pattern_multi_user = reduce_scatter_template(
KeywordArg("input"), users=2
)

# Matches funcol.reduce_scatter_tensor with scatter_dim > 0
non_zero_dim_reduce_scatter_pattern = reduce_scatter_template(
non_zero_dim_reduce_scatter_pattern_single_user = reduce_scatter_template(
CallFunction(
aten.cat.default,
ListOf(
CallFunction(
operator.getitem,
CallFunction(
aten.split.Tensor,
KeywordArg("input"),
Ignored(),
KeywordArg("scatter_dim"),
_users=MULTIPLE,
),
Ignored(),
)
),
),
users=1,
)

# Two users will occur when the reduce-scatter result is saved for backward
non_zero_dim_reduce_scatter_pattern_multi_user = reduce_scatter_template(
CallFunction(
aten.cat.default,
ListOf(
Expand All @@ -255,32 +307,59 @@ def reduce_scatter_template(inp: PatternExpr):
)
),
),
users=2,
)

reduce_scatters = []
for node in reversed(graph.nodes):
if node.target == c10d.wait_tensor.default:
if match := non_zero_dim_reduce_scatter_pattern.match(node):
if match := non_zero_dim_reduce_scatter_pattern_single_user.match(node):
assert isinstance(match, Match)
reduce_scatters.append(
_ReduceScatterMatch(
match=match,
input_node=match.kwargs["input"],
reduce_scatter_node=match.nodes[-2],
wait_tensor_node=node,
reduce_op=match.kwargs["reduce_op"],
scatter_dim=match.kwargs["scatter_dim"],
group_name=match.kwargs["group_name"],
)
)
elif match := zero_dim_reduce_scatter_pattern_single_user.match(node):
assert isinstance(match, Match)
reduce_scatters.append(
_ReduceScatterMatch(
match=match,
input_node=match.kwargs["input"],
reduce_scatter_node=match.nodes[0],
wait_tensor_node=node,
reduce_op=match.kwargs["reduce_op"],
scatter_dim=0,
group_name=match.kwargs["group_name"],
)
)
elif match := non_zero_dim_reduce_scatter_pattern_multi_user.match(node):
assert isinstance(match, Match)
reduce_scatters.append(
_ReduceScatterMatch(
match=match,
input_node=match.kwargs["input"],
rs_node=match.nodes[-2],
res_node=node,
reduce_scatter_node=match.nodes[-2],
wait_tensor_node=node,
reduce_op=match.kwargs["reduce_op"],
scatter_dim=match.kwargs["scatter_dim"],
group_name=match.kwargs["group_name"],
)
)
elif match := zero_dim_reduce_scatter_pattern.match(node):
elif match := zero_dim_reduce_scatter_pattern_multi_user.match(node):
assert isinstance(match, Match)
reduce_scatters.append(
_ReduceScatterMatch(
match=match,
input_node=match.kwargs["input"],
rs_node=match.nodes[0],
res_node=node,
reduce_scatter_node=match.nodes[0],
wait_tensor_node=node,
reduce_op=match.kwargs["reduce_op"],
scatter_dim=0,
group_name=match.kwargs["group_name"],
Expand Down Expand Up @@ -757,7 +836,7 @@ def _insert_fused_matmul_reduce_scatter(
raise AssertionError(f"Unexpected matmul match type: {type(matmul)}")


def fuse_matmul_reduce_scatter(reduce_scatter: _ReduceScatterMatch) -> bool:
def fuse_matmul_reduce_scatter(reduce_scatter: _ReduceScatterMatch) -> None:
"""
Fused the pattern

Expand All @@ -771,58 +850,57 @@ def fuse_matmul_reduce_scatter(reduce_scatter: _ReduceScatterMatch) -> bool:

Returns boolean indicating if fusion was successful or not.
"""
if (
not torch.distributed.is_available()
or not torch.distributed.is_nccl_available()
):
log.debug(
"torch.distributed is not available, skipping fuse_matmul_reduce_scatter fusion"
)
return False
assert torch.distributed.is_available() and torch.distributed.is_nccl_available(), (
"torch.distributed and NCCL must be available to use async tensor parallelism"
)

from torch.distributed._symmetric_memory import (
is_symm_mem_enabled_for_group,
restride_A_for_fused_matmul_reduce_scatter,
)

input_node, _rs_node, rs_res_node, reduce_op, orig_scatter_dim, group_name = (
(
input_node,
_reduce_scatter_node,
rs_wait_tensor_node,
reduce_op,
orig_scatter_dim,
group_name,
) = (
reduce_scatter.input_node,
reduce_scatter.rs_node,
reduce_scatter.res_node,
reduce_scatter.reduce_scatter_node,
reduce_scatter.wait_tensor_node,
reduce_scatter.reduce_op,
reduce_scatter.scatter_dim,
reduce_scatter.group_name,
)

if not is_symm_mem_enabled_for_group(group_name):
log.debug(
"symmetric memory is not enabled for process group %s, skipping fuse_matmul_reduce_scatter fusion",
group_name,
)
return False
assert is_symm_mem_enabled_for_group(group_name), (
f"symmetric memory is not enabled for process group {group_name}, this is required for async TP"
)

# Currently fused_matmul_reduce_scatter doesn't return the matmul result,
# so we can't apply the fusion if the matmul result is used by multiple
# users. This is not a fundamental limitation of the fused op and can be
# addressed if needed.
if len(input_node.users) != 1:
log.debug(
log.warning(
"matmul result has more than one user, skipping fused_matmul_reduce_scatter fusion."
)
return False
return

matmul = _find_producer_matmul(input_node)
if matmul is None:
log.debug(
log.warning(
"no producer matmul found for reduce scatter, skipping fuse_matmul_reduce_scatter fusion"
)
return False
return

if rs_res_node in matmul.arg_ancestor_nodes:
log.debug(
if rs_wait_tensor_node in matmul.arg_ancestor_nodes:
log.warning(
"reduce-scatter result node is an ancestor of matmul, skipping fuse_matmul_reduce_scatter fusion"
)
return False
return

# We need to track 3 values for the fused scaled mm reduce scatter implementation:
# 1. The scatter dim before the reshape, which was assigned using the original (a,b,c) @ (c,d) = (a,b,d) dims.
Expand All @@ -848,8 +926,8 @@ def fuse_matmul_reduce_scatter(reduce_scatter: _ReduceScatterMatch) -> bool:
B_shape = list(_get_tensor(matmul.B_node).shape)
output_shape = [*A_orig_shape[:-1], B_shape[-1]]

graph = rs_res_node.graph
with graph.inserting_before(rs_res_node):
graph = rs_wait_tensor_node.graph
with graph.inserting_before(rs_wait_tensor_node):
# Restride A tensor before fused op, for optimal perf in fused matmul reduce scatter
if "val" in matmul.A_node.meta:
restrided = restride_A_for_fused_matmul_reduce_scatter(
Expand Down Expand Up @@ -885,7 +963,6 @@ def fuse_matmul_reduce_scatter(reduce_scatter: _ReduceScatterMatch) -> bool:
fused_node.prepend(node)

log.debug("successfully fused matmul reduce scatter")
return True


def _get_node_to_ancestors(
Expand Down Expand Up @@ -984,22 +1061,18 @@ def micro_pipeline_tp_pass(graph: torch.fx.Graph):
unexposed_collectives = _get_unexposed_collectives(graph)
all_gathers = [x for x in all_gathers if x.ag_node not in unexposed_collectives]
reduce_scatters = [
x for x in reduce_scatters if x.rs_node not in unexposed_collectives
x
for x in reduce_scatters
if x.reduce_scatter_node not in unexposed_collectives
]

if not all_gathers and not reduce_scatters:
raise AssertionError(
"async TP found no matching all-gather/reduce-scatter patterns for fusion"
)

# TODO: raise an exception if we're using async TP but failed to fuse any all-gather-matmuls
for all_gather in all_gathers:
fuse_all_gather_matmul(all_gather)

fused_reduce_scatters = False
for reduce_scatter in reduce_scatters:
if fuse_matmul_reduce_scatter(reduce_scatter):
fused_reduce_scatters = True

if reduce_scatters and not fused_reduce_scatters:
raise AssertionError("no successful fusions of matul-reduce-scatters")
fuse_matmul_reduce_scatter(reduce_scatter)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Curious why do we decide to change the check of the return value? What if there are no successful fusion occurred, do we still want to raise?

Copy link
Contributor Author
@danielvegamyhre danielvegamyhre Mar 28, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah I found that there were only valid scaled_mm-reduce-scatter patterns in the forward graph, but not in the backward graph, so we can't assert this here. In the backward graph, the reduce-scatters are receiving as input the addition of various scaled_mms, so it's not what we are looking to fuse. See diagram below:

Screenshot 2025-03-28 at 8 41 43 AM

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

uh, thanks, his makes sense.

Loading
0