8000 Adding json_serialize and json_deserialize to requests transport (#466) · graphql-python/gql@e5c7c8f · GitHub
[go: up one dir, main page]

Skip to content

Commit e5c7c8f

Browse files
authored
Adding json_serialize and json_deserialize to requests transport (#466)
1 parent a3f0bd9 commit e5c7c8f

File tree

2 files changed

+123
-6
lines changed

2 files changed

+123
-6
lines changed

gql/transport/requests.py

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import io
22
import json
33
import logging
4-
from typing import Any, Collection, Dict, List, Optional, Tuple, Type, Union
4+
from typing import Any, Callable, Collection, Dict, List, Optional, Tuple, Type, Union
55

66
import requests
77
from graphql import DocumentNode, ExecutionResult, print_ast
@@ -47,6 +47,8 @@ def __init__(
4747
method: str = "POST",
4848
retry_backoff_factor: float = 0.1,
4949
retry_status_forcelist: Collection[int] = _default_retry_codes,
50+
json_serialize: Callable = json.dumps,
51+
json_deserialize: Callable = json.loads,
5052
**kwargs: Any,
5153
):
5254
"""Initialize the transport with the given request parameters.
@@ -73,6 +75,10 @@ def __init__(
7375
should force a retry on. A retry is initiated if the request method is
7476
in allowed_methods and the response status code is in status_forcelist.
7577
(Default: [429, 500, 502, 503, 504])
78+
:param json_serialize: Json serializer callable.
79+
By default json.dumps() function
80+
:param json_deserialize: Json deserializer callable.
81+
By default json.loads() function
7682
:param kwargs: Optional arguments that ``request`` takes.
7783
These can be seen at the `requests`_ source code or the official `docs`_
7884
@@ -90,6 +96,8 @@ def __init__(
9096
self.method = method
9197
self.retry_backoff_factor = retry_backoff_factor
9298
self.retry_status_forcelist = retry_status_forcelist
99+
self.json_serialize: Callable = json_serialize
100+
self.json_deserialize: Callable = json_deserialize
93101
self.kwargs = kwargs
94102

95103
self.session = None
@@ -174,7 +182,7 @@ def execute( # type: ignore
174182
payload["variables"] = nulled_variable_values
175183

176184
# Add the payload to the operations field
177-
operations_str = json.dumps(payload)
185+
operations_str = self.json_serialize(payload)
178186
log.debug("operations %s", operations_str)
179187

180188
# Generate the file map
@@ -188,7 +196,7 @@ def execute( # type: ignore
188196
file_streams = {str(i): files[path] for i, path in enumerate(files)}
189197

190198
# Add the file map field
191-
file_map_str = json.dumps(file_map)
199+
file_map_str = self.json_serialize(file_map)
192200
log.debug("file_map %s", file_map_str)
193201

194202
fields = {"operations": operations_str, "map": file_map_str}
@@ -224,7 +232,7 @@ def execute( # type: ignore
224232

225233
# Log the payload
226234
if log.isEnabledFor(logging.INFO):
227-
log.info(">>> %s", json.dumps(payload))
235+
log.info(">>> %s", self.json_serialize(payload))
228236

229237
# Pass kwargs to requests post method
230238
post_args.update(self.kwargs)
@@ -257,7 +265,10 @@ def raise_response_error(resp: requests.Response, reason: str):
257265
)
258266

259267
try:
260-
result = response.json()
268+
if self.json_deserialize == json.loads:
269+
result = response.json()
270+
else:
271+
result = self.json_deserialize(response.text)
261272

262273
if log.isEnabledFor(logging.INFO):
263274
log.info("<<< %s", response.text)
@@ -396,7 +407,7 @@ def _build_batch_post_args(
396407

397408
# Log the payload
398409
if log.isEnabledFor(logging.INFO):
399-
log.info(">>> %s", json.dumps(post_args[data_key]))
410+
log.info(">>> %s", self.json_serialize(post_args[data_key]))
400411

401412
# Pass kwargs to requests post method
402413
post_args.update(self.kwargs)

tests/test_requests.py

Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -923,3 +923,109 @@ def test_code():
923923
assert transport.session is None
924924

925925
await run_sync_test(event_loop, server, test_code)
926+
927+
928+
@pytest.mark.aiohttp
929+
@pytest.mark.asyncio
930+
async def test_requests_json_serializer(
931+
event_loop, aiohttp_server, run_sync_test, caplog
932+
):
933+
import json
934+
from aiohttp import web
935+
from gql.transport.requests import RequestsHTTPTransport
936+
937+
async def handler(request):
938+
939+
request_text = await request.text()
940+
print("Received on backend: " + request_text)
941+
942+
return web.Response(
943+
text=query1_server_answer,
944+
content_type="application/json",
945+
)
946+
947+
app = web.Application()
948+
app.router.add_route("POST", "/", handler)
949+
server = await aiohttp_server(app)
950+
951+
url = server.make_url("/")
952+
953+
def test_code():
954+
transport = RequestsHTTPTransport(
955+
url=url,
956+
json_serialize=lambda e: json.dumps(e, separators=(",", ":")),
957+
)
958+
959+
with Client(transport=transport) as session:
960+
961+
query = gql(query1_str)
962+
963+
# Execute query asynchronously
964+
result = session.execute(query)
965+
966+
continents = result["continents"]
967+
968+
africa = continents[0]
969+
970+
assert africa["code"] == "AF"
971+
972+
# Checking that there is no space after the colon in the log
973+
expected_log = '"query":"query getContinents'
974+
assert expected_log in caplog.text
975+
976+
await run_sync_test(event_loop, server, test_code)
977+
978+
979+
query_float_str = """
980+
query getPi {
981+
pi
982+
}
983+
"""
984+
985+
query_float_server_answer_data = '{"pi": 3.141592653589793238462643383279502884197}'
986+
987+
query_float_server_answer = f'{{"data":{query_float_server_answer_data}}}'
988+
989+
990+
@pytest.mark.aiohttp
991+
@pytest.mark.asyncio
992+
async def test_requests_json_deserializer(event_loop, aiohttp_server, run_sync_test):
993+
import json
994+
from aiohttp import web
995+
from decimal import Decimal
996+
from functools import partial
997+
from gql.transport.requests import RequestsHTTPTransport
998+
999+
async def handler(request):
1000+
return web.Response(
1001+
text=query_float_server_answer,
1002+
content_type="application/json",
1003+
)
1004+
1005+
app = web.Application()
1006+
app.router.add_route("POST", "/", handler)
1007+
server = await aiohttp_server(app)
1008+
1009+
url = server.make_url("/")
1010+
1011+
def test_code():
1012+
1013+
json_loads = partial(json.loads, parse_float=Decimal)
1014+
1015+
transport = RequestsHTTPTransport(
1016+
url=url,
1017+
json_deserialize=json_loads,
1018+
)
1019+
1020+
with Client(transport=transport) as session:
1021+
1022+
query = gql(query_float_str)
1023+
1024+
# Execute query asynchronously
1025+
result = session.execute(query)
1026+
1027+
pi = result["pi"]
1028+
1029+
assert pi == Decimal("3.141592653589793238462643383279502884197")
1030+
1031+
await run_sync_test(event_loop, server, test_code)

0 commit comments

Comments
 (0)
0