8000 [a2av] Align length of major dimension in output of 2D a2av by kwen2501 · Pull Request #155172 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content

Conversation

@kwen2501
Copy link
Collaborator
@kwen2501 kwen2501 commented Jun 4, 2025

Stack from ghstack (oldest at bottom):

Downstream consumer of the 2D all-to-all-v is often a group GEMM.
Today the GEMM often have an alignment requirement on the chunk sizes within grouped sequence, where each chunk carries the tokens headed for an expert. For example, torch._group_mm requires an alignment of 8.

This PR adds that alignment capability, when user passes in a major_align argument, so that no extra padding step is needed.

The key in supporting that is making the output offsets aligned to such value. (Output offsets are returned to the users in the 3rd row of in_out_splits, on device. The 2nd row, output splits, are unaffected by this alignment value -- i.e. reflecting true number of tokens for an expert.)

The algorithm is as follows.

502413288_678786854922438_530852083153996358_n

In detailed implementation, we use warp scan to calculate prefix sum on the "block" illustrated above. As a result, the "block" size, i.e. npes is currently limited to warp size 32.

cc @H-Huang @awgu @wanchaol @fegin @fduwjj @wz337 @wconstab @d4l3k

[ghstack-poisoned]
@pytorch-bot
Copy link
pytorch-bot bot commented Jun 4, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/155172

Note: Links to docs will display an error until the docs builds have been completed.

✅ You can merge normally! (2 Unrelated Failures)

As of commit fd596f4 with merge base fa85434 (image):

FLAKY - The following job failed but was likely due to flakiness present on trunk:

UNSTABLE - The following job is marked as unstable, possibly due to flakiness on trunk:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@pytorch-bot pytorch-bot bot added oncall: distributed Add this issue/PR to distributed oncall triage queue release notes: distributed (c10d) release notes category labels Jun 4, 2025
[ghstack-poisoned]
@kwen2501 kwen2501 requested a review from ngimel June 5, 2025 00:05
@kwen2501 kwen2501 requested a review from lessw2020 June 5, 2025 00:09
[ghstack-poisoned]
kwen2501 added a commit that referenced this pull request Jun 5, 2025
Tile-based prefix sum

Working

Use cub for prefix sum of tile offset

ghstack-source-id: 96c6c13
Pull-Request-resolved: #155172
prefixSum(e_offsets, output_splits, nsplits);
// Split the thread block into tiles
constexpr int NUM_TILES = THREADS_PER_BLOCK / A2AV_TILE_SIZE;
thread_group tile = cg::tiled_partition(this_thread_block(), A2AV_TILE_SIZE);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

tbh, doesn't look like you really need thread_group, you could just compute warpId and laneId and use them?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep, I was hoping there would be a cg-wise scan. But for now, let me revert it to basic form.

// Starting offset of each tile
__shared__ int64_t start_offset_per_tile[NUM_TILES];
// This tile does not need to do tile-wise prefix sum
if (nsplits_per_tile < 0) goto end_of_tile_prefix_sum;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you put this in the conditional instead? Also, if nsplits_per_tile is 0, you also should skip prefix sum?

// Each tile calculates its own prefix sum
__shared__ int64_t tile_prefix_sums[NUM_TILES][A2AV_TILE_SIZE];
// A tile takes care of npes worth of splits
int nsplits_per_tile = min(npes, nsplits - tileId * npes);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You also need ne < NUM_TILES?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

correct. Let me add it.

// that we can save a __syncthreads(). This is safe because the last thread is
// the one that writes the last sum in the prefixSum function.
if (tile.thread_rank() == A2AV_TILE_SIZE - 1) {
auto my_tile_len = tile_prefix_sums[tileId][A2AV_TILE_SIZE - 1] + output_splits[tileId * npes + nsplits_per_tile - 1];
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you should have computed it in prefixSum_warp using an overload with warp_aggregate https://nvidia.github.io/cccl/cub/api/classcub_1_1WarpScan.html#_CPPv4N3cub8WarpScan12ExclusiveSumE1TR1TR1T

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice!

if (tile.thread_rank() == A2AV_TILE_SIZE - 1) {
auto my_tile_len = tile_prefix_sums[tileId][A2AV_TILE_SIZE - 1] + output_splits[tileId * npes + nsplits_per_tile - 1];
// Up align
len_per_tile[tileId] = (my_tile_len + major_align) / major_align * major_align;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If my_tile_len was already aligned this will increment it by unneeded major_align

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does cutlass support zero length now? If supported, I can use + major_align - 1 here.
I was using + major_align instead of if-else for the ease of documentation.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It still doesn't (they promise to support in 4.0), but it still makes sense to increment only empty bins, and leave non-empty as is.

[ghstack-poisoned]
@kwen2501
Copy link
Collaborator Author
kwen2501 commented Jun 5, 2025

@ngimel all comments accommodated. Thanks!

[ghstack-poisoned]
// TODO: currently it is assumed that the number of PE's is smaller than
// `A2AV_TILE_SIZE` bc the warp-scope prefix sum can only handle up to
// WARP_SIZE elements
CUDA_KERNEL_ASSERT(npes <= A2AV_TILE_SIZE);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

scan can easily do multiple elements per thread, so we can relax this assert, but it's fine for now

// Starting offset of each tile
__shared__ int64_t start_offset_per_tile[NUM_TILES];
// Prefix sum again to get the tiles' start offsets. This is a block-wide prefix sum.
prefixSum(start_offset_per_tile, len_per_tile, NUM_TILES);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

only one warp has to do it?

Copy link
Collaborator Author
@kwen2501 kwen2501 Jun 5, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

True, our thread block size is typically smaller than 1024, and 1024 / WARP_SIZE = 32 at max.
I am templating prefixSum_warp to

template <int NUM_WARPS>
__device__ int64_t prefixSum_warp(int64_t *odata, int64_t *idata, int n)

so that we can call either

  • prefixSum_warp<NUM_TILES> for concurrent, warp-wise prefix sums, or
  • prefixSum_warp<1> when only 1 warp is needed.

[ghstack-poisoned]
[ghstack-poisoned]
@kwen2501 kwen2501 added the ciflow/trunk Trigger trunk jobs on your pull request label Jun 5, 2025
[ghstack-poisoned]
[ghstack-poisoned]
@kwen2501
Copy link
Collaborator Author
kwen2501 commented Jun 6, 2025

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge failed

Reason: Not merging any PRs at the moment because there is a merge blocking https://github.com/pytorch/pytorch/labels/ci:%20sev issue open at:
#155265

Details for Dev Infra team Raised by workflow job

[ghstack-poisoned]
kwen2501 added a commit that referenced this pull request Jun 6, 2025
Tile-based prefix sum

Working

Use cub for prefix sum of tile offset

ghstack-source-id: 9378ffc
Pull-Request-resolved: #155172

Comments

warp_aggregate

Handle zero bin case without waste

Templates prefixSum_warp
@kwen2501
Copy link
Collaborator Author
kwen2501 commented Jun 6, 2025

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@pytorchmergebot
Copy link
Collaborator

Merge failed

Reason: Command git -C /home/runner/work/pytorch/pytorch cherry-pick -x ed9cfdd209d88c2687c06e185ff372f1ec288a77 returned non-zero exit code 1

Auto-merging test/distributed/test_nvshmem.py
CONFLICT (content): Merge conflict in test/distributed/test_nvshmem.py
Auto-merging torch/csrc/distributed/c10d/SymmetricMemory.cpp
CONFLICT (content): Merge conflict in torch/csrc/distributed/c10d/SymmetricMemory.cpp
Auto-merging torch/csrc/distributed/c10d/nvshmem_extension.cu
CONFLICT (content): Merge conflict in torch/csrc/distributed/c10d/nvshmem_extension.cu
Auto-merging torch/csrc/distributed/c10d/nvshmem_extension.cuh
CONFLICT (content): Merge conflict in torch/csrc/distributed/c10d/nvshmem_extension.cuh
error: could not apply ed9cfdd209d... [a2av] Align length of major dimension in output of 2D a2av
hint: After resolving the conflicts, mark them with
hint: "git add/rm <pathspec>", then run
hint: "git cherry-pick --continue".
hint: You can instead skip this commit with "git cherry-pick --skip".
hint: To abort and get back to the state before "git cherry-pick",
hint: run "git cherry-pick --abort".
hint: Disable this message with "git config set advice.mergeConflict false"
Details for Dev Infra team Raised by workflow job

@kwen2501
Copy link
Collaborator Author
kwen2501 commented Jun 6, 2025

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@github-actions github-actions bot deleted the gh/kwen2501/165/head branch July 9, 2025 02:20
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment 7CC9

Labels

ciflow/trunk Trigger trunk jobs on your pull request Merged oncall: distributed Add this issue/PR to distributed oncall triage queue release notes: distributed (c10d) release notes category

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants

0