8000 bpo-43751: Fix anext() bug where it erroneously returned None (GH-25238) · python/cpython@dfb4532 · GitHub
[go: up one dir, main page]

Skip to content

Commit dfb4532

Browse files
authored
bpo-43751: Fix anext() bug where it erroneously returned None (GH-25238)
1 parent 9045919 commit dfb4532

File tree

3 files changed

+182
-6
lines changed

3 files changed

+182
-6
lines changed

Lib/test/test_asyncgen.py

Lines changed: 135 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -372,11 +372,8 @@ def tearDown(self):
372372
self.loop = None
373373
asyncio.set_event_loop_policy(None)
374374

375-
def test_async_gen_anext(self):
376-
async def gen():
377-
yield 1
378-
yield 2
379-
g = gen()
375+
def check_async_iterator_anext(self, ait_class):
376+
g = ait_class()
380377
async def consume():
381378
results = []
382379
results.append(await anext(g))
@@ -388,6 +385,66 @@ async def consume():
388385
with self.assertRaises(StopAsyncIteration):
389386
self.loop.run_until_complete(consume())
390387

388+
async def test_2():
389+
g1 = ait_class()
390+
self.assertEqual(await anext(g1), 1)
391+
self.assertEqual(await anext(g1), 2)
392+
with self.assertRaises(StopAsyncIteration):
393+
await anext(g1)
394+
with self.assertRaises(StopAsyncIteration):
395+
await anext(g1)
396+
397+
g2 = ait_class()
398+
self.assertEqual(await anext(g2, "default"), 1)
399+
self.assertEqual(await anext(g2, "default"), 2)
400+
self.assertEqual(await anext(g2, "default"), "default")
401+
self.assertEqual(await anext(g2, "default"), "default")
402+
403+
return "completed"
404+
405+
result = self.loop.run_until_complete(test_2())
406+
self.assertEqual(result, "completed")
407+
408+
def test_async_generator_anext(self):
409+
async def agen():
410+
yield 1
411+
yield 2
412+
self.check_async_iterator_anext(agen)
413+
414+
def test_python_async_iterator_anext(self):
415+
class MyAsyncIter:
416+
"""Asynchronously yield 1, then 2."""
417+
def __init__(self):
418+
self.yielded = 0
419+
def __aiter__(self):
420+
return self
421+
async def __anext__(self):
422+
if self.yielded >= 2:
423+
raise StopAsyncIteration()
424+
else:
425+
self.yielded += 1
426+
return self.yielded
427+
self.check_async_iterator_anext(MyAsyncIter)
428+
429+
def test_python_async_iterator_types_coroutine_anext(self):
430+
import types
431+
class MyAsyncIterWithTypesCoro:
432+
"""Asynchronously yield 1, then 2."""
433+
def __init__(self):
434+
self.yielded = 0
435+
def __aiter__(self):
436+
return self
437+
@types.coroutine
438+
def __anext__(self):
439+
if False:
440+
yield "this is a generator-based coroutine"
441+
if self.yielded >= 2:
442+
raise StopAsyncIteration()
443+
else:
444+
self.yielded += 1
445+
return self.yielded
446+
self.check_async_iterator_anext(MyAsyncIterWithTypesCoro)
447+
391448
def test_async_gen_aiter(self):
392449
async def gen():
393450
yield 1
@@ -431,12 +488,85 @@ async def call_with_too_many_args():
431488
await anext(gen(), 1, 3)
432489
async def call_with_wrong_type_args():
433490
await anext(1, gen())
491+
async def call_with_kwarg():
492+
await anext(aiterator=gen())
434493
with self.assertRaises(TypeError):
435494
self.loop.run_until_complete(call_with_too_few_args())
436495
with self.assertRaises(TypeError):
437496
self.loop.run_until_complete(call_with_too_many_args())
438497
with self.assertRaises(TypeError):
439498
self.loop.run_until_complete(call_with_wrong_type_args())
499+
with self.assertRaises(TypeError):
500+
self.loop.run_until_complete(call_with_kwarg())
501+
502+
def test_anext_bad_await(self):
503+
async def bad_awaitable():
504+
class BadAwaitable:
505+
def __await__(self):
506+
return 42
507+
class MyAsyncIter:
508+
def __aiter__(self):
509+
return self
510+
def __anext__(self):
511+
return BadAwaitable()
512+
regex = r"__await__.*iterator"
513+
awaitable = anext(MyAsyncIter(), "default")
514+
with self.assertRaisesRegex(TypeError, regex):
515+
await awaitable
516+
awaitable = anext(MyAsyncIter())
517+
with self.assertRaisesRegex(TypeError, regex):
518+
await awaitable
519+
return "completed"
520+
result = self.loop.run_until_complete(bad_awaitable())
521+
self.assertEqual(result, "completed")
522+
523+
async def check_anext_returning_iterator(self, aiter_class):
524+
awaitable = anext(aiter_class(), "default")
525+
with self.assertRaises(TypeError):
526+
await awaitable
527+
awaitable = anext(aiter_class())
528+
with self.assertRaises(TypeError):
529+
await awaitable
530+
return "completed"
531+
532+
def test_anext_return_iterator(self):
533+
class WithIterAnext:
534+
def __aiter__(self):
535+
return self
536+
def __anext__(self):
537+
return iter("abc")
538+
result = self.loop.run_until_complete(self.check_anext_returning_iterator(WithIterAnext))
539+
self.assertEqual(result, "completed")
540+
541+
def test_anext_return_generator(self):
542+
class WithGenAnext:
543+
def __aiter__(self):
544+
return self
545+
def __anext__(self):
546+
yield
547+
result = self.loop.run_until_complete(self.check_anext_returning_iterator(WithGenAnext))
548+
self.assertEqual(result, "completed")
549+
550+
def test_anext_await_raises(self):
551+
class RaisingAwaitable:
552+
def __await__(self):
553+
raise ZeroDivisionError()
554+
yield
555+
class WithRaisingAwaitableAnext:
556+
def __aiter__(self):
557+
return self
558+
def __anext__(self):
559+
return RaisingAwaitable()
560+
async def do_test():
561+
awaitable = anext(WithRaisingAwaitableAnext())
562+
with self.assertRaises(ZeroDivisionError):
563+
await awaitable
564+
awaitable = anext(WithRaisingAwaitableAnext(), "default")
565+
with self.assertRaises(ZeroDivisionError):
566+
await awaitable
567+
return "completed"
568+
result = self.loop.run_until_complete(do_test())
569+
self.assertEqual(result, "completed")
440570

441571
def test_aiter_bad_args(self):
442572
async def gen():
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Fixed a bug where ``anext(ait, default)`` would erroneously return None.

Objects/iterobject.c

Lines changed: 46 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -316,7 +316,52 @@ anextawaitable_traverse(anextawaitableobject *obj, visitproc visit, void *arg)
316316
static PyObject *
317317
anextawaitable_iternext(anextawaitableobject *obj)
318318
{
319-
PyObject *result = PyIter_Next(obj->wrapped);
319+
/* Consider the following class:
320+
*
321+
* class A:
322+
* async def __anext__(self):
323+
* ...
324+
* a = A()
325+
*
326+
* Then `await anext(a)` should call
327+
* a.__anext__().__await__().__next__()
328+
*
329+
* On the other hand, given
330+
*
331+
* async def agen():
332+
* yield 1
333+
* yield 2
334+
* gen = agen()
335+
*
336+
* Then `await anext(gen)` can just call
337+
* gen.__anext__().__next__()
338+
*/
339+
assert(obj->wrapped != NULL);
340+
PyObject *awaitable = _PyCoro_GetAwaitableIter(obj->wrapped);
341+
if (awaitable == NULL) {
342+
return NULL;
343+
}
344+
if (Py_TYPE(awaitable)->tp_iternext == NULL) {
345+
/* _PyCoro_GetAwaitableIter returns a Coroutine, a Generator,
346+
* or an iterator. Of these, only coroutines lack tp_iternext.
347+
*/
348+
assert(PyCoro_CheckExact(awaitable));
349+
unaryfunc getter = Py_TYPE(awaitable)->tp_as_async->am_await;
350+
PyObject *new_awaitable = getter(awaitable);
351+
if (new_awaitable == NULL) {
352+
Py_DECREF(awaitable);
353+
return NULL;
354+
}
355+
Py_SETREF(awaitable, new_awaitable);
356+
if (Py_TYPE(awaitable)->tp_iternext == NULL) {
357+
PyErr_SetString(PyExc_TypeError,
358+
"__await__ returned a non-iterable");
359+
Py_DECREF(awaitable);
360+
return NULL;
361+
}
362+
}
363+
PyObject *result = (*Py_TYPE(awaitable)->tp_iternext)(awaitable);
364+
Py_DECREF(awaitable);
320365
if (result != NULL) {
321366
return result;
322367
}

0 commit comments

Comments
 (0)
0