8000 more update · pytorch/pytorch@2df7332 · GitHub
[go: up one dir, main page]

Skip to content

Commit 2df7332

Browse files
committed
more update
1 parent cf174bf commit 2df7332

File tree

5 files changed

+17
-49
lines changed

5 files changed

+17
-49
lines changed

torch/csrc/cuda/nccl.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,7 @@ ncclDataType_t to_nccl_data_type(c10::ScalarType type) {
118118
return ncclDataType_t::ncclUint8;
119119
#endif
120120

121-
#if HAS_NCCL_BF16_DATATYPE
121+
#ifdef HAS_NCCL_BF16_DATATYPE
122122
case at::kBFloat16:
123123
return ncclDataType_t::ncclBfloat16;
124124
#endif

torch/csrc/cuda/nccl.h

+12-25
Original file line numberDiff line numberDiff line change
@@ -9,52 +9,39 @@
99

1010
// NCCL BFloat16 is enabled only for CUDA 11+ and NCCL versions 2.10+, or for
1111
// HIP 3.1+
12+
#if defined(NCCL_MAJOR) && \
13+
((NCCL_MAJOR > 2) || (NCCL_MAJOR == 2) && (NCCL_MINOR >= 10))
1214
#if defined(__CUDA_BF16_TYPES_EXIST__)
13-
#define HAS_NCCL_BF16_DATATYPE \
14-
((NCCL_MAJOR > 2) || (NCCL_MAJOR == 2) && (NCCL_MINOR >= 10))
15+
#define HAS_NCCL_BF16_DATATYPE
16+
#endif // defined(__CUDA_BF16_TYPES_EXIST__)
17+
#define NCCL_HAS_AVG
1518
#elif defined(USE_ROCM) && (TORCH_HIP_VERSION >= 301)
16-
#define HAS_NCCL_BF16_DATATYPE 1
17-
#else
18-
#define HAS_NCCL_BF16_DATATYPE 0
19-
#endif
20-
21-
// Error checking is enabled only for NCCL versions 2.4+ since ncclCommAbort()
22-
// and ncclCommGetAsyncError() are not supported in earlier versions.
23-
#if defined(NCCL_MAJOR) && defined(NCCL_MINOR) && \
24-
(NCCL_MAJOR > 2 || (NCCL_MAJOR == 2 && NCCL_MINOR >= 4))
25-
#define ENABLE_NCCL_ERROR_CHECKING
26-
#endif
27-
28-
// P2P is enabled only for NCCL versions 2.7+ since ncclSend()
29-
// and ncclRecv() are not supported in earlier versions.
30-
#if defined(NCCL_MAJOR) && defined(NCCL_MINOR) && \
31-
(NCCL_MAJOR > 2 || (NCCL_MAJOR == 2 && NCCL_MINOR >= 7))
32-
#define ENABLE_NCCL_P2P_SUPPORT
33-
#endif
19+
#define HAS_NCCL_BF16_DATATYPE
20+
#endif // NCCL >= 2.10
3421

35-
#if defined(NCCL_MAJOR) && defined(NCCL_MINOR) && \
22+
#if defined(NCCL_MAJOR) && \
3623
(NCCL_MAJOR > 2 || (NCCL_MAJOR == 2 && NCCL_MINOR >= 11))
3724
#define ENABLE_NCCL_PREMUL_SUM_SUPPORT
3825
#endif
3926

40-
#if defined(NCCL_MAJOR) && defined(NCCL_MINOR) && \
27+
#if defined(NCCL_MAJOR) && \
4128
(NCCL_MAJOR > 2 || (NCCL_MAJOR == 2 && NCCL_MINOR >= 13))
4229
#define NCCL_HAS_REMOTE_ERROR
4330
#define ENABLE_NCCL_GET_LAST_ERROR
4431
#endif
4532

46-
#if defined(NCCL_MAJOR) && defined(NCCL_MINOR) && \
33+
#if defined(NCCL_MAJOR) && \
4734
(NCCL_MAJOR > 2 || (NCCL_MAJOR == 2 && NCCL_MINOR >= 14))
4835
#define NCCL_HAS_COMM_NONBLOCKING
4936
#endif
5037

51-
#if defined(NCCL_MAJOR) && defined(NCCL_MINOR) && \
38+
#if defined(NCCL_MAJOR) && \
5239
(NCCL_MAJOR > 2 || (NCCL_MAJOR == 2 && NCCL_MINOR >= 17))
5340
#define NCCL_HAS_COMM_CTA_CGA
5441
#define NCCL_HAS_COMM_SPLIT
5542
#endif
5643

57-
#if defined(NCCL_MAJOR) && defined(NCCL_MINOR) && \
44+
#if defined(NCCL_MAJOR) && \
5845
(NCCL_MAJOR > 2 || (NCCL_MAJOR == 2 && NCCL_MINOR >= 19))
5946
#define NCCL_HAS_COMM_REGISTER
6047
#endif

torch/csrc/distributed/c10d/NCCLUtils.cpp

-10
-
// This is a NOOP, if error checks are disabled.
Original file line numberDiff line numberDiff line change
@@ -245,7 +245,6 @@ void NCCLComm::destroy() {
245245
void NCCLComm::abort(std::optional<std::string> commFailureReason) {
246246
LockType lock(mutex_);
247247
at::cuda::OptionalCUDAGuard gpuGuard(deviceIndex_);
248-
#ifdef ENABLE_NCCL_ERROR_CHECKING
249248
if (aborted_ && !initialized_) {
250249
// Should not abort twice.
251250
return;
@@ -285,10 +284,6 @@ void NCCLComm::abort(std::optional<std::string> commFailureReason) {
285284
if (ncclAsyncErr_ == ncclSuccess) {
286285
ncclAsyncErr_ = ncclSystemError;
287286
}
288-
#else
289
290-
return;
291-
#endif
292287
}
293288

294289
bool NCCLComm::isInitialized() const {
@@ -307,17 +302,12 @@ uint64_t NCCLComm::getCommSplitCounter() const {
307302

308303
ncclResult_t NCCLComm::checkForNcclError() {
309304
LockType lock(mutex_);
310-
#ifdef ENABLE_NCCL_ERROR_CHECKING
311305
if (ncclAsyncErr_ != ncclSuccess) {
312306
return ncclAsyncErr_;
313307
}
314308
C10D_NCCL_CHECK(
315309
ncclCommGetAsyncError(ncclComm_, &ncclAsyncErr_), commFailureReason_);
316310
return ncclAsyncErr_;
317-
#else
318-
// Always return success, if error checks are disabled.
319-
return ncclSuccess;
320-
#endif
321311
}
322312

323313
ncclResult_t NCCLComm::registerSegment(void* ptr, size_t size) {

torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp

+1-10
Original file line numberDiff line numberDiff line change
@@ -38,11 +38,6 @@ constexpr const char* const kNCCLAbortedCommStoreKey = "NCCLABORTEDCOMM";
3838

3939
namespace {
4040

41-
#if defined(NCCL_MAJOR) && \
42-
((NCCL_MAJOR > 2) || (NCCL_MAJOR == 2) && (NCCL_MINOR >= 10))
43-
#define NCCL_HAS_AVG 1
44-
#endif // NCCL version >= 2.10
45-
4641
// NCCL op mapping
4742
const std::map<ReduceOp::RedOpType, ncclRedOp_t> ncclOp = {
4843
{ReduceOp::MIN, ncclMin},
@@ -68,7 +63,7 @@ std::map<at::ScalarType, ncclDataType_t> ncclDataType = {
6863
{at::kFloat8_e4m3fn, ncclUint8},
6964
{at::kFloat8_e4m3fnuz, ncclUint8},
7065
{at::kFloat8_e5m2fnuz, ncclUint8},
71-
#if HAS_NCCL_BF16_DATATYPE
66+
#ifdef HAS_NCCL_BF16_DATATYPE
7267
{at::kBFloat16, ncclBfloat16},
7368
#endif // HAS_NCCL_BF16_DATATYPE
7469
};
@@ -928,10 +923,8 @@ ProcessGroupNCCL::ProcessGroupNCCL(
928923
PrefixStore* prefixStore = dynamic_cast<PrefixStore*>(store_.get());
929924
globalStore_ =
930925
prefixStore ? prefixStore->getUnderlyingNonPrefixStore() : store_;
931-
#ifdef ENABLE_NCCL_ERROR_CHECKING
932926
enableTiming_.store(
933927
getCvarBool(TORCH_NCCL_ENABLE_TIMING, false) || desyncDebug_);
934-
#endif // ENABLE_NCCL_ERROR_CHECKING
935928
avoidRecordStreams_ = getCvarBool(TORCH_NCCL_AVOID_RECORD_STREAMS, false);
936929
#ifdef NCCL_HAS_COMM_REGISTER
937930
useTensorRegisterAllocatorHook_ =
@@ -960,15 +953,13 @@ ProcessGroupNCCL::ProcessGroupNCCL(
960953
}
961954
}
962955

963-
#ifdef ENABLE_NCCL_ERROR_CHECKING
964956
// in blockingWait mode, we don't need to enable the watchdog thread to check
965957
// the timeout or nccl error because the main thread would throw an exception
966958
// and it is the user's responsibility to handle the exception.
967959
if (!blockingWait_) {
968960
ncclCommWatchdogThread_ =
969961
std::thread(&ProcessGroupNCCL::ncclCommWatchdog, this);
970962
}
971-
#endif // ENABLE_NCCL_ERROR_CHECKING
972963

973964
init();
974965
const std::string OFF = "OFF";

torch/csrc/distributed/c10d/quantization/quantization_gpu.cu

+3-3
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ at::Tensor _float_to_bfloat16_cuda(const at::Tensor& input) {
6666

6767
auto output = at::empty(
6868
{nrows, ncols},
69-
#if HAS_NCCL_BF16_DATATYPE
69+
#ifdef HAS_NCCL_BF16_DATATYPE
7070
input.options().dtype(at::kBFloat16));
7171
#else
7272
input.options().dtype(at::kHalf));
@@ -92,7 +92,7 @@ at::Tensor _float_to_bfloat16_cuda(const at::Tensor& input) {
9292
input.const_data_ptr<float>(),
9393
nrows,
9494
ncols,
95-
#if HAS_NCCL_BF16_DATATYPE
95+
#ifdef HAS_NCCL_BF16_DATATYPE
9696
reinterpret_cast<uint16_t*>(output.mutable_data_ptr<at::BFloat16>())
9797
#else
9898
reinterpret_cast<uint16_t*>(output.mutable_data_ptr<at::Half>())
@@ -137,7 +137,7 @@ at::Tensor _bfloat16_to_float_cuda(const at::Tensor& input) {
137137
blockDim,
138138
0,
139139
at::cuda::getCurrentCUDAStream()>>>(
140-
#if HAS_NCCL_BF16_DATATYPE
140+
#ifdef HAS_NCCL_BF16_DATATYPE
141141
reinterpret_cast<const uint16_t*>(input.const_data_ptr<at::BFloat16>()),
142142
#else
143143
reinterpret_cast<const uint16_t*>(input.const_data_ptr<at::Half>()),

0 commit comments

Comments
 (0)
0