8000 [c10d] Add support for testing SIGABRT return (#153167) · pytorch/pytorch@03e102d · GitHub
[go: up one dir, main page]

Skip to content
10000

Commit 03e102d

Browse files
kwen2501pytorchmergebot
authored andcommitted
[c10d] Add support for testing SIGABRT return (#153167)
`SIGABRT` is a common return by *negative* distributed tests, which checks for effectiveness of NaN assert, watchdog throw, etc. These errors are not detectable by traditional statements like `with self.assertRaises(RuntimeError)`. Instead, we'd need to check for the process's return code, e.g. `SIGABRT(6)` would have a return code of -6. Pull Request resolved: #153167 Approved by: https://github.com/fduwjj
1 parent 10c51b1 commit 03e102d

File tree

2 files changed

+63
-152
lines changed

2 files changed

+63
-152
lines changed

test/distributed/test_c10d_nccl.py

Lines changed: 37 additions & 127 deletions
CAD1
Original file line numberDiff line numberDiff line change
@@ -44,13 +44,11 @@
4444
get_timeout,
4545
init_multigpu_helper,
4646
MultiProcessTestCase,
47-
requires_gloo,
4847
requires_multicast_support,
4948
requires_nccl,
5049
requires_nccl_version,
5150
skip_if_lt_x_gpu,
5251
skip_if_rocm_multiprocess,
53-
sm_is_or_higher_than,
5452
TEST_SKIPS,
5553
with_dist_debug_levels,
5654
with_nccl_blocking_wait,
@@ -284,16 +282,17 @@ def opts(self, high_priority_stream=False):
284282

285283
def setUp(self):
286284
super().setUp()
287-
# Need to skip return code checking for these tests since the child
288-
# processes don't exit cleanly in some cuda versions
289-
self.skip_return_code_checks = [
290-
self.test_nan_assert_float16.__wrapped__,
291-
self.test_nan_assert_float32.__wrapped__,
292-
self.test_nan_assert_float64.__wrapped__,
293-
self.test_nan_assert_bfloat16.__wrapped__,
294-
self.test_nan_assert_float8_e4m3fn.__wrapped__,
295-
self.test_nan_assert_float8_e5m2.__wrapped__,
296-
]
285+
286+
# These tests are expected to throw SIGABRT(6); adding the negative sign
287+
# bc the test return code is actually -6
288+
self.special_return_code_checks = {
289+
self.test_nan_assert_float16.__wrapped__: -signal.SIGABRT,
290+
self.test_nan_assert_float32.__wrapped__: -signal.SIGABRT,
291+
self.test_nan_assert_float64.__wrapped__: -signal.SIGABRT,
292+
self.test_nan_assert_bfloat16.__wrapped__: -signal.SIGABRT,
293+
self.test_nan_assert_float8_e4m3fn.__wrapped__: -signal.SIGABRT,
294+
self.test_nan_assert_float8_e5m2.__wrapped__: -signal.SIGABRT,
295+
}
297296

298297
# TORCH_NCCL_BLOCKING_WAIT overrides TORCH_NCCL_ASYNC_ERROR_HANDLING hence tests
299298
# that use TORCH_NCCL_BLOCKING_WAIT will test it as expected.
@@ -534,14 +533,14 @@ def test_nan_assert(self, type):
534533

535534
# confirm enable/disable flag works
536535
backend._set_enable_nan_check(False)
537-
pg.allreduce(nan_tensor)
536+
# Note: using all-gather here bc some NCCL/SM version does not support
537+
# FP8 reduction
538+
pg._allgather_base(output, nan_tensor)
538539

539540
backend._set_enable_nan_check(True)
540-
with self.assertRaises(RuntimeError):
541-
# Note: using all-gather here bc FP8 types do not support reduce ops
542-
# at the moment
543-
pg._allgather_base(output, nan_tensor)
541+
pg._allgather_base(output, nan_tensor)
544542
dist.destroy_process_group()
543+
545544
# reset env
546545
os.environ["TORCH_NCCL_NAN_CHECK"] = "0"
547546

@@ -576,16 +575,13 @@ def test_nan_rank_filter(self):
576575
def test_nan_check(self):
577576
# Not expecting an error, NaN check should not make legit code fail
578577
device = torch.device(f"cuda:{self.rank:d}")
579-
if not sm_is_or_higher_than(device, 8, 0):
580-
self.skipTest("bf16 requires sm >= 8.0")
581-
582578
os.environ["TORCH_NCCL_NAN_CHECK"] = "1"
583579
store = c10d.FileStore(self.file_name, self.world_size)
584580
c10d.init_process_group(
585581
backend="nccl", store=store, rank=self.rank, world_size=self.world_size
586582
)
587-
x = torch.ones((10,), dtype=torch.bfloat16, device=device) * self.rank
588-
t = torch.ones(3, 4, dtype=torch.bfloat16, device=device)
583+
x = torch.ones((10,), device=device) * self.rank
584+
t = torch.ones(3, 4, device=device)
589585
c10d.broadcast(x, src=0)
590586
c10d.all_reduce(t)
591587
c10d.barrier()
@@ -2775,14 +2771,6 @@ def hook(work_info: torch._C._distributed_c10d.WorkInfo):
27752771
class NcclErrorHandlingTest(MultiProcessTestCase):
27762772
def setUp(self):
27772773
super().setUp()
2778-
# Need to skip return code checking for these tests since the child
2779-
# processes don't exit cleanly.
2780-
self.skip_return_code_checks = [
2781-
self.test_nccl_errors_blocking_abort.__wrapped__,
2782-
self.test_nccl_errors_blocking_sigkill.__wrapped__,
2783-
self.test_nccl_errors_blocking_sigterm.__wrapped__,
2784-
self.test_nccl_errors_blocking_nonzero_exit.__wrapped__,
2785-
]
27862774
# TORCH_NCCL_BLOCKING_WAIT overrides TORCH_NCCL_ASYNC_ERROR_HANDLING hence tests
27872775
# that use TORCH_NCCL_BLOCKING_WAIT will test it as expected.
27882776
os.environ["TORCH_NCCL_ASYNC_ERROR_HANDLING"] = "1"
@@ -2810,12 +2798,19 @@ def blocking_wait_error_msg(self):
28102798
def _run_all_reduce(self, pg):
28112799
pg.allreduce(torch.rand(10).cuda(self.rank))
28122800

2801+
def _reduce_timeout(self):
2802+
# set heartbeat timeout to a small value so that we don't wait too long
2803+
# for things to shutdown
2804+
os.environ["TORCH_NCCL_HEARTBEAT_TIMEOUT_SEC"] = "4"
2805+
os.environ["TORCH_NCCL_WAIT_TIMEOUT_DUMP_MILSEC"] = "1000"
2806+
28132807
@requires_nccl()
28142808
@requires_nccl_version((2, 4, 0), "Need NCCL 2.4+ for error checking")
28152809
@skip_if_lt_x_gpu(3)
28162810
@skip_if_rocm_multiprocess
28172811
@skip_but_pass_in_sandcastle("Test does not pass when run locally")
28182812
def test_nccl_errors_nonblocking(self):
2813+
self._reduce_timeout()
28192814
# Note: we unset and restore TORCH_NCCL_ASYNC_ERROR_HANDLING for this test
28202815
# since test_c10d_common runs with async error handling by default, but this
28212816
# tests behavior when it is not enabled.
@@ -2846,30 +2841,24 @@ def test_nccl_errors_nonblocking(self):
28462841
"TORCH_NCCL_ASYNC_ERROR_HANDLING"
28472842
] = prev_nccl_async_error_handling
28482843

2849-
def _test_nccl_errors_blocking(self, func):
2844+
@requires_nccl()
2845+
@skip_if_lt_x_gpu(3)
2846+
@skip_if_rocm_multiprocess
2847+
def test_nccl_errors_blocking(self):
2848+
self._reduce_timeout()
2849+
os.environ["TORCH_NCCL_ASYNC_ERROR_HANDLING"] = "0"
28502850
store = c10d.FileStore(self.file_name, self.world_size)
28512851
process_group = c10d.ProcessGroupNCCL(
28522852
store,
28532853
self.rank,
28542854
self.world_size,
2855-
timeout=timedelta(seconds=10),
28562855
)
2857-
process_group.allreduce(torch.rand(10).cuda(self.rank))
2856+
x = torch.rand(1024 * 1024).cuda(self.rank)
2857+
process_group.allreduce(x)
28582858
if self.rank == 0:
2859-
work = process_group.allreduce(torch.rand(10).cuda(self.rank))
2859+
work = process_group.allreduce(x)
28602860
with self.assertRaisesRegex(dist.DistBackendError, ""):
2861-
# It seems the error message would be different depending on
2862-
# whether the test is run on CI machine and devGPU. Skipping
2863-
# the error message check to make both sides happy.
28642861
work.wait(timeout=timedelta(seconds=self.op_timeout_sec))
2865-
# Run some GPU operations to make sure cuda has not gotten stuck.
2866-
# It was observed cuda could get stuck if NCCL communicators were
2867-
# not properly aborted before throwing RuntimeError.
2868-
torch.rand(10).cuda(self.rank)
2869-
elif self.rank == 1:
2870-
# Clean up structures (ex: files for FileStore before going down)
2871-
del process_group
2872-
func()
28732862

28742863
def _test_barrier_error(self):
28752864
store = c10d.FileStore(self.file_name, self.world_size)
@@ -2889,60 +2878,19 @@ def _test_barrier_error(self):
28892878
timeout=timedelta(seconds=self.op_timeout_sec)
28902879
)
28912880

2892-
@with_nccl_blocking_wait
2893-
@requires_nccl()
2894-
@requires_nccl_version((2, 4, 0), "Need NCCL 2.4+ for error checking")
2895-
@skip_if_lt_x_gpu(3)
2896-
@skip_if_rocm_multiprocess
2897-
def test_nccl_errors_blocking_clean_exit(self):
2898-
self._test_nccl_errors_blocking(lambda: sys.exit(0))
2899-
2900-
@with_nccl_blocking_wait
2901-
@requires_nccl()
2902-
@requires_nccl_version((2, 4, 0), "Need NCCL 2.4+ for error checking")
2903-
@skip_if_lt_x_gpu(3)
2904-
@skip_if_rocm_multiprocess
2905-
def test_nccl_errors_blocking_nonzero_exit(self):
2906-
self._test_nccl_errors_blocking(lambda: sys.exit(1))
2907-
2908-
@with_nccl_blocking_wait
2909-
@requires_nccl()
2910-
@requires_nccl_version((2, 4, 0), "Need NCCL 2.4+ for error checking")
2911-
@skip_if_lt_x_gpu(3)
2912-
@skip_if_rocm_multiprocess
2913-
@skip_but_pass_in_sandcastle(
2914-
"Frequently times out see https://github.com/pytorch/pytorch/issues/58920"
2915-
)
2916-
def test_nccl_errors_blocking_abort(self):
2917-
self._test_nccl_errors_blocking(lambda: os.abort())
2918-
2919-
@with_nccl_blocking_wait
2920-
@requires_nccl()
2921-
@requires_nccl_version((2, 4, 0), "Need NCCL 2.4+ for error checking")
2922-
@skip_if_lt_x_gpu(3)
2923-
@skip_if_rocm_multiprocess
2924-
def test_nccl_errors_blocking_sigkill(self):
2925-
self._test_nccl_errors_blocking(lambda: os.kill(os.getpid(), signal.SIGKILL))
2926-
2927-
@with_nccl_blocking_wait
2928-
@requires_nccl()
2929-
@requires_nccl_version((2, 4, 0), "Need NCCL 2.4+ for error checking")
2930-
@skip_if_lt_x_gpu(3)
2931-
@skip_if_rocm_multiprocess
2932-
def test_nccl_errors_blocking_sigterm(self):
2933-
self._test_nccl_errors_blocking(lambda: os.kill(os.getpid(), signal.SIGTERM))
2934-
29352881
@with_nccl_blocking_wait
29362882
@requires_nccl()
29372883
@requires_nccl_version((2, 4, 0), "Need NCCL 2.4+ for error checking")
29382884
@skip_if_lt_x_gpu(3)
29392885
def test_nccl_blocking_wait_with_barrier(self):
2886+
self._reduce_timeout()
29402887
self._test_barrier_error()
29412888

29422889
@requires_nccl()
29432890
@requires_nccl_version((2, 4, 0), "Need NCCL 2.4+ for error checking")
29442891
@skip_if_lt_x_gpu(3)
29452892
def test_nccl_non_blocking_wait_with_barrier(self):
2893+
self._reduce_timeout()
29462894
# test the barrier behavior in the non blocking wait setting
29472895
prev_nccl_async_error_handling = os.environ.get(
29482896
"TORCH_NCCL_ASYNC_ERROR_HANDLING", None
@@ -3013,6 +2961,7 @@ def assert_fut_success(fut):
30132961
@skip_if_rocm_multiprocess
30142962
@skip_if_lt_x_gpu(3)
30152963
def test_restart_pg_after_error(self):
2964+
self._reduce_timeout()
30162965
# test the barrier behavior in the non blocking wait setting
30172966
prev_nccl_async_error_handling = os.environ.get(
30182967
"TORCH_NCCL_ASYNC_ERROR_HANDLING", None
@@ -3102,45 +3051,6 @@ def test_invalid_nccl_blocking_wait_env(self):
31023051
self._run_invalid_nccl_blocking_wait_env("2147483647")
31033052
self._run_invalid_nccl_blocking_wait_env("4294967295")
31043053

3105-
@with_nccl_blocking_wait
3106-
@requires_nccl()
3107-
@requires_gloo()
3108-
@skip_if_lt_x_gpu(3)
3109-
def test_nccl_timeout(self):
3110-
store = c10d.FileStore(self.file_name, self.world_size)
3111-
3112-
# Initialize process_group.
3113-
process_group = c10d.ProcessGroupNCCL(
3114-
store, self.rank, self.world_size, timeout=timedelta(seconds=10)
3115-
)
3116-
# Control gloo pg used as go-ahead signal/barrier
3117-
# to coordinate btwn ranks.
3118-
pg_gloo = c10d.ProcessGroupGloo(store, self.rank, self.world_size)
3119-
failed_collective_timeout = timedelta(milliseconds=100)
3120-
process_group.allreduce(torch.rand(10).cuda(self.rank)).wait(
3121-
timeout=timedelta(seconds=5)
3122-
)
3123-
3124-
if self.rank == 0:
3125-
# This should timeout in about 1 second.
3126-
# Watchdog may abort timed out work resulting in NCCL error instead of operation timed out.
3127-
with self.assertRaisesRegex(
3128-
dist.DistBackendError, self.blocking_wait_error_msg
3129-
):
3130-
process_group.allreduce(torch.rand(10).cuda(self.rank)).wait(
3131-
timeout=failed_collective_timeout
3132-
)
3133-
# Now do a barrier to tell other rank to go ahead.
3134-
pg_gloo.barrier().wait()
3135-
else:
3136-
# Wait on rank 0 to fail.
3137-
try:
3138-
pg_gloo.barrier().wait()
3139-
except Exception as e:
3140-
raise ValueError(
3141-
f"Rank {self.rank} barrier timed out waiting for rank 0 with error: {str(e)}"
3142-
) from e
3143-
31443054

31453055
class NcclUserBufferRegistrationTest(MultiProcessTestCase):
31463056
def setUp(self):

torch/testing/_internal/common_distributed.py

Lines changed: 26 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -642,7 +642,15 @@ def __init__(
642642

643643
def setUp(self) -> None:
644644
super().setUp()
645-
self.skip_return_code_checks = [] # type: ignore[var-annotated]
645+
646+
# Used for tests that are expected to return a non-0 exit code, such as
647+
# SIGABRT thrown by watchdog.
648+
self.special_return_code_checks: dict = {}
649+
650+
# Used for tests that may return any exit code, which makes it hard to
651+
# check. This is rare, use with caution.
652+
self.skip_return_code_checks: list = []
653+
646654
self.processes = [] # type: ignore[var-annotated]
647655
self.rank = self.MAIN_PROCESS_RANK
648656
self.file_name = tempfile.NamedTemporaryFile(delete=False).name
@@ -862,28 +870,13 @@ def _join_processes(self, fn) -> None:
862870
time.sleep(0.1)
863871

864872
elapsed_time = time.time() - start_time
865-
866-
if fn in self.skip_return_code_checks:
867-
self._check_no_test_errors(elapsed_time)
868-
else:
869-
self._check_return_codes(elapsed_time)
873+
self._check_return_codes(fn, elapsed_time)
870874
finally:
871875
# Close all pipes
872876
for pipe in self.pid_to_pipe.values():
873877
pipe.close()
874878

875-
def _check_no_test_errors(self, elapsed_time) -> None:
876-
"""
877-
Checks that we didn't have any errors thrown in the child processes.
878-
"""
879-
for i, p in enumerate(self.processes):
880-
if p.exitcode is None:
881-
raise RuntimeError(
882-
f"Process {i} timed out after {elapsed_time} seconds"
883-
)
884-
self.assertNotEqual(self.TEST_ERROR_EXIT_CODE, p.exitcode)
885-
886-
def _check_return_codes(self, elapsed_time) -> None:
879+
def _check_return_codes(self, fn, elapsed_time) -> None:
887880
"""
888881
Checks that the return codes of all spawned processes match, and skips
889882
tests if they returned a return code indicating a skipping condition.
@@ -925,11 +918,11 @@ def _check_return_codes(self, elapsed_time) -> None:
925918
raise RuntimeError(
926919
f"Process {i} terminated or timed out after {elapsed_time} seconds"
927920
)
928-
self.assertEqual(
929-
p.exitcode,
930-
first_process.exitcode,
931-
msg=f"Expect process {i} exit code to match Process 0 exit code of {first_process.exitcode}, but got {p.exitcode}",
932-
)
921+
922+
# Skip the test return code check
923+
if fn in self.skip_return_code_checks:
924+
return
925+
933926
for skip in TEST_SKIPS.values():
934927
if first_process.exitcode == skip.exit_code:
935928
if IS_SANDCASTLE:
@@ -945,10 +938,18 @@ def _check_return_codes(self, elapsed_time) -> None:
945938
return
946939
else:
947940
raise unittest.SkipTest(skip.message)
941+
942+
# In most cases, we expect test to return exit code 0, standing for success.
943+
expected_return_code = 0
944+
# In some negative tests, we expect test to return non-zero exit code,
945+
# such as watchdog throwing SIGABRT.
946+
if fn in self.special_return_code_checks:
947+
expected_return_code = self.special_return_code_checks[fn]
948+
948949
self.assertEqual(
949950
first_process.exitcode,
950-
0,
951-
msg=f"Expected zero exit code but got {first_process.exitcode} for pid: {first_process.pid}",
951+
expected_return_code,
952+
msg=f"Expected exit code {expected_return_code} but got {first_process.exitcode} for pid: {first_process.pid}",
952953
)
953954

954955
@property

0 commit comments

Comments
 (0)
0