@@ -925,6 +925,19 @@ constexpr const char* MULTI_DEVICE_ERROR_MSG =
925
925
" https://pytorch.org/docs/stable/distributed.html#multi-gpu-collective-functions. "
926
926
" ProcessGroupNCCL continues supporting multi-process and multi-thread modes." ;
927
927
928
+ std::atomic<bool > ProcessGroupNCCL::terminateHeartbeatMonitorThread_{false };
929
+ c10::once_flag ProcessGroupNCCL::initFlag_;
930
+ std::atomic_uint64_t ProcessGroupNCCL::heartbeat_;
931
+ int ProcessGroupNCCL::heartbeatTimeoutInSec_;
932
+ int ProcessGroupNCCL::waitTimeoutDumpInMilSec_;
933
+ int ProcessGroupNCCL::coordCheckIntervalMilSec_;
934
+ std::atomic<bool > ProcessGroupNCCL::watchdogHeartbeatMonitorEnabled_;
935
+ std::mutex ProcessGroupNCCL::monitorMutex_;
936
+ std::condition_variable ProcessGroupNCCL::monitorWakeUpCV_;
937
+ bool ProcessGroupNCCL::dumpOnTimeoutOrEx_;
938
+ std::string ProcessGroupNCCL::globalLogPrefix_;
939
+ std::thread ProcessGroupNCCL::ncclHeartbeatMonitorThread_;
940
+
928
941
ProcessGroupNCCL::ProcessGroupNCCL (
929
942
c10::intrusive_ptr<Store> store,
930
943
int rank,
@@ -934,7 +947,6 @@ ProcessGroupNCCL::ProcessGroupNCCL(
934
947
store_ (std::move(store)),
935
948
options_(std::move(options)),
936
949
terminateProcessGroup_(false ),
937
- terminateHeartbeatMonitorThread_(false ),
938
950
local_id_(process_group_id++),
939
951
intraNodeComm_(initIntraNodeComm()) {
940
952
TORCH_CHECK_WITH (
@@ -956,33 +968,30 @@ ProcessGroupNCCL::ProcessGroupNCCL(
956
968
desyncDebug_ = getCvarBool (TORCH_NCCL_DESYNC_DEBUG, false ) ||
957
969
(dist_debug_level_ >= DebugLevel::Detail);
958
970
rethrowCUDAErrors_ = getCvarBool (TORCH_NCCL_RETHROW_CUDA_ERRORS, true );
959
- // TODO, we should either deprecate TORCH_NCCL_DUMP_ON_TIMEOUT
960
- // or change its name to reflect that dump happens on exception including
961
- // both timeout and other errors.
962
- dumpOnTimeoutOrEx_ = getCvarBool (TORCH_NCCL_DUMP_ON_TIMEOUT, true ) ||
963
- (dist_debug_level_ >= DebugLevel::Detail);
964
971
propagatePgError_ = getCvarBool (TORCH_NCCL_PROPAGATE_ERROR, false );
965
972
// logging C++ stack isn't safe. Introduce a variable to control it.
966
973
logCppStackOnUncleanShutdown_ =
967
974
getCvarBool (TORCH_NCCL_LOG_CPP_STACK_ON_UNCLEAN_SHUTDOWN, true );
968
975
enableNanCheck_ = getCvarBool (TORCH_NCCL_NAN_CHECK, false );
969
- heartbeat_ = 1ULL ;
970
- monitorThreadEnabled_.store (getCvarBool (TORCH_NCCL_ENABLE_MONITORING, true ));
971
976
cudaEventCacheEnabled_.store (getCvarBool (TORCH_NCCL_CUDA_EVENT_CACHE, true ));
972
- heartbeatTimeoutInSec_ =
973
- getCvarInt (TORCH_NCCL_HEARTBEAT_TIMEOUT_SEC, 60 * 8 /* 8 Mins*/ );
974
- waitTimeoutDumpInMilSec_ =
975
- getCvarInt (TORCH_NCCL_WAIT_TIMEOUT_DUMP_MILSEC, 15 * 1000 /* 15 Sec*/ );
976
- coordCheckIntervalMilSec_ = getCvarInt (TORCH_NCCL_COORD_CHECK_MILSEC, 1000 );
977
977
traceBufferSize_ = getCvarInt (TORCH_NCCL_TRACE_BUFFER_SIZE, 2000 );
978
978
enableCollectiveHashDebug_ = (dist_debug_level_ >= DebugLevel::Detail);
979
- // store_ usually is wrapped with PrefixStore and the prefix is different
980
- // across different ProcessGroupNCCL(PG) instances. We need to get the
981
- // underlying non-PrefixStore for sharing global information shared across
982
- // different PGs.
983
- PrefixStore* prefixStore = dynamic_cast <PrefixStore*>(store_.get ());
984
- globalStore_ =
985
- prefixStore ? prefixStore->getUnderlyingNonPrefixStore () : store_;
979
+ c10::call_once (initFlag_, [&]() {
980
+ heartbeat_ = 1ULL ;
981
+ heartbeatTimeoutInSec_ =
982
+ getCvarInt (TORCH_NCCL_HEARTBEAT_TIMEOUT_SEC, 60 * 8 /* 8 Mins*/ );
983
+ waitTimeoutDumpInMilSec_ =
984
+ getCvarInt (TORCH_NCCL_WAIT_TIMEOUT_DUMP_MILSEC, 15 * 1000 /* 15 Sec*/ );
985
+ coordCheckIntervalMilSec_ = getCvarInt (TORCH_NCCL_COORD_CHECK_MILSEC, 1000 );
986
+ // TODO, we should either deprecate TORCH_NCCL_DUMP_ON_TIMEOUT
987
+ // or change its name to reflect that dump happens on exception including
988
+ // both timeout and other errors.
989
+ dumpOnTimeoutOrEx_ = getCvarBool (TORCH_NCCL_DUMP_ON_TIMEOUT, true ) ||
990
+ (dist_debug_level_ >= DebugLevel::Detail);
991
+ watchdogHeartbeatMonitorEnabled_.store (
992
+ getCvarBool (TORCH_NCCL_ENABLE_MONITORING, true ));
993
+ globalLogPrefix_ = c10::str (" [Global Rank " , globalRank (), " ] " );
994
+ });
986
995
#ifdef ENABLE_NCCL_ERROR_CHECKING
987
996
enableTiming_.store (
988
997
getCvarBool (TORCH_NCCL_ENABLE_TIMING, false ) || desyncDebug_);
@@ -1046,7 +1055,7 @@ ProcessGroupNCCL::ProcessGroupNCCL(
1046
1055
<< shouldAllCommunicatorsRegisterAllTensors ()
1047
1056
#endif // NCCL_HAS_COMM_REGISTER
1048
1057
<< " , TORCH_NCCL_ENABLE_MONITORING: "
1049
- << monitorThreadEnabled_ .load ()
1058
+ << watchdogHeartbeatMonitorEnabled_ .load ()
1050
1059
<< " , TORCH_NCCL_HEARTBEAT_TIMEOUT_SEC: " << heartbeatTimeoutInSec_
1051
1060
<< " , TORCH_NCCL_TRACE_BUFFER_SIZE: " << traceBufferSize_
1052
1061
<< " , TORCH_NCCL_COORD_CHECK_MILSEC: " << coordCheckIntervalMilSec_
@@ -1521,6 +1530,25 @@ void ProcessGroupNCCL::shutdown() {
1521
1530
LOG (INFO) << logPrefix () << " Destroy complete." ;
1522
1531
}
1523
1532
1533
+ void ProcessGroupNCCL::startMonitorThread () {
1534
+ static c10::once_flag startFlag;
1535
+ c10::call_once (startFlag, [this ]() {
1536
+ ncclHeartbeatMonitorThread_ =
1537
+ std::thread (&ProcessGroupNCCL::heartbeatMonitor, this );
1538
+ });
1539
+ }
1540
+
1541
+ void ProcessGroupNCCL::waitMonitorThread () {
1542
+ static c10::once_flag shutdownFlag;
1543
+ c10::call_once (shutdownFlag, [this ]() {
1544
+ if (ncclHeartbeatMonitorThread_.joinable ()) {
1545
+ ncclHeartbeatMonitorThread_.join ();
1546
+ LOG (INFO) << logPrefix ()
1547
+ << " ProcessGroupNCCL heart beat monitor thread joined." ;
1548
+ }
1549
+ });
1550
+ }
1551
+
1524
1552
// NOLINTNEXTLINE(bugprone-exception-escape)
1525
1553
ProcessGroupNCCL::~ProcessGroupNCCL () {
1526
1554
LOG (INFO) << logPrefix () << " ProcessGroupNCCL destructor entered." ;
@@ -1571,11 +1599,7 @@ ProcessGroupNCCL::~ProcessGroupNCCL() {
1571
1599
ncclCommWatchdogThread_.join ();
1572
1600
LOG (INFO) << logPrefix () << " ProcessGroupNCCL watchdog thread joined." ;
1573
1601
}
1574
- if (ncclHeartbeatMonitorThread_.joinable ()) {
1575
- ncclHeartbeatMonitorThread_.join ();
1576
- LOG (INFO) << logPrefix ()
1577
- << " ProcessGroupNCCL heart beat monitor thread joined." ;
1578
- }
1602
+ waitMonitorThread ();
1579
1603
if (onCompletionHookThread_.joinable ()) {
1580
1604
onCompletionHookThread_.join ();
1581
1605
LOG (INFO) << logPrefix ()
@@ -1631,11 +1655,6 @@ std::string ProcessGroupNCCL::getNCCLWatchdogTimeoutErrorMsg(
1631
1655
" Received a dump signal due to a collective timeout from " ,
1632
1656
extraMsg,
1633
1657
" and we will try our best to dump the debug info. " ,
1634
- " Last enqueued NCCL work: " ,
1635
- pgStatus_->lastEnqueuedSeq ,
1636
- " , last completed NCCL work: " ,
1637
- pgStatus_->lastCompletedSeq ,
1638
- " ." ,
1639
1658
" This is most likely caused by incorrect usages of collectives, e.g., wrong " ,
1640
1659
" sizes used across ranks, the order of collectives is not same for all ranks " ,
1641
1660
" or the scheduled collective, for some reason, didn't run. Additionally, " ,
@@ -1660,28 +1679,21 @@ void ProcessGroupNCCL::heartbeatMonitor() {
1660
1679
c10::setThreadName (" pt_nccl_heartbt" );
1661
1680
1662
1681
uint64_t heartBeatCounter = 0ULL ;
1682
+ bool watchdogThreadHang = false ;
1663
1683
std::string errorMsg;
1664
1684
std::string exitReason;
1665
- bool checkDumpSignal = (dumpOnTimeoutOrEx_ && local_id_ == 0 );
1666
- int monitorPollInterval = checkDumpSignal || propagatePgError_
1667
- ? coordCheckIntervalMilSec_
1668
- : heartbeatTimeoutInSec_ * 1000 ;
1669
1685
auto lastTimePollStore = std::chrono::steady_clock::now ();
1670
1686
auto lastTimeHeartBeatCheck = std::chrono::steady_clock::now ();
1671
1687
std::optional<DumpPipe> dumpPipe = std::nullopt;
1672
- if (local_id_ == 0 ) {
1673
- // DumpPipe is one per-trainer process, and its convenient to name them
1674
- // after 'global' ranks in the system, So we assume processgroup (uid)==0 is
1675
- // the global PG and has globally unique rank ids across trainers.
1676
- dumpPipe.emplace (rank_);
1677
- }
1688
+ // To use the pipe correctly, one needs to initiate the default PG first.
1689
+ dumpPipe.emplace (globalRank ());
1678
1690
while (true ) {
1679
1691
// This won't have any lock since this lock is only used here.
1680
1692
// Please be aware that mutex `monitorMutex_` should not be used
1681
1693
// somewhere else to avoid the deadlock.
1682
1694
std::unique_lock<std::mutex> lock (monitorMutex_);
1683
1695
if (monitorWakeUpCV_.wait_for (
1684
- lock, std::chrono::milliseconds (monitorPollInterval ), [&] {
1696
+ lock, std::chrono::milliseconds (coordCheckIntervalMilSec_ ), [&] {
1685
1697
return terminateHeartbeatMonitorThread_.load ();
1686
1698
})) {
1687
1699
// For the normal complete or user interception, monitorWakeUpCV_
@@ -1690,19 +1702,14 @@ void ProcessGroupNCCL::heartbeatMonitor() {
1690
1702
}
1691
1703
auto currentTime = std::chrono::steady_clock::now ();
1692
1704
1693
- if (propagatePgError_) {
1694
- // Check and set remote error if it has not been set before
1695
- checkAndSetRemoteError ();
1696
- }
1697
-
1698
1705
// We put extra functionality in the thread for the default PG (aka,
1699
1706
// local_id_=0) because the signal is same across different PGs. We only
1700
1707
// need to run once per process to avoid duplicate things performed in too
1701
1708
// many separate threads. For example, we check a global flag on the
1702
1709
// TCPStore periodically to see if any PG on any rank observed a timeout and
1703
1710
// signaled peers to dump debugging info, and we avoid hammering the
1704
1711
// TCPStore from all PGs on the same rank.
1705
- if (checkDumpSignal ) {
1712
+ if (dumpOnTimeoutOrEx_ ) {
1706
1713
// There are two scenarios where monitor thread will dump on timeout:
1707
1714
// 1. The current rank is the first to observe a timeout in watchdog.
1708
1715
// (shouldDump_ was set to true by the watchdog thread).
@@ -1724,7 +1731,7 @@ void ProcessGroupNCCL::heartbeatMonitor() {
1724
1731
lastTimePollStore = currentTime;
1725
1732
auto handleError = [&](const std::string& errorMessage) {
1726
1733
LOG (WARNING)
1727
- << logPrefix ()
1734
+ << globalLogPrefix ()
1728
1735
<< " Failed to check the \" should dump\" flag on TCPStore, "
1729
1736
<< " (maybe TCPStore server has shut down too early), with error: "
1730
1737
<< errorMessage;
@@ -1736,7 +1743,7 @@ void ProcessGroupNCCL::heartbeatMonitor() {
1736
1743
bool checkExceptionDump = false ;
1737
1744
try {
1738
1745
checkExceptionDump =
1739
- globalStore_ ->check ({std::string (kStoreDumpKey )});
1746
+ globalStore () ->check ({std::string (kStoreDumpKey )});
1740
1747
} catch (const c10::DistNetworkError& e) {
1741
1748
handleError (e.msg ());
1742
1749
} catch (const std::exception& e) {
@@ -1747,12 +1754,12 @@ void ProcessGroupNCCL::heartbeatMonitor() {
1747
1754
int timeOutRank = -1 ;
1748
1755
if (!shouldDump_.load ()) {
1749
1756
LOG (ERROR)
1750
- << logPrefix ()
1757
+ << globalLogPrefix ()
1751
1758
<< " Observed flight recorder dump signal from another rank via TCPStore." ;
1752
1759
}
1753
1760
shouldDump_.store (true );
1754
1761
try {
1755
- auto vec = globalStore_ ->get (std::string (kStoreDumpKey ));
1762
+ auto vec = globalStore () ->get (std::string (kStoreDumpKey ));
1756
1763
TORCH_CHECK_WITH (
1757
1764
DistBackendError,
1758
1765
vec.size () == sizeof (int ),
@@ -1782,7 +1789,7 @@ void ProcessGroupNCCL::heartbeatMonitor() {
1782
1789
shouldDump_.store (true );
1783
1790
// Watchdog heartbeat timeout.
1784
1791
errorMsg = c10::str (
1785
- logPrefix (),
1792
+ globalLogPrefix (),
1786
1793
" ProcessGroupNCCL's watchdog got stuck for " ,
1787
1794
heartbeatTimeoutInSec_,
1788
1795
" seconds without making progress in monitoring enqueued collectives. " ,
@@ -1818,7 +1825,7 @@ void ProcessGroupNCCL::heartbeatMonitor() {
1818
1825
// TORCH_NCCL_LOG_CPP_STACK_ON_UNCLEAN_SHUTDOWN=0).
1819
1826
1820
1827
// Dump the nccl trace (flight recorder).
1821
- if (checkDumpSignal && shouldDump_.load ()) {
1828
+ if (dumpOnTimeoutOrEx_ && shouldDump_.load ()) {
1822
1829
// Store debug info to storage if no other thread does it. (By default to
1823
1830
// local disk)
1824
1831
bool dumpStackTrace = true ;
@@ -1844,7 +1851,7 @@ void ProcessGroupNCCL::heartbeatMonitor() {
1844
1851
1845
1852
if (complete) {
1846
1853
LOG (INFO)
1847
- << logPrefix ()
1854
+ << globalLogPrefix ()
1848
1855
<< " Finished flight recorder successfully. Output can be analyzed using the fr_trace script." ;
1849
1856
if (i > 0 ) {
1850
1857
debugLog.strings [" exception_msg" ] = " Dump with stack trace failed." ;
@@ -1872,22 +1879,23 @@ void ProcessGroupNCCL::heartbeatMonitor() {
1872
1879
futStatus != std::future_status::deferred,
1873
1880
" Expected the future to have been launched eagerly." );
1874
1881
LOG (ERROR)
1875
- << logPrefix ()
1882
+ << globalLogPrefix ()
1876
1883
<< " Could not acquire GIL within 300 ms on exit, possible GIL induced hang" ;
1877
1884
}
1878
1885
} else {
1879
1886
VLOG (2 )
1880
- << logPrefix ()
1887
+ << globalLogPrefix ()
1881
1888
<< " GIL checker was not registered, perhaps this is a no-python build?" ;
1882
1889
}
1883
1890
1884
1891
// Dump the c++ stacktraces.
1885
1892
auto & cpp_dumper = get_cpp_trace_dumper ();
1886
1893
if (logCppStackOnUncleanShutdown_ && cpp_dumper.has_value ()) {
1887
- LOG (INFO) << logPrefix () << " Dumping c++ stacktraces:" ;
1888
- cpp_dumper.value ()(
1889
- [&](const std::string& line) { LOG (INFO) << logPrefix () << line; });
1890
- LOG (INFO) << logPrefix () << " Finished c++ stacktraces dump." ;
1894
+ LOG (INFO) << globalLogPrefix () << " Dumping c++ stacktraces:" ;
1895
+ cpp_dumper.value ()([&](const std::string& line) {
1896
+ LOG (INFO) << globalLogPrefix () << line;
1897
+ });
1898
+ LOG (INFO) << globalLogPrefix () << " Finished c++ stacktraces dump." ;
1891
1899
}
1892
1900
1893
1901
// There are two possible cases for the watchdog thread exit:
@@ -1898,13 +1906,15 @@ void ProcessGroupNCCL::heartbeatMonitor() {
1898
1906
// Case two: desync might be slow or get stuck. Or we get stuck in
1899
1907
// destructors, we will sleep for some time before calling std::abort() to
1900
1908
// kill the whole process.
1901
- if ((terminateProcessGroup_. load () || desyncDebug_ || shouldDump_. load ()) &&
1902
- !terminateHeartbeatMonitorThread_.load ()) {
1903
- // Leave another two mins for desync report generation or process group
1904
- // destroy.
1909
+ auto currentTime = std::chrono::steady_clock::now ();
1910
+ if (!watchdogThreadHang && !terminateHeartbeatMonitorThread_.load () &&
1911
+ ( computeDeltaMS (lastTimeHeartBeatCheck, currentTime) <
1912
+ heartbeatTimeoutInSec_ * 1000l )) {
1905
1913
std::this_thread::sleep_for (std::chrono::seconds (heartbeatTimeoutInSec_));
1906
- LOG (INFO) << logPrefix () << " slept for " << heartbeatTimeoutInSec_
1907
- << " waiting for desync report or process group destroy." ;
1914
+ LOG (INFO)
1915
+ << globalLogPrefix () << " slept for " << heartbeatTimeoutInSec_
1916
+ << " because we want to wait longer to verify there is indeed a watchdog hang." ;
1917
+ watchdogThreadHang = (heartbeat_.load () == heartBeatCounter);
1908
1918
}
1909
1919
1910
1920
// At this point, we either already sleep for another `heartbeatTimeoutInSec_`
@@ -1917,19 +1927,19 @@ void ProcessGroupNCCL::heartbeatMonitor() {
1917
1927
// check the return value here. We mainly use a future so we can exit early
1918
1928
// if done.
1919
1929
1920
- if (!terminateHeartbeatMonitorThread_.load ()) {
1930
+ if (!terminateHeartbeatMonitorThread_.load () && watchdogThreadHang ) {
1921
1931
// Create a error message reported from MonitorThread, so
1922
1932
// we throw exception and make the whole process to be killed.
1923
1933
// TODO(fduwjj): After having a hang debug wiki, we need to update the wiki
1924
1934
// url here.
1925
- if (monitorThreadEnabled_ .load ()) {
1935
+ if (watchdogHeartbeatMonitorEnabled_ .load ()) {
1926
1936
terminateProcess (getNCCLWatchdogTimeoutExitMsg (exitReason));
1927
1937
} else {
1928
1938
// Ideally we want to merge this one with the above one, but we are going
1929
1939
// to remove the kill switch for monitor thread soon, so we keep this one
1930
1940
// for now.
1931
1941
LOG (ERROR)
1932
- << logPrefix ()
1942
+ << globalLogPrefix ()
1933
1943
<< " ProcessGroupNCCL monitor thread is disabled, but would have terminated the process"
1934
1944
<< " after attempting to dump debug info, due to " << exitReason
1935
1945
<< " ." ;
@@ -1942,8 +1952,7 @@ void ProcessGroupNCCL::ncclCommWatchdog() {
1942
1952
1943
1953
try {
1944
1954
VLOG (2 ) << logPrefix () << " Process group watchdog thread started!" ;
1945
- ncclHeartbeatMonitorThread_ =
1946
- std::thread (&ProcessGroupNCCL::heartbeatMonitor, this );
1955
+ startMonitorThread ();
1947
1956
watchdogHandler ();
1948
1957
VLOG (2 ) << logPrefix ()
1949
1958
<< " Process group watchdog thread terminated normally" ;
@@ -2098,11 +2107,20 @@ const std::string& ProcessGroupNCCL::logPrefix() const {
2098
2107
return logPrefix_;
2099
2108
}
2100
2109
2110
+ const std::string& ProcessGroupNCCL::globalLogPrefix () {
2111
+ return globalLogPrefix_;
2112
+ }
2113
+
2101
2114
const int & ProcessGroupNCCL::globalRank () const {
2102
2115
static int globalRank = rank_;
2103
2116
return globalRank;
2104
2117
}
2105
2118
2119
+ const c10::intrusive_ptr<Store>& ProcessGroupNCCL::globalStore () const {
2120
+ static c10::intrusive_ptr<Store> globalStore = store_;
2121
+ return globalStore;
2122
+ }
2123
+
2106
2124
const std::vector<uint64_t >& ProcessGroupNCCL::groupRanks () const {
2107
2125
if (options_->global_ranks_in_group .empty () && local_id_ == 0 ) {
2108
2126
static std::vector<uint64_t > globalRanks (size_);
@@ -2182,7 +2200,8 @@ int ProcessGroupNCCL::getSignalSrcRank(
2182
2200
2183
2201
void ProcessGroupNCCL::broadcastDumpSignal () {
2184
2202
// broadcast dump signal to all other global ranks.
2185
- broadcastSignal (globalStore_, std::string (kStoreDumpKey ), globalRank ());
2203
+ auto global_store = globalStore ();
2204
+ broadcastSignal (global_store, std::string (kStoreDumpKey ), globalRank ());
2186
2205
// signal the local rank to start dumping
2187
2206
if (!shouldDump_.load ()) {
2188
2207
LOG (ERROR) << logPrefix () << " First PG on this rank to signal dumping." ;
@@ -2300,6 +2319,11 @@ void ProcessGroupNCCL::watchdogHandler() {
2300
2319
lastStatusUpdateTime = std::chrono::steady_clock::now ();
2301
2320
}
2302
2321
2322
+ if (propagatePgError_) {
2323
+ // Check and set remote error if it has not been set before
2324
+ checkAndSetRemoteError ();
2325
+ }
2326
+
2303
2327
for (auto it = workMetaList_.begin (); it != workMetaList_.end ();
2304
2328
/* no increment */ ) {
2305
2329
auto & work = *it;
0 commit comments