8000 [a2av] Align length of major dimension in output of 2D a2av by kwen2501 · Pull Request #155172 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content
Closed
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 34 additions & 12 deletions test/distributed/test_nvshmem.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,19 +142,20 @@ def test_nvshmem_all_to_all_vdev_2d(self) -> None:
nsplits = ne * self.world_size
# Number of elements for an expert is random between [0, k)
k = 3
# Align
align = 16
inp_splits = torch.randint(k, (nsplits,), device=self.device)
inp_numel = inp_splits.sum().item()
# Exchange input splits to get output splits
out_splits = torch.zeros_like(inp_splits)
dist.all_to_all_single(out_splits, inp_splits)
# We do a .t() here because there is a rank-major to expert-major shuffle
out_splits_t = out_splits.reshape(self.world_size, ne).t().reshape(-1)
out_splits_t = out_splits.reshape(self.world_size, ne).t()

# Total number of output elements
out_numel = out_splits.sum().item()
# Align up to make it bigger
align = 16
out_numel_max = (out_numel + align - 1) // align * align
# Align-up makes it bigger
out_numel_max = (out_numel + align * ne) // align * align

inp = symm_mem.empty(inp_numel, dtype=dtype, device=self.device).fill_(
self.rank
Expand All @@ -167,20 +168,37 @@ def test_nvshmem_all_to_all_vdev_2d(self) -> None:
in_out_splits[0].copy_(inp_splits)

torch.ops.symm_mem.nvshmem_all_to_all_vdev_2d(
inp, out, in_out_splits, group_name
inp, out, in_out_splits, group_name, align
)
received_out_splits = in_out_splits[1]
received_out_offsets = in_out_splits[2]

# Check input splits (row 0) -- should not change
torch.testing.assert_close(in_out_splits[0], inp_splits)

# Check output splits (row 1)
torch.testing.assert_close(in_out_splits[1], out_splits_t)
torch.testing.assert_close(received_out_splits, out_splits_t.reshape(-1))

# Check output offsets (row 2)
out_offsets = torch.cumsum(out_splits_t, dim=0) # inclusive scan
# output offsets from `nvshmem_all_to_all_vdev` is exclusive scan
self.assertEqual(in_out_splits[2][0], 0)
torch.testing.assert_close(in_out_splits[2][1:], out_offsets[:-1])
out_split_list = out_splits_t.tolist()
for i in range(ne):
expert_sum = 0
for j in range(self.world_size):
expert_sum += out_split_list[i][j]
# Align up expert_sum
expert_sum_aligned = (expert_sum + align - 1) // align * align
# If 0, make it at least `align` (bc cutlass currently does not support empty bins)
expert_sum_aligned = max(expert_sum_aligned, align)
# last element absorbs the padding
out_split_list[i][-1] += expert_sum_aligned - expert_sum

out_splits_padded = torch.tensor(out_split_list, device=self.device).reshape(-1)
out_offsets = torch.cumsum(out_splits_padded, dim=0) # inclusive scan
# Make it exclusive scan because that's what `nvshmem_all_to_all_vdev_2d` returns
out_offsets = torch.cat(
[torch.zeros(1, device=self.device), out_offsets[:-1]]
).to(torch.int64)
torch.testing.assert_close(received_out_offsets, out_offsets)

# Check data
expected = torch.empty(out_numel, dtype=dtype, device=self.device)
Expand All @@ -199,8 +217,12 @@ def test_nvshmem_all_to_all_vdev_2d(self) -> None:
chunk = expected[offset - out_splits[chunk_id] : offset]
result_list.append(chunk)

final = torch.cat(result_list)
torch.testing.assert_close(out[:out_numel], final)
# Do a chunk-wise comparison
for c, chunk in enumerate(result_list):
start = received_out_offsets[c].item()
split = received_out_splits[c].item()
received_chunk = out[start : start + split]
torch.testing.assert_close(received_chunk, chunk)


if __name__ == "__main__":
Expand Down
2 changes: 1 addition & 1 deletion torch/csrc/distributed/c10d/SymmetricMemory.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -281,7 +281,7 @@ TORCH_LIBRARY_FRAGMENT(symm_mem, m) {
m.def(
"nvshmem_all_to_all_vdev(Tensor input, Tensor(a!) out, Tensor(a!) in_out_splits, str group_name) -> Tensor(a!)");
m.def(
"nvshmem_all_to_all_vdev_2d(Tensor input, Tensor(a!) out, Tensor(a!) in_out_splits, str group_name) -> Tensor(a!)");
"nvshmem_all_to_all_vdev_2d(Tensor input, Tensor(a!) out, Tensor(a!) in_out_splits, str group_name, int major_align) -> Tensor(a!)");
}

TORCH_LIBRARY_IMPL(symm_mem, Meta, m) {
Expand Down
116 changes: 108 additions & 8 deletions torch/csrc/distributed/c10d/nvshmem_extension.cu
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ using c10d::symmetric_memory::StoreExchange;
static StoreExchange storeExchange = StoreExchange("nvshmem_ext");

#define THREADS_PER_BLOCK 512
#define WARP_SIZE 32

// Bootstrap based on user's setting for NCCL
// Long term, this may be a bit unclean; short term, it improves UX
Expand Down Expand Up @@ -346,20 +347,100 @@ __global__ void exchangeSplitAndOffset_2d(int64_t* in_out_splits, int mype, int
nvshmemx_barrier_all_block();
}

// This is an warp-scope, exclusive prefix sum. When called by a block of
// threads, each warp will perform an independent prefix sum, concurrently.
// Returns the sum of all elements in the warp.
// `NUM_WARPS` is the number of warps participating the concurrent prefix sum.
template <int NUM_WARPS>
__device__ int64_t prefixSum_warp(int64_t *odata, int64_t *idata, int n) {
CUDA_KERNEL_ASSERT(n <= WARP_SIZE);

// Specialize WarpScan for type int
using WarpScan = at_cuda_detail::cub::WarpScan<int64_t>;
// Allocate WarpScan shared memory for N warps
__shared__ typename WarpScan::TempStorage temp_storage[NUM_WARPS];

int warp_id = threadIdx.x / WARP_SIZE;
if (warp_id >= NUM_WARPS) {
return 0;
}

// Obtain input item for each thread
int tid = threadIdx.x % WARP_SIZE;
int64_t thread_data = (tid < n) ? idata[tid] : 0;

// Total sum of all elements in the warp
int64_t warp_aggregate;
// Compute the warp-wide exclusive prefix sum
WarpScan(temp_storage[warp_id]).ExclusiveSum(thread_data, thread_data, warp_aggregate);

// Store the result
odata[tid] = thread_data;
return warp_aggregate;
}

// This is for abstracting a thread-group-scope, exclusive prefix sum.
// Since we use warp-scope prefix sum, the thread group size is limited to warp size.
#define A2AV_TILE_SIZE WARP_SIZE

// This kernel is used to do the actual data exchange.
// `in_out_splits` has the same definition as in `exchangeSplitAndOffset`.
// `stride` is the stride at dim 0, unit in byte.
// For meaning of `mype` and `npes`, see the docstring of `nvshmem_all_to_all_vdev_2d`.
__global__ void allToAllV_2d(void *send_data, void *recv_data, int64_t* in_out_splits, size_t stride, int mype, int npes, int ne) {
__global__ void allToAllV_2d(void *send_data, void *recv_data, int64_t* in_out_splits, size_t stride, int mype, int npes, int ne, int64_t major_align) {
int nsplits = npes * ne;
auto output_splits = in_out_splits + nsplits;
auto source_offsets = in_out_splits + nsplits * 2;
int bid = blockIdx.x;
int tid = threadIdx.x;

// Calculate the output offsets
__shared__ int64_t e_offsets[THREADS_PER_BLOCK];
prefixSum(e_offsets, output_splits, nsplits);
// Split the thread block into tiles
constexpr int NUM_TILES = THREADS_PER_BLOCK / A2AV_TILE_SIZE;
int tileId = tid / A2AV_TILE_SIZE;
int laneId = tid % A2AV_TILE_SIZE;
// Each tile calculates its own prefix sum
__shared__ int64_t tile_prefix_sums[NUM_TILES][A2AV_TILE_SIZE];
// A tile takes care of npes worth of splits
int nsplits_per_tile = min(npes, nsplits - tileId * npes);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You also need ne < NUM_TILES?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

correct. Let me add it.

// TODO: currently it is assumed that the number of PE's is smaller than
// `A2AV_TILE_SIZE` bc the warp-scope prefix sum can only handle up to
// WARP_SIZE elements
CUDA_KERNEL_ASSERT(npes <= A2AV_TILE_SIZE);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

scan can easily do multiple elements per thread, so we can relax this assert, but it's fine for now

// Similarly, the number of experts per rank is also assumed to be smaller
// than `NUM_TILES`
CUDA_KERNEL_ASSERT(ne <= NUM_TILES);

// Total length of each tile
__shared__ int64_t len_per_tile[NUM_TILES];
// When `nsplits` is small, not every tile gets data to sum. They can skip
// this local prefix sum.
if (nsplits_per_tile > 0) {
// Each tile calculates its own prefix sum, return value is the sum of all elements in the tile.
int64_t my_tile_len = prefixSum_warp<NUM_TILES>(tile_prefix_sums[tileId], output_splits + tileId * npes, nsplits_per_tile);
// Last thread in each tile does the up aligning.
if (laneId == A2AV_TILE_SIZE - 1) {
auto aligned_len = (my_tile_len + major_align - 1) / major_align * major_align;
// In case `aligned_len` is 0, we set it to `major_align` to avoid an
// empty bin, bc cutlass currently does not support it. See
// https://github.com/pytorch/pytorch/issues/152668.
len_per_tile[tileId] = max(aligned_len, major_align);
}
}
__syncthreads();

// Starting offset of each tile
__shared__ int64_t start_offset_per_tile[NUM_TILES];
// Prefix sum again to get the tiles' start offsets.
// `NUM_TILES` is typically not greater than 32, because 32 tiles * 32 threads
// = 1024 threads, and this kernel is launched within 1024 threads. Thus, we
// can use warp-scope prefix sum.
static_assert(NUM_TILES <= WARP_SIZE);
// Only 1 warp is needed
prefixSum_warp<1>(start_offset_per_tile, len_per_tile, NUM_TILES);
__syncthreads();

// Add tile offset to every element in the tile
tile_prefix_sums[tileId][laneId] += start_offset_per_tile[tileId];
__syncthreads();

// Target a different e based on bid
Expand All @@ -368,7 +449,8 @@ __global__ void allToAllV_2d(void *send_data, void *recv_data, int64_t* in_out_s
// Amount from `peer` for `e`
auto peer_size = output_splits[eid] * stride;
auto source_offset = source_offsets[eid] * stride;
auto write_offset = e_offsets[eid] * stride;
auto e_offset = tile_prefix_sums[eid / npes][peer];
auto write_offset = e_offset * stride;
nvshmemx_getmem_block(
(char*)recv_data + write_offset,
(char*)send_data + source_offset,
Expand All @@ -377,15 +459,16 @@ __global__ void allToAllV_2d(void *send_data, void *recv_data, int64_t* in_out_s
}
// Write out the output offsets (to the scratchpad line)
if (bid == 0 && tid < nsplits) {
source_offsets[tid] = e_offsets[tid];
source_offsets[tid] = tile_prefix_sums[tid / npes][tid % npes];
}
}

at::Tensor nvshmem_all_to_all_vdev_2d(
at::Tensor& input,
at::Tensor& out,
at::Tensor& in_out_splits,
std::string group_name) {
std::string group_name,
int64_t major_align) {
/* Perform a 2D AllToAllv shuffle operation using NVSHMEM, with split information provided on device.
* Arguments:
* - `input` is the input tensor
Expand All @@ -397,6 +480,8 @@ at::Tensor nvshmem_all_to_all_vdev_2d(
output splits (OUT) and
output offsets (OUT).
* - `group_name` is the name of the group to use for the collective operation.
* - `major_align` is the alignment of the "major dimension" of the output
sequence. See below for details.

* A 2D AllToAllv shuffle is illustrated below:
(world_size = 2, ne = 2, total number of 5276 experts = 4)
Expand All @@ -410,12 +495,23 @@ at::Tensor nvshmem_all_to_all_vdev_2d(
`in_out_splits[0]`). That is, the 2D AllToAllv shuffle achives a
transpose from rank-major order at input to expert-major order at
output.

* If `major_align` is not 1, the output offsets of c1, c2, c3 will be
up-aligned to this value. For example, if c0 has length 5 and d0 has
length 7 (making a total of 12), and if the `major_align` is set to 16,
the output offset of c1 will be 16. Similar for c2 and c3. This value has
no effect on the offset of the minor dimension, i.e. d0, d1, d2 and d3.
Note: since cutlass does not support empty bins, we set the aligned length
to `major_align` if it is 0. See
https://github.com/pytorch/pytorch/issues/152668.
*/
auto input_hdl = c10d::symmetric_memory::rendezvous(input, group_name);
auto out_hdl = c10d::symmetric_memory::rendezvous(out, group_name);
auto splits_hdl = c10d::symmetric_memory::rendezvous(in_out_splits, group_name);
int rank = input_hdl->get_rank();
int world_size = input_hdl->get_world_size();
// TODO: world_size is currently limited by the number of elements in a WarpScan.
TORCH_CHECK(world_size <= A2AV_TILE_SIZE, "world_size must be smaller than A2AV_TILE_SIZE", A2AV_TILE_SIZE);

void* input_ptr = input_hdl->get_buffer_ptrs()[rank];
void* output_ptr = out_hdl->get_buffer_ptrs()[rank];
Expand All @@ -428,8 +524,11 @@ at::Tensor nvshmem_all_to_all_vdev_2d(
"input, out and in_out_splits must be contiguous");
TORCH_CHECK(split_shape.size() == 2 && split_shape[0] == 3, "in_out_splits must be 2D with 3 rows");
TORCH_CHECK(split_shape[1] % world_si B94A ze == 0, "Each row of in_out_splits must be a multiple of world_size");

// Number of experts per rank
int ne = split_shape[1] / world_size;
constexpr int NUM_TILES = THREADS_PER_BLOCK / A2AV_TILE_SIZE;
TORCH_CHECK(ne <= NUM_TILES, "Number of experts must be smaller than NUM_TILES", NUM_TILES);

// Set device context for getting the stream and launching kernels below
c10::cuda::CUDAGuard guard(input.device());
Expand Down Expand Up @@ -466,7 +565,8 @@ at::Tensor nvshmem_all_to_all_vdev_2d(
&stride_bytes,
&rank,
&world_size,
&ne};
&ne,
&major_align};
nvshmemx_collective_launch(
(const void*)allToAllV_2d,
dim3(num_blocks),
Expand Down
3 changes: 2 additions & 1 deletion torch/csrc/distributed/c10d/nvshmem_extension.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ at::Tensor nvshmem_all_to_all_vdev_2d(
at::Tensor& input,
at::Tensor& out,
at::Tensor& in_out_splits,
std::string group_name);
std::string group_name,
int64_t major_align = 1);

} // namespace c10d::nvshmem_extension
Loading
0