8000 gh-89811: Check for valid tp_version_tag in specializer (GH-113558) · python/cpython@f653caa · GitHub
[go: up one dir, main page]

Skip to content

Commit f653caa

Browse files
authored
gh-89811: Check for valid tp_version_tag in specializer (GH-113558)
1 parent c65ae26 commit f653caa

File tree

4 files changed

+243
-3
lines changed

4 files changed

+243
-3
lines changed

Lib/test/test_type_cache.py

Lines changed: 183 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
""" Tests for the internal type cache in CPython. """
22
import unittest
3+
import dis
34
from test import support
45
from test.support import import_helper
56
try:
@@ -8,8 +9,11 @@
89
_clear_type_cache = None
910

1011
# Skip this test if the _testcapi module isn't available.
11-
type_get_version = import_helper.import_module('_testcapi').type_get_version
12-
type_assign_version = import_helper.import_module('_testcapi').type_assign_version
12+
_testcapi = import_helper.import_module("_testcapi")
13+
type_get_version = _testcapi.type_get_version
14+
type_assign_specific_version_unsafe = _testcapi.type_assign_specific_version_unsafe
15+
type_assign_version = _testcapi.type_assign_version
16+
type_modified = _testcapi.type_modified
1317

1418

1519
@support.cpython_only
@@ -56,6 +60,183 @@ class C:
5660
self.assertNotEqual(type_get_version(C), 0)
5761
self.assertNotEqual(type_get_version(C), c_ver)
5862

63+
def test_type_assign_specific_version(self):
64+
"""meta-test for type_assign_specific_version_unsafe"""
65+
class C:
66+
pass
67+
68+
type_assign_version(C)
69+
orig_version = type_get_version(C)
70+
self.assertNotEqual(orig_version, 0)
71+
72+
type_modified(C)
73+
type_assign_specific_version_unsafe(C, orig_version + 5)
74+
type_assign_version(C) # this should do nothing
75+
76+
new_version = type_get_version(C)
77+
self.assertEqual(new_version, orig_version + 5)
78+
79+
_clear_type_cache()
80+
81+
82+
@support.cpython_only
83+
class TypeCacheWithSpecializationTests(unittest.TestCase):
84+
def tearDown(self):
85+
_clear_type_cache()
86+
87+
def _assign_and_check_valid_version(self, user_type):
88+
type_modified(user_type)
89+
type_assign_version(user_type)
90+
self.assertNotEqual(type_get_version(user_type), 0)
91+
92+
def _assign_and_check_version_0(self, user_type):
93+
type_modified(user_type)
94+
type_assign_specific_version_unsafe(user_type, 0)
95+
self.assertEqual(type_get_version(user_type), 0)
96+
97+
def _all_opnames(self, func):
98+
return set(instr.opname for instr in dis.Bytecode(func, adaptive=True))
99+
100+
def _check_specialization(self, func, arg, opname, *, should_specialize):
101+
self.assertIn(opname, self._all_opnames(func))
102+
103+
for _ in range(100):
104+
func(arg)
105+
106+
if should_specialize:
107+
self.assertNotIn(opname, self._all_opnames(func))
108+
else:
109+
self.assertIn(opname, self._all_opnames(func))
110+
111+
def test_class_load_attr_specialization_user_type(self):
112+
class A:
113+
def foo(self):
114+
pass
115+
116+
self._assign_and_check_valid_version(A)
117+
118+
def load_foo_1(type_):
119+
type_.foo
120+
121+
self._check_specialization(load_foo_1, A, "LOAD_ATTR", should_specialize=True)
122+
del load_foo_1
123+
124+
self._assign_and_check_version_0(A)
125+
126+
def load_foo_2(type_):
127+
return type_.foo
128+
129+
self._check_specialization(load_foo_2, A, "LOAD_ATTR", should_specialize=False)
130+
131+
def test_class_load_attr_specialization_static_type(self):
132+
self._assign_and_check_valid_version(str)
133+
self._assign_and_check_valid_version(bytes)
134+
135+
def get_capitalize_1(type_):
136+
return type_.capitalize
137+
138+
self._check_specialization(get_capitalize_1, str, "LOAD_ATTR", should_specialize=True)
139+
self.assertEqual(get_capitalize_1(str)('hello'), 'Hello')
140+
self.assertEqual(get_capitalize_1(bytes)(b'hello'), b'Hello')
141+
del get_capitalize_1
142+
143+
# Permanently overflow the static type version counter, and force str and bytes
144+
# to have tp_version_tag == 0
145+
for _ in range(2**16):
146+
type_modified(str)
147+
type_assign_version(str)
148+
type_modified(bytes)
149+
type_assign_version(bytes)
150+
151+
self.assertEqual(type_get_version(str), 0)
152+
self.assertEqual(type_get_version(bytes), 0)
153+
154+
def get_capitalize_2(type_):
155+
return type_.capitalize
156+
157+
self._check_specialization(get_capitalize_2, str, "LOAD_ATTR", should_specialize=False)
158+
self.assertEqual(get_capitalize_2(str)('hello'), 'Hello')
159+
self.assertEqual(get_capitalize_2(bytes)(b'hello'), b'Hello')
160+
161+
def test_property_load_attr_specialization_user_type(self):
162+
class G:
163+
@property
164+
def x(self):
165+
return 9
166+
167+
self._assign_and_check_valid_version(G)
168+
169+
def load_x_1(instance):
170+
instance.x
171+
172+
self._check_specialization(load_x_1, G(), "LOAD_ATTR", should_specialize=True)
173+
del load_x_1
174+
175+
self._assign_and_check_version_0(G)
176+
177+
def load_x_2(instance):
178+
instance.x
179+
180+
self._check_specialization(load_x_2, G(), "LOAD_ATTR", should_specialize=False)
181+
182+
def test_store_attr_specialization_user_type(self):
183+
class B:
184+
__slots__ = ("bar",)
185+
186+
self._assign_and_check_valid_version(B)
187+
188+
def store_bar_1(type_):
189+
type_.bar = 10
190+
191+
self._check_specialization(store_bar_1, B(), "STORE_ATTR", should_specialize=True)
192+
del store_bar_1
193+
194+
self._assign_and_check_version_0(B)
195+
196+
def store_bar_2(type_):
197+
type_.bar = 10
198+
199+
self._check_specialization(store_bar_2, B(), "STORE_ATTR", should_specialize=False)
200+
201+
def test_class_call_specialization_user_type(self):
202+
class F:
203+
def __init__(self):
204+
pass
205+
206+
self._assign_and_check_valid_version(F)
207+
208+
def call_class_1(type_):
209+
type_()
210+
211+
self._check_specialization(call_class_1, F, "CALL", should_specialize=True)
212+
del call_class_1
213+
214+
self._assign_and_check_version_0(F)
215+
216+
def call_class_2(type_):
217+
type_()
218+
219+
self._check_specialization(call_class_2, F, "CALL", should_specialize=False)
220+
221+
def test_to_bool_specialization_user_type(self):
222+
class H:
223+
pass
224+
225+
self._assign_and_check_valid_version(H)
226+
227+
def to_bool_1(instance):
228+
not instance
229+
230+
self._check_specialization(to_bool_1, H(), "TO_BOOL", should_specialize=True)
231+
del to_bool_1
232+
233+
self._assign_and_check_version_0(H)
234+
235+
def to_bool_2(instance):
236+
not instance
237+
238+
self._check_specialization(to_bool_2, H(), "TO_BOOL", should_specialize=False)
239+
59240

60241
if __name__ == "__main__":
61242
unittest.main()
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
Check for a valid ``tp_version_tag`` before performing bytecode specializations that
2+
rely on this value being usable.

Modules/_testcapimodule.c

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2409,6 +2409,32 @@ type_get_version(PyObject *self, PyObject *type)
24092409
return res;
24102410
}
24112411

2412+
static PyObject *
2413+
type_modified(PyObject *self, PyObject *type)
2414+
{
2415+
if (!PyType_Check(type)) {
2416+
PyErr_SetString(PyExc_TypeError, "argument must be a type");
2417+
return NULL;
2418+
}
2419+
PyType_Modified((PyTypeObject *)type);
2420+
Py_RETURN_NONE;
2421+
}
2422+
2423+
// Circumvents standard version assignment machinery - use with caution and only on
2424+
// short-lived heap types
2425+
static PyObject *
2426+
type_assign_specific_version_unsafe(PyObject *self, PyObject *args)
2427+
{
2428+
PyTypeObject *type;
2429+
unsigned int version;
2430+
if (!PyArg_ParseTuple(args, "Oi:type_assign_specific_version_unsafe", &type, &version)) {
2431+
return NULL;
2432+
}
2433+
assert(!PyType_HasFeature(type, Py_TPFLAGS_IMMUTABLETYPE));
2434+
type->tp_version_tag = version;
2435+
type->tp_flags |= Py_TPFLAGS_VALID_VERSION_TAG;
2436+
Py_RETURN_NONE;
2437+
}
24122438

24132439
static PyObject *
24142440
type_assign_version(PyObject *self, PyObject *type)
@@ -3342,6 +3368,9 @@ static PyMethodDef TestMethods[] = {
33423368
{"test_py_is_macros", test_py_is_macros, METH_NOARGS},
33433369
{"test_py_is_funcs", test_py_is_funcs, METH_NOARGS},
33443370
{"type_get_version", type_get_version, METH_O, PyDoc_STR("type->tp_version_tag")},
3371+
{"type_modified", type_modified, METH_O, PyDoc_STR("PyType_Modified")},
3372+
{"type_assign_specific_version_unsafe", type_assign_specific_version_unsafe, METH_VARARGS,
3373+
PyDoc_STR("forcefully assign type->tp_version_tag")},
33453374
{"type_assign_version", type_assign_version, METH_O, PyDoc_STR("PyUnstable_Type_AssignVersionTag")},
33463375
{"type_get_tp_bases", type_get_tp_bases, METH_O},
33473376
{"type_get_tp_mro", type_get_tp_mro, METH_O},

Python/specialize.c

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -586,6 +586,7 @@ _PyCode_Quicken(PyCodeObject *code)
586586
static int function_kind(PyCodeObject *code);
587587
static bool function_check_args(PyObject *o, int expected_argcount, int opcode);
588588
static uint32_t function_get_version(PyObject *o, int opcode);
589+
static uint32_t type_get_version(PyTypeObject *t, int opcode);
589590

590591
static int
591592
specialize_module_load_attr(
@@ -874,6 +875,9 @@ _Py_Specialize_LoadAttr(PyObject *owner, _Py_CODEUNIT *instr, PyObject *name)
874875
PyObject *descr = NULL;
875876
DescriptorClassification kind = analyze_descriptor(type, name, &descr, 0);
876877
assert(descr != NULL || kind == ABSENT || kind == GETSET_OVERRIDDEN);
878+
if (type_get_version(type, LOAD_ATTR) == 0) {
879+
goto fail;
880+
}
877881
switch(kind) {
878882
case OVERRIDING:
879883
SPECIALIZATION_FAIL(LOAD_ATTR, SPEC_FAIL_ATTR_OVERRIDING_DESCRIPTOR);
@@ -1057,6 +1061,9 @@ _Py_Specialize_StoreAttr(PyObject *owner, _Py_CODEUNIT *instr, PyObject *name)
10571061
}
10581062
PyObject *descr;
10591063
DescriptorClassification kind = analyze_descriptor(type, name, &descr, 1);
1064+
if (type_get_version(type, STORE_ATTR) == 0) {
1065+
goto fail;
1066+
}
10601067
switch(kind) {
10611068
case OVERRIDING:
10621069
SPECIALIZATION_FAIL(STORE_ATTR, SPEC_FAIL_ATTR_OVERRIDING_DESCRIPTOR);
@@ -1183,6 +1190,9 @@ specialize_class_load_attr(PyObject *owner, _Py_CODEUNIT *instr,
11831190
PyObject *descr = NULL;
11841191
DescriptorClassification kind = 0;
11851192
kind = analyze_descriptor((PyTypeObject *)owner, name, &descr, 0);
1193+
if (type_get_version((PyTypeObject *)owner, LOAD_ATTR) == 0) {
1194+
return -1;
1195+
}
11861196
switch (kind) {
11871197
case METHOD:
11881198
case NON_DESCRIPTOR:
@@ -1455,6 +1465,18 @@ function_get_version(PyObject *o, int opcode)
14551465
return version;
14561466
}
14571467

1468+
/* Returning 0 indicates a failure. */
1469+
static uint32_t
1470+
type_get_version(PyTypeObject *t, int opcode)
1471+
{
1472+
uint32_t version = t->tp_version_tag;
1473+
if (version == 0) {
1474+
SPECIALIZATION_FAIL(opcode, SPEC_FAIL_OUT_OF_VERSIONS);
1475+
return 0;
1476+
}
1477+
return version;
1478+
}
1479+
14581480
void
14591481
_Py_Specialize_BinarySubscr(
14601482
PyObject *container, PyObject *sub, _Py_CODEUNIT *instr)
@@ -1726,6 +1748,9 @@ specialize_class_call(PyObject *callable, _Py_CODEUNIT *instr, int nargs)
17261748
}
17271749
if (tp->tp_new == PyBaseObject_Type.tp_new) {
17281750
PyFunctionObject *init = get_init_for_simple_managed_python_class(tp);
1751+
if (type_get_version(tp, CALL) == 0) {
1752+
return -1;
1753+
}
17291754
if (init != NULL) {
17301755
if (((PyCodeObject *)init->func_code)->co_argcount != nargs+1) {
17311756
SPECIALIZATION_FAIL(CALL, SPEC_FAIL_WRONG_NUMBER_ARGUMENTS);
@@ -2466,7 +2491,10 @@ _Py_Specialize_ToBool(PyObject *value, _Py_CODEUNIT *instr)
24662491
SPECIALIZATION_FAIL(TO_BOOL, SPEC_FAIL_OUT_OF_VERSIONS);
24672492
goto failure;
24682493
}
2469-
uint32_t version = Py_TYPE(value)->tp_version_tag;
2494+
uint32_t version = type_get_version(Py_TYPE(value), TO_BOOL);
2495+
if (version == 0) {
2496+
goto failure;
2497+
}
24702498
instr->op.code = TO_BOOL_ALWAYS_TRUE;
24712499
write_u32(cache->version, version);
24722500
assert(version);

0 commit comments

Comments
 (0)
0