@@ -17,6 +17,7 @@ using c10d::symmetric_memory::StoreExchange;
1717static StoreExchange storeExchange = StoreExchange(" nvshmem_ext" );
1818
1919#define THREADS_PER_BLOCK 512
20+ #define WARP_SIZE 32
2021
2122constexpr int MiB = 1024 * 1024 ;
2223
@@ -354,20 +355,100 @@ __global__ void exchangeSplitAndOffset_2d(int64_t* in_out_splits, int mype, int
354355 nvshmemx_barrier_all_block ();
355356}
356357
358+ // This is an warp-scope, exclusive prefix sum. When called by a block of
359+ // threads, each warp will perform an independent prefix sum, concurrently.
360+ // Returns the sum of all elements in the warp.
361+ // `NUM_WARPS` is the number of warps participating the concurrent prefix sum.
362+ template <int NUM_WARPS>
363+ __device__ int64_t prefixSum_warp (int64_t *odata, int64_t *idata, int n) {
364+ CUDA_KERNEL_ASSERT (n <= WARP_SIZE);
365+
366+ // Specialize WarpScan for type int
367+ using WarpScan = at_cuda_detail::cub::WarpScan<int64_t >;
368+ // Allocate WarpScan shared memory for N warps
369+ __shared__ typename WarpScan::TempStorage temp_storage[NUM_WARPS];
370+
371+ int warp_id = threadIdx .x / WARP_SIZE;
372+ if (warp_id >= NUM_WARPS) {
373+ return 0 ;
374+ }
375+
376+ // Obtain input item for each thread
377+ int tid = threadIdx .x % WARP_SIZE;
378+ int64_t thread_data = (tid < n) ? idata[tid] : 0 ;
379+
380+ // Total sum of all elements in the warp
381+ int64_t warp_aggregate;
382+ // Compute the warp-wide exclusive prefix sum
383+ WarpScan (temp_storage[warp_id]).ExclusiveSum (thread_data, thread_data, warp_aggregate);
384+
385+ // Store the result
386+ odata[tid] = thread_data;
387+ return warp_aggregate;
388+ }
389+
390+ // This is for abstracting a thread-group-scope, exclusive prefix sum.
391+ // Since we use warp-scope prefix sum, the thread group size is limited to warp size.
392+ #define A2AV_TILE_SIZE WARP_SIZE
393+
357394// This kernel is used to do the actual data exchange.
358395// `in_out_splits` has the same definition as in `exchangeSplitAndOffset`.
359396// `stride` is the stride at dim 0, unit in byte.
360397// For meaning of `mype` and `npes`, see the docstring of `nvshmem_all_to_all_vdev_2d`.
361- __global__ void allToAllV_2d (void *send_data, void *recv_data, int64_t * in_out_splits, size_t stride, int mype, int npes, int ne) {
398+ __global__ void allToAllV_2d (void *send_data, void *recv_data, int64_t * in_out_splits, size_t stride, int mype, int npes, int ne, int64_t major_align ) {
362399 int nsplits = npes * ne;
363400 auto output_splits = in_out_splits + nsplits;
364401 auto source_offsets = in_out_splits + nsplits * 2 ;
365402 int bid = blockIdx .x ;
366403 int tid = threadIdx .x ;
367404
368- // Calculate the output offsets
369- __shared__ int64_t e_offsets[THREADS_PER_BLOCK];
370- prefixSum (e_offsets, output_splits, nsplits);
405+ // Split the thread block into tiles
406+ constexpr int NUM_TILES = THREADS_PER_BLOCK / A2AV_TILE_SIZE;
407+ int tileId = tid / A2AV_TILE_SIZE;
408+ int laneId = tid % A2AV_TILE_SIZE;
409+ // Each tile calculates its own prefix sum
410+ __shared__ int64_t tile_prefix_sums[NUM_TILES][A2AV_TILE_SIZE];
411+ // A tile takes care of npes worth of splits
412+ int nsplits_per_tile = min (npes, nsplits - tileId * npes);
413+ // TODO: currently it is assumed that the number of PE's is smaller than
414+ // `A2AV_TILE_SIZE` bc the warp-scope prefix sum can only handle up to
415+ // WARP_SIZE elements
416+ CUDA_KERNEL_ASSERT (npes <= A2AV_TILE_SIZE);
417+ // Similarly, the number of experts per rank is also assumed to be smaller
418+ // than `NUM_TILES`
419+ CUDA_KERNEL_ASSERT (ne <= NUM_TILES);
420+
421+ // Total length of each tile
422+ __shared__ int64_t len_per_tile[NUM_TILES];
423+ // When `nsplits` is small, not every tile gets data to sum. They can skip
424+ // this local prefix sum.
425+ if (nsplits_per_tile > 0 ) {
426+ // Each tile calculates its own prefix sum, return value is the sum of all elements in the tile.
427+ int64_t my_tile_len = prefixSum_warp<NUM_TILES>(tile_prefix_sums[tileId], output_splits + tileId * npes, nsplits_per_tile);
428+ // Last thread in each tile does the up aligning.
429+ if (laneId == A2AV_TILE_SIZE - 1 ) {
430+ auto aligned_len = (my_tile_len + major_align - 1 ) / major_align * major_align;
431+ // In case `aligned_len` is 0, we set it to `major_align` to avoid an
432+ // empty bin, bc cutlass currently does not support it. See
433+ // https://github.com/pytorch/pytorch/issues/152668.
434+ len_per_tile[tileId] = max (aligned_len, major_align);
435+ }
436+ }
437+ __syncthreads ();
438+
439+ // Starting offset of each tile
440+ __shared__ int64_t start_offset_per_tile[NUM_TILES];
441+ // Prefix sum again to get the tiles' start offsets.
442+ // `NUM_TILES` is typically not greater than 32, because 32 tiles * 32 threads
443+ // = 1024 threads, and this kernel is launched within 1024 threads. Thus, we
444+ // can use warp-scope prefix sum.
445+ static_assert (NUM_TILES <= WARP_SIZE);
446+ // Only 1 warp is needed
447+ prefixSum_warp<1 >(start_offset_per_tile, len_per_tile, NUM_TILES);
448+ __syncthreads ();
449+
450+ // Add tile offset to every element in the tile
451+ tile_prefix_sums[tileId][laneId] += start_offset_per_tile[tileId];
371452 __syncthreads ();
372453
373454 // Target a different e based on bid
@@ -376,7 +457,8 @@ __global__ void allToAllV_2d(void *send_data, void *recv_data, int64_t* in_out_s
376457 // Amount from `peer` for `e`
377458 auto peer_size = output_splits[eid] * stride;
378459 auto source_offset = source_offsets[eid] * stride;
379- auto write_offset = e_offsets[eid] * stride;
460+ auto e_offset = tile_prefix_sums[eid / npes][peer];
461+ auto write_offset = e_offset * stride;
380462 nvshmemx_getmem_block (
381463 (char *)recv_data + write_offset,
382464 (char *)send_data + source_offset,
@@ -385,15 +467,16 @@ __global__ void allToAllV_2d(void *send_data, void *recv_data, int64_t* in_out_s
385467 }
386468 // Write out the output offsets (to the scratchpad line)
387469 if (bid == 0 && tid < nsplits) {
388- source_offsets[tid] = e_offsets [tid];
470+ source_offsets[tid] = tile_prefix_sums [tid / npes][tid % npes ];
389471 }
390472}
391473
392474at::Tensor nvshmem_all_to_all_vdev_2d (
393475 at::Tensor& input,
394476 at::Tensor& out,
395477 at::Tensor& in_out_splits,
396- std::string group_name) {
478+ std::string group_name,
479+ std::optional<int64_t > major_align) {
397480 /* Perform a 2D AllToAllv shuffle operation using NVSHMEM, with split information provided on device.
398481 * Arguments:
399482 * - `input` is the input tensor
@@ -405,6 +488,8 @@ at::Tensor nvshmem_all_to_all_vdev_2d(
405488 output splits (OUT) and
406489 output offsets (OUT).
407490 * - `group_name` is the name of the group to use for the collective operation.
491+ * - `major_align` is the alignment of the "major dimension" of the output
492+ sequence. See below for details.
408493
409494 * A 2D AllToAllv shuffle is illustrated below:
410495 (world_size = 2, ne = 2, total number of experts = 4)
@@ -418,12 +503,27 @@ at::Tensor nvshmem_all_to_all_vdev_2d(
418503 `in_out_splits[0]`). That is, the 2D AllToAllv shuffle achives a
419504 transpose from rank-major order at input to expert-major order at
420505 output.
506+
507+ * If `major_align` is not 1, the output offsets of c1, c2, c3 will be
508+ up-aligned to this value. For example, if c0 has length 5 and d0 has
509+ length 7 (making a t
4E22
otal of 12), and if the `major_align` is set to 16,
510+ the output offset of c1 will be 16. Similar for c2 and c3. This value has
511+ no effect on the offset of the minor dimension, i.e. d0, d1, d2 and d3.
512+ Note: since cutlass does not support empty bins, we set the aligned length
513+ to `major_align` if it is 0. See
514+ https://github.com/pytorch/pytorch/issues/152668.
421515 */
422516 auto input_hdl = c10d::symmetric_memory::rendezvous (input, group_name);
423517 auto out_hdl = c10d::symmetric_memory::rendezvous (out, group_name);
424518 auto splits_hdl = c10d::symmetric_memory::rendezvous (in_out_splits, group_name);
425519 int rank = input_hdl->get_rank ();
426520 int world_size = input_hdl->get_world_size ();
521+ // TODO: world_size is currently limited by the number of elements in a WarpScan.
522+ TORCH_CHECK (world_size <= A2AV_TILE_SIZE, " world_size must be smaller than A2AV_TILE_SIZE" , A2AV_TILE_SIZE);
523+
524+ // If `major_align` is not provided, use 1 as the default value.
525+ int64_t major_align_val = major_align.value_or (1 );
526+ TORCH_CHECK (major_align_val > 0 , " major_align must be positive" );
427527
428528 void * input_ptr = input_hdl->get_buffer_ptrs ()[rank];
429529 void * output_ptr = out_hdl->get_buffer_ptrs ()[rank];
@@ -449,6 +549,8 @@ at::Tensor nvshmem_all_to_all_vdev_2d(
449549
450550 // Number of experts per rank
451551 int ne = split_shape[1 ] / world_size;
552+ constexpr int NUM_TILES = THREADS_PER_BLOCK / A2AV_TILE_SIZE;
553+ TORCH_CHECK (ne <= NUM_TILES, " Number of experts must be smaller than NUM_TILES" , NUM_TILES);
452554
453555 // Set device context for getting the stream and launching kernels below
454556 c10::cuda::CUDAGuard guard (input.device ());
@@ -487,7 +589,8 @@ at::Tensor nvshmem_all_to_all_vdev_2d(
487589 &stride_bytes,
488590 &rank,
489591 &world_size,
490- &ne};
592+ &ne,
593+ &major_align_val};
491594 nvshmemx_collective_launch (
492595 (const void *)allToAllV_2d,
493596 dim3 (num_blocks),
0 commit comments