@@ -143,7 +143,7 @@ at::Tensor nvshmem_all_to_all(
143143}
144144
145145// This is an exclusive prefix sum function that calculates read (or write) offsets for each peer.
146- __device__ void prefixSum (int64_t *odata, int64_t *idata, int n) {
146+ __device__ int64_t prefixSum (int64_t *odata, int64_t *idata, int n) {
147147 // Specialize BlockScan for a 1D block of threads, of type int64_t.
148148 // - `BLOCK_SCAN_WARP_SCANS` is a low-latency scan algorithm (instead of high<
527D
/span>
149149 // throughput which we don't need here).
@@ -161,12 +161,12 @@ __device__ void prefixSum(int64_t *odata, int64_t *idata, int n) {
161161 int64_t thread_data = (tid < n) ? idata[tid] : 0 ;
162162
163163 // Collectively compute the block-wide exclusive prefix sum
164- BlockScanT (temp_storage).ExclusiveSum (thread_data, thread_data);
164+ int64_t block_aggregate;
165+ BlockScanT (temp_storage).ExclusiveSum (thread_data, thread_data, block_aggregate);
165166
166167 // Store the result
167- if (tid < n) {
168- odata[tid] = thread_data;
169- }
168+ odata[tid] = thread_data;
169+ return
B943
span> block_aggregate;
170170}
171171
172172// This kernel is used to exchange output splits and source offsets between peers.
@@ -318,11 +318,192 @@ at::Tensor nvshmem_all_to_all_vdev(
318318 return out;
319319}
320320
321+ // Start of `nvshmem_all_to_all_vdev_2d`
322+ // This kernel is used to exchange output splits and source offsets between peers.
323+ // For meaning of `mype` and `npes`, see the docstring of `nvshmem_all_to_all_vdev_2d`.
324+ // `in_out_splits` is of size (3, npes * ne) and contains:
325+ // - input splits (IN)
326+ // - output splits (OUT) and
327+ // - source offsets (OUT).
328+ __global__ void exchangeSplitAndOffset_2d (int64_t * in_out_splits, int mype, int npes, int ne, size_t input_dim0) {
329+ int nsplits = npes * ne;
330+ auto input_splits = in_out_splits;
331+ auto output_splits = in_out_splits + nsplits;
332+ auto source_offsets = in_out_splits + nsplits * 2 ;
333+ int tid = threadIdx .x ;
334+
335+ __shared__ int64_t peer_offsets[THREADS_PER_BLOCK];
336+
337+ // Scan input splits to get the source offsets
338+ auto sum_of_splits = prefixSum (peer_offsets, input_splits, nsplits);
339+ __syncthreads ();;
340+ CUDA_KERNEL_ASSERT (sum_of_splits <= input_dim0);
341+
342+ // Use 1 block to do the exchange
343+ if (tid < nsplits) {
344+ int peer = tid / ne;
345+ int e = tid % ne;
346+ // This does a transpose from rank-major order to expert-major order
347+ int dst_offset = e * npes + mype;
348+ auto split_val = input_splits[tid];
349+ CUDA_KERNEL_ASSERT (split_val >= 0 );
350+ nvshmem_int64_p (source_offsets + dst_offset, peer_offsets[tid], peer);
351+ nvshmem_int64_p (output_splits + dst_offset, split_val, peer);
352+ }
353+ // This barrier ensures that all remote PEs see the updated values
354+ nvshmemx_barrier_all_block ();
355+ }
356+
357+ // This kernel is used to do the actual data exchange.
358+ // `in_out_splits` has the same definition as in `exchangeSplitAndOffset`.
359+ // `stride` is the stride at dim 0, unit in byte.
360+ // For meaning of `mype` and `npes`, see the docstring of `nvshmem_all_to_all_vdev_2d`.
361+ __global__ void allToAllV_2d (void *send_data, void *recv_data, int64_t * in_out_splits, size_t stride, int mype, int npes, int ne) {
362+ int nsplits = npes * ne;
363+ auto output_splits = in_out_splits + nsplits;
364+ auto source_offsets = in_out_splits + nsplits * 2 ;
365+ int bid = blockIdx .x ;
366+ int tid = threadIdx .x ;
367+
368+ // Calculate the output offsets
369+ __shared__ int64_t e_offsets[THREADS_PER_BLOCK];
370+ prefixSum (e_offsets, output_splits, nsplits);
371+ __syncthreads ();
372+
373+ // Target a different e based on bid
374+ for (int eid = bid; eid < nsplits; eid += gridDim .x ) {
375+ int peer = eid % npes;
376+ // Amount from `peer` for `e`
377+ auto peer_size = output_splits[eid] * stride;
378+ auto source_offset = source_offsets[eid] * stride;
379+ auto write_offset = e_offsets[eid] * stride;
380+ nvshmemx_getmem_block (
381+ (char *)recv_data + write_offset,
382+ (char *)send_data + source_offset,
383+ peer_size,
384+ peer);
385+ }
386+ // Write out the output offsets (to the scratchpad line)
387+ if (bid == 0 && tid < nsplits) {
388+ source_offsets[tid] = e_offsets[tid];
389+ }
390+ }
391+
392+ at::Tensor nvshmem_all_to_all_vdev_2d (
393+ at::Tensor& input,
394+ at::Tensor& out,
395+ at::Tensor& in_out_splits,
396+ std::string group_name) {
397+ /* Perform a 2D AllToAllv shuffle operation using NVSHMEM, with split information provided on device.
398+ * Arguments:
399+ * - `input` is the input tensor
400+ * - `out` is the output tensor
401+ * - `in_out_splits` is a 2D tensor of size (3, `world_size` * `ne`). In the
402+ scenario of Mixture-of-Experts models, `ne` is the number of experts per
403+ rank. The rows of `in_out_splits` are (in order):
404+ input splits (IN)
405+ output splits (OUT) and
406+ output offsets (OUT).
407+ * - `group_name` is the name of the group to use for the collective operation.
408+
409+ * A 2D AllToAllv shuffle is illustrated below:
410+ (world_size = 2, ne = 2, total number of experts = 4)
411+ Source: | Rank 0 | Rank 1 |
412+ | c0 | c1 | c2 | c3 | d0 | d1 | d2 | d3 |
413+
414+ Dest : | Rank 0 | Rank 1 |
415+ | c0 | d0 | c1 | d1 | c2 | d2 | c3 | d3 |
416+ where each `c_i` / `d_i` are slices of the `input` tensor, targeting
417+ expert `i`, with length indicated by input splits (in
418+ `in_out_splits[0]`). That is, the 2D AllToAllv shuffle achives a
419+ transpose from rank-major order at input to expert-major order at
420+ output.
421+ */
422+ auto input_hdl = c10d::symmetric_memory::rendezvous (input, group_name);
423+ auto out_hdl = c10d::symmetric_memory::rendezvous (out, group_name);
424+ auto splits_hdl = c10d::symmetric_memory::rendezvous (in_out_splits, group_name);
425+ int rank = input_hdl->get_rank ();
426+ int world_size = input_hdl->get_world_size ();
427+
428+ void * input_ptr = input_hdl->get_buffer_ptrs ()[rank];
429+ void * output_ptr = out_hdl->get_buffer_ptrs ()[rank];
430+ int64_t * splits_ptr = (int64_t *)(splits_hdl->get_buffer_ptrs ()[rank]);
431+
432+ // Shape checks
433+ auto split_shape = in_out_splits.sizes ();
434+ TORCH_CHECK (in_out_splits.is_contiguous ()
435+ && input.is_contiguous ()
436+ && out.is_contiguous (),
437+ " input, out and in_out_splits must be contiguous" );
438+ TORCH_CHECK (split_shape.size () == 2
439+ && split_shape[0 ] == 3
440+ && split_shape[1 ] % world_size == 0 ,
441+ " in_out_splits must be 2D with 3 rows, "
442+ " each row must be a multiple of world_size" );
443+
444+ // Consistency checks
445+ TORCH_CHECK (input.dtype () == out.dtype ()
446+ && input.stride (0 ) == out.stride (0 ),
447+ " input and out must have the same dtype and same stride at dim 0" );
448+ TORCH_CHECK (in_out_splits.scalar_type () == at::kLong , " in_out_splits must be int64" );
449+
450+ // Number of experts per rank
451+ int ne = split_shape[1 ] / world_size;
452+
453+ // Set device context for getting the stream and launching kernels below
454+ c10::cuda::CUDAGuard guard (input.device ());
455+ auto stream = at::cuda::getCurrentCUDAStream ();
456+
457+ // Exchange output splits and source offsets
458+ auto input_dim0 = input.size (0 );
459+ // Use collective launch because kernel involves nvshmem barrier
460+ void * args0[] = {
461+ &splits_ptr,
462+ &rank,
463+ &world_size,
464+ &ne,
465+ &input_dim0};
466+ nvshmemx_collective_launch (
467+ (const void *)exchangeSplitAndOffset_2d,
468+ dim3 (1 ),
469+ dim3 (THREADS_PER_BLOCK),
470+ args0,
471+ 0 ,
472+ stream);
473+
474+ // CTA Tuning
475+ // Naive for now, use 1 block per expert.
476+ // Total number of blocks is limited to 64 (intra-node) or 8 (inter-node).
477+ int num_blocks = std::min (world_size * ne, world_size > 8 ? 8 : 64 );
478+
479+ // Stride at dim 0
480+ size_t stride_bytes = input.stride (0 ) * input.element_size ();
481+
482+ // All to all data exchange
483+ void * args1[] = {
484+ &input_ptr,
485+ &output_ptr,
486+ &splits_ptr,
487+ &stride_bytes,
488+ &rank,
489+ &world_size,
490+ &ne};
491+ nvshmemx_collective_launch (
492+ (const void *)allToAllV_2d,
493+ dim3 (num_blocks),
494+ dim3 (THREADS_PER_BLOCK),
495+ args1,
496+ 0 ,
497+ stream);
498+ return out;
499+ }
500+
321501} // namespace c10d::nvshmem_extension
322502
323503
324504TORCH_LIBRARY_IMPL (symm_mem, CUDA, m) {
325505 m.impl (" nvshmem_broadcast" , c10d::nvshmem_extension::nvshmem_broadcast);
326506 m.impl (" nvshmem_all_to_all" , c10d::nvshmem_extension::nvshmem_all_to_all);
327507 m.impl (" nvshmem_all_to_all_vdev" , c10d::nvshmem_extension::nvshmem_all_to_all_vdev);
508+ m.impl (" nvshmem_all_to_all_vdev_2d" , c10d::nvshmem_extension::nvshmem_all_to_all_vdev_2d);
328509}
0 commit comments