10BC0 [a2av] Align length of major dimension in output of 2D a2av · pytorch/pytorch@ed9cfdd · GitHub
[go: up one dir, main page]

Skip to content

Commit ed9cfdd

Browse files
committed
[a2av] Align length of major dimension in output of 2D a2av
Tile-based prefix sum Working Use cub for prefix sum of tile offset ghstack-source-id: 9378ffc Pull-Request-resolved: #155172 Comments warp_aggregate Handle zero bin case without waste Templates prefixSum_warp
1 parent d4a4ec2 commit ed9cfdd

File tree

4 files changed

+151
-23
lines changed

4 files changed

+151
-23
lines changed

test/distributed/test_nvshmem.py

Lines changed: 37 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@
1313
import torch.distributed._symmetric_memory as symm_mem
1414
from torch.testing._internal.common_distributed import MultiProcContinousTest
1515
from torch.testing._internal.common_utils import (
16+
instantiate_parametrized_tests,
17+
parametrize,
1618
run_tests,
1719
skip_but_pass_in_sandcastle_if,
1820
skipIfRocm,
@@ -42,6 +44,7 @@ def requires_nvshmem():
4244
device_module = torch.get_device_module(device_type)
4345

4446

47+
@instantiate_parametrized_tests
4548
@requires_nvshmem()
4649
class NVSHMEMSymmetricMemoryTest(MultiProcContinousTest):
4750
def _init_device(self) -> None:
@@ -129,7 +132,8 @@ def test_nvshmem_all_to_all_vdev(self) -> None:
129132
torch.testing.assert_close(out[:out_numel], expected)
130133

131134
@skipIfRocm
132-
def test_nvshmem_all_to_all_vdev_2d(self) -> None:
135+
@parametrize("align", [1, 8, 16]) # `major_align` of output
136+
def test_nvshmem_all_to_all_vdev_2d(self, align: int) -> None:
133137
torch.manual_seed(42 + self.rank)
134138
self._init_device()
135139

@@ -148,13 +152,12 @@ def test_nvshmem_all_to_all_vdev_2d(self) -> None:
148152
out_splits = torch.zeros_like(inp_splits)
149153
dist.all_to_all_single(out_splits, inp_splits)
150154
# We do a .t() here because there is a rank-major to expert-major shuffle
151-
out_splits_t = out_splits.reshape(self.world_size, ne).t().reshape(-1)
155+
out_splits_t = out_splits.reshape(self.world_size, ne).t()
152156

153157
# Total number of output elements
154158
out_numel = out_splits.sum().item()
155-
# Align up to make it bigger
156-
align = 16
157-
out_numel_max = (out_numel + align - 1) // align * align
159+
# Align-up makes it bigger
160+
out_numel_max = (out_numel + align * ne) // align * align
158161

159162
inp = symm_mem.empty(inp_numel, dtype=dtype, device=self.device).fill_(
160163
self.rank
@@ -167,20 +170,37 @@ def test_nvshmem_all_to_all_vdev_2d(self) -> None:
167170
in_out_splits[0].copy_(inp_splits)
168171

169172
torch.ops.symm_mem.nvshmem_all_to_all_vdev_2d(
170-
inp, out, in_out_splits, group_name
173+
inp, out, in_out_splits, group_name, major_align=align
171174
)
175+
received_out_splits = in_out_splits[1]
176+
received_out_offsets = in_out_splits[2]
172177

173178
# Check input splits (row 0) -- should not change
174179
torch.testing.assert_close(in_out_splits[0], inp_splits)
175180

176181
# Check output splits (row 1)
177-
torch.testing.assert_close(in_out_splits[1], out_splits_t)
182+
torch.testing.assert_close(received_out_splits, out_splits_t.reshape(-1))
178183

179184
# Check output offsets (row 2)
180-
out_offsets = torch.cumsum(out_splits_t, dim=0) # inclusive scan
181-
# output offsets from `nvshmem_all_to_all_vdev` is exclusive scan
182-
self.assertEqual(in_out_splits[2][0], 0)
183-
torch.testing.assert_close(in_out_splits[2][1:], out_offsets[:-1])
185+
out_split_list = out_splits_t.tolist()
186+
for i in range(ne):
187+
expert_sum = 0
188+
for j in range(self.world_size):
189+
expert_sum += out_split_list[i][j]
190+
# Align up expert_sum
191+
expert_sum_aligned = (expert_sum + align - 1) // align * align
192+
# If 0, make it at least `align` (bc cutlass currently does not support empty bins)
193+
expert_sum_aligned = max(expert_sum_aligned, align)
194+
# last element absorbs the padding
195+
out_split_list[i][-1] += expert_sum_aligned - expert_sum
196+
197+
out_splits_padded = torch.tensor(out_split_list, device=self.device).reshape(-1)
198+
out_offsets = torch.cumsum(out_splits_padded, dim=0) # inclusive scan
199+
# Make it exclusive scan because that's what `nvshmem_all_to_all_vdev_2d` returns
200+
out_offsets = torch.cat(
201+
[torch.zeros(1, device=self.device), out_offsets[:-1]]
202+
).to(torch.int64)
203+
torch.testing.assert_close(received_out_offsets, out_offsets)
184204

185205
# Check data
186206
expected = torch.empty(out_numel, dtype=dtype, device=self.device)
@@ -199,8 +219,12 @@ def test_nvshmem_all_to_all_vdev_2d(self) -> None:
199219
chunk = expected[offset - out_splits[chunk_id] : offset]
200220
result_list.append(chunk)
201221

202-
final = torch.cat(result_list)
203-
torch.testing.assert_close(out[:out_numel], final)
222+
# Do a chunk-wise comparison
223+
for c, chunk in enumerate(result_list):
224+
start = received_out_offsets[c].item()
225+
split = received_out_splits[c].item()
226+
received_chunk = out[start : start + split]
227+
torch.testing.assert_close(received_chunk, chunk)
204228

205229

206230
if __name__ == "__main__":

torch/csrc/distributed/c10d/SymmetricMemory.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -281,7 +281,7 @@ TORCH_LIBRARY_FRAGMENT(symm_mem, m) {
281281
m.def(
282282
"nvshmem_all_to_all_vdev(Tensor input, Tensor(a!) out, Tensor(a!) in_out_splits, str group_name) -> Tensor(a!)");
283283
m.def(
284-
"nvshmem_all_to_all_vdev_2d(Tensor input, Tensor(a!) out, Tensor(a!) in_out_splits, str group_name) -> Tensor(a!)");
284+
"nvshmem_all_to_all_vdev_2d(Tensor input, Tensor(a!) out, Tensor(a!) in_out_splits, str group_name, int? major_align=None) -> Tensor(a!)");
285285
}
286286

287287
TORCH_LIBRARY_IMPL(symm_mem, Meta, m) {

torch/csrc/distributed/c10d/nvshmem_extension.cu

Lines changed: 111 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ using c10d::symmetric_memory::StoreExchange;
1717
static StoreExchange storeExchange = StoreExchange("nvshmem_ext");
1818

1919
#define THREADS_PER_BLOCK 512
20+
#define WARP_SIZE 32
2021

2122
// Bootstrap based on user's setting for NCCL
2223
// Long term, this may be a bit unclean; short term, it improves UX
@@ -347,20 +348,100 @@ __global__ void exchangeSplitAndOffset_2d(int64_t* in_out_splits, int mype, int
347348
nvshmemx_barrier_all_block();
348349
}
349350

351+
// This is an warp-scope, exclusive prefix sum. When called by a block of
352+
// threads, each warp will perform an independent prefix sum, concurrently.
353+
// Returns the sum of all elements in the warp.
354+
// `NUM_WARPS` is the number of warps participating the concurrent prefix sum.
355+
template <int NUM_WARPS>
356+
__device__ int64_t prefixSum_warp(int64_t *odata, int64_t *idata, int n) {
357+
CUDA_KERNEL_ASSERT(n <= WARP_SIZE);
358+
359+
// Specialize WarpScan for type int
360+
using WarpScan = at_cuda_detail::cub::WarpScan<int64_t>;
361+
// Allocate WarpScan shared memory for N warps
362+
__shared__ typename WarpScan::TempStorage temp_storage[NUM_WARPS];
363+
364+
int warp_id = threadIdx.x / WARP_SIZE;
365+
if (warp_id >= NUM_WARPS) {
366+
return 0;
367+
}
368+
369+
// Obtain input item for each thread
370+
int tid = threadIdx.x % WARP_SIZE;
371+
int64_t thread_data = (tid < n) ? idata[tid] : 0;
372+
373+
// Total sum of all elements in the warp
374+
int64_t warp_aggregate;
375+
// Compute the warp-wide exclusive prefix sum
376+
WarpScan(temp_storage[warp_id]).ExclusiveSum(thread_data, thread_data, warp_aggregate);
377+
378+
// Store the result
379+
odata[tid] = thread_data;
380+
return warp_aggregate;
381+
}
382+
383+
// This is for abstracting a thread-group-scope, exclusive prefix sum.
384+
// Since we use warp-scope prefix sum, the thread group size is limited to warp size.
385+
#define A2AV_TILE_SIZE WARP_SIZE
386+
350387
// This kernel is used to do the actual data exchange.
351388
// `in_out_splits` has the same definition as in `exchangeSplitAndOffset`.
352389
// `stride` is the stride at dim 0, unit in byte.
353390
// For meaning of `mype` and `npes`, see the docstring of `nvshmem_all_to_all_vdev_2d`.
354-
__global__ void allToAllV_2d(void *send_data, void *recv_data, int64_t* in_out_splits, size_t stride, int mype, int npes, int ne) {
391+
__global__ void allToAllV_2d(void *send_data, void *recv_data, int64_t* in_out_splits, size_t stride, int mype, int npes, int ne, int64_t major_align) {
355392
int nsplits = npes * ne;
356393
auto output_splits = in_out_splits + nsplits;
357394
auto source_offsets = in_out_splits + nsplits * 2;
358395
int bid = blockIdx.x;
359396
int tid = threadIdx.x;
360397

361-
// Calculate the output offsets
362 8B92 -
__shared__ int64_t e_offsets[THREADS_PER_BLOCK];
363-
prefixSum(e_offsets, output_splits, nsplits);
398+
// Split the thread block into tiles
399+
constexpr int NUM_TILES = THREADS_PER_BLOCK / A2AV_TILE_SIZE;
400+
int tileId = tid / A2AV_TILE_SIZE;
401+
int laneId = tid % A2AV_TILE_SIZE;
402+
// Each tile calculates its own prefix sum
403+
__shared__ int64_t tile_prefix_sums[NUM_TILES][A2AV_TILE_SIZE];
404+
// A tile takes care of npes worth of splits
405+
int nsplits_per_tile = min(npes, nsplits - tileId * npes);
406+
// TODO: currently it is assumed that the number of PE's is smaller than
407+
// `A2AV_TILE_SIZE` bc the warp-scope prefix sum can only handle up to
408+
// WARP_SIZE elements
409+
CUDA_KERNEL_ASSERT(npes <= A2AV_TILE_SIZE);
410+
// Similarly, the number of experts per rank is also assumed to be smaller
411+
// than `NUM_TILES`
412+
CUDA_KERNEL_ASSERT(ne <= NUM_TILES);
413+
414+
// Total length of each tile
415+
__shared__ int64_t len_per_tile[NUM_TILES];
416+
// When `nsplits` is small, not every tile gets data to sum. They can skip
417+
// this local prefix sum.
418+
if (nsplits_per_tile > 0) {
419+
// Each tile calculates its own prefix sum, return value is the sum of all elements in the tile.
420+
int64_t my_tile_len = prefixSum_warp<NUM_TILES>(tile_prefix_sums[tileId], output_splits + tileId * npes, nsplits_per_tile);
421+
// Last thread in each tile does the up aligning.
422+
if (laneId == A2AV_TILE_SIZE - 1) {
423+
auto aligned_len = (my_tile_len + major_align - 1) / major_align * major_align;
424+
// In case `aligned_len` is 0, we set it to `major_align` to avoid an
425+
// empty bin, bc cutlass currently does not support it. See
426+
// https://github.com/pytorch/pytorch/issues/152668.
427+
len_per_tile[tileId] = max(aligned_len, major_align);
428+
}
429+
}
430+
__syncthreads();
431+
432+
// Starting offset of each tile
433+
__shared__ int64_t start_offset_per_tile[NUM_TILES];
434+
// Prefix sum again to get the tiles' start offsets.
435+
// `NUM_TILES` is typically not greater than 32, because 32 tiles * 32 threads
436+
// = 1024 threads, and this kernel is launched within 1024 threads. Thus, we
437+
// can use warp-scope prefix sum.
438+
static_assert(NUM_TILES <= WARP_SIZE);
439+
// Only 1 warp is needed
440+
prefixSum_warp<1>(start_offset_per_tile, len_per_tile, NUM_TILES);
441+
__syncthreads();
442+
443+
// Add tile offset to every element in the tile
444+
tile_prefix_sums[tileId][laneId] += start_offset_per_tile[tileId];
364445
__syncthreads();
365446

366447
// Target a different e based on bid
@@ -369,7 +450,8 @@ __global__ void allToAllV_2d(void *send_data, void *recv_data, int64_t* in_out_s
369450
// Amount from `peer` for `e`
370451
auto peer_size = output_splits[eid] * stride;
371452
auto source_offset = source_offsets[eid] * stride;
372-
auto write_offset = e_offsets[eid] * stride;
453+
auto e_offset = tile_prefix_sums[eid / npes][peer];
454+
auto write_offset = e_offset * stride;
373455
nvshmemx_getmem_block(
374456
(char*)recv_data + write_offset,
375457
(char*)send_data + source_offset,
@@ -378,15 +460,16 @@ __global__ void allToAllV_2d(void *send_data, void *recv_data, int64_t* in_out_s
378460
}
379461
// Write out the output offsets (to the scratchpad line)
380462
if (bid == 0 && tid < nsplits) {
381-
source_offsets[tid] = e_offsets[tid];
463+
source_offsets[tid] = tile_prefix_sums[tid / npes][tid % npes];
382464
}
383465
}
384466

385467
at::Tensor nvshmem_all_to_all_vdev_2d(
386468
at::Tensor& input,
387469
at::Tensor& out,
388470
at::Tensor& in_out_splits,
389-
std::string group_name) {
471+
std::string group_name,
472+
std::optional<int64_t> major_align) {
390473
/* Perform a 2D AllToAllv shuffle operation using NVSHMEM, with split information provided on device.
391474
* Arguments:
392475
* - `input` is the input tensor
@@ -398,6 +481,8 @@ at::Tensor nvshmem_all_to_all_vdev_2d(
398481
output splits (OUT) and
399482
output offsets (OUT).
400483
* - `group_name` is the name of the group to use for the collective operation.
484+
* - `major_align` is the alignment of the "major dimension" of the output
485+
sequence. See below for details.
401486
402487
* A 2D AllToAllv shuffle is illustrated below:
403488
(world_size = 2, ne = 2, total number of experts = 4)
@@ -411,12 +496,27 @@ at::Tensor nvshmem_all_to_all_vdev_2d(
411496
`in_out_splits[0]`). That is, the 2D AllToAllv shuffle achives a
412497
transpose from rank-major order at input to expert-major order at
413498
output.
499+
500+
* If `major_align` is not 1, the output offsets of c1, c2, c3 will be
501+
up-aligned to this value. For example, if c0 has length 5 and d0 has
502+
length 7 (making a total of 12), and if the `major_align` is set to 16,
503+
the output offset of c1 will be 16. Similar for c2 and c3. This value has
504+
no effect on the offset of the minor dimension, i.e. d0, d1, d2 and d3.
505+
Note: since cutlass does not support empty bins, we set the aligned length
506+
to `major_align` if it is 0. See
507+
https://github.com/pytorch/pytorch/issues/152668.
414508
*/
415509
auto input_hdl = c10d::symmetric_memory::rendezvous(input, group_name);
416510
auto out_hdl = c10d::symmetric_memory::rendezvous(out, group_name);
417511
auto splits_hdl = c10d::symmetric_memory::rendezvous(in_out_splits, group_name);
418512
int rank = input_hdl->get_rank();
419513
int world_size = input_hdl->get_world_size();
514+
// TODO: world_size is currently limited by the number of elements in a WarpScan.
515+
TORCH_CHECK(world_size <= A2AV_TILE_SIZE, "world_size must be smaller than A2AV_TILE_SIZE", A2AV_TILE_SIZE);
516+
517+
// If `major_align` is not provided, use 1 as the default value.
518+
int64_t major_align_val = major_align.value_or(1);
519+
TORCH_CHECK(major_align_val > 0, "major_align must be positive");
420520

421521
void* input_ptr = input_hdl->get_buffer_ptrs()[rank];
422522
void* output_ptr = out_hdl->get_buffer_ptrs()[rank];
@@ -442,6 +542,8 @@ at::Tensor nvshmem_all_to_all_vdev_2d(
442542

443543
// Number of experts per rank
444544
int ne = split_shape[1] / world_size;
545+
constexpr int NUM_TILES = THREADS_PER_BLOCK / A2AV_TILE_SIZE;
546+
TORCH_CHECK(ne <= NUM_TILES, "Number of experts must be smaller than NUM_TILES", NUM_TILES);
445547

446548
// Set device context for getting the stream and launching kernels below
447549
c10::cuda::CUDAGuard guard(input.device());
@@ -480,7 +582,8 @@ at::Tensor nvshmem_all_to_all_vdev_2d(
480582
&stride_bytes,
481583
&rank,
482584
&world_size,
483-
&ne};
585+
&ne,
586+
&major_align_val};
484587
nvshmemx_collective_launch(
485588
(const void*)allToAllV_2d,
486589
dim3(num_blocks),

torch/csrc/distributed/c10d/nvshmem_extension.cuh

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ at::Tensor nvshmem_all_to_all_vdev_2d(
3232
at::Tensor& input,
3333
at::Tensor& out,
3434
at::Tensor& in_out_splits,
35-
std::string group_name);
35+
std::string group_name,
36+
std::optional<int64_t> major_align = std::nullopt);
3637

3738
} // namespace c10d::nvshmem_extension

0 commit comments

Comments
 (0)
0