8000 Fix static `py::object` leak with `py::gil_safe_call_once_and_store` · pytorch/pytorch@43c7a4f · GitHub
[go: up one dir, main page]

Skip to content

Commit 43c7a4f

Browse files
committed
Fix static py::object leak with py::gil_safe_call_once_and_store
ghstack-source-id: dfe5990 Pull Request resolved: #130721
1 parent 6c1f156 commit 43c7a4f

File tree

6 files changed

+82
-36
lines changed

6 files changed

+82
-36
lines changed

functorch/csrc/dim/dim.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -740,6 +740,8 @@ struct Tensor : public mpy::base<Tensor> {
740740
static mpy::obj<Tensor> create() {
741741
if (!TensorType) {
742742
TensorType = (PyTypeObject*) mpy::import("functorch.dim").attr("Tensor").ptr();
743+
// NB: leak
744+
Py_INCREF(TensorType);
743745
}
744746
return Tensor::alloc(TensorType);
745747
}

torch/csrc/dynamo/guards.cpp

Lines changed: 19 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
#endif
1919

2020
#include <sstream>
21+
#include <tuple>
2122
#include <utility>
2223

2324
// For TupleIteratorGetItemAccessor, we need a fast way to retrieve the
@@ -2427,14 +2428,24 @@ std::unique_ptr<GuardManager> make_guard_manager(
24272428
std::string source,
24282429
py::handle example_value,
24292430
py::handle guard_manager_enum) {
2430-
static py::object guard_manager_enum_class =
2431-
py::module_::import("torch._dynamo.guards").attr("GuardManagerType");
2432-
static py::object base_guard_manager_enum =
2433-
guard_manager_enum_class.attr("GUARD_MANAGER");
2434-
static py::object dict_guard_manager_enum =
2435-
guard_manager_enum_class.attr("DICT_GUARD_MANAGER");
2436-
static py::object dict_subclass_guard_manager_enum =
2437-
guard_manager_enum_class.attr("DICT_SUBCLASS_GUARD_MANAGER");
2431+
using fourobjects =
2432+
std::tuple<py::object, py::object, py::object, py::object>;
2433+
PYBIND11_CONSTINIT static py::gil_safe_call_once_and_store<fourobjects>
2434+
storage;
2435+
2436+
auto& [guard_manager_enum_class, base_guard_manager_enum, dict_guard_manager_enum, dict_subclass_guard_manager_enum] =
2437+
storage
2438+
.call_once_and_store_result([]() -> fourobjects {
2439+
py::object guard_manager_enum_class =
2440+
py::module_::import("torch._dynamo.guards")
2441+
.attr("GuardManagerType");
2442+
return {
2443+
guard_manager_enum_class,
2444+
guard_manager_enum_class.attr("GUARD_MANAGER"),
2445+
guard_manager_enum_class.attr("DICT_GUARD_MANAGER"),
2446+
guard_manager_enum_class.attr("DICT_SUBCLASS_GUARD_MANAGER")};
2447+
})
2448+
.get_stored();
24382449
if (py::isinstance<py::dict>(example_value)) {
24392450
// The purpose of having both DictGuardManager and DictSubclassGuardManager
24402451
// is to handle the variability in how dictionaries and their subclasses

torch/csrc/jit/python/module_python.h

Lines changed: 22 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3,29 +3,44 @@
33
#include <pybind11/stl.h>
44
#include <torch/csrc/jit/api/module.h>
55
#include <torch/csrc/utils/pybind.h>
6+
#include <tuple>
67

78
namespace py = pybind11;
89

910
namespace torch::jit {
1011

1112
inline std::optional<Module> as_module(py::handle obj) {
12-
static py::handle ScriptModule =
13-
py::module::import("torch.jit").attr("ScriptModule");
13+
PYBIND11_CONSTINIT static py::gil_safe_call_once_and_store<py::object>
14+
storage;
15+
auto& ScriptModule =
16+
storage
17+
.call_once_and_store_result([]() -> py::object {
18+
return py::module_::import("torch.jit").attr("ScriptModule");
19+
})
20+
.get_stored();
1421
if (py::isinstance(obj, ScriptModule)) {
1522
return py::cast<Module>(obj.attr("_c"));
1623
}
1724
return std::nullopt;
1825
}
1926

2027
inline std::optional<Object> as_object(py::handle obj) {
21-
static py::handle ScriptObject =
22-
py::module::import("torch").attr("ScriptObject");
28+
PYBIND11_CONSTINIT static py::gil_safe_call_once_and_store<
29+
std::tuple<py::object, py::object>>
30+
storage;
31+
auto& [ScriptObject, RecursiveScriptClass] =
32+
storage
33+
.call_once_and_store_result(
34+
[]() -> std::tuple<py::object, py::object> {
35+
return {
36+
py::module_::import("torch").attr("ScriptObject"),
37+
py::module_::import("torch.jit")
38+
.attr("RecursiveScriptClass")};
39+
})
40+
.get_stored();
2341
if (py::isinstance(obj, ScriptObject)) {
2442
return py::cast<Object>(obj);
2543
}
26-
27-
static py::handle RecursiveScriptClass =
28-
py::module::import("torch.jit").attr("RecursiveScriptClass");
2944
if (py::isinstance(obj, RecursiveScriptClass)) {
3045
return py::cast<Object>(obj.attr("_c"));
3146
}

torch/csrc/jit/python/python_ivalue.h

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -49,8 +49,15 @@ struct C10_EXPORT ConcretePyObjectHolder final : PyObjectHolder {
4949
// when using C++. The reason is unclear.
5050
try {
5151
pybind11::gil_scoped_acquire ag;
52-
static py::object& extractorFn = *new py::object(
53-
py::module::import("torch._jit_internal").attr("_extract_tensors"));
52+
PYBIND11_CONSTINIT static py::gil_safe_call_once_and_store<py::object>
53+
storage;
54+
auto& extractorFn =
55+
storage
56+
.call_once_and_store_result([]() -> py::object {
57+
return py::module_::import("torch._jit_internal")
58+
.attr("_extract_tensors");
59+
})
60+
.get_stored();
5461
return extractorFn(py_obj_).cast<std::vector<at::Tensor>>();
5562
} catch (py::error_already_set& e) {
5663
auto err = std::runtime_error(

torch/csrc/utils/python_arg_parser.cpp

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -263,13 +263,15 @@ static PyObject* get_type_of_overloaded_arg(PyObject* obj_or_type) {
263263
static py::object maybe_get_registered_torch_dispatch_rule(
264264
PyObject* torch_api_function,
265265
const py::object& torch_dispatch_object) {
266-
// This is a static object, so we must leak the Python object
267-
// "release()" is used here to preserve 1 refcount on the
268-
// object, preventing it from ever being de-allocated by CPython.
269-
static const py::handle find_torch_dispatch_rule =
270-
py::object(py::module_::import("torch._library.simple_registry")
271-
.attr("find_torch_dispatch_rule"))
272-
.release();
266+
PYBIND11_CONSTINIT static py::gil_safe_call_once_and_store<py::object>
267+
storage;
268+
py::object find_torch_dispatch_rule =
269+
storage
270+
.call_once_and_store_result([]() -> py::object {
271+
return py::module_::import("torch._library.simple_registry")
272+
.attr("find_torch_dispatch_rule");
273+
})
274+
.get_stored();
273275
auto result = find_torch_dispatch_rule(
274276
py::reinterpret_borrow<py::object>(torch_api_function),
275277
torch_dispatch_object.get_type());

torch/csrc/utils/python_symnode.cpp

Lines changed: 21 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -3,24 +3,33 @@
33
namespace torch {
44

55
py::handle get_symint_class() {
6-
// NB: leak
7-
static py::handle symint_class =
8-
py::object(py::module::import("torch").attr("SymInt")).release();
9-
return symint_class;
6+
PYBIND11_CONSTINIT static py::gil_safe_call_once_and_store<py::object>
7+
storage;
8+
return storage
9+
.call_once_and_store_result([]() -> py::object {
10+
return py::module::import("torch").attr("SymInt");
11+
})
12+
.get_stored();
1013
}
1114

1215
py::handle get_symfloat_class() {
13-
// NB: leak
14-
static py::handle symfloat_class =
15-
py::object(py::module::import("torch").attr("SymFloat")).release();
16-
return symfloat_class;
16+
PYBIND11_CONSTINIT static py::gil_safe_call_once_and_store<py::object>
17+
storage;
18+
return storage
19+
.call_once_and_store_result([]() -> py::object {
20+
return py::module::import("torch").attr("SymFloat");
21+
})
22+
.get_stored();
1723
}
1824

1925
py::handle get_symbool_class() {
20-
// NB: leak
21-
static py::handle symbool_class =
22-
py::object(py::module::import("torch").attr("SymBool")).release();
23-
return symbool_class;
26+
PYBIND11_CONSTINIT static py::gil_safe_call_once_and_store<py::object>
27+
storage;
28+
return storage
29+
.call_once_and_store_result([]() -> py::object {
30+
return py::module::import("torch").attr("SymBool");
31+
})
32+
.get_stored();
2433
}
2534

2635
} // namespace torch

0 commit comments

Comments
 (0)
0