8000 [a2av] Align length of major dimension in output of 2D a2av by kwen2501 · Pull Request #155172 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Update
[ghstack-poisoned]
  • Loading branch information
kwen2501 committed Jun 5, 2025
commit 93041d6e102037b1b2f88a70798824609948fdf7
7 changes: 5 additions & 2 deletions test/distributed/test_nvshmem.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,9 +185,12 @@ def test_nvshmem_all_to_all_vdev_2d(self) -> None:
expert_sum = 0
for j in range(self.world_size):
expert_sum += out_split_list[i][j]
align_pad = align - (expert_sum % align)
# Align up expert_sum
expert_sum_aligned = (expert_sum + align - 1) // align * align
# If 0, make it at least `align` (bc cutlass currently does not support empty bins)
expert_sum_aligned = max(expert_sum_aligned, align)
# last element absorbs the padding
out_split_list[i][-1] += align_pad
out_split_list[i][-1] += expert_sum_aligned - expert_sum

out_splits_padded = torch.tensor(out_split_list, device=self.device).reshape(-1)
out_offsets = torch.cumsum(out_splits_padded, dim=0) # inclusive scan
Expand Down
61 changes: 34 additions & 27 deletions torch/csrc/distributed/c10d/nvshmem_extension.cu
8000
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,6 @@

namespace c10d::nvshmem_extension {

using namespace cooperative_groups;
namespace cg = cooperative_groups;

using c10d::symmetric_memory::StoreExchange;
static StoreExchange storeExchange = StoreExchange("nvshmem_ext");

Expand Down Expand Up @@ -350,8 +347,10 @@ __global__ void exchangeSplitAndOffset_2d(int64_t* in_out_splits, int mype, int
nvshmemx_barrier_all_block();
}

// This is an warp-scope, exclusive prefix sum.
__device__ void prefixSum_warp(int64_t *odata, int64_t *idata, int n) {
// This is an warp-scope, exclusive prefix sum. When called by a block of
// threads, each warp will perform an independent prefix sum, concurrently.
// Returns the sum of all elements in the warp.
__device__ int64_t prefixSum_warp(int64_t *odata, int64_t *idata, int n) {
CUDA_KERNEL_ASSERT(n <= WARP_SIZE);
constexpr int NUM_WARPS = THREADS_PER_BLOCK / WARP_SIZE;

Expand All @@ -365,11 +364,13 @@ __device__ void prefixSum_warp(int64_t *odata, int64_t *idata, int n) {
int64_t thread_data = (tid < n) ? idata[tid] : 0;

// Compute the warp-wide exclusive prefix sum
int64_t warp_aggregate;
int warp_id = threadIdx.x / WARP_SIZE;
WarpScan(temp_storage[warp_id]).ExclusiveSum(thread_data, thread_data);
WarpScan(temp_storage[warp_id]).ExclusiveSum(thread_data, thread_data, warp_aggregate);

// Store the result
odata[tid] = thread_data;
return warp_aggregate;
}

// This is for abstracting a thread-group-scope, exclusive prefix sum.
Expand All @@ -389,43 +390,46 @@ __global__ void allToAllV_2d(void *send_data, void *recv_data, int64_t* in_out_s

// Split the thread block into tiles
constexpr int NUM_TILES = THREADS_PER_BLOCK / A2AV_TILE_SIZE;
thread_group tile = cg::tiled_partition(this_thread_block(), A2AV_TILE_SIZE);
int tileId = tid / A2AV_TILE_SIZE;
int laneId = tid % A2AV_TILE_SIZE;
// Each tile calculates its own prefix sum
__shared__ int64_t tile_prefix_sums[NUM_TILES][A2AV_TILE_SIZE];
// A tile takes care of npes worth of splits
int nsplits_per_tile = min(npes, nsplits - tileId * npes);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You also need ne < NUM_TILES?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

correct. Let me add it.

// TODO: currently it is assumed that the number of PE's is smaller than `A2AV_TILE_SIZE`
CUDA_KERNEL_ASSERT(nsplits_per_tile <= A2AV_TILE_SIZE);
// TODO: currently it is assumed that the number of PE's is smaller than
// `A2AV_TILE_SIZE` bc the warp-scope prefix sum can only handle up to
// WARP_SIZE elements
CUDA_KERNEL_ASSERT(npes <= A2AV_TILE_SIZE);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

scan can easily do multiple elements per thread, so we can relax this assert, but it's fine for now

// Similarly, the number of experts per rank is also assumed to be smaller
// than `NUM_TILES`
CUDA_KERNEL_ASSERT(ne <= NUM_TILES);

// Total length of each tile
__shared__ int64_t len_per_tile[NUM_TILES];
// Starting offset of each tile
__shared__ int64_t start_offset_per_tile[NUM_TILES];
// This tile does not need to do tile-wise prefix sum
if (nsplits_per_tile < 0) goto end_of_tile_prefix_sum;

// Each tile calculates its own prefix sum
prefixSum_warp(tile_prefix_sums[tileId], output_splits + tileId * npes, nsplits_per_tile);

// Last thread in each tile does the up aligning.
// Note: using the last thread to read the last sum from `tile_prefix_sums` so
// that we can save a __syncthreads(). This is safe because the last thread is
// the one that writes the last sum in the prefixSum function.
if (tile.thread_rank() == A2AV_TILE_SIZE - 1) {
auto my_tile_len = tile_prefix_sums[tileId][A2AV_TILE_SIZE - 1] + output_splits[tileId * npes + nsplits_per_tile - 1];
// Up align
len_per_tile[tileId] = (my_tile_len + major_align) / major_align * major_align;
// When `nsplits` is small, not every tile gets data to sum. They can skip
// this local prefix sum.
if (nsplits_per_tile > 0) {
// Each tile calculates its own prefix sum, return value is the sum of all elements in the tile.
int64_t my_tile_len = prefixSum_warp(tile_prefix_sums[tileId], output_splits + tileId * npes, nsplits_per_tile);
// Last thread in each tile does the up aligning.
if (laneId == A2AV_TILE_SIZE - 1) {
auto aligned_len = (my_tile_len + major_align - 1) / major_align * major_align;
// In case `aligned_len` is 0, we set it to `major_align` to avoid an
// empty bin, bc cutlass currently does not support it. See
// https://github.com/pytorch/pytorch/issues/152668.
len_per_tile[tileId] = max(aligned_len, major_align);
}
}
end_of_tile_prefix_sum:
__syncthreads();

// Starting offset of each tile
__shared__ int64_t start_offset_per_tile[NUM_TILES];
// Prefix sum again to get the tiles' start offsets. This is a block-wide prefix sum.
prefixSum(start_offset_per_tile, len_per_tile, NUM_TILES);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

only one warp has to do it?

Copy link
Collaborator Author
@kwen2501 kwen2501 Jun 5, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

True, our thread block size is typically smaller than 1024, and 1024 / WARP_SIZE = 32 at max.
I am templating prefixSum_warp to

template <int NUM_WARPS>
__device__ int64_t prefixSum_warp(int64_t *odata, int64_t *idata, int n)

so that we can call either

  • prefixSum_warp<NUM_TILES> for concurrent, warp-wise prefix sums, or
  • prefixSum_warp<1> when only 1 warp is needed.

__syncthreads();

// Add tile offset to every element in the tile
tile_prefix_sums[tileId][tile.thread_rank()] += start_offset_per_tile[tileId];
tile_prefix_sums[tileId][laneId] += start_offset_per_tile[tileId];
__syncthreads();

// Target a different e based on bid
Expand Down Expand Up @@ -486,6 +490,9 @@ at::Tensor nvshmem_all_to_all_vdev_2d(
length 7 (making a total of 12), and if the `major_align` is set to 16,
the output offset of c1 will be 16. Similar for c2 and c3. This value has
no effect on the offset of the minor dimension, i.e. d0, d1, d2 and d3.
Note: since cutlass does not support empty bins, we set the aligned length
to `major_align` if it is 0. See
https://github.com/pytorch/pytorch/issues/152668.
*/
auto input_hdl = c10d::symmetric_memory::rendezvous(input, group_name);
auto out_hdl = c10d::symmetric_memory::rendezvous(out, group_name);
Expand Down
Loading
0