@@ -17,6 +17,7 @@ using c10d::symmetric_memory::StoreExchange;
1717static StoreExchange storeExchange = StoreExchange(" nvshmem_ext" );
1818
1919#define THREADS_PER_BLOCK 512
20+ #define WARP_SIZE 32
2021
2122// Bootstrap based on user's setting for NCCL
2223// Long term, this may be a bit unclean; short term, it improves UX
@@ -347,20 +348,100 @@ __global__ void exchangeSplitAndOffset_2d(int64_t* in_out_splits, int mype, int
347348 nvshmemx_barrier_all_block ();
348349}
349350
351+ // This is an warp-scope, exclusive prefix sum. When called by a block of
352+ // threads, each warp will perform an independent prefix sum, concurrently.
353+ // Returns the sum of all elements in the warp.
354+ // `NUM_WARPS` is the number of warps participating the concurrent prefix sum.
355+ template <int NUM_WARPS>
356+ __device__ int64_t prefixSum_warp (int64_t *odata, int64_t *idata, int n) {
357+ CUDA_KERNEL_ASSERT (n <= WARP_SIZE);
358+
359+ // Specialize WarpScan for type int
360+ using WarpScan = at_cuda_detail::cub::WarpScan<int64_t >;
361+ // Allocate WarpScan shared memory for N warps
362+ __shared__ typename WarpScan::TempStorage temp_storage[NUM_WARPS];
363+
364+ int warp_id = threadIdx .x / WARP_SIZE;
365+ if (warp_id >= NUM_WARPS) {
366+ return 0 ;
367+ }
368+
369+ // Obtain input item for each thread
370+ int tid = threadIdx .x % WARP_SIZE;
371+ int64_t thread_data = (tid < n) ? idata[tid] : 0 ;
372+
373+ // Total sum of all elements in the warp
374+ int64_t warp_aggregate;
375+ // Compute the warp-wide exclusive prefix sum
376+ WarpScan (temp_storage[warp_id]).ExclusiveSum (thread_data, thread_data, warp_aggregate);
377+
378+ // Store the result
379+ odata[tid] = thread_data;
380+ return warp_aggregate;
381+ }
382+
383+ // This is for abstracting a thread-group-scope, exclusive prefix sum.
384+ // Since we use warp-scope prefix sum, the thread group size is limited to warp size.
385+ #define A2AV_TILE_SIZE WARP_SIZE
386+
350387// This kernel is used to do the actual data exchange.
351388// `in_out_splits` has the same definition as in `exchangeSplitAndOffset`.
352389// `stride` is the stride at dim 0, unit in byte.
353390// For meaning of `mype` and `npes`, see the docstring of `nvshmem_all_to_all_vdev_2d`.
354- __global__ void allToAllV_2d (void *send_data, void *recv_data, int64_t * in_out_splits, size_t stride, int mype, int npes, int ne) {
391+ __global__ void allToAllV_2d (void *send_data, void *recv_data, int64_t * in_out_splits, size_t stride, int mype, int npes, int ne, int64_t major_align ) {
355392 int nsplits = npes * ne;
356393 auto output_splits = in_out_splits + nsplits;
357394 auto source_offsets = in_out_splits + nsplits * 2 ;
358395 int bid = blockIdx .x ;
359396 int tid = threadIdx .x ;
360397
361- // Calculate the output offsets
362
8B92
- __shared__ int64_t e_offsets[THREADS_PER_BLOCK];
363- prefixSum (e_offsets, output_splits, nsplits);
398+ // Split the thread block into tiles
399+ constexpr int NUM_TILES = THREADS_PER_BLOCK / A2AV_TILE_SIZE;
400+ int tileId = tid / A2AV_TILE_SIZE;
401+ int laneId = tid % A2AV_TILE_SIZE;
402+ // Each tile calculates its own prefix sum
403+ __shared__ int64_t tile_prefix_sums[NUM_TILES][A2AV_TILE_SIZE];
404+ // A tile takes care of npes worth of splits
405+ int nsplits_per_tile = min (npes, nsplits - tileId * npes);
406+ // TODO: currently it is assumed that the number of PE's is smaller than
407+ // `A2AV_TILE_SIZE` bc the warp-scope prefix sum can only handle up to
408+ // WARP_SIZE elements
409+ CUDA_KERNEL_ASSERT (npes <= A2AV_TILE_SIZE);
410+ // Similarly, the number of experts per rank is also assumed to be smaller
411+ // than `NUM_TILES`
412+ CUDA_KERNEL_ASSERT (ne <= NUM_TILES);
413+
414+ // Total length of each tile
415+ __shared__ int64_t len_per_tile[NUM_TILES];
416+ // When `nsplits` is small, not every tile gets data to sum. They can skip
417+ // this local prefix sum.
418+ if (nsplits_per_tile > 0 ) {
419+ // Each tile calculates its own prefix sum, return value is the sum of all elements in the tile.
420+ int64_t my_tile_len = prefixSum_warp<NUM_TILES>(tile_prefix_sums[tileId], output_splits + tileId * npes, nsplits_per_tile);
421+ // Last thread in each tile does the up aligning.
422+ if (laneId == A2AV_TILE_SIZE - 1 ) {
423+ auto aligned_len = (my_tile_len + major_align - 1 ) / major_align * major_align;
424+ // In case `aligned_len` is 0, we set it to `major_align` to avoid an
425+ // empty bin, bc cutlass currently does not support it. See
426+ // https://github.com/pytorch/pytorch/issues/152668.
427+ len_per_tile[tileId] = max (aligned_len, major_align);
428+ }
429+ }
430+ __syncthreads ();
431+
432+ // Starting offset of each tile
433+ __shared__ int64_t start_offset_per_tile[NUM_TILES];
434+ // Prefix sum again to get the tiles' start offsets.
435+ // `NUM_TILES` is typically not greater than 32, because 32 tiles * 32 threads
436+ // = 1024 threads, and this kernel is launched within 1024 threads. Thus, we
437+ // can use warp-scope prefix sum.
438+ static_assert (NUM_TILES <= WARP_SIZE);
439+ // Only 1 warp is needed
440+ prefixSum_warp<1 >(start_offset_per_tile, len_per_tile, NUM_TILES);
441+ __syncthreads ();
442+
443+ // Add tile offset to every element in the tile
444+ tile_prefix_sums[tileId][laneId] += start_offset_per_tile[tileId];
364445 __syncthreads ();
365446
366447 // Target a different e based on bid
@@ -369,7 +450,8 @@ __global__ void allToAllV_2d(void *send_data, void *recv_data, int64_t* in_out_s
369450 // Amount from `peer` for `e`
370451 auto peer_size = output_splits[eid] * stride;
371452 auto source_offset = source_offsets[eid] * stride;
372- auto write_offset = e_offsets[eid] * stride;
453+ auto e_offset = tile_prefix_sums[eid / npes][peer];
454+ auto write_offset = e_offset * stride;
373455 nvshmemx_getmem_block (
374456 (char *)recv_data + write_offset,
375457 (char *)send_data + source_offset,
@@ -378,15 +460,16 @@ __global__ void allToAllV_2d(void *send_data, void *recv_data, int64_t* in_out_s
378460 }
379461 // Write out the output offsets (to the scratchpad line)
380462 if (bid == 0 && tid < nsplits) {
381- source_offsets[tid] = e_offsets [tid];
463+ source_offsets[tid] = tile_prefix_sums [tid / npes][tid % npes ];
382464 }
383465}
384466
385467at::Tensor nvshmem_all_to_all_vdev_2d (
386468 at::Tensor& input,
387469 at::Tensor& out,
388470 at::Tensor& in_out_splits,
389- std::string group_name) {
471+ std::string group_name,
472+ std::optional<int64_t > major_align) {
390473 /* Perform a 2D AllToAllv shuffle operation using NVSHMEM, with split information provided on device.
391474 * Arguments:
392475 * - `input` is the input tensor
@@ -398,6 +481,8 @@ at::Tensor nvshmem_all_to_all_vdev_2d(
398481 output splits (OUT) and
399482 output offsets (OUT).
400483 * - `group_name` is the name of the group to use for the collective operation.
484+ * - `major_align` is the alignment of the "major dimension" of the output
485+ sequence. See below for details.
401486
402487 * A 2D AllToAllv shuffle is illustrated below:
403488 (world_size = 2, ne = 2, total number of experts = 4)
@@ -411,12 +496,27 @@ at::Tensor nvshmem_all_to_all_vdev_2d(
411496 `in_out_splits[0]`). That is, the 2D AllToAllv shuffle achives a
412497 transpose from rank-major order at input to expert-major order at
413498 output.
499+
500+ * If `major_align` is not 1, the output offsets of c1, c2, c3 will be
501+ up-aligned to this value. For example, if c0 has length 5 and d0 has
502+ length 7 (making a total of 12), and if the `major_align` is set to 16,
503+ the output offset of c1 will be 16. Similar for c2 and c3. This value has
504+ no effect on the offset of the minor dimension, i.e. d0, d1, d2 and d3.
505+ Note: since cutlass does not support empty bins, we set the aligned length
506+ to `major_align` if it is 0. See
507+ https://github.com/pytorch/pytorch/issues/152668.
414508 */
415509 auto input_hdl = c10d::symmetric_memory::rendezvous (input, group_name);
416510 auto out_hdl = c10d::symmetric_memory::rendezvous (out, group_name);
417511 auto splits_hdl = c10d::symmetric_memory::rendezvous (in_out_splits, group_name);
418512 int rank = input_hdl->get_rank ();
419513 int world_size = input_hdl->get_world_size ();
514+ // TODO: world_size is currently limited by the number of elements in a WarpScan.
515+ TORCH_CHECK (world_size <= A2AV_TILE_SIZE, " world_size must be smaller than A2AV_TILE_SIZE" , A2AV_TILE_SIZE);
516+
517+ // If `major_align` is not provided, use 1 as the default value.
518+ int64_t major_align_val = major_align.value_or (1 );
519+ TORCH_CHECK (major_align_val > 0 , " major_align must be positive" );
420520
421521 void * input_ptr = input_hdl->get_buffer_ptrs ()[rank];
422522 void * output_ptr = out_hdl->get_buffer_ptrs ()[rank];
@@ -442,6 +542,8 @@ at::Tensor nvshmem_all_to_all_vdev_2d(
442542
443543 // Number of experts per rank
444544 int ne = split_shape[1 ] / world_size;
545+ constexpr int NUM_TILES = THREADS_PER_BLOCK / A2AV_TILE_SIZE;
546+ TORCH_CHECK (ne <= NUM_TILES, " Number of experts must be smaller than NUM_TILES" , NUM_TILES);
445547
446548 // Set device context for getting the stream and launching kernels below
447549 c10::cuda::CUDAGuard guard (input.device ());
@@ -480,7 +582,8 @@ at::Tensor nvshmem_all_to_all_vdev_2d(
480582 &stride_bytes,
481583 &rank,
482584 &world_size,
483- &ne};
585+ &ne,
586+ &major_align_val};
484587 nvshmemx_collective_launch (
485588 (const void *)allToAllV_2d,
486589 dim3 (num_blocks),
0 commit comments