10000 switches to flat_hash_set for recorded streams, restores original max… · pytorch/pytorch@a006a53 · GitHub
[go: up one dir, main page]

Skip to content

Commit a006a53

Browse files
committed
switches to flat_hash_set for recorded streams, restores original max_split_size() retrieval to avoid int64<->size_t comparison warnings
1 parent 1cc5d02 commit a006a53

File tree

3 files changed

+77
-69
lines changed

3 files changed

+77
-69
lines changed

c10/cuda/CUDACachingAllocator.cpp

Lines changed: 23 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -310,11 +310,13 @@ cudaError_t cudaMallocMaybeCapturing(void** p, size_t size) {
310310
#endif
311311
}
312312

313-
} // namespace
313+
} // anonymous namespace
314+
} // namespace Native
314315

315316
// Environment config parser
316317
// Defined here, rather than its own .cpp file,
317318
// because parseArgs needs to know kLargeBuffer.
319+
// Defined outside namespace Native because it's not Native-specific.
318320
class CachingAllocatorConfig {
319321
public:
320322
static AllocatorBackend allocator_backend() {
@@ -379,11 +381,11 @@ class CachingAllocatorConfig {
379381
if (kv[0].compare("max_split_size_mb") == 0) {
380382
size_t val2 = stoi(kv[1]);
381383
TORCH_CHECK(
382-
val2 > kLargeBuffer / (1024 * 1024),
384+
val2 > Native::kLargeBuffer / (1024 * 1024),
383385
"CachingAllocator option max_split_size_mb too small, must be > ",
384-
kLargeBuffer / (1024 * 1024),
386+
Native::kLargeBuffer / (1024 * 1024),
385387
"");
386-
val2 = std::max(val2, kLargeBuffer / (1024 * 1024));
388+
val2 = std::max(val2, Native::kLargeBuffer / (1024 * 1024));
387389
val2 = std::min(
388390
val2, (std::numeric_limits<size_t>::max() / (1024 * 1024)));
389391
m_max_split_size = val2 * 1024 * 1024;
@@ -452,6 +454,8 @@ class CachingAllocatorConfig {
452454
}
453455
};
454456

457+
namespace Native {
458+
455459
class DeviceCachingAllocator {
456460
private:
457461
// lock around all operations
@@ -509,7 +513,7 @@ class DeviceCachingAllocator {
509513
DeviceCachingAllocator()
510514
: large_blocks(BlockComparator, /*is_small=*/false),
511515
small_blocks(BlockComparator, /*is_small=*/true) {
512-
stats.max_split_size = CUDACachingAllocator::maxSplitSize();
516+
stats.max_split_size = CachingAllocatorConfig::max_split_size();
513517
}
514518

515519
// All public methods (except the above) acquire the allocator mutex.
@@ -676,7 +680,7 @@ class DeviceCachingAllocator {
676680
update_stat(stats.active[stat_type], 1);
677681
update_stat(stats.active_bytes[stat_type], block->size);
678682
});
679-
if (block->size >= stats.max_split_size)
683+
if (block->size >= CachingAllocatorConfig::max_split_size())
680684
update_stat(stats.oversize_allocations, 1);
681685

682686
c10::reportMemoryUsageToProfiler(
@@ -707,7 +711,7 @@ class DeviceCachingAllocator {
707711
update_stat(stats.allocation[stat_type], -1);
708712
update_stat(stats.allocated_bytes[stat_type], -block->size);
709713
});
710-
if (block->size >= stats.max_split_size)
714+
if (block->size >= CachingAllocatorConfig::max_split_size())
711715
update_stat(stats.oversize_allocations, -1);
712716

713717
if (!block->stream_uses.empty()) {
@@ -1133,7 +1137,7 @@ class DeviceCachingAllocator {
11331137
if (block->pool->is_small) {
11341138
return remaining >= kMinBlockSize;
11351139
} else {
1136-
return (size < stats.max_split_size) &&
1140+
return (size < CachingAllocatorConfig::max_split_size()) &&
11371141
(remaining > kSmallSize);
11381142
}
11391143
}
@@ -1162,11 +1166,11 @@ class DeviceCachingAllocator {
11621166
if (it == pool.blocks.end() || (*it)->stream != p.stream())
11631167
return false;
11641168
// Do not return an oversized block for a large request
1165-
if ((p.size() < stats.max_split_size) &&
1166-
((*it)->size >= stats.max_split_size))
1169+
if ((p.size() < CachingAllocatorConfig::max_split_size()) &&
1170+
((*it)->size >= CachingAllocatorConfig::max_split_size()))
11671171
return false;
11681172
// Allow oversized block size to be rounded up but within a limit
1169-
if ((p.size() >= stats.max_split_size) &&
1173+
if ((p.size() >= CachingAllocatorConfig::max_split_size()) &&
11701174
((*it)->size >= p.size() + kLargeBuffer))
11711175
return false;
11721176
p.block = *it;
@@ -1288,7 +1292,7 @@ class DeviceCachingAllocator {
12881292
update_stat(stats.segment[stat_type], 1);
12891293
update_stat(stats.reserved_bytes[stat_type], size);
12901294
});
1291-
if (size >= stats.max_split_size)
1295+
if (size >= CachingAllocatorConfig::max_split_size())
12921296
update_stat(stats.oversize_segments, 1);
12931297

12941298
// p.block came from new, not cudaMalloc. It should not be nullptr here.
@@ -1300,13 +1304,13 @@ class DeviceCachingAllocator {
13001304
* **/
13011305
/** to satisfy the target size **/
13021306
bool release_available_cached_blocks(const AllocParams& p) {
1303-
if (stats.max_split_size ==
1307+
if (CachingAllocatorConfig::max_split_size() ==
13041308
std::numeric_limits<size_t>::max())
13051309
return false;
13061310
BlockPool& pool = *p.pool;
13071311
Block key = p.search_key;
1308-
key.size = (key.size < stats.max_split_size)
1309-
? stats.max_split_size
1312+
key.size = (key.size < CachingAllocatorConfig::max_split_size())
1313+
? CachingAllocatorConfig::max_split_size()
13101314
: key.size;
13111315
auto it = pool.blocks.lower_bound(&key);
13121316
if (it == pool.blocks.end() || (*it)->stream != p.stream()) {
@@ -1318,7 +1322,7 @@ class DeviceCachingAllocator {
13181322
--it; // Back up one item. Now on the largest block for the correct
13191323
// stream
13201324
while ((totalReleased < key.size) &&
1321-
((*it)->size >= stats.max_split_size) &&
1325+
((*it)->size >= CachingAllocatorConfig::max_split_size()) &&
13221326
((*it)->stream == p.stream())) {
13231327
auto cur = it;
13241328
totalReleased += (*it)->size;
@@ -1383,7 +1387,7 @@ class DeviceCachingAllocator {
13831387
update_stat(stats.segment[stat_type], -1);
13841388
update_stat(stats.reserved_bytes[stat_type], -block->size);
13851389
});
1386-
if (block->size >= stats.max_split_size)
1390+
if (block->size >= CachingAllocatorConfig::max_split_size())
13871391
update_stat(stats.oversize_segments, -1);
13881392

13891393
pool->blocks.erase(block);
@@ -1870,18 +1874,13 @@ std::shared_ptr<void> getIpcDevPtr(std::string handle) {
18701874
// General caching allocator utilities
18711875

18721876
// External config interface (declared in CUDACachingAllocator.h)
1873-
// Should we bother having these two functions?
1874-
// They are basically useless layers of indirection, but a minor
1875-
// code-cleanliness benefit is they alleviate the need to define
1877+
// This is a useless layer of indirection with a minor
1878+
// code-cleanliness benefit: it alleviates the need to define
18761879
// CachingAllocatorConfig itself in CUDACachingAllocator.h.
18771880
AllocatorBackend allocatorBackend() {
18781881
return CachingAllocatorConfig::allocator_backend();
18791882
}
18801883

1881-
size_t maxSplitSize() {
1882-
return CachingAllocatorConfig::max_split_size();
1883-
}
1884-
18851884
// Size pretty-printer
18861885
inline std::string format_size(uint64_t size) {
18871886
std::ostringstream os;

c10/cuda/CUDACachingAllocator.h

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -125,10 +125,9 @@ enum struct AllocatorBackend : uint8_t {
125125
};
126126

127127
C10_CUDA_API AllocatorBackend allocatorBackend();
128-
C10_CUDA_API size_t maxSplitSize();
129128

130129
// Size pretty-printer
131-
inline std::string format_size(uint64_t size);
130+
std::string format_size(uint64_t size);
132131

133132
#define CUDA_ALLOCATOR_BACKEND_INTERFACE \
134133
C10_CUDA_API void* raw_alloc(size_t nbytes); \

c10/cuda/CUDAMallocAsyncAllocator.cpp

Lines changed: 53 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -45,8 +45,21 @@ struct UsageStream {
4545
}
4646
};
4747

48+
bool operator==(const UsageStream& lhs, const UsageStream& rhs) {
49+
return (lhs.stream == rhs.stream) && (lhs.device == rhs.device);
50+
}
51+
52+
struct UsageStreamHash {
53+
size_t operator()(const UsageStream& us) const noexcept {
54+
return std::hash<void*>{}(us.stream) + size_t(us.device);
55+
}
56+
};
57+
4858
struct PtrUsage {
49-
std::vector<UsageStream> usage_streams;
59+
// recorded_streams holds side usage streams added by record_stream calls.
60+
// In other words, it does NOT include the original creation stream.
61+
ska::flat_hash_set<UsageStream, UsageStreamHash> recorded_streams;
62+
UsageStream creation_stream;
5063
uint64_t size;
5164
bool captured;
5265
PtrUsage(uint64_t s, bool c) : size(s), captured(c) {}
@@ -128,16 +141,6 @@ std::vector<size_t> pytorch_memory_limits;
128141
* carefully about the CPU overhead of remembering and rejoining
129142
* all free streams during capture. Maybe it's not a big deal.
130143
*/
131-
bool operator==(const UsageStream& lhs, const UsageStream& rhs) {
132-
return (lhs.stream == rhs.stream) && (lhs.device == rhs.device);
133-
}
134-
135-
struct UsageStreamHash {
136-
size_t operator()(const UsageStream& us) const noexcept {
137-
return std::hash<void*>{}(us.stream) + size_t(us.device);
138-
}
139-
};
140-
141144
std::unordered_set<UsageStream, UsageStreamHash> capture_free_streams;
142145
bool capture_underway = false;
143146

@@ -180,26 +183,36 @@ inline void lazy_init_device(int device) {
180183
}
181184
}
182185

186+
inline void sync_raw(cudaStream_t dependency, cudaStream_t dependent) {
187+
// CUDACachingAllocator.cpp uses raw cuda events, as do we.
188+
cudaEvent_t event;
189+
C10_CUDA_CHECK(cudaEventCreateWithFlags(&event, cudaEventDisableTiming));
190+
C10_CUDA_CHECK(cudaEventRecord(event, dependency));
191+
C10_CUDA_CHECK(cudaStreamWaitEvent(dependent, event));
192+
C10_CUDA_CHECK(cudaEventDestroy(event));
193+
}
194+
183195
// Assumes the caller holds general_mutex
184196
inline void free_impl(PtrInfo::iterator& it) {
185197
// Possible micro-optimization: If we did a value-copy here, we could move
186198
// ptr_info.erase(it) up here and drop the lock immediately.
187-
const auto& usage_streams = it->second.usage_streams;
199+
const auto& recorded_streams = it->second.recorded_streams;
200+
const auto& creation_stream = it->second.creation_stream;
188201

189202
// If the usage stream is a null (default) stream,
190203
// cudaFreeAsync infers the device from the ambient context,
191204
// so we need to set the right ambient context.
192-
CUDAGuard g(usage_streams[0].device);
205+
CUDAGuard g(creation_stream.device);
193206

194-
if (usage_streams.size() == 1) {
207+
if (recorded_streams.size() == 0) {
195208
// ptr was only used on one stream, which must have been
196209
// the original allocation stream.
197210
// Frees ptr in the original allocation stream.
198-
C10_CUDA_CHECK(cudaFreeAsync(it->first, usage_streams[0].stream));
211+
C10_CUDA_CHECK(cudaFreeAsync(it->first, creation_stream.stream));
199212

200213
if (C10_UNLIKELY(capture_underway)) {
201214
// See Note [Avoid dangling free streams during CUDA graph capture]
202-
capture_free_streams.insert(usage_streams[0]);
215+
capture_free_streams.insert(creation_stream);
203216
}
204217
} else {
205218
// ptr was used on many streams. We don't know which was the most recent.
@@ -212,23 +225,21 @@ inline void free_impl(PtrInfo::iterator& it) {
212225

213226
// Retrieves the dummy "unifier" stream from the device
214227
// on which the pointer was originally allocated.
215-
auto dummy_unifying_free_stream = dummy_unifying_free_streams[usage_streams[0].device];
216-
TORCH_INTERNAL_ASSERT(dummy_unifying_free_stream.device == usage_streams[0].device);
228+
auto dummy_unifying_free_stream = dummy_unifying_free_streams[creation_stream.device];
229+
TORCH_INTERNAL_ASSERT(dummy_unifying_free_stream.device == creation_stream.device);
230+
231+
// we're already on creation_stream.device, no need to re-guard
232+
sync_raw(creation_stream.stream, dummy_unifying_free_stream.stream);
217233

218234
// The number of usage streams is typically small (low single digits)
219-
for (const auto& usage_stream : usage_streams) {
235+
for (const auto& recorded_stream : recorded_streams) {
220236
// Logic here accommodates the chance some of the usage streams were on other devices,
221237
// which is possible if some usage kernels accessed the memory via p2p.
222238

223239
// cudaEventRecord requires that the input event and stream are on the same device.
224-
CUDAGuard g_usage(usage_stream.device);
225-
226-
// CUDACachingAllocator.cpp uses raw cuda events, as do we.
227-
cudaEvent_t event;
228-
C10_CUDA_CHECK(cudaEventCreateWithFlags(&event, cudaEventDisableTiming));
229-
C10_CUDA_CHECK(cudaEventRecord(event, usage_stream.stream));
230-
C10_CUDA_CHECK(cudaStreamWaitEvent(dummy_unifying_free_stream.stream, event));
231-
C10_CUDA_CHECK(cudaEventDestroy(event));
240+
CUDAGuard g_usage(recorded_stream.device);
241+
242+
sync_raw(recorded_stream.stream, dummy_unifying_free_stream.stream);
232243
}
233244

234245
// Frees ptr in the dummy "unifier" stream.
@@ -240,10 +251,10 @@ inline void free_impl(PtrInfo::iterator& it) {
240251
// In theory, we could remove the need for the driver to do this tracking by e.g. replacing
241252
// cudaStreamWaitEvent(dummy_unifying_free_stream.stream, event);
242253
// with
243-
// cudaStreamWaitEvent(usage_streams[0].stream, event);
244-
// then cudaFreeAsyncing straight back into usage_streams[0];
245-
// but this forces a potentially false dependency of usage_streams[0]
246-
// on all the other usage_streams.
254+
// cudaStreamWaitEvent(creation_stream.stream, event);
255+
// then cudaFreeAsyncing straight back into creation_stream.stream,
256+
// but this forces a potentially false dependency of creation_stream.stream
257+
// on all the recorded_streams.
247258

248259
if (C10_UNLIKELY(capture_underway)) {
249260
// See Note [Avoid dangling free streams during CUDA graph capture]
@@ -252,7 +263,7 @@ inline void free_impl(PtrInfo::iterator& it) {
252263
}
253264
}
254265

255-
pytorch_used_bytes[usage_streams[0].device] -= it->second.size;
266+
pytorch_used_bytes[creation_stream.device] -= it->second.size;
256267

257268
ptr_info.erase(it);
258269
}
@@ -263,8 +274,6 @@ void free(void* ptr) {
263274
auto it = ptr_info.find(ptr);
264275
TORCH_INTERNAL_ASSERT(it != ptr_info.end(),
265276
"ptr not found in ptr_info");
266-
TORCH_INTERNAL_ASSERT(it->second.usage_streams.size() != 0,
267-
"ptr's stream uses vector is empty");
268277

269278
if (C10_UNLIKELY(capture_underway)) {
270279
if (!it->second.captured) {
@@ -354,7 +363,7 @@ void malloc(void** devPtr, int device, size_t size, cudaStream_t stream) {
354363
"address returned by cudaMallocAsync already exists "
355364
"in ptr_info");
356365

357-
inserted.first->second.usage_streams.emplace_back(stream, device);
366+
inserted.first->second.creation_stream = {stream, device};
358367

359368
pytorch_used_bytes[device] += size;
360369
}
@@ -394,7 +403,7 @@ Allocator* get(void) {
394403
// just set up for later calls to init per-device pools based
395404
// on the current device each later call sees.
396405
void init(int dev_count) {
397-
static bool called = [] {;
406+
static bool called = [](int dev_count) {;
398407
// Are there external guarantees init will be called before
399408
// any of the allocator's other functions?
400409
// std::lock_guard<std::mutex> lk(general_mutex);
@@ -404,7 +413,7 @@ void init(int dev_count) {
404413
pytorch_used_bytes.resize(dev_count);
405414
pytorch_memory_limits.resize(dev_count);
406415
return true;
407-
}();
416+
}(dev_count);
408417
}
409418

410419
static inline void assertValidDevice(int device) {
@@ -532,11 +541,14 @@ void recordStream(const DataPtr& ptr, cuda::CUDAStream stream) {
532541
auto it = ptr_info.find(ptr.get());
533542
TORCH_INTERNAL_ASSERT(it != ptr_info.end(),
534543
"ptr not found in ptr_info");
535-
TORCH_INTERNAL_ASSERT(it->second.usage_streams.size() != 0,
536-
"ptr's stream uses vector is empty");
537544

538-
it->second.usage_streams.emplace_back(stream.stream(),
539-
stream.device_index());
545+
UsageStream to_record{stream.stream(), stream.device_index()};
546+
if (to_record == it->second.creation_stream) {
547+
TORCH_WARN("Called record_stream on tensor whose original creation stream "
548+
"matches the recorded stream. This is unnecessary and has no effect.");
549+
} else {
550+
it->second.recorded_streams.insert(to_record);
551+
}
540552
}
541553

542554
std::mutex* getFreeMutex() {
@@ -700,8 +712,6 @@ void notifyCaptureEnded(int device, CaptureId_t graph_id) {
700712
auto it = ptr_info.find(ptr);
701713
TORCH_INTERNAL_ASSERT(it != ptr_info.end(),
702714
"ptr not found in ptr_info");
703-
TORCH_INTERNAL_ASSERT(it->second.usage_streams.size() != 0,
704-
"ptr's stream uses vector is empty");
705715
free_impl(it);
706716
}
707717

0 commit comments

Comments
 (0)
0