8000 gh-106905: Use separate structs to track recursion depth in each PyAS… · python/cpython@48c4973 · GitHub
[go: up one dir, main page]

Skip to content

Commit 48c4973

Browse files
yileigpshead
andauthored
gh-106905: Use separate structs to track recursion depth in each PyAST_mod2obj call. (GH-113035)
Co-authored-by: Gregory P. Smith [Google LLC] <greg@krypto.org>
1 parent 3f5eb3e commit 48c4973

File tree

4 files changed

+412
-339
lines changed

4 files changed

+412
-339
lines changed

Include/internal/pycore_ast_state.h

Lines changed: 0 additions & 2 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
Use per AST-parser state rather than global state to track recursion depth
2+
within the AST parser to prevent potential race condition due to
3+
simultaneous parsing.
4+
5+
The issue primarily showed up in 3.11 by multithreaded users of
6+
:func:`ast.parse`. In 3.12 a change to when garbage collection can be
7+
triggered prevented the race condition from occurring.

Parser/asdl_c.py

Lines changed: 28 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -731,7 +731,7 @@ def emit_sequence_constructor(self, name, type):
731731
class PyTypesDeclareVisitor(PickleVisitor):
732732

733733
def visitProduct(self, prod, name):
734-
self.emit("static PyObject* ast2obj_%s(struct ast_state *state, void*);" % name, 0)
734+
self.emit("static PyObject* ast2obj_%s(struct ast_state *state, struct validator *vstate, void*);" % name, 0)
735735
if prod.attributes:
736736
self.emit("static const char * const %s_attributes[] = {" % name, 0)
737737
for a in prod.attributes:
@@ -752,7 +752,7 @@ def visitSum(self, sum, name):
752752
ptype = "void*"
753753
if is_simple(sum):
754754
ptype = get_c_type(name)
755-
self.emit("static PyObject* ast2obj_%s(struct ast_state *state, %s);" % (name, ptype), 0)
755+
self.emit("static PyObject* ast2obj_%s(struct ast_state *state, struct validator *vstate, %s);" % (name, ptype), 0)
756756
for t in sum.types:
757757
self.visitConstructor(t, name)
758758

@@ -984,15 +984,16 @@ def visitModule(self, mod):
984984
985985
/* Conversion AST -> Python */
986986
987-
static PyObject* ast2obj_list(struct ast_state *state, asdl_seq *seq, PyObject* (*func)(struct ast_state *state, void*))
987+
static PyObject* ast2obj_list(struct ast_state *state, struct validator *vstate, asdl_seq *seq,
988+
PyObject* (*func)(struct ast_state *state, struct validator *vstate, void*))
988989
{
989990
Py_ssize_t i, n = asdl_seq_LEN(seq);
990991
PyObject *result = PyList_New(n);
991992
PyObject *value;
992993
if (!result)
993994
return NULL;
994995
for (i = 0; i < n; i++) {
995-
value = func(state, asdl_seq_GET_UNTYPED(seq, i));
996+
value = func(state, vstate, asdl_seq_GET_UNTYPED(seq, i));
996997
if (!value) {
997998
Py_DECREF(result);
998999
return NULL;
@@ -1002,7 +1003,7 @@ def visitModule(self, mod):
10021003
return result;
10031004
}
10041005
1005-
static PyObject* ast2obj_object(struct ast_state *Py_UNUSED(state), void *o)
1006+
static PyObject* ast2obj_object(struct ast_state *Py_UNUSED(state), struct validator *Py_UNUSED(vstate), void *o)
10061007
{
10071008
PyObject *op = (PyObject*)o;
10081009
if (!op) {
@@ -1014,7 +1015,7 @@ def visitModule(self, mod):
10141015
#define ast2obj_identifier ast2obj_object
10151016
#define ast2obj_string ast2obj_object
10161017
1017-
static PyObject* ast2obj_int(struct ast_state *Py_UNUSED(state), long b)
1018+
static PyObject* ast2obj_int(struct ast_state *Py_UNUSED(state), struct validator *Py_UNUSED(vstate), long b)
10181019
{
10191020
return PyLong_FromLong(b);
10201021
}
@@ -1116,8 +1117,6 @@ def visitModule(self, mod):
11161117
for dfn in mod.dfns:
11171118
self.visit(dfn)
11181119
self.file.write(textwrap.dedent('''
1119-
state->recursion_depth = 0;
1120-
state->recursion_limit = 0;
11211120
return 0;
11221121
}
11231122
'''))
@@ -1260,25 +1259,25 @@ class ObjVisitor(PickleVisitor):
12601259
def func_begin(self, name):
12611260
ctype = get_c_type(name)
12621261
self.emit("PyObject*", 0)
1263-
self.emit("ast2obj_%s(struct ast_state *state, void* _o)" % (name), 0)
1262+
self.emit("ast2obj_%s(struct ast_state *state, struct validator *vstate, void* _o)" % (name), 0)
12641263
self.emit("{", 0)
12651264
self.emit("%s o = (%s)_o;" % (ctype, ctype), 1)
12661265
self.emit("PyObject *result = NULL, *value = NULL;", 1)
12671266
self.emit("PyTypeObject *tp;", 1)
12681267
self.emit('if (!o) {', 1)
12691268
self.emit("Py_RETURN_NONE;", 2)
12701269
self.emit("}", 1)
1271-
self.emit("if (++state->recursion_depth > state->recursion_limit) {", 1)
1270+
self.emit("if (++vstate->recursion_depth > vstate->recursion_limit) {", 1)
12721271
self.emit("PyErr_SetString(PyExc_RecursionError,", 2)
12731272
self.emit('"maximum recursion depth exceeded during ast construction");', 3)
12741273
self.emit("return NULL;", 2)
12751274
self.emit("}", 1)
12761275

12771276
def func_end(self):
1278-
self.emit("state->recursion_depth--;", 1)
1277+
self.emit("vstate->recursion_depth--;", 1)
12791278
self.emit("return result;", 1)
12801279
self.emit("failed:", 0)
1281-
self.emit("state->recursion_depth--;", 1)
1280+
self.emit("vstate->recursion_depth--;", 1)
12821281
self.emit("Py_XDECREF(value);", 1)
12831282
self.emit("Py_XDECREF(result);", 1)
12841283
self.emit("return NULL;", 1)
@@ -1296,15 +1295,15 @@ def visitSum(self, sum, name):
12961295
self.visitConstructor(t, i + 1, name)
12971296
self.emit("}", 1)
12981297
for a in sum.attributes:
1299-
self.emit("value = ast2obj_%s(state, o->%s);" % (a.type, a.name), 1)
1298+
self.emit("value = ast2obj_%s(state, vstate, o->%s);" % (a.type, a.name), 1)
13001299
self.emit("if (!value) goto failed;", 1)
13011300
self.emit('if (PyObject_SetAttr(result, state->%s, value) < 0)' % a.name, 1)
13021301
self.emit('goto failed;', 2)
13031302
self.emit('Py_DECREF(value);', 1)
13041303
self.func_end()
13051304

13061305
def simpleSum(self, sum, name):
1307-
self.emit("PyObject* ast2obj_%s(struct ast_state *state, %s_ty o)" % (name, name), 0)
1306+
self.emit("PyObject* ast2obj_%s(struct ast_state *state, struct validator *vstate, %s_ty o)" % (name, name), 0)
13081307
self.emit("{", 0)
13091308
self.emit("switch(o) {", 1)
13101309
for t in sum.types:
@@ -1322,7 +1321,7 @@ def visitProduct(self, prod, name):
13221321
for field in prod.fields:
13231322
self.visitField(field, name, 1, True)
13241323
for a in prod.attributes:
1325-
self.emit("value = ast2obj_%s(state, o->%s);" % (a.type, a.name), 1)
1324+
self.emit("value = ast2obj_%s(state, vstate, o->%s);" % (a.type, a.name), 1)
13261325
self.emit("if (!value) goto failed;", 1)
13271326
self.emit("if (PyObject_SetAttr(result, state->%s, value) < 0)" % a.name, 1)
13281327
self.emit('goto failed;', 2)
@@ -1363,7 +1362,7 @@ def set(self, field, value, depth):
13631362
self.emit("for(i = 0; i < n; i++)", depth+1)
13641363
# This cannot fail, so no need for error handling
13651364
self.emit(
1366-
"PyList_SET_ITEM(value, i, ast2obj_{0}(state, ({0}_ty)asdl_seq_GET({1}, i)));".format(
1365+
"PyList_SET_ITEM(value, i, ast2obj_{0}(state, vstate, ({0}_ty)asdl_seq_GET({1}, i)));".format(
13671366
field.type,
13681367
value
13691368
),
@@ -1372,9 +1371,9 @@ def set(self, field, value, depth):
13721371
)
13731372
self.emit("}", depth)
13741373
else:
1375-
self.emit("value = ast2obj_list(state, (asdl_seq*)%s, ast2obj_%s);" % (value, field.type), depth)
1374+
self.emit("value = ast2obj_list(state, vstate, (asdl_seq*)%s, ast2obj_%s);" % (value, field.type), depth)
13761375
else:
1377-
self.emit("value = ast2obj_%s(state, %s);" % (field.type, value), depth, reflow=False)
1376+
self.emit("value = ast2obj_%s(state, vstate, %s);" % (field.type, value), depth, reflow=False)
13781377

13791378

13801379
class PartingShots(StaticVisitor):
@@ -1394,18 +1393,19 @@ class PartingShots(StaticVisitor):
13941393
if (!tstate) {
13951394
return NULL;
13961395
}
1397-
state->recursion_limit = Py_C_RECURSION_LIMIT * COMPILER_STACK_FRAME_SCALE;
1396+
struct validator vstate;
1397+
vstate.recursion_limit = Py_C_RECURSION_LIMIT * COMPILER_STACK_FRAME_SCALE;
13981398
int recursion_depth = Py_C_RECURSION_LIMIT - tstate->c_recursion_remaining;
13991399
starting_recursion_depth = recursion_depth * COMPILER_STACK_FRAME_SCALE;
1400-
state->recursion_depth = starting_recursion_depth;
1400+
vstate.recursion_depth = starting_recursion_depth;
< 1009E code>14011401
1402-
PyObject *result = ast2obj_mod(state, t);
1402+
PyObject *result = ast2obj_mod(state, &vstate, t);
14031403
14041404
/* Check that the recursion depth counting balanced correctly */
1405-
if (result && state->recursion_depth != starting_recursion_depth) {
1405+
if (result && vstate.recursion_depth != starting_recursion_depth) {
14061406
PyErr_Format(PyExc_SystemError,
14071407
"AST constructor recursion depth mismatch (before=%d, after=%d)",
1408-
starting_recursion_depth, state->recursion_depth);
1408+
starting_recursion_depth, vstate.recursion_depth);
14091409
return NULL;
14101410
}
14111411
return result;
@@ -1475,8 +1475,6 @@ def generate_ast_state(module_state, f):
14751475
f.write('struct ast_state {\n')
14761476
f.write(' _PyOnceFlag once;\n')
14771477
f.write(' int finalized;\n')
1478-
f.write(' int recursion_depth;\n')
1479-
f.write(' int recursion_limit;\n')
14801478
for s in module_state:
14811479
f.write(' PyObject *' + s + ';\n')
14821480
f.write('};')
@@ -1539,6 +1537,11 @@ def generate_module_def(mod, metadata, f, internal_h):
15391537
#include "pycore_pystate.h" // _PyInterpreterState_GET()
15401538
#include <stddef.h>
15411539
1540+
struct validator {
1541+
int recursion_depth; /* current recursion depth */
1542+
int recursion_limit; /* recursion limit */
1543+
};
1544+
15421545
// Forward declaration
15431546
static int init_types(struct ast_state *state);
15441547

0 commit comments

Comments
 (0)
0