11
11
import torch .distributed as dist
12
12
import torch .distributed .fsdp ._traversal_utils as traversal_utils
13
13
import torch .nn as nn
14
+ from torch .distributed .device_mesh import init_device_mesh
14
15
from torch .distributed .distributed_c10d import _rank_not_in_group
15
16
from torch .distributed .fsdp import (
16
17
FullyShardedDataParallel as FSDP ,
@@ -284,6 +285,7 @@ def test_fsdp_hybrid_shard_basic_setup(self):
284
285
ShardingStrategyMode .MIXED_HYBRID_FULL_SHARD ,
285
286
],
286
287
"use_orig_params" : [False , True ],
288
+ "use_device_mesh" : [False , True ],
287
289
},
288
290
self ._test_fsdp_hybrid_shard_basic_setup ,
289
291
)
@@ -293,9 +295,17 @@ def _test_fsdp_hybrid_shard_basic_setup(
293
295
hsdp_sharding_strategy : ShardingStrategy ,
294
296
sharding_strategy_mode : ShardingStrategyMode ,
295
297
use_orig_params : bool ,
298
+ use_device_mesh : bool ,
296
299
):
300
+ if use_device_mesh :
301
+ device_mesh = init_device_mesh ("cuda" , (1 , self .world_size ))
302
+ else :
303
+ device_mesh = None
297
304
hsdp_model = self ._init_hsdp_model (
298
- hsdp_sharding_strategy , sharding_strategy_mode , use_orig_params
305
+ hsdp_sharding_strategy ,
306
+ sharding_strategy_mode ,
307
+ use_orig_params ,
308
+ hsdp_device_mesh = device_mesh ,
299
309
)
300
310
# All FSDP modules should have state.process_group as the process group over which to
301
311
# shard (default process group), and state._inter_node_pg (process group containing only
@@ -428,7 +438,9 @@ def _init_hsdp_model(
428
438
hsdp_process_groups : Optional [
429
439
Tuple [dist .ProcessGroup , dist .ProcessGroup ]
430
440
] = None ,
441
+ hsdp_device_mesh : Optional = None ,
431
442
):
443
+ assert hsdp_process_groups is None or hsdp_device_mesh is None
432
444
auto_wrap_policy = ModuleWrapPolicy (
433
445
{TransformerEncoderLayer , TransformerDecoderLayer },
434
446
)
@@ -437,6 +449,7 @@ def _init_hsdp_model(
437
449
"auto_wrap_policy" : auto_wrap_policy ,
438
450
"sharding_strategy" : hsdp_sharding_strategy ,
439
451
"use_orig_params" : use_orig_params ,
452
+ "device_mesh" : hsdp_device_mesh ,
440
453
}
441
454
if sharding_strategy_mode == ShardingStrategyMode .ALL_HYBRID_SHARD :
442
455
hsdp_model = TransformerWithSharedParams .init (
0 commit comments