@@ -156,6 +156,7 @@ def _pipelined_all_gather_and_consume(
156
156
group_size = symm_mem .world_size
157
157
rank = symm_mem .rank
158
158
159
+ symm_mem .barrier (channel = 0 )
159
160
backend_stream = _get_backend_stream ()
160
161
backend_stream .wait_stream (torch .cuda .current_stream ())
161
162
local_p2p_buf = symm_mem .get_buffer (rank , shard .shape , shard .dtype )
@@ -169,7 +170,7 @@ def _pipelined_all_gather_and_consume(
169
170
170
171
with torch .cuda .stream (backend_stream ):
171
172
local_p2p_buf .copy_ (shard )
172
- symm_mem .barrier (channel = 0 )
173
+ symm_mem .barrier (channel = 1 )
173
174
torch .cuda .current_stream ().wait_stream (backend_stream )
174
175
175
176
# At this point, all ranks have copied their local shard to
@@ -186,9 +187,8 @@ def _pipelined_all_gather_and_consume(
186
187
chunks [remote_rank ].copy_ (remote_p2p_buf )
187
188
shard_consumer (chunks [remote_rank ], remote_rank )
188
189
189
- with torch .cuda .stream (backend_stream ):
190
- symm_mem .barrier (channel = group_size % 2 )
191
190
torch .cuda .current_stream ().wait_stream (backend_stream )
191
+ symm_mem .barrier (channel = 0 )
192
192
193
193
194
194
def _pipelined_produce_and_all2all (
@@ -212,6 +212,7 @@ def _pipelined_produce_and_all2all(
212
212
group_size = symm_mem .world_size
213
213
rank = symm_mem .rank
214
214
215
+ symm_mem .barrier (channel = 0 )
215
216
backend_stream = _get_backend_stream ()
216
217
backend_stream .wait_stream (torch .cuda .current_stream ())
217
218
@@ -251,6 +252,7 @@ def get_p2p_buf(rank: int, idx: int) -> torch.Tensor:
251
252
252
253
chunk_producer (rank , out_chunks [rank ])
253
254
torch .cuda .current_stream ().wait_stream (backend_stream )
255
+ symm_mem .barrier (channel = 0 )
254
256
255
257
256
258
lib = torch .library .Library ("symm_mem" , "DEF" ) # noqa: TOR901
0 commit comments