8000 [a2av] 2D all-to-all-vdev by kwen2501 · Pull Request #155058 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content

Conversation

8000
@kwen2501
Copy link
Collaborator
@kwen2501 kwen2501 commented Jun 3, 2025

Stack from ghstack (oldest at bottom):

A 2D AllToAllv shuffle is illustrated below:
(world_size = 2, ne = 2, where ne is number of experts per rank)

        Source: |       Rank 0      |       Rank 1      |
                | c0 | c1 | c2 | c3 | d0 | d1 | d2 | d3 |

        Dest  : |       Rank 0      |       Rank 1      |
                | c0 | d0 | c1 | d1 | c2 | d2 | c3 | d3 |

where each c_i / d_i are slices of the input tensor, targeting expert i, with length indicated by input splits (in in_out_splits[0]).

That is, the 2D AllToAllv shuffle achieves a transpose from rank-major order at input to expert-major order at output.

cc @H-Huang @awgu @wanchaol @fegin @fduwjj @wz337 @wconstab @d4l3k

[ghstack-poisoned]
@pytorch-bot
Copy link
pytorch-bot bot commented Jun 3, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/155058

Note: Links to docs will display an error until the docs builds have been completed.

✅ You can merge normally! (2 Unrelated Failures)

As of commit d849727 with merge base fa85434 (image):

FLAKY - The following jobs failed but were likely due to flakiness present on trunk:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

kwen2501 added a commit that referenced this pull request Jun 3, 2025
ghstack-source-id: d0c5cfa
Pull-Request-resolved: #155058
@pytorch-bot pytorch-bot bot added oncall: distributed Add this issue/PR to distributed oncall triage queue release notes: distributed (c10d) release notes category labels Jun 3, 2025
@kwen2501 kwen2501 requested review from fduwjj, fegin and ngimel June 3, 2025 19:01
[ghstack-poisoned]
int64_t* splits_ptr = (int64_t*)(splits_hdl->get_buffer_ptrs()[rank]);

// Number of experts per rank
int ne = in_out_splits.stride(0) / world_size;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you need to check that in_out_splits is contiguous, and check its sizes, not strides

Copy link
Collaborator Author
@kwen2501 kwen2501 Jun 3, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

in_out_splits is expected of shape [3, world_size * ne].
I will add a check. Thanks!

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right, but if tensor is not contiguous stride may satisfy this check, but the tensor size won't.

int ne = in_out_splits.stride(0) / world_size;
TORCH_CHECK(ne * world_size == in_out_splits.stride(0), "Each row of in_out_splits must be a multiple of world_size")

auto stream = at::cuda::getCurrentCUDAStream(input.device().index());
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you need to check that current device matches input's device (unless device guard code is generated automatically, I don't know how you function is registered), and then just use getCurrentCUDAStream() here because you'd get stream for current device

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd prefer if the API do not require user setting a device guard upfront.
So, I will do something like the following here:

at::cuda::OptionalCUDAGuard gpuGuard(input.device());
auto stream = at::cuda::getCurrentCUDAStream();
nvshmemx_collective_launch(..., stream);

[ghstack-poisoned]
kwen2501 added a commit that referenced this pull request Jun 3, 2025
ghstack-source-id: 5d2a22a
Pull-Request-resolved: #155058

Add device guard
@kwen2501
Copy link
Collaborator Author
kwen2501 commented Jun 3, 2025

@ngimel Updated the PR with the changes mentioned above :)

@kwen2501 kwen2501 added the ciflow/trunk Trigger trunk jobs on your pull request label Jun 5, 2025
void* output_ptr = out_hdl->get_buffer_ptrs()[rank];
int64_t* splits_ptr = (int64_t*)(splits_hdl->get_buffer_ptrs()[rank]);

auto split_shape = in_out_splits.sizes();
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I still don't see a check for in_out_splits.is_contiguous()

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry added.

transpose from rank-major order at input to expert-major order at
output.
*/
auto input_hdl = c10d::symmetric_memory::rendezvous(input, group_name);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you also need to check input/output dimensionality and contiguity

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added a contiguity check. I can't think of a dimensionality requirement here (can be 1D, 2D or n-D).

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also you need to check dtypes - integer for splits, same dtype for input output

int e = tid % ne;
// This does a transpose from rank-major order to expert-major order
int dst_offset = e * npes + mype;
nvshmem_int64_p(source_offsets + dst_offset, peer_offsets[tid], peer);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

here would also make sense to check that there are no negative numbers in splits, and that sum of splits is less than input size?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added a negative check.

[ghstack-poisoned]
int num_blocks = std::min(world_size * ne, world_size > 8 ? 8 : 64);

// Stride at dim 0 (assuming input is contiguous, TODO)
size_t stride_bytes = input.stride(0) * input.element_size();
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

here you are assuming that input.stride(0) == output.stride(0), you should check it

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added

transpose from rank-major order at input to expert-major order at
output.
*/
auto input_hdl = c10d::symmetric_memory::rendezvous(input, group_name);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also you need to check dtypes - integer for splits, same dtype for input output

int peer = eid % npes;
// Amount from `peer` for `e`
auto peer_size = output_splits[eid] * stride;
auto source_offset = source_offsets[eid] * stride;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you need to check that these offsets are within tensor, so there are no OOB reads

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added block_aggregate and check against input size 0.

[ghstack-poisoned]

// This is an exclusive prefix sum function that calculates read (or write) offsets for each peer.
__device__ void prefixSum(int64_t *odata, int64_t *idata, int n) {
__device__ int64_t prefixSum(int64_t *odata, int64_t *idata, int n) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

btw previously I've seen int64 scan slow down kernel big time, you might want to check the performance

@kwen2501
Copy link
Collaborator Author
kwen2501 commented Jun 6, 2025

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge failed

Reason: Not merging any PRs at the moment because there is a merge blocking https://github.com/pytorch/pytorch/labels/ci:%20sev issue open at:
#155265

Details for Dev Infra team Raised by workflow job

@kwen2501
Copy link
Collaborator Author
kwen2501 commented Jun 6, 2025

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@pytorchmergebot
Copy link
Collaborator

Starting merge as part of PR stack under #155172

pytorchmergebot pushed a commit that referenced this pull request Jun 6, 2025
Downstream consumer of the 2D all-to-all-v is often a group GEMM.
Today the GEMM often have an alignment requirement on the chunk sizes within grouped sequence, where each chunk carries the tokens headed for an expert. For example, `torch._group_mm` requires an alignment of 8.

This PR adds that alignment capability, when user passes in a `major_align` argument, so that no extra padding step is needed.

The key in supporting that is making the output offsets aligned to such value. (Output offsets are returned to the users in the 3rd row of `in_out_splits`, on device. The 2nd row, output splits, are unaffected by this alignment value -- i.e. reflecting true number of tokens for an expert.)

The algorithm is as follows.

![502413288_678786854922438_530852083153996358_n](https://github.com/user-attachments/assets/557624a3-150e-4ab6-ba8b-1dbaa5ac01ac)

In detailed implementation, we use warp scan to calculate prefix sum on the "block" illustrated above. As a result, the "block" size, i.e. `npes` is currently limited to warp size 32.

Pull Request resolved: #155172
Approved by: https://github.com/ngimel
ghstack dependencies: #153653, #153677, #155058
@github-actions github-actions bot deleted the gh/kwen2501/164/head branch July 9, 2025 02:19
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/trunk Trigger trunk jobs on your pull request Merged oncall: distributed Add this issue/PR to distributed oncall triage queue release notes: distributed (c10d) release notes category

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants

0