-
Notifications
You must be signed in to change notification settings - Fork 26.7k
[a2av] 2D all-to-all-vdev #155058
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[a2av] 2D all-to-all-vdev #155058
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/155058
Note: Links to docs will display an error until the docs builds have been completed. ✅ You can merge normally! (2 Unrelated Failures)As of commit d849727 with merge base fa85434 ( FLAKY - The following jobs failed but were likely due to flakiness present on trunk:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
| int64_t* splits_ptr = (int64_t*)(splits_hdl->get_buffer_ptrs()[rank]); | ||
|
|
||
| // Number of experts per rank | ||
| int ne = in_out_splits.stride(0) / world_size; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
you need to check that in_out_splits is contiguous, and check its sizes, not strides
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
in_out_splits is expected of shape [3, world_size * ne].
I will add a check. Thanks!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Right, but if tensor is not contiguous stride may satisfy this check, but the tensor size won't.
| int ne = in_out_splits.stride(0) / world_size; | ||
| TORCH_CHECK(ne * world_size == in_out_splits.stride(0), "Each row of in_out_splits must be a multiple of world_size") | ||
|
|
||
| auto stream = at::cuda::getCurrentCUDAStream(input.device().index()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
you need to check that current device matches input's device (unless device guard code is generated automatically, I don't know how you function is registered), and then just use getCurrentCUDAStream() here because you'd get stream for current device
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'd prefer if the API do not require user setting a device guard upfront.
So, I will do something like the following here:
at::cuda::OptionalCUDAGuard gpuGuard(input.device());
auto stream = at::cuda::getCurrentCUDAStream();
nvshmemx_collective_launch(..., stream);
|
@ngimel Updated the PR with the changes mentioned above :) |
| void* output_ptr = out_hdl->get_buffer_ptrs()[rank]; | ||
| int64_t* splits_ptr = (int64_t*)(splits_hdl->get_buffer_ptrs()[rank]); | ||
|
|
||
| auto split_shape = in_out_splits.sizes(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I still don't see a check for in_out_splits.is_contiguous()
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry added.
| transpose from rank-major order at input to expert-major order at | ||
| output. | ||
| */ | ||
| auto input_hdl = c10d::symmetric_memory::rendezvous(input, group_name); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
you also need to check input/output dimensionality and contiguity
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added a contiguity check. I can't think of a dimensionality requirement here (can be 1D, 2D or n-D).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
also you need to check dtypes - integer for splits, same dtype for input output
| int e = tid % ne; | ||
| // This does a transpose from rank-major order to expert-major order | ||
| int dst_offset = e * npes + mype; | ||
| nvshmem_int64_p(source_offsets + dst_offset, peer_offsets[tid], peer); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
here would also make sense to check that there are no negative numbers in splits, and that sum of splits is less than input size?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added a negative check.
| int num_blocks = std::min(world_size * ne, world_size > 8 ? 8 : 64); | ||
|
|
||
| // Stride at dim 0 (assuming input is contiguous, TODO) | ||
| size_t stride_bytes = input.stride(0) * input.element_size(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
here you are assuming that input.stride(0) == output.stride(0), you should check it
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added
| transpose from rank-major order at input to expert-major order at | ||
| output. | ||
| */ | ||
| auto input_hdl = c10d::symmetric_memory::rendezvous(input, group_name); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
also you need to check dtypes - integer for splits, same dtype for input output
| int peer = eid % npes; | ||
| // Amount from `peer` for `e` | ||
| auto peer_size = output_splits[eid] * stride; | ||
| auto source_offset = source_offsets[eid] * stride; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
you need to check that these offsets are within tensor, so there are no OOB reads
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added block_aggregate and check against input size 0.
|
|
||
| // This is an exclusive prefix sum function that calculates read (or write) offsets for each peer. | ||
| __device__ void prefixSum(int64_t *odata, int64_t *idata, int n) { | ||
| __device__ int64_t prefixSum(int64_t *odata, int64_t *idata, int n) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
btw previously I've seen int64 scan slow down kernel big time, you might want to check the performance
|
@pytorchbot merge |
Merge failedReason: Not merging any PRs at the moment because there is a merge blocking https://github.com/pytorch/pytorch/labels/ci:%20sev issue open at: Details for Dev Infra teamRaised by workflow job |
|
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
|
Starting merge as part of PR stack under #155172 |
Downstream consumer of the 2D all-to-all-v is often a group GEMM. Today the GEMM often have an alignment requirement on the chunk sizes within grouped sequence, where each chunk carries the tokens headed for an expert. For example, `torch._group_mm` requires an alignment of 8. This PR adds that alignment capability, when user passes in a `major_align` argument, so that no extra padding step is needed. The key in supporting that is making the output offsets aligned to such value. (Output offsets are returned to the users in the 3rd row of `in_out_splits`, on device. The 2nd row, output splits, are unaffected by this alignment value -- i.e. reflecting true number of tokens for an expert.) The algorithm is as follows.  In detailed implementation, we use warp scan to calculate prefix sum on the "block" illustrated above. As a result, the "block" size, i.e. `npes` is currently limited to warp size 32. Pull Request resolved: #155172 Approved by: https://github.com/ngimel ghstack dependencies: #153653, #153677, #155058
Stack from ghstack (oldest at bottom):
A 2D AllToAllv shuffle is illustrated below:
(
world_size= 2,ne= 2, whereneis number of experts per rank)where each
c_i/d_iare slices of theinputtensor, targeting experti, with length indicated by input splits (inin_out_splits[0]).That is, the 2D AllToAllv shuffle achieves a transpose from rank-major order at input to expert-major order at output.
cc @H-Huang @awgu @wanchaol @fegin @fduwjj @wz337 @wconstab @d4l3k