8000 gh-112075: Fix race in constructing dict for instance (#118499) · python/cpython@636b8d9 · GitHub
[go: up one dir, main page]

Skip to content

Commit 636b8d9

Browse files
authored
gh-112075: Fix race in constructing dict for instance (#118499)
1 parent 430945d commit 636b8d9

File tree

4 files changed

+216
-71
lines changed

4 files changed

+216
-71
lines changed

Include/internal/pycore_dict.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -105,10 +105,10 @@ PyAPI_FUNC(PyObject *)_PyDict_LoadGlobal(PyDictObject *, PyDictObject *, PyObjec
105105

106106
/* Consumes references to key and value */
107107
PyAPI_FUNC(int) _PyDict_SetItem_Take2(PyDictObject *op, PyObject *key, PyObject *value);
108-
extern int _PyObjectDict_SetItem(PyTypeObject *tp, PyObject **dictptr, PyObject *name, PyObject *value);
109108
extern int _PyDict_SetItem_LockHeld(PyDictObject *dict, PyObject *name, PyObject *value);
110109
extern int _PyDict_GetItemRef_Unicode_LockHeld(PyDictObject *op, PyObject *key, PyObject **result);
111110
extern int _PyDict_GetItemRef_KnownHash(PyDictObject *op, PyObject *key, Py_hash_t hash, PyObject **result);
111+
extern int _PyObjectDict_SetItem(PyTypeObject *tp, PyObject *obj, PyObject **dictptr, PyObject *name, PyObject *value);
112112

113113
extern int _PyDict_Pop_KnownHash(
114114
PyDictObject *dict,
Lines changed: 141 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,141 @@
1+
import gc
2+
import time
3+
import unittest
4+
import weakref
5+
6+
from ast import Or
7+
from functools import partial
8+
from threading import Thread
9+
from unittest import TestCase
10+
11+
from test.support import threading_helper
12+
13+
14+
@threading_helper.requires_working_threading()
15+
class TestDict(TestCase):
16+
def test_racing_creation_shared_keys(self):
17+
"""Verify that creating dictionaries is thread safe when we
18+
have a type with shared keys"""
19+
class C(int):
20+
pass
21+
22+
self.racing_creation(C)
23+
24+
def test_racing_creation_no_shared_keys(self):
25+
"""Verify that creating dictionaries is thread safe when we
26+
have a type with an ordinary dict"""
27+
self.racing_creation(Or)
28+
29+
def test_racing_creation_inline_values_invalid(self):
30+
"""Verify that re-creating a dict after we have invalid inline values
31+
is thread safe"""
32+
class C:
33+
pass
34+
35+
def make_obj():
36+
a = C()
37+
# Make object, make inline values invalid, and then delete dict
38+
a.__dict__ = {}
39+
del a.__dict__
40+
return a
41+
42+
self.racing_creation(make_obj)
43+
44+
def test_racing_creation_nonmanaged_dict(self):
45+
"""Verify that explicit creation of an unmanaged dict is thread safe
46+
outside of the normal attribute setting code path"""
47+
def make_obj():
48+
def f(): pass
49+
return f
50+
51+
def set(func, name, val):
52+
# Force creation of the dict via PyObject_GenericGetDict
53+
func.__dict__[name] = val
54+
55+
self.racing_creation(make_obj, set)
56+
57+
def racing_creation(self, cls, set=setattr):
58+
objects = []
59+
processed = []
60+
61+
OBJECT_COUNT = 100
62+
THREAD_COUNT = 10
63+
CUR = 0
64+
65+
for i in range(OBJECT_COUNT):
66+
objects.append(cls())
67+
68+
def writer_func(name):
69+
last = -1
70+
while True:
71+
if CUR == last:
72+
continue
73+
elif CUR == OBJECT_COUNT:
74+
break
75+
76+
obj = objects[CUR]
77+
set(obj, name, name)
78+
last = CUR
79+
processed.append(name)
80+
81+
writers = []
82+
for x in range(THREAD_COUNT):
83+
writer = Thread(target=partial(writer_func, f"a{x:02}"))
84+
writers.append(writer)
85+
writer.start()
86+
87+
for i in range(OBJECT_COUNT):
88+
CUR = i
89+
while len(processed) != THREAD_COUNT:
90+
time.sleep(0.001)
91+
processed.clear()
92+
93+
CUR = OBJECT_COUNT
94+
95+
for writer in writers:
96+
writer.join()
97+
98+
for obj_idx, obj in enumerate(objects):
99+
assert (
100+
len(obj.__dict__) == THREAD_COUNT
101+
), f"{len(obj.__dict__)} {obj.__dict__!r} {obj_idx}"
102+
for i in range(THREAD_COUNT):
103+
assert f"a{i:02}" in obj.__dict__, f"a{i:02} missing at {obj_idx}"
104+
105+
def test_racing_set_dict(self):
106+
"""Races assigning to __dict__ should be thread safe"""
107+
108+
def f(): pass
109+
l = []
110+
THREAD_COUNT = 10
111+
class MyDict(dict): pass
112+
113+
def writer_func(l):
114+
for i in range(1000):
115+
d = MyDict()
116+
l.append(weakref.ref(d))
117+
f.__dict__ = d
118+
119+
lists = []
120+
writers = []
121+
for x in range(THREAD_COUNT):
122+
thread_list = []
123+
lists.append(thread_list)
124+
writer = Thread(target=partial(writer_func, thread_list))
125+
writers.append(writer)
126+
127+
for writer in writers:
128+
writer.start()
129+
130+
for writer in writers:
131+
writer.join()
132+
133+
f.__dict__ = {}
134+
gc.collect()
135+
136+
for thread_list in lists:
137+
for ref in thread_list:
138+
self.assertIsNone(ref())
139+
140+
if __name__ == "__main__":
141+
unittest.main()

Objects/dictobject.c

Lines changed: 71 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -924,16 +924,15 @@ new_dict(PyInterpreterState *interp,
924924
return (PyObject *)mp;
925925
}
926926

927-
/* Consumes a reference to the keys object */
928927
static PyObject *
929928
new_dict_with_shared_keys(PyInterpreterState *interp, PyDictKeysObject *keys)
930929
{
931930
size_t size = shared_keys_usable_size(keys);
932931
PyDictValues *values = new_values(size);
933932
if (values == NULL) {
934-
dictkeys_decref(interp, keys, false);
935933
return PyErr_NoMemory();
936934
}
935+
dictkeys_incref(keys);
937936
for (size_t i = 0; i < size; i++) {
938937
values->values[i] = NULL;
939938
}
@@ -6693,8 +6692,6 @@ materialize_managed_dict_lock_held(PyObject *obj)
66936692
{
66946693
_Py_CRITICAL_SECTION_ASSERT_OBJECT_LOCKED(obj);
66956694

6696-
OBJECT_STAT_INC(dict_materialized_on_request);
6697-
66986695
PyDictValues *values = _PyObject_InlineValues(obj);
66996696
PyInterpreterState *interp = _PyInterpreterState_GET();
67006697
PyDictKeysObject *keys = CACHED_KEYS(Py_TYPE(obj));
@@ -7186,35 +7183,77 @@ _PyDict_DetachFromObject(PyDictObject *mp, PyObject *obj)
71867183
return 0;
71877184
}
71887185

7189-
PyObject *
7190-
PyObject_GenericGetDict(PyObject *obj, void *context)
7186+
static inline PyObject *
7187+
ensure_managed_dict(PyObject *obj)
71917188
{
7192-
PyInterpreterState *interp = _PyInterpreterState_GET();
7193-
PyTypeObject *tp = Py_TYPE(obj);
7194-
PyDictObject *dict;
7195-
if (_PyType_HasFeature(tp, Py_TPFLAGS_MANAGED_DICT)) {
7196-
dict = _PyObject_GetManagedDict(obj);
7197-
if (dict == NULL &&
7198-
(tp->tp_flags & Py_TPFLAGS_INLINE_VALUES) &&
7189+
PyDictObject *dict = _PyObject_GetManagedDict(obj);
7190+
if (dict == NULL) {
7191+
PyTypeObject *tp = Py_TYPE(obj);
7192+
if ((tp->tp_flags & Py_TPFLAGS_INLINE_VALUES) &&
71997193
FT_ATOMIC_LOAD_UINT8(_PyObject_InlineValues(obj)->valid)) {
72007194
dict = _PyObject_MaterializeManagedDict(obj);
72017195
}
7202-
else if (dict == NULL) {
7203-
Py_BEGIN_CRITICAL_SECTION(obj);
7204-
7196+
else {
7197+
#ifdef Py_GIL_DISABLED
72057198
// Check again that we're not racing with someone else creating the dict
7199+
Py_BEGIN_CRITICAL_SECTION(obj);
72067200
dict = _PyObject_GetManagedDict(obj);
7207-
if (dict == NULL) {
7208-
OBJECT_STAT_INC(dict_materialized_on_request);
7209-
dictkeys_incref(CACHED_KEYS(tp));
7210-
dict = (PyDictObject *)new_dict_with_shared_keys(interp, CACHED_KEYS(tp));
7211-
FT_ATOMIC_STORE_PTR_RELEASE(_PyObject_ManagedDictPointer(obj)->dict,
7212-
(PyDictObject *)dict);
7201+
if (dict != NULL) {
7202+
goto done;
72137203
}
7204+
#endif
7205+
dict = (PyDictObject *)new_dict_with_shared_keys(_PyInterpreterState_GET(),
7206+
CACHED_KEYS(tp));
7207+
FT_ATOMIC_STORE_PTR_RELEASE(_PyObject_ManagedDictPointer(obj)->dict,
7208+
(PyDictObject *)dict);
72147209

7210+
#ifdef Py_GIL_DISABLED
7211+
done:
72157212
Py_END_CRITICAL_SECTION();
7213+
#endif
72167214
}
7217-
return Py_XNewRef((PyObject *)dict);
7215+
}
7216+
return (PyObject *)dict;
7217+
}
7218+
7219+
static inline PyObject *
7220+
ensure_nonmanaged_dict(PyObject *obj, PyObject **dictptr)
7221+
{
7222+
PyDictKeysObject *cached;
7223+
7224+
PyObject *dict = FT_ATOMIC_LOAD_PTR_ACQUIRE(*dictptr);
7225+
if (dict == NULL) {
7226+
#ifdef Py_GIL_DISABLED
7227+
Py_BEGIN_CRITICAL_SECTION(obj);
7228+
dict = *dictptr;
7229+
if (dict != NULL) {
7230+
goto done;
7231+
}
7232+
#endif
7233+
PyTypeObject *tp = Py_TYPE(obj);
7234+
if (_PyType_HasFeature(tp, Py_TPFLAGS_HEAPTYPE) && (cached = CACHED_KEYS(tp))) {
7235+
PyInterpreterState *interp = _PyInterpreterState_GET();
7236+
assert(!_PyType_HasFeature(tp, Py_TPFLAGS_INLINE_VALUES));
7237+
dict = new_dict_with_shared_keys(interp, cached);
7238+
}
7239+
else {
7240+
dict = PyDict_New();
7241+
}
7242+
FT_ATOMIC_STORE_PTR_RELEASE(*dictptr, dict);
7243+
#ifdef Py_GIL_DISABLED
7244+
done:
7245+
Py_END_CRITICAL_SECTION();
7246+
#endif
7247+
}
7248+
return dict;
7249+
}
7250+
7251+
PyObject *
7252+
PyObject_GenericGetDict(PyObject *obj, void *context)
7253+
{
7254+
PyTypeObject *tp = Py_TYPE(obj);
7255+
if (_PyType_HasFeature(tp, Py_TPFLAGS_MANAGED_DICT)) {
7256+
return Py_XNewRef(ensure_managed_dict(obj));
72187257
}
72197258
else {
72207259
PyObject **dictptr = _PyObject_ComputedDictPointer(obj);
@@ -7223,65 +7262,28 @@ PyObject_GenericGetDict(PyObject *obj, void *context)
72237262
"This object has no __dict__");
72247263
return NULL;
72257264
}
7226-
PyObject *dict = *dictptr;
7227-
if (dict == NULL) {
7228-
PyTypeObject *tp = Py_TYPE(obj);
7229-
if (_PyType_HasFeature(tp, Py_TPFLAGS_HEAPTYPE) && CACHED_KEYS(tp)) {
7230-
dictkeys_incref(CACHED_KEYS(tp));
7231-
*dictptr = dict = new_dict_with_shared_keys(
7232-
interp, CACHED_KEYS(tp));
7233-
}
7234-
else {
7235-
*dictptr = dict = PyDict_New();
7236-
}
7237-
}
7238-
return Py_XNewRef(dict);
7265+
7266+
return Py_XNewRef(ensure_nonmanaged_dict(obj, dictptr));
72397267
}
72407268
}
72417269

72427270
int
7243-
_PyObjectDict_SetItem(PyTypeObject *tp, PyObject **dictptr,
7271+
_PyObjectDict_SetItem(PyTypeObject *tp, PyObject *obj, PyObject **dictptr,
72447272
PyObject *key, PyObject *value)
72457273
{
72467274
PyObject *dict;
72477275
int res;
7248-
PyDictKeysObject *cached;
7249-
PyInterpreterState *interp = _PyInterpreterState_GET();
72507276

72517277
assert(dictptr != NULL);
7252-
if ((tp->tp_flags & Py_TPFLAGS_HEAPTYPE) && (cached = CACHED_KEYS(tp))) {
7253-
assert(dictptr != NULL);
7254-
dict = *dictptr;
7255-
if (dict == NULL) {
7256-
assert(!_PyType_HasFeature(tp, Py_TPFLAGS_INLINE_VALUES));
7257-
dictkeys_incref(cached);
7258-
dict = new_dict_with_shared_keys(interp, cached);
7259-
if (dict == NULL)
7260-
return -1;
7261-
*dictptr = dict;
7262-
}
7263-
if (value == NULL) {
7264-
res = PyDict_DelItem(dict, key);
7265-
}
7266-
else {
7267-
res = PyDict_SetItem(dict, key, value);
7268-
}
7269-
} else {
7270-
dict = *dictptr;
7271-
if (dict == NULL) {
7272-
dict = PyDict_New();
7273-
if (dict == NULL)
7274-
return -1;
7275-
*dictptr = dict;
7276-
}
7277-
if (value == NULL) {
7278-
res = PyDict_DelItem(dict, key);
7279-
} else {
7280-
res = PyDict_SetItem(dict, key, value);
7281-
}
7278+
dict = ensure_nonmanaged_dict(obj, dictptr);
7279+
if (dict == NULL) {
7280+
return -1;
72827281
}
72837282

7283+
Py_BEGIN_CRITICAL_SECTION(dict);
7284+
res = _PyDict_SetItem_LockHeld((PyDictObject *)dict, key, value);
72847285
ASSERT_CONSISTENT(dict);
7286+
Py_END_CRITICAL_SECTION();
72857287
return res;
72867288
}
72877289

Objects/object.c

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1731,7 +1731,7 @@ _PyObject_GenericSetAttrWithDict(PyObject *obj, PyObject *name,
17311731
goto done;
17321732
}
17331733
else {
1734-
res = _PyObjectDict_SetItem(tp, dictptr, name, value);
1734+
res = _PyObjectDict_SetItem(tp, obj, dictptr, name, value);
17351735
}
17361736
}
17371737
else {
@@ -1789,7 +1789,9 @@ PyObject_GenericSetDict(PyObject *obj, PyObject *value, void *context)
17891789
"not a '%.200s'", Py_TYPE(value)->tp_name);
17901790
return -1;
17911791
}
1792+
Py_BEGIN_CRITICAL_SECTION(obj);
17921793
Py_XSETREF(*dictptr, Py_NewRef(value));
1794+
Py_END_CRITICAL_SECTION();
17931795
return 0;
17941796
}
17951797

0 commit comments

Comments
 (0)
0