10BC0 [a2av] Align length of major dimension in output of 2D a2av · pytorch/pytorch@eb3a3f4 · GitHub
[go: up one dir, main page]

Skip to content

Commit eb3a3f4

Browse files
committed
[a2av] Align length of major dimension in output of 2D a2av
Tile-based prefix sum Working Use cub for prefix sum of tile offset ghstack-source-id: 96c6c13 Pull-Request-resolved: #155172
1 parent fb92ced commit eb3a3f4

File tree

4 files changed

+123
-22
lines changed

4 files changed

+123
-22
lines changed

test/distributed/test_nvshmem.py

Lines changed: 31 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -142,19 +142,20 @@ def test_nvshmem_all_to_all_vdev_2d(self) -> None:
142142
nsplits = ne * self.world_size
143143
# Number of elements for an expert is random between [0, k)
144144
k = 3
145+
# Align
146+
align = 16
145147
inp_splits = torch.randint(k, (nsplits,), device=self.device)
146148
inp_numel = inp_splits.sum().item()
147149
# Exchange input splits to get output splits
148150
out_splits = torch.zeros_like(inp_splits)
149151
dist.all_to_all_single(out_splits, inp_splits)
150152
# We do a .t() here because there is a rank-major to expert-major shuffle
151-
out_splits_t = out_splits.reshape(self.world_size, ne).t().reshape(-1)
153+
out_splits_t = out_splits.reshape(self.world_size, ne).t()
152154

153155
# Total number of output elements
154156
out_numel = out_splits.sum().item()
155-
# Align up to make it bigger
156-
align = 16
157-
out_numel_max = (out_numel + align - 1) // align * align
157+
# Align-up makes it bigger
158+
out_numel_max = (out_numel + align * ne) // align * align
158159

159160
inp = symm_mem.empty(inp_numel, dtype=dtype, device=self.device).fill_(
160161
self.rank
@@ -167,20 +168,34 @@ def test_nvshmem_all_to_all_vdev_2d(self) -> None:
167168
in_out_splits[0].copy_(inp_splits)
168169

169170
torch.ops.symm_mem.nvshmem_all_to_all_vdev_2d(
170-
inp, out, in_out_splits, group_name
171+
inp, out, in_out_splits, group_name, align
171172
)
173+
received_out_splits = in_out_splits[1]
174+
received_out_offsets = in_out_splits[2]
172175

173176
# Check input splits (row 0) -- should not change
174177
torch.testing.assert_close(in_out_splits[0], inp_splits)
175178

176179
# Check output splits (row 1)
177-
torch.testing.assert_close(in_out_splits[1], out_splits_t)
180+
torch.testing.assert_close(received_out_splits, out_splits_t.reshape(-1))
178181

179182
# Check output offsets (row 2)
180-
out_offsets = torch.cumsum(out_splits_t, dim=0) # inclusive scan
181-
# output offsets from `nvshmem_all_to_all_vdev` is exclusive scan
182-
self.assertEqual(in_out_splits[2][0], 0)
183-
torch.testing.assert_close(in_out_splits[2][1:], out_offsets[:-1])
183+
out_split_list = out_splits_t.tolist()
184+
for i in range(ne):
185+
expert_sum = 0
186+
for j in range(self.world_size):
187+
expert_sum += out_split_list[i][j]
188+
align_pad = align - (expert_sum % align)
189+
# last element absorbs the padding
190+
out_split_list[i][-1] += align_pad
191+
192+
out_splits_padded = torch.tensor(out_split_list, device=self.device).reshape(-1)
193+
out_offsets = torch.cumsum(out_splits_padded, dim=0) # inclusive scan
194+
# Make it exclusive scan because that's what `nvshmem_all_to_all_vdev_2d` returns
195+
out_offsets = torch.cat(
196+
[torch.zeros(1, device=self.device), out_offsets[:-1]]
197+
).to(torch.int64)
198+
torch.testing.assert_close(received_out_offsets, out_offsets)
184199

185200
# Check data
186201
expected = torch.empty(out_numel, dtype=dtype, device=self.device)
@@ -199,8 +214,12 @@ def test_nvshmem_all_to_all_vdev_2d(self) -> None:
199214
chunk = expected[offset - out_splits[chunk_id] : offset]
200215
result_list.append(chunk)
201216

202-
final = torch.cat(result_list)
203-
torch.testing.assert_close(out[:out_numel], final)
217+
# Do a chunk-wise comparison
218+
for c, chunk in enumerate(result_list):
219+
start = received_out_offsets[c].item()
220+
split = received_out_splits[c].item()
221+
received_chunk = out[start : start + split]
222+
torch.testing.assert_close(received_chunk, chunk)
204223

205224

206225
if __name__ == "__main__":

torch/csrc/distributed/c10d/SymmetricMemory.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -281,7 +281,7 @@ TORCH_LIBRARY_FRAGMENT(symm_mem, m) {
281281
m.def(
282282
"nvshmem_all_to_all_vdev(Tensor input, Tensor(a!) out, Tensor(a!) in_out_splits, str group_name) -> Tensor(a!)");
283283
m.def(
284-
"nvshmem_all_to_all_vdev_2d(Tensor input, Tensor(a!) out, Tensor(a!) in_out_splits, str group_name) -> Tensor(a!)");
284+
"nvshmem_all_to_all_vdev_2d(Tensor input, Tensor(a!) out, Tensor(a!) in_out_splits, str group_name, int major_align) -> Tensor(a!)");
285285
}
286286

287287
TORCH_LIBRARY_IMPL(symm_mem, Meta, m) {

torch/csrc/distributed/c10d/nvshmem_extension.cu

Lines changed: 89 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -11,12 +11,18 @@
1111
#include <ATen/cuda/cub.cuh>
1212
#include <nvshmem.h>
1313

14+
#include <cooperative_groups.h>
15+
1416
namespace c10d::nvshmem_extension {
1517

18+
using namespace cooperative_groups;
19+
namespace cg = cooperative_groups;
20+
1621
using c10d::symmetric_memory::StoreExchange;
1722
static StoreExchange storeExchange = StoreExchange("nvshmem_ext");
1823

1924
#define THREADS_PER_BLOCK 512
25+
#define WARP_SIZE 32
2026

2127
// Bootstrap based on user's setting for NCCL
2228
// Long term, this may be a bit unclean; short term, it improves UX
@@ -344,20 +350,82 @@ __global__ void exchangeSplitAndOffset_2d(int64_t* in_out_splits, int mype, int
344350
nvshmemx_barrier_all_block();
345351
}
346352

353+
// This is an warp-scope, exclusive prefix sum.
354+
__device__ void prefixSum_warp(int64_t *odata, int64_t *idata, int n) {
355+
CUDA_KERNEL_ASSERT(n <= WARP_SIZE);
356+
constexpr int NUM_WARPS = THREADS_PER_BLOCK / WARP_SIZE;
357+
358+
// Specialize WarpScan for type int
359+
using WarpScan = at_cuda_detail::cub::WarpScan<int64_t>;
360+
// Allocate WarpScan shared memory for N warps
361+
__shared__ typename WarpScan::TempStorage temp_storage[NUM_WARPS];
362+
363+
// Obtain input item for each thread
364+
int tid = threadIdx.x % WARP_SIZE;
365+
int64_t thread_data = (tid < n) ? idata[tid] : 0;
366+
367+
// Compute the warp-wide exclusive prefix sum
368+
int warp_id = threadIdx.x / WARP_SIZE;
369+
WarpScan(temp_storage[warp_id]).ExclusiveSum(thread_data, thread_data);
370+
371+
// Store the result
372+
odata[tid] = thread_data;
373+
}
374+
375+
// This is for abstracting a thread-group-scope, exclusive prefix sum.
376+
// Since we use warp-scope prefix sum, the thread group size is limited to warp size.
377+
#define A2AV_TILE_SIZE WARP_SIZE
378+
347379
// This kernel is used to do the actual data exchange.
348380
// `in_out_splits` has the same definition as in `exchangeSplitAndOffset`.
349381
// `stride` is the stride at dim 0, unit in byte.
350382
// For meaning of `mype` and `npes`, see the docstring of `nvshmem_all_to_all_vdev_2d`.
351-
__global__ void allToAllV_2d(void *send_data, void *recv_data, int64_t* in_out_splits, size_t stride, int mype, int npes, int ne) {
383+
__global__ void allToAllV_2d(void *send_data, void *recv_data, int64_t* in_out_splits, size_t stride, int mype, int npes, int ne, int64_t major_align) {
352384
int nsplits = npes * ne;
353385
auto output_splits = in_out_splits + nsplits;
354386
auto source_offsets = in_out_splits + nsplits * 2;
355387
int bid = blockIdx.x;
356388
int tid = threadIdx.x;
357389

358-
// Calculate the output offsets
359-
__shared__ int64_t e_offsets[THREADS_PER_BLOCK];
360-
prefixSum(e_offsets, output_splits, nsplits);
390+
// Split the thread block into tiles
391+
constexpr int NUM_TILES = THREADS_PER_BLOCK / A2AV_TILE_SIZE;
392+
thread_group tile = cg::tiled_partition(this_thread_block(), A2AV_TILE_SIZE);
393+
int tileId = tid / A2AV_TILE_SIZE;
394+
// Each tile calculates its own prefix sum
395+
__shared__ int64_t tile_prefix_sums[NUM_TILES][A2AV_TILE_SIZE];
396+
// A tile takes care of npes worth of splits
397+
int nsplits_per_tile = min(npes, nsplits - tileId * npes);
398+
// TODO: currently it is assumed that the number of PE's is smaller than `A2AV_TILE_SIZE`
399+
CUDA_KERNEL_ASSERT(nsplits_per_tile <= A2AV_TILE_SIZE);
400+
401+
// Total length of each tile
402+
__shared__ int64_t len_per_tile[NUM_TILES];
403+
// Starting offset of each tile
404+
__shared__ int64_t start_offset_per_tile[NUM_TILES];
405+
// This tile does not need to do tile-wise prefix sum
406+
if (nsplits_per_tile < 0) goto end_of_tile_prefix_sum;
407+
408+
// Each tile calculates its own prefix sum
409+
prefixSum_warp(tile_prefix_sums[tileId], output_splits + tileId * npes, nsplits_per_tile);
410+
411+
// Last thread in each tile does the up aligning.
412+
// Note: using the last thread to read the last sum from `tile_prefix_sums` so
413+
// that we can save a __syncthreads(). This is safe because the last thread is
414+
// the one that writes the last sum in the prefixSum function.
415+
if (tile.thread_rank() == A2AV_TILE_SIZE - 1) {
416+
auto my_tile_len = tile_prefix_sums[tileId][A2AV_TILE_SIZE - 1] + output_splits[tileId * npes + nsplits_per_tile - 1];
417+
// Up align
418+
len_per_tile[tileId] = (my_tile_len + major_align) / major_align * major_align;
419+
}
420+
end_of_tile_prefix_sum:
421+
__syncthreads();
422+
423+
// Prefix sum again to get the tiles' start offsets. This is a block-wide prefix sum.
424+
prefixSum(start_offset_per_tile, len_per_tile, NUM_TILES);
425+
__syncthreads();
426+
427+
// Add tile offset to every element in the tile
428+
tile_prefix_sums[tileId][tile.thread_rank()] += start_offset_per_tile[tileId];
361429
__syncthreads();
362430

363431
// Target a different e based on bid
@@ -366,7 +434,8 @@ __global__ void allToAllV_2d(void *send_data, void *recv_data, int64_t* in_out_s
366434
// Amount from `peer` for `e`
367435
auto peer_size = output_splits[eid] * stride;
368436
auto source_offset = source_offsets[eid] * stride;
369-
auto write_offset = e_offsets[eid] * stride;
437+
auto e_offset = tile_prefix_sums[eid / npes][peer];
438+
auto write_offset = e_offset * stride;
370439
nvshmemx_getmem_block(
371440
(char*)recv_data + write_offset,
372441
(char*)send_data + source_offset,
@@ -375,15 +444,16 @@ __global__ void allToAllV_2d(void *send_data, void *recv_data, int64_t* in_out_s
375444
}
376445
// Write out the output offsets (to the scratchpad line)
377446
if (bid == 0 && tid < nsplits) {
378-
source_offsets[tid] = e_offsets[tid];
447+
source_offsets[tid] = tile_prefix_sums[tid / npes][tid % npes];
379448
}
380449
}
381450

382451
at::Tensor nvshmem_all_to_all_vdev_2d(
383452
at::Tensor& input,
384453
at::Tensor& out,
385454
at::Tensor& in_out_splits,
386-
std::string group_name) {
455+
std::string group_name,
456+
int64_t major_align) {
387457
/* Perform a 2D AllToAllv shuffle operation using NVSHMEM, with split information provided on device.
388458
* Arguments:
389459
* - `input` is the input tensor
@@ -395,6 +465,8 @@ at::Tensor nvshmem_all_to_all_vdev_2d(
395465
output splits (OUT) and
396466
output offsets (OUT).
397467
* - `group_name` is the name of the group to use for the collective operation.
468+
* - `major_align` is the alignment of the "major dimension" of the output
469+
sequence. See below for details.
398470
399471
* A 2D AllToAllv shuffle is illustrated below:
400472
(world_size = 2, ne = 2, total number of experts = 4)
@@ -408,12 +480,20 @@ at::Tensor nvshmem_all_to_all_vdev_2d(
408480
`in_out_splits[0]`). That is, the 2D AllToAllv shuffle achives a
409481
transpose from rank-major order at input to expert-major order at
410482
output.
483+
484+
* If `major_align` is not 1, the output offsets of c1, c2, c3 will be
485+
up-aligned to this value. For example, if c0 has length 5 and d0 has
486+
length 7 (making a total of 12), and if the `major_align` is set to 16,
487+
the output offset of c1 will be 16. Similar for c2 and c3. This value has
488+
no effect on the offset of the minor dimension, i.e. d0, d1, d2 and d3.
411489
*/
412490
auto input_hdl = c10d::symmetric_memory::rendezvous(input, group_name);
413491
auto out_hdl = c10d::symmetric_memory::rendezvous(out, group_name);
414492
auto splits_hdl = c10d::symmetric_memory::rendezvous(in_out_splits, group_name);
415493
int rank = input_hdl->get_rank();
416494
int world_size = input_hdl->get_world_size();
495+
// TODO: world_size is currently limited by the number of elements in a WarpScan.
496+
TORCH_CHECK(world_size <= A2AV_TILE_SIZE, "world_size must be smaller than A2AV_TILE_SIZE", A2AV_TILE_SIZE);
417497

418498
void* input_ptr = input_hdl->get_buffer_ptrs()[rank];
419499
void* output_ptr = out_hdl->get_buffer_ptrs()[rank];
@@ -460,7 +540,8 @@ at::Tensor nvshmem_all_to_all_vdev_2d(
460540
&stride_bytes,
461541
&rank,
462542
&world_size,
463-
&ne};
543+
&ne,
544+
&major_align};
464545
nvshmemx_collective_launch(
465546
(const void*)allToAllV_2d,
466547
dim3(num_blocks),

torch/csrc/distributed/c10d/nvshmem_extension.cuh

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ at::Tensor nvshmem_all_to_all_vdev_2d(
3232
at::Tensor& input,
3333
at::Tensor& out,
3434
at::Tensor& in_out_splits,
35-
std::string group_name);
35+
std::string group_name,
36+
int64_t major_align = 1);
3637

3738
} // namespace c10d::nvshmem_extension

0 commit comments

Comments
 (0)
0