8000 Avoid calling fallback directly for symmetric memory tests by fegin · Pull Request #153520 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content

Avoid calling fallback directly for symmetric memory tests #153520

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: gh/fegin/306/base
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
121 changes: 65 additions & 56 deletions test/distributed/test_symmetric_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,6 @@
from torch._inductor.utils import fresh_inductor_cache, run_and_get_triton_code
from torch.distributed._functional_collectives import all_gather_tensor
from torch.distributed._symmetric_memory import (
_fused_all_gather_matmul_fallback,
_fused_all_gather_scaled_matmul_fallback,
_fused_matmul_reduce_scatter_fallback,
_test_mode,
enable_symm_mem_for_group,
restride_A_for_fused_matmul_reduce_scatter,
Expand All @@ -41,7 +38,10 @@
TestCase,
)


# When nullcontext is used, the dispatcher will dispatch the fused kernel calls to
# the symmetric memory implementation. If _test_mode is enabled, the dispatcher will
# dispatch the fused kernel calls to the "fallback" implementation, where symm_mem is
# not used and the regular collectives and pytorch ops are used instead.
test_contexts = [nullcontext, _test_mode]


Expand Down Expand Up @@ -364,18 +364,21 @@ def test_fused_all_gather_matmul(self, gather_dim: int) -> None:
A_shard = torch.rand(BATCH, M // self.world_size, K, device="cuda")
Bs = [torch.rand(K, N, device="cuda") for _ in range(3)]

ag_output_0, mm_outputs_0 = _fused_all_gather_matmul_fallback(
A_shard, Bs, gather_dim=gather_dim, group_name=group.group_name
)
ag_output_1, mm_outputs_1 = torch.ops.symm_mem.fused_all_gather_matmul(
A_shard, Bs, gather_dim=gather_dim, group_name=group.group_name
)
ag_output_vec = []
mm_outputs_vec = []
for context in test_contexts:
< 8000 span class='blob-code-inner blob-code-marker ' data-code-marker="+"> with context():
Comment on lines +367 to +370
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: it seems a little bit obscure to wrap the test in a loop (and use vector).

Since there is only 1 special context here, how about:

    ag_output_0, mm_outputs_0 = torch.ops.symm_mem.fused_all_gather_matmul(
        A_shard, Bs, gather_dim=gather_dim, group_name=group.group_name
    )

    with _test_mode():
        ag_output_1, mm_outputs_1 = torch.ops.symm_mem.fused_all_gather_matmul(
            A_shard, Bs, gather_dim=gather_dim, group_name=group.group_name
        )

    assert torch.allclose(ag_output_0, ag_output_1)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(Like you did for the other test below)

ag_output, mm_outputs = torch.ops.symm_mem.fused_all_gather_matmul(
A_shard, Bs, gather_dim=gather_dim, group_name=group.group_name
)
ag_output_vec.append(ag_output)
mm_outputs_vec.append(mm_outputs)

assert torch.allclose(ag_output_0, ag_output_1)
assert ag_output_0.stride() == ag_output_1.stride()
for mm_output_0, mm_output_1 in zip(mm_outputs_0, mm_outputs_1):
assert torch.allclose(ag_output_vec[0], ag_output_vec[1])
assert ag_output_vec[0].stride() == ag_output_vec[1].stride()
for mm_output_0, mm_output_1 in zip(mm_outputs_vec[0], mm_outputs_vec[1]):
assert torch.allclose(mm_output_0, mm_output_1)
assert mm_output_0.stride(), mm_output_1.stride()
assert mm_output_0.stride() == mm_output_1.stride()

dist.destroy_process_group()

Expand Down Expand Up @@ -418,9 +421,10 @@ def test_fused_all_gather_matmul_native(
else:
B = torch.rand(N, K, dtype=torch.bfloat16, device="cuda").t()

ag_baseline, mm_baseline = _fused_all_gather_matmul_fallback(
A_shard, [B], gather_dim=0, group_name=group_name
)
with _test_mode():
ag_baseline, mm_baseline = torch.ops.symm_mem.fused_all_gather_matmul(
A_shard, [B], gather_dim=0, group_name=group_name
)
with torch.profiler.profile(
activities=[
torch.profiler.ProfilerActivity.CUDA,
Expand Down Expand Up @@ -458,9 +462,11 @@ def test_multimem_all_gather_matmul(self) -> None:

B = torch.rand(K, N, dtype=torch.bfloat16, device="cuda")

ag_baseline, mm_baseline = _fused_all_gather_matmul_fallback(
A_shard, [B], gather_dim=0, group_name=group_name, return_A=False
)
with _test_mode():
ag_baseline, mm_baseline = torch.ops.symm_mem.fused_all_gather_matmul(
A_shard, [B], gather_dim=0, group_name=group_name, return_A=False
)

with torch.profiler.profile(
activities=[
torch.profiler.ProfilerActivity.CUDA,
Expand Down Expand Up @@ -524,39 +530,36 @@ def test_fused_all_gather_scaled_matmul(
else:
raise AssertionError(f"Invalid scale_mode: {scale_mode}")

ag_output_0, mm_outputs_0 = _fused_all_gather_scaled_matmul_fallback(
A_shard,
Bs,
A_scale,
B_scales,
gather_dim=gather_dim,
group_name=group.group_name,
biases=[None] * len(Bs),
result_scales=[None] * len(Bs),
out_dtypes=out_dtypes,
use_fast_accum=[None] * len(Bs),
)
ag_output_1, mm_outputs_1 = torch.ops.symm_mem.fused_all_gather_scaled_matmul(
A_shard,
Bs,
A_scale,
B_scales,
gather_dim=gather_dim,
group_name=group.group_name,
biases=[None] * len(Bs),
result_scales=[None] * len(Bs),
out_dtypes=out_dtypes,
use_fast_accum=[None] * len(Bs),
)
ag_output_vec = []
mm_outputs_vec = []
for context in test_contexts:
with context():
(
ag_output,
mm_outputs,
) = torch.ops.symm_mem.fused_all_gather_scaled_matmul(
A_shard,
Bs,
A_scale,
B_scales,
gather_dim=gather_dim,
group_name=group.group_name,
biases=[None] * len(Bs),
result_scales=[None] * len(Bs),
out_dtypes=out_dtypes,
use_fast_accum=[None] * len(Bs),
)
ag_output_vec.append(ag_output)
mm_outputs_vec.append(mm_outputs)

self.assertTrue(
torch.allclose(
ag_output_0.to(torch.float32),
ag_output_1.to(torch.float32),
ag_output_vec[0].to(torch.float32),
ag_output_vec[1].to(torch.float32),
)
)
self.assertEqual(ag_output_0.stride(), ag_output_1.stride())
for mm_output_0, mm_output_1 in zip(mm_outputs_0, mm_outputs_1):
self.assertEqual(ag_output_vec[0].stride(), ag_output_vec[1].stride())
for mm_output_0, mm_output_1 in zip(mm_outputs_vec[0], mm_outputs_vec[1]):
self.as A178 sertTrue(
torch.allclose(
mm_output_0.to(torch.float32), mm_output_1.to(torch.float32)
Expand Down Expand Up @@ -584,15 +587,21 @@ def test_fused_matmul_reduce_scatter(self, scatter_dim: int) -> None:
A = torch.rand(BATCH, M, K, device="cuda")
B = torch.rand(K, N, device="cuda")

output_0 = _fused_matmul_reduce_scatter_fallback(
A, B, "avg", scatter_dim=scatter_dim, group_name=group.group_name
)
output_1 = torch.ops.symm_mem.fused_matmul_reduce_scatter(
A, B, "avg", scatter_dim=scatter_dim, group_name=group.group_name
)
outputs = []
for context in test_contexts:
with context():
outputs.append(
torch.ops.symm_mem.fused_matmul_reduce_scatter(
A,
B,
"avg",
scatter_dim=scatter_dim,
group_name=group.group_name,
)
)

assert torch.allclose(output_0, output_1)
assert output_0.stride() == output_1.stride()
assert torch.allclose(outputs[0], outputs[1])
assert outputs[0].stride() == outputs[1].stride()

dist.destroy_process_group()

Expand Down
Loading
0