@@ -341,6 +341,104 @@ def get_generator_genfunc(obj):
341
341
self .process_tests (get_generator_genfunc )
342
342
343
343
344
+ class SequenceClass :
345
+ def __init__ (self , n ):
346
+ self .n = n
347
+ def __getitem__ (self , i ):
348
+ if 0 <= i < self .n :
349
+ return i
350
+ else :
351
+ raise IndexError
352
+
353
+
354
+ class IncorrectIterable :
355
+ def __iter__ (self ):
356
+ return 123
357
+
358
+
359
+ class IncorrectAsyncIterable :
360
+ def __aiter__ (self ):
361
+ return 123
362
+
363
+
364
+ class CheckIterableTest (unittest .TestCase ):
365
+ sequences = (
366
+ (1 , 2 ),
367
+ [1 , 2 , 3 ],
368
+ range (42 ),
369
+ SequenceClass (10 ),
370
+ )
371
+
372
+ non_iterables = (
373
+ None ,
374
+ 42 ,
375
+ 13.0 ,
376
+ )
377
+
378
+ err_msg_sync = "'.*' object is not iterable"
379
+ err_msg_async = "'async for' requires an object with " \
380
+ "__aiter__ method, got .*"
381
+
382
+ def test_sequences (self ):
383
+ for seq in self .sequences :
384
+ (x for x in seq )
385
+ (x for x in iter (seq ))
386
+ with self .assertRaisesRegex (TypeError , self .err_msg_async ):
387
+ (x async for x in seq )
388
+ with self .assertRaisesRegex (TypeError , self .err_msg_async ):
389
+ (x async for x in iter (seq ))
390
+
391
+ def test_non_iterables (self ):
392
+ for obj in self .non_iterables :
393
+ with self .assertRaisesRegex (TypeError , self .err_msg_sync ):
394
+ (x for x in obj )
395
+ with self .assertRaisesRegex (TypeError , self .err_msg_async ):
396
+ (x async for x in obj )
397
+
398
+ def test_generators (self ):
399
+ def gen ():
400
+ yield 1
401
+
402
+ (x for x in gen ())
403
+ (x for x in iter (gen ()))
404
+
405
+ with self .assertRaisesRegex (TypeError , self .err_msg_async ):
406
+ (x async for x in gen ())
407
+ with self .assertRaisesRegex (TypeError , self .err_msg_async ):
408
+ (x async for x in iter (gen ()))
409
+
410
+ def test_async_generators (self ):
411
+ async def agen ():
412
+ yield 1
413
+ yield 2
414
+
415
+ with self .assertRaisesRegex (TypeError , self .err_msg_sync ):
416
+ (x for x in agen ())
417
+ with self .assertRaisesRegex (TypeError , self .err_msg_sync ):
418
+ (x for x in aiter (agen ()))
419
+
420
+ (x async for x in agen ())
421
+ (x async for x in aiter (agen ()))
422
+
423
+ def test_incorrect_iterable (self ):
424
+ g = (x for x in IncorrectIterable ())
425
+ err_msg = ".* returned non-iterator of type '.*'"
426
+ with self .assertRaisesRegex (TypeError , err_msg ):
427
+ list (g )
428
+
429
+ def test_incorrect_async_iterable (self ):
430
+ g = (x async for x in IncorrectAsyncIterable ())
431
+
432
+ async def coroutine ():
433
+ async for x in g :
434
+ pass
435
+
436
+ err_msg = "'async for' received an object from __aiter__ " \
437
+ "that does not implement __anext__: .*"
438
+ with self .assertRaisesRegex (TypeError , err_msg ):
439
+ coroutine ().send (None )
440
+
441
+
344
442
class ExceptionTest (unittest .TestCase ):
345
443
# Tests for the issue #23353: check that the currently handled exception
346
444
# is correctly saved/restored in PyEval_EvalFrameEx().
0 commit comments