8000 [PyTorch][NCCL PG][Resubmit D67193887] Change getNCCLCommDumpMap to u… · pytorch/pytorch@b3c444f · GitHub
[go: up one dir, main page]

Skip to content

Commit b3c444f

Browse files
jiayulufacebook-github-bot
authored andcommitted
[PyTorch][NCCL PG][Resubmit D67193887] Change getNCCLCommDumpMap to use new ncclCommDumpAll API
Summary: see D67193887 Test Plan: https://www.internalfb.com/mlhub/pipelines/runs/mast/aps-dlrm_fsdp2_test-7711e496fc?job_attempt=0&version=0&tab=execution_details&env=PRODUCTION Differential Revision: D74820576
1 parent 480ae2d commit b3c444f

File tree

2 files changed

+23
-4
lines changed

2 files changed

+23
-4
lines changed

torch/csrc/distributed/c10d/NCCLUtils.hpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -280,6 +280,13 @@ class NCCLComm {
280280

281281
ncclUniqueId getNcclId();
282282
at::DeviceIndex getDeviceIndex();
283+
#if defined(IS_NCCLX) && defined(NCCL_COMM_GET_UNIQUE_HASH)
284+
uint64_t getNcclUniqueHash() {
285+
uint64_t ncclUniqueHash = 0;
286+
C10D_NCCL_CHECK(ncclCommGetUniqueHash(ncclComm_, &ncclUniqueHash), std::nullopt);
287+
return ncclUniqueHash;
288+
}
289+
#endif
283290

284291
// Must not be copyable
285292
NCCLComm(const NCCLComm&) = delete;

torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -364,11 +364,20 @@ static void attachAllocatorHooks() {
364364
static std::
365365
unordered_map<std::string, std::unordered_map<std::string, std::string>>
366366
getNCCLCommDumpMap() {
367-
#if (defined(IS_NCCLX) || defined(USE_ROCM)) && defined(NCCL_COMM_DUMP)
367+
#if (defined(IS_NCCLX) || defined(USE_ROCM)) && defined(NCCL_COMM_DUMP) && \
368+
defined(NCCL_COMM_GET_UNIQUE_HASH)
368369
std::unordered_map<
369-
std::string /* ncclUniqueID */,
370+
std::string /* CommHash */,
370371
std::unordered_map<std::string, std::string> /* dump from this comm */>
371372
ncclDumpMap;
373+
#ifdef NCCL_COMM_DUMP_ALL
374+
auto res = ncclCommDumpAll(ncclDumpMap);
375+
if (res == ncclSuccess) {
376+
return ncclDumpMap;
377+
}
378+
// Fall back to dump from each comm if ncclCommDumpAll failed
379+
#endif // NCCL_COMM_DUMP_ALL
380+
372381
// dump_nccl_trace is only called from the default PG (local_id_=0), but we
373382
// want to dump from all comms so we need to iterate over ncclCommMemPoolMap,
374383
// which is static
@@ -382,8 +391,11 @@ static std::
382391
}
383392
}
384393
for (auto& ncclComm : allNCCLComms) {
385-
std::string ncclUniqueIDStr = buildNcclUniqueIdStr(ncclComm->getNcclId());
386-
ncclDumpMap[ncclUniqueIDStr] = ncclComm->ncclCommDump();
394+
std::stringstream ss;
395+
ss << std::hex << ncclComm->getNcclUniqueHash();
396+
std::string ncclUniqueHashStr = ss.str();
397+
398+
ncclDumpMap[ncclUniqueHashStr] = ncclComm->ncclCommDump();
387399
}
388400
return ncclDumpMap;
389401
#else

0 commit comments

Comments
 (0)
0