8000 Adds cudaMallocAsync as an alternative backend for the CUDA allocator by mcarilli · Pull Request #65365 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content

Adds cudaMallocAsync as an alternative backend for the CUDA allocator #65365

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 73 commits into from
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
73 commits
Select commit Hold shift + click to select a range
f5f5cde
torch.cuda.backends.allocator
mcarilli Sep 20, 2021
16ea79b
cudaAllocatorModule
mcarilli Sep 20, 2021
f1f803d
remove backends.cuda binding, want envvar instead
mcarilli Sep 21, 2021
a73877e
use PYTORCH_CUDA_ALLOC_CONF
mcarilli Sep 21, 2021
35b322f
Stashing for vsibility of this idea
mcarilli Oct 11, 2021
6ada155
Taking shape enough to be worth showing
mcarilli Oct 31, 2021
a6f271e
Separated allocator config so functions can be shared
mcarilli Nov 8, 2021
9e2b3b7
docstring
mcarilli Nov 8, 2021
3026470
Fix multiple definition errors, simplify config
mcarilli Nov 8, 2021
430e43c
Merge remote-tracking branch 'upstream/master' into cudaMallocAsync
mcarilli Nov 8, 2021
57aa340
uncovered graph can of worms
mcarilli Nov 9, 2021
44a9c72
Merge remote-tracking branch 'upstream/master' into cudaMallocAsync
mcarilli Nov 14, 2021
bae5608
Should fix dangling streams issue
mcarilli Nov 15, 2021
0b6ffa1
approach to avoid freeing ungraphed pointers during capture
mcarilli Nov 16, 2021
f2038a3
Hash and operator== for UsageStream
mcarilli Nov 16, 2021
6b9d832
stashing work
mcarilli Nov 19, 2021
253851d
Merge remote-tracking branch 'upstream/master' into cudaMallocAsync
mcarilli Nov 19, 2021
cfd624e
Almost!
mcarilli Nov 19, 2021
bfaae65
Merge remote-tracking branch 'upstream/master' into cudaMallocAsync
mcarilli Nov 19, 2021
ddcec61
fix annoying link error
mcarilli Nov 24, 2021
06d58b5
Error messages for compile or runtime CUDA version < 11.4
mcarilli Nov 29, 2021
fdaaa9f
Completely avoid cudaGraphInstantiateFlagAutoFreeOnLaunch if cuda < 11.4
mcarilli Nov 30, 2021
8f94458
Fix CUDA_VERSION usage
mcarilli Nov 30, 2021
daf188f
resolve conflict in CUDACachingAllocator.cpp
mcarilli Dec 7, 2021
eeeac81
Merge remote-tracking branch 'upstream/master' into cudaMallocAsync
mcarilli Dec 7, 2021
4d7388b
Merge remote-tracking branch 'upstream/master' into cudaMallocAsync
mcarilli Dec 7, 2021
89b03b3
fix warning string
mcarilli Dec 7, 2021
eef5b31
reset properly
mcarilli Dec 9, 2021
27a2d68
Merge remote-tracking branch 'upstream/master' into cudaMallocAsync
mcarilli Dec 13, 2021
991d303
halfassed but defensible attempt at stat handling
mcarilli Dec 17, 2021
1815536
For discussion, highlight strange interaction between cacheInfo and c…
mcarilli Dec 17, 2021
4d8dca7
fdsa
mcarilli Dec 17, 2021
363fe3c
Let's see how this makes the docs look
mcarilli Dec 20, 2021
b13f118
before i forget
mcarilli Jan 1, 2022
4089885
Merge remote-tracking branch 'upstream/master' into cudaMallocAsync
mcarilli Jan 14, 2022
368a0de
typo
mcarilli Jan 14, 2022
c80a05f
All graph tests in test_cuda.py pass except for test_graph_cudnn_dropout
mcarilli Jan 14, 2022
db53e41
better skip
mcarilli Jan 16, 2022
2293a94
Merge remote-tracking branch 'upstream/master' into cudaMallocAsync
mcarilli Jan 17, 2022
bdc6d30
test_graph_cudnn_dropout now passes too
mcarilli Jan 18, 2022
7e7c12b
worth a try
mcarilli Jan 20, 2022
c90a3a0
enable p2p transfers for p2p-capable devices
mcarilli Jan 25, 2022
c2d84ea
test_cuda.py passes on my machine
mcarilli Jan 26, 2022
aa2bee7
TEST_CUDAMALLOCASYNC
mcarilli Jan 27, 2022
6fbb8cc
fixes OOM handling and test_record_stream
mcarilli Jan 28, 2022
78eb46e
fix regex for test_set_per_process_memory_fraction
mcarilli Jan 29, 2022
a240ec8
Merge remote-tracking branch 'upstream/master' into cudaMallocAsync
mcarilli Jan 31, 2022
a208bdb
test_out_of_memory_retry
mcarilli Feb 1, 2022
1189ff7
s/THC/Native requested by natalia
mcarilli Feb 1, 2022
05c4554
Resolve conflicts in CUDACachingAllocator.cpp
mcarilli Feb 2, 2022
c33ce86
Resolve conflict in CUDACachingAllocator.cpp
mcarilli Feb 9, 2022
c85cb9c
fix test_set_per_process_memory_fraction failure caused by knock-on e…
mcarilli Feb 10, 2022
7fe0e75
typo
mcarilli Feb 10, 2022
b534129
fix signature for unavailable version
mcarilli Feb 10, 2022
2f7d1b5
comment
mcarilli Feb 17, 2022
bc55994
Merge remote-tracking branch 'upstream/master' into cudaMallocAsync
mcarilli Feb 17, 2022
e0ec118
temporary test disables to unblock local CI run
mcarilli Feb 17, 2022
09709d5
resolve conflict in Copy.cu
mcarilli Mar 8, 2022
4cbde6f
#if CUDA_VERSION >= 11040
mcarilli Mar 8, 2022
6bbf293
use backend enum, use 71668 config parser, de-inline format_size, rem…
mcarilli Mar 16, 2022
e810e0f
avoid compiling pool-specific stuff with cuda < 11.4
mcarilli Mar 18, 2022
95af048
Resolves conflicts with #74213
mcarilli Mar 18, 2022
1cc5d02
Resolving conflicts
mcarilli Mar 26, 2022
a006a53
switches to flat_hash_set for recorded streams, restores original max…
mcarilli Mar 28, 2022
3d23053
Add Python-facing torch.cuda.get_allocator_backend()
mcarilli Mar 30, 2022
49d4332
Pytorch->PyTorch again lmao
mcarilli Mar 30, 2022
3cc7a1f
rename malloc and free, avoid gratuitous exception-catching in cacheInfo
mcarilli Apr 5, 2022
5401784
typos
mcarilli Apr 5, 2022
ec5b6ff
Merge remote-tracking branch 'upstream/master' into cudaMallocAsync
mcarilli Apr 5, 2022
c4a9acf
flake8 and mypy
mcarilli Apr 5, 2022
4c79ed8
let's see if this cleans up some failures
mcarilli Apr 5, 2022
2b1e0b2
un-static allocatorBackend checks in p2p and copy path
mcarilli Apr 5, 2022
9a47eff
Implements backend load time initialization. Builds and import succeeds.
mcarilli Apr 22, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
approach to avoid freeing ungraphed pointers during capture
  • Loading branch information
mcarilli committed Nov 16, 2021
commit 0b6ffa18cc1c6680b1773b27e541a1dd0a58ea37
11 changes: 9 additions & 2 deletions c10/cuda/CUDACachingAllocator.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@

#include <c10/cuda/CUDACachingAllocator.h>

#include <c10/cuda/CUDAException.h>
#include <c10/cuda/CUDAFunctions.h>
#include <c10/cuda/CUDAGuard.h>
Expand Down Expand Up @@ -1619,6 +1617,9 @@ void parseArgs() {
std::sregex_token_iterator end;
std::vector<std::string> options(it, end);

bool used_max_split_size_mb(false);
bool used_cudaMallocAsync(false);

for (auto option : options) {
std::regex exp2("[:]+");
std::sregex_token_iterator it2(option.begin(), option.end(), exp2, -1);
Expand All @@ -1637,17 +1638,23 @@ void parseArgs() {
val2 = std::min(
val2, (std::numeric_limits<size_t>::max() / (1024 * 1024)));
m_max_split_size = val2 * 1024 * 1024;
used_max_split_size_mb = true;
} else if (kv[0].compare("backend") == 0) {
TORCH_CHECK(((kv[1].compare("native") == 0) ||
(kv[1].compare("cudaMallocAsync") == 0)),
"Unknown allocator backend, "
"options are native and cudaMallocAsync");
m_allocator_backend = kv[1];
used_cudaMallocAsync = (kv[1].compare("cudaMallocAsync") == 0);
} else {
TORCH_CHECK(false, "Unrecognized CachingAllocator option: ", kv[0]);
}
}
}

if (used_max_split_size_mb && used_cudaMallocAsync) {
TORCH_WARN("backend:cudaMallocAsync ignores max_split_size_mb");
}
}
}

Expand Down
150 changes: 109 additions & 41 deletions c10/cuda/CUDAMallocAsyncAllocator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,10 @@ namespace cuda {
namespace CUDACachingAllocator {
namespace CudaMallocAsync {

// Allocator that uses cudaMallocAsync to implement the same interface
// as CUDACachingAllocator.cpp.
//
// cudaMallocAsync works transparently with CUDA graphs.
// CUDA device allocator that uses cudaMallocAsync to implement
// the same interface as CUDACachingAllocator.cpp.

// Designed to be safe for CUDA graph capture.

// Implementation details, not declared in CUDACachingAllocator.h
namespace {
Expand All @@ -31,8 +31,8 @@ int device_count = 0;
std::vector<bool> devs_initialized_flags;
std::vector<CUDAStream> dummy_unifying_free_streams;

// Potential future micro-optimization:
// Some accesses to usage_streams_each_ptr are read-only.
// Possible micro-optimization:
// Some accesses to ptr_info are read-only.
// We could let those be concurrent with a shared_mutex and
// have concurrent calls take a shared_lock.
// Keeping it simple with an ordinary mutex for now.
Expand All @@ -47,14 +47,28 @@ std::mutex general_mutex;
* sure no uncaptured tensor will ever have its destructor called
* in a capturing region.
* We avoid errors by
* 1. tracking captured and uncaptured allocated pointers separately
* 1. remembering if allocated pointers were captured or uncaptured
* 2. during capture, if we detect an attempt to free an uncaptured
* allocation on a capturing stream, don't free it immediately,
* just remember it and defer its cudaFreeAsync call to after
* the end of capture (specifically, to notifyCaptureEnded).
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems... suboptimal. Why isn't there an option to do an untracked free?

Copy link
Collaborator Author
@mcarilli mcarilli Mar 7, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's a restriction of the CUDA API because ungraphed allocations are meant to be one-off, while graphs are meant to be replayed many times. I think their philosophy is: if cuda silently allowed ungraphed allocations to be freed during capture, an ungraphed allocation address that happened to get freed during capture would be actually-freed on the first replay, then later replays would attempt to actually-free the same address, at which point it would either no longer be live or be associated with some totally different allocation (and in either scenario an actual free is a bad thing to do).

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So CUDA API is inconsistent in that on replays it will ignore allocations recorded during graph capture (with relaxed mode or whatever), but won't ignore frees?

Copy link
Collaborator Author
@mcarilli mcarilli Mar 15, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cuda guys reconcile this "inconsistency" in their brains by regarding

on replays it will ignore allocations recorded during graph capture

as "auto free on launch", ie, when the replay sees some captured VA that is unfreed since the last launch, it (conceptually or literally, im not sure) frees and immediately reallocs that VA.

The graph should not ignore frees of graphed memory, because graphed VA is intended to get reallocated as many times as it's freed. For example, we want:

graph1.replay(); // e.g. forward pass of a make_graphed_callable, allocates some activations
graph2.replay(); // e.g. backward pass of make_graphed_callable, consumes and frees the activations
// graph2.replay()s frees should have a real effect, because out here we shouldn't regard the activations as allocated and their PA unusable by later cudaMallocAsyncs.

*/
std::unordered_map<void*, std::vector<CUDAStream>> usage_streams_ungraphed_ptrs;
std::unordered_map<void*, std::vector<CUDAStream>> usage_streams_graphed_ptrs;

struct UsageStream {
cudaStream_t stream;
int device;
UsageStream(cudaStream_t s, int d) : stream(s), device(d) {}
};

struct PtrUsage {
std::vector<UsageStream> usage_streams;
uint64_t size;
bool captured;
PtrUsage(uint64_t s, bool c) : size(s), captured(c) {}
};

using PtrInfo = std::unordered_map<void*, PtrUsage>;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You sure you want an unordered_map? Is this hot?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Called every malloc, looks pretty hot to me

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changed to ska::flat_hash_map. Do you know for sure flat_hash_map a drop-in replacement for std::unordered map? Does it implement the same interface? The code builds with flat_hash_map, but that doesn't guarantee the semantics of member functions are the same...

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It might have different guaranteed on iterator invalidation, but otherwise should be the same (I couldn't find the docs right away)

PtrInfo ptr_info;
std::vector<void*> ungraphed_ptrs_defer_free_until_no_capture;

// Graph-specific helpers
Expand Down Expand Up @@ -94,7 +108,7 @@ std::vector<void*> ungraphed_ptrs_defer_free_until_no_capture;
* carefully about the CPU overhead of remembering and rejoining
* all free streams during capture. Maybe it's not a big deal.
*/
std::unordered_set<CUDAStream> capture_free_streams;
std::unordered_set<UsageStream> capture_free_streams;
bool capture_underway = false;

// Implementation functions
Expand Down Expand Up @@ -128,31 +142,24 @@ inline void lazy_init_device(int device) {
}
}

void free(void* ptr) {
std::lock_guard<std::mutex> lk(general_mutex);

auto it = usage_streams_each_ptr.find(ptr);
TORCH_INTERNAL_ASSERT(it != usage_streams_each_ptr.end(),
"ptr not represented in usage_streams_each_ptr");
TORCH_INTERNAL_ASSERT(it->second.size() != 0,
"ptr's stream uses vector is empty");

// Assumes the caller holds general_mutex
inline void free_impl(PtrInfo::iterator& it) {
// Possible micro-optimization: If we did a value-copy here, we could move
// usage_streams_each_ptr.erase(it) up here and drop the lock immediately.
const auto& usage_streams = it->second;
// ptr_info.erase(it) up here and drop the lock immediately.
const auto& usage_streams = it->second.usage_streams;

// If the usage stream is a null (default) stream,
// cudaFreeAsync infers the device from the ambient context,
// so we need to set the right ambient context.
CUDAGuard g(usage_streams[0].device_index());
CUDAGuard g(usage_streams[0].device);

if (usage_streams.size() == 1) {
// ptr was only used on one stream, which must have been
// the original allocation stream.
// Frees ptr in the original allocation stream.
C10_CUDA_CHECK(cudaFreeAsync(ptr, usage_streams[0]));
C10_CUDA_CHECK(cudaFreeAsync(ptr, usage_streams[0].stream));

if (C10_UNLIKELY(captures_underway)) {
if (C10_UNLIKELY(capture_underway)) {
// See Note [Avoid dangling free streams during CUDA graph capture]
capture_free_streams.insert(usage_streams[0]);
}
Expand All @@ -167,20 +174,20 @@ void free(void* ptr) {

// Retrieves the dummy "unifier" stream from the device
// on which the pointer was originally allocated.
auto dummy_unifying_free_stream = dummy_unifying_free_streams[usage_streams[0].device_index()];
auto dummy_unifying_free_stream = dummy_unifying_free_streams[usage_streams[0].devce];

// The number of usage streams is typically small (low single digits)
for (const auto& usage_stream : usage_streams) {
// Logic here accommodates the chance some of the usage streams were on other devices,
// which is possible if some usage kernels accessed the memory via p2p.

// cudaEventRecord requires that the input event and stream are on the same device.
CUDAGuard g_usage(usage_stream.device_index());
CUDAGuard g_usage(usage_stream.device);

// CUDACachingAllocator.cpp uses raw cuda events, as do we.
cudaEvent_t event;
C10_CUDA_CHECK(cudaEventCreateWithFlags(&event, cudaEventDisableTiming));
C10_CUDA_CHECK(cudaEventRecord(event, usage_stream.stream()));
C10_CUDA_CHECK(cudaEventRecord(event, usage_stream.stream));
C10_CUDA_CHECK(cudaStreamWaitEvent(dummy_unifying_free_stream.stream(), event));
C10_CUDA_CHECK(cudaEventDestroy(event));
}
Expand All @@ -199,13 +206,46 @@ void free(void* ptr) {
// but this forces a potentially false dependency of usage_streams[0]
// on all the other usage_streams.

if (C10_UNLIKELY(captures_underway)) {
if (C10_UNLIKELY(capture_underway)) {
// See Note [Avoid dangling free streams during CUDA graph capture]
capture_free_streams.insert(dummy_unifying_free_stream);
capture_free_streams.insert({dummy_unifying_free_stream.stream,
dummy_unifying_free_stream.device});
}
}

usage_streams_each_ptr.erase(it);
ptr_info.erase(it);
}

void free(void* ptr) {
std::lock_guard<std::mutex> lk(general_mutex);

auto it = ptr_info.find(ptr);
TORCH_INTERNAL_ASSERT(it != ptr_info.end(),
"ptr not found in ptr_info");
TORCH_INTERNAL_ASSERT(it->second.usage_streams.size() != 0,
"ptr's stream uses vector is empty");

if (C10_UNLIKELY(capture_underway)) {
if (it->second.captured) {
// See Note [Avoid freeing uncaptured ptrs during CUDA graph capture]
// Remembers the raw pointer, not the iterator.
// This forces notifyCaptureEnded to do another lookup,
// but avoids the risk the iterator might be invalidated
// between now and then.
ungraphed_ptrs_defer_free_until_no_capture.push_back(ptr);
return;
}
}

if (C10_UNLIKELY(it->second.captured)) {
TORCH_WARN("Attempting uncaptured free of a captured allocation. "
"This is technically allowed, but may indicate you are losing "
"the last user-visible tensor through which the allocation can "
"be accessed, so you'll have no way to view the data after "
"future replays of the owning graph.");
}

free_impl(it);
}

// Symmetric with THCCachingAllocator::malloc for now,
Expand All @@ -228,11 +268,12 @@ void malloc(void** devPtr, int device, size_t size, cudaStream_t stream) {

lazy_init_device(device);

auto inserted = usage_streams_each_ptr.insert(std::make_pair(devPtr, {});
auto inserted = ptr_info.emplace({size, capture_underway});
TORCH_INTERNAL_ASSERT(inserted.second,
"address returned by cudaMallocAsync already exists "
"in usage_streams_each_ptr");
inserted.first->second.push_back();

inserted.first->second.usage_streams.emplace_back(stream, device);
}

} // anonymous namespace
Expand All @@ -251,7 +292,7 @@ struct CudaCachingAllocator : public Allocator {
C10_CUDA_CHECK(cudaGetDevice(&device));
void* r = nullptr;
if (size != 0) {
malloc(&r, device, size, getCurrentCUDAStream(device));
malloc(&r, device, size, cuda::getCurrentCUDAStream(device));
}
return {r, r, &raw_delete, Device(DeviceType::CUDA, device)};
}
Expand Down Expand Up @@ -319,13 +360,13 @@ void recordStream(const DataPtr& ptr, cuda::CUDAStream stream) {
std::lock_guard<std::mutex> lk(general_mutex);

// The pointer should exist in the map already.
auto it = usage_streams_each_ptr.find(ptr.get());
TORCH_INTERNAL_ASSERT(it != usage_streams_each_ptr.end(),
"ptr not represented in usage_streams_each_ptr");
TORCH_INTERNAL_ASSERT(it->second.size() != 0,
auto it = ptr_info.find(ptr.get());
TORCH_INTERNAL_ASSERT(it != ptr_info.end(),
"ptr not found in ptr_info");
TORCH_INTERNAL_ASSERT(it->second.usage_streams.size() != 0,
"ptr's stream uses vector is empty");

it->second.push_back(stream);
it->second.usage_streams.emplace_back(stream, device);
}

std::mutex* getFreeMutex() {
Expand Down Expand Up @@ -404,22 +445,34 @@ std::vector<SegmentInfo> snapshot() {
}

// CUDAGraph interactions
void notifyCaptureBegin(CaptureId_t graph_id, MempoolId_t mempool_id) {} // no-op
void notifyCaptureBegin(CaptureId_t graph_id, MempoolId_t mempool_id) {
std::lock_guard<std::mutex> lk(general_mutex);

TORCH_CHECK(!capture_underway.
"Only one capture at a time is allowed in a process.")
capture_underway = true;
}

void notifyCaptureAboutToEnd(int device, CaptureId_t graph_id) {
assertValidDevice(device);

std::lock_guard<std::mutex> lk(general_mutex);

TORCH_CHECK(capture_underway.
"CudaMallocAsync::notifyCaptureAboutToEnd called, "
"but CudaMallocAsync::capture_underway is false");

auto capture_stream = cuda::getCurrentCUDAStream(device)

// See Note [Avoid dangling free streams during CUDA graph capture]
for (const auto& free_stream : capture_free_streams) {
// cudaEventRecord requires that the input event and stream are on the same device.
CUDAGuard g(free_stream.device_index());
CUDAGuard g(free_stream.device);

// CUDACachingAllocator.cpp uses raw cuda events, as do we.
cudaEvent_t event;
C10_CUDA_CHECK(cudaEventCreateWithFlags(&event, cudaEventDisableTiming));
C10_CUDA_CHECK(cudaEventRecord(event, free_stream.stream()));
C10_CUDA_CHECK(cudaEventRecord(event, free_stream.stream));
C10_CUDA_CHECK(cudaStreamWaitEvent(capture_stream.stream(), event));
C10_CUDA_CHECK(cudaEventDestroy(event));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You don't have to fix this in this PR (because I think this affects the original caching allocator as well), but now I'm kind of wondering if an error does happen in this cleanup. It seems to me like the capture data structures will be left in a wonky intermediate state. Maybe we shouldn't raise an exception if these calls fail :/

}
Expand All @@ -428,6 +481,21 @@ void notifyCaptureAboutToEnd(int device, CaptureId_t graph_id) {
void notifyCaptureEnded(int device, CaptureId_t graph_id) {
assertValidDevice(device);

std::lock_guard<std::mutex> lk(general_mutex);

TORCH_CHECK(capture_underway.
"CudaMallocAsync::notifyCaptureEnded called, "
"but CudaMallocAsync::capture_underway is false");
capture_underway = false;

for (const auto ptr : ungraphed_ptrs_defer_free_until_no_capture ) {
auto it = ptr_info.find(ptr);
TORCH_INTERNAL_ASSERT(it != ptr_info.end(),
"ptr not found in ptr_info");
TORCH_INTERNAL_ASSERT(it->second.usage_streams.size() != 0,
"ptr's stream uses vector is empty");
free_impl(it);
}
}

void notifyCaptureDestroy(int device, MempoolId_t mempool_id) {} // no-op
Expand Down
2 changes: 1 addition & 1 deletion docs/source/notes/cuda.rst
Original file line number Diff line number Diff line change
Expand Up @@ -352,7 +352,7 @@ Use of a caching allocator can interfere with memory checking tools such as

The behavior of caching allocator can be controlled via environment variable
``PYTORCH_CUDA_ALLOC_CONF``.
The format is ``PYTORCH_CUDA_ALLOC_CONF=<option>:<value>,<option2><value2>...``
The format is ``PYTORCH_CUDA_ALLOC_CONF=<option>:<value>,<option2>:<value2>...``
Available options:

* ``max_split_size_mb`` prevents the allocator from splitting blocks larger
Expand Down
0