8000 Avoid calling fallback directly for symmetric memory tests · pytorch/pytorch@52c29cb · GitHub
[go: up one dir, main page]

Skip to content

Commit 52c29cb

Browse files
committed
Avoid calling fallback directly for symmetric memory tests
Since we can just use _test_mode to dispatch the calls, we should use this method to also verify the function signature is consistent. ghstack-source-id: 4cd25df Pull-Request-resolved: #153520
1 parent 81dfde0 commit 52c29cb

File tree

1 file changed

+65
-56
lines changed

1 file changed

+65
-56
lines changed

test/distributed/test_symmetric_memory.py

Lines changed: 65 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,6 @@
1313
from torch._inductor.utils import fresh_inductor_cache, run_and_get_triton_code
1414
from torch.distributed._functional_collectives import all_gather_tensor
1515
from torch.distributed._symmetric_memory import (
16-
_fused_all_gather_matmul_fallback,
17-
_fused_all_gather_scaled_matmul_fallback,
18-
_fused_matmul_reduce_scatter_fallback,
1916
_test_mode,
2017
enable_symm_mem_for_group,
2118
restride_A_for_fused_matmul_reduce_scatter,
@@ -41,7 +38,10 @@
4138
TestCase,
4239
)
4340

44-
41+
# When nullcontext is used, the dispatcher will dispatch the fused kernel calls to
42+
# the symmetric memory implementation. If _test_mode is enabled, the dispatcher will
43+
# dispatch the fused kernel calls to the "fallback" implementation, where symm_mem is
44+
# not used and the regular collectives and pytorch ops are used instead.
4545
test_contexts = [nullcontext, _test_mode]
4646

4747

@@ -364,18 +364,21 @@ def test_fused_all_gather_matmul(self, gather_dim: int) -> None:
364364
A_shard = torch.rand(BATCH, M // self.world_size, K, device="cuda")
365365
Bs = [torch.rand(K, N, device="cuda") for _ in range(3)]
366366

367-
ag_output_0, mm_outputs_0 = _fused_all_gather_matmul_fallback(
368-
A_shard, Bs, gather_dim=gather_dim, group_name=group.group_name
369-
)
370-
ag_output_1, mm_outputs_1 = torch.ops.symm_mem.fused_all_gather_matmul(
371-
A_shard, Bs, gather_dim=gather_dim, group_name=group.group_name
372-
)
367+
ag_output_vec = []
368+
mm_outputs_vec = []
369+
for context in test_contexts:
370+
with context():
371+
ag_output, mm_outputs = torch.ops.symm_mem.fused_all_gather_matmul(
372+
A_shard, Bs, gather_dim=gather_dim, group_name=group.group_name
373+
)
374+
ag_output_vec.append(ag_output)
375+
mm_outputs_vec.append(mm_outputs)
373376

374-
assert torch.allclose(ag_output_0, ag_output_1)
375-
assert ag_output_0.stride() == ag_output_1.stride()
376-
for mm_output_0, mm_output_1 in zip(mm_outputs_0, mm_outputs_1):
377+
assert torch.allclose(ag_output_vec[0], ag_output_vec[1])
378+
assert ag_output_vec[0].stride() == ag_output_vec[1].stride()
379+
for mm_output_0, mm_output_1 in zip(mm_outputs_vec[0], mm_outputs_vec[1]):
377380
assert torch.allclose(mm_output_0, mm_output_1)
378-
assert mm_output_0.stride(), mm_output_1.stride()
381+
assert mm_output_0.stride() == mm_output_1.stride()
379382

380383
dist.destroy_process_group()
381384

@@ -418,9 +421,10 @@ def test_fused_all_gather_matmul_native(
418421
else:
419422
B = torch.rand(N, K, dtype=torch.bfloat16, device="cuda").t()
420423

421-
ag_baseline, mm_baseline = _fused_all_gather_matmul_fallback(
422-
A_shard, [B], gather_dim=0, group_name=group_name
423-
)
424+
with _test_mode():
425+
ag_baseline, mm_baseline = torch.ops.symm_mem.fused_all_gather_matmul(
426+
A_shard, [B], gather_dim=0, group_name=group_name
427+
)
424428
with torch.profiler.profile(
425429
activities=[
426430
torch.profiler.ProfilerActivity.CUDA,
@@ -458,9 +462,11 @@ def test_multimem_all_gather_matmul(self) -> None:
458462

459463
B = torch.rand(K, N, dtype=torch.bfloat16, device="cuda")
460464

461-
ag_baseline, mm_baseline = _fused_all_gather_matmul_fallback(
462-
A_shard, [B], gather_dim=0, group_name=group_name, return_A=False
463-
)
465+
with _test_mode():
466+
ag_baseline, mm_baseline = torch.ops.symm_mem.fused_all_gather_matmul(
467+
A_shard, [B], gather_dim=0, group_name=group_name, return_A=False
468+
)
469+
464470
with torch.profiler.profile(
465471
activities=[
466472
torch.profiler.ProfilerActivity.CUDA,
@@ -524,39 +530,36 @@ def test_fused_all_gather_scaled_matmul(
524530
else:
525531
raise AssertionError(f"Invalid scale_mode: {scale_mode}")
526532

527-
ag_output_0, mm_outputs_0 = _fused_all_gather_scaled_matmul_fallback(
528-
A_shard,
529-
Bs,
530-
A_scale,
531-
B_scales,
532-
gather_dim=gather_dim,
533-
group_name=group.group_name,
534-
biases=[None] * len(Bs),
535-
result_scales=[None] * len(Bs),
536-
out_dtypes=out_dtypes,
537-
use_fast_accum=[None] * len(Bs),
538-
)
539-
ag_output_1, mm_outputs_1 = torch.ops.symm_mem.fused_all_gather_scaled_matmul(
540-
A_shard,
541-
Bs,
542-
A_scale,
543-
B_scales,
544-
gather_dim=gather_dim,
545-
group_name=group.group_name,
546-
biases=[None] * len(Bs),
547-
result_scales=[None] * len(Bs),
548-
out_dtypes=out_dtypes,
549-
use_fast_accum=[None] * len(Bs),
550-
)
533+
ag_output_vec = []
534+
mm_outputs_vec = []
535+
for context in test_contexts:
536+
with context():
537+
(
538+
ag_output,
539+
mm_outputs,
540+
) = torch.ops.symm_mem.fused_all_gather_scaled_matmul(
541+
A_shard,
542+
Bs,
543+
A_scale,
544+
B_scales,
545+
gather_dim=gather_dim,
546+
group_name=group.group_name,
547+
biases=[None] * len(Bs),
548+
result_scales=[None] * len(Bs),
549+
out_dtypes=out_dtypes,
550+
use_fast_accum=[None] * len(Bs),
551+
)
552+
ag_output_vec.append(ag_output)
553+
mm_outputs_vec.append(mm_outputs)
551554

552555
self.assertTrue(
553556
torch.allclose(
554-
ag_output_0.to(torch.float32),
555-
ag_output_1.to(torch.float32),
557+
ag_output_vec[0].to(torch.float32),
558+
ag_output_vec[1].to(torch.float32),
556559
)
557560
)
558-
self.assertEqual(ag_output_0.stride(), ag_output_1.stride())
559-
for mm_output_0, mm_output_1 in zip(mm_outputs_0, mm_outputs_1):
561+
self.assertEqual(ag_output_vec[0].stride(), ag_output_vec[1].stride())
562+
for mm_output_0, mm_output_1 in zip(mm_outputs_vec[0], mm_outputs_vec[1]):
560563
self.assertTrue(
561564
torch.allclose(
562565
mm_output_0.to(torch.float32), mm_output_1.to(torch.float32)
@@ -584,15 +587,21 @@ def test_fused_matmul_reduce_scatter(self, scatter_dim: int) -> None:
584587
A = torch.rand(BATCH, M, K, device="cuda")
585588
B = torch.rand(K, N, device="cuda")
586589

587-
output_0 = _fused_matmul_reduce_scatter_fallback(
588-
A, B, "avg", scatter_dim=scatter_dim, group_name=group.group_name
589-
)
590-
output_1 = torch.ops.symm_mem.fused_matmul_reduce_scatter(
591-
A, B, "avg", scatter_dim=scatter_dim, group_name=group.group_name
592-
)
590+
outputs = []
591+
for context in test_contexts:
592+
with context():
593+
outputs.append(
594+
torch.ops.symm_mem.fused_matmul_reduce_scatter(
595+
A,
596+
B,
597+
"avg",
598+
scatter_dim=scatter_dim,
599+
group_name=group.group_name,
600+
)
601+
)
593602

594-
assert torch.allclose(output_0, output_1)
595-
assert output_0.stride() == output_1.stride()
603+
assert torch.allclose(outputs[0], outputs[1])
604+
assert outputs[0].stride() == outputs[1].stride()
596605

597606
dist.destroy_process_group()
598607

0 commit comments

Comments
 (0)
0