E5DF [a2av] Align length of major dimension in output of 2D a2av by kwen2501 · Pull Request #155172 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content
Closed
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Update
[ghstack-poisoned]
  • Loading branch information
kwen2501 committed Jun 5, 2025
commit ba6b78db3dc8531bc9495c695aeed0280f4b6644
3 changes: 3 additions & 0 deletions torch/csrc/distributed/c10d/nvshmem_extension.cu
Original file line number Diff line number Diff line change
Expand Up @@ -518,8 +518,11 @@ at::Tensor nvshmem_all_to_all_vdev_2d(
auto split_shape = in_out_splits.sizes();
TORCH_CHECK(split_shape.size() == 2 && split_shape[0] == 3, "in_out_splits must be 2D with 3 rows");
TORCH_CHECK(split_shape[1] % world_size == 0, "Each row of in_out_splits must be a multiple of world_size");

// Number of experts per rank
int ne = split_shape[1] / world_size;
constexpr int NUM_TILES = THREADS_PER_BLOCK / A2AV_TILE_SIZE;
TORCH_CHECK(ne <= NUM_TILES, "Number of experts must be smaller than NUM_TILES", NUM_TILES);

// Set device context for getting the stream and launching kernels below
c10::cuda::CUDAGuard guard(input.device());
Expand Down
Loading
0