10BC0 [a2av] Align length of major dimension in output of 2D a2av by kwen2501 · Pull Request #155172 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content
Closed
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Update
[ghstack-poisoned]
  • Loading branch information
kwen2501 committed Jun 4, 2025
commit ecfeaf4cb2eb9396a6fee9e3b39b77d9916746a3
12 changes: 6 additions & 6 deletions torch/csrc/distributed/c10d/nvshmem_extension.cu
Original file line number Diff line number Diff line change
Expand Up @@ -466,7 +466,7 @@ at::Tensor nvshmem_all_to_all_vdev_2d(
output offsets (OUT).
* - `group_name` is the name of the group to use for the collective operation.
* - `major_align` is the alignment of the "major dimension" of the output
tensor. See below for details.
sequence. See below for details.

* A 2D AllToAllv shuffle is illustrated below:
(world_size = 2, ne = 2, total number of experts = 4)
Expand All @@ -481,11 +481,11 @@ at::Tensor nvshmem_all_to_all_vdev_2d(
transpose from rank-major order at input to expert-major order at
output.

* If `major_align` is not 1, the output offset of c1 will be up-aligned to
this value. For example, if c0 has length 5 and d0 has length 7 (making a
total of 12), if and the `major_align` is 16, the output offset of c1 will
be 16. Similar for c2 and c3. This value has no effect on the offset of
the minor dimensions, i.e. d0, d1, d2 and d3.
* If `major_align` is not 1, the output offsets of c1, c2, c3 will be
up-aligned to this value. For example, if c0 has length 5 and d0 has
length 7 (making a total of 12), and if the `major_align` is set to 16,
the output offset of c1 will be 16. Similar for c2 and c3. This value has
no effect on the offset of the minor dimension, i.e. d0, d1, d2 and d3.
*/
auto input_hdl = c10d::symmetric_memory::rendezvous(input, group_name);
auto out_hdl = c10d::symmetric_memory::rendezvous(out, group_name);
Expand Down
Loading
0