8000 Fix static `py::object` dangling pointer with `py::gil_safe_call_once_and_store` by XuehaiPan · Pull Request #130341 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content

Fix static py::object dangling pointer with py::gil_safe_call_once_and_store #130341

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions cmake/Dependencies.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -908,6 +908,12 @@ if(USE_SYSTEM_PYBIND11)
if(NOT pybind11_FOUND)
message(FATAL "Cannot find system pybind11")
endif()
if(${pybind11_VERSION} VERSION_LESS 2.12) # for pybind11::gil_safe_call_once_and_store
message(FATAL_ERROR
"Found pybind11 version ${pybind11_VERSION} which misses some features required by PyTorch. "
"Please install pybind11 >= 2.12.0."
)
endif()
else()
message(STATUS "Using third_party/pybind11.")
set(pybind11_INCLUDE_DIRS ${CMAKE_CURRENT_LIST_DIR}/../third_party/pybind11/include)
Expand Down
2 changes: 2 additions & 0 deletions functorch/csrc/dim/dim.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -740,6 +740,8 @@ struct Tensor : public mpy::base<Tensor> {
static mpy::obj<Tensor> create() {
if (!TensorType) {
TensorType = (PyTypeObject*) mpy::import("functorch.dim").attr("Tensor").ptr();
// NB: leak
Py_INCREF(TensorType);
}
return Tensor::alloc(TensorType);
}
Expand Down
27 changes: 19 additions & 8 deletions torch/csrc/dynamo/guards.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
#endif

#include <sstream>
#include <tuple>
#include <utility>

// For TupleIteratorGetItemAccessor, we need a fast way to retrieve the
Expand Down Expand Up @@ -2461,14 +2462,24 @@ std::unique_ptr<GuardManager> make_guard_manager(
std::string source,
py::handle example_value,
py::handle guard_manager_enum) {
static py::object guard_manager_enum_class =
py::module_::import("torch._dynamo.guards").attr("GuardManagerType");
static py::object base_guard_manager_enum =
guard_manager_enum_class.attr("GUARD_MANAGER");
static py::object dict_guard_manager_enum =
guard_manager_enum_class.attr("DICT_GUARD_MANAGER");
static py::object dict_subclass_guard_manager_enum =
guard_manager_enum_class.attr("DICT_SUBCLASS_GUARD_MANAGER");
using fourobjects =
std::tuple<py::object, py::object, py::object, py::object>;
PYBIND11_CONSTINIT static py::gil_safe_call_once_and_store<fourobjects>
storage;

auto& [guard_manager_enum_class, base_guard_manager_enum, dict_guard_manager_enum, dict_subclass_guard_manager_enum] =
storage
.call_once_and_store_result([]() -> fourobjects {
py::object guard_manager_enum_class =
py::module_::import("torch._dynamo.guards")
.attr("GuardManagerType");
return {
guard_manager_enum_class,
guard_manager_enum_class.attr("GUARD_MANAGER"),
guard_manager_enum_class.attr("DICT_GUARD_MANAGER"),
guard_manager_enum_class.attr("DICT_SUBCLASS_GUARD_MANAGER")};
})
.get_stored();
if (py::isinstance<py::dict>(example_value)) {
// The purpose of having both DictGuardManager and DictSubclassGuardManager
// is to handle the variability in how dictionaries and their subclasses
Expand Down
29 changes: 22 additions & 7 deletions torch/csrc/jit/python/module_python.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,29 +3,44 @@
#include <pybind11/stl.h>
#include <torch/csrc/jit/api/module.h>
#include <torch/csrc/utils/pybind.h>
#include <tuple>

namespace py = pybind11;

namespace torch::jit {

< 10000 /span>
inline std::optional<Module> as_module(py::handle obj) {
static py::handle ScriptModule =
py::module::import("torch.jit").attr("ScriptModule");
PYBIND11_CONSTINIT static py::gil_safe_call_once_and_store<py::object>
storage;
auto& ScriptModule =
storage
.call_once_and_store_result([]() -> py::object {
return py::module_::import("torch.jit").attr("ScriptModule");
})
.get_stored();
if (py::isinstance(obj, ScriptModule)) {
return py::cast<Module>(obj.attr("_c"));
}
return std::nullopt;
}

inline std::optional<Object> as_object(py::handle obj) {
static py::handle ScriptObject =
py::module::import("torch").attr("ScriptObject");
PYBIND11_CONSTINIT static py::gil_safe_call_once_and_store<
std::tuple<py::object, py::object>>
storage;
auto& [ScriptObject, RecursiveScriptClass] =
storage
.call_once_and_store_result(
[]() -> std::tuple<py::object, py::object> {
return {
py::module_::import("torch").attr("ScriptObject"),
py::module_::import("torch.jit")
.attr("RecursiveScriptClass")};
})
.get_stored();
if (py::isinstance(obj, ScriptObject)) {
return py::cast<Object>(obj);
}

static py::handle RecursiveScriptClass =
py::module::import("torch.jit").attr("RecursiveScriptClass");
if (py::isinstance(obj, RecursiveScriptClass)) {
return py::cast<Object>(obj.attr("_c"));
}
Expand Down
11 changes: 9 additions & 2 deletions torch/csrc/jit/python/python_ivalue.h
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,15 @@ struct C10_EXPORT ConcretePyObjectHolder final : PyObjectHolder {
// when using C++. The reason is unclear.
try {
pybind11::gil_scoped_acquire ag;
static py::object& extractorFn = *new py::object(
py::module::import("torch._jit_internal").attr("_extract_tensors"));
PYBIND11_CONSTINIT static py::gil_safe_call_once_and_store<py::object>
storage;
auto& extractorFn =
storage
.call_once_and_store_result([]() -> py::object {
return py::module_::import("torch._jit_internal")
.attr("_extract_tensors");
})
.get_stored();
return extractorFn(py_obj_).cast<std::vector<at::Tensor>>();
} catch (py::error_already_set& e) {
auto err = std::runtime_error(
Expand Down
16 changes: 9 additions & 7 deletions torch/csrc/utils/python_arg_parser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -263,13 +263,15 @@ static PyObject* get_type_of_overloaded_arg(PyObject* obj_or_type) {
static py::object maybe_get_registered_torch_dispatch_rule(
PyObject* torch_api_function,
const py::object& torch_dispatch_object) {
// This is a static object, so we must leak the Python object
// "release()" is used here to preserve 1 refcount on the
// object, preventing it from ever being de-allocated by CPython.
static const py::handle find_torch_dispatch_rule =
py::object(py::module_::import("torch._library.simple_registry")
.attr("find_torch_dispatch_rule"))
.release();
PYBIND11_CONSTINIT static py::gil_safe_call_once_and_store<py::object>
storage;
py::object find_torch_dispatch_rule =
storage
.call_once_and_store_result([]() -> py::object {
return py::module_::import("torch._library.simple_registry")
.attr("find_torch_dispatch_rule");
})
.get_stored();
auto result = find_torch_dispatch_rule(
py::reinterpret_borrow<py::object>(torch_api_function),
torch_dispatch_object.get_type());
Expand Down
33 changes: 21 additions & 12 deletions torch/csrc/utils/python_symnode.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,24 +3,33 @@
namespace torch {

py::handle get_symint_class() {
// NB: leak
static py::handle symint_class =
py::object(py::module::import("torch").attr("SymInt")).release();
return symint_class;
PYBIND11_CONSTINIT static py::gil_safe_call_once_and_store<py::object>
storage;
return storage
.call_once_and_store_result([]() -> py::object {
return py::module::import("torch").attr("SymInt");
})
.get_stored();
}

py::handle get_symfloat_class() {
// NB: leak
static py::handle symfloat_class =
py::object(py::module::import("torch").attr("SymFloat")).release();
return symfloat_class;
PYBIND11_CONSTINIT static py::gil_safe_call_once_and_store<py::object>
storage;
return storage
.call_once_and_store_result([]() -> py::object {
return py::module::import("torch").attr("SymFloat");
})
.get_stored();
}

py::handle get_symbool_class() {
// NB: leak
static py::handle symbool_class =
py::object(py::module::import("torch").attr("SymBool")).release();
return symbool_class;
PYBIND11_CONSTINIT static py::gil_safe_call_once_and_store<py::object>
storage;
return storage
.call_once_and_store_result([]() -> py::object {
return py::module::import("torch").attr("SymBool");
})
.get_stored();
}

} // namespace torch
Loading
0