8000 [SymmMem] Add all-to-all (#151498) · pytorch/pytorch@d7961a1 · GitHub
[go: up one dir, main page]

Skip to content

Commit d7961a1

Browse files
kwen2501pytorchmergebot
authored andcommitted
[SymmMem] Add all-to-all (#151498)
Add an all-to-all impl based on NVSHMEM's on-stream API `nvshmemx_alltoallmem_on_stream`. Pull Request resolved: #151498 Approved by: https://github.com/fegin, https://github.com/fduwjj ghstack dependencies: #151261
1 parent 7c3e679 commit d7961a1

File tree

4 files changed

+141
-0
lines changed
Open diff view settings

4 files changed

+141
-0
lines changed

test/distributed/test_nvshmem.py

Lines changed: 115 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,115 @@
1+
# Owner(s): ["oncall: distributed"]
2+
3+
# To run:
4+
# TORCH_SYMMMEM=NVSHMEM python test/distributed/test_nvshmem.py
5+
# OR
6+
# TORCH_SYMMMEM=NVSHMEM torchrun --nproc-per-node 4 test/distributed/test_nvshmem.py
7+
8+
import os
9+
import sys
10+
import tempfile
11+
12+
import torch
13+
import torch.distributed as dist
14+
import torch.distributed._symmetric_memory as symm_mem
15+
from torch.testing._internal.common_distributed import (
16+
MultiProcContinousTest,
17+
TEST_SKIPS,
18+
)
19+
from torch.testing._internal.common_utils import (
20+
skip_but_pass_in_sandcastle_if,
21+
skipIfRocm,
22+
)
23+
24+
25+
symm_mem_backend = os.getenv("TORCH_SYMMMEM")
26+
27+
if symm_mem_backend != "NVSHMEM":
28+
print(
29+
"test_nvshmem requires setting `TORCH_SYMMMEM=NVSHMEM`, skipping tests",
30+
file=sys.stderr,
31+
)
32+
sys.exit(0)
33+
34+
35+
# Decorator
36+
def requires_nvshmem():
37+
return skip_but_pass_in_sandcastle_if(
38+
symm_mem_backend != "NVSHMEM",
39+
"test_nvshmem requires setting `TORCH_SYMMMEM=NVSHMEM`",
40+
)
41+
42+
43+
# So that tests are written in device-agnostic way
44+
device_type = "cuda"
45+
device_module = torch.get_device_module(device_type)
46+
47+
48+
@requires_nvshmem()
49+
class NVSHMEMSymmetricMemoryTest(MultiProcContinousTest):
50+
def setUp(self) -> None:
51+
super().setUp()
52+
# TODO: relieve this (seems to hang if without)
53+
device_module.set_device(self.device)
54+
# NOTE: required for nvshmem allocation
55+
torch.empty(1, device=self.device)
56+
57+
# Required by MultiProcContinousTest
58+
@classmethod
59+
def backend_str(cls) -> str:
60+
return "nccl"
61+
62+
@property
63+
def world_size(self) -> int:
64+
return device_module.device_count()
65+
66+
@property
67+
def device(self) -> torch.device:
68+
return torch.device(device_type, self.rank)
69+
70+
@skipIfRocm
71+
def test_nvshmem_all_to_all(self) -> None:
72+
group_name = dist.group.WORLD.group_name
73+
symm_mem.enable_symm_mem_for_group(group_name)
74+
75+
dtype = torch.float
76+
numel_per_peer = 10
77+
numel = self.world_size * numel_per_peer
78+
inp = symm_mem.empty(numel, dtype=dtype, device=self.device).fill_(self.rank)
79+
out = symm_mem.empty(numel, dtype=dtype, device=self.device).fill_(-1)
80+
81+
symm_mem.rendezvous(inp, group=group_name)
82+
symm_mem.rendezvous(out, group=group_name)
83+
torch.ops.symm_mem.nvshmem_all_to_all(inp, out, group_name)
84+
85+
expected = torch.cat(
86+
[
87+
torch.empty(numel_per_peer, dtype=dtype, device=self.device).fill_(i)
88+
for i in range(self.world_size)
89+
]
90+
)
91+
torch.testing.assert_close(out, expected)
92+
93+
94+
if __name__ == "__main__":
95+
if not device_module.is_available():
96+
sys.exit(TEST_SKIPS["no_cuda"].exit_code)
97+
98+
# If launched by torchrun, these values would have been set
99+
rank = int(os.getenv("RANK", "-1"))
100+
world_size = int(os.getenv("WORLD_SIZE", "-1"))
101+
102+
if rank != -1:
103+
# Launched with torchrun or other multi-proc launchers. Directly run the test.
104+
NVSHMEMSymmetricMemoryTest.run_rank(rank, world_size)
105+
else:
106+
# No external launcher, spawn N processes
107+
world_size = device_module.device_count()
108+
# Launched as a single process. Spawn subprocess to run the tests.
109+
# Also need a rendezvous file for `init_process_group` purpose.
110+
rdvz_file = tempfile.NamedTemporaryFile(delete=False).name
111+
torch.multiprocessing.spawn(
112+
NVSHMEMSymmetricMemoryTest.run_rank,
113+
nprocs=world_size,
114+
args=(world_size, rdvz_file),
115+
)

torch/csrc/distributed/c10d/SymmetricMemory.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -276,6 +276,8 @@ TORCH_LIBRARY_FRAGMENT(symm_mem, m) {
276276
"memset32_(Tensor(a!) input, int offset, int val, int count) -> Tensor(a!)");
277277

278278
m.def("nvshmem_broadcast(Tensor(a!) input, str group_name) -> Tensor(a!)");
279+
m.def(
280+
"nvshmem_all_to_all(Tensor input, Tensor(a!) out, str group_name) -> Tensor(a!)");
279281
}
280282

281283
TORCH_LIBRARY_IMPL(symm_mem, Meta, m) {

torch/csrc/distributed/c10d/nvshmem_extension.cu

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -117,10 +117,29 @@ at::Tensor nvshmem_broadcast(at::Tensor& input, const std::string& group_name) {
117117
return input;
118118
}
119119

120+
at::Tensor nvshmem_all_to_all(
121+
at::Tensor& input,
122+
at::Tensor& out,
123+
std::string group_name) {
124+
auto input_hdl = c10d::symmetric_memory::rendezvous(input, group_name);
125+
auto out_hdl = c10d::symmetric_memory::rendezvous(out, group_name);
126+
int rank = input_hdl->get_rank();
127+
int world_size = input_hdl->get_world_size();
128+
auto team = group_to_team(group_name, input_hdl->get_rank_to_global_rank());
129+
130+
void* input_ptr = input_hdl->get_buffer_ptrs()[rank];
131+
void* output_ptr = out_hdl->get_buffer_ptrs()[rank];
132+
size_t bytes_per_rank = input_hdl->get_buffer_size() / world_size;
133+
134+
auto stream = at::cuda::getCurrentCUDAStream(input.device().index());
135+
nvshmemx_alltoallmem_on_stream(team, output_ptr, input_ptr, bytes_per_rank, stream);
136+
return out;
137+
}
120138

121139
} // namespace c10d::nvshmem_extension
122140

123141

124142
TORCH_LIBRARY_IMPL(symm_mem, CUDA, m) {
125143
m.impl("nvshmem_broadcast", c10d::nvshmem_extension::nvshmem_broadcast);
144+
m.impl("nvshmem_all_to_all", c10d::nvshmem_extension::nvshmem_all_to_all);
126145
}

torch/csrc/distributed/c10d/nvshmem_extension.cuh

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,4 +17,9 @@ void* nvshmem_ptr(const void* dest, int pe);
1717

1818
at::Tensor nvshmem_broadcast(at::Tensor& input, const std::string& group_name);
1919

20+
at::Tensor nvshmem_all_to_all(
21+
at::Tensor& input,
22+
at::Tensor& out,
23+
std::string group_name);
24+
2025
} // namespace c10d::nvshmem_extension

0 commit comments

Comments
 (0)
0