@@ -38,9 +38,8 @@ def _init_device_mesh_stub():
38
38
else :
39
39
from torch ._C ._distributed_c10d import Backend as C10dBackend
40
40
from torch .distributed .distributed_c10d import (
41
- _find_pg_by_ranks_and_tag ,
42
41
_get_default_group ,
43
- _get_group_tag ,
42
+ _resolve_process_group ,
44
43
get_backend ,
45
44
get_process_group_ranks ,
46
45
get_rank ,
@@ -103,7 +102,7 @@ def create_sub_mesh(
103
102
mesh_tensor = device_mesh .mesh
104
103
# slice_dim_idx could be differnt from submesh_dims, as we may need to flatten out some dims.
105
104
slice_dim_idx = []
106
- slice_dim_group_info = []
105
+ slice_dim_group_name = []
107
106
# keep track of the number of dims that have been flattened so we can get the correct slice_dim_idx in the
108
107
# flattened mesh tensor.
109
108
num_dims_flatten = 0
@@ -121,15 +120,15 @@ def create_sub_mesh(
121
120
# then the final slice_dim_idx should be [0, 1, 2].
122
121
slice_dim_idx .append (mesh_dim_indices [0 ] - num_dims_flatten )
123
122
num_dims_flatten += len (mesh_dim_indices ) - 1
124
- slice_dim_group_info .append (
123
+ slice_dim_group_name .append (
125
124
self .root_to_flatten_mapping [device_mesh ][
126
125
mesh_dim_name
127
- ]._dim_group_infos [0 ]
126
+ ]._dim_group_names [0 ]
128
127
)
129
128
else :
130
129
slice_dim_idx .append (mesh_dim_indices [0 ] - num_dims_flatten )
131
- slice_dim_group_info .append (
132
- device_mesh ._dim_group_infos [mesh_dim_indices [0 ]]
130
+ slice_dim_group_name .append (
131
+ device_mesh ._dim_group_names [mesh_dim_indices [0 ]]
133
132
)
134
133
135
134
# mesh_tensor has already been flattened if needed. So mesh_tensor.ndim <= device_mesh.mesh.ndim now.
@@ -155,7 +154,7 @@ def create_sub_mesh(
155
154
if cur_rank in mesh_nd :
156
155
res_submesh = submesh
157
156
158
- res_submesh ._dim_group_infos = slice_dim_group_info # type: ignore[possibly-undefined]
157
+ res_submesh ._dim_group_names = slice_dim_group_name # type: ignore[possibly-undefined]
159
158
self .child_to_root_mapping [res_submesh ] = device_mesh
160
F438
159
161
160
return res_submesh
@@ -360,8 +359,8 @@ def _get_all_submeshes(
360
359
mesh_dim_names = (mesh_dim_name ,),
361
360
_init_backend = False ,
362
361
)
363
- submesh ._dim_group_infos = (
364
- [device_mesh ._dim_group_infos [mesh_dim ]]
362
+ submesh ._dim_group_names = (
363
+ [device_mesh ._dim_group_names [mesh_dim ]]
365
364
if cur_rank in mesh_1d
366
365
else []
367
366
)
@@ -496,13 +495,10 @@ def _get_or_create_default_group(self):
496
495
return _get_default_group ()
497
496
498
497
def _init_process_groups (self ):
499
- # tag/ranks/ group_name associated with each mesh dimension, each
498
+ # group_name associated with each mesh dimension, each
500
499
# mesh dimension should have one sub-group per rank
501
500
#
502
- # TODO(yifu): remove tag and ranks once we fully migrate to native
503
- # functional collectives. See details in:
504
- # https://github.com/pytorch/pytorch/issues/93173#issuecomment-1907095208
505
- dim_group_infos : list [tuple [str , list [int ], str ]] = []
501
+ dim_group_names : list [str ] = []
506
502
default_group = _get_default_group ()
507
503
508
504
if self .mesh .ndim == 1 and self .mesh .numel () == get_world_size ():
@@ -519,13 +515,7 @@ def _init_process_groups(self):
519
515
and get_backend (default_group ) == "gloo"
520
516
else default_group
521
517
)
522
- dim_group_infos .append (
523
- (
524
- _get_group_tag (dim_group ),
525
- ranks ,
526
- dim_group .group_name ,
527
- )
528
- )
518
+ dim_group_names .append (dim_group .group_name )
529
519
else :
530
520
# create sub pgs base on the mesh argument specified
531
521
for dim in range (self .mesh .ndim ):
@@ -579,10 +569,9 @@ def _init_process_groups(self):
579
569
has_split_group = True
580
570
581
571
# If the subgroup has been already created through `split_group`, we simply loop over `pg_ranks_by_dim`
582
- # and append the `(group_tag, subgroup_ranks, and group_name)` tuple to the `dim_group_infos` list when
583
- # the current rank is in the subgroup.
572
+ # and append the `group_name` to the `dim_group_names` list when the current rank is in the subgroup.
584
573
# Otherwise, we use `new_group` instead of `split_group` to create subgroups by looping over `pg_ranks_by_dim`
585
- # along with appending information to the `dim_group_infos ` list whenever necessary.
574
+ # along with appending information to the `dim_group_names ` list whenever necessary.
586
575
for dim_mesh in pg_ranks_by_dim :
587
576
subgroup_ranks = dim_mesh .tolist ()
588
577
@@ -599,19 +588,13 @@ def _init_process_groups(self):
599
588
600
589
# only add to dim_groups if the current rank in the subgroup
601
590
if self .get_rank () in subgroup_ranks :
602
- if len (dim_group_infos ) > dim :
591
+ if len (dim_group_names ) > dim :
603
592
raise RuntimeError (
604
593
f"Each device mesh dimension should get only one process group, but got { self .get_rank ()} "
605
594
f"in { subgroup_ranks } !"
606
595
)
607
- dim_group_infos .append (
608
- (
609
- _get_group_tag (not_none (dim_group )),
610
- subgroup_ranks ,
611
- dim_group .group_name ,
612
- )
613
- )
614
- self ._dim_group_infos = dim_group_infos
596
+ dim_group_names .append (dim_group .group_name )
597
+ self ._dim_group_names = dim_group_names
615
598
616
599
def __enter__ (self ) -> "DeviceMesh" :
617
600
# set this mesh as the current mesh in mesh env
@@ -745,7 +728,7 @@ def get_group(self, mesh_dim: Optional[Union[int, str]] = None) -> ProcessGroup:
745
728
Returns:
746
729
A :class:`ProcessGroup` object.
747
730
"""
748
- if not hasattr (self , "_dim_group_infos " ):
731
+ if not hasattr (self , "_dim_group_names " ):
749
732
raise RuntimeError ("DeviceMesh process groups not initialized!" )
750
733
751
734
if self .mesh .ndim > 1 and mesh_dim is None :
@@ -758,28 +741,25 @@ def get_group(self, mesh_dim: Optional[Union[int, str]] = None) -> ProcessGroup:
758
741
759
742
# Quick return if the current device_mesh is a 1D mesh.
760
743
if self .mesh .ndim == 1 and mesh_dim is None :
761
- return not_none (
762
- _find_pg_by_ranks_and_tag (* self ._dim_group_infos [0 ][:2 ]) # type: ignore[index]
763
- )
744
+ return not_none (_resolve_process_group (self ._dim_group_names [0 ]))
764
745
765
746
root_mesh = _mesh_resources .get_root_mesh (self )
766
747
root_to_flatten_mapping = _mesh_resources .root_to_flatten_mapping .get (
767
748
root_mesh , None
768
749
)
769
750
if root_to_flatten_mapping and mesh_dim in root_to_flatten_mapping .keys ():
770
- dim_group_infos = root_to_flatten_mapping [
751
+ dim_group_name = root_to_flatten_mapping [
771
752
mesh_dim # type: ignore[index]
772
- ]._dim_group_infos [ 0 ][: 2 ]
773
- return not_none (_find_pg_by_ranks_and_tag ( * dim_group_infos ))
753
+ ]._dim_group_names [ 0 ]
754
+ return not_none (_resolve_process_group ( dim_group_name ))
774
755
else :
775
756
mesh_dim = (
776
757
_mesh_resources .get_mesh_dim_by_name (self , mesh_dim )
777
758
if isinstance (mesh_dim , str )
778
759
else mesh_dim
779
760
)
780
- return not_none (
781
- _find_pg_by_ranks_and_tag (* self ._dim_group_infos [mesh_dim ][:2 ]) # type: ignore[index]
782
- )
761
+ assert isinstance (mesh_dim , int )
762
+ return not_none (_resolve_process_group (self ._dim_group_names [mesh_dim ]))
783
763
784
764
def get_all_groups (self ) -> list [ProcessGroup ]:
785
765
"""
@@ -852,9 +832,7 @@ def from_group(
852
832
mesh_dim_names = mesh_dim_names ,
853
833
_init_backend = False ,
854
834
)
855
- device_mesh ._dim_group_infos = [
856
- (_get_group_tag (group ), group_ranks , group .group_name )
857
- ]
835
+ device_mesh ._dim_group_names = [group .group_name ]
858
836
return device_mesh
859
837
860
838
# nD scenario
@@ -880,14 +858,7 @@ def from_group(
880
858
device_mesh = DeviceMesh (
881
859
device_type , mesh , mesh_dim_names = mesh_dim_names , _init_backend = False
882
860
)
883
- device_mesh ._dim_group_infos = [
884
- (
885
- _get_group_tag (group ),
886
- get_process_group_ranks (group ),
887
- group .group_name ,
888
- )
889
- for group in groups
890
- ]
861
+ device_mesh ._dim_group_names = [group .group_name for group in groups ]
891
862
return device_mesh
892
863
893
864
def size (self , mesh_dim : Optional [int ] = None ) -> int :
0 commit comments