8000 Update · pytorch/pytorch@fd596f4 · GitHub
[go: up one dir, main page]

Skip to content

Commit fd596f4

Browse files
committed
Update
[ghstack-poisoned]
1 parent 741b5a8 commit fd596f4

File tree

4 files changed

+14
-8
lines changed

4 files changed

+14
-8
lines changed

test/distributed/test_nvshmem.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@
1313
import torch.distributed._symmetric_memory as symm_mem
1414
from torch.testing._internal.common_distributed import MultiProcContinousTest
1515
from torch.testing._internal.common_utils import (
16+
instantiate_parametrized_tests,
17+
parametrize,
1618
run_tests,
1719
skip_but_pass_in_sandcastle_if,
1820
skipIfRocm,
@@ -42,6 +44,7 @@ def requires_nvshmem():
4244
device_module = torch.get_device_module(device_type)
4345

4446

47+
@instantiate_parametrized_tests
4548
@requires_nvshmem()
4649
class NVSHMEMSymmetricMemoryTest(MultiProcContinousTest):
4750
def _init_device(self) -> None:
@@ -129,7 +132,8 @@ def test_nvshmem_all_to_all_vdev(self) -> None:
129132
torch.testing.assert_close(out[:out_numel], expected)
130133

131134
@skipIfRocm
132-
def test_nvshmem_all_to_all_vdev_2d(self) -> None:
135+
@parametrize("align", [1, 8, 16]) # `major_align` of output
136+
def test_nvshmem_all_to_all_vdev_2d(self, align: int) -> None:
133137
torch.manual_seed(42 + self.rank)
134138
self._init_device()
135139

@@ -142,8 +146,6 @@ def test_nvshmem_all_to_all_vdev_2d(self) -> None:
142146
nsplits = ne * self.world_size
143147
# Number of elements for an expert is random between [0, k)
144148
k = 3
145-
# Align
146-
align = 16
147149
inp_splits = torch.randint(k, (nsplits,), device=self.device)
148150
inp_numel = inp_splits.sum().item()
149151
# Exchange input splits to get output splits
@@ -168,7 +170,7 @@ def test_nvshmem_all_to_all_vdev_2d(self) -> None:
168170
in_out_splits[0].copy_(inp_splits)
169171

170172
torch.ops.symm_mem.nvshmem_all_to_all_vdev_2d(
171-
inp, out, in_out_splits, group_name, align
173+
inp, out, in_out_splits, group_name, major_align=align
172174
)
173175
received_out_splits = in_out_splits[1]
174176
received_out_offsets = in_out_splits[2]

torch/csrc/distributed/c10d/SymmetricMemory.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -281,7 +281,7 @@ TORCH_LIBRARY_FRAGMENT(symm_mem, m) {
281281
m.def(
282282
"nvshmem_all_to_all_vdev(Tensor input, Tensor(a!) out, Tensor(a!) in_out_splits, str group_name) -> Tensor(a!)");
283283
m.def(
284-
"nvshmem_all_to_all_vdev_2d(Tensor input, Tensor(a!) out, Tensor(a!) in_out_splits, str group_name, int major_align) -> Tensor(a!)");
284+
"nvshmem_all_to_all_vdev_2d(Tensor input, Tensor(a!) out, Tensor(a!) in_out_splits, str group_name, int? major_align=None) -> Tensor(a!)");
285285
}
286286

287287
TORCH_LIBRARY_IMPL(symm_mem, Meta, m) {

torch/csrc/distributed/c10d/nvshmem_extension.cu

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -469,7 +469,7 @@ at::Tensor nvshmem_all_to_all_vdev_2d(
469469
at::Tensor& out,
470470
at::Tensor& in_out_splits,
471471
std::string group_name,
472-
int64_t major_align) {
472+
std::optional<int64_t> major_align) {
473473
/* Perform a 2D AllToAllv shuffle operation using NVSHMEM, with split information provided on device.
474474
* Arguments:
475475
* - `input` is the input tensor
@@ -514,6 +514,10 @@ at::Tensor nvshmem_all_to_all_vdev_2d(
514514
// TODO: world_size is currently limited by the number of elements in a WarpScan.
515515
TORCH_CHECK(world_size <= A2AV_TILE_SIZE, "world_size must be smaller than A2AV_TILE_SIZE", A2AV_TILE_SIZE);
516516

517+
// If `major_align` is not provided, use 1 as the default value.
518+
int64_t major_align_val = major_align.value_or(1);
519+
TORCH_CHECK(major_align_val > 0, "major_align must be positive");
520+
517521
void* input_ptr = input_hdl->get_buffer_ptrs()[rank];
518522
void* output_ptr = out_hdl->get_buffer_ptrs()[rank];
519523
int64_t* splits_ptr = (int64_t*)(splits_hdl->get_buffer_ptrs()[rank]);
@@ -579,7 +583,7 @@ at::Tensor nvshmem_all_to_all_vdev_2d(
579583
&rank,
580584
&world_size,
581585
&ne,
582-
&major_align};
586+
&major_align_val};
583587
nvshmemx_collective_launch(
584588
(const void*)allToAllV_2d,
585589
dim3(num_blocks),

torch/csrc/distributed/c10d/nvshmem_extension.cuh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,6 @@ at::Tensor nvshmem_all_to_all_vdev_2d(
3333
at::Tensor& out,
3434
at::Tensor& in_out_splits,
3535
std::string group_name,
36-
int64_t major_align = 1);
36+
std::optional<int64_t> major_align = std::nullopt);
3737

3838
} // namespace c10d::nvshmem_extension

0 commit comments

Comments
 (0)
0