8000 [async-tp] fix a race condition that can cause data corruption · pytorch/pytorch@2909b0f · GitHub
[go: up one dir, main page]

Skip to content

Commit 2909b0f

Browse files
author
Yifu Wang
committed
[async-tp] fix a race condition that can cause data corruption
ghstack-source-id: 25ae32a Pull Request resolved: #137199
1 parent 0d1701f commit 2909b0f

File tree

1 file changed

+5
-3
lines changed

1 file changed

+5
-3
lines changed

torch/distributed/_symmetric_memory/__init__.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -156,6 +156,7 @@ def _pipelined_all_gather_and_consume(
156156
group_size = symm_mem.world_size
157157
rank = symm_mem.rank
158158

159+
symm_mem.barrier(channel=0)
159160
backend_stream = _get_backend_stream()
160161
backend_stream.wait_stream(torch.cuda.current_stream())
161162
local_p2p_buf = symm_mem.get_buffer(rank, shard.shape, shard.dtype)
@@ -169,7 +170,7 @@ def _pipelined_all_gather_and_consume(
169170

170171
with torch.cuda.stream(backend_stream):
171172
local_p2p_buf.copy_(shard)
172-
symm_mem.barrier(channel=0)
173+
symm_mem.barrier(channel=1)
173174
torch.cuda.current_stream().wait_stream(backend_stream)
174175

175176
# At this point, all ranks have copied their local shard to
@@ -186,9 +187,8 @@ def _pipelined_all_gather_and_consume(
186187
chunks[remote_rank].copy_(remote_p2p_buf)
187188
shard_consumer(chunks[remote_rank], remote_rank)
188189

189-
with torch.cuda.stream(backend_stream):
190-
symm_mem.barrier(channel=group_size % 2)
191190
torch.cuda.current_stream().wait_stream(backend_stream)
191+
symm_mem.barrier(channel=0)
192192

193193

194194
def _pipelined_produce_and_all2all(
@@ -212,6 +212,7 @@ def _pipelined_produce_and_all2all(
212212
group_size = symm_mem.world_size
213213
rank = symm_mem.rank
214214

215+
symm_mem.barrier(channel=0)
215216
backend_stream = _get_backend_stream()
216217
backend_stream.wait_stream(torch.cuda.current_stream())
217218

@@ -251,6 +252,7 @@ def get_p2p_buf(rank: int, idx: int) -> torch.Tensor:
251252

252253
chunk_producer(rank, out_chunks[rank])
253254
torch.cuda.current_stream().wait_stream(backend_stream)
255+
symm_mem.barrier(channel=0)
254256

255257

256258
lib = torch.library.Library("symm_mem", "DEF") # noqa: TOR901

0 commit comments

Comments
 (0)
0