8000 [FSDP] Enable async collectives in FSDP with MPI backend for compute/comm and comm/comm overlap by nariaki3551 · Pull Request #153215 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content

[FSDP] Enable async collectives in FSDP with MPI backend for compute/comm and comm/comm overlap #153215

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions torch/distributed/fsdp/_common_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,11 @@ def __getattribute__(self, name: str, /) -> Any:


class _FSDPState(_State):
# NOTE: These are used as tokens to ensure that there is at most one ongoing allgather
# and reducescatter each in the FSDP pass.
_all_gather_work_handle: Optional["FlatParamHandle"] = None
_reduce_scatter_work_handle: Optional["FlatParamHandle"] = None

def __init__(self) -> None:
# TODO: Move all the attributes to this class to enable typing for
# FSDP/fully_shard.
Expand Down
37 changes: 35 additions & 2 deletions torch/distributed/fsdp/_flat_param.py
Original file line number Diff line number Diff line change
Expand Up @@ -574,6 +574,10 @@ def __init__(
self._needs_pre_backward_unshard = False
# Was the handle prefetched? Set on successful _prefetch_handle and unshard
self._prefetched = False
# Allgather work for unsharding
self._all_gather_work: Optional[dist.Work] = None
# Reduce scatter work for unsharding
self._reduce_scatter_work: Optional[dist.Work] = None
# Optimistically assume a valid input `params` and set dtype attributes
# before `_init_flat_param()`, which performs the actual validation
self._orig_param_dtype = params[0].dtype
Expand Down Expand Up @@ -1346,6 +1350,20 @@ def unshard(self):
padded_unsharded_flat_param = self._all_gather_flat_param(unsharded_flat_param)
self._use_unsharded_flat_param(padded_unsharded_flat_param)

def wait_all_gather_work(self):
"""Wait for the unshard work to complete."""
if self._all_gather_work is None:
return # no-op when there is no all gather work to wait for
self._all_gather_work.wait()
self._all_gather_work = None

def wait_reduce_scatter_work(self):
"""Wait for the reduce scatter work to complete."""
if self._reduce_scatter_work is None:
return # no-op when there is no reduce scatter work to wait for
self._reduce_scatter_work.wait()
self._reduce_scatter_work = None

def needs_unshard(self) -> bool:
"""Return if the handle's flat parameter needs to be unsharded."""
if not self.uses_sharded_strategy:
Expand Down Expand Up @@ -1427,6 +1445,15 @@ def _all_gather_flat_param(
else self.process_group
)

if dist.get_backend() == "mpi":
# Set async_op to True to delay the wait for all_gather_work for unshard.
# This enables overlapping of communication and computation, and overlapping of
# reduce_scatter (for gradient collection) and all_gather (for unshard) when using
# the MPI backend, improving utilization of compute and network resources.
async_op = True
else:
async_op = False

# HACK this should be handled by C10D
if sharded_flat_param.is_cpu: # type: ignore[attr-defined]
tensor_list = list(
Expand All @@ -1435,13 +1462,19 @@ def _all_gather_flat_param(
dist.get_world_size(pg), # type: ignore[arg-type]
)
)
dist.all_gather(tensor_list, sharded_flat_param, group=pg)
all_gather_work = dist.all_gather(
tensor_list, sharded_flat_param, group=pg, async_op=async_op
)
else:
dist.all_gather_into_tensor(
all_gather_work = dist.all_gather_into_tensor(
padded_unsharded_flat_param,
10000 sharded_flat_param,
pg,
async_op=async_op,
)
if async_op:
self.wait_all_gather_work()
self._all_gather_work = all_gather_work

if self._offload_params:
# In case of offloading, `flat_param.data` (i.e. sharded param) is
Expand Down
44 changes: 42 additions & 2 deletions torch/distributed/fsdp/_runtime_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,16 +290,22 @@ def _unshard(
ran_pre_unshard = handle.pre_unshard()
if ran_pre_unshard:
unshard_stream.wait_stream(pre_unshard_stream)
if dist.get_backend() == "mpi":
pre_unshard_stream.synchronize()
if state.limit_all_gathers:
event = state._free_event_queue.dequeue_if_needed()
if event:
with torch.profiler.record_function(
"FullyShardedDataParallel.rate_limiter"
):
event.synchronize()
if _FSDPState._all_gather_work_handle is not None:
_FSDPState._all_gather_work_handle.wait_all_gather_work()
_FSDPState._all_gather_work_handle = None
with state._device_handle.stream(unshard_stream):
handle.unshard()
handle.post_unshard()
_FSDPState._all_gather_work_handle = handle


@no_type_check
Expand Down Expand Up @@ -400,6 +406,10 @@ def _pre_forward(
input_dtype: Optional[torch.dtype] = state.mixed_precision.param_dtype
args, kwargs = _cast_forward_inputs(input_dtype, *args, **kwargs)
_register_post_backward_reshard_only_hook(state, handle, args, kwargs)
_p_assert(
handle is None or handle._all_gather_work is None,
"handle.wait_all_gather_work() must be called to ensure unshard work is completed",
)
return args, kwargs


Expand All @@ -424,6 +434,8 @@ def _pre_forward_unshard(
state._unshard_event = None
else:
current_stream.wait_stream(state._unshard_stream)
# wait for the unshard to complete
handle.wait_all_gather_work()
with torch.profiler.record_function(
"FullyShardedDataParallel._pre_forward_prefetch"
):
Expand Down Expand Up @@ -679,7 +691,8 @@ def _pre_backward_hook(
# Don't wait during trace
if not torch.distributed._functional_collectives.is_torchdynamo_compiling():
state._device_handle.current_stream().wait_stream(state._unshard_stream)

# Explicitly wait to ensure unshard operation has completed
handle.wait_all_gather_work()
# Set this to `False` to ensure that a mistargeted prefetch does not
# actually unshard these handles
handle._needs_pre_backward_unshard = False
Expand All @@ -689,6 +702,10 @@ def _pre_backward_hook(
_prefetch_handle(state, handle, _PrefetchMode.BACKWARD)
handle.prepare_gradient_for_backward()
handle._ran_pre_backward_hook = True
_p_assert(
handle is None or handle._all_gather_work is None,
"handle.wait_all_gather_work() must be called to ensure unshard work is completed",
)
return grad


Expand Down Expand Up @@ -747,6 +764,10 @@ def _post_backward_hook(
state._post_backward_stream.wait_stream(
state._device_handle.current_stream()
)
if dist.get_backend() == "mpi":
# wait for the backward computation to complete
current_stream = state._device_handle.current_stream()
current_stream.synchronize()

with state._device_handle.stream(state._post_backward_stream):
autograd_computed_grad = flat_param.grad.data
Expand Down Expand Up @@ -851,11 +872,25 @@ def _reduce_grad(state: _FSDPState, handle: FlatParamHandle) -> None:
if handle._use_fake_reduce
else state.process_group
)
dist.reduce_scatter_tensor(
if _FSDPState._reduce_scatter_work_handle is not None:
_FSDPState._reduce_scatter_work_handle.wait_reduce_scatter_work()
_FSDPState._reduce_scatter_work_handle = None
if dist.get_backend() == "mpi" and not uses_hybrid_sharded_strategy:
# Set async_op to True to delay the wait for reduce_scatter work for gradient collection.
# This enables overlapping of communication and computation, and overlapping of
# reduce_scatter (for gradient collection) and all_gather (for unshard) when using
# the MPI backend, improving utilization of compute and network resources.
async_op = True
else:
async_op = False
reduce_scatter_work = dist.reduce_scatter_tensor(
new_sharded_grad,
padded_unsharded_grad,
group=pg,
async_op=async_op,
)
handle._reduce_scatter_work = reduce_scatter_work
_FSDPState._reduce_scatter_work_handle = handle
if uses_hybrid_sharded_strategy:
# Don't wait during trace
if not torch.distributed._functional_collectives.is_torchdynamo_compiling():
Expand Down Expand Up @@ -1098,6 +1133,9 @@ def _post_backward_final_callback(
# since it currently runs in the post-backward stream. That can be
# pushed to the next forward if run in a different stream
current_stream.wait_stream(root_state._post_backward_stream)
if _FSDPState._reduce_scatter_work_handle is not None:
_FSDPState._reduce_scatter_work_handle.wait_reduce_scatter_work()
_FSDPState._reduce_scatter_work_handle = None
if root_state._all_reduce_stream is not current_stream: # uses HSDP
current_stream.wait_stream(root_state._all_reduce_stream)
if root_state.cpu_offload.offload_params:
Expand Down Expand Up @@ -1552,6 +1590,8 @@ def _wait_for_computation_stream(
# do not leverage the pre-all-gather stream is tolerable since this only
# runs once per iteration
pre_unshard_stream.wait_stream(computation_stream) # type: ignore[attr-defined]
if dist.get_backend() == "mpi":
computation_stream.synchronize()


def _reset_flat_param_grad_info_if_needed(
Expand Down
0