8000 Fix PyObject_GenericGetDict, PyObject_GenericSetDict, and add test cases · python/cpython@56a79e6 · GitHub
[go: up one dir, main page]

Skip to content

Commit 56a79e6

Browse files
committed
Fix PyObject_GenericGetDict, PyObject_GenericSetDict, and add test cases
1 parent ae6d925 commit 56a79e6

File tree

3 files changed

+206
-51
lines changed

3 files changed

+206
-51
lines changed
Lines changed: 141 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,141 @@
1+
import gc
2+
import time
3+
import unittest
4+
import weakref
5+
6+
from ast import Or
7+
from functools import partial
8+
from threading import Thread
9+
from unittest import TestCase
10+
11+
from test.support import is_wasi
12+
13+
14+
@unittest.skipIf(is_wasi, "WASI has no threads.")
15+
class TestDict(TestCase):
16+
def test_racing_creation_shared_keys(self):
17+
"""Verify that creating dictionaries is thread safe when we
18+
have a type with shared keys"""
19+
class C(int):
20+
pass
21+
22+
self.racing_creation(C)
23+
24+
def test_racing_creation_no_shared_keys(self):
25+
"""Verify that creating dictionaries is thread safe when we
26+
have a type with an ordinary dict"""
27+
self.racing_creation(Or)
28+
29+
def test_racing_creation_inline_values_invalid(self):
30+
"""Verify that re-creating a dict after we have invalid inline values
31+
is thread safe"""
32+
class C:
33+
pass
34+
35+
def make_obj():
36+
a = C()
37+
# Make object, make inline values invalid, and then delete dict
38+
a.__dict__ = {}
39+
del a.__dict__
40+
return a
41+
42+
self.racing_creation(make_obj)
43+
44+
def test_racing_creation_nonmanaged_dict(self):
45+
"""Verify that explicit creation of an unmanaged dict is thread safe
46+
outside of the normal attribute setting code path"""
47+
def make_obj():
48+
def f(): pass
49+
return f
50+
51+
def set(func, name, val):
52+
# Force creation of the dict via PyObject_GenericGetDict
53+
func.__dict__[name] = val
54+
55+
self.racing_creation(make_obj, set)
56+
57+
def racing_creation(self, cls, set=setattr):
58+
objects = []
59+
processed = []
60+
61+
OBJECT_COUNT = 100
62+
THREAD_COUNT = 10
63+
CUR = 0
64+
65+
for i in range(OBJECT_COUNT):
66+
objects.append(cls())
67+
68+
def writer_func(name):
69+
last = -1
70+
while True:
71+
if CUR == last:
72+
continue
73+
elif CUR == OBJECT_COUNT:
74+
break
75+
76+
obj = objects[CUR]
77+
set(obj, name, name)
78+
last = CUR
79+
processed.append(name)
80+
81+
writers = []
82+
for x in range(THREAD_COUNT):
83+
writer = Thread(target=partial(writer_func, f"a{x:02}"))
84+
writers.append(writer)
85+
writer.start()
86+
87+
for i in range(OBJECT_COUNT):
88+
CUR = i
89+
while len(processed) != THREAD_COUNT:
90+
time.sleep(0.001)
91+
processed.clear()
92+
93+
CUR = OBJECT_COUNT
94+
95+
for writer in writers:
96+
writer.join()
97+
98+
for obj_idx, obj in enumerate(objects):
99+
assert (
100+
len(obj.__dict__) == THREAD_COUNT
101+
), f"{len(obj.__dict__)} {obj.__dict__!r} {obj_idx}"
102+
for i in range(THREAD_COUNT):
103+
assert f"a{i:02}" in obj.__dict__, f"a{i:02} missing at {obj_idx}"
104+
105+
def test_racing_set_dict(self):
106+
"""Races assigning to __dict__ should be thread safe"""
107+
108+
def f(): pass
109+
l = []
110+
THREAD_COUNT = 10
111+
class MyDict(dict): pass
112+
113+
def writer_func(l):
114+
for i in range(1000):
115+
d = MyDict()
116+
l.append(weakref.ref(d))
117+
f.__dict__ = d
118+
119+
lists = []
120+
writers = []
121+
for x in range(THREAD_COUNT):
122+
thread_list = []
123+
lists.append(thread_list)
124+
writer = Thread(target=partial(writer_func, thread_list))
125+
writers.append(writer)
126+
127+
for writer in writers:
128+
writer.start()
129+
130+
for writer in writers:
131+
writer.join()
132+
133+
f.__dict__ = {}
134+
gc.collect()
135+
136+
for thread_list in lists:
137+
for ref in thread_list:
138+
self.assertIsNone(ref())
139+
140+
if __name__ == "__main__":
141+
unittest.main()

Objects/dictobject.c

Lines changed: 61 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -7159,70 +7159,47 @@ _PyDict_DetachFromObject(PyDictObject *mp, PyObject *obj)
71597159
return 0;
71607160
}
71617161

7162-
PyObject *
7163-
PyObject_GenericGetDict(PyObject *obj, void *context)
7162+
static inline PyObject *
7163+
ensure_managed_dict(PyObject *obj)
71647164
{
7165-
PyInterpreterState *interp = _PyInterpreterState_GET();
7166-
PyTypeObject *tp = Py_TYPE(obj);
7167-
PyDictObject *dict;
7168-
if (_PyType_HasFeature(tp, Py_TPFLAGS_MANAGED_DICT)) {
7169-
dict = _PyObject_GetManagedDict(obj);
7170-
if (dict == NULL &&
7171-
(tp->tp_flags & Py_TPFLAGS_INLINE_VALUES) &&
7165+
PyDictObject *dict = _PyObject_GetManagedDict(obj);
7166+
if (dict == NULL) {
7167+
PyTypeObject *tp = Py_TYPE(obj);
7168+
if ((tp->tp_flags & Py_TPFLAGS_INLINE_VALUES) &&
71727169
FT_ATOMIC_LOAD_UINT8(_PyObject_InlineValues(obj)->valid)) {
71737170
dict = _PyObject_MaterializeManagedDict(obj);
71747171
}
7175-
else if (dict == NULL) {
7176-
Py_BEGIN_CRITICAL_SECTION(obj);
7177-
7172+
else {
7173+
#ifdef Py_GIL_DISABLED
71787174
// Check again that we're not racing with someone else creating the dict
7175+
Py_BEGIN_CRITICAL_SECTION(obj);
71797176
dict = _PyObject_GetManagedDict(obj);
7180-
if (dict == NULL) {
7181-
OBJECT_STAT_INC(dict_materialized_on_request);
7182-
dictkeys_incref(CACHED_KEYS(tp));
7183-
dict = (PyDictObject *)new_dict_with_shared_keys(interp, CACHED_KEYS(tp));
7184-
FT_ATOMIC_STORE_PTR_RELEASE(_PyObject_ManagedDictPointer(obj)->dict,
7185-
(PyDictObject *)dict);
7177+
if (dict != NULL) {
7178+
goto done;
71867179
}
7180+
#endif
7181+
OBJECT_STAT_INC(dict_materialized_on_request);
7182+
dictkeys_incref(CACHED_KEYS(tp));
7183+
dict = (PyDictObject *)new_dict_with_shared_keys(_PyInterpreterState_GET(),
7184+
CACHED_KEYS(tp));
7185+
FT_ATOMIC_STORE_PTR_RELEASE(_PyObject_ManagedDictPointer(obj)->dict,
7186+
(PyDictObject *)dict);
71877187

7188+
#ifdef Py_GIL_DISABLED
7189+
done:
71887190
Py_END_CRITICAL_SECTION();
7191+
#endif
71897192
}
7190-
return Py_XNewRef((PyObject *)dict);
7191-
}
7192-
else {
7193-
PyObject **dictptr = _PyObject_ComputedDictPointer(obj);
7194-
if (dictptr == NULL) {
7195-
PyErr_SetString(PyExc_AttributeError,
7196-
"This object has no __dict__");
7197-
return NULL;
7198-
}
7199-
PyObject *dict = *dictptr;
7200-
if (dict == NULL) {
7201-
PyTypeObject *tp = Py_TYPE(obj);
7202-
if (_PyType_HasFeature(tp, Py_TPFLAGS_HEAPTYPE) && CACHED_KEYS(tp)) {
7203-
dictkeys_incref(CACHED_KEYS(tp));
7204-
*dictptr = dict = new_dict_with_shared_keys(
7205-
interp, CACHED_KEYS(tp));
7206-
}
7207-
else {
7208-
*dictptr = dict = PyDict_New();
7209-
}
7210-
}
7211-
return Py_XNewRef(dict);
72127193
}
7194+
return (PyObject *)dict;
72137195
}
72147196

7215-
int
7216-
_PyObjectDict_SetItem(PyTypeObject *tp, PyObject *obj, PyObject **dictptr,
7217-
PyObject *key, PyObject *value)
7197+
static inline PyObject *
7198+
ensure_nonmanaged_dict(PyObject *obj, PyObject **dictptr)
72187199
{
7219-
PyObject *dict;
7220-
int res;
72217200
PyDictKeysObject *cached;
7222-
PyInterpreterState *interp = _PyInterpreterState_GET();
72237201

7224-
assert(dictptr != NULL);
7225-
dict = *dictptr;
7202+
PyObject *dict = FT_ATOMIC_LOAD_PTR_RELAXED(*dictptr);
72267203
if (dict == NULL) {
72277204
#ifdef Py_GIL_DISABLED
72287205
Py_BEGIN_CRITICAL_SECTION(obj);
@@ -7231,7 +7208,9 @@ _PyObjectDict_SetItem(PyTypeObject *tp, PyObject *obj, PyObject **dictptr,
72317208
goto done;
72327209
}
72337210
#endif
7211+
PyTypeObject *tp = Py_TYPE(obj);
72347212
if ((tp->tp_flags & Py_TPFLAGS_HEAPTYPE) && (cached = CACHED_KEYS(tp))) {
7213+
PyInterpreterState *interp = _PyInterpreterState_GET();
72357214
assert(!_PyType_HasFeature(tp, Py_TPFLAGS_INLINE_VALUES));
72367215
dictkeys_incref(cached);
72377216
dict = new_dict_with_shared_keys(interp, cached);
@@ -7242,14 +7221,45 @@ _PyObjectDict_SetItem(PyTypeObject *tp, PyObject *obj, PyObject **dictptr,
72427221
else {
72437222
dict = PyDict_New();
72447223
}
7245-
*dictptr = dict;
7224+
FT_ATOMIC_STORE_PTR_RELAXED(*dictptr, dict);
72467225
#ifdef Py_GIL_DISABLED
72477226
done:
72487227
Py_END_CRITICAL_SECTION();
72497228
#endif
7250-
if (dict == NULL) {
7251-
return -1;
7229+
}
7230+
return dict;
7231+
}
7232+
7233+
PyObject *
7234+
PyObject_GenericGetDict(PyObject *obj, void *context)
7235+
{
7236+
PyTypeObject *tp = Py_TYPE(obj);
7237+
if (_PyType_HasFeature(tp, Py_TPFLAGS_MANAGED_DICT)) {
7238+
return Py_XNewRef(ensure_managed_dict(obj));
7239+
}
7240+
else {
7241+
PyObject **dictptr = _PyObject_ComputedDictPointer(obj);
7242+
if (dictptr == NULL) {
7243+
PyErr_SetString(PyExc_AttributeError,
7244+
"This object has no __dict__");
7245+
return NULL;
72527246
}
7247+
7248+
return Py_XNewRef(ensure_nonmanaged_dict(obj, dictptr));
7249+
}
7250+
}
7251+
7252+
int
7253+
_PyObjectDict_SetItem(PyTypeObject *tp, PyObject *obj, PyObject **dictptr,
7254+
PyObject *key, PyObject *value)
7255+
{
7256+
PyObject *dict;
7257+
int res;
7258+
7259+
assert(dictptr != NULL);
7260+
dict = ensure_nonmanaged_dict(obj, dictptr);
7261+
if (dict == NULL) {
7262+
return -1;
72537263
}
72547264

72557265
Py_BEGIN_CRITICAL_SECTION(dict);

Objects/object.c

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1797,7 +1797,11 @@ PyObject_GenericSetDict(PyObject *obj, PyObject *value, void *context)
17971797
"not a '%.200s'", Py_TYPE(value)->tp_name);
17981798
return -1;
17991799
}
1800+
#ifdef Py_GIL_DISABLED
1801+
Py_XDECREF(_Py_atomic_exchange_ptr(dictptr, Py_NewRef(value)));
1802+
#else
18001803
Py_XSETREF(*dictptr, Py_NewRef(value));
1804+
#endif
18011805
return 0;
18021806
}
18031807

0 commit comments

Comments
 (0)
0