8000 Reproduce security issue · jnak/graphql-core@4cef037 · GitHub
[go: up one dir, main page]

Skip to content

Commit 4cef037

Browse files
committed
Reproduce security issue
1 parent a26e8aa commit 4cef037

File tree

1 file changed

+121
-0
lines changed

1 file changed

+121
-0
lines changed

tests/test_concurrency.py

Lines changed: 121 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,121 @@
1+
from graphql.type import (
2+
GraphQLField,
3+
GraphQLObjectType,
4+
GraphQLSchema,
5+
GraphQLString,
6+
)
7+
8+
from graphql import graphql
9+
import threading
10+
from promise import dataloader, promise
11+
12+
13+
REQUEST_GLOBALS = threading.local()
14+
15+
16+
def viewer_id_resolver(root, info, **args):
17+
return REQUEST_GLOBALS.current_user_id
18+
19+
20+
def promise_viewer_id_resolver(root, info, **args):
21+
return promise.Promise.resolve(None).then(lambda x: REQUEST_GLOBALS.current_user_id)
22+
23+
24+
class UserIdLoader(dataloader.DataLoader):
25+
def batch_load_fn(self, user_ids):
26+
return promise.Promise.resolve(user_ids)
27+
28+
29+
user_id_loader = UserIdLoader()
30+
31+
32+
def dataloader_viewer_id_resolver(root, info, **args):
33+
return user_id_loader.load(REQUEST_GLOBALS.current_user_id)
34+
35+
36+
queryType = GraphQLObjectType(
37+
"Query",
38+
fields=lambda: {
39+
"viewerId": GraphQLField(
40+
GraphQLString,
41+
resolver=viewer_id_resolver,
42+
),
43+
"promiseViewerUserId": GraphQLField(
44+
GraphQLString,
45+
resolver=promise_viewer_id_resolver,
46+
),
47+
"dataloaderViewerUserId": GraphQLField(
48+
GraphQLString,
49+
resolver=dataloader_viewer_id_resolver,
50+
),
51+
52+
},
53+
)
54+
55+
Schema = GraphQLSchema(query=queryType)
56+
57+
58+
def handle_request(session, query, variables={}):
59+
# Authenticate requests and set global user id
60+
# https://django-globals.readthedocs.io/en/latest/#usage
61+
# https://flask.palletsprojects.com/en/1.1.x/appcontext/
62+
REQUEST_GLOBALS.current_user_id = session.get('userId')
63+
64+
return graphql(Schema, query, variables)
65+
66+
67+
def send_request(user_id, query):
68+
session = {"userId": user_id}
69+
i = 0
70+
while i < 1000:
71+
i += 1
72+
73+
result = handle_request(session, query)
74+
if result.errors:
75+
raise Exception('Execution error', result.errors)
76+
if 'viewerId' not in result.data:
77+
raise Exception('Missing data', result.data)
78+
assert result.data['viewerId'] == user_id, \
79+
"request #{}: logged in user {} - actual user {}".format(str(i), user_id, result.data['viewerId'])
80+
81+
82+
def simulate_concurrent_requests(query):
83+
threads = [threading.Thread(target=send_request, args=(user_id, query)) for user_id in ['1', '2']]
84+
85+
for thread in threads:
86+
thread.start()
87+
88+
for thread in threads:
89+
thread.join()
90+
91+
92+
def test_regular_field():
93+
simulate_concurrent_requests("""
94+
query {
95+
viewerId
96+
}
97+
""")
98+
99+
100+
def test_promised_field():
101+
simulate_concurrent_requests("""
102+
query {
103+
viewerId: promiseViewerUserId
104+
}
105+
""")
106+
107+
108+
def test_dataloader_field():
109+
simulate_concurrent_requests("""
110+
query {
111+
viewerId: dataloaderViewerUserId
112+
}
113+
""")
114+
115+
116+
# Run this directly
117+
# PyTest does not report exceptions happening in threads
118+
if __name__ == '__main__':
119+
test_regular_field()
120+
test_promised_field()
121+
test_dataloader_field()

0 commit comments

Comments
 (0)
0