E5EA [a2av] 2D all-to-all-vdev by kwen2501 · Pull Request #155058 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content
Closed
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Update
[ghstack-poisoned]
  • Loading branch information
kwen2501 committed Jun 3, 2025
commit dc5013783cd096d53c1852f2c276bf61e2302d95
23 changes: 14 additions & 9 deletions torch/csrc/distributed/c10d/nvshmem_extension.cu
Original file line number Diff line number Diff line change
Expand Up @@ -389,20 +389,25 @@ at::Tensor nvshmem_all_to_all_vdev_2d(
* - `input` is the input tensor
* - `out` is the output tensor
* - `in_out_splits` is a 2D tensor of size (3, `world_size` * `ne`). In the
* scenario of Mixture-of-Experts models, `ne` is the number of experts per
* rank. The rows are (in order):
scenario of Mixture-of-Experts models, `ne` is the number of experts per
rank. The rows of `in_out_splits` are (in order):
input splits (IN)
output splits (OUT) and
output offsets (OUT).
* - `group_name` is the name of the group to use for the collective operation.

* A 2D AllToAllv shuffle is illustrated below:
(world_size = 2, ne = 2)
Source: | Rank 0 | Rank 1 |
| e00 | e01 | e02 | e03 | e10 | e11 | e12 | e13 |
Dest: | Rank 0 | Rank 1 |
| e00 | e10 | e01 | e11 | e02 | e12 | e03 | e13 |
where each eij is a slice of the `input` tensor, with length indicated by input splits (`in_out_splits[0]`) at index `i * ne + j`.
That is, the 2D AllToAllv shuffle achives a transpose from rank-major order at input to expert-major order at output.
(world_size = 2, ne = 2, total number of experts = 4)
Source: | Rank 0 | Rank 1 |
| c0 | c1 | c2 | c3 | d0 | d1 | d2 | d3 |

Dest : | Rank 0 | Rank 1 |
| c0 | d0 | c1 | d1 | c2 | d2 | c3 | d3 |
where each `c_i` / `d_i` are slices of the `input` tensor, targeting
expert `i`, with length indicated by input splits (in
`in_out_splits[0]`). That is, the 2D AllToAllv shuffle achives a
transpose from rank-major order at input to expert-major order at
output.
*/
auto input_hdl = c10d::symmetric_memory::rendezvous(input, group_name);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you also need to check input/output dimensionality and contiguity

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added a contiguity check. I can't think of a dimensionality requirement here (can be 1D, 2D or n-D).

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also you need to check dtypes - integer for splits, same dtype for input output

auto out_hdl = c10d::symmetric_memory::rendezvous(out, group_name);
Expand Down
Loading
0