10000 bpo-45648 / gh-89811: check for valid tp_version_tag in specializer · python/cpython@98480f0 · GitHub
[go: up one dir, main page]

Skip to content

Commit 98480f0

Browse files
committed
bpo-45648 / gh-89811: check for valid tp_version_tag in specializer
1 parent 6ca0e67 commit 98480f0

File tree

3 files changed

+147
-3
lines changed

3 files changed

+147
-3
lines changed

Lib/test/test_type_cache.py

Lines changed: 89 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
""" Tests for the internal type cache in CPython. """
22
import unittest
3+
import dis
34
from test import support
45
from test.support import import_helper
56
try:
@@ -8,8 +9,11 @@
89
_clear_type_cache = None
910

1011
# Skip this test if the _testcapi module isn't available.
11-
type_get_version = import_helper.import_module('_testcapi').type_get_version
12-
type_assign_version = import_helper.import_module('_testcapi').type_assign_version
12+
_testcapi = import_helper.import_module("_testcapi")
13+
type_get_version = _testcapi.type_get_version
14+
type_assign_specific_version_unsafe = _testcapi.type_assign_specific_version_unsafe
15+
type_assign_version = _testcapi.type_assign_version
16+
type_modified = _testcapi.type_modified
1317

1418

1519
@support.cpython_only
@@ -56,6 +60,89 @@ class C:
5660
self.assertNotEqual(type_get_version(C), 0)
5761
self.assertNotEqual(type_get_version(C), c_ver)
5862

63+
def test_type_assign_specific_version(self):
64+
"""meta-test for type_assign_specific_version_unsafe"""
65+
class C:
66+
pass
67+
68+
type_assign_version(C)
69+
orig_version = type_get_version(C)
70+
self.assertNotEqual(orig_version, 0)
71+
72+
type_modified(C)
73+
type_assign_specific_version_unsafe(C, orig_version + 5)
74+
type_assign_version(C) # this should do nothing
75+
76+
new_version = type_get_version(C)
77+
self.assertEqual(new_version, orig_version + 5)
78+
79+
def test_specialization_user_type_no_tag_overflow(self):
80+
class A:
81+
def foo(self):
82+
pass
83+
84+
class B:
85+
def foo(self):
86+
pass
87+
88+
type_modified(A)
89+
type_assign_version(A)
90+
type_modified(B)
91+
type_assign_version(B)
92+
self.assertNotEqual(type_get_version(A), 0)
93+
self.assertNotEqual(type_get_version(B), 0)
94+
self.assertNotEqual(type_get_version(A), type_get_version(B))
95+
96+
def get_foo(type_):
97+
return type_.foo
98+
99+
self.assertIn(
100+
"LOAD_ATTR",
101+
[instr.opname for instr in dis.Bytecode(get_foo, adaptive=True)],
102+
)
103+
104+
get_foo(A)
105+
get_foo(A)
106+
107+
# check that specialization has occurred
108+
self.assertNotIn(
109+
"LOAD_ATTR",
110+
[instr.opname for instr in dis.Bytecode(get_foo, adaptive=True)],
111+
)
112+
113+
def test_specialization_user_type_tag_overflow(self):
114+
class A:
115+
def foo(self):
116+
pass
117+
118+
class B:
119+
def foo(self):
120+
pass
121+
122+
type_modified(A)
123+
type_assign_specific_version_unsafe(A, 0)
124+
type_modified(B)
125+
type_assign_specific_version_unsafe(B, 0)
126+
self.assertEqual(type_get_version(A), 0)
127+
self.assertEqual(type_get_version(B), 0)
128+
129+
def get_foo(type_):
130+
return type_.foo
131+
132+
self.assertIn(
133+
"LOAD_ATTR",
134+
[instr.opname for instr in dis.Bytecode(get_foo, adaptive=True)],
135+
)
136+
137+
get_foo(A)
138+
get_foo(A)
139+
140+
# check that specialization has not occurred due to version tag == 0
141+
self.assertIn(
142+
"LOAD_ATTR",
143+
[instr.opname for instr in dis.Bytecode(get_foo, adaptive=True)],
144+
)
145+
59146

60147
if __name__ == "__main__":
61148
unittest.main()

Modules/_testcapimodule.c

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2392,6 +2392,32 @@ type_get_version(PyObject *self, PyObject *type)
23922392
return res;
23932393
}
23942394

2395+
static PyObject *
2396+
type_modified(PyObject *self, PyObject *type)
2397+
{
2398+
if (!PyType_Check(type)) {
2399+
PyErr_SetString(PyExc_TypeError, "argument must be a type");
2400+
return NULL;
2401+
}
2402+
PyType_Modified((PyTypeObject *)type);
2403+
Py_RETURN_NONE;
2404+
}
2405+
2406+
// Circumvents standard version assignment machinery - use with caution and only on
2407+
// short-lived heap types
2408+
static PyObject *
2409+
type_assign_specific_version_unsafe(PyObject *self, PyObject *args)
2410+
{
2411+
PyTypeObject *type;
2412+
unsigned int version;
2413+
if (!PyArg_ParseTuple(args, "Oi:type_assign_specific_version_unsafe", &type, &version)) {
2414+
return NULL;
2415+
}
2416+
assert(!PyType_HasFeature(type, Py_TPFLAGS_IMMUTABLETYPE));
2417+
type->tp_version_tag = version;
2418+
type->tp_flags |= Py_TPFLAGS_VALID_VERSION_TAG;
2419+
Py_RETURN_NONE;
2420+
}
23952421

23962422
static PyObject *
23972423
type_assign_version(PyObject *self, PyObject *type)
@@ -3325,6 +3351,9 @@ static PyMethodDef TestMethods[] = {
33253351
{"test_py_is_macros", test_py_is_macros, METH_NOARGS},
33263352
{"test_py_is_funcs", test_py_is_funcs, METH_NOARGS},
33273353
{"type_get_version", type_get_version, METH_O, PyDoc_STR("type->tp_version_tag")},
3354+
{"type_modified", type_modified, METH_O, PyDoc_STR("PyType_Modified")},
3355+
{"type_assign_specific_version_unsafe", type_assign_specific_version_unsafe, METH_VARARGS,
3356+
PyDoc_STR("forcefully assign type->tp_version_tag")},
33283357
{"type_assign_version", type_assign_version, METH_O, PyDoc_STR("PyUnstable_Type_AssignVersionTag")},
33293358
{"type_get_tp_bases", type_get_tp_bases, METH_O},
33303359
{"type_get_tp_mro", type_get_tp_mro, METH_O},

Python/specialize.c

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -586,6 +586,7 @@ _PyCode_Quicken(PyCodeObject *code)
586586
static int function_kind(PyCodeObject *code);
587587
static bool function_check_args(PyObject *o, int expected_argcount, int opcode);
588588
static uint32_t function_get_version(PyObject *o, int opcode);
589+
static uint32_t type_get_version(PyTypeObject *t, int opcode);
589590

590591
static int
591592
specialize_module_load_attr(
@@ -874,6 +875,9 @@ _Py_Specialize_LoadAttr(PyObject *owner, _Py_CODEUNIT *instr, PyObject *name)
874875
PyObject *descr = NULL;
875876
DescriptorClassification kind = analyze_descriptor(type, name, &descr, 0);
876877
assert(descr != NULL || kind == ABSENT || kind == GETSET_OVERRIDDEN);
878+
if (type_get_version(type, LOAD_ATTR) == 0) {
879+
goto fail;
880+
}
877881
switch(kind) {
878882
case OVERRIDING:
879883
SPECIALIZATION_FAIL(LOAD_ATTR, SPEC_FAIL_ATTR_OVERRIDING_DESCRIPTOR);
@@ -1057,6 +1061,9 @@ _Py_Specialize_StoreAttr(PyObject *owner, _Py_CODEUNIT *instr, PyObject *name)
10571061
}
10581062
PyObject *descr;
10591063
DescriptorClassification kind = analyze_descriptor(type, name, &descr, 1);
1064+
if (type_get_version(type, STORE_ATTR) == 0) {
1065+
goto fail;
1066+
}
10601067
switch(kind) {
10611068
case OVERRIDING:
10621069
SPECIALIZATION_FAIL(STORE_ATTR, SPEC_FAIL_ATTR_OVERRIDING_DESCRIPTOR);
@@ -1183,6 +1190,9 @@ specialize_class_load_attr(PyObject *owner, _Py_CODEUNIT *instr,
11831190
PyObject *descr = NULL;
11841191
DescriptorClassification kind = 0;
11851192
kind = analyze_descriptor((PyTypeObject *)owner, name, &descr, 0);
1193+
if (type_get_version((PyTypeObject *)owner, LOAD_ATTR) == 0) {
1194+
return -1;
1195+
}
11861196
switch (kind) {
11871197
case METHOD:
11881198
case NON_DESCRIPTOR:
@@ -1455,6 +1465,18 @@ function_get_version(PyObject *o, int opcode)
14551465
return version;
14561466
}
14571467

1468+
/* Returning 0 indicates a failure. */
1469+
static uint32_t
1470+
type_get_version(PyTypeObject *t, int opcode)
1471+
{
1472+
uint32_t version = t->tp_version_tag;
1473+
if (version == 0) {
1474+
SPECIALIZATION_FAIL(opcode, SPEC_FAIL_OUT_OF_VERSIONS);
1475+
return 0;
1476+
}
1477+
return version;
1478+
}
1479+
14581480
void
14591481
_Py_Specialize_BinarySubscr(
14601482
PyObject *container, PyObject *sub, _Py_CODEUNIT *instr)
@@ -1726,6 +1748,9 @@ specialize_class_call(PyObject *callable, _Py_CODEUNIT *instr, int nargs)
17261748
}
17271749
if (tp->tp_new == PyBaseObject_Type.tp_new) {
17281750
PyFunctionObject *init = get_init_for_simple_managed_python_class(tp);
1751+
if (type_get_version(tp, CALL) == 0) {
1752+
return -1;
1753+
}
17291754
if (init != NULL) {
17301755
if (((PyCodeObject *)init->func_code)->co_argcount != nargs+1) {
17311756
SPECIALIZATION_FAIL(CALL, SPEC_FAIL_WRONG_NUMBER_ARGUMENTS);
@@ -2466,7 +2491,10 @@ _Py_Specialize_ToBool(PyObject *value, _Py_CODEUNIT *instr)
24662491
SPECIALIZATION_FAIL(TO_BOOL, SPEC_FAIL_OUT_OF_VERSIONS);
24672492
goto failure;
24682493
}
2469-
uint32_t version = Py_TYPE(value)->tp_version_tag;
2494+
uint32_t version = type_get_version(Py_TYPE(value), TO_BOOL);
2495+
if (version == 0) {
2496+
goto failure;
2497+
}
24702498
instr->op.code = TO_BOOL_ALWAYS_TRUE;
24712499
write_u32(cache->version, version);
24722500
assert(version);

0 commit comments

Comments
 (0)
0