From 6f755487eef5f6dceba6560591081b5ce1e6628d Mon Sep 17 00:00:00 2001 From: Yifu Wang Date: Wed, 2 Oct 2024 11:35:03 -0700 Subject: [PATCH] [async-tp] fix a race condition that can cause data corruption [ghstack-poisoned] --- torch/distributed/_symmetric_memory/__init__.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/torch/distributed/_symmetric_memory/__init__.py b/torch/distributed/_symmetric_memory/__init__.py index 4773bbb930d8..042bd14f5984 100644 --- a/torch/distributed/_symmetric_memory/__init__.py +++ b/torch/distributed/_symmetric_memory/__init__.py @@ -156,6 +156,7 @@ def _pipelined_all_gather_and_consume( group_size = symm_mem.world_size rank = symm_mem.rank + symm_mem.barrier(channel=0) backend_stream = _get_backend_stream() backend_stream.wait_stream(torch.cuda.current_stream()) local_p2p_buf = symm_mem.get_buffer(rank, shard.shape, shard.dtype) @@ -169,7 +170,7 @@ def _pipelined_all_gather_and_consume( with torch.cuda.stream(backend_stream): local_p2p_buf.copy_(shard) - symm_mem.barrier(channel=0) + symm_mem.barrier(channel=1) torch.cuda.current_stream().wait_stream(backend_stream) # At this point, all ranks have copied their local shard to @@ -186,9 +187,8 @@ def _pipelined_all_gather_and_consume( chunks[remote_rank].copy_(remote_p2p_buf) shard_consumer(chunks[remote_rank], remote_rank) - with torch.cuda.stream(backend_stream): - symm_mem.barrier(channel=group_size % 2) torch.cuda.current_stream().wait_stream(backend_stream) + symm_mem.barrier(channel=0) def _pipelined_produce_and_all2all( @@ -212,6 +212,7 @@ def _pipelined_produce_and_all2all( group_size = symm_mem.world_size rank = symm_mem.rank + symm_mem.barrier(channel=0) backend_stream = _get_backend_stream() backend_stream.wait_stream(torch.cuda.current_stream()) @@ -251,6 +252,7 @@ def get_p2p_buf(rank: int, idx: int) -> torch.Tensor: chunk_producer(rank, out_chunks[rank]) torch.cuda.current_stream().wait_stream(backend_stream) + symm_mem.barrier(channel=0) lib = torch.library.Library("symm_mem", "DEF") # noqa: TOR901