diff --git a/cmake/Dependencies.cmake b/cmake/Dependencies.cmake index 1dc33efec7b878..7a878b9e1f27f4 100644 --- a/cmake/Dependencies.cmake +++ b/cmake/Dependencies.cmake @@ -908,6 +908,12 @@ if(USE_SYSTEM_PYBIND11) if(NOT pybind11_FOUND) message(FATAL "Cannot find system pybind11") endif() + if(${pybind11_VERSION} VERSION_LESS 2.12) # for pybind11::gil_safe_call_once_and_store + message(FATAL_ERROR + "Found pybind11 version ${pybind11_VERSION} which misses some features required by PyTorch. " + "Please install pybind11 >= 2.12.0." + ) + endif() else() message(STATUS "Using third_party/pybind11.") set(pybind11_INCLUDE_DIRS ${CMAKE_CURRENT_LIST_DIR}/../third_party/pybind11/include) diff --git a/functorch/csrc/dim/dim.cpp b/functorch/csrc/dim/dim.cpp index 722618efbb090b..64e39e09d84d38 100644 --- a/functorch/csrc/dim/dim.cpp +++ b/functorch/csrc/dim/dim.cpp @@ -740,6 +740,8 @@ struct Tensor : public mpy::base { static mpy::obj create() { if (!TensorType) { TensorType = (PyTypeObject*) mpy::import("functorch.dim").attr("Tensor").ptr(); + // NB: leak + Py_INCREF(TensorType); } return Tensor::alloc(TensorType); } diff --git a/torch/csrc/dynamo/guards.cpp b/torch/csrc/dynamo/guards.cpp index 0c2cf51a2cbbbb..097f1f03d3a7e5 100644 --- a/torch/csrc/dynamo/guards.cpp +++ b/torch/csrc/dynamo/guards.cpp @@ -26,6 +26,7 @@ #endif #include +#include #include // For TupleIteratorGetItemAccessor, we need a fast way to retrieve the @@ -2461,14 +2462,24 @@ std::unique_ptr make_guard_manager( std::string source, py::handle example_value, py::handle guard_manager_enum) { - static py::object guard_manager_enum_class = - py::module_::import("torch._dynamo.guards").attr("GuardManagerType"); - static py::object base_guard_manager_enum = - guard_manager_enum_class.attr("GUARD_MANAGER"); - static py::object dict_guard_manager_enum = - guard_manager_enum_class.attr("DICT_GUARD_MANAGER"); - static py::object dict_subclass_guard_manager_enum = - guard_manager_enum_class.attr("DICT_SUBCLASS_GUARD_MANAGER"); + using fourobjects = + std::tuple; + PYBIND11_CONSTINIT static py::gil_safe_call_once_and_store + storage; + + auto& [guard_manager_enum_class, base_guard_manager_enum, dict_guard_manager_enum, dict_subclass_guard_manager_enum] = + storage + .call_once_and_store_result([]() -> fourobjects { + py::object guard_manager_enum_class = + py::module_::import("torch._dynamo.guards") + .attr("GuardManagerType"); + return { + guard_manager_enum_class, + guard_manager_enum_class.attr("GUARD_MANAGER"), + guard_manager_enum_class.attr("DICT_GUARD_MANAGER"), + guard_manager_enum_class.attr("DICT_SUBCLASS_GUARD_MANAGER")}; + }) + .get_stored(); if (py::isinstance(example_value)) { // The purpose of having both DictGuardManager and DictSubclassGuardManager // is to handle the variability in how dictionaries and their subclasses diff --git a/torch/csrc/jit/python/module_python.h b/torch/csrc/jit/python/module_python.h index b1ddf6f37c6786..2be4a27293d07a 100644 --- a/torch/csrc/jit/python/module_python.h +++ b/torch/csrc/jit/python/module_python.h @@ -3,14 +3,21 @@ #include #include #include +#include namespace py = pybind11; namespace torch::jit { inline std::optional as_module(py::handle obj) { - static py::handle ScriptModule = - py::module::import("torch.jit").attr("ScriptModule"); + PYBIND11_CONSTINIT static py::gil_safe_call_once_and_store + storage; + auto& ScriptModule = + storage + .call_once_and_store_result([]() -> py::object { + return py::module_::import("torch.jit").attr("ScriptModule"); + }) + .get_stored(); if (py::isinstance(obj, ScriptModule)) { return py::cast(obj.attr("_c")); } @@ -18,14 +25,22 @@ inline std::optional as_module(py::handle obj) { } inline std::optional as_object(py::handle obj) { - static py::handle ScriptObject = - py::module::import("torch").attr("ScriptObject"); + PYBIND11_CONSTINIT static py::gil_safe_call_once_and_store< + std::tuple> + storage; + auto& [ScriptObject, RecursiveScriptClass] = + storage + .call_once_and_store_result( + []() -> std::tuple { + return { + py::module_::import("torch").attr("ScriptObject"), + py::module_::import("torch.jit") + .attr("RecursiveScriptClass")}; + }) + .get_stored(); if (py::isinstance(obj, ScriptObject)) { return py::cast(obj); } - - static py::handle RecursiveScriptClass = - py::module::import("torch.jit").attr("RecursiveScriptClass"); if (py::isinstance(obj, RecursiveScriptClass)) { return py::cast(obj.attr("_c")); } diff --git a/torch/csrc/jit/python/python_ivalue.h b/torch/csrc/jit/python/python_ivalue.h index a5475bfb849960..b8ed28dba7aca7 100644 --- a/torch/csrc/jit/python/python_ivalue.h +++ b/torch/csrc/jit/python/python_ivalue.h @@ -49,8 +49,15 @@ struct C10_EXPORT ConcretePyObjectHolder final : PyObjectHolder { // when using C++. The reason is unclear. try { pybind11::gil_scoped_acquire ag; - static py::object& extractorFn = *new py::object( - py::module::import("torch._jit_internal").attr("_extract_tensors")); + PYBIND11_CONSTINIT static py::gil_safe_call_once_and_store + storage; + auto& extractorFn = + storage + .call_once_and_store_result([]() -> py::object { + return py::module_::import("torch._jit_internal") + .attr("_extract_tensors"); + }) + .get_stored(); return extractorFn(py_obj_).cast>(); } catch (py::error_already_set& e) { auto err = std::runtime_error( diff --git a/torch/csrc/utils/python_arg_parser.cpp b/torch/csrc/utils/python_arg_parser.cpp index 181a66d2a13823..b52ca76830bffa 100644 --- a/torch/csrc/utils/python_arg_parser.cpp +++ b/torch/csrc/utils/python_arg_parser.cpp @@ -263,13 +263,15 @@ static PyObject* get_type_of_overloaded_arg(PyObject* obj_or_type) { static py::object maybe_get_registered_torch_dispatch_rule( PyObject* torch_api_function, const py::object& torch_dispatch_object) { - // This is a static object, so we must leak the Python object - // "release()" is used here to preserve 1 refcount on the - // object, preventing it from ever being de-allocated by CPython. - static const py::handle find_torch_dispatch_rule = - py::object(py::module_::import("torch._library.simple_registry") - .attr("find_torch_dispatch_rule")) - .release(); + PYBIND11_CONSTINIT static py::gil_safe_call_once_and_store + storage; + py::object find_torch_dispatch_rule = + storage + .call_once_and_store_result([]() -> py::object { + return py::module_::import("torch._library.simple_registry") + .attr("find_torch_dispatch_rule"); + }) + .get_stored(); auto result = find_torch_dispatch_rule( py::reinterpret_borrow(torch_api_function), torch_dispatch_object.get_type()); diff --git a/torch/csrc/utils/python_symnode.cpp b/torch/csrc/utils/python_symnode.cpp index f8f3d79cf3494d..d853e361139c55 100644 --- a/torch/csrc/utils/python_symnode.cpp +++ b/torch/csrc/utils/python_symnode.cpp @@ -3,24 +3,33 @@ namespace torch { py::handle get_symint_class() { - // NB: leak - static py::handle symint_class = - py::object(py::module::import("torch").attr("SymInt")).release(); - return symint_class; + PYBIND11_CONSTINIT static py::gil_safe_call_once_and_store + storage; + return storage + .call_once_and_store_result([]() -> py::object { + return py::module::import("torch").attr("SymInt"); + }) + .get_stored(); } py::handle get_symfloat_class() { - // NB: leak - static py::handle symfloat_class = - py::object(py::module::import("torch").attr("SymFloat")).release(); - return symfloat_class; + PYBIND11_CONSTINIT static py::gil_safe_call_once_and_store + storage; + return storage + .call_once_and_store_result([]() -> py::object { + return py::module::import("torch").attr("SymFloat"); + }) + .get_stored(); } py::handle get_symbool_class() { - // NB: leak - static py::handle symbool_class = - py::object(py::module::import("torch").attr("SymBool")).release(); - return symbool_class; + PYBIND11_CONSTINIT static py::gil_safe_call_once_and_store + storage; + return storage + .call_once_and_store_result([]() -> py::object { + return py::module::import("torch").attr("SymBool"); + }) + .get_stored(); } } // namespace torch