8000 test_out_of_memory_retry · pytorch/pytorch@a208bdb · GitHub
[go: up one dir, main page]

Skip to content

Commit a208bdb

Browse files
committed
test_out_of_memory_retry
1 parent a240ec8 commit a208bdb

File tree

2 files changed

+12
-1
lines changed

2 files changed

+12
-1
lines changed

c10/cuda/CUDACachingAllocator.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -370,7 +370,7 @@ class DeviceCachingAllocator {
370370
//
371371
// Q. Why skip process_events if a capture might be underway?
372372
// A. process_events involves cudaEventQueries, illegal during CUDA graph
373-
// capture.
373+
// capture.
374374
// Dumb simple solution: defer reclaiming these allocations until after
375375
// capture. Cross-stream memory use is uncommon, so the deferral's
376376
// effect on memory use during capture should be small.

test/test_cuda.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -377,6 +377,17 @@ def test_out_of_memory(self):
377377
tensor.fill_(1)
378378
self.assertTrue((tensor == 1).all())
379379

380+
def test_out_of_memory_retry(self):
381+
total_memory = torch.cuda.get_device_properties(0).total_memory
382+
oom_regex = "would exceed allowed memory" if TEST_CUDAMALLOCASYNC else \
383+
"Tried to allocate"
384+
size = int(total_memory * 0.5)
385+
a = torch.empty(size , dtype=torch.int8, device='cuda')
386+
with self.assertRaisesRegex(RuntimeError, oom_regex):
387+
b = torch.empty(size, dtype=torch.int8, device='cuda')
388+
del a
389+
b = torch.empty(size, dtype=torch.int8, device='cuda')
390+
380391
def test_set_per_process_memory_fraction(self):
381392
# test invalid fraction value.
382393
with self.assertRaisesRegex(TypeError, "Invalid type"):

0 commit comments

Comments
 (0)
0