8000 [c10d] Add support for testing SIGABRT return (#153167) · pytorch/pytorch@8c16d0e · GitHub
[go: up one dir, main page]

Skip to content
8000

Commit 8c16d0e

Browse files
kwen2501pytorchmergebot
authored andcommitted
[c10d] Add support for testing SIGABRT return (#153167)
`SIGABRT` is a common return by *negative* distributed tests, which checks for effectiveness of NaN assert, watchdog throw, etc. These errors are not detectable by traditional statements like `with self.assertRaises(RuntimeError)`. Instead, we'd need to check for the process's return code, e.g. `SIGABRT(6)` would have a return code of -6. Pull Request resolved: #153167 Approved by: https://github.com/fduwjj
1 parent b04852e commit 8c16d0e

File tree

2 files changed

+63
-151
lines changed

2 files changed

+63
-151
lines changed

test/distributed/test_c10d_nccl.py

Lines changed: 37 additions & 126 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,6 @@
4444
get_timeout,
4545
init_multigpu_helper,
4646
MultiProcessTestCase,
47-
requires_gloo,
4847
requires_multicast_support,
4948
requires_nccl,
5049
requires_nccl_version,
@@ -284,16 +283,17 @@ def opts(self, high_priority_stream=False):
284283

285284
def setUp(self):
286285
super().setUp()
287-
# Need to skip return code checking for these tests since the child
288-
# processes don't exit cleanly in some cuda versions
289-
self.skip_return_code_checks = [
290-
self.test_nan_assert_float16.__wrapped__,
291-
self.test_nan_assert_float32.__wrapped__,
292-
self.test_nan_assert_float64.__wrapped__,
293-
self.test_nan_assert_bfloat16.__wrapped__,
294-
self.test_nan_assert_float8_e4m3fn.__wrapped__,
295-
self.test_nan_assert_float8_e5m2.__wrapped__,
296-
]
286+
287+
# These tests are expected to throw SIGABRT(6); adding the negative sign
288+
# bc the test return code is actually -6
289+
self.special_return_code_checks = {
290+
self.test_nan_assert_float16.__wrapped__: -signal.SIGABRT,
291+
self.test_nan_assert_float32.__wrapped__: -signal.SIGABRT,
292+
self.test_nan_assert_float64.__wrapped__: -signal.SIGABRT,
293+
self.test_nan_assert_bfloat16.__wrapped__: -signal.SIGABRT,
294+
self.test_nan_assert_float8_e4m3fn.__wrapped__: -signal.SIGABRT,
295+
self.test_nan_assert_float8_e5m2.__wrapped__: -signal.SIGABRT,
296+
}
297297

298298
# TORCH_NCCL_BLOCKING_WAIT overrides TORCH_NCCL_ASYNC_ERROR_HANDLING hence tests
299299
# that use TORCH_NCCL_BLOCKING_WAIT will test it as expected.
@@ -534,14 +534,14 @@ def test_nan_assert(self, type):
534534

535535
# confirm enable/disable flag works
536536
backend._set_enable_nan_check(False)
537-
pg.allreduce(nan_tensor)
537+
# Note: using all-gather here bc some NCCL/SM version does not support
538+
# FP8 reduction
539+
pg._allgather_base(output, nan_tensor)
538540

539541
backend._set_enable_nan_check(True)
540-
with self.assertRaises(RuntimeError):
541-
# Note: using all-gather here bc FP8 types do not support reduce ops
542-
# at the moment
543-
pg._allgather_base(output, nan_tensor)
542+
pg._allgather_base(output, nan_tensor)
544543
dist.destroy_process_group()
544+
545545
# reset env
546546
os.environ["TORCH_NCCL_NAN_CHECK"] = "0"
547547

@@ -576,16 +576,13 @@ def test_nan_rank_filter(self):
576576
def test_nan_check(self):
577577
# Not expecting an error, NaN check should not make legit code fail
578578
device = torch.device(f"cuda:{self.rank:d}")
579-
if not sm_is_or_higher_than(device, 8, 0):
580-
self.skipTest("bf16 requires sm >= 8.0")
581-
582579
os.environ["TORCH_NCCL_NAN_CHECK"] = "1"
583580
store = c10d.FileStore(self.file_name, self.world_size)
584581
c10d.init_process_group(
585582
backend="nccl", store=store, rank=self.rank, world_size=self.world_size
586583
)
587-
x = torch.ones((10,), dtype=torch.bfloat16, device=device) * self.rank
588-
t = torch.ones(3, 4, dtype=torch.bfloat16, device=device)
584+
x = torch.ones((10,), device=device) * self.rank
585+
t = torch.ones(3, 4, device=device)
589586
c10d.broadcast(x, src=0)
590587
c10d.all_reduce(t)
591588
c10d.barrier()
@@ -2775,14 +2772,6 @@ def hook(work_info: torch._C._distributed_c10d.WorkInfo):
27752772
class NcclErrorHandlingTest(MultiProcessTestCase):
27762773
def setUp(self):
27772774
super().setUp()
2778-
# Need to skip return code checking for these tests since the child
2779-
# processes don't exit cleanly.
2780-
self.skip_return_code_checks = [
2781-
self.test_nccl_errors_blocking_abort.__wrapped__,
2782-
self.test_nccl_errors_blocking_sigkill.__wrapped__,
2783-
self.test_nccl_errors_blocking_sigterm.__wrapped__,
2784-
self.test_nccl_errors_blocking_nonzero_exit.__wrapped__,
2785-
]
27862775
# TORCH_NCCL_BLOCKING_WAIT overrides TORCH_NCCL_ASYNC_ERROR_HANDLING hence tests
27872776
# that use TORCH_NCCL_BLOCKING_WAIT will test it as expected.
27882777
os.environ["TORCH_NCCL_ASYNC_ERROR_HANDLING"] = "1"
@@ -2810,12 +2799,19 @@ def blocking_wait_error_msg(self):
28102799
def _run_all_reduce(self, pg):
28112800
pg.allreduce(torch.rand(10).cuda(self.rank))
28122801

2802+
def _reduce_timeout(self):
2803+
# set heartbeat timeout to a small value so that we don't wait too long
2804+
# for things to shutdown
2805+
os.environ["TORCH_NCCL_HEARTBEAT_TIMEOUT_SEC"] = "4"
2806+
os.environ["TORCH_NCCL_WAIT_TIMEOUT_DUMP_MILSEC"] = "1000"
2807+
28132808
@requires_nccl()
28142809
@requires_nccl_version((2, 4, 0), "Need NCCL 2.4+ for error checking")
28152810
@skip_if_lt_x_gpu(3)
28162811
@skip_if_rocm_multiprocess
28172812
@skip_but_pass_in_sandcastle("Test does not pass when run locally")
28182813
def test_nccl_errors_nonblocking(self):
2814+
self._reduce_timeout()
28192815
# Note: we unset and restore TORCH_NCCL_ASYNC_ERROR_HANDLING for this test
28202816
# since test_c10d_common runs with async error handling by default, but this
28212817
# tests behavior when it is not enabled.
@@ -2846,30 +2842,24 @@ def test_nccl_errors_nonblocking(self):
28462842
"TORCH_NCCL_ASYNC_ERROR_HANDLING"
28472843
] = prev_nccl_async_error_handling
28482844

2849-
def _test_nccl_errors_blocking(self, func):
2845+
@requires_nccl()
2846+
@skip_if_lt_x_gpu(3)
2847+
@skip_if_rocm_multiprocess
2848+
def test_nccl_errors_blocking(self):
2849+
self._reduce_timeout()
2850+
os.environ["TORCH_NCCL_ASYNC_ERROR_HANDLING"] = "0"
28502851
store = c10d.FileStore(self.file_name, self.world_size)
28512852
process_group = c10d.ProcessGroupNCCL(
28522853
store,
28532854
self.rank,
28542855
self.world_size,
2855-
timeout=timedelta(seconds=10),
28562856
)
2857-
process_group.allreduce(torch.rand(10).cuda(self.rank))
2857+
x = torch.rand(1024 * 1024).cuda(self.rank)
2858+
process_group.allreduce(x)
28582859
if self.rank == 0:
2859-
work = process_group.allreduce(torch.rand(10).cuda(self.rank))
2860+
work = process_group.allreduce(x)
28602861
with self.assertRaisesRegex(dist.DistBackendError, ""):
2861-
# It seems the error message would be different depending on
2862-
# whether the test is run on CI machine and devGPU. Skipping
2863-
# the error message check to make both sides happy.
28642862
work.wait(timeout=timedelta(seconds=self.op_timeout_sec))
2865-
# Run some GPU operations to make sure cuda has not gotten stuck.
2866-
# It was observed cuda could get stuck if NCCL communicators were
2867-
# not properly aborted before throwing RuntimeError.
2868-
torch.rand(10).cuda(self.rank)
2869-
elif self.rank == 1:
2870-
# Clean up structures (ex: files for FileStore before going down)
2871-
del process_group
2872-
func()
28732863

28742864
def _test_barrier_error(self):
28752865
store = c10d.FileStore(self.file_name, self.world_size)
@@ -2889,60 +2879,19 @@ def _test_barrier_error(self):
28892879 BE9D
timeout=timedelta(seconds=self.op_timeout_sec)
28902880
)
28912881

2892-
@with_nccl_blocking_wait
2893-
@requires_nccl()
2894-
@requires_nccl_version((2, 4, 0), "Need NCCL 2.4+ for error checking")
2895-
@skip_if_lt_x_gpu(3)
2896-
@skip_if_rocm_multiprocess
2897-
def test_nccl_errors_blocking_clean_exit(self):
2898-
self._test_nccl_errors_blocking(lambda: sys.exit(0))
2899-
2900-
@with_nccl_blocking_wait
2901-
@requires_nccl()
2902-
@requires_nccl_version((2, 4, 0), "Need NCCL 2.4+ for error checking")
2903-
@skip_if_lt_x_gpu(3)
2904-
@skip_if_rocm_multiprocess
2905-
def test_nccl_errors_blocking_nonzero_exit(self):
2906-
self._test_nccl_errors_blocking(lambda: sys.exit(1))
2907-
2908-
@with_nccl_blocking_wait
2909-
@requires_nccl()
2910-
@requires_nccl_version((2, 4, 0), "Need NCCL 2.4+ for error checking")
2911-
@skip_if_lt_x_gpu(3)
2912-
@skip_if_rocm_multiprocess
2913-
@skip_but_pass_in_sandcastle(
2914-
"Frequently times out see https://github.com/pytorch/pytorch/issues/58920"
2915-
)
2916-
def test_nccl_errors_blocking_abort(self):
2917-
self._test_nccl_errors_blocking(lambda: os.abort())
2918-
2919-
@with_nccl_blocking_wait
2920-
@requires_nccl()
2921-
@requires_nccl_version((2, 4, 0), "Need NCCL 2.4+ for error checking")
2922-
@skip_if_lt_x_gpu(3)
2923-
@skip_if_rocm_multiprocess
2924-
def test_nccl_errors_blocking_sigkill(self):
2925-
self._test_nccl_errors_blocking(lambda: os.kill(os.getpid(), signal.SIGKILL))
2926-
2927-
@with_nccl_blocking_wait
2928-
@requires_nccl()
2929-
@requires_nccl_version((2, 4, 0), "Need NCCL 2.4+ for error checking")
2930-
@skip_if_lt_x_gpu(3)
2931-
@skip_if_rocm_multiprocess
2932-
def test_nccl_errors_blocking_sigterm(self):
2933-
self._test_nccl_errors_blocking(lambda: os.kill(os.getpid(), signal.SIGTERM))
2934-
29352882
@with_nccl_blocking_wait
29362883
@requires_nccl()
29372884
@requires_nccl_version((2, 4, 0), "Need NCCL 2.4+ for error checking")
29382885
@skip_if_lt_x_gpu(3)
29392886
def test_nccl_blocking_wait_with_barrier(self):
2887+
self._reduce_timeout()
29402888
self._test_barrier_error()
29412889

29422890
@requires_nccl()
29432891
@requires_nccl_version((2, 4, 0), "Need NCCL 2.4+ for error checking")
29442892
@skip_if_lt_x_gpu(3)
29452893
def test_nccl_non_blocking_wait_with_barrier(self):
2894+
self._reduce_timeout()
29462895
# test the barrier behavior in the non blocking wait setting
29472896
prev_nccl_async_error_handling = os.environ.get(
29482897
"TORCH_NCCL_ASYNC_ERROR_HANDLING", None
@@ -3013,6 +2962,7 @@ def assert_fut_success(fut):
30132962
@skip_if_rocm_multiprocess
30142963
@skip_if_lt_x_gpu(3)
30152964
def test_restart_pg_after_error(self):
2965+
self._reduce_timeout()
30162966
# test the barrier behavior in the non blocking wait setting
30172967
prev_nccl_async_error_handling = os.environ.get(
30182968
"TORCH_NCCL_ASYNC_ERROR_HANDLING", None
@@ -3102,45 +3052,6 @@ def test_invalid_nccl_blocking_wait_env(self):
31023052
self._run_invalid_nccl_blocking_wait_env("2147483647")
31033053
self._run_invalid_nccl_blocking_wait_env("4294967295")
31043054

3105-
@with_nccl_blocking_wait
3106-
@requires_nccl()
3107-
@requires_gloo()
3108-
@skip_if_lt_x_gpu(3)
3109-
def test_nccl_timeout(self):
3110-
store = c10d.FileStore(self.file_name, self.world_size)
3111-
3112-
# Initialize process_group.
3113-
process_group = c10d.ProcessGroupNCCL(
3114-
store, self.rank, self.world_size, timeout=timedelta(seconds=10)
3115-
)
3116-
# Control gloo pg used as go-ahead signal/barrier
3117-
# to coordinate btwn ranks.
3118-
pg_gloo = c10d.ProcessGroupGloo(store, self.rank, self.world_size)
3119-
failed_collective_timeout = timedelta(milliseconds=100)
3120-
process_group.allreduce(torch.rand(10).cuda(self.rank)).wait(
3121-
timeout=timedelta(seconds=5)
3122-
)
3123-
3124-
if self.rank == 0:
3125-
# This should timeout in about 1 second.
3126-
# Watchdog may abort timed out work resulting in NCCL error instead of operation timed out.
3127-
with self.assertRaisesRegex(
3128-
dist.DistBackendError, self.blocking_wait_error_msg
3129-
):
3130-
process_group.allreduce(torch.rand(10).cuda(self.rank)).wait(
3131-
timeout=failed_collective_timeout
3132-
)
3133-
# Now do a barrier to tell other rank to go ahead.
3134-
pg_gloo.barrier().wait()
3135-
else:
3136-
# Wait on rank 0 to fail.
3137-
try:
3138-
pg_gloo.barrier().wait()
3139-
except Exception as e:
3140-
raise ValueError(
3141-
f"Rank {self.rank} barrier timed out waiting for rank 0 with error: {str(e)}"
3142-
) from e
3143-
31443055

31453056
class NcclUserBufferRegistrationTest(MultiProcessTestCase):
31463057
def setUp(self):

torch/testing/_internal/common_distributed.py

Lines changed: 26 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -640,7 +640,15 @@ def __init__(
640640

641641
def setUp(self) -> None:
642642
super().setUp()
643-
self.skip_return_code_checks = [] # type: ignore[var-annotated]
643+
644+
# Used for tests that are expected to return a non-0 exit code, such as
645+
# SIGABRT thrown by watchdog.
646+
self.special_return_code_checks: dict = {}
647+
648+
# Used for tests that may return any exit code, which makes it hard to
649+
# check. This is rare, use with caution.
650+
self.skip_return_code_checks: list = []
651+
644652
self.processes = [] # type: ignore[var-annotated]
645653
self.rank = self.MAIN_PROCESS_RANK
646654
self.file_name = tempfile.NamedTemporaryFile(delete=False).name
@@ -865,28 +873,13 @@ def _join_processes(self, fn) -> None:
865873
time.sleep(0.1)
866874

867875
elapsed_time = time.time() - start_time
868-
869-
if fn in self.skip_return_code_checks:
870-
self._check_no_test_errors(elapsed_time)
871-
else:
872-
self._check_return_codes(elapsed_time)
876+
self._check_return_codes(fn, elapsed_time)
873877
finally:
874878
# Close all pipes
875879
for pipe in self.pid_to_pipe.values():
876880
pipe.close()
877881

878-
def _check_no_test_errors(self, elapsed_time) -> None:
879-
"""
880-
Checks that we didn't have any errors thrown in the child processes.
881-
"""
882-
for i, p in enumerate(self.processes):
883-
if p.exitcode is None:
884-
raise RuntimeError(
885-
f"Process {i} timed out after {elapsed_time} seconds"
886-
)
887-
self.assertNotEqual(self.TEST_ERROR_EXIT_CODE, p.exitcode)
888-
889-
def _check_return_codes(self, elapsed_time) -> None:
882+
def _check_return_codes(self, fn, elapsed_time) -> None:
890883
"""
891884
Checks that the return codes of all spawned processes match, and skips
892885
tests if they returned a return code indicating a skipping condition.
@@ -928,11 +921,11 @@ def _check_return_codes(self, elapsed_time) -> None:
928921
raise RuntimeError(
929922
f"Process {i} terminated or timed out after {elapsed_time} seconds"
930923
)
931-
self.assertEqual(
932-
p.exitcode,
933-
first_process.exitcode,
934-
msg=f"Expect process {i} exit code to match Process 0 exit code of {first_process.exitcode}, but got {p.exitcode}",
935-
)
924+
925+
# Skip the test return code check
926+
if fn in self.skip_return_code_checks:
927+
return
928+
936929
for skip in TEST_SKIPS.values():
937930
if first_process.exitcode == skip.exit_code:
938931
if IS_SANDCASTLE:
@@ -948,10 +941,18 @@ def _check_return_codes(self, elapsed_time) -> None:
948941
return
949942
else:
950943
raise unittest.SkipTest(skip.message)
944+
945+
# In most cases, we expect test to return exit code 0, standing for success.
946+
expected_return_code = 0
947+
# In some negative tests, we expect test to return non-zero exit code,
948+
# such as watchdog throwing SIGABRT.
949+
if fn in self.special_return_code_checks:
950+
expected_return_code = self.special_return_code_checks[fn]
951+
951952
self.assertEqual(
952953
first_process.exitcode,
953-
0,
954-
msg=f"Expected zero exit code but got {first_process.exitcode} for pid: {first_process.pid}",
954+
expected_return_code,
955+
msg=f"Expected exit code {expected_return_code} but got {first_process.exitcode} for pid: {first_process.pid}",
955956
)
956957

957958
@property

0 commit comments

Comments
 (0)
0