|
13 | 13 | from torch._inductor.utils import fresh_inductor_cache, run_and_get_triton_code
|
14 | 14 | from torch.distributed._functional_collectives import all_gather_tensor
|
15 | 15 | from torch.distributed._symmetric_memory import (
|
16 |
| - _fused_all_gather_matmul_fallback, |
17 |
| - _fused_all_gather_scaled_matmul_fallback, |
18 |
| - _fused_matmul_reduce_scatter_fallback, |
19 | 16 | _test_mode,
|
20 | 17 | enable_symm_mem_for_group,
|
21 | 18 | restride_A_for_fused_matmul_reduce_scatter,
|
|
41 | 38 | TestCase,
|
42 | 39 | )
|
43 | 40 |
|
44 |
| - |
| 41 | +# When nullcontext is used, the dispatcher will dispatch the fused kernel calls to |
| 42 | +# the symmetric memory implementation. If _test_mode is enabled, the dispatcher will |
| 43 | +# dispatch the fused kernel calls to the "fallback" implementation, where symm_mem is |
| 44 | +# not used and the regular collectives and pytorch ops are used instead. |
45 | 45 | test_contexts = [nullcontext, _test_mode]
|
46 | 46 |
|
47 | 47 |
|
@@ -364,18 +364,21 @@ def test_fused_all_gather_matmul(self, gather_dim: int) -> None:
|
364 | 364 | A_shard = torch.rand(BATCH, M // self.world_size, K, device="cuda")
|
365 | 365 | Bs = [torch.rand(K, N, device="cuda") for _ in range(3)]
|
366 | 366 |
|
367 |
| - ag_output_0, mm_outputs_0 = _fused_all_gather_matmul_fallback( |
368 |
| - A_shard, Bs, gather_dim=gather_dim, group_name=group.group_name |
369 |
| - ) |
370 |
| - ag_output_1, mm_outputs_1 = torch.ops.symm_mem.fused_all_gather_matmul( |
371 |
| - A_shard, Bs, gather_dim=gather_dim, group_name=group.group_name |
372 |
| - ) |
| 367 | + ag_output_vec = [] |
| 368 | + mm_outputs_vec = [] |
| 369 | + for context in test_contexts: |
| 370 | + with context(): |
| 371 | + ag_output, mm_outputs = torch.ops.symm_mem.fused_all_gather_matmul( |
| 372 | + A_shard, Bs, gather_dim=gather_dim, group_name=group.group_name |
| 373 | + ) |
| 374 | + ag_output_vec.append(ag_output) |
| 375 | + mm_outputs_vec.append(mm_outputs) |
373 | 376 |
|
374 |
| - assert torch.allclose(ag_output_0, ag_output_1) |
375 |
| - assert ag_output_0.stride() == ag_output_1.stride() |
376 |
| - for mm_output_0, mm_output_1 in zip(mm_outputs_0, mm_outputs_1): |
| 377 | + assert torch.allclose(ag_output_vec[0], ag_output_vec[1]) |
| 378 | + assert ag_output_vec[0].stride() == ag_output_vec[1].stride() |
| 379 | + for mm_output_0, mm_output_1 in zip(mm_outputs_vec[0], mm_outputs_vec[1]): |
377 | 380 | assert torch.allclose(mm_output_0, mm_output_1)
|
378 |
| - assert mm_output_0.stride(), mm_output_1.stride() |
| 381 | + assert mm_output_0.stride() == mm_output_1.stride() |
379 | 382 |
|
380 | 383 | dist.destroy_process_group()
|
381 | 384 |
|
@@ -418,9 +421,10 @@ def test_fused_all_gather_matmul_native(
|
418 | 421 | else:
|
419 | 422 | B = torch.rand(N, K, dtype=torch.bfloat16, device="cuda").t()
|
420 | 423 |
|
421 |
| - ag_baseline, mm_baseline = _fused_all_gather_matmul_fallback( |
422 |
| - A_shard, [B], gather_dim=0, group_name=group_name |
423 |
| - ) |
| 424 | + with _test_mode(): |
| 425 | + ag_baseline, mm_baseline = torch.ops.symm_mem.fused_all_gather_matmul( |
| 426 | + A_shard, [B], gather_dim=0, group_name=group_name |
| 427 | + ) |
424 | 428 | with torch.profiler.profile(
|
425 | 429 | activities=[
|
426 | 430 | torch.profiler.ProfilerActivity.CUDA,
|
@@ -458,9 +462,11 @@ def test_multimem_all_gather_matmul(self) -> None:
|
458 | 462 |
|
459 | 463 | B = torch.rand(K, N, dtype=torch.bfloat16, device="cuda")
|
460 | 464 |
|
461 |
| - ag_baseline, mm_baseline = _fused_all_gather_matmul_fallback( |
462 |
| - A_shard, [B], gather_dim=0, group_name=group_name, return_A=False |
463 |
| - ) |
| 465 | + with _test_mode(): |
| 466 | + ag_baseline, mm_baseline = torch.ops.symm_mem.fused_all_gather_matmul( |
| 467 | + A_shard, [B], gather_dim=0, group_name=group_name, return_A=False |
| 468 | + ) |
| 469 | + |
464 | 470 | with torch.profiler.profile(
|
465 | 471 | activities=[
|
466 | 472 | torch.profiler.ProfilerActivity.CUDA,
|
@@ -524,39 +530,36 @@ def test_fused_all_gather_scaled_matmul(
|
524 | 530 | else:
|
525 | 531 | raise AssertionError(f"Invalid scale_mode: {scale_mode}")
|
526 | 532 |
|
527 |
| - ag_output_0, mm_outputs_0 = _fused_all_gather_scaled_matmul_fallback( |
528 |
| - A_shard, |
529 |
| - Bs, |
530 |
| - A_scale, |
531 |
| - B_scales, |
532 |
| - gather_dim=gather_dim, |
533 |
| - group_name=group.group_name, |
534 |
| - biases=[None] * len(Bs), |
535 |
| - result_scales=[None] * len(Bs), |
536 |
| - out_dtypes=out_dtypes, |
537 |
| - use_fast_accum=[None] * len(Bs), |
538 |
| - ) |
539 |
| - ag_output_1, mm_outputs_1 = torch.ops.symm_mem.fused_all_gather_scaled_matmul( |
540 |
| - A_shard, |
541 |
| - Bs, |
542 |
| - A_scale, |
543 |
| - B_scales, |
544 |
| - gather_dim=gather_dim, |
545 |
| - group_name=group.group_name, |
546 |
| - biases=[None] * len(Bs), |
547 |
| - result_scales=[None] * len(Bs), |
548 |
| - out_dtypes=out_dtypes, |
549 |
| - use_fast_accum=[None] * len(Bs), |
550 |
| - ) |
| 533 | + ag_output_vec = [] |
| 534 | + mm_outputs_vec = [] |
| 535 | + for context in test_contexts: |
| 536 | + with context(): |
| 537 | + ( |
| 538 | + ag_output, |
| 539 | + mm_outputs, |
| 540 | + ) = torch.ops.symm_mem.fused_all_gather_scaled_matmul( |
| 541 | + A_shard, |
| 542 | + Bs, |
| 543 | + A_scale, |
| 544 | + B_scales, |
| 545 | + gather_dim=gather_dim, |
| 546 | + group_name=group.group_name, |
| 547 | + biases=[None] * len(Bs), |
| 548 | + result_scales=[None] * len(Bs), |
| 549 | + out_dtypes=out_dtypes, |
| 550 | + use_fast_accum=[None] * len(Bs), |
| 551 | + ) |
| 552 | + ag_output_vec.append(ag_output) |
| 553 | + mm_outputs_vec.append(mm_outputs) |
551 | 554 |
|
552 | 555 | self.assertTrue(
|
553 | 556 | torch.allclose(
|
554 |
| - ag_output_0.to(torch.float32), |
555 |
| - ag_output_1.to(torch.float32), |
| 557 | + ag_output_vec[0].to(torch.float32), |
| 558 | + ag_output_vec[1].to(torch.float32), |
556 | 559 | )
|
557 | 560 | )
|
558 |
| - self.assertEqual(ag_output_0.stride(), ag_output_1.stride()) |
559 |
| - for mm_output_0, mm_output_1 in zip(mm_outputs_0, mm_outputs_1): |
| 561 | + self.assertEqual(ag_output_vec[0].stride(), ag_output_vec[1].stride()) |
| 562 | + for mm_output_0, mm_output_1 in zip(mm_outputs_vec[0], mm_outputs_vec[1]): |
560 | 563 | self.assertTrue(
|
561 | 564 | torch.allclose(
|
562 | 565 | mm_output_0.to(torch.float32), mm_output_1.to(torch.float32)
|
@@ -584,15 +587,21 @@ def test_fused_matmul_reduce_scatter(self, scatter_dim: int) -> None:
|
584 | 587 | A = torch.rand(BATCH, M, K, device="cuda")
|
585 | 588 | B = torch.rand(K, N, device="cuda")
|
586 | 589 |
|
587 |
| - output_0 = _fused_matmul_reduce_scatter_fallback( |
588 |
| - A, B, "avg", scatter_dim=scatter_dim, group_name=group.group_name |
589 |
| - ) |
590 |
| - output_1 = torch.ops.symm_mem.fused_matmul_reduce_scatter( |
591 |
| - A, B, "avg", scatter_dim=scatter_dim, group_name=group.group_name |
592 |
| - ) |
| 590 | + outputs = [] |
| 591 | + for context in test_contexts: |
| 592 | + with context(): |
| 593 | + outputs.append( |
| 594 | + torch.ops.symm_mem.fused_matmul_reduce_scatter( |
| 595 | + A, |
| 596 | + B, |
| 597 | + "avg", |
| 598 | + scatter_dim=scatter_dim, |
| 599 | + group_name=group.group_name, |
| 600 | + ) |
| 601 | + ) |
593 | 602 |
|
594 |
| - assert torch.allclose(output_0, output_1) |
595 |
| - assert output_0.stride() == output_1.stride() |
| 603 | + assert torch.allclose(outputs[0], outputs[1]) |
| 604 | + assert outputs[0].stride() == outputs[1].stride() |
596 | 605 |
|
597 | 606 | dist.destroy_process_group()
|
598 | 607 |
|
|
0 commit comments