10000 Update on "[PGNCCL] Launch kernel on current stream & remove `record_… · pytorch/pytorch@ff1f4c4 · GitHub
[go: up one dir, main page]

Skip to content

Commit ff1f4c4

Browse files
committed
Update on "[PGNCCL] Launch kernel on current stream & remove record_stream entirely"
This PR has multiple changes to `ProcessGroupNCCL` (which unfortunately are related): 1. When async_op=False, we directly launch the collective on "current" stream, instead of a trampoline stream and join back. - Resolves #147729 - Resolves #146881 - Also saves two event syncs (which have overhead in case of HIP) and one pybind when we call `work.wait()` in distributed_c10d.py on behalf of user. 2. Entirely remove `record_stream` and use CPU-side stashing for managing tensor lifetime against recycling. - Resolves #147168 3. Remove tensor life management when async_op=False; only use it when async_op=True. 4. To guard against user not calling `work.wait()`, we ask watchdog to unstash tensors after detecting completion of collectives, to prevent us from holding reference to tensors forever. This is a safety net, rather than a service guarantee, see discussion [here](#147168 (comment)). 5. Profile in async_op=False mode would look different -- collective kernels would show up in the same line and compute kernels. Joint work with cenzhaometa who wants to remove the event sync overhead. Cc: ngimel awgu Aidyn-A skyw wconstab leonardo0lyj cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k c-p-i-o [ghstack-poisoned]
2 parents 510dead + 37a9b05 commit ff1f4c4

File tree

1 file changed

+16
-16
lines changed

1 file changed

+16
-16
lines changed

torch/csrc/distributed/c10d/Ops.cpp

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -17,37 +17,37 @@ TORCH_LIBRARY(c10d, m) {
1717
.def("wait", [](const c10::intrusive_ptr<Work>& self) { self->wait(); });
1818
m.class_<ReduceOp>("ReduceOp").def(torch::init<>());
1919
m.def(
20-
"broadcast_(Tensor[] tensors, __torch__.torch.classes.c10d.ProcessGroup process_group, int root_rank, int root_tensor, bool asyncOp, int timeout) -> (Tensor[], __torch__.torch.classes.c10d.Work)");
20+
"broadcast_(Tensor[] tensors, __torch__.torch.classes.c10d.ProcessGroup process_group, int root_rank, int root_tensor, bool async_op=True, int timeout=-1) -> (Tensor[], __torch__.torch.classes.c10d.Work)");
2121
m.def(
22-
"allreduce_(Tensor[] tensors, __torch__.torch.classes.c10d.ProcessGroup process_group, __torch__.torch.classes.c10d.ReduceOp reduce_op, Tensor? sparse_indices, bool asyncOp, int timeout) -> (Tensor[], __torch__.torch.classes.c10d.Work)");
22+
"allreduce_(Tensor[] tensors, __torch__.torch.classes.c10d.ProcessGroup process_group, __torch__.torch.classes.c10d.ReduceOp reduce_op, Tensor? sparse_indices, bool async_op=True, int timeout=-1) -> (Tensor[], __torch__.torch.classes.c10d.Work)");
2323
m.def(
24-
"allreduce_coalesced_(Tensor[] tensors, __torch__.torch.classes.c10d.ProcessGroup process_group, __torch__.torch.classes.c10d.ReduceOp reduce_op, bool asyncOp, int timeout) -> __torch__.torch.classes.c10d.Work");
24+
"allreduce_coalesced_(Tensor[] tensors, __torch__.torch.classes.c10d.ProcessGroup process_group, __torch__.torch.classes.c10d.ReduceOp reduce_op, bool async_op=True, int timeout=-1) -> __torch__.torch.classes.c10d.Work");
2525
m.def(
26-
"allgather_(Tensor[][] output_tensors, Tensor[] input_tensors, __torch__.torch.classes.c10d.ProcessGroup process_group, bool asyncOp, int timeout) -> (Tensor[][], __torch__.torch.classes.c10d.Work)");
26+
"allgather_(Tensor[][] output_tensors, Tensor[] input_tensors, __torch__.torch.classes.c10d.ProcessGroup process_group, bool async_op=True, int timeout=-1) -> (Tensor[][], __torch__.torch.classes.c10d.Work)");
2727
m.def(
28-
"_allgather_base_(Tensor output_tensor, Tensor input_tensor, __torch__.torch.classes.c10d.ProcessGroup process_group, bool asyncOp, int timeout) -> (Tensor, __torch__.torch.classes.c10d.Work)");
28+
"_allgather_base_(Tensor output_tensor, Tensor input_tensor, __torch__.torch.classes.c10d.ProcessGroup process_group, bool async_op=True, int timeout=-1) -> (Tensor, __torch__.torch.classes.c10d.Work)");
2929
m.def(
30-
"allgather_coalesced_(Tensor[][] output_lists, Tensor[] input_list, __torch__.torch.classes.c10d.ProcessGroup process_group, bool asyncOp) -> __torch__.torch.classes.c10d.Work");
30+
"allgather_coalesced_(Tensor[][] output_lists, Tensor[] input_list, __torch__.torch.classes.c10d.ProcessGroup process_group, bool async_op=True) -> __torch__.torch.classes.c10d.Work");
3131
m.def(
32-
"allgather_into_tensor_coalesced_(Tensor[] outputs, Tensor[] inputs, __torch__.torch.classes.c10d.ProcessGroup process_group, bool asyncOp) -> __torch__.torch.classes.c10d.Work");
32+
"allgather_into_tensor_coalesced_(Tensor[] outputs, Tensor[] inputs, __torch__.torch.classes.c10d.ProcessGroup process_group, bool async_op=True) -> __torch__.torch.classes.c10d.Work");
3333
m.def(
34-
"reduce_scatter_(Tensor[] output_tensors, Tensor[][] input_tensors, __torch__.torch.classes.c10d.ProcessGroup process_group, __torch__.torch.classes.c10d.ReduceOp reduce_op, bool asyncOp, int timeout) -> (Tensor[], __torch__.torch.classes.c10d.Work)");
34+
"reduce_scatter_(Tensor[] output_tensors, Tensor[][] input_tensors, __torch__.torch.classes.c10d.ProcessGroup process_group, __torch__.torch.classes.c10d.ReduceOp reduce_op, bool async_op=True, int timeout=-1) -> (Tensor[], __torch__.torch.classes.c10d.Work)");
3535
m.def(
36-
"_reduce_scatter_base_(Tensor output_tensor, Tensor input_tensor, __torch__.torch.classes.c10d.ProcessGroup process_group, __torch__.torch.classes.c10d.ReduceOp reduce_op, bool asyncOp, int timeout) -> (Tensor, __torch__.torch.classes.c10d.Work)");
36+
"_reduce_scatter_base_(Tensor output_tensor, Tensor input_tensor, __torch__.torch.classes.c10d.ProcessGroup process_group, __torch__.torch.classes.c10d.ReduceOp reduce_op, bool async_op=True, int timeout=-1) -> (Tensor, __torch__.torch.classes.c10d.Work)");
3737
m.def(
38-
"reduce_scatter_tensor_coalesced_(Tensor[] outputs, Tensor[] inputs, __torch__.torch.classes.c10d.ProcessGroup process_group, __torch__.torch.classes.c10d.ReduceOp reduce_op, bool asyncOp, int timeout) -> __torch__.torch.classes.c10d.Work");
38+
"reduce_scatter_tensor_coalesced_(Tensor[] outputs, Tensor[] inputs, __torch__.torch.classes.c10d.ProcessGroup process_group, __torch__.torch.classes.c10d.ReduceOp reduce_op, bool async_op=True, int timeout=-1) -> __torch__.torch.classes.c10d.Work");
3939
m.def(
40-
"reduce_(Tensor[] tensors, __torch__.torch.classes.c10d.ProcessGroup process_group, __torch__.torch.classes.c10d.ReduceOp reduce_op, int root_rank, int root_tensor, bool asyncOp, int timeout) -> __torch__.torch.classes.c10d.Work");
40+
"reduce_(Tensor[] tensors, __torch__.torch.classes.c10d.ProcessGroup process_group, __torch__.torch.classes.c10d.ReduceOp reduce_op, int root_rank, int root_tensor, bool async_op=True, int timeout=-1) -> __torch__.torch.classes.c10d.Work");
4141
m.def(
42-
"gather_(Tensor[][] output_tensors, Tensor[] input_tensors, __torch__.torch.classes.c10d.ProcessGroup process_group, int root_rank, bool asyncOp, int timeout) -> __torch__.torch.classes.c10d.Work");
42+
"gather_(Tensor[][] output_tensors, Tensor[] input_tensors, __torch__.torch.classes.c10d.ProcessGroup process_group, int root_rank, bool async_op=True, int timeout=-1) -> __torch__.torch.classes.c10d.Work");
4343
m.def(
44-
"scatter_(Tensor[] output_tensors, Tensor[][] input_tensors, __torch__.torch.classes.c10d.ProcessGroup process_group, int root_rank, bool asyncOp, int timeout) -> (Tensor[], __torch__.torch.classes.c10d.Work)");
44+
"scatter_(Tensor[] output_tensors, Tensor[][] input_tensors, __torch__.torch.classes.c10d.ProcessGroup process_group, int root_rank, bool async_op=True, int timeout=-1) -> (Tensor[], __torch__.torch.classes.c10d.Work)");
4545
m.def(
46-
"alltoall_(Tensor[] output_tensors, Tensor[] input_tensors, __torch__.torch.classes.c10d.ProcessGroup process_group, bool a 736B syncOp, int timeout) -> (Tensor[], __torch__.torch.classes.c10d.Work)");
46+
"alltoall_(Tensor[] output_tensors, Tensor[] input_tensors, __torch__.torch.classes.c10d.ProcessGroup process_group, bool async_op=True, int timeout=-1) -> (Tensor[], __torch__.torch.classes.c10d.Work)");
4747
m.def(
48-
"alltoall_base_(Tensor output, Tensor input, __torch__.torch.classes.c10d.ProcessGroup process_group, int[] output_split_sizes, int[] input_split_sizes, bool asyncOp, int timeout) -> __torch__.torch.classes.c10d.Work");
48+
"alltoall_base_(Tensor output, Tensor input, __torch__.torch.classes.c10d.ProcessGroup process_group, int[] output_split_sizes, int[] input_split_sizes, bool async_op=True, int timeout=-1) -> __torch__.torch.classes.c10d.Work");
4949
m.def(
50-
"barrier(Tensor tensor, __torch__.torch.classes.c10d.ProcessGroup process_group, int[] device_ids, bool asyncOp, int timeout) -> __torch__.torch.classes.c10d.Work");
50+
"barrier(Tensor tensor, __torch__.torch.classes.c10d.ProcessGroup process_group, int[] device_ids, bool async_op=True, int timeout=-1) -> __torch__.torch.classes.c10d.Work");
5151
m.def(
5252
"monitored_barrier_(Tensor tensor, __torch__.torch.classes.c10d.ProcessGroup process_group, int[] device_ids, int timeout, bool wait_all_ranks) -> ()");
5353
m.def(

0 commit comments

Comments
 (0)
0