@@ -2793,61 +2793,59 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::allreduce_sparse(
2793
2793
tensor = tensor.coalesce ();
2794
2794
at::Tensor outputTensor =
2795
2795
torch::zeros (tensor.sizes (), tensor.options ().layout (torch::kStrided ));
2796
- }
2797
- int dev_in_group = 0 ;
2798
- auto work = collective (
2799
- tensor,
2800
- outputTensor,
2801
- [&](at::Tensor& input,
2802
- at::Tensor& output,
2803
- ncclComm_t comm,
2804
- at::cuda::CUDAStream& stream) {
2805
- auto ncclDataType = getNcclDataType (input.scalar_type ());
2806
- auto ncclReduceOp =
2807
- getNcclReduceOp (opts.reduceOp , input, ncclDataType, comm);
2808
-
2809
- size_t num_elements = output.numel ();
2810
- auto indices = input.indices ();
2811
- auto sizes = input.sizes ();
2812
- int colSize = sizes[1 ];
2813
- auto rows = indices[0 ];
2814
- size_t blockCount = rows.sizes ()[0 ];
2815
- auto recvIndices = indices[0 ] * colSize;
2816
-
2817
- // prevent output and recvIndices from being freed
2818
- c10::cuda::CUDACachingAllocator::recordStream (
2819
- output.storage ().data_ptr (), stream);
2820
- c10::cuda::CUDACachingAllocator::recordStream (
2821
- recvIndices.storage ().data_ptr (), stream);
2822
- auto result = ncclAllReduceSparseBlock (
2823
- input._values ().data_ptr (), // sendbuff
2824
- recvIndices.data_ptr <int64_t >(), // recv_indices
2825
- blockCount, // block_count
2826
- colSize, // block_length
2827
- output.data_ptr (), // recvbuff
2828
- output.numel (), // recv_count
2829
- ncclDataType,
2830
- ncclReduceOp,
2831
- comm,
2832
- stream.stream ());
2833
- return result;
2834
- },
2835
- [](at::cuda::CUDAStream& ncclStream,
2836
- c10::intrusive_ptr<ProcessGroupNCCL::WorkNCCL>& work) {},
2837
- [&](at::cuda::CUDAStream& ncclStream,
2838
- c10::intrusive_ptr<ProcessGroupNCCL::WorkNCCL>& work) {
2839
- // Convert output tensors to sparse and back into tensors.
2840
- at::cuda::CUDAStreamGuard guard (ncclStream);
2841
- if (opts.sparseIndices .has_value ()) {
2842
- tensor = at::sparse_coo_tensor (
2843
- opts.sparseIndices .value (), outputTensor, tensor.sizes ());
2844
- } else {
2845
- tensor = outputTensor.to_sparse ();
2846
- }
2847
- },
2848
- OpType::_ALLREDUCE_SPARSE,
2849
- " nccl:all_reduce_sparse" );
2850
- return work;
2796
+ auto work = collective (
2797
+ tensor,
2798
+ outputTensor,
2799
+ [&](at::Tensor& input,
2800
+ at::Tensor& output,
2801
+ ncclComm_t comm,
2802
+ at::cuda::CUDAStream& stream) {
2803
+ auto ncclDataType = getNcclDataType (input.scalar_type ());
2804
+ auto ncclReduceOp =
2805
+ getNcclReduceOp (opts.reduceOp , input, ncclDataType, comm);
2806
+
2807
+ size_t num_elements = output.numel ();
2808
+ auto indices = input.indices ();
2809
+ auto sizes = input.sizes ();
2810
+ int colSize = sizes[1 ];
2811
+ auto rows = indices[0 ];
2812
+ size_t blockCount = rows.sizes ()[0 ];
2813
+ auto recvIndices = indices[0 ] * colSize;
2814
+
2815
+ // prevent output and recvIndices from being freed
2816
+ c10::cuda::CUDACachingAllocator::recordStream (
2817
+ output.storage ().data_ptr (), stream);
2818
+ c10::cuda::CUDACachingAllocator::recordStream (
2819
+ recvIndices.storage ().data_ptr (), stream);
2820
+ auto result = ncclAllReduceSparseBlock (
2821
+ input._values ().data_ptr (), // sendbuff
2822
+ recvIndices.data_ptr <int64_t >(), // recv_indices
2823
+ blockCount, // block_count
2824
+ colSize, // block_length
2825
+ output.data_ptr (), // recvbuff
2826
+ output.numel (), // recv_count
2827
+ ncclDataType,
2828
+ ncclReduceOp,
2829
+ comm,
2830
+ stream.stream ());
2831
+ return result;
2832
+ },
2833
+ [](at::cuda::CUDAStream& ncclStream,
2834
+ c10::intrusive_ptr<ProcessGroupNCCL::WorkNCCL>& work) {},
2835
+ [&](at::cuda::CUDAStream& ncclStream,
2836
+ c10::intrusive_ptr<ProcessGroupNCCL::WorkNCCL>& work) {
2837
+ // Convert output tensors to sparse and back into tensors.
2838
+ at::cuda::CUDAStreamGuard guard (ncclStream);
2839
+ if (opts.sparseIndices .has_value ()) {
2840
+ tensor = at::sparse_coo_tensor (
2841
+ opts.sparseIndices .value (), outputTensor, tensor.sizes ());
2842
+ } else {
2843
+ tensor = outputTensor.to_sparse ();
2844
+ }
2845
+ },
2846
+ OpType::_ALLREDUCE_SPARSE,
2847
+ " nccl:all_reduce_sparse" );
2848
+ return work;
2851
2849
#else
2852
2850
// If the nccl branch is not "exp" then we just error
2853
2851
C10_THROW_ERROR (
0 commit comments