10BC0 [a2av] 2D all-to-all-vdev · pytorch/pytorch@1f99ef1 · GitHub
[go: up one dir, main page]

Skip to content

Commit 1f99ef1

Browse files
committed
[a2av] 2D all-to-all-vdev
ghstack-source-id: 5d2a22a Pull-Request-resolved: #155058 Add device guard
1 parent b146ebd commit 1f99ef1

File tree

4 files changed

+243
-0
lines changed

4 files changed

+243
-0
lines changed

test/distributed/test_nvshmem.py

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,80 @@ def test_nvshmem_all_to_all_vdev(self) -> None:
128128
dist.all_to_all_single(expected, inp, out_splits.tolist(), inp_splits.tolist())
129129
torch.testing.assert_close(out[:out_numel], expected)
130130

131+
@skipIfRocm
132+
def test_nvshmem_all_to_all_vdev_2d(self) -> None:
133+
torch.manual_seed(42 + self.rank)
134+
self._init_device()
135+
136+
group_name = dist.group.WORLD.group_name
137+
symm_mem.enable_symm_mem_for_group(group_name)
138+
139+
dtype = torch.float
140+
# Number of experts per rank
141+
ne = 4
142+
nsplits = ne * self.world_size
143+
# Number of elements for an expert is random between [0, k)
144+
k = 3
145+
inp_splits = torch.randint(k, (nsplits,), device=self.device)
146+
inp_numel = inp_splits.sum().item()
147+
# Exchange input splits to get output splits
148+
out_splits = torch.zeros_like(inp_splits)
149+
dist.all_to_all_single(out_splits, inp_splits)
150+
# We do a .t() here because there is a rank-major to expert-major shuffle
151+
out_splits_t = out_splits.reshape(self.world_size, ne).t().reshape(-1)
152+
153+
# Total number of output elements
154+
out_numel = out_splits.sum().item()
155+
# Align up to make it bigger
156+
align = 16
157+
out_numel_max = (out_numel + align - 1) // align * align
158+
159+
inp = symm_mem.empty(inp_numel, dtype=dtype, device=self.device).fill_(
160+
self.rank
161+
)
162+
out = symm_mem.empty(out_numel_max, dtype=dtype, device=self.device).fill_(-1)
163+
in_out_splits = symm_mem.empty(
164+
(3, nsplits), dtype=torch.int64, device=self.device
165+
).fill_(-1)
166+
# Row 0 is input splits
167+
in_out_splits[0].copy_(inp_splits)
168+
169+
torch.ops.symm_mem.nvshmem_all_to_all_vdev_2d(
170+
inp, out, in_out_splits, group_name
171+
)
172+
173+
# Check input splits (row 0) -- should not change
174+
torch.testing.assert_close(in_out_splits[0], inp_splits)
175+
176+
# Check output splits (row 1)
177+
torch.testing.assert_close(in_out_splits[1], out_splits_t)
178+
179+
# Check output offsets (row 2)
180+
out_offsets = torch.cumsum(out_splits_t, dim=0) # inclusive scan
181+
# output offsets from `nvshmem_all_to_all_vdev` is exclusive scan
182+
self.assertEqual(in_out_splits[2][0], 0)
183+
torch.testing.assert_close(in_out_splits[2][1:], out_offsets[:-1])
184+
185+
# Check data
186+
expected = torch.empty(out_numel, dtype=dtype, device=self.device)
187+
inp_splits_rank = inp_splits.reshape(self.world_size, ne).sum(1)
188+
out_splits_rank = out_splits.reshape(self.world_size, ne).sum(1)
189+
dist.all_to_all_single(
190+
expected, inp, out_splits_rank.tolist(), inp_splits_rank.tolist()
191+
)
192+
# We still need to shuffle `expected`
193+
out_offsets = torch.cumsum(out_splits, dim=0) # inclusive scan
194+
result_list = []
195+
for j in range(ne):
196+
for i in range(self.world_size):
197+
chunk_id = i * ne + j
198+
offset = out_offsets[chunk_id]
199+
chunk = expected[offset - out_splits[chunk_id] : offset]
200+
result_list.append(chunk)
201+
202+
final = torch.cat(result_list)
203+
torch.testing.assert_close(out[:out_numel], final)
204+
131205

132206
if __name__ == "__main__":
133207
run_tests()

torch/csrc/distributed/c10d/SymmetricMemory.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -280,6 +280,8 @@ TORCH_LIBRARY_FRAGMENT(symm_mem, m) {
280280
"nvshmem_all_to_all(Tensor input, Tensor(a!) out, str group_name) -> Tensor(a!)");
281281
m.def(
282282
"nvshmem_all_to_all_vdev(Tensor input, Tensor(a!) out, Tensor(a!) in_out_splits, str group_name) -> Tensor(a!)");
283+
m.def(
284+
"nvshmem_all_to_all_vdev_2d(Tensor input, Tensor(a!) out, Tensor(a!) in_out_splits, str group_name) -> Tensor(a!)");
283285
}
284286

285287
TORCH_LIBRARY_IMPL(symm_mem, Meta, m) {

torch/csrc/distributed/c10d/nvshmem_extension.cu

Lines changed: 161 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -311,11 +311,172 @@ at::Tensor nvshmem_all_to_all_vdev(
311311
return out;
312312
}
313313

314+
// Start of `nvshmem_all_to_all_vdev_2d`
315+
// This kernel is used to exchange output splits and source offsets between peers.
316+
// For meaning of `mype` and `npes`, see the docstring of `nvshmem_all_to_all_vdev_2d`.
317+
// `in_out_splits` is of size (3, npes * ne) and contains:
318+
// - input splits (IN)
319+
// - output splits (OUT) and
320+
// - source offsets (OUT).
321+
__global__ void exchangeSplitAndOffset_2d(int64_t* in_out_splits, int mype, int npes, int ne) {
322+
int nsplits = npes * ne;
323+
auto input_splits = in_out_splits;
324+
auto output_splits = in_out_splits + nsplits;
325+
auto source_offsets = in_out_splits + nsplits * 2;
326+
int tid = threadIdx.x;
327+
328+
__shared__ int64_t peer_offsets[THREADS_PER_BLOCK];
329+
330+
// Scan input splits to get the source offsets
331+
prefixSum(peer_offsets, input_splits, nsplits);
332+
__syncthreads();;
333+
334+
// Use 1 block to do the exchange
335+
if (tid < nsplits) {
336+
int peer = tid / ne;
337+
int e = tid % ne;
338+
// This does a transpose from rank-major order to expert-major order
339+
int dst_offset = e * npes + mype;
340+
nvshmem_int64_p(source_offsets + dst_offset, peer_offsets[tid], peer);
341+
nvshmem_int64_p(output_splits + dst_offset, input_splits[tid], peer);
342+
}
343+
// This barrier ensures that all remote PEs see the updated values
344+
nvshmemx_barrier_all_block();
345+
}
346+
347+
// This kernel is used to do the actual data exchange.
348+
// `in_out_splits` has the same definition as in `exchangeSplitAndOffset`.
349+
// `stride` is the stride at dim 0, unit in byte.
350+
// For meaning of `mype` and `npes`, see the docstring of `nvshmem_all_to_all_vdev_2d`.
351+
__global__ void allToAllV_2d(void *send_data, void *recv_data, int64_t* in_out_splits, size_t stride, int mype, int npes, int ne) {
352+
int nsplits = npes * ne;
353+
auto output_splits = in_out_splits + nsplits;
354+
auto source_offsets = in_out_splits + nsplits * 2;
355+
int bid = blockIdx.x;
356+
int tid = threadIdx.x;
357+
358+
// Calculate the output offsets
359+
__shared__ int64_t e_offsets[THREADS_PER_BLOCK];
360+
prefixSum(e_offsets, output_splits, nsplits);
361+
__syncthreads();
362+
363+
// Target a different e based on bid
364+
for (int eid = bid; eid < nsplits; eid += gridDim.x) {
365+
int peer = eid % npes;
366+
// Amount from `peer` for `e`
367+
auto peer_size = output_splits[eid] * stride;
368+
auto source_offset = source_offsets[eid] * stride;
369+
auto write_offset = e_offsets[eid] * stride;
370+
nvshmemx_getmem_block(
371+
(char*)recv_data + write_offset,
372+
(char*)send_data + source_offset,
373+
peer_size,
374+
peer);
375+
}
376+
// Write out the output offsets (to the scratchpad line)
377+
if (bid == 0 && tid < nsplits) {
378+
source_offsets[tid] = e_offsets[tid];
379+
}
380+
}
381+
382+
at::Tensor nvshmem_all_to_all_vdev_2d(
383+
at::Tensor& input,
384+
at::Tensor& out,
385+
at::Tensor& in_out_splits,
386+
std::string group_name) {
387+
/* Perform a 2D AllToAllv shuffle operation using NVSHMEM, with split information provided on device.
388+
* Arguments:
389+
* - `input` is the input tensor
390+
* - `out` is the output tensor
391+
* - `in_out_splits` is a 2D tensor of size (3, `world_size` * `ne`). In the
392+
scenario of Mixture-of-Experts models, `ne` is the number of experts per
393+
rank. The rows of `in_out_splits` are (in order):
394+
input splits (IN)
395+
output splits (OUT) and
396+
output offsets (OUT).
397+
* - `group_name` is the name of the group to use for the collective operation.
398+
399+
* A 2D AllToAllv shuffle is illustrated below:
400+
(world_size = 2, ne = 2, total number of experts = 4)
401+
Source: | Rank 0 | Rank 1 |
402+
| c0 | c1 | c2 | c3 | d0 | d1 | d2 | d3 |
403+
404+
Dest : | Rank 0 | Rank 1 |
405+
| c0 | d0 | c1 | d1 | c2 | d2 | c3 | d3 |
406+
where each `c_i` / `d_i` are slices of the `input` tensor, targeting
407+
expert `i`, with length indicated by input splits (in
408+
`in_out_splits[0]`). That is, the 2D AllToAllv shuffle achives a
409+
transpose from rank-major order at input to expert-major order at
410+
output.
411+
*/
412+
auto input_hdl = c10d::symmetric_memory::rendezvous(input, group_name);
413+
auto out_hdl = c10d::symmetric_memory::rendezvous(out, group_name);
414+
auto splits_hdl = c10d::symmetric_memory::rendezvous(in_out_splits, group_name);
415+
int rank = input_hdl->get_rank();
416+
int world_size = input_hdl->get_world_size();
417+
418+
void* input_ptr = input_hdl->get_buffer_ptrs()[rank];
419+
void* output_ptr = out_hdl->get_buffer_ptrs()[rank];
420+
int64_t* splits_ptr = (int64_t*)(splits_hdl->get_buffer_ptrs()[rank]);
421+
422+
auto split_shape = in_out_splits.sizes();
423+
TORCH_CHECK(split_shape.size() == 2 && split_shape[0] == 3, "in_out_splits must be 2D with 3 rows");
424+
TORCH_CHECK(split_shape[1] % world_size == 0, "Each row of in_out_splits must be a multiple of world_size");
425+
// Number of experts per rank
426+
int ne = split_shape[1] / world_size;
427+
428+
// Set device context for getting the stream and launching kernels below
429+
c10::cuda::CUDAGuard guard(input.device());
430+
auto stream = at::cuda::getCurrentCUDAStream();
431+
432+
// Exchange output splits and source offsets
433+
// Use collective launch because kernel involves nvshmem barrier
434+
void* args0[] = {
435+
&splits_ptr,
436+
&rank,
437+
&world_size,
438+
&ne};
439+
nvshmemx_collective_launch(
440+
(const void*)exchangeSplitAndOffset_2d,
441+
dim3(1),
442+
dim3(THREADS_PER_BLOCK),
443+
args0,
444+
0,
445+
stream);
446+
447+
// CTA Tuning
448+
// Naive for now, use 1 block per expert.
449+
// Total number of blocks is limited to 64 (intra-node) or 8 (inter-node).
450+
int num_blocks = std::min(world_size * ne, world_size > 8 ? 8 : 64);
451+
452+
// Stride at dim 0 (assuming input is contiguous, TODO)
453+
size_t stride_bytes = input.stride(0) * input.element_size();
454+
455+
// All to all data exchange
456+
void* args1[] = {
457+
&input_ptr,
458+
&output_ptr,
459+
&splits_ptr,
460+
&stride_bytes,
461+
&rank,
462+
&world_size,
463+
&ne};
464+
nvshmemx_collective_launch(
465+
(const void*)allToAllV_2d,
466+
dim3(num_blocks),
467+
dim3(THREADS_PER_BLOCK),
468+
args1,
469+
0,
470+
stream);
471+
return out;
472+
}
473+
314474
} // namespace c10d::nvshmem_extension
315475

316476

317477
TORCH_LIBRARY_IMPL(symm_mem, CUDA, m) {
318478
m.impl("nvshmem_broadcast", c10d::nvshmem_extension::nvshmem_broadcast);
319479
m.impl("nvshmem_all_to_all", c10d::nvshmem_extension::nvshmem_all_to_all);
320480
m.impl("nvshmem_all_to_all_vdev", c10d::nvshmem_extension::nvshmem_all_to_all_vdev);
481+
m.impl("nvshmem_all_to_all_vdev_2d", c10d::nvshmem_extension::nvshmem_all_to_all_vdev_2d);
321482
}

torch/csrc/distributed/c10d/nvshmem_extension.cuh

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,4 +28,10 @@ at::Tensor nvshmem_all_to_all_vdev(
2828
at::Tensor& in_out_splits,
2929
std::string group_name);
3030

31+
at::Tensor nvshmem_all_to_all_vdev_2d(
32+
at::Tensor& input,
33+
at::Tensor& out,
34+
at::Tensor& in_out_splits,
35+
std::string group_name);
36+
3137
} // namespace c10d::nvshmem_extension

0 commit comments

Comments
 (0)
0