8000 gh-115490: Make the interpreter.channels and interpreter.queues Modules Handle Reloading Properly by ericsnowcurrently · Pull Request #115493 · python/cpython · GitHub
[go: up one dir, main page]

Skip to content

gh-115490: Make the interpreter.channels and interpreter.queues Modules Handle Reloading Properly #115493

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
1 change: 0 additions & 1 deletion Lib/test/support/interpreters/channels.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,5 +167,4 @@ def close(self):
_channels.close(self._id, send=True)


# XXX This is causing leaks (gh-110318):
_channels._register_end_types(SendChannel, RecvChannel)
179 changes: 83 additions & 96 deletions Modules/_xxinterpchannelsmodule.c
< 10000 td id="diff-78fc17a03e8a02168b63278bea03622a9ac267330c5d73226ec8fc36f5550906L119" data-line-number="119" class="blob-num blob-num-deletion js-linkable-line-number">
Original file line number Diff line number Diff line change
Expand Up @@ -93,35 +93,35 @@ API.. The module does not create any objects that are shared globally.
PyMem_RawFree(VAR)


struct xid_class_registry {
size_t count;
#define MAX_XID_CLASSES 5
struct {
PyTypeObject *cls;
} added[MAX_XID_CLASSES];
struct xid_types {
/* Added at runtime by interpreters module. */
PyTypeObject *SendChannel;
PyTypeObject *RecvChannel;

/* heap types */
PyTypeObject *ChannelID;
};

static int
register_xid_class(PyTypeObject *cls, crossinterpdatafunc shared,
struct xid_class_registry *classes)
static void
clear_xid_types(struct xid_types *types)
{
int res = ensure_xid_class(cls, shared);
if (res == 0) {
assert(classes->count < MAX_XID_CLASSES);
// The class has refs elsewhere, so we need to incref here.
classes->added[classes->count].cls = cls;
classes->count += 1;
/* external types */
if (types->SendChannel != NULL) {
(void)_PyCrossInterpreterData_UnregisterClass(
types->SendChannel);
Py_CLEAR(types->SendChannel);
}
if (types->RecvChannel != NULL) {
(void)_PyCrossInterpreterData_UnregisterClass(
types->RecvChannel);
Py_CLEAR(types->RecvChannel);
}
return res;
}

static void
clear_xid_class_registry(struct xid_class_registry *classes)
{
while (classes->count > 0) {
classes->count -= 1;
PyTypeObject *cls = classes->added[classes->count].cls;
_PyCrossInterpreterData_UnregisterClass(cls);
/* heap types */
if (types->ChannelID != NULL) {
(void)_PyCrossInterpreterData_UnregisterClass(
types->ChannelID);
Py_CLEAR(types->ChannelID);
}
}

Expand Down Expand Up @@ -223,28 +223,6 @@ add_new_exception(PyObject *mod, const char *name, PyObject *base)
#define ADD_NEW_EXCEPTION(MOD, NAME, BASE) \
add_new_exception(MOD, MODULE_NAME_STR "." Py_STRINGIFY(NAME), BASE)

static PyTypeObject *
add_new_type(PyObject *mod, PyType_Spec *spec, crossinterpdatafunc shared,
struct xid_class_registry *classes)
{
PyTypeObject *cls = (PyTypeObject *)PyType_FromModuleAndSpec(
mod, spec, NULL);
if (cls == NULL) {
return NULL;
}
if (PyModule_AddType(mod, cls) < 0) {
Py_DECREF(cls);
return NULL;
}
if (shared != NULL) {
if (register_xid_class(cls, shared, classes)) {
Py_DECREF(cls);
return NULL;
}
}
return cls;
}

static int
wait_for_lock(PyThread_type_lock mutex, PY_TIMEOUT_T timeout)
{
Expand All @@ -269,15 +247,10 @@ wait_for_lock(PyThread_type_lock mutex, PY_TIMEOUT_T timeout)
/* module state *************************************************************/

typedef struct {
struct xid_class_registry xid_classes;

/* Added at runtime by interpreters module. */
PyTypeObject *send_channel_type;
PyTypeObject *recv_channel_type;
struct xid_types xid_types;

/* heap types */
PyTypeObject *ChannelInfoType;
PyTypeObject *ChannelIDType;

/* exceptions */
PyObject *ChannelError;
Expand Down Expand Up @@ -315,12 +288,12 @@ static int
traverse_module_state(module_state *state, visitproc visit, void *arg)
{
/* external types */
Py_VISIT(state->send_channel_type);
Py_VISIT(state->recv_channel_type);
Py_VISIT(state->xid_types.SendChannel);
Py_VISIT(state->xid_types.RecvChannel);

/* heap types */
Py_VISIT(state->ChannelInfoType);
Py_VISIT(state->ChannelIDType);
Py_VISIT(state->xid_types.ChannelID);

/* exceptions */
Py_VISIT(state->ChannelError);
Expand All @@ -335,16 +308,10 @@ traverse_module_state(module_state *state, visitproc visit, void *arg)
static int
clear_module_state(module_state *state)
{
/* external types */
Py_CLEAR(state->send_channel_type);
Py_CLEAR(state->recv_channel_type);
clear_xid_types(&state->xid_types);

/* heap types */
Py_CLEAR(state->ChannelInfoType);
if (state->ChannelIDType != NULL) {
(void)_PyCrossInterpreterData_UnregisterClass(state->ChannelIDType);
}
Py_CLEAR(state->ChannelIDType);

/* exceptions */
Py_CLEAR(state->ChannelError);
Expand Down Expand Up @@ -2210,7 +2177,7 @@ channel_id_converter(PyObject *arg, void *ptr)
struct channel_id_converter_data *data = ptr;
module_state *state = get_module_state(data->module);
assert(state != NULL);
if (PyObject_TypeCheck(arg, state->ChannelIDType)) {
if (PyObject_TypeCheck(arg, state->xid_types.ChannelID)) {
cid = ((channelid *)arg)->cid;
end = ((channelid *)arg)->end;
}
Expand Down Expand Up @@ -2406,14 +2373,14 @@ channelid_richcompare(PyObject *self, PyObject *other, int op)
goto done;
}

if (!PyObject_TypeCheck(self, state->ChannelIDType)) {
if (!PyObject_TypeCheck(self, state->xid_types.ChannelID)) {
res = Py_NewRef(Py_NotImplemented);
goto done;
}

channelid *cidobj = (channelid *)self;
int equal;
if (PyObject_TypeCheck(other, state->ChannelIDType)) {
if (PyObject_TypeCheck(other, state->xid_types.ChannelID)) {
channelid *othercidobj = (channelid *)other;
equal = (cidobj->end == othercidobj->end) && (cidobj->cid == othercidobj->cid);
}
Expand Down Expand Up @@ -2494,7 +2461,7 @@ _channelid_from_xid(_PyCrossInterpreterData *data)

// Note that we do not preserve the "resolve" flag.
PyObject *cidobj = NULL;
int err = newchannelid(state->ChannelIDType, xid->cid, xid->end,
int err = newchannelid(state->xid_types.ChannelID, xid->cid, xid->end,
_global_channels(), 0, 0,
(channelid **)&cidobj);
if (err != 0) {
Expand Down Expand Up @@ -2614,6 +2581,26 @@ static PyType_Spec channelid_typespec = {
.slots = channelid_typeslots,
};

static PyTypeObject *
add_channelid_type(PyObject *mod)
{
PyTypeObject *cls = (PyTypeObject *)PyType_FromModuleAndSpec(
mod, &channelid_typespec, NULL);
if (cls == NULL) {
return NULL;
}
if (ensure_xid_class(cls, _channelid_shared)) {
Py_DECREF(cls);
return NULL;
}
if (PyModule_AddType(mod, cls) < 0) {
(void)_PyCrossInterpreterData_UnregisterClass(cls);
Py_DECREF(cls);
return NULL;
}
return cls;
}


/* SendChannel and RecvChannel classes */

Expand All @@ -2628,11 +2615,11 @@ _get_current_channelend_type(int end)
}
PyTypeObject *cls;
if (end == CHANNEL_SEND) {
cls = state->send_channel_type;
cls = state->xid_types.SendChannel;
}
else {
assert(end == CHANNEL_RECV);
cls = state->recv_channel_type;
cls = state->xid_types.RecvChannel;
}
if (cls == NULL) {
// Force the module to be loaded, to register the type.
Expand All @@ -2646,10 +2633,10 @@ _get_current_channelend_type(int end)
}
Py_DECREF(highlevel);
if (end == CHANNEL_SEND) {
cls = state->send_channel_type;
cls = state->xid_types.SendChannel;
}
else {
cls = state->recv_channel_type;
cls = state->xid_types.RecvChannel;
}
assert(cls != NULL);
}
Expand Down Expand Up @@ -2697,23 +2684,30 @@ set_channelend_types(PyObject *mod, PyTypeObject *send, PyTypeObject *recv)
if (state == NULL) {
return -1;
}
struct xid_class_registry *xid_classes = &state->xid_classes;
struct xid_types *types = &state->xid_types;

if (state->send_channel_type != NULL
|| state->recv_channel_type != NULL)
{
PyErr_SetString(PyExc_TypeError, "already registered");
return -1;
// Clear the old values if the .py module was reloaded.
if (types->SendChannel != NULL) {
(void)_PyCrossInterpreterData_UnregisterClass(
types->SendChannel);
Py_CLEAR(types->SendChannel);
}
if (types->RecvChannel != NULL) {
(void)_PyCrossInterpreterData_UnregisterClass(
types->RecvChannel);
Py_CLEAR(types->RecvChannel);
}
state->send_channel_type = (PyTypeObject *)Py_NewRef(send);
state->recv_channel_type = (PyTypeObject *)Py_NewRef(recv);

if (register_xid_class(send, _channelend_shared, xid_classes)) {
// Add and register the new types.
if (ensure_xid_class(send, _channelend_shared) < 0) {
return -1;
}
if (register_xid_class(recv, _ch 9E7A annelend_shared, xid_classes)) {
if (ensure_xid_class(recv, _channelend_shared) < 0) {
(void)_PyCrossInterpreterData_UnregisterClass(send);
return -1;
}
types->SendChannel = (PyTypeObject *)Py_NewRef(send);
types->RecvChannel = (PyTypeObject *)Py_NewRef(recv);

return 0;
}
Expand Down Expand Up @@ -2792,7 +2786,7 @@ channelsmod_create(PyObject *self, PyObject *Py_UNUSED(ignored))
return NULL;
}
PyObject *cidobj = NULL;
int err = newchannelid(state->ChannelIDType, cid, 0,
int err = newchannelid(state->xid_types.ChannelID, cid, 0,
&_globals.channels, 0, 0,
(channelid **)&cidobj);
if (handle_channel_error(err, self, cid)) {
Expand Down Expand Up @@ -2864,7 +2858,7 @@ channelsmod_list_all(PyObject *self, PyObject *Py_UNUSED(ignored))
int64_t *cur = cids;
for (int64_t i=0; i < count; cur++, i++) {
PyObject *cidobj = NULL;
int err = newchannelid(state->ChannelIDType, *cur, 0,
int err = newchannelid(state->xid_types.ChannelID, *cur, 0,
&_globals.channels, 0, 0,
(channelid **)&cidobj);
if (handle_channel_error(err, self, *cur)) {
Expand Down Expand Up @@ -3214,7 +3208,7 @@ channelsmod__channel_id(PyObject *self, PyObject *args, PyObject *kwds)
if (state == NULL) {
return NULL;
}
PyTypeObject *cls = state->ChannelIDType;
PyTypeObject *cls = state->xid_types.ChannelID;

PyObject *mod = get_module_from_owned_type(cls);
assert(mod == self);
Expand Down Expand Up @@ -3294,13 +3288,13 @@ module_exec(PyObject *mod)
if (_globals_init() != 0) {
return -1;
}
struct xid_class_registry *xid_classes = NULL;
struct xid_types *xid_types = NULL;

module_state *state = get_module_state(mod);
if (state == NULL) {
goto error;
}
xid_classes = &state->xid_classes;
xid_types = &state->xid_types;

/* Add exception types */
if (exceptions_init(mod) != 0) {
Expand All @@ -3319,9 +3313,8 @@ module_exec(PyObject *mod)
}

// ChannelID
state->ChannelIDType = add_new_type(
mod, &channelid_typespec, _channelid_shared, xid_classes);
if (state->ChannelIDType == NULL) {
xid_types->ChannelID = add_channelid_type(mod);
if (xid_types->ChannelID == NULL) {
goto error;
}

Expand All @@ -3332,8 +3325,8 @@ module_exec(PyObject *mod)
return 0;

error:
if (xid_classes != NULL) {
clear_xid_class_registry(xid_classes);
if (xid_types != NULL) {
clear_xid_types(xid_types);
}
_globals_fini();
return -1;
Expand All @@ -3360,9 +3353,6 @@ module_clear(PyObject *mod)
module_state *state = get_module_state(mod);
assert(state != NULL);

// Before clearing anything, we unregister the various XID types. */
clear_xid_class_registry(&state->xid_classes);

// Now we clear the module state.
clear_module_state(state);
return 0;
Expand All @@ -3374,9 +3364,6 @@ module_free(void *mod)
module_state *state = get_module_state(mod);
assert(state != NULL);

// Before clearing anything, we unregister the various XID types. */
clear_xid_class_registry(&state->xid_classes);

// Now we clear the module state.
clear_module_state(state);

Expand Down
Loading
0