8000 subscribe-test: extract subscribe_with_bad_fn function · graphql-python/graphql-core@47ecdb3 · GitHub
[go: up one dir, main page]

Skip to content

Commit 47ecdb3

Browse files
committed
subscribe-test: extract subscribe_with_bad_fn function
Replicates graphql/graphql-js@2deb272
1 parent 1d6e008 commit 47ecdb3

Original file line numberDiff line numberDiff line change
@@ -3,7 +3,12 @@
33

44
from pytest import mark, raises
55 8000

6-
from graphql.execution import MapAsyncIterator, create_source_event_stream, subscribe
6+
from graphql.execution import (
7+
ExecutionResult,
8+
MapAsyncIterator,
9+
create_source_event_stream,
10+
subscribe,
11+
)
712
from graphql.language import parse
813
from graphql.pyutils import SimplePubSub
914
from graphql.type import (
@@ -132,6 +137,22 @@ def transform(new_email):
132137
DummyQueryType = GraphQLObjectType("Query", {"dummy": GraphQLField(GraphQLString)})
133138

134139

140+
async def subscribe_with_bad_fn(subscribe_fn: Callable) -> ExecutionResult:
141+
schema = GraphQLSchema(
142+
query=DummyQueryType,
143+
subscription=GraphQLObjectType(
144+
"Subscription",
145+
{"foo": GraphQLField(GraphQLString, subscribe=subscribe_fn)},
146+
),
147+
)
148+
document = parse("subscription { foo }")
149+
result = await subscribe(schema, document)
150+
151+
assert isinstance(result, ExecutionResult)
152+
assert await create_source_event_stream(schema, document) == result
153+
return result
154+
155+
135156
# Check all error cases when initializing the subscription.
136157
def describe_subscription_initialization_phase():
137158
@mark.asyncio
@@ -333,43 +354,15 @@ async def should_pass_through_unexpected_errors_thrown_in_subscribe():
333354
@mark.asyncio
334355
@mark.filterwarnings("ignore:.* was never awaited:RuntimeWarning")
335356
async def throws_an_error_if_subscribe_does_not_return_an_iterator():
336-
schema = GraphQLSchema(
337-
query=DummyQueryType,
338-
subscription=GraphQLObjectType(
339-
"Subscription",
340-
{
341-
"foo": GraphQLField(
342-
GraphQLString, subscribe=lambda _obj, _info: "test"
343-
)
344-
},
345-
),
346-
)
347-
348-
document = parse("subscription { foo }")
349-
350357
with raises(TypeError) as exc_info:
351-
await subscribe(schema, document)
358+
await subscribe_with_bad_fn(lambda _obj, _info: "test")
352359

353360
assert str(exc_info.value) == (
354361
"Subscription field must return AsyncIterable. Received: 'test'."
355362
)
356363

357364
@mark.asyncio
358365
async def resolves_to_an_error_for_subscription_resolver_errors():
359-
async def subscribe_with_fn(subscribe_fn: Callable):
360-
schema = GraphQLSchema(
361-
query=DummyQueryType,
362-
subscription=GraphQLObjectType(
363-
"Subscription",
364-
{"foo": GraphQLField(GraphQLString, subscribe=subscribe_fn)},
365-
),
366-
)
367-
document = parse("subscription { foo }")
368-
result = await subscribe(schema, document)
369-
370-
assert await create_source_event_stream(schema, document) == result
371-
return result
372-
373366
expected_result = (
374367
None,
375368
[
@@ -385,25 +378,25 @@ async def subscribe_with_fn(subscribe_fn: Callable):
385378
def return_error(_obj, _info):
386379
return TypeError("test error")
387380

388-
assert await subscribe_with_fn(return_error) == expected_result
381+
assert await subscribe_with_bad_fn(return_error) == expected_result
389382

390383
# Throwing an error
391384
def throw_error(*_args):
392385
raise TypeError("test error")
393386

394-
assert await subscribe_with_fn(throw_error) == expected_result
387+
assert await subscribe_with_bad_fn(throw_error) == expected_result
395388

396389
# Resolving to an error
397390
async def resolve_error(*_args):
398391
return TypeError("test error")
399392

400-
assert await subscribe_with_fn(resolve_error) == expected_result
393+
assert await subscribe_with_bad_fn(resolve_error) == expected_result
401394

402395
# Rejecting with an error
403396
async def reject_error(*_args):
404397
return TypeError("test error")
405398

406-
assert await subscribe_with_fn(reject_error) == expected_result
399+
assert await subscribe_with_bad_fn(reject_error) == expected_result
407400

408401
@mark.asyncio
409402
async def resolves_to_an_error_if_variables_were_wrong_type():