8000 [a2av] 2D all-to-all-vdev by kwen2501 · Pull Request #155058 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content
Closed
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 74 additions & 0 deletions test/distributed/test_nvshmem.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,80 @@ def test_nvshmem_all_to_all_vdev(self) -> None:
dist.all_to_all_single(expected, inp, out_splits.tolist(), inp_splits.tolist())
torch.testing.assert_close(out[:out_numel], expected)

@skipIfRocm
def test_nvshmem_all_to_all_vdev_2d(self) -> None:
torch.manual_seed(42 + self.rank)
self._init_device()

group_name = dist.group.WORLD.group_name
symm_mem.enable_symm_mem_for_group(group_name)

dtype = torch.float
# Number of experts per rank
ne = 4
nsplits = ne * self.world_size
# Number of elements for an expert is random between [0, k)
k = 3
inp_splits = torch.randint(k, (nsplits,), device=self.device)
inp_numel = inp_splits.sum().item()
# Exchange input splits to get output splits
out_splits = torch.zeros_like(inp_splits)
dist.all_to_all_single(out_splits, inp_splits)
# We do a .t() here because there is a rank-major to expert-major shuffle
out_splits_t = out_splits.reshape(self.world_size, ne).t().reshape(-1)

# Total number of output elements
out_numel = out_splits.sum().item()
# Align up to make it bigger
align = 16
out_numel_max = (out_numel + align - 1) // align * align

inp = symm_mem.empty(inp_numel, dtype=dtype, device=self.device).fill_(
self.rank
)
out = symm_mem.empty(out_numel_max, dtype=dtype, device=self.device).fill_(-1)
in_out_splits = symm_mem.empty(
(3, nsplits), dtype=torch.int64, device=self.device
).fill_(-1)
# Row 0 is input splits
in_out_splits[0].copy_(inp_splits)

torch.ops.symm_mem.nvshmem_all_to_all_vdev_2d(
inp, out, in_out_splits, group_name
)

# Check input splits (row 0) -- should not change
torch.testing.assert_close(in_out_splits[0], inp_splits)

# Check output splits (row 1)
torch.testing.assert_close(in_out_splits[1], out_splits_t)

# Check output offsets (row 2)
out_offsets = torch.cumsum(out_splits_t, dim=0) # inclusive scan
# output offsets from `nvshmem_all_to_all_vdev` is exclusive scan
self.assertEqual(in_out_splits[2][0], 0)
torch.testing.assert_close(in_out_splits[2][1:], out_offsets[:-1])

# Check data
expected = torch.empty(out_numel, dtype=dtype, device=self.device)
inp_splits_rank = inp_splits.reshape(self.world_size, ne).sum(1)
out_splits_rank = out_splits.reshape(self.world_size, ne).sum(1)
dist.all_to_all_single(
expected, inp, out_splits_rank.tolist(), inp_splits_rank.tolist()
)
# We still need to shuffle `expected`
out_offsets = torch.cumsum(out_splits, dim=0) # inclusive scan
result_list = []
for j in range(ne):
for i in range(self.world_size):
chunk_id = i * ne + j
offset = out_offsets[chunk_id]
chunk = expected[offset - out_splits[chunk_id] : offset]
result_list.append(chunk)

final = torch.cat(result_list)
torch.testing.assert_close(out[:out_numel], final)


if __name__ == "__main__":
run_tests()
2 changes: 2 additions & 0 deletions torch/csrc/distributed/c10d/SymmetricMemory.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -280,6 +280,8 @@ TORCH_LIBRARY_FRAGMENT(symm_mem, m) {
"nvshmem_all_to_all(Tensor input, Tensor(a!) out, str group_name) -> Tensor(a!)");
m.def(
"nvshmem_all_to_all_vdev(Tensor input, Tensor(a!) out, Tensor(a!) in_out_splits, str group_name) -> Tensor(a!)");
m.def(
"nvshmem_all_to_all_vdev_2d(Tensor input, Tensor(a!) out, Tensor(a!) in_out_splits, str group_name) -> Tensor(a!)");
}

TORCH_LIBRARY_IMPL(symm_mem, Meta, m) {
Expand Down
161 changes: 161 additions & 0 deletions torch/csrc/distributed/c10d/nvshmem_extension.cu
52E5
Original file line number Diff line number Diff line change
Expand Up @@ -311,11 +311,172 @@ at::Tensor nvshmem_all_to_all_vdev(
return out;
}

// Start of `nvshmem_all_to_all_vdev_2d`
// This kernel is used to exchange output splits and source offsets between peers.
// For meaning of `mype` and `npes`, see the docstring of `nvshmem_all_to_all_vdev_2d`.
// `in_out_splits` is of size (3, npes * ne) and contains:
// - input splits (IN)
// - output splits (OUT) and
// - source offsets (OUT).
__global__ void exchangeSplitAndOffset_2d(int64_t* in_out_splits, int mype, int npes, int ne) {
int nsplits = npes * ne;
auto input_splits = in_out_splits;
auto output_splits = in_out_splits + nsplits;
auto source_offsets = in_out_splits + nsplits * 2;
int tid = threadIdx.x;

__shared__ int64_t peer_offsets[THREADS_PER_BLOCK];

// Scan input splits to get the source offsets
prefixSum(peer_offsets, input_splits, nsplits);
__syncthreads();;

// Use 1 block to do the exchange
if (tid < nsplits) {
int peer = tid / ne;
int e = tid % ne;
// This does a transpose from rank-major order to expert-major order
int dst_offset = e * npes + mype;
nvshmem_int64_p(source_offsets + dst_offset, peer_offsets[tid], peer);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

here would also make sense to check that there are no negative numbers in splits, and that sum of splits is less than input size?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added a negative check.

nvshmem_int64_p(output_splits + dst_offset, input_splits[tid], peer);
}
// This barrier ensures that all remote PEs see the updated values
nvshmemx_barrier_all_block();
}

// This kernel is used to do the actual data exchange.
// `in_out_splits` has the same definition as in `exchangeSplitAndOffset`.
// `stride` is the stride at dim 0, unit in byte.
// For meaning of `mype` and `npes`, see the docstring of `nvshmem_all_to_all_vdev_2d`.
__global__ void allToAllV_2d(void *send_data, void *recv_data, int64_t* in_out_splits, size_t stride, int mype, int npes, int ne) {
int nsplits = npes * ne;
auto output_splits = in_out_splits + nsplits;
auto source_offsets = in_out_splits + nsplits * 2;
int bid = blockIdx.x;
int tid = threadIdx.x;

// Calculate the output offsets
__shared__ int64_t e_offsets[THREADS_PER_BLOCK];
prefixSum(e_offsets, output_splits, nsplits);
__syncthreads();

// Target a different e based on bid
for (int eid = bid; eid < nsplits; eid += gridDim.x) {
int peer = eid % npes;
// Amount from `peer` for `e`
auto peer_size = output_splits[eid] * stride;
auto source_offset = source_offsets[eid] * stride;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you need to check that these offsets are within tensor, so there are no OOB reads

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added block_aggregate and check against input size 0.

auto write_offset = e_offsets[eid] * stride;
nvshmemx_getmem_block(
(char*)recv_data + write_offset,
(char*)send_data + source_offset,
peer_size,
peer);
}
// Write out the output offsets (to the scratchpad line)
if (bid == 0 && tid < nsplits) {
source_offsets[tid] = e_offsets[tid];
}
}

at::Tensor nvshmem_all_to_all_vdev_2d(
at::Tensor& input,
at::Tensor& out,
at::Tensor& in_out_splits,
std::string group_name) {
/* Perform a 2D AllToAllv shuffle operation using NVSHMEM, with split information provided on device.
* Arguments:
* - `input` is the input tensor
* - `out` is the output tensor
* - `in_out_splits` is a 2D tensor of size (3, `world_size` * `ne`). In the
scenario of Mixture-of-Experts models, `ne` is the number of experts per
rank. The rows of `in_out_splits` are (in order):
input splits (IN)
output splits (OUT) and
output offsets (OUT).
* - `group_name` is the name of the group to use for the collective operation.

* A 2D AllToAllv shuffle is illustrated below:
(world_size = 2, ne = 2, total number of experts = 4)
Source: | Rank 0 | Rank 1 |
| c0 | c1 | c2 | c3 | d0 | d1 | d2 | d3 |

Dest : | Rank 0 | Rank 1 |
| c0 | d0 | c1 | d1 | c2 | d2 | c3 | d3 |
where each `c_i` / `d_i` are slices of the `input` tensor, targeting
expert `i`, with length indicated by input splits (in
`in_out_splits[0]`). That is, the 2D AllToAllv shuffle achives a
transpose from rank-major order at input to expert-major order at
output.
*/
auto input_hdl = c10d::symmetric_memory::rendezvous(input, group_name);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you also need to check input/output dimensionality and contiguity

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added a contiguity check. I can't think of a dimensionality requirement here (can be 1D, 2D or n-D).

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also you need to check dtypes - integer for splits, same dtype for input output

auto out_hdl = c10d::symmetric_memory::rendezvous(out, group_name);
auto splits_hdl = c10d::symmetric_memory::rendezvous(in_out_splits, group_name);
int rank = input_hdl->get_rank();
int world_size = input_hdl->get_world_size();

void* input_ptr = input_hdl->get_buffer_ptrs()[rank];
void* output_ptr = out_hdl->get_buffer_ptrs()[rank];
int64_t* splits_ptr = (int64_t*)(splits_hdl->get_buffer_ptrs()[rank]);

auto split_shape = in_out_splits.sizes();
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I still don't see a check for in_out_splits.is_contiguous()

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry added.

TORCH_CHECK(split_shape.size() == 2 && split_shape[0] == 3, "in_out_splits must be 2D with 3 rows");
TORCH_CHECK(split_shape[1] % world_size == 0, "Each row of in_out_splits must be a multiple of world_size");
// Number of experts per rank
int ne = split_shape[1] / world_size;

// Set device context for getting the stream and launching kernels below
c10::cuda B8DB ::CUDAGuard guard(input.device());
auto stream = at::cuda::getCurrentCUDAStream();

// Exchange output splits and source offsets
// Use collective launch because kernel involves nvshmem barrier
void* args0[] = {
&splits_ptr,
&rank,
&world_size,
&ne};
nvshmemx_collective_launch(
(const void*)exchangeSplitAndOffset_2d,
dim3(1),
dim3(THREADS_PER_BLOCK),
args0,
0,
stream);

// CTA Tuning
// Naive for now, use 1 block per expert.
// Total number of blocks is limited to 64 (intra-node) or 8 (inter-node).
int num_blocks = std::min(world_size * ne, world_size > 8 ? 8 : 64);

// Stride at dim 0 (assuming input is contiguous, TODO)
size_t stride_bytes = input.stride(0) * input.element_size();
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

here you are assuming that input.stride(0) == output.stride(0), you should check it

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added


// All to all data exchange
void* args1[] = {
&input_ptr,
&output_ptr,
&splits_ptr,
&stride_bytes,
&rank,
&world_size,
&ne};
nvshmemx_collective_launch(
(const void*)allToAllV_2d,
dim3(num_blocks),
dim3(THREADS_PER_BLOCK),
args1,
0,
stream);
return out;
}

} // namespace c10d::nvshmem_extension


TORCH_LIBRARY_IMPL(symm_mem, CUDA, m) {
m.impl("nvshmem_broadcast", c10d::nvshmem_extension::nvshmem_broadcast);
m.impl("nvshmem_all_to_all", c10d::nvshmem_extension::nvshmem_all_to_all);
m.impl("nvshmem_all_to_all_vdev", c10d::nvshmem_extension::nvshmem_all_to_all_vdev);
m.impl("nvshmem_all_to_all_vdev_2d", c10d::nvshmem_extension::nvshmem_all_to_all_vdev_2d);
}
6 changes: 6 additions & 0 deletions torch/csrc/distributed/c10d/nvshmem_extension.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -28,4 +28,10 @@ at::Tensor nvshmem_all_to_all_vdev(
at::Tensor& in_out_splits,
std::string group_name);

at::Tensor nvshmem_all_to_all_vdev_2d(
at::Tensor& input,
at::Tensor& out,
at::Tensor& in_out_splits,
std::string group_name);

} // namespace c10d::nvshmem_extension
Loading
0