-
Notifications
You must be signed in to change notification settings - Fork 26.7k
[a2av] Align length of major dimension in output of 2D a2av #155172
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 1 commit
79094e2
ecfeaf4
754d6bd
93041d6
b8db459
8a889c0
ba6b78d
75cc801
741b5a8
fd596f4
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
[ghstack-poisoned]
- Loading branch information
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -15,9 +15,6 @@ | |
|
|
||
| namespace c10d::nvshmem_extension { | ||
|
|
||
| using namespace cooperative_groups; | ||
| namespace cg = cooperative_groups; | ||
|
|
||
| 8000 | using c10d::symmetric_memory::StoreExchange; | |
| static StoreExchange storeExchange = StoreExchange("nvshmem_ext"); | ||
|
|
||
|
|
@@ -350,8 +347,10 @@ __global__ void exchangeSplitAndOffset_2d(int64_t* in_out_splits, int mype, int | |
| nvshmemx_barrier_all_block(); | ||
| } | ||
|
|
||
| // This is an warp-scope, exclusive prefix sum. | ||
| __device__ void prefixSum_warp(int64_t *odata, int64_t *idata, int n) { | ||
| // This is an warp-scope, exclusive prefix sum. When called by a block of | ||
| // threads, each warp will perform an independent prefix sum, concurrently. | ||
| // Returns the sum of all elements in the warp. | ||
| __device__ int64_t prefixSum_warp(int64_t *odata, int64_t *idata, int n) { | ||
| CUDA_KERNEL_ASSERT(n <= WARP_SIZE); | ||
| constexpr int NUM_WARPS = THREADS_PER_BLOCK / WARP_SIZE; | ||
|
|
||
|
|
@@ -365,11 +364,13 @@ __device__ void prefixSum_warp(int64_t *odata, int64_t *idata, int n) { | |
| int64_t thread_data = (tid < n) ? idata[tid] : 0; | ||
|
|
||
| // Compute the warp-wide exclusive prefix sum | ||
| int64_t warp_aggregate; | ||
| int warp_id = threadIdx.x / WARP_SIZE; | ||
| WarpScan(temp_storage[warp_id]).ExclusiveSum(thread_data, thread_data); | ||
| WarpScan(temp_storage[warp_id]).ExclusiveSum(thread_data, thread_data, warp_aggregate); | ||
|
|
||
| // Store the result | ||
| odata[tid] = thread_data; | ||
| return warp_aggregate; | ||
| } | ||
|
|
||
| // This is for abstracting a thread-group-scope, exclusive prefix sum. | ||
|
|
@@ -389,43 +390,46 @@ __global__ void allToAllV_2d(void *send_data, void *recv_data, int64_t* in_out_s | |
|
|
||
| // Split the thread block into tiles | ||
| constexpr int NUM_TILES = THREADS_PER_BLOCK / A2AV_TILE_SIZE; | ||
| thread_group tile = cg::tiled_partition(this_thread_block(), A2AV_TILE_SIZE); | ||
| int tileId = tid / A2AV_TILE_SIZE; | ||
| int laneId = tid % A2AV_TILE_SIZE; | ||
| // Each tile calculates its own prefix sum | ||
| __shared__ int64_t tile_prefix_sums[NUM_TILES][A2AV_TILE_SIZE]; | ||
| // A tile takes care of npes worth of splits | ||
| int nsplits_per_tile = min(npes, nsplits - tileId * npes); | ||
| // TODO: currently it is assumed that the number of PE's is smaller than `A2AV_TILE_SIZE` | ||
| CUDA_KERNEL_ASSERT(nsplits_per_tile <= A2AV_TILE_SIZE); | ||
| // TODO: currently it is assumed that the number of PE's is smaller than | ||
| // `A2AV_TILE_SIZE` bc the warp-scope prefix sum can only handle up to | ||
| // WARP_SIZE elements | ||
| CUDA_KERNEL_ASSERT(npes <= A2AV_TILE_SIZE); | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. scan can easily do multiple elements per thread, so we can relax this assert, but it's fine for now |
||
| // Similarly, the number of experts per rank is also assumed to be smaller | ||
| // than `NUM_TILES` | ||
| CUDA_KERNEL_ASSERT(ne <= NUM_TILES); | ||
|
|
||
| // Total length of each tile | ||
| __shared__ int64_t len_per_tile[NUM_TILES]; | ||
| // Starting offset of each tile | ||
| __shared__ int64_t start_offset_per_tile[NUM_TILES]; | ||
| // This tile does not need to do tile-wise prefix sum | ||
| if (nsplits_per_tile < 0) goto end_of_tile_prefix_sum; | ||
|
|
||
| // Each tile calculates its own prefix sum | ||
| prefixSum_warp(tile_prefix_sums[tileId], output_splits + tileId * npes, nsplits_per_tile); | ||
|
|
||
| // Last thread in each tile does the up aligning. | ||
| // Note: using the last thread to read the last sum from `tile_prefix_sums` so | ||
| // that we can save a __syncthreads(). This is safe because the last thread is | ||
| // the one that writes the last sum in the prefixSum function. | ||
| if (tile.thread_rank() == A2AV_TILE_SIZE - 1) { | ||
| auto my_tile_len = tile_prefix_sums[tileId][A2AV_TILE_SIZE - 1] + output_splits[tileId * npes + nsplits_per_tile - 1]; | ||
| // Up align | ||
| len_per_tile[tileId] = (my_tile_len + major_align) / major_align * major_align; | ||
| // When `nsplits` is small, not every tile gets data to sum. They can skip | ||
| // this local prefix sum. | ||
| if (nsplits_per_tile > 0) { | ||
| // Each tile calculates its own prefix sum, return value is the sum of all elements in the tile. | ||
| int64_t my_tile_len = prefixSum_warp(tile_prefix_sums[tileId], output_splits + tileId * npes, nsplits_per_tile); | ||
| // Last thread in each tile does the up aligning. | ||
| if (laneId == A2AV_TILE_SIZE - 1) { | ||
| auto aligned_len = (my_tile_len + major_align - 1) / major_align * major_align; | ||
| // In case `aligned_len` is 0, we set it to `major_align` to avoid an | ||
| // empty bin, bc cutlass currently does not support it. See | ||
| // https://github.com/pytorch/pytorch/issues/152668. | ||
| len_per_tile[tileId] = max(aligned_len, major_align); | ||
| } | ||
| } | ||
| end_of_tile_prefix_sum: | ||
| __syncthreads(); | ||
|
|
||
| // Starting offset of each tile | ||
| __shared__ int64_t start_offset_per_tile[NUM_TILES]; | ||
| // Prefix sum again to get the tiles' start offsets. This is a block-wide prefix sum. | ||
| prefixSum(start_offset_per_tile, len_per_tile, NUM_TILES); | ||
|
||
| __syncthreads(); | ||
|
|
||
| // Add tile offset to every element in the tile | ||
| tile_prefix_sums[tileId][tile.thread_rank()] += start_offset_per_tile[tileId]; | ||
| tile_prefix_sums[tileId][laneId] += start_offset_per_tile[tileId]; | ||
| __syncthreads(); | ||
|
|
||
| // Target a different e based on bid | ||
|
|
@@ -486,6 +490,9 @@ at::Tensor nvshmem_all_to_all_vdev_2d( | |
| length 7 (making a total of 12), and if the `major_align` is set to 16, | ||
| the output offset of c1 will be 16. Similar for c2 and c3. This value has | ||
| no effect on the offset of the minor dimension, i.e. d0, d1, d2 and d3. | ||
| Note: since cutlass does not support empty bins, we set the aligned length | ||
| to `major_align` if it is 0. See | ||
| https://github.com/pytorch/pytorch/issues/152668. | ||
| */ | ||
| auto input_hdl = c10d::symmetric_memory::rendezvous(input, group_name); | ||
| auto out_hdl = c10d::symmetric_memory::rendezvous(out, group_name); | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You also need ne < NUM_TILES?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
correct. Let me add it.