8000 [device_mesh] replace dim_group_info with group_name (#150898) · pytorch/pytorch@9df9d9d · GitHub
[go: up one dir, main page]

Skip to content

Commit 9df9d9d

Browse files
wanchaolpytorchmergebot
authored andcommitted
[device_mesh] replace dim_group_info with group_name (#150898)
as titled, there's no need to maintain a dim_group_info anymore, we can simply maintain a list of group_name instead. This will simplify the logic Pull Request resolved: #150898 Approved by: https://github.com/tianyu-l, https://github.com/fegin
1 parent 9c3cef4 commit 9df9d9d

File tree

4 files changed

+55
-88
lines changed

4 files changed

+55
-88
lines changed

test/distributed/_composable/fsdp/test_fully_shard_init.py

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -840,7 +840,7 @@ def test_1d_process_group_init(self):
840840
# since the ref has a parent mesh, while the `from_group` one does not
841841
self.assertEqual(dp_mesh.mesh, ref_dp_mesh.mesh)
842842
self.assertEqual(dp_mesh._coordinate_on_dim, ref_dp_mesh._coordinate_on_dim)
843-
self.assertEqual(dp_mesh._dim_group_infos, ref_dp_mesh._dim_group_infos)
843+
self.assertEqual(dp_mesh._dim_group_names, ref_dp_mesh._dim_group_names)
844844

845845
# Check 1D FSDP forward/backward parity over the DP mesh
846846
# NOTE: We cannot use 2D DTensor-based training here because the DP
@@ -916,12 +916,6 @@ def test_2d_process_group_init(self):
916916
)
917917
self.assertEqual(mesh.mesh, ref_mesh.mesh)
918918
self.assertEqual(mesh._coordinate_on_dim, ref_mesh._coordinate_on_dim)
919-
for (_, ranks, _), (_, ref_ranks, _) in zip(
920-
mesh._dim_group_infos, ref_mesh._dim_group_infos
921-
):
922-
# Since we manually constructed new subgroups, the test and ref
923-
# groups are not the same
924-
self.assertEqual(ranks, ref_ranks)
925919
for mesh_dim_name in mesh_dim_names:
926920
child_mesh = mesh[mesh_dim_name]
927921
ref_child_mesh = ref_mesh[mesh_dim_name]

test/distributed/test_device_mesh.py

Lines changed: 20 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import os
44

55
import torch
6+
import torch.distributed as dist
67
import torch.distributed._functional_collectives as funcol
78
from torch._subclasses.fake_tensor import FakeTensorMode
89
from torch.distributed.device_mesh import _mesh_resources, DeviceMesh, init_device_mesh
@@ -197,7 +198,7 @@ def test_fake_pg_device_mesh(self):
197198
local_tensor = torch.randn(2, 8)
198199
global_tensor = funcol.all_gather_tensor(
199200
local_tensor, gather_dim=0, group=(mesh, 0)
200-
)
201+
).wait()
201202
self.assertEqual(global_tensor.shape, (self.world_size * 2, 8))
202203

203204
@with_comms
@@ -208,7 +209,7 @@ def test_from_group_with_global_pg(self):
208209
mesh_pg = ref_global_mesh.get_group()
209210
global_mesh = DeviceMesh.from_group(mesh_pg, self.device_type)
210211
self.assertEqual(ref_global_mesh, global_mesh)
211-
self.assertEqual(ref_global_mesh._dim_group_infos, global_mesh._dim_group_infos)
212+
self.assertEqual(ref_global_mesh._dim_group_names, global_mesh._dim_group_names)
212213
self.assertEqual(
213214
ref_global_mesh._coordinate_on_dim, global_mesh._coordinate_on_dim
214215
)
@@ -217,7 +218,7 @@ def test_from_group_with_global_pg(self):
217218
mesh_pg, self.device_type, mesh=torch.arange(self.world_size)
218219
)
219220
self.assertEqual(ref_global_mesh, global_mesh)
220-
self.assertEqual(ref_global_mesh._dim_group_infos, global_mesh._dim_group_infos)
221+
self.assertEqual(ref_global_mesh._dim_group_names, global_mesh._dim_group_names)
221222
self.assertEqual(
222223
ref_global_mesh._coordinate_on_dim, global_mesh._coordinate_on_dim
223224
)
@@ -396,24 +397,20 @@ def test_from_group_with_mesh_shape_3d(self):
396397
mesh_dim_names=("dp_replicate", "dp_shard"),
397398
)
398399

399-
ref_mesh_dp_dim_group_infos = ref_mesh._dim_group_infos[:2]
400-
for (_, ref_ranks, _), (_, ranks, _) in zip(
401-
ref_mesh_dp_dim_group_infos, dp_mesh._dim_group_infos
402-
):
403-
self.assertEqual(ref_ranks, ranks)
400+
ref_mesh_dp_dim_group_names = ref_mesh._dim_group_names[:2]
401+
self.assertEqual(ref_mesh_dp_dim_group_names, dp_mesh._dim_group_names[:2])
404402
# Cannot check directly for mesh equality since parent meshes are not
405403
# the same since the ref's parent mesh is 3D
406404
self.assertEqual(dp_mesh["dp_replicate"].mesh, ref_mesh["dp_replicate"].mesh)
407-
for (_, ref_ranks, _), (_, ranks, _) in zip(
408-
dp_mesh["dp_replicate"]._dim_group_infos,
409-
ref_mesh["dp_replicate"]._dim_group_infos,
410-
):
411-
self.assertEqual(ref_ranks, ranks)
405+
self.assertEqual(
406+
dp_mesh["dp_replicate"]._dim_group_names,
407+
ref_mesh["dp_replicate"]._dim_group_names,
408+
)
412409
self.assertEqual(dp_mesh["dp_shard"].mesh, ref_mesh["dp_shard"].mesh)
413-
for (_, ref_ranks, _), (_, ranks, _) in zip(
414-
dp_mesh["dp_shard"]._dim_group_infos, ref_mesh["dp_shard"]._dim_group_infos
415-
):
416-
self.assertEqual(ref_ranks, ranks)
410+
self.assertEqual(
411+
dp_mesh["dp_shard"]._dim_group_names,
412+
ref_mesh["dp_shard"]._dim_group_names,
413+
)
417414

418415
@with_comms()
419416
def test_from_group_with_mesh_shape_2d(self):
@@ -456,12 +453,13 @@ def test_from_group_with_mesh_shape_2d(self):
456453
mesh_dim_names=("dp_replicate", "dp_shard"),
457454
)
458455

459-
ref_mesh_dp_dim_group_infos = ref_mesh._dim_group_infos[:2]
460-
for (_, ref_ranks, _), (_, ranks, _) in zip(
461-
ref_mesh_dp_dim_group_infos, dp_mesh._dim_group_infos
456+
# self.assertEqual(ref_mesh._dim_group_names, dp_mesh._dim_group_names)
457+
for mesh_dim_group, ref_mesh_dim_group in zip(
458+
dp_mesh.get_all_groups(), ref_mesh.get_all_groups()
462459
):
463-
self.assertEqual(ref_ranks, ranks)
464-
460+
mesh_dim_group_ranks = dist.get_process_group_ranks(mesh_dim_group)
461+
ref_mesh_dim_group_ranks = dist.get_process_group_ranks(ref_mesh_dim_group)
462+
self.assertEqual(mesh_dim_group_ranks, ref_mesh_dim_group_ranks)
465463
# check both the 2d mesh and the submeshes are exactly the same.
466464
self.assertEqual(dp_mesh, ref_mesh)
467465
self.assertEqual(dp_mesh["dp_replicate"], ref_mesh["dp_replicate"])

torch/distributed/_functional_collectives.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -731,8 +731,10 @@ def cast_listint(x):
731731
"Only 1D mesh is supported, pass in (DeviceMesh, int) together if mesh > 1D"
732732
)
733733
# TODO: it should run collective in the whole mesh instead of dim 0
734-
tag, rankset, _ = group._dim_group_infos[0]
734+
pg = group.get_group()
735+
rankset = dist.get_process_group_ranks(pg)
735736
group_size = len(rankset)
737+
tag = tag or c10d._get_group_tag(pg)
736738
elif isinstance(group, tuple):
737739
if (
738740
len(group) == 2
@@ -741,8 +743,10 @@ def cast_listint(x):
741743
):
742744
dmesh = group[0]
743745
dim = group[1]
744-
tag, rankset, _ = dmesh._dim_group_infos[dim]
746+
pg = dmesh.get_group(dim)
747+
rankset = dist.get_process_group_ranks(pg)
745748
group_size = len(rankset)
749+
tag = tag or c10d._get_group_tag(pg)
746750
else:
747751
raise ValueError("Invalid tuple for group must be (DeviceMesh, int)")
748752
else:
@@ -767,7 +771,7 @@ def _resolve_group_name(group: RANK_TYPES, tag: str = "") -> str:
767771
assert group.ndim == 1, (
768772
"Only 1D mesh is supported, pass in (DeviceMesh, int) together if mesh > 1D"
769773
)
770-
return group._dim_group_infos[0][2]
774+
return group._dim_group_names[0]
771775
elif isinstance(group, tuple):
772776
if (
773777
len(group) == 2
@@ -776,7 +780,7 @@ def _resolve_group_name(group: RANK_TYPES, tag: str = "") -> str:
776780
):
777781
dmesh = group[0]
778782
dim = group[1]
779-
return dmesh._dim_group_infos[dim][2]
783+
return dmesh._dim_group_names[dim]
780784
else:
781785
raise ValueError("Invalid tuple for group must be (DeviceMesh, int)")
782786
elif isinstance(group, list):

torch/distributed/device_mesh.py

Lines changed: 26 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -38,9 +38,8 @@ def _init_device_mesh_stub():
3838
else:
3939
from torch._C._distributed_c10d import Backend as C10dBackend
4040
from torch.distributed.distributed_c10d import (
41-
_find_pg_by_ranks_and_tag,
4241
_get_default_group,
43-
_get_group_tag,
42+
_resolve_process_group,
4443
get_backend,
4544
get_process_group_ranks,
4645
get_rank,
@@ -103,7 +102,7 @@ def create_sub_mesh(
103102
mesh_tensor = device_mesh.mesh
104103
# slice_dim_idx could be differnt from submesh_dims, as we may need to flatten out some dims.
105104
slice_dim_idx = []
106-
slice_dim_group_info = []
105+
slice_dim_group_name = []
107106
# keep track of the number of dims that have been flattened so we can get the correct slice_dim_idx in the
108107
# flattened mesh tensor.
109108
num_dims_flatten = 0
@@ -121,15 +120,15 @@ def create_sub_mesh(
121120
# then the final slice_dim_idx should be [0, 1, 2].
122121
slice_dim_idx.append(mesh_dim_indices[0] - num_dims_flatten)
123122
num_dims_flatten += len(mesh_dim_indices) - 1
124-
slice_dim_group_info.append(
123+
slice_dim_group_name.append(
125124
self.root_to_flatten_mapping[device_mesh][
126125
mesh_dim_name
127-
]._dim_group_infos[0]
126+
]._dim_group_names[0]
128127
)
129128
else:
130129
slice_dim_idx.append(mesh_dim_indices[0] - num_dims_flatten)
131-
slice_dim_group_info.append(
132-
device_mesh._dim_group_infos[mesh_dim_indices[0]]
130+
slice_dim_group_name.append(
131+
device_mesh._dim_group_names[mesh_dim_indices[0]]
133132
)
134133

135134
# mesh_tensor has already been flattened if needed. So mesh_tensor.ndim <= device_mesh.mesh.ndim now.
@@ -155,7 +154,7 @@ def create_sub_mesh(
155154
if cur_rank in mesh_nd:
156155
res_submesh = submesh
157156

158-
res_submesh._dim_group_infos = slice_dim_group_info # type: ignore[possibly-undefined]
157+
res_submesh._dim_group_names = slice_dim_group_name # type: ignore[possibly-undefined]
159158
self.child_to_root_mapping[res_submesh] = device_mesh
160 F438 159

161160
return res_submesh
@@ -360,8 +359,8 @@ def _get_all_submeshes(
360359
mesh_dim_names=(mesh_dim_name,),
361360
_init_backend=False,
362361
)
363-
submesh._dim_group_infos = (
364-
[device_mesh._dim_group_infos[mesh_dim]]
362+
submesh._dim_group_names = (
363+
[device_mesh._dim_group_names[mesh_dim]]
365364
if cur_rank in mesh_1d
366365
else []
367366
)
@@ -496,13 +495,10 @@ def _get_or_create_default_group(self):
496495
return _get_default_group()
497496

498497
def _init_process_groups(self):
499-
# tag/ranks/group_name associated with each mesh dimension, each
498+
# group_name associated with each mesh dimension, each
500499
# mesh dimension should have one sub-group per rank
501500
#
502-
# TODO(yifu): remove tag and ranks once we fully migrate to native
503-
# functional collectives. See details in:
504-
# https://github.com/pytorch/pytorch/issues/93173#issuecomment-1907095208
505-
dim_group_infos: list[tuple[str, list[int], str]] = []
501+
dim_group_names: list[str] = []
506502
default_group = _get_default_group()
507503

508504
if self.mesh.ndim == 1 and self.mesh.numel() == get_world_size():
@@ -519,13 +515,7 @@ def _init_process_groups(self):
519515
and get_backend(default_group) == "gloo"
520516
else default_group
521517
)
522-
dim_group_infos.append(
523-
(
524-
_get_group_tag(dim_group),
525-
ranks,
526-
dim_group.group_name,
527-
)
528-
)
518+
dim_group_names.append(dim_group.group_name)
529519
else:
530520
# create sub pgs base on the mesh argument specified
531521
for dim in range(self.mesh.ndim):
@@ -579,10 +569,9 @@ def _init_process_groups(self):
579569
has_split_group = True
580570

581571
# If the subgroup has been already created through `split_group`, we simply loop over `pg_ranks_by_dim`
582-
# and append the `(group_tag, subgroup_ranks, and group_name)` tuple to the `dim_group_infos` list when
583-
# the current rank is in the subgroup.
572+
# and append the `group_name` to the `dim_group_names` list when the current rank is in the subgroup.
584573
# Otherwise, we use `new_group` instead of `split_group` to create subgroups by looping over `pg_ranks_by_dim`
585-
# along with appending information to the `dim_group_infos` list whenever necessary.
574+
# along with appending information to the `dim_group_names` list whenever necessary.
586575
for dim_mesh in pg_ranks_by_dim:
587576
subgroup_ranks = dim_mesh.tolist()
588577

@@ -599,19 +588,13 @@ def _init_process_groups(self):
599588

600589
# only add to dim_groups if the current rank in the subgroup
601590
if self.get_rank() in subgroup_ranks:
602-
if len(dim_group_infos) > dim:
591+
if len(dim_group_names) > dim:
603592
raise RuntimeError(
604593
f"Each device mesh dimension should get only one process group, but got {self.get_rank()} "
605594
f"in {subgroup_ranks}!"
606595
)
607-
dim_group_infos.append(
608-
(
609-
_get_group_tag(not_none(dim_group)),
610-
subgroup_ranks,
611-
dim_group.group_name,
612-
)
613-
)
614-
self._dim_group_infos = dim_group_infos
596+
dim_group_names.append(dim_group.group_name)
597+
self._dim_group_names = dim_group_names
615598

616599
def __enter__(self) -> "DeviceMesh":
617600
# set this mesh as the current mesh in mesh env
@@ -745,7 +728,7 @@ def get_group(self, mesh_dim: Optional[Union[int, str]] = None) -> ProcessGroup:
745728
Returns:
746729
A :class:`ProcessGroup` object.
747730
"""
748-
if not hasattr(self, "_dim_group_infos"):
731+
if not hasattr(self, "_dim_group_names"):
749732
raise RuntimeError("DeviceMesh process groups not initialized!")
750733

751734
if self.mesh.ndim > 1 and mesh_dim is None:
@@ -758,28 +741,25 @@ def get_group(self, mesh_dim: Optional[Union[int, str]] = None) -> ProcessGroup:
758741

759742
# Quick return if the current device_mesh is a 1D mesh.
760743
if self.mesh.ndim == 1 and mesh_dim is None:
761-
return not_none(
762-
_find_pg_by_ranks_and_tag(*self._dim_group_infos[0][:2]) # type: ignore[index]
763-
)
744+
return not_none(_resolve_process_group(self._dim_group_names[0]))
764745

765746
root_mesh = _mesh_resources.get_root_mesh(self)
766747
root_to_flatten_mapping = _mesh_resources.root_to_flatten_mapping.get(
767748
root_mesh, None
768749
)
769750
if root_to_flatten_mapping and mesh_dim in root_to_flatten_mapping.keys():
770-
dim_group_infos = root_to_flatten_mapping[
751+
dim_group_name = root_to_flatten_mapping[
771752
mesh_dim # type: ignore[index]
772-
]._dim_group_infos[0][:2]
773-
return not_none(_find_pg_by_ranks_and_tag(*dim_group_infos))
753+
]._dim_group_names[0]
754+
return not_none(_resolve_process_group(dim_group_name))
774755
else:
775756
mesh_dim = (
776757
_mesh_resources.get_mesh_dim_by_name(self, mesh_dim)
777758
if isinstance(mesh_dim, str)
778759
else mesh_dim
779760
)
780-
return not_none(
781-
_find_pg_by_ranks_and_tag(*self._dim_group_infos[mesh_dim][:2]) # type: ignore[index]
782-
)
761+
assert isinstance(mesh_dim, int)
762+
return not_none(_resolve_process_group(self._dim_group_names[mesh_dim]))
783763

784764
def get_all_groups(self) -> list[ProcessGroup]:
785765
"""
@@ -852,9 +832,7 @@ def from_group(
852832
mesh_dim_names=mesh_dim_names,
853833
_init_backend=False,
854834
)
855-
device_mesh._dim_group_infos = [
856-
(_get_group_tag(group), group_ranks, group.group_name)
857-
]
835+
device_mesh._dim_group_names = [group.group_name]
858836
return device_mesh
859837

860838
# nD scenario
@@ -880,14 +858,7 @@ def from_group(
880858
device_mesh = DeviceMesh(
881859
device_type, mesh, mesh_dim_names=mesh_dim_names, _init_backend=False
882860
)
883-
device_mesh._dim_group_infos = [
884-
(
885-
_get_group_tag(group),
886-
get_process_group_ranks(group),
887-
group.group_name,
888-
)
889-
for group in groups
890-
]
861+
device_mesh._dim_group_names = [group.group_name for group in groups]
891862
return device_mesh
892863

893864
def size(self, mesh_dim: Optional[int] = None) -> int:

0 commit comments

Comments
 (0)
0