8000 [a2av] Align length of major dimension in output of 2D a2av by kwen2501 · Pull Request #155172 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content
Closed
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Update
[ghstack-poisoned]
  • Loading branch information
kwen2501 committed Jun 5, 2025
commit 8a889c0a99a3bfbe60c5d9d004f8b87ed8f12515
23 changes: 17 additions & 6 deletions torch/csrc/distributed/c10d/nvshmem_extension.cu
Original file line number Diff line number Diff line change
Expand Up @@ -348,22 +348,28 @@ __global__ void exchangeSplitAndOffset_2d(int64_t* in_out_splits, int mype, int
// This is an warp-scope, exclusive prefix sum. When called by a block of
// threads, each warp will perform an independent prefix sum, concurrently.
// Returns the sum of all elements in the warp.
// `NUM_WARPS` is the number of warps participating the concurrent prefix sum.
template <int NUM_WARPS>
__device__ int64_t prefixSum_warp(int64_t *odata, int64_t *idata, int n) {
CUDA_KERNEL_ASSERT(n <= WARP_SIZE);
constexpr int NUM_WARPS = THREADS_PER_BLOCK / WARP_SIZE;

// Specialize WarpScan for type int
using WarpScan = at_cuda_detail::cub::WarpScan<int64_t>;
// Allocate WarpScan shared memory for N warps
__shared__ typename WarpScan::TempStorage temp_storage[NUM_WARPS];

int warp_id = threadIdx.x / WARP_SIZE;
if (warp_id >= NUM_WARPS) {
return 0;
}

// Obtain input item for each thread
int tid = threadIdx.x % WARP_SIZE;
int64_t thread_data = (tid < n) ? idata[tid] : 0;

// Compute the warp-wide exclusive prefix sum
// Total sum of all elements in the warp
int64_t warp_aggregate;
int warp_id = threadIdx.x / WARP_SIZE;
// Compute the warp-wide exclusive prefix sum
WarpScan(temp_storage[warp_id]).ExclusiveSum(thread_data, thread_data, warp_aggregate);

// Store the result
Expand Down Expand Up @@ -408,7 +414,7 @@ __global__ void allToAllV_2d(void *send_data, void *recv_data, int64_t* in_out_s
// this local prefix sum.
if (nsplits_per_tile > 0) {
// Each tile calculates its own prefix sum, return value is the sum of all elements in the tile.
int64_t my_tile_len = prefixSum_warp(tile_prefix_sums[tileId], output_splits + tileId * npes, nsplits_per_tile);
int64_t my_tile_len = prefixSum_warp<NUM_TILES>(tile_prefix_sums[tileId], output_splits + tileId * npes, nsplits_per_tile);
// Last thread in each tile does the up aligning.
if (laneId == A2AV_TILE_SIZE - 1) {
auto aligned_len = (my_tile_len + major_align - 1) / major_align * major_align;
Expand All @@ -422,8 +428,13 @@ __global__ void allToAllV_2d(void *send_data, void *recv_data, int64_t* in_out_s

// Starting offset of each tile
__shared__ int64_t start_offset_per_tile[NUM_TILES];
// Prefix sum again to get the tiles' start offsets. This is a block-wide prefix sum.
prefixSum(start_offset_per_tile, len_per_tile, NUM_TILES);
// Prefix sum again to get the tiles' start offsets.
// `NUM_TILES` is typically not greater than 32, because 32 tiles * 32 threads
// = 1024 threads, and this kernel is launched within 1024 threads. Thus, we
// can use warp-scope prefix sum.
static_assert(NUM_TILES <= WARP_SIZE);
// Only 1 warp is needed
prefixSum_warp<1>(start_offset_per_tile, len_per_tile, NUM_TILES);
__syncthreads();

// Add tile offset to every element in the tile
Expand Down
Loading
0