4444 get_timeout ,
4545 init_multigpu_helper ,
4646 MultiProcessTestCase ,
47- requires_gloo ,
4847 requires_multicast_support ,
4948 requires_nccl ,
5049 requires_nccl_version ,
@@ -284,16 +283,17 @@ def opts(self, high_priority_stream=False):
284283
285284 def setUp (self ):
286285 super ().setUp ()
287- # Need to skip return code checking for these tests since the child
288- # processes don't exit cleanly in some cuda versions
289- self .skip_return_code_checks = [
290- self .test_nan_assert_float16 .__wrapped__ ,
291- self .test_nan_assert_float32 .__wrapped__ ,
292- self .test_nan_assert_float64 .__wrapped__ ,
293- self .test_nan_assert_bfloat16 .__wrapped__ ,
294- self .test_nan_assert_float8_e4m3fn .__wrapped__ ,
295- self .test_nan_assert_float8_e5m2 .__wrapped__ ,
296- ]
286+
287+ # These tests are expected to throw SIGABRT(6); adding the negative sign
288+ # bc the test return code is actually -6
289+ self .special_return_code_checks = {
290+ self .test_nan_assert_float16 .__wrapped__ : - signal .SIGABRT ,
291+ self .test_nan_assert_float32 .__wrapped__ : - signal .SIGABRT ,
292+ self .test_nan_assert_float64 .__wrapped__ : - signal .SIGABRT ,
293+ self .test_nan_assert_bfloat16 .__wrapped__ : - signal .SIGABRT ,
294+ self .test_nan_assert_float8_e4m3fn .__wrapped__ : - signal .SIGABRT ,
295+ self .test_nan_assert_float8_e5m2 .__wrapped__ : - signal .SIGABRT ,
296+ }
297297
298298 # TORCH_NCCL_BLOCKING_WAIT overrides TORCH_NCCL_ASYNC_ERROR_HANDLING hence tests
299299 # that use TORCH_NCCL_BLOCKING_WAIT will test it as expected.
@@ -534,14 +534,14 @@ def test_nan_assert(self, type):
534534
535535 # confirm enable/disable flag works
536536 backend ._set_enable_nan_check (False )
537- pg .allreduce (nan_tensor )
537+ # Note: using all-gather here bc some NCCL/SM version does not support
538+ # FP8 reduction
539+ pg ._allgather_base (output , nan_tensor )
538540
539541 backend ._set_enable_nan_check (True )
540- with self .assertRaises (RuntimeError ):
541- # Note: using all-gather here bc FP8 types do not support reduce ops
542- # at the moment
543- pg ._allgather_base (output , nan_tensor )
542+ pg ._allgather_base (output , nan_tensor )
544543 dist .destroy_process_group ()
544+
545545 # reset env
546546 os .environ ["TORCH_NCCL_NAN_CHECK" ] = "0"
547547
@@ -576,16 +576,13 @@ def test_nan_rank_filter(self):
576576 def test_nan_check (self ):
577577 # Not expecting an error, NaN check should not make legit code fail
578578 device = torch .device (f"cuda:{ self .rank :d} " )
579- if not sm_is_or_higher_than (device , 8 , 0 ):
580- self .skipTest ("bf16 requires sm >= 8.0" )
581-
582579 os .environ ["TORCH_NCCL_NAN_CHECK" ] = "1"
583580 store = c10d .FileStore (self .file_name , self .world_size )
584581 c10d .init_process_group (
585582 backend = "nccl" , store = store , rank = self .rank , world_size = self .world_size
586583 )
587- x = torch .ones ((10 ,), dtype = torch . bfloat16 , device = device ) * self .rank
588- t = torch .ones (3 , 4 , dtype = torch . bfloat16 , device = device )
584+ x = torch .ones ((10 ,), device = device ) * self .rank
585+ t = torch .ones (3 , 4 , device = device )
589586 c10d .broadcast (x , src = 0 )
590587 c10d .all_reduce (t )
591588 c10d .barrier ()
@@ -2775,14 +2772,6 @@ def hook(work_info: torch._C._distributed_c10d.WorkInfo):
27752772class NcclErrorHandlingTest (MultiProcessTestCase ):
27762773 def setUp (self ):
27772774 super ().setUp ()
2778- # Need to skip return code checking for these tests since the child
2779- # processes don't exit cleanly.
2780- self .skip_return_code_checks = [
2781- self .test_nccl_errors_blocking_abort .__wrapped__ ,
2782- self .test_nccl_errors_blocking_sigkill .__wrapped__ ,
2783- self .test_nccl_errors_blocking_sigterm .__wrapped__ ,
2784- self .test_nccl_errors_blocking_nonzero_exit .__wrapped__ ,
2785- ]
27862775 # TORCH_NCCL_BLOCKING_WAIT overrides TORCH_NCCL_ASYNC_ERROR_HANDLING hence tests
27872776 # that use TORCH_NCCL_BLOCKING_WAIT will test it as expected.
27882777 os .environ ["TORCH_NCCL_ASYNC_ERROR_HANDLING" ] = "1"
@@ -2810,12 +2799,19 @@ def blocking_wait_error_msg(self):
28102799 def _run_all_reduce (self , pg ):
28112800 pg .allreduce (torch .rand (10 ).cuda (self .rank ))
28122801
2802+ def _reduce_timeout (self ):
2803+ # set heartbeat timeout to a small value so that we don't wait too long
2804+ # for things to shutdown
2805+ os .environ ["TORCH_NCCL_HEARTBEAT_TIMEOUT_SEC" ] = "4"
2806+ os .environ ["TORCH_NCCL_WAIT_TIMEOUT_DUMP_MILSEC" ] = "1000"
2807+
28132808 @requires_nccl ()
28142809 @requires_nccl_version ((2 , 4 , 0 ), "Need NCCL 2.4+ for error checking" )
28152810 @skip_if_lt_x_gpu (3 )
28162811 @skip_if_rocm_multiprocess
28172812 @skip_but_pass_in_sandcastle ("Test does not pass when run locally" )
28182813 def test_nccl_errors_nonblocking (self ):
2814+ self ._reduce_timeout ()
28192815 # Note: we unset and restore TORCH_NCCL_ASYNC_ERROR_HANDLING for this test
28202816 # since test_c10d_common runs with async error handling by default, but this
28212817 # tests behavior when it is not enabled.
@@ -2846,30 +2842,24 @@ def test_nccl_errors_nonblocking(self):
28462842 "TORCH_NCCL_ASYNC_ERROR_HANDLING"
28472843 ] = prev_nccl_async_error_handling
28482844
2849- def _test_nccl_errors_blocking (self , func ):
2845+ @requires_nccl ()
2846+ @skip_if_lt_x_gpu (3 )
2847+ @skip_if_rocm_multiprocess
2848+ def test_nccl_errors_blocking (self ):
2849+ self ._reduce_timeout ()
2850+ os .environ ["TORCH_NCCL_ASYNC_ERROR_HANDLING" ] = "0"
28502851 store = c10d .FileStore (self .file_name , self .world_size )
28512852 process_group = c10d .ProcessGroupNCCL (
28522853 store ,
28532854 self .rank ,
28542855 self .world_size ,
2855- timeout = timedelta (seconds = 10 ),
28562856 )
2857- process_group .allreduce (torch .rand (10 ).cuda (self .rank ))
2857+ x = torch .rand (1024 * 1024 ).cuda (self .rank )
2858+ process_group .allreduce (x )
28582859 if self .rank == 0 :
2859- work = process_group .allreduce (torch . rand ( 10 ). cuda ( self . rank ) )
2860+ work = process_group .allreduce (x )
28602861 with self .assertRaisesRegex (dist .DistBackendError , "" ):
2861- # It seems the error message would be different depending on
2862- # whether the test is run on CI machine and devGPU. Skipping
2863- # the error message check to make both sides happy.
28642862 work .wait (timeout = timedelta (seconds = self .op_timeout_sec ))
2865- # Run some GPU operations to make sure cuda has not gotten stuck.
2866- # It was observed cuda could get stuck if NCCL communicators were
2867- # not properly aborted before throwing RuntimeError.
2868- torch .rand (10 ).cuda (self .rank )
2869- elif self .rank == 1 :
2870- # Clean up structures (ex: files for FileStore before going down)
2871- del process_group
2872- func ()
28732863
28742864 def _test_barrier_error (self ):
28752865 store = c10d .FileStore (self .file_name , self .world_size )
@@ -2889,60 +2879,19 @@ def _test_barrier_error(self):
28892879
BE9D
timeout = timedelta (seconds = self .op_timeout_sec )
28902880 )
28912881
2892- @with_nccl_blocking_wait
2893- @requires_nccl ()
2894- @requires_nccl_version ((2 , 4 , 0 ), "Need NCCL 2.4+ for error checking" )
2895- @skip_if_lt_x_gpu (3 )
2896- @skip_if_rocm_multiprocess
2897- def test_nccl_errors_blocking_clean_exit (self ):
2898- self ._test_nccl_errors_blocking (lambda : sys .exit (0 ))
2899-
2900- @with_nccl_blocking_wait
2901- @requires_nccl ()
2902- @requires_nccl_version ((2 , 4 , 0 ), "Need NCCL 2.4+ for error checking" )
2903- @skip_if_lt_x_gpu (3 )
2904- @skip_if_rocm_multiprocess
2905- def test_nccl_errors_blocking_nonzero_exit (self ):
2906- self ._test_nccl_errors_blocking (lambda : sys .exit (1 ))
2907-
2908- @with_nccl_blocking_wait
2909- @requires_nccl ()
2910- @requires_nccl_version ((2 , 4 , 0 ), "Need NCCL 2.4+ for error checking" )
2911- @skip_if_lt_x_gpu (3 )
2912- @skip_if_rocm_multiprocess
2913- @skip_but_pass_in_sandcastle (
2914- "Frequently times out see https://github.com/pytorch/pytorch/issues/58920"
2915- )
2916- def test_nccl_errors_blocking_abort (self ):
2917- self ._test_nccl_errors_blocking (lambda : os .abort ())
2918-
2919- @with_nccl_blocking_wait
2920- @requires_nccl ()
2921- @requires_nccl_version ((2 , 4 , 0 ), "Need NCCL 2.4+ for error checking" )
2922- @skip_if_lt_x_gpu (3 )
2923- @skip_if_rocm_multiprocess
2924- def test_nccl_errors_blocking_sigkill (self ):
2925- self ._test_nccl_errors_blocking (lambda : os .kill (os .getpid (), signal .SIGKILL ))
2926-
2927- @with_nccl_blocking_wait
2928- @requires_nccl ()
2929- @requires_nccl_version ((2 , 4 , 0 ), "Need NCCL 2.4+ for error checking" )
2930- @skip_if_lt_x_gpu (3 )
2931- @skip_if_rocm_multiprocess
2932- def test_nccl_errors_blocking_sigterm (self ):
2933- self ._test_nccl_errors_blocking (lambda : os .kill (os .getpid (), signal .SIGTERM ))
2934-
29352882 @with_nccl_blocking_wait
29362883 @requires_nccl ()
29372884 @requires_nccl_version ((2 , 4 , 0 ), "Need NCCL 2.4+ for error checking" )
29382885 @skip_if_lt_x_gpu (3 )
29392886 def test_nccl_blocking_wait_with_barrier (self ):
2887+ self ._reduce_timeout ()
29402888 self ._test_barrier_error ()
29412889
29422890 @requires_nccl ()
29432891 @requires_nccl_version ((2 , 4 , 0 ), "Need NCCL 2.4+ for error checking" )
29442892 @skip_if_lt_x_gpu (3 )
29452893 def test_nccl_non_blocking_wait_with_barrier (self ):
2894+ self ._reduce_timeout ()
29462895 # test the barrier behavior in the non blocking wait setting
29472896 prev_nccl_async_error_handling = os .environ .get (
29482897 "TORCH_NCCL_ASYNC_ERROR_HANDLING" , None
@@ -3013,6 +2962,7 @@ def assert_fut_success(fut):
30132962 @skip_if_rocm_multiprocess
30142963 @skip_if_lt_x_gpu (3 )
30152964 def test_restart_pg_after_error (self ):
2965+ self ._reduce_timeout ()
30162966 # test the barrier behavior in the non blocking wait setting
30172967 prev_nccl_async_error_handling = os .environ .get (
30182968 "TORCH_NCCL_ASYNC_ERROR_HANDLING" , None
@@ -3102,45 +3052,6 @@ def test_invalid_nccl_blocking_wait_env(self):
31023052 self ._run_invalid_nccl_blocking_wait_env ("2147483647" )
31033053 self ._run_invalid_nccl_blocking_wait_env ("4294967295" )
31043054
3105- @with_nccl_blocking_wait
3106- @requires_nccl ()
3107- @requires_gloo ()
3108- @skip_if_lt_x_gpu (3 )
3109- def test_nccl_timeout (self ):
3110- store = c10d .FileStore (self .file_name , self .world_size )
3111-
3112- # Initialize process_group.
3113- process_group = c10d .ProcessGroupNCCL (
3114- store , self .rank , self .world_size , timeout = timedelta (seconds = 10 )
3115- )
3116- # Control gloo pg used as go-ahead signal/barrier
3117- # to coordinate btwn ranks.
3118- pg_gloo = c10d .ProcessGroupGloo (store , self .rank , self .world_size )
3119- failed_collective_timeout = timedelta (milliseconds = 100 )
3120- process_group .allreduce (torch .rand (10 ).cuda (self .rank )).wait (
3121- timeout = timedelta (seconds = 5 )
3122- )
3123-
3124- if self .rank == 0 :
3125- # This should timeout in about 1 second.
3126- # Watchdog may abort timed out work resulting in NCCL error instead of operation timed out.
3127- with self .assertRaisesRegex (
3128- dist .DistBackendError , self .blocking_wait_error_msg
3129- ):
3130- process_group .allreduce (torch .rand (10 ).cuda (self .rank )).wait (
3131- timeout = failed_collective_timeout
3132- )
3133- # Now do a barrier to tell other rank to go ahead.
3134- pg_gloo .barrier ().wait ()
3135- else :
3136- # Wait on rank 0 to fail.
3137- try :
3138- pg_gloo .barrier ().wait ()
3139- except Exception as e :
3140- raise ValueError (
3141- f"Rank { self .rank } barrier timed out waiting for rank 0 with error: { str (e )} "
3142- ) from e
3143-
31443055
31453056class NcclUserBufferRegistrationTest (MultiProcessTestCase ):
31463057 def setUp (self ):
0 commit comments