diff --git a/.github/workflows/deploy.yml b/.github/workflows/deploy.yml index 69c11d2a..1b489a95 100644 --- a/.github/workflows/deploy.yml +++ b/.github/workflows/deploy.yml @@ -15,12 +15,14 @@ jobs: uses: actions/setup-python@v5 with: python-version: 3.12 - - name: Build wheel and source tarball + - name: Install build dependencies run: | - pip install wheel setuptools - python setup.py sdist bdist_wheel + python -m pip install --upgrade pip + pip install build wheel + - name: Build package + run: | + python -m build - name: Publish a Python distribution to PyPI - uses: pypa/gh-action-pypi-publish@v1.1.0 + uses: pypa/gh-action-pypi-publish@release/v1 with: - user: __token__ password: ${{ secrets.pypi_password }} diff --git a/Makefile b/Makefile index 59d08bac..9af372f7 100644 --- a/Makefile +++ b/Makefile @@ -24,7 +24,7 @@ tests_websockets: pytest tests --websockets-only check: - isort --recursive $(SRC_PYTHON) + isort $(SRC_PYTHON) black $(SRC_PYTHON) flake8 $(SRC_PYTHON) mypy $(SRC_PYTHON) diff --git a/README.md b/README.md index cbc53af6..e79a63d2 100644 --- a/README.md +++ b/README.md @@ -40,6 +40,7 @@ The complete documentation for GQL can be found at * Supports [sync or async usage](https://gql.readthedocs.io/en/latest/async/index.html), [allowing concurrent requests](https://gql.readthedocs.io/en/latest/advanced/async_advanced_usage.html#async-advanced-usage) * Supports [File uploads](https://gql.readthedocs.io/en/latest/usage/file_upload.html) * Supports [Custom scalars / Enums](https://gql.readthedocs.io/en/latest/usage/custom_scalars_and_enums.html) +* Supports [Batching requests](https://gql.readthedocs.io/en/latest/advanced/batching_requests.html) * [gql-cli script](https://gql.readthedocs.io/en/latest/gql-cli/intro.html) to execute GraphQL queries or download schemas from the command line * [DSL module](https://gql.readthedocs.io/en/latest/advanced/dsl_module.html) to compose GraphQL queries dynamically diff --git a/docs/advanced/batching_requests.rst b/docs/advanced/batching_requests.rst new file mode 100644 index 00000000..7c9fc9b6 --- /dev/null +++ b/docs/advanced/batching_requests.rst @@ -0,0 +1,96 @@ +.. _batching_requests: + +Batching requests +================= + +If you need to send multiple GraphQL queries to a backend, +and if the backend supports batch requests, +then you might want to send those requests in a batch instead of +making multiple execution requests. + +.. warning:: + - Some backends do not support batch requests + - File uploads and subscriptions are not supported with batch requests + +Batching requests manually +^^^^^^^^^^^^^^^^^^^^^^^^^^ + +To execute a batch of requests manually: + +- First Make a list of :class:`GraphQLRequest ` objects, containing: + * your GraphQL query + * Optional variable_values + * Optional operation_name + +.. code-block:: python + + request1 = gql(""" + query getContinents { + continents { + code + name + } + } + """ + ) + + request2 = GraphQLRequest(""" + query getContinentName ($code: ID!) { + continent (code: $code) { + name + } + } + """, + variable_values={ + "code": "AF", + }, + ) + + requests = [request1, request2] + +- Then use one of the `execute_batch` methods, either on Client, + or in a sync or async session + +**Sync**: + +.. code-block:: python + + transport = RequestsHTTPTransport(url=url) + # Or transport = HTTPXTransport(url=url) + + with Client(transport=transport) as session: + + results = session.execute_batch(requests) + + result1 = results[0] + result2 = results[1] + +**Async**: + +.. code-block:: python + + transport = AIOHTTPTransport(url=url) + # Or transport = HTTPXAsyncTransport(url=url) + + async with Client(transport=transport) as session: + + results = await session.execute_batch(requests) + + result1 = results[0] + result2 = results[1] + +.. note:: + If any request in the batch returns an error, then a TransportQueryError will be raised + with the first error found. + +Automatic Batching of requests +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +If your code execute multiple requests independently in a short time +(either from different threads in sync code, or from different asyncio tasks in async code), +then you can use gql automatic batching of request functionality. + +You define a :code:`batching_interval` in your :class:`Client ` +and each time a new execution request is received through an `execute` method, +we will wait that interval (in seconds) for other requests to arrive +before sending all the requests received in that interval in a single batch. diff --git a/docs/advanced/error_handling.rst b/docs/advanced/error_handling.rst index 4e6618c9..458f2667 100644 --- a/docs/advanced/error_handling.rst +++ b/docs/advanced/error_handling.rst @@ -46,6 +46,11 @@ Here are the possible Transport Errors: If you don't need the schema, you can try to create the client with :code:`fetch_schema_from_transport=False` +- :class:`TransportConnectionFailed `: + This exception is generated when an unexpected Exception is received from the + transport dependency when trying to connect or to send the request. + For example in case of an SSL error, or if a websocket connection suddenly fails. + - :class:`TransportClosed `: This exception is generated when the client is trying to use the transport while the transport was previously closed. diff --git a/docs/advanced/index.rst b/docs/advanced/index.rst index baae9276..ef14defd 100644 --- a/docs/advanced/index.rst +++ b/docs/advanced/index.rst @@ -6,6 +6,7 @@ Advanced async_advanced_usage async_permanent_session + batching_requests logging error_handling local_schema diff --git a/docs/advanced/logging.rst b/docs/advanced/logging.rst index 02fdf3fd..f75c5f32 100644 --- a/docs/advanced/logging.rst +++ b/docs/advanced/logging.rst @@ -4,14 +4,7 @@ Logging GQL uses the python `logging`_ module. In order to debug a problem, you can enable logging to see the messages exchanged between the client and the server. -To do that, set the loglevel at **INFO** at the beginning of your code: - -.. code-block:: python - - import logging - logging.basicConfig(level=logging.INFO) - -For even more logs, you can set the loglevel at **DEBUG**: +To do that, set the loglevel at **DEBUG** at the beginning of your code: .. code-block:: python @@ -21,10 +14,7 @@ For even more logs, you can set the loglevel at **DEBUG**: Disabling logs -------------- -By default, the logs for the transports are quite verbose. - -On the **INFO** level, all the messages between the frontend and the backend are logged which can -be difficult to read especially when it fetches the schema from the transport. +On the **DEBUG** log level, the logs for the transports are quite verbose. It is possible to disable the logs only for a specific gql transport by setting a higher log level for this transport (**WARNING** for example) so that the other logs of your program are not affected. diff --git a/docs/code_examples/aiohttp_async_dsl.py b/docs/code_examples/aiohttp_async_dsl.py index 958ea490..2c4804db 100644 --- a/docs/code_examples/aiohttp_async_dsl.py +++ b/docs/code_examples/aiohttp_async_dsl.py @@ -17,6 +17,8 @@ async def main(): # GQL will fetch the schema just after the establishment of the first session async with client as session: + assert client.schema is not None + # Instantiate the root of the DSL Schema as ds ds = DSLSchema(client.schema) diff --git a/docs/code_examples/appsync/mutation_api_key.py b/docs/code_examples/appsync/mutation_api_key.py index 634e2439..47067aca 100644 --- a/docs/code_examples/appsync/mutation_api_key.py +++ b/docs/code_examples/appsync/mutation_api_key.py @@ -46,9 +46,9 @@ async def main(): }""" ) - variable_values = {"message": "Hello world!"} + query.variable_values = {"message": "Hello world!"} - result = await session.execute(query, variable_values=variable_values) + result = await session.execute(query) print(result) diff --git a/docs/code_examples/appsync/mutation_iam.py b/docs/code_examples/appsync/mutation_iam.py index 3cc04a5a..efe9889b 100644 --- a/docs/code_examples/appsync/mutation_iam.py +++ b/docs/code_examples/appsync/mutation_iam.py @@ -45,9 +45,9 @@ async def main(): }""" ) - variable_values = {"message": "Hello world!"} + query.variable_values = {"message": "Hello world!"} - result = await session.execute(query, variable_values=variable_values) + result = await session.execute(query) print(result) diff --git a/docs/code_examples/console_async.py b/docs/code_examples/console_async.py index 9a5e94e5..69c71bce 100644 --- a/docs/code_examples/console_async.py +++ b/docs/code_examples/console_async.py @@ -1,8 +1,11 @@ import asyncio import logging +from typing import Optional from aioconsole import ainput + from gql import Client, gql +from gql.client import AsyncClientSession from gql.transport.aiohttp import AIOHTTPTransport logging.basicConfig(level=logging.INFO) @@ -21,7 +24,7 @@ def __init__(self): self._client = Client( transport=AIOHTTPTransport(url="https://countries.trevorblades.com/") ) - self._session = None + self._session: Optional[AsyncClientSession] = None self.get_continent_name_query = gql(GET_CONTINENT_NAME) @@ -32,13 +35,13 @@ async def close(self): await self._client.close_async() async def get_continent_name(self, code): - params = {"code": code} + self.get_continent_name_query.variable_values = {"code": code} - answer = await self._session.execute( - self.get_continent_name_query, variable_values=params - ) + assert self._session is not None + + answer = await self._session.execute(self.get_continent_name_query) - return answer.get("continent").get("name") + return answer.get("continent").get("name") # type: ignore async def main(): diff --git a/docs/code_examples/fastapi_async.py b/docs/code_examples/fastapi_async.py index 80920252..0b174fe5 100644 --- a/docs/code_examples/fastapi_async.py +++ b/docs/code_examples/fastapi_async.py @@ -10,7 +10,9 @@ from fastapi import FastAPI, HTTPException from fastapi.responses import HTMLResponse + from gql import Client, gql +from gql.client import ReconnectingAsyncClientSession from gql.transport.aiohttp import AIOHTTPTransport logging.basicConfig(level=logging.DEBUG) @@ -90,9 +92,9 @@ async def get_continent(continent_code): raise HTTPException(status_code=404, detail="Continent not found") try: - result = await client.session.execute( - query, variable_values={"code": continent_code} - ) + assert isinstance(client.session, ReconnectingAsyncClientSession) + query.variable_values = {"code": continent_code} + result = await client.session.execute(query) except Exception as e: log.debug(f"get_continent Error: {e}") raise HTTPException(status_code=503, detail="GraphQL backend unavailable") diff --git a/docs/code_examples/httpx_async_trio.py b/docs/code_examples/httpx_async_trio.py index b76dab42..058b952b 100644 --- a/docs/code_examples/httpx_async_trio.py +++ b/docs/code_examples/httpx_async_trio.py @@ -1,4 +1,5 @@ import trio + from gql import Client, gql from gql.transport.httpx import HTTPXAsyncTransport diff --git a/docs/code_examples/reconnecting_mutation_http.py b/docs/code_examples/reconnecting_mutation_http.py index f4329c8b..5deb5063 100644 --- a/docs/code_examples/reconnecting_mutation_http.py +++ b/docs/code_examples/reconnecting_mutation_http.py @@ -33,10 +33,10 @@ async def main(): # Execute single query query = gql("mutation ($message: String!) {sendMessage(message: $message)}") - params = {"message": f"test {num}"} + query.variable_values = {"message": f"test {num}"} try: - result = await session.execute(query, variable_values=params) + result = await session.execute(query) print(result) except Exception as e: print(f"Received exception {e}") diff --git a/docs/code_examples/reconnecting_mutation_ws.py b/docs/code_examples/reconnecting_mutation_ws.py index 7d7c8f8a..d7e7cfe2 100644 --- a/docs/code_examples/reconnecting_mutation_ws.py +++ b/docs/code_examples/reconnecting_mutation_ws.py @@ -33,10 +33,10 @@ async def main(): # Execute single query query = gql("mutation ($message: String!) {sendMessage(message: $message)}") - params = {"message": f"test {num}"} + query.variable_values = {"message": f"test {num}"} try: - result = await session.execute(query, variable_values=params) + result = await session.execute(query) print(result) except Exception as e: print(f"Received exception {e}") diff --git a/docs/conf.py b/docs/conf.py index 94daf942..8289ef4b 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -83,11 +83,11 @@ intersphinx_mapping = { 'aiohttp': ('https://docs.aiohttp.org/en/stable/', None), 'graphql': ('https://graphql-core-3.readthedocs.io/en/latest/', None), - 'multidict': ('https://multidict.readthedocs.io/en/stable/', None), + 'multidict': ('https://multidict.aio-libs.org/en/stable/', None), 'python': ('https://docs.python.org/3/', None), 'requests': ('https://requests.readthedocs.io/en/latest/', None), 'websockets': ('https://websockets.readthedocs.io/en/11.0.3/', None), - 'yarl': ('https://yarl.readthedocs.io/en/stable/', None), + 'yarl': ('https://yarl.aio-libs.org/en/stable/', None), } nitpick_ignore = [ @@ -100,6 +100,8 @@ ('py:class', 'asyncio.locks.Event'), # aiohttp: should be fixed + # See issue: https://github.com/aio-libs/aiohttp/issues/10468 + ('py:class', 'aiohttp.client.ClientSession'), ('py:class', 'aiohttp.client_reqrep.Fingerprint'), ('py:class', 'aiohttp.helpers.BasicAuth'), diff --git a/docs/gql-cli/intro.rst b/docs/gql-cli/intro.rst index f88b60a1..c3237093 100644 --- a/docs/gql-cli/intro.rst +++ b/docs/gql-cli/intro.rst @@ -79,12 +79,14 @@ Print the GraphQL schema in a file $ gql-cli https://countries.trevorblades.com/graphql --print-schema > schema.graphql -.. note:: - - By default, deprecated input fields are not requested from the backend. - You can add :code:`--schema-download input_value_deprecation:true` to request them. - .. note:: You can add :code:`--schema-download descriptions:false` to request a compact schema without comments. + +.. warning:: + + By default, from gql version 4.0, deprecated input fields are requested from the backend. + It is possible that some old backends do not support this feature. In that case + you can add :code:`--schema-download input_value_deprecation:false` to go back + to the previous behavior. diff --git a/docs/modules/gql.rst b/docs/modules/gql.rst index b7c13c7c..035f196f 100644 --- a/docs/modules/gql.rst +++ b/docs/modules/gql.rst @@ -24,11 +24,15 @@ Sub-Packages transport_aiohttp_websockets transport_appsync_auth transport_appsync_websockets + transport_common_base + transport_common_adapters_connection + transport_common_adapters_aiohttp + transport_common_adapters_websockets transport_exceptions transport_phoenix_channel_websockets transport_requests transport_httpx transport_websockets - transport_websockets_base + transport_websockets_protocol dsl utilities diff --git a/docs/modules/transport_common_adapters_aiohttp.rst b/docs/modules/transport_common_adapters_aiohttp.rst new file mode 100644 index 00000000..537c8673 --- /dev/null +++ b/docs/modules/transport_common_adapters_aiohttp.rst @@ -0,0 +1,7 @@ +gql.transport.common.adapters.aiohttp +===================================== + +.. currentmodule:: gql.transport.common.adapters.aiohttp + +.. automodule:: gql.transport.common.adapters.aiohttp + :member-order: bysource diff --git a/docs/modules/transport_common_adapters_connection.rst b/docs/modules/transport_common_adapters_connection.rst new file mode 100644 index 00000000..ffa1a1b3 --- /dev/null +++ b/docs/modules/transport_common_adapters_connection.rst @@ -0,0 +1,7 @@ +gql.transport.common.adapters.connection +======================================== + +.. currentmodule:: gql.transport.common.adapters.connection + +.. automodule:: gql.transport.common.adapters.connection + :member-order: bysource diff --git a/docs/modules/transport_common_adapters_websockets.rst b/docs/modules/transport_common_adapters_websockets.rst new file mode 100644 index 00000000..4005694c --- /dev/null +++ b/docs/modules/transport_common_adapters_websockets.rst @@ -0,0 +1,7 @@ +gql.transport.common.adapters.websockets +======================================== + +.. currentmodule:: gql.transport.common.adapters.websockets + +.. automodule:: gql.transport.common.adapters.websockets + :member-order: bysource diff --git a/docs/modules/transport_common_base.rst b/docs/modules/transport_common_base.rst new file mode 100644 index 00000000..4a7ec15a --- /dev/null +++ b/docs/modules/transport_common_base.rst @@ -0,0 +1,7 @@ +gql.transport.common.base +========================= + +.. currentmodule:: gql.transport.common.base + +.. automodule:: gql.transport.common.base + :member-order: bysource diff --git a/docs/modules/transport_websockets_base.rst b/docs/modules/transport_websockets_base.rst deleted file mode 100644 index 548351eb..00000000 --- a/docs/modules/transport_websockets_base.rst +++ /dev/null @@ -1,7 +0,0 @@ -gql.transport.websockets_base -============================= - -.. currentmodule:: gql.transport.websockets_base - -.. automodule:: gql.transport.websockets_base - :member-order: bysource diff --git a/docs/modules/transport_websockets_protocol.rst b/docs/modules/transport_websockets_protocol.rst new file mode 100644 index 00000000..b835abee --- /dev/null +++ b/docs/modules/transport_websockets_protocol.rst @@ -0,0 +1,7 @@ +gql.transport.websockets_protocol +================================= + +.. currentmodule:: gql.transport.websockets_protocol + +.. automodule:: gql.transport.websockets_protocol + :member-order: bysource diff --git a/docs/usage/custom_scalars_and_enums.rst b/docs/usage/custom_scalars_and_enums.rst index fc9008d8..f85b583a 100644 --- a/docs/usage/custom_scalars_and_enums.rst +++ b/docs/usage/custom_scalars_and_enums.rst @@ -203,11 +203,11 @@ In a variable query = gql("query shift5days($time: Datetime) {shiftDays(time: $time, days: 5)}") - variable_values = { + query.variable_values = { "time": "2021-11-12T11:58:13.461161", } - result = client.execute(query, variable_values=variable_values) + result = client.execute(query) - enum: @@ -220,11 +220,11 @@ In a variable }""" ) - variable_values = { + query.variable_values = { "color": 'RED', } - result = client.execute(query, variable_values=variable_values) + result = client.execute(query) Automatically ^^^^^^^^^^^^^ @@ -256,12 +256,10 @@ Examples: query = gql("query shift5days($time: Datetime) {shiftDays(time: $time, days: 5)}") # the argument for time is a datetime instance - variable_values = {"time": datetime.now()} + query.variable_values = {"time": datetime.now()} # we execute the query with serialize_variables set to True - result = await session.execute( - query, variable_values=variable_values, serialize_variables=True - ) + result = await session.execute(query, serialize_variables=True) - enums: @@ -285,14 +283,12 @@ Examples: ) # the argument for time is an instance of our Enum type - variable_values = { + query.variable_values = { "color": Color.RED, } # we execute the query with serialize_variables set to True - result = client.execute( - query, variable_values=variable_values, serialize_variables=True - ) + result = client.execute(query, serialize_variables=True) Parsing output -------------- @@ -319,11 +315,10 @@ Same example as above, with result parsing enabled: query = gql("query shift5days($time: Datetime) {shiftDays(time: $time, days: 5)}") - variable_values = {"time": datetime.now()} + query.variable_values = {"time": datetime.now()} result = await session.execute( query, - variable_values=variable_values, serialize_variables=True, parse_result=True, ) diff --git a/docs/usage/file_upload.rst b/docs/usage/file_upload.rst index 10903585..09d51742 100644 --- a/docs/usage/file_upload.rst +++ b/docs/usage/file_upload.rst @@ -14,11 +14,14 @@ Single File In order to upload a single file, you need to: * set the file as a variable value in the mutation -* provide the opened file to the `variable_values` argument of `execute` +* create a :class:`FileVar ` object with your file path +* provide the `FileVar` instance to the `variable_values` attribute of your query * set the `upload_files` argument to True .. code-block:: python + from gql import client, gql, FileVar + transport = AIOHTTPTransport(url='YOUR_URL') # Or transport = RequestsHTTPTransport(url='YOUR_URL') # Or transport = HTTPXTransport(url='YOUR_URL') @@ -34,32 +37,36 @@ In order to upload a single file, you need to: } ''') - with open("YOUR_FILE_PATH", "rb") as f: - - params = {"file": f} + query.variable_values = {"file": FileVar("YOUR_FILE_PATH")} - result = client.execute( - query, variable_values=params, upload_files=True - ) + result = client.execute(query, upload_files=True) Setting the content-type ^^^^^^^^^^^^^^^^^^^^^^^^ If you need to set a specific Content-Type attribute to a file, -you can set the :code:`content_type` attribute of the file like this: +you can set the :code:`content_type` attribute of :class:`FileVar `: .. code-block:: python - with open("YOUR_FILE_PATH", "rb") as f: + # Setting the content-type to a pdf file for example + filevar = FileVar( + "YOUR_FILE_PATH", + content_type="application/pdf", + ) + +Setting the uploaded file name +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - # Setting the content-type to a pdf file for example - f.content_type = "application/pdf" +To modify the uploaded filename, use the :code:`filename` attribute of :class:`FileVar `: - params = {"file": f} +.. code-block:: python - result = client.execute( - query, variable_values=params, upload_files=True - ) + # Setting the content-type to a pdf file for example + filevar = FileVar( + "YOUR_FILE_PATH", + filename="filename1.txt", + ) File list --------- @@ -68,6 +75,8 @@ It is also possible to upload multiple files using a list. .. code-block:: python + from gql import client, gql, FileVar + transport = AIOHTTPTransport(url='YOUR_URL') # Or transport = RequestsHTTPTransport(url='YOUR_URL') # Or transport = HTTPXTransport(url='YOUR_URL') @@ -83,17 +92,12 @@ It is also possible to upload multiple files using a list. } ''') - f1 = open("YOUR_FILE_PATH_1", "rb") - f2 = open("YOUR_FILE_PATH_2", "rb") - - params = {"files": [f1, f2]} + f1 = FileVar("YOUR_FILE_PATH_1") + f2 = FileVar("YOUR_FILE_PATH_2") - result = client.execute( - query, variable_values=params, upload_files=True - ) + query.variable_values = {"files": [f1, f2]} - f1.close() - f2.close() + result = client.execute(query, upload_files=True) Streaming @@ -120,18 +124,8 @@ Streaming local files aiohttp allows to upload files using an asynchronous generator. See `Streaming uploads on aiohttp docs`_. - -In order to stream local files, instead of providing opened files to the -`variable_values` argument of `execute`, you need to provide an async generator -which will provide parts of the files. - -You can use `aiofiles`_ -to read the files in chunks and create this asynchronous generator. - -.. _Streaming uploads on aiohttp docs: https://docs.aiohttp.org/en/stable/client_quickstart.html#streaming-uploads -.. _aiofiles: https://github.com/Tinche/aiofiles - -Example: +From gql version 4.0, it is possible to activate file streaming simply by +setting the `streaming` argument of :class:`FileVar ` to `True` .. code-block:: python @@ -147,18 +141,34 @@ Example: } ''') + f1 = FileVar( + file_name='YOUR_FILE_PATH', + streaming=True, + ) + + query.variable_values = {"file": f1} + + result = client.execute(query, upload_files=True) + +Another option is to use an async generator to provide parts of the file. + +You can use `aiofiles`_ +to read the files in chunks and create this asynchronous generator. + +.. _Streaming uploads on aiohttp docs: https://docs.aiohttp.org/en/stable/client_quickstart.html#streaming-uploads +.. _aiofiles: https://github.com/Tinche/aiofiles + +.. code-block:: python + async def file_sender(file_name): async with aiofiles.open(file_name, 'rb') as f: - chunk = await f.read(64*1024) - while chunk: - yield chunk - chunk = await f.read(64*1024) + while chunk := await f.read(64*1024): + yield chunk - params = {"file": file_sender(file_name='YOUR_FILE_PATH')} + f1 = FileVar(file_sender(file_name='YOUR_FILE_PATH')) + query.variable_values = {"file": f1} - result = client.execute( - query, variable_values=params, upload_files=True - ) + result = client.execute(query, upload_files=True) Streaming downloaded files ^^^^^^^^^^^^^^^^^^^^^^^^^^ @@ -175,7 +185,7 @@ In order to do that, you need to: * get the response from an aiohttp request and then get the StreamReader instance from `resp.content` -* provide the StreamReader instance to the `variable_values` argument of `execute` +* provide the StreamReader instance to the `variable_values` attribute of your query Example: @@ -186,7 +196,7 @@ Example: async with http_client.get('YOUR_DOWNLOAD_URL') as resp: # We now have a StreamReader instance in resp.content - # and we provide it to the variable_values argument of execute + # and we provide it to the variable_values attribute of the query transport = AIOHTTPTransport(url='YOUR_GRAPHQL_URL') @@ -200,8 +210,6 @@ Example: } ''') - params = {"file": resp.content} + query.variable_values = {"file": FileVar(resp.content)} - result = client.execute( - query, variable_values=params, upload_files=True - ) + result = client.execute(query, upload_files=True) diff --git a/docs/usage/validation.rst b/docs/usage/validation.rst index f9711f31..18b1cda1 100644 --- a/docs/usage/validation.rst +++ b/docs/usage/validation.rst @@ -24,7 +24,7 @@ The schema can be provided as a String (which is usually stored in a .graphql fi .. note:: You can download a schema from a server by using :ref:`gql-cli ` - :code:`$ gql-cli https://SERVER_URL/graphql --print-schema --schema-download input_value_deprecation:true > schema.graphql` + :code:`$ gql-cli https://SERVER_URL/graphql --print-schema > schema.graphql` OR can be created using python classes: diff --git a/docs/usage/variables.rst b/docs/usage/variables.rst index 81924c6e..1eddd042 100644 --- a/docs/usage/variables.rst +++ b/docs/usage/variables.rst @@ -2,7 +2,7 @@ Using variables =============== It is possible to provide variable values with your query by providing a Dict to -the variable_values argument of the `execute` or the `subscribe` methods. +the variable_values attribute of your query. The variable values will be sent alongside the query in the transport message (there is no local substitution). @@ -19,14 +19,14 @@ The variable values will be sent alongside the query in the transport message """ ) - params = {"code": "EU"} + query.variable_values = {"code": "EU"} # Get name of continent with code "EU" - result = client.execute(query, variable_values=params) + result = client.execute(query) print(result) - params = {"code": "AF"} + query.variable_values = {"code": "AF"} # Get name of continent with code "AF" - result = client.execute(query, variable_values=params) + result = client.execute(query) print(result) diff --git a/gql/__init__.py b/gql/__init__.py index 8eaa0b7c..4c9a6aa0 100644 --- a/gql/__init__.py +++ b/gql/__init__.py @@ -11,10 +11,12 @@ from .client import Client from .gql import gql from .graphql_request import GraphQLRequest +from .transport.file_upload import FileVar __all__ = [ "__version__", "gql", "Client", "GraphQLRequest", + "FileVar", ] diff --git a/gql/__version__.py b/gql/__version__.py index cfe6b54e..65ac68de 100644 --- a/gql/__version__.py +++ b/gql/__version__.py @@ -1 +1 @@ -__version__ = "3.6.0b4" +__version__ = "4.0.0b0" diff --git a/gql/cli.py b/gql/cli.py index 91c67873..37be3656 100644 --- a/gql/cli.py +++ b/gql/cli.py @@ -132,12 +132,12 @@ def get_parser(with_examples: bool = False) -> ArgumentParser: By default, it will: - request field descriptions - - not request deprecated input fields + - request deprecated input fields Possible options: - descriptions:false for a compact schema without comments - - input_value_deprecation:true to download deprecated input fields + - input_value_deprecation:false to omit deprecated input fields - specified_by_url:true - schema_description:true - directive_is_repeatable:true""" @@ -391,9 +391,10 @@ def get_transport(args: Namespace) -> Optional[AsyncTransport]: auth = AppSyncJWTAuthentication(host=url.host, jwt=args.jwt) else: - from gql.transport.appsync_auth import AppSyncIAMAuthentication from botocore.exceptions import NoRegionError + from gql.transport.appsync_auth import AppSyncIAMAuthentication + try: auth = AppSyncIAMAuthentication(host=url.host) except NoRegionError: diff --git a/gql/client.py b/gql/client.py index c52a00b2..e17a0b7c 100644 --- a/gql/client.py +++ b/gql/client.py @@ -24,7 +24,6 @@ import backoff from anyio import fail_after from graphql import ( - DocumentNode, ExecutionResult, GraphQLSchema, IntrospectionQuery, @@ -33,14 +32,13 @@ validate, ) -from .graphql_request import GraphQLRequest +from .graphql_request import GraphQLRequest, support_deprecated_request from .transport.async_transport import AsyncTransport -from .transport.exceptions import TransportClosed, TransportQueryError +from .transport.exceptions import TransportConnectionFailed, TransportQueryError from .transport.local_schema import LocalSchemaTransport from .transport.transport import Transport from .utilities import build_client_schema, get_introspection_query_ast from .utilities import parse_result as parse_result_fn -from .utilities import serialize_variable_values from .utils import str_first_element log = logging.getLogger(__name__) @@ -68,6 +66,7 @@ class Client: def __init__( self, + *, schema: Optional[Union[str, GraphQLSchema]] = None, introspection: Optional[IntrospectionQuery] = None, transport: Optional[Union[Transport, AsyncTransport]] = None, @@ -131,7 +130,10 @@ def __init__( self.introspection: Optional[IntrospectionQuery] = introspection # GraphQL transport chosen - self.transport: Optional[Union[Transport, AsyncTransport]] = transport + assert ( + transport is not None + ), "You need to provide either a transport or a schema to the Client." + self.transport: Union[Transport, AsyncTransport] = transport # Flag to indicate that we need to fetch the schema from the transport # On async transports, we fetch the schema before executing the first query @@ -149,20 +151,22 @@ def __init__( self.batch_max = batch_max @property - def batching_enabled(self): + def batching_enabled(self) -> bool: return self.batch_interval != 0 - def validate(self, document: DocumentNode): + def validate(self, request: GraphQLRequest) -> None: """:meta private:""" assert ( self.schema ), "Cannot validate the document locally, you need to pass a schema." - validation_errors = validate(self.schema, document) + validation_errors = validate(self.schema, request.document) if validation_errors: raise validation_errors[0] - def _build_schema_from_introspection(self, execution_result: ExecutionResult): + def _build_schema_from_introspection( + self, execution_result: ExecutionResult + ) -> None: if execution_result.errors: raise TransportQueryError( ( @@ -179,64 +183,70 @@ def _build_schema_from_introspection(self, execution_result: ExecutionResult): self.introspection = cast(IntrospectionQuery, execution_result.data) self.schema = build_client_schema(self.introspection) + @staticmethod + def _get_event_loop() -> asyncio.AbstractEventLoop: + """Get the current asyncio event loop. + + Or create a new event loop if there isn't one (in a new Thread). + """ + try: + with warnings.catch_warnings(): + warnings.filterwarnings( + "ignore", message="There is no current event loop" + ) + loop = asyncio.get_event_loop() + except RuntimeError: + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + + return loop + @overload def execute_sync( self, - document: DocumentNode, - variable_values: Optional[Dict[str, Any]] = ..., - operation_name: Optional[str] = ..., + request: GraphQLRequest, + *, serialize_variables: Optional[bool] = ..., parse_result: Optional[bool] = ..., - *, # https://github.com/python/mypy/issues/7333#issuecomment-788255229 get_execution_result: Literal[False] = ..., - **kwargs, - ) -> Dict[str, Any]: - ... # pragma: no cover + **kwargs: Any, + ) -> Dict[str, Any]: ... # pragma: no cover @overload def execute_sync( self, - document: DocumentNode, - variable_values: Optional[Dict[str, Any]] = ..., - operation_name: Optional[str] = ..., + request: GraphQLRequest, + *, serialize_variables: Optional[bool] = ..., parse_result: Optional[bool] = ..., - *, get_execution_result: Literal[True], - **kwargs, - ) -> ExecutionResult: - ... # pragma: no cover + **kwargs: Any, + ) -> ExecutionResult: ... # pragma: no cover @overload def execute_sync( self, - document: DocumentNode, - variable_values: Optional[Dict[str, Any]] = ..., - operation_name: Optional[str] = ..., + request: GraphQLRequest, + *, serialize_variables: Optional[bool] = ..., parse_result: Optional[bool] = ..., - *, get_execution_result: bool, - **kwargs, - ) -> Union[Dict[str, Any], ExecutionResult]: - ... # pragma: no cover + **kwargs: Any, + ) -> Union[Dict[str, Any], ExecutionResult]: ... # pragma: no cover def execute_sync( self, - document: DocumentNode, - variable_values: Optional[Dict[str, Any]] = None, - operation_name: Optional[str] = None, + request: GraphQLRequest, + *, serialize_variables: Optional[bool] = None, parse_result: Optional[bool] = None, get_execution_result: bool = False, - **kwargs, + **kwargs: Any, ) -> Union[Dict[str, Any], ExecutionResult]: """:meta private:""" with self as session: return session.execute( - document, - variable_values=variable_values, - operation_name=operation_name, + request, serialize_variables=serialize_variables, parse_result=parse_result, get_execution_result=get_execution_result, @@ -251,9 +261,8 @@ def execute_batch_sync( serialize_variables: Optional[bool] = None, parse_result: Optional[bool] = None, get_execution_result: Literal[False] = ..., - **kwargs, - ) -> List[Dict[str, Any]]: - ... # pragma: no cover + **kwargs: Any, + ) -> List[Dict[str, Any]]: ... # pragma: no cover @overload def execute_batch_sync( @@ -263,9 +272,8 @@ def execute_batch_sync( serialize_variables: Optional[bool] = None, parse_result: Optional[bool] = None, get_execution_result: Literal[True], - **kwargs, - ) -> List[ExecutionResult]: - ... # pragma: no cover + **kwargs: Any, + ) -> List[ExecutionResult]: ... # pragma: no cover @overload def execute_batch_sync( @@ -275,9 +283,8 @@ def execute_batch_sync( serialize_variables: Optional[bool] = None, parse_result: Optional[bool] = None, get_execution_result: bool, - **kwargs, - ) -> Union[List[Dict[str, Any]], List[ExecutionResult]]: - ... # pragma: no cover + **kwargs: Any, + ) -> Union[List[Dict[str, Any]], List[ExecutionResult]]: ... # pragma: no cover def execute_batch_sync( self, @@ -286,7 +293,7 @@ def execute_batch_sync( serialize_variables: Optional[bool] = None, parse_result: Optional[bool] = None, get_execution_result: bool = False, - **kwargs, + **kwargs: Any, ) -> Union[List[Dict[str, Any]], List[ExecutionResult]]: """:meta private:""" with self as session: @@ -301,61 +308,101 @@ def execute_batch_sync( @overload async def execute_async( self, - document: DocumentNode, - variable_values: Optional[Dict[str, Any]] = ..., - operation_name: Optional[str] = ..., + request: GraphQLRequest, + *, serialize_variables: Optional[bool] = ..., parse_result: Optional[bool] = ..., - *, # https://github.com/python/mypy/issues/7333#issuecomment-788255229 get_execution_result: Literal[False] = ..., - **kwargs, - ) -> Dict[str, Any]: - ... # pragma: no cover + **kwargs: Any, + ) -> Dict[str, Any]: ... # pragma: no cover @overload async def execute_async( self, - document: DocumentNode, - variable_values: Optional[Dict[str, Any]] = ..., - operation_name: Optional[str] = ..., + request: GraphQLRequest, + *, serialize_variables: Optional[bool] = ..., parse_result: Optional[bool] = ..., - *, get_execution_result: Literal[True], - **kwargs, - ) -> ExecutionResult: - ... # pragma: no cover + **kwargs: Any, + ) -> ExecutionResult: ... # pragma: no cover @overload async def execute_async( self, - document: DocumentNode, - variable_values: Optional[Dict[str, Any]] = ..., - operation_name: Optional[str] = ..., + request: GraphQLRequest, + *, serialize_variables: Optional[bool] = ..., parse_result: Optional[bool] = ..., - *, get_execution_result: bool, - **kwargs, - ) -> Union[Dict[str, Any], ExecutionResult]: - ... # pragma: no cover + **kwargs: Any, + ) -> Union[Dict[str, Any], ExecutionResult]: ... # pragma: no cover async def execute_async( self, - document: DocumentNode, - variable_values: Optional[Dict[str, Any]] = None, - operation_name: Optional[str] = None, + request: GraphQLRequest, + *, serialize_variables: Optional[bool] = None, parse_result: Optional[bool] = None, get_execution_result: bool = False, - **kwargs, + **kwargs: Any, ) -> Union[Dict[str, Any], ExecutionResult]: """:meta private:""" async with self as session: return await session.execute( - document, - variable_values=variable_values, - operation_name=operation_name, + request, + serialize_variables=serialize_variables, + parse_result=parse_result, + get_execution_result=get_execution_result, + **kwargs, + ) + + @overload + async def execute_batch_async( + self, + requests: List[GraphQLRequest], + *, + serialize_variables: Optional[bool] = None, + parse_result: Optional[bool] = None, + get_execution_result: Literal[False] = ..., + **kwargs: Any, + ) -> List[Dict[str, Any]]: ... # pragma: no cover + + @overload + async def execute_batch_async( + self, + requests: List[GraphQLRequest], + *, + serialize_variables: Optional[bool] = None, + parse_result: Optional[bool] = None, + get_execution_result: Literal[True], + **kwargs: Any, + ) -> List[ExecutionResult]: ... # pragma: no cover + + @overload + async def execute_batch_async( + self, + requests: List[GraphQLRequest], + *, + serialize_variables: Optional[bool] = None, + parse_result: Optional[bool] = None, + get_execution_result: bool, + **kwargs: Any, + ) -> Union[List[Dict[str, Any]], List[ExecutionResult]]: ... # pragma: no cover + + async def execute_batch_async( + self, + requests: List[GraphQLRequest], + *, + serialize_variables: Optional[bool] = None, + parse_result: Optional[bool] = None, + get_execution_result: bool = False, + **kwargs: Any, + ) -> Union[List[Dict[str, Any]], List[ExecutionResult]]: + """:meta private:""" + async with self as session: + return await session.execute_batch( + requests, serialize_variables=serialize_variables, parse_result=parse_result, get_execution_result=get_execution_result, @@ -365,56 +412,46 @@ async def execute_async( @overload def execute( self, - document: DocumentNode, - variable_values: Optional[Dict[str, Any]] = ..., - operation_name: Optional[str] = ..., + request: GraphQLRequest, + *, serialize_variables: Optional[bool] = ..., parse_result: Optional[bool] = ..., - *, # https://github.com/python/mypy/issues/7333#issuecomment-788255229 get_execution_result: Literal[False] = ..., - **kwargs, - ) -> Dict[str, Any]: - ... # pragma: no cover + **kwargs: Any, + ) -> Dict[str, Any]: ... # pragma: no cover @overload def execute( self, - document: DocumentNode, - variable_values: Optional[Dict[str, Any]] = ..., - operation_name: Optional[str] = ..., + request: GraphQLRequest, + *, serialize_variables: Optional[bool] = ..., parse_result: Optional[bool] = ..., - *, get_execution_result: Literal[True], - **kwargs, - ) -> ExecutionResult: - ... # pragma: no cover + **kwargs: Any, + ) -> ExecutionResult: ... # pragma: no cover @overload def execute( self, - document: DocumentNode, - variable_values: Optional[Dict[str, Any]] = ..., - operation_name: Optional[str] = ..., + request: GraphQLRequest, + *, serialize_variables: Optional[bool] = ..., parse_result: Optional[bool] = ..., - *, get_execution_result: bool, - **kwargs, - ) -> Union[Dict[str, Any], ExecutionResult]: - ... # pragma: no cover + **kwargs: Any, + ) -> Union[Dict[str, Any], ExecutionResult]: ... # pragma: no cover def execute( self, - document: DocumentNode, - variable_values: Optional[Dict[str, Any]] = None, - operation_name: Optional[str] = None, + request: GraphQLRequest, + *, serialize_variables: Optional[bool] = None, parse_result: Optional[bool] = None, get_execution_result: bool = False, - **kwargs, + **kwargs: Any, ) -> Union[Dict[str, Any], ExecutionResult]: - """Execute the provided document AST against the remote server using + """Execute the provided request against the remote server using the transport provided during init. This function **WILL BLOCK** until the result is received from the server. @@ -437,17 +474,7 @@ def execute( """ if isinstance(self.transport, AsyncTransport): - # Get the current asyncio event loop - # Or create a new event loop if there isn't one (in a new Thread) - try: - with warnings.catch_warnings(): - warnings.filterwarnings( - "ignore", message="There is no current event loop" - ) - loop = asyncio.get_event_loop() - except RuntimeError: - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) + loop = self._get_event_loop() assert not loop.is_running(), ( "Cannot run client.execute(query) if an asyncio loop is running." @@ -456,9 +483,7 @@ def execute( data = loop.run_until_complete( self.execute_async( - document, - variable_values=variable_values, - operation_name=operation_name, + request, serialize_variables=serialize_variables, parse_result=parse_result, get_execution_result=get_execution_result, @@ -470,9 +495,7 @@ def execute( else: # Sync transports return self.execute_sync( - document, - variable_values=variable_values, - operation_name=operation_name, + request, serialize_variables=serialize_variables, parse_result=parse_result, get_execution_result=get_execution_result, @@ -487,9 +510,8 @@ def execute_batch( serialize_variables: Optional[bool] = None, parse_result: Optional[bool] = None, get_execution_result: Literal[False] = ..., - **kwargs, - ) -> List[Dict[str, Any]]: - ... # pragma: no cover + **kwargs: Any, + ) -> List[Dict[str, Any]]: ... # pragma: no cover @overload def execute_batch( @@ -499,9 +521,8 @@ def execute_batch( serialize_variables: Optional[bool] = None, parse_result: Optional[bool] = None, get_execution_result: Literal[True], - **kwargs, - ) -> List[ExecutionResult]: - ... # pragma: no cover + **kwargs: Any, + ) -> List[ExecutionResult]: ... # pragma: no cover @overload def execute_batch( @@ -511,9 +532,8 @@ def execute_batch( serialize_variables: Optional[bool] = None, parse_result: Optional[bool] = None, get_execution_result: bool, - **kwargs, - ) -> Union[List[Dict[str, Any]], List[ExecutionResult]]: - ... # pragma: no cover + **kwargs: Any, + ) -> Union[List[Dict[str, Any]], List[ExecutionResult]]: ... # pragma: no cover def execute_batch( self, @@ -522,7 +542,7 @@ def execute_batch( serialize_variables: Optional[bool] = None, parse_result: Optional[bool] = None, get_execution_result: bool = False, - **kwargs, + **kwargs: Any, ) -> Union[List[Dict[str, Any]], List[ExecutionResult]]: """Execute multiple GraphQL requests in a batch against the remote server using the transport provided during init. @@ -547,7 +567,24 @@ def execute_batch( """ if isinstance(self.transport, AsyncTransport): - raise NotImplementedError("Batching is not implemented for async yet.") + loop = self._get_event_loop() + + assert not loop.is_running(), ( + "Cannot run client.execute_batch(query) if an asyncio loop is running." + " Use 'await client.execute_batch(query)' instead." + ) + + data = loop.run_until_complete( + self.execute_batch_async( + requests, + serialize_variables=serialize_variables, + parse_result=parse_result, + get_execution_result=get_execution_result, + **kwargs, + ) + ) + + return data else: # Sync transports return self.execute_batch_sync( @@ -561,65 +598,53 @@ def execute_batch( @overload def subscribe_async( self, - document: DocumentNode, - variable_values: Optional[Dict[str, Any]] = ..., - operation_name: Optional[str] = ..., + request: GraphQLRequest, + *, serialize_variables: Optional[bool] = ..., parse_result: Optional[bool] = ..., - *, get_execution_result: Literal[False] = ..., - **kwargs, - ) -> AsyncGenerator[Dict[str, Any], None]: - ... # pragma: no cover + **kwargs: Any, + ) -> AsyncGenerator[Dict[str, Any], None]: ... # pragma: no cover @overload def subscribe_async( self, - document: DocumentNode, - variable_values: Optional[Dict[str, Any]] = ..., - operation_name: Optional[str] = ..., + request: GraphQLRequest, + *, serialize_variables: Optional[bool] = ..., parse_result: Optional[bool] = ..., - *, get_execution_result: Literal[True], - **kwargs, - ) -> AsyncGenerator[ExecutionResult, None]: - ... # pragma: no cover + **kwargs: Any, + ) -> AsyncGenerator[ExecutionResult, None]: ... # pragma: no cover @overload def subscribe_async( self, - document: DocumentNode, - variable_values: Optional[Dict[str, Any]] = ..., - operation_name: Optional[str] = ..., + request: GraphQLRequest, + *, serialize_variables: Optional[bool] = ..., parse_result: Optional[bool] = ..., - *, get_execution_result: bool, - **kwargs, + **kwargs: Any, ) -> Union[ AsyncGenerator[Dict[str, Any], None], AsyncGenerator[ExecutionResult, None] - ]: - ... # pragma: no cover + ]: ... # pragma: no cover async def subscribe_async( self, - document: DocumentNode, - variable_values: Optional[Dict[str, Any]] = None, - operation_name: Optional[str] = None, + request: GraphQLRequest, + *, serialize_variables: Optional[bool] = None, parse_result: Optional[bool] = None, get_execution_result: bool = False, - **kwargs, + **kwargs: Any, ) -> Union[ AsyncGenerator[Dict[str, Any], None], AsyncGenerator[ExecutionResult, None] ]: """:meta private:""" async with self as session: generator = session.subscribe( - document, - variable_values=variable_values, - operation_name=operation_name, + request, serialize_variables=serialize_variables, parse_result=parse_result, get_execution_result=get_execution_result, @@ -632,57 +657,46 @@ async def subscribe_async( @overload def subscribe( self, - document: DocumentNode, - variable_values: Optional[Dict[str, Any]] = ..., - operation_name: Optional[str] = ..., + request: GraphQLRequest, + *, serialize_variables: Optional[bool] = ..., parse_result: Optional[bool] = ..., - *, get_execution_result: Literal[False] = ..., - **kwargs, - ) -> Generator[Dict[str, Any], None, None]: - ... # pragma: no cover + **kwargs: Any, + ) -> Generator[Dict[str, Any], None, None]: ... # pragma: no cover @overload def subscribe( self, - document: DocumentNode, - variable_values: Optional[Dict[str, Any]] = ..., - operation_name: Optional[str] = ..., + request: GraphQLRequest, + *, serialize_variables: Optional[bool] = ..., parse_result: Optional[bool] = ..., - *, get_execution_result: Literal[True], - **kwargs, - ) -> Generator[ExecutionResult, None, None]: - ... # pragma: no cover + **kwargs: Any, + ) -> Generator[ExecutionResult, None, None]: ... # pragma: no cover @overload def subscribe( self, - document: DocumentNode, - variable_values: Optional[Dict[str, Any]] = ..., - operation_name: Optional[str] = ..., + request: GraphQLRequest, + *, serialize_variables: Optional[bool] = ..., parse_result: Optional[bool] = ..., - *, get_execution_result: bool, - **kwargs, + **kwargs: Any, ) -> Union[ Generator[Dict[str, Any], None, None], Generator[ExecutionResult, None, None] - ]: - ... # pragma: no cover + ]: ... # pragma: no cover def subscribe( self, - document: DocumentNode, - variable_values: Optional[Dict[str, Any]] = None, - operation_name: Optional[str] = None, + request: GraphQLRequest, + *, serialize_variables: Optional[bool] = None, parse_result: Optional[bool] = None, - *, get_execution_result: bool = False, - **kwargs, + **kwargs: Any, ) -> Union[ Generator[Dict[str, Any], None, None], Generator[ExecutionResult, None, None] ]: @@ -691,35 +705,23 @@ def subscribe( We need an async transport for this functionality. """ - # Get the current asyncio event loop - # Or create a new event loop if there isn't one (in a new Thread) - try: - with warnings.catch_warnings(): - warnings.filterwarnings( - "ignore", message="There is no current event loop" - ) - loop = asyncio.get_event_loop() - except RuntimeError: - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) + loop = self._get_event_loop() + + assert not loop.is_running(), ( + "Cannot run client.subscribe(query) if an asyncio loop is running." + " Use 'await client.subscribe_async(query)' instead." + ) async_generator: Union[ AsyncGenerator[Dict[str, Any], None], AsyncGenerator[ExecutionResult, None] ] = self.subscribe_async( - document, - variable_values=variable_values, - operation_name=operation_name, + request, serialize_variables=serialize_variables, parse_result=parse_result, get_execution_result=get_execution_result, **kwargs, ) - assert not loop.is_running(), ( - "Cannot run client.subscribe(query) if an asyncio loop is running." - " Use 'await client.subscribe_async(query)' instead." - ) - try: while True: # Note: we need to create a task here in order to be able to close @@ -770,17 +772,15 @@ async def connect_async(self, reconnecting=False, **kwargs): self.transport, AsyncTransport ), "Only a transport of type AsyncTransport can be used asynchronously" + self.session: Union[AsyncClientSession, SyncClientSession] + if reconnecting: self.session = ReconnectingAsyncClientSession(client=self, **kwargs) - await self.session.start_connecting_task() else: - try: - await self.transport.connect() - except Exception as e: - await self.transport.close() - raise e self.session = AsyncClientSession(client=self) + await self.session.connect() + # Get schema from transport if needed try: if self.fetch_schema_from_transport and not self.schema: @@ -789,7 +789,7 @@ async def connect_async(self, reconnecting=False, **kwargs): # we don't know what type of exception is thrown here because it # depends on the underlying transport; we just make sure that the # transport is closed and re-raise the exception - await self.transport.close() + await self.session.close() raise return self.session @@ -797,10 +797,7 @@ async def connect_async(self, reconnecting=False, **kwargs): async def close_async(self): """Close the async transport and stop the optional reconnecting task.""" - if isinstance(self.session, ReconnectingAsyncClientSession): - await self.session.stop_connecting_task() - - await self.transport.close() + await self.session.close() async def __aenter__(self): return await self.connect_async() @@ -825,6 +822,8 @@ def connect_sync(self): if not hasattr(self, "session"): self.session = SyncClientSession(client=self) + assert isinstance(self.session, SyncClientSession) + self.session.connect() # Get schema from transport if needed @@ -846,6 +845,8 @@ def close_sync(self): If batching is enabled, this will block until the remaining queries in the batching queue have been processed. """ + assert isinstance(self.session, SyncClientSession) + self.session.close() def __enter__(self): @@ -868,19 +869,17 @@ def __init__(self, client: Client): def _execute( self, - document: DocumentNode, - variable_values: Optional[Dict[str, Any]] = None, - operation_name: Optional[str] = None, + request: GraphQLRequest, + *, serialize_variables: Optional[bool] = None, parse_result: Optional[bool] = None, - **kwargs, + **kwargs: Any, ) -> ExecutionResult: - """Execute the provided document AST synchronously using + """Execute the provided request synchronously using the sync transport, returning an ExecutionResult object. - :param document: GraphQL query as AST Node object. - :param variable_values: Dictionary of input parameters. - :param operation_name: Name of the operation that shall be executed. + :param request: GraphQL request as a + :class:`GraphQLRequest ` object. :param serialize_variables: whether the variable values should be serialized. Used for custom scalars and/or enums. By default use the serialize_variables argument of the client. @@ -889,36 +888,28 @@ def _execute( The extra arguments are passed to the transport execute method.""" + # Still supporting for now old method of providing + # variable_values and operation_name + request = support_deprecated_request(request, kwargs) + # Validate document if self.client.schema: - self.client.validate(document) + self.client.validate(request) # Parse variable values for custom scalars if requested - if variable_values is not None: + if request.variable_values is not None: if serialize_variables or ( serialize_variables is None and self.client.serialize_variables ): - variable_values = serialize_variable_values( - self.client.schema, - document, - variable_values, - operation_name=operation_name, - ) + request = request.serialize_variable_values(self.client.schema) if self.client.batching_enabled: - request = GraphQLRequest( - document, - variable_values=variable_values, - operation_name=operation_name, - ) future_result = self._execute_future(request) result = future_result.result() else: result = self.transport.execute( - document, - variable_values=variable_values, - operation_name=operation_name, + request, **kwargs, ) @@ -927,9 +918,9 @@ def _execute( if parse_result or (parse_result is None and self.client.parse_results): result.data = parse_result_fn( self.client.schema, - document, + request.document, result.data, - operation_name=operation_name, + operation_name=request.operation_name, ) return result @@ -937,64 +928,52 @@ def _execute( @overload def execute( self, - document: DocumentNode, - variable_values: Optional[Dict[str, Any]] = ..., - operation_name: Optional[str] = ..., + request: GraphQLRequest, + *, serialize_variables: Optional[bool] = ..., parse_result: Optional[bool] = ..., - *, get_execution_result: Literal[False] = ..., - **kwargs, - ) -> Dict[str, Any]: - ... # pragma: no cover + **kwargs: Any, + ) -> Dict[str, Any]: ... # pragma: no cover @overload def execute( self, - document: DocumentNode, - variable_values: Optional[Dict[str, Any]] = ..., - operation_name: Optional[str] = ..., + request: GraphQLRequest, + *, serialize_variables: Optional[bool] = ..., parse_result: Optional[bool] = ..., - *, get_execution_result: Literal[True], - **kwargs, - ) -> ExecutionResult: - ... # pragma: no cover + **kwargs: Any, + ) -> ExecutionResult: ... # pragma: no cover @overload def execute( self, - document: DocumentNode, - variable_values: Optional[Dict[str, Any]] = ..., - operation_name: Optional[str] = ..., + request: GraphQLRequest, + *, serialize_variables: Optional[bool] = ..., parse_result: Optional[bool] = ..., - *, get_execution_result: bool, - **kwargs, - ) -> Union[Dict[str, Any], ExecutionResult]: - ... # pragma: no cover + **kwargs: Any, + ) -> Union[Dict[str, Any], ExecutionResult]: ... # pragma: no cover def execute( self, - document: DocumentNode, - variable_values: Optional[Dict[str, Any]] = None, - operation_name: Optional[str] = None, + request: GraphQLRequest, + *, serialize_variables: Optional[bool] = None, parse_result: Optional[bool] = None, get_execution_result: bool = False, - **kwargs, + **kwargs: Any, ) -> Union[Dict[str, Any], ExecutionResult]: - """Execute the provided document AST synchronously using + """Execute the provided request synchronously using the sync transport. Raises a TransportQueryError if an error has been returned in the ExecutionResult. - :param document: GraphQL query as AST Node object. - :param variable_values: Dictionary of input parameters. - :param operation_name: Name of the operation that shall be executed. + :param request: GraphQL query as :class:`GraphQLRequest `. :param serialize_variables: whether the variable values should be serialized. Used for custom scalars and/or enums. By default use the serialize_variables argument of the client. @@ -1007,9 +986,7 @@ def execute( # Validate and execute on the transport result = self._execute( - document, - variable_values=variable_values, - operation_name=operation_name, + request, serialize_variables=serialize_variables, parse_result=parse_result, **kwargs, @@ -1040,7 +1017,7 @@ def _execute_batch( serialize_variables: Optional[bool] = None, parse_result: Optional[bool] = None, validate_document: Optional[bool] = True, - **kwargs, + **kwargs: Any, ) -> List[ExecutionResult]: """Execute multiple GraphQL requests in a batch, using the sync transport, returning a list of ExecutionResult objects. @@ -1060,16 +1037,18 @@ def _execute_batch( if validate_document: for req in requests: - self.client.validate(req.document) + self.client.validate(req) # Parse variable values for custom scalars if requested if serialize_variables or ( serialize_variables is None and self.client.serialize_variables ): requests = [ - req.serialize_variable_values(self.client.schema) - if req.variable_values is not None - else req + ( + req.serialize_variable_values(self.client.schema) + if req.variable_values is not None + else req + ) for req in requests ] @@ -1096,9 +1075,8 @@ def execute_batch( serialize_variables: Optional[bool] = None, parse_result: Optional[bool] = None, get_execution_result: Literal[False] = ..., - **kwargs, - ) -> List[Dict[str, Any]]: - ... # pragma: no cover + **kwargs: Any, + ) -> List[Dict[str, Any]]: ... # pragma: no cover @overload def execute_batch( @@ -1108,9 +1086,8 @@ def execute_batch( serialize_variables: Optional[bool] = None, parse_result: Optional[bool] = None, get_execution_result: Literal[True], - **kwargs, - ) -> List[ExecutionResult]: - ... # pragma: no cover + **kwargs: Any, + ) -> List[ExecutionResult]: ... # pragma: no cover @overload def execute_batch( @@ -1120,9 +1097,8 @@ def execute_batch( serialize_variables: Optional[bool] = None, parse_result: Optional[bool] = None, get_execution_result: bool, - **kwargs, - ) -> Union[List[Dict[str, Any]], List[ExecutionResult]]: - ... # pragma: no cover + **kwargs: Any, + ) -> Union[List[Dict[str, Any]], List[ExecutionResult]]: ... # pragma: no cover def execute_batch( self, @@ -1131,7 +1107,7 @@ def execute_batch( serialize_variables: Optional[bool] = None, parse_result: Optional[bool] = None, get_execution_result: bool = False, - **kwargs, + **kwargs: Any, ) -> Union[List[Dict[str, Any]], List[ExecutionResult]]: """Execute multiple GraphQL requests in a batch, using the sync transport. This method sends the requests to the server all at once. @@ -1284,7 +1260,7 @@ def fetch_schema(self) -> None: introspection_query = get_introspection_query_ast( **self.client.introspection_args ) - execution_result = self.transport.execute(introspection_query) + execution_result = self.transport.execute(GraphQLRequest(introspection_query)) self.client._build_schema_from_introspection(execution_result) @@ -1307,23 +1283,21 @@ def __init__(self, client: Client): async def _subscribe( self, - document: DocumentNode, - variable_values: Optional[Dict[str, Any]] = None, - operation_name: Optional[str] = None, + request: GraphQLRequest, + *, serialize_variables: Optional[bool] = None, parse_result: Optional[bool] = None, - **kwargs, + **kwargs: Any, ) -> AsyncGenerator[ExecutionResult, None]: - """Coroutine to subscribe asynchronously to the provided document AST + """Coroutine to subscribe asynchronously to the provided request asynchronously using the async transport, returning an async generator producing ExecutionResult objects. * Validate the query with the schema if provided. * Serialize the variable_values if requested. - :param document: GraphQL query as AST Node object. - :param variable_values: Dictionary of input parameters. - :param operation_name: Name of the operation that shall be executed. + :param request: GraphQL request as a + :class:`GraphQLRequest ` object. :param serialize_variables: whether the variable values should be serialized. Used for custom scalars and/or enums. By default use the serialize_variables argument of the client. @@ -1332,30 +1306,27 @@ async def _subscribe( The extra arguments are passed to the transport subscribe method.""" + # Still supporting for now old method of providing + # variable_values and operation_name + request = support_deprecated_request(request, kwargs) + # Validate document if self.client.schema: - self.client.validate(document) + self.client.validate(request) # Parse variable values for custom scalars if requested - if variable_values is not None: + if request.variable_values is not None: if serialize_variables or ( serialize_variables is None and self.client.serialize_variables ): - variable_values = serialize_variable_values( - self.client.schema, - document, - variable_values, - operation_name=operation_name, - ) + request = request.serialize_variable_values(self.client.schema) # Subscribe to the transport - inner_generator: AsyncGenerator[ - ExecutionResult, None - ] = self.transport.subscribe( - document, - variable_values=variable_values, - operation_name=operation_name, - **kwargs, + inner_generator: AsyncGenerator[ExecutionResult, None] = ( + self.transport.subscribe( + request, + **kwargs, + ) ) # Keep a reference to the inner generator @@ -1370,9 +1341,9 @@ async def _subscribe( ): result.data = parse_result_fn( self.client.schema, - document, + request.document, result.data, - operation_name=operation_name, + operation_name=request.operation_name, ) yield result @@ -1383,68 +1354,56 @@ async def _subscribe( @overload def subscribe( self, - document: DocumentNode, - variable_values: Optional[Dict[str, Any]] = ..., - operation_name: Optional[str] = ..., + request: GraphQLRequest, + *, serialize_variables: Optional[bool] = ..., parse_result: Optional[bool] = ..., - *, get_execution_result: Literal[False] = ..., - **kwargs, - ) -> AsyncGenerator[Dict[str, Any], None]: - ... # pragma: no cover + **kwargs: Any, + ) -> AsyncGenerator[Dict[str, Any], None]: ... # pragma: no cover @overload def subscribe( self, - document: DocumentNode, - variable_values: Optional[Dict[str, Any]] = ..., - operation_name: Optional[str] = ..., + request: GraphQLRequest, + *, serialize_variables: Optional[bool] = ..., parse_result: Optional[bool] = ..., - *, get_execution_result: Literal[True], - **kwargs, - ) -> AsyncGenerator[ExecutionResult, None]: - ... # pragma: no cover + **kwargs: Any, + ) -> AsyncGenerator[ExecutionResult, None]: ... # pragma: no cover @overload def subscribe( self, - document: DocumentNode, - variable_values: Optional[Dict[str, Any]] = ..., - operation_name: Optional[str] = ..., + request: GraphQLRequest, + *, serialize_variables: Optional[bool] = ..., parse_result: Optional[bool] = ..., - *, get_execution_result: bool, - **kwargs, + **kwargs: Any, ) -> Union[ AsyncGenerator[Dict[str, Any], None], AsyncGenerator[ExecutionResult, None] - ]: - ... # pragma: no cover + ]: ... # pragma: no cover async def subscribe( self, - document: DocumentNode, - variable_values: Optional[Dict[str, Any]] = None, - operation_name: Optional[str] = None, + request: GraphQLRequest, + *, serialize_variables: Optional[bool] = None, parse_result: Optional[bool] = None, get_execution_result: bool = False, - **kwargs, + **kwargs: Any, ) -> Union[ AsyncGenerator[Dict[str, Any], None], AsyncGenerator[ExecutionResult, None] ]: - """Coroutine to subscribe asynchronously to the provided document AST + """Coroutine to subscribe asynchronously to the provided request asynchronously using the async transport. Raises a TransportQueryError if an error has been returned in the ExecutionResult. - :param document: GraphQL query as AST Node object. - :param variable_values: Dictionary of input parameters. - :param operation_name: Name of the operation that shall be executed. + :param request: GraphQL query as :class:`GraphQLRequest `. :param serialize_variables: whether the variable values should be serialized. Used for custom scalars and/or enums. By default use the serialize_variables argument of the client. @@ -1456,9 +1415,7 @@ async def subscribe( The extra arguments are passed to the transport subscribe method.""" inner_generator: AsyncGenerator[ExecutionResult, None] = self._subscribe( - document, - variable_values=variable_values, - operation_name=operation_name, + request, serialize_variables=serialize_variables, parse_result=parse_result, **kwargs, @@ -1486,22 +1443,20 @@ async def subscribe( async def _execute( self, - document: DocumentNode, - variable_values: Optional[Dict[str, Any]] = None, - operation_name: Optional[str] = None, + request: GraphQLRequest, + *, serialize_variables: Optional[bool] = None, parse_result: Optional[bool] = None, - **kwargs, + **kwargs: Any, ) -> ExecutionResult: - """Coroutine to execute the provided document AST asynchronously using + """Coroutine to execute the provided request asynchronously using the async transport, returning an ExecutionResult object. * Validate the query with the schema if provided. * Serialize the variable_values if requested. - :param document: GraphQL query as AST Node object. - :param variable_values: Dictionary of input parameters. - :param operation_name: Name of the operation that shall be executed. + :param request: GraphQL request as a + :class:`GraphQLRequest ` object. :param serialize_variables: whether the variable values should be serialized. Used for custom scalars and/or enums. By default use the serialize_variables argument of the client. @@ -1510,39 +1465,41 @@ async def _execute( The extra arguments are passed to the transport execute method.""" + # Still supporting for now old method of providing + # variable_values and operation_name + request = support_deprecated_request(request, kwargs) + # Validate document if self.client.schema: - self.client.validate(document) + self.client.validate(request) # Parse variable values for custom scalars if requested - if variable_values is not None: + if request.variable_values is not None: if serialize_variables or ( serialize_variables is None and self.client.serialize_variables ): - variable_values = serialize_variable_values( - self.client.schema, - document, - variable_values, - operation_name=operation_name, - ) + request = request.serialize_variable_values(self.client.schema) - # Execute the query with the transport with a timeout - with fail_after(self.client.execute_timeout): - result = await self.transport.execute( - document, - variable_values=variable_values, - operation_name=operation_name, - **kwargs, - ) + # Check if batching is enabled + if self.client.batching_enabled: + future_result = await self._execute_future(request) + result = await future_result + else: + # Execute the query with the transport with a timeout + with fail_after(self.client.execute_timeout): + result = await self.transport.execute( + request, + **kwargs, + ) # Unserialize the result if requested if self.client.schema: if parse_result or (parse_result is None and self.client.parse_results): result.data = parse_result_fn( self.client.schema, - document, + request.document, result.data, - operation_name=operation_name, + operation_name=request.operation_name, ) return result @@ -1550,64 +1507,52 @@ async def _execute( @overload async def execute( self, - document: DocumentNode, - variable_values: Optional[Dict[str, Any]] = ..., - operation_name: Optional[str] = ..., + request: GraphQLRequest, + *, serialize_variables: Optional[bool] = ..., parse_result: Optional[bool] = ..., - *, get_execution_result: Literal[False] = ..., - **kwargs, - ) -> Dict[str, Any]: - ... # pragma: no cover + **kwargs: Any, + ) -> Dict[str, Any]: ... # pragma: no cover @overload async def execute( self, - document: DocumentNode, - variable_values: Optional[Dict[str, Any]] = ..., - operation_name: Optional[str] = ..., + request: GraphQLRequest, + *, serialize_variables: Optional[bool] = ..., parse_result: Optional[bool] = ..., - *, get_execution_result: Literal[True], - **kwargs, - ) -> ExecutionResult: - ... # pragma: no cover + **kwargs: Any, + ) -> ExecutionResult: ... # pragma: no cover @overload async def execute( self, - document: DocumentNode, - variable_values: Optional[Dict[str, Any]] = ..., - operation_name: Optional[str] = ..., + request: GraphQLRequest, + *, serialize_variables: Optional[bool] = ..., parse_result: Optional[bool] = ..., - *, get_execution_result: bool, - **kwargs, - ) -> Union[Dict[str, Any], ExecutionResult]: - ... # pragma: no cover + **kwargs: Any, + ) -> Union[Dict[str, Any], ExecutionResult]: ... # pragma: no cover async def execute( self, - document: DocumentNode, - variable_values: Optional[Dict[str, Any]] = None, - operation_name: Optional[str] = None, + request: GraphQLRequest, + *, serialize_variables: Optional[bool] = None, parse_result: Optional[bool] = None, get_execution_result: bool = False, - **kwargs, + **kwargs: Any, ) -> Union[Dict[str, Any], ExecutionResult]: - """Coroutine to execute the provided document AST asynchronously using + """Coroutine to execute the provided request asynchronously using the async transport. Raises a TransportQueryError if an error has been returned in the ExecutionResult. - :param document: GraphQL query as AST Node object. - :param variable_values: Dictionary of input parameters. - :param operation_name: Name of the operation that shall be executed. + :param request: GraphQL query as :class:`GraphQLRequest `. :param serialize_variables: whether the variable values should be serialized. Used for custom scalars and/or enums. By default use the serialize_variables argument of the client. @@ -1620,9 +1565,7 @@ async def execute( # Validate and execute on the transport result = await self._execute( - document, - variable_values=variable_values, - operation_name=operation_name, + request, serialize_variables=serialize_variables, parse_result=parse_result, **kwargs, @@ -1646,6 +1589,277 @@ async def execute( return result.data + async def _execute_batch( + self, + requests: List[GraphQLRequest], + *, + serialize_variables: Optional[bool] = None, + parse_result: Optional[bool] = None, + validate_document: Optional[bool] = True, + **kwargs: Any, + ) -> List[ExecutionResult]: + """Execute multiple GraphQL requests in a batch, using + the async transport, returning a list of ExecutionResult objects. + + :param requests: List of requests that will be executed. + :param serialize_variables: whether the variable values should be + serialized. Used for custom scalars and/or enums. + By default use the serialize_variables argument of the client. + :param parse_result: Whether gql will deserialize the result. + By default use the parse_results argument of the client. + :param validate_document: Whether we still need to validate the document. + + The extra arguments are passed to the transport execute_batch method.""" + + # Validate document + if self.client.schema: + + if validate_document: + for req in requests: + self.client.validate(req) + + # Parse variable values for custom scalars if requested + if serialize_variables or ( + serialize_variables is None and self.client.serialize_variables + ): + requests = [ + ( + req.serialize_variable_values(self.client.schema) + if req.variable_values is not None + else req + ) + for req in requests + ] + + results = await self.transport.execute_batch(requests, **kwargs) + + # Unserialize the result if requested + if self.client.schema: + if parse_result or (parse_result is None and self.client.parse_results): + for result in results: + result.data = parse_result_fn( + self.client.schema, + req.document, + result.data, + operation_name=req.operation_name, + ) + + return results + + @overload + async def execute_batch( + self, + requests: List[GraphQLRequest], + *, + serialize_variables: Optional[bool] = None, + parse_result: Optional[bool] = None, + get_execution_result: Literal[False] = ..., + **kwargs: Any, + ) -> List[Dict[str, Any]]: ... # pragma: no cover + + @overload + async def execute_batch( + self, + requests: List[GraphQLRequest], + *, + serialize_variables: Optional[bool] = None, + parse_result: Optional[bool] = None, + get_execution_result: Literal[True], + **kwargs: Any, + ) -> List[ExecutionResult]: ... # pragma: no cover + + @overload + async def execute_batch( + self, + requests: List[GraphQLRequest], + *, + serialize_variables: Optional[bool] = None, + parse_result: Optional[bool] = None, + get_execution_result: bool, + **kwargs: Any, + ) -> Union[List[Dict[str, Any]], List[ExecutionResult]]: ... # pragma: no cover + + async def execute_batch( + self, + requests: List[GraphQLRequest], + *, + serialize_variables: Optional[bool] = None, + parse_result: Optional[bool] = None, + get_execution_result: bool = False, + **kwargs: Any, + ) -> Union[List[Dict[str, Any]], List[ExecutionResult]]: + """Execute multiple GraphQL requests in a batch, using + the async transport. This method sends the requests to the server all at once. + + Raises a TransportQueryError if an error has been returned in any + ExecutionResult. + + :param requests: List of requests that will be executed. + :param serialize_variables: whether the variable values should be + serialized. Used for custom scalars and/or enums. + By default use the serialize_variables argument of the client. + :param parse_result: Whether gql will deserialize the result. + By default use the parse_results argument of the client. + :param get_execution_result: return the full ExecutionResult instance instead of + only the "data" field. Necessary if you want to get the "extensions" field. + + The extra arguments are passed to the transport execute method.""" + + # Validate and execute on the transport + results = await self._execute_batch( + requests, + serialize_variables=serialize_variables, + parse_result=parse_result, + **kwargs, + ) + + for result in results: + # Raise an error if an error is returned in the ExecutionResult object + if result.errors: + raise TransportQueryError( + str_first_element(result.errors), + errors=result.errors, + data=result.data, + extensions=result.extensions, + ) + + assert ( + result.data is not None + ), "Transport returned an ExecutionResult without data or errors" + + if get_execution_result: + return results + + return cast(List[Dict[str, Any]], [result.data for result in results]) + + async def _batch_loop(self) -> None: + """Main loop of the task used to wait for requests + to execute them in a batch""" + + stop_loop = False + + while not stop_loop: + # First wait for a first request in from the batch queue + requests_and_futures: List[Tuple[GraphQLRequest, asyncio.Future]] = [] + + # Wait for the first request + request_and_future: Optional[Tuple[GraphQLRequest, asyncio.Future]] = ( + await self.batch_queue.get() + ) + + if request_and_future is None: + # None is our sentinel value to stop the loop + break + + requests_and_futures.append(request_and_future) + + # Then wait the requested batch interval except if we already + # have the maximum number of requests in the queue + if self.batch_queue.qsize() < self.client.batch_max - 1: + # Wait for the batch interval + await asyncio.sleep(self.client.batch_interval) + + # Then get the requests which had been made during that wait interval + for _ in range(self.client.batch_max - 1): + try: + # Use get_nowait since we don't want to wait here + request_and_future = self.batch_queue.get_nowait() + + if request_and_future is None: + # Sentinel value - stop after processing current batch + stop_loop = True + break + + requests_and_futures.append(request_and_future) + + except asyncio.QueueEmpty: + # No more requests in queue, that's fine + break + + # Extract requests and futures + requests = [request for request, _ in requests_and_futures] + futures = [future for _, future in requests_and_futures] + + # Execute the batch + try: + results: List[ExecutionResult] = await self._execute_batch( + requests, + serialize_variables=False, # already done + parse_result=False, # will be done later + validate_document=False, # already validated + ) + + # Set the result for each future + for result, future in zip(results, futures): + if not future.cancelled(): + future.set_result(result) + + except Exception as exc: + # If batch execution fails, propagate the error to all futures + for future in futures: + if not future.cancelled(): + future.set_exception(exc) + + # Signal that the task has stopped + self._batch_task_stopped_event.set() + + async def _execute_future( + self, + request: GraphQLRequest, + ) -> asyncio.Future: + """If batching is enabled, this method will put a request in the batching queue + instead of executing it directly so that the requests could be put in a batch. + """ + + assert hasattr(self, "batch_queue"), "Batching is not enabled" + assert not self._batch_task_stop_requested, "Batching task has been stopped" + + future: asyncio.Future = asyncio.Future() + await self.batch_queue.put((request, future)) + + return future + + async def _batch_init(self): + """Initialize the batch task loop if batching is enabled.""" + if self.client.batching_enabled: + self.batch_queue: asyncio.Queue = asyncio.Queue() + self._batch_task_stop_requested = False + self._batch_task_stopped_event = asyncio.Event() + self._batch_task = asyncio.create_task(self._batch_loop()) + + async def _batch_cleanup(self): + """Cleanup the batching task if batching is enabled.""" + if hasattr(self, "_batch_task_stopped_event"): + # Send a None in the queue to indicate that the batching task must stop + # after having processed the remaining requests in the queue + self._batch_task_stop_requested = True + await self.batch_queue.put(None) + + # Wait for the task to process remaining requests and stop + await self._batch_task_stopped_event.wait() + + async def connect(self): + """Connect the transport and initialize the batch task loop if batching + is enabled.""" + + await self._batch_init() + + try: + await self.transport.connect() + except Exception as e: + await self.transport.close() + raise e + + async def close(self): + """Close the transport and cleanup the batching task if batching is enabled. + + Will wait until all the remaining requests in the batch processing queue + have been executed. + """ + await self._batch_cleanup() + + await self.transport.close() + async def fetch_schema(self) -> None: """Fetch the GraphQL schema explicitly using introspection. @@ -1654,7 +1868,9 @@ async def fetch_schema(self) -> None: introspection_query = get_introspection_query_ast( **self.client.introspection_args ) - execution_result = await self.transport.execute(introspection_query) + execution_result = await self.transport.execute( + GraphQLRequest(introspection_query) + ) self.client._build_schema_from_introspection(execution_result) @@ -1679,6 +1895,7 @@ class ReconnectingAsyncClientSession(AsyncClientSession): def __init__( self, client: Client, + *, retry_connect: Union[bool, _Decorator] = True, retry_execute: Union[bool, _Decorator] = True, ): @@ -1750,6 +1967,7 @@ async def _connection_loop(self): # Then wait for the reconnect event self._reconnect_request_event.clear() await self._reconnect_request_event.wait() + await self.transport.close() async def start_connecting_task(self): """Start the task responsible to restart the connection @@ -1768,29 +1986,43 @@ async def stop_connecting_task(self): self._connect_task.cancel() self._connect_task = None + async def connect(self): + """Start the connect task and initialize the batch task loop if batching + is enabled.""" + + await self._batch_init() + + await self.start_connecting_task() + + async def close(self): + """Stop the connect task and cleanup the batching task + if batching is enabled.""" + await self._batch_cleanup() + + await self.stop_connecting_task() + + await self.transport.close() + async def _execute_once( self, - document: DocumentNode, - variable_values: Optional[Dict[str, Any]] = None, - operation_name: Optional[str] = None, + request: GraphQLRequest, + *, serialize_variables: Optional[bool] = None, parse_result: Optional[bool] = None, - **kwargs, + **kwargs: Any, ) -> ExecutionResult: """Same Coroutine as parent method _execute but requesting a - reconnection if we receive a TransportClosed exception. + reconnection if we receive a TransportConnectionFailed exception. """ try: answer = await super()._execute( - document, - variable_values=variable_values, - operation_name=operation_name, + request, serialize_variables=serialize_variables, parse_result=parse_result, **kwargs, ) - except TransportClosed: + except TransportConnectionFailed: self._reconnect_request_event.set() raise @@ -1798,21 +2030,19 @@ async def _execute_once( async def _execute( self, - document: DocumentNode, - variable_values: Optional[Dict[str, Any]] = None, - operation_name: Optional[str] = None, + request: GraphQLRequest, + *, serialize_variables: Optional[bool] = None, parse_result: Optional[bool] = None, - **kwargs, + **kwargs: Any, ) -> ExecutionResult: """Same Coroutine as parent, but with optional retries - and requesting a reconnection if we receive a TransportClosed exception. + and requesting a reconnection if we receive a + TransportConnectionFailed exception. """ return await self._execute_with_retries( - document, - variable_values=variable_values, - operation_name=operation_name, + request, serialize_variables=serialize_variables, parse_result=parse_result, **kwargs, @@ -1820,21 +2050,18 @@ async def _execute( async def _subscribe( self, - document: DocumentNode, - variable_values: Optional[Dict[str, Any]] = None, - operation_name: Optional[str] = None, + request: GraphQLRequest, + *, serialize_variables: Optional[bool] = None, parse_result: Optional[bool] = None, - **kwargs, + **kwargs: Any, ) -> AsyncGenerator[ExecutionResult, None]: """Same Async generator as parent method _subscribe but requesting a - reconnection if we receive a TransportClosed exception. + reconnection if we receive a TransportConnectionFailed exception. """ inner_generator: AsyncGenerator[ExecutionResult, None] = super()._subscribe( - document, - variable_values=variable_values, - operation_name=operation_name, + request, serialize_variables=serialize_variables, parse_result=parse_result, **kwargs, @@ -1844,7 +2071,7 @@ async def _subscribe( async for result in inner_generator: yield result - except TransportClosed: + except TransportConnectionFailed: self._reconnect_request_event.set() raise diff --git a/gql/dsl.py b/gql/dsl.py index be2b5a7e..1a8716c2 100644 --- a/gql/dsl.py +++ b/gql/dsl.py @@ -2,6 +2,7 @@ .. image:: http://www.plantuml.com/plantuml/png/ZLAzJWCn3Dxz51vXw1im50ag8L4XwC1OkLTJ8gMvAd4GwEYxGuC8pTbKtUxy_TZEvsaIYfAt7e1MII9rWfsdbF1cSRzWpvtq4GT0JENduX8GXr_g7brQlf5tw-MBOx_-HlS0LV_Kzp8xr1kZav9PfCsMWvolEA_1VylHoZCExKwKv4Tg2s_VkSkca2kof2JDb0yxZYIk3qMZYUe1B1uUZOROXn96pQMugEMUdRnUUqUf6DBXQyIz2zu5RlgUQAFVNYaeRfBI79_JrUTaeg9JZFQj5MmUc69PDmNGE2iU61fDgfri3x36gxHw3gDHD6xqqQ7P4vjKqz2-602xtkO7uo17SCLhVSv25VjRjUAFcUE73Sspb8ADBl8gTT7j2cFAOPst_Wi0 # noqa :alt: UML diagram """ + import logging import re from abc import ABC, abstractmethod @@ -63,6 +64,7 @@ ) from graphql.pyutils import inspect +from .graphql_request import GraphQLRequest from .utils import to_camel_case log = logging.getLogger(__name__) @@ -213,7 +215,7 @@ def ast_from_value(value: Any, type_: GraphQLInputType) -> Optional[ValueNode]: def dsl_gql( *operations: "DSLExecutable", **operations_with_name: "DSLExecutable" -) -> DocumentNode: +) -> GraphQLRequest: r"""Given arguments instances of :class:`DSLExecutable` containing GraphQL operations or fragments, generate a Document which can be executed later in a @@ -230,7 +232,8 @@ def dsl_gql( :param \**operations_with_name: the GraphQL operations with an operation name :type \**operations_with_name: DSLQuery, DSLMutation, DSLSubscription - :return: a Document which can be later executed or subscribed by a + :return: a :class:`GraphQLRequest ` + which can be later executed or subscribed by a :class:`Client `, by an :class:`async session ` or by a :class:`sync session ` @@ -258,10 +261,12 @@ def dsl_gql( f"Received: {type(operation)}." ) - return DocumentNode( + document = DocumentNode( definitions=[operation.executable_ast for operation in all_operations] ) + return GraphQLRequest(document) + class DSLSchema: """The DSLSchema is the root of the DSL code. @@ -338,7 +343,7 @@ def select( self, *fields: "DSLSelectable", **fields_with_alias: "DSLSelectableWithAlias", - ): + ) -> Any: r"""Select the fields which should be added. :param \*fields: fields or fragments @@ -595,9 +600,11 @@ def get_ast_definitions(self) -> Tuple[VariableDefinitionNode, ...]: VariableDefinitionNode( type=var.ast_variable_type, variable=var.ast_variable_name, - default_value=None - if var.default_value is None - else ast_from_value(var.default_value, var.type), + default_value=( + None + if var.default_value is None + else ast_from_value(var.default_value, var.type) + ), directives=(), ) for var in self.variables.values() @@ -836,10 +843,10 @@ def name(self): """:meta private:""" return self.ast_field.name.value - def __call__(self, **kwargs) -> "DSLField": + def __call__(self, **kwargs: Any) -> "DSLField": return self.args(**kwargs) - def args(self, **kwargs) -> "DSLField": + def args(self, **kwargs: Any) -> "DSLField": r"""Set the arguments of a field The arguments are parsed to be stored in the AST of this field. diff --git a/gql/gql.py b/gql/gql.py index e9705947..8a5a1b32 100644 --- a/gql/gql.py +++ b/gql/gql.py @@ -1,24 +1,16 @@ -from __future__ import annotations +from .graphql_request import GraphQLRequest -from graphql import DocumentNode, Source, parse - -def gql(request_string: str | Source) -> DocumentNode: - """Given a string containing a GraphQL request, parse it into a Document. +def gql(request_string: str) -> GraphQLRequest: + """Given a string containing a GraphQL request, + parse it into a Document and put it into a GraphQLRequest object. :param request_string: the GraphQL request as a String - :type request_string: str | Source - :return: a Document which can be later executed or subscribed by a - :class:`Client `, by an - :class:`async session ` or by a - :class:`sync session ` - + :return: a :class:`GraphQLRequest ` + which can be later executed or subscribed by a + :class:`Client `, by an + :class:`async session ` or by a + :class:`sync session ` :raises graphql.error.GraphQLError: if a syntax error is encountered. """ - if isinstance(request_string, Source): - source = request_string - elif isinstance(request_string, str): - source = Source(request_string, "GraphQL request") - else: - raise TypeError("Request must be passed as a string or Source object.") - return parse(source) + return GraphQLRequest(request_string) diff --git a/gql/graphql_request.py b/gql/graphql_request.py index b0c68f5c..5e6f3ee4 100644 --- a/gql/graphql_request.py +++ b/gql/graphql_request.py @@ -1,32 +1,59 @@ -from dataclasses import dataclass -from typing import Any, Dict, Optional +import warnings +from typing import Any, Dict, Optional, Union -from graphql import DocumentNode, GraphQLSchema +from graphql import DocumentNode, GraphQLSchema, Source, parse, print_ast -from .utilities import serialize_variable_values - -@dataclass(frozen=True) class GraphQLRequest: """GraphQL Request to be executed.""" - document: DocumentNode - """GraphQL query as AST Node object.""" + def __init__( + self, + request: Union[DocumentNode, "GraphQLRequest", str], + *, + variable_values: Optional[Dict[str, Any]] = None, + operation_name: Optional[str] = None, + ): + """Initialize a GraphQL request. - variable_values: Optional[Dict[str, Any]] = None - """Dictionary of input parameters (Default: None).""" + :param request: GraphQL request as DocumentNode object or as a string. + If string, it will be converted to DocumentNode. + :param variable_values: Dictionary of input parameters (Default: None). + :param operation_name: Name of the operation that shall be executed. + Only required in multi-operation documents (Default: None). + :return: a :class:`GraphQLRequest ` + which can be later executed or subscribed by a + :class:`Client `, by an + :class:`async session ` or by a + :class:`sync session ` + :raises graphql.error.GraphQLError: if a syntax error is encountered. + """ + if isinstance(request, str): + source = Source(request, "GraphQL request") + self.document = parse(source) + elif isinstance(request, DocumentNode): + self.document = request + elif not isinstance(request, GraphQLRequest): + raise TypeError(f"Unexpected type for GraphQLRequest: {type(request)}") - operation_name: Optional[str] = None - """ - Name of the operation that shall be executed. - Only required in multi-operation documents (Default: None). - """ + if isinstance(request, GraphQLRequest): + self.document = request.document + if variable_values is None: + variable_values = request.variable_values + if operation_name is None: + operation_name = request.operation_name + + self.variable_values: Optional[Dict[str, Any]] = variable_values + self.operation_name: Optional[str] = operation_name def serialize_variable_values(self, schema: GraphQLSchema) -> "GraphQLRequest": + + from .utilities.serialize_variable_values import serialize_variable_values + assert self.variable_values return GraphQLRequest( - document=self.document, + self.document, variable_values=serialize_variable_values( schema=schema, document=self.document, @@ -35,3 +62,63 @@ def serialize_variable_values(self, schema: GraphQLSchema) -> "GraphQLRequest": ), operation_name=self.operation_name, ) + + @property + def payload(self) -> Dict[str, Any]: + query_str = print_ast(self.document) + payload: Dict[str, Any] = {"query": query_str} + + if self.operation_name: + payload["operationName"] = self.operation_name + + if self.variable_values: + payload["variables"] = self.variable_values + + return payload + + def __str__(self): + return str(self.payload) + + +def support_deprecated_request( + request: Union[GraphQLRequest, DocumentNode], + kwargs: Dict, +) -> GraphQLRequest: + """This methods is there temporarily to convert the old style of calling + execute and subscribe methods with a DocumentNode, + variable_values and operation_name arguments. + """ + + if isinstance(request, DocumentNode): + warnings.warn( + ( + "Using a DocumentNode is deprecated. Please use a " + "GraphQLRequest instead." + ), + DeprecationWarning, + stacklevel=2, + ) + request = GraphQLRequest(request) + + if not isinstance(request, GraphQLRequest): + raise TypeError("request should be a GraphQLRequest object") + + variable_values = kwargs.pop("variable_values", None) + operation_name = kwargs.pop("operation_name", None) + + if variable_values or operation_name: + warnings.warn( + ( + "Using variable_values and operation_name arguments of " + "execute and subscribe methods is deprecated. Instead, " + "please use the variable_values and operation_name properties " + "of GraphQLRequest" + ), + DeprecationWarning, + stacklevel=2, + ) + + request.variable_values = variable_values + request.operation_name = operation_name + + return request diff --git a/gql/transport/aiohttp.py b/gql/transport/aiohttp.py index 0c332205..e3bfdb3b 100644 --- a/gql/transport/aiohttp.py +++ b/gql/transport/aiohttp.py @@ -1,20 +1,18 @@ import asyncio -import functools import io import json import logging -import warnings from ssl import SSLContext from typing import ( Any, AsyncGenerator, Callable, Dict, + List, Optional, Tuple, Type, Union, - cast, ) import aiohttp @@ -22,18 +20,23 @@ from aiohttp.client_reqrep import Fingerprint from aiohttp.helpers import BasicAuth from aiohttp.typedefs import LooseCookies, LooseHeaders -from graphql import DocumentNode, ExecutionResult, print_ast +from graphql import ExecutionResult from multidict import CIMultiDictProxy -from ..utils import extract_files +from ..graphql_request import GraphQLRequest from .appsync_auth import AppSyncAuthentication from .async_transport import AsyncTransport +from .common.aiohttp_closed_event import create_aiohttp_closed_event +from .common.batch import get_batch_execution_result_list from .exceptions import ( TransportAlreadyConnected, TransportClosed, + TransportConnectionFailed, + TransportError, TransportProtocolError, TransportServerError, ) +from .file_upload import FileVar, close_files, extract_files, open_files log = logging.getLogger(__name__) @@ -57,7 +60,7 @@ def __init__( headers: Optional[LooseHeaders] = None, cookies: Optional[LooseCookies] = None, auth: Optional[Union[BasicAuth, "AppSyncAuthentication"]] = None, - ssl: Union[SSLContext, bool, Fingerprint, str] = "ssl_warning", + ssl: Union[SSLContext, bool, Fingerprint] = True, timeout: Optional[int] = None, ssl_close_timeout: Optional[Union[int, float]] = 10, json_serialize: Callable = json.dumps, @@ -71,7 +74,8 @@ def __init__( :param cookies: Dict of HTTP cookies. :param auth: BasicAuth object to enable Basic HTTP auth if needed Or Appsync Authentication class - :param ssl: ssl_context of the connection. Use ssl=False to disable encryption + :param ssl: ssl_context of the connection. + Use ssl=False to not verify ssl certificates. :param ssl_close_timeout: Timeout in seconds to wait for the ssl connection to close properly :param json_serialize: Json serializer callable. @@ -88,20 +92,7 @@ def __init__( self.headers: Optional[LooseHeaders] = headers self.cookies: Optional[LooseCookies] = cookies self.auth: Optional[Union[BasicAuth, "AppSyncAuthentication"]] = auth - - if ssl == "ssl_warning": - ssl = False - if str(url).startswith("https"): - warnings.warn( - "WARNING: By default, AIOHTTPTransport does not verify" - " ssl certificates. This will be fixed in the next major version." - " You can set ssl=True to force the ssl certificate verification" - " or ssl=False to disable this warning" - ) - - self.ssl: Union[SSLContext, bool, Fingerprint] = cast( - Union[SSLContext, bool, Fingerprint], ssl - ) + self.ssl: Union[SSLContext, bool, Fingerprint] = ssl self.timeout: Optional[int] = timeout self.ssl_close_timeout: Optional[Union[int, float]] = ssl_close_timeout self.client_session_args = client_session_args @@ -125,9 +116,9 @@ async def connect(self) -> None: client_session_args: Dict[str, Any] = { "cookies": self.cookies, "headers": self.headers, - "auth": None - if isinstance(self.auth, AppSyncAuthentication) - else self.auth, + "auth": ( + None if isinstance(self.auth, AppSyncAuthentication) else self.auth + ), "json_serialize": self.json_serialize, } @@ -138,7 +129,7 @@ async def connect(self) -> None: # Adding custom parameters passed from init if self.client_session_args: - client_session_args.update(self.client_session_args) # type: ignore + client_session_args.update(self.client_session_args) log.debug("Connecting transport") @@ -147,59 +138,6 @@ async def connect(self) -> None: else: raise TransportAlreadyConnected("Transport is already connected") - @staticmethod - def create_aiohttp_closed_event(session) -> asyncio.Event: - """Work around aiohttp issue that doesn't properly close transports on exit. - - See https://github.com/aio-libs/aiohttp/issues/1925#issuecomment-639080209 - - Returns: - An event that will be set once all transports have been properly closed. - """ - - ssl_transports = 0 - all_is_lost = asyncio.Event() - - def connection_lost(exc, orig_lost): - nonlocal ssl_transports - - try: - orig_lost(exc) - finally: - ssl_transports -= 1 - if ssl_transports == 0: - all_is_lost.set() - - def eof_received(orig_eof_received): - try: # pragma: no cover - orig_eof_received() - except AttributeError: # pragma: no cover - # It may happen that eof_received() is called after - # _app_protocol and _transport are set to None. - pass - - for conn in session.connector._conns.values(): - for handler, _ in conn: - proto = getattr(handler.transport, "_ssl_protocol", None) - if proto is None: - continue - - ssl_transports += 1 - orig_lost = proto.connection_lost - orig_eof_received = proto.eof_received - - proto.connection_lost = functools.partial( - connection_lost, orig_lost=orig_lost - ) - proto.eof_received = functools.partial( - eof_received, orig_eof_received=orig_eof_received - ) - - if ssl_transports == 0: - all_is_lost.set() - - return all_is_lost - async def close(self) -> None: """Coroutine which will close the aiohttp session. @@ -219,7 +157,7 @@ async def close(self) -> None: log.debug("connector_owner is False -> not closing connector") else: - closed_event = self.create_aiohttp_closed_event(self.session) + closed_event = create_aiohttp_closed_event(self.session) await self.session.close() try: await asyncio.wait_for(closed_event.wait(), self.ssl_close_timeout) @@ -228,160 +166,264 @@ async def close(self) -> None: self.session = None - async def execute( + def _prepare_request( self, - document: DocumentNode, - variable_values: Optional[Dict[str, Any]] = None, - operation_name: Optional[str] = None, + request: Union[GraphQLRequest, List[GraphQLRequest]], extra_args: Optional[Dict[str, Any]] = None, upload_files: bool = False, - ) -> ExecutionResult: - """Execute the provided document AST against the configured remote server - using the current session. - This uses the aiohttp library to perform a HTTP POST request asynchronously - to the remote server. + ) -> Dict[str, Any]: - Don't call this coroutine directly on the transport, instead use - :code:`execute` on a client or a session. + payload: Dict | List + if isinstance(request, GraphQLRequest): + payload = request.payload + else: + payload = [req.payload for req in request] - :param document: the parsed GraphQL request - :param variable_values: An optional Dict of variable values - :param operation_name: An optional Operation name for the request - :param extra_args: additional arguments to send to the aiohttp post method - :param upload_files: Set to True if you want to put files in the variable values - :returns: an ExecutionResult object. - """ + if upload_files: + assert isinstance(payload, Dict) + assert isinstance(request, GraphQLRequest) + post_args = self._prepare_file_uploads(request, payload) + else: + post_args = {"json": payload} - query_str = print_ast(document) + # Log the payload + if log.isEnabledFor(logging.DEBUG): + log.debug(">>> %s", self.json_serialize(payload)) - payload: Dict[str, Any] = { - "query": query_str, - } + # Pass post_args to aiohttp post method + if extra_args: + post_args.update(extra_args) - if operation_name: - payload["operationName"] = operation_name + # Add headers for AppSync if requested + if isinstance(self.auth, AppSyncAuthentication): + post_args["headers"] = self.auth.get_headers( + self.json_serialize(payload), + {"content-type": "application/json"}, + ) - if upload_files: + return post_args - # If the upload_files flag is set, then we need variable_values - assert variable_values is not None + def _prepare_file_uploads( + self, request: GraphQLRequest, payload: Dict[str, Any] + ) -> Dict[str, Any]: - # If we upload files, we will extract the files present in the - # variable_values dict and replace them by null values - nulled_variable_values, files = extract_files( - variables=variable_values, - file_classes=self.file_classes, - ) + # If the upload_files flag is set, then we need variable_values + variable_values = request.variable_values + assert variable_values is not None + + # If we upload files, we will extract the files present in the + # variable_values dict and replace them by null values + nulled_variable_values, files = extract_files( + variables=variable_values, + file_classes=self.file_classes, + ) + + # Opening the files using the FileVar parameters + open_files(list(files.values()), transport_supports_streaming=True) + self.files = files + + # Save the nulled variable values in the payload + payload["variables"] = nulled_variable_values + + # Prepare aiohttp to send multipart-encoded data + data = aiohttp.FormData() + + # Generate the file map + # path is nested in a list because the spec allows multiple pointers + # to the same file. But we don't support that. + # Will generate something like {"0": ["variables.file"]} + file_map = {str(i): [path] for i, path in enumerate(files)} - # Save the nulled variable values in the payload - payload["variables"] = nulled_variable_values + # Enumerate the file streams + # Will generate something like {'0': FileVar object} + file_vars = {str(i): files[path] for i, path in enumerate(files)} - # Prepare aiohttp to send multipart-encoded data - data = aiohttp.FormData() + # Add the payload to the operations field + operations_str = self.json_serialize(payload) + log.debug("operations %s", operations_str) + data.add_field("operations", operations_str, content_type="application/json") - # Generate the file map - # path is nested in a list because the spec allows multiple pointers - # to the same file. But we don't support that. - # Will generate something like {"0": ["variables.file"]} - file_map = {str(i): [path] for i, path in enumerate(files)} + # Add the file map field + file_map_str = self.json_serialize(file_map) + log.debug("file_map %s", file_map_str) + data.add_field("map", file_map_str, content_type="application/json") - # Enumerate the file streams - # Will generate something like {'0': <_io.BufferedReader ...>} - file_streams = {str(i): files[path] for i, path in enumerate(files)} + for k, file_var in file_vars.items(): + assert isinstance(file_var, FileVar) - # Add the payload to the operations field - operations_str = self.json_serialize(payload) - log.debug("operations %s", operations_str) data.add_field( - "operations", operations_str, content_type="application/json" + k, + file_var.f, + filename=file_var.filename, + content_type=file_var.content_type, ) - # Add the file map field - file_map_str = self.json_serialize(file_map) - log.debug("file_map %s", file_map_str) - data.add_field("map", file_map_str, content_type="application/json") + post_args: Dict[str, Any] = {"data": data} - # Add the extracted files as remaining fields - for k, f in file_streams.items(): - name = getattr(f, "name", k) - content_type = getattr(f, "content_type", None) + return post_args - data.add_field(k, f, filename=name, content_type=content_type) + @staticmethod + def _raise_transport_server_error_if_status_more_than_400( + resp: aiohttp.ClientResponse, + ) -> None: + # If the status is >400, + # then we need to raise a TransportServerError + try: + # Raise ClientResponseError if response status is 400 or higher + resp.raise_for_status() + except ClientResponseError as e: + raise TransportServerError(str(e), e.status) from e + + @classmethod + async def _raise_response_error( + cls, + resp: aiohttp.ClientResponse, + reason: str, + ) -> None: + # We raise a TransportServerError if status code is 400 or higher + # We raise a TransportProtocolError in the other cases - post_args: Dict[str, Any] = {"data": data} + cls._raise_transport_server_error_if_status_more_than_400(resp) - else: - if variable_values: - payload["variables"] = variable_values + result_text = await resp.text() + raise TransportProtocolError( + f"Server did not return a valid GraphQL result: " + f"{reason}: " + f"{result_text}" + ) - if log.isEnabledFor(logging.INFO): - log.info(">>> %s", self.json_serialize(payload)) + async def _get_json_result(self, response: aiohttp.ClientResponse) -> Any: - post_args = {"json": payload} + # Saving latest response headers in the transport + self.response_headers = response.headers - # Pass post_args to aiohttp post method - if extra_args: - post_args.update(extra_args) + try: + result = await response.json(loads=self.json_deserialize, content_type=None) - # Add headers for AppSync if requested - if isinstance(self.auth, AppSyncAuthentication): - post_args["headers"] = self.auth.get_headers( - self.json_serialize(payload), - {"content-type": "application/json"}, + if log.isEnabledFor(logging.DEBUG): + result_text = await response.text() + log.debug("<<< %s", result_text) + + except Exception: + await self._raise_response_error(response, "Not a JSON answer") + + if result is None: + await self._raise_response_error(response, "Not a JSON answer") + + return result + + async def _prepare_result( + self, response: aiohttp.ClientResponse + ) -> ExecutionResult: + + result = await self._get_json_result(response) + + if "errors" not in result and "data" not in result: + await self._raise_response_error( + response, 'No "data" or "errors" keys in answer' ) - if self.session is None: - raise TransportClosed("Transport is not connected") + return ExecutionResult( + errors=result.get("errors"), + data=result.get("data"), + extensions=result.get("extensions"), + ) - async with self.session.post(self.url, ssl=self.ssl, **post_args) as resp: + async def _prepare_batch_result( + self, + reqs: List[GraphQLRequest], + response: aiohttp.ClientResponse, + ) -> List[ExecutionResult]: - # Saving latest response headers in the transport - self.response_headers = resp.headers + answers = await self._get_json_result(response) - async def raise_response_error(resp: aiohttp.ClientResponse, reason: str): - # We raise a TransportServerError if the status code is 400 or higher - # We raise a TransportProtocolError in the other cases + try: + return get_batch_execution_result_list(reqs, answers) + except TransportProtocolError: + # Raise a TransportServerError if status > 400 + self._raise_transport_server_error_if_status_more_than_400(response) + # In other cases, raise a TransportProtocolError + raise - try: - # Raise a ClientResponseError if response status is 400 or higher - resp.raise_for_status() - except ClientResponseError as e: - raise TransportServerError(str(e), e.status) from e - - result_text = await resp.text() - raise TransportProtocolError( - f"Server did not return a GraphQL result: " - f"{reason}: " - f"{result_text}" - ) + async def execute( + self, + request: GraphQLRequest, + *, + extra_args: Optional[Dict[str, Any]] = None, + upload_files: bool = False, + ) -> ExecutionResult: + """Execute the provided request against the configured remote server + using the current session. + This uses the aiohttp library to perform a HTTP POST request asynchronously + to the remote server. - try: - result = await resp.json(loads=self.json_deserialize, content_type=None) + Don't call this coroutine directly on the transport, instead use + :code:`execute` on a client or a session. - if log.isEnabledFor(logging.INFO): - result_text = await resp.text() - log.info("<<< %s", result_text) + :param request: GraphQL request as a + :class:`GraphQLRequest ` object. + :param extra_args: additional arguments to send to the aiohttp post method + :param upload_files: Set to True if you want to put files in the variable values + :returns: an ExecutionResult object. + """ - except Exception: - await raise_response_error(resp, "Not a JSON answer") + if self.session is None: + raise TransportClosed("Transport is not connected") - if result is None: - await raise_response_error(resp, "Not a JSON answer") + post_args = self._prepare_request( + request, + extra_args, + upload_files, + ) - if "errors" not in result and "data" not in result: - await raise_response_error(resp, 'No "data" or "errors" keys in answer') + try: + async with self.session.post(self.url, ssl=self.ssl, **post_args) as resp: + return await self._prepare_result(resp) + except TransportError: + raise + except Exception as e: + raise TransportConnectionFailed(str(e)) from e + finally: + if upload_files: + close_files(list(self.files.values())) + + async def execute_batch( + self, + reqs: List[GraphQLRequest], + extra_args: Optional[Dict[str, Any]] = None, + ) -> List[ExecutionResult]: + """Execute multiple GraphQL requests in a batch. - return ExecutionResult( - errors=result.get("errors"), - data=result.get("data"), - extensions=result.get("extensions"), - ) + Don't call this coroutine directly on the transport, instead use + :code:`execute_batch` on a client or a session. + + :param reqs: GraphQL requests as a list of GraphQLRequest objects. + :param extra_args: additional arguments to send to the aiohttp post method + :return: A list of results of execution. + For every result `data` is the result of executing the query, + `errors` is null if no errors occurred, and is a non-empty array + if an error occurred. + """ + + if self.session is None: + raise TransportClosed("Transport is not connected") + + post_args = self._prepare_request( + reqs, + extra_args, + ) + + try: + async with self.session.post(self.url, ssl=self.ssl, **post_args) as resp: + return await self._prepare_batch_result(reqs, resp) + except TransportError: + raise + except Exception as e: + raise TransportConnectionFailed(str(e)) from e def subscribe( self, - document: DocumentNode, - variable_values: Optional[Dict[str, Any]] = None, - operation_name: Optional[str] = None, + request: GraphQLRequest, ) -> AsyncGenerator[ExecutionResult, None]: """Subscribe is not supported on HTTP. diff --git a/gql/transport/aiohttp_websockets.py b/gql/transport/aiohttp_websockets.py index 18699b5e..59d870f6 100644 --- a/gql/transport/aiohttp_websockets.py +++ b/gql/transport/aiohttp_websockets.py @@ -1,106 +1,26 @@ -import asyncio -import json -import logging -import warnings -from contextlib import suppress from ssl import SSLContext -from typing import ( - Any, - AsyncGenerator, - Collection, - Dict, - Literal, - Mapping, - Optional, - Tuple, - Union, -) +from typing import Any, Dict, List, Literal, Mapping, Optional, Union -import aiohttp -from aiohttp import BasicAuth, Fingerprint, WSMsgType +from aiohttp import BasicAuth, ClientSession, Fingerprint from aiohttp.typedefs import LooseHeaders, StrOrURL -from graphql import DocumentNode, ExecutionResult, print_ast -from multidict import CIMultiDictProxy -from gql.transport.aiohttp import AIOHTTPTransport -from gql.transport.async_transport import AsyncTransport -from gql.transport.exceptions import ( - TransportAlreadyConnected, - TransportClosed, - TransportProtocolError, - TransportQueryError, - TransportServerError, -) +from .common.adapters.aiohttp import AIOHTTPWebSocketsAdapter +from .websockets_protocol import WebsocketsProtocolTransportBase -log = logging.getLogger("gql.transport.aiohttp_websockets") -ParsedAnswer = Tuple[str, Optional[ExecutionResult]] +class AIOHTTPWebsocketsTransport(WebsocketsProtocolTransportBase): + """:ref:`Async Transport ` used to execute GraphQL queries on + remote servers with websocket connection. - -class ListenerQueue: - """Special queue used for each query waiting for server answers - - If the server is stopped while the listener is still waiting, - Then we send an exception to the queue and this exception will be raised - to the consumer once all the previous messages have been consumed from the queue + This transport uses asyncio and the provided aiohttp adapter library + in order to send requests on a websocket connection. """ - def __init__(self, query_id: int, send_stop: bool) -> None: - self.query_id: int = query_id - self.send_stop: bool = send_stop - self._queue: asyncio.Queue = asyncio.Queue() - self._closed: bool = False - - async def get(self) -> ParsedAnswer: - - try: - item = self._queue.get_nowait() - except asyncio.QueueEmpty: - item = await self._queue.get() - - self._queue.task_done() - - # If we receive an exception when reading the queue, we raise it - if isinstance(item, Exception): - self._closed = True - raise item - - # Don't need to save new answers or - # send the stop message if we already received the complete message - answer_type, execution_result = item - if answer_type == "complete": - self.send_stop = False - self._closed = True - - return item - - async def put(self, item: ParsedAnswer) -> None: - - if not self._closed: - await self._queue.put(item) - - async def set_exception(self, exception: Exception) -> None: - - # Put the exception in the queue - await self._queue.put(exception) - - # Don't need to send stop messages in case of error - self.send_stop = False - self._closed = True - - -class AIOHTTPWebsocketsTransport(AsyncTransport): - - # This transport supports two subprotocols and will autodetect the - # subprotocol supported on the server - APOLLO_SUBPROTOCOL: str = "graphql-ws" - GRAPHQLWS_SUBPROTOCOL: str = "graphql-transport-ws" - def __init__( self, url: StrOrURL, *, - subprotocols: Optional[Collection[str]] = None, + subprotocols: Optional[List[str]] = None, heartbeat: Optional[float] = None, auth: Optional[BasicAuth] = None, origin: Optional[str] = None, @@ -121,8 +41,9 @@ def __init__( ping_interval: Optional[Union[int, float]] = None, pong_timeout: Optional[Union[int, float]] = None, answer_pings: bool = True, + session: Optional[ClientSession] = None, client_session_args: Optional[Dict[str, Any]] = None, - connect_args: Dict[str, Any] = {}, + connect_args: Optional[Dict[str, Any]] = None, ) -> None: """Initialize the transport with the given parameters. @@ -193,6 +114,7 @@ def __init__( :param answer_pings: Whether the client answers the pings from the backend (for the graphql-ws protocol). By default: True + :param session: Optional aiohttp.ClientSession instance. :param client_session_args: Dict of extra args passed to `aiohttp.ClientSession`_ :param connect_args: Dict of extra args passed to @@ -203,986 +125,46 @@ def __init__( .. _aiohttp.ClientSession: https://docs.aiohttp.org/en/stable/client_reference.html#aiohttp.ClientSession """ - self.url: StrOrURL = url - self.heartbeat: Optional[float] = heartbeat - self.auth: Optional[BasicAuth] = auth - self.origin: Optional[str] = origin - self.params: Optional[Mapping[str, str]] = params - self.headers: Optional[LooseHeaders] = headers - - self.proxy: Optional[StrOrURL] = proxy - self.proxy_auth: Optional[BasicAuth] = proxy_auth - self.proxy_headers: Optional[LooseHeaders] = proxy_headers - - self.ssl: Optional[Union[SSLContext, Literal[False], Fingerprint]] = ssl - - self.websocket_close_timeout: float = websocket_close_timeout - self.receive_timeout: Optional[float] = receive_timeout - - self.ssl_close_timeout: Optional[Union[int, float]] = ssl_close_timeout - self.connect_timeout: Optional[Union[int, float]] = connect_timeout - self.close_timeout: Optional[Union[int, float]] = close_timeout - self.ack_timeout: Optional[Union[int, float]] = ack_timeout - self.keep_alive_timeout: Optional[Union[int, float]] = keep_alive_timeout - - self.init_payload: Dict[str, Any] = init_payload - - # We need to set an event loop here if there is none - # Or else we will not be able to create an asyncio.Event() - try: - with warnings.catch_warnings(): - warnings.filterwarnings( - "ignore", message="There is no current event loop" - ) - self._loop = asyncio.get_event_loop() - except RuntimeError: - self._loop = asyncio.new_event_loop() - asyncio.set_event_loop(self._loop) - - self._next_keep_alive_message: asyncio.Event = asyncio.Event() - self._next_keep_alive_message.set() - - self.session: Optional[aiohttp.ClientSession] = None - self.websocket: Optional[aiohttp.ClientWebSocketResponse] = None - self.next_query_id: int = 1 - self.listeners: Dict[int, ListenerQueue] = {} - self._connecting: bool = False - self.response_headers: Optional[CIMultiDictProxy[str]] = None - - self.receive_data_task: Optional[asyncio.Future] = None - self.check_keep_alive_task: Optional[asyncio.Future] = None - self.close_task: Optional[asyncio.Future] = None - - self._wait_closed: asyncio.Event = asyncio.Event() - self._wait_closed.set() - - self._no_more_listeners: asyncio.Event = asyncio.Event() - self._no_more_listeners.set() - - self.payloads: Dict[str, Any] = {} - - self.ping_interval: Optional[Union[int, float]] = ping_interval - self.pong_timeout: Optional[Union[int, float]] - self.answer_pings: bool = answer_pings - - if ping_interval is not None: - if pong_timeout is None: - self.pong_timeout = ping_interval / 2 - else: - self.pong_timeout = pong_timeout - - self.send_ping_task: Optional[asyncio.Future] = None - - self.ping_received: asyncio.Event = asyncio.Event() - """ping_received is an asyncio Event which will fire each time - a ping is received with the graphql-ws protocol""" - - self.pong_received: asyncio.Event = asyncio.Event() - """pong_received is an asyncio Event which will fire each time - a pong is received with the graphql-ws protocol""" - - self.supported_subprotocols: Collection[str] = subprotocols or ( - self.APOLLO_SUBPROTOCOL, - self.GRAPHQLWS_SUBPROTOCOL, - ) - - self.close_exception: Optional[Exception] = None - - self.client_session_args = client_session_args - self.connect_args = connect_args - - def _parse_answer_graphqlws( - self, answer: Dict[str, Any] - ) -> Tuple[str, Optional[int], Optional[ExecutionResult]]: - """Parse the answer received from the server if the server supports the - graphql-ws protocol. - - Returns a list consisting of: - - the answer_type (between: - 'connection_ack', 'ping', 'pong', 'data', 'error', 'complete') - - the answer id (Integer) if received or None - - an execution Result if the answer_type is 'data' or None - - Differences with the apollo websockets protocol (superclass): - - the "data" message is now called "next" - - the "stop" message is now called "complete" - - there is no connection_terminate or connection_error messages - - instead of a unidirectional keep-alive (ka) message from server to client, - there is now the possibility to send bidirectional ping/pong messages - - connection_ack has an optional payload - - the 'error' answer type returns a list of errors instead of a single error - """ - - answer_type: str = "" - answer_id: Optional[int] = None - execution_result: Optional[ExecutionResult] = None - - try: - answer_type = str(answer.get("type")) - - if answer_type in ["next", "error", "complete"]: - answer_id = int(str(answer.get("id"))) - - if answer_type == "next" or answer_type == "error": - - payload = answer.get("payload") - - if answer_type == "next": - - if not isinstance(payload, dict): - raise ValueError("payload is not a dict") - - if "errors" not in payload and "data" not in payload: - raise ValueError( - "payload does not contain 'data' or 'errors' fields" - ) - - execution_result = ExecutionResult( - errors=payload.get("errors"), - data=payload.get("data"), - extensions=payload.get("extensions"), - ) - - # Saving answer_type as 'data' to be understood with superclass - answer_type = "data" - - elif answer_type == "error": - - if not isinstance(payload, list): - raise ValueError("payload is not a list") - - raise TransportQueryError( - str(payload[0]), query_id=answer_id, errors=payload - ) - - elif answer_type in ["ping", "pong", "connection_ack"]: - self.payloads[answer_type] = answer.get("payload", None) - - else: - raise ValueError - - if self.check_keep_alive_task is not None: - self._next_keep_alive_message.set() - - except ValueError as e: - raise TransportProtocolError( - f"Server did not return a GraphQL result: {answer}" - ) from e - - return answer_type, answer_id, execution_result - - def _parse_answer_apollo( - self, answer: Dict[str, Any] - ) -> Tuple[str, Optional[int], Optional[ExecutionResult]]: - """Parse the answer received from the server if the server supports the - apollo websockets protocol. - - Returns a list consisting of: - - the answer_type (between: - 'connection_ack', 'ka', 'connection_error', 'data', 'error', 'complete') - - the answer id (Integer) if received or None - - an execution Result if the answer_type is 'data' or None - """ - - answer_type: str = "" - answer_id: Optional[int] = None - execution_result: Optional[ExecutionResult] = None - - try: - answer_type = str(answer.get("type")) - - if answer_type in ["data", "error", "complete"]: - answer_id = int(str(answer.get("id"))) - - if answer_type == "data" or answer_type == "error": - - payload = answer.get("payload") - - if not isinstance(payload, dict): - raise ValueError("payload is not a dict") - - if answer_type == "data": - - if "errors" not in payload and "data" not in payload: - raise ValueError( - "payload does not contain 'data' or 'errors' fields" - ) - - execution_result = ExecutionResult( - errors=payload.get("errors"), - data=payload.get("data"), - extensions=payload.get("extensions"), - ) - - elif answer_type == "error": - - raise TransportQueryError( - str(payload), query_id=answer_id, errors=[payload] - ) - - elif answer_type == "ka": - # Keep-alive message - if self.check_keep_alive_task is not None: - self._next_keep_alive_message.set() - elif answer_type == "connection_ack": - pass - elif answer_type == "connection_error": - error_payload = answer.get("payload") - raise TransportServerError(f"Server error: '{repr(error_payload)}'") - else: - raise ValueError - - except ValueError as e: - raise TransportProtocolError( - f"Server did not return a GraphQL result: {answer}" - ) from e - - return answer_type, answer_id, execution_result - - def _parse_answer( - self, answer: str - ) -> Tuple[str, Optional[int], Optional[ExecutionResult]]: - """Parse the answer received from the server depending on - the detected subprotocol. - """ - try: - json_answer = json.loads(answer) - except ValueError: - raise TransportProtocolError( - f"Server did not return a GraphQL result: {answer}" - ) - - if self.subprotocol == self.GRAPHQLWS_SUBPROTOCOL: - return self._parse_answer_graphqlws(json_answer) - - return self._parse_answer_apollo(json_answer) - - async def _wait_ack(self) -> None: - """Wait for the connection_ack message. Keep alive messages are ignored""" - - while True: - init_answer = await self._receive() - - answer_type, _, _ = self._parse_answer(init_answer) - - if answer_type == "connection_ack": - return - - if answer_type != "ka": - raise TransportProtocolError( - "Websocket server did not return a connection ack" - ) - - async def _send_init_message_and_wait_ack(self) -> None: - """Send init message to the provided websocket and wait for the connection ACK. - - If the answer is not a connection_ack message, we will return an Exception. - """ - - init_message = {"type": "connection_init", "payload": self.init_payload} - - await self._send(init_message) - - # Wait for the connection_ack message or raise a TimeoutError - await asyncio.wait_for(self._wait_ack(), self.ack_timeout) - - async def _initialize(self): - """Hook to send the initialization messages after the connection - and potentially wait for the backend ack. - """ - await self._send_init_message_and_wait_ack() - - async def _stop_listener(self, query_id: int): - """Hook to stop to listen to a specific query. - Will send a stop message in some subclasses. - """ - log.debug(f"stop listener {query_id}") - - if self.subprotocol == self.GRAPHQLWS_SUBPROTOCOL: - await self._send_complete_message(query_id) - await self.listeners[query_id].put(("complete", None)) - else: - await self._send_stop_message(query_id) - - async def _after_connect(self): - """Hook to add custom code for subclasses after the connection - has been established. - """ - # Find the backend subprotocol returned in the response headers - response_headers = self.websocket._response.headers - log.debug(f"Response headers: {response_headers!r}") - try: - self.subprotocol = response_headers["Sec-WebSocket-Protocol"] - except KeyError: - self.subprotocol = self.APOLLO_SUBPROTOCOL - - log.debug(f"backend subprotocol returned: {self.subprotocol!r}") - - async def send_ping(self, payload: Optional[Any] = None) -> None: - """Send a ping message for the graphql-ws protocol""" - - ping_message = {"type": "ping"} - - if payload is not None: - ping_message["payload"] = payload - - await self._send(ping_message) - - async def send_pong(self, payload: Optional[Any] = None) -> None: - """Send a pong message for the graphql-ws protocol""" - - pong_message = {"type": "pong"} - - if payload is not None: - pong_message["payload"] = payload - - await self._send(pong_message) - - async def _send_stop_message(self, query_id: int) -> None: - """Send stop message to the provided websocket connection and query_id. - - The server should afterwards return a 'complete' message. - """ - - stop_message = {"id": str(query_id), "type": "stop"} - - await self._send(stop_message) - - async def _send_complete_message(self, query_id: int) -> None: - """Send a complete message for the provided query_id. - - This is only for the graphql-ws protocol. - """ - - complete_message = {"id": str(query_id), "type": "complete"} - - await self._send(complete_message) - - async def _send_ping_coro(self) -> None: - """Coroutine to periodically send a ping from the client to the backend. - - Only used for the graphql-ws protocol. - - Send a ping every ping_interval seconds. - Close the connection if a pong is not received within pong_timeout seconds. - """ - - assert self.ping_interval is not None - - try: - while True: - await asyncio.sleep(self.ping_interval) - - await self.send_ping() - - await asyncio.wait_for(self.pong_received.wait(), self.pong_timeout) - - # Reset for the next iteration - self.pong_received.clear() - - except asyncio.TimeoutError: - # No pong received in the appriopriate time, close with error - # If the timeout happens during a close already in progress, do nothing - if self.close_task is None: - await self._fail( - TransportServerError( - f"No pong received after {self.pong_timeout!r} seconds" - ), - clean_close=False, - ) - - async def _after_initialize(self): - """Hook to add custom code for subclasses after the initialization - has been done. - """ - - # If requested, create a task to send periodic pings to the backend - if ( - self.subprotocol == self.GRAPHQLWS_SUBPROTOCOL - and self.ping_interval is not None - ): - - self.send_ping_task = asyncio.ensure_future(self._send_ping_coro()) - - async def _close_hook(self): - """Hook to add custom code for subclasses for the connection close""" - # Properly shut down the send ping task if enabled - if self.send_ping_task is not None: - self.send_ping_task.cancel() - with suppress(asyncio.CancelledError): - await self.send_ping_task - self.send_ping_task = None - - async def _connection_terminate(self): - """Hook to add custom code for subclasses after the initialization - has been done. - """ - if self.subprotocol == self.APOLLO_SUBPROTOCOL: - await self._send_connection_terminate_message() - async def _send_connection_terminate_message(self) -> None: - """Send a connection_terminate message to the provided websocket connection. - - This message indicates that the connection will disconnect. - """ - - connection_terminate_message = {"type": "connection_terminate"} - - await self._send(connection_terminate_message) - - async def _send_query( - self, - document: DocumentNode, - variable_values: Optional[Dict[str, Any]] = None, - operation_name: Optional[str] = None, - ) -> int: - """Send a query to the provided websocket connection. - - We use an incremented id to reference the query. - - Returns the used id for this query. - """ - - query_id = self.next_query_id - self.next_query_id += 1 - - payload: Dict[str, Any] = {"query": print_ast(document)} - if variable_values: - payload["variables"] = variable_values - if operation_name: - payload["operationName"] = operation_name - - query_type = "start" - - if self.subprotocol == self.GRAPHQLWS_SUBPROTOCOL: - query_type = "subscribe" - - query = {"id": str(query_id), "type": query_type, "payload": payload} - - await self._send(query) - - return query_id - - async def _send(self, message: Dict[str, Any]) -> None: - """Send the provided message to the websocket connection and log the message""" - - if self.websocket is None: - raise TransportClosed("WebSocket connection is closed") - - try: - await self.websocket.send_json(message) - log.info(">>> %s", message) - except ConnectionResetError as e: - await self._fail(e, clean_close=False) - raise e - - async def _receive(self) -> str: - """Wait the next message from the websocket connection and log the answer""" - - # It is possible that the websocket has been already closed in another task - if self.websocket is None: - raise TransportClosed("Transport is already closed") - - while True: - ws_message = await self.websocket.receive() - - # Ignore low-level ping and pong received - if ws_message.type not in (WSMsgType.PING, WSMsgType.PONG): - break - - if ws_message.type in ( - WSMsgType.CLOSE, - WSMsgType.CLOSED, - WSMsgType.CLOSING, - WSMsgType.ERROR, - ): - raise ConnectionResetError - elif ws_message.type is WSMsgType.BINARY: - raise TransportProtocolError("Binary data received in the websocket") - - assert ws_message.type is WSMsgType.TEXT - - answer: str = ws_message.data - - log.info("<<< %s", answer) - - return answer - - def _remove_listener(self, query_id) -> None: - """After exiting from a subscription, remove the listener and - signal an event if this was the last listener for the client. - """ - if query_id in self.listeners: - del self.listeners[query_id] - - remaining = len(self.listeners) - log.debug(f"listener {query_id} deleted, {remaining} remaining") - - if remaining == 0: - self._no_more_listeners.set() - - async def _check_ws_liveness(self) -> None: - """Coroutine which will periodically check the liveness of the connection - through keep-alive messages - """ - - try: - while True: - await asyncio.wait_for( - self._next_keep_alive_message.wait(), self.keep_alive_timeout - ) - - # Reset for the next iteration - self._next_keep_alive_message.clear() - - except asyncio.TimeoutError: - # No keep-alive message in the appriopriate interval, close with error - # while trying to notify the server of a proper close (in case - # the keep-alive interval of the client or server was not aligned - # the connection still remains) - - # If the timeout happens during a close already in progress, do nothing - if self.close_task is None: - await self._fail( - TransportServerError( - "No keep-alive message has been received within " - "the expected interval ('keep_alive_timeout' parameter)" - ), - clean_close=False, - ) - - except asyncio.CancelledError: - # The client is probably closing, handle it properly - pass - - async def _handle_answer( - self, - answer_type: str, - answer_id: Optional[int], - execution_result: Optional[ExecutionResult], - ) -> None: - - try: - # Put the answer in the queue - if answer_id is not None: - await self.listeners[answer_id].put((answer_type, execution_result)) - except KeyError: - # Do nothing if no one is listening to this query_id. - pass - - # Answer pong to ping for graphql-ws protocol - if answer_type == "ping": - self.ping_received.set() - if self.answer_pings: - await self.send_pong() - - elif answer_type == "pong": - self.pong_received.set() - - async def _receive_data_loop(self) -> None: - """Main asyncio task which will listen to the incoming messages and will - call the parse_answer and handle_answer methods of the subclass.""" - log.debug("Entering _receive_data_loop()") - - try: - while True: - - # Wait the next answer from the websocket server - try: - answer = await self._receive() - except (ConnectionResetError, TransportProtocolError) as e: - await self._fail(e, clean_close=False) - break - except TransportClosed as e: - await self._fail(e, clean_close=False) - raise e - - # Parse the answer - try: - answer_type, answer_id, execution_result = self._parse_answer( - answer - ) - except TransportQueryError as e: - # Received an exception for a specific query - # ==> Add an exception to this query queue - # The exception is raised for this specific query, - # but the transport is not closed. - assert isinstance( - e.query_id, int - ), "TransportQueryError should have a query_id defined here" - try: - await self.listeners[e.query_id].set_exception(e) - except KeyError: - # Do nothing if no one is listening to this query_id - pass - - continue - - except (TransportServerError, TransportProtocolError) as e: - # Received a global exception for this transport - # ==> close the transport - # The exception will be raised for all current queries. - await self._fail(e, clean_close=False) - break - - await self._handle_answer(answer_type, answer_id, execution_result) - - finally: - log.debug("Exiting _receive_data_loop()") - - async def connect(self) -> None: - log.debug("connect: starting") - - if self.session is None: - client_session_args: Dict[str, Any] = {} - - # Adding custom parameters passed from init - if self.client_session_args: - client_session_args.update(self.client_session_args) # type: ignore - - self.session = aiohttp.ClientSession(**client_session_args) - - if self.websocket is None and not self._connecting: - self._connecting = True - - connect_args: Dict[str, Any] = { - "url": self.url, - "headers": self.headers, - "auth": self.auth, - "heartbeat": self.heartbeat, - "origin": self.origin, - "params": self.params, - "protocols": self.supported_subprotocols, - "proxy": self.proxy, - "proxy_auth": self.proxy_auth, - "proxy_headers": self.proxy_headers, - "timeout": self.websocket_close_timeout, - "receive_timeout": self.receive_timeout, - } - - if self.ssl is not None: - connect_args.update( - { - "ssl": self.ssl, - } - ) - - # Adding custom parameters passed from init - if self.connect_args: - connect_args.update(self.connect_args) - - try: - # Connection to the specified url - # Generate a TimeoutError if taking more than connect_timeout seconds - # Set the _connecting flag to False after in all cases - self.websocket = await asyncio.wait_for( - self.session.ws_connect( - **connect_args, - ), - self.connect_timeout, - ) - finally: - self._connecting = False - - self.response_headers = self.websocket._response.headers - - await self._after_connect() - - self.next_query_id = 1 - self.close_exception = None - self._wait_closed.clear() - - # Send the init message and wait for the ack from the server - # Note: This should generate a TimeoutError - # if no ACKs are received within the ack_timeout - try: - await self._initialize() - except ConnectionResetError as e: - raise e - except ( - TransportProtocolError, - TransportServerError, - asyncio.TimeoutError, - ) as e: - await self._fail(e, clean_close=False) - raise e - - # Run the after_init hook of the subclass - await self._after_initialize() - - # If specified, create a task to check liveness of the connection - # through keep-alive messages - if self.keep_alive_timeout is not None: - self.check_keep_alive_task = asyncio.ensure_future( - self._check_ws_liveness() - ) - - # Create a task to listen to the incoming websocket messages - self.receive_data_task = asyncio.ensure_future(self._receive_data_loop()) - - else: - raise TransportAlreadyConnected("Transport is already connected") - - log.debug("connect: done") - - async def _clean_close(self) -> None: - """Coroutine which will: - - - send stop messages for each active subscription to the server - - send the connection terminate message - """ - log.debug(f"Listeners: {self.listeners}") - - # Send 'stop' message for all current queries - for query_id, listener in self.listeners.items(): - print(f"Listener {query_id} send_stop: {listener.send_stop}") - - if listener.send_stop: - await self._stop_listener(query_id) - listener.send_stop = False - - # Wait that there is no more listeners (we received 'complete' for all queries) - try: - await asyncio.wait_for(self._no_more_listeners.wait(), self.close_timeout) - except asyncio.TimeoutError: # pragma: no cover - log.debug("Timer close_timeout fired") - - # Calling the subclass hook - await self._connection_terminate() - - async def _close_coro(self, e: Exception, clean_close: bool = True) -> None: - """Coroutine which will: - - - do a clean_close if possible: - - send stop messages for each active query to the server - - send the connection terminate message - - close the websocket connection - - send the exception to all the remaining listeners - """ - - log.debug("_close_coro: starting") - - try: - - try: - # Properly shut down liveness checker if enabled - if self.check_keep_alive_task is not None: - # More info: https://stackoverflow.com/a/43810272/1113207 - self.check_keep_alive_task.cancel() - with suppress(asyncio.CancelledError): - await self.check_keep_alive_task - except Exception as exc: # pragma: no cover - log.warning( - "_close_coro cancel keep alive task exception: " + repr(exc) - ) - - try: - # Calling the subclass close hook - await self._close_hook() - except Exception as exc: # pragma: no cover - log.warning("_close_coro close_hook exception: " + repr(exc)) - - # Saving exception to raise it later if trying to use the transport - # after it has already closed. - self.close_exception = e - - if clean_close: - log.debug("_close_coro: starting clean_close") - try: - await self._clean_close() - except Exception as exc: # pragma: no cover - log.warning("Ignoring exception in _clean_close: " + repr(exc)) - - log.debug("_close_coro: sending exception to listeners") - - # Send an exception to all remaining listeners - for query_id, listener in self.listeners.items(): - await listener.set_exception(e) - - log.debug("_close_coro: close websocket connection") - - try: - assert self.websocket is not None - - await self.websocket.close() - self.websocket = None - except Exception as exc: - log.warning("_close_coro websocket close exception: " + repr(exc)) - - log.debug("_close_coro: close aiohttp session") - - if ( - self.client_session_args - and self.client_session_args.get("connector_owner") is False - ): - - log.debug("connector_owner is False -> not closing connector") - - else: - try: - assert self.session is not None - - closed_event = AIOHTTPTransport.create_aiohttp_closed_event( - self.session - ) - await self.session.close() - try: - await asyncio.wait_for( - closed_event.wait(), self.ssl_close_timeout - ) - except asyncio.TimeoutError: - pass - except Exception as exc: # pragma: no cover - log.warning("_close_coro session close exception: " + repr(exc)) - - self.session = None - - log.debug("_close_coro: aiohttp session closed") - - try: - assert self.receive_data_task is not None - - self.receive_data_task.cancel() - with suppress(asyncio.CancelledError): - await self.receive_data_task - except Exception as exc: # pragma: no cover - log.warning( - "_close_coro cancel receive data task exception: " + repr(exc) - ) - - except Exception as exc: # pragma: no cover - log.warning("Exception catched in _close_coro: " + repr(exc)) - - finally: - - log.debug("_close_coro: final cleanup") - - self.websocket = None - self.close_task = None - self.check_keep_alive_task = None - self.receive_data_task = None - self._wait_closed.set() - - log.debug("_close_coro: exiting") - - async def _fail(self, e: Exception, clean_close: bool = True) -> None: - log.debug("_fail: starting with exception: " + repr(e)) - - if self.close_task is None: - - if self._wait_closed.is_set(): - log.debug("_fail started but transport is already closed") - else: - self.close_task = asyncio.shield( - asyncio.ensure_future(self._close_coro(e, clean_close=clean_close)) - ) - else: - log.debug( - "close_task is not None in _fail. Previous exception is: " - + repr(self.close_exception) - + " New exception is: " - + repr(e) - ) - - async def close(self) -> None: - log.debug("close: starting") - - await self._fail(TransportClosed("Websocket GraphQL transport closed by user")) - await self.wait_closed() - - log.debug("close: done") - - async def wait_closed(self) -> None: - log.debug("wait_close: starting") - - if not self._wait_closed.is_set(): - await self._wait_closed.wait() - - log.debug("wait_close: done") - - async def execute( - self, - document: DocumentNode, - variable_values: Optional[Dict[str, Any]] = None, - operation_name: Optional[str] = None, - ) -> ExecutionResult: - """Execute the provided document AST against the configured remote server - using the current session. - - Send a query but close the async generator as soon as we have the first answer. - - The result is sent as an ExecutionResult object. - """ - first_result = None - - generator = self.subscribe( - document, variable_values, operation_name, send_stop=False + # Instanciate a AIOHTTPWebSocketAdapter to indicate the use + # of the aiohttp dependency for this transport + self.adapter: AIOHTTPWebSocketsAdapter = AIOHTTPWebSocketsAdapter( + url=url, + headers=headers, + ssl=ssl, + session=session, + client_session_args=client_session_args, + connect_args=connect_args, + heartbeat=heartbeat, + auth=auth, + origin=origin, + params=params, + proxy=proxy, + proxy_auth=proxy_auth, + proxy_headers=proxy_headers, + websocket_close_timeout=websocket_close_timeout, + receive_timeout=receive_timeout, + ssl_close_timeout=ssl_close_timeout, ) - async for result in generator: - first_result = result - break - - if first_result is None: - raise TransportQueryError( - "Query completed without any answer received from the server" - ) - - return first_result - - async def subscribe( - self, - document: DocumentNode, - variable_values: Optional[Dict[str, Any]] = None, - operation_name: Optional[str] = None, - send_stop: Optional[bool] = True, - ) -> AsyncGenerator[ExecutionResult, None]: - """Send a query and receive the results using a python async generator. - - The query can be a graphql query, mutation or subscription. - - The results are sent as an ExecutionResult object. - """ - - # Send the query and receive the id - query_id: int = await self._send_query( - document, variable_values, operation_name + # Initialize the WebsocketsProtocolTransportBase parent class + super().__init__( + adapter=self.adapter, + init_payload=init_payload, + connect_timeout=connect_timeout, + close_timeout=close_timeout, + ack_timeout=ack_timeout, + keep_alive_timeout=keep_alive_timeout, + ping_interval=ping_interval, + pong_timeout=pong_timeout, + answer_pings=answer_pings, + subprotocols=subprotocols, ) - # Create a queue to receive the answers for this query_id - listener = ListenerQueue(query_id, send_stop=(send_stop is True)) - self.listeners[query_id] = listener - - # We will need to wait at close for this query to clean properly - self._no_more_listeners.clear() - - try: - # Loop over the received answers - while True: - - # Wait for the answer from the queue of this query_id - # This can raise a TransportError or ConnectionClosed exception. - answer_type, execution_result = await listener.get() - - # If the received answer contains data, - # Then we will yield the results back as an ExecutionResult object - if execution_result is not None: - yield execution_result - - # If we receive a 'complete' answer from the server, - # Then we will end this async generator output without errors - elif answer_type == "complete": - log.debug( - f"Complete received for query {query_id} --> exit without error" - ) - break - - except (asyncio.CancelledError, GeneratorExit) as e: - log.debug(f"Exception in subscribe: {e!r}") - if listener.send_stop: - await self._stop_listener(query_id) - listener.send_stop = False + @property + def headers(self) -> Optional[LooseHeaders]: + return self.adapter.headers - finally: - log.debug(f"In subscribe finally for query_id {query_id}") - self._remove_listener(query_id) + @property + def ssl(self) -> Optional[Union[SSLContext, Literal[False], Fingerprint]]: + return self.adapter.ssl diff --git a/gql/transport/appsync_websockets.py b/gql/transport/appsync_websockets.py index 66091747..e2ab4f96 100644 --- a/gql/transport/appsync_websockets.py +++ b/gql/transport/appsync_websockets.py @@ -4,11 +4,14 @@ from typing import Any, Dict, Optional, Tuple, Union, cast from urllib.parse import urlparse -from graphql import DocumentNode, ExecutionResult, print_ast +from graphql import ExecutionResult +from ..graphql_request import GraphQLRequest from .appsync_auth import AppSyncAuthentication, AppSyncIAMAuthentication +from .common.adapters.websockets import WebSocketsAdapter +from .common.base import SubscriptionTransportBase from .exceptions import TransportProtocolError, TransportServerError -from .websockets import WebsocketsTransport, WebsocketsTransportBase +from .websockets import WebsocketsTransport log = logging.getLogger("gql.transport.appsync") @@ -19,7 +22,7 @@ pass -class AppSyncWebsocketsTransport(WebsocketsTransportBase): +class AppSyncWebsocketsTransport(SubscriptionTransportBase): """:ref:`Async Transport ` used to execute GraphQL subscription on AWS appsync realtime endpoint. @@ -27,11 +30,12 @@ class AppSyncWebsocketsTransport(WebsocketsTransportBase): on a websocket connection. """ - auth: Optional[AppSyncAuthentication] + auth: AppSyncAuthentication def __init__( self, url: str, + *, auth: Optional[AppSyncAuthentication] = None, session: Optional["botocore.session.Session"] = None, ssl: Union[SSLContext, bool] = False, @@ -69,22 +73,30 @@ def __init__( # May raise NoRegionError or NoCredentialsError or ImportError auth = AppSyncIAMAuthentication(host=host, session=session) - self.auth = auth + self.auth: AppSyncAuthentication = auth + self.ack_timeout: Optional[Union[int, float]] = ack_timeout + self.init_payload: Dict[str, Any] = {} url = self.auth.get_auth_url(url) - super().__init__( - url, + # Instanciate a WebSocketAdapter to indicate the use + # of the websockets dependency for this transport + self.adapter: WebSocketsAdapter = WebSocketsAdapter( + url=url, ssl=ssl, + connect_args=connect_args, + ) + + # Initialize the generic SubscriptionTransportBase parent class + super().__init__( + adapter=self.adapter, connect_timeout=connect_timeout, close_timeout=close_timeout, - ack_timeout=ack_timeout, keep_alive_timeout=keep_alive_timeout, - connect_args=connect_args, ) # Using the same 'graphql-ws' protocol as the apollo protocol - self.supported_subprotocols = [ + self.adapter.subprotocols = [ WebsocketsTransport.APOLLO_SUBPROTOCOL, ] self.subprotocol = WebsocketsTransport.APOLLO_SUBPROTOCOL @@ -139,22 +151,14 @@ def _parse_answer( async def _send_query( self, - document: DocumentNode, - variable_values: Optional[Dict[str, Any]] = None, - operation_name: Optional[str] = None, + request: GraphQLRequest, ) -> int: query_id = self.next_query_id self.next_query_id += 1 - data: Dict = {"query": print_ast(document)} - - if variable_values: - data["variables"] = variable_values - - if operation_name: - data["operationName"] = operation_name + data: Dict[str, Any] = request.payload serialized_data = json.dumps(data, separators=(",", ":")) @@ -181,7 +185,7 @@ async def _send_query( return query_id - subscribe = WebsocketsTransportBase.subscribe + subscribe = SubscriptionTransportBase.subscribe # type: ignore[assignment] """Send a subscription query and receive the results using a python async generator. @@ -192,9 +196,7 @@ async def _send_query( async def execute( self, - document: DocumentNode, - variable_values: Optional[Dict[str, Any]] = None, - operation_name: Optional[str] = None, + request: GraphQLRequest, ) -> ExecutionResult: """This method is not available. @@ -212,3 +214,7 @@ async def execute( WebsocketsTransport._send_init_message_and_wait_ack ) _wait_ack = WebsocketsTransport._wait_ack + + @property + def ssl(self) -> Union[SSLContext, bool]: + return self.adapter.ssl diff --git a/gql/transport/async_transport.py b/gql/transport/async_transport.py index 4cecc9f9..526c97ba 100644 --- a/gql/transport/async_transport.py +++ b/gql/transport/async_transport.py @@ -1,7 +1,9 @@ import abc -from typing import Any, AsyncGenerator, Dict, Optional +from typing import Any, AsyncGenerator, List -from graphql import DocumentNode, ExecutionResult +from graphql import ExecutionResult + +from ..graphql_request import GraphQLRequest class AsyncTransport(abc.ABC): @@ -22,22 +24,35 @@ async def close(self): @abc.abstractmethod async def execute( self, - document: DocumentNode, - variable_values: Optional[Dict[str, Any]] = None, - operation_name: Optional[str] = None, + request: GraphQLRequest, ) -> ExecutionResult: - """Execute the provided document AST for either a remote or local GraphQL + """Execute the provided request for either a remote or local GraphQL Schema.""" raise NotImplementedError( "Any AsyncTransport subclass must implement execute method" ) # pragma: no cover + async def execute_batch( + self, + reqs: List[GraphQLRequest], + *args: Any, + **kwargs: Any, + ) -> List[ExecutionResult]: + """Execute multiple GraphQL requests in a batch. + + Execute the provided requests for either a remote or local GraphQL Schema. + + :param reqs: GraphQL requests as a list of GraphQLRequest objects. + :return: a list of ExecutionResult objects + """ + raise NotImplementedError( + "This Transport has not implemented the execute_batch method" + ) # pragma: no cover + @abc.abstractmethod def subscribe( self, - document: DocumentNode, - variable_values: Optional[Dict[str, Any]] = None, - operation_name: Optional[str] = None, + request: GraphQLRequest, ) -> AsyncGenerator[ExecutionResult, None]: """Send a query and receive the results using an async generator diff --git a/gql/transport/common/__init__.py b/gql/transport/common/__init__.py new file mode 100644 index 00000000..a60ce0b0 --- /dev/null +++ b/gql/transport/common/__init__.py @@ -0,0 +1,10 @@ +from .adapters import AdapterConnection +from .base import SubscriptionTransportBase +from .listener_queue import ListenerQueue, ParsedAnswer + +__all__ = [ + "AdapterConnection", + "ListenerQueue", + "ParsedAnswer", + "SubscriptionTransportBase", +] diff --git a/gql/transport/common/adapters/__init__.py b/gql/transport/common/adapters/__init__.py new file mode 100644 index 00000000..593c46b6 --- /dev/null +++ b/gql/transport/common/adapters/__init__.py @@ -0,0 +1,3 @@ +from .connection import AdapterConnection + +__all__ = ["AdapterConnection"] diff --git a/gql/transport/common/adapters/aiohttp.py b/gql/transport/common/adapters/aiohttp.py new file mode 100644 index 00000000..d2e1a346 --- /dev/null +++ b/gql/transport/common/adapters/aiohttp.py @@ -0,0 +1,278 @@ +import asyncio +import logging +from ssl import SSLContext +from typing import Any, Dict, Literal, Mapping, Optional, Union + +import aiohttp +from aiohttp import BasicAuth, ClientWSTimeout, Fingerprint, WSMsgType +from aiohttp.typedefs import LooseHeaders, StrOrURL +from multidict import CIMultiDictProxy + +from ...exceptions import TransportConnectionFailed, TransportProtocolError +from ..aiohttp_closed_event import create_aiohttp_closed_event +from .connection import AdapterConnection + +log = logging.getLogger("gql.transport.common.adapters.aiohttp") + + +class AIOHTTPWebSocketsAdapter(AdapterConnection): + """AdapterConnection implementation using the aiohttp library.""" + + def __init__( + self, + url: StrOrURL, + *, + headers: Optional[LooseHeaders] = None, + ssl: Optional[Union[SSLContext, Literal[False], Fingerprint]] = None, + session: Optional[aiohttp.ClientSession] = None, + client_session_args: Optional[Dict[str, Any]] = None, + connect_args: Optional[Dict[str, Any]] = None, + heartbeat: Optional[float] = None, + auth: Optional[BasicAuth] = None, + origin: Optional[str] = None, + params: Optional[Mapping[str, str]] = None, + proxy: Optional[StrOrURL] = None, + proxy_auth: Optional[BasicAuth] = None, + proxy_headers: Optional[LooseHeaders] = None, + websocket_close_timeout: float = 10.0, + receive_timeout: Optional[float] = None, + ssl_close_timeout: Optional[Union[int, float]] = 10, + ) -> None: + """Initialize the transport with the given parameters. + + :param url: The GraphQL server URL. Example: 'wss://server.com:PORT/graphql'. + :param headers: Dict of HTTP Headers. + :param ssl: SSL validation mode. ``True`` for default SSL check + (:func:`ssl.create_default_context` is used), + ``False`` for skip SSL certificate validation, + :class:`aiohttp.Fingerprint` for fingerprint + validation, :class:`ssl.SSLContext` for custom SSL + certificate validation. + :param session: Optional aiohttp opened session. + :param client_session_args: Dict of extra args passed to + :class:`aiohttp.ClientSession` + :param connect_args: Dict of extra args passed to + :meth:`aiohttp.ClientSession.ws_connect` + + :param float heartbeat: Send low level `ping` message every `heartbeat` + seconds and wait `pong` response, close + connection if `pong` response is not + received. The timer is reset on any data reception. + :param auth: An object that represents HTTP Basic Authorization. + :class:`~aiohttp.BasicAuth` (optional) + :param str origin: Origin header to send to server(optional) + :param params: Mapping, iterable of tuple of *key*/*value* pairs or + string to be sent as parameters in the query + string of the new request. Ignored for subsequent + redirected requests (optional) + + Allowed values are: + + - :class:`collections.abc.Mapping` e.g. :class:`dict`, + :class:`multidict.MultiDict` or + :class:`multidict.MultiDictProxy` + - :class:`collections.abc.Iterable` e.g. :class:`tuple` or + :class:`list` + - :class:`str` with preferably url-encoded content + (**Warning:** content will not be encoded by *aiohttp*) + :param proxy: Proxy URL, :class:`str` or :class:`~yarl.URL` (optional) + :param aiohttp.BasicAuth proxy_auth: an object that represents proxy HTTP + Basic Authorization (optional) + :param float websocket_close_timeout: Timeout for websocket to close. + ``10`` seconds by default + :param float receive_timeout: Timeout for websocket to receive + complete message. ``None`` (unlimited) + seconds by default + :param ssl_close_timeout: Timeout in seconds to wait for the ssl connection + to close properly + """ + super().__init__( + url=str(url), + connect_args=connect_args, + ) + + self._headers: Optional[LooseHeaders] = headers + self.ssl: Optional[Union[SSLContext, Literal[False], Fingerprint]] = ssl + + self.session: Optional[aiohttp.ClientSession] = session + self._using_external_session = True if self.session else False + + if client_session_args is None: + client_session_args = {} + self.client_session_args = client_session_args + + self.heartbeat: Optional[float] = heartbeat + self.auth: Optional[BasicAuth] = auth + self.origin: Optional[str] = origin + self.params: Optional[Mapping[str, str]] = params + + self.proxy: Optional[StrOrURL] = proxy + self.proxy_auth: Optional[BasicAuth] = proxy_auth + self.proxy_headers: Optional[LooseHeaders] = proxy_headers + + self.websocket_close_timeout: float = websocket_close_timeout + self.receive_timeout: Optional[float] = receive_timeout + + self.ssl_close_timeout: Optional[Union[int, float]] = ssl_close_timeout + + self.websocket: Optional[aiohttp.ClientWebSocketResponse] = None + self._response_headers: Optional[CIMultiDictProxy[str]] = None + + async def connect(self) -> None: + """Connect to the WebSocket server.""" + + assert self.websocket is None + + # Create a session if necessary + if self.session is None: + client_session_args: Dict[str, Any] = {} + + # Adding custom parameters passed from init + client_session_args.update(self.client_session_args) # type: ignore + + self.session = aiohttp.ClientSession(**client_session_args) + + ws_timeout = ClientWSTimeout( + ws_receive=self.receive_timeout, + ws_close=self.websocket_close_timeout, + ) + + connect_args: Dict[str, Any] = { + "url": self.url, + "headers": self.headers, + "auth": self.auth, + "heartbeat": self.heartbeat, + "origin": self.origin, + "params": self.params, + "proxy": self.proxy, + "proxy_auth": self.proxy_auth, + "proxy_headers": self.proxy_headers, + "timeout": ws_timeout, + } + + if self.subprotocols: + connect_args["protocols"] = self.subprotocols + + if self.ssl is not None: + connect_args["ssl"] = self.ssl + + # Adding custom parameters passed from init + connect_args.update(self.connect_args) + + try: + self.websocket = await self.session.ws_connect( + **connect_args, + ) + except Exception as e: + raise TransportConnectionFailed("Connect failed") from e + + self._response_headers = self.websocket._response.headers + + async def send(self, message: str) -> None: + """Send message to the WebSocket server. + + Args: + message: String message to send + + Raises: + TransportConnectionFailed: If connection closed + """ + if self.websocket is None: + raise TransportConnectionFailed("WebSocket connection is already closed") + + try: + await self.websocket.send_str(message) + except Exception as e: + raise TransportConnectionFailed( + f"Error trying to send data: {type(e).__name__}" + ) from e + + async def receive(self) -> str: + """Receive message from the WebSocket server. + + Returns: + String message received + + Raises: + TransportConnectionFailed: If connection closed + TransportProtocolError: If protocol error or binary data received + """ + # It is possible that the websocket has been already closed in another task + if self.websocket is None: + raise TransportConnectionFailed("Connection is already closed") + + while True: + # Should not raise any exception: + # https://docs.aiohttp.org/en/stable/_modules/aiohttp/client_ws.html + # #ClientWebSocketResponse.receive + ws_message = await self.websocket.receive() + + # Ignore low-level ping and pong received + if ws_message.type not in (WSMsgType.PING, WSMsgType.PONG): + break + + if ws_message.type in ( + WSMsgType.CLOSE, + WSMsgType.CLOSED, + WSMsgType.CLOSING, + WSMsgType.ERROR, + ): + raise TransportConnectionFailed("Connection was closed") + elif ws_message.type is WSMsgType.BINARY: + raise TransportProtocolError("Binary data received in the websocket") + + assert ws_message.type is WSMsgType.TEXT + + answer: str = ws_message.data + + return answer + + async def _close_session(self) -> None: + """Close the aiohttp session.""" + + assert self.session is not None + + closed_event = create_aiohttp_closed_event(self.session) + await self.session.close() + try: + await asyncio.wait_for(closed_event.wait(), self.ssl_close_timeout) + except asyncio.TimeoutError: + pass + finally: + self.session = None + + async def close(self) -> None: + """Close the WebSocket connection.""" + + if self.websocket: + websocket = self.websocket + self.websocket = None + try: + await websocket.close() + except Exception as exc: # pragma: no cover + log.warning("websocket.close() exception: " + repr(exc)) + + if self.session and not self._using_external_session: + await self._close_session() + + @property + def headers(self) -> Optional[LooseHeaders]: + """Get the response headers from the WebSocket connection. + + Returns: + Dictionary of response headers + """ + if self._headers: + return self._headers + return {} + + @property + def response_headers(self) -> Dict[str, str]: + """Get the response headers from the WebSocket connection. + + Returns: + Dictionary of response headers + """ + if self._response_headers: + return dict(self._response_headers) + return {} diff --git a/gql/transport/common/adapters/connection.py b/gql/transport/common/adapters/connection.py new file mode 100644 index 00000000..ac178bc6 --- /dev/null +++ b/gql/transport/common/adapters/connection.py @@ -0,0 +1,68 @@ +import abc +from typing import Any, Dict, List, Optional + + +class AdapterConnection(abc.ABC): + """Abstract interface for subscription connections. + + This allows different WebSocket implementations to be used interchangeably. + """ + + url: str + connect_args: Dict[str, Any] + subprotocols: Optional[List[str]] + + def __init__(self, url: str, connect_args: Optional[Dict[str, Any]]): + """Initialize the connection adapter.""" + self.url: str = url + + if connect_args is None: + connect_args = {} + self.connect_args = connect_args + + self.subprotocols = None + + @abc.abstractmethod + async def connect(self) -> None: + """Connect to the server.""" + pass # pragma: no cover + + @abc.abstractmethod + async def send(self, message: str) -> None: + """Send message to the server. + + Args: + message: String message to send + + Raises: + TransportConnectionFailed: If connection closed + """ + pass # pragma: no cover + + @abc.abstractmethod + async def receive(self) -> str: + """Receive message from the server. + + Returns: + String message received + + Raises: + TransportConnectionFailed: If connection closed + TransportProtocolError: If protocol error or binary data received + """ + pass # pragma: no cover + + @abc.abstractmethod + async def close(self) -> None: + """Close the connection.""" + pass # pragma: no cover + + @property + @abc.abstractmethod + def response_headers(self) -> Dict[str, str]: + """Get the response headers from the connection. + + Returns: + Dictionary of response headers + """ + pass # pragma: no cover diff --git a/gql/transport/common/adapters/websockets.py b/gql/transport/common/adapters/websockets.py new file mode 100644 index 00000000..6d248e71 --- /dev/null +++ b/gql/transport/common/adapters/websockets.py @@ -0,0 +1,156 @@ +import logging +from ssl import SSLContext +from typing import Any, Dict, Optional, Union + +import websockets +from websockets import ClientConnection +from websockets.datastructures import Headers, HeadersLike + +from ...exceptions import TransportConnectionFailed, TransportProtocolError +from .connection import AdapterConnection + +log = logging.getLogger("gql.transport.common.adapters.websockets") + + +class WebSocketsAdapter(AdapterConnection): + """AdapterConnection implementation using the websockets library.""" + + def __init__( + self, + url: str, + *, + headers: Optional[HeadersLike] = None, + ssl: Union[SSLContext, bool] = False, + connect_args: Optional[Dict[str, Any]] = None, + ) -> None: + """Initialize the transport with the given parameters. + + :param url: The GraphQL server URL. Example: 'wss://server.com:PORT/graphql'. + :param headers: Dict of HTTP Headers. + :param ssl: ssl_context of the connection. Use ssl=False to disable encryption + :param connect_args: Other parameters forwarded to + `websockets.connect `_ + """ + super().__init__( + url=url, + connect_args=connect_args, + ) + + self._headers: Optional[HeadersLike] = headers + self.ssl = ssl + + self.websocket: Optional[ClientConnection] = None + self._response_headers: Optional[Headers] = None + + async def connect(self) -> None: + """Connect to the WebSocket server.""" + + assert self.websocket is None + + ssl: Optional[Union[SSLContext, bool]] + if self.ssl: + ssl = self.ssl + else: + ssl = True if self.url.startswith("wss") else None + + # Set default arguments used in the websockets.connect call + connect_args: Dict[str, Any] = { + "ssl": ssl, + "additional_headers": self.headers, + } + + if self.subprotocols: + connect_args["subprotocols"] = self.subprotocols + + # Adding custom parameters passed from init + connect_args.update(self.connect_args) + + # Connection to the specified url + try: + self.websocket = await websockets.connect(self.url, **connect_args) + except Exception as e: + raise TransportConnectionFailed("Connect failed") from e + + assert self.websocket.response is not None + + self._response_headers = self.websocket.response.headers + + async def send(self, message: str) -> None: + """Send message to the WebSocket server. + + Args: + message: String message to send + + Raises: + TransportConnectionFailed: If connection closed + """ + if self.websocket is None: + raise TransportConnectionFailed("WebSocket connection is already closed") + + try: + await self.websocket.send(message) + except Exception as e: + raise TransportConnectionFailed( + f"Error trying to send data: {type(e).__name__}" + ) from e + + async def receive(self) -> str: + """Receive message from the WebSocket server. + + Returns: + String message received + + Raises: + TransportConnectionFailed: If connection closed + TransportProtocolError: If protocol error or binary data received + """ + # It is possible that the websocket has been already closed in another task + if self.websocket is None: + raise TransportConnectionFailed("Connection is already closed") + + # Wait for the next websocket frame. Can raise ConnectionClosed + try: + data = await self.websocket.recv() + except Exception as e: + raise TransportConnectionFailed( + f"Error trying to receive data: {type(e).__name__}" + ) from e + + # websocket.recv() can return either str or bytes + # In our case, we should receive only str here + if not isinstance(data, str): + raise TransportProtocolError("Binary data received in the websocket") + + answer: str = data + + return answer + + async def close(self) -> None: + """Close the WebSocket connection.""" + if self.websocket: + websocket = self.websocket + self.websocket = None + await websocket.close() + + @property + def headers(self) -> Optional[HeadersLike]: + """Get the response headers from the WebSocket connection. + + Returns: + Dictionary of response headers + """ + if self._headers: + return self._headers + return {} + + @property + def response_headers(self) -> Dict[str, str]: + """Get the response headers from the WebSocket connection. + + Returns: + Dictionary of response headers + """ + if self._response_headers: + return dict(self._response_headers.raw_items()) + return {} diff --git a/gql/transport/common/aiohttp_closed_event.py b/gql/transport/common/aiohttp_closed_event.py new file mode 100644 index 00000000..412448f9 --- /dev/null +++ b/gql/transport/common/aiohttp_closed_event.py @@ -0,0 +1,59 @@ +import asyncio +import functools + +from aiohttp import ClientSession + + +def create_aiohttp_closed_event(session: ClientSession) -> asyncio.Event: + """Work around aiohttp issue that doesn't properly close transports on exit. + + See https://github.com/aio-libs/aiohttp/issues/1925#issuecomment-639080209 + + Returns: + An event that will be set once all transports have been properly closed. + """ + + ssl_transports = 0 + all_is_lost = asyncio.Event() + + def connection_lost(exc, orig_lost): + nonlocal ssl_transports + + try: + orig_lost(exc) + finally: + ssl_transports -= 1 + if ssl_transports == 0: + all_is_lost.set() + + def eof_received(orig_eof_received): + try: # pragma: no cover + orig_eof_received() + except AttributeError: # pragma: no cover + # It may happen that eof_received() is called after + # _app_protocol and _transport are set to None. + pass + + assert session.connector is not None + + for conn in session.connector._conns.values(): + for handler, _ in conn: + proto = getattr(handler.transport, "_ssl_protocol", None) + if proto is None: + continue + + ssl_transports += 1 + orig_lost = proto.connection_lost + orig_eof_received = proto.eof_received + + proto.connection_lost = functools.partial( + connection_lost, orig_lost=orig_lost + ) + proto.eof_received = functools.partial( + eof_received, orig_eof_received=orig_eof_received + ) + + if ssl_transports == 0: + all_is_lost.set() + + return all_is_lost diff --git a/gql/transport/websockets_base.py b/gql/transport/common/base.py similarity index 67% rename from gql/transport/websockets_base.py rename to gql/transport/common/base.py index accca275..734c393b 100644 --- a/gql/transport/websockets_base.py +++ b/gql/transport/common/base.py @@ -3,132 +3,55 @@ import warnings from abc import abstractmethod from contextlib import suppress -from ssl import SSLContext -from typing import Any, AsyncGenerator, Dict, List, Optional, Tuple, Union, cast - -import websockets -from graphql import DocumentNode, ExecutionResult -from websockets.client import WebSocketClientProtocol -from websockets.datastructures import Headers, HeadersLike -from websockets.exceptions import ConnectionClosed -from websockets.typing import Data, Subprotocol - -from .async_transport import AsyncTransport -from .exceptions import ( +from typing import Any, AsyncGenerator, Dict, Optional, Tuple, Union + +from graphql import ExecutionResult + +from ...graphql_request import GraphQLRequest +from ..async_transport import AsyncTransport +from ..exceptions import ( TransportAlreadyConnected, TransportClosed, + TransportConnectionFailed, TransportProtocolError, TransportQueryError, TransportServerError, ) +from .adapters import AdapterConnection +from .listener_queue import ListenerQueue -log = logging.getLogger("gql.transport.websockets") - -ParsedAnswer = Tuple[str, Optional[ExecutionResult]] - - -class ListenerQueue: - """Special queue used for each query waiting for server answers - - If the server is stopped while the listener is still waiting, - Then we send an exception to the queue and this exception will be raised - to the consumer once all the previous messages have been consumed from the queue - """ - - def __init__(self, query_id: int, send_stop: bool) -> None: - self.query_id: int = query_id - self.send_stop: bool = send_stop - self._queue: asyncio.Queue = asyncio.Queue() - self._closed: bool = False - - async def get(self) -> ParsedAnswer: - - try: - item = self._queue.get_nowait() - except asyncio.QueueEmpty: - item = await self._queue.get() - - self._queue.task_done() - - # If we receive an exception when reading the queue, we raise it - if isinstance(item, Exception): - self._closed = True - raise item - - # Don't need to save new answers or - # send the stop message if we already received the complete message - answer_type, execution_result = item - if answer_type == "complete": - self.send_stop = False - self._closed = True +log = logging.getLogger("gql.transport.common.base") - return item - async def put(self, item: ParsedAnswer) -> None: - - if not self._closed: - await self._queue.put(item) - - async def set_exception(self, exception: Exception) -> None: - - # Put the exception in the queue - await self._queue.put(exception) - - # Don't need to send stop messages in case of error - self.send_stop = False - self._closed = True - - -class WebsocketsTransportBase(AsyncTransport): +class SubscriptionTransportBase(AsyncTransport): """abstract :ref:`Async Transport ` used to implement - different websockets protocols. - - This transport uses asyncio and the websockets library in order to send requests - on a websocket connection. + different subscription protocols (mainly websockets). """ def __init__( self, - url: str, - headers: Optional[HeadersLike] = None, - ssl: Union[SSLContext, bool] = False, - init_payload: Dict[str, Any] = {}, + *, + adapter: AdapterConnection, connect_timeout: Optional[Union[int, float]] = 10, close_timeout: Optional[Union[int, float]] = 10, - ack_timeout: Optional[Union[int, float]] = 10, keep_alive_timeout: Optional[Union[int, float]] = None, - connect_args: Dict[str, Any] = {}, ) -> None: """Initialize the transport with the given parameters. - :param url: The GraphQL server URL. Example: 'wss://server.com:PORT/graphql'. - :param headers: Dict of HTTP Headers. - :param ssl: ssl_context of the connection. Use ssl=False to disable encryption - :param init_payload: Dict of the payload sent in the connection_init message. + :param adapter: The connection dependency adapter :param connect_timeout: Timeout in seconds for the establishment - of the websocket connection. If None is provided this will wait forever. + of the connection. If None is provided this will wait forever. :param close_timeout: Timeout in seconds for the close. If None is provided this will wait forever. - :param ack_timeout: Timeout in seconds to wait for the connection_ack message - from the server. If None is provided this will wait forever. :param keep_alive_timeout: Optional Timeout in seconds to receive a sign of liveness from the server. - :param connect_args: Other parameters forwarded to websockets.connect """ - self.url: str = url - self.headers: Optional[HeadersLike] = headers - self.ssl: Union[SSLContext, bool] = ssl - self.init_payload: Dict[str, Any] = init_payload - self.connect_timeout: Optional[Union[int, float]] = connect_timeout self.close_timeout: Optional[Union[int, float]] = close_timeout - self.ack_timeout: Optional[Union[int, float]] = ack_timeout self.keep_alive_timeout: Optional[Union[int, float]] = keep_alive_timeout + self.adapter: AdapterConnection = adapter - self.connect_args = connect_args - - self.websocket: Optional[WebSocketClientProtocol] = None self.next_query_id: int = 1 self.listeners: Dict[int, ListenerQueue] = {} @@ -158,18 +81,14 @@ def __init__( self._next_keep_alive_message: asyncio.Event = asyncio.Event() self._next_keep_alive_message.set() - self.payloads: Dict[str, Any] = {} - """payloads is a dict which will contain the payloads received - for example with the graphql-ws protocol: 'ping', 'pong', 'connection_ack'""" - self._connecting: bool = False + self._connected: bool = False self.close_exception: Optional[Exception] = None - # The list of supported subprotocols should be defined in the subclass - self.supported_subprotocols: List[Subprotocol] = [] - - self.response_headers: Optional[Headers] = None + @property + def response_headers(self) -> Dict[str, str]: + return self.adapter.response_headers async def _initialize(self): """Hook to send the initialization messages after the connection @@ -177,76 +96,70 @@ async def _initialize(self): """ pass # pragma: no cover - async def _stop_listener(self, query_id: int): + async def _stop_listener(self, query_id: int) -> None: """Hook to stop to listen to a specific query. Will send a stop message in some subclasses. """ pass # pragma: no cover - async def _after_connect(self): + async def _after_connect(self) -> None: """Hook to add custom code for subclasses after the connection has been established. """ pass # pragma: no cover - async def _after_initialize(self): + async def _after_initialize(self) -> None: """Hook to add custom code for subclasses after the initialization has been done. """ pass # pragma: no cover - async def _close_hook(self): + async def _close_hook(self) -> None: """Hook to add custom code for subclasses for the connection close""" pass # pragma: no cover - async def _connection_terminate(self): + async def _connection_terminate(self) -> None: """Hook to add custom code for subclasses after the initialization has been done. """ pass # pragma: no cover async def _send(self, message: str) -> None: - """Send the provided message to the websocket connection and log the message""" + """Send the provided message to the adapter connection and log the message""" - if not self.websocket: - raise TransportClosed( - "Transport is not connected" - ) from self.close_exception + if not self._connected: + if isinstance(self.close_exception, TransportConnectionFailed): + raise self.close_exception + else: + raise TransportConnectionFailed() from self.close_exception try: - await self.websocket.send(message) - log.info(">>> %s", message) - except ConnectionClosed as e: + # Can raise TransportConnectionFailed + await self.adapter.send(message) + log.debug(">>> %s", message) + except TransportConnectionFailed as e: await self._fail(e, clean_close=False) raise e async def _receive(self) -> str: - """Wait the next message from the websocket connection and log the answer""" - - # It is possible that the websocket has been already closed in another task - if self.websocket is None: - raise TransportClosed("Transport is already closed") - - # Wait for the next websocket frame. Can raise ConnectionClosed - data: Data = await self.websocket.recv() + """Wait the next message from the connection and log the answer""" - # websocket.recv() can return either str or bytes - # In our case, we should receive only str here - if not isinstance(data, str): - raise TransportProtocolError("Binary data received in the websocket") + # It is possible that the connection has been already closed in another task + if not self._connected: + raise TransportConnectionFailed() from self.close_exception - answer: str = data + # Wait for the next frame. + # Can raise TransportConnectionFailed or TransportProtocolError + answer: str = await self.adapter.receive() - log.info("<<< %s", answer) + log.debug("<<< %s", answer) return answer @abstractmethod async def _send_query( self, - document: DocumentNode, - variable_values: Optional[Dict[str, Any]] = None, - operation_name: Optional[str] = None, + request: GraphQLRequest, ) -> int: raise NotImplementedError # pragma: no cover @@ -296,14 +209,12 @@ async def _receive_data_loop(self) -> None: try: while True: - # Wait the next answer from the websocket server + # Wait the next answer from the server try: answer = await self._receive() - except (ConnectionClosed, TransportProtocolError) as e: + except (TransportConnectionFailed, TransportProtocolError) as e: await self._fail(e, clean_close=False) break - except TransportClosed: - break # Parse the answer try: @@ -355,9 +266,8 @@ async def _handle_answer( async def subscribe( self, - document: DocumentNode, - variable_values: Optional[Dict[str, Any]] = None, - operation_name: Optional[str] = None, + request: GraphQLRequest, + *, send_stop: Optional[bool] = True, ) -> AsyncGenerator[ExecutionResult, None]: """Send a query and receive the results using a python async generator. @@ -369,7 +279,7 @@ async def subscribe( # Send the query and receive the id query_id: int = await self._send_query( - document, variable_values, operation_name + request, ) # Create a queue to receive the answers for this query_id @@ -384,7 +294,7 @@ async def subscribe( while True: # Wait for the answer from the queue of this query_id - # This can raise a TransportError or ConnectionClosed exception. + # This can raise TransportError or TransportConnectionFailed answer_type, execution_result = await listener.get() # If the received answer contains data, @@ -405,6 +315,7 @@ async def subscribe( if listener.send_stop: await self._stop_listener(query_id) listener.send_stop = False + raise e finally: log.debug(f"In subscribe finally for query_id {query_id}") @@ -412,11 +323,9 @@ async def subscribe( async def execute( self, - document: DocumentNode, - variable_values: Optional[Dict[str, Any]] = None, - operation_name: Optional[str] = None, + request: GraphQLRequest, ) -> ExecutionResult: - """Execute the provided document AST against the configured remote server + """Execute the provided request against the configured remote server using the current session. Send a query but close the async generator as soon as we have the first answer. @@ -426,13 +335,19 @@ async def execute( first_result = None generator = self.subscribe( - document, variable_values, operation_name, send_stop=False + request, + send_stop=False, ) async for result in generator: first_result = result break + # Apparently, on pypy the GeneratorExit exception is not raised after a break + # --> the clean_close has to time out + # We still need to manually close the async generator + await generator.aclose() + if first_result is None: raise TransportQueryError( "Query completed without any answer received from the server" @@ -447,52 +362,30 @@ async def connect(self) -> None: - send the init message - wait for the connection acknowledge from the server - create an asyncio task which will be used to receive - and parse the websocket answers + and parse the answers Should be cleaned with a call to the close coroutine """ log.debug("connect: starting") - if self.websocket is None and not self._connecting: + if not self._connected and not self._connecting: # Set connecting to True to avoid a race condition if user is trying # to connect twice using the same client at the same time self._connecting = True - # If the ssl parameter is not provided, - # generate the ssl value depending on the url - ssl: Optional[Union[SSLContext, bool]] - if self.ssl: - ssl = self.ssl - else: - ssl = True if self.url.startswith("wss") else None - - # Set default arguments used in the websockets.connect call - connect_args: Dict[str, Any] = { - "ssl": ssl, - "extra_headers": self.headers, - "subprotocols": self.supported_subprotocols, - } - - # Adding custom parameters passed from init - connect_args.update(self.connect_args) - - # Connection to the specified url # Generate a TimeoutError if taking more than connect_timeout seconds # Set the _connecting flag to False after in all cases try: - self.websocket = await asyncio.wait_for( - websockets.client.connect(self.url, **connect_args), + await asyncio.wait_for( + self.adapter.connect(), self.connect_timeout, ) + self._connected = True finally: self._connecting = False - self.websocket = cast(WebSocketClientProtocol, self.websocket) - - self.response_headers = self.websocket.response_headers - # Run the after_connect hook of the subclass await self._after_connect() @@ -505,7 +398,7 @@ async def connect(self) -> None: # if no ACKs are received within the ack_timeout try: await self._initialize() - except ConnectionClosed as e: + except TransportConnectionFailed as e: raise e except ( TransportProtocolError, @@ -533,7 +426,7 @@ async def connect(self) -> None: log.debug("connect: done") - def _remove_listener(self, query_id) -> None: + def _remove_listener(self, query_id: int) -> None: """After exiting from a subscription, remove the listener and signal an event if this was the last listener for the client. """ @@ -555,7 +448,6 @@ async def _clean_close(self, e: Exception) -> None: # Send 'stop' message for all current queries for query_id, listener in self.listeners.items(): - if listener.send_stop: await self._stop_listener(query_id) listener.send_stop = False @@ -584,7 +476,11 @@ async def _close_coro(self, e: Exception, clean_close: bool = True) -> None: try: # We should always have an active websocket connection here - assert self.websocket is not None + assert self._connected + + # Saving exception to raise it later if trying to use the transport + # after it has already closed. + self.close_exception = e # Properly shut down liveness checker if enabled if self.check_keep_alive_task is not None: @@ -596,10 +492,6 @@ async def _close_coro(self, e: Exception, clean_close: bool = True) -> None: # Calling the subclass close hook await self._close_hook() - # Saving exception to raise it later if trying to use the transport - # after it has already closed. - self.close_exception = e - if clean_close: log.debug("_close_coro: starting clean_close") try: @@ -607,17 +499,20 @@ async def _close_coro(self, e: Exception, clean_close: bool = True) -> None: except Exception as exc: # pragma: no cover log.warning("Ignoring exception in _clean_close: " + repr(exc)) - log.debug("_close_coro: sending exception to listeners") + if log.isEnabledFor(logging.DEBUG): + log.debug( + f"_close_coro: sending exception to {len(self.listeners)} listeners" + ) # Send an exception to all remaining listeners for query_id, listener in self.listeners.items(): await listener.set_exception(e) - log.debug("_close_coro: close websocket connection") + log.debug("_close_coro: close connection") - await self.websocket.close() + await self.adapter.close() - log.debug("_close_coro: websocket connection closed") + log.debug("_close_coro: connection closed") except Exception as exc: # pragma: no cover log.warning("Exception catched in _close_coro: " + repr(exc)) @@ -626,7 +521,7 @@ async def _close_coro(self, e: Exception, clean_close: bool = True) -> None: log.debug("_close_coro: start cleanup") - self.websocket = None + self._connected = False self.close_task = None self.check_keep_alive_task = None self._wait_closed.set() @@ -634,16 +529,24 @@ async def _close_coro(self, e: Exception, clean_close: bool = True) -> None: log.debug("_close_coro: exiting") async def _fail(self, e: Exception, clean_close: bool = True) -> None: - log.debug("_fail: starting with exception: " + repr(e)) + if log.isEnabledFor(logging.DEBUG): + import inspect + + current_frame = inspect.currentframe() + assert current_frame is not None + caller_frame = current_frame.f_back + assert caller_frame is not None + caller_name = inspect.getframeinfo(caller_frame).function + log.debug(f"_fail from {caller_name}: " + repr(e)) if self.close_task is None: - if self.websocket is None: - log.debug("_fail started with self.websocket == None -> already closed") - else: + if self._connected: self.close_task = asyncio.shield( asyncio.ensure_future(self._close_coro(e, clean_close=clean_close)) ) + else: + log.debug("_fail started with self._connected:False -> already closed") else: log.debug( "close_task is not None in _fail. Previous exception is: " @@ -655,7 +558,7 @@ async def _fail(self, e: Exception, clean_close: bool = True) -> None: async def close(self) -> None: log.debug("close: starting") - await self._fail(TransportClosed("Websocket GraphQL transport closed by user")) + await self._fail(TransportClosed("Transport closed by user")) await self.wait_closed() log.debug("close: done") @@ -663,6 +566,17 @@ async def close(self) -> None: async def wait_closed(self) -> None: log.debug("wait_close: starting") - await self._wait_closed.wait() + try: + await asyncio.wait_for(self._wait_closed.wait(), self.close_timeout) + except asyncio.TimeoutError: + log.warning("Timer close_timeout fired in wait_closed") log.debug("wait_close: done") + + @property + def url(self) -> str: + return self.adapter.url + + @property + def connect_args(self) -> Dict[str, Any]: + return self.adapter.connect_args diff --git a/gql/transport/common/batch.py b/gql/transport/common/batch.py new file mode 100644 index 00000000..4feadee6 --- /dev/null +++ b/gql/transport/common/batch.py @@ -0,0 +1,76 @@ +from typing import ( + Any, + Dict, + List, +) + +from graphql import ExecutionResult + +from ...graphql_request import GraphQLRequest +from ..exceptions import ( + TransportProtocolError, +) + + +def _raise_protocol_error(result_text: str, reason: str) -> None: + raise TransportProtocolError( + f"Server did not return a valid GraphQL result: " f"{reason}: " f"{result_text}" + ) + + +def _validate_answer_is_a_list(results: Any) -> None: + if not isinstance(results, list): + _raise_protocol_error( + str(results), + "Answer is not a list", + ) + + +def _validate_data_and_errors_keys_in_answers(results: List[Dict[str, Any]]) -> None: + for result in results: + if "errors" not in result and "data" not in result: + _raise_protocol_error( + str(results), + 'No "data" or "errors" keys in answer', + ) + + +def _validate_every_answer_is_a_dict(results: List[Dict[str, Any]]) -> None: + for result in results: + if not isinstance(result, dict): + _raise_protocol_error(str(results), "Not every answer is dict") + + +def _validate_num_of_answers_same_as_requests( + reqs: List[GraphQLRequest], + results: List[Dict[str, Any]], +) -> None: + if len(reqs) != len(results): + _raise_protocol_error( + str(results), + ( + "Invalid number of answers: " + f"{len(results)} answers received for {len(reqs)} requests" + ), + ) + + +def _answer_to_execution_result(result: Dict[str, Any]) -> ExecutionResult: + return ExecutionResult( + errors=result.get("errors"), + data=result.get("data"), + extensions=result.get("extensions"), + ) + + +def get_batch_execution_result_list( + reqs: List[GraphQLRequest], + answers: List, +) -> List[ExecutionResult]: + + _validate_answer_is_a_list(answers) + _validate_num_of_answers_same_as_requests(reqs, answers) + _validate_every_answer_is_a_dict(answers) + _validate_data_and_errors_keys_in_answers(answers) + + return [_answer_to_execution_result(answer) for answer in answers] diff --git a/gql/transport/common/listener_queue.py b/gql/transport/common/listener_queue.py new file mode 100644 index 00000000..54aa650f --- /dev/null +++ b/gql/transport/common/listener_queue.py @@ -0,0 +1,58 @@ +import asyncio +from typing import Optional, Tuple + +from graphql import ExecutionResult + +ParsedAnswer = Tuple[str, Optional[ExecutionResult]] + + +class ListenerQueue: + """Special queue used for each query waiting for server answers + + If the server is stopped while the listener is still waiting, + Then we send an exception to the queue and this exception will be raised + to the consumer once all the previous messages have been consumed from the queue + """ + + def __init__(self, query_id: int, send_stop: bool) -> None: + self.query_id: int = query_id + self.send_stop: bool = send_stop + self._queue: asyncio.Queue = asyncio.Queue() + self._closed: bool = False + + async def get(self) -> ParsedAnswer: + + try: + item = self._queue.get_nowait() + except asyncio.QueueEmpty: + item = await self._queue.get() + + self._queue.task_done() + + # If we receive an exception when reading the queue, we raise it + if isinstance(item, Exception): + self._closed = True + raise item + + # Don't need to save new answers or + # send the stop message if we already received the complete message + answer_type, execution_result = item + if answer_type == "complete": + self.send_stop = False + self._closed = True + + return item + + async def put(self, item: ParsedAnswer) -> None: + + if not self._closed: + await self._queue.put(item) + + async def set_exception(self, exception: Exception) -> None: + + # Put the exception in the queue + await self._queue.put(exception) + + # Don't need to send stop messages in case of error + self.send_stop = False + self._closed = True diff --git a/gql/transport/exceptions.py b/gql/transport/exceptions.py index 7ec27a33..0049d5c2 100644 --- a/gql/transport/exceptions.py +++ b/gql/transport/exceptions.py @@ -61,6 +61,14 @@ class TransportClosed(TransportError): """ +class TransportConnectionFailed(TransportError): + """Transport connection failed. + + This exception is by the connection adapter code when a connection closed + or if an unexpected Exception was received when trying to send a request. + """ + + class TransportAlreadyConnected(TransportError): """Transport is already connected. diff --git a/gql/transport/file_upload.py b/gql/transport/file_upload.py new file mode 100644 index 00000000..8673ab60 --- /dev/null +++ b/gql/transport/file_upload.py @@ -0,0 +1,126 @@ +import io +import os +import warnings +from typing import Any, Dict, List, Optional, Tuple, Type + + +class FileVar: + def __init__( + self, + f: Any, # str | io.IOBase | aiohttp.StreamReader | AsyncGenerator + *, + filename: Optional[str] = None, + content_type: Optional[str] = None, + streaming: bool = False, + streaming_block_size: int = 64 * 1024, + ): + self.f = f + self.filename = filename + self.content_type = content_type + self.streaming = streaming + self.streaming_block_size = streaming_block_size + + self._file_opened: bool = False + + def open_file( + self, + transport_supports_streaming: bool = False, + ) -> None: + assert self._file_opened is False + + if self.streaming: + assert ( + transport_supports_streaming + ), "streaming not supported on this transport" + self._make_file_streamer() + else: + if isinstance(self.f, str): + if self.filename is None: + # By default we set the filename to the basename + # of the opened file + self.filename = os.path.basename(self.f) + self.f = open(self.f, "rb") + self._file_opened = True + + def close_file(self) -> None: + if self._file_opened: + assert isinstance(self.f, io.IOBase) + self.f.close() + self._file_opened = False + + def _make_file_streamer(self) -> None: + assert isinstance(self.f, str), "streaming option needs a filepath str" + + import aiofiles + + async def file_sender(file_name): + async with aiofiles.open(file_name, "rb") as f: + while chunk := await f.read(self.streaming_block_size): + yield chunk + + self.f = file_sender(self.f) + + +def open_files( + filevars: List[FileVar], + transport_supports_streaming: bool = False, +) -> None: + + for filevar in filevars: + filevar.open_file(transport_supports_streaming=transport_supports_streaming) + + +def close_files(filevars: List[FileVar]) -> None: + for filevar in filevars: + filevar.close_file() + + +FILE_UPLOAD_DOCS = "https://gql.readthedocs.io/en/latest/usage/file_upload.html" + + +def extract_files( + variables: Dict, file_classes: Tuple[Type[Any], ...] +) -> Tuple[Dict, Dict[str, FileVar]]: + files: Dict[str, FileVar] = {} + + def recurse_extract(path, obj): + """ + recursively traverse obj, doing a deepcopy, but + replacing any file-like objects with nulls and + shunting the originals off to the side. + """ + nonlocal files + if isinstance(obj, list): + nulled_list = [] + for key, value in enumerate(obj): + value = recurse_extract(f"{path}.{key}", value) + nulled_list.append(value) + return nulled_list + elif isinstance(obj, dict): + nulled_dict = {} + for key, value in obj.items(): + value = recurse_extract(f"{path}.{key}", value) + nulled_dict[key] = value + return nulled_dict + elif isinstance(obj, file_classes): + # extract obj from its parent and put it into files instead. + warnings.warn( + "Not using FileVar for file upload is deprecated. " + f"See {FILE_UPLOAD_DOCS} for details.", + DeprecationWarning, + ) + name = getattr(obj, "name", None) + content_type = getattr(obj, "content_type", None) + files[path] = FileVar(obj, filename=name, content_type=content_type) + return None + elif isinstance(obj, FileVar): + # extract obj from its parent and put it into files instead. + files[path] = obj + return None + else: + # base case: pass through unchanged + return obj + + nulled_variables = recurse_extract("variables", variables) + + return nulled_variables, files diff --git a/gql/transport/httpx.py b/gql/transport/httpx.py index 811601b8..0a338639 100644 --- a/gql/transport/httpx.py +++ b/gql/transport/httpx.py @@ -7,24 +7,27 @@ Callable, Dict, List, + NoReturn, Optional, Tuple, Type, Union, - cast, ) import httpx -from graphql import DocumentNode, ExecutionResult, print_ast +from graphql import ExecutionResult -from ..utils import extract_files +from ..graphql_request import GraphQLRequest from . import AsyncTransport, Transport +from .common.batch import get_batch_execution_result_list from .exceptions import ( TransportAlreadyConnected, TransportClosed, + TransportConnectionFailed, TransportProtocolError, TransportServerError, ) +from .file_upload import close_files, extract_files, open_files log = logging.getLogger(__name__) @@ -39,7 +42,7 @@ def __init__( url: Union[str, httpx.URL], json_serialize: Callable = json.dumps, json_deserialize: Callable = json.loads, - **kwargs, + **kwargs: Any, ): """Initialize the transport with the given httpx parameters. @@ -57,30 +60,23 @@ def __init__( def _prepare_request( self, - document: DocumentNode, - variable_values: Optional[Dict[str, Any]] = None, - operation_name: Optional[str] = None, + request: Union[GraphQLRequest, List[GraphQLRequest]], + *, extra_args: Optional[Dict[str, Any]] = None, upload_files: bool = False, ) -> Dict[str, Any]: - query_str = print_ast(document) - payload: Dict[str, Any] = { - "query": query_str, - } - - if operation_name: - payload["operationName"] = operation_name + payload: Dict | List + if isinstance(request, GraphQLRequest): + payload = request.payload + else: + payload = [req.payload for req in request] if upload_files: - # If the upload_files flag is set, then we need variable_values - assert variable_values is not None - - post_args = self._prepare_file_uploads(variable_values, payload) + assert isinstance(payload, Dict) + assert isinstance(request, GraphQLRequest) + post_args = self._prepare_file_uploads(request, payload) else: - if variable_values: - payload["variables"] = variable_values - post_args = {"json": payload} # Log the payload @@ -93,7 +89,17 @@ def _prepare_request( return post_args - def _prepare_file_uploads(self, variable_values, payload) -> Dict[str, Any]: + def _prepare_file_uploads( + self, + request: GraphQLRequest, + payload: Dict[str, Any], + ) -> Dict[str, Any]: + + variable_values = request.variable_values + + # If the upload_files flag is set, then we need variable_values + assert variable_values is not None + # If we upload files, we will extract the files present in the # variable_values dict and replace them by null values nulled_variable_values, files = extract_files( @@ -101,6 +107,10 @@ def _prepare_file_uploads(self, variable_values, payload) -> Dict[str, Any]: file_classes=self.file_classes, ) + # Opening the files using the FileVar parameters + open_files(list(files.values())) + self.files = files + # Save the nulled variable values in the payload payload["variables"] = nulled_variable_values @@ -109,7 +119,7 @@ def _prepare_file_uploads(self, variable_values, payload) -> Dict[str, Any]: file_map: Dict[str, List[str]] = {} file_streams: Dict[str, Tuple[str, ...]] = {} - for i, (path, f) in enumerate(files.items()): + for i, (path, file_var) in enumerate(files.items()): key = str(i) # Generate the file map @@ -118,16 +128,12 @@ def _prepare_file_uploads(self, variable_values, payload) -> Dict[str, Any]: # Will generate something like {"0": ["variables.file"]} file_map[key] = [path] - # Generate the file streams - # Will generate something like - # {"0": ("variables.file", <_io.BufferedReader ...>)} - name = cast(str, getattr(f, "name", key)) - content_type = getattr(f, "content_type", None) + name = key if file_var.filename is None else file_var.filename - if content_type is None: - file_streams[key] = (name, f) + if file_var.content_type is None: + file_streams[key] = (name, file_var.f) else: - file_streams[key] = (name, f, content_type) + file_streams[key] = (name, file_var.f, file_var.content_type) # Add the payload to the operations field operations_str = self.json_serialize(payload) @@ -141,8 +147,9 @@ def _prepare_file_uploads(self, variable_values, payload) -> Dict[str, Any]: return {"data": data, "files": file_streams} - def _prepare_result(self, response: httpx.Response) -> ExecutionResult: - # Save latest response headers in transport + def _get_json_result(self, response: httpx.Response) -> Any: + + # Saving latest response headers in the transport self.response_headers = response.headers if log.isEnabledFor(logging.DEBUG): @@ -150,10 +157,15 @@ def _prepare_result(self, response: httpx.Response) -> ExecutionResult: try: result: Dict[str, Any] = self.json_deserialize(response.content) - except Exception: self._raise_response_error(response, "Not a JSON answer") + return result + + def _prepare_result(self, response: httpx.Response) -> ExecutionResult: + + result = self._get_json_result(response) + if "errors" not in result and "data" not in result: self._raise_response_error(response, 'No "data" or "errors" keys in answer') @@ -163,16 +175,41 @@ def _prepare_result(self, response: httpx.Response) -> ExecutionResult: extensions=result.get("extensions"), ) - def _raise_response_error(self, response: httpx.Response, reason: str): - # We raise a TransportServerError if the status code is 400 or higher - # We raise a TransportProtocolError in the other cases + def _prepare_batch_result( + self, + reqs: List[GraphQLRequest], + response: httpx.Response, + ) -> List[ExecutionResult]: + + answers = self._get_json_result(response) try: - # Raise a HTTPError if response status is 400 or higher + return get_batch_execution_result_list(reqs, answers) + except TransportProtocolError: + # Raise a TransportServerError if status > 400 + self._raise_transport_server_error_if_status_more_than_400(response) + # In other cases, raise a TransportProtocolError + raise + + @staticmethod + def _raise_transport_server_error_if_status_more_than_400( + response: httpx.Response, + ) -> None: + # If the status is >400, + # then we need to raise a TransportServerError + try: + # Raise a HTTPStatusError if response status is 400 or higher response.raise_for_status() except httpx.HTTPStatusError as e: raise TransportServerError(str(e), e.response.status_code) from e + @classmethod + def _raise_response_error(cls, response: httpx.Response, reason: str) -> NoReturn: + # We raise a TransportServerError if the status code is 400 or higher + # We raise a TransportProtocolError in the other cases + + cls._raise_transport_server_error_if_status_more_than_400(response) + raise TransportProtocolError( f"Server did not return a GraphQL result: " f"{reason}: " f"{response.text}" ) @@ -195,23 +232,20 @@ def connect(self): self.client = httpx.Client(**self.kwargs) - def execute( # type: ignore + def execute( self, - document: DocumentNode, - variable_values: Optional[Dict[str, Any]] = None, - operation_name: Optional[str] = None, + request: GraphQLRequest, + *, extra_args: Optional[Dict[str, Any]] = None, upload_files: bool = False, ) -> ExecutionResult: """Execute GraphQL query. - Execute the provided document AST against the configured remote server. This + Execute the provided request against the configured remote server. This uses the httpx library to perform a HTTP POST request to the remote server. - :param document: GraphQL query as AST Node object. - :param variable_values: Dictionary of input parameters (Default: None). - :param operation_name: Name of the operation that shall be executed. - Only required in multi-operation documents (Default: None). + :param request: GraphQL request as a + :class:`GraphQLRequest ` object. :param extra_args: additional arguments to send to the httpx post method :param upload_files: Set to True if you want to put files in the variable values :return: The result of execution. @@ -222,17 +256,54 @@ def execute( # type: ignore raise TransportClosed("Transport is not connected") post_args = self._prepare_request( - document, - variable_values, - operation_name, - extra_args, - upload_files, + request, + extra_args=extra_args, + upload_files=upload_files, ) - response = self.client.post(self.url, **post_args) + try: + response = self.client.post(self.url, **post_args) + except Exception as e: + raise TransportConnectionFailed(str(e)) from e + finally: + if upload_files: + close_files(list(self.files.values())) return self._prepare_result(response) + def execute_batch( + self, + reqs: List[GraphQLRequest], + extra_args: Optional[Dict[str, Any]] = None, + ) -> List[ExecutionResult]: + """Execute multiple GraphQL requests in a batch. + + Don't call this coroutine directly on the transport, instead use + :code:`execute_batch` on a client or a session. + + :param reqs: GraphQL requests as a list of GraphQLRequest objects. + :param extra_args: additional arguments to send to the httpx post method + :return: A list of results of execution. + For every result `data` is the result of executing the query, + `errors` is null if no errors occurred, and is a non-empty array + if an error occurred. + """ + + if not self.client: + raise TransportClosed("Transport is not connected") + + post_args = self._prepare_request( + reqs, + extra_args=extra_args, + ) + + try: + response = self.client.post(self.url, **post_args) + except Exception as e: + raise TransportConnectionFailed(str(e)) from e + + return self._prepare_batch_result(reqs, response) + def close(self): """Closing the transport by closing the inner session""" if self.client: @@ -259,22 +330,19 @@ async def connect(self): async def execute( self, - document: DocumentNode, - variable_values: Optional[Dict[str, Any]] = None, - operation_name: Optional[str] = None, + request: GraphQLRequest, + *, extra_args: Optional[Dict[str, Any]] = None, upload_files: bool = False, ) -> ExecutionResult: """Execute GraphQL query. - Execute the provided document AST against the configured remote server. This + Execute the provided request against the configured remote server. This uses the httpx library to perform a HTTP POST request asynchronously to the remote server. - :param document: GraphQL query as AST Node object. - :param variable_values: Dictionary of input parameters (Default: None). - :param operation_name: Name of the operation that shall be executed. - Only required in multi-operation documents (Default: None). + :param request: GraphQL request as a + :class:`GraphQLRequest ` object. :param extra_args: additional arguments to send to the httpx post method :param upload_files: Set to True if you want to put files in the variable values :return: The result of execution. @@ -285,31 +353,66 @@ async def execute( raise TransportClosed("Transport is not connected") post_args = self._prepare_request( - document, - variable_values, - operation_name, - extra_args, - upload_files, + request, + extra_args=extra_args, + upload_files=upload_files, ) - response = await self.client.post(self.url, **post_args) + try: + response = await self.client.post(self.url, **post_args) + except Exception as e: + raise TransportConnectionFailed(str(e)) from e + finally: + if upload_files: + close_files(list(self.files.values())) return self._prepare_result(response) - async def close(self): - """Closing the transport by closing the inner session""" - if self.client: - await self.client.aclose() - self.client = None + async def execute_batch( + self, + reqs: List[GraphQLRequest], + extra_args: Optional[Dict[str, Any]] = None, + ) -> List[ExecutionResult]: + """Execute multiple GraphQL requests in a batch. + + Don't call this coroutine directly on the transport, instead use + :code:`execute_batch` on a client or a session. + + :param reqs: GraphQL requests as a list of GraphQLRequest objects. + :param extra_args: additional arguments to send to the httpx post method + :return: A list of results of execution. + For every result `data` is the result of executing the query, + `errors` is null if no errors occurred, and is a non-empty array + if an error occurred. + """ + + if not self.client: + raise TransportClosed("Transport is not connected") + + post_args = self._prepare_request( + reqs, + extra_args=extra_args, + ) + + try: + response = await self.client.post(self.url, **post_args) + except Exception as e: + raise TransportConnectionFailed(str(e)) from e + + return self._prepare_batch_result(reqs, response) def subscribe( self, - document: DocumentNode, - variable_values: Optional[Dict[str, Any]] = None, - operation_name: Optional[str] = None, + request: GraphQLRequest, ) -> AsyncGenerator[ExecutionResult, None]: """Subscribe is not supported on HTTP. :meta private: """ raise NotImplementedError("The HTTP transport does not support subscriptions") + + async def close(self): + """Closing the transport by closing the inner session""" + if self.client: + await self.client.aclose() + self.client = None diff --git a/gql/transport/local_schema.py b/gql/transport/local_schema.py index 04ed4ff1..f87854e2 100644 --- a/gql/transport/local_schema.py +++ b/gql/transport/local_schema.py @@ -1,11 +1,13 @@ import asyncio from inspect import isawaitable -from typing import AsyncGenerator, Awaitable, cast +from typing import Any, AsyncGenerator, Awaitable, cast -from graphql import DocumentNode, ExecutionResult, GraphQLSchema, execute, subscribe +from graphql import ExecutionResult, GraphQLSchema, execute, subscribe from gql.transport import AsyncTransport +from ..graphql_request import GraphQLRequest + class LocalSchemaTransport(AsyncTransport): """A transport for executing GraphQL queries against a local schema.""" @@ -30,13 +32,24 @@ async def close(self): async def execute( self, - document: DocumentNode, - *args, - **kwargs, + request: GraphQLRequest, + *args: Any, + **kwargs: Any, ) -> ExecutionResult: - """Execute the provided document AST for on a local GraphQL Schema.""" - - result_or_awaitable = execute(self.schema, document, *args, **kwargs) + """Execute the provided request for on a local GraphQL Schema.""" + + inner_kwargs = { + "variable_values": request.variable_values, + "operation_name": request.operation_name, + **kwargs, + } + + result_or_awaitable = execute( + self.schema, + request.document, + *args, + **inner_kwargs, + ) execution_result: ExecutionResult @@ -57,17 +70,28 @@ async def _await_if_necessary(obj): async def subscribe( self, - document: DocumentNode, - *args, - **kwargs, + request: GraphQLRequest, + *args: Any, + **kwargs: Any, ) -> AsyncGenerator[ExecutionResult, None]: """Send a subscription and receive the results using an async generator The results are sent as an ExecutionResult object """ + inner_kwargs = { + "variable_values": request.variable_values, + "operation_name": request.operation_name, + **kwargs, + } + subscribe_result = await self._await_if_necessary( - subscribe(self.schema, document, *args, **kwargs) + subscribe( + self.schema, + request.document, + *args, + **inner_kwargs, + ) ) if isinstance(subscribe_result, ExecutionResult): diff --git a/gql/transport/phoenix_channel_websockets.py b/gql/transport/phoenix_channel_websockets.py index 08cde8cc..8e7455e2 100644 --- a/gql/transport/phoenix_channel_websockets.py +++ b/gql/transport/phoenix_channel_websockets.py @@ -1,17 +1,19 @@ import asyncio import json import logging -from typing import Any, Dict, Optional, Tuple +from typing import Any, Dict, Optional, Tuple, Union -from graphql import DocumentNode, ExecutionResult, print_ast -from websockets.exceptions import ConnectionClosed +from graphql import ExecutionResult, print_ast +from ..graphql_request import GraphQLRequest +from .common.adapters.websockets import WebSocketsAdapter +from .common.base import SubscriptionTransportBase from .exceptions import ( + TransportConnectionFailed, TransportProtocolError, TransportQueryError, TransportServerError, ) -from .websockets_base import WebsocketsTransportBase log = logging.getLogger(__name__) @@ -24,7 +26,7 @@ def __init__(self, query_id: int) -> None: self.unsubscribe_id: Optional[int] = None -class PhoenixChannelWebsocketsTransport(WebsocketsTransportBase): +class PhoenixChannelWebsocketsTransport(SubscriptionTransportBase): """The PhoenixChannelWebsocketsTransport is an async transport which allows you to execute queries and subscriptions against an `Absinthe`_ backend using the `Phoenix`_ framework `channels`_. @@ -36,23 +38,48 @@ class PhoenixChannelWebsocketsTransport(WebsocketsTransportBase): def __init__( self, + url: str, + *, channel_name: str = "__absinthe__:control", heartbeat_interval: float = 30, - *args, - **kwargs, + ack_timeout: Optional[Union[int, float]] = 10, + **kwargs: Any, ) -> None: """Initialize the transport with the given parameters. + :param url: The server URL.'. :param channel_name: Channel on the server this transport will join. The default for Absinthe servers is "__absinthe__:control" :param heartbeat_interval: Interval in second between each heartbeat messages sent by the client + :param ack_timeout: Timeout in seconds to wait for the reply message + from the server. """ self.channel_name: str = channel_name self.heartbeat_interval: float = heartbeat_interval self.heartbeat_task: Optional[asyncio.Future] = None self.subscriptions: Dict[str, Subscription] = {} - super().__init__(*args, **kwargs) + self.ack_timeout: Optional[Union[int, float]] = ack_timeout + + # Instanciate a WebSocketAdapter to indicate the use + # of the websockets dependency for this transport + ws_adapter_args = {} + for ws_arg in ["headers", "ssl", "connect_args"]: + try: + ws_adapter_args[ws_arg] = kwargs.pop(ws_arg) + except KeyError: + pass + + self.adapter: WebSocketsAdapter = WebSocketsAdapter( + url=url, + **ws_adapter_args, + ) + + # Initialize the generic SubscriptionTransportBase parent class + super().__init__( + adapter=self.adapter, + **kwargs, + ) async def _initialize(self) -> None: """Join the specified channel and wait for the connection ACK. @@ -101,7 +128,7 @@ async def heartbeat_coro(): } ) ) - except ConnectionClosed: # pragma: no cover + except TransportConnectionFailed: # pragma: no cover return self.heartbeat_task = asyncio.ensure_future(heartbeat_coro()) @@ -156,9 +183,7 @@ async def _connection_terminate(self): async def _send_query( self, - document: DocumentNode, - variable_values: Optional[Dict[str, Any]] = None, - operation_name: Optional[str] = None, + request: GraphQLRequest, ) -> int: """Send a query to the provided websocket connection. @@ -175,8 +200,8 @@ async def _send_query( "topic": self.channel_name, "event": "doc", "payload": { - "query": print_ast(document), - "variables": variable_values or {}, + "query": print_ast(request.document), + "variables": request.variable_values or {}, }, "ref": query_id, } @@ -218,7 +243,7 @@ def _required_value(d: Any, key: str, label: str) -> Any: return value def _required_subscription_id( - d: Any, label: str, must_exist: bool = False, must_not_exist=False + d: Any, label: str, must_exist: bool = False, must_not_exist: bool = False ) -> str: subscription_id = str(_required_value(d, "subscriptionId", label)) if must_exist and (subscription_id not in self.subscriptions): @@ -370,7 +395,7 @@ async def _handle_answer( execution_result: Optional[ExecutionResult], ) -> None: if answer_type == "close": - await self.close() + pass else: await super()._handle_answer(answer_type, answer_id, execution_result) diff --git a/gql/transport/requests.py b/gql/transport/requests.py index bd370908..a29f7f0f 100644 --- a/gql/transport/requests.py +++ b/gql/transport/requests.py @@ -1,25 +1,39 @@ import io import json import logging -from typing import Any, Callable, Collection, Dict, List, Optional, Tuple, Type, Union +from typing import ( + Any, + Callable, + Collection, + Dict, + List, + NoReturn, + Optional, + Tuple, + Type, + Union, +) import requests -from graphql import DocumentNode, ExecutionResult, print_ast +from graphql import ExecutionResult from requests.adapters import HTTPAdapter, Retry from requests.auth import AuthBase from requests.cookies import RequestsCookieJar +from requests.structures import CaseInsensitiveDict from requests_toolbelt.multipart.encoder import MultipartEncoder from gql.transport import Transport from ..graphql_request import GraphQLRequest -from ..utils import extract_files +from .common.batch import get_batch_execution_result_list from .exceptions import ( TransportAlreadyConnected, TransportClosed, + TransportConnectionFailed, TransportProtocolError, TransportServerError, ) +from .file_upload import FileVar, close_files, extract_files, open_files log = logging.getLogger(__name__) @@ -100,9 +114,9 @@ def __init__( self.json_deserialize: Callable = json_deserialize self.kwargs = kwargs - self.session = None + self.session: Optional[requests.Session] = None - self.response_headers = None + self.response_headers: Optional[CaseInsensitiveDict[str]] = None def connect(self): if self.session is None: @@ -124,42 +138,22 @@ def connect(self): else: raise TransportAlreadyConnected("Transport is already connected") - def execute( # type: ignore + def _prepare_request( self, - document: DocumentNode, - variable_values: Optional[Dict[str, Any]] = None, - operation_name: Optional[str] = None, + request: Union[GraphQLRequest, List[GraphQLRequest]], + *, timeout: Optional[int] = None, extra_args: Optional[Dict[str, Any]] = None, upload_files: bool = False, - ) -> ExecutionResult: - """Execute GraphQL query. - - Execute the provided document AST against the configured remote server. This - uses the requests library to perform a HTTP POST request to the remote server. - - :param document: GraphQL query as AST Node object. - :param variable_values: Dictionary of input parameters (Default: None). - :param operation_name: Name of the operation that shall be executed. - Only required in multi-operation documents (Default: None). - :param timeout: Specifies a default timeout for requests (Default: None). - :param extra_args: additional arguments to send to the requests post method - :param upload_files: Set to True if you want to put files in the variable values - :return: The result of execution. - `data` is the result of executing the query, `errors` is null - if no errors occurred, and is a non-empty array if an error occurred. - """ - - if not self.session: - raise TransportClosed("Transport is not connected") - - query_str = print_ast(document) - payload: Dict[str, Any] = {"query": query_str} + ) -> Dict[str, Any]: - if operation_name: - payload["operationName"] = operation_name + payload: Dict | List + if isinstance(request, GraphQLRequest): + payload = request.payload + else: + payload = [req.payload for req in request] - post_args = { + post_args: Dict[str, Any] = { "headers": self.headers, "auth": self.auth, "cookies": self.cookies, @@ -168,124 +162,168 @@ def execute( # type: ignore } if upload_files: - # If the upload_files flag is set, then we need variable_values - assert variable_values is not None - - # If we upload files, we will extract the files present in the - # variable_values dict and replace them by null values - nulled_variable_values, files = extract_files( - variables=variable_values, - file_classes=self.file_classes, + assert isinstance(payload, Dict) + assert isinstance(request, GraphQLRequest) + post_args = self._prepare_file_uploads( + request=request, + payload=payload, + post_args=post_args, ) - # Save the nulled variable values in the payload - payload["variables"] = nulled_variable_values + else: + data_key = "json" if self.use_json else "data" + post_args[data_key] = payload - # Add the payload to the operations field - operations_str = self.json_serialize(payload) - log.debug("operations %s", operations_str) + # Log the payload + if log.isEnabledFor(logging.DEBUG): + log.debug(">>> %s", self.json_serialize(payload)) + + # Pass kwargs to requests post method + post_args.update(self.kwargs) + + # Pass post_args to requests post method + if extra_args: + post_args.update(extra_args) + + return post_args + + def _prepare_file_uploads( + self, + request: GraphQLRequest, + *, + payload: Dict[str, Any], + post_args: Dict[str, Any], + ) -> Dict[str, Any]: + # If the upload_files flag is set, then we need variable_values + assert request.variable_values is not None + + # If we upload files, we will extract the files present in the + # variable_values dict and replace them by null values + nulled_variable_values, files = extract_files( + variables=request.variable_values, + file_classes=self.file_classes, + ) - # Generate the file map - # path is nested in a list because the spec allows multiple pointers - # to the same file. But we don't support that. - # Will generate something like {"0": ["variables.file"]} - file_map = {str(i): [path] for i, path in enumerate(files)} + # Opening the files using the FileVar parameters + open_files(list(files.values())) + self.files = files - # Enumerate the file streams - # Will generate something like {'0': <_io.BufferedReader ...>} - file_streams = {str(i): files[path] for i, path in enumerate(files)} + # Save the nulled variable values in the payload + payload["variables"] = nulled_variable_values - # Add the file map field - file_map_str = self.json_serialize(file_map) - log.debug("file_map %s", file_map_str) + # Add the payload to the operations field + operations_str = self.json_serialize(payload) + log.debug("operations %s", operations_str) - fields = {"operations": operations_str, "map": file_map_str} + # Generate the file map + # path is nested in a list because the spec allows multiple pointers + # to the same file. But we don't support that. + # Will generate something like {"0": ["variables.file"]} + file_map = {str(i): [path] for i, path in enumerate(files)} - # Add the extracted files as remaining fields - for k, f in file_streams.items(): - name = getattr(f, "name", k) - content_type = getattr(f, "content_type", None) + # Enumerate the file streams + # Will generate something like {'0': FileVar object} + file_vars = {str(i): files[path] for i, path in enumerate(files)} - if content_type is None: - fields[k] = (name, f) - else: - fields[k] = (name, f, content_type) + # Add the file map field + file_map_str = self.json_serialize(file_map) + log.debug("file_map %s", file_map_str) - # Prepare requests http to send multipart-encoded data - data = MultipartEncoder(fields=fields) + fields = {"operations": operations_str, "map": file_map_str} - post_args["data"] = data + # Add the extracted files as remaining fields + for k, file_var in file_vars.items(): + assert isinstance(file_var, FileVar) + name = k if file_var.filename is None else file_var.filename - if post_args["headers"] is None: - post_args["headers"] = {} + if file_var.content_type is None: + fields[k] = (name, file_var.f) else: - post_args["headers"] = {**post_args["headers"]} + fields[k] = (name, file_var.f, file_var.content_type) + + # Prepare requests http to send multipart-encoded data + data = MultipartEncoder(fields=fields) - post_args["headers"]["Content-Type"] = data.content_type + post_args["data"] = data + if post_args["headers"] is None: + post_args["headers"] = {} else: - if variable_values: - payload["variables"] = variable_values + post_args["headers"] = dict(post_args["headers"]) - data_key = "json" if self.use_json else "data" - post_args[data_key] = payload + post_args["headers"]["Content-Type"] = data.content_type - # Log the payload - if log.isEnabledFor(logging.INFO): - log.info(">>> %s", self.json_serialize(payload)) + return post_args - # Pass kwargs to requests post method - post_args.update(self.kwargs) + def execute( + self, + request: GraphQLRequest, + timeout: Optional[int] = None, + extra_args: Optional[Dict[str, Any]] = None, + upload_files: bool = False, + ) -> ExecutionResult: + """Execute GraphQL query. - # Pass post_args to requests post method - if extra_args: - post_args.update(extra_args) + Execute the provided request against the configured remote server. This + uses the requests library to perform a HTTP POST request to the remote server. - # Using the created session to perform requests - response = self.session.request( - self.method, self.url, **post_args # type: ignore - ) - self.response_headers = response.headers + :param request: GraphQL request as a + :class:`GraphQLRequest ` object. + :param timeout: Specifies a default timeout for requests (Default: None). + :param extra_args: additional arguments to send to the requests post method + :param upload_files: Set to True if you want to put files in the variable values + :return: The result of execution. + `data` is the result of executing the query, `errors` is null + if no errors occurred, and is a non-empty array if an error occurred. + """ - def raise_response_error(resp: requests.Response, reason: str): - # We raise a TransportServerError if the status code is 400 or higher - # We raise a TransportProtocolError in the other cases - - try: - # Raise a HTTPError if response status is 400 or higher - resp.raise_for_status() - except requests.HTTPError as e: - raise TransportServerError(str(e), e.response.status_code) from e - - result_text = resp.text - raise TransportProtocolError( - f"Server did not return a GraphQL result: " - f"{reason}: " - f"{result_text}" - ) + if not self.session: + raise TransportClosed("Transport is not connected") - try: - if self.json_deserialize == json.loads: - result = response.json() - else: - result = self.json_deserialize(response.text) + post_args = self._prepare_request( + request, + timeout=timeout, + extra_args=extra_args, + upload_files=upload_files, + ) - if log.isEnabledFor(logging.INFO): - log.info("<<< %s", response.text) + # Using the created session to perform requests + try: + response = self.session.request(self.method, self.url, **post_args) + except Exception as e: + raise TransportConnectionFailed(str(e)) from e + finally: + if upload_files: + close_files(list(self.files.values())) + + return self._prepare_result(response) + + @staticmethod + def _raise_transport_server_error_if_status_more_than_400( + response: requests.Response, + ) -> None: + # If the status is >400, + # then we need to raise a TransportServerError + try: + # Raise a HTTPError if response status is 400 or higher + response.raise_for_status() + except requests.HTTPError as e: + status_code = e.response.status_code if e.response is not None else None + raise TransportServerError(str(e), status_code) from e - except Exception: - raise_response_error(response, "Not a JSON answer") + @classmethod + def _raise_response_error(cls, resp: requests.Response, reason: str) -> NoReturn: + # We raise a TransportServerError if the status code is 400 or higher + # We raise a TransportProtocolError in the other cases - if "errors" not in result and "data" not in result: - raise_response_error(response, 'No "data" or "errors" keys in answer') + cls._raise_transport_server_error_if_status_more_than_400(resp) - return ExecutionResult( - errors=result.get("errors"), - data=result.get("data"), - extensions=result.get("extensions"), + result_text = resp.text + raise TransportProtocolError( + f"Server did not return a GraphQL result: " f"{reason}: " f"{result_text}" ) - def execute_batch( # type: ignore + def execute_batch( self, reqs: List[GraphQLRequest], timeout: Optional[int] = None, @@ -308,129 +346,67 @@ def execute_batch( # type: ignore if not self.session: raise TransportClosed("Transport is not connected") - # Using the created session to perform requests - response = self.session.request( - self.method, - self.url, - **self._build_batch_post_args(reqs, timeout, extra_args), + post_args = self._prepare_request( + reqs, + timeout=timeout, + extra_args=extra_args, ) - self.response_headers = response.headers - - answers = self._extract_response(response) - - self._validate_answer_is_a_list(answers) - self._validate_num_of_answers_same_as_requests(reqs, answers) - self._validate_every_answer_is_a_dict(answers) - self._validate_data_and_errors_keys_in_answers(answers) - - return [self._answer_to_execution_result(answer) for answer in answers] - def _answer_to_execution_result(self, result: Dict[str, Any]) -> ExecutionResult: - return ExecutionResult( - errors=result.get("errors"), - data=result.get("data"), - extensions=result.get("extensions"), - ) - - def _validate_answer_is_a_list(self, results: Any) -> None: - if not isinstance(results, list): - self._raise_invalid_result( - str(results), - "Answer is not a list", + try: + response = self.session.request( + self.method, + self.url, + **post_args, ) + except Exception as e: + raise TransportConnectionFailed(str(e)) from e - def _validate_data_and_errors_keys_in_answers( - self, results: List[Dict[str, Any]] - ) -> None: - for result in results: - if "errors" not in result and "data" not in result: - self._raise_invalid_result( - str(results), - 'No "data" or "errors" keys in answer', - ) + return self._prepare_batch_result(reqs, response) - def _validate_every_answer_is_a_dict(self, results: List[Dict[str, Any]]) -> None: - for result in results: - if not isinstance(result, dict): - self._raise_invalid_result(str(results), "Not every answer is dict") + def _get_json_result(self, response: requests.Response) -> Any: - def _validate_num_of_answers_same_as_requests( - self, - reqs: List[GraphQLRequest], - results: List[Dict[str, Any]], - ) -> None: - if len(reqs) != len(results): - self._raise_invalid_result( - str(results), - "Invalid answer length", - ) - - def _raise_invalid_result(self, result_text: str, reason: str) -> None: - raise TransportProtocolError( - f"Server did not return a valid GraphQL result: " - f"{reason}: " - f"{result_text}" - ) + # Saving latest response headers in the transport + self.response_headers = response.headers - def _extract_response(self, response: requests.Response) -> Any: try: - response.raise_for_status() - result = response.json() - - if log.isEnabledFor(logging.INFO): - log.info("<<< %s", response.text) + result = self.json_deserialize(response.text) - except requests.HTTPError as e: - raise TransportServerError( - str(e), e.response.status_code if e.response is not None else None - ) from e + if log.isEnabledFor(logging.DEBUG): + log.debug("<<< %s", response.text) except Exception: - self._raise_invalid_result(str(response.text), "Not a JSON answer") + self._raise_response_error(response, "Not a JSON answer") return result - def _build_batch_post_args( - self, - reqs: List[GraphQLRequest], - timeout: Optional[int] = None, - extra_args: Optional[Dict[str, Any]] = None, - ) -> Dict[str, Any]: - post_args: Dict[str, Any] = { - "headers": self.headers, - "auth": self.auth, - "cookies": self.cookies, - "timeout": timeout or self.default_timeout, - "verify": self.verify, - } - - data_key = "json" if self.use_json else "data" - post_args[data_key] = [self._build_data(req) for req in reqs] + def _prepare_result(self, response: requests.Response) -> ExecutionResult: - # Log the payload - if log.isEnabledFor(logging.INFO): - log.info(">>> %s", self.json_serialize(post_args[data_key])) + result = self._get_json_result(response) - # Pass kwargs to requests post method - post_args.update(self.kwargs) - - # Pass post_args to requests post method - if extra_args: - post_args.update(extra_args) - - return post_args + if "errors" not in result and "data" not in result: + self._raise_response_error(response, 'No "data" or "errors" keys in answer') - def _build_data(self, req: GraphQLRequest) -> Dict[str, Any]: - query_str = print_ast(req.document) - payload: Dict[str, Any] = {"query": query_str} + return ExecutionResult( + errors=result.get("errors"), + data=result.get("data"), + extensions=result.get("extensions"), + ) - if req.operation_name: - payload["operationName"] = req.operation_name + def _prepare_batch_result( + self, + reqs: List[GraphQLRequest], + response: requests.Response, + ) -> List[ExecutionResult]: - if req.variable_values: - payload["variables"] = req.variable_values + answers = self._get_json_result(response) - return payload + try: + return get_batch_execution_result_list(reqs, answers) + except TransportProtocolError: + # Raise a TransportServerError if status > 400 + self._raise_transport_server_error_if_status_more_than_400(response) + # In other cases, raise a TransportProtocolError + raise def close(self): """Closing the transport by closing the inner session""" diff --git a/gql/transport/transport.py b/gql/transport/transport.py index a5bd7100..7a72f9a6 100644 --- a/gql/transport/transport.py +++ b/gql/transport/transport.py @@ -1,19 +1,24 @@ import abc -from typing import List +from typing import Any, List -from graphql import DocumentNode, ExecutionResult +from graphql import ExecutionResult from ..graphql_request import GraphQLRequest class Transport(abc.ABC): @abc.abstractmethod - def execute(self, document: DocumentNode, *args, **kwargs) -> ExecutionResult: + def execute( + self, + request: GraphQLRequest, + *args: Any, + **kwargs: Any, + ) -> ExecutionResult: """Execute GraphQL query. - Execute the provided document AST for either a remote or local GraphQL Schema. + Execute the provided request for either a remote or local GraphQL Schema. - :param document: GraphQL query as AST Node or Document object. + :param request: GraphQL request as a GraphQLRequest object. :return: ExecutionResult """ raise NotImplementedError( @@ -23,8 +28,8 @@ def execute(self, document: DocumentNode, *args, **kwargs) -> ExecutionResult: def execute_batch( self, reqs: List[GraphQLRequest], - *args, - **kwargs, + *args: Any, + **kwargs: Any, ) -> List[ExecutionResult]: """Execute multiple GraphQL requests in a batch. @@ -35,7 +40,7 @@ def execute_batch( """ raise NotImplementedError( "This Transport has not implemented the execute_batch method" - ) # pragma: no cover + ) def connect(self): """Establish a session with the transport.""" diff --git a/gql/transport/websockets.py b/gql/transport/websockets.py index 02abb61f..7a0ce10a 100644 --- a/gql/transport/websockets.py +++ b/gql/transport/websockets.py @@ -1,26 +1,13 @@ -import asyncio -import json -import logging -from contextlib import suppress from ssl import SSLContext -from typing import Any, Dict, List, Optional, Tuple, Union, cast +from typing import Any, Dict, List, Optional, Union -from graphql import DocumentNode, ExecutionResult, print_ast from websockets.datastructures import HeadersLike -from websockets.exceptions import ConnectionClosed -from websockets.typing import Subprotocol -from .exceptions import ( - TransportProtocolError, - TransportQueryError, - TransportServerError, -) -from .websockets_base import WebsocketsTransportBase +from .common.adapters.websockets import WebSocketsAdapter +from .websockets_protocol import WebsocketsProtocolTransportBase -log = logging.getLogger(__name__) - -class WebsocketsTransport(WebsocketsTransportBase): +class WebsocketsTransport(WebsocketsProtocolTransportBase): """:ref:`Async Transport ` used to execute GraphQL queries on remote servers with websocket connection. @@ -28,17 +15,13 @@ class WebsocketsTransport(WebsocketsTransportBase): on a websocket connection. """ - # This transport supports two subprotocols and will autodetect the - # subprotocol supported on the server - APOLLO_SUBPROTOCOL = cast(Subprotocol, "graphql-ws") - GRAPHQLWS_SUBPROTOCOL = cast(Subprotocol, "graphql-transport-ws") - def __init__( self, url: str, + *, headers: Optional[HeadersLike] = None, ssl: Union[SSLContext, bool] = False, - init_payload: Dict[str, Any] = {}, + init_payload: Optional[Dict[str, Any]] = None, connect_timeout: Optional[Union[int, float]] = 10, close_timeout: Optional[Union[int, float]] = 10, ack_timeout: Optional[Union[int, float]] = 10, @@ -46,8 +29,8 @@ def __init__( ping_interval: Optional[Union[int, float]] = None, pong_timeout: Optional[Union[int, float]] = None, answer_pings: bool = True, - connect_args: Dict[str, Any] = {}, - subprotocols: Optional[List[Subprotocol]] = None, + connect_args: Optional[Dict[str, Any]] = None, + subprotocols: Optional[List[str]] = None, ) -> None: """Initialize the transport with the given parameters. @@ -83,438 +66,33 @@ def __init__( By default: both apollo and graphql-ws subprotocols. """ - super().__init__( - url, - headers, - ssl, - init_payload, - connect_timeout, - close_timeout, - ack_timeout, - keep_alive_timeout, - connect_args, + # Instanciate a WebSocketAdapter to indicate the use + # of the websockets dependency for this transport + self.adapter: WebSocketsAdapter = WebSocketsAdapter( + url=url, + headers=headers, + ssl=ssl, + connect_args=connect_args, ) - self.ping_interval: Optional[Union[int, float]] = ping_interval - self.pong_timeout: Optional[Union[int, float]] - self.answer_pings: bool = answer_pings - - if ping_interval is not None: - if pong_timeout is None: - self.pong_timeout = ping_interval / 2 - else: - self.pong_timeout = pong_timeout - - self.send_ping_task: Optional[asyncio.Future] = None - - self.ping_received: asyncio.Event = asyncio.Event() - """ping_received is an asyncio Event which will fire each time - a ping is received with the graphql-ws protocol""" - - self.pong_received: asyncio.Event = asyncio.Event() - """pong_received is an asyncio Event which will fire each time - a pong is received with the graphql-ws protocol""" - - if subprotocols is None: - self.supported_subprotocols = [ - self.APOLLO_SUBPROTOCOL, - self.GRAPHQLWS_SUBPROTOCOL, - ] - else: - self.supported_subprotocols = subprotocols - - async def _wait_ack(self) -> None: - """Wait for the connection_ack message. Keep alive messages are ignored""" - - while True: - init_answer = await self._receive() - - answer_type, answer_id, execution_result = self._parse_answer(init_answer) - - if answer_type == "connection_ack": - return - - if answer_type != "ka": - raise TransportProtocolError( - "Websocket server did not return a connection ack" - ) - - async def _send_init_message_and_wait_ack(self) -> None: - """Send init message to the provided websocket and wait for the connection ACK. - - If the answer is not a connection_ack message, we will return an Exception. - """ - - init_message = json.dumps( - {"type": "connection_init", "payload": self.init_payload} - ) - - await self._send(init_message) - - # Wait for the connection_ack message or raise a TimeoutError - await asyncio.wait_for(self._wait_ack(), self.ack_timeout) - - async def _initialize(self): - await self._send_init_message_and_wait_ack() - - async def send_ping(self, payload: Optional[Any] = None) -> None: - """Send a ping message for the graphql-ws protocol""" - - ping_message = {"type": "ping"} - - if payload is not None: - ping_message["payload"] = payload - - await self._send(json.dumps(ping_message)) - - async def send_pong(self, payload: Optional[Any] = None) -> None: - """Send a pong message for the graphql-ws protocol""" - - pong_message = {"type": "pong"} - - if payload is not None: - pong_message["payload"] = payload - - await self._send(json.dumps(pong_message)) - - async def _send_stop_message(self, query_id: int) -> None: - """Send stop message to the provided websocket connection and query_id. - - The server should afterwards return a 'complete' message. - """ - - stop_message = json.dumps({"id": str(query_id), "type": "stop"}) - - await self._send(stop_message) - - async def _send_complete_message(self, query_id: int) -> None: - """Send a complete message for the provided query_id. - - This is only for the graphql-ws protocol. - """ - - complete_message = json.dumps({"id": str(query_id), "type": "complete"}) - - await self._send(complete_message) - - async def _stop_listener(self, query_id: int): - """Stop the listener corresponding to the query_id depending on the - detected backend protocol. - - For apollo: send a "stop" message - (a "complete" message will be sent from the backend) - - For graphql-ws: send a "complete" message and simulate the reception - of a "complete" message from the backend - """ - log.debug(f"stop listener {query_id}") - - if self.subprotocol == self.GRAPHQLWS_SUBPROTOCOL: - await self._send_complete_message(query_id) - await self.listeners[query_id].put(("complete", None)) - else: - await self._send_stop_message(query_id) - - async def _send_connection_terminate_message(self) -> None: - """Send a connection_terminate message to the provided websocket connection. - - This message indicates that the connection will disconnect. - """ - - connection_terminate_message = json.dumps({"type": "connection_terminate"}) - - await self._send(connection_terminate_message) - - async def _send_query( - self, - document: DocumentNode, - variable_values: Optional[Dict[str, Any]] = None, - operation_name: Optional[str] = None, - ) -> int: - """Send a query to the provided websocket connection. - - We use an incremented id to reference the query. - - Returns the used id for this query. - """ - - query_id = self.next_query_id - self.next_query_id += 1 - - payload: Dict[str, Any] = {"query": print_ast(document)} - if variable_values: - payload["variables"] = variable_values - if operation_name: - payload["operationName"] = operation_name - - query_type = "start" - - if self.subprotocol == self.GRAPHQLWS_SUBPROTOCOL: - query_type = "subscribe" - - query_str = json.dumps( - {"id": str(query_id), "type": query_type, "payload": payload} + # Initialize the WebsocketsProtocolTransportBase parent class + super().__init__( + adapter=self.adapter, + init_payload=init_payload, + connect_timeout=connect_timeout, + close_timeout=close_timeout, + ack_timeout=ack_timeout, + keep_alive_timeout=keep_alive_timeout, + ping_interval=ping_interval, + pong_timeout=pong_timeout, + answer_pings=answer_pings, + subprotocols=subprotocols, ) - await self._send(query_str) - - return query_id - - async def _connection_terminate(self): - if self.subprotocol == self.APOLLO_SUBPROTOCOL: - await self._send_connection_terminate_message() - - def _parse_answer_graphqlws( - self, json_answer: Dict[str, Any] - ) -> Tuple[str, Optional[int], Optional[ExecutionResult]]: - """Parse the answer received from the server if the server supports the - graphql-ws protocol. - - Returns a list consisting of: - - the answer_type (between: - 'connection_ack', 'ping', 'pong', 'data', 'error', 'complete') - - the answer id (Integer) if received or None - - an execution Result if the answer_type is 'data' or None - - Differences with the apollo websockets protocol (superclass): - - the "data" message is now called "next" - - the "stop" message is now called "complete" - - there is no connection_terminate or connection_error messages - - instead of a unidirectional keep-alive (ka) message from server to client, - there is now the possibility to send bidirectional ping/pong messages - - connection_ack has an optional payload - - the 'error' answer type returns a list of errors instead of a single error - """ - - answer_type: str = "" - answer_id: Optional[int] = None - execution_result: Optional[ExecutionResult] = None - - try: - answer_type = str(json_answer.get("type")) - - if answer_type in ["next", "error", "complete"]: - answer_id = int(str(json_answer.get("id"))) - - if answer_type == "next" or answer_type == "error": - - payload = json_answer.get("payload") - - if answer_type == "next": - - if not isinstance(payload, dict): - raise ValueError("payload is not a dict") - - if "errors" not in payload and "data" not in payload: - raise ValueError( - "payload does not contain 'data' or 'errors' fields" - ) - - execution_result = ExecutionResult( - errors=payload.get("errors"), - data=payload.get("data"), - extensions=payload.get("extensions"), - ) - - # Saving answer_type as 'data' to be understood with superclass - answer_type = "data" - - elif answer_type == "error": - - if not isinstance(payload, list): - raise ValueError("payload is not a list") - - raise TransportQueryError( - str(payload[0]), query_id=answer_id, errors=payload - ) - - elif answer_type in ["ping", "pong", "connection_ack"]: - self.payloads[answer_type] = json_answer.get("payload", None) - - else: - raise ValueError - - if self.check_keep_alive_task is not None: - self._next_keep_alive_message.set() - - except ValueError as e: - raise TransportProtocolError( - f"Server did not return a GraphQL result: {json_answer}" - ) from e - - return answer_type, answer_id, execution_result - - def _parse_answer_apollo( - self, json_answer: Dict[str, Any] - ) -> Tuple[str, Optional[int], Optional[ExecutionResult]]: - """Parse the answer received from the server if the server supports the - apollo websockets protocol. - - Returns a list consisting of: - - the answer_type (between: - 'connection_ack', 'ka', 'connection_error', 'data', 'error', 'complete') - - the answer id (Integer) if received or None - - an execution Result if the answer_type is 'data' or None - """ - - answer_type: str = "" - answer_id: Optional[int] = None - execution_result: Optional[ExecutionResult] = None - - try: - answer_type = str(json_answer.get("type")) - - if answer_type in ["data", "error", "complete"]: - answer_id = int(str(json_answer.get("id"))) - - if answer_type == "data" or answer_type == "error": - - payload = json_answer.get("payload") - - if not isinstance(payload, dict): - raise ValueError("payload is not a dict") - - if answer_type == "data": - - if "errors" not in payload and "data" not in payload: - raise ValueError( - "payload does not contain 'data' or 'errors' fields" - ) - - execution_result = ExecutionResult( - errors=payload.get("errors"), - data=payload.get("data"), - extensions=payload.get("extensions"), - ) - - elif answer_type == "error": - - raise TransportQueryError( - str(payload), query_id=answer_id, errors=[payload] - ) - - elif answer_type == "ka": - # Keep-alive message - if self.check_keep_alive_task is not None: - self._next_keep_alive_message.set() - elif answer_type == "connection_ack": - pass - elif answer_type == "connection_error": - error_payload = json_answer.get("payload") - raise TransportServerError(f"Server error: '{repr(error_payload)}'") - else: - raise ValueError - - except ValueError as e: - raise TransportProtocolError( - f"Server did not return a GraphQL result: {json_answer}" - ) from e - - return answer_type, answer_id, execution_result - - def _parse_answer( - self, answer: str - ) -> Tuple[str, Optional[int], Optional[ExecutionResult]]: - """Parse the answer received from the server depending on - the detected subprotocol. - """ - try: - json_answer = json.loads(answer) - except ValueError: - raise TransportProtocolError( - f"Server did not return a GraphQL result: {answer}" - ) - - if self.subprotocol == self.GRAPHQLWS_SUBPROTOCOL: - return self._parse_answer_graphqlws(json_answer) - - return self._parse_answer_apollo(json_answer) - - async def _send_ping_coro(self) -> None: - """Coroutine to periodically send a ping from the client to the backend. - - Only used for the graphql-ws protocol. - - Send a ping every ping_interval seconds. - Close the connection if a pong is not received within pong_timeout seconds. - """ - - assert self.ping_interval is not None - - try: - while True: - await asyncio.sleep(self.ping_interval) - - await self.send_ping() - - await asyncio.wait_for(self.pong_received.wait(), self.pong_timeout) - - # Reset for the next iteration - self.pong_received.clear() - - except asyncio.TimeoutError: - # No pong received in the appriopriate time, close with error - # If the timeout happens during a close already in progress, do nothing - if self.close_task is None: - await self._fail( - TransportServerError( - f"No pong received after {self.pong_timeout!r} seconds" - ), - clean_close=False, - ) - - async def _handle_answer( - self, - answer_type: str, - answer_id: Optional[int], - execution_result: Optional[ExecutionResult], - ) -> None: - - # Put the answer in the queue - await super()._handle_answer(answer_type, answer_id, execution_result) - - # Answer pong to ping for graphql-ws protocol - if answer_type == "ping": - self.ping_received.set() - if self.answer_pings: - await self.send_pong() - - elif answer_type == "pong": - self.pong_received.set() - - async def _after_connect(self): - - # Find the backend subprotocol returned in the response headers - response_headers = self.websocket.response_headers - try: - self.subprotocol = response_headers["Sec-WebSocket-Protocol"] - except KeyError: - # If the server does not send the subprotocol header, using - # the apollo subprotocol by default - self.subprotocol = self.APOLLO_SUBPROTOCOL - - log.debug(f"backend subprotocol returned: {self.subprotocol!r}") - - async def _after_initialize(self): - - # If requested, create a task to send periodic pings to the backend - if ( - self.subprotocol == self.GRAPHQLWS_SUBPROTOCOL - and self.ping_interval is not None - ): - - self.send_ping_task = asyncio.ensure_future(self._send_ping_coro()) - - async def _close_hook(self): - log.debug("_close_hook: start") - - # Properly shut down the send ping task if enabled - if self.send_ping_task is not None: - log.debug("_close_hook: cancelling send_ping_task") - self.send_ping_task.cancel() - with suppress(asyncio.CancelledError, ConnectionClosed): - log.debug("_close_hook: awaiting send_ping_task") - await self.send_ping_task - self.send_ping_task = None + @property + def headers(self) -> Optional[HeadersLike]: + return self.adapter.headers - log.debug("_close_hook: end") + @property + def ssl(self) -> Union[SSLContext, bool]: + return self.adapter.ssl diff --git a/gql/transport/websockets_protocol.py b/gql/transport/websockets_protocol.py new file mode 100644 index 00000000..3b66a0cb --- /dev/null +++ b/gql/transport/websockets_protocol.py @@ -0,0 +1,511 @@ +import asyncio +import json +import logging +from contextlib import suppress +from typing import Any, Dict, List, Optional, Tuple, Union + +from graphql import ExecutionResult + +from ..graphql_request import GraphQLRequest +from .common.adapters.connection import AdapterConnection +from .common.base import SubscriptionTransportBase +from .exceptions import ( + TransportConnectionFailed, + TransportProtocolError, + TransportQueryError, + TransportServerError, +) + +log = logging.getLogger("gql.transport.websockets") + + +class WebsocketsProtocolTransportBase(SubscriptionTransportBase): + """:ref:`Async Transport ` used to execute GraphQL queries on + remote servers with websocket connection. + + This transport uses asyncio and the provided websockets adapter library + in order to send requests on a websocket connection. + """ + + # This transport supports two subprotocols and will autodetect the + # subprotocol supported on the server + APOLLO_SUBPROTOCOL = "graphql-ws" + GRAPHQLWS_SUBPROTOCOL = "graphql-transport-ws" + + def __init__( + self, + *, + adapter: AdapterConnection, + init_payload: Optional[Dict[str, Any]] = None, + connect_timeout: Optional[Union[int, float]] = 10, + close_timeout: Optional[Union[int, float]] = 10, + ack_timeout: Optional[Union[int, float]] = 10, + keep_alive_timeout: Optional[Union[int, float]] = None, + ping_interval: Optional[Union[int, float]] = None, + pong_timeout: Optional[Union[int, float]] = None, + answer_pings: bool = True, + subprotocols: Optional[List[str]] = None, + ) -> None: + """Initialize the transport with the given parameters. + + :param adapter: The connection dependency adapter + :param init_payload: Dict of the payload sent in the connection_init message. + :param connect_timeout: Timeout in seconds for the establishment + of the websocket connection. If None is provided this will wait forever. + :param close_timeout: Timeout in seconds for the close. If None is provided + this will wait forever. + :param ack_timeout: Timeout in seconds to wait for the connection_ack message + from the server. If None is provided this will wait forever. + :param keep_alive_timeout: Optional Timeout in seconds to receive + a sign of liveness from the server. + :param ping_interval: Delay in seconds between pings sent by the client to + the backend for the graphql-ws protocol. None (by default) means that + we don't send pings. Note: there are also pings sent by the underlying + websockets protocol. See the + :ref:`keepalive documentation ` + for more information about this. + :param pong_timeout: Delay in seconds to receive a pong from the backend + after we sent a ping (only for the graphql-ws protocol). + By default equal to half of the ping_interval. + :param answer_pings: Whether the client answers the pings from the backend + (for the graphql-ws protocol). + By default: True + :param subprotocols: list of subprotocols sent to the + backend in the 'subprotocols' http header. + By default: both apollo and graphql-ws subprotocols. + """ + + if subprotocols is None: + subprotocols = [ + self.APOLLO_SUBPROTOCOL, + self.GRAPHQLWS_SUBPROTOCOL, + ] + + self.adapter.subprotocols = subprotocols + + # Initialize the generic SubscriptionTransportBase parent class + super().__init__( + adapter=self.adapter, + connect_timeout=connect_timeout, + close_timeout=close_timeout, + keep_alive_timeout=keep_alive_timeout, + ) + + if init_payload is None: + init_payload = {} + + self.init_payload: Dict[str, Any] = init_payload + self.ack_timeout: Optional[Union[int, float]] = ack_timeout + + self.payloads: Dict[str, Any] = {} + """payloads is a dict which will contain the payloads received + for example with the graphql-ws protocol: 'ping', 'pong', 'connection_ack'""" + + self.ping_interval: Optional[Union[int, float]] = ping_interval + self.pong_timeout: Optional[Union[int, float]] + self.answer_pings: bool = answer_pings + + if ping_interval is not None: + if pong_timeout is None: + self.pong_timeout = ping_interval / 2 + else: + self.pong_timeout = pong_timeout + + self.send_ping_task: Optional[asyncio.Future] = None + + self.ping_received: asyncio.Event = asyncio.Event() + """ping_received is an asyncio Event which will fire each time + a ping is received with the graphql-ws protocol""" + + self.pong_received: asyncio.Event = asyncio.Event() + """pong_received is an asyncio Event which will fire each time + a pong is received with the graphql-ws protocol""" + + async def _wait_ack(self) -> None: + """Wait for the connection_ack message. Keep alive messages are ignored""" + + while True: + init_answer = await self._receive() + + answer_type, answer_id, execution_result = self._parse_answer(init_answer) + + if answer_type == "connection_ack": + return + + if answer_type != "ka": + raise TransportProtocolError( + "Websocket server did not return a connection ack" + ) + + async def _send_init_message_and_wait_ack(self) -> None: + """Send init message to the provided websocket and wait for the connection ACK. + + If the answer is not a connection_ack message, we will return an Exception. + """ + + init_message = json.dumps( + {"type": "connection_init", "payload": self.init_payload} + ) + + await self._send(init_message) + + # Wait for the connection_ack message or raise a TimeoutError + await asyncio.wait_for(self._wait_ack(), self.ack_timeout) + + async def _initialize(self): + await self._send_init_message_and_wait_ack() + + async def send_ping(self, payload: Optional[Any] = None) -> None: + """Send a ping message for the graphql-ws protocol""" + + ping_message = {"type": "ping"} + + if payload is not None: + ping_message["payload"] = payload + + await self._send(json.dumps(ping_message)) + + async def send_pong(self, payload: Optional[Any] = None) -> None: + """Send a pong message for the graphql-ws protocol""" + + pong_message = {"type": "pong"} + + if payload is not None: + pong_message["payload"] = payload + + await self._send(json.dumps(pong_message)) + + async def _send_stop_message(self, query_id: int) -> None: + """Send stop message to the provided websocket connection and query_id. + + The server should afterwards return a 'complete' message. + """ + + stop_message = json.dumps({"id": str(query_id), "type": "stop"}) + + await self._send(stop_message) + + async def _send_complete_message(self, query_id: int) -> None: + """Send a complete message for the provided query_id. + + This is only for the graphql-ws protocol. + """ + + complete_message = json.dumps({"id": str(query_id), "type": "complete"}) + + await self._send(complete_message) + + async def _stop_listener(self, query_id: int) -> None: + """Stop the listener corresponding to the query_id depending on the + detected backend protocol. + + For apollo: send a "stop" message + (a "complete" message will be sent from the backend) + + For graphql-ws: send a "complete" message and simulate the reception + of a "complete" message from the backend + """ + log.debug(f"stop listener {query_id}") + + if self.subprotocol == self.GRAPHQLWS_SUBPROTOCOL: + await self._send_complete_message(query_id) + await self.listeners[query_id].put(("complete", None)) + else: + await self._send_stop_message(query_id) + + async def _send_connection_terminate_message(self) -> None: + """Send a connection_terminate message to the provided websocket connection. + + This message indicates that the connection will disconnect. + """ + + connection_terminate_message = json.dumps({"type": "connection_terminate"}) + + await self._send(connection_terminate_message) + + async def _send_query( + self, + request: GraphQLRequest, + ) -> int: + """Send a query to the provided websocket connection. + + We use an incremented id to reference the query. + + Returns the used id for this query. + """ + + query_id = self.next_query_id + self.next_query_id += 1 + + payload: Dict[str, Any] = request.payload + + query_type = "start" + + if self.subprotocol == self.GRAPHQLWS_SUBPROTOCOL: + query_type = "subscribe" + + query_str = json.dumps( + {"id": str(query_id), "type": query_type, "payload": payload} + ) + + await self._send(query_str) + + return query_id + + async def _connection_terminate(self): + if self.subprotocol == self.APOLLO_SUBPROTOCOL: + await self._send_connection_terminate_message() + + def _parse_answer_graphqlws( + self, json_answer: Dict[str, Any] + ) -> Tuple[str, Optional[int], Optional[ExecutionResult]]: + """Parse the answer received from the server if the server supports the + graphql-ws protocol. + + Returns a list consisting of: + - the answer_type (between: + 'connection_ack', 'ping', 'pong', 'data', 'error', 'complete') + - the answer id (Integer) if received or None + - an execution Result if the answer_type is 'data' or None + + Differences with the apollo websockets protocol (superclass): + - the "data" message is now called "next" + - the "stop" message is now called "complete" + - there is no connection_terminate or connection_error messages + - instead of a unidirectional keep-alive (ka) message from server to client, + there is now the possibility to send bidirectional ping/pong messages + - connection_ack has an optional payload + - the 'error' answer type returns a list of errors instead of a single error + """ + + answer_type: str = "" + answer_id: Optional[int] = None + execution_result: Optional[ExecutionResult] = None + + try: + answer_type = str(json_answer.get("type")) + + if answer_type in ["next", "error", "complete"]: + answer_id = int(str(json_answer.get("id"))) + + if answer_type == "next" or answer_type == "error": + + payload = json_answer.get("payload") + + if answer_type == "next": + + if not isinstance(payload, dict): + raise ValueError("payload is not a dict") + + if "errors" not in payload and "data" not in payload: + raise ValueError( + "payload does not contain 'data' or 'errors' fields" + ) + + execution_result = ExecutionResult( + errors=payload.get("errors"), + data=payload.get("data"), + extensions=payload.get("extensions"), + ) + + # Saving answer_type as 'data' to be understood with superclass + answer_type = "data" + + elif answer_type == "error": + + if not isinstance(payload, list): + raise ValueError("payload is not a list") + + raise TransportQueryError( + str(payload[0]), query_id=answer_id, errors=payload + ) + + elif answer_type in ["ping", "pong", "connection_ack"]: + self.payloads[answer_type] = json_answer.get("payload", None) + + else: + raise ValueError + + if self.check_keep_alive_task is not None: + self._next_keep_alive_message.set() + + except ValueError as e: + raise TransportProtocolError( + f"Server did not return a GraphQL result: {json_answer}" + ) from e + + return answer_type, answer_id, execution_result + + def _parse_answer_apollo( + self, json_answer: Dict[str, Any] + ) -> Tuple[str, Optional[int], Optional[ExecutionResult]]: + """Parse the answer received from the server if the server supports the + apollo websockets protocol. + + Returns a list consisting of: + - the answer_type (between: + 'connection_ack', 'ka', 'connection_error', 'data', 'error', 'complete') + - the answer id (Integer) if received or None + - an execution Result if the answer_type is 'data' or None + """ + + answer_type: str = "" + answer_id: Optional[int] = None + execution_result: Optional[ExecutionResult] = None + + try: + answer_type = str(json_answer.get("type")) + + if answer_type in ["data", "error", "complete"]: + answer_id = int(str(json_answer.get("id"))) + + if answer_type == "data" or answer_type == "error": + + payload = json_answer.get("payload") + + if not isinstance(payload, dict): + raise ValueError("payload is not a dict") + + if answer_type == "data": + + if "errors" not in payload and "data" not in payload: + raise ValueError( + "payload does not contain 'data' or 'errors' fields" + ) + + execution_result = ExecutionResult( + errors=payload.get("errors"), + data=payload.get("data"), + extensions=payload.get("extensions"), + ) + + elif answer_type == "error": + + raise TransportQueryError( + str(payload), query_id=answer_id, errors=[payload] + ) + + elif answer_type == "ka": + # Keep-alive message + if self.check_keep_alive_task is not None: + self._next_keep_alive_message.set() + elif answer_type == "connection_ack": + pass + elif answer_type == "connection_error": + error_payload = json_answer.get("payload") + raise TransportServerError(f"Server error: '{repr(error_payload)}'") + else: + raise ValueError + + except ValueError as e: + raise TransportProtocolError( + f"Server did not return a GraphQL result: {json_answer}" + ) from e + + return answer_type, answer_id, execution_result + + def _parse_answer( + self, answer: str + ) -> Tuple[str, Optional[int], Optional[ExecutionResult]]: + """Parse the answer received from the server depending on + the detected subprotocol. + """ + try: + json_answer = json.loads(answer) + except ValueError: + raise TransportProtocolError( + f"Server did not return a GraphQL result: {answer}" + ) + + if self.subprotocol == self.GRAPHQLWS_SUBPROTOCOL: + return self._parse_answer_graphqlws(json_answer) + + return self._parse_answer_apollo(json_answer) + + async def _send_ping_coro(self) -> None: + """Coroutine to periodically send a ping from the client to the backend. + + Only used for the graphql-ws protocol. + + Send a ping every ping_interval seconds. + Close the connection if a pong is not received within pong_timeout seconds. + """ + + assert self.ping_interval is not None + + try: + while True: + await asyncio.sleep(self.ping_interval) + + await self.send_ping() + + await asyncio.wait_for(self.pong_received.wait(), self.pong_timeout) + + # Reset for the next iteration + self.pong_received.clear() + + except asyncio.TimeoutError: + # No pong received in the appriopriate time, close with error + # If the timeout happens during a close already in progress, do nothing + if self.close_task is None: + await self._fail( + TransportServerError( + f"No pong received after {self.pong_timeout!r} seconds" + ), + clean_close=False, + ) + + async def _handle_answer( + self, + answer_type: str, + answer_id: Optional[int], + execution_result: Optional[ExecutionResult], + ) -> None: + + # Put the answer in the queue + await super()._handle_answer(answer_type, answer_id, execution_result) + + # Answer pong to ping for graphql-ws protocol + if answer_type == "ping": + self.ping_received.set() + if self.answer_pings: + await self.send_pong() + + elif answer_type == "pong": + self.pong_received.set() + + async def _after_connect(self): + + # Find the backend subprotocol returned in the response headers + try: + self.subprotocol = self.response_headers["Sec-WebSocket-Protocol"] + except KeyError: + # If the server does not send the subprotocol header, using + # the apollo subprotocol by default + self.subprotocol = self.APOLLO_SUBPROTOCOL + + log.debug(f"backend subprotocol returned: {self.subprotocol!r}") + + async def _after_initialize(self): + + # If requested, create a task to send periodic pings to the backend + if ( + self.subprotocol == self.GRAPHQLWS_SUBPROTOCOL + and self.ping_interval is not None + ): + + self.send_ping_task = asyncio.ensure_future(self._send_ping_coro()) + + async def _close_hook(self): + log.debug("_close_hook: start") + + # Properly shut down the send ping task if enabled + if self.send_ping_task is not None: + log.debug("_close_hook: cancelling send_ping_task") + self.send_ping_task.cancel() + with suppress(asyncio.CancelledError, TransportConnectionFailed): + log.debug("_close_hook: awaiting send_ping_task") + await self.send_ping_task + self.send_ping_task = None + + log.debug("_close_hook: end") diff --git a/gql/utilities/get_introspection_query_ast.py b/gql/utilities/get_introspection_query_ast.py index 975ccc83..0422a225 100644 --- a/gql/utilities/get_introspection_query_ast.py +++ b/gql/utilities/get_introspection_query_ast.py @@ -10,7 +10,7 @@ def get_introspection_query_ast( specified_by_url: bool = False, directive_is_repeatable: bool = False, schema_description: bool = False, - input_value_deprecation: bool = False, + input_value_deprecation: bool = True, type_recursion_level: int = 7, ) -> DocumentNode: """Get a query for introspection as a document using the DSL module. @@ -139,4 +139,4 @@ def get_introspection_query_ast( dsl_query = dsl_gql(query, fragment_FullType, fragment_InputValue, fragment_TypeRef) - return dsl_query + return dsl_query.document diff --git a/gql/utilities/node_tree.py b/gql/utilities/node_tree.py index 4313188e..08fb1bf5 100644 --- a/gql/utilities/node_tree.py +++ b/gql/utilities/node_tree.py @@ -8,7 +8,7 @@ def _node_tree_recursive( *, indent: int = 0, ignored_keys: List, -): +) -> str: assert ignored_keys is not None @@ -65,7 +65,7 @@ def node_tree( ignore_loc: bool = True, ignore_block: bool = True, ignored_keys: Optional[List] = None, -): +) -> str: """Method which returns a tree of Node elements as a String. Useful to debug deep DocumentNode instances created by gql or dsl_gql. diff --git a/gql/utilities/parse_result.py b/gql/utilities/parse_result.py index 02355425..f9bc2e0c 100644 --- a/gql/utilities/parse_result.py +++ b/gql/utilities/parse_result.py @@ -44,7 +44,7 @@ } -def _ignore_non_null(type_: GraphQLType): +def _ignore_non_null(type_: GraphQLType) -> GraphQLType: """Removes the GraphQLNonNull wrappings around types.""" if isinstance(type_, GraphQLNonNull): return type_.of_type @@ -153,6 +153,8 @@ def get_current_result_type(self, path): list_level = self.inside_list_level + assert field_type is not None + result_type = _ignore_non_null(field_type) if self.in_first_field(path): diff --git a/gql/utilities/update_schema_enum.py b/gql/utilities/update_schema_enum.py index 80c73862..6f7ba0ce 100644 --- a/gql/utilities/update_schema_enum.py +++ b/gql/utilities/update_schema_enum.py @@ -9,7 +9,7 @@ def update_schema_enum( name: str, values: Union[Dict[str, Any], Type[Enum]], use_enum_values: bool = False, -): +) -> None: """Update in the schema the GraphQLEnumType corresponding to the given name. Example:: diff --git a/gql/utilities/update_schema_scalars.py b/gql/utilities/update_schema_scalars.py index db3adb17..c2c1b4e8 100644 --- a/gql/utilities/update_schema_scalars.py +++ b/gql/utilities/update_schema_scalars.py @@ -3,7 +3,9 @@ from graphql import GraphQLScalarType, GraphQLSchema -def update_schema_scalar(schema: GraphQLSchema, name: str, scalar: GraphQLScalarType): +def update_schema_scalar( + schema: GraphQLSchema, name: str, scalar: GraphQLScalarType +) -> None: """Update the scalar in a schema with the scalar provided. :param schema: the GraphQL schema @@ -36,7 +38,9 @@ def update_schema_scalar(schema: GraphQLSchema, name: str, scalar: GraphQLScalar setattr(schema_scalar, "parse_literal", scalar.parse_literal) -def update_schema_scalars(schema: GraphQLSchema, scalars: List[GraphQLScalarType]): +def update_schema_scalars( + schema: GraphQLSchema, scalars: List[GraphQLScalarType] +) -> None: """Update the scalars in a schema with the scalars provided. :param schema: the GraphQL schema diff --git a/gql/utils.py b/gql/utils.py index b4265ce1..f7f0f5a7 100644 --- a/gql/utils.py +++ b/gql/utils.py @@ -1,6 +1,6 @@ """Utilities to manipulate several python objects.""" -from typing import Any, Dict, List, Tuple, Type +from typing import List # From this response in Stackoverflow @@ -12,43 +12,6 @@ def to_camel_case(snake_str): return components[0] + "".join(x.title() if x else "_" for x in components[1:]) -def extract_files( - variables: Dict, file_classes: Tuple[Type[Any], ...] -) -> Tuple[Dict, Dict]: - files = {} - - def recurse_extract(path, obj): - """ - recursively traverse obj, doing a deepcopy, but - replacing any file-like objects with nulls and - shunting the originals off to the side. - """ - nonlocal files - if isinstance(obj, list): - nulled_obj = [] - for key, value in enumerate(obj): - value = recurse_extract(f"{path}.{key}", value) - nulled_obj.append(value) - return nulled_obj - elif isinstance(obj, dict): - nulled_obj = {} - for key, value in obj.items(): - value = recurse_extract(f"{path}.{key}", value) - nulled_obj[key] = value - return nulled_obj - elif isinstance(obj, file_classes): - # extract obj from its parent and put it into files instead. - files[path] = obj - return None - else: - # base case: pass through unchanged - return obj - - nulled_variables = recurse_extract("variables", variables) - - return nulled_variables, files - - def str_first_element(errors: List) -> str: try: first_error = errors[0] diff --git a/pyproject.toml b/pyproject.toml index 9b631e08..f5eb5c8d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -7,3 +7,16 @@ dynamic = ["authors", "classifiers", "dependencies", "description", "entry-point [build-system] requires = ["setuptools"] build-backend = "setuptools.build_meta" + +[tool.isort] +extra_standard_library = "ssl" +known_first_party = "gql" +profile = "black" + +[tool.pytest.ini_options] +asyncio_default_fixture_loop_scope = "function" + +[tool.mypy] +ignore_missing_imports = true +check_untyped_defs = true +disallow_incomplete_defs = true diff --git a/setup.cfg b/setup.cfg index 66380493..533b80f1 100644 --- a/setup.cfg +++ b/setup.cfg @@ -1,16 +1,5 @@ [flake8] max-line-length = 88 -[isort] -known_standard_library = ssl -known_first_party = gql -multi_line_output = 3 -include_trailing_comma = True -line_length = 88 -not_skip = __init__.py - -[mypy] -ignore_missing_imports = true - [tool:pytest] norecursedirs = venv .venv .tox .git .cache .mypy_cache .pytest_cache diff --git a/setup.py b/setup.py index a44c2e01..3db1c9f8 100644 --- a/setup.py +++ b/setup.py @@ -14,32 +14,31 @@ ] tests_requires = [ - "parse==1.15.0", - "pytest==7.4.2", - "pytest-asyncio==0.21.1", + "parse==1.20.2", + "pytest==8.3.4", + "pytest-asyncio==0.25.3", "pytest-console-scripts==1.4.1", - "pytest-cov==5.0.0", + "pytest-cov==6.0.0", "vcrpy==7.0.0", "aiofiles", ] dev_requires = [ - "black==22.3.0", + "black==25.1.0", "check-manifest>=0.42,<1", - "flake8==7.1.1", - "isort==4.3.21", - "mypy==1.10", + "flake8==7.1.2", + "isort==6.0.1", + "mypy==1.15", "sphinx>=7.0.0,<8;python_version<='3.9'", "sphinx>=8.1.0,<9;python_version>'3.9'", "sphinx_rtd_theme>=3.0.2,<4", - "sphinx-argparse==0.4.0", + "sphinx-argparse==0.5.2", "types-aiofiles", "types-requests", ] + tests_requires install_aiohttp_requires = [ - "aiohttp>=3.8.0,<4;python_version<='3.11'", - "aiohttp>=3.9.0b0,<4;python_version>'3.11'", + "aiohttp>=3.11.2,<4", ] install_requests_requires = [ @@ -52,15 +51,19 @@ ] install_websockets_requires = [ - "websockets>=10.1,<14", + "websockets>=14.2,<16", ] install_botocore_requires = [ "botocore>=1.21,<2", ] +install_aiofiles_requires = [ + "aiofiles", +] + install_all_requires = ( - install_aiohttp_requires + install_requests_requires + install_httpx_requires + install_websockets_requires + install_botocore_requires + install_aiohttp_requires + install_requests_requires + install_httpx_requires + install_websockets_requires + install_botocore_requires + install_aiofiles_requires ) # Get version from __version__.py file @@ -83,7 +86,6 @@ "Development Status :: 5 - Production/Stable", "Intended Audience :: Developers", "Topic :: Software Development :: Libraries", - "License :: OSI Approved :: MIT License", "Programming Language :: Python :: 3", "Programming Language :: Python :: 3 :: Only", "Programming Language :: Python :: 3.9", @@ -108,6 +110,7 @@ "httpx": install_httpx_requires, "websockets": install_websockets_requires, "botocore": install_botocore_requires, + "aiofiles": install_aiofiles_requires, }, include_package_data=True, zip_safe=False, diff --git a/tests/conftest.py b/tests/conftest.py index b0103a99..cef561f7 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -3,22 +3,27 @@ import logging import os import pathlib +import platform import re import ssl import sys import tempfile import types from concurrent.futures import ThreadPoolExecutor -from typing import Union +from typing import Callable, Iterable, List, Union, cast import pytest import pytest_asyncio +from _pytest.fixtures import SubRequest from gql import Client all_transport_dependencies = ["aiohttp", "requests", "httpx", "websockets", "botocore"] +PyPy = platform.python_implementation() == "PyPy" + + def pytest_addoption(parser): parser.addoption( "--run-online", @@ -121,9 +126,10 @@ async def ssl_aiohttp_server(): "gql.transport.aiohttp", "gql.transport.aiohttp_websockets", "gql.transport.appsync", + "gql.transport.common.base", + "gql.transport.httpx", "gql.transport.phoenix_channel_websockets", "gql.transport.requests", - "gql.transport.httpx", "gql.transport.websockets", "gql.dsl", "gql.utilities.parse_result", @@ -191,7 +197,7 @@ def __init__(self, with_ssl: bool = False): async def start(self, handler, extra_serve_args=None): - import websockets.server + import websockets print("Starting server") @@ -203,18 +209,23 @@ async def start(self, handler, extra_serve_args=None): extra_serve_args["ssl"] = ssl_context # Adding dummy response headers - extra_serve_args["extra_headers"] = {"dummy": "test1234"} + extra_headers = {"dummy": "test1234"} + + def process_response(connection, request, response): + response.headers.update(extra_headers) + return response # Start a server with a random open port - self.start_server = websockets.server.serve( - handler, "127.0.0.1", 0, **extra_serve_args + self.server = await websockets.serve( + handler, + "127.0.0.1", + 0, + process_response=process_response, + **extra_serve_args, ) - # Wait that the server is started - self.server = await self.start_server - # Get hostname and port - hostname, port = self.server.sockets[0].getsockname()[:2] + hostname, port = self.server.sockets[0].getsockname()[:2] # type: ignore assert hostname == "127.0.0.1" self.hostname = hostname @@ -245,7 +256,7 @@ def __init__(self, with_ssl=False): if with_ssl: _, self.ssl_context = get_localhost_ssl_context() - def get_default_server_handler(answers): + def get_default_server_handler(answers: Iterable[str]) -> Callable: async def default_server_handler(request): import aiohttp @@ -286,7 +297,7 @@ async def default_server_handler(request): elif msg.type == WSMsgType.ERROR: print(f"WebSocket connection closed with: {ws.exception()}") - raise ws.exception() + raise ws.exception() # type: ignore elif msg.type in ( WSMsgType.CLOSE, WSMsgType.CLOSED, @@ -336,7 +347,8 @@ async def start(self, handler): await self.site.start() # Retrieve the actual port the server is listening on - sockets = self.site._server.sockets + assert self.site._server is not None + sockets = self.site._server.sockets # type: ignore if sockets: self.port = sockets[0].getsockname()[1] protocol = "https" if self.with_ssl else "http" @@ -443,7 +455,7 @@ async def send_connection_ack(ws): class TemporaryFile: """Class used to generate temporary files for the tests""" - def __init__(self, content: Union[str, bytearray]): + def __init__(self, content: Union[str, bytearray, bytes]): mode = "w" if isinstance(content, str) else "wb" @@ -469,24 +481,30 @@ def __exit__(self, type, value, traceback): os.unlink(self.filename) -def get_aiohttp_ws_server_handler(request): +def get_aiohttp_ws_server_handler( + request: SubRequest, +) -> Callable: """Get the server handler for the aiohttp websocket server. Either get it from test or use the default server handler if the test provides only an array of answers. """ + server_handler: Callable + if isinstance(request.param, types.FunctionType): server_handler = request.param else: - answers = request.param + answers = cast(List[str], request.param) server_handler = AIOHTTPWebsocketServer.get_default_server_handler(answers) return server_handler -def get_server_handler(request): +def get_server_handler( + request: SubRequest, +) -> Callable: """Get the server handler. Either get it from test or use the default server handler @@ -496,7 +514,7 @@ def get_server_handler(request): from websockets.exceptions import ConnectionClosed if isinstance(request.param, types.FunctionType): - server_handler = request.param + server_handler: Callable = request.param else: answers = request.param @@ -590,24 +608,6 @@ async def graphqlws_server(request): subprotocol = "graphql-transport-ws" - from websockets.server import WebSocketServerProtocol - - class CustomSubprotocol(WebSocketServerProtocol): - def select_subprotocol(self, client_subprotocols, server_subprotocols): - print(f"Client subprotocols: {client_subprotocols!r}") - print(f"Server subprotocols: {server_subprotocols!r}") - - return subprotocol - - def process_subprotocol(self, headers, available_subprotocols): - # Overwriting available subprotocols - available_subprotocols = [subprotocol] - - print(f"headers: {headers!r}") - # print (f"Available subprotocols: {available_subprotocols!r}") - - return super().process_subprotocol(headers, available_subprotocols) - server_handler = get_server_handler(request) try: @@ -615,7 +615,7 @@ def process_subprotocol(self, headers, available_subprotocols): # Starting the server with the fixture param as the handler function await test_server.start( - server_handler, extra_serve_args={"create_protocol": CustomSubprotocol} + server_handler, extra_serve_args={"subprotocols": [subprotocol]} ) yield test_server @@ -634,9 +634,9 @@ async def client_and_server(server): # Generate transport to connect to the server fixture path = "/graphql" url = f"ws://{server.hostname}:{server.port}{path}" - sample_transport = WebsocketsTransport(url=url) + transport = WebsocketsTransport(url=url) - async with Client(transport=sample_transport) as session: + async with Client(transport=transport) as session: # Yield both client session and server yield session, server @@ -654,9 +654,9 @@ async def aiohttp_client_and_server(server): # Generate transport to connect to the server fixture path = "/graphql" url = f"ws://{server.hostname}:{server.port}{path}" - sample_transport = AIOHTTPWebsocketsTransport(url=url) + transport = AIOHTTPWebsocketsTransport(url=url) - async with Client(transport=sample_transport) as session: + async with Client(transport=transport) as session: # Yield both client session and server yield session, server @@ -676,9 +676,9 @@ async def aiohttp_client_and_aiohttp_ws_server(aiohttp_ws_server): # Generate transport to connect to the server fixture path = "/graphql" url = f"ws://{server.hostname}:{server.port}{path}" - sample_transport = AIOHTTPWebsocketsTransport(url=url) + transport = AIOHTTPWebsocketsTransport(url=url) - async with Client(transport=sample_transport) as session: + async with Client(transport=transport) as session: # Yield both client session and server yield session, server @@ -694,12 +694,12 @@ async def client_and_graphqlws_server(graphqlws_server): # Generate transport to connect to the server fixture path = "/graphql" url = f"ws://{graphqlws_server.hostname}:{graphqlws_server.port}{path}" - sample_transport = WebsocketsTransport( + transport = WebsocketsTransport( url=url, subprotocols=[WebsocketsTransport.GRAPHQLWS_SUBPROTOCOL], ) - async with Client(transport=sample_transport) as session: + async with Client(transport=transport) as session: # Yield both client session and server yield session, graphqlws_server @@ -715,12 +715,12 @@ async def client_and_aiohttp_websocket_graphql_server(graphqlws_server): # Generate transport to connect to the server fixture path = "/graphql" url = f"ws://{graphqlws_server.hostname}:{graphqlws_server.port}{path}" - sample_transport = AIOHTTPWebsocketsTransport( + transport = AIOHTTPWebsocketsTransport( url=url, subprotocols=[AIOHTTPWebsocketsTransport.GRAPHQLWS_SUBPROTOCOL], ) - async with Client(transport=sample_transport) as session: + async with Client(transport=transport) as session: # Yield both client session and server yield session, graphqlws_server @@ -728,11 +728,12 @@ async def client_and_aiohttp_websocket_graphql_server(graphqlws_server): @pytest_asyncio.fixture async def run_sync_test(): - async def run_sync_test_inner(event_loop, server, test_function): + async def run_sync_test_inner(server, test_function): """This function will run the test in a different Thread. This allows us to run sync code while aiohttp server can still run. """ + event_loop = asyncio.get_running_loop() executor = ThreadPoolExecutor(max_workers=2) test_task = event_loop.run_in_executor(executor, test_function) @@ -762,3 +763,62 @@ def strip_braces_spaces(s): strip_back = re.sub(r"([^\s]) }", r"\1}", strip_front) return strip_back + + +def make_upload_handler( + nb_files=1, + filenames=None, + request_headers=None, + file_headers=None, + binary=False, + expected_contents=None, + expected_operations=None, + expected_map=None, + server_answer='{"data":{"success":true}}', +): + assert expected_contents is not None + assert expected_operations is not None + assert expected_map is not None + + async def single_upload_handler(request): + from aiohttp import web + + reader = await request.multipart() + + if request_headers is not None: + for k, v in request_headers.items(): + assert request.headers[k] == v + + field_0 = await reader.next() + assert field_0.name == "operations" + field_0_text = await field_0.text() + assert strip_braces_spaces(field_0_text) == expected_operations + + field_1 = await reader.next() + assert field_1.name == "map" + field_1_text = await field_1.text() + assert field_1_text == expected_map + + for i in range(nb_files): + field = await reader.next() + assert field.name == str(i) + if filenames is not None: + assert field.filename == filenames[i] + + if binary: + field_content = await field.read() + assert field_content == expected_contents[i] + else: + field_text = await field.text() + assert field_text == expected_contents[i] + + if file_headers is not None: + for k, v in file_headers[i].items(): + assert field.headers[k] == v + + final_field = await reader.next() + assert final_field is None + + return web.Response(text=server_answer, content_type="application/json") + + return single_upload_handler diff --git a/tests/custom_scalars/test_datetime.py b/tests/custom_scalars/test_datetime.py index 5a36669c..4d9589f1 100644 --- a/tests/custom_scalars/test_datetime.py +++ b/tests/custom_scalars/test_datetime.py @@ -117,11 +117,11 @@ def test_shift_days(): query = gql("query shift5days($time: Datetime) {shiftDays(time: $time, days: 5)}") - variable_values = { + query.variable_values = { "time": now, } - result = client.execute(query, variable_values=variable_values) + result = client.execute(query) print(result) @@ -151,11 +151,11 @@ def test_shift_days_serialized_manually_in_variables(): query = gql("query shift5days($time: Datetime) {shiftDays(time: $time, days: 5)}") - variable_values = { + query.variable_values = { "time": "2021-11-12T11:58:13.461161", } - result = client.execute(query, variable_values=variable_values) + result = client.execute(query) print(result) @@ -171,13 +171,11 @@ def test_latest(): query = gql("query latest($times: [Datetime!]!) {latest(times: $times)}") - variable_values = { + query.variable_values = { "times": [now, in_five_days], } - result = client.execute( - query, variable_values=variable_values, serialize_variables=True - ) + result = client.execute(query, serialize_variables=True) print(result) @@ -194,11 +192,9 @@ def test_seconds(): "query seconds($interval: IntervalInput) {seconds(interval: $interval)}" ) - variable_values = {"interval": {"start": now, "end": in_five_days}} + query.variable_values = {"interval": {"start": now, "end": in_five_days}} - result = client.execute( - query, variable_values=variable_values, serialize_variables=True - ) + result = client.execute(query, serialize_variables=True) print(result) @@ -214,11 +210,9 @@ def test_seconds_omit_optional_start_argument(): "query seconds($interval: IntervalInput) {seconds(interval: $interval)}" ) - variable_values = {"interval": {"end": in_five_days}} + query.variable_values = {"interval": {"end": in_five_days}} - result = client.execute( - query, variable_values=variable_values, serialize_variables=True - ) + result = client.execute(query, serialize_variables=True) print(result) diff --git a/tests/custom_scalars/test_enum_colors.py b/tests/custom_scalars/test_enum_colors.py index 2f15a8ca..ff893571 100644 --- a/tests/custom_scalars/test_enum_colors.py +++ b/tests/custom_scalars/test_enum_colors.py @@ -1,4 +1,5 @@ from enum import Enum +from typing import Optional import pytest from graphql import ( @@ -6,6 +7,7 @@ GraphQLEnumType, GraphQLField, GraphQLList, + GraphQLNamedType, GraphQLNonNull, GraphQLObjectType, GraphQLSchema, @@ -163,11 +165,11 @@ def test_opposite_color_variable_serialized_manually(): }""" ) - variable_values = { + query.variable_values = { "color": "RED", } - result = client.execute(query, variable_values=variable_values) + result = client.execute(query) print(result) @@ -188,13 +190,11 @@ def test_opposite_color_variable_serialized_by_gql(): }""" ) - variable_values = { + query.variable_values = { "color": RED, } - result = client.execute( - query, variable_values=variable_values, serialize_variables=True - ) + result = client.execute(query, serialize_variables=True) print(result) @@ -251,19 +251,30 @@ def test_list_of_list_of_list(): def test_update_schema_enum(): - assert schema.get_type("Color").parse_value("RED") == Color.RED + color_type: Optional[GraphQLNamedType] + + color_type = schema.get_type("Color") + assert isinstance(color_type, GraphQLEnumType) + assert color_type is not None + assert color_type.parse_value("RED") == Color.RED # Using values update_schema_enum(schema, "Color", Color, use_enum_values=True) - assert schema.get_type("Color").parse_value("RED") == 0 - assert schema.get_type("Color").serialize(1) == "GREEN" + color_type = schema.get_type("Color") + assert isinstance(color_type, GraphQLEnumType) + assert color_type is not None + assert color_type.parse_value("RED") == 0 + assert color_type.serialize(1) == "GREEN" update_schema_enum(schema, "Color", Color) - assert schema.get_type("Color").parse_value("RED") == Color.RED - assert schema.get_type("Color").serialize(Color.RED) == "RED" + color_type = schema.get_type("Color") + assert isinstance(color_type, GraphQLEnumType) + assert color_type is not None + assert color_type.parse_value("RED") == Color.RED + assert color_type.serialize(Color.RED) == "RED" def test_update_schema_enum_errors(): @@ -273,20 +284,22 @@ def test_update_schema_enum_errors(): assert "Enum Corlo not found in schema!" in str(exc_info) - with pytest.raises(TypeError) as exc_info: - update_schema_enum(schema, "Color", 6) + with pytest.raises(TypeError) as exc_info2: + update_schema_enum(schema, "Color", 6) # type: ignore - assert "Invalid type for enum values: " in str(exc_info) + assert "Invalid type for enum values: " in str(exc_info2) - with pytest.raises(TypeError) as exc_info: + with pytest.raises(TypeError) as exc_info3: update_schema_enum(schema, "RootQueryType", Color) - assert 'The type "RootQueryType" is not a GraphQLEnumType, it is a' in str(exc_info) + assert 'The type "RootQueryType" is not a GraphQLEnumType, it is a' in str( + exc_info3 + ) - with pytest.raises(KeyError) as exc_info: + with pytest.raises(KeyError) as exc_info4: update_schema_enum(schema, "Color", {"RED": Color.RED}) - assert 'Enum key "GREEN" not found in provided values!' in str(exc_info) + assert 'Enum key "GREEN" not found in provided values!' in str(exc_info4) def test_parse_results_with_operation_type(): @@ -313,13 +326,12 @@ def test_parse_results_with_operation_type(): """ ) - variable_values = { + query.variable_values = { "color": "RED", } + query.operation_name = "GetOppositeColor" - result = client.execute( - query, variable_values=variable_values, operation_name="GetOppositeColor" - ) + result = client.execute(query) print(result) diff --git a/tests/custom_scalars/test_json.py b/tests/custom_scalars/test_json.py index d3eae3b8..903dfa6d 100644 --- a/tests/custom_scalars/test_json.py +++ b/tests/custom_scalars/test_json.py @@ -166,7 +166,7 @@ def test_json_value_input_in_ast_with_variables(): }""" ) - variable_values = { + query.variable_values = { "name": "Barbara", "level": 1, "is_connected": False, @@ -174,9 +174,7 @@ def test_json_value_input_in_ast_with_variables(): "friends": ["Alex", "John"], } - result = client.execute( - query, variable_values=variable_values, root_value=root_value - ) + result = client.execute(query, root_value=root_value) print(result) diff --git a/tests/custom_scalars/test_money.py b/tests/custom_scalars/test_money.py index 374c70e6..55a6577a 100644 --- a/tests/custom_scalars/test_money.py +++ b/tests/custom_scalars/test_money.py @@ -20,7 +20,7 @@ ) from graphql.utilities import value_from_ast_untyped -from gql import Client, GraphQLRequest, gql +from gql import Client, gql from gql.transport.exceptions import TransportQueryError from gql.utilities import serialize_value, update_schema_scalar, update_schema_scalars @@ -275,11 +275,9 @@ def test_custom_scalar_in_input_variable_values(): money_value = {"amount": 10, "currency": "DM"} - variable_values = {"money": money_value} + query.variable_values = {"money": money_value} - result = client.execute( - query, variable_values=variable_values, root_value=root_value - ) + result = client.execute(query, root_value=root_value) assert result["toEuros"] == 5 @@ -292,11 +290,10 @@ def test_custom_scalar_in_input_variable_values_serialized(): money_value = Money(10, "DM") - variable_values = {"money": money_value} + query.variable_values = {"money": money_value} result = client.execute( query, - variable_values=variable_values, root_value=root_value, serialize_variables=True, ) @@ -312,14 +309,13 @@ def test_custom_scalar_in_input_variable_values_serialized_with_operation_name() money_value = Money(10, "DM") - variable_values = {"money": money_value} + query.variable_values = {"money": money_value} + query.operation_name = "myquery" result = client.execute( query, - variable_values=variable_values, root_value=root_value, serialize_variables=True, - operation_name="myquery", ) assert result["toEuros"] == 5 @@ -342,12 +338,11 @@ def test_serialize_variable_values_exception_multiple_ops_without_operation_name money_value = Money(10, "DM") - variable_values = {"money": money_value} + query.variable_values = {"money": money_value} with pytest.raises(GraphQLError) as exc_info: client.execute( query, - variable_values=variable_values, root_value=root_value, serialize_variables=True, ) @@ -374,15 +369,14 @@ def test_serialize_variable_values_exception_operation_name_not_found(): money_value = Money(10, "DM") - variable_values = {"money": money_value} + query.variable_values = {"money": money_value} + query.operation_name = "invalid_operation_name" with pytest.raises(GraphQLError) as exc_info: client.execute( query, - variable_values=variable_values, root_value=root_value, serialize_variables=True, - operation_name="invalid_operation_name", ) exception = exc_info.value @@ -398,13 +392,12 @@ def test_custom_scalar_subscribe_in_input_variable_values_serialized(): money_value = Money(10, "DM") - variable_values = {"money": money_value} + query.variable_values = {"money": money_value} expected_result = {"spend": Money(10, "DM")} for result in client.subscribe( query, - variable_values=variable_values, root_value=root_value, serialize_variables=True, parse_result=True, @@ -441,9 +434,9 @@ def handle_single(data: Dict[str, Any]) -> ExecutionResult: [ { "data": result.data, - "errors": [str(e) for e in result.errors] - if result.errors - else None, + "errors": ( + [str(e) for e in result.errors] if result.errors else None + ), } for result in results ] @@ -453,9 +446,9 @@ def handle_single(data: Dict[str, Any]) -> ExecutionResult: return web.json_response( { "data": result.data, - "errors": [str(e) for e in result.errors] - if result.errors - else None, + "errors": ( + [str(e) for e in result.errors] if result.errors else None + ), } ) @@ -491,7 +484,7 @@ async def make_sync_money_transport(aiohttp_server): @pytest.mark.asyncio -async def test_custom_scalar_in_output_with_transport(event_loop, aiohttp_server): +async def test_custom_scalar_in_output_with_transport(aiohttp_server): transport = await make_money_transport(aiohttp_server) @@ -509,7 +502,7 @@ async def test_custom_scalar_in_output_with_transport(event_loop, aiohttp_server @pytest.mark.asyncio -async def test_custom_scalar_in_input_query_with_transport(event_loop, aiohttp_server): +async def test_custom_scalar_in_input_query_with_transport(aiohttp_server): transport = await make_money_transport(aiohttp_server) @@ -531,9 +524,7 @@ async def test_custom_scalar_in_input_query_with_transport(event_loop, aiohttp_s @pytest.mark.asyncio -async def test_custom_scalar_in_input_variable_values_with_transport( - event_loop, aiohttp_server -): +async def test_custom_scalar_in_input_variable_values_with_transport(aiohttp_server): transport = await make_money_transport(aiohttp_server) @@ -546,9 +537,9 @@ async def test_custom_scalar_in_input_variable_values_with_transport( money_value = {"amount": 10, "currency": "DM"} # money_value = Money(10, "DM") - variable_values = {"money": money_value} + query.variable_values = {"money": money_value} - result = await session.execute(query, variable_values=variable_values) + result = await session.execute(query) print(f"result = {result!r}") assert result["toEuros"] == 5 @@ -556,7 +547,7 @@ async def test_custom_scalar_in_input_variable_values_with_transport( @pytest.mark.asyncio async def test_custom_scalar_in_input_variable_values_split_with_transport( - event_loop, aiohttp_server + aiohttp_server, ): transport = await make_money_transport(aiohttp_server) @@ -572,16 +563,16 @@ async def test_custom_scalar_in_input_variable_values_split_with_transport( }""" ) - variable_values = {"amount": 10, "currency": "DM"} + query.variable_values = {"amount": 10, "currency": "DM"} - result = await session.execute(query, variable_values=variable_values) + result = await session.execute(query) print(f"result = {result!r}") assert result["toEuros"] == 5 @pytest.mark.asyncio -async def test_custom_scalar_serialize_variables(event_loop, aiohttp_server): +async def test_custom_scalar_serialize_variables(aiohttp_server): transport = await make_money_transport(aiohttp_server) @@ -592,18 +583,16 @@ async def test_custom_scalar_serialize_variables(event_loop, aiohttp_server): query = gql("query myquery($money: Money) {toEuros(money: $money)}") - variable_values = {"money": Money(10, "DM")} + query.variable_values = {"money": Money(10, "DM")} - result = await session.execute( - query, variable_values=variable_values, serialize_variables=True - ) + result = await session.execute(query, serialize_variables=True) print(f"result = {result!r}") assert result["toEuros"] == 5 @pytest.mark.asyncio -async def test_custom_scalar_serialize_variables_no_schema(event_loop, aiohttp_server): +async def test_custom_scalar_serialize_variables_no_schema(aiohttp_server): transport = await make_money_transport(aiohttp_server) @@ -613,17 +602,15 @@ async def test_custom_scalar_serialize_variables_no_schema(event_loop, aiohttp_s query = gql("query myquery($money: Money) {toEuros(money: $money)}") - variable_values = {"money": Money(10, "DM")} + query.variable_values = {"money": Money(10, "DM")} with pytest.raises(TransportQueryError): - await session.execute( - query, variable_values=variable_values, serialize_variables=True - ) + await session.execute(query, serialize_variables=True) @pytest.mark.asyncio async def test_custom_scalar_serialize_variables_schema_from_introspection( - event_loop, aiohttp_server + aiohttp_server, ): transport = await make_money_transport(aiohttp_server) @@ -645,18 +632,16 @@ async def test_custom_scalar_serialize_variables_schema_from_introspection( query = gql("query myquery($money: Money) {toEuros(money: $money)}") - variable_values = {"money": Money(10, "DM")} + query.variable_values = {"money": Money(10, "DM")} - result = await session.execute( - query, variable_values=variable_values, serialize_variables=True - ) + result = await session.execute(query, serialize_variables=True) print(f"result = {result!r}") assert result["toEuros"] == 5 @pytest.mark.asyncio -async def test_update_schema_scalars(event_loop, aiohttp_server): +async def test_update_schema_scalars(aiohttp_server): transport = await make_money_transport(aiohttp_server) @@ -669,11 +654,9 @@ async def test_update_schema_scalars(event_loop, aiohttp_server): query = gql("query myquery($money: Money) {toEuros(money: $money)}") - variable_values = {"money": Money(10, "DM")} + query.variable_values = {"money": Money(10, "DM")} - result = await session.execute( - query, variable_values=variable_values, serialize_variables=True - ) + result = await session.execute(query, serialize_variables=True) print(f"result = {result!r}") assert result["toEuros"] == 5 @@ -682,14 +665,14 @@ async def test_update_schema_scalars(event_loop, aiohttp_server): def test_update_schema_scalars_invalid_scalar(): with pytest.raises(TypeError) as exc_info: - update_schema_scalars(schema, [int]) + update_schema_scalars(schema, [int]) # type: ignore exception = exc_info.value assert str(exception) == "Scalars should be instances of GraphQLScalarType." with pytest.raises(TypeError) as exc_info: - update_schema_scalar(schema, "test", int) + update_schema_scalar(schema, "test", int) # type: ignore exception = exc_info.value @@ -699,7 +682,7 @@ def test_update_schema_scalars_invalid_scalar(): def test_update_schema_scalars_invalid_scalar_argument(): with pytest.raises(TypeError) as exc_info: - update_schema_scalars(schema, MoneyScalar) + update_schema_scalars(schema, MoneyScalar) # type: ignore exception = exc_info.value @@ -735,7 +718,7 @@ def test_update_schema_scalars_scalar_type_is_not_a_scalar_in_schema(): @pytest.mark.asyncio @pytest.mark.requests async def test_custom_scalar_serialize_variables_sync_transport( - event_loop, aiohttp_server, run_sync_test + aiohttp_server, run_sync_test ): server, transport = await make_sync_money_transport(aiohttp_server) @@ -745,22 +728,20 @@ def test_code(): query = gql("query myquery($money: Money) {toEuros(money: $money)}") - variable_values = {"money": Money(10, "DM")} + query.variable_values = {"money": Money(10, "DM")} - result = session.execute( - query, variable_values=variable_values, serialize_variables=True - ) + result = session.execute(query, serialize_variables=True) print(f"result = {result!r}") assert result["toEuros"] == 5 - await run_sync_test(event_loop, server, test_code) + await run_sync_test(server, test_code) @pytest.mark.asyncio @pytest.mark.requests async def test_custom_scalar_serialize_variables_sync_transport_2( - event_loop, aiohttp_server, run_sync_test + aiohttp_server, run_sync_test ): server, transport = await make_sync_money_transport(aiohttp_server) @@ -769,12 +750,12 @@ def test_code(): query = gql("query myquery($money: Money) {toEuros(money: $money)}") - variable_values = {"money": Money(10, "DM")} + query.variable_values = {"money": Money(10, "DM")} results = session.execute_batch( [ - GraphQLRequest(document=query, variable_values=variable_values), - GraphQLRequest(document=query, variable_values=variable_values), + query, + query, ], serialize_variables=True, ) @@ -783,13 +764,39 @@ def test_code(): assert results[0]["toEuros"] == 5 assert results[1]["toEuros"] == 5 - await run_sync_test(event_loop, server, test_code) + await run_sync_test(server, test_code) + + +@pytest.mark.asyncio +@pytest.mark.aiohttp +async def test_custom_scalar_serialize_variables_async_transport(aiohttp_server): + transport = await make_money_transport(aiohttp_server) + + async with Client( + schema=schema, transport=transport, parse_results=True + ) as session: + + query = gql("query myquery($money: Money) {toEuros(money: $money)}") + + query.variable_values = {"money": Money(10, "DM")} + + results = await session.execute_batch( + [ + query, + query, + ], + serialize_variables=True, + ) + + print(f"result = {results!r}") + assert results[0]["toEuros"] == 5 + assert results[1]["toEuros"] == 5 def test_serialize_value_with_invalid_type(): with pytest.raises(GraphQLError) as exc_info: - serialize_value("Not a valid type", 50) + serialize_value("Not a valid type", 50) # type: ignore exception = exc_info.value @@ -818,7 +825,7 @@ def test_serialize_value_with_nullable_type(): @pytest.mark.asyncio -async def test_gql_cli_print_schema(event_loop, aiohttp_server, capsys): +async def test_gql_cli_print_schema(aiohttp_server, capsys): from gql.cli import get_parser, main diff --git a/tests/custom_scalars/test_parse_results.py b/tests/custom_scalars/test_parse_results.py index e3c6d6f6..32812818 100644 --- a/tests/custom_scalars/test_parse_results.py +++ b/tests/custom_scalars/test_parse_results.py @@ -93,6 +93,5 @@ def test_parse_results_null_mapping(): } }""" ) - assert client.execute(query, variable_values={"count": 2}) == { - "test": static_result - } + query.variable_values = {"count": 2} + assert client.execute(query) == {"test": static_result} diff --git a/tests/fixtures/aws/fake_signer.py b/tests/fixtures/aws/fake_signer.py index c0177a32..61e80fa0 100644 --- a/tests/fixtures/aws/fake_signer.py +++ b/tests/fixtures/aws/fake_signer.py @@ -12,10 +12,10 @@ def _fake_signer_factory(request=None): class FakeSigner: - def __init__(self, request=None) -> None: + def __init__(self, request=None): self.request = request - def add_auth(self, request) -> None: + def add_auth(self, request): """ A fake for getting a request object that :return: diff --git a/tests/regressions/issue_447_dsl_missing_directives/test_dsl_directives.py b/tests/regressions/issue_447_dsl_missing_directives/test_dsl_directives.py index b31ade7f..67c2e739 100644 --- a/tests/regressions/issue_447_dsl_missing_directives/test_dsl_directives.py +++ b/tests/regressions/issue_447_dsl_missing_directives/test_dsl_directives.py @@ -1,3 +1,5 @@ +from graphql import GraphQLSchema + from gql import Client, gql from gql.dsl import DSLFragment, DSLQuery, DSLSchema, dsl_gql, print_ast from gql.utilities import node_tree @@ -34,6 +36,9 @@ def test_issue_447(): client = Client(schema=schema_str) + + assert isinstance(client.schema, GraphQLSchema) + ds = DSLSchema(client.schema) sprite = DSLFragment("SpriteUnionAsSprite") @@ -60,10 +65,10 @@ def test_issue_447(): client.validate(q) # Creating a tree from the DocumentNode created by dsl_gql - dsl_tree = node_tree(q) + dsl_tree = node_tree(q.document) # Creating a tree from the DocumentNode created by gql - gql_tree = node_tree(gql(print_ast(q))) + gql_tree = node_tree(gql(print_ast(q.document)).document) print("=======") print(dsl_tree) diff --git a/tests/starwars/fixtures.py b/tests/starwars/fixtures.py index 59d7ddfa..1d179f60 100644 --- a/tests/starwars/fixtures.py +++ b/tests/starwars/fixtures.py @@ -148,9 +148,10 @@ def create_review(episode, review): async def make_starwars_backend(aiohttp_server): from aiohttp import web - from .schema import StarWarsSchema from graphql import graphql_sync + from .schema import StarWarsSchema + async def handler(request): data = await request.json() source = data["query"] diff --git a/tests/starwars/schema.py b/tests/starwars/schema.py index 4b672ad3..8f1efe99 100644 --- a/tests/starwars/schema.py +++ b/tests/starwars/schema.py @@ -1,4 +1,5 @@ import asyncio +from typing import cast from graphql import ( GraphQLArgument, @@ -14,6 +15,7 @@ GraphQLObjectType, GraphQLSchema, GraphQLString, + IntrospectionQuery, get_introspection_query, graphql_sync, print_schema, @@ -271,6 +273,8 @@ async def resolve_review(review, _info, **_args): ) -StarWarsIntrospection = graphql_sync(StarWarsSchema, get_introspection_query()).data +StarWarsIntrospection = cast( + IntrospectionQuery, graphql_sync(StarWarsSchema, get_introspection_query()).data +) StarWarsTypeDef = print_schema(StarWarsSchema) diff --git a/tests/starwars/test_dsl.py b/tests/starwars/test_dsl.py index 5cd051ba..e47a97d8 100644 --- a/tests/starwars/test_dsl.py +++ b/tests/starwars/test_dsl.py @@ -4,6 +4,7 @@ GraphQLError, GraphQLFloat, GraphQLID, + GraphQLInputObjectType, GraphQLInt, GraphQLList, GraphQLNonNull, @@ -53,6 +54,7 @@ def client(): def test_ast_from_value_with_input_type_and_not_mapping_value(): obj_type = StarWarsSchema.get_type("ReviewInput") + assert isinstance(obj_type, GraphQLInputObjectType) assert ast_from_value(8, obj_type) is None @@ -78,7 +80,7 @@ def test_ast_from_value_with_graphqlid(): def test_ast_from_value_with_invalid_type(): with pytest.raises(TypeError) as exc_info: - ast_from_value(4, None) + ast_from_value(4, None) # type: ignore assert "Unexpected input type: None." in str(exc_info.value) @@ -114,7 +116,10 @@ def test_ast_from_serialized_value_untyped_typeerror(): def test_variable_to_ast_type_passing_wrapping_type(): - wrapping_type = GraphQLNonNull(GraphQLList(StarWarsSchema.get_type("ReviewInput"))) + review_type = StarWarsSchema.get_type("ReviewInput") + assert isinstance(review_type, GraphQLInputObjectType) + + wrapping_type = GraphQLNonNull(GraphQLList(review_type)) variable = DSLVariable("review_input") ast = variable.to_ast_type(wrapping_type) assert ast == NonNullTypeNode( @@ -138,7 +143,7 @@ def test_use_variable_definition_multiple_times(ds): query = dsl_gql(op) assert ( - print_ast(query) + print_ast(query.document) == """mutation \ ($badReview: ReviewInput, $episode: Episode, $goodReview: ReviewInput) { badReview: createReview(review: $badReview, episode: $episode) { @@ -152,7 +157,9 @@ def test_use_variable_definition_multiple_times(ds): }""" ) - assert node_tree(query) == node_tree(gql(print_ast(query))) + assert node_tree(query.document) == node_tree( + gql(print_ast(query.document)).document + ) def test_add_variable_definitions(ds): @@ -166,7 +173,7 @@ def test_add_variable_definitions(ds): query = dsl_gql(op) assert ( - print_ast(query) + print_ast(query.document) == """mutation ($review: ReviewInput, $episode: Episode) { createReview(review: $review, episode: $episode) { stars @@ -175,7 +182,9 @@ def test_add_variable_definitions(ds): }""" ) - assert node_tree(query) == node_tree(gql(print_ast(query))) + assert node_tree(query.document) == node_tree( + gql(print_ast(query.document)).document + ) def test_add_variable_definitions_with_default_value_enum(ds): @@ -189,7 +198,7 @@ def test_add_variable_definitions_with_default_value_enum(ds): query = dsl_gql(op) assert ( - print_ast(query) + print_ast(query.document) == """mutation ($review: ReviewInput, $episode: Episode = NEWHOPE) { createReview(review: $review, episode: $episode) { stars @@ -211,7 +220,7 @@ def test_add_variable_definitions_with_default_value_input_object(ds): query = dsl_gql(op) assert ( - strip_braces_spaces(print_ast(query)) + strip_braces_spaces(print_ast(query.document)) == """ mutation ($review: ReviewInput = {stars: 5, commentary: "Wow!"}, $episode: Episode) { createReview(review: $review, episode: $episode) { @@ -221,7 +230,9 @@ def test_add_variable_definitions_with_default_value_input_object(ds): }""".strip() ) - assert node_tree(query) == node_tree(gql(print_ast(query))) + assert node_tree(query.document) == node_tree( + gql(print_ast(query.document)).document + ) def test_add_variable_definitions_in_input_object(ds): @@ -236,7 +247,7 @@ def test_add_variable_definitions_in_input_object(ds): query = dsl_gql(op) assert ( - strip_braces_spaces(print_ast(query)) + strip_braces_spaces(print_ast(query.document)) == """mutation ($stars: Int, $commentary: String, $episode: Episode) { createReview( review: {stars: $stars, commentary: $commentary} @@ -248,7 +259,9 @@ def test_add_variable_definitions_in_input_object(ds): }""" ) - assert node_tree(query) == node_tree(gql(print_ast(query))) + assert node_tree(query.document) == node_tree( + gql(print_ast(query.document)).document + ) def test_invalid_field_on_type_query(ds): @@ -383,7 +396,7 @@ def test_fetch_luke_aliased(ds): assert query == str(query_dsl) -def test_fetch_name_aliased(ds: DSLSchema): +def test_fetch_name_aliased(ds: DSLSchema) -> None: query = """ human(id: "1000") { my_name: name @@ -394,7 +407,7 @@ def test_fetch_name_aliased(ds: DSLSchema): assert query == str(query_dsl) -def test_fetch_name_aliased_as_kwargs(ds: DSLSchema): +def test_fetch_name_aliased_as_kwargs(ds: DSLSchema) -> None: query = """ human(id: "1000") { my_name: name @@ -411,7 +424,9 @@ def test_hero_name_query_result(ds, client): result = client.execute(query) expected = {"hero": {"name": "R2-D2"}} assert result == expected - assert node_tree(query) == node_tree(gql(print_ast(query))) + assert node_tree(query.document) == node_tree( + gql(print_ast(query.document)).document + ) def test_arg_serializer_list(ds, client): @@ -431,7 +446,9 @@ def test_arg_serializer_list(ds, client): ] } assert result == expected - assert node_tree(query) == node_tree(gql(print_ast(query))) + assert node_tree(query.document) == node_tree( + gql(print_ast(query.document)).document + ) def test_arg_serializer_enum(ds, client): @@ -439,7 +456,9 @@ def test_arg_serializer_enum(ds, client): result = client.execute(query) expected = {"hero": {"name": "Luke Skywalker"}} assert result == expected - assert node_tree(query) == node_tree(gql(print_ast(query))) + assert node_tree(query.document) == node_tree( + gql(print_ast(query.document)).document + ) def test_create_review_mutation_result(ds, client): @@ -454,7 +473,9 @@ def test_create_review_mutation_result(ds, client): result = client.execute(query) expected = {"createReview": {"stars": 5, "commentary": "This is a great movie!"}} assert result == expected - assert node_tree(query) == node_tree(gql(print_ast(query))) + assert node_tree(query.document) == node_tree( + gql(print_ast(query.document)).document + ) def test_subscription(ds): @@ -467,7 +488,7 @@ def test_subscription(ds): ) ) assert ( - print_ast(query) + print_ast(query.document) == """subscription { reviewAdded(episode: JEDI) { stars @@ -476,7 +497,9 @@ def test_subscription(ds): }""" ) - assert node_tree(query) == node_tree(gql(print_ast(query))) + assert node_tree(query.document) == node_tree( + gql(print_ast(query.document)).document + ) def test_field_does_not_exit_in_type(ds): @@ -517,7 +540,9 @@ def test_multiple_root_fields(ds, client): "hero_of_episode_5": {"name": "Luke Skywalker"}, } assert result == expected - assert node_tree(query) == node_tree(gql(print_ast(query))) + assert node_tree(query.document) == node_tree( + gql(print_ast(query.document)).document + ) def test_root_fields_aliased(ds, client): @@ -533,7 +558,9 @@ def test_root_fields_aliased(ds, client): "hero_of_episode_5": {"name": "Luke Skywalker"}, } assert result == expected - assert node_tree(query) == node_tree(gql(print_ast(query))) + assert node_tree(query.document) == node_tree( + gql(print_ast(query.document)).document + ) def test_operation_name(ds): @@ -544,7 +571,7 @@ def test_operation_name(ds): ) assert ( - print_ast(query) + print_ast(query.document) == """query GetHeroName { hero { name @@ -552,7 +579,9 @@ def test_operation_name(ds): }""" ) - assert node_tree(query) == node_tree(gql(print_ast(query))) + assert node_tree(query.document) == node_tree( + gql(print_ast(query.document)).document + ) def test_multiple_operations(ds): @@ -566,7 +595,7 @@ def test_multiple_operations(ds): ) assert ( - strip_braces_spaces(print_ast(query)) + strip_braces_spaces(print_ast(query.document)) == """query GetHeroName { hero { name @@ -584,7 +613,9 @@ def test_multiple_operations(ds): }""" ) - assert node_tree(query) == node_tree(gql(print_ast(query))) + assert node_tree(query.document) == node_tree( + gql(print_ast(query.document)).document + ) def test_inline_fragments(ds): @@ -651,12 +682,14 @@ def test_fragments(ds): query_dsl = DSLQuery(ds.Query.hero.select(name_and_appearances)) - document = dsl_gql(name_and_appearances, query_dsl) + request = dsl_gql(name_and_appearances, query_dsl) + + document = request.document print(print_ast(document)) assert query == print_ast(document) - assert node_tree(document) == node_tree(gql(print_ast(document))) + assert node_tree(document) == node_tree(gql(print_ast(document)).document) def test_fragment_without_type_condition_error(ds): @@ -748,12 +781,14 @@ def test_dsl_nested_query_with_fragment(ds): ) ) - document = dsl_gql(name_and_appearances, NestedQueryWithFragment=query_dsl) + request = dsl_gql(name_and_appearances, NestedQueryWithFragment=query_dsl) + + document = request.document print(print_ast(document)) assert query == print_ast(document) - assert node_tree(document) == node_tree(gql(print_ast(document))) + assert node_tree(document) == node_tree(gql(print_ast(document)).document) # Same thing, but incrementaly @@ -774,12 +809,14 @@ def test_dsl_nested_query_with_fragment(ds): query_dsl = DSLQuery(hero) - document = dsl_gql(name_and_appearances, NestedQueryWithFragment=query_dsl) + request = dsl_gql(name_and_appearances, NestedQueryWithFragment=query_dsl) + + document = request.document print(print_ast(document)) assert query == print_ast(document) - assert node_tree(document) == node_tree(gql(print_ast(document))) + assert node_tree(document) == node_tree(gql(print_ast(document)).document) def test_dsl_query_all_fields_should_be_instances_of_DSLField(): @@ -787,7 +824,7 @@ def test_dsl_query_all_fields_should_be_instances_of_DSLField(): TypeError, match="Fields should be instances of DSLSelectable. Received: ", ): - DSLQuery("I am a string") + DSLQuery("I am a string") # type: ignore def test_dsl_query_all_fields_should_correspond_to_the_root_type(ds): @@ -823,7 +860,7 @@ def test_dsl_root_type_not_default(): version } """ - assert print_ast(query) == expected_query.strip() + assert print_ast(query.document) == expected_query.strip() with pytest.raises(GraphQLError) as excinfo: DSLSubscription(ds.QueryNotDefault.version) @@ -832,14 +869,16 @@ def test_dsl_root_type_not_default(): "Invalid field for : " ) in str(excinfo.value) - assert node_tree(query) == node_tree(gql(print_ast(query))) + assert node_tree(query.document) == node_tree( + gql(print_ast(query.document)).document + ) def test_dsl_gql_all_arguments_should_be_operations_or_fragments(): with pytest.raises( TypeError, match="Operations should be instances of DSLExecutable " ): - dsl_gql("I am a string") + dsl_gql("I am a string") # type: ignore def test_DSLSchema_requires_a_schema(client): @@ -920,7 +959,7 @@ def test_type_hero_query(ds): ) query_dsl = DSLQuery(type_hero) - assert query == str(print_ast(dsl_gql(query_dsl))).strip() + assert query == str(print_ast(dsl_gql(query_dsl).document)).strip() def test_invalid_meta_field_selection(ds): @@ -995,9 +1034,11 @@ def test_get_introspection_query_ast(option): ) try: - assert print_ast(gql(introspection_query)) == print_ast(dsl_introspection_query) + assert print_ast(gql(introspection_query).document) == print_ast( + dsl_introspection_query + ) assert node_tree(dsl_introspection_query) == node_tree( - gql(print_ast(dsl_introspection_query)) + gql(print_ast(dsl_introspection_query)).document ) except AssertionError: @@ -1010,9 +1051,11 @@ def test_get_introspection_query_ast(option): input_value_deprecation=option, type_recursion_level=9, ) - assert print_ast(gql(introspection_query)) == print_ast(dsl_introspection_query) + assert print_ast(gql(introspection_query).document) == print_ast( + dsl_introspection_query + ) assert node_tree(dsl_introspection_query) == node_tree( - gql(print_ast(dsl_introspection_query)) + gql(print_ast(dsl_introspection_query)).document ) @@ -1042,7 +1085,7 @@ def test_node_tree_with_loc(ds): } }""".strip() - document = gql(query) + document = gql(query).document node_tree_result = """ DocumentNode @@ -1227,4 +1270,4 @@ def test_legacy_fragment_with_variables(ds): } } """.strip() - assert print_ast(query) == expected + assert print_ast(query.document) == expected diff --git a/tests/starwars/test_introspection.py b/tests/starwars/test_introspection.py index c3063808..9e5ff4aa 100644 --- a/tests/starwars/test_introspection.py +++ b/tests/starwars/test_introspection.py @@ -10,7 +10,7 @@ @pytest.mark.asyncio -async def test_starwars_introspection_args(event_loop, aiohttp_server): +async def test_starwars_introspection_args(aiohttp_server): transport = await make_starwars_transport(aiohttp_server) @@ -19,6 +19,9 @@ async def test_starwars_introspection_args(event_loop, aiohttp_server): async with Client( transport=transport, fetch_schema_from_transport=True, + introspection_args={ + "input_value_deprecation": False, + }, ) as session: schema_str = print_schema(session.client.schema) @@ -35,6 +38,7 @@ async def test_starwars_introspection_args(event_loop, aiohttp_server): fetch_schema_from_transport=True, introspection_args={ "descriptions": False, + "input_value_deprecation": False, }, ) as session: @@ -50,9 +54,6 @@ async def test_starwars_introspection_args(event_loop, aiohttp_server): async with Client( transport=transport, fetch_schema_from_transport=True, - introspection_args={ - "input_value_deprecation": True, - }, ) as session: schema_str = print_schema(session.client.schema) diff --git a/tests/starwars/test_parse_results.py b/tests/starwars/test_parse_results.py index e8f3f8d4..2ae94ea8 100644 --- a/tests/starwars/test_parse_results.py +++ b/tests/starwars/test_parse_results.py @@ -1,3 +1,5 @@ +from typing import Any, Dict + import pytest from graphql import GraphQLError @@ -20,6 +22,7 @@ def test_hero_name_and_friends_query(): } """ ) + result = { "hero": { "id": "2001", @@ -32,7 +35,7 @@ def test_hero_name_and_friends_query(): } } - parsed_result = parse_result(StarWarsSchema, query, result) + parsed_result = parse_result(StarWarsSchema, query.document, result) assert result == parsed_result @@ -56,6 +59,7 @@ def test_hero_name_and_friends_query_with_fragment(): } """ ) + result = { "hero": { "id": "2001", @@ -68,7 +72,7 @@ def test_hero_name_and_friends_query_with_fragment(): } } - parsed_result = parse_result(StarWarsSchema, query, result) + parsed_result = parse_result(StarWarsSchema, query.document, result) assert result == parsed_result @@ -87,9 +91,9 @@ def test_key_not_found_in_result(): # Backend returned an invalid result without the hero key # Should be impossible. In that case, we ignore the missing key - result = {} + result: Dict[str, Any] = {} - parsed_result = parse_result(StarWarsSchema, query, result) + parsed_result = parse_result(StarWarsSchema, query.document, result) assert result == parsed_result @@ -110,7 +114,7 @@ def test_invalid_result_raise_error(): with pytest.raises(GraphQLError) as exc_info: - parse_result(StarWarsSchema, query, result) + parse_result(StarWarsSchema, query.document, result) assert "Invalid result for container of field id: 5" in str(exc_info) @@ -139,7 +143,7 @@ def test_fragment(): "leia": {"name": "Leia Organa", "homePlanet": "Alderaan"}, } - parsed_result = parse_result(StarWarsSchema, query, result) + parsed_result = parse_result(StarWarsSchema, query.document, result) assert result == parsed_result @@ -162,7 +166,7 @@ def test_fragment_not_found(): with pytest.raises(GraphQLError) as exc_info: - parse_result(StarWarsSchema, query, result) + parse_result(StarWarsSchema, query.document, result) assert 'Fragment "HumanFragment" not found in document!' in str(exc_info) @@ -181,7 +185,7 @@ def test_return_none_if_result_is_none(): result = None - assert parse_result(StarWarsSchema, query, result) is None + assert parse_result(StarWarsSchema, query.document, result) is None def test_null_result_is_allowed(): @@ -198,7 +202,7 @@ def test_null_result_is_allowed(): result = {"hero": None} - parsed_result = parse_result(StarWarsSchema, query, result) + parsed_result = parse_result(StarWarsSchema, query.document, result) assert result == parsed_result @@ -222,6 +226,6 @@ def test_inline_fragment(): "luke": {"name": "Luke Skywalker", "homePlanet": "Tatooine"}, } - parsed_result = parse_result(StarWarsSchema, query, result) + parsed_result = parse_result(StarWarsSchema, query.document, result) assert result == parsed_result diff --git a/tests/starwars/test_query.py b/tests/starwars/test_query.py index bf15e11a..ff2af7d7 100644 --- a/tests/starwars/test_query.py +++ b/tests/starwars/test_query.py @@ -1,5 +1,5 @@ import pytest -from graphql import GraphQLError, Source +from graphql import GraphQLError from gql import Client, gql from tests.starwars.schema import StarWarsSchema @@ -136,11 +136,11 @@ def test_fetch_some_id_query(client): } """ ) - params = { + query.variable_values = { "someId": "1000", } expected = {"human": {"name": "Luke Skywalker"}} - result = client.execute(query, variable_values=params) + result = client.execute(query) assert result == expected @@ -154,11 +154,11 @@ def test_fetch_some_id_query2(client): } """ ) - params = { + query.variable_values = { "someId": "1002", } expected = {"human": {"name": "Han Solo"}} - result = client.execute(query, variable_values=params) + result = client.execute(query) assert result == expected @@ -172,11 +172,11 @@ def test_invalid_id_query(client): } """ ) - params = { + query.variable_values = { "id": "not a valid id", } expected = {"human": None} - result = client.execute(query, variable_values=params) + result = client.execute(query) assert result == expected @@ -316,24 +316,10 @@ def test_mutation_result(client): } """ ) - params = { + query.variable_values = { "ep": "JEDI", "review": {"stars": 5, "commentary": "This is a great movie!"}, } expected = {"createReview": {"stars": 5, "commentary": "This is a great movie!"}} - result = client.execute(query, variable_values=params) - assert result == expected - - -def test_query_from_source(client): - source = Source("{ hero { name } }") - query = gql(source) - expected = {"hero": {"name": "R2-D2"}} result = client.execute(query) assert result == expected - - -def test_already_parsed_query(client): - query = gql("{ hero { name } }") - with pytest.raises(TypeError, match="must be passed as a string"): - gql(query) diff --git a/tests/starwars/test_subscription.py b/tests/starwars/test_subscription.py index 0f412acc..4f5f425b 100644 --- a/tests/starwars/test_subscription.py +++ b/tests/starwars/test_subscription.py @@ -41,7 +41,7 @@ async def test_subscription_support(): expected = [{**review, "episode": "JEDI"} for review in reviews[6]] ai = await await_if_coroutine( - subscribe(StarWarsSchema, subs, variable_values=params) + subscribe(StarWarsSchema, subs.document, variable_values=params) ) result = [result.data["reviewAdded"] async for result in ai] @@ -59,14 +59,14 @@ async def test_subscription_support_using_client(): subs = gql(subscription_str) - params = {"ep": "JEDI"} + subs.variable_values = {"ep": "JEDI"} expected = [{**review, "episode": "JEDI"} for review in reviews[6]] async with Client(schema=StarWarsSchema) as session: results = [ result["reviewAdded"] async for result in await await_if_coroutine( - session.subscribe(subs, variable_values=params, parse_result=False) + session.subscribe(subs, parse_result=False) ) ] @@ -85,7 +85,7 @@ async def test_subscription_support_using_client_invalid_field(): subs = gql(subscription_invalid_str) - params = {"ep": "JEDI"} + subs.variable_values = {"ep": "JEDI"} async with Client(schema=StarWarsSchema) as session: @@ -93,7 +93,7 @@ async def test_subscription_support_using_client_invalid_field(): results = [ result async for result in await await_if_coroutine( - session.transport.subscribe(subs, variable_values=params) + session.transport.subscribe(subs) ) ] diff --git a/tests/starwars/test_validation.py b/tests/starwars/test_validation.py index 1ca8a2bb..75ce4162 100644 --- a/tests/starwars/test_validation.py +++ b/tests/starwars/test_validation.py @@ -1,3 +1,5 @@ +import copy + import pytest from gql import Client, gql @@ -62,7 +64,8 @@ def introspection_schema(): @pytest.fixture def introspection_schema_empty_directives(): - introspection = StarWarsIntrospection + # Create a deep copy to avoid modifying the original + introspection = copy.deepcopy(StarWarsIntrospection) # Simulate an empty dictionary for directives introspection["__schema"]["directives"] = [] @@ -72,10 +75,11 @@ def introspection_schema_empty_directives(): @pytest.fixture def introspection_schema_no_directives(): - introspection = StarWarsIntrospection + # Create a deep copy to avoid modifying the original + introspection = copy.deepcopy(StarWarsIntrospection) # Simulate no directives key - del introspection["__schema"]["directives"] + del introspection["__schema"]["directives"] # type: ignore return Client(introspection=introspection) @@ -104,7 +108,7 @@ def validation_errors(client, query): def test_incompatible_request_gql(client): with pytest.raises(TypeError): - gql(123) + gql(123) # type: ignore """ The error generated depends on graphql-core version @@ -249,7 +253,7 @@ def test_build_client_schema_invalid_introspection(): from gql.utilities import build_client_schema with pytest.raises(TypeError) as exc_info: - build_client_schema("blah") + build_client_schema("blah") # type: ignore assert ( "Invalid or incomplete introspection result. Ensure that you are passing the " diff --git a/tests/test_aiohttp.py b/tests/test_aiohttp.py index 81af20ff..506b04f4 100644 --- a/tests/test_aiohttp.py +++ b/tests/test_aiohttp.py @@ -1,14 +1,17 @@ import io import json +import os +import warnings from typing import Mapping import pytest -from gql import Client, gql +from gql import Client, FileVar, gql from gql.cli import get_parser, main from gql.transport.exceptions import ( TransportAlreadyConnected, TransportClosed, + TransportConnectionFailed, TransportProtocolError, TransportQueryError, TransportServerError, @@ -17,7 +20,7 @@ from .conftest import ( TemporaryFile, get_localhost_ssl_context_client, - strip_braces_spaces, + make_upload_handler, ) query1_str = """ @@ -45,8 +48,9 @@ @pytest.mark.asyncio -async def test_aiohttp_query(event_loop, aiohttp_server): +async def test_aiohttp_query(aiohttp_server): from aiohttp import web + from gql.transport.aiohttp import AIOHTTPTransport async def handler(request): @@ -84,8 +88,9 @@ async def handler(request): @pytest.mark.asyncio -async def test_aiohttp_ignore_backend_content_type(event_loop, aiohttp_server): +async def test_aiohttp_ignore_backend_content_type(aiohttp_server): from aiohttp import web + from gql.transport.aiohttp import AIOHTTPTransport async def handler(request): @@ -113,8 +118,9 @@ async def handler(request): @pytest.mark.asyncio -async def test_aiohttp_cookies(event_loop, aiohttp_server): +async def test_aiohttp_cookies(aiohttp_server): from aiohttp import web + from gql.transport.aiohttp import AIOHTTPTransport async def handler(request): @@ -146,8 +152,9 @@ async def handler(request): @pytest.mark.asyncio -async def test_aiohttp_error_code_401(event_loop, aiohttp_server): +async def test_aiohttp_error_code_401(aiohttp_server): from aiohttp import web + from gql.transport.aiohttp import AIOHTTPTransport async def handler(request): @@ -177,8 +184,9 @@ async def handler(request): @pytest.mark.asyncio -async def test_aiohttp_error_code_429(event_loop, aiohttp_server): +async def test_aiohttp_error_code_429(aiohttp_server): from aiohttp import web + from gql.transport.aiohttp import AIOHTTPTransport async def handler(request): @@ -224,8 +232,9 @@ async def handler(request): @pytest.mark.asyncio -async def test_aiohttp_error_code_500(event_loop, aiohttp_server): +async def test_aiohttp_error_code_500(aiohttp_server): from aiohttp import web + from gql.transport.aiohttp import AIOHTTPTransport async def handler(request): @@ -259,8 +268,9 @@ async def handler(request): @pytest.mark.asyncio @pytest.mark.parametrize("query_error", transport_query_error_responses) -async def test_aiohttp_error_code(event_loop, aiohttp_server, query_error): +async def test_aiohttp_error_code(aiohttp_server, query_error): from aiohttp import web + from gql.transport.aiohttp import AIOHTTPTransport async def handler(request): @@ -286,27 +296,28 @@ async def handler(request): { "response": "{}", "expected_exception": ( - "Server did not return a GraphQL result: " + "Server did not return a valid GraphQL result: " 'No "data" or "errors" keys in answer: {}' ), }, { "response": "qlsjfqsdlkj", "expected_exception": ( - "Server did not return a GraphQL result: Not a JSON answer: qlsjfqsdlkj" + "Server did not return a valid GraphQL result: " + "Not a JSON answer: qlsjfqsdlkj" ), }, { "response": '{"not_data_or_errors": 35}', "expected_exception": ( - "Server did not return a GraphQL result: " + "Server did not return a valid GraphQL result: " 'No "data" or "errors" keys in answer: {"not_data_or_errors": 35}' ), }, { "response": "", "expected_exception": ( - "Server did not return a GraphQL result: Not a JSON answer: " + "Server did not return a valid GraphQL result: Not a JSON answer: " ), }, ] @@ -314,8 +325,9 @@ async def handler(request): @pytest.mark.asyncio @pytest.mark.parametrize("param", invalid_protocol_responses) -async def test_aiohttp_invalid_protocol(event_loop, aiohttp_server, param): +async def test_aiohttp_invalid_protocol(aiohttp_server, param): from aiohttp import web + from gql.transport.aiohttp import AIOHTTPTransport response = param["response"] @@ -342,8 +354,9 @@ async def handler(request): @pytest.mark.asyncio -async def test_aiohttp_subscribe_not_supported(event_loop, aiohttp_server): +async def test_aiohttp_subscribe_not_supported(aiohttp_server): from aiohttp import web + from gql.transport.aiohttp import AIOHTTPTransport async def handler(request): @@ -367,8 +380,9 @@ async def handler(request): @pytest.mark.asyncio -async def test_aiohttp_cannot_connect_twice(event_loop, aiohttp_server): +async def test_aiohttp_cannot_connect_twice(aiohttp_server): from aiohttp import web + from gql.transport.aiohttp import AIOHTTPTransport async def handler(request): @@ -389,8 +403,9 @@ async def handler(request): @pytest.mark.asyncio -async def test_aiohttp_cannot_execute_if_not_connected(event_loop, aiohttp_server): +async def test_aiohttp_cannot_execute_if_not_connected(aiohttp_server): from aiohttp import web + from gql.transport.aiohttp import AIOHTTPTransport async def handler(request): @@ -411,8 +426,9 @@ async def handler(request): @pytest.mark.asyncio -async def test_aiohttp_extra_args(event_loop, aiohttp_server): +async def test_aiohttp_extra_args(aiohttp_server): from aiohttp import web + from gql.transport.aiohttp import AIOHTTPTransport async def handler(request): @@ -458,8 +474,9 @@ async def handler(request): @pytest.mark.asyncio -async def test_aiohttp_query_variable_values(event_loop, aiohttp_server): +async def test_aiohttp_query_variable_values(aiohttp_server): from aiohttp import web + from gql.transport.aiohttp import AIOHTTPTransport async def handler(request): @@ -475,14 +492,13 @@ async def handler(request): async with Client(transport=transport) as session: - params = {"code": "EU"} - query = gql(query2_str) + query.variable_values = {"code": "EU"} + query.operation_name = "getEurope" + # Execute query asynchronously - result = await session.execute( - query, variable_values=params, operation_name="getEurope" - ) + result = await session.execute(query) continent = result["continent"] @@ -490,12 +506,13 @@ async def handler(request): @pytest.mark.asyncio -async def test_aiohttp_query_variable_values_fix_issue_292(event_loop, aiohttp_server): +async def test_aiohttp_query_variable_values_fix_issue_292(aiohttp_server): """Allow to specify variable_values without keyword. See https://github.com/graphql-python/gql/issues/292""" from aiohttp import web + from gql.transport.aiohttp import AIOHTTPTransport async def handler(request): @@ -511,12 +528,13 @@ async def handler(request): async with Client(transport=transport) as session: - params = {"code": "EU"} - query = gql(query2_str) + query.variable_values = {"code": "EU"} + query.operation_name = "getEurope" + # Execute query asynchronously - result = await session.execute(query, params, operation_name="getEurope") + result = await session.execute(query) continent = result["continent"] @@ -524,10 +542,9 @@ async def handler(request): @pytest.mark.asyncio -async def test_aiohttp_execute_running_in_thread( - event_loop, aiohttp_server, run_sync_test -): +async def test_aiohttp_execute_running_in_thread(aiohttp_server, run_sync_test): from aiohttp import web + from gql.transport.aiohttp import AIOHTTPTransport async def handler(request): @@ -548,14 +565,13 @@ def test_code(): client.execute(query) - await run_sync_test(event_loop, server, test_code) + await run_sync_test(server, test_code) @pytest.mark.asyncio -async def test_aiohttp_subscribe_running_in_thread( - event_loop, aiohttp_server, run_sync_test -): +async def test_aiohttp_subscribe_running_in_thread(aiohttp_server, run_sync_test): from aiohttp import web + from gql.transport.aiohttp import AIOHTTPTransport async def handler(request): @@ -585,10 +601,8 @@ def test_code(): for result in client.subscribe(query): pass - await run_sync_test(event_loop, server, test_code) - + await run_sync_test(server, test_code) -file_upload_server_answer = '{"data":{"success":true}}' file_upload_mutation_1 = """ mutation($file: Upload!) { @@ -612,40 +626,22 @@ def test_code(): """ -async def single_upload_handler(request): - - from aiohttp import web - - reader = await request.multipart() - - field_0 = await reader.next() - assert field_0.name == "operations" - field_0_text = await field_0.text() - assert strip_braces_spaces(field_0_text) == file_upload_mutation_1_operations - - field_1 = await reader.next() - assert field_1.name == "map" - field_1_text = await field_1.text() - assert field_1_text == file_upload_mutation_1_map - - field_2 = await reader.next() - assert field_2.name == "0" - field_2_text = await field_2.text() - assert field_2_text == file_1_content - - field_3 = await reader.next() - assert field_3 is None - - return web.Response(text=file_upload_server_answer, content_type="application/json") - - @pytest.mark.asyncio -async def test_aiohttp_file_upload(event_loop, aiohttp_server): +async def test_aiohttp_file_upload(aiohttp_server): from aiohttp import web + from gql.transport.aiohttp import AIOHTTPTransport app = web.Application() - app.router.add_route("POST", "/", single_upload_handler) + app.router.add_route( + "POST", + "/", + make_upload_handler( + expected_map=file_upload_mutation_1_map, + expected_operations=file_upload_mutation_1_operations, + expected_contents=[file_1_content], + ), + ) server = await aiohttp_server(app) url = server.make_url("/") @@ -660,57 +656,59 @@ async def test_aiohttp_file_upload(event_loop, aiohttp_server): file_path = test_file.filename + # Using an opened file with open(file_path, "rb") as f: - params = {"file": f, "other_var": 42} + query.variable_values = {"file": f, "other_var": 42} # Execute query asynchronously - result = await session.execute( - query, variable_values=params, upload_files=True - ) + with pytest.warns( + DeprecationWarning, + match="Not using FileVar for file upload is deprecated", + ): + result = await session.execute(query, upload_files=True) success = result["success"] - assert success + # Using an opened file inside a FileVar object + with open(file_path, "rb") as f: -async def single_upload_handler_with_content_type(request): - - from aiohttp import web - - reader = await request.multipart() - - field_0 = await reader.next() - assert field_0.name == "operations" - field_0_text = await field_0.text() - assert strip_braces_spaces(field_0_text) == file_upload_mutation_1_operations + query.variable_values = {"file": FileVar(f), "other_var": 42} - field_1 = await reader.next() - assert field_1.name == "map" - field_1_text = await field_1.text() - assert field_1_text == file_upload_mutation_1_map + with warnings.catch_warnings(): + warnings.simplefilter("error") # Turn warnings into errors + result = await session.execute(query, upload_files=True) - field_2 = await reader.next() - assert field_2.name == "0" - field_2_text = await field_2.text() - assert field_2_text == file_1_content + success = result["success"] + assert success - # Verifying the content_type - assert field_2.headers["Content-Type"] == "application/pdf" + # Using an filename string inside a FileVar object + query.variable_values = {"file": FileVar(file_path), "other_var": 42} - field_3 = await reader.next() - assert field_3 is None + result = await session.execute(query, upload_files=True) - return web.Response(text=file_upload_server_answer, content_type="application/json") + success = result["success"] + assert success @pytest.mark.asyncio -async def test_aiohttp_file_upload_with_content_type(event_loop, aiohttp_server): +async def test_aiohttp_file_upload_with_content_type(aiohttp_server): from aiohttp import web + from gql.transport.aiohttp import AIOHTTPTransport app = web.Application() - app.router.add_route("POST", "/", single_upload_handler_with_content_type) + app.router.add_route( + "POST", + "/", + make_upload_handler( + file_headers=[{"Content-Type": "application/pdf"}], + expected_map=file_upload_mutation_1_map, + expected_operations=file_upload_mutation_1_operations, + expected_contents=[file_1_content], + ), + ) server = await aiohttp_server(app) url = server.make_url("/") @@ -725,100 +723,205 @@ async def test_aiohttp_file_upload_with_content_type(event_loop, aiohttp_server) file_path = test_file.filename + # Using an opened file with open(file_path, "rb") as f: # Setting the content_type - f.content_type = "application/pdf" + f.content_type = "application/pdf" # type: ignore - params = {"file": f, "other_var": 42} + query.variable_values = {"file": f, "other_var": 42} - # Execute query asynchronously - result = await session.execute( - query, variable_values=params, upload_files=True - ) + with pytest.warns( + DeprecationWarning, + match="Not using FileVar for file upload is deprecated", + ): + result = await session.execute(query, upload_files=True) success = result["success"] + assert success + + # Using an opened file inside a FileVar object + with open(file_path, "rb") as f: + + query.variable_values = { + "file": FileVar( + f, + content_type="application/pdf", + ), + "other_var": 42, + } + + result = await session.execute(query, upload_files=True) + + success = result["success"] + assert success + + # Using an filename string inside a FileVar object + query.variable_values = { + "file": FileVar( + file_path, + content_type="application/pdf", + ), + "other_var": 42, + } + result = await session.execute(query, upload_files=True) + + success = result["success"] assert success @pytest.mark.asyncio -async def test_aiohttp_file_upload_without_session( - event_loop, aiohttp_server, run_sync_test -): +async def test_aiohttp_file_upload_default_filename_is_basename(aiohttp_server): from aiohttp import web + from gql.transport.aiohttp import AIOHTTPTransport app = web.Application() - app.router.add_route("POST", "/", single_upload_handler) - server = await aiohttp_server(app) - url = server.make_url("/") + with TemporaryFile(file_1_content) as test_file: + file_path = test_file.filename + file_basename = os.path.basename(file_path) + + app.router.add_route( + "POST", + "/", + make_upload_handler( + filenames=[file_basename], + expected_map=file_upload_mutation_1_map, + expected_operations=file_upload_mutation_1_operations, + expected_contents=[file_1_content], + ), + ) + server = await aiohttp_server(app) - def test_code(): - transport = AIOHTTPTransport(url=url, timeout=10) + url = server.make_url("/") - with TemporaryFile(file_1_content) as test_file: + transport = AIOHTTPTransport(url=url, timeout=10) - client = Client(transport=transport) + async with Client(transport=transport) as session: query = gql(file_upload_mutation_1) - file_path = test_file.filename + query.variable_values = { + "file": FileVar( + file_path, + ), + "other_var": 42, + } - with open(file_path, "rb") as f: + result = await session.execute(query, upload_files=True) - params = {"file": f, "other_var": 42} + success = result["success"] + assert success - result = client.execute( - query, variable_values=params, upload_files=True - ) - success = result["success"] +@pytest.mark.asyncio +async def test_aiohttp_file_upload_with_filename(aiohttp_server): + from aiohttp import web - assert success + from gql.transport.aiohttp import AIOHTTPTransport + + app = web.Application() - await run_sync_test(event_loop, server, test_code) + with TemporaryFile(file_1_content) as test_file: + file_path = test_file.filename + + app.router.add_route( + "POST", + "/", + make_upload_handler( + filenames=["filename1.txt"], + expected_map=file_upload_mutation_1_map, + expected_operations=file_upload_mutation_1_operations, + expected_contents=[file_1_content], + ), + ) + server = await aiohttp_server(app) + url = server.make_url("/") -# This is a sample binary file content containing all possible byte values -binary_file_content = bytes(range(0, 256)) + transport = AIOHTTPTransport(url=url, timeout=10) + + async with Client(transport=transport) as session: + + query = gql(file_upload_mutation_1) + query.variable_values = { + "file": FileVar( + file_path, + filename="filename1.txt", + ), + "other_var": 42, + } -async def binary_upload_handler(request): + result = await session.execute(query, upload_files=True) + success = result["success"] + assert success + + +@pytest.mark.asyncio +async def test_aiohttp_file_upload_without_session(aiohttp_server, run_sync_test): from aiohttp import web - reader = await request.multipart() + from gql.transport.aiohttp import AIOHTTPTransport + + app = web.Application() + app.router.add_route( + "POST", + "/", + make_upload_handler( + expected_map=file_upload_mutation_1_map, + expected_operations=file_upload_mutation_1_operations, + expected_contents=[file_1_content], + ), + ) + server = await aiohttp_server(app) + + url = server.make_url("/") - field_0 = await reader.next() - assert field_0.name == "operations" - field_0_text = await field_0.text() - assert strip_braces_spaces(field_0_text) == file_upload_mutation_1_operations + def test_code(): + transport = AIOHTTPTransport(url=url, timeout=10) - field_1 = await reader.next() - assert field_1.name == "map" - field_1_text = await field_1.text() - assert field_1_text == file_upload_mutation_1_map + with TemporaryFile(file_1_content) as test_file: - field_2 = await reader.next() - assert field_2.name == "0" - field_2_binary = await field_2.read() - assert field_2_binary == binary_file_content + client = Client(transport=transport) - field_3 = await reader.next() - assert field_3 is None + query = gql(file_upload_mutation_1) - return web.Response(text=file_upload_server_answer, content_type="application/json") + file_path = test_file.filename + + query.variable_values = {"file": FileVar(file_path), "other_var": 42} + + result = client.execute(query, upload_files=True) + + success = result["success"] + assert success + + await run_sync_test(server, test_code) @pytest.mark.asyncio -async def test_aiohttp_binary_file_upload(event_loop, aiohttp_server): +async def test_aiohttp_binary_file_upload(aiohttp_server): from aiohttp import web + from gql.transport.aiohttp import AIOHTTPTransport + # This is a sample binary file content containing all possible byte values + binary_file_content = bytes(range(0, 256)) + app = web.Application() - app.router.add_route("POST", "/", binary_upload_handler) + app.router.add_route( + "POST", + "/", + make_upload_handler( + binary=True, + expected_contents=[binary_file_content], + expected_map=file_upload_mutation_1_map, + expected_operations=file_upload_mutation_1_operations, + ), + ) server = await aiohttp_server(app) url = server.make_url("/") @@ -833,14 +936,10 @@ async def test_aiohttp_binary_file_upload(event_loop, aiohttp_server): file_path = test_file.filename - with open(file_path, "rb") as f: - - params = {"file": f, "other_var": 42} + query.variable_values = {"file": FileVar(file_path), "other_var": 42} - # Execute query asynchronously - result = await session.execute( - query, variable_values=params, upload_files=True - ) + # Execute query asynchronously + result = await session.execute(query, upload_files=True) success = result["success"] @@ -848,17 +947,30 @@ async def test_aiohttp_binary_file_upload(event_loop, aiohttp_server): @pytest.mark.asyncio -async def test_aiohttp_stream_reader_upload(event_loop, aiohttp_server): - from aiohttp import web, ClientSession +async def test_aiohttp_stream_reader_upload(aiohttp_server): + from aiohttp import ClientSession, web + from gql.transport.aiohttp import AIOHTTPTransport + # This is a sample binary file content containing all possible byte values + binary_file_content = bytes(range(0, 256)) + async def binary_data_handler(request): return web.Response( body=binary_file_content, content_type="binary/octet-stream" ) app = web.Application() - app.router.add_route("POST", "/", binary_upload_handler) + app.router.add_route( + "POST", + "/", + make_upload_handler( + binary=True, + expected_contents=[binary_file_content], + expected_map=file_upload_mutation_1_map, + expected_operations=file_upload_mutation_1_operations, + ), + ) app.router.add_route("GET", "/binary_data", binary_data_handler) server = await aiohttp_server(app) @@ -868,123 +980,156 @@ async def binary_data_handler(request): transport = AIOHTTPTransport(url=url, timeout=10) + # Not using FileVar async with Client(transport=transport) as session: query = gql(file_upload_mutation_1) async with ClientSession() as client: async with client.get(binary_data_url) as resp: - params = {"file": resp.content, "other_var": 42} + query.variable_values = {"file": resp.content, "other_var": 42} - # Execute query asynchronously - result = await session.execute( - query, variable_values=params, upload_files=True - ) + with pytest.warns( + DeprecationWarning, + match="Not using FileVar for file upload is deprecated", + ): + result = await session.execute(query, upload_files=True) success = result["success"] + assert success + # Using FileVar + async with Client(transport=transport) as session: + query = gql(file_upload_mutation_1) + async with ClientSession() as client: + async with client.get(binary_data_url) as resp: + query.variable_values = {"file": FileVar(resp.content), "other_var": 42} + + result = await session.execute(query, upload_files=True) + + success = result["success"] assert success @pytest.mark.asyncio -async def test_aiohttp_async_generator_upload(event_loop, aiohttp_server): +async def test_aiohttp_async_generator_upload(aiohttp_server): import aiofiles from aiohttp import web + from gql.transport.aiohttp import AIOHTTPTransport + # This is a sample binary file content containing all possible byte values + binary_file_content = bytes(range(0, 256)) + app = web.Application() - app.router.add_route("POST", "/", binary_upload_handler) + app.router.add_route( + "POST", + "/", + make_upload_handler( + binary=True, + expected_contents=[binary_file_content], + expected_map=file_upload_mutation_1_map, + expected_operations=file_upload_mutation_1_operations, + ), + ) server = await aiohttp_server(app) url = server.make_url("/") transport = AIOHTTPTransport(url=url, timeout=10) + query = gql(file_upload_mutation_1) + with TemporaryFile(binary_file_content) as test_file: + file_path = test_file.filename + + async def file_sender(file_name): + async with aiofiles.open(file_name, "rb") as f: + chunk = await f.read(64 * 1024) + while chunk: + yield chunk + chunk = await f.read(64 * 1024) + + # Not using FileVar async with Client(transport=transport) as session: - query = gql(file_upload_mutation_1) + query.variable_values = {"file": file_sender(file_path), "other_var": 42} - file_path = test_file.filename + with pytest.warns( + DeprecationWarning, + match="Not using FileVar for file upload is deprecated", + ): + result = await session.execute(query, upload_files=True) - async def file_sender(file_name): - async with aiofiles.open(file_name, "rb") as f: - chunk = await f.read(64 * 1024) - while chunk: - yield chunk - chunk = await f.read(64 * 1024) + success = result["success"] + assert success - params = {"file": file_sender(file_path), "other_var": 42} + # Using FileVar + async with Client(transport=transport) as session: + + query.variable_values = { + "file": FileVar(file_sender(file_path)), + "other_var": 42, + } # Execute query asynchronously - result = await session.execute( - query, variable_values=params, upload_files=True - ) + result = await session.execute(query, upload_files=True) success = result["success"] - assert success + # Using FileVar with new streaming support + async with Client(transport=transport) as session: -file_upload_mutation_2 = """ - mutation($file1: Upload!, $file2: Upload!) { - uploadFile(input:{file1:$file, file2:$file}) { - success - } - } -""" + query.variable_values = { + "file": FileVar(file_path, streaming=True), + "other_var": 42, + } -file_upload_mutation_2_operations = ( - '{"query": "mutation ($file1: Upload!, $file2: Upload!) {\\n ' - 'uploadFile(input: {file1: $file, file2: $file}) {\\n success\\n }\\n}", ' - '"variables": {"file1": null, "file2": null}}' -) - -file_upload_mutation_2_map = '{"0": ["variables.file1"], "1": ["variables.file2"]}' + # Execute query asynchronously + result = await session.execute(query, upload_files=True) -file_2_content = """ -This is a second test file -This file will also be sent in the GraphQL mutation -""" + success = result["success"] + assert success @pytest.mark.asyncio -async def test_aiohttp_file_upload_two_files(event_loop, aiohttp_server): +async def test_aiohttp_file_upload_two_files(aiohttp_server): from aiohttp import web - from gql.transport.aiohttp import AIOHTTPTransport - async def handler(request): - - reader = await request.multipart() - - field_0 = await reader.next() - assert field_0.name == "operations" - field_0_text = await field_0.text() - assert strip_braces_spaces(field_0_text) == file_upload_mutation_2_operations - - field_1 = await reader.next() - assert field_1.name == "map" - field_1_text = await field_1.text() - assert field_1_text == file_upload_mutation_2_map + from gql.transport.aiohttp import AIOHTTPTransport - field_2 = await reader.next() - assert field_2.name == "0" - field_2_text = await field_2.text() - assert field_2_text == file_1_content + file_upload_mutation_2 = """ + mutation($file1: Upload!, $file2: Upload!) { + uploadFile(input:{file1:$file, file2:$file}) { + success + } + } + """ - field_3 = await reader.next() - assert field_3.name == "1" - field_3_text = await field_3.text() - assert field_3_text == file_2_content + file_upload_mutation_2_operations = ( + '{"query": "mutation ($file1: Upload!, $file2: Upload!) {\\n ' + 'uploadFile(input: {file1: $file, file2: $file}) {\\n success\\n }\\n}", ' + '"variables": {"file1": null, "file2": null}}' + ) - field_4 = await reader.next() - assert field_4 is None + file_upload_mutation_2_map = '{"0": ["variables.file1"], "1": ["variables.file2"]}' - return web.Response( - text=file_upload_server_answer, content_type="application/json" - ) + file_2_content = """ + This is a second test file + This file will also be sent in the GraphQL mutation + """ app = web.Application() - app.router.add_route("POST", "/", handler) + app.router.add_route( + "POST", + "/", + make_upload_handler( + nb_files=2, + expected_map=file_upload_mutation_2_map, + expected_operations=file_upload_mutation_2_operations, + expected_contents=[file_1_content, file_2_content], + ), + ) server = await aiohttp_server(app) url = server.make_url("/") @@ -1001,81 +1146,58 @@ async def handler(request): file_path_1 = test_file_1.filename file_path_2 = test_file_2.filename - f1 = open(file_path_1, "rb") - f2 = open(file_path_2, "rb") - - params = { - "file1": f1, - "file2": f2, + query.variable_values = { + "file1": FileVar(file_path_1), + "file2": FileVar(file_path_2), } - result = await session.execute( - query, variable_values=params, upload_files=True - ) - - f1.close() - f2.close() + result = await session.execute(query, upload_files=True) success = result["success"] assert success -file_upload_mutation_3 = """ +@pytest.mark.asyncio +async def test_aiohttp_file_upload_list_of_two_files(aiohttp_server): + from aiohttp import web + + from gql.transport.aiohttp import AIOHTTPTransport + + file_upload_mutation_3 = """ mutation($files: [Upload!]!) { uploadFiles(input:{files:$files}) { success } } -""" - -file_upload_mutation_3_operations = ( - '{"query": "mutation ($files: [Upload!]!) {\\n uploadFiles(' - "input: {files: $files})" - ' {\\n success\\n }\\n}", "variables": {"files": [null, null]}}' -) - -file_upload_mutation_3_map = '{"0": ["variables.files.0"], "1": ["variables.files.1"]}' - - -@pytest.mark.asyncio -async def test_aiohttp_file_upload_list_of_two_files(event_loop, aiohttp_server): - from aiohttp import web - from gql.transport.aiohttp import AIOHTTPTransport - - async def handler(request): - - reader = await request.multipart() - - field_0 = await reader.next() - assert field_0.name == "operations" - field_0_text = await field_0.text() - assert strip_braces_spaces(field_0_text) == file_upload_mutation_3_operations + """ - field_1 = await reader.next() - assert field_1.name == "map" - field_1_text = await field_1.text() - assert field_1_text == file_upload_mutation_3_map - - field_2 = await reader.next() - assert field_2.name == "0" - field_2_text = await field_2.text() - assert field_2_text == file_1_content - - field_3 = await reader.next() - assert field_3.name == "1" - field_3_text = await field_3.text() - assert field_3_text == file_2_content + file_upload_mutation_3_operations = ( + '{"query": "mutation ($files: [Upload!]!) {\\n uploadFiles' + "(input: {files: $files})" + ' {\\n success\\n }\\n}", "variables": {"files": [null, null]}}' + ) - field_4 = await reader.next() - assert field_4 is None + file_upload_mutation_3_map = ( + '{"0": ["variables.files.0"], "1": ["variables.files.1"]}' + ) - return web.Response( - text=file_upload_server_answer, content_type="application/json" - ) + file_2_content = """ + This is a second test file + This file will also be sent in the GraphQL mutation + """ app = web.Application() - app.router.add_route("POST", "/", handler) + app.router.add_route( + "POST", + "/", + make_upload_handler( + nb_files=2, + expected_map=file_upload_mutation_3_map, + expected_operations=file_upload_mutation_3_operations, + expected_contents=[file_1_content, file_2_content], + ), + ) server = await aiohttp_server(app) url = server.make_url("/") @@ -1092,18 +1214,15 @@ async def handler(request): file_path_1 = test_file_1.filename file_path_2 = test_file_2.filename - f1 = open(file_path_1, "rb") - f2 = open(file_path_2, "rb") - - params = {"files": [f1, f2]} + query.variable_values = { + "files": [ + FileVar(file_path_1), + FileVar(file_path_2), + ], + } # Execute query asynchronously - result = await session.execute( - query, variable_values=params, upload_files=True - ) - - f1.close() - f2.close() + result = await session.execute(query, upload_files=True) success = result["success"] @@ -1111,7 +1230,7 @@ async def handler(request): @pytest.mark.asyncio -async def test_aiohttp_using_cli(event_loop, aiohttp_server, monkeypatch, capsys): +async def test_aiohttp_using_cli(aiohttp_server, monkeypatch, capsys): from aiohttp import web async def handler(request): @@ -1148,7 +1267,7 @@ async def handler(request): @pytest.mark.asyncio @pytest.mark.script_launch_mode("subprocess") async def test_aiohttp_using_cli_ep( - event_loop, aiohttp_server, monkeypatch, script_runner, run_sync_test + aiohttp_server, monkeypatch, script_runner, run_sync_test ): from aiohttp import web @@ -1181,13 +1300,11 @@ def test_code(): assert received_answer == expected_answer - await run_sync_test(event_loop, server, test_code) + await run_sync_test(server, test_code) @pytest.mark.asyncio -async def test_aiohttp_using_cli_invalid_param( - event_loop, aiohttp_server, monkeypatch, capsys -): +async def test_aiohttp_using_cli_invalid_param(aiohttp_server, monkeypatch, capsys): from aiohttp import web async def handler(request): @@ -1221,9 +1338,7 @@ async def handler(request): @pytest.mark.asyncio -async def test_aiohttp_using_cli_invalid_query( - event_loop, aiohttp_server, monkeypatch, capsys -): +async def test_aiohttp_using_cli_invalid_query(aiohttp_server, monkeypatch, capsys): from aiohttp import web async def handler(request): @@ -1261,8 +1376,9 @@ async def handler(request): @pytest.mark.asyncio -async def test_aiohttp_query_with_extensions(event_loop, aiohttp_server): +async def test_aiohttp_query_with_extensions(aiohttp_server): from aiohttp import web + from gql.transport.aiohttp import AIOHTTPTransport async def handler(request): @@ -1290,10 +1406,9 @@ async def handler(request): @pytest.mark.asyncio @pytest.mark.parametrize("ssl_close_timeout", [0, 10]) @pytest.mark.parametrize("verify_https", ["disabled", "cert_provided"]) -async def test_aiohttp_query_https( - event_loop, ssl_aiohttp_server, ssl_close_timeout, verify_https -): +async def test_aiohttp_query_https(ssl_aiohttp_server, ssl_close_timeout, verify_https): from aiohttp import web + from gql.transport.aiohttp import AIOHTTPTransport async def handler(request): @@ -1337,12 +1452,11 @@ async def handler(request): assert africa["code"] == "AF" -@pytest.mark.skip(reason="We will change the default to fix this in a future version") @pytest.mark.asyncio -async def test_aiohttp_query_https_self_cert_fail(event_loop, ssl_aiohttp_server): +async def test_aiohttp_query_https_self_cert_fail(ssl_aiohttp_server): """By default, we should verify the ssl certificate""" - from aiohttp.client_exceptions import ClientConnectorCertificateError from aiohttp import web + from gql.transport.aiohttp import AIOHTTPTransport async def handler(request): @@ -1358,22 +1472,29 @@ async def handler(request): transport = AIOHTTPTransport(url=url, timeout=10) - with pytest.raises(ClientConnectorCertificateError) as exc_info: - async with Client(transport=transport) as session: - query = gql(query1_str) + query = gql(query1_str) - # Execute query asynchronously + expected_error = "certificate verify failed: self-signed certificate" + + with pytest.raises(TransportConnectionFailed) as exc_info: + async with Client(transport=transport) as session: await session.execute(query) - expected_error = "certificate verify failed: self-signed certificate" + assert expected_error in str(exc_info.value) + + with pytest.raises(TransportConnectionFailed) as exc_info: + async with Client(transport=transport) as session: + await session.execute_batch([query]) assert expected_error in str(exc_info.value) + assert transport.session is None @pytest.mark.asyncio -async def test_aiohttp_query_https_self_cert_warn(event_loop, ssl_aiohttp_server): +async def test_aiohttp_query_https_self_cert_default(ssl_aiohttp_server): from aiohttp import web + from gql.transport.aiohttp import AIOHTTPTransport async def handler(request): @@ -1387,18 +1508,15 @@ async def handler(request): assert str(url).startswith("https://") - expected_warning = ( - "WARNING: By default, AIOHTTPTransport does not verify ssl certificates." - " This will be fixed in the next major version." - ) + transport = AIOHTTPTransport(url=url) - with pytest.warns(Warning, match=expected_warning): - AIOHTTPTransport(url=url, timeout=10) + assert transport.ssl is True @pytest.mark.asyncio -async def test_aiohttp_error_fetching_schema(event_loop, aiohttp_server): +async def test_aiohttp_error_fetching_schema(aiohttp_server): from aiohttp import web + from gql.transport.aiohttp import AIOHTTPTransport error_answer = """ @@ -1440,8 +1558,9 @@ async def handler(request): @pytest.mark.asyncio -async def test_aiohttp_reconnecting_session(event_loop, aiohttp_server): +async def test_aiohttp_reconnecting_session(aiohttp_server): from aiohttp import web + from gql.transport.aiohttp import AIOHTTPTransport async def handler(request): @@ -1478,10 +1597,9 @@ async def handler(request): @pytest.mark.asyncio @pytest.mark.parametrize("retries", [False, lambda e: e]) -async def test_aiohttp_reconnecting_session_retries( - event_loop, aiohttp_server, retries -): +async def test_aiohttp_reconnecting_session_retries(aiohttp_server, retries): from aiohttp import web + from gql.transport.aiohttp import AIOHTTPTransport async def handler(request): @@ -1512,9 +1630,10 @@ async def handler(request): @pytest.mark.asyncio async def test_aiohttp_reconnecting_session_start_connecting_task_twice( - event_loop, aiohttp_server, caplog + aiohttp_server, caplog ): from aiohttp import web + from gql.transport.aiohttp import AIOHTTPTransport async def handler(request): @@ -1546,8 +1665,9 @@ async def handler(request): @pytest.mark.asyncio -async def test_aiohttp_json_serializer(event_loop, aiohttp_server, caplog): +async def test_aiohttp_json_serializer(aiohttp_server, caplog): from aiohttp import web + from gql.transport.aiohttp import AIOHTTPTransport async def handler(request): @@ -1602,10 +1722,12 @@ async def handler(request): @pytest.mark.asyncio -async def test_aiohttp_json_deserializer(event_loop, aiohttp_server): - from aiohttp import web +async def test_aiohttp_json_deserializer(aiohttp_server): from decimal import Decimal from functools import partial + + from aiohttp import web + from gql.transport.aiohttp import AIOHTTPTransport async def handler(request): @@ -1641,8 +1763,9 @@ async def handler(request): @pytest.mark.asyncio -async def test_aiohttp_connector_owner_false(event_loop, aiohttp_server): - from aiohttp import web, TCPConnector +async def test_aiohttp_connector_owner_false(aiohttp_server): + from aiohttp import TCPConnector, web + from gql.transport.aiohttp import AIOHTTPTransport async def handler(request): @@ -1682,3 +1805,104 @@ async def handler(request): assert africa["code"] == "AF" await connector.close() + + +@pytest.mark.asyncio +async def test_aiohttp_deprecation_warning_using_document_node_execute(aiohttp_server): + from aiohttp import web + + from gql.transport.aiohttp import AIOHTTPTransport + + async def handler(request): + return web.Response( + text=query1_server_answer, + content_type="application/json", + ) + + app = web.Application() + app.router.add_route("POST", "/", handler) + server = await aiohttp_server(app) + + url = server.make_url("/") + + transport = AIOHTTPTransport(url=url, timeout=10) + + async with Client(transport=transport) as session: + + query = gql(query1_str) + + with pytest.warns( + DeprecationWarning, + match="Using a DocumentNode is deprecated", + ): + result = await session.execute(query.document) + + continents = result["continents"] + + africa = continents[0] + + assert africa["code"] == "AF" + + +@pytest.mark.asyncio +async def test_aiohttp_deprecation_warning_execute_variable_values(aiohttp_server): + from aiohttp import web + + from gql.transport.aiohttp import AIOHTTPTransport + + async def handler(request): + return web.Response(text=query2_server_answer, content_type="application/json") + + app = web.Application() + app.router.add_route("POST", "/", handler) + server = await aiohttp_server(app) + + url = server.make_url("/") + + transport = AIOHTTPTransport(url=url, timeout=10) + + async with Client(transport=transport) as session: + + query = gql(query2_str) + + with pytest.warns( + DeprecationWarning, + match=( + "Using variable_values and operation_name arguments of " + "execute and subscribe methods is deprecated" + ), + ): + result = await session.execute( + query, + variable_values={"code": "EU"}, + operation_name="getEurope", + ) + + continent = result["continent"] + + assert continent["name"] == "Europe" + + +@pytest.mark.asyncio +async def test_aiohttp_type_error_execute(aiohttp_server): + from aiohttp import web + + from gql.transport.aiohttp import AIOHTTPTransport + + async def handler(request): + return web.Response(text=query2_server_answer, content_type="application/json") + + app = web.Application() + app.router.add_route("POST", "/", handler) + server = await aiohttp_server(app) + + url = server.make_url("/") + + transport = AIOHTTPTransport(url=url, timeout=10) + + async with Client(transport=transport) as session: + + with pytest.raises(TypeError) as exc_info: + await session.execute("qmlsdkfj") + + assert "request should be a GraphQLRequest object" in str(exc_info.value) diff --git a/tests/test_aiohttp_batch.py b/tests/test_aiohttp_batch.py new file mode 100644 index 00000000..ad9924a0 --- /dev/null +++ b/tests/test_aiohttp_batch.py @@ -0,0 +1,523 @@ +import asyncio +from typing import Mapping + +import pytest + +from gql import Client, GraphQLRequest, gql +from gql.transport.exceptions import ( + TransportClosed, + TransportProtocolError, + TransportQueryError, + TransportServerError, +) + +# Marking all tests in this file with the aiohttp marker +pytestmark = pytest.mark.aiohttp + +query1_str = """ + query getContinents { + continents { + code + name + } + } +""" + +query1_server_answer_list = ( + '[{"data":{"continents":[' + '{"code":"AF","name":"Africa"},{"code":"AN","name":"Antarctica"},' + '{"code":"AS","name":"Asia"},{"code":"EU","name":"Europe"},' + '{"code":"NA","name":"North America"},{"code":"OC","name":"Oceania"},' + '{"code":"SA","name":"South America"}]}}]' +) + +query1_server_answer_twice_list = ( + "[" + '{"data":{"continents":[' + '{"code":"AF","name":"Africa"},{"code":"AN","name":"Antarctica"},' + '{"code":"AS","name":"Asia"},{"code":"EU","name":"Europe"},' + '{"code":"NA","name":"North America"},{"code":"OC","name":"Oceania"},' + '{"code":"SA","name":"South America"}]}},' + '{"data":{"continents":[' + '{"code":"AF","name":"Africa"},{"code":"AN","name":"Antarctica"},' + '{"code":"AS","name":"Asia"},{"code":"EU","name":"Europe"},' + '{"code":"NA","name":"North America"},{"code":"OC","name":"Oceania"},' + '{"code":"SA","name":"South America"}]}}' + "]" +) + + +@pytest.mark.asyncio +async def test_aiohttp_batch_query(aiohttp_server): + from aiohttp import web + + from gql.transport.aiohttp import AIOHTTPTransport + + async def handler(request): + return web.Response( + text=query1_server_answer_list, + content_type="application/json", + headers={"dummy": "test1234"}, + ) + + app = web.Application() + app.router.add_route("POST", "/", handler) + server = await aiohttp_server(app) + + url = server.make_url("/") + + transport = AIOHTTPTransport(url=url, timeout=10) + + async with Client(transport=transport) as session: + + query = [GraphQLRequest(query1_str)] + + # Execute query asynchronously + results = await session.execute_batch(query) + + result = results[0] + + continents = result["continents"] + + africa = continents[0] + + assert africa["code"] == "AF" + + # Checking response headers are saved in the transport + assert hasattr(transport, "response_headers") + assert isinstance(transport.response_headers, Mapping) + assert transport.response_headers["dummy"] == "test1234" + + +@pytest.mark.asyncio +async def test_aiohttp_batch_query_auto_batch_enabled(aiohttp_server, run_sync_test): + from aiohttp import web + + from gql.transport.aiohttp import AIOHTTPTransport + + async def handler(request): + return web.Response( + text=query1_server_answer_list, + content_type="application/json", + headers={"dummy": "test1234"}, + ) + + app = web.Application() + app.router.add_route("POST", "/", handler) + server = await aiohttp_server(app) + + url = server.make_url("/") + + transport = AIOHTTPTransport(url=url, timeout=10) + + async with Client( + transport=transport, + batch_interval=0.01, # 10ms batch interval + ) as session: + + query = gql(query1_str) + + result = await session.execute(query) + + continents = result["continents"] + + africa = continents[0] + + assert africa["code"] == "AF" + + # Checking response headers are saved in the transport + assert hasattr(transport, "response_headers") + assert isinstance(transport.response_headers, Mapping) + assert transport.response_headers["dummy"] == "test1234" + + +@pytest.mark.asyncio +async def test_aiohttp_batch_auto_two_requests(aiohttp_server): + from aiohttp import web + + from gql.transport.aiohttp import AIOHTTPTransport + + async def handler(request): + return web.Response( + text=query1_server_answer_twice_list, + content_type="application/json", + headers={"dummy": "test1234"}, + ) + + app = web.Application() + app.router.add_route("POST", "/", handler) + server = await aiohttp_server(app) + + url = server.make_url("/") + transport = AIOHTTPTransport(url=url, timeout=10) + + async with Client( + transport=transport, + batch_interval=0.01, + ) as session: + + async def test_coroutine(): + query = gql(query1_str) + + # Execute query asynchronously + result = await session.execute(query) + + continents = result["continents"] + + africa = continents[0] + + assert africa["code"] == "AF" + + # Create two concurrent tasks that will be batched together + tasks = [] + for _ in range(2): + task = asyncio.create_task(test_coroutine()) + tasks.append(task) + + # Wait for all tasks to complete + await asyncio.gather(*tasks) + + +@pytest.mark.asyncio +async def test_aiohttp_batch_auto_two_requests_close_session_directly(aiohttp_server): + from aiohttp import web + + from gql.transport.aiohttp import AIOHTTPTransport + + async def handler(request): + return web.Response( + text=query1_server_answer_twice_list, + content_type="application/json", + headers={"dummy": "test1234"}, + ) + + app = web.Application() + app.router.add_route("POST", "/", handler) + server = await aiohttp_server(app) + + url = server.make_url("/") + transport = AIOHTTPTransport(url=url, timeout=10) + + async with Client( + transport=transport, + batch_interval=0.1, + ) as session: + + async def test_coroutine(): + query = gql(query1_str) + + # Execute query asynchronously + result = await session.execute(query) + + continents = result["continents"] + + africa = continents[0] + + assert africa["code"] == "AF" + + # Create two concurrent tasks that will be batched together + tasks = [] + for _ in range(2): + task = asyncio.create_task(test_coroutine()) + tasks.append(task) + + await asyncio.sleep(0.01) + + # Wait for all tasks to complete + await asyncio.gather(*tasks) + + +@pytest.mark.asyncio +async def test_aiohttp_batch_error_code_401(aiohttp_server): + from aiohttp import web + + from gql.transport.aiohttp import AIOHTTPTransport + + async def handler(request): + # Will generate http error code 401 + return web.Response( + text='{"error":"Unauthorized","message":"401 Client Error: Unauthorized"}', + content_type="application/json", + status=401, + ) + + app = web.Application() + app.router.add_route("POST", "/", handler) + server = await aiohttp_server(app) + + url = server.make_url("/") + + transport = AIOHTTPTransport(url=url, timeout=10) + + async with Client( + transport=transport, + batch_interval=0.01, # 10ms batch interval + ) as session: + + query = gql(query1_str) + + with pytest.raises(TransportServerError) as exc_info: + await session.execute(query) + + assert "401, message='Unauthorized'" in str(exc_info.value) + + +@pytest.mark.asyncio +async def test_aiohttp_batch_query_without_session(aiohttp_server, run_sync_test): + from aiohttp import web + + from gql.transport.aiohttp import AIOHTTPTransport + + async def handler(request): + return web.Response( + text=query1_server_answer_list, + content_type="application/json", + headers={"dummy": "test1234"}, + ) + + app = web.Application() + app.router.add_route("POST", "/", handler) + server = await aiohttp_server(app) + + url = server.make_url("/") + + def test_code(): + transport = AIOHTTPTransport(url=url, timeout=10) + + client = Client(transport=transport) + + query = [GraphQLRequest(query1_str)] + + results = client.execute_batch(query) + + result = results[0] + + continents = result["continents"] + + africa = continents[0] + + assert africa["code"] == "AF" + + # Checking response headers are saved in the transport + assert hasattr(transport, "response_headers") + assert isinstance(transport.response_headers, Mapping) + assert transport.response_headers["dummy"] == "test1234" + + await run_sync_test(server, test_code) + + +query1_server_error_answer_list = '[{"errors": ["Error 1", "Error 2"]}]' + + +@pytest.mark.asyncio +async def test_aiohttp_batch_error_code(aiohttp_server): + from aiohttp import web + + from gql.transport.aiohttp import AIOHTTPTransport + + async def handler(request): + return web.Response( + text=query1_server_error_answer_list, content_type="application/json" + ) + + app = web.Application() + app.router.add_route("POST", "/", handler) + server = await aiohttp_server(app) + + url = server.make_url("/") + + transport = AIOHTTPTransport(url=url) + + async with Client(transport=transport) as session: + + query = [GraphQLRequest(query1_str)] + + with pytest.raises(TransportQueryError): + await session.execute_batch(query) + + +invalid_protocol_responses = [ + "{}", + "qlsjfqsdlkj", + '{"not_data_or_errors": 35}', + "[{}]", + "[qlsjfqsdlkj]", + '[{"not_data_or_errors": 35}]', + "[]", + "[1]", +] + + +@pytest.mark.asyncio +@pytest.mark.parametrize("response", invalid_protocol_responses) +async def test_aiohttp_batch_invalid_protocol(aiohttp_server, response): + from aiohttp import web + + from gql.transport.aiohttp import AIOHTTPTransport + + async def handler(request): + return web.Response(text=response, content_type="application/json") + + app = web.Application() + app.router.add_route("POST", "/", handler) + server = await aiohttp_server(app) + + url = server.make_url("/") + + transport = AIOHTTPTransport(url=url, timeout=10) + + async with Client(transport=transport) as session: + + query = [GraphQLRequest(query1_str)] + + with pytest.raises(TransportProtocolError): + await session.execute_batch(query) + + +@pytest.mark.aiohttp +@pytest.mark.asyncio +async def test_aiohttp_batch_cannot_execute_if_not_connected( + aiohttp_server, run_sync_test +): + from aiohttp import web + + from gql.transport.aiohttp import AIOHTTPTransport + + async def handler(request): + return web.Response( + text=query1_server_answer_list, + content_type="application/json", + headers={"dummy": "test1234"}, + ) + + app = web.Application() + app.router.add_route("POST", "/", handler) + server = await aiohttp_server(app) + + url = server.make_url("/") + + transport = AIOHTTPTransport(url=url, timeout=10) + + query = [GraphQLRequest(query1_str)] + + with pytest.raises(TransportClosed): + await transport.execute_batch(query) + + +@pytest.mark.asyncio +async def test_aiohttp_batch_extra_args(aiohttp_server): + from aiohttp import web + + from gql.transport.aiohttp import AIOHTTPTransport + + async def handler(request): + return web.Response( + text=query1_server_answer_list, + content_type="application/json", + headers={"dummy": "test1234"}, + ) + + app = web.Application() + app.router.add_route("POST", "/", handler) + server = await aiohttp_server(app) + + url = server.make_url("/") + + # passing extra arguments to aiohttp.ClientSession + from aiohttp import DummyCookieJar + + jar = DummyCookieJar() + transport = AIOHTTPTransport( + url=url, timeout=10, client_session_args={"version": "1.1", "cookie_jar": jar} + ) + + async with Client(transport=transport) as session: + + query = [GraphQLRequest(query1_str)] + + # Passing extra arguments to the post method of aiohttp + results = await session.execute_batch( + query, extra_args={"allow_redirects": False} + ) + + result = results[0] + + continents = result["continents"] + + africa = continents[0] + + assert africa["code"] == "AF" + + +query1_server_answer_with_extensions_list = ( + '[{"data":{"continents":[' + '{"code":"AF","name":"Africa"},{"code":"AN","name":"Antarctica"},' + '{"code":"AS","name":"Asia"},{"code":"EU","name":"Europe"},' + '{"code":"NA","name":"North America"},{"code":"OC","name":"Oceania"},' + '{"code":"SA","name":"South America"}]},' + '"extensions": {"key1": "val1"}' + "}]" +) + + +@pytest.mark.asyncio +async def test_aiohttp_batch_query_with_extensions(aiohttp_server): + from aiohttp import web + + from gql.transport.aiohttp import AIOHTTPTransport + + async def handler(request): + return web.Response( + text=query1_server_answer_with_extensions_list, + content_type="application/json", + ) + + app = web.Application() + app.router.add_route("POST", "/", handler) + server = await aiohttp_server(app) + + url = server.make_url("/") + + transport = AIOHTTPTransport(url=url) + + query = [GraphQLRequest(query1_str)] + + async with Client(transport=transport) as session: + + execution_results = await session.execute_batch( + query, get_execution_result=True + ) + + assert execution_results[0].extensions["key1"] == "val1" + + +ONLINE_URL = "https://countries.trevorblades.workers.dev/graphql" + + +@pytest.mark.online +@pytest.mark.asyncio +async def test_aiohttp_batch_online_manual(): + + from gql.transport.aiohttp import AIOHTTPTransport + + client = Client( + transport=AIOHTTPTransport(url=ONLINE_URL, timeout=10), + ) + + query = """ + query getContinentName($continent_code: ID!) { + continent(code: $continent_code) { + name + } + } + """ + + async with client as session: + + request_eu = GraphQLRequest(query, variable_values={"continent_code": "EU"}) + request_af = GraphQLRequest(query, variable_values={"continent_code": "AF"}) + + result_eu, result_af = await session.execute_batch([request_eu, request_af]) + + assert result_eu["continent"]["name"] == "Europe" + assert result_af["continent"]["name"] == "Africa" diff --git a/tests/test_aiohttp_online.py b/tests/test_aiohttp_online.py index 39b8a9d2..a4f2480c 100644 --- a/tests/test_aiohttp_online.py +++ b/tests/test_aiohttp_online.py @@ -11,7 +11,7 @@ @pytest.mark.aiohttp @pytest.mark.online @pytest.mark.asyncio -async def test_aiohttp_simple_query(event_loop): +async def test_aiohttp_simple_query(): from gql.transport.aiohttp import AIOHTTPTransport @@ -19,10 +19,10 @@ async def test_aiohttp_simple_query(event_loop): url = "https://countries.trevorblades.com/graphql" # Get transport - sample_transport = AIOHTTPTransport(url=url) + transport = AIOHTTPTransport(url=url) # Instanciate client - async with Client(transport=sample_transport) as session: + async with Client(transport=transport) as session: query = gql( """ @@ -56,15 +56,13 @@ async def test_aiohttp_simple_query(event_loop): @pytest.mark.aiohttp @pytest.mark.online @pytest.mark.asyncio -async def test_aiohttp_invalid_query(event_loop): +async def test_aiohttp_invalid_query(): from gql.transport.aiohttp import AIOHTTPTransport - sample_transport = AIOHTTPTransport( - url="https://countries.trevorblades.com/graphql" - ) + transport = AIOHTTPTransport(url="https://countries.trevorblades.com/graphql") - async with Client(transport=sample_transport) as session: + async with Client(transport=transport) as session: query = gql( """ @@ -85,16 +83,16 @@ async def test_aiohttp_invalid_query(event_loop): @pytest.mark.online @pytest.mark.skipif(sys.version_info < (3, 8), reason="requires python3.8 or higher") @pytest.mark.asyncio -async def test_aiohttp_two_queries_in_parallel_using_two_tasks(event_loop): +async def test_aiohttp_two_queries_in_parallel_using_two_tasks(): from gql.transport.aiohttp import AIOHTTPTransport - sample_transport = AIOHTTPTransport( + transport = AIOHTTPTransport( url="https://countries.trevorblades.com/graphql", ) # Instanciate client - async with Client(transport=sample_transport) as session: + async with Client(transport=transport) as session: query1 = gql( """ diff --git a/tests/test_aiohttp_websocket_exceptions.py b/tests/test_aiohttp_websocket_exceptions.py index 8ee44d2c..2fb6722c 100644 --- a/tests/test_aiohttp_websocket_exceptions.py +++ b/tests/test_aiohttp_websocket_exceptions.py @@ -7,7 +7,7 @@ from gql import Client, gql from gql.transport.exceptions import ( - TransportClosed, + TransportConnectionFailed, TransportProtocolError, TransportQueryError, ) @@ -40,9 +40,7 @@ @pytest.mark.asyncio @pytest.mark.parametrize("server", [invalid_query1_server], indirect=True) @pytest.mark.parametrize("query_str", [invalid_query_str]) -async def test_aiohttp_websocket_invalid_query( - event_loop, aiohttp_client_and_server, query_str -): +async def test_aiohttp_websocket_invalid_query(aiohttp_client_and_server, query_str): session, server = aiohttp_client_and_server @@ -82,7 +80,7 @@ async def server_invalid_subscription(ws): @pytest.mark.parametrize("server", [server_invalid_subscription], indirect=True) @pytest.mark.parametrize("query_str", [invalid_subscription_str]) async def test_aiohttp_websocket_invalid_subscription( - event_loop, aiohttp_client_and_server, query_str + aiohttp_client_and_server, query_str ): session, server = aiohttp_client_and_server @@ -115,17 +113,15 @@ async def server_no_ack(ws): @pytest.mark.asyncio @pytest.mark.parametrize("server", [server_no_ack], indirect=True) @pytest.mark.parametrize("query_str", [invalid_query_str]) -async def test_aiohttp_websocket_server_does_not_send_ack( - event_loop, server, query_str -): +async def test_aiohttp_websocket_server_does_not_send_ack(server, query_str): from gql.transport.aiohttp_websockets import AIOHTTPWebsocketsTransport url = f"ws://{server.hostname}:{server.port}/graphql" - sample_transport = AIOHTTPWebsocketsTransport(url=url, ack_timeout=1) + transport = AIOHTTPWebsocketsTransport(url=url, ack_timeout=0.1) with pytest.raises(asyncio.TimeoutError): - async with Client(transport=sample_transport): + async with Client(transport=transport): pass @@ -141,14 +137,14 @@ async def server_connection_error(ws): @pytest.mark.parametrize("server", [server_connection_error], indirect=True) @pytest.mark.parametrize("query_str", [invalid_query_str]) async def test_aiohttp_websocket_sending_invalid_data( - event_loop, aiohttp_client_and_server, query_str + aiohttp_client_and_server, query_str ): session, server = aiohttp_client_and_server invalid_data = "QSDF" print(f">>> {invalid_data}") - await session.transport.websocket.send_str(invalid_data) + await session.transport.adapter.websocket.send_str(invalid_data) await asyncio.sleep(2 * MS) @@ -171,7 +167,7 @@ async def server_invalid_payload(ws): @pytest.mark.parametrize("server", [server_invalid_payload], indirect=True) @pytest.mark.parametrize("query_str", [invalid_query_str]) async def test_aiohttp_websocket_sending_invalid_payload( - event_loop, aiohttp_client_and_server, query_str + aiohttp_client_and_server, query_str ): session, server = aiohttp_client_and_server @@ -183,7 +179,7 @@ async def monkey_patch_send_query( document, variable_values=None, operation_name=None, - ) -> int: + ): query_id = self.next_query_id self.next_query_id += 1 @@ -241,9 +237,7 @@ async def monkey_patch_send_query( ], indirect=True, ) -async def test_aiohttp_websocket_transport_protocol_errors( - event_loop, aiohttp_client_and_server -): +async def test_aiohttp_websocket_transport_protocol_errors(aiohttp_client_and_server): session, server = aiohttp_client_and_server @@ -261,16 +255,16 @@ async def server_without_ack(ws): @pytest.mark.asyncio @pytest.mark.parametrize("server", [server_without_ack], indirect=True) -async def test_aiohttp_websocket_server_does_not_ack(event_loop, server): +async def test_aiohttp_websocket_server_does_not_ack(server): from gql.transport.aiohttp_websockets import AIOHTTPWebsocketsTransport url = f"ws://{server.hostname}:{server.port}/graphql" print(f"url = {url}") - sample_transport = AIOHTTPWebsocketsTransport(url=url) + transport = AIOHTTPWebsocketsTransport(url=url) with pytest.raises(TransportProtocolError): - async with Client(transport=sample_transport): + async with Client(transport=transport): pass @@ -280,17 +274,17 @@ async def server_closing_directly(ws): @pytest.mark.asyncio @pytest.mark.parametrize("server", [server_closing_directly], indirect=True) -async def test_aiohttp_websocket_server_closing_directly(event_loop, server): +async def test_aiohttp_websocket_server_closing_directly(server): from gql.transport.aiohttp_websockets import AIOHTTPWebsocketsTransport url = f"ws://{server.hostname}:{server.port}/graphql" print(f"url = {url}") - sample_transport = AIOHTTPWebsocketsTransport(url=url) + transport = AIOHTTPWebsocketsTransport(url=url) - with pytest.raises(ConnectionResetError): - async with Client(transport=sample_transport): + with pytest.raises(TransportConnectionFailed): + async with Client(transport=transport): pass @@ -301,15 +295,22 @@ async def server_closing_after_ack(ws): @pytest.mark.asyncio @pytest.mark.parametrize("server", [server_closing_after_ack], indirect=True) -async def test_aiohttp_websocket_server_closing_after_ack( - event_loop, aiohttp_client_and_server -): +async def test_aiohttp_websocket_server_closing_after_ack(aiohttp_client_and_server): session, server = aiohttp_client_and_server query = gql("query { hello }") - with pytest.raises(TransportClosed): + print("\n Trying to execute first query.\n") + + with pytest.raises(TransportConnectionFailed): + await session.execute(query) + + await session.transport.wait_closed() + + print("\n Trying to execute second query.\n") + + with pytest.raises(TransportConnectionFailed): await session.execute(query) @@ -325,24 +326,22 @@ async def server_sending_invalid_query_errors(ws): @pytest.mark.asyncio @pytest.mark.parametrize("server", [server_sending_invalid_query_errors], indirect=True) -async def test_aiohttp_websocket_server_sending_invalid_query_errors( - event_loop, server -): +async def test_aiohttp_websocket_server_sending_invalid_query_errors(server): from gql.transport.aiohttp_websockets import AIOHTTPWebsocketsTransport url = f"ws://{server.hostname}:{server.port}/graphql" print(f"url = {url}") - sample_transport = AIOHTTPWebsocketsTransport(url=url) + transport = AIOHTTPWebsocketsTransport(url=url) # Invalid server message is ignored - async with Client(transport=sample_transport): + async with Client(transport=transport): await asyncio.sleep(2 * MS) @pytest.mark.asyncio @pytest.mark.parametrize("server", [server_sending_invalid_query_errors], indirect=True) -async def test_aiohttp_websocket_non_regression_bug_105(event_loop, server): +async def test_aiohttp_websocket_non_regression_bug_105(server): from gql.transport.aiohttp_websockets import AIOHTTPWebsocketsTransport # This test will check a fix to a race condition which happens if the user is trying @@ -352,9 +351,9 @@ async def test_aiohttp_websocket_non_regression_bug_105(event_loop, server): url = f"ws://{server.hostname}:{server.port}/graphql" print(f"url = {url}") - sample_transport = AIOHTTPWebsocketsTransport(url=url) + transport = AIOHTTPWebsocketsTransport(url=url) - client = Client(transport=sample_transport) + client = Client(transport=transport) # Create a coroutine which start the connection with the transport but does nothing async def client_connect(client): @@ -373,9 +372,7 @@ async def client_connect(client): @pytest.mark.asyncio @pytest.mark.parametrize("server", [invalid_query1_server], indirect=True) -async def test_aiohttp_websocket_using_cli_invalid_query( - event_loop, server, monkeypatch, capsys -): +async def test_aiohttp_websocket_using_cli_invalid_query(server, monkeypatch, capsys): url = f"ws://{server.hostname}:{server.port}/graphql" print(f"url = {url}") diff --git a/tests/test_aiohttp_websocket_graphqlws_exceptions.py b/tests/test_aiohttp_websocket_graphqlws_exceptions.py index b234d296..52bc27a4 100644 --- a/tests/test_aiohttp_websocket_graphqlws_exceptions.py +++ b/tests/test_aiohttp_websocket_graphqlws_exceptions.py @@ -5,7 +5,7 @@ from gql import Client, gql from gql.transport.exceptions import ( - TransportClosed, + TransportConnectionFailed, TransportProtocolError, TransportQueryError, ) @@ -39,7 +39,7 @@ @pytest.mark.parametrize("graphqlws_server", [invalid_query1_server], indirect=True) @pytest.mark.parametrize("query_str", [invalid_query_str]) async def test_aiohttp_websocket_graphqlws_invalid_query( - event_loop, client_and_aiohttp_websocket_graphql_server, query_str + client_and_aiohttp_websocket_graphql_server, query_str ): session, server = client_and_aiohttp_websocket_graphql_server @@ -82,7 +82,7 @@ async def server_invalid_subscription(ws): ) @pytest.mark.parametrize("query_str", [invalid_subscription_str]) async def test_aiohttp_websocket_graphqlws_invalid_subscription( - event_loop, client_and_aiohttp_websocket_graphql_server, query_str + client_and_aiohttp_websocket_graphql_server, query_str ): session, server = client_and_aiohttp_websocket_graphql_server @@ -110,13 +110,13 @@ async def server_no_ack(ws): @pytest.mark.parametrize("graphqlws_server", [server_no_ack], indirect=True) @pytest.mark.parametrize("query_str", [invalid_query_str]) async def test_aiohttp_websocket_graphqlws_server_does_not_send_ack( - event_loop, graphqlws_server, query_str + graphqlws_server, query_str ): from gql.transport.aiohttp_websockets import AIOHTTPWebsocketsTransport url = f"ws://{graphqlws_server.hostname}:{graphqlws_server.port}/graphql" - transport = AIOHTTPWebsocketsTransport(url=url, ack_timeout=1) + transport = AIOHTTPWebsocketsTransport(url=url, ack_timeout=0.1) with pytest.raises(asyncio.TimeoutError): async with Client(transport=transport): @@ -142,7 +142,7 @@ async def server_invalid_query(ws): @pytest.mark.asyncio @pytest.mark.parametrize("graphqlws_server", [server_invalid_query], indirect=True) async def test_aiohttp_websocket_graphqlws_sending_invalid_query( - event_loop, client_and_aiohttp_websocket_graphql_server + client_and_aiohttp_websocket_graphql_server, ): session, server = client_and_aiohttp_websocket_graphql_server @@ -196,7 +196,7 @@ async def test_aiohttp_websocket_graphqlws_sending_invalid_query( indirect=True, ) async def test_aiohttp_websocket_graphqlws_transport_protocol_errors( - event_loop, client_and_aiohttp_websocket_graphql_server + client_and_aiohttp_websocket_graphql_server, ): session, server = client_and_aiohttp_websocket_graphql_server @@ -215,9 +215,7 @@ async def server_without_ack(ws): @pytest.mark.asyncio @pytest.mark.parametrize("graphqlws_server", [server_without_ack], indirect=True) -async def test_aiohttp_websocket_graphqlws_server_does_not_ack( - event_loop, graphqlws_server -): +async def test_aiohttp_websocket_graphqlws_server_does_not_ack(graphqlws_server): from gql.transport.aiohttp_websockets import AIOHTTPWebsocketsTransport url = f"ws://{graphqlws_server.hostname}:{graphqlws_server.port}/graphql" @@ -236,9 +234,7 @@ async def server_closing_directly(ws): @pytest.mark.asyncio @pytest.mark.parametrize("graphqlws_server", [server_closing_directly], indirect=True) -async def test_aiohttp_websocket_graphqlws_server_closing_directly( - event_loop, graphqlws_server -): +async def test_aiohttp_websocket_graphqlws_server_closing_directly(graphqlws_server): from gql.transport.aiohttp_websockets import AIOHTTPWebsocketsTransport @@ -247,7 +243,7 @@ async def test_aiohttp_websocket_graphqlws_server_closing_directly( transport = AIOHTTPWebsocketsTransport(url=url) - with pytest.raises(ConnectionResetError): + with pytest.raises(TransportConnectionFailed): async with Client(transport=transport): pass @@ -260,17 +256,21 @@ async def server_closing_after_ack(ws): @pytest.mark.asyncio @pytest.mark.parametrize("graphqlws_server", [server_closing_after_ack], indirect=True) async def test_aiohttp_websocket_graphqlws_server_closing_after_ack( - event_loop, client_and_aiohttp_websocket_graphql_server + client_and_aiohttp_websocket_graphql_server, ): session, _ = client_and_aiohttp_websocket_graphql_server query = gql("query { hello }") - with pytest.raises(TransportClosed): + print("\n Trying to execute first query.\n") + + with pytest.raises(TransportConnectionFailed): await session.execute(query) await session.transport.wait_closed() - with pytest.raises(TransportClosed): + print("\n Trying to execute second query.\n") + + with pytest.raises(TransportConnectionFailed): await session.execute(query) diff --git a/tests/test_aiohttp_websocket_graphqlws_subscription.py b/tests/test_aiohttp_websocket_graphqlws_subscription.py index d40d15ce..e03ad8f9 100644 --- a/tests/test_aiohttp_websocket_graphqlws_subscription.py +++ b/tests/test_aiohttp_websocket_graphqlws_subscription.py @@ -8,7 +8,8 @@ from parse import search from gql import Client, gql -from gql.transport.exceptions import TransportServerError +from gql.client import AsyncClientSession +from gql.transport.exceptions import TransportConnectionFailed, TransportServerError from .conftest import MS, WebSocketServerHelper @@ -229,7 +230,7 @@ async def server_countdown_disconnect(ws): @pytest.mark.parametrize("graphqlws_server", [server_countdown], indirect=True) @pytest.mark.parametrize("subscription_str", [countdown_subscription_str]) async def test_aiohttp_websocket_graphqlws_subscription( - event_loop, client_and_aiohttp_websocket_graphql_server, subscription_str + client_and_aiohttp_websocket_graphql_server, subscription_str ): session, server = client_and_aiohttp_websocket_graphql_server @@ -252,7 +253,7 @@ async def test_aiohttp_websocket_graphqlws_subscription( @pytest.mark.parametrize("graphqlws_server", [server_countdown], indirect=True) @pytest.mark.parametrize("subscription_str", [countdown_subscription_str]) async def test_aiohttp_websocket_graphqlws_subscription_break( - event_loop, client_and_aiohttp_websocket_graphql_server, subscription_str + client_and_aiohttp_websocket_graphql_server, subscription_str ): session, server = client_and_aiohttp_websocket_graphql_server @@ -260,7 +261,8 @@ async def test_aiohttp_websocket_graphqlws_subscription_break( count = 10 subscription = gql(subscription_str.format(count=count)) - async for result in session.subscribe(subscription): + generator = session.subscribe(subscription) + async for result in generator: number = result["number"] print(f"Number received: {number}") @@ -274,12 +276,15 @@ async def test_aiohttp_websocket_graphqlws_subscription_break( assert count == 5 + # Using aclose here to make it stop cleanly on pypy + await generator.aclose() + @pytest.mark.asyncio @pytest.mark.parametrize("graphqlws_server", [server_countdown], indirect=True) @pytest.mark.parametrize("subscription_str", [countdown_subscription_str]) async def test_aiohttp_websocket_graphqlws_subscription_task_cancel( - event_loop, client_and_aiohttp_websocket_graphql_server, subscription_str + client_and_aiohttp_websocket_graphql_server, subscription_str ): session, server = client_and_aiohttp_websocket_graphql_server @@ -287,16 +292,24 @@ async def test_aiohttp_websocket_graphqlws_subscription_task_cancel( count = 10 subscription = gql(subscription_str.format(count=count)) + task_cancelled = False + async def task_coro(): nonlocal count - async for result in session.subscribe(subscription): + nonlocal task_cancelled - number = result["number"] - print(f"Number received: {number}") + try: + async for result in session.subscribe(subscription): - assert number == count + number = result["number"] + print(f"Number received: {number}") - count -= 1 + assert number == count + + count -= 1 + except asyncio.CancelledError: + print("Inside task cancelled") + task_cancelled = True task = asyncio.ensure_future(task_coro()) @@ -312,13 +325,14 @@ async def cancel_task_coro(): await asyncio.gather(task, cancel_task) assert count > 0 + assert task_cancelled is True @pytest.mark.asyncio @pytest.mark.parametrize("graphqlws_server", [server_countdown], indirect=True) @pytest.mark.parametrize("subscription_str", [countdown_subscription_str]) async def test_aiohttp_websocket_graphqlws_subscription_close_transport( - event_loop, client_and_aiohttp_websocket_graphql_server, subscription_str + client_and_aiohttp_websocket_graphql_server, subscription_str ): session, server = client_and_aiohttp_websocket_graphql_server @@ -383,14 +397,14 @@ async def server_countdown_close_connection_in_middle(ws): ) @pytest.mark.parametrize("subscription_str", [countdown_subscription_str]) async def test_aiohttp_websocket_graphqlws_subscription_server_connection_closed( - event_loop, client_and_aiohttp_websocket_graphql_server, subscription_str + client_and_aiohttp_websocket_graphql_server, subscription_str ): session, _ = client_and_aiohttp_websocket_graphql_server count = 10 subscription = gql(subscription_str.format(count=count)) - with pytest.raises(ConnectionResetError): + with pytest.raises(TransportConnectionFailed): async for result in session.subscribe(subscription): number = result["number"] print(f"Number received: {number}") @@ -404,17 +418,16 @@ async def test_aiohttp_websocket_graphqlws_subscription_server_connection_closed @pytest.mark.parametrize("graphqlws_server", [server_countdown], indirect=True) @pytest.mark.parametrize("subscription_str", [countdown_subscription_str]) async def test_aiohttp_websocket_graphqlws_subscription_with_operation_name( - event_loop, client_and_aiohttp_websocket_graphql_server, subscription_str + client_and_aiohttp_websocket_graphql_server, subscription_str ): session, server = client_and_aiohttp_websocket_graphql_server count = 10 subscription = gql(subscription_str.format(count=count)) + subscription.operation_name = "CountdownSubscription" - async for result in session.subscribe( - subscription, operation_name="CountdownSubscription" - ): + async for result in session.subscribe(subscription): number = result["number"] print(f"Number received: {number}") @@ -434,7 +447,7 @@ async def test_aiohttp_websocket_graphqlws_subscription_with_operation_name( ) @pytest.mark.parametrize("subscription_str", [countdown_subscription_str]) async def test_aiohttp_websocket_graphqlws_subscription_with_keepalive( - event_loop, client_and_aiohttp_websocket_graphql_server, subscription_str + client_and_aiohttp_websocket_graphql_server, subscription_str ): session, server = client_and_aiohttp_websocket_graphql_server @@ -464,7 +477,7 @@ async def test_aiohttp_websocket_graphqlws_subscription_with_keepalive( ) @pytest.mark.parametrize("subscription_str", [countdown_subscription_str]) async def test_aiohttp_websocket_graphqlws_subscription_with_keepalive_with_timeout_ok( - event_loop, graphqlws_server, subscription_str + graphqlws_server, subscription_str ): from gql.transport.aiohttp_websockets import AIOHTTPWebsocketsTransport @@ -498,7 +511,7 @@ async def test_aiohttp_websocket_graphqlws_subscription_with_keepalive_with_time ) @pytest.mark.parametrize("subscription_str", [countdown_subscription_str]) async def test_aiohttp_websocket_graphqlws_subscription_with_keepalive_with_timeout_nok( - event_loop, graphqlws_server, subscription_str + graphqlws_server, subscription_str ): from gql.transport.aiohttp_websockets import AIOHTTPWebsocketsTransport @@ -533,7 +546,7 @@ async def test_aiohttp_websocket_graphqlws_subscription_with_keepalive_with_time ) @pytest.mark.parametrize("subscription_str", [countdown_subscription_str]) async def test_aiohttp_websocket_graphqlws_subscription_with_ping_interval_ok( - event_loop, graphqlws_server, subscription_str + graphqlws_server, subscription_str ): from gql.transport.aiohttp_websockets import AIOHTTPWebsocketsTransport @@ -569,7 +582,7 @@ async def test_aiohttp_websocket_graphqlws_subscription_with_ping_interval_ok( ) @pytest.mark.parametrize("subscription_str", [countdown_subscription_str]) async def test_aiohttp_websocket_graphqlws_subscription_with_ping_interval_nok( - event_loop, graphqlws_server, subscription_str + graphqlws_server, subscription_str ): from gql.transport.aiohttp_websockets import AIOHTTPWebsocketsTransport @@ -602,7 +615,7 @@ async def test_aiohttp_websocket_graphqlws_subscription_with_ping_interval_nok( ) @pytest.mark.parametrize("subscription_str", [countdown_subscription_str]) async def test_aiohttp_websocket_graphqlws_subscription_manual_pings_with_payload( - event_loop, graphqlws_server, subscription_str + graphqlws_server, subscription_str ): from gql.transport.aiohttp_websockets import AIOHTTPWebsocketsTransport @@ -644,7 +657,7 @@ async def test_aiohttp_websocket_graphqlws_subscription_manual_pings_with_payloa ) @pytest.mark.parametrize("subscription_str", [countdown_subscription_str]) async def test_aiohttp_websocket_graphqlws_subscription_manual_pong_with_payload( - event_loop, graphqlws_server, subscription_str + graphqlws_server, subscription_str ): from gql.transport.aiohttp_websockets import AIOHTTPWebsocketsTransport @@ -759,6 +772,7 @@ def test_aiohttp_websocket_graphqlws_subscription_sync_graceful_shutdown( warnings.filterwarnings( "ignore", message="There is no current event loop" ) + assert isinstance(client.session, AsyncClientSession) asyncio.ensure_future( client.session._generator.athrow(KeyboardInterrupt) ) @@ -777,7 +791,7 @@ def test_aiohttp_websocket_graphqlws_subscription_sync_graceful_shutdown( ) @pytest.mark.parametrize("subscription_str", [countdown_subscription_str]) async def test_aiohttp_websocket_graphqlws_subscription_running_in_thread( - event_loop, graphqlws_server, subscription_str, run_sync_test + graphqlws_server, subscription_str, run_sync_test ): from gql.transport.aiohttp_websockets import AIOHTTPWebsocketsTransport @@ -801,7 +815,7 @@ def test_code(): assert count == -1 - await run_sync_test(event_loop, graphqlws_server, test_code) + await run_sync_test(graphqlws_server, test_code) @pytest.mark.asyncio @@ -811,10 +825,9 @@ def test_code(): @pytest.mark.parametrize("subscription_str", [countdown_subscription_str]) @pytest.mark.parametrize("execute_instead_of_subscribe", [False, True]) async def test_aiohttp_websocket_graphqlws_subscription_reconnecting_session( - event_loop, graphqlws_server, subscription_str, execute_instead_of_subscribe + graphqlws_server, subscription_str, execute_instead_of_subscribe ): - from gql.transport.exceptions import TransportClosed from gql.transport.aiohttp_websockets import AIOHTTPWebsocketsTransport path = "/graphql" @@ -833,44 +846,62 @@ async def test_aiohttp_websocket_graphqlws_subscription_reconnecting_session( reconnecting=True, retry_connect=False, retry_execute=False ) - # First we make a subscription which will cause a disconnect in the backend - # (count=8) - try: - print("\nSUBSCRIPTION_1_WITH_DISCONNECT\n") - async for result in session.subscribe(subscription_with_disconnect): - pass - except ConnectionResetError: - pass - - await asyncio.sleep(50 * MS) - - # Then with the same session handle, we make a subscription or an execute - # which will detect that the transport is closed so that the client could - # try to reconnect + # First we make a query or subscription which will cause a disconnect + # in the backend (count=8) try: if execute_instead_of_subscribe: - print("\nEXECUTION_2\n") - await session.execute(subscription) + print("\nEXECUTION_1\n") + await session.execute(subscription_with_disconnect) else: - print("\nSUBSCRIPTION_2\n") - async for result in session.subscribe(subscription): + print("\nSUBSCRIPTION_1_WITH_DISCONNECT\n") + async for result in session.subscribe(subscription_with_disconnect): pass - except TransportClosed: + except TransportConnectionFailed: pass - await asyncio.sleep(50 * MS) + # Wait for disconnect + for i in range(200): + await asyncio.sleep(1 * MS) + if not transport._connected: + print(f"\nDisconnected in {i+1} MS") + break - # And finally with the same session handle, we make a subscription - # which works correctly - print("\nSUBSCRIPTION_3\n") - async for result in session.subscribe(subscription): + # Wait for reconnect + for i in range(200): + await asyncio.sleep(1 * MS) + if transport._connected: + print(f"\nConnected again in {i+1} MS") + break - number = result["number"] - print(f"Number received: {number}") + assert transport._connected is True + + # Then after the reconnection, we make a query or a subscription + if execute_instead_of_subscribe: + print("\nEXECUTION_2\n") + result = await session.execute(subscription) + assert result["number"] == 10 + else: + print("\nSUBSCRIPTION_2\n") + generator = session.subscribe(subscription) + async for result in generator: + number = result["number"] + print(f"Number received: {number}") - assert number == count - count -= 1 + assert number == count + count -= 1 - assert count == -1 + await generator.aclose() + assert count == -1 + + # Close the reconnecting session await client.close_async() + + # Wait for disconnect + for i in range(200): + await asyncio.sleep(1 * MS) + if not transport._connected: + print(f"\nDisconnected in {i+1} MS") + break + + assert transport._connected is False diff --git a/tests/test_aiohttp_websocket_query.py b/tests/test_aiohttp_websocket_query.py index d76d646f..a3087d78 100644 --- a/tests/test_aiohttp_websocket_query.py +++ b/tests/test_aiohttp_websocket_query.py @@ -1,14 +1,14 @@ import asyncio import json import sys -from typing import Dict, Mapping +from typing import Any, Dict, Mapping import pytest from gql import Client, gql from gql.transport.exceptions import ( TransportAlreadyConnected, - TransportClosed, + TransportConnectionFailed, TransportQueryError, TransportServerError, ) @@ -50,9 +50,7 @@ @pytest.mark.asyncio @pytest.mark.parametrize("aiohttp_ws_server", [server1_answers], indirect=True) -async def test_aiohttp_websocket_starting_client_in_context_manager( - event_loop, aiohttp_ws_server -): +async def test_aiohttp_websocket_starting_client_in_context_manager(aiohttp_ws_server): server = aiohttp_ws_server from gql.transport.aiohttp_websockets import AIOHTTPWebsocketsTransport @@ -60,7 +58,15 @@ async def test_aiohttp_websocket_starting_client_in_context_manager( url = f"ws://{server.hostname}:{server.port}/graphql" print(f"url = {url}") - transport = AIOHTTPWebsocketsTransport(url=url, websocket_close_timeout=10) + transport = AIOHTTPWebsocketsTransport( + url=url, + websocket_close_timeout=10, + headers={"test": "1234"}, + ) + + assert transport.response_headers == {} + assert isinstance(transport.headers, Mapping) + assert transport.headers["test"] == "1234" # type: ignore async with Client(transport=transport) as session: @@ -84,7 +90,7 @@ async def test_aiohttp_websocket_starting_client_in_context_manager( assert transport.response_headers["dummy"] == "test1234" # Check client is disconnect here - assert transport.websocket is None + assert transport._connected is False @pytest.mark.asyncio @@ -93,7 +99,7 @@ async def test_aiohttp_websocket_starting_client_in_context_manager( @pytest.mark.parametrize("ssl_close_timeout", [0, 10]) @pytest.mark.parametrize("verify_https", ["disabled", "cert_provided"]) async def test_aiohttp_websocket_using_ssl_connection( - event_loop, ws_ssl_server, ssl_close_timeout, verify_https + ws_ssl_server, ssl_close_timeout, verify_https ): from gql.transport.aiohttp_websockets import AIOHTTPWebsocketsTransport @@ -135,7 +141,7 @@ async def test_aiohttp_websocket_using_ssl_connection( assert africa["code"] == "AF" # Check client is disconnect here - assert transport.websocket is None + assert transport._connected is False @pytest.mark.asyncio @@ -144,10 +150,11 @@ async def test_aiohttp_websocket_using_ssl_connection( @pytest.mark.parametrize("ssl_close_timeout", [10]) @pytest.mark.parametrize("verify_https", ["explicitely_enabled", "default"]) async def test_aiohttp_websocket_using_ssl_connection_self_cert_fail( - event_loop, ws_ssl_server, ssl_close_timeout, verify_https + ws_ssl_server, ssl_close_timeout, verify_https ): from aiohttp.client_exceptions import ClientConnectorCertificateError + from gql.transport.aiohttp_websockets import AIOHTTPWebsocketsTransport server = ws_ssl_server @@ -155,7 +162,7 @@ async def test_aiohttp_websocket_using_ssl_connection_self_cert_fail( url = f"wss://{server.hostname}:{server.port}/graphql" print(f"url = {url}") - extra_args = {} + extra_args: Dict[str, Any] = {} if verify_https == "explicitely_enabled": extra_args["ssl"] = True @@ -166,28 +173,33 @@ async def test_aiohttp_websocket_using_ssl_connection_self_cert_fail( **extra_args, ) - with pytest.raises(ClientConnectorCertificateError) as exc_info: + if verify_https == "explicitely_enabled": + assert transport.ssl is True + + with pytest.raises(TransportConnectionFailed) as exc_info: async with Client(transport=transport) as session: query1 = gql(query1_str) await session.execute(query1) + cause = exc_info.value.__cause__ + + assert isinstance(cause, ClientConnectorCertificateError) + expected_error = "certificate verify failed: self-signed certificate" - assert expected_error in str(exc_info.value) + assert expected_error in str(cause) # Check client is disconnect here - assert transport.websocket is None + assert transport._connected is False @pytest.mark.asyncio @pytest.mark.websockets @pytest.mark.parametrize("server", [server1_answers], indirect=True) @pytest.mark.parametrize("query_str", [query1_str]) -async def test_aiohttp_websocket_simple_query( - event_loop, aiohttp_client_and_server, query_str -): +async def test_aiohttp_websocket_simple_query(aiohttp_client_and_server, query_str): session, server = aiohttp_client_and_server @@ -210,7 +222,7 @@ async def test_aiohttp_websocket_simple_query( ) @pytest.mark.parametrize("query_str", [query1_str]) async def test_aiohttp_websocket_two_queries_in_series( - event_loop, aiohttp_client_and_aiohttp_ws_server, query_str + aiohttp_client_and_aiohttp_ws_server, query_str ): session, server = aiohttp_client_and_aiohttp_ws_server @@ -247,7 +259,7 @@ async def server1_two_queries_in_parallel(ws): @pytest.mark.parametrize("server", [server1_two_queries_in_parallel], indirect=True) @pytest.mark.parametrize("query_str", [query1_str]) async def test_aiohttp_websocket_two_queries_in_parallel( - event_loop, aiohttp_client_and_server, query_str + aiohttp_client_and_server, query_str ): session, server = aiohttp_client_and_server @@ -295,7 +307,7 @@ async def server_closing_while_we_are_doing_something_else(ws): ) @pytest.mark.parametrize("query_str", [query1_str]) async def test_aiohttp_websocket_server_closing_after_first_query( - event_loop, aiohttp_client_and_server, query_str + aiohttp_client_and_server, query_str ): session, server = aiohttp_client_and_server @@ -306,11 +318,11 @@ async def test_aiohttp_websocket_server_closing_after_first_query( await session.execute(query) # Then we do other things - await asyncio.sleep(1000 * MS) + await asyncio.sleep(10 * MS) # Now the server is closed but we don't know it yet, we have to send a query # to notice it and to receive the exception - with pytest.raises(TransportClosed): + with pytest.raises(TransportConnectionFailed): await session.execute(query) @@ -327,7 +339,7 @@ async def test_aiohttp_websocket_server_closing_after_first_query( ) @pytest.mark.parametrize("query_str", [query1_str]) async def test_aiohttp_websocket_ignore_invalid_id( - event_loop, aiohttp_client_and_aiohttp_ws_server, query_str + aiohttp_client_and_aiohttp_ws_server, query_str ): session, server = aiohttp_client_and_aiohttp_ws_server @@ -363,9 +375,7 @@ async def assert_client_is_working(session): @pytest.mark.asyncio @pytest.mark.parametrize("aiohttp_ws_server", [server1_answers], indirect=True) -async def test_aiohttp_websocket_multiple_connections_in_series( - event_loop, aiohttp_ws_server -): +async def test_aiohttp_websocket_multiple_connections_in_series(aiohttp_ws_server): server = aiohttp_ws_server @@ -380,20 +390,18 @@ async def test_aiohttp_websocket_multiple_connections_in_series( await assert_client_is_working(session) # Check client is disconnect here - assert transport.websocket is None + assert transport._connected is False async with Client(transport=transport) as session: await assert_client_is_working(session) # Check client is disconnect here - assert transport.websocket is None + assert transport._connected is False @pytest.mark.asyncio @pytest.mark.parametrize("aiohttp_ws_server", [server1_answers], indirect=True) -async def test_aiohttp_websocket_multiple_connections_in_parallel( - event_loop, aiohttp_ws_server -): +async def test_aiohttp_websocket_multiple_connections_in_parallel(aiohttp_ws_server): server = aiohttp_ws_server @@ -416,7 +424,7 @@ async def task_coro(): @pytest.mark.asyncio @pytest.mark.parametrize("aiohttp_ws_server", [server1_answers], indirect=True) async def test_aiohttp_websocket_trying_to_connect_to_already_connected_transport( - event_loop, aiohttp_ws_server + aiohttp_ws_server, ): server = aiohttp_ws_server @@ -467,7 +475,7 @@ async def server_with_authentication_in_connection_init_payload(ws): ) @pytest.mark.parametrize("query_str", [query1_str]) async def test_aiohttp_websocket_connect_success_with_authentication_in_connection_init( - event_loop, server, query_str + server, query_str ): from gql.transport.aiohttp_websockets import AIOHTTPWebsocketsTransport @@ -503,7 +511,7 @@ async def test_aiohttp_websocket_connect_success_with_authentication_in_connecti @pytest.mark.parametrize("query_str", [query1_str]) @pytest.mark.parametrize("init_payload", [{}, {"Authorization": "invalid_code"}]) async def test_aiohttp_websocket_connect_failed_with_authentication_in_connection_init( - event_loop, server, query_str, init_payload + server, query_str, init_payload ): from gql.transport.aiohttp_websockets import AIOHTTPWebsocketsTransport @@ -519,8 +527,8 @@ async def test_aiohttp_websocket_connect_failed_with_authentication_in_connectio await session.execute(query1) - assert transport.session is None - assert transport.websocket is None + assert transport.adapter.session is None + assert transport._connected is False @pytest.mark.parametrize("aiohttp_ws_server", [server1_answers], indirect=True) @@ -564,14 +572,12 @@ def test_aiohttp_websocket_execute_sync(aiohttp_ws_server): assert africa["code"] == "AF" # Check client is disconnect here - assert transport.websocket is None + assert transport._connected is False @pytest.mark.asyncio @pytest.mark.parametrize("aiohttp_ws_server", [server1_answers], indirect=True) -async def test_aiohttp_websocket_add_extra_parameters_to_connect( - event_loop, aiohttp_ws_server -): +async def test_aiohttp_websocket_add_extra_parameters_to_connect(aiohttp_ws_server): server = aiohttp_ws_server @@ -613,7 +619,7 @@ async def server_sending_keep_alive_before_connection_ack(ws): ) @pytest.mark.parametrize("query_str", [query1_str]) async def test_aiohttp_websocket_non_regression_bug_108( - event_loop, aiohttp_client_and_server, query_str + aiohttp_client_and_server, query_str ): # This test will check that we now ignore keepalive message @@ -638,9 +644,8 @@ async def test_aiohttp_websocket_non_regression_bug_108( @pytest.mark.parametrize("aiohttp_ws_server", [server1_answers], indirect=True) @pytest.mark.parametrize("transport_arg", [[], ["--transport=aiohttp_websockets"]]) async def test_aiohttp_websocket_using_cli( - event_loop, aiohttp_ws_server, transport_arg, monkeypatch, capsys + aiohttp_ws_server, transport_arg, monkeypatch, capsys ): - """ Note: depending on the transport_arg parameter, if there is no transport argument, then we will use WebsocketsTransport if the websockets dependency is installed, @@ -702,7 +707,7 @@ async def test_aiohttp_websocket_using_cli( ) @pytest.mark.parametrize("query_str", [query1_str]) async def test_aiohttp_websocket_simple_query_with_extensions( - event_loop, aiohttp_client_and_aiohttp_ws_server, query_str + aiohttp_client_and_aiohttp_ws_server, query_str ): session, server = aiohttp_client_and_aiohttp_ws_server @@ -716,7 +721,7 @@ async def test_aiohttp_websocket_simple_query_with_extensions( @pytest.mark.asyncio @pytest.mark.parametrize("aiohttp_ws_server", [server1_answers], indirect=True) -async def test_aiohttp_websocket_connector_owner_false(event_loop, aiohttp_ws_server): +async def test_aiohttp_websocket_connector_owner_false(aiohttp_ws_server): server = aiohttp_ws_server @@ -753,6 +758,6 @@ async def test_aiohttp_websocket_connector_owner_false(event_loop, aiohttp_ws_se assert africa["code"] == "AF" # Check client is disconnect here - assert transport.websocket is None + assert transport._connected is False await connector.close() diff --git a/tests/test_aiohttp_websocket_subscription.py b/tests/test_aiohttp_websocket_subscription.py index 9d2d652b..f06046df 100644 --- a/tests/test_aiohttp_websocket_subscription.py +++ b/tests/test_aiohttp_websocket_subscription.py @@ -9,7 +9,8 @@ from parse import search from gql import Client, gql -from gql.transport.exceptions import TransportClosed, TransportServerError +from gql.client import AsyncClientSession +from gql.transport.exceptions import TransportConnectionFailed, TransportServerError from .conftest import MS, WebSocketServerHelper from .starwars.schema import StarWarsIntrospection, StarWarsSchema, StarWarsTypeDef @@ -194,7 +195,7 @@ async def keepalive_coro(): @pytest.mark.parametrize("server", [server_countdown], indirect=True) @pytest.mark.parametrize("subscription_str", [countdown_subscription_str]) async def test_aiohttp_websocket_subscription( - event_loop, aiohttp_client_and_server, subscription_str + aiohttp_client_and_server, subscription_str ): session, server = aiohttp_client_and_server @@ -217,7 +218,7 @@ async def test_aiohttp_websocket_subscription( @pytest.mark.parametrize("server", [server_countdown], indirect=True) @pytest.mark.parametrize("subscription_str", [countdown_subscription_str]) async def test_aiohttp_websocket_subscription_get_execution_result( - event_loop, aiohttp_client_and_server, subscription_str + aiohttp_client_and_server, subscription_str ): session, server = aiohttp_client_and_server @@ -228,6 +229,7 @@ async def test_aiohttp_websocket_subscription_get_execution_result( async for result in session.subscribe(subscription, get_execution_result=True): assert isinstance(result, ExecutionResult) + assert result.data is not None number = result.data["number"] print(f"Number received: {number}") @@ -242,7 +244,7 @@ async def test_aiohttp_websocket_subscription_get_execution_result( @pytest.mark.parametrize("server", [server_countdown], indirect=True) @pytest.mark.parametrize("subscription_str", [countdown_subscription_str]) async def test_aiohttp_websocket_subscription_break( - event_loop, aiohttp_client_and_server, subscription_str + aiohttp_client_and_server, subscription_str ): session, server = aiohttp_client_and_server @@ -250,7 +252,8 @@ async def test_aiohttp_websocket_subscription_break( count = 10 subscription = gql(subscription_str.format(count=count)) - async for result in session.subscribe(subscription): + generator = session.subscribe(subscription) + async for result in generator: number = result["number"] print(f"Number received: {number}") @@ -264,12 +267,15 @@ async def test_aiohttp_websocket_subscription_break( assert count == 5 + # Using aclose here to make it stop cleanly on pypy + await generator.aclose() + @pytest.mark.asyncio @pytest.mark.parametrize("server", [server_countdown], indirect=True) @pytest.mark.parametrize("subscription_str", [countdown_subscription_str]) async def test_aiohttp_websocket_subscription_task_cancel( - event_loop, aiohttp_client_and_server, subscription_str + aiohttp_client_and_server, subscription_str ): session, server = aiohttp_client_and_server @@ -277,16 +283,24 @@ async def test_aiohttp_websocket_subscription_task_cancel( count = 10 subscription = gql(subscription_str.format(count=count)) + task_cancelled = False + async def task_coro(): nonlocal count - async for result in session.subscribe(subscription): + nonlocal task_cancelled - number = result["number"] - print(f"Number received: {number}") + try: + async for result in session.subscribe(subscription): - assert number == count + number = result["number"] + print(f"Number received: {number}") - count -= 1 + assert number == count + + count -= 1 + except asyncio.CancelledError: + print("Inside task cancelled") + task_cancelled = True task = asyncio.ensure_future(task_coro()) @@ -302,13 +316,14 @@ async def cancel_task_coro(): await asyncio.gather(task, cancel_task) assert count > 0 + assert task_cancelled is True @pytest.mark.asyncio @pytest.mark.parametrize("server", [server_countdown], indirect=True) @pytest.mark.parametrize("subscription_str", [countdown_subscription_str]) async def test_aiohttp_websocket_subscription_close_transport( - event_loop, aiohttp_client_and_server, subscription_str + aiohttp_client_and_server, subscription_str ): session, _ = aiohttp_client_and_server @@ -373,7 +388,7 @@ async def server_countdown_close_connection_in_middle(ws): ) @pytest.mark.parametrize("subscription_str", [countdown_subscription_str]) async def test_aiohttp_websocket_subscription_server_connection_closed( - event_loop, aiohttp_client_and_server, subscription_str + aiohttp_client_and_server, subscription_str ): session, server = aiohttp_client_and_server @@ -381,7 +396,7 @@ async def test_aiohttp_websocket_subscription_server_connection_closed( count = 10 subscription = gql(subscription_str.format(count=count)) - with pytest.raises(ConnectionResetError): + with pytest.raises(TransportConnectionFailed): async for result in session.subscribe(subscription): @@ -397,7 +412,7 @@ async def test_aiohttp_websocket_subscription_server_connection_closed( @pytest.mark.parametrize("server", [server_countdown], indirect=True) @pytest.mark.parametrize("subscription_str", [countdown_subscription_str]) async def test_aiohttp_websocket_subscription_slow_consumer( - event_loop, aiohttp_client_and_server, subscription_str + aiohttp_client_and_server, subscription_str ): session, server = aiohttp_client_and_server @@ -422,17 +437,16 @@ async def test_aiohttp_websocket_subscription_slow_consumer( @pytest.mark.parametrize("server", [server_countdown], indirect=True) @pytest.mark.parametrize("subscription_str", [countdown_subscription_str]) async def test_aiohttp_websocket_subscription_with_operation_name( - event_loop, aiohttp_client_and_server, subscription_str + aiohttp_client_and_server, subscription_str ): session, server = aiohttp_client_and_server count = 10 subscription = gql(subscription_str.format(count=count)) + subscription.operation_name = "CountdownSubscription" - async for result in session.subscribe( - subscription, operation_name="CountdownSubscription" - ): + async for result in session.subscribe(subscription): number = result["number"] print(f"Number received: {number}") @@ -453,7 +467,7 @@ async def test_aiohttp_websocket_subscription_with_operation_name( @pytest.mark.parametrize("server", [server_countdown], indirect=True) @pytest.mark.parametrize("subscription_str", [countdown_subscription_str]) async def test_aiohttp_websocket_subscription_with_keepalive( - event_loop, aiohttp_client_and_server, subscription_str + aiohttp_client_and_server, subscription_str ): session, server = aiohttp_client_and_server @@ -476,7 +490,7 @@ async def test_aiohttp_websocket_subscription_with_keepalive( @pytest.mark.parametrize("server", [server_countdown], indirect=True) @pytest.mark.parametrize("subscription_str", [countdown_subscription_str]) async def test_aiohttp_websocket_subscription_with_keepalive_with_timeout_ok( - event_loop, server, subscription_str + server, subscription_str ): from gql.transport.aiohttp_websockets import AIOHTTPWebsocketsTransport @@ -506,7 +520,7 @@ async def test_aiohttp_websocket_subscription_with_keepalive_with_timeout_ok( @pytest.mark.parametrize("server", [server_countdown], indirect=True) @pytest.mark.parametrize("subscription_str", [countdown_subscription_str]) async def test_aiohttp_websocket_subscription_with_keepalive_with_timeout_nok( - event_loop, server, subscription_str + server, subscription_str ): from gql.transport.aiohttp_websockets import AIOHTTPWebsocketsTransport @@ -665,6 +679,7 @@ def test_aiohttp_websocket_subscription_sync_graceful_shutdown( warnings.filterwarnings( "ignore", message="There is no current event loop" ) + assert isinstance(client.session, AsyncClientSession) interrupt_task = asyncio.ensure_future( client.session._generator.athrow(KeyboardInterrupt) ) @@ -674,6 +689,7 @@ def test_aiohttp_websocket_subscription_sync_graceful_shutdown( assert count == 4 # Catch interrupt_task exception to remove warning + assert interrupt_task is not None interrupt_task.exception() # Check that the server received a connection_terminate message last @@ -684,7 +700,7 @@ def test_aiohttp_websocket_subscription_sync_graceful_shutdown( @pytest.mark.parametrize("server", [server_countdown], indirect=True) @pytest.mark.parametrize("subscription_str", [countdown_subscription_str]) async def test_aiohttp_websocket_subscription_running_in_thread( - event_loop, server, subscription_str, run_sync_test + server, subscription_str, run_sync_test ): from gql.transport.aiohttp_websockets import AIOHTTPWebsocketsTransport @@ -708,7 +724,7 @@ def test_code(): assert count == -1 - await run_sync_test(event_loop, server, test_code) + await run_sync_test(server, test_code) @pytest.mark.asyncio @@ -722,9 +738,7 @@ def test_code(): {"schema": StarWarsTypeDef}, ], ) -async def test_async_aiohttp_client_validation( - event_loop, server, subscription_str, client_params -): +async def test_async_aiohttp_client_validation(server, subscription_str, client_params): from gql.transport.aiohttp_websockets import AIOHTTPWebsocketsTransport @@ -736,15 +750,13 @@ async def test_async_aiohttp_client_validation( async with client as session: - variable_values = {"ep": "JEDI"} - subscription = gql(subscription_str) + subscription.variable_values = {"ep": "JEDI"} + expected = [] - async for result in session.subscribe( - subscription, variable_values=variable_values, parse_result=False - ): + async for result in session.subscribe(subscription, parse_result=False): review = result["reviewAdded"] expected.append(review) @@ -759,7 +771,7 @@ async def test_async_aiohttp_client_validation( @pytest.mark.asyncio @pytest.mark.parametrize("server", [server_countdown], indirect=True) @pytest.mark.parametrize("subscription_str", [countdown_subscription_str]) -async def test_subscribe_on_closing_transport(event_loop, server, subscription_str): +async def test_subscribe_on_closing_transport(server, subscription_str): from gql.transport.aiohttp_websockets import AIOHTTPWebsocketsTransport @@ -772,19 +784,17 @@ async def test_subscribe_on_closing_transport(event_loop, server, subscription_s subscription = gql(subscription_str.format(count=count)) async with client as session: - session.transport.websocket._writer._closing = True + session.transport.adapter.websocket._writer._closing = True - with pytest.raises(ConnectionResetError) as e: + with pytest.raises(TransportConnectionFailed): async for _ in session.subscribe(subscription): pass - assert e.value.args[0] == "Cannot write to closing transport" - @pytest.mark.asyncio @pytest.mark.parametrize("server", [server_countdown], indirect=True) @pytest.mark.parametrize("subscription_str", [countdown_subscription_str]) -async def test_subscribe_on_null_transport(event_loop, server, subscription_str): +async def test_subscribe_on_null_transport(server, subscription_str): from gql.transport.aiohttp_websockets import AIOHTTPWebsocketsTransport @@ -798,9 +808,7 @@ async def test_subscribe_on_null_transport(event_loop, server, subscription_str) async with client as session: - session.transport.websocket = None - with pytest.raises(TransportClosed) as e: + session.transport.adapter.websocket = None + with pytest.raises(TransportConnectionFailed): async for _ in session.subscribe(subscription): pass - - assert e.value.args[0] == "WebSocket connection is closed" diff --git a/tests/test_appsync_auth.py b/tests/test_appsync_auth.py index cb279ae5..94eaed2b 100644 --- a/tests/test_appsync_auth.py +++ b/tests/test_appsync_auth.py @@ -9,28 +9,29 @@ def test_appsync_init_with_minimal_args(fake_session_factory): from gql.transport.appsync_auth import AppSyncIAMAuthentication from gql.transport.appsync_websockets import AppSyncWebsocketsTransport - sample_transport = AppSyncWebsocketsTransport( + transport = AppSyncWebsocketsTransport( url=mock_transport_url, session=fake_session_factory() ) - assert isinstance(sample_transport.auth, AppSyncIAMAuthentication) - assert sample_transport.connect_timeout == 10 - assert sample_transport.close_timeout == 10 - assert sample_transport.ack_timeout == 10 - assert sample_transport.ssl is False - assert sample_transport.connect_args == {} + assert isinstance(transport.auth, AppSyncIAMAuthentication) + assert transport.connect_timeout == 10 + assert transport.close_timeout == 10 + assert transport.ack_timeout == 10 + assert transport.ssl is False + assert transport.connect_args == {} @pytest.mark.botocore def test_appsync_init_with_no_credentials(caplog, fake_session_factory): import botocore.exceptions + from gql.transport.appsync_websockets import AppSyncWebsocketsTransport with pytest.raises(botocore.exceptions.NoCredentialsError): - sample_transport = AppSyncWebsocketsTransport( + transport = AppSyncWebsocketsTransport( url=mock_transport_url, session=fake_session_factory(credentials=None), ) - assert sample_transport.auth is None + assert transport.auth is None expected_error = "Credentials not found" @@ -45,8 +46,8 @@ def test_appsync_init_with_jwt_auth(): from gql.transport.appsync_websockets import AppSyncWebsocketsTransport auth = AppSyncJWTAuthentication(host=mock_transport_host, jwt="some-jwt") - sample_transport = AppSyncWebsocketsTransport(url=mock_transport_url, auth=auth) - assert sample_transport.auth is auth + transport = AppSyncWebsocketsTransport(url=mock_transport_url, auth=auth) + assert transport.auth is auth assert auth.get_headers() == { "host": mock_transport_host, @@ -60,8 +61,8 @@ def test_appsync_init_with_apikey_auth(): from gql.transport.appsync_websockets import AppSyncWebsocketsTransport auth = AppSyncApiKeyAuthentication(host=mock_transport_host, api_key="some-api-key") - sample_transport = AppSyncWebsocketsTransport(url=mock_transport_url, auth=auth) - assert sample_transport.auth is auth + transport = AppSyncWebsocketsTransport(url=mock_transport_url, auth=auth) + assert transport.auth is auth assert auth.get_headers() == { "host": mock_transport_host, @@ -72,6 +73,7 @@ def test_appsync_init_with_apikey_auth(): @pytest.mark.botocore def test_appsync_init_with_iam_auth_without_creds(fake_session_factory): import botocore.exceptions + from gql.transport.appsync_auth import AppSyncIAMAuthentication from gql.transport.appsync_websockets import AppSyncWebsocketsTransport @@ -93,8 +95,8 @@ def test_appsync_init_with_iam_auth_with_creds(fake_credentials_factory): credentials=fake_credentials_factory(), region_name="us-east-1", ) - sample_transport = AppSyncWebsocketsTransport(url=mock_transport_url, auth=auth) - assert sample_transport.auth is auth + transport = AppSyncWebsocketsTransport(url=mock_transport_url, auth=auth) + assert transport.auth is auth @pytest.mark.botocore @@ -108,10 +110,13 @@ def test_appsync_init_with_iam_auth_and_no_region( - you have the AWS_DEFAULT_REGION environment variable set """ - from gql.transport.appsync_websockets import AppSyncWebsocketsTransport - from botocore.exceptions import NoRegionError import logging + from botocore.exceptions import NoRegionError + + from gql.transport.appsync_auth import AppSyncIAMAuthentication + from gql.transport.appsync_websockets import AppSyncWebsocketsTransport + caplog.set_level(logging.WARNING) with pytest.raises(NoRegionError): @@ -120,6 +125,8 @@ def test_appsync_init_with_iam_auth_and_no_region( session._credentials.region = None transport = AppSyncWebsocketsTransport(url=mock_transport_url, session=session) + assert isinstance(transport.auth, AppSyncIAMAuthentication) + # prints the region name in case the test fails print(f"Region found: {transport.auth._region_name}") @@ -146,7 +153,7 @@ def test_munge_url(fake_signer_factory, fake_request_factory): signer=fake_signer_factory(), request_creator=fake_request_factory, ) - sample_transport = AppSyncWebsocketsTransport(url=test_url, auth=auth) + transport = AppSyncWebsocketsTransport(url=test_url, auth=auth) header_string = ( "eyJGYWtlQXV0aG9yaXphdGlvbiI6ImEiLCJGYWtlVGltZSI6InRvZGF5" @@ -157,7 +164,7 @@ def test_munge_url(fake_signer_factory, fake_request_factory): "wss://appsync-realtime-api.aws.example.org/" f"some-other-params?header={header_string}&payload=e30=" ) - assert sample_transport.url == expected_url + assert transport.url == expected_url @pytest.mark.botocore diff --git a/tests/test_appsync_http.py b/tests/test_appsync_http.py index ca3a3fcb..168924bc 100644 --- a/tests/test_appsync_http.py +++ b/tests/test_appsync_http.py @@ -8,13 +8,13 @@ @pytest.mark.asyncio @pytest.mark.aiohttp @pytest.mark.botocore -async def test_appsync_iam_mutation( - event_loop, aiohttp_server, fake_credentials_factory -): +async def test_appsync_iam_mutation(aiohttp_server, fake_credentials_factory): + from urllib.parse import urlparse + from aiohttp import web + from gql.transport.aiohttp import AIOHTTPTransport from gql.transport.appsync_auth import AppSyncIAMAuthentication - from urllib.parse import urlparse async def handler(request): data = { @@ -49,9 +49,9 @@ async def handler(request): region_name="us-east-1", ) - sample_transport = AIOHTTPTransport(url=url, auth=auth) + transport = AIOHTTPTransport(url=url, auth=auth) - async with Client(transport=sample_transport) as session: + async with Client(transport=transport) as session: query = gql( """ diff --git a/tests/test_appsync_websockets.py b/tests/test_appsync_websockets.py index 88bae8b6..b2299960 100644 --- a/tests/test_appsync_websockets.py +++ b/tests/test_appsync_websockets.py @@ -139,7 +139,7 @@ async def realtime_appsync_server_template(ws): ) return - path = ws.path + path = ws.request.path print(f"path = {path}") @@ -404,7 +404,7 @@ async def default_transport_test(transport): @pytest.mark.asyncio @pytest.mark.parametrize("server", [realtime_appsync_server_keepalive], indirect=True) -async def test_appsync_subscription_api_key(event_loop, server): +async def test_appsync_subscription_api_key(server): from gql.transport.appsync_auth import AppSyncApiKeyAuthentication from gql.transport.appsync_websockets import AppSyncWebsocketsTransport @@ -424,11 +424,12 @@ async def test_appsync_subscription_api_key(event_loop, server): @pytest.mark.asyncio @pytest.mark.botocore @pytest.mark.parametrize("server", [realtime_appsync_server], indirect=True) -async def test_appsync_subscription_iam_with_token(event_loop, server): +async def test_appsync_subscription_iam_with_token(server): + + from botocore.credentials import Credentials from gql.transport.appsync_auth import AppSyncIAMAuthentication from gql.transport.appsync_websockets import AppSyncWebsocketsTransport - from botocore.credentials import Credentials path = "/graphql" url = f"ws://{server.hostname}:{server.port}{path}" @@ -451,11 +452,12 @@ async def test_appsync_subscription_iam_with_token(event_loop, server): @pytest.mark.asyncio @pytest.mark.botocore @pytest.mark.parametrize("server", [realtime_appsync_server], indirect=True) -async def test_appsync_subscription_iam_without_token(event_loop, server): +async def test_appsync_subscription_iam_without_token(server): + + from botocore.credentials import Credentials from gql.transport.appsync_auth import AppSyncIAMAuthentication from gql.transport.appsync_websockets import AppSyncWebsocketsTransport - from botocore.credentials import Credentials path = "/graphql" url = f"ws://{server.hostname}:{server.port}{path}" @@ -477,11 +479,12 @@ async def test_appsync_subscription_iam_without_token(event_loop, server): @pytest.mark.asyncio @pytest.mark.botocore @pytest.mark.parametrize("server", [realtime_appsync_server], indirect=True) -async def test_appsync_execute_method_not_allowed(event_loop, server): +async def test_appsync_execute_method_not_allowed(server): + + from botocore.credentials import Credentials from gql.transport.appsync_auth import AppSyncIAMAuthentication from gql.transport.appsync_websockets import AppSyncWebsocketsTransport - from botocore.credentials import Credentials path = "/graphql" url = f"ws://{server.hostname}:{server.port}{path}" @@ -511,10 +514,10 @@ async def test_appsync_execute_method_not_allowed(event_loop, server): }""" ) - variable_values = {"message": "Hello world!"} + query.variable_values = {"message": "Hello world!"} with pytest.raises(AssertionError) as exc_info: - await session.execute(query, variable_values=variable_values) + await session.execute(query) assert ( "execute method is not allowed for AppSyncWebsocketsTransport " @@ -524,11 +527,12 @@ async def test_appsync_execute_method_not_allowed(event_loop, server): @pytest.mark.asyncio @pytest.mark.botocore -async def test_appsync_fetch_schema_from_transport_not_allowed(event_loop): +async def test_appsync_fetch_schema_from_transport_not_allowed(): + + from botocore.credentials import Credentials from gql.transport.appsync_auth import AppSyncIAMAuthentication from gql.transport.appsync_websockets import AppSyncWebsocketsTransport - from botocore.credentials import Credentials dummy_credentials = Credentials( access_key=DUMMY_ACCESS_KEY_ID, @@ -552,7 +556,7 @@ async def test_appsync_fetch_schema_from_transport_not_allowed(event_loop): @pytest.mark.asyncio @pytest.mark.parametrize("server", [realtime_appsync_server], indirect=True) -async def test_appsync_subscription_api_key_unauthorized(event_loop, server): +async def test_appsync_subscription_api_key_unauthorized(server): from gql.transport.appsync_auth import AppSyncApiKeyAuthentication from gql.transport.appsync_websockets import AppSyncWebsocketsTransport @@ -577,12 +581,13 @@ async def test_appsync_subscription_api_key_unauthorized(event_loop, server): @pytest.mark.asyncio @pytest.mark.botocore @pytest.mark.parametrize("server", [realtime_appsync_server], indirect=True) -async def test_appsync_subscription_iam_not_allowed(event_loop, server): +async def test_appsync_subscription_iam_not_allowed(server): + + from botocore.credentials import Credentials from gql.transport.appsync_auth import AppSyncIAMAuthentication from gql.transport.appsync_websockets import AppSyncWebsocketsTransport from gql.transport.exceptions import TransportQueryError - from botocore.credentials import Credentials path = "/graphql" url = f"ws://{server.hostname}:{server.port}{path}" @@ -616,9 +621,7 @@ async def test_appsync_subscription_iam_not_allowed(event_loop, server): @pytest.mark.parametrize( "server", [realtime_appsync_server_not_json_answer], indirect=True ) -async def test_appsync_subscription_server_sending_a_not_json_answer( - event_loop, server -): +async def test_appsync_subscription_server_sending_a_not_json_answer(server): from gql.transport.appsync_auth import AppSyncApiKeyAuthentication from gql.transport.appsync_websockets import AppSyncWebsocketsTransport @@ -644,9 +647,7 @@ async def test_appsync_subscription_server_sending_a_not_json_answer( @pytest.mark.parametrize( "server", [realtime_appsync_server_error_without_id], indirect=True ) -async def test_appsync_subscription_server_sending_an_error_without_an_id( - event_loop, server -): +async def test_appsync_subscription_server_sending_an_error_without_an_id(server): from gql.transport.appsync_auth import AppSyncApiKeyAuthentication from gql.transport.appsync_websockets import AppSyncWebsocketsTransport @@ -670,9 +671,7 @@ async def test_appsync_subscription_server_sending_an_error_without_an_id( @pytest.mark.asyncio @pytest.mark.parametrize("server", [realtime_appsync_server_keepalive], indirect=True) -async def test_appsync_subscription_variable_values_and_operation_name( - event_loop, server -): +async def test_appsync_subscription_variable_values_and_operation_name(server): from gql.transport.appsync_auth import AppSyncApiKeyAuthentication from gql.transport.appsync_websockets import AppSyncWebsocketsTransport @@ -694,10 +693,11 @@ async def test_appsync_subscription_variable_values_and_operation_name( async with client as session: subscription = gql(on_create_message_subscription_str) + subscription.variable_values = {"key1": "val1"} + subscription.operation_name = "onCreateMessage" + async for execution_result in session.subscribe( subscription, - operation_name="onCreateMessage", - variable_values={"key1": "val1"}, get_execution_result=True, ): diff --git a/tests/test_async_client_validation.py b/tests/test_async_client_validation.py index acfabe0e..ec73593e 100644 --- a/tests/test_async_client_validation.py +++ b/tests/test_async_client_validation.py @@ -85,29 +85,25 @@ async def server_starwars(ws): {"schema": StarWarsTypeDef}, ], ) -async def test_async_client_validation( - event_loop, server, subscription_str, client_params -): +async def test_async_client_validation(server, subscription_str, client_params): from gql.transport.websockets import WebsocketsTransport url = f"ws://{server.hostname}:{server.port}/graphql" - sample_transport = WebsocketsTransport(url=url) + transport = WebsocketsTransport(url=url) - client = Client(transport=sample_transport, **client_params) + client = Client(transport=transport, **client_params) async with client as session: - variable_values = {"ep": "JEDI"} - subscription = gql(subscription_str) + subscription.variable_values = {"ep": "JEDI"} + expected = [] - async for result in session.subscribe( - subscription, variable_values=variable_values, parse_result=False - ): + async for result in session.subscribe(subscription, parse_result=False): review = result["reviewAdded"] expected.append(review) @@ -133,27 +129,25 @@ async def test_async_client_validation( ], ) async def test_async_client_validation_invalid_query( - event_loop, server, subscription_str, client_params + server, subscription_str, client_params ): from gql.transport.websockets import WebsocketsTransport url = f"ws://{server.hostname}:{server.port}/graphql" - sample_transport = WebsocketsTransport(url=url) + transport = WebsocketsTransport(url=url) - client = Client(transport=sample_transport, **client_params) + client = Client(transport=transport, **client_params) async with client as session: - variable_values = {"ep": "JEDI"} - subscription = gql(subscription_str) + subscription.variable_values = {"ep": "JEDI"} + with pytest.raises(graphql.error.GraphQLError): - async for _result in session.subscribe( - subscription, variable_values=variable_values - ): + async for _result in session.subscribe(subscription): pass @@ -166,17 +160,17 @@ async def test_async_client_validation_invalid_query( [{"schema": StarWarsSchema, "introspection": StarWarsIntrospection}], ) async def test_async_client_validation_different_schemas_parameters_forbidden( - event_loop, server, subscription_str, client_params + server, subscription_str, client_params ): from gql.transport.websockets import WebsocketsTransport url = f"ws://{server.hostname}:{server.port}/graphql" - sample_transport = WebsocketsTransport(url=url) + transport = WebsocketsTransport(url=url) with pytest.raises(AssertionError): - async with Client(transport=sample_transport, **client_params): + async with Client(transport=transport, **client_params): pass @@ -192,7 +186,7 @@ async def test_async_client_validation_different_schemas_parameters_forbidden( @pytest.mark.asyncio @pytest.mark.parametrize("server", [hero_server_answers], indirect=True) async def test_async_client_validation_fetch_schema_from_server_valid_query( - event_loop, client_and_server + client_and_server, ): session, server = client_and_server client = session.client @@ -230,7 +224,7 @@ async def test_async_client_validation_fetch_schema_from_server_valid_query( @pytest.mark.asyncio @pytest.mark.parametrize("server", [hero_server_answers], indirect=True) async def test_async_client_validation_fetch_schema_from_server_invalid_query( - event_loop, client_and_server + client_and_server, ): session, server = client_and_server @@ -256,17 +250,17 @@ async def test_async_client_validation_fetch_schema_from_server_invalid_query( @pytest.mark.asyncio @pytest.mark.parametrize("server", [hero_server_answers], indirect=True) async def test_async_client_validation_fetch_schema_from_server_with_client_argument( - event_loop, server + server, ): from gql.transport.websockets import WebsocketsTransport url = f"ws://{server.hostname}:{server.port}/graphql" - sample_transport = WebsocketsTransport(url=url) + transport = WebsocketsTransport(url=url) async with Client( - transport=sample_transport, + transport=transport, fetch_schema_from_transport=True, ) as session: diff --git a/tests/test_cli.py b/tests/test_cli.py index dccfcb5a..4c6b7d15 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -286,8 +286,8 @@ async def test_cli_main_appsync_websockets_iam(parser, url): ) def test_cli_get_transport_appsync_websockets_api_key(parser, url): - from gql.transport.appsync_websockets import AppSyncWebsocketsTransport from gql.transport.appsync_auth import AppSyncApiKeyAuthentication + from gql.transport.appsync_websockets import AppSyncWebsocketsTransport args = parser.parse_args( [url, "--transport", "appsync_websockets", "--api-key", "test-api-key"] @@ -307,8 +307,8 @@ def test_cli_get_transport_appsync_websockets_api_key(parser, url): ) def test_cli_get_transport_appsync_websockets_jwt(parser, url): - from gql.transport.appsync_websockets import AppSyncWebsocketsTransport from gql.transport.appsync_auth import AppSyncJWTAuthentication + from gql.transport.appsync_websockets import AppSyncWebsocketsTransport args = parser.parse_args( [url, "--transport", "appsync_websockets", "--jwt", "test-jwt"] diff --git a/tests/test_client.py b/tests/test_client.py index 1e794558..4e2e9bca 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -1,9 +1,10 @@ import os from contextlib import suppress +from typing import Any from unittest import mock import pytest -from graphql import build_ast_schema, parse +from graphql import ExecutionResult, build_ast_schema, parse from gql import Client, GraphQLRequest, gql from gql.transport import Transport @@ -29,35 +30,30 @@ def http_transport_query(): def test_request_transport_not_implemented(http_transport_query): class RandomTransport(Transport): - def execute(self): - super().execute(http_transport_query) + pass + + with pytest.raises(TypeError) as exc_info: + RandomTransport() # type: ignore - with pytest.raises(NotImplementedError) as exc_info: - RandomTransport().execute() + assert "Can't instantiate abstract class RandomTransport" in str(exc_info.value) - assert "Any Transport subclass must implement execute method" == str(exc_info.value) + class RandomTransport2(Transport): + def execute( + self, + request: GraphQLRequest, + *args: Any, + **kwargs: Any, + ) -> ExecutionResult: + return ExecutionResult() - with pytest.raises(NotImplementedError) as exc_info: - RandomTransport().execute_batch([]) + with pytest.raises(NotImplementedError) as exc_info2: + RandomTransport2().execute_batch([]) assert "This Transport has not implemented the execute_batch method" == str( - exc_info.value + exc_info2.value ) -@pytest.mark.aiohttp -def test_request_async_execute_batch_not_implemented_yet(): - from gql.transport.aiohttp import AIOHTTPTransport - - transport = AIOHTTPTransport(url="http://localhost/") - client = Client(transport=transport) - - with pytest.raises(NotImplementedError) as exc_info: - client.execute_batch([GraphQLRequest(document=gql("{dummy}"))]) - - assert "Batching is not implemented for async yet." == str(exc_info.value) - - @pytest.mark.requests @mock.patch("urllib3.connection.HTTPConnection._new_conn") def test_retries_on_transport(execute_mock): @@ -70,7 +66,7 @@ def test_retries_on_transport(execute_mock): expected_retries = 3 execute_mock.side_effect = NewConnectionError( - "Should be HTTPConnection", "Fake connection error" + "Should be HTTPConnection", "Fake connection error" # type: ignore ) transport = RequestsHTTPTransport( url="http://127.0.0.1:8000/graphql", @@ -98,7 +94,7 @@ def test_retries_on_transport(execute_mock): assert execute_mock.call_count == expected_retries + 1 execute_mock.reset_mock() - queries = map(lambda d: GraphQLRequest(document=d), [query, query, query]) + queries = [query, query, query] with client as session: # We're using the client as context manager with pytest.raises(Exception): @@ -109,11 +105,10 @@ def test_retries_on_transport(execute_mock): assert execute_mock.call_count == expected_retries + 1 -def test_no_schema_exception(): +def test_no_schema_no_transport_exception(): with pytest.raises(AssertionError) as exc_info: - client = Client() - client.validate("") - assert "Cannot validate the document locally, you need to pass a schema." in str( + Client() + assert "You need to provide either a transport or a schema to the Client." in str( exc_info.value ) @@ -148,7 +143,7 @@ def test_execute_result_error(): Batching is not supported anymore on countries backend with pytest.raises(TransportQueryError) as exc_info: - client.execute_batch([GraphQLRequest(document=failing_query)]) + client.execute_batch([GraphQLRequest(failing_query)]) assert 'Cannot query field "id" on type "Continent".' in str(exc_info.value) """ @@ -176,7 +171,7 @@ def test_http_transport_verify_error(http_transport_query): Batching is not supported anymore on countries backend with pytest.warns(Warning) as record: - client.execute_batch([GraphQLRequest(document=http_transport_query)]) + client.execute_batch([GraphQLRequest(http_transport_query)]) assert len(record) == 1 assert "Unverified HTTPS request is being made to host" in str( @@ -202,7 +197,7 @@ def test_http_transport_specify_method_valid(http_transport_query): """ Batching is not supported anymore on countries backend - result = client.execute_batch([GraphQLRequest(document=http_transport_query)]) + result = client.execute_batch([GraphQLRequest(http_transport_query)]) assert result is not None """ @@ -255,6 +250,7 @@ def test_sync_transport_close_on_schema_retrieval_failure(): # transport is closed afterwards pass + assert isinstance(client.transport, RequestsHTTPTransport) assert client.transport.session is None @@ -279,4 +275,9 @@ async def test_async_transport_close_on_schema_retrieval_failure(): # transport is closed afterwards pass + assert isinstance(client.transport, AIOHTTPTransport) assert client.transport.session is None + + import asyncio + + await asyncio.sleep(1) diff --git a/tests/test_graphql_request.py b/tests/test_graphql_request.py index 4c9e7d76..ea255c7d 100644 --- a/tests/test_graphql_request.py +++ b/tests/test_graphql_request.py @@ -18,9 +18,9 @@ ) from graphql.utilities import value_from_ast_untyped -from gql import GraphQLRequest, gql +from gql import GraphQLRequest -from .conftest import MS +from .conftest import MS, strip_braces_spaces # Marking all tests in this file with the aiohttp marker pytestmark = pytest.mark.aiohttp @@ -188,15 +188,51 @@ async def subscribe_spend_all(_root, _info, money): def test_serialize_variables_using_money_example(): - req = GraphQLRequest(document=gql("{balance}")) + req = GraphQLRequest("{balance}") money_value = Money(10, "DM") req = GraphQLRequest( - document=gql("query myquery($money: Money) {toEuros(money: $money)}"), + "query myquery($money: Money) {toEuros(money: $money)}", variable_values={"money": money_value}, ) req = req.serialize_variable_values(schema) assert req.variable_values == {"money": {"amount": 10, "currency": "DM"}} + + +def test_graphql_request_using_string_instead_of_document(): + request = GraphQLRequest("{balance}") + + expected_payload = "{'query': '{\\n balance\\n}'}" + + print(request) + + assert str(request) == strip_braces_spaces(expected_payload) + + +def test_graphql_request_init_with_graphql_request(): + money_value_1 = Money(10, "DM") + money_value_2 = Money(20, "DM") + + request_1 = GraphQLRequest( + "query myquery($money: Money) {toEuros(money: $money)}", + variable_values={"money": money_value_1}, + ) + request_2 = GraphQLRequest( + request_1, + ) + request_3 = GraphQLRequest( + request_1, + variable_values={"money": money_value_2}, + ) + + assert request_1.document == request_2.document + assert request_2.document == request_3.document + assert isinstance(request_1.variable_values, Dict) + assert isinstance(request_2.variable_values, Dict) + assert isinstance(request_3.variable_values, Dict) + assert request_1.variable_values["money"] == money_value_1 + assert request_2.variable_values["money"] == money_value_1 + assert request_3.variable_values["money"] == money_value_2 diff --git a/tests/test_graphqlws_exceptions.py b/tests/test_graphqlws_exceptions.py index befeeb4e..6f30c8da 100644 --- a/tests/test_graphqlws_exceptions.py +++ b/tests/test_graphqlws_exceptions.py @@ -5,7 +5,7 @@ from gql import Client, gql from gql.transport.exceptions import ( - TransportClosed, + TransportConnectionFailed, TransportProtocolError, TransportQueryError, ) @@ -38,9 +38,7 @@ @pytest.mark.asyncio @pytest.mark.parametrize("graphqlws_server", [invalid_query1_server], indirect=True) @pytest.mark.parametrize("query_str", [invalid_query_str]) -async def test_graphqlws_invalid_query( - event_loop, client_and_graphqlws_server, query_str -): +async def test_graphqlws_invalid_query(client_and_graphqlws_server, query_str): session, server = client_and_graphqlws_server @@ -81,9 +79,7 @@ async def server_invalid_subscription(ws): "graphqlws_server", [server_invalid_subscription], indirect=True ) @pytest.mark.parametrize("query_str", [invalid_subscription_str]) -async def test_graphqlws_invalid_subscription( - event_loop, client_and_graphqlws_server, query_str -): +async def test_graphqlws_invalid_subscription(client_and_graphqlws_server, query_str): session, server = client_and_graphqlws_server @@ -109,17 +105,15 @@ async def server_no_ack(ws): @pytest.mark.asyncio @pytest.mark.parametrize("graphqlws_server", [server_no_ack], indirect=True) @pytest.mark.parametrize("query_str", [invalid_query_str]) -async def test_graphqlws_server_does_not_send_ack( - event_loop, graphqlws_server, query_str -): +async def test_graphqlws_server_does_not_send_ack(graphqlws_server, query_str): from gql.transport.websockets import WebsocketsTransport url = f"ws://{graphqlws_server.hostname}:{graphqlws_server.port}/graphql" - sample_transport = WebsocketsTransport(url=url, ack_timeout=1) + transport = WebsocketsTransport(url=url, ack_timeout=0.1) with pytest.raises(asyncio.TimeoutError): - async with Client(transport=sample_transport): + async with Client(transport=transport): pass @@ -141,7 +135,7 @@ async def server_invalid_query(ws): @pytest.mark.asyncio @pytest.mark.parametrize("graphqlws_server", [server_invalid_query], indirect=True) -async def test_graphqlws_sending_invalid_query(event_loop, client_and_graphqlws_server): +async def test_graphqlws_sending_invalid_query(client_and_graphqlws_server): session, server = client_and_graphqlws_server @@ -193,9 +187,7 @@ async def test_graphqlws_sending_invalid_query(event_loop, client_and_graphqlws_ ], indirect=True, ) -async def test_graphqlws_transport_protocol_errors( - event_loop, client_and_graphqlws_server -): +async def test_graphqlws_transport_protocol_errors(client_and_graphqlws_server): session, server = client_and_graphqlws_server @@ -213,16 +205,16 @@ async def server_without_ack(ws): @pytest.mark.asyncio @pytest.mark.parametrize("graphqlws_server", [server_without_ack], indirect=True) -async def test_graphqlws_server_does_not_ack(event_loop, graphqlws_server): +async def test_graphqlws_server_does_not_ack(graphqlws_server): from gql.transport.websockets import WebsocketsTransport url = f"ws://{graphqlws_server.hostname}:{graphqlws_server.port}/graphql" print(f"url = {url}") - sample_transport = WebsocketsTransport(url=url) + transport = WebsocketsTransport(url=url) with pytest.raises(TransportProtocolError): - async with Client(transport=sample_transport): + async with Client(transport=transport): pass @@ -232,17 +224,16 @@ async def server_closing_directly(ws): @pytest.mark.asyncio @pytest.mark.parametrize("graphqlws_server", [server_closing_directly], indirect=True) -async def test_graphqlws_server_closing_directly(event_loop, graphqlws_server): - import websockets +async def test_graphqlws_server_closing_directly(graphqlws_server): from gql.transport.websockets import WebsocketsTransport url = f"ws://{graphqlws_server.hostname}:{graphqlws_server.port}/graphql" print(f"url = {url}") - sample_transport = WebsocketsTransport(url=url) + transport = WebsocketsTransport(url=url) - with pytest.raises(websockets.exceptions.ConnectionClosed): - async with Client(transport=sample_transport): + with pytest.raises(TransportConnectionFailed): + async with Client(transport=transport): pass @@ -253,20 +244,38 @@ async def server_closing_after_ack(ws): @pytest.mark.asyncio @pytest.mark.parametrize("graphqlws_server", [server_closing_after_ack], indirect=True) -async def test_graphqlws_server_closing_after_ack( - event_loop, client_and_graphqlws_server -): - - import websockets +async def test_graphqlws_server_closing_after_ack(client_and_graphqlws_server): session, server = client_and_graphqlws_server query = gql("query { hello }") - with pytest.raises(websockets.exceptions.ConnectionClosed): + print("\n Trying to execute first query.\n") + + with pytest.raises(TransportConnectionFailed) as exc1: await session.execute(query) + exc1_cause = exc1.value.__cause__ + exc1_cause_str = f"{type(exc1_cause).__name__}:{exc1_cause!s}" + + print(f"\n First query Exception cause: {exc1_cause_str}\n") + + assert ( + exc1_cause_str == "ConnectionClosedOK:received 1000 (OK); then sent 1000 (OK)" + ) + await session.transport.wait_closed() - with pytest.raises(TransportClosed): + print("\n Trying to execute second query.\n") + + with pytest.raises(TransportConnectionFailed) as exc2: await session.execute(query) + + exc2_cause = exc2.value.__cause__ + exc2_cause_str = f"{type(exc2_cause).__name__}:{exc2_cause!s}" + + print(f" Second query Exception cause: {exc2_cause_str}\n") + + assert ( + exc2_cause_str == "ConnectionClosedOK:received 1000 (OK); then sent 1000 (OK)" + ) diff --git a/tests/test_graphqlws_subscription.py b/tests/test_graphqlws_subscription.py index 683da43a..416726aa 100644 --- a/tests/test_graphqlws_subscription.py +++ b/tests/test_graphqlws_subscription.py @@ -8,7 +8,8 @@ from parse import search from gql import Client, gql -from gql.transport.exceptions import TransportServerError +from gql.client import AsyncClientSession +from gql.transport.exceptions import TransportConnectionFailed, TransportServerError from .conftest import MS, WebSocketServerHelper @@ -228,9 +229,7 @@ async def server_countdown_disconnect(ws): @pytest.mark.asyncio @pytest.mark.parametrize("graphqlws_server", [server_countdown], indirect=True) @pytest.mark.parametrize("subscription_str", [countdown_subscription_str]) -async def test_graphqlws_subscription( - event_loop, client_and_graphqlws_server, subscription_str -): +async def test_graphqlws_subscription(client_and_graphqlws_server, subscription_str): session, server = client_and_graphqlws_server @@ -252,7 +251,7 @@ async def test_graphqlws_subscription( @pytest.mark.parametrize("graphqlws_server", [server_countdown], indirect=True) @pytest.mark.parametrize("subscription_str", [countdown_subscription_str]) async def test_graphqlws_subscription_break( - event_loop, client_and_graphqlws_server, subscription_str + client_and_graphqlws_server, subscription_str ): session, server = client_and_graphqlws_server @@ -260,7 +259,8 @@ async def test_graphqlws_subscription_break( count = 10 subscription = gql(subscription_str.format(count=count)) - async for result in session.subscribe(subscription): + generator = session.subscribe(subscription) + async for result in generator: number = result["number"] print(f"Number received: {number}") @@ -274,12 +274,15 @@ async def test_graphqlws_subscription_break( assert count == 5 + # Using aclose here to make it stop cleanly on pypy + await generator.aclose() + @pytest.mark.asyncio @pytest.mark.parametrize("graphqlws_server", [server_countdown], indirect=True) @pytest.mark.parametrize("subscription_str", [countdown_subscription_str]) async def test_graphqlws_subscription_task_cancel( - event_loop, client_and_graphqlws_server, subscription_str + client_and_graphqlws_server, subscription_str ): session, server = client_and_graphqlws_server @@ -287,16 +290,24 @@ async def test_graphqlws_subscription_task_cancel( count = 10 subscription = gql(subscription_str.format(count=count)) + task_cancelled = False + async def task_coro(): nonlocal count - async for result in session.subscribe(subscription): + nonlocal task_cancelled - number = result["number"] - print(f"Number received: {number}") + try: + async for result in session.subscribe(subscription): - assert number == count + number = result["number"] + print(f"Number received: {number}") - count -= 1 + assert number == count + + count -= 1 + except asyncio.CancelledError: + print("Inside task cancelled") + task_cancelled = True task = asyncio.ensure_future(task_coro()) @@ -312,13 +323,14 @@ async def cancel_task_coro(): await asyncio.gather(task, cancel_task) assert count > 0 + assert task_cancelled is True @pytest.mark.asyncio @pytest.mark.parametrize("graphqlws_server", [server_countdown], indirect=True) @pytest.mark.parametrize("subscription_str", [countdown_subscription_str]) async def test_graphqlws_subscription_close_transport( - event_loop, client_and_graphqlws_server, subscription_str + client_and_graphqlws_server, subscription_str ): session, server = client_and_graphqlws_server @@ -383,16 +395,14 @@ async def server_countdown_close_connection_in_middle(ws): ) @pytest.mark.parametrize("subscription_str", [countdown_subscription_str]) async def test_graphqlws_subscription_server_connection_closed( - event_loop, client_and_graphqlws_server, subscription_str + client_and_graphqlws_server, subscription_str ): - import websockets - session, server = client_and_graphqlws_server count = 10 subscription = gql(subscription_str.format(count=count)) - with pytest.raises(websockets.exceptions.ConnectionClosedOK): + with pytest.raises(TransportConnectionFailed): async for result in session.subscribe(subscription): @@ -408,17 +418,16 @@ async def test_graphqlws_subscription_server_connection_closed( @pytest.mark.parametrize("graphqlws_server", [server_countdown], indirect=True) @pytest.mark.parametrize("subscription_str", [countdown_subscription_str]) async def test_graphqlws_subscription_with_operation_name( - event_loop, client_and_graphqlws_server, subscription_str + client_and_graphqlws_server, subscription_str ): session, server = client_and_graphqlws_server count = 10 subscription = gql(subscription_str.format(count=count)) + subscription.operation_name = "CountdownSubscription" - async for result in session.subscribe( - subscription, operation_name="CountdownSubscription" - ): + async for result in session.subscribe(subscription): number = result["number"] print(f"Number received: {number}") @@ -438,7 +447,7 @@ async def test_graphqlws_subscription_with_operation_name( ) @pytest.mark.parametrize("subscription_str", [countdown_subscription_str]) async def test_graphqlws_subscription_with_keepalive( - event_loop, client_and_graphqlws_server, subscription_str + client_and_graphqlws_server, subscription_str ): session, server = client_and_graphqlws_server @@ -468,7 +477,7 @@ async def test_graphqlws_subscription_with_keepalive( ) @pytest.mark.parametrize("subscription_str", [countdown_subscription_str]) async def test_graphqlws_subscription_with_keepalive_with_timeout_ok( - event_loop, graphqlws_server, subscription_str + graphqlws_server, subscription_str ): from gql.transport.websockets import WebsocketsTransport @@ -500,7 +509,7 @@ async def test_graphqlws_subscription_with_keepalive_with_timeout_ok( ) @pytest.mark.parametrize("subscription_str", [countdown_subscription_str]) async def test_graphqlws_subscription_with_keepalive_with_timeout_nok( - event_loop, graphqlws_server, subscription_str + graphqlws_server, subscription_str ): from gql.transport.websockets import WebsocketsTransport @@ -533,7 +542,7 @@ async def test_graphqlws_subscription_with_keepalive_with_timeout_nok( ) @pytest.mark.parametrize("subscription_str", [countdown_subscription_str]) async def test_graphqlws_subscription_with_ping_interval_ok( - event_loop, graphqlws_server, subscription_str + graphqlws_server, subscription_str ): from gql.transport.websockets import WebsocketsTransport @@ -569,7 +578,7 @@ async def test_graphqlws_subscription_with_ping_interval_ok( ) @pytest.mark.parametrize("subscription_str", [countdown_subscription_str]) async def test_graphqlws_subscription_with_ping_interval_nok( - event_loop, graphqlws_server, subscription_str + graphqlws_server, subscription_str ): from gql.transport.websockets import WebsocketsTransport @@ -602,7 +611,7 @@ async def test_graphqlws_subscription_with_ping_interval_nok( ) @pytest.mark.parametrize("subscription_str", [countdown_subscription_str]) async def test_graphqlws_subscription_manual_pings_with_payload( - event_loop, graphqlws_server, subscription_str + graphqlws_server, subscription_str ): from gql.transport.websockets import WebsocketsTransport @@ -644,7 +653,7 @@ async def test_graphqlws_subscription_manual_pings_with_payload( ) @pytest.mark.parametrize("subscription_str", [countdown_subscription_str]) async def test_graphqlws_subscription_manual_pong_answers_with_payload( - event_loop, graphqlws_server, subscription_str + graphqlws_server, subscription_str ): from gql.transport.websockets import WebsocketsTransport @@ -757,6 +766,7 @@ def test_graphqlws_subscription_sync_graceful_shutdown( warnings.filterwarnings( "ignore", message="There is no current event loop" ) + assert isinstance(client.session, AsyncClientSession) asyncio.ensure_future( client.session._generator.athrow(KeyboardInterrupt) ) @@ -775,7 +785,7 @@ def test_graphqlws_subscription_sync_graceful_shutdown( ) @pytest.mark.parametrize("subscription_str", [countdown_subscription_str]) async def test_graphqlws_subscription_running_in_thread( - event_loop, graphqlws_server, subscription_str, run_sync_test + graphqlws_server, subscription_str, run_sync_test ): from gql.transport.websockets import WebsocketsTransport @@ -799,7 +809,7 @@ def test_code(): assert count == -1 - await run_sync_test(event_loop, graphqlws_server, test_code) + await run_sync_test(graphqlws_server, test_code) @pytest.mark.asyncio @@ -809,12 +819,10 @@ def test_code(): @pytest.mark.parametrize("subscription_str", [countdown_subscription_str]) @pytest.mark.parametrize("execute_instead_of_subscribe", [False, True]) async def test_graphqlws_subscription_reconnecting_session( - event_loop, graphqlws_server, subscription_str, execute_instead_of_subscribe + graphqlws_server, subscription_str, execute_instead_of_subscribe ): - import websockets from gql.transport.websockets import WebsocketsTransport - from gql.transport.exceptions import TransportClosed path = "/graphql" url = f"ws://{graphqlws_server.hostname}:{graphqlws_server.port}{path}" @@ -832,44 +840,62 @@ async def test_graphqlws_subscription_reconnecting_session( reconnecting=True, retry_connect=False, retry_execute=False ) - # First we make a subscription which will cause a disconnect in the backend - # (count=8) - try: - print("\nSUBSCRIPTION_1_WITH_DISCONNECT\n") - async for result in session.subscribe(subscription_with_disconnect): - pass - except websockets.exceptions.ConnectionClosedOK: - pass - - await asyncio.sleep(50 * MS) - - # Then with the same session handle, we make a subscription or an execute - # which will detect that the transport is closed so that the client could - # try to reconnect + # First we make a query or subscription which will cause a disconnect + # in the backend (count=8) try: if execute_instead_of_subscribe: - print("\nEXECUTION_2\n") - await session.execute(subscription) + print("\nEXECUTION_1\n") + await session.execute(subscription_with_disconnect) else: - print("\nSUBSCRIPTION_2\n") - async for result in session.subscribe(subscription): + print("\nSUBSCRIPTION_1_WITH_DISCONNECT\n") + async for result in session.subscribe(subscription_with_disconnect): pass - except TransportClosed: + except TransportConnectionFailed: pass - await asyncio.sleep(50 * MS) + # Wait for disconnect + for i in range(200): + await asyncio.sleep(1 * MS) + if not transport._connected: + print(f"\nDisconnected in {i+1} MS") + break - # And finally with the same session handle, we make a subscription - # which works correctly - print("\nSUBSCRIPTION_3\n") - async for result in session.subscribe(subscription): + # Wait for reconnect + for i in range(200): + await asyncio.sleep(1 * MS) + if transport._connected: + print(f"\nConnected again in {i+1} MS") + break - number = result["number"] - print(f"Number received: {number}") + assert transport._connected is True + + # Then after the reconnection, we make a query or a subscription + if execute_instead_of_subscribe: + print("\nEXECUTION_2\n") + result = await session.execute(subscription) + assert result["number"] == 10 + else: + print("\nSUBSCRIPTION_2\n") + generator = session.subscribe(subscription) + async for result in generator: + number = result["number"] + print(f"Number received: {number}") - assert number == count - count -= 1 + assert number == count + count -= 1 - assert count == -1 + await generator.aclose() + assert count == -1 + + # Close the reconnecting session await client.close_async() + + # Wait for disconnect + for i in range(200): + await asyncio.sleep(1 * MS) + if not transport._connected: + print(f"\nDisconnected in {i+1} MS") + break + + assert transport._connected is False diff --git a/tests/test_http_async_sync.py b/tests/test_http_async_sync.py index 19b6cfa2..61dc1809 100644 --- a/tests/test_http_async_sync.py +++ b/tests/test_http_async_sync.py @@ -7,7 +7,7 @@ @pytest.mark.online @pytest.mark.asyncio @pytest.mark.parametrize("fetch_schema_from_transport", [True, False]) -async def test_async_client_async_transport(event_loop, fetch_schema_from_transport): +async def test_async_client_async_transport(fetch_schema_from_transport): from gql.transport.aiohttp import AIOHTTPTransport @@ -15,11 +15,11 @@ async def test_async_client_async_transport(event_loop, fetch_schema_from_transp url = "https://countries.trevorblades.com/graphql" # Get async transport - sample_transport = AIOHTTPTransport(url=url) + transport = AIOHTTPTransport(url=url) # Instantiate client async with Client( - transport=sample_transport, + transport=transport, fetch_schema_from_transport=fetch_schema_from_transport, ) as session: @@ -51,24 +51,24 @@ async def test_async_client_async_transport(event_loop, fetch_schema_from_transp @pytest.mark.online @pytest.mark.asyncio @pytest.mark.parametrize("fetch_schema_from_transport", [True, False]) -async def test_async_client_sync_transport(event_loop, fetch_schema_from_transport): +async def test_async_client_sync_transport(fetch_schema_from_transport): from gql.transport.requests import RequestsHTTPTransport url = "http://countries.trevorblades.com/graphql" # Get sync transport - sample_transport = RequestsHTTPTransport(url=url, use_json=True) + transport = RequestsHTTPTransport(url=url, use_json=True) # Impossible to use a sync transport asynchronously with pytest.raises(AssertionError): async with Client( - transport=sample_transport, + transport=transport, fetch_schema_from_transport=fetch_schema_from_transport, ): pass - sample_transport.close() + transport.close() @pytest.mark.aiohttp @@ -82,11 +82,11 @@ def test_sync_client_async_transport(fetch_schema_from_transport): url = "https://countries.trevorblades.com/graphql" # Get async transport - sample_transport = AIOHTTPTransport(url=url) + transport = AIOHTTPTransport(url=url) # Instanciate client client = Client( - transport=sample_transport, + transport=transport, fetch_schema_from_transport=fetch_schema_from_transport, ) @@ -125,11 +125,11 @@ def test_sync_client_sync_transport(fetch_schema_from_transport): url = "https://countries.trevorblades.com/graphql" # Get sync transport - sample_transport = RequestsHTTPTransport(url=url, use_json=True) + transport = RequestsHTTPTransport(url=url, use_json=True) # Instanciate client client = Client( - transport=sample_transport, + transport=transport, fetch_schema_from_transport=fetch_schema_from_transport, ) diff --git a/tests/test_httpx.py b/tests/test_httpx.py index 8ef57a84..0411294b 100644 --- a/tests/test_httpx.py +++ b/tests/test_httpx.py @@ -1,17 +1,23 @@ -from typing import Mapping +import os +from typing import Any, Dict, Mapping import pytest -from gql import Client, gql +from gql import Client, FileVar, gql from gql.transport.exceptions import ( TransportAlreadyConnected, TransportClosed, + TransportConnectionFailed, TransportProtocolError, TransportQueryError, TransportServerError, ) -from .conftest import TemporaryFile, get_localhost_ssl_context, strip_braces_spaces +from .conftest import ( + TemporaryFile, + get_localhost_ssl_context_client, + make_upload_handler, +) # Marking all tests in this file with the httpx marker pytestmark = pytest.mark.httpx @@ -36,8 +42,9 @@ @pytest.mark.aiohttp @pytest.mark.asyncio -async def test_httpx_query(event_loop, aiohttp_server, run_sync_test): +async def test_httpx_query(aiohttp_server, run_sync_test): from aiohttp import web + from gql.transport.httpx import HTTPXTransport async def handler(request): @@ -74,16 +81,15 @@ def test_code(): assert isinstance(transport.response_headers, Mapping) assert transport.response_headers["dummy"] == "test1234" - await run_sync_test(event_loop, server, test_code) + await run_sync_test(server, test_code) @pytest.mark.aiohttp @pytest.mark.asyncio @pytest.mark.parametrize("verify_https", ["disabled", "cert_provided"]) -async def test_httpx_query_https( - event_loop, ssl_aiohttp_server, run_sync_test, verify_https -): +async def test_httpx_query_https(ssl_aiohttp_server, run_sync_test, verify_https): from aiohttp import web + from gql.transport.httpx import HTTPXTransport async def handler(request): @@ -105,9 +111,9 @@ def test_code(): extra_args = {} if verify_https == "cert_provided": - cert, _ = get_localhost_ssl_context() + _, ssl_context = get_localhost_ssl_context_client() - extra_args["verify"] = cert.decode() + extra_args["verify"] = ssl_context elif verify_https == "disabled": extra_args["verify"] = False @@ -134,18 +140,18 @@ def test_code(): assert isinstance(transport.response_headers, Mapping) assert transport.response_headers["dummy"] == "test1234" - await run_sync_test(event_loop, server, test_code) + await run_sync_test(server, test_code) @pytest.mark.aiohttp @pytest.mark.asyncio @pytest.mark.parametrize("verify_https", ["explicitely_enabled", "default"]) async def test_httpx_query_https_self_cert_fail( - event_loop, ssl_aiohttp_server, run_sync_test, verify_https + ssl_aiohttp_server, run_sync_test, verify_https ): """By default, we should verify the ssl certificate""" from aiohttp import web - from httpx import ConnectError + from gql.transport.httpx import HTTPXTransport async def handler(request): @@ -164,7 +170,7 @@ async def handler(request): assert str(url).startswith("https://") def test_code(): - extra_args = {} + extra_args: Dict[str, Any] = {} if verify_https == "explicitely_enabled": extra_args["verify"] = True @@ -174,25 +180,30 @@ def test_code(): **extra_args, ) - with pytest.raises(ConnectError) as exc_info: - with Client(transport=transport) as session: + query = gql(query1_str) - query = gql(query1_str) + expected_error = "certificate verify failed: self-signed certificate" - # Execute query synchronously + with pytest.raises(TransportConnectionFailed) as exc_info: + with Client(transport=transport) as session: session.execute(query) - expected_error = "certificate verify failed: self-signed certificate" + assert expected_error in str(exc_info.value) + + with pytest.raises(TransportConnectionFailed) as exc_info: + with Client(transport=transport) as session: + session.execute_batch([query]) assert expected_error in str(exc_info.value) - await run_sync_test(event_loop, server, test_code) + await run_sync_test(server, test_code) @pytest.mark.aiohttp @pytest.mark.asyncio -async def test_httpx_cookies(event_loop, aiohttp_server, run_sync_test): +async def test_httpx_cookies(aiohttp_server, run_sync_test): from aiohttp import web + from gql.transport.httpx import HTTPXTransport async def handler(request): @@ -223,13 +234,14 @@ def test_code(): assert africa["code"] == "AF" - await run_sync_test(event_loop, server, test_code) + await run_sync_test(server, test_code) @pytest.mark.aiohttp @pytest.mark.asyncio -async def test_httpx_error_code_401(event_loop, aiohttp_server, run_sync_test): +async def test_httpx_error_code_401(aiohttp_server, run_sync_test): from aiohttp import web + from gql.transport.httpx import HTTPXTransport async def handler(request): @@ -258,13 +270,14 @@ def test_code(): assert "Client error '401 Unauthorized'" in str(exc_info.value) - await run_sync_test(event_loop, server, test_code) + await run_sync_test(server, test_code) @pytest.mark.aiohttp @pytest.mark.asyncio -async def test_httpx_error_code_429(event_loop, aiohttp_server, run_sync_test): +async def test_httpx_error_code_429(aiohttp_server, run_sync_test): from aiohttp import web + from gql.transport.httpx import HTTPXTransport async def handler(request): @@ -312,8 +325,9 @@ def test_code(): @pytest.mark.aiohttp @pytest.mark.asyncio -async def test_httpx_error_code_500(event_loop, aiohttp_server, run_sync_test): +async def test_httpx_error_code_500(aiohttp_server, run_sync_test): from aiohttp import web + from gql.transport.httpx import HTTPXTransport async def handler(request): @@ -336,7 +350,7 @@ def test_code(): with pytest.raises(TransportServerError): session.execute(query) - await run_sync_test(event_loop, server, test_code) + await run_sync_test(server, test_code) query1_server_error_answer = '{"errors": ["Error 1", "Error 2"]}' @@ -344,8 +358,9 @@ def test_code(): @pytest.mark.aiohttp @pytest.mark.asyncio -async def test_httpx_error_code(event_loop, aiohttp_server, run_sync_test): +async def test_httpx_error_code(aiohttp_server, run_sync_test): from aiohttp import web + from gql.transport.httpx import HTTPXTransport async def handler(request): @@ -369,23 +384,23 @@ def test_code(): with pytest.raises(TransportQueryError): session.execute(query) - await run_sync_test(event_loop, server, test_code) + await run_sync_test(server, test_code) invalid_protocol_responses = [ "{}", "qlsjfqsdlkj", '{"not_data_or_errors": 35}', + "", ] @pytest.mark.aiohttp @pytest.mark.asyncio @pytest.mark.parametrize("response", invalid_protocol_responses) -async def test_httpx_invalid_protocol( - event_loop, aiohttp_server, response, run_sync_test -): +async def test_httpx_invalid_protocol(aiohttp_server, response, run_sync_test): from aiohttp import web + from gql.transport.httpx import HTTPXTransport async def handler(request): @@ -407,13 +422,14 @@ def test_code(): with pytest.raises(TransportProtocolError): session.execute(query) - await run_sync_test(event_loop, server, test_code) + await run_sync_test(server, test_code) @pytest.mark.aiohttp @pytest.mark.asyncio -async def test_httpx_cannot_connect_twice(event_loop, aiohttp_server, run_sync_test): +async def test_httpx_cannot_connect_twice(aiohttp_server, run_sync_test): from aiohttp import web + from gql.transport.httpx import HTTPXTransport async def handler(request): @@ -433,15 +449,14 @@ def test_code(): with pytest.raises(TransportAlreadyConnected): session.transport.connect() - await run_sync_test(event_loop, server, test_code) + await run_sync_test(server, test_code) @pytest.mark.aiohttp @pytest.mark.asyncio -async def test_httpx_cannot_execute_if_not_connected( - event_loop, aiohttp_server, run_sync_test -): +async def test_httpx_cannot_execute_if_not_connected(aiohttp_server, run_sync_test): from aiohttp import web + from gql.transport.httpx import HTTPXTransport async def handler(request): @@ -461,7 +476,7 @@ def test_code(): with pytest.raises(TransportClosed): transport.execute(query) - await run_sync_test(event_loop, server, test_code) + await run_sync_test(server, test_code) query1_server_answer_with_extensions = ( @@ -477,8 +492,9 @@ def test_code(): @pytest.mark.aiohttp @pytest.mark.asyncio -async def test_httpx_query_with_extensions(event_loop, aiohttp_server, run_sync_test): +async def test_httpx_query_with_extensions(aiohttp_server, run_sync_test): from aiohttp import web + from gql.transport.httpx import HTTPXTransport async def handler(request): @@ -503,11 +519,9 @@ def test_code(): assert execution_result.extensions["key1"] == "val1" - await run_sync_test(event_loop, server, test_code) + await run_sync_test(server, test_code) -file_upload_server_answer = '{"data":{"success":true}}' - file_upload_mutation_1 = """ mutation($file: Upload!) { uploadFile(input:{other_var:$other_var, file:$file}) { @@ -532,39 +546,21 @@ def test_code(): @pytest.mark.aiohttp @pytest.mark.asyncio -async def test_httpx_file_upload(event_loop, aiohttp_server, run_sync_test): +async def test_httpx_file_upload(aiohttp_server, run_sync_test): from aiohttp import web - from gql.transport.httpx import HTTPXTransport - - async def single_upload_handler(request): - from aiohttp import web - - reader = await request.multipart() - - field_0 = await reader.next() - assert field_0.name == "operations" - field_0_text = await field_0.text() - assert strip_braces_spaces(field_0_text) == file_upload_mutation_1_operations - - field_1 = await reader.next() - assert field_1.name == "map" - field_1_text = await field_1.text() - assert field_1_text == file_upload_mutation_1_map - - field_2 = await reader.next() - assert field_2.name == "0" - field_2_text = await field_2.text() - assert field_2_text == file_1_content - - field_3 = await reader.next() - assert field_3 is None - return web.Response( - text=file_upload_server_answer, content_type="application/json" - ) + from gql.transport.httpx import HTTPXTransport app = web.Application() - app.router.add_route("POST", "/", single_upload_handler) + app.router.add_route( + "POST", + "/", + make_upload_handler( + expected_map=file_upload_mutation_1_map, + expected_operations=file_upload_mutation_1_operations, + expected_contents=[file_1_content], + ), + ) server = await aiohttp_server(app) url = str(server.make_url("/")) @@ -578,58 +574,56 @@ def test_code(): file_path = test_file.filename + # Using an opened file with open(file_path, "rb") as f: - params = {"file": f, "other_var": 42} - execution_result = session._execute( - query, variable_values=params, upload_files=True - ) + query.variable_values = {"file": f, "other_var": 42} + with pytest.warns( + DeprecationWarning, + match="Not using FileVar for file upload is deprecated", + ): + execution_result = session.execute(query, upload_files=True) - assert execution_result.data["success"] + assert execution_result["success"] - await run_sync_test(event_loop, server, test_code) - - -@pytest.mark.aiohttp -@pytest.mark.asyncio -async def test_httpx_file_upload_with_content_type( - event_loop, aiohttp_server, run_sync_test -): - from aiohttp import web - from gql.transport.httpx import HTTPXTransport + # Using an opened file inside a FileVar object + with open(file_path, "rb") as f: - async def single_upload_handler(request): - from aiohttp import web + query.variable_values = {"file": FileVar(f), "other_var": 42} + execution_result = session.execute(query, upload_files=True) - reader = await request.multipart() + assert execution_result["success"] - field_0 = await reader.next() - assert field_0.name == "operations" - field_0_text = await field_0.text() - assert strip_braces_spaces(field_0_text) == file_upload_mutation_1_operations + # Using an filename string inside a FileVar object + query.variable_values = { + "file": FileVar(file_path), + "other_var": 42, + } + execution_result = session.execute(query, upload_files=True) - field_1 = await reader.next() - assert field_1.name == "map" - field_1_text = await field_1.text() - assert field_1_text == file_upload_mutation_1_map + assert execution_result["success"] - field_2 = await reader.next() - assert field_2.name == "0" - field_2_text = await field_2.text() - assert field_2_text == file_1_content + await run_sync_test(server, test_code) - # Verifying the content_type - assert field_2.headers["Content-Type"] == "application/pdf" - field_3 = await reader.next() - assert field_3 is None +@pytest.mark.aiohttp +@pytest.mark.asyncio +async def test_httpx_file_upload_with_content_type(aiohttp_server, run_sync_test): + from aiohttp import web - return web.Response( - text=file_upload_server_answer, content_type="application/json" - ) + from gql.transport.httpx import HTTPXTransport app = web.Application() - app.router.add_route("POST", "/", single_upload_handler) + app.router.add_route( + "POST", + "/", + make_upload_handler( + file_headers=[{"Content-Type": "application/pdf"}], + expected_map=file_upload_mutation_1_map, + expected_operations=file_upload_mutation_1_operations, + expected_contents=[file_1_content], + ), + ) server = await aiohttp_server(app) url = str(server.make_url("/")) @@ -643,60 +637,98 @@ def test_code(): file_path = test_file.filename + # Using an opened file with open(file_path, "rb") as f: # Setting the content_type - f.content_type = "application/pdf" + f.content_type = "application/pdf" # type: ignore + + query.variable_values = {"file": f, "other_var": 42} + with pytest.warns( + DeprecationWarning, + match="Not using FileVar for file upload is deprecated", + ): + execution_result = session.execute(query, upload_files=True) - params = {"file": f, "other_var": 42} - execution_result = session._execute( - query, variable_values=params, upload_files=True - ) + assert execution_result["success"] - assert execution_result.data["success"] + # Using FileVar + query.variable_values = { + "file": FileVar(file_path, content_type="application/pdf"), + "other_var": 42, + } + execution_result = session.execute(query, upload_files=True) - await run_sync_test(event_loop, server, test_code) + assert execution_result["success"] + + await run_sync_test(server, test_code) @pytest.mark.aiohttp @pytest.mark.asyncio -async def test_httpx_file_upload_additional_headers( - event_loop, aiohttp_server, run_sync_test +async def test_httpx_file_upload_default_filename_is_basename( + aiohttp_server, run_sync_test ): from aiohttp import web + from gql.transport.httpx import HTTPXTransport - async def single_upload_handler(request): - from aiohttp import web + app = web.Application() + + with TemporaryFile(file_1_content) as test_file: + file_path = test_file.filename + file_basename = os.path.basename(file_path) + + app.router.add_route( + "POST", + "/", + make_upload_handler( + filenames=[file_basename], + expected_map=file_upload_mutation_1_map, + expected_operations=file_upload_mutation_1_operations, + expected_contents=[file_1_content], + ), + ) + server = await aiohttp_server(app) - assert request.headers["X-Auth"] == "foobar" + url = str(server.make_url("/")) - reader = await request.multipart() + def test_code(): + transport = HTTPXTransport(url=url) - field_0 = await reader.next() - assert field_0.name == "operations" - field_0_text = await field_0.text() - assert strip_braces_spaces(field_0_text) == file_upload_mutation_1_operations + with Client(transport=transport) as session: + query = gql(file_upload_mutation_1) - field_1 = await reader.next() - assert field_1.name == "map" - field_1_text = await field_1.text() - assert field_1_text == file_upload_mutation_1_map + # Using FileVar + query.variable_values = { + "file": FileVar(file_path), + "other_var": 42, + } + execution_result = session.execute(query, upload_files=True) - field_2 = await reader.next() - assert field_2.name == "0" - field_2_text = await field_2.text() - assert field_2_text == file_1_content + assert execution_result["success"] - field_3 = await reader.next() - assert field_3 is None + await run_sync_test(server, test_code) - return web.Response( - text=file_upload_server_answer, content_type="application/json" - ) + +@pytest.mark.aiohttp +@pytest.mark.asyncio +async def test_httpx_file_upload_additional_headers(aiohttp_server, run_sync_test): + from aiohttp import web + + from gql.transport.httpx import HTTPXTransport app = web.Application() - app.router.add_route("POST", "/", single_upload_handler) + app.router.add_route( + "POST", + "/", + make_upload_handler( + request_headers={"X-Auth": "foobar"}, + expected_map=file_upload_mutation_1_map, + expected_operations=file_upload_mutation_1_operations, + expected_contents=[file_1_content], + ), + ) server = await aiohttp_server(app) url = str(server.make_url("/")) @@ -710,57 +742,35 @@ def test_code(): file_path = test_file.filename - with open(file_path, "rb") as f: + query.variable_values = {"file": FileVar(file_path), "other_var": 42} + execution_result = session.execute(query, upload_files=True) - params = {"file": f, "other_var": 42} - execution_result = session._execute( - query, variable_values=params, upload_files=True - ) + assert execution_result["success"] - assert execution_result.data["success"] - - await run_sync_test(event_loop, server, test_code) + await run_sync_test(server, test_code) @pytest.mark.aiohttp @pytest.mark.asyncio -async def test_httpx_binary_file_upload(event_loop, aiohttp_server, run_sync_test): +async def test_httpx_binary_file_upload(aiohttp_server, run_sync_test): from aiohttp import web + from gql.transport.httpx import HTTPXTransport # This is a sample binary file content containing all possible byte values binary_file_content = bytes(range(0, 256)) - async def binary_upload_handler(request): - - from aiohttp import web - - reader = await request.multipart() - - field_0 = await reader.next() - assert field_0.name == "operations" - field_0_text = await field_0.text() - assert strip_braces_spaces(field_0_text) == file_upload_mutation_1_operations - - field_1 = await reader.next() - assert field_1.name == "map" - field_1_text = await field_1.text() - assert field_1_text == file_upload_mutation_1_map - - field_2 = await reader.next() - assert field_2.name == "0" - field_2_binary = await field_2.read() - assert field_2_binary == binary_file_content - - field_3 = await reader.next() - assert field_3 is None - - return web.Response( - text=file_upload_server_answer, content_type="application/json" - ) - app = web.Application() - app.router.add_route("POST", "/", binary_upload_handler) + app.router.add_route( + "POST", + "/", + make_upload_handler( + binary=True, + expected_contents=[binary_file_content], + expected_map=file_upload_mutation_1_map, + expected_operations=file_upload_mutation_1_operations, + ), + ) server = await aiohttp_server(app) url = str(server.make_url("/")) @@ -775,30 +785,20 @@ def test_code(): file_path = test_file.filename - with open(file_path, "rb") as f: - - params = {"file": f, "other_var": 42} - - execution_result = session._execute( - query, variable_values=params, upload_files=True - ) + query.variable_values = {"file": FileVar(file_path), "other_var": 42} - assert execution_result.data["success"] + execution_result = session.execute(query, upload_files=True) - await run_sync_test(event_loop, server, test_code) + assert execution_result["success"] - -file_upload_mutation_2_operations = ( - '{"query": "mutation ($file1: Upload!, $file2: Upload!) {\\n ' - 'uploadFile(input: {file1: $file, file2: $file}) {\\n success\\n }\\n}", ' - '"variables": {"file1": null, "file2": null}}' -) + await run_sync_test(server, test_code) @pytest.mark.aiohttp @pytest.mark.asyncio -async def test_httpx_file_upload_two_files(event_loop, aiohttp_server, run_sync_test): +async def test_httpx_file_upload_two_files(aiohttp_server, run_sync_test): from aiohttp import web + from gql.transport.httpx import HTTPXTransport file_upload_mutation_2 = """ @@ -809,6 +809,12 @@ async def test_httpx_file_upload_two_files(event_loop, aiohttp_server, run_sync_ } """ + file_upload_mutation_2_operations = ( + '{"query": "mutation ($file1: Upload!, $file2: Upload!) {\\n ' + 'uploadFile(input: {file1: $file, file2: $file}) {\\n success\\n }\\n}", ' + '"variables": {"file1": null, "file2": null}}' + ) + file_upload_mutation_2_map = '{"0": ["variables.file1"], "1": ["variables.file2"]}' file_2_content = """ @@ -816,39 +822,17 @@ async def test_httpx_file_upload_two_files(event_loop, aiohttp_server, run_sync_ This file will also be sent in the GraphQL mutation """ - async def handler(request): - - reader = await request.multipart() - - field_0 = await reader.next() - assert field_0.name == "operations" - field_0_text = await field_0.text() - assert strip_braces_spaces(field_0_text) == file_upload_mutation_2_operations - - field_1 = await reader.next() - assert field_1.name == "map" - field_1_text = await field_1.text() - assert field_1_text == file_upload_mutation_2_map - - field_2 = await reader.next() - assert field_2.name == "0" - field_2_text = await field_2.text() - assert field_2_text == file_1_content - - field_3 = await reader.next() - assert field_3.name == "1" - field_3_text = await field_3.text() - assert field_3_text == file_2_content - - field_4 = await reader.next() - assert field_4 is None - - return web.Response( - text=file_upload_server_answer, content_type="application/json" - ) - app = web.Application() - app.router.add_route("POST", "/", handler) + app.router.add_route( + "POST", + "/", + make_upload_handler( + nb_files=2, + expected_map=file_upload_mutation_2_map, + expected_operations=file_upload_mutation_2_operations, + expected_contents=[file_1_content, file_2_content], + ), + ) server = await aiohttp_server(app) url = str(server.make_url("/")) @@ -866,39 +850,23 @@ def test_code(): file_path_1 = test_file_1.filename file_path_2 = test_file_2.filename - f1 = open(file_path_1, "rb") - f2 = open(file_path_2, "rb") - - params = { - "file1": f1, - "file2": f2, + query.variable_values = { + "file1": FileVar(file_path_1), + "file2": FileVar(file_path_2), } - execution_result = session._execute( - query, variable_values=params, upload_files=True - ) - - assert execution_result.data["success"] + execution_result = session.execute(query, upload_files=True) - f1.close() - f2.close() + assert execution_result["success"] - await run_sync_test(event_loop, server, test_code) - - -file_upload_mutation_3_operations = ( - '{"query": "mutation ($files: [Upload!]!) {\\n uploadFiles' - "(input: {files: $files})" - ' {\\n success\\n }\\n}", "variables": {"files": [null, null]}}' -) + await run_sync_test(server, test_code) @pytest.mark.aiohttp @pytest.mark.asyncio -async def test_httpx_file_upload_list_of_two_files( - event_loop, aiohttp_server, run_sync_test -): +async def test_httpx_file_upload_list_of_two_files(aiohttp_server, run_sync_test): from aiohttp import web + from gql.transport.httpx import HTTPXTransport file_upload_mutation_3 = """ @@ -909,6 +877,12 @@ async def test_httpx_file_upload_list_of_two_files( } """ + file_upload_mutation_3_operations = ( + '{"query": "mutation ($files: [Upload!]!) {\\n uploadFiles' + "(input: {files: $files})" + ' {\\n success\\n }\\n}", "variables": {"files": [null, null]}}' + ) + file_upload_mutation_3_map = ( '{"0": ["variables.files.0"], "1": ["variables.files.1"]}' ) @@ -918,39 +892,17 @@ async def test_httpx_file_upload_list_of_two_files( This file will also be sent in the GraphQL mutation """ - async def handler(request): - - reader = await request.multipart() - - field_0 = await reader.next() - assert field_0.name == "operations" - field_0_text = await field_0.text() - assert strip_braces_spaces(field_0_text) == file_upload_mutation_3_operations - - field_1 = await reader.next() - assert field_1.name == "map" - field_1_text = await field_1.text() - assert field_1_text == file_upload_mutation_3_map - - field_2 = await reader.next() - assert field_2.name == "0" - field_2_text = await field_2.text() - assert field_2_text == file_1_content - - field_3 = await reader.next() - assert field_3.name == "1" - field_3_text = await field_3.text() - assert field_3_text == file_2_content - - field_4 = await reader.next() - assert field_4 is None - - return web.Response( - text=file_upload_server_answer, content_type="application/json" - ) - app = web.Application() - app.router.add_route("POST", "/", handler) + app.router.add_route( + "POST", + "/", + make_upload_handler( + nb_files=2, + expected_map=file_upload_mutation_3_map, + expected_operations=file_upload_mutation_3_operations, + expected_contents=[file_1_content, file_2_content], + ), + ) server = await aiohttp_server(app) url = str(server.make_url("/")) @@ -967,27 +919,25 @@ def test_code(): file_path_1 = test_file_1.filename file_path_2 = test_file_2.filename - f1 = open(file_path_1, "rb") - f2 = open(file_path_2, "rb") - - params = {"files": [f1, f2]} - - execution_result = session._execute( - query, variable_values=params, upload_files=True - ) + query.variable_values = { + "files": [ + FileVar(file_path_1), + FileVar(file_path_2), + ], + } - assert execution_result.data["success"] + execution_result = session.execute(query, upload_files=True) - f1.close() - f2.close() + assert execution_result["success"] - await run_sync_test(event_loop, server, test_code) + await run_sync_test(server, test_code) @pytest.mark.aiohttp @pytest.mark.asyncio -async def test_httpx_error_fetching_schema(event_loop, aiohttp_server, run_sync_test): +async def test_httpx_error_fetching_schema(aiohttp_server, run_sync_test): from aiohttp import web + from gql.transport.httpx import HTTPXTransport error_answer = """ @@ -1028,4 +978,4 @@ def test_code(): assert expected_error in str(exc_info.value) assert transport.client is None - await run_sync_test(event_loop, server, test_code) + await run_sync_test(server, test_code) diff --git a/tests/test_httpx_async.py b/tests/test_httpx_async.py index 47744538..690b3ee7 100644 --- a/tests/test_httpx_async.py +++ b/tests/test_httpx_async.py @@ -1,14 +1,15 @@ import io import json -from typing import Mapping +from typing import Any, Dict, Mapping import pytest -from gql import Client, gql +from gql import Client, FileVar, gql from gql.cli import get_parser, main from gql.transport.exceptions import ( TransportAlreadyConnected, TransportClosed, + TransportConnectionFailed, TransportProtocolError, TransportQueryError, TransportServerError, @@ -17,7 +18,7 @@ from .conftest import ( TemporaryFile, get_localhost_ssl_context_client, - strip_braces_spaces, + make_upload_handler, ) query1_str = """ @@ -46,8 +47,9 @@ @pytest.mark.aiohttp @pytest.mark.asyncio -async def test_httpx_query(event_loop, aiohttp_server): +async def test_httpx_query(aiohttp_server): from aiohttp import web + from gql.transport.httpx import HTTPXAsyncTransport async def handler(request): @@ -86,8 +88,9 @@ async def handler(request): @pytest.mark.aiohttp @pytest.mark.asyncio -async def test_httpx_ignore_backend_content_type(event_loop, aiohttp_server): +async def test_httpx_ignore_backend_content_type(aiohttp_server): from aiohttp import web + from gql.transport.httpx import HTTPXAsyncTransport async def handler(request): @@ -116,8 +119,9 @@ async def handler(request): @pytest.mark.aiohttp @pytest.mark.asyncio -async def test_httpx_cookies(event_loop, aiohttp_server): +async def test_httpx_cookies(aiohttp_server): from aiohttp import web + from gql.transport.httpx import HTTPXAsyncTransport async def handler(request): @@ -150,8 +154,9 @@ async def handler(request): @pytest.mark.aiohttp @pytest.mark.asyncio -async def test_httpx_error_code_401(event_loop, aiohttp_server): +async def test_httpx_error_code_401(aiohttp_server): from aiohttp import web + from gql.transport.httpx import HTTPXAsyncTransport async def handler(request): @@ -182,8 +187,9 @@ async def handler(request): @pytest.mark.aiohttp @pytest.mark.asyncio -async def test_httpx_error_code_429(event_loop, aiohttp_server): +async def test_httpx_error_code_429(aiohttp_server): from aiohttp import web + from gql.transport.httpx import HTTPXAsyncTransport async def handler(request): @@ -230,8 +236,9 @@ async def handler(request): @pytest.mark.aiohttp @pytest.mark.asyncio -async def test_httpx_error_code_500(event_loop, aiohttp_server): +async def test_httpx_error_code_500(aiohttp_server): from aiohttp import web + from gql.transport.httpx import HTTPXAsyncTransport async def handler(request): @@ -266,8 +273,9 @@ async def handler(request): @pytest.mark.aiohttp @pytest.mark.asyncio @pytest.mark.parametrize("query_error", transport_query_error_responses) -async def test_httpx_error_code(event_loop, aiohttp_server, query_error): +async def test_httpx_error_code(aiohttp_server, query_error): from aiohttp import web + from gql.transport.httpx import HTTPXAsyncTransport async def handler(request): @@ -322,8 +330,9 @@ async def handler(request): @pytest.mark.aiohttp @pytest.mark.asyncio @pytest.mark.parametrize("param", invalid_protocol_responses) -async def test_httpx_invalid_protocol(event_loop, aiohttp_server, param): +async def test_httpx_invalid_protocol(aiohttp_server, param): from aiohttp import web + from gql.transport.httpx import HTTPXAsyncTransport response = param["response"] @@ -351,8 +360,9 @@ async def handler(request): @pytest.mark.aiohttp @pytest.mark.asyncio -async def test_httpx_subscribe_not_supported(event_loop, aiohttp_server): +async def test_httpx_subscribe_not_supported(aiohttp_server): from aiohttp import web + from gql.transport.httpx import HTTPXAsyncTransport async def handler(request): @@ -377,8 +387,9 @@ async def handler(request): @pytest.mark.aiohttp @pytest.mark.asyncio -async def test_httpx_cannot_connect_twice(event_loop, aiohttp_server): +async def test_httpx_cannot_connect_twice(aiohttp_server): from aiohttp import web + from gql.transport.httpx import HTTPXAsyncTransport async def handler(request): @@ -400,8 +411,9 @@ async def handler(request): @pytest.mark.aiohttp @pytest.mark.asyncio -async def test_httpx_cannot_execute_if_not_connected(event_loop, aiohttp_server): +async def test_httpx_cannot_execute_if_not_connected(aiohttp_server): from aiohttp import web + from gql.transport.httpx import HTTPXAsyncTransport async def handler(request): @@ -423,10 +435,11 @@ async def handler(request): @pytest.mark.aiohttp @pytest.mark.asyncio -async def test_httpx_extra_args(event_loop, aiohttp_server): +async def test_httpx_extra_args(aiohttp_server): + import httpx from aiohttp import web + from gql.transport.httpx import HTTPXAsyncTransport - import httpx async def handler(request): return web.Response(text=query1_server_answer, content_type="application/json") @@ -438,14 +451,14 @@ async def handler(request): url = str(server.make_url("/")) # passing extra arguments to httpx.AsyncClient - transport = httpx.AsyncHTTPTransport(retries=2) - transport = HTTPXAsyncTransport(url=url, max_redirects=2, transport=transport) + inner_transport = httpx.AsyncHTTPTransport(retries=2) + transport = HTTPXAsyncTransport(url=url, max_redirects=2, transport=inner_transport) async with Client(transport=transport) as session: query = gql(query1_str) - # Passing extra arguments to the post method of aiohttp + # Passing extra arguments to the post method result = await session.execute(query, extra_args={"follow_redirects": True}) continents = result["continents"] @@ -468,8 +481,9 @@ async def handler(request): @pytest.mark.aiohttp @pytest.mark.asyncio -async def test_httpx_query_variable_values(event_loop, aiohttp_server): +async def test_httpx_query_variable_values(aiohttp_server): from aiohttp import web + from gql.transport.httpx import HTTPXAsyncTransport async def handler(request): @@ -485,14 +499,13 @@ async def handler(request): async with Client(transport=transport) as session: - params = {"code": "EU"} - query = gql(query2_str) + query.variable_values = {"code": "EU"} + query.operation_name = "getEurope" + # Execute query asynchronously - result = await session.execute( - query, variable_values=params, operation_name="getEurope" - ) + result = await session.execute(query) continent = result["continent"] @@ -501,12 +514,13 @@ async def handler(request): @pytest.mark.aiohttp @pytest.mark.asyncio -async def test_httpx_query_variable_values_fix_issue_292(event_loop, aiohttp_server): +async def test_httpx_query_variable_values_fix_issue_292(aiohttp_server): """Allow to specify variable_values without keyword. See https://github.com/graphql-python/gql/issues/292""" from aiohttp import web + from gql.transport.httpx import HTTPXAsyncTransport async def handler(request): @@ -522,12 +536,13 @@ async def handler(request): async with Client(transport=transport) as session: - params = {"code": "EU"} - query = gql(query2_str) + query.variable_values = {"code": "EU"} + query.operation_name = "getEurope" + # Execute query asynchronously - result = await session.execute(query, params, operation_name="getEurope") + result = await session.execute(query) continent = result["continent"] @@ -536,10 +551,9 @@ async def handler(request): @pytest.mark.aiohttp @pytest.mark.asyncio -async def test_httpx_execute_running_in_thread( - event_loop, aiohttp_server, run_sync_test -): +async def test_httpx_execute_running_in_thread(aiohttp_server, run_sync_test): from aiohttp import web + from gql.transport.httpx import HTTPXAsyncTransport async def handler(request): @@ -560,15 +574,14 @@ def test_code(): client.execute(query) - await run_sync_test(event_loop, server, test_code) + await run_sync_test(server, test_code) @pytest.mark.aiohttp @pytest.mark.asyncio -async def test_httpx_subscribe_running_in_thread( - event_loop, aiohttp_server, run_sync_test -): +async def test_httpx_subscribe_running_in_thread(aiohttp_server, run_sync_test): from aiohttp import web + from gql.transport.httpx import HTTPXAsyncTransport async def handler(request): @@ -598,10 +611,8 @@ def test_code(): for result in client.subscribe(query): pass - await run_sync_test(event_loop, server, test_code) - + await run_sync_test(server, test_code) -file_upload_server_answer = '{"data":{"success":true}}' file_upload_mutation_1 = """ mutation($file: Upload!) { @@ -625,41 +636,23 @@ def test_code(): """ -async def single_upload_handler(request): - - from aiohttp import web - - reader = await request.multipart() - - field_0 = await reader.next() - assert field_0.name == "operations" - field_0_text = await field_0.text() - assert strip_braces_spaces(field_0_text) == file_upload_mutation_1_operations - - field_1 = await reader.next() - assert field_1.name == "map" - field_1_text = await field_1.text() - assert field_1_text == file_upload_mutation_1_map - - field_2 = await reader.next() - assert field_2.name == "0" - field_2_text = await field_2.text() - assert field_2_text == file_1_content - - field_3 = await reader.next() - assert field_3 is None - - return web.Response(text=file_upload_server_answer, content_type="application/json") - - @pytest.mark.aiohttp @pytest.mark.asyncio -async def test_httpx_file_upload(event_loop, aiohttp_server): +async def test_httpx_file_upload(aiohttp_server): from aiohttp import web + from gql.transport.httpx import HTTPXAsyncTransport app = web.Application() - app.router.add_route("POST", "/", single_upload_handler) + app.router.add_route( + "POST", + "/", + make_upload_handler( + expected_map=file_upload_mutation_1_map, + expected_operations=file_upload_mutation_1_operations, + expected_contents=[file_1_content], + ), + ) server = await aiohttp_server(app) url = str(server.make_url("/")) @@ -674,30 +667,59 @@ async def test_httpx_file_upload(event_loop, aiohttp_server): file_path = test_file.filename + # Using an opened file + with open(file_path, "rb") as f: + + query.variable_values = {"file": f, "other_var": 42} + + # Execute query asynchronously + with pytest.warns( + DeprecationWarning, + match="Not using FileVar for file upload is deprecated", + ): + result = await session.execute(query, upload_files=True) + + success = result["success"] + assert success + + # Using an opened file inside a FileVar object with open(file_path, "rb") as f: - params = {"file": f, "other_var": 42} + query.variable_values = {"file": FileVar(f), "other_var": 42} # Execute query asynchronously - result = await session.execute( - query, variable_values=params, upload_files=True - ) + result = await session.execute(query, upload_files=True) success = result["success"] + assert success + # Using an filename string inside a FileVar object + query.variable_values = {"file": FileVar(file_path), "other_var": 42} + + # Execute query asynchronously + result = await session.execute(query, upload_files=True) + + success = result["success"] assert success @pytest.mark.aiohttp @pytest.mark.asyncio -async def test_httpx_file_upload_without_session( - event_loop, aiohttp_server, run_sync_test -): +async def test_httpx_file_upload_without_session(aiohttp_server, run_sync_test): from aiohttp import web + from gql.transport.httpx import HTTPXAsyncTransport app = web.Application() - app.router.add_route("POST", "/", single_upload_handler) + app.router.add_route( + "POST", + "/", + make_upload_handler( + expected_map=file_upload_mutation_1_map, + expected_operations=file_upload_mutation_1_operations, + expected_contents=[file_1_content], + ), + ) server = await aiohttp_server(app) url = str(server.make_url("/")) @@ -713,60 +735,38 @@ def test_code(): file_path = test_file.filename - with open(file_path, "rb") as f: - - params = {"file": f, "other_var": 42} - - result = client.execute( - query, variable_values=params, upload_files=True - ) - - success = result["success"] - - assert success - - await run_sync_test(event_loop, server, test_code) + query.variable_values = {"file": FileVar(file_path), "other_var": 42} + result = client.execute(query, upload_files=True) -# This is a sample binary file content containing all possible byte values -binary_file_content = bytes(range(0, 256)) - - -async def binary_upload_handler(request): - - from aiohttp import web - - reader = await request.multipart() - - field_0 = await reader.next() - assert field_0.name == "operations" - field_0_text = await field_0.text() - assert strip_braces_spaces(field_0_text) == file_upload_mutation_1_operations - - field_1 = await reader.next() - assert field_1.name == "map" - field_1_text = await field_1.text() - assert field_1_text == file_upload_mutation_1_map - - field_2 = await reader.next() - assert field_2.name == "0" - field_2_binary = await field_2.read() - assert field_2_binary == binary_file_content + success = result["success"] - field_3 = await reader.next() - assert field_3 is None + assert success - return web.Response(text=file_upload_server_answer, content_type="application/json") + await run_sync_test(server, test_code) @pytest.mark.aiohttp @pytest.mark.asyncio -async def test_httpx_binary_file_upload(event_loop, aiohttp_server): +async def test_httpx_binary_file_upload(aiohttp_server): from aiohttp import web + from gql.transport.httpx import HTTPXAsyncTransport + # This is a sample binary file content containing all possible byte values + binary_file_content = bytes(range(0, 256)) + app = web.Application() - app.router.add_route("POST", "/", binary_upload_handler) + app.router.add_route( + "POST", + "/", + make_upload_handler( + binary=True, + expected_contents=[binary_file_content], + expected_map=file_upload_mutation_1_map, + expected_operations=file_upload_mutation_1_operations, + ), + ) server = await aiohttp_server(app) url = str(server.make_url("/")) @@ -781,81 +781,55 @@ async def test_httpx_binary_file_upload(event_loop, aiohttp_server): file_path = test_file.filename - with open(file_path, "rb") as f: - - params = {"file": f, "other_var": 42} + query.variable_values = {"file": FileVar(file_path), "other_var": 42} - # Execute query asynchronously - result = await session.execute( - query, variable_values=params, upload_files=True - ) + # Execute query asynchronously + result = await session.execute(query, upload_files=True) success = result["success"] assert success -file_upload_mutation_2 = """ - mutation($file1: Upload!, $file2: Upload!) { - uploadFile(input:{file1:$file, file2:$file}) { - success - } - } -""" - -file_upload_mutation_2_operations = ( - '{"query": "mutation ($file1: Upload!, $file2: Upload!) {\\n ' - 'uploadFile(input: {file1: $file, file2: $file}) {\\n success\\n }\\n}", ' - '"variables": {"file1": null, "file2": null}}' -) - -file_upload_mutation_2_map = '{"0": ["variables.file1"], "1": ["variables.file2"]}' - -file_2_content = """ -This is a second test file -This file will also be sent in the GraphQL mutation -""" - - @pytest.mark.aiohttp @pytest.mark.asyncio -async def test_httpx_file_upload_two_files(event_loop, aiohttp_server): +async def test_httpx_file_upload_two_files(aiohttp_server): from aiohttp import web - from gql.transport.httpx import HTTPXAsyncTransport - - async def handler(request): - - reader = await request.multipart() - - field_0 = await reader.next() - assert field_0.name == "operations" - field_0_text = await field_0.text() - assert strip_braces_spaces(field_0_text) == file_upload_mutation_2_operations - field_1 = await reader.next() - assert field_1.name == "map" - field_1_text = await field_1.text() - assert field_1_text == file_upload_mutation_2_map + from gql.transport.httpx import HTTPXAsyncTransport - field_2 = await reader.next() - assert field_2.name == "0" - field_2_text = await field_2.text() - assert field_2_text == file_1_content + file_upload_mutation_2 = """ + mutation($file1: Upload!, $file2: Upload!) { + uploadFile(input:{file1:$file, file2:$file}) { + success + } + } + """ - field_3 = await reader.next() - assert field_3.name == "1" - field_3_text = await field_3.text() - assert field_3_text == file_2_content + file_upload_mutation_2_operations = ( + '{"query": "mutation ($file1: Upload!, $file2: Upload!) {\\n ' + 'uploadFile(input: {file1: $file, file2: $file}) {\\n success\\n }\\n}", ' + '"variables": {"file1": null, "file2": null}}' + ) - field_4 = await reader.next() - assert field_4 is None + file_upload_mutation_2_map = '{"0": ["variables.file1"], "1": ["variables.file2"]}' - return web.Response( - text=file_upload_server_answer, content_type="application/json" - ) + file_2_content = """ + This is a second test file + This file will also be sent in the GraphQL mutation + """ app = web.Application() - app.router.add_route("POST", "/", handler) + app.router.add_route( + "POST", + "/", + make_upload_handler( + nb_files=2, + expected_map=file_upload_mutation_2_map, + expected_operations=file_upload_mutation_2_operations, + expected_contents=[file_1_content, file_2_content], + ), + ) server = await aiohttp_server(app) url = str(server.make_url("/")) @@ -872,82 +846,58 @@ async def handler(request): file_path_1 = test_file_1.filename file_path_2 = test_file_2.filename - f1 = open(file_path_1, "rb") - f2 = open(file_path_2, "rb") - - params = { - "file1": f1, - "file2": f2, + query.variable_values = { + "file1": FileVar(file_path_1), + "file2": FileVar(file_path_2), } - result = await session.execute( - query, variable_values=params, upload_files=True - ) - - f1.close() - f2.close() + result = await session.execute(query, upload_files=True) success = result["success"] - assert success -file_upload_mutation_3 = """ - mutation($files: [Upload!]!) { - uploadFiles(input:{files:$files}) { - success - } - } -""" - -file_upload_mutation_3_operations = ( - '{"query": "mutation ($files: [Upload!]!) {\\n uploadFiles(' - "input: {files: $files})" - ' {\\n success\\n }\\n}", "variables": {"files": [null, null]}}' -) - -file_upload_mutation_3_map = '{"0": ["variables.files.0"], "1": ["variables.files.1"]}' - - @pytest.mark.aiohttp @pytest.mark.asyncio -async def test_httpx_file_upload_list_of_two_files(event_loop, aiohttp_server): +async def test_httpx_file_upload_list_of_two_files(aiohttp_server): from aiohttp import web - from gql.transport.httpx import HTTPXAsyncTransport - - async def handler(request): - - reader = await request.multipart() - field_0 = await reader.next() - assert field_0.name == "operations" - field_0_text = await field_0.text() - assert strip_braces_spaces(field_0_text) == file_upload_mutation_3_operations - - field_1 = await reader.next() - assert field_1.name == "map" - field_1_text = await field_1.text() - assert field_1_text == file_upload_mutation_3_map + from gql.transport.httpx import HTTPXAsyncTransport - field_2 = await reader.next() - assert field_2.name == "0" - field_2_text = await field_2.text() - assert field_2_text == file_1_content + file_upload_mutation_3 = """ + mutation($files: [Upload!]!) { + uploadFiles(input:{files:$files}) { + success + } + } + """ - field_3 = await reader.next() - assert field_3.name == "1" - field_3_text = await field_3.text() - assert field_3_text == file_2_content + file_upload_mutation_3_operations = ( + '{"query": "mutation ($files: [Upload!]!) {\\n uploadFiles' + "(input: {files: $files})" + ' {\\n success\\n }\\n}", "variables": {"files": [null, null]}}' + ) - field_4 = await reader.next() - assert field_4 is None + file_upload_mutation_3_map = ( + '{"0": ["variables.files.0"], "1": ["variables.files.1"]}' + ) - return web.Response( - text=file_upload_server_answer, content_type="application/json" - ) + file_2_content = """ + This is a second test file + This file will also be sent in the GraphQL mutation + """ app = web.Application() - app.router.add_route("POST", "/", handler) + app.router.add_route( + "POST", + "/", + make_upload_handler( + nb_files=2, + expected_map=file_upload_mutation_3_map, + expected_operations=file_upload_mutation_3_operations, + expected_contents=[file_1_content, file_2_content], + ), + ) server = await aiohttp_server(app) url = str(server.make_url("/")) @@ -964,27 +914,23 @@ async def handler(request): file_path_1 = test_file_1.filename file_path_2 = test_file_2.filename - f1 = open(file_path_1, "rb") - f2 = open(file_path_2, "rb") - - params = {"files": [f1, f2]} + query.variable_values = { + "files": [ + FileVar(file_path_1), + FileVar(file_path_2), + ], + } # Execute query asynchronously - result = await session.execute( - query, variable_values=params, upload_files=True - ) - - f1.close() - f2.close() + result = await session.execute(query, upload_files=True) success = result["success"] - assert success @pytest.mark.aiohttp @pytest.mark.asyncio -async def test_httpx_using_cli(event_loop, aiohttp_server, monkeypatch, capsys): +async def test_httpx_using_cli(aiohttp_server, monkeypatch, capsys): from aiohttp import web async def handler(request): @@ -1022,7 +968,7 @@ async def handler(request): @pytest.mark.asyncio @pytest.mark.script_launch_mode("subprocess") async def test_httpx_using_cli_ep( - event_loop, aiohttp_server, monkeypatch, script_runner, run_sync_test + aiohttp_server, monkeypatch, script_runner, run_sync_test ): from aiohttp import web @@ -1055,14 +1001,12 @@ def test_code(): assert received_answer == expected_answer - await run_sync_test(event_loop, server, test_code) + await run_sync_test(server, test_code) @pytest.mark.aiohttp @pytest.mark.asyncio -async def test_httpx_using_cli_invalid_param( - event_loop, aiohttp_server, monkeypatch, capsys -): +async def test_httpx_using_cli_invalid_param(aiohttp_server, monkeypatch, capsys): from aiohttp import web async def handler(request): @@ -1097,9 +1041,7 @@ async def handler(request): @pytest.mark.aiohttp @pytest.mark.asyncio -async def test_httpx_using_cli_invalid_query( - event_loop, aiohttp_server, monkeypatch, capsys -): +async def test_httpx_using_cli_invalid_query(aiohttp_server, monkeypatch, capsys): from aiohttp import web async def handler(request): @@ -1138,8 +1080,9 @@ async def handler(request): @pytest.mark.aiohttp @pytest.mark.asyncio -async def test_httpx_query_with_extensions(event_loop, aiohttp_server): +async def test_httpx_query_with_extensions(aiohttp_server): from aiohttp import web + from gql.transport.httpx import HTTPXAsyncTransport async def handler(request): @@ -1167,8 +1110,9 @@ async def handler(request): @pytest.mark.aiohttp @pytest.mark.asyncio @pytest.mark.parametrize("verify_https", ["disabled", "cert_provided"]) -async def test_httpx_query_https(event_loop, ssl_aiohttp_server, verify_https): +async def test_httpx_query_https(ssl_aiohttp_server, verify_https): from aiohttp import web + from gql.transport.httpx import HTTPXAsyncTransport async def handler(request): @@ -1210,12 +1154,10 @@ async def handler(request): @pytest.mark.aiohttp @pytest.mark.asyncio @pytest.mark.parametrize("verify_https", ["explicitely_enabled", "default"]) -async def test_httpx_query_https_self_cert_fail( - event_loop, ssl_aiohttp_server, verify_https -): +async def test_httpx_query_https_self_cert_fail(ssl_aiohttp_server, verify_https): from aiohttp import web + from gql.transport.httpx import HTTPXAsyncTransport - from httpx import ConnectError async def handler(request): return web.Response(text=query1_server_answer, content_type="application/json") @@ -1228,30 +1170,35 @@ async def handler(request): assert url.startswith("https://") - extra_args = {} + extra_args: Dict[str, Any] = {} if verify_https == "explicitely_enabled": extra_args["verify"] = True transport = HTTPXAsyncTransport(url=url, timeout=10, **extra_args) - with pytest.raises(ConnectError) as exc_info: - async with Client(transport=transport) as session: + query = gql(query1_str) - query = gql(query1_str) + expected_error = "certificate verify failed: self-signed certificate" - # Execute query asynchronously + with pytest.raises(TransportConnectionFailed) as exc_info: + async with Client(transport=transport) as session: await session.execute(query) - expected_error = "certificate verify failed: self-signed certificate" + assert expected_error in str(exc_info.value) + + with pytest.raises(TransportConnectionFailed) as exc_info: + async with Client(transport=transport) as session: + await session.execute_batch([query]) assert expected_error in str(exc_info.value) @pytest.mark.aiohttp @pytest.mark.asyncio -async def test_httpx_error_fetching_schema(event_loop, aiohttp_server): +async def test_httpx_error_fetching_schema(aiohttp_server): from aiohttp import web + from gql.transport.httpx import HTTPXAsyncTransport error_answer = """ @@ -1294,8 +1241,9 @@ async def handler(request): @pytest.mark.aiohttp @pytest.mark.asyncio -async def test_httpx_reconnecting_session(event_loop, aiohttp_server): +async def test_httpx_reconnecting_session(aiohttp_server): from aiohttp import web + from gql.transport.httpx import HTTPXAsyncTransport async def handler(request): @@ -1333,8 +1281,9 @@ async def handler(request): @pytest.mark.aiohttp @pytest.mark.asyncio @pytest.mark.parametrize("retries", [False, lambda e: e]) -async def test_httpx_reconnecting_session_retries(event_loop, aiohttp_server, retries): +async def test_httpx_reconnecting_session_retries(aiohttp_server, retries): from aiohttp import web + from gql.transport.httpx import HTTPXAsyncTransport async def handler(request): @@ -1366,9 +1315,10 @@ async def handler(request): @pytest.mark.aiohttp @pytest.mark.asyncio async def test_httpx_reconnecting_session_start_connecting_task_twice( - event_loop, aiohttp_server, caplog + aiohttp_server, caplog ): from aiohttp import web + from gql.transport.httpx import HTTPXAsyncTransport async def handler(request): @@ -1401,8 +1351,9 @@ async def handler(request): @pytest.mark.aiohttp @pytest.mark.asyncio -async def test_httpx_json_serializer(event_loop, aiohttp_server, caplog): +async def test_httpx_json_serializer(aiohttp_server, caplog): from aiohttp import web + from gql.transport.httpx import HTTPXAsyncTransport async def handler(request): @@ -1458,10 +1409,12 @@ async def handler(request): @pytest.mark.aiohttp @pytest.mark.asyncio -async def test_httpx_json_deserializer(event_loop, aiohttp_server): - from aiohttp import web +async def test_httpx_json_deserializer(aiohttp_server): from decimal import Decimal from functools import partial + + from aiohttp import web + from gql.transport.httpx import HTTPXAsyncTransport async def handler(request): diff --git a/tests/test_httpx_batch.py b/tests/test_httpx_batch.py new file mode 100644 index 00000000..63472dab --- /dev/null +++ b/tests/test_httpx_batch.py @@ -0,0 +1,436 @@ +from typing import Mapping + +import pytest + +from gql import Client, GraphQLRequest +from gql.transport.exceptions import ( + TransportClosed, + TransportProtocolError, + TransportQueryError, +) + +# Marking all tests in this file with the httpx marker +pytestmark = pytest.mark.httpx + +query1_str = """ + query getContinents { + continents { + code + name + } + } +""" + +query1_server_answer_list = ( + '[{"data":{"continents":[' + '{"code":"AF","name":"Africa"},{"code":"AN","name":"Antarctica"},' + '{"code":"AS","name":"Asia"},{"code":"EU","name":"Europe"},' + '{"code":"NA","name":"North America"},{"code":"OC","name":"Oceania"},' + '{"code":"SA","name":"South America"}]}}]' +) + + +@pytest.mark.aiohttp +@pytest.mark.asyncio +async def test_httpx_async_batch_query(aiohttp_server): + from aiohttp import web + + from gql.transport.httpx import HTTPXAsyncTransport + + async def handler(request): + return web.Response( + text=query1_server_answer_list, + content_type="application/json", + headers={"dummy": "test1234"}, + ) + + app = web.Application() + app.router.add_route("POST", "/", handler) + server = await aiohttp_server(app) + + url = str(server.make_url("/")) + + transport = HTTPXAsyncTransport(url=url, timeout=10) + + async with Client(transport=transport) as session: + + query = [GraphQLRequest(query1_str)] + + # Execute query asynchronously + results = await session.execute_batch(query) + + result = results[0] + + continents = result["continents"] + + africa = continents[0] + + assert africa["code"] == "AF" + + # Checking response headers are saved in the transport + assert hasattr(transport, "response_headers") + assert isinstance(transport.response_headers, Mapping) + assert transport.response_headers["dummy"] == "test1234" + + +@pytest.mark.aiohttp +@pytest.mark.asyncio +async def test_httpx_sync_batch_query(aiohttp_server, run_sync_test): + from aiohttp import web + + from gql.transport.httpx import HTTPXTransport + + async def handler(request): + return web.Response( + text=query1_server_answer_list, + content_type="application/json", + headers={"dummy": "test1234"}, + ) + + app = web.Application() + app.router.add_route("POST", "/", handler) + server = await aiohttp_server(app) + + url = str(server.make_url("/")) + + transport = HTTPXTransport(url=url, timeout=10) + + def test_code(): + with Client(transport=transport) as session: + + query = [GraphQLRequest(query1_str)] + + results = session.execute_batch(query) + + result = results[0] + + continents = result["continents"] + + africa = continents[0] + + assert africa["code"] == "AF" + + # Checking response headers are saved in the transport + assert hasattr(transport, "response_headers") + assert isinstance(transport.response_headers, Mapping) + assert transport.response_headers["dummy"] == "test1234" + + await run_sync_test(server, test_code) + + +@pytest.mark.aiohttp +@pytest.mark.asyncio +async def test_httpx_async_batch_query_without_session(aiohttp_server, run_sync_test): + from aiohttp import web + + from gql.transport.httpx import HTTPXAsyncTransport + + async def handler(request): + return web.Response( + text=query1_server_answer_list, + content_type="application/json", + headers={"dummy": "test1234"}, + ) + + app = web.Application() + app.router.add_route("POST", "/", handler) + server = await aiohttp_server(app) + + url = str(server.make_url("/")) + + def test_code(): + transport = HTTPXAsyncTransport(url=url, timeout=10) + + client = Client(transport=transport) + + query = [GraphQLRequest(query1_str)] + + results = client.execute_batch(query) + + result = results[0] + + continents = result["continents"] + + africa = continents[0] + + assert africa["code"] == "AF" + + # Checking response headers are saved in the transport + assert hasattr(transport, "response_headers") + assert isinstance(transport.response_headers, Mapping) + assert transport.response_headers["dummy"] == "test1234" + + await run_sync_test(server, test_code) + + +query1_server_error_answer_list = '[{"errors": ["Error 1", "Error 2"]}]' + + +@pytest.mark.aiohttp +@pytest.mark.asyncio +async def test_httpx_async_batch_error_code(aiohttp_server): + from aiohttp import web + + from gql.transport.httpx import HTTPXAsyncTransport + + async def handler(request): + return web.Response( + text=query1_server_error_answer_list, content_type="application/json" + ) + + app = web.Application() + app.router.add_route("POST", "/", handler) + server = await aiohttp_server(app) + + url = str(server.make_url("/")) + + transport = HTTPXAsyncTransport(url=url) + + async with Client(transport=transport) as session: + + query = [GraphQLRequest(query1_str)] + + with pytest.raises(TransportQueryError): + await session.execute_batch(query) + + +invalid_protocol_responses = [ + "{}", + "qlsjfqsdlkj", + '{"not_data_or_errors": 35}', + "[{}]", + "[qlsjfqsdlkj]", + '[{"not_data_or_errors": 35}]', + "[]", + "[1]", +] + + +@pytest.mark.aiohttp +@pytest.mark.asyncio +@pytest.mark.parametrize("response", invalid_protocol_responses) +async def test_httpx_async_batch_invalid_protocol(aiohttp_server, response): + from aiohttp import web + + from gql.transport.httpx import HTTPXAsyncTransport + + async def handler(request): + return web.Response(text=response, content_type="application/json") + + app = web.Application() + app.router.add_route("POST", "/", handler) + server = await aiohttp_server(app) + + url = str(server.make_url("/")) + + transport = HTTPXAsyncTransport(url=url, timeout=10) + + async with Client(transport=transport) as session: + + query = [GraphQLRequest(query1_str)] + + with pytest.raises(TransportProtocolError): + await session.execute_batch(query) + + +@pytest.mark.aiohttp +@pytest.mark.asyncio +async def test_httpx_async_batch_cannot_execute_if_not_connected(aiohttp_server): + from aiohttp import web + + from gql.transport.httpx import HTTPXAsyncTransport + + async def handler(request): + return web.Response( + text=query1_server_answer_list, + content_type="application/json", + headers={"dummy": "test1234"}, + ) + + app = web.Application() + app.router.add_route("POST", "/", handler) + server = await aiohttp_server(app) + + url = str(server.make_url("/")) + + transport = HTTPXAsyncTransport(url=url, timeout=10) + + query = [GraphQLRequest(query1_str)] + + with pytest.raises(TransportClosed): + await transport.execute_batch(query) + + +@pytest.mark.aiohttp +@pytest.mark.asyncio +async def test_httpx_sync_batch_cannot_execute_if_not_connected(aiohttp_server): + from aiohttp import web + + from gql.transport.httpx import HTTPXTransport + + async def handler(request): + return web.Response( + text=query1_server_answer_list, + content_type="application/json", + headers={"dummy": "test1234"}, + ) + + app = web.Application() + app.router.add_route("POST", "/", handler) + server = await aiohttp_server(app) + + url = str(server.make_url("/")) + + transport = HTTPXTransport(url=url, timeout=10) + + query = [GraphQLRequest(query1_str)] + + with pytest.raises(TransportClosed): + transport.execute_batch(query) + + +@pytest.mark.aiohttp +@pytest.mark.asyncio +async def test_httpx_async_batch_extra_args(aiohttp_server): + import httpx + from aiohttp import web + + from gql.transport.httpx import HTTPXAsyncTransport + + async def handler(request): + return web.Response( + text=query1_server_answer_list, + content_type="application/json", + headers={"dummy": "test1234"}, + ) + + app = web.Application() + app.router.add_route("POST", "/", handler) + server = await aiohttp_server(app) + + url = str(server.make_url("/")) + + # passing extra arguments to httpx.AsyncClient + inner_transport = httpx.AsyncHTTPTransport(retries=2) + transport = HTTPXAsyncTransport(url=url, max_redirects=2, transport=inner_transport) + + async with Client(transport=transport) as session: + + query = [GraphQLRequest(query1_str)] + + # Passing extra arguments to the post method + results = await session.execute_batch( + query, extra_args={"follow_redirects": True} + ) + + result = results[0] + + continents = result["continents"] + + africa = continents[0] + + assert africa["code"] == "AF" + + +query1_server_answer_with_extensions_list = ( + '[{"data":{"continents":[' + '{"code":"AF","name":"Africa"},{"code":"AN","name":"Antarctica"},' + '{"code":"AS","name":"Asia"},{"code":"EU","name":"Europe"},' + '{"code":"NA","name":"North America"},{"code":"OC","name":"Oceania"},' + '{"code":"SA","name":"South America"}]},' + '"extensions": {"key1": "val1"}' + "}]" +) + + +@pytest.mark.aiohttp +@pytest.mark.asyncio +async def test_httpx_async_batch_query_with_extensions(aiohttp_server): + from aiohttp import web + + from gql.transport.httpx import HTTPXAsyncTransport + + async def handler(request): + return web.Response( + text=query1_server_answer_with_extensions_list, + content_type="application/json", + ) + + app = web.Application() + app.router.add_route("POST", "/", handler) + server = await aiohttp_server(app) + + url = str(server.make_url("/")) + + transport = HTTPXAsyncTransport(url=url) + + query = [GraphQLRequest(query1_str)] + + async with Client(transport=transport) as session: + + execution_results = await session.execute_batch( + query, get_execution_result=True + ) + + assert execution_results[0].extensions["key1"] == "val1" + + +ONLINE_URL = "https://countries.trevorblades.workers.dev/graphql" + + +@pytest.mark.online +@pytest.mark.asyncio +async def test_httpx_batch_online_async_manual(): + + from gql.transport.httpx import HTTPXAsyncTransport + + client = Client( + transport=HTTPXAsyncTransport(url=ONLINE_URL), + ) + + query = """ + query getContinentName($continent_code: ID!) { + continent(code: $continent_code) { + name + } + } + """ + + async with client as session: + + request_eu = GraphQLRequest(query, variable_values={"continent_code": "EU"}) + request_af = GraphQLRequest(query, variable_values={"continent_code": "AF"}) + + result_eu, result_af = await session.execute_batch([request_eu, request_af]) + + assert result_eu["continent"]["name"] == "Europe" + assert result_af["continent"]["name"] == "Africa" + + +@pytest.mark.online +@pytest.mark.asyncio +async def test_httpx_batch_online_sync_manual(): + + from gql.transport.httpx import HTTPXTransport + + client = Client( + transport=HTTPXTransport(url=ONLINE_URL), + ) + + query = """ + query getContinentName($continent_code: ID!) { + continent(code: $continent_code) { + name + } + } + """ + + with client as session: + + request_eu = GraphQLRequest(query, variable_values={"continent_code": "EU"}) + request_af = GraphQLRequest(query, variable_values={"continent_code": "AF"}) + + result_eu, result_af = session.execute_batch([request_eu, request_af]) + + assert result_eu["continent"]["name"] == "Europe" + assert result_af["continent"]["name"] == "Africa" diff --git a/tests/test_httpx_online.py b/tests/test_httpx_online.py index 23d28dcc..c6e84368 100644 --- a/tests/test_httpx_online.py +++ b/tests/test_httpx_online.py @@ -11,7 +11,7 @@ @pytest.mark.httpx @pytest.mark.online @pytest.mark.asyncio -async def test_httpx_simple_query(event_loop): +async def test_httpx_simple_query(): from gql.transport.httpx import HTTPXAsyncTransport @@ -19,10 +19,10 @@ async def test_httpx_simple_query(event_loop): url = "https://countries.trevorblades.com/graphql" # Get transport - sample_transport = HTTPXAsyncTransport(url=url) + transport = HTTPXAsyncTransport(url=url) # Instanciate client - async with Client(transport=sample_transport) as session: + async with Client(transport=transport) as session: query = gql( """ @@ -56,15 +56,13 @@ async def test_httpx_simple_query(event_loop): @pytest.mark.httpx @pytest.mark.online @pytest.mark.asyncio -async def test_httpx_invalid_query(event_loop): +async def test_httpx_invalid_query(): from gql.transport.httpx import HTTPXAsyncTransport - sample_transport = HTTPXAsyncTransport( - url="https://countries.trevorblades.com/graphql" - ) + transport = HTTPXAsyncTransport(url="https://countries.trevorblades.com/graphql") - async with Client(transport=sample_transport) as session: + async with Client(transport=transport) as session: query = gql( """ @@ -85,16 +83,16 @@ async def test_httpx_invalid_query(event_loop): @pytest.mark.online @pytest.mark.skipif(sys.version_info < (3, 8), reason="requires python3.8 or higher") @pytest.mark.asyncio -async def test_httpx_two_queries_in_parallel_using_two_tasks(event_loop): +async def test_httpx_two_queries_in_parallel_using_two_tasks(): from gql.transport.httpx import HTTPXAsyncTransport - sample_transport = HTTPXAsyncTransport( + transport = HTTPXAsyncTransport( url="https://countries.trevorblades.com/graphql", ) # Instanciate client - async with Client(transport=sample_transport) as session: + async with Client(transport=transport) as session: query1 = gql( """ diff --git a/tests/test_phoenix_channel_exceptions.py b/tests/test_phoenix_channel_exceptions.py index c042ce01..b7f11dcb 100644 --- a/tests/test_phoenix_channel_exceptions.py +++ b/tests/test_phoenix_channel_exceptions.py @@ -19,9 +19,7 @@ def ensure_list(s): return ( s if s is None or isinstance(s, list) - else list(s) - if isinstance(s, tuple) - else [s] + else list(s) if isinstance(s, tuple) else [s] ) @@ -161,7 +159,7 @@ async def no_connection_ack_phoenix_server(ws): indirect=True, ) @pytest.mark.parametrize("query_str", [query1_str]) -async def test_phoenix_channel_query_protocol_error(event_loop, server, query_str): +async def test_phoenix_channel_query_protocol_error(server, query_str): from gql.transport.phoenix_channel_websockets import ( PhoenixChannelWebsocketsTransport, @@ -169,13 +167,11 @@ async def test_phoenix_channel_query_protocol_error(event_loop, server, query_st path = "/graphql" url = f"ws://{server.hostname}:{server.port}{path}" - sample_transport = PhoenixChannelWebsocketsTransport( - channel_name="test_channel", url=url - ) + transport = PhoenixChannelWebsocketsTransport(channel_name="test_channel", url=url) query = gql(query_str) with pytest.raises(TransportProtocolError): - async with Client(transport=sample_transport) as session: + async with Client(transport=transport) as session: await session.execute(query) @@ -191,7 +187,7 @@ async def test_phoenix_channel_query_protocol_error(event_loop, server, query_st indirect=True, ) @pytest.mark.parametrize("query_str", [query1_str]) -async def test_phoenix_channel_query_error(event_loop, server, query_str): +async def test_phoenix_channel_query_error(server, query_str): from gql.transport.phoenix_channel_websockets import ( PhoenixChannelWebsocketsTransport, @@ -199,13 +195,11 @@ async def test_phoenix_channel_query_error(event_loop, server, query_str): path = "/graphql" url = f"ws://{server.hostname}:{server.port}{path}" - sample_transport = PhoenixChannelWebsocketsTransport( - channel_name="test_channel", url=url - ) + transport = PhoenixChannelWebsocketsTransport(channel_name="test_channel", url=url) query = gql(query_str) with pytest.raises(TransportQueryError): - async with Client(transport=sample_transport) as session: + async with Client(transport=transport) as session: await session.execute(query) @@ -360,9 +354,10 @@ def subscription_server( data_answers=default_subscription_data_answer, unsubscribe_answers=default_subscription_unsubscribe_answer, ): - from .conftest import PhoenixChannelServerHelper import json + from .conftest import PhoenixChannelServerHelper + async def phoenix_server(ws): await PhoenixChannelServerHelper.send_connection_ack(ws) await ws.recv() @@ -407,9 +402,7 @@ async def phoenix_server(ws): indirect=True, ) @pytest.mark.parametrize("query_str", [query2_str]) -async def test_phoenix_channel_subscription_protocol_error( - event_loop, server, query_str -): +async def test_phoenix_channel_subscription_protocol_error(server, query_str): from gql.transport.phoenix_channel_websockets import ( PhoenixChannelWebsocketsTransport, @@ -417,13 +410,11 @@ async def test_phoenix_channel_subscription_protocol_error( path = "/graphql" url = f"ws://{server.hostname}:{server.port}{path}" - sample_transport = PhoenixChannelWebsocketsTransport( - channel_name="test_channel", url=url - ) + transport = PhoenixChannelWebsocketsTransport(channel_name="test_channel", url=url) query = gql(query_str) with pytest.raises(TransportProtocolError): - async with Client(transport=sample_transport) as session: + async with Client(transport=transport) as session: async for _result in session.subscribe(query): await asyncio.sleep(10 * MS) break @@ -439,7 +430,7 @@ async def test_phoenix_channel_subscription_protocol_error( indirect=True, ) @pytest.mark.parametrize("query_str", [query1_str]) -async def test_phoenix_channel_server_error(event_loop, server, query_str): +async def test_phoenix_channel_server_error(server, query_str): from gql.transport.phoenix_channel_websockets import ( PhoenixChannelWebsocketsTransport, @@ -447,13 +438,11 @@ async def test_phoenix_channel_server_error(event_loop, server, query_str): path = "/graphql" url = f"ws://{server.hostname}:{server.port}{path}" - sample_transport = PhoenixChannelWebsocketsTransport( - channel_name="test_channel", url=url - ) + transport = PhoenixChannelWebsocketsTransport(channel_name="test_channel", url=url) query = gql(query_str) with pytest.raises(TransportServerError): - async with Client(transport=sample_transport) as session: + async with Client(transport=transport) as session: await session.execute(query) @@ -468,7 +457,7 @@ async def test_phoenix_channel_server_error(event_loop, server, query_str): indirect=True, ) @pytest.mark.parametrize("query_str", [query2_str]) -async def test_phoenix_channel_unsubscribe_error(event_loop, server, query_str): +async def test_phoenix_channel_unsubscribe_error(server, query_str): from gql.transport.phoenix_channel_websockets import ( PhoenixChannelWebsocketsTransport, @@ -479,12 +468,12 @@ async def test_phoenix_channel_unsubscribe_error(event_loop, server, query_str): # Reduce close_timeout. These tests will wait for an unsubscribe # reply that will never come... - sample_transport = PhoenixChannelWebsocketsTransport( + transport = PhoenixChannelWebsocketsTransport( channel_name="test_channel", url=url, close_timeout=1 ) query = gql(query_str) - async with Client(transport=sample_transport) as session: + async with Client(transport=transport) as session: async for _result in session.subscribe(query): break @@ -498,7 +487,7 @@ async def test_phoenix_channel_unsubscribe_error(event_loop, server, query_str): indirect=True, ) @pytest.mark.parametrize("query_str", [query2_str]) -async def test_phoenix_channel_unsubscribe_error_forcing(event_loop, server, query_str): +async def test_phoenix_channel_unsubscribe_error_forcing(server, query_str): from gql.transport.phoenix_channel_websockets import ( PhoenixChannelWebsocketsTransport, @@ -507,13 +496,13 @@ async def test_phoenix_channel_unsubscribe_error_forcing(event_loop, server, que path = "/graphql" url = f"ws://{server.hostname}:{server.port}{path}" - sample_transport = PhoenixChannelWebsocketsTransport( + transport = PhoenixChannelWebsocketsTransport( channel_name="test_channel", url=url, close_timeout=1 ) query = gql(query_str) with pytest.raises(TransportProtocolError): - async with Client(transport=sample_transport) as session: + async with Client(transport=transport) as session: async for _result in session.subscribe(query): await session.transport._send_stop_message(2) await asyncio.sleep(10 * MS) diff --git a/tests/test_phoenix_channel_query.py b/tests/test_phoenix_channel_query.py index f39edacb..7dff7062 100644 --- a/tests/test_phoenix_channel_query.py +++ b/tests/test_phoenix_channel_query.py @@ -1,6 +1,7 @@ import pytest from gql import Client, gql +from gql.transport.exceptions import TransportConnectionFailed from .conftest import get_localhost_ssl_context_client @@ -51,7 +52,7 @@ async def query_server(ws): @pytest.mark.asyncio @pytest.mark.parametrize("server", [query_server], indirect=True) @pytest.mark.parametrize("query_str", [query1_str]) -async def test_phoenix_channel_query(event_loop, server, query_str): +async def test_phoenix_channel_query(server, query_str): from gql.transport.phoenix_channel_websockets import ( PhoenixChannelWebsocketsTransport, ) @@ -65,16 +66,16 @@ async def test_phoenix_channel_query(event_loop, server, query_str): result = await session.execute(query) print("Client received:", result) + continents = result["continents"] + print("Continents received:", continents) + africa = continents[0] + assert africa["code"] == "AF" -@pytest.mark.skip(reason="ssl=False is not working for now") @pytest.mark.asyncio @pytest.mark.parametrize("ws_ssl_server", [query_server], indirect=True) @pytest.mark.parametrize("query_str", [query1_str]) -@pytest.mark.parametrize("verify_https", ["disabled", "cert_provided"]) -async def test_phoenix_channel_query_ssl( - event_loop, ws_ssl_server, query_str, verify_https -): +async def test_phoenix_channel_query_ssl(ws_ssl_server, query_str): from gql.transport.phoenix_channel_websockets import ( PhoenixChannelWebsocketsTransport, ) @@ -85,12 +86,9 @@ async def test_phoenix_channel_query_ssl( extra_args = {} - if verify_https == "cert_provided": - _, ssl_context = get_localhost_ssl_context_client() + _, ssl_context = get_localhost_ssl_context_client() - extra_args["ssl"] = ssl_context - elif verify_https == "disabled": - extra_args["ssl"] = False + extra_args["ssl"] = ssl_context transport = PhoenixChannelWebsocketsTransport( channel_name="test_channel", @@ -110,12 +108,13 @@ async def test_phoenix_channel_query_ssl( @pytest.mark.parametrize("query_str", [query1_str]) @pytest.mark.parametrize("verify_https", ["explicitely_enabled", "default"]) async def test_phoenix_channel_query_ssl_self_cert_fail( - event_loop, ws_ssl_server, query_str, verify_https + ws_ssl_server, query_str, verify_https ): + from ssl import SSLCertVerificationError + from gql.transport.phoenix_channel_websockets import ( PhoenixChannelWebsocketsTransport, ) - from ssl import SSLCertVerificationError path = "/graphql" server = ws_ssl_server @@ -134,13 +133,17 @@ async def test_phoenix_channel_query_ssl_self_cert_fail( query = gql(query_str) - with pytest.raises(SSLCertVerificationError) as exc_info: + with pytest.raises(TransportConnectionFailed) as exc_info: async with Client(transport=transport) as session: await session.execute(query) + cause = exc_info.value.__cause__ + + assert isinstance(cause, SSLCertVerificationError) + expected_error = "certificate verify failed: self-signed certificate" - assert expected_error in str(exc_info.value) + assert expected_error in str(cause) query2_str = """ @@ -202,7 +205,7 @@ async def subscription_server(ws): @pytest.mark.asyncio @pytest.mark.parametrize("server", [subscription_server], indirect=True) @pytest.mark.parametrize("query_str", [query2_str]) -async def test_phoenix_channel_subscription(event_loop, server, query_str): +async def test_phoenix_channel_subscription(server, query_str): from gql.transport.phoenix_channel_websockets import ( PhoenixChannelWebsocketsTransport, ) @@ -214,8 +217,12 @@ async def test_phoenix_channel_subscription(event_loop, server, query_str): first_result = None query = gql(query_str) async with Client(transport=transport) as session: - async for result in session.subscribe(query): + generator = session.subscribe(query) + async for result in generator: first_result = result break + # Using aclose here to make it stop cleanly on pypy + await generator.aclose() + print("Client received:", first_result) diff --git a/tests/test_phoenix_channel_subscription.py b/tests/test_phoenix_channel_subscription.py index 6193c658..ecda9c38 100644 --- a/tests/test_phoenix_channel_subscription.py +++ b/tests/test_phoenix_channel_subscription.py @@ -173,9 +173,7 @@ async def stopping_coro(): @pytest.mark.parametrize("server", [server_countdown], indirect=True) @pytest.mark.parametrize("subscription_str", [countdown_subscription_str]) @pytest.mark.parametrize("end_count", [0, 5]) -async def test_phoenix_channel_subscription( - event_loop, server, subscription_str, end_count -): +async def test_phoenix_channel_subscription(server, subscription_str, end_count): """Parameterized test. :param end_count: Target count at which the test will 'break' to unsubscribe. @@ -186,22 +184,24 @@ async def test_phoenix_channel_subscription( PhoenixChannelWebsocketsTransport, ) from gql.transport.phoenix_channel_websockets import log as phoenix_logger - from gql.transport.websockets import log as websockets_logger + from gql.transport.websockets_protocol import log as websockets_logger websockets_logger.setLevel(logging.DEBUG) phoenix_logger.setLevel(logging.DEBUG) path = "/graphql" url = f"ws://{server.hostname}:{server.port}{path}" - sample_transport = PhoenixChannelWebsocketsTransport( + transport = PhoenixChannelWebsocketsTransport( channel_name=test_channel, url=url, close_timeout=5 ) count = 10 subscription = gql(subscription_str.format(count=count)) - async with Client(transport=sample_transport) as session: - async for result in session.subscribe(subscription): + async with Client(transport=transport) as session: + + generator = session.subscribe(subscription) + async for result in generator: number = result["countdown"]["number"] print(f"Number received: {number}") @@ -212,22 +212,23 @@ async def test_phoenix_channel_subscription( count -= 1 + # Using aclose here to make it stop cleanly on pypy + await generator.aclose() + assert count == end_count @pytest.mark.asyncio @pytest.mark.parametrize("server", [server_countdown], indirect=True) @pytest.mark.parametrize("subscription_str", [countdown_subscription_str]) -async def test_phoenix_channel_subscription_no_break( - event_loop, server, subscription_str -): +async def test_phoenix_channel_subscription_no_break(server, subscription_str): import logging from gql.transport.phoenix_channel_websockets import ( PhoenixChannelWebsocketsTransport, ) from gql.transport.phoenix_channel_websockets import log as phoenix_logger - from gql.transport.websockets import log as websockets_logger + from gql.transport.websockets_protocol import log as websockets_logger from .conftest import MS @@ -239,14 +240,14 @@ async def test_phoenix_channel_subscription_no_break( async def testing_stopping_without_break(): - sample_transport = PhoenixChannelWebsocketsTransport( + transport = PhoenixChannelWebsocketsTransport( channel_name=test_channel, url=url, close_timeout=(5000 * MS) ) count = 10 subscription = gql(subscription_str.format(count=count)) - async with Client(transport=sample_transport) as session: + async with Client(transport=transport) as session: async for result in session.subscribe(subscription): number = result["countdown"]["number"] print(f"Number received: {number}") @@ -364,21 +365,22 @@ async def heartbeat_coro(): @pytest.mark.asyncio @pytest.mark.parametrize("server", [phoenix_heartbeat_server], indirect=True) @pytest.mark.parametrize("subscription_str", [heartbeat_subscription_str]) -async def test_phoenix_channel_heartbeat(event_loop, server, subscription_str): +async def test_phoenix_channel_heartbeat(server, subscription_str): from gql.transport.phoenix_channel_websockets import ( PhoenixChannelWebsocketsTransport, ) path = "/graphql" url = f"ws://{server.hostname}:{server.port}{path}" - sample_transport = PhoenixChannelWebsocketsTransport( + transport = PhoenixChannelWebsocketsTransport( channel_name=test_channel, url=url, heartbeat_interval=0.1 ) subscription = gql(heartbeat_subscription_str) - async with Client(transport=sample_transport) as session: + async with Client(transport=transport) as session: i = 0 - async for result in session.subscribe(subscription): + generator = session.subscribe(subscription) + async for result in generator: heartbeat_count = result["heartbeat"]["heartbeat_count"] print(f"Heartbeat count received: {heartbeat_count}") @@ -387,3 +389,6 @@ async def test_phoenix_channel_heartbeat(event_loop, server, subscription_str): break i += 1 + + # Using aclose here to make it stop cleanly on pypy + await generator.aclose() diff --git a/tests/test_requests.py b/tests/test_requests.py index 95db0b3f..fe57f5e3 100644 --- a/tests/test_requests.py +++ b/tests/test_requests.py @@ -1,11 +1,14 @@ -from typing import Mapping +import os +import warnings +from typing import Any, Dict, Mapping import pytest -from gql import Client, gql +from gql import Client, FileVar, gql from gql.transport.exceptions import ( TransportAlreadyConnected, TransportClosed, + TransportConnectionFailed, TransportProtocolError, TransportQueryError, TransportServerError, @@ -14,7 +17,7 @@ from .conftest import ( TemporaryFile, get_localhost_ssl_context_client, - strip_braces_spaces, + make_upload_handler, ) # Marking all tests in this file with the requests marker @@ -40,8 +43,9 @@ @pytest.mark.aiohttp @pytest.mark.asyncio -async def test_requests_query(event_loop, aiohttp_server, run_sync_test): +async def test_requests_query(aiohttp_server, run_sync_test): from aiohttp import web + from gql.transport.requests import RequestsHTTPTransport async def handler(request): @@ -78,18 +82,16 @@ def test_code(): assert isinstance(transport.response_headers, Mapping) assert transport.response_headers["dummy"] == "test1234" - await run_sync_test(event_loop, server, test_code) + await run_sync_test(server, test_code) @pytest.mark.aiohttp @pytest.mark.asyncio @pytest.mark.parametrize("verify_https", ["disabled", "cert_provided"]) -async def test_requests_query_https( - event_loop, ssl_aiohttp_server, run_sync_test, verify_https -): +async def test_requests_query_https(ssl_aiohttp_server, run_sync_test, verify_https): from aiohttp import web + from gql.transport.requests import RequestsHTTPTransport - import warnings async def handler(request): return web.Response( @@ -142,19 +144,19 @@ def test_code(): assert isinstance(transport.response_headers, Mapping) assert transport.response_headers["dummy"] == "test1234" - await run_sync_test(event_loop, server, test_code) + await run_sync_test(server, test_code) @pytest.mark.aiohttp @pytest.mark.asyncio @pytest.mark.parametrize("verify_https", ["explicitely_enabled", "default"]) async def test_requests_query_https_self_cert_fail( - event_loop, ssl_aiohttp_server, run_sync_test, verify_https + ssl_aiohttp_server, run_sync_test, verify_https ): """By default, we should verify the ssl certificate""" from aiohttp import web + from gql.transport.requests import RequestsHTTPTransport - from requests.exceptions import SSLError async def handler(request): return web.Response( @@ -170,7 +172,7 @@ async def handler(request): url = server.make_url("/") def test_code(): - extra_args = {} + extra_args: Dict[str, Any] = {} if verify_https == "explicitely_enabled": extra_args["verify"] = True @@ -180,7 +182,7 @@ def test_code(): **extra_args, ) - with pytest.raises(SSLError) as exc_info: + with pytest.raises(TransportConnectionFailed) as exc_info: with Client(transport=transport) as session: query = gql(query1_str) @@ -192,13 +194,14 @@ def test_code(): assert expected_error in str(exc_info.value) - await run_sync_test(event_loop, server, test_code) + await run_sync_test(server, test_code) @pytest.mark.aiohttp @pytest.mark.asyncio -async def test_requests_cookies(event_loop, aiohttp_server, run_sync_test): +async def test_requests_cookies(aiohttp_server, run_sync_test): from aiohttp import web + from gql.transport.requests import RequestsHTTPTransport async def handler(request): @@ -229,13 +232,14 @@ def test_code(): assert africa["code"] == "AF" - await run_sync_test(event_loop, server, test_code) + await run_sync_test(server, test_code) @pytest.mark.aiohttp @pytest.mark.asyncio -async def test_requests_error_code_401(event_loop, aiohttp_server, run_sync_test): +async def test_requests_error_code_401(aiohttp_server, run_sync_test): from aiohttp import web + from gql.transport.requests import RequestsHTTPTransport async def handler(request): @@ -264,13 +268,14 @@ def test_code(): assert "401 Client Error: Unauthorized" in str(exc_info.value) - await run_sync_test(event_loop, server, test_code) + await run_sync_test(server, test_code) @pytest.mark.aiohttp @pytest.mark.asyncio -async def test_requests_error_code_429(event_loop, aiohttp_server, run_sync_test): +async def test_requests_error_code_429(aiohttp_server, run_sync_test): from aiohttp import web + from gql.transport.requests import RequestsHTTPTransport async def handler(request): @@ -318,8 +323,9 @@ def test_code(): @pytest.mark.aiohttp @pytest.mark.asyncio -async def test_requests_error_code_500(event_loop, aiohttp_server, run_sync_test): +async def test_requests_error_code_500(aiohttp_server, run_sync_test): from aiohttp import web + from gql.transport.requests import RequestsHTTPTransport async def handler(request): @@ -342,7 +348,7 @@ def test_code(): with pytest.raises(TransportServerError): session.execute(query) - await run_sync_test(event_loop, server, test_code) + await run_sync_test(server, test_code) query1_server_error_answer = '{"errors": ["Error 1", "Error 2"]}' @@ -350,8 +356,9 @@ def test_code(): @pytest.mark.aiohttp @pytest.mark.asyncio -async def test_requests_error_code(event_loop, aiohttp_server, run_sync_test): +async def test_requests_error_code(aiohttp_server, run_sync_test): from aiohttp import web + from gql.transport.requests import RequestsHTTPTransport async def handler(request): @@ -375,7 +382,7 @@ def test_code(): with pytest.raises(TransportQueryError): session.execute(query) - await run_sync_test(event_loop, server, test_code) + await run_sync_test(server, test_code) invalid_protocol_responses = [ @@ -388,10 +395,9 @@ def test_code(): @pytest.mark.aiohttp @pytest.mark.asyncio @pytest.mark.parametrize("response", invalid_protocol_responses) -async def test_requests_invalid_protocol( - event_loop, aiohttp_server, response, run_sync_test -): +async def test_requests_invalid_protocol(aiohttp_server, response, run_sync_test): from aiohttp import web + from gql.transport.requests import RequestsHTTPTransport async def handler(request): @@ -413,13 +419,14 @@ def test_code(): with pytest.raises(TransportProtocolError): session.execute(query) - await run_sync_test(event_loop, server, test_code) + await run_sync_test(server, test_code) @pytest.mark.aiohttp @pytest.mark.asyncio -async def test_requests_cannot_connect_twice(event_loop, aiohttp_server, run_sync_test): +async def test_requests_cannot_connect_twice(aiohttp_server, run_sync_test): from aiohttp import web + from gql.transport.requests import RequestsHTTPTransport async def handler(request): @@ -439,15 +446,14 @@ def test_code(): with pytest.raises(TransportAlreadyConnected): session.transport.connect() - await run_sync_test(event_loop, server, test_code) + await run_sync_test(server, test_code) @pytest.mark.aiohttp @pytest.mark.asyncio -async def test_requests_cannot_execute_if_not_connected( - event_loop, aiohttp_server, run_sync_test -): +async def test_requests_cannot_execute_if_not_connected(aiohttp_server, run_sync_test): from aiohttp import web + from gql.transport.requests import RequestsHTTPTransport async def handler(request): @@ -467,7 +473,7 @@ def test_code(): with pytest.raises(TransportClosed): transport.execute(query) - await run_sync_test(event_loop, server, test_code) + await run_sync_test(server, test_code) query1_server_answer_with_extensions = ( @@ -483,10 +489,9 @@ def test_code(): @pytest.mark.aiohttp @pytest.mark.asyncio -async def test_requests_query_with_extensions( - event_loop, aiohttp_server, run_sync_test -): +async def test_requests_query_with_extensions(aiohttp_server, run_sync_test): from aiohttp import web + from gql.transport.requests import RequestsHTTPTransport async def handler(request): @@ -511,11 +516,9 @@ def test_code(): assert execution_result.extensions["key1"] == "val1" - await run_sync_test(event_loop, server, test_code) + await run_sync_test(server, test_code) -file_upload_server_answer = '{"data":{"success":true}}' - file_upload_mutation_1 = """ mutation($file: Upload!) { uploadFile(input:{other_var:$other_var, file:$file}) { @@ -540,39 +543,21 @@ def test_code(): @pytest.mark.aiohttp @pytest.mark.asyncio -async def test_requests_file_upload(event_loop, aiohttp_server, run_sync_test): +async def test_requests_file_upload(aiohttp_server, run_sync_test): from aiohttp import web - from gql.transport.requests import RequestsHTTPTransport - - async def single_upload_handler(request): - from aiohttp import web - - reader = await request.multipart() - - field_0 = await reader.next() - assert field_0.name == "operations" - field_0_text = await field_0.text() - assert strip_braces_spaces(field_0_text) == file_upload_mutation_1_operations - field_1 = await reader.next() - assert field_1.name == "map" - field_1_text = await field_1.text() - assert field_1_text == file_upload_mutation_1_map - - field_2 = await reader.next() - assert field_2.name == "0" - field_2_text = await field_2.text() - assert field_2_text == file_1_content - - field_3 = await reader.next() - assert field_3 is None - - return web.Response( - text=file_upload_server_answer, content_type="application/json" - ) + from gql.transport.requests import RequestsHTTPTransport app = web.Application() - app.router.add_route("POST", "/", single_upload_handler) + app.router.add_route( + "POST", + "/", + make_upload_handler( + expected_map=file_upload_mutation_1_map, + expected_operations=file_upload_mutation_1_operations, + expected_contents=[file_1_content], + ), + ) server = await aiohttp_server(app) url = server.make_url("/") @@ -586,58 +571,56 @@ def test_code(): file_path = test_file.filename + # Using an opened file with open(file_path, "rb") as f: - params = {"file": f, "other_var": 42} - execution_result = session._execute( - query, variable_values=params, upload_files=True - ) - - assert execution_result.data["success"] + query.variable_values = {"file": f, "other_var": 42} - await run_sync_test(event_loop, server, test_code) + with pytest.warns( + DeprecationWarning, + match="Not using FileVar for file upload is deprecated", + ): + execution_result = session.execute(query, upload_files=True) + assert execution_result["success"] -@pytest.mark.aiohttp -@pytest.mark.asyncio -async def test_requests_file_upload_with_content_type( - event_loop, aiohttp_server, run_sync_test -): - from aiohttp import web - from gql.transport.requests import RequestsHTTPTransport + # Using an opened file inside a FileVar object + with open(file_path, "rb") as f: - async def single_upload_handler(request): - from aiohttp import web + query.variable_values = {"file": FileVar(f), "other_var": 42} + with warnings.catch_warnings(): + warnings.simplefilter("error") # Turn warnings into errors + execution_result = session.execute(query, upload_files=True) - reader = await request.multipart() + assert execution_result["success"] - field_0 = await reader.next() - assert field_0.name == "operations" - field_0_text = await field_0.text() - assert strip_braces_spaces(field_0_text) == file_upload_mutation_1_operations + # Using an filename string inside a FileVar object + query.variable_values = {"file": FileVar(file_path), "other_var": 42} + execution_result = session.execute(query, upload_files=True) - field_1 = await reader.next() - assert field_1.name == "map" - field_1_text = await field_1.text() - assert field_1_text == file_upload_mutation_1_map + assert execution_result["success"] - field_2 = await reader.next() - assert field_2.name == "0" - field_2_text = await field_2.text() - assert field_2_text == file_1_content + await run_sync_test(server, test_code) - # Verifying the content_type - assert field_2.headers["Content-Type"] == "application/pdf" - field_3 = await reader.next() - assert field_3 is None +@pytest.mark.aiohttp +@pytest.mark.asyncio +async def test_requests_file_upload_with_content_type(aiohttp_server, run_sync_test): + from aiohttp import web - return web.Response( - text=file_upload_server_answer, content_type="application/json" - ) + from gql.transport.requests import RequestsHTTPTransport app = web.Application() - app.router.add_route("POST", "/", single_upload_handler) + app.router.add_route( + "POST", + "/", + make_upload_handler( + file_headers=[{"Content-Type": "application/pdf"}], + expected_map=file_upload_mutation_1_map, + expected_operations=file_upload_mutation_1_operations, + expected_contents=[file_1_content], + ), + ) server = await aiohttp_server(app) url = server.make_url("/") @@ -651,66 +634,107 @@ def test_code(): file_path = test_file.filename + # Using an opened file with open(file_path, "rb") as f: # Setting the content_type - f.content_type = "application/pdf" + f.content_type = "application/pdf" # type: ignore + + query.variable_values = {"file": f, "other_var": 42} + with pytest.warns( + DeprecationWarning, + match="Not using FileVar for file upload is deprecated", + ): + execution_result = session.execute(query, upload_files=True) + + assert execution_result["success"] - params = {"file": f, "other_var": 42} - execution_result = session._execute( - query, variable_values=params, upload_files=True - ) + # Using an opened file inside a FileVar object + with open(file_path, "rb") as f: + + query.variable_values = { + "file": FileVar(f, content_type="application/pdf"), + "other_var": 42, + } + execution_result = session.execute(query, upload_files=True) - assert execution_result.data["success"] + assert execution_result["success"] - await run_sync_test(event_loop, server, test_code) + await run_sync_test(server, test_code) @pytest.mark.aiohttp @pytest.mark.asyncio -async def test_requests_file_upload_additional_headers( - event_loop, aiohttp_server, run_sync_test +async def test_requests_file_upload_default_filename_is_basename( + aiohttp_server, run_sync_test ): from aiohttp import web + from gql.transport.requests import RequestsHTTPTransport - async def single_upload_handler(request): - from aiohttp import web + app = web.Application() - assert request.headers["X-Auth"] == "foobar" + with TemporaryFile(file_1_content) as test_file: + file_path = test_file.filename + file_basename = os.path.basename(file_path) + + app.router.add_route( + "POST", + "/", + make_upload_handler( + filenames=[file_basename], + expected_map=file_upload_mutation_1_map, + expected_operations=file_upload_mutation_1_operations, + expected_contents=[file_1_content], + ), + ) + server = await aiohttp_server(app) - reader = await request.multipart() + url = server.make_url("/") - field_0 = await reader.next() - assert field_0.name == "operations" - field_0_text = await field_0.text() - assert strip_braces_spaces(field_0_text) == file_upload_mutation_1_operations + def test_code(): - field_1 = await reader.next() - assert field_1.name == "map" - field_1_text = await field_1.text() - assert field_1_text == file_upload_mutation_1_map + transport = RequestsHTTPTransport(url=url) - field_2 = await reader.next() - assert field_2.name == "0" - field_2_text = await field_2.text() - assert field_2_text == file_1_content + with Client(transport=transport) as session: + query = gql(file_upload_mutation_1) - field_3 = await reader.next() - assert field_3 is None + query.variable_values = { + "file": FileVar(file_path), + "other_var": 42, + } + execution_result = session.execute(query, upload_files=True) - return web.Response( - text=file_upload_server_answer, content_type="application/json" - ) + assert execution_result["success"] + + await run_sync_test(server, test_code) + + +@pytest.mark.aiohttp +@pytest.mark.asyncio +async def test_requests_file_upload_with_filename(aiohttp_server, run_sync_test): + from aiohttp import web + + from gql.transport.requests import RequestsHTTPTransport app = web.Application() - app.router.add_route("POST", "/", single_upload_handler) + app.router.add_route( + "POST", + "/", + make_upload_handler( + filenames=["filename1.txt"], + expected_map=file_upload_mutation_1_map, + expected_operations=file_upload_mutation_1_operations, + expected_contents=[file_1_content], + ), + ) server = await aiohttp_server(app) url = server.make_url("/") def test_code(): - transport = RequestsHTTPTransport(url=url, headers={"X-Auth": "foobar"}) + + transport = RequestsHTTPTransport(url=url) with TemporaryFile(file_1_content) as test_file: with Client(transport=transport) as session: @@ -720,55 +744,83 @@ def test_code(): with open(file_path, "rb") as f: - params = {"file": f, "other_var": 42} - execution_result = session._execute( - query, variable_values=params, upload_files=True - ) + query.variable_values = { + "file": FileVar(f, filename="filename1.txt"), + "other_var": 42, + } + execution_result = session.execute(query, upload_files=True) - assert execution_result.data["success"] + assert execution_result["success"] - await run_sync_test(event_loop, server, test_code) + await run_sync_test(server, test_code) @pytest.mark.aiohttp @pytest.mark.asyncio -async def test_requests_binary_file_upload(event_loop, aiohttp_server, run_sync_test): +async def test_requests_file_upload_additional_headers(aiohttp_server, run_sync_test): from aiohttp import web + from gql.transport.requests import RequestsHTTPTransport - # This is a sample binary file content containing all possible byte values - binary_file_content = bytes(range(0, 256)) + app = web.Application() + app.router.add_route( + "POST", + "/", + make_upload_handler( + request_headers={"X-Auth": "foobar"}, + expected_map=file_upload_mutation_1_map, + expected_operations=file_upload_mutation_1_operations, + expected_contents=[file_1_content], + ), + ) + server = await aiohttp_server(app) - async def binary_upload_handler(request): + url = server.make_url("/") - from aiohttp import web + def test_code(): + transport = RequestsHTTPTransport(url=url, headers={"X-Auth": "foobar"}) - reader = await request.multipart() + with TemporaryFile(file_1_content) as test_file: + with Client(transport=transport) as session: + query = gql(file_upload_mutation_1) - field_0 = await reader.next() - assert field_0.name == "operations" - field_0_text = await field_0.text() - assert strip_braces_spaces(field_0_text) == file_upload_mutation_1_operations + file_path = test_file.filename - field_1 = await reader.next() - assert field_1.name == "map" - field_1_text = await field_1.text() - assert field_1_text == file_upload_mutation_1_map + with open(file_path, "rb") as f: - field_2 = await reader.next() - assert field_2.name == "0" - field_2_binary = await field_2.read() - assert field_2_binary == binary_file_content + query.variable_values = {"file": f, "other_var": 42} + with pytest.warns( + DeprecationWarning, + match="Not using FileVar for file upload is deprecated", + ): + execution_result = session.execute(query, upload_files=True) - field_3 = await reader.next() - assert field_3 is None + assert execution_result["success"] - return web.Response( - text=file_upload_server_answer, content_type="application/json" - ) + await run_sync_test(server, test_code) + + +@pytest.mark.aiohttp +@pytest.mark.asyncio +async def test_requests_binary_file_upload(aiohttp_server, run_sync_test): + from aiohttp import web + + from gql.transport.requests import RequestsHTTPTransport + + # This is a sample binary file content containing all possible byte values + binary_file_content = bytes(range(0, 256)) app = web.Application() - app.router.add_route("POST", "/", binary_upload_handler) + app.router.add_route( + "POST", + "/", + make_upload_handler( + binary=True, + expected_contents=[binary_file_content], + expected_map=file_upload_mutation_1_map, + expected_operations=file_upload_mutation_1_operations, + ), + ) server = await aiohttp_server(app) url = server.make_url("/") @@ -785,30 +837,24 @@ def test_code(): with open(file_path, "rb") as f: - params = {"file": f, "other_var": 42} - - execution_result = session._execute( - query, variable_values=params, upload_files=True - ) + query.variable_values = {"file": f, "other_var": 42} - assert execution_result.data["success"] + with pytest.warns( + DeprecationWarning, + match="Not using FileVar for file upload is deprecated", + ): + execution_result = session.execute(query, upload_files=True) - await run_sync_test(event_loop, server, test_code) + assert execution_result["success"] - -file_upload_mutation_2_operations = ( - '{"query": "mutation ($file1: Upload!, $file2: Upload!) {\\n ' - 'uploadFile(input: {file1: $file, file2: $file}) {\\n success\\n }\\n}", ' - '"variables": {"file1": null, "file2": null}}' -) + await run_sync_test(server, test_code) @pytest.mark.aiohttp @pytest.mark.asyncio -async def test_requests_file_upload_two_files( - event_loop, aiohttp_server, run_sync_test -): +async def test_requests_file_upload_two_files(aiohttp_server, run_sync_test): from aiohttp import web + from gql.transport.requests import RequestsHTTPTransport file_upload_mutation_2 = """ @@ -819,6 +865,12 @@ async def test_requests_file_upload_two_files( } """ + file_upload_mutation_2_operations = ( + '{"query": "mutation ($file1: Upload!, $file2: Upload!) {\\n ' + 'uploadFile(input: {file1: $file, file2: $file}) {\\n success\\n }\\n}", ' + '"variables": {"file1": null, "file2": null}}' + ) + file_upload_mutation_2_map = '{"0": ["variables.file1"], "1": ["variables.file2"]}' file_2_content = """ @@ -826,39 +878,17 @@ async def test_requests_file_upload_two_files( This file will also be sent in the GraphQL mutation """ - async def handler(request): - - reader = await request.multipart() - - field_0 = await reader.next() - assert field_0.name == "operations" - field_0_text = await field_0.text() - assert strip_braces_spaces(field_0_text) == file_upload_mutation_2_operations - - field_1 = await reader.next() - assert field_1.name == "map" - field_1_text = await field_1.text() - assert field_1_text == file_upload_mutation_2_map - - field_2 = await reader.next() - assert field_2.name == "0" - field_2_text = await field_2.text() - assert field_2_text == file_1_content - - field_3 = await reader.next() - assert field_3.name == "1" - field_3_text = await field_3.text() - assert field_3_text == file_2_content - - field_4 = await reader.next() - assert field_4 is None - - return web.Response( - text=file_upload_server_answer, content_type="application/json" - ) - app = web.Application() - app.router.add_route("POST", "/", handler) + app.router.add_route( + "POST", + "/", + make_upload_handler( + nb_files=2, + expected_map=file_upload_mutation_2_map, + expected_operations=file_upload_mutation_2_operations, + expected_contents=[file_1_content, file_2_content], + ), + ) server = await aiohttp_server(app) url = server.make_url("/") @@ -873,42 +903,56 @@ def test_code(): query = gql(file_upload_mutation_2) + # Old method file_path_1 = test_file_1.filename file_path_2 = test_file_2.filename f1 = open(file_path_1, "rb") f2 = open(file_path_2, "rb") - params = { + query.variable_values = { "file1": f1, "file2": f2, } - execution_result = session._execute( - query, variable_values=params, upload_files=True - ) + with pytest.warns( + DeprecationWarning, + match="Not using FileVar for file upload is deprecated", + ): + execution_result = session.execute(query, upload_files=True) - assert execution_result.data["success"] + assert execution_result["success"] f1.close() f2.close() - await run_sync_test(event_loop, server, test_code) + # Using FileVar + file_path_1 = test_file_1.filename + file_path_2 = test_file_2.filename + + f1 = open(file_path_1, "rb") + f2 = open(file_path_2, "rb") + + query.variable_values = { + "file1": FileVar(f1), + "file2": FileVar(f2), + } + execution_result = session.execute(query, upload_files=True) -file_upload_mutation_3_operations = ( - '{"query": "mutation ($files: [Upload!]!) {\\n uploadFiles' - "(input: {files: $files})" - ' {\\n success\\n }\\n}", "variables": {"files": [null, null]}}' -) + assert execution_result["success"] + + f1.close() + f2.close() + + await run_sync_test(server, test_code) @pytest.mark.aiohttp @pytest.mark.asyncio -async def test_requests_file_upload_list_of_two_files( - event_loop, aiohttp_server, run_sync_test -): +async def test_requests_file_upload_list_of_two_files(aiohttp_server, run_sync_test): from aiohttp import web + from gql.transport.requests import RequestsHTTPTransport file_upload_mutation_3 = """ @@ -919,6 +963,12 @@ async def test_requests_file_upload_list_of_two_files( } """ + file_upload_mutation_3_operations = ( + '{"query": "mutation ($files: [Upload!]!) {\\n uploadFiles' + "(input: {files: $files})" + ' {\\n success\\n }\\n}", "variables": {"files": [null, null]}}' + ) + file_upload_mutation_3_map = ( '{"0": ["variables.files.0"], "1": ["variables.files.1"]}' ) @@ -928,39 +978,17 @@ async def test_requests_file_upload_list_of_two_files( This file will also be sent in the GraphQL mutation """ - async def handler(request): - - reader = await request.multipart() - - field_0 = await reader.next() - assert field_0.name == "operations" - field_0_text = await field_0.text() - assert strip_braces_spaces(field_0_text) == file_upload_mutation_3_operations - - field_1 = await reader.next() - assert field_1.name == "map" - field_1_text = await field_1.text() - assert field_1_text == file_upload_mutation_3_map - - field_2 = await reader.next() - assert field_2.name == "0" - field_2_text = await field_2.text() - assert field_2_text == file_1_content - - field_3 = await reader.next() - assert field_3.name == "1" - field_3_text = await field_3.text() - assert field_3_text == file_2_content - - field_4 = await reader.next() - assert field_4 is None - - return web.Response( - text=file_upload_server_answer, content_type="application/json" - ) - app = web.Application() - app.router.add_route("POST", "/", handler) + app.router.add_route( + "POST", + "/", + make_upload_handler( + nb_files=2, + expected_map=file_upload_mutation_3_map, + expected_operations=file_upload_mutation_3_operations, + expected_contents=[file_1_content, file_2_content], + ), + ) server = await aiohttp_server(app) url = server.make_url("/") @@ -974,32 +1002,50 @@ def test_code(): query = gql(file_upload_mutation_3) + # Old method file_path_1 = test_file_1.filename file_path_2 = test_file_2.filename f1 = open(file_path_1, "rb") f2 = open(file_path_2, "rb") - params = {"files": [f1, f2]} + query.variable_values = {"files": [f1, f2]} - execution_result = session._execute( - query, variable_values=params, upload_files=True - ) + with pytest.warns( + DeprecationWarning, + match="Not using FileVar for file upload is deprecated", + ): + execution_result = session.execute(query, upload_files=True) - assert execution_result.data["success"] + assert execution_result["success"] f1.close() f2.close() - await run_sync_test(event_loop, server, test_code) + # Using FileVar + file_path_1 = test_file_1.filename + file_path_2 = test_file_2.filename + + f1 = open(file_path_1, "rb") + f2 = open(file_path_2, "rb") + + query.variable_values = {"files": [FileVar(f1), FileVar(f2)]} + + execution_result = session.execute(query, upload_files=True) + + assert execution_result["success"] + + f1.close() + f2.close() + + await run_sync_test(server, test_code) @pytest.mark.aiohttp @pytest.mark.asyncio -async def test_requests_error_fetching_schema( - event_loop, aiohttp_server, run_sync_test -): +async def test_requests_error_fetching_schema(aiohttp_server, run_sync_test): from aiohttp import web + from gql.transport.requests import RequestsHTTPTransport error_answer = """ @@ -1040,16 +1086,16 @@ def test_code(): assert expected_error in str(exc_info.value) assert transport.session is None - await run_sync_test(event_loop, server, test_code) + await run_sync_test(server, test_code) @pytest.mark.aiohttp @pytest.mark.asyncio -async def test_requests_json_serializer( - event_loop, aiohttp_server, run_sync_test, caplog -): +async def test_requests_json_serializer(aiohttp_server, run_sync_test, caplog): import json + from aiohttp import web + from gql.transport.requests import RequestsHTTPTransport async def handler(request): @@ -1091,7 +1137,7 @@ def test_code(): expected_log = '"query":"query getContinents' assert expected_log in caplog.text - await run_sync_test(event_loop, server, test_code) + await run_sync_test(server, test_code) query_float_str = """ @@ -1107,11 +1153,13 @@ def test_code(): @pytest.mark.aiohttp @pytest.mark.asyncio -async def test_requests_json_deserializer(event_loop, aiohttp_server, run_sync_test): +async def test_requests_json_deserializer(aiohttp_server, run_sync_test): import json - from aiohttp import web from decimal import Decimal from functools import partial + + from aiohttp import web + from gql.transport.requests import RequestsHTTPTransport async def handler(request): @@ -1146,4 +1194,4 @@ def test_code(): assert pi == Decimal("3.141592653589793238462643383279502884197") - await run_sync_test(event_loop, server, test_code) + await run_sync_test(server, test_code) diff --git a/tests/test_requests_batch.py b/tests/test_requests_batch.py index 4d8bf27e..a2f0cdbf 100644 --- a/tests/test_requests_batch.py +++ b/tests/test_requests_batch.py @@ -48,8 +48,9 @@ @pytest.mark.aiohttp @pytest.mark.asyncio -async def test_requests_query(event_loop, aiohttp_server, run_sync_test): +async def test_requests_query(aiohttp_server, run_sync_test): from aiohttp import web + from gql.transport.requests import RequestsHTTPTransport async def handler(request): @@ -70,7 +71,7 @@ def test_code(): with Client(transport=transport) as session: - query = [GraphQLRequest(document=gql(query1_str))] + query = [GraphQLRequest(query1_str)] # Execute query synchronously results = session.execute_batch(query) @@ -86,15 +87,14 @@ def test_code(): assert isinstance(transport.response_headers, Mapping) assert transport.response_headers["dummy"] == "test1234" - await run_sync_test(event_loop, server, test_code) + await run_sync_test(server, test_code) @pytest.mark.aiohttp @pytest.mark.asyncio -async def test_requests_query_auto_batch_enabled( - event_loop, aiohttp_server, run_sync_test -): +async def test_requests_query_auto_batch_enabled(aiohttp_server, run_sync_test): from aiohttp import web + from gql.transport.requests import RequestsHTTPTransport async def handler(request): @@ -134,17 +134,19 @@ def test_code(): assert isinstance(transport.response_headers, Mapping) assert transport.response_headers["dummy"] == "test1234" - await run_sync_test(event_loop, server, test_code) + await run_sync_test(server, test_code) @pytest.mark.aiohttp @pytest.mark.asyncio async def test_requests_query_auto_batch_enabled_two_requests( - event_loop, aiohttp_server, run_sync_test + aiohttp_server, run_sync_test ): + from threading import Thread + from aiohttp import web + from gql.transport.requests import RequestsHTTPTransport - from threading import Thread async def handler(request): return web.Response( @@ -194,13 +196,14 @@ def test_thread(): for thread in threads: thread.join() - await run_sync_test(event_loop, server, test_code) + await run_sync_test(server, test_code) @pytest.mark.aiohttp @pytest.mark.asyncio -async def test_requests_cookies(event_loop, aiohttp_server, run_sync_test): +async def test_requests_cookies(aiohttp_server, run_sync_test): from aiohttp import web + from gql.transport.requests import RequestsHTTPTransport async def handler(request): @@ -222,7 +225,7 @@ def test_code(): with Client(transport=transport) as session: - query = [GraphQLRequest(document=gql(query1_str))] + query = [GraphQLRequest(query1_str)] # Execute query synchronously results = session.execute_batch(query) @@ -233,13 +236,14 @@ def test_code(): assert africa["code"] == "AF" - await run_sync_test(event_loop, server, test_code) + await run_sync_test(server, test_code) @pytest.mark.aiohttp @pytest.mark.asyncio -async def test_requests_error_code_401(event_loop, aiohttp_server, run_sync_test): +async def test_requests_error_code_401(aiohttp_server, run_sync_test): from aiohttp import web + from gql.transport.requests import RequestsHTTPTransport async def handler(request): @@ -261,22 +265,23 @@ def test_code(): with Client(transport=transport) as session: - query = [GraphQLRequest(document=gql(query1_str))] + query = [GraphQLRequest(query1_str)] with pytest.raises(TransportServerError) as exc_info: session.execute_batch(query) assert "401 Client Error: Unauthorized" in str(exc_info.value) - await run_sync_test(event_loop, server, test_code) + await run_sync_test(server, test_code) @pytest.mark.aiohttp @pytest.mark.asyncio async def test_requests_error_code_401_auto_batch_enabled( - event_loop, aiohttp_server, run_sync_test + aiohttp_server, run_sync_test ): from aiohttp import web + from gql.transport.requests import RequestsHTTPTransport async def handler(request): @@ -308,13 +313,14 @@ def test_code(): assert "401 Client Error: Unauthorized" in str(exc_info.value) - await run_sync_test(event_loop, server, test_code) + await run_sync_test(server, test_code) @pytest.mark.aiohttp @pytest.mark.asyncio -async def test_requests_error_code_429(event_loop, aiohttp_server, run_sync_test): +async def test_requests_error_code_429(aiohttp_server, run_sync_test): from aiohttp import web + from gql.transport.requests import RequestsHTTPTransport async def handler(request): @@ -347,7 +353,7 @@ def test_code(): with Client(transport=transport) as session: - query = [GraphQLRequest(document=gql(query1_str))] + query = [GraphQLRequest(query1_str)] with pytest.raises(TransportServerError) as exc_info: session.execute_batch(query) @@ -362,8 +368,9 @@ def test_code(): @pytest.mark.aiohttp @pytest.mark.asyncio -async def test_requests_error_code_500(event_loop, aiohttp_server, run_sync_test): +async def test_requests_error_code_500(aiohttp_server, run_sync_test): from aiohttp import web + from gql.transport.requests import RequestsHTTPTransport async def handler(request): @@ -381,12 +388,12 @@ def test_code(): with Client(transport=transport) as session: - query = [GraphQLRequest(document=gql(query1_str))] + query = [GraphQLRequest(query1_str)] with pytest.raises(TransportServerError): session.execute_batch(query) - await run_sync_test(event_loop, server, test_code) + await run_sync_test(server, test_code) query1_server_error_answer_list = '[{"errors": ["Error 1", "Error 2"]}]' @@ -394,8 +401,9 @@ def test_code(): @pytest.mark.aiohttp @pytest.mark.asyncio -async def test_requests_error_code(event_loop, aiohttp_server, run_sync_test): +async def test_requests_error_code(aiohttp_server, run_sync_test): from aiohttp import web + from gql.transport.requests import RequestsHTTPTransport async def handler(request): @@ -414,12 +422,12 @@ def test_code(): with Client(transport=transport) as session: - query = [GraphQLRequest(document=gql(query1_str))] + query = [GraphQLRequest(query1_str)] with pytest.raises(TransportQueryError): session.execute_batch(query) - await run_sync_test(event_loop, server, test_code) + await run_sync_test(server, test_code) invalid_protocol_responses = [ @@ -437,10 +445,9 @@ def test_code(): @pytest.mark.aiohttp @pytest.mark.asyncio @pytest.mark.parametrize("response", invalid_protocol_responses) -async def test_requests_invalid_protocol( - event_loop, aiohttp_server, response, run_sync_test -): +async def test_requests_invalid_protocol(aiohttp_server, response, run_sync_test): from aiohttp import web + from gql.transport.requests import RequestsHTTPTransport async def handler(request): @@ -457,20 +464,19 @@ def test_code(): with Client(transport=transport) as session: - query = [GraphQLRequest(document=gql(query1_str))] + query = [GraphQLRequest(query1_str)] with pytest.raises(TransportProtocolError): session.execute_batch(query) - await run_sync_test(event_loop, server, test_code) + await run_sync_test(server, test_code) @pytest.mark.aiohttp @pytest.mark.asyncio -async def test_requests_cannot_execute_if_not_connected( - event_loop, aiohttp_server, run_sync_test -): +async def test_requests_cannot_execute_if_not_connected(aiohttp_server, run_sync_test): from aiohttp import web + from gql.transport.requests import RequestsHTTPTransport async def handler(request): @@ -487,12 +493,12 @@ async def handler(request): def test_code(): transport = RequestsHTTPTransport(url=url) - query = [GraphQLRequest(document=gql(query1_str))] + query = [GraphQLRequest(query1_str)] with pytest.raises(TransportClosed): transport.execute_batch(query) - await run_sync_test(event_loop, server, test_code) + await run_sync_test(server, test_code) query1_server_answer_with_extensions_list = ( @@ -508,10 +514,9 @@ def test_code(): @pytest.mark.aiohttp @pytest.mark.asyncio -async def test_requests_query_with_extensions( - event_loop, aiohttp_server, run_sync_test -): +async def test_requests_query_with_extensions(aiohttp_server, run_sync_test): from aiohttp import web + from gql.transport.requests import RequestsHTTPTransport async def handler(request): @@ -531,26 +536,24 @@ def test_code(): with Client(transport=transport) as session: - query = [GraphQLRequest(document=gql(query1_str))] + query = [GraphQLRequest(query1_str)] execution_results = session.execute_batch(query, get_execution_result=True) assert execution_results[0].extensions["key1"] == "val1" - await run_sync_test(event_loop, server, test_code) + await run_sync_test(server, test_code) -ONLINE_URL = "https://countries.trevorblades.com/" - -skip_reason = "backend does not support batching anymore..." +ONLINE_URL = "https://countries.trevorblades.workers.dev/graphql" @pytest.mark.online @pytest.mark.requests -@pytest.mark.skip(reason=skip_reason) def test_requests_sync_batch_auto(): from threading import Thread + from gql.transport.requests import RequestsHTTPTransport client = Client( @@ -613,7 +616,6 @@ def get_continent_name(session, continent_code): @pytest.mark.online @pytest.mark.requests -@pytest.mark.skip(reason=skip_reason) def test_requests_sync_batch_auto_execute_future(): from gql.transport.requests import RequestsHTTPTransport @@ -624,15 +626,13 @@ def test_requests_sync_batch_auto_execute_future(): batch_max=3, ) - query = gql( - """ + query = """ query getContinentName($continent_code: ID!) { continent(code: $continent_code) { name } } - """ - ) + """ with client as session: @@ -651,7 +651,6 @@ def test_requests_sync_batch_auto_execute_future(): @pytest.mark.online @pytest.mark.requests -@pytest.mark.skip(reason=skip_reason) def test_requests_sync_batch_manual(): from gql.transport.requests import RequestsHTTPTransport @@ -660,15 +659,13 @@ def test_requests_sync_batch_manual(): transport=RequestsHTTPTransport(url=ONLINE_URL), ) - query = gql( - """ + query = """ query getContinentName($continent_code: ID!) { continent(code: $continent_code) { name } } - """ - ) + """ with client as session: diff --git a/tests/test_transport.py b/tests/test_transport.py index d9a3eced..7c2a5a8f 100644 --- a/tests/test_transport.py +++ b/tests/test_transport.py @@ -28,6 +28,7 @@ def use_cassette(name): @pytest.fixture def client(): import requests + from gql.transport.requests import RequestsHTTPTransport with use_cassette("client"): @@ -42,6 +43,9 @@ def client(): url=URL, cookies={"csrftoken": csrf}, headers={"x-csrftoken": csrf} ), fetch_schema_from_transport=True, + introspection_args={ + "input_value_deprecation": False, + }, ) @@ -96,9 +100,10 @@ def test_query_with_variable(client): } """ ) + query.variable_values = {"id": "UGxhbmV0OjEw"} expected = {"planet": {"id": "UGxhbmV0OjEw", "name": "Kamino"}} with use_cassette("queries"): - result = client.execute(query, variable_values={"id": "UGxhbmV0OjEw"}) + result = client.execute(query) assert result == expected @@ -119,9 +124,10 @@ def test_named_query(client): } """ ) + query.operation_name = "Planet2" expected = {"planet": {"id": "UGxhbmV0OjEx", "name": "Geonosis"}} with use_cassette("queries"): - result = client.execute(query, operation_name="Planet2") + result = client.execute(query) assert result == expected diff --git a/tests/test_transport_batch.py b/tests/test_transport_batch.py index a9b21e6a..671858e7 100644 --- a/tests/test_transport_batch.py +++ b/tests/test_transport_batch.py @@ -2,7 +2,7 @@ import pytest -from gql import Client, GraphQLRequest, gql +from gql import Client, gql # We serve https://github.com/graphql-python/swapi-graphene locally: URL = "http://127.0.0.1:8000/graphql" @@ -28,6 +28,7 @@ def use_cassette(name): @pytest.fixture def client(): import requests + from gql.transport.requests import RequestsHTTPTransport with use_cassette("client"): @@ -42,6 +43,9 @@ def client(): url=URL, cookies={"csrftoken": csrf}, headers={"x-csrftoken": csrf} ), fetch_schema_from_transport=True, + introspection_args={ + "input_value_deprecation": False, + }, ) @@ -83,7 +87,7 @@ def test_hero_name_query(client): } ] with use_cassette("queries_batch"): - results = client.execute_batch([GraphQLRequest(document=query)]) + results = client.execute_batch([query]) assert results == expected @@ -98,11 +102,10 @@ def test_query_with_variable(client): } """ ) + query.variable_values = {"id": "UGxhbmV0OjEw"} expected = [{"planet": {"id": "UGxhbmV0OjEw", "name": "Kamino"}}] with use_cassette("queries_batch"): - results = client.execute_batch( - [GraphQLRequest(document=query, variable_values={"id": "UGxhbmV0OjEw"})] - ) + results = client.execute_batch([query]) assert results == expected @@ -123,11 +126,10 @@ def test_named_query(client): } """ ) + query.operation_name = "Planet2" expected = [{"planet": {"id": "UGxhbmV0OjEx", "name": "Geonosis"}}] with use_cassette("queries_batch"): - results = client.execute_batch( - [GraphQLRequest(document=query, operation_name="Planet2")] - ) + results = client.execute_batch([query]) assert results == expected @@ -145,7 +147,7 @@ def test_header_query(client): expected = [{"planet": {"id": "UGxhbmV0OjEx", "name": "Geonosis"}}] with use_cassette("queries_batch"): results = client.execute_batch( - [GraphQLRequest(document=query)], + [query], extra_args={"headers": {"authorization": "xxx-123"}}, ) assert results == expected diff --git a/tests/test_websocket_exceptions.py b/tests/test_websocket_exceptions.py index cb9e7274..b6169468 100644 --- a/tests/test_websocket_exceptions.py +++ b/tests/test_websocket_exceptions.py @@ -8,7 +8,7 @@ from gql import Client, gql from gql.transport.exceptions import ( TransportAlreadyConnected, - TransportClosed, + TransportConnectionFailed, TransportProtocolError, TransportQueryError, ) @@ -41,7 +41,7 @@ @pytest.mark.asyncio @pytest.mark.parametrize("server", [invalid_query1_server], indirect=True) @pytest.mark.parametrize("query_str", [invalid_query_str]) -async def test_websocket_invalid_query(event_loop, client_and_server, query_str): +async def test_websocket_invalid_query(client_and_server, query_str): session, server = client_and_server @@ -80,7 +80,7 @@ async def server_invalid_subscription(ws): @pytest.mark.asyncio @pytest.mark.parametrize("server", [server_invalid_subscription], indirect=True) @pytest.mark.parametrize("query_str", [invalid_subscription_str]) -async def test_websocket_invalid_subscription(event_loop, client_and_server, query_str): +async def test_websocket_invalid_subscription(client_and_server, query_str): session, server = client_and_server @@ -112,15 +112,15 @@ async def server_no_ack(ws): @pytest.mark.asyncio @pytest.mark.parametrize("server", [server_no_ack], indirect=True) @pytest.mark.parametrize("query_str", [invalid_query_str]) -async def test_websocket_server_does_not_send_ack(event_loop, server, query_str): +async def test_websocket_server_does_not_send_ack(server, query_str): from gql.transport.websockets import WebsocketsTransport url = f"ws://{server.hostname}:{server.port}/graphql" - sample_transport = WebsocketsTransport(url=url, ack_timeout=1) + transport = WebsocketsTransport(url=url, ack_timeout=0.1) with pytest.raises(asyncio.TimeoutError): - async with Client(transport=sample_transport): + async with Client(transport=transport): pass @@ -135,13 +135,13 @@ async def server_connection_error(ws): @pytest.mark.asyncio @pytest.mark.parametrize("server", [server_connection_error], indirect=True) @pytest.mark.parametrize("query_str", [invalid_query_str]) -async def test_websocket_sending_invalid_data(event_loop, client_and_server, query_str): +async def test_websocket_sending_invalid_data(client_and_server, query_str): session, server = client_and_server invalid_data = "QSDF" print(f">>> {invalid_data}") - await session.transport.websocket.send(invalid_data) + await session.transport.adapter.websocket.send(invalid_data) await asyncio.sleep(2 * MS) @@ -163,9 +163,7 @@ async def server_invalid_payload(ws): @pytest.mark.asyncio @pytest.mark.parametrize("server", [server_invalid_payload], indirect=True) @pytest.mark.parametrize("query_str", [invalid_query_str]) -async def test_websocket_sending_invalid_payload( - event_loop, client_and_server, query_str -): +async def test_websocket_sending_invalid_payload(client_and_server, query_str): session, server = client_and_server @@ -176,7 +174,7 @@ async def monkey_patch_send_query( document, variable_values=None, operation_name=None, - ) -> int: + ): query_id = self.next_query_id self.next_query_id += 1 @@ -234,7 +232,7 @@ async def monkey_patch_send_query( ], indirect=True, ) -async def test_websocket_transport_protocol_errors(event_loop, client_and_server): +async def test_websocket_transport_protocol_errors(client_and_server): session, server = client_and_server @@ -252,16 +250,16 @@ async def server_without_ack(ws): @pytest.mark.asyncio @pytest.mark.parametrize("server", [server_without_ack], indirect=True) -async def test_websocket_server_does_not_ack(event_loop, server): +async def test_websocket_server_does_not_ack(server): from gql.transport.websockets import WebsocketsTransport url = f"ws://{server.hostname}:{server.port}/graphql" print(f"url = {url}") - sample_transport = WebsocketsTransport(url=url) + transport = WebsocketsTransport(url=url) with pytest.raises(TransportProtocolError): - async with Client(transport=sample_transport): + async with Client(transport=transport): pass @@ -271,17 +269,16 @@ async def server_closing_directly(ws): @pytest.mark.asyncio @pytest.mark.parametrize("server", [server_closing_directly], indirect=True) -async def test_websocket_server_closing_directly(event_loop, server): - import websockets +async def test_websocket_server_closing_directly(server): from gql.transport.websockets import WebsocketsTransport url = f"ws://{server.hostname}:{server.port}/graphql" print(f"url = {url}") - sample_transport = WebsocketsTransport(url=url) + transport = WebsocketsTransport(url=url) - with pytest.raises(websockets.exceptions.ConnectionClosed): - async with Client(transport=sample_transport): + with pytest.raises(TransportConnectionFailed): + async with Client(transport=transport): pass @@ -292,22 +289,42 @@ async def server_closing_after_ack(ws): @pytest.mark.asyncio @pytest.mark.parametrize("server", [server_closing_after_ack], indirect=True) -async def test_websocket_server_closing_after_ack(event_loop, client_and_server): - - import websockets +async def test_websocket_server_closing_after_ack(client_and_server): session, server = client_and_server query = gql("query { hello }") - with pytest.raises(websockets.exceptions.ConnectionClosed): + print("\n Trying to execute first query.\n") + + with pytest.raises(TransportConnectionFailed) as exc1: await session.execute(query) + exc1_cause = exc1.value.__cause__ + exc1_cause_str = f"{type(exc1_cause).__name__}:{exc1_cause!s}" + + print(f"\n First query Exception cause: {exc1_cause_str}\n") + + assert ( + exc1_cause_str == "ConnectionClosedOK:received 1000 (OK); then sent 1000 (OK)" + ) + await session.transport.wait_closed() - with pytest.raises(TransportClosed): + print("\n Trying to execute second query.\n") + + with pytest.raises(TransportConnectionFailed) as exc2: await session.execute(query) + exc2_cause = exc2.value.__cause__ + exc2_cause_str = f"{type(exc2_cause).__name__}:{exc2_cause!s}" + + print(f" Second query Exception cause: {exc2_cause_str}\n") + + assert ( + exc2_cause_str == "ConnectionClosedOK:received 1000 (OK); then sent 1000 (OK)" + ) + async def server_sending_invalid_query_errors(ws): await WebSocketServerHelper.send_connection_ack(ws) @@ -321,22 +338,22 @@ async def server_sending_invalid_query_errors(ws): @pytest.mark.asyncio @pytest.mark.parametrize("server", [server_sending_invalid_query_errors], indirect=True) -async def test_websocket_server_sending_invalid_query_errors(event_loop, server): +async def test_websocket_server_sending_invalid_query_errors(server): from gql.transport.websockets import WebsocketsTransport url = f"ws://{server.hostname}:{server.port}/graphql" print(f"url = {url}") - sample_transport = WebsocketsTransport(url=url) + transport = WebsocketsTransport(url=url) # Invalid server message is ignored - async with Client(transport=sample_transport): + async with Client(transport=transport): await asyncio.sleep(2 * MS) @pytest.mark.asyncio @pytest.mark.parametrize("server", [server_sending_invalid_query_errors], indirect=True) -async def test_websocket_non_regression_bug_105(event_loop, server): +async def test_websocket_non_regression_bug_105(server): from gql.transport.websockets import WebsocketsTransport # This test will check a fix to a race condition which happens if the user is trying @@ -346,9 +363,9 @@ async def test_websocket_non_regression_bug_105(event_loop, server): url = f"ws://{server.hostname}:{server.port}/graphql" print(f"url = {url}") - sample_transport = WebsocketsTransport(url=url) + transport = WebsocketsTransport(url=url) - client = Client(transport=sample_transport) + client = Client(transport=transport) # Create a coroutine which start the connection with the transport but does nothing async def client_connect(client): @@ -365,16 +382,15 @@ async def client_connect(client): @pytest.mark.asyncio @pytest.mark.parametrize("server", [invalid_query1_server], indirect=True) -async def test_websocket_using_cli_invalid_query( - event_loop, server, monkeypatch, capsys -): +async def test_websocket_using_cli_invalid_query(server, monkeypatch, capsys): url = f"ws://{server.hostname}:{server.port}/graphql" print(f"url = {url}") - from gql.cli import main, get_parser import io + from gql.cli import get_parser, main + parser = get_parser(with_examples=True) args = parser.parse_args([url]) diff --git a/tests/test_websocket_online.py b/tests/test_websocket_online.py index fa288b6d..c53be5f4 100644 --- a/tests/test_websocket_online.py +++ b/tests/test_websocket_online.py @@ -27,12 +27,10 @@ async def test_websocket_simple_query(): from gql.transport.websockets import WebsocketsTransport # Get Websockets transport - sample_transport = WebsocketsTransport( - url="wss://countries.trevorblades.com/graphql" - ) + transport = WebsocketsTransport(url="wss://countries.trevorblades.com/graphql") # Instanciate client - async with Client(transport=sample_transport) as session: + async with Client(transport=transport) as session: query = gql( """ @@ -68,12 +66,12 @@ async def test_websocket_invalid_query(): from gql.transport.websockets import WebsocketsTransport # Get Websockets transport - sample_transport = WebsocketsTransport( + transport = WebsocketsTransport( url="wss://countries.trevorblades.com/graphql", ssl=True ) # Instanciate client - async with Client(transport=sample_transport) as session: + async with Client(transport=transport) as session: query = gql( """ @@ -98,12 +96,12 @@ async def test_websocket_sending_invalid_data(): from gql.transport.websockets import WebsocketsTransport # Get Websockets transport - sample_transport = WebsocketsTransport( + transport = WebsocketsTransport( url="wss://countries.trevorblades.com/graphql", ssl=True ) # Instanciate client - async with Client(transport=sample_transport) as session: + async with Client(transport=transport) as session: query = gql( """ @@ -122,7 +120,8 @@ async def test_websocket_sending_invalid_data(): invalid_data = "QSDF" print(f">>> {invalid_data}") - await sample_transport.websocket.send(invalid_data) + assert transport.adapter.websocket is not None + await transport.adapter.websocket.send(invalid_data) await asyncio.sleep(2) @@ -134,17 +133,18 @@ async def test_websocket_sending_invalid_payload(): from gql.transport.websockets import WebsocketsTransport # Get Websockets transport - sample_transport = WebsocketsTransport( + transport = WebsocketsTransport( url="wss://countries.trevorblades.com/graphql", ssl=True ) # Instanciate client - async with Client(transport=sample_transport): + async with Client(transport=transport): invalid_payload = '{"id": "1", "type": "start", "payload": "BLAHBLAH"}' print(f">>> {invalid_payload}") - await sample_transport.websocket.send(invalid_payload) + assert transport.adapter.websocket is not None + await transport.adapter.websocket.send(invalid_payload) await asyncio.sleep(2) @@ -156,12 +156,12 @@ async def test_websocket_sending_invalid_data_while_other_query_is_running(): from gql.transport.websockets import WebsocketsTransport # Get Websockets transport - sample_transport = WebsocketsTransport( + transport = WebsocketsTransport( url="wss://countries.trevorblades.com/graphql", ssl=True ) # Instanciate client - async with Client(transport=sample_transport) as session: + async with Client(transport=transport) as session: query = gql( """ @@ -190,7 +190,8 @@ async def query_task2(): invalid_data = "QSDF" print(f">>> {invalid_data}") - await sample_transport.websocket.send(invalid_data) + assert transport.adapter.websocket is not None + await transport.adapter.websocket.send(invalid_data) task1 = asyncio.create_task(query_task1()) task2 = asyncio.create_task(query_task2()) @@ -207,12 +208,12 @@ async def test_websocket_two_queries_in_parallel_using_two_tasks(): from gql.transport.websockets import WebsocketsTransport # Get Websockets transport - sample_transport = WebsocketsTransport( + transport = WebsocketsTransport( url="wss://countries.trevorblades.com/graphql", ssl=True ) # Instanciate client - async with Client(transport=sample_transport) as session: + async with Client(transport=transport) as session: query1 = gql( """ diff --git a/tests/test_websocket_query.py b/tests/test_websocket_query.py index 2c723b3f..979bb99b 100644 --- a/tests/test_websocket_query.py +++ b/tests/test_websocket_query.py @@ -1,14 +1,14 @@ import asyncio import json import sys -from typing import Dict, Mapping +from typing import Any, Dict, Mapping import pytest from gql import Client, gql from gql.transport.exceptions import ( TransportAlreadyConnected, - TransportClosed, + TransportConnectionFailed, TransportQueryError, TransportServerError, ) @@ -50,20 +50,21 @@ @pytest.mark.asyncio @pytest.mark.parametrize("server", [server1_answers], indirect=True) -async def test_websocket_starting_client_in_context_manager(event_loop, server): - import websockets +async def test_websocket_starting_client_in_context_manager(server): from gql.transport.websockets import WebsocketsTransport url = f"ws://{server.hostname}:{server.port}/graphql" print(f"url = {url}") - transport = WebsocketsTransport(url=url) + transport = WebsocketsTransport(url=url, headers={"test": "1234"}) + + assert transport.response_headers == {} + assert isinstance(transport.headers, Mapping) + assert transport.headers["test"] == "1234" async with Client(transport=transport) as session: - assert isinstance( - transport.websocket, websockets.client.WebSocketClientProtocol - ) + assert transport._connected is True query1 = gql(query1_str) @@ -85,15 +86,14 @@ async def test_websocket_starting_client_in_context_manager(event_loop, server): assert transport.response_headers["dummy"] == "test1234" # Check client is disconnect here - assert transport.websocket is None + assert transport._connected is False -@pytest.mark.skip(reason="ssl=False is not working for now") @pytest.mark.asyncio @pytest.mark.parametrize("ws_ssl_server", [server1_answers], indirect=True) -@pytest.mark.parametrize("verify_https", ["disabled", "cert_provided"]) -async def test_websocket_using_ssl_connection(event_loop, ws_ssl_server, verify_https): +async def test_websocket_using_ssl_connection(ws_ssl_server): import websockets + from gql.transport.websockets import WebsocketsTransport server = ws_ssl_server @@ -103,20 +103,15 @@ async def test_websocket_using_ssl_connection(event_loop, ws_ssl_server, verify_ extra_args = {} - if verify_https == "cert_provided": - _, ssl_context = get_localhost_ssl_context_client() + _, ssl_context = get_localhost_ssl_context_client() - extra_args["ssl"] = ssl_context - elif verify_https == "disabled": - extra_args["ssl"] = False + extra_args["ssl"] = ssl_context transport = WebsocketsTransport(url=url, **extra_args) async with Client(transport=transport) as session: - assert isinstance( - transport.websocket, websockets.client.WebSocketClientProtocol - ) + assert isinstance(transport.adapter.websocket, websockets.ClientConnection) query1 = gql(query1_str) @@ -133,49 +128,57 @@ async def test_websocket_using_ssl_connection(event_loop, ws_ssl_server, verify_ assert africa["code"] == "AF" # Check client is disconnect here - assert transport.websocket is None + assert transport._connected is False @pytest.mark.asyncio @pytest.mark.parametrize("ws_ssl_server", [server1_answers], indirect=True) @pytest.mark.parametrize("verify_https", ["explicitely_enabled", "default"]) async def test_websocket_using_ssl_connection_self_cert_fail( - event_loop, ws_ssl_server, verify_https + ws_ssl_server, verify_https ): - from gql.transport.websockets import WebsocketsTransport from ssl import SSLCertVerificationError + from gql.transport.websockets import WebsocketsTransport + server = ws_ssl_server url = f"wss://{server.hostname}:{server.port}/graphql" print(f"url = {url}") - extra_args = {} + extra_args: Dict[str, Any] = {} if verify_https == "explicitely_enabled": extra_args["ssl"] = True transport = WebsocketsTransport(url=url, **extra_args) - with pytest.raises(SSLCertVerificationError) as exc_info: + if verify_https == "explicitely_enabled": + assert transport.ssl is True + + with pytest.raises(TransportConnectionFailed) as exc_info: async with Client(transport=transport) as session: query1 = gql(query1_str) await session.execute(query1) + cause = exc_info.value.__cause__ + + assert isinstance(cause, SSLCertVerificationError) + expected_error = "certificate verify failed: self-signed certificate" - assert expected_error in str(exc_info.value) + assert expected_error in str(cause) # Check client is disconnect here - assert transport.websocket is None + assert transport._connected is False @pytest.mark.asyncio @pytest.mark.parametrize("server", [server1_answers], indirect=True) @pytest.mark.parametrize("query_str", [query1_str]) -async def test_websocket_simple_query(event_loop, client_and_server, query_str): +async def test_websocket_simple_query(client_and_server, query_str): session, server = client_and_server @@ -195,9 +198,7 @@ async def test_websocket_simple_query(event_loop, client_and_server, query_str): @pytest.mark.asyncio @pytest.mark.parametrize("server", [server1_two_answers_in_series], indirect=True) @pytest.mark.parametrize("query_str", [query1_str]) -async def test_websocket_two_queries_in_series( - event_loop, client_and_server, query_str -): +async def test_websocket_two_queries_in_series(client_and_server, query_str): session, server = client_and_server @@ -231,9 +232,7 @@ async def server1_two_queries_in_parallel(ws): @pytest.mark.asyncio @pytest.mark.parametrize("server", [server1_two_queries_in_parallel], indirect=True) @pytest.mark.parametrize("query_str", [query1_str]) -async def test_websocket_two_queries_in_parallel( - event_loop, client_and_server, query_str -): +async def test_websocket_two_queries_in_parallel(client_and_server, query_str): session, server = client_and_server @@ -278,9 +277,7 @@ async def server_closing_while_we_are_doing_something_else(ws): "server", [server_closing_while_we_are_doing_something_else], indirect=True ) @pytest.mark.parametrize("query_str", [query1_str]) -async def test_websocket_server_closing_after_first_query( - event_loop, client_and_server, query_str -): +async def test_websocket_server_closing_after_first_query(client_and_server, query_str): session, server = client_and_server @@ -290,11 +287,11 @@ async def test_websocket_server_closing_after_first_query( await session.execute(query) # Then we do other things - await asyncio.sleep(100 * MS) + await asyncio.sleep(10 * MS) # Now the server is closed but we don't know it yet, we have to send a query # to notice it and to receive the exception - with pytest.raises(TransportClosed): + with pytest.raises(TransportConnectionFailed): await session.execute(query) @@ -308,7 +305,7 @@ async def test_websocket_server_closing_after_first_query( @pytest.mark.asyncio @pytest.mark.parametrize("server", [ignore_invalid_id_answers], indirect=True) @pytest.mark.parametrize("query_str", [query1_str]) -async def test_websocket_ignore_invalid_id(event_loop, client_and_server, query_str): +async def test_websocket_ignore_invalid_id(client_and_server, query_str): session, server = client_and_server @@ -343,7 +340,7 @@ async def assert_client_is_working(session): @pytest.mark.asyncio @pytest.mark.parametrize("server", [server1_answers], indirect=True) -async def test_websocket_multiple_connections_in_series(event_loop, server): +async def test_websocket_multiple_connections_in_series(server): from gql.transport.websockets import WebsocketsTransport url = f"ws://{server.hostname}:{server.port}/graphql" @@ -355,18 +352,18 @@ async def test_websocket_multiple_connections_in_series(event_loop, server): await assert_client_is_working(session) # Check client is disconnect here - assert transport.websocket is None + assert transport._connected is False async with Client(transport=transport) as session: await assert_client_is_working(session) # Check client is disconnect here - assert transport.websocket is None + assert transport._connected is False @pytest.mark.asyncio @pytest.mark.parametrize("server", [server1_answers], indirect=True) -async def test_websocket_multiple_connections_in_parallel(event_loop, server): +async def test_websocket_multiple_connections_in_parallel(server): from gql.transport.websockets import WebsocketsTransport url = f"ws://{server.hostname}:{server.port}/graphql" @@ -385,9 +382,7 @@ async def task_coro(): @pytest.mark.asyncio @pytest.mark.parametrize("server", [server1_answers], indirect=True) -async def test_websocket_trying_to_connect_to_already_connected_transport( - event_loop, server -): +async def test_websocket_trying_to_connect_to_already_connected_transport(server): from gql.transport.websockets import WebsocketsTransport url = f"ws://{server.hostname}:{server.port}/graphql" @@ -434,7 +429,7 @@ async def server_with_authentication_in_connection_init_payload(ws): ) @pytest.mark.parametrize("query_str", [query1_str]) async def test_websocket_connect_success_with_authentication_in_connection_init( - event_loop, server, query_str + server, query_str ): from gql.transport.websockets import WebsocketsTransport @@ -469,7 +464,7 @@ async def test_websocket_connect_success_with_authentication_in_connection_init( @pytest.mark.parametrize("query_str", [query1_str]) @pytest.mark.parametrize("init_payload", [{}, {"Authorization": "invalid_code"}]) async def test_websocket_connect_failed_with_authentication_in_connection_init( - event_loop, server, query_str, init_payload + server, query_str, init_payload ): from gql.transport.websockets import WebsocketsTransport @@ -484,7 +479,7 @@ async def test_websocket_connect_failed_with_authentication_in_connection_init( await session.execute(query1) - assert transport.websocket is None + assert transport._connected is False @pytest.mark.parametrize("server", [server1_answers], indirect=True) @@ -526,12 +521,12 @@ def test_websocket_execute_sync(server): assert africa["code"] == "AF" # Check client is disconnect here - assert transport.websocket is None + assert transport._connected is False @pytest.mark.asyncio @pytest.mark.parametrize("server", [server1_answers], indirect=True) -async def test_websocket_add_extra_parameters_to_connect(event_loop, server): +async def test_websocket_add_extra_parameters_to_connect(server): from gql.transport.websockets import WebsocketsTransport url = f"ws://{server.hostname}:{server.port}/graphql" @@ -563,9 +558,7 @@ async def server_sending_keep_alive_before_connection_ack(ws): "server", [server_sending_keep_alive_before_connection_ack], indirect=True ) @pytest.mark.parametrize("query_str", [query1_str]) -async def test_websocket_non_regression_bug_108( - event_loop, client_and_server, query_str -): +async def test_websocket_non_regression_bug_108(client_and_server, query_str): # This test will check that we now ignore keepalive message # arriving before the connection_ack @@ -587,15 +580,16 @@ async def test_websocket_non_regression_bug_108( @pytest.mark.asyncio @pytest.mark.parametrize("server", [server1_answers], indirect=True) -async def test_websocket_using_cli(event_loop, server, monkeypatch, capsys): +async def test_websocket_using_cli(server, monkeypatch, capsys): url = f"ws://{server.hostname}:{server.port}/graphql" print(f"url = {url}") - from gql.cli import main, get_parser import io import json + from gql.cli import get_parser, main + parser = get_parser(with_examples=True) args = parser.parse_args([url]) @@ -638,9 +632,7 @@ async def test_websocket_using_cli(event_loop, server, monkeypatch, capsys): @pytest.mark.asyncio @pytest.mark.parametrize("server", [server1_answers_with_extensions], indirect=True) @pytest.mark.parametrize("query_str", [query1_str]) -async def test_websocket_simple_query_with_extensions( - event_loop, client_and_server, query_str -): +async def test_websocket_simple_query_with_extensions(client_and_server, query_str): session, server = client_and_server @@ -649,3 +641,52 @@ async def test_websocket_simple_query_with_extensions( execution_result = await session.execute(query, get_execution_result=True) assert execution_result.extensions["key1"] == "val1" + + +@pytest.mark.asyncio +@pytest.mark.parametrize("server", [server1_answers], indirect=True) +async def test_websocket_adapter_connection_closed(server): + from gql.transport.websockets import WebsocketsTransport + + url = f"ws://{server.hostname}:{server.port}/graphql" + print(f"url = {url}") + + transport = WebsocketsTransport(url=url, headers={"test": "1234"}) + + async with Client(transport=transport) as session: + + query1 = gql(query1_str) + + # Close adapter connection manually (should not be done) + await transport.adapter.close() + + with pytest.raises(TransportConnectionFailed): + await session.execute(query1) + + # Check client is disconnect here + assert transport._connected is False + + +@pytest.mark.asyncio +@pytest.mark.parametrize("server", [server1_answers], indirect=True) +async def test_websocket_transport_closed_in_receive(server): + from gql.transport.websockets import WebsocketsTransport + + url = f"ws://{server.hostname}:{server.port}/graphql" + print(f"url = {url}") + + transport = WebsocketsTransport( + url=url, + close_timeout=0.1, + ) + + async with Client(transport=transport) as session: + + query1 = gql(query1_str) + + # Close adapter connection manually (should not be done) + # await transport.adapter.close() + transport._connected = False + + with pytest.raises(TransportConnectionFailed): + await session.execute(query1) diff --git a/tests/test_websocket_subscription.py b/tests/test_websocket_subscription.py index 5af44d59..5baa0b4e 100644 --- a/tests/test_websocket_subscription.py +++ b/tests/test_websocket_subscription.py @@ -9,9 +9,10 @@ from parse import search from gql import Client, gql -from gql.transport.exceptions import TransportServerError +from gql.client import AsyncClientSession +from gql.transport.exceptions import TransportConnectionFailed, TransportServerError -from .conftest import MS, WebSocketServerHelper +from .conftest import MS, PyPy, WebSocketServerHelper # Marking all tests in this file with the websockets marker pytestmark = pytest.mark.websockets @@ -126,7 +127,7 @@ async def keepalive_coro(): @pytest.mark.asyncio @pytest.mark.parametrize("server", [server_countdown], indirect=True) @pytest.mark.parametrize("subscription_str", [countdown_subscription_str]) -async def test_websocket_subscription(event_loop, client_and_server, subscription_str): +async def test_websocket_subscription(client_and_server, subscription_str): session, server = client_and_server @@ -148,7 +149,7 @@ async def test_websocket_subscription(event_loop, client_and_server, subscriptio @pytest.mark.parametrize("server", [server_countdown], indirect=True) @pytest.mark.parametrize("subscription_str", [countdown_subscription_str]) async def test_websocket_subscription_get_execution_result( - event_loop, client_and_server, subscription_str + client_and_server, subscription_str ): session, server = client_and_server @@ -160,6 +161,7 @@ async def test_websocket_subscription_get_execution_result( assert isinstance(result, ExecutionResult) + assert result.data is not None number = result.data["number"] print(f"Number received: {number}") @@ -172,16 +174,15 @@ async def test_websocket_subscription_get_execution_result( @pytest.mark.asyncio @pytest.mark.parametrize("server", [server_countdown], indirect=True) @pytest.mark.parametrize("subscription_str", [countdown_subscription_str]) -async def test_websocket_subscription_break( - event_loop, client_and_server, subscription_str -): +async def test_websocket_subscription_break(client_and_server, subscription_str): session, server = client_and_server count = 10 subscription = gql(subscription_str.format(count=count)) - async for result in session.subscribe(subscription): + generator = session.subscribe(subscription) + async for result in generator: number = result["number"] print(f"Number received: {number}") @@ -195,29 +196,38 @@ async def test_websocket_subscription_break( assert count == 5 + # Using aclose here to make it stop cleanly on pypy + await generator.aclose() + @pytest.mark.asyncio @pytest.mark.parametrize("server", [server_countdown], indirect=True) @pytest.mark.parametrize("subscription_str", [countdown_subscription_str]) -async def test_websocket_subscription_task_cancel( - event_loop, client_and_server, subscription_str -): +async def test_websocket_subscription_task_cancel(client_and_server, subscription_str): session, server = client_and_server count = 10 subscription = gql(subscription_str.format(count=count)) + task_cancelled = False + async def task_coro(): nonlocal count - async for result in session.subscribe(subscription): + nonlocal task_cancelled - number = result["number"] - print(f"Number received: {number}") + try: + async for result in session.subscribe(subscription): - assert number == count + number = result["number"] + print(f"Number received: {number}") - count -= 1 + assert number == count + + count -= 1 + except asyncio.CancelledError: + print("Inside task cancelled") + task_cancelled = True task = asyncio.ensure_future(task_coro()) @@ -233,13 +243,14 @@ async def cancel_task_coro(): await asyncio.gather(task, cancel_task) assert count > 0 + assert task_cancelled is True @pytest.mark.asyncio @pytest.mark.parametrize("server", [server_countdown], indirect=True) @pytest.mark.parametrize("subscription_str", [countdown_subscription_str]) async def test_websocket_subscription_close_transport( - event_loop, client_and_server, subscription_str + client_and_server, subscription_str ): session, server = client_and_server @@ -304,16 +315,14 @@ async def server_countdown_close_connection_in_middle(ws): ) @pytest.mark.parametrize("subscription_str", [countdown_subscription_str]) async def test_websocket_subscription_server_connection_closed( - event_loop, client_and_server, subscription_str + client_and_server, subscription_str ): - import websockets - session, server = client_and_server count = 10 subscription = gql(subscription_str.format(count=count)) - with pytest.raises(websockets.exceptions.ConnectionClosedOK): + with pytest.raises(TransportConnectionFailed): async for result in session.subscribe(subscription): @@ -329,7 +338,7 @@ async def test_websocket_subscription_server_connection_closed( @pytest.mark.parametrize("server", [server_countdown], indirect=True) @pytest.mark.parametrize("subscription_str", [countdown_subscription_str]) async def test_websocket_subscription_slow_consumer( - event_loop, client_and_server, subscription_str + client_and_server, subscription_str ): session, server = client_and_server @@ -354,17 +363,16 @@ async def test_websocket_subscription_slow_consumer( @pytest.mark.parametrize("server", [server_countdown], indirect=True) @pytest.mark.parametrize("subscription_str", [countdown_subscription_str]) async def test_websocket_subscription_with_operation_name( - event_loop, client_and_server, subscription_str + client_and_server, subscription_str ): session, server = client_and_server count = 10 subscription = gql(subscription_str.format(count=count)) + subscription.operation_name = "CountdownSubscription" - async for result in session.subscribe( - subscription, operation_name="CountdownSubscription" - ): + async for result in session.subscribe(subscription): number = result["number"] print(f"Number received: {number}") @@ -385,7 +393,7 @@ async def test_websocket_subscription_with_operation_name( @pytest.mark.parametrize("server", [server_countdown], indirect=True) @pytest.mark.parametrize("subscription_str", [countdown_subscription_str]) async def test_websocket_subscription_with_keepalive( - event_loop, client_and_server, subscription_str + client_and_server, subscription_str ): session, server = client_and_server @@ -408,16 +416,21 @@ async def test_websocket_subscription_with_keepalive( @pytest.mark.parametrize("server", [server_countdown], indirect=True) @pytest.mark.parametrize("subscription_str", [countdown_subscription_str]) async def test_websocket_subscription_with_keepalive_with_timeout_ok( - event_loop, server, subscription_str + server, subscription_str ): from gql.transport.websockets import WebsocketsTransport path = "/graphql" url = f"ws://{server.hostname}:{server.port}{path}" - sample_transport = WebsocketsTransport(url=url, keep_alive_timeout=(20 * MS)) - client = Client(transport=sample_transport) + keep_alive_timeout = 20 * MS + if PyPy: + keep_alive_timeout = 200 * MS + + transport = WebsocketsTransport(url=url, keep_alive_timeout=keep_alive_timeout) + + client = Client(transport=transport) count = 10 subscription = gql(subscription_str.format(count=count)) @@ -438,16 +451,16 @@ async def test_websocket_subscription_with_keepalive_with_timeout_ok( @pytest.mark.parametrize("server", [server_countdown], indirect=True) @pytest.mark.parametrize("subscription_str", [countdown_subscription_str]) async def test_websocket_subscription_with_keepalive_with_timeout_nok( - event_loop, server, subscription_str + server, subscription_str ): from gql.transport.websockets import WebsocketsTransport path = "/graphql" url = f"ws://{server.hostname}:{server.port}{path}" - sample_transport = WebsocketsTransport(url=url, keep_alive_timeout=(1 * MS)) + transport = WebsocketsTransport(url=url, keep_alive_timeout=(1 * MS)) - client = Client(transport=sample_transport) + client = Client(transport=transport) count = 10 subscription = gql(subscription_str.format(count=count)) @@ -473,9 +486,9 @@ def test_websocket_subscription_sync(server, subscription_str): url = f"ws://{server.hostname}:{server.port}/graphql" print(f"url = {url}") - sample_transport = WebsocketsTransport(url=url) + transport = WebsocketsTransport(url=url) - client = Client(transport=sample_transport) + client = Client(transport=transport) count = 10 subscription = gql(subscription_str.format(count=count)) @@ -499,9 +512,9 @@ def test_websocket_subscription_sync_user_exception(server, subscription_str): url = f"ws://{server.hostname}:{server.port}/graphql" print(f"url = {url}") - sample_transport = WebsocketsTransport(url=url) + transport = WebsocketsTransport(url=url) - client = Client(transport=sample_transport) + client = Client(transport=transport) count = 10 subscription = gql(subscription_str.format(count=count)) @@ -530,9 +543,9 @@ def test_websocket_subscription_sync_break(server, subscription_str): url = f"ws://{server.hostname}:{server.port}/graphql" print(f"url = {url}") - sample_transport = WebsocketsTransport(url=url) + transport = WebsocketsTransport(url=url) - client = Client(transport=sample_transport) + client = Client(transport=transport) count = 10 subscription = gql(subscription_str.format(count=count)) @@ -571,9 +584,9 @@ def test_websocket_subscription_sync_graceful_shutdown(server, subscription_str) url = f"ws://{server.hostname}:{server.port}/graphql" print(f"url = {url}") - sample_transport = WebsocketsTransport(url=url) + transport = WebsocketsTransport(url=url) - client = Client(transport=sample_transport) + client = Client(transport=transport) count = 10 subscription = gql(subscription_str.format(count=count)) @@ -595,6 +608,7 @@ def test_websocket_subscription_sync_graceful_shutdown(server, subscription_str) warnings.filterwarnings( "ignore", message="There is no current event loop" ) + assert isinstance(client.session, AsyncClientSession) interrupt_task = asyncio.ensure_future( client.session._generator.athrow(KeyboardInterrupt) ) @@ -604,6 +618,7 @@ def test_websocket_subscription_sync_graceful_shutdown(server, subscription_str) assert count == 4 # Catch interrupt_task exception to remove warning + assert interrupt_task is not None interrupt_task.exception() # Check that the server received a connection_terminate message last @@ -614,16 +629,16 @@ def test_websocket_subscription_sync_graceful_shutdown(server, subscription_str) @pytest.mark.parametrize("server", [server_countdown], indirect=True) @pytest.mark.parametrize("subscription_str", [countdown_subscription_str]) async def test_websocket_subscription_running_in_thread( - event_loop, server, subscription_str, run_sync_test + server, subscription_str, run_sync_test ): from gql.transport.websockets import WebsocketsTransport def test_code(): path = "/graphql" url = f"ws://{server.hostname}:{server.port}{path}" - sample_transport = WebsocketsTransport(url=url) + transport = WebsocketsTransport(url=url) - client = Client(transport=sample_transport) + client = Client(transport=transport) count = 10 subscription = gql(subscription_str.format(count=count)) @@ -638,4 +653,4 @@ def test_code(): assert count == -1 - await run_sync_test(event_loop, server, test_code) + await run_sync_test(server, test_code) diff --git a/tests/test_websockets_adapter.py b/tests/test_websockets_adapter.py new file mode 100644 index 00000000..31422487 --- /dev/null +++ b/tests/test_websockets_adapter.py @@ -0,0 +1,100 @@ +import json +from typing import Mapping + +import pytest +from graphql import print_ast + +from gql import gql +from gql.transport.exceptions import TransportConnectionFailed + +# Marking all tests in this file with the websockets marker +pytestmark = pytest.mark.websockets + +query1_str = """ + query getContinents { + continents { + code + name + } + } +""" + +query1_server_answer = ( + '{{"type":"data","id":"{query_id}","payload":{{"data":{{"continents":[' + '{{"code":"AF","name":"Africa"}},{{"code":"AN","name":"Antarctica"}},' + '{{"code":"AS","name":"Asia"}},{{"code":"EU","name":"Europe"}},' + '{{"code":"NA","name":"North America"}},{{"code":"OC","name":"Oceania"}},' + '{{"code":"SA","name":"South America"}}]}}}}}}' +) + +server1_answers = [ + query1_server_answer, +] + + +@pytest.mark.asyncio +@pytest.mark.parametrize("server", [server1_answers], indirect=True) +async def test_websockets_adapter_simple_query(server): + from gql.transport.common.adapters.websockets import WebSocketsAdapter + + url = f"ws://{server.hostname}:{server.port}/graphql" + + query = print_ast(gql(query1_str).document) + print("query=", query) + + adapter = WebSocketsAdapter(url) + + await adapter.connect() + + init_message = json.dumps({"type": "connection_init", "payload": {}}) + + await adapter.send(init_message) + + result = await adapter.receive() + print(f"result={result}") + + payload = json.dumps({"query": query}) + query_message = json.dumps({"id": 1, "type": "start", "payload": payload}) + + await adapter.send(query_message) + + result = await adapter.receive() + print(f"result={result}") + + await adapter.close() + + +@pytest.mark.asyncio +@pytest.mark.parametrize("server", [server1_answers], indirect=True) +async def test_websockets_adapter_edge_cases(server): + from gql.transport.common.adapters.websockets import WebSocketsAdapter + + url = f"ws://{server.hostname}:{server.port}/graphql" + + query = print_ast(gql(query1_str).document) + print("query=", query) + + adapter = WebSocketsAdapter(url, headers={"a": "r1"}, ssl=False, connect_args={}) + + await adapter.connect() + + assert isinstance(adapter.headers, Mapping) + assert adapter.headers["a"] == "r1" + assert adapter.ssl is False + assert adapter.connect_args == {} + assert adapter.response_headers["dummy"] == "test1234" + + # Connect twice causes AssertionError + with pytest.raises(AssertionError): + await adapter.connect() + + await adapter.close() + + # Second close call is ignored + await adapter.close() + + with pytest.raises(TransportConnectionFailed): + await adapter.send("Blah") + + with pytest.raises(TransportConnectionFailed): + await adapter.receive() diff --git a/tox.ini b/tox.ini index 8796357b..f6d4b48e 100644 --- a/tox.ini +++ b/tox.ini @@ -47,7 +47,7 @@ commands = basepython = python deps = -e.[dev] commands = - isort --recursive --check-only --diff gql tests + isort --check-only --diff gql tests [testenv:mypy] basepython = python