8000 gh-115490: Make the interpreter.channels and interpreter.queues Modules Handle Reloading Properly by ericsnowcurrently · Pull Request #115493 · python/cpython · GitHub
[go: up one dir, main page]

Skip to content

gh-115490: Make the interpreter.channels and interpreter.queues Modules Handle Reloading Properly #115493

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Prev Previous commit
Next Next commit
Revert "Fix the channels module."
This reverts commit 5316560.
  • Loading branch information
ericsnowcurrently committed Feb 15, 2024
commit 18fa5d4868e028ab28b6952c32dfc8f744ca09d0
1 change: 1 addition & 0 deletions Lib/test/support/interpreters/channels.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,4 +167,5 @@ def close(self):
_channels.close(self._id, send=True)


# XXX This is causing leaks (gh-110318):
_channels._register_end_types(SendChannel, RecvChannel)
179 changes: 96 additions & 83 deletions Modules/_xxinterpchannelsmodule.c
Original file line number Diff line number Diff line change
Expand Up @@ -93,35 +93,35 @@ API.. The module does not create any objects that are shared globally.
PyMem_RawFree(VAR)


struct xid_types {
/* Added at runtime by interpreters module. */
PyTypeObject *SendChannel;
PyTypeObject *RecvChannel;

/* heap types */
PyTypeObject *ChannelID;
struct xid_class_registry {
size_t count;
#define MAX_XID_CLASSES 5
struct {
PyTypeObject *cls;
} added[MAX_XID_CLASSES];
};

static void
clear_xid_types(struct xid_types *types)
static int
register_xid_class(PyTypeObject *cls, crossinterpdatafunc shared,
struct xid_class_registry *classes)
{
/* external types */
if (types->SendChannel != NULL) {
(void)_PyCrossInterpreterData_UnregisterClass(
types->SendChannel);
Py_CLEAR(types->SendChannel);
}
if (types->RecvChannel != NULL) {
(void)_PyCrossInterpreterData_UnregisterClass(
types->RecvChannel);
Py_CLEAR(types->RecvChannel);
int res = ensure_xid_class(cls, shared);
if (res == 0) {
assert(classes->count < MAX_XID_CLASSES);
// The class has refs elsewhere, so we need to incref here.
classes->added[classes->count].cls = cls;
classes->count += 1;
}
return res;
}

/* heap types */
if (types->ChannelID != NULL) {
(void)_PyCrossInterpreterData_UnregisterClass(
types->ChannelID);
Py_CLEAR(types->ChannelID);
static void
clear_xid_class_registry(struct xid_class_registry *classes)
{
while (classes->count > 0) {
classes->count -= 1;
PyTypeObject *cls = classes->added[classes->count].cls;
_PyCrossInterpreterData_UnregisterClass(cls);
}
}

Expand Down Expand Up @@ -223,6 +223,28 @@ add_new_exception(PyObject *mod, const char *name, PyObject *base)
#define ADD_NEW_EXCEPTION(MOD, NAME, BASE) \
add_new_exception(MOD, MODULE_NAME_STR "." Py_STRINGIFY(NAME), BASE)

static PyTypeObject *
add_new_type(PyObject *mod, PyType_Spec *spec, crossinterpdatafunc shared,
struct xid_class_registry *classes)
{
PyTypeObject *cls = (PyTypeObject *)PyType_FromModuleAndSpec(
mod, spec, NULL);
if (cls == NULL) {
return NULL;
}
if (PyModule_AddType(mod, cls) < 0) {
Py_DECREF(cls);
return NULL;
}
if (shared != NULL) {
if (register_xid_class(cls, shared, classes)) {
Py_DECREF(cls);
return NULL;
}
}
return cls;
}

static int
wait_for_lock(PyThread_type_lock mutex, PY_TIMEOUT_T timeout)
{
Expand All @@ -247,10 +269,15 @@ wait_for_lock(PyThread_type_lock mutex, PY_TIMEOUT_T timeout)
/* module state *************************************************************/

typedef struct {
struct xid_types xid_types;
struct xid_class_registry xid_classes;

/* Added at runtime by interpreters module. */
PyTypeObject *send_channel_type;
PyTypeObject *recv_channel_type;

/* heap types */
PyTypeObject *ChannelInfoType;
PyTypeObject *ChannelIDType;

/* exceptions */
PyObject *ChannelError;
Expand Down Expand Up @@ -288,12 +315,12 @@ static int
traverse_module_state(module_state *state, visitproc visit, void *arg)
{
/* external types */
Py_VISIT(state->xid_types.SendChannel);
Py_VISIT(state->xid_types.RecvChannel);
Py_VISIT(state->send_channel_type);
Py_VISIT(state->recv_channel_type);

/* heap types */
Py_VISIT(state->ChannelInfoType);
Py_VISIT(state->xid_types.ChannelID);
Py_VISIT(state->ChannelIDType);

/* exceptions */
Py_VISIT(state->ChannelError);
Expand All @@ -308,10 +335,16 @@ traverse_module_state(module_state *state, visitproc visit, void *arg)
static int
clear_module_state(module_state *state)
{
clear_xid_types(&state->xid_types);
/* external types */
Py_CLEAR(state->send_channel_type);
Py_CLEAR(state->recv_channel_type);

/* heap types */
Py_CLEAR(state->ChannelInfoType);
if (state->ChannelIDType != NULL) {
(void)_PyCrossInterpreterData_UnregisterClass(state->ChannelIDType);
}
Py_CLEAR(state->ChannelIDType);

/* exceptions */
Py_CLEAR(state->ChannelError);
Expand Down Expand Up @@ -2177,7 +2210,7 @@ channel_id_converter(PyObject *arg, void *ptr)
struct channel_id_converter_data *data = ptr;
module_state *state = get_module_state(data->module);
assert(state != NULL);
if (PyObject_TypeCheck(arg, state->xid_types.ChannelID)) {
if (PyObject_TypeCheck(arg, state->ChannelIDType)) {
cid = ((channelid *)arg)->cid;
end = ((channelid *)arg)->end;
}
Expand Down Expand Up @@ -2373,14 +2406,14 @@ channelid_richcompare(PyObject *self, PyObject *other, int op)
goto done;
}

if (!PyObject_TypeCheck(self, state->xid_types.ChannelID)) {
if (!PyObject_TypeCheck(self, state->ChannelIDType)) {
res = Py_NewRef(Py_NotImplemented);
goto done;
}

channelid *cidobj = (channelid *)self;
int equal;
if (PyObject_TypeCheck(other, state->xid_types.ChannelID)) {
if (PyObject_TypeCheck(other, state->ChannelIDType)) {
channelid *othercidobj = (channelid *)other;
equal = (cidobj->end == othercidobj->end) && (cidobj->cid == othercidobj->cid);
}
Expand Down Expand Up @@ -2461,7 +2494,7 @@ _channelid_from_xid(_PyCrossInterpreterData *data)

// Note that we do not preserve the "resolve" flag.
PyObject *cidobj = NULL;
int err = newchannelid(state->xid_types.ChannelID, xid->cid, xid->end,
int err = newchannelid(state->ChannelIDType, xid->cid, xid->end,
_global_channels(), 0, 0,
(channelid **)&cidobj);
if (err != 0) {
Expand Down Expand Up @@ -2581,26 +2614,6 @@ static PyType_Spec channelid_typespec = {
.slots = channelid_typeslots,
};

static PyTypeObject *
add_channelid_type(PyObject *mod)
{
PyTypeObject *cls = (PyTypeObject *)PyType_FromModuleAndSpec(
mod, &channelid_typespec, NULL);
if (cls == NULL) {
return NULL;
}
if (ensure_xid_class(cls, _channelid_shared)) {
Py_DECREF(cls);
return NULL;
}
if (PyModule_AddType(mod, cls) < 0) {
(void)_PyCrossInterpreterData_UnregisterClass(cls);
Py_DECREF(cls);
return NULL;
}
return cls;
}


/* SendChannel and RecvChannel classes */

Expand All @@ -2615,11 +2628,11 @@ _get_current_channelend_type(int end)
}
PyTypeObject *cls;
if (end == CHANNEL_SEND) {
cls = state->xid_types.SendChannel;
cls = state->send_channel_type;
}
else {
assert(end == CHANNEL_RECV);
cls = state->xid_types.RecvChannel;
cls = state->recv_channel_type;
}
if (cls == NULL) {
// Force the module to be loaded, to register the type.
Expand All @@ -2633,10 +2646,10 @@ _get_current_channelend_type(int end)
}
Py_DECREF(highlevel);
if (end == CHANNEL_SEND) {
cls = state->xid_types.SendChannel;
cls = state->send_channel_type;
}
else {
cls = state->xid_types.RecvChannel;
cls = state->recv_channel_type;
}
assert(cls != NULL);
}
Expand Down Expand Up @@ -2684,30 +2697,23 @@ set_channelend_types(PyObject *mod, PyTypeObject *send, PyTypeObject *recv)
if (state == NULL) {
return -1;
}
struct xid_types *types = &state->xid_types;
struct xid_class_registry *xid_classes = &state->xid_classes;

// Clear the old values if the .py module was reloaded.
if (types->SendChannel != NULL) {
(void)_PyCrossInterpreterData_UnregisterClass(
types->SendChannel);
Py_CLEAR(types->SendChannel);
}
if (types->RecvChannel != NULL) {
(void)_PyCrossInterpreterData_UnregisterClass(
types->RecvChannel);
Py_CLEAR(types->RecvChannel);
if (state->send_channel_type != NULL
|| state->recv_channel_type != NULL)
{
PyErr_SetString(PyExc_TypeError, "already registered");
return -1;
}
state->send_channel_type = (PyTypeObject *)Py_NewRef(send);
state->recv_channel_type = (PyTypeObject *)Py_NewRef(recv);

// Add and register the new types.
if (ensure_xid_class(send, _channelend_shared) < 0) {
if (register_xid_class(send, _channelend_shared, xid_classes)) {
return -1;
}
if (ensure_xid_class(recv, _channelend_shared) < 0) {
(void)_PyCrossInterpreterData_UnregisterClass(send);
if (register_xid_class(recv, _channelend_shared, xid_classes)) {
return -1;
}
types->SendChannel = (PyTypeObject *)Py_NewRef(send);
types->RecvChannel = (PyTypeObject *)Py_NewRef(recv);

return 0;
}
Expand Down Expand Up @@ -2786,7 +2792,7 @@ channelsmod_create(PyObject *self, PyObject *Py_UNUSED(ignored))
return NULL;
}
PyObject *cidobj = NULL;
int err = newchannelid(state->xid_types.ChannelID, cid, 0,
int err = newchannelid(state->ChannelIDType, cid, 0,
&_globals.channels, 0, 0,
(channelid **)&cidobj);
if (handle_channel_error(err, self, cid)) {
Expand Down Expand Up @@ -2858,7 +2864,7 @@ channelsmod_list_all(PyObject *self, PyObject *Py_UNUSED(ignored))
int64_t *cur = cids;
for (int64_t i=0; i < count; cur++, i++) {
PyObject *cidobj = NULL;
int err = newchannelid(state->xid_types.ChannelID, *cur, 0,
int err = newchannelid(state->ChannelIDType, *cur, 0,
&_globals.channels, 0, 0,
(channelid **)&cidobj);
if (handle_channel_error(err, self, *cur)) {
Expand Down Expand Up @@ -3208,7 +3214,7 @@ channelsmod__channel_id(PyObject *self, PyObject *args, PyObject *kwds)
if (state == NULL) {
return NULL;
}
PyTypeObject *cls = state->xid_types.ChannelID;
PyTypeObject *cls = state->ChannelIDType;

PyObject *mod = get_module_from_owned_type(cls);
assert(mod == self);
Expand Down Expand Up @@ -3288,13 +3294,13 @@ module_exec(PyObject *mod)
if (_globals_init() != 0) {
return -1;
}
struct xid_types *xid_types = NULL;
struct xid_class_registry *xid_classes = NULL;

module_state *state = get_module_state(mod);
if (state == NULL) {
goto error;
}
xid_types = &state->xid_types;
xid_classes = &state->xid_classes;

/* Add exception types */
if (exceptions_init(mod) != 0) {
Expand All @@ -3313,8 +3319,9 @@ module_exec(PyObject *mod)
}

// ChannelID
xid_types->ChannelID = add_channelid_type(mod);
if (xid_types->ChannelID == NULL) {
state->ChannelIDType = add_new_type(
mod, &channelid_typespec, _channelid_shared, xid_classes);
if (state->ChannelIDType == NULL) {
goto error;
}

Expand All @@ -3325,8 +3332,8 @@ module_exec(PyObject *mod)
return 0;

error:
if (xid_types != NULL) {
clear_xid_types(xid_types);
if (xid_classes != NULL) {
clear_xid_class_registry(xid_classes);
}
_globals_fini();
return -1;
Expand All @@ -3353,6 +3360,9 @@ module_clear(PyObject *mod)
module_state *state = get_module_state(mod);
assert(state != NULL);

// Before clearing anything, we unregister the various XID types. */
clear_xid_class_registry(&state->xid_classes);

// Now we clear the module state.
clear_module_state(state);
return 0;
Expand All @@ -3364,6 +3374,9 @@ module_free(void *mod)
module_state *state = get_module_state(mod);
assert(state != NULL);

// Before clearing anything, we unregister the various XID types. */
clear_xid_class_registry(&state->xid_classes);

// Now we clear the module state.
clear_module_state(state);

Expand Down
0