8000 Fix PyObject_GenericGetDict, PyObject_GenericSetDict, and add test cases · python/cpython@0404e6d · GitHub
[go: up one dir, main page]

Skip to content

Commit 0404e6d

Browse files
committed
Fix PyObject_GenericGetDict, PyObject_GenericSetDict, and add test cases
1 parent a59ebf7 commit 0404e6d

File tree

3 files changed

+206
-51
lines changed

3 files changed

+206
-51
lines changed
Lines changed: 141 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,141 @@
1+
import gc
2+
import time
3+
import unittest
4+
import weakref
5+
6+
from ast import Or
7+
from functools import partial
8+
from threading import Thread
9+
from unittest import TestCase
10+
11+
from test.support import is_wasi
12+
13+
14+
@unittest.skipIf(is_wasi, "WASI has no threads.")
15+
class TestDict(TestCase):
16+
def test_racing_creation_shared_keys(self):
17+
"""Verify that creating dictionaries is thread safe when we
18+
have a type with shared keys"""
19+
class C(int):
20+
pass
21+
22+
self.racing_creation(C)
23+
24+
def test_racing_creation_no_shared_keys(self):
25+
"""Verify that creating dictionaries is thread safe when we
26+
have a type with an ordinary dict"""
27+
self.racing_creation(Or)
28+
29+
def test_racing_creation_inline_values_invalid(self):
30+
"""Verify that re-creating a dict after we have invalid inline values
31+
is thread safe"""
32+
class C:
33+
pass
34+
35+
def make_obj():
36+
a = C()
37+
# Make object, make inline values invalid, and then delete dict
38+
a.__dict__ = {}
39+
del a.__dict__
40+
return a
41+
42+
self.racing_creation(make_obj)
43+
44+
def test_racing_creation_nonmanaged_dict(self):
45+
"""Verify that explicit creation of an unmanaged dict is thread safe
46+
outside of the normal attribute setting code path"""
47+
def make_obj():
48+
def f(): pass
49+
return f
50+
51+
def set(func, name, val):
52+
# Force creation of the dict via PyObject_GenericGetDict
53+
func.__dict__[name] = val
54+
55+
self.racing_creation(make_obj, set)
56+
57+
def racing_creation(self, cls, set=setattr):
58+
objects = []
59+
processed = []
60+
61+
OBJECT_COUNT = 100
62+
THREAD_COUNT = 10
63+
CUR = 0
64+
65+
for i in range(OBJECT_COUNT):
66+
objects.append(cls())
67+
68+
def writer_func(name):
69+
last = -1
70+
while True:
71+
if CUR == last:
72+
continue
73+
elif CUR == OBJECT_COUNT:
74+
break
75+
76+
obj = objects[CUR]
77+
set(obj, name, name)
78+
last = CUR
79+
processed.append(name)
80+
81+
writers = []
82+
for x in range(THREAD_COUNT):
83+
writer = Thread(target=partial(writer_func, f"a{x:02}"))
84+
writers.append(writer)
85+
writer.start()
86+
87+
for i in range(OBJECT_COUNT):
88+
CUR = i
89+
while len(processed) != THREAD_COUNT:
90+
time.sleep(0.001)
91+
processed.clear()
92+
93+
CUR = OBJECT_COUNT
94+
95+
for writer in writers:
96+
writer.join()
97+
98+
for obj_idx, obj in enumerate(objects):
99+
assert (
100+
len(obj.__dict__) == THREAD_COUNT
101+
), f"{len(obj.__dict__)} {obj.__dict__!r} {obj_idx}"
102+
for i in range(THREAD_COUNT):
103+
assert f"a{i:02}" in obj.__dict__, f"a{i:02} missing at {obj_idx}"
104+
105+
def test_racing_set_dict(self):
106+
"""Races assigning to __dict__ should be thread safe"""
107+
108+
def f(): pass
109+
l = []
110+
THREAD_COUNT = 10
111+
class MyDict(dict): pass
112+
113+
def writer_func(l):
114+
for i in range(1000):
115+
d = MyDict()
116+
l.append(weakref.ref(d))
117+
f.__dict__ = d
118+
119+
lists = []
120+
writers = []
121+
for x in range(THREAD_COUNT):
122+
thread_list = []
123+
lists.append(thread_list)
124+
writer = Thread(target=partial(writer_func, thread_list))
125+
writers.append(writer)
126+
127+
for writer in writers:
128+
writer.start()
129+
130+
for writer in writers:
131+
writer.join()
132+
133+
f.__dict__ = {}
134+
gc.collect()
135+
136+
for thread_list in lists:
137+
for ref in thread_list:
138+
self.assertIsNone(ref())
139+
140+
if __name__ == "__main__":
141+
unittest.main()

Objects/dictobject.c

Lines changed: 61 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -7186,70 +7186,47 @@ _PyDict_DetachFromObject(PyDictObject *mp, PyObject *obj)
71867186
return 0;
71877187
}
71887188

7189-
PyObject *
7190-
PyObject_GenericGetDict(PyObject *obj, void *context)
7189+
static inline PyObject *
7190+
ensure_managed_dict(PyObject *obj)
71917191
{
7192-
PyInterpreterState *interp = _PyInterpreterState_GET();
7193-
PyTypeObject *tp = Py_TYPE(obj);
7194-
PyDictObject *dict;
7195-
if (_PyType_HasFeature(tp, Py_TPFLAGS_MANAGED_DICT)) {
7196-
dict = _PyObject_GetManagedDict(obj);
7197-
if (dict == NULL &&
7198-
(tp->tp_flags & Py_TPFLAGS_INLINE_VALUES) &&
7192+
PyDictObject *dict = _PyObject_GetManagedDict(obj);
7193+
if (dict == NULL) {
7194+
PyTypeObject *tp = Py_TYPE(obj);
7195+
if ((tp->tp_flags & Py_TPFLAGS_INLINE_VALUES) &&
71997196
FT_ATOMIC_LOAD_UINT8(_PyObject_InlineValues(obj)->valid)) {
72007197
dict = _PyObject_MaterializeManagedDict(obj);
72017198
}
7202-
else if (dict == NULL) {
7203-
Py_BEGIN_CRITICAL_SECTION(obj);
7204-
7199+
else {
7200+
#ifdef Py_GIL_DISABLED
72057201
// Check again that we're not racing with someone else creating the dict
7202+
Py_BEGIN_CRITICAL_SECTION(obj);
72067203
dict = _PyObject_GetManagedDict(obj);
7207-
if (dict == NULL) {
7208-
OBJECT_STAT_INC(dict_materialized_on_request);
7209-
dictkeys_incref(CACHED_KEYS(tp));
7210-
dict = (PyDictObject *)new_dict_with_shared_keys(interp, CACHED_KEYS(tp));
7211-
FT_ATOMIC_STORE_PTR_RELEASE(_PyObject_ManagedDictPointer(obj)->dict,
7212-
(PyDictObject *)dict);
7204+
if (dict != NULL) {
7205+
goto done;
72137206
}
7207+
#endif
7208+
OBJECT_STAT_INC(dict_materialized_on_request);
7209+
dictkeys_incref(CACHED_KEYS(tp));
7210+
dict = (PyDictObject *)new_dict_with_shared_keys(_PyInterpreterState_GET(),
7211+
CACHED_KEYS(tp));
7212+
FT_ATOMIC_STORE_PTR_RELEASE(_PyObject_ManagedDictPointer(obj)->dict,
7213+
(PyDictObject *)dict);
72147214

7215+
#ifdef Py_GIL_DISABLED
7216+
done:
72157217
Py_END_CRITICAL_SECTION();
7218+
#endif
72167219
}
7217-
return Py_XNewRef((PyObject *)dict);
7218-
}
7219-
else {
7220-
PyObject **dictptr = _PyObject_ComputedDictPointer(obj);
7221-
if (dictptr == NULL) {
7222-
PyErr_SetString(PyExc_AttributeError,
7223-
"This object has no __dict__");
7224-
return NULL;
7225-
}
7226-
PyObject *dict = *dictptr;
7227-
if (dict == NULL) {
7228-
PyTypeObject *tp = Py_TYPE(obj);
7229-
if (_PyType_HasFeature(tp, Py_TPFLAGS_HEAPTYPE) && CACHED_KEYS(tp)) {
7230-
dictkeys_incref(CACHED_KEYS(tp));
7231-
*dictptr = dict = new_dict_with_shared_keys(
7232-
interp, CACHED_KEYS(tp));
7233-
}
7234-
else {
7235-
*dictptr = dict = PyDict_New();
7236-
}
7237-
}
7238-
return Py_XNewRef(dict);
72397220
}
7221+
return (PyObject *)dict;
72407222
}
72417223

7242-
int
7243-
_PyObjectDict_SetItem(PyTypeObject *tp, PyObject *obj, PyObject **dictptr,
7244-
PyObject *key, PyObject *value)
7224+
static inline PyObject *
7225+
ensure_nonmanaged_dict(PyObject *obj, PyObject **dictptr)
72457226
{
7246-
PyObject *dict;
7247-
int res;
72487227
PyDictKeysObject *cached;
7249-
PyInterpreterState *interp = _PyInterpreterState_GET();
72507228

7251-
assert(dictptr != NULL);
7252-
dict = *dictptr;
7229+
PyObject *dict = FT_ATOMIC_LOAD_PTR_RELAXED(*dictptr);
72537230
if (dict == NULL) {
72547231
#ifdef Py_GIL_DISABLED
72557232
Py_BEGIN_CRITICAL_SECTION(obj);
@@ -7258,7 +7235,9 @@ _PyObjectDict_SetItem(PyTypeObject *tp, PyObject *obj, PyObject **dictptr,
72587235
goto done;
72597236
}
72607237
#endif
7238+
PyTypeObject *tp = Py_TYPE(obj);
72617239
if ((tp->tp_flags & Py_TPFLAGS_HEAPTYPE) && (cached = CACHED_KEYS(tp))) {
7240+
PyInterpreterState *interp = _PyInterpreterState_GET();
72627241
assert(!_PyType_HasFeature(tp, Py_TPFLAGS_INLINE_VALUES));
72637242
dictkeys_incref(cached);
72647243
dict = new_dict_with_shared_keys(interp, cached);
@@ -7269,14 +7248,45 @@ _PyObjectDict_SetItem(PyTypeObject *tp, PyObject *obj, PyObject **dictptr,
72697248
else {
72707249
dict = PyDict_New();
72717250
}
7272-
*dictptr = dict;
7251+
FT_ATOMIC_STORE_PTR_RELAXED(*dictptr, dict);
72737252
#ifdef Py_GIL_DISABLED
72747253
done:
72757254
Py_END_CRITICAL_SECTION();
72767255
#endif
7277-
if (dict == NULL) {
7278-
return -1;
7256+
}
7257+
return dict;
7258+
}
7259+
7260+
PyObject *
7261+
PyObject_GenericGetDict(PyObject *obj, void *context)
7262+
{
7263+
PyTypeObject *tp = Py_TYPE(obj);
7264+
if (_PyType_HasFeature(tp, Py_TPFLAGS_MANAGED_DICT)) {
7265+
return Py_XNewRef(ensure_managed_dict(obj));
7266+
}
7267+
else {
7268+
PyObject **dictptr = _PyObject_ComputedDictPointer(obj);
7269+
if (dictptr == NULL) {
7270+
PyErr_SetString(PyExc_AttributeError,
7271+
"This object has no __dict__");
7272+
return NULL;
72797273
}
7274+
7275+
return Py_XNewRef(ensure_nonmanaged_dict(obj, dictptr));
7276+
}
7277+
}
7278+
7279+
int
7280+
_PyObjectDict_SetItem(PyTypeObject *tp, PyObject *obj, PyObject **dictptr,
7281+
PyObject *key, PyObject *value)
7282+
{
7283+
PyObject *dict;
7284+
int res;
7285+
7286+
assert(dictptr != NULL);
7287+
dict = ensure_nonmanaged_dict(obj, dictptr);
7288+
if (dict == NULL) {
7289+
return -1;
72807290
}
72817291

72827292
Py_BEGIN_CRITICAL_SECTION(dict);

Objects/object.c

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1789,7 +1789,11 @@ PyObject_GenericSetDict(PyObject *obj, PyObject *value, void *context)
17891789
"not a '%.200s'", Py_TYPE(value)->tp_name);
17901790
return -1;
17911791
}
1792+
#ifdef Py_GIL_DISABLED
1793+
Py_XDECREF(_Py_atomic_exchange_ptr(dictptr, Py_NewRef(value)));
1794+
#else
17921795
Py_XSETREF(*dictptr, Py_NewRef(value));
1796+
#endif
17931797
return 0;
17941798
}
17951799

0 commit comments

Comments
 (0)
0