-
Notifications
You must be signed in to change notification settings - Fork 24.8k
Adds cudaMallocAsync as an alternative backend for the CUDA allocator #65365
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 1 commit
f5f5cde
16ea79b
f1f803d
a73877e
35b322f
6ada155
a6f271e
9e2b3b7
3026470
430e43c
57aa340
44a9c72
bae5608
0b6ffa1
f2038a3
6b9d832
253851d
cfd624e
bfaae65
ddcec61
06d58b5
fdaaa9f
8f94458
daf188f
eeeac81
4d7388b
89b03b3
eef5b31
27a2d68
991d303
1815536
4d8dca7
363fe3c
b13f118
4089885
368a0de
c80a05f
db53e41
2293a94
bdc6d30
7e7c12b
c90a3a0
c2d84ea
aa2bee7
6fbb8cc
78eb46e
a240ec8
a208bdb
1189ff7
05c4554
c33ce86
c85cb9c
7fe0e75
b534129
2f7d1b5
bc55994
e0ec118
09709d5
4cbde6f
6bbf293
e810e0f
95af048
1cc5d02
a006a53
3d23053
49d4332
3cc7a1f
5401784
ec5b6ff
c4a9acf
4c79ed8
2b1e0b2
9a47eff
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
- Loading branch information
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -15,10 +15,10 @@ namespace cuda { | |
namespace CUDACachingAllocator { | ||
namespace CudaMallocAsync { | ||
|
||
// Allocator that uses cudaMallocAsync to implement the same interface | ||
// as CUDACachingAllocator.cpp. | ||
// | ||
// cudaMallocAsync works transparently with CUDA graphs. | ||
// CUDA device allocator that uses cudaMallocAsync to implement | ||
// the same interface as CUDACachingAllocator.cpp. | ||
|
||
// Designed to be safe for CUDA graph capture. | ||
|
||
// Implementation details, not declared in CUDACachingAllocator.h | ||
namespace { | ||
|
@@ -31,8 +31,8 @@ int device_count = 0; | |
std::vector<bool> devs_initialized_flags; | ||
std::vector<CUDAStream> dummy_unifying_free_streams; | ||
|
||
// Potential future micro-optimization: | ||
// Some accesses to usage_streams_each_ptr are read-only. | ||
// Possible micro-optimization: | ||
// Some accesses to ptr_info are read-only. | ||
// We could let those be concurrent with a shared_mutex and | ||
// have concurrent calls take a shared_lock. | ||
// Keeping it simple with an ordinary mutex for now. | ||
|
@@ -47,14 +47,28 @@ std::mutex general_mutex; | |
* sure no uncaptured tensor will ever have its destructor called | ||
* in a capturing region. | ||
* We avoid errors by | ||
* 1. tracking captured and uncaptured allocated pointers separately | ||
* 1. remembering if allocated pointers were captured or uncaptured | ||
* 2. during capture, if we detect an attempt to free an uncaptured | ||
* allocation on a capturing stream, don't free it immediately, | ||
* just remember it and defer its cudaFreeAsync call to after | ||
* the end of capture (specifically, to notifyCaptureEnded). | ||
*/ | ||
std::unordered_map<void*, std::vector<CUDAStream>> usage_streams_ungraphed_ptrs; | ||
std::unordered_map<void*, std::vector<CUDAStream>> usage_streams_graphed_ptrs; | ||
|
||
struct UsageStream { | ||
cudaStream_t stream; | ||
int device; | ||
UsageStream(cudaStream_t s, int d) : stream(s), device(d) {} | ||
}; | ||
|
||
struct PtrUsage { | ||
std::vector<UsageStream> usage_streams; | ||
uint64_t size; | ||
bool captured; | ||
PtrUsage(uint64_t s, bool c) : size(s), captured(c) {} | ||
}; | ||
|
||
using PtrInfo = std::unordered_map<void*, PtrUsage>; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You sure you want an There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Called every malloc, looks pretty hot to me There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Changed to ska::flat_hash_map. Do you know for sure flat_hash_map a drop-in replacement for std::unordered map? Does it implement the same interface? The code builds with flat_hash_map, but that doesn't guarantee the semantics of member functions are the same... There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It might have different guaranteed on iterator invalidation, but otherwise should be the same (I couldn't find the docs right away) |
||
PtrInfo ptr_info; | ||
std::vector<void*> ungraphed_ptrs_defer_free_until_no_capture; | ||
|
||
// Graph-specific helpers | ||
|
@@ -94,7 +108,7 @@ std::vector<void*> ungraphed_ptrs_defer_free_until_no_capture; | |
* carefully about the CPU overhead of remembering and rejoining | ||
* all free streams during capture. Maybe it's not a big deal. | ||
*/ | ||
std::unordered_set<CUDAStream> capture_free_streams; | ||
std::unordered_set<UsageStream> capture_free_streams; | ||
bool capture_underway = false; | ||
|
||
// Implementation functions | ||
|
@@ -128,31 +142,24 @@ inline void lazy_init_device(int device) { | |
} | ||
} | ||
|
||
void free(void* ptr) { | ||
std::lock_guard<std::mutex> lk(general_mutex); | ||
|
||
auto it = usage_streams_each_ptr.find(ptr); | ||
TORCH_INTERNAL_ASSERT(it != usage_streams_each_ptr.end(), | ||
"ptr not represented in usage_streams_each_ptr"); | ||
TORCH_INTERNAL_ASSERT(it->second.size() != 0, | ||
"ptr's stream uses vector is empty"); | ||
|
||
// Assumes the caller holds general_mutex | ||
inline void free_impl(PtrInfo::iterator& it) { | ||
// Possible micro-optimization: If we did a value-copy here, we could move | ||
// usage_streams_each_ptr.erase(it) up here and drop the lock immediately. | ||
const auto& usage_streams = it->second; | ||
// ptr_info.erase(it) up here and drop the lock immediately. | ||
const auto& usage_streams = it->second.usage_streams; | ||
|
||
// If the usage stream is a null (default) stream, | ||
// cudaFreeAsync infers the device from the ambient context, | ||
// so we need to set the right ambient context. | ||
CUDAGuard g(usage_streams[0].device_index()); | ||
CUDAGuard g(usage_streams[0].device); | ||
|
||
if (usage_streams.size() == 1) { | ||
// ptr was only used on one stream, which must have been | ||
// the original allocation stream. | ||
// Frees ptr in the original allocation stream. | ||
C10_CUDA_CHECK(cudaFreeAsync(ptr, usage_streams[0])); | ||
C10_CUDA_CHECK(cudaFreeAsync(ptr, usage_streams[0].stream)); | ||
|
||
if (C10_UNLIKELY(captures_underway)) { | ||
if (C10_UNLIKELY(capture_underway)) { | ||
// See Note [Avoid dangling free streams during CUDA graph capture] | ||
capture_free_streams.insert(usage_streams[0]); | ||
} | ||
|
@@ -167,20 +174,20 @@ void free(void* ptr) { | |
|
||
// Retrieves the dummy "unifier" stream from the device | ||
// on which the pointer was originally allocated. | ||
auto dummy_unifying_free_stream = dummy_unifying_free_streams[usage_streams[0].device_index()]; | ||
auto dummy_unifying_free_stream = dummy_unifying_free_streams[usage_streams[0].devce]; | ||
|
||
// The number of usage streams is typically small (low single digits) | ||
for (const auto& usage_stream : usage_streams) { | ||
// Logic here accommodates the chance some of the usage streams were on other devices, | ||
// which is possible if some usage kernels accessed the memory via p2p. | ||
|
||
// cudaEventRecord requires that the input event and stream are on the same device. | ||
CUDAGuard g_usage(usage_stream.device_index()); | ||
CUDAGuard g_usage(usage_stream.device); | ||
|
||
// CUDACachingAllocator.cpp uses raw cuda events, as do we. | ||
cudaEvent_t event; | ||
C10_CUDA_CHECK(cudaEventCreateWithFlags(&event, cudaEventDisableTiming)); | ||
C10_CUDA_CHECK(cudaEventRecord(event, usage_stream.stream())); | ||
C10_CUDA_CHECK(cudaEventRecord(event, usage_stream.stream)); | ||
C10_CUDA_CHECK(cudaStreamWaitEvent(dummy_unifying_free_stream.stream(), event)); | ||
C10_CUDA_CHECK(cudaEventDestroy(event)); | ||
} | ||
|
@@ -199,13 +206,46 @@ void free(void* ptr) { | |
// but this forces a potentially false dependency of usage_streams[0] | ||
// on all the other usage_streams. | ||
|
||
if (C10_UNLIKELY(captures_underway)) { | ||
if (C10_UNLIKELY(capture_underway)) { | ||
// See Note [Avoid dangling free streams during CUDA graph capture] | ||
capture_free_streams.insert(dummy_unifying_free_stream); | ||
capture_free_streams.insert({dummy_unifying_free_stream.stream, | ||
dummy_unifying_free_stream.device}); | ||
} | ||
} | ||
|
||
usage_streams_each_ptr.erase(it); | ||
ptr_info.erase(it); | ||
} | ||
|
||
void free(void* ptr) { | ||
std::lock_guard<std::mutex> lk(general_mutex); | ||
|
||
auto it = ptr_info.find(ptr); | ||
TORCH_INTERNAL_ASSERT(it != ptr_info.end(), | ||
"ptr not found in ptr_info"); | ||
TORCH_INTERNAL_ASSERT(it->second.usage_streams.size() != 0, | ||
"ptr's stream uses vector is empty"); | ||
|
||
if (C10_UNLIKELY(capture_underway)) { | ||
if (it->second.captured) { | ||
// See Note [Avoid freeing uncaptured ptrs during CUDA graph capture] | ||
// Remembers the raw pointer, not the iterator. | ||
// This forces notifyCaptureEnded to do another lookup, | ||
// but avoids the risk the iterator might be invalidated | ||
// between now and then. | ||
ungraphed_ptrs_defer_free_until_no_capture.push_back(ptr); | ||
return; | ||
} | ||
} | ||
|
||
if (C10_UNLIKELY(it->second.captured)) { | ||
TORCH_WARN("Attempting uncaptured free of a captured allocation. " | ||
"This is technically allowed, but may indicate you are losing " | ||
"the last user-visible tensor through which the allocation can " | ||
"be accessed, so you'll have no way to view the data after " | ||
"future replays of the owning graph."); | ||
} | ||
|
||
free_impl(it); | ||
} | ||
|
||
// Symmetric with THCCachingAllocator::malloc for now, | ||
|
@@ -228,11 +268,12 @@ void malloc(void** devPtr, int device, size_t size, cudaStream_t stream) { | |
|
||
lazy_init_device(device); | ||
|
||
auto inserted = usage_streams_each_ptr.insert(std::make_pair(devPtr, {}); | ||
auto inserted = ptr_info.emplace({size, capture_underway}); | ||
TORCH_INTERNAL_ASSERT(inserted.second, | ||
"address returned by cudaMallocAsync already exists " | ||
"in usage_streams_each_ptr"); | ||
inserted.first->second.push_back(); | ||
|
||
inserted.first->second.usage_streams.emplace_back(stream, device); | ||
ezyang marked this conversation as resolved.
Show resolved
Hide resolved
|
||
} | ||
|
||
} // anonymous namespace | ||
|
@@ -251,7 +292,7 @@ struct CudaCachingAllocator : public Allocator { | |
C10_CUDA_CHECK(cudaGetDevice(&device)); | ||
void* r = nullptr; | ||
if (size != 0) { | ||
malloc(&r, device, size, getCurrentCUDAStream(device)); | ||
malloc(&r, device, size, cuda::getCurrentCUDAStream(device)); | ||
} | ||
return {r, r, &raw_delete, Device(DeviceType::CUDA, device)}; | ||
} | ||
|
@@ -319,13 +360,13 @@ void recordStream(const DataPtr& ptr, cuda::CUDAStream stream) { | |
std::lock_guard<std::mutex> lk(general_mutex); | ||
|
||
// The pointer should exist in the map already. | ||
auto it = usage_streams_each_ptr.find(ptr.get()); | ||
TORCH_INTERNAL_ASSERT(it != usage_streams_each_ptr.end(), | ||
"ptr not represented in usage_streams_each_ptr"); | ||
TORCH_INTERNAL_ASSERT(it->second.size() != 0, | ||
auto it = ptr_info.find(ptr.get()); | ||
TORCH_INTERNAL_ASSERT(it != ptr_info.end(), | ||
"ptr not found in ptr_info"); | ||
TORCH_INTERNAL_ASSERT(it->second.usage_streams.size() != 0, | ||
"ptr's stream uses vector is empty"); | ||
|
||
it->second.push_back(stream); | ||
it->second.usage_streams.emplace_back(stream, device); | ||
} | ||
|
||
std::mutex* getFreeMutex() { | ||
|
@@ -404,22 +445,34 @@ std::vector<SegmentInfo> snapshot() { | |
} | ||
|
||
// CUDAGraph interactions | ||
void notifyCaptureBegin(CaptureId_t graph_id, MempoolId_t mempool_id) {} // no-op | ||
void notifyCaptureBegin(CaptureId_t graph_id, MempoolId_t mempool_id) { | ||
std::lock_guard<std::mutex> lk(general_mutex); | ||
|
||
TORCH_CHECK(!capture_underway. | ||
"Only one capture at a time is allowed in a process.") | ||
capture_underway = true; | ||
} | ||
|
||
void notifyCaptureAboutToEnd(int device, CaptureId_t graph_id) { | ||
assertValidDevice(device); | ||
|
||
std::lock_guard<std::mutex> lk(general_mutex); | ||
|
||
TORCH_CHECK(capture_underway. | ||
"CudaMallocAsync::notifyCaptureAboutToEnd called, " | ||
"but CudaMallocAsync::capture_underway is false"); | ||
|
||
auto capture_stream = cuda::getCurrentCUDAStream(device) | ||
|
||
// See Note [Avoid dangling free streams during CUDA graph capture] | ||
for (const auto& free_stream : capture_free_streams) { | ||
// cudaEventRecord requires that the input event and stream are on the same device. | ||
CUDAGuard g(free_stream.device_index()); | ||
CUDAGuard g(free_stream.device); | ||
|
||
// CUDACachingAllocator.cpp uses raw cuda events, as do we. | ||
cudaEvent_t event; | ||
C10_CUDA_CHECK(cudaEventCreateWithFlags(&event, cudaEventDisableTiming)); | ||
C10_CUDA_CHECK(cudaEventRecord(event, free_stream.stream())); | ||
C10_CUDA_CHECK(cudaEventRecord(event, free_stream.stream)); | ||
C10_CUDA_CHECK(cudaStreamWaitEvent(capture_stream.stream(), event)); | ||
C10_CUDA_CHECK(cudaEventDestroy(event)); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You don't have to fix this in this PR (because I think this affects the original caching allocator as well), but now I'm kind of wondering if an error does happen in this cleanup. It seems to me like the capture data structures will be left in a wonky intermediate state. Maybe we shouldn't raise an exception if these calls fail :/ |
||
} | ||
|
@@ -428,6 +481,21 @@ void notifyCaptureAboutToEnd(int device, CaptureId_t graph_id) { | |
void notifyCaptureEnded(int device, CaptureId_t graph_id) { | ||
assertValidDevice(device); | ||
|
||
std::lock_guard<std::mutex> lk(general_mutex); | ||
|
||
TORCH_CHECK(capture_underway. | ||
"CudaMallocAsync::notifyCaptureEnded called, " | ||
"but CudaMallocAsync::capture_underway is false"); | ||
capture_underway = false; | ||
|
||
for (const auto ptr : ungraphed_ptrs_defer_free_until_no_capture ) { | ||
auto it = ptr_info.find(ptr); | ||
TORCH_INTERNAL_ASSERT(it != ptr_info.end(), | ||
"ptr not found in ptr_info"); | ||
TORCH_INTERNAL_ASSERT(it->second.usage_streams.size() != 0, | ||
"ptr's stream uses vector is empty"); | ||
free_impl(it); | ||
} | ||
} | ||
|
||
void notifyCaptureDestroy(int device, MempoolId_t mempool_id) {} // no-op | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This seems... suboptimal. Why isn't there an option to do an untracked free?
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's a restriction of the CUDA API because ungraphed allocations are meant to be one-off, while graphs are meant to be replayed many times. I think their philosophy is: if cuda silently allowed ungraphed allocations to be freed during capture, an ungraphed allocation address that happened to get freed during capture would be actually-freed on the first replay, then later replays would attempt to actually-free the same address, at which point it would either no longer be live or be associated with some totally different allocation (and in either scenario an actual free is a bad thing to do).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So CUDA API is inconsistent in that on replays it will ignore allocations recorded during graph capture (with relaxed mode or whatever), but won't ignore frees?
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Cuda guys reconcile this "inconsistency" in their brains by regarding
as "auto free on launch", ie, when the replay sees some captured VA that is unfreed since the last launch, it (conceptually or literally, im not sure) frees and immediately reallocs that VA.
The graph should not ignore frees of graphed memory, because graphed VA is intended to get reallocated as many times as it's freed. For example, we want: