8000 [WIP] Consolidate watchdog and monitoring thread · pytorch/pytorch@05a260d · GitHub
[go: up one dir, main page]

Skip to content

Commit 05a260d

Browse files
committed
[WIP] Consolidate watchdog and monitoring thread
ghstack-source-id: 777a504 Pull Request resolved: #153668
1 parent 7243c69 commit 05a260d

File tree

2 files changed

+121
-89
lines changed

2 files changed

+121
-89
lines changed

torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp

Lines changed: 97 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -925,6 +925,19 @@ constexpr const char* MULTI_DEVICE_ERROR_MSG =
925925
"https://pytorch.org/docs/stable/distributed.html#multi-gpu-collective-functions. "
926926
"ProcessGroupNCCL continues supporting multi-process and multi-thread modes.";
927927

928+
std::atomic<bool> ProcessGroupNCCL::terminateHeartbeatMonitorThread_{false};
929+
c10::once_flag ProcessGroupNCCL::initFlag_;
930+
std::atomic_uint64_t ProcessGroupNCCL::heartbeat_;
931+
int ProcessGroupNCCL::heartbeatTimeoutInSec_;
932+
int ProcessGroupNCCL::waitTimeoutDumpInMilSec_;
933+
int ProcessGroupNCCL::coordCheckIntervalMilSec_;
934+
std::atomic<bool> ProcessGroupNCCL::watchdogHeartbeatMonitorEnabled_;
935+
std::mutex ProcessGroupNCCL::monitorMutex_;
936+
std::condition_variable ProcessGroupNCCL::monitorWakeUpCV_;
937+
bool ProcessGroupNCCL::dumpOnTimeoutOrEx_;
938+
std::string ProcessGroupNCCL::globalLogPrefix_;
939+
std::thread ProcessGroupNCCL::ncclHeartbeatMonitorThread_;
940+
928941
ProcessGroupNCCL::ProcessGroupNCCL(
929942
c10::intrusive_ptr<Store> store,
930943
int rank,
@@ -934,7 +947,6 @@ ProcessGroupNCCL::ProcessGroupNCCL(
934947
store_(std::move(store)),
935948
options_(std::move(options)),
936949
terminateProcessGroup_(false),
937-
terminateHeartbeatMonitorThread_(false),
938950
local_id_(process_group_id++),
939951
intraNodeComm_(initIntraNodeComm()) {
940952
TORCH_CHECK_WITH(
@@ -956,33 +968,30 @@ ProcessGroupNCCL::ProcessGroupNCCL(
956968
desyncDebug_ = getCvarBool(TORCH_NCCL_DESYNC_DEBUG, false) ||
957969
(dist_debug_level_ >= DebugLevel::Detail);
958970
rethrowCUDAErrors_ = getCvarBool(TORCH_NCCL_RETHROW_CUDA_ERRORS, true);
959-
// TODO, we should either deprecate TORCH_NCCL_DUMP_ON_TIMEOUT
960-
// or change its name to reflect that dump happens on exception including
961-
// both timeout and other errors.
962-
dumpOnTimeoutOrEx_ = getCvarBool(TORCH_NCCL_DUMP_ON_TIMEOUT, true) ||
963-
(dist_debug_level_ >= DebugLevel::Detail);
964971
propagatePgError_ = getCvarBool(TORCH_NCCL_PROPAGATE_ERROR, false);
965972
// logging C++ stack isn't safe. Introduce a variable to control it.
966973
logCppStackOnUncleanShutdown_ =
967974
getCvarBool(TORCH_NCCL_LOG_CPP_STACK_ON_UNCLEAN_SHUTDOWN, true);
968975
enableNanCheck_ = getCvarBool(TORCH_NCCL_NAN_CHECK, false);
969-
heartbeat_ = 1ULL;
970-
monitorThreadEnabled_.store(getCvarBool(TORCH_NCCL_ENABLE_MONITORING, true));
971976
cudaEventCacheEnabled_.store(getCvarBool(TORCH_NCCL_CUDA_EVENT_CACHE, true));
972-
heartbeatTimeoutInSec_ =
973-
getCvarInt(TORCH_NCCL_HEARTBEAT_TIMEOUT_SEC, 60 * 8 /*8 Mins*/);
974-
waitTimeoutDumpInMilSec_ =
975-
getCvarInt(TORCH_NCCL_WAIT_TIMEOUT_DUMP_MILSEC, 15 * 1000 /*15 Sec*/);
976-
coordCheckIntervalMilSec_ = getCvarInt(TORCH_NCCL_COORD_CHECK_MILSEC, 1000);
977977
traceBufferSize_ = getCvarInt(TORCH_NCCL_TRACE_BUFFER_SIZE, 2000);
978978
enableCollectiveHashDebug_ = (dist_debug_level_ >= DebugLevel::Detail);
979-
// store_ usually is wrapped with PrefixStore and the prefix is different
980-
// across different ProcessGroupNCCL(PG) instances. We need to get the
981-
// underlying non-PrefixStore for sharing global information shared across
982-
// different PGs.
983-
PrefixStore* prefixStore = dynamic_cast<PrefixStore*>(store_.get());
984-
globalStore_ =
985-
prefixStore ? prefixStore->getUnderlyingNonPrefixStore() : store_;
979+
c10::call_once(initFlag_, [&]() {
980+
heartbeat_ = 1ULL;
981+
heartbeatTimeoutInSec_ =
982+
getCvarInt(TORCH_NCCL_HEARTBEAT_TIMEOUT_SEC, 60 * 8 /*8 Mins*/);
983+
waitTimeoutDumpInMilSec_ =
984+
getCvarInt(TORCH_NCCL_WAIT_TIMEOUT_DUMP_MILSEC, 15 * 1000 /*15 Sec*/);
985+
coordCheckIntervalMilSec_ = getCvarInt(TORCH_NCCL_COORD_CHECK_MILSEC, 1000);
986+
// TODO, we should either deprecate TORCH_NCCL_DUMP_ON_TIMEOUT
987+
// or change its name to reflect that dump happens on exception including
988+
// both timeout and other errors.
989+
dumpOnTimeoutOrEx_ = getCvarBool(TORCH_NCCL_DUMP_ON_TIMEOUT, true) ||
990+
(dist_debug_level_ >= DebugLevel::Detail);
991+
watchdogHeartbeatMonitorEnabled_.store(
992+
getCvarBool(TORCH_NCCL_ENABLE_MONITORING, true));
993+
globalLogPrefix_ = c10::str("[Global Rank ", globalRank(), "] ");
994+
});
986995
#ifdef ENABLE_NCCL_ERROR_CHECKING
987996
enableTiming_.store(
988997
getCvarBool(TORCH_NCCL_ENABLE_TIMING, false) || desyncDebug_);
@@ -1046,7 +1055,7 @@ ProcessGroupNCCL::ProcessGroupNCCL(
10461055
<< shouldAllCommunicatorsRegisterAllTensors()
10471056
#endif // NCCL_HAS_COMM_REGISTER
10481057
<< ", TORCH_NCCL_ENABLE_MONITORING: "
1049-
<< monitorThreadEnabled_.load()
1058+
<< watchdogHeartbeatMonitorEnabled_.load()
10501059
<< ", TORCH_NCCL_HEARTBEAT_TIMEOUT_SEC: " << heartbeatTimeoutInSec_
10511060
<< ", TORCH_NCCL_TRACE_BUFFER_SIZE: " << traceBufferSize_
10521061
<< ", TORCH_NCCL_COORD_CHECK_MILSEC: " << coordCheckIntervalMilSec_
@@ -1521,6 +1530,25 @@ void ProcessGroupNCCL::shutdown() {
15211530
LOG(INFO) << logPrefix() << "Destroy complete.";
15221531
}
15231532

1533+
void ProcessGroupNCCL::startMonitorThread() {
1534+
static c10::once_flag startFlag;
1535+
c10::call_once(startFlag, [this]() {
1536+
ncclHeartbeatMonitorThread_ =
1537+
std::thread(&ProcessGroupNCCL::heartbeatMonitor, this);
1538+
});
1539+
}
1540+
1541+
void ProcessGroupNCCL::waitMonitorThread() {
1542+
static c10::once_flag shutdownFlag;
1543+
c10::call_once(shutdownFlag, [this]() {
1544+
if (ncclHeartbeatMonitorThread_.joinable()) {
1545+
ncclHeartbeatMonitorThread_.join();
1546+
LOG(INFO) << logPrefix()
1547+
<< "ProcessGroupNCCL heart beat monitor thread joined.";
1548+
}
1549+
});
1550+
}
1551+
15241552
// NOLINTNEXTLINE(bugprone-exception-escape)
15251553
ProcessGroupNCCL::~ProcessGroupNCCL() {
15261554
LOG(INFO) << logPrefix() << "ProcessGroupNCCL destructor entered.";
@@ -1571,11 +1599,7 @@ ProcessGroupNCCL::~ProcessGroupNCCL() {
15711599
ncclCommWatchdogThread_.join();
15721600
LOG(INFO) << logPrefix() << "ProcessGroupNCCL watchdog thread joined.";
15731601
}
1574-
if (ncclHeartbeatMonitorThread_.joinable()) {
1575-
ncclHeartbeatMonitorThread_.join();
1576-
LOG(INFO) << logPrefix()
1577-
<< "ProcessGroupNCCL heart beat monitor thread joined.";
1578-
}
1602+
waitMonitorThread();
15791603
if (onCompletionHookThread_.joinable()) {
15801604
onCompletionHookThread_.join();
15811605
LOG(INFO) << logPrefix()
@@ -1631,11 +1655,6 @@ std::string ProcessGroupNCCL::getNCCLWatchdogTimeoutErrorMsg(
16311655
"Received a dump signal due to a collective timeout from ",
16321656
extraMsg,
16331657
" and we will try our best to dump the debug info. ",
1634-
"Last enqueued NCCL work: ",
1635-
pgStatus_->lastEnqueuedSeq,
1636-
", last completed NCCL work: ",
1637-
pgStatus_->lastCompletedSeq,
1638-
".",
16391658
"This is most likely caused by incorrect usages of collectives, e.g., wrong ",
16401659
"sizes used across ranks, the order of collectives is not same for all ranks ",
16411660
"or the scheduled collective, for some reason, didn't run. Additionally, ",
@@ -1660,28 +1679,21 @@ void ProcessGroupNCCL::heartbeatMonitor() {
16601679
c10::setThreadName("pt_nccl_heartbt");
16611680

16621681
uint64_t heartBeatCounter = 0ULL;
1682+
bool watchdogThreadHang = false;
16631683
std::string errorMsg;
16641684
std::string exitReason;
1665-
bool checkDumpSignal = (dumpOnTimeoutOrEx_ && local_id_ == 0);
1666-
int monitorPollInterval = checkDumpSignal || propagatePgError_
1667-
? coordCheckIntervalMilSec_
1668-
: heartbeatTimeoutInSec_ * 1000;
16691685
auto lastTimePollStore = std::chrono::steady_clock::now();
16701686
auto lastTimeHeartBeatCheck = std::chrono::steady_clock::now();
16711687
std::optional<DumpPipe> dumpPipe = std::nullopt;
1672-
if (local_id_ == 0) {
1673-
// DumpPipe is one per-trainer process, and its convenient to name them
1674-
// after 'global' ranks in the system, So we assume processgroup (uid)==0 is
1675-
// the global PG and has globally unique rank ids across trainers.
1676-
dumpPipe.emplace(rank_);
1677-
}
1688+
// To use the pipe correctly, one needs to initiate the default PG first.
1689+
dumpPipe.emplace(globalRank());
16781690
while (true) {
16791691
// This won't have any lock since this lock is only used here.
16801692
// Please be aware that mutex `monitorMutex_` should not be used
16811693
// somewhere else to avoid the deadlock.
16821694
std::unique_lock<std::mutex> lock(monitorMutex_);
16831695
if (monitorWakeUpCV_.wait_for(
1684-
lock, std::chrono::milliseconds(monitorPollInterval), [&] {
1696+
lock, std::chrono::milliseconds(coordCheckIntervalMilSec_), [&] {
16851697
return terminateHeartbeatMonitorThread_.load();
16861698
})) {
16871699
// For the normal complete or user interception, monitorWakeUpCV_
@@ -1690,19 +1702,14 @@ void ProcessGroupNCCL::heartbeatMonitor() {
16901702
}
16911703
auto currentTime = std::chrono::steady_clock::now();
16921704

1693-
if (propagatePgError_) {
1694-
// Check and set remote error if it has not been set before
1695-
checkAndSetRemoteError();
1696-
}
1697-
16981705
// We put extra functionality in the thread for the default PG (aka,
16991706
// local_id_=0) because the signal is same across different PGs. We only
17001707
// need to run once per process to avoid duplicate things performed in too
17011708
// many separate threads. For example, we check a global flag on the
17021709
// TCPStore periodically to see if any PG on any rank observed a timeout and
17031710
// signaled peers to dump debugging info, and we avoid hammering the
17041711
// TCPStore from all PGs on the same rank.
1705-
if (checkDumpSignal) {
1712+
if (dumpOnTimeoutOrEx_) {
17061713
// There are two scenarios where monitor thread will dump on timeout:
17071714
// 1. The current rank is the first to observe a timeout in watchdog.
17081715
// (shouldDump_ was set to true by the watchdog thread).
@@ -1724,7 +1731,7 @@ void ProcessGroupNCCL::heartbeatMonitor() {
17241731
lastTimePollStore = currentTime;
17251732
auto handleError = [&](const std::string& errorMessage) {
17261733
LOG(WARNING)
1727-
<< logPrefix()
1734+
<< globalLogPrefix()
17281735
<< "Failed to check the \"should dump\" flag on TCPStore, "
17291736
<< "(maybe TCPStore server has shut down too early), with error: "
17301737
<< errorMessage;
@@ -1736,7 +1743,7 @@ void ProcessGroupNCCL::heartbeatMonitor() {
17361743
bool checkExceptionDump = false;
17371744
try {
17381745
checkExceptionDump =
1739-
globalStore_->check({std::string(kStoreDumpKey)});
1746+
globalStore()->check({std::string(kStoreDumpKey)});
17401747
} catch (const c10::DistNetworkError& e) {
17411748
handleError(e.msg());
17421749
} catch (const std::exception& e) {
@@ -1747,12 +1754,12 @@ void ProcessGroupNCCL::heartbeatMonitor() {
17471754
int timeOutRank = -1;
17481755
if (!shouldDump_.load()) {
17491756
LOG(ERROR)
1750-
<< logPrefix()
1757+
<< globalLogPrefix()
17511758
<< "Observed flight recorder dump signal from another rank via TCPStore.";
17521759
}
17531760
shouldDump_.store(true);
17541761
try {
1755-
auto vec = globalStore_->get(std::string(kStoreDumpKey));
1762+
auto vec = globalStore()->get(std::string(kStoreDumpKey));
17561763
TORCH_CHECK_WITH(
17571764
DistBackendError,
17581765
vec.size() == sizeof(int),
@@ -1782,7 +1789,7 @@ void ProcessGroupNCCL::heartbeatMonitor() {
17821789
shouldDump_.store(true);
17831790
// Watchdog heartbeat timeout.
17841791
errorMsg = c10::str(
1785-
logPrefix(),
1792+
globalLogPrefix(),
17861793
"ProcessGroupNCCL's watchdog got stuck for ",
17871794
heartbeatTimeoutInSec_,
17881795
" seconds without making progress in monitoring enqueued collectives. ",
@@ -1818,7 +1825,7 @@ void ProcessGroupNCCL::heartbeatMonitor() {
18181825
// TORCH_NCCL_LOG_CPP_STACK_ON_UNCLEAN_SHUTDOWN=0).
18191826

18201827
// Dump the nccl trace (flight recorder).
1821-
if (checkDumpSignal && shouldDump_.load()) {
1828+
if (dumpOnTimeoutOrEx_ && shouldDump_.load()) {
18221829
// Store debug info to storage if no other thread does it. (By default to
18231830
// local disk)
18241831
bool dumpStackTrace = true;
@@ -1844,7 +1851,7 @@ void ProcessGroupNCCL::heartbeatMonitor() {
18441851

18451852
if (complete) {
18461853
LOG(INFO)
1847-
<< logPrefix()
1854+
<< globalLogPrefix()
18481855
<< "Finished flight recorder successfully. Output can be analyzed using the fr_trace script.";
18491856
if (i > 0) {
18501857
debugLog.strings["exception_msg"] = "Dump with stack trace failed.";
@@ -1872,22 +1879,23 @@ void ProcessGroupNCCL::heartbeatMonitor() {
18721879
futStatus != std::future_status::deferred,
18731880
"Expected the future to have been launched eagerly.");
18741881
LOG(ERROR)
1875-
<< logPrefix()
1882+
<< globalLogPrefix()
18761883
<< "Could not acquire GIL within 300 ms on exit, possible GIL induced hang";
18771884
}
18781885
} else {
18791886
VLOG(2)
1880-
<< logPrefix()
1887+
<< globalLogPrefix()
18811888
<< "GIL checker was not registered, perhaps this is a no-python build?";
18821889
}
18831890

18841891
// Dump the c++ stacktraces.
18851892
auto& cpp_dumper = get_cpp_trace_dumper();
18861893
if (logCppStackOnUncleanShutdown_ && cpp_dumper.has_value()) {
1887-
LOG(INFO) << logPrefix() << "Dumping c++ stacktraces:";
1888-
cpp_dumper.value()(
1889-
[&](const std::string& line) { LOG(INFO) << logPrefix() << line; });
1890-
LOG(INFO) << logPrefix() << "Finished c++ stacktraces dump.";
1894+
LOG(INFO) << globalLogPrefix() << "Dumping c++ stacktraces:";
1895+
cpp_dumper.value()([&](const std::string& line) {
1896+
LOG(INFO) << globalLogPrefix() << line;
1897+
});
1898+
LOG(INFO) << globalLogPrefix() << "Finished c++ stacktraces dump.";
18911899
}
18921900

18931901
// There are two possible cases for the watchdog thread exit:
@@ -1898,13 +1906,15 @@ void ProcessGroupNCCL::heartbeatMonitor() {
18981906
// Case two: desync might be slow or get stuck. Or we get stuck in
18991907
// destructors, we will sleep for some time before calling std::abort() to
19001908
// kill the whole process.
1901-
if ((terminateProcessGroup_.load() || desyncDebug_ || shouldDump_.load()) &&
1902-
!terminateHeartbeatMonitorThread_.load()) {
1903-
// Leave another two mins for desync report generation or process group
1904-
// destroy.
1909+
auto currentTime = std::chrono::steady_clock::now();
1910+
if (!watchdogThreadHang && !terminateHeartbeatMonitorThread_.load() &&
1911+
(computeDeltaMS(lastTimeHeartBeatCheck, currentTime) <
1912+
heartbeatTimeoutInSec_ * 1000l)) {
19051913
std::this_thread::sleep_for(std::chrono::seconds(heartbeatTimeoutInSec_));
1906-
LOG(INFO) << logPrefix() << "slept for " << heartbeatTimeoutInSec_
1907-
<< " waiting for desync report or process group destroy.";
1914+
LOG(INFO)
1915+
<< globalLogPrefix() << "slept for " << heartbeatTimeoutInSec_
1916+
<< " because we want to wait longer to verify there is indeed a watchdog hang.";
1917+
watchdogThreadHang = (heartbeat_.load() == heartBeatCounter);
19081918
}
19091919

19101920
// At this point, we either already sleep for another `heartbeatTimeoutInSec_`
@@ -1917,19 +1927,19 @@ void ProcessGroupNCCL::heartbeatMonitor() {
19171927
// check the return value here. We mainly use a future so we can exit early
19181928
// if done.
19191929

1920-
if (!terminateHeartbeatMonitorThread_.load()) {
1930+
if (!terminateHeartbeatMonitorThread_.load() && watchdogThreadHang) {
19211931
// Create a error message reported from MonitorThread, so
19221932
// we throw exception and make the whole process to be killed.
19231933
// TODO(fduwjj): After having a hang debug wiki, we need to update the wiki
19241934
// url here.
1925-
if (monitorThreadEnabled_.load()) {
1935+
if (watchdogHeartbeatMonitorEnabled_.load()) {
19261936
terminateProcess(getNCCLWatchdogTimeoutExitMsg(exitReason));
19271937
} else {
19281938
// Ideally we want to merge this one with the above one, but we are going
19291939
// to remove the kill switch for monitor thread soon, so we keep this one
19301940
// for now.
19311941
LOG(ERROR)
1932-
<< logPrefix()
1942+
<< globalLogPrefix()
19331943
<< "ProcessGroupNCCL monitor thread is disabled, but would have terminated the process"
19341944
<< "after attempting to dump debug info, due to " << exitReason
19351945
<< ".";
@@ -1942,8 +1952,7 @@ void ProcessGroupNCCL::ncclCommWatchdog() {
19421952

19431953
try {
19441954
VLOG(2) << logPrefix() << "Process group watchdog thread started!";
1945-
ncclHeartbeatMonitorThread_ =
1946-
std::thread(&ProcessGroupNCCL::heartbeatMonitor, this);
1955+
startMonitorThread();
19471956
watchdogHandler();
19481957
VLOG(2) << logPrefix()
19491958
<< "Process group watchdog thread terminated normally";
@@ -2098,11 +2107,20 @@ const std::string& ProcessGroupNCCL::logPrefix() const {
20982107
return logPrefix_;
20992108
}
21002109

2110+
const std::string& ProcessGroupNCCL::globalLogPrefix() {
2111+
return globalLogPrefix_;
2112+
}
2113+
21012114
const int& ProcessGroupNCCL::globalRank() const {
21022115
static int globalRank = rank_;
21032116
return globalRank;
21042117
}
21052118

2119+
const c10::intrusive_ptr<Store>& ProcessGroupNCCL::globalStore() const {
2120+
static c10::intrusive_ptr<Store> globalStore = store_;
2121+
return globalStore;
2122+
}
2123+
21062124
const std::vector<uint64_t>& ProcessGroupNCCL::groupRanks() const {
21072125
if (options_->global_ranks_in_group.empty() && local_id_ == 0) {
21082126
static std::vector<uint64_t> globalRanks(size_);
@@ -2182,7 +2200,8 @@ int ProcessGroupNCCL::getSignalSrcRank(
21822200

21832201
void ProcessGroupNCCL::broadcastDumpSignal() {
21842202
// broadcast dump signal to all other global ranks.
2185-
broadcastSignal(globalStore_, std::string(kStoreDumpKey), globalRank());
2203+
auto global_store = globalStore();
2204+
broadcastSignal(global_store, std::string(kStoreDumpKey), globalRank());
21862205
// signal the local rank to start dumping
21872206
if (!shouldDump_.load()) {
21882207
LOG(ERROR) << logPrefix() << "First PG on this rank to signal dumping.";
@@ -2300,6 +2319,11 @@ void ProcessGroupNCCL::watchdogHandler() {
23002319
lastStatusUpdateTime = std::chrono::steady_clock::now();
23012320
}
23022321

2322+
if (propagatePgError_) {
2323+
// Check and set remote error if it has not been set before
2324+
checkAndSetRemoteError();
2325+
}
2326+
23032327
for (auto it = workMetaList_.begin(); it != workMetaList_.end();
23042328
/* no increment */) {
23052329
auto& work = *it;

0 commit comments

Comments
 (0)
0