8000 Fix static `py::object` leak with `py::gil_safe_call_once_and_store` · pytorch/pytorch@3dbb655 · GitHub
[go: up one dir, main page]

Skip to content

Commit 3dbb655

Browse files
committed
Fix static py::object leak with py::gil_safe_call_once_and_store
1 parent b1a00b7 commit 3dbb655

File tree

7 files changed

+88
-36
lines changed

7 files changed

+88
-36
lines changed

cmake/Dependencies.cmake

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -908,6 +908,12 @@ if(USE_SYSTEM_PYBIND11)
908908
if(NOT pybind11_FOUND)
909909
message(FATAL "Cannot find system pybind11")
910910
endif()
911+
if(${pybind11_VERSION} VERSION_LESS 2.12)
912+
message(FATAL_ERROR
913+
"Found pybind11 version ${pybind11_VERSION} which misses some features required by PyTorch. "
914+
"Please install pybind11 >= 2.12.0."
915+
)
916+
endif()
911917
else()
912918
message(STATUS "Using third_party/pybind11.")
913919
set(pybind11_INCLUDE_DIRS ${CMAKE_CURRENT_LIST_DIR}/../third_party/pybind11/include)

functorch/csrc/dim/dim.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -740,6 +740,8 @@ struct Tensor : public mpy::base<Tensor> {
740740
static mpy::obj<Tensor> create() {
741741
if (!TensorType) {
742742
TensorType = (PyTypeObject*) mpy::import("functorch.dim").attr("Tensor").ptr();
743+
// NB: leak
744+
Py_INCREF(TensorType);
743745
}
744746
return Tensor::alloc(TensorType);
745747
}

torch/csrc/dynamo/guards.cpp

Lines changed: 19 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
#endif
2727

2828
#include <sstream>
29+
#include <tuple>
2930
#include <utility>
3031

3132
// For TupleIteratorGetItemAccessor, we need a fast way to retrieve the
@@ -2461,14 +2462,24 @@ std::unique_ptr<GuardManager> make_guard_manager(
24612462
std::string source,
24622463
py::handle example_value,
24632464
py::handle guard_manager_enum) {
2464-
static py::object guard_manager_enum_class =
2465-
py::module_::import("torch._dynamo.guards").attr("GuardManagerType");
2466-
static py::object base_guard_manager_enum =
2467-
guard_manager_enum_class.attr("GUARD_MANAGER");
2468-
static py::object dict_guard_manager_enum =
2469-
guard_manager_enum_class.attr("DICT_GUARD_MANAGER");
2470-
static py::object dict_subclass_guard_manager_enum =
2471-
guard_manager_enum_class.attr("DICT_SUBCLASS_GUARD_MANAGER");
2465+
using fourobjects =
2466 8000 +
std::tuple<py::object, py::object, py::object, py::object>;
2467+
PYBIND11_CONSTINIT static py::gil_safe_call_once_and_store<fourobjects>
2468+
storage;
2469+
2470+
auto& [guard_manager_enum_class, base_guard_manager_enum, dict_guard_manager_enum, dict_subclass_guard_manager_enum] =
2471+
storage
2472+
.call_once_and_store_result([]() -> fourobjects {
2473+
py::object guard_manager_enum_class =
2474+
py::module_::import("torch._dynamo.guards")
2475+
.attr("GuardManagerType");
2476+
return {
2477+
guard_manager_enum_class,
2478+
guard_manager_enum_class.attr("GUARD_MANAGER"),
2479+
guard_manager_enum_class.attr("DICT_GUARD_MANAGER"),
2480+
guard_manager_enum_class.attr("DICT_SUBCLASS_GUARD_MANAGER")};
2481+
})
2482+
.get_stored();
24722483
if (py::isinstance<py::dict>(example_value)) {
24732484
// The purpose of having both DictGuardManager and DictSubclassGuardManager
24742485
// is to handle the variability in how dictionaries and their subclasses

torch/csrc/jit/python/module_python.h

Lines changed: 22 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3,29 +3,44 @@
33
#include <pybind11/stl.h>
44
#include <torch/csrc/jit/api/module.h>
55
#include <torch/csrc/utils/pybind.h>
6+
#include <tuple>
67

78
namespace py = pybind11;
89

910
namespace torch::jit {
1011

1112
inline std::optional<Module> as_module(py::handle obj) {
12-
static py::handle ScriptModule =
13-
py::module::import("torch.jit").attr("ScriptModule");
13+
PYBIND11_CONSTINIT static py::gil_safe_call_once_and_store<py::object>
14+
storage;
15+
auto& ScriptModule =
16+
storage
17+
.call_once_and_store_result([]() -> py::object {
18+
return py::module_::import("torch.jit").attr("ScriptModule");
19+
})
20+
.get_stored();
1421
if (py::isinstance(obj, ScriptModule)) {
1522
return py::cast<Module>(obj.attr("_c"));
1623
}
1724
return std::nullopt;
1825
}
1926

2027
inline std::optional<Object> as_object(py::handle obj) {
21-
static py::handle ScriptObject =
22-
py::module::import("torch").attr("ScriptObject");
28+
PYBIND11_CONSTINIT static py::gil_safe_call_once_and_store<
29+
std::tuple<py::object, py::object>>
30+
storage;
31+
auto& [ScriptObject, RecursiveScriptClass] =
32+
storage
33+
.call_once_and_store_result(
34+
[]() -> std::tuple<py::object, py::object> {
35+
return {
36+
py::module_::import("torch").attr("ScriptObject"),
37+
py::module_::import("torch.jit")
38+
.attr("RecursiveScriptClass")};
39+
})
40+
.get_stored();
2341
if (py::isinstance(obj, ScriptObject)) {
2442
return py::cast<Object>(obj);
2543
}
26-
27-
static py::handle RecursiveScriptClass =
28-
py::module::import("torch.jit").attr("RecursiveScriptClass");
2944
if (py::isinstance(obj, RecursiveScriptClass)) {
3045
return py::cast<Object>(obj.attr("_c"));
3146
}

torch/csrc/jit/python/python_ivalue.h

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -49,8 +49,15 @@ struct C10_EXPORT ConcretePyObjectHolder final : PyObjectHolder {
4949
// when using C++. The reason is unclear.
5050
try {
5151
pybind11::gil_scoped_acquire ag;
52-
static py::object& extractorFn = *new py::object(
53-
py::module::import("torch._jit_internal").attr("_extract_tensors"));
52+
PYBIND11_CONSTINIT static py::gil_safe_call_once_and_store<py::object>
53+
storage;
54+
auto& extractorFn =
55+
storage
56+
.call_once_and_store_result([]() -> py::object {
57+
return py::module_::import("torch._jit_internal")
58+
.attr("_extract_tensors");
59+
})
60+
.get_stored();
5461
return extractorFn(py_obj_).cast<std::vector<at::Tensor>>();
5562
} catch (py::error_already_set& e) {
5663
auto err = std::runtime_error(

torch/csrc/utils/python_arg_parser.cpp

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -263,13 +263,15 @@ static PyObject* get_type_of_overloaded_arg(PyObject* obj_or_type) {
263263
static py::object maybe_get_registered_torch_dispatch_rule(
264264
PyObject* torch_api_function,
265265
const py::object& torch_dispatch_object) {
266-
// This is a static object, so we must leak the Python object
267-
// "release()" is used here to preserve 1 refcount on the
268-
// object, preventing it from ever being de-allocated by CPython.
269-
static const py::handle find_torch_dispatch_rule =
270-
py::object(py::module_::import("torch._library.simple_registry")
271-
.attr("find_torch_dispatch_rule"))
272-
.release();
266+
PYBIND11_CONSTINIT static py::gil_safe_call_once_and_store<py::object>
267+
storage;
268+
py::object find_torch_dispatch_rule =
269+
storage
270+
.call_once_and_store_result([]() -> py::object {
271+
return py::module_::import("torch._library.simple_registry")
272+
.attr("find_torch_dispatch_rule");
273+
})
274+
.get_stored();
273275
auto result = find_torch_dispatch_rule(
274276
py::reinterpret_borrow<py::object>(torch_api_function),
275277
torch_dispatch_object.get_type());

torch/csrc/utils/python_symnode.cpp

Lines changed: 21 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -3,24 +3,33 @@
33
namespace torch {
44

55
py::handle get_symint_class() {
6-
// NB: leak
7-
static py::handle symint_class =
8-
py::object(py::module::import("torch").attr("SymInt")).release();
9-
return symint_class;
6+
PYBIND11_CONSTINIT static py::gil_safe_call_once_and_store<py::object>
7+
storage;
8+
return storage
9+
.call_once_and_store_result([]() -> py::object {
10+
return py::module::import("torch").attr("SymInt");
11+
})
12+
.get_stored();
1013
}
1114

1215
py::handle get_symfloat_class() {
13-
// NB: leak
14-
static py::handle symfloat_class =
15-
py::object(py::module::import("torch").attr("SymFloat")).release();
16-
return symfloat_class;
16+
PYBIND11_CONSTINIT static py::gil_safe_call_once_and_store<py::object>
17+
storage;
18+
return storage
19+
.call_once_and_store_result([]() -> py::object {
20+
return py::module::import("torch").attr("SymFloat");
21+
})
22+
.get_stored();
1723
}
1824

1925
py::handle get_symbool_class() {
20-
// NB: leak
21-
static py::handle symbool_class =
22-
py::object(py::module::import("torch").attr("SymBool")).release();
23-
return symbool_class;
26+
PYBIND11_CONSTINIT static py::gil_safe_call_once_and_store<py::object>
27+
storage;
28+
return storage
29+
.call_once_and_store_result([]() -> py::object {
30+
return py::module::import("torch").attr("SymBool");
31+
})
32+
.get_stored();
2433
}
2534

2635
} // namespace torch

0 commit comments

Comments
 (0)
0