8000 [FSDP] Fixed `device_mesh` and auto wrap (#119064) · pytorch/pytorch@ce40ee8 · GitHub
[go: up one dir, main page]

Skip to content

Commit ce40ee8

Browse files
Andrew Gupytorchmergebot
Andrew Gu
authored andcommitted
[FSDP] Fixed device_mesh and auto wrap (#119064)
If the user passes `device_mesh`, then we should not forward the process groups to the children during auto wrap and instead just rely on the `device_mesh` argument. This should fix #118906. Pull Request resolved: #119064 Approved by: https://github.com/wz337
1 parent 18fc1ca commit ce40ee8

File tree

2 files changed

+15
-2
lines changed

2 files changed

+15
-2
lines changed

test/distributed/fsdp/test_fsdp_hybrid_shard.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
import torch.distributed as dist
1212
import torch.distributed.fsdp._traversal_utils as traversal_utils
1313
import torch.nn as nn
14+
from torch.distributed.device_mesh import init_device_mesh
1415
from torch.distributed.distributed_c10d import _rank_not_in_group
1516
from torch.distributed.fsdp import (
1617
FullyShardedDataParallel as FSDP,
@@ -284,6 +285,7 @@ def test_fsdp_hybrid_shard_basic_setup(self):
284285
ShardingStrategyMode.MIXED_HYBRID_FULL_SHARD,
285286
],
286287
"use_orig_params": [False, True],
288+
"use_device_mesh": [False, True],
287289
},
288290
self._test_fsdp_hybrid_shard_basic_setup,
289291
)
@@ -293,9 +295,17 @@ def _test_fsdp_hybrid_shard_basic_setup(
293295
hsdp_sharding_strategy: ShardingStrategy,
294296
sharding_strategy_mode: ShardingStrategyMode,
295297
use_orig_params: bool,
298+
use_device_mesh: bool,
296299
):
300+
if use_device_mesh:
301+
device_mesh = init_device_mesh("cuda", (1, self.world_size))
302+
else:
303+
device_mesh = None
297304
hsdp_model = self._init_hsdp_model(
298-
hsdp_sharding_strategy, sharding_strategy_mode, use_orig_params
305+
hsdp_sharding_strategy,
306+
sharding_strategy_mode,
307+
use_orig_params,
308+
hsdp_device_mesh=device_mesh,
299309
)
300310
# All FSDP modules should have state.process_group as the process group over which to
301311
# shard (default process group), and state._inter_node_pg (process group containing only
@@ -428,7 +438,9 @@ def _init_hsdp_model(
428438
hsdp_process_groups: Optional[
429439
Tuple[dist.ProcessGroup, dist.ProcessGroup]
430440
] = None,
441+
hsdp_device_mesh: Optional = None,
431442
):
443+
assert hsdp_process_groups is None or hsdp_device_mesh is None
432444
auto_wrap_policy = ModuleWrapPolicy(
433445
{TransformerEncoderLayer, TransformerDecoderLayer},
434446
)
@@ -437,6 +449,7 @@ def _init_hsdp_model(
437449
"auto_wrap_policy": auto_wrap_policy,
438450
"sharding_strategy": hsdp_sharding_strategy,
439451
"use_orig_params": use_orig_params,
452+
"device_mesh": hsdp_device_mesh,
440453
}
441454
if sharding_strategy_mode == ShardingStrategyMode.ALL_HYBRID_SHARD:
442455
hsdp_model = TransformerWithSharedParams.init(

torch/distributed/fsdp/fully_sharded_data_parallel.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -472,7 +472,7 @@ def __init__(
472472
"ignored_states": self._ignored_params,
473473
"device_mesh": device_mesh,
474474
}
475-
if sharding_strategy in HYBRID_SHARDING_STRATEGIES:
475+
if sharding_strategy in HYBRID_SHARDING_STRATEGIES and device_mesh is None:
476476
# Share root process groups with children to maintain
477477
# the invariant that all FSDP modules will have the same
478478
# process groups.

0 commit comments

Comments
 (0)
0