diff --git a/Lib/asyncio/streams.py b/Lib/asyncio/streams.py index 4517ca22d74637..177f7db75a4e1c 100644 --- a/Lib/asyncio/streams.py +++ b/Lib/asyncio/streams.py @@ -2,7 +2,7 @@ 'StreamReader', 'StreamWriter', 'StreamReaderProtocol', 'open_connection', 'start_server') -import collections +import collections.abc import socket import sys import warnings @@ -597,7 +597,7 @@ async def readuntil(self, separator=b'\n'): the shortest possible separator is considered to be the one that matched. """ - if isinstance(separator, bytes): + if isinstance(separator, collections.abc.Buffer): separator = [separator] else: # Makes sure shortest matches wins, and supports arbitrary iterables diff --git a/Lib/test/test_asyncio/test_streams.py b/Lib/test/test_asyncio/test_streams.py index 792e88761acdc2..8af06a060cd4fb 100644 --- a/Lib/test/test_asyncio/test_streams.py +++ b/Lib/test/test_asyncio/test_streams.py @@ -512,6 +512,13 @@ def test_readuntil_multi_separator_negative_offset(self): self.assertEqual(b'dataZA', data) self.assertEqual(b'aaa', stream._buffer) + def test_readuntil_bytearray(self): + stream = asyncio.StreamReader(loop=self.loop) + stream.feed_data(b'some data\r\n') + data = self.loop.run_until_complete(stream.readuntil(bytearray(b'\r\n'))) + self.assertEqual(b'some data\r\n', data) + self.assertEqual(b'', stream._buffer) + def test_readexactly_zero_or_less(self): # Read exact number of bytes (zero or less). stream = asyncio.StreamReader(loop=self.loop)