8000 [Async TP] More robust support for rowwise scales when fusing matmul reduce-scatter by danielvegamyhre · Pull Request #149247 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content

[Async TP] More robust support for rowwise scales when fusing matmul reduce-scatter #149247

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

8000
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 71 additions & 3 deletions test/distributed/tensor/parallel/test_micro_pipeline_tp.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from torch._inductor.utils import fresh_inductor_cache, run_and_get_triton_code
from torch.distributed._functional_collectives import (
all_gather_tensor,
all_reduce,
reduce_scatter_tensor,
)
from torch.distributed._symmetric_memory import _test_mode
Expand Down Expand Up @@ -401,7 +402,7 @@ def func(

@skipIfRocm
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
@parametrize("scatter_dim", [2])
@parametrize("scatter_dim", [0, 1, 2])
@fresh_inductor_cache()
def test_fuse_scaled_matmul_reduce_scatter_rowwise_scales_reshape_mm_reshape(
self, scatter_dim
Expand Down Expand Up @@ -432,11 +433,11 @@ def reshape_mm_reshape(
C = C.view(*orig_shape[:-1], C.shape[-1])
return reduce_scatter_tensor(C, "sum", scatter_dim, group)

A = torch.rand(1, 16, 32, device="cuda").to(torch.float8_e4m3fn)
A = torch.rand(2, 16, 32, device="cuda").to(torch.float8_e4m3fn)
B = torch.rand(64, 32, device="cuda").to(torch.float8_e4m3fn).T

# A_scale = rowwise scales
A_scale = torch.full((1, 16, 1), 0.1, device="cuda")
A_scale = torch.full((2, 16, 1), 0.1, device="cuda")

# B_scale = rowwise scales transposed for A @ B^T
B_scale = torch.full((1, 64), 0.1, device="cuda")
Expand All @@ -462,6 +463,73 @@ def reshape_mm_reshape(
self.assertIn("fused_scaled_matmul_reduce_scatter", code)
self.assertNotIn("reduce_scatter_tensor", code)

@skipIfRocm
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
@fresh_inductor_cache()
def test_no_all_gathers_or_reduce_scatters(self):
group = dist.group.WORLD

def no_matching_pattern(
A: torch.Tensor,
B: torch.Tensor,
) -> torch.Tensor:
"""
Performs some ops which will not have any all-gather-matmul or matmul-reduce-scatter patterns.
"""
C = A * B
return all_reduce(C, "sum", group)

A = torch.rand(2, 16, 32, device="cuda").to(torch.bfloat16)
B = torch.rand(16, 32, device="cuda").to(torch.bfloat16)

gm = _make_post_grad_fx(no_matching_pattern, A, B)

with _test_mode():
self.assertRaisesRegex(
AssertionError,
"async TP found no matching all-gather/reduce-scatter patterns for fusion",
micro_pipeline_tp_pass,
gm.graph,
)

@skipIfRocm
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
@fresh_inductor_cache()
def test_unsuccessful_fusion(self):
group = dist.group.WORLD
scatter_dim = 0

def no_matching_pattern(
A: torch.Tensor,
B: torch.Tensor,
) -> torch.Tensor:
"""
Performs 'reshape -> reciprocal -> mm -> reshape -> reduce scatter' pattern,
so the extra 'reciprocal' op in the middle should cause pattern matching to fail.
"""
out_shape = [*A.shape[:-1], B.shape[-1]]
A = A.reshape(-1, A.shape[-1])

# insert extra op after reshape that will cause pattern matching to fail
A = torch.reciprocal(A)

C = A @ B
C = C.view(out_shape)
return reduce_scatter_tensor(C, "sum", scatter_dim, group)

A = torch.rand(2, 16, 32, device="cuda").to(torch.bfloat16)
B = torch.rand(16, 32, device="cuda").to(torch.bfloat16).T

gm = _make_post_grad_fx(no_matching_pattern, A, B)

with _test_mode():
self.assertRaisesRegex(
AssertionError,
"no successful fusions of matul-reduce-scatters",
micro_pipeline_tp_pass,
gm.graph,
)

@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
@parametrize("shard_dim", [0, 1])
@fresh_inductor_cache()
Expand Down
Loading
Loading
0