8000 Add Python-facing torch.cuda.get_allocator_backend() · pytorch/pytorch@3d23053 · GitHub
[go: up one dir, main page]

Skip to content

Commit 3d23053

Browse files
committed
Add Python-facing torch.cuda.get_allocator_backend()
1 parent a006a53 commit 3d23053

File tree

5 files changed

+34
-3
lines changed

5 files changed

+34
-3
lines changed

docs/source/cuda.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,7 @@ Memory management
111111
reset_peak_memory_stats
112112
caching_allocator_alloc
113113
caching_allocator_delete
114+
get_allocator_backend
114115
.. FIXME The following doesn't seem to exist. Is it supposed to?
115116
https://github.com/pytorch/pytorch/issues/27785
116117
.. autofunction:: reset_max_memory_reserved

docs/source/notes/cuda.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -365,6 +365,8 @@ Available options:
365365
implementation, and ``cudaMallocAsync``, which uses
366366
`CUDA's built-in asynchronous allocator`_.
367367
``cudaMallocAsync`` requires CUDA 11.4 or newer. The default is ``native``.
368+
``backend`` applies to all devices used by the process, and can't be
369+
specified on a per-device basis.
368370
* ``max_split_size_mb`` prevents the native allocator
369371
from splitting blocks larger than this size (in MB). This can reduce
370372
fragmentation and may allow some borderline workloads to complete without

test/test_cuda.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -44,9 +44,7 @@
4444
print('CUDA not available, skipping tests', file=sys.stderr)
4545
TestCase = object # noqa: F811
4646

47-
TEST_CUDAMALLOCASYNC = ((os.getenv("PYTORCH_CUDA_ALLOC_CONF") is not None) and
48-
("backend:cudaMallocAsync" in os.getenv("PYTORCH_CUDA_ALLOC_CONF")))
49-
47+
TEST_CUDAMALLOCASYNC = (torch.cuda.get_allocator_backend() == "cudaMallocAsync")
5048
TEST_LARGE_TENSOR = TEST_CUDA
5149
TEST_MEDIUM_TENSOR = TEST_CUDA
5250
TEST_CUDNN = TEST_CUDA

torch/csrc/cuda/Module.cpp

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -225,6 +225,24 @@ PyObject * THCPModule_cudaCachingAllocator_raw_delete(PyObject *_unused, PyObjec
225225
END_HANDLE_TH_ERRORS
226226
}
227227

228+
PyObject * THCPModule_getAllocatorBackend(PyObject *_unused, PyObject *noargs)
229+
{
230+
HANDLE_TH_ERRORS
231+
using c10::cuda::CUDACachingAllocator::AllocatorBackend;
232+
AllocatorBackend backend = c10::cuda::CUDACachingAllocator::allocatorBackend();
233+
// this call should be uncommon, don't bother interning strings
234+
switch (backend) {
235+
case AllocatorBackend::NATIVE:
236+
return THPUtils_packString("native");
237+
case AllocatorBackend::CUDAMALLOCASYNC:
238+
return THPUtils_packString("cudaMallocAsync");
239+
default:
240+
THPUtils_assert(false, "Unexpected value for backend");
241+
return nullptr;
242+
}
243+
END_HANDLE_TH_ERRORS
244+
}
245+
228246
PyObject * THCPModule_cudaSynchronize(PyObject *_unused, PyObject *noargs)
229247
{
230248
HANDLE_TH_ERRORS
@@ -590,6 +608,7 @@ static struct PyMethodDef _THCPModule_methods[] = {
590608
{"_cuda_cudaHostAllocator", THCPModule_cudaHostAllocator, METH_NOARGS, nullptr},
591609
{"_cuda_cudaCachingAllocator_raw_alloc", THCPModule_cudaCachingAllocator_raw_alloc, METH_VARARGS, nullptr},
592610
{"_cuda_cudaCachingAllocator_raw_delete", THCPModule_cudaCachingAllocator_raw_delete, METH_O, nullptr},
611+
{"_cuda_getAllocatorBackend", THCPModule_getAllocatorBackend, METH_NOARGS, nullptr},
593612
{"_cuda_synchronize", THCPModule_cudaSynchronize, METH_NOARGS, nullptr},
594613
{"_cuda_ipc_collect", THCPModule_cudaIPCCollect, METH_NOARGS, nullptr},
595614
{"_cuda_sleep", THCPModule_cudaSleep, METH_O, nullptr},

torch/cuda/memory.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -587,3 +587,14 @@ def mem_get_info(device: Union[Device, int] = None) -> int:
587587
device = torch.cuda.current_device()
588588
device = _get_device_index(device)
589589
return torch.cuda.cudart().cudaMemGetInfo(device)
590+
591+
def get_allocator_backend() -> str:
592+
r"""Returns a string describing the active allocator backend as set by
593+
``PYTORCH_CUDA_ALLOC_CONF``. Currently available backends are
594+
``native`` (Pytorch's native caching allocator) and `cudaMallocAsync``
595+
(CUDA's built-in asynchronous allocator).
596+
597+
.. note::
598+
See :ref:`cuda-memory-management` for details on choosing the allocator backend.
599+
"""
600+
return torch._C._cuda_getAllocatorBackend()

0 commit comments

Comments
 (0)
0