8000 [4/N] Test NaN checker against broadcast (#134701) · pytorch/pytorch@d9d95dc · GitHub
[go: up one dir, main page]

Skip to content

Commit d9d95dc

Browse files
kwen2501pytorchmergebot
authored andcommitted
[4/N] Test NaN checker against broadcast (#134701)
Pull Request resolved: #134701 Approved by: https://github.com/wconstab ghstack dependencies: #134345, #134357
1 parent ab646cd commit d9d95dc

File tree

1 file changed

+7
-3
lines changed

1 file changed

+7
-3
lines changed

test/distributed/test_c10d_nccl.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -369,7 +369,7 @@ def test_nan_assert(self, type):
369369

370370
@requires_nccl()
371371
@skip_if_lt_x_gpu(2)
372-
def test_nan_p2p(self):
372+
def test_nan_rank_filter(self):
373373
# Putting NaN at recv buffer, program should not fail as NaN checker
374374
# should not check on receive buffer
375375
os.environ["TORCH_NCCL_NAN_CHECK"] = "1"
@@ -379,11 +379,15 @@ def test_nan_p2p(self):
379379
backend="nccl", store=store, rank=self.rank, world_size=self.world_size
380380
)
381381
t = torch.ones(3, 4, dtype=torch.bfloat16, device=device)
382+
if self.rank != 0:
383+
# Putting NaN at recv buffer
384+
t[1, 1] = float("nan")
385+
# Against broadcast
386+
c10d.broadcast(t, 0)
387+
# Against P2P
382388
if self.rank == 0:
383389
c10d.send(t, 1)
384390
elif self.rank == 1:
385-
# Putting NaN at recv buffer
386-
t[1, 1] = float("nan")
387391
c10d.recv(t, 0)
388392
c10d.destroy_process_group()
389393
# reset env

0 commit comments

Comments
 (0)
0