@@ -224,13 +224,15 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder {
224
224
const c10::intrusive_ptr<::c10d::ProcessGroup>&,
225
225
const c10::intrusive_ptr<::c10d::ReduceOp>&,
226
226
const std::optional<at::Tensor>& sparse_indices,
227
+ bool ,
227
228
int64_t )>();
228
229
229
230
auto work = std::get<1 >(op.call (
230
231
tensors,
231
232
c10::intrusive_ptr<ProcessGroup>::unsafe_reclaim_from_nonowning (this ),
232
233
c10::make_intrusive<ReduceOp>(opts.reduceOp ),
233
234
opts.sparseIndices ,
235
+ opts.asyncOp ,
234
236
opts.timeout .count ()));
235
237
236
238
if (c10d::allow_inflight_collective_as_graph_input ()) {
@@ -250,12 +252,14 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder {
250
252
at::TensorList,
251
253
const c10::intrusive_ptr<::c10d::ProcessGroup>&,
252
254
const c10::intrusive_ptr<::c10d::ReduceOp>&,
255
+ bool ,
253
256
int64_t )>();
254
257
255
258
auto work = op.call (
256
259
tensors,
257
260
c10::intrusive_ptr<ProcessGroup>::unsafe_reclaim_from_nonowning (this ),
258
261
c10::make_intrusive<ReduceOp>(opts.reduceOp ),
262
+ opts.asyncOp ,
259
263
opts.timeout .count ());
260
264
261
265
if (c10d::allow_inflight_collective_as_graph_input ()) {
@@ -277,13 +281,15 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder {
277
281
const c10::intrusive_ptr<::c10d::ReduceOp>&,
278
282
int64_t ,
279
283
int64_t ,
284
+ bool ,
280
285
int64_t )>();
281
286
auto work = op.call (
282
287
tensors,
283
288
c10::intrusive_ptr<ProcessGroup>::unsafe_reclaim_from_nonowning (this ),
284
289
c10::make_intrusive<ReduceOp>(opts.reduceOp ),
285
290
opts.rootRank ,
286
291
opts.rootTensor ,
292
+ opts.asyncOp ,
287
293
opts.timeout .count ());
288
294
289
295
if (c10d::allow_inflight_collective_as_graph_input ()) {
@@ -306,12 +312,14 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder {
306
312
const std::vector<std::vector<at::Tensor>>&,
307
313
at::TensorList,
308
314
const c10::intrusive_ptr<::c10d::ProcessGroup>&,
315
+ bool ,
309
316
int64_t )>();
310
317
311
318
auto work = std::get<1 >(op.call (
312
319
outputTensors,
313
320
inputTensors,
314
321
c10::intrusive_ptr<ProcessGroup>::unsafe_reclaim_from_nonowning (this ),
322
+ opts.asyncOp ,
315
323
opts.timeout .count ()));
316
324
317
325
if (c10d::allow_inflight_collective_as_graph_input ()) {
@@ -363,18 +371,19 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder {
363
371
std::vector<std::vector<at::Tensor>>& outputTensorLists,
364
372
std::vector<at::Tensor>& inputTensors,
365
373
const AllgatherOptions& opts = AllgatherOptions()) {
366
- static auto op =
367
- c10::Dispatcher::singleton ( )
368
- . findSchemaOrThrow ( " c10d::allgather_coalesced_ " , " " )
369
- . typed <c10::intrusive_ptr<Work>(
370
- const std::vector<std::vector< at::Tensor>> &,
371
- const at::TensorList &,
372
- const c10::intrusive_ptr<::c10d::ProcessGroup>& )>();
374
+ static auto op = c10::Dispatcher::singleton ()
375
+ . findSchemaOrThrow ( " c10d::allgather_coalesced_ " , " " )
376
+ . typed <c10::intrusive_ptr<Work>(
377
+ const std::vector<std::vector<at::Tensor>>&,
378
+ const at::TensorList &,
379
+ const c10::intrusive_ptr<::c10d::ProcessGroup> &,
380
+ bool )>();
373
381
374
382
auto work = op.call (
375
383
outputTensorLists,
376
384
inputTensors,
377
- c10::intrusive_ptr<ProcessGroup>::unsafe_reclaim_from_nonowning (this ));
385
+ c10::intrusive_ptr<ProcessGroup>::unsafe_reclaim_from_nonowning (this ),
386
+ opts.asyncOp );
378
387
379
388
if (c10d::allow_inflight_collective_as_graph_input ()) {
380
389
for (const auto & tensor_list : outputTensorLists) {
@@ -399,12 +408,14 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder {
399
408
.typed <c10::intrusive_ptr<Work>(
400
409
const at::TensorList,
401
410
const at::TensorList,
402
- const c10::intrusive_ptr<::c10d::ProcessGroup>&)>();
411
+ const c10::intrusive_ptr<::c10d::ProcessGroup>&,
412
+ bool )>();
403
413
404
414
auto work = op.call (
405
415
outputTensors,
406
416
inputTensors,
407
- c10::intrusive_ptr<ProcessGroup>::unsafe_reclaim_from_nonowning (this ));
417
+ c10::intrusive_ptr<ProcessGroup>::unsafe_reclaim_from_nonowning (this ),
418
+ opts.asyncOp );
408
419
409
420
if (c10d::allow_inflight_collective_as_graph_input ()) {
410
421
for (const auto & tensor : outputTensors) {
@@ -425,12 +436,14 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder {
425
436
const at::TensorList&,
426
437
const c10::intrusive_ptr<::c10d::ProcessGroup>&,
427
438
int64_t ,
439
+ bool ,
428
440
int64_t )>();
429
441
auto work = op.call (
430
442
outputTensors,
431
443
inputTensors,
432
444
c10::intrusive_ptr<ProcessGroup>::unsafe_reclaim_from_nonowning (this ),
433
445
opts.rootRank ,
446
+ opts.asyncOp ,
434
447
opts.timeout .count ());
435
448
436
449
if (c10d::allow_inflight_collective_as_graph_input ()) {
@@ -487,12 +500,14 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder {
487
500
const std::vector<std::vector<at::Tensor>>&,
488
501
const c10::intrusive_ptr<::c10d::ProcessGroup>&,
489
502
const c10::intrusive_ptr<::c10d::ReduceOp>&,
503
+ bool ,
490
504
int64_t )>();
491
505
auto work = std::get<1 >(op.call (
492
506
outputTensors,
493
507
inputTensors,
494
508
c10::intrusive_ptr<ProcessGroup>::unsafe_reclaim_from_nonowning (this ),
495
509
c10::make_intrusive<::c10d::ReduceOp>(opts.reduceOp ),
510
+ opts.asyncOp ,
496
511
opts.timeout .count ()));
497
512
498
513
if (c10d::allow_inflight_collective_as_graph_input ()) {
@@ -546,13 +561,15 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder {
546
561
const at::TensorList,
547
562
const c10::intrusive_ptr<::c10d::ProcessGroup>&,
548
563
const c10::intrusive_ptr<::c10d::ReduceOp>&,
564
+ bool ,
549
565
int64_t )>();
550
566
551
567
auto work = op.call (
552
568
outputTensors,
553
569
inputTensors,
554
570
c10::intrusive_ptr<ProcessGroup>::unsafe_reclaim_from_nonowning (this ),
555
571
c10::make_intrusive<::c10d::ReduceOp>(opts.reduceOp ),
572
+ opts.asyncOp ,
556
573
opts.timeout .count ());
557
574
558
575
if (c10d::allow_inflight_collective_as_graph_input ()) {
@@ -577,13 +594,15 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder {
577
594
const c10::intrusive_ptr<::c10d::ProcessGroup>&,
578
595
std::vector<int64_t >,
579
596
std::vector<int64_t >,
597
+ bool ,
580
598
int64_t )>();
581
599
auto work = op.call (
582
600
outputBuffer,
583
601
inputBuffer,
584
602
c10::intrusive_ptr<ProcessGroup>::unsafe_reclaim_from_nonowning (this ),
585
603
outputSplitSizes,
586
604
inputSplitSizes,
605
+ opts.asyncOp ,
587
606
opts.timeout .count ());
588
607
589
608
if (c10d::allow_inflight_collective_as_graph_input ()) {
@@ -604,11 +623,13 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder {
604
623
const at::TensorList&,
605
624
const at::TensorList&,
606
625
const c10::intrusive_ptr<::c10d::ProcessGroup>&,
626
+ bool ,
607
627
int64_t )>();
608
628
auto work = std::get<1 >(op.call (
609
629
outputTensors,
610
630
inputTensors,
611
631
c10::intrusive_ptr<ProcessGroup>::unsafe_reclaim_from_nonowning (this ),
632
+ opts.asyncOp ,
612
633
opts.timeout .count ()));
613
634
614
635
if (c10d::allow_inflight_collective_as_graph_input ()) {
@@ -778,12 +799,14 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder {
778
799
at::Tensor,
779
800
const c10::intrusive_ptr<::c10d::ProcessGroup>&,
780
801
const std::vector<int64_t >&,
802
+ bool ,
781
803
int64_t )>();
782
804
783
805
auto work = op.call (
784
806
tensor,
785
807
c10::intrusive_ptr<ProcessGroup>::unsafe_reclaim_from_nonowning (this ),
786
808
opts.device_ids ,
809
+ opts.asyncOp ,
787
810
opts.timeout .count ());
788
811
if (c10d::allow_inflight_collective_as_graph_input ()) {
789
812
c10d::register_work (tensor, work);
0 commit comments