@@ -731,7 +731,7 @@ def emit_sequence_constructor(self, name, type):
731
731
class PyTypesDeclareVisitor (PickleVisitor ):
732
732
733
733
def visitProduct (self , prod , name ):
734
- self .emit ("static PyObject* ast2obj_%s(struct ast_state *state, void*);" % name , 0 )
734
+ self .emit ("static PyObject* ast2obj_%s(struct ast_state *state, struct validator *vstate, void*);" % name , 0 )
735
735
if prod .attributes :
736
736
self .emit ("static const char * const %s_attributes[] = {" % name , 0 )
737
737
for a in prod .attributes :
@@ -752,7 +752,7 @@ def visitSum(self, sum, name):
752
752
ptype = "void*"
753
753
if is_simple (sum ):
754
754
ptype = get_c_type (name )
755
- self .emit ("static PyObject* ast2obj_%s(struct ast_state *state, %s);" % (name , ptype ), 0 )
755
+ self .emit ("static PyObject* ast2obj_%s(struct ast_state *state, struct validator *vstate, %s);" % (name , ptype ), 0 )
756
756
for t in sum .types :
757
757
self .visitConstructor (t , name )
758
758
@@ -984,15 +984,16 @@ def visitModule(self, mod):
984
984
985
985
/* Conversion AST -> Python */
986
986
987
- static PyObject* ast2obj_list(struct ast_state *state, asdl_seq *seq, PyObject* (*func)(struct ast_state *state, void*))
987
+ static PyObject* ast2obj_list(struct ast_state *state, struct validator *vstate, asdl_seq *seq,
988
+ PyObject* (*func)(struct ast_state *state, struct validator *vstate, void*))
988
989
{
989
990
Py_ssize_t i, n = asdl_seq_LEN(seq);
990
991
PyObject *result = PyList_New(n);
991
992
PyObject *value;
992
993
if (!result)
993
994
return NULL;
994
995
for (i = 0; i < n; i++) {
995
- value = func(state, asdl_seq_GET_UNTYPED(seq, i));
996
+ value = func(state, vstate, asdl_seq_GET_UNTYPED(seq, i));
996
997
if (!value) {
997
998
Py_DECREF(result);
998
999
return NULL;
@@ -1002,7 +1003,7 @@ def visitModule(self, mod):
1002
1003
return result;
1003
1004
}
1004
1005
1005
- static PyObject* ast2obj_object(struct ast_state *Py_UNUSED(state), void *o)
1006
+ static PyObject* ast2obj_object(struct ast_state *Py_UNUSED(state), struct validator *Py_UNUSED(vstate), void *o)
1006
1007
{
1007
1008
PyObject *op = (PyObject*)o;
1008
1009
if (!op) {
@@ -1014,7 +1015,7 @@ def visitModule(self, mod):
1014
1015
#define ast2obj_identifier ast2obj_object
1015
1016
#define ast2obj_string ast2obj_object
1016
1017
1017
- static PyObject* ast2obj_int(struct ast_state *Py_UNUSED(state), long b)
1018
+ static PyObject* ast2obj_int(struct ast_state *Py_UNUSED(state), struct validator *Py_UNUSED(vstate), long b)
1018
1019
{
1019
1020
return PyLong_FromLong(b);
1020
1021
}
@@ -1116,8 +1117,6 @@ def visitModule(self, mod):
1116
1117
for dfn in mod .dfns :
1117
1118
self .visit (dfn )
1118
1119
self .file .write (textwrap .dedent ('''
1119
- state->recursion_depth = 0;
1120
- state->recursion_limit = 0;
1121
1120
return 0;
1122
1121
}
1123
1122
''' ))
@@ -1260,25 +1259,25 @@ class ObjVisitor(PickleVisitor):
1260
1259
def func_begin (self , name ):
1261
1260
ctype = get_c_type (name )
1262
1261
self .emit ("PyObject*" , 0 )
1263
- self .emit ("ast2obj_%s(struct ast_state *state, void* _o)" % (name ), 0 )
1262
+ self .emit ("ast2obj_%s(struct ast_state *state, struct validator *vstate, void* _o)" % (name ), 0 )
1264
1263
self .emit ("{" , 0 )
1265
1264
self .emit ("%s o = (%s)_o;" % (ctype , ctype ), 1 )
1266
1265
self .emit ("PyObject *result = NULL, *value = NULL;" , 1 )
1267
1266
self .emit ("PyTypeObject *tp;" , 1 )
1268
1267
self .emit ('if (!o) {' , 1 )
1269
1268
self .emit ("Py_RETURN_NONE;" , 2 )
1270
1269
self .emit ("}" , 1 )
1271
- self .emit ("if (++state ->recursion_depth > state ->recursion_limit) {" , 1 )
1270
+ self .emit ("if (++vstate ->recursion_depth > vstate ->recursion_limit) {" , 1 )
1272
1271
self .emit ("PyErr_SetString(PyExc_RecursionError," , 2 )
1273
1272
self .emit ('"maximum recursion depth exceeded during ast construction");' , 3 )
1274
1273
self .emit ("return NULL;" , 2 )
1275
1274
self .emit ("}" , 1 )
1276
1275
1277
1276
def func_end (self ):
1278
- self .emit ("state ->recursion_depth--;" , 1 )
1277
+ self .emit ("vstate ->recursion_depth--;" , 1 )
1279
1278
self .emit ("return result;" , 1 )
1280
1279
self .emit ("failed:" , 0 )
1281
- self .emit ("state ->recursion_depth--;" , 1 )
1280
+ self .emit ("vstate ->recursion_depth--;" , 1 )
1282
1281
self .emit ("Py_XDECREF(value);" , 1 )
1283
1282
self .emit ("Py_XDECREF(result);" , 1 )
1284
1283
self .emit ("return NULL;" , 1 )
@@ -1296,15 +1295,15 @@ def visitSum(self, sum, name):
1296
1295
self .visitConstructor (t , i + 1 , name )
1297
1296
self .emit ("}" , 1 )
1298
1297
for a in sum .attributes :
1299
- self .emit ("value = ast2obj_%s(state, o->%s);" % (a .type , a .name ), 1 )
1298
+ self .emit ("value = ast2obj_%s(state, vstate, o->%s);" % (a .type , a .name ), 1 )
1300
1299
self .emit ("if (!value) goto failed;" , 1 )
1301
1300
self .emit ('if (PyObject_SetAttr(result, state->%s, value) < 0)' % a .name , 1 )
1302
1301
self .emit ('goto failed;' , 2 )
1303
1302
self .emit ('Py_DECREF(value);' , 1 )
1304
1303
self .func_end ()
1305
1304
1306
1305
def simpleSum (self , sum , name ):
1307
- self .emit ("PyObject* ast2obj_%s(struct ast_state *state, %s_ty o)" % (name , name ), 0 )
1306
+ self .emit ("PyObject* ast2obj_%s(struct ast_state *state, struct validator *vstate, %s_ty o)" % (name , name ), 0 )
1308
1307
self .emit ("{" , 0 )
1309
1308
self .emit ("switch(o) {" , 1 )
1310
1309
for t in sum .types :
@@ -1322,7 +1321,7 @@ def visitProduct(self, prod, name):
1322
1321
for field in prod .fields :
1323
1322
self .visitField (field , name , 1 , True )
1324
1323
for a in prod .attributes :
1325
- self .emit ("value = ast2obj_%s(state, o->%s);" % (a .type , a .name ), 1 )
1324
+ self .emit ("value = ast2obj_%s(state, vstate, o->%s);" % (a .type , a .name ), 1 )
1326
1325
self .emit ("if (!value) goto failed;" , 1 )
1327
1326
self .emit ("if (PyObject_SetAttr(result, state->%s, value) < 0)" % a .name , 1 )
1328
1327
self .emit ('goto failed;' , 2 )
@@ -1363,7 +1362,7 @@ def set(self, field, value, depth):
1363
1362
self .emit ("for(i = 0; i < n; i++)" , depth + 1 )
1364
1363
# This cannot fail, so no need for error handling
1365
1364
self .emit (
1366
- "PyList_SET_ITEM(value, i, ast2obj_{0}(state, ({0}_ty)asdl_seq_GET({1}, i)));" .format (
1365
+ "PyList_SET_ITEM(value, i, ast2obj_{0}(state, vstate, ({0}_ty)asdl_seq_GET({1}, i)));" .format (
1367
1366
field .type ,
1368
1367
value
1369
1368
),
@@ -1372,9 +1371,9 @@ def set(self, field, value, depth):
1372
1371
)
1373
1372
self .emit ("}" , depth )
1374
1373
else :
1375
- self .emit ("value = ast2obj_list(state, (asdl_seq*)%s, ast2obj_%s);" % (value , field .type ), depth )
1374
+ self .emit ("value = ast2obj_list(state, vstate, (asdl_seq*)%s, ast2obj_%s);" % (value , field .type ), depth )
1376
1375
else :
1377
- self .emit ("value = ast2obj_%s(state, %s);" % (field .type , value ), depth , reflow = False )
1376
+ self .emit ("value = ast2obj_%s(state, vstate, %s);" % (field .type , value ), depth , reflow = False )
1378
1377
1379
1378
1380
1379
class PartingShots (StaticVisitor ):
@@ -1394,18 +1393,19 @@ class PartingShots(StaticVisitor):
1394
1393
if (!tstate) {
1395
1394
return NULL;
1396
1395
}
1397
- state->recursion_limit = Py_C_RECURSION_LIMIT * COMPILER_STACK_FRAME_SCALE;
1396
+ struct validator vstate;
1397
+ vstate.recursion_limit = Py_C_RECURSION_LIMIT * COMPILER_STACK_FRAME_SCALE;
1398
1398
int recursion_depth = Py_C_RECURSION_LIMIT - tstate->c_recursion_remaining;
1399
1399
starting_recursion_depth = recursion_depth * COMPILER_STACK_FRAME_SCALE;
1400
- state-> recursion_depth = starting_recursion_depth;
1400
+ vstate. recursion_depth = starting_recursion_depth;
<
1009E
code>1401 1401
1402
- PyObject *result = ast2obj_mod(state, t);
1402
+ PyObject *result = ast2obj_mod(state, &vstate, t);
1403
1403
1404
1404
/* Check that the recursion depth counting balanced correctly */
1405
- if (result && state-> recursion_depth != starting_recursion_depth) {
1405
+ if (result && vstate. recursion_depth != starting_recursion_depth) {
1406
1406
PyErr_Format(PyExc_SystemError,
1407
1407
"AST constructor recursion depth mismatch (before=%d, after=%d)",
1408
- starting_recursion_depth, state-> recursion_depth);
1408
+ starting_recursion_depth, vstate. recursion_depth);
1409
1409
return NULL;
1410
1410
}
1411
1411
return result;
@@ -1475,8 +1475,6 @@ def generate_ast_state(module_state, f):
1475
1475
f .write ('struct ast_state {\n ' )
1476
1476
f .write (' _PyOnceFlag once;\n ' )
1477
1477
f .write (' int finalized;\n ' )
1478
- f .write (' int recursion_depth;\n ' )
1479
- f .write (' int recursion_limit;\n ' )
1480
1478
for s in module_state :
1481
1479
f .write (' PyObject *' + s + ';\n ' )
1482
1480
f .write ('};' )
@@ -1539,6 +1537,11 @@ def generate_module_def(mod, metadata, f, internal_h):
1539
1537
#include "pycore_pystate.h" // _PyInterpreterState_GET()
1540
1538
#include <stddef.h>
1541
1539
1540
+ struct validator {
1541
+ int recursion_depth; /* current recursion depth */
1542
+ int recursion_limit; /* recursion limit */
1543
+ };
1544
+
1542
1545
// Forward declaration
1543
1546
static int init_types(struct ast_state *state);
1544
1547
0 commit comments