@@ -45,8 +45,21 @@ struct UsageStream {
45
45
}
46
46
};
47
47
48
+ bool operator ==(const UsageStream& lhs, const UsageStream& rhs) {
49
+ return (lhs.stream == rhs.stream ) && (lhs.device == rhs.device );
50
+ }
51
+
52
+ struct UsageStreamHash {
53
+ size_t operator ()(const UsageStream& us) const noexcept {
54
+ return std::hash<void *>{}(us.stream ) + size_t (us.device );
55
+ }
56
+ };
57
+
48
58
struct PtrUsage {
49
- std::vector<UsageStream> usage_streams;
59
+ // recorded_streams holds side usage streams added by record_stream calls.
60
+ // In other words, it does NOT include the original creation stream.
61
+ ska::flat_hash_set<UsageStream, UsageStreamHash> recorded_streams;
62
+ UsageStream creation_stream;
50
63
uint64_t size;
51
64
bool captured;
52
65
PtrUsage (uint64_t s, bool c) : size(s), captured(c) {}
@@ -128,16 +141,6 @@ std::vector<size_t> pytorch_memory_limits;
128
141
* carefully about the CPU overhead of remembering and rejoining
129
142
* all free streams during capture. Maybe it's not a big deal.
130
143
*/
131
- bool operator ==(const UsageStream& lhs, const UsageStream& rhs) {
132
- return (lhs.stream == rhs.stream ) && (lhs.device == rhs.device );
133
- }
134
-
135
- struct UsageStreamHash {
136
- size_t operator ()(const UsageStream& us) const noexcept {
137
- return std::hash<void *>{}(us.stream ) + size_t (us.device );
138
- }
139
- };
140
-
141
144
std::unordered_set<UsageStream, UsageStreamHash> capture_free_streams;
142
145
bool capture_underway = false ;
143
146
@@ -180,26 +183,36 @@ inline void lazy_init_device(int device) {
180
183
}
181
184
}
182
185
186
+ inline void sync_raw (cudaStream_t dependency, cudaStream_t dependent) {
187
+ // CUDACachingAllocator.cpp uses raw cuda events, as do we.
188
+ cudaEvent_t event;
189
+ C10_CUDA_CHECK (cudaEventCreateWithFlags (&event, cudaEventDisableTiming));
190
+ C10_CUDA_CHECK (cudaEventRecord (event, dependency));
191
+ C10_CUDA_CHECK (cudaStreamWaitEvent (dependent, event));
192
+ C10_CUDA_CHECK (cudaEventDestroy (event));
193
+ }
194
+
183
195
// Assumes the caller holds general_mutex
184
196
inline void free_impl (PtrInfo::iterator& it) {
185
197
// Possible micro-optimization: If we did a value-copy here, we could move
186
198
// ptr_info.erase(it) up here and drop the lock immediately.
187
- const auto & usage_streams = it->second .usage_streams ;
199
+ const auto & recorded_streams = it->second .recorded_streams ;
200
+ const auto & creation_stream = it->second .creation_stream ;
188
201
189
202
// If the usage stream is a null (default) stream,
190
203
// cudaFreeAsync infers the device from the ambient context,
191
204
// so we need to set the right ambient context.
192
- CUDAGuard g (usage_streams[ 0 ] .device );
205
+ CUDAGuard g (creation_stream .device );
193
206
194
- if (usage_streams .size () == 1 ) {
207
+ if (recorded_streams .size () == 0 ) {
195
208
// ptr was only used on one stream, which must have been
196
209
// the original allocation stream.
197
210
// Frees ptr in the original allocation stream.
198
- C10_CUDA_CHECK (cudaFreeAsync (it->first , usage_streams[ 0 ] .stream ));
211
+ C10_CUDA_CHECK (cudaFreeAsync (it->first , creation_stream .stream ));
199
212
200
213
if (C10_UNLIKELY (capture_underway)) {
201
214
// See Note [Avoid dangling free streams during CUDA graph capture]
202
- capture_free_streams.insert (usage_streams[ 0 ] );
215
+ capture_free_streams.insert (creation_stream );
203
216
}
204
217
} else {
205
218
// ptr was used on many streams. We don't know which was the most recent.
@@ -212,23 +225,21 @@ inline void free_impl(PtrInfo::iterator& it) {
212
225
213
226
// Retrieves the dummy "unifier" stream from the device
214
227
// on which the pointer was originally allocated.
215
- auto dummy_unifying_free_stream = dummy_unifying_free_streams[usage_streams[0 ].device ];
216
- TORCH_INTERNAL_ASSERT (dummy_unifying_free_stream.device == usage_streams[0 ].device );
228
+ auto dummy_unifying_free_stream = dummy_unifying_free_streams[creation_stream.device ];
229
+ TORCH_INTERNAL_ASSERT (dummy_unifying_free_stream.device == creation_stream.device );
230
+
231
+ // we're already on creation_stream.device, no need to re-guard
232
+ sync_raw (creation_stream.stream , dummy_unifying_free_stream.stream );
217
233
218
234
// The number of usage streams is typically small (low single digits)
219
- for (const auto & usage_stream : usage_streams ) {
235
+ for (const auto & recorded_stream : recorded_streams ) {
220
236
// Logic here accommodates the chance some of the usage streams were on other devices,
221
237
// which is possible if some usage kernels accessed the memory via p2p.
222
238
223
239
// cudaEventRecord requires that the input event and stream are on the same device.
224
- CUDAGuard g_usage (usage_stream.device );
225
-
226
- // CUDACachingAllocator.cpp uses raw cuda events, as do we.
227
- cudaEvent_t event;
228
- C10_CUDA_CHECK (cudaEventCreateWithFlags (&event, cudaEventDisableTiming));
229
- C10_CUDA_CHECK (cudaEventRecord (event, usage_stream.stream ));
230
- C10_CUDA_CHECK (cudaStreamWaitEvent (dummy_unifying_free_stream.stream , event));
231
- C10_CUDA_CHECK (cudaEventDestroy (event));
240
+ CUDAGuard g_usage (recorded_stream.device );
241
+
242
+ sync_raw (recorded_stream.stream , dummy_unifying_free_stream.stream );
232
243
}
233
244
234
245
// Frees ptr in the dummy "unifier" stream.
@@ -240,10 +251,10 @@ inline void free_impl(PtrInfo::iterator& it) {
240
251
// In theory, we could remove the need for the driver to do this tracking by e.g. replacing
241
252
// cudaStreamWaitEvent(dummy_unifying_free_stream.stream, event);
242
253
// with
243
- // cudaStreamWaitEvent(usage_streams[0] .stream, event);
244
- // then cudaFreeAsyncing straight back into usage_streams[0];
245
- // but this forces a potentially false dependency of usage_streams[0]
246
- // on all the other usage_streams .
254
+ // cudaStreamWaitEvent(creation_stream .stream, event);
255
+ // then cudaFreeAsyncing straight back into creation_stream.stream,
256
+ // but this forces a potentially false dependency of creation_stream.stream
257
+ // on all the recorded_streams .
247
258
248
259
if (C10_UNLIKELY (capture_underway)) {
249
260
// See Note [Avoid dangling free streams during CUDA graph capture]
@@ -252,7 +263,7 @@ inline void free_impl(PtrInfo::iterator& it) {
252
263
}
253
264
}
254
265
255
- pytorch_used_bytes[usage_streams[ 0 ] .device ] -= it->second .size ;
266
+ pytorch_used_bytes[creation_stream .device ] -= it->second .size ;
256
267
257
268
ptr_info.erase (it);
258
269
}
@@ -263,8 +274,6 @@ void free(void* ptr) {
263
274
auto it = ptr_info.find (ptr);
264
275
TORCH_INTERNAL_ASSERT (it != ptr_info.end (),
265
276
" ptr not found in ptr_info" );
266
- TORCH_INTERNAL_ASSERT (it->second .usage_streams .size () != 0 ,
267
- " ptr's stream uses vector is empty" );
268
277
269
278
if (C10_UNLIKELY (capture_underway)) {
270
279
if (!it->second .captured ) {
@@ -354,7 +363,7 @@ void malloc(void** devPtr, int device, size_t size, cudaStream_t stream) {
354
363
" address returned by cudaMallocAsync already exists "
355
364
" in ptr_info" );
356
365
357
- inserted.first ->second .usage_streams . emplace_back ( stream, device) ;
366
+ inserted.first ->second .creation_stream = { stream, device} ;
358
367
359
368
pytorch_used_bytes[device] += size;
360
369
}
@@ -394,7 +403,7 @@ Allocator* get(void) {
394
403
// just set up for later calls to init per-device pools based
395
404
// on the current device each later call sees.
396
405
void init (int dev_count) {
397
- static bool called = [] {;
406
+ static bool called = []( int dev_count) {;
398
407
// Are there external guarantees init will be called before
399
408
// any of the allocator's other functions?
400
409
// std::lock_guard<std::mutex> lk(general_mutex);
@@ -404,7 +413,7 @@ void init(int dev_count) {
404
413
pytorch_used_bytes.resize (dev_count);
405
414
pytorch_memory_limits.resize (dev_count);
406
415
return true ;
407
- }();
416
+ }(dev_count );
408
417
}
409
418
410
419
static inline void assertValidDevice (int device) {
@@ -532,11 +541,14 @@ void recordStream(const DataPtr& ptr, cuda::CUDAStream stream) {
532
541
auto it = ptr_info.find (ptr.get ());
533
542
TORCH_INTERNAL_ASSERT (it != ptr_info.end (),
534
543
" ptr not found in ptr_info" );
535
- TORCH_INTERNAL_ASSERT (it->second .usage_streams .size () != 0 ,
536
- " ptr's stream uses vector is empty" );
537
544
538
- it->second .usage_streams .emplace_back (stream.stream (),
539
- stream.device_index ());
545
+ UsageStream to_record{stream.stream (), stream.device_index ()};
546
+ if (to_record == it->second .creation_stream ) {
547
+ TORCH_WARN (" Called record_stream on tensor whose original creation stream "
548
+ " matches the recorded stream. This is unnecessary and has no effect." );
549
+ } else {
550
+ it->second .recorded_streams .insert (to_record);
551
+ }
540
552
}
541
553
542
554
std::mutex* getFreeMutex () {
@@ -700,8 +712,6 @@ void notifyCaptureEnded(int device, CaptureId_t graph_id) {
700
712
auto it = ptr_info.find (ptr);
701
713
TORCH_INTERNAL_ASSERT (it != ptr_info.end (),
702
714
" ptr not found in ptr_info" );
703
- TORCH_INTERNAL_ASSERT (it->second .usage_streams .size () != 0 ,
704
- " ptr's stream uses vector is empty" );
705
715
free_impl (it);
706
716
}
707
717
0 commit comments