-
Notifications
You must be signed in to change notification settings - Fork 26.7k
[a2av] Align length of major dimension in output of 2D a2av #155172
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Closed
Closed
Changes from 8 commits
Commits
Show all changes
10 commits
Select commit
Hold shift + click to select a range
79094e2
Update
kwen2501 ecfeaf4
Update
kwen2501 754d6bd
Update
kwen2501 93041d6
Update
kwen2501 b8db459
Update
kwen2501 8a889c0
Update
kwen2501 ba6b78d
Update
kwen2501 75cc801
Update
kwen2501 741b5a8
Update
kwen2501 fd596f4
Update
kwen2501 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
8000
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -17,6 +17,7 @@ using c10d::symmetric_memory::StoreExchange; | |
| static StoreExchange storeExchange = StoreExchange("nvshmem_ext"); | ||
|
|
||
| #define THREADS_PER_BLOCK 512 | ||
| #define WARP_SIZE 32 | ||
|
|
||
| // Bootstrap based on user's setting for NCCL | ||
| // Long term, this may be a bit unclean; short term, it improves UX | ||
|
|
@@ -346,20 +347,100 @@ __global__ void exchangeSplitAndOffset_2d(int64_t* in_out_splits, int mype, int | |
| nvshmemx_barrier_all_block(); | ||
| } | ||
|
|
||
| // This is an warp-scope, exclusive prefix sum. When called by a block of | ||
| // threads, each warp will perform an independent prefix sum, concurrently. | ||
| // Returns the sum of all elements in the warp. | ||
| // `NUM_WARPS` is the number of warps participating the concurrent prefix sum. | ||
| template <int NUM_WARPS> | ||
| __device__ int64_t prefixSum_warp(int64_t *odata, int64_t *idata, int n) { | ||
| CUDA_KERNEL_ASSERT(n <= WARP_SIZE); | ||
|
|
||
| // Specialize WarpScan for type int | ||
| using WarpScan = at_cuda_detail::cub::WarpScan<int64_t>; | ||
| // Allocate WarpScan shared memory for N warps | ||
| __shared__ typename WarpScan::TempStorage temp_storage[NUM_WARPS]; | ||
|
|
||
| int warp_id = threadIdx.x / WARP_SIZE; | ||
| if (warp_id >= NUM_WARPS) { | ||
| return 0; | ||
| } | ||
|
|
||
| // Obtain input item for each thread | ||
| int tid = threadIdx.x % WARP_SIZE; | ||
| int64_t thread_data = (tid < n) ? idata[tid] : 0; | ||
|
|
||
| // Total sum of all elements in the warp | ||
| int64_t warp_aggregate; | ||
| // Compute the warp-wide exclusive prefix sum | ||
| WarpScan(temp_storage[warp_id]).ExclusiveSum(thread_data, thread_data, warp_aggregate); | ||
|
|
||
| // Store the result | ||
| odata[tid] = thread_data; | ||
| return warp_aggregate; | ||
| } | ||
|
|
||
| // This is for abstracting a thread-group-scope, exclusive prefix sum. | ||
| // Since we use warp-scope prefix sum, the thread group size is limited to warp size. | ||
| #define A2AV_TILE_SIZE WARP_SIZE | ||
|
|
||
| // This kernel is used to do the actual data exchange. | ||
| // `in_out_splits` has the same definition as in `exchangeSplitAndOffset`. | ||
| // `stride` is the stride at dim 0, unit in byte. | ||
| // For meaning of `mype` and `npes`, see the docstring of `nvshmem_all_to_all_vdev_2d`. | ||
| __global__ void allToAllV_2d(void *send_data, void *recv_data, int64_t* in_out_splits, size_t stride, int mype, int npes, int ne) { | ||
| __global__ void allToAllV_2d(void *send_data, void *recv_data, int64_t* in_out_splits, size_t stride, int mype, int npes, int ne, int64_t major_align) { | ||
| int nsplits = npes * ne; | ||
| auto output_splits = in_out_splits + nsplits; | ||
| auto source_offsets = in_out_splits + nsplits * 2; | ||
| int bid = blockIdx.x; | ||
| int tid = threadIdx.x; | ||
|
|
||
| // Calculate the output offsets | ||
| __shared__ int64_t e_offsets[THREADS_PER_BLOCK]; | ||
| prefixSum(e_offsets, output_splits, nsplits); | ||
| // Split the thread block into tiles | ||
| constexpr int NUM_TILES = THREADS_PER_BLOCK / A2AV_TILE_SIZE; | ||
| int tileId = tid / A2AV_TILE_SIZE; | ||
| int laneId = tid % A2AV_TILE_SIZE; | ||
| // Each tile calculates its own prefix sum | ||
| __shared__ int64_t tile_prefix_sums[NUM_TILES][A2AV_TILE_SIZE]; | ||
| // A tile takes care of npes worth of splits | ||
| int nsplits_per_tile = min(npes, nsplits - tileId * npes); | ||
| // TODO: currently it is assumed that the number of PE's is smaller than | ||
| // `A2AV_TILE_SIZE` bc the warp-scope prefix sum can only handle up to | ||
| // WARP_SIZE elements | ||
| CUDA_KERNEL_ASSERT(npes <= A2AV_TILE_SIZE); | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. scan can easily do multiple elements per thread, so we can relax this assert, but it's fine for now |
||
| // Similarly, the number of experts per rank is also assumed to be smaller | ||
| // than `NUM_TILES` | ||
| CUDA_KERNEL_ASSERT(ne <= NUM_TILES); | ||
|
|
||
| // Total length of each tile | ||
| __shared__ int64_t len_per_tile[NUM_TILES]; | ||
| // When `nsplits` is small, not every tile gets data to sum. They can skip | ||
| // this local prefix sum. | ||
| if (nsplits_per_tile > 0) { | ||
| // Each tile calculates its own prefix sum, return value is the sum of all elements in the tile. | ||
| int64_t my_tile_len = prefixSum_warp<NUM_TILES>(tile_prefix_sums[tileId], output_splits + tileId * npes, nsplits_per_tile); | ||
| // Last thread in each tile does the up aligning. | ||
| if (laneId == A2AV_TILE_SIZE - 1) { | ||
| auto aligned_len = (my_tile_len + major_align - 1) / major_align * major_align; | ||
| // In case `aligned_len` is 0, we set it to `major_align` to avoid an | ||
| // empty bin, bc cutlass currently does not support it. See | ||
| // https://github.com/pytorch/pytorch/issues/152668. | ||
| len_per_tile[tileId] = max(aligned_len, major_align); | ||
| } | ||
| } | ||
| __syncthreads(); | ||
|
|
||
| // Starting offset of each tile | ||
| __shared__ int64_t start_offset_per_tile[NUM_TILES]; | ||
| // Prefix sum again to get the tiles' start offsets. | ||
| // `NUM_TILES` is typically not greater than 32, because 32 tiles * 32 threads | ||
| // = 1024 threads, and this kernel is launched within 1024 threads. Thus, we | ||
| // can use warp-scope prefix sum. | ||
| static_assert(NUM_TILES <= WARP_SIZE); | ||
| // Only 1 warp is needed | ||
| prefixSum_warp<1>(start_offset_per_tile, len_per_tile, NUM_TILES); | ||
| __syncthreads(); | ||
|
|
||
| // Add tile offset to every element in the tile | ||
| tile_prefix_sums[tileId][laneId] += start_offset_per_tile[tileId]; | ||
| __syncthreads(); | ||
|
|
||
| // Target a different e based on bid | ||
|
|
@@ -368,7 +449,8 @@ __global__ void allToAllV_2d(void *send_data, void *recv_data, int64_t* in_out_s | |
| // Amount from `peer` for `e` | ||
| auto peer_size = output_splits[eid] * stride; | ||
| auto source_offset = source_offsets[eid] * stride; | ||
| auto write_offset = e_offsets[eid] * stride; | ||
| auto e_offset = tile_prefix_sums[eid / npes][peer]; | ||
| auto write_offset = e_offset * stride; | ||
| nvshmemx_getmem_block( | ||
| (char*)recv_data + write_offset, | ||
| (char*)send_data + source_offset, | ||
|
|
@@ -377,15 +459,16 @@ __global__ void allToAllV_2d(void *send_data, void *recv_data, int64_t* in_out_s | |
| } | ||
| // Write out the output offsets (to the scratchpad line) | ||
| if (bid == 0 && tid < nsplits) { | ||
| source_offsets[tid] = e_offsets[tid]; | ||
| source_offsets[tid] = tile_prefix_sums[tid / npes][tid % npes]; | ||
| } | ||
| } | ||
|
|
||
| at::Tensor nvshmem_all_to_all_vdev_2d( | ||
| at::Tensor& input, | ||
| at::Tensor& out, | ||
| at::Tensor& in_out_splits, | ||
| std::string group_name) { | ||
| std::string group_name, | ||
| int64_t major_align) { | ||
| /* Perform a 2D AllToAllv shuffle operation using NVSHMEM, with split information provided on device. | ||
| * Arguments: | ||
| * - `input` is the input tensor | ||
|
|
@@ -397,6 +480,8 @@ at::Tensor nvshmem_all_to_all_vdev_2d( | |
| output splits (OUT) and | ||
| output offsets (OUT). | ||
| * - `group_name` is the name of the group to use for the collective operation. | ||
| * - `major_align` is the alignment of the "major dimension" of the output | ||
| sequence. See below for details. | ||
|
|
||
| * A 2D AllToAllv shuffle is illustrated below: | ||
| (world_size = 2, ne = 2, total number of 5276 experts = 4) | ||
|
|
@@ -410,12 +495,23 @@ at::Tensor nvshmem_all_to_all_vdev_2d( | |
| `in_out_splits[0]`). That is, the 2D AllToAllv shuffle achives a | ||
| transpose from rank-major order at input to expert-major order at | ||
| output. | ||
|
|
||
| * If `major_align` is not 1, the output offsets of c1, c2, c3 will be | ||
| up-aligned to this value. For example, if c0 has length 5 and d0 has | ||
| length 7 (making a total of 12), and if the `major_align` is set to 16, | ||
| the output offset of c1 will be 16. Similar for c2 and c3. This value has | ||
| no effect on the offset of the minor dimension, i.e. d0, d1, d2 and d3. | ||
| Note: since cutlass does not support empty bins, we set the aligned length | ||
| to `major_align` if it is 0. See | ||
| https://github.com/pytorch/pytorch/issues/152668. | ||
| */ | ||
| auto input_hdl = c10d::symmetric_memory::rendezvous(input, group_name); | ||
| auto out_hdl = c10d::symmetric_memory::rendezvous(out, group_name); | ||
| auto splits_hdl = c10d::symmetric_memory::rendezvous(in_out_splits, group_name); | ||
| int rank = input_hdl->get_rank(); | ||
| int world_size = input_hdl->get_world_size(); | ||
| // TODO: world_size is currently limited by the number of elements in a WarpScan. | ||
| TORCH_CHECK(world_size <= A2AV_TILE_SIZE, "world_size must be smaller than A2AV_TILE_SIZE", A2AV_TILE_SIZE); | ||
|
|
||
| void* input_ptr = input_hdl->get_buffer_ptrs()[rank]; | ||
| void* output_ptr = out_hdl->get_buffer_ptrs()[rank]; | ||
|
|
@@ -428,8 +524,11 @@ at::Tensor nvshmem_all_to_all_vdev_2d( | |
| "input, out and in_out_splits must be contiguous"); | ||
| TORCH_CHECK(split_shape.size() == 2 && split_shape[0] == 3, "in_out_splits must be 2D with 3 rows"); | ||
| TORCH_CHECK(split_shape[1] % world_si B94A ze == 0, "Each row of in_out_splits must be a multiple of world_size"); | ||
|
|
||
| // Number of experts per rank | ||
| int ne = split_shape[1] / world_size; | ||
| constexpr int NUM_TILES = THREADS_PER_BLOCK / A2AV_TILE_SIZE; | ||
| TORCH_CHECK(ne <= NUM_TILES, "Number of experts must be smaller than NUM_TILES", NUM_TILES); | ||
|
|
||
| // Set device context for getting the stream and launching kernels below | ||
| c10::cuda::CUDAGuard guard(input.device()); | ||
|
|
@@ -466,7 +565,8 @@ at::Tensor nvshmem_all_to_all_vdev_2d( | |
| &stride_bytes, | ||
| &rank, | ||
| &world_size, | ||
| &ne}; | ||
| &ne, | ||
| &major_align}; | ||
| nvshmemx_collective_launch( | ||
| (const void*)allToAllV_2d, | ||
| dim3(num_blocks), | ||
|
|
||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You also need ne < NUM_TILES?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
correct. Let me add it.