8000 [a2av] Align length of major dimension in output of 2D a2av (#155172) · pytorch/pytorch@749757a · GitHub
[go: up one dir, main page]

Skip to content

Commit 749757a

Browse files
kwen2501pytorchmergebot
authored andcommitted
[a2av] Align length of major dimension in output of 2D a2av (#155172)
Downstream consumer of the 2D all-to-all-v is often a group GEMM. Today the GEMM often have an alignment requirement on the chunk sizes within grouped sequence, where each chunk carries the tokens headed for an expert. For example, `torch._group_mm` requires an alignment of 8. This PR adds that alignment capability, when user passes in a `major_align` argument, so that no extra padding step is needed. The key in supporting that is making the output offsets aligned to such value. (Output offsets are returned to the users in the 3rd row of `in_out_splits`, on device. The 2nd row, output splits, are unaffected by this alignment value -- i.e. reflecting true number of tokens for an expert.) The algorithm is as follows. ![502413288_678786854922438_530852083153996358_n](https://github.com/user-attachments/assets/557624a3-150e-4ab6-ba8b-1dbaa5ac01ac) In detailed implementation, we use warp scan to calculate prefix sum on the "block" illustrated above. As a result, the "block" size, i.e. `npes` is currently limited to warp size 32. Pull Request resolved: #155172 Approved by: https://github.com/ngimel ghstack dependencies: #153653, #153677, #155058
1 parent 1ccc57e commit 749757a

File tree

4 files changed

+151
-23
lines changed

4 files changed

+151
-23
lines changed

test/distributed/test_nvshmem.py

Lines changed: 37 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@
1313
import torch.distributed._symmetric_memory as symm_mem
1414
from torch.testing._internal.common_distributed import MultiProcContinousTest
1515
from torch.testing._internal.common_utils import (
16+
instantiate_parametrized_tests,
17+
parametrize,
1618
run_tests,
1719
skip_but_pass_in_sandcastle_if,
1820
skipIfRocm,
@@ -42,6 +44,7 @@ def requires_nvshmem():
4244
device_module = torch.get_device_module(device_type)
4345

4446

47+
@instantiate_parametrized_tests
4548
@requires_nvshmem()
4649
class NVSHMEMSymmetricMemoryTest(MultiProcContinousTest):
4750
def _init_device(self) -> None:
@@ -129,7 +132,8 @@ def test_nvshmem_all_to_all_vdev(self) -> None:
129132
torch.testing.assert_close(out[:out_numel], expected)
130133

131134
@skipIfRocm
132-
def test_nvshmem_all_to_all_vdev_2d(self) -> None:
135+
@parametrize("align", [1, 8, 16]) # `major_align` of output
136+
def test_nvshmem_all_to_all_vdev_2d(self, align: int) -> None:
133137
torch.manual_seed(42 + self.rank)
134138
self._init_device()
135139

@@ -148,13 +152,12 @@ def test_nvshmem_all_to_all_vdev_2d(self) -> None:
148152
out_splits = torch.zeros_like(inp_splits)
149153
dist.all_to_all_single(out_splits, inp_splits)
150154
# We do a .t() here because there is a rank-major to expert-major shuffle
151-
out_splits_t = out_splits.reshape(self.world_size, ne).t().reshape(-1)
155+
out_splits_t = out_splits.reshape(self.world_size, ne).t()
152156

153157
# Total number of output elements
154158
out_numel = out_splits.sum().item()
155-
# Align up to make it bigger
156-
align = 16
157-
out_numel_max = (out_numel + align - 1) // align * align
159+
# Align-up makes it bigger
160+
out_numel_max = (out_numel + align * ne) // align * align
158161

159162
inp = symm_mem.empty(inp_numel, dtype=dtype, device=self.device).fill_(
160163
self.rank
@@ -167,20 +170,37 @@ def test_nvshmem_all_to_all_vdev_2d(self) -> None:
167170
in_out_splits[0].copy_(inp_splits)
168171

169172
torch.ops.symm_mem.nvshmem_all_to_all_vdev_2d(
170-
inp, out, in_out_splits, group_name
173+
inp, out, in_out_splits, group_name, major_align=align
171174
)
175+
received_out_splits = in_out_splits[1]
176+
received_out_offsets = in_out_splits[2]
172177

173178
# Check input splits (row 0) -- should not change
174179
torch.testing.assert_close(in_out_splits[0], inp_splits)
175180

176181
# Check output splits (row 1)
177-
torch.testing.assert_close(in_out_splits[1], out_splits_t)
182+
torch.testing.assert_close(received_out_splits, out_splits_t.reshape(-1))
178183

179184
# Check output offsets (row 2)
180-
out_offsets = torch.cumsum(out_splits_t, dim=0) # inclusive scan
181-
# output offsets from `nvshmem_all_to_all_vdev` is exclusive scan
182-
self.assertEqual(in_out_splits[2][0], 0)
183-
torch.testing.assert_close(in_out_splits[2][1:], out_offsets[:-1])
185+
out_split_list = out_splits_t.tolist()
186+
for i in range(ne):
187+
expert_sum = 0
188+
for j in range(self.world_size):
189+
expert_sum += out_split_list[i][j]
190+
# Align up expert_sum
191+
expert_sum_aligned = (expert_sum + align - 1) // align * align
192+
# If 0, make it at least `align` (bc cutlass currently does not support empty bins)
193+
expert_sum_aligned = max(expert_sum_aligned, align)
194+
# last element absorbs the padding
195+
out_split_list[i][-1] += expert_sum_aligned - expert_sum
196+
197+
out_splits_padded = torch.tensor(out_split_list, device=self.device).reshape(-1)
198+
out_offsets = torch.cumsum(out_splits_padded, dim=0) # inclusive scan
199+
# Make it exclusive scan because that's what `nvshmem_all_to_all_vdev_2d` returns
200+
out_offsets = torch.cat(
201+
[torch.zeros(1, device=self.device), out_offsets[:-1]]
202+
).to(torch.int64)
203+
torch.testing.assert_close(received_out_offsets, out_offsets)
184204

185205
# Check data
186206
expected = torch.empty(out_numel, dtype=dtype, device=self.device)
@@ -199,8 +219,12 @@ def test_nvshmem_all_to_all_vdev_2d(self) -> None:
199219
chunk = expected[offset - out_splits[chunk_id] : offset]
200220
result_list.append(chunk)
201221

202-
final = torch.cat(result_list)
203-
torch.testing.assert_close(out[:out_numel], final)
222+
# Do a chunk-wise comparison
223+
for c, chunk in enumerate(result_list):
224+
start = received_out_offsets[c].item()
225+
split = received_out_splits[c].item()
226+
received_chunk = out[start : start + split]
227+
torch.testing.assert_close(received_chunk, chunk)
204228

205229

206230
if __name__ == "__main__":

torch/csrc/distributed/c10d/SymmetricMemory.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -281,7 +281,7 @@ TORCH_LIBRARY_FRAGMENT(symm_mem, m) {
281281
m.def(
282282
"nvshmem_all_to_all_vdev(Tensor input, Tensor(a!) out, Tensor(a!) in_out_splits, str group_name) -> Tensor(a!)");
283283
m.def(
284-
"nvshmem_all_to_all_vdev_2d(Tensor input, Tensor(a!) out, Tensor(a!) in_out_splits, str group_name) -> Tensor(a!)");
284+
"nvshmem_all_to_all_vdev_2d(Tensor input, Tensor(a!) out, Tensor(a!) in_out_splits, str group_name, int? major_align=None) -> Tensor(a!)");
285285
}
286286

287287
TORCH_LIBRARY_IMPL(symm_mem, Meta, m) {

torch/csrc/distributed/c10d/nvshmem_extension.cu

Lines changed: 111 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ using c10d::symmetric_memory::StoreExchange;
1717
static StoreExchange storeExchange = StoreExchange("nvshmem_ext");
1818

1919
#define THREADS_PER_BLOCK 512
20+
#define WARP_SIZE 32
2021

2122
constexpr int MiB = 1024 * 1024;
2223

@@ -354,20 +355,100 @@ __global__ void exchangeSplitAndOffset_2d(int64_t* in_out_splits, int mype, int
354355
nvshmemx_barrier_all_block();
355356
}
356357

358+
// This is an warp-scope, exclusive prefix sum. When called by a block of
359+
// threads, each warp will perform an independent prefix sum, concurrently.
360+
// Returns the sum of all elements in the warp.
361+
// `NUM_WARPS` is the number of warps participating the concurrent prefix sum.
362+
template <int NUM_WARPS>
363+
__device__ int64_t prefixSum_warp(int64_t *odata, int64_t *idata, int n) {
364+
CUDA_KERNEL_ASSERT(n <= WARP_SIZE);
365+
366+
// Specialize WarpScan for type int
367+
using WarpScan = at_cuda_detail::cub::WarpScan<int64_t>;
368+
// Allocate WarpScan shared memory for N warps
369+
__shared__ typename WarpScan::TempStorage temp_storage[NUM_WARPS];
370+
371+
int warp_id = threadIdx.x / WARP_SIZE;
372+
if (warp_id >= NUM_WARPS) {
373+
return 0;
374+
}
375+
376+
// Obtain input item for each thread
377+
int tid = threadIdx.x % WARP_SIZE;
378+
int64_t thread_data = (tid < n) ? idata[tid] : 0;
379+
380+
// Total sum of all elements in the warp
381+
int64_t warp_aggregate;
382+
// Compute the warp-wide exclusive prefix sum
383+
WarpScan(temp_storage[warp_id]).ExclusiveSum(thread_data, thread_data, warp_aggregate);
384+
385+
// Store the result
386+
odata[tid] = thread_data;
387+
return warp_aggregate;
388+
}
389+
390+
// This is for abstracting a thread-group-scope, exclusive prefix sum.
391+
// Since we use warp-scope prefix sum, the thread group size is limited to warp size.
392+
#define A2AV_TILE_SIZE WARP_SIZE
393+
357394
// This kernel is used to do the actual data exchange.
358395
// `in_out_splits` has the same definition as in `exchangeSplitAndOffset`.
359396
// `stride` is the stride at dim 0, unit in byte.
360397
// For meaning of `mype` and `npes`, see the docstring of `nvshmem_all_to_all_vdev_2d`.
361-
__global__ void allToAllV_2d(void *send_data, void *recv_data, int64_t* in_out_splits, size_t stride, int mype, int npes, int ne) {
398+
__global__ void allToAllV_2d(void *send_data, void *recv_data, int64_t* in_out_splits, size_t stride, int mype, int npes, int ne, int64_t major_align) {
362399
int nsplits = npes * ne;
363400
auto output_splits = in_out_splits + nsplits;
364401
auto source_offsets = in_out_splits + nsplits * 2;
365402
int bid = blockIdx.x;
366403
int tid = threadIdx.x;
367404

368-
// Calculate the output offsets
369-
__shared__ int64_t e_offsets[THREADS_PER_BLOCK];
370-
prefixSum(e_offsets, output_splits, nsplits);
405+
// Split the thread block into tiles
406+
constexpr int NUM_TILES = THREADS_PER_BLOCK / A2AV_TILE_SIZE;
407+
int tileId = tid / A2AV_TILE_SIZE;
408+
int laneId = tid % A2AV_TILE_SIZE;
409+
// Each tile calculates its own prefix sum
410+
__shared__ int64_t tile_prefix_sums[NUM_TILES][A2AV_TILE_SIZE];
411+
// A tile takes care of npes worth of splits
412+
int nsplits_per_tile = min(npes, nsplits - tileId * npes);
413+
// TODO: currently it is assumed that the number of PE's is smaller than
414+
// `A2AV_TILE_SIZE` bc the warp-scope prefix sum can only handle up to
415+
// WARP_SIZE elements
416+
CUDA_KERNEL_ASSERT(npes <= A2AV_TILE_SIZE);
417+
// Similarly, the number of experts per rank is also assumed to be smaller
418+
// than `NUM_TILES`
419+
CUDA_KERNEL_ASSERT(ne <= NUM_TILES);
420+
421+
// Total length of each tile
422+
__shared__ int64_t len_per_tile[NUM_TILES];
423+
// When `nsplits` is small, not every tile gets data to sum. They can skip
424+
// this local prefix sum.
425+
if (nsplits_per_tile > 0) {
426+
// Each tile calculates its own prefix sum, return value is the sum of all elements in the tile.
427+
int64_t my_tile_len = prefixSum_warp<NUM_TILES>(tile_prefix_sums[tileId], output_splits + tileId * npes, nsplits_per_tile);
428+
// Last thread in each tile does the up aligning.
429+
if (laneId == A2AV_TILE_SIZE - 1) {
430+
auto aligned_len = (my_tile_len + major_align - 1) / major_align * major_align;
431+
// In case `aligned_len` is 0, we set it to `major_align` to avoid an
432+
// empty bin, bc cutlass currently does not support it. See
433+
// https://github.com/pytorch/pytorch/issues/152668.
434+
len_per_tile[tileId] = max(aligned_len, major_align);
435+
}
436+
}
437+
__syncthreads();
438+
439+
// Starting offset of each tile
440+
__shared__ int64_t start_offset_per_tile[NUM_TILES];
441+
// Prefix sum again to get the tiles' start offsets.
442+
// `NUM_TILES` is typically not greater than 32, because 32 tiles * 32 threads
443+
// = 1024 threads, and this kernel is launched within 1024 threads. Thus, we
444+
// can use warp-scope prefix sum.
445+
static_assert(NUM_TILES <= WARP_SIZE);
446+
// Only 1 warp is needed
447+
prefixSum_warp<1>(start_offset_per_tile, len_per_tile, NUM_TILES);
44 D9D3 8+
__syncthreads();
449+
450+
// Add tile offset to every element in the tile
451+
tile_prefix_sums[tileId][laneId] += start_offset_per_tile[tileId];
371452
__syncthreads();
372453

373454
// Target a different e based on bid
@@ -376,7 +457,8 @@ __global__ void allToAllV_2d(void *send_data, void *recv_data, int64_t* in_out_s
376457
// Amount from `peer` for `e`
377458
auto peer_size = output_splits[eid] * stride;
378459
auto source_offset = source_offsets[eid] * stride;
379-
auto write_offset = e_offsets[eid] * stride;
460+
auto e_offset = tile_prefix_sums[eid / npes][peer];
461+
auto write_offset = e_offset * stride;
380462
nvshmemx_getmem_block(
381463
(char*)recv_data + write_offset,
382464
(char*)send_data + source_offset,
@@ -385,15 +467,16 @@ __global__ void allToAllV_2d(void *send_data, void *recv_data, int64_t* in_out_s
385467
}
386468
// Write out the output offsets (to the scratchpad line)
387469
if (bid == 0 && tid < nsplits) {
388-
source_offsets[tid] = e_offsets[tid];
470+
source_offsets[tid] = tile_prefix_sums[tid / npes][tid % npes];
389471
}
390472
}
391473

392474
at::Tensor nvshmem_all_to_all_vdev_2d(
393475
at::Tensor& input,
394476
at::Tensor& out,
395477
at::Tensor& in_out_splits,
396-
std::string group_name) {
478+
std::string group_name,
479+
std::optional<int64_t> major_align) {
397480
/* Perform a 2D AllToAllv shuffle operation using NVSHMEM, with split information provided on device.
398481
* Arguments:
399482
* - `input` is the input tensor
@@ -405,6 +488,8 @@ at::Tensor nvshmem_all_to_all_vdev_2d(
405488
output splits (OUT) and
406489
output offsets (OUT).
407490
* - `group_name` is the name of the group to use for the collective operation.
491+
* - `major_align` is the alignment of the "major dimension" of the output
492+
sequence. See below for details.
408493
409494
* A 2D AllToAllv shuffle is illustrated below:
410495
(world_size = 2, ne = 2, total number of experts = 4)
@@ -418,12 +503,27 @@ at::Tensor nvshmem_all_to_all_vdev_2d(
418503
`in_out_splits[0]`). That is, the 2D AllToAllv shuffle achives a
419504
transpose from rank-major order at input to expert-major order at
420505
output.
506+
507+
* If `major_align` is not 1, the output offsets of c1, c2, c3 will be
508+
up-aligned to this value. For example, if c0 has length 5 and d0 has
509+
length 7 (making a total of 12), and if the `major_align` is set to 16,
510+
the output offset of c1 will be 16. Similar for c2 and c3. This value has
511+
no effect on the offset of the minor dimension, i.e. d0, d1, d2 and d3.
512+
Note: since cutlass does not support empty bins, we set the aligned length
513+
to `major_align` if it is 0. See
514+
https://github.com/pytorch/pytorch/issues/152668.
421515
*/
422516
auto input_hdl = c10d::symmetric_memory::rendezvous(input, group_name);
423517
auto out_hdl = c10d::symmetric_memory::rendezvous(out, group_name);
424518
auto splits_hdl = c10d::symmetric_memory::rendezvous(in_out_splits, group_name);
425519
int rank = input_hdl->get_rank();
426520
int world_size = input_hdl->get_world_size();
521+
// TODO: world_size is currently limited by the number of elements in a WarpScan.
522+
TORCH_CHECK(world_size <= A2AV_TILE_SIZE, "world_size must be smaller than A2AV_TILE_SIZE", A2AV_TILE_SIZE);
523+
524+
// If `major_align` is not provided, use 1 as the default value.
525+
int64_t major_align_val = major_align.value_or(1);
526+
TORCH_CHECK(major_align_val > 0, "major_align must be positive");
427527

428528
void* input_ptr = input_hdl->get_buffer_ptrs()[rank];
429529
void* output_ptr = out_hdl->get_buffer_ptrs()[rank];
@@ -449,6 +549,8 @@ at::Tensor nvshmem_all_to_all_vdev_2d(
449549

450550
// Number of experts per rank
451551
int ne = split_shape[1] / world_size;
552+
constexpr int NUM_TILES = THREADS_PER_BLOCK / A2AV_TILE_SIZE;
553+
TORCH_CHECK(ne <= NUM_TILES, "Number of experts must be smaller than NUM_TILES", NUM_TILES);
452554

453555
// Set device context for getting the stream and launching kernels below
454556
c10::cuda::CUDAGuard guard(input.device());
@@ -487,7 +589,8 @@ at::Tensor nvshmem_all_to_all_vdev_2d(
487589
&stride_bytes,
488590
&rank,
489591
&world_size,
490-
&ne};
592+
&ne,
593+
&m 31ED ajor_align_val};
491594
nvshmemx_collective_launch(
492595
(const void*)allToAllV_2d,
493596
dim3(num_blocks),

torch/csrc/distributed/c10d/nvshmem_extension.cuh

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ at::Tensor nvshmem_all_to_all_vdev_2d(
3232
at::Tensor& input,
3333
at::Tensor& out,
3434
at::Tensor& in_out_splits,
35-
std::string group_name);
35+
std::string group_name,
36+
std::optional<int64_t> major_align = std::nullopt);
3637

3738
} // namespace c10d::nvshmem_extension

0 commit comments

Comments
 (0)
0