4444 get_timeout ,
4545 init_multigpu_helper ,
4646 MultiProcessTestCase ,
47- requires_gloo ,
4847 requires_multicast_support ,
4948 requires_nccl ,
5049 requires_nccl_version ,
5150 skip_if_lt_x_gpu ,
5251 skip_if_rocm_multiprocess ,
53- sm_is_or_higher_than ,
5452 TEST_SKIPS ,
5553 with_dist_debug_levels ,
5654 with_nccl_blocking_wait ,
@@ -284,16 +282,17 @@ def opts(self, high_priority_stream=False):
284282
285283 def setUp (self ):
286284 super ().setUp ()
287- # Need to skip return code checking for these tests since the child
288- # processes don't exit cleanly in some cuda versions
289- self .skip_return_code_checks = [
290- self .test_nan_assert_float16 .__wrapped__ ,
291- self .test_nan_assert_float32 .__wrapped__ ,
292- self .test_nan_assert_float64 .__wrapped__ ,
293- self .test_nan_assert_bfloat16 .__wrapped__ ,
294- self .test_nan_assert_float8_e4m3fn .__wrapped__ ,
295- self .test_nan_assert_float8_e5m2 .__wrapped__ ,
296- ]
285+
286+ # These tests are expected to throw SIGABRT(6); adding the negative sign
287+ # bc the test return code is actually -6
288+ self .special_return_code_checks = {
289+ self .test_nan_assert_float16 .__wrapped__ : - signal .SIGABRT ,
290+ self .test_nan_assert_float32 .__wrapped__ : - signal .SIGABRT ,
291+ self .test_nan_assert_float64 .__wrapped__ : - signal .SIGABRT ,
292+ self .test_nan_assert_bfloat16 .__wrapped__ : - signal .SIGABRT ,
293+ self .test_nan_assert_float8_e4m3fn .__wrapped__ : - signal .SIGABRT ,
294+ self .test_nan_assert_float8_e5m2 .__wrapped__ : - signal .SIGABRT ,
295+ }
297296
298297 # TORCH_NCCL_BLOCKING_WAIT overrides TORCH_NCCL_ASYNC_ERROR_HANDLING hence tests
299298 # that use TORCH_NCCL_BLOCKING_WAIT will test it as expected.
@@ -534,14 +533,14 @@ def test_nan_assert(self, type):
534533
535534 # confirm enable/disable flag works
536535 backend ._set_enable_nan_check (False )
537- pg .allreduce (nan_tensor )
536+ # Note: using all-gather here bc some NCCL/SM version does not support
537+ # FP8 reduction
538+ pg ._allgather_base (output , nan_tensor )
538539
539540 backend ._set_enable_nan_check (True )
540- with self .assertRaises (RuntimeError ):
541- # Note: using all-gather here bc FP8 types do not support reduce ops
542- # at the moment
543- pg ._allgather_base (output , nan_tensor )
541+ pg ._allgather_base (output , nan_tensor )
544542 dist .destroy_process_group ()
543+
545544 # reset env
546545 os .environ ["TORCH_NCCL_NAN_CHECK" ] = "0"
547546
@@ -576,16 +575,13 @@ def test_nan_rank_filter(self):
576575 def test_nan_check (self ):
577576 # Not expecting an error, NaN check should not make legit code fail
578577 device = torch .device (f"cuda:{ self .rank :d} " )
579- if not sm_is_or_higher_than (device , 8 , 0 ):
580- self .skipTest ("bf16 requires sm >= 8.0" )
581-
582578 os .environ ["TORCH_NCCL_NAN_CHECK" ] = "1"
583579 store = c10d .FileStore (self .file_name , self .world_size )
584580 c10d .init_process_group (
585581 backend = "nccl" , store = store , rank = self .rank , world_size = self .world_size
586582 )
587- x = torch .ones ((10 ,), dtype = torch . bfloat16 , device = device ) * self .rank
588- t = torch .ones (3 , 4 , dtype = torch . bfloat16 , device = device )
583+ x = torch .ones ((10 ,), device = device ) * self .rank
584+ t = torch .ones (3 , 4 , device = device )
589585 c10d .broadcast (x , src = 0 )
590586 c10d .all_reduce (t )
591587 c10d .barrier ()
@@ -2775,14 +2771,6 @@ def hook(work_info: torch._C._distributed_c10d.WorkInfo):
27752771class NcclErrorHandlingTest (MultiProcessTestCase ):
27762772 def setUp (self ):
27772773 super ().setUp ()
2778- # Need to skip return code checking for these tests since the child
2779- # processes don't exit cleanly.
2780- self .skip_return_code_checks = [
2781- self .test_nccl_errors_blocking_abort .__wrapped__ ,
2782- self .test_nccl_errors_blocking_sigkill .__wrapped__ ,
2783- self .test_nccl_errors_blocking_sigterm .__wrapped__ ,
2784- self .test_nccl_errors_blocking_nonzero_exit .__wrapped__ ,
2785- ]
27862774 # TORCH_NCCL_BLOCKING_WAIT overrides TORCH_NCCL_ASYNC_ERROR_HANDLING hence tests
27872775 # that use TORCH_NCCL_BLOCKING_WAIT will test it as expected.
27882776 os .environ ["TORCH_NCCL_ASYNC_ERROR_HANDLING" ] = "1"
@@ -2810,12 +2798,19 @@ def blocking_wait_error_msg(self):
28102798 def _run_all_reduce (self , pg ):
28112799 pg .allreduce (torch .rand (10 ).cuda (self .rank ))
28122800
2801+ def _reduce_timeout (self ):
2802+ # set heartbeat timeout to a small value so that we don't wait too long
2803+ # for things to shutdown
2804+ os .environ ["TORCH_NCCL_HEARTBEAT_TIMEOUT_SEC" ] = "4"
2805+ os .environ ["TORCH_NCCL_WAIT_TIMEOUT_DUMP_MILSEC" ] = "1000"
2806+
28132807 @requires_nccl ()
28142808 @requires_nccl_version ((2 , 4 , 0 ), "Need NCCL 2.4+ for error checking" )
28152809 @skip_if_lt_x_gpu (3 )
28162810 @skip_if_rocm_multiprocess
28172811 @skip_but_pass_in_sandcastle ("Test does not pass when run locally" )
28182812 def test_nccl_errors_nonblocking (self ):
2813+ self ._reduce_timeout ()
28192814 # Note: we unset and restore TORCH_NCCL_ASYNC_ERROR_HANDLING for this test
28202815 # since test_c10d_common runs with async error handling by default, but this
28212816 # tests behavior when it is not enabled.
@@ -2846,30 +2841,24 @@ def test_nccl_errors_nonblocking(self):
28462841 "TORCH_NCCL_ASYNC_ERROR_HANDLING"
28472842 ] = prev_nccl_async_error_handling
28482843
2849- def _test_nccl_errors_blocking (self , func ):
2844+ @requires_nccl ()
2845+ @skip_if_lt_x_gpu (3 )
2846+ @skip_if_rocm_multiprocess
2847+ def test_nccl_errors_blocking (self ):
2848+ self ._reduce_timeout ()
2849+ os .environ ["TORCH_NCCL_ASYNC_ERROR_HANDLING" ] = "0"
28502850 store = c10d .FileStore (self .file_name , self .world_size )
28512851 process_group = c10d .ProcessGroupNCCL (
28522852 store ,
28532853 self .rank ,
28542854 self .world_size ,
2855- timeout = timedelta (seconds = 10 ),
28562855 )
2857- process_group .allreduce (torch .rand (10 ).cuda (self .rank ))
2856+ x = torch .rand (1024 * 1024 ).cuda (self .rank )
2857+ process_group .allreduce (x )
28582858 if self .rank == 0 :
2859- work = process_group .allreduce (torch . rand ( 10 ). cuda ( self . rank ) )
2859+ work = process_group .allreduce (x )
28602860 with self .assertRaisesRegex (dist .DistBackendError , "" ):
2861- # It seems the error message would be different depending on
2862- # whether the test is run on CI machine and devGPU. Skipping
2863- # the error message check to make both sides happy.
28642861 work .wait (timeout = timedelta (seconds = self .op_timeout_sec ))
2865- # Run some GPU operations to make sure cuda has not gotten stuck.
2866- # It was observed cuda could get stuck if NCCL communicators were
2867- # not properly aborted before throwing RuntimeError.
2868- torch .rand (10 ).cuda (self .rank )
2869- elif self .rank == 1 :
2870- # Clean up structures (ex: files for FileStore before going down)
2871- del process_group
2872- func ()
28732862
28742863 def _test_barrier_error (self ):
28752864 store = c10d .FileStore (self .file_name , self .world_size )
@@ -2889,60 +2878,19 @@ def _test_barrier_error(self):
28892878 timeout = timedelta (seconds = self .op_timeout_sec )
28902879 )
28912880
2892- @with_nccl_blocking_wait
2893- @requires_nccl ()
2894- @requires_nccl_version ((2 , 4 , 0 ), "Need NCCL 2.4+ for error checking" )
2895- @skip_if_lt_x_gpu (3 )
2896- @skip_if_rocm_multiprocess
2897- def test_nccl_errors_blocking_clean_exit (self ):
2898- self ._test_nccl_errors_blocking (lambda : sys .exit (0 ))
2899-
2900- @with_nccl_blocking_wait
2901- @requires_nccl ()
2902- @requires_nccl_version ((2 , 4 , 0 ), "Need NCCL 2.4+ for error checking" )
2903- @skip_if_lt_x_gpu (3 )
2904- @skip_if_rocm_multiprocess
2905- def test_nccl_errors_blocking_nonzero_exit (self ):
2906- self ._test_nccl_errors_blocking (lambda : sys .exit (1 ))
2907-
2908- @with_nccl_blocking_wait
2909- @requires_nccl ()
2910- @requires_nccl_version ((2 , 4 , 0 ), "Need NCCL 2.4+ for error checking" )
2911- @skip_if_lt_x_gpu (3 )
2912- @skip_if_rocm_multiprocess
2913- @skip_but_pass_in_sandcastle (
2914- "Frequently times out see https://github.com/pytorch/pytorch/issues/58920"
2915- )
2916- def test_nccl_errors_blocking_abort (self ):
2917- self ._test_nccl_errors_blocking (lambda : os .abort ())
2918-
2919- @with_nccl_blocking_wait
2920- @requires_nccl ()
2921- @requires_nccl_version ((2 , 4 , 0 ), "Need NCCL 2.4+ for error checking" )
2922- @skip_if_lt_x_gpu (3 )
2923- @skip_if_rocm_multiprocess
2924- def test_nccl_errors_blocking_sigkill (self ):
2925- self ._test_nccl_errors_blocking (lambda : os .kill (os .getpid (), signal .SIGKILL ))
2926-
2927- @with_nccl_blocking_wait
2928- @requires_nccl ()
2929- @requires_nccl_version ((2 , 4 , 0 ), "Need NCCL 2.4+ for error checking" )
2930- @skip_if_lt_x_gpu (3 )
2931- @skip_if_rocm_multiprocess
2932- def test_nccl_errors_blocking_sigterm (self ):
2933- self ._test_nccl_errors_blocking (lambda : os .kill (os .getpid (), signal .SIGTERM ))
2934-
29352881 @with_nccl_blocking_wait
29362882 @requires_nccl ()
29372883 @requires_nccl_version ((2 , 4 , 0 ), "Need NCCL 2.4+ for error checking" )
29382884 @skip_if_lt_x_gpu (3 )
29392885 def test_nccl_blocking_wait_with_barrier (self ):
2886+ self ._reduce_timeout ()
29402887 self ._test_barrier_error ()
29412888
29422889 @requires_nccl ()
29432890 @requires_nccl_version ((2 , 4 , 0 ), "Need NCCL 2.4+ for error checking" )
29442891 @skip_if_lt_x_gpu (3 )
29452892 def test_nccl_non_blocking_wait_with_barrier (self ):
2893+ self ._reduce_timeout ()
29462894 # test the barrier behavior in the non blocking wait setting
29472895 prev_nccl_async_error_handling = os .environ .get (
29482896 "TORCH_NCCL_ASYNC_ERROR_HANDLING" , None
@@ -3013,6 +2961,7 @@ def assert_fut_success(fut):
30132961 @skip_if_rocm_multiprocess
30142962
CAD1
@skip_if_lt_x_gpu (3 )
30152963 def test_restart_pg_after_error (self ):
2964+ self ._reduce_timeout ()
30162965 # test the barrier behavior in the non blocking wait setting
30172966 prev_nccl_async_error_handling = os .environ .get (
30182967 "TORCH_NCCL_ASYNC_ERROR_HANDLING" , None
@@ -3102,45 +3051,6 @@ def test_invalid_nccl_blocking_wait_env(self):
31023051 self ._run_invalid_nccl_blocking_wait_env ("2147483647" )
31033052 self ._run_invalid_nccl_blocking_wait_env ("4294967295" )
31043053
3105- @with_nccl_blocking_wait
3106- @requires_nccl ()
3107- @requires_gloo ()
3108- @skip_if_lt_x_gpu (3 )
3109- def test_nccl_timeout (self ):
3110- store = c10d .FileStore (self .file_name , self .world_size )
3111-
3112- # Initialize process_group.
3113- process_group = c10d .ProcessGroupNCCL (
3114- store , self .rank , self .world_size , timeout = timedelta (seconds = 10 )
3115- )
3116- # Control gloo pg used as go-ahead signal/barrier
3117- # to coordinate btwn ranks.
3118- pg_gloo = c10d .ProcessGroupGloo (store , self .rank , self .world_size )
3119- failed_collective_timeout = timedelta (milliseconds = 100 )
3120- process_group .allreduce (torch .rand (10 ).cuda (self .rank )).wait (
3121- timeout = timedelta (seconds = 5 )
3122- )
3123-
3124- if self .rank == 0 :
3125- # This should timeout in about 1 second.
3126- # Watchdog may abort timed out work resulting in NCCL error instead of operation timed out.
3127- with self .assertRaisesRegex (
3128- dist .DistBackendError , self .blocking_wait_error_msg
3129- ):
3130- process_group .allreduce (torch .rand (10 ).cuda (self .rank )).wait (
3131- timeout = failed_collective_timeout
3132- )
3133- # Now do a barrier to tell other rank to go ahead.
3134- pg_gloo .barrier ().wait ()
3135- else :
3136- # Wait on rank 0 to fail.
3137- try :
3138- pg_gloo .barrier ().wait ()
3139- except Exception as e :
3140- raise ValueError (
3141- f"Rank { self .rank } barrier timed out waiting for rank 0 with error: { str (e )} "
3142- ) from e
3143-
31443054
31453055class NcclUserBufferRegistrationTest (MultiProcessTestCase ):
31463056 def setUp (self ):
0 commit comments