E5DE [a2av] 2D all-to-all-vdev by kwen2501 · Pull Request #155058 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 74 additions & 0 deletions test/distributed/test_nvshmem.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,80 @@ def test_nvshmem_all_to_all_vdev(self) -> None:
dist.all_to_all_single(expected, inp, out_splits.tolist(), inp_splits.tolist())
torch.testing.assert_close(out[:out_numel], expected)

@skipIfRocm
def test_nvshmem_all_to_all_vdev_2d(self) -> None:
torch.manual_seed(42 + self.rank)
self._init_device()

group_name = dist.group.WORLD.group_name
symm_mem.enable_symm_mem_for_group(group_name)

dtype = torch. 10BC0 float
# Number of experts per rank
ne = 4
nsplits = ne * self.world_size
# Number of elements for an expert is random between [0, k)
k = 3
inp_splits = torch.randint(k, (nsplits,), device=self.device)
inp_numel = inp_splits.sum().item()
# Exchange input splits to get output splits
out_splits = torch.zeros_like(inp_splits)
dist.all_to_all_single(out_splits, inp_splits)
# We do a .t() here because there is a rank-major to expert-major shuffle
out_splits_t = out_splits.reshape(self.world_size, ne).t().reshape(-1)

# Total number of output elements
out_numel = out_splits.sum().item()
# Align up to make it bigger
align = 16
out_numel_max = (out_numel + align - 1) // align * align

inp = symm_mem.empty(inp_numel, dtype=dtype, device=self.device).fill_(
self.rank
)
out = symm_mem.empty(out_numel_max, dtype=dtype, device=self.device).fill_(-1)
in_out_splits = symm_mem.empty(
(3, nsplits), dtype=torch.int64, device=self.device
).fill_(-1)
# Row 0 is input splits
in_out_splits[0].copy_(inp_splits)

torch.ops.symm_mem.nvshmem_all_to_all_vdev_2d(
inp, out, in_out_splits, group_name
)

# Check input splits (row 0) -- should not change
torch.testing.assert_close(in_out_splits[0], inp_splits)

# Check output splits (row 1)
torch.testing.assert_close(in_out_splits[1], out_splits_t)

# Check output offsets (row 2)
out_offsets = torch.cumsum(out_splits_t, dim=0) # inclusive scan
# output offsets from `nvshmem_all_to_all_vdev` is exclusive scan
self.assertEqual(in_out_splits[2][0], 0)
torch.testing.assert_close(in_out_splits[2][1:], out_offsets[:-1])

# Check data
expected = torch.empty(out_numel, dtype=dtype, device=self.device)
inp_splits_rank = inp_splits.reshape(self.world_size, ne).sum(1)
out_splits_rank = out_splits.reshape(self.world_size, ne).sum(1)
dist.all_to_all_single(
expected, inp, out_splits_rank.tolist(), inp_splits_rank.tolist()
)
# We still need to shuffle `expected`
out_offsets = torch.cumsum(out_splits, dim=0) # inclusive scan
result_list = []
for j in range(ne):
for i in range(self.world_size):
chunk_id = i * ne + j
offset = out_offsets[chunk_id]
chunk = expected[offset - out_splits[chunk_id] : offset]
result_list.append(chunk)

final = torch.cat(result_list)
torch.testing.assert_close(out[:out_numel], final)


if __name__ == "__main__":
run_tests()
2 changes: 2 additions & 0 deletions torch/csrc/distributed/c10d/SymmetricMemory.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -280,6 +280,8 @@ TORCH_LIBRARY_FRAGMENT(symm_mem, m) {
"nvshmem_all_to_all(Tensor input, Tensor(a!) out, str group_name) -> Tensor(a!)");
m.def(
"nvshmem_all_to_all_vdev(Tensor input, Tensor(a!) out, Tensor(a!) in_out_splits, str group_name) -> Tensor(a!)");
m.def(
"nvshmem_all_to_all_vdev_2d(Tensor input, Tensor(a!) out, Tensor(a!) in_out_splits, str group_name) -> Tensor(a!)");
}

TORCH_LIBRARY_IMPL(symm_mem, Meta, m) {
Expand Down
191 changes: 186 additions & 5 deletions torch/csrc/distributed/c10d/nvshmem_extension.cu
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ at::Tensor nvshmem_all_to_all(
}

// This is an exclusive prefix sum function that calculates read (or write) offsets for each peer.
__device__ void prefixSum(int64_t *odata, int64_t *idata, int n) {
__device__ int64_t prefixSum(int64_t *odata, int64_t *idata, int n) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

btw previously I've seen int64 scan slow down kernel big time, you might want to check the performance

// Specialize BlockScan for a 1D block of threads, of type int64_t.
// - `BLOCK_SCAN_WARP_SCANS` is a low-latency scan algorithm (instead of high
// throughput which we don't need here).
Expand All @@ -159,12 +159,12 @@ __device__ void prefixSum(int64_t *odata, int64_t *idata, int n) {
int64_t thread_data = (tid < n) ? idata[tid] : 0;

// Collectively compute the block-wide exclusive prefix sum
BlockScanT(temp_storage).ExclusiveSum(thread_data, thread_data);
int64_t block_aggregate;
BlockScanT(temp_storage).ExclusiveSum(thread_data, thread_data, block_aggregate);

// Store the result
if (tid < n) {
odata[tid] = thread_data;
}
odata[tid] = thread_data;
return block_aggregate;
}

// This kernel is used to exchange output splits and source offsets between peers.
Expand Down Expand Up @@ -311,11 +311,192 @@ at::Tensor nvshmem_all_to_all_vdev(
return out;
}

// Start of `nvshmem_all_to_all_vdev_2d`
// This kernel is used to exchange output splits and source offsets between peers.
// For meaning of `mype` and `npes`, see the docstring of `nvshmem_all_to_all_vdev_2d`.
// `in_out_splits` is of size (3, npes * ne) and contains:
// - input splits (IN)
// - output splits (OUT) and
// - source offsets (OUT).
__global__ void exchangeSplitAndOffset_2d(int64_t* in_out_splits, int mype, int npes, int ne, size_t input_dim0) {
int nsplits = npes * ne;
auto input_splits = in_out_splits;
auto output_splits = in_out_splits + nsplits;
auto source_offsets = in_out_splits + nsplits * 2;
int tid = threadIdx.x;

__shared__ int64_t peer_offsets[THREADS_PER_BLOCK];

// Scan input splits to get the source offsets
auto sum_of_splits = prefixSum(peer_offsets, input_splits, nsplits);
__syncthreads();;
CUDA_KERNEL_ASSERT(sum_of_splits <= input_dim0);

// Use 1 block to do the exchange
if (tid < nsplits) {
int peer = tid / ne;
int e = tid % ne;
// This does a transpose from rank-major order to expert-major order
int dst_offset = e * npes + mype;
auto split_val = input_splits[tid];
CUDA_KERNEL_ASSERT(split_val >= 0);
nvshmem_int64_p(source_offsets + dst_offset, peer_offsets[tid], peer);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

here would also make sense to check that there are no negative numbers in splits, and that sum of splits is less than input size?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added a negative check.

nvshmem_int64_p(output_splits + dst_offset, split_val, peer);
}
// This barrier ensures that all remote PEs see the updated values
nvshmemx_barrier_all_block();
}

// This kernel is used to do the actual data exchange.
// `in_out_splits` has the same definition as in `exchangeSplitAndOffset`.
// `stride` is the stride at dim 0, unit in byte.
// For meaning of `mype` and `npes`, see the docstring of `nvshmem_all_to_all_vdev_2d`.
__global__ void allToAllV_2d(void *send_data, void *recv_data, int64_t* in_out_splits, size_t stride, int mype, int npes, int ne) {
int nsplits = npes * ne;
auto output_splits = in_out_splits + nsplits;
auto source_offsets = in_out_splits + nsplits * 2;
int bid = blockIdx.x;
int tid = threadIdx.x;

// Calculate the output offsets
__shared__ int64_t e_offsets[THREADS_PER_BLOCK];
prefixSum(e_offsets, output_splits, nsplits);
__syncthreads();

// Target a different e based on bid
for (int eid = bid; eid < nsplits; eid += gridDim.x) {
int peer = eid % npes;
// Amount from `peer` for `e`
auto peer_size = output_splits[eid] * stride;
auto source_offset = source_offsets[eid] * stride;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you need to check that these offsets are within tensor, so there are no OOB reads

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added block_aggregate and check against input size 0.

auto write_offset = e_offsets[eid] * stride;
nvshmemx_getmem_block(
(char*)recv_data + write_offset,
(char*)send_data + source_offset,
peer_size,
peer);
}
// Write out the output offsets (to the scratchpad line)
if (bid == 0 && tid < nsplits) {
source_offsets[tid] = e_offsets[tid];
}
}

at::Tensor nvshmem_all_to_all_vdev_2d(
at::Tensor& input,
at::Tensor& out,
at::Tensor& in_out_splits,
std::string group_name) {
/* Perform a 2D AllToAllv shuffle operation using NVSHMEM, with split information provided on device.
* Arguments:
* - `input` is the input tensor
* - `out` is the output tensor
* - `in_out_splits` is a 2D tensor of size (3, `world_size` * `ne`). In the
scenario of Mixture-of-Experts models, `ne` is the number of experts per
rank. The rows of `in_out_splits` are (in order):
input splits (IN)
output splits (OUT) and
output offsets (OUT).
* - `group_name` is the name of the group to use for the collective operation.

* A 2D AllToAllv shuffle is illustrated below:
(world_size = 2, ne = 2, total number of experts = 4)
Source: | Rank 0 | Rank 1 |
| c0 | c1 | c2 | c3 | d0 | d1 | d2 | d3 |

Dest : | Rank 0 | Rank 1 |
| c0 | d0 | c1 | d1 | c2 | d2 | c3 | d3 |
where each `c_i` / `d_i` are slices of the `input` tensor, targeting
expert `i`, with length indicated by input splits (in
`in_out_splits[0]`). That is, the 2D AllToAllv shuffle achives a
transpose from rank-major order at input to expert-major order at
output.
*/
auto input_hdl = c10d::symmetric_memory::rendezvous(input, group_name);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you also need to check input/output dimensionality and contiguity

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added a contiguity check. I can't think of a dimensionality requirement here (can be 1D, 2D or n-D).

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also you need to check dtypes - integer for splits, same dtype for input output

auto out_hdl = c10d::symmetric_memory::rendezvous(out, group_name);
auto splits_hdl = c10d::symmetric_memory::rendezvous(in_out_splits, group_name);
int rank = input_hdl->get_rank();
int world_size = input_hdl->get_world_size();

void* input_ptr = input_hdl->get_buffer_ptrs()[rank];
void* output_ptr = out_hdl->get_buffer_ptrs()[rank];
int64_t* splits_ptr = (int64_t*)(splits_hdl->get_buffer_ptrs()[rank]);

// Shape checks
auto split_shape = in_out_splits.sizes();
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I still don't see a check for in_out_splits.is_contiguous()

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry added.

TORCH_CHECK(in_out_splits.is_contiguous()
&& input.is_contiguous()
&& out.is_contiguous(),
"input, out and in_out_splits must be contiguous");
TORCH_CHECK(split_shape.size() == 2
&& split_shape[0] == 3
&& split_shape[1] % world_size == 0,
"in_out_splits must be 2D with 3 rows, "
"each row must be a multiple of world_size");

// Consistency checks
TORCH_CHECK(input.dtype() == out.dtype()
&& input.stride(0) == out.stride(0),
"input and out must have the same dtype and same stride at dim 0");
TORCH_CHECK(in_out_splits.scalar_type() == at::kLong, "in_out_splits must be int64");

// Number of experts per rank
int ne = split_shape[1] / world_size;

// Set device context for getting the stream and launching kernels below
c10::cuda::CUDAGuard guard(input.device());
auto stream = at::cuda::getCurrentCUDAStream();

// Exchange output splits and source offsets
auto input_dim0 = input.size(0);
// Use collective launch because kernel involves nvshmem barrier
void* args0[] = {
&splits_ptr,
&rank,
&world_size,
&ne,
&input_dim0};
nvshmemx_collective_launch(
(const void*)exchangeSplitAndOffset_2d,
dim3(1),
dim3(THREADS_PER_BLOCK),
args0,
0,
stream);

// CTA Tuning
// Naive for now, use 1 block per expert.
// Total number of blocks is limited to 64 (intra-node) or 8 (inter-node).
int num_blocks = std::min(world_size * ne, world_size > 8 ? 8 : 64);

// Stride at dim 0
size_t stride_bytes = input.stride(0) * input.element_size();
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

here you are assuming that input.stride(0) == output.stride(0), you should check it

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added


// All to all data exchange
void* args1[] = {
&input_ptr,
&output_ptr,
&splits_ptr,
&stride_bytes,
&rank,
&world_size,
&ne};
nvshmemx_collective_launch(
(const void*)allToAllV_2d,
dim3(num_blocks),
dim3(THREADS_PER_BLOCK),
args1,
0,
stream);
return out;
}

} // namespace c10d::nvshmem_extension


TORCH_LIBRARY_IMPL(symm_mem, CUDA, m) {
m.impl("nvshmem_broadcast", c10d::nvshmem_extension::nvshmem_broadcast);
m.impl("nvshmem_all_to_all", c10d::nvshmem_extension::nvshmem_all_to_all);
m.impl("nvshmem_all_to_all_vdev", c10d::nvshmem_extension::nvshmem_all_to_all_vdev);
m.impl("nvshmem_all_to_all_vdev_2d", c10d::nvshmem_extension::nvshmem_all_to_all_vdev_2d);
}
6 changes: 6 additions & 0 deletions torch/csrc/distributed/c10d/nvshmem_extension.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -28,4 +28,10 @@ at::Tensor nvshmem_all_to_all_vdev(
at::Tensor& in_out_splits,
std::string group_name);

at::Tensor nvshmem_all_to_all_vdev_2d(
at::Tensor& input,
at::Tensor& out,
at::Tensor& in_out_splits,
std::string group_name);

} // namespace c10d::nvshmem_extension
Loading
0