diff --git a/graphql_server/aiohttp/graphqlview.py b/graphql_server/aiohttp/graphqlview.py index fa4d998..46a71df 100644 --- a/graphql_server/aiohttp/graphqlview.py +++ b/graphql_server/aiohttp/graphqlview.py @@ -69,7 +69,7 @@ def get_root_value(self): def get_context(self, request): context = ( copy.copy(self.context) - if self.context and isinstance(self.context, MutableMapping) + if self.context is not None and isinstance(self.context, MutableMapping) else {} ) if isinstance(context, MutableMapping) and "request" not in context: diff --git a/graphql_server/flask/graphqlview.py b/graphql_server/flask/graphqlview.py index 4bb4665..faff4e7 100644 --- a/graphql_server/flask/graphqlview.py +++ b/graphql_server/flask/graphqlview.py @@ -68,7 +68,7 @@ def get_root_value(self): def get_context(self): context = ( copy.copy(self.context) - if self.context and isinstance(self.context, MutableMapping) + if self.context is not None and isinstance(self.context, MutableMapping) else {} ) if isinstance(context, MutableMapping) and "request" not in context: diff --git a/graphql_server/quart/graphqlview.py b/graphql_server/quart/graphqlview.py index 2ac624b..51c00df 100644 --- a/graphql_server/quart/graphqlview.py +++ b/graphql_server/quart/graphqlview.py @@ -69,7 +69,7 @@ def get_root_value(self): def get_context(self): context = ( copy.copy(self.context) - if self.context and isinstance(self.context, MutableMapping) + if self.context is not None and isinstance(self.context, MutableMapping) else {} ) if isinstance(context, MutableMapping) and "request" not in context: diff --git a/graphql_server/sanic/graphqlview.py b/graphql_server/sanic/graphqlview.py index 7bea500..8f8c87e 100644 --- a/graphql_server/sanic/graphqlview.py +++ b/graphql_server/sanic/graphqlview.py @@ -71,7 +71,7 @@ def get_root_value(self): def get_context(self, request): context = ( copy.copy(self.context) - if self.context and isinstance(self.context, MutableMapping) + if self.context is not None and isinstance(self.context, MutableMapping) else {} ) if isinstance(context, MutableMapping) and "request" not in context: diff --git a/graphql_server/webob/graphqlview.py b/graphql_server/webob/graphqlview.py index 1048e49..8b43abc 100644 --- a/graphql_server/webob/graphqlview.py +++ b/graphql_server/webob/graphqlview.py @@ -67,7 +67,7 @@ def get_root_value(self): def get_context(self, request): context = ( copy.copy(self.context) - if self.context and isinstance(self.context, MutableMapping) + if self.context is not None and isinstance(self.context, MutableMapping) else {} ) if isinstance(context, MutableMapping) and "request" not in context: diff --git a/tests/aiohttp/schema.py b/tests/aiohttp/schema.py index 54e0d10..e94088f 100644 --- a/tests/aiohttp/schema.py +++ b/tests/aiohttp/schema.py @@ -35,6 +35,9 @@ def resolve_raises(*_): GraphQLNonNull(GraphQLString), resolve=lambda obj, info: info.context["request"], ), + "property": GraphQLField( + GraphQLString, resolve=lambda obj, info: info.context.property + ), }, ), resolve=lambda obj, info: info.context, diff --git a/tests/aiohttp/test_graphqlview.py b/tests/aiohttp/test_graphqlview.py index 41e31d3..dc7cbf7 100644 --- a/tests/aiohttp/test_graphqlview.py +++ b/tests/aiohttp/test_graphqlview.py @@ -555,6 +555,24 @@ async def test_context_remapped_if_not_mapping(app, client): assert "Request" in _json["data"]["context"]["request"] +class CustomContext(dict): + property = "A custom property" + + +@pytest.mark.asyncio +@pytest.mark.parametrize("app", [create_app(context=CustomContext())]) +async def test_allow_empty_custom_context(app, client): + response = await client.get(url_string(query="{context { property request }}")) + + _json = await response.json() + assert response.status == 200 + assert "data" in _json + assert "request" in _json["data"]["context"] + assert "property" in _json["data"]["context"] + assert "A custom property" == _json["data"]["context"]["property"] + assert "Request" in _json["data"]["context"]["request"] + + @pytest.mark.asyncio @pytest.mark.parametrize("app", [create_app(context={"request": "test"})]) async def test_request_not_replaced(app, client): diff --git a/tests/flask/schema.py b/tests/flask/schema.py index eb51e26..fc056fa 100644 --- a/tests/flask/schema.py +++ b/tests/flask/schema.py @@ -29,6 +29,9 @@ def resolve_raises(*_): GraphQLNonNull(GraphQLString), resolve=lambda obj, info: info.context["request"], ), + "property": GraphQLField( + GraphQLString, resolve=lambda obj, info: info.context.property + ), }, ), resolve=lambda obj, info: info.context, diff --git a/tests/flask/test_graphqlview.py b/tests/flask/test_graphqlview.py index 5dc3be0..aed8392 100644 --- a/tests/flask/test_graphqlview.py +++ b/tests/flask/test_graphqlview.py @@ -518,6 +518,23 @@ def test_context_remapped_if_not_mapping(app, client): assert "Request" in res["data"]["context"]["request"] +class CustomContext(dict): + property = "A custom property" + + +@pytest.mark.parametrize("app", [create_app(context=CustomContext())]) +def test_allow_empty_custom_context(app, client): + response = client.get(url_string(app, query="{context { property request }}")) + + assert response.status_code == 200 + res = response_json(response) + assert "data" in res + assert "request" in res["data"]["context"] + assert "property" in res["data"]["context"] + assert "A custom property" == res["data"]["context"]["property"] + assert "Request" in res["data"]["context"]["request"] + + def test_post_multipart_data(app, client): query = "mutation TestMutation { writeTest { test } }" response = client.post( diff --git a/tests/quart/schema.py b/tests/quart/schema.py index eb51e26..fc056fa 100644 --- a/tests/quart/schema.py +++ b/tests/quart/schema.py @@ -29,6 +29,9 @@ def resolve_raises(*_): GraphQLNonNull(GraphQLString), resolve=lambda obj, info: info.context["request"], ), + "property": GraphQLField( + GraphQLString, resolve=lambda obj, info: info.context.property + ), }, ), resolve=lambda obj, info: info.context, diff --git a/tests/quart/test_graphqlview.py b/tests/quart/test_graphqlview.py index d0da414..6db7bde 100644 --- a/tests/quart/test_graphqlview.py +++ b/tests/quart/test_graphqlview.py @@ -648,6 +648,25 @@ async def test_context_remapped_if_not_mapping(app: Quart, client: TestClientPro assert "Request" in res["data"]["context"]["request"] +class CustomContext(dict): + property = "A custom property" + + +@pytest.mark.asyncio +@pytest.mark.parametrize("app", [create_app(context=CustomContext())]) +async def test_allow_empty_custom_context(app: Quart, client: TestClientProtocol): + response = await execute_client(app, client, query="{context { property request }}") + + assert response.status_code == 200 + result = await response.get_data(as_text=True) + res = response_json(result) + assert "data" in res + assert "request" in res["data"]["context"] + assert "property" in res["data"]["context"] + assert "A custom property" == res["data"]["context"]["property"] + assert "Request" in res["data"]["context"]["request"] + + # @pytest.mark.asyncio # async def test_post_multipart_data(app: Quart, client: TestClientProtocol): # query = "mutation TestMutation { writeTest { test } }" diff --git a/tests/sanic/schema.py b/tests/sanic/schema.py index 3c3298f..5df9a20 100644 --- a/tests/sanic/schema.py +++ b/tests/sanic/schema.py @@ -32,6 +32,9 @@ def resolve_raises(*_): GraphQLNonNull(GraphQLString), resolve=lambda obj, info: info.context["request"], ), + "property": GraphQLField( + GraphQLString, resolve=lambda obj, info: info.context.property + ), }, ), resolve=lambda obj, info: info.context, diff --git a/tests/sanic/test_graphqlview.py b/tests/sanic/test_graphqlview.py index d71fa62..8620762 100644 --- a/tests/sanic/test_graphqlview.py +++ b/tests/sanic/test_graphqlview.py @@ -498,6 +498,25 @@ def test_passes_custom_context_into_context(app): assert "Request" in res["data"]["context"]["request"] +class CustomContext(dict): + property = "A custom property" + + +@pytest.mark.parametrize("app", [create_app(context=CustomContext())]) +def test_allow_empty_custom_context(app): + _, response = app.test_client.get( + uri=url_string(query="{context { property request }}") + ) + + assert response.status_code == 200 + res = response_json(response) + assert "data" in res + assert "request" in res["data"]["context"] + assert "property" in res["data"]["context"] + assert "A custom property" == res["data"]["context"]["property"] + assert "Request" in res["data"]["context"]["request"] + + @pytest.mark.parametrize("app", [create_app(context="CUSTOM CONTEXT")]) def test_context_remapped_if_not_mapping(app): _, response = app.test_client.get( diff --git a/tests/webob/schema.py b/tests/webob/schema.py index e6aa93f..e94f596 100644 --- a/tests/webob/schema.py +++ b/tests/webob/schema.py @@ -30,6 +30,9 @@ def resolve_raises(*_): GraphQLNonNull(GraphQLString), resolve=lambda obj, info: info.context["request"], ), + "property": GraphQLField( + GraphQLString, resolve=lambda obj, info: info.context.property + ), }, ), resolve=lambda obj, info: info.context, diff --git a/tests/webob/test_graphqlview.py b/tests/webob/test_graphqlview.py index 53e9680..f2b761f 100644 --- a/tests/webob/test_graphqlview.py +++ b/tests/webob/test_graphqlview.py @@ -471,6 +471,23 @@ def test_context_remapped_if_not_mapping(client, settings): assert "request" in res["data"]["context"]["request"] +class CustomContext(dict): + property = "A custom property" + + +@pytest.mark.parametrize("settings", [dict(context=CustomContext())]) +def test_allow_empty_custom_context(client, settings): + response = client.get(url_string(query="{context { property request }}")) + + assert response.status_code == 200 + res = response_json(response) + assert "data" in res + assert "request" in res["data"]["context"] + assert "property" in res["data"]["context"] + assert "A custom property" == res["data"]["context"]["property"] + assert "request" in res["data"]["context"]["request"] + + def test_post_multipart_data(client): query = "mutation TestMutation { writeTest { test } }" data = (