8000 Update on "[async-tp] fix a race condition that can cause silent corr… · pytorch/pytorch@4f2957e · GitHub
[go: up one dir, main page]

Skip to content

Commit 4f2957e

Browse files
author
Yifu Wang
committed
Update on "[async-tp] fix a race condition that can cause silent correctness issue"
Details described in #137171: ![image](https://github.com/user-attachments/assets/8247b4f1-7805-4585-9d72-05e9475f218b) Fix: we introduce the following invariants in `_pipelined_all_gather_and_consume` and `_pipelined_produce_and_all2all`: - Before any stream writes to/reads from p2p buffers, perform a barrier on channel 0 on the launch stream. - After all streams completed writing to/reading from p2p buffers, perform a barrier on channel 0 on the launch stream. NOTE: This fix only focuses on addressing the race condition. Some barriers are exposed, which can be hidden by computation, and we'll optimize them in subsequent PRs. cc XilunWu H-Huang awgu kwen2501 wanchaol fegin fduwjj wz337 wconstab d4l3k c-p-i-o [ghstack-poisoned]
2 parents 9537bc7 + 8472482 commit 4f2957e

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

42 files changed

+910
-690
lines changed

aten/src/ATen/native/cuda/ForeachBinaryOpList.cu

Lines changed: 39 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -285,64 +285,44 @@ struct Copy<dst_t, c10::complex<float>> {
285285
}
286286
};
287287

288-
#define AT_DISPATCH_SOURCE_TYPES(TYPE, NAME, ...) \
289-
AT_DISPATCH_SWITCH( \
290-
TYPE, \
291-
NAME, \
292-
AT_PRIVATE_CASE_TYPE_USING_HINT( \
293-
at::ScalarType::Byte, \
294-
src_t, \
295-
__VA_ARGS__) AT_PRIVATE_CASE_TYPE_USING_HINT(at::ScalarType::Char, src_t, __VA_ARGS__) \
296-
AT_PRIVATE_CASE_TYPE_USING_HINT( \
297-
at::ScalarType::Long, src_t, __VA_ARGS__) \
298-
AT_PRIVATE_CASE_TYPE_USING_HINT( \
299-
at::ScalarType::Short, src_t, __VA_ARGS__) \
300-
AT_PRIVATE_CASE_TYPE_USING_HINT( \
301-
at::ScalarType::Int, src_t, __VA_ARGS__) \
302-
AT_PRIVATE_CASE_TYPE_USING_HINT( \
303-
at::ScalarType::Double, src_t, __VA_ARGS__) \
304-
AT_PRIVATE_CASE_TYPE_USING_HINT( \
305-
at::ScalarType::Float, src_t, __VA_ARGS__) \
306-
AT_PRIVATE_CASE_TYPE_USING_HINT( \
307-
at::ScalarType::ComplexDouble, \
308-
src_t, \
309-
__VA_ARGS__) \
310-
AT_PRIVATE_CASE_TYPE_USING_HINT( \
311-
at::ScalarType::ComplexFloat, \
312-
src_t, \
313-
__VA_ARGS__) \
314-
AT_PRIVATE_CASE_TYPE_USING_HINT( \
315-
at::ScalarType::Half, \
316-
src_t, \
317-
__VA_ARGS__) \
318-
AT_PRIVATE_CASE_TYPE_USING_HINT( \
319-
at::ScalarType::BFloat16, \
320-
src_t, \
321-
__VA_ARGS__) \
322-
AT_PRIVATE_CASE_TYPE_USING_HINT( \
323-
at::ScalarType::Bool, \
324-
src_t, \
325-
__VA_ARGS__) \
326-
AT_PRIVATE_CASE_TYPE_USING_HINT( \
327-
at::ScalarType:: \
328-
Float8_e4m3fn, \
329-
src_t, \
330-
__VA_ARGS__) \
331-
AT_PRIVATE_CASE_TYPE_USING_HINT( \
332-
at::ScalarType:: \
333-
Float8_e4m3fnuz, \
334-
src_t, \
335-
__VA_ARGS__) \
336-
AT_PRIVATE_CASE_TYPE_USING_HINT( \
337-
at::ScalarType:: \
338-
Float8_e5m2, \
339-
src_t, \
340-
__VA_ARGS__) \
341-
AT_PRIVATE_CASE_TYPE_USING_HINT( \
342-
at::ScalarType:: \
343-
Float8_e5m2fnuz, \
344-
src_t, \
345-
__VA_ARGS__))
288+
#define AT_DISPATCH_SOURCE_TYPES(TYPE, NAME, ...) \
289+
AT_DISPATCH_SWITCH( \
290+
TYPE, \
291+
NAME, \
292+
AT_PRIVATE_CASE_TYPE_USING_HINT( \
293+
at::ScalarType::Byte, src_t, __VA_ARGS__) \
294+
AT_PRIVATE_CASE_TYPE_USING_HINT( \
295+
at::ScalarType::Char, src_t, __VA_ARGS__) \
296+
AT_PRIVATE_CASE_TYPE_USING_HINT( \
297+
at::ScalarType::Long, src_t, __VA_ARGS__) \
298+
AT_PRIVATE_CASE_TYPE_USING_HINT( \
299+
at::ScalarType::Short, src_t, __VA_ARGS__) \
300+
AT_PRIVATE_CASE_TYPE_USING_HINT( \
301+
at::ScalarType::Int, src_t, __VA_ARGS__) \
302+
AT_PRIVATE_CASE_TYPE_USING_HINT( \
303+
at::ScalarType::Double, src_t, __VA_ARGS__) \
304+
AT_PRIVATE_CASE_TYPE_USING_HINT( \
305+
at::ScalarType::Float, src_t, __VA_ARGS__) \
306+
AT_PRIVATE_CASE_TYPE_USING_HINT( \
307+
at::ScalarType::ComplexDouble, \
308+
src_t, \
309+
__VA_ARGS__) \
310+
AT_PRIVATE_CASE_TYPE_USING_HINT( \
311+
at::ScalarType::ComplexFloat, \
312+
src_t, \
313+
__VA_ARGS__) \
314+
AT_PRIVATE_CASE_TYPE_USING_HINT( \
315+
at::ScalarType::Half, \
316+
src_t, \
317+
__VA_ARGS__) \
318+
AT_PRIVATE_CASE_TYPE_USING_HINT( \
319+
at::ScalarType::BFloat16, \
320+
src_t, \
321+
__VA_ARGS__) \
322+
AT_PRIVATE_CASE_TYPE_USING_HINT( \
323+
at::ScalarType::Bool, \
324+
src_t, \
325+
__VA_ARGS__))
346326

347327
namespace {
348328

@@ -430,14 +410,10 @@ void foreach_tensor_copy_list_kernel_cuda_(
430410

431411
std::vector<std::vector<at::Tensor>> tensor_lists{src.vec(), self.vec()};
432412

433-
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND7(
413+
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(
434414
ScalarType::Half,
435415
ScalarType::BFloat16,
436416
ScalarType::Bool,
437-
ScalarType::Float8_e4m3fn,
438-
ScalarType::Float8_e4m3fnuz,
439-
ScalarType::Float8_e5m2,
440-
ScalarType::Float8_e5m2fnuz,
441417
self[0].scalar_type(),
442418
"foreach_tensor_copy",
443419
[&]() {

c10/cuda/CUDACachingAllocator.cpp

Lines changed: 21 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -3127,28 +3127,12 @@ class DeviceCachingAllocator {
31273127
// Returns whether to force all allocations to bypass the caching allocator and
31283128
// go straight to cudaMalloc. This setting is useful when debugging GPU memory
31293129
// errors, since the caching allocator foils cuda-memcheck.
3130-
static bool forceUncachedAllocator() {
3131-
// Allow either CUDA or HIP name for env var for maximum user comfort
3132-
// the CUDA env var avoids being hipified in cuda_to_hip_mappings.py
3133-
static const char* cuda_env = getenv("PYTORCH_NO_CUDA_MEMORY_CACHING");
3134-
static const char* rocm_env = getenv("PYTORCH_NO_HIP_MEMORY_CACHING");
3135-
static bool force_uncached = (cuda_env != nullptr) || (rocm_env != nullptr);
3130+
bool forceUncachedAllocator() {
3131+
static bool force_uncached =
3132+
getenv("PYTORCH_NO_CUDA_MEMORY_CACHING") != nullptr;
31363133
return force_uncached;
31373134
}
31383135

3139-
static void* uncached_allocate(size_t size) {
3140-
void* devPtr = nullptr;
3141-
// Deliberately don't use cudaMallocMaybeCapturing here, to force an error
3142-
// if someone tries to use forceUncachedAllocator while capturing.
3143-
C10_CUDA_CHECK(cudaMalloc(&devPtr, size));
3144-
const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace();
3145-
if (C10_UNLIKELY(interp)) {
3146-
(*interp)->trace_gpu_memory_allocation(
3147-
c10::kCUDA, reinterpret_cast<uintptr_t>(devPtr));
3148-
}
3149-
return devPtr;
3150-
}
3151-
31523136
static void uncached_delete(void* ptr) {
31533137
if (TORCH_SDT_IS_ENABLED(free)) {
31543138
TORCH_SDT_WITH_SEMAPHORE(free, ptr);
@@ -3166,9 +3150,6 @@ void local_raw_delete(void* ptr);
31663150

31673151
class NativeCachingAllocator : public CUDAAllocator {
31683152
private:
3169-
// allows this allocator to be turned on and off programmatically
3170-
bool enable_ = true;
3171-
31723153
// Shard allocation region to have independent mutexes to reduce contention.
31733154
static constexpr size_t kNumMutexShard = 67;
31743155

@@ -3343,14 +3324,6 @@ class NativeCachingAllocator : public CUDAAllocator {
33433324
da->emptyCache( 10000 );
33443325
}
33453326

3346-
void enable(bool value) override {
3347-
enable_ = value;
3348-
}
3349-
3350-
bool isEnabled() const override {
3351-
return enable_;
3352-
}
3353-
33543327
void* getBaseAllocation(void* ptr, size_t* outSize) override {
33553328
Block* block = get_allocated_block(ptr);
33563329
if (!block) {
@@ -3485,9 +3458,17 @@ class NativeCachingAllocator : public CUDAAllocator {
34853458
void (*deleteFunc)(void*) = &local_raw_delete;
34863459
CUDAStream stream = cuda::getCurrentCUDAStream(device);
34873460

3488-
if (forceUncachedAllocator() || !isEnabled()) {
3461+
if (forceUncachedAllocator()) {
34893462
deleteFunc = &uncached_delete;
3490-
devPtr = uncached_allocate(size);
3463+
3464+
// Deliberately don't use cudaMallocMaybeCapturing here, to force an error
3465+
// if someone tries to use forceUncachedAllocator while capturing.
3466+
C10_CUDA_CHECK(cudaMalloc(&devPtr, size));
3467+
const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace();
3468+
if (C10_UNLIKELY(interp)) {
3469+
(*interp)->trace_gpu_memory_allocation(
3470+
c10::kCUDA, reinterpret_cast<uintptr_t>(devPtr));
3471+
}
34913472
} else {
34923473
if (size != 0) {
34933474
this->malloc(&devPtr, device, size, stream);
@@ -3501,7 +3482,7 @@ class NativeCachingAllocator : public CUDAAllocator {
35013482
return {devPtr, devPtr, deleteFunc, Device(DeviceType::CUDA, device)};
35023483
}
35033484
DeleterFnPtr raw_deleter() const override {
3504-
if (forceUncachedAllocator() || !isEnabled()) {
3485+
if (forceUncachedAllocator()) {
35053486
return &uncached_delete;
35063487
} else {
35073488
return &local_raw_delete;
@@ -3558,29 +3539,21 @@ class NativeCachingAllocator : public CUDAAllocator {
35583539
if (nbytes == 0) {
35593540
return nullptr;
35603541
}
3542+
c10::DeviceIndex device = 0;
3543+
C10_CUDA_CHECK(c10::cuda::GetDevice(&device));
35613544
void* r = nullptr;
3562-
if (forceUncachedAllocator() || !isEnabled()) {
3563-
r = uncached_allocate(nbytes);
3564-
} else {
3565-
c10::DeviceIndex device = 0;
3566-
C10_CUDA_CHECK(c10::cuda::GetDevice(&device));
3567-
malloc(&r, device, nbytes, cuda::getCurrentCUDAStream(device));
3568-
}
3545+
malloc(&r, device, nbytes, cuda::getCurrentCUDAStream(device));
35693546
return r;
35703547
}
35713548

35723549
void* raw_alloc_with_stream(size_t nbytes, cudaStream_t stream) override {
35733550
if (nbytes == 0) {
35743551
return nullptr;
35753552
}
3553+
c10::DeviceIndex device = 0;
3554+
C10_CUDA_CHECK(c10::cuda::GetDevice(&device));
35763555
void* r = nullptr;
3577-
if (forceUncachedAllocator() || !isEnabled()) {
3578-
r = uncached_allocate(nbytes);
3579-
} else {
3580-
c10::DeviceIndex device = 0;
3581-
C10_CUDA_CHECK(c10::cuda::GetDevice(&device));
3582-
malloc(&r, device, nbytes, stream);
3583-
}
3556+
malloc(&r, device, nbytes, stream);
35843557
return r;
35853558
}
35863559

@@ -3625,11 +3598,7 @@ class NativeCachingAllocator : public CUDAAllocator {
36253598
}
36263599

36273600
void raw_delete(void* ptr) override {
3628-
if (forceUncachedAllocator() || !isEnabled()) {
3629-
uncached_delete(ptr);
3630-
} else {
3631-
this->free(ptr);
3632-
}
3601+
this->free(ptr);
36333602
}
36343603

36353604
// In CUDA IPC, sender sends a tensor to receiver via shareIPCHandle,

c10/cuda/CUDACachingAllocator.h

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -206,8 +206,6 @@ class CUDAAllocator : public Allocator {
206206
virtual bool initialized() = 0;
207207
virtual void setMemoryFraction(double fraction, c10::DeviceIndex device) = 0;
208208
virtual void emptyCache() = 0;
209-
virtual void enable(bool value) = 0;
210-
virtual bool isEnabled() const = 0;
211209
virtual void cacheInfo(c10::DeviceIndex device, size_t* largestBlock) = 0;
212210
virtual void* getBaseAllocation(void* ptr, size_t* size) = 0;
213211
virtual void recordStream(const DataPtr&, CUDAStream stream) = 0;
@@ -329,14 +327,6 @@ inline void emptyCache() {
329327
return get()->emptyCache();
330328
}
331329

332-
inline void enable(bool value) {
333-
return get()->enable(value);
334-
}
335-
336-
inline bool isEnabled() {
337-
return get()->isEnabled();
338-
}
339-
340330
inline void cacheInfo(c10::DeviceIndex device, size_t* largestBlock) {
341331
return get()->cacheInfo(device, largestBlock);
342332
}

c10/cuda/CUDAMallocAsyncAllocator.cpp

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -496,14 +496,6 @@ struct CudaMallocAsyncAllocator : public CUDAAllocator {
496496
}
497497
}
498498

499-
void enable(bool) override {
500-
// cannot disable
501-
}
502-
503-
bool isEnabled() const override {
504-
return true;
505-
}
506-
507499
void cacheInfo(c10::DeviceIndex device, size_t* maxWorkspaceGuess) override {
508500
// The only consumer of cacheInfo is getMaxWorkspaceSize in Conv_v7.cpp.
509501
// Afaict, the role of cacheInfo is to give getMaxWorkspaceSize a reasonable

docs/source/cuda.rst

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -123,15 +123,6 @@ Memory management
123123
MemPool
124124
MemPoolContext
125125

126-
.. currentmodule:: torch.cuda.memory
127-
128-
.. autosummary::
129-
:toctree: generated
130-
:nosignatures:
131-
132-
caching_allocator_enable
133-
134-
.. currentmodule:: torch.cuda
135126
.. autoclass:: torch.cuda.use_mem_pool
136127

137128
.. FIXME The following doesn't seem to exist. Is it supposed to?

test/distributed/_composable/fully_shard/test_fully_shard_util.py

Lines changed: 1 addition & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,6 @@
22

33
import sys
44

5-
import pytest
6-
75
import torch
86
import torch.distributed as dist
97
from torch.distributed._composable import fully_shard
@@ -14,14 +12,7 @@
1412
from torch.testing._internal.common_dist_composable import CompositeModel, UnitModule
1513
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
1614
from torch.testing._internal.common_fsdp import FSDPTest
17-
from torch.testing._internal.common_utils import (
18-
run_tests,
19-
TEST_WITH_DEV_DBG_ASAN,
20-
TestCase,
21-
)
22-
23-
24-
is_cuda_8_9 = torch.cuda.is_available() and torch.cuda.get_device_capability() >= (8, 9)
15+
from torch.testing._internal.common_utils import run_tests, TEST_WITH_DEV_DBG_ASAN
2516

2617

2718
if not dist.is_available():
@@ -121,32 +112,5 @@ def test_get_sharded_module_tree_with_module_name_to_fqns(self):
121112
)
122113

123114

124-
class TestUtilsSingleDevice(TestCase):
125-
@pytest.mark.skipif(not is_cuda_8_9, reason="requires SM89 compatible machine")
126-
def test_foreach_copy_float8(self):
127-
for dtype in [
128-
torch.float8_e4m3fn,
129-
torch.float8_e4m3fnuz,
130-
torch.float8_e5m2,
131-
torch.float8_e5m2fnuz,
132-
]:
133-
src = [torch.rand(2, 2, device="cuda").to(dtype)] * 2
134-
dst = [torch.zeros(2, 2, device="cuda").to(dtype) 741A ] * 2
135-
# needed by fully_shard(Float8Linear)
136-
torch._foreach_copy_(src, dst)
137-
for s, d in zip(src, dst):
138-
self.assertEqual(s, d)
139-
torch.equal(src[0], dst[0])
140-
141-
src = [torch.rand(2, 2, device="cpu").to(dtype)] * 2
142-
dst = [torch.zeros(2, 2, device="cpu").to(dtype)] * 2
143-
# needed by fully_shard(Float8Linear)
144-
torch._foreach_copy_(src, dst)
145-
for s, d in zip(src, dst):
146-
# did not use torch.equal because
147-
# "equal_cpu" not implemented
148-
assert torch.all(s == d).item()
149-
150-
151115
if __name__ == "__main__":
152116
run_tests()

test/distributed/_tensor/test_dtensor_ops.py

BF1B
Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -314,12 +314,8 @@ def wrapped(fn):
314314
xfail("nn.functional.huber_loss"),
315315
xfail("nn.functional.instance_norm"),
316316
xfail("nn.functional.interpolate", "area"),
317-
xfail("nn.functional.interpolate", "bicubic"),
318-
xfail("nn.functional.interpolate", "bilinear"),
319-
xfail("nn.functional.interpolate", "linear"),
320317
xfail("nn.functional.interpolate", "nearest"),
321318
xfail("nn.functional.interpolate", "nearest-exact"),
322-
xfail("nn.functional.interpolate", "trilinear"),
323319
xfail("nn.functional.leaky_relu"),
324320
xfail("nn.functional.linear"),
325321
xfail("nn.functional.local_response_norm"),
@@ -361,7 +357,6 @@ def wrapped(fn):
361357
xfail("nn.functional.triplet_margin_loss"),
362358
xfail("nn.functional.triplet_margin_with_distance_loss"),
363359
xfail("nn.functional.unfold"),
364-
xfail("nn.functional.upsample_bilinear"),
365360
xfail("nn.functional.upsample_nearest"),
366361
xfail("nonzero"),
367362
xfail("normal"),

0 commit comments

Comments
 (0)
0