8000 Generalize torch._C._set_allocator_settings to be generic (#156175) · pytorch/pytorch@d3ce450 · GitHub
[go: up one dir, main page]

Skip to content

Commit d3ce450

Browse files
guangyeypytorchmergebot
authored andcommitted
Generalize torch._C._set_allocator_settings to be generic (#156175)
# Motivation This PR moves the implementation of `torch.cuda.memory._set_allocator_settings` to `torch._C._accelerator_setAllocatorSettings`. Since the original API was intended as a temporary/internal utility, I am not exposing the new function as a public API. Pull Request resolved: #156175 Approved by: https://github.com/albanD ghstack dependencies: #149601, #157908, #150312, #156165
1 parent 1fc010a commit d3ce450

File tree

9 files changed

+26
-34
lines changed

9 files changed

+26
-34
lines changed

c10/core/AllocatorConfig.cpp

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ size_t AcceleratorAllocatorConfig::roundup_power2_divisions(size_t size) {
4545
63 - llvm::countLeadingZeros(kRoundUpPowerOfTwoStart);
4646
const size_t interval_end =
4747
63 - llvm::countLeadingZeros(kRoundUpPowerOfTwoEnd);
48-
TORCH_CHECK(
48+
TORCH_CHECK_VALUE(
4949
interval_end - interval_start == kRoundUpPowerOfTwoIntervals,
5050
"kRoundUpPowerOfTwoIntervals mismatch");
5151

@@ -64,7 +64,7 @@ size_t AcceleratorAllocatorConfig::parseMaxSplitSize(
6464
std::numeric_limits<size_t>::max() / kMB;
6565

6666
size_t val_env = tokenizer.toSizeT(++i);
67-
TORCH_CHECK(
67+
TORCH_CHECK_VALUE(
6868
val_env >= min_allowed_split_size_mb,
6969
"CachingAllocator option max_split_size_mb too small, must be >= ",
7070
min_allowed_split_size_mb);
@@ -83,7 +83,7 @@ size_t AcceleratorAllocatorConfig::parseMaxNonSplitRoundingSize(
8383
std::numeric_limits<size_t>::max() / kMB;
8484

8585
size_t val_env = tokenizer.toSizeT(++i);
86-
TORCH_CHECK(
86+
TORCH_CHECK_VALUE(
8787
val_env >= min_allowed_split_size_mb,
8888
"CachingAllocator option max_non_split_rounding_mb too small, must be >= ",
8989
min_allowed_split_size_mb);
@@ -98,7 +98,7 @@ size_t AcceleratorAllocatorConfig::parseGarbageCollectionThreshold(
9898
size_t i) {
9999
tokenizer.checkToken(++i, ":");
100100
double val_env = tokenizer.toDouble(++i);
101-
TORCH_CHECK(
101+
TORCH_CHECK_VALUE(
102102
val_env > 0 && val_env < 1.0,
103103
"garbage_collect_threshold is invalid, set it in (0.0, 1.0)");
104104
garbage_collection_threshold_ = val_env;
@@ -119,7 +119,7 @@ size_t AcceleratorAllocatorConfig::parseRoundUpPower2Divisions(
119119
size_t value_index = i;
120120
tokenizer.checkToken(++i, ":");
121121
size_t value = tokenizer.toSizeT(++i);
122-
TORCH_CHECK(
122+
TORCH_CHECK_VALUE(
123123
value == 0 || llvm::isPowerOf2_64(value),
124124
"For roundups, the divisions has to be power of 2 or 0 to disable roundup ");
125125

@@ -133,7 +133,7 @@ size_t AcceleratorAllocatorConfig::parseRoundUpPower2Divisions(
133133
value);
134134
} else {
135135
size_t boundary = tokenizer.toSizeT(value_index);
136-
TORCH_CHECK(
136+
TORCH_CHECK_VALUE(
137137
llvm::isPowerOf2_64(boundary),
138138
"For roundups, the intervals have to be power of 2 ");
139139

@@ -163,7 +163,7 @@ size_t AcceleratorAllocatorConfig::parseRoundUpPower2Divisions(
163163
"Expected closing bracket ']' in ConfigTokenizer but reached end of config");
164164
} else { // Keep this for backwards compatibility
165165
size_t value = tokenizer.toSizeT(i);
166-
TORCH_CHECK(
166+
TORCH_CHECK_VALUE(
167167
llvm::isPowerOf2_64(value),
168168
"For roundups, the divisions has to be power of 2 ");
169169
std::fill(

c10/core/AllocatorConfig.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ class ConfigTokenizer {
7676
} else if (token == "False") {
7777
return false;
7878
} else {
79-
TORCH_CHECK(
79+
TORCH_CHECK_VALUE(
8080
false,
8181
"Expected 'True' or 'False' at index ",
8282
i,

c10/cuda/CUDAAllocatorConfig.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ size_t CUDAAllocatorConfig::parseAllocatorConfig(
2222
#define PYTORCH_TOKEN2 "hipMallocAsync"
2323
tokenizer.checkToken(++i, ":");
2424
i++; // Move to the value after the colon
25-
TORCH_CHECK(
25+
TORCH_CHECK_VALUE(
2626
((tokenizer[i] == "native") || (tokenizer[i] == PYTORCH_TOKEN1) ||
2727
(tokenizer[i] == PYTORCH_TOKEN2)),
2828
"Unknown allocator backend, "
@@ -134,12 +134,12 @@ size_t CUDAAllocatorConfig::parsePinnedNumRegisterThreads(
134134
size_t i) {
135135
tokenizer.checkToken(++i, ":");
136136
size_t val2 = tokenizer.toSizeT(++i);
137-
TORCH_CHECK(
137+
TORCH_CHECK_VALUE(
138138
llvm::isPowerOf2_64(val2),
139139
"Number of register threads has to be power of 2 ",
140140
"");
141141
auto maxThreads = CUDAAllocatorConfig::pinned_max_register_threads();
142-
TORCH_CHECK(
142+
TORCH_CHECK_VALUE(
143143
val2 <= maxThreads,
144144
"Number of register threads should be less than or equal to " +
145145
std::to_string(maxThreads),

test/test_cuda.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4471,28 +4471,28 @@ def power2_div(size, div_factor):
44714471
with self.assertRaises(RuntimeError):
44724472
torch.cuda.memory._set_allocator_settings("foo:1,bar:2")
44734473

4474-
with self.assertRaises(RuntimeError):
4474+
with self.assertRaises(ValueError):
44754475
torch.cuda.memory._set_allocator_settings(
44764476
"garbage_collection_threshold:1.2"
44774477
)
44784478

4479-
with self.assertRaises(RuntimeError):
4479+
with self.assertRaises(ValueError):
44804480
torch.cuda.memory._set_allocator_settings("max_split_size_mb:2")
44814481

4482-
with self.assertRaises(RuntimeError):
4482+
with self.assertRaises(ValueError):
44834483
torch.cuda.memory._set_allocator_settings("release_lock_on_cudamalloc:none")
44844484

4485-
with self.assertRaises(RuntimeError):
4485+
with self.assertRaises(ValueError):
44864486
torch.cuda.memory._set_allocator_settings(
44874487
"pinned_use_cuda_host_register:none"
44884488
)
44894489

4490-
with self.assertRaises(RuntimeError):
4490+
with self.assertRaises(ValueError):
44914491
torch.cuda.memory._set_allocator_settings(
44924492
"pinned_num_register_threads:none"
44934493
)
44944494

4495-
with self.assertRaises(RuntimeError):
4495+
with self.assertRaises(ValueError):
44964496
torch.cuda.memory._set_allocator_settings(
44974497
"pinned_num_register_threads:1024"
44984498
)

torch/_C/__init__.pyi.in

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2017,7 +2017,6 @@ def _cuda_cudaHostAllocator() -> _int: ...
20172017
def _cuda_cudaCachingAllocator_raw_alloc(size: _int, cuda_stream: _int) -> _int: ...
20182018
def _cuda_cudaCachingAllocator_raw_delete(ptr: _int) -> None: ...
20192019
def _cuda_cudaCachingAllocator_enable(val: _bool) -> None: ...
2020-
def _cuda_cudaCachingAllocator_set_allocator_settings(env: str) -> None: ...
20212020
def _cuda_beginAllocateToPool(device: _int, mempool_id: tuple[_int, _int]) -> None: ...
20222021
def _cuda_beginAllocateCurrentThreadToPool(
20232022
device: _int,
@@ -2435,6 +2434,7 @@ def _accelerator_getStream(device_index: _int) -> Stream: ...
24352434
def _accelerator_synchronizeDevice(device_index: _int) -> None: ...
24362435
def _accelerator_exchangeDevice(device_index: _int) -> _int: ...
24372436
def _accelerator_maybeExchangeDevice(device_index: _int) -> _int: ...
2437+
def _accelerator_setAllocatorSettings(env: str) -> None: ...
24382438

24392439
# Defined in torch/csrc/jit/python/python_tracer.cpp
24402440
class TracingState:

torch/_dynamo/trace_rules.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -447,6 +447,7 @@
447447
"torch._C._accelerator_getAccelerator",
448448
"torch._C._accelerator_getDeviceIndex",
449449
"torch._C._accelerator_getStream",
450+
"torch._C._accelerator_setAllocatorSettings",
450451
"torch._C._accelerator_setStream",
451452
"torch._C._accelerator_synchronizeDevice",
452453
"torch._C._activate_gpu_trace",
@@ -503,7 +504,6 @@
503504
"torch._C._cuda_clearCublasWorkspaces",
504505
"torch._C._cuda_cudaCachingAllocator_raw_alloc",
505506
"torch._C._cuda_cudaCachingAllocator_raw_delete",
506-
"torch._C._cuda_cudaCachingAllocator_set_allocator_settings",
507507
"torch._C._cuda_cudaHostAllocator",
508508
"torch._C._cuda_customAllocator",
509509
"torch._C._cuda_emptyCache",

torch/csrc/DeviceAccelerator.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
#include <c10/core/AllocatorConfig.h>
12
#include <torch/csrc/DeviceAccelerator.h>
23
#include <torch/csrc/utils/device_lazy_init.h>
34

@@ -72,6 +73,10 @@ void initModule(PyObject* module) {
7273
torch::utils::maybe_initialize_device(device_type);
7374
return a F913 t::accelerator::maybeExchangeDevice(device_index);
7475
});
76+
77+
m.def("_accelerator_setAllocatorSettings", [](std::string env) {
78+
c10::CachingAllocator::setAllocatorSettings(env);
79+
});
7580
}
7681

7782
} // namespace torch::accelerator

torch/csrc/cuda/Module.cpp

Lines changed: 0 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -422,15 +422,6 @@ PyObject* THCPModule_cudaCachingAllocator_enable(
422422
END_HANDLE_TH_ERRORS
423423
}
424424

425-
PyObject* THCPModule_cudaCachingAllocator_set_allocator_settings(
426-
PyObject* _unused,
427-
PyObject* env) {
428-
HANDLE_TH_ERRORS
429-
c10::CachingAllocator::setAllocatorSettings(THPUtils_unpackString(env));
430-
Py_RETURN_NONE;
431-
END_HANDLE_TH_ERRORS
432-
}
433-
434425
PyObject* THCPModule_getAllocatorBackend(PyObject* _unused, PyObject* noargs) {
435426
HANDLE_TH_ERRORS
436427
return THPUtils_packString(c10::cuda::CUDACachingAllocator::name());
@@ -2052,10 +2043,6 @@ static struct PyMethodDef _THCPModule_methods[] = {
20522043
THCPModule_cudaCachingAllocator_enable,
20532044
METH_O,
20542045
nullptr},
2055-
{"_cuda_cudaCachingAllocator_set_allocator_settings",
2056-
THCPModule_cudaCachingAllocator_set_allocator_settings,
2057-
METH_O,
2058-
nullptr},
20592046
{"_cuda_getAllocatorBackend",
20602047
THCPModule_getAllocatorBackend,
20612048
METH_NOARGS,

torch/cuda/memory.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1075,8 +1075,8 @@ def _save_memory_usage(filename="output.svg", snapshot=None):
10751075
f.write(_memory(snapshot))
10761076

10771077

1078-
def _set_allocator_settings(env: str):
1079-
return torch._C._cuda_cudaCachingAllocator_set_allocator_settings(env)
1078+
# Keep for BC only
1079+
_set_allocator_settings = torch._C._accelerator_setAllocatorSettings
10801080

10811081

10821082
def get_allocator_backend() -> str:

0 commit comments

Comments
 (0)
0