8000 bpo-43751: Fix anext() bug where it erroneously returned None by sweeneyde · Pull Request #25238 · python/cpython · GitHub
[go: up one dir, main page]

Skip to content

bpo-43751: Fix anext() bug where it erroneously returned None #25238

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 16 commits into from
Apr 11, 2021
Merged
Show file tree
Hide file tree
Changes from 15 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
140 changes: 135 additions & 5 deletions Lib/test/test_asyncgen.py
Original file line number Diff line number Diff line change
Expand Up @@ -372,11 +372,8 @@ def tearDown(self):
self.loop = None
asyncio.set_event_loop_policy(None)

def test_async_gen_anext(self):
async def gen():
yield 1
yield 2
g = gen()
def check_async_iterator_anext(self, ait_class):
g = ait_class()
async def consume():
results = []
results.append(await anext(g))
Expand All @@ -388,6 +385,66 @@ async def consume():
with self.assertRaises(StopAsyncIteration):
self.loop.run_until_complete(consume())

async def test_2():
g1 = ait_class()
self.assertEqual(await anext(g1), 1)
self.assertEqual(await anext(g1), 2)
with self.assertRaises(StopAsyncIteration):
await anext(g1)
with self.assertRaises(StopAsyncIteration):
await anext(g1)

g2 = ait_class()
self.assertEqual(await anext(g2, "default"), 1)
self.assertEqual(await anext(g2, "default"), 2)
self.assertEqual(await anext(g2, "default"), "default")
self.assertEqual(await anext(g2, "default"), "default")

return "completed"

result = self.loop.run_until_complete(test_2())
self.assertEqual(result, "completed")

def test_async_generator_anext(self):
async def agen():
yield 1
yield 2
self.check_async_iterator_anext(agen)

def test_python_async_iterator_anext(self):
class MyAsyncIter:
"""Asynchronously yield 1, then 2."""
def __init__(self):
self.yielded = 0
def __aiter__(self):
return self
async def __anext__(self):
if self.yielded >= 2:
raise StopAsyncIteration()
else:
self.yielded += 1
return self.yielded
self.check_async_iterator_anext(MyAsyncIter)

def test_python_async_iterator_types_coroutine_anext(self):
import types
class MyAsyncIterWithTypesCoro:
"""Asynchronously yield 1, then 2."""
def __init__(self):
self.yielded = 0
def __aiter__(self):
return self
@types.coroutine
def __anext__(self):
if False:
yield "this is a generator-based coroutine"
if self.yielded >= 2:
raise StopAsyncIteration()
else:
self.yielded += 1
return self.yielded
self.check_async_iterator_anext(MyAsyncIterWithTypesCoro)

def test_async_gen_aiter(self):
async def gen():
yield 1
Expand Down Expand Up @@ -431,12 +488,85 @@ async def call_with_too_many_args():
await anext(gen(), 1, 3)
async def call_with_wrong_type_args():
await anext(1, gen())
async def call_with_kwarg():
await anext(aiterator=gen())
with self.assertRaises(TypeError):
self.loop.run_until_complete(call_with_too_few_args())
with self.assertRaises(TypeError):
self.loop.run_until_complete(call_with_too_many_args())
with self.assertRaises(TypeError):
self.loop.run_until_complete(call_with_wrong_type_args())
with self.assertRaises(TypeError):
self.loop.run_until_complete(call_with_kwarg())

def test_anext_bad_await(self):
async def bad_awaitable():
class BadAwaitable:
def __await__(self):
return 42
class MyAsyncIter:
def __aiter__(self):
return self
def __anext__(self):
return BadAwaitable()
regex = r"__await__.*iterator"
awaitable = anext(MyAsyncIter(), "default")
with self.assertRaisesRegex(TypeError, regex):
await awaitable
awaitable = anext(MyAsyncIter())
with self.assertRaisesRegex(TypeError, regex):
await awaitable
return "completed"
result = self.loop.run_until_complete(bad_awaitable())
self.assertEqual(result, "completed")

async def check_anext_returning_iterator(self, aiter_class):
awaitable = anext(aiter_class(), "default")
with self.assertRaises(TypeError):
await awaitable
awaitable = anext(aiter_class())
with self.assertRaises(TypeError):
await awaitable
return "completed"

def test_anext_return_iterator(self):
class WithIterAnext:
def __aiter__(self):
return self
def __anext__(self):
return iter("abc")
result = self.loop.run_until_complete(self.check_anext_returning_iterator(WithIterAnext))
self.assertEqual(result, "completed")

def test_anext_return_generator(self):
class WithGenAnext:
def __aiter__(self):
return self
def __anext__(self):
yield
result = self.loop.run_until_complete(self.check_anext_returning_iterator(WithGenAnext))
self.assertEqual(result, "completed")

def test_anext_await_raises(self):
class RaisingAwaitable:
def __await__(self):
raise ZeroDivisionError()
yield
class WithRaisingAwaitableAnext:
def __aiter__(self):
return self
def __anext__(self):
return RaisingAwaitable()
async def do_test():
awaitable = anext(WithRaisingAwaitableAnext())
with self.assertRaises(ZeroDivisionError):
await awaitable
awaitable = anext(WithRaisingAwaitableAnext(), "default")
with self.assertRaises(ZeroDivisionError):
await awaitable
return "completed"
result = self.loop.run_until_complete(do_test())
self.assertEqual(result, "completed")

def test_aiter_bad_args(self):
async def gen():
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Fixed a bug where ``anext(ait, default)`` would erroneously return None.
45 changes: 44 additions & 1 deletion Objects/iterobject.c
Original file line number Diff line number Diff line change
Expand Up @@ -316,7 +316,50 @@ anextawaitable_traverse(anextawaitableobject *obj, visitproc visit, void *arg)
static PyObject *
anextawaitable_iternext(anextawaitableobject *obj)
{
PyObject *result = PyIter_Next(obj->wrapped);
/* Consider the following class:
*
* class A:
* async def __anext__(self):
* ...
* a = A()
*
* Then `await anext(a)` should call
* a.__anext__().__await__().__next__()
*
* On the other hand, given
*
* async def agen():
* yield 1
* yield 2
* gen = agen()
*
* Then `await anext(gen)` can just call
* gen.__anext__().__next__()
*/
assert(obj->wrapped != NULL);
PyObject *awaitable = _PyCoro_GetAwaitableIter(obj->wrapped);
if (awaitable == NULL) {
return NULL;
}
if (Py_TYPE(awaitable)->tp_iternext == NULL) {
/* _PyCoro_GetAwaitableIter returns a Coroutine, a Generator,
* or an iterator. Of these, only coroutines lack tp_iternext.
*/
assert(PyCoro_CheckExact(awaitable));
unaryfunc getter = Py_TYPE(awaitable)->tp_as_async->am_await;
PyObject *new_awaitable = getter(awaitable);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This can raise, you need to check for errors and propagate accordingly

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, can you add a test that checks this code path?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added the check, but I couldn't figure out how to make a test to check the code path, since awaitable is always a coroutine, so getter points to this in genobject.c:

static PyObject *
coro_await(PyCoroObject *coro)
{
    PyCoroWrapper *cw = PyObject_GC_New(PyCoroWrapper, &_PyCoroWrapper_Type);
    if (cw == NULL) {
        return NULL;
    }
    Py_INCREF(coro);
    cw->cw_coroutine = coro;
    _PyObject_GC_TRACK(cw);
    return (PyObject *)cw;
}

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You need to add a test to the _testcapi module, but this is going to be a bit too much for just this, so is fine if we don't add such a test.

A3E2

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wait, isn't this something you can get with a custom __await__ in a class?
 

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

_PyCoro_GetAwaitableIter calls am_await and ensures it's iterable; this if block is only for coroutines.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

test_anext_bad_await covers __await__ methods that raise

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍

if (new_awaitable == NULL) {
Py_DECREF(awaitable);
return NULL;
}
Py_SETREF(awaitable, new_awaitable);
if (Py_TYPE(awaitable)->tp_iternext == NULL) {
PyErr_SetString(PyExc_TypeError,
Copy link
Member
@pablogsal pablogsal Apr 11, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is leaking awaitable.

"__await__ returned a non-iterable");
return NULL;
}
}
PyObject *result = (*Py_TYPE(awaitable)->tp_iternext)(awaitable);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You need to decrement awaitable once you are done with it

if (result != NULL) {
return result;
}
Expand Down
0