10BC0 [a2av] 2D all-to-all-vdev by kwen2501 · Pull Request #155058 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content
Closed
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Update
[ghstack-poisoned]
  • Loading branch information
kwen2501 committed Jun 5, 2025
commit cc1537fef122f9c0279534c1f4ac1d39afcbaff8
8 changes: 7 additions & 1 deletion torch/csrc/distributed/c10d/nvshmem_extension.cu
Original file line number Diff line number Diff line change
Expand Up @@ -337,8 +337,10 @@ __global__ void exchangeSplitAndOffset_2d(int64_t* in_out_splits, int mype, int
int e = tid % ne;
// This does a transpose from rank-major order to expert-major order
int dst_offset = e * npes + mype;
auto split_val = input_splits[tid];
CUDA_KERNEL_ASSERT(split_val >= 0);
nvshmem_int64_p(source_offsets + dst_offset, peer_offsets[tid], peer);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

here would also make sense to check that there are no negative numbers in splits, and that sum of splits is less than input size?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added a negative check.

nvshmem_int64_p(output_splits + dst_offset, input_splits[tid], peer);
nvshmem_int64_p(output_splits + dst_offset, split_val, peer);
}
// This barrier ensures that all remote PEs see the updated values
nvshmemx_barrier_all_block();
Expand Down Expand Up @@ -420,6 +422,10 @@ at::Tensor nvshmem_all_to_all_vdev_2d(
int64_t* splits_ptr = (int64_t*)(splits_hdl->get_buffer_ptrs()[rank]);

auto split_shape = in_out_splits.sizes();
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I still don't see a check for in_out_splits.is_contiguous()

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry added.

TORCH_CHECK(in_out_splits.is_contiguous()
&& input.is_contiguous()
&& out.is_contiguous(),
"input, out and in_out_splits must be contiguous");
TORCH_CHECK(split_shape.size() == 2 && split_shape[0] == 3, "in_out_splits must be 2D with 3 rows");
TORCH_CHECK(split_shape[1] % world_size == 0, "Each row of in_out_splits must be a multiple of world_size");
// Number of experts per rank
Expand Down
Loading
0