10000 Implementation of automatic batching for async by leszekhanusz · Pull Request #554 · graphql-python/gql · GitHub
[go: up one dir, main page]

Skip to content

Implementation of automatic batching for async #554

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
Implementation of automatic batching for async
  • Loading branch information
leszekhanusz committed May 26, 2025
commit 1480d3b81fb8fea877e5e205512a881b87458952
177 changes: 160 additions & 17 deletions gql/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -829,15 +829,11 @@ async def connect_async(self, reconnecting=False, **kwargs):

if reconnecting:
self.session = ReconnectingAsyncClientSession(client=self, **kwargs)
await self.session.start_connecting_task()
else:
try:
await self.transport.connect()
except Exception as e:
await self.transport.close()
raise e
self.session = AsyncClientSession(client=self)

await self.session.connect()

# Get schema from transport if needed
try:
if self.fetch_schema_from_transport and not self.schema:
Expand All @@ -846,18 +842,15 @@ async def connect_async(self, reconnecting=False, **kwargs):
# we don't know what type of exception is thrown here because it
# depends on the underlying transport; we just make sure that the
# transport is closed and re-raise the exception
await self.transport.close()
await self.session.close()
raise

return self.session

async def close_async(self):
"""Close the async transport and stop the optional reconnecting task."""

if isinstance(self.session, ReconnectingAsyncClientSession):
await self.session.stop_connecting_task()

await self.transport.close()
await self.session.close()

async def __aenter__(self):
return await self.connect_async()
Expand Down Expand Up @@ -1564,12 +1557,17 @@ async def _execute(
):
request = request.serialize_variable_values(self.client.schema)

# Execute the query with the transport with a timeout
with fail_after(self.client.execute_timeout):
result = await self.transport.execute(
request,
**kwargs,
)
# Check if batching is enabled
if self.client.batching_enabled:
future_result = await self._execute_future(request)
result = await future_result
else:
# Execute the query with the transport with a timeout
with fail_after(self.client.execute_timeout):
result = await self.transport.execute(
request,
**kwargs,
)

# Unserialize the result if requested
if self.client.schema:
Expand Down Expand Up @@ -1828,6 +1826,134 @@ async def execute_batch(

return cast(List[Dict[str, Any]], [result.data for result in results])

async def _batch_loop(self) -> None:
"""Main loop of the task used to wait for requests
to execute them in a batch"""

stop_loop = False

while not stop_loop:
# First wait for a first request in from the batch queue
requests_and_futures: List[Tuple[GraphQLRequest, asyncio.Future]] = []

# Wait for the first request
request_and_future: Optional[Tuple[GraphQLRequest, asyncio.Future]] = (
await self.batch_queue.get()
)

if request_and_future is None:
# None is our sentinel value to stop the loop
break

requests_and_futures.append(request_and_future)

# Then wait the requested batch interval except if we already
# have the maximum number of requests in the queue
if self.batch_queue.qsize() < self.client.batch_max - 1:
# Wait for the batch interval
await asyncio.sleep(self.client.batch_interval)

# Then get the requests which had been made during that wait interval
for _ in range(self.client.batch_max - 1):
try:
# Use get_nowait since we don't want to wait here
request_and_future = self.batch_queue.get_nowait()

if request_and_future is None:
# Sentinel value - stop after processing current batch
stop_loop = True
break

requests_and_futures.append(request_and_future)

except asyncio.QueueEmpty:
# No more requests in queue, that's fine
break

# Extract requests and futures
requests = [request for request, _ in requests_and_futures]
futures = [future for _, future in requests_and_futures]

# Execute the batch
try:
results: List[ExecutionResult] = await self._execute_batch(
requests,
serialize_variables=False, # already done
parse_result=False, # will be done later
validate_document=False, # already validated
)

# Set the result for each future
for result, future in zip(results, futures):
if not future.cancelled():
future.set_result(result)

except Exception as exc:
# If batch execution fails, propagate the error to all futures
for future in futures:
if not future.cancelled():
future.set_exception(exc)

# Signal that the task has stopped
self._batch_task_stopped_event.set()

async def _execute_future(
self,
request: GraphQLRequest,
) -> asyncio.Future:
"""If batching is enabled, this method will put a request in the batching queue
instead of executing it directly so that the requests could be put in a batch.
"""

assert hasattr(self, "batch_queue"), "Batching is not enabled"
assert not self._batch_task_stop_requested, "Batching task has been stopped"

future: asyncio.Future = asyncio.Future()
await self.batch_queue.put((request, future))

return future

async def _batch_init(self):
"""Initialize the batch task loop if batching is enabled."""
if self.client.batching_enabled:
self.batch_queue: asyncio.Queue = asyncio.Queue()
self._batch_task_stop_requested = False
self._batch_task_stopped_event = asyncio.Event()
self._batch_task = asyncio.create_task(self._batch_loop())

async def _batch_cleanup(self):
"""Cleanup the batching task if batching is enabled."""
if hasattr(self, "_batch_task_stopped_event"):
# Send a None in the queue to indicate that the batching task must stop
# after having processed the remaining requests in the queue
self._batch_task_stop_requested = True
await self.batch_queue.put(None)

# Wait for the task to process remaining requests and stop
await self._batch_task_stopped_event.wait()

async def connect(self):
"""Connect the transport and initialize the batch task loop if batching
is enabled."""

await self._batch_init()

try:
await self.transport.connect()
except Exception as e:
await self.transport.close()
raise e

async def close(self):
"""Close the transport and cleanup the batching task if batching is enabled.

Will wait until all the remaining requests in the batch processing queue
have been executed.
"""
await self._batch_cleanup()

await self.transport.close()

async def fetch_schema(self) -> None:
"""Fetch the GraphQL schema explicitly using introspection.

Expand Down Expand Up @@ -1954,6 +2080,23 @@ async def stop_connecting_task(self):
self._connect_task.cancel()
self._connect_task = None

async def connect(self):
"""Start the connect task and initialize the batch task loop if batching
is enabled."""

await self._batch_init()

await self.start_connecting_task()

async def close(self):
"""Stop the connect task and cleanup the batching task
if batching is enabled."""
await self._batch_cleanup()

await self.stop_connecting_task()

await self.transport.close()

async def _execute_once(
self,
request: GraphQLRequest,
Expand Down
48 changes: 30 additions & 18 deletions gql/transport/aiohttp.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,22 +274,35 @@ def _prepare_file_uploads(

return post_args

async def raise_response_error(
self,
@staticmethod
def _raise_transport_server_error_if_status_more_than_400(
resp: aiohttp.ClientResponse,
reason: str,
) -> None:
# We raise a TransportServerError if status code is 400 or higher
# We raise a TransportProtocolError in the other cases

# If the status is >400,
# then we need to raise a TransportServerError
try:
# Raise ClientResponseError if response status is 400 or higher
resp.raise_for_status()
except ClientResponseError as e:
raise TransportServerError(str(e), e.status) from e

@classmethod
async def _raise_response_error(
cls,
resp: aiohttp.ClientResponse,
reason: str,
) -> None:
# We raise a TransportServerError if status code is 400 or higher
# We raise a TransportProtocolError in the other cases

cls._raise_transport_server_error_if_status_more_than_400(resp)

result_text = await resp.text()
self._raise_invalid_result(result_text, reason)
raise TransportProtocolError(
f"Server did not return a valid GraphQL result: "
f"{reason}: "
f"{result_text}"
)

async def _get_json_result(self, response: aiohttp.ClientResponse) -> Any:

Expand All @@ -304,10 +317,10 @@ async def _get_json_result(self, response: aiohttp.ClientResponse) -> Any:
log.debug("<<< %s", result_text)

except Exception:
await self.raise_response_error(response, "Not a JSON answer")
await self._raise_response_error(response, "Not a JSON answer")

if result is None:
await self.raise_response_error(response, "Not a JSON answer")
await self._raise_response_error(response, "Not a JSON answer")

return result

Expand All @@ -318,7 +331,7 @@ async def _prepare_result(
result = await self._get_json_result(response)

if "errors" not in result and "data" not in result:
await self.raise_response_error(
await self._raise_response_error(
response, 'No "data" or "errors" keys in answer'
)

Expand All @@ -336,14 +349,13 @@ async def _prepare_batch_result(

answers = await self._get_json_result(response)

return get_batch_execution_result_list(reqs, answers)

def _raise_invalid_result(self, result_text: str, reason: str) -> None:
raise TransportProtocolError(
f"Server did not return a valid GraphQL result: "
f"{reason}: "
f"{result_text}"
)
try:
return get_batch_execution_result_list(reqs, answers)
except TransportProtocolError:
# Raise a TransportServerError if status > 400
self._raise_transport_server_error_if_status_more_than_400(response)
# In other cases, raise a TransportProtocolError
raise

async def execute(
self,
Expand Down
29 changes: 22 additions & 7 deletions gql/transport/httpx.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,18 +195,33 @@ def _prepare_batch_result(

answers = self._get_json_result(response)

return get_batch_execution_result_list(reqs, answers)

def _raise_response_error(self, response: httpx.Response, reason: str) -> NoReturn:
# We raise a TransportServerError if the status code is 400 or higher
# We raise a TransportProtocolError in the other cases

try:
# Raise a HTTPError if response status is 400 or higher
return get_batch_execution_result_list(reqs, answers)
except TransportProtocolError:
# Raise a TransportServerError if status > 400
self._raise_transport_server_error_if_status_more_than_400(response)
# In other cases, raise a TransportProtocolError
raise

@staticmethod
def _raise_transport_server_error_if_status_more_than_400(
response: httpx.Response,
) -> None:
# If the status is >400,
# then we need to raise a TransportServerError
try:
# Raise a HTTPStatusError if response status is 400 or higher
response.raise_for_status()
except httpx.HTTPStatusError as e:
raise TransportServerError(str(e), e.response.status_code) from e

@classmethod
def _raise_response_error(cls, response: httpx.Response, reason: str) -> NoReturn:
# We raise a TransportServerError if the status code is 400 or higher
# We raise a TransportProtocolError in the other cases

cls._raise_transport_server_error_if_status_more_than_400(response)

raise TransportProtocolError(
f"Server did not return a GraphQL result: " f"{reason}: " f"{response.text}"
)
Expand Down
Loading
Loading
0