3
3
4
4
from pytest import mark , raises
5
5
8000
6
- from graphql .execution import MapAsyncIterator , create_source_event_stream , subscribe
6
+ from graphql .execution import (
7
+ ExecutionResult ,
8
+ MapAsyncIterator ,
9
+ create_source_event_stream ,
10
+ subscribe ,
11
+ )
7
12
from graphql .language import parse
8
13
from graphql .pyutils import SimplePubSub
9
14
from graphql .type import (
@@ -132,6 +137,22 @@ def transform(new_email):
132
137
DummyQueryType = GraphQLObjectType ("Query" , {"dummy" : GraphQLField (GraphQLString )})
133
138
134
139
140
+ async def subscribe_with_bad_fn (subscribe_fn : Callable ) -> ExecutionResult :
141
+ schema = GraphQLSchema (
142
+ query = DummyQueryType ,
143
+ subscription = GraphQLObjectType (
144
+ "Subscription" ,
145
+ {"foo" : GraphQLField (GraphQLString , subscribe = subscribe_fn )},
146
+ ),
147
+ )
148
+ document = parse ("subscription { foo }" )
149
+ result = await subscribe (schema , document )
150
+
151
+ assert isinstance (result , ExecutionResult )
152
+ assert await create_source_event_stream (schema , document ) == result
153
+ return result
154
+
155
+
135
156
# Check all error cases when initializing the subscription.
136
157
def describe_subscription_initialization_phase ():
137
158
@mark .asyncio
@@ -333,43 +354,15 @@ async def should_pass_through_unexpected_errors_thrown_in_subscribe():
333
354
@mark .asyncio
334
355
@mark .filterwarnings ("ignore:.* was never awaited:RuntimeWarning" )
335
356
async def throws_an_error_if_subscribe_does_not_return_an_iterator ():
336
- schema = GraphQLSchema (
337
- query = DummyQueryType ,
338
- subscription = GraphQLObjectType (
339
- "Subscription" ,
340
- {
341
- "foo" : GraphQLField (
342
- GraphQLString , subscribe = lambda _obj , _info : "test"
343
- )
344
- },
345
- ),
346
- )
347
-
348
- document = parse ("subscription { foo }" )
349
-
350
357
with raises (TypeError ) as exc_info :
351
- await subscribe ( schema , document )
358
+ await subscribe_with_bad_fn ( lambda _obj , _info : "test" )
352
359
353
360
assert str (exc_info .value ) == (
354
361
"Subscription field must return AsyncIterable. Received: 'test'."
355
362
)
356
363
357
364
@mark .asyncio
358
365
async def resolves_to_an_error_for_subscription_resolver_errors ():
359
- async def subscribe_with_fn (subscribe_fn : Callable ):
360
- schema = GraphQLSchema (
361
- query = DummyQueryType ,
362
- subscription = GraphQLObjectType (
363
- "Subscription" ,
364
- {"foo" : GraphQLField (GraphQLString , subscribe = subscribe_fn )},
365
- ),
366
- )
367
- document = parse ("subscription { foo }" )
368
- result = await subscribe (schema , document )
369
-
370
- assert await create_source_event_stream (schema , document ) == result
371
- return result
372
-
373
366
expected_result = (
374
367
None ,
375
368
[
@@ -385,25 +378,25 @@ async def subscribe_with_fn(subscribe_fn: Callable):
385
378
def return_error (_obj , _info ):
386
379
return TypeError ("test error" )
387
380
388
- assert await subscribe_with_fn (return_error ) == expected_result
381
+ assert await subscribe_with_bad_fn (return_error ) == expected_result
389
382
390
383
# Throwing an error
391
384
def throw_error (* _args ):
392
385
raise TypeError ("test error" )
393
386
394
- assert await subscribe_with_fn (throw_error ) == expected_result
387
+ assert await subscribe_with_bad_fn (throw_error ) == expected_result
395
388
396
389
# Resolving to an error
397
390
async def resolve_error (* _args ):
398
391
return TypeError ("test error" )
399
392
400
- assert await subscribe_with_fn (resolve_error ) == expected_result
393
+ assert await subscribe_with_bad_fn (resolve_error ) == expected_result
401
394
402
395
# Rejecting with an error
403
396
async def reject_error (* _args ):
404
397
return TypeError ("test error" )
405
398
406
- assert await subscribe_with_fn (reject_error ) == expected_result
399
+ assert await subscribe_with_bad_fn (reject_error ) == expected_result
407
400
408
401
@mark .asyncio
409
402
async def resolves_to_an_error_if_variables_were_wrong_type ():
0 commit comments