@@ -372,11 +372,8 @@ def tearDown(self):
372
372
self .loop = None
373
373
asyncio .set_event_loop_policy (None )
374
374
375
- def test_async_gen_anext (self ):
376
- async def gen ():
377
- yield 1
378
- yield 2
379
- g = gen ()
375
+ def check_async_iterator_anext (self , ait_class ):
376
+ g = ait_class ()
380
377
async def consume ():
381
378
results = []
382
379
results .append (await anext (g ))
@@ -388,6 +385,66 @@ async def consume():
388
385
with self .assertRaises (StopAsyncIteration ):
389
386
self .loop .run_until_complete (consume ())
390
387
388
+ async def test_2 ():
389
+ g1 = ait_class ()
390
+ self .assertEqual (await anext (g1 ), 1 )
391
+ self .assertEqual (await anext (g1 ), 2 )
392
+ with self .assertRaises (StopAsyncIteration ):
393
+ await anext (g1 )
394
+ with self .assertRaises (StopAsyncIteration ):
395
+ await anext (g1 )
396
+
397
+ g2 = ait_class ()
398
+ self .assertEqual (await anext (g2 , "default" ), 1 )
399
+ self .assertEqual (await anext (g2 , "default" ), 2 )
400
+ self .assertEqual (await anext (g2 , "default" ), "default" )
401
+ self .assertEqual (await anext (g2 , "default" ), "default" )
402
+
403
+ return "completed"
404
+
405
+ result = self .loop .run_until_complete (test_2 ())
406
+ self .assertEqual (result , "completed" )
407
+
408
+ def test_async_generator_anext (self ):
409
+ async def agen ():
410
+ yield 1
411
+ yield 2
412
+ self .check_async_iterator_anext (agen )
413
+
414
+ def test_python_async_iterator_anext (self ):
415
+ class MyAsyncIter :
416
+ """Asynchronously yield 1, then 2."""
417
+ def __init__ (self ):
418
+ self .yielded = 0
419
+ def __aiter__ (self ):
420
+ return self
421
+ async def __anext__ (self ):
422
+ if self .yielded >= 2 :
423
+ raise StopAsyncIteration ()
424
+ else :
425
+ self .yielded += 1
426
+ return self .yielded
427
+ self .check_async_iterator_anext (MyAsyncIter )
428
+
429
+ def test_python_async_iterator_types_coroutine_anext (self ):
430
+ import types
431
+ class MyAsyncIterWithTypesCoro :
432
+ """Asynchronously yield 1, then 2."""
433
+ def __init__ (self ):
434
+ self .yielded = 0
435
+ def __aiter__ (self ):
436
+ return self
437
+ @types .coroutine
438
+ def __anext__ (self ):
439
+ if False :
440
+ yield "this is a generator-based coroutine"
441
+ if self .yielded >= 2 :
442
+ raise StopAsyncIteration ()
443
+ else :
444
+ self .yielded += 1
445
+ return self .yielded
446
+ self .check_async_iterator_anext (MyAsyncIterWithTypesCoro )
447
+
391
448
def test_async_gen_aiter (self ):
392
449
async def gen ():
393
450
yield 1
@@ -431,12 +488,85 @@ async def call_with_too_many_args():
431
488
await anext (gen (), 1 , 3 )
432
489
async def call_with_wrong_type_args ():
433
490
await anext (1 , gen ())
491
+ async def call_with_kwarg ():
492
+ await anext (aiterator = gen ())
434
493
with self .assertRaises (TypeError ):
435
494
self .loop .run_until_complete (call_with_too_few_args ())
436
495
with self .assertRaises (TypeError ):
437
496
self .loop .run_until_complete (call_with_too_many_args ())
438
497
with self .assertRaises (TypeError ):
439
498
self .loop .run_until_complete (call_with_wrong_type_args ())
499
+ with self .assertRaises (TypeError ):
500
+ self .loop .run_until_complete (call_with_kwarg ())
501
+
502
+ def test_anext_bad_await (self ):
503
+ async def bad_awaitable ():
504
+ class BadAwaitable :
505
+ def __await__ (self ):
506
+ return 42
507
+ class MyAsyncIter :
508
+ def __aiter__ (self ):
509
+ return self
510
+ def __anext__ (self ):
511
+ return BadAwaitable ()
512
+ regex = r"__await__.*iterator"
513
+ awaitable = anext (MyAsyncIter (), "default" )
514
+ with self .assertRaisesRegex (TypeError , regex ):
515
+ await awaitable
516
+ awaitable = anext (MyAsyncIter ())
517
+ with self .assertRaisesRegex (TypeError , regex ):
518
+ await awaitable
519
+ return "completed"
520
+ result = self .loop .run_until_complete (bad_awaitable ())
521
+ self .assertEqual (result , "completed" )
522
+
523
+ async def check_anext_returning_iterator (self , aiter_class ):
524
+ awaitable = anext (aiter_class (), "default" )
525
+ with self .assertRaises (TypeError ):
526
+ await awaitable
527
+ awaitable = anext (aiter_class ())
528
+ with self .assertRaises (TypeError ):
529
+ await awaitable
530
+ return "completed"
531
+
532
+ def test_anext_return_iterator (self ):
533
+ class WithIterAnext :
534
+ def __aiter__ (self ):
535
+ return self
536
+ def __anext__ (self ):
537
+ return iter ("abc" )
538
+ result = self .loop .run_until_complete (self .check_anext_returning_iterator (WithIterAnext ))
539
+ self .assertEqual (result , "completed" )
540
+
541
+ def test_anext_return_generator (self ):
542
+ class WithGenAnext :
543
+ def __aiter__ (self ):
544
+ return self
545
+ def __anext__ (self ):
546
+ yield
547
+ result = self .loop .run_until_complete (self .check_anext_returning_iterator (WithGenAnext ))
548
+ self .assertEqual (result , "completed" )
549
+
550
+ def test_anext_await_raises (self ):
551
+ class RaisingAwaitable :
552
+ def __await__ (self ):
553
+ raise ZeroDivisionError ()
554
+ yield
555
+ class WithRaisingAwaitableAnext :
556
+ def __aiter__ (self ):
557
+ return self
558
+ def __anext__ (self ):
559
+ return RaisingAwaitable ()
560
+ async def do_test ():
561
+ awaitable = anext (WithRaisingAwaitableAnext ())
562
+ with self .assertRaises (ZeroDivisionError ):
563
+ await awaitable
564
+ awaitable = anext (WithRaisingAwaitableAnext (), "default" )
565
+ with self .assertRaises (ZeroDivisionError ):
566
+ await awaitable
567
+ return "completed"
568
+ result = self .loop .run_until_complete (do_test ())
569
+ self .assertEqual (result , "completed" )
440
570
441
571
def test_aiter_bad_args (self ):
442
572
async def gen ():
0 commit comments