-
Notifications
You must be signed in to change notification settings - Fork 26.7k
[a2av] Align length of major dimension in output of 2D a2av #155172
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 2 commits
79094e2
ecfeaf4
754d6bd
93041d6
b8db459
8a889c0
ba6b78d
75cc801
741b5a8
fd596f4
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -11,12 +11,18 @@ | |
| #include <ATen/cuda/cub.cuh> | ||
| #include <nvshmem.h> | ||
|
|
||
| #include <cooperative_groups.h> | ||
|
|
||
| namespace c10d::nvshmem_extension { | ||
|
|
||
| using namespace cooperative_groups; | ||
| namespace cg = cooperative_groups; | ||
|
|
||
| using c10d::symmetric_memory::StoreExchange; | ||
| static StoreExchange storeExchange = StoreExchange("nvshmem_ext"); | ||
|
|
||
| #define THREADS_PER_BLOCK 512 | ||
| #define WARP_SIZE 32 | ||
|
|
||
| // Bootstrap based on user's setting for NCCL | ||
| // Long term, this may be a bit unclean; short term, it improves UX | ||
|
|
@@ -344,20 +350,82 @@ __global__ void exchangeSplitAndOffset_2d(int64_t* in_out_splits, int mype, int | |
| nvshmemx_barrier_all_block(); | ||
| } | ||
|
|
||
| // This is an warp-scope, exclusive prefix sum. | ||
| __device__ void prefixSum_warp(int64_t *odata, int64_t *idata, int n) { | ||
| CUDA_KERNEL_ASSERT(n <= WARP_SIZE); | ||
| constexpr int NUM_WARPS = THREADS_PER_BLOCK / WARP_SIZE; | ||
|
|
||
| // Specialize WarpScan for type int | ||
| using WarpScan = at_cuda_detail::cub::WarpScan<int64_t>; | ||
| // Allocate WarpScan shared memory for N warps | ||
| __shared__ typename WarpScan::TempStorage temp_storage[NUM_WARPS]; | ||
|
|
||
| // Obtain input item for each thread | ||
| int tid = threadIdx.x % WARP_SIZE; | ||
| int64_t thread_data = (tid < n) ? idata[tid] : 0; | ||
|
|
||
| // Compute the warp-wide exclusive prefix sum | ||
| int warp_id = threadIdx.x / WARP_SIZE; | ||
| WarpScan(temp_storage[warp_id]).ExclusiveSum(thread_data, thread_data); | ||
|
|
||
| // Store the result | ||
| odata[tid] = thread_data; | ||
| } | ||
|
|
||
| // This is for abstracting a thread-group-scope, exclusive prefix sum. | ||
| // Since we use warp-scope prefix sum, the thread group size is limited to warp size. | ||
| #define A2AV_TILE_SIZE WARP_SIZE | ||
|
|
||
| // This kernel is used to do the actual data exchange. | ||
| // `in_out_splits` has the same definition as in `exchangeSplitAndOffset`. | ||
| // `stride` is the stride at dim 0, unit in byte. | ||
| // For meaning of `mype` and `npes`, see the docstring of `nvshmem_all_to_all_vdev_2d`. | ||
| __global__ void allToAllV_2d(void *send_data, void *recv_data, int64_t* in_out_splits, size_t stride, int mype, int npes, int ne) { | ||
| __global__ void allToAllV_2d(void *send_data, void *recv_data, int64_t* in_out_splits, size_t stride, int mype, int npes, int ne, int64_t major_align) { | ||
| int nsplits = npes * ne; | ||
| auto output_splits = in_out_splits + nsplits; | ||
| auto source_offsets = in_out_splits + nsplits * 2; | ||
| int bid = blockIdx.x; | ||
| int tid = threadIdx.x; | ||
|
|
||
| // Calculate the output offsets | ||
| __shared__ int64_t e_offsets[THREADS_PER_BLOCK]; | ||
| prefixSum(e_offsets, output_splits, nsplits); | ||
| // Split the thread block into tiles | ||
| constexpr int NUM_TILES = THREADS_PER_BLOCK / A2AV_TILE_SIZE; | ||
| thread_group tile = cg::tiled_partition(this_thread_block(), A2AV_TILE_SIZE); | ||
|
||
| int tileId = tid / A2AV_TILE_SIZE; | ||
| // Each tile calculates its own prefix sum | ||
| __shared__ int64_t tile_prefix_sums[NUM_TILES][A2AV_TILE_SIZE]; | ||
| // A tile takes care of npes worth of splits | ||
| int nsplits_per_tile = min(npes, nsplits - tileId * npes); | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You also need ne < NUM_TILES?
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. correct. Let me add it. |
||
| // TODO: currently it is assumed that the number of PE's is smaller than `A2AV_TILE_SIZE` | ||
| CUDA_KERNEL_ASSERT(nsplits_per_tile <= A2AV_TILE_SIZE); | ||
|
|
||
| // Total length of each tile | ||
| __shared__ int64_t len_per_tile[NUM_TILES]; | ||
| // Starting offset of each tile | ||
| __shared__ int64_t start_offset_per_tile[NUM_TILES]; | ||
| // This tile does not need to do tile-wise prefix sum | ||
| if (nsplits_per_tile < 0) goto end_of_tile_prefix_sum; | ||
|
||
|
|
||
| // Each tile calculates its own prefix sum | ||
| prefixSum_warp(tile_prefix_sums[tileId], output_splits + tileId * npes, nsplits_per_tile); | ||
|
|
||
| // Last thread in each tile does the up aligning. | ||
| // Note: using the last thread to read the last sum from `tile_prefix_sums` so | ||
| // that we can save a __syncthreads(). This is safe because the last thread is | ||
| // the one that writes the last sum in the prefixSum function. | ||
| if (tile.thread_rank() == A2AV_TILE_SIZE - 1) { | ||
| auto my_tile_len = tile_prefix_sums[tileId][A2AV_TILE_SIZE - 1] + output_splits[tileId * npes + nsplits_per_tile - 1]; | ||
|
||
| // Up align | ||
| len_per_tile[tileId] = (my_tile_len + major_align) / major_align * major_align; | ||
|
||
| } | ||
| end_of_tile_prefix_sum: | ||
| __syncthreads(); | ||
|
|
||
| // Prefix sum again to get the tiles' start offsets. This is a block-wide prefix sum. | ||
| prefixSum(start_offset_per_tile, len_per_tile, NUM_TILES); | ||
|
||
| __syncthreads(); | ||
|
|
||
| // Add tile offset to every element in the tile | ||
| tile_prefix_sums[tileId][tile.thread_rank()] += start_offset_per_tile[tileId]; | ||
| __syncthreads(); | ||
|
|
||
| // Target a different e based on bid | ||
|
|
@@ -366,7 +434,8 @@ __global__ void allToAllV_2d(void *send_data, void *recv_data, int64_t* in_out_s | |
| // Amount from `peer` for `e` | ||
| auto peer_size = output_splits[eid] * stride; | ||
| auto source_offset = source_offsets[eid] * stride; | ||
| auto write_offset = e_offsets[eid] * stride; | ||
| auto e_offset = tile_prefix_sums[eid / npes][peer]; | ||
| auto write_offset = e_offset * stride; | ||
| nvshmemx_getmem_block( | ||
| (char*)recv_data + write_offset, | ||
| (char*)send_data + source_offset, | ||
|
|
@@ -375,15 +444,16 @@ __global__ void allToAllV_2d(void *send_data, void *recv_data, int64_t* in_out_s | |
| } | ||
| // Write out the output offsets (to the scratchpad line) | ||
| if (bid == 0 && tid < nsplits) { | ||
| source_offsets[tid] = e_offsets[tid]; | ||
| source_offsets[tid] = tile_prefix_sums[tid / npes][tid % npes]; | ||
| } | ||
| } | ||
|
|
||
| at::Tensor nvshmem_all_to_all_vdev_2d( | ||
| at::Tensor& input, | ||
| at::Tensor& out, | ||
| at::Tensor& in_out_splits, | ||
| std::string group_name) { | ||
| std::string group_name, | ||
| int64_t major_align) { | ||
| /* Perform a 2D AllToAllv shuffle operation using NVSHMEM, with split information provided on device. | ||
| * Arguments: | ||
| * - `input` is the input tensor | ||
|
|
@@ -395,6 +465,8 @@ at::Tensor nvshmem_all_to_all_vdev_2d( | |
| output splits (OUT) and | ||
| output offsets (OUT). | ||
| * - `group_name` is the name of the group to use for the collective operation. | ||
| * - `major_align` is the alignment of the "major dimension" of the output | ||
| sequence. See below for details. | ||
|
|
||
| * A 2D AllToAllv shuffle is illustrated below: | ||
| (world_size = 2, ne = 2, total number of experts = 4) | ||
|
|
@@ -408,12 +480,20 @@ at::Tensor nvshmem_all_to_all_vdev_2d( | |
| `in_out_splits[0]`). That is, the 2D AllToAllv shuffle achives a | ||
| transpose from rank-major order at input to expert-major order at | ||
| output. | ||
|
|
||
| * If `major_align` is not 1, the output offsets of c1, c2, c3 will be | ||
| up-aligned to this value. For example, if c0 has length 5 and d0 has | ||
| length 7 (making a total of 12), and if the `major_align` is set to 16, | ||
| the output offset of c1 will be 16. Similar for c2 and c3. This value has | ||
| no effect on the offset of the minor dimension, i.e. d0, d1, d2 and d3. | ||
| */ | ||
| auto input_hdl = c10d::symmetric_memory::rendezvous(input, group_name); | ||
| auto out_hdl = c10d::symmetric_memory::rendezvous(out, group_name); | ||
| auto splits_hdl = c10d::symmetric_memory::rendezvous(in_out_splits, group_name); | ||
| int rank = input_hdl->get_rank(); | ||
| int world_size = input_hdl->get_world_size(); | ||
| // TODO: world_size is currently limited by the number of elements in a WarpScan. | ||
| TORCH_CHECK(world_size <= A2AV_TILE_SIZE, "world_size must be smaller than A2AV_TILE_SIZE", A2AV_TILE_SIZE); | ||
|
|
||
| void* input_ptr = input_hdl->get_buffer_ptrs()[rank]; | ||
| void* output_ptr = out_hdl->get_buffer_ptrs()[rank]; | ||
|
|
@@ -460,7 +540,8 @@ at::Tensor nvshmem_all_to_all_vdev_2d( | |
| &stride_bytes, | ||
| &rank, | ||
| &world_size, | ||
| &ne}; | ||
| &ne, | ||
| &major_align}; | ||
| nvshmemx_collective_launch( | ||
| (const void*)allToAllV_2d, | ||
| dim3(num_blocks), | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.