8000 DistributedModelParallel resharding Interface (#2945) · pytorch/torchrec@7b44948 · GitHub
[go: up one dir, main page]

Skip to content

Commit 7b44948

Browse files
committed
DistributedModelParallel resharding Interface (#2945)
Summary: Pull Request resolved: #2945 Finally! DMP interface for resharding, most of the changes here are to enable proper testing of DMP. ## Main changes: ### 1. DMP reshard API: * which calls the underlying sharder for sharded module to reshard ### 2. Proper Testing: * A multi-rank test which generates a full Model and utilizes DMP interface. Currently only tests TW. * This test is called from `test_dynamic_sharding.py` -> `test_model_parallel.py` -> `test_sharding.py`, which follows the same structure as current DMP unit tests * This is how the test tests for correctness: ``` 1. Generate global model and inputs 2. Create 2 identical local models based on global model 3. Use planner to generate sharding plan for local model 4. Based on planner output, generate a second, different sharding plan 5. Shard both local models 1 and 2 through DMP with plan 1 and 2 respectively 6. Reshard (dynamic sharding API) model 1 with plan 2 7. Generate predictions for local models and compare them to global model prediction. Expect to be the same. ``` * This tests for `optimzier` being correctly saved in resharding * The test is setup with other variables to-be-set once more functionalities are enabled with dynamic sharding, e.g. `variable_batch_size` etc. ### 3. Helper functions for testing * `get_sharding_constructor_from_type` to enable setting sharding_type for each unit test. * `compare_model_pred_one_step` only used for debugging to get more information on whether models are identical after resharding/running initial step * `compare_model_weights` also for debugging ### 3. Small refactoring in `update_shards` call. Differential Revision: D73049934
1 parent 5b1dbe7 commit 7b44948

File tree

6 files changed

+721
-29
lines changed

6 files changed

+721
-29
lines changed

torchrec/distributed/embeddingbag.py

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1516,15 +1516,9 @@ def update_shards(
15161516
current_state = self.state_dict()
15171517
# TODO: Save Optimizers
15181518

1519-
saved_weights = {}
15201519
# TODO: Saving lookups tensors to CPU to eventually avoid recreating them completely again
1521-
for i, lookup in enumerate(self._lookups):
1522-
for attribute, tbe_module in lookup.named_modules():
1523-
if type(tbe_module) is DenseTableBatchedEmbeddingBagsCodegen:
1524-
saved_weights[str(i) + "." + attribute] = tbe_module.weights.cpu()
1525-
# Note: lookup.purge should delete tbe_module and weights
1526-
# del tbe_module.weights
1527-
# del tbe_module
1520+
# TODO: Ensure lookup tensors are actually being deleted
1521+
for _, lookup 10000 in enumerate(self._lookups):
15281522
# pyre-ignore
15291523
lookup.purge()
15301524

torchrec/distributed/model_parallel.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
from torchrec.distributed.types import (
3636
EnumerableShardingSpec,
3737
ModuleSharder,
38+
ParameterSharding,
3839
ShardedModule,
3940
ShardingEnv,
4041
ShardingEnv2D,
@@ -612,6 +613,43 @@ def _reset_parameters(module: nn.Module) -> None:
612613
if hasattr(m, "reset_parameters"):
613614
m.reset_parameters()
614615

616+
def reshard(
617+
self,
618+
path_to_sharded_module: str,
619+
changed_shard_to_params: Dict[str, ParameterSharding],
620+
) -> None:
621+
"""
622+
Reshards a module in the DMP. This is useful when the sharding plan for a module
623+
changes during training.
624+
625+
Args:
626+
path_to_sharded_module (str): The path to the sharded module in the DMP.
627+
changed_shard_to_params (Dict[str, ParameterSharding]): The delta between original sharding plan
628+
and new sharding plan for the module.
629+
"""
630+
steps = path_to_sharded_module.split(".")
631+
sharded_module = self.module
632+
for s in steps:
633+
sharded_module = getattr(sharded_module, s)
634+
635+
assert isinstance(sharded_module, ShardedModule)
636+
assert changed_shard_to_params is not None
637+
sharder_key = sharded_module.unsharded_module_type
638+
sharder = self._sharder_map[sharder_key]
639+
assert hasattr(
640+
sharder, "reshard"
641+
), "reshard is not implemented for this sharder"
642+
sharded_module = sharder.reshard( # pyre-ignore
643+
sharded_module,
644+
changed_shard_to_params,
645+
self._env,
646+
self.device,
647+
)
648+
649+
self._optim: CombinedOptimizer = self._init_optim(self._dmp_wrapped_module)
650+
self._plan.plan[path_to_sharded_module] = sharded_module.module_sharding_plan
651+
return sharded_module
652+
615653

616654
class DMPCollection(DistributedModelParallel):
617655
"""

torchrec/distributed/sharding_plan.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -410,6 +410,20 @@ def _get_parameter_sharding(
410410
]
411411

412412

413+
def get_sharding_constructor_from_type(
414+
sharding_type: ShardingType,
415+
) -> Callable[..., ParameterShardingGenerator]:
416+
sharding_type_to_constructor = {
417+
ShardingType.TABLE_WISE: table_wise,
418+
ShardingType.ROW_WISE: row_wise,
419+
ShardingType.COLUMN_WISE: column_wise,
420+
ShardingType.TABLE_ROW_WISE: table_row_wise,
421+
ShardingType.GRID_SHARD: grid_shard,
422+
ShardingType.DATA_PARALLEL: data_parallel,
423+
}
424+
return sharding_type_to_constructor[sharding_type]
425+
426+
413427
def data_parallel() -> ParameterShardingGenerator:
414428
"""
415429
Returns a generator of ParameterShardingPlan for `ShardingType::DATA_PARALLEL` for construct_module_sharding_plan.

torchrec/distributed/test_utils/test_model_parallel.py

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from torchrec.distributed.test_utils.test_model import TestSparseNN, TestSparseNNBase
2323
from torchrec.distributed.test_utils.test_sharding import (
2424
create_test_sharder,
25+
dynamic_sharding_test,
2526
SharderType,
2627
sharding_single_rank_test,
2728
)
@@ -186,6 +187,78 @@ def _test_sharding(
186187
lengths_dtype=lengths_dtype,
187188
)
188189

190+
def _test_dynamic_sharding(
191+
self,
192+
sharders: List[ModuleSharder[nn.Module]],
193+
backend: str = "gloo",
194+
world_size: int = 2,
195+
local_size: Optional[int] = None,
196+
world_size_2D: Optional[int] = None,
197+
node_group_size: Optional[int] = None,
198+
model_class: Type[TestSparseNNBase] = TestSparseNN,
199+
qcomms_config: Optional[QCommsConfig] = None,
200+
app F438 ly_optimizer_in_backward_config: Optional[
201+
Dict[str, Tuple[Type[torch.optim.Optimizer], Dict[str, Any]]]
202+
] = None,
203+
variable_batch_size: bool = False,
204+
variable_batch_per_feature: bool = False,
205+
has_weighted_tables: bool = True,
206+
global_constant_batch: bool = False,
207+
pooling: PoolingType = PoolingType.SUM,
208+
data_type: DataType = DataType.FP32,
209+
use_inter_host_allreduce: bool = False,
210+
allow_zero_batch_size: bool = False,
211+
custom_all_reduce: bool = False,
212+
use_offsets: bool = False,
213+
indices_dtype: torch.dtype = torch.int64,
214+
offsets_dtype: torch.dtype = torch.int64,
215+
lengths_dtype: torch.dtype = torch.int64,
216+
sharding_type: ShardingType = None, # pyre-ignore
217+
random_seed: int = 0,
218+
) -> None:
219+
"""
220+
Tests the reshard API with dynamic_sharding_test, which creates 2 identical models
221+
one of which is resharded, and then compares the predictions of the 2 models.
222+
"""
223+
self._build_tables_and_groups(data_type=data_type)
224+
constraints = {}
225+
if sharding_type is not None:
226+
for table in self.tables:
227+
name = table.name
228+
# Default sharding type constraints
229+
constraints[name] = ParameterConstraints(
230+
sharding_types=[sharding_type.value],
231+
)
232+
233+
self._run_multi_process_test(
234+
callable=dynamic_sharding_test,
235+
world_size=world_size,
236+
local_size=local_size,
237+
world_size_2D=world_size_2D,
238+
node_group_size=node_group_size,
239+
model_class=model_class,
240+
tables=self.tables if pooling == PoolingType.SUM else self.mean_tables,
241+
weighted_tables=self.weighted_tables if has_weighted_tables else None,
242+
embedding_groups=self.embedding_groups,
243+
sharders=sharders,
244+
backend=backend,
245+
optim=EmbOptimType.EXACT_SGD,
246+
constraints=constraints,
247+
qcomms_config=qcomms_config,
248+
variable_batch_size=variable_batch_size,
249+
apply_optimizer_in_backward_config=apply_optimizer_in_backward_config,
250+
variable_batch_per_feature=variable_batch_per_feature,
251+
global_constant_batch=global_constant_batch,
252+
use_inter_host_allreduce=use_inter_host_allreduce,
253+
allow_zero_batch_size=allow_zero_batch_size,
254+
custom_all_reduce=custom_all_reduce,
255+
use_offsets=use_offsets,
256+
indices_dtype=indices_dtype,
257+
offsets_dtype=offsets_dtype,
258+
lengths_dtype=lengths_dtype,
259+
random_seed=random_seed,
260+
)
261+
189262

190263
@skip_if_asan_class
191264
class ModelParallelBase(ModelParallelTestShared):

0 commit comments

Comments
 (0)
0