@@ -311,11 +311,172 @@ at::Tensor nvshmem_all_to_all_vdev(
311311 return out;
312312}
313313
314+ // Start of `nvshmem_all_to_all_vdev_2d`
315+ // This kernel is used to exchange output splits and source offsets between peers.
316+ // For meaning of `mype` and `npes`, see the docstring of `nvshmem_all_to_all_vdev_2d`.
317+ // `in_out_splits` is of size (3, npes * ne) and contains:
318+ // - input splits (IN)
319+ // - output splits (OUT) and
320+ // - source offsets (OUT).
321+ __global__ void exchangeSplitAndOffset_2d (int64_t * in_out_splits, int mype, int npes, int ne) {
322+ int nsplits = npes * ne;
323+ auto input_splits = in_out_splits;
324+ auto output_splits = in_out_splits + nsplits;
325+ auto source_offsets = in_out_splits + nsplits * 2 ;
326+ int tid = threadIdx .x ;
327+
328+ __shared__ int64_t peer_offsets[THREADS_PER_BLOCK];
329+
330+ // Scan input splits to get the source offsets
331+ prefixSum (peer_offsets, input_splits, nsplits);
332+ __syncthreads ();;
333+
334+ // Use 1 block to do the exchange
335+ if (tid < nsplits) {
336+ int peer = tid / ne;
337+ int e = tid % ne;
338+ // This does a transpose from rank-major order to expert-major order
339+ int dst_offset = e * npes + mype;
340+ nvshmem_int64_p (source_offsets + dst_offset, peer_offsets[tid], peer);
341+ nvshmem_int64_p (output_splits + dst_offset, input_splits[tid], peer);
342+ }
343+ // This barrier ensures that all remote PEs see the updated values
344+ nvshmemx_barrier_all_block ();
345+ }
346+
347+ // This kernel is used to do the actual data exchange.
348+ // `in_out_splits` has the same definition as in `exchangeSplitAndOffset`.
349+ // `stride` is the stride at dim 0, unit in byte.
350+ // For meaning of `mype` and `npes`, see the docstring of `nvshmem_all_to_all_vdev_2d`.
351+ __global__ void allToAllV_2d (void *send_data, void *recv_data, int64_t * in_out_splits, size_t stride, int mype, int npes, int ne) {
352+ int nsplits = npes * ne;
353+ auto output_splits = in_out_splits + nsplits;
354+ auto source_offsets = in_out_splits + nsplits * 2 ;
355+ int bid = blockIdx .x ;
356+ int tid = threadIdx .x ;
357+
358+ // Calculate the output offsets
359+ __shared__ int64_t e_offsets[THREADS_PER_BLOCK];
360+ prefixSum (e_offsets, output_splits, nsplits);
361+ __syncthreads ();
362+
363+ // Target a different e based on bid
364+ for (int eid = bid; eid < nsplits; eid += gridDim .x ) {
365+ int peer = eid % npes;
366+ // Amount from `peer` for `e`
367+ auto peer_size = output_splits[eid] * stride;
368+ auto source_offset = source_offsets[eid] * stride;
369+ auto write_offset = e_offsets[eid] * stride;
370+ nvshmemx_getmem_block (
371+ (char *)recv_data + write_offset,
372+ (char *)send_data + source_offset,
373+ peer_size,
374+ peer);
375+ }
376+ // Write out the output offsets (to the scratchpad line)
377+ if (bid == 0 && tid < nsplits) {
378+ source_offsets[tid] = e_offsets[tid];
379+ }
380+ }
381+
382+ at::Tensor nvshmem_all_to_all_vdev_2d (
383+ at::Tensor& input,
384+ at::Tensor& out,
385+ at::Tensor& in_out_splits,
386+ std::string group_name) {
387+ /* Perform a 2D AllToAllv shuffle operation using NVSHMEM, with split information provided on device.
388+ * Arguments:
389+ * - `input` is the input tensor
390+ * - `out` is the output tensor
391+ * - `in_out_splits` is a 2D tensor of size (3, `world_size` * `ne`). In the
392+ scenario of Mixture-of-Experts models, `ne` is the number of experts per
393+ rank. The rows of `in_out_splits` are (in order):
394+ input splits (IN)
395+ output splits (OUT) and
396+ output offsets (OUT).
397+ * - `group_name` is the name of the group to use for the collective operation.
398+
399+ * A 2D AllToAllv shuffle is illustrated below:
400+ (world_size = 2, ne = 2, total number of experts = 4)
401+ Source: | Rank 0 | Rank 1 |
402+ | c0 | c1 | c2 | c3 | d0 | d1 | d2 | d3 |
403+
404+ Dest : | Rank 0 | Rank 1 |
405+ | c0 | d0 | c1 | d1 | c2 | d2 | c3 | d3 |
406+ where each `c_i` / `d_i` are slices of the `input` tensor, targeting
407+ expert `i`, with length indicated by input splits (in
408+ `in_out_splits[0]`). That is, the 2D AllToAllv shuffle achives a
409+ transpose from rank-major order at input to expert-major order at
410+ output.
411+ */
412+ auto input_hdl = c10d::symmetric_memory::rendezvous (input, group_name);
413+ auto out_hdl = c10d::symmetric_memory::rendezvous (out, group_name);
414+ auto splits_hdl = c10d::symmetric_memory::rendezvous (in_out_splits, group_name);
415+ int rank = input_hdl->get_rank ();
416+ int world_size = input_hdl->get_world_size ();
417+
418+ void * input_ptr = input_hdl->get_buffer_ptrs ()[rank];
419+ void * output_ptr = out_hdl->get_buffer_ptrs ()[rank];
420+ int64_t * splits_ptr = (int64_t *)(splits_hdl->get_buffer_ptrs ()[rank]);
421+
422+ auto split_shape = in_out_splits.sizes ();
423+ TORCH_CHECK (split_shape.size () == 2 && split_shape[0 ] == 3 , " in_out_splits must be 2D with 3 rows" );
424+ TORCH_CHECK (split_shape[1 ] % world_size == 0 , " Each row of in_out_splits must be a multiple of world_size" );
425+ // Number of experts per rank
426+ int ne = split_shape[1 ] / world_size;
427+
428+ // Set device context for getting the stream and launching kernels below
429+ c10::cuda::CUDAGuard guard (input.device ());
430+ auto stream = at::cuda::getCurrentCUDAStream ();
431+
432+ // Exchange output splits and source offsets
433+ // Use collective launch because kernel involves nvshmem barrier
434+ void * args0[] = {
435+ &splits_ptr,
436+ &rank,
437+ &world_size,
438+ &ne};
439+ nvshmemx_collective_launch (
440+ (const void *)exchangeSplitAndOffset_2d,
441+ dim3 (1 ),
442+ dim3 (THREADS_PER_BLOCK),
443+ args0,
444+ 0 ,
445+ stream);
446+
447+ // CTA Tuning
448+ // Naive for now, use 1 block per expert.
449+ // Total number of blocks is limited to 64 (intra-node) or 8 (inter-node).
450+ int num_blocks = std::min (world_size * ne, world_size > 8 ? 8 : 64 );
451+
452+ // Stride at dim 0 (assuming input is contiguous, TODO)
453+ size_t stride_bytes = input.stride (0 ) * input.element_size ();
454+
455+ // All to all data exchange
456+ void * args1[] = {
457+ &input_ptr,
458+ &output_ptr,
459+ &splits_ptr,
460+ &stride_bytes,
461+ &rank,
462+ &world_size,
463+ &ne};
464+ nvshmemx_collective_launch (
465+ (const void *)allToAllV_2d,
466+ dim3 (num_blocks),
467+ dim3 (THREADS_PER_BLOCK),
468+ args1,
469+ 0 ,
470+ stream);
471+ return out;
472+ }
473+
314474} // namespace c10d::nvshmem_extension
315475
316476
317477TORCH_LIBRARY_IMPL (symm_mem, CUDA, m) {
318478 m.impl (" nvshmem_broadcast" , c10d::nvshmem_extension::nvshmem_broadcast);
319479 m.impl (" nvshmem_all_to_all" , c10d::nvshmem_extension::nvshmem_all_to_all);
320480 m.impl (" nvshmem_all_to_all_vdev" , c10d::nvshmem_extension::nvshmem_all_to_all_vdev);
481+ m.impl (" nvshmem_all_to_all_vdev_2d" , c10d::nvshmem_extension::nvshmem_all_to_all_vdev_2d);
321482}
0 commit comments