10000 [a2av] 2D all-to-all-vdev by kwen2501 · Pull Request #155058 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content
Closed
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Update
[ghstack-poisoned]
  • Loading branch information
kwen2501 committed Jun 3, 2025
commit baa5600928e387053e2a40d00956af4313b6587c
11 changes: 7 additions & 4 deletions torch/csrc/distributed/c10d/nvshmem_extension.cu
Original file line number Diff line number Diff line change
Expand Up @@ -419,11 +419,15 @@ at::Tensor nvshmem_all_to_all_vdev_2d(
void* output_ptr = out_hdl->get_buffer_ptrs()[rank];
int64_t* splits_ptr = (int64_t*)(splits_hdl->get_buffer_ptrs()[rank]);

auto split_shape = in_out_splits.sizes();
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I still don't see a check for in_out_splits.is_contiguous()

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry added.

TORCH_CHECK(split_shape.size() == 2 && split_shape[0] == 3, "in_out_splits must be 2D with 3 rows");
TORCH_CHECK(split_shape[1] % world_size == 0, "Each row of in_out_splits must be a multiple of world_size");
// Number of experts per rank
int ne = in_out_splits.stride(0) / world_size;
TORCH_CHECK(ne * world_size == in_out_splits.stride(0), "Each row of in_out_splits must be a multiple of world_size")
int ne = split_shape[1] / world_size;

auto stream = at::cuda::getCurrentCUDAStream(input.device().index());
// Set device context for getting the stream and launching kernels below
c10::cuda::CUDAGuard guard(input.device());
auto stream = at::cuda::getCurrentCUDAStream();

// Exchange output splits and source offsets
// Use collective launch because kernel involves nvshmem barrier
Expand All @@ -444,7 +448,6 @@ at::Tensor nvshmem_all_to_all_vdev_2d(
// Naive for now, use 1 block per expert.
// Total number of blocks is limited to 64 (intra-node) or 8 (inter-node).
int num_blocks = std::min(world_size * ne, world_size > 8 ? 8 : 64);
LOG(INFO) << "num_blocks: " << num_blocks << ", ne: " << ne << ", world_size: " << world_size;

// Stride at dim 0 (assuming input is contiguous, TODO)
size_t stride_bytes = input.stride(0) * input.element_size();
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

here you are assuming that input.stride(0) == output.stride(0), you should check it

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added

Expand Down
Loading
0