-
Notifications
You must be signed in to change notification settings - Fork 26.7k
[a2av] 2D all-to-all-vdev #155058
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[a2av] 2D all-to-all-vdev #155058
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -141,7 +141,7 @@ at::Tensor nvshmem_all_to_all( | |
| } | ||
|
|
||
| // This is an exclusive prefix sum function that calculates read (or write) offsets for each peer. | ||
| __device__ void prefixSum(int64_t *odata, int64_t *idata, int n) { | ||
| __device__ int64_t prefixSum(int64_t *odata, int64_t *idata, int n) { | ||
| // Specialize BlockScan for a 1D block of threads, of type int64_t. | ||
| // - `BLOCK_SCAN_WARP_SCANS` is a low-latency scan algorithm (instead of high | ||
| // throughput which we don't need here). | ||
|
|
@@ -159,12 +159,12 @@ __device__ void prefixSum(int64_t *odata, int64_t *idata, int n) { | |
| int64_t thread_data = (tid < n) ? idata[tid] : 0; | ||
|
|
||
| // Collectively compute the block-wide exclusive prefix sum | ||
| BlockScanT(temp_storage).ExclusiveSum(thread_data, thread_data); | ||
| int64_t block_aggregate; | ||
| BlockScanT(temp_storage).ExclusiveSum(thread_data, thread_data, block_aggregate); | ||
|
|
||
| // Store the result | ||
| if (tid < n) { | ||
| odata[tid] = thread_data; | ||
| } | ||
| odata[tid] = thread_data; | ||
| return block_aggregate; | ||
| } | ||
|
|
||
| // This kernel is used to exchange output splits and source offsets between peers. | ||
|
|
@@ -311,11 +311,192 @@ at::Tensor nvshmem_all_to_all_vdev( | |
| return out; | ||
| } | ||
|
|
||
| // Start of `nvshmem_all_to_all_vdev_2d` | ||
| // This kernel is used to exchange output splits and source offsets between peers. | ||
| // For meaning of `mype` and `npes`, see the docstring of `nvshmem_all_to_all_vdev_2d`. | ||
| // `in_out_splits` is of size (3, npes * ne) and contains: | ||
| // - input splits (IN) | ||
| // - output splits (OUT) and | ||
| // - source offsets (OUT). | ||
| __global__ void exchangeSplitAndOffset_2d(int64_t* in_out_splits, int mype, int npes, int ne, size_t input_dim0) { | ||
| int nsplits = npes * ne; | ||
| auto input_splits = in_out_splits; | ||
| auto output_splits = in_out_splits + nsplits; | ||
| auto source_offsets = in_out_splits + nsplits * 2; | ||
| int tid = threadIdx.x; | ||
|
|
||
| __shared__ int64_t peer_offsets[THREADS_PER_BLOCK]; | ||
|
|
||
| // Scan input splits to get the source offsets | ||
| auto sum_of_splits = prefixSum(peer_offsets, input_splits, nsplits); | ||
| __syncthreads();; | ||
| CUDA_KERNEL_ASSERT(sum_of_splits <= input_dim0); | ||
|
|
||
| // Use 1 block to do the exchange | ||
| if (tid < nsplits) { | ||
| int peer = tid / ne; | ||
| int e = tid % ne; | ||
| // This does a transpose from rank-major order to expert-major order | ||
| int dst_offset = e * npes + mype; | ||
| auto split_val = input_splits[tid]; | ||
| CUDA_KERNEL_ASSERT(split_val >= 0); | ||
| nvshmem_int64_p(source_offsets + dst_offset, peer_offsets[tid], peer); | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. here would also make sense to check that there are no negative numbers in splits, and that sum of splits is less than input size?
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Added a negative check. |
||
| nvshmem_int64_p(output_splits + dst_offset, split_val, peer); | ||
| } | ||
| // This barrier ensures that all remote PEs see the updated values | ||
| nvshmemx_barrier_all_block(); | ||
| } | ||
|
|
||
| // This kernel is used to do the actual data exchange. | ||
| // `in_out_splits` has the same definition as in `exchangeSplitAndOffset`. | ||
| // `stride` is the stride at dim 0, unit in byte. | ||
| // For meaning of `mype` and `npes`, see the docstring of `nvshmem_all_to_all_vdev_2d`. | ||
| __global__ void allToAllV_2d(void *send_data, void *recv_data, int64_t* in_out_splits, size_t stride, int mype, int npes, int ne) { | ||
| int nsplits = npes * ne; | ||
| auto output_splits = in_out_splits + nsplits; | ||
| auto source_offsets = in_out_splits + nsplits * 2; | ||
| int bid = blockIdx.x; | ||
| int tid = threadIdx.x; | ||
|
|
||
| // Calculate the output offsets | ||
| __shared__ int64_t e_offsets[THREADS_PER_BLOCK]; | ||
| prefixSum(e_offsets, output_splits, nsplits); | ||
| __syncthreads(); | ||
|
|
||
| // Target a different e based on bid | ||
| for (int eid = bid; eid < nsplits; eid += gridDim.x) { | ||
| int peer = eid % npes; | ||
| // Amount from `peer` for `e` | ||
| auto peer_size = output_splits[eid] * stride; | ||
| auto source_offset = source_offsets[eid] * stride; | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. you need to check that these offsets are within tensor, so there are no OOB reads
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Added |
||
| auto write_offset = e_offsets[eid] * stride; | ||
| nvshmemx_getmem_block( | ||
| (char*)recv_data + write_offset, | ||
| (char*)send_data + source_offset, | ||
| peer_size, | ||
| peer); | ||
| } | ||
| // Write out the output offsets (to the scratchpad line) | ||
| if (bid == 0 && tid < nsplits) { | ||
| source_offsets[tid] = e_offsets[tid]; | ||
| } | ||
| } | ||
|
|
||
| at::Tensor nvshmem_all_to_all_vdev_2d( | ||
| at::Tensor& input, | ||
| at::Tensor& out, | ||
| at::Tensor& in_out_splits, | ||
| std::string group_name) { | ||
| /* Perform a 2D AllToAllv shuffle operation using NVSHMEM, with split information provided on device. | ||
| * Arguments: | ||
| * - `input` is the input tensor | ||
| * - `out` is the output tensor | ||
| * - `in_out_splits` is a 2D tensor of size (3, `world_size` * `ne`). In the | ||
| scenario of Mixture-of-Experts models, `ne` is the number of experts per | ||
| rank. The rows of `in_out_splits` are (in order): | ||
| input splits (IN) | ||
| output splits (OUT) and | ||
| output offsets (OUT). | ||
| * - `group_name` is the name of the group to use for the collective operation. | ||
|
|
||
| * A 2D AllToAllv shuffle is illustrated below: | ||
| (world_size = 2, ne = 2, total number of experts = 4) | ||
| Source: | Rank 0 | Rank 1 | | ||
| | c0 | c1 | c2 | c3 | d0 | d1 | d2 | d3 | | ||
|
|
||
| Dest : | Rank 0 | Rank 1 | | ||
| | c0 | d0 | c1 | d1 | c2 | d2 | c3 | d3 | | ||
| where each `c_i` / `d_i` are slices of the `input` tensor, targeting | ||
| expert `i`, with length indicated by input splits (in | ||
| `in_out_splits[0]`). That is, the 2D AllToAllv shuffle achives a | ||
| transpose from rank-major order at input to expert-major order at | ||
| output. | ||
| */ | ||
| auto input_hdl = c10d::symmetric_memory::rendezvous(input, group_name); | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. you also need to check input/output dimensionality and contiguity
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Added a contiguity check. I can't think of a dimensionality requirement here (can be 1D, 2D or n-D).
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. also you need to check dtypes - integer for splits, same dtype for input output |
||
| auto out_hdl = c10d::symmetric_memory::rendezvous(out, group_name); | ||
| auto splits_hdl = c10d::symmetric_memory::rendezvous(in_out_splits, group_name); | ||
| int rank = input_hdl->get_rank(); | ||
| int world_size = input_hdl->get_world_size(); | ||
|
|
||
| void* input_ptr = input_hdl->get_buffer_ptrs()[rank]; | ||
| void* output_ptr = out_hdl->get_buffer_ptrs()[rank]; | ||
| int64_t* splits_ptr = (int64_t*)(splits_hdl->get_buffer_ptrs()[rank]); | ||
|
|
||
| // Shape checks | ||
| auto split_shape = in_out_splits.sizes(); | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I still don't see a check for
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sorry added. |
||
| TORCH_CHECK(in_out_splits.is_contiguous() | ||
| && input.is_contiguous() | ||
| && out.is_contiguous(), | ||
| "input, out and in_out_splits must be contiguous"); | ||
| TORCH_CHECK(split_shape.size() == 2 | ||
| && split_shape[0] == 3 | ||
| && split_shape[1] % world_size == 0, | ||
| "in_out_splits must be 2D with 3 rows, " | ||
| "each row must be a multiple of world_size"); | ||
|
|
||
| // Consistency checks | ||
| TORCH_CHECK(input.dtype() == out.dtype() | ||
| && input.stride(0) == out.stride(0), | ||
| "input and out must have the same dtype and same stride at dim 0"); | ||
| TORCH_CHECK(in_out_splits.scalar_type() == at::kLong, "in_out_splits must be int64"); | ||
|
|
||
| // Number of experts per rank | ||
| int ne = split_shape[1] / world_size; | ||
|
|
||
| // Set device context for getting the stream and launching kernels below | ||
| c10::cuda::CUDAGuard guard(input.device()); | ||
| auto stream = at::cuda::getCurrentCUDAStream(); | ||
|
|
||
| // Exchange output splits and source offsets | ||
| auto input_dim0 = input.size(0); | ||
| // Use collective launch because kernel involves nvshmem barrier | ||
| void* args0[] = { | ||
| &splits_ptr, | ||
| &rank, | ||
| &world_size, | ||
| &ne, | ||
| &input_dim0}; | ||
| nvshmemx_collective_launch( | ||
| (const void*)exchangeSplitAndOffset_2d, | ||
| dim3(1), | ||
| dim3(THREADS_PER_BLOCK), | ||
| args0, | ||
| 0, | ||
| stream); | ||
|
|
||
| // CTA Tuning | ||
| // Naive for now, use 1 block per expert. | ||
| // Total number of blocks is limited to 64 (intra-node) or 8 (inter-node). | ||
| int num_blocks = std::min(world_size * ne, world_size > 8 ? 8 : 64); | ||
|
|
||
| // Stride at dim 0 | ||
| size_t stride_bytes = input.stride(0) * input.element_size(); | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. here you are assuming that input.stride(0) == output.stride(0), you should check it
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Added |
||
|
|
||
| // All to all data exchange | ||
| void* args1[] = { | ||
| &input_ptr, | ||
| &output_ptr, | ||
| &splits_ptr, | ||
| &stride_bytes, | ||
| &rank, | ||
| &world_size, | ||
| &ne}; | ||
| nvshmemx_collective_launch( | ||
| (const void*)allToAllV_2d, | ||
| dim3(num_blocks), | ||
| dim3(THREADS_PER_BLOCK), | ||
| args1, | ||
| 0, | ||
| stream); | ||
| return out; | ||
| } | ||
|
|
||
| } // namespace c10d::nvshmem_extension | ||
|
|
||
|
|
||
| TORCH_LIBRARY_IMPL(symm_mem, CUDA, m) { | ||
| m.impl("nvshmem_broadcast", c10d::nvshmem_extension::nvshmem_broadcast); | ||
| m.impl("nvshmem_all_to_all", c10d::nvshmem_extension::nvshmem_all_to_all); | ||
| m.impl("nvshmem_all_to_all_vdev", c10d::nvshmem_extension::nvshmem_all_to_all_vdev); | ||
| m.impl("nvshmem_all_to_all_vdev_2d", c10d::nvshmem_extension::nvshmem_all_to_all_vdev_2d); | ||
| } | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
btw previously I've seen int64 scan slow down kernel big time, you might want to check the performance