8000 [BE]: Remove redundant copy in torch chunk shard (#144269) · pytorch/pytorch@b5cf8e2 · GitHub
[go: up one dir, main page]

Skip to content

Commit b5cf8e2

Browse files
Skylion007pytorchmergebot
authored andcommitted
[BE]: Remove redundant copy in torch chunk shard (#144269)
Fixes an issue noticed in recent all_gather PR. Some parts of the codebase have a double copy with `clone().contiguous()` which could be fused into a single copy op. Pull Request resolved: #144269 Approved by: https://github.com/awgu
1 parent 1b8a943 commit b5cf8e2

File tree

1 file changed

+3
-1
lines changed

1 file changed

+3
-1
lines changed

torch/distributed/_shard/sharding_spec/chunk_sharding_spec.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -162,7 +162,9 @@ def shard(
162162
narrowed_tensor.detach().clone().resize_(scatter_shape)
163163
)
164164
else:
165-
tensor_to_scatter = narrowed_tensor.detach().clone().contiguous()
165+
tensor_to_scatter = narrowed_tensor.detach().clone(
166+
memory_format=torch.contiguous_format
167+
)
166168

167169
tensors_to_scatter[
168170
dist.get_group_rank(process_group, remote_global_rank)

0 commit comments

Comments
 (0)
0