8000 [shard] use gather_object in gather API (#71624) · pytorch/pytorch@d0f9556 · GitHub
[go: up one dir, main page]

Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit d0f9556

Browse files
Wanchao Liangpytorchmergebot
Wanchao Liang
authored andcommitted
[shard] use gather_object in gather API (#71624)
Summary: Pull Request resolved: #71624 Now we have gather available in NCCL pg, we can switch our `sharded_tensor.gather` to use gather_object instead of all_gather_object, which will reduce the communication overhead. TODO: To further reduce the comm overhead, we need to figure out a way to avoid using `gather_object`, as `gather_object` or `all_gather_object` incurs pickling copy between devices. ghstack-source-id: 151007578 Test Plan: wait for ci Reviewed By: pritamdamania87 Differential Revision: D33688907 fbshipit-source-id: 2073c5a46c33a7a2640a9e3599dc795d9e4c0a1e (cherry picked from commit dbc983a)
1 parent 5b805a6 commit d0f9556

File tree

1 file changed

+9
-7
lines changed
  • torch/distributed/_shard/sharded_tensor

1 file changed

+9
-7
lines changed

torch/distributed/_shard/sharded_tensor/api.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -280,17 +280,19 @@ def gather(
280280

281281
world_size = dist.get_world_size(self._process_group)
282282

283-
gathered_shards = [None] * world_size
284-
# will revise this part with CPU support and use dist.gather()
285-
# once NCCL support for gather() is ready
286-
# https://github.com/pytorch/pytorch/issues/66187
287-
dist.all_gather_object(
283+
gathered_shards: List[Optional[List[Shard]]] = [None] * world_size if rank == dst else []
284+
# TODO: see how we could use dist.gather() instead of dist.gather_object
285+
# as the latter one involves pickling on CPU, see more context
286+
# https://github.com/pytorch/pytorch/issues/73935
287+
dist.gather_object(
288288
obj=local_shards,
289-
object_list=gathered_shards,
289+
object_gather_list=gathered_shards,
290+
dst=dst,
290291
group=self._process_group,
291292
)
292-
293293
if rank == dst:
294+
if out is None:
295+
raise ValueError("`out` Tensor must be provided on dst rank!")
294296
dims = len(full_size)
295297
for shards in gathered_shards:
296298
if shards is None:

0 commit comments

Comments
 (0)
0